diff --git a/.ocamlformat b/.ocamlformat new file mode 100644 index 000000000..9e152fbde --- /dev/null +++ b/.ocamlformat @@ -0,0 +1,9 @@ +profile = default +version = 0.25.1 +margin = 120 +exp-grouping = preserve +parens-ite = true +space-around-lists = false +indicate-multiline-delimiters = closing-on-separate-line +module-item-spacing = preserve +doc-comments = before diff --git a/.ocamlformat-ignore b/.ocamlformat-ignore new file mode 100644 index 000000000..854a99d14 --- /dev/null +++ b/.ocamlformat-ignore @@ -0,0 +1,2 @@ +lib/** + diff --git a/sailcov/dune b/sailcov/dune index c69fec07c..69dd4adea 100644 --- a/sailcov/dune +++ b/sailcov/dune @@ -1,2 +1,2 @@ (executable - (name main)) + (name main)) diff --git a/src/bin/callgraph_commands.ml b/src/bin/callgraph_commands.ml index 52af85351..8690e1f63 100644 --- a/src/bin/callgraph_commands.ml +++ b/src/bin/callgraph_commands.ml @@ -77,7 +77,7 @@ let node_string n = node_id n |> string_of_id |> String.escaped let edge_color _from_node _to_node = "black" let node_color cuts = - let module NodeSet = Set.Make(Node) in + let module NodeSet = Set.Make (Node) in function | node when NodeSet.mem node cuts -> "red" | Register _ -> "lightpink" @@ -92,8 +92,8 @@ let node_color cuts = | Outcome _ -> "purple" let dot_of_ast out_chan ast = - let module G = Graph.Make(Node) in - let module NodeSet = Set.Make(Node) in + let module G = Graph.Make (Node) in + let module NodeSet = Set.Make (Node) in let g = graph_of_ast ast in G.make_dot (node_color NodeSet.empty) edge_color node_string out_chan g @@ -101,74 +101,94 @@ let node_of_id env = let lets = Type_check.Env.get_toplevel_lets env in let specs = Type_check.Env.get_defined_val_specs env in fun id -> - if IdSet.mem id lets then Letbind id - else if IdSet.mem id specs then Function id - else if Type_check.Env.bound_typ_id env id then Type id - else (prerr_endline ("Warning: unknown identifier " ^ string_of_id id); Function id) + if IdSet.mem id lets then Letbind id + else if IdSet.mem id specs then Function id + else if Type_check.Env.bound_typ_id env id then Type id + else ( + prerr_endline ("Warning: unknown identifier " ^ string_of_id id); + Function id + ) let () = let slice_roots = ref IdSet.empty in let slice_keep_std = ref false in let slice_cuts = ref IdSet.empty in - ArgString ("identifiers", fun arg -> ActionUnit (fun _ -> - let args = Str.split (Str.regexp " +") arg in - let ids = List.map mk_id args |> IdSet.of_list in - Specialize.add_initial_calls ids; - slice_roots := IdSet.union ids !slice_roots - )) |> register_command ~name:"slice_roots" ~help:"Set the roots for :slice"; + ArgString + ( "identifiers", + fun arg -> + ActionUnit + (fun _ -> + let args = Str.split (Str.regexp " +") arg in + let ids = List.map mk_id args |> IdSet.of_list in + Specialize.add_initial_calls ids; + slice_roots := IdSet.union ids !slice_roots + ) + ) + |> register_command ~name:"slice_roots" ~help:"Set the roots for :slice"; - ActionUnit (fun _ -> - slice_keep_std := true - ) |> register_command ~name:"slice_keep_std" ~help:"Keep standard library contents during :slice"; + ActionUnit (fun _ -> slice_keep_std := true) + |> register_command ~name:"slice_keep_std" ~help:"Keep standard library contents during :slice"; - ArgString ("identifiers", fun arg -> ActionUnit (fun _ -> - let args = Str.split (Str.regexp " +") arg in - let ids = List.map mk_id args |> IdSet.of_list in - slice_cuts := IdSet.union ids !slice_cuts - )) |> register_command ~name:"slice_cuts" ~help:"Set the cuts for :slice"; + ArgString + ( "identifiers", + fun arg -> + ActionUnit + (fun _ -> + let args = Str.split (Str.regexp " +") arg in + let ids = List.map mk_id args |> IdSet.of_list in + slice_cuts := IdSet.union ids !slice_cuts + ) + ) + |> register_command ~name:"slice_cuts" ~help:"Set the cuts for :slice"; - Action (fun istate -> - let module NodeSet = Set.Make(Node) in - let module G = Graph.Make(Node) in - let g = graph_of_ast istate.ast in - let roots = !slice_roots |> IdSet.elements |> List.map (node_of_id istate.env) |> NodeSet.of_list in - let cuts = !slice_cuts |> IdSet.elements |> List.map (node_of_id istate.env) |> NodeSet.of_list in - let g = G.prune roots cuts g in - { istate with ast = filter_ast_extra cuts g istate.ast !slice_keep_std } - ) |> register_command - ~name:"slice" - ~help:"Slice AST to the definitions which the functions given \ - by :slice_roots depend on, up to the functions given \ - by :slice_cuts"; + Action + (fun istate -> + let module NodeSet = Set.Make (Node) in + let module G = Graph.Make (Node) in + let g = graph_of_ast istate.ast in + let roots = !slice_roots |> IdSet.elements |> List.map (node_of_id istate.env) |> NodeSet.of_list in + let cuts = !slice_cuts |> IdSet.elements |> List.map (node_of_id istate.env) |> NodeSet.of_list in + let g = G.prune roots cuts g in + { istate with ast = filter_ast_extra cuts g istate.ast !slice_keep_std } + ) + |> register_command ~name:"slice" + ~help: + "Slice AST to the definitions which the functions given by :slice_roots depend on, up to the functions given \ + by :slice_cuts"; - Action (fun istate -> - let module NodeSet = Set.Make(Node) in - let module NodeMap = Map.Make(Node) in - let module G = Graph.Make(Node) in - let g = graph_of_ast istate.ast in - let roots = !slice_roots |> IdSet.elements |> List.map (node_of_id istate.env) |> NodeSet.of_list in - let keep = function - | (Function id,_) when IdSet.mem id (!slice_roots) -> None - | (Function id,_) -> Some (Function id) - | _ -> None - in - let cuts = NodeMap.bindings g |> List.filter_map keep |> NodeSet.of_list in - let g = G.prune roots cuts g in - { istate with ast = filter_ast_extra cuts g istate.ast !slice_keep_std } - ) |> register_command - ~name:"thin_slice" - ~help:(sprintf ":thin_slice - Slice AST to the function definitions given with %s" (command "slice_roots")); + Action + (fun istate -> + let module NodeSet = Set.Make (Node) in + let module NodeMap = Map.Make (Node) in + let module G = Graph.Make (Node) in + let g = graph_of_ast istate.ast in + let roots = !slice_roots |> IdSet.elements |> List.map (node_of_id istate.env) |> NodeSet.of_list in + let keep = function + | Function id, _ when IdSet.mem id !slice_roots -> None + | Function id, _ -> Some (Function id) + | _ -> None + in + let cuts = NodeMap.bindings g |> List.filter_map keep |> NodeSet.of_list in + let g = G.prune roots cuts g in + { istate with ast = filter_ast_extra cuts g istate.ast !slice_keep_std } + ) + |> register_command ~name:"thin_slice" + ~help:(sprintf ":thin_slice - Slice AST to the function definitions given with %s" (command "slice_roots")); - ArgString ("format", fun arg -> ActionUnit (fun istate -> - let format = if arg = "" then "svg" else arg in - let dotfile, out_chan = Filename.open_temp_file "sail_graph_" ".gz" in - let image = Filename.temp_file "sail_graph_" ("." ^ format) in - dot_of_ast out_chan istate.ast; - close_out out_chan; - let _ = Unix.system (Printf.sprintf "dot -T%s %s -o %s" format dotfile image) in - let _ = Unix.system (Printf.sprintf "xdg-open %s" image) in - () - )) |> register_command - ~name:"graph" - ~help:"Draw a callgraph using dot in :0 (e.g. svg), and open with xdg-open" + ArgString + ( "format", + fun arg -> + ActionUnit + (fun istate -> + let format = if arg = "" then "svg" else arg in + let dotfile, out_chan = Filename.open_temp_file "sail_graph_" ".gz" in + let image = Filename.temp_file "sail_graph_" ("." ^ format) in + dot_of_ast out_chan istate.ast; + close_out out_chan; + let _ = Unix.system (Printf.sprintf "dot -T%s %s -o %s" format dotfile image) in + let _ = Unix.system (Printf.sprintf "xdg-open %s" image) in + () + ) + ) + |> register_command ~name:"graph" ~help:"Draw a callgraph using dot in :0 (e.g. svg), and open with xdg-open" diff --git a/src/bin/dune b/src/bin/dune index 67dcb2fef..813ed0a41 100644 --- a/src/bin/dune +++ b/src/bin/dune @@ -1,122 +1,200 @@ - (rule - (target manifest.ml) - (mode fallback) - (action - (with-outputs-to %{target} - (chdir %{workspace_root} - (run sail_manifest -gen_manifest))))) + (target manifest.ml) + (mode fallback) + (action + (with-outputs-to + %{target} + (chdir + %{workspace_root} + (run sail_manifest -gen_manifest))))) (executable - (name sail) - (public_name sail) - (package sail) - (link_flags -linkall) - (libraries libsail linenoise dynlink)) + (name sail) + (public_name sail) + (package sail) + (link_flags -linkall) + (libraries libsail linenoise dynlink)) ; For legacy reasons install all the Sail files in lib as part of this package + (install - (section share) - (package sail) - (files - (%{workspace_root}/src/lib/util.ml as src/lib/util.ml) - (%{workspace_root}/src/lib/sail_lib.ml as src/lib/sail_lib.ml) - (%{workspace_root}/src/lib/elf_loader.ml as src/lib/elf_loader.ml) - (%{workspace_root}/lib/_tags as lib/_tags) - (%{workspace_root}/lib/_tags_coverage as lib/_tags_coverage) - (%{workspace_root}/lib/arith.sail as lib/arith.sail) - (%{workspace_root}/lib/concurrency_interface.sail as lib/concurrency_interface.sail) - (%{workspace_root}/lib/concurrency_interface/v1.sail as lib/concurrency_interface/v1.sail) - (%{workspace_root}/lib/concurrency_interface/emulator_memory.sail as lib/concurrency_interface/emulator_memory.sail) - (%{workspace_root}/lib/concurrency_interface/exception.sail as lib/concurrency_interface/exception.sail) - (%{workspace_root}/lib/concurrency_interface/tlbi.sail as lib/concurrency_interface/tlbi.sail) - (%{workspace_root}/lib/concurrency_interface/barrier.sail as lib/concurrency_interface/barrier.sail) - (%{workspace_root}/lib/concurrency_interface/cache_op.sail as lib/concurrency_interface/cache_op.sail) - (%{workspace_root}/lib/concurrency_interface/common.sail as lib/concurrency_interface/common.sail) - (%{workspace_root}/lib/concurrency_interface/read_write.sail as lib/concurrency_interface/read_write.sail) - (%{workspace_root}/lib/coverage/Cargo.toml as lib/coverage/Cargo.toml) - (%{workspace_root}/lib/coverage/Makefile as lib/coverage/Makefile) - (%{workspace_root}/lib/coverage/src/lib.rs as lib/coverage/src/lib.rs) - (%{workspace_root}/lib/elf.c as lib/elf.c) - (%{workspace_root}/lib/elf.h as lib/elf.h) - (%{workspace_root}/lib/elf.sail as lib/elf.sail) - (%{workspace_root}/lib/exception_basic.sail as lib/exception_basic.sail) - (%{workspace_root}/lib/exception_result.sail as lib/exception_result.sail) - (%{workspace_root}/lib/exception.sail as lib/exception.sail) - (%{workspace_root}/lib/flow.sail as lib/flow.sail) - (%{workspace_root}/lib/generic_equality.sail as lib/generic_equality.sail) - (%{workspace_root}/lib/hol/.gitignore as lib/hol/.gitignore) - (%{workspace_root}/lib/hol/Holmakefile as lib/hol/Holmakefile) - (%{workspace_root}/lib/hol/Makefile as lib/hol/Makefile) - (%{workspace_root}/lib/hol/sail2_prompt.lem as lib/hol/sail2_prompt.lem) - (%{workspace_root}/lib/hol/sail2_prompt_monad.lem as lib/hol/sail2_prompt_monad.lem) - (%{workspace_root}/lib/hol/sail2_stateAuxiliaryScript.sml as lib/hol/sail2_stateAuxiliaryScript.sml) - (%{workspace_root}/lib/hol/sail2_undefined.lem as lib/hol/sail2_undefined.lem) - (%{workspace_root}/lib/hol/sail2_valuesAuxiliaryScript.sml as lib/hol/sail2_valuesAuxiliaryScript.sml) - (%{workspace_root}/lib/instr_kinds.sail as lib/instr_kinds.sail) - (%{workspace_root}/lib/int128/rts.c as lib/int128/rts.c) - (%{workspace_root}/lib/int128/rts.h as lib/int128/rts.h) - (%{workspace_root}/lib/int128/sail.c as lib/int128/sail.c) - (%{workspace_root}/lib/int128/sail.h as lib/int128/sail.h) - (%{workspace_root}/lib/isabelle/.gitignore as lib/isabelle/.gitignore) - (%{workspace_root}/lib/isabelle/Add_Cancel_Distinct.thy as lib/isabelle/Add_Cancel_Distinct.thy) - (%{workspace_root}/lib/isabelle/Hoare.thy as lib/isabelle/Hoare.thy) - (%{workspace_root}/lib/isabelle/Makefile as lib/isabelle/Makefile) - (%{workspace_root}/lib/isabelle/ROOT as lib/isabelle/ROOT) - (%{workspace_root}/lib/isabelle/Sail2_operators_mwords_lemmas.thy as lib/isabelle/Sail2_operators_mwords_lemmas.thy) - (%{workspace_root}/lib/isabelle/Sail2_prompt_monad_lemmas.thy as lib/isabelle/Sail2_prompt_monad_lemmas.thy) - (%{workspace_root}/lib/isabelle/Sail2_state_lemmas.thy as lib/isabelle/Sail2_state_lemmas.thy) - (%{workspace_root}/lib/isabelle/Sail2_state_monad_lemmas.thy as lib/isabelle/Sail2_state_monad_lemmas.thy) - (%{workspace_root}/lib/isabelle/Sail2_values_lemmas.thy as lib/isabelle/Sail2_values_lemmas.thy) - (%{workspace_root}/lib/isabelle/document/root.tex as lib/isabelle/document/root.tex) - (%{workspace_root}/lib/isabelle/manual/Manual.thy as lib/isabelle/manual/Manual.thy) - (%{workspace_root}/lib/isabelle/manual/ROOT as lib/isabelle/manual/ROOT) - (%{workspace_root}/lib/isabelle/manual/document/root.tex as lib/isabelle/manual/document/root.tex) - (%{workspace_root}/lib/isla.sail as lib/isla.sail) - (%{workspace_root}/lib/main.ml as lib/main.ml) - (%{workspace_root}/lib/mapping.sail as lib/mapping.sail) - (%{workspace_root}/lib/mono_rewrites.sail as lib/mono_rewrites.sail) - (%{workspace_root}/lib/myocamlbuild_coverage.ml as lib/myocamlbuild_coverage.ml) - (%{workspace_root}/lib/nostd/sail.c as lib/nostd/sail.c) - (%{workspace_root}/lib/nostd/sail.h as lib/nostd/sail.h) - (%{workspace_root}/lib/nostd/sail_alloc.h as lib/nostd/sail_alloc.h) - (%{workspace_root}/lib/nostd/sail_arena.c as lib/nostd/sail_arena.c) - (%{workspace_root}/lib/nostd/sail_arena.h as lib/nostd/sail_arena.h) - (%{workspace_root}/lib/nostd/sail_failure.h as lib/nostd/sail_failure.h) - (%{workspace_root}/lib/nostd/sail_spinlock.h as lib/nostd/sail_spinlock.h) - (%{workspace_root}/lib/nostd/stubs/sail_failure.c as lib/nostd/stubs/sail_failure.c) - (%{workspace_root}/lib/nostd/test/test.c as lib/nostd/test/test.c) - (%{workspace_root}/lib/option.sail as lib/option.sail) - (%{workspace_root}/lib/prelude.sail as lib/prelude.sail) - (%{workspace_root}/lib/real.sail as lib/real.sail) - (%{workspace_root}/lib/regfp.sail as lib/regfp.sail) - (%{workspace_root}/lib/result.sail as lib/result.sail) - (%{workspace_root}/lib/float.sail as lib/float.sail) - (%{workspace_root}/lib/reverse_endianness.sail as lib/reverse_endianness.sail) - (%{workspace_root}/lib/rts.c as lib/rts.c) - (%{workspace_root}/lib/rts.h as lib/rts.h) - (%{workspace_root}/lib/sail.c as lib/sail.c) - (%{workspace_root}/lib/sail.h as lib/sail.h) - (%{workspace_root}/lib/sail.tex as lib/sail.tex) - (%{workspace_root}/lib/sail_coverage.h as lib/sail_coverage.h) - (%{workspace_root}/lib/sail_failure.c as lib/sail_failure.c) - (%{workspace_root}/lib/sail_failure.h as lib/sail_failure.h) - (%{workspace_root}/lib/sail_state.h as lib/sail_state.h) - (%{workspace_root}/lib/smt.sail as lib/smt.sail) - (%{workspace_root}/lib/string.sail as lib/string.sail) - (%{workspace_root}/lib/trace.sail as lib/trace.sail) - (%{workspace_root}/lib/vector_dec.sail as lib/vector_dec.sail) - (%{workspace_root}/lib/vector_inc.sail as lib/vector_inc.sail) - (%{workspace_root}/src/gen_lib/sail2_deep_shallow_convert.lem as src/gen_lib/sail2_deep_shallow_convert.lem) - (%{workspace_root}/src/gen_lib/sail2_instr_kinds.lem as src/gen_lib/sail2_instr_kinds.lem) - (%{workspace_root}/src/gen_lib/sail2_operators.lem as src/gen_lib/sail2_operators.lem) - (%{workspace_root}/src/gen_lib/sail2_operators_bitlists.lem as src/gen_lib/sail2_operators_bitlists.lem) - (%{workspace_root}/src/gen_lib/sail2_operators_mwords.lem as src/gen_lib/sail2_operators_mwords.lem) - (%{workspace_root}/src/gen_lib/sail2_prompt.lem as src/gen_lib/sail2_prompt.lem) - (%{workspace_root}/src/gen_lib/sail2_prompt_monad.lem as src/gen_lib/sail2_prompt_monad.lem) - (%{workspace_root}/src/gen_lib/sail2_state.lem as src/gen_lib/sail2_state.lem) - (%{workspace_root}/src/gen_lib/sail2_state_lifting.lem as src/gen_lib/sail2_state_lifting.lem) - (%{workspace_root}/src/gen_lib/sail2_state_monad.lem as src/gen_lib/sail2_state_monad.lem) - (%{workspace_root}/src/gen_lib/sail2_string.lem as src/gen_lib/sail2_string.lem) - (%{workspace_root}/src/gen_lib/sail2_undefined.lem as src/gen_lib/sail2_undefined.lem) - (%{workspace_root}/src/gen_lib/sail2_values.lem as src/gen_lib/sail2_values.lem))) + (section share) + (package sail) + (files + (%{workspace_root}/src/lib/util.ml as src/lib/util.ml) + (%{workspace_root}/src/lib/sail_lib.ml as src/lib/sail_lib.ml) + (%{workspace_root}/src/lib/elf_loader.ml as src/lib/elf_loader.ml) + (%{workspace_root}/lib/_tags as lib/_tags) + (%{workspace_root}/lib/_tags_coverage as lib/_tags_coverage) + (%{workspace_root}/lib/arith.sail as lib/arith.sail) + (%{workspace_root}/lib/concurrency_interface.sail + as + lib/concurrency_interface.sail) + (%{workspace_root}/lib/concurrency_interface/v1.sail + as + lib/concurrency_interface/v1.sail) + (%{workspace_root}/lib/concurrency_interface/emulator_memory.sail + as + lib/concurrency_interface/emulator_memory.sail) + (%{workspace_root}/lib/concurrency_interface/exception.sail + as + lib/concurrency_interface/exception.sail) + (%{workspace_root}/lib/concurrency_interface/tlbi.sail + as + lib/concurrency_interface/tlbi.sail) + (%{workspace_root}/lib/concurrency_interface/barrier.sail + as + lib/concurrency_interface/barrier.sail) + (%{workspace_root}/lib/concurrency_interface/cache_op.sail + as + lib/concurrency_interface/cache_op.sail) + (%{workspace_root}/lib/concurrency_interface/common.sail + as + lib/concurrency_interface/common.sail) + (%{workspace_root}/lib/concurrency_interface/read_write.sail + as + lib/concurrency_interface/read_write.sail) + (%{workspace_root}/lib/coverage/Cargo.toml as lib/coverage/Cargo.toml) + (%{workspace_root}/lib/coverage/Makefile as lib/coverage/Makefile) + (%{workspace_root}/lib/coverage/src/lib.rs as lib/coverage/src/lib.rs) + (%{workspace_root}/lib/elf.c as lib/elf.c) + (%{workspace_root}/lib/elf.h as lib/elf.h) + (%{workspace_root}/lib/elf.sail as lib/elf.sail) + (%{workspace_root}/lib/exception_basic.sail as lib/exception_basic.sail) + (%{workspace_root}/lib/exception_result.sail as lib/exception_result.sail) + (%{workspace_root}/lib/exception.sail as lib/exception.sail) + (%{workspace_root}/lib/flow.sail as lib/flow.sail) + (%{workspace_root}/lib/generic_equality.sail as lib/generic_equality.sail) + (%{workspace_root}/lib/hol/.gitignore as lib/hol/.gitignore) + (%{workspace_root}/lib/hol/Holmakefile as lib/hol/Holmakefile) + (%{workspace_root}/lib/hol/Makefile as lib/hol/Makefile) + (%{workspace_root}/lib/hol/sail2_prompt.lem as lib/hol/sail2_prompt.lem) + (%{workspace_root}/lib/hol/sail2_prompt_monad.lem + as + lib/hol/sail2_prompt_monad.lem) + (%{workspace_root}/lib/hol/sail2_stateAuxiliaryScript.sml + as + lib/hol/sail2_stateAuxiliaryScript.sml) + (%{workspace_root}/lib/hol/sail2_undefined.lem + as + lib/hol/sail2_undefined.lem) + (%{workspace_root}/lib/hol/sail2_valuesAuxiliaryScript.sml + as + lib/hol/sail2_valuesAuxiliaryScript.sml) + (%{workspace_root}/lib/instr_kinds.sail as lib/instr_kinds.sail) + (%{workspace_root}/lib/int128/rts.c as lib/int128/rts.c) + (%{workspace_root}/lib/int128/rts.h as lib/int128/rts.h) + (%{workspace_root}/lib/int128/sail.c as lib/int128/sail.c) + (%{workspace_root}/lib/int128/sail.h as lib/int128/sail.h) + (%{workspace_root}/lib/isabelle/.gitignore as lib/isabelle/.gitignore) + (%{workspace_root}/lib/isabelle/Add_Cancel_Distinct.thy + as + lib/isabelle/Add_Cancel_Distinct.thy) + (%{workspace_root}/lib/isabelle/Hoare.thy as lib/isabelle/Hoare.thy) + (%{workspace_root}/lib/isabelle/Makefile as lib/isabelle/Makefile) + (%{workspace_root}/lib/isabelle/ROOT as lib/isabelle/ROOT) + (%{workspace_root}/lib/isabelle/Sail2_operators_mwords_lemmas.thy + as + lib/isabelle/Sail2_operators_mwords_lemmas.thy) + (%{workspace_root}/lib/isabelle/Sail2_prompt_monad_lemmas.thy + as + lib/isabelle/Sail2_prompt_monad_lemmas.thy) + (%{workspace_root}/lib/isabelle/Sail2_state_lemmas.thy + as + lib/isabelle/Sail2_state_lemmas.thy) + (%{workspace_root}/lib/isabelle/Sail2_state_monad_lemmas.thy + as + lib/isabelle/Sail2_state_monad_lemmas.thy) + (%{workspace_root}/lib/isabelle/Sail2_values_lemmas.thy + as + lib/isabelle/Sail2_values_lemmas.thy) + (%{workspace_root}/lib/isabelle/document/root.tex + as + lib/isabelle/document/root.tex) + (%{workspace_root}/lib/isabelle/manual/Manual.thy + as + lib/isabelle/manual/Manual.thy) + (%{workspace_root}/lib/isabelle/manual/ROOT as lib/isabelle/manual/ROOT) + (%{workspace_root}/lib/isabelle/manual/document/root.tex + as + lib/isabelle/manual/document/root.tex) + (%{workspace_root}/lib/isla.sail as lib/isla.sail) + (%{workspace_root}/lib/main.ml as lib/main.ml) + (%{workspace_root}/lib/mapping.sail as lib/mapping.sail) + (%{workspace_root}/lib/mono_rewrites.sail as lib/mono_rewrites.sail) + (%{workspace_root}/lib/myocamlbuild_coverage.ml + as + lib/myocamlbuild_coverage.ml) + (%{workspace_root}/lib/nostd/sail.c as lib/nostd/sail.c) + (%{workspace_root}/lib/nostd/sail.h as lib/nostd/sail.h) + (%{workspace_root}/lib/nostd/sail_alloc.h as lib/nostd/sail_alloc.h) + (%{workspace_root}/lib/nostd/sail_arena.c as lib/nostd/sail_arena.c) + (%{workspace_root}/lib/nostd/sail_arena.h as lib/nostd/sail_arena.h) + (%{workspace_root}/lib/nostd/sail_failure.h as lib/nostd/sail_failure.h) + (%{workspace_root}/lib/nostd/sail_spinlock.h as lib/nostd/sail_spinlock.h) + (%{workspace_root}/lib/nostd/stubs/sail_failure.c + as + lib/nostd/stubs/sail_failure.c) + (%{workspace_root}/lib/nostd/test/test.c as lib/nostd/test/test.c) + (%{workspace_root}/lib/option.sail as lib/option.sail) + (%{workspace_root}/lib/prelude.sail as lib/prelude.sail) + (%{workspace_root}/lib/real.sail as lib/real.sail) + (%{workspace_root}/lib/regfp.sail as lib/regfp.sail) + (%{workspace_root}/lib/result.sail as lib/result.sail) + (%{workspace_root}/lib/float.sail as lib/float.sail) + (%{workspace_root}/lib/reverse_endianness.sail + as + lib/reverse_endianness.sail) + (%{workspace_root}/lib/rts.c as lib/rts.c) + (%{workspace_root}/lib/rts.h as lib/rts.h) + (%{workspace_root}/lib/sail.c as lib/sail.c) + (%{workspace_root}/lib/sail.h as lib/sail.h) + (%{workspace_root}/lib/sail.tex as lib/sail.tex) + (%{workspace_root}/lib/sail_coverage.h as lib/sail_coverage.h) + (%{workspace_root}/lib/sail_failure.c as lib/sail_failure.c) + (%{workspace_root}/lib/sail_failure.h as lib/sail_failure.h) + (%{workspace_root}/lib/sail_state.h as lib/sail_state.h) + (%{workspace_root}/lib/smt.sail as lib/smt.sail) + (%{workspace_root}/lib/string.sail as lib/string.sail) + (%{workspace_root}/lib/trace.sail as lib/trace.sail) + (%{workspace_root}/lib/vector_dec.sail as lib/vector_dec.sail) + (%{workspace_root}/lib/vector_inc.sail as lib/vector_inc.sail) + (%{workspace_root}/src/gen_lib/sail2_deep_shallow_convert.lem + as + src/gen_lib/sail2_deep_shallow_convert.lem) + (%{workspace_root}/src/gen_lib/sail2_instr_kinds.lem + as + src/gen_lib/sail2_instr_kinds.lem) + (%{workspace_root}/src/gen_lib/sail2_operators.lem + as + src/gen_lib/sail2_operators.lem) + (%{workspace_root}/src/gen_lib/sail2_operators_bitlists.lem + as + src/gen_lib/sail2_operators_bitlists.lem) + (%{workspace_root}/src/gen_lib/sail2_operators_mwords.lem + as + src/gen_lib/sail2_operators_mwords.lem) + (%{workspace_root}/src/gen_lib/sail2_prompt.lem + as + src/gen_lib/sail2_prompt.lem) + (%{workspace_root}/src/gen_lib/sail2_prompt_monad.lem + as + src/gen_lib/sail2_prompt_monad.lem) + (%{workspace_root}/src/gen_lib/sail2_state.lem + as + src/gen_lib/sail2_state.lem) + (%{workspace_root}/src/gen_lib/sail2_state_lifting.lem + as + src/gen_lib/sail2_state_lifting.lem) + (%{workspace_root}/src/gen_lib/sail2_state_monad.lem + as + src/gen_lib/sail2_state_monad.lem) + (%{workspace_root}/src/gen_lib/sail2_string.lem + as + src/gen_lib/sail2_string.lem) + (%{workspace_root}/src/gen_lib/sail2_undefined.lem + as + src/gen_lib/sail2_undefined.lem) + (%{workspace_root}/src/gen_lib/sail2_values.lem + as + src/gen_lib/sail2_values.lem))) diff --git a/src/bin/repl.ml b/src/bin/repl.ml index e654a00cb..176082483 100644 --- a/src/bin/repl.ml +++ b/src/bin/repl.ml @@ -78,37 +78,32 @@ module Callgraph_commands = Callgraph_commands module Gdbmi = Gdbmi *) -type mode = - | Evaluation of frame - | Normal +type mode = Evaluation of frame | Normal type istate = { - ast : Type_check.tannot ast; - effect_info : Effects.side_effect_info; - env : Type_check.Env.t; - ref_state : Interactive.istate ref; - vs_ids : IdSet.t ref; - options : (Arg.key * Arg.spec * Arg.doc) list; - mode : mode; - clear : bool; - state : Interpreter.lstate * Interpreter.gstate; - default_sail_dir : string; - } - -let shrink_istate istate = ({ - ast = istate.ast; - effect_info = istate.effect_info; - env = istate.env; - default_sail_dir = istate.default_sail_dir; - } : Interactive.istate) - -let initial_istate options env effect_info ast = { - ast = ast; - effect_info = effect_info; - env = env; + ast : Type_check.tannot ast; + effect_info : Effects.side_effect_info; + env : Type_check.Env.t; + ref_state : Interactive.istate ref; + vs_ids : IdSet.t ref; + options : (Arg.key * Arg.spec * Arg.doc) list; + mode : mode; + clear : bool; + state : Interpreter.lstate * Interpreter.gstate; + default_sail_dir : string; +} + +let shrink_istate istate : Interactive.istate = + { ast = istate.ast; effect_info = istate.effect_info; env = istate.env; default_sail_dir = istate.default_sail_dir } + +let initial_istate options env effect_info ast = + { + ast; + effect_info; + env; ref_state = ref (Interactive.initial_istate Manifest.dir); vs_ids = ref (val_spec_ids ast.defs); - options = options; + options; mode = Normal; clear = true; state = initial_state ~registers:false empty_ast Type_check.initial_env !Value.primops; @@ -116,45 +111,37 @@ let initial_istate options env effect_info ast = { } let setup_interpreter_state istate = - istate.ref_state := { - ast = istate.ast; - effect_info = istate.effect_info; - env = istate.env; - default_sail_dir = istate.default_sail_dir - }; + istate.ref_state := + { ast = istate.ast; effect_info = istate.effect_info; env = istate.env; default_sail_dir = istate.default_sail_dir }; { istate with state = initial_state istate.ast istate.env !Value.primops } -let prompt istate = - match istate.mode with - | Normal -> "sail> " - | Evaluation _ -> "eval> " +let prompt istate = match istate.mode with Normal -> "sail> " | Evaluation _ -> "eval> " let mode_clear istate = - match istate.mode with - | Normal -> () - | Evaluation _ -> if istate.clear then LNoise.clear_screen () else () - + match istate.mode with Normal -> () | Evaluation _ -> if istate.clear then LNoise.clear_screen () else () + let rec user_input istate callback = match LNoise.linenoise (prompt istate) with | None -> () | Some line -> - mode_clear istate; - user_input (callback istate line) callback + mode_clear istate; + user_input (callback istate line) callback let sail_logo = let banner str = str |> Util.bold |> Util.red |> Util.clear in let logo = - [ {| ___ ___ ___ ___ |}; + [ + {| ___ ___ ___ ___ |}; {| /\ \ /\ \ /\ \ /\__\|}; {| /::\ \ /::\ \ _\:\ \ /:/ /|}; {| /\:\:\__\ /::\:\__\ /\/::\__\ /:/__/ |}; {| \:\:\/__/ \/\::/ / \::/\/__/ \:\ \ |}; {| \::/ / /:/ / \:\__\ \:\__\|}; - {| \/__/ \/__/ \/__/ \/__/|} ] + {| \/__/ \/__/ \/__/ \/__/|}; + ] in let help = - [ "Type :commands for a list of commands, and :help for help."; - "Type expressions to evaluate them." ] + ["Type :commands for a list of commands, and :help for help."; "Type expressions to evaluate them."] in List.map banner logo @ [""] @ help @ [""] @@ -165,14 +152,23 @@ let sep = "-----------------------------------------------------" |> Util.blue | let () = let open Interactive in let open Elf_loader in - ArgString ("file", fun file -> ActionUnit (fun _ -> load_elf file)) |> register_command ~name:"elf" ~help:"Load an elf file"; - ArgString ("addr", fun addr_s -> ArgString ("file", fun filename -> ActionUnit (fun _ -> - let addr = Big_int.of_string addr_s in - load_binary addr filename - ))) |> register_command ~name:"bin" ~help:"Load a raw binary file at :0. Use :elf to load an ELF"; + ArgString + ( "addr", + fun addr_s -> + ArgString + ( "file", + fun filename -> + ActionUnit + (fun _ -> + let addr = Big_int.of_string addr_s in + load_binary addr filename + ) + ) + ) + |> register_command ~name:"bin" ~help:"Load a raw binary file at :0. Use :elf to load an ELF"; ActionUnit (fun istate -> print_endline (Reporting.get_sail_dir istate.default_sail_dir)) |> register_command ~name:"sail_dir" ~help:"print Sail directory location" @@ -184,237 +180,207 @@ let setup_sail_scripting istate = let sail_command_name cmd = "sail_" ^ String.sub cmd 1 (String.length cmd - 1) in let cmds = Interactive.all_commands () in - + let val_specs = - List.map (fun (cmd, (_, action)) -> + List.map + (fun (cmd, (_, action)) -> let name = sail_command_name cmd in let typschm = mk_typschm (mk_typquant []) (Interactive.reflect_typ action) in mk_val_spec (VS_val_spec (typschm, mk_id name, Some { pure = false; bindings = [("_", name)] }, false)) - ) cmds in + ) + cmds + in let val_specs, env = Type_check.check_defs istate.env val_specs in - - List.iter (fun (cmd, (help, action)) -> + + List.iter + (fun (cmd, (help, action)) -> let open Value in let name = sail_command_name cmd in let impl values = let rec call values action = - match values, action with - | (v :: vs), Interactive.ArgString (_, next) -> - call vs (next (coerce_string v)) - | (v :: vs), Interactive.ArgInt (_, next) -> - call vs (next (Big_int.to_int (coerce_int v))) + match (values, action) with + | v :: vs, Interactive.ArgString (_, next) -> call vs (next (coerce_string v)) + | v :: vs, Interactive.ArgInt (_, next) -> call vs (next (Big_int.to_int (coerce_int v))) | _, ActionUnit act -> - act !(istate.ref_state); V_unit + act !(istate.ref_state); + V_unit | _, Action act -> - istate.ref_state := act !(istate.ref_state); - V_unit - | _, _ -> - failwith help + istate.ref_state := act !(istate.ref_state); + V_unit + | _, _ -> failwith help in call values action in Value.add_primop name impl - ) cmds; + ) + cmds; - { istate with ast = append_ast_defs istate.ast val_specs; env = env } + { istate with ast = append_ast_defs istate.ast val_specs; env } let print_program istate = match istate.mode with | Normal -> () | Evaluation (Step (out, _, _, stack)) - | Evaluation (Effect_request(out, _, stack, _)) - | Evaluation (Fail (out, _, _, stack, _)) -> - List.map stack_string stack |> List.rev |> List.iter (fun code -> print_endline (Lazy.force code); print_endline sep); - print_endline (Lazy.force out) - | Evaluation (Done (_, v)) -> - print_endline (Value.string_of_value v |> Util.green |> Util.clear) + | Evaluation (Effect_request (out, _, stack, _)) + | Evaluation (Fail (out, _, _, stack, _)) -> + List.map stack_string stack |> List.rev + |> List.iter (fun code -> + print_endline (Lazy.force code); + print_endline sep + ); + print_endline (Lazy.force out) + | Evaluation (Done (_, v)) -> print_endline (Value.string_of_value v |> Util.green |> Util.clear) | Evaluation _ -> () let rec run istate = match istate.mode with | Normal -> istate - | Evaluation frame -> - begin match frame with - | Done (state, v) -> - print_endline ("Result = " ^ Value.string_of_value v); - { istate with mode = Normal; state = state } - | Fail (_, _, _, _, msg) -> - print_endline ("Error: " ^ msg); - { istate with mode = Normal } - | Step _ -> - let istate = - try - { istate with mode = Evaluation (eval_frame frame) } - with - | Failure str -> - print_endline str; - { istate with mode = Normal } - in - run istate - | Break frame -> - print_endline "Breakpoint"; - { istate with mode = Evaluation frame } - | Effect_request (_, state, _, eff) -> - let istate = - try - { istate with mode = Evaluation (!Interpreter.effect_interp state eff) } - with - | Failure str -> - print_endline str; - { istate with mode = Normal } - in - run istate - end + | Evaluation frame -> begin + match frame with + | Done (state, v) -> + print_endline ("Result = " ^ Value.string_of_value v); + { istate with mode = Normal; state } + | Fail (_, _, _, _, msg) -> + print_endline ("Error: " ^ msg); + { istate with mode = Normal } + | Step _ -> + let istate = + try { istate with mode = Evaluation (eval_frame frame) } + with Failure str -> + print_endline str; + { istate with mode = Normal } + in + run istate + | Break frame -> + print_endline "Breakpoint"; + { istate with mode = Evaluation frame } + | Effect_request (_, state, _, eff) -> + let istate = + try { istate with mode = Evaluation (!Interpreter.effect_interp state eff) } + with Failure str -> + print_endline str; + { istate with mode = Normal } + in + run istate + end let rec run_function istate depth = let run_function' istate stack = match depth with | None -> run_function istate (Some (List.length stack)) - | Some n -> - if List.compare_length_with stack n >= 0 then - run_function istate depth - else - istate + | Some n -> if List.compare_length_with stack n >= 0 then run_function istate depth else istate in match istate.mode with | Normal -> istate - | Evaluation frame -> - begin match frame with - | Done (state, v) -> - print_endline ("Result = " ^ Value.string_of_value v); - { istate with mode = Normal; state = state } - | Fail (_, _, _, _, msg) -> - print_endline ("Error: " ^ msg); - { istate with mode = Normal } - | Step (_, _, _, stack) -> - let istate = - try - { istate with mode = Evaluation (eval_frame frame) } - with - | Failure str -> - print_endline str; - { istate with mode = Normal } - in - run_function' istate stack - | Break frame -> - print_endline "Breakpoint"; - { istate with mode = Evaluation frame } - | Effect_request (_, state, stack, eff) -> - let istate = - try - { istate with mode = Evaluation (!Interpreter.effect_interp state eff) } - with - | Failure str -> - print_endline str; - { istate with mode = Normal } - in - run_function' istate stack - end + | Evaluation frame -> begin + match frame with + | Done (state, v) -> + print_endline ("Result = " ^ Value.string_of_value v); + { istate with mode = Normal; state } + | Fail (_, _, _, _, msg) -> + print_endline ("Error: " ^ msg); + { istate with mode = Normal } + | Step (_, _, _, stack) -> + let istate = + try { istate with mode = Evaluation (eval_frame frame) } + with Failure str -> + print_endline str; + { istate with mode = Normal } + in + run_function' istate stack + | Break frame -> + print_endline "Breakpoint"; + { istate with mode = Evaluation frame } + | Effect_request (_, state, stack, eff) -> + let istate = + try { istate with mode = Evaluation (!Interpreter.effect_interp state eff) } + with Failure str -> + print_endline str; + { istate with mode = Normal } + in + run_function' istate stack + end let rec run_steps istate n = match istate.mode with | _ when n <= 0 -> istate | Normal -> istate - | Evaluation frame -> - begin match frame with - | Done (state, v) -> - print_endline ("Result = " ^ Value.string_of_value v); - { istate with mode = Normal; state = state } - | Fail (_, _, _, _, msg) -> - print_endline ("Error: " ^ msg); - { istate with mode = Normal } - | Step (_, _, _, _) -> - let istate = - try - { istate with mode = Evaluation (eval_frame frame) } - with - | Failure str -> - print_endline str; - { istate with mode = Normal } - in - run_steps istate (n - 1) - | Break frame -> - print_endline "Breakpoint"; - { istate with mode = Evaluation frame } - | Effect_request (_, state, _, eff) -> - let istate = - try - { istate with mode = Evaluation (!Interpreter.effect_interp state eff) } - with - | Failure str -> - print_endline str; - { istate with mode = Normal } - in - run_steps istate (n - 1) - end - + | Evaluation frame -> begin + match frame with + | Done (state, v) -> + print_endline ("Result = " ^ Value.string_of_value v); + { istate with mode = Normal; state } + | Fail (_, _, _, _, msg) -> + print_endline ("Error: " ^ msg); + { istate with mode = Normal } + | Step (_, _, _, _) -> + let istate = + try { istate with mode = Evaluation (eval_frame frame) } + with Failure str -> + print_endline str; + { istate with mode = Normal } + in + run_steps istate (n - 1) + | Break frame -> + print_endline "Breakpoint"; + { istate with mode = Evaluation frame } + | Effect_request (_, state, _, eff) -> + let istate = + try { istate with mode = Evaluation (!Interpreter.effect_interp state eff) } + with Failure str -> + print_endline str; + { istate with mode = Normal } + in + run_steps istate (n - 1) + end + let help = let open Printf in let open Util in let color c str = str |> c |> clear in function - | ":t" | ":type" -> - sprintf "(:t | :type) %s - Print the type of a function." - (color yellow "") - | ":q" | ":quit" -> - "(:q | :quit) - Exit the interpreter." - | ":i" | ":infer" -> - sprintf "(:i | :infer) %s - Infer the type of an expression." - (color yellow "") - | ":v" | ":verbose" -> - "(:v | :verbose) - Increase the verbosity level, or reset to zero at max verbosity." + | ":t" | ":type" -> sprintf "(:t | :type) %s - Print the type of a function." (color yellow "") + | ":q" | ":quit" -> "(:q | :quit) - Exit the interpreter." + | ":i" | ":infer" -> sprintf "(:i | :infer) %s - Infer the type of an expression." (color yellow "") + | ":v" | ":verbose" -> "(:v | :verbose) - Increase the verbosity level, or reset to zero at max verbosity." | ":b" | ":bind" -> - sprintf "(:b | :bind) %s : %s - Declare a variable of a specific type" - (color yellow "") (color yellow "") - | ":let" -> - sprintf ":let %s = %s - Bind a variable to expression" - (color yellow "") (color yellow "") - | ":def" -> - sprintf ":def %s - Evaluate a top-level definition" - (color yellow "") + sprintf "(:b | :bind) %s : %s - Declare a variable of a specific type" (color yellow "") + (color yellow "") + | ":let" -> sprintf ":let %s = %s - Bind a variable to expression" (color yellow "") (color yellow "") + | ":def" -> sprintf ":def %s - Evaluate a top-level definition" (color yellow "") | ":prove" -> - sprintf ":prove %s - Try to prove a constraint in the top-level environment" - (color yellow "") - | ":assume" -> - sprintf ":assume %s - Add a constraint to the top-level environment" - (color yellow "") - | ":commands" -> - ":commands - List all available commands." + sprintf ":prove %s - Try to prove a constraint in the top-level environment" (color yellow "") + | ":assume" -> sprintf ":assume %s - Add a constraint to the top-level environment" (color yellow "") + | ":commands" -> ":commands - List all available commands." | ":help" -> - sprintf ":help %s - Get a description of . Commands are prefixed with a colon, e.g. %s." - (color yellow "") (color green ":help :type") - | ":r" | ":run" -> - "(:r | :run) - Completely evaluate the currently evaluating expression." - | ":s" | ":step" -> - sprintf "(:s | :step) %s - Perform a number of evaluation steps." - (color yellow "") + sprintf ":help %s - Get a description of . Commands are prefixed with a colon, e.g. %s." + (color yellow "") (color green ":help :type") + | ":r" | ":run" -> "(:r | :run) - Completely evaluate the currently evaluating expression." + | ":s" | ":step" -> sprintf "(:s | :step) %s - Perform a number of evaluation steps." (color yellow "") | ":f" | ":step_function" -> - sprintf "(:f | :step_function) - Perform evaluation steps until the currently evaulating function returns." - | ":n" | ":normal" -> - "(:n | :normal) - Exit evaluation mode back to normal mode." + sprintf "(:f | :step_function) - Perform evaluation steps until the currently evaulating function returns." + | ":n" | ":normal" -> "(:n | :normal) - Exit evaluation mode back to normal mode." | ":clear" -> - sprintf ":clear %s - Set whether to clear the screen or not in evaluation mode." - (color yellow "(on|off)") - | ":output" -> - sprintf ":output %s - Redirect evaluating expression output to a file." - (color yellow "") + sprintf ":clear %s - Set whether to clear the screen or not in evaluation mode." (color yellow "(on|off)") + | ":output" -> sprintf ":output %s - Redirect evaluating expression output to a file." (color yellow "") | ":option" -> - sprintf ":option %s - Parse string as if it was an option passed on the command line. e.g. :option -help." - (color yellow "") + sprintf ":option %s - Parse string as if it was an option passed on the command line. e.g. :option -help." + (color yellow "") | ":recheck" -> - sprintf ":recheck - Re type-check the Sail AST, and synchronize the interpreters internal state to that AST." + sprintf ":recheck - Re type-check the Sail AST, and synchronize the interpreters internal state to that AST." | ":rewrite" -> - sprintf ":rewrite %s - Apply a rewrite to the AST. %s shows all possible rewrites. See also %s" - (color yellow " ") (color green ":list_rewrites") (color green ":rewrites") + sprintf ":rewrite %s - Apply a rewrite to the AST. %s shows all possible rewrites. See also %s" + (color yellow " ") (color green ":list_rewrites") (color green ":rewrites") | "" -> - sprintf "Type %s for a list of commands, and %s %s for information about a specific command" - (color green ":commands") (color green ":help") (color yellow "") - | cmd -> - match Interactive.get_command cmd with - | Some (help_message, action) -> Interactive.generate_help cmd help_message action - | None -> - sprintf "Either invalid command passed to help, or no documentation for %s. Try %s." - (color green cmd) (color green ":help :help") + sprintf "Type %s for a list of commands, and %s %s for information about a specific command" + (color green ":commands") (color green ":help") (color yellow "") + | cmd -> ( + match Interactive.get_command cmd with + | Some (help_message, action) -> Interactive.generate_help cmd help_message action + | None -> + sprintf "Either invalid command passed to help, or no documentation for %s. Try %s." (color green cmd) + (color green ":help :help") + ) type input = Command of string * string | Expression of string | Empty @@ -425,18 +391,17 @@ let handle_input' istate input = (* Process the input and check if it's a command, a raw expression, or empty. *) let input = - if input <> "" && input.[0] = ':' then + if input <> "" && input.[0] = ':' then ( let n = try String.index input ' ' with Not_found -> String.length input in let cmd = Str.string_before input n in let arg = String.trim (Str.string_after input n) in Command (cmd, arg) + ) else if String.length input >= 2 && input.[0] = '/' && input.[1] = '/' then (* Treat anything starting with // as a comment *) Empty - else if input <> "" then - Expression input - else - Empty + else if input <> "" then Expression input + else Empty in let recognised = ref true in @@ -447,333 +412,341 @@ let handle_input' istate input = in (* First handle commands that are mode-independent *) - let istate = match input with - | Command (cmd, arg) -> - begin match cmd with - | ":n" | ":normal" -> - { istate with mode = Normal } - | ":t" | ":type" -> - let typq, typ = Type_check.Env.get_val_spec (mk_id arg) istate.env in - pretty_sail stdout (doc_binding (typq, typ)); - print_newline (); - istate - | ":q" | ":quit" -> - Value.output_close (); - exit 0 - | ":i" | ":infer" -> - let exp = Initial_check.exp_of_string arg in - let exp = Type_check.infer_exp istate.env exp in - pretty_sail stdout (doc_typ (Type_check.typ_of exp)); - print_newline (); - istate - | ":prove" -> - let nc = Initial_check.constraint_of_string arg in - print_endline (string_of_bool (Type_check.prove __POS__ istate.env nc)); - istate - | ":assume" -> - let nc = Initial_check.constraint_of_string arg in - { istate with env = Type_check.Env.add_constraint nc istate.env } - | ":v" | ":verbose" -> - Type_check.opt_tc_debug := (!Type_check.opt_tc_debug + 1) mod 3; - print_endline ("Verbosity: " ^ string_of_int !Type_check.opt_tc_debug); - istate - | ":clear" -> - if arg = "on" || arg = "true" then - { istate with clear = true } - else if arg = "off" || arg = "false" then - { istate with clear = false } - else ( - print_endline "Invalid argument for :clear, expected either :clear on or :clear off"; + let istate = + match input with + | Command (cmd, arg) -> begin + match cmd with + | ":n" | ":normal" -> { istate with mode = Normal } + | ":t" | ":type" -> + let typq, typ = Type_check.Env.get_val_spec (mk_id arg) istate.env in + pretty_sail stdout (doc_binding (typq, typ)); + print_newline (); istate - ) - | ":commands" -> - let more_commands = Util.string_of_list " " fst (Interactive.all_commands ()) in - let commands = - [ "Universal commands - :(t)ype :(i)nfer :(q)uit :(v)erbose :prove :assume :clear :commands :help :output :option"; - "Normal mode commands - :elf :(l)oad :(u)nload :let :def :(b)ind :recheck :compile " ^ more_commands; - "Evaluation mode commands - :(r)un :(s)tep :step_(f)unction :(n)ormal"; - ""; - ":(c)ommand can be called as either :c or :command." ] - in - List.iter print_endline commands; - istate - | ":option" -> - begin - try - let args = Str.split (Str.regexp " +") arg in - begin match args with - | opt :: args -> - Arg.parse_argv ~current:(ref 0) (Array.of_list ["sail"; opt; String.concat " " args]) istate.options (fun _ -> ()) ""; - | [] -> print_endline "Must provide a valid option" - end - with - | Arg.Bad message | Arg.Help message -> print_endline message - end; - istate - (* + | ":q" | ":quit" -> + Value.output_close (); + exit 0 + | ":i" | ":infer" -> + let exp = Initial_check.exp_of_string arg in + let exp = Type_check.infer_exp istate.env exp in + pretty_sail stdout (doc_typ (Type_check.typ_of exp)); + print_newline (); + istate + | ":prove" -> + let nc = Initial_check.constraint_of_string arg in + print_endline (string_of_bool (Type_check.prove __POS__ istate.env nc)); + istate + | ":assume" -> + let nc = Initial_check.constraint_of_string arg in + { istate with env = Type_check.Env.add_constraint nc istate.env } + | ":v" | ":verbose" -> + Type_check.opt_tc_debug := (!Type_check.opt_tc_debug + 1) mod 3; + print_endline ("Verbosity: " ^ string_of_int !Type_check.opt_tc_debug); + istate + | ":clear" -> + if arg = "on" || arg = "true" then { istate with clear = true } + else if arg = "off" || arg = "false" then { istate with clear = false } + else ( + print_endline "Invalid argument for :clear, expected either :clear on or :clear off"; + istate + ) + | ":commands" -> + let more_commands = Util.string_of_list " " fst (Interactive.all_commands ()) in + let commands = + [ + "Universal commands - :(t)ype :(i)nfer :(q)uit :(v)erbose :prove :assume :clear :commands :help \ + :output :option"; + "Normal mode commands - :elf :(l)oad :(u)nload :let :def :(b)ind :recheck :compile " ^ more_commands; + "Evaluation mode commands - :(r)un :(s)tep :step_(f)unction :(n)ormal"; + ""; + ":(c)ommand can be called as either :c or :command."; + ] + in + List.iter print_endline commands; + istate + | ":option" -> + begin + try + let args = Str.split (Str.regexp " +") arg in + begin + match args with + | opt :: args -> + Arg.parse_argv ~current:(ref 0) + (Array.of_list ["sail"; opt; String.concat " " args]) + istate.options + (fun _ -> ()) + "" + | [] -> print_endline "Must provide a valid option" + end + with Arg.Bad message | Arg.Help message -> print_endline message + end; + istate + (* | ":pretty" -> print_endline (Pretty_print_sail.to_string (Latex.defs istate.ast)); istate *) - | ":ast" -> - let chan = open_out arg in - Pretty_print_sail.pp_ast chan (Type_check.strip_ast istate.ast); - close_out chan; - istate - | ":output" -> - let chan = open_out arg in - Value.output_redirect chan; - istate - | ":help" -> - print_endline (help arg); - istate - | _ -> - recognised := false; - istate - end + | ":ast" -> + let chan = open_out arg in + Pretty_print_sail.pp_ast chan (Type_check.strip_ast istate.ast); + close_out chan; + istate + | ":output" -> + let chan = open_out arg in + Value.output_redirect chan; + istate + | ":help" -> + print_endline (help arg); + istate + | _ -> + recognised := false; + istate + end | _ -> istate in match istate.mode with - | Normal -> - begin match input with - | Command (cmd, arg) -> - (* Normal mode commands *) - begin match cmd with - | ":b" | ":bind" -> - let args = Str.split (Str.regexp " +") arg in - begin match args with - | v :: ":" :: args -> - let typ = Initial_check.typ_of_string (String.concat " " args) in - let _, env, _ = Type_check.bind_pat istate.env (mk_pat (P_id (mk_id v))) typ in - { istate with env = env } - | _ -> - failwith "Invalid arguments for :bind"; - end - | ":let" -> - let args = Str.split (Str.regexp " +") arg in - begin match args with - | v :: "=" :: args -> - let exp = Initial_check.exp_of_string (String.concat " " args) in - let defs, env = Type_check.check_defs istate.env [mk_def (DEF_let (mk_letbind (mk_pat (P_id (mk_id v))) exp))] in - { istate with ast = append_ast_defs istate.ast defs; env = env } - | _ -> - failwith "Invalid arguments for :let"; - end - | ":def" -> - let ast = Initial_check.ast_of_def_string_with __POS__ (Preprocess.preprocess istate.default_sail_dir None istate.options) arg in - let ast, env = Type_check.check istate.env ast in - { istate with ast = append_ast istate.ast ast; env = env } - | ":rewrite" -> - let open Rewrites in - let args = Str.split (Str.regexp " +") arg in - let rec parse_args rw args = - match rw, args with - | Base_rewriter rw, [] -> rw - | Bool_rewriter rw, arg :: args -> parse_args (rw (bool_of_string arg)) args - | String_rewriter rw, arg :: args -> parse_args (rw arg) args - | Literal_rewriter rw, arg :: args -> - begin match arg with - | "ocaml" -> parse_args (rw rewrite_lit_ocaml) args - | "lem" -> parse_args (rw rewrite_lit_lem) args - | "all" -> parse_args (rw (fun _ -> true)) args - | _ -> failwith "Target for literal rewrite must be one of ocaml/lem/all" - end - | _, _ -> failwith "Invalid arguments to rewrite" - in - begin match args with - | rw :: args -> - let rw = List.assoc rw Rewrites.all_rewriters in - let rw = parse_args rw args in - let ast', effect_info', env' = rw istate.effect_info istate.env istate.ast in - { istate with ast = ast'; effect_info = effect_info'; env = env' } - | [] -> - failwith "Must provide the name of a rewrite, use :list_rewrites for a list of possible rewrites" - end - | ":sync_script" -> - { istate with ast = !(istate.ref_state).ast; effect_info = !(istate.ref_state).effect_info; env = !(istate.ref_state).env } - | ":recheck" | ":recheck_types" -> - let ast, env = Type_check.check Type_check.initial_env (Type_check.strip_ast istate.ast) in - { istate with env = env; ast = ast } - | _ -> - match Interactive.get_command cmd with - | Some (_, action) -> - let res = Interactive.run_action (shrink_istate istate) cmd arg action in - { istate with ast = res.ast; effect_info = res.effect_info; env = res.env } - | None -> - unrecognised_command istate cmd + | Normal -> begin + match input with + | Command (cmd, arg) -> begin + (* Normal mode commands *) + match cmd with + | ":b" | ":bind" -> + let args = Str.split (Str.regexp " +") arg in + begin + match args with + | v :: ":" :: args -> + let typ = Initial_check.typ_of_string (String.concat " " args) in + let _, env, _ = Type_check.bind_pat istate.env (mk_pat (P_id (mk_id v))) typ in + { istate with env } + | _ -> failwith "Invalid arguments for :bind" + end + | ":let" -> + let args = Str.split (Str.regexp " +") arg in + begin + match args with + | v :: "=" :: args -> + let exp = Initial_check.exp_of_string (String.concat " " args) in + let defs, env = + Type_check.check_defs istate.env [mk_def (DEF_let (mk_letbind (mk_pat (P_id (mk_id v))) exp))] + in + { istate with ast = append_ast_defs istate.ast defs; env } + | _ -> failwith "Invalid arguments for :let" + end + | ":def" -> + let ast = + Initial_check.ast_of_def_string_with __POS__ + (Preprocess.preprocess istate.default_sail_dir None istate.options) + arg + in + let ast, env = Type_check.check istate.env ast in + { istate with ast = append_ast istate.ast ast; env } + | ":rewrite" -> + let open Rewrites in + let args = Str.split (Str.regexp " +") arg in + let rec parse_args rw args = + match (rw, args) with + | Base_rewriter rw, [] -> rw + | Bool_rewriter rw, arg :: args -> parse_args (rw (bool_of_string arg)) args + | String_rewriter rw, arg :: args -> parse_args (rw arg) args + | Literal_rewriter rw, arg :: args -> begin + match arg with + | "ocaml" -> parse_args (rw rewrite_lit_ocaml) args + | "lem" -> parse_args (rw rewrite_lit_lem) args + | "all" -> parse_args (rw (fun _ -> true)) args + | _ -> failwith "Target for literal rewrite must be one of ocaml/lem/all" + end + | _, _ -> failwith "Invalid arguments to rewrite" + in + begin + match args with + | rw :: args -> + let rw = List.assoc rw Rewrites.all_rewriters in + let rw = parse_args rw args in + let ast', effect_info', env' = rw istate.effect_info istate.env istate.ast in + { istate with ast = ast'; effect_info = effect_info'; env = env' } + | [] -> + failwith "Must provide the name of a rewrite, use :list_rewrites for a list of possible rewrites" + end + | ":sync_script" -> + { + istate with + ast = !(istate.ref_state).ast; + effect_info = !(istate.ref_state).effect_info; + env = !(istate.ref_state).env; + } + | ":recheck" | ":recheck_types" -> + let ast, env = Type_check.check Type_check.initial_env (Type_check.strip_ast istate.ast) in + { istate with env; ast } + | _ -> ( + match Interactive.get_command cmd with + | Some (_, action) -> + let res = Interactive.run_action (shrink_istate istate) cmd arg action in + { istate with ast = res.ast; effect_info = res.effect_info; env = res.env } + | None -> unrecognised_command istate cmd + ) end - | Expression str -> - (* An expression in normal mode is type checked, then puts - us in evaluation mode. *) - let exp = Type_check.infer_exp istate.env (Initial_check.exp_of_string str) in - let istate = setup_interpreter_state istate in - let istate = { istate with mode = Evaluation (eval_frame (Step (lazy "", istate.state, return exp, []))) } in - print_program istate; - istate - | Empty -> istate - end - - | Evaluation frame -> - begin match input with - | Command (cmd, arg) -> - (* Evaluation mode commands *) - begin match cmd with - | ":r" | ":run" -> - run istate - | ":s" | ":step" -> - let istate = run_steps istate (int_of_string arg) in - print_program istate; - istate - | ":f" | ":step_function" -> - let istate = run_function istate None in - print_program istate; - istate - | _ -> unrecognised_command istate cmd + | Expression str -> + (* An expression in normal mode is type checked, then puts + us in evaluation mode. *) + let exp = Type_check.infer_exp istate.env (Initial_check.exp_of_string str) in + let istate = setup_interpreter_state istate in + let istate = { istate with mode = Evaluation (eval_frame (Step (lazy "", istate.state, return exp, []))) } in + print_program istate; + istate + | Empty -> istate + end + | Evaluation frame -> begin + match input with + | Command (cmd, arg) -> begin + (* Evaluation mode commands *) + match cmd with + | ":r" | ":run" -> run istate + | ":s" | ":step" -> + let istate = run_steps istate (int_of_string arg) in + print_program istate; + istate + | ":f" | ":step_function" -> + let istate = run_function istate None in + print_program istate; + istate + | _ -> unrecognised_command istate cmd end - | Expression _ -> - print_endline "Already evaluating expression"; - istate - | Empty -> - (* Empty input will evaluate one step, or switch back to - normal mode when evaluation is completed. *) - begin match frame with - | Done (state, v) -> - print_endline ("Result = " ^ Value.string_of_value v); - { istate with mode = Normal; state = state } - | Fail (_, _, _, _, msg) -> - print_endline ("Error: " ^ msg); - { istate with mode = Normal } - | Step (_, state, _, _) -> - begin - try - let istate = { istate with mode = Evaluation (eval_frame frame); state = state } in - print_program istate; - istate - with - | Failure str -> + | Expression _ -> + print_endline "Already evaluating expression"; + istate + | Empty -> begin + (* Empty input will evaluate one step, or switch back to + normal mode when evaluation is completed. *) + match frame with + | Done (state, v) -> + print_endline ("Result = " ^ Value.string_of_value v); + { istate with mode = Normal; state } + | Fail (_, _, _, _, msg) -> + print_endline ("Error: " ^ msg); + { istate with mode = Normal } + | Step (_, state, _, _) -> begin + try + let istate = { istate with mode = Evaluation (eval_frame frame); state } in + print_program istate; + istate + with Failure str -> print_endline str; { istate with mode = Normal } - end - | Break frame -> - print_endline "Breakpoint"; - { istate with mode = Evaluation frame } - | Effect_request (_, state, _, eff) -> - begin - try - let istate = { istate with mode = Evaluation (!Interpreter.effect_interp state eff); state = state } in - print_program istate; - istate - with - | Failure str -> + end + | Break frame -> + print_endline "Breakpoint"; + { istate with mode = Evaluation frame } + | Effect_request (_, state, _, eff) -> begin + try + let istate = { istate with mode = Evaluation (!Interpreter.effect_interp state eff); state } in + print_program istate; + istate + with Failure str -> print_endline str; { istate with mode = Normal } - end + end end - end + end let handle_input istate input = try handle_input' istate input with | Failure str -> - print_endline ("Error: " ^ str); - istate + print_endline ("Error: " ^ str); + istate | Type_check.Type_error (env, _, err) -> - print_endline (Type_error.string_of_type_error err); - { istate with env = env } + print_endline (Type_error.string_of_type_error err); + { istate with env } | Reporting.Fatal_error err -> - Reporting.print_error ~interactive:true err; - istate + Reporting.print_error ~interactive:true err; + istate | exn -> - print_endline (Printexc.to_string exn); - istate + print_endline (Printexc.to_string exn); + istate -let start_repl ?banner:(banner = true) ?commands:(script = []) ?auto_rewrites:(rewrites = true) ~options:options env effect_info ast = +let start_repl ?(banner = true) ?commands:(script = []) ?auto_rewrites:(rewrites = true) ~options env effect_info ast = let istate = if rewrites then ( - let ast, effect_info, env = Rewrites.rewrite effect_info env (Rewrites.instantiate_rewrites Rewrites.rewrites_interpreter) ast in - initial_istate options env effect_info ast - ) else ( + let ast, effect_info, env = + Rewrites.rewrite effect_info env (Rewrites.instantiate_rewrites Rewrites.rewrites_interpreter) ast + in initial_istate options env effect_info ast ) + else initial_istate options env effect_info ast in - LNoise.set_completion_callback ( - fun line_so_far ln_completions -> + LNoise.set_completion_callback (fun line_so_far ln_completions -> let line_so_far, last_id = try let p = Str.search_backward (Str.regexp "[^a-zA-Z0-9_/-]") line_so_far (String.length line_so_far - 1) in - Str.string_before line_so_far (p + 1), Str.string_after line_so_far (p + 1) + (Str.string_before line_so_far (p + 1), Str.string_after line_so_far (p + 1)) with - | Not_found -> "", line_so_far - | Invalid_argument _ -> line_so_far, "" + | Not_found -> ("", line_so_far) + | Invalid_argument _ -> (line_so_far, "") in let n = try String.index line_so_far ' ' with Not_found -> String.length line_so_far in let cmd = Str.string_before line_so_far n in - if last_id <> "" then - begin match cmd with + if last_id <> "" then begin + match cmd with | ":rewrite" -> - List.map fst Rewrites.all_rewriters - |> List.filter (fun opt -> Str.string_match (Str.regexp_string last_id) opt 0) - |> List.map (fun completion -> line_so_far ^ completion) - |> List.iter (LNoise.add_completion ln_completions) + List.map fst Rewrites.all_rewriters + |> List.filter (fun opt -> Str.string_match (Str.regexp_string last_id) opt 0) + |> List.map (fun completion -> line_so_far ^ completion) + |> List.iter (LNoise.add_completion ln_completions) | ":option" -> - List.map (fun (opt, _, _) -> opt) options - |> List.filter (fun opt -> Str.string_match (Str.regexp_string last_id) opt 0) - |> List.map (fun completion -> line_so_far ^ completion) - |> List.iter (LNoise.add_completion ln_completions) + List.map (fun (opt, _, _) -> opt) options + |> List.filter (fun opt -> Str.string_match (Str.regexp_string last_id) opt 0) + |> List.map (fun completion -> line_so_far ^ completion) + |> List.iter (LNoise.add_completion ln_completions) | _ -> - IdSet.elements !(istate.vs_ids) - |> List.map string_of_id - |> List.filter (fun id -> Str.string_match (Str.regexp_string last_id) id 0) - |> List.map (fun completion -> line_so_far ^ completion) - |> List.iter (LNoise.add_completion ln_completions) - end + IdSet.elements !(istate.vs_ids) |> List.map string_of_id + |> List.filter (fun id -> Str.string_match (Str.regexp_string last_id) id 0) + |> List.map (fun completion -> line_so_far ^ completion) + |> List.iter (LNoise.add_completion ln_completions) + end else () - ); + ); - LNoise.set_hints_callback ( - fun line_so_far -> + LNoise.set_hints_callback (fun line_so_far -> let hint str = Some (" " ^ str, LNoise.Yellow, false) in match String.trim line_so_far with | ":clear" -> hint "(on|off)" - | ":bind" | ":b" -> hint " : " + | ":bind" | ":b" -> hint " : " | ":infer" | ":i" -> hint "" - | ":type" | ":t" -> hint "" + | ":type" | ":t" -> hint "" | ":let" -> hint " = " | ":def" -> hint "" | ":prove" -> hint "" | ":assume" -> hint "" | ":compile" -> hint "" | ":rewrites" -> hint "" - | str -> - let args = Str.split (Str.regexp " +") str in - match args with - | [":rewrite"] -> hint "" - | ":rewrite" :: rw :: args -> - begin match List.assoc_opt rw Rewrites.all_rewriters with - | Some rw -> - let hints = Rewrites.describe_rewriter rw in - let hints = Util.drop (List.length args) hints in - (match hints with [] -> None | _ -> hint (String.concat " " hints)) - | None -> None + | str -> ( + let args = Str.split (Str.regexp " +") str in + match args with + | [":rewrite"] -> hint "" + | ":rewrite" :: rw :: args -> begin + match List.assoc_opt rw Rewrites.all_rewriters with + | Some rw -> ( + let hints = Rewrites.describe_rewriter rw in + let hints = Util.drop (List.length args) hints in + match hints with [] -> None | _ -> hint (String.concat " " hints) + ) + | None -> None end - | [":option"] -> hint "" - | [":option"; flag] -> - begin match List.find_opt (fun (opt, _, _) -> flag = opt) options with - | Some (_, _, help) -> hint (Str.global_replace (Str.regexp " +") " " help) - | None -> None + | [":option"] -> hint "" + | [":option"; flag] -> begin + match List.find_opt (fun (opt, _, _) -> flag = opt) options with + | Some (_, _, help) -> hint (Str.global_replace (Str.regexp " +") " " help) + | None -> None end - | _ -> None - ); + | _ -> None + ) + ); let istate = List.fold_left handle_input istate script in LNoise.history_load ~filename:"sail_history" |> ignore; LNoise.history_set ~max_length:100 |> ignore; - if banner then ( - List.iter print_endline sail_logo - ); + if banner then List.iter print_endline sail_logo; let istate = setup_sail_scripting istate in user_input istate handle_input - diff --git a/src/bin/repl.mli b/src/bin/repl.mli index 0c95ae748..b0dbca9a4 100644 --- a/src/bin/repl.mli +++ b/src/bin/repl.mli @@ -81,9 +81,9 @@ open Type_check *) val start_repl : ?banner:bool -> - ?commands:(string list) -> + ?commands:string list -> ?auto_rewrites:bool -> - options:((Arg.key * Arg.spec * Arg.doc) list) -> + options:(Arg.key * Arg.spec * Arg.doc) list -> Env.t -> Effects.side_effect_info -> tannot ast -> diff --git a/src/bin/sail.ml b/src/bin/sail.ml index bc1aff01b..30ae00347 100644 --- a/src/bin/sail.ml +++ b/src/bin/sail.ml @@ -66,7 +66,7 @@ (****************************************************************************) open Libsail - + let opt_file_arguments : string list ref = ref [] let opt_file_out : string option ref = ref None let opt_just_check : bool ref = ref false @@ -85,237 +85,221 @@ let opt_format_skip : string list ref = ref [] (* Allow calling all options as either -foo_bar or -foo-bar *) let rec fix_options = function - | (flag, spec, doc) :: opts -> (flag, spec, doc) :: (String.map (function '_' -> '-' | c -> c) flag, spec, "") :: fix_options opts + | (flag, spec, doc) :: opts -> + (flag, spec, doc) :: (String.map (function '_' -> '-' | c -> c) flag, spec, "") :: fix_options opts | [] -> [] let load_plugin opts plugin = try Dynlink.loadfile_private plugin; opts := Arg.align (!opts @ fix_options (Target.extract_options ())) - with - | Dynlink.Error msg -> - prerr_endline ("Failed to load plugin " ^ plugin ^ ": " ^ Dynlink.error_message msg) + with Dynlink.Error msg -> prerr_endline ("Failed to load plugin " ^ plugin ^ ": " ^ Dynlink.error_message msg) let version = let open Manifest in let default = Printf.sprintf "Sail %s @ %s" branch commit in (* version is parsed from the output of git describe *) match String.split_on_char '-' version with - | (vnum :: _) -> - Printf.sprintf "Sail %s (%s @ %s)" vnum branch commit + | vnum :: _ -> Printf.sprintf "Sail %s (%s @ %s)" vnum branch commit | _ -> default - -let usage_msg = - version - ^ "\nusage: sail ... \n" + +let usage_msg = version ^ "\nusage: sail ... \n" let help options = raise (Arg.Help (Arg.usage_string options usage_msg)) -let rec options = ref ([ - ( "-o", - Arg.String (fun f -> opt_file_out := Some f), - " select output filename prefix"); - ( "-dir", - Arg.Set opt_show_sail_dir, - " show current Sail library directory"); - ( "-i", - Arg.Tuple [Arg.Set Interactive.opt_interactive; - Arg.Set opt_auto_interpreter_rewrites; - Arg.Set Initial_check.opt_undefined_gen], - " start interactive interpreter"); - ( "-is", - Arg.Tuple [Arg.Set Interactive.opt_interactive; - Arg.Set opt_auto_interpreter_rewrites; - Arg.Set Initial_check.opt_undefined_gen; - Arg.String (fun s -> opt_interactive_script := Some s)], - " start interactive interpreter and execute commands in script"); - ( "-iout", - Arg.String (fun file -> Value.output_redirect (open_out file)), - " print interpreter output to file"); - ( "-interact_custom", - Arg.Set Interactive.opt_interactive, - " drop to an interactive session after running Sail. Differs from \ - -i in that it does not set up the interpreter in the interactive \ - shell."); - ( "-config", - Arg.String (fun file -> opt_config_file := Some file), - " configuration file"); - ( "-fmt", - Arg.Set opt_format, - " format input source code"); - ( "-fmt_backup", - Arg.String (fun suffix -> opt_format_backup := Some suffix), - " Create backups of formated files as 'file.suffix'"); - ( "-fmt_only", - Arg.String (fun file -> opt_format_only := file :: !opt_format_only), - " Format only this file"); - ( "-fmt_skip", - Arg.String (fun file -> opt_format_skip := file :: !opt_format_skip), - " Skip formatting this file"); - ( "-D", - Arg.String (fun symbol -> Preprocess.add_symbol symbol), - " define a symbol for the preprocessor, as $define does in the source code"); - ( "-no_warn", - Arg.Clear Reporting.opt_warnings, - " do not print warnings"); - ( "-plugin", - Arg.String (fun plugin -> load_plugin options plugin), - " load a Sail plugin"); - ( "-just_check", - Arg.Set opt_just_check, - " terminate immediately after typechecking"); - ( "-memo_z3", - Arg.Set opt_memo_z3, - " memoize calls to z3, improving performance when typechecking repeatedly"); - ( "-no_memo_z3", - Arg.Clear opt_memo_z3, - " do not memoize calls to z3 (default)"); - ( "-have_feature", - Arg.String (fun symbol -> opt_have_feature := Some symbol), - " check if a feature symbol is set by default"); - ( "-no_color", - Arg.Clear Util.opt_colors, - " do not use terminal color codes in output"); - ( "-undefined_gen", - Arg.Set Initial_check.opt_undefined_gen, - " generate undefined_type functions for types in the specification"); - ( "-grouped_regstate", - Arg.Set State.opt_type_grouped_regstate, - " group registers with same type together in generated register state record"); - ( "-enum_casts", - Arg.Set Initial_check.opt_enum_casts, - " allow enumerations to be automatically casted to numeric range types"); - ( "-non_lexical_flow", - Arg.Set Nl_flow.opt_nl_flow, - " allow non-lexical flow typing"); - ( "-no_lexp_bounds_check", - Arg.Set Type_check.opt_no_lexp_bounds_check, - " turn off bounds checking for vector assignments in l-expressions"); - ( "-auto_mono", - Arg.Set Rewrites.opt_auto_mono, - " automatically infer how to monomorphise code"); - ( "-mono_rewrites", - Arg.Set Rewrites.opt_mono_rewrites, - " turn on rewrites for combining bitvector operations"); - ( "-mono_split", - Arg.String (fun s -> - let l = Util.split_on_char ':' s in - match l with - | [fname;line;var] -> - Rewrites.opt_mono_split := ((fname,int_of_string line),var)::!Rewrites.opt_mono_split - | _ -> raise (Arg.Bad (s ^ " not of form ::"))), - ":: manually gives a case split for monomorphisation"); - ( "-splice", - Arg.String (fun s -> opt_splice := s :: !opt_splice), - " add functions from file, replacing existing definitions where necessary"); - ( "-smt_solver", - Arg.String (fun s -> Constraint.set_solver (String.trim s)), - " choose SMT solver. Supported solvers are z3 (default), alt-ergo, cvc4, mathsat, vampire and yices."); - ( "-smt_linearize", - Arg.Set Type_check.opt_smt_linearize, - "(experimental) force linearization for constraints involving exponentials"); - ( "-Oconstant_fold", - Arg.Set Constant_fold.optimize_constant_fold, - " apply constant folding optimizations"); - ( "-Oaarch64_fast", - Arg.Set Jib_compile.optimize_aarch64_fast_struct, - " apply ARMv8.5 specific optimizations (potentially unsound in general)"); - ( "-Ofast_undefined", - Arg.Set Initial_check.opt_fast_undefined, - " turn on fast-undefined mode"); - ( "-const_prop_mutrec", - Arg.String (fun name -> Constant_propagation_mutrec.targets := Ast_util.mk_id name :: !Constant_propagation_mutrec.targets), - " unroll function in a set of mutually recursive functions"); - ( "-ddump_initial_ast", - Arg.Set Frontend.opt_ddump_initial_ast, - " (debug) dump the initial ast to stdout"); - ( "-ddump_tc_ast", - Arg.Set Frontend.opt_ddump_tc_ast, - " (debug) dump the typechecked ast to stdout"); - ( "-dtc_verbose", - Arg.Int (fun verbosity -> Type_check.opt_tc_debug := verbosity), - " (debug) verbose typechecker output: 0 is silent"); - ( "-dsmt_verbose", - Arg.Set Constraint.opt_smt_verbose, - " (debug) print SMTLIB constraints sent to SMT solver"); - ( "-dmagic_hash", - Arg.Set Initial_check.opt_magic_hash, - " (debug) allow special character # in identifiers"); - ( "-dprofile", - Arg.Set Profile.opt_profile, - " (debug) provide basic profiling information for rewriting passes within Sail"); - ( "-dno_cast", - Arg.Set Frontend.opt_dno_cast, - " (debug) typecheck without any implicit casting"); - ( "-dallow_cast", - Arg.Tuple [ - Arg.Unit (fun () -> Reporting.simple_warn "-dallow_cast option is deprecated"); - Arg.Clear Frontend.opt_dno_cast - ], - " (debug) typecheck allowing implicit casting (deprecated)"); - ( "-ddump_rewrite_ast", - Arg.String (fun l -> Rewrites.opt_ddump_rewrite_ast := Some (l, 0); Specialize.opt_ddump_spec_ast := Some (l, 0)), - " (debug) dump the ast after each rewriting step to _.lem"); - ( "-dmono_all_split_errors", - Arg.Set Rewrites.opt_dall_split_errors, - " (debug) display all case split errors from monomorphisation, rather than one"); - ( "-dmono_analysis", - Arg.Set_int Rewrites.opt_dmono_analysis, - " (debug) dump information about monomorphisation analysis: 0 silent, 3 max"); - ( "-dmono_continue", - Arg.Set Rewrites.opt_dmono_continue, - " (debug) continue despite monomorphisation errors"); - ( "-dpattern_warning_no_literals", - Arg.Set Pattern_completeness.opt_debug_no_literals, - ""); - ( "-infer_effects", - Arg.Unit (fun () -> Reporting.simple_warn "-infer_effects option is deprecated"), - " Ignored for compatibility with older versions; effects are always inferred now (deprecated)"); - ( "-dbacktrace", - Arg.Int (fun l -> Reporting.opt_backtrace_length := l), - " Length of backtrace to show when reporting unreachable code"); - ( "-v", - Arg.Set opt_print_version, - " print version"); - ( "-version", - Arg.Set opt_print_version, - " print version"); - ( "-verbose", - Arg.Int (fun verbosity -> Util.opt_verbosity := verbosity), - " produce verbose output"); - ( "-explain_all_variables", - Arg.Set Type_error.opt_explain_all_variables, - " Explain all type variables in type error messages"); - ( "-explain_constraints", - Arg.Set Type_error.opt_explain_constraints, - " Explain all type variables in type error messages"); - ( "-explain_verbose", - Arg.Tuple [ - Arg.Set Type_error.opt_explain_all_variables; - Arg.Set Type_error.opt_explain_constraints - ], - " Add the maximum amount of explanation to type errors"); - ( "-help", - Arg.Unit (fun () -> help !options), - " Display this list of options. Also available as -h or --help"); - ( "-h", Arg.Unit (fun () -> help !options), ""); - ( "--help", Arg.Unit (fun () -> help !options), ""); -]) - -let register_default_target () = - Target.register ~name:"default" (fun _ _ _ _ _ -> ()) - +let rec options = + ref + [ + ("-o", Arg.String (fun f -> opt_file_out := Some f), " select output filename prefix"); + ("-dir", Arg.Set opt_show_sail_dir, " show current Sail library directory"); + ( "-i", + Arg.Tuple + [ + Arg.Set Interactive.opt_interactive; + Arg.Set opt_auto_interpreter_rewrites; + Arg.Set Initial_check.opt_undefined_gen; + ], + " start interactive interpreter" + ); + ( "-is", + Arg.Tuple + [ + Arg.Set Interactive.opt_interactive; + Arg.Set opt_auto_interpreter_rewrites; + Arg.Set Initial_check.opt_undefined_gen; + Arg.String (fun s -> opt_interactive_script := Some s); + ], + " start interactive interpreter and execute commands in script" + ); + ( "-iout", + Arg.String (fun file -> Value.output_redirect (open_out file)), + " print interpreter output to file" + ); + ( "-interact_custom", + Arg.Set Interactive.opt_interactive, + " drop to an interactive session after running Sail. Differs from -i in that it does not set up the \ + interpreter in the interactive shell." + ); + ("-config", Arg.String (fun file -> opt_config_file := Some file), " configuration file"); + ("-fmt", Arg.Set opt_format, " format input source code"); + ( "-fmt_backup", + Arg.String (fun suffix -> opt_format_backup := Some suffix), + " Create backups of formated files as 'file.suffix'" + ); + ("-fmt_only", Arg.String (fun file -> opt_format_only := file :: !opt_format_only), " Format only this file"); + ( "-fmt_skip", + Arg.String (fun file -> opt_format_skip := file :: !opt_format_skip), + " Skip formatting this file" + ); + ( "-D", + Arg.String (fun symbol -> Preprocess.add_symbol symbol), + " define a symbol for the preprocessor, as $define does in the source code" + ); + ("-no_warn", Arg.Clear Reporting.opt_warnings, " do not print warnings"); + ("-plugin", Arg.String (fun plugin -> load_plugin options plugin), " load a Sail plugin"); + ("-just_check", Arg.Set opt_just_check, " terminate immediately after typechecking"); + ("-memo_z3", Arg.Set opt_memo_z3, " memoize calls to z3, improving performance when typechecking repeatedly"); + ("-no_memo_z3", Arg.Clear opt_memo_z3, " do not memoize calls to z3 (default)"); + ( "-have_feature", + Arg.String (fun symbol -> opt_have_feature := Some symbol), + " check if a feature symbol is set by default" + ); + ("-no_color", Arg.Clear Util.opt_colors, " do not use terminal color codes in output"); + ( "-undefined_gen", + Arg.Set Initial_check.opt_undefined_gen, + " generate undefined_type functions for types in the specification" + ); + ( "-grouped_regstate", + Arg.Set State.opt_type_grouped_regstate, + " group registers with same type together in generated register state record" + ); + ( "-enum_casts", + Arg.Set Initial_check.opt_enum_casts, + " allow enumerations to be automatically casted to numeric range types" + ); + ("-non_lexical_flow", Arg.Set Nl_flow.opt_nl_flow, " allow non-lexical flow typing"); + ( "-no_lexp_bounds_check", + Arg.Set Type_check.opt_no_lexp_bounds_check, + " turn off bounds checking for vector assignments in l-expressions" + ); + ("-auto_mono", Arg.Set Rewrites.opt_auto_mono, " automatically infer how to monomorphise code"); + ("-mono_rewrites", Arg.Set Rewrites.opt_mono_rewrites, " turn on rewrites for combining bitvector operations"); + ( "-mono_split", + Arg.String + (fun s -> + let l = Util.split_on_char ':' s in + match l with + | [fname; line; var] -> + Rewrites.opt_mono_split := ((fname, int_of_string line), var) :: !Rewrites.opt_mono_split + | _ -> raise (Arg.Bad (s ^ " not of form ::")) + ), + ":: manually gives a case split for monomorphisation" + ); + ( "-splice", + Arg.String (fun s -> opt_splice := s :: !opt_splice), + " add functions from file, replacing existing definitions where necessary" + ); + ( "-smt_solver", + Arg.String (fun s -> Constraint.set_solver (String.trim s)), + " choose SMT solver. Supported solvers are z3 (default), alt-ergo, cvc4, mathsat, vampire and yices." + ); + ( "-smt_linearize", + Arg.Set Type_check.opt_smt_linearize, + "(experimental) force linearization for constraints involving exponentials" + ); + ("-Oconstant_fold", Arg.Set Constant_fold.optimize_constant_fold, " apply constant folding optimizations"); + ( "-Oaarch64_fast", + Arg.Set Jib_compile.optimize_aarch64_fast_struct, + " apply ARMv8.5 specific optimizations (potentially unsound in general)" + ); + ("-Ofast_undefined", Arg.Set Initial_check.opt_fast_undefined, " turn on fast-undefined mode"); + ( "-const_prop_mutrec", + Arg.String + (fun name -> + Constant_propagation_mutrec.targets := Ast_util.mk_id name :: !Constant_propagation_mutrec.targets + ), + " unroll function in a set of mutually recursive functions" + ); + ("-ddump_initial_ast", Arg.Set Frontend.opt_ddump_initial_ast, " (debug) dump the initial ast to stdout"); + ("-ddump_tc_ast", Arg.Set Frontend.opt_ddump_tc_ast, " (debug) dump the typechecked ast to stdout"); + ( "-dtc_verbose", + Arg.Int (fun verbosity -> Type_check.opt_tc_debug := verbosity), + " (debug) verbose typechecker output: 0 is silent" + ); + ("-dsmt_verbose", Arg.Set Constraint.opt_smt_verbose, " (debug) print SMTLIB constraints sent to SMT solver"); + ("-dmagic_hash", Arg.Set Initial_check.opt_magic_hash, " (debug) allow special character # in identifiers"); + ( "-dprofile", + Arg.Set Profile.opt_profile, + " (debug) provide basic profiling information for rewriting passes within Sail" + ); + ("-dno_cast", Arg.Set Frontend.opt_dno_cast, " (debug) typecheck without any implicit casting"); + ( "-dallow_cast", + Arg.Tuple + [ + Arg.Unit (fun () -> Reporting.simple_warn "-dallow_cast option is deprecated"); + Arg.Clear Frontend.opt_dno_cast; + ], + " (debug) typecheck allowing implicit casting (deprecated)" + ); + ( "-ddump_rewrite_ast", + Arg.String + (fun l -> + Rewrites.opt_ddump_rewrite_ast := Some (l, 0); + Specialize.opt_ddump_spec_ast := Some (l, 0) + ), + " (debug) dump the ast after each rewriting step to _.lem" + ); + ( "-dmono_all_split_errors", + Arg.Set Rewrites.opt_dall_split_errors, + " (debug) display all case split errors from monomorphisation, rather than one" + ); + ( "-dmono_analysis", + Arg.Set_int Rewrites.opt_dmono_analysis, + " (debug) dump information about monomorphisation analysis: 0 silent, 3 max" + ); + ("-dmono_continue", Arg.Set Rewrites.opt_dmono_continue, " (debug) continue despite monomorphisation errors"); + ("-dpattern_warning_no_literals", Arg.Set Pattern_completeness.opt_debug_no_literals, ""); + ( "-infer_effects", + Arg.Unit (fun () -> Reporting.simple_warn "-infer_effects option is deprecated"), + " Ignored for compatibility with older versions; effects are always inferred now (deprecated)" + ); + ( "-dbacktrace", + Arg.Int (fun l -> Reporting.opt_backtrace_length := l), + " Length of backtrace to show when reporting unreachable code" + ); + ("-v", Arg.Set opt_print_version, " print version"); + ("-version", Arg.Set opt_print_version, " print version"); + ("-verbose", Arg.Int (fun verbosity -> Util.opt_verbosity := verbosity), " produce verbose output"); + ( "-explain_all_variables", + Arg.Set Type_error.opt_explain_all_variables, + " Explain all type variables in type error messages" + ); + ( "-explain_constraints", + Arg.Set Type_error.opt_explain_constraints, + " Explain all type variables in type error messages" + ); + ( "-explain_verbose", + Arg.Tuple [Arg.Set Type_error.opt_explain_all_variables; Arg.Set Type_error.opt_explain_constraints], + " Add the maximum amount of explanation to type errors" + ); + ("-help", Arg.Unit (fun () -> help !options), " Display this list of options. Also available as -h or --help"); + ("-h", Arg.Unit (fun () -> help !options), ""); + ("--help", Arg.Unit (fun () -> help !options), ""); + ] + +let register_default_target () = Target.register ~name:"default" (fun _ _ _ _ _ -> ()) + let run_sail tgt = Target.run_pre_parse_hook tgt (); - let ast, env, effect_info = Frontend.load_files ~target:tgt Manifest.dir !options Type_check.initial_env !opt_file_arguments in - let ast, env = Frontend.initial_rewrite effect_info env ast in - let ast, env = - List.fold_right (fun file (ast, _) -> Splice.splice ast file) - (!opt_splice) (ast, env) + let ast, env, effect_info = + Frontend.load_files ~target:tgt Manifest.dir !options Type_check.initial_env !opt_file_arguments in + let ast, env = Frontend.initial_rewrite effect_info env ast in + let ast, env = List.fold_right (fun file (ast, _) -> Splice.splice ast file) !opt_splice (ast, env) in let effect_info = Effects.infer_side_effects (Target.asserts_termination tgt) ast in - Reporting.opt_warnings := false; (* Don't show warnings during re-writing for now *) + Reporting.opt_warnings := false; + (* Don't show warnings during re-writing for now *) Target.run_pre_rewrites_hook tgt ast effect_info env; let ast, effect_info, env = Rewrites.rewrite effect_info env (Target.rewrites tgt) ast in @@ -339,52 +323,41 @@ let file_to_string filename = Buffer.contents buf let run_sail_format (config : Yojson.Basic.t option) = - let is_format_file f = match !opt_format_only with - | [] -> true - | files -> List.exists (fun f' -> f = f') files - in - let is_skipped_file f = match !opt_format_skip with - | [] -> false - | files -> List.exists (fun f' -> f = f') files - in + let is_format_file f = match !opt_format_only with [] -> true | files -> List.exists (fun f' -> f = f') files in + let is_skipped_file f = match !opt_format_skip with [] -> false | files -> List.exists (fun f' -> f = f') files in let module Config = struct - let config = match config with - | Some (`Assoc keys) -> - List.assoc_opt "fmt" keys - |> Option.map Format_sail.config_from_json - |> Option.value ~default:Format_sail.default_config - | Some _ -> - raise (Reporting.err_general Parse_ast.Unknown "Invalid configuration file (must be a json object)") - | None -> - Format_sail.default_config - end in - let module Formatter = Format_sail.Make(Config) in + let config = + match config with + | Some (`Assoc keys) -> + List.assoc_opt "fmt" keys |> Option.map Format_sail.config_from_json + |> Option.value ~default:Format_sail.default_config + | Some _ -> raise (Reporting.err_general Parse_ast.Unknown "Invalid configuration file (must be a json object)") + | None -> Format_sail.default_config + end in + let module Formatter = Format_sail.Make (Config) in let parsed_files = List.map (fun f -> (f, Initial_check.parse_file f)) !opt_file_arguments in - List.iter (fun (f, (comments, parse_ast)) -> + List.iter + (fun (f, (comments, parse_ast)) -> let source = file_to_string f in if is_format_file f && not (is_skipped_file f) then ( let formatted = Formatter.format_defs f source comments parse_ast in - begin match !opt_format_backup with - | Some backup_file -> - let out_chan = open_out backup_file in - output_string out_chan source; - close_out out_chan - | None -> () + begin + match !opt_format_backup with + | Some backup_file -> + let out_chan = open_out backup_file in + output_string out_chan source; + close_out out_chan + | None -> () end; let ((out_chan, _, _, _) as file_info) = Util.open_output_with_check_unformatted None f in output_string out_chan formatted; Util.close_output_with_check file_info ) - ) parsed_files - + ) + parsed_files + let feature_check () = - match !opt_have_feature with - | None -> () - | Some symbol -> - if Preprocess.have_symbol symbol then - exit 0 - else - exit 2 + match !opt_have_feature with None -> () | Some symbol -> if Preprocess.have_symbol symbol then exit 0 else exit 2 let get_plugin_dir () = match Sys.getenv_opt "SAIL_PLUGIN_DIR" with @@ -394,66 +367,59 @@ let get_plugin_dir () = let rec find_file_above ?prev_inode_opt dir file = try let inode = (Unix.stat dir).st_ino in - if Option.fold ~none:true ~some:((<>) inode) prev_inode_opt then ( + if Option.fold ~none:true ~some:(( <> ) inode) prev_inode_opt then ( let filepath = Filename.concat dir file in - if Sys.file_exists filepath then ( - Some filepath - ) else ( - find_file_above ~prev_inode_opt:inode (dir ^ Filename.dir_sep ^ Filename.parent_dir_name) file - ) - ) else ( - None + if Sys.file_exists filepath then Some filepath + else find_file_above ~prev_inode_opt:inode (dir ^ Filename.dir_sep ^ Filename.parent_dir_name) file ) - with - | Unix.Unix_error _ -> None + else None + with Unix.Unix_error _ -> None let get_config_file () = let check_exists file = - if Sys.file_exists file then ( - Some file - ) else ( + if Sys.file_exists file then Some file + else ( Reporting.warn "" Parse_ast.Unknown (Printf.sprintf "Configuration file %s does not exist" file); None - ) in + ) + in match !opt_config_file with | Some file -> check_exists file - | None -> - match Sys.getenv_opt "SAIL_CONFIG" with - | Some file -> check_exists file - | None -> find_file_above (Sys.getcwd ()) "sail_config.json" + | None -> ( + match Sys.getenv_opt "SAIL_CONFIG" with + | Some file -> check_exists file + | None -> find_file_above (Sys.getcwd ()) "sail_config.json" + ) let parse_config_file file = - try - Some (Yojson.Basic.from_file ~fname:file ~lnum:0 file) - with - | Yojson.Json_error message -> - Reporting.warn "" Parse_ast.Unknown (Printf.sprintf "Failed to parse configuration file: %s" message); - None - + try Some (Yojson.Basic.from_file ~fname:file ~lnum:0 file) + with Yojson.Json_error message -> + Reporting.warn "" Parse_ast.Unknown (Printf.sprintf "Failed to parse configuration file: %s" message); + None + let main () = - begin match Sys.getenv_opt "SAIL_NO_PLUGINS" with - | Some _ -> () - | None -> - match get_plugin_dir () with - | dir :: _ -> - List.iter - (fun plugin -> - let path = Filename.concat dir plugin in - if Filename.extension plugin = ".cmxs" then - load_plugin options path) - (Array.to_list (Sys.readdir dir)) - | [] -> () + begin + match Sys.getenv_opt "SAIL_NO_PLUGINS" with + | Some _ -> () + | None -> ( + match get_plugin_dir () with + | dir :: _ -> + List.iter + (fun plugin -> + let path = Filename.concat dir plugin in + if Filename.extension plugin = ".cmxs" then load_plugin options path + ) + (Array.to_list (Sys.readdir dir)) + | [] -> () + ) end; options := Arg.align !options; - - Arg.parse_dynamic options - (fun s -> - opt_file_arguments := (!opt_file_arguments) @ [s]) - usage_msg; + + Arg.parse_dynamic options (fun s -> opt_file_arguments := !opt_file_arguments @ [s]) usage_msg; let config = Option.bind (get_config_file ()) parse_config_file in - + feature_check (); if !opt_print_version then ( @@ -470,12 +436,13 @@ let main () = run_sail_format config; exit 0 ); - + let default_target = register_default_target () in if !opt_memo_z3 then Constraint.load_digests (); - - let ast, env, effect_info = match Target.get_the_target () with + + let ast, env, effect_info = + match Target.get_the_target () with | Some target when not !opt_just_check -> run_sail target | _ -> run_sail default_target in @@ -483,29 +450,27 @@ let main () = if !opt_memo_z3 then Constraint.save_digests (); if !Interactive.opt_interactive then ( - let script = match !opt_interactive_script with + let script = + match !opt_interactive_script with | None -> [] - | Some file -> - let chan = open_in file in - let lines = ref [] in - try - while true do - let line = input_line chan in - lines := line :: !lines - done; - [] - with - | End_of_file -> List.rev !lines + | Some file -> ( + let chan = open_in file in + let lines = ref [] in + try + while true do + let line = input_line chan in + lines := line :: !lines + done; + [] + with End_of_file -> List.rev !lines + ) in - Repl.start_repl ~commands:script ~auto_rewrites:(!opt_auto_interpreter_rewrites) ~options:!(options) env effect_info ast + Repl.start_repl ~commands:script ~auto_rewrites:!opt_auto_interpreter_rewrites ~options:!options env effect_info ast ) let () = - try ( - try main () - with Failure s -> raise (Reporting.err_general Parse_ast.Unknown s) - ) with - | Reporting.Fatal_error e -> - Reporting.print_error e; - if !opt_memo_z3 then Constraint.save_digests () else (); - exit 1 + try try main () with Failure s -> raise (Reporting.err_general Parse_ast.Unknown s) + with Reporting.Fatal_error e -> + Reporting.print_error e; + if !opt_memo_z3 then Constraint.save_digests () else (); + exit 1 diff --git a/src/lem_interp/pretty_interp.ml b/src/lem_interp/pretty_interp.ml index d524f4c82..692bb2e32 100644 --- a/src/lem_interp/pretty_interp.ml +++ b/src/lem_interp/pretty_interp.ml @@ -91,10 +91,7 @@ let ignore_casts = ref true let zero_big = of_int 0 let one_big = of_int 1 -let pp_format_id (Id_aux(i,_)) = - match i with - | Id(i) -> i - | DeIid(x) -> "(deinfix " ^ x ^ ")" +let pp_format_id (Id_aux (i, _)) = match i with Id i -> i | DeIid x -> "(deinfix " ^ x ^ ")" let lit_to_string = function | L_unit -> "unit" @@ -103,54 +100,55 @@ let lit_to_string = function | L_true -> "true" | L_false -> "false" | L_num n -> Nat_big_num.to_string n - | L_hex s -> "0x"^s - | L_bin s -> "0b"^s + | L_hex s -> "0x" ^ s + | L_bin s -> "0b" ^ s | L_undef -> "undefined" | L_string s -> "\"" ^ s ^ "\"" -;; -let id_to_string = function - | Id_aux(Id s,_) | Id_aux(DeIid s,_) -> s -;; +let id_to_string = function Id_aux (Id s, _) | Id_aux (DeIid s, _) -> s let rec loc_to_string = function | Unknown -> "location unknown" - | Int(s,_) -> s + | Int (s, _) -> s | Generated l -> "Generated near " ^ loc_to_string l - | Range(s,fline,fchar,tline,tchar) -> - if fline = tline - then sprintf "%s:%d:%d" s fline fchar - else sprintf "%s:%d:%d-%d:%d" s fline fchar tline tchar -;; - + | Range (s, fline, fchar, tline, tchar) -> + if fline = tline then sprintf "%s:%d:%d" s fline fchar else sprintf "%s:%d:%d-%d:%d" s fline fchar tline tchar + let collapse_leading s = - if String.length s <= 8 then s else - let first_bit = s.[0] in - let templ = sprintf "%c...%c" first_bit first_bit in - - let rec find_first_diff str cha pos = - if pos >= String.length str then None - else if str.[pos] != cha then Some pos - else find_first_diff str cha (pos+1) - in + if String.length s <= 8 then s + else ( + let first_bit = s.[0] in + let templ = sprintf "%c...%c" first_bit first_bit in - match find_first_diff s first_bit 0 with - | None -> templ - | Some pos when pos > 4 -> templ ^ (String.sub s pos ((String.length s)- pos)) - | _ -> s -;; + let rec find_first_diff str cha pos = + if pos >= String.length str then None else if str.[pos] != cha then Some pos else find_first_diff str cha (pos + 1) + in -(* pp the bytes of a Bytevector as a hex value *) + match find_first_diff s first_bit 0 with + | None -> templ + | Some pos when pos > 4 -> templ ^ String.sub s pos (String.length s - pos) + | _ -> s + ) -let bitvec_to_string l = "0b" ^ collapse_leading (String.concat "" (List.map (function - | Interp_ast.V_lit(L_aux(L_zero, _)) -> "0" - | Interp_ast.V_lit(L_aux(L_one, _)) -> "1" - | Interp_ast.V_lit(L_aux(L_undef, _)) -> "u" - | Interp_ast.V_unknown -> "?" - | v -> (Printf.printf "bitvec found a non bit %s%!\n" (Interp.string_of_value v));assert false) - (List.map Interp.detaint l))) - ;; +(* pp the bytes of a Bytevector as a hex value *) +let bitvec_to_string l = + "0b" + ^ collapse_leading + (String.concat "" + (List.map + (function + | Interp_ast.V_lit (L_aux (L_zero, _)) -> "0" + | Interp_ast.V_lit (L_aux (L_one, _)) -> "1" + | Interp_ast.V_lit (L_aux (L_undef, _)) -> "u" + | Interp_ast.V_unknown -> "?" + | v -> + Printf.printf "bitvec found a non bit %s%!\n" (Interp.string_of_value v); + assert false + ) + (List.map Interp.detaint l) + ) + ) (**************************************************************************** * PPrint-based source-to-source pretty printer @@ -158,7 +156,7 @@ let bitvec_to_string l = "0b" ^ collapse_leading (String.concat "" (List.map (fu open PPrint -let doc_id (Id_aux(i,_)) = +let doc_id (Id_aux (i, _)) = match i with | Id "0" -> string "\x1b[1;31m[_]\x1b[m" (* internal representation of a hole *) | Id i -> string i @@ -167,15 +165,11 @@ let doc_id (Id_aux(i,_)) = * token in case of x ending with star. *) parens (separate space [string "deinfix"; string x; empty]) -let doc_var (Kid_aux(Var v,_)) = string v +let doc_var (Kid_aux (Var v, _)) = string v let doc_int i = string (to_string i) -let doc_bkind (BK_aux(k,_)) = - string (match k with - | BK_type -> "Type" - | BK_int -> "Int" - | BK_order -> "Order") +let doc_bkind (BK_aux (k, _)) = string (match k with BK_type -> "Type" | BK_int -> "Int" | BK_order -> "Order") let doc_op symb a b = infix 2 1 symb a b let doc_unop symb a = prefix 2 1 symb a @@ -194,541 +188,591 @@ let semi_sp = semi ^^ space let comma_sp = comma ^^ space let colon_sp = spaces colon -let doc_kind (K_aux(K_kind(klst),_)) = - separate_map (spaces arrow) doc_bkind klst - -let doc_effect (BE_aux (e,_)) = - string (match e with - | BE_rreg -> "rreg" - | BE_wreg -> "wreg" - | BE_rmem -> "rmem" - | BE_wmem -> "wmem" - | BE_wmv -> "wmv" - | BE_eamem -> "eamem" - | BE_exmem -> "exmem" - | BE_barr -> "barr" - | BE_depend -> "depend" - | BE_undef -> "undef" - | BE_unspec -> "unspec" - | BE_escape -> "escape" - | BE_nondet -> "nondet" - | BE_lset -> "(*lset*)" - | BE_lret -> "(*lret*)") - -let doc_effects (Effect_aux(e,_)) = match e with +let doc_kind (K_aux (K_kind klst, _)) = separate_map (spaces arrow) doc_bkind klst + +let doc_effect (BE_aux (e, _)) = + string + ( match e with + | BE_rreg -> "rreg" + | BE_wreg -> "wreg" + | BE_rmem -> "rmem" + | BE_wmem -> "wmem" + | BE_wmv -> "wmv" + | BE_eamem -> "eamem" + | BE_exmem -> "exmem" + | BE_barr -> "barr" + | BE_depend -> "depend" + | BE_undef -> "undef" + | BE_unspec -> "unspec" + | BE_escape -> "escape" + | BE_nondet -> "nondet" + | BE_lset -> "(*lset*)" + | BE_lret -> "(*lret*)" + ) + +let doc_effects (Effect_aux (e, _)) = + match e with | Effect_var v -> doc_var v | Effect_set [] -> string "pure" | Effect_set s -> braces (separate_map comma_sp doc_effect s) -let doc_ord (Ord_aux(o,_)) = match o with - | Ord_var v -> doc_var v - | Ord_inc -> string "inc" - | Ord_dec -> string "dec" +let doc_ord (Ord_aux (o, _)) = match o with Ord_var v -> doc_var v | Ord_inc -> string "inc" | Ord_dec -> string "dec" let doc_typ, doc_atomic_typ, doc_nexp = (* following the structure of parser for precedence *) let rec typ ty = fn_typ ty - and fn_typ ((Typ_aux (t, _)) as ty) = match t with - | Typ_fn(arg,ret,efct) -> - separate space [tup_typ arg; arrow; fn_typ ret; string "effect"; doc_effects efct] - | _ -> tup_typ ty - and tup_typ ((Typ_aux (t, _)) as ty) = match t with - | Typ_tuple typs -> parens (separate_map comma_sp app_typ typs) - | _ -> app_typ ty - and app_typ ((Typ_aux (t, _)) as ty) = match t with - | Typ_app(Id_aux (Id "vector", _), [ - Typ_arg_aux(Typ_arg_nexp (Nexp_aux(Nexp_constant n, _)), _); - Typ_arg_aux(Typ_arg_nexp (Nexp_aux(Nexp_constant m, _)), _); - Typ_arg_aux (Typ_arg_order (Ord_aux (Ord_inc, _)), _); - Typ_arg_aux (Typ_arg_typ (Typ_aux (Typ_id id, _)), _)]) -> - (doc_id id) ^^ - (brackets (if equal n zero_big then doc_int m - else doc_op colon (doc_int n) (doc_int (add n (sub m one_big))))) - | Typ_app(Id_aux (Id "range", _), [ - Typ_arg_aux(Typ_arg_nexp (Nexp_aux(Nexp_constant n, _)), _); - Typ_arg_aux(Typ_arg_nexp m, _);]) -> - (squarebars (if equal n zero_big then nexp m else doc_op colon (doc_int n) (nexp m))) - | Typ_app(id,args) -> - (* trailing space to avoid >> token in case of nested app types *) - (doc_id id) ^^ (angles (separate_map comma_sp doc_typ_arg args)) ^^ space - | _ -> atomic_typ ty - and atomic_typ ((Typ_aux (t, _)) as ty) = match t with - | Typ_id id -> doc_id id - | Typ_var v -> doc_var v - | Typ_app _ | Typ_tuple _ | Typ_fn _ -> - (* exhaustiveness matters here to avoid infinite loops - * if we add a new Typ constructor *) - group (parens (typ ty)) - and doc_typ_arg (Typ_arg_aux(t,_)) = match t with - (* Be careful here because typ_arg is implemented as nexp in the - * parser - in practice falling through app_typ after all the proper nexp - * cases; so Typ_arg_typ has the same precedence as a Typ_app *) - | Typ_arg_typ t -> app_typ t - | Typ_arg_nexp n -> nexp n - | Typ_arg_order o -> doc_ord o - + and fn_typ (Typ_aux (t, _) as ty) = + match t with + | Typ_fn (arg, ret, efct) -> separate space [tup_typ arg; arrow; fn_typ ret; string "effect"; doc_effects efct] + | _ -> tup_typ ty + and tup_typ (Typ_aux (t, _) as ty) = + match t with Typ_tuple typs -> parens (separate_map comma_sp app_typ typs) | _ -> app_typ ty + and app_typ (Typ_aux (t, _) as ty) = + match t with + | Typ_app + ( Id_aux (Id "vector", _), + [ + Typ_arg_aux (Typ_arg_nexp (Nexp_aux (Nexp_constant n, _)), _); + Typ_arg_aux (Typ_arg_nexp (Nexp_aux (Nexp_constant m, _)), _); + Typ_arg_aux (Typ_arg_order (Ord_aux (Ord_inc, _)), _); + Typ_arg_aux (Typ_arg_typ (Typ_aux (Typ_id id, _)), _); + ] + ) -> + doc_id id + ^^ brackets (if equal n zero_big then doc_int m else doc_op colon (doc_int n) (doc_int (add n (sub m one_big)))) + | Typ_app + ( Id_aux (Id "range", _), + [Typ_arg_aux (Typ_arg_nexp (Nexp_aux (Nexp_constant n, _)), _); Typ_arg_aux (Typ_arg_nexp m, _)] + ) -> + squarebars (if equal n zero_big then nexp m else doc_op colon (doc_int n) (nexp m)) + | Typ_app (id, args) -> + (* trailing space to avoid >> token in case of nested app types *) + doc_id id ^^ angles (separate_map comma_sp doc_typ_arg args) ^^ space + | _ -> atomic_typ ty + and atomic_typ (Typ_aux (t, _) as ty) = + match t with + | Typ_id id -> doc_id id + | Typ_var v -> doc_var v + | Typ_app _ | Typ_tuple _ | Typ_fn _ -> + (* exhaustiveness matters here to avoid infinite loops + * if we add a new Typ constructor *) + group (parens (typ ty)) + and doc_typ_arg (Typ_arg_aux (t, _)) = + match t with + (* Be careful here because typ_arg is implemented as nexp in the + * parser - in practice falling through app_typ after all the proper nexp + * cases; so Typ_arg_typ has the same precedence as a Typ_app *) + | Typ_arg_typ t -> app_typ t + | Typ_arg_nexp n -> nexp n + | Typ_arg_order o -> doc_ord o (* same trick to handle precedence of nexp *) and nexp ne = sum_typ ne - and sum_typ ((Nexp_aux(n,_)) as ne) = match n with - | Nexp_sum(n1,n2) -> doc_op plus (sum_typ n1) (star_typ n2) - | Nexp_minus(n1,n2) -> doc_op minus (sum_typ n1) (star_typ n2) - | _ -> star_typ ne - and star_typ ((Nexp_aux(n,_)) as ne) = match n with - | Nexp_times(n1,n2) -> doc_op star (star_typ n1) (exp_typ n2) - | _ -> exp_typ ne - and exp_typ ((Nexp_aux(n,_)) as ne) = match n with - | Nexp_exp n1 -> doc_unop (string "2**") (atomic_nexp_typ n1) - | _ -> neg_typ ne - and neg_typ ((Nexp_aux(n,_)) as ne) = match n with - | Nexp_neg n1 -> - (* XXX this is not valid Sail, only an internal representation - - * work around by commenting it *) - let minus = concat [string "(*"; minus; string "*)"] in - minus ^^ (atomic_nexp_typ n1) - | _ -> atomic_nexp_typ ne - and atomic_nexp_typ ((Nexp_aux(n,_)) as ne) = match n with + and sum_typ (Nexp_aux (n, _) as ne) = + match n with + | Nexp_sum (n1, n2) -> doc_op plus (sum_typ n1) (star_typ n2) + | Nexp_minus (n1, n2) -> doc_op minus (sum_typ n1) (star_typ n2) + | _ -> star_typ ne + and star_typ (Nexp_aux (n, _) as ne) = + match n with Nexp_times (n1, n2) -> doc_op star (star_typ n1) (exp_typ n2) | _ -> exp_typ ne + and exp_typ (Nexp_aux (n, _) as ne) = + match n with Nexp_exp n1 -> doc_unop (string "2**") (atomic_nexp_typ n1) | _ -> neg_typ ne + and neg_typ (Nexp_aux (n, _) as ne) = + match n with + | Nexp_neg n1 -> + (* XXX this is not valid Sail, only an internal representation - + * work around by commenting it *) + let minus = concat [string "(*"; minus; string "*)"] in + minus ^^ atomic_nexp_typ n1 + | _ -> atomic_nexp_typ ne + and atomic_nexp_typ (Nexp_aux (n, _) as ne) = + match n with | Nexp_id id -> doc_id id | Nexp_var v -> doc_var v | Nexp_constant i -> doc_int i - | Nexp_neg _ | Nexp_exp _ | Nexp_times _ | Nexp_sum _ | Nexp_minus _ -> - group (parens (nexp ne)) + | Nexp_neg _ | Nexp_exp _ | Nexp_times _ | Nexp_sum _ | Nexp_minus _ -> group (parens (nexp ne)) + (* expose doc_typ, doc_atomic_typ and doc_nexp *) + in - (* expose doc_typ, doc_atomic_typ and doc_nexp *) - in typ, atomic_typ, nexp + (typ, atomic_typ, nexp) -let doc_nexp_constraint (NC_aux(nc,_)) = match nc with - | NC_equal(n1,n2) -> doc_op equals (doc_nexp n1) (doc_nexp n2) - | NC_bounded_ge(n1,n2) -> doc_op (string ">=") (doc_nexp n1) (doc_nexp n2) - | NC_bounded_le(n1,n2) -> doc_op (string "<=") (doc_nexp n1) (doc_nexp n2) - | NC_set(v,bounds) -> - doc_op (string "IN") (doc_var v) - (braces (separate_map comma_sp doc_int bounds)) +let doc_nexp_constraint (NC_aux (nc, _)) = + match nc with + | NC_equal (n1, n2) -> doc_op equals (doc_nexp n1) (doc_nexp n2) + | NC_bounded_ge (n1, n2) -> doc_op (string ">=") (doc_nexp n1) (doc_nexp n2) + | NC_bounded_le (n1, n2) -> doc_op (string "<=") (doc_nexp n1) (doc_nexp n2) + | NC_set (v, bounds) -> doc_op (string "IN") (doc_var v) (braces (separate_map comma_sp doc_int bounds)) -let doc_qi (QI_aux(qi,_)) = match qi with +let doc_qi (QI_aux (qi, _)) = + match qi with | QI_const n_const -> doc_nexp_constraint n_const - | QI_id(KOpt_aux(ki,_)) -> - match ki with - | KOpt_none v -> doc_var v - | KOpt_kind(k,v) -> separate space [doc_kind k; doc_var v] + | QI_id (KOpt_aux (ki, _)) -> ( + match ki with KOpt_none v -> doc_var v | KOpt_kind (k, v) -> separate space [doc_kind k; doc_var v] + ) (* typ_doc is the doc for the type being quantified *) -let doc_typquant (TypQ_aux(tq,_)) typ_doc = match tq with +let doc_typquant (TypQ_aux (tq, _)) typ_doc = + match tq with | TypQ_no_forall -> typ_doc | TypQ_tq [] -> failwith "TypQ_tq with empty list" | TypQ_tq qlist -> - (* include trailing break because the caller doesn't know if tq is empty *) - doc_op dot - (separate space [string "forall"; separate_map comma_sp doc_qi qlist]) - typ_doc - -let doc_typscm (TypSchm_aux(TypSchm_ts(tq,t),_)) = - (doc_typquant tq (doc_typ t)) - -let doc_typscm_atomic (TypSchm_aux(TypSchm_ts(tq,t),_)) = - (doc_typquant tq (doc_atomic_typ t)) - -let doc_lit (L_aux(l,_)) = - utf8string (match l with - | L_unit -> "()" - | L_zero -> "bitzero" - | L_one -> "bitone" - | L_true -> "true" - | L_false -> "false" - | L_num i -> to_string i - | L_hex n -> "0x" ^ n - | L_bin n -> "0b" ^ n - | L_undef -> "undefined" - | L_string s -> "\"" ^ s ^ "\"") + (* include trailing break because the caller doesn't know if tq is empty *) + doc_op dot (separate space [string "forall"; separate_map comma_sp doc_qi qlist]) typ_doc + +let doc_typscm (TypSchm_aux (TypSchm_ts (tq, t), _)) = doc_typquant tq (doc_typ t) + +let doc_typscm_atomic (TypSchm_aux (TypSchm_ts (tq, t), _)) = doc_typquant tq (doc_atomic_typ t) + +let doc_lit (L_aux (l, _)) = + utf8string + ( match l with + | L_unit -> "()" + | L_zero -> "bitzero" + | L_one -> "bitone" + | L_true -> "true" + | L_false -> "false" + | L_num i -> to_string i + | L_hex n -> "0x" ^ n + | L_bin n -> "0b" ^ n + | L_undef -> "undefined" + | L_string s -> "\"" ^ s ^ "\"" + ) let doc_pat, doc_atomic_pat = let rec pat pa = pat_colons pa - and pat_colons ((P_aux(p,l)) as pa) = match p with - | P_vector_concat pats -> separate_map colon_sp atomic_pat pats - | _ -> app_pat pa - and app_pat ((P_aux(p,l)) as pa) = match p with - | P_app(id, ((_ :: _) as pats)) -> doc_unop (doc_id id) (parens (separate_map comma_sp atomic_pat pats)) - | _ -> atomic_pat pa - and atomic_pat ((P_aux(p,l)) as pa) = match p with - | P_lit lit -> doc_lit lit - | P_wild -> underscore - | P_id id -> doc_id id - | P_as(p,id) -> parens (separate space [pat p; string "as"; doc_id id]) - | P_typ(typ,p) -> separate space [parens (doc_typ typ); atomic_pat p] - | P_app(id,[]) -> doc_id id - | P_record(fpats,_) -> braces (separate_map semi_sp fpat fpats) - | P_vector pats -> brackets (separate_map comma_sp atomic_pat pats) - | P_tuple pats -> parens (separate_map comma_sp atomic_pat pats) - | P_list pats -> squarebarbars (separate_map semi_sp atomic_pat pats) - | P_app(_, _ :: _) | P_vector_concat _ -> - group (parens (pat pa)) - and fpat (FP_aux(FP_Fpat(id,fpat),_)) = doc_op equals (doc_id id) (pat fpat) - and npat (i,p) = doc_op equals (doc_int i) (pat p) - - (* expose doc_pat and doc_atomic_pat *) - in pat, atomic_pat + and pat_colons (P_aux (p, l) as pa) = + match p with P_vector_concat pats -> separate_map colon_sp atomic_pat pats | _ -> app_pat pa + and app_pat (P_aux (p, l) as pa) = + match p with + | P_app (id, (_ :: _ as pats)) -> doc_unop (doc_id id) (parens (separate_map comma_sp atomic_pat pats)) + | _ -> atomic_pat pa + and atomic_pat (P_aux (p, l) as pa) = + match p with + | P_lit lit -> doc_lit lit + | P_wild -> underscore + | P_id id -> doc_id id + | P_as (p, id) -> parens (separate space [pat p; string "as"; doc_id id]) + | P_typ (typ, p) -> separate space [parens (doc_typ typ); atomic_pat p] + | P_app (id, []) -> doc_id id + | P_record (fpats, _) -> braces (separate_map semi_sp fpat fpats) + | P_vector pats -> brackets (separate_map comma_sp atomic_pat pats) + | P_tuple pats -> parens (separate_map comma_sp atomic_pat pats) + | P_list pats -> squarebarbars (separate_map semi_sp atomic_pat pats) + | P_app (_, _ :: _) | P_vector_concat _ -> group (parens (pat pa)) + and fpat (FP_aux (FP_Fpat (id, fpat), _)) = doc_op equals (doc_id id) (pat fpat) + and npat (i, p) = doc_op equals (doc_int i) (pat p) (* expose doc_pat and doc_atomic_pat *) in + + (pat, atomic_pat) let doc_exp, doc_let = let rec exp env mem add_red show_hole_contents e = group (or_exp env mem add_red show_hole_contents e) - and or_exp env mem add_red show_hole_contents ((E_aux(e,_)) as expr) = match e with - | E_app_infix(l,(Id_aux(Id ("|" | "||"),_) as op),r) -> - doc_op (doc_id op) (and_exp env mem add_red show_hole_contents l) (or_exp env mem add_red show_hole_contents r) - | _ -> and_exp env mem add_red show_hole_contents expr - and and_exp env mem add_red show_hole_contents ((E_aux(e,_)) as expr) = match e with - | E_app_infix(l,(Id_aux(Id ("&" | "&&"),_) as op),r) -> - doc_op (doc_id op) (eq_exp env mem add_red show_hole_contents l) (and_exp env mem add_red show_hole_contents r) - | _ -> eq_exp env mem add_red show_hole_contents expr - and eq_exp env mem add_red show_hole_contents ((E_aux(e,_)) as expr) = match e with - | E_app_infix(l,(Id_aux(Id ( - (* XXX this is not very consistent - is the parser bogus here? *) - "=" | "==" | "!=" - | ">=" | ">=_s" | ">=_u" | ">" | ">_s" | ">_u" - | "<=" | "<=_s" | "<" | "<_s" | "<_si" | "<_u" - ),_) as op),r) -> - doc_op (doc_id op) (eq_exp env mem add_red show_hole_contents l) (at_exp env mem add_red show_hole_contents r) - (* XXX assignment should not have the same precedence as equal etc. *) - | E_assign(le,exp) -> - doc_op coloneq (doc_lexp env mem add_red show_hole_contents le) (at_exp env mem add_red show_hole_contents exp) - | _ -> at_exp env mem add_red show_hole_contents expr - and at_exp env mem add_red show_hole_contents ((E_aux(e,_)) as expr) = match e with - | E_app_infix(l,(Id_aux(Id ("@" | "^^" | "^" | "~^"),_) as op),r) -> - doc_op (doc_id op) (cons_exp env mem add_red show_hole_contents l) (at_exp env mem add_red show_hole_contents r) - | _ -> cons_exp env mem add_red show_hole_contents expr - and cons_exp env mem add_red show_hole_contents ((E_aux(e,_)) as expr) = match e with - | E_vector_append(l,r) -> - doc_op colon (shift_exp env mem add_red show_hole_contents l) (cons_exp env mem add_red show_hole_contents r) - | E_cons(l,r) -> - doc_op colon (shift_exp env mem add_red show_hole_contents l) (cons_exp env mem add_red show_hole_contents r) - | _ -> shift_exp env mem add_red show_hole_contents expr - and shift_exp env mem add_red show_hole_contents ((E_aux(e,_)) as expr) = match e with - | E_app_infix(l,(Id_aux(Id (">>" | ">>>" | "<<" | "<<<"),_) as op),r) -> - doc_op (doc_id op) (shift_exp env mem add_red show_hole_contents l) (plus_exp env mem add_red show_hole_contents r) - | _ -> plus_exp env mem add_red show_hole_contents expr - and plus_exp env mem add_red show_hole_contents ((E_aux(e,_)) as expr) = match e with - | E_app_infix(l,(Id_aux(Id ("+" | "-"| "+_s" | "-_s" ),_) as op),r) -> - doc_op (doc_id op) (plus_exp env mem add_red show_hole_contents l) (star_exp env mem add_red show_hole_contents r) - | _ -> star_exp env mem add_red show_hole_contents expr - and star_exp env mem add_red show_hole_contents ((E_aux(e,_)) as expr) = match e with - | E_app_infix(l,(Id_aux(Id ( - "*" | "/" - | "div" | "quot" | "rem" | "mod" | "quot_s" | "mod_s" - | "*_s" | "*_si" | "*_u" | "*_ui"),_) as op),r) -> - doc_op (doc_id op) (star_exp env mem add_red show_hole_contents l) (starstar_exp env mem add_red show_hole_contents r) - | _ -> starstar_exp env mem add_red show_hole_contents expr - and starstar_exp env mem add_red show_hole_contents ((E_aux(e,_)) as expr) = match e with - | E_app_infix(l,(Id_aux(Id "**",_) as op),r) -> - doc_op (doc_id op) (starstar_exp env mem add_red show_hole_contents l) (app_exp env mem add_red show_hole_contents r) - | E_if _ | E_for _ | E_let _ -> right_atomic_exp env mem add_red show_hole_contents expr - | _ -> app_exp env mem add_red show_hole_contents expr - and right_atomic_exp env mem add_red show_hole_contents ((E_aux(e,_)) as expr) = match e with - (* Special case: omit "else ()" when the else branch is empty. *) - | E_if(c,t,E_aux(E_block [], _)) -> - string "if" ^^ space ^^ group (exp env mem add_red show_hole_contents c) ^/^ - string "then" ^^ space ^^ group (exp env mem add_red show_hole_contents t) - | E_if(c,t,e) -> - string "if" ^^ space ^^ group (exp env mem add_red show_hole_contents c) ^/^ - string "then" ^^ space ^^ group (exp env mem add_red show_hole_contents t) ^/^ - string "else" ^^ space ^^ group (exp env mem add_red show_hole_contents e) - | E_for(id,exp1,exp2,exp3,order,exp4) -> - string "foreach" ^^ space ^^ - group (parens ( - separate (break 1) [ - doc_id id; - string "from " ^^ (atomic_exp env mem add_red show_hole_contents exp1); - string "to " ^^ (atomic_exp env mem add_red show_hole_contents exp2); - string "by " ^^ (atomic_exp env mem add_red show_hole_contents exp3); - string "in " ^^ doc_ord order - ] - )) ^/^ - (exp env mem add_red show_hole_contents exp4) - | E_let(leb,e) -> doc_op (string "in") (let_exp env mem add_red show_hole_contents leb) (exp env mem add_red show_hole_contents e) - | _ -> group (parens (exp env mem add_red show_hole_contents expr)) - and app_exp env mem add_red show_hole_contents ((E_aux(e,_)) as expr) = match e with - | E_app(f,args) -> - doc_unop (doc_id f) (parens (separate_map comma (exp env mem add_red show_hole_contents) args)) - | _ -> vaccess_exp env mem add_red show_hole_contents expr - and vaccess_exp env mem add_red show_hole_contents ((E_aux(e,_)) as expr) = match e with - | E_vector_access(v,e) -> - (atomic_exp env mem add_red show_hole_contents v) ^^ brackets (exp env mem add_red show_hole_contents e) - | E_vector_subrange(v,e1,e2) -> - (atomic_exp env mem add_red show_hole_contents v) ^^ - brackets (doc_op dotdot (exp env mem add_red show_hole_contents e1) (exp env mem add_red show_hole_contents e2)) - | _ -> field_exp env mem add_red show_hole_contents expr - and field_exp env mem add_red show_hole_contents ((E_aux(e,_)) as expr) = match e with - | E_field(fexp,id) -> (atomic_exp env mem add_red show_hole_contents fexp) ^^ dot ^^ doc_id id - | _ -> atomic_exp env mem add_red show_hole_contents expr - and atomic_exp env mem add_red (show_hole_contents:bool) ((E_aux(e,annot)) as expr) = match e with - (* Special case: an empty block is equivalent to unit, but { } is a syntactic struct *) - | E_block [] -> string "()" - | E_block exps -> - let exps_doc = separate_map (semi ^^ hardline) (exp env mem add_red show_hole_contents) exps in - surround 2 1 lbrace exps_doc rbrace - | E_nondet exps -> - let exps_doc = separate_map (semi ^^ hardline) (exp env mem add_red show_hole_contents) exps in - string "nondet" ^^ space ^^ (surround 2 1 lbrace exps_doc rbrace) - | E_id id -> - (match id with - | Id_aux(Id("0"), _) -> - (match Interp.in_lenv env id with - | Interp_ast.V_unknown -> string (add_red "[_]") - | v -> - if show_hole_contents - then string (add_red (Interp.string_of_value v)) - else string (add_red "[_]")) - | _ -> doc_id id) - | E_lit lit -> doc_lit lit - | E_typ(typ,e) -> - if !ignore_casts then + and or_exp env mem add_red show_hole_contents (E_aux (e, _) as expr) = + match e with + | E_app_infix (l, (Id_aux (Id ("|" | "||"), _) as op), r) -> + doc_op (doc_id op) (and_exp env mem add_red show_hole_contents l) (or_exp env mem add_red show_hole_contents r) + | _ -> and_exp env mem add_red show_hole_contents expr + and and_exp env mem add_red show_hole_contents (E_aux (e, _) as expr) = + match e with + | E_app_infix (l, (Id_aux (Id ("&" | "&&"), _) as op), r) -> + doc_op (doc_id op) (eq_exp env mem add_red show_hole_contents l) (and_exp env mem add_red show_hole_contents r) + | _ -> eq_exp env mem add_red show_hole_contents expr + and eq_exp env mem add_red show_hole_contents (E_aux (e, _) as expr) = + match e with + | E_app_infix + ( l, + ( Id_aux + ( Id + (* XXX this is not very consistent - is the parser bogus here? *) + ( "=" | "==" | "!=" | ">=" | ">=_s" | ">=_u" | ">" | ">_s" | ">_u" | "<=" | "<=_s" | "<" | "<_s" + | "<_si" | "<_u" ), + _ + ) as op + ), + r + ) -> + doc_op (doc_id op) (eq_exp env mem add_red show_hole_contents l) (at_exp env mem add_red show_hole_contents r) + (* XXX assignment should not have the same precedence as equal etc. *) + | E_assign (le, exp) -> + doc_op coloneq (doc_lexp env mem add_red show_hole_contents le) (at_exp env mem add_red show_hole_contents exp) + | _ -> at_exp env mem add_red show_hole_contents expr + and at_exp env mem add_red show_hole_contents (E_aux (e, _) as expr) = + match e with + | E_app_infix (l, (Id_aux (Id ("@" | "^^" | "^" | "~^"), _) as op), r) -> + doc_op (doc_id op) (cons_exp env mem add_red show_hole_contents l) (at_exp env mem add_red show_hole_contents r) + | _ -> cons_exp env mem add_red show_hole_contents expr + and cons_exp env mem add_red show_hole_contents (E_aux (e, _) as expr) = + match e with + | E_vector_append (l, r) -> + doc_op colon (shift_exp env mem add_red show_hole_contents l) (cons_exp env mem add_red show_hole_contents r) + | E_cons (l, r) -> + doc_op colon (shift_exp env mem add_red show_hole_contents l) (cons_exp env mem add_red show_hole_contents r) + | _ -> shift_exp env mem add_red show_hole_contents expr + and shift_exp env mem add_red show_hole_contents (E_aux (e, _) as expr) = + match e with + | E_app_infix (l, (Id_aux (Id (">>" | ">>>" | "<<" | "<<<"), _) as op), r) -> + doc_op (doc_id op) + (shift_exp env mem add_red show_hole_contents l) + (plus_exp env mem add_red show_hole_contents r) + | _ -> plus_exp env mem add_red show_hole_contents expr + and plus_exp env mem add_red show_hole_contents (E_aux (e, _) as expr) = + match e with + | E_app_infix (l, (Id_aux (Id ("+" | "-" | "+_s" | "-_s"), _) as op), r) -> + doc_op (doc_id op) + (plus_exp env mem add_red show_hole_contents l) + (star_exp env mem add_red show_hole_contents r) + | _ -> star_exp env mem add_red show_hole_contents expr + and star_exp env mem add_red show_hole_contents (E_aux (e, _) as expr) = + match e with + | E_app_infix + ( l, + ( Id_aux + (Id ("*" | "/" | "div" | "quot" | "rem" | "mod" | "quot_s" | "mod_s" | "*_s" | "*_si" | "*_u" | "*_ui"), _) + as op + ), + r + ) -> + doc_op (doc_id op) + (star_exp env mem add_red show_hole_contents l) + (starstar_exp env mem add_red show_hole_contents r) + | _ -> starstar_exp env mem add_red show_hole_contents expr + and starstar_exp env mem add_red show_hole_contents (E_aux (e, _) as expr) = + match e with + | E_app_infix (l, (Id_aux (Id "**", _) as op), r) -> + doc_op (doc_id op) + (starstar_exp env mem add_red show_hole_contents l) + (app_exp env mem add_red show_hole_contents r) + | E_if _ | E_for _ | E_let _ -> right_atomic_exp env mem add_red show_hole_contents expr + | _ -> app_exp env mem add_red show_hole_contents expr + and right_atomic_exp env mem add_red show_hole_contents (E_aux (e, _) as expr) = + match e with + (* Special case: omit "else ()" when the else branch is empty. *) + | E_if (c, t, E_aux (E_block [], _)) -> + string "if" ^^ space + ^^ group (exp env mem add_red show_hole_contents c) + ^/^ string "then" ^^ space + ^^ group (exp env mem add_red show_hole_contents t) + | E_if (c, t, e) -> + string "if" ^^ space + ^^ group (exp env mem add_red show_hole_contents c) + ^/^ string "then" ^^ space + ^^ group (exp env mem add_red show_hole_contents t) + ^/^ string "else" ^^ space + ^^ group (exp env mem add_red show_hole_contents e) + | E_for (id, exp1, exp2, exp3, order, exp4) -> + string "foreach" ^^ space + ^^ group + (parens + (separate (break 1) + [ + doc_id id; + string "from " ^^ atomic_exp env mem add_red show_hole_contents exp1; + string "to " ^^ atomic_exp env mem add_red show_hole_contents exp2; + string "by " ^^ atomic_exp env mem add_red show_hole_contents exp3; + string "in " ^^ doc_ord order; + ] + ) + ) + ^/^ exp env mem add_red show_hole_contents exp4 + | E_let (leb, e) -> + doc_op (string "in") (let_exp env mem add_red show_hole_contents leb) (exp env mem add_red show_hole_contents e) + | _ -> group (parens (exp env mem add_red show_hole_contents expr)) + and app_exp env mem add_red show_hole_contents (E_aux (e, _) as expr) = + match e with + | E_app (f, args) -> doc_unop (doc_id f) (parens (separate_map comma (exp env mem add_red show_hole_contents) args)) + | _ -> vaccess_exp env mem add_red show_hole_contents expr + and vaccess_exp env mem add_red show_hole_contents (E_aux (e, _) as expr) = + match e with + | E_vector_access (v, e) -> + atomic_exp env mem add_red show_hole_contents v ^^ brackets (exp env mem add_red show_hole_contents e) + | E_vector_subrange (v, e1, e2) -> + atomic_exp env mem add_red show_hole_contents v + ^^ brackets + (doc_op dotdot (exp env mem add_red show_hole_contents e1) (exp env mem add_red show_hole_contents e2)) + | _ -> field_exp env mem add_red show_hole_contents expr + and field_exp env mem add_red show_hole_contents (E_aux (e, _) as expr) = + match e with + | E_field (fexp, id) -> atomic_exp env mem add_red show_hole_contents fexp ^^ dot ^^ doc_id id + | _ -> atomic_exp env mem add_red show_hole_contents expr + and atomic_exp env mem add_red (show_hole_contents : bool) (E_aux (e, annot) as expr) = + match e with + (* Special case: an empty block is equivalent to unit, but { } is a syntactic struct *) + | E_block [] -> string "()" + | E_block exps -> + let exps_doc = separate_map (semi ^^ hardline) (exp env mem add_red show_hole_contents) exps in + surround 2 1 lbrace exps_doc rbrace + | E_nondet exps -> + let exps_doc = separate_map (semi ^^ hardline) (exp env mem add_red show_hole_contents) exps in + string "nondet" ^^ space ^^ surround 2 1 lbrace exps_doc rbrace + | E_id id -> ( + match id with + | Id_aux (Id "0", _) -> ( + match Interp.in_lenv env id with + | Interp_ast.V_unknown -> string (add_red "[_]") + | v -> if show_hole_contents then string (add_red (Interp.string_of_value v)) else string (add_red "[_]") + ) + | _ -> doc_id id + ) + | E_lit lit -> doc_lit lit + | E_typ (typ, e) -> + if !ignore_casts then atomic_exp env mem add_red show_hole_contents e + else prefix 2 1 (parens (doc_typ typ)) (group (atomic_exp env mem add_red show_hole_contents e)) + | E_internal_cast (_, e) -> + (* XXX ignore internal casts in the interpreter *) atomic_exp env mem add_red show_hole_contents e - else - prefix 2 1 (parens (doc_typ typ)) (group (atomic_exp env mem add_red show_hole_contents e)) - | E_internal_cast(_,e) -> - (* XXX ignore internal casts in the interpreter *) - atomic_exp env mem add_red show_hole_contents e - | E_tuple exps -> - parens (separate_map comma (exp env mem add_red show_hole_contents) exps) - | E_struct(FES_aux(FES_fexps(fexps,_),_)) -> - braces (separate_map semi_sp (doc_fexp env mem add_red show_hole_contents) fexps) - | E_struct_update(e,(FES_aux(FES_fexps(fexps,_),_))) -> - braces (doc_op (string "with") - (exp env mem add_red show_hole_contents e) - (separate_map semi_sp (doc_fexp env mem add_red show_hole_contents) fexps)) - | E_vector exps -> - let default_print _ = brackets (separate_map comma (exp env mem add_red show_hole_contents) exps) in - (match exps with - | [] -> default_print () - | es -> - if (List.for_all - (fun e -> match e with - | (E_aux(E_lit(L_aux((L_one | L_zero | L_undef),_)),_)) -> true - | _ -> false) es) - then - utf8string - ("0b" ^ - (List.fold_right (fun (E_aux(e,_)) rst -> - match e with - | E_lit(L_aux(l, _)) -> (match l with | L_one -> "1"^rst - | L_zero -> "0"^rst - | L_undef -> "u"^rst - | _ -> failwith "bit vector not just bit values") - | _ -> failwith "bit vector not all lits") exps "")) - else default_print ()) - | E_vector_update(v,e1,e2) -> - brackets (doc_op (string "with") - (exp env mem add_red show_hole_contents v) - (doc_op equals (atomic_exp env mem add_red show_hole_contents e1) - (exp env mem add_red show_hole_contents e2))) - | E_vector_update_subrange(v,e1,e2,e3) -> - brackets ( - doc_op (string "with") (exp env mem add_red show_hole_contents v) - (doc_op equals ((atomic_exp env mem add_red show_hole_contents e1) ^^ colon - ^^ (atomic_exp env mem add_red show_hole_contents e2)) (exp env mem add_red show_hole_contents e3))) - | E_list exps -> - squarebarbars (separate_map comma (exp env mem add_red show_hole_contents) exps) - | E_match(e,pexps) -> - let opening = separate space [string "switch"; exp env mem add_red show_hole_contents e; lbrace] in - let cases = separate_map (break 1) (doc_case env mem add_red show_hole_contents) pexps in - surround 2 1 opening cases rbrace - | E_exit e -> separate space [string "exit"; exp env mem add_red show_hole_contents e;] - | E_return e -> separate space [string "return"; exp env mem add_red show_hole_contents e;] - | E_assert(e,msg) -> string "assert" ^^ parens (separate_map comma (exp env mem add_red show_hole_contents) [e; msg]) - (* adding parens and loop for lower precedence *) - | E_app (_, _)|E_vector_access (_, _)|E_vector_subrange (_, _, _) - | E_cons (_, _)|E_field (_, _)|E_assign (_, _) - | E_if _ | E_for _ | E_let _ - | E_vector_append _ - | E_app_infix (_, - (* for every app_infix operator caught at a higher precedence, - * we need to wrap around with parens *) - (Id_aux(Id("|" | "||" - | "&" | "&&" - | "=" | "==" | "!=" - | ">=" | ">=_s" | ">=_u" | ">" | ">_s" | ">_u" - | "<=" | "<=_s" | "<" | "<_s" | "<_si" | "<_u" - | "@" | "^^" | "^" | "~^" - | ">>" | ">>>" | "<<" | "<<<" - | "+" | "+_s" | "-" | "-_s" - | "*" | "/" - | "div" | "quot" | "quot_s" | "rem" | "mod" | "mod_s" - | "*_s" | "*_si" | "*_u" | "*_ui" - | "**"), _)) - , _) -> - group (parens (exp env mem add_red show_hole_contents expr)) - (* XXX fixup deinfix into infix ones *) - | E_app_infix(l, (Id_aux((DeIid op), annot')), r) -> - group (parens - (exp env mem add_red show_hole_contents (E_aux ((E_app_infix (l, (Id_aux(Id op, annot')), r)), annot)))) - (* XXX default precedence for app_infix? *) - | E_app_infix(l,op,r) -> - failwith ("unexpected app_infix operator " ^ (pp_format_id op)) - (* doc_op (doc_id op) (exp l) (exp r) *) - (* XXX missing case *) - | E_comment _ | E_comment_struc _ -> string "" - | E_internal_value v -> - string (Interp.string_of_value v) - | _-> failwith "internal expression escaped" - - and let_exp env mem add_red show_hole_contents (LB_aux(lb,_)) = match lb with - | LB_val(pat,e) -> - prefix 2 1 - (separate space [string "let"; doc_atomic_pat pat; equals]) - (exp env mem add_red show_hole_contents e) - - and doc_fexp env mem add_red show_hole_contents (FE_aux(FE_fexp(id,e),_)) = + | E_tuple exps -> parens (separate_map comma (exp env mem add_red show_hole_contents) exps) + | E_struct (FES_aux (FES_fexps (fexps, _), _)) -> + braces (separate_map semi_sp (doc_fexp env mem add_red show_hole_contents) fexps) + | E_struct_update (e, FES_aux (FES_fexps (fexps, _), _)) -> + braces + (doc_op (string "with") + (exp env mem add_red show_hole_contents e) + (separate_map semi_sp (doc_fexp env mem add_red show_hole_contents) fexps) + ) + | E_vector exps -> ( + let default_print _ = brackets (separate_map comma (exp env mem add_red show_hole_contents) exps) in + match exps with + | [] -> default_print () + | es -> + if + List.for_all + (fun e -> match e with E_aux (E_lit (L_aux ((L_one | L_zero | L_undef), _)), _) -> true | _ -> false) + es + then + utf8string + ("0b" + ^ List.fold_right + (fun (E_aux (e, _)) rst -> + match e with + | E_lit (L_aux (l, _)) -> ( + match l with + | L_one -> "1" ^ rst + | L_zero -> "0" ^ rst + | L_undef -> "u" ^ rst + | _ -> failwith "bit vector not just bit values" + ) + | _ -> failwith "bit vector not all lits" + ) + exps "" + ) + else default_print () + ) + | E_vector_update (v, e1, e2) -> + brackets + (doc_op (string "with") + (exp env mem add_red show_hole_contents v) + (doc_op equals + (atomic_exp env mem add_red show_hole_contents e1) + (exp env mem add_red show_hole_contents e2) + ) + ) + | E_vector_update_subrange (v, e1, e2, e3) -> + brackets + (doc_op (string "with") + (exp env mem add_red show_hole_contents v) + (doc_op equals + (atomic_exp env mem add_red show_hole_contents e1 + ^^ colon + ^^ atomic_exp env mem add_red show_hole_contents e2 + ) + (exp env mem add_red show_hole_contents e3) + ) + ) + | E_list exps -> squarebarbars (separate_map comma (exp env mem add_red show_hole_contents) exps) + | E_match (e, pexps) -> + let opening = separate space [string "switch"; exp env mem add_red show_hole_contents e; lbrace] in + let cases = separate_map (break 1) (doc_case env mem add_red show_hole_contents) pexps in + surround 2 1 opening cases rbrace + | E_exit e -> separate space [string "exit"; exp env mem add_red show_hole_contents e] + | E_return e -> separate space [string "return"; exp env mem add_red show_hole_contents e] + | E_assert (e, msg) -> + string "assert" ^^ parens (separate_map comma (exp env mem add_red show_hole_contents) [e; msg]) + (* adding parens and loop for lower precedence *) + | E_app (_, _) + | E_vector_access (_, _) + | E_vector_subrange (_, _, _) + | E_cons (_, _) + | E_field (_, _) + | E_assign (_, _) + | E_if _ | E_for _ | E_let _ | E_vector_append _ + | E_app_infix + ( _, + (* for every app_infix operator caught at a higher precedence, + * we need to wrap around with parens *) + Id_aux + ( Id + ( "|" | "||" | "&" | "&&" | "=" | "==" | "!=" | ">=" | ">=_s" | ">=_u" | ">" | ">_s" | ">_u" | "<=" + | "<=_s" | "<" | "<_s" | "<_si" | "<_u" | "@" | "^^" | "^" | "~^" | ">>" | ">>>" | "<<" | "<<<" | "+" + | "+_s" | "-" | "-_s" | "*" | "/" | "div" | "quot" | "quot_s" | "rem" | "mod" | "mod_s" | "*_s" | "*_si" + | "*_u" | "*_ui" | "**" ), + _ + ), + _ + ) -> + group (parens (exp env mem add_red show_hole_contents expr)) + (* XXX fixup deinfix into infix ones *) + | E_app_infix (l, Id_aux (DeIid op, annot'), r) -> + group + (parens (exp env mem add_red show_hole_contents (E_aux (E_app_infix (l, Id_aux (Id op, annot'), r), annot)))) + (* XXX default precedence for app_infix? *) + | E_app_infix (l, op, r) -> failwith ("unexpected app_infix operator " ^ pp_format_id op) + (* doc_op (doc_id op) (exp l) (exp r) *) + (* XXX missing case *) + | E_comment _ | E_comment_struc _ -> string "" + | E_internal_value v -> string (Interp.string_of_value v) + | _ -> failwith "internal expression escaped" + and let_exp env mem add_red show_hole_contents (LB_aux (lb, _)) = + match lb with + | LB_val (pat, e) -> + prefix 2 1 (separate space [string "let"; doc_atomic_pat pat; equals]) (exp env mem add_red show_hole_contents e) + and doc_fexp env mem add_red show_hole_contents (FE_aux (FE_fexp (id, e), _)) = doc_op equals (doc_id id) (exp env mem add_red show_hole_contents e) - - and doc_case env mem add_red show_hole_contents (Pat_aux(Pat_exp(pat,e),_)) = + and doc_case env mem add_red show_hole_contents (Pat_aux (Pat_exp (pat, e), _)) = doc_op arrow (separate space [string "case"; doc_atomic_pat pat]) (group (exp env mem add_red show_hole_contents e)) - (* lexps are parsed as eq_exp - we need to duplicate the precedence * structure for them *) and doc_lexp env mem add_red show_hole_contents le = app_lexp env mem add_red show_hole_contents le - and app_lexp env mem add_red show_hole_contents ((LE_aux(lexp,_)) as le) = match lexp with - | LE_app(id,args) -> doc_id id ^^ parens (separate_map comma (exp env mem add_red show_hole_contents) args) - | _ -> vaccess_lexp env mem add_red show_hole_contents le - and vaccess_lexp env mem add_red show_hole_contents ((LE_aux(lexp,_)) as le) = match lexp with - | LE_vector(v,e) -> - (atomic_lexp env mem add_red show_hole_contents v) ^^ brackets (exp env mem add_red show_hole_contents e) - | LE_vector_range(v,e1,e2) -> - (atomic_lexp env mem add_red show_hole_contents v) ^^ - brackets ((exp env mem add_red show_hole_contents e1) ^^ dotdot ^^ (exp env mem add_red show_hole_contents e2)) + and app_lexp env mem add_red show_hole_contents (LE_aux (lexp, _) as le) = + match lexp with + | LE_app (id, args) -> doc_id id ^^ parens (separate_map comma (exp env mem add_red show_hole_contents) args) + | _ -> vaccess_lexp env mem add_red show_hole_contents le + and vaccess_lexp env mem add_red show_hole_contents (LE_aux (lexp, _) as le) = + match lexp with + | LE_vector (v, e) -> + atomic_lexp env mem add_red show_hole_contents v ^^ brackets (exp env mem add_red show_hole_contents e) + | LE_vector_range (v, e1, e2) -> + atomic_lexp env mem add_red show_hole_contents v + ^^ brackets (exp env mem add_red show_hole_contents e1 ^^ dotdot ^^ exp env mem add_red show_hole_contents e2) | _ -> field_lexp env mem add_red show_hole_contents le - and field_lexp env mem add_red show_hole_contents ((LE_aux(lexp,_)) as le) = match lexp with - | LE_field(v,id) -> (atomic_lexp env mem add_red show_hole_contents v) ^^ dot ^^ doc_id id - | _ -> atomic_lexp env mem add_red show_hole_contents le - and atomic_lexp env mem add_red show_hole_contents ((LE_aux(lexp,_)) as le) = match lexp with - | LE_id id -> doc_id id - | LE_typ(typ,id) -> prefix 2 1 (parens (doc_typ typ)) (doc_id id) - | LE_tuple(lexps) -> group (parens (separate_map comma (doc_lexp env mem add_red show_hole_contents) lexps)) - | LE_app _ | LE_vector _ | LE_vector_range _ - | LE_field _ -> group (parens (doc_lexp env mem add_red show_hole_contents le)) - - (* expose doc_exp and doc_let *) - in exp, let_exp - -let doc_default (DT_aux(df,_)) = match df with - | DT_kind(bk,v) -> separate space [string "default"; doc_bkind bk; doc_var v] - | DT_typ(ts,id) -> separate space [string "default"; doc_typscm ts; doc_id id] + and field_lexp env mem add_red show_hole_contents (LE_aux (lexp, _) as le) = + match lexp with + | LE_field (v, id) -> atomic_lexp env mem add_red show_hole_contents v ^^ dot ^^ doc_id id + | _ -> atomic_lexp env mem add_red show_hole_contents le + and atomic_lexp env mem add_red show_hole_contents (LE_aux (lexp, _) as le) = + match lexp with + | LE_id id -> doc_id id + | LE_typ (typ, id) -> prefix 2 1 (parens (doc_typ typ)) (doc_id id) + | LE_tuple lexps -> group (parens (separate_map comma (doc_lexp env mem add_red show_hole_contents) lexps)) + | LE_app _ | LE_vector _ | LE_vector_range _ | LE_field _ -> + group (parens (doc_lexp env mem add_red show_hole_contents le)) + (* expose doc_exp and doc_let *) + in + + (exp, let_exp) + +let doc_default (DT_aux (df, _)) = + match df with + | DT_kind (bk, v) -> separate space [string "default"; doc_bkind bk; doc_var v] + | DT_typ (ts, id) -> separate space [string "default"; doc_typscm ts; doc_id id] | DT_order o -> separate space [string "default"; string "Order"; doc_ord o] -let doc_spec (VS_aux(v,_)) = match v with - | VS_val_spec(ts,id, _, _) -> - separate space [string "val"; doc_typscm ts; doc_id id] +let doc_spec (VS_aux (v, _)) = + match v with VS_val_spec (ts, id, _, _) -> separate space [string "val"; doc_typscm ts; doc_id id] -let doc_namescm (Name_sect_aux(ns,_)) = match ns with +let doc_namescm (Name_sect_aux (ns, _)) = + match ns with | Name_sect_none -> empty (* include leading space because the caller doesn't know if ns is * empty, and trailing break already added by the following equals *) | Name_sect_some s -> space ^^ brackets (doc_op equals (string "name") (dquotes (string s))) -let rec doc_range (BF_aux(r,_)) = match r with +let rec doc_range (BF_aux (r, _)) = + match r with | BF_single i -> doc_int i - | BF_range(i1,i2) -> doc_op dotdot (doc_int i1) (doc_int i2) - | BF_concat(ir1,ir2) -> (doc_range ir1) ^^ comma ^^ (doc_range ir2) + | BF_range (i1, i2) -> doc_op dotdot (doc_int i1) (doc_int i2) + | BF_concat (ir1, ir2) -> doc_range ir1 ^^ comma ^^ doc_range ir2 -let doc_type_union (Tu_aux(typ_u,_)) = match typ_u with - | Tu_ty_id(typ,id) -> separate space [doc_typ typ; doc_id id] - | Tu_id id -> doc_id id +let doc_type_union (Tu_aux (typ_u, _)) = + match typ_u with Tu_ty_id (typ, id) -> separate space [doc_typ typ; doc_id id] | Tu_id id -> doc_id id -let doc_typdef (TD_aux(td,_)) = match td with - | TD_abbrev(id,nm,typschm) -> +let doc_typdef (TD_aux (td, _)) = + match td with + | TD_abbrev (id, nm, typschm) -> doc_op equals (concat [string "typedef"; space; doc_id id; doc_namescm nm]) (doc_typscm typschm) - | TD_record(id,nm,typq,fs,_) -> - let f_pp (typ,id) = concat [doc_typ typ; space; doc_id id; semi] in + | TD_record (id, nm, typq, fs, _) -> + let f_pp (typ, id) = concat [doc_typ typ; space; doc_id id; semi] in let fs_doc = group (separate_map (break 1) f_pp fs) in doc_op equals (concat [string "typedef"; space; doc_id id; doc_namescm nm]) (string "const struct" ^^ space ^^ doc_typquant typq (braces fs_doc)) - | TD_variant(id,nm,typq,ar,_) -> + | TD_variant (id, nm, typq, ar, _) -> let ar_doc = group (separate_map (semi ^^ break 1) doc_type_union ar) in doc_op equals (concat [string "typedef"; space; doc_id id; doc_namescm nm]) (string "const union" ^^ space ^^ doc_typquant typq (braces ar_doc)) - | TD_enum(id,nm,enums,_) -> + | TD_enum (id, nm, enums, _) -> let enums_doc = group (separate_map (semi ^^ break 1) doc_id enums) in doc_op equals (concat [string "typedef"; space; doc_id id; doc_namescm nm]) (string "enumerate" ^^ space ^^ braces enums_doc) - | TD_register(id,n1,n2,rs) -> - let doc_rid (r,id) = separate space [doc_range r; colon; doc_id id] ^^ semi in + | TD_register (id, n1, n2, rs) -> + let doc_rid (r, id) = separate space [doc_range r; colon; doc_id id] ^^ semi in let doc_rids = group (separate_map (break 1) doc_rid rs) in doc_op equals (string "typedef" ^^ space ^^ doc_id id) - (separate space [ - string "register bits"; - brackets (doc_nexp n1 ^^ colon ^^ doc_nexp n2); - braces doc_rids; - ]) + (separate space [string "register bits"; brackets (doc_nexp n1 ^^ colon ^^ doc_nexp n2); braces doc_rids]) -let doc_rec (Rec_aux(r,_)) = match r with +let doc_rec (Rec_aux (r, _)) = + match r with | Rec_nonrec -> empty (* include trailing space because caller doesn't know if we return * empty *) | Rec_rec -> string "rec" ^^ space -let doc_tannot_opt (Typ_annot_opt_aux(t,_)) = match t with - | Typ_annot_opt_some(tq,typ) -> doc_typquant tq (doc_typ typ) +let doc_tannot_opt (Typ_annot_opt_aux (t, _)) = + match t with Typ_annot_opt_some (tq, typ) -> doc_typquant tq (doc_typ typ) -let doc_effects_opt (Effect_opt_aux(e,_)) = match e with - | Effect_opt_pure -> string "pure" - | Effect_opt_effect e -> doc_effects e +let doc_effects_opt (Effect_opt_aux (e, _)) = + match e with Effect_opt_pure -> string "pure" | Effect_opt_effect e -> doc_effects e -let doc_funcl env mem add_red (FCL_aux(FCL_funcl(id,Pat_aux (Pat_exp (pat, exp), _)),_)) = +let doc_funcl env mem add_red (FCL_aux (FCL_funcl (id, Pat_aux (Pat_exp (pat, exp), _)), _)) = group (doc_op equals (separate space [doc_id id; doc_atomic_pat pat]) (doc_exp env mem add_red false exp)) -let doc_fundef env mem add_red (FD_aux(FD_function(r, typa, efa, fcls),_)) = +let doc_fundef env mem add_red (FD_aux (FD_function (r, typa, efa, fcls), _)) = match fcls with | [] -> failwith "FD_function with empty function list" | _ -> let sep = hardline ^^ string "and" ^^ space in let clauses = separate_map sep (doc_funcl env mem add_red) fcls in - separate space [string "function"; - doc_rec r ^^ doc_tannot_opt typa; - string "effect"; doc_effects_opt efa; - clauses] + separate space + [string "function"; doc_rec r ^^ doc_tannot_opt typa; string "effect"; doc_effects_opt efa; clauses] -let doc_dec (DEC_aux(d,_)) = match d with - | DEC_reg(typ,id) -> separate space [string "register"; doc_atomic_typ typ; doc_id id] +let doc_dec (DEC_aux (d, _)) = + match d with + | DEC_reg (typ, id) -> separate space [string "register"; doc_atomic_typ typ; doc_id id] | _ -> failwith "interpreter printing out declarations unexpectedly" -let doc_scattered env mem add_red (SD_aux (sdef, _)) = match sdef with - | SD_scattered_function (r, typa, efa, id) -> - separate space [ - string "scattered function"; - doc_rec r ^^ doc_tannot_opt typa; - string "effect"; doc_effects_opt efa; - doc_id id] - | SD_scattered_variant (id, ns, tq) -> - doc_op equals - (string "scattered typedef" ^^ space ^^ doc_id id ^^ doc_namescm ns) - (doc_typquant tq empty) - | SD_scattered_funcl funcl -> - string "function clause" ^^ space ^^ doc_funcl env mem add_red funcl - | SD_scattered_unioncl (id, tu) -> - separate space [string "union"; doc_id id; - string "member"; doc_type_union tu] - | SD_scattered_end id -> string "end" ^^ space ^^ doc_id id - -let rec doc_def env mem add_red def = group (match def with - | DEF_default df -> doc_default df - | DEF_val v_spec -> doc_spec v_spec - | DEF_type t_def -> doc_typdef t_def - | DEF_kind k_def -> failwith "interpreter unexpectedly printing kind def" - | DEF_fundef f_def -> doc_fundef env mem add_red f_def - | DEF_let lbind -> doc_let env mem add_red false lbind - | DEF_register dec -> doc_dec dec - | DEF_scattered sdef -> doc_scattered env mem add_red sdef - | DEF_comm comm_dec -> string "(*" ^^ doc_comm_dec env mem add_red comm_dec ^^ string "*)" - ) ^^ hardline - -and doc_comm_dec env mem add_red dec = match dec with - | DC_comm s -> string s - | DC_comm_struct d -> doc_def env mem add_red d - -let doc_defs env mem add_red (Defs(defs)) = - separate_map hardline (doc_def env mem add_red) defs - -let print ?(len=80) channel doc = ToChannel.pretty 1. len channel doc -let to_buf ?(len=80) buf doc = ToBuffer.pretty 1. len buf doc +let doc_scattered env mem add_red (SD_aux (sdef, _)) = + match sdef with + | SD_scattered_function (r, typa, efa, id) -> + separate space + [string "scattered function"; doc_rec r ^^ doc_tannot_opt typa; string "effect"; doc_effects_opt efa; doc_id id] + | SD_scattered_variant (id, ns, tq) -> + doc_op equals (string "scattered typedef" ^^ space ^^ doc_id id ^^ doc_namescm ns) (doc_typquant tq empty) + | SD_scattered_funcl funcl -> string "function clause" ^^ space ^^ doc_funcl env mem add_red funcl + | SD_scattered_unioncl (id, tu) -> separate space [string "union"; doc_id id; string "member"; doc_type_union tu] + | SD_scattered_end id -> string "end" ^^ space ^^ doc_id id + +let rec doc_def env mem add_red def = + group + ( match def with + | DEF_default df -> doc_default df + | DEF_val v_spec -> doc_spec v_spec + | DEF_type t_def -> doc_typdef t_def + | DEF_kind k_def -> failwith "interpreter unexpectedly printing kind def" + | DEF_fundef f_def -> doc_fundef env mem add_red f_def + | DEF_let lbind -> doc_let env mem add_red false lbind + | DEF_register dec -> doc_dec dec + | DEF_scattered sdef -> doc_scattered env mem add_red sdef + | DEF_comm comm_dec -> string "(*" ^^ doc_comm_dec env mem add_red comm_dec ^^ string "*)" + ) + ^^ hardline + +and doc_comm_dec env mem add_red dec = + match dec with DC_comm s -> string s | DC_comm_struct d -> doc_def env mem add_red d + +let doc_defs env mem add_red (Defs defs) = separate_map hardline (doc_def env mem add_red) defs + +let print ?(len = 80) channel doc = ToChannel.pretty 1. len channel doc +let to_buf ?(len = 80) buf doc = ToBuffer.pretty 1. len buf doc let pp_exp env mem add_red show_hole_contents e = let b = Buffer.create 20 in diff --git a/src/lem_interp/printing_functions.ml b/src/lem_interp/printing_functions.ml index 1d04314a9..f1df32cbb 100644 --- a/src/lem_interp/printing_functions.ml +++ b/src/lem_interp/printing_functions.ml @@ -65,45 +65,43 @@ (* SUCH DAMAGE. *) (****************************************************************************) -open Printf ;; -open Interp_ast ;; -open Sail_impl_base ;; -open Interp_utilities ;; -open Interp_interface ;; - - -open Nat_big_num ;; - -let lit_to_string = Pretty_interp.lit_to_string ;; -let id_to_string = Pretty_interp.id_to_string ;; -let loc_to_string = Pretty_interp.loc_to_string ;; -let bitvec_to_string = Pretty_interp.bitvec_to_string ;; -let collapse_leading = Pretty_interp.collapse_leading ;; - - -type bits_lifted_homogenous = - | Bitslh_concrete of bit list - | Bitslh_undef - | Bitslh_unknown - -let rec bits_lifted_homogenous_of_bit_lifteds' (bls:bit_lifted list) (acc:bits_lifted_homogenous) = - match (bls,acc) with - | ([], _) -> Some acc - | (Bitl_zero::bls', Bitslh_concrete bs) -> bits_lifted_homogenous_of_bit_lifteds' bls' (Bitslh_concrete (bs@[Bitc_zero])) - | (Bitl_one::bls', Bitslh_concrete bs) -> bits_lifted_homogenous_of_bit_lifteds' bls' (Bitslh_concrete (bs@[Bitc_one])) - | (Bitl_undef::bls', Bitslh_undef) -> bits_lifted_homogenous_of_bit_lifteds' bls' Bitslh_undef - | (Bitl_unknown::bls', Bitslh_unknown) -> bits_lifted_homogenous_of_bit_lifteds' bls' Bitslh_unknown - | (_,_) -> None - -let bits_lifted_homogenous_of_bit_lifteds (bls:bit_lifted list) : bits_lifted_homogenous option = - let bls',acc0 = +open Printf +open Interp_ast +open Sail_impl_base +open Interp_utilities +open Interp_interface + +open Nat_big_num + +let lit_to_string = Pretty_interp.lit_to_string +let id_to_string = Pretty_interp.id_to_string +let loc_to_string = Pretty_interp.loc_to_string +let bitvec_to_string = Pretty_interp.bitvec_to_string +let collapse_leading = Pretty_interp.collapse_leading + +type bits_lifted_homogenous = Bitslh_concrete of bit list | Bitslh_undef | Bitslh_unknown + +let rec bits_lifted_homogenous_of_bit_lifteds' (bls : bit_lifted list) (acc : bits_lifted_homogenous) = + match (bls, acc) with + | [], _ -> Some acc + | Bitl_zero :: bls', Bitslh_concrete bs -> + bits_lifted_homogenous_of_bit_lifteds' bls' (Bitslh_concrete (bs @ [Bitc_zero])) + | Bitl_one :: bls', Bitslh_concrete bs -> + bits_lifted_homogenous_of_bit_lifteds' bls' (Bitslh_concrete (bs @ [Bitc_one])) + | Bitl_undef :: bls', Bitslh_undef -> bits_lifted_homogenous_of_bit_lifteds' bls' Bitslh_undef + | Bitl_unknown :: bls', Bitslh_unknown -> bits_lifted_homogenous_of_bit_lifteds' bls' Bitslh_unknown + | _, _ -> None + +let bits_lifted_homogenous_of_bit_lifteds (bls : bit_lifted list) : bits_lifted_homogenous option = + let bls', acc0 = match bls with - | [] -> [], Bitslh_concrete [] - | Bitl_zero::bls' -> bls', Bitslh_concrete [Bitc_zero] - | Bitl_one::bls' -> bls', Bitslh_concrete [Bitc_one] - | Bitl_undef::bls' -> bls', Bitslh_undef - | Bitl_unknown::bls' -> bls', Bitslh_unknown in - bits_lifted_homogenous_of_bit_lifteds' bls' acc0 + | [] -> ([], Bitslh_concrete []) + | Bitl_zero :: bls' -> (bls', Bitslh_concrete [Bitc_zero]) + | Bitl_one :: bls' -> (bls', Bitslh_concrete [Bitc_one]) + | Bitl_undef :: bls' -> (bls', Bitslh_undef) + | Bitl_unknown :: bls' -> (bls', Bitslh_unknown) + in + bits_lifted_homogenous_of_bit_lifteds' bls' acc0 (*let byte_it_lifted_to_string = function | BL0 -> "0" @@ -112,86 +110,78 @@ let bits_lifted_homogenous_of_bit_lifteds (bls:bit_lifted list) : bits_lifted_ho | BLUnknown -> "?" *) -let bit_lifted_to_string = function - | Bitl_zero -> "0" - | Bitl_one -> "1" - | Bitl_undef -> "u" - | Bitl_unknown -> "?" +let bit_lifted_to_string = function Bitl_zero -> "0" | Bitl_one -> "1" | Bitl_undef -> "u" | Bitl_unknown -> "?" -let hex_int_to_string i = - let s = (Printf.sprintf "%x" i) in if (String.length s = 1) then "0"^s else s +let hex_int_to_string i = + let s = Printf.sprintf "%x" i in + if String.length s = 1 then "0" ^ s else s let bytes_lifted_homogenous_to_string = function - | Bitslh_concrete bs -> + | Bitslh_concrete bs -> let i = to_int (Sail_impl_base.integer_of_bit_list bs) in hex_int_to_string i | Bitslh_undef -> "uu" | Bitslh_unknown -> "??" -let simple_bit_lifteds_to_string ?(collapse=true) bls (show_length_and_start:bool) (starto: int option) = - let s = - String.concat "" (List.map bit_lifted_to_string bls) in - let s = - if collapse then collapse_leading s else s in - let len = string_of_int (List.length bls) in - if show_length_and_start then - match starto with - | None -> len ^ "b" ^s - | Some start -> len ^ "b" ^ "_" ^string_of_int start ^"'" ^ s - else - "0b"^s +let simple_bit_lifteds_to_string ?(collapse = true) bls (show_length_and_start : bool) (starto : int option) = + let s = String.concat "" (List.map bit_lifted_to_string bls) in + let s = if collapse then collapse_leading s else s in + let len = string_of_int (List.length bls) in + if show_length_and_start then ( + match starto with None -> len ^ "b" ^ s | Some start -> len ^ "b" ^ "_" ^ string_of_int start ^ "'" ^ s + ) + else "0b" ^ s (* if a multiple of 8 lifted bits and each chunk of 8 is homogenous, -print as lifted hex, otherwise print as lifted bits *) -let bit_lifteds_to_string ?(collapse=true) (bls: bit_lifted list) (show_length_and_start:bool) (starto: int option) (abbreviate_zero_to_nine: bool) = + print as lifted hex, otherwise print as lifted bits *) +let bit_lifteds_to_string ?(collapse = true) (bls : bit_lifted list) (show_length_and_start : bool) + (starto : int option) (abbreviate_zero_to_nine : bool) = let l = List.length bls in - if l mod 8 = 0 then (* if List.mem l [8;16;32;64;128] then *) + if l mod 8 = 0 then ( + (* if List.mem l [8;16;32;64;128] then *) let bytesl = List.map (fun (Byte_lifted bs) -> bs) (Sail_impl_base.byte_lifteds_of_bit_lifteds bls) in let byteslhos = List.map bits_lifted_homogenous_of_bit_lifteds bytesl in - match maybe_all byteslhos with - | None -> (* print as bitvector after all *) - simple_bit_lifteds_to_string ~collapse:collapse bls show_length_and_start starto - | Some (byteslhs: bits_lifted_homogenous list) -> - (* if abbreviate_zero_to_nine, all bytes are concrete, and the number is <=9, just print that *) + match maybe_all byteslhos with + | None -> + (* print as bitvector after all *) + simple_bit_lifteds_to_string ~collapse bls show_length_and_start starto + | Some (byteslhs : bits_lifted_homogenous list) -> ( + (* if abbreviate_zero_to_nine, all bytes are concrete, and the number is <=9, just print that *) (* (note that that doesn't print the length or start - it's appropriate only for memory values, where we typically have an explicit footprint also printed *) - let nos = List.rev_map (function Bitslh_concrete bs -> Some (Sail_impl_base.nat_of_bit_list bs) | Bitslh_undef -> None | Bitslh_unknown -> None) byteslhs in - let (lsb,msbs) = ((List.hd nos), List.tl nos) in - match (abbreviate_zero_to_nine, List.for_all (fun no -> no=Some 0) msbs, lsb) with - | (true, true, Some n) when n <= 9 -> - string_of_int n - | _ -> + let nos = + List.rev_map + (function + | Bitslh_concrete bs -> Some (Sail_impl_base.nat_of_bit_list bs) + | Bitslh_undef -> None + | Bitslh_unknown -> None + ) + byteslhs + in + let lsb, msbs = (List.hd nos, List.tl nos) in + match (abbreviate_zero_to_nine, List.for_all (fun no -> no = Some 0) msbs, lsb) with + | true, true, Some n when n <= 9 -> string_of_int n + | _ -> (* otherwise, print the bytes as hex *) let s = String.concat "" (List.map bytes_lifted_homogenous_to_string byteslhs) in - if show_length_and_start then - match starto with - | None -> "0x" ^ s - | Some start -> "0x" ^ "_" ^string_of_int start ^"'" ^ s - else - "0x"^s - else - simple_bit_lifteds_to_string ~collapse:collapse bls show_length_and_start starto - - -let register_value_to_string rv = - bit_lifteds_to_string rv.rv_bits true (Some rv.rv_start_internal) false - -let memory_value_to_string endian mv = - let bls = - Sail_impl_base.match_endianness endian mv - |> List.map (fun (Byte_lifted bs) -> bs) - |> List.concat - in + if show_length_and_start then ( + match starto with None -> "0x" ^ s | Some start -> "0x" ^ "_" ^ string_of_int start ^ "'" ^ s + ) + else "0x" ^ s + ) + ) + else simple_bit_lifteds_to_string ~collapse bls show_length_and_start starto + +let register_value_to_string rv = bit_lifteds_to_string rv.rv_bits true (Some rv.rv_start_internal) false + +let memory_value_to_string endian mv = + let bls = Sail_impl_base.match_endianness endian mv |> List.map (fun (Byte_lifted bs) -> bs) |> List.concat in bit_lifteds_to_string bls true None true -let logfile_register_value_to_string rv = - bit_lifteds_to_string ~collapse:false rv.rv_bits false (Some rv.rv_start) false +let logfile_register_value_to_string rv = + bit_lifteds_to_string ~collapse:false rv.rv_bits false (Some rv.rv_start) false -let logfile_memory_value_to_string endian mv = - let bls = - Sail_impl_base.match_endianness endian mv - |> List.map (fun (Byte_lifted bs) -> bs) - |> List.concat - in +let logfile_memory_value_to_string endian mv = + let bls = Sail_impl_base.match_endianness endian mv |> List.map (fun (Byte_lifted bs) -> bs) |> List.concat in bit_lifteds_to_string bls false None false let byte_list_to_string bs = @@ -201,235 +191,198 @@ let byte_list_to_string bs = let logfile_address_to_string a = let bs' = List.map byte_lifted_of_byte (byte_list_of_address a) in logfile_memory_value_to_string E_big_endian bs' - - -(*let bytes_to_string bytes = - (String.concat "" - (List.map (fun i -> hex_int_to_string i) - (List.map (fun (Byte_lifted bs) -> bits_to_word8 bs) bytes)))*) - - -let bit_to_string = function - | Bitc_zero -> "0" - | Bitc_one -> "1" +(*let bytes_to_string bytes = + (String.concat "" + (List.map (fun i -> hex_int_to_string i) + (List.map (fun (Byte_lifted bs) -> bits_to_word8 bs) bytes)))*) +let bit_to_string = function Bitc_zero -> "0" | Bitc_one -> "1" let reg_value_to_string v = "deprecated" -(* let l = List.length v.rv_bits in - let start = string_of_int v.rv_start in - if List.mem l [8;16;32;64;128] then - let bytes = Interp_inter_imp.to_bytes v.rv_bits in - "0x" ^ "_" ^ start ^ "'" ^ bytes_to_string bytes - else (string_of_int l) ^ "_" ^ start ^ "'b" ^ - collapse_leading (List.fold_right (^) (List.map bit_lifted_to_string v.rv_bits) "")*) +(* let l = List.length v.rv_bits in + let start = string_of_int v.rv_start in + if List.mem l [8;16;32;64;128] then + let bytes = Interp_inter_imp.to_bytes v.rv_bits in + "0x" ^ "_" ^ start ^ "'" ^ bytes_to_string bytes + else (string_of_int l) ^ "_" ^ start ^ "'b" ^ + collapse_leading (List.fold_right (^) (List.map bit_lifted_to_string v.rv_bits) "")*) -let ifield_to_string v = - "0b"^ collapse_leading (List.fold_right (^) (List.map bit_to_string v) "") +let ifield_to_string v = "0b" ^ collapse_leading (List.fold_right ( ^ ) (List.map bit_to_string v) "") (*let val_to_string v = match v with - | Bitvector(bools, inc, fst)-> - let l = List.length bools in - if List.mem l [8;16;32;64;128] then - let Bytevector bytes = Interp_inter_imp.coerce_Bytevector_of_Bitvector v in + | Bitvector(bools, inc, fst)-> + let l = List.length bools in + if List.mem l [8;16;32;64;128] then + let Bytevector bytes = Interp_inter_imp.coerce_Bytevector_of_Bitvector v in + "0x" ^ + "_" ^ (string_of_int (Big_int.int_of_big_int fst)) ^ "'" ^ + bytes_to_string bytes + else + (* (string_of_int l) ^ " bits -- 0b" ^ collapse_leading (String.concat "" (List.map (function | true -> "1" | _ -> "0") bools))*) + (string_of_int l) ^ "_" ^ (string_of_int (Big_int.int_of_big_int fst)) ^ "'b" ^ collapse_leading (String.concat "" (List.map (function | true -> "1" | _ -> "0") bools)) + | Bytevector bytes -> + (* let l = List.length words in *) + (*(string_of_int l) ^ " bytes -- " ^*) "0x" ^ - "_" ^ (string_of_int (Big_int.int_of_big_int fst)) ^ "'" ^ bytes_to_string bytes - else -(* (string_of_int l) ^ " bits -- 0b" ^ collapse_leading (String.concat "" (List.map (function | true -> "1" | _ -> "0") bools))*) - (string_of_int l) ^ "_" ^ (string_of_int (Big_int.int_of_big_int fst)) ^ "'b" ^ collapse_leading (String.concat "" (List.map (function | true -> "1" | _ -> "0") bools)) - | Bytevector bytes -> - (* let l = List.length words in *) - (*(string_of_int l) ^ " bytes -- " ^*) - "0x" ^ - bytes_to_string bytes - | Unknown0 -> "Unknown"*) - -let half_byte_to_hex v = + | Unknown0 -> "Unknown"*) + +let half_byte_to_hex v = match v with - | [false;false;false;false] -> "0" - | [false;false;false;true ] -> "1" - | [false;false;true ;false] -> "2" - | [false;false;true ;true ] -> "3" - | [false;true ;false;false] -> "4" - | [false;true ;false;true ] -> "5" - | [false;true ;true ;false] -> "6" - | [false;true ;true ;true ] -> "7" - | [true ;false;false;false] -> "8" - | [true ;false;false;true ] -> "9" - | [true ;false;true ;false] -> "a" - | [true ;false;true ;true ] -> "b" - | [true ;true ;false;false] -> "c" - | [true ;true ;false;true ] -> "d" - | [true ;true ;true ;false] -> "e" - | [true ;true ;true ;true ] -> "f" - | _ -> failwith "half_byte_to_hex given list of length longer than or shorter than 4" - -let rec bit_to_hex v = + | [false; false; false; false] -> "0" + | [false; false; false; true] -> "1" + | [false; false; true; false] -> "2" + | [false; false; true; true] -> "3" + | [false; true; false; false] -> "4" + | [false; true; false; true] -> "5" + | [false; true; true; false] -> "6" + | [false; true; true; true] -> "7" + | [true; false; false; false] -> "8" + | [true; false; false; true] -> "9" + | [true; false; true; false] -> "a" + | [true; false; true; true] -> "b" + | [true; true; false; false] -> "c" + | [true; true; false; true] -> "d" + | [true; true; true; false] -> "e" + | [true; true; true; true] -> "f" + | _ -> failwith "half_byte_to_hex given list of length longer than or shorter than 4" + +let rec bit_to_hex v = match v with - | [] -> "" - | a::b::c::d::vs -> half_byte_to_hex [a;b;c;d] ^ bit_to_hex vs - | _ -> failwith "bitstring given not divisible by 4" + | [] -> "" + | a :: b :: c :: d :: vs -> half_byte_to_hex [a; b; c; d] ^ bit_to_hex vs + | _ -> failwith "bitstring given not divisible by 4" (*let val_to_hex_string v = match v with - | Bitvector(bools, _, _) -> "0x" ^ bit_to_hex bools - | Bytevector words -> val_to_string v - | Unknown0 -> "Error: cannot turn Unknown into hex" -;;*) + | Bitvector(bools, _, _) -> "0x" ^ bit_to_hex bools + | Bytevector words -> val_to_string v + | Unknown0 -> "Error: cannot turn Unknown into hex" + ;;*) -let dir_to_string = function - | D_increasing -> "inc" - | D_decreasing -> "dec" +let dir_to_string = function D_increasing -> "inc" | D_decreasing -> "dec" let reg_name_to_string = function - | Reg(s,start,size,d) -> s (*^ "(" ^ (string_of_int start) ^ ", " ^ (string_of_int size) ^ ", " ^ (dir_to_string d) ^ ")"*) - | Reg_slice(s,start,dir,(first,second)) -> - let first,second = - match dir with - | D_increasing -> (first,second) - | D_decreasing -> (start - first, start - second) in - s ^ "[" ^ string_of_int first ^ (if (first = second) then "" else ".." ^ (string_of_int second)) ^ "]" - | Reg_field(s,_,_,f,_) -> s ^ "." ^ f - | Reg_f_slice(s,start,dir,f,_,(first,second)) -> - let first,second = - match dir with - | D_increasing -> (first,second) - | D_decreasing -> (start - first, start - second) in - s ^ "." ^ f ^ "]" ^ string_of_int first ^ (if (first = second) then "" else ".." ^ (string_of_int second)) ^ "]" + | Reg (s, start, size, d) -> + s (*^ "(" ^ (string_of_int start) ^ ", " ^ (string_of_int size) ^ ", " ^ (dir_to_string d) ^ ")"*) + | Reg_slice (s, start, dir, (first, second)) -> + let first, second = + match dir with D_increasing -> (first, second) | D_decreasing -> (start - first, start - second) + in + s ^ "[" ^ string_of_int first ^ (if first = second then "" else ".." ^ string_of_int second) ^ "]" + | Reg_field (s, _, _, f, _) -> s ^ "." ^ f + | Reg_f_slice (s, start, dir, f, _, (first, second)) -> + let first, second = + match dir with D_increasing -> (first, second) | D_decreasing -> (start - first, start - second) + in + s ^ "." ^ f ^ "]" ^ string_of_int first ^ (if first = second then "" else ".." ^ string_of_int second) ^ "]" let dependencies_to_string dependencies = String.concat ", " (List.map reg_name_to_string dependencies) let rec top_frame_exp_state = function | Interp.Top -> raise (Invalid_argument "top_frame_exp") - | Interp.Hole_frame(_, e, _, env, mem, Interp.Top) - | Interp.Thunk_frame(e, _, env, mem, Interp.Top) -> (e,(env,mem)) - | Interp.Thunk_frame(_, _, _, _, s) - | Interp.Hole_frame(_, _, _, _, _, s) -> top_frame_exp_state s + | Interp.Hole_frame (_, e, _, env, mem, Interp.Top) | Interp.Thunk_frame (e, _, env, mem, Interp.Top) -> + (e, (env, mem)) + | Interp.Thunk_frame (_, _, _, _, s) | Interp.Hole_frame (_, _, _, _, _, s) -> top_frame_exp_state s -let tunk = Unknown, None -let ldots = E_aux(E_id (Id_aux (Id "...", Unknown)), tunk) +let tunk = (Unknown, None) +let ldots = E_aux (E_id (Id_aux (Id "...", Unknown)), tunk) let rec compact_exp (E_aux (e, l)) = let wrap e = E_aux (e, l) in match e with - | E_block (e :: _) -> compact_exp e - | E_nondet (e :: _) -> compact_exp e - | E_if (e, _, _) -> - wrap(E_if(compact_exp e, ldots, E_aux(E_block [], tunk))) - | E_for (i, e1, e2, e3, o, e4) -> - wrap(E_for(i, compact_exp e1, compact_exp e2, compact_exp e3, o, ldots)) - | E_match (e, _) -> - wrap(E_match(compact_exp e, [])) - | E_let (bind, _) -> wrap(E_let(bind, ldots)) - | E_app (f, args) -> wrap(E_app(f, List.map compact_exp args)) - | E_app_infix (l, op, r) -> wrap(E_app_infix(compact_exp l, op, compact_exp r)) - | E_tuple exps -> wrap(E_tuple(List.map compact_exp exps)) - | E_vector exps -> wrap(E_vector(List.map compact_exp exps)) - | E_vector_access (e1, e2) -> - wrap(E_vector_access(compact_exp e1, compact_exp e2)) - | E_vector_subrange (e1, e2, e3) -> - wrap(E_vector_subrange(compact_exp e1, compact_exp e2, compact_exp e3)) - | E_vector_update (e1, e2, e3) -> - wrap(E_vector_update(compact_exp e1, compact_exp e2, compact_exp e3)) - | E_vector_update_subrange (e1, e2, e3, e4) -> - wrap(E_vector_update_subrange(compact_exp e1, compact_exp e2, compact_exp e3, compact_exp e4)) - | E_vector_append (e1, e2) -> - wrap(E_vector_append(compact_exp e1, compact_exp e2)) - | E_list exps -> wrap(E_list(List.map compact_exp exps)) - | E_cons (e1, e2) -> - wrap(E_cons(compact_exp e1, compact_exp e2)) - | E_struct_update (e, fexps) -> - wrap(E_struct_update (compact_exp e, fexps)) - | E_field (e, id) -> - wrap(E_field(compact_exp e, id)) - | E_assign (lexp, e) -> wrap(E_assign(lexp, compact_exp e)) - | E_block [] | E_nondet [] | E_typ (_, _) | E_internal_cast (_, _) - | _ -> wrap e + | E_block (e :: _) -> compact_exp e + | E_nondet (e :: _) -> compact_exp e + | E_if (e, _, _) -> wrap (E_if (compact_exp e, ldots, E_aux (E_block [], tunk))) + | E_for (i, e1, e2, e3, o, e4) -> wrap (E_for (i, compact_exp e1, compact_exp e2, compact_exp e3, o, ldots)) + | E_match (e, _) -> wrap (E_match (compact_exp e, [])) + | E_let (bind, _) -> wrap (E_let (bind, ldots)) + | E_app (f, args) -> wrap (E_app (f, List.map compact_exp args)) + | E_app_infix (l, op, r) -> wrap (E_app_infix (compact_exp l, op, compact_exp r)) + | E_tuple exps -> wrap (E_tuple (List.map compact_exp exps)) + | E_vector exps -> wrap (E_vector (List.map compact_exp exps)) + | E_vector_access (e1, e2) -> wrap (E_vector_access (compact_exp e1, compact_exp e2)) + | E_vector_subrange (e1, e2, e3) -> wrap (E_vector_subrange (compact_exp e1, compact_exp e2, compact_exp e3)) + | E_vector_update (e1, e2, e3) -> wrap (E_vector_update (compact_exp e1, compact_exp e2, compact_exp e3)) + | E_vector_update_subrange (e1, e2, e3, e4) -> + wrap (E_vector_update_subrange (compact_exp e1, compact_exp e2, compact_exp e3, compact_exp e4)) + | E_vector_append (e1, e2) -> wrap (E_vector_append (compact_exp e1, compact_exp e2)) + | E_list exps -> wrap (E_list (List.map compact_exp exps)) + | E_cons (e1, e2) -> wrap (E_cons (compact_exp e1, compact_exp e2)) + | E_struct_update (e, fexps) -> wrap (E_struct_update (compact_exp e, fexps)) + | E_field (e, id) -> wrap (E_field (compact_exp e, id)) + | E_assign (lexp, e) -> wrap (E_assign (lexp, compact_exp e)) + | E_block [] | E_nondet [] | E_typ (_, _) | E_internal_cast (_, _) | _ -> wrap e (* extract, compact and reverse expressions on the stack; * the top of the stack is the head of the returned list. *) -let rec compact_stack ?(acc=[]) = function +let rec compact_stack ?(acc = []) = function | Interp.Top -> acc - | Interp.Hole_frame(_,e,_,env,mem,s) - | Interp.Thunk_frame(e,_,env,mem,s) -> compact_stack ~acc:(((compact_exp e),(env,mem)) :: acc) s -;; + | Interp.Hole_frame (_, e, _, env, mem, s) | Interp.Thunk_frame (e, _, env, mem, s) -> + compact_stack ~acc:((compact_exp e, (env, mem)) :: acc) s -let sub_to_string = function None -> "" | Some (x, y) -> sprintf " (%s, %s)" - (to_string x) (to_string y) -;; +let sub_to_string = function None -> "" | Some (x, y) -> sprintf " (%s, %s)" (to_string x) (to_string y) -let format_tracking t = match t with - | Some rs -> "{ " ^ (dependencies_to_string rs) ^ "}" - | None -> "None" +let format_tracking t = match t with Some rs -> "{ " ^ dependencies_to_string rs ^ "}" | None -> "None" let rec format_events = function - | [] -> - " Done\n" - | [E_error s] -> - " Failed with message : " ^ s ^ "\n" - | (E_error s)::events -> - " Failed with message : " ^ s ^ " but continued on erroneously\n" - | (E_read_mem(read_kind, (Address_lifted(location, _)), length, tracking))::events -> - " Read_mem at " ^ (memory_value_to_string E_big_endian location) ^ " based on registers " ^ - format_tracking tracking ^ " for " ^ (string_of_int length) ^ " bytes \n" ^ - (format_events events) - | (E_read_memt(read_kind, (Address_lifted(location, _)), length, tracking))::events -> - " Read_memt at " ^ (memory_value_to_string E_big_endian location) ^ " based on registers " ^ - format_tracking tracking ^ " for " ^ (string_of_int length) ^ " bytes \n" ^ - (format_events events) - | (E_write_mem(write_kind,(Address_lifted (location,_)), length, tracking, value, v_tracking))::events -> - " Write_mem at " ^ (memory_value_to_string E_big_endian location) ^ ", based on registers " ^ - format_tracking tracking ^ ", writing " ^ (memory_value_to_string E_big_endian value) ^ - ", based on " ^ format_tracking v_tracking ^ " across " ^ (string_of_int length) ^ " bytes\n" ^ - (format_events events) - | (E_write_ea(write_kind,(Address_lifted (location,_)), length, tracking))::events -> - " Write_ea at " ^ (memory_value_to_string E_big_endian location) ^ ", based on registers " ^ - format_tracking tracking ^ " across " ^ (string_of_int length) ^ " bytes\n" ^ - (format_events events) - | E_excl_res::events -> - " Excl_res\n" ^ (format_events events) - | (E_write_memv(_, value, v_tracking))::events -> - " Write_memv of " ^ (memory_value_to_string E_big_endian value) ^ ", based on registers " ^ - format_tracking v_tracking ^ "\n" ^ - (format_events events) - | (E_write_memvt(_, (tag, value), v_tracking))::events -> - " Write_memvt of " ^ (memory_value_to_string E_big_endian value) ^ ", based on registers " ^ - format_tracking v_tracking ^ "\n" ^ - (format_events events) - | ((E_barrier b_kind)::events) -> - " Memory_barrier occurred\n" ^ - (format_events events) - | (E_read_reg reg_name)::events -> - " Read_reg of " ^ (reg_name_to_string reg_name) ^ "\n" ^ - (format_events events) - | (E_write_reg(reg_name, value))::events -> - " Write_reg of " ^ (reg_name_to_string reg_name) ^ " writing " ^ (register_value_to_string value) ^ "\n" ^ - (format_events events) - | E_escape::events -> - " Escape event\n"^ (format_events events) - | E_footprint::events -> - " Dynamic footprint calculation event\n" ^ (format_events events) -;; + | [] -> " Done\n" + | [E_error s] -> " Failed with message : " ^ s ^ "\n" + | E_error s :: events -> " Failed with message : " ^ s ^ " but continued on erroneously\n" + | E_read_mem (read_kind, Address_lifted (location, _), length, tracking) :: events -> + " Read_mem at " + ^ memory_value_to_string E_big_endian location + ^ " based on registers " ^ format_tracking tracking ^ " for " ^ string_of_int length ^ " bytes \n" + ^ format_events events + | E_read_memt (read_kind, Address_lifted (location, _), length, tracking) :: events -> + " Read_memt at " + ^ memory_value_to_string E_big_endian location + ^ " based on registers " ^ format_tracking tracking ^ " for " ^ string_of_int length ^ " bytes \n" + ^ format_events events + | E_write_mem (write_kind, Address_lifted (location, _), length, tracking, value, v_tracking) :: events -> + " Write_mem at " + ^ memory_value_to_string E_big_endian location + ^ ", based on registers " ^ format_tracking tracking ^ ", writing " + ^ memory_value_to_string E_big_endian value + ^ ", based on " ^ format_tracking v_tracking ^ " across " ^ string_of_int length ^ " bytes\n" + ^ format_events events + | E_write_ea (write_kind, Address_lifted (location, _), length, tracking) :: events -> + " Write_ea at " + ^ memory_value_to_string E_big_endian location + ^ ", based on registers " ^ format_tracking tracking ^ " across " ^ string_of_int length ^ " bytes\n" + ^ format_events events + | E_excl_res :: events -> " Excl_res\n" ^ format_events events + | E_write_memv (_, value, v_tracking) :: events -> + " Write_memv of " + ^ memory_value_to_string E_big_endian value + ^ ", based on registers " ^ format_tracking v_tracking ^ "\n" ^ format_events events + | E_write_memvt (_, (tag, value), v_tracking) :: events -> + " Write_memvt of " + ^ memory_value_to_string E_big_endian value + ^ ", based on registers " ^ format_tracking v_tracking ^ "\n" ^ format_events events + | E_barrier b_kind :: events -> " Memory_barrier occurred\n" ^ format_events events + | E_read_reg reg_name :: events -> " Read_reg of " ^ reg_name_to_string reg_name ^ "\n" ^ format_events events + | E_write_reg (reg_name, value) :: events -> + " Write_reg of " ^ reg_name_to_string reg_name ^ " writing " ^ register_value_to_string value ^ "\n" + ^ format_events events + | E_escape :: events -> " Escape event\n" ^ format_events events + | E_footprint :: events -> " Dynamic footprint calculation event\n" ^ format_events events (* ANSI/VT100 colors *) -type ppmode = - | Interp_latex - | Interp_ascii - | Interp_html +type ppmode = Interp_latex | Interp_ascii | Interp_html let ppmode = ref Interp_ascii -let set_interp_ppmode ppm = ppmode := ppm +let set_interp_ppmode ppm = ppmode := ppm let disable_color = ref false let set_color_enabled on = disable_color := not on let color bright code s = - if !disable_color then s - else sprintf "\x1b[%s3%dm%s\x1b[m" (if bright then "1;" else "") code s -let red s = - match !ppmode with - | Interp_html -> ""^ s ^"" - | Interp_latex -> "\\myred{" ^ s ^"}" - | Interp_ascii -> color true 1 s + if !disable_color then s else sprintf "\x1b[%s3%dm%s\x1b[m" (if bright then "1;" else "") code s +let red s = + match !ppmode with + | Interp_html -> "" ^ s ^ "" + | Interp_latex -> "\\myred{" ^ s ^ "}" + | Interp_ascii -> color true 1 s let green = color false 2 let yellow = color true 3 let blue = color true 4 @@ -437,77 +390,81 @@ let grey = color false 7 let exp_to_string env mem show_hole_value e = Pretty_interp.pp_exp env mem red show_hole_value e -let get_loc (E_aux(_, (l, (_ : tannot)))) = loc_to_string l +let get_loc (E_aux (_, (l, (_ : tannot)))) = loc_to_string l let print_exp printer env mem show_hole_value e = - printer ((get_loc e) ^ ": " ^ (Pretty_interp.pp_exp env mem red show_hole_value e) ^ "\n") + printer (get_loc e ^ ": " ^ Pretty_interp.pp_exp env mem red show_hole_value e ^ "\n") -let instruction_state_to_string (IState(stack, _)) = - List.fold_right (fun (e,(env,mem)) es -> (exp_to_string env mem true e) ^ "\n" ^ es) (compact_stack stack) "" +let instruction_state_to_string (IState (stack, _)) = + List.fold_right (fun (e, (env, mem)) es -> exp_to_string env mem true e ^ "\n" ^ es) (compact_stack stack) "" -let top_instruction_state_to_string (IState(stack,_)) = - let (exp,(env,mem)) = top_frame_exp_state stack in exp_to_string env mem true exp +let top_instruction_state_to_string (IState (stack, _)) = + let exp, (env, mem) = top_frame_exp_state stack in + exp_to_string env mem true exp -let instruction_stack_to_string (IState(stack,_)) = +let instruction_stack_to_string (IState (stack, _)) = let rec stack_to_string = function - Interp.Top -> "" - | Interp.Hole_frame(_,e,_,env,mem,Interp.Top) - | Interp.Thunk_frame(e,_,env,mem,Interp.Top) -> - exp_to_string env mem true e - | Interp.Hole_frame(_,e,_,env,mem,s) - | Interp.Thunk_frame(e,_,env,mem,s) -> - (exp_to_string env mem false e) ^ "\n----------------------------------------------------------\n" ^ - (stack_to_string s) + | Interp.Top -> "" + | Interp.Hole_frame (_, e, _, env, mem, Interp.Top) | Interp.Thunk_frame (e, _, env, mem, Interp.Top) -> + exp_to_string env mem true e + | Interp.Hole_frame (_, e, _, env, mem, s) | Interp.Thunk_frame (e, _, env, mem, s) -> + exp_to_string env mem false e ^ "\n----------------------------------------------------------\n" + ^ stack_to_string s in match stack with - | Interp.Hole_frame(_,(E_aux (E_id (Id_aux (Id "0",_)), _)),_,_,_,s) -> - stack_to_string s + | Interp.Hole_frame (_, E_aux (E_id (Id_aux (Id "0", _)), _), _, _, _, s) -> stack_to_string s | _ -> stack_to_string stack - -let rec option_map f xs = - match xs with - | [] -> [] - | x::xs -> - ( match f x with - | None -> option_map f xs - | Some x -> x :: (option_map f xs) ) - -let local_variables_to_string (IState(stack,_)) = - let (_,(env,mem)) = top_frame_exp_state stack in + +let rec option_map f xs = + match xs with + | [] -> [] + | x :: xs -> ( + match f x with None -> option_map f xs | Some x -> x :: option_map f xs + ) + +let local_variables_to_string (IState (stack, _)) = + let _, (env, mem) = top_frame_exp_state stack in match env with - | Interp.LEnv(_,env) -> - String.concat ", " (option_map (fun (id,value)-> - match id with - | "0" -> None (*Let's not print out the context hole again*) - | _ -> Some (id ^ "=" ^ Interp.string_of_value value)) (Pmap.bindings_list env)) - -let instr_parm_to_string (name, typ, value) = - name ^"="^ + | Interp.LEnv (_, env) -> + String.concat ", " + (option_map + (fun (id, value) -> + match id with + | "0" -> None (*Let's not print out the context hole again*) + | _ -> Some (id ^ "=" ^ Interp.string_of_value value) + ) + (Pmap.bindings_list env) + ) + +let instr_parm_to_string (name, typ, value) = + name ^ "=" + ^ match typ with - | Other -> "Unrepresentable external value" - | _ -> let intern_v = (Interp_inter_imp.intern_ifield_value D_increasing value) in - match Interp_lib.to_num Interp_lib.Unsigned intern_v with - | Interp_ast.V_lit (L_aux(L_num n, _)) -> to_string n - | _ -> ifield_to_string value - -let rec instr_parms_to_string ps = + | Other -> "Unrepresentable external value" + | _ -> ( + let intern_v = Interp_inter_imp.intern_ifield_value D_increasing value in + match Interp_lib.to_num Interp_lib.Unsigned intern_v with + | Interp_ast.V_lit (L_aux (L_num n, _)) -> to_string n + | _ -> ifield_to_string value + ) + +let rec instr_parms_to_string ps = match ps with - | [] -> "" - | [p] -> instr_parm_to_string p - | p::ps -> instr_parm_to_string p ^ " " ^ instr_parms_to_string ps + | [] -> "" + | [p] -> instr_parm_to_string p + | p :: ps -> instr_parm_to_string p ^ " " ^ instr_parms_to_string ps -let pad n s = if String.length s < n then s ^ String.make (n-String.length s) ' ' else s +let pad n s = if String.length s < n then s ^ String.make (n - String.length s) ' ' else s -let instruction_to_string (name, parms) = - ((*pad 5*) (String.lowercase name)) ^ " " ^ instr_parms_to_string parms +let instruction_to_string (name, parms) = (*pad 5*) String.lowercase name ^ " " ^ instr_parms_to_string parms -let print_backtrace_compact printer (IState(stack,_)) = - List.iter (fun (e,(env,mem)) -> print_exp printer env mem true e) (compact_stack stack) +let print_backtrace_compact printer (IState (stack, _)) = + List.iter (fun (e, (env, mem)) -> print_exp printer env mem true e) (compact_stack stack) let print_stack printer is = printer (instruction_stack_to_string is) -let print_continuation printer (IState(stack,_)) = - let (e,(env,mem)) = top_frame_exp_state stack in print_exp printer env mem true e +let print_continuation printer (IState (stack, _)) = + let e, (env, mem) = top_frame_exp_state stack in + print_exp printer env mem true e let print_instruction printer instr = printer (instruction_to_string instr) -let pp_instruction_state state () = - (instruction_stack_to_string state,local_variables_to_string state) +let pp_instruction_state state () = (instruction_stack_to_string state, local_variables_to_string state) diff --git a/src/lem_interp/printing_functions.mli b/src/lem_interp/printing_functions.mli index 908d101d6..78f067965 100644 --- a/src/lem_interp/printing_functions.mli +++ b/src/lem_interp/printing_functions.mli @@ -1,30 +1,31 @@ -open Interp_utilities;; -open Interp_ast ;; -open Sail_impl_base ;; -open Interp_interface ;; +open Interp_utilities +open Interp_ast +open Sail_impl_base +open Interp_interface (*Functions to translate values, registers, or locations strings *) (*Takes a location to a formatted string*) val loc_to_string : l -> string + (*Returns the result of above for the exp's location *) val get_loc : tannot exp -> string + (*interp_interface.value to string*) val reg_value_to_string : register_value -> string (*(*Force all representations to hex strings instead of a mixture of hex and binary strings*) -val val_to_hex_string : value0 -> string*) + val val_to_hex_string : value0 -> string*) (* format one register *) val reg_name_to_string : reg_name -> string + (* format the register dependencies *) val dependencies_to_string : reg_name list -> string + (* formats an expression, using interp_pretty *) val exp_to_string : Interp.lenv -> Interp.lmem -> bool -> tannot exp -> string (* Functions to set the color of parts of the output *) -type ppmode = - | Interp_latex - | Interp_ascii - | Interp_html +type ppmode = Interp_latex | Interp_ascii | Interp_html val set_interp_ppmode : ppmode -> unit val set_color_enabled : bool -> unit @@ -34,28 +35,28 @@ val green : string -> string val yellow : string -> string val grey : string -> string - (*Functions to modify the instruction state and expression used in printing and in run_model*) val compact_exp : tannot exp -> tannot exp -val top_frame_exp_state : interpreter_state -> (tannot exp * (Interp.lenv*Interp.lmem)) - +val top_frame_exp_state : interpreter_state -> tannot exp * (Interp.lenv * Interp.lmem) (*functions to format events and instruction_states to strings *) (*Create one large string of all of the events (indents automatically) *) val format_events : event list -> string + (*format a portion of the instruction state for easy viewing *) val instruction_state_to_string : instruction_state -> string + (*format a the cull instruction call stack*) val instruction_stack_to_string : instruction_state -> string + (*format just the top of the call stack*) val top_instruction_state_to_string : instruction_state -> string val local_variables_to_string : instruction_state -> string - val instruction_to_string : instruction -> string (*Functions to take a print function and cause a print event for the above functions *) -val print_exp : (string-> unit) -> Interp.lenv -> Interp.lmem -> bool -> tannot exp -> unit +val print_exp : (string -> unit) -> Interp.lenv -> Interp.lmem -> bool -> tannot exp -> unit val print_backtrace_compact : (string -> unit) -> instruction_state -> unit val print_continuation : (string -> unit) -> instruction_state -> unit val print_instruction : (string -> unit) -> instruction -> unit @@ -64,7 +65,6 @@ val print_stack : (string -> unit) -> instruction_state -> unit val register_value_to_string : register_value -> string val memory_value_to_string : end_flag -> memory_value -> string - val logfile_register_value_to_string : register_value -> string val logfile_memory_value_to_string : end_flag -> memory_value -> string val logfile_address_to_string : address -> string @@ -72,4 +72,4 @@ val logfile_address_to_string : address -> string val byte_list_to_string : byte list -> string val bit_lifted_to_string : bit_lifted -> string -val pp_instruction_state : instruction_state -> unit -> (string * string) +val pp_instruction_state : instruction_state -> unit -> string * string diff --git a/src/lem_interp/run_interp.ml b/src/lem_interp/run_interp.ml index a70a29e32..2ddc0062c 100644 --- a/src/lem_interp/run_interp.ml +++ b/src/lem_interp/run_interp.ml @@ -65,67 +65,63 @@ (* SUCH DAMAGE. *) (****************************************************************************) -open Printf ;; -open Interp_ast ;; -open Interp ;; -open Interp_lib ;; -open Interp_interface ;; -open Interp_inter_imp ;; +open Printf +open Interp_ast +open Interp +open Interp_lib +open Interp_interface +open Interp_inter_imp -open Big_int ;; +open Big_int let lit_to_string = function - | L_unit -> "unit" - | L_zero -> "0b0" - | L_one -> "0b1" - | L_true -> "true" - | L_false -> "false" - | L_num n -> string_of_big_int n - | L_hex s -> "0x"^s - | L_bin s -> "0b"^s - | L_undef -> "undefined" - | L_string s -> "\"" ^ s ^ "\"" -;; + | L_unit -> "unit" + | L_zero -> "0b0" + | L_one -> "0b1" + | L_true -> "true" + | L_false -> "false" + | L_num n -> string_of_big_int n + | L_hex s -> "0x" ^ s + | L_bin s -> "0b" ^ s + | L_undef -> "undefined" + | L_string s -> "\"" ^ s ^ "\"" -let id_to_string = function - | Id_aux(Id s,_) | Id_aux(DeIid s,_) -> s -;; +let id_to_string = function Id_aux (Id s, _) | Id_aux (DeIid s, _) -> s let loc_to_string = function | Unknown -> "location unknown" - | Int(s,_) -> s - | Range(s,fline,fchar,tline,tchar) -> - if fline = tline - then sprintf "%s:%d:%d" s fline fchar - else sprintf "%s:%d:%d-%d:%d" s fline fchar tline tchar -;; + | Int (s, _) -> s + | Range (s, fline, fchar, tline, tchar) -> + if fline = tline then sprintf "%s:%d:%d" s fline fchar else sprintf "%s:%d:%d-%d:%d" s fline fchar tline tchar let collapse_leading s = - if String.length s <= 8 then s else - let first_bit = s.[0] in - let templ = sprintf "%c...%c" first_bit first_bit in - let regexp = Str.regexp "^\\(000000*\\|111111*\\)" in - Str.replace_first regexp templ s -;; + if String.length s <= 8 then s + else ( + let first_bit = s.[0] in + let templ = sprintf "%c...%c" first_bit first_bit in + let regexp = Str.regexp "^\\(000000*\\|111111*\\)" in + Str.replace_first regexp templ s + ) -let bitvec_to_string l = "0b" ^ collapse_leading (String.concat "" (List.map (function - | V_lit(L_aux(L_zero, _)) -> "0" - | V_lit(L_aux(L_one, _)) -> "1" - | _ -> assert false) l)) -;; +let bitvec_to_string l = + "0b" + ^ collapse_leading + (String.concat "" + (List.map (function V_lit (L_aux (L_zero, _)) -> "0" | V_lit (L_aux (L_one, _)) -> "1" | _ -> assert false) l) + ) (*let val_to_string v = match v with | Bitvector(bools, _, _) -> "0b" ^ collapse_leading (String.concat "" (List.map (function | true -> "1" | _ -> "0") bools)) | Bytevector words-> "0x" ^ (String.concat "" - (List.map (function - | 10 -> "A" - | 11 -> "B" - | 12 -> "C" - | 13 -> "D" - | 14 -> "E" - | 15 -> "F" - | i -> string_of_int i) words)) + (List.map (function + | 10 -> "A" + | 11 -> "B" + | 12 -> "C" + | 13 -> "D" + | 14 -> "E" + | 15 -> "F" + | i -> string_of_int i) words)) | Unknown0 -> "Unknown"*) (*let reg_name_to_string = function @@ -135,221 +131,206 @@ let bitvec_to_string l = "0b" ^ collapse_leading (String.concat "" (List.map (fu | Reg_f_slice(s,f,_,(first,second)) -> s ^ "." ^ f*) let rec reg_to_string = function - | Reg (id,_) -> id_to_string id - | SubReg (id,r,_) -> sprintf "%s.%s" (reg_to_string r) (id_to_string id) -;; + | Reg (id, _) -> id_to_string id + | SubReg (id, r, _) -> sprintf "%s.%s" (reg_to_string r) (id_to_string id) let rec top_frame_exp_state = function | Top -> raise (Invalid_argument "top_frame_exp") - | Hole_frame(_, e, _, env, mem, Top) - | Thunk_frame(e, _, env, mem, Top) -> (e,(env,mem)) - | Thunk_frame(_, _, _, _, s) - | Hole_frame(_, _, _, _, _, s) -> top_frame_exp_state s + | Hole_frame (_, e, _, env, mem, Top) | Thunk_frame (e, _, env, mem, Top) -> (e, (env, mem)) + | Thunk_frame (_, _, _, _, s) | Hole_frame (_, _, _, _, _, s) -> top_frame_exp_state s -let tunk = Unknown, None -let ldots = E_aux(E_id (Id_aux (Id "...", Unknown)), tunk) +let tunk = (Unknown, None) +let ldots = E_aux (E_id (Id_aux (Id "...", Unknown)), tunk) let rec compact_exp (E_aux (e, l)) = let wrap e = E_aux (e, l) in match e with - | E_block (e :: _) -> compact_exp e - | E_nondet (e :: _) -> compact_exp e - | E_if (e, _, _) -> - wrap(E_if(compact_exp e, ldots, E_aux(E_block [], tunk))) - | E_for (i, e1, e2, e3, o, e4) -> - wrap(E_for(i, compact_exp e1, compact_exp e2, compact_exp e3, o, ldots)) - | E_match (e, _) -> - wrap(E_match(compact_exp e, [])) - | E_let (bind, _) -> wrap(E_let(bind, ldots)) - | E_app (f, args) -> wrap(E_app(f, List.map compact_exp args)) - | E_app_infix (l, op, r) -> wrap(E_app_infix(compact_exp l, op, compact_exp r)) - | E_tuple exps -> wrap(E_tuple(List.map compact_exp exps)) - | E_vector exps -> wrap(E_vector(List.map compact_exp exps)) - | E_vector_access (e1, e2) -> - wrap(E_vector_access(compact_exp e1, compact_exp e2)) - | E_vector_subrange (e1, e2, e3) -> - wrap(E_vector_subrange(compact_exp e1, compact_exp e2, compact_exp e3)) - | E_vector_update (e1, e2, e3) -> - wrap(E_vector_update(compact_exp e1, compact_exp e2, compact_exp e3)) - | E_vector_update_subrange (e1, e2, e3, e4) -> - wrap(E_vector_update_subrange(compact_exp e1, compact_exp e2, compact_exp e3, compact_exp e4)) - | E_vector_append (e1, e2) -> - wrap(E_vector_append(compact_exp e1, compact_exp e2)) - | E_list exps -> wrap(E_list(List.map compact_exp exps)) - | E_cons (e1, e2) -> - wrap(E_cons(compact_exp e1, compact_exp e2)) - | E_struct_update (e, fexps) -> - wrap(E_struct_update (compact_exp e, fexps)) - | E_field (e, id) -> - wrap(E_field(compact_exp e, id)) - | E_assign (lexp, e) -> wrap(E_assign(lexp, compact_exp e)) - | E_block [] | E_nondet [] | E_typ (_, _) | E_internal_cast (_, _) - | E_id _|E_lit _|E_vector_indexed (_, _)|E_struct _|E_internal_exp _ -> - wrap e + | E_block (e :: _) -> compact_exp e + | E_nondet (e :: _) -> compact_exp e + | E_if (e, _, _) -> wrap (E_if (compact_exp e, ldots, E_aux (E_block [], tunk))) + | E_for (i, e1, e2, e3, o, e4) -> wrap (E_for (i, compact_exp e1, compact_exp e2, compact_exp e3, o, ldots)) + | E_match (e, _) -> wrap (E_match (compact_exp e, [])) + | E_let (bind, _) -> wrap (E_let (bind, ldots)) + | E_app (f, args) -> wrap (E_app (f, List.map compact_exp args)) + | E_app_infix (l, op, r) -> wrap (E_app_infix (compact_exp l, op, compact_exp r)) + | E_tuple exps -> wrap (E_tuple (List.map compact_exp exps)) + | E_vector exps -> wrap (E_vector (List.map compact_exp exps)) + | E_vector_access (e1, e2) -> wrap (E_vector_access (compact_exp e1, compact_exp e2)) + | E_vector_subrange (e1, e2, e3) -> wrap (E_vector_subrange (compact_exp e1, compact_exp e2, compact_exp e3)) + | E_vector_update (e1, e2, e3) -> wrap (E_vector_update (compact_exp e1, compact_exp e2, compact_exp e3)) + | E_vector_update_subrange (e1, e2, e3, e4) -> + wrap (E_vector_update_subrange (compact_exp e1, compact_exp e2, compact_exp e3, compact_exp e4)) + | E_vector_append (e1, e2) -> wrap (E_vector_append (compact_exp e1, compact_exp e2)) + | E_list exps -> wrap (E_list (List.map compact_exp exps)) + | E_cons (e1, e2) -> wrap (E_cons (compact_exp e1, compact_exp e2)) + | E_struct_update (e, fexps) -> wrap (E_struct_update (compact_exp e, fexps)) + | E_field (e, id) -> wrap (E_field (compact_exp e, id)) + | E_assign (lexp, e) -> wrap (E_assign (lexp, compact_exp e)) + | E_block [] + | E_nondet [] + | E_typ (_, _) + | E_internal_cast (_, _) + | E_id _ | E_lit _ + | E_vector_indexed (_, _) + | E_struct _ | E_internal_exp _ -> + wrap e (* extract, compact and reverse expressions on the stack; * the top of the stack is the head of the returned list. *) -let rec compact_stack ?(acc=[]) = function +let rec compact_stack ?(acc = []) = function | Top -> acc - | Hole_frame(_,e,_,env,mem,s) - | Thunk_frame(e,_,env,mem,s) -> compact_stack ~acc:(((compact_exp e),(env,mem)) :: acc) s -;; + | Hole_frame (_, e, _, env, mem, s) | Thunk_frame (e, _, env, mem, s) -> + compact_stack ~acc:((compact_exp e, (env, mem)) :: acc) s -let sub_to_string = function None -> "" | Some (x, y) -> sprintf " (%s, %s)" - (string_of_big_int x) (string_of_big_int y) -;; +let sub_to_string = function + | None -> "" + | Some (x, y) -> sprintf " (%s, %s)" (string_of_big_int x) (string_of_big_int y) -let id_compare i1 i2 = - match (i1, i2) with - | (Id_aux(Id(i1),_),Id_aux(Id(i2),_)) - | (Id_aux(Id(i1),_),Id_aux(DeIid(i2),_)) - | (Id_aux(DeIid(i1),_),Id_aux(Id(i2),_)) - | (Id_aux(DeIid(i1),_),Id_aux(DeIid(i2),_)) -> compare i1 i2 +let id_compare i1 i2 = + match (i1, i2) with + | Id_aux (Id i1, _), Id_aux (Id i2, _) + | Id_aux (Id i1, _), Id_aux (DeIid i2, _) + | Id_aux (DeIid i1, _), Id_aux (Id i2, _) + | Id_aux (DeIid i1, _), Id_aux (DeIid i2, _) -> + compare i1 i2 module Reg = struct - include Map.Make(struct type t = id let compare = id_compare end) - let to_string id v = - sprintf "%s -> %s" (id_to_string id) (string_of_value v) + include Map.Make (struct + type t = id + let compare = id_compare + end) + let to_string id v = sprintf "%s -> %s" (id_to_string id) (string_of_value v) let find id m = -(* eprintf "reg_find called with %s\n" (id_to_string id);*) + (* eprintf "reg_find called with %s\n" (id_to_string id);*) let v = find id m in -(* eprintf "%s -> %s\n" (id_to_string id) (val_to_string v);*) + (* eprintf "%s -> %s\n" (id_to_string id) (val_to_string v);*) v -end ;; +end + (* Old Mem, that used the id to map as well as the int... which seems wrong -module Mem = struct - include Map.Make(struct - type t = (id * big_int) - let compare (i1, v1) (i2, v2) = - (* optimize for common case: different addresses, same id *) - match compare_big_int v1 v2 with - | 0 -> id_compare i1 i2 - | n -> n - end) - (* debugging memory accesses - let add (n, idx) v m = - eprintf "%s[%s] <- %s\n" (id_to_string n) (string_of_big_int idx) (val_to_string v); - add (n, idx) v m - let find (n, idx) m = - let v = find (n, idx) m in - eprintf "%s[%s] -> %s\n" (id_to_string n) (string_of_big_int idx) (val_to_string v); - v - *) - let to_string (n, idx) v = - sprintf "%s[%s] -> %s" (id_to_string n) (string_of_big_int idx) (val_to_string v) -end ;;*) + module Mem = struct + include Map.Make(struct + type t = (id * big_int) + let compare (i1, v1) (i2, v2) = + (* optimize for common case: different addresses, same id *) + match compare_big_int v1 v2 with + | 0 -> id_compare i1 i2 + | n -> n + end) + (* debugging memory accesses + let add (n, idx) v m = + eprintf "%s[%s] <- %s\n" (id_to_string n) (string_of_big_int idx) (val_to_string v); + add (n, idx) v m + let find (n, idx) m = + let v = find (n, idx) m in + eprintf "%s[%s] -> %s\n" (id_to_string n) (string_of_big_int idx) (val_to_string v); + v + *) + let to_string (n, idx) v = + sprintf "%s[%s] -> %s" (id_to_string n) (string_of_big_int idx) (val_to_string v) + end ;;*) module Mem = struct - include Map.Make(struct + include Map.Make (struct type t = big_int let compare v1 v2 = compare_big_int v1 v2 end) + (* debugging memory accesses - let add idx v m = - eprintf "[%s] <- %s\n" (string_of_big_int idx) (val_to_string v); - add idx v m - let find idx m = - let v = find idx m in - eprintf "[%s] -> %s\n" (string_of_big_int idx) (val_to_string v); - v + let add idx v m = + eprintf "[%s] <- %s\n" (string_of_big_int idx) (val_to_string v); + add idx v m + let find idx m = + let v = find idx m in + eprintf "[%s] -> %s\n" (string_of_big_int idx) (val_to_string v); + v *) - let to_string idx v = - sprintf "[%s] -> %s" (string_of_big_int idx) (string_of_value v) -end ;; - + let to_string idx v = sprintf "[%s] -> %s" (string_of_big_int idx) (string_of_value v) +end -let vconcat v v' = vec_concat (V_tuple [v; v']) ;; +let vconcat v v' = vec_concat (V_tuple [v; v']) -let slice v = function - | None -> v - | Some (n, m) -> slice_vector v n m -;; +let slice v = function None -> v | Some (n, m) -> slice_vector v n m let rec slice_ir v = function | BF_single n -> slice_vector v n n | BF_range (n, m) -> slice_vector v n m | BF_concat (BF_aux (ir, _), BF_aux (ir', _)) -> vconcat (slice_ir v ir) (slice_ir v ir') -;; -let unit_lit = V_lit (L_aux(L_unit,Interp_ast.Unknown)) +let unit_lit = V_lit (L_aux (L_unit, Interp_ast.Unknown)) let rec perform_action ((reg, mem) as env) = function - (* registers *) - | Read_reg (Reg (id, _), sub) -> - slice (Reg.find id reg) sub, env - | Write_reg (Reg (id, _), None, value) -> - unit_lit, (Reg.add id value reg, mem) - | Write_reg (Reg (id, _), Some (start, stop), (V_vector _ as value)) -> - let old_val = Reg.find id reg in - let new_val = fupdate_vector_slice old_val value start stop in - unit_lit, (Reg.add id new_val reg, mem) - (* subregisters *) - | Read_reg (SubReg (_, Reg (id, _), BF_aux (ir, _)), sub) -> - slice (slice_ir (Reg.find id reg) ir) sub, env - | Write_reg (SubReg (_, (Reg _ as r), BF_aux (ir, _)), None, value) -> - (match ir with - | BF_single n -> - perform_action env (Write_reg (r, Some(n, n), value)) - | BF_range (n, m) -> - perform_action env (Write_reg (r, Some(n, m), value)) - | BF_concat _ -> failwith "unimplemented: non-contiguous register write") - (* memory *) - | Read_mem (id, V_lit(L_aux((L_num n),_)), sub) -> - slice (Mem.find n mem) sub, env - | Write_mem (id, V_lit(L_aux(L_num n,_)), None, value) -> - unit_lit, (reg, Mem.add n value mem) - (* multi-byte accesses to memory *) - | Read_mem (id, V_tuple [V_lit(L_aux(L_num n,_)); V_lit(L_aux(L_num size,_))], sub) -> - let rec fetch k acc = - if eq_big_int k size then slice acc sub else - let slice = Mem.find (add_big_int n k) mem in - fetch (succ_big_int k) (vconcat acc slice) - in - fetch zero_big_int (V_vector (zero_big_int, true, [])), env - (* XXX no support for multi-byte slice write at the moment *) - | Write_mem (id, V_tuple [V_lit(L_aux(L_num n,_)); V_lit(L_aux(L_num size,_))], None, V_vector (m, inc, vs)) -> - (* normalize input vector so that it is indexed from 0 - for slices *) - let value = V_vector (zero_big_int, inc, vs) in - (* assumes smallest unit of memory is 8 bit *) - let byte_size = 8 in - let rec update k mem = - if eq_big_int k size then mem else - let n1 = mult_int_big_int byte_size k in - let n2 = sub_big_int (mult_int_big_int byte_size (succ_big_int k)) (big_int_of_int 1) in - let slice = slice_vector value n1 n2 in - let mem' = Mem.add (add_big_int n k) slice mem in - update (succ_big_int k) mem' - in unit_lit, (reg, update zero_big_int mem) - (* This case probably never happens in the POWER spec anyway *) - | Write_mem (id, V_lit(L_aux(L_num n,_)), Some (start, stop), (V_vector _ as value)) -> - let old_val = Mem.find n mem in - let new_val = fupdate_vector_slice old_val value start stop in - unit_lit, (reg, Mem.add n new_val mem) - (* special case for slices of size 1: wrap value in a vector *) - | Write_reg ((Reg (_, _) as r), (Some (start, stop) as slice), value) when eq_big_int start stop -> - perform_action env (Write_reg (r, slice, V_vector(zero_big_int, true, [value]))) - | Write_mem (id, (V_lit(L_aux(L_num _,_)) as n), (Some (start, stop) as slice), value) when eq_big_int start stop -> - perform_action env (Write_mem (id, n, slice, V_vector(zero_big_int, true, [value]))) - (* extern functions *) - | Call_extern (name, arg) -> eval_external name arg, env - | Interp.Step _ | Nondet _ | Exit _ -> unit_lit, env - | _ -> assert false -;; + (* registers *) + | Read_reg (Reg (id, _), sub) -> (slice (Reg.find id reg) sub, env) + | Write_reg (Reg (id, _), None, value) -> (unit_lit, (Reg.add id value reg, mem)) + | Write_reg (Reg (id, _), Some (start, stop), (V_vector _ as value)) -> + let old_val = Reg.find id reg in + let new_val = fupdate_vector_slice old_val value start stop in + (unit_lit, (Reg.add id new_val reg, mem)) + (* subregisters *) + | Read_reg (SubReg (_, Reg (id, _), BF_aux (ir, _)), sub) -> (slice (slice_ir (Reg.find id reg) ir) sub, env) + | Write_reg (SubReg (_, (Reg _ as r), BF_aux (ir, _)), None, value) -> ( + match ir with + | BF_single n -> perform_action env (Write_reg (r, Some (n, n), value)) + | BF_range (n, m) -> perform_action env (Write_reg (r, Some (n, m), value)) + | BF_concat _ -> failwith "unimplemented: non-contiguous register write" + ) + (* memory *) + | Read_mem (id, V_lit (L_aux (L_num n, _)), sub) -> (slice (Mem.find n mem) sub, env) + | Write_mem (id, V_lit (L_aux (L_num n, _)), None, value) -> (unit_lit, (reg, Mem.add n value mem)) + (* multi-byte accesses to memory *) + | Read_mem (id, V_tuple [V_lit (L_aux (L_num n, _)); V_lit (L_aux (L_num size, _))], sub) -> + let rec fetch k acc = + if eq_big_int k size then slice acc sub + else ( + let slice = Mem.find (add_big_int n k) mem in + fetch (succ_big_int k) (vconcat acc slice) + ) + in + (fetch zero_big_int (V_vector (zero_big_int, true, [])), env) + (* XXX no support for multi-byte slice write at the moment *) + | Write_mem (id, V_tuple [V_lit (L_aux (L_num n, _)); V_lit (L_aux (L_num size, _))], None, V_vector (m, inc, vs)) -> + (* normalize input vector so that it is indexed from 0 - for slices *) + let value = V_vector (zero_big_int, inc, vs) in + (* assumes smallest unit of memory is 8 bit *) + let byte_size = 8 in + let rec update k mem = + if eq_big_int k size then mem + else ( + let n1 = mult_int_big_int byte_size k in + let n2 = sub_big_int (mult_int_big_int byte_size (succ_big_int k)) (big_int_of_int 1) in + let slice = slice_vector value n1 n2 in + let mem' = Mem.add (add_big_int n k) slice mem in + update (succ_big_int k) mem' + ) + in + (unit_lit, (reg, update zero_big_int mem)) + (* This case probably never happens in the POWER spec anyway *) + | Write_mem (id, V_lit (L_aux (L_num n, _)), Some (start, stop), (V_vector _ as value)) -> + let old_val = Mem.find n mem in + let new_val = fupdate_vector_slice old_val value start stop in + (unit_lit, (reg, Mem.add n new_val mem)) + (* special case for slices of size 1: wrap value in a vector *) + | Write_reg ((Reg (_, _) as r), (Some (start, stop) as slice), value) when eq_big_int start stop -> + perform_action env (Write_reg (r, slice, V_vector (zero_big_int, true, [value]))) + | Write_mem (id, (V_lit (L_aux (L_num _, _)) as n), (Some (start, stop) as slice), value) when eq_big_int start stop + -> + perform_action env (Write_mem (id, n, slice, V_vector (zero_big_int, true, [value]))) + (* extern functions *) + | Call_extern (name, arg) -> (eval_external name arg, env) + | Interp.Step _ | Nondet _ | Exit _ -> (unit_lit, env) + | _ -> assert false let debug = ref true let debugf : ('a, out_channel, unit) format -> 'a = function f -> if !debug then eprintf f else ifprintf stderr f type interactive_mode = Step | Run | Next -let mode_to_string = function - | Step -> "step" - | Run -> "run" - | Next -> "next" +let mode_to_string = function Step -> "step" | Run -> "run" | Next -> "next" (* ANSI/VT100 colors *) let disable_color = ref false let color bright code s = - if !disable_color then s - else sprintf "\x1b[%s3%dm%s\x1b[m" (if bright then "1;" else "") code s + if !disable_color then s else sprintf "\x1b[%s3%dm%s\x1b[m" (if bright then "1;" else "") code s let red = color true 1 let green = color false 2 let yellow = color true 3 @@ -357,128 +338,142 @@ let blue = color true 4 let grey = color false 7 let run - ?(entry=E_aux(E_app(Id_aux((Id "main"),Unknown), [E_aux(E_lit (L_aux(L_unit,Unknown)),(Unknown,None))]),(Unknown,None))) - ?(reg=Reg.empty) - ?(mem=Mem.empty) - ?(eager_eval=true) - ?mode - (name, test) = - let get_loc (E_aux(_, (l, _))) = loc_to_string l in - let print_exp env e = - debugf "%s: %s\n" (get_loc e) (Pretty_interp.pp_exp env Printing_functions.red e) in + ?(entry = + E_aux + ( E_app (Id_aux (Id "main", Unknown), [E_aux (E_lit (L_aux (L_unit, Unknown)), (Unknown, None))]), + (Unknown, None) + )) ?(reg = Reg.empty) ?(mem = Mem.empty) ?(eager_eval = true) ?mode (name, test) = + let get_loc (E_aux (_, (l, _))) = loc_to_string l in + let print_exp env e = debugf "%s: %s\n" (get_loc e) (Pretty_interp.pp_exp env Printing_functions.red e) in (* interactive loop for step-by-step execution *) - let usage = "Usage: - step go to next action [default] - next go to next break point - run complete current execution, - bt print call stack - cont print continuation of the top stack frame - env print content of environment - mem print content of memory - quit exit interpreter" in + let usage = + "Usage:\n\ + \ step go to next action [default]\n\ + \ next go to next break point\n\ + \ run complete current execution,\n\ + \ bt print call stack\n\ + \ cont print continuation of the top stack frame\n\ + \ env print content of environment\n\ + \ mem print content of memory\n\ + \ quit exit interpreter" + in let rec interact mode ((reg, mem) as env) stack = - flush_all(); + flush_all (); let command = Pervasives.read_line () in let command' = if command = "" then mode_to_string mode else command in - begin match command' with - | "s" | "step" -> Step - | "n" | "next" -> Next - | "r" | "run" -> Run - | "e" | "env" | "environment" -> - Reg.iter (fun k v -> debugf "%s\n" (Reg.to_string k v)) reg; - interact mode env stack - | "m" | "mem" | "memory" -> - Mem.iter (fun k v -> debugf "%s\n" (Mem.to_string k v)) mem; - interact mode env stack - | "bt" | "backtrace" | "stack" -> - List.iter (fun (e,(env,mem)) -> print_exp env e) (compact_stack stack); - interact mode env stack - | "c" | "cont" | "continuation" -> - (* print not-compacted continuation *) - let (e,(lenv,lmem)) = top_frame_exp_state stack in - print_exp lenv e; - interact mode env stack - | "show_casts" -> - Pretty_interp.ignore_casts := false; - interact mode env stack - | "hide_casts" -> - Pretty_interp.ignore_casts := true; - interact mode env stack - | "q" | "quit" | "exit" -> exit 0 - | _ -> debugf "%s\n" usage; interact mode env stack + begin + match command' with + | "s" | "step" -> Step + | "n" | "next" -> Next + | "r" | "run" -> Run + | "e" | "env" | "environment" -> + Reg.iter (fun k v -> debugf "%s\n" (Reg.to_string k v)) reg; + interact mode env stack + | "m" | "mem" | "memory" -> + Mem.iter (fun k v -> debugf "%s\n" (Mem.to_string k v)) mem; + interact mode env stack + | "bt" | "backtrace" | "stack" -> + List.iter (fun (e, (env, mem)) -> print_exp env e) (compact_stack stack); + interact mode env stack + | "c" | "cont" | "continuation" -> + (* print not-compacted continuation *) + let e, (lenv, lmem) = top_frame_exp_state stack in + print_exp lenv e; + interact mode env stack + | "show_casts" -> + Pretty_interp.ignore_casts := false; + interact mode env stack + | "hide_casts" -> + Pretty_interp.ignore_casts := true; + interact mode env stack + | "q" | "quit" | "exit" -> exit 0 + | _ -> + debugf "%s\n" usage; + interact mode env stack end in let rec loop mode env = function - | Value v -> - debugf "%s: %s %s\n" (grey name) (blue "return") (string_of_value v); - true, mode, env - | Action (a, s) -> - let (top_exp,(top_env,top_mem)) = top_frame_exp_state s in - let loc = get_loc (compact_exp top_exp) in - let return, env' = perform_action env a in - let step ?(force=false) () = - if mode = Step || force then begin - debugf "%s\n" (Pretty_interp.pp_exp top_env Printing_functions.red top_exp); - interact mode env s - end else - mode in - let show act lhs arrow rhs = debugf "%s: %s: %s %s %s\n" - (grey loc) (green act) lhs (blue arrow) rhs in - let left = "<-" and right = "->" in - let (mode',env',s) = begin match a with - | Read_reg (reg, sub) -> - show "read_reg" (reg_to_string reg ^ sub_to_string sub) right (string_of_value return); - step (),env',s - | Write_reg (reg, sub, value) -> - assert (return = unit_lit); - show "write_reg" (reg_to_string reg ^ sub_to_string sub) left (string_of_value value); - step (),env',s - | Read_mem (id, args, sub) -> - show "read_mem" (id_to_string id ^ string_of_value args ^ sub_to_string sub) right (string_of_value return); - step (),env',s - | Write_mem (id, args, sub, value) -> - assert (return = unit_lit); - show "write_mem" (id_to_string id ^ string_of_value args ^ sub_to_string sub) left (string_of_value value); - step (),env',s - (* distinguish single argument for pretty-printing *) - | Call_extern (f, (V_tuple _ as args)) -> - show "call_lib" (f ^ string_of_value args) right (string_of_value return); - step (),env',s - | Call_extern (f, arg) -> - show "call_lib" (sprintf "%s(%s)" f (string_of_value arg)) right (string_of_value return); - step (),env',s - | Interp.Step _ -> - assert (return = unit_lit); - show "breakpoint" "" "" ""; - step ~force:true (),env',s - | Nondet exps -> - let stacks = List.sort (fun (_,i1) (_,i2) -> compare i1 i2) - (List.combine (List.map (set_in_context s) exps) - (List.map (fun _ -> Random.bits ()) exps)) in - show "nondeterministic evaluation begun" "" "" ""; - let (_,_,env') = List.fold_right (fun (stack,_) (_,_,env') -> loop mode env' (resume {eager_eval = (mode = Run); track_values = false;} stack None)) stacks (false,mode,env'); in - show "nondeterministic evaluation ended" "" "" ""; - step (),env',s - | Exit e -> - show "exiting current evaluation" "" "" ""; - step (),env', (set_in_context s e) - | Barrier (_, _) | Write_next_IA _ -> - failwith "unexpected action" - end in - loop mode' env' (resume {eager_eval = (mode' = Run);track_values = false} s (Some return)) - | Error(l, e) -> - debugf "%s: %s: %s\n" (grey (loc_to_string l)) (red "error") e; - false, mode, env in - debugf "%s: %s %s\n" (grey name) (blue "evaluate") - (Pretty_interp.pp_exp Interp.eenv Printing_functions.red entry); - let mode = match mode with - | None -> if eager_eval then Run else Step - | Some m -> m in + | Value v -> + debugf "%s: %s %s\n" (grey name) (blue "return") (string_of_value v); + (true, mode, env) + | Action (a, s) -> + let top_exp, (top_env, top_mem) = top_frame_exp_state s in + let loc = get_loc (compact_exp top_exp) in + let return, env' = perform_action env a in + let step ?(force = false) () = + if mode = Step || force then begin + debugf "%s\n" (Pretty_interp.pp_exp top_env Printing_functions.red top_exp); + interact mode env s + end + else mode + in + let show act lhs arrow rhs = debugf "%s: %s: %s %s %s\n" (grey loc) (green act) lhs (blue arrow) rhs in + let left = "<-" and right = "->" in + let mode', env', s = + begin + match a with + | Read_reg (reg, sub) -> + show "read_reg" (reg_to_string reg ^ sub_to_string sub) right (string_of_value return); + (step (), env', s) + | Write_reg (reg, sub, value) -> + assert (return = unit_lit); + show "write_reg" (reg_to_string reg ^ sub_to_string sub) left (string_of_value value); + (step (), env', s) + | Read_mem (id, args, sub) -> + show "read_mem" + (id_to_string id ^ string_of_value args ^ sub_to_string sub) + right (string_of_value return); + (step (), env', s) + | Write_mem (id, args, sub, value) -> + assert (return = unit_lit); + show "write_mem" + (id_to_string id ^ string_of_value args ^ sub_to_string sub) + left (string_of_value value); + (step (), env', s) + (* distinguish single argument for pretty-printing *) + | Call_extern (f, (V_tuple _ as args)) -> + show "call_lib" (f ^ string_of_value args) right (string_of_value return); + (step (), env', s) + | Call_extern (f, arg) -> + show "call_lib" (sprintf "%s(%s)" f (string_of_value arg)) right (string_of_value return); + (step (), env', s) + | Interp.Step _ -> + assert (return = unit_lit); + show "breakpoint" "" "" ""; + (step ~force:true (), env', s) + | Nondet exps -> + let stacks = + List.sort + (fun (_, i1) (_, i2) -> compare i1 i2) + (List.combine (List.map (set_in_context s) exps) (List.map (fun _ -> Random.bits ()) exps)) + in + show "nondeterministic evaluation begun" "" "" ""; + let _, _, env' = + List.fold_right + (fun (stack, _) (_, _, env') -> + loop mode env' (resume { eager_eval = mode = Run; track_values = false } stack None) + ) + stacks (false, mode, env') + in + show "nondeterministic evaluation ended" "" "" ""; + (step (), env', s) + | Exit e -> + show "exiting current evaluation" "" "" ""; + (step (), env', set_in_context s e) + | Barrier (_, _) | Write_next_IA _ -> failwith "unexpected action" + end + in + loop mode' env' (resume { eager_eval = mode' = Run; track_values = false } s (Some return)) + | Error (l, e) -> + debugf "%s: %s: %s\n" (grey (loc_to_string l)) (red "error") e; + (false, mode, env) + in + debugf "%s: %s %s\n" (grey name) (blue "evaluate") (Pretty_interp.pp_exp Interp.eenv Printing_functions.red entry); + let mode = match mode with None -> if eager_eval then Run else Step | Some m -> m in try Printexc.record_backtrace true; - loop mode (reg, mem) (interp {eager_eval = eager_eval; track_values = false} (fun id -> None) test entry) + loop mode (reg, mem) (interp { eager_eval; track_values = false } (fun id -> None) test entry) with e -> let trace = Printexc.get_backtrace () in debugf "%s: %s %s\n%s\n" (grey name) (red "interpretor error") (Printexc.to_string e) trace; - false, mode, (reg, mem) -;; + (false, mode, (reg, mem)) diff --git a/src/lem_interp/run_interp_model.ml b/src/lem_interp/run_interp_model.ml index 26da9eb98..5308d915c 100644 --- a/src/lem_interp/run_interp_model.ml +++ b/src/lem_interp/run_interp_model.ml @@ -75,432 +75,456 @@ open Printing_functions open Nat_big_num module Reg = struct - include Map.Make(struct type t = string let compare = String.compare end) - let to_string id v = - sprintf "%s -> %s" id (register_value_to_string v) - let find id m = -(* eprintf "reg_find called with %s\n" id; *) + include Map.Make (struct + type t = string + let compare = String.compare + end) + let to_string id v = sprintf "%s -> %s" id (register_value_to_string v) + let find id m = + (* eprintf "reg_find called with %s\n" id; *) let v = find id m in -(* eprintf "%s -> %s\n" id (val_to_string v);*) + (* eprintf "%s -> %s\n" id (val_to_string v);*) v -end ;; +end -let compare_addresses (Address_lifted(v1,n1)) (Address_lifted(v2,n2)) = - let rec comp v1s v2s = match (v1s,v2s) with - | ([],[]) -> 0 - | ([],_) -> -1 - | (_,[]) -> 1 - | (v1::v1s,v2::v2s) -> - match Pervasives.compare v1 v2 with - | 0 -> comp v1s v2s - | ans -> ans in - match n1,n2 with - | Some(n1),Some(n2) -> compare n1 n2 +let compare_addresses (Address_lifted (v1, n1)) (Address_lifted (v2, n2)) = + let rec comp v1s v2s = + match (v1s, v2s) with + | [], [] -> 0 + | [], _ -> -1 + | _, [] -> 1 + | v1 :: v1s, v2 :: v2s -> ( + match Pervasives.compare v1 v2 with 0 -> comp v1s v2s | ans -> ans + ) + in + match (n1, n2) with + | Some n1, Some n2 -> compare n1 n2 | _ -> - let l1 = List.length v1 in - let l2 = List.length v2 in - if l1 > l2 then 1 - else if l1 < l2 then -1 - else comp v1 v2 + let l1 = List.length v1 in + let l2 = List.length v2 in + if l1 > l2 then 1 else if l1 < l2 then -1 else comp v1 v2 let default_endian = ref E_big_endian let default_order = ref D_increasing module Mem = struct - include Map.Make(struct - type t = num - let compare v1 v2 = compare v1 v2 - end) - let find idx m = - if mem idx m - then find idx m - else List.hd(memory_value_undef 1) + include Map.Make (struct + type t = num + let compare v1 v2 = compare v1 v2 + end) + let find idx m = if mem idx m then find idx m else List.hd (memory_value_undef 1) - let to_string loc v = - sprintf "[%s] -> %s" (to_string loc) - (memory_value_to_string !default_endian [v]) + let to_string loc v = sprintf "[%s] -> %s" (to_string loc) (memory_value_to_string !default_endian [v]) end -let slice register_vector (start,stop) = - if register_vector.rv_dir = D_increasing - then slice_reg_value register_vector start stop - else - (*Interface turns start and stop into forms for +let slice register_vector (start, stop) = + if register_vector.rv_dir = D_increasing then slice_reg_value register_vector start stop + else ( + (*Interface turns start and stop into forms for increasing because ppcmem only speaks increasing, so here we turn it back *) - let startd = register_vector.rv_start_internal- start in + let startd = register_vector.rv_start_internal - start in let stopd = startd - (stop - start) in -(* let _ = Printf.eprintf "slice decreasing with %i, %i, %i\n" startd stopd register_vector.rv_start in*) + (* let _ = Printf.eprintf "slice decreasing with %i, %i, %i\n" startd stopd register_vector.rv_start in*) slice_reg_value register_vector start stop + ) let big_num_unit = of_int 1 -let rec list_update index start stop e vals = - match vals with - | [] -> [] - | x :: xs -> - if Nat_big_num.equal index stop - then e :: xs - else if Nat_big_num.greater_equal index start - then e :: (list_update (Nat_big_num.add index big_num_unit) start stop e xs) - else x :: (list_update (Nat_big_num.add index big_num_unit) start stop e xs) +let rec list_update index start stop e vals = + match vals with + | [] -> [] + | x :: xs -> + if Nat_big_num.equal index stop then e :: xs + else if Nat_big_num.greater_equal index start then + e :: list_update (Nat_big_num.add index big_num_unit) start stop e xs + else x :: list_update (Nat_big_num.add index big_num_unit) start stop e xs let rec list_update_list index start stop es vals = match vals with | [] -> [] - | x :: xs -> - match es with - | [] -> xs - | e::es -> - if Nat_big_num.equal index stop - then e::xs - else if Nat_big_num.greater_equal index start - then e :: (list_update_list (Nat_big_num.add index big_num_unit) start stop es xs) - else x :: (list_update_list (Nat_big_num.add index big_num_unit) start stop (e::es) xs) + | x :: xs -> ( + match es with + | [] -> xs + | e :: es -> + if Nat_big_num.equal index stop then e :: xs + else if Nat_big_num.greater_equal index start then + e :: list_update_list (Nat_big_num.add index big_num_unit) start stop es xs + else x :: list_update_list (Nat_big_num.add index big_num_unit) start stop (e :: es) xs + ) -let fupdate_slice reg_name original e (start,stop) = - if original.rv_dir = D_increasing - then update_reg_value_slice reg_name original start stop e - else - (*Interface turns start and stop into forms for +let fupdate_slice reg_name original e (start, stop) = + if original.rv_dir = D_increasing then update_reg_value_slice reg_name original start stop e + else ( + (*Interface turns start and stop into forms for increasing because ppcmem only speaks increasing, so here we turn it back *) - let startd = original.rv_start_internal- start in + let startd = original.rv_start_internal - start in let stopd = startd - (stop - start) in -(* let _ = Printf.eprintf "fupdate_slice: starts at %i, %i -> %i,%i -> %i\n" original.rv_start_internal start startd stop stopd in *) + (* let _ = Printf.eprintf "fupdate_slice: starts at %i, %i -> %i,%i -> %i\n" original.rv_start_internal start startd stop stopd in *) update_reg_value_slice reg_name original startd stopd e + ) -let combine_slices (start, stop) (inner_start,inner_stop) = - (start + inner_start, start + inner_stop) +let combine_slices (start, stop) (inner_start, inner_stop) = (start + inner_start, start + inner_stop) -let unit_lit = (L_aux(L_unit,Interp_ast.Unknown)) +let unit_lit = L_aux (L_unit, Interp_ast.Unknown) -let align_addr addr size = - sub addr (modulus addr size) +let align_addr addr size = sub addr (modulus addr size) -let rec perform_action ((reg, mem, tagmem) as env, cap_size) = function - (* registers *) - | Read_reg1(Reg(id,_,_,_), _) -> (Some(Reg.find id reg), env) - | Read_reg1(Reg_slice(id, _, _, range), _) - | Read_reg1(Reg_field(id, _, _, _, range), _) -> (Some(slice (Reg.find id reg) range), env) - | Read_reg1(Reg_f_slice(id, _,_,_, range, mini_range), _) -> - (Some(slice (slice (Reg.find id reg) range) mini_range),env) - | Write_reg1(Reg(id,_,_,_), value, _) -> (None, (Reg.add id value reg,mem,tagmem)) - | Write_reg1((Reg_slice(id,_,_,range) as reg_n),value, _) - | Write_reg1((Reg_field(id,_,_,_,range) as reg_n),value,_)-> - let old_val = Reg.find id reg in - let new_val = fupdate_slice reg_n old_val value range in - (None, (Reg.add id new_val reg, mem,tagmem)) - | Write_reg1((Reg_f_slice(id,_,_,_,range,mini_range) as reg_n),value,_) -> - let old_val = Reg.find id reg in - let new_val = fupdate_slice reg_n old_val value (combine_slices range mini_range) in - (None,(Reg.add id new_val reg,mem,tagmem)) - | Read_mem1(kind,location, length, _,_) -> - let address = match address_of_address_lifted location with - | Some a -> a - | None -> assert false (*TODO remember how to report an error *)in - let naddress = integer_of_address address in - let rec reading (n : num) length = - if length = 0 - then [] - else (Mem.find n mem)::(reading (add n big_num_unit) (length - 1)) in - (Some (register_value_of_memory_value (reading naddress length) !default_order), env) - | Read_mem_tagged0(kind,location, length, _,_) -> - let address = match address_of_address_lifted location with - | Some a -> a - | None -> assert false (*TODO remember how to report an error *)in - let naddress = integer_of_address address in - let tag = Mem.find (align_addr naddress cap_size) tagmem in - let rec reading (n : num) length = - if length = 0 - then [] - else (Mem.find n mem)::(reading (add n big_num_unit) (length - 1)) in - (Some (register_value_of_memory_value (tag::(reading naddress length)) !default_order), env) - | Write_mem0(kind,location, length, _, bytes,_,_) -> - let address = match address_of_address_lifted location with - | Some a -> a - | None -> assert false (*TODO remember how to report an error *)in - let naddress = integer_of_address address in - let rec writing location length bytes mem = - if length = 0 - then mem - else match bytes with - | [] -> mem - | b::bytes -> - writing (add location big_num_unit) (length - 1) bytes (Mem.add location b mem) in - (None,(reg,writing naddress length bytes mem,tagmem)) - | Write_memv1(Some location, bytes, _, _) -> - let address = match address_of_address_lifted location with - | Some a -> a - | _ -> failwith "Write address not known" in - let naddress = integer_of_address address in - let length = List.length bytes in - let rec writing location length bytes mem = - if length = 0 - then mem - else match bytes with - | [] -> mem - | b::bytes -> - writing (add location big_num_unit) (length - 1) bytes (Mem.add location b mem) in - (None, (reg,writing naddress length bytes mem, tagmem)) - | Write_memv_tagged0(Some location, (tag, bytes), _, _) -> - let address = match address_of_address_lifted location with - | Some a -> a - | _ -> failwith "Write address not known" in - let naddress = integer_of_address address in - let length = List.length bytes in - let rec writing location length bytes mem = - if length = 0 - then mem - else match bytes with - | [] -> mem - | b::bytes -> - writing (add location big_num_unit) (length - 1) bytes (Mem.add location b mem) in - let tagmem = Mem.add (align_addr naddress cap_size) (Byte_lifted ([Bitl_zero;Bitl_zero;Bitl_zero;Bitl_zero;Bitl_zero;Bitl_zero;Bitl_zero;tag])) tagmem in - (None, (reg,writing naddress length bytes mem, tagmem)) - | _ -> (None, env) -;; +let rec perform_action (((reg, mem, tagmem) as env), cap_size) = function + (* registers *) + | Read_reg1 (Reg (id, _, _, _), _) -> (Some (Reg.find id reg), env) + | Read_reg1 (Reg_slice (id, _, _, range), _) | Read_reg1 (Reg_field (id, _, _, _, range), _) -> + (Some (slice (Reg.find id reg) range), env) + | Read_reg1 (Reg_f_slice (id, _, _, _, range, mini_range), _) -> + (Some (slice (slice (Reg.find id reg) range) mini_range), env) + | Write_reg1 (Reg (id, _, _, _), value, _) -> (None, (Reg.add id value reg, mem, tagmem)) + | Write_reg1 ((Reg_slice (id, _, _, range) as reg_n), value, _) + | Write_reg1 ((Reg_field (id, _, _, _, range) as reg_n), value, _) -> + let old_val = Reg.find id reg in + let new_val = fupdate_slice reg_n old_val value range in + (None, (Reg.add id new_val reg, mem, tagmem)) + | Write_reg1 ((Reg_f_slice (id, _, _, _, range, mini_range) as reg_n), value, _) -> + let old_val = Reg.find id reg in + let new_val = fupdate_slice reg_n old_val value (combine_slices range mini_range) in + (None, (Reg.add id new_val reg, mem, tagmem)) + | Read_mem1 (kind, location, length, _, _) -> + let address = + match address_of_address_lifted location with + | Some a -> a + | None -> assert false (*TODO remember how to report an error *) + in + let naddress = integer_of_address address in + let rec reading (n : num) length = + if length = 0 then [] else Mem.find n mem :: reading (add n big_num_unit) (length - 1) + in + (Some (register_value_of_memory_value (reading naddress length) !default_order), env) + | Read_mem_tagged0 (kind, location, length, _, _) -> + let address = + match address_of_address_lifted location with + | Some a -> a + | None -> assert false (*TODO remember how to report an error *) + in + let naddress = integer_of_address address in + let tag = Mem.find (align_addr naddress cap_size) tagmem in + let rec reading (n : num) length = + if length = 0 then [] else Mem.find n mem :: reading (add n big_num_unit) (length - 1) + in + (Some (register_value_of_memory_value (tag :: reading naddress length) !default_order), env) + | Write_mem0 (kind, location, length, _, bytes, _, _) -> + let address = + match address_of_address_lifted location with + | Some a -> a + | None -> assert false (*TODO remember how to report an error *) + in + let naddress = integer_of_address address in + let rec writing location length bytes mem = + if length = 0 then mem + else ( + match bytes with + | [] -> mem + | b :: bytes -> writing (add location big_num_unit) (length - 1) bytes (Mem.add location b mem) + ) + in + (None, (reg, writing naddress length bytes mem, tagmem)) + | Write_memv1 (Some location, bytes, _, _) -> + let address = + match address_of_address_lifted location with Some a -> a | _ -> failwith "Write address not known" + in + let naddress = integer_of_address address in + let length = List.length bytes in + let rec writing location length bytes mem = + if length = 0 then mem + else ( + match bytes with + | [] -> mem + | b :: bytes -> writing (add location big_num_unit) (length - 1) bytes (Mem.add location b mem) + ) + in + (None, (reg, writing naddress length bytes mem, tagmem)) + | Write_memv_tagged0 (Some location, (tag, bytes), _, _) -> + let address = + match address_of_address_lifted location with Some a -> a | _ -> failwith "Write address not known" + in + let naddress = integer_of_address address in + let length = List.length bytes in + let rec writing location length bytes mem = + if length = 0 then mem + else ( + match bytes with + | [] -> mem + | b :: bytes -> writing (add location big_num_unit) (length - 1) bytes (Mem.add location b mem) + ) + in + let tagmem = + Mem.add (align_addr naddress cap_size) + (Byte_lifted [Bitl_zero; Bitl_zero; Bitl_zero; Bitl_zero; Bitl_zero; Bitl_zero; Bitl_zero; tag]) + tagmem + in + (None, (reg, writing naddress length bytes mem, tagmem)) + | _ -> (None, env) let interact_print = ref true let result_print = ref true let error_print = ref true -let interactf : ('a, out_channel, unit) format -> 'a = - function f -> if !interact_print then eprintf f else ifprintf stderr f -let errorf : ('a, out_channel, unit) format -> 'a = - function f -> if !error_print then eprintf f else ifprintf stderr f -let resultf : ('a, out_channel, unit) format -> 'a = - function f -> if !result_print then eprintf f else ifprintf stderr f +let interactf : ('a, out_channel, unit) format -> 'a = function + | f -> if !interact_print then eprintf f else ifprintf stderr f +let errorf : ('a, out_channel, unit) format -> 'a = function + | f -> if !error_print then eprintf f else ifprintf stderr f +let resultf : ('a, out_channel, unit) format -> 'a = function + | f -> if !result_print then eprintf f else ifprintf stderr f type interactive_mode = Step | Run | Next -let mode_to_string = function - | Step -> "step" - | Run -> "run" - | Next -> "next" +let mode_to_string = function Step -> "step" | Run -> "run" | Next -> "next" -let run - (istate : instruction_state) - reg - mem - tagmem - cap_size - eager_eval - track_dependencies - mode - name = +let run (istate : instruction_state) reg mem tagmem cap_size eager_eval track_dependencies mode name = (* interactive loop for step-by-step execution *) - let usage = "Usage: - step go to next action [default] - next go to next break point - run complete current execution - track begin/end tracking register dependencies - bt print call stack - cont print continuation of the top stack frame - reg print content of environment - mem print content of memory - exh run interpreter exhaustively with unknown and print events - quit exit interpreter" in + let usage = + "Usage:\n\ + \ step go to next action [default]\n\ + \ next go to next break point\n\ + \ run complete current execution\n\ + \ track begin/end tracking register dependencies\n\ + \ bt print call stack\n\ + \ cont print continuation of the top stack frame\n\ + \ reg print content of environment\n\ + \ mem print content of memory\n\ + \ exh run interpreter exhaustively with unknown and print events \n\ + \ quit exit interpreter" + in let rec interact mode ((reg, mem, tagmem) as env) stack = - flush_all(); + flush_all (); let command = Pervasives.read_line () in let command' = if command = "" then mode_to_string mode else command in - begin match command' with + begin + match command' with | "s" | "step" -> Step | "n" | "next" -> Next | "r" | "run" -> Run | "rg" | "reg" | "registers" -> - Reg.iter (fun k v -> interactf "%s\n" (Reg.to_string k v)) reg; - interact mode env stack + Reg.iter (fun k v -> interactf "%s\n" (Reg.to_string k v)) reg; + interact mode env stack | "m" | "mem" | "memory" -> - Mem.iter (fun k v -> interactf "%s\n" (Mem.to_string k v)) mem; - interact mode env stack + Mem.iter (fun k v -> interactf "%s\n" (Mem.to_string k v)) mem; + interact mode env stack | "bt" | "backtrace" | "stack" -> - print_backtrace_compact (fun s -> interactf "%s" s) stack; - interact mode env stack + print_backtrace_compact (fun s -> interactf "%s" s) stack; + interact mode env stack | "e" | "exh" | "exhaust" -> - interactf "interpreting exhaustively from current state\n"; - let events = interp_exhaustive false None stack in - interactf "%s" (format_events events); - interact mode env stack + interactf "interpreting exhaustively from current state\n"; + let events = interp_exhaustive false None stack in + interactf "%s" (format_events events); + interact mode env stack | "c" | "cont" | "continuation" -> - (* print not-compacted continuation *) - print_continuation (fun s -> interactf "%s" s) stack; - interact mode env stack + (* print not-compacted continuation *) + print_continuation (fun s -> interactf "%s" s) stack; + interact mode env stack | "track" | "t" -> - track_dependencies := not(!track_dependencies); - interact mode env stack + track_dependencies := not !track_dependencies; + interact mode env stack | "show_casts" -> - Pretty_interp.ignore_casts := false; - interact mode env stack + Pretty_interp.ignore_casts := false; + interact mode env stack | "hide_casts" -> - Pretty_interp.ignore_casts := true; - interact mode env stack + Pretty_interp.ignore_casts := true; + interact mode env stack | "q" | "quit" | "exit" -> exit 0 - | _ -> interactf "%s\n" usage; interact mode env stack + | _ -> + interactf "%s\n" usage; + interact mode env stack end in - let show act lhs arrow rhs = interactf "%s: %s %s %s\n" - (green act) lhs (blue arrow) rhs in + let show act lhs arrow rhs = interactf "%s: %s %s %s\n" (green act) lhs (blue arrow) rhs in let left = "<-" and right = "->" in let rec loop mode env = function | Done0 -> - interactf "%s: %s\n" (grey name) (blue "done"); - (true, mode, !track_dependencies, env) - | Error1 s -> - errorf "%s: %s: %s\n" (grey name) (red "error") s; - (false, mode, !track_dependencies, env) + interactf "%s: %s\n" (grey name) (blue "done"); + (true, mode, !track_dependencies, env) + | Error1 s -> + errorf "%s: %s: %s\n" (grey name) (red "error") s; + (false, mode, !track_dependencies, env) | Escape0 (None, _) -> - show "exiting current instruction" "" "" ""; - interactf "%s: %s\n" (grey name) (blue "done"); - (true, mode, !track_dependencies, env) + show "exiting current instruction" "" "" ""; + interactf "%s: %s\n" (grey name) (blue "done"); + (true, mode, !track_dependencies, env) | Fail1 (Some s) -> - errorf "%s: %s: %s\n" (grey name) (red "assertion failed") s; - (false, mode, !track_dependencies, env) + errorf "%s: %s: %s\n" (grey name) (red "assertion failed") s; + (false, mode, !track_dependencies, env) | Fail1 None -> - errorf "%s: %s: %s\n" (grey name) (red "assertion failed") "No message provided"; - (false, mode, !track_dependencies, env) + errorf "%s: %s: %s\n" (grey name) (red "assertion failed") "No message provided"; + (false, mode, !track_dependencies, env) | action -> - let (return,env') = perform_action (env, cap_size) action in - let step ?(force=false) (state: instruction_state) = - let stack = match state with IState(stack,_) -> stack in - let (top_exp,(top_env,top_mem)) = top_frame_exp_state stack in - let loc = get_loc (compact_exp top_exp) in - if mode = Step || force then begin - interactf "%s\n" (Pretty_interp.pp_exp top_env top_mem Printing_functions.red true top_exp); - interact mode env' state - end else - mode in - let (mode', env', next) = - (match action with - | Read_reg1(reg,next_thunk) -> - (match return with - | Some(value) -> - show "read_reg" (reg_name_to_string reg) right (register_value_to_string value); - let next = next_thunk value in - (step next, env', next) - | None -> assert false) - | Write_reg1(reg,value,next) -> - show "write_reg" (reg_name_to_string reg) left (register_value_to_string value); - (step next, env', next) - | Read_mem1(kind, (Address_lifted(location,_)), length, tracking, next_thunk) -> - (match return with - | Some(value) -> - show "read_mem" - (memory_value_to_string !default_endian location) right - (register_value_to_string value); - (match tracking with - | None -> () - | Some(deps) -> - show "read_mem address depended on" (dependencies_to_string deps) "" ""); - let next = next_thunk (memory_value_of_register_value value) in - (step next, env', next) - | None -> assert false) - | Read_mem_tagged0(kind, (Address_lifted(location,_)), length, tracking, next_thunk) -> - (match return with - | Some(value) -> - show "read_mem_tagged" - (memory_value_to_string !default_endian location) right - (register_value_to_string value); - (match tracking with - | None -> () - | Some(deps) -> - show "read_mem address depended on" (dependencies_to_string deps) "" ""); - let next = - (match (memory_value_of_register_value value) with - | (Byte_lifted tag)::bytes -> next_thunk ((List.nth tag 7), bytes) - | _ -> assert false) in - (step next, env', next) - | None -> assert false) - | Write_mem0(kind,(Address_lifted(location,_)), length, tracking, value, v_tracking, next_thunk) -> - show "write_mem" (memory_value_to_string !default_endian location) left - (memory_value_to_string !default_endian value); - (match (tracking,v_tracking) with - | (None,None) -> (); - | (Some(deps),None) -> - show "write_mem address depended on" (dependencies_to_string deps) "" ""; - | (None,Some(deps)) -> - show "write_mem value depended on" (dependencies_to_string deps) "" ""; - | (Some(deps),Some(vdeps)) -> - show "write_mem address depended on" (dependencies_to_string deps) "" ""; - show "write_mem value depended on" (dependencies_to_string vdeps) "" "";); - let next = next_thunk true in - (step next,env',next) - | Write_memv1(Some(Address_lifted(location,_)),value,_,next_thunk) -> - show "write_mem value" (memory_value_to_string !default_endian location) left (memory_value_to_string !default_endian value); - let next = next_thunk true in - (step next,env',next) - | Write_memv_tagged0(Some(Address_lifted(location,_)),(tag, value),_,next_thunk) -> - show "write_mem_tagged value" (memory_value_to_string !default_endian location) left (memory_value_to_string !default_endian value); - let next = next_thunk true in - (step next,env',next) - | Write_ea1(_,(Address_lifted(location,_)), size,_,next) -> - show "write_announce" (memory_value_to_string !default_endian location) left ((string_of_int size) ^ " bytes"); - (step next, env, next) - | Excl_res1(next_thunk) -> - show "exclusive_result" "" "" ""; - let next = next_thunk true in - (step next,env',next) - | Barrier1(bkind,next) -> - show "mem_barrier" "" "" ""; - (step next, env, next) - | Internal0(None,None, next) -> - show "stepped" "" "" ""; - (step ~force:true next,env',next) - | Internal0((Some fn),None,next) -> - show "evaluated" fn "" ""; - (step ~force:true next, env',next) - | Internal0(None,Some vdisp,next) -> - show "evaluated" (vdisp ()) "" ""; - (step ~force:true next,env', next) - | Internal0((Some fn),(Some vdisp),next) -> - show "evaluated" (fn ^ " " ^ (vdisp ())) "" ""; - (step ~force:true next, env', next) - | Nondet_choice(nondets, next) -> - let choose_order = List.sort (fun (_,i1) (_,i2) -> Pervasives.compare i1 i2) - (List.combine nondets (List.map (fun _ -> Random.bits ()) nondets)) in - show "nondeterministic evaluation begun" "" "" ""; - let (_,_,_,env') = List.fold_right (fun (next,_) (_,_,_,env') -> - loop mode env' (interp0 (make_mode (mode=Run) !track_dependencies true) next)) - choose_order (false,mode,!track_dependencies,env'); in - show "nondeterministic evaluation ended" "" "" ""; - (step next,env',next) - | Analysis_non_det (possible_istates, i_state) -> - let choose_order = List.sort (fun (_,i1) (_,i2) -> Pervasives.compare i1 i2) - (List.combine possible_istates (List.map (fun _ -> Random.bits ()) possible_istates)) in - if possible_istates = [] - then (step i_state,env',i_state) - else begin - show "undefined triggered a non_det" "" "" ""; - let (_,_,_,env') = List.fold_right (fun (next,_) (_,_,_,env') -> - loop mode env' (interp0 (make_mode (mode=Run) !track_dependencies true) next)) - choose_order (false,mode,!track_dependencies,env'); in - (step i_state,env',i_state) end - | Escape0(Some e,_) -> - show "exiting current evaluation" "" "" ""; - step e,env', e - | Escape0 _ -> assert false - | Error1 _ -> failwith "Internal error" - | Fail1 _ -> failwith "Assertion in program failed" - | Done0 -> - show "done evalution" "" "" ""; - assert false - | Footprint1 _ -> assert false - | Write_ea1 _ -> assert false - | Write_memv1 _ -> assert false) - (*| _ -> assert false*) - in - loop mode' env' (Interp_inter_imp.interp0 (make_mode (mode' = Run) !track_dependencies true) next) in - let mode = match mode with - | None -> if eager_eval then Run else Step - | Some m -> m in + let return, env' = perform_action (env, cap_size) action in + let step ?(force = false) (state : instruction_state) = + let stack = match state with IState (stack, _) -> stack in + let top_exp, (top_env, top_mem) = top_frame_exp_state stack in + let loc = get_loc (compact_exp top_exp) in + if mode = Step || force then begin + interactf "%s\n" (Pretty_interp.pp_exp top_env top_mem Printing_functions.red true top_exp); + interact mode env' state + end + else mode + in + let mode', env', next = + match action with + | Read_reg1 (reg, next_thunk) -> ( + match return with + | Some value -> + show "read_reg" (reg_name_to_string reg) right (register_value_to_string value); + let next = next_thunk value in + (step next, env', next) + | None -> assert false + ) + | Write_reg1 (reg, value, next) -> + show "write_reg" (reg_name_to_string reg) left (register_value_to_string value); + (step next, env', next) + | Read_mem1 (kind, Address_lifted (location, _), length, tracking, next_thunk) -> ( + match return with + | Some value -> + show "read_mem" + (memory_value_to_string !default_endian location) + right (register_value_to_string value); + ( match tracking with + | None -> () + | Some deps -> show "read_mem address depended on" (dependencies_to_string deps) "" "" + ); + let next = next_thunk (memory_value_of_register_value value) in + (step next, env', next) + | None -> assert false + ) + | Read_mem_tagged0 (kind, Address_lifted (location, _), length, tracking, next_thunk) -> ( + match return with + | Some value -> + show "read_mem_tagged" + (memory_value_to_string !default_endian location) + right (register_value_to_string value); + ( match tracking with + | None -> () + | Some deps -> show "read_mem address depended on" (dependencies_to_string deps) "" "" + ); + let next = + match memory_value_of_register_value value with + | Byte_lifted tag :: bytes -> next_thunk (List.nth tag 7, bytes) + | _ -> assert false + in + (step next, env', next) + | None -> assert false + ) + | Write_mem0 (kind, Address_lifted (location, _), length, tracking, value, v_tracking, next_thunk) -> + show "write_mem" + (memory_value_to_string !default_endian location) + left + (memory_value_to_string !default_endian value); + ( match (tracking, v_tracking) with + | None, None -> () + | Some deps, None -> show "write_mem address depended on" (dependencies_to_string deps) "" "" + | None, Some deps -> show "write_mem value depended on" (dependencies_to_string deps) "" "" + | Some deps, Some vdeps -> + show "write_mem address depended on" (dependencies_to_string deps) "" ""; + show "write_mem value depended on" (dependencies_to_string vdeps) "" "" + ); + let next = next_thunk true in + (step next, env', next) + | Write_memv1 (Some (Address_lifted (location, _)), value, _, next_thunk) -> + show "write_mem value" + (memory_value_to_string !default_endian location) + left + (memory_value_to_string !default_endian value); + let next = next_thunk true in + (step next, env', next) + | Write_memv_tagged0 (Some (Address_lifted (location, _)), (tag, value), _, next_thunk) -> + show "write_mem_tagged value" + (memory_value_to_string !default_endian location) + left + (memory_value_to_string !default_endian value); + let next = next_thunk true in + (step next, env', next) + | Write_ea1 (_, Address_lifted (location, _), size, _, next) -> + show "write_announce" + (memory_value_to_string !default_endian location) + left + (string_of_int size ^ " bytes"); + (step next, env, next) + | Excl_res1 next_thunk -> + show "exclusive_result" "" "" ""; + let next = next_thunk true in + (step next, env', next) + | Barrier1 (bkind, next) -> + show "mem_barrier" "" "" ""; + (step next, env, next) + | Internal0 (None, None, next) -> + show "stepped" "" "" ""; + (step ~force:true next, env', next) + | Internal0 (Some fn, None, next) -> + show "evaluated" fn "" ""; + (step ~force:true next, env', next) + | Internal0 (None, Some vdisp, next) -> + show "evaluated" (vdisp ()) "" ""; + (step ~force:true next, env', next) + | Internal0 (Some fn, Some vdisp, next) -> + show "evaluated" (fn ^ " " ^ vdisp ()) "" ""; + (step ~force:true next, env', next) + | Nondet_choice (nondets, next) -> + let choose_order = + List.sort + (fun (_, i1) (_, i2) -> Pervasives.compare i1 i2) + (List.combine nondets (List.map (fun _ -> Random.bits ()) nondets)) + in + show "nondeterministic evaluation begun" "" "" ""; + let _, _, _, env' = + List.fold_right + (fun (next, _) (_, _, _, env') -> + loop mode env' (interp0 (make_mode (mode = Run) !track_dependencies true) next) + ) + choose_order + (false, mode, !track_dependencies, env') + in + show "nondeterministic evaluation ended" "" "" ""; + (step next, env', next) + | Analysis_non_det (possible_istates, i_state) -> + let choose_order = + List.sort + (fun (_, i1) (_, i2) -> Pervasives.compare i1 i2) + (List.combine possible_istates (List.map (fun _ -> Random.bits ()) possible_istates)) + in + if possible_istates = [] then (step i_state, env', i_state) + else begin + show "undefined triggered a non_det" "" "" ""; + let _, _, _, env' = + List.fold_right + (fun (next, _) (_, _, _, env') -> + loop mode env' (interp0 (make_mode (mode = Run) !track_dependencies true) next) + ) + choose_order + (false, mode, !track_dependencies, env') + in + (step i_state, env', i_state) + end + | Escape0 (Some e, _) -> + show "exiting current evaluation" "" "" ""; + (step e, env', e) + | Escape0 _ -> assert false + | Error1 _ -> failwith "Internal error" + | Fail1 _ -> failwith "Assertion in program failed" + | Done0 -> + show "done evalution" "" "" ""; + assert false + | Footprint1 _ -> assert false + | Write_ea1 _ -> assert false + | Write_memv1 _ -> assert false + (*| _ -> assert false*) + in + loop mode' env' (Interp_inter_imp.interp0 (make_mode (mode' = Run) !track_dependencies true) next) + in + let mode = match mode with None -> if eager_eval then Run else Step | Some m -> m in let imode = make_mode eager_eval !track_dependencies true in - let (IState(instr_state,context)) = istate in - let (top_exp,(top_env,top_mem)) = top_frame_exp_state instr_state in - interactf "%s: %s %s\n" (grey name) (blue "evaluate") + let (IState (instr_state, context)) = istate in + let top_exp, (top_env, top_mem) = top_frame_exp_state instr_state in + interactf "%s: %s %s\n" (grey name) (blue "evaluate") (Pretty_interp.pp_exp top_env top_mem Printing_functions.red true top_exp); try Printexc.record_backtrace true; - loop mode (reg, mem,tagmem) (Interp_inter_imp.interp0 imode istate) + loop mode (reg, mem, tagmem) (Interp_inter_imp.interp0 imode istate) with e -> let trace = Printexc.get_backtrace () in interactf "%s: %s %s\n%s\n" (grey name) (red "interpretor error") (Printexc.to_string e) trace; (false, mode, !track_dependencies, (reg, mem, tagmem)) -;; diff --git a/src/lem_interp/run_with_elf.ml b/src/lem_interp/run_with_elf.ml index 45db41738..028ace029 100644 --- a/src/lem_interp/run_with_elf.ml +++ b/src/lem_interp/run_with_elf.ml @@ -65,38 +65,35 @@ (* SUCH DAMAGE. *) (****************************************************************************) -open Printf ;; -open Format ;; -open Big_int ;; -open Interp_ast ;; -open Interp_interface ;; -open Interp_inter_imp ;; -open Sail_impl_base ;; -open Run_interp_model ;; +open Printf +open Format +open Big_int +open Interp_ast +open Interp_interface +open Interp_inter_imp +open Sail_impl_base +open Run_interp_model -open Sail_interface ;; +open Sail_interface -module StringMap = Map.Make(String) +module StringMap = Map.Make (String) -let file = ref "" ;; +let file = ref "" -let rec foldli f acc ?(i=0) = function - | [] -> acc - | x::xs -> foldli f (f i acc x) ~i:(i+1) xs -;; +let rec foldli f acc ?(i = 0) = function [] -> acc | x :: xs -> foldli f (f i acc x) ~i:(i + 1) xs -let endian = ref E_big_endian ;; +let endian = ref E_big_endian -let hex_to_big_int s = big_int_of_int64 (Int64.of_string s) ;; +let hex_to_big_int s = big_int_of_int64 (Int64.of_string s) -let data_mem = (ref Mem.empty : (memory_byte Run_interp_model.Mem.t) ref) ;; -let prog_mem = (ref Mem.empty : (memory_byte Run_interp_model.Mem.t) ref) ;; -let tag_mem = (ref Mem.empty : (memory_byte Run_interp_model.Mem.t) ref);; -let reg = ref Reg.empty ;; -let input_buf = (ref [] : int list ref);; +let data_mem = (ref Mem.empty : memory_byte Run_interp_model.Mem.t ref) +let prog_mem = (ref Mem.empty : memory_byte Run_interp_model.Mem.t ref) +let tag_mem = (ref Mem.empty : memory_byte Run_interp_model.Mem.t ref) +let reg = ref Reg.empty +let input_buf = (ref [] : int list ref) let add_mem byte addr mem = - assert(byte >= 0 && byte < 256); + assert (byte >= 0 && byte < 256); (*Printf.printf "add_mem %s: 0x%02x\n" (Uint64.to_string_hex (Uint64.of_string (Nat_big_num.to_string addr))) byte;*) let mem_byte = memory_byte_of_int byte in let zero_byte = memory_byte_of_int 0 in @@ -104,50 +101,49 @@ let add_mem byte addr mem = tag_mem := Mem.add addr zero_byte !tag_mem let get_reg reg name = - let reg_content = Reg.find name reg in reg_content + let reg_content = Reg.find name reg in + reg_content -let rec load_memory_segment' (bytes,addr) mem = +let rec load_memory_segment' (bytes, addr) mem = match bytes with | [] -> () - | byte::bytes' -> - let data_byte = Char.code byte in - let addr' = Nat_big_num.succ addr in - begin add_mem data_byte addr mem; - load_memory_segment' (bytes',addr') mem - end + | byte :: bytes' -> + let data_byte = Char.code byte in + let addr' = Nat_big_num.succ addr in + begin + add_mem data_byte addr mem; + load_memory_segment' (bytes', addr') mem + end -let rec load_memory_segment (segment: Elf_interpreted_segment.elf64_interpreted_segment) mem = +let rec load_memory_segment (segment : Elf_interpreted_segment.elf64_interpreted_segment) mem = let (Byte_sequence.Sequence bytes) = segment.Elf_interpreted_segment.elf64_segment_body in let addr = segment.Elf_interpreted_segment.elf64_segment_paddr in - load_memory_segment' (bytes,addr) mem - + load_memory_segment' (bytes, addr) mem let rec load_memory_segments segments = - begin match segments with + begin + match segments with | [] -> () - | segment::segments' -> - let (x,w,r) = segment.Elf_interpreted_segment.elf64_segment_flags in - begin - load_memory_segment segment prog_mem; - load_memory_segments segments' - end + | segment :: segments' -> + let x, w, r = segment.Elf_interpreted_segment.elf64_segment_flags in + begin + load_memory_segment segment prog_mem; + load_memory_segments segments' + end end - -let rec read_mem mem address length = - if length = 0 - then [] - else - let byte = - try Mem.find address mem with - | Not_found -> failwith "start address not found" - in - byte :: (read_mem mem (Nat_big_num.succ address) (length - 1)) + +let rec read_mem mem address length = + if length = 0 then [] + else ( + let byte = try Mem.find address mem with Not_found -> failwith "start address not found" in + byte :: read_mem mem (Nat_big_num.succ address) (length - 1) + ) let register_state_zero register_data rbn : register_value = - let (dir,width,start_index) = - try List.assoc rbn register_data with - | Not_found -> failwith ("register_state_zero lookup failed (" ^ rbn) - in register_value_zeros dir width start_index + let dir, width, start_index = + try List.assoc rbn register_data with Not_found -> failwith ("register_state_zero lookup failed (" ^ rbn) + in + register_value_zeros dir width start_index type model = PPC | AArch64 | MIPS (* @@ -514,399 +510,416 @@ let initial_stack_and_reg_data_of_AAarch64_elf_file e_entry all_data_memory = (initial_stack_data, initial_register_abi_data) *) -let mips_register_data_all = [ - (*Pseudo registers*) - ("PC", (D_decreasing, 64, 63)); - ("branchPending", (D_decreasing, 1, 0)); - ("inBranchDelay", (D_decreasing, 1, 0)); - ("delayedPC", (D_decreasing, 64, 63)); - ("nextPC", (D_decreasing, 64, 63)); - (* General purpose registers *) - ("GPR00", (D_decreasing, 64, 63)); - ("GPR01", (D_decreasing, 64, 63)); - ("GPR02", (D_decreasing, 64, 63)); - ("GPR03", (D_decreasing, 64, 63)); - ("GPR04", (D_decreasing, 64, 63)); - ("GPR05", (D_decreasing, 64, 63)); - ("GPR06", (D_decreasing, 64, 63)); - ("GPR07", (D_decreasing, 64, 63)); - ("GPR08", (D_decreasing, 64, 63)); - ("GPR09", (D_decreasing, 64, 63)); - ("GPR10", (D_decreasing, 64, 63)); - ("GPR11", (D_decreasing, 64, 63)); - ("GPR12", (D_decreasing, 64, 63)); - ("GPR13", (D_decreasing, 64, 63)); - ("GPR14", (D_decreasing, 64, 63)); - ("GPR15", (D_decreasing, 64, 63)); - ("GPR16", (D_decreasing, 64, 63)); - ("GPR17", (D_decreasing, 64, 63)); - ("GPR18", (D_decreasing, 64, 63)); - ("GPR19", (D_decreasing, 64, 63)); - ("GPR20", (D_decreasing, 64, 63)); - ("GPR21", (D_decreasing, 64, 63)); - ("GPR22", (D_decreasing, 64, 63)); - ("GPR23", (D_decreasing, 64, 63)); - ("GPR24", (D_decreasing, 64, 63)); - ("GPR25", (D_decreasing, 64, 63)); - ("GPR26", (D_decreasing, 64, 63)); - ("GPR27", (D_decreasing, 64, 63)); - ("GPR28", (D_decreasing, 64, 63)); - ("GPR29", (D_decreasing, 64, 63)); - ("GPR30", (D_decreasing, 64, 63)); - ("GPR31", (D_decreasing, 64, 63)); - (* special registers for mul/div *) - ("HI", (D_decreasing, 64, 63)); - ("LO", (D_decreasing, 64, 63)); - (* control registers *) - ("CP0Status", (D_decreasing, 32, 31)); - ("CP0Cause", (D_decreasing, 32, 31)); - ("CP0EPC", (D_decreasing, 64, 63)); - ("CP0LLAddr", (D_decreasing, 64, 63)); - ("CP0LLBit", (D_decreasing, 1, 0)); - ("CP0Count", (D_decreasing, 32, 31)); - ("CP0Compare", (D_decreasing, 32, 31)); - ("CP0HWREna", (D_decreasing, 32, 31)); - ("CP0UserLocal", (D_decreasing, 64, 63)); - ("CP0BadVAddr", (D_decreasing, 64, 63)); - ("TLBProbe" ,(D_decreasing, 1, 0)); - ("TLBIndex" ,(D_decreasing, 6, 5)); - ("TLBRandom" ,(D_decreasing, 6, 5)); - ("TLBEntryLo0",(D_decreasing, 64, 63)); - ("TLBEntryLo1",(D_decreasing, 64, 63)); - ("TLBContext" ,(D_decreasing, 64, 63)); - ("TLBPageMask",(D_decreasing, 16, 15)); - ("TLBWired" ,(D_decreasing, 6, 5)); - ("TLBEntryHi" ,(D_decreasing, 64, 63)); - ("TLBXContext",(D_decreasing, 64, 63)); - - ("TLBEntry00" ,(D_decreasing, 117, 116)); - ("TLBEntry01" ,(D_decreasing, 117, 116)); - ("TLBEntry02" ,(D_decreasing, 117, 116)); - ("TLBEntry03" ,(D_decreasing, 117, 116)); - ("TLBEntry04" ,(D_decreasing, 117, 116)); - ("TLBEntry05" ,(D_decreasing, 117, 116)); - ("TLBEntry06" ,(D_decreasing, 117, 116)); - ("TLBEntry07" ,(D_decreasing, 117, 116)); - ("TLBEntry08" ,(D_decreasing, 117, 116)); - ("TLBEntry09" ,(D_decreasing, 117, 116)); - ("TLBEntry10" ,(D_decreasing, 117, 116)); - ("TLBEntry11" ,(D_decreasing, 117, 116)); - ("TLBEntry12" ,(D_decreasing, 117, 116)); - ("TLBEntry13" ,(D_decreasing, 117, 116)); - ("TLBEntry14" ,(D_decreasing, 117, 116)); - ("TLBEntry15" ,(D_decreasing, 117, 116)); - ("TLBEntry16" ,(D_decreasing, 117, 116)); - ("TLBEntry17" ,(D_decreasing, 117, 116)); - ("TLBEntry18" ,(D_decreasing, 117, 116)); - ("TLBEntry19" ,(D_decreasing, 117, 116)); - ("TLBEntry20" ,(D_decreasing, 117, 116)); - ("TLBEntry21" ,(D_decreasing, 117, 116)); - ("TLBEntry22" ,(D_decreasing, 117, 116)); - ("TLBEntry23" ,(D_decreasing, 117, 116)); - ("TLBEntry24" ,(D_decreasing, 117, 116)); - ("TLBEntry25" ,(D_decreasing, 117, 116)); - ("TLBEntry26" ,(D_decreasing, 117, 116)); - ("TLBEntry27" ,(D_decreasing, 117, 116)); - ("TLBEntry28" ,(D_decreasing, 117, 116)); - ("TLBEntry29" ,(D_decreasing, 117, 116)); - ("TLBEntry30" ,(D_decreasing, 117, 116)); - ("TLBEntry31" ,(D_decreasing, 117, 116)); - ("TLBEntry32" ,(D_decreasing, 117, 116)); - ("TLBEntry33" ,(D_decreasing, 117, 116)); - ("TLBEntry34" ,(D_decreasing, 117, 116)); - ("TLBEntry35" ,(D_decreasing, 117, 116)); - ("TLBEntry36" ,(D_decreasing, 117, 116)); - ("TLBEntry37" ,(D_decreasing, 117, 116)); - ("TLBEntry38" ,(D_decreasing, 117, 116)); - ("TLBEntry39" ,(D_decreasing, 117, 116)); - ("TLBEntry40" ,(D_decreasing, 117, 116)); - ("TLBEntry41" ,(D_decreasing, 117, 116)); - ("TLBEntry42" ,(D_decreasing, 117, 116)); - ("TLBEntry43" ,(D_decreasing, 117, 116)); - ("TLBEntry44" ,(D_decreasing, 117, 116)); - ("TLBEntry45" ,(D_decreasing, 117, 116)); - ("TLBEntry46" ,(D_decreasing, 117, 116)); - ("TLBEntry47" ,(D_decreasing, 117, 116)); - ("TLBEntry48" ,(D_decreasing, 117, 116)); - ("TLBEntry49" ,(D_decreasing, 117, 116)); - ("TLBEntry50" ,(D_decreasing, 117, 116)); - ("TLBEntry51" ,(D_decreasing, 117, 116)); - ("TLBEntry52" ,(D_decreasing, 117, 116)); - ("TLBEntry53" ,(D_decreasing, 117, 116)); - ("TLBEntry54" ,(D_decreasing, 117, 116)); - ("TLBEntry55" ,(D_decreasing, 117, 116)); - ("TLBEntry56" ,(D_decreasing, 117, 116)); - ("TLBEntry57" ,(D_decreasing, 117, 116)); - ("TLBEntry58" ,(D_decreasing, 117, 116)); - ("TLBEntry59" ,(D_decreasing, 117, 116)); - ("TLBEntry60" ,(D_decreasing, 117, 116)); - ("TLBEntry61" ,(D_decreasing, 117, 116)); - ("TLBEntry62" ,(D_decreasing, 117, 116)); - ("TLBEntry63" ,(D_decreasing, 117, 116)); - - ("UART_WDATA" ,(D_decreasing, 8, 7)); - ("UART_RDATA" ,(D_decreasing, 8, 7)); - ("UART_WRITTEN" ,(D_decreasing, 1, 0)); - ("UART_RVALID" ,(D_decreasing, 1, 0)); -] +let mips_register_data_all = + [ + (*Pseudo registers*) + ("PC", (D_decreasing, 64, 63)); + ("branchPending", (D_decreasing, 1, 0)); + ("inBranchDelay", (D_decreasing, 1, 0)); + ("delayedPC", (D_decreasing, 64, 63)); + ("nextPC", (D_decreasing, 64, 63)); + (* General purpose registers *) + ("GPR00", (D_decreasing, 64, 63)); + ("GPR01", (D_decreasing, 64, 63)); + ("GPR02", (D_decreasing, 64, 63)); + ("GPR03", (D_decreasing, 64, 63)); + ("GPR04", (D_decreasing, 64, 63)); + ("GPR05", (D_decreasing, 64, 63)); + ("GPR06", (D_decreasing, 64, 63)); + ("GPR07", (D_decreasing, 64, 63)); + ("GPR08", (D_decreasing, 64, 63)); + ("GPR09", (D_decreasing, 64, 63)); + ("GPR10", (D_decreasing, 64, 63)); + ("GPR11", (D_decreasing, 64, 63)); + ("GPR12", (D_decreasing, 64, 63)); + ("GPR13", (D_decreasing, 64, 63)); + ("GPR14", (D_decreasing, 64, 63)); + ("GPR15", (D_decreasing, 64, 63)); + ("GPR16", (D_decreasing, 64, 63)); + ("GPR17", (D_decreasing, 64, 63)); + ("GPR18", (D_decreasing, 64, 63)); + ("GPR19", (D_decreasing, 64, 63)); + ("GPR20", (D_decreasing, 64, 63)); + ("GPR21", (D_decreasing, 64, 63)); + ("GPR22", (D_decreasing, 64, 63)); + ("GPR23", (D_decreasing, 64, 63)); + ("GPR24", (D_decreasing, 64, 63)); + ("GPR25", (D_decreasing, 64, 63)); + ("GPR26", (D_decreasing, 64, 63)); + ("GPR27", (D_decreasing, 64, 63)); + ("GPR28", (D_decreasing, 64, 63)); + ("GPR29", (D_decreasing, 64, 63)); + ("GPR30", (D_decreasing, 64, 63)); + ("GPR31", (D_decreasing, 64, 63)); + (* special registers for mul/div *) + ("HI", (D_decreasing, 64, 63)); + ("LO", (D_decreasing, 64, 63)); + (* control registers *) + ("CP0Status", (D_decreasing, 32, 31)); + ("CP0Cause", (D_decreasing, 32, 31)); + ("CP0EPC", (D_decreasing, 64, 63)); + ("CP0LLAddr", (D_decreasing, 64, 63)); + ("CP0LLBit", (D_decreasing, 1, 0)); + ("CP0Count", (D_decreasing, 32, 31)); + ("CP0Compare", (D_decreasing, 32, 31)); + ("CP0HWREna", (D_decreasing, 32, 31)); + ("CP0UserLocal", (D_decreasing, 64, 63)); + ("CP0BadVAddr", (D_decreasing, 64, 63)); + ("TLBProbe", (D_decreasing, 1, 0)); + ("TLBIndex", (D_decreasing, 6, 5)); + ("TLBRandom", (D_decreasing, 6, 5)); + ("TLBEntryLo0", (D_decreasing, 64, 63)); + ("TLBEntryLo1", (D_decreasing, 64, 63)); + ("TLBContext", (D_decreasing, 64, 63)); + ("TLBPageMask", (D_decreasing, 16, 15)); + ("TLBWired", (D_decreasing, 6, 5)); + ("TLBEntryHi", (D_decreasing, 64, 63)); + ("TLBXContext", (D_decreasing, 64, 63)); + ("TLBEntry00", (D_decreasing, 117, 116)); + ("TLBEntry01", (D_decreasing, 117, 116)); + ("TLBEntry02", (D_decreasing, 117, 116)); + ("TLBEntry03", (D_decreasing, 117, 116)); + ("TLBEntry04", (D_decreasing, 117, 116)); + ("TLBEntry05", (D_decreasing, 117, 116)); + ("TLBEntry06", (D_decreasing, 117, 116)); + ("TLBEntry07", (D_decreasing, 117, 116)); + ("TLBEntry08", (D_decreasing, 117, 116)); + ("TLBEntry09", (D_decreasing, 117, 116)); + ("TLBEntry10", (D_decreasing, 117, 116)); + ("TLBEntry11", (D_decreasing, 117, 116)); + ("TLBEntry12", (D_decreasing, 117, 116)); + ("TLBEntry13", (D_decreasing, 117, 116)); + ("TLBEntry14", (D_decreasing, 117, 116)); + ("TLBEntry15", (D_decreasing, 117, 116)); + ("TLBEntry16", (D_decreasing, 117, 116)); + ("TLBEntry17", (D_decreasing, 117, 116)); + ("TLBEntry18", (D_decreasing, 117, 116)); + ("TLBEntry19", (D_decreasing, 117, 116)); + ("TLBEntry20", (D_decreasing, 117, 116)); + ("TLBEntry21", (D_decreasing, 117, 116)); + ("TLBEntry22", (D_decreasing, 117, 116)); + ("TLBEntry23", (D_decreasing, 117, 116)); + ("TLBEntry24", (D_decreasing, 117, 116)); + ("TLBEntry25", (D_decreasing, 117, 116)); + ("TLBEntry26", (D_decreasing, 117, 116)); + ("TLBEntry27", (D_decreasing, 117, 116)); + ("TLBEntry28", (D_decreasing, 117, 116)); + ("TLBEntry29", (D_decreasing, 117, 116)); + ("TLBEntry30", (D_decreasing, 117, 116)); + ("TLBEntry31", (D_decreasing, 117, 116)); + ("TLBEntry32", (D_decreasing, 117, 116)); + ("TLBEntry33", (D_decreasing, 117, 116)); + ("TLBEntry34", (D_decreasing, 117, 116)); + ("TLBEntry35", (D_decreasing, 117, 116)); + ("TLBEntry36", (D_decreasing, 117, 116)); + ("TLBEntry37", (D_decreasing, 117, 116)); + ("TLBEntry38", (D_decreasing, 117, 116)); + ("TLBEntry39", (D_decreasing, 117, 116)); + ("TLBEntry40", (D_decreasing, 117, 116)); + ("TLBEntry41", (D_decreasing, 117, 116)); + ("TLBEntry42", (D_decreasing, 117, 116)); + ("TLBEntry43", (D_decreasing, 117, 116)); + ("TLBEntry44", (D_decreasing, 117, 116)); + ("TLBEntry45", (D_decreasing, 117, 116)); + ("TLBEntry46", (D_decreasing, 117, 116)); + ("TLBEntry47", (D_decreasing, 117, 116)); + ("TLBEntry48", (D_decreasing, 117, 116)); + ("TLBEntry49", (D_decreasing, 117, 116)); + ("TLBEntry50", (D_decreasing, 117, 116)); + ("TLBEntry51", (D_decreasing, 117, 116)); + ("TLBEntry52", (D_decreasing, 117, 116)); + ("TLBEntry53", (D_decreasing, 117, 116)); + ("TLBEntry54", (D_decreasing, 117, 116)); + ("TLBEntry55", (D_decreasing, 117, 116)); + ("TLBEntry56", (D_decreasing, 117, 116)); + ("TLBEntry57", (D_decreasing, 117, 116)); + ("TLBEntry58", (D_decreasing, 117, 116)); + ("TLBEntry59", (D_decreasing, 117, 116)); + ("TLBEntry60", (D_decreasing, 117, 116)); + ("TLBEntry61", (D_decreasing, 117, 116)); + ("TLBEntry62", (D_decreasing, 117, 116)); + ("TLBEntry63", (D_decreasing, 117, 116)); + ("UART_WDATA", (D_decreasing, 8, 7)); + ("UART_RDATA", (D_decreasing, 8, 7)); + ("UART_WRITTEN", (D_decreasing, 1, 0)); + ("UART_RVALID", (D_decreasing, 1, 0)); + ] let initial_stack_and_reg_data_of_MIPS_elf_file e_entry all_data_memory = - let initial_stack_data = [] in - let initial_register_abi_data : (string * Sail_impl_base.register_value) list = [ - ("CP0Status", Sail_impl_base.register_value_of_integer 32 31 D_decreasing (Nat_big_num.of_string "0x00400000")); - ] in + let initial_stack_data = [] in + let initial_register_abi_data : (string * Sail_impl_base.register_value) list = + [("CP0Status", Sail_impl_base.register_value_of_integer 32 31 D_decreasing (Nat_big_num.of_string "0x00400000"))] + in (initial_stack_data, initial_register_abi_data) let initial_reg_file reg_data init = List.iter (fun (reg_name, _) -> reg := Reg.add reg_name (init reg_name) !reg) reg_data -let initial_system_state_of_elf_file name = - +let initial_system_state_of_elf_file name = (* call ELF analyser on file *) match Sail_interface.populate_and_obtain_global_symbol_init_info name with | Error.Fail s -> failwith ("populate_and_obtain_global_symbol_init_info: " ^ s) - | Error.Success - (_, (elf_epi: Sail_interface.executable_process_image), - (symbol_map: Elf_file.global_symbol_init_info)) - -> - let (segments, e_entry, e_machine) = - begin match elf_epi with - | ELF_Class_32 _ -> failwith "cannot handle ELF_Class_32" - | ELF_Class_64 (segments,e_entry,e_machine) -> - (* remove all the auto generated segments (they contain only 0s) *) - let segments = - Lem_list.mapMaybe - (fun (seg, prov) -> if prov = Elf_file.FromELF then Some seg else None) - segments - in - (segments,e_entry,e_machine) - end - in + | Error.Success + (_, (elf_epi : Sail_interface.executable_process_image), (symbol_map : Elf_file.global_symbol_init_info)) -> + let segments, e_entry, e_machine = + begin + match elf_epi with + | ELF_Class_32 _ -> failwith "cannot handle ELF_Class_32" + | ELF_Class_64 (segments, e_entry, e_machine) -> + (* remove all the auto generated segments (they contain only 0s) *) + let segments = + Lem_list.mapMaybe (fun (seg, prov) -> if prov = Elf_file.FromELF then Some seg else None) segments + in + (segments, e_entry, e_machine) + end + in - (* construct program memory and start address *) - begin - prog_mem := Mem.empty; - data_mem := Mem.empty; - tag_mem := Mem.empty; - load_memory_segments segments; - (* + (* construct program memory and start address *) + begin + prog_mem := Mem.empty; + data_mem := Mem.empty; + tag_mem := Mem.empty; + load_memory_segments segments; + (* debugf "prog_mem\n"; Mem.iter (fun k v -> debugf "%s\n" (Mem.to_string k v)) !prog_mem; debugf "data_mem\n"; Mem.iter (fun k v -> debugf "%s\n" (Mem.to_string k v)) !data_mem; *) - let (isa_defs, isa_memory_access, isa_externs, isa_model, model_reg_d, startaddr, - initial_stack_data, initial_register_abi_data, register_data_all) = - match Nat_big_num.to_int e_machine with -(* | 21 (* EM_PPC64 *) -> - let startaddr = - let e_entry = Uint64.of_int64 (Nat_big_num.to_int64 e_entry) in - match Abi_power64.abi_power64_compute_program_entry_point segments e_entry with - | Error.Fail s -> failwith "Failed computing entry point" - | Error.Success s -> Nat_big_num.of_int64 (Uint64.to_int64 s) - in - let (initial_stack_data, initial_register_abi_data) = - initial_stack_and_reg_data_of_PPC_elf_file e_entry !data_mem in - - (Power.defs, - (Power_extras.read_memory_functions,Power_extras.memory_writes,[],[],Power_extras.barrier_functions), - Power_extras.power_externs, - PPC, - D_increasing, - startaddr, - initial_stack_data, - initial_register_abi_data, - ppc_register_data_all) - - | 183 (* EM_AARCH64 *) -> - let startaddr = - let e_entry = Uint64.of_int64 (Nat_big_num.to_int64 e_entry) in - match Abi_aarch64_le.abi_aarch64_le_compute_program_entry_point segments e_entry with - | Error.Fail s -> failwith "Failed computing entry point" - | Error.Success s -> Nat_big_num.of_int64 (Uint64.to_int64 s) - in - - let (initial_stack_data, initial_register_abi_data) = - initial_stack_and_reg_data_of_AAarch64_elf_file e_entry !data_mem in - - (ArmV8.defs, - (ArmV8_extras.aArch64_read_memory_functions, - ArmV8_extras.aArch64_memory_writes, - ArmV8_extras.aArch64_memory_eas, - ArmV8_extras.aArch64_memory_vals, - ArmV8_extras.aArch64_barrier_functions), - [], - AArch64, - D_decreasing, - startaddr, - initial_stack_data, - initial_register_abi_data, - aarch64_register_data_all) *) - | 8 (* EM_MIPS *) -> - let startaddr = - let e_entry = Uint64_wrapper.of_bigint e_entry in - match Abi_mips64.abi_mips64_compute_program_entry_point segments e_entry with - | Error.Fail s -> failwith "Failed computing entry point" - | Error.Success s -> s + let ( isa_defs, + isa_memory_access, + isa_externs, + isa_model, + model_reg_d, + startaddr, + initial_stack_data, + initial_register_abi_data, + register_data_all ) = + match Nat_big_num.to_int e_machine with + (* | 21 (* EM_PPC64 *) -> + let startaddr = + let e_entry = Uint64.of_int64 (Nat_big_num.to_int64 e_entry) in + match Abi_power64.abi_power64_compute_program_entry_point segments e_entry with + | Error.Fail s -> failwith "Failed computing entry point" + | Error.Success s -> Nat_big_num.of_int64 (Uint64.to_int64 s) + in + let (initial_stack_data, initial_register_abi_data) = + initial_stack_and_reg_data_of_PPC_elf_file e_entry !data_mem in + + (Power.defs, + (Power_extras.read_memory_functions,Power_extras.memory_writes,[],[],Power_extras.barrier_functions), + Power_extras.power_externs, + PPC, + D_increasing, + startaddr, + initial_stack_data, + initial_register_abi_data, + ppc_register_data_all) + + | 183 (* EM_AARCH64 *) -> + let startaddr = + let e_entry = Uint64.of_int64 (Nat_big_num.to_int64 e_entry) in + match Abi_aarch64_le.abi_aarch64_le_compute_program_entry_point segments e_entry with + | Error.Fail s -> failwith "Failed computing entry point" + | Error.Success s -> Nat_big_num.of_int64 (Uint64.to_int64 s) + in + + let (initial_stack_data, initial_register_abi_data) = + initial_stack_and_reg_data_of_AAarch64_elf_file e_entry !data_mem in + + (ArmV8.defs, + (ArmV8_extras.aArch64_read_memory_functions, + ArmV8_extras.aArch64_memory_writes, + ArmV8_extras.aArch64_memory_eas, + ArmV8_extras.aArch64_memory_vals, + ArmV8_extras.aArch64_barrier_functions), + [], + AArch64, + D_decreasing, + startaddr, + initial_stack_data, + initial_register_abi_data, + aarch64_register_data_all) *) + | 8 (* EM_MIPS *) -> + let startaddr = + let e_entry = Uint64_wrapper.of_bigint e_entry in + match Abi_mips64.abi_mips64_compute_program_entry_point segments e_entry with + | Error.Fail s -> failwith "Failed computing entry point" + | Error.Success s -> s + in + let initial_stack_data, initial_register_abi_data = + initial_stack_and_reg_data_of_MIPS_elf_file e_entry !data_mem + in + + ( Mips.defs, + ( Mips_extras.mips_read_memory_functions, + Mips_extras.mips_memory_writes, + Mips_extras.mips_memory_eas, + Mips_extras.mips_memory_vals, + Mips_extras.mips_barrier_functions + ), + [], + MIPS, + D_decreasing, + startaddr, + initial_stack_data, + initial_register_abi_data, + mips_register_data_all + ) + | _ -> + failwith + (Printf.sprintf + "Sail sequential interpreter can't handle the e_machine value %s, only EM_PPC64, EM_AARCH64 and \ + EM_MIPS are supported." + (Nat_big_num.to_string e_machine) + ) + in + + (* pull the object symbols from the symbol table *) + let symbol_table : (string * Nat_big_num.num * int * word8 list (*their bytes*)) list = + let rec convert_symbol_table symbol_map = + begin + match symbol_map with + | [] -> [] + | ( (name : string), + ( (typ : Nat_big_num.num), + (size : Nat_big_num.num (*number of bytes*)), + (address : Nat_big_num.num), + (mb : Byte_sequence.byte_sequence option (*present iff type=stt_object*)), + (binding : Nat_big_num.num) + ) + ) (* (mb: Byte_sequence_wrapper.t option (*present iff type=stt_object*)) )) *) + :: symbol_map' -> + if + Nat_big_num.equal typ Elf_symbol_table.stt_object + && not (Nat_big_num.equal size (Nat_big_num.of_int 0)) + then ( + (* an object symbol - map *) + (*Printf.printf "*** size %d ***\n" (Nat_big_num.to_int size);*) + let bytes = + match mb with + | None -> raise (Failure "this cannot happen") + | Some (Sequence bytes) -> List.map (fun (c : char) -> Char.code c) bytes + in + (name, address, List.length bytes, bytes) :: convert_symbol_table symbol_map' + ) + else (* not an object symbol or of zero size - ignore *) + convert_symbol_table symbol_map' + end in - let (initial_stack_data, initial_register_abi_data) = - initial_stack_and_reg_data_of_MIPS_elf_file e_entry !data_mem in - - (Mips.defs, - (Mips_extras.mips_read_memory_functions, - Mips_extras.mips_memory_writes, - Mips_extras.mips_memory_eas, - Mips_extras.mips_memory_vals, - Mips_extras.mips_barrier_functions), - [], - MIPS, - D_decreasing, - startaddr, - initial_stack_data, - initial_register_abi_data, - mips_register_data_all) - - | _ -> failwith (Printf.sprintf "Sail sequential interpreter can't handle the e_machine value %s, only EM_PPC64, EM_AARCH64 and EM_MIPS are supported." (Nat_big_num.to_string e_machine)) - in - - (* pull the object symbols from the symbol table *) - let symbol_table : (string * Nat_big_num.num * int * word8 list (*their bytes*)) list = - let rec convert_symbol_table symbol_map = - begin match symbol_map with - | [] -> [] - | ((name: string), - ((typ: Nat_big_num.num), - (size: Nat_big_num.num (*number of bytes*)), - (address: Nat_big_num.num), - (mb: Byte_sequence.byte_sequence option (*present iff type=stt_object*)), - (binding: Nat_big_num.num))) - (* (mb: Byte_sequence_wrapper.t option (*present iff type=stt_object*)) )) *) - ::symbol_map' -> - if Nat_big_num.equal typ Elf_symbol_table.stt_object && not (Nat_big_num.equal size (Nat_big_num.of_int 0)) - then - ( - (* an object symbol - map *) - (*Printf.printf "*** size %d ***\n" (Nat_big_num.to_int size);*) - let bytes = - (match mb with - | None -> raise (Failure "this cannot happen") - | Some (Sequence bytes) -> - List.map (fun (c:char) -> Char.code c) bytes) in - (name, address, List.length bytes, bytes):: convert_symbol_table symbol_map' + List.map (fun (n, a, bs) -> (n, a, List.length bs, bs)) initial_stack_data @ convert_symbol_table symbol_map + in + + (* invert the symbol table to use for pp *) + let symbol_table_pp : ((Sail_impl_base.address * int) * string) list = + (* map symbol to (bindings, footprint), + if a symbol appears more then onece keep the one with higher + precedence (stb_global > stb_weak > stb_local) *) + let map = + List.fold_left + (fun map (name, (typ, size, address, mb, binding)) -> + if + String.length name <> 0 + && (if String.length name = 1 then Char.code (String.get name 0) <> 0 else true) + && not (Nat_big_num.equal address (Nat_big_num.of_int 0)) + then ( + try + let binding', _ = StringMap.find name map in + if + Nat_big_num.equal binding' Elf_symbol_table.stb_local + || Nat_big_num.equal binding Elf_symbol_table.stb_global + then + StringMap.add name + (binding, (Sail_impl_base.address_of_integer address, Nat_big_num.to_int size)) + map + else map + with Not_found -> + StringMap.add name + (binding, (Sail_impl_base.address_of_integer address, Nat_big_num.to_int size)) + map + ) + else map ) - else - (* not an object symbol or of zero size - ignore *) - convert_symbol_table symbol_map' - end + StringMap.empty symbol_map + in + + List.map (fun (name, (binding, fp)) -> (fp, name)) (StringMap.bindings map) in - (List.map (fun (n,a,bs) -> (n,a,List.length bs,bs)) initial_stack_data) @ convert_symbol_table symbol_map - in - (* invert the symbol table to use for pp *) - let symbol_table_pp : ((Sail_impl_base.address * int) * string) list = - (* map symbol to (bindings, footprint), - if a symbol appears more then onece keep the one with higher - precedence (stb_global > stb_weak > stb_local) *) - let map = - List.fold_left - (fun map (name, (typ, size, address, mb, binding)) -> - if String.length name <> 0 && - (if String.length name = 1 then Char.code (String.get name 0) <> 0 else true) && - not (Nat_big_num.equal address (Nat_big_num.of_int 0)) - then - try - let (binding', _) = StringMap.find name map in - if Nat_big_num.equal binding' Elf_symbol_table.stb_local || - Nat_big_num.equal binding Elf_symbol_table.stb_global - then - StringMap.add name (binding, - (Sail_impl_base.address_of_integer address, Nat_big_num.to_int size)) map - else map - with Not_found -> - StringMap.add name (binding, - (Sail_impl_base.address_of_integer address, Nat_big_num.to_int size)) map - - else map - ) - StringMap.empty - symbol_map + (* Now we examine the rest of the data memory, + removing the footprint of the symbols and chunking it into aligned chunks *) + let rec remove_symbols_from_data_memory data_mem symbols = + match symbols with + | [] -> data_mem + | (name, address, size, bs) :: symbols' -> + let data_mem' = + Mem.filter + (fun a v -> + not + (Nat_big_num.greater_equal a address + && Nat_big_num.less a (Nat_big_num.add (Nat_big_num.of_int (List.length bs)) address) + ) + ) + data_mem + in + remove_symbols_from_data_memory data_mem' symbols' in - List.map (fun (name, (binding, fp)) -> (fp, name)) (StringMap.bindings map) - in + let trimmed_data_memory : (Nat_big_num.num * memory_byte) list = + Mem.bindings (remove_symbols_from_data_memory !data_mem symbol_table) + in + (* make sure that's ordered increasingly.... *) + let trimmed_data_memory = List.sort (fun (a, b) (a', b') -> Nat_big_num.compare a a') trimmed_data_memory in - (* Now we examine the rest of the data memory, - removing the footprint of the symbols and chunking it into aligned chunks *) - - let rec remove_symbols_from_data_memory data_mem symbols = - match symbols with - | [] -> data_mem - | (name,address,size,bs)::symbols' -> - let data_mem' = - Mem.filter - (fun a v -> - not (Nat_big_num.greater_equal a address && - Nat_big_num.less a (Nat_big_num.add (Nat_big_num.of_int (List.length bs)) address))) - data_mem in - remove_symbols_from_data_memory data_mem' symbols' in - - let trimmed_data_memory : (Nat_big_num.num * memory_byte) list = - Mem.bindings (remove_symbols_from_data_memory !data_mem symbol_table) in - - (* make sure that's ordered increasingly.... *) - let trimmed_data_memory = - List.sort (fun (a,b) (a',b') -> Nat_big_num.compare a a') trimmed_data_memory in - - let aligned a n = (* a mod n = 0 *) - let n_big = Nat_big_num.of_int n in - Nat_big_num.equal (Nat_big_num.modulus a n_big) ((Nat_big_num.of_int 0)) in - - let isplus a' a n = (* a' = a+n *) - Nat_big_num.equal a' (Nat_big_num.add (Nat_big_num.of_int n) a) in - - let rec chunk_data_memory dm = - match dm with - | (a0,b0)::(a1,b1)::(a2,b2)::(a3,b3)::(a4,b4)::(a5,b5)::(a6,b6)::(a7,b7)::dm' when - (aligned a0 8 && isplus a1 a0 1 && isplus a2 a0 2 && isplus a3 a0 3 && - isplus a4 a0 4 && isplus a5 a0 5 && isplus a6 a0 6 && isplus a7 a0 7) -> - (a0,8,[b0;b1;b2;b3;b4;b5;b6;b7]) :: chunk_data_memory dm' - | (a0,b0)::(a1,b1)::(a2,b2)::(a3,b3)::dm' when - (aligned a0 4 && isplus a1 a0 1 && isplus a2 a0 2 && isplus a3 a0 3) -> - (a0,4,[b0;b1;b2;b3]) :: chunk_data_memory dm' - | (a0,b0)::(a1,b1)::dm' when - (aligned a0 2 && isplus a1 a0 1) -> - (a0,2,[b0;b1]) :: chunk_data_memory dm' - | (a0,b0)::dm' -> - (a0,1,[b0]):: chunk_data_memory dm' - | [] -> [] in - - let initial_register_state = - fun rbn -> - try - List.assoc rbn initial_register_abi_data - with - Not_found -> - (register_state_zero register_data_all) rbn - in + let aligned a n = + (* a mod n = 0 *) + let n_big = Nat_big_num.of_int n in + Nat_big_num.equal (Nat_big_num.modulus a n_big) (Nat_big_num.of_int 0) + in - begin - (initial_reg_file register_data_all initial_register_state); - - (* construct initial system state *) - let initial_system_state = - (isa_defs, - isa_memory_access, - isa_externs, - isa_model, - model_reg_d, - startaddr, - (Sail_impl_base.address_of_integer startaddr)) + let isplus a' a n = + (* a' = a+n *) + Nat_big_num.equal a' (Nat_big_num.add (Nat_big_num.of_int n) a) + in + + let rec chunk_data_memory dm = + match dm with + | (a0, b0) :: (a1, b1) :: (a2, b2) :: (a3, b3) :: (a4, b4) :: (a5, b5) :: (a6, b6) :: (a7, b7) :: dm' + when aligned a0 8 && isplus a1 a0 1 && isplus a2 a0 2 && isplus a3 a0 3 && isplus a4 a0 4 && isplus a5 a0 5 + && isplus a6 a0 6 && isplus a7 a0 7 -> + (a0, 8, [b0; b1; b2; b3; b4; b5; b6; b7]) :: chunk_data_memory dm' + | (a0, b0) :: (a1, b1) :: (a2, b2) :: (a3, b3) :: dm' + when aligned a0 4 && isplus a1 a0 1 && isplus a2 a0 2 && isplus a3 a0 3 -> + (a0, 4, [b0; b1; b2; b3]) :: chunk_data_memory dm' + | (a0, b0) :: (a1, b1) :: dm' when aligned a0 2 && isplus a1 a0 1 -> + (a0, 2, [b0; b1]) :: chunk_data_memory dm' + | (a0, b0) :: dm' -> (a0, 1, [b0]) :: chunk_data_memory dm' + | [] -> [] + in + + let initial_register_state rbn = + try List.assoc rbn initial_register_abi_data with Not_found -> (register_state_zero register_data_all) rbn in - - (initial_system_state, symbol_table_pp) + + begin + initial_reg_file register_data_all initial_register_state; + + (* construct initial system state *) + let initial_system_state = + ( isa_defs, + isa_memory_access, + isa_externs, + isa_model, + model_reg_d, + startaddr, + Sail_impl_base.address_of_integer startaddr + ) + in + + (initial_system_state, symbol_table_pp) + end end - end let eager_eval = ref true let break_point = ref false @@ -914,22 +927,42 @@ let break_instr = ref 0 let max_cut_off = ref false let max_instr = ref 0 let raw_file = ref "" -let raw_at = ref 0 - -let args = [ - ("--file", Arg.Set_string file, "filename of elf binary to load in memory"); - ("--quiet", Arg.Clear Run_interp_model.interact_print, "do not display per-instruction actions"); - ("--silent", Arg.Tuple [Arg.Clear Run_interp_model.error_print; - Arg.Clear Run_interp_model.interact_print; - Arg.Clear Run_interp_model.result_print], - "do not dispaly error messages, per-instruction actions, or results"); - ("--no_result", Arg.Clear Run_interp_model.result_print, "do not display final register values"); - ("--interactive", Arg.Clear eager_eval , "interactive execution"); - ("--breakpoint", Arg.Int (fun i -> break_point := true; break_instr:= i), "run to instruction number i, then run interactively"); - ("--max_instruction", Arg.Int (fun i -> max_cut_off := true; max_instr := i), "only run i instructions, then stop"); - ("--raw", Arg.Set_string raw_file, "filename of raw file to load in memory"); - ("--at", Arg.Set_int raw_at, "address to load raw file in memory"); -] +let raw_at = ref 0 + +let args = + [ + ("--file", Arg.Set_string file, "filename of elf binary to load in memory"); + ("--quiet", Arg.Clear Run_interp_model.interact_print, "do not display per-instruction actions"); + ( "--silent", + Arg.Tuple + [ + Arg.Clear Run_interp_model.error_print; + Arg.Clear Run_interp_model.interact_print; + Arg.Clear Run_interp_model.result_print; + ], + "do not dispaly error messages, per-instruction actions, or results" + ); + ("--no_result", Arg.Clear Run_interp_model.result_print, "do not display final register values"); + ("--interactive", Arg.Clear eager_eval, "interactive execution"); + ( "--breakpoint", + Arg.Int + (fun i -> + break_point := true; + break_instr := i + ), + "run to instruction number i, then run interactively" + ); + ( "--max_instruction", + Arg.Int + (fun i -> + max_cut_off := true; + max_instr := i + ), + "only run i instructions, then stop" + ); + ("--raw", Arg.Set_string raw_file, "filename of raw file to load in memory"); + ("--at", Arg.Set_int raw_at, "address to load raw file in memory"); + ] let time_it action arg = let start_time = Sys.time () in @@ -939,405 +972,456 @@ let time_it action arg = (*TODO MIPS specific, should print final register values under all models*) let rec debug_print_gprs start stop = - resultf "DEBUG MIPS REG %.2d %s\n" start (Printing_functions.logfile_register_value_to_string (Reg.find (Format.sprintf "GPR%02d" start) !reg)); - if start < stop - then debug_print_gprs (start + 1) stop - else () + resultf "DEBUG MIPS REG %.2d %s\n" start + (Printing_functions.logfile_register_value_to_string (Reg.find (Format.sprintf "GPR%02d" start) !reg)); + if start < stop then debug_print_gprs (start + 1) stop else () let stop_condition_met model instr = match model with - | PPC -> - (match instr with - | ("Sc", [("Lev", _, arg)]) -> - Nat_big_num.equal (integer_of_bit_list arg) (Nat_big_num.of_int 32) - | _ -> false) - | AArch64 -> (match instr with - | ("ImplementationDefinedStopFetching", _) -> true - | _ -> false) - | MIPS -> (match instr with - | ("HCF", _) -> - resultf "DEBUG MIPS PC %s\n" (Printing_functions.logfile_register_value_to_string (Reg.find "PC" !reg)); - debug_print_gprs 0 31; - true - | _ -> false) + | PPC -> ( + match instr with + | "Sc", [("Lev", _, arg)] -> Nat_big_num.equal (integer_of_bit_list arg) (Nat_big_num.of_int 32) + | _ -> false + ) + | AArch64 -> ( + match instr with "ImplementationDefinedStopFetching", _ -> true | _ -> false + ) + | MIPS -> ( + match instr with + | "HCF", _ -> + resultf "DEBUG MIPS PC %s\n" (Printing_functions.logfile_register_value_to_string (Reg.find "PC" !reg)); + debug_print_gprs 0 31; + true + | _ -> false + ) let is_branch model instruction = - let (name,_,_) = instruction in - match (model , name) with - | (PPC, "B") -> true - | (PPC, "Bc") -> true - | (PPC, "Bclr") -> true - | (PPC, "Bcctr") -> true - | (PPC, _) -> false - | (AArch64, "BranchImmediate") -> true - | (AArch64, "BranchConditional") -> true - | (AArch64, "CompareAndBranch") -> true - | (AArch64, "TestBitAndBranch") -> true - | (AArch64, "BranchRegister") -> true - | (AArch64, _) -> false - | (MIPS, _) -> false (*todo,fill this in*) - -let option_int_of_option_integer i = match i with - | Some i -> Some (Nat_big_num.to_int i) - | None -> None + let name, _, _ = instruction in + match (model, name) with + | PPC, "B" -> true + | PPC, "Bc" -> true + | PPC, "Bclr" -> true + | PPC, "Bcctr" -> true + | PPC, _ -> false + | AArch64, "BranchImmediate" -> true + | AArch64, "BranchConditional" -> true + | AArch64, "CompareAndBranch" -> true + | AArch64, "TestBitAndBranch" -> true + | AArch64, "BranchRegister" -> true + | AArch64, _ -> false + | MIPS, _ -> false (*todo,fill this in*) + +let option_int_of_option_integer i = match i with Some i -> Some (Nat_big_num.to_int i) | None -> None let set_next_instruction_address model = match model with - | PPC -> - let cia = Reg.find "CIA" !reg in - let cia_addr = address_of_register_value cia in - (match cia_addr with - | Some cia_addr -> - let nia_addr = add_address_nat cia_addr 4 in - let nia = register_value_of_address nia_addr Sail_impl_base.D_increasing in - reg := Reg.add "NIA" nia !reg - | _ -> failwith "CIA address contains unknown or undefined") - | AArch64 -> - let pc = Reg.find "_PC" !reg in - let pc_addr = address_of_register_value pc in - (match pc_addr with - | Some pc_addr -> - let n_addr = add_address_nat pc_addr 4 in - let n_pc = register_value_of_address n_addr D_decreasing in - reg := Reg.add "_PC" n_pc !reg - | _ -> failwith "_PC address contains unknown or undefined") - | MIPS -> - let pc_addr = address_of_register_value (Reg.find "PC" !reg) in - let branchPending = integer_of_register_value (Reg.find "branchPending" !reg) in - (match (pc_addr, option_int_of_option_integer branchPending) with - | (Some pc_val, Some 0) -> - (* normal -- increment PC *) - let n_addr = add_address_nat pc_val 4 in - let n_pc = register_value_of_address n_addr D_decreasing in - begin - reg := Reg.add "nextPC" n_pc !reg; - reg := Reg.add "inBranchDelay" (register_value_of_integer 1 0 Sail_impl_base.D_decreasing Nat_big_num.zero) !reg; - end - | (Some pc_val, Some 1) -> - (* delay slot -- branch to delayed PC and clear branchPending *) - begin - reg := Reg.add "nextPC" (Reg.find "delayedPC" !reg) !reg; - reg := Reg.add "branchPending" (register_value_of_integer 1 0 Sail_impl_base.D_decreasing Nat_big_num.zero) !reg; - reg := Reg.add "inBranchDelay" (register_value_of_integer 1 0 Sail_impl_base.D_decreasing (Nat_big_num.of_int 1)) !reg; - end - | (_, _) -> errorf "PC address contains unknown or undefined"; exit 1) + | PPC -> ( + let cia = Reg.find "CIA" !reg in + let cia_addr = address_of_register_value cia in + match cia_addr with + | Some cia_addr -> + let nia_addr = add_address_nat cia_addr 4 in + let nia = register_value_of_address nia_addr Sail_impl_base.D_increasing in + reg := Reg.add "NIA" nia !reg + | _ -> failwith "CIA address contains unknown or undefined" + ) + | AArch64 -> ( + let pc = Reg.find "_PC" !reg in + let pc_addr = address_of_register_value pc in + match pc_addr with + | Some pc_addr -> + let n_addr = add_address_nat pc_addr 4 in + let n_pc = register_value_of_address n_addr D_decreasing in + reg := Reg.add "_PC" n_pc !reg + | _ -> failwith "_PC address contains unknown or undefined" + ) + | MIPS -> ( + let pc_addr = address_of_register_value (Reg.find "PC" !reg) in + let branchPending = integer_of_register_value (Reg.find "branchPending" !reg) in + match (pc_addr, option_int_of_option_integer branchPending) with + | Some pc_val, Some 0 -> + (* normal -- increment PC *) + let n_addr = add_address_nat pc_val 4 in + let n_pc = register_value_of_address n_addr D_decreasing in + begin + reg := Reg.add "nextPC" n_pc !reg; + reg := + Reg.add "inBranchDelay" (register_value_of_integer 1 0 Sail_impl_base.D_decreasing Nat_big_num.zero) !reg + end + | Some pc_val, Some 1 -> begin + (* delay slot -- branch to delayed PC and clear branchPending *) + reg := Reg.add "nextPC" (Reg.find "delayedPC" !reg) !reg; + reg := + Reg.add "branchPending" (register_value_of_integer 1 0 Sail_impl_base.D_decreasing Nat_big_num.zero) !reg; + reg := + Reg.add "inBranchDelay" + (register_value_of_integer 1 0 Sail_impl_base.D_decreasing (Nat_big_num.of_int 1)) + !reg + end + | _, _ -> + errorf "PC address contains unknown or undefined"; + exit 1 + ) let add1 = Nat_big_num.add (Nat_big_num.of_int 1) let get_addr_trans_regs _ = - Some([ - (Sail_impl_base.Reg("PC", 63, 64, Sail_impl_base.D_decreasing), Reg.find "PC" !reg); - (Sail_impl_base.Reg("CP0Status", 31, 32, Sail_impl_base.D_decreasing), Reg.find "CP0Status" !reg); - (Sail_impl_base.Reg("CP0Cause", 31, 32, Sail_impl_base.D_decreasing), Reg.find "CP0Cause" !reg); - (Sail_impl_base.Reg("CP0Count", 31, 32, Sail_impl_base.D_decreasing), Reg.find "CP0Count" !reg); - (Sail_impl_base.Reg("CP0Compare", 31, 32, Sail_impl_base.D_decreasing), Reg.find "CP0Compare" !reg); - (Sail_impl_base.Reg("inBranchDelay", 0, 1, Sail_impl_base.D_decreasing), Reg.find "inBranchDelay" !reg); - (Sail_impl_base.Reg("TLBRandom", 5, 6, Sail_impl_base.D_decreasing), Reg.find "TLBRandom" !reg); - (Sail_impl_base.Reg("TLBWired", 5, 6, Sail_impl_base.D_decreasing), Reg.find "TLBWired" !reg); - (Sail_impl_base.Reg("TLBEntryHi", 63, 64, Sail_impl_base.D_decreasing), Reg.find "TLBEntryHi" !reg); - (Sail_impl_base.Reg("TLBEntry00", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry00" !reg); - (Sail_impl_base.Reg("TLBEntry01", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry01" !reg); - (Sail_impl_base.Reg("TLBEntry02", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry02" !reg); - (Sail_impl_base.Reg("TLBEntry03", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry03" !reg); - (Sail_impl_base.Reg("TLBEntry04", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry04" !reg); - (Sail_impl_base.Reg("TLBEntry05", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry05" !reg); - (Sail_impl_base.Reg("TLBEntry06", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry06" !reg); - (Sail_impl_base.Reg("TLBEntry07", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry07" !reg); - (Sail_impl_base.Reg("TLBEntry08", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry08" !reg); - (Sail_impl_base.Reg("TLBEntry09", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry09" !reg); - (Sail_impl_base.Reg("TLBEntry10", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry10" !reg); - (Sail_impl_base.Reg("TLBEntry11", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry11" !reg); - (Sail_impl_base.Reg("TLBEntry12", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry12" !reg); - (Sail_impl_base.Reg("TLBEntry13", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry13" !reg); - (Sail_impl_base.Reg("TLBEntry14", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry14" !reg); - (Sail_impl_base.Reg("TLBEntry15", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry15" !reg); - (Sail_impl_base.Reg("TLBEntry16", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry16" !reg); - (Sail_impl_base.Reg("TLBEntry17", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry17" !reg); - (Sail_impl_base.Reg("TLBEntry18", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry18" !reg); - (Sail_impl_base.Reg("TLBEntry19", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry19" !reg); - (Sail_impl_base.Reg("TLBEntry20", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry20" !reg); - (Sail_impl_base.Reg("TLBEntry21", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry21" !reg); - (Sail_impl_base.Reg("TLBEntry22", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry22" !reg); - (Sail_impl_base.Reg("TLBEntry23", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry23" !reg); - (Sail_impl_base.Reg("TLBEntry24", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry24" !reg); - (Sail_impl_base.Reg("TLBEntry25", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry25" !reg); - (Sail_impl_base.Reg("TLBEntry26", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry26" !reg); - (Sail_impl_base.Reg("TLBEntry27", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry27" !reg); - (Sail_impl_base.Reg("TLBEntry28", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry28" !reg); - (Sail_impl_base.Reg("TLBEntry29", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry29" !reg); - (Sail_impl_base.Reg("TLBEntry30", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry30" !reg); - (Sail_impl_base.Reg("TLBEntry31", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry31" !reg); - (Sail_impl_base.Reg("TLBEntry32", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry32" !reg); - (Sail_impl_base.Reg("TLBEntry33", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry33" !reg); - (Sail_impl_base.Reg("TLBEntry34", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry34" !reg); - (Sail_impl_base.Reg("TLBEntry35", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry35" !reg); - (Sail_impl_base.Reg("TLBEntry36", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry36" !reg); - (Sail_impl_base.Reg("TLBEntry37", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry37" !reg); - (Sail_impl_base.Reg("TLBEntry38", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry38" !reg); - (Sail_impl_base.Reg("TLBEntry39", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry39" !reg); - (Sail_impl_base.Reg("TLBEntry40", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry40" !reg); - (Sail_impl_base.Reg("TLBEntry41", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry41" !reg); - (Sail_impl_base.Reg("TLBEntry42", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry42" !reg); - (Sail_impl_base.Reg("TLBEntry43", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry43" !reg); - (Sail_impl_base.Reg("TLBEntry44", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry44" !reg); - (Sail_impl_base.Reg("TLBEntry45", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry45" !reg); - (Sail_impl_base.Reg("TLBEntry46", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry46" !reg); - (Sail_impl_base.Reg("TLBEntry47", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry47" !reg); - (Sail_impl_base.Reg("TLBEntry48", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry48" !reg); - (Sail_impl_base.Reg("TLBEntry49", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry49" !reg); - (Sail_impl_base.Reg("TLBEntry50", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry50" !reg); - (Sail_impl_base.Reg("TLBEntry51", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry51" !reg); - (Sail_impl_base.Reg("TLBEntry52", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry52" !reg); - (Sail_impl_base.Reg("TLBEntry53", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry53" !reg); - (Sail_impl_base.Reg("TLBEntry54", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry54" !reg); - (Sail_impl_base.Reg("TLBEntry55", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry55" !reg); - (Sail_impl_base.Reg("TLBEntry56", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry56" !reg); - (Sail_impl_base.Reg("TLBEntry57", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry57" !reg); - (Sail_impl_base.Reg("TLBEntry58", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry58" !reg); - (Sail_impl_base.Reg("TLBEntry59", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry59" !reg); - (Sail_impl_base.Reg("TLBEntry60", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry60" !reg); - (Sail_impl_base.Reg("TLBEntry61", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry61" !reg); - (Sail_impl_base.Reg("TLBEntry62", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry62" !reg); - (Sail_impl_base.Reg("TLBEntry63", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry63" !reg); - ]) + Some + [ + (Sail_impl_base.Reg ("PC", 63, 64, Sail_impl_base.D_decreasing), Reg.find "PC" !reg); + (Sail_impl_base.Reg ("CP0Status", 31, 32, Sail_impl_base.D_decreasing), Reg.find "CP0Status" !reg); + (Sail_impl_base.Reg ("CP0Cause", 31, 32, Sail_impl_base.D_decreasing), Reg.find "CP0Cause" !reg); + (Sail_impl_base.Reg ("CP0Count", 31, 32, Sail_impl_base.D_decreasing), Reg.find "CP0Count" !reg); + (Sail_impl_base.Reg ("CP0Compare", 31, 32, Sail_impl_base.D_decreasing), Reg.find "CP0Compare" !reg); + (Sail_impl_base.Reg ("inBranchDelay", 0, 1, Sail_impl_base.D_decreasing), Reg.find "inBranchDelay" !reg); + (Sail_impl_base.Reg ("TLBRandom", 5, 6, Sail_impl_base.D_decreasing), Reg.find "TLBRandom" !reg); + (Sail_impl_base.Reg ("TLBWired", 5, 6, Sail_impl_base.D_decreasing), Reg.find "TLBWired" !reg); + (Sail_impl_base.Reg ("TLBEntryHi", 63, 64, Sail_impl_base.D_decreasing), Reg.find "TLBEntryHi" !reg); + (Sail_impl_base.Reg ("TLBEntry00", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry00" !reg); + (Sail_impl_base.Reg ("TLBEntry01", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry01" !reg); + (Sail_impl_base.Reg ("TLBEntry02", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry02" !reg); + (Sail_impl_base.Reg ("TLBEntry03", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry03" !reg); + (Sail_impl_base.Reg ("TLBEntry04", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry04" !reg); + (Sail_impl_base.Reg ("TLBEntry05", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry05" !reg); + (Sail_impl_base.Reg ("TLBEntry06", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry06" !reg); + (Sail_impl_base.Reg ("TLBEntry07", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry07" !reg); + (Sail_impl_base.Reg ("TLBEntry08", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry08" !reg); + (Sail_impl_base.Reg ("TLBEntry09", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry09" !reg); + (Sail_impl_base.Reg ("TLBEntry10", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry10" !reg); + (Sail_impl_base.Reg ("TLBEntry11", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry11" !reg); + (Sail_impl_base.Reg ("TLBEntry12", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry12" !reg); + (Sail_impl_base.Reg ("TLBEntry13", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry13" !reg); + (Sail_impl_base.Reg ("TLBEntry14", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry14" !reg); + (Sail_impl_base.Reg ("TLBEntry15", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry15" !reg); + (Sail_impl_base.Reg ("TLBEntry16", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry16" !reg); + (Sail_impl_base.Reg ("TLBEntry17", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry17" !reg); + (Sail_impl_base.Reg ("TLBEntry18", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry18" !reg); + (Sail_impl_base.Reg ("TLBEntry19", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry19" !reg); + (Sail_impl_base.Reg ("TLBEntry20", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry20" !reg); + (Sail_impl_base.Reg ("TLBEntry21", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry21" !reg); + (Sail_impl_base.Reg ("TLBEntry22", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry22" !reg); + (Sail_impl_base.Reg ("TLBEntry23", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry23" !reg); + (Sail_impl_base.Reg ("TLBEntry24", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry24" !reg); + (Sail_impl_base.Reg ("TLBEntry25", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry25" !reg); + (Sail_impl_base.Reg ("TLBEntry26", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry26" !reg); + (Sail_impl_base.Reg ("TLBEntry27", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry27" !reg); + (Sail_impl_base.Reg ("TLBEntry28", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry28" !reg); + (Sail_impl_base.Reg ("TLBEntry29", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry29" !reg); + (Sail_impl_base.Reg ("TLBEntry30", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry30" !reg); + (Sail_impl_base.Reg ("TLBEntry31", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry31" !reg); + (Sail_impl_base.Reg ("TLBEntry32", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry32" !reg); + (Sail_impl_base.Reg ("TLBEntry33", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry33" !reg); + (Sail_impl_base.Reg ("TLBEntry34", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry34" !reg); + (Sail_impl_base.Reg ("TLBEntry35", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry35" !reg); + (Sail_impl_base.Reg ("TLBEntry36", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry36" !reg); + (Sail_impl_base.Reg ("TLBEntry37", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry37" !reg); + (Sail_impl_base.Reg ("TLBEntry38", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry38" !reg); + (Sail_impl_base.Reg ("TLBEntry39", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry39" !reg); + (Sail_impl_base.Reg ("TLBEntry40", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry40" !reg); + (Sail_impl_base.Reg ("TLBEntry41", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry41" !reg); + (Sail_impl_base.Reg ("TLBEntry42", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry42" !reg); + (Sail_impl_base.Reg ("TLBEntry43", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry43" !reg); + (Sail_impl_base.Reg ("TLBEntry44", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry44" !reg); + (Sail_impl_base.Reg ("TLBEntry45", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry45" !reg); + (Sail_impl_base.Reg ("TLBEntry46", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry46" !reg); + (Sail_impl_base.Reg ("TLBEntry47", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry47" !reg); + (Sail_impl_base.Reg ("TLBEntry48", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry48" !reg); + (Sail_impl_base.Reg ("TLBEntry49", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry49" !reg); + (Sail_impl_base.Reg ("TLBEntry50", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry50" !reg); + (Sail_impl_base.Reg ("TLBEntry51", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry51" !reg); + (Sail_impl_base.Reg ("TLBEntry52", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry52" !reg); + (Sail_impl_base.Reg ("TLBEntry53", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry53" !reg); + (Sail_impl_base.Reg ("TLBEntry54", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry54" !reg); + (Sail_impl_base.Reg ("TLBEntry55", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry55" !reg); + (Sail_impl_base.Reg ("TLBEntry56", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry56" !reg); + (Sail_impl_base.Reg ("TLBEntry57", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry57" !reg); + (Sail_impl_base.Reg ("TLBEntry58", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry58" !reg); + (Sail_impl_base.Reg ("TLBEntry59", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry59" !reg); + (Sail_impl_base.Reg ("TLBEntry60", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry60" !reg); + (Sail_impl_base.Reg ("TLBEntry61", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry61" !reg); + (Sail_impl_base.Reg ("TLBEntry62", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry62" !reg); + (Sail_impl_base.Reg ("TLBEntry63", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry63" !reg); + ] let get_opcode pc_a = - List.map (fun b -> match b with - | Some b -> b - | None -> failwith "A byte in opcode contained unknown or undef") + List.map + (fun b -> match b with Some b -> b | None -> failwith "A byte in opcode contained unknown or undef") (List.map byte_of_memory_byte - ([Mem.find pc_a !prog_mem; + [ + Mem.find pc_a !prog_mem; Mem.find (add1 pc_a) !prog_mem; Mem.find (add1 (add1 pc_a)) !prog_mem; - Mem.find (add1 (add1 (add1 pc_a))) !prog_mem])) + Mem.find (add1 (add1 (add1 pc_a))) !prog_mem; + ] + ) let rec write_events = function | [] -> () - | e::events -> - (match e with - | E_write_reg (Reg(id,_,_,_), value) -> reg := Reg.add id value !reg - | E_write_reg ((Reg_slice(id,_,_,range) as reg_n),value) - | E_write_reg ((Reg_field(id,_,_,_,range) as reg_n),value)-> - let old_val = Reg.find id !reg in - let new_val = fupdate_slice reg_n old_val value range in - reg := Reg.add id new_val !reg - | E_write_reg((Reg_f_slice(id,_,_,_,range,mini_range) as reg_n),value) -> - let old_val = Reg.find id !reg in - let new_val = fupdate_slice reg_n old_val value (combine_slices range mini_range) in - reg := Reg.add id new_val !reg - | _ -> failwith "Only register write events expected"); - write_events events + | e :: events -> + ( match e with + | E_write_reg (Reg (id, _, _, _), value) -> reg := Reg.add id value !reg + | E_write_reg ((Reg_slice (id, _, _, range) as reg_n), value) + | E_write_reg ((Reg_field (id, _, _, _, range) as reg_n), value) -> + let old_val = Reg.find id !reg in + let new_val = fupdate_slice reg_n old_val value range in + reg := Reg.add id new_val !reg + | E_write_reg ((Reg_f_slice (id, _, _, _, range, mini_range) as reg_n), value) -> + let old_val = Reg.find id !reg in + let new_val = fupdate_slice reg_n old_val value (combine_slices range mini_range) in + reg := Reg.add id new_val !reg + | _ -> failwith "Only register write events expected" + ); + write_events events let fetch_instruction_opcode_and_update_ia model addr_trans = match model with - | PPC -> - let cia = Reg.find "CIA" !reg in - let cia_addr = address_of_register_value cia in - (match cia_addr with - | Some cia_addr -> - let cia_a = integer_of_address cia_addr in - let opcode = (get_opcode cia_a) in - begin - reg := Reg.add "CIA" (Reg.find "NIA" !reg) !reg; - Opcode opcode - end - | None -> failwith "CIA address contains unknown or undefined") - | AArch64 -> - let pc = Reg.find "_PC" !reg in - let pc_addr = address_of_register_value pc in - (match pc_addr with - | Some pc_addr -> - let pc_a = integer_of_address pc_addr in - let opcode = (get_opcode pc_a) in - Opcode opcode - | None -> failwith "_PC address contains unknown or undefined") - | MIPS -> - begin - let nextPC = Reg.find "nextPC" !reg in - let pc_addr = address_of_register_value nextPC in - (*let unused = interactf "PC: %s\n" (Printing_functions.register_value_to_string nextPC) in*) - (match pc_addr with - | Some pc_addr -> - let pc_a = match addr_trans (get_addr_trans_regs ()) pc_addr with - | Some a, Some events -> write_events (List.rev events); integer_of_address a - | Some a, None -> integer_of_address a - | None, Some events -> - write_events (List.rev events); - let nextPC = Reg.find "nextPC" !reg in - let pc_addr = address_of_register_value nextPC in - (match pc_addr with - | Some pc_addr -> - (match addr_trans (get_addr_trans_regs ()) pc_addr with - | Some a, Some events -> write_events (List.rev events); integer_of_address a - | Some a, None -> integer_of_address a - | None, _ -> failwith "Address translation failed twice") - | None -> failwith "no nextPc address") - | _ -> failwith "No address and no events from translate address" - in - let opcode = (get_opcode pc_a) in - begin - reg := Reg.add "PC" (Reg.find "nextPC" !reg) !reg; - Opcode opcode - end - | None -> errorf "nextPC contains unknown or undefined"; exit 1) + | PPC -> ( + let cia = Reg.find "CIA" !reg in + let cia_addr = address_of_register_value cia in + match cia_addr with + | Some cia_addr -> + let cia_a = integer_of_address cia_addr in + let opcode = get_opcode cia_a in + begin + reg := Reg.add "CIA" (Reg.find "NIA" !reg) !reg; + Opcode opcode + end + | None -> failwith "CIA address contains unknown or undefined" + ) + | AArch64 -> ( + let pc = Reg.find "_PC" !reg in + let pc_addr = address_of_register_value pc in + match pc_addr with + | Some pc_addr -> + let pc_a = integer_of_address pc_addr in + let opcode = get_opcode pc_a in + Opcode opcode + | None -> failwith "_PC address contains unknown or undefined" + ) + | MIPS -> begin + let nextPC = Reg.find "nextPC" !reg in + let pc_addr = address_of_register_value nextPC in + (*let unused = interactf "PC: %s\n" (Printing_functions.register_value_to_string nextPC) in*) + match pc_addr with + | Some pc_addr -> + let pc_a = + match addr_trans (get_addr_trans_regs ()) pc_addr with + | Some a, Some events -> + write_events (List.rev events); + integer_of_address a + | Some a, None -> integer_of_address a + | None, Some events -> ( + write_events (List.rev events); + let nextPC = Reg.find "nextPC" !reg in + let pc_addr = address_of_register_value nextPC in + match pc_addr with + | Some pc_addr -> ( + match addr_trans (get_addr_trans_regs ()) pc_addr with + | Some a, Some events -> + write_events (List.rev events); + integer_of_address a + | Some a, None -> integer_of_address a + | None, _ -> failwith "Address translation failed twice" + ) + | None -> failwith "no nextPc address" + ) + | _ -> failwith "No address and no events from translate address" + in + let opcode = get_opcode pc_a in + begin + reg := Reg.add "PC" (Reg.find "nextPC" !reg) !reg; + Opcode opcode + end + | None -> + errorf "nextPC contains unknown or undefined"; + exit 1 end - | _ -> assert false + | _ -> assert false -let get_pc_address = function - | MIPS -> Reg.find "PC" !reg - | PPC -> Reg.find "CIA" !reg - | AArch64 -> Reg.find "_PC" !reg - +let get_pc_address = function MIPS -> Reg.find "PC" !reg | PPC -> Reg.find "CIA" !reg | AArch64 -> Reg.find "_PC" !reg -let option_int_of_reg str = - option_int_of_option_integer (integer_of_register_value (Reg.find str !reg)) +let option_int_of_reg str = option_int_of_option_integer (integer_of_register_value (Reg.find str !reg)) let rec fde_loop count context model mode track_dependencies addr_trans = - if !max_cut_off && count = !max_instr - then resultf "\nEnding evaluation due to reaching cut off point of %d instructions\n" count + if !max_cut_off && count = !max_instr then + resultf "\nEnding evaluation due to reaching cut off point of %d instructions\n" count else begin - if !break_point && count = !break_instr then begin break_point := false; eager_eval := false end; + if !break_point && count = !break_instr then begin + break_point := false; + eager_eval := false + end; let pc_regval = get_pc_address model in - interactf "\n**** instruction %d from address %s ****\n" - count (Printing_functions.register_value_to_string pc_regval); - let pc_addr = address_of_register_value pc_regval in - let pc_val = match pc_addr with - | Some v -> v - | None -> failwith "pc contains undef or unknown" in - let m_paddr_int = match addr_trans (get_addr_trans_regs ()) pc_val with - | Some a, Some events -> write_events (List.rev events); Some (integer_of_address a) - | Some a, None -> Some (integer_of_address a) - | None, Some events -> write_events (List.rev events); None - | None, None -> failwith "address translation failed and no writes" in + interactf "\n**** instruction %d from address %s ****\n" count + (Printing_functions.register_value_to_string pc_regval); + let pc_addr = address_of_register_value pc_regval in + let pc_val = match pc_addr with Some v -> v | None -> failwith "pc contains undef or unknown" in + let m_paddr_int = + match addr_trans (get_addr_trans_regs ()) pc_val with + | Some a, Some events -> + write_events (List.rev events); + Some (integer_of_address a) + | Some a, None -> Some (integer_of_address a) + | None, Some events -> + write_events (List.rev events); + None + | None, None -> failwith "address translation failed and no writes" + in match m_paddr_int with - | Some pc -> - let inBranchDelay = option_int_of_reg "inBranchDelay" in - (match inBranchDelay with - | Some 0 -> + | Some pc -> + let inBranchDelay = option_int_of_reg "inBranchDelay" in + ( match inBranchDelay with + | Some 0 -> let npc_addr = add_address_nat pc_val 4 in let npc_reg = register_value_of_address npc_addr Sail_impl_base.D_decreasing in - reg := Reg.add "nextPC" npc_reg !reg; - | Some 1 -> - reg := Reg.add "nextPC" (Reg.find "delayedPC" !reg) !reg; - | _ -> failwith "invalid value of inBranchDelay"); - let opcode = Opcode (get_opcode pc) in - let (instruction,istate) = match Interp_inter_imp.decode_to_istate context None opcode with - | Instr(instruction,istate) -> - let instruction = interp_value_to_instr_external context instruction in - interactf "\n**** Running: %s ****\n" - (Printing_functions.instruction_to_string instruction); - (instruction,istate) - | Decode_error d -> - (match d with + reg := Reg.add "nextPC" npc_reg !reg + | Some 1 -> reg := Reg.add "nextPC" (Reg.find "delayedPC" !reg) !reg + | _ -> failwith "invalid value of inBranchDelay" + ); + let opcode = Opcode (get_opcode pc) in + let instruction, istate = + match Interp_inter_imp.decode_to_istate context None opcode with + | Instr (instruction, istate) -> + let instruction = interp_value_to_instr_external context instruction in + interactf "\n**** Running: %s ****\n" (Printing_functions.instruction_to_string instruction); + (instruction, istate) + | Decode_error d -> + ( match d with | Interp_interface.Unsupported_instruction_error instruction -> - let instruction = interp_value_to_instr_external context instruction in - errorf "\n**** Encountered unsupported instruction %s ****\n" - (Printing_functions.instruction_to_string instruction) - | Interp_interface.Not_an_instruction_error op -> - (match op with - | Opcode bytes -> - errorf "\n**** Encountered non-decodeable opcode: %s ****\n" (Printing_functions.byte_list_to_string bytes)) - | Internal_error s -> errorf "\n**** Internal error on decode: %s ****\n" s); exit 1 - in - if stop_condition_met model instruction - then resultf "\nSUCCESS program terminated after %d instructions\n" count - else - begin - match Run_interp_model.run istate !reg !prog_mem !tag_mem (Nat_big_num.of_int 1) !eager_eval track_dependencies mode "execute" with - | false, _,_, _ -> errorf "FAILURE\n"; exit 1 - | true, mode, track_dependencies, (my_reg, my_mem, my_tags) -> - reg := my_reg; - prog_mem := my_mem; - tag_mem := my_tags; - - (try - let (pending, _, _) = (Unix.select [(Unix.stdin)] [] [] 0.0) in - (if (pending != []) then - let char = (input_byte stdin) in ( + let instruction = interp_value_to_instr_external context instruction in + errorf "\n**** Encountered unsupported instruction %s ****\n" + (Printing_functions.instruction_to_string instruction) + | Interp_interface.Not_an_instruction_error op -> ( + match op with + | Opcode bytes -> + errorf "\n**** Encountered non-decodeable opcode: %s ****\n" + (Printing_functions.byte_list_to_string bytes) + ) + | Internal_error s -> errorf "\n**** Internal error on decode: %s ****\n" s + ); + exit 1 + in + if stop_condition_met model instruction then + resultf "\nSUCCESS program terminated after %d instructions\n" count + else begin + match + Run_interp_model.run istate !reg !prog_mem !tag_mem (Nat_big_num.of_int 1) !eager_eval track_dependencies + mode "execute" + with + | false, _, _, _ -> + errorf "FAILURE\n"; + exit 1 + | true, mode, track_dependencies, (my_reg, my_mem, my_tags) -> + reg := my_reg; + prog_mem := my_mem; + tag_mem := my_tags; + + ( try + let pending, _, _ = Unix.select [Unix.stdin] [] [] 0.0 in + if pending != [] then ( + let char = input_byte stdin in errorf "Input %x\n" char; - input_buf := (!input_buf) @ [char])); - with - | _ -> ()); - - let uart_rvalid = option_int_of_reg "UART_RVALID" in - (match uart_rvalid with - | Some 0 -> - (match !input_buf with - | x :: xs -> ( - reg := Reg.add "UART_RDATA" (register_value_of_integer 8 7 Sail_impl_base.D_decreasing (Nat_big_num.of_int x)) !reg; - reg := Reg.add "UART_RVALID" (register_value_of_integer 1 0 Sail_impl_base.D_decreasing (Nat_big_num.of_int 1)) !reg; - input_buf := xs; - ) - | [] -> ()) - | _-> ()); - - let uart_written = option_int_of_reg "UART_WRITTEN" in - (match uart_written with - | Some 1 -> - (let uart_data = option_int_of_reg "UART_WDATA" in + input_buf := !input_buf @ [char] + ) + with _ -> () + ); + + let uart_rvalid = option_int_of_reg "UART_RVALID" in + ( match uart_rvalid with + | Some 0 -> ( + match !input_buf with + | x :: xs -> + reg := + Reg.add "UART_RDATA" + (register_value_of_integer 8 7 Sail_impl_base.D_decreasing (Nat_big_num.of_int x)) + !reg; + reg := + Reg.add "UART_RVALID" + (register_value_of_integer 1 0 Sail_impl_base.D_decreasing (Nat_big_num.of_int 1)) + !reg; + input_buf := xs + | [] -> () + ) + | _ -> () + ); + + let uart_written = option_int_of_reg "UART_WRITTEN" in + ( match uart_written with + | Some 1 -> ( + let uart_data = option_int_of_reg "UART_WDATA" in match uart_data with - | Some b -> (printf "%c" (Char.chr b); printf "%!") - | None -> (errorf "UART_WDATA was undef" ; exit 1)) - | _ -> ()); - reg := Reg.add "UART_WRITTEN" (register_value_of_integer 1 0 Sail_impl_base.D_decreasing Nat_big_num.zero) !reg; - - reg := Reg.add "inBranchDelay" (Reg.find "branchPending" !reg) !reg; - reg := Reg.add "branchPending" (register_value_of_integer 1 0 Sail_impl_base.D_decreasing Nat_big_num.zero) !reg; - reg := Reg.add "PC" (Reg.find "nextPC" !reg) !reg; - fde_loop (count + 1) context model (Some mode) (ref track_dependencies) addr_trans - end - | None -> begin - reg := Reg.add "inBranchDelay" (Reg.find "branchPending" !reg) !reg; - reg := Reg.add "branchPending" (register_value_of_integer 1 0 Sail_impl_base.D_decreasing Nat_big_num.zero) !reg; - reg := Reg.add "PC" (Reg.find "nextPC" !reg) !reg; - fde_loop (count + 1) context model mode track_dependencies addr_trans - end + | Some b -> + printf "%c" (Char.chr b); + printf "%!" + | None -> + errorf "UART_WDATA was undef"; + exit 1 + ) + | _ -> () + ); + reg := + Reg.add "UART_WRITTEN" (register_value_of_integer 1 0 Sail_impl_base.D_decreasing Nat_big_num.zero) !reg; + + reg := Reg.add "inBranchDelay" (Reg.find "branchPending" !reg) !reg; + reg := + Reg.add "branchPending" + (register_value_of_integer 1 0 Sail_impl_base.D_decreasing Nat_big_num.zero) + !reg; + reg := Reg.add "PC" (Reg.find "nextPC" !reg) !reg; + fde_loop (count + 1) context model (Some mode) (ref track_dependencies) addr_trans + end + | None -> begin + reg := Reg.add "inBranchDelay" (Reg.find "branchPending" !reg) !reg; + reg := Reg.add "branchPending" (register_value_of_integer 1 0 Sail_impl_base.D_decreasing Nat_big_num.zero) !reg; + reg := Reg.add "PC" (Reg.find "nextPC" !reg) !reg; + fde_loop (count + 1) context model mode track_dependencies addr_trans + end end let rec load_raw_file' mem addr chan = let byte = input_byte chan in - (add_mem byte addr mem; - load_raw_file' mem (Nat_big_num.succ addr) chan) + add_mem byte addr mem; + load_raw_file' mem (Nat_big_num.succ addr) chan -let rec load_raw_file mem addr chan = - try - load_raw_file' mem addr chan - with - | End_of_file -> () +let rec load_raw_file mem addr chan = try load_raw_file' mem addr chan with End_of_file -> () let run () = - Arg.parse args (fun _ -> raise (Arg.Bad "anonymous parameter")) "" ; + Arg.parse args (fun _ -> raise (Arg.Bad "anonymous parameter")) ""; if !file = "" then begin Arg.usage args ""; - exit 1; + exit 1 end; if !break_point then eager_eval := true; - let ((isa_defs, - (isa_m0, isa_m1, isa_m2, isa_m3,isa_m4), - isa_externs, - isa_model, - model_reg_d, - startaddr, - startaddr_internal), pp_symbol_map) = initial_system_state_of_elf_file !file in + let ( ( isa_defs, + (isa_m0, isa_m1, isa_m2, isa_m3, isa_m4), + isa_externs, + isa_model, + model_reg_d, + startaddr, + startaddr_internal + ), + pp_symbol_map ) = + initial_system_state_of_elf_file !file + in let context = build_context false isa_defs isa_m0 [] isa_m1 isa_m2 isa_m3 [] isa_m4 None isa_externs in - (*NOTE: this is likely MIPS specific, so should probably pull from initial_system_state info on to translate or not, - endian mode, and translate function name + (*NOTE: this is likely MIPS specific, so should probably pull from initial_system_state info on to translate or not, + endian mode, and translate function name *) let addr_trans = translate_address context E_little_endian "TranslatePC" in - if String.length(!raw_file) != 0 then - load_raw_file prog_mem (Nat_big_num.of_int !raw_at) (open_in_bin !raw_file); - reg := Reg.add "PC" (register_value_of_address startaddr_internal model_reg_d ) !reg; + if String.length !raw_file != 0 then load_raw_file prog_mem (Nat_big_num.of_int !raw_at) (open_in_bin !raw_file); + reg := Reg.add "PC" (register_value_of_address startaddr_internal model_reg_d) !reg; (* entry point: unit -> unit fde *) let name = Filename.basename !file in - let t = time_it (fun () -> fde_loop 0 context isa_model (Some Run) (ref false) addr_trans) () in - resultf "Execution time for file %s: %f seconds\n" name t;; + let t = time_it (fun () -> fde_loop 0 context isa_model (Some Run) (ref false) addr_trans) () in + resultf "Execution time for file %s: %f seconds\n" name t +;; (* Turn off line-buffering of standard input to allow responsive console input *) -if Unix.isatty (Unix.stdin) then begin - let tattrs = Unix.tcgetattr (Unix.stdin) in - Unix.tcsetattr (Unix.stdin) (Unix.TCSANOW) ({tattrs with c_icanon=false}) -end ;; +if Unix.isatty Unix.stdin then begin + let tattrs = Unix.tcgetattr Unix.stdin in + Unix.tcsetattr Unix.stdin Unix.TCSANOW { tattrs with c_icanon = false } +end +;; -run () ;; +run () diff --git a/src/lem_interp/run_with_elf_cheri.ml b/src/lem_interp/run_with_elf_cheri.ml index 2a8e04954..da791a66b 100644 --- a/src/lem_interp/run_with_elf_cheri.ml +++ b/src/lem_interp/run_with_elf_cheri.ml @@ -65,37 +65,34 @@ (* SUCH DAMAGE. *) (****************************************************************************) -open Printf ;; -open Format ;; -open Big_int ;; -open Interp_ast ;; -open Interp_interface ;; -open Interp_inter_imp ;; -open Run_interp_model ;; -open Sail_impl_base ;; -open Sail_interface ;; - -module StringMap = Map.Make(String) - -let file = ref "" ;; - -let rec foldli f acc ?(i=0) = function - | [] -> acc - | x::xs -> foldli f (f i acc x) ~i:(i+1) xs -;; +open Printf +open Format +open Big_int +open Interp_ast +open Interp_interface +open Interp_inter_imp +open Run_interp_model +open Sail_impl_base +open Sail_interface + +module StringMap = Map.Make (String) + +let file = ref "" -let endian = ref E_big_endian ;; +let rec foldli f acc ?(i = 0) = function [] -> acc | x :: xs -> foldli f (f i acc x) ~i:(i + 1) xs -let hex_to_big_int s = big_int_of_int64 (Int64.of_string s) ;; +let endian = ref E_big_endian -let data_mem = (ref Mem.empty : (memory_byte Run_interp_model.Mem.t) ref) ;; -let prog_mem = (ref Mem.empty : (memory_byte Run_interp_model.Mem.t) ref) ;; -let tag_mem = (ref Mem.empty : (memory_byte Run_interp_model.Mem.t) ref);; -let reg = ref Reg.empty ;; -let input_buf = (ref [] : int list ref);; +let hex_to_big_int s = big_int_of_int64 (Int64.of_string s) + +let data_mem = (ref Mem.empty : memory_byte Run_interp_model.Mem.t ref) +let prog_mem = (ref Mem.empty : memory_byte Run_interp_model.Mem.t ref) +let tag_mem = (ref Mem.empty : memory_byte Run_interp_model.Mem.t ref) +let reg = ref Reg.empty +let input_buf = (ref [] : int list ref) let add_mem byte addr mem = - assert(byte >= 0 && byte < 256); + assert (byte >= 0 && byte < 256); (*Printf.printf "add_mem %s: 0x%02x\n" (Uint64.to_string_hex (Uint64.of_string (Nat_big_num.to_string addr))) byte;*) let mem_byte = memory_byte_of_int byte in let zero_byte = memory_byte_of_int 0 in @@ -103,437 +100,452 @@ let add_mem byte addr mem = tag_mem := Mem.add addr zero_byte !tag_mem let get_reg reg name = - let reg_content = Reg.find name reg in reg_content + let reg_content = Reg.find name reg in + reg_content -let rec load_memory_segment' (bytes,addr) mem = +let rec load_memory_segment' (bytes, addr) mem = match bytes with | [] -> () - | byte::bytes' -> - let data_byte = Char.code byte in - let addr' = Nat_big_num.succ addr in - begin add_mem data_byte addr mem; - load_memory_segment' (bytes',addr') mem - end - -let rec load_memory_segment (segment: Elf_interpreted_segment.elf64_interpreted_segment) mem = + | byte :: bytes' -> + let data_byte = Char.code byte in + let addr' = Nat_big_num.succ addr in + begin + add_mem data_byte addr mem; + load_memory_segment' (bytes', addr') mem + end + +let rec load_memory_segment (segment : Elf_interpreted_segment.elf64_interpreted_segment) mem = let (Byte_sequence.Sequence bytes) = segment.Elf_interpreted_segment.elf64_segment_body in let addr = segment.Elf_interpreted_segment.elf64_segment_paddr in - load_memory_segment' (bytes,addr) mem - + load_memory_segment' (bytes, addr) mem let rec load_memory_segments segments = - begin match segments with + begin + match segments with | [] -> () - | segment::segments' -> - let (x,w,r) = segment.Elf_interpreted_segment.elf64_segment_flags in - begin - load_memory_segment segment prog_mem; - load_memory_segments segments' - end + | segment :: segments' -> + let x, w, r = segment.Elf_interpreted_segment.elf64_segment_flags in + begin + load_memory_segment segment prog_mem; + load_memory_segments segments' + end end - -let rec read_mem mem address length = - if length = 0 - then [] - else - let byte = - try Mem.find address mem with - | Not_found -> failwith "start address not found" - in - byte :: (read_mem mem (Nat_big_num.succ address) (length - 1)) + +let rec read_mem mem address length = + if length = 0 then [] + else ( + let byte = try Mem.find address mem with Not_found -> failwith "start address not found" in + byte :: read_mem mem (Nat_big_num.succ address) (length - 1) + ) let register_state_zero register_data rbn : register_value = - let (dir,width,start_index) = - try List.assoc rbn register_data with - | Not_found -> failwith ("register_state_zero lookup failed (" ^ rbn) - in register_value_zeros dir width start_index + let dir, width, start_index = + try List.assoc rbn register_data with Not_found -> failwith ("register_state_zero lookup failed (" ^ rbn) + in + register_value_zeros dir width start_index type model = PPC | AArch64 | MIPS -let mips_register_data_all = [ - (*Pseudo registers*) - ("PC", (D_decreasing, 64, 63)); - ("branchPending", (D_decreasing, 1, 0)); - ("inBranchDelay", (D_decreasing, 1, 0)); - ("inCCallDelay", (D_decreasing, 1, 0)); - ("delayedPC", (D_decreasing, 64, 63)); - ("nextPC", (D_decreasing, 64, 63)); - (* General purpose registers *) - ("GPR00", (D_decreasing, 64, 63)); - ("GPR01", (D_decreasing, 64, 63)); - ("GPR02", (D_decreasing, 64, 63)); - ("GPR03", (D_decreasing, 64, 63)); - ("GPR04", (D_decreasing, 64, 63)); - ("GPR05", (D_decreasing, 64, 63)); - ("GPR06", (D_decreasing, 64, 63)); - ("GPR07", (D_decreasing, 64, 63)); - ("GPR08", (D_decreasing, 64, 63)); - ("GPR09", (D_decreasing, 64, 63)); - ("GPR10", (D_decreasing, 64, 63)); - ("GPR11", (D_decreasing, 64, 63)); - ("GPR12", (D_decreasing, 64, 63)); - ("GPR13", (D_decreasing, 64, 63)); - ("GPR14", (D_decreasing, 64, 63)); - ("GPR15", (D_decreasing, 64, 63)); - ("GPR16", (D_decreasing, 64, 63)); - ("GPR17", (D_decreasing, 64, 63)); - ("GPR18", (D_decreasing, 64, 63)); - ("GPR19", (D_decreasing, 64, 63)); - ("GPR20", (D_decreasing, 64, 63)); - ("GPR21", (D_decreasing, 64, 63)); - ("GPR22", (D_decreasing, 64, 63)); - ("GPR23", (D_decreasing, 64, 63)); - ("GPR24", (D_decreasing, 64, 63)); - ("GPR25", (D_decreasing, 64, 63)); - ("GPR26", (D_decreasing, 64, 63)); - ("GPR27", (D_decreasing, 64, 63)); - ("GPR28", (D_decreasing, 64, 63)); - ("GPR29", (D_decreasing, 64, 63)); - ("GPR30", (D_decreasing, 64, 63)); - ("GPR31", (D_decreasing, 64, 63)); - (* special registers for mul/div *) - ("HI", (D_decreasing, 64, 63)); - ("LO", (D_decreasing, 64, 63)); - (* control registers *) - ("CP0Status", (D_decreasing, 32, 31)); - ("CP0Cause", (D_decreasing, 32, 31)); - ("CP0EPC", (D_decreasing, 64, 63)); - ("CP0LLAddr", (D_decreasing, 64, 63)); - ("CP0LLBit", (D_decreasing, 1, 0)); - ("CP0Count", (D_decreasing, 32, 31)); - ("CP0Compare", (D_decreasing, 32, 31)); - ("CP0HWREna", (D_decreasing, 32, 31)); - ("CP0UserLocal", (D_decreasing, 64, 63)); - ("CP0BadVAddr", (D_decreasing, 64, 63)); - ("TLBProbe" ,(D_decreasing, 1, 0)); - ("TLBIndex" ,(D_decreasing, 6, 5)); - ("TLBRandom" ,(D_decreasing, 6, 5)); - ("TLBEntryLo0",(D_decreasing, 64, 63)); - ("TLBEntryLo1",(D_decreasing, 64, 63)); - ("TLBContext" ,(D_decreasing, 64, 63)); - ("TLBPageMask",(D_decreasing, 16, 15)); - ("TLBWired" ,(D_decreasing, 6, 5)); - ("TLBEntryHi" ,(D_decreasing, 64, 63)); - ("TLBXContext",(D_decreasing, 64, 63)); - - ("TLBEntry00" ,(D_decreasing, 117, 116)); - ("TLBEntry01" ,(D_decreasing, 117, 116)); - ("TLBEntry02" ,(D_decreasing, 117, 116)); - ("TLBEntry03" ,(D_decreasing, 117, 116)); - ("TLBEntry04" ,(D_decreasing, 117, 116)); - ("TLBEntry05" ,(D_decreasing, 117, 116)); - ("TLBEntry06" ,(D_decreasing, 117, 116)); - ("TLBEntry07" ,(D_decreasing, 117, 116)); - ("TLBEntry08" ,(D_decreasing, 117, 116)); - ("TLBEntry09" ,(D_decreasing, 117, 116)); - ("TLBEntry10" ,(D_decreasing, 117, 116)); - ("TLBEntry11" ,(D_decreasing, 117, 116)); - ("TLBEntry12" ,(D_decreasing, 117, 116)); - ("TLBEntry13" ,(D_decreasing, 117, 116)); - ("TLBEntry14" ,(D_decreasing, 117, 116)); - ("TLBEntry15" ,(D_decreasing, 117, 116)); - ("TLBEntry16" ,(D_decreasing, 117, 116)); - ("TLBEntry17" ,(D_decreasing, 117, 116)); - ("TLBEntry18" ,(D_decreasing, 117, 116)); - ("TLBEntry19" ,(D_decreasing, 117, 116)); - ("TLBEntry20" ,(D_decreasing, 117, 116)); - ("TLBEntry21" ,(D_decreasing, 117, 116)); - ("TLBEntry22" ,(D_decreasing, 117, 116)); - ("TLBEntry23" ,(D_decreasing, 117, 116)); - ("TLBEntry24" ,(D_decreasing, 117, 116)); - ("TLBEntry25" ,(D_decreasing, 117, 116)); - ("TLBEntry26" ,(D_decreasing, 117, 116)); - ("TLBEntry27" ,(D_decreasing, 117, 116)); - ("TLBEntry28" ,(D_decreasing, 117, 116)); - ("TLBEntry29" ,(D_decreasing, 117, 116)); - ("TLBEntry30" ,(D_decreasing, 117, 116)); - ("TLBEntry31" ,(D_decreasing, 117, 116)); - ("TLBEntry32" ,(D_decreasing, 117, 116)); - ("TLBEntry33" ,(D_decreasing, 117, 116)); - ("TLBEntry34" ,(D_decreasing, 117, 116)); - ("TLBEntry35" ,(D_decreasing, 117, 116)); - ("TLBEntry36" ,(D_decreasing, 117, 116)); - ("TLBEntry37" ,(D_decreasing, 117, 116)); - ("TLBEntry38" ,(D_decreasing, 117, 116)); - ("TLBEntry39" ,(D_decreasing, 117, 116)); - ("TLBEntry40" ,(D_decreasing, 117, 116)); - ("TLBEntry41" ,(D_decreasing, 117, 116)); - ("TLBEntry42" ,(D_decreasing, 117, 116)); - ("TLBEntry43" ,(D_decreasing, 117, 116)); - ("TLBEntry44" ,(D_decreasing, 117, 116)); - ("TLBEntry45" ,(D_decreasing, 117, 116)); - ("TLBEntry46" ,(D_decreasing, 117, 116)); - ("TLBEntry47" ,(D_decreasing, 117, 116)); - ("TLBEntry48" ,(D_decreasing, 117, 116)); - ("TLBEntry49" ,(D_decreasing, 117, 116)); - ("TLBEntry50" ,(D_decreasing, 117, 116)); - ("TLBEntry51" ,(D_decreasing, 117, 116)); - ("TLBEntry52" ,(D_decreasing, 117, 116)); - ("TLBEntry53" ,(D_decreasing, 117, 116)); - ("TLBEntry54" ,(D_decreasing, 117, 116)); - ("TLBEntry55" ,(D_decreasing, 117, 116)); - ("TLBEntry56" ,(D_decreasing, 117, 116)); - ("TLBEntry57" ,(D_decreasing, 117, 116)); - ("TLBEntry58" ,(D_decreasing, 117, 116)); - ("TLBEntry59" ,(D_decreasing, 117, 116)); - ("TLBEntry60" ,(D_decreasing, 117, 116)); - ("TLBEntry61" ,(D_decreasing, 117, 116)); - ("TLBEntry62" ,(D_decreasing, 117, 116)); - ("TLBEntry63" ,(D_decreasing, 117, 116)); - - ("UART_WDATA" ,(D_decreasing, 8, 7)); - ("UART_RDATA" ,(D_decreasing, 8, 7)); - ("UART_WRITTEN" ,(D_decreasing, 1, 0)); - ("UART_RVALID" ,(D_decreasing, 1, 0)); -] - -let cheri_register_data_all = mips_register_data_all @ [ - ("CapCause", (D_decreasing, 16, 15)); - ("PCC", (D_decreasing, 257, 256)); - ("nextPCC", (D_decreasing, 257, 256)); - ("delayedPCC", (D_decreasing, 257, 256)); - ("C00", (D_decreasing, 257, 256)); - ("C01", (D_decreasing, 257, 256)); - ("C02", (D_decreasing, 257, 256)); - ("C03", (D_decreasing, 257, 256)); - ("C04", (D_decreasing, 257, 256)); - ("C05", (D_decreasing, 257, 256)); - ("C06", (D_decreasing, 257, 256)); - ("C07", (D_decreasing, 257, 256)); - ("C08", (D_decreasing, 257, 256)); - ("C09", (D_decreasing, 257, 256)); - ("C10", (D_decreasing, 257, 256)); - ("C11", (D_decreasing, 257, 256)); - ("C12", (D_decreasing, 257, 256)); - ("C13", (D_decreasing, 257, 256)); - ("C14", (D_decreasing, 257, 256)); - ("C15", (D_decreasing, 257, 256)); - ("C16", (D_decreasing, 257, 256)); - ("C17", (D_decreasing, 257, 256)); - ("C18", (D_decreasing, 257, 256)); - ("C19", (D_decreasing, 257, 256)); - ("C20", (D_decreasing, 257, 256)); - ("C21", (D_decreasing, 257, 256)); - ("C22", (D_decreasing, 257, 256)); - ("C23", (D_decreasing, 257, 256)); - ("C24", (D_decreasing, 257, 256)); - ("C25", (D_decreasing, 257, 256)); - ("C26", (D_decreasing, 257, 256)); - ("C27", (D_decreasing, 257, 256)); - ("C28", (D_decreasing, 257, 256)); - ("C29", (D_decreasing, 257, 256)); - ("C30", (D_decreasing, 257, 256)); - ("C31", (D_decreasing, 257, 256)); -] +let mips_register_data_all = + [ + (*Pseudo registers*) + ("PC", (D_decreasing, 64, 63)); + ("branchPending", (D_decreasing, 1, 0)); + ("inBranchDelay", (D_decreasing, 1, 0)); + ("inCCallDelay", (D_decreasing, 1, 0)); + ("delayedPC", (D_decreasing, 64, 63)); + ("nextPC", (D_decreasing, 64, 63)); + (* General purpose registers *) + ("GPR00", (D_decreasing, 64, 63)); + ("GPR01", (D_decreasing, 64, 63)); + ("GPR02", (D_decreasing, 64, 63)); + ("GPR03", (D_decreasing, 64, 63)); + ("GPR04", (D_decreasing, 64, 63)); + ("GPR05", (D_decreasing, 64, 63)); + ("GPR06", (D_decreasing, 64, 63)); + ("GPR07", (D_decreasing, 64, 63)); + ("GPR08", (D_decreasing, 64, 63)); + ("GPR09", (D_decreasing, 64, 63)); + ("GPR10", (D_decreasing, 64, 63)); + ("GPR11", (D_decreasing, 64, 63)); + ("GPR12", (D_decreasing, 64, 63)); + ("GPR13", (D_decreasing, 64, 63)); + ("GPR14", (D_decreasing, 64, 63)); + ("GPR15", (D_decreasing, 64, 63)); + ("GPR16", (D_decreasing, 64, 63)); + ("GPR17", (D_decreasing, 64, 63)); + ("GPR18", (D_decreasing, 64, 63)); + ("GPR19", (D_decreasing, 64, 63)); + ("GPR20", (D_decreasing, 64, 63)); + ("GPR21", (D_decreasing, 64, 63)); + ("GPR22", (D_decreasing, 64, 63)); + ("GPR23", (D_decreasing, 64, 63)); + ("GPR24", (D_decreasing, 64, 63)); + ("GPR25", (D_decreasing, 64, 63)); + ("GPR26", (D_decreasing, 64, 63)); + ("GPR27", (D_decreasing, 64, 63)); + ("GPR28", (D_decreasing, 64, 63)); + ("GPR29", (D_decreasing, 64, 63)); + ("GPR30", (D_decreasing, 64, 63)); + ("GPR31", (D_decreasing, 64, 63)); + (* special registers for mul/div *) + ("HI", (D_decreasing, 64, 63)); + ("LO", (D_decreasing, 64, 63)); + (* control registers *) + ("CP0Status", (D_decreasing, 32, 31)); + ("CP0Cause", (D_decreasing, 32, 31)); + ("CP0EPC", (D_decreasing, 64, 63)); + ("CP0LLAddr", (D_decreasing, 64, 63)); + ("CP0LLBit", (D_decreasing, 1, 0)); + ("CP0Count", (D_decreasing, 32, 31)); + ("CP0Compare", (D_decreasing, 32, 31)); + ("CP0HWREna", (D_decreasing, 32, 31)); + ("CP0UserLocal", (D_decreasing, 64, 63)); + ("CP0BadVAddr", (D_decreasing, 64, 63)); + ("TLBProbe", (D_decreasing, 1, 0)); + ("TLBIndex", (D_decreasing, 6, 5)); + ("TLBRandom", (D_decreasing, 6, 5)); + ("TLBEntryLo0", (D_decreasing, 64, 63)); + ("TLBEntryLo1", (D_decreasing, 64, 63)); + ("TLBContext", (D_decreasing, 64, 63)); + ("TLBPageMask", (D_decreasing, 16, 15)); + ("TLBWired", (D_decreasing, 6, 5)); + ("TLBEntryHi", (D_decreasing, 64, 63)); + ("TLBXContext", (D_decreasing, 64, 63)); + ("TLBEntry00", (D_decreasing, 117, 116)); + ("TLBEntry01", (D_decreasing, 117, 116)); + ("TLBEntry02", (D_decreasing, 117, 116)); + ("TLBEntry03", (D_decreasing, 117, 116)); + ("TLBEntry04", (D_decreasing, 117, 116)); + ("TLBEntry05", (D_decreasing, 117, 116)); + ("TLBEntry06", (D_decreasing, 117, 116)); + ("TLBEntry07", (D_decreasing, 117, 116)); + ("TLBEntry08", (D_decreasing, 117, 116)); + ("TLBEntry09", (D_decreasing, 117, 116)); + ("TLBEntry10", (D_decreasing, 117, 116)); + ("TLBEntry11", (D_decreasing, 117, 116)); + ("TLBEntry12", (D_decreasing, 117, 116)); + ("TLBEntry13", (D_decreasing, 117, 116)); + ("TLBEntry14", (D_decreasing, 117, 116)); + ("TLBEntry15", (D_decreasing, 117, 116)); + ("TLBEntry16", (D_decreasing, 117, 116)); + ("TLBEntry17", (D_decreasing, 117, 116)); + ("TLBEntry18", (D_decreasing, 117, 116)); + ("TLBEntry19", (D_decreasing, 117, 116)); + ("TLBEntry20", (D_decreasing, 117, 116)); + ("TLBEntry21", (D_decreasing, 117, 116)); + ("TLBEntry22", (D_decreasing, 117, 116)); + ("TLBEntry23", (D_decreasing, 117, 116)); + ("TLBEntry24", (D_decreasing, 117, 116)); + ("TLBEntry25", (D_decreasing, 117, 116)); + ("TLBEntry26", (D_decreasing, 117, 116)); + ("TLBEntry27", (D_decreasing, 117, 116)); + ("TLBEntry28", (D_decreasing, 117, 116)); + ("TLBEntry29", (D_decreasing, 117, 116)); + ("TLBEntry30", (D_decreasing, 117, 116)); + ("TLBEntry31", (D_decreasing, 117, 116)); + ("TLBEntry32", (D_decreasing, 117, 116)); + ("TLBEntry33", (D_decreasing, 117, 116)); + ("TLBEntry34", (D_decreasing, 117, 116)); + ("TLBEntry35", (D_decreasing, 117, 116)); + ("TLBEntry36", (D_decreasing, 117, 116)); + ("TLBEntry37", (D_decreasing, 117, 116)); + ("TLBEntry38", (D_decreasing, 117, 116)); + ("TLBEntry39", (D_decreasing, 117, 116)); + ("TLBEntry40", (D_decreasing, 117, 116)); + ("TLBEntry41", (D_decreasing, 117, 116)); + ("TLBEntry42", (D_decreasing, 117, 116)); + ("TLBEntry43", (D_decreasing, 117, 116)); + ("TLBEntry44", (D_decreasing, 117, 116)); + ("TLBEntry45", (D_decreasing, 117, 116)); + ("TLBEntry46", (D_decreasing, 117, 116)); + ("TLBEntry47", (D_decreasing, 117, 116)); + ("TLBEntry48", (D_decreasing, 117, 116)); + ("TLBEntry49", (D_decreasing, 117, 116)); + ("TLBEntry50", (D_decreasing, 117, 116)); + ("TLBEntry51", (D_decreasing, 117, 116)); + ("TLBEntry52", (D_decreasing, 117, 116)); + ("TLBEntry53", (D_decreasing, 117, 116)); + ("TLBEntry54", (D_decreasing, 117, 116)); + ("TLBEntry55", (D_decreasing, 117, 116)); + ("TLBEntry56", (D_decreasing, 117, 116)); + ("TLBEntry57", (D_decreasing, 117, 116)); + ("TLBEntry58", (D_decreasing, 117, 116)); + ("TLBEntry59", (D_decreasing, 117, 116)); + ("TLBEntry60", (D_decreasing, 117, 116)); + ("TLBEntry61", (D_decreasing, 117, 116)); + ("TLBEntry62", (D_decreasing, 117, 116)); + ("TLBEntry63", (D_decreasing, 117, 116)); + ("UART_WDATA", (D_decreasing, 8, 7)); + ("UART_RDATA", (D_decreasing, 8, 7)); + ("UART_WRITTEN", (D_decreasing, 1, 0)); + ("UART_RVALID", (D_decreasing, 1, 0)); + ] + +let cheri_register_data_all = + mips_register_data_all + @ [ + ("CapCause", (D_decreasing, 16, 15)); + ("PCC", (D_decreasing, 257, 256)); + ("nextPCC", (D_decreasing, 257, 256)); + ("delayedPCC", (D_decreasing, 257, 256)); + ("C00", (D_decreasing, 257, 256)); + ("C01", (D_decreasing, 257, 256)); + ("C02", (D_decreasing, 257, 256)); + ("C03", (D_decreasing, 257, 256)); + ("C04", (D_decreasing, 257, 256)); + ("C05", (D_decreasing, 257, 256)); + ("C06", (D_decreasing, 257, 256)); + ("C07", (D_decreasing, 257, 256)); + ("C08", (D_decreasing, 257, 256)); + ("C09", (D_decreasing, 257, 256)); + ("C10", (D_decreasing, 257, 256)); + ("C11", (D_decreasing, 257, 256)); + ("C12", (D_decreasing, 257, 256)); + ("C13", (D_decreasing, 257, 256)); + ("C14", (D_decreasing, 257, 256)); + ("C15", (D_decreasing, 257, 256)); + ("C16", (D_decreasing, 257, 256)); + ("C17", (D_decreasing, 257, 256)); + ("C18", (D_decreasing, 257, 256)); + ("C19", (D_decreasing, 257, 256)); + ("C20", (D_decreasing, 257, 256)); + ("C21", (D_decreasing, 257, 256)); + ("C22", (D_decreasing, 257, 256)); + ("C23", (D_decreasing, 257, 256)); + ("C24", (D_decreasing, 257, 256)); + ("C25", (D_decreasing, 257, 256)); + ("C26", (D_decreasing, 257, 256)); + ("C27", (D_decreasing, 257, 256)); + ("C28", (D_decreasing, 257, 256)); + ("C29", (D_decreasing, 257, 256)); + ("C30", (D_decreasing, 257, 256)); + ("C31", (D_decreasing, 257, 256)); + ] let initial_stack_and_reg_data_of_MIPS_elf_file e_entry all_data_memory = - let initial_stack_data = [] in - let initial_cap_val_int = Nat_big_num.of_string "115792089264276142078167421332581561412618036492862375629811892344162380414975" (*"0x100000000fffffffe00000000000000000000000000000000ffffffffffffffff"*) in - let initial_cap_val_reg = Sail_impl_base.register_value_of_integer 257 256 D_decreasing initial_cap_val_int in - let initial_register_abi_data : (string * Sail_impl_base.register_value) list = [ - ("CP0Status", Sail_impl_base.register_value_of_integer 32 31 D_decreasing (Nat_big_num.of_string "0x00400000")); - ("PCC", initial_cap_val_reg); - ("nextPCC", initial_cap_val_reg); - ("delayedPCC", initial_cap_val_reg); - ("C00", initial_cap_val_reg); - ("C01", initial_cap_val_reg); - ("C02", initial_cap_val_reg); - ("C03", initial_cap_val_reg); - ("C04", initial_cap_val_reg); - ("C05", initial_cap_val_reg); - ("C06", initial_cap_val_reg); - ("C07", initial_cap_val_reg); - ("C08", initial_cap_val_reg); - ("C09", initial_cap_val_reg); - ("C10", initial_cap_val_reg); - ("C11", initial_cap_val_reg); - ("C12", initial_cap_val_reg); - ("C13", initial_cap_val_reg); - ("C14", initial_cap_val_reg); - ("C15", initial_cap_val_reg); - ("C16", initial_cap_val_reg); - ("C17", initial_cap_val_reg); - ("C18", initial_cap_val_reg); - ("C19", initial_cap_val_reg); - ("C20", initial_cap_val_reg); - ("C21", initial_cap_val_reg); - ("C22", initial_cap_val_reg); - ("C23", initial_cap_val_reg); - ("C24", initial_cap_val_reg); - ("C25", initial_cap_val_reg); - ("C26", initial_cap_val_reg); - ("C27", initial_cap_val_reg); - ("C28", initial_cap_val_reg); - ("C29", initial_cap_val_reg); - ("C30", initial_cap_val_reg); - ("C31", initial_cap_val_reg); - ] in + let initial_stack_data = [] in + let initial_cap_val_int = + Nat_big_num.of_string "115792089264276142078167421332581561412618036492862375629811892344162380414975" + (*"0x100000000fffffffe00000000000000000000000000000000ffffffffffffffff"*) + in + let initial_cap_val_reg = Sail_impl_base.register_value_of_integer 257 256 D_decreasing initial_cap_val_int in + let initial_register_abi_data : (string * Sail_impl_base.register_value) list = + [ + ("CP0Status", Sail_impl_base.register_value_of_integer 32 31 D_decreasing (Nat_big_num.of_string "0x00400000")); + ("PCC", initial_cap_val_reg); + ("nextPCC", initial_cap_val_reg); + ("delayedPCC", initial_cap_val_reg); + ("C00", initial_cap_val_reg); + ("C01", initial_cap_val_reg); + ("C02", initial_cap_val_reg); + ("C03", initial_cap_val_reg); + ("C04", initial_cap_val_reg); + ("C05", initial_cap_val_reg); + ("C06", initial_cap_val_reg); + ("C07", initial_cap_val_reg); + ("C08", initial_cap_val_reg); + ("C09", initial_cap_val_reg); + ("C10", initial_cap_val_reg); + ("C11", initial_cap_val_reg); + ("C12", initial_cap_val_reg); + ("C13", initial_cap_val_reg); + ("C14", initial_cap_val_reg); + ("C15", initial_cap_val_reg); + ("C16", initial_cap_val_reg); + ("C17", initial_cap_val_reg); + ("C18", initial_cap_val_reg); + ("C19", initial_cap_val_reg); + ("C20", initial_cap_val_reg); + ("C21", initial_cap_val_reg); + ("C22", initial_cap_val_reg); + ("C23", initial_cap_val_reg); + ("C24", initial_cap_val_reg); + ("C25", initial_cap_val_reg); + ("C26", initial_cap_val_reg); + ("C27", initial_cap_val_reg); + ("C28", initial_cap_val_reg); + ("C29", initial_cap_val_reg); + ("C30", initial_cap_val_reg); + ("C31", initial_cap_val_reg); + ] + in (initial_stack_data, initial_register_abi_data) let initial_reg_file reg_data init = List.iter (fun (reg_name, _) -> reg := Reg.add reg_name (init reg_name) !reg) reg_data -let initial_system_state_of_elf_file name = - +let initial_system_state_of_elf_file name = (* call ELF analyser on file *) match Sail_interface.populate_and_obtain_global_symbol_init_info name with | Error.Fail s -> failwith ("populate_and_obtain_global_symbol_init_info: " ^ s) - | Error.Success - (_, (elf_epi: Sail_interface.executable_process_image), - (symbol_map: Elf_file.global_symbol_init_info)) - -> - let (segments, e_entry, e_machine) = - begin match elf_epi with - | ELF_Class_32 _ -> failwith "cannot handle ELF_Class_32" - | ELF_Class_64 (segments,e_entry,e_machine) -> - (* remove all the auto generated segments (they contain only 0s) *) - let segments = - Lem_list.mapMaybe - (fun (seg, prov) -> if prov = Elf_file.FromELF then Some seg else None) - segments - in - (segments,e_entry,e_machine) - end - in + | Error.Success + (_, (elf_epi : Sail_interface.executable_process_image), (symbol_map : Elf_file.global_symbol_init_info)) -> + let segments, e_entry, e_machine = + begin + match elf_epi with + | ELF_Class_32 _ -> failwith "cannot handle ELF_Class_32" + | ELF_Class_64 (segments, e_entry, e_machine) -> + (* remove all the auto generated segments (they contain only 0s) *) + let segments = + Lem_list.mapMaybe (fun (seg, prov) -> if prov = Elf_file.FromELF then Some seg else None) segments + in + (segments, e_entry, e_machine) + end + in - (* construct program memory and start address *) - begin - prog_mem := Mem.empty; - data_mem := Mem.empty; - tag_mem := Mem.empty; - load_memory_segments segments; - (* + (* construct program memory and start address *) + begin + prog_mem := Mem.empty; + data_mem := Mem.empty; + tag_mem := Mem.empty; + load_memory_segments segments; + (* debugf "prog_mem\n"; Mem.iter (fun k v -> debugf "%s\n" (Mem.to_string k v)) !prog_mem; debugf "data_mem\n"; Mem.iter (fun k v -> debugf "%s\n" (Mem.to_string k v)) !data_mem; *) - let (isa_defs, isa_memory_access, isa_externs, isa_model, model_reg_d, startaddr, - initial_stack_data, initial_register_abi_data, register_data_all) = - match Nat_big_num.to_int e_machine with - | 8 (* EM_MIPS *) -> - let startaddr = - let e_entry = Uint64_wrapper.of_bigint e_entry in - match Abi_mips64.abi_mips64_compute_program_entry_point segments e_entry with - | Error.Fail s -> failwith "Failed computing entry point" - | Error.Success s -> s - in - let (initial_stack_data, initial_register_abi_data) = - initial_stack_and_reg_data_of_MIPS_elf_file e_entry !data_mem in - - (Cheri.defs, - (Mips_extras.mips_read_memory_functions, - Mips_extras.mips_read_memory_tagged_functions, - Mips_extras.mips_memory_writes, - Mips_extras.mips_memory_eas, - Mips_extras.mips_memory_vals, - Mips_extras.mips_memory_vals_tagged, - Mips_extras.mips_barrier_functions), - [], - MIPS, - D_decreasing, - startaddr, - initial_stack_data, - initial_register_abi_data, - cheri_register_data_all) - - | _ -> failwith (Printf.sprintf "Sail sequential interpreter can't handle the e_machine value %s, only EM_PPC64, EM_AARCH64 and EM_MIPS are supported." (Nat_big_num.to_string e_machine)) - in - - (* pull the object symbols from the symbol table *) - let symbol_table : (string * Nat_big_num.num * int * word8 list (*their bytes*)) list = - let rec convert_symbol_table symbol_map = - begin match symbol_map with - | [] -> [] - | ((name: string), - ((typ: Nat_big_num.num), - (size: Nat_big_num.num (*number of bytes*)), - (address: Nat_big_num.num), - (mb: Byte_sequence.byte_sequence option (*present iff type=stt_object*)), - (binding: Nat_big_num.num))) - (* (mb: Byte_sequence_wrapper.t option (*present iff type=stt_object*)) )) *) - ::symbol_map' -> - if Nat_big_num.equal typ Elf_symbol_table.stt_object && not (Nat_big_num.equal size (Nat_big_num.of_int 0)) - then - ( - (* an object symbol - map *) - (*Printf.printf "*** size %d ***\n" (Nat_big_num.to_int size);*) - let bytes = - (match mb with - | None -> raise (Failure "this cannot happen") - | Some (Sequence bytes) -> - List.map (fun (c:char) -> Char.code c) bytes) in - (name, address, List.length bytes, bytes):: convert_symbol_table symbol_map' + let ( isa_defs, + isa_memory_access, + isa_externs, + isa_model, + model_reg_d, + startaddr, + initial_stack_data, + initial_register_abi_data, + register_data_all ) = + match Nat_big_num.to_int e_machine with + | 8 (* EM_MIPS *) -> + let startaddr = + let e_entry = Uint64_wrapper.of_bigint e_entry in + match Abi_mips64.abi_mips64_compute_program_entry_point segments e_entry with + | Error.Fail s -> failwith "Failed computing entry point" + | Error.Success s -> s + in + let initial_stack_data, initial_register_abi_data = + initial_stack_and_reg_data_of_MIPS_elf_file e_entry !data_mem + in + + ( Cheri.defs, + ( Mips_extras.mips_read_memory_functions, + Mips_extras.mips_read_memory_tagged_functions, + Mips_extras.mips_memory_writes, + Mips_extras.mips_memory_eas, + Mips_extras.mips_memory_vals, + Mips_extras.mips_memory_vals_tagged, + Mips_extras.mips_barrier_functions + ), + [], + MIPS, + D_decreasing, + startaddr, + initial_stack_data, + initial_register_abi_data, + cheri_register_data_all ) - else - (* not an object symbol or of zero size - ignore *) - convert_symbol_table symbol_map' - end + | _ -> + failwith + (Printf.sprintf + "Sail sequential interpreter can't handle the e_machine value %s, only EM_PPC64, EM_AARCH64 and \ + EM_MIPS are supported." + (Nat_big_num.to_string e_machine) + ) in - (List.map (fun (n,a,bs) -> (n,a,List.length bs,bs)) initial_stack_data) @ convert_symbol_table symbol_map - in - (* invert the symbol table to use for pp *) - let symbol_table_pp : ((Sail_impl_base.address * int) * string) list = - (* map symbol to (bindings, footprint), - if a symbol appears more then onece keep the one with higher - precedence (stb_global > stb_weak > stb_local) *) - let map = - List.fold_left - (fun map (name, (typ, size, address, mb, binding)) -> - if String.length name <> 0 && - (if String.length name = 1 then Char.code (String.get name 0) <> 0 else true) && - not (Nat_big_num.equal address (Nat_big_num.of_int 0)) - then - try - let (binding', _) = StringMap.find name map in - if Nat_big_num.equal binding' Elf_symbol_table.stb_local || - Nat_big_num.equal binding Elf_symbol_table.stb_global - then - StringMap.add name (binding, - (Sail_impl_base.address_of_integer address, Nat_big_num.to_int size)) map - else map - with Not_found -> - StringMap.add name (binding, - (Sail_impl_base.address_of_integer address, Nat_big_num.to_int size)) map - - else map - ) - StringMap.empty - symbol_map + (* pull the object symbols from the symbol table *) + let symbol_table : (string * Nat_big_num.num * int * word8 list (*their bytes*)) list = + let rec convert_symbol_table symbol_map = + begin + match symbol_map with + | [] -> [] + | ( (name : string), + ( (typ : Nat_big_num.num), + (size : Nat_big_num.num (*number of bytes*)), + (address : Nat_big_num.num), + (mb : Byte_sequence.byte_sequence option (*present iff type=stt_object*)), + (binding : Nat_big_num.num) + ) + ) (* (mb: Byte_sequence_wrapper.t option (*present iff type=stt_object*)) )) *) + :: symbol_map' -> + if + Nat_big_num.equal typ Elf_symbol_table.stt_object + && not (Nat_big_num.equal size (Nat_big_num.of_int 0)) + then ( + (* an object symbol - map *) + (*Printf.printf "*** size %d ***\n" (Nat_big_num.to_int size);*) + let bytes = + match mb with + | None -> raise (Failure "this cannot happen") + | Some (Sequence bytes) -> List.map (fun (c : char) -> Char.code c) bytes + in + (name, address, List.length bytes, bytes) :: convert_symbol_table symbol_map' + ) + else (* not an object symbol or of zero size - ignore *) + convert_symbol_table symbol_map' + end + in + List.map (fun (n, a, bs) -> (n, a, List.length bs, bs)) initial_stack_data @ convert_symbol_table symbol_map in - List.map (fun (name, (binding, fp)) -> (fp, name)) (StringMap.bindings map) - in + (* invert the symbol table to use for pp *) + let symbol_table_pp : ((Sail_impl_base.address * int) * string) list = + (* map symbol to (bindings, footprint), + if a symbol appears more then onece keep the one with higher + precedence (stb_global > stb_weak > stb_local) *) + let map = + List.fold_left + (fun map (name, (typ, size, address, mb, binding)) -> + if + String.length name <> 0 + && (if String.length name = 1 then Char.code (String.get name 0) <> 0 else true) + && not (Nat_big_num.equal address (Nat_big_num.of_int 0)) + then ( + try + let binding', _ = StringMap.find name map in + if + Nat_big_num.equal binding' Elf_symbol_table.stb_local + || Nat_big_num.equal binding Elf_symbol_table.stb_global + then + StringMap.add name + (binding, (Sail_impl_base.address_of_integer address, Nat_big_num.to_int size)) + map + else map + with Not_found -> + StringMap.add name + (binding, (Sail_impl_base.address_of_integer address, Nat_big_num.to_int size)) + map + ) + else map + ) + StringMap.empty symbol_map + in + List.map (fun (name, (binding, fp)) -> (fp, name)) (StringMap.bindings map) + in + let initial_register_state rbn = + try List.assoc rbn initial_register_abi_data with Not_found -> (register_state_zero register_data_all) rbn + in - let initial_register_state = - fun rbn -> - try - List.assoc rbn initial_register_abi_data - with - Not_found -> - (register_state_zero register_data_all) rbn - in + begin + initial_reg_file register_data_all initial_register_state; + + (* construct initial system state *) + let initial_system_state = + ( isa_defs, + isa_memory_access, + isa_externs, + isa_model, + model_reg_d, + startaddr, + Sail_impl_base.address_of_integer startaddr + ) + in - begin - (initial_reg_file register_data_all initial_register_state); - - (* construct initial system state *) - let initial_system_state = - (isa_defs, - isa_memory_access, - isa_externs, - isa_model, - model_reg_d, - startaddr, - (Sail_impl_base.address_of_integer startaddr)) - in - - (initial_system_state, symbol_table_pp) + (initial_system_state, symbol_table_pp) + end end - end let eager_eval = ref true let break_point = ref false @@ -541,22 +553,42 @@ let break_instr = ref 0 let max_cut_off = ref false let max_instr = ref 0 let raw_file = ref "" -let raw_at = ref 0 - -let args = [ - ("--file", Arg.Set_string file, "filename of elf binary to load in memory"); - ("--quiet", Arg.Clear Run_interp_model.interact_print, "do not display per-instruction actions"); - ("--silent", Arg.Tuple [Arg.Clear Run_interp_model.error_print; - Arg.Clear Run_interp_model.interact_print; - Arg.Clear Run_interp_model.result_print], - "do not dispaly error messages, per-instruction actions, or results"); - ("--no_result", Arg.Clear Run_interp_model.result_print, "do not display final register values"); - ("--interactive", Arg.Clear eager_eval , "interactive execution"); - ("--breakpoint", Arg.Int (fun i -> break_point := true; break_instr:= i), "run to instruction number i, then run interactively"); - ("--max_instruction", Arg.Int (fun i -> max_cut_off := true; max_instr := i), "only run i instructions, then stop"); - ("--raw", Arg.Set_string raw_file, "filename of raw file to load in memory"); - ("--at", Arg.Set_int raw_at, "address to load raw file in memory"); -] +let raw_at = ref 0 + +let args = + [ + ("--file", Arg.Set_string file, "filename of elf binary to load in memory"); + ("--quiet", Arg.Clear Run_interp_model.interact_print, "do not display per-instruction actions"); + ( "--silent", + Arg.Tuple + [ + Arg.Clear Run_interp_model.error_print; + Arg.Clear Run_interp_model.interact_print; + Arg.Clear Run_interp_model.result_print; + ], + "do not dispaly error messages, per-instruction actions, or results" + ); + ("--no_result", Arg.Clear Run_interp_model.result_print, "do not display final register values"); + ("--interactive", Arg.Clear eager_eval, "interactive execution"); + ( "--breakpoint", + Arg.Int + (fun i -> + break_point := true; + break_instr := i + ), + "run to instruction number i, then run interactively" + ); + ( "--max_instruction", + Arg.Int + (fun i -> + max_cut_off := true; + max_instr := i + ), + "only run i instructions, then stop" + ); + ("--raw", Arg.Set_string raw_file, "filename of raw file to load in memory"); + ("--at", Arg.Set_int raw_at, "address to load raw file in memory"); + ] let time_it action arg = let start_time = Sys.time () in @@ -566,305 +598,341 @@ let time_it action arg = (*TODO MIPS specific, should print final register values under all models*) let rec debug_print_gprs start stop = - resultf "DEBUG MIPS REG %.2d %s\n" start (Printing_functions.logfile_register_value_to_string (Reg.find (Format.sprintf "GPR%02d" start) !reg)); - if start < stop - then debug_print_gprs (start + 1) stop - else () + resultf "DEBUG MIPS REG %.2d %s\n" start + (Printing_functions.logfile_register_value_to_string (Reg.find (Format.sprintf "GPR%02d" start) !reg)); + if start < stop then debug_print_gprs (start + 1) stop else () let rec debug_print_capregs start stop = - resultf "DEBUG CAP REG %.2d %s\n" start (Printing_functions.logfile_register_value_to_string (Reg.find (Format.sprintf "C%02d" start) !reg)); - if start < stop - then debug_print_capregs (start + 1) stop - else () + resultf "DEBUG CAP REG %.2d %s\n" start + (Printing_functions.logfile_register_value_to_string (Reg.find (Format.sprintf "C%02d" start) !reg)); + if start < stop then debug_print_capregs (start + 1) stop else () let stop_condition_met model instr = match model with - | PPC -> - (match instr with - | ("Sc", [("Lev", _, arg)]) -> - Nat_big_num.equal (integer_of_bit_list arg) (Nat_big_num.of_int 32) - | _ -> false) - | AArch64 -> (match instr with - | ("ImplementationDefinedStopFetching", _) -> true - | _ -> false) - | MIPS -> (match instr with - | ("HCF", _) -> - resultf "DEBUG MIPS PC %s\n" (Printing_functions.logfile_register_value_to_string (Reg.find "PC" !reg)); - debug_print_gprs 0 31; - resultf "DEBUG CAP PCC %s\n" (Printing_functions.logfile_register_value_to_string (Reg.find "PCC" !reg)); - debug_print_capregs 0 31; - true - | _ -> false) - -let option_int_of_option_integer i = match i with - | Some i -> Some (Nat_big_num.to_int i) - | None -> None + | PPC -> ( + match instr with + | "Sc", [("Lev", _, arg)] -> Nat_big_num.equal (integer_of_bit_list arg) (Nat_big_num.of_int 32) + | _ -> false + ) + | AArch64 -> ( + match instr with "ImplementationDefinedStopFetching", _ -> true | _ -> false + ) + | MIPS -> ( + match instr with + | "HCF", _ -> + resultf "DEBUG MIPS PC %s\n" (Printing_functions.logfile_register_value_to_string (Reg.find "PC" !reg)); + debug_print_gprs 0 31; + resultf "DEBUG CAP PCC %s\n" (Printing_functions.logfile_register_value_to_string (Reg.find "PCC" !reg)); + debug_print_capregs 0 31; + true + | _ -> false + ) + +let option_int_of_option_integer i = match i with Some i -> Some (Nat_big_num.to_int i) | None -> None let add1 = Nat_big_num.add (Nat_big_num.of_int 1) let get_addr_trans_regs _ = (*resultf "PCC %s\n" (Printing_functions.logfile_register_value_to_string (Reg.find "PCC" !reg));*) - Some([ - (Sail_impl_base.Reg("PC", 63, 64, Sail_impl_base.D_decreasing), Reg.find "PC" !reg); - (Sail_impl_base.Reg("PCC", 256, 257, Sail_impl_base.D_decreasing), Reg.find "PCC" !reg); - (Sail_impl_base.Reg("C29", 256, 257, Sail_impl_base.D_decreasing), Reg.find "C29" !reg); - (Sail_impl_base.Reg("CP0Status", 31, 32, Sail_impl_base.D_decreasing), Reg.find "CP0Status" !reg); - (Sail_impl_base.Reg("CP0Cause", 31, 32, Sail_impl_base.D_decreasing), Reg.find "CP0Cause" !reg); - (Sail_impl_base.Reg("CP0Count", 31, 32, Sail_impl_base.D_decreasing), Reg.find "CP0Count" !reg); - (Sail_impl_base.Reg("CP0Compare", 31, 32, Sail_impl_base.D_decreasing), Reg.find "CP0Compare" !reg); - (Sail_impl_base.Reg("inBranchDelay", 0, 1, Sail_impl_base.D_decreasing), Reg.find "inBranchDelay" !reg); - (Sail_impl_base.Reg("TLBRandom", 5, 6, Sail_impl_base.D_decreasing), Reg.find "TLBRandom" !reg); - (Sail_impl_base.Reg("TLBWired", 5, 6, Sail_impl_base.D_decreasing), Reg.find "TLBWired" !reg); - (Sail_impl_base.Reg("TLBEntryHi", 63, 64, Sail_impl_base.D_decreasing), Reg.find "TLBEntryHi" !reg); - (Sail_impl_base.Reg("TLBEntry00", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry00" !reg); - (Sail_impl_base.Reg("TLBEntry01", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry01" !reg); - (Sail_impl_base.Reg("TLBEntry02", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry02" !reg); - (Sail_impl_base.Reg("TLBEntry03", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry03" !reg); - (Sail_impl_base.Reg("TLBEntry04", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry04" !reg); - (Sail_impl_base.Reg("TLBEntry05", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry05" !reg); - (Sail_impl_base.Reg("TLBEntry06", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry06" !reg); - (Sail_impl_base.Reg("TLBEntry07", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry07" !reg); - (Sail_impl_base.Reg("TLBEntry08", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry08" !reg); - (Sail_impl_base.Reg("TLBEntry09", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry09" !reg); - (Sail_impl_base.Reg("TLBEntry10", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry10" !reg); - (Sail_impl_base.Reg("TLBEntry11", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry11" !reg); - (Sail_impl_base.Reg("TLBEntry12", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry12" !reg); - (Sail_impl_base.Reg("TLBEntry13", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry13" !reg); - (Sail_impl_base.Reg("TLBEntry14", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry14" !reg); - (Sail_impl_base.Reg("TLBEntry15", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry15" !reg); - (Sail_impl_base.Reg("TLBEntry16", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry16" !reg); - (Sail_impl_base.Reg("TLBEntry17", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry17" !reg); - (Sail_impl_base.Reg("TLBEntry18", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry18" !reg); - (Sail_impl_base.Reg("TLBEntry19", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry19" !reg); - (Sail_impl_base.Reg("TLBEntry20", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry20" !reg); - (Sail_impl_base.Reg("TLBEntry21", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry21" !reg); - (Sail_impl_base.Reg("TLBEntry22", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry22" !reg); - (Sail_impl_base.Reg("TLBEntry23", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry23" !reg); - (Sail_impl_base.Reg("TLBEntry24", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry24" !reg); - (Sail_impl_base.Reg("TLBEntry25", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry25" !reg); - (Sail_impl_base.Reg("TLBEntry26", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry26" !reg); - (Sail_impl_base.Reg("TLBEntry27", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry27" !reg); - (Sail_impl_base.Reg("TLBEntry28", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry28" !reg); - (Sail_impl_base.Reg("TLBEntry29", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry29" !reg); - (Sail_impl_base.Reg("TLBEntry30", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry30" !reg); - (Sail_impl_base.Reg("TLBEntry31", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry31" !reg); - (Sail_impl_base.Reg("TLBEntry32", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry32" !reg); - (Sail_impl_base.Reg("TLBEntry33", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry33" !reg); - (Sail_impl_base.Reg("TLBEntry34", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry34" !reg); - (Sail_impl_base.Reg("TLBEntry35", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry35" !reg); - (Sail_impl_base.Reg("TLBEntry36", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry36" !reg); - (Sail_impl_base.Reg("TLBEntry37", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry37" !reg); - (Sail_impl_base.Reg("TLBEntry38", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry38" !reg); - (Sail_impl_base.Reg("TLBEntry39", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry39" !reg); - (Sail_impl_base.Reg("TLBEntry40", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry40" !reg); - (Sail_impl_base.Reg("TLBEntry41", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry41" !reg); - (Sail_impl_base.Reg("TLBEntry42", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry42" !reg); - (Sail_impl_base.Reg("TLBEntry43", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry43" !reg); - (Sail_impl_base.Reg("TLBEntry44", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry44" !reg); - (Sail_impl_base.Reg("TLBEntry45", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry45" !reg); - (Sail_impl_base.Reg("TLBEntry46", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry46" !reg); - (Sail_impl_base.Reg("TLBEntry47", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry47" !reg); - (Sail_impl_base.Reg("TLBEntry48", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry48" !reg); - (Sail_impl_base.Reg("TLBEntry49", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry49" !reg); - (Sail_impl_base.Reg("TLBEntry50", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry50" !reg); - (Sail_impl_base.Reg("TLBEntry51", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry51" !reg); - (Sail_impl_base.Reg("TLBEntry52", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry52" !reg); - (Sail_impl_base.Reg("TLBEntry53", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry53" !reg); - (Sail_impl_base.Reg("TLBEntry54", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry54" !reg); - (Sail_impl_base.Reg("TLBEntry55", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry55" !reg); - (Sail_impl_base.Reg("TLBEntry56", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry56" !reg); - (Sail_impl_base.Reg("TLBEntry57", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry57" !reg); - (Sail_impl_base.Reg("TLBEntry58", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry58" !reg); - (Sail_impl_base.Reg("TLBEntry59", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry59" !reg); - (Sail_impl_base.Reg("TLBEntry60", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry60" !reg); - (Sail_impl_base.Reg("TLBEntry61", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry61" !reg); - (Sail_impl_base.Reg("TLBEntry62", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry62" !reg); - (Sail_impl_base.Reg("TLBEntry63", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry63" !reg); - ]) + Some + [ + (Sail_impl_base.Reg ("PC", 63, 64, Sail_impl_base.D_decreasing), Reg.find "PC" !reg); + (Sail_impl_base.Reg ("PCC", 256, 257, Sail_impl_base.D_decreasing), Reg.find "PCC" !reg); + (Sail_impl_base.Reg ("C29", 256, 257, Sail_impl_base.D_decreasing), Reg.find "C29" !reg); + (Sail_impl_base.Reg ("CP0Status", 31, 32, Sail_impl_base.D_decreasing), Reg.find "CP0Status" !reg); + (Sail_impl_base.Reg ("CP0Cause", 31, 32, Sail_impl_base.D_decreasing), Reg.find "CP0Cause" !reg); + (Sail_impl_base.Reg ("CP0Count", 31, 32, Sail_impl_base.D_decreasing), Reg.find "CP0Count" !reg); + (Sail_impl_base.Reg ("CP0Compare", 31, 32, Sail_impl_base.D_decreasing), Reg.find "CP0Compare" !reg); + (Sail_impl_base.Reg ("inBranchDelay", 0, 1, Sail_impl_base.D_decreasing), Reg.find "inBranchDelay" !reg); + (Sail_impl_base.Reg ("TLBRandom", 5, 6, Sail_impl_base.D_decreasing), Reg.find "TLBRandom" !reg); + (Sail_impl_base.Reg ("TLBWired", 5, 6, Sail_impl_base.D_decreasing), Reg.find "TLBWired" !reg); + (Sail_impl_base.Reg ("TLBEntryHi", 63, 64, Sail_impl_base.D_decreasing), Reg.find "TLBEntryHi" !reg); + (Sail_impl_base.Reg ("TLBEntry00", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry00" !reg); + (Sail_impl_base.Reg ("TLBEntry01", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry01" !reg); + (Sail_impl_base.Reg ("TLBEntry02", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry02" !reg); + (Sail_impl_base.Reg ("TLBEntry03", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry03" !reg); + (Sail_impl_base.Reg ("TLBEntry04", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry04" !reg); + (Sail_impl_base.Reg ("TLBEntry05", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry05" !reg); + (Sail_impl_base.Reg ("TLBEntry06", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry06" !reg); + (Sail_impl_base.Reg ("TLBEntry07", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry07" !reg); + (Sail_impl_base.Reg ("TLBEntry08", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry08" !reg); + (Sail_impl_base.Reg ("TLBEntry09", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry09" !reg); + (Sail_impl_base.Reg ("TLBEntry10", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry10" !reg); + (Sail_impl_base.Reg ("TLBEntry11", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry11" !reg); + (Sail_impl_base.Reg ("TLBEntry12", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry12" !reg); + (Sail_impl_base.Reg ("TLBEntry13", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry13" !reg); + (Sail_impl_base.Reg ("TLBEntry14", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry14" !reg); + (Sail_impl_base.Reg ("TLBEntry15", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry15" !reg); + (Sail_impl_base.Reg ("TLBEntry16", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry16" !reg); + (Sail_impl_base.Reg ("TLBEntry17", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry17" !reg); + (Sail_impl_base.Reg ("TLBEntry18", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry18" !reg); + (Sail_impl_base.Reg ("TLBEntry19", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry19" !reg); + (Sail_impl_base.Reg ("TLBEntry20", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry20" !reg); + (Sail_impl_base.Reg ("TLBEntry21", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry21" !reg); + (Sail_impl_base.Reg ("TLBEntry22", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry22" !reg); + (Sail_impl_base.Reg ("TLBEntry23", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry23" !reg); + (Sail_impl_base.Reg ("TLBEntry24", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry24" !reg); + (Sail_impl_base.Reg ("TLBEntry25", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry25" !reg); + (Sail_impl_base.Reg ("TLBEntry26", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry26" !reg); + (Sail_impl_base.Reg ("TLBEntry27", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry27" !reg); + (Sail_impl_base.Reg ("TLBEntry28", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry28" !reg); + (Sail_impl_base.Reg ("TLBEntry29", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry29" !reg); + (Sail_impl_base.Reg ("TLBEntry30", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry30" !reg); + (Sail_impl_base.Reg ("TLBEntry31", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry31" !reg); + (Sail_impl_base.Reg ("TLBEntry32", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry32" !reg); + (Sail_impl_base.Reg ("TLBEntry33", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry33" !reg); + (Sail_impl_base.Reg ("TLBEntry34", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry34" !reg); + (Sail_impl_base.Reg ("TLBEntry35", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry35" !reg); + (Sail_impl_base.Reg ("TLBEntry36", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry36" !reg); + (Sail_impl_base.Reg ("TLBEntry37", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry37" !reg); + (Sail_impl_base.Reg ("TLBEntry38", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry38" !reg); + (Sail_impl_base.Reg ("TLBEntry39", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry39" !reg); + (Sail_impl_base.Reg ("TLBEntry40", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry40" !reg); + (Sail_impl_base.Reg ("TLBEntry41", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry41" !reg); + (Sail_impl_base.Reg ("TLBEntry42", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry42" !reg); + (Sail_impl_base.Reg ("TLBEntry43", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry43" !reg); + (Sail_impl_base.Reg ("TLBEntry44", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry44" !reg); + (Sail_impl_base.Reg ("TLBEntry45", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry45" !reg); + (Sail_impl_base.Reg ("TLBEntry46", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry46" !reg); + (Sail_impl_base.Reg ("TLBEntry47", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry47" !reg); + (Sail_impl_base.Reg ("TLBEntry48", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry48" !reg); + (Sail_impl_base.Reg ("TLBEntry49", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry49" !reg); + (Sail_impl_base.Reg ("TLBEntry50", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry50" !reg); + (Sail_impl_base.Reg ("TLBEntry51", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry51" !reg); + (Sail_impl_base.Reg ("TLBEntry52", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry52" !reg); + (Sail_impl_base.Reg ("TLBEntry53", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry53" !reg); + (Sail_impl_base.Reg ("TLBEntry54", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry54" !reg); + (Sail_impl_base.Reg ("TLBEntry55", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry55" !reg); + (Sail_impl_base.Reg ("TLBEntry56", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry56" !reg); + (Sail_impl_base.Reg ("TLBEntry57", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry57" !reg); + (Sail_impl_base.Reg ("TLBEntry58", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry58" !reg); + (Sail_impl_base.Reg ("TLBEntry59", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry59" !reg); + (Sail_impl_base.Reg ("TLBEntry60", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry60" !reg); + (Sail_impl_base.Reg ("TLBEntry61", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry61" !reg); + (Sail_impl_base.Reg ("TLBEntry62", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry62" !reg); + (Sail_impl_base.Reg ("TLBEntry63", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry63" !reg); + ] let get_opcode pc_a = - List.map (fun b -> match b with - | Some b -> b - | None -> failwith "A byte in opcode contained unknown or undef") + List.map + (fun b -> match b with Some b -> b | None -> failwith "A byte in opcode contained unknown or undef") (List.map byte_of_memory_byte - ([Mem.find pc_a !prog_mem; + [ + Mem.find pc_a !prog_mem; Mem.find (add1 pc_a) !prog_mem; Mem.find (add1 (add1 pc_a)) !prog_mem; - Mem.find (add1 (add1 (add1 pc_a))) !prog_mem])) + Mem.find (add1 (add1 (add1 pc_a))) !prog_mem; + ] + ) let rec write_events = function | [] -> () - | e::events -> - (match e with - | E_write_reg (Reg(id,_,_,_), value) -> reg := Reg.add id value !reg - | E_write_reg ((Reg_slice(id,_,_,range) as reg_n),value) - | E_write_reg ((Reg_field(id,_,_,_,range) as reg_n),value)-> - let old_val = Reg.find id !reg in - let new_val = fupdate_slice reg_n old_val value range in - reg := Reg.add id new_val !reg - | E_write_reg((Reg_f_slice(id,_,_,_,range,mini_range) as reg_n),value) -> - let old_val = Reg.find id !reg in - let new_val = fupdate_slice reg_n old_val value (combine_slices range mini_range) in - reg := Reg.add id new_val !reg - | _ -> failwith "Only register write events expected"); - write_events events - -let get_pc_address = function - | MIPS -> Reg.find "PC" !reg - | PPC -> Reg.find "CIA" !reg - | AArch64 -> Reg.find "_PC" !reg - -let option_int_of_reg str = - option_int_of_option_integer (integer_of_register_value (Reg.find str !reg)) + | e :: events -> + ( match e with + | E_write_reg (Reg (id, _, _, _), value) -> reg := Reg.add id value !reg + | E_write_reg ((Reg_slice (id, _, _, range) as reg_n), value) + | E_write_reg ((Reg_field (id, _, _, _, range) as reg_n), value) -> + let old_val = Reg.find id !reg in + let new_val = fupdate_slice reg_n old_val value range in + reg := Reg.add id new_val !reg + | E_write_reg ((Reg_f_slice (id, _, _, _, range, mini_range) as reg_n), value) -> + let old_val = Reg.find id !reg in + let new_val = fupdate_slice reg_n old_val value (combine_slices range mini_range) in + reg := Reg.add id new_val !reg + | _ -> failwith "Only register write events expected" + ); + write_events events + +let get_pc_address = function MIPS -> Reg.find "PC" !reg | PPC -> Reg.find "CIA" !reg | AArch64 -> Reg.find "_PC" !reg + +let option_int_of_reg str = option_int_of_option_integer (integer_of_register_value (Reg.find str !reg)) let rec fde_loop count context model mode track_dependencies addr_trans = - if !max_cut_off && count = !max_instr - then resultf "\nEnding evaluation due to reaching cut off point of %d instructions\n" count + if !max_cut_off && count = !max_instr then + resultf "\nEnding evaluation due to reaching cut off point of %d instructions\n" count else begin - if !break_point && count = !break_instr then begin break_point := false; eager_eval := false end; + if !break_point && count = !break_instr then begin + break_point := false; + eager_eval := false + end; let pc_regval = get_pc_address model in - interactf "\n**** instruction %d from address %s ****\n" - count (Printing_functions.register_value_to_string pc_regval); - let pc_addr = address_of_register_value pc_regval in - let pc_val = match pc_addr with - | Some v -> v - | None -> failwith "pc contains undef or unknown" in - let m_paddr_int = match addr_trans (get_addr_trans_regs ()) pc_val with - | Some a, Some events -> write_events (List.rev events); Some (integer_of_address a) - | Some a, None -> Some (integer_of_address a) - | None, Some events -> write_events (List.rev events); None - | None, None -> failwith "address translation failed and no writes" in + interactf "\n**** instruction %d from address %s ****\n" count + (Printing_functions.register_value_to_string pc_regval); + let pc_addr = address_of_register_value pc_regval in + let pc_val = match pc_addr with Some v -> v | None -> failwith "pc contains undef or unknown" in + let m_paddr_int = + match addr_trans (get_addr_trans_regs ()) pc_val with + | Some a, Some events -> + write_events (List.rev events); + Some (integer_of_address a) + | Some a, None -> Some (integer_of_address a) + | None, Some events -> + write_events (List.rev events); + None + | None, None -> failwith "address translation failed and no writes" + in match m_paddr_int with - | Some pc -> - let inBranchDelay = option_int_of_reg "inBranchDelay" in - (match inBranchDelay with - | Some 0 -> + | Some pc -> + let inBranchDelay = option_int_of_reg "inBranchDelay" in + ( match inBranchDelay with + | Some 0 -> let npc_addr = add_address_nat pc_val 4 in let npc_reg = register_value_of_address npc_addr Sail_impl_base.D_decreasing in reg := Reg.add "nextPC" npc_reg !reg; - reg := Reg.add "inCCallDelay" (register_value_of_integer 1 0 Sail_impl_base.D_decreasing Nat_big_num.zero) !reg; - | Some 1 -> + reg := + Reg.add "inCCallDelay" (register_value_of_integer 1 0 Sail_impl_base.D_decreasing Nat_big_num.zero) !reg + | Some 1 -> reg := Reg.add "nextPC" (Reg.find "delayedPC" !reg) !reg; - reg := Reg.add "nextPCC" (Reg.find "delayedPCC" !reg) !reg; - | _ -> failwith "invalid value of inBranchDelay"); - let opcode = Opcode (get_opcode pc) in - let (instruction,istate) = match Interp_inter_imp.decode_to_istate context None opcode with - | Instr(instruction,istate) -> - let instruction = interp_value_to_instr_external context instruction in + reg := Reg.add "nextPCC" (Reg.find "delayedPCC" !reg) !reg + | _ -> failwith "invalid value of inBranchDelay" + ); + let opcode = Opcode (get_opcode pc) in + let instruction, istate = + match Interp_inter_imp.decode_to_istate context None opcode with + | Instr (instruction, istate) -> + let instruction = interp_value_to_instr_external context instruction in interactf "\n**** Running: %s ****\n" (Printing_functions.instruction_to_string instruction); - (instruction,istate) - | Decode_error d -> - (match d with + (instruction, istate) + | Decode_error d -> + ( match d with | Interp_interface.Unsupported_instruction_error instruction -> - let instruction = interp_value_to_instr_external context instruction in - errorf "\n**** Encountered unsupported instruction %s ****\n" (Printing_functions.instruction_to_string instruction) - | Interp_interface.Not_an_instruction_error op -> - (match op with - | Opcode bytes -> - errorf "\n**** Encountered non-decodeable opcode: %s ****\n" (Printing_functions.byte_list_to_string bytes)) - | Internal_error s -> errorf "\n**** Internal error on decode: %s ****\n" s); exit 1 - in - if stop_condition_met model instruction - then resultf "\nSUCCESS program terminated after %d instructions\n" count - else - begin - match Run_interp_model.run istate !reg !prog_mem !tag_mem (Nat_big_num.of_int 32) !eager_eval track_dependencies mode "execute" with - | false, _,_, _ -> errorf "FAILURE\n"; exit 1 - | true, mode, track_dependencies, (my_reg, my_mem, my_tags) -> - reg := my_reg; - prog_mem := my_mem; - tag_mem := my_tags; - - (try - let (pending, _, _) = (Unix.select [(Unix.stdin)] [] [] 0.0) in - (if (pending != []) then - let char = (input_byte stdin) in ( + let instruction = interp_value_to_instr_external context instruction in + errorf "\n**** Encountered unsupported instruction %s ****\n" + (Printing_functions.instruction_to_string instruction) + | Interp_interface.Not_an_instruction_error op -> ( + match op with + | Opcode bytes -> + errorf "\n**** Encountered non-decodeable opcode: %s ****\n" + (Printing_functions.byte_list_to_string bytes) + ) + | Internal_error s -> errorf "\n**** Internal error on decode: %s ****\n" s + ); + exit 1 + in + if stop_condition_met model instruction then + resultf "\nSUCCESS program terminated after %d instructions\n" count + else begin + match + Run_interp_model.run istate !reg !prog_mem !tag_mem (Nat_big_num.of_int 32) !eager_eval track_dependencies + mode "execute" + with + | false, _, _, _ -> + errorf "FAILURE\n"; + exit 1 + | true, mode, track_dependencies, (my_reg, my_mem, my_tags) -> + reg := my_reg; + prog_mem := my_mem; + tag_mem := my_tags; + + ( try + let pending, _, _ = Unix.select [Unix.stdin] [] [] 0.0 in + if pending != [] then ( + let char = input_byte stdin in errorf "Input %x\n" char; - input_buf := (!input_buf) @ [char])); - with - | _ -> ()); - - let uart_rvalid = option_int_of_reg "UART_RVALID" in - (match uart_rvalid with - | Some 0 -> - (match !input_buf with - | x :: xs -> ( - reg := Reg.add "UART_RDATA" (register_value_of_integer 8 7 Sail_impl_base.D_decreasing (Nat_big_num.of_int x)) !reg; - reg := Reg.add "UART_RVALID" (register_value_of_integer 1 0 Sail_impl_base.D_decreasing (Nat_big_num.of_int 1)) !reg; - input_buf := xs; - ) - | [] -> ()) - | _-> ()); - - let uart_written = option_int_of_reg "UART_WRITTEN" in - (match uart_written with - | Some 1 -> - (let uart_data = option_int_of_reg "UART_WDATA" in + input_buf := !input_buf @ [char] + ) + with _ -> () + ); + + let uart_rvalid = option_int_of_reg "UART_RVALID" in + ( match uart_rvalid with + | Some 0 -> ( + match !input_buf with + | x :: xs -> + reg := + Reg.add "UART_RDATA" + (register_value_of_integer 8 7 Sail_impl_base.D_decreasing (Nat_big_num.of_int x)) + !reg; + reg := + Reg.add "UART_RVALID" + (register_value_of_integer 1 0 Sail_impl_base.D_decreasing (Nat_big_num.of_int 1)) + !reg; + input_buf := xs + | [] -> () + ) + | _ -> () + ); + + let uart_written = option_int_of_reg "UART_WRITTEN" in + ( match uart_written with + | Some 1 -> ( + let uart_data = option_int_of_reg "UART_WDATA" in match uart_data with - | Some b -> (printf "%c" (Char.chr b); printf "%!") - | None -> (errorf "UART_WDATA was undef" ; exit 1)) - | _ -> ()); - reg := Reg.add "UART_WRITTEN" (register_value_of_integer 1 0 Sail_impl_base.D_decreasing Nat_big_num.zero) !reg; - - reg := Reg.add "inBranchDelay" (Reg.find "branchPending" !reg) !reg; - reg := Reg.add "branchPending" (register_value_of_integer 1 0 Sail_impl_base.D_decreasing Nat_big_num.zero) !reg; - reg := Reg.add "PC" (Reg.find "nextPC" !reg) !reg; - reg := Reg.add "PCC" (Reg.find "nextPCC" !reg) !reg; - fde_loop (count + 1) context model (Some mode) (ref track_dependencies) addr_trans - end - | None -> begin - reg := Reg.add "inBranchDelay" (Reg.find "branchPending" !reg) !reg; - reg := Reg.add "branchPending" (register_value_of_integer 1 0 Sail_impl_base.D_decreasing Nat_big_num.zero) !reg; - reg := Reg.add "PC" (Reg.find "nextPC" !reg) !reg; - reg := Reg.add "PCC" (Reg.find "nextPCC" !reg) !reg; - fde_loop (count + 1) context model mode track_dependencies addr_trans - end + | Some b -> + printf "%c" (Char.chr b); + printf "%!" + | None -> + errorf "UART_WDATA was undef"; + exit 1 + ) + | _ -> () + ); + reg := + Reg.add "UART_WRITTEN" (register_value_of_integer 1 0 Sail_impl_base.D_decreasing Nat_big_num.zero) !reg; + + reg := Reg.add "inBranchDelay" (Reg.find "branchPending" !reg) !reg; + reg := + Reg.add "branchPending" + (register_value_of_integer 1 0 Sail_impl_base.D_decreasing Nat_big_num.zero) + !reg; + reg := Reg.add "PC" (Reg.find "nextPC" !reg) !reg; + reg := Reg.add "PCC" (Reg.find "nextPCC" !reg) !reg; + fde_loop (count + 1) context model (Some mode) (ref track_dependencies) addr_trans + end + | None -> begin + reg := Reg.add "inBranchDelay" (Reg.find "branchPending" !reg) !reg; + reg := Reg.add "branchPending" (register_value_of_integer 1 0 Sail_impl_base.D_decreasing Nat_big_num.zero) !reg; + reg := Reg.add "PC" (Reg.find "nextPC" !reg) !reg; + reg := Reg.add "PCC" (Reg.find "nextPCC" !reg) !reg; + fde_loop (count + 1) context model mode track_dependencies addr_trans + end end - + let rec load_raw_file' mem addr chan = let byte = input_byte chan in - (add_mem byte addr mem; - load_raw_file' mem (Nat_big_num.succ addr) chan) + add_mem byte addr mem; + load_raw_file' mem (Nat_big_num.succ addr) chan -let rec load_raw_file mem addr chan = - try - load_raw_file' mem addr chan - with - | End_of_file -> () +let rec load_raw_file mem addr chan = try load_raw_file' mem addr chan with End_of_file -> () let run () = - Arg.parse args (fun _ -> raise (Arg.Bad "anonymous parameter")) "" ; + Arg.parse args (fun _ -> raise (Arg.Bad "anonymous parameter")) ""; if !file = "" then begin Arg.usage args ""; - exit 1; + exit 1 end; if !break_point then eager_eval := true; - let ((isa_defs, - (isa_m0, isa_m1, isa_m2, isa_m3,isa_m4,isa_m5,isa_m6), - isa_externs, - isa_model, - model_reg_d, - startaddr, - startaddr_internal), pp_symbol_map) = initial_system_state_of_elf_file !file in + let ( ( isa_defs, + (isa_m0, isa_m1, isa_m2, isa_m3, isa_m4, isa_m5, isa_m6), + isa_externs, + isa_model, + model_reg_d, + startaddr, + startaddr_internal + ), + pp_symbol_map ) = + initial_system_state_of_elf_file !file + in let context = build_context false isa_defs isa_m0 isa_m1 isa_m2 isa_m3 isa_m4 isa_m5 isa_m6 None isa_externs in - (*NOTE: this is likely MIPS specific, so should probably pull from initial_system_state info on to translate or not, - endian mode, and translate function name + (*NOTE: this is likely MIPS specific, so should probably pull from initial_system_state info on to translate or not, + endian mode, and translate function name *) let addr_trans = translate_address context E_little_endian "TranslatePC" in - if String.length(!raw_file) != 0 then - load_raw_file prog_mem (Nat_big_num.of_int !raw_at) (open_in_bin !raw_file); - reg := Reg.add "PC" (register_value_of_address startaddr_internal model_reg_d ) !reg; + if String.length !raw_file != 0 then load_raw_file prog_mem (Nat_big_num.of_int !raw_at) (open_in_bin !raw_file); + reg := Reg.add "PC" (register_value_of_address startaddr_internal model_reg_d) !reg; (* entry point: unit -> unit fde *) let name = Filename.basename !file in - let t = time_it (fun () -> fde_loop 0 context isa_model (Some Run) (ref false) addr_trans) () in - resultf "Execution time for file %s: %f seconds\n" name t;; + let t = time_it (fun () -> fde_loop 0 context isa_model (Some Run) (ref false) addr_trans) () in + resultf "Execution time for file %s: %f seconds\n" name t +;; (* Turn off line-buffering of standard input to allow responsive console input *) -if Unix.isatty (Unix.stdin) then begin - let tattrs = Unix.tcgetattr (Unix.stdin) in - Unix.tcsetattr (Unix.stdin) (Unix.TCSANOW) ({tattrs with c_icanon=false}) -end ;; +if Unix.isatty Unix.stdin then begin + let tattrs = Unix.tcgetattr Unix.stdin in + Unix.tcsetattr Unix.stdin Unix.TCSANOW { tattrs with c_icanon = false } +end +;; -run () ;; +run () diff --git a/src/lem_interp/run_with_elf_cheri128.ml b/src/lem_interp/run_with_elf_cheri128.ml index 104417ffd..463567f27 100644 --- a/src/lem_interp/run_with_elf_cheri128.ml +++ b/src/lem_interp/run_with_elf_cheri128.ml @@ -65,37 +65,34 @@ (* SUCH DAMAGE. *) (****************************************************************************) -open Printf ;; -open Format ;; -open Big_int ;; -open Interp_ast ;; -open Interp_interface ;; -open Interp_inter_imp ;; -open Run_interp_model ;; -open Sail_impl_base ;; -open Sail_interface ;; - -module StringMap = Map.Make(String) - -let file = ref "" ;; - -let rec foldli f acc ?(i=0) = function - | [] -> acc - | x::xs -> foldli f (f i acc x) ~i:(i+1) xs -;; +open Printf +open Format +open Big_int +open Interp_ast +open Interp_interface +open Interp_inter_imp +open Run_interp_model +open Sail_impl_base +open Sail_interface + +module StringMap = Map.Make (String) + +let file = ref "" -let endian = ref E_big_endian ;; +let rec foldli f acc ?(i = 0) = function [] -> acc | x :: xs -> foldli f (f i acc x) ~i:(i + 1) xs -let hex_to_big_int s = big_int_of_int64 (Int64.of_string s) ;; +let endian = ref E_big_endian -let data_mem = (ref Mem.empty : (memory_byte Run_interp_model.Mem.t) ref) ;; -let prog_mem = (ref Mem.empty : (memory_byte Run_interp_model.Mem.t) ref) ;; -let tag_mem = (ref Mem.empty : (memory_byte Run_interp_model.Mem.t) ref);; -let reg = ref Reg.empty ;; -let input_buf = (ref [] : int list ref);; +let hex_to_big_int s = big_int_of_int64 (Int64.of_string s) + +let data_mem = (ref Mem.empty : memory_byte Run_interp_model.Mem.t ref) +let prog_mem = (ref Mem.empty : memory_byte Run_interp_model.Mem.t ref) +let tag_mem = (ref Mem.empty : memory_byte Run_interp_model.Mem.t ref) +let reg = ref Reg.empty +let input_buf = (ref [] : int list ref) let add_mem byte addr mem = - assert(byte >= 0 && byte < 256); + assert (byte >= 0 && byte < 256); (*Printf.printf "add_mem %s: 0x%02x\n" (Uint64.to_string_hex (Uint64.of_string (Nat_big_num.to_string addr))) byte;*) let mem_byte = memory_byte_of_int byte in let zero_byte = memory_byte_of_int 0 in @@ -103,435 +100,450 @@ let add_mem byte addr mem = tag_mem := Mem.add addr zero_byte !tag_mem let get_reg reg name = - let reg_content = Reg.find name reg in reg_content + let reg_content = Reg.find name reg in + reg_content -let rec load_memory_segment' (bytes,addr) mem = +let rec load_memory_segment' (bytes, addr) mem = match bytes with | [] -> () - | byte::bytes' -> - let data_byte = Char.code byte in - let addr' = Nat_big_num.succ addr in - begin add_mem data_byte addr mem; - load_memory_segment' (bytes',addr') mem - end - -let rec load_memory_segment (segment: Elf_interpreted_segment.elf64_interpreted_segment) mem = + | byte :: bytes' -> + let data_byte = Char.code byte in + let addr' = Nat_big_num.succ addr in + begin + add_mem data_byte addr mem; + load_memory_segment' (bytes', addr') mem + end + +let rec load_memory_segment (segment : Elf_interpreted_segment.elf64_interpreted_segment) mem = let (Byte_sequence.Sequence bytes) = segment.Elf_interpreted_segment.elf64_segment_body in let addr = segment.Elf_interpreted_segment.elf64_segment_paddr in - load_memory_segment' (bytes,addr) mem - + load_memory_segment' (bytes, addr) mem let rec load_memory_segments segments = - begin match segments with + begin + match segments with | [] -> () - | segment::segments' -> - let (x,w,r) = segment.Elf_interpreted_segment.elf64_segment_flags in - begin - load_memory_segment segment prog_mem; - load_memory_segments segments' - end + | segment :: segments' -> + let x, w, r = segment.Elf_interpreted_segment.elf64_segment_flags in + begin + load_memory_segment segment prog_mem; + load_memory_segments segments' + end end - -let rec read_mem mem address length = - if length = 0 - then [] - else - let byte = - try Mem.find address mem with - | Not_found -> failwith "start address not found" - in - byte :: (read_mem mem (Nat_big_num.succ address) (length - 1)) + +let rec read_mem mem address length = + if length = 0 then [] + else ( + let byte = try Mem.find address mem with Not_found -> failwith "start address not found" in + byte :: read_mem mem (Nat_big_num.succ address) (length - 1) + ) let register_state_zero register_data rbn : register_value = - let (dir,width,start_index) = - try List.assoc rbn register_data with - | Not_found -> failwith ("register_state_zero lookup failed (" ^ rbn) - in register_value_zeros dir width start_index + let dir, width, start_index = + try List.assoc rbn register_data with Not_found -> failwith ("register_state_zero lookup failed (" ^ rbn) + in + register_value_zeros dir width start_index type model = PPC | AArch64 | MIPS -let mips_register_data_all = [ - (*Pseudo registers*) - ("PC", (D_decreasing, 64, 63)); - ("branchPending", (D_decreasing, 1, 0)); - ("inBranchDelay", (D_decreasing, 1, 0)); - ("inCCallDelay", (D_decreasing, 1, 0)); - ("delayedPC", (D_decreasing, 64, 63)); - ("nextPC", (D_decreasing, 64, 63)); - (* General purpose registers *) - ("GPR00", (D_decreasing, 64, 63)); - ("GPR01", (D_decreasing, 64, 63)); - ("GPR02", (D_decreasing, 64, 63)); - ("GPR03", (D_decreasing, 64, 63)); - ("GPR04", (D_decreasing, 64, 63)); - ("GPR05", (D_decreasing, 64, 63)); - ("GPR06", (D_decreasing, 64, 63)); - ("GPR07", (D_decreasing, 64, 63)); - ("GPR08", (D_decreasing, 64, 63)); - ("GPR09", (D_decreasing, 64, 63)); - ("GPR10", (D_decreasing, 64, 63)); - ("GPR11", (D_decreasing, 64, 63)); - ("GPR12", (D_decreasing, 64, 63)); - ("GPR13", (D_decreasing, 64, 63)); - ("GPR14", (D_decreasing, 64, 63)); - ("GPR15", (D_decreasing, 64, 63)); - ("GPR16", (D_decreasing, 64, 63)); - ("GPR17", (D_decreasing, 64, 63)); - ("GPR18", (D_decreasing, 64, 63)); - ("GPR19", (D_decreasing, 64, 63)); - ("GPR20", (D_decreasing, 64, 63)); - ("GPR21", (D_decreasing, 64, 63)); - ("GPR22", (D_decreasing, 64, 63)); - ("GPR23", (D_decreasing, 64, 63)); - ("GPR24", (D_decreasing, 64, 63)); - ("GPR25", (D_decreasing, 64, 63)); - ("GPR26", (D_decreasing, 64, 63)); - ("GPR27", (D_decreasing, 64, 63)); - ("GPR28", (D_decreasing, 64, 63)); - ("GPR29", (D_decreasing, 64, 63)); - ("GPR30", (D_decreasing, 64, 63)); - ("GPR31", (D_decreasing, 64, 63)); - (* special registers for mul/div *) - ("HI", (D_decreasing, 64, 63)); - ("LO", (D_decreasing, 64, 63)); - (* control registers *) - ("CP0Status", (D_decreasing, 32, 31)); - ("CP0Cause", (D_decreasing, 32, 31)); - ("CP0EPC", (D_decreasing, 64, 63)); - ("CP0LLAddr", (D_decreasing, 64, 63)); - ("CP0LLBit", (D_decreasing, 1, 0)); - ("CP0Count", (D_decreasing, 32, 31)); - ("CP0Compare", (D_decreasing, 32, 31)); - ("CP0HWREna", (D_decreasing, 32, 31)); - ("CP0UserLocal", (D_decreasing, 64, 63)); - ("CP0BadVAddr", (D_decreasing, 64, 63)); - ("TLBProbe" ,(D_decreasing, 1, 0)); - ("TLBIndex" ,(D_decreasing, 6, 5)); - ("TLBRandom" ,(D_decreasing, 6, 5)); - ("TLBEntryLo0",(D_decreasing, 64, 63)); - ("TLBEntryLo1",(D_decreasing, 64, 63)); - ("TLBContext" ,(D_decreasing, 64, 63)); - ("TLBPageMask",(D_decreasing, 16, 15)); - ("TLBWired" ,(D_decreasing, 6, 5)); - ("TLBEntryHi" ,(D_decreasing, 64, 63)); - ("TLBXContext",(D_decreasing, 64, 63)); - - ("TLBEntry00" ,(D_decreasing, 117, 116)); - ("TLBEntry01" ,(D_decreasing, 117, 116)); - ("TLBEntry02" ,(D_decreasing, 117, 116)); - ("TLBEntry03" ,(D_decreasing, 117, 116)); - ("TLBEntry04" ,(D_decreasing, 117, 116)); - ("TLBEntry05" ,(D_decreasing, 117, 116)); - ("TLBEntry06" ,(D_decreasing, 117, 116)); - ("TLBEntry07" ,(D_decreasing, 117, 116)); - ("TLBEntry08" ,(D_decreasing, 117, 116)); - ("TLBEntry09" ,(D_decreasing, 117, 116)); - ("TLBEntry10" ,(D_decreasing, 117, 116)); - ("TLBEntry11" ,(D_decreasing, 117, 116)); - ("TLBEntry12" ,(D_decreasing, 117, 116)); - ("TLBEntry13" ,(D_decreasing, 117, 116)); - ("TLBEntry14" ,(D_decreasing, 117, 116)); - ("TLBEntry15" ,(D_decreasing, 117, 116)); - ("TLBEntry16" ,(D_decreasing, 117, 116)); - ("TLBEntry17" ,(D_decreasing, 117, 116)); - ("TLBEntry18" ,(D_decreasing, 117, 116)); - ("TLBEntry19" ,(D_decreasing, 117, 116)); - ("TLBEntry20" ,(D_decreasing, 117, 116)); - ("TLBEntry21" ,(D_decreasing, 117, 116)); - ("TLBEntry22" ,(D_decreasing, 117, 116)); - ("TLBEntry23" ,(D_decreasing, 117, 116)); - ("TLBEntry24" ,(D_decreasing, 117, 116)); - ("TLBEntry25" ,(D_decreasing, 117, 116)); - ("TLBEntry26" ,(D_decreasing, 117, 116)); - ("TLBEntry27" ,(D_decreasing, 117, 116)); - ("TLBEntry28" ,(D_decreasing, 117, 116)); - ("TLBEntry29" ,(D_decreasing, 117, 116)); - ("TLBEntry30" ,(D_decreasing, 117, 116)); - ("TLBEntry31" ,(D_decreasing, 117, 116)); - ("TLBEntry32" ,(D_decreasing, 117, 116)); - ("TLBEntry33" ,(D_decreasing, 117, 116)); - ("TLBEntry34" ,(D_decreasing, 117, 116)); - ("TLBEntry35" ,(D_decreasing, 117, 116)); - ("TLBEntry36" ,(D_decreasing, 117, 116)); - ("TLBEntry37" ,(D_decreasing, 117, 116)); - ("TLBEntry38" ,(D_decreasing, 117, 116)); - ("TLBEntry39" ,(D_decreasing, 117, 116)); - ("TLBEntry40" ,(D_decreasing, 117, 116)); - ("TLBEntry41" ,(D_decreasing, 117, 116)); - ("TLBEntry42" ,(D_decreasing, 117, 116)); - ("TLBEntry43" ,(D_decreasing, 117, 116)); - ("TLBEntry44" ,(D_decreasing, 117, 116)); - ("TLBEntry45" ,(D_decreasing, 117, 116)); - ("TLBEntry46" ,(D_decreasing, 117, 116)); - ("TLBEntry47" ,(D_decreasing, 117, 116)); - ("TLBEntry48" ,(D_decreasing, 117, 116)); - ("TLBEntry49" ,(D_decreasing, 117, 116)); - ("TLBEntry50" ,(D_decreasing, 117, 116)); - ("TLBEntry51" ,(D_decreasing, 117, 116)); - ("TLBEntry52" ,(D_decreasing, 117, 116)); - ("TLBEntry53" ,(D_decreasing, 117, 116)); - ("TLBEntry54" ,(D_decreasing, 117, 116)); - ("TLBEntry55" ,(D_decreasing, 117, 116)); - ("TLBEntry56" ,(D_decreasing, 117, 116)); - ("TLBEntry57" ,(D_decreasing, 117, 116)); - ("TLBEntry58" ,(D_decreasing, 117, 116)); - ("TLBEntry59" ,(D_decreasing, 117, 116)); - ("TLBEntry60" ,(D_decreasing, 117, 116)); - ("TLBEntry61" ,(D_decreasing, 117, 116)); - ("TLBEntry62" ,(D_decreasing, 117, 116)); - ("TLBEntry63" ,(D_decreasing, 117, 116)); - - ("UART_WDATA" ,(D_decreasing, 8, 7)); - ("UART_RDATA" ,(D_decreasing, 8, 7)); - ("UART_WRITTEN" ,(D_decreasing, 1, 0)); - ("UART_RVALID" ,(D_decreasing, 1, 0)); -] - -let cheri_register_data_all = mips_register_data_all @ [ - ("CapCause", (D_decreasing, 16, 15)); - ("PCC", (D_decreasing, 129, 128)); - ("nextPCC", (D_decreasing, 129, 128)); - ("delayedPCC", (D_decreasing, 129, 128)); - ("C00", (D_decreasing, 129, 128)); - ("C01", (D_decreasing, 129, 128)); - ("C02", (D_decreasing, 129, 128)); - ("C03", (D_decreasing, 129, 128)); - ("C04", (D_decreasing, 129, 128)); - ("C05", (D_decreasing, 129, 128)); - ("C06", (D_decreasing, 129, 128)); - ("C07", (D_decreasing, 129, 128)); - ("C08", (D_decreasing, 129, 128)); - ("C09", (D_decreasing, 129, 128)); - ("C10", (D_decreasing, 129, 128)); - ("C11", (D_decreasing, 129, 128)); - ("C12", (D_decreasing, 129, 128)); - ("C13", (D_decreasing, 129, 128)); - ("C14", (D_decreasing, 129, 128)); - ("C15", (D_decreasing, 129, 128)); - ("C16", (D_decreasing, 129, 128)); - ("C17", (D_decreasing, 129, 128)); - ("C18", (D_decreasing, 129, 128)); - ("C19", (D_decreasing, 129, 128)); - ("C20", (D_decreasing, 129, 128)); - ("C21", (D_decreasing, 129, 128)); - ("C22", (D_decreasing, 129, 128)); - ("C23", (D_decreasing, 129, 128)); - ("C24", (D_decreasing, 129, 128)); - ("C25", (D_decreasing, 129, 128)); - ("C26", (D_decreasing, 129, 128)); - ("C27", (D_decreasing, 129, 128)); - ("C28", (D_decreasing, 129, 128)); - ("C29", (D_decreasing, 129, 128)); - ("C30", (D_decreasing, 129, 128)); - ("C31", (D_decreasing, 129, 128)); -] +let mips_register_data_all = + [ + (*Pseudo registers*) + ("PC", (D_decreasing, 64, 63)); + ("branchPending", (D_decreasing, 1, 0)); + ("inBranchDelay", (D_decreasing, 1, 0)); + ("inCCallDelay", (D_decreasing, 1, 0)); + ("delayedPC", (D_decreasing, 64, 63)); + ("nextPC", (D_decreasing, 64, 63)); + (* General purpose registers *) + ("GPR00", (D_decreasing, 64, 63)); + ("GPR01", (D_decreasing, 64, 63)); + ("GPR02", (D_decreasing, 64, 63)); + ("GPR03", (D_decreasing, 64, 63)); + ("GPR04", (D_decreasing, 64, 63)); + ("GPR05", (D_decreasing, 64, 63)); + ("GPR06", (D_decreasing, 64, 63)); + ("GPR07", (D_decreasing, 64, 63)); + ("GPR08", (D_decreasing, 64, 63)); + ("GPR09", (D_decreasing, 64, 63)); + ("GPR10", (D_decreasing, 64, 63)); + ("GPR11", (D_decreasing, 64, 63)); + ("GPR12", (D_decreasing, 64, 63)); + ("GPR13", (D_decreasing, 64, 63)); + ("GPR14", (D_decreasing, 64, 63)); + ("GPR15", (D_decreasing, 64, 63)); + ("GPR16", (D_decreasing, 64, 63)); + ("GPR17", (D_decreasing, 64, 63)); + ("GPR18", (D_decreasing, 64, 63)); + ("GPR19", (D_decreasing, 64, 63)); + ("GPR20", (D_decreasing, 64, 63)); + ("GPR21", (D_decreasing, 64, 63)); + ("GPR22", (D_decreasing, 64, 63)); + ("GPR23", (D_decreasing, 64, 63)); + ("GPR24", (D_decreasing, 64, 63)); + ("GPR25", (D_decreasing, 64, 63)); + ("GPR26", (D_decreasing, 64, 63)); + ("GPR27", (D_decreasing, 64, 63)); + ("GPR28", (D_decreasing, 64, 63)); + ("GPR29", (D_decreasing, 64, 63)); + ("GPR30", (D_decreasing, 64, 63)); + ("GPR31", (D_decreasing, 64, 63)); + (* special registers for mul/div *) + ("HI", (D_decreasing, 64, 63)); + ("LO", (D_decreasing, 64, 63)); + (* control registers *) + ("CP0Status", (D_decreasing, 32, 31)); + ("CP0Cause", (D_decreasing, 32, 31)); + ("CP0EPC", (D_decreasing, 64, 63)); + ("CP0LLAddr", (D_decreasing, 64, 63)); + ("CP0LLBit", (D_decreasing, 1, 0)); + ("CP0Count", (D_decreasing, 32, 31)); + ("CP0Compare", (D_decreasing, 32, 31)); + ("CP0HWREna", (D_decreasing, 32, 31)); + ("CP0UserLocal", (D_decreasing, 64, 63)); + ("CP0BadVAddr", (D_decreasing, 64, 63)); + ("TLBProbe", (D_decreasing, 1, 0)); + ("TLBIndex", (D_decreasing, 6, 5)); + ("TLBRandom", (D_decreasing, 6, 5)); + ("TLBEntryLo0", (D_decreasing, 64, 63)); + ("TLBEntryLo1", (D_decreasing, 64, 63)); + ("TLBContext", (D_decreasing, 64, 63)); + ("TLBPageMask", (D_decreasing, 16, 15)); + ("TLBWired", (D_decreasing, 6, 5)); + ("TLBEntryHi", (D_decreasing, 64, 63)); + ("TLBXContext", (D_decreasing, 64, 63)); + ("TLBEntry00", (D_decreasing, 117, 116)); + ("TLBEntry01", (D_decreasing, 117, 116)); + ("TLBEntry02", (D_decreasing, 117, 116)); + ("TLBEntry03", (D_decreasing, 117, 116)); + ("TLBEntry04", (D_decreasing, 117, 116)); + ("TLBEntry05", (D_decreasing, 117, 116)); + ("TLBEntry06", (D_decreasing, 117, 116)); + ("TLBEntry07", (D_decreasing, 117, 116)); + ("TLBEntry08", (D_decreasing, 117, 116)); + ("TLBEntry09", (D_decreasing, 117, 116)); + ("TLBEntry10", (D_decreasing, 117, 116)); + ("TLBEntry11", (D_decreasing, 117, 116)); + ("TLBEntry12", (D_decreasing, 117, 116)); + ("TLBEntry13", (D_decreasing, 117, 116)); + ("TLBEntry14", (D_decreasing, 117, 116)); + ("TLBEntry15", (D_decreasing, 117, 116)); + ("TLBEntry16", (D_decreasing, 117, 116)); + ("TLBEntry17", (D_decreasing, 117, 116)); + ("TLBEntry18", (D_decreasing, 117, 116)); + ("TLBEntry19", (D_decreasing, 117, 116)); + ("TLBEntry20", (D_decreasing, 117, 116)); + ("TLBEntry21", (D_decreasing, 117, 116)); + ("TLBEntry22", (D_decreasing, 117, 116)); + ("TLBEntry23", (D_decreasing, 117, 116)); + ("TLBEntry24", (D_decreasing, 117, 116)); + ("TLBEntry25", (D_decreasing, 117, 116)); + ("TLBEntry26", (D_decreasing, 117, 116)); + ("TLBEntry27", (D_decreasing, 117, 116)); + ("TLBEntry28", (D_decreasing, 117, 116)); + ("TLBEntry29", (D_decreasing, 117, 116)); + ("TLBEntry30", (D_decreasing, 117, 116)); + ("TLBEntry31", (D_decreasing, 117, 116)); + ("TLBEntry32", (D_decreasing, 117, 116)); + ("TLBEntry33", (D_decreasing, 117, 116)); + ("TLBEntry34", (D_decreasing, 117, 116)); + ("TLBEntry35", (D_decreasing, 117, 116)); + ("TLBEntry36", (D_decreasing, 117, 116)); + ("TLBEntry37", (D_decreasing, 117, 116)); + ("TLBEntry38", (D_decreasing, 117, 116)); + ("TLBEntry39", (D_decreasing, 117, 116)); + ("TLBEntry40", (D_decreasing, 117, 116)); + ("TLBEntry41", (D_decreasing, 117, 116)); + ("TLBEntry42", (D_decreasing, 117, 116)); + ("TLBEntry43", (D_decreasing, 117, 116)); + ("TLBEntry44", (D_decreasing, 117, 116)); + ("TLBEntry45", (D_decreasing, 117, 116)); + ("TLBEntry46", (D_decreasing, 117, 116)); + ("TLBEntry47", (D_decreasing, 117, 116)); + ("TLBEntry48", (D_decreasing, 117, 116)); + ("TLBEntry49", (D_decreasing, 117, 116)); + ("TLBEntry50", (D_decreasing, 117, 116)); + ("TLBEntry51", (D_decreasing, 117, 116)); + ("TLBEntry52", (D_decreasing, 117, 116)); + ("TLBEntry53", (D_decreasing, 117, 116)); + ("TLBEntry54", (D_decreasing, 117, 116)); + ("TLBEntry55", (D_decreasing, 117, 116)); + ("TLBEntry56", (D_decreasing, 117, 116)); + ("TLBEntry57", (D_decreasing, 117, 116)); + ("TLBEntry58", (D_decreasing, 117, 116)); + ("TLBEntry59", (D_decreasing, 117, 116)); + ("TLBEntry60", (D_decreasing, 117, 116)); + ("TLBEntry61", (D_decreasing, 117, 116)); + ("TLBEntry62", (D_decreasing, 117, 116)); + ("TLBEntry63", (D_decreasing, 117, 116)); + ("UART_WDATA", (D_decreasing, 8, 7)); + ("UART_RDATA", (D_decreasing, 8, 7)); + ("UART_WRITTEN", (D_decreasing, 1, 0)); + ("UART_RVALID", (D_decreasing, 1, 0)); + ] + +let cheri_register_data_all = + mips_register_data_all + @ [ + ("CapCause", (D_decreasing, 16, 15)); + ("PCC", (D_decreasing, 129, 128)); + ("nextPCC", (D_decreasing, 129, 128)); + ("delayedPCC", (D_decreasing, 129, 128)); + ("C00", (D_decreasing, 129, 128)); + ("C01", (D_decreasing, 129, 128)); + ("C02", (D_decreasing, 129, 128)); + ("C03", (D_decreasing, 129, 128)); + ("C04", (D_decreasing, 129, 128)); + ("C05", (D_decreasing, 129, 128)); + ("C06", (D_decreasing, 129, 128)); + ("C07", (D_decreasing, 129, 128)); + ("C08", (D_decreasing, 129, 128)); + ("C09", (D_decreasing, 129, 128)); + ("C10", (D_decreasing, 129, 128)); + ("C11", (D_decreasing, 129, 128)); + ("C12", (D_decreasing, 129, 128)); + ("C13", (D_decreasing, 129, 128)); + ("C14", (D_decreasing, 129, 128)); + ("C15", (D_decreasing, 129, 128)); + ("C16", (D_decreasing, 129, 128)); + ("C17", (D_decreasing, 129, 128)); + ("C18", (D_decreasing, 129, 128)); + ("C19", (D_decreasing, 129, 128)); + ("C20", (D_decreasing, 129, 128)); + ("C21", (D_decreasing, 129, 128)); + ("C22", (D_decreasing, 129, 128)); + ("C23", (D_decreasing, 129, 128)); + ("C24", (D_decreasing, 129, 128)); + ("C25", (D_decreasing, 129, 128)); + ("C26", (D_decreasing, 129, 128)); + ("C27", (D_decreasing, 129, 128)); + ("C28", (D_decreasing, 129, 128)); + ("C29", (D_decreasing, 129, 128)); + ("C30", (D_decreasing, 129, 128)); + ("C31", (D_decreasing, 129, 128)); + ] let initial_stack_and_reg_data_of_MIPS_elf_file e_entry all_data_memory = - let initial_stack_data = [] in - let initial_cap_val_int = Nat_big_num.of_string "0x1fffe6000000100000000000000000000" in (* hex((0x10000 << 64) + (48 << 105) + (0x7fff << 113) + (1 << 128)) T=0x10000 E=48 perms=0x7fff tag=1 *) - let initial_cap_val_reg = Sail_impl_base.register_value_of_integer 129 128 D_decreasing initial_cap_val_int in - let initial_register_abi_data : (string * Sail_impl_base.register_value) list = [ - ("CP0Status", Sail_impl_base.register_value_of_integer 32 31 D_decreasing (Nat_big_num.of_string "0x00400000")); - ("PCC", initial_cap_val_reg); - ("nextPCC", initial_cap_val_reg); - ("delayedPCC", initial_cap_val_reg); - ("C00", initial_cap_val_reg); - ("C01", initial_cap_val_reg); - ("C02", initial_cap_val_reg); - ("C03", initial_cap_val_reg); - ("C04", initial_cap_val_reg); - ("C05", initial_cap_val_reg); - ("C06", initial_cap_val_reg); - ("C07", initial_cap_val_reg); - ("C08", initial_cap_val_reg); - ("C09", initial_cap_val_reg); - ("C10", initial_cap_val_reg); - ("C11", initial_cap_val_reg); - ("C12", initial_cap_val_reg); - ("C13", initial_cap_val_reg); - ("C14", initial_cap_val_reg); - ("C15", initial_cap_val_reg); - ("C16", initial_cap_val_reg); - ("C17", initial_cap_val_reg); - ("C18", initial_cap_val_reg); - ("C19", initial_cap_val_reg); - ("C20", initial_cap_val_reg); - ("C21", initial_cap_val_reg); - ("C22", initial_cap_val_reg); - ("C23", initial_cap_val_reg); - ("C24", initial_cap_val_reg); - ("C25", initial_cap_val_reg); - ("C26", initial_cap_val_reg); - ("C27", initial_cap_val_reg); - ("C28", initial_cap_val_reg); - ("C29", initial_cap_val_reg); - ("C30", initial_cap_val_reg); - ("C31", initial_cap_val_reg); - ] in + let initial_stack_data = [] in + let initial_cap_val_int = Nat_big_num.of_string "0x1fffe6000000100000000000000000000" in + (* hex((0x10000 << 64) + (48 << 105) + (0x7fff << 113) + (1 << 128)) T=0x10000 E=48 perms=0x7fff tag=1 *) + let initial_cap_val_reg = Sail_impl_base.register_value_of_integer 129 128 D_decreasing initial_cap_val_int in + let initial_register_abi_data : (string * Sail_impl_base.register_value) list = + [ + ("CP0Status", Sail_impl_base.register_value_of_integer 32 31 D_decreasing (Nat_big_num.of_string "0x00400000")); + ("PCC", initial_cap_val_reg); + ("nextPCC", initial_cap_val_reg); + ("delayedPCC", initial_cap_val_reg); + ("C00", initial_cap_val_reg); + ("C01", initial_cap_val_reg); + ("C02", initial_cap_val_reg); + ("C03", initial_cap_val_reg); + ("C04", initial_cap_val_reg); + ("C05", initial_cap_val_reg); + ("C06", initial_cap_val_reg); + ("C07", initial_cap_val_reg); + ("C08", initial_cap_val_reg); + ("C09", initial_cap_val_reg); + ("C10", initial_cap_val_reg); + ("C11", initial_cap_val_reg); + ("C12", initial_cap_val_reg); + ("C13", initial_cap_val_reg); + ("C14", initial_cap_val_reg); + ("C15", initial_cap_val_reg); + ("C16", initial_cap_val_reg); + ("C17", initial_cap_val_reg); + ("C18", initial_cap_val_reg); + ("C19", initial_cap_val_reg); + ("C20", initial_cap_val_reg); + ("C21", initial_cap_val_reg); + ("C22", initial_cap_val_reg); + ("C23", initial_cap_val_reg); + ("C24", initial_cap_val_reg); + ("C25", initial_cap_val_reg); + ("C26", initial_cap_val_reg); + ("C27", initial_cap_val_reg); + ("C28", initial_cap_val_reg); + ("C29", initial_cap_val_reg); + ("C30", initial_cap_val_reg); + ("C31", initial_cap_val_reg); + ] + in (initial_stack_data, initial_register_abi_data) let initial_reg_file reg_data init = List.iter (fun (reg_name, _) -> reg := Reg.add reg_name (init reg_name) !reg) reg_data -let initial_system_state_of_elf_file name = - +let initial_system_state_of_elf_file name = (* call ELF analyser on file *) match Sail_interface.populate_and_obtain_global_symbol_init_info name with | Error.Fail s -> failwith ("populate_and_obtain_global_symbol_init_info: " ^ s) - | Error.Success - (_, (elf_epi: Sail_interface.executable_process_image), - (symbol_map: Elf_file.global_symbol_init_info)) - -> - let (segments, e_entry, e_machine) = - begin match elf_epi with - | ELF_Class_32 _ -> failwith "cannot handle ELF_Class_32" - | ELF_Class_64 (segments,e_entry,e_machine) -> - (* remove all the auto generated segments (they contain only 0s) *) - let segments = - Lem_list.mapMaybe - (fun (seg, prov) -> if prov = Elf_file.FromELF then Some seg else None) - segments - in - (segments,e_entry,e_machine) - end - in + | Error.Success + (_, (elf_epi : Sail_interface.executable_process_image), (symbol_map : Elf_file.global_symbol_init_info)) -> + let segments, e_entry, e_machine = + begin + match elf_epi with + | ELF_Class_32 _ -> failwith "cannot handle ELF_Class_32" + | ELF_Class_64 (segments, e_entry, e_machine) -> + (* remove all the auto generated segments (they contain only 0s) *) + let segments = + Lem_list.mapMaybe (fun (seg, prov) -> if prov = Elf_file.FromELF then Some seg else None) segments + in + (segments, e_entry, e_machine) + end + in - (* construct program memory and start address *) - begin - prog_mem := Mem.empty; - data_mem := Mem.empty; - tag_mem := Mem.empty; - load_memory_segments segments; - (* + (* construct program memory and start address *) + begin + prog_mem := Mem.empty; + data_mem := Mem.empty; + tag_mem := Mem.empty; + load_memory_segments segments; + (* debugf "prog_mem\n"; Mem.iter (fun k v -> debugf "%s\n" (Mem.to_string k v)) !prog_mem; debugf "data_mem\n"; Mem.iter (fun k v -> debugf "%s\n" (Mem.to_string k v)) !data_mem; *) - let (isa_defs, isa_memory_access, isa_externs, isa_model, model_reg_d, startaddr, - initial_stack_data, initial_register_abi_data, register_data_all) = - match Nat_big_num.to_int e_machine with - | 8 (* EM_MIPS *) -> - let startaddr = - let e_entry = Uint64_wrapper.of_bigint e_entry in - match Abi_mips64.abi_mips64_compute_program_entry_point segments e_entry with - | Error.Fail s -> failwith "Failed computing entry point" - | Error.Success s -> s - in - let (initial_stack_data, initial_register_abi_data) = - initial_stack_and_reg_data_of_MIPS_elf_file e_entry !data_mem in - - (Cheri128.defs, - (Mips_extras.mips_read_memory_functions, - Mips_extras.mips_read_memory_tagged_functions, - Mips_extras.mips_memory_writes, - Mips_extras.mips_memory_eas, - Mips_extras.mips_memory_vals, - Mips_extras.mips_memory_vals_tagged, - Mips_extras.mips_barrier_functions), - [], - MIPS, - D_decreasing, - startaddr, - initial_stack_data, - initial_register_abi_data, - cheri_register_data_all) - - | _ -> failwith (Printf.sprintf "Sail sequential interpreter can't handle the e_machine value %s, only EM_PPC64, EM_AARCH64 and EM_MIPS are supported." (Nat_big_num.to_string e_machine)) - in - - (* pull the object symbols from the symbol table *) - let symbol_table : (string * Nat_big_num.num * int * word8 list (*their bytes*)) list = - let rec convert_symbol_table symbol_map = - begin match symbol_map with - | [] -> [] - | ((name: string), - ((typ: Nat_big_num.num), - (size: Nat_big_num.num (*number of bytes*)), - (address: Nat_big_num.num), - (mb: Byte_sequence.byte_sequence option (*present iff type=stt_object*)), - (binding: Nat_big_num.num))) - (* (mb: Byte_sequence_wrapper.t option (*present iff type=stt_object*)) )) *) - ::symbol_map' -> - if Nat_big_num.equal typ Elf_symbol_table.stt_object && not (Nat_big_num.equal size (Nat_big_num.of_int 0)) - then - ( - (* an object symbol - map *) - (*Printf.printf "*** size %d ***\n" (Nat_big_num.to_int size);*) - let bytes = - (match mb with - | None -> raise (Failure "this cannot happen") - | Some (Sequence bytes) -> - List.map (fun (c:char) -> Char.code c) bytes) in - (name, address, List.length bytes, bytes):: convert_symbol_table symbol_map' + let ( isa_defs, + isa_memory_access, + isa_externs, + isa_model, + model_reg_d, + startaddr, + initial_stack_data, + initial_register_abi_data, + register_data_all ) = + match Nat_big_num.to_int e_machine with + | 8 (* EM_MIPS *) -> + let startaddr = + let e_entry = Uint64_wrapper.of_bigint e_entry in + match Abi_mips64.abi_mips64_compute_program_entry_point segments e_entry with + | Error.Fail s -> failwith "Failed computing entry point" + | Error.Success s -> s + in + let initial_stack_data, initial_register_abi_data = + initial_stack_and_reg_data_of_MIPS_elf_file e_entry !data_mem + in + + ( Cheri128.defs, + ( Mips_extras.mips_read_memory_functions, + Mips_extras.mips_read_memory_tagged_functions, + Mips_extras.mips_memory_writes, + Mips_extras.mips_memory_eas, + Mips_extras.mips_memory_vals, + Mips_extras.mips_memory_vals_tagged, + Mips_extras.mips_barrier_functions + ), + [], + MIPS, + D_decreasing, + startaddr, + initial_stack_data, + initial_register_abi_data, + cheri_register_data_all ) - else - (* not an object symbol or of zero size - ignore *) - convert_symbol_table symbol_map' - end + | _ -> + failwith + (Printf.sprintf + "Sail sequential interpreter can't handle the e_machine value %s, only EM_PPC64, EM_AARCH64 and \ + EM_MIPS are supported." + (Nat_big_num.to_string e_machine) + ) in - (List.map (fun (n,a,bs) -> (n,a,List.length bs,bs)) initial_stack_data) @ convert_symbol_table symbol_map - in - (* invert the symbol table to use for pp *) - let symbol_table_pp : ((Sail_impl_base.address * int) * string) list = - (* map symbol to (bindings, footprint), - if a symbol appears more then onece keep the one with higher - precedence (stb_global > stb_weak > stb_local) *) - let map = - List.fold_left - (fun map (name, (typ, size, address, mb, binding)) -> - if String.length name <> 0 && - (if String.length name = 1 then Char.code (String.get name 0) <> 0 else true) && - not (Nat_big_num.equal address (Nat_big_num.of_int 0)) - then - try - let (binding', _) = StringMap.find name map in - if Nat_big_num.equal binding' Elf_symbol_table.stb_local || - Nat_big_num.equal binding Elf_symbol_table.stb_global - then - StringMap.add name (binding, - (Sail_impl_base.address_of_integer address, Nat_big_num.to_int size)) map - else map - with Not_found -> - StringMap.add name (binding, - (Sail_impl_base.address_of_integer address, Nat_big_num.to_int size)) map - - else map - ) - StringMap.empty - symbol_map + (* pull the object symbols from the symbol table *) + let symbol_table : (string * Nat_big_num.num * int * word8 list (*their bytes*)) list = + let rec convert_symbol_table symbol_map = + begin + match symbol_map with + | [] -> [] + | ( (name : string), + ( (typ : Nat_big_num.num), + (size : Nat_big_num.num (*number of bytes*)), + (address : Nat_big_num.num), + (mb : Byte_sequence.byte_sequence option (*present iff type=stt_object*)), + (binding : Nat_big_num.num) + ) + ) (* (mb: Byte_sequence_wrapper.t option (*present iff type=stt_object*)) )) *) + :: symbol_map' -> + if + Nat_big_num.equal typ Elf_symbol_table.stt_object + && not (Nat_big_num.equal size (Nat_big_num.of_int 0)) + then ( + (* an object symbol - map *) + (*Printf.printf "*** size %d ***\n" (Nat_big_num.to_int size);*) + let bytes = + match mb with + | None -> raise (Failure "this cannot happen") + | Some (Sequence bytes) -> List.map (fun (c : char) -> Char.code c) bytes + in + (name, address, List.length bytes, bytes) :: convert_symbol_table symbol_map' + ) + else (* not an object symbol or of zero size - ignore *) + convert_symbol_table symbol_map' + end + in + List.map (fun (n, a, bs) -> (n, a, List.length bs, bs)) initial_stack_data @ convert_symbol_table symbol_map in - List.map (fun (name, (binding, fp)) -> (fp, name)) (StringMap.bindings map) - in + (* invert the symbol table to use for pp *) + let symbol_table_pp : ((Sail_impl_base.address * int) * string) list = + (* map symbol to (bindings, footprint), + if a symbol appears more then onece keep the one with higher + precedence (stb_global > stb_weak > stb_local) *) + let map = + List.fold_left + (fun map (name, (typ, size, address, mb, binding)) -> + if + String.length name <> 0 + && (if String.length name = 1 then Char.code (String.get name 0) <> 0 else true) + && not (Nat_big_num.equal address (Nat_big_num.of_int 0)) + then ( + try + let binding', _ = StringMap.find name map in + if + Nat_big_num.equal binding' Elf_symbol_table.stb_local + || Nat_big_num.equal binding Elf_symbol_table.stb_global + then + StringMap.add name + (binding, (Sail_impl_base.address_of_integer address, Nat_big_num.to_int size)) + map + else map + with Not_found -> + StringMap.add name + (binding, (Sail_impl_base.address_of_integer address, Nat_big_num.to_int size)) + map + ) + else map + ) + StringMap.empty symbol_map + in - let initial_register_state = - fun rbn -> - try - List.assoc rbn initial_register_abi_data - with - Not_found -> - (register_state_zero register_data_all) rbn - in + List.map (fun (name, (binding, fp)) -> (fp, name)) (StringMap.bindings map) + in - begin - (initial_reg_file register_data_all initial_register_state); - - (* construct initial system state *) - let initial_system_state = - (isa_defs, - isa_memory_access, - isa_externs, - isa_model, - model_reg_d, - startaddr, - (Sail_impl_base.address_of_integer startaddr)) + let initial_register_state rbn = + try List.assoc rbn initial_register_abi_data with Not_found -> (register_state_zero register_data_all) rbn in - - (initial_system_state, symbol_table_pp) + + begin + initial_reg_file register_data_all initial_register_state; + + (* construct initial system state *) + let initial_system_state = + ( isa_defs, + isa_memory_access, + isa_externs, + isa_model, + model_reg_d, + startaddr, + Sail_impl_base.address_of_integer startaddr + ) + in + + (initial_system_state, symbol_table_pp) + end end - end let eager_eval = ref true let break_point = ref false @@ -539,22 +551,42 @@ let break_instr = ref 0 let max_cut_off = ref false let max_instr = ref 0 let raw_file = ref "" -let raw_at = ref 0 - -let args = [ - ("--file", Arg.Set_string file, "filename of elf binary to load in memory"); - ("--quiet", Arg.Clear Run_interp_model.interact_print, "do not display per-instruction actions"); - ("--silent", Arg.Tuple [Arg.Clear Run_interp_model.error_print; - Arg.Clear Run_interp_model.interact_print; - Arg.Clear Run_interp_model.result_print], - "do not dispaly error messages, per-instruction actions, or results"); - ("--no_result", Arg.Clear Run_interp_model.result_print, "do not display final register values"); - ("--interactive", Arg.Clear eager_eval , "interactive execution"); - ("--breakpoint", Arg.Int (fun i -> break_point := true; break_instr:= i), "run to instruction number i, then run interactively"); - ("--max_instruction", Arg.Int (fun i -> max_cut_off := true; max_instr := i), "only run i instructions, then stop"); - ("--raw", Arg.Set_string raw_file, "filename of raw file to load in memory"); - ("--at", Arg.Set_int raw_at, "address to load raw file in memory"); -] +let raw_at = ref 0 + +let args = + [ + ("--file", Arg.Set_string file, "filename of elf binary to load in memory"); + ("--quiet", Arg.Clear Run_interp_model.interact_print, "do not display per-instruction actions"); + ( "--silent", + Arg.Tuple + [ + Arg.Clear Run_interp_model.error_print; + Arg.Clear Run_interp_model.interact_print; + Arg.Clear Run_interp_model.result_print; + ], + "do not dispaly error messages, per-instruction actions, or results" + ); + ("--no_result", Arg.Clear Run_interp_model.result_print, "do not display final register values"); + ("--interactive", Arg.Clear eager_eval, "interactive execution"); + ( "--breakpoint", + Arg.Int + (fun i -> + break_point := true; + break_instr := i + ), + "run to instruction number i, then run interactively" + ); + ( "--max_instruction", + Arg.Int + (fun i -> + max_cut_off := true; + max_instr := i + ), + "only run i instructions, then stop" + ); + ("--raw", Arg.Set_string raw_file, "filename of raw file to load in memory"); + ("--at", Arg.Set_int raw_at, "address to load raw file in memory"); + ] let time_it action arg = let start_time = Sys.time () in @@ -564,305 +596,341 @@ let time_it action arg = (*TODO MIPS specific, should print final register values under all models*) let rec debug_print_gprs start stop = - resultf "DEBUG MIPS REG %.2d %s\n" start (Printing_functions.logfile_register_value_to_string (Reg.find (Format.sprintf "GPR%02d" start) !reg)); - if start < stop - then debug_print_gprs (start + 1) stop - else () + resultf "DEBUG MIPS REG %.2d %s\n" start + (Printing_functions.logfile_register_value_to_string (Reg.find (Format.sprintf "GPR%02d" start) !reg)); + if start < stop then debug_print_gprs (start + 1) stop else () let rec debug_print_capregs start stop = - resultf "DEBUG CAP REG %.2d %s\n" start (Printing_functions.logfile_register_value_to_string (Reg.find (Format.sprintf "C%02d" start) !reg)); - if start < stop - then debug_print_capregs (start + 1) stop - else () + resultf "DEBUG CAP REG %.2d %s\n" start + (Printing_functions.logfile_register_value_to_string (Reg.find (Format.sprintf "C%02d" start) !reg)); + if start < stop then debug_print_capregs (start + 1) stop else () let stop_condition_met model instr = match model with - | PPC -> - (match instr with - | ("Sc", [("Lev", _, arg)]) -> - Nat_big_num.equal (integer_of_bit_list arg) (Nat_big_num.of_int 32) - | _ -> false) - | AArch64 -> (match instr with - | ("ImplementationDefinedStopFetching", _) -> true - | _ -> false) - | MIPS -> (match instr with - | ("HCF", _) -> - resultf "DEBUG MIPS PC %s\n" (Printing_functions.logfile_register_value_to_string (Reg.find "PC" !reg)); - debug_print_gprs 0 31; - resultf "DEBUG CAP PCC %s\n" (Printing_functions.logfile_register_value_to_string (Reg.find "PCC" !reg)); - debug_print_capregs 0 31; - true - | _ -> false) - -let option_int_of_option_integer i = match i with - | Some i -> Some (Nat_big_num.to_int i) - | None -> None + | PPC -> ( + match instr with + | "Sc", [("Lev", _, arg)] -> Nat_big_num.equal (integer_of_bit_list arg) (Nat_big_num.of_int 32) + | _ -> false + ) + | AArch64 -> ( + match instr with "ImplementationDefinedStopFetching", _ -> true | _ -> false + ) + | MIPS -> ( + match instr with + | "HCF", _ -> + resultf "DEBUG MIPS PC %s\n" (Printing_functions.logfile_register_value_to_string (Reg.find "PC" !reg)); + debug_print_gprs 0 31; + resultf "DEBUG CAP PCC %s\n" (Printing_functions.logfile_register_value_to_string (Reg.find "PCC" !reg)); + debug_print_capregs 0 31; + true + | _ -> false + ) + +let option_int_of_option_integer i = match i with Some i -> Some (Nat_big_num.to_int i) | None -> None let add1 = Nat_big_num.add (Nat_big_num.of_int 1) let get_addr_trans_regs _ = (*resultf "PCC %s\n" (Printing_functions.logfile_register_value_to_string (Reg.find "PCC" !reg));*) - Some([ - (Sail_impl_base.Reg("PC", 63, 64, Sail_impl_base.D_decreasing), Reg.find "PC" !reg); - (Sail_impl_base.Reg("PCC", 128, 129, Sail_impl_base.D_decreasing), Reg.find "PCC" !reg); - (Sail_impl_base.Reg("C29", 128, 129, Sail_impl_base.D_decreasing), Reg.find "C29" !reg); - (Sail_impl_base.Reg("CP0Status", 31, 32, Sail_impl_base.D_decreasing), Reg.find "CP0Status" !reg); - (Sail_impl_base.Reg("CP0Cause", 31, 32, Sail_impl_base.D_decreasing), Reg.find "CP0Cause" !reg); - (Sail_impl_base.Reg("CP0Count", 31, 32, Sail_impl_base.D_decreasing), Reg.find "CP0Count" !reg); - (Sail_impl_base.Reg("CP0Compare", 31, 32, Sail_impl_base.D_decreasing), Reg.find "CP0Compare" !reg); - (Sail_impl_base.Reg("inBranchDelay", 0, 1, Sail_impl_base.D_decreasing), Reg.find "inBranchDelay" !reg); - (Sail_impl_base.Reg("TLBRandom", 5, 6, Sail_impl_base.D_decreasing), Reg.find "TLBRandom" !reg); - (Sail_impl_base.Reg("TLBWired", 5, 6, Sail_impl_base.D_decreasing), Reg.find "TLBWired" !reg); - (Sail_impl_base.Reg("TLBEntryHi", 63, 64, Sail_impl_base.D_decreasing), Reg.find "TLBEntryHi" !reg); - (Sail_impl_base.Reg("TLBEntry00", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry00" !reg); - (Sail_impl_base.Reg("TLBEntry01", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry01" !reg); - (Sail_impl_base.Reg("TLBEntry02", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry02" !reg); - (Sail_impl_base.Reg("TLBEntry03", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry03" !reg); - (Sail_impl_base.Reg("TLBEntry04", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry04" !reg); - (Sail_impl_base.Reg("TLBEntry05", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry05" !reg); - (Sail_impl_base.Reg("TLBEntry06", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry06" !reg); - (Sail_impl_base.Reg("TLBEntry07", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry07" !reg); - (Sail_impl_base.Reg("TLBEntry08", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry08" !reg); - (Sail_impl_base.Reg("TLBEntry09", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry09" !reg); - (Sail_impl_base.Reg("TLBEntry10", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry10" !reg); - (Sail_impl_base.Reg("TLBEntry11", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry11" !reg); - (Sail_impl_base.Reg("TLBEntry12", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry12" !reg); - (Sail_impl_base.Reg("TLBEntry13", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry13" !reg); - (Sail_impl_base.Reg("TLBEntry14", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry14" !reg); - (Sail_impl_base.Reg("TLBEntry15", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry15" !reg); - (Sail_impl_base.Reg("TLBEntry16", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry16" !reg); - (Sail_impl_base.Reg("TLBEntry17", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry17" !reg); - (Sail_impl_base.Reg("TLBEntry18", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry18" !reg); - (Sail_impl_base.Reg("TLBEntry19", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry19" !reg); - (Sail_impl_base.Reg("TLBEntry20", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry20" !reg); - (Sail_impl_base.Reg("TLBEntry21", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry21" !reg); - (Sail_impl_base.Reg("TLBEntry22", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry22" !reg); - (Sail_impl_base.Reg("TLBEntry23", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry23" !reg); - (Sail_impl_base.Reg("TLBEntry24", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry24" !reg); - (Sail_impl_base.Reg("TLBEntry25", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry25" !reg); - (Sail_impl_base.Reg("TLBEntry26", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry26" !reg); - (Sail_impl_base.Reg("TLBEntry27", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry27" !reg); - (Sail_impl_base.Reg("TLBEntry28", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry28" !reg); - (Sail_impl_base.Reg("TLBEntry29", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry29" !reg); - (Sail_impl_base.Reg("TLBEntry30", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry30" !reg); - (Sail_impl_base.Reg("TLBEntry31", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry31" !reg); - (Sail_impl_base.Reg("TLBEntry32", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry32" !reg); - (Sail_impl_base.Reg("TLBEntry33", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry33" !reg); - (Sail_impl_base.Reg("TLBEntry34", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry34" !reg); - (Sail_impl_base.Reg("TLBEntry35", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry35" !reg); - (Sail_impl_base.Reg("TLBEntry36", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry36" !reg); - (Sail_impl_base.Reg("TLBEntry37", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry37" !reg); - (Sail_impl_base.Reg("TLBEntry38", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry38" !reg); - (Sail_impl_base.Reg("TLBEntry39", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry39" !reg); - (Sail_impl_base.Reg("TLBEntry40", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry40" !reg); - (Sail_impl_base.Reg("TLBEntry41", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry41" !reg); - (Sail_impl_base.Reg("TLBEntry42", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry42" !reg); - (Sail_impl_base.Reg("TLBEntry43", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry43" !reg); - (Sail_impl_base.Reg("TLBEntry44", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry44" !reg); - (Sail_impl_base.Reg("TLBEntry45", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry45" !reg); - (Sail_impl_base.Reg("TLBEntry46", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry46" !reg); - (Sail_impl_base.Reg("TLBEntry47", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry47" !reg); - (Sail_impl_base.Reg("TLBEntry48", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry48" !reg); - (Sail_impl_base.Reg("TLBEntry49", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry49" !reg); - (Sail_impl_base.Reg("TLBEntry50", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry50" !reg); - (Sail_impl_base.Reg("TLBEntry51", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry51" !reg); - (Sail_impl_base.Reg("TLBEntry52", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry52" !reg); - (Sail_impl_base.Reg("TLBEntry53", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry53" !reg); - (Sail_impl_base.Reg("TLBEntry54", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry54" !reg); - (Sail_impl_base.Reg("TLBEntry55", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry55" !reg); - (Sail_impl_base.Reg("TLBEntry56", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry56" !reg); - (Sail_impl_base.Reg("TLBEntry57", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry57" !reg); - (Sail_impl_base.Reg("TLBEntry58", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry58" !reg); - (Sail_impl_base.Reg("TLBEntry59", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry59" !reg); - (Sail_impl_base.Reg("TLBEntry60", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry60" !reg); - (Sail_impl_base.Reg("TLBEntry61", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry61" !reg); - (Sail_impl_base.Reg("TLBEntry62", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry62" !reg); - (Sail_impl_base.Reg("TLBEntry63", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry63" !reg); - ]) + Some + [ + (Sail_impl_base.Reg ("PC", 63, 64, Sail_impl_base.D_decreasing), Reg.find "PC" !reg); + (Sail_impl_base.Reg ("PCC", 128, 129, Sail_impl_base.D_decreasing), Reg.find "PCC" !reg); + (Sail_impl_base.Reg ("C29", 128, 129, Sail_impl_base.D_decreasing), Reg.find "C29" !reg); + (Sail_impl_base.Reg ("CP0Status", 31, 32, Sail_impl_base.D_decreasing), Reg.find "CP0Status" !reg); + (Sail_impl_base.Reg ("CP0Cause", 31, 32, Sail_impl_base.D_decreasing), Reg.find "CP0Cause" !reg); + (Sail_impl_base.Reg ("CP0Count", 31, 32, Sail_impl_base.D_decreasing), Reg.find "CP0Count" !reg); + (Sail_impl_base.Reg ("CP0Compare", 31, 32, Sail_impl_base.D_decreasing), Reg.find "CP0Compare" !reg); + (Sail_impl_base.Reg ("inBranchDelay", 0, 1, Sail_impl_base.D_decreasing), Reg.find "inBranchDelay" !reg); + (Sail_impl_base.Reg ("TLBRandom", 5, 6, Sail_impl_base.D_decreasing), Reg.find "TLBRandom" !reg); + (Sail_impl_base.Reg ("TLBWired", 5, 6, Sail_impl_base.D_decreasing), Reg.find "TLBWired" !reg); + (Sail_impl_base.Reg ("TLBEntryHi", 63, 64, Sail_impl_base.D_decreasing), Reg.find "TLBEntryHi" !reg); + (Sail_impl_base.Reg ("TLBEntry00", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry00" !reg); + (Sail_impl_base.Reg ("TLBEntry01", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry01" !reg); + (Sail_impl_base.Reg ("TLBEntry02", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry02" !reg); + (Sail_impl_base.Reg ("TLBEntry03", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry03" !reg); + (Sail_impl_base.Reg ("TLBEntry04", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry04" !reg); + (Sail_impl_base.Reg ("TLBEntry05", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry05" !reg); + (Sail_impl_base.Reg ("TLBEntry06", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry06" !reg); + (Sail_impl_base.Reg ("TLBEntry07", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry07" !reg); + (Sail_impl_base.Reg ("TLBEntry08", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry08" !reg); + (Sail_impl_base.Reg ("TLBEntry09", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry09" !reg); + (Sail_impl_base.Reg ("TLBEntry10", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry10" !reg); + (Sail_impl_base.Reg ("TLBEntry11", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry11" !reg); + (Sail_impl_base.Reg ("TLBEntry12", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry12" !reg); + (Sail_impl_base.Reg ("TLBEntry13", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry13" !reg); + (Sail_impl_base.Reg ("TLBEntry14", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry14" !reg); + (Sail_impl_base.Reg ("TLBEntry15", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry15" !reg); + (Sail_impl_base.Reg ("TLBEntry16", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry16" !reg); + (Sail_impl_base.Reg ("TLBEntry17", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry17" !reg); + (Sail_impl_base.Reg ("TLBEntry18", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry18" !reg); + (Sail_impl_base.Reg ("TLBEntry19", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry19" !reg); + (Sail_impl_base.Reg ("TLBEntry20", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry20" !reg); + (Sail_impl_base.Reg ("TLBEntry21", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry21" !reg); + (Sail_impl_base.Reg ("TLBEntry22", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry22" !reg); + (Sail_impl_base.Reg ("TLBEntry23", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry23" !reg); + (Sail_impl_base.Reg ("TLBEntry24", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry24" !reg); + (Sail_impl_base.Reg ("TLBEntry25", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry25" !reg); + (Sail_impl_base.Reg ("TLBEntry26", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry26" !reg); + (Sail_impl_base.Reg ("TLBEntry27", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry27" !reg); + (Sail_impl_base.Reg ("TLBEntry28", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry28" !reg); + (Sail_impl_base.Reg ("TLBEntry29", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry29" !reg); + (Sail_impl_base.Reg ("TLBEntry30", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry30" !reg); + (Sail_impl_base.Reg ("TLBEntry31", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry31" !reg); + (Sail_impl_base.Reg ("TLBEntry32", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry32" !reg); + (Sail_impl_base.Reg ("TLBEntry33", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry33" !reg); + (Sail_impl_base.Reg ("TLBEntry34", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry34" !reg); + (Sail_impl_base.Reg ("TLBEntry35", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry35" !reg); + (Sail_impl_base.Reg ("TLBEntry36", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry36" !reg); + (Sail_impl_base.Reg ("TLBEntry37", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry37" !reg); + (Sail_impl_base.Reg ("TLBEntry38", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry38" !reg); + (Sail_impl_base.Reg ("TLBEntry39", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry39" !reg); + (Sail_impl_base.Reg ("TLBEntry40", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry40" !reg); + (Sail_impl_base.Reg ("TLBEntry41", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry41" !reg); + (Sail_impl_base.Reg ("TLBEntry42", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry42" !reg); + (Sail_impl_base.Reg ("TLBEntry43", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry43" !reg); + (Sail_impl_base.Reg ("TLBEntry44", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry44" !reg); + (Sail_impl_base.Reg ("TLBEntry45", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry45" !reg); + (Sail_impl_base.Reg ("TLBEntry46", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry46" !reg); + (Sail_impl_base.Reg ("TLBEntry47", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry47" !reg); + (Sail_impl_base.Reg ("TLBEntry48", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry48" !reg); + (Sail_impl_base.Reg ("TLBEntry49", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry49" !reg); + (Sail_impl_base.Reg ("TLBEntry50", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry50" !reg); + (Sail_impl_base.Reg ("TLBEntry51", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry51" !reg); + (Sail_impl_base.Reg ("TLBEntry52", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry52" !reg); + (Sail_impl_base.Reg ("TLBEntry53", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry53" !reg); + (Sail_impl_base.Reg ("TLBEntry54", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry54" !reg); + (Sail_impl_base.Reg ("TLBEntry55", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry55" !reg); + (Sail_impl_base.Reg ("TLBEntry56", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry56" !reg); + (Sail_impl_base.Reg ("TLBEntry57", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry57" !reg); + (Sail_impl_base.Reg ("TLBEntry58", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry58" !reg); + (Sail_impl_base.Reg ("TLBEntry59", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry59" !reg); + (Sail_impl_base.Reg ("TLBEntry60", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry60" !reg); + (Sail_impl_base.Reg ("TLBEntry61", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry61" !reg); + (Sail_impl_base.Reg ("TLBEntry62", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry62" !reg); + (Sail_impl_base.Reg ("TLBEntry63", 116, 117, Sail_impl_base.D_decreasing), Reg.find "TLBEntry63" !reg); + ] let get_opcode pc_a = - List.map (fun b -> match b with - | Some b -> b - | None -> failwith "A byte in opcode contained unknown or undef") + List.map + (fun b -> match b with Some b -> b | None -> failwith "A byte in opcode contained unknown or undef") (List.map byte_of_memory_byte - ([Mem.find pc_a !prog_mem; + [ + Mem.find pc_a !prog_mem; Mem.find (add1 pc_a) !prog_mem; Mem.find (add1 (add1 pc_a)) !prog_mem; - Mem.find (add1 (add1 (add1 pc_a))) !prog_mem])) + Mem.find (add1 (add1 (add1 pc_a))) !prog_mem; + ] + ) let rec write_events = function | [] -> () - | e::events -> - (match e with - | E_write_reg (Reg(id,_,_,_), value) -> reg := Reg.add id value !reg - | E_write_reg ((Reg_slice(id,_,_,range) as reg_n),value) - | E_write_reg ((Reg_field(id,_,_,_,range) as reg_n),value)-> - let old_val = Reg.find id !reg in - let new_val = fupdate_slice reg_n old_val value range in - reg := Reg.add id new_val !reg - | E_write_reg((Reg_f_slice(id,_,_,_,range,mini_range) as reg_n),value) -> - let old_val = Reg.find id !reg in - let new_val = fupdate_slice reg_n old_val value (combine_slices range mini_range) in - reg := Reg.add id new_val !reg - | _ -> failwith "Only register write events expected"); - write_events events - -let get_pc_address = function - | MIPS -> Reg.find "PC" !reg - | PPC -> Reg.find "CIA" !reg - | AArch64 -> Reg.find "_PC" !reg - -let option_int_of_reg str = - option_int_of_option_integer (integer_of_register_value (Reg.find str !reg)) + | e :: events -> + ( match e with + | E_write_reg (Reg (id, _, _, _), value) -> reg := Reg.add id value !reg + | E_write_reg ((Reg_slice (id, _, _, range) as reg_n), value) + | E_write_reg ((Reg_field (id, _, _, _, range) as reg_n), value) -> + let old_val = Reg.find id !reg in + let new_val = fupdate_slice reg_n old_val value range in + reg := Reg.add id new_val !reg + | E_write_reg ((Reg_f_slice (id, _, _, _, range, mini_range) as reg_n), value) -> + let old_val = Reg.find id !reg in + let new_val = fupdate_slice reg_n old_val value (combine_slices range mini_range) in + reg := Reg.add id new_val !reg + | _ -> failwith "Only register write events expected" + ); + write_events events + +let get_pc_address = function MIPS -> Reg.find "PC" !reg | PPC -> Reg.find "CIA" !reg | AArch64 -> Reg.find "_PC" !reg + +let option_int_of_reg str = option_int_of_option_integer (integer_of_register_value (Reg.find str !reg)) let rec fde_loop count context model mode track_dependencies addr_trans = - if !max_cut_off && count = !max_instr - then resultf "\nEnding evaluation due to reaching cut off point of %d instructions\n" count + if !max_cut_off && count = !max_instr then + resultf "\nEnding evaluation due to reaching cut off point of %d instructions\n" count else begin - if !break_point && count = !break_instr then begin break_point := false; eager_eval := false end; + if !break_point && count = !break_instr then begin + break_point := false; + eager_eval := false + end; let pc_regval = get_pc_address model in - interactf "\n**** instruction %d from address %s ****\n" - count (Printing_functions.register_value_to_string pc_regval); - let pc_addr = address_of_register_value pc_regval in - let pc_val = match pc_addr with - | Some v -> v - | None -> failwith "pc contains undef or unknown" in - let m_paddr_int = match addr_trans (get_addr_trans_regs ()) pc_val with - | Some a, Some events -> write_events (List.rev events); Some (integer_of_address a) - | Some a, None -> Some (integer_of_address a) - | None, Some events -> write_events (List.rev events); None - | None, None -> failwith "address translation failed and no writes" in + interactf "\n**** instruction %d from address %s ****\n" count + (Printing_functions.register_value_to_string pc_regval); + let pc_addr = address_of_register_value pc_regval in + let pc_val = match pc_addr with Some v -> v | None -> failwith "pc contains undef or unknown" in + let m_paddr_int = + match addr_trans (get_addr_trans_regs ()) pc_val with + | Some a, Some events -> + write_events (List.rev events); + Some (integer_of_address a) + | Some a, None -> Some (integer_of_address a) + | None, Some events -> + write_events (List.rev events); + None + | None, None -> failwith "address translation failed and no writes" + in match m_paddr_int with - | Some pc -> - let inBranchDelay = option_int_of_reg "inBranchDelay" in - (match inBranchDelay with - | Some 0 -> + | Some pc -> + let inBranchDelay = option_int_of_reg "inBranchDelay" in + ( match inBranchDelay with + | Some 0 -> let npc_addr = add_address_nat pc_val 4 in let npc_reg = register_value_of_address npc_addr Sail_impl_base.D_decreasing in reg := Reg.add "nextPC" npc_reg !reg; - reg := Reg.add "inCCallDelay" (register_value_of_integer 1 0 Sail_impl_base.D_decreasing Nat_big_num.zero) !reg; - | Some 1 -> + reg := + Reg.add "inCCallDelay" (register_value_of_integer 1 0 Sail_impl_base.D_decreasing Nat_big_num.zero) !reg + | Some 1 -> reg := Reg.add "nextPC" (Reg.find "delayedPC" !reg) !reg; - reg := Reg.add "nextPCC" (Reg.find "delayedPCC" !reg) !reg; - | _ -> failwith "invalid value of inBranchDelay"); - let opcode = Opcode (get_opcode pc) in - let (instruction,istate) = match Interp_inter_imp.decode_to_istate context None opcode with - | Instr(instruction,istate) -> - let instruction = interp_value_to_instr_external context instruction in + reg := Reg.add "nextPCC" (Reg.find "delayedPCC" !reg) !reg + | _ -> failwith "invalid value of inBranchDelay" + ); + let opcode = Opcode (get_opcode pc) in + let instruction, istate = + match Interp_inter_imp.decode_to_istate context None opcode with + | Instr (instruction, istate) -> + let instruction = interp_value_to_instr_external context instruction in interactf "\n**** Running: %s ****\n" (Printing_functions.instruction_to_string instruction); - (instruction,istate) - | Decode_error d -> - (match d with + (instruction, istate) + | Decode_error d -> + ( match d with | Interp_interface.Unsupported_instruction_error instruction -> - let instruction = interp_value_to_instr_external context instruction in - errorf "\n**** Encountered unsupported instruction %s ****\n" (Printing_functions.instruction_to_string instruction) - | Interp_interface.Not_an_instruction_error op -> - (match op with - | Opcode bytes -> - errorf "\n**** Encountered non-decodeable opcode: %s ****\n" (Printing_functions.byte_list_to_string bytes)) - | Internal_error s -> errorf "\n**** Internal error on decode: %s ****\n" s); exit 1 - in - if stop_condition_met model instruction - then resultf "\nSUCCESS program terminated after %d instructions\n" count - else - begin - match Run_interp_model.run istate !reg !prog_mem !tag_mem (Nat_big_num.of_int 16) !eager_eval track_dependencies mode "execute" with - | false, _,_, _ -> errorf "FAILURE\n"; exit 1 - | true, mode, track_dependencies, (my_reg, my_mem, my_tags) -> - reg := my_reg; - prog_mem := my_mem; - tag_mem := my_tags; - - (try - let (pending, _, _) = (Unix.select [(Unix.stdin)] [] [] 0.0) in - (if (pending != []) then - let char = (input_byte stdin) in ( + let instruction = interp_value_to_instr_external context instruction in + errorf "\n**** Encountered unsupported instruction %s ****\n" + (Printing_functions.instruction_to_string instruction) + | Interp_interface.Not_an_instruction_error op -> ( + match op with + | Opcode bytes -> + errorf "\n**** Encountered non-decodeable opcode: %s ****\n" + (Printing_functions.byte_list_to_string bytes) + ) + | Internal_error s -> errorf "\n**** Internal error on decode: %s ****\n" s + ); + exit 1 + in + if stop_condition_met model instruction then + resultf "\nSUCCESS program terminated after %d instructions\n" count + else begin + match + Run_interp_model.run istate !reg !prog_mem !tag_mem (Nat_big_num.of_int 16) !eager_eval track_dependencies + mode "execute" + with + | false, _, _, _ -> + errorf "FAILURE\n"; + exit 1 + | true, mode, track_dependencies, (my_reg, my_mem, my_tags) -> + reg := my_reg; + prog_mem := my_mem; + tag_mem := my_tags; + + ( try + let pending, _, _ = Unix.select [Unix.stdin] [] [] 0.0 in + if pending != [] then ( + let char = input_byte stdin in errorf "Input %x\n" char; - input_buf := (!input_buf) @ [char])); - with - | _ -> ()); - - let uart_rvalid = option_int_of_reg "UART_RVALID" in - (match uart_rvalid with - | Some 0 -> - (match !input_buf with - | x :: xs -> ( - reg := Reg.add "UART_RDATA" (register_value_of_integer 8 7 Sail_impl_base.D_decreasing (Nat_big_num.of_int x)) !reg; - reg := Reg.add "UART_RVALID" (register_value_of_integer 1 0 Sail_impl_base.D_decreasing (Nat_big_num.of_int 1)) !reg; - input_buf := xs; - ) - | [] -> ()) - | _-> ()); - - let uart_written = option_int_of_reg "UART_WRITTEN" in - (match uart_written with - | Some 1 -> - (let uart_data = option_int_of_reg "UART_WDATA" in + input_buf := !input_buf @ [char] + ) + with _ -> () + ); + + let uart_rvalid = option_int_of_reg "UART_RVALID" in + ( match uart_rvalid with + | Some 0 -> ( + match !input_buf with + | x :: xs -> + reg := + Reg.add "UART_RDATA" + (register_value_of_integer 8 7 Sail_impl_base.D_decreasing (Nat_big_num.of_int x)) + !reg; + reg := + Reg.add "UART_RVALID" + (register_value_of_integer 1 0 Sail_impl_base.D_decreasing (Nat_big_num.of_int 1)) + !reg; + input_buf := xs + | [] -> () + ) + | _ -> () + ); + + let uart_written = option_int_of_reg "UART_WRITTEN" in + ( match uart_written with + | Some 1 -> ( + let uart_data = option_int_of_reg "UART_WDATA" in match uart_data with - | Some b -> (printf "%c" (Char.chr b); printf "%!") - | None -> (errorf "UART_WDATA was undef" ; exit 1)) - | _ -> ()); - reg := Reg.add "UART_WRITTEN" (register_value_of_integer 1 0 Sail_impl_base.D_decreasing Nat_big_num.zero) !reg; - - reg := Reg.add "inBranchDelay" (Reg.find "branchPending" !reg) !reg; - reg := Reg.add "branchPending" (register_value_of_integer 1 0 Sail_impl_base.D_decreasing Nat_big_num.zero) !reg; - reg := Reg.add "PC" (Reg.find "nextPC" !reg) !reg; - reg := Reg.add "PCC" (Reg.find "nextPCC" !reg) !reg; - fde_loop (count + 1) context model (Some mode) (ref track_dependencies) addr_trans - end - | None -> begin - reg := Reg.add "inBranchDelay" (Reg.find "branchPending" !reg) !reg; - reg := Reg.add "branchPending" (register_value_of_integer 1 0 Sail_impl_base.D_decreasing Nat_big_num.zero) !reg; - reg := Reg.add "PC" (Reg.find "nextPC" !reg) !reg; - reg := Reg.add "PCC" (Reg.find "nextPCC" !reg) !reg; - fde_loop (count + 1) context model mode track_dependencies addr_trans - end + | Some b -> + printf "%c" (Char.chr b); + printf "%!" + | None -> + errorf "UART_WDATA was undef"; + exit 1 + ) + | _ -> () + ); + reg := + Reg.add "UART_WRITTEN" (register_value_of_integer 1 0 Sail_impl_base.D_decreasing Nat_big_num.zero) !reg; + + reg := Reg.add "inBranchDelay" (Reg.find "branchPending" !reg) !reg; + reg := + Reg.add "branchPending" + (register_value_of_integer 1 0 Sail_impl_base.D_decreasing Nat_big_num.zero) + !reg; + reg := Reg.add "PC" (Reg.find "nextPC" !reg) !reg; + reg := Reg.add "PCC" (Reg.find "nextPCC" !reg) !reg; + fde_loop (count + 1) context model (Some mode) (ref track_dependencies) addr_trans + end + | None -> begin + reg := Reg.add "inBranchDelay" (Reg.find "branchPending" !reg) !reg; + reg := Reg.add "branchPending" (register_value_of_integer 1 0 Sail_impl_base.D_decreasing Nat_big_num.zero) !reg; + reg := Reg.add "PC" (Reg.find "nextPC" !reg) !reg; + reg := Reg.add "PCC" (Reg.find "nextPCC" !reg) !reg; + fde_loop (count + 1) context model mode track_dependencies addr_trans + end end - + let rec load_raw_file' mem addr chan = let byte = input_byte chan in - (add_mem byte addr mem; - load_raw_file' mem (Nat_big_num.succ addr) chan) + add_mem byte addr mem; + load_raw_file' mem (Nat_big_num.succ addr) chan -let rec load_raw_file mem addr chan = - try - load_raw_file' mem addr chan - with - | End_of_file -> () +let rec load_raw_file mem addr chan = try load_raw_file' mem addr chan with End_of_file -> () let run () = - Arg.parse args (fun _ -> raise (Arg.Bad "anonymous parameter")) "" ; + Arg.parse args (fun _ -> raise (Arg.Bad "anonymous parameter")) ""; if !file = "" then begin Arg.usage args ""; - exit 1; + exit 1 end; if !break_point then eager_eval := true; - let ((isa_defs, - (isa_m0, isa_m1, isa_m2, isa_m3, isa_m4, isa_m5, isa_m6), - isa_externs, - isa_model, - model_reg_d, - startaddr, - startaddr_internal), pp_symbol_map) = initial_system_state_of_elf_file !file in + let ( ( isa_defs, + (isa_m0, isa_m1, isa_m2, isa_m3, isa_m4, isa_m5, isa_m6), + isa_externs, + isa_model, + model_reg_d, + startaddr, + startaddr_internal + ), + pp_symbol_map ) = + initial_system_state_of_elf_file !file + in let context = build_context false isa_defs isa_m0 isa_m1 isa_m2 isa_m3 isa_m4 isa_m5 isa_m6 None isa_externs in - (*NOTE: this is likely MIPS specific, so should probably pull from initial_system_state info on to translate or not, - endian mode, and translate function name + (*NOTE: this is likely MIPS specific, so should probably pull from initial_system_state info on to translate or not, + endian mode, and translate function name *) let addr_trans = translate_address context E_little_endian "TranslatePC" in - if String.length(!raw_file) != 0 then - load_raw_file prog_mem (Nat_big_num.of_int !raw_at) (open_in_bin !raw_file); - reg := Reg.add "PC" (register_value_of_address startaddr_internal model_reg_d ) !reg; + if String.length !raw_file != 0 then load_raw_file prog_mem (Nat_big_num.of_int !raw_at) (open_in_bin !raw_file); + reg := Reg.add "PC" (register_value_of_address startaddr_internal model_reg_d) !reg; (* entry point: unit -> unit fde *) let name = Filename.basename !file in - let t = time_it (fun () -> fde_loop 0 context isa_model (Some Run) (ref false) addr_trans) () in - resultf "Execution time for file %s: %f seconds\n" name t;; + let t = time_it (fun () -> fde_loop 0 context isa_model (Some Run) (ref false) addr_trans) () in + resultf "Execution time for file %s: %f seconds\n" name t +;; (* Turn off line-buffering of standard input to allow responsive console input *) -if Unix.isatty (Unix.stdin) then begin - let tattrs = Unix.tcgetattr (Unix.stdin) in - Unix.tcsetattr (Unix.stdin) (Unix.TCSANOW) ({tattrs with c_icanon=false}) -end ;; +if Unix.isatty Unix.stdin then begin + let tattrs = Unix.tcgetattr Unix.stdin in + Unix.tcsetattr Unix.stdin Unix.TCSANOW { tattrs with c_icanon = false } +end +;; -run () ;; +run () diff --git a/src/lib/anf.ml b/src/lib/anf.ml index a996f37cf..66bb6d64e 100644 --- a/src/lib/anf.ml +++ b/src/lib/anf.ml @@ -82,11 +82,11 @@ type 'a aexp = AE_aux of 'a aexp_aux * Env.t * l and 'a aexp_aux = | AE_val of 'a aval - | AE_app of id * ('a aval) list * 'a + | AE_app of id * 'a aval list * 'a | AE_typ of 'a aexp * 'a | AE_assign of 'a alexp * 'a aexp | AE_let of mut * id * 'a * 'a aexp * 'a aexp * 'a - | AE_block of ('a aexp) list * 'a aexp * 'a + | AE_block of 'a aexp list * 'a aexp * 'a | AE_return of 'a aval * 'a | AE_exit of 'a aval * 'a | AE_throw of 'a aval * 'a @@ -94,7 +94,7 @@ and 'a aexp_aux = | AE_field of 'a aval * id * 'a | AE_match of 'a aval * ('a apat * 'a aexp * 'a aexp) list * 'a | AE_try of 'a aexp * ('a apat * 'a aexp * 'a aexp) list * 'a - | AE_struct_update of 'a aval * ('a aval) Bindings.t * 'a + | AE_struct_update of 'a aval * 'a aval Bindings.t * 'a | AE_for of id * 'a aexp * 'a aexp * 'a aexp * order * 'a aexp | AE_loop of loop * 'a aexp * 'a aexp | AE_short_circuit of sc_op * 'a aval * 'a aexp @@ -104,7 +104,7 @@ and sc_op = SC_and | SC_or and 'a apat = AP_aux of 'a apat_aux * Env.t * l and 'a apat_aux = - | AP_tuple of ('a apat) list + | AP_tuple of 'a apat list | AP_id of id * 'a | AP_global of id * 'a | AP_app of id * 'a apat * 'a @@ -117,19 +117,16 @@ and 'a aval = | AV_lit of lit * 'a | AV_id of id * 'a lvar | AV_ref of id * 'a lvar - | AV_tuple of ('a aval) list - | AV_list of ('a aval) list * 'a - | AV_vector of ('a aval) list * 'a - | AV_record of ('a aval) Bindings.t * 'a + | AV_tuple of 'a aval list + | AV_list of 'a aval list * 'a + | AV_vector of 'a aval list * 'a + | AV_record of 'a aval Bindings.t * 'a | AV_cval of cval * 'a -and 'a alexp = - | AL_id of id * 'a - | AL_addr of id * 'a - | AL_field of 'a alexp * id - +and 'a alexp = AL_id of id * 'a | AL_addr of id * 'a | AL_field of 'a alexp * id + let aexp_loc (AE_aux (_, _, l)) = l - + (* Renaming variables in ANF expressions *) let rec apat_bindings (AP_aux (apat_aux, _, _)) = @@ -147,10 +144,10 @@ let rec apat_bindings (AP_aux (apat_aux, _, _)) = pattern. It ignores AP_global, apat_globals is used for that. *) let rec apat_types (AP_aux (apat_aux, env, _)) = let merge id b1 b2 = - match b1, b2 with - | None, None -> None - | Some v, None -> Some v - | None, Some v -> Some v + match (b1, b2) with + | None, None -> None + | Some v, None -> Some v + | None, Some v -> Some v | Some _, Some _ -> assert false in match apat_aux with @@ -165,7 +162,8 @@ let rec apat_types (AP_aux (apat_aux, env, _)) = | AP_wild _ -> Bindings.empty let rec apat_rename from_id to_id (AP_aux (apat_aux, env, l)) = - let apat_aux = match apat_aux with + let apat_aux = + match apat_aux with | AP_tuple apats -> AP_tuple (List.map (apat_rename from_id to_id) apats) | AP_id (id, typ) when Id.compare id from_id = 0 -> AP_id (to_id, typ) | AP_id (id, typ) -> AP_id (id, typ) @@ -230,53 +228,63 @@ let rec alexp_rename from_id to_id = function let rec aexp_rename from_id to_id (AE_aux (aexp, env, l)) = let recur = aexp_rename from_id to_id in - let aexp = match aexp with + let aexp = + match aexp with | AE_val aval -> AE_val (aval_rename from_id to_id aval) | AE_app (id, avals, typ) -> AE_app (id, List.map (aval_rename from_id to_id) avals, typ) | AE_typ (aexp, typ) -> AE_typ (recur aexp, typ) | AE_assign (alexp, aexp) -> AE_assign (alexp_rename from_id to_id alexp, aexp_rename from_id to_id aexp) - | AE_let (mut, id, typ1, aexp1, aexp2, typ2) when Id.compare from_id id = 0 -> AE_let (mut, id, typ1, recur aexp1, aexp2, typ2) + | AE_let (mut, id, typ1, aexp1, aexp2, typ2) when Id.compare from_id id = 0 -> + AE_let (mut, id, typ1, recur aexp1, aexp2, typ2) | AE_let (mut, id, typ1, aexp1, aexp2, typ2) -> AE_let (mut, id, typ1, recur aexp1, recur aexp2, typ2) | AE_block (aexps, aexp, typ) -> AE_block (List.map recur aexps, recur aexp, typ) | AE_return (aval, typ) -> AE_return (aval_rename from_id to_id aval, typ) | AE_exit (aval, typ) -> AE_exit (aval_rename from_id to_id aval, typ) | AE_throw (aval, typ) -> AE_throw (aval_rename from_id to_id aval, typ) - | AE_if (aval, then_aexp, else_aexp, typ) -> AE_if (aval_rename from_id to_id aval, recur then_aexp, recur else_aexp, typ) + | AE_if (aval, then_aexp, else_aexp, typ) -> + AE_if (aval_rename from_id to_id aval, recur then_aexp, recur else_aexp, typ) | AE_field (aval, id, typ) -> AE_field (aval_rename from_id to_id aval, id, typ) - | AE_match (aval, apexps, typ) -> AE_match (aval_rename from_id to_id aval, List.map (apexp_rename from_id to_id) apexps, typ) - | AE_try (aexp, apexps, typ) -> AE_try (aexp_rename from_id to_id aexp, List.map (apexp_rename from_id to_id) apexps, typ) - | AE_struct_update (aval, avals, typ) -> AE_struct_update (aval_rename from_id to_id aval, Bindings.map (aval_rename from_id to_id) avals, typ) - | AE_for (id, aexp1, aexp2, aexp3, order, aexp4) when Id.compare from_id to_id = 0 -> AE_for (id, aexp1, aexp2, aexp3, order, aexp4) - | AE_for (id, aexp1, aexp2, aexp3, order, aexp4) -> AE_for (id, recur aexp1, recur aexp2, recur aexp3, order, recur aexp4) + | AE_match (aval, apexps, typ) -> + AE_match (aval_rename from_id to_id aval, List.map (apexp_rename from_id to_id) apexps, typ) + | AE_try (aexp, apexps, typ) -> + AE_try (aexp_rename from_id to_id aexp, List.map (apexp_rename from_id to_id) apexps, typ) + | AE_struct_update (aval, avals, typ) -> + AE_struct_update (aval_rename from_id to_id aval, Bindings.map (aval_rename from_id to_id) avals, typ) + | AE_for (id, aexp1, aexp2, aexp3, order, aexp4) when Id.compare from_id to_id = 0 -> + AE_for (id, aexp1, aexp2, aexp3, order, aexp4) + | AE_for (id, aexp1, aexp2, aexp3, order, aexp4) -> + AE_for (id, recur aexp1, recur aexp2, recur aexp3, order, recur aexp4) | AE_loop (loop, aexp1, aexp2) -> AE_loop (loop, recur aexp1, recur aexp2) | AE_short_circuit (op, aval, aexp) -> AE_short_circuit (op, aval_rename from_id to_id aval, recur aexp) in AE_aux (aexp, env, l) and apexp_rename from_id to_id (apat, aexp1, aexp2) = - if IdSet.mem from_id (apat_bindings apat) then - (apat, aexp1, aexp2) - else - (apat, aexp_rename from_id to_id aexp1, aexp_rename from_id to_id aexp2) + if IdSet.mem from_id (apat_bindings apat) then (apat, aexp1, aexp2) + else (apat, aexp_rename from_id to_id aexp1, aexp_rename from_id to_id aexp2) let rec fold_aexp f (AE_aux (aexp, env, l)) = - let aexp = match aexp with + let aexp = + match aexp with | AE_app (id, vs, typ) -> AE_app (id, vs, typ) | AE_typ (aexp, typ) -> AE_typ (fold_aexp f aexp, typ) | AE_assign (alexp, aexp) -> AE_assign (alexp, fold_aexp f aexp) | AE_short_circuit (op, aval, aexp) -> AE_short_circuit (op, aval, fold_aexp f aexp) | AE_let (mut, id, typ1, aexp1, aexp2, typ2) -> AE_let (mut, id, typ1, fold_aexp f aexp1, fold_aexp f aexp2, typ2) | AE_block (aexps, aexp, typ) -> AE_block (List.map (fold_aexp f) aexps, fold_aexp f aexp, typ) - | AE_if (aval, aexp1, aexp2, typ) -> - AE_if (aval, fold_aexp f aexp1, fold_aexp f aexp2, typ) + | AE_if (aval, aexp1, aexp2, typ) -> AE_if (aval, fold_aexp f aexp1, fold_aexp f aexp2, typ) | AE_loop (loop_typ, aexp1, aexp2) -> AE_loop (loop_typ, fold_aexp f aexp1, fold_aexp f aexp2) | AE_for (id, aexp1, aexp2, aexp3, order, aexp4) -> - AE_for (id, fold_aexp f aexp1, fold_aexp f aexp2, fold_aexp f aexp3, order, fold_aexp f aexp4) + AE_for (id, fold_aexp f aexp1, fold_aexp f aexp2, fold_aexp f aexp3, order, fold_aexp f aexp4) | AE_match (aval, cases, typ) -> - AE_match (aval, List.map (fun (pat, aexp1, aexp2) -> pat, fold_aexp f aexp1, fold_aexp f aexp2) cases, typ) + AE_match (aval, List.map (fun (pat, aexp1, aexp2) -> (pat, fold_aexp f aexp1, fold_aexp f aexp2)) cases, typ) | AE_try (aexp, cases, typ) -> - AE_try (fold_aexp f aexp, List.map (fun (pat, aexp1, aexp2) -> pat, fold_aexp f aexp1, fold_aexp f aexp2) cases, typ) - | AE_field _ | AE_struct_update _ | AE_val _ | AE_return _ | AE_exit _ | AE_throw _ as v -> v + AE_try + ( fold_aexp f aexp, + List.map (fun (pat, aexp1, aexp2) -> (pat, fold_aexp f aexp1, fold_aexp f aexp2)) cases, + typ + ) + | (AE_field _ | AE_struct_update _ | AE_val _ | AE_return _ | AE_exit _ | AE_throw _) as v -> v in f (AE_aux (aexp, env, l)) @@ -284,8 +292,8 @@ let aexp_bindings aexp = let ids = ref IdSet.empty in let collect_lets = function | AE_aux (AE_let (_, id, _, _, _, _), _, _) as aexp -> - ids := IdSet.add id !ids; - aexp + ids := IdSet.add id !ids; + aexp | aexp -> aexp in ignore (fold_aexp collect_lets aexp); @@ -299,18 +307,19 @@ let new_shadow id = shadow_id let rec no_shadow ids (AE_aux (aexp, env, l)) = - let aexp = match aexp with + let aexp = + match aexp with | AE_val aval -> AE_val aval | AE_app (id, avals, typ) -> AE_app (id, avals, typ) | AE_typ (aexp, typ) -> AE_typ (no_shadow ids aexp, typ) | AE_assign (alexp, aexp) -> AE_assign (alexp, no_shadow ids aexp) | AE_let (mut, id, typ1, aexp1, aexp2, typ2) when IdSet.mem id ids -> - let shadow_id = new_shadow id in - let aexp1 = no_shadow ids aexp1 in - let ids = IdSet.add shadow_id ids in - AE_let (mut, shadow_id, typ1, aexp1, no_shadow ids (aexp_rename id shadow_id aexp2), typ2) + let shadow_id = new_shadow id in + let aexp1 = no_shadow ids aexp1 in + let ids = IdSet.add shadow_id ids in + AE_let (mut, shadow_id, typ1, aexp1, no_shadow ids (aexp_rename id shadow_id aexp2), typ2) | AE_let (mut, id, typ1, aexp1, aexp2, typ2) -> - AE_let (mut, id, typ1, no_shadow ids aexp1, no_shadow (IdSet.add id ids) aexp2, typ2) + AE_let (mut, id, typ1, no_shadow ids aexp1, no_shadow (IdSet.add id ids) aexp2, typ2) | AE_block (aexps, aexp, typ) -> AE_block (List.map (no_shadow ids) aexps, no_shadow ids aexp, typ) | AE_return (aval, typ) -> AE_return (aval, typ) | AE_exit (aval, typ) -> AE_exit (aval, typ) @@ -321,15 +330,15 @@ let rec no_shadow ids (AE_aux (aexp, env, l)) = | AE_try (aexp, apexps, typ) -> AE_try (no_shadow ids aexp, List.map (no_shadow_apexp ids) apexps, typ) | AE_struct_update (aval, avals, typ) -> AE_struct_update (aval, avals, typ) | AE_for (id, aexp1, aexp2, aexp3, order, aexp4) when IdSet.mem id ids -> - let shadow_id = new_shadow id in - let aexp1 = no_shadow ids aexp1 in - let aexp2 = no_shadow ids aexp2 in - let aexp3 = no_shadow ids aexp3 in - let ids = IdSet.add shadow_id ids in - AE_for (shadow_id, aexp1, aexp2, aexp3, order, no_shadow ids (aexp_rename id shadow_id aexp4)) + let shadow_id = new_shadow id in + let aexp1 = no_shadow ids aexp1 in + let aexp2 = no_shadow ids aexp2 in + let aexp3 = no_shadow ids aexp3 in + let ids = IdSet.add shadow_id ids in + AE_for (shadow_id, aexp1, aexp2, aexp3, order, no_shadow ids (aexp_rename id shadow_id aexp4)) | AE_for (id, aexp1, aexp2, aexp3, order, aexp4) -> - let ids = IdSet.add id ids in - AE_for (id, no_shadow ids aexp1, no_shadow ids aexp2, no_shadow ids aexp3, order, no_shadow ids aexp4) + let ids = IdSet.add id ids in + AE_for (id, no_shadow ids aexp1, no_shadow ids aexp2, no_shadow ids aexp3, order, no_shadow ids aexp4) | AE_loop (loop, aexp1, aexp2) -> AE_loop (loop, no_shadow ids aexp1, no_shadow ids aexp2) | AE_short_circuit (op, aval, aexp) -> AE_short_circuit (op, aval, no_shadow ids aexp) in @@ -337,7 +346,7 @@ let rec no_shadow ids (AE_aux (aexp, env, l)) = and no_shadow_apexp ids (apat, aexp1, aexp2) = let shadows = IdSet.inter (apat_bindings apat) ids in - let shadows = List.map (fun id -> id, new_shadow id) (IdSet.elements shadows) in + let shadows = List.map (fun id -> (id, new_shadow id)) (IdSet.elements shadows) in let rename aexp = List.fold_left (fun aexp (from_id, to_id) -> aexp_rename from_id to_id aexp) aexp shadows in let rename_apat apat = List.fold_left (fun apat (from_id, to_id) -> apat_rename from_id to_id apat) apat shadows in let ids = IdSet.union (apat_bindings apat) (IdSet.union ids (IdSet.of_list (List.map snd shadows))) in @@ -346,53 +355,58 @@ and no_shadow_apexp ids (apat, aexp1, aexp2) = (* Map over all the avals in an aexp. *) let rec map_aval f (AE_aux (aexp, env, l)) = - let aexp = match aexp with + let aexp = + match aexp with | AE_val v -> AE_val (f env l v) | AE_typ (aexp, typ) -> AE_typ (map_aval f aexp, typ) | AE_assign (alexp, aexp) -> AE_assign (alexp, map_aval f aexp) | AE_app (id, vs, typ) -> AE_app (id, List.map (f env l) vs, typ) - | AE_let (mut, id, typ1, aexp1, aexp2, typ2) -> - AE_let (mut, id, typ1, map_aval f aexp1, map_aval f aexp2, typ2) + | AE_let (mut, id, typ1, aexp1, aexp2, typ2) -> AE_let (mut, id, typ1, map_aval f aexp1, map_aval f aexp2, typ2) | AE_block (aexps, aexp, typ) -> AE_block (List.map (map_aval f) aexps, map_aval f aexp, typ) | AE_return (aval, typ) -> AE_return (f env l aval, typ) | AE_exit (aval, typ) -> AE_exit (f env l aval, typ) | AE_throw (aval, typ) -> AE_throw (f env l aval, typ) - | AE_if (aval, aexp1, aexp2, typ2) -> - AE_if (f env l aval, map_aval f aexp1, map_aval f aexp2, typ2) + | AE_if (aval, aexp1, aexp2, typ2) -> AE_if (f env l aval, map_aval f aexp1, map_aval f aexp2, typ2) | AE_loop (loop_typ, aexp1, aexp2) -> AE_loop (loop_typ, map_aval f aexp1, map_aval f aexp2) | AE_for (id, aexp1, aexp2, aexp3, order, aexp4) -> - AE_for (id, map_aval f aexp1, map_aval f aexp2, map_aval f aexp3, order, map_aval f aexp4) - | AE_struct_update (aval, updates, typ) -> - AE_struct_update (f env l aval, Bindings.map (f env l) updates, typ) - | AE_field (aval, field, typ) -> - AE_field (f env l aval, field, typ) + AE_for (id, map_aval f aexp1, map_aval f aexp2, map_aval f aexp3, order, map_aval f aexp4) + | AE_struct_update (aval, updates, typ) -> AE_struct_update (f env l aval, Bindings.map (f env l) updates, typ) + | AE_field (aval, field, typ) -> AE_field (f env l aval, field, typ) | AE_match (aval, cases, typ) -> - AE_match (f env l aval, List.map (fun (pat, aexp1, aexp2) -> pat, map_aval f aexp1, map_aval f aexp2) cases, typ) + AE_match + (f env l aval, List.map (fun (pat, aexp1, aexp2) -> (pat, map_aval f aexp1, map_aval f aexp2)) cases, typ) | AE_try (aexp, cases, typ) -> - AE_try (map_aval f aexp, List.map (fun (pat, aexp1, aexp2) -> pat, map_aval f aexp1, map_aval f aexp2) cases, typ) + AE_try + (map_aval f aexp, List.map (fun (pat, aexp1, aexp2) -> (pat, map_aval f aexp1, map_aval f aexp2)) cases, typ) | AE_short_circuit (op, aval, aexp) -> AE_short_circuit (op, f env l aval, map_aval f aexp) in AE_aux (aexp, env, l) (* Map over all the functions in an aexp. *) let rec map_functions f (AE_aux (aexp, env, l)) = - let aexp = match aexp with + let aexp = + match aexp with | AE_app (id, vs, typ) -> f env l id vs typ | AE_typ (aexp, typ) -> AE_typ (map_functions f aexp, typ) | AE_assign (alexp, aexp) -> AE_assign (alexp, map_functions f aexp) | AE_short_circuit (op, aval, aexp) -> AE_short_circuit (op, aval, map_functions f aexp) - | AE_let (mut, id, typ1, aexp1, aexp2, typ2) -> AE_let (mut, id, typ1, map_functions f aexp1, map_functions f aexp2, typ2) + | AE_let (mut, id, typ1, aexp1, aexp2, typ2) -> + AE_let (mut, id, typ1, map_functions f aexp1, map_functions f aexp2, typ2) | AE_block (aexps, aexp, typ) -> AE_block (List.map (map_functions f) aexps, map_functions f aexp, typ) - | AE_if (aval, aexp1, aexp2, typ) -> - AE_if (aval, map_functions f aexp1, map_functions f aexp2, typ) + | AE_if (aval, aexp1, aexp2, typ) -> AE_if (aval, map_functions f aexp1, map_functions f aexp2, typ) | AE_loop (loop_typ, aexp1, aexp2) -> AE_loop (loop_typ, map_functions f aexp1, map_functions f aexp2) | AE_for (id, aexp1, aexp2, aexp3, order, aexp4) -> - AE_for (id, map_functions f aexp1, map_functions f aexp2, map_functions f aexp3, order, map_functions f aexp4) + AE_for (id, map_functions f aexp1, map_functions f aexp2, map_functions f aexp3, order, map_functions f aexp4) | AE_match (aval, cases, typ) -> - AE_match (aval, List.map (fun (pat, aexp1, aexp2) -> pat, map_functions f aexp1, map_functions f aexp2) cases, typ) + AE_match + (aval, List.map (fun (pat, aexp1, aexp2) -> (pat, map_functions f aexp1, map_functions f aexp2)) cases, typ) | AE_try (aexp, cases, typ) -> - AE_try (map_functions f aexp, List.map (fun (pat, aexp1, aexp2) -> pat, map_functions f aexp1, map_functions f aexp2) cases, typ) - | AE_field _ | AE_struct_update _ | AE_val _ | AE_return _ | AE_exit _ | AE_throw _ as v -> v + AE_try + ( map_functions f aexp, + List.map (fun (pat, aexp1, aexp2) -> (pat, map_functions f aexp1, map_functions f aexp2)) cases, + typ + ) + | (AE_field _ | AE_struct_update _ | AE_val _ | AE_return _ | AE_exit _ | AE_throw _) as v -> v in AE_aux (aexp, env, l) @@ -402,92 +416,85 @@ let rec map_functions f (AE_aux (aexp, env, l)) = let pp_lvar lvar doc = match lvar with - | Register typ -> - string "[R/" ^^ string (string_of_typ typ |> Util.yellow |> Util.clear) ^^ string "]" ^^ doc - | Local (Mutable, typ) -> - string "[M/" ^^ string (string_of_typ typ |> Util.yellow |> Util.clear) ^^ string "]" ^^ doc + | Register typ -> string "[R/" ^^ string (string_of_typ typ |> Util.yellow |> Util.clear) ^^ string "]" ^^ doc + | Local (Mutable, typ) -> string "[M/" ^^ string (string_of_typ typ |> Util.yellow |> Util.clear) ^^ string "]" ^^ doc | Local (Immutable, typ) -> - string "[I/" ^^ string (string_of_typ typ |> Util.yellow |> Util.clear) ^^ string "]" ^^ doc - | Enum typ -> - string "[E/" ^^ string (string_of_typ typ |> Util.yellow |> Util.clear) ^^ string "]" ^^ doc + string "[I/" ^^ string (string_of_typ typ |> Util.yellow |> Util.clear) ^^ string "]" ^^ doc + | Enum typ -> string "[E/" ^^ string (string_of_typ typ |> Util.yellow |> Util.clear) ^^ string "]" ^^ doc | Unbound id -> string "[?" ^^ string (string_of_id id) ^^ string "]" ^^ doc -let pp_annot typ doc = - string "[" ^^ string (string_of_typ typ |> Util.yellow |> Util.clear) ^^ string "]" ^^ doc +let pp_annot typ doc = string "[" ^^ string (string_of_typ typ |> Util.yellow |> Util.clear) ^^ string "]" ^^ doc let pp_order = function | Ord_aux (Ord_inc, _) -> string "inc" | Ord_aux (Ord_dec, _) -> string "dec" | _ -> assert false (* Order types have been specialised, so no polymorphism in C backend. *) -let pp_id id = - string (string_of_id id) +let pp_id id = string (string_of_id id) let rec pp_alexp = function - | AL_id (id, typ) -> - pp_annot typ (pp_id id) - | AL_addr (id, typ) -> - string "*" ^^ parens (pp_annot typ (pp_id id)) - | AL_field (alexp, field) -> - pp_alexp alexp ^^ dot ^^ pp_id field + | AL_id (id, typ) -> pp_annot typ (pp_id id) + | AL_addr (id, typ) -> string "*" ^^ parens (pp_annot typ (pp_id id)) + | AL_field (alexp, field) -> pp_alexp alexp ^^ dot ^^ pp_id field let rec pp_aexp (AE_aux (aexp, _, _)) = match aexp with | AE_val v -> pp_aval v - | AE_typ (aexp, typ) -> - pp_annot typ (string "$" ^^ pp_aexp aexp) - | AE_assign (alexp, aexp) -> - pp_alexp alexp ^^ string " := " ^^ pp_aexp aexp - | AE_app (id, args, typ) -> - pp_annot typ (pp_id id ^^ parens (separate_map (comma ^^ space) pp_aval args)) - | AE_short_circuit (SC_or, aval, aexp) -> - pp_aval aval ^^ string " || " ^^ pp_aexp aexp - | AE_short_circuit (SC_and, aval, aexp) -> - pp_aval aval ^^ string " && " ^^ pp_aexp aexp - | AE_let (mut, id, id_typ, binding, body, typ) -> group - begin - match binding with - | AE_aux (AE_let _, _, _) -> - (pp_annot typ (separate space [string "let"; pp_annot id_typ (pp_id id); string "="]) - ^^ hardline ^^ nest 2 (pp_aexp binding)) - ^^ hardline ^^ string "in" ^^ space ^^ pp_aexp body - | _ -> - pp_annot typ (separate space [string "let"; pp_annot id_typ (pp_id id); string "="; pp_aexp binding; string "in"]) - ^^ hardline ^^ pp_aexp body - end + | AE_typ (aexp, typ) -> pp_annot typ (string "$" ^^ pp_aexp aexp) + | AE_assign (alexp, aexp) -> pp_alexp alexp ^^ string " := " ^^ pp_aexp aexp + | AE_app (id, args, typ) -> pp_annot typ (pp_id id ^^ parens (separate_map (comma ^^ space) pp_aval args)) + | AE_short_circuit (SC_or, aval, aexp) -> pp_aval aval ^^ string " || " ^^ pp_aexp aexp + | AE_short_circuit (SC_and, aval, aexp) -> pp_aval aval ^^ string " && " ^^ pp_aexp aexp + | AE_let (mut, id, id_typ, binding, body, typ) -> + group + begin + match binding with + | AE_aux (AE_let _, _, _) -> + (pp_annot typ (separate space [string "let"; pp_annot id_typ (pp_id id); string "="]) + ^^ hardline + ^^ nest 2 (pp_aexp binding) + ) + ^^ hardline ^^ string "in" ^^ space ^^ pp_aexp body + | _ -> + pp_annot typ + (separate space [string "let"; pp_annot id_typ (pp_id id); string "="; pp_aexp binding; string "in"]) + ^^ hardline ^^ pp_aexp body + end | AE_if (cond, then_aexp, else_aexp, typ) -> - pp_annot typ (separate space [ string "if"; pp_aval cond; - string "then"; pp_aexp then_aexp; - string "else"; pp_aexp else_aexp ]) - | AE_block (aexps, aexp, typ) -> - pp_annot typ (surround 2 0 lbrace (pp_block (aexps @ [aexp])) rbrace) + pp_annot typ + (separate space [string "if"; pp_aval cond; string "then"; pp_aexp then_aexp; string "else"; pp_aexp else_aexp]) + | AE_block (aexps, aexp, typ) -> pp_annot typ (surround 2 0 lbrace (pp_block (aexps @ [aexp])) rbrace) | AE_return (v, typ) -> pp_annot typ (string "return" ^^ parens (pp_aval v)) | AE_exit (v, typ) -> pp_annot typ (string "exit" ^^ parens (pp_aval v)) | AE_throw (v, typ) -> pp_annot typ (string "throw" ^^ parens (pp_aval v)) - | AE_loop (While, aexp1, aexp2) -> - separate space [string "while"; pp_aexp aexp1; string "do"; pp_aexp aexp2] - | AE_loop (Until, aexp1, aexp2) -> - separate space [string "repeat"; pp_aexp aexp2; string "until"; pp_aexp aexp1] + | AE_loop (While, aexp1, aexp2) -> separate space [string "while"; pp_aexp aexp1; string "do"; pp_aexp aexp2] + | AE_loop (Until, aexp1, aexp2) -> separate space [string "repeat"; pp_aexp aexp2; string "until"; pp_aexp aexp1] | AE_for (id, aexp1, aexp2, aexp3, order, aexp4) -> - let header = - string "foreach" ^^ space ^^ - group (parens (separate (break 1) - [ pp_id id; - string "from " ^^ pp_aexp aexp1; - string "to " ^^ pp_aexp aexp2; - string "by " ^^ pp_aexp aexp3; - string "in " ^^ pp_order order ])) - in - header ^//^ pp_aexp aexp4 + let header = + string "foreach" ^^ space + ^^ group + (parens + (separate (break 1) + [ + pp_id id; + string "from " ^^ pp_aexp aexp1; + string "to " ^^ pp_aexp aexp2; + string "by " ^^ pp_aexp aexp3; + string "in " ^^ pp_order order; + ] + ) + ) + in + header ^//^ pp_aexp aexp4 | AE_field (aval, field, typ) -> pp_annot typ (parens (pp_aval aval ^^ string "." ^^ pp_id field)) - | AE_match (aval, cases, typ) -> - pp_annot typ (separate space [string "match"; pp_aval aval; pp_cases cases]) - | AE_try (aexp, cases, typ) -> - pp_annot typ (separate space [string "try"; pp_aexp aexp; pp_cases cases]) + | AE_match (aval, cases, typ) -> pp_annot typ (separate space [string "match"; pp_aval aval; pp_cases cases]) + | AE_try (aexp, cases, typ) -> pp_annot typ (separate space [string "try"; pp_aexp aexp; pp_cases cases]) | AE_struct_update (aval, updates, typ) -> - braces (pp_aval aval ^^ string " with " - ^^ separate (string ", ") (List.map (fun (id, aval) -> pp_id id ^^ string " = " ^^ pp_aval aval) - (Bindings.bindings updates))) + braces + (pp_aval aval ^^ string " with " + ^^ separate (string ", ") + (List.map (fun (id, aval) -> pp_id id ^^ string " = " ^^ pp_aval aval) (Bindings.bindings updates)) + ) and pp_apat (AP_aux (apat_aux, _, _)) = match apat_aux with @@ -502,8 +509,7 @@ and pp_apat (AP_aux (apat_aux, _, _)) = and pp_cases cases = surround 2 0 lbrace (separate_map (comma ^^ hardline) pp_case cases) rbrace -and pp_case (apat, guard, body) = - separate space [pp_apat apat; string "if"; pp_aexp guard; string "=>"; pp_aexp body] +and pp_case (apat, guard, body) = separate space [pp_apat apat; string "if"; pp_aexp guard; string "=>"; pp_aexp body] and pp_block = function | [] -> string "()" @@ -515,16 +521,17 @@ and pp_aval = function | AV_id (id, lvar) -> pp_lvar lvar (pp_id id) | AV_tuple avals -> parens (separate_map (comma ^^ space) pp_aval avals) | AV_ref (id, lvar) -> string "ref" ^^ space ^^ pp_lvar lvar (pp_id id) - | AV_cval (cval, typ) -> - pp_annot typ (string (string_of_cval cval |> Util.cyan |> Util.clear)) - | AV_vector (avals, typ) -> - pp_annot typ (string "[" ^^ separate_map (comma ^^ space) pp_aval avals ^^ string "]") - | AV_list (avals, typ) -> - pp_annot typ (string "[|" ^^ separate_map (comma ^^ space) pp_aval avals ^^ string "|]") + | AV_cval (cval, typ) -> pp_annot typ (string (string_of_cval cval |> Util.cyan |> Util.clear)) + | AV_vector (avals, typ) -> pp_annot typ (string "[" ^^ separate_map (comma ^^ space) pp_aval avals ^^ string "]") + | AV_list (avals, typ) -> pp_annot typ (string "[|" ^^ separate_map (comma ^^ space) pp_aval avals ^^ string "|]") | AV_record (fields, typ) -> - pp_annot typ (string "struct {" - ^^ separate_map (comma ^^ space) (fun (id, field) -> pp_id id ^^ string " = " ^^ pp_aval field) (Bindings.bindings fields) - ^^ string "}") + pp_annot typ + (string "struct {" + ^^ separate_map (comma ^^ space) + (fun (id, field) -> pp_id id ^^ string " = " ^^ pp_aval field) + (Bindings.bindings fields) + ^^ string "}" + ) [@@@coverage on] @@ -532,34 +539,38 @@ let ae_lit lit typ = AE_val (AV_lit (lit, typ)) let is_dead_aexp (AE_aux (_, env, _)) = prove __POS__ env nc_false -let (gensym, reset_anf_counter) = symbol_generator "ga" +let gensym, reset_anf_counter = symbol_generator "ga" let rec split_block l = function - | [exp] -> [], exp + | [exp] -> ([], exp) | exp :: exps -> - let exps, last = split_block l exps in - exp :: exps, last - | [] -> - Reporting.unreachable l __POS__ "empty block found when converting to ANF" [@coverage off] + let exps, last = split_block l exps in + (exp :: exps, last) + | [] -> Reporting.unreachable l __POS__ "empty block found when converting to ANF" [@coverage off] -let rec anf_pat ?global:(global=false) (P_aux (p_aux, annot) as pat) = +let rec anf_pat ?(global = false) (P_aux (p_aux, annot) as pat) = let mk_apat aux = AP_aux (aux, env_of_annot annot, fst annot) in match p_aux with | P_id id when global -> mk_apat (AP_global (id, typ_of_pat pat)) | P_id id -> mk_apat (AP_id (id, typ_of_pat pat)) | P_wild -> mk_apat (AP_wild (typ_of_pat pat)) - | P_tuple pats -> mk_apat (AP_tuple (List.map (fun pat -> anf_pat ~global:global pat) pats)) - | P_app (id, [subpat]) -> mk_apat (AP_app (id, anf_pat ~global:global subpat, typ_of_pat pat)) - | P_app (id, pats) -> mk_apat (AP_app (id, mk_apat (AP_tuple (List.map (fun pat -> anf_pat ~global:global pat) pats)), typ_of_pat pat)) - | P_typ (_, pat) -> anf_pat ~global:global pat - | P_var (pat, _) -> anf_pat ~global:global pat - | P_cons (hd_pat, tl_pat) -> mk_apat (AP_cons (anf_pat ~global:global hd_pat, anf_pat ~global:global tl_pat)) - | P_list pats -> List.fold_right (fun pat apat -> mk_apat (AP_cons (anf_pat ~global:global pat, apat))) pats (mk_apat (AP_nil (typ_of_pat pat))) + | P_tuple pats -> mk_apat (AP_tuple (List.map (fun pat -> anf_pat ~global pat) pats)) + | P_app (id, [subpat]) -> mk_apat (AP_app (id, anf_pat ~global subpat, typ_of_pat pat)) + | P_app (id, pats) -> + mk_apat (AP_app (id, mk_apat (AP_tuple (List.map (fun pat -> anf_pat ~global pat) pats)), typ_of_pat pat)) + | P_typ (_, pat) -> anf_pat ~global pat + | P_var (pat, _) -> anf_pat ~global pat + | P_cons (hd_pat, tl_pat) -> mk_apat (AP_cons (anf_pat ~global hd_pat, anf_pat ~global tl_pat)) + | P_list pats -> + List.fold_right + (fun pat apat -> mk_apat (AP_cons (anf_pat ~global pat, apat))) + pats + (mk_apat (AP_nil (typ_of_pat pat))) | P_lit (L_aux (L_unit, _)) -> mk_apat (AP_wild (typ_of_pat pat)) - | P_as (pat, id) -> mk_apat (AP_as (anf_pat ~global:global pat, id, typ_of_pat pat)) + | P_as (pat, id) -> mk_apat (AP_as (anf_pat ~global pat, id, typ_of_pat pat)) | _ -> - Reporting.unreachable (fst annot) __POS__ - ("Could not convert pattern to ANF: " ^ string_of_pat pat) [@coverage off] + Reporting.unreachable (fst annot) __POS__ + ("Could not convert pattern to ANF: " ^ string_of_pat pat) [@coverage off] let rec apat_globals (AP_aux (aux, _, _)) = match aux with @@ -575,18 +586,16 @@ let rec anf (E_aux (e_aux, ((l, _) as exp_annot)) as exp) = let rec anf_lexp env (LE_aux (aux, (l, _)) as lexp) = match aux with - | LE_id id | LE_typ (_, id) -> - (fun x -> x), AL_id (id, lvar_typ ~loc:l (Env.lookup_id id env)) + | LE_id id | LE_typ (_, id) -> ((fun x -> x), AL_id (id, lvar_typ ~loc:l (Env.lookup_id id env))) | LE_field (lexp, field_id) -> - let wrap, alexp = anf_lexp env lexp in - wrap, AL_field (alexp, field_id) + let wrap, alexp = anf_lexp env lexp in + (wrap, AL_field (alexp, field_id)) | LE_deref dexp -> - let gs = gensym () in - (fun x -> mk_aexp (AE_let (Mutable, gs, typ_of dexp, anf dexp, x, unit_typ))), - AL_addr (gs, typ_of dexp) + let gs = gensym () in + ((fun x -> mk_aexp (AE_let (Mutable, gs, typ_of dexp, anf dexp, x, unit_typ))), AL_addr (gs, typ_of dexp)) | _ -> - Reporting.unreachable l __POS__ - ("Encountered complex l-expression " ^ string_of_lexp lexp ^ " when converting to ANF") [@coverage off] + Reporting.unreachable l __POS__ + ("Encountered complex l-expression " ^ string_of_lexp lexp ^ " when converting to ANF") [@coverage off] in let to_aval (AE_aux (aexp_aux, env, _) as aexp) = @@ -594,221 +603,186 @@ let rec anf (E_aux (e_aux, ((l, _) as exp_annot)) as exp) = match aexp_aux with | AE_val v -> (v, fun x -> x) | AE_short_circuit (_, _, _) -> - let id = gensym () in - (AV_id (id, Local (Immutable, bool_typ)), fun x -> mk_aexp x (AE_let (Immutable, id, bool_typ, aexp, x, typ_of exp))) + let id = gensym () in + ( AV_id (id, Local (Immutable, bool_typ)), + fun x -> mk_aexp x (AE_let (Immutable, id, bool_typ, aexp, x, typ_of exp)) + ) | AE_app (_, _, typ) - | AE_let (_, _, _, _, _, typ) - | AE_return (_, typ) - | AE_exit (_, typ) - | AE_throw (_, typ) - | AE_typ (_, typ) - | AE_if (_, _, _, typ) - | AE_field (_, _, typ) - | AE_match (_, _, typ) - | AE_try (_, _, typ) - | AE_struct_update (_, _, typ) - | AE_block (_, _, typ) -> - let id = gensym () in - (AV_id (id, Local (Immutable, typ)), fun x -> mk_aexp x (AE_let (Immutable, id, typ, aexp, x, typ_of exp))) + | AE_let (_, _, _, _, _, typ) + | AE_return (_, typ) + | AE_exit (_, typ) + | AE_throw (_, typ) + | AE_typ (_, typ) + | AE_if (_, _, _, typ) + | AE_field (_, _, typ) + | AE_match (_, _, typ) + | AE_try (_, _, typ) + | AE_struct_update (_, _, typ) + | AE_block (_, _, typ) -> + let id = gensym () in + (AV_id (id, Local (Immutable, typ)), fun x -> mk_aexp x (AE_let (Immutable, id, typ, aexp, x, typ_of exp))) | AE_assign _ | AE_for _ | AE_loop _ -> - let id = gensym () in - (AV_id (id, Local (Immutable, unit_typ)), fun x -> mk_aexp x (AE_let (Immutable, id, unit_typ, aexp, x, typ_of exp))) + let id = gensym () in + ( AV_id (id, Local (Immutable, unit_typ)), + fun x -> mk_aexp x (AE_let (Immutable, id, unit_typ, aexp, x, typ_of exp)) + ) in match e_aux with | E_lit lit -> mk_aexp (ae_lit lit (typ_of exp)) - | E_block [] -> - Reporting.warn "" l - "Translating empty block (possibly assigning to an uninitialized variable at the end of a block?)"; - mk_aexp (ae_lit (L_aux (L_unit, l)) (typ_of exp)) + Reporting.warn "" l + "Translating empty block (possibly assigning to an uninitialized variable at the end of a block?)"; + mk_aexp (ae_lit (L_aux (L_unit, l)) (typ_of exp)) | E_block exps -> - let exps, last = split_block l exps in - let aexps = List.map anf exps in - let alast = anf last in - mk_aexp (AE_block (aexps, alast, typ_of exp)) - + let exps, last = split_block l exps in + let aexps = List.map anf exps in + let alast = anf last in + mk_aexp (AE_block (aexps, alast, typ_of exp)) | E_assign (lexp, assign_exp) -> - let aexp = anf assign_exp in - let wrap, alexp = anf_lexp (env_of exp) lexp in - wrap (mk_aexp (AE_assign (alexp, aexp))) - + let aexp = anf assign_exp in + let wrap, alexp = anf_lexp (env_of exp) lexp in + wrap (mk_aexp (AE_assign (alexp, aexp))) | E_loop (loop_typ, _, cond, exp) -> - let acond = anf cond in - let aexp = anf exp in - mk_aexp (AE_loop (loop_typ, acond, aexp)) - + let acond = anf cond in + let aexp = anf exp in + mk_aexp (AE_loop (loop_typ, acond, aexp)) | E_for (id, exp1, exp2, exp3, order, body) -> - let aexp1, aexp2, aexp3, abody = anf exp1, anf exp2, anf exp3, anf body in - mk_aexp (AE_for (id, aexp1, aexp2, aexp3, order, abody)) - + let aexp1, aexp2, aexp3, abody = (anf exp1, anf exp2, anf exp3, anf body) in + mk_aexp (AE_for (id, aexp1, aexp2, aexp3, order, abody)) | E_if (cond, then_exp, else_exp) -> - let cond_val, wrap = to_aval (anf cond) in - let then_aexp = anf then_exp in - let else_aexp = anf else_exp in - wrap (mk_aexp (AE_if (cond_val, then_aexp, else_aexp, typ_of exp))) - - | E_app_infix (x, Id_aux (Id op, l), y) -> - anf (E_aux (E_app (Id_aux (Operator op, l), [x; y]), exp_annot)) - | E_app_infix (x, Id_aux (Operator op, l), y) -> - anf (E_aux (E_app (Id_aux (Id op, l), [x; y]), exp_annot)) - + let cond_val, wrap = to_aval (anf cond) in + let then_aexp = anf then_exp in + let else_aexp = anf else_exp in + wrap (mk_aexp (AE_if (cond_val, then_aexp, else_aexp, typ_of exp))) + | E_app_infix (x, Id_aux (Id op, l), y) -> anf (E_aux (E_app (Id_aux (Operator op, l), [x; y]), exp_annot)) + | E_app_infix (x, Id_aux (Operator op, l), y) -> anf (E_aux (E_app (Id_aux (Id op, l), [x; y]), exp_annot)) | E_vector exps -> - let aexps = List.map anf exps in - let avals = List.map to_aval aexps in - let wrap = List.fold_left (fun f g x -> f (g x)) (fun x -> x) (List.map snd avals) in - wrap (mk_aexp (AE_val (AV_vector (List.map fst avals, typ_of exp)))) - + let aexps = List.map anf exps in + let avals = List.map to_aval aexps in + let wrap = List.fold_left (fun f g x -> f (g x)) (fun x -> x) (List.map snd avals) in + wrap (mk_aexp (AE_val (AV_vector (List.map fst avals, typ_of exp)))) | E_list exps -> - let aexps = List.map anf exps in - let avals = List.map to_aval aexps in - let wrap = List.fold_left (fun f g x -> f (g x)) (fun x -> x) (List.map snd avals) in - wrap (mk_aexp (AE_val (AV_list (List.map fst avals, typ_of exp)))) - + let aexps = List.map anf exps in + let avals = List.map to_aval aexps in + let wrap = List.fold_left (fun f g x -> f (g x)) (fun x -> x) (List.map snd avals) in + wrap (mk_aexp (AE_val (AV_list (List.map fst avals, typ_of exp)))) | E_field (field_exp, id) -> - let aval, wrap = to_aval (anf field_exp) in - wrap (mk_aexp (AE_field (aval, id, typ_of exp))) - + let aval, wrap = to_aval (anf field_exp) in + wrap (mk_aexp (AE_field (aval, id, typ_of exp))) | E_struct_update (exp, fexps) -> - let anf_fexp (FE_aux (FE_fexp (id, exp), _)) = - let aval, wrap = to_aval (anf exp) in - (id, aval), wrap - in - let aval, exp_wrap = to_aval (anf exp) in - let fexps = List.map anf_fexp fexps in - let wrap = List.fold_left (fun f g x -> f (g x)) (fun x -> x) (List.map snd fexps) in - let record = List.fold_left (fun r (id, aval) -> Bindings.add id aval r) Bindings.empty (List.map fst fexps) in - exp_wrap (wrap (mk_aexp (AE_struct_update (aval, record, typ_of exp)))) - + let anf_fexp (FE_aux (FE_fexp (id, exp), _)) = + let aval, wrap = to_aval (anf exp) in + ((id, aval), wrap) + in + let aval, exp_wrap = to_aval (anf exp) in + let fexps = List.map anf_fexp fexps in + let wrap = List.fold_left (fun f g x -> f (g x)) (fun x -> x) (List.map snd fexps) in + let record = List.fold_left (fun r (id, aval) -> Bindings.add id aval r) Bindings.empty (List.map fst fexps) in + exp_wrap (wrap (mk_aexp (AE_struct_update (aval, record, typ_of exp)))) | E_app (id, [exp1; exp2]) when string_of_id id = "and_bool" -> - let aexp1 = anf exp1 in - let aexp2 = anf exp2 in - let aval1, wrap = to_aval aexp1 in - wrap (mk_aexp (AE_short_circuit (SC_and, aval1, aexp2))) - + let aexp1 = anf exp1 in + let aexp2 = anf exp2 in + let aval1, wrap = to_aval aexp1 in + wrap (mk_aexp (AE_short_circuit (SC_and, aval1, aexp2))) | E_app (id, [exp1; exp2]) when string_of_id id = "or_bool" -> - let aexp1 = anf exp1 in - let aexp2 = anf exp2 in - let aval1, wrap = to_aval aexp1 in - wrap (mk_aexp (AE_short_circuit (SC_or, aval1, aexp2))) - + let aexp1 = anf exp1 in + let aexp2 = anf exp2 in + let aval1, wrap = to_aval aexp1 in + wrap (mk_aexp (AE_short_circuit (SC_or, aval1, aexp2))) | E_app (id, exps) -> - let aexps = List.map anf exps in - let avals = List.map to_aval aexps in - let wrap = List.fold_left (fun f g x -> f (g x)) (fun x -> x) (List.map snd avals) in - wrap (mk_aexp (AE_app (id, List.map fst avals, typ_of exp))) - + let aexps = List.map anf exps in + let avals = List.map to_aval aexps in + let wrap = List.fold_left (fun f g x -> f (g x)) (fun x -> x) (List.map snd avals) in + wrap (mk_aexp (AE_app (id, List.map fst avals, typ_of exp))) | E_throw exn_exp -> - let aexp = anf exn_exp in - let aval, wrap = to_aval aexp in - wrap (mk_aexp (AE_throw (aval, typ_of exp))) - + let aexp = anf exn_exp in + let aval, wrap = to_aval aexp in + wrap (mk_aexp (AE_throw (aval, typ_of exp))) | E_exit exp -> - let aexp = anf exp in - let aval, wrap = to_aval aexp in - wrap (mk_aexp (AE_exit (aval, typ_of exp))) - + let aexp = anf exp in + let aval, wrap = to_aval aexp in + wrap (mk_aexp (AE_exit (aval, typ_of exp))) | E_return ret_exp -> - let aexp = anf ret_exp in - let aval, wrap = to_aval aexp in - wrap (mk_aexp (AE_return (aval, typ_of exp))) - + let aexp = anf ret_exp in + let aval, wrap = to_aval aexp in + wrap (mk_aexp (AE_return (aval, typ_of exp))) | E_assert (exp1, exp2) -> - let aexp1 = anf exp1 in - let aexp2 = anf exp2 in - let aval1, wrap1 = to_aval aexp1 in - let aval2, wrap2 = to_aval aexp2 in - wrap1 (wrap2 (mk_aexp (AE_app (mk_id "sail_assert", [aval1; aval2], unit_typ)))) - + let aexp1 = anf exp1 in + let aexp2 = anf exp2 in + let aval1, wrap1 = to_aval aexp1 in + let aval2, wrap2 = to_aval aexp2 in + wrap1 (wrap2 (mk_aexp (AE_app (mk_id "sail_assert", [aval1; aval2], unit_typ)))) | E_cons (exp1, exp2) -> - let aexp1 = anf exp1 in - let aexp2 = anf exp2 in - let aval1, wrap1 = to_aval aexp1 in - let aval2, wrap2 = to_aval aexp2 in - wrap1 (wrap2 (mk_aexp (AE_app (mk_id "sail_cons", [aval1; aval2], typ_of exp)))) - + let aexp1 = anf exp1 in + let aexp2 = anf exp2 in + let aval1, wrap1 = to_aval aexp1 in + let aval2, wrap2 = to_aval aexp2 in + wrap1 (wrap2 (mk_aexp (AE_app (mk_id "sail_cons", [aval1; aval2], typ_of exp)))) | E_id id -> - let lvar = Env.lookup_id id (env_of exp) in - begin match lvar with - | _ -> mk_aexp (AE_val (AV_id (id, lvar))) - end - + let lvar = Env.lookup_id id (env_of exp) in + begin + match lvar with _ -> mk_aexp (AE_val (AV_id (id, lvar))) + end | E_ref id -> - let lvar = Env.lookup_id id (env_of exp) in - mk_aexp (AE_val (AV_ref (id, lvar))) - + let lvar = Env.lookup_id id (env_of exp) in + mk_aexp (AE_val (AV_ref (id, lvar))) | E_match (match_exp, pexps) -> - let match_aval, match_wrap = to_aval (anf match_exp) in - let anf_pexp (Pat_aux (pat_aux, _)) = - match pat_aux with - | Pat_when (pat, guard, body) -> - (anf_pat pat, anf guard, anf body) - | Pat_exp (pat, body) -> - (anf_pat pat, mk_aexp (AE_val (AV_lit (mk_lit (L_true), bool_typ))), anf body) - in - match_wrap (mk_aexp (AE_match (match_aval, List.map anf_pexp pexps, typ_of exp))) - + let match_aval, match_wrap = to_aval (anf match_exp) in + let anf_pexp (Pat_aux (pat_aux, _)) = + match pat_aux with + | Pat_when (pat, guard, body) -> (anf_pat pat, anf guard, anf body) + | Pat_exp (pat, body) -> (anf_pat pat, mk_aexp (AE_val (AV_lit (mk_lit L_true, bool_typ))), anf body) + in + match_wrap (mk_aexp (AE_match (match_aval, List.map anf_pexp pexps, typ_of exp))) | E_try (match_exp, pexps) -> - let match_aexp = anf match_exp in - let anf_pexp (Pat_aux (pat_aux, _)) = - match pat_aux with - | Pat_when (pat, guard, body) -> - (anf_pat pat, anf guard, anf body) - | Pat_exp (pat, body) -> - (anf_pat pat, mk_aexp (AE_val (AV_lit (mk_lit (L_true), bool_typ))), anf body) - in - mk_aexp (AE_try (match_aexp, List.map anf_pexp pexps, typ_of exp)) - + let match_aexp = anf match_exp in + let anf_pexp (Pat_aux (pat_aux, _)) = + match pat_aux with + | Pat_when (pat, guard, body) -> (anf_pat pat, anf guard, anf body) + | Pat_exp (pat, body) -> (anf_pat pat, mk_aexp (AE_val (AV_lit (mk_lit L_true, bool_typ))), anf body) + in + mk_aexp (AE_try (match_aexp, List.map anf_pexp pexps, typ_of exp)) | E_var (LE_aux (LE_id id, _), binding, body) - | E_var (LE_aux (LE_typ (_, id), _), binding, body) - | E_let (LB_aux (LB_val (P_aux (P_id id, _), binding), _), body) - | E_let (LB_aux (LB_val (P_aux (P_typ (_, P_aux (P_id id, _)), _), binding), _), body) -> - let env = env_of body in - let lvar = Env.lookup_id id env in - mk_aexp (AE_let (Mutable, id, lvar_typ ~loc:l lvar, anf binding, anf body, typ_of exp)) - + | E_var (LE_aux (LE_typ (_, id), _), binding, body) + | E_let (LB_aux (LB_val (P_aux (P_id id, _), binding), _), body) + | E_let (LB_aux (LB_val (P_aux (P_typ (_, P_aux (P_id id, _)), _), binding), _), body) -> + let env = env_of body in + let lvar = Env.lookup_id id env in + mk_aexp (AE_let (Mutable, id, lvar_typ ~loc:l lvar, anf binding, anf body, typ_of exp)) | E_var (lexp, _, _) -> - Reporting.unreachable l __POS__ - ("Encountered complex l-expression " ^ string_of_lexp lexp ^ " when converting to ANF") [@coverage off] - + Reporting.unreachable l __POS__ + ("Encountered complex l-expression " ^ string_of_lexp lexp ^ " when converting to ANF") [@coverage off] | E_let (LB_aux (LB_val (pat, binding), _), body) -> - anf (E_aux (E_match (binding, [Pat_aux (Pat_exp (pat, body), (Parse_ast.Unknown, empty_tannot))]), exp_annot)) - + anf (E_aux (E_match (binding, [Pat_aux (Pat_exp (pat, body), (Parse_ast.Unknown, empty_tannot))]), exp_annot)) | E_tuple exps -> - let aexps = List.map anf exps in - let avals = List.map to_aval aexps in - let wrap = List.fold_left (fun f g x -> f (g x)) (fun x -> x) (List.map snd avals) in - wrap (mk_aexp (AE_val (AV_tuple (List.map fst avals)))) - + let aexps = List.map anf exps in + let avals = List.map to_aval aexps in + let wrap = List.fold_left (fun f g x -> f (g x)) (fun x -> x) (List.map snd avals) in + wrap (mk_aexp (AE_val (AV_tuple (List.map fst avals)))) | E_struct fexps -> - let anf_fexp (FE_aux (FE_fexp (id, exp), _)) = - let aval, wrap = to_aval (anf exp) in - (id, aval), wrap - in - let fexps = List.map anf_fexp fexps in - let wrap = List.fold_left (fun f g x -> f (g x)) (fun x -> x) (List.map snd fexps) in - let record = List.fold_left (fun r (id, aval) -> Bindings.add id aval r) Bindings.empty (List.map fst fexps) in - wrap (mk_aexp (AE_val (AV_record (record, typ_of exp)))) - + let anf_fexp (FE_aux (FE_fexp (id, exp), _)) = + let aval, wrap = to_aval (anf exp) in + ((id, aval), wrap) + in + let fexps = List.map anf_fexp fexps in + let wrap = List.fold_left (fun f g x -> f (g x)) (fun x -> x) (List.map snd fexps) in + let record = List.fold_left (fun r (id, aval) -> Bindings.add id aval r) Bindings.empty (List.map fst fexps) in + wrap (mk_aexp (AE_val (AV_record (record, typ_of exp)))) | E_typ (typ, exp) -> mk_aexp (AE_typ (anf exp, typ)) - | E_internal_assume (_nc, exp) -> anf exp - | E_vector_access _ | E_vector_subrange _ | E_vector_update _ | E_vector_update_subrange _ | E_vector_append _ -> - (* Should be re-written by type checker *) - Reporting.unreachable l __POS__ "encountered raw vector operation when converting to ANF" [@coverage off] - + (* Should be re-written by type checker *) + Reporting.unreachable l __POS__ "encountered raw vector operation when converting to ANF" [@coverage off] | E_internal_value _ -> - (* Interpreter specific *) - Reporting.unreachable l __POS__ "encountered E_internal_value when converting to ANF" [@coverage off] - + (* Interpreter specific *) + Reporting.unreachable l __POS__ "encountered E_internal_value when converting to ANF" [@coverage off] | E_sizeof nexp -> - (* Sizeof nodes removed by sizeof rewriting pass *) - Reporting.unreachable l __POS__ ("encountered E_sizeof node " ^ string_of_nexp nexp ^ " when converting to ANF") [@coverage off] - + (* Sizeof nodes removed by sizeof rewriting pass *) + Reporting.unreachable l __POS__ + ("encountered E_sizeof node " ^ string_of_nexp nexp ^ " when converting to ANF") [@coverage off] | E_constraint _ -> - (* Sizeof nodes removed by sizeof rewriting pass *) - Reporting.unreachable l __POS__ "encountered E_constraint node when converting to ANF" [@coverage off] - + (* Sizeof nodes removed by sizeof rewriting pass *) + Reporting.unreachable l __POS__ "encountered E_constraint node when converting to ANF" [@coverage off] | E_internal_return _ | E_internal_plet _ -> - Reporting.unreachable l __POS__ "encountered unexpected internal node when converting to ANF" [@coverage off] + Reporting.unreachable l __POS__ "encountered unexpected internal node when converting to ANF" [@coverage off] diff --git a/src/lib/anf.mli b/src/lib/anf.mli index 3912ad37e..80bb55d94 100644 --- a/src/lib/anf.mli +++ b/src/lib/anf.mli @@ -102,11 +102,11 @@ type 'a aexp = AE_aux of 'a aexp_aux * Env.t * l and 'a aexp_aux = | AE_val of 'a aval - | AE_app of id * ('a aval) list * 'a + | AE_app of id * 'a aval list * 'a | AE_typ of 'a aexp * 'a | AE_assign of 'a alexp * 'a aexp | AE_let of mut * id * 'a * 'a aexp * 'a aexp * 'a - | AE_block of ('a aexp) list * 'a aexp * 'a + | AE_block of 'a aexp list * 'a aexp * 'a | AE_return of 'a aval * 'a | AE_exit of 'a aval * 'a | AE_throw of 'a aval * 'a @@ -114,7 +114,7 @@ and 'a aexp_aux = | AE_field of 'a aval * id * 'a | AE_match of 'a aval * ('a apat * 'a aexp * 'a aexp) list * 'a | AE_try of 'a aexp * ('a apat * 'a aexp * 'a aexp) list * 'a - | AE_struct_update of 'a aval * ('a aval) Bindings.t * 'a + | AE_struct_update of 'a aval * 'a aval Bindings.t * 'a | AE_for of id * 'a aexp * 'a aexp * 'a aexp * order * 'a aexp | AE_loop of loop * 'a aexp * 'a aexp | AE_short_circuit of sc_op * 'a aval * 'a aexp @@ -124,7 +124,7 @@ and sc_op = SC_and | SC_or and 'a apat = AP_aux of 'a apat_aux * Env.t * l and 'a apat_aux = - | AP_tuple of ('a apat) list + | AP_tuple of 'a apat list | AP_id of id * 'a | AP_global of id * 'a | AP_app of id * 'a apat * 'a @@ -140,25 +140,22 @@ and 'a aval = | AV_lit of lit * 'a | AV_id of id * 'a lvar | AV_ref of id * 'a lvar - | AV_tuple of ('a aval) list - | AV_list of ('a aval) list * 'a - | AV_vector of ('a aval) list * 'a - | AV_record of ('a aval) Bindings.t * 'a + | AV_tuple of 'a aval list + | AV_list of 'a aval list * 'a + | AV_vector of 'a aval list * 'a + | AV_record of 'a aval Bindings.t * 'a | AV_cval of cval * 'a -and 'a alexp = - | AL_id of id * 'a - | AL_addr of id * 'a - | AL_field of 'a alexp * id - +and 'a alexp = AL_id of id * 'a | AL_addr of id * 'a | AL_field of 'a alexp * id + (** When ANF translation has to introduce new bindings it uses a counter to ensure uniqueness. This function resets that counter. *) val reset_anf_counter : unit -> unit val aexp_loc : 'a aexp -> Parse_ast.l - + (** {2 Functions for transforming ANF expressions} *) - + val aval_typ : typ aval -> typ val aexp_typ : typ aexp -> typ @@ -166,12 +163,12 @@ val aexp_typ : typ aexp -> typ val map_aval : (Env.t -> Ast.l -> 'a aval -> 'a aval) -> 'a aexp -> 'a aexp (** Map over all function calls in an ANF expression *) -val map_functions : (Env.t -> Ast.l -> id -> ('a aval) list -> 'a -> 'a aexp_aux) -> 'a aexp -> 'a aexp +val map_functions : (Env.t -> Ast.l -> id -> 'a aval list -> 'a -> 'a aexp_aux) -> 'a aexp -> 'a aexp val fold_aexp : ('a aexp -> 'a aexp) -> 'a aexp -> 'a aexp val aexp_bindings : 'a aexp -> IdSet.t - + (** Remove all variable shadowing in an ANF expression *) val no_shadow : IdSet.t -> 'a aexp -> 'a aexp diff --git a/src/lib/ast_defs.ml b/src/lib/ast_defs.ml index 7843bc314..3656189c1 100644 --- a/src/lib/ast_defs.ml +++ b/src/lib/ast_defs.ml @@ -67,12 +67,6 @@ open Ast -type 'a ast = { - defs : 'a def list; - comments : (string * Lexer.comment list) list - } +type 'a ast = { defs : 'a def list; comments : (string * Lexer.comment list) list } -let empty_ast = { - defs = []; - comments = [] - } +let empty_ast = { defs = []; comments = [] } diff --git a/src/lib/ast_util.ml b/src/lib/ast_util.ml index 165b25b94..2ccbba138 100644 --- a/src/lib/ast_util.ml +++ b/src/lib/ast_util.ml @@ -71,59 +71,42 @@ open Util module Big_int = Nat_big_num (* The type of annotations for untyped AST nodes *) -type uannot = { - attrs : (l * string * string) list - } +type uannot = { attrs : (l * string * string) list } -let empty_uannot = { - attrs = [] - } +let empty_uannot = { attrs = [] } -let add_attribute l attr arg (annot : uannot) = - { attrs = (l, attr, arg) :: annot.attrs } +let add_attribute l attr arg (annot : uannot) = { attrs = (l, attr, arg) :: annot.attrs } let get_attribute attr annot = - List.find_opt (fun (l, attr', arg) -> attr = attr') annot.attrs - |> Option.map (fun (l, _, arg) -> (l, arg)) + List.find_opt (fun (l, attr', arg) -> attr = attr') annot.attrs |> Option.map (fun (l, _, arg) -> (l, arg)) let get_attributes annot = annot.attrs let find_attribute_opt attr1 attrs = - List.find_opt (fun (_, attr2, _) -> attr1 = attr2) attrs - |> Option.map (fun (_, _, arg) -> arg) - -let mk_def_annot l = { - doc_comment = None; - attrs = []; - loc = l; - } - + List.find_opt (fun (_, attr2, _) -> attr1 = attr2) attrs |> Option.map (fun (_, _, arg) -> arg) + +let mk_def_annot l = { doc_comment = None; attrs = []; loc = l } + let map_clause_annot f (def_annot, annot) = - let (l, annot') = f (def_annot.loc, annot) in + let l, annot' = f (def_annot.loc, annot) in ({ def_annot with loc = l }, annot') -let def_annot_map_loc f (annot : def_annot) = { annot with loc = f annot.loc } - -let add_def_attribute l attr arg (annot : def_annot) = - { annot with attrs = (l, attr, arg) :: annot.attrs } +let def_annot_map_loc f (annot : def_annot) = { annot with loc = f annot.loc } + +let add_def_attribute l attr arg (annot : def_annot) = { annot with attrs = (l, attr, arg) :: annot.attrs } let get_def_attribute attr (annot : def_annot) = - List.find_opt (fun (l, attr', arg) -> attr = attr') annot.attrs - |> Option.map (fun (l, arg, _) -> (l, arg)) + List.find_opt (fun (l, attr', arg) -> attr = attr') annot.attrs |> Option.map (fun (l, arg, _) -> (l, arg)) type mut = Immutable | Mutable type 'a lvar = Register of 'a | Enum of 'a | Local of mut * 'a | Unbound of id -let is_unbound = function - | Unbound _ -> true - | _ -> false - -let string_of_id = function - | Id_aux (Id v, _) -> v - | Id_aux (Operator v, _) -> "(operator " ^ v ^ ")" - -let lvar_typ ?loc:(l=Parse_ast.Unknown) = function +let is_unbound = function Unbound _ -> true | _ -> false + +let string_of_id = function Id_aux (Id v, _) -> v | Id_aux (Operator v, _) -> "(operator " ^ v ^ ")" + +let lvar_typ ?loc:(l = Parse_ast.Unknown) = function | Local (_, typ) -> typ | Register typ -> typ | Enum typ -> typ @@ -131,30 +114,21 @@ let lvar_typ ?loc:(l=Parse_ast.Unknown) = function let no_annot = (Parse_ast.Unknown, empty_uannot) -let id_loc = function - | Id_aux (_, l) -> l +let id_loc = function Id_aux (_, l) -> l + +let kid_loc = function Kid_aux (_, l) -> l -let kid_loc = function - | Kid_aux (_, l) -> l +let kopt_loc = function KOpt_aux (_, l) -> l -let kopt_loc = function - | KOpt_aux (_, l) -> l - -let typ_loc = function - | Typ_aux (_, l) -> l +let typ_loc = function Typ_aux (_, l) -> l -let pat_loc = function - | P_aux (_, (l, _)) -> l +let pat_loc = function P_aux (_, (l, _)) -> l -let exp_loc = function - | E_aux (_, (l, _)) -> l +let exp_loc = function E_aux (_, (l, _)) -> l -let nexp_loc = function - | Nexp_aux (_, l) -> l +let nexp_loc = function Nexp_aux (_, l) -> l -let gen_loc = function - | Parse_ast.Generated l -> Parse_ast.Generated l - | l -> Parse_ast.Generated l +let gen_loc = function Parse_ast.Generated l -> Parse_ast.Generated l | l -> Parse_ast.Generated l let rec is_gen_loc = function | Parse_ast.Unknown -> false @@ -172,21 +146,18 @@ let mk_nc nc_aux = NC_aux (nc_aux, Parse_ast.Unknown) let mk_nexp nexp_aux = Nexp_aux (nexp_aux, Parse_ast.Unknown) -let mk_exp ?loc:(l=Parse_ast.Unknown) exp_aux = E_aux (exp_aux, (l, empty_uannot)) +let mk_exp ?loc:(l = Parse_ast.Unknown) exp_aux = E_aux (exp_aux, (l, empty_uannot)) let unaux_exp (E_aux (exp_aux, _)) = exp_aux let uncast_exp = function - | E_aux (E_internal_return (E_aux (E_typ (typ, exp), _)), a) -> - E_aux (E_internal_return exp, a), Some typ - | E_aux (E_typ (typ, exp), _) -> exp, Some typ - | exp -> exp, None + | E_aux (E_internal_return (E_aux (E_typ (typ, exp), _)), a) -> (E_aux (E_internal_return exp, a), Some typ) + | E_aux (E_typ (typ, exp), _) -> (exp, Some typ) + | exp -> (exp, None) let mk_pat pat_aux = P_aux (pat_aux, no_annot) let unaux_pat (P_aux (pat_aux, _)) = pat_aux -let untyp_pat = function - | P_aux (P_typ (typ, pat), _) -> pat, Some typ - | pat -> pat, None +let untyp_pat = function P_aux (P_typ (typ, pat), _) -> (pat, Some typ) | pat -> (pat, None) -let mk_pexp ?loc:(l=Parse_ast.Unknown) pexp_aux = Pat_aux (pexp_aux, (l, empty_uannot)) +let mk_pexp ?loc:(l = Parse_ast.Unknown) pexp_aux = Pat_aux (pexp_aux, (l, empty_uannot)) let mk_mpat mpat_aux = MP_aux (mpat_aux, no_annot) let mk_mpexp mpexp_aux = MPat_aux (mpexp_aux, no_annot) @@ -199,15 +170,13 @@ let mk_lit lit_aux = L_aux (lit_aux, Parse_ast.Unknown) let mk_lit_exp lit_aux = mk_exp (E_lit (mk_lit lit_aux)) -let mk_funcl ?loc:(l=Parse_ast.Unknown) id pat body = +let mk_funcl ?loc:(l = Parse_ast.Unknown) id pat body = FCL_aux (FCL_funcl (id, Pat_aux (Pat_exp (pat, body), (l, empty_uannot))), (mk_def_annot l, empty_uannot)) let mk_qi_nc nc = QI_aux (QI_constraint nc, Parse_ast.Unknown) let mk_qi_id k kid = - let kopt = - KOpt_aux (KOpt_kind (K_aux (k, Parse_ast.Unknown), kid), Parse_ast.Unknown) - in + let kopt = KOpt_aux (KOpt_kind (K_aux (k, Parse_ast.Unknown), kid), Parse_ast.Unknown) in QI_aux (QI_id kopt, Parse_ast.Unknown) let mk_qi_kopt kopt = QI_aux (QI_id kopt, Parse_ast.Unknown) @@ -215,40 +184,26 @@ let mk_qi_kopt kopt = QI_aux (QI_id kopt, Parse_ast.Unknown) let mk_fundef funcls = let tannot_opt = Typ_annot_opt_aux (Typ_annot_opt_none, Parse_ast.Unknown) in let rec_opt = Rec_aux (Rec_nonrec, Parse_ast.Unknown) in - DEF_aux ( - DEF_fundef - (FD_aux (FD_function (rec_opt, tannot_opt, funcls), no_annot)), - mk_def_annot Parse_ast.Unknown - ) + DEF_aux (DEF_fundef (FD_aux (FD_function (rec_opt, tannot_opt, funcls), no_annot)), mk_def_annot Parse_ast.Unknown) let mk_letbind pat exp = LB_aux (LB_val (pat, exp), no_annot) -let mk_val_spec vs_aux = - DEF_aux (DEF_val (VS_aux (vs_aux, no_annot)), mk_def_annot Parse_ast.Unknown) +let mk_val_spec vs_aux = DEF_aux (DEF_val (VS_aux (vs_aux, no_annot)), mk_def_annot Parse_ast.Unknown) let mk_def ?loc:(l = Parse_ast.Unknown) def = DEF_aux (def, mk_def_annot l) - + let kopt_kid (KOpt_aux (KOpt_kind (_, kid), _)) = kid let kopt_kind (KOpt_aux (KOpt_kind (k, _), _)) = k -let is_int_kopt = function - | KOpt_aux (KOpt_kind (K_aux (K_int, _), _), _) -> true - | _ -> false +let is_int_kopt = function KOpt_aux (KOpt_kind (K_aux (K_int, _), _), _) -> true | _ -> false -let is_order_kopt = function - | KOpt_aux (KOpt_kind (K_aux (K_order, _), _), _) -> true - | _ -> false +let is_order_kopt = function KOpt_aux (KOpt_kind (K_aux (K_order, _), _), _) -> true | _ -> false -let is_typ_kopt = function - | KOpt_aux (KOpt_kind (K_aux (K_type, _), _), _) -> true - | _ -> false +let is_typ_kopt = function KOpt_aux (KOpt_kind (K_aux (K_type, _), _), _) -> true | _ -> false -let is_bool_kopt = function - | KOpt_aux (KOpt_kind (K_aux (K_bool, _), _), _) -> true - | _ -> false +let is_bool_kopt = function KOpt_aux (KOpt_kind (K_aux (K_bool, _), _), _) -> true | _ -> false -let string_of_kid = function - | Kid_aux (Var v, _) -> v +let string_of_kid = function Kid_aux (Var v, _) -> v module Kid = struct type t = kid @@ -258,22 +213,24 @@ end module Kind = struct type t = kind let compare (K_aux (aux1, _)) (K_aux (aux2, _)) = - match aux1, aux2 with + match (aux1, aux2) with | K_int, K_int -> 0 | K_type, K_type -> 0 | K_order, K_order -> 0 | K_bool, K_bool -> 0 - | K_int, _ -> 1 | _, K_int -> -1 - | K_type, _ -> 1 | _, K_type -> -1 - | K_order, _ -> 1 | _, K_order -> -1 + | K_int, _ -> 1 + | _, K_int -> -1 + | K_type, _ -> 1 + | _, K_type -> -1 + | K_order, _ -> 1 + | _, K_order -> -1 end module KOpt = struct type t = kinded_id let compare kopt1 kopt2 = let lex_ord c1 c2 = if c1 = 0 then c2 else c1 in - lex_ord (Kid.compare (kopt_kid kopt1) (kopt_kid kopt2)) - (Kind.compare (kopt_kind kopt1) (kopt_kind kopt2)) + lex_ord (Kid.compare (kopt_kid kopt1) (kopt_kid kopt2)) (Kind.compare (kopt_kind kopt1) (kopt_kind kopt2)) end module Id = struct @@ -290,209 +247,197 @@ module Nexp = struct type t = nexp let rec compare (Nexp_aux (nexp1, _)) (Nexp_aux (nexp2, _)) = let lex_ord (c1, c2) = if c1 = 0 then c2 else c1 in - match nexp1, nexp2 with + match (nexp1, nexp2) with | Nexp_id v1, Nexp_id v2 -> Id.compare v1 v2 | Nexp_var kid1, Nexp_var kid2 -> Kid.compare kid1 kid2 | Nexp_constant c1, Nexp_constant c2 -> Big_int.compare c1 c2 | Nexp_app (op1, args1), Nexp_app (op2, args2) -> - let lex1 = Id.compare op1 op2 in - let lex2 = List.length args1 - List.length args2 in - let lex3 = - if lex2 = 0 then - List.fold_left2 (fun l n1 n2 -> lex_ord (l, compare n1 n2)) 0 args1 args2 - else 0 - in - lex_ord (lex1, lex_ord (lex2, lex3)) + let lex1 = Id.compare op1 op2 in + let lex2 = List.length args1 - List.length args2 in + let lex3 = if lex2 = 0 then List.fold_left2 (fun l n1 n2 -> lex_ord (l, compare n1 n2)) 0 args1 args2 else 0 in + lex_ord (lex1, lex_ord (lex2, lex3)) | Nexp_times (n1a, n1b), Nexp_times (n2a, n2b) | Nexp_sum (n1a, n1b), Nexp_sum (n2a, n2b) | Nexp_minus (n1a, n1b), Nexp_minus (n2a, n2b) -> - lex_ord (compare n1a n2a, compare n1b n2b) + lex_ord (compare n1a n2a, compare n1b n2b) | Nexp_exp n1, Nexp_exp n2 -> compare n1 n2 | Nexp_neg n1, Nexp_neg n2 -> compare n1 n2 - | Nexp_constant _, _ -> -1 | _, Nexp_constant _ -> 1 - | Nexp_id _, _ -> -1 | _, Nexp_id _ -> 1 - | Nexp_var _, _ -> -1 | _, Nexp_var _ -> 1 - | Nexp_neg _, _ -> -1 | _, Nexp_neg _ -> 1 - | Nexp_exp _, _ -> -1 | _, Nexp_exp _ -> 1 - | Nexp_minus _, _ -> -1 | _, Nexp_minus _ -> 1 - | Nexp_sum _, _ -> -1 | _, Nexp_sum _ -> 1 - | Nexp_times _, _ -> -1 | _, Nexp_times _ -> 1 + | Nexp_constant _, _ -> -1 + | _, Nexp_constant _ -> 1 + | Nexp_id _, _ -> -1 + | _, Nexp_id _ -> 1 + | Nexp_var _, _ -> -1 + | _, Nexp_var _ -> 1 + | Nexp_neg _, _ -> -1 + | _, Nexp_neg _ -> 1 + | Nexp_exp _, _ -> -1 + | _, Nexp_exp _ -> 1 + | Nexp_minus _, _ -> -1 + | _, Nexp_minus _ -> 1 + | Nexp_sum _, _ -> -1 + | _, Nexp_sum _ -> 1 + | Nexp_times _, _ -> -1 + | _, Nexp_times _ -> 1 end -module Bindings = Map.Make(Id) -module IdSet = Set.Make(Id) -module KBindings = Map.Make(Kid) -module KidSet = Set.Make(Kid) -module KOptSet = Set.Make(KOpt) -module KOptMap = Map.Make(KOpt) -module NexpSet = Set.Make(Nexp) -module NexpMap = Map.Make(Nexp) - -let nexp_identical nexp1 nexp2 = (Nexp.compare nexp1 nexp2 = 0) - -let rec is_nexp_constant (Nexp_aux (nexp, _)) = match nexp with +module Bindings = Map.Make (Id) +module IdSet = Set.Make (Id) +module KBindings = Map.Make (Kid) +module KidSet = Set.Make (Kid) +module KOptSet = Set.Make (KOpt) +module KOptMap = Map.Make (KOpt) +module NexpSet = Set.Make (Nexp) +module NexpMap = Map.Make (Nexp) + +let nexp_identical nexp1 nexp2 = Nexp.compare nexp1 nexp2 = 0 + +let rec is_nexp_constant (Nexp_aux (nexp, _)) = + match nexp with | Nexp_id _ | Nexp_var _ -> false | Nexp_constant _ -> true - | Nexp_times (n1, n2) | Nexp_sum (n1, n2) | Nexp_minus (n1, n2) -> - is_nexp_constant n1 && is_nexp_constant n2 + | Nexp_times (n1, n2) | Nexp_sum (n1, n2) | Nexp_minus (n1, n2) -> is_nexp_constant n1 && is_nexp_constant n2 | Nexp_exp n | Nexp_neg n -> is_nexp_constant n | Nexp_app (_, nexps) -> List.for_all is_nexp_constant nexps -let int_of_nexp_opt nexp = - match nexp with - | Nexp_aux(Nexp_constant i,_) -> Some i - | _ -> None +let int_of_nexp_opt nexp = match nexp with Nexp_aux (Nexp_constant i, _) -> Some i | _ -> None let rec nexp_simp (Nexp_aux (nexp, l)) = Nexp_aux (nexp_simp_aux nexp, l) + and nexp_simp_aux = function (* (n - (n - m)) often appears in foreach loops *) - | Nexp_minus (nexp1, Nexp_aux (Nexp_minus (nexp2, Nexp_aux (n3,_)),_)) - when nexp_identical nexp1 nexp2 -> - nexp_simp_aux n3 - | Nexp_minus (Nexp_aux (Nexp_sum (Nexp_aux (n1, _), nexp2), _), nexp3) - when nexp_identical nexp2 nexp3 -> - nexp_simp_aux n1 - | Nexp_sum (Nexp_aux (Nexp_minus (Nexp_aux (n1, _), nexp2), _), nexp3) - when nexp_identical nexp2 nexp3 -> - nexp_simp_aux n1 - | Nexp_sum (n1, n2) -> - begin - let (Nexp_aux (n1_simp, _) as n1) = nexp_simp n1 in - let (Nexp_aux (n2_simp, _) as n2) = nexp_simp n2 in - match n1_simp, n2_simp with - | Nexp_constant c1, Nexp_constant c2 -> Nexp_constant (Big_int.add c1 c2) - | _, Nexp_neg n2 -> Nexp_minus (n1, n2) - | _, _ -> Nexp_sum (n1, n2) - end - | Nexp_times (n1, n2) -> - begin - let (Nexp_aux (n1_simp, _) as n1) = nexp_simp n1 in - let (Nexp_aux (n2_simp, _) as n2) = nexp_simp n2 in - match n1_simp, n2_simp with - | Nexp_constant c, _ when Big_int.equal c (Big_int.of_int 1) -> n2_simp - | _, Nexp_constant c when Big_int.equal c (Big_int.of_int 1) -> n1_simp - | Nexp_constant c1, Nexp_constant c2 -> Nexp_constant (Big_int.mul c1 c2) - | _, _ -> Nexp_times (n1, n2) - end - | Nexp_minus (n1, n2) -> - begin - let (Nexp_aux (n1_simp, _) as n1) = nexp_simp n1 in - let (Nexp_aux (n2_simp, _) as n2) = nexp_simp n2 in - match n1_simp, n2_simp with - | Nexp_constant c1, Nexp_constant c2 -> Nexp_constant (Big_int.sub c1 c2) - (* A vector range x['n-1 .. 0] can result in the size "('n-1) - -1" *) - | Nexp_minus (Nexp_aux (n,_), Nexp_aux (Nexp_constant c1,_)), Nexp_constant c2 - when Big_int.equal c1 (Big_int.negate c2) -> n - | _, _ -> Nexp_minus (n1, n2) - end - | Nexp_neg n -> - begin - let (Nexp_aux (n_simp, _) as n) = nexp_simp n in - match n_simp with - | Nexp_constant c -> Nexp_constant (Big_int.negate c) - | _ -> Nexp_neg n - end - | Nexp_app (Id_aux (Id "div", _) as id, [n1; n2]) -> - begin - let (Nexp_aux (n1_simp, _) as n1) = nexp_simp n1 in - let (Nexp_aux (n2_simp, _) as n2) = nexp_simp n2 in - match n1_simp, n2_simp with - | Nexp_constant c1, Nexp_constant c2 -> Nexp_constant (Big_int.div c1 c2) - | _, _ -> Nexp_app (id,[n1;n2]) - end + | Nexp_minus (nexp1, Nexp_aux (Nexp_minus (nexp2, Nexp_aux (n3, _)), _)) when nexp_identical nexp1 nexp2 -> + nexp_simp_aux n3 + | Nexp_minus (Nexp_aux (Nexp_sum (Nexp_aux (n1, _), nexp2), _), nexp3) when nexp_identical nexp2 nexp3 -> + nexp_simp_aux n1 + | Nexp_sum (Nexp_aux (Nexp_minus (Nexp_aux (n1, _), nexp2), _), nexp3) when nexp_identical nexp2 nexp3 -> + nexp_simp_aux n1 + | Nexp_sum (n1, n2) -> begin + let (Nexp_aux (n1_simp, _) as n1) = nexp_simp n1 in + let (Nexp_aux (n2_simp, _) as n2) = nexp_simp n2 in + match (n1_simp, n2_simp) with + | Nexp_constant c1, Nexp_constant c2 -> Nexp_constant (Big_int.add c1 c2) + | _, Nexp_neg n2 -> Nexp_minus (n1, n2) + | _, _ -> Nexp_sum (n1, n2) + end + | Nexp_times (n1, n2) -> begin + let (Nexp_aux (n1_simp, _) as n1) = nexp_simp n1 in + let (Nexp_aux (n2_simp, _) as n2) = nexp_simp n2 in + match (n1_simp, n2_simp) with + | Nexp_constant c, _ when Big_int.equal c (Big_int.of_int 1) -> n2_simp + | _, Nexp_constant c when Big_int.equal c (Big_int.of_int 1) -> n1_simp + | Nexp_constant c1, Nexp_constant c2 -> Nexp_constant (Big_int.mul c1 c2) + | _, _ -> Nexp_times (n1, n2) + end + | Nexp_minus (n1, n2) -> begin + let (Nexp_aux (n1_simp, _) as n1) = nexp_simp n1 in + let (Nexp_aux (n2_simp, _) as n2) = nexp_simp n2 in + match (n1_simp, n2_simp) with + | Nexp_constant c1, Nexp_constant c2 -> Nexp_constant (Big_int.sub c1 c2) + (* A vector range x['n-1 .. 0] can result in the size "('n-1) - -1" *) + | Nexp_minus (Nexp_aux (n, _), Nexp_aux (Nexp_constant c1, _)), Nexp_constant c2 + when Big_int.equal c1 (Big_int.negate c2) -> + n + | _, _ -> Nexp_minus (n1, n2) + end + | Nexp_neg n -> begin + let (Nexp_aux (n_simp, _) as n) = nexp_simp n in + match n_simp with Nexp_constant c -> Nexp_constant (Big_int.negate c) | _ -> Nexp_neg n + end + | Nexp_app ((Id_aux (Id "div", _) as id), [n1; n2]) -> begin + let (Nexp_aux (n1_simp, _) as n1) = nexp_simp n1 in + let (Nexp_aux (n2_simp, _) as n2) = nexp_simp n2 in + match (n1_simp, n2_simp) with + | Nexp_constant c1, Nexp_constant c2 -> Nexp_constant (Big_int.div c1 c2) + | _, _ -> Nexp_app (id, [n1; n2]) + end | Nexp_exp nexp -> - let nexp = nexp_simp nexp in - begin match nexp with - | Nexp_aux (Nexp_constant c, _) when Big_int.greater_equal c Big_int.zero && Big_int.less_equal c (Big_int.of_int 7) -> - Nexp_constant (Big_int.pow_int_positive 2 (Big_int.to_int c)) - | _ -> Nexp_exp nexp - end + let nexp = nexp_simp nexp in + begin + match nexp with + | Nexp_aux (Nexp_constant c, _) + when Big_int.greater_equal c Big_int.zero && Big_int.less_equal c (Big_int.of_int 7) -> + Nexp_constant (Big_int.pow_int_positive 2 (Big_int.to_int c)) + | _ -> Nexp_exp nexp + end | nexp -> nexp let rec constraint_simp (NC_aux (nc_aux, l)) = - let nc_aux = match nc_aux with + let nc_aux = + match nc_aux with | NC_equal (nexp1, nexp2) -> - let nexp1, nexp2 = nexp_simp nexp1, nexp_simp nexp2 in - if nexp_identical nexp1 nexp2 then - NC_true - else - NC_equal (nexp1, nexp2) - + let nexp1, nexp2 = (nexp_simp nexp1, nexp_simp nexp2) in + if nexp_identical nexp1 nexp2 then NC_true else NC_equal (nexp1, nexp2) | NC_and (nc1, nc2) -> - let nc1, nc2 = constraint_simp nc1, constraint_simp nc2 in - begin match nc1, nc2 with - | NC_aux (NC_true, _), NC_aux (nc, _) -> nc - | NC_aux (nc, _), NC_aux (NC_true, _) -> nc - | NC_aux (NC_false, _), NC_aux (_, _) -> NC_false - | NC_aux (_, _), NC_aux (NC_false, _) -> NC_false - | _, _ -> NC_and (nc1, nc2) - end - + let nc1, nc2 = (constraint_simp nc1, constraint_simp nc2) in + begin + match (nc1, nc2) with + | NC_aux (NC_true, _), NC_aux (nc, _) -> nc + | NC_aux (nc, _), NC_aux (NC_true, _) -> nc + | NC_aux (NC_false, _), NC_aux (_, _) -> NC_false + | NC_aux (_, _), NC_aux (NC_false, _) -> NC_false + | _, _ -> NC_and (nc1, nc2) + end | NC_or (nc1, nc2) -> - let nc1, nc2 = constraint_simp nc1, constraint_simp nc2 in - begin match nc1, nc2 with - | NC_aux (NC_false, _), NC_aux (nc, _) -> nc - | NC_aux (nc, _), NC_aux (NC_false, _) -> nc - | NC_aux (NC_true, _), NC_aux (_, _) -> NC_true - | NC_aux (_, _), NC_aux (NC_true, _) -> NC_true - | _, _ -> NC_or (nc1, nc2) - end - + let nc1, nc2 = (constraint_simp nc1, constraint_simp nc2) in + begin + match (nc1, nc2) with + | NC_aux (NC_false, _), NC_aux (nc, _) -> nc + | NC_aux (nc, _), NC_aux (NC_false, _) -> nc + | NC_aux (NC_true, _), NC_aux (_, _) -> NC_true + | NC_aux (_, _), NC_aux (NC_true, _) -> NC_true + | _, _ -> NC_or (nc1, nc2) + end | NC_bounded_ge (nexp1, nexp2) -> - let nexp1, nexp2 = nexp_simp nexp1, nexp_simp nexp2 in - begin match nexp1, nexp2 with - | Nexp_aux (Nexp_constant c1, _), Nexp_aux (Nexp_constant c2, _) -> - if Big_int.greater_equal c1 c2 then NC_true else NC_false - | _, _ -> NC_bounded_ge (nexp1, nexp2) - end - + let nexp1, nexp2 = (nexp_simp nexp1, nexp_simp nexp2) in + begin + match (nexp1, nexp2) with + | Nexp_aux (Nexp_constant c1, _), Nexp_aux (Nexp_constant c2, _) -> + if Big_int.greater_equal c1 c2 then NC_true else NC_false + | _, _ -> NC_bounded_ge (nexp1, nexp2) + end | NC_bounded_gt (nexp1, nexp2) -> - let nexp1, nexp2 = nexp_simp nexp1, nexp_simp nexp2 in - begin match nexp1, nexp2 with - | Nexp_aux (Nexp_constant c1, _), Nexp_aux (Nexp_constant c2, _) -> - if Big_int.greater c1 c2 then NC_true else NC_false - | _, _ -> NC_bounded_gt (nexp1, nexp2) - end - + let nexp1, nexp2 = (nexp_simp nexp1, nexp_simp nexp2) in + begin + match (nexp1, nexp2) with + | Nexp_aux (Nexp_constant c1, _), Nexp_aux (Nexp_constant c2, _) -> + if Big_int.greater c1 c2 then NC_true else NC_false + | _, _ -> NC_bounded_gt (nexp1, nexp2) + end | NC_bounded_le (nexp1, nexp2) -> - let nexp1, nexp2 = nexp_simp nexp1, nexp_simp nexp2 in - begin match nexp1, nexp2 with - | Nexp_aux (Nexp_constant c1, _), Nexp_aux (Nexp_constant c2, _) -> - if Big_int.less_equal c1 c2 then NC_true else NC_false - | _, _ -> NC_bounded_le (nexp1, nexp2) - end - + let nexp1, nexp2 = (nexp_simp nexp1, nexp_simp nexp2) in + begin + match (nexp1, nexp2) with + | Nexp_aux (Nexp_constant c1, _), Nexp_aux (Nexp_constant c2, _) -> + if Big_int.less_equal c1 c2 then NC_true else NC_false + | _, _ -> NC_bounded_le (nexp1, nexp2) + end | NC_bounded_lt (nexp1, nexp2) -> - let nexp1, nexp2 = nexp_simp nexp1, nexp_simp nexp2 in - begin match nexp1, nexp2 with - | Nexp_aux (Nexp_constant c1, _), Nexp_aux (Nexp_constant c2, _) -> - if Big_int.less c1 c2 then NC_true else NC_false - | _, _ -> NC_bounded_lt (nexp1, nexp2) - end - + let nexp1, nexp2 = (nexp_simp nexp1, nexp_simp nexp2) in + begin + match (nexp1, nexp2) with + | Nexp_aux (Nexp_constant c1, _), Nexp_aux (Nexp_constant c2, _) -> + if Big_int.less c1 c2 then NC_true else NC_false + | _, _ -> NC_bounded_lt (nexp1, nexp2) + end | NC_app (id, [A_aux (A_bool nc, arg_l)]) when Id.compare (mk_id "not") id = 0 -> - let nc = constraint_simp nc in - begin match nc with - | NC_aux (NC_false, _) -> NC_true - | NC_aux (NC_true, _) -> NC_false - | NC_aux (NC_app (id, [A_aux (A_bool (NC_aux (nc_aux, _)), _)]), _) when Id.compare (mk_id "not") id = 0 -> - nc_aux - | _ -> NC_app (id, [A_aux (A_bool nc, arg_l)]) - end - + let nc = constraint_simp nc in + begin + match nc with + | NC_aux (NC_false, _) -> NC_true + | NC_aux (NC_true, _) -> NC_false + | NC_aux (NC_app (id, [A_aux (A_bool (NC_aux (nc_aux, _)), _)]), _) when Id.compare (mk_id "not") id = 0 -> + nc_aux + | _ -> NC_app (id, [A_aux (A_bool nc, arg_l)]) + end | _ -> nc_aux in NC_aux (nc_aux, l) let rec constraint_conj (NC_aux (nc_aux, _) as nc) = - match nc_aux with - | NC_and (nc1, nc2) -> constraint_conj nc1 @ constraint_conj nc2 - | _ -> [nc] + match nc_aux with NC_and (nc1, nc2) -> constraint_conj nc1 @ constraint_conj nc2 | _ -> [nc] let rec constraint_disj (NC_aux (nc_aux, _) as nc) = - match nc_aux with - | NC_or (nc1, nc2) -> constraint_disj nc1 @ constraint_disj nc2 - | _ -> [nc] + match nc_aux with NC_or (nc1, nc2) -> constraint_disj nc1 @ constraint_disj nc2 | _ -> [nc] let mk_typ typ = Typ_aux (typ, Parse_ast.Unknown) let mk_typ_arg arg = A_aux (arg, Parse_ast.Unknown) @@ -501,9 +446,7 @@ let mk_kid str = Kid_aux (Var ("'" ^ str), Parse_ast.Unknown) let mk_id_typ id = Typ_aux (Typ_id id, Parse_ast.Unknown) let mk_kopt ?loc:(l = Parse_ast.Unknown) kind_aux v = - let l = match l with - | Parse_ast.Unknown -> kid_loc v - | l -> l in + let l = match l with Parse_ast.Unknown -> kid_loc v | l -> l in KOpt_aux (KOpt_kind (K_aux (kind_aux, l), v), l) let mk_ord ord_aux = Ord_aux (ord_aux, Parse_ast.Unknown) @@ -516,13 +459,10 @@ let bit_typ = mk_id_typ (mk_id "bit") let real_typ = mk_id_typ (mk_id "real") let app_typ id args = mk_typ (Typ_app (id, args)) let register_typ typ = mk_typ (Typ_app (mk_id "register", [mk_typ_arg (A_typ typ)])) -let atom_typ nexp = - mk_typ (Typ_app (mk_id "atom", [mk_typ_arg (A_nexp (nexp_simp nexp))])) -let implicit_typ nexp = - mk_typ (Typ_app (mk_id "implicit", [mk_typ_arg (A_nexp (nexp_simp nexp))])) +let atom_typ nexp = mk_typ (Typ_app (mk_id "atom", [mk_typ_arg (A_nexp (nexp_simp nexp))])) +let implicit_typ nexp = mk_typ (Typ_app (mk_id "implicit", [mk_typ_arg (A_nexp (nexp_simp nexp))])) let range_typ nexp1 nexp2 = - mk_typ (Typ_app (mk_id "range", [mk_typ_arg (A_nexp (nexp_simp nexp1)); - mk_typ_arg (A_nexp (nexp_simp nexp2))])) + mk_typ (Typ_app (mk_id "range", [mk_typ_arg (A_nexp (nexp_simp nexp1)); mk_typ_arg (A_nexp (nexp_simp nexp2))])) let bool_typ = mk_id_typ (mk_id "bool") let atom_bool_typ nc = mk_typ (Typ_app (mk_id "atom_bool", [mk_typ_arg (A_bool nc)])) let string_typ = mk_id_typ (mk_id "string") @@ -531,15 +471,11 @@ let tuple_typ typs = mk_typ (Typ_tuple typs) let function_typ arg_typs ret_typ = mk_typ (Typ_fn (arg_typs, ret_typ)) let vector_typ n ord typ = - mk_typ (Typ_app (mk_id "vector", - [mk_typ_arg (A_nexp (nexp_simp n)); - mk_typ_arg (A_order ord); - mk_typ_arg (A_typ typ)])) + mk_typ + (Typ_app (mk_id "vector", [mk_typ_arg (A_nexp (nexp_simp n)); mk_typ_arg (A_order ord); mk_typ_arg (A_typ typ)])) let bitvector_typ n ord = - mk_typ (Typ_app (mk_id "bitvector", - [mk_typ_arg (A_nexp (nexp_simp n)); - mk_typ_arg (A_order ord)])) + mk_typ (Typ_app (mk_id "bitvector", [mk_typ_arg (A_nexp (nexp_simp n)); mk_typ_arg (A_order ord)])) let exc_typ = mk_id_typ (mk_id "exception") @@ -566,21 +502,21 @@ let nc_true = mk_nc NC_true let nc_false = mk_nc NC_false let nc_or nc1 nc2 = - match nc1, nc2 with + match (nc1, nc2) with | _, NC_aux (NC_false, _) -> nc1 | NC_aux (NC_false, _), _ -> nc2 | _, _ -> mk_nc (NC_or (nc1, nc2)) let nc_and nc1 nc2 = - match nc1, nc2 with + match (nc1, nc2) with | _, NC_aux (NC_true, _) -> nc1 | NC_aux (NC_true, _), _ -> nc2 | _, _ -> mk_nc (NC_and (nc1, nc2)) -let arg_nexp ?loc:(l=Parse_ast.Unknown) n = A_aux (A_nexp n, l) -let arg_order ?loc:(l=Parse_ast.Unknown) ord = A_aux (A_order ord, l) -let arg_typ ?loc:(l=Parse_ast.Unknown) typ = A_aux (A_typ typ, l) -let arg_bool ?loc:(l=Parse_ast.Unknown) nc = A_aux (A_bool nc, l) +let arg_nexp ?loc:(l = Parse_ast.Unknown) n = A_aux (A_nexp n, l) +let arg_order ?loc:(l = Parse_ast.Unknown) ord = A_aux (A_order ord, l) +let arg_typ ?loc:(l = Parse_ast.Unknown) typ = A_aux (A_typ typ, l) +let arg_bool ?loc:(l = Parse_ast.Unknown) nc = A_aux (A_bool nc, l) let arg_kopt (KOpt_aux (KOpt_kind (K_aux (k, _), v), l)) = match k with @@ -598,12 +534,12 @@ let mk_typquant qis = TypQ_aux (TypQ_tq qis, Parse_ast.Unknown) let mk_fexp id exp = FE_aux (FE_fexp (id, exp), no_annot) type effect = bool - + let no_effect = false let monadic_effect = true let quant_add qi typq = - match qi, typq with + match (qi, typq) with | QI_aux (QI_constraint (NC_aux (NC_true, _)), _), _ -> typq | QI_aux (QI_id _, _), TypQ_aux (TypQ_tq qis, l) -> TypQ_aux (TypQ_tq (qi :: qis), l) | QI_aux (QI_constraint _, _), TypQ_aux (TypQ_tq qis, l) -> TypQ_aux (TypQ_tq (qis @ [qi]), l) @@ -614,35 +550,22 @@ let quant_items : typquant -> quant_item list = function | TypQ_aux (TypQ_no_forall, _) -> [] let quant_kopts typq = - let qi_kopt = function - | QI_aux (QI_id kopt, _) -> [kopt] - | QI_aux _ -> [] - in + let qi_kopt = function QI_aux (QI_id kopt, _) -> [kopt] | QI_aux _ -> [] in quant_items typq |> List.map qi_kopt |> List.concat let quant_split typq = - let qi_kopt = function - | QI_aux (QI_id kopt, _) -> [kopt] - | _ -> [] - in - let qi_nc = function - | QI_aux (QI_constraint nc, _) -> [nc] - | _ -> [] - in + let qi_kopt = function QI_aux (QI_id kopt, _) -> [kopt] | _ -> [] in + let qi_nc = function QI_aux (QI_constraint nc, _) -> [nc] | _ -> [] in let qis = quant_items typq in - List.concat (List.map qi_kopt qis), List.concat (List.map qi_nc qis) + (List.concat (List.map qi_kopt qis), List.concat (List.map qi_nc qis)) let quant_map_items f = function | TypQ_aux (TypQ_no_forall, l) -> TypQ_aux (TypQ_no_forall, l) | TypQ_aux (TypQ_tq qis, l) -> TypQ_aux (TypQ_tq (List.map f qis), l) -let is_quant_kopt = function - | QI_aux (QI_id _, _) -> true - | _ -> false +let is_quant_kopt = function QI_aux (QI_id _, _) -> true | _ -> false -let is_quant_constraint = function - | QI_aux (QI_constraint _, _) -> true - | _ -> false +let is_quant_constraint = function QI_aux (QI_constraint _, _) -> true | _ -> false let unaux_nexp (Nexp_aux (nexp, _)) = nexp let unaux_order (Ord_aux (ord, _)) = ord @@ -653,36 +576,33 @@ let unaux_constraint (NC_aux (nc, _)) = nc let rec insert_subrange ms (n1, n2) = match ms with | (m1, m2) :: ms -> - if Big_int.equal n2 (Big_int.succ m1) then - (n1, m2) :: ms - else if Big_int.greater n2 m1 then - (n1, n2) :: (m1, m2) :: ms - else if Big_int.equal m2 (Big_int.succ n1) then - insert_subrange ms (m1, n2) - else - (m1, m2) :: insert_subrange ms (n1, n2) + if Big_int.equal n2 (Big_int.succ m1) then (n1, m2) :: ms + else if Big_int.greater n2 m1 then (n1, n2) :: (m1, m2) :: ms + else if Big_int.equal m2 (Big_int.succ n1) then insert_subrange ms (m1, n2) + else (m1, m2) :: insert_subrange ms (n1, n2) | [] -> [(n1, n2)] let insert_subranges ns ms = List.fold_left insert_subrange ns ms let rec pattern_vector_subranges (P_aux (aux, (l, _))) = match aux with - | P_vector_subrange (id, n, m) when Big_int.greater n m -> - Bindings.singleton id [(n, m)] - | P_vector_subrange (id, n, m) -> - Bindings.singleton id [(m, n)] - | P_typ (_, pat) | P_var (pat, _) | P_as (pat, _) | P_not pat -> - pattern_vector_subranges pat + | P_vector_subrange (id, n, m) when Big_int.greater n m -> Bindings.singleton id [(n, m)] + | P_vector_subrange (id, n, m) -> Bindings.singleton id [(m, n)] + | P_typ (_, pat) | P_var (pat, _) | P_as (pat, _) | P_not pat -> pattern_vector_subranges pat | P_cons (pat1, pat2) | P_or (pat1, pat2) -> - Bindings.union (fun _ r1 r2 -> Some (insert_subranges r1 r2)) (pattern_vector_subranges pat1) (pattern_vector_subranges pat2) + Bindings.union + (fun _ r1 r2 -> Some (insert_subranges r1 r2)) + (pattern_vector_subranges pat1) (pattern_vector_subranges pat2) | P_tuple pats | P_vector_concat pats | P_app (_, pats) | P_list pats | P_string_append pats | P_vector pats -> - List.fold_left (fun ranges pat -> - Bindings.union (fun _ r1 r2 -> Some (insert_subranges r1 r2)) ranges (pattern_vector_subranges pat) - ) Bindings.empty pats - | P_id _ | P_lit _ | P_wild -> - Bindings.empty + List.fold_left + (fun ranges pat -> + Bindings.union (fun _ r1 r2 -> Some (insert_subranges r1 r2)) ranges (pattern_vector_subranges pat) + ) + Bindings.empty pats + | P_id _ | P_lit _ | P_wild -> Bindings.empty let rec map_exp_annot f (E_aux (exp, annot)) = E_aux (map_exp_annot_aux f exp, f annot) + and map_exp_annot_aux f = function | E_block xs -> E_block (List.map (map_exp_annot f) xs) | E_id id -> E_id id @@ -693,14 +613,18 @@ and map_exp_annot_aux f = function | E_app_infix (x, op, y) -> E_app_infix (map_exp_annot f x, op, map_exp_annot f y) | E_tuple xs -> E_tuple (List.map (map_exp_annot f) xs) | E_if (cond, t, e) -> E_if (map_exp_annot f cond, map_exp_annot f t, map_exp_annot f e) - | E_for (v, e1, e2, e3, o, e4) -> E_for (v, map_exp_annot f e1, map_exp_annot f e2, map_exp_annot f e3, o, map_exp_annot f e4) - | E_loop (loop_type, measure, e1, e2) -> E_loop (loop_type, map_measure_annot f measure, map_exp_annot f e1, map_exp_annot f e2) + | E_for (v, e1, e2, e3, o, e4) -> + E_for (v, map_exp_annot f e1, map_exp_annot f e2, map_exp_annot f e3, o, map_exp_annot f e4) + | E_loop (loop_type, measure, e1, e2) -> + E_loop (loop_type, map_measure_annot f measure, map_exp_annot f e1, map_exp_annot f e2) | E_vector exps -> E_vector (List.map (map_exp_annot f) exps) | E_vector_access (exp1, exp2) -> E_vector_access (map_exp_annot f exp1, map_exp_annot f exp2) - | E_vector_subrange (exp1, exp2, exp3) -> E_vector_subrange (map_exp_annot f exp1, map_exp_annot f exp2, map_exp_annot f exp3) - | E_vector_update (exp1, exp2, exp3) -> E_vector_update (map_exp_annot f exp1, map_exp_annot f exp2, map_exp_annot f exp3) + | E_vector_subrange (exp1, exp2, exp3) -> + E_vector_subrange (map_exp_annot f exp1, map_exp_annot f exp2, map_exp_annot f exp3) + | E_vector_update (exp1, exp2, exp3) -> + E_vector_update (map_exp_annot f exp1, map_exp_annot f exp2, map_exp_annot f exp3) | E_vector_update_subrange (exp1, exp2, exp3, exp4) -> - E_vector_update_subrange (map_exp_annot f exp1, map_exp_annot f exp2, map_exp_annot f exp3, map_exp_annot f exp4) + E_vector_update_subrange (map_exp_annot f exp1, map_exp_annot f exp2, map_exp_annot f exp3, map_exp_annot f exp4) | E_vector_append (exp1, exp2) -> E_vector_append (map_exp_annot f exp1, map_exp_annot f exp2) | E_list xs -> E_list (List.map (map_exp_annot f) xs) | E_cons (exp1, exp2) -> E_cons (map_exp_annot f exp1, map_exp_annot f exp2) @@ -719,24 +643,32 @@ and map_exp_annot_aux f = function | E_assert (test, msg) -> E_assert (map_exp_annot f test, map_exp_annot f msg) | E_internal_value v -> E_internal_value v | E_var (lexp, exp1, exp2) -> E_var (map_lexp_annot f lexp, map_exp_annot f exp1, map_exp_annot f exp2) - | E_internal_plet (pat, exp1, exp2) -> E_internal_plet (map_pat_annot f pat, map_exp_annot f exp1, map_exp_annot f exp2) + | E_internal_plet (pat, exp1, exp2) -> + E_internal_plet (map_pat_annot f pat, map_exp_annot f exp1, map_exp_annot f exp2) | E_internal_return exp -> E_internal_return (map_exp_annot f exp) | E_internal_assume (nc, exp) -> E_internal_assume (nc, map_exp_annot f exp) + and map_measure_annot f (Measure_aux (m, l)) = Measure_aux (map_measure_annot_aux f m, l) + and map_measure_annot_aux f = function | Measure_none -> Measure_none | Measure_some exp -> Measure_some (map_exp_annot f exp) + and map_fexp_annot f (FE_aux (FE_fexp (id, exp), annot)) = FE_aux (FE_fexp (id, map_exp_annot f exp), f annot) + and map_pexp_annot f (Pat_aux (pexp, annot)) = Pat_aux (map_pexp_annot_aux f pexp, f annot) + and map_pexp_annot_aux f = function | Pat_exp (pat, exp) -> Pat_exp (map_pat_annot f pat, map_exp_annot f exp) | Pat_when (pat, guard, exp) -> Pat_when (map_pat_annot f pat, map_exp_annot f guard, map_exp_annot f exp) + and map_pat_annot f (P_aux (pat, annot)) = P_aux (map_pat_annot_aux f pat, f annot) + and map_pat_annot_aux f = function | P_lit lit -> P_lit lit | P_wild -> P_wild - | P_or (pat1, pat2) -> P_or (map_pat_annot f pat1, map_pat_annot f pat2) - | P_not pat -> P_not (map_pat_annot f pat) + | P_or (pat1, pat2) -> P_or (map_pat_annot f pat1, map_pat_annot f pat2) + | P_not pat -> P_not (map_pat_annot f pat) | P_as (pat, id) -> P_as (map_pat_annot f pat, id) | P_typ (typ, pat) -> P_typ (typ, map_pat_annot f pat) | P_id id -> P_id id @@ -751,20 +683,21 @@ and map_pat_annot_aux f = function | P_string_append pats -> P_string_append (List.map (map_pat_annot f) pats) and map_mpexp_annot f (MPat_aux (mpexp, annot)) = MPat_aux (map_mpexp_annot_aux f mpexp, f annot) + and map_mpexp_annot_aux f = function | MPat_pat mpat -> MPat_pat (map_mpat_annot f mpat) | MPat_when (mpat, guard) -> MPat_when (map_mpat_annot f mpat, map_exp_annot f guard) -and map_mapcl_annot f = - function - | (MCL_aux (MCL_bidir (mpexp1, mpexp2), annot)) -> - MCL_aux (MCL_bidir (map_mpexp_annot f mpexp1, map_mpexp_annot f mpexp2), map_clause_annot f annot) - | (MCL_aux (MCL_forwards (mpexp, exp), annot)) -> - MCL_aux (MCL_forwards (map_mpexp_annot f mpexp, map_exp_annot f exp), map_clause_annot f annot) - | (MCL_aux (MCL_backwards (mpexp, exp), annot)) -> - MCL_aux (MCL_backwards (map_mpexp_annot f mpexp, map_exp_annot f exp), map_clause_annot f annot) +and map_mapcl_annot f = function + | MCL_aux (MCL_bidir (mpexp1, mpexp2), annot) -> + MCL_aux (MCL_bidir (map_mpexp_annot f mpexp1, map_mpexp_annot f mpexp2), map_clause_annot f annot) + | MCL_aux (MCL_forwards (mpexp, exp), annot) -> + MCL_aux (MCL_forwards (map_mpexp_annot f mpexp, map_exp_annot f exp), map_clause_annot f annot) + | MCL_aux (MCL_backwards (mpexp, exp), annot) -> + MCL_aux (MCL_backwards (map_mpexp_annot f mpexp, map_exp_annot f exp), map_clause_annot f annot) and map_mpat_annot f (MP_aux (mpat, annot)) = MP_aux (map_mpat_annot_aux f mpat, f annot) + and map_mpat_annot_aux f = function | MP_lit lit -> MP_lit lit | MP_id id -> MP_id id @@ -780,9 +713,11 @@ and map_mpat_annot_aux f = function | MP_as (mpat, id) -> MP_as (map_mpat_annot f mpat, id) and map_letbind_annot f (LB_aux (lb, annot)) = LB_aux (map_letbind_annot_aux f lb, f annot) -and map_letbind_annot_aux f = function - | LB_val (pat, exp) -> LB_val (map_pat_annot f pat, map_exp_annot f exp) + +and map_letbind_annot_aux f = function LB_val (pat, exp) -> LB_val (map_pat_annot f pat, map_exp_annot f exp) + and map_lexp_annot f (LE_aux (lexp, annot)) = LE_aux (map_lexp_annot_aux f lexp, f annot) + and map_lexp_annot_aux f = function | LE_id id -> LE_id id | LE_deref exp -> LE_deref (map_exp_annot f exp) @@ -791,38 +726,38 @@ and map_lexp_annot_aux f = function | LE_tuple lexps -> LE_tuple (List.map (map_lexp_annot f) lexps) | LE_vector_concat lexps -> LE_vector_concat (List.map (map_lexp_annot f) lexps) | LE_vector (lexp, exp) -> LE_vector (map_lexp_annot f lexp, map_exp_annot f exp) - | LE_vector_range (lexp, exp1, exp2) -> LE_vector_range (map_lexp_annot f lexp, map_exp_annot f exp1, map_exp_annot f exp2) + | LE_vector_range (lexp, exp1, exp2) -> + LE_vector_range (map_lexp_annot f lexp, map_exp_annot f exp1, map_exp_annot f exp2) | LE_field (lexp, id) -> LE_field (map_lexp_annot f lexp, id) -and map_typedef_annot f = function - | TD_aux (td_aux, annot) -> TD_aux (td_aux, f annot) +and map_typedef_annot f = function TD_aux (td_aux, annot) -> TD_aux (td_aux, f annot) + +and map_fundef_annot f = function FD_aux (fd_aux, annot) -> FD_aux (map_fundef_annot_aux f fd_aux, f annot) -and map_fundef_annot f = function - | FD_aux (fd_aux, annot) -> FD_aux (map_fundef_annot_aux f fd_aux, f annot) and map_fundef_annot_aux f = function - | FD_function (rec_opt, tannot_opt, funcls) -> FD_function (map_recopt_annot f rec_opt, tannot_opt, - List.map (map_funcl_annot f) funcls) -and map_funcl_annot f = function - | FCL_aux (fcl, annot) -> FCL_aux (map_funcl_annot_aux f fcl, map_clause_annot f annot) -and map_funcl_annot_aux f = function - | FCL_funcl (id, pexp) -> FCL_funcl (id, map_pexp_annot f pexp) -and map_recopt_annot f = function - | Rec_aux (rec_aux, l) -> Rec_aux (map_recopt_annot_aux f rec_aux, l) + | FD_function (rec_opt, tannot_opt, funcls) -> + FD_function (map_recopt_annot f rec_opt, tannot_opt, List.map (map_funcl_annot f) funcls) + +and map_funcl_annot f = function FCL_aux (fcl, annot) -> FCL_aux (map_funcl_annot_aux f fcl, map_clause_annot f annot) + +and map_funcl_annot_aux f = function FCL_funcl (id, pexp) -> FCL_funcl (id, map_pexp_annot f pexp) + +and map_recopt_annot f = function Rec_aux (rec_aux, l) -> Rec_aux (map_recopt_annot_aux f rec_aux, l) + and map_recopt_annot_aux f = function | Rec_nonrec -> Rec_nonrec | Rec_rec -> Rec_rec | Rec_measure (pat, exp) -> Rec_measure (map_pat_annot f pat, map_exp_annot f exp) -and map_mapdef_annot f = function - | MD_aux (md_aux, annot) -> MD_aux (map_mapdef_annot_aux f md_aux, f annot) +and map_mapdef_annot f = function MD_aux (md_aux, annot) -> MD_aux (map_mapdef_annot_aux f md_aux, f annot) + and map_mapdef_annot_aux f = function | MD_mapping (id, tannot_opt, mapcls) -> MD_mapping (id, tannot_opt, List.map (map_mapcl_annot f) mapcls) -and map_valspec_annot f = function - | VS_aux (vs_aux, annot) -> VS_aux (vs_aux, f annot) +and map_valspec_annot f = function VS_aux (vs_aux, annot) -> VS_aux (vs_aux, f annot) + +and map_scattered_annot f = function SD_aux (sd_aux, annot) -> SD_aux (map_scattered_annot_aux f sd_aux, f annot) -and map_scattered_annot f = function - | SD_aux (sd_aux, annot) -> SD_aux (map_scattered_annot_aux f sd_aux, f annot) and map_scattered_annot_aux f = function | SD_function (rec_opt, tannot_opt, name) -> SD_function (map_recopt_annot f rec_opt, tannot_opt, name) | SD_funcl fcl -> SD_funcl (map_funcl_annot f fcl) @@ -832,14 +767,15 @@ and map_scattered_annot_aux f = function | SD_mapcl (id, mcl) -> SD_mapcl (id, map_mapcl_annot f mcl) | SD_end id -> SD_end id -and map_register_annot f = function - | DEC_aux (dec_aux, annot) -> DEC_aux (map_register_annot_aux f dec_aux, f annot) +and map_register_annot f = function DEC_aux (dec_aux, annot) -> DEC_aux (map_register_annot_aux f dec_aux, f annot) + and map_register_annot_aux f = function | DEC_reg (typ, id, None) -> DEC_reg (typ, id, None) | DEC_reg (typ, id, Some exp) -> DEC_reg (typ, id, Some (map_exp_annot f exp)) and map_def_annot f (DEF_aux (aux, annot)) = - let aux = match aux with + let aux = + match aux with | DEF_type td -> DEF_type (map_typedef_annot f td) | DEF_fundef fd -> DEF_fundef (map_fundef_annot f fd) | DEF_mapdef md -> DEF_mapdef (map_mapdef_annot f md) @@ -859,50 +795,40 @@ and map_def_annot f (DEF_aux (aux, annot)) = | DEF_pragma (name, arg, l) -> DEF_pragma (name, arg, l) in DEF_aux (aux, annot) + and map_ast_annot f ast = { ast with defs = List.map (map_def_annot f) ast.defs } -and map_loop_measure_annot f = function - | Loop (loop, exp) -> Loop (loop, map_exp_annot f exp) +and map_loop_measure_annot f = function Loop (loop, exp) -> Loop (loop, map_exp_annot f exp) let def_loc (DEF_aux (_, annot)) = annot.loc -let id_of_kid = function - | Kid_aux (Var v, l) -> Id_aux (Id (String.sub v 1 (String.length v - 1)), l) +let id_of_kid = function Kid_aux (Var v, l) -> Id_aux (Id (String.sub v 1 (String.length v - 1)), l) -let kid_of_id = function - | Id_aux (Id v, l) -> Kid_aux (Var ("'" ^ v), l) - | Id_aux (Operator _, _) -> assert false +let kid_of_id = function Id_aux (Id v, l) -> Kid_aux (Var ("'" ^ v), l) | Id_aux (Operator _, _) -> assert false let prepend_id str = function | Id_aux (Id v, l) -> Id_aux (Id (str ^ v), l) | Id_aux (Operator v, l) -> Id_aux (Operator (str ^ v), l) let append_id id str = - match id with - | Id_aux (Id v, l) -> Id_aux (Id (v ^ str), l) - | Id_aux (Operator v, l) -> Id_aux (Operator (v ^ str), l) + match id with Id_aux (Id v, l) -> Id_aux (Id (v ^ str), l) | Id_aux (Operator v, l) -> Id_aux (Operator (v ^ str), l) let prepend_kid str = function | Kid_aux (Var v, l) -> Kid_aux (Var ("'" ^ str ^ String.sub v 1 (String.length v - 1)), l) -let string_of_kind_aux = function - | K_type -> "Type" - | K_int -> "Int" - | K_order -> "Order" - | K_bool -> "Bool" +let string_of_kind_aux = function K_type -> "Type" | K_int -> "Int" | K_order -> "Order" | K_bool -> "Bool" let string_of_kind (K_aux (k, _)) = string_of_kind_aux k -let string_of_kinded_id (KOpt_aux (KOpt_kind (k, kid), _)) = - "(" ^ string_of_kid kid ^ " : " ^ string_of_kind k ^ ")" +let string_of_kinded_id (KOpt_aux (KOpt_kind (k, kid), _)) = "(" ^ string_of_kid kid ^ " : " ^ string_of_kind k ^ ")" let string_of_order = function | Ord_aux (Ord_var kid, _) -> string_of_kid kid | Ord_aux (Ord_inc, _) -> "inc" | Ord_aux (Ord_dec, _) -> "dec" -let rec string_of_nexp = function - | Nexp_aux (nexp, _) -> string_of_nexp_aux nexp +let rec string_of_nexp = function Nexp_aux (nexp, _) -> string_of_nexp_aux nexp + and string_of_nexp_aux = function | Nexp_id id -> string_of_id id | Nexp_var kid -> string_of_kid kid @@ -914,31 +840,34 @@ and string_of_nexp_aux = function | Nexp_exp n -> "2 ^ " ^ string_of_nexp n | Nexp_neg n -> "- " ^ string_of_nexp n -let rec string_of_typ = function - | Typ_aux (typ, _) -> string_of_typ_aux typ +let rec string_of_typ = function Typ_aux (typ, _) -> string_of_typ_aux typ + and string_of_typ_aux = function | Typ_internal_unknown -> "" | Typ_id id -> string_of_id id | Typ_var kid -> string_of_kid kid | Typ_tuple typs -> "(" ^ string_of_list ", " string_of_typ typs ^ ")" - | Typ_app (id, args) when Id.compare id (mk_id "atom") = 0 -> "int(" ^ string_of_list ", " string_of_typ_arg args ^ ")" - | Typ_app (id, args) when Id.compare id (mk_id "atom_bool") = 0 -> "bool(" ^ string_of_list ", " string_of_typ_arg args ^ ")" + | Typ_app (id, args) when Id.compare id (mk_id "atom") = 0 -> + "int(" ^ string_of_list ", " string_of_typ_arg args ^ ")" + | Typ_app (id, args) when Id.compare id (mk_id "atom_bool") = 0 -> + "bool(" ^ string_of_list ", " string_of_typ_arg args ^ ")" | Typ_app (id, args) -> string_of_id id ^ "(" ^ string_of_list ", " string_of_typ_arg args ^ ")" - | Typ_fn ([typ_arg], typ_ret) -> - string_of_typ typ_arg ^ " -> " ^ string_of_typ typ_ret - | Typ_fn (typ_args, typ_ret) -> - "(" ^ string_of_list ", " string_of_typ typ_args ^ ") -> " - ^ string_of_typ typ_ret + | Typ_fn ([typ_arg], typ_ret) -> string_of_typ typ_arg ^ " -> " ^ string_of_typ typ_ret + | Typ_fn (typ_args, typ_ret) -> "(" ^ string_of_list ", " string_of_typ typ_args ^ ") -> " ^ string_of_typ typ_ret | Typ_bidir (typ1, typ2) -> string_of_typ typ1 ^ " <-> " ^ string_of_typ typ2 | Typ_exist (kids, nc, typ) -> - "{" ^ string_of_list " " string_of_kinded_id kids ^ ", " ^ string_of_n_constraint nc ^ ". " ^ string_of_typ typ ^ "}" -and string_of_typ_arg = function - | A_aux (typ_arg, _) -> string_of_typ_arg_aux typ_arg + "{" + ^ string_of_list " " string_of_kinded_id kids + ^ ", " ^ string_of_n_constraint nc ^ ". " ^ string_of_typ typ ^ "}" + +and string_of_typ_arg = function A_aux (typ_arg, _) -> string_of_typ_arg_aux typ_arg + and string_of_typ_arg_aux = function | A_nexp n -> string_of_nexp n | A_typ typ -> string_of_typ typ | A_order o -> string_of_order o | A_bool nc -> string_of_n_constraint nc + and string_of_n_constraint = function | NC_aux (NC_equal (n1, n2), _) -> string_of_nexp n1 ^ " == " ^ string_of_nexp n2 | NC_aux (NC_not_equal (n1, n2), _) -> string_of_nexp n1 ^ " != " ^ string_of_nexp n2 @@ -946,14 +875,11 @@ and string_of_n_constraint = function | NC_aux (NC_bounded_gt (n1, n2), _) -> string_of_nexp n1 ^ " > " ^ string_of_nexp n2 | NC_aux (NC_bounded_le (n1, n2), _) -> string_of_nexp n1 ^ " <= " ^ string_of_nexp n2 | NC_aux (NC_bounded_lt (n1, n2), _) -> string_of_nexp n1 ^ " < " ^ string_of_nexp n2 - | NC_aux (NC_or (nc1, nc2), _) -> - "(" ^ string_of_n_constraint nc1 ^ " | " ^ string_of_n_constraint nc2 ^ ")" - | NC_aux (NC_and (nc1, nc2), _) -> - "(" ^ string_of_n_constraint nc1 ^ " & " ^ string_of_n_constraint nc2 ^ ")" - | NC_aux (NC_set (kid, ns), _) -> - string_of_kid kid ^ " in {" ^ string_of_list ", " Big_int.to_string ns ^ "}" + | NC_aux (NC_or (nc1, nc2), _) -> "(" ^ string_of_n_constraint nc1 ^ " | " ^ string_of_n_constraint nc2 ^ ")" + | NC_aux (NC_and (nc1, nc2), _) -> "(" ^ string_of_n_constraint nc1 ^ " & " ^ string_of_n_constraint nc2 ^ ")" + | NC_aux (NC_set (kid, ns), _) -> string_of_kid kid ^ " in {" ^ string_of_list ", " Big_int.to_string ns ^ "}" | NC_aux (NC_app (Id_aux (Operator op, _), [arg1; arg2]), _) -> - "(" ^ string_of_typ_arg arg1 ^ " " ^ op ^ " " ^ string_of_typ_arg arg2 ^ ")" + "(" ^ string_of_typ_arg arg1 ^ " " ^ op ^ " " ^ string_of_typ_arg arg2 ^ ")" | NC_aux (NC_app (id, args), _) -> string_of_id id ^ "(" ^ string_of_list ", " string_of_typ_arg args ^ ")" | NC_aux (NC_var v, _) -> string_of_kid v | NC_aux (NC_true, _) -> "true" @@ -965,18 +891,15 @@ let string_of_quant_item_aux = function | QI_id kopt -> string_of_kinded_id kopt | QI_constraint constr -> string_of_n_constraint constr -let string_of_quant_item = function - | QI_aux (qi, _) -> string_of_quant_item_aux qi +let string_of_quant_item = function QI_aux (qi, _) -> string_of_quant_item_aux qi let string_of_typquant_aux = function | TypQ_tq quants -> "forall " ^ string_of_list ", " string_of_quant_item quants | TypQ_no_forall -> "" -let string_of_typquant = function - | TypQ_aux (quant, _) -> string_of_typquant_aux quant +let string_of_typquant = function TypQ_aux (quant, _) -> string_of_typquant_aux quant -let string_of_typschm (TypSchm_aux (TypSchm_ts (quant, typ), _)) = - string_of_typquant quant ^ ". " ^ string_of_typ typ +let string_of_typschm (TypSchm_aux (TypSchm_ts (quant, typ), _)) = string_of_typquant quant ^ ". " ^ string_of_typ typ let string_of_lit (L_aux (lit, _)) = match lit with | L_unit -> "()" @@ -1004,68 +927,67 @@ let rec string_of_exp (E_aux (exp, _)) = | E_app (f, args) -> string_of_id f ^ "(" ^ string_of_list ", " string_of_exp args ^ ")" | E_app_infix (x, op, y) -> "(" ^ string_of_exp x ^ " " ^ string_of_id op ^ " " ^ string_of_exp y ^ ")" | E_tuple exps -> "(" ^ string_of_list ", " string_of_exp exps ^ ")" - | E_match (exp, cases) -> - "match " ^ string_of_exp exp ^ " { " ^ string_of_list ", " string_of_pexp cases ^ " }" + | E_match (exp, cases) -> "match " ^ string_of_exp exp ^ " { " ^ string_of_list ", " string_of_pexp cases ^ " }" | E_try (exp, cases) -> - "try " ^ string_of_exp exp ^ " catch { case " ^ string_of_list " case " string_of_pexp cases ^ "}" + "try " ^ string_of_exp exp ^ " catch { case " ^ string_of_list " case " string_of_pexp cases ^ "}" | E_let (letbind, exp) -> "let " ^ string_of_letbind letbind ^ " in " ^ string_of_exp exp | E_assign (lexp, bind) -> string_of_lexp lexp ^ " = " ^ string_of_exp bind | E_typ (typ, exp) -> string_of_exp exp ^ " : " ^ string_of_typ typ | E_vector vec -> "[" ^ string_of_list ", " string_of_exp vec ^ "]" | E_vector_access (v, n) -> string_of_exp v ^ "[" ^ string_of_exp n ^ "]" | E_vector_update (v, n, exp) -> "[" ^ string_of_exp v ^ " with " ^ string_of_exp n ^ " = " ^ string_of_exp exp ^ "]" - | E_vector_update_subrange (v, n, m, exp) -> "[" ^ string_of_exp v ^ " with " ^ string_of_exp n ^ " .. " ^ string_of_exp m ^ " = " ^ string_of_exp exp ^ "]" + | E_vector_update_subrange (v, n, m, exp) -> + "[" ^ string_of_exp v ^ " with " ^ string_of_exp n ^ " .. " ^ string_of_exp m ^ " = " ^ string_of_exp exp ^ "]" | E_vector_subrange (v, n1, n2) -> string_of_exp v ^ "[" ^ string_of_exp n1 ^ " .. " ^ string_of_exp n2 ^ "]" | E_vector_append (v1, v2) -> string_of_exp v1 ^ " @ " ^ string_of_exp v2 | E_if (cond, then_branch, else_branch) -> - "if " ^ string_of_exp cond ^ " then " ^ string_of_exp then_branch ^ " else " ^ string_of_exp else_branch + "if " ^ string_of_exp cond ^ " then " ^ string_of_exp then_branch ^ " else " ^ string_of_exp else_branch | E_field (exp, id) -> string_of_exp exp ^ "." ^ string_of_id id | E_for (id, f, t, u, ord, body) -> - "foreach (" - ^ string_of_id id ^ " from " ^ string_of_exp f ^ " to " ^ string_of_exp t - ^ " by " ^ string_of_exp u ^ " order " ^ string_of_order ord - ^ ") { " - ^ string_of_exp body - | E_loop (While, measure, cond, body) -> "while " ^ string_of_measure measure ^ string_of_exp cond ^ " do " ^ string_of_exp body - | E_loop (Until, measure, cond, body) -> "repeat " ^ string_of_measure measure ^ string_of_exp body ^ " until " ^ string_of_exp cond + "foreach (" ^ string_of_id id ^ " from " ^ string_of_exp f ^ " to " ^ string_of_exp t ^ " by " ^ string_of_exp u + ^ " order " ^ string_of_order ord ^ ") { " ^ string_of_exp body + | E_loop (While, measure, cond, body) -> + "while " ^ string_of_measure measure ^ string_of_exp cond ^ " do " ^ string_of_exp body + | E_loop (Until, measure, cond, body) -> + "repeat " ^ string_of_measure measure ^ string_of_exp body ^ " until " ^ string_of_exp cond | E_assert (test, msg) -> "assert(" ^ string_of_exp test ^ ", " ^ string_of_exp msg ^ ")" | E_exit exp -> "exit " ^ string_of_exp exp | E_throw exp -> "throw " ^ string_of_exp exp | E_cons (x, xs) -> string_of_exp x ^ " :: " ^ string_of_exp xs | E_list xs -> "[|" ^ string_of_list ", " string_of_exp xs ^ "|]" | E_struct_update (exp, fexps) -> - "struct { " ^ string_of_exp exp ^ " with " ^ string_of_list "; " string_of_fexp fexps ^ " }" - | E_struct fexps -> - "struct { " ^ string_of_list "; " string_of_fexp fexps ^ " }" - | E_var (lexp, binding, exp) -> "var " ^ string_of_lexp lexp ^ " = " ^ string_of_exp binding ^ " in " ^ string_of_exp exp + "struct { " ^ string_of_exp exp ^ " with " ^ string_of_list "; " string_of_fexp fexps ^ " }" + | E_struct fexps -> "struct { " ^ string_of_list "; " string_of_fexp fexps ^ " }" + | E_var (lexp, binding, exp) -> + "var " ^ string_of_lexp lexp ^ " = " ^ string_of_exp binding ^ " in " ^ string_of_exp exp | E_internal_return exp -> "internal_return (" ^ string_of_exp exp ^ ")" - | E_internal_plet (pat, exp, body) -> "internal_plet " ^ string_of_pat pat ^ " = " ^ string_of_exp exp ^ " in " ^ string_of_exp body + | E_internal_plet (pat, exp, body) -> + "internal_plet " ^ string_of_pat pat ^ " = " ^ string_of_exp exp ^ " in " ^ string_of_exp body | E_internal_value v -> "INTERNAL_VALUE(" ^ Value.string_of_value v ^ ")" | E_internal_assume (nc, exp) -> "internal_assume " ^ string_of_n_constraint nc ^ " in " ^ string_of_exp exp -and string_of_measure (Measure_aux (m,_)) = - match m with - | Measure_none -> "" - | Measure_some exp -> "termination_measure { " ^ string_of_exp exp ^ "}" +and string_of_measure (Measure_aux (m, _)) = + match m with Measure_none -> "" | Measure_some exp -> "termination_measure { " ^ string_of_exp exp ^ "}" + +and string_of_fexp (FE_aux (FE_fexp (field, exp), _)) = string_of_id field ^ " = " ^ string_of_exp exp -and string_of_fexp (FE_aux (FE_fexp (field, exp), _)) = - string_of_id field ^ " = " ^ string_of_exp exp and string_of_pexp (Pat_aux (pexp, _)) = match pexp with | Pat_exp (pat, exp) -> string_of_pat pat ^ " -> " ^ string_of_exp exp | Pat_when (pat, guard, exp) -> string_of_pat pat ^ " when " ^ string_of_exp guard ^ " -> " ^ string_of_exp exp + and string_of_typ_pat (TP_aux (tpat_aux, _)) = match tpat_aux with | TP_wild -> "_" | TP_var kid -> string_of_kid kid | TP_app (f, tpats) -> string_of_id f ^ "(" ^ string_of_list ", " string_of_typ_pat tpats ^ ")" + and string_of_pat (P_aux (pat, _)) = match pat with | P_lit lit -> string_of_lit lit | P_wild -> "_" - | P_or (pat1, pat2) -> "(" ^ string_of_pat pat1 ^ " | " ^ string_of_pat pat2 - ^ ")" - | P_not pat -> "(!" ^ string_of_pat pat ^ ")" + | P_or (pat1, pat2) -> "(" ^ string_of_pat pat1 ^ " | " ^ string_of_pat pat2 ^ ")" + | P_not pat -> "(!" ^ string_of_pat pat ^ ")" | P_id v -> string_of_id v | P_var (pat, tpat) -> string_of_pat pat ^ " as " ^ string_of_typ_pat tpat | P_typ (typ, pat) -> string_of_pat pat ^ " : " ^ string_of_typ typ @@ -1075,10 +997,8 @@ and string_of_pat (P_aux (pat, _)) = | P_list pats -> "[||" ^ string_of_list "," string_of_pat pats ^ "||]" | P_vector_concat pats -> string_of_list " @ " string_of_pat pats | P_vector_subrange (id, n, m) -> - if Big_int.equal n m then - string_of_id id ^ "[" ^ Big_int.to_string n ^ "]" - else - string_of_id id ^ "[" ^ Big_int.to_string n ^ ".." ^ Big_int.to_string m ^ "]" + if Big_int.equal n m then string_of_id id ^ "[" ^ Big_int.to_string n ^ "]" + else string_of_id id ^ "[" ^ Big_int.to_string n ^ ".." ^ Big_int.to_string m ^ "]" | P_vector pats -> "[" ^ string_of_list ", " string_of_pat pats ^ "]" | P_as (pat, id) -> "(" ^ string_of_pat pat ^ " as " ^ string_of_id id ^ ")" | P_string_append [] -> "\"\"" @@ -1107,14 +1027,13 @@ and string_of_lexp (LE_aux (lexp, _)) = | LE_tuple lexps -> "(" ^ string_of_list ", " string_of_lexp lexps ^ ")" | LE_vector (lexp, exp) -> string_of_lexp lexp ^ "[" ^ string_of_exp exp ^ "]" | LE_vector_range (lexp, exp1, exp2) -> - string_of_lexp lexp ^ "[" ^ string_of_exp exp1 ^ " .. " ^ string_of_exp exp2 ^ "]" - | LE_vector_concat lexps -> - string_of_list " @ " string_of_lexp lexps + string_of_lexp lexp ^ "[" ^ string_of_exp exp1 ^ " .. " ^ string_of_exp exp2 ^ "]" + | LE_vector_concat lexps -> string_of_list " @ " string_of_lexp lexps | LE_field (lexp, id) -> string_of_lexp lexp ^ "." ^ string_of_id id | LE_app (f, xs) -> string_of_id f ^ "(" ^ string_of_list ", " string_of_exp xs ^ ")" + and string_of_letbind (LB_aux (lb, _)) = - match lb with - | LB_val (pat, exp) -> string_of_pat pat ^ " = " ^ string_of_exp exp + match lb with LB_val (pat, exp) -> string_of_pat pat ^ " = " ^ string_of_exp exp let rec string_of_index_range (BF_aux (ir, _)) = match ir with @@ -1128,36 +1047,43 @@ let rec pat_ids (P_aux (pat_aux, _)) = | P_id id | P_vector_subrange (id, _, _) -> IdSet.singleton id | P_as (pat, id) -> IdSet.add id (pat_ids pat) | P_or (pat1, pat2) -> IdSet.union (pat_ids pat1) (pat_ids pat2) - | P_not (pat) -> pat_ids pat + | P_not pat -> pat_ids pat | P_var (pat, _) | P_typ (_, pat) -> pat_ids pat | P_app (_, pats) | P_tuple pats | P_vector pats | P_vector_concat pats | P_list pats -> - List.fold_right IdSet.union (List.map pat_ids pats) IdSet.empty - | P_cons (pat1, pat2) -> - IdSet.union (pat_ids pat1) (pat_ids pat2) - | P_string_append pats -> - List.fold_right IdSet.union (List.map pat_ids pats) IdSet.empty + List.fold_right IdSet.union (List.map pat_ids pats) IdSet.empty + | P_cons (pat1, pat2) -> IdSet.union (pat_ids pat1) (pat_ids pat2) + | P_string_append pats -> List.fold_right IdSet.union (List.map pat_ids pats) IdSet.empty let id_of_fundef (FD_aux (FD_function (_, _, funcls), (l, _))) = - match (List.fold_right - (fun (FCL_aux (FCL_funcl (id, _), _)) id' -> - match id' with - | Some id' -> if string_of_id id' = string_of_id id then Some id' - else raise (Reporting.err_typ l - ("Function declaration expects all definitions to have the same name, " - ^ string_of_id id ^ " differs from other definitions of " ^ string_of_id id')) - | None -> Some id) funcls None) + match + List.fold_right + (fun (FCL_aux (FCL_funcl (id, _), _)) id' -> + match id' with + | Some id' -> + if string_of_id id' = string_of_id id then Some id' + else + raise + (Reporting.err_typ l + ("Function declaration expects all definitions to have the same name, " ^ string_of_id id + ^ " differs from other definitions of " ^ string_of_id id' + ) + ) + | None -> Some id + ) + funcls None with | Some id -> id | None -> raise (Reporting.err_typ l "Function clause list is empty") let id_of_mapdef (MD_aux (MD_mapping (id, _, _), _)) = id - + let id_of_type_def_aux = function | TD_abbrev (id, _, _) | TD_record (id, _, _, _) | TD_variant (id, _, _, _) | TD_enum (id, _, _) - | TD_bitfield (id, _, _) -> id + | TD_bitfield (id, _, _) -> + id let id_of_type_def (TD_aux (td_aux, _)) = id_of_type_def_aux td_aux let id_of_val_spec (VS_aux (VS_val_spec (_, id, _, _), _)) = id @@ -1166,9 +1092,14 @@ let id_of_dec_spec (DEC_aux (DEC_reg (_, id, _), _)) = id let id_of_scattered (SD_aux (sdef, _)) = match sdef with - | SD_function (_, _, id) | SD_funcl (FCL_aux (FCL_funcl (id, _), _)) | SD_end id - | SD_variant (id, _) | SD_unioncl (id, _) - | SD_mapping (id, _) | SD_mapcl (id, _) -> id + | SD_function (_, _, id) + | SD_funcl (FCL_aux (FCL_funcl (id, _), _)) + | SD_end id + | SD_variant (id, _) + | SD_unioncl (id, _) + | SD_mapping (id, _) + | SD_mapcl (id, _) -> + id let ids_of_def (DEF_aux (aux, _)) = match aux with @@ -1181,15 +1112,11 @@ let ids_of_def (DEF_aux (aux, _)) = | DEF_internal_mutrec fds -> IdSet.of_list (List.map id_of_fundef fds) | DEF_scattered sdef -> IdSet.singleton (id_of_scattered sdef) | _ -> IdSet.empty -let ids_of_defs defs = - List.fold_left IdSet.union IdSet.empty (List.map ids_of_def defs) +let ids_of_defs defs = List.fold_left IdSet.union IdSet.empty (List.map ids_of_def defs) let ids_of_ast ast = ids_of_defs ast.defs let val_spec_ids defs = - let val_spec_id (VS_aux (vs_aux, _)) = - match vs_aux with - | VS_val_spec (_, id, _, _) -> id - in + let val_spec_id (VS_aux (vs_aux, _)) = match vs_aux with VS_val_spec (_, id, _, _) -> id in let rec vs_ids = function | DEF_aux (DEF_val vs, _) :: defs -> val_spec_id vs :: vs_ids defs | _ :: defs -> vs_ids defs @@ -1207,130 +1134,134 @@ let record_ids defs = let rec get_scattered_union_clauses id = function | DEF_aux (DEF_scattered (SD_aux (SD_unioncl (uid, tu), _)), _) :: defs when Id.compare id uid = 0 -> - tu :: get_scattered_union_clauses id defs - | _ :: defs -> - get_scattered_union_clauses id defs + tu :: get_scattered_union_clauses id defs + | _ :: defs -> get_scattered_union_clauses id defs | [] -> [] -let order_compare (Ord_aux (o1,_)) (Ord_aux (o2,_)) = - match o1, o2 with +let order_compare (Ord_aux (o1, _)) (Ord_aux (o2, _)) = + match (o1, o2) with | Ord_var k1, Ord_var k2 -> Kid.compare k1 k2 | Ord_inc, Ord_inc -> 0 | Ord_dec, Ord_dec -> 0 - | Ord_var _, _ -> -1 | _, Ord_var _ -> 1 - | Ord_inc, _ -> -1 | _, Ord_inc -> 1 - -let lex_ord f g x1 x2 y1 y2 = - match f x1 x2 with - | 0 -> g y1 y2 - | n -> n - -let rec nc_compare (NC_aux (nc1,_)) (NC_aux (nc2,_)) = - match nc1, nc2 with - | NC_equal (n1,n2), NC_equal (n3,n4) - | NC_bounded_ge (n1,n2), NC_bounded_ge (n3,n4) - | NC_bounded_gt (n1,n2), NC_bounded_gt (n3,n4) - | NC_bounded_le (n1,n2), NC_bounded_le (n3,n4) - | NC_bounded_lt (n1,n2), NC_bounded_lt (n3,n4) - | NC_not_equal (n1,n2), NC_not_equal (n3,n4) - -> lex_ord Nexp.compare Nexp.compare n1 n3 n2 n4 - | NC_set (k1,s1), NC_set (k2,s2) -> - lex_ord Kid.compare (Util.compare_list Nat_big_num.compare) k1 k2 s1 s2 - | NC_or (nc1,nc2), NC_or (nc3,nc4) - | NC_and (nc1,nc2), NC_and (nc3,nc4) - -> lex_ord nc_compare nc_compare nc1 nc3 nc2 nc4 - | NC_app (f1,args1), NC_app (f2,args2) - -> lex_ord Id.compare (Util.compare_list typ_arg_compare) f1 f2 args1 args2 - | NC_var v1, NC_var v2 - -> Kid.compare v1 v2 - | NC_true, NC_true - | NC_false, NC_false - -> 0 - | NC_equal _, _ -> -1 | _, NC_equal _ -> 1 - | NC_bounded_ge _, _ -> -1 | _, NC_bounded_ge _ -> 1 - | NC_bounded_gt _, _ -> -1 | _, NC_bounded_gt _ -> 1 - | NC_bounded_le _, _ -> -1 | _, NC_bounded_le _ -> 1 - | NC_bounded_lt _, _ -> -1 | _, NC_bounded_lt _ -> 1 - | NC_not_equal _, _ -> -1 | _, NC_not_equal _ -> 1 - | NC_set _, _ -> -1 | _, NC_set _ -> 1 - | NC_or _, _ -> -1 | _, NC_or _ -> 1 - | NC_and _, _ -> -1 | _, NC_and _ -> 1 - | NC_app _, _ -> -1 | _, NC_app _ -> 1 - | NC_var _, _ -> -1 | _, NC_var _ -> 1 - | NC_true, _ -> -1 | _, NC_true -> 1 - -and typ_compare (Typ_aux (t1,_)) (Typ_aux (t2,_)) = - match t1,t2 with + | Ord_var _, _ -> -1 + | _, Ord_var _ -> 1 + | Ord_inc, _ -> -1 + | _, Ord_inc -> 1 + +let lex_ord f g x1 x2 y1 y2 = match f x1 x2 with 0 -> g y1 y2 | n -> n + +let rec nc_compare (NC_aux (nc1, _)) (NC_aux (nc2, _)) = + match (nc1, nc2) with + | NC_equal (n1, n2), NC_equal (n3, n4) + | NC_bounded_ge (n1, n2), NC_bounded_ge (n3, n4) + | NC_bounded_gt (n1, n2), NC_bounded_gt (n3, n4) + | NC_bounded_le (n1, n2), NC_bounded_le (n3, n4) + | NC_bounded_lt (n1, n2), NC_bounded_lt (n3, n4) + | NC_not_equal (n1, n2), NC_not_equal (n3, n4) -> + lex_ord Nexp.compare Nexp.compare n1 n3 n2 n4 + | NC_set (k1, s1), NC_set (k2, s2) -> lex_ord Kid.compare (Util.compare_list Nat_big_num.compare) k1 k2 s1 s2 + | NC_or (nc1, nc2), NC_or (nc3, nc4) | NC_and (nc1, nc2), NC_and (nc3, nc4) -> + lex_ord nc_compare nc_compare nc1 nc3 nc2 nc4 + | NC_app (f1, args1), NC_app (f2, args2) -> lex_ord Id.compare (Util.compare_list typ_arg_compare) f1 f2 args1 args2 + | NC_var v1, NC_var v2 -> Kid.compare v1 v2 + | NC_true, NC_true | NC_false, NC_false -> 0 + | NC_equal _, _ -> -1 + | _, NC_equal _ -> 1 + | NC_bounded_ge _, _ -> -1 + | _, NC_bounded_ge _ -> 1 + | NC_bounded_gt _, _ -> -1 + | _, NC_bounded_gt _ -> 1 + | NC_bounded_le _, _ -> -1 + | _, NC_bounded_le _ -> 1 + | NC_bounded_lt _, _ -> -1 + | _, NC_bounded_lt _ -> 1 + | NC_not_equal _, _ -> -1 + | _, NC_not_equal _ -> 1 + | NC_set _, _ -> -1 + | _, NC_set _ -> 1 + | NC_or _, _ -> -1 + | _, NC_or _ -> 1 + | NC_and _, _ -> -1 + | _, NC_and _ -> 1 + | NC_app _, _ -> -1 + | _, NC_app _ -> 1 + | NC_var _, _ -> -1 + | _, NC_var _ -> 1 + | NC_true, _ -> -1 + | _, NC_true -> 1 + +and typ_compare (Typ_aux (t1, _)) (Typ_aux (t2, _)) = + match (t1, t2) with | Typ_internal_unknown, Typ_internal_unknown -> 0 | Typ_id id1, Typ_id id2 -> Id.compare id1 id2 | Typ_var kid1, Typ_var kid2 -> Kid.compare kid1 kid2 - | Typ_fn (ts1,t2), Typ_fn (ts3,t4) -> - (match Util.compare_list typ_compare ts1 ts3 with - | 0 -> typ_compare t2 t4 - | n -> n) - | Typ_bidir (t1,t2), Typ_bidir (t3,t4) -> - (match typ_compare t1 t3 with - | 0 -> typ_compare t2 t4 - | n -> n) + | Typ_fn (ts1, t2), Typ_fn (ts3, t4) -> ( + match Util.compare_list typ_compare ts1 ts3 with 0 -> typ_compare t2 t4 | n -> n + ) + | Typ_bidir (t1, t2), Typ_bidir (t3, t4) -> ( + match typ_compare t1 t3 with 0 -> typ_compare t2 t4 | n -> n + ) | Typ_tuple ts1, Typ_tuple ts2 -> Util.compare_list typ_compare ts1 ts2 - | Typ_exist (ks1,nc1,t1), Typ_exist (ks2,nc2,t2) -> - (match Util.compare_list KOpt.compare ks1 ks2 with - | 0 -> (match nc_compare nc1 nc2 with - | 0 -> typ_compare t1 t2 - | n -> n) - | n -> n) - | Typ_app (id1,ts1), Typ_app (id2,ts2) -> - (match Id.compare id1 id2 with - | 0 -> Util.compare_list typ_arg_compare ts1 ts2 - | n -> n) - | Typ_internal_unknown, _ -> -1 | _, Typ_internal_unknown -> 1 - | Typ_id _, _ -> -1 | _, Typ_id _ -> 1 - | Typ_var _, _ -> -1 | _, Typ_var _ -> 1 - | Typ_fn _, _ -> -1 | _, Typ_fn _ -> 1 - | Typ_bidir _, _ -> -1 | _, Typ_bidir _ -> 1 - | Typ_tuple _, _ -> -1 | _, Typ_tuple _ -> 1 - | Typ_exist _, _ -> -1 | _, Typ_exist _ -> 1 - -and typ_arg_compare (A_aux (ta1,_)) (A_aux (ta2,_)) = - match ta1, ta2 with - | A_nexp n1, A_nexp n2 -> Nexp.compare n1 n2 - | A_typ t1, A_typ t2 -> typ_compare t1 t2 + | Typ_exist (ks1, nc1, t1), Typ_exist (ks2, nc2, t2) -> ( + match Util.compare_list KOpt.compare ks1 ks2 with + | 0 -> ( + match nc_compare nc1 nc2 with 0 -> typ_compare t1 t2 | n -> n + ) + | n -> n + ) + | Typ_app (id1, ts1), Typ_app (id2, ts2) -> ( + match Id.compare id1 id2 with 0 -> Util.compare_list typ_arg_compare ts1 ts2 | n -> n + ) + | Typ_internal_unknown, _ -> -1 + | _, Typ_internal_unknown -> 1 + | Typ_id _, _ -> -1 + | _, Typ_id _ -> 1 + | Typ_var _, _ -> -1 + | _, Typ_var _ -> 1 + | Typ_fn _, _ -> -1 + | _, Typ_fn _ -> 1 + | Typ_bidir _, _ -> -1 + | _, Typ_bidir _ -> 1 + | Typ_tuple _, _ -> -1 + | _, Typ_tuple _ -> 1 + | Typ_exist _, _ -> -1 + | _, Typ_exist _ -> 1 + +and typ_arg_compare (A_aux (ta1, _)) (A_aux (ta2, _)) = + match (ta1, ta2) with + | A_nexp n1, A_nexp n2 -> Nexp.compare n1 n2 + | A_typ t1, A_typ t2 -> typ_compare t1 t2 | A_order o1, A_order o2 -> order_compare o1 o2 | A_bool nc1, A_bool nc2 -> nc_compare nc1 nc2 - | A_nexp _, _ -> -1 | _, A_nexp _ -> 1 - | A_typ _, _ -> -1 | _, A_typ _ -> 1 - | A_order _, _ -> -1 | _, A_order _ -> 1 + | A_nexp _, _ -> -1 + | _, A_nexp _ -> 1 + | A_typ _, _ -> -1 + | _, A_typ _ -> 1 + | A_order _, _ -> -1 + | _, A_order _ -> 1 -let is_typ_arg_nexp = function - | (A_aux (A_typ _, _)) -> true - | _ -> false +let is_typ_arg_nexp = function A_aux (A_typ _, _) -> true | _ -> false -let is_typ_arg_typ = function - | (A_aux (A_typ _, _)) -> true - | _ -> false - -let is_typ_arg_order = function - | (A_aux (A_order _, _)) -> true - | _ -> false +let is_typ_arg_typ = function A_aux (A_typ _, _) -> true | _ -> false -let is_typ_arg_bool = function - | (A_aux (A_bool _, _)) -> true - | _ -> false +let is_typ_arg_order = function A_aux (A_order _, _) -> true | _ -> false + +let is_typ_arg_bool = function A_aux (A_bool _, _) -> true | _ -> false module NC = struct type t = n_constraint let compare = nc_compare end -module NCMap = Map.Make(NC) +module NCMap = Map.Make (NC) module Typ = struct type t = typ let compare = typ_compare end -module TypMap = Map.Make(Typ) - +module TypMap = Map.Make (Typ) + let rec nexp_frees (Nexp_aux (nexp, l)) = match nexp with | Nexp_id _ -> raise (Reporting.err_typ l "Unimplemented Nexp_id in nexp_frees") @@ -1348,63 +1279,54 @@ let rec lexp_to_exp (LE_aux (lexp_aux, annot)) = match lexp_aux with | LE_id id | LE_typ (_, id) -> rewrap (E_id id) | LE_tuple les -> - let get_id (LE_aux(lexp,((l,_) as annot)) as le) = match lexp with - | LE_id id | LE_typ (_, id) -> E_aux (E_id id, annot) - | _ -> - raise (Reporting.err_unreachable l __POS__ - ("Unsupported sub-lexp " ^ string_of_lexp le ^ " in tuple")) in - rewrap (E_tuple (List.map get_id les)) + let get_id (LE_aux (lexp, ((l, _) as annot)) as le) = + match lexp with + | LE_id id | LE_typ (_, id) -> E_aux (E_id id, annot) + | _ -> raise (Reporting.err_unreachable l __POS__ ("Unsupported sub-lexp " ^ string_of_lexp le ^ " in tuple")) + in + rewrap (E_tuple (List.map get_id les)) | LE_vector (lexp, e) -> rewrap (E_vector_access (lexp_to_exp lexp, e)) | LE_vector_range (lexp, e1, e2) -> rewrap (E_vector_subrange (lexp_to_exp lexp, e1, e2)) | LE_field (lexp, id) -> rewrap (E_field (lexp_to_exp lexp, id)) | LE_app (id, exps) -> rewrap (E_app (id, exps)) | LE_vector_concat [] -> rewrap (E_vector []) | LE_vector_concat (lexp :: lexps) -> - List.fold_left (fun exp lexp -> rewrap (E_vector_append (exp, lexp_to_exp lexp))) (lexp_to_exp lexp) lexps + List.fold_left (fun exp lexp -> rewrap (E_vector_append (exp, lexp_to_exp lexp))) (lexp_to_exp lexp) lexps | LE_deref exp -> rewrap (E_app (mk_id "__deref", [exp])) -let is_unit_typ = function - | Typ_aux (Typ_id u, _) -> string_of_id u = "unit" - | _ -> false +let is_unit_typ = function Typ_aux (Typ_id u, _) -> string_of_id u = "unit" | _ -> false -let is_number (Typ_aux (t,_)) = +let is_number (Typ_aux (t, _)) = match t with | Typ_id (Id_aux (Id "int", _)) | Typ_id (Id_aux (Id "nat", _)) - | Typ_app (Id_aux (Id "range", _),_) - | Typ_app (Id_aux (Id "implicit", _),_) - | Typ_app (Id_aux (Id "atom", _),_) -> true + | Typ_app (Id_aux (Id "range", _), _) + | Typ_app (Id_aux (Id "implicit", _), _) + | Typ_app (Id_aux (Id "atom", _), _) -> + true | _ -> false -let is_ref_typ (Typ_aux (typ_aux, _)) = match typ_aux with - | Typ_app (id, _) -> string_of_id id = "register" || string_of_id id = "reg" - | _ -> false +let is_ref_typ (Typ_aux (typ_aux, _)) = + match typ_aux with Typ_app (id, _) -> string_of_id id = "register" || string_of_id id = "reg" | _ -> false let rec is_vector_typ = function - | Typ_aux (Typ_app (Id_aux (Id "vector",_), [_;_;_]), _) -> true - | Typ_aux (Typ_app (Id_aux (Id "register",_), [A_aux (A_typ rtyp,_)]), _) -> - is_vector_typ rtyp + | Typ_aux (Typ_app (Id_aux (Id "vector", _), [_; _; _]), _) -> true + | Typ_aux (Typ_app (Id_aux (Id "register", _), [A_aux (A_typ rtyp, _)]), _) -> is_vector_typ rtyp | _ -> false let typ_app_args_of = function - | Typ_aux (Typ_app (Id_aux (Id c,_), targs), l) -> - (c, List.map (fun (A_aux (a,_)) -> a) targs, l) - | Typ_aux (_, l) as typ -> - raise (Reporting.err_typ l - ("typ_app_args_of called on non-app type " ^ string_of_typ typ)) - -let rec vector_typ_args_of typ = match typ_app_args_of typ with - | ("vector", [A_nexp len; A_order ord; A_typ etyp], _) -> - (nexp_simp len, ord, etyp) - | ("bitvector", [A_nexp len; A_order ord], _) -> - (nexp_simp len, ord, bit_typ) - | ("register", [A_typ rtyp], _) -> vector_typ_args_of rtyp - | (_, _, l) -> - raise (Reporting.err_typ l - ("vector_typ_args_of called on non-vector type " ^ string_of_typ typ)) + | Typ_aux (Typ_app (Id_aux (Id c, _), targs), l) -> (c, List.map (fun (A_aux (a, _)) -> a) targs, l) + | Typ_aux (_, l) as typ -> raise (Reporting.err_typ l ("typ_app_args_of called on non-app type " ^ string_of_typ typ)) + +let rec vector_typ_args_of typ = + match typ_app_args_of typ with + | "vector", [A_nexp len; A_order ord; A_typ etyp], _ -> (nexp_simp len, ord, etyp) + | "bitvector", [A_nexp len; A_order ord], _ -> (nexp_simp len, ord, bit_typ) + | "register", [A_typ rtyp], _ -> vector_typ_args_of rtyp + | _, _, l -> raise (Reporting.err_typ l ("vector_typ_args_of called on non-vector type " ^ string_of_typ typ)) let vector_start_index typ = - let (len, ord, _) = vector_typ_args_of typ in + let len, ord, _ = vector_typ_args_of typ in match ord with | Ord_aux (Ord_inc, _) -> nint 0 | Ord_aux (Ord_dec, _) -> nexp_simp (nminus len (nint 1)) @@ -1414,48 +1336,35 @@ let is_order_inc = function | Ord_aux (Ord_inc, _) -> true | Ord_aux (Ord_dec, _) -> false | Ord_aux (Ord_var _, l) -> - raise (Reporting.err_unreachable l __POS__ "is_order_inc called on vector with variable ordering") + raise (Reporting.err_unreachable l __POS__ "is_order_inc called on vector with variable ordering") -let is_bit_typ = function - | Typ_aux (Typ_id (Id_aux (Id "bit", _)), _) -> true - | _ -> false +let is_bit_typ = function Typ_aux (Typ_id (Id_aux (Id "bit", _)), _) -> true | _ -> false let rec is_bitvector_typ = function - | Typ_aux (Typ_app (Id_aux (Id "bitvector", _), [_;_]), _) -> true - | Typ_aux (Typ_app (Id_aux (Id "register",_), [A_aux (A_typ rtyp,_)]), _) -> - is_bitvector_typ rtyp + | Typ_aux (Typ_app (Id_aux (Id "bitvector", _), [_; _]), _) -> true + | Typ_aux (Typ_app (Id_aux (Id "register", _), [A_aux (A_typ rtyp, _)]), _) -> is_bitvector_typ rtyp | _ -> false (* Utilities for constructing effect sets *) let effectful e = e - + let union_effects e1 e2 = e1 || e2 let equal_effects e1 e2 = e1 = e2 -let subseteq_effects e1 e2 = - match e1, e2 with - | false, _ -> true - | true, true -> true - | true, false -> false - -let rec kopts_of_nexp (Nexp_aux (nexp,_)) = +let subseteq_effects e1 e2 = match (e1, e2) with false, _ -> true | true, true -> true | true, false -> false + +let rec kopts_of_nexp (Nexp_aux (nexp, _)) = match nexp with - | Nexp_id _ - | Nexp_constant _ -> KOptSet.empty + | Nexp_id _ | Nexp_constant _ -> KOptSet.empty | Nexp_var kid -> KOptSet.singleton (mk_kopt K_int kid) - | Nexp_times (n1,n2) - | Nexp_sum (n1,n2) - | Nexp_minus (n1,n2) -> KOptSet.union (kopts_of_nexp n1) (kopts_of_nexp n2) - | Nexp_exp n - | Nexp_neg n -> kopts_of_nexp n + | Nexp_times (n1, n2) | Nexp_sum (n1, n2) | Nexp_minus (n1, n2) -> KOptSet.union (kopts_of_nexp n1) (kopts_of_nexp n2) + | Nexp_exp n | Nexp_neg n -> kopts_of_nexp n | Nexp_app (_, nexps) -> List.fold_left KOptSet.union KOptSet.empty (List.map kopts_of_nexp nexps) let kopts_of_order (Ord_aux (ord, _)) = - match ord with - | Ord_var kid -> KOptSet.singleton (mk_kopt K_order kid) - | Ord_inc | Ord_dec -> KOptSet.empty + match ord with Ord_var kid -> KOptSet.singleton (mk_kopt K_order kid) | Ord_inc | Ord_dec -> KOptSet.empty let rec kopts_of_constraint (NC_aux (nc, _)) = match nc with @@ -1465,54 +1374,42 @@ let rec kopts_of_constraint (NC_aux (nc, _)) = | NC_bounded_le (nexp1, nexp2) | NC_bounded_lt (nexp1, nexp2) | NC_not_equal (nexp1, nexp2) -> - KOptSet.union (kopts_of_nexp nexp1) (kopts_of_nexp nexp2) + KOptSet.union (kopts_of_nexp nexp1) (kopts_of_nexp nexp2) | NC_set (kid, _) -> KOptSet.singleton (mk_kopt K_int kid) - | NC_or (nc1, nc2) - | NC_and (nc1, nc2) -> - KOptSet.union (kopts_of_constraint nc1) (kopts_of_constraint nc2) - | NC_app (_, args) -> - List.fold_left (fun s t -> KOptSet.union s (kopts_of_typ_arg t)) KOptSet.empty args + | NC_or (nc1, nc2) | NC_and (nc1, nc2) -> KOptSet.union (kopts_of_constraint nc1) (kopts_of_constraint nc2) + | NC_app (_, args) -> List.fold_left (fun s t -> KOptSet.union s (kopts_of_typ_arg t)) KOptSet.empty args | NC_var kid -> KOptSet.singleton (mk_kopt K_bool kid) | NC_true | NC_false -> KOptSet.empty -and kopts_of_typ (Typ_aux (t,_)) = +and kopts_of_typ (Typ_aux (t, _)) = match t with | Typ_internal_unknown -> KOptSet.empty | Typ_id _ -> KOptSet.empty | Typ_var kid -> KOptSet.singleton (mk_kopt K_type kid) | Typ_fn (ts, t) -> List.fold_left KOptSet.union (kopts_of_typ t) (List.map kopts_of_typ ts) | Typ_bidir (t1, t2) -> KOptSet.union (kopts_of_typ t1) (kopts_of_typ t2) - | Typ_tuple ts -> - List.fold_left (fun s t -> KOptSet.union s (kopts_of_typ t)) - KOptSet.empty ts - | Typ_app (_,tas) -> - List.fold_left (fun s ta -> KOptSet.union s (kopts_of_typ_arg ta)) - KOptSet.empty tas + | Typ_tuple ts -> List.fold_left (fun s t -> KOptSet.union s (kopts_of_typ t)) KOptSet.empty ts + | Typ_app (_, tas) -> List.fold_left (fun s ta -> KOptSet.union s (kopts_of_typ_arg ta)) KOptSet.empty tas | Typ_exist (kopts, nc, t) -> - let s = KOptSet.union (kopts_of_typ t) (kopts_of_constraint nc) in - KOptSet.diff s (KOptSet.of_list kopts) -and kopts_of_typ_arg (A_aux (ta,_)) = + let s = KOptSet.union (kopts_of_typ t) (kopts_of_constraint nc) in + KOptSet.diff s (KOptSet.of_list kopts) + +and kopts_of_typ_arg (A_aux (ta, _)) = match ta with | A_nexp nexp -> kopts_of_nexp nexp | A_typ typ -> kopts_of_typ typ | A_order ord -> kopts_of_order ord | A_bool nc -> kopts_of_constraint nc -let kopts_of_quant_item (QI_aux (qi, _)) = match qi with - | QI_id kopt -> - KOptSet.singleton kopt - | QI_constraint nc -> kopts_of_constraint nc +let kopts_of_quant_item (QI_aux (qi, _)) = + match qi with QI_id kopt -> KOptSet.singleton kopt | QI_constraint nc -> kopts_of_constraint nc -let rec tyvars_of_nexp (Nexp_aux (nexp,_)) = +let rec tyvars_of_nexp (Nexp_aux (nexp, _)) = match nexp with - | Nexp_id _ - | Nexp_constant _ -> KidSet.empty + | Nexp_id _ | Nexp_constant _ -> KidSet.empty | Nexp_var kid -> KidSet.singleton kid - | Nexp_times (n1,n2) - | Nexp_sum (n1,n2) - | Nexp_minus (n1,n2) -> KidSet.union (tyvars_of_nexp n1) (tyvars_of_nexp n2) - | Nexp_exp n - | Nexp_neg n -> tyvars_of_nexp n + | Nexp_times (n1, n2) | Nexp_sum (n1, n2) | Nexp_minus (n1, n2) -> KidSet.union (tyvars_of_nexp n1) (tyvars_of_nexp n2) + | Nexp_exp n | Nexp_neg n -> tyvars_of_nexp n | Nexp_app (_, nexps) -> List.fold_left KidSet.union KidSet.empty (List.map tyvars_of_nexp nexps) let rec tyvars_of_constraint (NC_aux (nc, _)) = @@ -1523,43 +1420,36 @@ let rec tyvars_of_constraint (NC_aux (nc, _)) = | NC_bounded_le (nexp1, nexp2) | NC_bounded_lt (nexp1, nexp2) | NC_not_equal (nexp1, nexp2) -> - KidSet.union (tyvars_of_nexp nexp1) (tyvars_of_nexp nexp2) + KidSet.union (tyvars_of_nexp nexp1) (tyvars_of_nexp nexp2) | NC_set (kid, _) -> KidSet.singleton kid - | NC_or (nc1, nc2) - | NC_and (nc1, nc2) -> - KidSet.union (tyvars_of_constraint nc1) (tyvars_of_constraint nc2) - | NC_app (_, args) -> - List.fold_left (fun s t -> KidSet.union s (tyvars_of_typ_arg t)) KidSet.empty args + | NC_or (nc1, nc2) | NC_and (nc1, nc2) -> KidSet.union (tyvars_of_constraint nc1) (tyvars_of_constraint nc2) + | NC_app (_, args) -> List.fold_left (fun s t -> KidSet.union s (tyvars_of_typ_arg t)) KidSet.empty args | NC_var kid -> KidSet.singleton kid - | NC_true - | NC_false -> KidSet.empty + | NC_true | NC_false -> KidSet.empty -and tyvars_of_typ (Typ_aux (t,_)) = +and tyvars_of_typ (Typ_aux (t, _)) = match t with | Typ_internal_unknown -> KidSet.empty | Typ_id _ -> KidSet.empty | Typ_var kid -> KidSet.singleton kid | Typ_fn (ts, t) -> List.fold_left KidSet.union (tyvars_of_typ t) (List.map tyvars_of_typ ts) | Typ_bidir (t1, t2) -> KidSet.union (tyvars_of_typ t1) (tyvars_of_typ t2) - | Typ_tuple ts -> - List.fold_left (fun s t -> KidSet.union s (tyvars_of_typ t)) - KidSet.empty ts - | Typ_app (_,tas) -> - List.fold_left (fun s ta -> KidSet.union s (tyvars_of_typ_arg ta)) - KidSet.empty tas + | Typ_tuple ts -> List.fold_left (fun s t -> KidSet.union s (tyvars_of_typ t)) KidSet.empty ts + | Typ_app (_, tas) -> List.fold_left (fun s ta -> KidSet.union s (tyvars_of_typ_arg ta)) KidSet.empty tas | Typ_exist (kids, nc, t) -> - let s = KidSet.union (tyvars_of_typ t) (tyvars_of_constraint nc) in - List.fold_left (fun s k -> KidSet.remove k s) s (List.map kopt_kid kids) -and tyvars_of_typ_arg (A_aux (ta,_)) = + let s = KidSet.union (tyvars_of_typ t) (tyvars_of_constraint nc) in + List.fold_left (fun s k -> KidSet.remove k s) s (List.map kopt_kid kids) + +and tyvars_of_typ_arg (A_aux (ta, _)) = match ta with | A_nexp nexp -> tyvars_of_nexp nexp | A_typ typ -> tyvars_of_typ typ | A_order _ -> KidSet.empty | A_bool nc -> tyvars_of_constraint nc -let tyvars_of_quant_item (QI_aux (qi, _)) = match qi with - | QI_id (KOpt_aux (KOpt_kind (_, kid), _)) -> - KidSet.singleton kid +let tyvars_of_quant_item (QI_aux (qi, _)) = + match qi with + | QI_id (KOpt_aux (KOpt_kind (_, kid), _)) -> KidSet.singleton kid | QI_constraint nc -> tyvars_of_constraint nc let is_kid_generated kid = String.contains (string_of_kid kid) '#' @@ -1567,29 +1457,25 @@ let is_kid_generated kid = String.contains (string_of_kid kid) '#' let rec undefined_of_typ mwords l annot (Typ_aux (typ_aux, _) as typ) = let wrap e_aux typ = E_aux (e_aux, (l, annot typ)) in match typ_aux with - | Typ_id id -> - wrap (E_app (prepend_id "undefined_" id, [wrap (E_lit (mk_lit L_unit)) unit_typ])) typ - | Typ_app (_,[size;_;_]) when mwords && is_bitvector_typ typ -> - wrap (E_app (mk_id "undefined_bitvector", - undefined_of_typ_args mwords l annot size)) typ - | Typ_app (atom, [A_aux (A_nexp i, _)]) when string_of_id atom = "atom" -> - wrap (E_sizeof i) typ + | Typ_id id -> wrap (E_app (prepend_id "undefined_" id, [wrap (E_lit (mk_lit L_unit)) unit_typ])) typ + | Typ_app (_, [size; _; _]) when mwords && is_bitvector_typ typ -> + wrap (E_app (mk_id "undefined_bitvector", undefined_of_typ_args mwords l annot size)) typ + | Typ_app (atom, [A_aux (A_nexp i, _)]) when string_of_id atom = "atom" -> wrap (E_sizeof i) typ | Typ_app (id, args) -> - wrap (E_app (prepend_id "undefined_" id, - List.concat (List.map (undefined_of_typ_args mwords l annot) args))) typ - | Typ_tuple typs -> - wrap (E_tuple (List.map (undefined_of_typ mwords l annot) typs)) typ + wrap (E_app (prepend_id "undefined_" id, List.concat (List.map (undefined_of_typ_args mwords l annot) args))) typ + | Typ_tuple typs -> wrap (E_tuple (List.map (undefined_of_typ mwords l annot) typs)) typ | Typ_var kid -> - (* Undefined monomorphism restriction in the type checker should - guarantee that the typ_(kid) parameter was always one created - in an undefined_(type) function created in - initial_check.ml. i.e. the rewriter should only encounter this - case when re-writing those functions. *) - wrap (E_id (prepend_id "typ_" (id_of_kid kid))) typ + (* Undefined monomorphism restriction in the type checker should + guarantee that the typ_(kid) parameter was always one created + in an undefined_(type) function created in + initial_check.ml. i.e. the rewriter should only encounter this + case when re-writing those functions. *) + wrap (E_id (prepend_id "typ_" (id_of_kid kid))) typ | Typ_internal_unknown -> assert false | Typ_bidir _ -> assert false | Typ_fn _ -> assert false | Typ_exist _ -> assert false (* Typ_exist should be re-written *) + and undefined_of_typ_args mwords l annot (A_aux (typ_arg_aux, _)) = match typ_arg_aux with | A_nexp n -> [E_aux (E_sizeof n, (l, annot (atom_typ n)))] @@ -1597,32 +1483,28 @@ and undefined_of_typ_args mwords l annot (A_aux (typ_arg_aux, _)) = | A_bool nc -> [E_aux (E_constraint nc, (l, annot (atom_bool_typ nc)))] | A_order _ -> [] -let destruct_pexp (Pat_aux (pexp,ann)) = +let destruct_pexp (Pat_aux (pexp, ann)) = match pexp with - | Pat_exp (pat,exp) -> pat,None,exp,ann - | Pat_when (pat,guard,exp) -> pat,Some guard,exp,ann + | Pat_exp (pat, exp) -> (pat, None, exp, ann) + | Pat_when (pat, guard, exp) -> (pat, Some guard, exp, ann) -let construct_pexp (pat,guard,exp,ann) = - match guard with - | None -> Pat_aux (Pat_exp (pat,exp),ann) - | Some guard -> Pat_aux (Pat_when (pat,guard,exp),ann) +let construct_pexp (pat, guard, exp, ann) = + match guard with None -> Pat_aux (Pat_exp (pat, exp), ann) | Some guard -> Pat_aux (Pat_when (pat, guard, exp), ann) -let destruct_mpexp (MPat_aux (mpexp,ann)) = - match mpexp with - | MPat_pat mpat -> mpat,None,ann - | MPat_when (mpat,guard) -> mpat,Some guard,ann +let destruct_mpexp (MPat_aux (mpexp, ann)) = + match mpexp with MPat_pat mpat -> (mpat, None, ann) | MPat_when (mpat, guard) -> (mpat, Some guard, ann) -let construct_mpexp (mpat,guard,ann) = - match guard with - | None -> MPat_aux (MPat_pat mpat,ann) - | Some guard -> MPat_aux (MPat_when (mpat,guard),ann) +let construct_mpexp (mpat, guard, ann) = + match guard with None -> MPat_aux (MPat_pat mpat, ann) | Some guard -> MPat_aux (MPat_when (mpat, guard), ann) let is_valspec id = function | DEF_aux (DEF_val (VS_aux (VS_val_spec (_, id', _, _), _)), _) when Id.compare id id' = 0 -> true | _ -> false let is_fundef id = function - | DEF_aux (DEF_fundef (FD_aux (FD_function (_, _, FCL_aux (FCL_funcl (id', _), _) :: _), _)), _) when Id.compare id' id = 0 -> true + | DEF_aux (DEF_fundef (FD_aux (FD_function (_, _, FCL_aux (FCL_funcl (id', _), _) :: _), _)), _) + when Id.compare id' id = 0 -> + true | _ -> false let rename_valspec id (VS_aux (VS_val_spec (typschm, _, externs, is_cast), annot)) = @@ -1642,8 +1524,7 @@ let rec split_defs' f defs acc = let split_defs f defs = match split_defs' f defs [] with | None -> None - | Some (pre_defs, def, post_defs) -> - Some (List.rev pre_defs, def, post_defs) + | Some (pre_defs, def, post_defs) -> Some (List.rev pre_defs, def, post_defs) let append_ast ast1 ast2 = { defs = ast1.defs @ ast2.defs; comments = ast1.comments @ ast2.comments } let append_ast_defs ast defs = { ast with defs = ast.defs @ defs } @@ -1653,73 +1534,55 @@ let type_union_id (Tu_aux (Tu_ty_id (_, id), _)) = id let rec subst id value (E_aux (e_aux, annot) as exp) = let wrap e_aux = E_aux (e_aux, annot) in - let e_aux = match e_aux with + let e_aux = + match e_aux with | E_block exps -> E_block (List.map (subst id value) exps) | E_id id' -> if Id.compare id id' = 0 then unaux_exp value else E_id id' | E_lit lit -> E_lit lit | E_typ (typ, exp) -> E_typ (typ, subst id value exp) - | E_app (fn, exps) -> E_app (fn, List.map (subst id value) exps) | E_app_infix (exp1, op, exp2) -> E_app_infix (subst id value exp1, op, subst id value exp2) - | E_tuple exps -> E_tuple (List.map (subst id value) exps) - - | E_if (cond, then_exp, else_exp) -> - E_if (subst id value cond, subst id value then_exp, subst id value else_exp) - + | E_if (cond, then_exp, else_exp) -> E_if (subst id value cond, subst id value then_exp, subst id value else_exp) | E_loop (loop, measure, cond, body) -> - E_loop (loop, subst_measure id value measure, subst id value cond, subst id value body) - | E_for (id', exp1, exp2, exp3, order, body) when Id.compare id id' = 0 -> - E_for (id', exp1, exp2, exp3, order, body) + E_loop (loop, subst_measure id value measure, subst id value cond, subst id value body) + | E_for (id', exp1, exp2, exp3, order, body) when Id.compare id id' = 0 -> E_for (id', exp1, exp2, exp3, order, body) | E_for (id', exp1, exp2, exp3, order, body) -> - E_for (id', subst id value exp1, subst id value exp2, subst id value exp3, order, subst id value body) - + E_for (id', subst id value exp1, subst id value exp2, subst id value exp3, order, subst id value body) | E_vector exps -> E_vector (List.map (subst id value) exps) | E_vector_access (exp1, exp2) -> E_vector_access (subst id value exp1, subst id value exp2) - | E_vector_subrange (exp1, exp2, exp3) -> E_vector_subrange (subst id value exp1, subst id value exp2, subst id value exp3) - | E_vector_update (exp1, exp2, exp3) -> E_vector_update (subst id value exp1, subst id value exp2, subst id value exp3) + | E_vector_subrange (exp1, exp2, exp3) -> + E_vector_subrange (subst id value exp1, subst id value exp2, subst id value exp3) + | E_vector_update (exp1, exp2, exp3) -> + E_vector_update (subst id value exp1, subst id value exp2, subst id value exp3) | E_vector_update_subrange (exp1, exp2, exp3, exp4) -> - E_vector_update_subrange (subst id value exp1, subst id value exp2, subst id value exp3, subst id value exp4) + E_vector_update_subrange (subst id value exp1, subst id value exp2, subst id value exp3, subst id value exp4) | E_vector_append (exp1, exp2) -> E_vector_append (subst id value exp1, subst id value exp2) - | E_list exps -> E_list (List.map (subst id value) exps) | E_cons (exp1, exp2) -> E_cons (subst id value exp1, subst id value exp2) - | E_struct fexps -> E_struct (List.map (subst_fexp id value) fexps) | E_struct_update (exp, fexps) -> E_struct_update (subst id value exp, List.map (subst_fexp id value) fexps) | E_field (exp, id') -> E_field (subst id value exp, id') - - | E_match (exp, pexps) -> - E_match (subst id value exp, List.map (subst_pexp id value) pexps) - + | E_match (exp, pexps) -> E_match (subst id value exp, List.map (subst_pexp id value) pexps) | E_let (LB_aux (LB_val (pat, bind), lb_annot), body) -> - E_let (LB_aux (LB_val (pat, subst id value bind), lb_annot), - if IdSet.mem id (pat_ids pat) then body else subst id value body) - + E_let + ( LB_aux (LB_val (pat, subst id value bind), lb_annot), + if IdSet.mem id (pat_ids pat) then body else subst id value body + ) | E_assign (lexp, exp) -> E_assign (subst_lexp id value lexp, subst id value exp) (* Shadowing... *) - (* Should be re-written *) | E_sizeof nexp -> E_sizeof nexp | E_constraint nc -> E_constraint nc - | E_return exp -> E_return (subst id value exp) | E_exit exp -> E_exit (subst id value exp) - (* id should always be immutable while id' must be mutable register name so should be ok to never substitute here *) | E_ref id' -> E_ref id' | E_throw exp -> E_throw (subst id value exp) - - | E_try (exp, pexps) -> - E_try (subst id value exp, List.map (subst_pexp id value) pexps) - + | E_try (exp, pexps) -> E_try (subst id value exp, List.map (subst_pexp id value) pexps) | E_assert (exp1, exp2) -> E_assert (subst id value exp1, subst id value exp2) - | E_internal_value v -> E_internal_value v - | E_var (lexp, exp1, exp2) -> E_var (subst_lexp id value lexp, subst id value exp1, subst id value exp2) - | E_internal_assume (nc, exp) -> E_internal_assume (nc, subst id value exp) - | E_internal_plet _ | E_internal_return _ -> failwith ("subst " ^ string_of_exp exp) in wrap e_aux @@ -1730,7 +1593,8 @@ and subst_measure id value (Measure_aux (m_aux, l)) = | Measure_some exp -> Measure_aux (Measure_some (subst id value exp), l) and subst_pexp id value (Pat_aux (pexp_aux, annot)) = - let pexp_aux = match pexp_aux with + let pexp_aux = + match pexp_aux with | Pat_exp (pat, exp) when IdSet.mem id (pat_ids pat) -> Pat_exp (pat, exp) | Pat_exp (pat, exp) -> Pat_exp (pat, subst id value exp) | Pat_when (pat, guard, exp) when IdSet.mem id (pat_ids pat) -> Pat_when (pat, guard, exp) @@ -1738,66 +1602,68 @@ and subst_pexp id value (Pat_aux (pexp_aux, annot)) = in Pat_aux (pexp_aux, annot) -and subst_fexp id value (FE_aux (FE_fexp (id', exp), annot)) = - FE_aux (FE_fexp (id', subst id value exp), annot) +and subst_fexp id value (FE_aux (FE_fexp (id', exp), annot)) = FE_aux (FE_fexp (id', subst id value exp), annot) and subst_lexp id value (LE_aux (lexp_aux, annot)) = let wrap lexp_aux = LE_aux (lexp_aux, annot) in - let lexp_aux = match lexp_aux with + let lexp_aux = + match lexp_aux with | LE_deref exp -> LE_deref (subst id value exp) | LE_id id' -> LE_id id' | LE_app (f, exps) -> LE_app (f, List.map (subst id value) exps) | LE_typ (typ, id') -> LE_typ (typ, id') | LE_tuple lexps -> LE_tuple (List.map (subst_lexp id value) lexps) | LE_vector (lexp, exp) -> LE_vector (subst_lexp id value lexp, subst id value exp) - | LE_vector_range (lexp, exp1, exp2) -> LE_vector_range (subst_lexp id value lexp, subst id value exp1, subst id value exp2) - | LE_vector_concat lexps -> - LE_vector_concat (List.map (subst_lexp id value) lexps) + | LE_vector_range (lexp, exp1, exp2) -> + LE_vector_range (subst_lexp id value lexp, subst id value exp1, subst id value exp2) + | LE_vector_concat lexps -> LE_vector_concat (List.map (subst_lexp id value) lexps) | LE_field (lexp, id') -> LE_field (subst_lexp id value lexp, id') in wrap lexp_aux let hex_to_bin hex = - Util.string_to_list hex - |> List.map Sail_lib.hex_char - |> List.concat - |> List.map Sail_lib.char_of_bit - |> (fun bits -> String.init (List.length bits) (List.nth bits)) + Util.string_to_list hex |> List.map Sail_lib.hex_char |> List.concat |> List.map Sail_lib.char_of_bit |> fun bits -> + String.init (List.length bits) (List.nth bits) let explode s = let rec exp i l = if i < 0 then l else exp (i - 1) (s.[i] :: l) in exp (String.length s - 1) [] let vector_string_to_bit_list (L_aux (lit, l)) = - let hexchar_to_binlist = function - | '0' -> ['0';'0';'0';'0'] - | '1' -> ['0';'0';'0';'1'] - | '2' -> ['0';'0';'1';'0'] - | '3' -> ['0';'0';'1';'1'] - | '4' -> ['0';'1';'0';'0'] - | '5' -> ['0';'1';'0';'1'] - | '6' -> ['0';'1';'1';'0'] - | '7' -> ['0';'1';'1';'1'] - | '8' -> ['1';'0';'0';'0'] - | '9' -> ['1';'0';'0';'1'] - | 'A' -> ['1';'0';'1';'0'] - | 'B' -> ['1';'0';'1';'1'] - | 'C' -> ['1';'1';'0';'0'] - | 'D' -> ['1';'1';'0';'1'] - | 'E' -> ['1';'1';'1';'0'] - | 'F' -> ['1';'1';'1';'1'] - | _ -> raise (Reporting.err_unreachable l __POS__ "hexchar_to_binlist given unrecognized character") in - - let s_bin = match lit with + | '0' -> ['0'; '0'; '0'; '0'] + | '1' -> ['0'; '0'; '0'; '1'] + | '2' -> ['0'; '0'; '1'; '0'] + | '3' -> ['0'; '0'; '1'; '1'] + | '4' -> ['0'; '1'; '0'; '0'] + | '5' -> ['0'; '1'; '0'; '1'] + | '6' -> ['0'; '1'; '1'; '0'] + | '7' -> ['0'; '1'; '1'; '1'] + | '8' -> ['1'; '0'; '0'; '0'] + | '9' -> ['1'; '0'; '0'; '1'] + | 'A' -> ['1'; '0'; '1'; '0'] + | 'B' -> ['1'; '0'; '1'; '1'] + | 'C' -> ['1'; '1'; '0'; '0'] + | 'D' -> ['1'; '1'; '0'; '1'] + | 'E' -> ['1'; '1'; '1'; '0'] + | 'F' -> ['1'; '1'; '1'; '1'] + | _ -> raise (Reporting.err_unreachable l __POS__ "hexchar_to_binlist given unrecognized character") + in + + let s_bin = + match lit with | L_hex s_hex -> List.flatten (List.map hexchar_to_binlist (explode (String.uppercase_ascii s_hex))) | L_bin s_bin -> explode s_bin - | _ -> raise (Reporting.err_unreachable l __POS__ "s_bin given non vector literal") in - - List.map (function '0' -> L_aux (L_zero, gen_loc l) - | '1' -> L_aux (L_one, gen_loc l) - | _ -> raise (Reporting.err_unreachable (gen_loc l) __POS__ "binary had non-zero or one")) s_bin + | _ -> raise (Reporting.err_unreachable l __POS__ "s_bin given non vector literal") + in + List.map + (function + | '0' -> L_aux (L_zero, gen_loc l) + | '1' -> L_aux (L_one, gen_loc l) + | _ -> raise (Reporting.err_unreachable (gen_loc l) __POS__ "binary had non-zero or one") + ) + s_bin (* Functions for working with locations *) @@ -1807,21 +1673,17 @@ let locate_kid f (Kid_aux (name, l)) = Kid_aux (name, f l) let locate_kind f (K_aux (kind, l)) = K_aux (kind, f l) -let locate_kinded_id f (KOpt_aux (KOpt_kind (k, kid), l)) = - KOpt_aux (KOpt_kind (locate_kind f k, locate_kid f kid), f l) +let locate_kinded_id f (KOpt_aux (KOpt_kind (k, kid), l)) = KOpt_aux (KOpt_kind (locate_kind f k, locate_kid f kid), f l) let locate_lit f (L_aux (lit, l)) = L_aux (lit, f l) let locate_order f (Ord_aux (ord_aux, l)) = - let ord_aux = match ord_aux with - | Ord_inc -> Ord_inc - | Ord_dec -> Ord_dec - | Ord_var v -> Ord_var (locate_kid f v) - in + let ord_aux = match ord_aux with Ord_inc -> Ord_inc | Ord_dec -> Ord_dec | Ord_var v -> Ord_var (locate_kid f v) in Ord_aux (ord_aux, f l) let rec locate_nexp f (Nexp_aux (nexp_aux, l)) = - let nexp_aux = match nexp_aux with + let nexp_aux = + match nexp_aux with | Nexp_id id -> Nexp_id (locate_id f id) | Nexp_var kid -> Nexp_var (locate_kid f kid) | Nexp_constant n -> Nexp_constant n @@ -1835,7 +1697,8 @@ let rec locate_nexp f (Nexp_aux (nexp_aux, l)) = Nexp_aux (nexp_aux, f l) let rec locate_nc f (NC_aux (nc_aux, l)) = - let nc_aux = match nc_aux with + let nc_aux = + match nc_aux with | NC_equal (nexp1, nexp2) -> NC_equal (locate_nexp f nexp1, locate_nexp f nexp2) | NC_bounded_ge (nexp1, nexp2) -> NC_bounded_ge (locate_nexp f nexp1, locate_nexp f nexp2) | NC_bounded_gt (nexp1, nexp2) -> NC_bounded_gt (locate_nexp f nexp1, locate_nexp f nexp2) @@ -1853,21 +1716,23 @@ let rec locate_nc f (NC_aux (nc_aux, l)) = NC_aux (nc_aux, f l) and locate_typ f (Typ_aux (typ_aux, l)) = - let typ_aux = match typ_aux with + let typ_aux = + match typ_aux with | Typ_internal_unknown -> Typ_internal_unknown | Typ_id id -> Typ_id (locate_id f id) | Typ_var kid -> Typ_var (locate_kid f kid) - | Typ_fn (arg_typs, ret_typ) -> - Typ_fn (List.map (locate_typ f) arg_typs, locate_typ f ret_typ) + | Typ_fn (arg_typs, ret_typ) -> Typ_fn (List.map (locate_typ f) arg_typs, locate_typ f ret_typ) | Typ_bidir (typ1, typ2) -> Typ_bidir (locate_typ f typ1, locate_typ f typ2) | Typ_tuple typs -> Typ_tuple (List.map (locate_typ f) typs) - | Typ_exist (kopts, constr, typ) -> Typ_exist (List.map (locate_kinded_id f) kopts, locate_nc f constr, locate_typ f typ) + | Typ_exist (kopts, constr, typ) -> + Typ_exist (List.map (locate_kinded_id f) kopts, locate_nc f constr, locate_typ f typ) | Typ_app (id, typ_args) -> Typ_app (locate_id f id, List.map (locate_typ_arg f) typ_args) in Typ_aux (typ_aux, f l) and locate_typ_arg f (A_aux (typ_arg_aux, l)) = - let typ_arg_aux = match typ_arg_aux with + let typ_arg_aux = + match typ_arg_aux with | A_nexp nexp -> A_nexp (locate_nexp f nexp) | A_typ typ -> A_typ (locate_typ f typ) | A_order ord -> A_order (locate_order f ord) @@ -1876,15 +1741,18 @@ and locate_typ_arg f (A_aux (typ_arg_aux, l)) = A_aux (typ_arg_aux, f l) let rec locate_typ_pat f (TP_aux (tp_aux, l)) = - let tp_aux = match tp_aux with + let tp_aux = + match tp_aux with | TP_wild -> TP_wild | TP_var kid -> TP_var (locate_kid f kid) | TP_app (id, tps) -> TP_app (locate_id f id, List.map (locate_typ_pat f) tps) in TP_aux (tp_aux, f l) -let rec locate_pat : 'a. (l -> l) -> 'a pat -> 'a pat = fun f (P_aux (p_aux, (l, annot))) -> - let p_aux = match p_aux with +let rec locate_pat : 'a. (l -> l) -> 'a pat -> 'a pat = + fun f (P_aux (p_aux, (l, annot))) -> + let p_aux = + match p_aux with | P_lit lit -> P_lit (locate_lit f lit) | P_wild -> P_wild | P_or (pat1, pat2) -> P_or (locate_pat f pat1, locate_pat f pat2) @@ -1904,8 +1772,10 @@ let rec locate_pat : 'a. (l -> l) -> 'a pat -> 'a pat = fun f (P_aux (p_aux, (l, in P_aux (p_aux, (f l, annot)) -let rec locate : 'a. (l -> l) -> 'a exp -> 'a exp = fun f (E_aux (e_aux, (l, annot))) -> - let e_aux = match e_aux with +let rec locate : 'a. (l -> l) -> 'a exp -> 'a exp = + fun f (E_aux (e_aux, (l, annot))) -> + let e_aux = + match e_aux with | E_block exps -> E_block (List.map (locate f) exps) | E_id id -> E_id (locate_id f id) | E_lit lit -> E_lit (locate_lit f lit) @@ -1916,15 +1786,14 @@ let rec locate : 'a. (l -> l) -> 'a exp -> 'a exp = fun f (E_aux (e_aux, (l, ann | E_if (cond_exp, then_exp, else_exp) -> E_if (locate f cond_exp, locate f then_exp, locate f else_exp) | E_loop (loop, measure, cond, body) -> E_loop (loop, locate_measure f measure, locate f cond, locate f body) | E_for (id, exp1, exp2, exp3, ord, exp4) -> - E_for (locate_id f id, locate f exp1, locate f exp2, locate f exp3, ord, locate f exp4) + E_for (locate_id f id, locate f exp1, locate f exp2, locate f exp3, ord, locate f exp4) | E_vector exps -> E_vector (List.map (locate f) exps) | E_vector_access (exp1, exp2) -> E_vector_access (locate f exp1, locate f exp2) | E_vector_subrange (exp1, exp2, exp3) -> E_vector_subrange (locate f exp1, locate f exp2, locate f exp3) | E_vector_update (exp1, exp2, exp3) -> E_vector_update (locate f exp1, locate f exp2, locate f exp3) | E_vector_update_subrange (exp1, exp2, exp3, exp4) -> - E_vector_update_subrange (locate f exp1, locate f exp2, locate f exp3, locate f exp4) - | E_vector_append (exp1, exp2) -> - E_vector_append (locate f exp1, locate f exp2) + E_vector_update_subrange (locate f exp1, locate f exp2, locate f exp3, locate f exp4) + | E_vector_append (exp1, exp2) -> E_vector_append (locate f exp1, locate f exp2) | E_list exps -> E_list (List.map (locate f) exps) | E_cons (exp1, exp2) -> E_cons (locate f exp1, locate f exp2) | E_struct fexps -> E_struct (List.map (locate_fexp f) fexps) @@ -1949,24 +1818,27 @@ let rec locate : 'a. (l -> l) -> 'a exp -> 'a exp = fun f (E_aux (e_aux, (l, ann in E_aux (e_aux, (f l, annot)) -and locate_measure : 'a. (l -> l) -> 'a internal_loop_measure -> 'a internal_loop_measure = fun f (Measure_aux (m, l)) -> - let m = match m with - | Measure_none -> Measure_none - | Measure_some exp -> Measure_some (locate f exp) - in Measure_aux (m, f l) +and locate_measure : 'a. (l -> l) -> 'a internal_loop_measure -> 'a internal_loop_measure = + fun f (Measure_aux (m, l)) -> + let m = match m with Measure_none -> Measure_none | Measure_some exp -> Measure_some (locate f exp) in + Measure_aux (m, f l) -and locate_letbind : 'a. (l -> l) -> 'a letbind -> 'a letbind = fun f (LB_aux (LB_val (pat, exp), (l, annot))) -> - LB_aux (LB_val (locate_pat f pat, locate f exp), (f l, annot)) +and locate_letbind : 'a. (l -> l) -> 'a letbind -> 'a letbind = + fun f (LB_aux (LB_val (pat, exp), (l, annot))) -> LB_aux (LB_val (locate_pat f pat, locate f exp), (f l, annot)) -and locate_pexp : 'a. (l -> l) -> 'a pexp -> 'a pexp = fun f (Pat_aux (pexp_aux, (l, annot))) -> - let pexp_aux = match pexp_aux with +and locate_pexp : 'a. (l -> l) -> 'a pexp -> 'a pexp = + fun f (Pat_aux (pexp_aux, (l, annot))) -> + let pexp_aux = + match pexp_aux with | Pat_exp (pat, exp) -> Pat_exp (locate_pat f pat, locate f exp) | Pat_when (pat, guard, exp) -> Pat_when (locate_pat f pat, locate f guard, locate f exp) in Pat_aux (pexp_aux, (f l, annot)) -and locate_lexp : 'a. (l -> l) -> 'a lexp -> 'a lexp = fun f (LE_aux (lexp_aux, (l, annot))) -> - let lexp_aux = match lexp_aux with +and locate_lexp : 'a. (l -> l) -> 'a lexp -> 'a lexp = + fun f (LE_aux (lexp_aux, (l, annot))) -> + let lexp_aux = + match lexp_aux with | LE_id id -> LE_id (locate_id f id) | LE_deref exp -> LE_deref (locate f exp) | LE_app (id, exps) -> LE_app (locate_id f id, List.map (locate f) exps) @@ -1979,8 +1851,8 @@ and locate_lexp : 'a. (l -> l) -> 'a lexp -> 'a lexp = fun f (LE_aux (lexp_aux, in LE_aux (lexp_aux, (f l, annot)) -and locate_fexp : 'a. (l -> l) -> 'a fexp -> 'a fexp = fun f (FE_aux (FE_fexp (id, exp), (l, annot))) -> - FE_aux (FE_fexp (locate_id f id, locate f exp), (f l, annot)) +and locate_fexp : 'a. (l -> l) -> 'a fexp -> 'a fexp = + fun f (FE_aux (FE_fexp (id, exp), (l, annot))) -> FE_aux (FE_fexp (locate_id f id, locate f exp), (f l, annot)) let unique_ref = ref 0 @@ -1992,40 +1864,33 @@ let unique l = let extern_assoc backend ext = match ext with | None -> None - | Some ext -> - match List.assoc_opt backend ext.bindings with - | Some f -> Some f - | None -> List.assoc_opt "_" ext.bindings + | Some ext -> ( + match List.assoc_opt backend ext.bindings with Some f -> Some f | None -> List.assoc_opt "_" ext.bindings + ) (**************************************************************************) (* 1. Substitutions *) (**************************************************************************) let order_subst_aux sv subst = function - | Ord_var kid -> - begin match subst with - | A_aux (A_order ord, _) when Kid.compare kid sv = 0 -> - unaux_order ord - | _ -> Ord_var kid - end + | Ord_var kid -> begin + match subst with A_aux (A_order ord, _) when Kid.compare kid sv = 0 -> unaux_order ord | _ -> Ord_var kid + end | Ord_inc -> Ord_inc | Ord_dec -> Ord_dec let order_subst sv subst (Ord_aux (ord, l)) = Ord_aux (order_subst_aux sv subst ord, l) let rec nexp_subst sv subst = function - | (Nexp_aux (Nexp_var kid, _)) as nexp -> - begin match subst with - | A_aux (A_nexp n, _) when Kid.compare kid sv = 0 -> n - | _ -> nexp - end + | Nexp_aux (Nexp_var kid, _) as nexp -> begin + match subst with A_aux (A_nexp n, _) when Kid.compare kid sv = 0 -> n | _ -> nexp + end | Nexp_aux (nexp, l) -> Nexp_aux (nexp_subst_aux sv subst nexp, l) + and nexp_subst_aux sv subst = function - | Nexp_var kid -> - begin match subst with - | A_aux (A_nexp n, _) when Kid.compare kid sv = 0 -> unaux_nexp n - | _ -> Nexp_var kid - end + | Nexp_var kid -> begin + match subst with A_aux (A_nexp n, _) when Kid.compare kid sv = 0 -> unaux_nexp n | _ -> Nexp_var kid + end | Nexp_id id -> Nexp_id id | Nexp_constant c -> Nexp_constant c | Nexp_times (nexp1, nexp2) -> Nexp_times (nexp_subst sv subst nexp1, nexp_subst sv subst nexp2) @@ -2038,9 +1903,10 @@ and nexp_subst_aux sv subst = function let rec nexp_set_to_or l subst = function | [] -> raise (Reporting.err_unreachable l __POS__ "Empty set in constraint") | [int] -> NC_equal (subst, nconstant int) - | (int :: ints) -> NC_or (mk_nc (NC_equal (subst, nconstant int)), mk_nc (nexp_set_to_or l subst ints)) + | int :: ints -> NC_or (mk_nc (NC_equal (subst, nconstant int)), mk_nc (nexp_set_to_or l subst ints)) let rec constraint_subst sv subst (NC_aux (nc, l)) = NC_aux (constraint_subst_aux l sv subst nc, l) + and constraint_subst_aux l sv subst = function | NC_equal (n1, n2) -> NC_equal (nexp_subst sv subst n1, nexp_subst sv subst n2) | NC_bounded_ge (n1, n2) -> NC_bounded_ge (nexp_subst sv subst n1, nexp_subst sv subst n2) @@ -2048,46 +1914,39 @@ and constraint_subst_aux l sv subst = function | NC_bounded_le (n1, n2) -> NC_bounded_le (nexp_subst sv subst n1, nexp_subst sv subst n2) | NC_bounded_lt (n1, n2) -> NC_bounded_lt (nexp_subst sv subst n1, nexp_subst sv subst n2) | NC_not_equal (n1, n2) -> NC_not_equal (nexp_subst sv subst n1, nexp_subst sv subst n2) - | NC_set (kid, ints) as set_nc -> - begin match subst with - | A_aux (A_nexp (Nexp_aux (Nexp_var kid',_)), _) when Kid.compare kid sv = 0 -> - NC_set (kid', ints) - | A_aux (A_nexp n, _) when Kid.compare kid sv = 0 -> - nexp_set_to_or l n ints - | _ -> set_nc - end + | NC_set (kid, ints) as set_nc -> begin + match subst with + | A_aux (A_nexp (Nexp_aux (Nexp_var kid', _)), _) when Kid.compare kid sv = 0 -> NC_set (kid', ints) + | A_aux (A_nexp n, _) when Kid.compare kid sv = 0 -> nexp_set_to_or l n ints + | _ -> set_nc + end | NC_or (nc1, nc2) -> NC_or (constraint_subst sv subst nc1, constraint_subst sv subst nc2) | NC_and (nc1, nc2) -> NC_and (constraint_subst sv subst nc1, constraint_subst sv subst nc2) | NC_app (id, args) -> NC_app (id, List.map (typ_arg_subst sv subst) args) - | NC_var kid -> - begin match subst with - | A_aux (A_bool nc, _) when Kid.compare kid sv = 0 -> - unaux_constraint nc - | _ -> NC_var kid - end + | NC_var kid -> begin + match subst with A_aux (A_bool nc, _) when Kid.compare kid sv = 0 -> unaux_constraint nc | _ -> NC_var kid + end | NC_false -> NC_false | NC_true -> NC_true and typ_subst sv subst (Typ_aux (typ, l)) = Typ_aux (typ_subst_aux sv subst typ, l) + and typ_subst_aux sv subst = function | Typ_internal_unknown -> Typ_internal_unknown | Typ_id v -> Typ_id v - | Typ_var kid -> - begin match subst with - | A_aux (A_typ typ, _) when Kid.compare kid sv = 0 -> - unaux_typ typ - | _ -> Typ_var kid - end + | Typ_var kid -> begin + match subst with A_aux (A_typ typ, _) when Kid.compare kid sv = 0 -> unaux_typ typ | _ -> Typ_var kid + end | Typ_fn (arg_typs, ret_typ) -> Typ_fn (List.map (typ_subst sv subst) arg_typs, typ_subst sv subst ret_typ) | Typ_bidir (typ1, typ2) -> Typ_bidir (typ_subst sv subst typ1, typ_subst sv subst typ2) | Typ_tuple typs -> Typ_tuple (List.map (typ_subst sv subst) typs) | Typ_app (f, args) -> Typ_app (f, List.map (typ_arg_subst sv subst) args) | Typ_exist (kopts, nc, typ) when KidSet.mem sv (KidSet.of_list (List.map kopt_kid kopts)) -> - Typ_exist (kopts, nc, typ) - | Typ_exist (kopts, nc, typ) -> - Typ_exist (kopts, constraint_subst sv subst nc, typ_subst sv subst typ) + Typ_exist (kopts, nc, typ) + | Typ_exist (kopts, nc, typ) -> Typ_exist (kopts, constraint_subst sv subst nc, typ_subst sv subst typ) and typ_arg_subst sv subst (A_aux (arg, l)) = A_aux (typ_arg_subst_aux sv subst arg, l) + and typ_arg_subst_aux sv subst = function | A_nexp nexp -> A_nexp (nexp_subst sv subst nexp) | A_typ typ -> A_typ (typ_subst sv subst typ) @@ -2105,10 +1964,8 @@ let kopt_subst_kid sv subst (KOpt_aux (KOpt_kind (k, kid), l) as orig) = if Kid.compare kid sv = 0 then KOpt_aux (KOpt_kind (k, subst), l) else orig let quant_item_subst_kid_aux sv subst = function - | QI_id kopt -> - QI_id (kopt_subst_kid sv subst kopt) - | QI_constraint nc -> - QI_constraint (subst_kid constraint_subst sv subst nc) + | QI_id kopt -> QI_id (kopt_subst_kid sv subst kopt) + | QI_constraint nc -> QI_constraint (subst_kid constraint_subst sv subst nc) let quant_item_subst_kid sv subst (QI_aux (quant, l)) = QI_aux (quant_item_subst_kid_aux sv subst quant, l) @@ -2119,201 +1976,166 @@ let typquant_subst_kid_aux sv subst = function let typquant_subst_kid sv subst (TypQ_aux (typq, l)) = TypQ_aux (typquant_subst_kid_aux sv subst typq, l) let subst_kids_nexp substs nexp = - let rec s_snexp substs (Nexp_aux (ne,l) as nexp) = - let re ne = Nexp_aux (ne,l) in + let rec s_snexp substs (Nexp_aux (ne, l) as nexp) = + let re ne = Nexp_aux (ne, l) in let s_snexp = s_snexp substs in match ne with - | Nexp_var v -> - (try KBindings.find v substs - with Not_found -> nexp) - | Nexp_id _ - | Nexp_constant _ -> nexp + | Nexp_var v -> ( + try KBindings.find v substs with Not_found -> nexp + ) + | Nexp_id _ | Nexp_constant _ -> nexp | Nexp_times (n1, n2) -> re (Nexp_times (s_snexp n1, s_snexp n2)) - | Nexp_sum (n1, n2) -> re (Nexp_sum (s_snexp n1, s_snexp n2)) + | Nexp_sum (n1, n2) -> re (Nexp_sum (s_snexp n1, s_snexp n2)) | Nexp_minus (n1, n2) -> re (Nexp_minus (s_snexp n1, s_snexp n2)) | Nexp_exp ne -> re (Nexp_exp (s_snexp ne)) | Nexp_neg ne -> re (Nexp_neg (s_snexp ne)) - | Nexp_app (id,args) -> re (Nexp_app (id,List.map s_snexp args)) - in s_snexp substs nexp + | Nexp_app (id, args) -> re (Nexp_app (id, List.map s_snexp args)) + in + s_snexp substs nexp let subst_kids_nc, subst_kids_typ, subst_kids_typ_arg = - let rec subst_kids_nc substs (NC_aux (nc,l) as n_constraint) = + let rec subst_kids_nc substs (NC_aux (nc, l) as n_constraint) = let snexp nexp = subst_kids_nexp substs nexp in let snc nc = subst_kids_nc substs nc in - let re nc = NC_aux (nc,l) in + let re nc = NC_aux (nc, l) in match nc with - | NC_equal (n1,n2) -> re (NC_equal (snexp n1, snexp n2)) - | NC_bounded_ge (n1,n2) -> re (NC_bounded_ge (snexp n1, snexp n2)) - | NC_bounded_gt (n1,n2) -> re (NC_bounded_gt (snexp n1, snexp n2)) - | NC_bounded_le (n1,n2) -> re (NC_bounded_le (snexp n1, snexp n2)) - | NC_bounded_lt (n1,n2) -> re (NC_bounded_lt (snexp n1, snexp n2)) - | NC_not_equal (n1,n2) -> re (NC_not_equal (snexp n1, snexp n2)) - | NC_set (kid, is) -> - begin - match KBindings.find kid substs with - | Nexp_aux (Nexp_constant i,_) -> + | NC_equal (n1, n2) -> re (NC_equal (snexp n1, snexp n2)) + | NC_bounded_ge (n1, n2) -> re (NC_bounded_ge (snexp n1, snexp n2)) + | NC_bounded_gt (n1, n2) -> re (NC_bounded_gt (snexp n1, snexp n2)) + | NC_bounded_le (n1, n2) -> re (NC_bounded_le (snexp n1, snexp n2)) + | NC_bounded_lt (n1, n2) -> re (NC_bounded_lt (snexp n1, snexp n2)) + | NC_not_equal (n1, n2) -> re (NC_not_equal (snexp n1, snexp n2)) + | NC_set (kid, is) -> begin + match KBindings.find kid substs with + | Nexp_aux (Nexp_constant i, _) -> if List.exists (fun j -> Big_int.equal i j) is then re NC_true else re NC_false - | nexp -> - begin match List.rev is with + | nexp -> begin + match List.rev is with | i :: is -> - let equal_num i = re (NC_equal (nexp, nconstant i)) in - List.fold_left (fun nc i -> re (NC_or (equal_num i, nc))) (equal_num i) is + let equal_num i = re (NC_equal (nexp, nconstant i)) in + List.fold_left (fun nc i -> re (NC_or (equal_num i, nc))) (equal_num i) is | [] -> re NC_false - end - | exception Not_found -> n_constraint - end - | NC_or (nc1,nc2) -> re (NC_or (snc nc1, snc nc2)) - | NC_and (nc1,nc2) -> re (NC_and (snc nc1, snc nc2)) - | NC_true - | NC_false - -> n_constraint + end + | exception Not_found -> n_constraint + end + | NC_or (nc1, nc2) -> re (NC_or (snc nc1, snc nc2)) + | NC_and (nc1, nc2) -> re (NC_and (snc nc1, snc nc2)) + | NC_true | NC_false -> n_constraint | NC_var kid -> re (NC_var kid) - | NC_app (f, args) -> - re (NC_app (f, List.map (s_starg substs) args)) - and s_styp substs ((Typ_aux (t,l)) as ty) = - let re t = Typ_aux (t,l) in + | NC_app (f, args) -> re (NC_app (f, List.map (s_starg substs) args)) + and s_styp substs (Typ_aux (t, l) as ty) = + let re t = Typ_aux (t, l) in match t with - | Typ_id _ - | Typ_var _ - -> ty - | Typ_fn (t1,t2) -> re (Typ_fn (List.map (s_styp substs) t1, s_styp substs t2)) - | Typ_bidir (t1,t2) -> re (Typ_bidir (s_styp substs t1, s_styp substs t2)) + | Typ_id _ | Typ_var _ -> ty + | Typ_fn (t1, t2) -> re (Typ_fn (List.map (s_styp substs) t1, s_styp substs t2)) + | Typ_bidir (t1, t2) -> re (Typ_bidir (s_styp substs t1, s_styp substs t2)) | Typ_tuple ts -> re (Typ_tuple (List.map (s_styp substs) ts)) - | Typ_app (id,tas) -> re (Typ_app (id,List.map (s_starg substs) tas)) - | Typ_exist (kopts,nc,t) -> - let substs = List.fold_left (fun sub kopt -> KBindings.remove (kopt_kid kopt) sub) substs kopts in - re (Typ_exist (kopts,subst_kids_nc substs nc,s_styp substs t)) + | Typ_app (id, tas) -> re (Typ_app (id, List.map (s_starg substs) tas)) + | Typ_exist (kopts, nc, t) -> + let substs = List.fold_left (fun sub kopt -> KBindings.remove (kopt_kid kopt) sub) substs kopts in + re (Typ_exist (kopts, subst_kids_nc substs nc, s_styp substs t)) | Typ_internal_unknown -> Reporting.unreachable l __POS__ "escaped Typ_internal_unknown" - and s_starg substs (A_aux (ta,l) as targ) = + and s_starg substs (A_aux (ta, l) as targ) = match ta with - | A_nexp ne -> A_aux (A_nexp (subst_kids_nexp substs ne),l) - | A_typ t -> A_aux (A_typ (s_styp substs t),l) + | A_nexp ne -> A_aux (A_nexp (subst_kids_nexp substs ne), l) + | A_typ t -> A_aux (A_typ (s_styp substs t), l) | A_order _ -> targ | A_bool nc -> A_aux (A_bool (subst_kids_nc substs nc), l) - in subst_kids_nc, s_styp, s_starg + in + (subst_kids_nc, s_styp, s_starg) let before p1 p2 = let open Lexing in p1.pos_fname = p2.pos_fname && p1.pos_cnum <= p2.pos_cnum let subloc sl l = - match sl, Reporting.simp_loc l with + match (sl, Reporting.simp_loc l) with | _, None -> false | None, _ -> false - | Some (p1a, p1b), Some (p2a, p2b) -> - before p2a p1a && before p1b p2b + | Some (p1a, p1b), Some (p2a, p2b) -> before p2a p1a && before p1b p2b let rec option_mapm f = function | [] -> None - | x :: xs -> - begin match f x with - | Some y -> Some y - | None -> option_mapm f xs - end + | x :: xs -> begin match f x with Some y -> Some y | None -> option_mapm f xs end let option_chain opt1 opt2 = - begin match opt1 with - | None -> opt2 - | _ -> opt1 + begin + match opt1 with None -> opt2 | _ -> opt1 end let rec find_annot_exp sl (E_aux (aux, (l, annot))) = - if not (subloc sl l) then None else - let result = match aux with - | E_block exps | E_tuple exps -> - option_mapm (find_annot_exp sl) exps - | E_app (_, exps) -> - option_mapm (find_annot_exp sl) exps + if not (subloc sl l) then None + else ( + let result = + match aux with + | E_block exps | E_tuple exps -> option_mapm (find_annot_exp sl) exps + | E_app (_, exps) -> option_mapm (find_annot_exp sl) exps | E_let (LB_aux (LB_val (pat, exp), _), body) -> - option_chain (find_annot_pat sl pat) (option_mapm (find_annot_exp sl) [exp; body]) - | E_assign (lexp, exp) -> - option_chain (find_annot_lexp sl lexp) (find_annot_exp sl exp) - | E_var (lexp, exp1, exp2) -> - option_chain (find_annot_lexp sl lexp) (option_mapm (find_annot_exp sl) [exp1; exp2]) - | E_if (cond_exp, then_exp, else_exp) -> - option_mapm (find_annot_exp sl) [cond_exp; then_exp; else_exp] + option_chain (find_annot_pat sl pat) (option_mapm (find_annot_exp sl) [exp; body]) + | E_assign (lexp, exp) -> option_chain (find_annot_lexp sl lexp) (find_annot_exp sl exp) + | E_var (lexp, exp1, exp2) -> option_chain (find_annot_lexp sl lexp) (option_mapm (find_annot_exp sl) [exp1; exp2]) + | E_if (cond_exp, then_exp, else_exp) -> option_mapm (find_annot_exp sl) [cond_exp; then_exp; else_exp] | E_match (exp, cases) | E_try (exp, cases) -> - option_chain (find_annot_exp sl exp) (option_mapm (find_annot_pexp sl) cases) - | E_return exp | E_typ (_, exp) -> - find_annot_exp sl exp + option_chain (find_annot_exp sl exp) (option_mapm (find_annot_pexp sl) cases) + | E_return exp | E_typ (_, exp) -> find_annot_exp sl exp | _ -> None in - match result with - | None -> Some (l, annot) - | _ -> result + match result with None -> Some (l, annot) | _ -> result + ) and find_annot_lexp sl (LE_aux (aux, (l, annot))) = - if not (subloc sl l) then None else - let result = match aux with + if not (subloc sl l) then None + else ( + let result = + match aux with | LE_vector_range (lexp, exp1, exp2) -> - option_chain (find_annot_lexp sl lexp) (option_mapm (find_annot_exp sl) [exp1; exp2]) - | LE_deref exp -> - find_annot_exp sl exp - | LE_tuple lexps -> - option_mapm (find_annot_lexp sl) lexps - | LE_app (_, exps) -> - option_mapm (find_annot_exp sl) exps + option_chain (find_annot_lexp sl lexp) (option_mapm (find_annot_exp sl) [exp1; exp2]) + | LE_deref exp -> find_annot_exp sl exp + | LE_tuple lexps -> option_mapm (find_annot_lexp sl) lexps + | LE_app (_, exps) -> option_mapm (find_annot_exp sl) exps | _ -> None in - match result with - | None -> Some (l, annot) - | _ -> result + match result with None -> Some (l, annot) | _ -> result + ) and find_annot_pat sl (P_aux (aux, (l, annot))) = - if not (subloc sl l) then None else - let result = match aux with - | P_vector_concat pats -> - option_mapm (find_annot_pat sl) pats - | _ -> None - in - match result with - | None -> Some (l, annot) - | _ -> result + if not (subloc sl l) then None + else ( + let result = match aux with P_vector_concat pats -> option_mapm (find_annot_pat sl) pats | _ -> None in + match result with None -> Some (l, annot) | _ -> result + ) and find_annot_pexp sl (Pat_aux (aux, (l, _))) = - if not (subloc sl l) then None else + if not (subloc sl l) then None + else ( match aux with - | Pat_exp (pat, exp) -> - option_chain (find_annot_pat sl pat) (find_annot_exp sl exp) - | Pat_when (pat, guard, exp) -> - option_chain (find_annot_pat sl pat) (option_mapm (find_annot_exp sl) [guard; exp]) + | Pat_exp (pat, exp) -> option_chain (find_annot_pat sl pat) (find_annot_exp sl exp) + | Pat_when (pat, guard, exp) -> option_chain (find_annot_pat sl pat) (option_mapm (find_annot_exp sl) [guard; exp]) + ) let find_annot_funcl sl (FCL_aux (FCL_funcl (_, pexp), (def_annot, annot))) = let l = def_annot.loc in - if not (subloc sl l) then None else - match find_annot_pexp sl pexp with - | None -> Some (l, annot) - | result -> result + if not (subloc sl l) then None else (match find_annot_pexp sl pexp with None -> Some (l, annot) | result -> result) let find_annot_fundef sl (FD_aux (FD_function (_, _, funcls), (l, annot))) = - if not (subloc sl l) then None else - match option_mapm (find_annot_funcl sl) funcls with - | None -> Some (l, annot) - | result -> result + if not (subloc sl l) then None + else (match option_mapm (find_annot_funcl sl) funcls with None -> Some (l, annot) | result -> result) let find_annot_scattered sl (SD_aux (aux, (l, annot))) = - if not (subloc sl l) then None else - let result = match aux with - | SD_funcl fcl -> find_annot_funcl sl fcl - | _ -> None - in - match result with - | None -> Some (l, annot) - | _ -> result + if not (subloc sl l) then None + else ( + let result = match aux with SD_funcl fcl -> find_annot_funcl sl fcl | _ -> None in + match result with None -> Some (l, annot) | _ -> result + ) let rec find_annot_defs sl = function - | DEF_aux (DEF_fundef fdef, _) :: defs -> - begin match find_annot_fundef sl fdef with - | None -> find_annot_defs sl defs - | result -> result - end - | DEF_aux (DEF_scattered sdef, _) :: defs -> - begin match find_annot_scattered sl sdef with - | None -> find_annot_defs sl defs - | result -> result - end - | _ :: defs -> - find_annot_defs sl defs + | DEF_aux (DEF_fundef fdef, _) :: defs -> begin + match find_annot_fundef sl fdef with None -> find_annot_defs sl defs | result -> result + end + | DEF_aux (DEF_scattered sdef, _) :: defs -> begin + match find_annot_scattered sl sdef with None -> find_annot_defs sl defs | result -> result + end + | _ :: defs -> find_annot_defs sl defs | [] -> None let find_annot_ast sl { defs; _ } = find_annot_defs sl defs @@ -2327,5 +2149,4 @@ let rec simple_string_of_loc = function | Parse_ast.Unique (n, l) -> "Unique(" ^ string_of_int n ^ ", " ^ simple_string_of_loc l ^ ")" | Parse_ast.Generated l -> "Generated(" ^ simple_string_of_loc l ^ ")" | Parse_ast.Hint (_, l1, l2) -> "Hint(_," ^ simple_string_of_loc l1 ^ "," ^ simple_string_of_loc l2 ^ ")" - | Parse_ast.Range (lx1,lx2) -> "Range(" ^ string_of_lx lx1 ^ "->" ^ string_of_lx lx2 ^ ")" - + | Parse_ast.Range (lx1, lx2) -> "Range(" ^ string_of_lx lx1 ^ "->" ^ string_of_lx lx2 ^ ")" diff --git a/src/lib/ast_util.mli b/src/lib/ast_util.mli index 26390756a..4c983dc0e 100644 --- a/src/lib/ast_util.mli +++ b/src/lib/ast_util.mli @@ -102,7 +102,7 @@ val add_def_attribute : l -> string -> string -> def_annot -> def_annot val get_def_attribute : string -> def_annot -> (l * string) option val def_annot_map_loc : (l -> l) -> def_annot -> def_annot - + (** The empty annotation (as a location + uannot pair). Should be used carefully because it can result in unhelpful error messgaes. However a common pattern is generating code with [no_annot], then adding location @@ -117,7 +117,7 @@ val gen_loc : Parse_ast.l -> Parse_ast.l val is_gen_loc : Parse_ast.l -> bool (** {1 Variable information} *) - + type mut = Immutable | Mutable (** [lvar] is the type of variables - they can either be registers, @@ -126,14 +126,14 @@ type mut = Immutable | Mutable type 'a lvar = Register of 'a | Enum of 'a | Local of mut * 'a | Unbound of id val is_unbound : 'a lvar -> bool - + (** Note: Partial function -- fails for {!Unbound} lvars *) val lvar_typ : ?loc:l -> 'a lvar -> 'a - + (** {1 Functions for building and destructuring untyped AST elements} *) (** {2 Functions for building untyped AST elements} *) - + val mk_id : string -> id val mk_kid : string -> kid val mk_ord : order_aux -> order @@ -149,7 +149,7 @@ val mk_lit : lit_aux -> lit val mk_lit_exp : lit_aux -> uannot exp val mk_typ_pat : typ_pat_aux -> typ_pat val mk_funcl : ?loc:l -> id -> uannot pat -> uannot exp -> uannot funcl -val mk_fundef : (uannot funcl) list -> uannot def +val mk_fundef : uannot funcl list -> uannot def val mk_val_spec : val_spec_aux -> uannot def val mk_typschm : typquant -> typ -> typschm val mk_typquant : quant_item list -> typquant @@ -160,7 +160,7 @@ val mk_fexp : id -> uannot exp -> uannot fexp val mk_letbind : uannot pat -> uannot exp -> uannot letbind val mk_kopt : ?loc:l -> kind_aux -> kid -> kinded_id val mk_def : ?loc:l -> 'a def_aux -> 'a def - + val inc_ord : order val dec_ord : order @@ -203,7 +203,7 @@ val is_typ_arg_nexp : typ_arg -> bool val is_typ_arg_typ : typ_arg -> bool val is_typ_arg_order : typ_arg -> bool val is_typ_arg_bool : typ_arg -> bool - + (** {2 Sail built-in types} *) val unknown_typ : typ @@ -252,7 +252,7 @@ val constraint_conj : n_constraint -> n_constraint list val constraint_disj : n_constraint -> n_constraint list type effect - + val no_effect : effect val monadic_effect : effect @@ -261,7 +261,7 @@ val effectful : effect -> bool val equal_effects : effect -> effect -> bool val subseteq_effects : effect -> effect -> bool val union_effects : effect -> effect -> effect - + (** {2 Functions for building numeric expressions} *) val nconstant : Big_int.num -> nexp @@ -330,13 +330,13 @@ module NC : sig type t = n_constraint val compare : n_constraint -> n_constraint -> int end - + (* NB: the comparison function does not expand synonyms *) module Typ : sig type t = typ val compare : typ -> typ -> int end - + module IdSet : sig include Set.S with type elt = id end @@ -376,7 +376,7 @@ end module TypMap : sig include Map.S with type key = typ end - + (** {1 Functions for working with type quantifiers} *) val quant_add : quant_item -> typquant -> typquant @@ -412,6 +412,7 @@ val map_ast_annot : ('a annot -> 'b annot) -> 'a ast -> 'b ast (** {1 Extract locations from terms} *) val id_loc : id -> Parse_ast.l + val kid_loc : kid -> Parse_ast.l val kopt_loc : kinded_id -> Parse_ast.l val typ_loc : typ -> Parse_ast.l @@ -455,7 +456,7 @@ val id_of_mapdef : 'a mapdef -> id val id_of_type_def : 'a type_def -> id val id_of_val_spec : 'a val_spec -> id val id_of_dec_spec : 'a dec_spec -> id - + (** {2 Functions for manipulating identifiers} *) val id_of_kid : kid -> id @@ -497,11 +498,11 @@ val undefined_of_typ : bool -> Ast.l -> (typ -> 'annot) -> typ -> 'annot exp val pattern_vector_subranges : 'a pat -> (Big_int.num * Big_int.num) list Bindings.t -val destruct_pexp : 'a pexp -> 'a pat * ('a exp) option * 'a exp * (Ast.l * 'a) -val construct_pexp : 'a pat * ('a exp) option * 'a exp * (Ast.l * 'a) -> 'a pexp +val destruct_pexp : 'a pexp -> 'a pat * 'a exp option * 'a exp * (Ast.l * 'a) +val construct_pexp : 'a pat * 'a exp option * 'a exp * (Ast.l * 'a) -> 'a pexp -val destruct_mpexp : 'a mpexp -> 'a mpat * ('a exp) option * (Ast.l * 'a) -val construct_mpexp : 'a mpat * ('a exp) option * (Ast.l * 'a) -> 'a mpexp +val destruct_mpexp : 'a mpexp -> 'a mpat * 'a exp option * (Ast.l * 'a) +val construct_mpexp : 'a mpat * 'a exp option * (Ast.l * 'a) -> 'a mpexp val is_valspec : id -> 'a def -> bool val is_fundef : id -> 'a def -> bool diff --git a/src/lib/bitfield.ml b/src/lib/bitfield.ml index b071f4af1..4caef0f68 100644 --- a/src/lib/bitfield.ml +++ b/src/lib/bitfield.ml @@ -90,69 +90,66 @@ let rec indices_of_range = function | BF_aux (BF_concat (l, r), _) -> indices_of_range l @ indices_of_range r let slice_width (i, j) = Big_int.succ (Big_int.abs (Big_int.sub i j)) -let range_width r = - List.map slice_width (indices_of_range r) - |> List.fold_left Big_int.add Big_int.zero +let range_width r = List.map slice_width (indices_of_range r) |> List.fold_left Big_int.add Big_int.zero (* Generate a constructor function for a bitfield type *) let constructor name order size = let typschm = fun_typschm [bitvec_typ size order] (mk_id_typ name) in let constructor_val = mk_val_spec (VS_val_spec (typschm, prepend_id "Mk_" name, None, false)) in let constructor_fun = Printf.sprintf "function Mk_%s v = struct { bits = v }" (string_of_id name) in - (constructor_val :: defs_of_string __POS__ constructor_fun) + constructor_val :: defs_of_string __POS__ constructor_fun (* Helper functions to generate different kinds of field accessor exps and lexps *) let get_field_exp range inner_exp = let mk_slice (i, j) = mk_exp (E_vector_subrange (inner_exp, mk_num_exp i, mk_num_exp j)) in let rec aux = function | [e] -> e - | (e :: es) -> mk_exp (E_vector_append (e, aux es)) + | e :: es -> mk_exp (E_vector_append (e, aux es)) | [] -> assert false (* unreachable *) in aux (List.map mk_slice (indices_of_range range)) let set_field_lexp range inner_lexp = let mk_slice (i, j) = mk_lexp (LE_vector_range (inner_lexp, mk_num_exp i, mk_num_exp j)) in - match List.map mk_slice (indices_of_range range) with - | [e] -> e - | es -> mk_lexp (LE_vector_concat es) + match List.map mk_slice (indices_of_range range) with [e] -> e | es -> mk_lexp (LE_vector_concat es) let set_bits_field_lexp inner_lexp = mk_lexp (LE_field (inner_lexp, mk_id "bits")) let get_bits_field exp = mk_exp (E_field (exp, mk_id "bits")) -let set_bits_field exp value = - mk_exp (E_struct_update (exp, [mk_fexp (mk_id "bits") value])) +let set_bits_field exp value = mk_exp (E_struct_update (exp, [mk_fexp (mk_id "bits") value])) let update_field_exp range order inner_exp new_value = - let single = (List.length (indices_of_range range) == 1) in + let single = List.length (indices_of_range range) == 1 in let rec aux e vi = function | (i, j) :: is -> - let w = slice_width (i, j) in - let vi' = if is_order_inc order then Big_int.add vi w else Big_int.sub vi w in - let rhs = - if single then new_value else begin - let vj = if is_order_inc order then Big_int.pred vi' else Big_int.succ vi' in - mk_exp (E_vector_subrange (new_value, mk_num_exp vi, mk_num_exp vj)) - end - in - let update = mk_exp (E_vector_update_subrange (e, mk_num_exp i, mk_num_exp j, rhs)) in - aux update vi' is + let w = slice_width (i, j) in + let vi' = if is_order_inc order then Big_int.add vi w else Big_int.sub vi w in + let rhs = + if single then new_value + else begin + let vj = if is_order_inc order then Big_int.pred vi' else Big_int.succ vi' in + mk_exp (E_vector_subrange (new_value, mk_num_exp vi, mk_num_exp vj)) + end + in + let update = mk_exp (E_vector_update_subrange (e, mk_num_exp i, mk_num_exp j, rhs)) in + aux update vi' is | [] -> e in let vi = if is_order_inc order then Big_int.zero else Big_int.pred (range_width range) in aux inner_exp vi (indices_of_range range) (* For every field, create getter and setter functions *) -type field_accessor_ids = { get : id; set : id; update: id; overload: id } +type field_accessor_ids = { get : id; set : id; update : id; overload : id } let field_accessor_ids type_name field = let type_name = string_of_id type_name in let field = string_of_id field in - { get = mk_id (Printf.sprintf "_get_%s_%s" type_name field); + { + get = mk_id (Printf.sprintf "_get_%s_%s" type_name field); set = mk_id (Printf.sprintf "_set_%s_%s" type_name field); update = mk_id (Printf.sprintf "_update_%s_%s" type_name field); - overload = mk_id (Printf.sprintf "_mod_%s" field) + overload = mk_id (Printf.sprintf "_mod_%s" field); } let field_getter typ_name field order range = @@ -176,7 +173,9 @@ let field_updater typ_name field order range = let new_bits = update_field_exp range order bits_exp (mk_id_exp new_val_var) in let body = set_bits_field (mk_id_exp orig_var) new_bits in let funcl = mk_funcl fun_id (mk_pat (P_tuple [mk_id_pat orig_var; mk_id_pat new_val_var])) body in - let overload = defs_of_string __POS__ (Printf.sprintf "overload update_%s = {%s}" (string_of_id field) (string_of_id fun_id)) in + let overload = + defs_of_string __POS__ (Printf.sprintf "overload update_%s = {%s}" (string_of_id field) (string_of_id fun_id)) + in [spec; mk_fundef [funcl]] @ overload let register_field_setter typ_name field order range = @@ -187,12 +186,14 @@ let register_field_setter typ_name field order range = let field_typ = string_of_typ (bitvec_typ size order) in let rfs_val = Printf.sprintf "val %s : (register(%s), %s) -> unit" fun_id typ_name field_typ in (* Read-modify-write using an internal _reg_deref function without rreg effect *) - let rfs_function = String.concat "\n" - [ Printf.sprintf "function %s (r_ref, v) = {" fun_id; - " r = __deref(r_ref);"; - Printf.sprintf " (*r_ref) = %s(r, v)" update_fun_id; - "}" - ] + let rfs_function = + String.concat "\n" + [ + Printf.sprintf "function %s (r_ref, v) = {" fun_id; + " r = __deref(r_ref);"; + Printf.sprintf " (*r_ref) = %s(r, v)" update_fun_id; + "}"; + ] in List.concat [defs_of_string __POS__ rfs_val; defs_of_string __POS__ rfs_function] @@ -203,16 +204,17 @@ let field_overload name field = defs_of_string __POS__ (Printf.sprintf "overload %s = {%s, %s}" fun_id get_id set_id) let field_accessors typ_name field order range = - List.concat [ - field_getter typ_name field order range; - field_updater typ_name field order range; - register_field_setter typ_name field order range; - field_overload typ_name field - ] + List.concat + [ + field_getter typ_name field order range; + field_updater typ_name field order range; + register_field_setter typ_name field order range; + field_overload typ_name field; + ] (* Generate all accessor functions for a given bitfield type *) let macro id size order ranges = let full_range = BF_aux (BF_range (nconstant (Big_int.pred size), nconstant Big_int.zero), Parse_ast.Unknown) in - let ranges = (mk_id "bits", full_range) :: (Bindings.bindings ranges) in + let ranges = (mk_id "bits", full_range) :: Bindings.bindings ranges in let accessors = List.map (fun (field, range) -> field_accessors id field order range) ranges in List.concat ([constructor id order size] @ accessors) diff --git a/src/lib/bitfield.mli b/src/lib/bitfield.mli index 8a3a76e68..873875aa9 100644 --- a/src/lib/bitfield.mli +++ b/src/lib/bitfield.mli @@ -74,12 +74,7 @@ val set_field_lexp : index_range -> uannot lexp -> uannot lexp (** Create an L-expression for setting all the bits of a bitfield *) val set_bits_field_lexp : uannot lexp -> uannot lexp -type field_accessor_ids = { - get : id; - set : id; - update : id; - overload : id; - } +type field_accessor_ids = { get : id; set : id; update : id; overload : id } (** The [macro] function generates multiple definitions to get, set, and update fields, so we can use this function to find the names of diff --git a/src/lib/callgraph.ml b/src/lib/callgraph.ml index 129d6616a..4ed7d5536 100644 --- a/src/lib/callgraph.ml +++ b/src/lib/callgraph.ml @@ -81,7 +81,7 @@ type node = | FunctionMeasure of id | LoopMeasures of id | Outcome of id - + let node_id = function | Register id -> id | Function id -> id @@ -113,7 +113,7 @@ module Node = struct lex_ord (compare (node_kind n1) (node_kind n2)) (Id.compare (node_id n1) (node_id n2)) end -module G = Graph.Make(Node) +module G = Graph.Make (Node) let builtins = let open Type_check in @@ -121,36 +121,33 @@ let builtins = let rec constraint_ids' (NC_aux (aux, _)) = match aux with - | NC_equal (n1, n2) | NC_bounded_le (n1, n2) | NC_bounded_ge (n1, n2) | NC_bounded_lt (n1, n2) | NC_bounded_gt (n1, n2) | NC_not_equal (n1, n2) -> - IdSet.union (nexp_ids' n1) (nexp_ids' n2) - | NC_or (nc1, nc2) | NC_and (nc1, nc2) -> - IdSet.union (constraint_ids' nc1) (constraint_ids' nc2) + | NC_equal (n1, n2) + | NC_bounded_le (n1, n2) + | NC_bounded_ge (n1, n2) + | NC_bounded_lt (n1, n2) + | NC_bounded_gt (n1, n2) + | NC_not_equal (n1, n2) -> + IdSet.union (nexp_ids' n1) (nexp_ids' n2) + | NC_or (nc1, nc2) | NC_and (nc1, nc2) -> IdSet.union (constraint_ids' nc1) (constraint_ids' nc2) | NC_var _ | NC_true | NC_false | NC_set _ -> IdSet.empty - | NC_app (id, args) -> - IdSet.add id (List.fold_left IdSet.union IdSet.empty (List.map typ_arg_ids' args)) + | NC_app (id, args) -> IdSet.add id (List.fold_left IdSet.union IdSet.empty (List.map typ_arg_ids' args)) and nexp_ids' (Nexp_aux (aux, _)) = match aux with | Nexp_id id -> IdSet.singleton id - | Nexp_app (id, nexps) -> - IdSet.add id (List.fold_left IdSet.union IdSet.empty (List.map nexp_ids' nexps)) + | Nexp_app (id, nexps) -> IdSet.add id (List.fold_left IdSet.union IdSet.empty (List.map nexp_ids' nexps)) | Nexp_var _ | Nexp_constant _ -> IdSet.empty | Nexp_exp n | Nexp_neg n -> nexp_ids' n - | Nexp_times (n1, n2) | Nexp_sum (n1, n2) | Nexp_minus (n1, n2) -> - IdSet.union (nexp_ids' n1) (nexp_ids' n2) + | Nexp_times (n1, n2) | Nexp_sum (n1, n2) | Nexp_minus (n1, n2) -> IdSet.union (nexp_ids' n1) (nexp_ids' n2) and typ_ids' (Typ_aux (aux, _)) = match aux with | Typ_var _ | Typ_internal_unknown -> IdSet.empty | Typ_id id -> IdSet.singleton id - | Typ_app (id, args) -> - IdSet.add id (List.fold_left IdSet.union IdSet.empty (List.map typ_arg_ids' args)) - | Typ_fn (typs, typ) -> - IdSet.union (typ_ids' typ) (List.fold_left IdSet.union IdSet.empty (List.map typ_ids' typs)) - | Typ_bidir (typ1, typ2) -> - IdSet.union (typ_ids' typ1) (typ_ids' typ2) - | Typ_tuple typs -> - List.fold_left IdSet.union IdSet.empty (List.map typ_ids' typs) + | Typ_app (id, args) -> IdSet.add id (List.fold_left IdSet.union IdSet.empty (List.map typ_arg_ids' args)) + | Typ_fn (typs, typ) -> IdSet.union (typ_ids' typ) (List.fold_left IdSet.union IdSet.empty (List.map typ_ids' typs)) + | Typ_bidir (typ1, typ2) -> IdSet.union (typ_ids' typ1) (typ_ids' typ2) + | Typ_tuple typs -> List.fold_left IdSet.union IdSet.empty (List.map typ_ids' typs) | Typ_exist (_, _, typ) -> typ_ids' typ and typ_arg_ids' (A_aux (aux, _)) = @@ -161,6 +158,7 @@ and typ_arg_ids' (A_aux (aux, _)) = | A_order _ -> IdSet.empty let constraint_ids nc = IdSet.diff (constraint_ids' nc) builtins + and typ_ids typ = IdSet.diff (typ_ids' typ) builtins let typ_arg_ids nc = IdSet.diff (typ_arg_ids' nc) builtins @@ -171,12 +169,11 @@ let add_def_to_graph graph (DEF_aux (def, _)) = let graph = ref graph in let scan_pat self p_aux annot = - begin match p_aux with - | P_app (id, _) -> - graph := G.add_edge self (Constructor id) !graph - | P_typ (typ, _) -> - IdSet.iter (fun id -> graph := G.add_edge self (Type id) !graph) (typ_ids typ) - | _ -> () + begin + match p_aux with + | P_app (id, _) -> graph := G.add_edge self (Constructor id) !graph + | P_typ (typ, _) -> IdSet.iter (fun id -> graph := G.add_edge self (Type id) !graph) (typ_ids typ) + | _ -> () end; P_aux (p_aux, annot) in @@ -184,66 +181,59 @@ let add_def_to_graph graph (DEF_aux (def, _)) = let scan_lexp self lexp_aux annot = let env = env_of_annot annot in - begin match lexp_aux with - | LE_typ (typ, id) -> - IdSet.iter (fun id -> graph := G.add_edge self (Type id) !graph) (typ_ids typ); - begin match Env.lookup_id id env with - | Register _ -> - graph := G.add_edge self (Register id) !graph - | Enum _ -> graph := G.add_edge self (Constructor id) !graph - | _ -> - if IdSet.mem id (Env.get_toplevel_lets env) then - graph := G.add_edge self (Letbind id) !graph - else () - end - | LE_app (id, _) -> - graph := G.add_edge self (Function id) !graph - | LE_id id -> - begin match Env.lookup_id id env with - | Register _ -> - graph := G.add_edge self (Register id) !graph - | Enum _ -> graph := G.add_edge self (Constructor id) !graph - | _ -> - if IdSet.mem id (Env.get_toplevel_lets env) then - graph := G.add_edge self (Letbind id) !graph - else () - end - | _ -> () + begin + match lexp_aux with + | LE_typ (typ, id) -> + IdSet.iter (fun id -> graph := G.add_edge self (Type id) !graph) (typ_ids typ); + begin + match Env.lookup_id id env with + | Register _ -> graph := G.add_edge self (Register id) !graph + | Enum _ -> graph := G.add_edge self (Constructor id) !graph + | _ -> if IdSet.mem id (Env.get_toplevel_lets env) then graph := G.add_edge self (Letbind id) !graph else () + end + | LE_app (id, _) -> graph := G.add_edge self (Function id) !graph + | LE_id id -> begin + match Env.lookup_id id env with + | Register _ -> graph := G.add_edge self (Register id) !graph + | Enum _ -> graph := G.add_edge self (Constructor id) !graph + | _ -> if IdSet.mem id (Env.get_toplevel_lets env) then graph := G.add_edge self (Letbind id) !graph else () + end + | _ -> () end; LE_aux (lexp_aux, annot) in let scan_exp self e_aux annot = let env = env_of_annot annot in - begin match e_aux with - | E_id id -> - begin match Env.lookup_id id env with - | Register _ -> graph := G.add_edge self (Register id) !graph - | Enum _ -> graph := G.add_edge self (Constructor id) !graph - | _ -> - if IdSet.mem id (Env.get_toplevel_lets env) then - graph := G.add_edge self (Letbind id) !graph - else () - end - | E_app (id, _) -> - if Env.is_union_constructor id env then - graph := G.add_edge self (Constructor id) !graph - else - graph := G.add_edge self (Function id) !graph - | E_ref id -> - graph := G.add_edge self (Register id) !graph - | E_typ (typ, _) -> - IdSet.iter (fun id -> graph := G.add_edge self (Type id) !graph) (typ_ids typ) - | _ -> () + begin + match e_aux with + | E_id id -> begin + match Env.lookup_id id env with + | Register _ -> graph := G.add_edge self (Register id) !graph + | Enum _ -> graph := G.add_edge self (Constructor id) !graph + | _ -> if IdSet.mem id (Env.get_toplevel_lets env) then graph := G.add_edge self (Letbind id) !graph else () + end + | E_app (id, _) -> + if Env.is_union_constructor id env then graph := G.add_edge self (Constructor id) !graph + else graph := G.add_edge self (Function id) !graph + | E_ref id -> graph := G.add_edge self (Register id) !graph + | E_typ (typ, _) -> IdSet.iter (fun id -> graph := G.add_edge self (Type id) !graph) (typ_ids typ) + | _ -> () end; E_aux (e_aux, annot) in - let rw_exp self = { id_exp_alg with e_aux = (fun (e_aux, annot) -> scan_exp self e_aux annot); - le_aux = (fun (l_aux, annot) -> scan_lexp self l_aux annot); - pat_alg = rw_pat self } in + let rw_exp self = + { + id_exp_alg with + e_aux = (fun (e_aux, annot) -> scan_exp self e_aux annot); + le_aux = (fun (l_aux, annot) -> scan_lexp self l_aux annot); + pat_alg = rw_pat self; + } + in let rewriters self = - { rewriters_base with + { + rewriters_base with rewrite_exp = (fun _ -> fold_exp (rw_exp self)); rewrite_pat = (fun _ -> fold_pat (rw_pat self)); rewrite_let = (fun _ -> fold_letbind (rw_exp self)); @@ -253,14 +243,11 @@ let add_def_to_graph graph (DEF_aux (def, _)) = let scan_quant_item self (QI_aux (aux, _)) = match aux with | QI_id _ -> () - | QI_constraint nc -> - IdSet.iter (fun id -> graph := G.add_edge self (Type id) !graph) (constraint_ids nc) + | QI_constraint nc -> IdSet.iter (fun id -> graph := G.add_edge self (Type id) !graph) (constraint_ids nc) in let scan_typquant self (TypQ_aux (aux, _)) = - match aux with - | TypQ_no_forall -> () - | TypQ_tq quants -> List.iter (scan_quant_item self) quants + match aux with TypQ_no_forall -> () | TypQ_tq quants -> List.iter (scan_quant_item self) quants in let scan_loop_measure self (Loop (_, exp)) = ignore (fold_exp (rw_exp self) exp) in @@ -268,126 +255,128 @@ let add_def_to_graph graph (DEF_aux (def, _)) = let add_type_def_to_graph (TD_aux (aux, (l, _))) = match aux with | TD_abbrev (id, typq, arg) -> - graph := G.add_edges (Type id) (List.map (fun id -> Type id) (IdSet.elements (typ_arg_ids arg))) !graph; - scan_typquant (Type id) typq + graph := G.add_edges (Type id) (List.map (fun id -> Type id) (IdSet.elements (typ_arg_ids arg))) !graph; + scan_typquant (Type id) typq | TD_record (id, typq, fields, _) -> - let field_nodes = - List.map (fun (typ, _) -> typ_ids typ) fields - |> List.fold_left IdSet.union IdSet.empty - |> IdSet.elements - |> List.map (fun id -> Type id) - in - graph := G.add_edges (Type id) field_nodes !graph; - scan_typquant (Type id) typq + let field_nodes = + List.map (fun (typ, _) -> typ_ids typ) fields + |> List.fold_left IdSet.union IdSet.empty |> IdSet.elements + |> List.map (fun id -> Type id) + in + graph := G.add_edges (Type id) field_nodes !graph; + scan_typquant (Type id) typq | TD_variant (id, typq, ctors, _) -> - let ctor_nodes = - List.map (fun (Tu_aux (Tu_ty_id (typ, id), _)) -> (typ_ids typ, id)) ctors - |> List.fold_left (fun (ids, ctors) (ids', ctor) -> (IdSet.union ids ids', IdSet.add ctor ctors)) (IdSet.empty, IdSet.empty) - in - IdSet.iter (fun ctor_id -> graph := G.add_edge (Constructor ctor_id) (Type id) !graph) (snd ctor_nodes); - IdSet.iter (fun typ_id -> graph := G.add_edge (Type id) (Type typ_id) !graph) (fst ctor_nodes); - scan_typquant (Type id) typq + let ctor_nodes = + List.map (fun (Tu_aux (Tu_ty_id (typ, id), _)) -> (typ_ids typ, id)) ctors + |> List.fold_left + (fun (ids, ctors) (ids', ctor) -> (IdSet.union ids ids', IdSet.add ctor ctors)) + (IdSet.empty, IdSet.empty) + in + IdSet.iter (fun ctor_id -> graph := G.add_edge (Constructor ctor_id) (Type id) !graph) (snd ctor_nodes); + IdSet.iter (fun typ_id -> graph := G.add_edge (Type id) (Type typ_id) !graph) (fst ctor_nodes); + scan_typquant (Type id) typq | TD_enum (id, ctors, _) -> - List.iter (fun ctor_id -> graph := G.add_edge (Constructor ctor_id) (Type id) !graph) ctors + List.iter (fun ctor_id -> graph := G.add_edge (Constructor ctor_id) (Type id) !graph) ctors | TD_bitfield (id, typ, ranges) -> - graph := G.add_edges (Type id) (List.map (fun id -> Type id) (IdSet.elements (typ_ids typ))) !graph + graph := G.add_edges (Type id) (List.map (fun id -> Type id) (IdSet.elements (typ_ids typ))) !graph in let scan_outcome_def l outcome (DEF_aux (aux, _)) = match aux with | DEF_val (VS_aux (VS_val_spec (TypSchm_aux (TypSchm_ts (typq, typ), _), _, _, _), _)) -> - graph := G.add_edges outcome [] !graph; - scan_typquant outcome typq; - IdSet.iter (fun typ_id -> graph := G.add_edge outcome (Type typ_id) !graph) (typ_ids typ) - | DEF_impl (FCL_aux (FCL_funcl (_, pexp), _)) -> - ignore (rewrite_pexp (rewriters outcome) pexp) - | _ -> - Reporting.unreachable l __POS__ "Unexpected definition in outcome block" + graph := G.add_edges outcome [] !graph; + scan_typquant outcome typq; + IdSet.iter (fun typ_id -> graph := G.add_edge outcome (Type typ_id) !graph) (typ_ids typ) + | DEF_impl (FCL_aux (FCL_funcl (_, pexp), _)) -> ignore (rewrite_pexp (rewriters outcome) pexp) + | _ -> Reporting.unreachable l __POS__ "Unexpected definition in outcome block" in let scan_fundef_tannot self (FD_aux (FD_function (_, Typ_annot_opt_aux (tannotopt, _), _), _)) = match tannotopt with | Typ_annot_opt_none -> () | Typ_annot_opt_some (typq, typ) -> - scan_typquant self typq; - IdSet.iter (fun typ_id -> graph := G.add_edge self (Type typ_id) !graph) (typ_ids typ) + scan_typquant self typq; + IdSet.iter (fun typ_id -> graph := G.add_edge self (Type typ_id) !graph) (typ_ids typ) in - begin match def with - | DEF_val (VS_aux (VS_val_spec (TypSchm_aux (TypSchm_ts (typq, (Typ_aux (Typ_bidir _, _) as typ)), _), id, _, _), _)) -> - graph := G.add_edges (Mapping id) [] !graph; - List.iter (fun gen_id -> - graph := G.add_edges (Function gen_id) [Mapping id] !graph - ) [append_id id "_forwards"; append_id id "_forwards_matches"; append_id id "_backwards"; append_id id "_backwards_matches"]; - scan_typquant (Mapping id) typq; - IdSet.iter (fun typ_id -> graph := G.add_edge (Mapping id) (Type typ_id) !graph) (typ_ids typ) - | DEF_val (VS_aux (VS_val_spec (TypSchm_aux (TypSchm_ts (typq, typ), _), id, _, _), _)) -> - graph := G.add_edges (Function id) [] !graph; - scan_typquant (Function id) typq; - IdSet.iter (fun typ_id -> graph := G.add_edge (Function id) (Type typ_id) !graph) (typ_ids typ) - | DEF_fundef fdef -> - let id = id_of_fundef fdef in - graph := G.add_edges (Function id) [] !graph; - scan_fundef_tannot (Function id) fdef; - ignore (rewrite_fun (rewriters (Function id)) fdef) - | DEF_mapdef mdef -> - let id = id_of_mapdef mdef in - graph := G.add_edges (Mapping id) [] !graph; - ignore (rewrite_mapdef (rewriters (Mapping id)) mdef) - | DEF_let (LB_aux (LB_val (pat, exp), _) as lb) -> - let ids = pat_ids pat in - IdSet.iter (fun id -> graph := G.add_edges (Letbind id) [] !graph) ids; - IdSet.iter (fun id -> ignore (rewrite_let (rewriters (Letbind id)) lb)) ids - | DEF_type tdef -> - add_type_def_to_graph tdef - | DEF_register (DEC_aux (DEC_reg (typ, id, opt_exp), _)) -> - begin match opt_exp with - | Some exp -> ignore (fold_exp (rw_exp (Register id)) exp); - | None -> () - end; - IdSet.iter (fun typ_id -> graph := G.add_edge (Register id) (Type typ_id) !graph) (typ_ids typ) - | DEF_measure (id, pat, exp) -> - graph := G.add_edges (FunctionMeasure id) [Function id] !graph; - ignore (fold_pat (rw_pat (FunctionMeasure id)) pat); - ignore (fold_exp (rw_exp (FunctionMeasure id)) exp) - | DEF_loop_measures (id, measures) -> - graph := G.add_edges (LoopMeasures id) [Function id] !graph; - List.iter (scan_loop_measure (LoopMeasures id)) measures - | DEF_outcome (OV_aux (OV_outcome (id, TypSchm_aux (TypSchm_ts (typq, typ), _), _), l), outcome_defs) -> - graph := G.add_edges (Outcome id) [] !graph; - scan_typquant (Outcome id) typq; - IdSet.iter (fun typ_id -> graph := G.add_edge (Function id) (Type typ_id) !graph) (typ_ids typ); - List.iter (scan_outcome_def l (Outcome id)) outcome_defs - | DEF_instantiation (IN_aux (IN_id id, _), substs) -> - graph := G.add_edges (Function id) [Outcome id] !graph; - List.iter (function - | IS_aux (IS_id (_, id_to), _) -> - graph := G.add_edges (Function id) [Function id_to] !graph - | IS_aux (IS_typ (_, typ), _) -> - IdSet.iter (fun typ_id -> graph := G.add_edge (Function id) (Type typ_id) !graph) (typ_ids typ) - ) substs - | DEF_scattered (SD_aux (sdef, _)) -> - begin match sdef with - | SD_funcl (FCL_aux (FCL_funcl (id, pexp), _)) -> - ignore (rewrite_pexp (rewriters (Function id)) pexp) - | _ -> () - end - | _ -> () + begin + match def with + | DEF_val (VS_aux (VS_val_spec (TypSchm_aux (TypSchm_ts (typq, (Typ_aux (Typ_bidir _, _) as typ)), _), id, _, _), _)) + -> + graph := G.add_edges (Mapping id) [] !graph; + List.iter + (fun gen_id -> graph := G.add_edges (Function gen_id) [Mapping id] !graph) + [ + append_id id "_forwards"; + append_id id "_forwards_matches"; + append_id id "_backwards"; + append_id id "_backwards_matches"; + ]; + scan_typquant (Mapping id) typq; + IdSet.iter (fun typ_id -> graph := G.add_edge (Mapping id) (Type typ_id) !graph) (typ_ids typ) + | DEF_val (VS_aux (VS_val_spec (TypSchm_aux (TypSchm_ts (typq, typ), _), id, _, _), _)) -> + graph := G.add_edges (Function id) [] !graph; + scan_typquant (Function id) typq; + IdSet.iter (fun typ_id -> graph := G.add_edge (Function id) (Type typ_id) !graph) (typ_ids typ) + | DEF_fundef fdef -> + let id = id_of_fundef fdef in + graph := G.add_edges (Function id) [] !graph; + scan_fundef_tannot (Function id) fdef; + ignore (rewrite_fun (rewriters (Function id)) fdef) + | DEF_mapdef mdef -> + let id = id_of_mapdef mdef in + graph := G.add_edges (Mapping id) [] !graph; + ignore (rewrite_mapdef (rewriters (Mapping id)) mdef) + | DEF_let (LB_aux (LB_val (pat, exp), _) as lb) -> + let ids = pat_ids pat in + IdSet.iter (fun id -> graph := G.add_edges (Letbind id) [] !graph) ids; + IdSet.iter (fun id -> ignore (rewrite_let (rewriters (Letbind id)) lb)) ids + | DEF_type tdef -> add_type_def_to_graph tdef + | DEF_register (DEC_aux (DEC_reg (typ, id, opt_exp), _)) -> + begin + match opt_exp with Some exp -> ignore (fold_exp (rw_exp (Register id)) exp) | None -> () + end; + IdSet.iter (fun typ_id -> graph := G.add_edge (Register id) (Type typ_id) !graph) (typ_ids typ) + | DEF_measure (id, pat, exp) -> + graph := G.add_edges (FunctionMeasure id) [Function id] !graph; + ignore (fold_pat (rw_pat (FunctionMeasure id)) pat); + ignore (fold_exp (rw_exp (FunctionMeasure id)) exp) + | DEF_loop_measures (id, measures) -> + graph := G.add_edges (LoopMeasures id) [Function id] !graph; + List.iter (scan_loop_measure (LoopMeasures id)) measures + | DEF_outcome (OV_aux (OV_outcome (id, TypSchm_aux (TypSchm_ts (typq, typ), _), _), l), outcome_defs) -> + graph := G.add_edges (Outcome id) [] !graph; + scan_typquant (Outcome id) typq; + IdSet.iter (fun typ_id -> graph := G.add_edge (Function id) (Type typ_id) !graph) (typ_ids typ); + List.iter (scan_outcome_def l (Outcome id)) outcome_defs + | DEF_instantiation (IN_aux (IN_id id, _), substs) -> + graph := G.add_edges (Function id) [Outcome id] !graph; + List.iter + (function + | IS_aux (IS_id (_, id_to), _) -> graph := G.add_edges (Function id) [Function id_to] !graph + | IS_aux (IS_typ (_, typ), _) -> + IdSet.iter (fun typ_id -> graph := G.add_edge (Function id) (Type typ_id) !graph) (typ_ids typ) + ) + substs + | DEF_scattered (SD_aux (sdef, _)) -> begin + match sdef with + | SD_funcl (FCL_aux (FCL_funcl (id, pexp), _)) -> ignore (rewrite_pexp (rewriters (Function id)) pexp) + | _ -> () + end + | _ -> () end; !graph let rec graph_of_defs defs = - let module G = Graph.Make(Node) in - + let module G = Graph.Make (Node) in match defs with | def :: defs -> - let g = graph_of_defs defs in - add_def_to_graph g def - + let g = graph_of_defs defs in + add_def_to_graph g def | [] -> G.empty let graph_of_ast ast = graph_of_defs ast.defs - + let id_of_typedef (TD_aux (aux, _)) = match aux with | TD_abbrev (id, _, _) -> id @@ -402,66 +391,49 @@ let id_of_funcl (FCL_aux (FCL_funcl (id, _), _)) = id let filter_ast_extra cuts g ast keep_std = let rec filter_ast' g = - let module NS = Set.Make(Node) in - let module NM = Map.Make(Node) in + let module NS = Set.Make (Node) in + let module NM = Map.Make (Node) in function - | DEF_aux (DEF_fundef fdef, _) :: defs when NS.mem (Function (id_of_fundef fdef)) cuts -> - filter_ast' g defs + | DEF_aux (DEF_fundef fdef, _) :: defs when NS.mem (Function (id_of_fundef fdef)) cuts -> filter_ast' g defs | DEF_aux (DEF_fundef fdef, def_annot) :: defs when NM.mem (Function (id_of_fundef fdef)) g -> - DEF_aux (DEF_fundef fdef, def_annot) :: filter_ast' g defs - | DEF_aux (DEF_fundef _, _) :: defs -> - filter_ast' g defs - + DEF_aux (DEF_fundef fdef, def_annot) :: filter_ast' g defs + | DEF_aux (DEF_fundef _, _) :: defs -> filter_ast' g defs | DEF_aux (DEF_scattered (SD_aux (SD_funcl funcl, _)), _) :: defs when NS.mem (Function (id_of_funcl funcl)) cuts -> - filter_ast' g defs - | DEF_aux (DEF_scattered (SD_aux (SD_funcl funcl, a)), def_annot) :: defs when NM.mem (Function (id_of_funcl funcl)) g -> - DEF_aux (DEF_scattered (SD_aux (SD_funcl funcl, a)), def_annot) :: filter_ast' g defs - | DEF_aux (DEF_scattered (SD_aux (SD_funcl _, _)), _) :: defs -> - filter_ast' g defs - + filter_ast' g defs + | DEF_aux (DEF_scattered (SD_aux (SD_funcl funcl, a)), def_annot) :: defs + when NM.mem (Function (id_of_funcl funcl)) g -> + DEF_aux (DEF_scattered (SD_aux (SD_funcl funcl, a)), def_annot) :: filter_ast' g defs + | DEF_aux (DEF_scattered (SD_aux (SD_funcl _, _)), _) :: defs -> filter_ast' g defs | DEF_aux (DEF_register rdec, def_annot) :: defs when NM.mem (Register (id_of_reg_dec rdec)) g -> - DEF_aux (DEF_register rdec, def_annot) :: filter_ast' g defs - | DEF_aux (DEF_register _, _) :: defs -> - filter_ast' g defs - + DEF_aux (DEF_register rdec, def_annot) :: filter_ast' g defs + | DEF_aux (DEF_register _, _) :: defs -> filter_ast' g defs | DEF_aux (DEF_val vs, def_annot) :: defs when NM.mem (Function (id_of_val_spec vs)) g -> - DEF_aux (DEF_val vs, def_annot) :: filter_ast' g defs - | DEF_aux (DEF_val _, _) :: defs -> - filter_ast' g defs - + DEF_aux (DEF_val vs, def_annot) :: filter_ast' g defs + | DEF_aux (DEF_val _, _) :: defs -> filter_ast' g defs | DEF_aux (DEF_let (LB_aux (LB_val (pat, exp), _) as lb), def_annot) :: defs -> - let ids = pat_ids pat |> IdSet.elements in - if List.exists (fun id -> NM.mem (Letbind id) g) ids then - DEF_aux (DEF_let lb, def_annot) :: filter_ast' g defs - else - filter_ast' g defs - + let ids = pat_ids pat |> IdSet.elements in + if List.exists (fun id -> NM.mem (Letbind id) g) ids then DEF_aux (DEF_let lb, def_annot) :: filter_ast' g defs + else filter_ast' g defs | DEF_aux (DEF_type tdef, def_annot) :: defs when NM.mem (Type (id_of_typedef tdef)) g -> - DEF_aux (DEF_type tdef, def_annot) :: filter_ast' g defs - | DEF_aux (DEF_type _, _) :: defs -> - filter_ast' g defs - - | DEF_aux (DEF_measure (id, _, _), _) :: defs when NS.mem (Function id) cuts -> - filter_ast' g defs - | (DEF_aux (DEF_measure (id, _, _), _) as def) :: defs when NM.mem (Function id) g -> - def :: filter_ast' g defs - | DEF_aux (DEF_measure _, _) :: defs -> - filter_ast' g defs - + DEF_aux (DEF_type tdef, def_annot) :: filter_ast' g defs + | DEF_aux (DEF_type _, _) :: defs -> filter_ast' g defs + | DEF_aux (DEF_measure (id, _, _), _) :: defs when NS.mem (Function id) cuts -> filter_ast' g defs + | (DEF_aux (DEF_measure (id, _, _), _) as def) :: defs when NM.mem (Function id) g -> def :: filter_ast' g defs + | DEF_aux (DEF_measure _, _) :: defs -> filter_ast' g defs | (DEF_aux (DEF_pragma ("include_start", file_name, _), _) as def) :: defs when keep_std -> - (* TODO: proper check *) - let d = Filename.dirname file_name in - if Filename.basename d = "lib" && Filename.basename (Filename.dirname d) = "sail" then - let rec in_file = function - | [] -> [] - | DEF_aux (DEF_pragma ("include_end", file_name', _), _) as def :: defs when file_name = file_name' -> - def :: filter_ast' g defs - | def :: defs -> def :: in_file defs - in def :: in_file defs - else def :: filter_ast' g defs - + (* TODO: proper check *) + let d = Filename.dirname file_name in + if Filename.basename d = "lib" && Filename.basename (Filename.dirname d) = "sail" then ( + let rec in_file = function + | [] -> [] + | (DEF_aux (DEF_pragma ("include_end", file_name', _), _) as def) :: defs when file_name = file_name' -> + def :: filter_ast' g defs + | def :: defs -> def :: in_file defs + in + def :: in_file defs + ) + else def :: filter_ast' g defs | def :: defs -> def :: filter_ast' g defs - | [] -> [] in { ast with defs = filter_ast' g ast.defs } @@ -469,8 +441,8 @@ let filter_ast_extra cuts g ast keep_std = let filter_ast cuts g ast = filter_ast_extra cuts g ast false let filter_ast_ids roots cuts ast = - let module NodeSet = Set.Make(Node) in - let module G = Graph.Make(Node) in + let module NodeSet = Set.Make (Node) in + let module G = Graph.Make (Node) in let g = graph_of_ast ast in let roots = roots |> IdSet.elements |> List.map (fun id -> Function id) |> NodeSet.of_list in let cuts = cuts |> IdSet.elements |> List.map (fun id -> Function id) |> NodeSet.of_list in diff --git a/src/lib/callgraph.mli b/src/lib/callgraph.mli index e7e2ef252..a9e73373d 100644 --- a/src/lib/callgraph.mli +++ b/src/lib/callgraph.mli @@ -84,22 +84,20 @@ type node = | Outcome of id val node_id : node -> id - + module Node : sig type t = node val compare : node -> node -> int end module G : sig - include Graph.S with type node = Node.t - and type node_set = Set.Make(Node).t - and type graph = Graph.Make(Node).graph + include Graph.S with type node = Node.t and type node_set = Set.Make(Node).t and type graph = Graph.Make(Node).graph end - + type callgraph = G.graph val graph_of_ast : Type_check.tannot ast -> callgraph - + val filter_ast_ids : IdSet.t -> IdSet.t -> Type_check.tannot ast -> Type_check.tannot ast val filter_ast : Set.Make(Node).t -> callgraph -> 'a ast -> 'a ast diff --git a/src/lib/chunk_ast.ml b/src/lib/chunk_ast.ml index 1a5a592ef..46c53e196 100644 --- a/src/lib/chunk_ast.ml +++ b/src/lib/chunk_ast.ml @@ -67,47 +67,30 @@ open Parse_ast -let string_of_id_aux = function - | Id v -> v - | Operator v -> v +let string_of_id_aux = function Id v -> v | Operator v -> v let string_of_id (Id_aux (id, l)) = string_of_id_aux id let id_loc (Id_aux (_, l)) = l -let starting_line_num l = match Reporting.simp_loc l with - | Some (s, _) -> Some s.pos_lnum - | None -> None +let starting_line_num l = match Reporting.simp_loc l with Some (s, _) -> Some s.pos_lnum | None -> None -let starting_column_num l = match Reporting.simp_loc l with - | Some (s, _) -> Some (s.pos_cnum - s.pos_bol) - | None -> None +let starting_column_num l = + match Reporting.simp_loc l with Some (s, _) -> Some (s.pos_cnum - s.pos_bol) | None -> None -let ending_line_num l = match Reporting.simp_loc l with - | Some (_, e) -> Some e.pos_lnum - | None -> None +let ending_line_num l = match Reporting.simp_loc l with Some (_, e) -> Some e.pos_lnum | None -> None type binder = Var_binder | Let_binder | Internal_plet_binder -type if_format = { - then_brace : bool; - else_brace : bool - } +type if_format = { then_brace : bool; else_brace : bool } type match_kind = Try_match | Match_match -let match_keywords = function - | Try_match -> "try", Some "catch" - | Match_match -> "match", None +let match_keywords = function Try_match -> ("try", Some "catch") | Match_match -> ("match", None) -let binder_keyword = function - | Var_binder -> "var" - | Let_binder -> "let" - | Internal_plet_binder -> "internal_plet" +let binder_keyword = function Var_binder -> "var" | Let_binder -> "let" | Internal_plet_binder -> "internal_plet" -let comment_type_delimiters = function - | Lexer.Comment_line -> "//", "" - | Lexer.Comment_block -> "/*", "*/" +let comment_type_delimiters = function Lexer.Comment_line -> ("//", "") | Lexer.Comment_block -> ("/*", "*/") type chunk = | Comment of Lexer.comment_type * int * int * string @@ -118,34 +101,13 @@ type chunk = rec_opt : chunks option; typq_opt : chunks option; return_typ_opt : chunks option; - funcls : pexp_chunks list - } - | Val of { - is_cast : bool; - id : id; - extern_opt : extern option; - typq_opt : chunks option; - typ : chunks; - } - | Enum of { - id : id; - enum_functions : chunks list option; - members : chunks list - } - | Function_typ of { - mapping : bool; - lhs : chunks; - rhs : chunks; - } - | Exists of { - vars : chunks; - constr : chunks; - typ : chunks; - } - | Typ_quant of { - vars : chunks; - constr_opt : chunks option; + funcls : pexp_chunks list; } + | Val of { is_cast : bool; id : id; extern_opt : extern option; typq_opt : chunks option; typ : chunks } + | Enum of { id : id; enum_functions : chunks list option; members : chunks list } + | Function_typ of { mapping : bool; lhs : chunks; rhs : chunks } + | Exists of { vars : chunks; constr : chunks; typ : chunks } + | Typ_quant of { vars : chunks; constr_opt : chunks option } | App of id * chunks list | Field of chunks * id | Tuple of string * string * int * chunks list @@ -165,324 +127,334 @@ type chunk = | If_then of bool * chunks * chunks | If_then_else of if_format * chunks * chunks * chunks | Struct_update of chunks * chunks list - | Match of { - kind : match_kind; - exp : chunks; - aligned : bool; - cases : pexp_chunks list - } + | Match of { kind : match_kind; exp : chunks; aligned : bool; cases : pexp_chunks list } | Foreach of { var : chunks; decreasing : bool; from_index : chunks; to_index : chunks; step : chunks option; - body : chunks - } - | While of { - repeat_until : bool; - termination_measure : chunks option; - cond : chunks; - body : chunks + body : chunks; } + | While of { repeat_until : bool; termination_measure : chunks option; cond : chunks; body : chunks } | Vector_updates of chunks * chunk list | Chunks of chunks | Raw of string and chunks = chunk Queue.t -and pexp_chunks = { - funcl_space : bool; - pat : chunks; - guard : chunks option; - body : chunks - } +and pexp_chunks = { funcl_space : bool; pat : chunks; guard : chunks option; body : chunks } -let add_chunk q chunk = - Queue.add chunk q +let add_chunk q chunk = Queue.add chunk q [@@@coverage off] let rec prerr_chunk indent = function | Comment (comment_type, n, col, contents) -> - let s, e = comment_type_delimiters comment_type in - Printf.eprintf "%sComment: blank=%d col=%d %s%s%s\n" indent n col s contents e - | Spacer (line, w) -> - Printf.eprintf "%sSpacer:%b %d\n" indent line w; - | Atom str -> - Printf.eprintf "%sAtom:%s\n" indent str - | String_literal str -> - Printf.eprintf "%sString_literal:%s\n" indent str + let s, e = comment_type_delimiters comment_type in + Printf.eprintf "%sComment: blank=%d col=%d %s%s%s\n" indent n col s contents e + | Spacer (line, w) -> Printf.eprintf "%sSpacer:%b %d\n" indent line w + | Atom str -> Printf.eprintf "%sAtom:%s\n" indent str + | String_literal str -> Printf.eprintf "%sString_literal:%s\n" indent str | App (id, args) -> - Printf.eprintf "%sApp:%s\n" indent (string_of_id id); - List.iteri (fun i arg -> - Printf.eprintf "%s %d:\n" indent i; - Queue.iter (prerr_chunk (indent ^ " ")) arg - ) args + Printf.eprintf "%sApp:%s\n" indent (string_of_id id); + List.iteri + (fun i arg -> + Printf.eprintf "%s %d:\n" indent i; + Queue.iter (prerr_chunk (indent ^ " ")) arg + ) + args | Tuple (s, e, n, args) -> - Printf.eprintf "%sTuple:%s %s %d\n" indent s e n; - List.iteri (fun i arg -> - Printf.eprintf "%s %d:\n" indent i; - Queue.iter (prerr_chunk (indent ^ " ")) arg - ) args + Printf.eprintf "%sTuple:%s %s %d\n" indent s e n; + List.iteri + (fun i arg -> + Printf.eprintf "%s %d:\n" indent i; + Queue.iter (prerr_chunk (indent ^ " ")) arg + ) + args | Intersperse (str, args) -> - Printf.eprintf "%sIntersperse:%s\n" indent str; - List.iteri (fun i arg -> - Printf.eprintf "%s %d:\n" indent i; - Queue.iter (prerr_chunk (indent ^ " ")) arg - ) args + Printf.eprintf "%sIntersperse:%s\n" indent str; + List.iteri + (fun i arg -> + Printf.eprintf "%s %d:\n" indent i; + Queue.iter (prerr_chunk (indent ^ " ")) arg + ) + args | Block (always_hardline, args) -> - Printf.eprintf "%sBlock: always_hardline=%b\n" indent always_hardline; - List.iteri (fun i arg -> - Printf.eprintf "%s %d:\n" indent i; - Queue.iter (prerr_chunk (indent ^ " ")) arg - ) args + Printf.eprintf "%sBlock: always_hardline=%b\n" indent always_hardline; + List.iteri + (fun i arg -> + Printf.eprintf "%s %d:\n" indent i; + Queue.iter (prerr_chunk (indent ^ " ")) arg + ) + args | Function fn -> - Printf.eprintf "%sFunction:%s clause=%b\n" indent (string_of_id fn.id) fn.clause; - begin match fn.typq_opt with - | Some typq -> - Printf.eprintf "%s typq:\n" indent; - Queue.iter (prerr_chunk (indent ^ " ")) typq - | None -> () - end; - begin match fn.return_typ_opt with - | Some return_typ -> - Printf.eprintf "%s return_typ:\n" indent; - Queue.iter (prerr_chunk (indent ^ " ")) return_typ - | None -> () - end; - List.iteri (fun i funcl -> - Printf.eprintf "%s pat %d:\n" indent i; - Queue.iter (prerr_chunk (indent ^ " ")) funcl.pat; - begin match funcl.guard with - | Some guard -> - Printf.eprintf "%s guard %d:\n" indent i; - Queue.iter (prerr_chunk (indent ^ " ")) guard; - | None -> () - end; - Printf.eprintf "%s body %d:\n" indent i; - Queue.iter (prerr_chunk (indent ^ " ")) funcl.body; - ) fn.funcls + Printf.eprintf "%sFunction:%s clause=%b\n" indent (string_of_id fn.id) fn.clause; + begin + match fn.typq_opt with + | Some typq -> + Printf.eprintf "%s typq:\n" indent; + Queue.iter (prerr_chunk (indent ^ " ")) typq + | None -> () + end; + begin + match fn.return_typ_opt with + | Some return_typ -> + Printf.eprintf "%s return_typ:\n" indent; + Queue.iter (prerr_chunk (indent ^ " ")) return_typ + | None -> () + end; + List.iteri + (fun i funcl -> + Printf.eprintf "%s pat %d:\n" indent i; + Queue.iter (prerr_chunk (indent ^ " ")) funcl.pat; + begin + match funcl.guard with + | Some guard -> + Printf.eprintf "%s guard %d:\n" indent i; + Queue.iter (prerr_chunk (indent ^ " ")) guard + | None -> () + end; + Printf.eprintf "%s body %d:\n" indent i; + Queue.iter (prerr_chunk (indent ^ " ")) funcl.body + ) + fn.funcls | Val vs -> - Printf.eprintf "%sVal:%s is_cast=%b has_extern=%b\n" - indent (string_of_id vs.id) vs.is_cast (Option.is_some vs.extern_opt) + Printf.eprintf "%sVal:%s is_cast=%b has_extern=%b\n" indent (string_of_id vs.id) vs.is_cast + (Option.is_some vs.extern_opt) | Enum e -> - Printf.eprintf "%sEnum:%s\n" indent (string_of_id e.id); - begin match e.enum_functions with - | Some enum_functions -> - List.iter (fun chunks -> - Printf.eprintf "%s enum_function:\n" indent; - Queue.iter (prerr_chunk (indent ^ " ")) chunks - ) enum_functions - | None -> () - end; - List.iter (fun chunks -> - Printf.eprintf "%s member:\n" indent; - Queue.iter (prerr_chunk (indent ^ " ")) chunks - ) e.members + Printf.eprintf "%sEnum:%s\n" indent (string_of_id e.id); + begin + match e.enum_functions with + | Some enum_functions -> + List.iter + (fun chunks -> + Printf.eprintf "%s enum_function:\n" indent; + Queue.iter (prerr_chunk (indent ^ " ")) chunks + ) + enum_functions + | None -> () + end; + List.iter + (fun chunks -> + Printf.eprintf "%s member:\n" indent; + Queue.iter (prerr_chunk (indent ^ " ")) chunks + ) + e.members | Match m -> - Printf.eprintf "%sMatch:%s %b\n" indent (fst (match_keywords m.kind)) m.aligned; - Printf.eprintf "%s exp:\n" indent; - Queue.iter (prerr_chunk (indent ^ " ")) m.exp; - List.iteri (fun i funcl -> - Printf.eprintf "%s pat %d:\n" indent i; - Queue.iter (prerr_chunk (indent ^ " ")) funcl.pat; - begin match funcl.guard with - | Some guard -> - Printf.eprintf "%s guard %d:\n" indent i; - Queue.iter (prerr_chunk (indent ^ " ")) guard; - | None -> () - end; - Printf.eprintf "%s body %d:\n" indent i; - Queue.iter (prerr_chunk (indent ^ " ")) funcl.body; - ) m.cases + Printf.eprintf "%sMatch:%s %b\n" indent (fst (match_keywords m.kind)) m.aligned; + Printf.eprintf "%s exp:\n" indent; + Queue.iter (prerr_chunk (indent ^ " ")) m.exp; + List.iteri + (fun i funcl -> + Printf.eprintf "%s pat %d:\n" indent i; + Queue.iter (prerr_chunk (indent ^ " ")) funcl.pat; + begin + match funcl.guard with + | Some guard -> + Printf.eprintf "%s guard %d:\n" indent i; + Queue.iter (prerr_chunk (indent ^ " ")) guard + | None -> () + end; + Printf.eprintf "%s body %d:\n" indent i; + Queue.iter (prerr_chunk (indent ^ " ")) funcl.body + ) + m.cases | Function_typ fn_typ -> - Printf.eprintf "%sFunction_typ: is_mapping=%b\n" indent fn_typ.mapping; - List.iter (fun (name, arg) -> - Printf.eprintf "%s %s:\n" indent name; - Queue.iter (prerr_chunk (indent ^ " ")) arg - ) [("lhs", fn_typ.lhs); ("rhs", fn_typ.rhs)] + Printf.eprintf "%sFunction_typ: is_mapping=%b\n" indent fn_typ.mapping; + List.iter + (fun (name, arg) -> + Printf.eprintf "%s %s:\n" indent name; + Queue.iter (prerr_chunk (indent ^ " ")) arg + ) + [("lhs", fn_typ.lhs); ("rhs", fn_typ.rhs)] | Foreach loop -> - Printf.eprintf "%sForeach: downto=%b\n" indent loop.decreasing; - begin match loop.step with - | Some step -> - Printf.eprintf "%s step:\n" indent; - Queue.iter (prerr_chunk (indent ^ " ")) step - | None -> () - end; - List.iter (fun (name, arg) -> - Printf.eprintf "%s %s:\n" indent name; - Queue.iter (prerr_chunk (indent ^ " ")) arg - ) [("var", loop.var); ("from", loop.from_index); ("to", loop.to_index); ("body", loop.body)] + Printf.eprintf "%sForeach: downto=%b\n" indent loop.decreasing; + begin + match loop.step with + | Some step -> + Printf.eprintf "%s step:\n" indent; + Queue.iter (prerr_chunk (indent ^ " ")) step + | None -> () + end; + List.iter + (fun (name, arg) -> + Printf.eprintf "%s %s:\n" indent name; + Queue.iter (prerr_chunk (indent ^ " ")) arg + ) + [("var", loop.var); ("from", loop.from_index); ("to", loop.to_index); ("body", loop.body)] | While loop -> - Printf.eprintf "%sWhile: repeat_until=%b\n" indent loop.repeat_until; - begin match loop.termination_measure with - | Some measure -> - Printf.eprintf "%s step:\n" indent; - Queue.iter (prerr_chunk (indent ^ " ")) measure - | None -> () - end; - List.iter (fun (name, arg) -> - Printf.eprintf "%s %s:\n" indent name; - Queue.iter (prerr_chunk (indent ^ " ")) arg - ) [("cond", loop.cond); ("body", loop.body)] + Printf.eprintf "%sWhile: repeat_until=%b\n" indent loop.repeat_until; + begin + match loop.termination_measure with + | Some measure -> + Printf.eprintf "%s step:\n" indent; + Queue.iter (prerr_chunk (indent ^ " ")) measure + | None -> () + end; + List.iter + (fun (name, arg) -> + Printf.eprintf "%s %s:\n" indent name; + Queue.iter (prerr_chunk (indent ^ " ")) arg + ) + [("cond", loop.cond); ("body", loop.body)] | Typ_quant typq -> - Printf.eprintf "%sTyp_quant:\n" indent; - List.iter (fun (name, arg) -> - Printf.eprintf "%s %s:\n" indent name; - Queue.iter (prerr_chunk (indent ^ " ")) arg - ) (match typq.constr_opt with - | Some constr -> [("vars", typq.vars); ("constr", constr)] - | None -> [("vars", typq.vars)]) - | Pragma (pragma, arg) -> - Printf.eprintf "%sPragma:$%s %s\n" indent pragma arg + Printf.eprintf "%sTyp_quant:\n" indent; + List.iter + (fun (name, arg) -> + Printf.eprintf "%s %s:\n" indent name; + Queue.iter (prerr_chunk (indent ^ " ")) arg + ) + ( match typq.constr_opt with + | Some constr -> [("vars", typq.vars); ("constr", constr)] + | None -> [("vars", typq.vars)] + ) + | Pragma (pragma, arg) -> Printf.eprintf "%sPragma:$%s %s\n" indent pragma arg | Unary (op, arg) -> - Printf.eprintf "%sUnary:%s\n" indent op; - Queue.iter (prerr_chunk (indent ^ " ")) arg + Printf.eprintf "%sUnary:%s\n" indent op; + Queue.iter (prerr_chunk (indent ^ " ")) arg | Binary (lhs, op, rhs) -> - Printf.eprintf "%sBinary:%s\n" indent op; - List.iter (fun (name, arg) -> - Printf.eprintf "%s %s:\n" indent name; - Queue.iter (prerr_chunk (indent ^ " ")) arg - ) [("lhs", lhs); ("rhs", rhs)] + Printf.eprintf "%sBinary:%s\n" indent op; + List.iter + (fun (name, arg) -> + Printf.eprintf "%s %s:\n" indent name; + Queue.iter (prerr_chunk (indent ^ " ")) arg + ) + [("lhs", lhs); ("rhs", rhs)] | Ternary (x, op1, y, op2, z) -> - Printf.eprintf "%sTernary:%s %s\n" indent op1 op2; - List.iter (fun (name, arg) -> - Printf.eprintf "%s %s:\n" indent name; - Queue.iter (prerr_chunk (indent ^ " ")) arg - ) [("x", x); ("y", y); ("z", z)] - | Delim str -> - Printf.eprintf "%sDelim:%s\n" indent str - | Opt_delim str -> - Printf.eprintf "%sOpt_delim:%s\n" indent str + Printf.eprintf "%sTernary:%s %s\n" indent op1 op2; + List.iter + (fun (name, arg) -> + Printf.eprintf "%s %s:\n" indent name; + Queue.iter (prerr_chunk (indent ^ " ")) arg + ) + [("x", x); ("y", y); ("z", z)] + | Delim str -> Printf.eprintf "%sDelim:%s\n" indent str + | Opt_delim str -> Printf.eprintf "%sOpt_delim:%s\n" indent str | Exists ex -> - Printf.eprintf "%sExists:\n" indent; - List.iter (fun (name, arg) -> - Printf.eprintf "%s %s:\n" indent name; - Queue.iter (prerr_chunk (indent ^ " ")) arg - ) [("vars", ex.vars); ("constr", ex.constr); ("typ", ex.typ)] + Printf.eprintf "%sExists:\n" indent; + List.iter + (fun (name, arg) -> + Printf.eprintf "%s %s:\n" indent name; + Queue.iter (prerr_chunk (indent ^ " ")) arg + ) + [("vars", ex.vars); ("constr", ex.constr); ("typ", ex.typ)] | Binder _ -> () | Block_binder (binder, binding, exp) -> - Printf.eprintf "%sBlock_binder:%s\n" indent (binder_keyword binder); - List.iter (fun (name, arg) -> - Printf.eprintf "%s %s:\n" indent name; - Queue.iter (prerr_chunk (indent ^ " ")) arg - ) [("binding", binding); ("exp", exp)] + Printf.eprintf "%sBlock_binder:%s\n" indent (binder_keyword binder); + List.iter + (fun (name, arg) -> + Printf.eprintf "%s %s:\n" indent name; + Queue.iter (prerr_chunk (indent ^ " ")) arg + ) + [("binding", binding); ("exp", exp)] | If_then (_, i, t) -> - Printf.eprintf "%sIf_then:\n" indent; - List.iter (fun (name, arg) -> - Printf.eprintf "%s %s:\n" indent name; - Queue.iter (prerr_chunk (indent ^ " ")) arg - ) [("if", i); ("then", t)] + Printf.eprintf "%sIf_then:\n" indent; + List.iter + (fun (name, arg) -> + Printf.eprintf "%s %s:\n" indent name; + Queue.iter (prerr_chunk (indent ^ " ")) arg + ) + [("if", i); ("then", t)] | If_then_else (_, i, t, e) -> - Printf.eprintf "%sIf_then_else:\n" indent; - List.iter (fun (name, arg) -> - Printf.eprintf "%s %s:\n" indent name; - Queue.iter (prerr_chunk (indent ^ " ")) arg - ) [("if", i); ("then", t); ("else", e)] + Printf.eprintf "%sIf_then_else:\n" indent; + List.iter + (fun (name, arg) -> + Printf.eprintf "%s %s:\n" indent name; + Queue.iter (prerr_chunk (indent ^ " ")) arg + ) + [("if", i); ("then", t); ("else", e)] | Field (exp, id) -> - Printf.eprintf "%sField:%s\n" indent (string_of_id id); - Queue.iter (prerr_chunk (indent ^ " ")) exp + Printf.eprintf "%sField:%s\n" indent (string_of_id id); + Queue.iter (prerr_chunk (indent ^ " ")) exp | Struct_update (exp, exps) -> - Printf.eprintf "%sStruct_update:\n" indent; - Queue.iter (prerr_chunk (indent ^ " ")) exp; - Printf.eprintf "%s with:" indent; - List.iter (fun exp -> - Queue.iter (prerr_chunk (indent ^ " ")) exp - ) exps - | Vector_updates (exp, updates) -> - Printf.eprintf "%sVector_updates:\n" indent + Printf.eprintf "%sStruct_update:\n" indent; + Queue.iter (prerr_chunk (indent ^ " ")) exp; + Printf.eprintf "%s with:" indent; + List.iter (fun exp -> Queue.iter (prerr_chunk (indent ^ " ")) exp) exps + | Vector_updates (exp, updates) -> Printf.eprintf "%sVector_updates:\n" indent | Index (exp, ix) -> - Printf.eprintf "%sIndex:\n" indent; - List.iter (fun (name, arg) -> - Printf.eprintf "%s %s:\n" indent name; - Queue.iter (prerr_chunk (indent ^ " ")) arg - ) [("exp", exp); ("ix", ix)] + Printf.eprintf "%sIndex:\n" indent; + List.iter + (fun (name, arg) -> + Printf.eprintf "%s %s:\n" indent name; + Queue.iter (prerr_chunk (indent ^ " ")) arg + ) + [("exp", exp); ("ix", ix)] | Chunks chunks -> - Printf.eprintf "%sChunks:\n" indent; - Queue.iter (prerr_chunk (indent ^ " ")) chunks - | Raw _ -> - Printf.eprintf "%sRaw\n" indent + Printf.eprintf "%sChunks:\n" indent; + Queue.iter (prerr_chunk (indent ^ " ")) chunks + | Raw _ -> Printf.eprintf "%sRaw\n" indent [@@@coverage on] - + let string_of_var (Kid_aux (Var v, _)) = v let rec pop_header_comments comments chunks l lnum = match Stack.top_opt comments with | None -> () - | Some (Lexer.Comment (comment_type, comment_s, e, contents)) -> - begin match Reporting.simp_loc l with - | Some (s, _) when e.pos_cnum < s.pos_cnum && comment_s.pos_lnum = lnum -> - let _ = Stack.pop comments in - Queue.add (Comment (comment_type, 0, comment_s.pos_cnum - comment_s.pos_bol, contents)) chunks; - Queue.add (Spacer (true, 1)) chunks; - pop_header_comments comments chunks l (lnum + 1) - | _ -> () - end + | Some (Lexer.Comment (comment_type, comment_s, e, contents)) -> begin + match Reporting.simp_loc l with + | Some (s, _) when e.pos_cnum < s.pos_cnum && comment_s.pos_lnum = lnum -> + let _ = Stack.pop comments in + Queue.add (Comment (comment_type, 0, comment_s.pos_cnum - comment_s.pos_bol, contents)) chunks; + Queue.add (Spacer (true, 1)) chunks; + pop_header_comments comments chunks l (lnum + 1) + | _ -> () + end let chunk_header_comments comments chunks = function | [] -> () - | (DEF_aux (_, l)) :: _ -> - pop_header_comments comments chunks l 1 + | DEF_aux (_, l) :: _ -> pop_header_comments comments chunks l 1 (* Pop comments preceeding location into the chunkstream *) let rec pop_comments comments chunks l = match Stack.top_opt comments with | None -> () - | Some (Lexer.Comment (comment_type, comment_s, e, contents)) -> - begin match Reporting.simp_loc l with - | Some (s, _) when e.pos_cnum <= s.pos_cnum -> - let _ = Stack.pop comments in - Queue.add (Comment (comment_type, 0, comment_s.pos_cnum - comment_s.pos_bol, contents)) chunks; - if e.pos_lnum < s.pos_lnum then ( - Queue.add (Spacer (true, 1)) chunks - ); - pop_comments comments chunks l - | _ -> () - end + | Some (Lexer.Comment (comment_type, comment_s, e, contents)) -> begin + match Reporting.simp_loc l with + | Some (s, _) when e.pos_cnum <= s.pos_cnum -> + let _ = Stack.pop comments in + Queue.add (Comment (comment_type, 0, comment_s.pos_cnum - comment_s.pos_bol, contents)) chunks; + if e.pos_lnum < s.pos_lnum then Queue.add (Spacer (true, 1)) chunks; + pop_comments comments chunks l + | _ -> () + end let rec discard_comments comments (pos : Lexing.position) = match Stack.top_opt comments with | None -> () | Some (Lexer.Comment (_, _, e, _)) -> - if e.pos_cnum <= pos.pos_cnum then ( + if e.pos_cnum <= pos.pos_cnum then ( let _ = Stack.pop comments in discard_comments comments pos - ) + ) let pop_trailing_comment ?space:(n = 0) comments chunks line_num = match line_num with | None -> false - | Some lnum -> - begin match Stack.top_opt comments with - | Some (Lexer.Comment (comment_type, s, _, contents)) when s.pos_lnum = lnum -> - let _ = Stack.pop comments in - Queue.add (Comment (comment_type, n, s.pos_cnum - s.pos_bol, contents)) chunks; - begin match comment_type with - | Lexer.Comment_line -> true - | _ -> false - end - | _ -> false - end + | Some lnum -> begin + match Stack.top_opt comments with + | Some (Lexer.Comment (comment_type, s, _, contents)) when s.pos_lnum = lnum -> + let _ = Stack.pop comments in + Queue.add (Comment (comment_type, n, s.pos_cnum - s.pos_bol, contents)) chunks; + begin + match comment_type with Lexer.Comment_line -> true | _ -> false + end + | _ -> false + end let string_of_kind (K_aux (k, _)) = - match k with - | K_type -> "Type" - | K_int -> "Int" - | K_order -> "Order" - | K_bool -> "Bool" + match k with K_type -> "Type" | K_int -> "Int" | K_order -> "Order" | K_bool -> "Bool" (* Right now, let's just assume we never break up kinded-identifiers *) let chunk_of_kopt (KOpt_aux (KOpt_kind (special, vars, kind), l)) = - match special, kind with + match (special, kind) with | Some c, Some k -> - Atom (Printf.sprintf "(%s %s : %s)" c (Util.string_of_list " " string_of_var vars) (string_of_kind k)) - | None, Some k -> - Atom (Printf.sprintf "(%s : %s)" (Util.string_of_list " " string_of_var vars) (string_of_kind k)) - | None, None -> - Atom (Util.string_of_list " " string_of_var vars) + Atom (Printf.sprintf "(%s %s : %s)" c (Util.string_of_list " " string_of_var vars) (string_of_kind k)) + | None, Some k -> Atom (Printf.sprintf "(%s : %s)" (Util.string_of_list " " string_of_var vars) (string_of_kind k)) + | None, None -> Atom (Util.string_of_list " " string_of_var vars) | _, _ -> - (* No other KOpt should be parseable *) - Reporting.unreachable l __POS__ "Invalid KOpt in formatter" [@coverage off] + (* No other KOpt should be parseable *) + Reporting.unreachable l __POS__ "Invalid KOpt in formatter" [@coverage off] let chunk_of_lit (L_aux (aux, _)) = match aux with @@ -500,43 +472,36 @@ let chunk_of_lit (L_aux (aux, _)) = let rec map_peek f = function | x1 :: x2 :: xs -> - let x1 = f (Some x2) x1 in - x1 :: map_peek f (x2 ::xs) + let x1 = f (Some x2) x1 in + x1 :: map_peek f (x2 :: xs) | [x] -> [f None x] | [] -> [] let rec map_peek_acc f acc = function | x1 :: x2 :: xs -> - let x1, acc = f acc (Some x2) x1 in - x1 :: map_peek_acc f acc (x2 ::xs) + let x1, acc = f acc (Some x2) x1 in + x1 :: map_peek_acc f acc (x2 :: xs) | [x] -> [fst (f acc None x)] | [] -> [] -let have_linebreak line_num1 line_num2 = - match line_num1, line_num2 with - | Some p1, Some p2 -> p1 < p2 - | _, _ -> false +let have_linebreak line_num1 line_num2 = match (line_num1, line_num2) with Some p1, Some p2 -> p1 < p2 | _, _ -> false let have_blank_linebreak line_num1 line_num2 = - match line_num1, line_num2 with - | Some p1, Some p2 -> p1 + 1 < p2 - | _, _ -> false + match (line_num1, line_num2) with Some p1, Some p2 -> p1 + 1 < p2 | _, _ -> false let chunk_delimit ?delim ~get_loc ~chunk comments chunks xs = - map_peek (fun next x -> + map_peek + (fun next x -> let l = get_loc x in let chunks = Queue.create () in chunk comments chunks x; (* Add a delimiter, which is optional for the last element *) - begin match delim with - | Some delim -> - if Option.is_some next then ( - Queue.add (Delim delim) chunks - ) else ( - Queue.add (Opt_delim delim) chunks - ) - | None -> () + begin + match delim with + | Some delim -> + if Option.is_some next then Queue.add (Delim delim) chunks else Queue.add (Opt_delim delim) chunks + | None -> () end; (* If the next delimited expression is on a new line, @@ -551,12 +516,12 @@ let chunk_delimit ?delim ~get_loc ~chunk comments chunks xs = the line comment will be attached to arg2, and the block comment to arg1 *) let next_line_num = Option.bind next (fun x2 -> starting_line_num (get_loc x2)) in - if have_linebreak (ending_line_num l) next_line_num then ( - ignore (pop_trailing_comment comments chunks (ending_line_num l)) - ); + if have_linebreak (ending_line_num l) next_line_num then + ignore (pop_trailing_comment comments chunks (ending_line_num l)); chunks - ) xs + ) + xs let rec chunk_atyp comments chunks (ATyp_aux (aux, l)) = pop_comments comments chunks l; @@ -566,62 +531,69 @@ let rec chunk_atyp comments chunks (ATyp_aux (aux, l)) = chunks in match aux with - | ATyp_id id -> - Queue.add (Atom (string_of_id id)) chunks - | ATyp_var v -> - Queue.add (Atom (string_of_var v)) chunks - | ATyp_lit lit -> - Queue.add (chunk_of_lit lit) chunks + | ATyp_id id -> Queue.add (Atom (string_of_id id)) chunks + | ATyp_var v -> Queue.add (Atom (string_of_var v)) chunks + | ATyp_lit lit -> Queue.add (chunk_of_lit lit) chunks | ATyp_nset (n, set) -> - (* We would need more granular location information to do anything better here *) - Queue.add (Atom (Printf.sprintf "%s in {%s}" (string_of_var n) (Util.string_of_list ", " Big_int.to_string set))) chunks + (* We would need more granular location information to do anything better here *) + Queue.add + (Atom (Printf.sprintf "%s in {%s}" (string_of_var n) (Util.string_of_list ", " Big_int.to_string set))) + chunks | (ATyp_times (lhs, rhs) | ATyp_sum (lhs, rhs) | ATyp_minus (lhs, rhs)) as binop -> - let op_symbol = match binop with - | ATyp_times _ -> "*" | ATyp_sum _ -> "+" | ATyp_minus _ -> "-" | _ -> Reporting.unreachable l __POS__ "Invalid binary atyp" [@coverage off] in - let lhs_chunks = rec_chunk_atyp lhs in - let rhs_chunks = rec_chunk_atyp rhs in - Queue.add (Binary (lhs_chunks, op_symbol, rhs_chunks)) chunks + let op_symbol = + match binop with + | ATyp_times _ -> "*" + | ATyp_sum _ -> "+" + | ATyp_minus _ -> "-" + | _ -> Reporting.unreachable l __POS__ "Invalid binary atyp" [@coverage off] + in + let lhs_chunks = rec_chunk_atyp lhs in + let rhs_chunks = rec_chunk_atyp rhs in + Queue.add (Binary (lhs_chunks, op_symbol, rhs_chunks)) chunks | ATyp_exp arg -> - let lhs_chunks = Queue.create () in - Queue.add (Atom "2") lhs_chunks; - let rhs_chunks = rec_chunk_atyp arg in - Queue.add (Binary (lhs_chunks, "^", rhs_chunks)) chunks + let lhs_chunks = Queue.create () in + Queue.add (Atom "2") lhs_chunks; + let rhs_chunks = rec_chunk_atyp arg in + Queue.add (Binary (lhs_chunks, "^", rhs_chunks)) chunks | ATyp_neg arg -> - let arg_chunks = rec_chunk_atyp arg in - Queue.add (Unary ("-", arg_chunks)) chunks - | ATyp_inc -> - Queue.add (Atom "inc") chunks - | ATyp_dec -> - Queue.add (Atom "dec") chunks + let arg_chunks = rec_chunk_atyp arg in + Queue.add (Unary ("-", arg_chunks)) chunks + | ATyp_inc -> Queue.add (Atom "inc") chunks + | ATyp_dec -> Queue.add (Atom "dec") chunks | ATyp_fn (lhs, rhs, _) -> - let lhs_chunks = rec_chunk_atyp lhs in - let rhs_chunks = rec_chunk_atyp rhs in - Queue.add (Function_typ { mapping = false; lhs = lhs_chunks; rhs = rhs_chunks }) chunks + let lhs_chunks = rec_chunk_atyp lhs in + let rhs_chunks = rec_chunk_atyp rhs in + Queue.add (Function_typ { mapping = false; lhs = lhs_chunks; rhs = rhs_chunks }) chunks | ATyp_bidir (lhs, rhs, _) -> - let lhs_chunks = rec_chunk_atyp lhs in - let rhs_chunks = rec_chunk_atyp rhs in - Queue.add (Function_typ { mapping = true; lhs = lhs_chunks; rhs = rhs_chunks }) chunks + let lhs_chunks = rec_chunk_atyp lhs in + let rhs_chunks = rec_chunk_atyp rhs in + Queue.add (Function_typ { mapping = true; lhs = lhs_chunks; rhs = rhs_chunks }) chunks | ATyp_app (Id_aux (Operator op, _), [lhs; rhs]) -> - let lhs_chunks = rec_chunk_atyp lhs in - let rhs_chunks = rec_chunk_atyp rhs in - Queue.add (Binary (lhs_chunks, op, rhs_chunks)) chunks + let lhs_chunks = rec_chunk_atyp lhs in + let rhs_chunks = rec_chunk_atyp rhs in + Queue.add (Binary (lhs_chunks, op, rhs_chunks)) chunks | ATyp_app (id, ([_] as args)) when string_of_id id = "atom" -> - let args = chunk_delimit ~delim:"," ~get_loc:(fun (ATyp_aux (_, l)) -> l) ~chunk:chunk_atyp comments chunks args in - Queue.add (App (Id_aux (Id "int", id_loc id), args)) chunks + let args = + chunk_delimit ~delim:"," ~get_loc:(fun (ATyp_aux (_, l)) -> l) ~chunk:chunk_atyp comments chunks args + in + Queue.add (App (Id_aux (Id "int", id_loc id), args)) chunks | ATyp_app (id, args) -> - let args = chunk_delimit ~delim:"," ~get_loc:(fun (ATyp_aux (_, l)) -> l) ~chunk:chunk_atyp comments chunks args in - Queue.add (App (id, args)) chunks + let args = + chunk_delimit ~delim:"," ~get_loc:(fun (ATyp_aux (_, l)) -> l) ~chunk:chunk_atyp comments chunks args + in + Queue.add (App (id, args)) chunks | ATyp_tuple args -> - let args = chunk_delimit ~delim:"," ~get_loc:(fun (ATyp_aux (_, l)) -> l) ~chunk:chunk_atyp comments chunks args in - Queue.add (Tuple ("(", ")", 0, args)) chunks - | ATyp_wild -> - Queue.add (Atom "_") chunks + let args = + chunk_delimit ~delim:"," ~get_loc:(fun (ATyp_aux (_, l)) -> l) ~chunk:chunk_atyp comments chunks args + in + Queue.add (Tuple ("(", ")", 0, args)) chunks + | ATyp_wild -> Queue.add (Atom "_") chunks | ATyp_exist (vars, constr, typ) -> - let var_chunks = Queue.create () in - List.iter (fun kopt -> Queue.add (chunk_of_kopt kopt) var_chunks) vars; - let constr_chunks = rec_chunk_atyp constr in - let typ_chunks = rec_chunk_atyp typ in - Queue.add (Exists { vars = var_chunks; constr = constr_chunks; typ = typ_chunks }) chunks + let var_chunks = Queue.create () in + List.iter (fun kopt -> Queue.add (chunk_of_kopt kopt) var_chunks) vars; + let constr_chunks = rec_chunk_atyp constr in + let typ_chunks = rec_chunk_atyp typ in + Queue.add (Exists { vars = var_chunks; constr = constr_chunks; typ = typ_chunks }) chunks | ATyp_set _ -> () let rec chunk_pat comments chunks (P_aux (aux, l)) = @@ -632,92 +604,83 @@ let rec chunk_pat comments chunks (P_aux (aux, l)) = chunks in match aux with - | P_id id -> - Queue.add (Atom (string_of_id id)) chunks - | P_wild -> - Queue.add (Atom "_") chunks - | P_lit lit -> - Queue.add (chunk_of_lit lit) chunks - | P_app (id, [P_aux (P_lit (L_aux (L_unit, _)), _)]) -> - Queue.add (App (id, [])) chunks + | P_id id -> Queue.add (Atom (string_of_id id)) chunks + | P_wild -> Queue.add (Atom "_") chunks + | P_lit lit -> Queue.add (chunk_of_lit lit) chunks + | P_app (id, [P_aux (P_lit (L_aux (L_unit, _)), _)]) -> Queue.add (App (id, [])) chunks | P_app (id, pats) -> - let pats = chunk_delimit ~delim:"," ~get_loc:(fun (P_aux (_, l)) -> l) ~chunk:chunk_pat comments chunks pats in - Queue.add (App (id, pats)) chunks + let pats = chunk_delimit ~delim:"," ~get_loc:(fun (P_aux (_, l)) -> l) ~chunk:chunk_pat comments chunks pats in + Queue.add (App (id, pats)) chunks | P_tuple pats -> - let pats = chunk_delimit ~delim:"," ~get_loc:(fun (P_aux (_, l)) -> l) ~chunk:chunk_pat comments chunks pats in - Queue.add (Tuple ("(", ")", 0, pats)) chunks + let pats = chunk_delimit ~delim:"," ~get_loc:(fun (P_aux (_, l)) -> l) ~chunk:chunk_pat comments chunks pats in + Queue.add (Tuple ("(", ")", 0, pats)) chunks | P_vector pats -> - let pats = chunk_delimit ~delim:"," ~get_loc:(fun (P_aux (_, l)) -> l) ~chunk:chunk_pat comments chunks pats in - Queue.add (Tuple ("[", "]", 0, pats)) chunks + let pats = chunk_delimit ~delim:"," ~get_loc:(fun (P_aux (_, l)) -> l) ~chunk:chunk_pat comments chunks pats in + Queue.add (Tuple ("[", "]", 0, pats)) chunks | P_list pats -> - let pats = chunk_delimit ~delim:"," ~get_loc:(fun (P_aux (_, l)) -> l) ~chunk:chunk_pat comments chunks pats in - Queue.add (Tuple ("[|", "|]", 0, pats)) chunks + let pats = chunk_delimit ~delim:"," ~get_loc:(fun (P_aux (_, l)) -> l) ~chunk:chunk_pat comments chunks pats in + Queue.add (Tuple ("[|", "|]", 0, pats)) chunks | P_string_append pats -> - let pats = chunk_delimit ~get_loc:(fun (P_aux (_, l)) -> l) ~chunk:chunk_pat comments chunks pats in - Queue.add (Intersperse ("^", pats)) chunks + let pats = chunk_delimit ~get_loc:(fun (P_aux (_, l)) -> l) ~chunk:chunk_pat comments chunks pats in + Queue.add (Intersperse ("^", pats)) chunks | P_vector_concat pats -> - let pats = chunk_delimit ~get_loc:(fun (P_aux (_, l)) -> l) ~chunk:chunk_pat comments chunks pats in - Queue.add (Intersperse ("@", pats)) chunks + let pats = chunk_delimit ~get_loc:(fun (P_aux (_, l)) -> l) ~chunk:chunk_pat comments chunks pats in + Queue.add (Intersperse ("@", pats)) chunks | P_vector_subrange (id, n, m) -> - let id_chunks = Queue.create () in - Queue.add (Atom (string_of_id id)) id_chunks; - let ix_chunks = Queue.create () in - if Big_int.equal n m then ( - Queue.add (Atom (Big_int.to_string n)) ix_chunks - ) else ( - let n_chunks = Queue.create () in - Queue.add (Atom (Big_int.to_string n)) n_chunks; - let m_chunks = Queue.create () in - Queue.add (Atom (Big_int.to_string m)) m_chunks; - Queue.add (Binary (n_chunks, "..", m_chunks)) ix_chunks - ); - Queue.add (Index (id_chunks, ix_chunks)) chunks + let id_chunks = Queue.create () in + Queue.add (Atom (string_of_id id)) id_chunks; + let ix_chunks = Queue.create () in + if Big_int.equal n m then Queue.add (Atom (Big_int.to_string n)) ix_chunks + else ( + let n_chunks = Queue.create () in + Queue.add (Atom (Big_int.to_string n)) n_chunks; + let m_chunks = Queue.create () in + Queue.add (Atom (Big_int.to_string m)) m_chunks; + Queue.add (Binary (n_chunks, "..", m_chunks)) ix_chunks + ); + Queue.add (Index (id_chunks, ix_chunks)) chunks | P_typ (typ, pat) -> - let pat_chunks = rec_chunk_pat pat in - let typ_chunks = Queue.create () in - chunk_atyp comments typ_chunks typ; - Queue.add (Binary (pat_chunks, ":", typ_chunks)) chunks + let pat_chunks = rec_chunk_pat pat in + let typ_chunks = Queue.create () in + chunk_atyp comments typ_chunks typ; + Queue.add (Binary (pat_chunks, ":", typ_chunks)) chunks | P_var (pat, typ) -> - let pat_chunks = rec_chunk_pat pat in - let typ_chunks = Queue.create () in - chunk_atyp comments typ_chunks typ; - Queue.add (Binary (pat_chunks, "as", typ_chunks)) chunks + let pat_chunks = rec_chunk_pat pat in + let typ_chunks = Queue.create () in + chunk_atyp comments typ_chunks typ; + Queue.add (Binary (pat_chunks, "as", typ_chunks)) chunks | P_cons (hd_pat, tl_pat) -> - let hd_pat_chunks = rec_chunk_pat hd_pat in - let tl_pat_chunks = rec_chunk_pat tl_pat in - Queue.add (Binary (hd_pat_chunks, "::", tl_pat_chunks)) chunks + let hd_pat_chunks = rec_chunk_pat hd_pat in + let tl_pat_chunks = rec_chunk_pat tl_pat in + Queue.add (Binary (hd_pat_chunks, "::", tl_pat_chunks)) chunks | P_attribute (attr, arg, pat) -> - Queue.add (Atom (Printf.sprintf "$[%s %s]" attr arg)) chunks; - Queue.add (Spacer (false, 1)) chunks; - chunk_pat comments chunks pat; + Queue.add (Atom (Printf.sprintf "$[%s %s]" attr arg)) chunks; + Queue.add (Spacer (false, 1)) chunks; + chunk_pat comments chunks pat -type block_exp = - | Block_exp of exp - | Block_let of letbind - | Block_var of exp * exp +type block_exp = Block_exp of exp | Block_let of letbind | Block_var of exp * exp let block_exp_locs = function | Block_exp (E_aux (_, l)) -> (l, l) | Block_let (LB_aux (_, l)) -> (l, l) | Block_var (E_aux (_, s_l), E_aux (_, e_l)) -> (s_l, e_l) - + let flatten_block exps = let block_exps = Queue.create () in let rec go = function | [] -> () | [E_aux (E_let (letbind, E_aux (E_block more_exps, _)), _)] -> - Queue.add (Block_let letbind) block_exps; - go more_exps + Queue.add (Block_let letbind) block_exps; + go more_exps | [E_aux (E_var (lexp, exp, E_aux (E_block more_exps, _)), _)] -> - Queue.add (Block_var (lexp, exp)) block_exps; - go more_exps - | [E_aux (E_let (letbind, E_aux (E_lit (L_aux (L_unit, _)), _)), _)] -> - Queue.add (Block_let letbind) block_exps + Queue.add (Block_var (lexp, exp)) block_exps; + go more_exps + | [E_aux (E_let (letbind, E_aux (E_lit (L_aux (L_unit, _)), _)), _)] -> Queue.add (Block_let letbind) block_exps | [E_aux (E_var (lexp, exp, E_aux (E_lit (L_aux (L_unit, _)), _)), _)] -> - Queue.add (Block_var (lexp, exp)) block_exps + Queue.add (Block_var (lexp, exp)) block_exps | exp :: exps -> - Queue.add (Block_exp exp) block_exps; - go exps + Queue.add (Block_exp exp) block_exps; + go exps in go exps; List.of_seq (Queue.to_seq block_exps) @@ -728,270 +691,267 @@ let is_aligned pexps = | Pat_aux (Pat_exp (_, E_aux (_, l)), _) -> starting_column_num l | Pat_aux (Pat_when (_, _, E_aux (_, l)), _) -> starting_column_num l in - List.fold_left (fun (all_same, col) pexp -> - if not all_same then ( - (false, None) - ) else ( + List.fold_left + (fun (all_same, col) pexp -> + if not all_same then (false, None) + else ( let new_col = pexp_exp_column pexp in - match col, new_col with + match (col, new_col) with | _, None -> - (* If a column number is unknown, assume not aligned *) - (false, None) - | None, Some _ -> - (true, new_col) - | Some col, Some new_col -> - if col = new_col then ( - (true, Some col) - ) else ( - (false, None) - ) + (* If a column number is unknown, assume not aligned *) + (false, None) + | None, Some _ -> (true, new_col) + | Some col, Some new_col -> if col = new_col then (true, Some col) else (false, None) ) - ) (true, None) pexps + ) + (true, None) pexps |> fst let rec chunk_exp comments chunks (E_aux (aux, l)) = pop_comments comments chunks l; - + let rec_chunk_exp exp = let chunks = Queue.create () in chunk_exp comments chunks exp; chunks in match aux with - | E_id id -> - Queue.add (Atom (string_of_id id)) chunks - | E_ref id -> - Queue.add (Atom ("ref " ^ string_of_id id)) chunks - | E_lit lit -> - Queue.add (chunk_of_lit lit) chunks + | E_id id -> Queue.add (Atom (string_of_id id)) chunks + | E_ref id -> Queue.add (Atom ("ref " ^ string_of_id id)) chunks + | E_lit lit -> Queue.add (chunk_of_lit lit) chunks | E_attribute (attr, arg, exp) -> - Queue.add (Atom (Printf.sprintf "$[%s %s]" attr arg)) chunks; - Queue.add (Spacer (false, 1)) chunks; - chunk_exp comments chunks exp - | E_app (id, [E_aux (E_lit (L_aux (L_unit, _)), _)]) -> - Queue.add (App (id, [])) chunks + Queue.add (Atom (Printf.sprintf "$[%s %s]" attr arg)) chunks; + Queue.add (Spacer (false, 1)) chunks; + chunk_exp comments chunks exp + | E_app (id, [E_aux (E_lit (L_aux (L_unit, _)), _)]) -> Queue.add (App (id, [])) chunks | E_app (id, args) -> - let args = chunk_delimit ~delim:"," ~get_loc:(fun (E_aux (_, l)) -> l) ~chunk:chunk_exp comments chunks args in - Queue.add (App (id, args)) chunks + let args = chunk_delimit ~delim:"," ~get_loc:(fun (E_aux (_, l)) -> l) ~chunk:chunk_exp comments chunks args in + Queue.add (App (id, args)) chunks | (E_sizeof atyp | E_constraint atyp) as typ_app -> - let name = match typ_app with E_sizeof _ -> "sizeof" | E_constraint _ -> "constraint" | _ -> Reporting.unreachable l __POS__ "Invalid typ_app" in - let typ_chunks = Queue.create () in - chunk_atyp comments typ_chunks atyp; - Queue.add (App (Id_aux (Id name, Unknown), [typ_chunks])) chunks + let name = + match typ_app with + | E_sizeof _ -> "sizeof" + | E_constraint _ -> "constraint" + | _ -> Reporting.unreachable l __POS__ "Invalid typ_app" + in + let typ_chunks = Queue.create () in + chunk_atyp comments typ_chunks atyp; + Queue.add (App (Id_aux (Id name, Unknown), [typ_chunks])) chunks | E_assert (exp, E_aux (E_lit (L_aux (L_string "", _)), _)) -> - let exp_chunks = rec_chunk_exp exp in - Queue.add (App (Id_aux (Id "assert", Unknown), [exp_chunks])) chunks + let exp_chunks = rec_chunk_exp exp in + Queue.add (App (Id_aux (Id "assert", Unknown), [exp_chunks])) chunks | E_assert (exp, msg) -> - let exp_chunks = rec_chunk_exp exp in - Queue.add (Delim ",") exp_chunks; - let msg_chunks = rec_chunk_exp msg in - Queue.add (App (Id_aux (Id "assert", Unknown), [exp_chunks; msg_chunks])) chunks + let exp_chunks = rec_chunk_exp exp in + Queue.add (Delim ",") exp_chunks; + let msg_chunks = rec_chunk_exp msg in + Queue.add (App (Id_aux (Id "assert", Unknown), [exp_chunks; msg_chunks])) chunks | E_exit exp -> - let exp_chunks = rec_chunk_exp exp in - Queue.add (App (Id_aux (Id "exit", Unknown), [exp_chunks])) chunks + let exp_chunks = rec_chunk_exp exp in + Queue.add (App (Id_aux (Id "exit", Unknown), [exp_chunks])) chunks | E_app_infix (lhs, op, rhs) -> - let lhs_chunks = rec_chunk_exp lhs in - let rhs_chunks = rec_chunk_exp rhs in - Queue.add (Binary (lhs_chunks, string_of_id op, rhs_chunks)) chunks + let lhs_chunks = rec_chunk_exp lhs in + let rhs_chunks = rec_chunk_exp rhs in + Queue.add (Binary (lhs_chunks, string_of_id op, rhs_chunks)) chunks | E_cons (lhs, rhs) -> - let lhs_chunks = rec_chunk_exp lhs in - let rhs_chunks = rec_chunk_exp rhs in - Queue.add (Binary (lhs_chunks, "::", rhs_chunks)) chunks + let lhs_chunks = rec_chunk_exp lhs in + let rhs_chunks = rec_chunk_exp rhs in + Queue.add (Binary (lhs_chunks, "::", rhs_chunks)) chunks | E_vector_append (lhs, rhs) -> - let lhs_chunks = rec_chunk_exp lhs in - let rhs_chunks = rec_chunk_exp rhs in - Queue.add (Binary (lhs_chunks, "@", rhs_chunks)) chunks + let lhs_chunks = rec_chunk_exp lhs in + let rhs_chunks = rec_chunk_exp rhs in + Queue.add (Binary (lhs_chunks, "@", rhs_chunks)) chunks | E_typ (typ, exp) -> - let exp_chunks = rec_chunk_exp exp in - let typ_chunks = Queue.create () in - chunk_atyp comments typ_chunks typ; - Queue.add (Binary (exp_chunks, ":", typ_chunks)) chunks - | E_tuple exps -> - let exps = chunk_delimit ~delim:"," ~get_loc:(fun (E_aux (_, l)) -> l) ~chunk:chunk_exp comments chunks exps in - Queue.add (Tuple ("(", ")", 0, exps)) chunks - | E_vector [] -> - Queue.add (Atom "[]") chunks + let exp_chunks = rec_chunk_exp exp in + let typ_chunks = Queue.create () in + chunk_atyp comments typ_chunks typ; + Queue.add (Binary (exp_chunks, ":", typ_chunks)) chunks + | E_tuple exps -> + let exps = chunk_delimit ~delim:"," ~get_loc:(fun (E_aux (_, l)) -> l) ~chunk:chunk_exp comments chunks exps in + Queue.add (Tuple ("(", ")", 0, exps)) chunks + | E_vector [] -> Queue.add (Atom "[]") chunks | E_vector exps -> - let exps = chunk_delimit ~delim:"," ~get_loc:(fun (E_aux (_, l)) -> l) ~chunk:chunk_exp comments chunks exps in - Queue.add (Tuple ("[", "]", 0, exps)) chunks - | E_list [] -> - Queue.add (Atom "[||]") chunks + let exps = chunk_delimit ~delim:"," ~get_loc:(fun (E_aux (_, l)) -> l) ~chunk:chunk_exp comments chunks exps in + Queue.add (Tuple ("[", "]", 0, exps)) chunks + | E_list [] -> Queue.add (Atom "[||]") chunks | E_list exps -> - let exps = chunk_delimit ~delim:"," ~get_loc:(fun (E_aux (_, l)) -> l) ~chunk:chunk_exp comments chunks exps in - Queue.add (Tuple ("[|", "|]", 0, exps)) chunks + let exps = chunk_delimit ~delim:"," ~get_loc:(fun (E_aux (_, l)) -> l) ~chunk:chunk_exp comments chunks exps in + Queue.add (Tuple ("[|", "|]", 0, exps)) chunks | E_struct fexps -> - let fexps = chunk_delimit ~delim:"," ~get_loc:(fun (E_aux (_, l)) -> l) ~chunk:chunk_exp comments chunks fexps in - Queue.add (Tuple ("struct {", "}", 1, fexps)) chunks + let fexps = chunk_delimit ~delim:"," ~get_loc:(fun (E_aux (_, l)) -> l) ~chunk:chunk_exp comments chunks fexps in + Queue.add (Tuple ("struct {", "}", 1, fexps)) chunks | E_struct_update (exp, fexps) -> - let exp = rec_chunk_exp exp in - let fexps = chunk_delimit ~delim:"," ~get_loc:(fun (E_aux (_, l)) -> l) ~chunk:chunk_exp comments chunks fexps in - Queue.add (Struct_update (exp, fexps)) chunks + let exp = rec_chunk_exp exp in + let fexps = chunk_delimit ~delim:"," ~get_loc:(fun (E_aux (_, l)) -> l) ~chunk:chunk_exp comments chunks fexps in + Queue.add (Struct_update (exp, fexps)) chunks | E_block exps -> - let block_exps = flatten_block exps in - let block_chunks = - map_peek_acc (fun need_spacer next block_exp -> - let s_l, e_l = block_exp_locs block_exp in - let chunks = Queue.create () in - - if need_spacer then ( - Queue.add (Spacer (true, 1)) chunks - ); - - begin match block_exp with - | Block_exp exp -> - chunk_exp comments chunks exp; - | Block_let (LB_aux (LB_val (pat, exp), _)) -> - pop_comments comments chunks s_l; - let pat_chunks = Queue.create () in - chunk_pat comments pat_chunks pat; - let exp_chunks = rec_chunk_exp exp in - Queue.add (Block_binder (Let_binder, pat_chunks, exp_chunks)) chunks - | Block_var (lexp, exp) -> - pop_comments comments chunks s_l; - let lexp_chunks = rec_chunk_exp lexp in - let exp_chunks = rec_chunk_exp exp in - Queue.add (Block_binder (Var_binder, lexp_chunks, exp_chunks)) chunks - end; - - (* TODO: Do we need to do something different for multiple trailing comments at end of a block? *) - let next_line_num = Option.bind next (fun bexp -> block_exp_locs bexp |> fst |> starting_line_num) in - if have_linebreak (ending_line_num e_l) next_line_num || Option.is_none next then ( - ignore (pop_trailing_comment comments chunks (ending_line_num e_l)) - ); - - (chunks, have_blank_linebreak (ending_line_num e_l) next_line_num) - ) false block_exps in - Queue.add (Block (true, block_chunks)) chunks + let block_exps = flatten_block exps in + let block_chunks = + map_peek_acc + (fun need_spacer next block_exp -> + let s_l, e_l = block_exp_locs block_exp in + let chunks = Queue.create () in + + if need_spacer then Queue.add (Spacer (true, 1)) chunks; + + begin + match block_exp with + | Block_exp exp -> chunk_exp comments chunks exp + | Block_let (LB_aux (LB_val (pat, exp), _)) -> + pop_comments comments chunks s_l; + let pat_chunks = Queue.create () in + chunk_pat comments pat_chunks pat; + let exp_chunks = rec_chunk_exp exp in + Queue.add (Block_binder (Let_binder, pat_chunks, exp_chunks)) chunks + | Block_var (lexp, exp) -> + pop_comments comments chunks s_l; + let lexp_chunks = rec_chunk_exp lexp in + let exp_chunks = rec_chunk_exp exp in + Queue.add (Block_binder (Var_binder, lexp_chunks, exp_chunks)) chunks + end; + + (* TODO: Do we need to do something different for multiple trailing comments at end of a block? *) + let next_line_num = Option.bind next (fun bexp -> block_exp_locs bexp |> fst |> starting_line_num) in + if have_linebreak (ending_line_num e_l) next_line_num || Option.is_none next then + ignore (pop_trailing_comment comments chunks (ending_line_num e_l)); + + (chunks, have_blank_linebreak (ending_line_num e_l) next_line_num) + ) + false block_exps + in + Queue.add (Block (true, block_chunks)) chunks | (E_let (LB_aux (LB_val (pat, exp), _), body) | E_internal_plet (pat, exp, body)) as binder -> - let binder = match binder with E_let _ -> Let_binder | E_internal_plet _ -> Internal_plet_binder | _ -> Reporting.unreachable l __POS__ "Unknown binder" in - let pat_chunks = Queue.create () in - chunk_pat comments pat_chunks pat; - let exp_chunks = rec_chunk_exp exp in - let body_chunks = rec_chunk_exp body in - Queue.add (Binder (binder, pat_chunks, exp_chunks, body_chunks)) chunks + let binder = + match binder with + | E_let _ -> Let_binder + | E_internal_plet _ -> Internal_plet_binder + | _ -> Reporting.unreachable l __POS__ "Unknown binder" + in + let pat_chunks = Queue.create () in + chunk_pat comments pat_chunks pat; + let exp_chunks = rec_chunk_exp exp in + let body_chunks = rec_chunk_exp body in + Queue.add (Binder (binder, pat_chunks, exp_chunks, body_chunks)) chunks | E_var (lexp, exp, body) -> - let lexp_chunks = rec_chunk_exp lexp in - let exp_chunks = rec_chunk_exp exp in - let body_chunks = rec_chunk_exp body in - Queue.add (Binder (Var_binder, lexp_chunks, exp_chunks, body_chunks)) chunks + let lexp_chunks = rec_chunk_exp lexp in + let exp_chunks = rec_chunk_exp exp in + let body_chunks = rec_chunk_exp body in + Queue.add (Binder (Var_binder, lexp_chunks, exp_chunks, body_chunks)) chunks | E_assign (lexp, exp) -> - let lexp_chunks = rec_chunk_exp lexp in - let exp_chunks = rec_chunk_exp exp in - Queue.add (Binary (lexp_chunks, "=", exp_chunks)) chunks + let lexp_chunks = rec_chunk_exp lexp in + let exp_chunks = rec_chunk_exp exp in + Queue.add (Binary (lexp_chunks, "=", exp_chunks)) chunks | E_if (i, t, E_aux (E_lit (L_aux (L_unit, _)), _)) -> - let then_brace = (match t with E_aux (E_block _, _) -> true | _ -> false) in - let i_chunks = rec_chunk_exp i in - let t_chunks = rec_chunk_exp t in - Queue.add (If_then (then_brace, i_chunks, t_chunks)) chunks + let then_brace = match t with E_aux (E_block _, _) -> true | _ -> false in + let i_chunks = rec_chunk_exp i in + let t_chunks = rec_chunk_exp t in + Queue.add (If_then (then_brace, i_chunks, t_chunks)) chunks | E_if (i, t, e) -> - let if_format = { - then_brace = (match t with E_aux (E_block _, _) -> true | _ -> false); - else_brace = (match e with E_aux (E_block _, _) -> true | _ -> false); - } in - let i_chunks = rec_chunk_exp i in - let t_chunks = rec_chunk_exp t in - let e_chunks = rec_chunk_exp e in - Queue.add (If_then_else (if_format, i_chunks, t_chunks, e_chunks)) chunks + let if_format = + { + then_brace = (match t with E_aux (E_block _, _) -> true | _ -> false); + else_brace = (match e with E_aux (E_block _, _) -> true | _ -> false); + } + in + let i_chunks = rec_chunk_exp i in + let t_chunks = rec_chunk_exp t in + let e_chunks = rec_chunk_exp e in + Queue.add (If_then_else (if_format, i_chunks, t_chunks, e_chunks)) chunks | (E_throw exp | E_return exp | E_deref exp | E_internal_return exp) as unop -> - let unop = match unop with - | E_throw _ -> "throw" - | E_return _ -> "return" - | E_internal_return _ -> "internal_return" - | E_deref _ -> "*" - | _ -> Reporting.unreachable l __POS__ "invalid unop" in - let e_chunks = rec_chunk_exp exp in - Queue.add (Unary (unop, e_chunks)) chunks + let unop = + match unop with + | E_throw _ -> "throw" + | E_return _ -> "return" + | E_internal_return _ -> "internal_return" + | E_deref _ -> "*" + | _ -> Reporting.unreachable l __POS__ "invalid unop" + in + let e_chunks = rec_chunk_exp exp in + Queue.add (Unary (unop, e_chunks)) chunks | E_field (exp, id) -> - let exp_chunks = rec_chunk_exp exp in - Queue.add (Field (exp_chunks, id)) chunks + let exp_chunks = rec_chunk_exp exp in + Queue.add (Field (exp_chunks, id)) chunks | (E_match (exp, cases) | E_try (exp, cases)) as match_exp -> - let kind = match match_exp with E_match _ -> Match_match | _ -> Try_match in - let exp_chunks = rec_chunk_exp exp in - let aligned = is_aligned cases in - let cases = List.map (chunk_pexp ~delim:"," comments) cases in - (Match { - kind = kind; - exp = exp_chunks; - aligned = aligned; - cases = cases - }) |> add_chunk chunks - | (E_vector_update _ | E_vector_update_subrange _) -> - let (vec_chunks, updates) = chunk_vector_update comments (E_aux (aux, l)) in - Queue.add (Vector_updates (vec_chunks, List.rev updates)) chunks + let kind = match match_exp with E_match _ -> Match_match | _ -> Try_match in + let exp_chunks = rec_chunk_exp exp in + let aligned = is_aligned cases in + let cases = List.map (chunk_pexp ~delim:"," comments) cases in + Match { kind; exp = exp_chunks; aligned; cases } |> add_chunk chunks + | E_vector_update _ | E_vector_update_subrange _ -> + let vec_chunks, updates = chunk_vector_update comments (E_aux (aux, l)) in + Queue.add (Vector_updates (vec_chunks, List.rev updates)) chunks | E_vector_access (exp, ix) -> - let exp_chunks = rec_chunk_exp exp in - let ix_chunks = rec_chunk_exp ix in - Queue.add (Index (exp_chunks, ix_chunks)) chunks + let exp_chunks = rec_chunk_exp exp in + let ix_chunks = rec_chunk_exp ix in + Queue.add (Index (exp_chunks, ix_chunks)) chunks | E_vector_subrange (exp, ix1, ix2) -> - let exp_chunks = rec_chunk_exp exp in - let ix1_chunks = rec_chunk_exp ix1 in - let ix2_chunks = rec_chunk_exp ix2 in - let ix_chunks = Queue.create () in - Queue.add (Binary (ix1_chunks, "..", ix2_chunks)) ix_chunks; - Queue.add (Index (exp_chunks, ix_chunks)) chunks + let exp_chunks = rec_chunk_exp exp in + let ix1_chunks = rec_chunk_exp ix1 in + let ix2_chunks = rec_chunk_exp ix2 in + let ix_chunks = Queue.create () in + Queue.add (Binary (ix1_chunks, "..", ix2_chunks)) ix_chunks; + Queue.add (Index (exp_chunks, ix_chunks)) chunks | E_for (var, from_index, to_index, step, order, body) -> - let decreasing = match order with - | ATyp_aux (ATyp_inc, _) -> false - | ATyp_aux (ATyp_dec, _) -> true - | _ -> Reporting.unreachable l __POS__ "Invalid foreach order" - in - let var_chunks = Queue.create () in - pop_comments comments var_chunks (id_loc var); - Queue.add (Atom (string_of_id var)) var_chunks; - let from_chunks = Queue.create () in - chunk_exp comments from_chunks from_index; - let to_chunks = Queue.create () in - chunk_exp comments to_chunks to_index; - let step_chunks_opt = match step with - | E_aux (E_lit (L_aux (L_num n, _)), _) when Big_int.equal n (Big_int.of_int 1) -> - None - | _ -> - let step_chunks = Queue.create () in - chunk_exp comments step_chunks step; - Some step_chunks - in - let body_chunks = Queue.create () in - chunk_exp comments body_chunks body; - (Foreach { - var = var_chunks; - decreasing = decreasing; - from_index = from_chunks; - to_index = to_chunks; - step = step_chunks_opt; - body = body_chunks - }) |> add_chunk chunks + let decreasing = + match order with + | ATyp_aux (ATyp_inc, _) -> false + | ATyp_aux (ATyp_dec, _) -> true + | _ -> Reporting.unreachable l __POS__ "Invalid foreach order" + in + let var_chunks = Queue.create () in + pop_comments comments var_chunks (id_loc var); + Queue.add (Atom (string_of_id var)) var_chunks; + let from_chunks = Queue.create () in + chunk_exp comments from_chunks from_index; + let to_chunks = Queue.create () in + chunk_exp comments to_chunks to_index; + let step_chunks_opt = + match step with + | E_aux (E_lit (L_aux (L_num n, _)), _) when Big_int.equal n (Big_int.of_int 1) -> None + | _ -> + let step_chunks = Queue.create () in + chunk_exp comments step_chunks step; + Some step_chunks + in + let body_chunks = Queue.create () in + chunk_exp comments body_chunks body; + Foreach + { + var = var_chunks; + decreasing; + from_index = from_chunks; + to_index = to_chunks; + step = step_chunks_opt; + body = body_chunks; + } + |> add_chunk chunks | E_loop (loop_type, measure, cond, body) -> - let measure_chunks_opt = match measure with - | Measure_aux (Measure_none, _) -> None - | Measure_aux (Measure_some exp, _) -> - let measure_chunks = Queue.create () in - chunk_exp comments measure_chunks exp; - Some measure_chunks - in - begin match loop_type with - | While -> - let cond_chunks = Queue.create () in - chunk_exp comments cond_chunks cond; - let body_chunks = Queue.create () in - chunk_exp comments body_chunks body; - (While { - repeat_until = false; - termination_measure = measure_chunks_opt; - cond = cond_chunks; - body = body_chunks - }) |> add_chunk chunks - | Until -> - let cond_chunks = Queue.create () in - chunk_exp comments cond_chunks cond; - let body_chunks = Queue.create () in - chunk_exp comments body_chunks body; - (While { - repeat_until = true; - termination_measure = measure_chunks_opt; - cond = cond_chunks; - body = body_chunks - }) |> add_chunk chunks - end + let measure_chunks_opt = + match measure with + | Measure_aux (Measure_none, _) -> None + | Measure_aux (Measure_some exp, _) -> + let measure_chunks = Queue.create () in + chunk_exp comments measure_chunks exp; + Some measure_chunks + in + begin + match loop_type with + | While -> + let cond_chunks = Queue.create () in + chunk_exp comments cond_chunks cond; + let body_chunks = Queue.create () in + chunk_exp comments body_chunks body; + While + { repeat_until = false; termination_measure = measure_chunks_opt; cond = cond_chunks; body = body_chunks } + |> add_chunk chunks + | Until -> + let cond_chunks = Queue.create () in + chunk_exp comments cond_chunks cond; + let body_chunks = Queue.create () in + chunk_exp comments body_chunks body; + While + { repeat_until = true; termination_measure = measure_chunks_opt; cond = cond_chunks; body = body_chunks } + |> add_chunk chunks + end and chunk_vector_update comments (E_aux (aux, l) as exp) = let rec_chunk_exp exp = @@ -1001,90 +961,81 @@ and chunk_vector_update comments (E_aux (aux, l) as exp) = in match aux with | E_vector_update (vec, ix, exp) -> - let (vec_chunks, update) = chunk_vector_update comments vec in - let ix = rec_chunk_exp ix in - let exp = rec_chunk_exp exp in - (vec_chunks, Binary (ix, "=", exp) :: update) + let vec_chunks, update = chunk_vector_update comments vec in + let ix = rec_chunk_exp ix in + let exp = rec_chunk_exp exp in + (vec_chunks, Binary (ix, "=", exp) :: update) | E_vector_update_subrange (vec, ix1, ix2, exp) -> - let (vec_chunks, update) = chunk_vector_update comments vec in - let ix1 = rec_chunk_exp ix1 in - let ix2 = rec_chunk_exp ix2 in - let exp = rec_chunk_exp exp in - (vec_chunks, Ternary (ix1, "..", ix2, "=", exp) :: update) + let vec_chunks, update = chunk_vector_update comments vec in + let ix1 = rec_chunk_exp ix1 in + let ix2 = rec_chunk_exp ix2 in + let exp = rec_chunk_exp exp in + (vec_chunks, Ternary (ix1, "..", ix2, "=", exp) :: update) | _ -> - let exp_chunks = Queue.create () in - chunk_exp comments exp_chunks exp; - (exp_chunks, []) + let exp_chunks = Queue.create () in + chunk_exp comments exp_chunks exp; + (exp_chunks, []) and chunk_pexp ?delim comments (Pat_aux (aux, l)) = match aux with | Pat_exp (pat, exp) -> - let funcl_space = match pat with P_aux (P_tuple _, _) -> false | _ -> true in - let pat_chunks = Queue.create () in - chunk_pat comments pat_chunks pat; - let exp_chunks = Queue.create () in - chunk_exp comments exp_chunks exp; - (match delim with Some d -> Queue.add (Delim d) exp_chunks | _ -> ()); - ignore (pop_trailing_comment comments exp_chunks (ending_line_num l)); - { funcl_space = funcl_space; pat = pat_chunks; guard = None; body = exp_chunks } + let funcl_space = match pat with P_aux (P_tuple _, _) -> false | _ -> true in + let pat_chunks = Queue.create () in + chunk_pat comments pat_chunks pat; + let exp_chunks = Queue.create () in + chunk_exp comments exp_chunks exp; + (match delim with Some d -> Queue.add (Delim d) exp_chunks | _ -> ()); + ignore (pop_trailing_comment comments exp_chunks (ending_line_num l)); + { funcl_space; pat = pat_chunks; guard = None; body = exp_chunks } | Pat_when (pat, guard, exp) -> - let pat_chunks = Queue.create () in - chunk_pat comments pat_chunks pat; - let guard_chunks = Queue.create () in - chunk_exp comments guard_chunks guard; - let exp_chunks = Queue.create () in - chunk_exp comments exp_chunks exp; - (match delim with Some d -> Queue.add (Delim d) exp_chunks | _ -> ()); - ignore (pop_trailing_comment comments exp_chunks (ending_line_num l)); - { funcl_space = true; pat = pat_chunks; guard = Some guard_chunks; body = exp_chunks } + let pat_chunks = Queue.create () in + chunk_pat comments pat_chunks pat; + let guard_chunks = Queue.create () in + chunk_exp comments guard_chunks guard; + let exp_chunks = Queue.create () in + chunk_exp comments exp_chunks exp; + (match delim with Some d -> Queue.add (Delim d) exp_chunks | _ -> ()); + ignore (pop_trailing_comment comments exp_chunks (ending_line_num l)); + { funcl_space = true; pat = pat_chunks; guard = Some guard_chunks; body = exp_chunks } let chunk_funcl comments (FCL_aux (FCL_funcl (_, pexp), _)) = chunk_pexp comments pexp let chunk_quant_item comments chunks last = function | QI_aux (QI_id kopt, l) -> - pop_comments comments chunks l; - Queue.add (chunk_of_kopt kopt) chunks; - if not last then ( - Queue.add (Spacer (false, 1)) chunks - ) - | QI_aux (QI_constraint atyp, _) -> - chunk_atyp comments chunks atyp + pop_comments comments chunks l; + Queue.add (chunk_of_kopt kopt) chunks; + if not last then Queue.add (Spacer (false, 1)) chunks + | QI_aux (QI_constraint atyp, _) -> chunk_atyp comments chunks atyp let chunk_quant_items l comments chunks quant_items = pop_comments comments chunks l; - let is_qi_id = function - | QI_aux (QI_id _, _) as qi -> Ok qi - | QI_aux (QI_constraint _, _) as qi -> Error qi - in + let is_qi_id = function QI_aux (QI_id _, _) as qi -> Ok qi | QI_aux (QI_constraint _, _) as qi -> Error qi in let kopts, constrs = Util.map_split is_qi_id quant_items in let kopt_chunks = Queue.create () in Util.iter_last (chunk_quant_item comments kopt_chunks) kopts; - let constr_chunks_opt = match constrs with + let constr_chunks_opt = + match constrs with | [] -> None | _ -> - let constr_chunks = Queue.create () in - Util.iter_last (chunk_quant_item comments constr_chunks) constrs; - Some constr_chunks + let constr_chunks = Queue.create () in + Util.iter_last (chunk_quant_item comments constr_chunks) constrs; + Some constr_chunks in - Typ_quant { - vars = kopt_chunks; - constr_opt = constr_chunks_opt; - } - |> add_chunk chunks + Typ_quant { vars = kopt_chunks; constr_opt = constr_chunks_opt } |> add_chunk chunks let chunk_tannot_opt comments (Typ_annot_opt_aux (aux, l)) = match aux with - | Typ_annot_opt_none -> None, None + | Typ_annot_opt_none -> (None, None) | Typ_annot_opt_some (TypQ_aux (TypQ_no_forall, _), typ) -> - let typ_chunks = Queue.create () in - chunk_atyp comments typ_chunks typ; - None, Some typ_chunks + let typ_chunks = Queue.create () in + chunk_atyp comments typ_chunks typ; + (None, Some typ_chunks) | Typ_annot_opt_some (TypQ_aux (TypQ_tq quant_items, _), typ) -> - let typq_chunks = Queue.create () in - chunk_quant_items l comments typq_chunks quant_items; - let typ_chunks = Queue.create () in - chunk_atyp comments typ_chunks typ; - Some typq_chunks, Some typ_chunks + let typq_chunks = Queue.create () in + chunk_quant_items l comments typq_chunks quant_items; + let typ_chunks = Queue.create () in + chunk_atyp comments typ_chunks typ; + (Some typq_chunks, Some typ_chunks) let chunk_default_typing_spec comments chunks (DT_aux (DT_order (kind, typ), l)) = pop_comments comments chunks l; @@ -1097,43 +1048,29 @@ let chunk_default_typing_spec comments chunks (DT_aux (DT_order (kind, typ), l)) let chunk_fundef comments chunks (FD_aux (FD_function (rec_opt, tannot_opt, _, funcls), l)) = pop_comments comments chunks l; - let fn_id = match funcls with + let fn_id = + match funcls with | FCL_aux (FCL_funcl (id, _), _) :: _ -> id | _ -> Reporting.unreachable l __POS__ "Empty funcl list in formatter" in let typq_opt, return_typ_opt = chunk_tannot_opt comments tannot_opt in let funcls = List.map (chunk_funcl comments) funcls in - (Function { - id = fn_id; - clause = false; - rec_opt = None; - typq_opt = typq_opt; - return_typ_opt = return_typ_opt; - funcls = funcls; - }) |> add_chunk chunks + Function { id = fn_id; clause = false; rec_opt = None; typq_opt; return_typ_opt; funcls } |> add_chunk chunks let chunk_val_spec comments chunks (VS_aux (VS_val_spec (typschm, id, extern_opt, is_cast), l)) = pop_comments comments chunks l; - let typq_chunks_opt, typ = match typschm with - | TypSchm_aux (TypSchm_ts (TypQ_aux (TypQ_no_forall, _), typ), _) -> - None, typ + let typq_chunks_opt, typ = + match typschm with + | TypSchm_aux (TypSchm_ts (TypQ_aux (TypQ_no_forall, _), typ), _) -> (None, typ) | TypSchm_aux (TypSchm_ts (TypQ_aux (TypQ_tq quant_items, _), typ), l) -> - let typq_chunks = Queue.create () in - chunk_quant_items l comments typq_chunks quant_items; - Some typq_chunks, typ + let typq_chunks = Queue.create () in + chunk_quant_items l comments typq_chunks quant_items; + (Some typq_chunks, typ) in let typ_chunks = Queue.create () in chunk_atyp comments typ_chunks typ; - add_chunk chunks (Val { - is_cast = is_cast; - id = id; - extern_opt = extern_opt; - typq_opt = typq_chunks_opt; - typ = typ_chunks; - }); - if not (pop_trailing_comment ~space:1 comments chunks (ending_line_num l)) then ( - Queue.push (Spacer (true, 1)) chunks - ) + add_chunk chunks (Val { is_cast; id; extern_opt; typq_opt = typq_chunks_opt; typ = typ_chunks }); + if not (pop_trailing_comment ~space:1 comments chunks (ending_line_num l)) then Queue.push (Spacer (true, 1)) chunks let chunk_register comments chunks (DEC_aux (DEC_reg ((ATyp_aux (_, typ_l) as typ), id, opt_exp), l)) = pop_comments comments chunks l; @@ -1147,19 +1084,19 @@ let chunk_register comments chunks (DEC_aux (DEC_reg ((ATyp_aux (_, typ_l) as ty let typ_chunks = Queue.create () in chunk_atyp comments typ_chunks typ; - let skip_spacer = match opt_exp with + let skip_spacer = + match opt_exp with | Some (E_aux (_, exp_l) as exp) -> - let exp_chunks = Queue.create () in - chunk_exp comments exp_chunks exp; - Queue.push (Ternary (id_chunks, ":", typ_chunks, "=", exp_chunks)) def_chunks; - pop_trailing_comment ~space:1 comments exp_chunks (ending_line_num exp_l) - | None -> - Queue.push (Binary (id_chunks, ":", typ_chunks)) def_chunks; - pop_trailing_comment ~space:1 comments typ_chunks (ending_line_num typ_l) in + let exp_chunks = Queue.create () in + chunk_exp comments exp_chunks exp; + Queue.push (Ternary (id_chunks, ":", typ_chunks, "=", exp_chunks)) def_chunks; + pop_trailing_comment ~space:1 comments exp_chunks (ending_line_num exp_l) + | None -> + Queue.push (Binary (id_chunks, ":", typ_chunks)) def_chunks; + pop_trailing_comment ~space:1 comments typ_chunks (ending_line_num typ_l) + in Queue.push (Chunks def_chunks) chunks; - if not skip_spacer then ( - Queue.push (Spacer (true, 1)) chunks - ) + if not skip_spacer then Queue.push (Spacer (true, 1)) chunks let chunk_toplevel_let l comments chunks (LB_aux (LB_val (pat, exp), _)) = pop_comments comments chunks l; @@ -1169,22 +1106,22 @@ let chunk_toplevel_let l comments chunks (LB_aux (LB_val (pat, exp), _)) = let pat_chunks = Queue.create () in let exp_chunks = Queue.create () in - begin match pat with - | P_aux (P_typ (typ, pat), pat_l) -> - chunk_pat comments pat_chunks pat; - let typ_chunks = Queue.create () in - chunk_atyp comments typ_chunks typ; - chunk_exp comments exp_chunks exp; - Queue.push (Ternary (pat_chunks, ":", typ_chunks, "=", exp_chunks)) def_chunks - | _ -> - chunk_pat comments pat_chunks pat; - chunk_exp comments exp_chunks exp; - Queue.push (Binary (pat_chunks, "=", exp_chunks)) def_chunks + begin + match pat with + | P_aux (P_typ (typ, pat), pat_l) -> + chunk_pat comments pat_chunks pat; + let typ_chunks = Queue.create () in + chunk_atyp comments typ_chunks typ; + chunk_exp comments exp_chunks exp; + Queue.push (Ternary (pat_chunks, ":", typ_chunks, "=", exp_chunks)) def_chunks + | _ -> + chunk_pat comments pat_chunks pat; + chunk_exp comments exp_chunks exp; + Queue.push (Binary (pat_chunks, "=", exp_chunks)) def_chunks end; Queue.push (Chunks def_chunks) chunks; - if not (pop_trailing_comment ~space:1 comments exp_chunks (ending_line_num l)) then ( + if not (pop_trailing_comment ~space:1 comments exp_chunks (ending_line_num l)) then Queue.push (Spacer (true, 1)) chunks - ) let chunk_keyword k chunks = Queue.push (Atom k) chunks; @@ -1194,29 +1131,27 @@ let chunk_id id comments chunks = pop_comments comments chunks (id_loc id); Queue.push (Atom (string_of_id id)) chunks; Queue.push (Spacer (false, 1)) chunks - + let finish_def def_chunks chunks = Queue.push (Chunks def_chunks) chunks; Queue.push (Spacer (true, 1)) chunks let build_def chunks fs = let def_chunks = Queue.create () in - List.iter (fun f -> - f def_chunks - ) fs; + List.iter (fun f -> f def_chunks) fs; finish_def def_chunks chunks let chunk_type_def comments chunks (TD_aux (aux, l)) = pop_comments comments chunks l; let chunk_enum_member comments chunks member = match member with - | (id, None) -> - pop_comments comments chunks (id_loc id); - Queue.push (Atom (string_of_id id)) chunks - | (id, Some exp) -> - chunk_id id comments chunks; - chunk_keyword "=>" chunks; - chunk_exp comments chunks exp + | id, None -> + pop_comments comments chunks (id_loc id); + Queue.push (Atom (string_of_id id)) chunks + | id, Some exp -> + chunk_id id comments chunks; + chunk_keyword "=>" chunks; + chunk_exp comments chunks exp in let chunk_enum_function comments chunks (id, typ) = chunk_id id comments chunks; @@ -1225,122 +1160,83 @@ let chunk_type_def comments chunks (TD_aux (aux, l)) = in match aux with | TD_enum (id, [], members, _) -> - let members = chunk_delimit ~delim:"," ~get_loc:(fun x -> id_loc (fst x)) ~chunk:chunk_enum_member comments chunks members in - Queue.add (Enum { - id = id; - enum_functions = None; - members = members - }) chunks; - Queue.add (Spacer (true, 1)) chunks + let members = + chunk_delimit ~delim:"," ~get_loc:(fun x -> id_loc (fst x)) ~chunk:chunk_enum_member comments chunks members + in + Queue.add (Enum { id; enum_functions = None; members }) chunks; + Queue.add (Spacer (true, 1)) chunks | TD_enum (id, enum_functions, members, _) -> - let enum_functions = chunk_delimit ~delim:"," ~get_loc:(fun x -> id_loc (fst x)) ~chunk:chunk_enum_function comments chunks enum_functions in - let members = chunk_delimit ~delim:"," ~get_loc:(fun x -> id_loc (fst x)) ~chunk:chunk_enum_member comments chunks members in - Queue.add (Enum { - id = id; - enum_functions = Some enum_functions; - members = members - }) chunks; - Queue.add (Spacer (true, 1)) chunks - | _ -> - Reporting.unreachable l __POS__ "unhandled type def" - + let enum_functions = + chunk_delimit ~delim:"," + ~get_loc:(fun x -> id_loc (fst x)) + ~chunk:chunk_enum_function comments chunks enum_functions + in + let members = + chunk_delimit ~delim:"," ~get_loc:(fun x -> id_loc (fst x)) ~chunk:chunk_enum_member comments chunks members + in + Queue.add (Enum { id; enum_functions = Some enum_functions; members }) chunks; + Queue.add (Spacer (true, 1)) chunks + | _ -> Reporting.unreachable l __POS__ "unhandled type def" + let chunk_scattered comments chunks (SD_aux (aux, l)) = pop_comments comments chunks l; match aux with | SD_funcl (FCL_aux (FCL_funcl (id, _), _) as funcl) -> - let funcl_chunks = chunk_funcl comments funcl in - Queue.push (Function { - id = id; - clause = true; - rec_opt = None; - typq_opt = None; - return_typ_opt = None; - funcls = [funcl_chunks] - }) chunks - | SD_end id -> - build_def chunks [ - chunk_keyword "end"; - chunk_id id comments - ] - | SD_function (_, _, _, id) -> - build_def chunks [ - chunk_keyword "scattered function"; - chunk_id id comments - ] - | _ -> - Reporting.unreachable l __POS__ "unhandled scattered def" - -let def_spacer (_, e) (s, _) = - match e, s with - | Some l_e, Some l_s -> - if l_s > l_e + 1 then 1 else 0 - | _, _ -> 1 + let funcl_chunks = chunk_funcl comments funcl in + Queue.push + (Function { id; clause = true; rec_opt = None; typq_opt = None; return_typ_opt = None; funcls = [funcl_chunks] }) + chunks + | SD_end id -> build_def chunks [chunk_keyword "end"; chunk_id id comments] + | SD_function (_, _, _, id) -> build_def chunks [chunk_keyword "scattered function"; chunk_id id comments] + | _ -> Reporting.unreachable l __POS__ "unhandled scattered def" + +let def_spacer (_, e) (s, _) = match (e, s) with Some l_e, Some l_s -> if l_s > l_e + 1 then 1 else 0 | _, _ -> 1 let read_source (p1 : Lexing.position) (p2 : Lexing.position) source = String.sub source p1.pos_cnum (p2.pos_cnum - p1.pos_cnum) -let can_handle_td (TD_aux (aux, _)) = - match aux with - | TD_enum _ -> true - | _ -> false +let can_handle_td (TD_aux (aux, _)) = match aux with TD_enum _ -> true | _ -> false -let can_handle_sd (SD_aux (aux, _)) = - match aux with - | (SD_funcl _ | SD_end _ | SD_function _) -> true - | _ -> false +let can_handle_sd (SD_aux (aux, _)) = match aux with SD_funcl _ | SD_end _ | SD_function _ -> true | _ -> false let chunk_def source last_line_span comments chunks (DEF_aux (def, l)) = - let line_span = starting_line_num l, ending_line_num l in + let line_span = (starting_line_num l, ending_line_num l) in let spacing = def_spacer last_line_span line_span in - if spacing > 0 then ( - Queue.add (Spacer (true, spacing)) chunks - ); + if spacing > 0 then Queue.add (Spacer (true, spacing)) chunks; let pragma_span = ref false in - begin match def with - | DEF_fundef fdef -> - chunk_fundef comments chunks fdef - | DEF_pragma (pragma, arg) -> - Queue.add (Pragma (pragma, arg)) chunks; - pragma_span := true - | DEF_default dts -> - chunk_default_typing_spec comments chunks dts - | DEF_fixity (prec, n, id) -> - pop_comments comments chunks (id_loc id); - let string_of_prec = function - | Infix -> "infix" - | InfixL -> "infixl" - | InfixR -> "infixr" in - Queue.add (Atom (Printf.sprintf "%s %s %s" (string_of_prec prec) (Big_int.to_string n) (string_of_id id))) chunks; - Queue.add (Spacer (true, 1)) chunks - | DEF_register reg -> - chunk_register comments chunks reg - | DEF_let lb -> - chunk_toplevel_let l comments chunks lb - | DEF_val vs -> - chunk_val_spec comments chunks vs - | DEF_scattered sd when can_handle_sd sd -> - chunk_scattered comments chunks sd - | DEF_type td when can_handle_td td -> - chunk_type_def comments chunks td - | _ -> - begin match Reporting.simp_loc l with - | Some (p1, p2) -> - pop_comments comments chunks l; - (* These comments are within the source we are about to include *) - discard_comments comments p2; - let source = read_source p1 p2 source in - Queue.add (Raw source) chunks; + begin + match def with + | DEF_fundef fdef -> chunk_fundef comments chunks fdef + | DEF_pragma (pragma, arg) -> + Queue.add (Pragma (pragma, arg)) chunks; + pragma_span := true + | DEF_default dts -> chunk_default_typing_spec comments chunks dts + | DEF_fixity (prec, n, id) -> + pop_comments comments chunks (id_loc id); + let string_of_prec = function Infix -> "infix" | InfixL -> "infixl" | InfixR -> "infixr" in + Queue.add + (Atom (Printf.sprintf "%s %s %s" (string_of_prec prec) (Big_int.to_string n) (string_of_id id))) + chunks; Queue.add (Spacer (true, 1)) chunks - | None -> - Reporting.unreachable l __POS__ "Could not format" - end + | DEF_register reg -> chunk_register comments chunks reg + | DEF_let lb -> chunk_toplevel_let l comments chunks lb + | DEF_val vs -> chunk_val_spec comments chunks vs + | DEF_scattered sd when can_handle_sd sd -> chunk_scattered comments chunks sd + | DEF_type td when can_handle_td td -> chunk_type_def comments chunks td + | _ -> begin + match Reporting.simp_loc l with + | Some (p1, p2) -> + pop_comments comments chunks l; + (* These comments are within the source we are about to include *) + discard_comments comments p2; + let source = read_source p1 p2 source in + Queue.add (Raw source) chunks; + Queue.add (Spacer (true, 1)) chunks + | None -> Reporting.unreachable l __POS__ "Could not format" + end end; (* Adjust the line span of a pragma to a single line so the spacing works out *) - if not !pragma_span then ( - line_span - ) else ( - fst line_span, fst line_span - ) + if not !pragma_span then line_span else (fst line_span, fst line_span) let chunk_defs source comments defs = let comments = Stack.of_seq (List.to_seq comments) in diff --git a/src/lib/chunk_ast.mli b/src/lib/chunk_ast.mli index e93c9df57..c3a147381 100644 --- a/src/lib/chunk_ast.mli +++ b/src/lib/chunk_ast.mli @@ -76,10 +76,7 @@ type binder = Var_binder | Let_binder | Internal_plet_binder val binder_keyword : binder -> string -type if_format = { - then_brace : bool; - else_brace : bool - } +type if_format = { then_brace : bool; else_brace : bool } type match_kind = Try_match | Match_match @@ -96,7 +93,7 @@ type chunk = rec_opt : chunks option; typq_opt : chunks option; return_typ_opt : chunks option; - funcls : pexp_chunks list + funcls : pexp_chunks list; } | Val of { is_cast : bool; @@ -105,25 +102,10 @@ type chunk = typq_opt : chunks option; typ : chunks; } - | Enum of { - id : Parse_ast.id; - enum_functions : chunks list option; - members : chunks list - } - | Function_typ of { - mapping : bool; - lhs : chunks; - rhs : chunks; - } - | Exists of { - vars: chunks; - constr: chunks; - typ: chunks; - } - | Typ_quant of { - vars : chunks; - constr_opt : chunks option; - } + | Enum of { id : Parse_ast.id; enum_functions : chunks list option; members : chunks list } + | Function_typ of { mapping : bool; lhs : chunks; rhs : chunks } + | Exists of { vars : chunks; constr : chunks; typ : chunks } + | Typ_quant of { vars : chunks; constr_opt : chunks option } | App of Parse_ast.id * chunks list | Field of chunks * Parse_ast.id | Tuple of string * string * int * chunks list @@ -143,38 +125,23 @@ type chunk = | If_then of bool * chunks * chunks | If_then_else of if_format * chunks * chunks * chunks | Struct_update of chunks * chunks list - | Match of { - kind : match_kind; - exp : chunks; - aligned : bool; - cases : pexp_chunks list - } + | Match of { kind : match_kind; exp : chunks; aligned : bool; cases : pexp_chunks list } | Foreach of { var : chunks; decreasing : bool; from_index : chunks; to_index : chunks; step : chunks option; - body : chunks - } - | While of { - repeat_until : bool; - termination_measure : chunks option; - cond : chunks; - body : chunks + body : chunks; } + | While of { repeat_until : bool; termination_measure : chunks option; cond : chunks; body : chunks } | Vector_updates of chunks * chunk list | Chunks of chunks | Raw of string and chunks = chunk Queue.t -and pexp_chunks = { - funcl_space : bool; - pat : chunks; - guard : chunks option; - body : chunks - } +and pexp_chunks = { funcl_space : bool; pat : chunks; guard : chunks option; body : chunks } val prerr_chunk : string -> chunk -> unit diff --git a/src/lib/constant_fold.ml b/src/lib/constant_fold.ml index 653b1e5a3..2bc0c99a3 100644 --- a/src/lib/constant_fold.ml +++ b/src/lib/constant_fold.ml @@ -70,18 +70,17 @@ open Ast_util open Type_check open Rewriter -module StringMap = Map.Make(String);; +module StringMap = Map.Make (String) (* Flag controls whether any constant folding will occur. false = no folding, true = perform constant folding. *) let optimize_constant_fold = ref false -let rec fexp_of_ctor (field, value) = - FE_aux (FE_fexp (mk_id field, exp_of_value value), no_annot) +let rec fexp_of_ctor (field, value) = FE_aux (FE_fexp (mk_id field, exp_of_value value), no_annot) (* The interpreter will return a value for each folded expression, so we must convert that back to expression to re-insert it in the AST - *) +*) and exp_of_value = let open Value in function @@ -91,15 +90,11 @@ and exp_of_value = | V_bool true -> mk_lit_exp L_true | V_bool false -> mk_lit_exp L_false | V_string str -> mk_lit_exp (L_string str) - | V_record ctors -> - mk_exp (E_struct (List.map fexp_of_ctor (StringMap.bindings ctors))) - | V_vector vs -> - mk_exp (E_vector (List.map exp_of_value vs)) - | V_tuple vs -> - mk_exp (E_tuple (List.map exp_of_value vs)) + | V_record ctors -> mk_exp (E_struct (List.map fexp_of_ctor (StringMap.bindings ctors))) + | V_vector vs -> mk_exp (E_vector (List.map exp_of_value vs)) + | V_tuple vs -> mk_exp (E_tuple (List.map exp_of_value vs)) | V_unit -> mk_lit_exp L_unit - | V_attempted_read str -> - mk_exp (E_id (mk_id str)) + | V_attempted_read str -> mk_exp (E_id (mk_id str)) | _ -> failwith "No expression for value" (* We want to avoid evaluating things like print statements at compile @@ -109,7 +104,8 @@ let safe_primops = List.fold_left (fun m k -> StringMap.remove k m) !Value.primops - [ "print_endline"; + [ + "print_endline"; "prerr_endline"; "putchar"; "print"; @@ -125,7 +121,7 @@ let safe_primops = "write_ram"; "get_time_ns"; "Elf_loader.elf_entry"; - "Elf_loader.elf_tohost" + "Elf_loader.elf_tohost"; ] (** We can specify a list of identifiers that we want to remove from @@ -147,11 +143,7 @@ let safe_primops = let opt_fold_to_unit = ref [] let fold_to_unit id = - let remove = - !opt_fold_to_unit - |> List.map mk_id - |> List.fold_left (fun m id -> IdSet.add id m) IdSet.empty - in + let remove = !opt_fold_to_unit |> List.map mk_id |> List.fold_left (fun m id -> IdSet.add id m) IdSet.empty in IdSet.mem id remove let rec is_constant (E_aux (e_aux, _) as exp) = @@ -161,10 +153,9 @@ let rec is_constant (E_aux (e_aux, _) as exp) = | E_struct fexps -> List.for_all is_constant_fexp fexps | E_typ (_, exp) -> is_constant exp | E_tuple exps -> List.for_all is_constant exps - | E_id id -> - (match Env.lookup_id id (env_of exp) with - | Enum _ -> true - | _ -> false) + | E_id id -> ( + match Env.lookup_id id (env_of exp) with Enum _ -> true | _ -> false + ) | _ -> false and is_constant_fexp (FE_aux (FE_fexp (_, exp), _)) = is_constant exp @@ -174,19 +165,16 @@ let rec run frame = match frame with | Interpreter.Done (state, v) -> v | Interpreter.Fail _ -> - (* something went wrong, raise exception to abort constant folding *) - assert false - | Interpreter.Step (lazy_str, _, _, _) -> - run (Interpreter.eval_frame frame) - | Interpreter.Break frame -> - run (Interpreter.eval_frame frame) + (* something went wrong, raise exception to abort constant folding *) + assert false + | Interpreter.Step (lazy_str, _, _, _) -> run (Interpreter.eval_frame frame) + | Interpreter.Break frame -> run (Interpreter.eval_frame frame) | Interpreter.Effect_request (out, st, stack, Interpreter.Read_reg (reg, cont)) -> - (* return a dummy value to read_reg requests which we handle above - if an expression finally evals to it, but the interpreter - will fail if it tries to actually use. See value.ml *) - run (cont (Value.V_attempted_read reg) st) - | Interpreter.Effect_request _ -> - assert false (* effectful, raise exception to abort constant folding *) + (* return a dummy value to read_reg requests which we handle above + if an expression finally evals to it, but the interpreter + will fail if it tries to actually use. See value.ml *) + run (cont (Value.V_attempted_read reg) st) + | Interpreter.Effect_request _ -> assert false (* effectful, raise exception to abort constant folding *) (** This rewriting pass looks for function applications (E_app) expressions where every argument is a literal. It passes these @@ -205,18 +193,11 @@ let rec run frame = - Throws an exception that isn't caught. *) -let initial_state ast env = - Interpreter.initial_state ~registers:false ast env safe_primops +let initial_state ast env = Interpreter.initial_state ~registers:false ast env safe_primops -type fixed = { - registers: tannot exp Bindings.t; - fields: tannot exp Bindings.t Bindings.t; - } +type fixed = { registers : tannot exp Bindings.t; fields : tannot exp Bindings.t Bindings.t } -let no_fixed = { - registers = Bindings.empty; - fields = Bindings.empty; - } +let no_fixed = { registers = Bindings.empty; fields = Bindings.empty } let rw_exp fixed target ok not_ok istate = let evaluate e_aux annot = @@ -225,87 +206,85 @@ let rw_exp fixed target ok not_ok istate = begin let v = run (Interpreter.Step (lazy "", istate, initial_monad, [])) in let exp = exp_of_value v in - try (ok (); Type_check.check_exp (env_of_annot annot) exp (typ_of_annot annot)) with - | Type_error (env, l, err) -> - (* A type error here would be unexpected, so don't ignore it! *) - Reporting.warn "" l - ("Type error when folding constants in " - ^ string_of_exp (E_aux (e_aux, annot)) - ^ "\n" ^ Type_error.string_of_type_error err); - not_ok (); - E_aux (e_aux, annot) + try + ok (); + Type_check.check_exp (env_of_annot annot) exp (typ_of_annot annot) + with Type_error (env, l, err) -> + (* A type error here would be unexpected, so don't ignore it! *) + Reporting.warn "" l + ("Type error when folding constants in " + ^ string_of_exp (E_aux (e_aux, annot)) + ^ "\n" ^ Type_error.string_of_type_error err + ); + not_ok (); + E_aux (e_aux, annot) end with (* Otherwise if anything goes wrong when trying to constant fold, just continue without optimising. *) - | _ -> E_aux (e_aux, annot) + | _ -> + E_aux (e_aux, annot) in let rw_funcall e_aux annot = match e_aux with | E_app (id, args) when fold_to_unit id -> - ok (); E_aux (E_lit (L_aux (L_unit, fst annot)), annot) - - | E_id id -> - begin match Bindings.find_opt id fixed.registers with - | Some exp -> - ok (); exp - | None -> - E_aux (e_aux, annot) - end - - | E_field (E_aux (E_id id, _), field) -> - begin match Bindings.find_opt id fixed.fields with - | Some fields -> - begin match Bindings.find_opt field fields with - | Some exp -> - ok (); exp - | None -> - E_aux (e_aux, annot) + ok (); + E_aux (E_lit (L_aux (L_unit, fst annot)), annot) + | E_id id -> begin + match Bindings.find_opt id fixed.registers with + | Some exp -> + ok (); + exp + | None -> E_aux (e_aux, annot) + end + | E_field (E_aux (E_id id, _), field) -> begin + match Bindings.find_opt id fixed.fields with + | Some fields -> begin + match Bindings.find_opt field fields with + | Some exp -> + ok (); + exp + | None -> E_aux (e_aux, annot) end - | None -> - E_aux (e_aux, annot) - end - + | None -> E_aux (e_aux, annot) + end (* Short-circuit boolean operators with constants *) | E_app (id, [(E_aux (E_lit (L_aux (L_false, _)), _) as false_exp); _]) when string_of_id id = "and_bool" -> - ok (); false_exp - + ok (); + false_exp | E_app (id, [(E_aux (E_lit (L_aux (L_true, _)), _) as true_exp); _]) when string_of_id id = "or_bool" -> - ok (); true_exp - + ok (); + true_exp | E_app (id, args) when List.for_all is_constant args -> - let env = env_of_annot annot in - (* We want to fold all primitive operations, but avoid folding - non-primitives that are defined in target-specific way. *) - let is_primop = - Env.is_extern id env "interpreter" && StringMap.mem (Env.get_extern id env "interpreter") safe_primops - in - if not (Env.is_extern id env target) || is_primop then - evaluate e_aux annot - else - E_aux (e_aux, annot) - - | E_typ (typ, (E_aux (E_lit _, _) as lit)) -> ok (); lit - - | E_field (exp, id) when is_constant exp -> - evaluate e_aux annot - - | E_if (E_aux (E_lit (L_aux (L_true, _)), _), then_exp, _) -> ok (); then_exp - | E_if (E_aux (E_lit (L_aux (L_false, _)), _), _, else_exp) -> ok (); else_exp - + let env = env_of_annot annot in + (* We want to fold all primitive operations, but avoid folding + non-primitives that are defined in target-specific way. *) + let is_primop = + Env.is_extern id env "interpreter" && StringMap.mem (Env.get_extern id env "interpreter") safe_primops + in + if (not (Env.is_extern id env target)) || is_primop then evaluate e_aux annot else E_aux (e_aux, annot) + | E_typ (typ, (E_aux (E_lit _, _) as lit)) -> + ok (); + lit + | E_field (exp, id) when is_constant exp -> evaluate e_aux annot + | E_if (E_aux (E_lit (L_aux (L_true, _)), _), then_exp, _) -> + ok (); + then_exp + | E_if (E_aux (E_lit (L_aux (L_false, _)), _), _, else_exp) -> + ok (); + else_exp (* We only propagate lets in the simple case where we know that the id will have the inferred type of the argument. For more complex let bindings trying to propagate them may result in type errors due to how type variables are bound by let bindings - *) + *) | E_let (LB_aux (LB_val (P_aux (P_id id, _), bind), _), exp) when is_constant bind -> - ok (); - subst id bind exp - + ok (); + subst id bind exp | _ -> E_aux (e_aux, annot) in - fold_exp { id_exp_alg with e_aux = (fun (e_aux, annot) -> rw_funcall e_aux annot)} + fold_exp { id_exp_alg with e_aux = (fun (e_aux, annot) -> rw_funcall e_aux annot) } let rewrite_exp_once target = rw_exp no_fixed target (fun _ -> ()) (fun _ -> ()) @@ -315,86 +294,96 @@ let rec rewrite_constant_function_calls' fixed target ast = let not_ok () = decr rewrite_count in let istate = initial_state ast Type_check.initial_env in - let rw_defs = { - rewriters_base with - rewrite_exp = (fun _ -> rw_exp fixed target ok not_ok istate) - } in + let rw_defs = { rewriters_base with rewrite_exp = (fun _ -> rw_exp fixed target ok not_ok istate) } in let ast = rewrite_ast_base rw_defs ast in (* We keep iterating until we have no more re-writes to do *) - if !rewrite_count > 0 - then rewrite_constant_function_calls' fixed target ast - else ast + if !rewrite_count > 0 then rewrite_constant_function_calls' fixed target ast else ast let rewrite_constant_function_calls fixed target ast = - if !optimize_constant_fold then - rewrite_constant_function_calls' fixed target ast - else - ast + if !optimize_constant_fold then rewrite_constant_function_calls' fixed target ast else ast -type to_constant = - | Register of id * typ * tannot exp - | Register_field of id * id * typ * tannot exp +type to_constant = Register of id * typ * tannot exp | Register_field of id * id * typ * tannot exp let () = let open Interactive in let open Printf in - let update_fixed fixed = function - | Register (id, _, exp) -> - { fixed with registers = Bindings.add id exp fixed.registers } + | Register (id, _, exp) -> { fixed with registers = Bindings.add id exp fixed.registers } | Register_field (id, field, _, exp) -> - let prev_fields = match Bindings.find_opt id fixed.fields with Some f -> f | None -> Bindings.empty in - let updated_fields = Bindings.add field exp prev_fields in - { fixed with fields = Bindings.add id updated_fields fixed.fields } + let prev_fields = match Bindings.find_opt id fixed.fields with Some f -> f | None -> Bindings.empty in + let updated_fields = Bindings.add field exp prev_fields in + { fixed with fields = Bindings.add id updated_fields fixed.fields } in - ArgString ("target", fun target -> ArgString ("assignments", fun assignments -> Action (fun istate -> - let assignments = Str.split (Str.regexp " +") assignments in - let assignments = - List.map (fun assignment -> - match String.split_on_char '=' assignment with - | [reg; value] -> - begin match String.split_on_char '.' reg with - | [reg; field] -> - let reg = mk_id reg in - let field = mk_id field in - begin match Env.lookup_id reg istate.env with - | Register (Typ_aux (Typ_id rec_id, _)) -> - let (_, fields) = Env.get_record rec_id istate.env in - let typ = match List.find_opt (fun (typ, id) -> Id.compare id field = 0) fields with - | Some (typ, _) -> typ - | None -> failwith (sprintf "Register %s does not have a field %s" (string_of_id reg) (string_of_id field)) - in - let exp = Initial_check.exp_of_string value in - let exp = check_exp istate.env exp typ in - Register_field (reg, field, typ, exp) - | _ -> - failwith (sprintf "Register %s is not defined as a record in the current environment" (string_of_id reg)) - end - | _ -> - let reg = mk_id reg in - begin match Env.lookup_id reg istate.env with - | Register typ -> - let exp = Initial_check.exp_of_string value in - let exp = check_exp istate.env exp typ in - Register (reg, typ, exp) - | _ -> - failwith (sprintf "Register %s is not defined in the current environment" (string_of_id reg)) - end - end - | _ -> failwith (sprintf "Could not parse '%s' as an assignment =" assignment) - ) assignments in - let assignments = List.fold_left update_fixed no_fixed assignments in - - { istate with ast = rewrite_constant_function_calls' assignments target istate.ast}))) - |> register_command - ~name:"fix_registers" - ~help:"Fix the value of specified registers, specified as a \ - list of =. Can also fix a specific \ - register field as .=. Note that \ - this is not used to set registers normally, but instead \ - fixes their value such that the constant folding rewrite \ - (which is subsequently invoked by this command) will \ - replace register reads with the fixed values. Requires a \ - target (c, lem, etc.), as the set of functions that can \ - be constant folded can differ on a per-target basis." + ArgString + ( "target", + fun target -> + ArgString + ( "assignments", + fun assignments -> + Action + (fun istate -> + let assignments = Str.split (Str.regexp " +") assignments in + let assignments = + List.map + (fun assignment -> + match String.split_on_char '=' assignment with + | [reg; value] -> begin + match String.split_on_char '.' reg with + | [reg; field] -> + let reg = mk_id reg in + let field = mk_id field in + begin + match Env.lookup_id reg istate.env with + | Register (Typ_aux (Typ_id rec_id, _)) -> + let _, fields = Env.get_record rec_id istate.env in + let typ = + match List.find_opt (fun (typ, id) -> Id.compare id field = 0) fields with + | Some (typ, _) -> typ + | None -> + failwith + (sprintf "Register %s does not have a field %s" (string_of_id reg) + (string_of_id field) + ) + in + let exp = Initial_check.exp_of_string value in + let exp = check_exp istate.env exp typ in + Register_field (reg, field, typ, exp) + | _ -> + failwith + (sprintf "Register %s is not defined as a record in the current environment" + (string_of_id reg) + ) + end + | _ -> + let reg = mk_id reg in + begin + match Env.lookup_id reg istate.env with + | Register typ -> + let exp = Initial_check.exp_of_string value in + let exp = check_exp istate.env exp typ in + Register (reg, typ, exp) + | _ -> + failwith + (sprintf "Register %s is not defined in the current environment" + (string_of_id reg) + ) + end + end + | _ -> failwith (sprintf "Could not parse '%s' as an assignment =" assignment) + ) + assignments + in + let assignments = List.fold_left update_fixed no_fixed assignments in + + { istate with ast = rewrite_constant_function_calls' assignments target istate.ast } + ) + ) + ) + |> register_command ~name:"fix_registers" + ~help: + "Fix the value of specified registers, specified as a list of =. Can also fix a specific \ + register field as .=. Note that this is not used to set registers normally, but \ + instead fixes their value such that the constant folding rewrite (which is subsequently invoked by this \ + command) will replace register reads with the fixed values. Requires a target (c, lem, etc.), as the set of \ + functions that can be constant folded can differ on a per-target basis." diff --git a/src/lib/constant_propagation.ml b/src/lib/constant_propagation.ml index 3564a1823..e488a78cf 100644 --- a/src/lib/constant_propagation.ml +++ b/src/lib/constant_propagation.ml @@ -81,116 +81,122 @@ open Type_check subexpressions, dropping assignments rather than committing to any particular order *) +let kbindings_from_list = List.fold_left (fun s (v, i) -> KBindings.add v i s) KBindings.empty +let bindings_from_list = List.fold_left (fun s (v, i) -> Bindings.add v i s) Bindings.empty -let kbindings_from_list = List.fold_left (fun s (v,i) -> KBindings.add v i s) KBindings.empty -let bindings_from_list = List.fold_left (fun s (v,i) -> Bindings.add v i s) Bindings.empty (* union was introduced in 4.03.0, a bit too recently *) let bindings_union s1 s2 = - Bindings.merge (fun _ x y -> match x,y with - | _, (Some x) -> Some x - | (Some x), _ -> Some x - | _, _ -> None) s1 s2 + Bindings.merge (fun _ x y -> match (x, y) with _, Some x -> Some x | Some x, _ -> Some x | _, _ -> None) s1 s2 let kbindings_union s1 s2 = - KBindings.merge (fun _ x y -> match x,y with - | _, (Some x) -> Some x - | (Some x), _ -> Some x - | _, _ -> None) s1 s2 + KBindings.merge (fun _ x y -> match (x, y) with _, Some x -> Some x | Some x, _ -> Some x | _, _ -> None) s1 s2 -let remove_bound (substs,ksubsts) pat = +let remove_bound (substs, ksubsts) pat = let bound = bindings_from_pat pat in - List.fold_left (fun sub v -> Bindings.remove v sub) substs bound, ksubsts + (List.fold_left (fun sub v -> Bindings.remove v sub) substs bound, ksubsts) -let rec is_value (E_aux (e,(l,annot))) = +let rec is_value (E_aux (e, (l, annot))) = let is_constructor id = match destruct_tannot annot with | None -> - (Reporting.print_err l "Monomorphisation" - ("Missing type information for identifier " ^ string_of_id id); - false) (* Be conservative if we have no info *) - | Some (env, _) -> - Env.is_union_constructor id env || - (match Env.lookup_id id env with - | Enum _ -> true - | Unbound _ | Local _ | Register _ -> false) + Reporting.print_err l "Monomorphisation" ("Missing type information for identifier " ^ string_of_id id); + false (* Be conservative if we have no info *) + | Some (env, _) -> ( + Env.is_union_constructor id env + || match Env.lookup_id id env with Enum _ -> true | Unbound _ | Local _ | Register _ -> false + ) in match e with | E_id id -> is_constructor id | E_lit _ -> true | E_tuple es | E_vector es -> List.for_all is_value es - | E_struct fes -> - List.for_all (fun (FE_aux (FE_fexp (_, e), _)) -> is_value e) fes - | E_app (id,es) -> is_constructor id && List.for_all is_value es + | E_struct fes -> List.for_all (fun (FE_aux (FE_fexp (_, e), _)) -> is_value e) fes + | E_app (id, es) -> is_constructor id && List.for_all is_value es (* We add casts to undefined to keep the type information in the AST *) - | E_typ (typ,E_aux (E_lit (L_aux (L_undef,_)),_)) -> true + | E_typ (typ, E_aux (E_lit (L_aux (L_undef, _)), _)) -> true (* Also keep casts around records, as type inference fails without *) | E_typ (_, (E_aux (E_struct _, _) as e')) -> is_value e' (* TODO: more? *) | _ -> false -let isubst_minus_set subst set = - IdSet.fold Bindings.remove set subst +let isubst_minus_set subst set = IdSet.fold Bindings.remove set subst let threaded_map f state l = - let l',state' = - List.fold_left (fun (tl,state) element -> let (el',state') = f state element in (el'::tl,state')) - ([],state) l - in List.rev l',state' - + let l', state' = + List.fold_left + (fun (tl, state) element -> + let el', state' = f state element in + (el' :: tl, state') + ) + ([], state) l + in + (List.rev l', state') (* Attempt simple pattern matches *) let lit_match = function | (L_zero | L_false), (L_zero | L_false) -> true - | (L_one | L_true ), (L_one | L_true ) -> true + | (L_one | L_true), (L_one | L_true) -> true | L_num i1, L_num i2 -> Big_int.equal i1 i2 - | l1,l2 -> l1 = l2 + | l1, l2 -> l1 = l2 (* There's no undefined nexp, so replace undefined sizes with a plausible size. 32 is used as a sensible default. *) let fabricate_nexp_exist env l typ kids nc typ' = - match kids,nc,Env.expand_synonyms env typ' with - | ([kid],NC_aux (NC_set (kid',i::_),_), - Typ_aux (Typ_app (Id_aux (Id "atom",_), - [A_aux (A_nexp (Nexp_aux (Nexp_var kid'',_)),_)]),_)) - when Kid.compare kid kid' = 0 && Kid.compare kid kid'' = 0 -> - Nexp_aux (Nexp_constant i,Unknown) - | ([kid],NC_aux (NC_true,_), - Typ_aux (Typ_app (Id_aux (Id "atom",_), - [A_aux (A_nexp (Nexp_aux (Nexp_var kid'',_)),_)]),_)) - when Kid.compare kid kid'' = 0 -> - nint 32 - | ([kid],NC_aux (NC_set (kid',i::_),_), - Typ_aux (Typ_app (Id_aux (Id "range",_), - [A_aux (A_nexp (Nexp_aux (Nexp_var kid'',_)),_); - A_aux (A_nexp (Nexp_aux (Nexp_var kid''',_)),_)]),_)) - when Kid.compare kid kid' = 0 && Kid.compare kid kid'' = 0 && - Kid.compare kid kid''' = 0 -> - Nexp_aux (Nexp_constant i,Unknown) - | ([kid],NC_aux (NC_true,_), - Typ_aux (Typ_app (Id_aux (Id "range",_), - [A_aux (A_nexp (Nexp_aux (Nexp_var kid'',_)),_); - A_aux (A_nexp (Nexp_aux (Nexp_var kid''',_)),_)]),_)) - when Kid.compare kid kid'' = 0 && - Kid.compare kid kid''' = 0 -> - nint 32 - | ([], _, typ) -> nint 32 - | (kids, nc, typ) -> - raise (Reporting.err_general l - ("Undefined value at unsupported type " ^ string_of_typ typ ^ " with " ^ Util.string_of_list ", " string_of_kid kids)) + match (kids, nc, Env.expand_synonyms env typ') with + | ( [kid], + NC_aux (NC_set (kid', i :: _), _), + Typ_aux (Typ_app (Id_aux (Id "atom", _), [A_aux (A_nexp (Nexp_aux (Nexp_var kid'', _)), _)]), _) ) + when Kid.compare kid kid' = 0 && Kid.compare kid kid'' = 0 -> + Nexp_aux (Nexp_constant i, Unknown) + | ( [kid], + NC_aux (NC_true, _), + Typ_aux (Typ_app (Id_aux (Id "atom", _), [A_aux (A_nexp (Nexp_aux (Nexp_var kid'', _)), _)]), _) ) + when Kid.compare kid kid'' = 0 -> + nint 32 + | ( [kid], + NC_aux (NC_set (kid', i :: _), _), + Typ_aux + ( Typ_app + ( Id_aux (Id "range", _), + [A_aux (A_nexp (Nexp_aux (Nexp_var kid'', _)), _); A_aux (A_nexp (Nexp_aux (Nexp_var kid''', _)), _)] + ), + _ + ) ) + when Kid.compare kid kid' = 0 && Kid.compare kid kid'' = 0 && Kid.compare kid kid''' = 0 -> + Nexp_aux (Nexp_constant i, Unknown) + | ( [kid], + NC_aux (NC_true, _), + Typ_aux + ( Typ_app + ( Id_aux (Id "range", _), + [A_aux (A_nexp (Nexp_aux (Nexp_var kid'', _)), _); A_aux (A_nexp (Nexp_aux (Nexp_var kid''', _)), _)] + ), + _ + ) ) + when Kid.compare kid kid'' = 0 && Kid.compare kid kid''' = 0 -> + nint 32 + | [], _, typ -> nint 32 + | kids, nc, typ -> + raise + (Reporting.err_general l + ("Undefined value at unsupported type " ^ string_of_typ typ ^ " with " + ^ Util.string_of_list ", " string_of_kid kids + ) + ) let fabricate_nexp l tannot = match destruct_tannot tannot with | None -> nint 32 - | Some (env,typ) -> - match Type_check.destruct_exist (Type_check.Env.expand_synonyms env typ) with - | None -> nint 32 - (* TODO: check this *) - | Some (kopts,nc,typ') -> fabricate_nexp_exist env l typ (List.map kopt_kid kopts) nc typ' + | Some (env, typ) -> ( + match Type_check.destruct_exist (Type_check.Env.expand_synonyms env typ) with + | None -> nint 32 + (* TODO: check this *) + | Some (kopts, nc, typ') -> fabricate_nexp_exist env l typ (List.map kopt_kid kopts) nc typ' + ) let atom_typ_kid kid = function - | Typ_aux (Typ_app (Id_aux (Id "atom",_), - [A_aux (A_nexp (Nexp_aux (Nexp_var kid',_)),_)]),_) -> - Kid.compare kid kid' = 0 + | Typ_aux (Typ_app (Id_aux (Id "atom", _), [A_aux (A_nexp (Nexp_aux (Nexp_var kid', _)), _)]), _) -> + Kid.compare kid kid' = 0 | _ -> false (* We reduce casts in a few cases, in particular to ensure that where the @@ -198,49 +204,46 @@ let atom_typ_kid kid = function fill in the 'n. For undefined we fabricate a suitable value for 'n. *) let reduce_cast typ exp l annot = - let env = env_of_annot (l,annot) in + let env = env_of_annot (l, annot) in let typ' = Env.base_typ_of env typ in - match exp, destruct_exist (Env.expand_synonyms env typ') with - | E_aux (E_lit (L_aux (L_num n,_)),_), Some ([kopt],nc,typ'') when atom_typ_kid (kopt_kid kopt) typ'' -> - let nc_env = Env.add_typ_var l kopt env in - let nc_env = Env.add_constraint (nc_eq (nvar (kopt_kid kopt)) (nconstant n)) nc_env in - if prove __POS__ nc_env nc - then exp - else raise (Reporting.err_unreachable l __POS__ - ("Constant propagation error: literal " ^ Big_int.to_string n ^ - " does not satisfy constraint " ^ string_of_n_constraint nc)) - | E_aux (E_lit (L_aux (L_undef,_)),_), Some ([kopt],nc,typ'') when atom_typ_kid (kopt_kid kopt) typ'' -> - let nexp = fabricate_nexp_exist env Unknown typ [kopt_kid kopt] nc typ'' in - let newtyp = subst_kids_typ (KBindings.singleton (kopt_kid kopt) nexp) typ'' in - E_aux (E_typ (newtyp, exp), (Generated l,replace_typ newtyp annot)) - | E_aux (E_typ (_, - (E_aux (E_lit (L_aux (L_undef,_)),_) as exp)),_), - Some ([kopt],nc,typ'') when atom_typ_kid (kopt_kid kopt) typ'' -> - let nexp = fabricate_nexp_exist env Unknown typ [kopt_kid kopt] nc typ'' in - let newtyp = subst_kids_typ (KBindings.singleton (kopt_kid kopt) nexp) typ'' in - E_aux (E_typ (newtyp, exp), (Generated l,replace_typ newtyp annot)) - | _ -> E_aux (E_typ (typ,exp),(l,annot)) + match (exp, destruct_exist (Env.expand_synonyms env typ')) with + | E_aux (E_lit (L_aux (L_num n, _)), _), Some ([kopt], nc, typ'') when atom_typ_kid (kopt_kid kopt) typ'' -> + let nc_env = Env.add_typ_var l kopt env in + let nc_env = Env.add_constraint (nc_eq (nvar (kopt_kid kopt)) (nconstant n)) nc_env in + if prove __POS__ nc_env nc then exp + else + raise + (Reporting.err_unreachable l __POS__ + ("Constant propagation error: literal " ^ Big_int.to_string n ^ " does not satisfy constraint " + ^ string_of_n_constraint nc + ) + ) + | E_aux (E_lit (L_aux (L_undef, _)), _), Some ([kopt], nc, typ'') when atom_typ_kid (kopt_kid kopt) typ'' -> + let nexp = fabricate_nexp_exist env Unknown typ [kopt_kid kopt] nc typ'' in + let newtyp = subst_kids_typ (KBindings.singleton (kopt_kid kopt) nexp) typ'' in + E_aux (E_typ (newtyp, exp), (Generated l, replace_typ newtyp annot)) + | E_aux (E_typ (_, (E_aux (E_lit (L_aux (L_undef, _)), _) as exp)), _), Some ([kopt], nc, typ'') + when atom_typ_kid (kopt_kid kopt) typ'' -> + let nexp = fabricate_nexp_exist env Unknown typ [kopt_kid kopt] nc typ'' in + let newtyp = subst_kids_typ (KBindings.singleton (kopt_kid kopt) nexp) typ'' in + E_aux (E_typ (newtyp, exp), (Generated l, replace_typ newtyp annot)) + | _ -> E_aux (E_typ (typ, exp), (l, annot)) (* Used for constant propagation in pattern matches *) -type 'a matchresult = - | DoesMatch of 'a - | DoesNotMatch - | GiveUp +type 'a matchresult = DoesMatch of 'a | DoesNotMatch | GiveUp (* Remove top-level casts from an expression. Useful when we need to look at subexpressions to reduce something, but could break type-checking if we used it everywhere. *) -let rec drop_casts = function - | E_aux (E_typ (_,e),_) -> drop_casts e - | exp -> exp +let rec drop_casts = function E_aux (E_typ (_, e), _) -> drop_casts e | exp -> exp let construct_lit_vector args = let rec aux l = function - | [] -> Some (L_aux (L_bin (String.concat "" (List.rev l)),Unknown)) - | E_aux (E_lit (L_aux ((L_zero | L_one) as lit,_)),_)::t -> - aux ((if lit = L_zero then "0" else "1")::l) t + | [] -> Some (L_aux (L_bin (String.concat "" (List.rev l)), Unknown)) + | E_aux (E_lit (L_aux (((L_zero | L_one) as lit), _)), _) :: t -> aux ((if lit = L_zero then "0" else "1") :: l) t | _ -> None - in aux [] args + in + aux [] args (* Add a cast to undefined so that it retains its type, otherwise it can't be substituted safely *) @@ -248,31 +251,29 @@ let keep_undef_typ value = let e_aux (e, ann) = match e with | E_lit (L_aux (L_undef, _)) -> - (* Add cast to undefined... *) - E_aux (E_typ (typ_of_annot ann, E_aux (e, ann)), ann) + (* Add cast to undefined... *) + E_aux (E_typ (typ_of_annot ann, E_aux (e, ann)), ann) | E_typ (typ, E_aux (E_typ (_, e), _)) -> - (* ... unless there was a cast already *) - E_aux (E_typ (typ, e), ann) + (* ... unless there was a cast already *) + E_aux (E_typ (typ, e), ann) | _ -> E_aux (e, ann) in let open Rewriter in - fold_exp { id_exp_alg with e_aux = e_aux } value + fold_exp { id_exp_alg with e_aux } value (* Check whether the current environment with the given kid assignments is inconsistent (and hence whether the code is dead) *) let is_env_inconsistent env ksubsts = - let env = KBindings.fold (fun k nexp env -> - Env.add_constraint (nc_eq (nvar k) nexp) env) ksubsts env in + let env = KBindings.fold (fun k nexp env -> Env.add_constraint (nc_eq (nvar k) nexp) env) ksubsts env in prove __POS__ env nc_false -module StringSet = Set.Make(String) -module StringMap = Map.Make(String) +module StringSet = Set.Make (String) +module StringMap = Map.Make (String) -(* This is set up so that a partially applied version can be used multiple +(* This is set up so that a partially applied version can be used multiple times, reducing start up time. *) let const_props target ast = - (* Constant-fold function applications with constant arguments *) let interpreter_istate = (* Do not interpret undefined_X functions *) @@ -280,7 +281,7 @@ let const_props target ast = let undefined_builtin_ids = ids_of_defs Initial_check.undefined_builtin_val_specs in let remove_primop id = StringMap.remove (string_of_id id) in let remove_undefined_primops = IdSet.fold remove_primop undefined_builtin_ids in - let (lstate, gstate) = Constant_fold.initial_state ast Type_check.initial_env in + let lstate, gstate = Constant_fold.initial_state ast Type_check.initial_env in (lstate, { gstate with primops = remove_undefined_primops gstate.primops }) in let const_fold exp = @@ -289,574 +290,557 @@ let const_props target ast = |> infer_exp (env_of exp) |> Constant_fold.rewrite_exp_once target interpreter_istate |> keep_undef_typ - with - | _ -> exp + with _ -> exp in fun ref_vars -> - - let constants = - let add m = function - | DEF_aux (DEF_let (LB_aux (LB_val (P_aux ((P_id id | P_typ (_,P_aux (P_id id,_))),_), exp),_)), _) - when Constant_fold.is_constant exp -> - Bindings.add id exp m - | _ -> m - in - List.fold_left add Bindings.empty ast.defs - in - let replace_constant (E_aux (e,annot) as exp) = - match e with - | E_id id -> - (match Bindings.find_opt id constants with - | Some e -> e - | None -> exp) - | _ -> exp - in - let rec const_prop_exp substs assigns ((E_aux (e,(l,annot))) as exp) = - (* Functions to treat lists and tuples of subexpressions as possibly - non-deterministic: that is, we stop making any assumptions about - variables that are assigned to in any of the subexpressions *) - let non_det_exp_list es = - let assigned_in = - List.fold_left (fun vs exp -> IdSet.union vs (assigned_vars exp)) - IdSet.empty es in - let assigns = isubst_minus_set assigns assigned_in in - let es' = List.map (fun e -> fst (const_prop_exp substs assigns e)) es in - es',assigns - in - let non_det_exp_2 e1 e2 = - let assigned_in_e12 = IdSet.union (assigned_vars e1) (assigned_vars e2) in - let assigns = isubst_minus_set assigns assigned_in_e12 in - let e1',_ = const_prop_exp substs assigns e1 in - let e2',_ = const_prop_exp substs assigns e2 in - e1',e2',assigns - in - let non_det_exp_3 e1 e2 e3 = - let assigned_in_e12 = IdSet.union (assigned_vars e1) (assigned_vars e2) in - let assigned_in_e123 = IdSet.union assigned_in_e12 (assigned_vars e3) in - let assigns = isubst_minus_set assigns assigned_in_e123 in - let e1',_ = const_prop_exp substs assigns e1 in - let e2',_ = const_prop_exp substs assigns e2 in - let e3',_ = const_prop_exp substs assigns e3 in - e1',e2',e3',assigns + let constants = + let add m = function + | DEF_aux (DEF_let (LB_aux (LB_val (P_aux ((P_id id | P_typ (_, P_aux (P_id id, _))), _), exp), _)), _) + when Constant_fold.is_constant exp -> + Bindings.add id exp m + | _ -> m + in + List.fold_left add Bindings.empty ast.defs in - let non_det_exp_4 e1 e2 e3 e4 = - let assigned_in_e12 = IdSet.union (assigned_vars e1) (assigned_vars e2) in - let assigned_in_e123 = IdSet.union assigned_in_e12 (assigned_vars e3) in - let assigned_in_e1234 = IdSet.union assigned_in_e123 (assigned_vars e4) in - let assigns = isubst_minus_set assigns assigned_in_e1234 in - let e1',_ = const_prop_exp substs assigns e1 in - let e2',_ = const_prop_exp substs assigns e2 in - let e3',_ = const_prop_exp substs assigns e3 in - let e4',_ = const_prop_exp substs assigns e4 in - e1',e2',e3',e4',assigns + let replace_constant (E_aux (e, annot) as exp) = + match e with + | E_id id -> ( + match Bindings.find_opt id constants with Some e -> e | None -> exp + ) + | _ -> exp in - let rewrap e = E_aux (e,(l,annot)) in - let re e assigns = rewrap e,assigns in - match e with + let rec const_prop_exp substs assigns (E_aux (e, (l, annot)) as exp) = + (* Functions to treat lists and tuples of subexpressions as possibly + non-deterministic: that is, we stop making any assumptions about + variables that are assigned to in any of the subexpressions *) + let non_det_exp_list es = + let assigned_in = List.fold_left (fun vs exp -> IdSet.union vs (assigned_vars exp)) IdSet.empty es in + let assigns = isubst_minus_set assigns assigned_in in + let es' = List.map (fun e -> fst (const_prop_exp substs assigns e)) es in + (es', assigns) + in + let non_det_exp_2 e1 e2 = + let assigned_in_e12 = IdSet.union (assigned_vars e1) (assigned_vars e2) in + let assigns = isubst_minus_set assigns assigned_in_e12 in + let e1', _ = const_prop_exp substs assigns e1 in + let e2', _ = const_prop_exp substs assigns e2 in + (e1', e2', assigns) + in + let non_det_exp_3 e1 e2 e3 = + let assigned_in_e12 = IdSet.union (assigned_vars e1) (assigned_vars e2) in + let assigned_in_e123 = IdSet.union assigned_in_e12 (assigned_vars e3) in + let assigns = isubst_minus_set assigns assigned_in_e123 in + let e1', _ = const_prop_exp substs assigns e1 in + let e2', _ = const_prop_exp substs assigns e2 in + let e3', _ = const_prop_exp substs assigns e3 in + (e1', e2', e3', assigns) + in + let non_det_exp_4 e1 e2 e3 e4 = + let assigned_in_e12 = IdSet.union (assigned_vars e1) (assigned_vars e2) in + let assigned_in_e123 = IdSet.union assigned_in_e12 (assigned_vars e3) in + let assigned_in_e1234 = IdSet.union assigned_in_e123 (assigned_vars e4) in + let assigns = isubst_minus_set assigns assigned_in_e1234 in + let e1', _ = const_prop_exp substs assigns e1 in + let e2', _ = const_prop_exp substs assigns e2 in + let e3', _ = const_prop_exp substs assigns e3 in + let e4', _ = const_prop_exp substs assigns e4 in + (e1', e2', e3', e4', assigns) + in + let rewrap e = E_aux (e, (l, annot)) in + let re e assigns = (rewrap e, assigns) in + match e with (* TODO: are there more circumstances in which we should get rid of these? *) - | E_block [e] -> const_prop_exp substs assigns e - | E_block es -> - let es',assigns = threaded_map (const_prop_exp substs) assigns es in - re (E_block es') assigns - | E_id id -> - let env = Type_check.env_of_annot (l, annot) in - (try - match Env.lookup_id id env with - | Local (Immutable,_) -> Bindings.find id (fst substs) - | Local (Mutable,_) -> Bindings.find id assigns - | _ -> exp - with Not_found -> exp),assigns - | E_lit _ - | E_sizeof _ - | E_constraint _ - -> exp,assigns - | E_typ (t,e') -> - let e'',assigns = const_prop_exp substs assigns e' in - if is_value e'' - then reduce_cast t e'' l annot, assigns - else re (E_typ (t, e'')) assigns - | E_app (id,es) -> - let es',assigns = non_det_exp_list es in - let env = Type_check.env_of_annot (l, annot) in - const_prop_try_fn env (id, es') (l, annot), assigns - | E_tuple es -> - let es',assigns = non_det_exp_list es in - re (E_tuple es') assigns - | E_if (e1,e2,e3) -> - let e1',assigns = const_prop_exp substs assigns e1 in - let e1_no_casts = drop_casts e1' in - (match e1_no_casts with - | E_aux (E_lit (L_aux ((L_true|L_false) as lit ,_)),_) -> - (match lit with - | L_true -> const_prop_exp substs assigns e2 - | _ -> const_prop_exp substs assigns e3) - | _ -> - (* If the guard is an equality check, propagate the value. *) - let env1 = env_of e1_no_casts in - let is_equal id = - List.exists (fun id' -> Id.compare id id' == 0) - (Env.get_overloads (Id_aux (Operator "==", Parse_ast.Unknown)) - env1) + | E_block [e] -> const_prop_exp substs assigns e + | E_block es -> + let es', assigns = threaded_map (const_prop_exp substs) assigns es in + re (E_block es') assigns + | E_id id -> + let env = Type_check.env_of_annot (l, annot) in + ( ( try + match Env.lookup_id id env with + | Local (Immutable, _) -> Bindings.find id (fst substs) + | Local (Mutable, _) -> Bindings.find id assigns + | _ -> exp + with Not_found -> exp + ), + assigns + ) + | E_lit _ | E_sizeof _ | E_constraint _ -> (exp, assigns) + | E_typ (t, e') -> + let e'', assigns = const_prop_exp substs assigns e' in + if is_value e'' then (reduce_cast t e'' l annot, assigns) else re (E_typ (t, e'')) assigns + | E_app (id, es) -> + let es', assigns = non_det_exp_list es in + let env = Type_check.env_of_annot (l, annot) in + (const_prop_try_fn env (id, es') (l, annot), assigns) + | E_tuple es -> + let es', assigns = non_det_exp_list es in + re (E_tuple es') assigns + | E_if (e1, e2, e3) -> ( + let e1', assigns = const_prop_exp substs assigns e1 in + let e1_no_casts = drop_casts e1' in + match e1_no_casts with + | E_aux (E_lit (L_aux (((L_true | L_false) as lit), _)), _) -> ( + match lit with L_true -> const_prop_exp substs assigns e2 | _ -> const_prop_exp substs assigns e3 + ) + | _ -> + (* If the guard is an equality check, propagate the value. *) + let env1 = env_of e1_no_casts in + let is_equal id = + List.exists + (fun id' -> Id.compare id id' == 0) + (Env.get_overloads (Id_aux (Operator "==", Parse_ast.Unknown)) env1) + in + let substs_true = + match e1_no_casts with + | (E_aux (E_app (id, [E_aux (E_id var, _); vl]), _) | E_aux (E_app (id, [vl; E_aux (E_id var, _)]), _)) + when is_equal id -> + if is_value vl then ( + match Env.lookup_id var env1 with + | Local (Immutable, _) -> (Bindings.add var vl (fst substs), snd substs) + | _ -> substs + ) + else substs + | _ -> substs + in + (* Discard impossible branches *) + if is_env_inconsistent (env_of e2) (snd substs) then const_prop_exp substs assigns e3 + else if is_env_inconsistent (env_of e3) (snd substs) then const_prop_exp substs_true assigns e2 + else ( + let e2', assigns2 = const_prop_exp substs_true assigns e2 in + let e3', assigns3 = const_prop_exp substs assigns e3 in + (* If one branch is a throw, use the assignments from the other *) + let assigns = + match (e2', e3') with + | E_aux (E_throw _, _), _ -> assigns3 + | _, E_aux (E_throw _, _) -> assigns2 + | _, _ -> + let assigns = isubst_minus_set assigns (assigned_vars e2) in + let assigns = isubst_minus_set assigns (assigned_vars e3) in + assigns + in + re (E_if (e1', e2', e3')) assigns + ) + ) + | E_for (id, e1, e2, e3, ord, e4) -> + (* Treat e1, e2 and e3 (from, to and by) as a non-det tuple *) + let e1', e2', e3', assigns = non_det_exp_3 e1 e2 e3 in + let assigns = isubst_minus_set assigns (assigned_vars e4) in + let e4', _ = const_prop_exp (Bindings.remove id (fst substs), snd substs) assigns e4 in + re (E_for (id, e1', e2', e3', ord, e4')) assigns + | E_loop (loop, m, e1, e2) -> + let assigns = isubst_minus_set assigns (IdSet.union (assigned_vars e1) (assigned_vars e2)) in + let m' = + match m with + | Measure_aux (Measure_none, _) -> m + | Measure_aux (Measure_some exp, l) -> + let exp', _ = const_prop_exp substs assigns exp in + Measure_aux (Measure_some exp', l) in - let substs_true = - match e1_no_casts with - | E_aux (E_app (id, [E_aux (E_id var,_); vl]),_) - | E_aux (E_app (id, [vl; E_aux (E_id var,_)]),_) - when is_equal id -> - if is_value vl then - (match Env.lookup_id var env1 with - | Local (Immutable,_) -> Bindings.add var vl (fst substs),snd substs - | _ -> substs) - else substs - | _ -> substs + let e1', _ = const_prop_exp substs assigns e1 in + let e2', _ = const_prop_exp substs assigns e2 in + re (E_loop (loop, m', e1', e2')) assigns + | E_vector es -> + let es', assigns = non_det_exp_list es in + begin + match construct_lit_vector es' with None -> re (E_vector es') assigns | Some lit -> re (E_lit lit) assigns + end + | E_vector_access (e1, e2) -> + let e1', e2', assigns = non_det_exp_2 e1 e2 in + re (E_vector_access (e1', e2')) assigns + | E_vector_subrange (e1, e2, e3) -> + let e1', e2', e3', assigns = non_det_exp_3 e1 e2 e3 in + re (E_vector_subrange (e1', e2', e3')) assigns + | E_vector_update (e1, e2, e3) -> + let e1', e2', e3', assigns = non_det_exp_3 e1 e2 e3 in + re (E_vector_update (e1', e2', e3')) assigns + | E_vector_update_subrange (e1, e2, e3, e4) -> + let e1', e2', e3', e4', assigns = non_det_exp_4 e1 e2 e3 e4 in + re (E_vector_update_subrange (e1', e2', e3', e4')) assigns + | E_vector_append (e1, e2) -> + let e1', e2', assigns = non_det_exp_2 e1 e2 in + re (E_vector_append (e1', e2')) assigns + | E_list es -> + let es', assigns = non_det_exp_list es in + re (E_list es') assigns + | E_cons (e1, e2) -> + let e1', e2', assigns = non_det_exp_2 e1 e2 in + re (E_cons (e1', e2')) assigns + | E_struct fes -> + let assigned_in_fes = assigned_vars_in_fexps fes in + let assigns = isubst_minus_set assigns assigned_in_fes in + re (E_struct (const_prop_fexps substs assigns fes)) assigns + | E_struct_update (e, fes) -> + let assigned_in = IdSet.union (assigned_vars_in_fexps fes) (assigned_vars e) in + let assigns = isubst_minus_set assigns assigned_in in + let e', _ = const_prop_exp substs assigns e in + let fes' = const_prop_fexps substs assigns fes in + begin + match unaux_exp (fst (uncast_exp e')) with + | E_struct fes0 -> + let apply_fexp (FE_aux (FE_fexp (id, e), _)) (FE_aux (FE_fexp (id', e'), ann)) = + if Id.compare id id' = 0 then FE_aux (FE_fexp (id', e), ann) else FE_aux (FE_fexp (id', e'), ann) + in + let update_fields fexp = List.map (apply_fexp fexp) in + let fes0' = List.fold_right update_fields fes' fes0 in + re (E_struct fes0') assigns + | _ -> re (E_struct_update (e', fes')) assigns + end + | E_field (e, id) -> + let e', assigns = const_prop_exp substs assigns e in + begin + let is_field (FE_aux (FE_fexp (id', _), _)) = Id.compare id id' = 0 in + match unaux_exp e' with + | E_struct fes0 when List.exists is_field fes0 -> + let (FE_aux (FE_fexp (_, e), _)) = List.find is_field fes0 in + re (unaux_exp e) assigns + | _ -> re (E_field (e', id)) assigns + end + | E_match (e, cases) -> ( + let e', assigns = const_prop_exp substs assigns e in + match can_match e' cases substs assigns with + | None -> + let assigned_in = + List.fold_left (fun vs pe -> IdSet.union vs (assigned_vars_in_pexp pe)) IdSet.empty cases + in + let assigns' = isubst_minus_set assigns assigned_in in + re (E_match (e', List.map (const_prop_pexp substs assigns) cases)) assigns' + | Some ((E_aux (_, (_, annot')) as exp), newbindings, kbindings) -> + let exp = nexp_subst_exp (kbindings_from_list kbindings) exp in + let newbindings_env = bindings_from_list newbindings in + let substs' = (bindings_union (fst substs) newbindings_env, snd substs) in + const_prop_exp substs' assigns exp + ) + | E_let (lb, e2) -> begin + match lb with + | LB_aux (LB_val (p, e), annot) -> + let e', assigns = const_prop_exp substs assigns e in + let substs' = remove_bound substs p in + let plain () = + let e2', assigns = const_prop_exp substs' assigns e2 in + re (E_let (LB_aux (LB_val (p, e'), annot), e2')) assigns + in + if is_value e' then ( + match can_match e' [Pat_aux (Pat_exp (p, e2), (Unknown, empty_tannot))] substs assigns with + | None -> plain () + | Some (e'', bindings, kbindings) -> + let e'' = nexp_subst_exp (kbindings_from_list kbindings) e'' in + let bindings = bindings_from_list bindings in + let substs'' = (bindings_union (fst substs') bindings, snd substs') in + const_prop_exp substs'' assigns e'' + ) + else plain () + end + (* TODO maybe - tuple assignments *) + | E_assign (le, e) -> + let env = Type_check.env_of_annot (l, annot) in + let assigned_in = IdSet.union (assigned_vars_in_lexp le) (assigned_vars e) in + let assigns = isubst_minus_set assigns assigned_in in + let le', idopt = const_prop_lexp substs assigns le in + let e', _ = const_prop_exp substs assigns e in + let assigns = + match idopt with + | Some id -> begin + match Env.lookup_id id env with + | Local (Mutable, _) | Unbound _ -> + if is_value e' && not (IdSet.mem id ref_vars) then Bindings.add id (keep_undef_typ e') assigns + else Bindings.remove id assigns + | _ -> assigns + end + | None -> assigns + in + re (E_assign (le', e')) assigns + | E_var (le, e, e2) -> + let env = Type_check.env_of_annot (l, annot) in + let assigned_in = IdSet.union (assigned_vars_in_lexp le) (assigned_vars e) in + let assigns = isubst_minus_set assigns assigned_in in + let le', idopt = const_prop_lexp substs assigns le in + let e', _ = const_prop_exp substs assigns e in + let assigns = + match idopt with + | Some id -> begin + match Env.lookup_id id env with + | Local (Mutable, _) | Unbound _ -> + if is_value e' && not (IdSet.mem id ref_vars) then Bindings.add id (keep_undef_typ e') assigns + else Bindings.remove id assigns + | _ -> assigns + end + | None -> assigns in - (* Discard impossible branches *) - if is_env_inconsistent (env_of e2) (snd substs) then - const_prop_exp substs assigns e3 - else if is_env_inconsistent (env_of e3) (snd substs) then - const_prop_exp substs_true assigns e2 - else - let e2',assigns2 = const_prop_exp substs_true assigns e2 in - let e3',assigns3 = const_prop_exp substs assigns e3 in - (* If one branch is a throw, use the assignments from the other *) - let assigns = - match e2', e3' with - | E_aux (E_throw _, _), _ -> assigns3 - | _, E_aux (E_throw _, _) -> assigns2 - | _, _ -> - let assigns = isubst_minus_set assigns (assigned_vars e2) in - let assigns = isubst_minus_set assigns (assigned_vars e3) in - assigns + let e2', _ = const_prop_exp substs assigns e2 in + re (E_var (le', e', e2')) assigns + | E_exit e -> + let e', _ = const_prop_exp substs assigns e in + re (E_exit e') Bindings.empty + | E_ref id -> re (E_ref id) Bindings.empty + | E_throw e -> + let e', _ = const_prop_exp substs assigns e in + re (E_throw e') Bindings.empty + | E_try (e, cases) -> + (* TODO: try and preserve *any* assignment info; note the special case in E_if if + one of the branches throws. *) + let e', _ = const_prop_exp substs assigns e in + re (E_match (e', List.map (const_prop_pexp substs Bindings.empty) cases)) Bindings.empty + | E_return e -> + let e', _ = const_prop_exp substs assigns e in + re (E_return e') Bindings.empty + | E_assert (e1, e2) -> + let e1', e2', assigns = non_det_exp_2 e1 e2 in + re (E_assert (e1', e2')) assigns + | E_internal_assume (nc, e) -> + let e', _ = const_prop_exp substs assigns e in + re (E_internal_assume (nc, e')) assigns + | E_app_infix _ | E_internal_plet _ | E_internal_return _ | E_internal_value _ -> + raise + (Reporting.err_unreachable l __POS__ + ("Unexpected expression encountered in monomorphisation: " ^ string_of_exp exp) + ) + and const_prop_fexps substs assigns fes = List.map (const_prop_fexp substs assigns) fes + and const_prop_fexp substs assigns (FE_aux (FE_fexp (id, e), annot)) = + FE_aux (FE_fexp (id, fst (const_prop_exp substs assigns e)), annot) + and const_prop_pexp substs assigns = function + | Pat_aux (Pat_exp (p, e), l) -> Pat_aux (Pat_exp (p, fst (const_prop_exp (remove_bound substs p) assigns e)), l) + | Pat_aux (Pat_when (p, e1, e2), l) -> + let substs' = remove_bound substs p in + let e1', assigns = const_prop_exp substs' assigns e1 in + Pat_aux (Pat_when (p, e1', fst (const_prop_exp substs' assigns e2)), l) + and const_prop_lexp substs assigns (LE_aux (e, annot) as le) = + let re e = (LE_aux (e, annot), None) in + match e with + | LE_id id (* shouldn't end up substituting here *) | LE_typ (_, id) -> (le, Some id) + | LE_app (id, es) -> re (LE_app (id, List.map (fun e -> fst (const_prop_exp substs assigns e)) es)) (* or here *) + | LE_tuple les -> re (LE_tuple (List.map (fun le -> fst (const_prop_lexp substs assigns le)) les)) + | LE_vector (le, e) -> + re (LE_vector (fst (const_prop_lexp substs assigns le), fst (const_prop_exp substs assigns e))) + | LE_vector_range (le, e1, e2) -> + re + (LE_vector_range + ( fst (const_prop_lexp substs assigns le), + fst (const_prop_exp substs assigns e1), + fst (const_prop_exp substs assigns e2) + ) + ) + | LE_vector_concat les -> re (LE_vector_concat (List.map (fun le -> fst (const_prop_lexp substs assigns le)) les)) + | LE_field (le, id) -> re (LE_field (fst (const_prop_lexp substs assigns le), id)) + | LE_deref e -> re (LE_deref (fst (const_prop_exp substs assigns e))) + (* Try to evaluate function calls with constant arguments via + (interpreter-based) constant folding. + Boolean connectives are special-cased to support short-circuiting when one + argument has a suitable value (even if the other one is not constant). + Moreover, calls to a __size function (in particular generated by sizeof + rewriting) with a known-constant return type are replaced by that constant; + e.g., (length(op : bits(32)) : int(32)) becomes 32 even if op is not constant. + *) + and const_prop_try_fn env (id, args) (l, annot) = + let exp_orig = E_aux (E_app (id, args), (l, annot)) in + let args = List.map replace_constant args in + let exp = E_aux (E_app (id, args), (l, annot)) in + let rec is_overload_of f = + Env.get_overloads f env |> List.exists (fun id' -> Id.compare id id' = 0 || is_overload_of id') + in + match (string_of_id id, args) with + | ( "and_bool", + ( [(E_aux (E_lit (L_aux (L_false, _)), _) as e_false); _] + | [_; (E_aux (E_lit (L_aux (L_false, _)), _) as e_false)] ) ) -> + e_false + | ( "or_bool", + ([(E_aux (E_lit (L_aux (L_true, _)), _) as e_true); _] | [_; (E_aux (E_lit (L_aux (L_true, _)), _) as e_true)]) + ) -> + e_true + | (_, [E_aux (E_vector [], _); e'] | _, [e'; E_aux (E_vector [], _)]) when is_overload_of (mk_id "append") -> e' + | _, _ when List.for_all Constant_fold.is_constant args -> const_fold exp + | _, [arg] when is_overload_of (mk_id "__size") -> ( + match destruct_atom_nexp env (typ_of exp) with + | Some (Nexp_aux (Nexp_constant i, _)) -> E_aux (E_lit (mk_lit (L_num i)), (l, annot)) + | _ -> exp_orig + ) + | _ -> exp_orig + and can_match_with_env env (E_aux (e, (l, annot)) as exp0) cases (substs, ksubsts) assigns = + let rec check_exp_pat (E_aux (e, (l, annot)) as exp) (P_aux (p, (l', _)) as pat) = + match (e, p) with + | _, P_wild -> DoesMatch ([], []) + | _, P_typ (_, p') -> check_exp_pat exp p' + | _, P_id id' when pat_id_is_variable env id' -> + let exp_typ = typ_of exp in + let pat_typ = typ_of_pat pat in + let goals = KidSet.diff (tyvars_of_typ pat_typ) (tyvars_of_typ exp_typ) in + let unifiers = try Type_check.unify l env goals pat_typ exp_typ with _ -> KBindings.empty in + let is_nexp (k, a) = match a with A_aux (A_nexp n, _) -> Some (k, n) | _ -> None in + let kbindings = List.filter_map is_nexp (KBindings.bindings unifiers) in + DoesMatch ([(id', exp)], kbindings) + | E_tuple es, P_tuple ps -> + let check = function + | DoesNotMatch -> fun _ -> DoesNotMatch + | GiveUp -> fun _ -> GiveUp + | DoesMatch (s, ns) -> ( + fun (e, p) -> + match check_exp_pat e p with DoesMatch (s', ns') -> DoesMatch (s @ s', ns @ ns') | x -> x + ) + in + List.fold_left check (DoesMatch ([], [])) (List.combine es ps) + | E_id id, _ -> ( + match Env.lookup_id id env with + | Enum _ -> begin + match p with + | P_id id' | P_app (id', []) -> if Id.compare id id' = 0 then DoesMatch ([], []) else DoesNotMatch + | _ -> + Reporting.print_err l' "Monomorphisation" + ("Unexpected kind of pattern for enumeration: " ^ string_of_pat pat); + GiveUp + end + | _ -> GiveUp + ) + | E_lit (L_aux (lit_e, lit_l)), P_lit (L_aux (lit_p, _)) -> + if lit_match (lit_e, lit_p) then DoesMatch ([], []) else DoesNotMatch + | E_lit (L_aux (lit_e, lit_l)), P_var (P_aux (P_id id, p_id_annot), TP_aux (TP_var kid, _)) -> begin + match lit_e with + | L_num i -> DoesMatch ([(id, E_aux (e, (l, annot)))], [(kid, Nexp_aux (Nexp_constant i, Unknown))]) + (* For undefined we fix the type-level size (because there's no good + way to construct an undefined size), but leave the term as undefined + to make the meaning clear. *) + | L_undef -> + let nexp = fabricate_nexp l annot in + let typ = subst_kids_typ (KBindings.singleton kid nexp) (typ_of_annot p_id_annot) in + DoesMatch ([(id, E_aux (E_typ (typ, E_aux (e, (l, empty_tannot))), (l, empty_tannot)))], [(kid, nexp)]) + | _ -> + Reporting.print_err lit_l "Monomorphisation" + ("Unexpected kind of literal for var match: " ^ string_of_lit (L_aux (lit_e, lit_l))); + GiveUp + end + | E_lit (L_aux ((L_bin _ | L_hex _), _) as lit), P_vector _ -> + let mk_bitlit lit = E_aux (E_lit lit, (Generated l, mk_tannot env bit_typ)) in + let lits' = List.map mk_bitlit (vector_string_to_bit_list lit) in + check_exp_pat (E_aux (E_vector lits', (l, annot))) pat + | E_lit _, _ -> + Reporting.print_err l' "Monomorphisation" ("Unexpected kind of pattern for literal: " ^ string_of_pat pat); + GiveUp + | E_vector es, P_vector ps when List.for_all (function E_aux (E_lit _, _) -> true | _ -> false) es -> ( + let matches = + List.map2 + (fun e p -> + let p = match p with P_aux (P_typ (_, p'), _) -> p' | _ -> p in + match (e, p) with + | E_aux (E_lit (L_aux (lit, _)), _), P_aux (P_lit (L_aux (lit', _)), _) -> + if lit_match (lit, lit') then DoesMatch ([], []) else DoesNotMatch + | E_aux (E_lit l, _), P_aux (P_id var, _) when pat_id_is_variable env var -> DoesMatch ([(var, e)], []) + | _, P_aux (P_wild, _) -> DoesMatch ([], []) + | _ -> GiveUp + ) + es ps in - re (E_if (e1',e2',e3')) assigns) - | E_for (id,e1,e2,e3,ord,e4) -> - (* Treat e1, e2 and e3 (from, to and by) as a non-det tuple *) - let e1',e2',e3',assigns = non_det_exp_3 e1 e2 e3 in - let assigns = isubst_minus_set assigns (assigned_vars e4) in - let e4',_ = const_prop_exp (Bindings.remove id (fst substs),snd substs) assigns e4 in - re (E_for (id,e1',e2',e3',ord,e4')) assigns - | E_loop (loop,m,e1,e2) -> - let assigns = isubst_minus_set assigns (IdSet.union (assigned_vars e1) (assigned_vars e2)) in - let m' = match m with - | Measure_aux (Measure_none,_) -> m - | Measure_aux (Measure_some exp,l) -> - let exp',_ = const_prop_exp substs assigns exp in - Measure_aux (Measure_some exp',l) - in - let e1',_ = const_prop_exp substs assigns e1 in - let e2',_ = const_prop_exp substs assigns e2 in - re (E_loop (loop,m',e1',e2')) assigns - | E_vector es -> - let es',assigns = non_det_exp_list es in - begin - match construct_lit_vector es' with - | None -> re (E_vector es') assigns - | Some lit -> re (E_lit lit) assigns - end - | E_vector_access (e1,e2) -> - let e1',e2',assigns = non_det_exp_2 e1 e2 in - re (E_vector_access (e1',e2')) assigns - | E_vector_subrange (e1,e2,e3) -> - let e1',e2',e3',assigns = non_det_exp_3 e1 e2 e3 in - re (E_vector_subrange (e1',e2',e3')) assigns - | E_vector_update (e1,e2,e3) -> - let e1',e2',e3',assigns = non_det_exp_3 e1 e2 e3 in - re (E_vector_update (e1',e2',e3')) assigns - | E_vector_update_subrange (e1,e2,e3,e4) -> - let e1',e2',e3',e4',assigns = non_det_exp_4 e1 e2 e3 e4 in - re (E_vector_update_subrange (e1',e2',e3',e4')) assigns - | E_vector_append (e1,e2) -> - let e1',e2',assigns = non_det_exp_2 e1 e2 in - re (E_vector_append (e1',e2')) assigns - | E_list es -> - let es',assigns = non_det_exp_list es in - re (E_list es') assigns - | E_cons (e1,e2) -> - let e1',e2',assigns = non_det_exp_2 e1 e2 in - re (E_cons (e1',e2')) assigns - | E_struct fes -> - let assigned_in_fes = assigned_vars_in_fexps fes in - let assigns = isubst_minus_set assigns assigned_in_fes in - re (E_struct (const_prop_fexps substs assigns fes)) assigns - | E_struct_update (e,fes) -> - let assigned_in = IdSet.union (assigned_vars_in_fexps fes) (assigned_vars e) in - let assigns = isubst_minus_set assigns assigned_in in - let e',_ = const_prop_exp substs assigns e in - let fes' = const_prop_fexps substs assigns fes in - begin - match unaux_exp (fst (uncast_exp e')) with - | E_struct (fes0) -> - let apply_fexp (FE_aux (FE_fexp (id, e), _)) (FE_aux (FE_fexp (id', e'), ann)) = - if Id.compare id id' = 0 then - FE_aux (FE_fexp (id', e), ann) - else - FE_aux (FE_fexp (id', e'), ann) + let final = + List.fold_left + (fun acc m -> + match (acc, m) with + | _, GiveUp -> GiveUp + | GiveUp, _ -> GiveUp + | DoesMatch (sub, ksub), DoesMatch (sub', ksub') -> DoesMatch (sub @ sub', ksub @ ksub') + | _ -> DoesNotMatch + ) + (DoesMatch ([], [])) + matches in - let update_fields fexp = List.map (apply_fexp fexp) in - let fes0' = List.fold_right update_fields fes' fes0 in - re (E_struct fes0') assigns - | _ -> - re (E_struct_update (e', fes')) assigns - end - | E_field (e,id) -> - let e',assigns = const_prop_exp substs assigns e in - begin - let is_field (FE_aux (FE_fexp (id', _), _)) = Id.compare id id' = 0 in - match unaux_exp e' with - | E_struct fes0 when List.exists is_field fes0 -> - let (FE_aux (FE_fexp (_, e), _)) = List.find is_field fes0 in - re (unaux_exp e) assigns - | _ -> - re (E_field (e',id)) assigns - end - | E_match (e,cases) -> - let e',assigns = const_prop_exp substs assigns e in - (match can_match e' cases substs assigns with - | None -> - let assigned_in = - List.fold_left (fun vs pe -> IdSet.union vs (assigned_vars_in_pexp pe)) - IdSet.empty cases + match final with + | GiveUp -> + Reporting.print_err l "Monomorphisation" + ("Unexpected kind of pattern for vector literal: " ^ string_of_pat pat); + GiveUp + | _ -> final + ) + | E_vector _, P_lit (L_aux ((L_bin _ | L_hex _), _) as lit) -> + let mk_bitlit lit = P_aux (P_lit lit, (Generated l, mk_tannot env bit_typ)) in + let lits' = List.map mk_bitlit (vector_string_to_bit_list lit) in + check_exp_pat exp (P_aux (P_vector lits', (l, annot))) + | E_vector _, _ -> + Reporting.print_err l "Monomorphisation" + ("Unexpected kind of pattern for vector literal: " ^ string_of_pat pat); + GiveUp + | E_typ (undef_typ, E_aux (E_lit (L_aux (L_undef, lit_l)), _)), P_lit (L_aux (lit_p, _)) -> DoesNotMatch + | ( E_typ (undef_typ, (E_aux (E_lit (L_aux (L_undef, lit_l)), _) as e_undef)), + P_var (P_aux (P_id id, p_id_annot), TP_aux (TP_var kid, _)) ) -> + (* For undefined we fix the type-level size (because there's no good + way to construct an undefined size), but leave the term as undefined + to make the meaning clear. *) + let nexp = fabricate_nexp l annot in + let kids = equal_kids (env_of_annot p_id_annot) kid in + let ksubst = KidSet.fold (fun k b -> KBindings.add k nexp b) kids KBindings.empty in + let typ = subst_kids_typ ksubst (typ_of_annot p_id_annot) in + DoesMatch ([(id, E_aux (E_typ (typ, e_undef), (l, empty_tannot)))], KBindings.bindings ksubst) + | E_typ (undef_typ, E_aux (E_lit (L_aux (L_undef, lit_l)), _)), _ -> + Reporting.print_err l' "Monomorphisation" ("Unexpected kind of pattern for literal: " ^ string_of_pat pat); + GiveUp + | E_struct _, _ | E_typ (_, E_aux (E_struct _, _)), _ -> DoesNotMatch + | _ -> GiveUp + in + let check_pat = check_exp_pat exp0 in + let add_ksubst_synonyms env' ksubst = + (* The type checker sometimes automatically generates kid synonyms, e.g. + in let 'datasize = ... in ... it binds both 'datasize and '_datasize. + If we subsitute one, we also want to substitute the other. + In order to find synonyms, we consult the environment after the + bind (see findpat_generic below). *) + let get_synonyms (kid, nexp) = + let rec synonyms_of_nc nc = + match unaux_constraint nc with + | NC_equal (Nexp_aux (Nexp_var kid1, _), Nexp_aux (Nexp_var kid2, _)) when Kid.compare kid kid1 = 0 -> + [(kid2, nexp)] + | NC_and _ -> List.concat (List.map synonyms_of_nc (constraint_conj nc)) + | _ -> [] in - let assigns' = isubst_minus_set assigns assigned_in in - re (E_match (e', List.map (const_prop_pexp substs assigns) cases)) assigns' - | Some (E_aux (_,(_,annot')) as exp,newbindings,kbindings) -> - let exp = nexp_subst_exp (kbindings_from_list kbindings) exp in - let newbindings_env = bindings_from_list newbindings in - let substs' = bindings_union (fst substs) newbindings_env, snd substs in - const_prop_exp substs' assigns exp) - | E_let (lb,e2) -> - begin - match lb with - | LB_aux (LB_val (p,e), annot) -> - let e',assigns = const_prop_exp substs assigns e in - let substs' = remove_bound substs p in - let plain () = - let e2',assigns = const_prop_exp substs' assigns e2 in - re (E_let (LB_aux (LB_val (p,e'), annot), - e2')) assigns in - if is_value e' then - match can_match e' [Pat_aux (Pat_exp (p,e2),(Unknown,empty_tannot))] substs assigns with - | None -> plain () - | Some (e'',bindings,kbindings) -> - let e'' = nexp_subst_exp (kbindings_from_list kbindings) e'' in - let bindings = bindings_from_list bindings in - let substs'' = bindings_union (fst substs') bindings, snd substs' in - const_prop_exp substs'' assigns e'' - else plain () - end - (* TODO maybe - tuple assignments *) - | E_assign (le,e) -> - let env = Type_check.env_of_annot (l, annot) in - let assigned_in = IdSet.union (assigned_vars_in_lexp le) (assigned_vars e) in - let assigns = isubst_minus_set assigns assigned_in in - let le',idopt = const_prop_lexp substs assigns le in - let e',_ = const_prop_exp substs assigns e in - let assigns = - match idopt with - | Some id -> - begin - match Env.lookup_id id env with - | Local (Mutable,_) | Unbound _ -> - if is_value e' && not (IdSet.mem id ref_vars) - then Bindings.add id (keep_undef_typ e') assigns - else Bindings.remove id assigns - | _ -> assigns - end - | None -> assigns - in - re (E_assign (le', e')) assigns - | E_var (le, e, e2) -> - let env = Type_check.env_of_annot (l, annot) in - let assigned_in = IdSet.union (assigned_vars_in_lexp le) (assigned_vars e) in - let assigns = isubst_minus_set assigns assigned_in in - let le',idopt = const_prop_lexp substs assigns le in - let e',_ = const_prop_exp substs assigns e in - let assigns = - match idopt with - | Some id -> - begin - match Env.lookup_id id env with - | Local (Mutable,_) | Unbound _ -> - if is_value e' && not (IdSet.mem id ref_vars) - then Bindings.add id (keep_undef_typ e') assigns - else Bindings.remove id assigns - | _ -> assigns - end - | None -> assigns - in - let e2', _ = const_prop_exp substs assigns e2 in - re (E_var (le', e', e2')) assigns - | E_exit e -> - let e',_ = const_prop_exp substs assigns e in - re (E_exit e') Bindings.empty - | E_ref id -> re (E_ref id) Bindings.empty - | E_throw e -> - let e',_ = const_prop_exp substs assigns e in - re (E_throw e') Bindings.empty - | E_try (e,cases) -> - (* TODO: try and preserve *any* assignment info; note the special case in E_if if - one of the branches throws. *) - let e',_ = const_prop_exp substs assigns e in - re (E_match (e', List.map (const_prop_pexp substs Bindings.empty) cases)) Bindings.empty - | E_return e -> - let e',_ = const_prop_exp substs assigns e in - re (E_return e') Bindings.empty - | E_assert (e1,e2) -> - let e1',e2',assigns = non_det_exp_2 e1 e2 in - re (E_assert (e1',e2')) assigns - | E_internal_assume (nc, e) -> - let e',_ = const_prop_exp substs assigns e in - re (E_internal_assume (nc, e')) assigns - - | E_app_infix _ - | E_internal_plet _ - | E_internal_return _ - | E_internal_value _ - -> raise (Reporting.err_unreachable l __POS__ - ("Unexpected expression encountered in monomorphisation: " ^ string_of_exp exp)) - and const_prop_fexps substs assigns fes = - List.map (const_prop_fexp substs assigns) fes - and const_prop_fexp substs assigns (FE_aux (FE_fexp (id,e), annot)) = - FE_aux (FE_fexp (id,fst (const_prop_exp substs assigns e)),annot) - and const_prop_pexp substs assigns = function - | (Pat_aux (Pat_exp (p,e),l)) -> - Pat_aux (Pat_exp (p,fst (const_prop_exp (remove_bound substs p) assigns e)),l) - | (Pat_aux (Pat_when (p,e1,e2),l)) -> - let substs' = remove_bound substs p in - let e1',assigns = const_prop_exp substs' assigns e1 in - Pat_aux (Pat_when (p, e1', fst (const_prop_exp substs' assigns e2)),l) - and const_prop_lexp substs assigns ((LE_aux (e,annot)) as le) = - let re e = LE_aux (e,annot), None in - match e with - | LE_id id (* shouldn't end up substituting here *) - | LE_typ (_,id) - -> le, Some id - | LE_app (id,es) -> - re (LE_app (id,List.map (fun e -> fst (const_prop_exp substs assigns e)) es)) (* or here *) - | LE_tuple les -> re (LE_tuple (List.map (fun le -> fst (const_prop_lexp substs assigns le)) les)) - | LE_vector (le,e) -> re (LE_vector (fst (const_prop_lexp substs assigns le), fst (const_prop_exp substs assigns e))) - | LE_vector_range (le,e1,e2) -> - re (LE_vector_range (fst (const_prop_lexp substs assigns le), - fst (const_prop_exp substs assigns e1), - fst (const_prop_exp substs assigns e2))) - | LE_vector_concat les -> re (LE_vector_concat (List.map (fun le -> fst (const_prop_lexp substs assigns le)) les)) - | LE_field (le,id) -> re (LE_field (fst (const_prop_lexp substs assigns le), id)) - | LE_deref e -> - re (LE_deref (fst (const_prop_exp substs assigns e))) - (* Try to evaluate function calls with constant arguments via - (interpreter-based) constant folding. - Boolean connectives are special-cased to support short-circuiting when one - argument has a suitable value (even if the other one is not constant). - Moreover, calls to a __size function (in particular generated by sizeof - rewriting) with a known-constant return type are replaced by that constant; - e.g., (length(op : bits(32)) : int(32)) becomes 32 even if op is not constant. - *) - and const_prop_try_fn env (id, args) (l, annot) = - let exp_orig = E_aux (E_app (id, args), (l, annot)) in - let args = List.map replace_constant args in - let exp = E_aux (E_app (id, args), (l, annot)) in - let rec is_overload_of f = - Env.get_overloads f env - |> List.exists (fun id' -> Id.compare id id' = 0 || is_overload_of id') - in - match (string_of_id id, args) with - | "and_bool", ([E_aux (E_lit (L_aux (L_false, _)), _) as e_false; _] | - [_; E_aux (E_lit (L_aux (L_false, _)), _) as e_false]) -> - e_false - | "or_bool", ([E_aux (E_lit (L_aux (L_true, _)), _) as e_true; _] | - [_; E_aux (E_lit (L_aux (L_true, _)), _) as e_true]) -> - e_true - | _, [E_aux (E_vector [], _); e'] - | _, [e'; E_aux (E_vector [], _)] - when is_overload_of (mk_id "append") -> - e' - | _, _ when List.for_all Constant_fold.is_constant args -> - const_fold exp - | _, [arg] when is_overload_of (mk_id "__size") -> - (match destruct_atom_nexp env (typ_of exp) with - | Some (Nexp_aux (Nexp_constant i, _)) -> - E_aux (E_lit (mk_lit (L_num i)), (l, annot)) - | _ -> exp_orig) - | _ -> exp_orig - - and can_match_with_env env (E_aux (e,(l,annot)) as exp0) cases (substs,ksubsts) assigns = - let rec check_exp_pat (E_aux (e,(l,annot)) as exp) (P_aux (p,(l',_)) as pat) = - match e, p with - | _, P_wild -> DoesMatch ([],[]) - | _, P_typ (_,p') -> check_exp_pat exp p' - | _, P_id id' when pat_id_is_variable env id' -> - let exp_typ = typ_of exp in - let pat_typ = typ_of_pat pat in - let goals = KidSet.diff (tyvars_of_typ pat_typ) (tyvars_of_typ exp_typ) in - let unifiers = - try Type_check.unify l env goals pat_typ exp_typ - with _ -> KBindings.empty in - let is_nexp (k,a) = match a with - | A_aux (A_nexp n,_) -> Some (k,n) - | _ -> None - in - let kbindings = List.filter_map is_nexp (KBindings.bindings unifiers) in - DoesMatch ([id',exp],kbindings) - | E_tuple es, P_tuple ps -> - let check = function - | DoesNotMatch -> fun _ -> DoesNotMatch - | GiveUp -> fun _ -> GiveUp - | DoesMatch (s,ns) -> - fun (e,p) -> - match check_exp_pat e p with - | DoesMatch (s',ns') -> DoesMatch (s@s', ns@ns') - | x -> x - in List.fold_left check (DoesMatch ([],[])) (List.combine es ps) - | E_id id, _ -> - (match Env.lookup_id id env with - | Enum _ -> begin - match p with - | P_id id' - | P_app (id',[]) -> - if Id.compare id id' = 0 then DoesMatch ([],[]) else DoesNotMatch - | _ -> - (Reporting.print_err l' "Monomorphisation" - ("Unexpected kind of pattern for enumeration: " ^ string_of_pat pat); GiveUp) - end - | _ -> GiveUp) - | E_lit (L_aux (lit_e, lit_l)), P_lit (L_aux (lit_p, _)) -> - if lit_match (lit_e,lit_p) then DoesMatch ([],[]) else DoesNotMatch - | E_lit (L_aux (lit_e, lit_l)), - P_var (P_aux (P_id id,p_id_annot), TP_aux (TP_var kid, _)) -> - begin - match lit_e with - | L_num i -> - DoesMatch ([id, E_aux (e,(l,annot))], - [kid,Nexp_aux (Nexp_constant i,Unknown)]) - (* For undefined we fix the type-level size (because there's no good - way to construct an undefined size), but leave the term as undefined - to make the meaning clear. *) - | L_undef -> - let nexp = fabricate_nexp l annot in - let typ = subst_kids_typ (KBindings.singleton kid nexp) (typ_of_annot p_id_annot) in - DoesMatch ([id, E_aux (E_typ (typ,E_aux (e,(l,empty_tannot))),(l,empty_tannot))], - [kid,nexp]) - | _ -> - (Reporting.print_err lit_l "Monomorphisation" - ("Unexpected kind of literal for var match: " ^ string_of_lit (L_aux (lit_e, lit_l))); GiveUp) - end - | E_lit ((L_aux ((L_bin _ | L_hex _), _) as lit)), P_vector _ -> - let mk_bitlit lit = E_aux (E_lit lit, (Generated l, mk_tannot env bit_typ)) in - let lits' = List.map mk_bitlit (vector_string_to_bit_list lit) in - check_exp_pat (E_aux (E_vector lits', (l, annot))) pat - | E_lit _, _ -> - (Reporting.print_err l' "Monomorphisation" - ("Unexpected kind of pattern for literal: " ^ string_of_pat pat); GiveUp) - | E_vector es, P_vector ps - when List.for_all (function (E_aux (E_lit _,_)) -> true | _ -> false) es -> - let matches = List.map2 (fun e p -> - let p = match p with P_aux (P_typ (_,p'),_) -> p' | _ -> p in - match e, p with - | E_aux (E_lit (L_aux (lit,_)),_), P_aux (P_lit (L_aux (lit',_)),_) -> - if lit_match (lit,lit') then DoesMatch ([],[]) else DoesNotMatch - | E_aux (E_lit l,_), P_aux (P_id var,_) when pat_id_is_variable env var -> - DoesMatch ([var, e],[]) - | _, P_aux (P_wild, _) -> DoesMatch ([],[]) - | _ -> GiveUp) es ps in - let final = List.fold_left (fun acc m -> match acc, m with - | _, GiveUp -> GiveUp - | GiveUp, _ -> GiveUp - | DoesMatch (sub,ksub), DoesMatch(sub',ksub') -> DoesMatch(sub@sub',ksub@ksub') - | _ -> DoesNotMatch) (DoesMatch ([],[])) matches in - (match final with - | GiveUp -> - (Reporting.print_err l "Monomorphisation" - ("Unexpected kind of pattern for vector literal: " ^ string_of_pat pat); GiveUp) - | _ -> final) - | E_vector _, P_lit ((L_aux ((L_bin _ | L_hex _), _) as lit)) -> - let mk_bitlit lit = P_aux (P_lit lit, (Generated l, mk_tannot env bit_typ)) in - let lits' = List.map mk_bitlit (vector_string_to_bit_list lit) in - check_exp_pat exp (P_aux (P_vector lits', (l, annot))) - | E_vector _, _ -> - (Reporting.print_err l "Monomorphisation" - ("Unexpected kind of pattern for vector literal: " ^ string_of_pat pat); GiveUp) - | E_typ (undef_typ, (E_aux (E_lit (L_aux (L_undef, lit_l)),_))), - P_lit (L_aux (lit_p, _)) - -> DoesNotMatch - | E_typ (undef_typ, (E_aux (E_lit (L_aux (L_undef, lit_l)),_) as e_undef)), - P_var (P_aux (P_id id,p_id_annot), TP_aux (TP_var kid, _)) -> - (* For undefined we fix the type-level size (because there's no good - way to construct an undefined size), but leave the term as undefined - to make the meaning clear. *) - let nexp = fabricate_nexp l annot in - let kids = equal_kids (env_of_annot p_id_annot) kid in - let ksubst = KidSet.fold (fun k b -> KBindings.add k nexp b) kids KBindings.empty in - let typ = subst_kids_typ ksubst (typ_of_annot p_id_annot) in - DoesMatch ([id, E_aux (E_typ (typ,e_undef),(l,empty_tannot))], - KBindings.bindings ksubst) - | E_typ (undef_typ, (E_aux (E_lit (L_aux (L_undef, lit_l)),_))), _ -> - (Reporting.print_err l' "Monomorphisation" - ("Unexpected kind of pattern for literal: " ^ string_of_pat pat); GiveUp) - | E_struct _,_ | E_typ (_, E_aux (E_struct _, _)),_ -> DoesNotMatch - | _ -> GiveUp - in - let check_pat = check_exp_pat exp0 in - let add_ksubst_synonyms env' ksubst = - (* The type checker sometimes automatically generates kid synonyms, e.g. - in let 'datasize = ... in ... it binds both 'datasize and '_datasize. - If we subsitute one, we also want to substitute the other. - In order to find synonyms, we consult the environment after the - bind (see findpat_generic below). *) - let get_synonyms (kid, nexp) = - let rec synonyms_of_nc nc = match unaux_constraint nc with - | NC_equal (Nexp_aux (Nexp_var kid1, _), Nexp_aux (Nexp_var (kid2), _)) - when Kid.compare kid kid1 = 0 -> - [(kid2, nexp)] - | NC_and _ -> List.concat (List.map synonyms_of_nc (constraint_conj nc)) - | _ -> [] + List.concat (List.map synonyms_of_nc (Env.get_constraints env')) in - List.concat (List.map synonyms_of_nc (Env.get_constraints env')) + ksubst @ List.concat (List.map get_synonyms ksubst) in - ksubst @ List.concat (List.map get_synonyms ksubst) + let rec findpat_generic description assigns = function + | [] -> + Reporting.print_err l "Monomorphisation" ("Failed to find a case for " ^ description); + None + | Pat_aux (Pat_when (p, guard, exp), _) :: tl -> begin + match check_pat p with + | DoesNotMatch -> findpat_generic description assigns tl + | DoesMatch (vsubst, ksubst) -> begin + let guard = nexp_subst_exp (kbindings_from_list ksubst) guard in + let substs = + ( bindings_union substs (bindings_from_list vsubst), + kbindings_union ksubsts (kbindings_from_list ksubst) + ) + in + let E_aux (guard, _), assigns = const_prop_exp substs assigns guard in + match guard with + | E_lit (L_aux (L_true, _)) -> + let ksubst = add_ksubst_synonyms (env_of exp) ksubst in + Some (exp, vsubst, ksubst) + | E_lit (L_aux (L_false, _)) -> findpat_generic description assigns tl + | _ -> None + end + | GiveUp -> None + end + | Pat_aux (Pat_exp (p, exp), _) :: tl -> ( + match check_pat p with + | DoesNotMatch -> findpat_generic description assigns tl + | DoesMatch (subst, ksubst) -> + let ksubst = add_ksubst_synonyms (env_of exp) ksubst in + Some (exp, subst, ksubst) + | GiveUp -> None + ) + in + findpat_generic (string_of_exp exp0) assigns cases + and can_match exp = + let env = Type_check.env_of exp in + can_match_with_env env exp in - let rec findpat_generic description assigns = function - | [] -> (Reporting.print_err l "Monomorphisation" - ("Failed to find a case for " ^ description); None) - | (Pat_aux (Pat_when (p,guard,exp),_))::tl -> begin - match check_pat p with - | DoesNotMatch -> findpat_generic description assigns tl - | DoesMatch (vsubst,ksubst) -> begin - let guard = nexp_subst_exp (kbindings_from_list ksubst) guard in - let substs = bindings_union substs (bindings_from_list vsubst), - kbindings_union ksubsts (kbindings_from_list ksubst) in - let (E_aux (guard,_)),assigns = const_prop_exp substs assigns guard in - match guard with - | E_lit (L_aux (L_true,_)) -> - let ksubst = add_ksubst_synonyms (env_of exp) ksubst in - Some (exp,vsubst,ksubst) - | E_lit (L_aux (L_false,_)) -> findpat_generic description assigns tl - | _ -> None - end - | GiveUp -> None - end - | (Pat_aux (Pat_exp (p,exp),_))::tl -> - match check_pat p with - | DoesNotMatch -> findpat_generic description assigns tl - | DoesMatch (subst,ksubst) -> - let ksubst = add_ksubst_synonyms (env_of exp) ksubst in - Some (exp,subst,ksubst) - | GiveUp -> None - in findpat_generic (string_of_exp exp0) assigns cases - and can_match exp = - let env = Type_check.env_of exp in - can_match_with_env env exp + (const_prop_exp, const_prop_pexp) -in (const_prop_exp, const_prop_pexp) - -let const_prop target d = let f = const_props target d in fun r -> fst (f r) +let const_prop target d = + let f = const_props target d in + fun r -> fst (f r) let referenced_vars exp = let open Rewriter in - fst (fold_exp - { (compute_exp_alg IdSet.empty IdSet.union) with - e_ref = (fun id -> IdSet.singleton id, E_ref id) } exp) + fst + (fold_exp { (compute_exp_alg IdSet.empty IdSet.union) with e_ref = (fun id -> (IdSet.singleton id, E_ref id)) } exp) (* This is intended to remove impossible cases when a type-level constant has been used to fix a property of the architecture. In particular, the current @@ -876,26 +860,24 @@ let referenced_vars exp = *) let remove_impossible_int_cases _ = - - let must_keep_case exp (Pat_aux ((Pat_exp (p,_) | Pat_when (p,_,_)),_)) = - let rec aux (E_aux (exp,_)) (P_aux (p,_)) = - match exp, p with + let must_keep_case exp (Pat_aux ((Pat_exp (p, _) | Pat_when (p, _, _)), _)) = + let rec aux (E_aux (exp, _)) (P_aux (p, _)) = + match (exp, p) with | E_tuple exps, P_tuple ps -> List.for_all2 aux exps ps - | E_lit (L_aux (lit,_)), P_lit (L_aux (lit',_)) -> lit_match (lit, lit') + | E_lit (L_aux (lit, _)), P_lit (L_aux (lit', _)) -> lit_match (lit, lit') | _ -> true - in aux exp p - in - let e_case (exp,cases) = - E_match (exp, List.filter (must_keep_case exp) cases) + in + aux exp p in + let e_case (exp, cases) = E_match (exp, List.filter (must_keep_case exp) cases) in let e_if (cond, e_then, e_else) = match destruct_atom_bool (env_of cond) (typ_of cond) with | Some nc -> - if prove __POS__ (env_of cond) nc then unaux_exp e_then else - if prove __POS__ (env_of cond) (nc_not nc) then unaux_exp e_else else - E_if (cond, e_then, e_else) + if prove __POS__ (env_of cond) nc then unaux_exp e_then + else if prove __POS__ (env_of cond) (nc_not nc) then unaux_exp e_else + else E_if (cond, e_then, e_else) | _ -> E_if (cond, e_then, e_else) in let open Rewriter in - let rewrite_exp _ = fold_exp { id_exp_alg with e_case = e_case; e_if = e_if } in - rewrite_ast_base { rewriters_base with rewrite_exp = rewrite_exp } + let rewrite_exp _ = fold_exp { id_exp_alg with e_case; e_if } in + rewrite_ast_base { rewriters_base with rewrite_exp } diff --git a/src/lib/constant_propagation_mutrec.ml b/src/lib/constant_propagation_mutrec.ml index 2d2117423..148950d14 100644 --- a/src/lib/constant_propagation_mutrec.ml +++ b/src/lib/constant_propagation_mutrec.ml @@ -79,35 +79,30 @@ open Rewriter let targets = ref ([] : id list) -let rec is_const_exp exp = match unaux_exp exp with +let rec is_const_exp exp = + match unaux_exp exp with | E_lit (L_aux ((L_true | L_false | L_one | L_zero | L_num _), _)) -> true | E_vector es -> List.for_all is_const_exp es && is_bitvector_typ (typ_of exp) | E_struct fes -> List.for_all is_const_fexp fes | _ -> false + and is_const_fexp (FE_aux (FE_fexp (_, e), _)) = is_const_exp e let recheck_exp exp = check_exp (env_of exp) (strip_exp exp) (typ_of exp) (* Name function copy by encoding values of constant arguments *) let generate_fun_id id args = - let rec suffix exp = match unaux_exp exp with + let rec suffix exp = + match unaux_exp exp with | E_lit (L_aux (L_one, _)) -> "1" | E_lit (L_aux (L_zero, _)) -> "0" | E_lit (L_aux (L_true, _)) -> "T" | E_lit (L_aux (L_false, _)) -> "F" | E_struct fes when is_const_exp exp -> - let fsuffix (FE_aux (FE_fexp (id, e), _)) = suffix e - in - "struct" ^ - Util.zencode_string (string_of_typ (typ_of exp)) ^ - "#" ^ - String.concat "" (List.map fsuffix fes) - | E_vector es when is_const_exp exp -> - String.concat "" (List.map suffix es) - | _ -> - if is_const_exp exp - then "#" ^ Util.zencode_string (string_of_exp exp) - else "v" + let fsuffix (FE_aux (FE_fexp (id, e), _)) = suffix e in + "struct" ^ Util.zencode_string (string_of_typ (typ_of exp)) ^ "#" ^ String.concat "" (List.map fsuffix fes) + | E_vector es when is_const_exp exp -> String.concat "" (List.map suffix es) + | _ -> if is_const_exp exp then "#" ^ Util.zencode_string (string_of_exp exp) else "v" in append_id id ("#mutrec_" ^ String.concat "" (List.map suffix args)) @@ -115,97 +110,87 @@ let generate_fun_id id args = that will be propagated in *) let generate_val_spec env id args l annot = match Env.get_val_spec_orig id env with - | tq, (Typ_aux (Typ_fn (arg_typs, ret_typ), _)) -> - (* Get instantiation of type variables at call site *) - let orig_ksubst (kid, typ_arg) = - match typ_arg with - | A_aux ((A_nexp _ | A_bool _), _) -> (orig_kid kid, typ_arg) - | _ -> raise (Reporting.err_todo l "Propagation of polymorphic arguments not implemented") - in - let ksubsts = - recheck_exp (E_aux (E_app (id, args), (l, annot))) - |> instantiation_of - |> KBindings.bindings - |> List.map orig_ksubst - |> List.fold_left (fun s (v,i) -> KBindings.add v i s) KBindings.empty - in - (* Apply instantiation to original function type. Also collect the - type variables in the new type together their kinds for the new - val spec. *) - let kopts_of_typ env typ = - tyvars_of_typ typ |> KidSet.elements - |> List.map (fun kid -> mk_kopt (Env.get_typ_var kid env) kid) - |> KOptSet.of_list - in - let ret_typ' = KBindings.fold typ_subst ksubsts ret_typ in - let (arg_typs', kopts') = - List.fold_right2 (fun arg typ (arg_typs', kopts') -> - if is_const_exp arg then - (arg_typs', kopts') - else - let typ' = KBindings.fold typ_subst ksubsts typ in - let arg_kopts = kopts_of_typ (env_of arg) typ' in - (typ' :: arg_typs', KOptSet.union arg_kopts kopts')) - args arg_typs ([], kopts_of_typ (env_of_tannot annot) ret_typ') - in - let arg_typs' = if arg_typs' = [] then [unit_typ] else arg_typs' in - let typ' = mk_typ (Typ_fn (arg_typs', ret_typ')) in - (* Construct new val spec *) - let constraints' = - quant_split tq |> snd - |> List.map (KBindings.fold constraint_subst ksubsts) - |> List.filter (fun nc -> KidSet.subset (tyvars_of_constraint nc) (tyvars_of_typ typ')) - in - let quant_items' = - List.map mk_qi_kopt (KOptSet.elements kopts') @ - List.map mk_qi_nc constraints' - in - let typschm = mk_typschm (mk_typquant quant_items') typ' in - mk_val_spec (VS_val_spec (typschm, generate_fun_id id args, None, false)), - ksubsts - | _, Typ_aux (_, l) -> - raise (Reporting.err_unreachable l __POS__ "Function val spec is not a function type") + | tq, Typ_aux (Typ_fn (arg_typs, ret_typ), _) -> + (* Get instantiation of type variables at call site *) + let orig_ksubst (kid, typ_arg) = + match typ_arg with + | A_aux ((A_nexp _ | A_bool _), _) -> (orig_kid kid, typ_arg) + | _ -> raise (Reporting.err_todo l "Propagation of polymorphic arguments not implemented") + in + let ksubsts = + recheck_exp (E_aux (E_app (id, args), (l, annot))) + |> instantiation_of |> KBindings.bindings |> List.map orig_ksubst + |> List.fold_left (fun s (v, i) -> KBindings.add v i s) KBindings.empty + in + (* Apply instantiation to original function type. Also collect the + type variables in the new type together their kinds for the new + val spec. *) + let kopts_of_typ env typ = + tyvars_of_typ typ |> KidSet.elements + |> List.map (fun kid -> mk_kopt (Env.get_typ_var kid env) kid) + |> KOptSet.of_list + in + let ret_typ' = KBindings.fold typ_subst ksubsts ret_typ in + let arg_typs', kopts' = + List.fold_right2 + (fun arg typ (arg_typs', kopts') -> + if is_const_exp arg then (arg_typs', kopts') + else ( + let typ' = KBindings.fold typ_subst ksubsts typ in + let arg_kopts = kopts_of_typ (env_of arg) typ' in + (typ' :: arg_typs', KOptSet.union arg_kopts kopts') + ) + ) + args arg_typs + ([], kopts_of_typ (env_of_tannot annot) ret_typ') + in + let arg_typs' = if arg_typs' = [] then [unit_typ] else arg_typs' in + let typ' = mk_typ (Typ_fn (arg_typs', ret_typ')) in + (* Construct new val spec *) + let constraints' = + quant_split tq |> snd + |> List.map (KBindings.fold constraint_subst ksubsts) + |> List.filter (fun nc -> KidSet.subset (tyvars_of_constraint nc) (tyvars_of_typ typ')) + in + let quant_items' = List.map mk_qi_kopt (KOptSet.elements kopts') @ List.map mk_qi_nc constraints' in + let typschm = mk_typschm (mk_typquant quant_items') typ' in + (mk_val_spec (VS_val_spec (typschm, generate_fun_id id args, None, false)), ksubsts) + | _, Typ_aux (_, l) -> raise (Reporting.err_unreachable l __POS__ "Function val spec is not a function type") let const_prop target defs substs ksubsts exp = (* Constant_propagation currently only supports nexps for kid substitutions *) let nexp_substs = KBindings.bindings ksubsts - |> List.map (function (kid, A_aux (A_nexp n, _)) -> [(kid, n)] | _ -> []) + |> List.map (function kid, A_aux (A_nexp n, _) -> [(kid, n)] | _ -> []) |> List.concat - |> List.fold_left (fun s (v,i) -> KBindings.add v i s) KBindings.empty + |> List.fold_left (fun s (v, i) -> KBindings.add v i s) KBindings.empty in - Constant_propagation.const_prop - target - defs + Constant_propagation.const_prop target defs (Constant_propagation.referenced_vars exp) - (substs, nexp_substs) - Bindings.empty - exp + (substs, nexp_substs) Bindings.empty exp |> fst (* Propagate constant arguments into function clause pexp *) let prop_args_pexp target ast ksubsts args pexp = let pat, guard, exp, annot = destruct_pexp pexp in - let pats = match pat with - | P_aux (P_tuple pats, _) -> pats - | _ -> [pat] - in + let pats = match pat with P_aux (P_tuple pats, _) -> pats | _ -> [pat] in let match_arg (E_aux (_, (l, _)) as arg) pat (pats, substs) = - if is_const_exp arg then + if is_const_exp arg then ( match pat with | P_aux (P_id id, _) -> (pats, Bindings.add id arg substs) | _ -> - raise (Reporting.err_todo l - ("Unsupported pattern match in propagation of constant arguments: " ^ - string_of_exp arg ^ " and " ^ string_of_pat pat)) + raise + (Reporting.err_todo l + ("Unsupported pattern match in propagation of constant arguments: " ^ string_of_exp arg ^ " and " + ^ string_of_pat pat + ) + ) + ) else (pat :: pats, substs) in let pats, substs = List.fold_right2 match_arg args pats ([], Bindings.empty) in let exp' = const_prop target ast substs ksubsts exp in - let pat' = match pats with - | [pat] -> pat - | _ -> P_aux (P_tuple pats, (Parse_ast.Unknown, empty_tannot)) - in + let pat' = match pats with [pat] -> pat | _ -> P_aux (P_tuple pats, (Parse_ast.Unknown, empty_tannot)) in construct_pexp (pat', guard, exp', annot) let rewrite_ast target effect_info env ({ defs; _ } as ast) = @@ -213,60 +198,58 @@ let rewrite_ast target effect_info env ({ defs; _ } as ast) = let rec rewrite = function | [] -> [] | DEF_aux (DEF_internal_mutrec mutrecs, def_annot) :: ds -> - let mutrec_ids = IdSet.of_list (List.map id_of_fundef mutrecs) in - let valspecs = ref ([] : uannot def list) in - let fundefs = ref ([] : uannot def list) in - (* Try to replace mutually recursive calls that have some constant arguments *) - let rec e_app (id, args) (l, annot) = - if IdSet.mem id mutrec_ids && List.exists is_const_exp args then - let id' = generate_fun_id id args in - effect_info := Effects.copy_function_effect id !effect_info id'; - let args' = match List.filter (fun e -> not (is_const_exp e)) args with - | [] -> [infer_exp env (mk_lit_exp L_unit)] - | args' -> args' - in - if not (IdSet.mem id' (ids_of_defs !valspecs)) then begin - (* Generate copy of function with constant arguments propagated in *) - let (FD_aux (FD_function (_, _, fcls), _)) = - List.find (fun fd -> Id.compare id (id_of_fundef fd) = 0) mutrecs - in - let valspec, ksubsts = generate_val_spec env id args l annot in - let const_prop_funcl (FCL_aux (FCL_funcl (_, pexp), (fcl_def_annot, _))) = - let pexp' = - prop_args_pexp target ast ksubsts args pexp - |> rewrite_pexp - |> strip_pexp - in - FCL_aux (FCL_funcl (id', pexp'), (def_annot_map_loc gen_loc fcl_def_annot, empty_uannot)) - in - valspecs := valspec :: !valspecs; - let fundef = mk_fundef (List.map const_prop_funcl fcls) in - fundefs := fundef :: !fundefs - end else (); - E_aux (E_app (id', args'), (l, annot)) - else E_aux (E_app (id, args), (l, annot)) - and e_aux (e, (l, annot)) = - match e with - | E_app (id, args) -> e_app (id, args) (l, annot) - | _ -> E_aux (e, (l, annot)) - and rewrite_pexp pexp = fold_pexp { id_exp_alg with e_aux = e_aux } pexp - and rewrite_funcl (FCL_aux (FCL_funcl (id, pexp), a)) = - let pexp' = - if List.exists (fun id' -> Id.compare id id' = 0) !targets then - let pat, guard, body, annot = destruct_pexp pexp in - let body' = const_prop target ast Bindings.empty KBindings.empty body in - rewrite_pexp (construct_pexp (pat, guard, recheck_exp body', annot)) - else pexp - in FCL_aux (FCL_funcl (id, pexp'), a) - and rewrite_fundef (FD_aux (FD_function (ropt, topt, fcls), a)) = - let fcls' = List.map rewrite_funcl fcls in - FD_aux (FD_function (ropt, topt, fcls'), a) - in - let mutrecs' = List.map (fun fd -> DEF_aux (DEF_fundef (rewrite_fundef fd), def_annot)) mutrecs in - let fdefs = fst (check_defs env (!valspecs @ !fundefs)) in - mutrecs' @ fdefs @ rewrite ds - | d :: ds -> - d :: rewrite ds + let mutrec_ids = IdSet.of_list (List.map id_of_fundef mutrecs) in + let valspecs = ref ([] : uannot def list) in + let fundefs = ref ([] : uannot def list) in + (* Try to replace mutually recursive calls that have some constant arguments *) + let rec e_app (id, args) (l, annot) = + if IdSet.mem id mutrec_ids && List.exists is_const_exp args then ( + let id' = generate_fun_id id args in + effect_info := Effects.copy_function_effect id !effect_info id'; + let args' = + match List.filter (fun e -> not (is_const_exp e)) args with + | [] -> [infer_exp env (mk_lit_exp L_unit)] + | args' -> args' + in + if not (IdSet.mem id' (ids_of_defs !valspecs)) then begin + (* Generate copy of function with constant arguments propagated in *) + let (FD_aux (FD_function (_, _, fcls), _)) = + List.find (fun fd -> Id.compare id (id_of_fundef fd) = 0) mutrecs + in + let valspec, ksubsts = generate_val_spec env id args l annot in + let const_prop_funcl (FCL_aux (FCL_funcl (_, pexp), (fcl_def_annot, _))) = + let pexp' = prop_args_pexp target ast ksubsts args pexp |> rewrite_pexp |> strip_pexp in + FCL_aux (FCL_funcl (id', pexp'), (def_annot_map_loc gen_loc fcl_def_annot, empty_uannot)) + in + valspecs := valspec :: !valspecs; + let fundef = mk_fundef (List.map const_prop_funcl fcls) in + fundefs := fundef :: !fundefs + end + else (); + E_aux (E_app (id', args'), (l, annot)) + ) + else E_aux (E_app (id, args), (l, annot)) + and e_aux (e, (l, annot)) = + match e with E_app (id, args) -> e_app (id, args) (l, annot) | _ -> E_aux (e, (l, annot)) + and rewrite_pexp pexp = fold_pexp { id_exp_alg with e_aux } pexp + and rewrite_funcl (FCL_aux (FCL_funcl (id, pexp), a)) = + let pexp' = + if List.exists (fun id' -> Id.compare id id' = 0) !targets then ( + let pat, guard, body, annot = destruct_pexp pexp in + let body' = const_prop target ast Bindings.empty KBindings.empty body in + rewrite_pexp (construct_pexp (pat, guard, recheck_exp body', annot)) + ) + else pexp + in + FCL_aux (FCL_funcl (id, pexp'), a) + and rewrite_fundef (FD_aux (FD_function (ropt, topt, fcls), a)) = + let fcls' = List.map rewrite_funcl fcls in + FD_aux (FD_function (ropt, topt, fcls'), a) + in + let mutrecs' = List.map (fun fd -> DEF_aux (DEF_fundef (rewrite_fundef fd), def_annot)) mutrecs in + let fdefs = fst (check_defs env (!valspecs @ !fundefs)) in + mutrecs' @ fdefs @ rewrite ds + | d :: ds -> d :: rewrite ds in let new_ast = Spec_analysis.top_sort_defs { ast with defs = rewrite defs } in - new_ast, !effect_info, env + (new_ast, !effect_info, env) diff --git a/src/lib/constraint.ml b/src/lib/constraint.ml index 7cfa0ebc8..fe3586f89 100644 --- a/src/lib/constraint.ml +++ b/src/lib/constraint.ml @@ -72,31 +72,28 @@ open Util let opt_smt_verbose = ref false -type solver = { - command : string; - header : string; - footer : string; - negative_literals : bool; - uninterpret_power : bool - } +type solver = { command : string; header : string; footer : string; negative_literals : bool; uninterpret_power : bool } -let cvc4_solver = { +let cvc4_solver = + { command = "cvc4 -L smtlib2 --tlimit=2000"; header = "(set-logic QF_UFNIA)\n"; footer = ""; negative_literals = false; - uninterpret_power = true + uninterpret_power = true; } -let mathsat_solver = { +let mathsat_solver = + { command = "mathsat"; header = "(set-logic QF_UFLIA)\n"; footer = ""; negative_literals = false; - uninterpret_power = true + uninterpret_power = true; } -let z3_solver = { +let z3_solver = + { command = "z3 -t:1000 -T:10"; (* Using push and pop is much faster, I believe because incremental mode uses a different solver. *) @@ -106,30 +103,27 @@ let z3_solver = { uninterpret_power = false; } -let yices_solver = { +let yices_solver = + { command = "yices-smt2 --timeout=2"; header = "(set-logic QF_UFLIA)\n"; footer = ""; negative_literals = false; - uninterpret_power = true + uninterpret_power = true; } -let vampire_solver = { +let vampire_solver = + { (* vampire sometimes likes to ignore its time limit *) command = "timeout -s SIGKILL 3s vampire --time_limit 2s --input_syntax smtlib2 --mode smtcomp"; header = ""; footer = ""; negative_literals = false; - uninterpret_power = true + uninterpret_power = true; } -let alt_ergo_solver ={ - command = "alt-ergo"; - header = ""; - footer = ""; - negative_literals = false; - uninterpret_power = true - } +let alt_ergo_solver = + { command = "alt-ergo"; header = ""; footer = ""; negative_literals = false; uninterpret_power = true } let opt_solver = ref z3_solver @@ -148,30 +142,27 @@ type sexpr = List of sexpr list | Atom of string let sfun (fn : string) (xs : sexpr list) : sexpr = List (Atom fn :: xs) -let rec pp_sexpr : sexpr -> string = function - | List xs -> "(" ^ string_of_list " " pp_sexpr xs ^ ")" - | Atom x -> x +let rec pp_sexpr : sexpr -> string = function List xs -> "(" ^ string_of_list " " pp_sexpr xs ^ ")" | Atom x -> x let rec add_sexpr buf = function | List xs -> - Buffer.add_char buf '('; - Util.iter_last (fun last x -> - add_sexpr buf x; - if not last then ( - Buffer.add_char buf ' ' - ) - ) xs; - Buffer.add_char buf ')' - | Atom x -> - Buffer.add_string buf x + Buffer.add_char buf '('; + Util.iter_last + (fun last x -> + add_sexpr buf x; + if not last then Buffer.add_char buf ' ' + ) + xs; + Buffer.add_char buf ')' + | Atom x -> Buffer.add_string buf x let rec add_list buf sep add_elem = function | [] -> () | [x] -> add_elem buf x | x :: xs -> - add_elem buf x; - Buffer.add_char buf sep; - add_list buf sep add_elem xs + add_elem buf x; + Buffer.add_char buf sep; + add_list buf sep add_elem xs (* Each non-Type/Order kind in Sail maps to a type in the SMT solver *) let smt_type l = function @@ -188,49 +179,46 @@ let to_smt l vars constr = let vnum = ref (-1) in let smt_var v = match KBindings.find_opt v !var_map with - | Some n -> Atom ("v" ^ string_of_int n), false + | Some n -> (Atom ("v" ^ string_of_int n), false) | None -> - let n = !vnum + 1 in - var_map := KBindings.add v n !var_map; - vnum := n; - Atom ("v" ^ string_of_int n), true + let n = !vnum + 1 in + var_map := KBindings.add v n !var_map; + vnum := n; + (Atom ("v" ^ string_of_int n), true) in let exponentials = ref [] in - + (* var_decs outputs the list of variables to be used by the SMT solver in SMTLIB v2.0 format. It takes a kind_aux KBindings, as returned by Type_check.get_typ_vars *) let var_decs l (vars : kind_aux KBindings.t) : sexpr list = - vars - |> KBindings.bindings - |> List.map (fun (v, k) -> sfun "declare-const" [fst (smt_var v); smt_type l k]) + vars |> KBindings.bindings |> List.map (fun (v, k) -> sfun "declare-const" [fst (smt_var v); smt_type l k]) in let rec smt_nexp (Nexp_aux (aux, _) : nexp) : sexpr = match aux with | Nexp_id id -> Atom (Util.zencode_string (string_of_id id)) | Nexp_var v -> fst (smt_var v) - | Nexp_constant c - when Big_int.less_equal c (Big_int.of_int (-1)) && not !opt_solver.negative_literals -> - sfun "-" [Atom "0"; Atom (Big_int.to_string (Big_int.abs c))] + | Nexp_constant c when Big_int.less_equal c (Big_int.of_int (-1)) && not !opt_solver.negative_literals -> + sfun "-" [Atom "0"; Atom (Big_int.to_string (Big_int.abs c))] | Nexp_constant c -> Atom (Big_int.to_string c) | Nexp_app (id, nexps) -> sfun (string_of_id id) (List.map smt_nexp nexps) | Nexp_times (nexp1, nexp2) -> sfun "*" [smt_nexp nexp1; smt_nexp nexp2] | Nexp_sum (nexp1, nexp2) -> sfun "+" [smt_nexp nexp1; smt_nexp nexp2] | Nexp_minus (nexp1, nexp2) -> sfun "-" [smt_nexp nexp1; smt_nexp nexp2] - | Nexp_exp nexp -> - begin match nexp_simp nexp with - | Nexp_aux (Nexp_constant c, _) when Big_int.greater_equal c Big_int.zero -> - Atom (Big_int.to_string (Big_int.pow_int_positive 2 (Big_int.to_int c))) - | nexp when !opt_solver.uninterpret_power -> - let exp = smt_nexp nexp in - exponentials := exp :: !exponentials; - sfun "sailexp" [exp] - | nexp -> - let exp = smt_nexp nexp in - exponentials := exp :: !exponentials; - sfun "to_int" [sfun "^" [Atom "2"; exp]] - end + | Nexp_exp nexp -> begin + match nexp_simp nexp with + | Nexp_aux (Nexp_constant c, _) when Big_int.greater_equal c Big_int.zero -> + Atom (Big_int.to_string (Big_int.pow_int_positive 2 (Big_int.to_int c))) + | nexp when !opt_solver.uninterpret_power -> + let exp = smt_nexp nexp in + exponentials := exp :: !exponentials; + sfun "sailexp" [exp] + | nexp -> + let exp = smt_nexp nexp in + exponentials := exp :: !exponentials; + sfun "to_int" [sfun "^" [Atom "2"; exp]] + end | Nexp_neg nexp -> sfun "-" [smt_nexp nexp] in let rec smt_constraint (NC_aux (aux, _) : n_constraint) : sexpr = @@ -241,12 +229,10 @@ let to_smt l vars constr = | NC_bounded_ge (nexp1, nexp2) -> sfun ">=" [smt_nexp nexp1; smt_nexp nexp2] | NC_bounded_gt (nexp1, nexp2) -> sfun ">" [smt_nexp nexp1; smt_nexp nexp2] | NC_not_equal (nexp1, nexp2) -> sfun "not" [sfun "=" [smt_nexp nexp1; smt_nexp nexp2]] - | NC_set (v, ints) -> - sfun "or" (List.map (fun i -> sfun "=" [fst (smt_var v); Atom (Big_int.to_string i)]) ints) + | NC_set (v, ints) -> sfun "or" (List.map (fun i -> sfun "=" [fst (smt_var v); Atom (Big_int.to_string i)]) ints) | NC_or (nc1, nc2) -> sfun "or" [smt_constraint nc1; smt_constraint nc2] | NC_and (nc1, nc2) -> sfun "and" [smt_constraint nc1; smt_constraint nc2] - | NC_app (id, args) -> - sfun (string_of_id id) (List.map smt_typ_arg args) + | NC_app (id, args) -> sfun (string_of_id id) (List.map smt_typ_arg args) | NC_true -> Atom "true" | NC_false -> Atom "false" | NC_var v -> fst (smt_var v) @@ -254,39 +240,36 @@ let to_smt l vars constr = match aux with | A_nexp nexp -> smt_nexp nexp | A_bool nc -> smt_constraint nc - | _ -> - raise (Reporting.err_unreachable l __POS__ "Tried to pass Type or Order kind to SMT function") + | _ -> raise (Reporting.err_unreachable l __POS__ "Tried to pass Type or Order kind to SMT function") in let smt_constr = smt_constraint constr in - var_decs l vars, smt_constr, smt_var, !exponentials - + (var_decs l vars, smt_constr, smt_var, !exponentials) + let sailexp_concrete n = - List.init (n + 1) (fun i -> sfun "=" [sfun "sailexp" [Atom (string_of_int i)]; Atom (Big_int.to_string (Big_int.pow_int_positive 2 i))]) - -let smtlib_of_constraints ?get_model:(get_model=false) l vars extra constr : string * (kid -> sexpr * bool) * sexpr list = + List.init (n + 1) (fun i -> + sfun "=" [sfun "sailexp" [Atom (string_of_int i)]; Atom (Big_int.to_string (Big_int.pow_int_positive 2 i))] + ) + +let smtlib_of_constraints ?(get_model = false) l vars extra constr : string * (kid -> sexpr * bool) * sexpr list = let open Buffer in let buf = create 512 in add_string buf !opt_solver.header; let variables, problem, var_map, exponentials = to_smt l vars constr in add_list buf '\n' add_sexpr variables; add_char buf '\n'; - if !opt_solver.uninterpret_power then ( - add_string buf "(declare-fun sailexp (Int) Int)\n" - ); + if !opt_solver.uninterpret_power then add_string buf "(declare-fun sailexp (Int) Int)\n"; add_list buf '\n' (fun buf sexpr -> add_sexpr buf (sfun "assert" [sexpr])) extra; add_char buf '\n'; add_sexpr buf (sfun "assert" [problem]); add_string buf "\n(check-sat)"; - if get_model then ( - add_string buf "\n(get-model)" - ); + if get_model then add_string buf "\n(get-model)"; add_char buf '\n'; add_string buf !opt_solver.footer; (Buffer.contents buf, var_map, exponentials) type smt_result = Unknown | Sat | Unsat -module DigestMap = Map.Make(Digest) +module DigestMap = Map.Make (Digest) let known_problems = ref DigestMap.empty let known_uniques = ref DigestMap.empty @@ -303,18 +286,15 @@ let load_digests_err () = | 2 -> known_problems := DigestMap.add digest Unsat !known_problems | 3 -> known_uniques := DigestMap.add digest None !known_uniques | 4 -> - let solution = input_binary_int in_chan in - known_uniques := DigestMap.add digest (Some solution) !known_uniques + let solution = input_binary_int in_chan in + known_uniques := DigestMap.add digest (Some solution) !known_uniques | _ -> assert false end; load () in - try load () with - | End_of_file -> close_in in_chan + try load () with End_of_file -> close_in in_chan -let load_digests () = - try load_digests_err () with - | Sys_error _ -> () +let load_digests () = try load_digests_err () with Sys_error _ -> () let save_digests () = let out_chan = open_out_bin "z3_problems" in @@ -330,7 +310,9 @@ let save_digests () = Digest.output out_chan digest; match result with | None -> output_byte out_chan 3 - | Some i -> output_byte out_chan 4; output_binary_int out_chan i + | Some i -> + output_byte out_chan 4; + output_binary_int out_chan i in DigestMap.iter output_solution !known_uniques; close_out out_chan @@ -341,133 +323,131 @@ let bound_exponential sexpr = sfun "and" [sfun "<=" [Atom "0"; sexpr]; sfun "<=" let constraint_to_smt l constr = let vars = - kopts_of_constraint constr - |> KOptSet.elements - |> List.map kopt_pair + kopts_of_constraint constr |> KOptSet.elements |> List.map kopt_pair |> List.fold_left (fun m (k, v) -> KBindings.add k v m) KBindings.empty in let vars, sexpr, var_map, exponentials = to_smt l vars constr in let vars = string_of_list "\n" pp_sexpr vars in - (vars ^ "\n(assert " ^ pp_sexpr sexpr ^ ")"), - (fun v -> let sexpr, found = var_map v in pp_sexpr sexpr, found), - List.map pp_sexpr exponentials - + ( vars ^ "\n(assert " ^ pp_sexpr sexpr ^ ")", + (fun v -> + let sexpr, found = var_map v in + (pp_sexpr sexpr, found) + ), + List.map pp_sexpr exponentials + ) + let rec call_smt' l extra constraints = let vars = - kopts_of_constraint constraints - |> KOptSet.elements - |> List.map kopt_pair + kopts_of_constraint constraints |> KOptSet.elements |> List.map kopt_pair |> List.fold_left (fun m (k, v) -> KBindings.add k v m) KBindings.empty in let problems = [constraints] in let smt_file, _, exponentials = smtlib_of_constraints l vars extra constraints in - if !opt_smt_verbose then ( - prerr_endline (Printf.sprintf "SMTLIB2 constraints are: \n%s%!" smt_file) - ); + if !opt_smt_verbose then prerr_endline (Printf.sprintf "SMTLIB2 constraints are: \n%s%!" smt_file); let rec input_lines chan = function | 0 -> [] | n -> - let l = input_line chan in - let ls = input_lines chan (n - 1) in - l :: ls + let l = input_line chan in + let ls = input_lines chan (n - 1) in + l :: ls in - let rec input_all chan = - match input_line chan with - | l -> l::(input_all chan) - | exception End_of_file -> [] - in + let rec input_all chan = match input_line chan with l -> l :: input_all chan | exception End_of_file -> [] in let digest = Digest.string smt_file in - let result = match DigestMap.find_opt digest !known_problems with + let result = + match DigestMap.find_opt digest !known_problems with | Some result -> result - | None -> - let (input_file, tmp_chan) = - try Filename.open_temp_file "constraint_" ".smt2" with - | Sys_error msg -> raise (Reporting.err_general l ("Could not open temp file when calling SMT: " ^ msg)) - in - output_string tmp_chan smt_file; - close_out tmp_chan; - let status, smt_output, smt_errors = - try - let smt_out, smt_in, smt_err = Unix.open_process_full (!opt_solver.command ^ " " ^ input_file) (Unix.environment ()) in - let smt_output = - try List.combine problems (input_lines smt_out (List.length problems)) with - | End_of_file -> List.combine problems ["unknown"] - in - let smt_errors = input_all smt_err in - let status = Unix.close_process_full (smt_out, smt_in, smt_err) in - status, smt_output, smt_errors - with - | exn -> - raise (Reporting.err_general l ("Error when calling smt: " ^ Printexc.to_string exn)) - in - let _ = match status with - | Unix.WEXITED 0 -> () - | Unix.WEXITED n -> - raise (Reporting.err_general l ("SMT solver returned unexpected status " ^ string_of_int n ^ "\n" ^ String.concat "\n" smt_errors)) - | Unix.WSIGNALED n | Unix.WSTOPPED n -> - raise (Reporting.err_general l ("SMT solver killed by signal " ^ string_of_int n)) - in - Sys.remove input_file; - try - let (_problem, _) = List.find (fun (_, result) -> result = "unsat") smt_output in - known_problems := DigestMap.add digest Unsat !known_problems; - Unsat - with - | Not_found -> + | None -> ( + let input_file, tmp_chan = + try Filename.open_temp_file "constraint_" ".smt2" + with Sys_error msg -> raise (Reporting.err_general l ("Could not open temp file when calling SMT: " ^ msg)) + in + output_string tmp_chan smt_file; + close_out tmp_chan; + let status, smt_output, smt_errors = + try + let smt_out, smt_in, smt_err = + Unix.open_process_full (!opt_solver.command ^ " " ^ input_file) (Unix.environment ()) + in + let smt_output = + try List.combine problems (input_lines smt_out (List.length problems)) + with End_of_file -> List.combine problems ["unknown"] + in + let smt_errors = input_all smt_err in + let status = Unix.close_process_full (smt_out, smt_in, smt_err) in + (status, smt_output, smt_errors) + with exn -> raise (Reporting.err_general l ("Error when calling smt: " ^ Printexc.to_string exn)) + in + let _ = + match status with + | Unix.WEXITED 0 -> () + | Unix.WEXITED n -> + raise + (Reporting.err_general l + ("SMT solver returned unexpected status " ^ string_of_int n ^ "\n" ^ String.concat "\n" smt_errors) + ) + | Unix.WSIGNALED n | Unix.WSTOPPED n -> + raise (Reporting.err_general l ("SMT solver killed by signal " ^ string_of_int n)) + in + Sys.remove input_file; + try + let _problem, _ = List.find (fun (_, result) -> result = "unsat") smt_output in + known_problems := DigestMap.add digest Unsat !known_problems; + Unsat + with Not_found -> let unsolved = List.filter (fun (_, result) -> result = "unknown") smt_output in if unsolved == [] then ( known_problems := DigestMap.add digest Sat !known_problems; Sat - ) else ( + ) + else ( known_problems := DigestMap.add digest Unknown !known_problems; Unknown ) + ) in - (match result with - | Unsat -> Unsat - | Sat -> Sat - | Unknown when exponentials <> [] && not !opt_solver.uninterpret_power -> - (* If we get an unknown result for a constraint involving `2^x`, - then try replacing `2^` with an uninterpreted function to see - if the problem would be unsat in that case. *) - opt_solver := { !opt_solver with uninterpret_power = true }; - let result = call_smt_uninterpret_power ~bound:64 l constraints in - opt_solver := { !opt_solver with uninterpret_power = false }; - result - | Unknown -> Unknown), - exponentials - -and call_smt_uninterpret_power ~bound:bound l constraints = + ( ( match result with + | Unsat -> Unsat + | Sat -> Sat + | Unknown when exponentials <> [] && not !opt_solver.uninterpret_power -> + (* If we get an unknown result for a constraint involving `2^x`, + then try replacing `2^` with an uninterpreted function to see + if the problem would be unsat in that case. *) + opt_solver := { !opt_solver with uninterpret_power = true }; + let result = call_smt_uninterpret_power ~bound:64 l constraints in + opt_solver := { !opt_solver with uninterpret_power = false }; + result + | Unknown -> Unknown + ), + exponentials + ) + +and call_smt_uninterpret_power ~bound l constraints = match call_smt' l (sailexp_concrete bound) constraints with - | (Unsat, _) -> Unsat - | (Sat, exponentials) -> - begin match call_smt' l (sailexp_concrete bound @ List.map bound_exponential exponentials) constraints with - | (Sat, _) -> Sat - | _ -> Unknown - end + | Unsat, _ -> Unsat + | Sat, exponentials -> begin + match call_smt' l (sailexp_concrete bound @ List.map bound_exponential exponentials) constraints with + | Sat, _ -> Sat + | _ -> Unknown + end | _ -> Unknown - + let call_smt l constraints = let t = Profile.start_smt () in let result = - if !opt_solver.uninterpret_power then ( - call_smt_uninterpret_power ~bound:64 l constraints - ) else ( - fst (call_smt' l [] constraints) - ) in + if !opt_solver.uninterpret_power then call_smt_uninterpret_power ~bound:64 l constraints + else fst (call_smt' l [] constraints) + in Profile.finish_smt t; result let solve_smt_file l extra constraints = let vars = - kopts_of_constraint constraints - |> KOptSet.elements - |> List.map kopt_pair + kopts_of_constraint constraints |> KOptSet.elements |> List.map kopt_pair |> List.fold_left (fun m (k, v) -> KBindings.add k v m) KBindings.empty in smtlib_of_constraints ~get_model:true l vars extra constraints @@ -483,11 +463,10 @@ let call_smt_solve l smt_file smt_vars var = let l = input_line chan in let ls = input_all chan in l :: ls - with - End_of_file -> [] + with End_of_file -> [] in - let (input_file, tmp_chan) = Filename.open_temp_file "constraint_" ".smt2" in + let input_file, tmp_chan = Filename.open_temp_file "constraint_" ".smt2" in output_string tmp_chan smt_file; close_out tmp_chan; let smt_output = @@ -498,9 +477,7 @@ let call_smt_solve l smt_file smt_vars var = let _ = Unix.close_process_in smt_chan in Profile.finish_smt t; smt_output - with - | exn -> - raise (Reporting.err_general l ("Got error when calling smt: " ^ Printexc.to_string exn)) + with exn -> raise (Reporting.err_general l ("Got error when calling smt: " ^ Printexc.to_string exn)) in Sys.remove input_file; let regexp = {|(define-fun |} ^ smt_var ^ {| () Int[ ]+\([0-9]+\))|} in @@ -508,8 +485,7 @@ let call_smt_solve l smt_file smt_vars var = let _ = Str.search_forward (Str.regexp regexp) smt_output 0 in let result = Big_int.of_string (Str.matched_group 1 smt_output) in Some result - with - | Not_found -> None + with Not_found -> None let call_smt_solve_bitvector l smt_file smt_vars = let rec input_all chan = @@ -517,11 +493,10 @@ let call_smt_solve_bitvector l smt_file smt_vars = let l = input_line chan in let ls = input_all chan in l :: ls - with - End_of_file -> [] + with End_of_file -> [] in - let (input_file, tmp_chan) = Filename.open_temp_file "constraint_" ".smt2" in + let input_file, tmp_chan = Filename.open_temp_file "constraint_" ".smt2" in output_string tmp_chan smt_file; close_out tmp_chan; let smt_output = @@ -532,14 +507,13 @@ let call_smt_solve_bitvector l smt_file smt_vars = let _ = Unix.close_process_in smt_chan in Profile.finish_smt t; smt_output - with - | exn -> - raise (Reporting.err_general l ("Got error when calling smt: " ^ Printexc.to_string exn)) + with exn -> raise (Reporting.err_general l ("Got error when calling smt: " ^ Printexc.to_string exn)) in Sys.remove input_file; - List.map (fun (smt_var, smt_ty) -> + List.map + (fun (smt_var, smt_ty) -> let smt_var_str = "p" ^ string_of_int smt_var in - try ( + try if smt_ty = "Int" then ( let regexp = "(define-fun " ^ smt_var_str ^ {| () Int [ ]+\([0-9]+\|\((- [0-9]+)\)\))|} in let _ = Str.search_forward (Str.regexp regexp) smt_output 0 in @@ -547,10 +521,10 @@ let call_smt_solve_bitvector l smt_file smt_vars = if result.[0] = '(' then ( let n = Big_int.of_string (String.sub result 3 (String.length result - 4)) in Some (smt_var, mk_lit (L_num (Big_int.negate n))) - ) else ( - Some (smt_var, mk_lit (L_num (Big_int.of_string result))) ) - ) else ( + else Some (smt_var, mk_lit (L_num (Big_int.of_string result))) + ) + else ( let regexp = "(define-fun " ^ smt_var_str ^ " () " ^ smt_ty ^ {|[ ]+\(#[xb]\)\([0-9A-Fa-f]+\))|} in let _ = Str.search_forward (Str.regexp regexp) smt_output 0 in let prefix = Str.matched_group 1 smt_output in @@ -558,26 +532,25 @@ let call_smt_solve_bitvector l smt_file smt_vars = match prefix with | "#b" -> Some (smt_var, mk_lit (L_bin result)) | "#x" -> Some (smt_var, mk_lit (L_hex result)) - | _ -> - raise (Reporting.err_general l "Could not parse bitvector value from SMT solver") + | _ -> raise (Reporting.err_general l "Could not parse bitvector value from SMT solver") ) - ) with - | Not_found -> None - ) smt_vars |> Util.option_all - + with Not_found -> None + ) + smt_vars + |> Util.option_all + let solve_smt l constraints var = let smt_file, smt_vars, _ = solve_smt_file l [] constraints in call_smt_solve l smt_file smt_vars var let solve_all_smt l constraints var = let rec aux results = - let constraints = List.fold_left (fun ncs r -> (nc_and ncs (nc_neq (nconstant r) (nvar var)))) constraints results in + let constraints = List.fold_left (fun ncs r -> nc_and ncs (nc_neq (nconstant r) (nvar var))) constraints results in match solve_smt l constraints var with | Some result -> aux (result :: results) - | None -> - match call_smt l constraints with - | Unsat -> Some results - | _ -> None + | None -> ( + match call_smt l constraints with Unsat -> Some results | _ -> None + ) in aux [] @@ -587,27 +560,30 @@ let solve_unique_smt' l constraints exp_defn exp_bound var = let result = match DigestMap.find_opt digest !known_uniques with | Some (Some result) -> Some (Big_int.of_int result) - | Some (None) -> None - | None -> - match call_smt_solve l smt_file smt_vars var with - | Some result -> - let t = Profile.start_smt () in - let smt_result' = fst (call_smt' l exp_defn (nc_and constraints (nc_neq (nconstant result) (nvar var)))) in - Profile.finish_smt t; - begin match smt_result' with - | Unsat -> - if Big_int.less_equal Big_int.zero result && Big_int.less result (Big_int.pow_int_positive 2 30) then - known_uniques := DigestMap.add digest (Some (Big_int.to_int result)) !known_uniques - else (); - Some result - | _ -> - known_uniques := DigestMap.add digest None !known_uniques; - None - end - | None -> - known_uniques := DigestMap.add digest None !known_uniques; - None - in result, exponentials + | Some None -> None + | None -> ( + match call_smt_solve l smt_file smt_vars var with + | Some result -> + let t = Profile.start_smt () in + let smt_result' = fst (call_smt' l exp_defn (nc_and constraints (nc_neq (nconstant result) (nvar var)))) in + Profile.finish_smt t; + begin + match smt_result' with + | Unsat -> + if Big_int.less_equal Big_int.zero result && Big_int.less result (Big_int.pow_int_positive 2 30) then + known_uniques := DigestMap.add digest (Some (Big_int.to_int result)) !known_uniques + else (); + Some result + | _ -> + known_uniques := DigestMap.add digest None !known_uniques; + None + end + | None -> + known_uniques := DigestMap.add digest None !known_uniques; + None + ) + in + (result, exponentials) (* Follows the same approach as call_smt' for unknown results due to exponentials, retrying with a bounded spec. *) @@ -619,12 +595,12 @@ let solve_unique_smt l constraints var = | Some result, _ -> Some result | None, [] -> None | None, exponentials -> - opt_solver := { !opt_solver with uninterpret_power = true }; - let sailexp = sailexp_concrete 64 in - let exp_bound = List.map bound_exponential exponentials in - let result, _ = solve_unique_smt' l constraints sailexp exp_bound var in - opt_solver := { !opt_solver with uninterpret_power = false }; - result + opt_solver := { !opt_solver with uninterpret_power = true }; + let sailexp = sailexp_concrete 64 in + let exp_bound = List.map bound_exponential exponentials in + let result, _ = solve_unique_smt' l constraints sailexp exp_bound var in + opt_solver := { !opt_solver with uninterpret_power = false }; + result in Profile.finish_smt t; result diff --git a/src/lib/constraint.mli b/src/lib/constraint.mli index c00a3f3ae..333e1c0a9 100644 --- a/src/lib/constraint.mli +++ b/src/lib/constraint.mli @@ -86,7 +86,7 @@ val constraint_to_smt : l -> n_constraint -> string * (kid -> string * bool) * s val call_smt : l -> n_constraint -> smt_result val call_smt_solve_bitvector : l -> string -> (int * string) list -> (int * lit) list option - + val solve_smt : l -> n_constraint -> kid -> Big_int.num option val solve_all_smt : l -> n_constraint -> kid -> Big_int.num list option diff --git a/src/lib/dune b/src/lib/dune index f4e20f1c6..5950e74ad 100644 --- a/src/lib/dune +++ b/src/lib/dune @@ -1,65 +1,111 @@ (env - (dev - (flags (:standard -w -33 -w -27))) - (release - (flags (:standard -w -33 -w -27)))) + (dev + (flags + (:standard -w -33 -w -27))) + (release + (flags + (:standard -w -33 -w -27)))) (rule - (target ast.lem) - (deps (:sail_ott ../../language/sail.ott)) - (action (run ott -sort false -generate_aux_rules true -o %{target} -picky_multiple_parses true %{sail_ott}))) + (target ast.lem) + (deps + (:sail_ott ../../language/sail.ott)) + (action + (run + ott + -sort + false + -generate_aux_rules + true + -o + %{target} + -picky_multiple_parses + true + %{sail_ott}))) (rule - (target jib.lem) - (deps (:jib_ott ../../language/jib.ott) ast.lem) - (action (run ott -sort false -generate_aux_rules true -o %{target} -picky_multiple_parses true %{jib_ott}))) + (target jib.lem) + (deps + (:jib_ott ../../language/jib.ott) + ast.lem) + (action + (run + ott + -sort + false + -generate_aux_rules + true + -o + %{target} + -picky_multiple_parses + true + %{jib_ott}))) (rule - (target ast.ml) - (deps (:ast ast.lem) (:sed ast.sed)) - (action - (progn (run lem -ocaml %{ast}) - (run sed -i.bak -f %{sed} %{target})))) + (target ast.ml) + (deps + (:ast ast.lem) + (:sed ast.sed)) + (action + (progn + (run lem -ocaml %{ast}) + (run sed -i.bak -f %{sed} %{target})))) (copy_files ../gen_lib/*.lem) (rule - (targets - value2.ml - sail2_values.ml - sail2_prompt.ml - sail2_instr_kinds.ml - sail2_prompt_monad.ml - sail2_operators.ml - sail2_operators_bitlists.ml) - (deps - value2.lem - sail2_prompt.lem - sail2_values.lem - sail2_instr_kinds.lem - sail2_prompt_monad.lem - sail2_operators.lem - sail2_operators_bitlists.lem) - (action (run lem -wl_rename ign -wl_pat_comp ign -wl_comp_message ign -ocaml %{deps}))) + (targets + value2.ml + sail2_values.ml + sail2_prompt.ml + sail2_instr_kinds.ml + sail2_prompt_monad.ml + sail2_operators.ml + sail2_operators_bitlists.ml) + (deps + value2.lem + sail2_prompt.lem + sail2_values.lem + sail2_instr_kinds.lem + sail2_prompt_monad.lem + sail2_operators.lem + sail2_operators_bitlists.lem) + (action + (run + lem + -wl_rename + ign + -wl_pat_comp + ign + -wl_comp_message + ign + -ocaml + %{deps}))) (rule - (target jib.ml) - (deps (:jib jib.lem) (:sed ast.sed) value2.ml (glob_files lem/*.lem)) - (action - (progn (run lem -ocaml %{jib} -lib . -lib lem/) - (run sed -i.bak -f %{sed} %{target})))) + (target jib.ml) + (deps + (:jib jib.lem) + (:sed ast.sed) + value2.ml + (glob_files lem/*.lem)) + (action + (progn + (run lem -ocaml %{jib} -lib . -lib lem/) + (run sed -i.bak -f %{sed} %{target})))) (menhir - (modules parser)) + (modules parser)) (ocamllex lexer) (generate_sites_module - (module libsail_sites) - (sites libsail)) - + (module libsail_sites) + (sites libsail)) + (library - (name libsail) - (public_name libsail) - (libraries lem linksem pprint dune-site yojson) - (instrumentation (backend bisect_ppx))) + (name libsail) + (public_name libsail) + (libraries lem linksem pprint dune-site yojson) + (instrumentation + (backend bisect_ppx))) diff --git a/src/lib/effects.ml b/src/lib/effects.ml index 8ddb4f2bb..a1c6d86a3 100644 --- a/src/lib/effects.ml +++ b/src/lib/effects.ml @@ -91,7 +91,7 @@ let string_of_side_effect = function | Undefined -> "contains undefined literal" | Scattered -> "scattered function" | NonExec -> "not executable" - | Outcome id -> ("outcome " ^ string_of_id id) + | Outcome id -> "outcome " ^ string_of_id id module Effect = struct type t = side_effect @@ -106,17 +106,25 @@ module Effect = struct | Scattered, Scattered -> 0 | NonExec, NonExec -> 0 | Outcome id1, Outcome id2 -> Id.compare id1 id2 - | Throw, _ -> 1 | _, Throw -> -1 - | Exit, _ -> 1 | _, Exit -> -1 - | IncompleteMatch, _ -> 1 | _, IncompleteMatch -> -1 - | External, _ -> 1 | _, External -> -1 - | Undefined, _ -> 1 | _, Undefined -> -1 - | Scattered, _ -> 1 | _, Scattered -> -1 - | NonExec, _ -> 1 | _, NonExec -> -1 - | Outcome _, _ -> 1 | _, Outcome _ -> -1 + | Throw, _ -> 1 + | _, Throw -> -1 + | Exit, _ -> 1 + | _, Exit -> -1 + | IncompleteMatch, _ -> 1 + | _, IncompleteMatch -> -1 + | External, _ -> 1 + | _, External -> -1 + | Undefined, _ -> 1 + | _, Undefined -> -1 + | Scattered, _ -> 1 + | _, Scattered -> -1 + | NonExec, _ -> 1 + | _, NonExec -> -1 + | Outcome _, _ -> 1 + | _, Outcome _ -> -1 end -module EffectSet = Set.Make(Effect) +module EffectSet = Set.Make (Effect) let throws = EffectSet.mem Throw @@ -125,7 +133,7 @@ let non_exec = EffectSet.mem NonExec let pure = EffectSet.is_empty let effectful set = not (pure set) - + let has_outcome id = EffectSet.mem (Outcome id) module PC_config = struct @@ -134,7 +142,7 @@ module PC_config = struct let add_attribute l attr arg = Type_check.map_uannot (add_attribute l attr arg) end -module PC = Pattern_completeness.Make(PC_config) +module PC = Pattern_completeness.Make (PC_config) let funcls_info = function | FCL_aux (FCL_funcl (id, Pat_aux (Pat_exp (pat, _), _)), _) :: _ -> Some (id, typ_of_pat pat, env_of_pat pat) @@ -146,80 +154,75 @@ let infer_def_direct_effects asserts_termination def = let scan_lexp lexp_aux annot = let env = env_of_annot annot in - begin match lexp_aux with - | LE_typ (_, id) | LE_id id -> - begin match Env.lookup_id id env with - | Register _ -> - effects := EffectSet.add Register !effects - | _ -> () - end - | LE_deref _ -> effects := EffectSet.add Register !effects - | _ -> () + begin + match lexp_aux with + | LE_typ (_, id) | LE_id id -> begin + match Env.lookup_id id env with Register _ -> effects := EffectSet.add Register !effects | _ -> () + end + | LE_deref _ -> effects := EffectSet.add Register !effects + | _ -> () end; LE_aux (lexp_aux, annot) in let scan_exp e_aux annot = let env = env_of_annot annot in - begin match e_aux with - | E_id id -> - begin match Env.lookup_id id env with - | Register _ -> - effects := EffectSet.add Register !effects - | _ -> () - end - | E_lit (L_aux (L_undef, _)) -> effects := EffectSet.add Undefined !effects - | E_throw _ -> effects := EffectSet.add Throw !effects - | E_exit _ | E_assert _ -> effects := EffectSet.add Exit !effects - | E_app (f, _) when Id.compare f (mk_id "__deref") = 0 -> - effects := EffectSet.add Register !effects - | E_match (_, _) -> - if Option.is_some (snd annot |> untyped_annot |> get_attribute "incomplete") then ( - effects := EffectSet.add IncompleteMatch !effects - ) - | E_loop (_, Measure_aux (Measure_some _, _), _, _) when asserts_termination -> - effects := EffectSet.add Exit !effects - | _ -> () + begin + match e_aux with + | E_id id -> begin + match Env.lookup_id id env with Register _ -> effects := EffectSet.add Register !effects | _ -> () + end + | E_lit (L_aux (L_undef, _)) -> effects := EffectSet.add Undefined !effects + | E_throw _ -> effects := EffectSet.add Throw !effects + | E_exit _ | E_assert _ -> effects := EffectSet.add Exit !effects + | E_app (f, _) when Id.compare f (mk_id "__deref") = 0 -> effects := EffectSet.add Register !effects + | E_match (_, _) -> + if Option.is_some (snd annot |> untyped_annot |> get_attribute "incomplete") then + effects := EffectSet.add IncompleteMatch !effects + | E_loop (_, Measure_aux (Measure_some _, _), _, _) when asserts_termination -> + effects := EffectSet.add Exit !effects + | _ -> () end; E_aux (e_aux, annot) in let scan_pat p_aux annot = - begin match p_aux with - | P_string_append _ -> effects := EffectSet.add NonExec !effects - | _ -> () + begin + match p_aux with P_string_append _ -> effects := EffectSet.add NonExec !effects | _ -> () end; P_aux (p_aux, annot) in let pat_alg = { id_pat_alg with p_aux = (fun (p_aux, annot) -> scan_pat p_aux annot) } in - + let rw_exp _ exp = - fold_exp { id_exp_alg with e_aux = (fun (e_aux, annot) -> scan_exp e_aux annot); - le_aux = (fun (l_aux, annot) -> scan_lexp l_aux annot); - pat_alg = pat_alg } exp in - ignore (rewrite_ast_defs { rewriters_base with rewrite_exp = rw_exp; - rewrite_pat = (fun _ -> fold_pat pat_alg) } [def]); - - begin match def with - | DEF_aux (DEF_val (VS_aux (VS_val_spec (_, id, Some { pure = false; _ }, _), _)), _) -> - effects := EffectSet.add External !effects - | DEF_aux (DEF_fundef (FD_aux (FD_function (_, _, funcls), (l, _))), def_annot) -> - begin match funcls_info funcls with - | Some (id, typ, env) -> - if Option.is_some (get_def_attribute "incomplete" def_annot) then ( - effects := EffectSet.add IncompleteMatch !effects - ) - | None -> - Reporting.unreachable l __POS__ "Empty funcls in infer_def_direct_effects" - end - | DEF_aux (DEF_mapdef _, _) -> - effects := EffectSet.add IncompleteMatch !effects - | DEF_aux (DEF_scattered _, _) -> - effects := EffectSet.add Scattered !effects - | _ -> () + fold_exp + { + id_exp_alg with + e_aux = (fun (e_aux, annot) -> scan_exp e_aux annot); + le_aux = (fun (l_aux, annot) -> scan_lexp l_aux annot); + pat_alg; + } + exp + in + ignore (rewrite_ast_defs { rewriters_base with rewrite_exp = rw_exp; rewrite_pat = (fun _ -> fold_pat pat_alg) } [def]); + + begin + match def with + | DEF_aux (DEF_val (VS_aux (VS_val_spec (_, id, Some { pure = false; _ }, _), _)), _) -> + effects := EffectSet.add External !effects + | DEF_aux (DEF_fundef (FD_aux (FD_function (_, _, funcls), (l, _))), def_annot) -> begin + match funcls_info funcls with + | Some (id, typ, env) -> + if Option.is_some (get_def_attribute "incomplete" def_annot) then + effects := EffectSet.add IncompleteMatch !effects + | None -> Reporting.unreachable l __POS__ "Empty funcls in infer_def_direct_effects" + end + | DEF_aux (DEF_mapdef _, _) -> effects := EffectSet.add IncompleteMatch !effects + | DEF_aux (DEF_scattered _, _) -> effects := EffectSet.add Scattered !effects + | _ -> () end; - + !effects let infer_mapdef_extra_direct_effects def = @@ -229,35 +232,31 @@ let infer_mapdef_extra_direct_effects def = let scan_mpat set mp_aux annot = match mp_aux with | Some (MP_string_append _ as aux) -> - set := EffectSet.add NonExec !set; - Some (MP_aux (aux, annot)) + set := EffectSet.add NonExec !set; + Some (MP_aux (aux, annot)) | Some aux -> Some (MP_aux (aux, annot)) | None -> None in let rw_mpat set = fold_mpat { id_mpat_alg with p_aux = (fun (mp_aux, annot) -> scan_mpat set mp_aux annot) } in let scan_mpexp set (MPat_aux (aux, _)) = - match aux with - | MPat_pat mpat -> ignore (rw_mpat set mpat) - | MPat_when (mpat, _) -> ignore (rw_mpat set mpat) + match aux with MPat_pat mpat -> ignore (rw_mpat set mpat) | MPat_when (mpat, _) -> ignore (rw_mpat set mpat) in let scan_mapcl (MCL_aux (aux, _)) = match aux with | MCL_bidir (forward, backward) -> - scan_mpexp forward_effects forward; - scan_mpexp backward_effects backward - | MCL_forwards (forward, _) -> - scan_mpexp forward_effects forward - | MCL_backwards (backward, _) -> - scan_mpexp backward_effects backward + scan_mpexp forward_effects forward; + scan_mpexp backward_effects backward + | MCL_forwards (forward, _) -> scan_mpexp forward_effects forward + | MCL_backwards (backward, _) -> scan_mpexp backward_effects backward in - - begin match def with - | DEF_aux (DEF_mapdef (MD_aux (MD_mapping (_, _, mapcls), _)), _) -> - List.iter scan_mapcl mapcls - | _ -> () + + begin + match def with + | DEF_aux (DEF_mapdef (MD_aux (MD_mapping (_, _, mapcls), _)), _) -> List.iter scan_mapcl mapcls + | _ -> () end; - - !forward_effects, !backward_effects + + (!forward_effects, !backward_effects) (* A top-level definition can have a side effect if it contains an expression which could have some side effect *) @@ -265,7 +264,7 @@ let can_have_direct_side_effect (DEF_aux (aux, _)) = match aux with | DEF_type _ -> false | DEF_fundef _ -> true - | DEF_mapdef _ -> false + | DEF_mapdef _ -> false | DEF_impl _ -> true | DEF_let _ -> true | DEF_val _ -> true @@ -280,82 +279,73 @@ let can_have_direct_side_effect (DEF_aux (aux, _)) = | DEF_register _ -> true | DEF_internal_mutrec _ -> true | DEF_pragma _ -> false - + type side_effect_info = { - functions : EffectSet.t Bindings.t; - letbinds : EffectSet.t Bindings.t; - mappings : EffectSet.t Bindings.t - } - -let empty_side_effect_info = { - functions = Bindings.empty; - letbinds = Bindings.empty; - mappings = Bindings.empty - } - -let function_is_pure id info = - match Bindings.find_opt id info.functions with - | Some eff -> pure eff - | None -> true - -let is_function = function - | Callgraph.Function _ -> true - | _ -> false - -let is_mapping = function - | Callgraph.Mapping _ -> true - | _ -> false + functions : EffectSet.t Bindings.t; + letbinds : EffectSet.t Bindings.t; + mappings : EffectSet.t Bindings.t; +} + +let empty_side_effect_info = { functions = Bindings.empty; letbinds = Bindings.empty; mappings = Bindings.empty } + +let function_is_pure id info = match Bindings.find_opt id info.functions with Some eff -> pure eff | None -> true + +let is_function = function Callgraph.Function _ -> true | _ -> false + +let is_mapping = function Callgraph.Mapping _ -> true | _ -> false (* Normally we only add effects once, but occasionally we need to add more (e.g., if a function is external for some targets, but defined in Sail for others). *) let add_effects id effect_set effect_map = - Bindings.update id (function Some effects -> Some (EffectSet.union effects effect_set) | None -> Some effect_set) effect_map - + Bindings.update id + (function Some effects -> Some (EffectSet.union effects effect_set) | None -> Some effect_set) + effect_map + let infer_side_effects asserts_termination ast = - let module NodeSet = Set.Make(Callgraph.Node) in + let module NodeSet = Set.Make (Callgraph.Node) in let cg = Callgraph.graph_of_ast ast in let total = List.length ast.defs in let direct_effects = ref Bindings.empty in let fun_termination_asserts = ref IdSet.empty in let infer_fun_termination_assert def = - if asserts_termination then + if asserts_termination then ( match def with - | DEF_aux (DEF_measure (id, _, _), _) -> - fun_termination_asserts := IdSet.add id !fun_termination_asserts + | DEF_aux (DEF_measure (id, _, _), _) -> fun_termination_asserts := IdSet.add id !fun_termination_asserts | _ -> () + ) in - List.iteri (fun i def -> + List.iteri + (fun i def -> Util.progress "Effects (direct) " (string_of_int (i + 1) ^ "/" ^ string_of_int total) (i + 1) total; (* Handle mapping separately to allow different effects for both directions *) - begin match def with - | DEF_aux (DEF_mapdef mdef, _) -> - let effs = infer_def_direct_effects asserts_termination def in - let fw, bk = infer_mapdef_extra_direct_effects def in - let id = id_of_mapdef mdef in - direct_effects := add_effects id effs !direct_effects; - direct_effects := add_effects (append_id id "_forwards") fw !direct_effects; - direct_effects := add_effects (append_id id "_backwards") bk !direct_effects - | _ when can_have_direct_side_effect def -> - infer_fun_termination_assert def; - let effs = infer_def_direct_effects asserts_termination def in - let ids = ids_of_def def in - IdSet.iter (fun id -> - direct_effects := add_effects id effs !direct_effects - ) ids - | _ -> () + begin + match def with + | DEF_aux (DEF_mapdef mdef, _) -> + let effs = infer_def_direct_effects asserts_termination def in + let fw, bk = infer_mapdef_extra_direct_effects def in + let id = id_of_mapdef mdef in + direct_effects := add_effects id effs !direct_effects; + direct_effects := add_effects (append_id id "_forwards") fw !direct_effects; + direct_effects := add_effects (append_id id "_backwards") bk !direct_effects + | _ when can_have_direct_side_effect def -> + infer_fun_termination_assert def; + let effs = infer_def_direct_effects asserts_termination def in + let ids = ids_of_def def in + IdSet.iter (fun id -> direct_effects := add_effects id effs !direct_effects) ids + | _ -> () end - ) ast.defs; + ) + ast.defs; (* If asserts_termination is true then we will have a set of recursive functions where the target will assert that the termination measure is respected, so add suitable effects for the assert. While loops are handled in infer_def_direct_effects - above. *) - direct_effects := IdSet.fold - (fun id -> add_effects id (EffectSet.singleton Exit)) - !fun_termination_asserts !direct_effects; + above. *) + direct_effects := + IdSet.fold (fun id -> add_effects id (EffectSet.singleton Exit)) !fun_termination_asserts !direct_effects; let function_effects = ref Bindings.empty in let letbind_effects = ref Bindings.empty in @@ -363,76 +353,84 @@ let infer_side_effects asserts_termination ast = let all_nodes = Callgraph.G.nodes cg in let total = List.length all_nodes in - List.iteri (fun i node -> + List.iteri + (fun i node -> Util.progress "Effects (transitive) " (string_of_int (i + 1) ^ "/" ^ string_of_int total) (i + 1) total; match node with | Callgraph.Function id | Callgraph.Letbind id | Callgraph.Mapping id -> - let reachable = Callgraph.G.reachable (NodeSet.singleton node) NodeSet.empty cg in - (* First, a function has any side effects it directly causes *) - let side_effects = match Bindings.find_opt id !direct_effects with Some effs -> effs | None -> EffectSet.empty in - (* Second, a function has any side effects from any reachable callee function *) - let side_effects = - NodeSet.fold (fun node side_effects -> - match Bindings.find_opt (Callgraph.node_id node) !direct_effects with - | Some effs -> EffectSet.union effs side_effects - | None -> side_effects - ) reachable side_effects in - (* Third, if a function or any callee invokes an outcome, it has that effect *) - let side_effects = - NodeSet.filter (function Callgraph.Outcome _ -> true | _ -> false) reachable - |> NodeSet.elements - |> List.map (fun node -> Outcome (Callgraph.node_id node)) - |> EffectSet.of_list - |> EffectSet.union side_effects - in - if is_function node then - function_effects := Bindings.add id side_effects !function_effects - else if is_mapping node then - mapping_effects := Bindings.add id side_effects !mapping_effects - else ( - letbind_effects := Bindings.add id side_effects !letbind_effects - ) + let reachable = Callgraph.G.reachable (NodeSet.singleton node) NodeSet.empty cg in + (* First, a function has any side effects it directly causes *) + let side_effects = + match Bindings.find_opt id !direct_effects with Some effs -> effs | None -> EffectSet.empty + in + (* Second, a function has any side effects from any reachable callee function *) + let side_effects = + NodeSet.fold + (fun node side_effects -> + match Bindings.find_opt (Callgraph.node_id node) !direct_effects with + | Some effs -> EffectSet.union effs side_effects + | None -> side_effects + ) + reachable side_effects + in + (* Third, if a function or any callee invokes an outcome, it has that effect *) + let side_effects = + NodeSet.filter (function Callgraph.Outcome _ -> true | _ -> false) reachable + |> NodeSet.elements + |> List.map (fun node -> Outcome (Callgraph.node_id node)) + |> EffectSet.of_list |> EffectSet.union side_effects + in + if is_function node then function_effects := Bindings.add id side_effects !function_effects + else if is_mapping node then mapping_effects := Bindings.add id side_effects !mapping_effects + else letbind_effects := Bindings.add id side_effects !letbind_effects | _ -> () - ) all_nodes; + ) + all_nodes; - { - functions = !function_effects; - letbinds = !letbind_effects; - mappings = !mapping_effects - } + { functions = !function_effects; letbinds = !letbind_effects; mappings = !mapping_effects } let check_side_effects effect_info ast = let allowed_nonexec = ref IdSet.empty in - List.iter (fun (DEF_aux (aux, _) as def) -> + List.iter + (fun (DEF_aux (aux, _) as def) -> match aux with | DEF_pragma ("non_exec", name, _) -> allowed_nonexec := IdSet.add (mk_id name) !allowed_nonexec | DEF_let _ -> - IdSet.iter (fun id -> - match Bindings.find_opt id effect_info.letbinds with - | Some eff when not (pure eff) -> - raise (Reporting.err_general (id_loc id) - ("Top-level let statement must not have any side effects. Found side effects: " - ^ Util.string_of_list ", " string_of_side_effect (EffectSet.elements eff))) - | _ -> () - ) (ids_of_def def) + IdSet.iter + (fun id -> + match Bindings.find_opt id effect_info.letbinds with + | Some eff when not (pure eff) -> + raise + (Reporting.err_general (id_loc id) + ("Top-level let statement must not have any side effects. Found side effects: " + ^ Util.string_of_list ", " string_of_side_effect (EffectSet.elements eff) + ) + ) + | _ -> () + ) + (ids_of_def def) | DEF_fundef fdef -> - let id = id_of_fundef fdef in - let eff = Bindings.find_opt (id_of_fundef fdef) effect_info.functions |> Option.value ~default:EffectSet.empty in - if non_exec eff && not (IdSet.mem id !allowed_nonexec) then - raise (Reporting.err_general (id_loc id) ("Function " ^ string_of_id id ^ " calls function marked non-executable")) + let id = id_of_fundef fdef in + let eff = + Bindings.find_opt (id_of_fundef fdef) effect_info.functions |> Option.value ~default:EffectSet.empty + in + if non_exec eff && not (IdSet.mem id !allowed_nonexec) then + raise + (Reporting.err_general (id_loc id) + ("Function " ^ string_of_id id ^ " calls function marked non-executable") + ) | _ -> () - ) ast.defs + ) + ast.defs let copy_function_effect id_from effect_info id_to = match Bindings.find_opt id_from effect_info.functions with - | Some eff -> - { effect_info with functions = Bindings.add id_to eff effect_info.functions } + | Some eff -> { effect_info with functions = Bindings.add id_to eff effect_info.functions } | None -> effect_info let add_function_effect id_from effect_info id_to = match Bindings.find_opt id_from effect_info.functions with - | Some eff -> - { effect_info with functions = add_effects id_to eff effect_info.functions } + | Some eff -> { effect_info with functions = add_effects id_to eff effect_info.functions } | None -> effect_info let copy_mapping_to_function id_from effect_info id_to = @@ -441,74 +439,61 @@ let copy_mapping_to_function id_from effect_info id_to = exists - this likely means the function has been manually defined. *) | Some eff -> - let existing_effects = match Bindings.find_opt id_to effect_info.functions with - | Some existing_eff -> existing_eff - | None -> EffectSet.empty in - { effect_info with functions = Bindings.add id_to (EffectSet.union eff existing_effects) effect_info.functions } - | _ -> - effect_info + let existing_effects = + match Bindings.find_opt id_to effect_info.functions with + | Some existing_eff -> existing_eff + | None -> EffectSet.empty + in + { effect_info with functions = Bindings.add id_to (EffectSet.union eff existing_effects) effect_info.functions } + | _ -> effect_info let add_monadic_built_in id effect_info = - { effect_info with - functions = add_effects id (EffectSet.singleton External) effect_info.functions - } + { effect_info with functions = add_effects id (EffectSet.singleton External) effect_info.functions } let rewrite_attach_effects effect_info = let rewrite_lexp_aux ((child_eff, lexp_aux), (l, tannot)) = let env = env_of_tannot tannot in - let eff = match lexp_aux with - | LE_typ (_, id) | LE_id id -> - begin match Env.lookup_id id env with - | Register _ -> monadic_effect - | _ -> no_effect - end + let eff = + match lexp_aux with + | LE_typ (_, id) | LE_id id -> begin + match Env.lookup_id id env with Register _ -> monadic_effect | _ -> no_effect + end | LE_deref _ -> monadic_effect | _ -> no_effect in let eff = union_effects eff child_eff in - eff, LE_aux (lexp_aux, (l, add_effect_annot tannot eff)) + (eff, LE_aux (lexp_aux, (l, add_effect_annot tannot eff))) in let rewrite_e_aux ((child_eff, e_aux), (l, tannot)) = let env = env_of_tannot tannot in - let eff = match e_aux with + let eff = + match e_aux with | E_app (f, _) when string_of_id f = "early_return" -> monadic_effect - | E_app (f, _) -> - begin match Bindings.find_opt f effect_info.functions with - | Some side_effects -> if pure side_effects then no_effect else monadic_effect - | None -> no_effect - end + | E_app (f, _) -> begin + match Bindings.find_opt f effect_info.functions with + | Some side_effects -> if pure side_effects then no_effect else monadic_effect + | None -> no_effect + end | E_lit (L_aux (L_undef, _)) -> monadic_effect - | E_id id -> - begin match Env.lookup_id id env with - | Register _ -> monadic_effect - | _ -> no_effect - end + | E_id id -> begin match Env.lookup_id id env with Register _ -> monadic_effect | _ -> no_effect end | E_throw _ -> monadic_effect | E_exit _ | E_assert _ -> monadic_effect | _ -> no_effect in let eff = union_effects eff child_eff in - eff, E_aux (e_aux, (l, add_effect_annot tannot eff)) + (eff, E_aux (e_aux, (l, add_effect_annot tannot eff))) in - + let rw_exp = - fold_exp { - (compute_exp_alg no_effect union_effects) - with - e_aux = rewrite_e_aux; - le_aux = rewrite_lexp_aux; - } + fold_exp { (compute_exp_alg no_effect union_effects) with e_aux = rewrite_e_aux; le_aux = rewrite_lexp_aux } in rewrite_ast_base { rewriters_base with rewrite_exp = (fun _ exp -> snd (rw_exp exp)) } -let string_of_effectset set = - String.concat ", " (List.map string_of_side_effect (EffectSet.elements set)) +let string_of_effectset set = String.concat ", " (List.map string_of_side_effect (EffectSet.elements set)) let dump_effect_bindings bindings = - Bindings.iter (fun id set -> - Printf.eprintf " %s: %s\n%!" (string_of_id id) (string_of_effectset set)) - bindings + Bindings.iter (fun id set -> Printf.eprintf " %s: %s\n%!" (string_of_id id) (string_of_effectset set)) bindings let dump_effects effect_info = prerr_endline "Function effects:"; diff --git a/src/lib/effects.mli b/src/lib/effects.mli index 363be308b..77839bf8c 100644 --- a/src/lib/effects.mli +++ b/src/lib/effects.mli @@ -90,7 +90,7 @@ end (* Note we intentionally keep the side effect type abstract, and expose some functions on effect sets based on what we actually need. *) - + val throws : EffectSet.t -> bool val pure : EffectSet.t -> bool @@ -108,10 +108,10 @@ val effectful : EffectSet.t -> bool val has_outcome : id -> EffectSet.t -> bool type side_effect_info = { - functions : EffectSet.t Bindings.t; - letbinds : EffectSet.t Bindings.t; - mappings : EffectSet.t Bindings.t - } + functions : EffectSet.t Bindings.t; + letbinds : EffectSet.t Bindings.t; + mappings : EffectSet.t Bindings.t; +} val empty_side_effect_info : side_effect_info @@ -133,6 +133,7 @@ val check_side_effects : side_effect_info -> Type_check.tannot ast -> unit information. The order of arguments is to make it convenient to use with List.fold_left. *) val copy_function_effect : id -> side_effect_info -> id -> side_effect_info + val copy_mapping_to_function : id -> side_effect_info -> id -> side_effect_info (** [add_function_effect id_from info id_to] adds the effect diff --git a/src/lib/elf_loader.ml b/src/lib/elf_loader.ml index e4c709579..03e0bd669 100644 --- a/src/lib/elf_loader.ml +++ b/src/lib/elf_loader.ml @@ -78,22 +78,20 @@ let opt_symbol_map = ref ([] : Elf_file.global_symbol_init_info) type word8 = int -let escape_char c = - if int_of_char c <= 31 then '.' - else if int_of_char c >= 127 then '.' - else c +let escape_char c = if int_of_char c <= 31 then '.' else if int_of_char c >= 127 then '.' else c let hex_line bs = - let hex_char i c = - (if i mod 2 == 0 && i <> 0 then " " else "") ^ Printf.sprintf "%02x" (int_of_char c) - in - String.concat "" (List.mapi hex_char bs) ^ " " ^ String.concat "" (List.map (fun c -> Printf.sprintf "%c" (escape_char c)) bs) + let hex_char i c = (if i mod 2 == 0 && i <> 0 then " " else "") ^ Printf.sprintf "%02x" (int_of_char c) in + String.concat "" (List.mapi hex_char bs) + ^ " " + ^ String.concat "" (List.map (fun c -> Printf.sprintf "%c" (escape_char c)) bs) let break n xs = - let rec helper acc =function + let rec helper acc = function | [] -> List.rev acc - | (_ :: _ as xs) -> helper ([Lem_list.take n xs] @ acc) (Lem_list.drop n xs) - in helper [] xs + | _ :: _ as xs -> helper ([Lem_list.take n xs] @ acc) (Lem_list.drop n xs) + in + helper [] xs let print_segment bs = prerr_endline "0011 2233 4455 6677 8899 aabb ccdd eeff 0123456789abcdef"; @@ -106,41 +104,44 @@ type elf_segs = let read name = let info = Sail_interface.populate_and_obtain_global_symbol_init_info name in prerr_endline "Elf read:"; - let (elf_file, elf_epi, symbol_map) = - begin match info with - | Error.Fail s -> failwith (Printf.sprintf "populate_and_obtain_global_symbol_init_info: %s" s) - | Error.Success ((elf_file: Elf_file.elf_file), - (elf_epi: Sail_interface.executable_process_image), - (symbol_map: Elf_file.global_symbol_init_info)) - -> - (* XXX disabled because it crashes if entry_point overflows an ocaml int :-( - prerr_endline (Sail_interface.string_of_executable_process_image elf_epi);*) - (elf_file, elf_epi, symbol_map) + let elf_file, elf_epi, symbol_map = + begin + match info with + | Error.Fail s -> failwith (Printf.sprintf "populate_and_obtain_global_symbol_init_info: %s" s) + | Error.Success + ( (elf_file : Elf_file.elf_file), + (elf_epi : Sail_interface.executable_process_image), + (symbol_map : Elf_file.global_symbol_init_info) + ) -> + (* XXX disabled because it crashes if entry_point overflows an ocaml int :-( + prerr_endline (Sail_interface.string_of_executable_process_image elf_epi);*) + (elf_file, elf_epi, symbol_map) end in prerr_endline "\nElf segments:"; (* remove all the auto generated segments (they contain only 0s) *) let prune_segments segs = - Lem_list.mapMaybe (fun (seg, prov) -> if prov = Elf_file.FromELF then Some seg else None) segs in - let (segments, e_entry, _e_machine) = - begin match elf_epi, elf_file with - | (Sail_interface.ELF_Class_32 (segments, e_entry, e_machine), Elf_file.ELF_File_32 _) -> + Lem_list.mapMaybe (fun (seg, prov) -> if prov = Elf_file.FromELF then Some seg else None) segs + in + let segments, e_entry, _e_machine = + begin + match (elf_epi, elf_file) with + | Sail_interface.ELF_Class_32 (segments, e_entry, e_machine), Elf_file.ELF_File_32 _ -> (ELF32 (prune_segments segments), e_entry, e_machine) - | (Sail_interface.ELF_Class_64 (segments, e_entry, e_machine), Elf_file.ELF_File_64 _) -> + | Sail_interface.ELF_Class_64 (segments, e_entry, e_machine), Elf_file.ELF_File_64 _ -> (ELF64 (prune_segments segments), e_entry, e_machine) - | (_, _) -> failwith "cannot handle ELF file" + | _, _ -> failwith "cannot handle ELF file" end in (segments, e_entry, symbol_map) -let write_sail_lib paddr i byte = - Sail_lib.wram (Big_int.add paddr (Big_int.of_int i)) byte +let write_sail_lib paddr i byte = Sail_lib.wram (Big_int.add paddr (Big_int.of_int i)) byte let write_mem_zeros start len = (* write in order for mem tracing logs *) let i = ref Big_int.zero in - while (Big_int.less !i len) do + while Big_int.less !i len do Sail_lib.wram (Big_int.add start !i) 0; i := Big_int.succ !i done @@ -151,62 +152,66 @@ let write_file chan paddr i byte = let print_seg_info offset base paddr size memsz = prerr_endline "\nLoading Segment"; - prerr_endline ("Segment offset: " ^ (Printf.sprintf "0x%Lx" (Big_int.to_int64 offset))); - prerr_endline ("Segment base address: " ^ (Big_int.to_string base)); + prerr_endline ("Segment offset: " ^ Printf.sprintf "0x%Lx" (Big_int.to_int64 offset)); + prerr_endline ("Segment base address: " ^ Big_int.to_string base); (* NB don't attempt to convert paddr to int64 because on MIPS it is quite likely to exceed signed 64-bit range e.g. addresses beginning 0x9.... Really need to_uint64 or to_string_hex but lem doesn't have them. *) - prerr_endline ("Segment physical address: " ^ (Printf.sprintf "0x%Lx" (Big_int.to_int64 paddr))); - prerr_endline ("Segment size: " ^ (Printf.sprintf "0x%Lx" (Big_int.to_int64 size))); - prerr_endline ("Segment memsz: " ^ (Printf.sprintf "0x%Lx" (Big_int.to_int64 memsz))) + prerr_endline ("Segment physical address: " ^ Printf.sprintf "0x%Lx" (Big_int.to_int64 paddr)); + prerr_endline ("Segment size: " ^ Printf.sprintf "0x%Lx" (Big_int.to_int64 size)); + prerr_endline ("Segment memsz: " ^ Printf.sprintf "0x%Lx" (Big_int.to_int64 memsz)) -let load_segment ?writer:(writer=write_sail_lib) bs paddr base offset size memsz = +let load_segment ?(writer = write_sail_lib) bs paddr base offset size memsz = print_seg_info offset base paddr size memsz; print_segment bs; List.iteri (writer paddr) (List.rev_map int_of_char (List.rev (Byte_sequence.char_list_of_byte_sequence bs))); write_mem_zeros (Big_int.add paddr size) (Big_int.sub memsz size) -let load_elf ?writer:(writer=write_sail_lib) name = +let load_elf ?(writer = write_sail_lib) name = let segments, e_entry, symbol_map = read name in opt_elf_entry := e_entry; opt_symbol_map := symbol_map; - (if List.mem_assoc "tohost" symbol_map then - let (_, _, tohost_addr, _, _) = List.assoc "tohost" symbol_map in - opt_elf_tohost := tohost_addr); - (match segments with - | ELF64 segs -> - List.iter (fun seg -> - let open Elf_interpreted_segment in - let bs = seg.elf64_segment_body in - let paddr = seg.elf64_segment_paddr in - let base = seg.elf64_segment_base in - let offset = seg.elf64_segment_offset in - let size = seg.elf64_segment_size in - let memsz = seg.elf64_segment_memsz in - load_segment ~writer:writer bs paddr base offset size memsz) - segs; - opt_elf_class := ELF_Class_64 - | ELF32 segs -> - List.iter (fun seg -> - let open Elf_interpreted_segment in - let bs = seg.elf32_segment_body in - let paddr = seg.elf32_segment_paddr in - let base = seg.elf32_segment_base in - let offset = seg.elf32_segment_offset in - let size = seg.elf32_segment_size in - let memsz = seg.elf32_segment_memsz in - load_segment ~writer:writer bs paddr base offset size memsz) - segs; - opt_elf_class := ELF_Class_32 - ) - -let load_binary ?writer:(writer=write_sail_lib) addr name = + if List.mem_assoc "tohost" symbol_map then ( + let _, _, tohost_addr, _, _ = List.assoc "tohost" symbol_map in + opt_elf_tohost := tohost_addr + ); + match segments with + | ELF64 segs -> + List.iter + (fun seg -> + let open Elf_interpreted_segment in + let bs = seg.elf64_segment_body in + let paddr = seg.elf64_segment_paddr in + let base = seg.elf64_segment_base in + let offset = seg.elf64_segment_offset in + let size = seg.elf64_segment_size in + let memsz = seg.elf64_segment_memsz in + load_segment ~writer bs paddr base offset size memsz + ) + segs; + opt_elf_class := ELF_Class_64 + | ELF32 segs -> + List.iter + (fun seg -> + let open Elf_interpreted_segment in + let bs = seg.elf32_segment_body in + let paddr = seg.elf32_segment_paddr in + let base = seg.elf32_segment_base in + let offset = seg.elf32_segment_offset in + let size = seg.elf32_segment_size in + let memsz = seg.elf32_segment_memsz in + load_segment ~writer bs paddr base offset size memsz + ) + segs; + opt_elf_class := ELF_Class_32 + +let load_binary ?(writer = write_sail_lib) addr name = let f = open_in_bin name in let buf = Buffer.create 1024 in try while true do let char = input_char f in - Buffer.add_char buf char; + Buffer.add_char buf char done; assert false with @@ -215,22 +220,26 @@ let load_binary ?writer:(writer=write_sail_lib) addr name = close_in f end | exc -> - close_in f; - raise exc + close_in f; + raise exc (* The sail model can access this by externing a unit -> int function as Elf_loader.elf_entry. *) let elf_entry () = !opt_elf_entry + (* Used by RISCV sail model test harness for exiting test *) let elf_tohost () = !opt_elf_tohost + (* Used to check last loaded elf class. *) let elf_class () = !opt_elf_class + (* Lookup the address for a symbol *) let elf_symbol symbol = - if List.mem_assoc symbol !opt_symbol_map then - let (_, _, addr, _, _) = List.assoc symbol !opt_symbol_map in + if List.mem_assoc symbol !opt_symbol_map then ( + let _, _, addr, _, _ = List.assoc symbol !opt_symbol_map in Some addr + ) else None + (* Get all symbols *) -let elf_symbols () = - !opt_symbol_map +let elf_symbols () = !opt_symbol_map diff --git a/src/lib/error_format.ml b/src/lib/error_format.ml index de06073d4..dbdfa62ff 100644 --- a/src/lib/error_format.ml +++ b/src/lib/error_format.ml @@ -67,14 +67,16 @@ let rec skip_lines in_chan = function | n when n <= 0 -> () - | n -> ignore (input_line in_chan); skip_lines in_chan (n - 1) + | n -> + ignore (input_line in_chan); + skip_lines in_chan (n - 1) let rec read_lines in_chan = function | n when n <= 0 -> [] | n -> - let l = input_line in_chan in - let ls = read_lines in_chan (n - 1) in - l :: ls + let l = input_line in_chan in + let ls = read_lines in_chan (n - 1) in + l :: ls (* Replace unprintable ASCII characters with an escape sequence. Optional color argument lets us change the color of the @@ -86,65 +88,45 @@ let rec read_lines in_chan = function let error_tabwidth = 4 -let unprintable_notation ?(color=(fun x -> x)) c = +let unprintable_notation ?(color = fun x -> x) c = let n = Char.code c in - if n = 9 then - Some (String.make error_tabwidth ' ') - else if n <= 30 || n = 127 then - Some (color (Char.escaped c)) - else - None + if n = 9 then Some (String.make error_tabwidth ' ') + else if n <= 30 || n = 127 then Some (color (Char.escaped c)) + else None -let unprintable_escape ?(color=(fun x -> x)) str = +let unprintable_escape ?(color = fun x -> x) str = let rec adjuster adjs cnum = - match adjs with - | (i, shift) :: rest -> - if cnum > i then cnum + shift else adjuster rest cnum - | [] -> - cnum + match adjs with (i, shift) :: rest -> if cnum > i then cnum + shift else adjuster rest cnum | [] -> cnum in let buf = Buffer.create (String.length str) in let shift = ref 0 in let adjusts = ref [] in - String.iteri (fun i c -> + String.iteri + (fun i c -> match unprintable_notation c with | Some escaped -> - shift := !shift + (String.length escaped - 1); - adjusts := (i, !shift) :: !adjusts; - Buffer.add_string buf (color escaped) + shift := !shift + (String.length escaped - 1); + adjusts := (i, !shift) :: !adjusts; + Buffer.add_string buf (color escaped) | None -> Buffer.add_char buf c - ) str; - Buffer.contents buf, adjuster !adjusts + ) + str; + (Buffer.contents buf, adjuster !adjusts) + +type formatter = { indent : string; endline : string -> unit; loc_color : string -> string } -type formatter = { - indent : string; - endline : string -> unit; - loc_color : string -> string - } +let err_formatter = { indent = ""; endline = prerr_endline; loc_color = Util.red } -let err_formatter = { - indent = ""; - endline = prerr_endline; - loc_color = Util.red - } +let buffer_formatter b = { indent = ""; endline = (fun str -> Buffer.add_string b (str ^ "\n")); loc_color = Util.red } -let buffer_formatter b = { - indent = ""; - endline = (fun str -> Buffer.add_string b (str ^ "\n")); - loc_color = Util.red - } +let format_endline str ppf = + ppf.endline (ppf.indent ^ Str.global_replace (Str.regexp_string "\n") ("\n" ^ ppf.indent) str) -let format_endline str ppf = ppf.endline (ppf.indent ^ (Str.global_replace (Str.regexp_string "\n") ("\n" ^ ppf.indent) str)) - let underline_single color cnum_from cnum_to = - if (cnum_from + 1) >= cnum_to then - Util.(String.make cnum_from ' ' ^ clear (color "^")) - else - Util.(String.make cnum_from ' ' ^ clear (color ("^" ^ String.make (cnum_to - cnum_from - 2) '-' ^ "^"))) + if cnum_from + 1 >= cnum_to then Util.(String.make cnum_from ' ' ^ clear (color "^")) + else Util.(String.make cnum_from ' ' ^ clear (color ("^" ^ String.make (cnum_to - cnum_from - 2) '-' ^ "^"))) -let format_hint color = function - | Some hint -> " " ^ Util.(hint |> color |> clear) - | None -> "" +let format_hint color = function Some hint -> " " ^ Util.(hint |> color |> clear) | None -> "" let format_code_single' prefix hint fname in_chan lnum cnum_from cnum_to contents ppf = skip_lines in_chan (lnum - 1); @@ -154,28 +136,31 @@ let format_code_single' prefix hint fname in_chan lnum cnum_from cnum_to content format_endline (Printf.sprintf "%s%s:%d.%d-%d:" prefix Util.(fname |> cyan |> clear) lnum cnum_from cnum_to) ppf; let line, adjust = unprintable_escape ~color:Util.(fun e -> e |> magenta |> clear) line in format_endline (line_prefix ^ line) ppf; - format_endline (blank_prefix ^ underline_single ppf.loc_color (adjust cnum_from) (adjust cnum_to) ^ format_hint ppf.loc_color hint) ppf; + format_endline + (blank_prefix ^ underline_single ppf.loc_color (adjust cnum_from) (adjust cnum_to) ^ format_hint ppf.loc_color hint) + ppf; contents { ppf with indent = blank_prefix ^ " " } let underline_double_from color cnum_from eol = Util.(String.make cnum_from ' ' ^ clear (color ("^" ^ String.make (eol - cnum_from - 1) '-'))) let underline_double_to color cnum_to = - if cnum_to = 0 then - Util.(clear (color "^")) - else - Util.(clear (color (String.make (cnum_to - 1) '-' ^ "^"))) - + if cnum_to = 0 then Util.(clear (color "^")) else Util.(clear (color (String.make (cnum_to - 1) '-' ^ "^"))) + let format_code_double' prefix fname in_chan lnum_from cnum_from lnum_to cnum_to contents ppf = skip_lines in_chan (lnum_from - 1); let line_from = input_line in_chan in skip_lines in_chan (lnum_to - lnum_from - 1); let line_to = input_line in_chan in let line_to_prefix = string_of_int lnum_to ^ Util.(clear (cyan " |")) in - let line_from_padding = String.make (String.length (string_of_int lnum_to) - String.length (string_of_int lnum_from)) ' ' in + let line_from_padding = + String.make (String.length (string_of_int lnum_to) - String.length (string_of_int lnum_from)) ' ' + in let line_from_prefix = string_of_int lnum_from ^ line_from_padding ^ Util.(clear (cyan " |")) in let blank_prefix = String.make (String.length (string_of_int lnum_to)) ' ' ^ Util.(clear (ppf.loc_color " |")) in - format_endline (Printf.sprintf "%s%s:%d.%d-%d.%d:" prefix Util.(fname |> cyan |> clear) lnum_from cnum_from lnum_to cnum_to) ppf; + format_endline + (Printf.sprintf "%s%s:%d.%d-%d.%d:" prefix Util.(fname |> cyan |> clear) lnum_from cnum_from lnum_to cnum_to) + ppf; let cnum_end = String.length line_from in let line_from, adjust = unprintable_escape ~color:Util.(fun e -> e |> magenta |> clear) line_from in format_endline (line_from_prefix ^ line_from) ppf; @@ -189,43 +174,48 @@ let format_code_single_fallback prefix fname lnum cnum_from cnum_to contents ppf let blank_prefix = String.make (String.length (string_of_int lnum)) ' ' ^ Util.(clear (ppf.loc_color " |")) in format_endline (Printf.sprintf "%s%s:%d.%d-%d:" prefix Util.(fname |> cyan |> clear) lnum cnum_from cnum_to) ppf; contents { ppf with indent = blank_prefix ^ " " } - + let format_code_single prefix hint fname lnum cnum_from cnum_to contents ppf = try let in_chan = open_in fname in begin - try format_code_single' prefix hint fname in_chan lnum cnum_from cnum_to contents ppf; close_in in_chan - with - | _ -> - format_code_single_fallback prefix fname lnum cnum_from cnum_to contents ppf; - close_in_noerr in_chan + try + format_code_single' prefix hint fname in_chan lnum cnum_from cnum_to contents ppf; + close_in in_chan + with _ -> + format_code_single_fallback prefix fname lnum cnum_from cnum_to contents ppf; + close_in_noerr in_chan end - with - | _ -> format_code_single_fallback prefix fname lnum cnum_from cnum_to contents ppf + with _ -> format_code_single_fallback prefix fname lnum cnum_from cnum_to contents ppf let format_code_double_fallback prefix fname lnum_from cnum_from lnum_to cnum_to contents ppf = let blank_prefix = String.make (String.length (string_of_int lnum_to)) ' ' ^ Util.(clear (ppf.loc_color " |")) in - format_endline (Printf.sprintf "%s%s:%d.%d-%d.%d:" prefix Util.(fname |> cyan |> clear) lnum_from cnum_from lnum_to cnum_to) ppf; + format_endline + (Printf.sprintf "%s%s:%d.%d-%d.%d:" prefix Util.(fname |> cyan |> clear) lnum_from cnum_from lnum_to cnum_to) + ppf; contents { ppf with indent = blank_prefix ^ " " } - + let format_code_double prefix fname lnum_from cnum_from lnum_to cnum_to contents ppf = try let in_chan = open_in fname in begin - try format_code_double' prefix fname in_chan lnum_from cnum_from lnum_to cnum_to contents ppf; close_in in_chan - with - | _ -> - format_code_double_fallback prefix fname lnum_from cnum_from lnum_to cnum_to contents ppf; - close_in_noerr in_chan + try + format_code_double' prefix fname in_chan lnum_from cnum_from lnum_to cnum_to contents ppf; + close_in in_chan + with _ -> + format_code_double_fallback prefix fname lnum_from cnum_from lnum_to cnum_to contents ppf; + close_in_noerr in_chan end - with - | _ -> format_code_double_fallback prefix fname lnum_from cnum_from lnum_to cnum_to contents ppf + with _ -> format_code_double_fallback prefix fname lnum_from cnum_from lnum_to cnum_to contents ppf let format_pos prefix hint p1 p2 contents ppf = let open Lexing in - if p1.pos_lnum == p2.pos_lnum - then format_code_single prefix hint p1.pos_fname p1.pos_lnum (p1.pos_cnum - p1.pos_bol) (p2.pos_cnum - p2.pos_bol) contents ppf - else format_code_double prefix p1.pos_fname p1.pos_lnum (p1.pos_cnum - p1.pos_bol) p2.pos_lnum (p2.pos_cnum - p2.pos_bol) contents ppf + if p1.pos_lnum == p2.pos_lnum then + format_code_single prefix hint p1.pos_fname p1.pos_lnum (p1.pos_cnum - p1.pos_bol) (p2.pos_cnum - p2.pos_bol) + contents ppf + else + format_code_double prefix p1.pos_fname p1.pos_lnum (p1.pos_cnum - p1.pos_bol) p2.pos_lnum (p2.pos_cnum - p2.pos_bol) + contents ppf let rec format_loc prefix hint l contents = match l with @@ -233,11 +223,13 @@ let rec format_loc prefix hint l contents = | Parse_ast.Range (p1, p2) -> format_pos prefix hint p1 p2 contents | Parse_ast.Unique (_, l) -> format_loc prefix hint l contents | Parse_ast.Hint (hint', l1, l2) -> - fun ppf -> format_loc prefix (Some hint') l1 (fun _ -> ()) { ppf with loc_color = Util.green }; format_loc prefix hint l2 contents ppf + fun ppf -> + format_loc prefix (Some hint') l1 (fun _ -> ()) { ppf with loc_color = Util.green }; + format_loc prefix hint l2 contents ppf | Parse_ast.Generated l -> - fun ppf -> - format_endline "Code generated nearby:" ppf; - format_loc prefix hint l contents ppf + fun ppf -> + format_endline "Code generated nearby:" ppf; + format_loc prefix hint l contents ppf type message = | Location of string * string option * Parse_ast.l * message @@ -250,16 +242,13 @@ let bullet = Util.(clear (blue "*")) let rec format_message msg ppf = match msg with - | Location (prefix, hint, l, msg) -> - format_loc prefix hint l (format_message msg) ppf - | Line str -> - format_endline str ppf - | Seq messages -> - List.iter (fun msg -> format_message msg ppf) messages + | Location (prefix, hint, l, msg) -> format_loc prefix hint l (format_message msg) ppf + | Line str -> format_endline str ppf + | Seq messages -> List.iter (fun msg -> format_message msg ppf) messages | List list -> - let format_list_item ppf (header, msg) = - format_endline (Util.(clear (blue "*")) ^ " " ^ header) ppf; - format_message msg { ppf with indent = ppf.indent ^ " " } - in - List.iter (format_list_item ppf) list + let format_list_item ppf (header, msg) = + format_endline (Util.(clear (blue "*")) ^ " " ^ header) ppf; + format_message msg { ppf with indent = ppf.indent ^ " " } + in + List.iter (format_list_item ppf) list | With (f, msg) -> format_message msg (f ppf) diff --git a/src/lib/format_sail.ml b/src/lib/format_sail.ml index 2139b9857..9309f6188 100644 --- a/src/lib/format_sail.ml +++ b/src/lib/format_sail.ml @@ -74,20 +74,17 @@ let rec map_last f = function | [] -> [] | [x] -> [f true x] | x :: xs -> - let x = f false x in - x :: map_last f xs + let x = f false x in + x :: map_last f xs -let line_comment_opt = function - | Comment (Lexer.Comment_line, _, _, contents) -> Some contents - | _ -> None +let line_comment_opt = function Comment (Lexer.Comment_line, _, _, contents) -> Some contents | _ -> None (** We implement a small wrapper around a subset of the PPrint API to track line breaks and dedents (points where the indentation level decreases), re-implementing a few core combinators. *) module PPrintWrapper = struct - type hardline_type = Required | Desired - + type document = | Empty | Char of char @@ -99,60 +96,39 @@ module PPrintWrapper = struct | Cat of document * document | Hardline of hardline_type | Ifflat of document * document - - type linebreak_info = { - hardlines : (int * int * hardline_type) Queue.t; - dedents: (int * int * int) Queue.t; - } - let empty_linebreak_info () = { - hardlines = Queue.create (); - dedents = Queue.create (); - } + type linebreak_info = { hardlines : (int * int * hardline_type) Queue.t; dedents : (int * int * int) Queue.t } + + let empty_linebreak_info () = { hardlines = Queue.create (); dedents = Queue.create () } let rec to_pprint lb_info = let open PPrint in function - | Empty -> - empty - | Char c -> - char c - | String s -> - string s - | Utf8string s -> - utf8string s - | Group doc -> - group (to_pprint lb_info doc) + | Empty -> empty + | Char c -> char c + | String s -> string s + | Utf8string s -> utf8string s + | Group doc -> group (to_pprint lb_info doc) | Nest (n, doc) -> - let doc = to_pprint lb_info doc in - ifflat (nest n doc) (range (fun (_, (l, c)) -> Queue.add (l, c, n) lb_info.dedents) (nest n doc)) + let doc = to_pprint lb_info doc in + ifflat (nest n doc) (range (fun (_, (l, c)) -> Queue.add (l, c, n) lb_info.dedents) (nest n doc)) | Align doc -> - let doc = to_pprint lb_info doc in - ifflat (align doc) (range (fun ((_, amount), (l, c)) -> Queue.add (l, c, amount) lb_info.dedents) (align doc)) + let doc = to_pprint lb_info doc in + ifflat (align doc) (range (fun ((_, amount), (l, c)) -> Queue.add (l, c, amount) lb_info.dedents) (align doc)) | Cat (doc1, doc2) -> - let doc1 = to_pprint lb_info doc1 in - let doc2 = to_pprint lb_info doc2 in - doc1 ^^ doc2 - | Hardline t -> - range (fun ((l, c), _) -> Queue.add (l, c, t) lb_info.hardlines) hardline + let doc1 = to_pprint lb_info doc1 in + let doc2 = to_pprint lb_info doc2 in + doc1 ^^ doc2 + | Hardline t -> range (fun ((l, c), _) -> Queue.add (l, c, t) lb_info.hardlines) hardline | Ifflat (doc1, doc2) -> - let doc1 = to_pprint lb_info doc1 in - let doc2 = to_pprint lb_info doc2 in - ifflat doc1 doc2 + let doc1 = to_pprint lb_info doc1 in + let doc2 = to_pprint lb_info doc2 in + ifflat doc1 doc2 - let (^^) doc1 doc2 = - match doc1, doc2 with - | Empty, _ -> doc2 - | _, Empty -> doc1 - | _, _ -> Cat (doc1, doc2) + let ( ^^ ) doc1 doc2 = match (doc1, doc2) with Empty, _ -> doc2 | _, Empty -> doc1 | _, _ -> Cat (doc1, doc2) let repeat n doc = - let rec go n acc = - if n = 0 then - acc - else - go (n - 1) (doc ^^ acc) - in + let rec go n acc = if n = 0 then acc else go (n - 1) (doc ^^ acc) in go n Empty let blank n = repeat n (Char ' ') @@ -185,81 +161,56 @@ module PPrintWrapper = struct let ifflat doc1 doc2 = Ifflat (doc1, doc2) - let separate_map sep f xs = - Util.fold_left_index (fun n acc x -> - if n = 0 then f x else acc ^^ sep ^^ f x - ) Empty xs - + let separate_map sep f xs = Util.fold_left_index (fun n acc x -> if n = 0 then f x else acc ^^ sep ^^ f x) Empty xs + let separate sep xs = separate_map sep (fun x -> x) xs let concat_map_last f xs = - Util.fold_left_index_last (fun n last acc x -> - if n = 0 then - f last x - else - acc ^^ f last x - ) Empty xs + Util.fold_left_index_last (fun n last acc x -> if n = 0 then f last x else acc ^^ f last x) Empty xs - let prefix n b x y = - Group (x ^^ Nest (n, break b ^^ y)) + let prefix n b x y = Group (x ^^ Nest (n, break b ^^ y)) - let infix n b op x y = - prefix n b (x ^^ blank b ^^ op) y + let infix n b op x y = prefix n b (x ^^ blank b ^^ op) y - let surround n b opening contents closing = - opening ^^ Nest (n, break b ^^ contents) ^^ break b ^^ closing + let surround n b opening contents closing = opening ^^ Nest (n, break b ^^ contents) ^^ break b ^^ closing let repeat n doc = - let rec go n acc = - if n = 0 then - acc - else - go (n - 1) (doc ^^ acc) - in + let rec go n acc = if n = 0 then acc else go (n - 1) (doc ^^ acc) in go n empty let lines s = List.map string (Util.split_on_char '\n' s) let block_comment_lines col s = let lines = Util.split_on_char '\n' s in - List.mapi (fun n line -> - if n = 0 || col > String.length line then ( - string line - ) else ( + List.mapi + (fun n line -> + if n = 0 || col > String.length line then string line + else ( (* Check we aren't deleting any content when adjusting the indentation of a block comment. *) let prefix = String.sub line 0 col in - if prefix = String.make col ' ' then ( - string (String.sub line col (String.length line - col)) - ) else ( - (* TODO: Maybe we should provide a warning here? *) + if prefix = String.make col ' ' then string (String.sub line col (String.length line - col)) + else (* TODO: Maybe we should provide a warning here? *) string line - ) ) - ) lines - + ) + lines end open PPrintWrapper -let doc_id (Id_aux (id_aux, _)) = - string (match id_aux with - | Id v -> v - | Operator op -> "operator " ^ op) +let doc_id (Id_aux (id_aux, _)) = string (match id_aux with Id v -> v | Operator op -> "operator " ^ op) type opts = { - (* Controls the bracketing of operators by underapproximating the - precedence level of the grammar as we print *) - precedence : int; - (* True if we are in a statement-like context. Controls how - if-then-else statements are formatted *) - statement : bool - } - -let default_opts = { - precedence = 10; - statement = true - } + (* Controls the bracketing of operators by underapproximating the + precedence level of the grammar as we print *) + precedence : int; + (* True if we are in a statement-like context. Controls how + if-then-else statements are formatted *) + statement : bool; +} + +let default_opts = { precedence = 10; statement = true } (* atomic lowers the allowed precedence of binary operators to zero, forcing them to always be bracketed *) @@ -274,7 +225,7 @@ let nonatomic opts = { opts with precedence = 10 } let subatomic opts = { opts with precedence = -1 } let precedence n opts = { opts with precedence = n } - + let atomic_parens opts doc = if opts.precedence <= 0 then parens doc else doc (* While everything in Sail is an expression, for formatting we @@ -298,35 +249,30 @@ let expression_like opts = { opts with statement = false } let statement_like opts = { opts with statement = true } let operator_precedence = function - | "=" -> 10, atomic, nonatomic, 1 - | ":" -> 0, subatomic, subatomic, 1 - | ".." -> 10, atomic, atomic, 0 - | "@" -> 6, precedence 5, precedence 6, 1 - | _ -> 10, subatomic, subatomic, 1 + | "=" -> (10, atomic, nonatomic, 1) + | ":" -> (0, subatomic, subatomic, 1) + | ".." -> (10, atomic, atomic, 0) + | "@" -> (6, precedence 5, precedence 6, 1) + | _ -> (10, subatomic, subatomic, 1) -let intersperse_operator_precedence = function - | "@" -> 6, precedence 5 - | _ -> 10, subatomic +let intersperse_operator_precedence = function "@" -> (6, precedence 5) | _ -> (10, subatomic) let ternary_operator_precedence = function - | ("..", "=") -> 0, atomic, atomic, nonatomic - | (":", "=") -> 0, atomic, nonatomic, nonatomic - | _ -> 10, subatomic, subatomic, subatomic + | "..", "=" -> (0, atomic, atomic, nonatomic) + | ":", "=" -> (0, atomic, nonatomic, nonatomic) + | _ -> (10, subatomic, subatomic, subatomic) let unary_operator_precedence = function - | "throw" -> 0, nonatomic, space - | "return" -> 0, nonatomic, space - | "internal_return" -> 0, nonatomic, space - | "*" -> 10, atomic, empty - | "-" -> 10, atomic, empty - | "2^" -> 10, atomic, empty - | _ -> 10, subatomic, empty - -let can_hang chunks = - match Queue.peek_opt chunks with - | Some (Comment _) -> false - | _ -> true - + | "throw" -> (0, nonatomic, space) + | "return" -> (0, nonatomic, space) + | "internal_return" -> (0, nonatomic, space) + | "*" -> (10, atomic, empty) + | "-" -> (10, atomic, empty) + | "2^" -> (10, atomic, empty) + | _ -> (10, subatomic, empty) + +let can_hang chunks = match Queue.peek_opt chunks with Some (Comment _) -> false | _ -> true + let opt_delim s = ifflat empty (string s) let softline = break 0 @@ -338,346 +284,345 @@ let surround_hardline h n b opening contents closing = let b = if h then hardline else break b in group (opening ^^ nest n (b ^^ contents) ^^ b ^^ closing) -type config = { - indent : int; - preserve_structure : bool; - line_width : int; - ribbon_width : float; - } - -let default_config = { - indent = 4; - preserve_structure = false; - line_width = 120; - ribbon_width = 1.; - } - -let known_key k = - k = "indent" - || k = "preserve_structure" - || k = "line_width" - || k = "ribbon_width" +type config = { indent : int; preserve_structure : bool; line_width : int; ribbon_width : float } + +let default_config = { indent = 4; preserve_structure = false; line_width = 120; ribbon_width = 1. } + +let known_key k = k = "indent" || k = "preserve_structure" || k = "line_width" || k = "ribbon_width" let int_option k = function | `Int n -> Some n | json -> - Reporting.simple_warn (Printf.sprintf "Argument for key %s must be an integer, got %s instead. Using default value." k (Yojson.Basic.to_string json)); - None + Reporting.simple_warn + (Printf.sprintf "Argument for key %s must be an integer, got %s instead. Using default value." k + (Yojson.Basic.to_string json) + ); + None let bool_option k = function | `Bool n -> Some n | json -> - Reporting.simple_warn (Printf.sprintf "Argument for key %s must be a boolean, got %s instead. Using default value." k (Yojson.Basic.to_string json)); - None + Reporting.simple_warn + (Printf.sprintf "Argument for key %s must be a boolean, got %s instead. Using default value." k + (Yojson.Basic.to_string json) + ); + None let float_option k = function | `Int n -> Some (float_of_int n) | `Float n -> Some n | json -> - Reporting.simple_warn (Printf.sprintf "Argument for key %s must be a number, got %s instead. Using default value." k (Yojson.Basic.to_string json)); - None - -let get_option ~key:k ~keys:ks ~read:read ~default:d = - List.assoc_opt k ks |> (fun opt -> Option.bind opt (read k)) |> Option.value ~default:d - + Reporting.simple_warn + (Printf.sprintf "Argument for key %s must be a number, got %s instead. Using default value." k + (Yojson.Basic.to_string json) + ); + None + +let get_option ~key:k ~keys:ks ~read ~default:d = + List.assoc_opt k ks |> (fun opt -> Option.bind opt (read k)) |> Option.value ~default:d + let config_from_json (json : Yojson.Basic.t) = match json with | `Assoc keys -> - begin match List.find_opt (fun (k, _) -> not (known_key k)) keys with - | Some (k, _) -> - Reporting.simple_warn (Printf.sprintf "Unknown key %s in formatting config" k) - | None -> () - end; - { indent = get_option ~key:"indent" ~keys:keys ~read:int_option ~default:default_config.indent; - preserve_structure = get_option ~key:"preserve_structure" ~keys:keys ~read:bool_option ~default:default_config.preserve_structure; - line_width = get_option ~key:"line_width" ~keys:keys ~read:int_option ~default:default_config.line_width; - ribbon_width = get_option ~key:"ribbon_width" ~keys:keys ~read:float_option ~default:default_config.ribbon_width; - } - | _ -> - raise (Reporting.err_general Parse_ast.Unknown "Invalid formatting configuration") + begin + match List.find_opt (fun (k, _) -> not (known_key k)) keys with + | Some (k, _) -> Reporting.simple_warn (Printf.sprintf "Unknown key %s in formatting config" k) + | None -> () + end; + { + indent = get_option ~key:"indent" ~keys ~read:int_option ~default:default_config.indent; + preserve_structure = + get_option ~key:"preserve_structure" ~keys ~read:bool_option ~default:default_config.preserve_structure; + line_width = get_option ~key:"line_width" ~keys ~read:int_option ~default:default_config.line_width; + ribbon_width = get_option ~key:"ribbon_width" ~keys ~read:float_option ~default:default_config.ribbon_width; + } + | _ -> raise (Reporting.err_general Parse_ast.Unknown "Invalid formatting configuration") module type CONFIG = sig val config : config end -module Make(Config : CONFIG) = struct +module Make (Config : CONFIG) = struct let indent = Config.config.indent let preserve_structure = Config.config.preserve_structure - let rec doc_chunk ?(ungroup_tuple=false) ?(toplevel=false) opts = function + let rec doc_chunk ?(ungroup_tuple = false) ?(toplevel = false) opts = function | Atom s -> string s | Chunks chunks -> doc_chunks opts chunks | Delim s -> string s ^^ space | Opt_delim s -> opt_delim s - | String_literal s -> - utf8string ("\"" ^ String.escaped s ^ "\"") + | String_literal s -> utf8string ("\"" ^ String.escaped s ^ "\"") | App (id, args) -> - doc_id id - ^^ group (surround indent 0 (char '(') (separate_map softline (doc_chunks (opts |> nonatomic |> expression_like)) args) (char ')')) + doc_id id + ^^ group + (surround indent 0 (char '(') + (separate_map softline (doc_chunks (opts |> nonatomic |> expression_like)) args) + (char ')') + ) | Tuple (l, r, spacing, args) -> - let group_fn = if ungroup_tuple then (fun x -> x) else group in - group_fn (surround indent spacing (string l) (separate_map softline (doc_chunks (nonatomic opts)) args) (string r)) + let group_fn = if ungroup_tuple then fun x -> x else group in + group_fn + (surround indent spacing (string l) (separate_map softline (doc_chunks (nonatomic opts)) args) (string r)) | Intersperse (op, args) -> - let outer_prec, prec = intersperse_operator_precedence op in - let doc = group (separate_map (space ^^ string op ^^ space) (doc_chunks (opts |> prec |> expression_like)) args) in - if outer_prec > opts.precedence then ( - parens doc - ) else ( - doc - ) - | Spacer (line, n) -> - if line then - repeat n hardline - else - repeat n space + let outer_prec, prec = intersperse_operator_precedence op in + let doc = + group (separate_map (space ^^ string op ^^ space) (doc_chunks (opts |> prec |> expression_like)) args) + in + if outer_prec > opts.precedence then parens doc else doc + | Spacer (line, n) -> if line then repeat n hardline else repeat n space | Unary (op, exp) -> - let outer_prec, inner_prec, spacing = unary_operator_precedence op in - let doc = string op ^^ spacing ^^ doc_chunks (opts |> inner_prec |> expression_like) exp in - if outer_prec > opts.precedence then ( - parens doc - ) else ( - doc - ) + let outer_prec, inner_prec, spacing = unary_operator_precedence op in + let doc = string op ^^ spacing ^^ doc_chunks (opts |> inner_prec |> expression_like) exp in + if outer_prec > opts.precedence then parens doc else doc | Binary (lhs, op, rhs) -> - let outer_prec, lhs_prec, rhs_prec, spacing = operator_precedence op in - let doc = - infix indent spacing (string op) - (doc_chunks (opts |> lhs_prec |> expression_like) lhs) - (doc_chunks (opts |> rhs_prec |> expression_like) rhs) - in - if outer_prec > opts.precedence then ( - parens doc - ) else ( - doc - ) + let outer_prec, lhs_prec, rhs_prec, spacing = operator_precedence op in + let doc = + infix indent spacing (string op) + (doc_chunks (opts |> lhs_prec |> expression_like) lhs) + (doc_chunks (opts |> rhs_prec |> expression_like) rhs) + in + if outer_prec > opts.precedence then parens doc else doc | Ternary (x, op1, y, op2, z) -> - let outer_prec, x_prec, y_prec, z_prec = ternary_operator_precedence (op1, op2) in - let doc = - prefix indent 1 (doc_chunks (opts |> x_prec |> expression_like) x - ^^ space ^^ string op1 ^^ space - ^^ doc_chunks (opts |> y_prec |> expression_like) y - ^^ space ^^ string op2) - (doc_chunks (opts |> z_prec |> expression_like) z) - in - if outer_prec > opts.precedence then ( - parens doc - ) else ( - doc - ) + let outer_prec, x_prec, y_prec, z_prec = ternary_operator_precedence (op1, op2) in + let doc = + prefix indent 1 + (doc_chunks (opts |> x_prec |> expression_like) x + ^^ space ^^ string op1 ^^ space + ^^ doc_chunks (opts |> y_prec |> expression_like) y + ^^ space ^^ string op2 + ) + (doc_chunks (opts |> z_prec |> expression_like) z) + in + if outer_prec > opts.precedence then parens doc else doc | If_then_else (bracing, i, t, e) -> - let insert_braces = opts.statement || bracing.then_brace || bracing.else_brace in - let i = doc_chunks (opts |> nonatomic |> expression_like) i in - let t = - if insert_braces && not preserve_structure && not bracing.then_brace then ( - doc_chunk opts (Block (true, [t])) - ) else ( - doc_chunks (opts |> nonatomic |> expression_like) t - ) in - let e = - if insert_braces && not preserve_structure && not bracing.else_brace then ( - doc_chunk opts (Block (true, [e])) - ) else ( - doc_chunks (opts |> nonatomic |> expression_like) e - ) in - separate space [string "if"; i; string "then"; t; string "else"; e] - |> atomic_parens opts + let insert_braces = opts.statement || bracing.then_brace || bracing.else_brace in + let i = doc_chunks (opts |> nonatomic |> expression_like) i in + let t = + if insert_braces && (not preserve_structure) && not bracing.then_brace then doc_chunk opts (Block (true, [t])) + else doc_chunks (opts |> nonatomic |> expression_like) t + in + let e = + if insert_braces && (not preserve_structure) && not bracing.else_brace then doc_chunk opts (Block (true, [e])) + else doc_chunks (opts |> nonatomic |> expression_like) e + in + separate space [string "if"; i; string "then"; t; string "else"; e] |> atomic_parens opts | If_then (bracing, i, t) -> - let i = doc_chunks (opts |> nonatomic |> expression_like) i in - let t = - if opts.statement && not preserve_structure && not bracing then ( - doc_chunk opts (Block (true, [t])) - ) else ( - doc_chunks (opts |> nonatomic |> expression_like) t - ) in - separate space [string "if"; i; string "then"; t] - |> atomic_parens opts + let i = doc_chunks (opts |> nonatomic |> expression_like) i in + let t = + if opts.statement && (not preserve_structure) && not bracing then doc_chunk opts (Block (true, [t])) + else doc_chunks (opts |> nonatomic |> expression_like) t + in + separate space [string "if"; i; string "then"; t] |> atomic_parens opts | Vector_updates (exp, updates) -> - let opts = opts |> nonatomic |> expression_like in - let exp_doc = doc_chunks opts exp in - surround indent 0 - (char '[' ^^ exp_doc ^^ space ^^ string "with" ^^ space) - (group (separate_map (char ',' ^^ break 1) (doc_chunk opts) updates)) - (char ']') - |> atomic_parens opts + let opts = opts |> nonatomic |> expression_like in + let exp_doc = doc_chunks opts exp in + surround indent 0 + (char '[' ^^ exp_doc ^^ space ^^ string "with" ^^ space) + (group (separate_map (char ',' ^^ break 1) (doc_chunk opts) updates)) + (char ']') + |> atomic_parens opts | Index (exp, ix) -> - let exp_doc = doc_chunks (opts |> atomic |> expression_like) exp in - let ix_doc = doc_chunks (opts |> nonatomic |> expression_like) ix in - exp_doc ^^ surround indent 0 (char '[') ix_doc (char ']') - |> atomic_parens opts + let exp_doc = doc_chunks (opts |> atomic |> expression_like) exp in + let ix_doc = doc_chunks (opts |> nonatomic |> expression_like) ix in + exp_doc ^^ surround indent 0 (char '[') ix_doc (char ']') |> atomic_parens opts | Exists ex -> - let ex_doc = - doc_chunks (atomic opts) ex.vars - ^^ char ',' ^^ break 1 - ^^ doc_chunks (nonatomic opts) ex.constr - ^^ char '.' ^^ break 1 - ^^ doc_chunks (nonatomic opts) ex.typ - in - enclose (char '{') (char '}') (align ex_doc) + let ex_doc = + doc_chunks (atomic opts) ex.vars + ^^ char ',' ^^ break 1 + ^^ doc_chunks (nonatomic opts) ex.constr + ^^ char '.' ^^ break 1 + ^^ doc_chunks (nonatomic opts) ex.typ + in + enclose (char '{') (char '}') (align ex_doc) | Function_typ ft -> - separate space [ - group (doc_chunks opts ft.lhs); - if ft.mapping then string "<->" else string "->"; - group (doc_chunks opts ft.rhs) - ] + separate space + [ + group (doc_chunks opts ft.lhs); + (if ft.mapping then string "<->" else string "->"); + group (doc_chunks opts ft.rhs); + ] | Typ_quant typq -> - group ( - align ( - string "forall" ^^ space - ^^ nest 2 ( - doc_chunks opts typq.vars - ^^ (match typq.constr_opt with - | None -> char '.' - | Some constr -> char ',' ^^ break 1 ^^ doc_chunks opts constr ^^ char '.') - ) + group + (align + (string "forall" ^^ space + ^^ nest 2 + (doc_chunks opts typq.vars + ^^ + match typq.constr_opt with + | None -> char '.' + | Some constr -> char ',' ^^ break 1 ^^ doc_chunks opts constr ^^ char '.' + ) ) - ) - ^^ break 1 + ) + ^^ break 1 | Struct_update (exp, fexps) -> - surround indent 1 (char '{') (doc_chunks opts exp ^^ space ^^ string "with" ^^ break 1 ^^ separate_map (break 1) (doc_chunks opts) fexps) (char '}') - | Comment (comment_type, n, col, contents) -> - begin match comment_type with - | Lexer.Comment_line -> - blank n ^^ string "//" ^^ string contents ^^ require_hardline - | Lexer.Comment_block -> - (* Allow a linebreak after a block comment with newlines. This prevents formatting like: - /* comment line 1 - comment line 2 */exp - by forcing exp on a newline if the comment contains linebreaks - *) - match block_comment_lines col contents with - | [l] -> blank n ^^ string "/*" ^^ l ^^ string "*/" ^^ space - | ls -> - blank n ^^ group (align (string "/*" ^^ separate hardline ls ^^ string "*/")) - ^^ require_hardline - end + surround indent 1 (char '{') + (doc_chunks opts exp ^^ space ^^ string "with" ^^ break 1 ^^ separate_map (break 1) (doc_chunks opts) fexps) + (char '}') + | Comment (comment_type, n, col, contents) -> begin + match comment_type with + | Lexer.Comment_line -> blank n ^^ string "//" ^^ string contents ^^ require_hardline + | Lexer.Comment_block -> ( + (* Allow a linebreak after a block comment with newlines. This prevents formatting like: + /* comment line 1 + comment line 2 */exp + by forcing exp on a newline if the comment contains linebreaks + *) + match block_comment_lines col contents with + | [l] -> blank n ^^ string "/*" ^^ l ^^ string "*/" ^^ space + | ls -> blank n ^^ group (align (string "/*" ^^ separate hardline ls ^^ string "*/")) ^^ require_hardline + ) + end | Function f -> - let sep = hardline ^^ string "and" ^^ space in - let clauses = match f.funcls with - | [] -> Reporting.unreachable (id_loc f.id) __POS__ "Function with no clauses found" - | [funcl] -> doc_funcl f.return_typ_opt opts funcl - | funcl :: funcls -> - doc_funcl f.return_typ_opt opts funcl ^^ sep ^^ separate_map sep (doc_funcl None opts) f.funcls in - string "function" ^^ (if f.clause then space ^^ string "clause" else empty) ^^ space ^^ doc_id f.id - ^^ (match f.typq_opt with Some typq -> space ^^ doc_chunks opts typq | None -> empty) - ^^ clauses ^^ hardline + let sep = hardline ^^ string "and" ^^ space in + let clauses = + match f.funcls with + | [] -> Reporting.unreachable (id_loc f.id) __POS__ "Function with no clauses found" + | [funcl] -> doc_funcl f.return_typ_opt opts funcl + | funcl :: funcls -> + doc_funcl f.return_typ_opt opts funcl ^^ sep ^^ separate_map sep (doc_funcl None opts) f.funcls + in + string "function" + ^^ (if f.clause then space ^^ string "clause" else empty) + ^^ space ^^ doc_id f.id + ^^ (match f.typq_opt with Some typq -> space ^^ doc_chunks opts typq | None -> empty) + ^^ clauses ^^ hardline | Val vs -> - let doc_binding (target, name) = string target ^^ char ':' ^^ space ^^ char '"' ^^ utf8string name ^^ char '"' in - string "val" ^^ space ^^ (if vs.is_cast then string "cast" ^^ space else empty) ^^ doc_id vs.id - ^^ group (match vs.extern_opt with - | Some extern -> - space ^^ char '=' ^^ space - ^^ string (if extern.pure then "pure" else "monadic") ^^ space - ^^ surround indent 1 (char '{') (separate_map (char ',' ^^ break 1) doc_binding extern.bindings) (char '}') - | None -> empty) - ^^ space ^^ char ':' - ^^ group (nest indent ((match vs.typq_opt with - | Some typq -> space ^^ doc_chunks opts typq - | None -> space) - ^^ doc_chunks opts vs.typ)) + let doc_binding (target, name) = + string target ^^ char ':' ^^ space ^^ char '"' ^^ utf8string name ^^ char '"' + in + string "val" ^^ space + ^^ (if vs.is_cast then string "cast" ^^ space else empty) + ^^ doc_id vs.id + ^^ group + ( match vs.extern_opt with + | Some extern -> + space ^^ char '=' ^^ space + ^^ string (if extern.pure then "pure" else "monadic") + ^^ space + ^^ surround indent 1 (char '{') + (separate_map (char ',' ^^ break 1) doc_binding extern.bindings) + (char '}') + | None -> empty + ) + ^^ space ^^ char ':' + ^^ group + (nest indent + ((match vs.typq_opt with Some typq -> space ^^ doc_chunks opts typq | None -> space) + ^^ doc_chunks opts vs.typ + ) + ) | Enum e -> - string "enum" ^^ space ^^ doc_id e.id - ^^ group ( - (match e.enum_functions with - | Some enum_functions -> + string "enum" ^^ space ^^ doc_id e.id + ^^ group + (( match e.enum_functions with + | Some enum_functions -> space ^^ string "with" ^^ space ^^ align (separate_map softline (doc_chunks opts) enum_functions) - | None -> - empty + | None -> empty ) - ^^ space ^^ char '=' ^^ space - ^^ surround indent 1 (char '{') (separate_map softline (doc_chunks opts) e.members) (char '}') - ) - | Pragma (pragma, arg) -> - char '$' ^^ string pragma ^^ space ^^ string arg ^^ hardline + ^^ space ^^ char '=' ^^ space + ^^ surround indent 1 (char '{') (separate_map softline (doc_chunks opts) e.members) (char '}') + ) + | Pragma (pragma, arg) -> char '$' ^^ string pragma ^^ space ^^ string arg ^^ hardline | Block (always_hardline, exps) -> - let exps = map_last (fun no_semi chunks -> doc_block_exp_chunks (opts |> nonatomic |> statement_like) no_semi chunks) exps in - let sep = if (always_hardline || List.exists snd exps) then hardline else break 1 in - let exps = List.map fst exps in - surround_hardline always_hardline indent 1 (char '{') (separate sep exps) (char '}') - |> atomic_parens opts + let exps = + map_last + (fun no_semi chunks -> doc_block_exp_chunks (opts |> nonatomic |> statement_like) no_semi chunks) + exps + in + let sep = if always_hardline || List.exists snd exps then hardline else break 1 in + let exps = List.map fst exps in + surround_hardline always_hardline indent 1 (char '{') (separate sep exps) (char '}') |> atomic_parens opts | Block_binder (binder, x, y) -> - if can_hang y then ( - separate space [string (binder_keyword binder); doc_chunks (atomic opts) x; char '='; doc_chunks (nonatomic opts) y] - ) else ( - separate space [string (binder_keyword binder); doc_chunks (atomic opts) x; char '='] - ^^ nest 4 (hardline ^^ doc_chunks (nonatomic opts) y) - ) + if can_hang y then + separate space + [string (binder_keyword binder); doc_chunks (atomic opts) x; char '='; doc_chunks (nonatomic opts) y] + else + separate space [string (binder_keyword binder); doc_chunks (atomic opts) x; char '='] + ^^ nest 4 (hardline ^^ doc_chunks (nonatomic opts) y) | Binder (binder, x, y, z) -> - prefix indent 1 - (separate space [string (binder_keyword binder); doc_chunks (atomic opts) x; char '='; doc_chunks (nonatomic opts) y; string "in"]) - (doc_chunks (nonatomic opts) z) + prefix indent 1 + (separate space + [ + string (binder_keyword binder); + doc_chunks (atomic opts) x; + char '='; + doc_chunks (nonatomic opts) y; + string "in"; + ] + ) + (doc_chunks (nonatomic opts) z) | Match m -> - let kw1, kw2 = match_keywords m.kind in - string kw1 ^^ space ^^ doc_chunks (nonatomic opts) m.exp - ^^ (Option.fold ~none:empty ~some:(fun k -> space ^^ string k) kw2) ^^ space - ^^ surround indent 1 (char '{') (separate_map hardline (doc_pexp_chunks opts) m.cases) (char '}') - |> atomic_parens opts + let kw1, kw2 = match_keywords m.kind in + string kw1 ^^ space + ^^ doc_chunks (nonatomic opts) m.exp + ^^ Option.fold ~none:empty ~some:(fun k -> space ^^ string k) kw2 + ^^ space + ^^ surround indent 1 (char '{') (separate_map hardline (doc_pexp_chunks opts) m.cases) (char '}') + |> atomic_parens opts | Foreach loop -> - let to_keyword = string (if loop.decreasing then "downto" else "to") in - string "foreach" ^^ space - ^^ group (surround indent 0 (char '(') - (separate (break 1) ([ - doc_chunks (opts |> atomic) loop.var; - string "from" ^^ space ^^ doc_chunks (opts |> atomic |> expression_like) loop.from_index; - to_keyword ^^ space ^^ doc_chunks (opts |> atomic |> expression_like) loop.to_index - ] - @ (match loop.step with - | Some step -> - [string "by" ^^ space ^^ doc_chunks (opts |> atomic |> expression_like) step] - | None -> - [] - ) - )) - (char ')')) - ^^ space ^^ group (doc_chunks (opts |> nonatomic |> statement_like) loop.body) + let to_keyword = string (if loop.decreasing then "downto" else "to") in + string "foreach" ^^ space + ^^ group + (surround indent 0 (char '(') + (separate (break 1) + ([ + doc_chunks (opts |> atomic) loop.var; + string "from" ^^ space ^^ doc_chunks (opts |> atomic |> expression_like) loop.from_index; + to_keyword ^^ space ^^ doc_chunks (opts |> atomic |> expression_like) loop.to_index; + ] + @ + match loop.step with + | Some step -> [string "by" ^^ space ^^ doc_chunks (opts |> atomic |> expression_like) step] + | None -> [] + ) + ) + (char ')') + ) + ^^ space + ^^ group (doc_chunks (opts |> nonatomic |> statement_like) loop.body) | While loop -> - let measure = match loop.termination_measure with - | Some chunks -> - string "termination_measure" ^^ space ^^ group (surround indent 1 (char '{') (doc_chunks opts chunks) (char '}')) ^^ space - | None -> empty - in - let cond = doc_chunks (opts |> nonatomic |> expression_like) loop.cond in - let body = doc_chunks (opts |> nonatomic |> statement_like) loop.body in - if loop.repeat_until then ( - string "repeat" ^^ space ^^ measure ^^ body ^^ space ^^ string "until" ^^ space ^^ cond - ) else ( - string "while" ^^ space ^^ measure ^^ cond ^^ space ^^ string "do" ^^ space ^^ body - ) - | Field (exp, id) -> - doc_chunks (subatomic opts) exp ^^ char '.' ^^ doc_id id - | Raw str -> - separate hardline (lines str) + let measure = + match loop.termination_measure with + | Some chunks -> + string "termination_measure" ^^ space + ^^ group (surround indent 1 (char '{') (doc_chunks opts chunks) (char '}')) + ^^ space + | None -> empty + in + let cond = doc_chunks (opts |> nonatomic |> expression_like) loop.cond in + let body = doc_chunks (opts |> nonatomic |> statement_like) loop.body in + if loop.repeat_until then + string "repeat" ^^ space ^^ measure ^^ body ^^ space ^^ string "until" ^^ space ^^ cond + else string "while" ^^ space ^^ measure ^^ cond ^^ space ^^ string "do" ^^ space ^^ body + | Field (exp, id) -> doc_chunks (subatomic opts) exp ^^ char '.' ^^ doc_id id + | Raw str -> separate hardline (lines str) and doc_pexp_chunks_pair opts pexp = let pat = doc_chunks opts pexp.pat in let body = doc_chunks opts pexp.body in match pexp.guard with - | None -> pat, body - | Some guard -> - separate space [pat; string "if"; doc_chunks opts guard], - body + | None -> (pat, body) + | Some guard -> (separate space [pat; string "if"; doc_chunks opts guard], body) and doc_pexp_chunks opts pexp = let guarded_pat, body = doc_pexp_chunks_pair opts pexp in separate space [guarded_pat; string "=>"; body] and doc_funcl return_typ_opt opts pexp = - let return_typ = match return_typ_opt with + let return_typ = + match return_typ_opt with | Some chunks -> space ^^ prefix_parens indent (string "->") (doc_chunks opts chunks) ^^ space - | None -> space in + | None -> space + in match pexp.guard with | None -> - (if pexp.funcl_space then space else empty) - ^^ group ( - doc_chunks ~ungroup_tuple:true opts pexp.pat - ^^ return_typ - ) - ^^ string "=" - ^^ space ^^ doc_chunks opts pexp.body + (if pexp.funcl_space then space else empty) + ^^ group (doc_chunks ~ungroup_tuple:true opts pexp.pat ^^ return_typ) + ^^ string "=" ^^ space ^^ doc_chunks opts pexp.body | Some guard -> - parens (separate space [doc_chunks opts pexp.pat; string "if"; doc_chunks opts guard]) - ^^ return_typ - ^^ string "=" - ^^ space ^^ doc_chunks opts pexp.body + parens (separate space [doc_chunks opts pexp.pat; string "if"; doc_chunks opts guard]) + ^^ return_typ ^^ string "=" ^^ space ^^ doc_chunks opts pexp.body (* Format an expression in a block, optionally terminating it with a semicolon. If the expression has a trailing comment then we will @@ -692,30 +637,31 @@ module Make(Config : CONFIG) = struct let requires_hardline = ref false in let terminator = if no_semi then empty else char ';' in let doc = - concat_map_last (fun last chunk -> - if last then + concat_map_last + (fun last chunk -> + if last then ( match line_comment_opt chunk with | Some contents -> - requires_hardline := true; - terminator ^^ space ^^ string "//" ^^ string contents ^^ require_hardline + requires_hardline := true; + terminator ^^ space ^^ string "//" ^^ string contents ^^ require_hardline | None -> doc_chunk opts chunk ^^ terminator - else - doc_chunk opts chunk - ) chunks in + ) + else doc_chunk opts chunk + ) + chunks + in (group doc, !requires_hardline) - and doc_chunks ?(ungroup_tuple=false) opts chunks = - Queue.fold (fun doc chunk -> - doc ^^ doc_chunk ~ungroup_tuple:ungroup_tuple opts chunk - ) empty chunks + and doc_chunks ?(ungroup_tuple = false) opts chunks = + Queue.fold (fun doc chunk -> doc ^^ doc_chunk ~ungroup_tuple opts chunk) empty chunks let to_string doc = let b = Buffer.create 1024 in let lb_info = empty_linebreak_info () in PPrint.ToBuffer.pretty Config.config.ribbon_width Config.config.line_width b (to_pprint lb_info doc); - Buffer.contents b, lb_info + (Buffer.contents b, lb_info) - let fixup ?(debug=false) lb_info s = + let fixup ?(debug = false) lb_info s = let buf = Buffer.create (String.length s) in let column = ref 0 in let line = ref 0 in @@ -729,90 +675,79 @@ module Make(Config : CONFIG) = struct hardline. Encountering a desired hardline means the requirement has been satisifed so we set it to false. *) let require_hardline = ref false in - String.iter (fun c -> + String.iter + (fun c -> let rec pop_dedents () = - begin match Queue.peek_opt lb_info.dedents with - | Some (l, c, amount) when l < !line || (l = !line && c = !column) -> - if !after_hardline then ( - pending_spaces := !pending_spaces - amount - ); - if debug then ( - Buffer.add_string buf Util.("D" ^ string_of_int amount |> green |> clear) - ); - ignore (Queue.take lb_info.dedents); - pop_dedents () - | _ -> (); + begin + match Queue.peek_opt lb_info.dedents with + | Some (l, c, amount) when l < !line || (l = !line && c = !column) -> + if !after_hardline then pending_spaces := !pending_spaces - amount; + if debug then Buffer.add_string buf Util.("D" ^ string_of_int amount |> green |> clear); + ignore (Queue.take lb_info.dedents); + pop_dedents () + | _ -> () end in pop_dedents (); if c = '\n' then ( - begin match Queue.take_opt lb_info.hardlines with - | Some (l, c, hardline_type) -> - begin match hardline_type with - | Desired -> - if debug then ( - Buffer.add_string buf Util.("H" |> red |> clear) - ); - Buffer.add_char buf '\n'; - pending_spaces := 0; - if !require_hardline then ( - require_hardline := false - ); - after_hardline := true - | Required -> - if debug then ( - Buffer.add_string buf Util.("R" |> red |> clear) - ); - require_hardline := true; - after_hardline := true - end - | None -> - Reporting.unreachable Parse_ast.Unknown __POS__ (Printf.sprintf "Missing hardline %d %d" !line !column) + begin + match Queue.take_opt lb_info.hardlines with + | Some (l, c, hardline_type) -> begin + match hardline_type with + | Desired -> + if debug then Buffer.add_string buf Util.("H" |> red |> clear); + Buffer.add_char buf '\n'; + pending_spaces := 0; + if !require_hardline then require_hardline := false; + after_hardline := true + | Required -> + if debug then Buffer.add_string buf Util.("R" |> red |> clear); + require_hardline := true; + after_hardline := true + end + | None -> + Reporting.unreachable Parse_ast.Unknown __POS__ (Printf.sprintf "Missing hardline %d %d" !line !column) end; incr line; - column := 0; - ) else ( - if c = ' ' then ( - incr pending_spaces - ) else ( + column := 0 + ) + else ( + if c = ' ' then incr pending_spaces + else ( if !require_hardline then ( Buffer.add_char buf '\n'; require_hardline := false ); - if !pending_spaces > 0 then ( - Buffer.add_string buf (String.make !pending_spaces ' '); - ); + if !pending_spaces > 0 then Buffer.add_string buf (String.make !pending_spaces ' '); Buffer.add_char buf c; after_hardline := false; pending_spaces := 0 ); - incr column; + incr column ) - ) s; + ) + s; Buffer.contents buf - let format_defs_once ?(debug=false) source comments defs = + let format_defs_once ?(debug = false) source comments defs = let chunks = chunk_defs source comments defs in - if debug then ( - Queue.iter (prerr_chunk "") chunks - ); + if debug then Queue.iter (prerr_chunk "") chunks; let doc = Queue.fold (fun doc chunk -> doc ^^ doc_chunk ~toplevel:true default_opts chunk) empty chunks in let formatted, lb_info = to_string (doc ^^ hardline) in fixup lb_info formatted - let format_defs ?(debug=false) filename source comments defs = + let format_defs ?(debug = false) filename source comments defs = let open Initial_check in - let f1 = format_defs_once ~debug:debug source comments defs in - let comments, defs = parse_file_from_string ~filename:filename ~contents:f1 in - let f2 = format_defs_once ~debug:debug f1 comments defs in - let comments, defs = parse_file_from_string ~filename:filename ~contents:f2 in - let f3 = format_defs_once ~debug:debug f2 comments defs in + let f1 = format_defs_once ~debug source comments defs in + let comments, defs = parse_file_from_string ~filename ~contents:f1 in + let f2 = format_defs_once ~debug f1 comments defs in + let comments, defs = parse_file_from_string ~filename ~contents:f2 in + let f3 = format_defs_once ~debug f2 comments defs in if f2 <> f3 then ( print_endline f2; print_endline f3; raise (Reporting.err_general Parse_ast.Unknown filename) ); f3 - end diff --git a/src/lib/format_sail.mli b/src/lib/format_sail.mli index 74d67d96a..493da5ddc 100644 --- a/src/lib/format_sail.mli +++ b/src/lib/format_sail.mli @@ -66,31 +66,29 @@ (****************************************************************************) type config = { - indent : int; - (** The default indentation depth (default 4) *) - preserve_structure : bool; - (** If true, the formatter preserves the structure of the AST as + indent : int; (** The default indentation depth (default 4) *) + preserve_structure : bool; + (** If true, the formatter preserves the structure of the AST as much as possible - it won't insert braces around if statements and so on where there weren't any and so on. (default false) *) - line_width : int; - (** The desired maximum line width. (default 120) *) - ribbon_width : float; - (** The fraction (between 0.0 and 1.0) of the maximum line width that + line_width : int; (** The desired maximum line width. (default 120) *) + ribbon_width : float; + (** The fraction (between 0.0 and 1.0) of the maximum line width that can be filled by non whitespace characters before we consider breaking. (default 1.0) *) - } +} (** Read the config struct from a json object. Raises err_general if the json is not an object, and warns about any invalid keys. *) val config_from_json : Yojson.Basic.t -> config - + val default_config : config module type CONFIG = sig val config : config end -module Make(Config : CONFIG) : sig +module Make (Config : CONFIG) : sig (** If debug is true, we print extra debugging information to stderr, and annotate the output with various information on linebreaking decisions. *) diff --git a/src/lib/frontend.ml b/src/lib/frontend.ml index d2fc639f6..986802ba8 100644 --- a/src/lib/frontend.ml +++ b/src/lib/frontend.ml @@ -72,8 +72,9 @@ let opt_ddump_initial_ast = ref false let opt_ddump_tc_ast = ref false let opt_dno_cast = ref true let opt_reformat : string option ref = ref None - -let check_ast (asserts_termination : bool) (env : Type_check.Env.t) (ast : uannot ast) : Type_check.tannot ast * Type_check.Env.t * Effects.side_effect_info = + +let check_ast (asserts_termination : bool) (env : Type_check.Env.t) (ast : uannot ast) : + Type_check.tannot ast * Type_check.Env.t * Effects.side_effect_info = let env = if !opt_dno_cast then Type_check.Env.no_casts env else env in let ast, env = Type_error.check env ast in let ast = Scattered.descatter ast in @@ -81,25 +82,32 @@ let check_ast (asserts_termination : bool) (env : Type_check.Env.t) (ast : uanno Effects.check_side_effects side_effects ast; let () = if !opt_ddump_tc_ast then Pretty_print_sail.pp_ast stdout (Type_check.strip_ast ast) else () in (ast, env, side_effects) - -let load_files ?target:target default_sail_dir options type_envs files = + +let load_files ?target default_sail_dir options type_envs files = let t = Profile.start () in let parsed_files = List.map (fun f -> (f, Initial_check.parse_file f)) files in let comments = List.map (fun (f, (comments, _)) -> (f, comments)) parsed_files in let target_name = Option.map Target.name target in - let ast = Parse_ast.Defs (List.map (fun (f, (_, file_ast)) -> (f, Preprocess.preprocess default_sail_dir target_name options file_ast)) parsed_files) in + let ast = + Parse_ast.Defs + (List.map + (fun (f, (_, file_ast)) -> (f, Preprocess.preprocess default_sail_dir target_name options file_ast)) + parsed_files + ) + in let ast = Initial_check.process_ast ~generate:true ast in - let ast = { ast with comments = comments } in - + let ast = { ast with comments } in + let () = if !opt_ddump_initial_ast then Pretty_print_sail.pp_ast stdout ast else () in - begin match !opt_reformat with - | Some dir -> - Pretty_print_sail.reformat dir ast; - exit 0 - | None -> () + begin + match !opt_reformat with + | Some dir -> + Pretty_print_sail.reformat dir ast; + exit 0 + | None -> () end; (* The separate loop measures declarations would be awkward to type check, so @@ -109,13 +117,15 @@ let load_files ?target:target default_sail_dir options type_envs files = let t = Profile.start () in let asserts_termination = Option.fold ~none:false ~some:Target.asserts_termination target in - let (ast, type_envs, side_effects) = check_ast asserts_termination type_envs ast in + let ast, type_envs, side_effects = check_ast asserts_termination type_envs ast in Profile.finish "type checking" t; (ast, type_envs, side_effects) -let rewrite_ast_initial effect_info env = Rewrites.rewrite effect_info env [("initial", fun effect_info env ast -> Rewriter.rewrite_ast ast, effect_info, env)] - +let rewrite_ast_initial effect_info env = + Rewrites.rewrite effect_info env + [("initial", fun effect_info env ast -> (Rewriter.rewrite_ast ast, effect_info, env))] + let initial_rewrite effect_info type_envs ast = let ast, _, type_envs = rewrite_ast_initial effect_info type_envs ast in (* Recheck after descattering so that the internal type environments diff --git a/src/lib/frontend.mli b/src/lib/frontend.mli index 2a35399d9..4653fcb42 100644 --- a/src/lib/frontend.mli +++ b/src/lib/frontend.mli @@ -73,14 +73,16 @@ val opt_reformat : string option ref open Ast_defs open Ast_util -val check_ast : bool -> Type_check.Env.t -> uannot ast -> Type_check.tannot ast * Type_check.Env.t * Effects.side_effect_info - +val check_ast : + bool -> Type_check.Env.t -> uannot ast -> Type_check.tannot ast * Type_check.Env.t * Effects.side_effect_info + val load_files : ?target:Target.target -> string -> (Arg.key * Arg.spec * Arg.doc) list -> Type_check.Env.t -> string list -> - (Type_check.tannot ast * Type_check.Env.t * Effects.side_effect_info) + Type_check.tannot ast * Type_check.Env.t * Effects.side_effect_info -val initial_rewrite : Effects.side_effect_info -> Type_check.Env.t -> Type_check.tannot ast -> Type_check.tannot ast * Type_check.Env.t +val initial_rewrite : + Effects.side_effect_info -> Type_check.Env.t -> Type_check.tannot ast -> Type_check.tannot ast * Type_check.Env.t diff --git a/src/lib/graph.ml b/src/lib/graph.ml index d23a3e3b5..62f8fd3b3 100644 --- a/src/lib/graph.ml +++ b/src/lib/graph.ml @@ -65,67 +65,66 @@ (* SUCH DAMAGE. *) (****************************************************************************) -module type OrderedType = - sig - type t - val compare : t -> t -> int - end +module type OrderedType = sig + type t + val compare : t -> t -> int +end -module type S = - sig - type node - type graph - type node_set +module type S = sig + type node + type graph + type node_set - val leaves : graph -> node_set + val leaves : graph -> node_set - val empty : graph + val empty : graph - (** Add an edge from the first node to the second node, creating + (** Add an edge from the first node to the second node, creating the nodes if they do not exist. *) - val add_edge : node -> node -> graph -> graph - val add_edges : node -> node list -> graph -> graph + val add_edge : node -> node -> graph -> graph + + val add_edges : node -> node list -> graph -> graph + + val children : graph -> node -> node list - val children : graph -> node -> node list + val nodes : graph -> node list - val nodes : graph -> node list - - (** Return the set of nodes that are reachable from the first set + (** Return the set of nodes that are reachable from the first set of nodes (roots), without passing through the second set of nodes (cuts). *) - val reachable : node_set -> node_set -> graph -> node_set + val reachable : node_set -> node_set -> graph -> node_set - (** Prune a graph from roots to cuts. *) - val prune : node_set -> node_set -> graph -> graph + (** Prune a graph from roots to cuts. *) + val prune : node_set -> node_set -> graph -> graph - val remove_self_loops : graph -> graph + val remove_self_loops : graph -> graph - val self_loops : graph -> node list + val self_loops : graph -> node list - val reverse : graph -> graph + val reverse : graph -> graph - exception Not_a_DAG of node * graph;; + exception Not_a_DAG of node * graph - (** Topologically sort a graph. Throws Not_a_DAG if the graph is + (** Topologically sort a graph. Throws Not_a_DAG if the graph is not directed acyclic. *) - val topsort : graph -> node list + val topsort : graph -> node list - (** Find strongly connected components using Tarjan's algorithm. + (** Find strongly connected components using Tarjan's algorithm. This algorithm also returns a topological sorting of the graph components. *) - val scc : ?original_order:(node list) -> graph -> node list list + val scc : ?original_order:node list -> graph -> node list list - val make_dot : (node -> string) -> (node -> node -> string) -> (node -> string) -> out_channel -> graph -> unit - end - -module Make(Ord: OrderedType) = struct + val make_dot : (node -> string) -> (node -> node -> string) -> (node -> string) -> out_channel -> graph -> unit +end +module Make (Ord : OrderedType) = struct type node = Ord.t (* Node set *) - module NS = Set.Make(Ord) + module NS = Set.Make (Ord) + (* Node map *) - module NM = Map.Make(Ord) + module NM = Map.Make (Ord) type graph = NS.t NM.t @@ -134,15 +133,13 @@ module Make(Ord: OrderedType) = struct let empty = NM.empty let leaves cg = - List.fold_left (fun acc (fn, callees) -> NS.filter (fun callee -> callee <> fn) (NS.union acc callees)) NS.empty (NM.bindings cg) + List.fold_left + (fun acc (fn, callees) -> NS.filter (fun callee -> callee <> fn) (NS.union acc callees)) + NS.empty (NM.bindings cg) let nodes cg = NM.bindings cg |> List.map fst - - let children cg caller = - try - NS.elements (NM.find caller cg) - with - | Not_found -> [] + + let children cg caller = try NS.elements (NM.find caller cg) with Not_found -> [] let fix_some_leaves cg nodes = NS.fold (fun leaf cg -> if NM.mem leaf cg then cg else NM.add leaf NS.empty cg) nodes cg @@ -151,18 +148,12 @@ module Make(Ord: OrderedType) = struct let add_edge caller callee cg = let cg = fix_some_leaves cg (NS.singleton callee) in - try - NM.add caller (NS.add callee (NM.find caller cg)) cg - with - | Not_found -> NM.add caller (NS.singleton callee) cg + try NM.add caller (NS.add callee (NM.find caller cg)) cg with Not_found -> NM.add caller (NS.singleton callee) cg let add_edges caller callees cg = let callees = List.fold_left (fun s c -> NS.add c s) NS.empty callees in let cg = fix_some_leaves cg callees in - try - NM.add caller (NS.union callees (NM.find caller cg)) cg - with - | Not_found -> NM.add caller callees cg + try NM.add caller (NS.union callees (NM.find caller cg)) cg with Not_found -> NM.add caller callees cg let reachable roots cuts cg = let visited = ref NS.empty in @@ -170,18 +161,15 @@ module Make(Ord: OrderedType) = struct let rec reachable' node = if NS.mem node !visited then () else if NS.mem node cuts then visited := NS.add node !visited - else - begin - visited := NS.add node !visited; - let children = - try NM.find node cg with - | Not_found -> NS.empty - in - NS.iter reachable' children - end + else begin + visited := NS.add node !visited; + let children = try NM.find node cg with Not_found -> NS.empty in + NS.iter reachable' children + end in - NS.iter reachable' roots; !visited + NS.iter reachable' roots; + !visited let prune roots cuts cg = let rs = reachable roots cuts cg in @@ -189,26 +177,20 @@ module Make(Ord: OrderedType) = struct let cg = NM.mapi (fun fn children -> if NS.mem fn cuts then NS.empty else children) cg in fix_leaves cg - let remove_self_loops cg = - NM.mapi (fun fn callees -> NS.remove fn callees) cg + let remove_self_loops cg = NM.mapi (fun fn callees -> NS.remove fn callees) cg - let self_loops cg = - NM.fold (fun fn callees nodes -> - if NS.mem fn callees then ( - fn :: nodes - ) else ( - nodes - ) - ) cg [] + let self_loops cg = NM.fold (fun fn callees nodes -> if NS.mem fn callees then fn :: nodes else nodes) cg [] let reverse cg = let rcg = ref NM.empty in let find_default fn cg = try NM.find fn cg with Not_found -> NS.empty in - NM.iter (fun caller -> NS.iter (fun callee -> rcg := NM.add callee (NS.add caller (find_default callee !rcg)) !rcg)) cg; + NM.iter + (fun caller -> NS.iter (fun callee -> rcg := NM.add callee (NS.add caller (find_default callee !rcg)) !rcg)) + cg; fix_leaves !rcg - exception Not_a_DAG of node * graph;; + exception Not_a_DAG of node * graph let prune_loop node cg = let down = prune (NS.singleton node) NS.empty cg in @@ -223,32 +205,32 @@ module Make(Ord: OrderedType) = struct let find_unmarked keys = List.find (fun node -> not (NS.mem node !marked)) keys in let rec visit node = - if NS.mem node !temp_marked - then raise (let lcg = prune_loop node cg in Not_a_DAG (node, lcg)) - else if NS.mem node !marked - then () - else - begin - let children = - try NM.find node cg with - | Not_found -> NS.empty - in - temp_marked := NS.add node !temp_marked; - NS.iter (fun child -> visit child) children; - marked := NS.add node !marked; - temp_marked := NS.remove node !temp_marked; - list := node :: !list - end + if NS.mem node !temp_marked then + raise + (let lcg = prune_loop node cg in + Not_a_DAG (node, lcg) + ) + else if NS.mem node !marked then () + else begin + let children = try NM.find node cg with Not_found -> NS.empty in + temp_marked := NS.add node !temp_marked; + NS.iter (fun child -> visit child) children; + marked := NS.add node !marked; + temp_marked := NS.remove node !temp_marked; + list := node :: !list + end in let rec topsort' () = try let unmarked = find_unmarked keys in - visit unmarked; topsort' () - with - | Not_found -> () + visit unmarked; + topsort' () + with Not_found -> () + in - in topsort' (); !list + topsort' (); + !list type node_idx = { index : int; root : int } @@ -257,7 +239,7 @@ module Make(Ord: OrderedType) = struct let index = ref 0 in let stack = ref [] in - let push v = (stack := v :: !stack) in + let push v = stack := v :: !stack in let pop () = begin let v = List.hd !stack in @@ -269,8 +251,7 @@ module Make(Ord: OrderedType) = struct let node_indices = Hashtbl.create (NM.cardinal cg) in let get_index v = (Hashtbl.find node_indices v).index in let get_root v = (Hashtbl.find node_indices v).root in - let set_root v r = - Hashtbl.replace node_indices v { (Hashtbl.find node_indices v) with root = r } in + let set_root v r = Hashtbl.replace node_indices v { (Hashtbl.find node_indices v) with root = r } in let rec visit_node v = begin @@ -278,32 +259,29 @@ module Make(Ord: OrderedType) = struct index := !index + 1; push v; if NM.mem v cg then NS.iter (visit_edge v) (NM.find v cg); - if get_root v = get_index v then (* v is the root of a SCC *) - begin - let component = ref [] in - let finished = ref false in - while not !finished do - let w = pop () in - component := w :: !component; - if Ord.compare v w = 0 then finished := true; - done; - components := !component :: !components; - end + if get_root v = get_index v then begin + (* v is the root of a SCC *) + let component = ref [] in + let finished = ref false in + while not !finished do + let w = pop () in + component := w :: !component; + if Ord.compare v w = 0 then finished := true + done; + components := !component :: !components + end end and visit_edge v w = - if not (Hashtbl.mem node_indices w) then - begin - visit_node w; - if Hashtbl.mem node_indices w then set_root v (min (get_root v) (get_root w)); - end else begin - if List.mem w !stack then set_root v (min (get_root v) (get_index w)) - end + if not (Hashtbl.mem node_indices w) then begin + visit_node w; + if Hashtbl.mem node_indices w then set_root v (min (get_root v) (get_root w)) + end + else begin + if List.mem w !stack then set_root v (min (get_root v) (get_index w)) + end in - let nodes = match original_order with - | Some nodes -> nodes - | None -> List.map fst (NM.bindings cg) - in + let nodes = match original_order with Some nodes -> nodes | None -> List.map fst (NM.bindings cg) in List.iter (fun v -> if not (Hashtbl.mem node_indices v) then visit_node v) nodes; List.rev !components @@ -312,14 +290,17 @@ module Make(Ord: OrderedType) = struct let to_string node = String.escaped (string_of_node node) in output_string out_chan "digraph DEPS {\n"; let make_node from_node = - output_string out_chan (Printf.sprintf " \"%s\" [fillcolor=%s;style=filled];\n" (to_string from_node) (node_color from_node)) + output_string out_chan + (Printf.sprintf " \"%s\" [fillcolor=%s;style=filled];\n" (to_string from_node) (node_color from_node)) in let make_line from_node to_node = - output_string out_chan (Printf.sprintf " \"%s\" -> \"%s\" [color=%s];\n" (to_string from_node) (to_string to_node) (edge_color from_node to_node)) + output_string out_chan + (Printf.sprintf " \"%s\" -> \"%s\" [color=%s];\n" (to_string from_node) (to_string to_node) + (edge_color from_node to_node) + ) in NM.bindings graph |> List.iter (fun (from_node, _) -> make_node from_node); NM.bindings graph |> List.iter (fun (from_node, to_nodes) -> NS.iter (make_line from_node) to_nodes); output_string out_chan "}\n"; Util.opt_colors := true - end diff --git a/src/lib/graph.mli b/src/lib/graph.mli index 85b069f74..1f3913dec 100644 --- a/src/lib/graph.mli +++ b/src/lib/graph.mli @@ -67,60 +67,57 @@ (** Generic graph type based on OCaml Set and Map *) -module type OrderedType = - sig - type t - val compare : t -> t -> int - end +module type OrderedType = sig + type t + val compare : t -> t -> int +end -module type S = - sig - type node - type graph - type node_set +module type S = sig + type node + type graph + type node_set - val leaves : graph -> node_set + val leaves : graph -> node_set - val empty : graph + val empty : graph - (** Add an edge from the first node to the second node, creating + (** Add an edge from the first node to the second node, creating the nodes if they do not exist. *) - val add_edge : node -> node -> graph -> graph - val add_edges : node -> node list -> graph -> graph + val add_edge : node -> node -> graph -> graph - val children : graph -> node -> node list + val add_edges : node -> node list -> graph -> graph - val nodes : graph -> node list - - (** Return the set of nodes that are reachable from the first set + val children : graph -> node -> node list + + val nodes : graph -> node list + + (** Return the set of nodes that are reachable from the first set of nodes (roots), without passing through the second set of nodes (cuts). *) - val reachable : node_set -> node_set -> graph -> node_set + val reachable : node_set -> node_set -> graph -> node_set - (** Prune a graph from roots to cuts. *) - val prune : node_set -> node_set -> graph -> graph + (** Prune a graph from roots to cuts. *) + val prune : node_set -> node_set -> graph -> graph - val remove_self_loops : graph -> graph + val remove_self_loops : graph -> graph - val self_loops : graph -> node list + val self_loops : graph -> node list - val reverse : graph -> graph + val reverse : graph -> graph - exception Not_a_DAG of node * graph;; + exception Not_a_DAG of node * graph - (** Topologically sort a graph. Throws Not_a_DAG if the graph is + (** Topologically sort a graph. Throws Not_a_DAG if the graph is not directed acyclic. *) - val topsort : graph -> node list + val topsort : graph -> node list - (** Find strongly connected components using Tarjan's algorithm. + (** Find strongly connected components using Tarjan's algorithm. This algorithm also returns a topological sorting of the graph components. *) - val scc : ?original_order:(node list) -> graph -> node list list + val scc : ?original_order:node list -> graph -> node list list - val make_dot : (node -> string) -> (node -> node -> string) -> (node -> string) -> out_channel -> graph -> unit - end + val make_dot : (node -> string) -> (node -> node -> string) -> (node -> string) -> out_channel -> graph -> unit +end -module Make(Ord: OrderedType) : S - with type node = Ord.t - and type node_set = Set.Make(Ord).t - and type graph = Set.Make(Ord).t Map.Make(Ord).t +module Make (Ord : OrderedType) : + S with type node = Ord.t and type node_set = Set.Make(Ord).t and type graph = Set.Make(Ord).t Map.Make(Ord).t diff --git a/src/lib/initial_check.ml b/src/lib/initial_check.ml index 07366026d..300aca24e 100644 --- a/src/lib/initial_check.ml +++ b/src/lib/initial_check.ml @@ -81,47 +81,43 @@ let opt_magic_hash = ref false let opt_enum_casts = ref false type ctx = { - kinds : kind_aux KBindings.t; - type_constructors : (kind_aux list) Bindings.t; - scattereds : ctx Bindings.t; - reserved_type_ids : id list; - internal_files : string list; - target_sets : (string * string list) list; - } + kinds : kind_aux KBindings.t; + type_constructors : kind_aux list Bindings.t; + scattereds : ctx Bindings.t; + reserved_type_ids : id list; + internal_files : string list; + target_sets : (string * string list) list; +} -let string_of_parse_id_aux = function - | P.Id v -> v - | P.Operator v -> v +let string_of_parse_id_aux = function P.Id v -> v | P.Operator v -> v let string_of_parse_id (P.Id_aux (id, l)) = string_of_parse_id_aux id let parse_id_loc (P.Id_aux (_, l)) = l let string_contains str char = - try (ignore (String.index str char); true) with - | Not_found -> false + try + ignore (String.index str char); + true + with Not_found -> false let to_ast_kind (P.K_aux (k, l)) = match k with - | P.K_type -> K_aux (K_type, l) - | P.K_int -> K_aux (K_int, l) + | P.K_type -> K_aux (K_type, l) + | P.K_int -> K_aux (K_int, l) | P.K_order -> K_aux (K_order, l) - | P.K_bool -> K_aux (K_bool, l) + | P.K_bool -> K_aux (K_bool, l) let to_ast_id ctx (P.Id_aux (id, l)) = - let to_ast_id' id = Id_aux ((match id with - | P.Id x -> Id x - | P.Operator x -> Operator x), - l) in - if string_contains (string_of_parse_id_aux id) '#' then - begin match Reporting.loc_file l with - | Some file when !opt_magic_hash || List.exists (fun internal_file -> file = internal_file) ctx.internal_files -> to_ast_id' id + let to_ast_id' id = Id_aux ((match id with P.Id x -> Id x | P.Operator x -> Operator x), l) in + if string_contains (string_of_parse_id_aux id) '#' then begin + match Reporting.loc_file l with + | Some file when !opt_magic_hash || List.exists (fun internal_file -> file = internal_file) ctx.internal_files -> + to_ast_id' id | None -> to_ast_id' id - | _ -> - raise (Reporting.err_general l "Identifier contains hash character and -dmagic_hash is unset") - end - else - to_ast_id' id + | _ -> raise (Reporting.err_general l "Identifier contains hash character and -dmagic_hash is unset") + end + else to_ast_id' id let to_ast_var (P.Kid_aux (P.Var v, l)) = Kid_aux (Var v, l) @@ -134,67 +130,86 @@ let to_ast_kopts ctx (P.KOpt_aux (aux, l)) = let mk_kopt v k = let v = to_ast_var v in let k = to_ast_kind k in - KOpt_aux (KOpt_kind (k, v), l), { ctx with kinds = KBindings.add v (unaux_kind k) ctx.kinds } + (KOpt_aux (KOpt_kind (k, v), l), { ctx with kinds = KBindings.add v (unaux_kind k) ctx.kinds }) in match aux with | P.KOpt_kind (attr, vs, None) -> - let k = P.K_aux (P.K_int, gen_loc l) in - List.fold_left (fun (kopts, ctx) v -> let kopt, ctx = mk_kopt v k in (kopt :: kopts, ctx)) ([], ctx) vs, attr + let k = P.K_aux (P.K_int, gen_loc l) in + ( List.fold_left + (fun (kopts, ctx) v -> + let kopt, ctx = mk_kopt v k in + (kopt :: kopts, ctx) + ) + ([], ctx) vs, + attr + ) | P.KOpt_kind (attr, vs, Some k) -> - List.fold_left (fun (kopts, ctx) v -> let kopt, ctx = mk_kopt v k in (kopt :: kopts, ctx)) ([], ctx) vs, attr + ( List.fold_left + (fun (kopts, ctx) v -> + let kopt, ctx = mk_kopt v k in + (kopt :: kopts, ctx) + ) + ([], ctx) vs, + attr + ) let rec to_ast_typ ctx (P.ATyp_aux (aux, l)) = - let aux = match aux with + let aux = + match aux with | P.ATyp_id id -> Typ_id (to_ast_id ctx id) | P.ATyp_var v -> Typ_var (to_ast_var v) | P.ATyp_fn (from_typ, to_typ, _) -> - let from_typs = match from_typ with - | P.ATyp_aux (P.ATyp_tuple typs, _) -> - List.map (to_ast_typ ctx) typs - | _ -> [to_ast_typ ctx from_typ] - in - Typ_fn (from_typs, to_ast_typ ctx to_typ) + let from_typs = + match from_typ with + | P.ATyp_aux (P.ATyp_tuple typs, _) -> List.map (to_ast_typ ctx) typs + | _ -> [to_ast_typ ctx from_typ] + in + Typ_fn (from_typs, to_ast_typ ctx to_typ) | P.ATyp_bidir (typ1, typ2, _) -> Typ_bidir (to_ast_typ ctx typ1, to_ast_typ ctx typ2) | P.ATyp_tuple typs -> Typ_tuple (List.map (to_ast_typ ctx) typs) - | P.ATyp_app (P.Id_aux (P.Id "int", il), [n]) -> - Typ_app (Id_aux (Id "atom", il), [to_ast_typ_arg ctx n K_int]) + | P.ATyp_app (P.Id_aux (P.Id "int", il), [n]) -> Typ_app (Id_aux (Id "atom", il), [to_ast_typ_arg ctx n K_int]) | P.ATyp_app (P.Id_aux (P.Id "bool", il), [n]) -> - Typ_app (Id_aux (Id "atom_bool", il), [to_ast_typ_arg ctx n K_bool]) + Typ_app (Id_aux (Id "atom_bool", il), [to_ast_typ_arg ctx n K_bool]) | P.ATyp_app (id, args) -> - let id = to_ast_id ctx id in - begin match Bindings.find_opt id ctx.type_constructors with - | None -> raise (Reporting.err_typ l (sprintf "Could not find type constructor %s" (string_of_id id))) - | Some kinds when List.length args <> List.length kinds -> - raise (Reporting.err_typ l (sprintf "%s : %s -> Type expected %d arguments, given %d" - (string_of_id id) (format_kind_aux_list kinds) - (List.length kinds) (List.length args))) - | Some kinds -> - Typ_app (id, List.map2 (to_ast_typ_arg ctx) args kinds) - end + let id = to_ast_id ctx id in + begin + match Bindings.find_opt id ctx.type_constructors with + | None -> raise (Reporting.err_typ l (sprintf "Could not find type constructor %s" (string_of_id id))) + | Some kinds when List.length args <> List.length kinds -> + raise + (Reporting.err_typ l + (sprintf "%s : %s -> Type expected %d arguments, given %d" (string_of_id id) + (format_kind_aux_list kinds) (List.length kinds) (List.length args) + ) + ) + | Some kinds -> Typ_app (id, List.map2 (to_ast_typ_arg ctx) args kinds) + end | P.ATyp_exist (kopts, nc, atyp) -> - let kopts, ctx = - List.fold_right (fun kopt (kopts, ctx) -> - let (kopts', ctx), attr = to_ast_kopts ctx kopt in - match attr with - | None -> - kopts' @ kopts, ctx - | Some attr -> - raise (Reporting.err_typ l (sprintf "Attribute %s cannot appear within an existential type" attr)) - ) kopts ([], ctx) - in - Typ_exist (kopts, to_ast_constraint ctx nc, to_ast_typ ctx atyp) + let kopts, ctx = + List.fold_right + (fun kopt (kopts, ctx) -> + let (kopts', ctx), attr = to_ast_kopts ctx kopt in + match attr with + | None -> (kopts' @ kopts, ctx) + | Some attr -> + raise (Reporting.err_typ l (sprintf "Attribute %s cannot appear within an existential type" attr)) + ) + kopts ([], ctx) + in + Typ_exist (kopts, to_ast_constraint ctx nc, to_ast_typ ctx atyp) | _ -> raise (Reporting.err_typ l "Invalid type") in Typ_aux (aux, l) and to_ast_typ_arg ctx (ATyp_aux (_, l) as atyp) = function - | K_type -> A_aux (A_typ (to_ast_typ ctx atyp), l) - | K_int -> A_aux (A_nexp (to_ast_nexp ctx atyp), l) + | K_type -> A_aux (A_typ (to_ast_typ ctx atyp), l) + | K_int -> A_aux (A_nexp (to_ast_nexp ctx atyp), l) | K_order -> A_aux (A_order (to_ast_order ctx atyp), l) - | K_bool -> A_aux (A_bool (to_ast_constraint ctx atyp), l) + | K_bool -> A_aux (A_bool (to_ast_constraint ctx atyp), l) and to_ast_nexp ctx (P.ATyp_aux (aux, l)) = - let aux = match aux with + let aux = + match aux with | P.ATyp_id id -> Nexp_id (to_ast_id ctx id) | P.ATyp_var v -> Nexp_var (to_ast_var v) | P.ATyp_lit (P.L_aux (P.L_num c, _)) -> Nexp_constant c @@ -209,7 +224,8 @@ and to_ast_nexp ctx (P.ATyp_aux (aux, l)) = Nexp_aux (aux, l) and to_ast_bitfield_index_nexp ctx (P.ATyp_aux (aux, l)) = - let aux = match aux with + let aux = + match aux with | P.ATyp_id id -> Nexp_id (to_ast_id ctx id) | P.ATyp_lit (P.L_aux (P.L_num c, _)) -> Nexp_constant c | P.ATyp_sum (t1, t2) -> Nexp_sum (to_ast_bitfield_index_nexp ctx t1, to_ast_bitfield_index_nexp ctx t2) @@ -230,37 +246,46 @@ and to_ast_order ctx (P.ATyp_aux (aux, l)) = | _ -> raise (Reporting.err_typ l "Invalid order in type") and to_ast_constraint ctx (P.ATyp_aux (aux, l)) = - let aux = match aux with - | P.ATyp_app (Id_aux (Operator op, _) as id, [t1; t2]) -> - begin match op with - | "==" -> NC_equal (to_ast_nexp ctx t1, to_ast_nexp ctx t2) - | "!=" -> NC_not_equal (to_ast_nexp ctx t1, to_ast_nexp ctx t2) - | ">=" -> NC_bounded_ge (to_ast_nexp ctx t1, to_ast_nexp ctx t2) - | "<=" -> NC_bounded_le (to_ast_nexp ctx t1, to_ast_nexp ctx t2) - | ">" -> NC_bounded_gt (to_ast_nexp ctx t1, to_ast_nexp ctx t2) - | "<" -> NC_bounded_lt (to_ast_nexp ctx t1, to_ast_nexp ctx t2) - | "&" -> NC_and (to_ast_constraint ctx t1, to_ast_constraint ctx t2) - | "|" -> NC_or (to_ast_constraint ctx t1, to_ast_constraint ctx t2) - | _ -> - let id = to_ast_id ctx id in + let aux = + match aux with + | P.ATyp_app ((Id_aux (Operator op, _) as id), [t1; t2]) -> begin + match op with + | "==" -> NC_equal (to_ast_nexp ctx t1, to_ast_nexp ctx t2) + | "!=" -> NC_not_equal (to_ast_nexp ctx t1, to_ast_nexp ctx t2) + | ">=" -> NC_bounded_ge (to_ast_nexp ctx t1, to_ast_nexp ctx t2) + | "<=" -> NC_bounded_le (to_ast_nexp ctx t1, to_ast_nexp ctx t2) + | ">" -> NC_bounded_gt (to_ast_nexp ctx t1, to_ast_nexp ctx t2) + | "<" -> NC_bounded_lt (to_ast_nexp ctx t1, to_ast_nexp ctx t2) + | "&" -> NC_and (to_ast_constraint ctx t1, to_ast_constraint ctx t2) + | "|" -> NC_or (to_ast_constraint ctx t1, to_ast_constraint ctx t2) + | _ -> ( + let id = to_ast_id ctx id in + match Bindings.find_opt id ctx.type_constructors with + | None -> raise (Reporting.err_typ l (sprintf "Could not find type constructor %s" (string_of_id id))) + | Some kinds when List.length kinds <> 2 -> + raise + (Reporting.err_typ l + (sprintf "%s : %s -> Bool expected %d arguments, given 2" (string_of_id id) + (format_kind_aux_list kinds) (List.length kinds) + ) + ) + | Some kinds -> NC_app (id, List.map2 (to_ast_typ_arg ctx) [t1; t2] kinds) + ) + end + | P.ATyp_app (id, args) -> + let id = to_ast_id ctx id in + begin match Bindings.find_opt id ctx.type_constructors with | None -> raise (Reporting.err_typ l (sprintf "Could not find type constructor %s" (string_of_id id))) - | Some kinds when List.length kinds <> 2 -> - raise (Reporting.err_typ l (sprintf "%s : %s -> Bool expected %d arguments, given 2" - (string_of_id id) (format_kind_aux_list kinds) - (List.length kinds))) - | Some kinds -> NC_app (id, List.map2 (to_ast_typ_arg ctx) [t1; t2] kinds) - end - | P.ATyp_app (id, args) -> - let id = to_ast_id ctx id in - begin match Bindings.find_opt id ctx.type_constructors with - | None -> raise (Reporting.err_typ l (sprintf "Could not find type constructor %s" (string_of_id id))) - | Some kinds when List.length args <> List.length kinds -> - raise (Reporting.err_typ l (sprintf "%s : %s -> Bool expected %d arguments, given %d" - (string_of_id id) (format_kind_aux_list kinds) - (List.length kinds) (List.length args))) - | Some kinds -> NC_app (id, List.map2 (to_ast_typ_arg ctx) args kinds) - end + | Some kinds when List.length args <> List.length kinds -> + raise + (Reporting.err_typ l + (sprintf "%s : %s -> Bool expected %d arguments, given %d" (string_of_id id) + (format_kind_aux_list kinds) (List.length kinds) (List.length args) + ) + ) + | Some kinds -> NC_app (id, List.map2 (to_ast_typ_arg ctx) args kinds) + end | P.ATyp_var v -> NC_var (to_ast_var v) | P.ATyp_lit (P.L_aux (P.L_true, _)) -> NC_true | P.ATyp_lit (P.L_aux (P.L_false, _)) -> NC_false @@ -271,331 +296,341 @@ and to_ast_constraint ctx (P.ATyp_aux (aux, l)) = let to_ast_quant_items ctx (P.QI_aux (aux, l)) = match aux with - | P.QI_constraint nc -> [QI_aux (QI_constraint (to_ast_constraint ctx nc), l)], ctx - | P.QI_id kopt -> - let (kopts, ctx), attr = to_ast_kopts ctx kopt in - match attr with - | Some "constant" -> - Reporting.warn "Deprecated" l "constant type variable attribute no longer used"; - List.map (fun kopt -> QI_aux (QI_id kopt, l)) kopts, ctx - | Some attr -> - raise (Reporting.err_typ l (sprintf "Unknown attribute %s" attr)) - | None -> - List.map (fun kopt -> QI_aux (QI_id kopt, l)) kopts, ctx + | P.QI_constraint nc -> ([QI_aux (QI_constraint (to_ast_constraint ctx nc), l)], ctx) + | P.QI_id kopt -> ( + let (kopts, ctx), attr = to_ast_kopts ctx kopt in + match attr with + | Some "constant" -> + Reporting.warn "Deprecated" l "constant type variable attribute no longer used"; + (List.map (fun kopt -> QI_aux (QI_id kopt, l)) kopts, ctx) + | Some attr -> raise (Reporting.err_typ l (sprintf "Unknown attribute %s" attr)) + | None -> (List.map (fun kopt -> QI_aux (QI_id kopt, l)) kopts, ctx) + ) let to_ast_typquant ctx (P.TypQ_aux (aux, l)) = match aux with - | P.TypQ_no_forall -> TypQ_aux (TypQ_no_forall, l), ctx + | P.TypQ_no_forall -> (TypQ_aux (TypQ_no_forall, l), ctx) | P.TypQ_tq quants -> - let quants, ctx = - List.fold_left (fun (qis, ctx) qi -> let qis', ctx = to_ast_quant_items ctx qi in qis' @ qis, ctx) ([], ctx) quants - in - TypQ_aux (TypQ_tq (List.rev quants), l), ctx + let quants, ctx = + List.fold_left + (fun (qis, ctx) qi -> + let qis', ctx = to_ast_quant_items ctx qi in + (qis' @ qis, ctx) + ) + ([], ctx) quants + in + (TypQ_aux (TypQ_tq (List.rev quants), l), ctx) let to_ast_typschm ctx (P.TypSchm_aux (P.TypSchm_ts (typq, typ), l)) = let typq, ctx = to_ast_typquant ctx typq in let typ = to_ast_typ ctx typ in - TypSchm_aux (TypSchm_ts (typq, typ), l), ctx + (TypSchm_aux (TypSchm_ts (typq, typ), l), ctx) let to_ast_lit (P.L_aux (lit, l)) = - L_aux ((match lit with - | P.L_unit -> L_unit - | P.L_zero -> L_zero - | P.L_one -> L_one - | P.L_true -> L_true - | P.L_false -> L_false - | P.L_undef -> L_undef - | P.L_num i -> L_num i - | P.L_hex h -> L_hex h - | P.L_bin b -> L_bin b - | P.L_real r -> L_real r - | P.L_string s -> L_string s) - ,l) + L_aux + ( ( match lit with + | P.L_unit -> L_unit + | P.L_zero -> L_zero + | P.L_one -> L_one + | P.L_true -> L_true + | P.L_false -> L_false + | P.L_undef -> L_undef + | P.L_num i -> L_num i + | P.L_hex h -> L_hex h + | P.L_bin b -> L_bin b + | P.L_real r -> L_real r + | P.L_string s -> L_string s + ), + l + ) let rec to_ast_typ_pat ctx (P.ATyp_aux (aux, l)) = match aux with | P.ATyp_wild -> TP_aux (TP_wild, l) | P.ATyp_var kid -> TP_aux (TP_var (to_ast_var kid), l) | P.ATyp_app (P.Id_aux (P.Id "int", il), typs) -> - TP_aux (TP_app (Id_aux (Id "atom", il), List.map (to_ast_typ_pat ctx) typs), l) - | P.ATyp_app (f, typs) -> - TP_aux (TP_app (to_ast_id ctx f, List.map (to_ast_typ_pat ctx) typs), l) + TP_aux (TP_app (Id_aux (Id "atom", il), List.map (to_ast_typ_pat ctx) typs), l) + | P.ATyp_app (f, typs) -> TP_aux (TP_app (to_ast_id ctx f, List.map (to_ast_typ_pat ctx) typs), l) | _ -> raise (Reporting.err_typ l "Unexpected type in type pattern") let rec to_ast_pat ctx (P.P_aux (aux, l)) = match aux with | P.P_attribute (attr, arg, pat) -> - let P_aux (aux, (pat_l, annot)) = to_ast_pat ctx pat in - (* The location of an E_attribute node is just the attribute by itself *) - let annot = add_attribute l attr arg annot in - P_aux (aux, (pat_l, annot)) + let (P_aux (aux, (pat_l, annot))) = to_ast_pat ctx pat in + (* The location of an E_attribute node is just the attribute by itself *) + let annot = add_attribute l attr arg annot in + P_aux (aux, (pat_l, annot)) | _ -> - let aux = match aux with - | P.P_attribute _ -> assert false - | P.P_lit lit -> P_lit (to_ast_lit lit) - | P.P_wild -> P_wild - | P.P_var (pat, P.ATyp_aux (P.ATyp_id id, _)) -> - P_as (to_ast_pat ctx pat, to_ast_id ctx id) - | P.P_typ (typ, pat) -> P_typ (to_ast_typ ctx typ, to_ast_pat ctx pat) - | P.P_id id -> P_id (to_ast_id ctx id) - | P.P_var (pat, typ) -> P_var (to_ast_pat ctx pat, to_ast_typ_pat ctx typ) - | P.P_app (id, []) -> P_id (to_ast_id ctx id) - | P.P_app (id, pats) -> - if List.length pats == 1 && string_of_parse_id id = "~" - then P_not (to_ast_pat ctx (List.hd pats)) - else P_app (to_ast_id ctx id, List.map (to_ast_pat ctx) pats) - | P.P_vector(pats) -> P_vector (List.map (to_ast_pat ctx) pats) - | P.P_vector_concat(pats) -> P_vector_concat (List.map (to_ast_pat ctx) pats) - | P.P_vector_subrange (id, n, m) -> P_vector_subrange (to_ast_id ctx id, n, m) - | P.P_tuple(pats) -> P_tuple (List.map (to_ast_pat ctx) pats) - | P.P_list(pats) -> P_list(List.map (to_ast_pat ctx) pats) - | P.P_cons(pat1, pat2) -> P_cons (to_ast_pat ctx pat1, to_ast_pat ctx pat2) - | P.P_string_append pats -> P_string_append (List.map (to_ast_pat ctx) pats) - in - P_aux (aux, (l, empty_uannot)) - -let rec to_ast_letbind ctx (P.LB_aux(lb,l) : P.letbind) : uannot letbind = - LB_aux( - (match lb with - | P.LB_val(pat,exp) -> - LB_val(to_ast_pat ctx pat, to_ast_exp ctx exp) - ), (l, empty_uannot)) + let aux = + match aux with + | P.P_attribute _ -> assert false + | P.P_lit lit -> P_lit (to_ast_lit lit) + | P.P_wild -> P_wild + | P.P_var (pat, P.ATyp_aux (P.ATyp_id id, _)) -> P_as (to_ast_pat ctx pat, to_ast_id ctx id) + | P.P_typ (typ, pat) -> P_typ (to_ast_typ ctx typ, to_ast_pat ctx pat) + | P.P_id id -> P_id (to_ast_id ctx id) + | P.P_var (pat, typ) -> P_var (to_ast_pat ctx pat, to_ast_typ_pat ctx typ) + | P.P_app (id, []) -> P_id (to_ast_id ctx id) + | P.P_app (id, pats) -> + if List.length pats == 1 && string_of_parse_id id = "~" then P_not (to_ast_pat ctx (List.hd pats)) + else P_app (to_ast_id ctx id, List.map (to_ast_pat ctx) pats) + | P.P_vector pats -> P_vector (List.map (to_ast_pat ctx) pats) + | P.P_vector_concat pats -> P_vector_concat (List.map (to_ast_pat ctx) pats) + | P.P_vector_subrange (id, n, m) -> P_vector_subrange (to_ast_id ctx id, n, m) + | P.P_tuple pats -> P_tuple (List.map (to_ast_pat ctx) pats) + | P.P_list pats -> P_list (List.map (to_ast_pat ctx) pats) + | P.P_cons (pat1, pat2) -> P_cons (to_ast_pat ctx pat1, to_ast_pat ctx pat2) + | P.P_string_append pats -> P_string_append (List.map (to_ast_pat ctx) pats) + in + P_aux (aux, (l, empty_uannot)) + +let rec to_ast_letbind ctx (P.LB_aux (lb, l) : P.letbind) : uannot letbind = + LB_aux ((match lb with P.LB_val (pat, exp) -> LB_val (to_ast_pat ctx pat, to_ast_exp ctx exp)), (l, empty_uannot)) and to_ast_exp ctx (P.E_aux (exp, l) : P.exp) = match exp with | P.E_attribute (attr, arg, exp) -> - let E_aux (exp, (exp_l, annot)) = to_ast_exp ctx exp in - (* The location of an E_attribute node is just the attribute itself *) - let annot = add_attribute l attr arg annot in - E_aux (exp, (exp_l, annot)) + let (E_aux (exp, (exp_l, annot))) = to_ast_exp ctx exp in + (* The location of an E_attribute node is just the attribute itself *) + let annot = add_attribute l attr arg annot in + E_aux (exp, (exp_l, annot)) | _ -> - let aux = match exp with - | P.E_attribute _ -> assert false - | P.E_block exps -> - (match to_ast_fexps false ctx exps with - | Some fexps -> E_struct fexps - | None -> E_block (List.map (to_ast_exp ctx) exps)) - | P.E_id id -> - (* We support identifiers the same as __LOC__, __FILE__ and - __LINE__ in the OCaml standard library, and similar - constructs in C *) - let id_str = string_of_parse_id id in - if id_str = "__LOC__" then ( - E_lit (L_aux (L_string (Reporting.short_loc_to_string l), l)) - ) else if id_str = "__FILE__" then ( - let file = match Reporting.simp_loc l with - | Some (p, _) -> p.pos_fname - | None -> "unknown file" in - E_lit (L_aux (L_string file, l)) - ) else if id_str = "__LINE__" then ( - let lnum = match Reporting.simp_loc l with - | Some (p, _) -> p.pos_lnum - | None -> -1 in - E_lit (L_aux (L_num (Big_int.of_int lnum), l)) - ) else ( - E_id (to_ast_id ctx id) + let aux = + match exp with + | P.E_attribute _ -> assert false + | P.E_block exps -> ( + match to_ast_fexps false ctx exps with + | Some fexps -> E_struct fexps + | None -> E_block (List.map (to_ast_exp ctx) exps) ) - | P.E_ref id -> E_ref (to_ast_id ctx id) - | P.E_lit lit -> E_lit (to_ast_lit lit) - | P.E_typ (typ, exp) -> E_typ (to_ast_typ ctx typ, to_ast_exp ctx exp) - | P.E_app (f, args) -> - (match List.map (to_ast_exp ctx) args with - | [] -> E_app (to_ast_id ctx f, []) - | exps -> E_app (to_ast_id ctx f, exps)) - | P.E_app_infix(left,op,right) -> - E_app_infix(to_ast_exp ctx left, to_ast_id ctx op, to_ast_exp ctx right) - | P.E_tuple(exps) -> E_tuple(List.map (to_ast_exp ctx) exps) - | P.E_if(e1,e2,e3) -> E_if(to_ast_exp ctx e1, to_ast_exp ctx e2, to_ast_exp ctx e3) - | P.E_for(id,e1,e2,e3,atyp,e4) -> - E_for(to_ast_id ctx id,to_ast_exp ctx e1, to_ast_exp ctx e2, - to_ast_exp ctx e3,to_ast_order ctx atyp, to_ast_exp ctx e4) - | P.E_loop (P.While, m, e1, e2) -> E_loop (While, to_ast_measure ctx m, to_ast_exp ctx e1, to_ast_exp ctx e2) - | P.E_loop (P.Until, m, e1, e2) -> E_loop (Until, to_ast_measure ctx m, to_ast_exp ctx e1, to_ast_exp ctx e2) - | P.E_vector(exps) -> E_vector(List.map (to_ast_exp ctx) exps) - | P.E_vector_access(vexp,exp) -> E_vector_access(to_ast_exp ctx vexp, to_ast_exp ctx exp) - | P.E_vector_subrange(vex,exp1,exp2) -> - E_vector_subrange(to_ast_exp ctx vex, to_ast_exp ctx exp1, to_ast_exp ctx exp2) - | P.E_vector_update(vex,exp1,exp2) -> - E_vector_update(to_ast_exp ctx vex, to_ast_exp ctx exp1, to_ast_exp ctx exp2) - | P.E_vector_update_subrange(vex,e1,e2,e3) -> - E_vector_update_subrange(to_ast_exp ctx vex, to_ast_exp ctx e1, - to_ast_exp ctx e2, to_ast_exp ctx e3) - | P.E_vector_append(e1,e2) -> E_vector_append(to_ast_exp ctx e1,to_ast_exp ctx e2) - | P.E_list(exps) -> E_list(List.map (to_ast_exp ctx) exps) - | P.E_cons(e1,e2) -> E_cons(to_ast_exp ctx e1, to_ast_exp ctx e2) - | P.E_struct fexps -> - (match to_ast_fexps true ctx fexps with - | Some fexps -> E_struct fexps - | None -> raise (Reporting.err_unreachable l __POS__ "to_ast_fexps with true returned none")) - | P.E_struct_update(exp,fexps) -> - (match to_ast_fexps true ctx fexps with - | Some(fexps) -> E_struct_update(to_ast_exp ctx exp, fexps) - | _ -> raise (Reporting.err_unreachable l __POS__ "to_ast_fexps with true returned none")) - | P.E_field(exp,id) -> E_field(to_ast_exp ctx exp, to_ast_id ctx id) - | P.E_match(exp,pexps) -> E_match(to_ast_exp ctx exp, List.map (to_ast_case ctx) pexps) - | P.E_try (exp, pexps) -> E_try (to_ast_exp ctx exp, List.map (to_ast_case ctx) pexps) - | P.E_let(leb,exp) -> E_let(to_ast_letbind ctx leb, to_ast_exp ctx exp) - | P.E_assign(lexp,exp) -> E_assign(to_ast_lexp ctx lexp, to_ast_exp ctx exp) - | P.E_var(lexp,exp1,exp2) -> E_var(to_ast_lexp ctx lexp, to_ast_exp ctx exp1, to_ast_exp ctx exp2) - | P.E_sizeof(nexp) -> E_sizeof(to_ast_nexp ctx nexp) - | P.E_constraint nc -> E_constraint (to_ast_constraint ctx nc) - | P.E_exit exp -> E_exit(to_ast_exp ctx exp) - | P.E_throw exp -> E_throw (to_ast_exp ctx exp) - | P.E_return exp -> E_return(to_ast_exp ctx exp) - | P.E_assert(cond,msg) -> E_assert(to_ast_exp ctx cond, to_ast_exp ctx msg) - | P.E_internal_plet(pat,exp1,exp2) -> - if !opt_magic_hash then - E_internal_plet(to_ast_pat ctx pat, to_ast_exp ctx exp1, to_ast_exp ctx exp2) - else - raise (Reporting.err_general l "Internal plet construct found without -dmagic_hash") - | P.E_internal_return(exp) -> - if !opt_magic_hash then - E_internal_return(to_ast_exp ctx exp) - else - raise (Reporting.err_general l "Internal return construct found without -dmagic_hash") - | P.E_deref exp -> - E_app (Id_aux (Id "__deref", l), [to_ast_exp ctx exp]) - in - E_aux (aux, (l, empty_uannot)) - -and to_ast_measure ctx (P.Measure_aux(m,l)) : uannot internal_loop_measure = - let m = match m with + | P.E_id id -> + (* We support identifiers the same as __LOC__, __FILE__ and + __LINE__ in the OCaml standard library, and similar + constructs in C *) + let id_str = string_of_parse_id id in + if id_str = "__LOC__" then E_lit (L_aux (L_string (Reporting.short_loc_to_string l), l)) + else if id_str = "__FILE__" then ( + let file = match Reporting.simp_loc l with Some (p, _) -> p.pos_fname | None -> "unknown file" in + E_lit (L_aux (L_string file, l)) + ) + else if id_str = "__LINE__" then ( + let lnum = match Reporting.simp_loc l with Some (p, _) -> p.pos_lnum | None -> -1 in + E_lit (L_aux (L_num (Big_int.of_int lnum), l)) + ) + else E_id (to_ast_id ctx id) + | P.E_ref id -> E_ref (to_ast_id ctx id) + | P.E_lit lit -> E_lit (to_ast_lit lit) + | P.E_typ (typ, exp) -> E_typ (to_ast_typ ctx typ, to_ast_exp ctx exp) + | P.E_app (f, args) -> ( + match List.map (to_ast_exp ctx) args with + | [] -> E_app (to_ast_id ctx f, []) + | exps -> E_app (to_ast_id ctx f, exps) + ) + | P.E_app_infix (left, op, right) -> E_app_infix (to_ast_exp ctx left, to_ast_id ctx op, to_ast_exp ctx right) + | P.E_tuple exps -> E_tuple (List.map (to_ast_exp ctx) exps) + | P.E_if (e1, e2, e3) -> E_if (to_ast_exp ctx e1, to_ast_exp ctx e2, to_ast_exp ctx e3) + | P.E_for (id, e1, e2, e3, atyp, e4) -> + E_for + ( to_ast_id ctx id, + to_ast_exp ctx e1, + to_ast_exp ctx e2, + to_ast_exp ctx e3, + to_ast_order ctx atyp, + to_ast_exp ctx e4 + ) + | P.E_loop (P.While, m, e1, e2) -> E_loop (While, to_ast_measure ctx m, to_ast_exp ctx e1, to_ast_exp ctx e2) + | P.E_loop (P.Until, m, e1, e2) -> E_loop (Until, to_ast_measure ctx m, to_ast_exp ctx e1, to_ast_exp ctx e2) + | P.E_vector exps -> E_vector (List.map (to_ast_exp ctx) exps) + | P.E_vector_access (vexp, exp) -> E_vector_access (to_ast_exp ctx vexp, to_ast_exp ctx exp) + | P.E_vector_subrange (vex, exp1, exp2) -> + E_vector_subrange (to_ast_exp ctx vex, to_ast_exp ctx exp1, to_ast_exp ctx exp2) + | P.E_vector_update (vex, exp1, exp2) -> + E_vector_update (to_ast_exp ctx vex, to_ast_exp ctx exp1, to_ast_exp ctx exp2) + | P.E_vector_update_subrange (vex, e1, e2, e3) -> + E_vector_update_subrange (to_ast_exp ctx vex, to_ast_exp ctx e1, to_ast_exp ctx e2, to_ast_exp ctx e3) + | P.E_vector_append (e1, e2) -> E_vector_append (to_ast_exp ctx e1, to_ast_exp ctx e2) + | P.E_list exps -> E_list (List.map (to_ast_exp ctx) exps) + | P.E_cons (e1, e2) -> E_cons (to_ast_exp ctx e1, to_ast_exp ctx e2) + | P.E_struct fexps -> ( + match to_ast_fexps true ctx fexps with + | Some fexps -> E_struct fexps + | None -> raise (Reporting.err_unreachable l __POS__ "to_ast_fexps with true returned none") + ) + | P.E_struct_update (exp, fexps) -> ( + match to_ast_fexps true ctx fexps with + | Some fexps -> E_struct_update (to_ast_exp ctx exp, fexps) + | _ -> raise (Reporting.err_unreachable l __POS__ "to_ast_fexps with true returned none") + ) + | P.E_field (exp, id) -> E_field (to_ast_exp ctx exp, to_ast_id ctx id) + | P.E_match (exp, pexps) -> E_match (to_ast_exp ctx exp, List.map (to_ast_case ctx) pexps) + | P.E_try (exp, pexps) -> E_try (to_ast_exp ctx exp, List.map (to_ast_case ctx) pexps) + | P.E_let (leb, exp) -> E_let (to_ast_letbind ctx leb, to_ast_exp ctx exp) + | P.E_assign (lexp, exp) -> E_assign (to_ast_lexp ctx lexp, to_ast_exp ctx exp) + | P.E_var (lexp, exp1, exp2) -> E_var (to_ast_lexp ctx lexp, to_ast_exp ctx exp1, to_ast_exp ctx exp2) + | P.E_sizeof nexp -> E_sizeof (to_ast_nexp ctx nexp) + | P.E_constraint nc -> E_constraint (to_ast_constraint ctx nc) + | P.E_exit exp -> E_exit (to_ast_exp ctx exp) + | P.E_throw exp -> E_throw (to_ast_exp ctx exp) + | P.E_return exp -> E_return (to_ast_exp ctx exp) + | P.E_assert (cond, msg) -> E_assert (to_ast_exp ctx cond, to_ast_exp ctx msg) + | P.E_internal_plet (pat, exp1, exp2) -> + if !opt_magic_hash then E_internal_plet (to_ast_pat ctx pat, to_ast_exp ctx exp1, to_ast_exp ctx exp2) + else raise (Reporting.err_general l "Internal plet construct found without -dmagic_hash") + | P.E_internal_return exp -> + if !opt_magic_hash then E_internal_return (to_ast_exp ctx exp) + else raise (Reporting.err_general l "Internal return construct found without -dmagic_hash") + | P.E_deref exp -> E_app (Id_aux (Id "__deref", l), [to_ast_exp ctx exp]) + in + E_aux (aux, (l, empty_uannot)) + +and to_ast_measure ctx (P.Measure_aux (m, l)) : uannot internal_loop_measure = + let m = + match m with | P.Measure_none -> Measure_none | P.Measure_some exp -> - if !opt_magic_hash then - Measure_some (to_ast_exp ctx exp) - else - raise (Reporting.err_general l "Internal loop termination measure found without -dmagic_hash") - in Measure_aux (m,l) - -and to_ast_lexp ctx (P.E_aux(exp,l) : P.exp) : uannot lexp = - let lexp = match exp with + if !opt_magic_hash then Measure_some (to_ast_exp ctx exp) + else raise (Reporting.err_general l "Internal loop termination measure found without -dmagic_hash") + in + Measure_aux (m, l) + +and to_ast_lexp ctx (P.E_aux (exp, l) : P.exp) : uannot lexp = + let lexp = + match exp with | P.E_id id -> LE_id (to_ast_id ctx id) | P.E_deref exp -> LE_deref (to_ast_exp ctx exp) - | P.E_typ (typ, P.E_aux (P.E_id id, l')) -> - LE_typ (to_ast_typ ctx typ, to_ast_id ctx id) + | P.E_typ (typ, P.E_aux (P.E_id id, l')) -> LE_typ (to_ast_typ ctx typ, to_ast_id ctx id) | P.E_tuple tups -> - let ltups = List.map (to_ast_lexp ctx) tups in - let is_ok_in_tup (LE_aux (le, (l, _))) = - match le with - | LE_id _ | LE_typ _ | LE_vector _ | LE_vector_concat _ | LE_field _ | LE_vector_range _ | LE_tuple _ -> () - | LE_app _ | LE_deref _ -> - raise (Reporting.err_typ l "only identifiers, fields, and vectors may be set in a tuple") - in - List.iter is_ok_in_tup ltups; - LE_tuple ltups - | P.E_app ((P.Id_aux (f, l') as f'), args) -> - begin match f with - | P.Id(id) -> - (match List.map (to_ast_exp ctx) args with - | [E_aux (E_lit (L_aux (L_unit, _)), _)] -> LE_app (to_ast_id ctx f', []) - | [E_aux (E_tuple exps,_)] -> LE_app (to_ast_id ctx f', exps) - | args -> LE_app(to_ast_id ctx f', args)) - | _ -> raise (Reporting.err_typ l' "memory call on lefthand side of assignment must begin with an id") - end - | P.E_vector_append (exp1, exp2) -> - LE_vector_concat (to_ast_lexp ctx exp1 :: to_ast_lexp_vector_concat ctx exp2) + let ltups = List.map (to_ast_lexp ctx) tups in + let is_ok_in_tup (LE_aux (le, (l, _))) = + match le with + | LE_id _ | LE_typ _ | LE_vector _ | LE_vector_concat _ | LE_field _ | LE_vector_range _ | LE_tuple _ -> () + | LE_app _ | LE_deref _ -> + raise (Reporting.err_typ l "only identifiers, fields, and vectors may be set in a tuple") + in + List.iter is_ok_in_tup ltups; + LE_tuple ltups + | P.E_app ((P.Id_aux (f, l') as f'), args) -> begin + match f with + | P.Id id -> ( + match List.map (to_ast_exp ctx) args with + | [E_aux (E_lit (L_aux (L_unit, _)), _)] -> LE_app (to_ast_id ctx f', []) + | [E_aux (E_tuple exps, _)] -> LE_app (to_ast_id ctx f', exps) + | args -> LE_app (to_ast_id ctx f', args) + ) + | _ -> raise (Reporting.err_typ l' "memory call on lefthand side of assignment must begin with an id") + end + | P.E_vector_append (exp1, exp2) -> LE_vector_concat (to_ast_lexp ctx exp1 :: to_ast_lexp_vector_concat ctx exp2) | P.E_vector_access (vexp, exp) -> LE_vector (to_ast_lexp ctx vexp, to_ast_exp ctx exp) | P.E_vector_subrange (vexp, exp1, exp2) -> - LE_vector_range (to_ast_lexp ctx vexp, to_ast_exp ctx exp1, to_ast_exp ctx exp2) + LE_vector_range (to_ast_lexp ctx vexp, to_ast_exp ctx exp1, to_ast_exp ctx exp2) | P.E_field (fexp, id) -> LE_field (to_ast_lexp ctx fexp, to_ast_id ctx id) - | _ -> raise (Reporting.err_typ l "Only identifiers, cast identifiers, vector accesses, vector slices, and fields can be on the lefthand side of an assignment") + | _ -> + raise + (Reporting.err_typ l + "Only identifiers, cast identifiers, vector accesses, vector slices, and fields can be on the lefthand \ + side of an assignment" + ) in LE_aux (lexp, (l, empty_uannot)) and to_ast_lexp_vector_concat ctx (P.E_aux (exp_aux, l) as exp) = match exp_aux with - | P.E_vector_append (exp1, exp2) -> - to_ast_lexp ctx exp1 :: to_ast_lexp_vector_concat ctx exp2 + | P.E_vector_append (exp1, exp2) -> to_ast_lexp ctx exp1 :: to_ast_lexp_vector_concat ctx exp2 | _ -> [to_ast_lexp ctx exp] -and to_ast_case ctx (P.Pat_aux(pex,l) : P.pexp) : uannot pexp = +and to_ast_case ctx (P.Pat_aux (pex, l) : P.pexp) : uannot pexp = match pex with - | P.Pat_exp (pat, exp) -> - Pat_aux (Pat_exp (to_ast_pat ctx pat, to_ast_exp ctx exp), (l, empty_uannot)) - | P.Pat_when(pat,guard,exp) -> - Pat_aux (Pat_when (to_ast_pat ctx pat, to_ast_exp ctx guard, to_ast_exp ctx exp), (l, empty_uannot)) + | P.Pat_exp (pat, exp) -> Pat_aux (Pat_exp (to_ast_pat ctx pat, to_ast_exp ctx exp), (l, empty_uannot)) + | P.Pat_when (pat, guard, exp) -> + Pat_aux (Pat_when (to_ast_pat ctx pat, to_ast_exp ctx guard, to_ast_exp ctx exp), (l, empty_uannot)) -and to_ast_fexps (fail_on_error:bool) ctx (exps : P.exp list) : uannot fexp list option = +and to_ast_fexps (fail_on_error : bool) ctx (exps : P.exp list) : uannot fexp list option = match exps with | [] -> Some [] - | fexp::exps -> let maybe_fexp,maybe_error = to_ast_record_try ctx fexp in - (match maybe_fexp,maybe_error with - | Some(fexp), None -> - (match (to_ast_fexps fail_on_error ctx exps) with - | Some(fexps) -> Some(fexp::fexps) - | _ -> None) - | None,Some(l,msg) -> - if fail_on_error - then raise (Reporting.err_typ l msg) - else None - | _ -> None) - -and to_ast_record_try ctx (P.E_aux(exp,l):P.exp): uannot fexp option * (l * string) option = + | fexp :: exps -> ( + let maybe_fexp, maybe_error = to_ast_record_try ctx fexp in + match (maybe_fexp, maybe_error) with + | Some fexp, None -> ( + match to_ast_fexps fail_on_error ctx exps with Some fexps -> Some (fexp :: fexps) | _ -> None + ) + | None, Some (l, msg) -> if fail_on_error then raise (Reporting.err_typ l msg) else None + | _ -> None + ) + +and to_ast_record_try ctx (P.E_aux (exp, l) : P.exp) : uannot fexp option * (l * string) option = match exp with - | P.E_app_infix(left,op,r) -> - (match left, op with - | P.E_aux(P.E_id(id),li), P.Id_aux(P.Id("="),leq) -> - Some(FE_aux(FE_fexp(to_ast_id ctx id, to_ast_exp ctx r), (l, empty_uannot))),None - | P.E_aux(_,li) , P.Id_aux(P.Id("="),leq) -> - None,Some(li,"Expected an identifier to begin this field assignment") - | P.E_aux(P.E_id(id),li), P.Id_aux(_,leq) -> - None,Some(leq,"Expected a field assignment to be identifier = expression") - | P.E_aux(_,li),P.Id_aux(_,leq) -> - None,Some(l,"Expected a field assignment to be identifier = expression")) - | _ -> - None,Some(l, "Expected a field assignment to be identifier = expression") + | P.E_app_infix (left, op, r) -> ( + match (left, op) with + | P.E_aux (P.E_id id, li), P.Id_aux (P.Id "=", leq) -> + (Some (FE_aux (FE_fexp (to_ast_id ctx id, to_ast_exp ctx r), (l, empty_uannot))), None) + | P.E_aux (_, li), P.Id_aux (P.Id "=", leq) -> + (None, Some (li, "Expected an identifier to begin this field assignment")) + | P.E_aux (P.E_id id, li), P.Id_aux (_, leq) -> + (None, Some (leq, "Expected a field assignment to be identifier = expression")) + | P.E_aux (_, li), P.Id_aux (_, leq) -> + (None, Some (l, "Expected a field assignment to be identifier = expression")) + ) + | _ -> (None, Some (l, "Expected a field assignment to be identifier = expression")) type 'a ctx_out = 'a * ctx let to_ast_default ctx (default : P.default_typing_spec) : default_spec ctx_out = match default with - | P.DT_aux(P.DT_order(k,o),l) -> - let k = to_ast_kind k in - match (k,o) with - | K_aux(K_order, _), P.ATyp_aux(P.ATyp_inc,lo) -> - let default_order = Ord_aux(Ord_inc,lo) in - DT_aux(DT_order default_order,l),ctx - | K_aux(K_order, _), P.ATyp_aux(P.ATyp_dec,lo) -> - let default_order = Ord_aux(Ord_dec,lo) in - DT_aux(DT_order default_order,l),ctx - | _ -> raise (Reporting.err_typ l "Inc and Dec must have kind Order") - -let to_ast_extern (ext : P.extern) : extern = - { pure = ext.pure; bindings = ext.bindings } - + | P.DT_aux (P.DT_order (k, o), l) -> ( + let k = to_ast_kind k in + match (k, o) with + | K_aux (K_order, _), P.ATyp_aux (P.ATyp_inc, lo) -> + let default_order = Ord_aux (Ord_inc, lo) in + (DT_aux (DT_order default_order, l), ctx) + | K_aux (K_order, _), P.ATyp_aux (P.ATyp_dec, lo) -> + let default_order = Ord_aux (Ord_dec, lo) in + (DT_aux (DT_order default_order, l), ctx) + | _ -> raise (Reporting.err_typ l "Inc and Dec must have kind Order") + ) + +let to_ast_extern (ext : P.extern) : extern = { pure = ext.pure; bindings = ext.bindings } + let to_ast_spec ctx (vs : P.val_spec) : uannot val_spec ctx_out = match vs with - | P.VS_aux (vs, l) -> - match vs with - | P.VS_val_spec (ts, id, ext, is_cast) -> - let typschm, _ = to_ast_typschm ctx ts in - let ext = Option.map to_ast_extern ext in - VS_aux (VS_val_spec (typschm,to_ast_id ctx id, ext, is_cast), (l, empty_uannot)), ctx - + | P.VS_aux (vs, l) -> ( + match vs with + | P.VS_val_spec (ts, id, ext, is_cast) -> + let typschm, _ = to_ast_typschm ctx ts in + let ext = Option.map to_ast_extern ext in + (VS_aux (VS_val_spec (typschm, to_ast_id ctx id, ext, is_cast), (l, empty_uannot)), ctx) + ) + let to_ast_outcome ctx (ev : P.outcome_spec) : outcome_spec ctx_out = match ev with | P.OV_aux (P.OV_outcome (id, typschm, outcome_args), l) -> - let outcome_args, inner_ctx = - List.fold_left (fun (args, ctx) arg -> let (arg, ctx), _ = to_ast_kopts ctx arg in (arg @ args, ctx)) ([], ctx) outcome_args - in - let typschm, _ = to_ast_typschm inner_ctx typschm in - OV_aux (OV_outcome (to_ast_id ctx id, typschm, List.rev outcome_args), l), inner_ctx - -let rec to_ast_range ctx (P.BF_aux(r,l)) = (* TODO add check that ranges are sensible for some definition of sensible *) - BF_aux( - (match r with - | P.BF_single(i) -> BF_single (to_ast_bitfield_index_nexp ctx i) - | P.BF_range(i1,i2) -> BF_range (to_ast_bitfield_index_nexp ctx i1, to_ast_bitfield_index_nexp ctx i2) - | P.BF_concat(ir1,ir2) -> BF_concat (to_ast_range ctx ir1, to_ast_range ctx ir2)), - l) + let outcome_args, inner_ctx = + List.fold_left + (fun (args, ctx) arg -> + let (arg, ctx), _ = to_ast_kopts ctx arg in + (arg @ args, ctx) + ) + ([], ctx) outcome_args + in + let typschm, _ = to_ast_typschm inner_ctx typschm in + (OV_aux (OV_outcome (to_ast_id ctx id, typschm, List.rev outcome_args), l), inner_ctx) + +let rec to_ast_range ctx (P.BF_aux (r, l)) = + (* TODO add check that ranges are sensible for some definition of sensible *) + BF_aux + ( ( match r with + | P.BF_single i -> BF_single (to_ast_bitfield_index_nexp ctx i) + | P.BF_range (i1, i2) -> BF_range (to_ast_bitfield_index_nexp ctx i1, to_ast_bitfield_index_nexp ctx i2) + | P.BF_concat (ir1, ir2) -> BF_concat (to_ast_range ctx ir1, to_ast_range ctx ir2) + ), + l + ) let to_ast_type_union ctx = function | P.Tu_aux (P.Tu_ty_id (atyp, id), l) -> - let typ = to_ast_typ ctx atyp in - Tu_aux (Tu_ty_id (typ, to_ast_id ctx id), l) + let typ = to_ast_typ ctx atyp in + Tu_aux (Tu_ty_id (typ, to_ast_id ctx id), l) | P.Tu_aux (_, l) -> - raise (Reporting.err_unreachable l __POS__ "Anonymous record type should have been rewritten by now") + raise (Reporting.err_unreachable l __POS__ "Anonymous record type should have been rewritten by now") let add_constructor id typq ctx = let kinds = List.map (fun kopt -> unaux_kind (kopt_kind kopt)) (quant_kopts typq) in @@ -603,287 +638,302 @@ let add_constructor id typq ctx = let anon_rec_constructor_typ record_id = function | P.TypQ_aux (P.TypQ_no_forall, l) -> P.ATyp_aux (P.ATyp_id record_id, Generated l) - | P.TypQ_aux (P.TypQ_tq quants, l) -> - let quant_arg = function - | P.QI_aux (P.QI_id (P.KOpt_aux (P.KOpt_kind (_, vs, _), l)), _) -> - List.map (fun v -> P.ATyp_aux (P.ATyp_var v, Generated l)) vs - | P.QI_aux (P.QI_constraint _, _) -> [] - in - match List.concat (List.map quant_arg quants) with - | [] -> P.ATyp_aux (P.ATyp_id record_id, Generated l) - | args -> P.ATyp_aux (P.ATyp_app (record_id, args), Generated l) + | P.TypQ_aux (P.TypQ_tq quants, l) -> ( + let quant_arg = function + | P.QI_aux (P.QI_id (P.KOpt_aux (P.KOpt_kind (_, vs, _), l)), _) -> + List.map (fun v -> P.ATyp_aux (P.ATyp_var v, Generated l)) vs + | P.QI_aux (P.QI_constraint _, _) -> [] + in + match List.concat (List.map quant_arg quants) with + | [] -> P.ATyp_aux (P.ATyp_id record_id, Generated l) + | args -> P.ATyp_aux (P.ATyp_app (record_id, args), Generated l) + ) let rec realise_union_anon_rec_types orig_union arms = match orig_union with - | P.TD_variant (union_id, typq, _, flag) -> - begin match arms with - | [] -> [] - | arm :: arms -> - match arm with - | (P.Tu_aux ((P.Tu_ty_id _), _)) -> (None, arm) :: realise_union_anon_rec_types orig_union arms - | (P.Tu_aux ((P.Tu_ty_anon_rec (fields, id)), l)) -> - let open Parse_ast in - let record_str = "_" ^ string_of_parse_id union_id ^ "_" ^ string_of_parse_id id ^ "_record" in - let record_id = Id_aux (Id record_str, Generated l) in - let new_arm = Tu_aux (Tu_ty_id (anon_rec_constructor_typ record_id typq, id), Generated l) in - let new_rec_def = TD_aux (TD_record (record_id, typq, fields, flag), Generated l) in - (Some new_rec_def, new_arm) :: (realise_union_anon_rec_types orig_union arms) - end + | P.TD_variant (union_id, typq, _, flag) -> begin + match arms with + | [] -> [] + | arm :: arms -> ( + match arm with + | P.Tu_aux (P.Tu_ty_id _, _) -> (None, arm) :: realise_union_anon_rec_types orig_union arms + | P.Tu_aux (P.Tu_ty_anon_rec (fields, id), l) -> + let open Parse_ast in + let record_str = "_" ^ string_of_parse_id union_id ^ "_" ^ string_of_parse_id id ^ "_record" in + let record_id = Id_aux (Id record_str, Generated l) in + let new_arm = Tu_aux (Tu_ty_id (anon_rec_constructor_typ record_id typq, id), Generated l) in + let new_rec_def = TD_aux (TD_record (record_id, typq, fields, flag), Generated l) in + (Some new_rec_def, new_arm) :: realise_union_anon_rec_types orig_union arms + ) + end | _ -> - raise (Reporting.err_unreachable Parse_ast.Unknown __POS__ "Non union type-definition passed to realise_union_anon_rec_typs") + raise + (Reporting.err_unreachable Parse_ast.Unknown __POS__ + "Non union type-definition passed to realise_union_anon_rec_typs" + ) let generate_enum_functions l ctx enum_id fns exps = let get_exp i = function - | Some (P.E_aux (P.E_tuple exps, _)) -> List.nth exps i + | Some (P.E_aux (P.E_tuple exps, _)) -> List.nth exps i | Some exp -> exp | None -> Reporting.unreachable l __POS__ "get_exp called without expression" in - let num_exps = function - | Some (P.E_aux (P.E_tuple exps, _)) -> List.length exps - | Some _ -> 1 - | None -> 0 - in + let num_exps = function Some (P.E_aux (P.E_tuple exps, _)) -> List.length exps | Some _ -> 1 | None -> 0 in let num_fns = List.length fns in - List.iter (fun (id, exp) -> + List.iter + (fun (id, exp) -> let n = num_exps exp in if n <> num_fns then ( - let l = (match exp with Some (P.E_aux (_, l)) -> l | None -> parse_id_loc id) in - raise (Reporting.err_general l - (sprintf "Each enumeration clause for %s must define exactly %d expressions for the functions %s\n\ - %s expressions have been given here" - (string_of_id enum_id) - num_fns - (string_of_list ", " string_of_parse_id (List.map fst fns)) - (if n = 0 then "No" else if n > num_fns then "Too many" else "Too few"))) + let l = match exp with Some (P.E_aux (_, l)) -> l | None -> parse_id_loc id in + raise + (Reporting.err_general l + (sprintf + "Each enumeration clause for %s must define exactly %d expressions for the functions %s\n\ + %s expressions have been given here" (string_of_id enum_id) num_fns + (string_of_list ", " string_of_parse_id (List.map fst fns)) + (if n = 0 then "No" else if n > num_fns then "Too many" else "Too few") + ) + ) ) - ) exps; - List.mapi (fun i (id, typ) -> + ) + exps; + List.mapi + (fun i (id, typ) -> let typ = to_ast_typ ctx typ in let name = mk_id (string_of_id enum_id ^ "_" ^ string_of_parse_id id) in - [mk_fundef [ - mk_funcl name (mk_pat (P_id (mk_id "arg#"))) - (mk_exp (E_match (mk_exp (E_id (mk_id "arg#")), - List.map (fun (id, exps) -> - let id = to_ast_id ctx id in - let exp = to_ast_exp ctx (get_exp i exps) in - mk_pexp (Pat_exp (mk_pat (P_id id), exp)) - ) exps))) - ]; - mk_val_spec (VS_val_spec (mk_typschm (mk_typquant []) (function_typ [mk_id_typ enum_id] typ), - name, - None, - false))] - ) fns + [ + mk_fundef + [ + mk_funcl name + (mk_pat (P_id (mk_id "arg#"))) + (mk_exp + (E_match + ( mk_exp (E_id (mk_id "arg#")), + List.map + (fun (id, exps) -> + let id = to_ast_id ctx id in + let exp = to_ast_exp ctx (get_exp i exps) in + mk_pexp (Pat_exp (mk_pat (P_id id), exp)) + ) + exps + ) + ) + ); + ]; + mk_val_spec (VS_val_spec (mk_typschm (mk_typquant []) (function_typ [mk_id_typ enum_id] typ), name, None, false)); + ] + ) + fns |> List.concat (* When desugaring a type definition, we check that the type does not have a reserved name *) let to_ast_reserved_type_id ctx id = let id = to_ast_id ctx id in - if List.exists (fun reserved -> Id.compare reserved id = 0) ctx.reserved_type_ids then - begin match Reporting.loc_file (id_loc id) with + if List.exists (fun reserved -> Id.compare reserved id = 0) ctx.reserved_type_ids then begin + match Reporting.loc_file (id_loc id) with | Some file when !opt_magic_hash || List.exists (fun internal_file -> file = internal_file) ctx.internal_files -> id | None -> id - | Some file -> - raise (Reporting.err_general (id_loc id) (sprintf "The type name %s is reserved" (string_of_id id))) - end - else - id + | Some file -> raise (Reporting.err_general (id_loc id) (sprintf "The type name %s is reserved" (string_of_id id))) + end + else id let rec to_ast_typedef ctx def_annot (P.TD_aux (aux, l) : P.type_def) : uannot def list ctx_out = match aux with | P.TD_abbrev (id, typq, kind, typ_arg) -> - let id = to_ast_reserved_type_id ctx id in - let typq, typq_ctx = to_ast_typquant ctx typq in - let kind = to_ast_kind kind in - let typ_arg = to_ast_typ_arg typq_ctx typ_arg (unaux_kind kind) in - [DEF_aux (DEF_type (TD_aux (TD_abbrev (id, typq, typ_arg), (l, empty_uannot))), def_annot)], - add_constructor id typq ctx - + let id = to_ast_reserved_type_id ctx id in + let typq, typq_ctx = to_ast_typquant ctx typq in + let kind = to_ast_kind kind in + let typ_arg = to_ast_typ_arg typq_ctx typ_arg (unaux_kind kind) in + ( [DEF_aux (DEF_type (TD_aux (TD_abbrev (id, typq, typ_arg), (l, empty_uannot))), def_annot)], + add_constructor id typq ctx + ) | P.TD_record (id, typq, fields, _) -> - let id = to_ast_reserved_type_id ctx id in - let typq, typq_ctx = to_ast_typquant ctx typq in - let fields = List.map (fun (atyp, id) -> to_ast_typ typq_ctx atyp, to_ast_id ctx id) fields in - [DEF_aux (DEF_type (TD_aux (TD_record (id, typq, fields, false), (l, empty_uannot))), def_annot)], - add_constructor id typq ctx - + let id = to_ast_reserved_type_id ctx id in + let typq, typq_ctx = to_ast_typquant ctx typq in + let fields = List.map (fun (atyp, id) -> (to_ast_typ typq_ctx atyp, to_ast_id ctx id)) fields in + ( [DEF_aux (DEF_type (TD_aux (TD_record (id, typq, fields, false), (l, empty_uannot))), def_annot)], + add_constructor id typq ctx + ) | P.TD_variant (id, typq, arms, _) as union -> - (* First generate auxilliary record types for anonymous records in constructors *) - let records_and_arms = realise_union_anon_rec_types union arms in - let rec filter_records = function - | [] -> [] - | Some x :: xs -> x :: filter_records xs - | None :: xs -> filter_records xs - in - let generated_records = filter_records (List.map fst records_and_arms) in - let generated_records, ctx = - List.fold_left (fun (prev, ctx) td -> let td, ctx = to_ast_typedef ctx (mk_def_annot (gen_loc l)) td in prev @ td, ctx) - ([], ctx) - generated_records - in - let arms = List.map snd records_and_arms in - (* Now generate the AST union type *) - let id = to_ast_reserved_type_id ctx id in - let typq, typq_ctx = to_ast_typquant ctx typq in - let arms = List.map (to_ast_type_union (add_constructor id typq typq_ctx)) arms in - [DEF_aux (DEF_type (TD_aux (TD_variant (id, typq, arms, false), (l, empty_uannot))), def_annot)] @ generated_records, - add_constructor id typq ctx - + (* First generate auxilliary record types for anonymous records in constructors *) + let records_and_arms = realise_union_anon_rec_types union arms in + let rec filter_records = function + | [] -> [] + | Some x :: xs -> x :: filter_records xs + | None :: xs -> filter_records xs + in + let generated_records = filter_records (List.map fst records_and_arms) in + let generated_records, ctx = + List.fold_left + (fun (prev, ctx) td -> + let td, ctx = to_ast_typedef ctx (mk_def_annot (gen_loc l)) td in + (prev @ td, ctx) + ) + ([], ctx) generated_records + in + let arms = List.map snd records_and_arms in + (* Now generate the AST union type *) + let id = to_ast_reserved_type_id ctx id in + let typq, typq_ctx = to_ast_typquant ctx typq in + let arms = List.map (to_ast_type_union (add_constructor id typq typq_ctx)) arms in + ( [DEF_aux (DEF_type (TD_aux (TD_variant (id, typq, arms, false), (l, empty_uannot))), def_annot)] + @ generated_records, + add_constructor id typq ctx + ) | P.TD_enum (id, fns, enums, _) -> - let id = to_ast_reserved_type_id ctx id in - let fns = generate_enum_functions l ctx id fns enums in - let enums = List.map (fun e -> to_ast_id ctx (fst e)) enums in - fns @ [DEF_aux (DEF_type (TD_aux (TD_enum (id, enums, false), (l, empty_uannot))), def_annot)], - { ctx with type_constructors = Bindings.add id [] ctx.type_constructors } - + let id = to_ast_reserved_type_id ctx id in + let fns = generate_enum_functions l ctx id fns enums in + let enums = List.map (fun e -> to_ast_id ctx (fst e)) enums in + ( fns @ [DEF_aux (DEF_type (TD_aux (TD_enum (id, enums, false), (l, empty_uannot))), def_annot)], + { ctx with type_constructors = Bindings.add id [] ctx.type_constructors } + ) | P.TD_bitfield (id, typ, ranges) -> - let id = to_ast_reserved_type_id ctx id in - let typ = to_ast_typ ctx typ in - let ranges = List.map (fun (id, range) -> (to_ast_id ctx id, to_ast_range ctx range)) ranges in - [DEF_aux (DEF_type (TD_aux (TD_bitfield (id, typ, ranges), (l, empty_uannot))), def_annot)], - { ctx with type_constructors = Bindings.add id [] ctx.type_constructors } - -let to_ast_rec ctx (P.Rec_aux(r,l): P.rec_opt) : uannot rec_opt = - Rec_aux((match r with - | P.Rec_none -> Rec_nonrec - | P.Rec_measure (p,e) -> - Rec_measure (to_ast_pat ctx p, to_ast_exp ctx e) - ),l) - -let to_ast_tannot_opt ctx (P.Typ_annot_opt_aux(tp,l)) : tannot_opt ctx_out = + let id = to_ast_reserved_type_id ctx id in + let typ = to_ast_typ ctx typ in + let ranges = List.map (fun (id, range) -> (to_ast_id ctx id, to_ast_range ctx range)) ranges in + ( [DEF_aux (DEF_type (TD_aux (TD_bitfield (id, typ, ranges), (l, empty_uannot))), def_annot)], + { ctx with type_constructors = Bindings.add id [] ctx.type_constructors } + ) + +let to_ast_rec ctx (P.Rec_aux (r, l) : P.rec_opt) : uannot rec_opt = + Rec_aux + ( ( match r with + | P.Rec_none -> Rec_nonrec + | P.Rec_measure (p, e) -> Rec_measure (to_ast_pat ctx p, to_ast_exp ctx e) + ), + l + ) + +let to_ast_tannot_opt ctx (P.Typ_annot_opt_aux (tp, l)) : tannot_opt ctx_out = match tp with - | P.Typ_annot_opt_none -> - Typ_annot_opt_aux (Typ_annot_opt_none, l), ctx - | P.Typ_annot_opt_some(tq,typ) -> - let typq, ctx = to_ast_typquant ctx tq in - Typ_annot_opt_aux (Typ_annot_opt_some(typq,to_ast_typ ctx typ),l),ctx + | P.Typ_annot_opt_none -> (Typ_annot_opt_aux (Typ_annot_opt_none, l), ctx) + | P.Typ_annot_opt_some (tq, typ) -> + let typq, ctx = to_ast_typquant ctx tq in + (Typ_annot_opt_aux (Typ_annot_opt_some (typq, to_ast_typ ctx typ), l), ctx) -let to_ast_typschm_opt ctx (P.TypSchm_opt_aux(aux,l)) : tannot_opt ctx_out = +let to_ast_typschm_opt ctx (P.TypSchm_opt_aux (aux, l)) : tannot_opt ctx_out = match aux with - | P.TypSchm_opt_none -> - Typ_annot_opt_aux (Typ_annot_opt_none, l), ctx + | P.TypSchm_opt_none -> (Typ_annot_opt_aux (Typ_annot_opt_none, l), ctx) | P.TypSchm_opt_some (P.TypSchm_aux (P.TypSchm_ts (tq, typ), l)) -> - let typq, ctx = to_ast_typquant ctx tq in - Typ_annot_opt_aux (Typ_annot_opt_some (typq, to_ast_typ ctx typ), l), ctx + let typq, ctx = to_ast_typquant ctx tq in + (Typ_annot_opt_aux (Typ_annot_opt_some (typq, to_ast_typ ctx typ), l), ctx) let to_ast_funcl ctx (P.FCL_aux (fcl, l) : P.funcl) : uannot funcl = match fcl with | P.FCL_funcl (id, pexp) -> - FCL_aux (FCL_funcl (to_ast_id ctx id, to_ast_case ctx pexp), (mk_def_annot l, empty_uannot)) + FCL_aux (FCL_funcl (to_ast_id ctx id, to_ast_case ctx pexp), (mk_def_annot l, empty_uannot)) let to_ast_impl_funcls ctx (P.FCL_aux (fcl, l) : P.funcl) : uannot funcl list = match fcl with - | P.FCL_funcl (id, pexp) -> - match List.assoc_opt (string_of_parse_id id) ctx.target_sets with - | Some targets -> - List.map (fun target -> - FCL_aux (FCL_funcl (Id_aux (Id target, parse_id_loc id), to_ast_case ctx pexp), (mk_def_annot l, empty_uannot)) - ) targets - | None -> - [FCL_aux (FCL_funcl (to_ast_id ctx id, to_ast_case ctx pexp), (mk_def_annot l, empty_uannot))] - -let to_ast_fundef ctx (P.FD_aux(fd,l):P.fundef) : uannot fundef = + | P.FCL_funcl (id, pexp) -> ( + match List.assoc_opt (string_of_parse_id id) ctx.target_sets with + | Some targets -> + List.map + (fun target -> + FCL_aux + (FCL_funcl (Id_aux (Id target, parse_id_loc id), to_ast_case ctx pexp), (mk_def_annot l, empty_uannot)) + ) + targets + | None -> [FCL_aux (FCL_funcl (to_ast_id ctx id, to_ast_case ctx pexp), (mk_def_annot l, empty_uannot))] + ) + +let to_ast_fundef ctx (P.FD_aux (fd, l) : P.fundef) : uannot fundef = match fd with | P.FD_function (rec_opt, tannot_opt, _, funcls) -> - let tannot_opt, ctx = to_ast_tannot_opt ctx tannot_opt in - FD_aux(FD_function(to_ast_rec ctx rec_opt, tannot_opt, List.map (to_ast_funcl ctx) funcls), (l, empty_uannot)) - -let rec to_ast_mpat ctx (P.MP_aux(mpat,l)) = - MP_aux ( - (match mpat with - | P.MP_lit lit -> MP_lit (to_ast_lit lit) - | P.MP_id id -> MP_id (to_ast_id ctx id) - | P.MP_as (mpat, id) -> MP_as (to_ast_mpat ctx mpat, to_ast_id ctx id) - | P.MP_app (id, mpats) -> - if mpats = [] - then MP_id (to_ast_id ctx id) - else MP_app (to_ast_id ctx id, List.map (to_ast_mpat ctx) mpats) - | P.MP_vector mpats -> MP_vector (List.map (to_ast_mpat ctx) mpats) - | P.MP_vector_concat mpats -> MP_vector_concat (List.map (to_ast_mpat ctx) mpats) - | P.MP_vector_subrange (id, n, m) -> MP_vector_subrange (to_ast_id ctx id, n, m) - | P.MP_tuple mpats -> MP_tuple (List.map (to_ast_mpat ctx) mpats) - | P.MP_list mpats -> MP_list (List.map (to_ast_mpat ctx) mpats) - | P.MP_cons (pat1, pat2) -> MP_cons (to_ast_mpat ctx pat1, to_ast_mpat ctx pat2) - | P.MP_string_append pats -> MP_string_append (List.map (to_ast_mpat ctx) pats) - | P.MP_typ (mpat, typ) -> MP_typ (to_ast_mpat ctx mpat, to_ast_typ ctx typ) - ), (l, empty_uannot) + let tannot_opt, ctx = to_ast_tannot_opt ctx tannot_opt in + FD_aux (FD_function (to_ast_rec ctx rec_opt, tannot_opt, List.map (to_ast_funcl ctx) funcls), (l, empty_uannot)) + +let rec to_ast_mpat ctx (P.MP_aux (mpat, l)) = + MP_aux + ( ( match mpat with + | P.MP_lit lit -> MP_lit (to_ast_lit lit) + | P.MP_id id -> MP_id (to_ast_id ctx id) + | P.MP_as (mpat, id) -> MP_as (to_ast_mpat ctx mpat, to_ast_id ctx id) + | P.MP_app (id, mpats) -> + if mpats = [] then MP_id (to_ast_id ctx id) else MP_app (to_ast_id ctx id, List.map (to_ast_mpat ctx) mpats) + | P.MP_vector mpats -> MP_vector (List.map (to_ast_mpat ctx) mpats) + | P.MP_vector_concat mpats -> MP_vector_concat (List.map (to_ast_mpat ctx) mpats) + | P.MP_vector_subrange (id, n, m) -> MP_vector_subrange (to_ast_id ctx id, n, m) + | P.MP_tuple mpats -> MP_tuple (List.map (to_ast_mpat ctx) mpats) + | P.MP_list mpats -> MP_list (List.map (to_ast_mpat ctx) mpats) + | P.MP_cons (pat1, pat2) -> MP_cons (to_ast_mpat ctx pat1, to_ast_mpat ctx pat2) + | P.MP_string_append pats -> MP_string_append (List.map (to_ast_mpat ctx) pats) + | P.MP_typ (mpat, typ) -> MP_typ (to_ast_mpat ctx mpat, to_ast_typ ctx typ) + ), + (l, empty_uannot) ) -let to_ast_mpexp ctx (P.MPat_aux(mpexp, l)) = +let to_ast_mpexp ctx (P.MPat_aux (mpexp, l)) = match mpexp with | P.MPat_pat mpat -> MPat_aux (MPat_pat (to_ast_mpat ctx mpat), (l, empty_uannot)) | P.MPat_when (mpat, exp) -> MPat_aux (MPat_when (to_ast_mpat ctx mpat, to_ast_exp ctx exp), (l, empty_uannot)) -let to_ast_mapcl ctx (P.MCL_aux(mapcl, l)) = +let to_ast_mapcl ctx (P.MCL_aux (mapcl, l)) = let def_annot = mk_def_annot l in match mapcl with | P.MCL_bidir (mpexp1, mpexp2) -> - MCL_aux (MCL_bidir (to_ast_mpexp ctx mpexp1, to_ast_mpexp ctx mpexp2), (def_annot, empty_uannot)) + MCL_aux (MCL_bidir (to_ast_mpexp ctx mpexp1, to_ast_mpexp ctx mpexp2), (def_annot, empty_uannot)) | P.MCL_forwards (mpexp, exp) -> - MCL_aux (MCL_forwards (to_ast_mpexp ctx mpexp, to_ast_exp ctx exp), (def_annot, empty_uannot)) + MCL_aux (MCL_forwards (to_ast_mpexp ctx mpexp, to_ast_exp ctx exp), (def_annot, empty_uannot)) | P.MCL_backwards (mpexp, exp) -> - MCL_aux (MCL_backwards (to_ast_mpexp ctx mpexp, to_ast_exp ctx exp), (def_annot, empty_uannot)) + MCL_aux (MCL_backwards (to_ast_mpexp ctx mpexp, to_ast_exp ctx exp), (def_annot, empty_uannot)) -let to_ast_mapdef ctx (P.MD_aux(md,l):P.mapdef) : uannot mapdef = +let to_ast_mapdef ctx (P.MD_aux (md, l) : P.mapdef) : uannot mapdef = match md with - | P.MD_mapping(id, typschm_opt, mapcls) -> - let tannot_opt, ctx = to_ast_typschm_opt ctx typschm_opt in - MD_aux(MD_mapping(to_ast_id ctx id, tannot_opt, List.map (to_ast_mapcl ctx) mapcls), (l, empty_uannot)) - -let to_ast_dec ctx (P.DEC_aux(regdec,l)) = - DEC_aux ( - (match regdec with - | P.DEC_reg (typ, id, opt_exp) -> - let opt_exp = match opt_exp with - | None -> None - | Some exp -> Some (to_ast_exp ctx exp) - in + | P.MD_mapping (id, typschm_opt, mapcls) -> + let tannot_opt, ctx = to_ast_typschm_opt ctx typschm_opt in + MD_aux (MD_mapping (to_ast_id ctx id, tannot_opt, List.map (to_ast_mapcl ctx) mapcls), (l, empty_uannot)) + +let to_ast_dec ctx (P.DEC_aux (regdec, l)) = + DEC_aux + ( ( match regdec with + | P.DEC_reg (typ, id, opt_exp) -> + let opt_exp = match opt_exp with None -> None | Some exp -> Some (to_ast_exp ctx exp) in DEC_reg (to_ast_typ ctx typ, to_ast_id ctx id, opt_exp) - ), (l, empty_uannot) + ), + (l, empty_uannot) ) let to_ast_scattered ctx (P.SD_aux (aux, l)) = - let aux, ctx = match aux with + let aux, ctx = + match aux with | P.SD_function (rec_opt, tannot_opt, _, id) -> - let tannot_opt, _ = to_ast_tannot_opt ctx tannot_opt in - SD_function (to_ast_rec ctx rec_opt, tannot_opt, to_ast_id ctx id), ctx - | P.SD_funcl funcl -> - SD_funcl (to_ast_funcl ctx funcl), ctx + let tannot_opt, _ = to_ast_tannot_opt ctx tannot_opt in + (SD_function (to_ast_rec ctx rec_opt, tannot_opt, to_ast_id ctx id), ctx) + | P.SD_funcl funcl -> (SD_funcl (to_ast_funcl ctx funcl), ctx) | P.SD_variant (id, typq) -> - let id = to_ast_id ctx id in - let typq, typq_ctx = to_ast_typquant ctx typq in - SD_variant (id, typq), - add_constructor id typq { ctx with scattereds = Bindings.add id typq_ctx ctx.scattereds } + let id = to_ast_id ctx id in + let typq, typq_ctx = to_ast_typquant ctx typq in + ( SD_variant (id, typq), + add_constructor id typq { ctx with scattereds = Bindings.add id typq_ctx ctx.scattereds } + ) | P.SD_unioncl (id, tu) -> - let id = to_ast_id ctx id in - begin match Bindings.find_opt id ctx.scattereds with - | Some typq_ctx -> - let tu = to_ast_type_union typq_ctx tu in - SD_unioncl (id, tu), ctx - | None -> raise (Reporting.err_typ l ("No scattered union declaration found for " ^ string_of_id id)) - end - | P.SD_end id -> SD_end (to_ast_id ctx id), ctx + let id = to_ast_id ctx id in + begin + match Bindings.find_opt id ctx.scattereds with + | Some typq_ctx -> + let tu = to_ast_type_union typq_ctx tu in + (SD_unioncl (id, tu), ctx) + | None -> raise (Reporting.err_typ l ("No scattered union declaration found for " ^ string_of_id id)) + end + | P.SD_end id -> (SD_end (to_ast_id ctx id), ctx) | P.SD_mapping (id, tannot_opt) -> - let id = to_ast_id ctx id in - let tannot_opt, _ = to_ast_tannot_opt ctx tannot_opt in - SD_mapping (id, tannot_opt), ctx + let id = to_ast_id ctx id in + let tannot_opt, _ = to_ast_tannot_opt ctx tannot_opt in + (SD_mapping (id, tannot_opt), ctx) | P.SD_mapcl (id, mapcl) -> - let id = to_ast_id ctx id in - let mapcl = to_ast_mapcl ctx mapcl in - SD_mapcl (id, mapcl), ctx + let id = to_ast_id ctx id in + let mapcl = to_ast_mapcl ctx mapcl in + (SD_mapcl (id, mapcl), ctx) in - SD_aux (aux, (l, empty_uannot)), ctx + (SD_aux (aux, (l, empty_uannot)), ctx) -let to_ast_prec = function - | P.Infix -> Infix - | P.InfixL -> InfixL - | P.InfixR -> InfixR +let to_ast_prec = function P.Infix -> Infix | P.InfixL -> InfixL | P.InfixR -> InfixR let to_ast_subst ctx = function - | P.IS_aux (P.IS_id (id_from, id_to), l) -> - IS_aux (IS_id (to_ast_id ctx id_from, to_ast_id ctx id_to), l) - | P.IS_aux (P.IS_typ (kid, typ), l) -> - IS_aux (IS_typ (to_ast_var kid, to_ast_typ ctx typ), l) - + | P.IS_aux (P.IS_id (id_from, id_to), l) -> IS_aux (IS_id (to_ast_id ctx id_from, to_ast_id ctx id_to), l) + | P.IS_aux (P.IS_typ (kid, typ), l) -> IS_aux (IS_typ (to_ast_var kid, to_ast_typ ctx typ), l) + let to_ast_loop_measure ctx = function | P.Loop (P.While, exp) -> Loop (While, to_ast_exp ctx exp) | P.Loop (P.Until, exp) -> Loop (Until, to_ast_exp ctx exp) @@ -892,111 +942,125 @@ let rec to_ast_def doc attrs ctx (P.DEF_aux (def, l)) : uannot def list ctx_out let annot = List.fold_left (fun a (attr, arg, l) -> add_def_attribute l attr arg a) (mk_def_annot l) attrs in let annot = { annot with doc_comment = doc } in match def with - | P.DEF_attribute (attr, arg, def) -> - to_ast_def doc ((attr, arg, l) :: attrs) ctx def - | P.DEF_doc (doc_comment, def) -> - begin match doc with - | Some _ -> - raise (Reporting.err_general l "Toplevel definition has multiple documentation comments") - | None -> - to_ast_def (Some doc_comment) attrs ctx def - end - | P.DEF_overload (id, ids) -> - [DEF_aux (DEF_overload (to_ast_id ctx id, List.map (to_ast_id ctx) ids), annot)], ctx - | P.DEF_fixity (prec, n, op) -> - [DEF_aux (DEF_fixity (to_ast_prec prec, n, to_ast_id ctx op), annot)], ctx - | P.DEF_type t_def -> - to_ast_typedef ctx annot t_def + | P.DEF_attribute (attr, arg, def) -> to_ast_def doc ((attr, arg, l) :: attrs) ctx def + | P.DEF_doc (doc_comment, def) -> begin + match doc with + | Some _ -> raise (Reporting.err_general l "Toplevel definition has multiple documentation comments") + | None -> to_ast_def (Some doc_comment) attrs ctx def + end + | P.DEF_overload (id, ids) -> ([DEF_aux (DEF_overload (to_ast_id ctx id, List.map (to_ast_id ctx) ids), annot)], ctx) + | P.DEF_fixity (prec, n, op) -> ([DEF_aux (DEF_fixity (to_ast_prec prec, n, to_ast_id ctx op), annot)], ctx) + | P.DEF_type t_def -> to_ast_typedef ctx annot t_def | P.DEF_fundef f_def -> - let fd = to_ast_fundef ctx f_def in - [DEF_aux (DEF_fundef fd, annot)], ctx + let fd = to_ast_fundef ctx f_def in + ([DEF_aux (DEF_fundef fd, annot)], ctx) | P.DEF_mapdef m_def -> - let md = to_ast_mapdef ctx m_def in - [DEF_aux (DEF_mapdef md, annot)], ctx + let md = to_ast_mapdef ctx m_def in + ([DEF_aux (DEF_mapdef md, annot)], ctx) | P.DEF_impl funcl -> - let funcls = to_ast_impl_funcls ctx funcl in - List.map (fun funcl -> DEF_aux (DEF_impl funcl, annot)) funcls, ctx + let funcls = to_ast_impl_funcls ctx funcl in + (List.map (fun funcl -> DEF_aux (DEF_impl funcl, annot)) funcls, ctx) | P.DEF_let lb -> - let lb = to_ast_letbind ctx lb in - [DEF_aux (DEF_let lb, annot)], ctx + let lb = to_ast_letbind ctx lb in + ([DEF_aux (DEF_let lb, annot)], ctx) | P.DEF_val val_spec -> - let vs,ctx = to_ast_spec ctx val_spec in - [DEF_aux (DEF_val vs, annot)], ctx + let vs, ctx = to_ast_spec ctx val_spec in + ([DEF_aux (DEF_val vs, annot)], ctx) | P.DEF_outcome (outcome_spec, defs) -> - let outcome_spec, inner_ctx = to_ast_outcome ctx outcome_spec in - let defs, _ = - List.fold_left (fun (defs, ctx) def -> let def, ctx = to_ast_def doc attrs ctx def in (def @ defs, ctx)) ([], inner_ctx) defs - in - [DEF_aux (DEF_outcome (outcome_spec, List.rev defs), annot)], ctx + let outcome_spec, inner_ctx = to_ast_outcome ctx outcome_spec in + let defs, _ = + List.fold_left + (fun (defs, ctx) def -> + let def, ctx = to_ast_def doc attrs ctx def in + (def @ defs, ctx) + ) + ([], inner_ctx) defs + in + ([DEF_aux (DEF_outcome (outcome_spec, List.rev defs), annot)], ctx) | P.DEF_instantiation (id, substs) -> - let id = to_ast_id ctx id in - [DEF_aux (DEF_instantiation (IN_aux (IN_id id, (id_loc id, empty_uannot)), List.map (to_ast_subst ctx) substs), annot)], ctx + let id = to_ast_id ctx id in + ( [ + DEF_aux + (DEF_instantiation (IN_aux (IN_id id, (id_loc id, empty_uannot)), List.map (to_ast_subst ctx) substs), annot); + ], + ctx + ) | P.DEF_default typ_spec -> - let default,ctx = to_ast_default ctx typ_spec in - [DEF_aux (DEF_default default, annot)], ctx + let default, ctx = to_ast_default ctx typ_spec in + ([DEF_aux (DEF_default default, annot)], ctx) | P.DEF_register dec -> - let d = to_ast_dec ctx dec in - [DEF_aux (DEF_register d, annot)], ctx - | P.DEF_pragma ("sail_internal", arg) -> - begin match Reporting.loc_file l with - | Some file -> - [DEF_aux (DEF_pragma ("sail_internal", arg, l), annot)], { ctx with internal_files = file :: ctx.internal_files } - | None -> - [DEF_aux (DEF_pragma ("sail_internal", arg, l), annot)], ctx - end + let d = to_ast_dec ctx dec in + ([DEF_aux (DEF_register d, annot)], ctx) + | P.DEF_pragma ("sail_internal", arg) -> begin + match Reporting.loc_file l with + | Some file -> + ( [DEF_aux (DEF_pragma ("sail_internal", arg, l), annot)], + { ctx with internal_files = file :: ctx.internal_files } + ) + | None -> ([DEF_aux (DEF_pragma ("sail_internal", arg, l), annot)], ctx) + end | P.DEF_pragma ("target_set", arg) -> - let args = String.split_on_char ' ' arg |> List.filter (fun s -> String.length s > 0) in - begin match args with - | (set :: targets) -> - [DEF_aux (DEF_pragma ("target_set", arg, l), annot)], { ctx with target_sets = (set, targets) :: ctx.target_sets } - | [] -> - raise (Reporting.err_general l "No arguments provided to target set directive") - end - | P.DEF_pragma (pragma, arg) -> - [DEF_aux (DEF_pragma (pragma, arg, l), annot)], ctx + let args = String.split_on_char ' ' arg |> List.filter (fun s -> String.length s > 0) in + begin + match args with + | set :: targets -> + ( [DEF_aux (DEF_pragma ("target_set", arg, l), annot)], + { ctx with target_sets = (set, targets) :: ctx.target_sets } + ) + | [] -> raise (Reporting.err_general l "No arguments provided to target set directive") + end + | P.DEF_pragma (pragma, arg) -> ([DEF_aux (DEF_pragma (pragma, arg, l), annot)], ctx) | P.DEF_internal_mutrec _ -> - (* Should never occur because of remove_mutrec *) - raise (Reporting.err_unreachable l __POS__ - "Internal mutual block found when processing scattered defs") + (* Should never occur because of remove_mutrec *) + raise (Reporting.err_unreachable l __POS__ "Internal mutual block found when processing scattered defs") | P.DEF_scattered sdef -> - let sdef, ctx = to_ast_scattered ctx sdef in - [DEF_aux (DEF_scattered sdef, annot)], ctx + let sdef, ctx = to_ast_scattered ctx sdef in + ([DEF_aux (DEF_scattered sdef, annot)], ctx) | P.DEF_measure (id, pat, exp) -> - [DEF_aux (DEF_measure (to_ast_id ctx id, to_ast_pat ctx pat, to_ast_exp ctx exp), annot)], ctx + ([DEF_aux (DEF_measure (to_ast_id ctx id, to_ast_pat ctx pat, to_ast_exp ctx exp), annot)], ctx) | P.DEF_loop_measures (id, measures) -> - [DEF_aux (DEF_loop_measures (to_ast_id ctx id, List.map (to_ast_loop_measure ctx) measures), annot)], ctx + ([DEF_aux (DEF_loop_measures (to_ast_id ctx id, List.map (to_ast_loop_measure ctx) measures), annot)], ctx) let rec remove_mutrec = function | [] -> [] | P.DEF_aux (P.DEF_internal_mutrec fundefs, _) :: defs -> - List.map (fun (P.FD_aux (_, l) as fdef) -> P.DEF_aux (P.DEF_fundef fdef, l)) fundefs @ remove_mutrec defs - | def :: defs -> - def :: remove_mutrec defs + List.map (fun (P.FD_aux (_, l) as fdef) -> P.DEF_aux (P.DEF_fundef fdef, l)) fundefs @ remove_mutrec defs + | def :: defs -> def :: remove_mutrec defs let to_ast ctx (P.Defs files) = let to_ast_defs ctx (_, defs) = let defs = remove_mutrec defs in let defs, ctx = - List.fold_left (fun (defs, ctx) def -> let def, ctx = to_ast_def None [] ctx def in (def @ defs, ctx)) ([], ctx) defs + List.fold_left + (fun (defs, ctx) def -> + let def, ctx = to_ast_def None [] ctx def in + (def @ defs, ctx) + ) + ([], ctx) defs in - List.rev defs, ctx + (List.rev defs, ctx) in let wrap_file file defs = - [mk_def (DEF_pragma ("file_start", file, P.Unknown))] - @ defs - @ [mk_def (DEF_pragma ("file_end", file, P.Unknown))] + [mk_def (DEF_pragma ("file_start", file, P.Unknown))] @ defs @ [mk_def (DEF_pragma ("file_end", file, P.Unknown))] in let defs, ctx = - List.fold_left (fun (defs, ctx) file -> - let defs', ctx = to_ast_defs ctx file in (defs @ wrap_file (fst file) defs', ctx) - ) ([], ctx) files + List.fold_left + (fun (defs, ctx) file -> + let defs', ctx = to_ast_defs ctx file in + (defs @ wrap_file (fst file) defs', ctx) + ) + ([], ctx) files in - { defs = defs; comments = [] }, ctx + ({ defs; comments = [] }, ctx) -let initial_ctx = { +let initial_ctx = + { type_constructors = - List.fold_left (fun m (k, v) -> Bindings.add (mk_id k) v m) Bindings.empty - [ ("bool", []); + List.fold_left + (fun m (k, v) -> Bindings.add (mk_id k) v m) + Bindings.empty + [ + ("bool", []); ("nat", []); ("int", []); ("unit", []); @@ -1024,39 +1088,30 @@ let exp_of_string str = try let exp = Parser.exp_eof Lexer.token (Lexing.from_string str) in to_ast_exp initial_ctx exp - with - | Parser.Error -> - Reporting.unreachable Parse_ast.Unknown __POS__ ("Failed to parse " ^ str) + with Parser.Error -> Reporting.unreachable Parse_ast.Unknown __POS__ ("Failed to parse " ^ str) let typschm_of_string str = try let typschm = Parser.typschm_eof Lexer.token (Lexing.from_string str) in let typschm, _ = to_ast_typschm initial_ctx typschm in typschm - with - | Parser.Error -> - Reporting.unreachable Parse_ast.Unknown __POS__ ("Failed to parse " ^ str) + with Parser.Error -> Reporting.unreachable Parse_ast.Unknown __POS__ ("Failed to parse " ^ str) let typ_of_string str = try let typ = Parser.typ_eof Lexer.token (Lexing.from_string str) in let typ = to_ast_typ initial_ctx typ in typ - with - | Parser.Error -> - Reporting.unreachable Parse_ast.Unknown __POS__ ("Failed to parse " ^ str) + with Parser.Error -> Reporting.unreachable Parse_ast.Unknown __POS__ ("Failed to parse " ^ str) let constraint_of_string str = try let atyp = Parser.typ_eof Lexer.token (Lexing.from_string str) in to_ast_constraint initial_ctx atyp - with - | Parser.Error -> - Reporting.unreachable Parse_ast.Unknown __POS__ ("Failed to parse " ^ str) - + with Parser.Error -> Reporting.unreachable Parse_ast.Unknown __POS__ ("Failed to parse " ^ str) + let extern_of_string ?(pure = false) id str = - VS_val_spec (typschm_of_string str, id, Some { pure = pure; bindings = [("_", string_of_id id)] }, false) - |> mk_val_spec + VS_val_spec (typschm_of_string str, id, Some { pure; bindings = [("_", string_of_id id)] }, false) |> mk_val_spec let val_spec_of_string id str = mk_val_spec (VS_val_spec (typschm_of_string str, id, None, false)) @@ -1074,119 +1129,134 @@ let quant_item_arg = function | _ -> [] let undefined_typschm id typq = let qis = quant_items typq in - if qis = [] then - mk_typschm typq (function_typ [unit_typ] (mk_typ (Typ_id id))) - else + if qis = [] then mk_typschm typq (function_typ [unit_typ] (mk_typ (Typ_id id))) + else ( let arg_typs = List.concat (List.map quant_item_typ qis) in let ret_typ = app_typ id (List.concat (List.map quant_item_arg qis)) in mk_typschm typq (function_typ arg_typs ret_typ) + ) let have_undefined_builtins = ref false let undefined_builtin_val_specs = - [extern_of_string (mk_id "internal_pick") "forall ('a:Type). list('a) -> 'a"; - extern_of_string (mk_id "undefined_bool") "unit -> bool"; - extern_of_string (mk_id "undefined_bit") "unit -> bit"; - extern_of_string (mk_id "undefined_int") "unit -> int"; - extern_of_string (mk_id "undefined_nat") "unit -> nat"; - extern_of_string (mk_id "undefined_real") "unit -> real"; - extern_of_string (mk_id "undefined_string") "unit -> string"; - extern_of_string (mk_id "undefined_list") "forall ('a:Type). 'a -> list('a)"; - extern_of_string (mk_id "undefined_range") "forall 'n 'm. (atom('n), atom('m)) -> range('n,'m)"; - extern_of_string (mk_id "undefined_vector") "forall 'n ('a:Type) ('ord : Order). (atom('n), 'a) -> vector('n, 'ord,'a)"; - extern_of_string (mk_id "undefined_bitvector") "forall 'n. atom('n) -> bitvector('n, dec)"; - extern_of_string (mk_id "undefined_unit") "unit -> unit"] + [ + extern_of_string (mk_id "internal_pick") "forall ('a:Type). list('a) -> 'a"; + extern_of_string (mk_id "undefined_bool") "unit -> bool"; + extern_of_string (mk_id "undefined_bit") "unit -> bit"; + extern_of_string (mk_id "undefined_int") "unit -> int"; + extern_of_string (mk_id "undefined_nat") "unit -> nat"; + extern_of_string (mk_id "undefined_real") "unit -> real"; + extern_of_string (mk_id "undefined_string") "unit -> string"; + extern_of_string (mk_id "undefined_list") "forall ('a:Type). 'a -> list('a)"; + extern_of_string (mk_id "undefined_range") "forall 'n 'm. (atom('n), atom('m)) -> range('n,'m)"; + extern_of_string (mk_id "undefined_vector") + "forall 'n ('a:Type) ('ord : Order). (atom('n), 'a) -> vector('n, 'ord,'a)"; + extern_of_string (mk_id "undefined_bitvector") "forall 'n. atom('n) -> bitvector('n, dec)"; + extern_of_string (mk_id "undefined_unit") "unit -> unit"; + ] let generate_undefineds vs_ids defs = let undefined_builtins = - if !have_undefined_builtins then - [] - else - begin - have_undefined_builtins := true; - List.filter - (fun def -> IdSet.is_empty (IdSet.inter vs_ids (ids_of_def def))) - undefined_builtin_val_specs - end + if !have_undefined_builtins then [] + else begin + have_undefined_builtins := true; + List.filter (fun def -> IdSet.is_empty (IdSet.inter vs_ids (ids_of_def def))) undefined_builtin_val_specs + end in let undefined_tu = function | Tu_aux (Tu_ty_id (Typ_aux (Typ_tuple typs, _), id), _) -> - mk_exp (E_app (id, List.map (fun typ -> mk_exp (E_typ (typ, mk_lit_exp L_undef))) typs)) + mk_exp (E_app (id, List.map (fun typ -> mk_exp (E_typ (typ, mk_lit_exp L_undef))) typs)) | Tu_aux (Tu_ty_id (typ, id), _) -> mk_exp (E_app (id, [mk_exp (E_typ (typ, mk_lit_exp L_undef))])) in - let p_tup = function - | [pat] -> pat - | pats -> mk_pat (P_tuple pats) - in + let p_tup = function [pat] -> pat | pats -> mk_pat (P_tuple pats) in let undefined_union id typq tus = - let pat = p_tup (quant_items typq |> List.map quant_item_param |> List.concat |> List.map (fun id -> mk_pat (P_id id))) in + let pat = + p_tup (quant_items typq |> List.map quant_item_param |> List.concat |> List.map (fun id -> mk_pat (P_id id))) + in let body = - if !opt_fast_undefined && List.length tus > 0 then - undefined_tu (List.hd tus) - else + if !opt_fast_undefined && List.length tus > 0 then undefined_tu (List.hd tus) + else ( (* Deduplicate arguments for each constructor to keep definitions manageable. *) let extract_tu = function | Tu_aux (Tu_ty_id (Typ_aux (Typ_tuple typs, _), id), _) -> (id, typs) | Tu_aux (Tu_ty_id (typ, id), _) -> (id, [typ]) in - let record_arg_typs m (_,typs) = + let record_arg_typs m (_, typs) = let m' = - List.fold_left (fun m typ -> - TypMap.add typ (1 + try TypMap.find typ m with Not_found -> 0) m) TypMap.empty typs in - TypMap.merge (fun _ x y -> match x,y with Some m, Some n -> Some (max m n) - | None, x -> x - | x, None -> x) m m' + List.fold_left + (fun m typ -> TypMap.add typ (1 + try TypMap.find typ m with Not_found -> 0) m) + TypMap.empty typs + in + TypMap.merge + (fun _ x y -> match (x, y) with Some m, Some n -> Some (max m n) | None, x -> x | x, None -> x) + m m' in - let make_undef_var typ n (i,lbs,m) = - let j = i+n in + let make_undef_var typ n (i, lbs, m) = + let j = i + n in let rec aux k = - if k = j then [] else + if k = j then [] + else ( let v = mk_id ("u_" ^ string_of_int k) in - (mk_letbind (mk_pat (P_typ (typ,mk_pat (P_id v)))) (mk_lit_exp L_undef)):: - (aux (k+1)) + mk_letbind (mk_pat (P_typ (typ, mk_pat (P_id v)))) (mk_lit_exp L_undef) :: aux (k + 1) + ) in (j, aux i @ lbs, TypMap.add typ i m) in - let make_constr m (id,typs) = - let args, _ = List.fold_right (fun typ (acc,m) -> - let i = TypMap.find typ m in - (mk_exp (E_id (mk_id ("u_" ^ string_of_int i)))::acc, - TypMap.add typ (i+1) m)) typs ([],m) in + let make_constr m (id, typs) = + let args, _ = + List.fold_right + (fun typ (acc, m) -> + let i = TypMap.find typ m in + (mk_exp (E_id (mk_id ("u_" ^ string_of_int i))) :: acc, TypMap.add typ (i + 1) m) + ) + typs ([], m) + in mk_exp (E_app (id, args)) in let constr_args = List.map extract_tu tus in let typs_needed = List.fold_left record_arg_typs TypMap.empty constr_args in - let (_,letbinds,typ_to_var) = TypMap.fold make_undef_var typs_needed (0,[],TypMap.empty) in - List.fold_left (fun e lb -> mk_exp (E_let (lb,e))) - (mk_exp (E_app (mk_id "internal_pick", - [mk_exp (E_list (List.map (make_constr typ_to_var) constr_args))]))) letbinds + let _, letbinds, typ_to_var = TypMap.fold make_undef_var typs_needed (0, [], TypMap.empty) in + List.fold_left + (fun e lb -> mk_exp (E_let (lb, e))) + (mk_exp (E_app (mk_id "internal_pick", [mk_exp (E_list (List.map (make_constr typ_to_var) constr_args))]))) + letbinds + ) in - (mk_val_spec (VS_val_spec (undefined_typschm id typq, prepend_id "undefined_" id, None, false)), - mk_fundef [mk_funcl (prepend_id "undefined_" id) - pat - body]) + ( mk_val_spec (VS_val_spec (undefined_typschm id typq, prepend_id "undefined_" id, None, false)), + mk_fundef [mk_funcl (prepend_id "undefined_" id) pat body] + ) in let undefined_td = function | TD_enum (id, ids, _) when not (IdSet.mem (prepend_id "undefined_" id) vs_ids) -> - let typschm = typschm_of_string ("unit -> " ^ string_of_id id) in - [mk_val_spec (VS_val_spec (typschm, prepend_id "undefined_" id, None, false)); - mk_fundef [mk_funcl (prepend_id "undefined_" id) - (mk_pat (P_lit (mk_lit L_unit))) - (if !opt_fast_undefined && List.length ids > 0 then - mk_exp (E_id (List.hd ids)) - else - mk_exp (E_app (mk_id "internal_pick", - [mk_exp (E_list (List.map (fun id -> mk_exp (E_id id)) ids))])))]] + let typschm = typschm_of_string ("unit -> " ^ string_of_id id) in + [ + mk_val_spec (VS_val_spec (typschm, prepend_id "undefined_" id, None, false)); + mk_fundef + [ + mk_funcl (prepend_id "undefined_" id) + (mk_pat (P_lit (mk_lit L_unit))) + ( if !opt_fast_undefined && List.length ids > 0 then mk_exp (E_id (List.hd ids)) + else + mk_exp (E_app (mk_id "internal_pick", [mk_exp (E_list (List.map (fun id -> mk_exp (E_id id)) ids))])) + ); + ]; + ] | TD_record (id, typq, fields, _) when not (IdSet.mem (prepend_id "undefined_" id) vs_ids) -> - let pat = p_tup (quant_items typq |> List.map quant_item_param |> List.concat |> List.map (fun id -> mk_pat (P_id id))) in - [mk_val_spec (VS_val_spec (undefined_typschm id typq, prepend_id "undefined_" id, None, false)); - mk_fundef [mk_funcl (prepend_id "undefined_" id) - pat - (mk_exp (E_struct (List.map (fun (_, id) -> mk_fexp id (mk_lit_exp L_undef)) fields)))]] + let pat = + p_tup (quant_items typq |> List.map quant_item_param |> List.concat |> List.map (fun id -> mk_pat (P_id id))) + in + [ + mk_val_spec (VS_val_spec (undefined_typschm id typq, prepend_id "undefined_" id, None, false)); + mk_fundef + [ + mk_funcl (prepend_id "undefined_" id) pat + (mk_exp (E_struct (List.map (fun (_, id) -> mk_fexp id (mk_lit_exp L_undef)) fields))); + ]; + ] | TD_variant (id, typq, tus, _) when not (IdSet.mem (prepend_id "undefined_" id) vs_ids) -> - let vs, def = undefined_union id typq tus in - [vs; def] + let vs, def = undefined_union id typq tus in + [vs; def] | _ -> [] in let undefined_scattered id typq = @@ -1194,20 +1264,19 @@ let generate_undefineds vs_ids defs = undefined_union id typq tus in let rec undefined_defs = function - | DEF_aux (DEF_type (TD_aux (td_aux, _)), _) as def :: defs -> - def :: undefined_td td_aux @ undefined_defs defs + | (DEF_aux (DEF_type (TD_aux (td_aux, _)), _) as def) :: defs -> (def :: undefined_td td_aux) @ undefined_defs defs (* The function definition must come after the scattered type definition is complete, so put it at the end. *) - | DEF_aux (DEF_scattered (SD_aux (SD_variant (id, typq), _)), _) as def :: defs -> - let vs, fn = undefined_scattered id typq in - def :: vs :: undefined_defs defs @ [fn] - | def :: defs -> - def :: undefined_defs defs + | (DEF_aux (DEF_scattered (SD_aux (SD_variant (id, typq), _)), _) as def) :: defs -> + let vs, fn = undefined_scattered id typq in + (def :: vs :: undefined_defs defs) @ [fn] + | def :: defs -> def :: undefined_defs defs | [] -> [] in undefined_builtins @ undefined_defs defs let rec get_uninitialized_registers = function - | DEF_aux (DEF_register (DEC_aux (DEC_reg (typ, id, None), _)), _) :: defs -> (typ, id) :: get_uninitialized_registers defs + | DEF_aux (DEF_register (DEC_aux (DEC_reg (typ, id, None), _)), _) :: defs -> + (typ, id) :: get_uninitialized_registers defs | _ :: defs -> get_uninitialized_registers defs | [] -> [] @@ -1216,70 +1285,76 @@ let generate_initialize_registers vs_ids defs = let initialize_registers = if IdSet.mem (mk_id "initialize_registers") vs_ids then [] else if regs = [] then - [val_spec_of_string (mk_id "initialize_registers") "unit -> unit"; - mk_fundef [mk_funcl (mk_id "initialize_registers") - (mk_pat (P_lit (mk_lit L_unit))) - (mk_exp (E_lit (mk_lit L_unit)))]] + [ + val_spec_of_string (mk_id "initialize_registers") "unit -> unit"; + mk_fundef + [mk_funcl (mk_id "initialize_registers") (mk_pat (P_lit (mk_lit L_unit))) (mk_exp (E_lit (mk_lit L_unit)))]; + ] else - [val_spec_of_string (mk_id "initialize_registers") "unit -> unit"; - mk_fundef [mk_funcl (mk_id "initialize_registers") - (mk_pat (P_lit (mk_lit L_unit))) - (mk_exp (E_block (List.map (fun (typ, id) -> mk_exp (E_assign (mk_lexp (LE_id id), mk_lit_exp L_undef))) regs)))]] + [ + val_spec_of_string (mk_id "initialize_registers") "unit -> unit"; + mk_fundef + [ + mk_funcl (mk_id "initialize_registers") + (mk_pat (P_lit (mk_lit L_unit))) + (mk_exp + (E_block (List.map (fun (typ, id) -> mk_exp (E_assign (mk_lexp (LE_id id), mk_lit_exp L_undef))) regs)) + ); + ]; + ] in defs @ initialize_registers let generate_enum_functions vs_ids defs = let rec gen_enums = function - | DEF_aux (DEF_type (TD_aux (TD_enum (id, elems, _), _)), _) as enum :: defs -> - let enum_val_spec name quants typ = - mk_val_spec (VS_val_spec (mk_typschm (mk_typquant quants) typ, name, None, !opt_enum_casts)) - in - let range_constraint kid = nc_and (nc_lteq (nint 0) (nvar kid)) (nc_lteq (nvar kid) (nint (List.length elems - 1))) in - - (* Create a function that converts a number to an enum. *) - let to_enum = - let kid = mk_kid "e" in - let name = append_id id "_of_num" in - let pexp n id = - let pat = - if n = List.length elems - 1 then - mk_pat (P_wild) - else - mk_pat (P_lit (mk_lit (L_num (Big_int.of_int n)))) - in - mk_pexp (Pat_exp (pat, mk_exp (E_id id))) - in - let funcl = - mk_funcl name - (mk_pat (P_id (mk_id "arg#"))) - (mk_exp (E_match (mk_exp (E_id (mk_id "arg#")), List.mapi pexp elems))) - in - if IdSet.mem name vs_ids then [] - else - [ enum_val_spec name - [mk_qi_id K_int kid; mk_qi_nc (range_constraint kid)] - (function_typ [atom_typ (nvar kid)] (mk_typ (Typ_id id))); - mk_fundef [funcl] ] - in - - (* Create a function that converts from an enum to a number. *) - let from_enum = - let kid = mk_kid "e" in - let to_typ = mk_typ (Typ_exist ([mk_kopt K_int kid], range_constraint kid, atom_typ (nvar kid))) in - let name = prepend_id "num_of_" id in - let pexp n id = mk_pexp (Pat_exp (mk_pat (P_id id), mk_lit_exp (L_num (Big_int.of_int n)))) in - let funcl = - mk_funcl name - (mk_pat (P_id (mk_id "arg#"))) - (mk_exp (E_match (mk_exp (E_id (mk_id "arg#")), List.mapi pexp elems))) - in - if IdSet.mem name vs_ids then [] - else - [ enum_val_spec name [] (function_typ [mk_typ (Typ_id id)] to_typ); - mk_fundef [funcl] ] - in - enum :: to_enum @ from_enum @ gen_enums defs + | (DEF_aux (DEF_type (TD_aux (TD_enum (id, elems, _), _)), _) as enum) :: defs -> + let enum_val_spec name quants typ = + mk_val_spec (VS_val_spec (mk_typschm (mk_typquant quants) typ, name, None, !opt_enum_casts)) + in + let range_constraint kid = + nc_and (nc_lteq (nint 0) (nvar kid)) (nc_lteq (nvar kid) (nint (List.length elems - 1))) + in + + (* Create a function that converts a number to an enum. *) + let to_enum = + let kid = mk_kid "e" in + let name = append_id id "_of_num" in + let pexp n id = + let pat = + if n = List.length elems - 1 then mk_pat P_wild else mk_pat (P_lit (mk_lit (L_num (Big_int.of_int n)))) + in + mk_pexp (Pat_exp (pat, mk_exp (E_id id))) + in + let funcl = + mk_funcl name + (mk_pat (P_id (mk_id "arg#"))) + (mk_exp (E_match (mk_exp (E_id (mk_id "arg#")), List.mapi pexp elems))) + in + if IdSet.mem name vs_ids then [] + else + [ + enum_val_spec name + [mk_qi_id K_int kid; mk_qi_nc (range_constraint kid)] + (function_typ [atom_typ (nvar kid)] (mk_typ (Typ_id id))); + mk_fundef [funcl]; + ] + in + (* Create a function that converts from an enum to a number. *) + let from_enum = + let kid = mk_kid "e" in + let to_typ = mk_typ (Typ_exist ([mk_kopt K_int kid], range_constraint kid, atom_typ (nvar kid))) in + let name = prepend_id "num_of_" id in + let pexp n id = mk_pexp (Pat_exp (mk_pat (P_id id), mk_lit_exp (L_num (Big_int.of_int n)))) in + let funcl = + mk_funcl name + (mk_pat (P_id (mk_id "arg#"))) + (mk_exp (E_match (mk_exp (E_id (mk_id "arg#")), List.mapi pexp elems))) + in + if IdSet.mem name vs_ids then [] + else [enum_val_spec name [] (function_typ [mk_typ (Typ_id id)] to_typ); mk_fundef [funcl]] + in + (enum :: to_enum) @ from_enum @ gen_enums defs | def :: defs -> def :: gen_enums defs | [] -> [] in @@ -1287,21 +1362,18 @@ let generate_enum_functions vs_ids defs = let incremental_ctx = ref initial_ctx -let process_ast ?generate:(generate=true) ast = +let process_ast ?(generate = true) ast = let ast, ctx = to_ast !incremental_ctx ast in incremental_ctx := ctx; let vs_ids = val_spec_ids ast.defs in - if not !opt_undefined_gen && generate then - { ast with defs = generate_enum_functions vs_ids ast.defs } + if (not !opt_undefined_gen) && generate then { ast with defs = generate_enum_functions vs_ids ast.defs } else if generate then - { ast with - defs = ast.defs - |> generate_undefineds vs_ids - |> generate_enum_functions vs_ids - |> generate_initialize_registers vs_ids + { + ast with + defs = + ast.defs |> generate_undefineds vs_ids |> generate_enum_functions vs_ids |> generate_initialize_registers vs_ids; } - else - ast + else ast let ast_of_def_string_with ocaml_pos f str = let lexbuf = Lexing.from_string str in @@ -1320,13 +1392,10 @@ let defs_of_string ocaml_pos str = (ast_of_def_string ocaml_pos str).defs let get_lexbuf f = let in_chan = open_in f in let lexbuf = Lexing.from_channel in_chan in - lexbuf.Lexing.lex_curr_p <- { Lexing.pos_fname = f; - Lexing.pos_lnum = 1; - Lexing.pos_bol = 0; - Lexing.pos_cnum = 0; }; - lexbuf, in_chan + lexbuf.Lexing.lex_curr_p <- { Lexing.pos_fname = f; Lexing.pos_lnum = 1; Lexing.pos_bol = 0; Lexing.pos_cnum = 0 }; + (lexbuf, in_chan) -let parse_file ?loc:(l=Parse_ast.Unknown) (f : string) : (Lexer.comment list * Parse_ast.def list) = +let parse_file ?loc:(l = Parse_ast.Unknown) (f : string) : Lexer.comment list * Parse_ast.def list = try let lexbuf, in_chan = get_lexbuf f in begin @@ -1335,31 +1404,25 @@ let parse_file ?loc:(l=Parse_ast.Unknown) (f : string) : (Lexer.comment list * P let defs = Parser.file Lexer.token lexbuf in close_in in_chan; (!Lexer.comments, defs) - with - | Parser.Error -> - let pos = Lexing.lexeme_start_p lexbuf in - let tok = Lexing.lexeme lexbuf in - raise (Reporting.err_syntax pos ("current token: " ^ tok)) + with Parser.Error -> + let pos = Lexing.lexeme_start_p lexbuf in + let tok = Lexing.lexeme lexbuf in + raise (Reporting.err_syntax pos ("current token: " ^ tok)) end - with - | Sys_error err -> raise (Reporting.err_general l err) + with Sys_error err -> raise (Reporting.err_general l err) let get_lexbuf_from_string f s = let lexbuf = Lexing.from_string s in - lexbuf.Lexing.lex_curr_p <- { Lexing.pos_fname = f; - Lexing.pos_lnum = 1; - Lexing.pos_bol = 0; - Lexing.pos_cnum = 0; }; + lexbuf.Lexing.lex_curr_p <- { Lexing.pos_fname = f; Lexing.pos_lnum = 1; Lexing.pos_bol = 0; Lexing.pos_cnum = 0 }; lexbuf - + let parse_file_from_string ~filename:f ~contents:s = let lexbuf = get_lexbuf_from_string f s in try Lexer.comments := []; let defs = Parser.file Lexer.token lexbuf in (!Lexer.comments, defs) - with - | Parser.Error -> - let pos = Lexing.lexeme_start_p lexbuf in - let tok = Lexing.lexeme lexbuf in - raise (Reporting.err_syntax pos ("current token: " ^ tok)) + with Parser.Error -> + let pos = Lexing.lexeme_start_p lexbuf in + let tok = Lexing.lexeme lexbuf in + raise (Reporting.err_syntax pos ("current token: " ^ tok)) diff --git a/src/lib/initial_check.mli b/src/lib/initial_check.mli index d65d0d46c..450970271 100644 --- a/src/lib/initial_check.mli +++ b/src/lib/initial_check.mli @@ -114,7 +114,7 @@ val undefined_builtin_val_specs : uannot def list val generate_undefineds : IdSet.t -> uannot def list -> uannot def list val generate_enum_functions : IdSet.t -> uannot def list -> uannot def list - + (** If the generate flag is false, then we won't generate any auxilliary definitions, like the initialize_registers function *) val process_ast : ?generate:bool -> Parse_ast.defs -> uannot ast @@ -123,15 +123,16 @@ val process_ast : ?generate:bool -> Parse_ast.defs -> uannot ast val extern_of_string : ?pure:bool -> id -> string -> uannot def val val_spec_of_string : id -> string -> uannot def -val defs_of_string : (string * int * int * int) -> string -> uannot def list -val ast_of_def_string : (string * int * int * int) -> string -> uannot ast -val ast_of_def_string_with : (string * int * int * int) -> (Parse_ast.def list -> Parse_ast.def list) -> string -> uannot ast +val defs_of_string : string * int * int * int -> string -> uannot def list +val ast_of_def_string : string * int * int * int -> string -> uannot ast +val ast_of_def_string_with : + string * int * int * int -> (Parse_ast.def list -> Parse_ast.def list) -> string -> uannot ast val exp_of_string : string -> uannot exp val typ_of_string : string -> typ val constraint_of_string : string -> n_constraint (** {2 Parsing files } *) - + (** Parse a file into a sequence of comments and a parse AST @param ?loc If we get an error reading the file, report the error at this location *) diff --git a/src/lib/interactive.ml b/src/lib/interactive.ml index 784106a44..a75e28837 100644 --- a/src/lib/interactive.ml +++ b/src/lib/interactive.ml @@ -73,24 +73,18 @@ open Printf let opt_interactive = ref false type istate = { - ast : Type_check.tannot ast; - effect_info : Effects.side_effect_info; - env : Type_check.Env.t; - default_sail_dir : string; - } + ast : Type_check.tannot ast; + effect_info : Effects.side_effect_info; + env : Type_check.Env.t; + default_sail_dir : string; +} -let initial_istate default_sail_dir = { - ast = empty_ast; - effect_info = Effects.empty_side_effect_info; - env = Type_check.initial_env; - default_sail_dir = default_sail_dir; - } - -let arg str = - ("<" ^ str ^ ">") |> Util.yellow |> Util.clear +let initial_istate default_sail_dir = + { ast = empty_ast; effect_info = Effects.empty_side_effect_info; env = Type_check.initial_env; default_sail_dir } -let command str = - str |> Util.green |> Util.clear +let arg str = "<" ^ str ^ ">" |> Util.yellow |> Util.clear + +let command str = str |> Util.green |> Util.clear type action = | ArgString of string * (string -> action) @@ -103,7 +97,7 @@ let commands = ref [] let get_command cmd = List.assoc_opt cmd !commands let all_commands () = !commands - + let reflect_typ action = let open Type_check in let rec arg_typs = function @@ -112,9 +106,7 @@ let reflect_typ action = | Action _ -> [] | ActionUnit _ -> [] in - match action with - | Action _ -> function_typ [unit_typ] unit_typ - | _ -> function_typ (arg_typs action) unit_typ + match action with Action _ -> function_typ [unit_typ] unit_typ | _ -> function_typ (arg_typs action) unit_typ let generate_help name help action = let rec args = function @@ -124,47 +116,44 @@ let generate_help name help action = | ActionUnit _ -> [] in let args = args action in - let help = match String.split_on_char ':' help with + let help = + match String.split_on_char ':' help with | [] -> assert false - | (prefix :: splits) -> - List.map (fun split -> - match String.split_on_char ' ' split with - | [] -> assert false - | (subst :: rest) -> - if Str.string_match (Str.regexp "^[0-9]+") subst 0 then - let num_str = Str.matched_string subst in - let num_end = Str.match_end () in - let punct = String.sub subst num_end (String.length subst - num_end) in - List.nth args (int_of_string num_str) ^ punct ^ " " ^ String.concat " " rest - else - command (":" ^ subst) ^ " " ^ String.concat " " rest - ) splits - |> String.concat "" - |> (fun rest -> prefix ^ rest) + | prefix :: splits -> + List.map + (fun split -> + match String.split_on_char ' ' split with + | [] -> assert false + | subst :: rest -> + if Str.string_match (Str.regexp "^[0-9]+") subst 0 then ( + let num_str = Str.matched_string subst in + let num_end = Str.match_end () in + let punct = String.sub subst num_end (String.length subst - num_end) in + List.nth args (int_of_string num_str) ^ punct ^ " " ^ String.concat " " rest + ) + else command (":" ^ subst) ^ " " ^ String.concat " " rest + ) + splits + |> String.concat "" + |> fun rest -> prefix ^ rest in sprintf "%s %s - %s" Util.(name |> green |> clear) (String.concat ", " args) help let run_action istate cmd argument action = let args = String.split_on_char ',' argument in let rec call args action = - match args, action with - | (x :: xs), ArgString (hint, next) -> - call xs (next (String.trim x)) - | (x :: xs), ArgInt (hint, next) -> - let x = String.trim x in - if Str.string_match (Str.regexp "^[0-9]+$") x 0 then - call xs (next (int_of_string x)) - else - failwith (sprintf "%s argument %s must be an non-negative integer" (command cmd) (arg hint)) - | _, Action act -> - act istate + match (args, action) with + | x :: xs, ArgString (hint, next) -> call xs (next (String.trim x)) + | x :: xs, ArgInt (hint, next) -> + let x = String.trim x in + if Str.string_match (Str.regexp "^[0-9]+$") x 0 then call xs (next (int_of_string x)) + else failwith (sprintf "%s argument %s must be an non-negative integer" (command cmd) (arg hint)) + | _, Action act -> act istate | _, ActionUnit act -> - act istate; - istate - | _, _ -> - failwith (sprintf "Bad arguments for %s, see (%s %s)" (command cmd) (command ":help") (command cmd)) + act istate; + istate + | _, _ -> failwith (sprintf "Bad arguments for %s, see (%s %s)" (command cmd) (command ":help") (command cmd)) in call args action - -let register_command ~name:name ~help:help action = - commands := (":" ^ name, (help, action)) :: !commands + +let register_command ~name ~help action = commands := (":" ^ name, (help, action)) :: !commands diff --git a/src/lib/interactive.mli b/src/lib/interactive.mli index 40bbad24f..02226f45f 100644 --- a/src/lib/interactive.mli +++ b/src/lib/interactive.mli @@ -75,11 +75,11 @@ val opt_interactive : bool ref abstract syntax tree, effect info and the type-checking environment. Also contains the default Sail directory *) type istate = { - ast : Type_check.tannot ast; - effect_info : Effects.side_effect_info; - env : Type_check.Env.t; - default_sail_dir : string; - } + ast : Type_check.tannot ast; + effect_info : Effects.side_effect_info; + env : Type_check.Env.t; + default_sail_dir : string; +} val initial_istate : string -> istate @@ -97,7 +97,7 @@ val reflect_typ : action -> typ val get_command : string -> (string * action) option val all_commands : unit -> (string * (string * action)) list - + val generate_help : string -> string -> action -> string val run_action : istate -> string -> string -> action -> istate diff --git a/src/lib/interpreter.ml b/src/lib/interpreter.ml index f46d296ba..23e0304f7 100644 --- a/src/lib/interpreter.ml +++ b/src/lib/interpreter.ml @@ -70,18 +70,17 @@ open Ast_defs open Ast_util open Value -type gstate = - { registers : value Bindings.t; - allow_registers : bool; (* For some uses we want to forbid touching any registers. *) - primops : (value list -> value) StringMap.t; - letbinds : (Type_check.tannot letbind) list; - fundefs : (Type_check.tannot fundef) Bindings.t; - last_write_ea : (value * value * value) option; - typecheck_env : Type_check.Env.t; - } - -type lstate = - { locals : value Bindings.t } +type gstate = { + registers : value Bindings.t; + allow_registers : bool; (* For some uses we want to forbid touching any registers. *) + primops : (value list -> value) StringMap.t; + letbinds : Type_check.tannot letbind list; + fundefs : Type_check.tannot fundef Bindings.t; + last_write_ea : (value * value * value) option; + typecheck_env : Type_check.Env.t; +} + +type lstate = { locals : value Bindings.t } type state = lstate * gstate @@ -95,58 +94,44 @@ let value_of_lit (L_aux (l_aux, _)) = | L_string str -> V_string str | L_num n -> V_int n | L_hex str -> - Util.string_to_list str - |> List.map (fun c -> List.map (fun b -> V_bit b) (Sail_lib.hex_char c)) - |> List.concat - |> (fun v -> V_vector v) - | L_bin str -> - Util.string_to_list str - |> List.map (fun c -> V_bit (Sail_lib.bin_char c)) - |> (fun v -> V_vector v) - | L_real str -> - begin match Util.split_on_char '.' str with - | [whole; frac] -> - let whole = Rational.of_int (int_of_string whole) in - let frac = Rational.div (Rational.of_int (int_of_string frac)) (Rational.of_int (Util.power 10 (String.length frac))) in - V_real (Rational.add whole frac) - | _ -> failwith "could not parse real literal" - end + Util.string_to_list str |> List.map (fun c -> List.map (fun b -> V_bit b) (Sail_lib.hex_char c)) |> List.concat + |> fun v -> V_vector v + | L_bin str -> Util.string_to_list str |> List.map (fun c -> V_bit (Sail_lib.bin_char c)) |> fun v -> V_vector v + | L_real str -> begin + match Util.split_on_char '.' str with + | [whole; frac] -> + let whole = Rational.of_int (int_of_string whole) in + let frac = + Rational.div (Rational.of_int (int_of_string frac)) (Rational.of_int (Util.power 10 (String.length frac))) + in + V_real (Rational.add whole frac) + | _ -> failwith "could not parse real literal" + end | L_undef -> failwith "value_of_lit of undefined" +let is_value = function E_aux (E_internal_value _, _) -> true | _ -> false -let is_value = function - | (E_aux (E_internal_value _, _)) -> true - | _ -> false - -let is_true = function - | (E_aux (E_internal_value (V_bool b), annot)) -> b - | _ -> false +let is_true = function E_aux (E_internal_value (V_bool b), annot) -> b | _ -> false -let is_false = function - | (E_aux (E_internal_value (V_bool b), _)) -> not b - | _ -> false +let is_false = function E_aux (E_internal_value (V_bool b), _) -> not b | _ -> false -let exp_of_value v = (E_aux (E_internal_value v, (Parse_ast.Unknown, Type_check.empty_tannot))) -let value_of_exp = function - | (E_aux (E_internal_value v, _)) -> v - | _ -> failwith "value_of_exp coerction failed" +let exp_of_value v = E_aux (E_internal_value v, (Parse_ast.Unknown, Type_check.empty_tannot)) +let value_of_exp = function E_aux (E_internal_value v, _) -> v | _ -> failwith "value_of_exp coerction failed" let fallthrough = let open Type_check in try let env = initial_env |> Env.add_scattered_variant (mk_id "exception") (mk_typquant []) in - check_case env exc_typ (mk_pexp (Pat_exp (mk_pat (P_id (mk_id "exn")), mk_exp (E_throw (mk_exp (E_id (mk_id "exn"))))))) unit_typ - with - | Type_error (_, l, err) -> - Reporting.unreachable l __POS__ (Type_error.string_of_type_error err); + check_case env exc_typ + (mk_pexp (Pat_exp (mk_pat (P_id (mk_id "exn")), mk_exp (E_throw (mk_exp (E_id (mk_id "exn"))))))) + unit_typ + with Type_error (_, l, err) -> Reporting.unreachable l __POS__ (Type_error.string_of_type_error err) (**************************************************************************) (* 1. Interpreter Monad *) (**************************************************************************) -type return_value = - | Return_ok of value - | Return_exception of value +type return_value = Return_ok of value | Return_exception of value (* when changing effect arms remember to also update effect_request type below *) type 'a response = @@ -158,18 +143,17 @@ type 'a response = | Read_mem of (* read_kind : *) value * (* address : *) value * (* length : *) value * (value -> 'a) | Write_ea of (* write_kind : *) value * (* address : *) value * (* length : *) value * (unit -> 'a) | Excl_res of (bool -> 'a) - | Write_mem of (* write_kind : *) value * (* address : *) value * (* length : *) value * (* value : *) value * (bool -> 'a) + | Write_mem of + (* write_kind : *) value * (* address : *) value * (* length : *) value * (* value : *) value * (bool -> 'a) | Barrier of (* barrier_kind : *) value * (unit -> 'a) | Read_reg of string * (value -> 'a) | Write_reg of string * value * (unit -> 'a) | Get_primop of string * ((value list -> value) -> 'a) | Get_local of string * (value -> 'a) | Put_local of string * value * (unit -> 'a) - | Get_global_letbinds of ((Type_check.tannot letbind) list -> 'a) + | Get_global_letbinds of (Type_check.tannot letbind list -> 'a) -and 'a monad = - | Pure of 'a - | Yield of ('a monad response) +and 'a monad = Pure of 'a | Yield of 'a monad response let map_response f = function | Early_return v -> Early_return v @@ -189,24 +173,17 @@ let map_response f = function | Put_local (name, v, cont) -> Put_local (name, v, fun () -> f (cont ())) | Get_global_letbinds cont -> Get_global_letbinds (fun lbs -> f (cont lbs)) -let rec liftM f = function - | Pure x -> Pure (f x) - | Yield g -> Yield (map_response (liftM f) g) +let rec liftM f = function Pure x -> Pure (f x) | Yield g -> Yield (map_response (liftM f) g) let return x = Pure x -let rec bind m f = - match m with - | Pure x -> f x - | Yield m -> Yield (map_response (fun m -> bind m f) m) +let rec bind m f = match m with Pure x -> f x | Yield m -> Yield (map_response (fun m -> bind m f) m) let ( >>= ) m f = bind m f let ( >> ) m1 m2 = bind m1 (function () -> m2) -type ('a, 'b) either = - | Left of 'a - | Right of 'b +type ('a, 'b) either = Left of 'a | Right of 'b (* Support for interpreting exceptions *) @@ -218,60 +195,47 @@ let catch m = let throw v = Yield (Exception v) -let call (f : id) (args : value list) : return_value monad = - Yield (Call (f, args, fun v -> Pure v)) +let call (f : id) (args : value list) : return_value monad = Yield (Call (f, args, fun v -> Pure v)) -let read_mem rk addr len : value monad = - Yield (Read_mem (rk, addr, len, (fun v -> Pure v))) +let read_mem rk addr len : value monad = Yield (Read_mem (rk, addr, len, fun v -> Pure v)) -let write_ea wk addr len : unit monad = - Yield (Write_ea (wk, addr, len, (fun () -> Pure ()))) +let write_ea wk addr len : unit monad = Yield (Write_ea (wk, addr, len, fun () -> Pure ())) -let excl_res () : bool monad = - Yield (Excl_res (fun b -> Pure b)) +let excl_res () : bool monad = Yield (Excl_res (fun b -> Pure b)) -let write_mem wk addr len v : bool monad = - Yield (Write_mem (wk, addr, len, v, fun b -> Pure b)) +let write_mem wk addr len v : bool monad = Yield (Write_mem (wk, addr, len, v, fun b -> Pure b)) -let barrier bk : unit monad = - Yield (Barrier (bk, fun () -> Pure ())) +let barrier bk : unit monad = Yield (Barrier (bk, fun () -> Pure ())) -let read_reg name : value monad = - Yield (Read_reg (name, fun v -> Pure v)) +let read_reg name : value monad = Yield (Read_reg (name, fun v -> Pure v)) -let write_reg name v : unit monad = - Yield (Write_reg (name, v, fun () -> Pure ())) +let write_reg name v : unit monad = Yield (Write_reg (name, v, fun () -> Pure ())) -let fail s = - Yield (Fail s) +let fail s = Yield (Fail s) -let get_primop name : (value list -> value) monad = - Yield (Get_primop (name, fun op -> Pure op)) +let get_primop name : (value list -> value) monad = Yield (Get_primop (name, fun op -> Pure op)) -let get_local name : value monad = - Yield (Get_local (name, fun v -> Pure v)) +let get_local name : value monad = Yield (Get_local (name, fun v -> Pure v)) -let put_local name v : unit monad = - Yield (Put_local (name, v, fun () -> Pure ())) +let put_local name v : unit monad = Yield (Put_local (name, v, fun () -> Pure ())) -let get_global_letbinds () : (Type_check.tannot letbind) list monad = - Yield (Get_global_letbinds (fun lbs -> Pure lbs)) +let get_global_letbinds () : Type_check.tannot letbind list monad = Yield (Get_global_letbinds (fun lbs -> Pure lbs)) let early_return v = Yield (Early_return v) let assertion_failed msg = Yield (Assertion_failed msg) -let liftM2 f m1 m2 = m1 >>= fun x -> m2 >>= fun y -> return (f x y) +let liftM2 f m1 m2 = + m1 >>= fun x -> + m2 >>= fun y -> return (f x y) let letbind_pat_ids (LB_aux (LB_val (pat, _), _)) = pat_ids pat let subst id value exp = Ast_util.subst id (exp_of_value value) exp let local_variable id lstate gstate = - try - Bindings.find id lstate.locals |> exp_of_value - with - | Not_found -> failwith ("Could not find local variable " ^ string_of_id id) + try Bindings.find id lstate.locals |> exp_of_value + with Not_found -> failwith ("Could not find local variable " ^ string_of_id id) (**************************************************************************) (* 2. Expression Evaluation *) @@ -285,9 +249,9 @@ let value_of_fexp (FE_aux (FE_fexp (id, exp), _)) = (string_of_id id, value_of_e let rec build_letchain id lbs (E_aux (_, annot) as exp) = match lbs with | [] -> exp - | lb :: lbs when IdSet.mem id (letbind_pat_ids lb)-> - let exp = E_aux (E_let (lb, exp), annot) in - build_letchain id lbs exp + | lb :: lbs when IdSet.mem id (letbind_pat_ids lb) -> + let exp = E_aux (E_let (lb, exp), annot) in + build_letchain id lbs exp | _ :: lbs -> build_letchain id lbs exp let is_interpreter_extern id env = @@ -298,38 +262,34 @@ let get_interpreter_extern id env = let open Type_check in Env.get_extern id env "interpreter" -type partial_binding = - | Complete_binding of value - | Partial_binding of (value * Big_int.num * Big_int.num) list +type partial_binding = Complete_binding of value | Partial_binding of (value * Big_int.num * Big_int.num) list let combine _ v1 v2 = match (v1, v2) with | None, None -> None | Some v1, None -> Some v1 | None, Some v2 -> Some v2 - | Some (Partial_binding p1), Some (Partial_binding p2) -> - Some (Partial_binding (p1 @ p2)) - | Some (Complete_binding _), Some (Complete_binding _) -> - failwith "Tried to bind same identifier twice!" - | Some _, Some _ -> - failwith "Tried to mix partial and complete binding!" + | Some (Partial_binding p1), Some (Partial_binding p2) -> Some (Partial_binding (p1 @ p2)) + | Some (Complete_binding _), Some (Complete_binding _) -> failwith "Tried to bind same identifier twice!" + | Some _, Some _ -> failwith "Tried to mix partial and complete binding!" let complete_bindings = Bindings.map (function - | Complete_binding v -> v - | Partial_binding ((v1, n1, m1) :: partial_values) -> - let max, min = - List.fold_left (fun (max, min) (_, n, m) -> - (Big_int.max max (Big_int.max n m), - Big_int.min min (Big_int.min n m)) - ) (n1, m1) partial_values in - let len = Big_int.sub (Big_int.succ max) min in - List.fold_left (fun bv (slice, n, m) -> - prerr_endline (string_of_value slice); - value_update_subrange [bv; V_int n; V_int m; slice] - ) (value_zeros [V_int len]) ((v1, n1, m1) :: partial_values) - | Partial_binding [] -> - Reporting.unreachable Parse_ast.Unknown __POS__ "Empty partial binding set" + | Complete_binding v -> v + | Partial_binding ((v1, n1, m1) :: partial_values) -> + let max, min = + List.fold_left + (fun (max, min) (_, n, m) -> (Big_int.max max (Big_int.max n m), Big_int.min min (Big_int.min n m))) + (n1, m1) partial_values + in + let len = Big_int.sub (Big_int.succ max) min in + List.fold_left + (fun bv (slice, n, m) -> + prerr_endline (string_of_value slice); + value_update_subrange [bv; V_int n; V_int m; slice] + ) + (value_zeros [V_int len]) ((v1, n1, m1) :: partial_values) + | Partial_binding [] -> Reporting.unreachable Parse_ast.Unknown __POS__ "Empty partial binding set" ) let rec step (E_aux (e_aux, annot) as orig_exp) = @@ -337,376 +297,319 @@ let rec step (E_aux (e_aux, annot) as orig_exp) = match e_aux with | E_block [] -> wrap (E_lit (L_aux (L_unit, Parse_ast.Unknown))) | E_block [exp] when is_value exp -> return exp - | E_block [E_aux (E_block _, _) as exp] -> return exp + | E_block [(E_aux (E_block _, _) as exp)] -> return exp | E_block (exp :: exps) when is_value exp -> wrap (E_block exps) - | E_block (exp :: exps) -> - step exp >>= fun exp' -> wrap (E_block (exp' :: exps)) - - | E_lit (L_aux (L_undef, _)) -> - begin - let env = Type_check.env_of_annot annot in - let typ = Type_check.typ_of_annot annot in - let undef_exp = Ast_util.undefined_of_typ false Parse_ast.Unknown (fun _ -> empty_uannot) typ in - let undef_exp = Type_check.check_exp env undef_exp typ in - return undef_exp - end - - | E_lit lit -> - begin - try return (exp_of_value (value_of_lit lit)) - with Failure s -> fail ("Failure: " ^ s) - end - + | E_block (exp :: exps) -> step exp >>= fun exp' -> wrap (E_block (exp' :: exps)) + | E_lit (L_aux (L_undef, _)) -> begin + let env = Type_check.env_of_annot annot in + let typ = Type_check.typ_of_annot annot in + let undef_exp = Ast_util.undefined_of_typ false Parse_ast.Unknown (fun _ -> empty_uannot) typ in + let undef_exp = Type_check.check_exp env undef_exp typ in + return undef_exp + end + | E_lit lit -> begin try return (exp_of_value (value_of_lit lit)) with Failure s -> fail ("Failure: " ^ s) end | E_if (exp, then_exp, else_exp) when is_true exp -> return then_exp | E_if (exp, then_exp, else_exp) when is_false exp -> return else_exp - | E_if (exp, then_exp, else_exp) -> - step exp >>= fun exp' -> wrap (E_if (exp', then_exp, else_exp)) - + | E_if (exp, then_exp, else_exp) -> step exp >>= fun exp' -> wrap (E_if (exp', then_exp, else_exp)) | E_loop (While, _, exp, body) -> wrap (E_if (exp, E_aux (E_block [body; orig_exp], annot), exp_of_value V_unit)) | E_loop (Until, _, exp, body) -> wrap (E_block [body; E_aux (E_if (exp, exp_of_value V_unit, orig_exp), annot)]) - | E_assert (exp, msg) when is_true exp -> wrap unit_exp - | E_assert (exp, msg) when is_false exp && is_value msg -> - assertion_failed (coerce_string (value_of_exp msg)) - | E_assert (exp, msg) when is_false exp -> - step msg >>= fun msg' -> wrap (E_assert (exp, msg')) - | E_assert (exp, msg) -> - step exp >>= fun exp' -> wrap (E_assert (exp', msg)) - + | E_assert (exp, msg) when is_false exp && is_value msg -> assertion_failed (coerce_string (value_of_exp msg)) + | E_assert (exp, msg) when is_false exp -> step msg >>= fun msg' -> wrap (E_assert (exp, msg')) + | E_assert (exp, msg) -> step exp >>= fun exp' -> wrap (E_assert (exp', msg)) | E_vector exps -> - let evaluated, unevaluated = Util.take_drop is_value exps in - begin - match unevaluated with - | exp :: exps -> - step exp >>= fun exp' -> wrap (E_vector (evaluated @ exp' :: exps)) - | [] -> return (exp_of_value (V_vector (List.map value_of_exp evaluated))) - end - + let evaluated, unevaluated = Util.take_drop is_value exps in + begin + match unevaluated with + | exp :: exps -> step exp >>= fun exp' -> wrap (E_vector (evaluated @ (exp' :: exps))) + | [] -> return (exp_of_value (V_vector (List.map value_of_exp evaluated))) + end | E_list exps -> - let evaluated, unevaluated = Util.take_drop is_value exps in - begin - match unevaluated with - | exp :: exps -> - step exp >>= fun exp' -> wrap (E_list (evaluated @ exp' :: exps)) - | [] -> return (exp_of_value (V_list (List.map value_of_exp evaluated))) - end - + let evaluated, unevaluated = Util.take_drop is_value exps in + begin + match unevaluated with + | exp :: exps -> step exp >>= fun exp' -> wrap (E_list (evaluated @ (exp' :: exps))) + | [] -> return (exp_of_value (V_list (List.map value_of_exp evaluated))) + end (* Special rules for short circuting boolean operators *) | E_app (id, [x; y]) when (string_of_id id = "and_bool" || string_of_id id = "or_bool") && not (is_value x) -> - step x >>= fun x' -> wrap (E_app (id, [x'; y])) - | E_app (id, [x; y]) when string_of_id id = "and_bool" && is_false x -> - return (exp_of_value (V_bool false)) - | E_app (id, [x; y]) when string_of_id id = "or_bool" && is_true x -> - return (exp_of_value (V_bool true)) - + step x >>= fun x' -> wrap (E_app (id, [x'; y])) + | E_app (id, [x; y]) when string_of_id id = "and_bool" && is_false x -> return (exp_of_value (V_bool false)) + | E_app (id, [x; y]) when string_of_id id = "or_bool" && is_true x -> return (exp_of_value (V_bool true)) | E_let (LB_aux (LB_val (pat, bind), lb_annot), body) when not (is_value bind) -> - step bind >>= fun bind' -> wrap (E_let (LB_aux (LB_val (pat, bind'), lb_annot), body)) + step bind >>= fun bind' -> wrap (E_let (LB_aux (LB_val (pat, bind'), lb_annot), body)) | E_let (LB_aux (LB_val (pat, bind), lb_annot), body) -> - let matched, bindings = pattern_match (Type_check.env_of orig_exp) pat (value_of_exp bind) in - if matched then - return (List.fold_left (fun body (id, v) -> subst id v body) body (Bindings.bindings (complete_bindings bindings))) - else - fail "Match failure" - - | E_vector_subrange (vec, n, m) -> - wrap (E_app (mk_id "vector_subrange_dec", [vec; n; m])) - | E_vector_access (vec, n) -> - wrap (E_app (mk_id "vector_access_dec", [vec; n])) - - | E_vector_update (vec, n, x) -> - wrap (E_app (mk_id "vector_update", [vec; n; x])) + let matched, bindings = pattern_match (Type_check.env_of orig_exp) pat (value_of_exp bind) in + if matched then + return + (List.fold_left (fun body (id, v) -> subst id v body) body (Bindings.bindings (complete_bindings bindings))) + else fail "Match failure" + | E_vector_subrange (vec, n, m) -> wrap (E_app (mk_id "vector_subrange_dec", [vec; n; m])) + | E_vector_access (vec, n) -> wrap (E_app (mk_id "vector_access_dec", [vec; n])) + | E_vector_update (vec, n, x) -> wrap (E_app (mk_id "vector_update", [vec; n; x])) | E_vector_update_subrange (vec, n, m, x) -> - (* FIXME: Currently not general enough *) - wrap (E_app (mk_id "vector_update_subrange_dec", [vec; n; m; x])) - + (* FIXME: Currently not general enough *) + wrap (E_app (mk_id "vector_update_subrange_dec", [vec; n; m; x])) (* otherwise left-to-right evaluation order for function applications *) | E_app (id, exps) -> - let evaluated, unevaluated = Util.take_drop is_value exps in - begin - let open Type_check in - match unevaluated with - | exp :: exps -> - step exp >>= fun exp' -> wrap (E_app (id, evaluated @ exp' :: exps)) - | [] when Env.is_union_constructor id (env_of_annot annot) -> - return (exp_of_value (V_ctor (string_of_id id, List.map value_of_exp evaluated))) - | [] when is_interpreter_extern id (env_of_annot annot) -> - begin + let evaluated, unevaluated = Util.take_drop is_value exps in + begin + let open Type_check in + match unevaluated with + | exp :: exps -> step exp >>= fun exp' -> wrap (E_app (id, evaluated @ (exp' :: exps))) + | [] when Env.is_union_constructor id (env_of_annot annot) -> + return (exp_of_value (V_ctor (string_of_id id, List.map value_of_exp evaluated))) + | [] when is_interpreter_extern id (env_of_annot annot) -> begin let extern = get_interpreter_extern id (env_of_annot annot) in match extern with | "reg_deref" -> - let regname = List.hd evaluated |> value_of_exp |> coerce_ref in - read_reg regname >>= fun v -> return (exp_of_value v) - | "read_mem" -> - begin match evaluated with - | [rk; addrsize; addr; len] -> - read_mem (value_of_exp rk) (value_of_exp addr) (value_of_exp len) >>= fun v -> return (exp_of_value v) - | _ -> - fail "Wrong number of parameters to read_mem intrinsic" - end - | "write_mem_ea" -> - begin match evaluated with - | [wk; addrsize; addr; len] -> - write_ea (value_of_exp wk) (value_of_exp addr) (value_of_exp len) >> wrap unit_exp - | _ -> - fail "Wrong number of parameters to write_ea intrinsic" - end - | "excl_res" -> - begin match evaluated with - | [_] -> - excl_res () >>= fun b -> return (exp_of_value (V_bool b)) - | _ -> - fail "Wrong number of parameters to excl_res intrinsic" - end - | "write_mem" -> - begin match evaluated with - | [wk; addrsize; addr; len; v] -> - write_mem (value_of_exp wk) (value_of_exp addr) (value_of_exp len) (value_of_exp v) >>= fun b -> return (exp_of_value (V_bool b)) - | _ -> - fail "Wrong number of parameters to write_memv intrinsic" - end - | "barrier" -> - begin match evaluated with - | [bk] -> - barrier (value_of_exp bk) >> wrap unit_exp - | _ -> - fail "Wrong number of parameters to barrier intrinsic" - end - | _ -> - get_primop extern >>= - fun op -> try - return (exp_of_value (op (List.map value_of_exp evaluated))) - with _ as exc -> fail ("Exception calling primop '" ^ extern ^ "': " ^ Printexc.to_string exc) + let regname = List.hd evaluated |> value_of_exp |> coerce_ref in + read_reg regname >>= fun v -> return (exp_of_value v) + | "read_mem" -> begin + match evaluated with + | [rk; addrsize; addr; len] -> + read_mem (value_of_exp rk) (value_of_exp addr) (value_of_exp len) >>= fun v -> + return (exp_of_value v) + | _ -> fail "Wrong number of parameters to read_mem intrinsic" + end + | "write_mem_ea" -> begin + match evaluated with + | [wk; addrsize; addr; len] -> + write_ea (value_of_exp wk) (value_of_exp addr) (value_of_exp len) >> wrap unit_exp + | _ -> fail "Wrong number of parameters to write_ea intrinsic" + end + | "excl_res" -> begin + match evaluated with + | [_] -> excl_res () >>= fun b -> return (exp_of_value (V_bool b)) + | _ -> fail "Wrong number of parameters to excl_res intrinsic" + end + | "write_mem" -> begin + match evaluated with + | [wk; addrsize; addr; len; v] -> + write_mem (value_of_exp wk) (value_of_exp addr) (value_of_exp len) (value_of_exp v) >>= fun b -> + return (exp_of_value (V_bool b)) + | _ -> fail "Wrong number of parameters to write_memv intrinsic" + end + | "barrier" -> begin + match evaluated with + | [bk] -> barrier (value_of_exp bk) >> wrap unit_exp + | _ -> fail "Wrong number of parameters to barrier intrinsic" + end + | _ -> ( + get_primop extern >>= fun op -> + try return (exp_of_value (op (List.map value_of_exp evaluated))) + with _ as exc -> fail ("Exception calling primop '" ^ extern ^ "': " ^ Printexc.to_string exc) + ) end - | [] -> - call id (List.map value_of_exp evaluated) >>= - (function Return_ok v -> return (exp_of_value v) - | Return_exception v -> wrap (E_throw (exp_of_value v))) - end - | E_app_infix (x, id, y) when is_value x && is_value y -> - call id [value_of_exp x; value_of_exp y] >>= - (function Return_ok v -> return (exp_of_value v) - | Return_exception v -> wrap (E_throw (exp_of_value v))) - | E_app_infix (x, id, y) when is_value x -> - step y >>= fun y' -> wrap (E_app_infix (x, id, y')) - | E_app_infix (x, id, y) -> - step x >>= fun x' -> wrap (E_app_infix (x', id, y)) - + | [] -> ( + call id (List.map value_of_exp evaluated) >>= function + | Return_ok v -> return (exp_of_value v) + | Return_exception v -> wrap (E_throw (exp_of_value v)) + ) + end + | E_app_infix (x, id, y) when is_value x && is_value y -> ( + call id [value_of_exp x; value_of_exp y] >>= function + | Return_ok v -> return (exp_of_value v) + | Return_exception v -> wrap (E_throw (exp_of_value v)) + ) + | E_app_infix (x, id, y) when is_value x -> step y >>= fun y' -> wrap (E_app_infix (x, id, y')) + | E_app_infix (x, id, y) -> step x >>= fun x' -> wrap (E_app_infix (x', id, y)) | E_return exp when is_value exp -> early_return (value_of_exp exp) | E_return exp -> step exp >>= fun exp' -> wrap (E_return exp') - | E_tuple exps -> - let evaluated, unevaluated = Util.take_drop is_value exps in - begin - match unevaluated with - | exp :: exps -> - step exp >>= fun exp' -> wrap (E_tuple (evaluated @ exp' :: exps)) - | [] -> return (exp_of_value (tuple_value (List.map value_of_exp exps))) - end - - | E_match (exp, pexps) when not (is_value exp) -> - step exp >>= fun exp' -> wrap (E_match (exp', pexps)) + let evaluated, unevaluated = Util.take_drop is_value exps in + begin + match unevaluated with + | exp :: exps -> step exp >>= fun exp' -> wrap (E_tuple (evaluated @ (exp' :: exps))) + | [] -> return (exp_of_value (tuple_value (List.map value_of_exp exps))) + end + | E_match (exp, pexps) when not (is_value exp) -> step exp >>= fun exp' -> wrap (E_match (exp', pexps)) | E_match (_, []) -> fail "Pattern matching failed" - | E_match (exp, Pat_aux (Pat_exp (pat, body), _) :: pexps) -> - begin try - let matched, bindings = pattern_match (Type_check.env_of body) pat (value_of_exp exp) in - if matched then - return (List.fold_left (fun body (id, v) -> subst id v body) body (Bindings.bindings (complete_bindings bindings))) - else - wrap (E_match (exp, pexps)) - with Failure s -> fail ("Failure: " ^ s) - end - | E_match (exp, Pat_aux (Pat_when (pat, guard, body), pat_annot) :: pexps) when not (is_value guard) -> - begin try - let matched, bindings = pattern_match (Type_check.env_of body) pat (value_of_exp exp) in - let bindings = complete_bindings bindings in - if matched then - let guard = List.fold_left (fun guard (id, v) -> subst id v guard) guard (Bindings.bindings bindings) in - let body = List.fold_left (fun body (id, v) -> subst id v body) body (Bindings.bindings bindings) in - step guard >>= fun guard' -> - wrap (E_match (exp, Pat_aux (Pat_when (pat, guard', body), pat_annot) :: pexps)) - else - wrap (E_match (exp, pexps)) - with Failure s -> fail ("Failure: " ^ s) - end + | E_match (exp, Pat_aux (Pat_exp (pat, body), _) :: pexps) -> begin + try + let matched, bindings = pattern_match (Type_check.env_of body) pat (value_of_exp exp) in + if matched then + return + (List.fold_left (fun body (id, v) -> subst id v body) body (Bindings.bindings (complete_bindings bindings))) + else wrap (E_match (exp, pexps)) + with Failure s -> fail ("Failure: " ^ s) + end + | E_match (exp, Pat_aux (Pat_when (pat, guard, body), pat_annot) :: pexps) when not (is_value guard) -> begin + try + let matched, bindings = pattern_match (Type_check.env_of body) pat (value_of_exp exp) in + let bindings = complete_bindings bindings in + if matched then ( + let guard = List.fold_left (fun guard (id, v) -> subst id v guard) guard (Bindings.bindings bindings) in + let body = List.fold_left (fun body (id, v) -> subst id v body) body (Bindings.bindings bindings) in + step guard >>= fun guard' -> wrap (E_match (exp, Pat_aux (Pat_when (pat, guard', body), pat_annot) :: pexps)) + ) + else wrap (E_match (exp, pexps)) + with Failure s -> fail ("Failure: " ^ s) + end | E_match (exp, Pat_aux (Pat_when (pat, guard, body), pat_annot) :: pexps) when is_true guard -> return body - | E_match (exp, Pat_aux (Pat_when (pat, guard, body), pat_annot) :: pexps) when is_false guard -> wrap (E_match (exp, pexps)) - + | E_match (exp, Pat_aux (Pat_when (pat, guard, body), pat_annot) :: pexps) when is_false guard -> + wrap (E_match (exp, pexps)) | E_typ (typ, exp) -> return exp - | E_throw exp when is_value exp -> throw (value_of_exp exp) | E_throw exp -> step exp >>= fun exp' -> wrap (E_throw exp') | E_exit exp when is_value exp -> throw (V_ctor ("Exit", [value_of_exp exp])) | E_exit exp -> step exp >>= fun exp' -> wrap (E_exit exp') - - | E_ref id -> - return (exp_of_value (V_ref (string_of_id id))) - - | E_id id -> - begin - let open Type_check in - match Env.lookup_id id (env_of_annot annot) with - | Register _ -> - read_reg (string_of_id id) >>= fun v -> return (exp_of_value v) - | Local (Mutable, _) -> get_local (string_of_id id) >>= fun v -> return (exp_of_value v) - | Local (Immutable, _) -> + | E_ref id -> return (exp_of_value (V_ref (string_of_id id))) + | E_id id -> begin + let open Type_check in + match Env.lookup_id id (env_of_annot annot) with + | Register _ -> read_reg (string_of_id id) >>= fun v -> return (exp_of_value v) + | Local (Mutable, _) -> get_local (string_of_id id) >>= fun v -> return (exp_of_value v) + | Local (Immutable, _) -> (* if we get here without already having substituted, it must be a top-level letbind *) get_global_letbinds () >>= fun lbs -> let chain = build_letchain id lbs orig_exp in return chain - | Enum _ -> - return (exp_of_value (V_ctor (string_of_id id, []))) - | _ -> fail ("Couldn't find id " ^ string_of_id id) - end - + | Enum _ -> return (exp_of_value (V_ctor (string_of_id id, []))) + | _ -> fail ("Couldn't find id " ^ string_of_id id) + end | E_struct fexps -> - let evaluated, unevaluated = Util.take_drop is_value_fexp fexps in - begin - match unevaluated with - | FE_aux (FE_fexp (id, exp), fe_annot) :: fexps -> - step exp >>= fun exp' -> - wrap (E_struct (evaluated @ FE_aux (FE_fexp (id, exp'), fe_annot) :: fexps)) - | [] -> - List.map value_of_fexp fexps - |> List.fold_left (fun record (field, v) -> StringMap.add field v record) StringMap.empty - |> (fun record -> V_record record) - |> exp_of_value - |> return - end - - | E_struct_update (exp, fexps) when not (is_value exp) -> - step exp >>= fun exp' -> wrap (E_struct_update (exp', fexps)) + let evaluated, unevaluated = Util.take_drop is_value_fexp fexps in + begin + match unevaluated with + | FE_aux (FE_fexp (id, exp), fe_annot) :: fexps -> + step exp >>= fun exp' -> wrap (E_struct (evaluated @ (FE_aux (FE_fexp (id, exp'), fe_annot) :: fexps))) + | [] -> + List.map value_of_fexp fexps + |> List.fold_left (fun record (field, v) -> StringMap.add field v record) StringMap.empty + |> (fun record -> V_record record) + |> exp_of_value |> return + end + | E_struct_update (exp, fexps) when not (is_value exp) -> step exp >>= fun exp' -> wrap (E_struct_update (exp', fexps)) | E_struct_update (record, fexps) -> - let evaluated, unevaluated = Util.take_drop is_value_fexp fexps in - begin - match unevaluated with - | FE_aux (FE_fexp (id, exp), fe_annot) :: fexps -> - step exp >>= fun exp' -> - wrap (E_struct_update (record, evaluated @ FE_aux (FE_fexp (id, exp'), fe_annot) :: fexps)) - | [] -> - List.map value_of_fexp fexps - |> List.fold_left (fun record (field, v) -> StringMap.add field v record) (coerce_record (value_of_exp record)) - |> (fun record -> V_record record) - |> exp_of_value - |> return - end - - | E_field (exp, id) when not (is_value exp) -> - step exp >>= fun exp' -> wrap (E_field (exp', id)) + let evaluated, unevaluated = Util.take_drop is_value_fexp fexps in + begin + match unevaluated with + | FE_aux (FE_fexp (id, exp), fe_annot) :: fexps -> + step exp >>= fun exp' -> + wrap (E_struct_update (record, evaluated @ (FE_aux (FE_fexp (id, exp'), fe_annot) :: fexps))) + | [] -> + List.map value_of_fexp fexps + |> List.fold_left + (fun record (field, v) -> StringMap.add field v record) + (coerce_record (value_of_exp record)) + |> (fun record -> V_record record) + |> exp_of_value |> return + end + | E_field (exp, id) when not (is_value exp) -> step exp >>= fun exp' -> wrap (E_field (exp', id)) | E_field (exp, id) -> - let record = coerce_record (value_of_exp exp) in - return (exp_of_value (StringMap.find (string_of_id id) record)) - - | E_var (lexp, exp, E_aux (E_block body, _)) -> - wrap (E_block (E_aux (E_assign (lexp, exp), annot) :: body)) - | E_var (lexp, exp, body) -> - wrap (E_block [E_aux (E_assign (lexp, exp), annot); body]) - + let record = coerce_record (value_of_exp exp) in + return (exp_of_value (StringMap.find (string_of_id id) record)) + | E_var (lexp, exp, E_aux (E_block body, _)) -> wrap (E_block (E_aux (E_assign (lexp, exp), annot) :: body)) + | E_var (lexp, exp, body) -> wrap (E_block [E_aux (E_assign (lexp, exp), annot); body]) | E_assign (lexp, exp) when not (is_value exp) -> step exp >>= fun exp' -> wrap (E_assign (lexp, exp')) | E_assign (LE_aux (LE_app (id, args), _), exp) -> wrap (E_app (id, args @ [exp])) - | E_assign (LE_aux (LE_field (lexp, id), ul), exp) -> - begin try - let open Type_check in - let lexp_exp = infer_exp (env_of_annot annot) (exp_of_lexp (strip_lexp lexp)) in - let exp' = E_aux (E_struct_update (lexp_exp, [FE_aux (FE_fexp (id, exp), ul)]), ul) in - wrap (E_assign (lexp, exp')) - with Failure s -> fail ("Failure: " ^ s) - end - | E_assign (LE_aux (LE_vector (vec, n), lexp_annot), exp) -> - begin try - let open Type_check in - let vec_exp = infer_exp (env_of_annot annot) (exp_of_lexp (strip_lexp vec)) in - let exp' = E_aux (E_vector_update (vec_exp, n, exp), lexp_annot) in - wrap (E_assign (vec, exp')) - with Failure s -> fail ("Failure: " ^ s) - end - | E_assign (LE_aux (LE_vector_range (vec, n, m), lexp_annot), exp) -> - begin try - let open Type_check in - let vec_exp = infer_exp (env_of_annot annot) (exp_of_lexp (strip_lexp vec)) in - (* FIXME: let the type checker check this *) - let exp' = E_aux (E_vector_update_subrange (vec_exp, n, m, exp), lexp_annot) in - wrap (E_assign (vec, exp')) - with Failure s -> fail ("Failure: " ^ s) - end - | E_assign (LE_aux (LE_id id, _), exp) | E_assign (LE_aux (LE_typ (_, id), _), exp) -> - begin - let open Type_check in - let name = string_of_id id in - match Env.lookup_id id (env_of_annot annot) with - | Register _ -> - write_reg name (value_of_exp exp) >> wrap unit_exp - | Local (Mutable, _) | Unbound _ -> - put_local name (value_of_exp exp) >> wrap unit_exp - | Local (Immutable, _) -> - fail ("Assignment to immutable local: " ^ name) - | Enum _ -> - fail ("Assignment to union constructor: " ^ name) - end + | E_assign (LE_aux (LE_field (lexp, id), ul), exp) -> begin + try + let open Type_check in + let lexp_exp = infer_exp (env_of_annot annot) (exp_of_lexp (strip_lexp lexp)) in + let exp' = E_aux (E_struct_update (lexp_exp, [FE_aux (FE_fexp (id, exp), ul)]), ul) in + wrap (E_assign (lexp, exp')) + with Failure s -> fail ("Failure: " ^ s) + end + | E_assign (LE_aux (LE_vector (vec, n), lexp_annot), exp) -> begin + try + let open Type_check in + let vec_exp = infer_exp (env_of_annot annot) (exp_of_lexp (strip_lexp vec)) in + let exp' = E_aux (E_vector_update (vec_exp, n, exp), lexp_annot) in + wrap (E_assign (vec, exp')) + with Failure s -> fail ("Failure: " ^ s) + end + | E_assign (LE_aux (LE_vector_range (vec, n, m), lexp_annot), exp) -> begin + try + let open Type_check in + let vec_exp = infer_exp (env_of_annot annot) (exp_of_lexp (strip_lexp vec)) in + (* FIXME: let the type checker check this *) + let exp' = E_aux (E_vector_update_subrange (vec_exp, n, m, exp), lexp_annot) in + wrap (E_assign (vec, exp')) + with Failure s -> fail ("Failure: " ^ s) + end + | E_assign (LE_aux (LE_id id, _), exp) | E_assign (LE_aux (LE_typ (_, id), _), exp) -> begin + let open Type_check in + let name = string_of_id id in + match Env.lookup_id id (env_of_annot annot) with + | Register _ -> write_reg name (value_of_exp exp) >> wrap unit_exp + | Local (Mutable, _) | Unbound _ -> put_local name (value_of_exp exp) >> wrap unit_exp + | Local (Immutable, _) -> fail ("Assignment to immutable local: " ^ name) + | Enum _ -> fail ("Assignment to union constructor: " ^ name) + end | E_assign (LE_aux (LE_deref reference, annot), exp) when not (is_value reference) -> - step reference >>= fun reference' -> wrap (E_assign (LE_aux (LE_deref reference', annot), exp)) + step reference >>= fun reference' -> wrap (E_assign (LE_aux (LE_deref reference', annot), exp)) | E_assign (LE_aux (LE_deref reference, annot), exp) -> - let name = coerce_ref (value_of_exp reference) in - write_reg name (value_of_exp exp) >> wrap unit_exp + let name = coerce_ref (value_of_exp reference) in + write_reg name (value_of_exp exp) >> wrap unit_exp | E_assign (LE_aux (LE_tuple lexps, annot), exp) -> fail "Tuple assignment" - | E_assign (LE_aux (LE_vector_concat lexps, annot), exp) -> fail "Vector concat assignment" - (* + | E_assign (LE_aux (LE_vector_concat lexps, annot), exp) -> + fail "Vector concat assignment" + (* let values = coerce_tuple (value_of_exp exp) in wrap (E_block (List.map2 (fun lexp v -> E_aux (E_assign (lexp, exp_of_value v), (Parse_ast.Unknown, None))) lexps values)) *) - | E_try (exp, pexps) when is_value exp -> return exp - | E_try (exp, pexps) -> - begin - catch (step exp) >>= fun exp' -> - match exp' with - | Left exn -> wrap (E_match (exp_of_value exn, pexps @ [fallthrough])) - | Right exp' -> wrap (E_try (exp', pexps)) - end - + | E_try (exp, pexps) -> begin + catch (step exp) >>= fun exp' -> + match exp' with + | Left exn -> wrap (E_match (exp_of_value exn, pexps @ [fallthrough])) + | Right exp' -> wrap (E_try (exp', pexps)) + end | E_for (id, exp_from, exp_to, exp_step, ord, body) when is_value exp_from && is_value exp_to && is_value exp_step -> - let v_from = value_of_exp exp_from in - let v_to = value_of_exp exp_to in - let v_step = value_of_exp exp_step in - begin match ord with - | Ord_aux (Ord_inc, _) -> - begin match value_gt [v_from; v_to] with - | V_bool true -> wrap (E_lit (L_aux (L_unit, Parse_ast.Unknown))) - | V_bool false -> - wrap (E_block [subst id v_from body; E_aux (E_for (id, exp_of_value (value_add_int [v_from; v_step]), exp_to, exp_step, ord, body), annot)]) - | _ -> assert false - end - | Ord_aux (Ord_dec, _) -> - begin match value_lt [v_from; v_to] with - | V_bool true -> wrap (E_lit (L_aux (L_unit, Parse_ast.Unknown))) - | V_bool false -> - wrap (E_block [subst id v_from body; E_aux (E_for (id, exp_of_value (value_sub_int [v_from; v_step]), exp_to, exp_step, ord, body), annot)]) - | _ -> assert false - end - | Ord_aux (Ord_var _, _) -> fail "Polymorphic order in foreach" - end + let v_from = value_of_exp exp_from in + let v_to = value_of_exp exp_to in + let v_step = value_of_exp exp_step in + begin + match ord with + | Ord_aux (Ord_inc, _) -> begin + match value_gt [v_from; v_to] with + | V_bool true -> wrap (E_lit (L_aux (L_unit, Parse_ast.Unknown))) + | V_bool false -> + wrap + (E_block + [ + subst id v_from body; + E_aux + (E_for (id, exp_of_value (value_add_int [v_from; v_step]), exp_to, exp_step, ord, body), annot); + ] + ) + | _ -> assert false + end + | Ord_aux (Ord_dec, _) -> begin + match value_lt [v_from; v_to] with + | V_bool true -> wrap (E_lit (L_aux (L_unit, Parse_ast.Unknown))) + | V_bool false -> + wrap + (E_block + [ + subst id v_from body; + E_aux + (E_for (id, exp_of_value (value_sub_int [v_from; v_step]), exp_to, exp_step, ord, body), annot); + ] + ) + | _ -> assert false + end + | Ord_aux (Ord_var _, _) -> fail "Polymorphic order in foreach" + end | E_for (id, exp_from, exp_to, exp_step, ord, body) when is_value exp_to && is_value exp_step -> - step exp_from >>= fun exp_from' -> wrap (E_for (id, exp_from', exp_to, exp_step, ord, body)) + step exp_from >>= fun exp_from' -> wrap (E_for (id, exp_from', exp_to, exp_step, ord, body)) | E_for (id, exp_from, exp_to, exp_step, ord, body) when is_value exp_step -> - step exp_to >>= fun exp_to' -> wrap (E_for (id, exp_from, exp_to', exp_step, ord, body)) + step exp_to >>= fun exp_to' -> wrap (E_for (id, exp_from, exp_to', exp_step, ord, body)) | E_for (id, exp_from, exp_to, exp_step, ord, body) -> - step exp_step >>= fun exp_step' -> wrap (E_for (id, exp_from, exp_to, exp_step', ord, body)) - - | E_sizeof nexp -> - begin - match Type_check.big_int_of_nexp nexp with - | Some n -> return (exp_of_value (V_int n)) - | None -> fail "Sizeof unevaluable nexp" - end - + step exp_step >>= fun exp_step' -> wrap (E_for (id, exp_from, exp_to, exp_step', ord, body)) + | E_sizeof nexp -> begin + match Type_check.big_int_of_nexp nexp with + | Some n -> return (exp_of_value (V_int n)) + | None -> fail "Sizeof unevaluable nexp" + end | E_cons (hd, tl) when is_value hd && is_value tl -> - let hd = value_of_exp hd in - let tl = coerce_listlike (value_of_exp tl) in - return (exp_of_value (V_list (hd :: tl))) - | E_cons (hd, tl) when is_value hd -> - step tl >>= fun tl' -> wrap (E_cons (hd, tl')) - | E_cons (hd, tl) -> - step hd >>= fun hd' -> wrap (E_cons (hd', tl)) - + let hd = value_of_exp hd in + let tl = coerce_listlike (value_of_exp tl) in + return (exp_of_value (V_list (hd :: tl))) + | E_cons (hd, tl) when is_value hd -> step tl >>= fun tl' -> wrap (E_cons (hd, tl')) + | E_cons (hd, tl) -> step hd >>= fun hd' -> wrap (E_cons (hd', tl)) | _ -> raise (Invalid_argument ("Unimplemented " ^ string_of_exp orig_exp)) and exp_of_lexp (LE_aux (lexp_aux, _)) = @@ -720,88 +623,97 @@ and exp_of_lexp (LE_aux (lexp_aux, _)) = | LE_vector_range (lexp, exp1, exp2) -> mk_exp (E_vector_subrange (exp_of_lexp lexp, exp1, exp2)) | LE_vector_concat [] -> failwith "Empty LE_vector_concat node in exp_of_lexp" | LE_vector_concat [lexp] -> exp_of_lexp lexp - | LE_vector_concat (lexp :: lexps) -> mk_exp (E_vector_append (exp_of_lexp lexp, exp_of_lexp (mk_lexp (LE_vector_concat lexps)))) + | LE_vector_concat (lexp :: lexps) -> + mk_exp (E_vector_append (exp_of_lexp lexp, exp_of_lexp (mk_lexp (LE_vector_concat lexps)))) | LE_field (lexp, id) -> mk_exp (E_field (exp_of_lexp lexp, id)) and pattern_match env (P_aux (p_aux, (l, _))) value = match p_aux with - | P_lit lit -> eq_value (value_of_lit lit) value, Bindings.empty - | P_wild -> true, Bindings.empty - | P_or(pat1, pat2) -> - let (m1, b1) = pattern_match env pat1 value in - let (m2, b2) = pattern_match env pat2 value in - (* todo: maybe add assertion that bindings are consistent or empty? *) - (m1 || m2, Bindings.merge combine b1 b2) - | P_not(pat) -> - let (m, b) = pattern_match env pat value in - (* todo: maybe add assertion that binding is empty *) - (not m, b) + | P_lit lit -> (eq_value (value_of_lit lit) value, Bindings.empty) + | P_wild -> (true, Bindings.empty) + | P_or (pat1, pat2) -> + let m1, b1 = pattern_match env pat1 value in + let m2, b2 = pattern_match env pat2 value in + (* todo: maybe add assertion that bindings are consistent or empty? *) + (m1 || m2, Bindings.merge combine b1 b2) + | P_not pat -> + let m, b = pattern_match env pat value in + (* todo: maybe add assertion that binding is empty *) + (not m, b) | P_as (pat, id) -> - let matched, bindings = pattern_match env pat value in - matched, Bindings.add id (Complete_binding value) bindings + let matched, bindings = pattern_match env pat value in + (matched, Bindings.add id (Complete_binding value) bindings) | P_typ (_, pat) -> pattern_match env pat value | P_id id -> - let open Type_check in - begin - match Env.lookup_id id env with - | Enum _ -> - if is_ctor value && string_of_id id = fst (coerce_ctor value) - then true, Bindings.empty - else false, Bindings.empty - | _ -> true, Bindings.singleton id (Complete_binding value) - end - | P_vector_subrange (id, n, m) -> - true, Bindings.singleton id (Partial_binding [(value, n, m)]) + let open Type_check in + begin + match Env.lookup_id id env with + | Enum _ -> + if is_ctor value && string_of_id id = fst (coerce_ctor value) then (true, Bindings.empty) + else (false, Bindings.empty) + | _ -> (true, Bindings.singleton id (Complete_binding value)) + end + | P_vector_subrange (id, n, m) -> (true, Bindings.singleton id (Partial_binding [(value, n, m)])) | P_var (pat, _) -> pattern_match env pat value | P_app (id, pats) -> - let (ctor, vals) = coerce_ctor value in - if Id.compare id (mk_id ctor) = 0 then - let matches = List.map2 (pattern_match env) pats vals in - List.for_all fst matches, List.fold_left (Bindings.merge combine) Bindings.empty (List.map snd matches) - else - false, Bindings.empty + let ctor, vals = coerce_ctor value in + if Id.compare id (mk_id ctor) = 0 then ( + let matches = List.map2 (pattern_match env) pats vals in + (List.for_all fst matches, List.fold_left (Bindings.merge combine) Bindings.empty (List.map snd matches)) + ) + else (false, Bindings.empty) | P_vector pats -> - let matches = List.map2 (pattern_match env) pats (coerce_gv value) in - List.for_all fst matches, List.fold_left (Bindings.merge combine) Bindings.empty (List.map snd matches) - | P_vector_concat [] -> eq_value (V_vector []) value, Bindings.empty + let matches = List.map2 (pattern_match env) pats (coerce_gv value) in + (List.for_all fst matches, List.fold_left (Bindings.merge combine) Bindings.empty (List.map snd matches)) + | P_vector_concat [] -> (eq_value (V_vector []) value, Bindings.empty) | P_vector_concat (pat :: pats) -> - (* We have to use the annotation on each member of the - vector_concat pattern to figure out its length. Due to the - recursive call that has an empty_tannot we must not use the - annotation in the whole vector_concat pattern. *) - let open Type_check in - let vector_concat_match n = - let init, rest = Util.take (Big_int.to_int n) (coerce_gv value), Util.drop (Big_int.to_int n) (coerce_gv value) in - let init_match, init_bind = pattern_match env pat (V_vector init) in - let rest_match, rest_bind = pattern_match env (P_aux (P_vector_concat pats, (l, empty_tannot))) (V_vector rest) in - init_match && rest_match, Bindings.merge combine init_bind rest_bind - in - begin match destruct_vector (env_of_pat pat) (typ_of_pat pat) with - | Some (Nexp_aux (Nexp_constant n, _), _, _) -> vector_concat_match n - | None -> - begin match destruct_bitvector (env_of_pat pat) (typ_of_pat pat) with - | Some (Nexp_aux (Nexp_constant n, _), _) -> vector_concat_match n - | _ -> failwith ("Bad bitvector annotation for bitvector concatenation pattern " ^ string_of_typ (Type_check.typ_of_pat pat)) - end - | _ -> failwith ("Bad vector annotation for vector concatentation pattern " ^ string_of_typ (Type_check.typ_of_pat pat)) - end + (* We have to use the annotation on each member of the + vector_concat pattern to figure out its length. Due to the + recursive call that has an empty_tannot we must not use the + annotation in the whole vector_concat pattern. *) + let open Type_check in + let vector_concat_match n = + let init, rest = + (Util.take (Big_int.to_int n) (coerce_gv value), Util.drop (Big_int.to_int n) (coerce_gv value)) + in + let init_match, init_bind = pattern_match env pat (V_vector init) in + let rest_match, rest_bind = + pattern_match env (P_aux (P_vector_concat pats, (l, empty_tannot))) (V_vector rest) + in + (init_match && rest_match, Bindings.merge combine init_bind rest_bind) + in + begin + match destruct_vector (env_of_pat pat) (typ_of_pat pat) with + | Some (Nexp_aux (Nexp_constant n, _), _, _) -> vector_concat_match n + | None -> begin + match destruct_bitvector (env_of_pat pat) (typ_of_pat pat) with + | Some (Nexp_aux (Nexp_constant n, _), _) -> vector_concat_match n + | _ -> + failwith + ("Bad bitvector annotation for bitvector concatenation pattern " + ^ string_of_typ (Type_check.typ_of_pat pat) + ) + end + | _ -> + failwith + ("Bad vector annotation for vector concatentation pattern " ^ string_of_typ (Type_check.typ_of_pat pat)) + end | P_tuple [pat] -> pattern_match env pat value | P_tuple pats | P_list pats -> - let values = coerce_listlike value in - if List.compare_lengths pats values = 0 then ( - let matches = List.map2 (pattern_match env) pats values in - List.for_all fst matches, List.fold_left (Bindings.merge combine) Bindings.empty (List.map snd matches) - ) else ( - false, Bindings.empty - ) - | P_cons (hd_pat, tl_pat) -> - begin match coerce_cons value with - | Some (hd_value, tl_values) -> - let hd_match, hd_bind = pattern_match env hd_pat hd_value in - let tl_match, tl_bind = pattern_match env tl_pat (V_list tl_values) in - hd_match && tl_match, Bindings.merge combine hd_bind tl_bind - | None -> false, Bindings.empty - end + let values = coerce_listlike value in + if List.compare_lengths pats values = 0 then ( + let matches = List.map2 (pattern_match env) pats values in + (List.for_all fst matches, List.fold_left (Bindings.merge combine) Bindings.empty (List.map snd matches)) + ) + else (false, Bindings.empty) + | P_cons (hd_pat, tl_pat) -> begin + match coerce_cons value with + | Some (hd_value, tl_values) -> + let hd_match, hd_bind = pattern_match env hd_pat hd_value in + let tl_match, tl_bind = pattern_match env tl_pat (V_list tl_values) in + (hd_match && tl_match, Bindings.merge combine hd_bind tl_bind) + | None -> (false, Bindings.empty) + end | P_string_append _ -> assert false (* TODO *) let exp_of_fundef (FD_aux (FD_function (_, _, funcls), annot)) value = @@ -814,8 +726,7 @@ let rec defs_letbinds defs = | DEF_aux (DEF_let lb, _) :: defs -> lb :: defs_letbinds defs | _ :: defs -> defs_letbinds defs -let initial_lstate = - { locals = Bindings.empty } +let initial_lstate = { locals = Bindings.empty } let stack_cont (_, _, cont) = cont let stack_string (str, _, _) = str @@ -823,17 +734,35 @@ let stack_state (_, lstate, _) = lstate type frame = | Done of state * value - | Step of string Lazy.t * state * (Type_check.tannot exp) monad * (string Lazy.t * lstate * (return_value -> (Type_check.tannot exp) monad)) list + | Step of + string Lazy.t + * state + * Type_check.tannot exp monad + * (string Lazy.t * lstate * (return_value -> Type_check.tannot exp monad)) list | Break of frame - | Effect_request of string Lazy.t * state * (string Lazy.t * lstate * (return_value -> (Type_check.tannot exp) monad)) list * effect_request - | Fail of string Lazy.t * state * (Type_check.tannot exp) monad * (string Lazy.t * lstate * (return_value -> (Type_check.tannot exp) monad)) list * string + | Effect_request of + string Lazy.t + * state + * (string Lazy.t * lstate * (return_value -> Type_check.tannot exp monad)) list + * effect_request + | Fail of + string Lazy.t + * state + * Type_check.tannot exp monad + * (string Lazy.t * lstate * (return_value -> Type_check.tannot exp monad)) list + * string (* when changing effect_request remember to also update response type above *) and effect_request = | Read_mem of (* read_kind : *) value * (* address : *) value * (* length : *) value * (value -> state -> frame) | Write_ea of (* write_kind : *) value * (* address : *) value * (* length : *) value * (unit -> state -> frame) | Excl_res of (bool -> state -> frame) - | Write_mem of (* write_kind : *) value * (* address : *) value * (* length : *) value * (* value : *) value * (bool -> state -> frame) + | Write_mem of + (* write_kind : *) value + * (* address : *) value + * (* length : *) value + * (* value : *) value + * (bool -> state -> frame) | Barrier of (* barrier_kind : *) value * (unit -> state -> frame) | Read_reg of string * (value -> state -> frame) | Write_reg of string * value * (unit -> state -> frame) @@ -843,140 +772,138 @@ let rec eval_frame' = function | Fail (out, state, m, stack, msg) -> Fail (out, state, m, stack, msg) | Break frame -> Break frame | Effect_request (out, state, stack, eff) -> Effect_request (out, state, stack, eff) - | Step (out, state, m, stack) -> - let lstate, gstate = state in - match (m, stack) with - | Pure v, [] when is_value v -> Done (state, value_of_exp v) - | Pure v, (head :: stack') when is_value v -> - Step (stack_string head, (stack_state head, gstate), stack_cont head (Return_ok (value_of_exp v)), stack') - | Pure exp', _ -> - let out' = lazy (Pretty_print_sail.to_string (Pretty_print_sail.doc_exp (Type_check.strip_exp exp'))) in - Step (out', state, step exp', stack) - | Yield (Call(id, vals, cont)), _ when string_of_id id = "break" -> - begin + | Step (out, state, m, stack) -> ( + let lstate, gstate = state in + match (m, stack) with + | Pure v, [] when is_value v -> Done (state, value_of_exp v) + | Pure v, head :: stack' when is_value v -> + Step (stack_string head, (stack_state head, gstate), stack_cont head (Return_ok (value_of_exp v)), stack') + | Pure exp', _ -> + let out' = lazy (Pretty_print_sail.to_string (Pretty_print_sail.doc_exp (Type_check.strip_exp exp'))) in + Step (out', state, step exp', stack) + | Yield (Call (id, vals, cont)), _ when string_of_id id = "break" -> begin let arg = if List.length vals != 1 then tuple_value vals else List.hd vals in try let body = exp_of_fundef (Bindings.find id gstate.fundefs) arg in Break (Step (lazy "", (initial_lstate, gstate), return body, (out, lstate, cont) :: stack)) - with Not_found -> - Step (out, state, fail ("Fundef not found: " ^ string_of_id id), stack) + with Not_found -> Step (out, state, fail ("Fundef not found: " ^ string_of_id id), stack) end - | Yield (Call(id, vals, cont)), _ -> - begin + | Yield (Call (id, vals, cont)), _ -> begin let arg = if List.length vals != 1 then tuple_value vals else List.hd vals in try let body = exp_of_fundef (Bindings.find id gstate.fundefs) arg in Step (lazy "", (initial_lstate, gstate), return body, (out, lstate, cont) :: stack) - with Not_found -> - Step (out, state, fail ("Fundef not found: " ^ string_of_id id), stack) + with Not_found -> Step (out, state, fail ("Fundef not found: " ^ string_of_id id), stack) end - - | Yield (Read_reg (name, cont)), _ -> - Effect_request (out, state, stack, Read_reg (name, fun v state' -> eval_frame' (Step (out, state', cont v, stack)))) - | Yield (Write_reg (name, v, cont)), _ -> - Effect_request (out, state, stack, Write_reg (name, v, fun () state' -> eval_frame' (Step (out, state', cont (), stack)))) - | Yield (Get_primop (name, cont)), _ -> - begin + | Yield (Read_reg (name, cont)), _ -> + Effect_request + (out, state, stack, Read_reg (name, fun v state' -> eval_frame' (Step (out, state', cont v, stack)))) + | Yield (Write_reg (name, v, cont)), _ -> + Effect_request + (out, state, stack, Write_reg (name, v, fun () state' -> eval_frame' (Step (out, state', cont (), stack)))) + | Yield (Get_primop (name, cont)), _ -> begin try (* If we are in the toplevel interactive interpreter allow the set of primops to be changed dynamically *) let op = StringMap.find name (if !Interactive.opt_interactive then !Value.primops else gstate.primops) in eval_frame' (Step (out, state, cont op, stack)) - with Not_found -> - eval_frame' (Step (out, state, fail ("No such primop: " ^ name), stack)) + with Not_found -> eval_frame' (Step (out, state, fail ("No such primop: " ^ name), stack)) end - | Yield (Get_local (name, cont)), _ -> - begin - try - eval_frame' (Step (out, state, cont (Bindings.find (mk_id name) lstate.locals), stack)) - with Not_found -> - eval_frame' (Step (out, state, fail ("Local not found: " ^ name), stack)) + | Yield (Get_local (name, cont)), _ -> begin + try eval_frame' (Step (out, state, cont (Bindings.find (mk_id name) lstate.locals), stack)) + with Not_found -> eval_frame' (Step (out, state, fail ("Local not found: " ^ name), stack)) end - | Yield (Put_local (name, v, cont)), _ -> - let state' = ({ locals = Bindings.add (mk_id name) v lstate.locals }, gstate) in - eval_frame' (Step (out, state', cont (), stack)) - | Yield (Get_global_letbinds cont), _ -> - eval_frame' (Step (out, state, cont gstate.letbinds, stack)) - | Yield (Read_mem (rk, addr, len, cont)), _ -> - Effect_request (out, state, stack, Read_mem (rk, addr, len, fun result state' -> eval_frame' (Step (out, state', cont result, stack)))) - | Yield (Write_ea (wk, addr, len, cont)), _ -> - Effect_request (out, state, stack, Write_ea (wk, addr, len, fun () state' -> eval_frame' (Step (out, state', cont (), stack)))) - | Yield (Excl_res cont), _ -> - Effect_request (out, state, stack, Excl_res (fun b state' -> eval_frame' (Step (out, state', cont b, stack)))) - | Yield (Write_mem (wk, addr, len, v, cont)), _ -> - Effect_request (out, state, stack, Write_mem (wk, addr, len, v, fun b state' -> eval_frame' (Step (out, state', cont b, stack)))) - | Yield (Barrier (bk, cont)), _ -> - Effect_request (out, state, stack, Barrier (bk, fun () state' -> eval_frame' (Step (out, state', cont (), stack)))) - | Yield (Early_return v), [] -> Done (state, v) - | Yield (Early_return v), (head :: stack') -> - Step (stack_string head, (stack_state head, gstate), stack_cont head (Return_ok v), stack') - | Yield (Assertion_failed msg), _ | Yield (Fail msg), _ -> - Fail (out, state, m, stack, msg) - | Yield (Exception v), [] -> - Fail (out, state, m, stack, "Uncaught exception: " ^ string_of_value v) - | Yield (Exception v), (head :: stack') -> - Step (stack_string head, (stack_state head, gstate), stack_cont head (Return_exception v), stack') + | Yield (Put_local (name, v, cont)), _ -> + let state' = ({ locals = Bindings.add (mk_id name) v lstate.locals }, gstate) in + eval_frame' (Step (out, state', cont (), stack)) + | Yield (Get_global_letbinds cont), _ -> eval_frame' (Step (out, state, cont gstate.letbinds, stack)) + | Yield (Read_mem (rk, addr, len, cont)), _ -> + Effect_request + ( out, + state, + stack, + Read_mem (rk, addr, len, fun result state' -> eval_frame' (Step (out, state', cont result, stack))) + ) + | Yield (Write_ea (wk, addr, len, cont)), _ -> + Effect_request + ( out, + state, + stack, + Write_ea (wk, addr, len, fun () state' -> eval_frame' (Step (out, state', cont (), stack))) + ) + | Yield (Excl_res cont), _ -> + Effect_request (out, state, stack, Excl_res (fun b state' -> eval_frame' (Step (out, state', cont b, stack)))) + | Yield (Write_mem (wk, addr, len, v, cont)), _ -> + Effect_request + ( out, + state, + stack, + Write_mem (wk, addr, len, v, fun b state' -> eval_frame' (Step (out, state', cont b, stack))) + ) + | Yield (Barrier (bk, cont)), _ -> + Effect_request + (out, state, stack, Barrier (bk, fun () state' -> eval_frame' (Step (out, state', cont (), stack)))) + | Yield (Early_return v), [] -> Done (state, v) + | Yield (Early_return v), head :: stack' -> + Step (stack_string head, (stack_state head, gstate), stack_cont head (Return_ok v), stack') + | Yield (Assertion_failed msg), _ | Yield (Fail msg), _ -> Fail (out, state, m, stack, msg) + | Yield (Exception v), [] -> Fail (out, state, m, stack, "Uncaught exception: " ^ string_of_value v) + | Yield (Exception v), head :: stack' -> + Step (stack_string head, (stack_state head, gstate), stack_cont head (Return_exception v), stack') + ) let eval_frame frame = - try eval_frame' frame with - | Type_check.Type_error (env, l, err) -> - raise (Reporting.err_typ l (Type_error.string_of_type_error err)) + try eval_frame' frame + with Type_check.Type_error (env, l, err) -> raise (Reporting.err_typ l (Type_error.string_of_type_error err)) let default_effect_interp state eff = let lstate, gstate = state in match eff with | Read_mem (rk, addr, len, cont) -> - (* all read-kinds treated the same in single-threaded interpreter *) - let addr' = coerce_bv addr in - let len' = coerce_int len in - let result = mk_vector (Sail_lib.read_ram (List.length addr', len', [], addr')) in - cont result state + (* all read-kinds treated the same in single-threaded interpreter *) + let addr' = coerce_bv addr in + let len' = coerce_int len in + let result = mk_vector (Sail_lib.read_ram (List.length addr', len', [], addr')) in + cont result state | Write_ea (wk, addr, len, cont) -> - (* just store the values for the next Write_memv *) - let state' = (lstate, { gstate with last_write_ea = Some (wk, addr, len) }) in - cont () state' + (* just store the values for the next Write_memv *) + let state' = (lstate, { gstate with last_write_ea = Some (wk, addr, len) }) in + cont () state' | Excl_res cont -> - (* always succeeds in single-threaded interpreter *) - cont true state - | Write_mem (wk, addr, len, v, cont) -> - begin - match gstate.last_write_ea with - | Some (wk', addr', len') -> + (* always succeeds in single-threaded interpreter *) + cont true state + | Write_mem (wk, addr, len, v, cont) -> begin + match gstate.last_write_ea with + | Some (wk', addr', len') -> let state' = (lstate, { gstate with last_write_ea = None }) in (* all write-kinds treated the same in single-threaded interpreter *) let addr' = coerce_bv addr in let len' = coerce_int len in let v' = coerce_bv v in - if Big_int.mul len' (Big_int.of_int 8) = Big_int.of_int (List.length v') then + if Big_int.mul len' (Big_int.of_int 8) = Big_int.of_int (List.length v') then ( let b = Sail_lib.write_ram (List.length addr', len', [], addr', v') in cont b state' - else - failwith "Write_memv with length mismatch to preceding Write_ea" - | None -> - failwith "Write_memv without preceding Write_ea" - end + ) + else failwith "Write_memv with length mismatch to preceding Write_ea" + | None -> failwith "Write_memv without preceding Write_ea" + end | Barrier (bk, cont) -> - (* no-op in single-threaded interpreter *) - cont () state + (* no-op in single-threaded interpreter *) + cont () state | Read_reg (name, cont) -> - if gstate.allow_registers then ( - try - cont (Bindings.find (mk_id name) gstate.registers) state - with Not_found -> - failwith ("Read of nonexistent register: " ^ name) - ) else ( - failwith ("Register read disallowed by allow_registers setting: " ^ name) - ) + if gstate.allow_registers then ( + try cont (Bindings.find (mk_id name) gstate.registers) state + with Not_found -> failwith ("Read of nonexistent register: " ^ name) + ) + else failwith ("Register read disallowed by allow_registers setting: " ^ name) | Write_reg (name, v, cont) -> - let id = mk_id name in - if gstate.allow_registers then ( - if Bindings.mem id gstate.registers then - let state' = (lstate, { gstate with registers = Bindings.add id v gstate.registers }) in - cont () state' - else - failwith ("Write of nonexistent register: " ^ name) - ) else ( - failwith ("Register write disallowed by allow_registers setting: " ^ name) - ) + let id = mk_id name in + if gstate.allow_registers then + if Bindings.mem id gstate.registers then ( + let state' = (lstate, { gstate with registers = Bindings.add id v gstate.registers }) in + cont () state' + ) + else failwith ("Write of nonexistent register: " ^ name) + else failwith ("Register write disallowed by allow_registers setting: " ^ name) let effect_interp = ref default_effect_interp @@ -986,20 +913,17 @@ let rec run_frame frame = match frame with | Done (state, v) -> v | Fail (_, _, _, _, msg) -> failwith ("run_frame got Fail: " ^ msg) - | Step (_, _, _, _) -> - run_frame (eval_frame frame) - | Break frame -> - run_frame (eval_frame frame) - | Effect_request (_, state, _, eff) -> - run_frame (!effect_interp state eff) + | Step (_, _, _, _) -> run_frame (eval_frame frame) + | Break frame -> run_frame (eval_frame frame) + | Effect_request (_, state, _, eff) -> run_frame (!effect_interp state eff) -let eval_exp state exp = - run_frame (Step (lazy "", state, return exp, [])) +let eval_exp state exp = run_frame (Step (lazy "", state, return exp, [])) let initial_gstate primops defs env = - { registers = Bindings.empty; + { + registers = Bindings.empty; allow_registers = true; - primops = primops; + primops; letbinds = defs_letbinds defs; fundefs = Bindings.empty; last_write_ea = None; @@ -1008,56 +932,47 @@ let initial_gstate primops defs env = let rec initialize_registers allow_registers gstate = let process_def = function - | DEF_aux (DEF_register (DEC_aux (DEC_reg (typ, id, opt_exp), annot)), _) when allow_registers -> - begin match opt_exp with - | None -> - let env = Type_check.env_of_annot annot in - let typ = Type_check.Env.expand_synonyms env typ in - let exp = mk_exp (E_typ (typ, mk_exp (E_lit (mk_lit L_undef)))) in - let exp = Type_check.check_exp env exp typ in - { gstate with registers = Bindings.add id (eval_exp (initial_lstate, gstate) exp) gstate.registers } - | Some exp -> - { gstate with registers = Bindings.add id (eval_exp (initial_lstate, gstate) exp) gstate.registers } - end + | DEF_aux (DEF_register (DEC_aux (DEC_reg (typ, id, opt_exp), annot)), _) when allow_registers -> begin + match opt_exp with + | None -> + let env = Type_check.env_of_annot annot in + let typ = Type_check.Env.expand_synonyms env typ in + let exp = mk_exp (E_typ (typ, mk_exp (E_lit (mk_lit L_undef)))) in + let exp = Type_check.check_exp env exp typ in + { gstate with registers = Bindings.add id (eval_exp (initial_lstate, gstate) exp) gstate.registers } + | Some exp -> + { gstate with registers = Bindings.add id (eval_exp (initial_lstate, gstate) exp) gstate.registers } + end | _ -> gstate in - function - | def :: defs -> - initialize_registers allow_registers (process_def def) defs - | [] -> gstate + function def :: defs -> initialize_registers allow_registers (process_def def) defs | [] -> gstate -let initial_state ?(registers=true) ast env primops = +let initial_state ?(registers = true) ast env primops = let gstate = initial_gstate primops ast.defs env in let add_function gstate = function - | DEF_aux (DEF_fundef fdef, _) -> - { gstate with fundefs = Bindings.add (id_of_fundef fdef) fdef gstate.fundefs } + | DEF_aux (DEF_fundef fdef, _) -> { gstate with fundefs = Bindings.add (id_of_fundef fdef) fdef gstate.fundefs } | _ -> gstate in let gstate = List.fold_left add_function gstate ast.defs in - let gstate = - { (initialize_registers registers gstate ast.defs) - with allow_registers = registers } - in - initial_lstate, gstate + let gstate = { (initialize_registers registers gstate ast.defs) with allow_registers = registers } in + (initial_lstate, gstate) -type value_result = - | Value_success of value - | Value_error of exn +type value_result = Value_success of value | Value_error of exn let decode_instruction state bv = try let env = (snd state).typecheck_env in - let untyped = mk_exp (E_app ((mk_id "decode"), [mk_exp (E_vector (List.map mk_lit_exp bv))])) in - let typed = Type_check.check_exp - env untyped (app_typ (mk_id "option") - [A_aux (A_typ (mk_typ (Typ_id (mk_id "ast"))), Parse_ast.Unknown)]) in + let untyped = mk_exp (E_app (mk_id "decode", [mk_exp (E_vector (List.map mk_lit_exp bv))])) in + let typed = + Type_check.check_exp env untyped + (app_typ (mk_id "option") [A_aux (A_typ (mk_typ (Typ_id (mk_id "ast"))), Parse_ast.Unknown)]) + in let evaled = eval_exp state typed in match evaled with | V_ctor ("Some", [v]) -> Value_success v | V_ctor ("None", _) -> failwith "decode returned None" | _ -> failwith "decode returned wrong value type" - with _ as exn -> - Value_error exn + with _ as exn -> Value_error exn let annot_exp_effect e_aux l env typ = E_aux (e_aux, (l, Type_check.mk_tannot env typ)) let annot_exp e_aux l env typ = annot_exp_effect e_aux l env typ @@ -1066,16 +981,30 @@ let id_typ id = mk_typ (Typ_id (mk_id id)) let analyse_instruction state ast = let env = (snd state).typecheck_env in let unk = Parse_ast.Unknown in - let typed = annot_exp - (E_app (mk_id "initial_analysis", [annot_exp (E_internal_value ast) unk env (id_typ "ast")])) unk env - (tuple_typ [id_typ "regfps"; id_typ "regfps"; id_typ "regfps"; id_typ "niafps"; id_typ "diafp"; id_typ "instruction_kind"]) + let typed = + annot_exp + (E_app (mk_id "initial_analysis", [annot_exp (E_internal_value ast) unk env (id_typ "ast")])) + unk env + (tuple_typ + [id_typ "regfps"; id_typ "regfps"; id_typ "regfps"; id_typ "niafps"; id_typ "diafp"; id_typ "instruction_kind"] + ) in - Step (lazy (Pretty_print_sail.to_string (Pretty_print_sail.doc_exp (Type_check.strip_exp typed))), state, return typed, []) + Step + ( lazy (Pretty_print_sail.to_string (Pretty_print_sail.doc_exp (Type_check.strip_exp typed))), + state, + return typed, + [] + ) let execute_instruction state ast = let env = (snd state).typecheck_env in let unk = Parse_ast.Unknown in - let typed = annot_exp - (E_app (mk_id "execute", [annot_exp (E_internal_value ast) unk env (id_typ "ast")])) unk env unit_typ + let typed = + annot_exp (E_app (mk_id "execute", [annot_exp (E_internal_value ast) unk env (id_typ "ast")])) unk env unit_typ in - Step (lazy (Pretty_print_sail.to_string (Pretty_print_sail.doc_exp (Type_check.strip_exp typed))), state, return typed, []) + Step + ( lazy (Pretty_print_sail.to_string (Pretty_print_sail.doc_exp (Type_check.strip_exp typed))), + state, + return typed, + [] + ) diff --git a/src/lib/jib_compile.ml b/src/lib/jib_compile.ml index b01e99316..2d5fe7927 100644 --- a/src/lib/jib_compile.ml +++ b/src/lib/jib_compile.ml @@ -79,7 +79,7 @@ let opt_memo_cache = ref false let optimize_aarch64_fast_struct = ref false -let (gensym, _) = symbol_generator "gs" +let gensym, _ = symbol_generator "gs" let ngensym () = name (gensym ()) (**************************************************************************) @@ -123,13 +123,9 @@ let value_of_aval_bit = function | AV_lit (L_aux (L_one, _), _) -> Sail2_values.B1 | _ -> assert false -let is_ct_enum = function - | CT_enum _ -> true - | _ -> false +let is_ct_enum = function CT_enum _ -> true | _ -> false -let iblock1 = function - | [instr] -> instr - | instrs -> iblock instrs +let iblock1 = function [instr] -> instr | instrs -> iblock instrs (** The context type contains two type-checking environments. ctx.local_env contains the closest typechecking @@ -138,19 +134,19 @@ let iblock1 = function type-checking the entire AST. We also keep track of local variables in ctx.locals, so we know when their type changes due to flow typing. *) -type ctx = - { records : (kid list * ctyp Bindings.t) Bindings.t; - enums : IdSet.t Bindings.t; - variants : (kid list * ctyp Bindings.t) Bindings.t; - valspecs : (string option * ctyp list * ctyp) Bindings.t; - quants : ctyp KBindings.t; - local_env : Env.t; - tc_env : Env.t; - effect_info : Effects.side_effect_info; - locals : (mut * ctyp) Bindings.t; - letbinds : int list; - no_raw : bool; - } +type ctx = { + records : (kid list * ctyp Bindings.t) Bindings.t; + enums : IdSet.t Bindings.t; + variants : (kid list * ctyp Bindings.t) Bindings.t; + valspecs : (string option * ctyp list * ctyp) Bindings.t; + quants : ctyp KBindings.t; + local_env : Env.t; + tc_env : Env.t; + effect_info : Effects.side_effect_info; + locals : (mut * ctyp) Bindings.t; + letbinds : int list; + no_raw : bool; +} let ctx_is_extern id ctx = match Bindings.find_opt id ctx.valspecs with @@ -162,19 +158,22 @@ let ctx_get_extern id ctx = match Bindings.find_opt id ctx.valspecs with | Some (Some extern, _, _) -> extern | Some (None, _, _) -> - Reporting.unreachable (id_loc id) __POS__ ("Tried to get extern information for non-extern function " ^ string_of_id id) + Reporting.unreachable (id_loc id) __POS__ + ("Tried to get extern information for non-extern function " ^ string_of_id id) | None -> Env.get_extern id ctx.tc_env "c" -let ctx_has_val_spec id ctx = - Bindings.mem id ctx.valspecs || Bindings.mem id (Env.get_val_specs ctx.tc_env) +let ctx_has_val_spec id ctx = Bindings.mem id ctx.valspecs || Bindings.mem id (Env.get_val_specs ctx.tc_env) let initial_ctx env effect_info = - let initial_valspecs = [ + let initial_valspecs = + [ (mk_id "size_itself_int", (Some "size_itself_int", [CT_lint], CT_lint)); - (mk_id "make_the_value", (Some "make_the_value", [CT_lint], CT_lint)) - ] |> List.to_seq |> Bindings.of_seq + (mk_id "make_the_value", (Some "make_the_value", [CT_lint], CT_lint)); + ] + |> List.to_seq |> Bindings.of_seq in - { records = Bindings.empty; + { + records = Bindings.empty; enums = Bindings.empty; variants = Bindings.empty; valspecs = initial_valspecs; @@ -203,29 +202,35 @@ let rec mangle_string_of_ctyp ctx = function | CT_rounding_mode -> "m" | CT_enum (id, _) -> "E" ^ string_of_id id ^ "%" | CT_ref ctyp -> "&" ^ mangle_string_of_ctyp ctx ctyp - | CT_tup ctyps -> "(" ^ Util.string_of_list "," (mangle_string_of_ctyp ctx) ctyps ^ ")" + | CT_tup ctyps -> "(" ^ Util.string_of_list "," (mangle_string_of_ctyp ctx) ctyps ^ ")" | CT_struct (id, fields) -> - let generic_fields = Bindings.find id ctx.records |> snd |> Bindings.bindings in - (* Note: It might be better to only do this if we actually have polymorphic fields *) - let unifiers = ctyp_unify (id_loc id) (CT_struct (id, generic_fields)) (CT_struct (id, fields)) |> KBindings.bindings |> List.map snd in - begin match unifiers with - | [] -> "R" ^ string_of_id id - | _ -> "R" ^ string_of_id id ^ "<" ^ Util.string_of_list "," (mangle_string_of_ctyp ctx) unifiers ^ ">" - end + let generic_fields = Bindings.find id ctx.records |> snd |> Bindings.bindings in + (* Note: It might be better to only do this if we actually have polymorphic fields *) + let unifiers = + ctyp_unify (id_loc id) (CT_struct (id, generic_fields)) (CT_struct (id, fields)) + |> KBindings.bindings |> List.map snd + in + begin + match unifiers with + | [] -> "R" ^ string_of_id id + | _ -> "R" ^ string_of_id id ^ "<" ^ Util.string_of_list "," (mangle_string_of_ctyp ctx) unifiers ^ ">" + end | CT_variant (id, ctors) -> - let generic_ctors = Bindings.find id ctx.variants |> snd |> Bindings.bindings in - let unifiers = ctyp_unify (id_loc id) (CT_variant (id, generic_ctors)) (CT_variant (id, ctors)) |> KBindings.bindings |> List.map snd in - let prefix = string_of_id id in - (if prefix = "option" then "O" else "U" ^ prefix) ^ "<" ^ Util.string_of_list "," (mangle_string_of_ctyp ctx) unifiers ^ ">" - | CT_vector (_, ctyp) -> - "V" ^ mangle_string_of_ctyp ctx ctyp - | CT_fvector (n, _, ctyp) -> - "F" ^ string_of_int n ^ mangle_string_of_ctyp ctx ctyp - | CT_list ctyp -> - "L" ^ mangle_string_of_ctyp ctx ctyp - | CT_poly kid -> - "P" ^ string_of_kid kid - + let generic_ctors = Bindings.find id ctx.variants |> snd |> Bindings.bindings in + let unifiers = + ctyp_unify (id_loc id) (CT_variant (id, generic_ctors)) (CT_variant (id, ctors)) + |> KBindings.bindings |> List.map snd + in + let prefix = string_of_id id in + (if prefix = "option" then "O" else "U" ^ prefix) + ^ "<" + ^ Util.string_of_list "," (mangle_string_of_ctyp ctx) unifiers + ^ ">" + | CT_vector (_, ctyp) -> "V" ^ mangle_string_of_ctyp ctx ctyp + | CT_fvector (n, _, ctyp) -> "F" ^ string_of_int n ^ mangle_string_of_ctyp ctx ctyp + | CT_list ctyp -> "L" ^ mangle_string_of_ctyp ctx ctyp + | CT_poly kid -> "P" ^ string_of_kid kid + module type Config = sig val convert_typ : ctx -> typ -> ctyp val optimize_anf : ctx -> typ aexp -> typ aexp @@ -240,340 +245,329 @@ module type Config = sig end let name_or_global ctx id = - if Env.is_register id ctx.local_env || IdSet.mem id (Env.get_toplevel_lets ctx.local_env) then - global id - else - name id - + if Env.is_register id ctx.local_env || IdSet.mem id (Env.get_toplevel_lets ctx.local_env) then global id else name id -module IdGraph = Graph.Make(Id) -module IdGraphNS = Set.Make(Id) +module IdGraph = Graph.Make (Id) +module IdGraphNS = Set.Make (Id) let callgraph cdefs = - List.fold_left (fun graph cdef -> + List.fold_left + (fun graph cdef -> match cdef with | CDEF_fundef (id, _, _, body) -> - let graph = ref graph in - List.iter (iter_instr (function - | I_aux (I_funcall (_, _, (call, _), _), _) -> - graph := IdGraph.add_edge id call !graph - | _ -> () - )) body; - !graph + let graph = ref graph in + List.iter + (iter_instr (function + | I_aux (I_funcall (_, _, (call, _), _), _) -> graph := IdGraph.add_edge id call !graph + | _ -> () + ) + ) + body; + !graph | _ -> graph - ) IdGraph.empty cdefs - -module Make(C: Config) = struct - -let ctyp_of_typ ctx typ = C.convert_typ ctx typ - -let rec chunkify n xs = - match Util.take n xs, Util.drop n xs with - | xs, [] -> [xs] - | xs, ys -> xs :: chunkify n ys - -let coverage_branch_count = ref 0 - -let coverage_loc_args l = - match Reporting.simp_loc l with - | None -> None - | Some (p1, p2) -> - Some (Printf.sprintf "\"%s\", %d, %d, %d, %d" - (String.escaped p1.pos_fname) p1.pos_lnum (p1.pos_cnum - p1.pos_bol) p2.pos_lnum (p2.pos_cnum - p2.pos_bol)) - -let coverage_branch_reached l = - let branch_id = !coverage_branch_count in - incr coverage_branch_count; - branch_id, - (match C.branch_coverage with - | Some _ -> - begin match coverage_loc_args l with - | None -> [] - | Some args -> - [iraw (Printf.sprintf "sail_branch_reached(%d, %s);" branch_id args)] + ) + IdGraph.empty cdefs + +module Make (C : Config) = struct + let ctyp_of_typ ctx typ = C.convert_typ ctx typ + + let rec chunkify n xs = match (Util.take n xs, Util.drop n xs) with xs, [] -> [xs] | xs, ys -> xs :: chunkify n ys + + let coverage_branch_count = ref 0 + + let coverage_loc_args l = + match Reporting.simp_loc l with + | None -> None + | Some (p1, p2) -> + Some + (Printf.sprintf "\"%s\", %d, %d, %d, %d" (String.escaped p1.pos_fname) p1.pos_lnum (p1.pos_cnum - p1.pos_bol) + p2.pos_lnum (p2.pos_cnum - p2.pos_bol) + ) + + let coverage_branch_reached l = + let branch_id = !coverage_branch_count in + incr coverage_branch_count; + ( branch_id, + match C.branch_coverage with + | Some _ -> begin + match coverage_loc_args l with + | None -> [] + | Some args -> [iraw (Printf.sprintf "sail_branch_reached(%d, %s);" branch_id args)] + end + | _ -> [] + ) + + let append_into_block instrs instr = match instrs with [] -> instr | _ -> iblock (instrs @ [instr]) + + let rec find_aexp_loc (AE_aux (e, _, l)) = + match Reporting.simp_loc l with + | Some _ -> l + | None -> ( + match e with AE_typ (e', _) -> find_aexp_loc e' | _ -> l + ) + + let coverage_branch_taken branch_id aexp = + match C.branch_coverage with + | None -> [] + | Some out -> begin + match coverage_loc_args (find_aexp_loc aexp) with + | None -> [] + | Some args -> + Printf.fprintf out "%s\n" ("B " ^ args); + [iraw (Printf.sprintf "sail_branch_taken(%d, %s);" branch_id args)] end - | _ -> [] - ) - -let append_into_block instrs instr = - match instrs with - | [] -> instr - | _ -> iblock (instrs @ [instr]) - -let rec find_aexp_loc (AE_aux (e, _, l)) = - match Reporting.simp_loc l with - | Some _ -> l - | None -> - match e with - | AE_typ (e',_) -> find_aexp_loc e' - | _ -> l - -let coverage_branch_taken branch_id aexp = - match C.branch_coverage with - | None -> [] - | Some out -> begin - match coverage_loc_args (find_aexp_loc aexp) with - | None -> [] - | Some args -> - Printf.fprintf out "%s\n" ("B " ^ args); - [iraw (Printf.sprintf "sail_branch_taken(%d, %s);" branch_id args)] - end -let coverage_function_entry id l = - match C.branch_coverage with - | None -> [] - | Some out -> begin - match coverage_loc_args l with - | None -> [] - | Some args -> - Printf.fprintf out "%s\n" ("F " ^ args); - [iraw (Printf.sprintf "sail_function_entry(\"%s\", %s);" (string_of_id id) args)] - end + let coverage_function_entry id l = + match C.branch_coverage with + | None -> [] + | Some out -> begin + match coverage_loc_args l with + | None -> [] + | Some args -> + Printf.fprintf out "%s\n" ("F " ^ args); + [iraw (Printf.sprintf "sail_function_entry(\"%s\", %s);" (string_of_id id) args)] + end -let rec compile_aval l ctx = function - | AV_cval (cval, typ) -> - let ctyp = cval_ctyp cval in - let ctyp' = ctyp_of_typ ctx typ in - if not (ctyp_equal ctyp ctyp') then - let gs = ngensym () in - [iinit l ctyp' gs cval], V_id (gs, ctyp'), [iclear ctyp' gs] - else - [], cval, [] - - | AV_id (id, typ) -> - begin match Bindings.find_opt id ctx.locals with - | Some (_, ctyp) -> - [], V_id (name id, ctyp), [] - | None -> - [], V_id (name_or_global ctx id, ctyp_of_typ ctx (lvar_typ typ)), [] - end - - | AV_ref (id, typ) -> - [], V_lit (VL_ref (string_of_id id), CT_ref (ctyp_of_typ ctx (lvar_typ typ))), [] - - | AV_lit (L_aux (L_string str, _), typ) -> - [], V_lit ((VL_string (String.escaped str)), ctyp_of_typ ctx typ), [] - - | AV_lit (L_aux (L_num n, _), typ) when C.ignore_64 -> - [], V_lit ((VL_int n), ctyp_of_typ ctx typ), [] - - | AV_lit (L_aux (L_num n, _), typ) when Big_int.less_equal (min_int 64) n && Big_int.less_equal n (max_int 64) -> - let gs = ngensym () in - [iinit l CT_lint gs (V_lit (VL_int n, CT_fint 64))], - V_id (gs, CT_lint), - [iclear CT_lint gs] - - | AV_lit (L_aux (L_num n, _), typ) -> - let gs = ngensym () in - [iinit l CT_lint gs (V_lit (VL_string (Big_int.to_string n), CT_string))], - V_id (gs, CT_lint), - [iclear CT_lint gs] - - | AV_lit (L_aux (L_zero, _), _) -> [], V_lit (VL_bit Sail2_values.B0, CT_bit), [] - | AV_lit (L_aux (L_one, _), _) -> [], V_lit (VL_bit Sail2_values.B1, CT_bit), [] - - | AV_lit (L_aux (L_true, _), _) -> [], V_lit (VL_bool true, CT_bool), [] - | AV_lit (L_aux (L_false, _), _) -> [], V_lit (VL_bool false, CT_bool), [] - - | AV_lit (L_aux (L_real str, _), _) -> - if C.use_real then - [], V_lit (VL_real str, CT_real), [] - else - let gs = ngensym () in - [iinit l CT_real gs (V_lit (VL_string str, CT_string))], - V_id (gs, CT_real), - [iclear CT_real gs] - - | AV_lit (L_aux (L_unit, _), _) -> [], V_lit (VL_unit, CT_unit), [] - - | AV_lit (L_aux (L_undef, _), typ) -> - let ctyp = ctyp_of_typ ctx typ in - [], V_lit (VL_undefined, ctyp), [] - - | AV_lit (L_aux (_, l) as lit, _) -> - raise (Reporting.err_general l ("Encountered unexpected literal " ^ string_of_lit lit ^ " when converting ANF represention into IR")) - - | AV_tuple avals -> - let elements = List.map (compile_aval l ctx) avals in - let cvals = List.map (fun (_, cval, _) -> cval) elements in - let setup = List.concat (List.map (fun (setup, _, _) -> setup) elements) in - let cleanup = List.concat (List.rev (List.map (fun (_, _, cleanup) -> cleanup) elements)) in - let tup_ctyp = CT_tup (List.map cval_ctyp cvals) in - let gs = ngensym () in - if C.tuple_value then ( - setup, - V_tuple (cvals, tup_ctyp), - cleanup - ) else ( - setup - @ [idecl l tup_ctyp gs] - @ List.mapi (fun n cval -> icopy l (CL_tuple (CL_id (gs, tup_ctyp), n)) cval) cvals, - V_id (gs, CT_tup (List.map cval_ctyp cvals)), - [iclear tup_ctyp gs] - @ cleanup - ) - - | AV_record (fields, typ) when C.struct_value -> - let ctyp = ctyp_of_typ ctx typ in - let compile_fields (id, aval) = - let field_setup, cval, field_cleanup = compile_aval l ctx aval in - field_setup, - (id, cval), - field_cleanup - in - let field_triples = List.map compile_fields (Bindings.bindings fields) in - let setup = List.concat (List.map (fun (s, _, _) -> s) field_triples) in - let fields = List.map (fun (_, f, _) -> f) field_triples in - let cleanup = List.concat (List.map (fun (_, _, c) -> c) field_triples) in - setup, - V_struct (fields, ctyp), - cleanup - - | AV_record (fields, typ) -> - let ctyp = ctyp_of_typ ctx typ in - let gs = ngensym () in - let compile_fields (id, aval) = - let field_setup, cval, field_cleanup = compile_aval l ctx aval in - field_setup - @ [icopy l (CL_field (CL_id (gs, ctyp), id)) cval] - @ field_cleanup - in - [idecl l ctyp gs] - @ List.concat (List.map compile_fields (Bindings.bindings fields)), - V_id (gs, ctyp), - [iclear ctyp gs] - - | AV_vector ([], typ) -> - let vector_ctyp = ctyp_of_typ ctx typ in - begin match ctyp_of_typ ctx typ with - | CT_fbits (0, ord) -> - [], V_lit (VL_bits ([], ord), vector_ctyp), [] - | _ -> + let rec compile_aval l ctx = function + | AV_cval (cval, typ) -> + let ctyp = cval_ctyp cval in + let ctyp' = ctyp_of_typ ctx typ in + if not (ctyp_equal ctyp ctyp') then ( + let gs = ngensym () in + ([iinit l ctyp' gs cval], V_id (gs, ctyp'), [iclear ctyp' gs]) + ) + else ([], cval, []) + | AV_id (id, typ) -> begin + match Bindings.find_opt id ctx.locals with + | Some (_, ctyp) -> ([], V_id (name id, ctyp), []) + | None -> ([], V_id (name_or_global ctx id, ctyp_of_typ ctx (lvar_typ typ)), []) + end + | AV_ref (id, typ) -> ([], V_lit (VL_ref (string_of_id id), CT_ref (ctyp_of_typ ctx (lvar_typ typ))), []) + | AV_lit (L_aux (L_string str, _), typ) -> ([], V_lit (VL_string (String.escaped str), ctyp_of_typ ctx typ), []) + | AV_lit (L_aux (L_num n, _), typ) when C.ignore_64 -> ([], V_lit (VL_int n, ctyp_of_typ ctx typ), []) + | AV_lit (L_aux (L_num n, _), typ) when Big_int.less_equal (min_int 64) n && Big_int.less_equal n (max_int 64) -> + let gs = ngensym () in + ([iinit l CT_lint gs (V_lit (VL_int n, CT_fint 64))], V_id (gs, CT_lint), [iclear CT_lint gs]) + | AV_lit (L_aux (L_num n, _), typ) -> let gs = ngensym () in - [idecl l vector_ctyp gs; - iextern l (CL_id (gs, vector_ctyp)) (mk_id "internal_vector_init", []) [V_lit (VL_int Big_int.zero, CT_fint 64)]], - V_id (gs, vector_ctyp), - [iclear vector_ctyp gs] - end - - (* Convert a small bitvector to a uint64_t literal. *) - | AV_vector (avals, typ) when is_bitvector avals && (List.length avals <= 64 || C.ignore_64) -> - begin - let bitstring = List.map value_of_aval_bit avals in - let len = List.length avals in - match destruct_bitvector ctx.tc_env typ with - | Some (_, Ord_aux (Ord_inc, _)) -> - [], V_lit (VL_bits (bitstring, false), CT_fbits (len, false)), [] - | Some (_, Ord_aux (Ord_dec, _)) -> - [], V_lit (VL_bits (bitstring, true), CT_fbits (len, true)), [] - | Some _ -> - raise (Reporting.err_general l "Encountered order polymorphic bitvector literal") - | None -> - raise (Reporting.err_general l "Encountered vector literal without vector type") - end - - (* Convert a bitvector literal that is larger than 64-bits to a - variable size bitvector, converting it in 64-bit chunks. *) - | AV_vector (avals, typ) when is_bitvector avals -> - let len = List.length avals in - let bitstring avals = VL_bits (List.map value_of_aval_bit avals, true) in - let first_chunk = bitstring (Util.take (len mod 64) avals) in - let chunks = Util.drop (len mod 64) avals |> chunkify 64 |> List.map bitstring in - let gs = ngensym () in - [iinit l (CT_lbits true) gs (V_lit (first_chunk, CT_fbits (len mod 64, true)))] - @ List.map (fun chunk -> ifuncall l (CL_id (gs, CT_lbits true)) - (mk_id "append_64", []) - [V_id (gs, CT_lbits true); V_lit (chunk, CT_fbits (64, true))]) chunks, - V_id (gs, CT_lbits true), - [iclear (CT_lbits true) gs] - - (* If we have a bitvector value, that isn't a literal then we need to set bits individually. *) - | AV_vector (avals, Typ_aux (Typ_app (id, [_; A_aux (A_order ord, _)]), _)) - when string_of_id id = "bitvector" && List.length avals <= 64 -> - let len = List.length avals in - let direction = match ord with - | Ord_aux (Ord_inc, _) -> false - | Ord_aux (Ord_dec, _) -> true - | Ord_aux (Ord_var _, _) -> raise (Reporting.err_general l "Polymorphic vector direction found") - in - let gs = ngensym () in - let ctyp = CT_fbits (len, direction) in - let mask i = VL_bits (Util.list_init (63 - i) (fun _ -> Sail2_values.B0) @ [Sail2_values.B1] @ Util.list_init i (fun _ -> Sail2_values.B0), direction) in - let aval_mask i aval = - let setup, cval, cleanup = compile_aval l ctx aval in - match cval with - | V_lit (VL_bit Sail2_values.B0, _) -> [] - | V_lit (VL_bit Sail2_values.B1, _) -> - [icopy l (CL_id (gs, ctyp)) (V_call (Bvor, [V_id (gs, ctyp); V_lit (mask i, ctyp)]))] - | _ -> + ( [iinit l CT_lint gs (V_lit (VL_string (Big_int.to_string n), CT_string))], + V_id (gs, CT_lint), + [iclear CT_lint gs] + ) + | AV_lit (L_aux (L_zero, _), _) -> ([], V_lit (VL_bit Sail2_values.B0, CT_bit), []) + | AV_lit (L_aux (L_one, _), _) -> ([], V_lit (VL_bit Sail2_values.B1, CT_bit), []) + | AV_lit (L_aux (L_true, _), _) -> ([], V_lit (VL_bool true, CT_bool), []) + | AV_lit (L_aux (L_false, _), _) -> ([], V_lit (VL_bool false, CT_bool), []) + | AV_lit (L_aux (L_real str, _), _) -> + if C.use_real then ([], V_lit (VL_real str, CT_real), []) + else ( + let gs = ngensym () in + ([iinit l CT_real gs (V_lit (VL_string str, CT_string))], V_id (gs, CT_real), [iclear CT_real gs]) + ) + | AV_lit (L_aux (L_unit, _), _) -> ([], V_lit (VL_unit, CT_unit), []) + | AV_lit (L_aux (L_undef, _), typ) -> + let ctyp = ctyp_of_typ ctx typ in + ([], V_lit (VL_undefined, ctyp), []) + | AV_lit ((L_aux (_, l) as lit), _) -> + raise + (Reporting.err_general l + ("Encountered unexpected literal " ^ string_of_lit lit ^ " when converting ANF represention into IR") + ) + | AV_tuple avals -> + let elements = List.map (compile_aval l ctx) avals in + let cvals = List.map (fun (_, cval, _) -> cval) elements in + let setup = List.concat (List.map (fun (setup, _, _) -> setup) elements) in + let cleanup = List.concat (List.rev (List.map (fun (_, _, cleanup) -> cleanup) elements)) in + let tup_ctyp = CT_tup (List.map cval_ctyp cvals) in + let gs = ngensym () in + if C.tuple_value then (setup, V_tuple (cvals, tup_ctyp), cleanup) + else + ( setup + @ [idecl l tup_ctyp gs] + @ List.mapi (fun n cval -> icopy l (CL_tuple (CL_id (gs, tup_ctyp), n)) cval) cvals, + V_id (gs, CT_tup (List.map cval_ctyp cvals)), + [iclear tup_ctyp gs] @ cleanup + ) + | AV_record (fields, typ) when C.struct_value -> + let ctyp = ctyp_of_typ ctx typ in + let compile_fields (id, aval) = + let field_setup, cval, field_cleanup = compile_aval l ctx aval in + (field_setup, (id, cval), field_cleanup) + in + let field_triples = List.map compile_fields (Bindings.bindings fields) in + let setup = List.concat (List.map (fun (s, _, _) -> s) field_triples) in + let fields = List.map (fun (_, f, _) -> f) field_triples in + let cleanup = List.concat (List.map (fun (_, _, c) -> c) field_triples) in + (setup, V_struct (fields, ctyp), cleanup) + | AV_record (fields, typ) -> + let ctyp = ctyp_of_typ ctx typ in + let gs = ngensym () in + let compile_fields (id, aval) = + let field_setup, cval, field_cleanup = compile_aval l ctx aval in + field_setup @ [icopy l (CL_field (CL_id (gs, ctyp), id)) cval] @ field_cleanup + in + ( [idecl l ctyp gs] @ List.concat (List.map compile_fields (Bindings.bindings fields)), + V_id (gs, ctyp), + [iclear ctyp gs] + ) + | AV_vector ([], typ) -> + let vector_ctyp = ctyp_of_typ ctx typ in + begin + match ctyp_of_typ ctx typ with + | CT_fbits (0, ord) -> ([], V_lit (VL_bits ([], ord), vector_ctyp), []) + | _ -> + let gs = ngensym () in + ( [ + idecl l vector_ctyp gs; + iextern l + (CL_id (gs, vector_ctyp)) + (mk_id "internal_vector_init", []) + [V_lit (VL_int Big_int.zero, CT_fint 64)]; + ], + V_id (gs, vector_ctyp), + [iclear vector_ctyp gs] + ) + end + (* Convert a small bitvector to a uint64_t literal. *) + | AV_vector (avals, typ) when is_bitvector avals && (List.length avals <= 64 || C.ignore_64) -> begin + let bitstring = List.map value_of_aval_bit avals in + let len = List.length avals in + match destruct_bitvector ctx.tc_env typ with + | Some (_, Ord_aux (Ord_inc, _)) -> ([], V_lit (VL_bits (bitstring, false), CT_fbits (len, false)), []) + | Some (_, Ord_aux (Ord_dec, _)) -> ([], V_lit (VL_bits (bitstring, true), CT_fbits (len, true)), []) + | Some _ -> raise (Reporting.err_general l "Encountered order polymorphic bitvector literal") + | None -> raise (Reporting.err_general l "Encountered vector literal without vector type") + end + (* Convert a bitvector literal that is larger than 64-bits to a + variable size bitvector, converting it in 64-bit chunks. *) + | AV_vector (avals, typ) when is_bitvector avals -> + let len = List.length avals in + let bitstring avals = VL_bits (List.map value_of_aval_bit avals, true) in + let first_chunk = bitstring (Util.take (len mod 64) avals) in + let chunks = Util.drop (len mod 64) avals |> chunkify 64 |> List.map bitstring in + let gs = ngensym () in + ( [iinit l (CT_lbits true) gs (V_lit (first_chunk, CT_fbits (len mod 64, true)))] + @ List.map + (fun chunk -> + ifuncall l + (CL_id (gs, CT_lbits true)) + (mk_id "append_64", []) + [V_id (gs, CT_lbits true); V_lit (chunk, CT_fbits (64, true))] + ) + chunks, + V_id (gs, CT_lbits true), + [iclear (CT_lbits true) gs] + ) + (* If we have a bitvector value, that isn't a literal then we need to set bits individually. *) + | AV_vector (avals, Typ_aux (Typ_app (id, [_; A_aux (A_order ord, _)]), _)) + when string_of_id id = "bitvector" && List.length avals <= 64 -> + let len = List.length avals in + let direction = + match ord with + | Ord_aux (Ord_inc, _) -> false + | Ord_aux (Ord_dec, _) -> true + | Ord_aux (Ord_var _, _) -> raise (Reporting.err_general l "Polymorphic vector direction found") + in + let gs = ngensym () in + let ctyp = CT_fbits (len, direction) in + let mask i = + VL_bits + ( Util.list_init (63 - i) (fun _ -> Sail2_values.B0) + @ [Sail2_values.B1] + @ Util.list_init i (fun _ -> Sail2_values.B0), + direction + ) + in + let aval_mask i aval = + let setup, cval, cleanup = compile_aval l ctx aval in + match cval with + | V_lit (VL_bit Sail2_values.B0, _) -> [] + | V_lit (VL_bit Sail2_values.B1, _) -> + [icopy l (CL_id (gs, ctyp)) (V_call (Bvor, [V_id (gs, ctyp); V_lit (mask i, ctyp)]))] + | _ -> + setup + @ [ + iextern l + (CL_id (gs, ctyp)) + (mk_id "update_fbits", []) + [V_id (gs, ctyp); V_lit (VL_int (Big_int.of_int i), CT_constant (Big_int.of_int i)); cval]; + ] + @ cleanup + in + ( [ + idecl l ctyp gs; + icopy l (CL_id (gs, ctyp)) (V_lit (VL_bits (Util.list_init len (fun _ -> Sail2_values.B0), direction), ctyp)); + ] + @ List.concat (List.mapi aval_mask (List.rev avals)), + V_id (gs, ctyp), + [] + ) + (* Compiling a vector literal that isn't a bitvector *) + | AV_vector (avals, Typ_aux (Typ_app (id, [_; A_aux (A_order ord, _); A_aux (A_typ typ, _)]), _)) + when string_of_id id = "vector" -> + let len = List.length avals in + let direction = + match ord with + | Ord_aux (Ord_inc, _) -> false + | Ord_aux (Ord_dec, _) -> true + | Ord_aux (Ord_var _, _) -> raise (Reporting.err_general l "Polymorphic vector direction found") + in + let elem_ctyp = ctyp_of_typ ctx typ in + let vector_ctyp = CT_vector (direction, elem_ctyp) in + let gs = ngensym () in + let aval_set i aval = + let setup, cval, cleanup = compile_aval l ctx aval in + let cval, conversion_setup, conversion_cleanup = + if ctyp_equal (cval_ctyp cval) elem_ctyp then (cval, [], []) + else ( + let gs = ngensym () in + (V_id (gs, elem_ctyp), [iinit l elem_ctyp gs cval], [iclear elem_ctyp gs]) + ) + in + setup @ conversion_setup + @ [ + iextern l + (CL_id (gs, vector_ctyp)) + (mk_id "internal_vector_update", []) + [V_id (gs, vector_ctyp); V_lit (VL_int (Big_int.of_int i), CT_fint 64); cval]; + ] + @ conversion_cleanup @ cleanup + in + ( [ + idecl l vector_ctyp gs; + iextern l + (CL_id (gs, vector_ctyp)) + (mk_id "internal_vector_init", []) + [V_lit (VL_int (Big_int.of_int len), CT_fint 64)]; + ] + @ List.concat (List.mapi aval_set (if direction then List.rev avals else avals)), + V_id (gs, vector_ctyp), + [iclear vector_ctyp gs] + ) + | AV_vector _ as aval -> + raise + (Reporting.err_general l + ("Have AVL_vector: " ^ Pretty_print_sail.to_string (pp_aval aval) ^ " which is not a vector type") + ) + | AV_list (avals, Typ_aux (typ, _)) -> + let ctyp = + match typ with + | Typ_app (id, [A_aux (A_typ typ, _)]) when string_of_id id = "list" -> ctyp_suprema (ctyp_of_typ ctx typ) + | _ -> raise (Reporting.err_general l "Invalid list type") + in + let gs = ngensym () in + let mk_cons aval = + let setup, cval, cleanup = compile_aval l ctx aval in setup - @ [iextern l (CL_id (gs, ctyp)) - (mk_id "update_fbits", []) - [V_id (gs, ctyp); V_lit (VL_int (Big_int.of_int i), CT_constant (Big_int.of_int i)); cval]] + @ [iextern l (CL_id (gs, CT_list ctyp)) (mk_id "sail_cons", [ctyp]) [cval; V_id (gs, CT_list ctyp)]] @ cleanup - in - [idecl l ctyp gs; - icopy l (CL_id (gs, ctyp)) (V_lit (VL_bits (Util.list_init len (fun _ -> Sail2_values.B0), direction), ctyp))] - @ List.concat (List.mapi aval_mask (List.rev avals)), - V_id (gs, ctyp), - [] - - (* Compiling a vector literal that isn't a bitvector *) - | AV_vector (avals, Typ_aux (Typ_app (id, [_; A_aux (A_order ord, _); A_aux (A_typ typ, _)]), _)) - when string_of_id id = "vector" -> - let len = List.length avals in - let direction = match ord with - | Ord_aux (Ord_inc, _) -> false - | Ord_aux (Ord_dec, _) -> true - | Ord_aux (Ord_var _, _) -> raise (Reporting.err_general l "Polymorphic vector direction found") - in - let elem_ctyp = ctyp_of_typ ctx typ in - let vector_ctyp = CT_vector (direction, elem_ctyp) in - let gs = ngensym () in - let aval_set i aval = - let setup, cval, cleanup = compile_aval l ctx aval in - let cval, conversion_setup, conversion_cleanup = - if ctyp_equal (cval_ctyp cval) elem_ctyp then ( - cval, [], [] - ) else ( - let gs = ngensym () in - V_id (gs, elem_ctyp), - [iinit l elem_ctyp gs cval], - [iclear elem_ctyp gs] - ) in - setup - @ conversion_setup - @ [iextern l (CL_id (gs, vector_ctyp)) - (mk_id "internal_vector_update", []) - [V_id (gs, vector_ctyp); V_lit (VL_int (Big_int.of_int i), CT_fint 64); cval]] - @ conversion_cleanup - @ cleanup - in - [idecl l vector_ctyp gs; - iextern l (CL_id (gs, vector_ctyp)) (mk_id "internal_vector_init", []) [V_lit (VL_int (Big_int.of_int len), CT_fint 64)]] - @ List.concat (List.mapi aval_set (if direction then List.rev avals else avals)), - V_id (gs, vector_ctyp), - [iclear vector_ctyp gs] - - | AV_vector _ as aval -> - raise (Reporting.err_general l ("Have AVL_vector: " ^ Pretty_print_sail.to_string (pp_aval aval) ^ " which is not a vector type")) - - | AV_list (avals, Typ_aux (typ, _)) -> - let ctyp = match typ with - | Typ_app (id, [A_aux (A_typ typ, _)]) when string_of_id id = "list" -> ctyp_suprema (ctyp_of_typ ctx typ) - | _ -> raise (Reporting.err_general l "Invalid list type") - in - let gs = ngensym () in - let mk_cons aval = - let setup, cval, cleanup = compile_aval l ctx aval in - setup @ [iextern l (CL_id (gs, CT_list ctyp)) (mk_id "sail_cons", [ctyp]) [cval; V_id (gs, CT_list ctyp)]] @ cleanup - in - [idecl l (CT_list ctyp) gs] - @ List.concat (List.map mk_cons (List.rev avals)), - V_id (gs, CT_list ctyp), - [iclear (CT_list ctyp) gs] - -(* + in + ( [idecl l (CT_list ctyp) gs] @ List.concat (List.map mk_cons (List.rev avals)), + V_id (gs, CT_list ctyp), + [iclear (CT_list ctyp) gs] + ) + + (* let optimize_call l ctx clexp id args arg_ctyps ret_ctyp = let call () = let setup = ref [] in @@ -623,704 +617,723 @@ let optimize_call l ctx clexp id args arg_ctyps ret_ctyp = else call () *) -let compile_funcall l ctx id args = - let setup = ref [] in - let cleanup = ref [] in + let compile_funcall l ctx id args = + let setup = ref [] in + let cleanup = ref [] in - let quant, Typ_aux (fn_typ, _) = - (* If we can't find a function in local_env, fall back to the - global env - this happens when representing assertions, exit, - etc as functions in the IR. *) - try Env.get_val_spec id ctx.local_env with Type_error _ -> Env.get_val_spec id ctx.tc_env - in - let arg_typs, ret_typ = match fn_typ with - | Typ_fn (arg_typs, ret_typ) -> arg_typs, ret_typ - | _ -> assert false - in - let ctx' = { ctx with local_env = Env.add_typquant (id_loc id) quant ctx.tc_env } in - let arg_ctyps, ret_ctyp = List.map (ctyp_of_typ ctx') arg_typs, ctyp_of_typ ctx' ret_typ in - - assert (List.length arg_ctyps = List.length args); - - let instantiation = ref KBindings.empty in - - let setup_arg ctyp aval = - let arg_setup, cval, arg_cleanup = compile_aval l ctx aval in - instantiation := KBindings.union merge_unifiers (ctyp_unify l ctyp (cval_ctyp cval)) !instantiation; - setup := List.rev arg_setup @ !setup; - cleanup := arg_cleanup @ !cleanup; - cval - in + let quant, Typ_aux (fn_typ, _) = + (* If we can't find a function in local_env, fall back to the + global env - this happens when representing assertions, exit, + etc as functions in the IR. *) + try Env.get_val_spec id ctx.local_env with Type_error _ -> Env.get_val_spec id ctx.tc_env + in + let arg_typs, ret_typ = match fn_typ with Typ_fn (arg_typs, ret_typ) -> (arg_typs, ret_typ) | _ -> assert false in + let ctx' = { ctx with local_env = Env.add_typquant (id_loc id) quant ctx.tc_env } in + let arg_ctyps, ret_ctyp = (List.map (ctyp_of_typ ctx') arg_typs, ctyp_of_typ ctx' ret_typ) in + + assert (List.length arg_ctyps = List.length args); + + let instantiation = ref KBindings.empty in + + let setup_arg ctyp aval = + let arg_setup, cval, arg_cleanup = compile_aval l ctx aval in + instantiation := KBindings.union merge_unifiers (ctyp_unify l ctyp (cval_ctyp cval)) !instantiation; + setup := List.rev arg_setup @ !setup; + cleanup := arg_cleanup @ !cleanup; + cval + in + + let setup_args = List.map2 setup_arg arg_ctyps args in - let setup_args = List.map2 setup_arg arg_ctyps args in - - List.rev !setup, - begin fun clexp -> - let instantiation = KBindings.union merge_unifiers (ctyp_unify l ret_ctyp (clexp_ctyp clexp)) !instantiation in - ifuncall l clexp (id, KBindings.bindings instantiation |> List.map snd) setup_args - (* iblock1 (optimize_call l ctx clexp (id, KBindings.bindings unifiers |> List.map snd) setup_args arg_ctyps ret_ctyp) *) - end, - !cleanup - -let rec apat_ctyp ctx (AP_aux (apat, env, _)) = - let ctx = { ctx with local_env = env } in - match apat with - | AP_tuple apats -> CT_tup (List.map (apat_ctyp ctx) apats) - | AP_global (_, typ) -> ctyp_of_typ ctx typ - | AP_cons (apat, _) -> CT_list (ctyp_suprema (apat_ctyp ctx apat)) - | AP_wild typ | AP_nil typ | AP_id (_, typ) -> ctyp_of_typ ctx typ - | AP_app (_, _, typ) -> ctyp_of_typ ctx typ - | AP_as (_, _, typ) -> ctyp_of_typ ctx typ - -let rec compile_match ctx (AP_aux (apat_aux, env, l)) cval case_label = - let ctx = { ctx with local_env = env } in - let ctyp = cval_ctyp cval in - match apat_aux with - | AP_global (pid, typ) -> - let global_ctyp = ctyp_of_typ ctx typ in - [], [icopy l (CL_id (global pid, global_ctyp)) cval], [], ctx - - | AP_id (pid, _) when is_ct_enum ctyp -> - begin match Env.lookup_id pid ctx.tc_env with - | Unbound _ -> [], [idecl l ctyp (name pid); icopy l (CL_id (name pid, ctyp)) cval], [], ctx - | _ -> [ijump l (V_call (Neq, [V_id (name pid, ctyp); cval])) case_label], [], [], ctx - end - - | AP_id (pid, typ) -> - let id_ctyp = ctyp_of_typ ctx typ in - let ctx = { ctx with locals = Bindings.add pid (Immutable, id_ctyp) ctx.locals } in - [], [idecl l id_ctyp (name pid); icopy l (CL_id (name pid, id_ctyp)) cval], [iclear id_ctyp (name pid)], ctx - - | AP_as (apat, id, typ) -> - let id_ctyp = ctyp_of_typ ctx typ in - let pre, instrs, cleanup, ctx = compile_match ctx apat cval case_label in - let ctx = { ctx with locals = Bindings.add id (Immutable, id_ctyp) ctx.locals } in - pre, instrs @ [idecl l id_ctyp (name id); icopy l (CL_id (name id, id_ctyp)) cval], iclear id_ctyp (name id) :: cleanup, ctx - - | AP_tuple apats -> - begin - let get_tup n = V_tuple_member (cval, List.length apats, n) in - let fold (pre, instrs, cleanup, n, ctx) apat ctyp = - let pre', instrs', cleanup', ctx = compile_match ctx apat (get_tup n) case_label in - pre @ pre', instrs @ instrs', cleanup' @ cleanup, n + 1, ctx - in - match ctyp with - | CT_tup ctyps -> - let pre, instrs, cleanup, _, ctx = List.fold_left2 fold ([], [], [], 0, ctx) apats ctyps in - pre, instrs, cleanup, ctx - | _ -> failwith ("AP_tuple with ctyp " ^ string_of_ctyp ctyp) - end - - | AP_app (ctor, apat, variant_typ) -> - begin match ctyp with - | CT_variant (var_id, ctors) -> - let pat_ctyp = apat_ctyp ctx apat in - (* These should really be the same, something has gone wrong if they are not. *) - if not (ctyp_equal (cval_ctyp cval) (ctyp_of_typ ctx variant_typ)) then - raise (Reporting.err_general l (Printf.sprintf "When compiling constructor pattern, %s should have the same type as %s" - (string_of_ctyp (cval_ctyp cval)) (string_of_ctyp (ctyp_of_typ ctx variant_typ)))) - else (); - let unifiers, ctor_ctyp = - let generic_ctors = Bindings.find var_id ctx.variants |> snd |> Bindings.bindings in - let unifiers = ctyp_unify l (CT_variant (var_id, generic_ctors)) (cval_ctyp cval) |> KBindings.bindings |> List.map snd in - let is_poly_ctor = List.exists (fun (id, ctyp) -> Id.compare id ctor = 0 && is_polymorphic ctyp) generic_ctors in - unifiers, if is_poly_ctor then ctyp_suprema pat_ctyp else pat_ctyp + ( List.rev !setup, + begin + fun clexp -> + let instantiation = + KBindings.union merge_unifiers (ctyp_unify l ret_ctyp (clexp_ctyp clexp)) !instantiation + in + ifuncall l clexp (id, KBindings.bindings instantiation |> List.map snd) setup_args + (* iblock1 (optimize_call l ctx clexp (id, KBindings.bindings unifiers |> List.map snd) setup_args arg_ctyps ret_ctyp) *) + end, + !cleanup + ) + + let rec apat_ctyp ctx (AP_aux (apat, env, _)) = + let ctx = { ctx with local_env = env } in + match apat with + | AP_tuple apats -> CT_tup (List.map (apat_ctyp ctx) apats) + | AP_global (_, typ) -> ctyp_of_typ ctx typ + | AP_cons (apat, _) -> CT_list (ctyp_suprema (apat_ctyp ctx apat)) + | AP_wild typ | AP_nil typ | AP_id (_, typ) -> ctyp_of_typ ctx typ + | AP_app (_, _, typ) -> ctyp_of_typ ctx typ + | AP_as (_, _, typ) -> ctyp_of_typ ctx typ + + let rec compile_match ctx (AP_aux (apat_aux, env, l)) cval case_label = + let ctx = { ctx with local_env = env } in + let ctyp = cval_ctyp cval in + match apat_aux with + | AP_global (pid, typ) -> + let global_ctyp = ctyp_of_typ ctx typ in + ([], [icopy l (CL_id (global pid, global_ctyp)) cval], [], ctx) + | AP_id (pid, _) when is_ct_enum ctyp -> begin + match Env.lookup_id pid ctx.tc_env with + | Unbound _ -> ([], [idecl l ctyp (name pid); icopy l (CL_id (name pid, ctyp)) cval], [], ctx) + | _ -> ([ijump l (V_call (Neq, [V_id (name pid, ctyp); cval])) case_label], [], [], ctx) + end + | AP_id (pid, typ) -> + let id_ctyp = ctyp_of_typ ctx typ in + let ctx = { ctx with locals = Bindings.add pid (Immutable, id_ctyp) ctx.locals } in + ([], [idecl l id_ctyp (name pid); icopy l (CL_id (name pid, id_ctyp)) cval], [iclear id_ctyp (name pid)], ctx) + | AP_as (apat, id, typ) -> + let id_ctyp = ctyp_of_typ ctx typ in + let pre, instrs, cleanup, ctx = compile_match ctx apat cval case_label in + let ctx = { ctx with locals = Bindings.add id (Immutable, id_ctyp) ctx.locals } in + ( pre, + instrs @ [idecl l id_ctyp (name id); icopy l (CL_id (name id, id_ctyp)) cval], + iclear id_ctyp (name id) :: cleanup, + ctx + ) + | AP_tuple apats -> begin + let get_tup n = V_tuple_member (cval, List.length apats, n) in + let fold (pre, instrs, cleanup, n, ctx) apat ctyp = + let pre', instrs', cleanup', ctx = compile_match ctx apat (get_tup n) case_label in + (pre @ pre', instrs @ instrs', cleanup' @ cleanup, n + 1, ctx) in - let pre, instrs, cleanup, ctx = compile_match ctx apat (V_ctor_unwrap (cval, (ctor, unifiers), ctor_ctyp)) case_label in - [ijump l (V_ctor_kind (cval, (ctor, unifiers), pat_ctyp)) case_label] @ pre, - instrs, - cleanup, - ctx - | ctyp -> - raise (Reporting.err_general l (Printf.sprintf "Variant constructor %s : %s matching against non-variant type %s : %s" - (string_of_id ctor) - (string_of_typ variant_typ) - (string_of_cval cval) - (string_of_ctyp ctyp))) - end - - | AP_wild _ -> [], [], [], ctx - - | AP_cons (hd_apat, tl_apat) -> - begin match ctyp with - | CT_list ctyp -> - let hd_pre, hd_setup, hd_cleanup, ctx = compile_match ctx hd_apat (V_call (List_hd, [cval])) case_label in - let tl_pre, tl_setup, tl_cleanup, ctx = compile_match ctx tl_apat (V_call (List_tl, [cval])) case_label in - [ijump l (V_call (Eq, [cval; V_lit (VL_empty_list, CT_list ctyp)])) case_label] @ hd_pre @ tl_pre, hd_setup @ tl_setup, tl_cleanup @ hd_cleanup, ctx - | _ -> - raise (Reporting.err_general l "Tried to pattern match cons on non list type") - end - - | AP_nil _ -> [ijump l (V_call (Neq, [cval; V_lit (VL_empty_list, ctyp)])) case_label], [], [], ctx - -let unit_cval = V_lit (VL_unit, CT_unit) - -let rec compile_alexp ctx alexp = - match alexp with - | AL_id (id, typ) -> - let ctyp = match Bindings.find_opt id ctx.locals with - | Some (_, ctyp) -> ctyp - | None -> ctyp_of_typ ctx typ - in - CL_id (name_or_global ctx id, ctyp) - | AL_addr (id, typ) -> - let ctyp = match Bindings.find_opt id ctx.locals with - | Some (_, ctyp) -> ctyp - | None -> ctyp_of_typ ctx typ - in - CL_addr (CL_id (name_or_global ctx id, ctyp)) - | AL_field (alexp, field_id) -> - CL_field (compile_alexp ctx alexp, field_id) - -let rec compile_aexp ctx (AE_aux (aexp_aux, env, l)) = - let ctx = { ctx with local_env = env } in - match aexp_aux with - | AE_let (mut, id, binding_typ, binding, (AE_aux (_, body_env, _) as body), body_typ) -> - let binding_ctyp = ctyp_of_typ { ctx with local_env = body_env } binding_typ in - let setup, call, cleanup = compile_aexp ctx binding in - let letb_setup, letb_cleanup = - [idecl l binding_ctyp (name id); - iblock1 (setup @ [call (CL_id (name id, binding_ctyp))] @ cleanup)], - [iclear binding_ctyp (name id)] - in - let ctx = { ctx with locals = Bindings.add id (mut, binding_ctyp) ctx.locals } in - let setup, call, cleanup = compile_aexp ctx body in - letb_setup @ setup, call, cleanup @ letb_cleanup - - | AE_app (id, vs, _) -> - compile_funcall l ctx id vs - - | AE_val aval -> - let setup, cval, cleanup = compile_aval l ctx aval in - setup, (fun clexp -> icopy l clexp cval), cleanup - - (* Compile case statements *) - | AE_match (aval, cases, typ) -> - let ctyp = ctyp_of_typ ctx typ in - let aval_setup, cval, aval_cleanup = compile_aval l ctx aval in - (* Get the number of cases, because we don't want to check branch - coverage for matches with only a single case. *) - let num_cases = List.length cases in - let branch_id, on_reached = coverage_branch_reached l in - let case_return_id = ngensym () in - let finish_match_label = label "finish_match_" in - let compile_case (apat, guard, body) = - let case_label = label "case_" in - if is_dead_aexp body then ( - [ilabel case_label] - ) else ( - let trivial_guard = match guard with - | AE_aux (AE_val (AV_lit (L_aux (L_true, _), _)), _, _) - | AE_aux (AE_val (AV_cval (V_lit (VL_bool true, CT_bool), _)), _, _) -> true - | _ -> false - in - let pre_destructure, destructure, destructure_cleanup, ctx = compile_match ctx apat cval case_label in - let guard_setup, guard_call, guard_cleanup = compile_aexp ctx guard in - let body_setup, body_call, body_cleanup = compile_aexp ctx body in - let gs = ngensym () in - let case_instrs = - pre_destructure - @ destructure - @ (if not trivial_guard then - guard_setup @ [idecl l CT_bool gs; guard_call (CL_id (gs, CT_bool))] @ guard_cleanup - @ [iif l (V_call (Bnot, [V_id (gs, CT_bool)])) (destructure_cleanup @ [igoto case_label]) [] CT_unit] - else []) - @ (if num_cases > 1 then coverage_branch_taken branch_id body else []) - @ body_setup - @ [body_call (CL_id (case_return_id, ctyp))] @ body_cleanup @ destructure_cleanup - @ [igoto finish_match_label] - in - [iblock case_instrs; ilabel case_label] - ) - in - aval_setup - @ (if num_cases > 1 then on_reached else []) - @ [idecl l ctyp case_return_id] - @ List.concat (List.map compile_case cases) - @ [imatch_failure l] - @ [ilabel finish_match_label], - (fun clexp -> icopy l clexp (V_id (case_return_id, ctyp))), - [iclear ctyp case_return_id] - @ aval_cleanup - - (* Compile try statement *) - | AE_try (aexp, cases, typ) -> - let ctyp = ctyp_of_typ ctx typ in - let aexp_setup, aexp_call, aexp_cleanup = compile_aexp ctx aexp in - let try_return_id = ngensym () in - let post_exception_handlers_label = label "post_exception_handlers_" in - let compile_case (apat, guard, body) = - let trivial_guard = match guard with - | AE_aux (AE_val (AV_lit (L_aux (L_true, _), _)), _, _) - | AE_aux (AE_val (AV_cval (V_lit (VL_bool true, CT_bool), _)), _, _) -> true - | _ -> false - in - let try_label = label "try_" in - let exn_cval = V_id (current_exception, ctyp_of_typ ctx (mk_typ (Typ_id (mk_id "exception")))) in - let pre_destructure, destructure, destructure_cleanup, ctx = compile_match ctx apat exn_cval try_label in - let guard_setup, guard_call, guard_cleanup = compile_aexp ctx guard in - let body_setup, body_call, body_cleanup = compile_aexp ctx body in - let gs = ngensym () in - let case_instrs = - pre_destructure - @ destructure - @ (if not trivial_guard then - guard_setup @ [idecl l CT_bool gs; guard_call (CL_id (gs, CT_bool))] @ guard_cleanup - @ [ijump l (V_call (Bnot, [V_id (gs, CT_bool)])) try_label] - else []) - @ body_setup @ [body_call (CL_id (try_return_id, ctyp))] @ body_cleanup @ destructure_cleanup - @ [igoto post_exception_handlers_label] - in - [iblock case_instrs; ilabel try_label] - in - assert (ctyp_equal ctyp (ctyp_of_typ ctx typ)); - [idecl l ctyp try_return_id; - itry_block l (aexp_setup @ [aexp_call (CL_id (try_return_id, ctyp))] @ aexp_cleanup); - ijump l (V_call (Bnot, [V_id (have_exception, CT_bool)])) post_exception_handlers_label; - icopy l (CL_id (have_exception, CT_bool)) (V_lit (VL_bool false, CT_bool))] - @ List.concat (List.map compile_case cases) - @ [(* fallthrough *) - icopy l (CL_id (have_exception, CT_bool)) (V_lit (VL_bool true, CT_bool)); - ilabel post_exception_handlers_label], - (fun clexp -> icopy l clexp (V_id (try_return_id, ctyp))), - [] - - | AE_if (aval, then_aexp, else_aexp, if_typ) -> - if is_dead_aexp then_aexp then - compile_aexp ctx else_aexp - else if is_dead_aexp else_aexp then - compile_aexp ctx then_aexp - else - let if_ctyp = ctyp_of_typ ctx if_typ in - let branch_id, on_reached = coverage_branch_reached l in - let compile_branch aexp = - let setup, call, cleanup = compile_aexp ctx aexp in - fun clexp -> coverage_branch_taken branch_id aexp @ setup @ [call clexp] @ cleanup - in - let setup, cval, cleanup = compile_aval l ctx aval in - setup, - (fun clexp -> - append_into_block on_reached - (iif l cval - (compile_branch then_aexp clexp) - (compile_branch else_aexp clexp) - if_ctyp)), - cleanup - - (* FIXME: AE_struct_update could be AV_record_update - would reduce some copying. *) - | AE_struct_update (aval, fields, typ) -> - let ctyp = ctyp_of_typ ctx typ in - let _ctors = match ctyp with - | CT_struct (_, ctors) -> List.fold_left (fun m (k, v) -> Bindings.add k v m) Bindings.empty ctors - | _ -> raise (Reporting.err_general l "Cannot perform record update for non-record type") - in - let gs = ngensym () in - let compile_fields (id, aval) = - let field_setup, cval, field_cleanup = compile_aval l ctx aval in - field_setup - @ [icopy l (CL_field (CL_id (gs, ctyp), id)) cval] - @ field_cleanup - in - let setup, cval, cleanup = compile_aval l ctx aval in - [idecl l ctyp gs] - @ setup - @ [icopy l (CL_id (gs, ctyp)) cval] - @ cleanup - @ List.concat (List.map compile_fields (Bindings.bindings fields)), - (fun clexp -> icopy l clexp (V_id (gs, ctyp))), - [iclear ctyp gs] - - | AE_short_circuit (SC_and, aval, aexp) -> - let branch_id, on_reached = coverage_branch_reached l in - let left_setup, cval, left_cleanup = compile_aval l ctx aval in - let right_setup, call, right_cleanup = compile_aexp ctx aexp in - let right_coverage = coverage_branch_taken branch_id aexp in - let gs = ngensym () in - left_setup - @ on_reached - @ [ idecl l CT_bool gs; - iif l cval - (right_coverage @ right_setup @ [call (CL_id (gs, CT_bool))] @ right_cleanup) - [icopy l (CL_id (gs, CT_bool)) (V_lit (VL_bool false, CT_bool))] - CT_bool ] - @ left_cleanup, - (fun clexp -> icopy l clexp (V_id (gs, CT_bool))), - [] - | AE_short_circuit (SC_or, aval, aexp) -> - let branch_id, on_reached = coverage_branch_reached l in - let left_setup, cval, left_cleanup = compile_aval l ctx aval in - let right_setup, call, right_cleanup = compile_aexp ctx aexp in - let right_coverage = coverage_branch_taken branch_id aexp in - let gs = ngensym () in - left_setup - @ on_reached - @ [ idecl l CT_bool gs; - iif l cval - [icopy l (CL_id (gs, CT_bool)) (V_lit (VL_bool true, CT_bool))] - (right_coverage @ right_setup @ [call (CL_id (gs, CT_bool))] @ right_cleanup) - CT_bool ] - @ left_cleanup, - (fun clexp -> icopy l clexp (V_id (gs, CT_bool))), - [] - - (* This is a faster assignment rule for updating fields of a - struct. *) - | AE_assign (AL_id (id, assign_typ), AE_aux (AE_struct_update (AV_id (rid, _), fields, typ), _, _)) - when Id.compare id rid = 0 -> - let compile_fields (field_id, aval) = - let field_setup, cval, field_cleanup = compile_aval l ctx aval in - field_setup - @ [icopy l (CL_field (CL_id (name_or_global ctx id, ctyp_of_typ ctx typ), field_id)) cval] - @ field_cleanup - in - List.concat (List.map compile_fields (Bindings.bindings fields)), - (fun clexp -> icopy l clexp unit_cval), - [] - - | AE_assign (alexp, aexp) -> - let setup, call, cleanup = compile_aexp ctx aexp in - setup @ [call (compile_alexp ctx alexp)], (fun clexp -> icopy l clexp unit_cval), cleanup - - | AE_block (aexps, aexp, _) -> - let block = compile_block ctx aexps in - let setup, call, cleanup = compile_aexp ctx aexp in - block @ setup, call, cleanup - - | AE_loop (While, cond, body) -> - let loop_start_label = label "while_" in - let loop_end_label = label "wend_" in - let cond_setup, cond_call, cond_cleanup = compile_aexp ctx cond in - let body_setup, body_call, body_cleanup = compile_aexp ctx body in - let gs = ngensym () in - let unit_gs = ngensym () in - let loop_test = V_call (Bnot, [V_id (gs, CT_bool)]) in - [idecl l CT_bool gs; idecl l CT_unit unit_gs] - @ [ilabel loop_start_label] - @ [iblock (cond_setup + match ctyp with + | CT_tup ctyps -> + let pre, instrs, cleanup, _, ctx = List.fold_left2 fold ([], [], [], 0, ctx) apats ctyps in + (pre, instrs, cleanup, ctx) + | _ -> failwith ("AP_tuple with ctyp " ^ string_of_ctyp ctyp) + end + | AP_app (ctor, apat, variant_typ) -> begin + match ctyp with + | CT_variant (var_id, ctors) -> + let pat_ctyp = apat_ctyp ctx apat in + (* These should really be the same, something has gone wrong if they are not. *) + if not (ctyp_equal (cval_ctyp cval) (ctyp_of_typ ctx variant_typ)) then + raise + (Reporting.err_general l + (Printf.sprintf "When compiling constructor pattern, %s should have the same type as %s" + (string_of_ctyp (cval_ctyp cval)) + (string_of_ctyp (ctyp_of_typ ctx variant_typ)) + ) + ) + else (); + let unifiers, ctor_ctyp = + let generic_ctors = Bindings.find var_id ctx.variants |> snd |> Bindings.bindings in + let unifiers = + ctyp_unify l (CT_variant (var_id, generic_ctors)) (cval_ctyp cval) |> KBindings.bindings |> List.map snd + in + let is_poly_ctor = + List.exists (fun (id, ctyp) -> Id.compare id ctor = 0 && is_polymorphic ctyp) generic_ctors + in + (unifiers, if is_poly_ctor then ctyp_suprema pat_ctyp else pat_ctyp) + in + let pre, instrs, cleanup, ctx = + compile_match ctx apat (V_ctor_unwrap (cval, (ctor, unifiers), ctor_ctyp)) case_label + in + ([ijump l (V_ctor_kind (cval, (ctor, unifiers), pat_ctyp)) case_label] @ pre, instrs, cleanup, ctx) + | ctyp -> + raise + (Reporting.err_general l + (Printf.sprintf "Variant constructor %s : %s matching against non-variant type %s : %s" + (string_of_id ctor) (string_of_typ variant_typ) (string_of_cval cval) (string_of_ctyp ctyp) + ) + ) + end + | AP_wild _ -> ([], [], [], ctx) + | AP_cons (hd_apat, tl_apat) -> begin + match ctyp with + | CT_list ctyp -> + let hd_pre, hd_setup, hd_cleanup, ctx = compile_match ctx hd_apat (V_call (List_hd, [cval])) case_label in + let tl_pre, tl_setup, tl_cleanup, ctx = compile_match ctx tl_apat (V_call (List_tl, [cval])) case_label in + ( [ijump l (V_call (Eq, [cval; V_lit (VL_empty_list, CT_list ctyp)])) case_label] @ hd_pre @ tl_pre, + hd_setup @ tl_setup, + tl_cleanup @ hd_cleanup, + ctx + ) + | _ -> raise (Reporting.err_general l "Tried to pattern match cons on non list type") + end + | AP_nil _ -> ([ijump l (V_call (Neq, [cval; V_lit (VL_empty_list, ctyp)])) case_label], [], [], ctx) + + let unit_cval = V_lit (VL_unit, CT_unit) + + let rec compile_alexp ctx alexp = + match alexp with + | AL_id (id, typ) -> + let ctyp = match Bindings.find_opt id ctx.locals with Some (_, ctyp) -> ctyp | None -> ctyp_of_typ ctx typ in + CL_id (name_or_global ctx id, ctyp) + | AL_addr (id, typ) -> + let ctyp = match Bindings.find_opt id ctx.locals with Some (_, ctyp) -> ctyp | None -> ctyp_of_typ ctx typ in + CL_addr (CL_id (name_or_global ctx id, ctyp)) + | AL_field (alexp, field_id) -> CL_field (compile_alexp ctx alexp, field_id) + + let rec compile_aexp ctx (AE_aux (aexp_aux, env, l)) = + let ctx = { ctx with local_env = env } in + match aexp_aux with + | AE_let (mut, id, binding_typ, binding, (AE_aux (_, body_env, _) as body), body_typ) -> + let binding_ctyp = ctyp_of_typ { ctx with local_env = body_env } binding_typ in + let setup, call, cleanup = compile_aexp ctx binding in + let letb_setup, letb_cleanup = + ( [idecl l binding_ctyp (name id); iblock1 (setup @ [call (CL_id (name id, binding_ctyp))] @ cleanup)], + [iclear binding_ctyp (name id)] + ) + in + let ctx = { ctx with locals = Bindings.add id (mut, binding_ctyp) ctx.locals } in + let setup, call, cleanup = compile_aexp ctx body in + (letb_setup @ setup, call, cleanup @ letb_cleanup) + | AE_app (id, vs, _) -> compile_funcall l ctx id vs + | AE_val aval -> + let setup, cval, cleanup = compile_aval l ctx aval in + (setup, (fun clexp -> icopy l clexp cval), cleanup) + (* Compile case statements *) + | AE_match (aval, cases, typ) -> + let ctyp = ctyp_of_typ ctx typ in + let aval_setup, cval, aval_cleanup = compile_aval l ctx aval in + (* Get the number of cases, because we don't want to check branch + coverage for matches with only a single case. *) + let num_cases = List.length cases in + let branch_id, on_reached = coverage_branch_reached l in + let case_return_id = ngensym () in + let finish_match_label = label "finish_match_" in + let compile_case (apat, guard, body) = + let case_label = label "case_" in + if is_dead_aexp body then [ilabel case_label] + else ( + let trivial_guard = + match guard with + | AE_aux (AE_val (AV_lit (L_aux (L_true, _), _)), _, _) + | AE_aux (AE_val (AV_cval (V_lit (VL_bool true, CT_bool), _)), _, _) -> + true + | _ -> false + in + let pre_destructure, destructure, destructure_cleanup, ctx = compile_match ctx apat cval case_label in + let guard_setup, guard_call, guard_cleanup = compile_aexp ctx guard in + let body_setup, body_call, body_cleanup = compile_aexp ctx body in + let gs = ngensym () in + let case_instrs = + pre_destructure @ destructure + @ ( if not trivial_guard then + guard_setup + @ [idecl l CT_bool gs; guard_call (CL_id (gs, CT_bool))] + @ guard_cleanup + @ [ + iif l (V_call (Bnot, [V_id (gs, CT_bool)])) (destructure_cleanup @ [igoto case_label]) [] CT_unit; + ] + else [] + ) + @ (if num_cases > 1 then coverage_branch_taken branch_id body else []) + @ body_setup + @ [body_call (CL_id (case_return_id, ctyp))] + @ body_cleanup @ destructure_cleanup + @ [igoto finish_match_label] + in + [iblock case_instrs; ilabel case_label] + ) + in + ( aval_setup + @ (if num_cases > 1 then on_reached else []) + @ [idecl l ctyp case_return_id] + @ List.concat (List.map compile_case cases) + @ [imatch_failure l] + @ [ilabel finish_match_label], + (fun clexp -> icopy l clexp (V_id (case_return_id, ctyp))), + [iclear ctyp case_return_id] @ aval_cleanup + ) + (* Compile try statement *) + | AE_try (aexp, cases, typ) -> + let ctyp = ctyp_of_typ ctx typ in + let aexp_setup, aexp_call, aexp_cleanup = compile_aexp ctx aexp in + let try_return_id = ngensym () in + let post_exception_handlers_label = label "post_exception_handlers_" in + let compile_case (apat, guard, body) = + let trivial_guard = + match guard with + | AE_aux (AE_val (AV_lit (L_aux (L_true, _), _)), _, _) + | AE_aux (AE_val (AV_cval (V_lit (VL_bool true, CT_bool), _)), _, _) -> + true + | _ -> false + in + let try_label = label "try_" in + let exn_cval = V_id (current_exception, ctyp_of_typ ctx (mk_typ (Typ_id (mk_id "exception")))) in + let pre_destructure, destructure, destructure_cleanup, ctx = compile_match ctx apat exn_cval try_label in + let guard_setup, guard_call, guard_cleanup = compile_aexp ctx guard in + let body_setup, body_call, body_cleanup = compile_aexp ctx body in + let gs = ngensym () in + let case_instrs = + pre_destructure @ destructure + @ ( if not trivial_guard then + guard_setup + @ [idecl l CT_bool gs; guard_call (CL_id (gs, CT_bool))] + @ guard_cleanup + @ [ijump l (V_call (Bnot, [V_id (gs, CT_bool)])) try_label] + else [] + ) + @ body_setup + @ [body_call (CL_id (try_return_id, ctyp))] + @ body_cleanup @ destructure_cleanup + @ [igoto post_exception_handlers_label] + in + [iblock case_instrs; ilabel try_label] + in + assert (ctyp_equal ctyp (ctyp_of_typ ctx typ)); + ( [ + idecl l ctyp try_return_id; + itry_block l (aexp_setup @ [aexp_call (CL_id (try_return_id, ctyp))] @ aexp_cleanup); + ijump l (V_call (Bnot, [V_id (have_exception, CT_bool)])) post_exception_handlers_label; + icopy l (CL_id (have_exception, CT_bool)) (V_lit (VL_bool false, CT_bool)); + ] + @ List.concat (List.map compile_case cases) + @ [ + (* fallthrough *) + icopy l (CL_id (have_exception, CT_bool)) (V_lit (VL_bool true, CT_bool)); + ilabel post_exception_handlers_label; + ], + (fun clexp -> icopy l clexp (V_id (try_return_id, ctyp))), + [] + ) + | AE_if (aval, then_aexp, else_aexp, if_typ) -> + if is_dead_aexp then_aexp then compile_aexp ctx else_aexp + else if is_dead_aexp else_aexp then compile_aexp ctx then_aexp + else ( + let if_ctyp = ctyp_of_typ ctx if_typ in + let branch_id, on_reached = coverage_branch_reached l in + let compile_branch aexp = + let setup, call, cleanup = compile_aexp ctx aexp in + fun clexp -> coverage_branch_taken branch_id aexp @ setup @ [call clexp] @ cleanup + in + let setup, cval, cleanup = compile_aval l ctx aval in + ( setup, + (fun clexp -> + append_into_block on_reached + (iif l cval (compile_branch then_aexp clexp) (compile_branch else_aexp clexp) if_ctyp) + ), + cleanup + ) + ) + (* FIXME: AE_struct_update could be AV_record_update - would reduce some copying. *) + | AE_struct_update (aval, fields, typ) -> + let ctyp = ctyp_of_typ ctx typ in + let _ctors = + match ctyp with + | CT_struct (_, ctors) -> List.fold_left (fun m (k, v) -> Bindings.add k v m) Bindings.empty ctors + | _ -> raise (Reporting.err_general l "Cannot perform record update for non-record type") + in + let gs = ngensym () in + let compile_fields (id, aval) = + let field_setup, cval, field_cleanup = compile_aval l ctx aval in + field_setup @ [icopy l (CL_field (CL_id (gs, ctyp), id)) cval] @ field_cleanup + in + let setup, cval, cleanup = compile_aval l ctx aval in + ( [idecl l ctyp gs] + @ setup + @ [icopy l (CL_id (gs, ctyp)) cval] + @ cleanup + @ List.concat (List.map compile_fields (Bindings.bindings fields)), + (fun clexp -> icopy l clexp (V_id (gs, ctyp))), + [iclear ctyp gs] + ) + | AE_short_circuit (SC_and, aval, aexp) -> + let branch_id, on_reached = coverage_branch_reached l in + let left_setup, cval, left_cleanup = compile_aval l ctx aval in + let right_setup, call, right_cleanup = compile_aexp ctx aexp in + let right_coverage = coverage_branch_taken branch_id aexp in + let gs = ngensym () in + ( left_setup @ on_reached + @ [ + idecl l CT_bool gs; + iif l cval + (right_coverage @ right_setup @ [call (CL_id (gs, CT_bool))] @ right_cleanup) + [icopy l (CL_id (gs, CT_bool)) (V_lit (VL_bool false, CT_bool))] + CT_bool; + ] + @ left_cleanup, + (fun clexp -> icopy l clexp (V_id (gs, CT_bool))), + [] + ) + | AE_short_circuit (SC_or, aval, aexp) -> + let branch_id, on_reached = coverage_branch_reached l in + let left_setup, cval, left_cleanup = compile_aval l ctx aval in + let right_setup, call, right_cleanup = compile_aexp ctx aexp in + let right_coverage = coverage_branch_taken branch_id aexp in + let gs = ngensym () in + ( left_setup @ on_reached + @ [ + idecl l CT_bool gs; + iif l cval + [icopy l (CL_id (gs, CT_bool)) (V_lit (VL_bool true, CT_bool))] + (right_coverage @ right_setup @ [call (CL_id (gs, CT_bool))] @ right_cleanup) + CT_bool; + ] + @ left_cleanup, + (fun clexp -> icopy l clexp (V_id (gs, CT_bool))), + [] + ) + (* This is a faster assignment rule for updating fields of a + struct. *) + | AE_assign (AL_id (id, assign_typ), AE_aux (AE_struct_update (AV_id (rid, _), fields, typ), _, _)) + when Id.compare id rid = 0 -> + let compile_fields (field_id, aval) = + let field_setup, cval, field_cleanup = compile_aval l ctx aval in + field_setup + @ [icopy l (CL_field (CL_id (name_or_global ctx id, ctyp_of_typ ctx typ), field_id)) cval] + @ field_cleanup + in + (List.concat (List.map compile_fields (Bindings.bindings fields)), (fun clexp -> icopy l clexp unit_cval), []) + | AE_assign (alexp, aexp) -> + let setup, call, cleanup = compile_aexp ctx aexp in + (setup @ [call (compile_alexp ctx alexp)], (fun clexp -> icopy l clexp unit_cval), cleanup) + | AE_block (aexps, aexp, _) -> + let block = compile_block ctx aexps in + let setup, call, cleanup = compile_aexp ctx aexp in + (block @ setup, call, cleanup) + | AE_loop (While, cond, body) -> + let loop_start_label = label "while_" in + let loop_end_label = label "wend_" in + let cond_setup, cond_call, cond_cleanup = compile_aexp ctx cond in + let body_setup, body_call, body_cleanup = compile_aexp ctx body in + let gs = ngensym () in + let unit_gs = ngensym () in + let loop_test = V_call (Bnot, [V_id (gs, CT_bool)]) in + ( [idecl l CT_bool gs; idecl l CT_unit unit_gs] + @ [ilabel loop_start_label] + @ [ + iblock + (cond_setup @ [cond_call (CL_id (gs, CT_bool))] @ cond_cleanup @ [ijump l loop_test loop_end_label] @ body_setup @ [body_call (CL_id (unit_gs, CT_unit))] @ body_cleanup - @ [igoto loop_start_label])] - @ [ilabel loop_end_label], - (fun clexp -> icopy l clexp unit_cval), - [] - - | AE_loop (Until, cond, body) -> - let loop_start_label = label "repeat_" in - let loop_end_label = label "until_" in - let cond_setup, cond_call, cond_cleanup = compile_aexp ctx cond in - let body_setup, body_call, body_cleanup = compile_aexp ctx body in - let gs = ngensym () in - let unit_gs = ngensym () in - let loop_test = V_id (gs, CT_bool) in - [idecl l CT_bool gs; idecl l CT_unit unit_gs] - @ [ilabel loop_start_label] - @ [iblock (body_setup + @ [igoto loop_start_label] + ); + ] + @ [ilabel loop_end_label], + (fun clexp -> icopy l clexp unit_cval), + [] + ) + | AE_loop (Until, cond, body) -> + let loop_start_label = label "repeat_" in + let loop_end_label = label "until_" in + let cond_setup, cond_call, cond_cleanup = compile_aexp ctx cond in + let body_setup, body_call, body_cleanup = compile_aexp ctx body in + let gs = ngensym () in + let unit_gs = ngensym () in + let loop_test = V_id (gs, CT_bool) in + ( [idecl l CT_bool gs; idecl l CT_unit unit_gs] + @ [ilabel loop_start_label] + @ [ + iblock + (body_setup @ [body_call (CL_id (unit_gs, CT_unit))] - @ body_cleanup - @ cond_setup + @ body_cleanup @ cond_setup @ [cond_call (CL_id (gs, CT_bool))] @ cond_cleanup @ [ijump l loop_test loop_end_label] - @ [igoto loop_start_label])] - @ [ilabel loop_end_label], - (fun clexp -> icopy l clexp unit_cval), - [] - - | AE_typ (aexp, typ) -> compile_aexp ctx aexp - - | AE_return (aval, typ) -> - let fn_return_ctyp = match Env.get_ret_typ env with - | Some typ -> ctyp_of_typ ctx typ - | None -> raise (Reporting.err_general l "No function return type found when compiling return statement") - in - (* Cleanup info will be re-added by fix_early_(heap/stack)_return *) - let return_setup, cval, _ = compile_aval l ctx aval in - let creturn = - if ctyp_equal fn_return_ctyp (cval_ctyp cval) then - [ireturn cval] - else - let gs = ngensym () in - [idecl l fn_return_ctyp gs; - icopy l (CL_id (gs, fn_return_ctyp)) cval; - ireturn (V_id (gs, fn_return_ctyp))] - in - return_setup @ creturn, - (fun clexp -> icomment "unreachable after return"), - [] - - | AE_throw (aval, typ) -> - (* Cleanup info will be handled by fix_exceptions *) - let throw_setup, cval, _ = compile_aval l ctx aval in - throw_setup @ [ithrow l cval], - (fun clexp -> icomment "unreachable after throw"), - [] - - | AE_exit (aval, typ) -> - let exit_setup, cval, _ = compile_aval l ctx aval in - exit_setup @ [iexit l], - (fun clexp -> icomment "unreachable after exit"), - [] - - | AE_field (aval, id, typ) -> - let setup, cval, cleanup = compile_aval l ctx aval in - let _ctyp = match cval_ctyp cval with - | CT_struct (struct_id, fields) -> - begin match Util.assoc_compare_opt Id.compare id fields with - | Some ctyp -> ctyp - | None -> - raise (Reporting.err_unreachable l __POS__ - ("Struct " ^ string_of_id struct_id ^ " does not have expected field " ^ string_of_id id - ^ "?\nFields: " ^ Util.string_of_list ", " (fun (id, ctyp) -> string_of_id id ^ ": " ^ string_of_ctyp ctyp) fields)) - end - | _ -> - raise (Reporting.err_unreachable l __POS__ "Field access on non-struct type in ANF representation!") - in - setup, - (fun clexp -> icopy l clexp (V_field (cval, id))), - cleanup - - | AE_for (loop_var, loop_from, loop_to, loop_step, Ord_aux (ord, _), body) -> - (* We assume that all loop indices are safe to put in a CT_fint. *) - let ctx = { ctx with locals = Bindings.add loop_var (Immutable, CT_fint 64) ctx.locals } in - - let is_inc = match ord with - | Ord_inc -> true - | Ord_dec -> false - | Ord_var _ -> raise (Reporting.err_general l "Polymorphic loop direction in C backend") - in - - (* Loop variables *) - let from_setup, from_call, from_cleanup = compile_aexp ctx loop_from in - let from_gs = ngensym () in - let to_setup, to_call, to_cleanup = compile_aexp ctx loop_to in - let to_gs = ngensym () in - let step_setup, step_call, step_cleanup = compile_aexp ctx loop_step in - let step_gs = ngensym () in - let variable_init gs setup call cleanup = - [idecl l (CT_fint 64) gs; - iblock (setup @ [call (CL_id (gs, CT_fint 64))] @ cleanup)] - in - - let loop_start_label = label "for_start_" in - let loop_end_label = label "for_end_" in - let body_setup, body_call, body_cleanup = compile_aexp ctx body in - let body_gs = ngensym () in - - let loop_var = name loop_var in - - let loop_body prefix continue = - prefix - @ [iblock ([ijump l (V_call ((if is_inc then Igt else Ilt), [V_id (loop_var, CT_fint 64); V_id (to_gs, CT_fint 64)])) loop_end_label] - @ body_setup - @ [body_call (CL_id (body_gs, CT_unit))] - @ body_cleanup - @ [icopy l (CL_id (loop_var, (CT_fint 64))) - (V_call ((if is_inc then Iadd else Isub), [V_id (loop_var, CT_fint 64); V_id (step_gs, CT_fint 64)]))] - @ continue ())] - in - (* We can either generate an actual loop body for C, or unroll the body for SMT *) - let actual = loop_body [ilabel loop_start_label] (fun () -> [igoto loop_start_label]) in - let rec unroll max n = loop_body [] (fun () -> if n < max then unroll max (n + 1) else [imatch_failure l]) in - let body = match C.unroll_loops with Some times -> unroll times 0 | None -> actual in - - variable_init from_gs from_setup from_call from_cleanup - @ variable_init to_gs to_setup to_call to_cleanup - @ variable_init step_gs step_setup step_call step_cleanup - @ [iblock ([idecl l (CT_fint 64) loop_var; - icopy l (CL_id (loop_var, (CT_fint 64))) (V_id (from_gs, CT_fint 64)); - idecl l CT_unit body_gs] + @ [igoto loop_start_label] + ); + ] + @ [ilabel loop_end_label], + (fun clexp -> icopy l clexp unit_cval), + [] + ) + | AE_typ (aexp, typ) -> compile_aexp ctx aexp + | AE_return (aval, typ) -> + let fn_return_ctyp = + match Env.get_ret_typ env with + | Some typ -> ctyp_of_typ ctx typ + | None -> raise (Reporting.err_general l "No function return type found when compiling return statement") + in + (* Cleanup info will be re-added by fix_early_(heap/stack)_return *) + let return_setup, cval, _ = compile_aval l ctx aval in + let creturn = + if ctyp_equal fn_return_ctyp (cval_ctyp cval) then [ireturn cval] + else ( + let gs = ngensym () in + [idecl l fn_return_ctyp gs; icopy l (CL_id (gs, fn_return_ctyp)) cval; ireturn (V_id (gs, fn_return_ctyp))] + ) + in + (return_setup @ creturn, (fun clexp -> icomment "unreachable after return"), []) + | AE_throw (aval, typ) -> + (* Cleanup info will be handled by fix_exceptions *) + let throw_setup, cval, _ = compile_aval l ctx aval in + (throw_setup @ [ithrow l cval], (fun clexp -> icomment "unreachable after throw"), []) + | AE_exit (aval, typ) -> + let exit_setup, cval, _ = compile_aval l ctx aval in + (exit_setup @ [iexit l], (fun clexp -> icomment "unreachable after exit"), []) + | AE_field (aval, id, typ) -> + let setup, cval, cleanup = compile_aval l ctx aval in + let _ctyp = + match cval_ctyp cval with + | CT_struct (struct_id, fields) -> begin + match Util.assoc_compare_opt Id.compare id fields with + | Some ctyp -> ctyp + | None -> + raise + (Reporting.err_unreachable l __POS__ + ("Struct " ^ string_of_id struct_id ^ " does not have expected field " ^ string_of_id id + ^ "?\nFields: " + ^ Util.string_of_list ", " + (fun (id, ctyp) -> string_of_id id ^ ": " ^ string_of_ctyp ctyp) + fields + ) + ) + end + | _ -> raise (Reporting.err_unreachable l __POS__ "Field access on non-struct type in ANF representation!") + in + (setup, (fun clexp -> icopy l clexp (V_field (cval, id))), cleanup) + | AE_for (loop_var, loop_from, loop_to, loop_step, Ord_aux (ord, _), body) -> + (* We assume that all loop indices are safe to put in a CT_fint. *) + let ctx = { ctx with locals = Bindings.add loop_var (Immutable, CT_fint 64) ctx.locals } in + + let is_inc = + match ord with + | Ord_inc -> true + | Ord_dec -> false + | Ord_var _ -> raise (Reporting.err_general l "Polymorphic loop direction in C backend") + in + + (* Loop variables *) + let from_setup, from_call, from_cleanup = compile_aexp ctx loop_from in + let from_gs = ngensym () in + let to_setup, to_call, to_cleanup = compile_aexp ctx loop_to in + let to_gs = ngensym () in + let step_setup, step_call, step_cleanup = compile_aexp ctx loop_step in + let step_gs = ngensym () in + let variable_init gs setup call cleanup = + [idecl l (CT_fint 64) gs; iblock (setup @ [call (CL_id (gs, CT_fint 64))] @ cleanup)] + in + + let loop_start_label = label "for_start_" in + let loop_end_label = label "for_end_" in + let body_setup, body_call, body_cleanup = compile_aexp ctx body in + let body_gs = ngensym () in + + let loop_var = name loop_var in + + let loop_body prefix continue = + prefix + @ [ + iblock + ([ + ijump l + (V_call ((if is_inc then Igt else Ilt), [V_id (loop_var, CT_fint 64); V_id (to_gs, CT_fint 64)])) + loop_end_label; + ] + @ body_setup + @ [body_call (CL_id (body_gs, CT_unit))] + @ body_cleanup + @ [ + icopy l + (CL_id (loop_var, CT_fint 64)) + (V_call + ((if is_inc then Iadd else Isub), [V_id (loop_var, CT_fint 64); V_id (step_gs, CT_fint 64)]) + ); + ] + @ continue () + ); + ] + in + (* We can either generate an actual loop body for C, or unroll the body for SMT *) + let actual = loop_body [ilabel loop_start_label] (fun () -> [igoto loop_start_label]) in + let rec unroll max n = loop_body [] (fun () -> if n < max then unroll max (n + 1) else [imatch_failure l]) in + let body = match C.unroll_loops with Some times -> unroll times 0 | None -> actual in + + ( variable_init from_gs from_setup from_call from_cleanup + @ variable_init to_gs to_setup to_call to_cleanup + @ variable_init step_gs step_setup step_call step_cleanup + @ [ + iblock + ([ + idecl l (CT_fint 64) loop_var; + icopy l (CL_id (loop_var, CT_fint 64)) (V_id (from_gs, CT_fint 64)); + idecl l CT_unit body_gs; + ] @ body - @ [ilabel loop_end_label])], - (fun clexp -> icopy l clexp unit_cval), - [] - -and compile_block ctx = function - | [] -> [] - | (AE_aux (_, _, l) as exp) :: exps -> - let setup, call, cleanup = compile_aexp ctx exp in - let rest = compile_block ctx exps in - let gs = ngensym () in - iblock (setup @ [idecl l CT_unit gs; call (CL_id (gs, CT_unit))] @ cleanup) :: rest - -let fast_int = function - | CT_lint when !optimize_aarch64_fast_struct -> CT_fint 64 - | ctyp -> ctyp - -(** Compile a sail type definition into a IR one. Most of the + @ [ilabel loop_end_label] + ); + ], + (fun clexp -> icopy l clexp unit_cval), + [] + ) + + and compile_block ctx = function + | [] -> [] + | (AE_aux (_, _, l) as exp) :: exps -> + let setup, call, cleanup = compile_aexp ctx exp in + let rest = compile_block ctx exps in + let gs = ngensym () in + iblock (setup @ [idecl l CT_unit gs; call (CL_id (gs, CT_unit))] @ cleanup) :: rest + + let fast_int = function CT_lint when !optimize_aarch64_fast_struct -> CT_fint 64 | ctyp -> ctyp + + (** Compile a sail type definition into a IR one. Most of the actual work of translating the typedefs into C is done by the code generator, as it's easy to keep track of structs, tuples and unions in their sail form at this level, and leave the fiddly details of how they get mapped to C in the next stage. This function also adds details of the types it compiles to the context, ctx, which is why it returns a ctypdef * ctx pair. **) -let compile_type_def ctx (TD_aux (type_def, (l, _))) = - match type_def with - | TD_enum (id, ids, _) -> - CTD_enum (id, ids), - { ctx with enums = Bindings.add id (IdSet.of_list ids) ctx.enums } - - | TD_record (id, typq, ctors, _) -> - let record_ctx = { ctx with local_env = Env.add_typquant l typq ctx.local_env } in - let ctors = - List.fold_left (fun ctors (typ, id) -> Bindings.add id (fast_int (ctyp_of_typ record_ctx typ)) ctors) Bindings.empty ctors - in - let params = quant_kopts typq |> List.filter is_typ_kopt |> List.map kopt_kid in - CTD_struct (id, Bindings.bindings ctors), - { ctx with records = Bindings.add id (params, ctors) ctx.records } - - | TD_variant (id, typq, tus, _) -> - let compile_tu = function - | Tu_aux (Tu_ty_id (typ, id), _) -> - let ctx = { ctx with local_env = Env.add_typquant (id_loc id) typq ctx.local_env } in - ctyp_of_typ ctx typ, id - in - let ctus = List.fold_left (fun ctus (ctyp, id) -> Bindings.add id ctyp ctus) Bindings.empty (List.map compile_tu tus) in - let params = quant_kopts typq |> List.filter is_typ_kopt |> List.map kopt_kid in - CTD_variant (id, Bindings.bindings ctus), - { ctx with variants = Bindings.add id (params, ctus) ctx.variants } - - (* Will be re-written before here, see bitfield.ml *) - | TD_bitfield _ -> - Reporting.unreachable l __POS__ "Cannot compile TD_bitfield" - - (* All type abbreviations are filtered out in compile_def *) - | TD_abbrev _ -> - Reporting.unreachable l __POS__ "Found TD_abbrev in compile_type_def" - -let generate_cleanup instrs = - let generate_cleanup' (I_aux (instr, _)) = - match instr with - | I_init (ctyp, id, cval) -> [(id, iclear ctyp id)] - | I_decl (ctyp, id) -> [(id, iclear ctyp id)] - | instr -> [] - in - let is_clear ids = function - | I_aux (I_clear (_, id), _) -> NameSet.add id ids - | _ -> ids - in - let cleaned = List.fold_left is_clear NameSet.empty instrs in - instrs - |> List.map generate_cleanup' - |> List.concat - |> List.filter (fun (id, _) -> not (NameSet.mem id cleaned)) - |> List.map snd - -let fix_exception_block ?return:(return=None) ctx instrs = - let end_block_label = label "end_block_exception_" in - let is_exception_stop (I_aux (instr, _)) = - match instr with - | I_throw _ | I_if _ | I_block _ | I_funcall _ -> true - | _ -> false - in - (* In this function 'after' is instructions after the one we've - matched on, 'before is instructions before the instruction we've - matched with, but after the previous match, and 'historic' are - all the befores from previous matches. *) - let rec rewrite_exception historic instrs = - match instr_split_at is_exception_stop instrs with - | instrs, [] -> instrs - | before, I_aux (I_block instrs, _) :: after -> - before - @ [iblock (rewrite_exception (historic @ before) instrs)] - @ rewrite_exception (historic @ before) after - | before, I_aux (I_if (cval, then_instrs, else_instrs, ctyp), (_, l)) :: after -> - let historic = historic @ before in - before - @ [iif l cval (rewrite_exception historic then_instrs) (rewrite_exception historic else_instrs) ctyp] - @ rewrite_exception historic after - | before, I_aux (I_throw cval, (_, l)) :: after -> - before - @ [icopy l (CL_id (current_exception, cval_ctyp cval)) cval; - icopy l (CL_id (have_exception, CT_bool)) (V_lit (VL_bool true, CT_bool))] - @ (if C.track_throw then - let loc_string = Reporting.short_loc_to_string l in - [icopy l (CL_id (throw_location, CT_string)) (V_lit (VL_string loc_string, CT_string))] - else []) - @ generate_cleanup (historic @ before) - @ [igoto end_block_label] - @ rewrite_exception (historic @ before) after - | before, (I_aux (I_funcall (x, _, f, args), (_, l)) as funcall) :: after -> - let effects = - match Bindings.find_opt (fst f) ctx.effect_info.functions with - | Some effects -> effects - (* Constructors and back-end built-in value operations might not be present *) - | None -> Effects.EffectSet.empty - in - if Effects.throws effects then - before - @ [funcall; - iif l (V_id (have_exception, CT_bool)) (generate_cleanup (historic @ before) @ [igoto end_block_label]) [] CT_unit] - @ rewrite_exception (historic @ before) after - else - before @ funcall :: rewrite_exception (historic @ before) after - | _, _ -> assert false (* unreachable *) - in - match return with - | None -> - rewrite_exception [] instrs @ [ilabel end_block_label] - | Some ctyp -> - rewrite_exception [] instrs @ [ilabel end_block_label; iundefined ctyp] - -let rec map_try_block f (I_aux (instr, aux)) = - let instr = match instr with - | I_decl _ | I_reset _ | I_init _ | I_reinit _ -> instr - | I_if (cval, instrs1, instrs2, ctyp) -> - I_if (cval, List.map (map_try_block f) instrs1, List.map (map_try_block f) instrs2, ctyp) - | I_funcall _ | I_copy _ | I_clear _ | I_throw _ | I_return _ -> instr - | I_block instrs -> I_block (List.map (map_try_block f) instrs) - | I_try_block instrs -> I_try_block (f (List.map (map_try_block f) instrs)) - | I_comment _ | I_label _ | I_goto _ | I_raw _ | I_jump _ | I_exit _ | I_undefined _ | I_end _ -> instr - in - I_aux (instr, aux) - -let fix_exception ?return:(return=None) ctx instrs = - let instrs = List.map (map_try_block (fix_exception_block ctx)) instrs in - fix_exception_block ~return:return ctx instrs - -let rec compile_arg_pat ctx label (P_aux (p_aux, (l, _)) as pat) ctyp = - match p_aux with - | P_id id -> (id, ([], [])) - | P_wild -> let gs = gensym () in (gs, ([], [])) - | P_tuple [] | P_lit (L_aux (L_unit, _)) -> let gs = gensym () in (gs, ([], [])) - | P_var (pat, _) -> compile_arg_pat ctx label pat ctyp - | P_typ (_, pat) -> compile_arg_pat ctx label pat ctyp - | _ -> - let apat = anf_pat pat in - let gs = gensym () in - let pre_destructure, destructure, cleanup, _ = compile_match ctx apat (V_id (name gs, ctyp)) label in - (gs, (pre_destructure @ destructure, cleanup)) - -let rec compile_arg_pats ctx label (P_aux (p_aux, (l, _)) as pat) ctyps = - match p_aux with - | P_typ (_, pat) -> compile_arg_pats ctx label pat ctyps - | P_tuple pats when List.length pats = List.length ctyps -> - [], List.map2 (fun pat ctyp -> compile_arg_pat ctx label pat ctyp) pats ctyps, [] - | _ when List.length ctyps = 1 -> - [], [compile_arg_pat ctx label pat (List.nth ctyps 0)], [] - - | _ -> - let arg_id, (destructure, cleanup) = compile_arg_pat ctx label pat (CT_tup ctyps) in - let new_ids = List.map (fun ctyp -> gensym (), ctyp) ctyps in - destructure - @ [idecl l (CT_tup ctyps) (name arg_id)] - @ List.mapi (fun i (id, ctyp) -> icopy l (CL_tuple (CL_id (name arg_id, CT_tup ctyps), i)) (V_id (name id, ctyp))) new_ids, - List.map (fun (id, _) -> id, ([], [])) new_ids, - [iclear (CT_tup ctyps) (name arg_id)] - @ cleanup - -let combine_destructure_cleanup xs = List.concat (List.map fst xs), List.concat (List.rev (List.map snd xs)) - -let fix_destructure l fail_label = function - | ([], cleanup) -> ([], cleanup) - | destructure, cleanup -> - let body_label = label "fundef_body_" in - (destructure @ [igoto body_label; ilabel fail_label; imatch_failure l; ilabel body_label], cleanup) - -(** Functions that have heap-allocated return types are implemented by + let compile_type_def ctx (TD_aux (type_def, (l, _))) = + match type_def with + | TD_enum (id, ids, _) -> (CTD_enum (id, ids), { ctx with enums = Bindings.add id (IdSet.of_list ids) ctx.enums }) + | TD_record (id, typq, ctors, _) -> + let record_ctx = { ctx with local_env = Env.add_typquant l typq ctx.local_env } in + let ctors = + List.fold_left + (fun ctors (typ, id) -> Bindings.add id (fast_int (ctyp_of_typ record_ctx typ)) ctors) + Bindings.empty ctors + in + let params = quant_kopts typq |> List.filter is_typ_kopt |> List.map kopt_kid in + (CTD_struct (id, Bindings.bindings ctors), { ctx with records = Bindings.add id (params, ctors) ctx.records }) + | TD_variant (id, typq, tus, _) -> + let compile_tu = function + | Tu_aux (Tu_ty_id (typ, id), _) -> + let ctx = { ctx with local_env = Env.add_typquant (id_loc id) typq ctx.local_env } in + (ctyp_of_typ ctx typ, id) + in + let ctus = + List.fold_left (fun ctus (ctyp, id) -> Bindings.add id ctyp ctus) Bindings.empty (List.map compile_tu tus) + in + let params = quant_kopts typq |> List.filter is_typ_kopt |> List.map kopt_kid in + (CTD_variant (id, Bindings.bindings ctus), { ctx with variants = Bindings.add id (params, ctus) ctx.variants }) + (* Will be re-written before here, see bitfield.ml *) + | TD_bitfield _ -> Reporting.unreachable l __POS__ "Cannot compile TD_bitfield" + (* All type abbreviations are filtered out in compile_def *) + | TD_abbrev _ -> Reporting.unreachable l __POS__ "Found TD_abbrev in compile_type_def" + + let generate_cleanup instrs = + let generate_cleanup' (I_aux (instr, _)) = + match instr with + | I_init (ctyp, id, cval) -> [(id, iclear ctyp id)] + | I_decl (ctyp, id) -> [(id, iclear ctyp id)] + | instr -> [] + in + let is_clear ids = function I_aux (I_clear (_, id), _) -> NameSet.add id ids | _ -> ids in + let cleaned = List.fold_left is_clear NameSet.empty instrs in + instrs |> List.map generate_cleanup' |> List.concat + |> List.filter (fun (id, _) -> not (NameSet.mem id cleaned)) + |> List.map snd + + let fix_exception_block ?(return = None) ctx instrs = + let end_block_label = label "end_block_exception_" in + let is_exception_stop (I_aux (instr, _)) = + match instr with I_throw _ | I_if _ | I_block _ | I_funcall _ -> true | _ -> false + in + (* In this function 'after' is instructions after the one we've + matched on, 'before is instructions before the instruction we've + matched with, but after the previous match, and 'historic' are + all the befores from previous matches. *) + let rec rewrite_exception historic instrs = + match instr_split_at is_exception_stop instrs with + | instrs, [] -> instrs + | before, I_aux (I_block instrs, _) :: after -> + before @ [iblock (rewrite_exception (historic @ before) instrs)] @ rewrite_exception (historic @ before) after + | before, I_aux (I_if (cval, then_instrs, else_instrs, ctyp), (_, l)) :: after -> + let historic = historic @ before in + before + @ [iif l cval (rewrite_exception historic then_instrs) (rewrite_exception historic else_instrs) ctyp] + @ rewrite_exception historic after + | before, I_aux (I_throw cval, (_, l)) :: after -> + before + @ [ + icopy l (CL_id (current_exception, cval_ctyp cval)) cval; + icopy l (CL_id (have_exception, CT_bool)) (V_lit (VL_bool true, CT_bool)); + ] + @ ( if C.track_throw then ( + let loc_string = Reporting.short_loc_to_string l in + [icopy l (CL_id (throw_location, CT_string)) (V_lit (VL_string loc_string, CT_string))] + ) + else [] + ) + @ generate_cleanup (historic @ before) + @ [igoto end_block_label] + @ rewrite_exception (historic @ before) after + | before, (I_aux (I_funcall (x, _, f, args), (_, l)) as funcall) :: after -> + let effects = + match Bindings.find_opt (fst f) ctx.effect_info.functions with + | Some effects -> effects + (* Constructors and back-end built-in value operations might not be present *) + | None -> Effects.EffectSet.empty + in + if Effects.throws effects then + before + @ [ + funcall; + iif l + (V_id (have_exception, CT_bool)) + (generate_cleanup (historic @ before) @ [igoto end_block_label]) + [] CT_unit; + ] + @ rewrite_exception (historic @ before) after + else before @ (funcall :: rewrite_exception (historic @ before) after) + | _, _ -> assert false (* unreachable *) + in + match return with + | None -> rewrite_exception [] instrs @ [ilabel end_block_label] + | Some ctyp -> rewrite_exception [] instrs @ [ilabel end_block_label; iundefined ctyp] + + let rec map_try_block f (I_aux (instr, aux)) = + let instr = + match instr with + | I_decl _ | I_reset _ | I_init _ | I_reinit _ -> instr + | I_if (cval, instrs1, instrs2, ctyp) -> + I_if (cval, List.map (map_try_block f) instrs1, List.map (map_try_block f) instrs2, ctyp) + | I_funcall _ | I_copy _ | I_clear _ | I_throw _ | I_return _ -> instr + | I_block instrs -> I_block (List.map (map_try_block f) instrs) + | I_try_block instrs -> I_try_block (f (List.map (map_try_block f) instrs)) + | I_comment _ | I_label _ | I_goto _ | I_raw _ | I_jump _ | I_exit _ | I_undefined _ | I_end _ -> instr + in + I_aux (instr, aux) + + let fix_exception ?(return = None) ctx instrs = + let instrs = List.map (map_try_block (fix_exception_block ctx)) instrs in + fix_exception_block ~return ctx instrs + + let rec compile_arg_pat ctx label (P_aux (p_aux, (l, _)) as pat) ctyp = + match p_aux with + | P_id id -> (id, ([], [])) + | P_wild -> + let gs = gensym () in + (gs, ([], [])) + | P_tuple [] | P_lit (L_aux (L_unit, _)) -> + let gs = gensym () in + (gs, ([], [])) + | P_var (pat, _) -> compile_arg_pat ctx label pat ctyp + | P_typ (_, pat) -> compile_arg_pat ctx label pat ctyp + | _ -> + let apat = anf_pat pat in + let gs = gensym () in + let pre_destructure, destructure, cleanup, _ = compile_match ctx apat (V_id (name gs, ctyp)) label in + (gs, (pre_destructure @ destructure, cleanup)) + + let rec compile_arg_pats ctx label (P_aux (p_aux, (l, _)) as pat) ctyps = + match p_aux with + | P_typ (_, pat) -> compile_arg_pats ctx label pat ctyps + | P_tuple pats when List.length pats = List.length ctyps -> + ([], List.map2 (fun pat ctyp -> compile_arg_pat ctx label pat ctyp) pats ctyps, []) + | _ when List.length ctyps = 1 -> ([], [compile_arg_pat ctx label pat (List.nth ctyps 0)], []) + | _ -> + let arg_id, (destructure, cleanup) = compile_arg_pat ctx label pat (CT_tup ctyps) in + let new_ids = List.map (fun ctyp -> (gensym (), ctyp)) ctyps in + ( destructure + @ [idecl l (CT_tup ctyps) (name arg_id)] + @ List.mapi + (fun i (id, ctyp) -> icopy l (CL_tuple (CL_id (name arg_id, CT_tup ctyps), i)) (V_id (name id, ctyp))) + new_ids, + List.map (fun (id, _) -> (id, ([], []))) new_ids, + [iclear (CT_tup ctyps) (name arg_id)] @ cleanup + ) + + let combine_destructure_cleanup xs = (List.concat (List.map fst xs), List.concat (List.rev (List.map snd xs))) + + let fix_destructure l fail_label = function + | [], cleanup -> ([], cleanup) + | destructure, cleanup -> + let body_label = label "fundef_body_" in + (destructure @ [igoto body_label; ilabel fail_label; imatch_failure l; ilabel body_label], cleanup) + + (** Functions that have heap-allocated return types are implemented by passing a pointer a location where the return value should be stored. The ANF -> Sail IR pass for expressions simply outputs an I_return instruction for any return value, so this function walks @@ -1329,756 +1342,833 @@ let fix_destructure l fail_label = function flow to cleanup heap-allocated variables correctly when a function terminates early. See the generate_cleanup function for how this is done. *) -let fix_early_return l ret instrs = - let end_function_label = label "end_function_" in - let is_return_recur (I_aux (instr, _)) = - match instr with - | I_return _ | I_undefined _ | I_if _ | I_block _ | I_try_block _ -> true - | _ -> false - in - let rec rewrite_return historic instrs = - match instr_split_at is_return_recur instrs with - | instrs, [] -> instrs - | before, I_aux (I_try_block instrs, (_, l)) :: after -> - before - @ [itry_block l (rewrite_return (historic @ before) instrs)] - @ rewrite_return (historic @ before) after - | before, I_aux (I_block instrs, _) :: after -> - before - @ [iblock (rewrite_return (historic @ before) instrs)] - @ rewrite_return (historic @ before) after - | before, I_aux (I_if (cval, then_instrs, else_instrs, ctyp), (_, l)) :: after -> - let historic = historic @ before in - before - @ [iif l cval (rewrite_return historic then_instrs) (rewrite_return historic else_instrs) ctyp] - @ rewrite_return historic after - | before, I_aux (I_return cval, (_, l)) :: after -> - let cleanup_label = label "cleanup_" in - let end_cleanup_label = label "end_cleanup_" in - before - @ [icopy l ret cval; - igoto cleanup_label] - (* This is probably dead code until cleanup_label, but we cannot be sure there are no jumps into it. *) - @ rewrite_return (historic @ before) after - @ [igoto end_cleanup_label; - ilabel cleanup_label] - @ generate_cleanup (historic @ before) - @ [igoto end_function_label; - ilabel end_cleanup_label] - | before, I_aux (I_undefined _, (_, l)) :: after -> - let cleanup_label = label "cleanup_" in - let end_cleanup_label = label "end_cleanup_" in - before - @ [igoto cleanup_label] - @ rewrite_return (historic @ before) after - @ [igoto end_cleanup_label; - ilabel cleanup_label] - @ generate_cleanup (historic @ before) - @ [igoto end_function_label; - ilabel end_cleanup_label] - | _, _ -> assert false - in - rewrite_return [] instrs - @ [ilabel end_function_label; iend l] - -(** This pass ensures that all variables created by I_decl have unique names *) -let unique_names = - let unique_counter = ref 0 in - let unique_id () = - let id = mk_id ("u#" ^ string_of_int !unique_counter) in - incr unique_counter; - name id - in - - let rec opt seen = function - | I_aux (I_decl (ctyp, id), aux) :: instrs when NameSet.mem id seen -> - let id' = unique_id () in - let instrs', seen = opt seen instrs in - I_aux (I_decl (ctyp, id'), aux) :: instrs_rename id id' instrs', seen - - | I_aux (I_decl (ctyp, id), aux) :: instrs -> - let instrs', seen = opt (NameSet.add id seen) instrs in - I_aux (I_decl (ctyp, id), aux) :: instrs', seen - - | I_aux (I_block block, aux) :: instrs -> - let block', seen = opt seen block in - let instrs', seen = opt seen instrs in - I_aux (I_block block', aux) :: instrs', seen - - | I_aux (I_try_block block, aux) :: instrs -> - let block', seen = opt seen block in - let instrs', seen = opt seen instrs in - I_aux (I_try_block block', aux) :: instrs', seen - - | I_aux (I_if (cval, then_instrs, else_instrs, ctyp), aux) :: instrs -> - let then_instrs', seen = opt seen then_instrs in - let else_instrs', seen = opt seen else_instrs in - let instrs', seen = opt seen instrs in - I_aux (I_if (cval, then_instrs', else_instrs', ctyp), aux) :: instrs', seen - - | instr :: instrs -> - let instrs', seen = opt seen instrs in - instr :: instrs', seen - - | [] -> [], seen - in - fun instrs -> fst (opt NameSet.empty instrs) - -let letdef_count = ref 0 - -let compile_funcl ctx id pat guard exp = - (* Find the function's type. *) - let quant, Typ_aux (fn_typ, _) = - try Env.get_val_spec id ctx.local_env with Type_error _ -> Env.get_val_spec id ctx.tc_env - in - let arg_typs, ret_typ = match fn_typ with - | Typ_fn (arg_typs, ret_typ) -> arg_typs, ret_typ - | _ -> assert false - in + let fix_early_return l ret instrs = + let end_function_label = label "end_function_" in + let is_return_recur (I_aux (instr, _)) = + match instr with I_return _ | I_undefined _ | I_if _ | I_block _ | I_try_block _ -> true | _ -> false + in + let rec rewrite_return historic instrs = + match instr_split_at is_return_recur instrs with + | instrs, [] -> instrs + | before, I_aux (I_try_block instrs, (_, l)) :: after -> + before @ [itry_block l (rewrite_return (historic @ before) instrs)] @ rewrite_return (historic @ before) after + | before, I_aux (I_block instrs, _) :: after -> + before @ [iblock (rewrite_return (historic @ before) instrs)] @ rewrite_return (historic @ before) after + | before, I_aux (I_if (cval, then_instrs, else_instrs, ctyp), (_, l)) :: after -> + let historic = historic @ before in + before + @ [iif l cval (rewrite_return historic then_instrs) (rewrite_return historic else_instrs) ctyp] + @ rewrite_return historic after + | before, I_aux (I_return cval, (_, l)) :: after -> + let cleanup_label = label "cleanup_" in + let end_cleanup_label = label "end_cleanup_" in + before + @ [icopy l ret cval; igoto cleanup_label] + (* This is probably dead code until cleanup_label, but we cannot be sure there are no jumps into it. *) + @ rewrite_return (historic @ before) after + @ [igoto end_cleanup_label; ilabel cleanup_label] + @ generate_cleanup (historic @ before) + @ [igoto end_function_label; ilabel end_cleanup_label] + | before, I_aux (I_undefined _, (_, l)) :: after -> + let cleanup_label = label "cleanup_" in + let end_cleanup_label = label "end_cleanup_" in + before + @ [igoto cleanup_label] + @ rewrite_return (historic @ before) after + @ [igoto end_cleanup_label; ilabel cleanup_label] + @ generate_cleanup (historic @ before) + @ [igoto end_function_label; ilabel end_cleanup_label] + | _, _ -> assert false + in + rewrite_return [] instrs @ [ilabel end_function_label; iend l] + + (** This pass ensures that all variables created by I_decl have unique names *) + let unique_names = + let unique_counter = ref 0 in + let unique_id () = + let id = mk_id ("u#" ^ string_of_int !unique_counter) in + incr unique_counter; + name id + in - (* Handle the argument pattern. *) - let fundef_label = label "fundef_fail_" in - let orig_ctx = ctx in - (* The context must be updated before we call ctyp_of_typ on the argument types. *) - let ctx = { ctx with local_env = Env.add_typquant (id_loc id) quant ctx.tc_env } in + let rec opt seen = function + | I_aux (I_decl (ctyp, id), aux) :: instrs when NameSet.mem id seen -> + let id' = unique_id () in + let instrs', seen = opt seen instrs in + (I_aux (I_decl (ctyp, id'), aux) :: instrs_rename id id' instrs', seen) + | I_aux (I_decl (ctyp, id), aux) :: instrs -> + let instrs', seen = opt (NameSet.add id seen) instrs in + (I_aux (I_decl (ctyp, id), aux) :: instrs', seen) + | I_aux (I_block block, aux) :: instrs -> + let block', seen = opt seen block in + let instrs', seen = opt seen instrs in + (I_aux (I_block block', aux) :: instrs', seen) + | I_aux (I_try_block block, aux) :: instrs -> + let block', seen = opt seen block in + let instrs', seen = opt seen instrs in + (I_aux (I_try_block block', aux) :: instrs', seen) + | I_aux (I_if (cval, then_instrs, else_instrs, ctyp), aux) :: instrs -> + let then_instrs', seen = opt seen then_instrs in + let else_instrs', seen = opt seen else_instrs in + let instrs', seen = opt seen instrs in + (I_aux (I_if (cval, then_instrs', else_instrs', ctyp), aux) :: instrs', seen) + | instr :: instrs -> + let instrs', seen = opt seen instrs in + (instr :: instrs', seen) + | [] -> ([], seen) + in + fun instrs -> fst (opt NameSet.empty instrs) - let arg_ctyps = List.map (ctyp_of_typ ctx) arg_typs in - let ret_ctyp = ctyp_of_typ ctx ret_typ in + let letdef_count = ref 0 - (* Compile the function arguments as patterns. *) - let arg_setup, compiled_args, arg_cleanup = compile_arg_pats ctx fundef_label pat arg_ctyps in - let ctx = - (* We need the primop analyzer to be aware of the function argument types, so put them in ctx *) - List.fold_left2 (fun ctx (id, _) ctyp -> { ctx with locals = Bindings.add id (Immutable, ctyp) ctx.locals }) ctx compiled_args arg_ctyps - in + let compile_funcl ctx id pat guard exp = + (* Find the function's type. *) + let quant, Typ_aux (fn_typ, _) = + try Env.get_val_spec id ctx.local_env with Type_error _ -> Env.get_val_spec id ctx.tc_env + in + let arg_typs, ret_typ = match fn_typ with Typ_fn (arg_typs, ret_typ) -> (arg_typs, ret_typ) | _ -> assert false in + + (* Handle the argument pattern. *) + let fundef_label = label "fundef_fail_" in + let orig_ctx = ctx in + (* The context must be updated before we call ctyp_of_typ on the argument types. *) + let ctx = { ctx with local_env = Env.add_typquant (id_loc id) quant ctx.tc_env } in + + let arg_ctyps = List.map (ctyp_of_typ ctx) arg_typs in + let ret_ctyp = ctyp_of_typ ctx ret_typ in + + (* Compile the function arguments as patterns. *) + let arg_setup, compiled_args, arg_cleanup = compile_arg_pats ctx fundef_label pat arg_ctyps in + let ctx = + (* We need the primop analyzer to be aware of the function argument types, so put them in ctx *) + List.fold_left2 + (fun ctx (id, _) ctyp -> { ctx with locals = Bindings.add id (Immutable, ctyp) ctx.locals }) + ctx compiled_args arg_ctyps + in - let guard_bindings = ref IdSet.empty in - let guard_instrs = match guard with - | Some guard -> - let (AE_aux (_, _, l) as guard) = anf guard in - guard_bindings := aexp_bindings guard; - let guard_aexp = C.optimize_anf ctx (no_shadow (pat_ids pat) guard) in - let guard_setup, guard_call, guard_cleanup = compile_aexp ctx guard_aexp in - let guard_label = label "guard_" in - let gs = ngensym () in - [iblock ( - [idecl l CT_bool gs] - @ guard_setup - @ [guard_call (CL_id (gs, CT_bool))] - @ guard_cleanup - @ [ijump (id_loc id) (V_id (gs, CT_bool)) guard_label; - imatch_failure l; - ilabel guard_label] - )] - | None -> [] - in + let guard_bindings = ref IdSet.empty in + let guard_instrs = + match guard with + | Some guard -> + let (AE_aux (_, _, l) as guard) = anf guard in + guard_bindings := aexp_bindings guard; + let guard_aexp = C.optimize_anf ctx (no_shadow (pat_ids pat) guard) in + let guard_setup, guard_call, guard_cleanup = compile_aexp ctx guard_aexp in + let guard_label = label "guard_" in + let gs = ngensym () in + [ + iblock + ([idecl l CT_bool gs] + @ guard_setup + @ [guard_call (CL_id (gs, CT_bool))] + @ guard_cleanup + @ [ijump (id_loc id) (V_id (gs, CT_bool)) guard_label; imatch_failure l; ilabel guard_label] + ); + ] + | None -> [] + in - (* Optimize and compile the expression to ANF. *) - let aexp = C.optimize_anf ctx (no_shadow (IdSet.union (pat_ids pat) !guard_bindings) (anf exp)) in + (* Optimize and compile the expression to ANF. *) + let aexp = C.optimize_anf ctx (no_shadow (IdSet.union (pat_ids pat) !guard_bindings) (anf exp)) in - let setup, call, cleanup = compile_aexp ctx aexp in - let destructure, destructure_cleanup = - compiled_args |> List.map snd |> combine_destructure_cleanup |> fix_destructure (id_loc id) fundef_label - in + let setup, call, cleanup = compile_aexp ctx aexp in + let destructure, destructure_cleanup = + compiled_args |> List.map snd |> combine_destructure_cleanup |> fix_destructure (id_loc id) fundef_label + in - let instrs = arg_setup @ destructure @ guard_instrs @ setup @ [call (CL_id (return, ret_ctyp))] @ cleanup @ destructure_cleanup @ arg_cleanup in - let instrs = fix_early_return (exp_loc exp) (CL_id (return, ret_ctyp)) instrs in - let instrs = unique_names instrs in - let instrs = fix_exception ~return:(Some ret_ctyp) ctx instrs in - let instrs = coverage_function_entry id (exp_loc exp) @ instrs in - - [CDEF_fundef (id, None, List.map fst compiled_args, instrs)], orig_ctx - -(** Compile a Sail toplevel definition into an IR definition **) -let rec compile_def n total ctx (DEF_aux (aux, _) as def) = - match aux with - | DEF_fundef (FD_aux (FD_function (_, _, [FCL_aux (FCL_funcl (id, _), _)]), _)) - when !opt_memo_cache -> - let digest = - strip_def def |> Pretty_print_sail.doc_def |> Pretty_print_sail.to_string |> Digest.string - in - let cachefile = Filename.concat "_sbuild" ("ccache" ^ Digest.to_hex digest) in - let cached = - if Sys.file_exists cachefile then - let in_chan = open_in cachefile in - try - let compiled = Marshal.from_channel in_chan in - close_in in_chan; - Some (compiled, ctx) - with - | _ -> close_in in_chan; None - else - None - in - begin match cached with - | Some (compiled, ctx) -> + let instrs = + arg_setup @ destructure @ guard_instrs @ setup + @ [call (CL_id (return, ret_ctyp))] + @ cleanup @ destructure_cleanup @ arg_cleanup + in + let instrs = fix_early_return (exp_loc exp) (CL_id (return, ret_ctyp)) instrs in + let instrs = unique_names instrs in + let instrs = fix_exception ~return:(Some ret_ctyp) ctx instrs in + let instrs = coverage_function_entry id (exp_loc exp) @ instrs in + + ([CDEF_fundef (id, None, List.map fst compiled_args, instrs)], orig_ctx) + + (** Compile a Sail toplevel definition into an IR definition **) + let rec compile_def n total ctx (DEF_aux (aux, _) as def) = + match aux with + | DEF_fundef (FD_aux (FD_function (_, _, [FCL_aux (FCL_funcl (id, _), _)]), _)) when !opt_memo_cache -> + let digest = strip_def def |> Pretty_print_sail.doc_def |> Pretty_print_sail.to_string |> Digest.string in + let cachefile = Filename.concat "_sbuild" ("ccache" ^ Digest.to_hex digest) in + let cached = + if Sys.file_exists cachefile then ( + let in_chan = open_in cachefile in + try + let compiled = Marshal.from_channel in_chan in + close_in in_chan; + Some (compiled, ctx) + with _ -> + close_in in_chan; + None + ) + else None + in + begin + match cached with + | Some (compiled, ctx) -> + Util.progress "Compiling " (string_of_id id) n total; + (compiled, ctx) + | None -> + let compiled, ctx = compile_def' n total ctx def in + let out_chan = open_out cachefile in + Marshal.to_channel out_chan compiled [Marshal.Closures]; + close_out out_chan; + (compiled, ctx) + end + | _ -> compile_def' n total ctx def + + and compile_def' n total ctx (DEF_aux (aux, _) as def) = + match aux with + | DEF_register (DEC_aux (DEC_reg (typ, id, None), _)) -> ([CDEF_register (id, ctyp_of_typ ctx typ, [])], ctx) + | DEF_register (DEC_aux (DEC_reg (typ, id, Some exp), _)) -> + let aexp = C.optimize_anf ctx (no_shadow IdSet.empty (anf exp)) in + let setup, call, cleanup = compile_aexp ctx aexp in + let instrs = setup @ [call (CL_id (global id, ctyp_of_typ ctx typ))] @ cleanup in + let instrs = unique_names instrs in + ([CDEF_register (id, ctyp_of_typ ctx typ, instrs)], ctx) + | DEF_val (VS_aux (VS_val_spec (_, id, ext, _), _)) -> + let quant, Typ_aux (fn_typ, _) = Env.get_val_spec id ctx.tc_env in + let extern = if Env.is_extern id ctx.tc_env "c" then Some (Env.get_extern id ctx.tc_env "c") else None in + let arg_typs, ret_typ = + match fn_typ with Typ_fn (arg_typs, ret_typ) -> (arg_typs, ret_typ) | _ -> assert false + in + let ctx' = { ctx with local_env = Env.add_typquant (id_loc id) quant ctx.local_env } in + let arg_ctyps, ret_ctyp = (List.map (ctyp_of_typ ctx') arg_typs, ctyp_of_typ ctx' ret_typ) in + ( [CDEF_val (id, extern, arg_ctyps, ret_ctyp)], + { ctx with valspecs = Bindings.add id (extern, arg_ctyps, ret_ctyp) ctx.valspecs } + ) + | DEF_fundef (FD_aux (FD_function (_, _, [FCL_aux (FCL_funcl (id, Pat_aux (Pat_exp (pat, exp), _)), _)]), _)) -> Util.progress "Compiling " (string_of_id id) n total; - compiled, ctx - | None -> - let compiled, ctx = compile_def' n total ctx def in - let out_chan = open_out cachefile in - Marshal.to_channel out_chan compiled [Marshal.Closures]; - close_out out_chan; - compiled, ctx - end - - | _ -> compile_def' n total ctx def - -and compile_def' n total ctx (DEF_aux (aux, _) as def) = - match aux with - | DEF_register (DEC_aux (DEC_reg (typ, id, None), _)) -> - [CDEF_register (id, ctyp_of_typ ctx typ, [])], ctx - | DEF_register (DEC_aux (DEC_reg (typ, id, Some exp), _)) -> - let aexp = C.optimize_anf ctx (no_shadow IdSet.empty (anf exp)) in - let setup, call, cleanup = compile_aexp ctx aexp in - let instrs = setup @ [call (CL_id (global id, ctyp_of_typ ctx typ))] @ cleanup in - let instrs = unique_names instrs in - [CDEF_register (id, ctyp_of_typ ctx typ, instrs)], ctx - - | DEF_val (VS_aux (VS_val_spec (_, id, ext, _), _)) -> - let quant, Typ_aux (fn_typ, _) = Env.get_val_spec id ctx.tc_env in - let extern = - if Env.is_extern id ctx.tc_env "c" then - Some (Env.get_extern id ctx.tc_env "c") - else - None - in - let arg_typs, ret_typ = match fn_typ with - | Typ_fn (arg_typs, ret_typ) -> arg_typs, ret_typ - | _ -> assert false - in - let ctx' = { ctx with local_env = Env.add_typquant (id_loc id) quant ctx.local_env } in - let arg_ctyps, ret_ctyp = List.map (ctyp_of_typ ctx') arg_typs, ctyp_of_typ ctx' ret_typ in - [CDEF_val (id, extern, arg_ctyps, ret_ctyp)], - { ctx with valspecs = Bindings.add id (extern, arg_ctyps, ret_ctyp) ctx.valspecs } - - | DEF_fundef (FD_aux (FD_function (_, _, [FCL_aux (FCL_funcl (id, Pat_aux (Pat_exp (pat, exp), _)), _)]), _)) -> - Util.progress "Compiling " (string_of_id id) n total; - compile_funcl ctx id pat None exp - - | DEF_fundef (FD_aux (FD_function (_, _, [FCL_aux (FCL_funcl (id, Pat_aux (Pat_when (pat, guard, exp), _)), _)]), _)) -> - Util.progress "Compiling " (string_of_id id) n total; - compile_funcl ctx id pat (Some guard) exp - - | DEF_fundef (FD_aux (FD_function (_, _, []), (l, _))) -> - raise (Reporting.err_general l "Encountered function with no clauses") - - | DEF_fundef (FD_aux (FD_function (_, _, _ :: _ :: _), (l, _))) -> - raise (Reporting.err_general l "Encountered function with multiple clauses") - - (* All abbreviations should expanded by the typechecker, so we don't - need to translate type abbreviations into C typedefs. *) - | DEF_type (TD_aux (TD_abbrev _, _)) -> [], ctx - - | DEF_type type_def -> - let tdef, ctx = compile_type_def ctx type_def in - [CDEF_type tdef], ctx - - | DEF_let (LB_aux (LB_val (pat, exp), _)) -> - let ctyp = ctyp_of_typ ctx (typ_of_pat pat) in - let aexp = C.optimize_anf ctx (no_shadow IdSet.empty (anf exp)) in - let setup, call, cleanup = compile_aexp ctx aexp in - let apat = anf_pat ~global:true pat in - let gs = ngensym () in - let end_label = label "let_end_" in - let pre_destructure, destructure, destructure_cleanup, _ = compile_match ctx apat (V_id (gs, ctyp)) end_label in - let gs_setup, gs_cleanup = - [idecl (exp_loc exp) ctyp gs], [iclear ctyp gs] - in - let bindings = List.map (fun (id, typ) -> id, ctyp_of_typ ctx typ) (apat_globals apat) in - let n = !letdef_count in - incr letdef_count; - let instrs = - gs_setup @ setup - @ [call (CL_id (gs, ctyp))] - @ cleanup - @ pre_destructure - @ destructure - @ destructure_cleanup @ gs_cleanup - @ [ilabel end_label] - in - let instrs = unique_names instrs in - [CDEF_let (n, bindings, instrs)], - { ctx with letbinds = n :: ctx.letbinds } - - (* Only DEF_default that matters is default Order, but all order - polymorphism is specialised by this point. *) - | DEF_default _ -> [], ctx - - (* Overloading resolved by type checker *) - | DEF_overload _ -> [], ctx - - (* Only the parser and sail pretty printer care about this. *) - | DEF_fixity _ -> [], ctx - - | DEF_pragma ("abstract", id_str, _) -> [CDEF_pragma ("abstract", id_str)], ctx - - (* We just ignore any pragmas we don't want to deal with. *) - | DEF_pragma _ -> [], ctx - - (* Termination measures only needed for Coq, and other theorem prover output *) - | DEF_measure _ -> [], ctx - | DEF_loop_measures _ -> [], ctx - - | DEF_internal_mutrec fundefs -> - let defs = List.map (fun fdef -> mk_def (DEF_fundef fdef)) fundefs in - List.fold_left (fun (cdefs, ctx) def -> let cdefs', ctx = compile_def n total ctx def in (cdefs @ cdefs', ctx)) ([], ctx) defs - - (* Scattereds, mapdefs, and event related definitions should be removed by this point *) - | DEF_scattered _ | DEF_mapdef _ | DEF_outcome _ | DEF_impl _ | DEF_instantiation _ -> - Reporting.unreachable (def_loc def) __POS__ - ("Could not compile:\n" ^ Pretty_print_sail.to_string (Pretty_print_sail.doc_def (strip_def def))) - -let mangle_mono_id id ctx ctyps = - append_id id ("<" ^ Util.string_of_list "," (mangle_string_of_ctyp ctx) ctyps ^ ">") - -(* The specialized calls argument keeps track of functions we have - already specialized, so we don't accidentally specialize them twice - in a future round of specialization *) -let rec specialize_functions ?(specialized_calls=ref IdSet.empty) ctx cdefs = - let polymorphic_functions = - List.filter_map (function - | CDEF_val (id, _, param_ctyps, ret_ctyp) -> - if List.exists is_polymorphic param_ctyps || is_polymorphic ret_ctyp then - Some id - else - None - | _ -> None - ) cdefs |> IdSet.of_list - in + compile_funcl ctx id pat None exp + | DEF_fundef (FD_aux (FD_function (_, _, [FCL_aux (FCL_funcl (id, Pat_aux (Pat_when (pat, guard, exp), _)), _)]), _)) + -> + Util.progress "Compiling " (string_of_id id) n total; + compile_funcl ctx id pat (Some guard) exp + | DEF_fundef (FD_aux (FD_function (_, _, []), (l, _))) -> + raise (Reporting.err_general l "Encountered function with no clauses") + | DEF_fundef (FD_aux (FD_function (_, _, _ :: _ :: _), (l, _))) -> + raise (Reporting.err_general l "Encountered function with multiple clauses") + (* All abbreviations should expanded by the typechecker, so we don't + need to translate type abbreviations into C typedefs. *) + | DEF_type (TD_aux (TD_abbrev _, _)) -> ([], ctx) + | DEF_type type_def -> + let tdef, ctx = compile_type_def ctx type_def in + ([CDEF_type tdef], ctx) + | DEF_let (LB_aux (LB_val (pat, exp), _)) -> + let ctyp = ctyp_of_typ ctx (typ_of_pat pat) in + let aexp = C.optimize_anf ctx (no_shadow IdSet.empty (anf exp)) in + let setup, call, cleanup = compile_aexp ctx aexp in + let apat = anf_pat ~global:true pat in + let gs = ngensym () in + let end_label = label "let_end_" in + let pre_destructure, destructure, destructure_cleanup, _ = compile_match ctx apat (V_id (gs, ctyp)) end_label in + let gs_setup, gs_cleanup = ([idecl (exp_loc exp) ctyp gs], [iclear ctyp gs]) in + let bindings = List.map (fun (id, typ) -> (id, ctyp_of_typ ctx typ)) (apat_globals apat) in + let n = !letdef_count in + incr letdef_count; + let instrs = + gs_setup @ setup + @ [call (CL_id (gs, ctyp))] + @ cleanup @ pre_destructure @ destructure @ destructure_cleanup @ gs_cleanup + @ [ilabel end_label] + in + let instrs = unique_names instrs in + ([CDEF_let (n, bindings, instrs)], { ctx with letbinds = n :: ctx.letbinds }) + (* Only DEF_default that matters is default Order, but all order + polymorphism is specialised by this point. *) + | DEF_default _ -> ([], ctx) + (* Overloading resolved by type checker *) + | DEF_overload _ -> ([], ctx) + (* Only the parser and sail pretty printer care about this. *) + | DEF_fixity _ -> ([], ctx) + | DEF_pragma ("abstract", id_str, _) -> ([CDEF_pragma ("abstract", id_str)], ctx) + (* We just ignore any pragmas we don't want to deal with. *) + | DEF_pragma _ -> ([], ctx) + (* Termination measures only needed for Coq, and other theorem prover output *) + | DEF_measure _ -> ([], ctx) + | DEF_loop_measures _ -> ([], ctx) + | DEF_internal_mutrec fundefs -> + let defs = List.map (fun fdef -> mk_def (DEF_fundef fdef)) fundefs in + List.fold_left + (fun (cdefs, ctx) def -> + let cdefs', ctx = compile_def n total ctx def in + (cdefs @ cdefs', ctx) + ) + ([], ctx) defs + (* Scattereds, mapdefs, and event related definitions should be removed by this point *) + | DEF_scattered _ | DEF_mapdef _ | DEF_outcome _ | DEF_impl _ | DEF_instantiation _ -> + Reporting.unreachable (def_loc def) __POS__ + ("Could not compile:\n" ^ Pretty_print_sail.to_string (Pretty_print_sail.doc_def (strip_def def))) + + let mangle_mono_id id ctx ctyps = append_id id ("<" ^ Util.string_of_list "," (mangle_string_of_ctyp ctx) ctyps ^ ">") + + (* The specialized calls argument keeps track of functions we have + already specialized, so we don't accidentally specialize them twice + in a future round of specialization *) + let rec specialize_functions ?(specialized_calls = ref IdSet.empty) ctx cdefs = + let polymorphic_functions = + List.filter_map + (function + | CDEF_val (id, _, param_ctyps, ret_ctyp) -> + if List.exists is_polymorphic param_ctyps || is_polymorphic ret_ctyp then Some id else None + | _ -> None + ) + cdefs + |> IdSet.of_list + in - (* First we find all the 'monomorphic calls', places where a - polymorphic function is applied to only concrete type arguments - - At each such location we remove the type arguments and mangle the - call name using them *) - let monomorphic_calls = ref Bindings.empty in - let collect_monomorphic_calls = function - | I_aux (I_funcall (clexp, extern, (id, ctyp_args), args), aux) - when IdSet.mem id polymorphic_functions && not (List.exists is_polymorphic ctyp_args) -> - monomorphic_calls := Bindings.update id (function None -> Some (CTListSet.singleton ctyp_args) | Some calls -> Some (CTListSet.add ctyp_args calls)) !monomorphic_calls; - I_aux (I_funcall (clexp, extern, (mangle_mono_id id ctx ctyp_args, []), args), aux) - | instr -> instr - in - let cdefs = List.rev_map (cdef_map_instr collect_monomorphic_calls) cdefs |> List.rev in - - (* Now we duplicate function defintions and type declarations for - each of the monomorphic calls we just found. *) - let spec_tyargs = ref Bindings.empty in - let rec specialize_fundefs ctx prior = function - | (CDEF_val (id, extern, param_ctyps, ret_ctyp) as orig_cdef) :: cdefs when Bindings.mem id !monomorphic_calls -> - let tyargs = List.fold_left (fun set ctyp -> KidSet.union (ctyp_vars ctyp) set) KidSet.empty (ret_ctyp :: param_ctyps) in - spec_tyargs := Bindings.add id tyargs !spec_tyargs; - let specialized_specs = - List.filter_map (fun instantiation -> - let specialized_id = mangle_mono_id id ctx instantiation in - if not (IdSet.mem specialized_id !specialized_calls) then ( - let substs = List.fold_left2 (fun substs tyarg ty -> KBindings.add tyarg ty substs) KBindings.empty (KidSet.elements tyargs) instantiation in - let param_ctyps = List.map (subst_poly substs) param_ctyps in - let ret_ctyp = subst_poly substs ret_ctyp in - Some (CDEF_val (specialized_id, extern, param_ctyps, ret_ctyp)) - ) else - None - ) (CTListSet.elements (Bindings.find id !monomorphic_calls)) - in - let ctx = - List.fold_left (fun ctx cdef -> - match cdef with - | CDEF_val (id, _, param_ctyps, ret_ctyp) -> { ctx with valspecs = Bindings.add id (extern, param_ctyps, ret_ctyp) ctx.valspecs } - | cdef -> ctx - ) ctx specialized_specs - in - specialize_fundefs ctx (orig_cdef :: specialized_specs @ prior) cdefs - - | (CDEF_fundef (id, heap_return, params, body) as orig_cdef) :: cdefs when Bindings.mem id !monomorphic_calls -> - let tyargs = Bindings.find id !spec_tyargs in - let specialized_fundefs = - List.filter_map (fun instantiation -> - let specialized_id = mangle_mono_id id ctx instantiation in - if not (IdSet.mem specialized_id !specialized_calls) then ( - specialized_calls := IdSet.add specialized_id !specialized_calls; - let substs = List.fold_left2 (fun substs tyarg ty -> KBindings.add tyarg ty substs) KBindings.empty (KidSet.elements tyargs) instantiation in - let body = List.map (map_instr_ctyp (subst_poly substs)) body in - Some (CDEF_fundef (specialized_id, heap_return, params, body)) - ) else - None - ) (CTListSet.elements (Bindings.find id !monomorphic_calls)) - in - specialize_fundefs ctx (orig_cdef :: specialized_fundefs @ prior) cdefs - - | cdef :: cdefs -> - specialize_fundefs ctx (cdef :: prior) cdefs - | [] -> - List.rev prior, ctx - in + (* First we find all the 'monomorphic calls', places where a + polymorphic function is applied to only concrete type arguments + + At each such location we remove the type arguments and mangle the + call name using them *) + let monomorphic_calls = ref Bindings.empty in + let collect_monomorphic_calls = function + | I_aux (I_funcall (clexp, extern, (id, ctyp_args), args), aux) + when IdSet.mem id polymorphic_functions && not (List.exists is_polymorphic ctyp_args) -> + monomorphic_calls := + Bindings.update id + (function + | None -> Some (CTListSet.singleton ctyp_args) | Some calls -> Some (CTListSet.add ctyp_args calls) + ) + !monomorphic_calls; + I_aux (I_funcall (clexp, extern, (mangle_mono_id id ctx ctyp_args, []), args), aux) + | instr -> instr + in + let cdefs = List.rev_map (cdef_map_instr collect_monomorphic_calls) cdefs |> List.rev in + + (* Now we duplicate function defintions and type declarations for + each of the monomorphic calls we just found. *) + let spec_tyargs = ref Bindings.empty in + let rec specialize_fundefs ctx prior = function + | (CDEF_val (id, extern, param_ctyps, ret_ctyp) as orig_cdef) :: cdefs when Bindings.mem id !monomorphic_calls -> + let tyargs = + List.fold_left (fun set ctyp -> KidSet.union (ctyp_vars ctyp) set) KidSet.empty (ret_ctyp :: param_ctyps) + in + spec_tyargs := Bindings.add id tyargs !spec_tyargs; + let specialized_specs = + List.filter_map + (fun instantiation -> + let specialized_id = mangle_mono_id id ctx instantiation in + if not (IdSet.mem specialized_id !specialized_calls) then ( + let substs = + List.fold_left2 + (fun substs tyarg ty -> KBindings.add tyarg ty substs) + KBindings.empty (KidSet.elements tyargs) instantiation + in + let param_ctyps = List.map (subst_poly substs) param_ctyps in + let ret_ctyp = subst_poly substs ret_ctyp in + Some (CDEF_val (specialized_id, extern, param_ctyps, ret_ctyp)) + ) + else None + ) + (CTListSet.elements (Bindings.find id !monomorphic_calls)) + in + let ctx = + List.fold_left + (fun ctx cdef -> + match cdef with + | CDEF_val (id, _, param_ctyps, ret_ctyp) -> + { ctx with valspecs = Bindings.add id (extern, param_ctyps, ret_ctyp) ctx.valspecs } + | cdef -> ctx + ) + ctx specialized_specs + in + specialize_fundefs ctx ((orig_cdef :: specialized_specs) @ prior) cdefs + | (CDEF_fundef (id, heap_return, params, body) as orig_cdef) :: cdefs when Bindings.mem id !monomorphic_calls -> + let tyargs = Bindings.find id !spec_tyargs in + let specialized_fundefs = + List.filter_map + (fun instantiation -> + let specialized_id = mangle_mono_id id ctx instantiation in + if not (IdSet.mem specialized_id !specialized_calls) then ( + specialized_calls := IdSet.add specialized_id !specialized_calls; + let substs = + List.fold_left2 + (fun substs tyarg ty -> KBindings.add tyarg ty substs) + KBindings.empty (KidSet.elements tyargs) instantiation + in + let body = List.map (map_instr_ctyp (subst_poly substs)) body in + Some (CDEF_fundef (specialized_id, heap_return, params, body)) + ) + else None + ) + (CTListSet.elements (Bindings.find id !monomorphic_calls)) + in + specialize_fundefs ctx ((orig_cdef :: specialized_fundefs) @ prior) cdefs + | cdef :: cdefs -> specialize_fundefs ctx (cdef :: prior) cdefs + | [] -> (List.rev prior, ctx) + in - let cdefs, ctx = specialize_fundefs ctx [] cdefs in - - (* Now we want to remove any polymorphic functions that are - unreachable from any monomorphic function *) - let graph = callgraph cdefs in - let monomorphic_roots = - List.filter_map (function - | CDEF_val (id, _, param_ctyps, ret_ctyp) -> - if List.exists is_polymorphic param_ctyps || is_polymorphic ret_ctyp then - None - else - Some id - | _ -> None - ) cdefs |> IdGraphNS.of_list - in - let monomorphic_reachable = IdGraph.reachable monomorphic_roots IdGraphNS.empty graph in - let unreachable_polymorphic_functions = - IdSet.filter (fun id -> not (IdGraphNS.mem id monomorphic_reachable)) polymorphic_functions - in - let cdefs = - List.filter_map (function - | CDEF_fundef (id, _, _, _) when IdSet.mem id unreachable_polymorphic_functions -> None - | CDEF_val (id, _, _, _) when IdSet.mem id unreachable_polymorphic_functions -> None - | cdef -> Some cdef - ) cdefs - in + let cdefs, ctx = specialize_fundefs ctx [] cdefs in + + (* Now we want to remove any polymorphic functions that are + unreachable from any monomorphic function *) + let graph = callgraph cdefs in + let monomorphic_roots = + List.filter_map + (function + | CDEF_val (id, _, param_ctyps, ret_ctyp) -> + if List.exists is_polymorphic param_ctyps || is_polymorphic ret_ctyp then None else Some id + | _ -> None + ) + cdefs + |> IdGraphNS.of_list + in + let monomorphic_reachable = IdGraph.reachable monomorphic_roots IdGraphNS.empty graph in + let unreachable_polymorphic_functions = + IdSet.filter (fun id -> not (IdGraphNS.mem id monomorphic_reachable)) polymorphic_functions + in + let cdefs = + List.filter_map + (function + | CDEF_fundef (id, _, _, _) when IdSet.mem id unreachable_polymorphic_functions -> None + | CDEF_val (id, _, _, _) when IdSet.mem id unreachable_polymorphic_functions -> None + | cdef -> Some cdef + ) + cdefs + in - (* If we have removed all the polymorphic functions we are done, otherwise go again *) - if IdSet.is_empty (IdSet.diff polymorphic_functions unreachable_polymorphic_functions) then - cdefs, ctx - else - specialize_functions ~specialized_calls:specialized_calls ctx cdefs - -let map_structs_and_variants f = function - | (CT_lint | CT_fint _ | CT_constant _ | CT_lbits _ | CT_fbits _ | CT_sbits _ | CT_bit | CT_unit - | CT_bool | CT_real | CT_string | CT_poly _ | CT_enum _ | CT_float _ | CT_rounding_mode) as ctyp -> ctyp - | CT_tup ctyps -> CT_tup (List.map (map_ctyp f) ctyps) - | CT_ref ctyp -> CT_ref (map_ctyp f ctyp) - | CT_vector (direction, ctyp) -> CT_vector (direction, map_ctyp f ctyp) - | CT_fvector (n, direction, ctyp) -> CT_fvector (n, direction, map_ctyp f ctyp) - | CT_list ctyp -> CT_list (map_ctyp f ctyp) - | CT_struct (id, fields) -> - begin match f (CT_struct (id, fields)) with - | CT_struct (id, fields) -> CT_struct (id, List.map (fun (id, ctyp) -> id, map_ctyp f ctyp) fields) - | _ -> Reporting.unreachable (id_loc id) __POS__ "Struct mapped to non-struct" - end - | CT_variant (id, ctors) -> - begin match f (CT_variant (id, ctors)) with - | CT_variant (id, ctors) -> CT_variant (id, List.map (fun (id, ctyp) -> id, map_ctyp f ctyp) ctors) - | _ -> Reporting.unreachable (id_loc id) __POS__ "Variant mapped to non-variant" - end - -let rec specialize_variants ctx prior = - let instantiations = ref CTListSet.empty in - let fix_variants ctx var_id = - map_structs_and_variants (function + (* If we have removed all the polymorphic functions we are done, otherwise go again *) + if IdSet.is_empty (IdSet.diff polymorphic_functions unreachable_polymorphic_functions) then (cdefs, ctx) + else specialize_functions ~specialized_calls ctx cdefs + + let map_structs_and_variants f = function + | ( CT_lint | CT_fint _ | CT_constant _ | CT_lbits _ | CT_fbits _ | CT_sbits _ | CT_bit | CT_unit | CT_bool + | CT_real | CT_string | CT_poly _ | CT_enum _ | CT_float _ | CT_rounding_mode ) as ctyp -> + ctyp + | CT_tup ctyps -> CT_tup (List.map (map_ctyp f) ctyps) + | CT_ref ctyp -> CT_ref (map_ctyp f ctyp) + | CT_vector (direction, ctyp) -> CT_vector (direction, map_ctyp f ctyp) + | CT_fvector (n, direction, ctyp) -> CT_fvector (n, direction, map_ctyp f ctyp) + | CT_list ctyp -> CT_list (map_ctyp f ctyp) + | CT_struct (id, fields) -> begin + match f (CT_struct (id, fields)) with + | CT_struct (id, fields) -> CT_struct (id, List.map (fun (id, ctyp) -> (id, map_ctyp f ctyp)) fields) + | _ -> Reporting.unreachable (id_loc id) __POS__ "Struct mapped to non-struct" + end + | CT_variant (id, ctors) -> begin + match f (CT_variant (id, ctors)) with + | CT_variant (id, ctors) -> CT_variant (id, List.map (fun (id, ctyp) -> (id, map_ctyp f ctyp)) ctors) + | _ -> Reporting.unreachable (id_loc id) __POS__ "Variant mapped to non-variant" + end + + let rec specialize_variants ctx prior = + let instantiations = ref CTListSet.empty in + let fix_variants ctx var_id = + map_structs_and_variants (function | CT_variant (id, ctors) when Id.compare var_id id = 0 -> - let generic_ctors = Bindings.find id ctx.variants |> snd |> Bindings.bindings in - let unifiers = ctyp_unify (id_loc id) (CT_variant (id, generic_ctors)) (CT_variant (id, ctors)) |> KBindings.bindings |> List.map snd in - CT_variant (mangle_mono_id id ctx unifiers, List.map (fun (ctor_id, ctyp) -> (mangle_mono_id ctor_id ctx unifiers, ctyp)) ctors) + let generic_ctors = Bindings.find id ctx.variants |> snd |> Bindings.bindings in + let unifiers = + ctyp_unify (id_loc id) (CT_variant (id, generic_ctors)) (CT_variant (id, ctors)) + |> KBindings.bindings |> List.map snd + in + CT_variant + ( mangle_mono_id id ctx unifiers, + List.map (fun (ctor_id, ctyp) -> (mangle_mono_id ctor_id ctx unifiers, ctyp)) ctors + ) | CT_struct (id, fields) when Id.compare var_id id = 0 -> - let generic_fields = Bindings.find id ctx.records |> snd |> Bindings.bindings in - let unifiers = ctyp_unify (id_loc id) (CT_struct (id, generic_fields)) (CT_struct (id, fields)) |> KBindings.bindings |> List.map snd in - CT_struct (mangle_mono_id id ctx unifiers, List.map (fun (field_id, ctyp) -> (field_id, ctyp)) fields) + let generic_fields = Bindings.find id ctx.records |> snd |> Bindings.bindings in + let unifiers = + ctyp_unify (id_loc id) (CT_struct (id, generic_fields)) (CT_struct (id, fields)) + |> KBindings.bindings |> List.map snd + in + CT_struct (mangle_mono_id id ctx unifiers, List.map (fun (field_id, ctyp) -> (field_id, ctyp)) fields) | ctyp -> ctyp - ) - in + ) + in - let specialize_cval ctx ctor_id = - function - | V_ctor_kind (cval, (id, unifiers), pat_ctyp) when Id.compare id ctor_id = 0 -> - V_ctor_kind (cval, (mangle_mono_id id ctx unifiers, []), pat_ctyp) - | V_ctor_unwrap (cval, (id, unifiers), ctor_ctyp) when Id.compare id ctor_id = 0 -> - V_ctor_unwrap (cval, (mangle_mono_id id ctx unifiers, []), ctor_ctyp) - | cval -> cval - in + let specialize_cval ctx ctor_id = function + | V_ctor_kind (cval, (id, unifiers), pat_ctyp) when Id.compare id ctor_id = 0 -> + V_ctor_kind (cval, (mangle_mono_id id ctx unifiers, []), pat_ctyp) + | V_ctor_unwrap (cval, (id, unifiers), ctor_ctyp) when Id.compare id ctor_id = 0 -> + V_ctor_unwrap (cval, (mangle_mono_id id ctx unifiers, []), ctor_ctyp) + | cval -> cval + in - let specialize_constructor ctx var_id ctor_id ctyp = - function - | I_aux (I_funcall (clexp, extern, (id, ctyp_args), [cval]), aux) when Id.compare id ctor_id = 0 -> - instantiations := CTListSet.add ctyp_args !instantiations; - I_aux (I_funcall (clexp, extern, (mangle_mono_id id ctx ctyp_args, []), [map_cval (specialize_cval ctx ctor_id) cval]), aux) + let specialize_constructor ctx var_id ctor_id ctyp = function + | I_aux (I_funcall (clexp, extern, (id, ctyp_args), [cval]), aux) when Id.compare id ctor_id = 0 -> + instantiations := CTListSet.add ctyp_args !instantiations; + I_aux + ( I_funcall + (clexp, extern, (mangle_mono_id id ctx ctyp_args, []), [map_cval (specialize_cval ctx ctor_id) cval]), + aux + ) + | instr -> map_instr_cval (map_cval (specialize_cval ctx ctor_id)) instr + in - | instr -> map_instr_cval (map_cval (specialize_cval ctx ctor_id)) instr - in + let specialize_field ctx struct_id = function + | I_aux (I_decl (CT_struct (struct_id', fields), _), (_, l)) as instr when Id.compare struct_id struct_id' = 0 -> + let generic_fields = Bindings.find struct_id ctx.records |> snd |> Bindings.bindings in + let unifiers = + ctyp_unify l (CT_struct (struct_id, generic_fields)) (CT_struct (struct_id, fields)) + |> KBindings.bindings |> List.map snd + in + instantiations := CTListSet.add unifiers !instantiations; + instr + | instr -> instr + in + + let mangled_pragma orig_id mangled_id = + CDEF_pragma + ("mangled", Util.zencode_string (string_of_id orig_id) ^ " " ^ Util.zencode_string (string_of_id mangled_id)) + in - let specialize_field ctx struct_id = function - | I_aux (I_decl (CT_struct (struct_id', fields), _), (_, l)) as instr when Id.compare struct_id struct_id' = 0 -> - let generic_fields = Bindings.find struct_id ctx.records |> snd |> Bindings.bindings in - let unifiers = ctyp_unify l (CT_struct (struct_id, generic_fields)) (CT_struct (struct_id, fields)) |> KBindings.bindings |> List.map snd in - instantiations := CTListSet.add unifiers !instantiations; - instr - | instr -> instr - in + | CDEF_type (CTD_variant (var_id, ctors)) :: cdefs when List.exists (fun (_, ctyp) -> is_polymorphic ctyp) ctors -> + let typ_params = List.fold_left (fun set (_, ctyp) -> KidSet.union (ctyp_vars ctyp) set) KidSet.empty ctors in + + List.iter + (function + | CDEF_val (id, _, ctyps, ctyp) -> + let _ = + List.map + (map_ctyp (fun ctyp -> + match ctyp with + | CT_variant (var_id', ctors) when Id.compare var_id var_id' = 0 -> + let generic_ctors = Bindings.find var_id ctx.variants |> snd |> Bindings.bindings in + let unifiers = + ctyp_unify (id_loc var_id') + (CT_variant (var_id, generic_ctors)) + (CT_variant (var_id, ctors)) + |> KBindings.bindings |> List.map snd + in + instantiations := CTListSet.add unifiers !instantiations; + ctyp + | ctyp -> ctyp + ) + ) + (ctyp :: ctyps) + in + () + | _ -> () + ) + cdefs; + + let cdefs = + List.fold_left + (fun cdefs (ctor_id, ctyp) -> + List.map (cdef_map_instr (specialize_constructor ctx var_id ctor_id ctyp)) cdefs + ) + cdefs ctors + in - let mangled_pragma orig_id mangled_id = - CDEF_pragma ("mangled", Util.zencode_string (string_of_id orig_id) ^ " " ^ Util.zencode_string (string_of_id mangled_id)) in - - function - | CDEF_type (CTD_variant (var_id, ctors)) :: cdefs when List.exists (fun (_, ctyp) -> is_polymorphic ctyp) ctors -> - let typ_params = List.fold_left (fun set (_, ctyp) -> KidSet.union (ctyp_vars ctyp) set) KidSet.empty ctors in - - List.iter (function - | CDEF_val (id, _, ctyps, ctyp) -> - let _ = List.map (map_ctyp (fun ctyp -> - match ctyp with - | CT_variant (var_id', ctors) when Id.compare var_id var_id' = 0 -> - let generic_ctors = Bindings.find var_id ctx.variants |> snd |> Bindings.bindings in - let unifiers = ctyp_unify (id_loc var_id') (CT_variant (var_id, generic_ctors)) (CT_variant (var_id, ctors)) |> KBindings.bindings |> List.map snd in - instantiations := CTListSet.add unifiers !instantiations; - ctyp - | ctyp -> ctyp - )) (ctyp :: ctyps) in - () - | _ -> () - ) cdefs; - - let cdefs = - List.fold_left - (fun cdefs (ctor_id, ctyp) -> List.map (cdef_map_instr (specialize_constructor ctx var_id ctor_id ctyp)) cdefs) - cdefs - ctors - in - - let monomorphized_variants = - List.map (fun inst -> - let substs = KBindings.of_seq (List.map2 (fun x y -> x, y) (KidSet.elements typ_params) inst |> List.to_seq) in - (mangle_mono_id var_id ctx inst, - List.map (fun (ctor_id, ctyp) -> mangle_mono_id ctor_id ctx inst, fix_variants ctx var_id (subst_poly substs ctyp)) ctors) - ) (CTListSet.elements !instantiations) - in - let ctx = - List.fold_left (fun ctx (id, ctors) -> - { ctx with variants = Bindings.add id ([], Bindings.of_seq (List.to_seq ctors)) ctx.variants }) - ctx monomorphized_variants - in - let mangled_ctors = - List.map (fun (_, monomorphized_ctors) -> - List.map2 (fun (ctor_id, _) (monomorphized_id, _) -> mangled_pragma ctor_id monomorphized_id) ctors monomorphized_ctors - ) monomorphized_variants - |> List.concat - in - - let prior = List.map (cdef_map_ctyp (fix_variants ctx var_id)) prior in - let cdefs = List.map (cdef_map_ctyp (fix_variants ctx var_id)) cdefs in - - let ctx = { ctx with valspecs = Bindings.map (fun (extern, param_ctyps, ret_ctyp) -> extern, List.map (fix_variants ctx var_id) param_ctyps, fix_variants ctx var_id ret_ctyp) ctx.valspecs } in - let ctx = { ctx with variants = Bindings.remove var_id ctx.variants } in - - specialize_variants ctx (List.concat (List.map (fun (id, ctors) -> [CDEF_type (CTD_variant (id, ctors)); mangled_pragma var_id id]) monomorphized_variants) @ mangled_ctors @ prior) cdefs - - | CDEF_type (CTD_struct (struct_id, fields)) :: cdefs when List.exists (fun (_, ctyp) -> is_polymorphic ctyp) fields -> - let typ_params = List.fold_left (fun set (_, ctyp) -> KidSet.union (ctyp_vars ctyp) set) KidSet.empty fields in - - let cdefs = List.map (cdef_map_instr (specialize_field ctx struct_id)) cdefs in - let monomorphized_structs = - List.map (fun inst -> - let substs = KBindings.of_seq (List.map2 (fun x y -> x, y) (KidSet.elements typ_params) inst |> List.to_seq) in - (mangle_mono_id struct_id ctx inst, List.map (fun (field_id, ctyp) -> field_id, fix_variants ctx struct_id (subst_poly substs ctyp)) fields) - ) (CTListSet.elements !instantiations) - in - let mangled_fields = - List.map (fun (_, monomorphized_fields) -> - List.map2 (fun (field_id, _) (monomorphized_id, _) -> mangled_pragma field_id monomorphized_id) fields monomorphized_fields - ) monomorphized_structs - |> List.concat - in - - let prior = List.map (cdef_map_ctyp (fix_variants ctx struct_id)) prior in - let cdefs = List.map (cdef_map_ctyp (fix_variants ctx struct_id)) cdefs in - let ctx = { ctx with valspecs = Bindings.map (fun (extern, param_ctyps, ret_ctyp) -> extern, List.map (fix_variants ctx struct_id) param_ctyps, fix_variants ctx struct_id ret_ctyp) ctx.valspecs } in - - let ctx = - List.fold_left (fun ctx (id, fields) -> - { ctx with records = Bindings.add id ([], Bindings.of_seq (List.to_seq fields)) ctx.records }) - ctx monomorphized_structs - in - let ctx = { ctx with records = Bindings.remove struct_id ctx.records } in - - specialize_variants ctx (List.concat (List.map (fun (id, fields) -> [CDEF_type (CTD_struct (id, fields)); mangled_pragma struct_id id]) monomorphized_structs) @ mangled_fields @ prior) cdefs - - | cdef :: cdefs -> - specialize_variants ctx (cdef :: prior) cdefs - - | [] -> - List.rev prior, ctx - -let make_calls_precise ctx cdefs = - let constructor_types = ref Bindings.empty in - - let get_function_typ id = match Bindings.find_opt id ctx.valspecs with - | None -> Bindings.find_opt id !constructor_types - | Some (_, param_ctyps, ret_ctyp) -> Some (param_ctyps, ret_ctyp) - in - - let precise_call call tail = - match call with - | I_aux (I_funcall (clexp, extern, (id, ctyp_args), args), ((_, l) as aux)) as instr -> - begin match get_function_typ id with - | None when string_of_id id = "sail_cons" -> - begin match ctyp_args, args with - | ([ctyp_arg], [hd_arg; tl_arg]) -> - if not (ctyp_equal (cval_ctyp hd_arg) ctyp_arg) then - let gs = ngensym () in - let cast = [ - idecl l ctyp_arg gs; - icopy l (CL_id (gs, ctyp_arg)) hd_arg - ] in - let cleanup = [ - iclear ~loc:l ctyp_arg gs - ] in - [iblock (cast @ [I_aux (I_funcall (clexp, extern, (id, ctyp_args), [V_id (gs, ctyp_arg); tl_arg]), aux)] @ tail @ cleanup)] - else - instr::tail - | _ -> - (* cons must have a single type parameter and two arguments *) - Reporting.unreachable (id_loc id) __POS__ "Invalid cons call" - end - | None -> - instr::tail - | Some (param_ctyps, ret_ctyp) -> - if List.compare_lengths args param_ctyps <> 0 then ( - Reporting.unreachable (id_loc id) __POS__ ("Function call found with incorrect arity: " ^ string_of_id id) - ); - let casted_args = - List.map2 (fun arg param_ctyp -> - if not (ctyp_equal (cval_ctyp arg) param_ctyp) then ( + let monomorphized_variants = + List.map + (fun inst -> + let substs = + KBindings.of_seq (List.map2 (fun x y -> (x, y)) (KidSet.elements typ_params) inst |> List.to_seq) + in + ( mangle_mono_id var_id ctx inst, + List.map + (fun (ctor_id, ctyp) -> + (mangle_mono_id ctor_id ctx inst, fix_variants ctx var_id (subst_poly substs ctyp)) + ) + ctors + ) + ) + (CTListSet.elements !instantiations) + in + let ctx = + List.fold_left + (fun ctx (id, ctors) -> + { ctx with variants = Bindings.add id ([], Bindings.of_seq (List.to_seq ctors)) ctx.variants } + ) + ctx monomorphized_variants + in + let mangled_ctors = + List.map + (fun (_, monomorphized_ctors) -> + List.map2 + (fun (ctor_id, _) (monomorphized_id, _) -> mangled_pragma ctor_id monomorphized_id) + ctors monomorphized_ctors + ) + monomorphized_variants + |> List.concat + in + + let prior = List.map (cdef_map_ctyp (fix_variants ctx var_id)) prior in + let cdefs = List.map (cdef_map_ctyp (fix_variants ctx var_id)) cdefs in + + let ctx = + { + ctx with + valspecs = + Bindings.map + (fun (extern, param_ctyps, ret_ctyp) -> + (extern, List.map (fix_variants ctx var_id) param_ctyps, fix_variants ctx var_id ret_ctyp) + ) + ctx.valspecs; + } + in + let ctx = { ctx with variants = Bindings.remove var_id ctx.variants } in + + specialize_variants ctx + (List.concat + (List.map + (fun (id, ctors) -> [CDEF_type (CTD_variant (id, ctors)); mangled_pragma var_id id]) + monomorphized_variants + ) + @ mangled_ctors @ prior + ) + cdefs + | CDEF_type (CTD_struct (struct_id, fields)) :: cdefs when List.exists (fun (_, ctyp) -> is_polymorphic ctyp) fields + -> + let typ_params = List.fold_left (fun set (_, ctyp) -> KidSet.union (ctyp_vars ctyp) set) KidSet.empty fields in + + let cdefs = List.map (cdef_map_instr (specialize_field ctx struct_id)) cdefs in + let monomorphized_structs = + List.map + (fun inst -> + let substs = + KBindings.of_seq (List.map2 (fun x y -> (x, y)) (KidSet.elements typ_params) inst |> List.to_seq) + in + ( mangle_mono_id struct_id ctx inst, + List.map + (fun (field_id, ctyp) -> (field_id, fix_variants ctx struct_id (subst_poly substs ctyp))) + fields + ) + ) + (CTListSet.elements !instantiations) + in + let mangled_fields = + List.map + (fun (_, monomorphized_fields) -> + List.map2 + (fun (field_id, _) (monomorphized_id, _) -> mangled_pragma field_id monomorphized_id) + fields monomorphized_fields + ) + monomorphized_structs + |> List.concat + in + + let prior = List.map (cdef_map_ctyp (fix_variants ctx struct_id)) prior in + let cdefs = List.map (cdef_map_ctyp (fix_variants ctx struct_id)) cdefs in + let ctx = + { + ctx with + valspecs = + Bindings.map + (fun (extern, param_ctyps, ret_ctyp) -> + (extern, List.map (fix_variants ctx struct_id) param_ctyps, fix_variants ctx struct_id ret_ctyp) + ) + ctx.valspecs; + } + in + + let ctx = + List.fold_left + (fun ctx (id, fields) -> + { ctx with records = Bindings.add id ([], Bindings.of_seq (List.to_seq fields)) ctx.records } + ) + ctx monomorphized_structs + in + let ctx = { ctx with records = Bindings.remove struct_id ctx.records } in + + specialize_variants ctx + (List.concat + (List.map + (fun (id, fields) -> [CDEF_type (CTD_struct (id, fields)); mangled_pragma struct_id id]) + monomorphized_structs + ) + @ mangled_fields @ prior + ) + cdefs + | cdef :: cdefs -> specialize_variants ctx (cdef :: prior) cdefs + | [] -> (List.rev prior, ctx) + + let make_calls_precise ctx cdefs = + let constructor_types = ref Bindings.empty in + + let get_function_typ id = + match Bindings.find_opt id ctx.valspecs with + | None -> Bindings.find_opt id !constructor_types + | Some (_, param_ctyps, ret_ctyp) -> Some (param_ctyps, ret_ctyp) + in + + let precise_call call tail = + match call with + | I_aux (I_funcall (clexp, extern, (id, ctyp_args), args), ((_, l) as aux)) as instr -> begin + match get_function_typ id with + | None when string_of_id id = "sail_cons" -> begin + match (ctyp_args, args) with + | [ctyp_arg], [hd_arg; tl_arg] -> + if not (ctyp_equal (cval_ctyp hd_arg) ctyp_arg) then ( + let gs = ngensym () in + let cast = [idecl l ctyp_arg gs; icopy l (CL_id (gs, ctyp_arg)) hd_arg] in + let cleanup = [iclear ~loc:l ctyp_arg gs] in + [ + iblock + (cast + @ [I_aux (I_funcall (clexp, extern, (id, ctyp_args), [V_id (gs, ctyp_arg); tl_arg]), aux)] + @ tail @ cleanup + ); + ] + ) + else instr :: tail + | _ -> + (* cons must have a single type parameter and two arguments *) + Reporting.unreachable (id_loc id) __POS__ "Invalid cons call" + end + | None -> instr :: tail + | Some (param_ctyps, ret_ctyp) -> + if List.compare_lengths args param_ctyps <> 0 then + Reporting.unreachable (id_loc id) __POS__ + ("Function call found with incorrect arity: " ^ string_of_id id); + let casted_args = + List.map2 + (fun arg param_ctyp -> + if not (ctyp_equal (cval_ctyp arg) param_ctyp) then ( + let gs = ngensym () in + let cast = [idecl l param_ctyp gs; icopy l (CL_id (gs, param_ctyp)) arg] in + let cleanup = [iclear ~loc:l param_ctyp gs] in + (cast, V_id (gs, param_ctyp), cleanup) + ) + else ([], arg, []) + ) + args param_ctyps + in + let ret_setup, clexp, ret_cleanup = + if not (ctyp_equal (clexp_ctyp clexp) ret_ctyp) then ( let gs = ngensym () in - let cast = [ - idecl l param_ctyp gs; - icopy l (CL_id (gs, param_ctyp)) arg - ] in - let cleanup = [ - iclear ~loc:l param_ctyp gs - ] in - (cast, V_id (gs, param_ctyp), cleanup) - ) else ( - ([], arg, []) + ( [idecl l ret_ctyp gs], + CL_id (gs, ret_ctyp), + [icopy l clexp (V_id (gs, ret_ctyp)); iclear ~loc:l ret_ctyp gs] + ) ) - ) args param_ctyps - in - let ret_setup, clexp, ret_cleanup = - if not (ctyp_equal (clexp_ctyp clexp) ret_ctyp) then - let gs = ngensym () in - ([idecl l ret_ctyp gs], (CL_id (gs, ret_ctyp)), [icopy l clexp (V_id (gs, ret_ctyp)); iclear ~loc:l ret_ctyp gs]) - else - ([], clexp, []) - in - let casts = List.map (fun (x, _, _) -> x) casted_args |> List.concat in - let args = List.map (fun (_, y, _) -> y) casted_args in - let cleanup = List.rev_map (fun (_, _, z) -> z) casted_args |> List.concat in - [iblock1 (casts @ ret_setup @ [I_aux (I_funcall (clexp, extern, (id, ctyp_args), args), aux)] @ tail @ ret_cleanup @ cleanup)] - end - - | instr -> instr::tail - in + else ([], clexp, []) + in + let casts = List.map (fun (x, _, _) -> x) casted_args |> List.concat in + let args = List.map (fun (_, y, _) -> y) casted_args in + let cleanup = List.rev_map (fun (_, _, z) -> z) casted_args |> List.concat in + [ + iblock1 + (casts @ ret_setup + @ [I_aux (I_funcall (clexp, extern, (id, ctyp_args), args), aux)] + @ tail @ ret_cleanup @ cleanup + ); + ] + end + | instr -> instr :: tail + in - let rec precise_calls prior = function - | (CDEF_type (CTD_variant (var_id, ctors)) as cdef) :: cdefs -> - List.iter (fun (id, ctyp) -> - constructor_types := Bindings.add id ([ctyp], CT_variant (var_id, ctors)) !constructor_types - ) ctors; - precise_calls (cdef :: prior) cdefs - - | cdef :: cdefs -> - precise_calls (cdef_map_funcall precise_call cdef :: prior) cdefs - - | [] -> - List.rev prior - in - precise_calls [] cdefs - -(** Once we specialize variants, there may be additional type + let rec precise_calls prior = function + | (CDEF_type (CTD_variant (var_id, ctors)) as cdef) :: cdefs -> + List.iter + (fun (id, ctyp) -> + constructor_types := Bindings.add id ([ctyp], CT_variant (var_id, ctors)) !constructor_types + ) + ctors; + precise_calls (cdef :: prior) cdefs + | cdef :: cdefs -> precise_calls (cdef_map_funcall precise_call cdef :: prior) cdefs + | [] -> List.rev prior + in + precise_calls [] cdefs + + (** Once we specialize variants, there may be additional type dependencies which could be in the wrong order. As such we need to sort the type definitions in the list of cdefs. *) -let sort_ctype_defs reverse cdefs = - (* Split the cdefs into type definitions and non type definitions *) - let is_ctype_def = function CDEF_type _ -> true | _ -> false in - let unwrap = function CDEF_type ctdef -> ctdef | _ -> assert false in - let ctype_defs = List.map unwrap (List.filter is_ctype_def cdefs) in - let cdefs = List.filter (fun cdef -> not (is_ctype_def cdef)) cdefs in - - let ctdef_id = function - | CTD_enum (id, _) | CTD_struct (id, _) | CTD_variant (id, _) -> id - in - - let ctdef_ids = function - | CTD_enum _ -> IdSet.empty - | CTD_struct (_, ctors) | CTD_variant (_, ctors) -> - List.fold_left (fun ids (_, ctyp) -> IdSet.union (ctyp_ids ctyp) ids) IdSet.empty ctors - in - - (* Create a reverse (i.e. from types to the types that are dependent - upon them) id graph of dependencies between types *) - let module IdGraph = Graph.Make(Id) in - - let graph = - List.fold_left (fun g ctdef -> - List.fold_left (fun g id -> IdGraph.add_edge id (ctdef_id ctdef) g) - (IdGraph.add_edges (ctdef_id ctdef) [] g) (* Make sure even types with no dependencies are in graph *) - (IdSet.elements (ctdef_ids ctdef))) - IdGraph.empty - ctype_defs - in + let sort_ctype_defs reverse cdefs = + (* Split the cdefs into type definitions and non type definitions *) + let is_ctype_def = function CDEF_type _ -> true | _ -> false in + let unwrap = function CDEF_type ctdef -> ctdef | _ -> assert false in + let ctype_defs = List.map unwrap (List.filter is_ctype_def cdefs) in + let cdefs = List.filter (fun cdef -> not (is_ctype_def cdef)) cdefs in + + let ctdef_id = function CTD_enum (id, _) | CTD_struct (id, _) | CTD_variant (id, _) -> id in + + let ctdef_ids = function + | CTD_enum _ -> IdSet.empty + | CTD_struct (_, ctors) | CTD_variant (_, ctors) -> + List.fold_left (fun ids (_, ctyp) -> IdSet.union (ctyp_ids ctyp) ids) IdSet.empty ctors + in - (* Then select the ctypes in the correct order as given by the topsort *) - let ids = IdGraph.topsort graph in - let ctype_defs = - List.map (fun id -> CDEF_type (List.find (fun ctdef -> Id.compare (ctdef_id ctdef) id = 0) ctype_defs)) ids - in + (* Create a reverse (i.e. from types to the types that are dependent + upon them) id graph of dependencies between types *) + let module IdGraph = Graph.Make (Id) in + let graph = + List.fold_left + (fun g ctdef -> + List.fold_left + (fun g id -> IdGraph.add_edge id (ctdef_id ctdef) g) + (IdGraph.add_edges (ctdef_id ctdef) [] g) (* Make sure even types with no dependencies are in graph *) + (IdSet.elements (ctdef_ids ctdef)) + ) + IdGraph.empty ctype_defs + in - (if reverse then List.rev ctype_defs else ctype_defs) @ cdefs + (* Then select the ctypes in the correct order as given by the topsort *) + let ids = IdGraph.topsort graph in + let ctype_defs = + List.map (fun id -> CDEF_type (List.find (fun ctdef -> Id.compare (ctdef_id ctdef) id = 0) ctype_defs)) ids + in -let toplevel_lets_of_ast ast = - let toplevel_lets_of_def = function - | DEF_aux (DEF_let (LB_aux (LB_val (pat, _), _)), _) -> pat_ids pat - | _ -> IdSet.empty - in - let toplevel_lets_of_defs defs = - List.fold_left IdSet.union IdSet.empty (List.map toplevel_lets_of_def defs) - in - toplevel_lets_of_defs ast.defs |> IdSet.elements - -let compile_ast ctx ast = - let module G = Graph.Make(Callgraph.Node) in - let g = Callgraph.graph_of_ast ast in - let module NodeSet = Set.Make(Callgraph.Node) in - let roots = Specialize.get_initial_calls () |> List.map (fun id -> Callgraph.Function id) |> NodeSet.of_list in - let roots = NodeSet.add (Callgraph.Type (mk_id "exception")) roots in - let roots = Bindings.fold (fun typ_id _ roots -> NodeSet.add (Callgraph.Type typ_id) roots) (Env.get_enums ctx.tc_env) roots in - let roots = NodeSet.union (toplevel_lets_of_ast ast |> List.map (fun id -> Callgraph.Letbind id) |> NodeSet.of_list) roots in - let g = G.prune roots NodeSet.empty g in - let ast = Callgraph.filter_ast NodeSet.empty g ast in - - if !opt_memo_cache then - (try - if Sys.is_directory "_sbuild" then - () - else - raise (Reporting.err_general Parse_ast.Unknown "_sbuild exists, but is a file not a directory!") - with - | Sys_error _ -> Unix.mkdir "_sbuild" 0o775) - else (); - - let total = List.length ast.defs in - let _, chunks, ctx = - List.fold_left (fun (n, chunks, ctx) def -> let defs, ctx = compile_def n total ctx def in n + 1, defs :: chunks, ctx) (1, [], ctx) ast.defs - in - let cdefs = List.concat (List.rev chunks) in - - (* If we don't have an exception type, add a dummy one *) - let dummy_exn = mk_id "__dummy_exn#" in - let cdefs, ctx = - if not (Bindings.mem (mk_id "exception") ctx.variants) then - CDEF_type (CTD_variant (mk_id "exception", [(dummy_exn, CT_unit)])) :: cdefs, - { ctx with variants = Bindings.add (mk_id "exception") ([], Bindings.singleton dummy_exn CT_unit) ctx.variants } - else - cdefs, ctx - in - let cdefs, ctx = specialize_functions ctx cdefs in - let cdefs = sort_ctype_defs true cdefs in - let cdefs, ctx = specialize_variants ctx [] cdefs in - let cdefs = if C.specialize_calls then cdefs else make_calls_precise ctx cdefs in - let cdefs = sort_ctype_defs false cdefs in - cdefs, ctx + (if reverse then List.rev ctype_defs else ctype_defs) @ cdefs + let toplevel_lets_of_ast ast = + let toplevel_lets_of_def = function + | DEF_aux (DEF_let (LB_aux (LB_val (pat, _), _)), _) -> pat_ids pat + | _ -> IdSet.empty + in + let toplevel_lets_of_defs defs = List.fold_left IdSet.union IdSet.empty (List.map toplevel_lets_of_def defs) in + toplevel_lets_of_defs ast.defs |> IdSet.elements + + let compile_ast ctx ast = + let module G = Graph.Make (Callgraph.Node) in + let g = Callgraph.graph_of_ast ast in + let module NodeSet = Set.Make (Callgraph.Node) in + let roots = Specialize.get_initial_calls () |> List.map (fun id -> Callgraph.Function id) |> NodeSet.of_list in + let roots = NodeSet.add (Callgraph.Type (mk_id "exception")) roots in + let roots = + Bindings.fold (fun typ_id _ roots -> NodeSet.add (Callgraph.Type typ_id) roots) (Env.get_enums ctx.tc_env) roots + in + let roots = + NodeSet.union (toplevel_lets_of_ast ast |> List.map (fun id -> Callgraph.Letbind id) |> NodeSet.of_list) roots + in + let g = G.prune roots NodeSet.empty g in + let ast = Callgraph.filter_ast NodeSet.empty g ast in + + if !opt_memo_cache then ( + try + if Sys.is_directory "_sbuild" then () + else raise (Reporting.err_general Parse_ast.Unknown "_sbuild exists, but is a file not a directory!") + with Sys_error _ -> Unix.mkdir "_sbuild" 0o775 + ) + else (); + + let total = List.length ast.defs in + let _, chunks, ctx = + List.fold_left + (fun (n, chunks, ctx) def -> + let defs, ctx = compile_def n total ctx def in + (n + 1, defs :: chunks, ctx) + ) + (1, [], ctx) ast.defs + in + let cdefs = List.concat (List.rev chunks) in + + (* If we don't have an exception type, add a dummy one *) + let dummy_exn = mk_id "__dummy_exn#" in + let cdefs, ctx = + if not (Bindings.mem (mk_id "exception") ctx.variants) then + ( CDEF_type (CTD_variant (mk_id "exception", [(dummy_exn, CT_unit)])) :: cdefs, + { + ctx with + variants = Bindings.add (mk_id "exception") ([], Bindings.singleton dummy_exn CT_unit) ctx.variants; + } + ) + else (cdefs, ctx) + in + let cdefs, ctx = specialize_functions ctx cdefs in + let cdefs = sort_ctype_defs true cdefs in + let cdefs, ctx = specialize_variants ctx [] cdefs in + let cdefs = if C.specialize_calls then cdefs else make_calls_precise ctx cdefs in + let cdefs = sort_ctype_defs false cdefs in + (cdefs, ctx) end let add_special_functions env effect_info = @@ -2089,4 +2179,4 @@ let add_special_functions env effect_info = let effect_info = Effects.add_monadic_built_in (mk_id "sail_assert") effect_info in let effect_info = Effects.add_monadic_built_in (mk_id "sail_exit") effect_info in - snd (Type_error.check_defs env [assert_vs; exit_vs; cons_vs]), effect_info + (snd (Type_error.check_defs env [assert_vs; exit_vs; cons_vs]), effect_info) diff --git a/src/lib/jib_compile.mli b/src/lib/jib_compile.mli index 51880837b..6774b6ec3 100644 --- a/src/lib/jib_compile.mli +++ b/src/lib/jib_compile.mli @@ -93,19 +93,19 @@ val opt_memo_cache : bool ref (** Dynamic context for compiling Sail to Jib. We need to pass a (global) typechecking environment given by checking the full AST. *) -type ctx = - { records : (kid list * ctyp Bindings.t) Bindings.t; - enums : IdSet.t Bindings.t; - variants : (kid list * ctyp Bindings.t) Bindings.t; - valspecs : (string option * ctyp list * ctyp) Bindings.t; - quants : ctyp KBindings.t; - local_env : Env.t; - tc_env : Env.t; - effect_info : Effects.side_effect_info; - locals : (mut * ctyp) Bindings.t; - letbinds : int list; - no_raw : bool; - } +type ctx = { + records : (kid list * ctyp Bindings.t) Bindings.t; + enums : IdSet.t Bindings.t; + variants : (kid list * ctyp Bindings.t) Bindings.t; + valspecs : (string option * ctyp list * ctyp) Bindings.t; + quants : ctyp KBindings.t; + local_env : Env.t; + tc_env : Env.t; + effect_info : Effects.side_effect_info; + locals : (mut * ctyp) Bindings.t; + letbinds : int list; + no_raw : bool; +} val ctx_is_extern : id -> ctx -> bool @@ -160,10 +160,10 @@ end module IdGraph : sig include Graph.S with type node = id end - + val callgraph : cdef list -> IdGraph.graph -module Make(C: Config) : sig +module Make (C : Config) : sig (** Compile a Sail definition into a Jib definition. The first two arguments are is the current definition number and the total number of definitions, and can be used to drive a progress bar diff --git a/src/lib/jib_optimize.ml b/src/lib/jib_optimize.ml index e712c7ecb..0eb211845 100644 --- a/src/lib/jib_optimize.ml +++ b/src/lib/jib_optimize.ml @@ -71,31 +71,20 @@ open Jib_compile open Jib_util let optimize_unit instrs = - let unit_cval cval = - match cval_ctyp cval with - | CT_unit -> (V_lit (VL_unit, CT_unit)) - | _ -> cval - in + let unit_cval cval = match cval_ctyp cval with CT_unit -> V_lit (VL_unit, CT_unit) | _ -> cval in let unit_instr = function - | I_aux (I_funcall (clexp, extern, id, args), annot) as instr -> - begin match clexp_ctyp clexp with - | CT_unit -> - I_aux (I_funcall (CL_void, extern, id, List.map unit_cval args), annot) - | _ -> instr - end - | I_aux (I_copy (clexp, cval), annot) as instr -> - begin match clexp_ctyp clexp with - | CT_unit -> - I_aux (I_copy (CL_void, unit_cval cval), annot) - | _ -> instr - end + | I_aux (I_funcall (clexp, extern, id, args), annot) as instr -> begin + match clexp_ctyp clexp with + | CT_unit -> I_aux (I_funcall (CL_void, extern, id, List.map unit_cval args), annot) + | _ -> instr + end + | I_aux (I_copy (clexp, cval), annot) as instr -> begin + match clexp_ctyp clexp with CT_unit -> I_aux (I_copy (CL_void, unit_cval cval), annot) | _ -> instr + end | instr -> instr in let non_pointless_copy (I_aux (aux, annot)) = - match aux with - | I_decl (CT_unit, _) -> false - | I_copy (CL_void, _) -> false - | _ -> true + match aux with I_decl (CT_unit, _) -> false | I_copy (CL_void, _) -> false | _ -> true in filter_instrs non_pointless_copy (map_instr_list unit_instr instrs) @@ -110,42 +99,33 @@ let flat_id orig_id = let rec flatten_instrs = function | I_aux (I_decl (ctyp, decl_id), aux) :: instrs -> - let fid = flat_id decl_id in - I_aux (I_decl (ctyp, fid), aux) :: flatten_instrs (instrs_rename decl_id fid instrs) - + let fid = flat_id decl_id in + I_aux (I_decl (ctyp, fid), aux) :: flatten_instrs (instrs_rename decl_id fid instrs) | I_aux (I_init (ctyp, decl_id, cval), aux) :: instrs -> - let fid = flat_id decl_id in - I_aux (I_init (ctyp, fid, cval), aux) :: flatten_instrs (instrs_rename decl_id fid instrs) - - | I_aux ((I_block block | I_try_block block), _) :: instrs -> - flatten_instrs block @ flatten_instrs instrs - + let fid = flat_id decl_id in + I_aux (I_init (ctyp, fid, cval), aux) :: flatten_instrs (instrs_rename decl_id fid instrs) + | I_aux ((I_block block | I_try_block block), _) :: instrs -> flatten_instrs block @ flatten_instrs instrs | I_aux (I_if (cval, then_instrs, else_instrs, _), (_, l)) :: instrs -> - let then_label = label "then_" in - let endif_label = label "endif_" in - [ijump l cval then_label] - @ flatten_instrs else_instrs - @ [igoto endif_label] - @ [ilabel then_label] - @ flatten_instrs then_instrs - @ [ilabel endif_label] - @ flatten_instrs instrs - + let then_label = label "then_" in + let endif_label = label "endif_" in + [ijump l cval then_label] + @ flatten_instrs else_instrs + @ [igoto endif_label] + @ [ilabel then_label] + @ flatten_instrs then_instrs + @ [ilabel endif_label] + @ flatten_instrs instrs | I_aux (I_comment _, _) :: instrs -> flatten_instrs instrs - | instr :: instrs -> instr :: flatten_instrs instrs | [] -> [] -let flatten_cdef = - function +let flatten_cdef = function | CDEF_fundef (function_id, heap_return, args, body) -> - flat_counter := 0; - CDEF_fundef (function_id, heap_return, args, flatten_instrs body) - + flat_counter := 0; + CDEF_fundef (function_id, heap_return, args, flatten_instrs body) | CDEF_let (n, bindings, instrs) -> - flat_counter := 0; - CDEF_let (n, bindings, flatten_instrs instrs) - + flat_counter := 0; + CDEF_let (n, bindings, flatten_instrs instrs) | cdef -> cdef let unique_per_function_ids cdefs = @@ -155,15 +135,14 @@ let unique_per_function_ids cdefs = in let rec unique_instrs i = function | I_aux (I_decl (ctyp, id), aux) :: rest -> - I_aux (I_decl (ctyp, unique_id i id), aux) :: unique_instrs i (instrs_rename id (unique_id i id) rest) + I_aux (I_decl (ctyp, unique_id i id), aux) :: unique_instrs i (instrs_rename id (unique_id i id) rest) | I_aux (I_init (ctyp, id, cval), aux) :: rest -> - I_aux (I_init (ctyp, unique_id i id, cval), aux) :: unique_instrs i (instrs_rename id (unique_id i id) rest) - | I_aux (I_block instrs, aux) :: rest -> - I_aux (I_block (unique_instrs i instrs), aux) :: unique_instrs i rest + I_aux (I_init (ctyp, unique_id i id, cval), aux) :: unique_instrs i (instrs_rename id (unique_id i id) rest) + | I_aux (I_block instrs, aux) :: rest -> I_aux (I_block (unique_instrs i instrs), aux) :: unique_instrs i rest | I_aux (I_try_block instrs, aux) :: rest -> - I_aux (I_try_block (unique_instrs i instrs), aux) :: unique_instrs i rest + I_aux (I_try_block (unique_instrs i instrs), aux) :: unique_instrs i rest | I_aux (I_if (cval, then_instrs, else_instrs, ctyp), aux) :: rest -> - I_aux (I_if (cval, unique_instrs i then_instrs, unique_instrs i else_instrs, ctyp), aux) :: unique_instrs i rest + I_aux (I_if (cval, unique_instrs i then_instrs, unique_instrs i else_instrs, ctyp), aux) :: unique_instrs i rest | instr :: instrs -> instr :: unique_instrs i instrs | [] -> [] in @@ -187,7 +166,7 @@ let rec cval_subst id subst = function | V_tuple_member (cval, len, n) -> V_tuple_member (cval_subst id subst cval, len, n) | V_ctor_kind (cval, ctor, ctyp) -> V_ctor_kind (cval_subst id subst cval, ctor, ctyp) | V_ctor_unwrap (cval, ctor, ctyp) -> V_ctor_unwrap (cval_subst id subst cval, ctor, ctyp) - | V_struct (fields, ctyp) -> V_struct (List.map (fun (field, cval) -> field, cval_subst id subst cval) fields, ctyp) + | V_struct (fields, ctyp) -> V_struct (List.map (fun (field, cval) -> (field, cval_subst id subst cval)) fields, ctyp) | V_tuple (members, ctyp) -> V_tuple (List.map (cval_subst id subst) members, ctyp) let rec cval_map_id f = function @@ -198,52 +177,43 @@ let rec cval_map_id f = function | V_tuple_member (cval, len, n) -> V_tuple_member (cval_map_id f cval, len, n) | V_ctor_kind (cval, ctor, ctyp) -> V_ctor_kind (cval_map_id f cval, ctor, ctyp) | V_ctor_unwrap (cval, ctor, ctyp) -> V_ctor_unwrap (cval_map_id f cval, ctor, ctyp) - | V_struct (fields, ctyp) -> - V_struct (List.map (fun (field, cval) -> field, cval_map_id f cval) fields, ctyp) - | V_tuple (members, ctyp) -> - V_tuple (List.map (cval_map_id f) members, ctyp) - -let rec instrs_subst id subst = - function - | (I_aux (I_decl (_, id'), _) :: _) as instrs when Name.compare id id' = 0 -> - instrs + | V_struct (fields, ctyp) -> V_struct (List.map (fun (field, cval) -> (field, cval_map_id f cval)) fields, ctyp) + | V_tuple (members, ctyp) -> V_tuple (List.map (cval_map_id f) members, ctyp) +let rec instrs_subst id subst = function + | I_aux (I_decl (_, id'), _) :: _ as instrs when Name.compare id id' = 0 -> instrs | I_aux (I_init (ctyp, id', cval), aux) :: rest when Name.compare id id' = 0 -> - I_aux (I_init (ctyp, id', cval_subst id subst cval), aux) :: rest - - | (I_aux (I_reset (_, id'), _) :: _) as instrs when Name.compare id id' = 0 -> - instrs - + I_aux (I_init (ctyp, id', cval_subst id subst cval), aux) :: rest + | I_aux (I_reset (_, id'), _) :: _ as instrs when Name.compare id id' = 0 -> instrs | I_aux (I_reinit (ctyp, id', cval), aux) :: rest when Name.compare id id' = 0 -> - I_aux (I_reinit (ctyp, id', cval_subst id subst cval), aux) :: rest - + I_aux (I_reinit (ctyp, id', cval_subst id subst cval), aux) :: rest | I_aux (instr, aux) :: instrs -> - let instrs = instrs_subst id subst instrs in - let instr = match instr with - | I_decl (ctyp, id') -> I_decl (ctyp, id') - | I_init (ctyp, id', cval) -> I_init (ctyp, id', cval_subst id subst cval) - | I_jump (cval, label) -> I_jump (cval_subst id subst cval, label) - | I_goto label -> I_goto label - | I_label label -> I_label label - | I_funcall (clexp, extern, fid, args) -> I_funcall (clexp, extern, fid, List.map (cval_subst id subst) args) - | I_copy (clexp, cval) -> I_copy (clexp, cval_subst id subst cval) - | I_clear (clexp, id') -> I_clear (clexp, id') - | I_undefined ctyp -> I_undefined ctyp - | I_exit cause -> I_exit cause - | I_end id' -> I_end id' - | I_if (cval, then_instrs, else_instrs, ctyp) -> - I_if (cval_subst id subst cval, instrs_subst id subst then_instrs, instrs_subst id subst else_instrs, ctyp) - | I_block instrs -> I_block (instrs_subst id subst instrs) - | I_try_block instrs -> I_try_block (instrs_subst id subst instrs) - | I_throw cval -> I_throw (cval_subst id subst cval) - | I_comment str -> I_comment str - | I_raw str -> I_raw str - | I_return cval -> I_return (cval_subst id subst cval) - | I_reset (ctyp, id') -> I_reset (ctyp, id') - | I_reinit (ctyp, id', cval) -> I_reinit (ctyp, id', cval_subst id subst cval) - in - I_aux (instr, aux) :: instrs - + let instrs = instrs_subst id subst instrs in + let instr = + match instr with + | I_decl (ctyp, id') -> I_decl (ctyp, id') + | I_init (ctyp, id', cval) -> I_init (ctyp, id', cval_subst id subst cval) + | I_jump (cval, label) -> I_jump (cval_subst id subst cval, label) + | I_goto label -> I_goto label + | I_label label -> I_label label + | I_funcall (clexp, extern, fid, args) -> I_funcall (clexp, extern, fid, List.map (cval_subst id subst) args) + | I_copy (clexp, cval) -> I_copy (clexp, cval_subst id subst cval) + | I_clear (clexp, id') -> I_clear (clexp, id') + | I_undefined ctyp -> I_undefined ctyp + | I_exit cause -> I_exit cause + | I_end id' -> I_end id' + | I_if (cval, then_instrs, else_instrs, ctyp) -> + I_if (cval_subst id subst cval, instrs_subst id subst then_instrs, instrs_subst id subst else_instrs, ctyp) + | I_block instrs -> I_block (instrs_subst id subst instrs) + | I_try_block instrs -> I_try_block (instrs_subst id subst instrs) + | I_throw cval -> I_throw (cval_subst id subst cval) + | I_comment str -> I_comment str + | I_raw str -> I_raw str + | I_return cval -> I_return (cval_subst id subst cval) + | I_reset (ctyp, id') -> I_reset (ctyp, id') + | I_reinit (ctyp, id', cval) -> I_reinit (ctyp, id', cval_subst id subst cval) + in + I_aux (instr, aux) :: instrs | [] -> [] let rec clexp_subst id subst = function @@ -256,11 +226,8 @@ let rec clexp_subst id subst = function | CL_rmw _ -> Reporting.unreachable Parse_ast.Unknown __POS__ "Cannot substitute into read-modify-write construct" let rec find_function fid = function - | CDEF_fundef (fid', heap_return, args, body) :: _ when Id.compare fid fid' = 0 -> - Some (heap_return, args, body) - + | CDEF_fundef (fid', heap_return, args, body) :: _ when Id.compare fid fid' = 0 -> Some (heap_return, args, body) | cdef :: cdefs -> find_function fid cdefs - | [] -> None let ssa_name i = function @@ -277,9 +244,8 @@ let inline cdefs should_inline instrs = let replace_return subst = function | I_aux (I_funcall (clexp, extern, fid, args), aux) -> - I_aux (I_funcall (clexp_subst return subst clexp, extern, fid, args), aux) - | I_aux (I_copy (clexp, cval), aux) -> - I_aux (I_copy (clexp_subst return subst clexp, cval), aux) + I_aux (I_funcall (clexp_subst return subst clexp, extern, fid, args), aux) + | I_aux (I_copy (clexp, cval), aux) -> I_aux (I_copy (clexp_subst return subst clexp, cval), aux) | instr -> instr in @@ -301,90 +267,76 @@ let inline cdefs should_inline instrs = let fix_substs = let f = cval_map_id (ssa_name (-1)) in function - | I_aux (I_init (ctyp, id, cval), aux) -> - I_aux (I_init (ctyp, id, f cval), aux) - | I_aux (I_jump (cval, label), aux) -> - I_aux (I_jump (f cval, label), aux) + | I_aux (I_init (ctyp, id, cval), aux) -> I_aux (I_init (ctyp, id, f cval), aux) + | I_aux (I_jump (cval, label), aux) -> I_aux (I_jump (f cval, label), aux) | I_aux (I_funcall (clexp, extern, function_id, args), aux) -> - I_aux (I_funcall (clexp, extern, function_id, List.map f args), aux) + I_aux (I_funcall (clexp, extern, function_id, List.map f args), aux) | I_aux (I_if (cval, then_instrs, else_instrs, ctyp), aux) -> - I_aux (I_if (f cval, then_instrs, else_instrs, ctyp), aux) - | I_aux (I_copy (clexp, cval), aux) -> - I_aux (I_copy (clexp, f cval), aux) - | I_aux (I_return cval, aux) -> - I_aux (I_return (f cval), aux) - | I_aux (I_throw cval, aux) -> - I_aux (I_throw (f cval), aux) + I_aux (I_if (f cval, then_instrs, else_instrs, ctyp), aux) + | I_aux (I_copy (clexp, cval), aux) -> I_aux (I_copy (clexp, f cval), aux) + | I_aux (I_return cval, aux) -> I_aux (I_return (f cval), aux) + | I_aux (I_throw cval, aux) -> I_aux (I_throw (f cval), aux) | instr -> instr in let inline_instr = function - | I_aux (I_funcall (clexp, false, function_id, args), aux) as instr when should_inline (fst function_id) -> - begin match find_function (fst function_id) cdefs with - | Some (None, ids, body) -> - incr inlines; - incr label_count; - let inline_label = label "end_inline_" in - (* For situations where we have e.g. x => x' and x' => y, we - use a dummy SSA number turning this into x => x'/-2 and - x' => y/-2, ensuring x's won't get turned into y's. This - is undone by fix_substs which removes the -2 SSA - numbers. *) - let args = List.map (cval_map_id (ssa_name (-2))) args in - let body = List.fold_right2 instrs_subst (List.map name ids) args body in - let body = List.map (map_instr fix_substs) body in - let body = List.map (map_instr fix_labels) body in - let body = List.map (map_instr (replace_end inline_label)) body in - let body = List.map (map_instr (replace_return clexp)) body in - I_aux (I_block (body @ [ilabel inline_label]), aux) - | Some (Some _, ids, body) -> - (* Some _ is only introduced by C backend, so we don't - expect it at this point. *) - raise (Reporting.err_general (snd aux) "Unexpected return method in IR") - | None -> instr - end + | I_aux (I_funcall (clexp, false, function_id, args), aux) as instr when should_inline (fst function_id) -> begin + match find_function (fst function_id) cdefs with + | Some (None, ids, body) -> + incr inlines; + incr label_count; + let inline_label = label "end_inline_" in + (* For situations where we have e.g. x => x' and x' => y, we + use a dummy SSA number turning this into x => x'/-2 and + x' => y/-2, ensuring x's won't get turned into y's. This + is undone by fix_substs which removes the -2 SSA + numbers. *) + let args = List.map (cval_map_id (ssa_name (-2))) args in + let body = List.fold_right2 instrs_subst (List.map name ids) args body in + let body = List.map (map_instr fix_substs) body in + let body = List.map (map_instr fix_labels) body in + let body = List.map (map_instr (replace_end inline_label)) body in + let body = List.map (map_instr (replace_return clexp)) body in + I_aux (I_block (body @ [ilabel inline_label]), aux) + | Some (Some _, ids, body) -> + (* Some _ is only introduced by C backend, so we don't + expect it at this point. *) + raise (Reporting.err_general (snd aux) "Unexpected return method in IR") + | None -> instr + end | instr -> instr in let rec go instrs = - if !inlines <> 0 then - begin - inlines := 0; - let instrs = List.map (map_instr inline_instr) instrs in - go instrs - end - else - instrs + if !inlines <> 0 then begin + inlines := 0; + let instrs = List.map (map_instr inline_instr) instrs in + go instrs + end + else instrs in go instrs let remove_pointless_goto instrs = let rec go acc = function | I_aux (I_goto label, _) :: I_aux (I_label label', aux) :: instrs when label = label' -> - go (I_aux (I_label label', aux) :: acc) instrs - | I_aux (I_goto label, aux) :: I_aux (I_goto _, _) :: instrs -> - go (I_aux (I_goto label, aux) :: acc) instrs - | instr :: instrs -> - go (instr :: acc) instrs - | [] -> - List.rev acc + go (I_aux (I_label label', aux) :: acc) instrs + | I_aux (I_goto label, aux) :: I_aux (I_goto _, _) :: instrs -> go (I_aux (I_goto label, aux) :: acc) instrs + | instr :: instrs -> go (instr :: acc) instrs + | [] -> List.rev acc in go [] instrs let remove_pointless_exit instrs = let rec go acc = function - | I_aux (I_end id, aux) :: I_aux (I_end _, _) :: instrs -> - go (I_aux (I_end id, aux) :: acc) instrs - | I_aux (I_end id, aux) :: I_aux (I_undefined _, _) :: instrs -> - go (I_aux (I_end id, aux) :: acc) instrs - | instr :: instrs -> - go (instr :: acc) instrs - | [] -> - List.rev acc + | I_aux (I_end id, aux) :: I_aux (I_end _, _) :: instrs -> go (I_aux (I_end id, aux) :: acc) instrs + | I_aux (I_end id, aux) :: I_aux (I_undefined _, _) :: instrs -> go (I_aux (I_end id, aux) :: acc) instrs + | instr :: instrs -> go (instr :: acc) instrs + | [] -> List.rev acc in go [] instrs - -module StringSet = Set.Make(String) + +module StringSet = Set.Make (String) let rec get_used_labels set = function | I_aux (I_goto label, _) :: instrs -> get_used_labels (StringSet.add label set) instrs @@ -417,10 +369,7 @@ let rec remove_dead_code instrs = let instrs' = instrs |> remove_unused_labels |> remove_pointless_goto |> remove_dead_after_goto |> remove_pointless_exit in - if List.length instrs' < List.length instrs then - remove_dead_code instrs' - else - instrs' + if List.length instrs' < List.length instrs then remove_dead_code instrs' else instrs' let rec remove_clear = function | I_aux (I_clear _, _) :: instrs -> remove_clear instrs @@ -430,93 +379,80 @@ let rec remove_clear = function let remove_tuples cdefs ctx = let already_removed = ref CTSet.empty in let rec all_tuples = function - | CT_tup ctyps as ctyp -> - CTSet.add ctyp (List.fold_left CTSet.union CTSet.empty (List.map all_tuples ctyps)) + | CT_tup ctyps as ctyp -> CTSet.add ctyp (List.fold_left CTSet.union CTSet.empty (List.map all_tuples ctyps)) | CT_struct (_, id_ctyps) | CT_variant (_, id_ctyps) -> - List.fold_left (fun cts (_, ctyp) -> CTSet.union (all_tuples ctyp) cts) CTSet.empty id_ctyps - | CT_list ctyp | CT_vector (_, ctyp) | CT_fvector (_, _, ctyp) | CT_ref ctyp -> - all_tuples ctyp - | CT_lint | CT_fint _ | CT_lbits _ | CT_sbits _ | CT_fbits _ | CT_constant _ | CT_float _ - | CT_unit | CT_bool | CT_real | CT_bit | CT_poly _ | CT_string | CT_enum _ | CT_rounding_mode -> - CTSet.empty + List.fold_left (fun cts (_, ctyp) -> CTSet.union (all_tuples ctyp) cts) CTSet.empty id_ctyps + | CT_list ctyp | CT_vector (_, ctyp) | CT_fvector (_, _, ctyp) | CT_ref ctyp -> all_tuples ctyp + | CT_lint | CT_fint _ | CT_lbits _ | CT_sbits _ | CT_fbits _ | CT_constant _ | CT_float _ | CT_unit | CT_bool + | CT_real | CT_bit | CT_poly _ | CT_string | CT_enum _ | CT_rounding_mode -> + CTSet.empty in let rec tuple_depth = function - | CT_tup ctyps -> - 1 + List.fold_left (fun d ctyp -> max d (tuple_depth ctyp)) 0 ctyps + | CT_tup ctyps -> 1 + List.fold_left (fun d ctyp -> max d (tuple_depth ctyp)) 0 ctyps | CT_struct (_, id_ctyps) | CT_variant (_, id_ctyps) -> - List.fold_left (fun d (_, ctyp) -> max (tuple_depth ctyp) d) 0 id_ctyps - | CT_list ctyp | CT_vector (_, ctyp) | CT_fvector (_, _, ctyp) | CT_ref ctyp -> - tuple_depth ctyp - | CT_lint | CT_fint _ | CT_lbits _ | CT_sbits _ | CT_fbits _ | CT_constant _ | CT_unit | CT_bool - | CT_real | CT_bit | CT_poly _ | CT_string | CT_enum _ | CT_float _ | CT_rounding_mode -> - 0 + List.fold_left (fun d (_, ctyp) -> max (tuple_depth ctyp) d) 0 id_ctyps + | CT_list ctyp | CT_vector (_, ctyp) | CT_fvector (_, _, ctyp) | CT_ref ctyp -> tuple_depth ctyp + | CT_lint | CT_fint _ | CT_lbits _ | CT_sbits _ | CT_fbits _ | CT_constant _ | CT_unit | CT_bool | CT_real | CT_bit + | CT_poly _ | CT_string | CT_enum _ | CT_float _ | CT_rounding_mode -> + 0 in let rec fix_tuples = function | CT_tup ctyps -> - let ctyps = List.map fix_tuples ctyps in - let name = "tuple#" ^ Util.string_of_list "_" string_of_ctyp ctyps in - CT_struct (mk_id name, List.mapi (fun n ctyp -> mk_id (name ^ string_of_int n), ctyp) ctyps) - | CT_struct (id, id_ctyps) -> - CT_struct (id, List.map (fun (id, ctyp) -> id, fix_tuples ctyp) id_ctyps) - | CT_variant (id, id_ctyps) -> - CT_variant (id, List.map (fun (id, ctyp) -> id, fix_tuples ctyp) id_ctyps) + let ctyps = List.map fix_tuples ctyps in + let name = "tuple#" ^ Util.string_of_list "_" string_of_ctyp ctyps in + CT_struct (mk_id name, List.mapi (fun n ctyp -> (mk_id (name ^ string_of_int n), ctyp)) ctyps) + | CT_struct (id, id_ctyps) -> CT_struct (id, List.map (fun (id, ctyp) -> (id, fix_tuples ctyp)) id_ctyps) + | CT_variant (id, id_ctyps) -> CT_variant (id, List.map (fun (id, ctyp) -> (id, fix_tuples ctyp)) id_ctyps) | CT_list ctyp -> CT_list (fix_tuples ctyp) | CT_vector (d, ctyp) -> CT_vector (d, fix_tuples ctyp) | CT_fvector (n, d, ctyp) -> CT_fvector (n, d, fix_tuples ctyp) | CT_ref ctyp -> CT_ref (fix_tuples ctyp) - | (CT_lint | CT_fint _ | CT_lbits _ | CT_sbits _ | CT_fbits _ | CT_constant _ | CT_float _ - | CT_unit | CT_bool | CT_real | CT_bit | CT_poly _ | CT_string | CT_enum _ | CT_rounding_mode) as ctyp -> - ctyp + | ( CT_lint | CT_fint _ | CT_lbits _ | CT_sbits _ | CT_fbits _ | CT_constant _ | CT_float _ | CT_unit | CT_bool + | CT_real | CT_bit | CT_poly _ | CT_string | CT_enum _ | CT_rounding_mode ) as ctyp -> + ctyp and fix_cval = function | V_id (id, ctyp) -> V_id (id, ctyp) | V_lit (vl, ctyp) -> V_lit (vl, ctyp) - | V_ctor_kind (cval, ctor, ctyp) -> - V_ctor_kind (fix_cval cval, ctor, ctyp) - | V_ctor_unwrap (cval, ctor, ctyp) -> - V_ctor_unwrap (fix_cval cval, ctor, ctyp) + | V_ctor_kind (cval, ctor, ctyp) -> V_ctor_kind (fix_cval cval, ctor, ctyp) + | V_ctor_unwrap (cval, ctor, ctyp) -> V_ctor_unwrap (fix_cval cval, ctor, ctyp) | V_tuple_member (cval, _, n) -> - let ctyp = fix_tuples (cval_ctyp cval) in - let cval = fix_cval cval in - let field = match ctyp with - | CT_struct (id, _) -> - mk_id (string_of_id id ^ string_of_int n) - | _ -> assert false - in - V_field (cval, field) - | V_call (op, cvals) -> - V_call (op, List.map (fix_cval) cvals) - | V_field (cval, field) -> - V_field (fix_cval cval, field) - | V_struct (fields, ctyp) -> V_struct (List.map (fun (id, cval) -> id, fix_cval cval) fields, ctyp) - | V_tuple (members, ctyp) -> - begin match ctyp with - | CT_tup ctyps -> - let ctyps = List.map fix_tuples ctyps in - let name = "tuple#" ^ Util.string_of_list "_" string_of_ctyp ctyps in - let struct_ctyp = CT_struct (mk_id name, List.mapi (fun n ctyp -> mk_id (name ^ string_of_int n), ctyp) ctyps) in - V_struct (List.mapi (fun n member -> mk_id (name ^ string_of_int n), fix_cval member) members, struct_ctyp) - | _ -> Reporting.unreachable Parse_ast.Unknown __POS__ "Tuple without tuple type" - end + let ctyp = fix_tuples (cval_ctyp cval) in + let cval = fix_cval cval in + let field = + match ctyp with CT_struct (id, _) -> mk_id (string_of_id id ^ string_of_int n) | _ -> assert false + in + V_field (cval, field) + | V_call (op, cvals) -> V_call (op, List.map fix_cval cvals) + | V_field (cval, field) -> V_field (fix_cval cval, field) + | V_struct (fields, ctyp) -> V_struct (List.map (fun (id, cval) -> (id, fix_cval cval)) fields, ctyp) + | V_tuple (members, ctyp) -> begin + match ctyp with + | CT_tup ctyps -> + let ctyps = List.map fix_tuples ctyps in + let name = "tuple#" ^ Util.string_of_list "_" string_of_ctyp ctyps in + let struct_ctyp = + CT_struct (mk_id name, List.mapi (fun n ctyp -> (mk_id (name ^ string_of_int n), ctyp)) ctyps) + in + V_struct (List.mapi (fun n member -> (mk_id (name ^ string_of_int n), fix_cval member)) members, struct_ctyp) + | _ -> Reporting.unreachable Parse_ast.Unknown __POS__ "Tuple without tuple type" + end in let rec fix_clexp = function | CL_id (id, ctyp) -> CL_id (id, ctyp) | CL_addr clexp -> CL_addr (fix_clexp clexp) | CL_tuple (clexp, n) -> - let ctyp = fix_tuples (clexp_ctyp clexp) in - let clexp = fix_clexp clexp in - let field = match ctyp with - | CT_struct (id, _) -> - mk_id (string_of_id id ^ string_of_int n) - | _ -> assert false - in - CL_field (clexp, field) + let ctyp = fix_tuples (clexp_ctyp clexp) in + let clexp = fix_clexp clexp in + let field = + match ctyp with CT_struct (id, _) -> mk_id (string_of_id id ^ string_of_int n) | _ -> assert false + in + CL_field (clexp, field) | CL_field (clexp, field) -> CL_field (fix_clexp clexp, field) | CL_void -> CL_void | CL_rmw (read, write, ctyp) -> CL_rmw (read, write, ctyp) in let rec fix_instr_aux = function - | I_funcall (clexp, extern, id, args) -> - I_funcall (fix_clexp clexp, extern, id, List.map fix_cval args) + | I_funcall (clexp, extern, id, args) -> I_funcall (fix_clexp clexp, extern, id, List.map fix_cval args) | I_copy (clexp, cval) -> I_copy (fix_clexp clexp, fix_cval cval) | I_init (ctyp, id, cval) -> I_init (ctyp, id, fix_cval cval) | I_reinit (ctyp, id, cval) -> I_reinit (ctyp, id, fix_cval cval) @@ -524,66 +460,69 @@ let remove_tuples cdefs ctx = | I_throw cval -> I_throw (fix_cval cval) | I_return cval -> I_return (fix_cval cval) | I_if (cval, then_instrs, else_instrs, ctyp) -> - I_if (fix_cval cval, List.map fix_instr then_instrs, List.map fix_instr else_instrs, ctyp) + I_if (fix_cval cval, List.map fix_instr then_instrs, List.map fix_instr else_instrs, ctyp) | I_block instrs -> I_block (List.map fix_instr instrs) | I_try_block instrs -> I_try_block (List.map fix_instr instrs) - | (I_goto _ | I_label _ | I_decl _ | I_clear _ | I_end _ | I_comment _ - | I_reset _ | I_undefined _ | I_exit _ | I_raw _) as instr -> instr - and fix_instr (I_aux (instr, aux)) = I_aux (fix_instr_aux instr, aux) - in + | ( I_goto _ | I_label _ | I_decl _ | I_clear _ | I_end _ | I_comment _ | I_reset _ | I_undefined _ | I_exit _ + | I_raw _ ) as instr -> + instr + and fix_instr (I_aux (instr, aux)) = I_aux (fix_instr_aux instr, aux) in let fix_conversions = function - | I_aux (I_copy (clexp, cval), (_, l)) as instr -> - begin match clexp_ctyp clexp, cval_ctyp cval with - | CT_tup lhs_ctyps, CT_tup rhs_ctyps when List.length lhs_ctyps = List.length rhs_ctyps -> - let elems = List.length lhs_ctyps in - if List.for_all2 ctyp_equal lhs_ctyps rhs_ctyps then - [instr] - else - List.mapi (fun n _ -> icopy l (CL_tuple (clexp, n)) (V_tuple_member (cval, elems, n))) lhs_ctyps - | _ -> [instr] - end + | I_aux (I_copy (clexp, cval), (_, l)) as instr -> begin + match (clexp_ctyp clexp, cval_ctyp cval) with + | CT_tup lhs_ctyps, CT_tup rhs_ctyps when List.length lhs_ctyps = List.length rhs_ctyps -> + let elems = List.length lhs_ctyps in + if List.for_all2 ctyp_equal lhs_ctyps rhs_ctyps then [instr] + else List.mapi (fun n _ -> icopy l (CL_tuple (clexp, n)) (V_tuple_member (cval, elems, n))) lhs_ctyps + | _ -> [instr] + end | instr -> [instr] in let fix_ctx ctx = - { ctx with - records = Bindings.map (fun (params, fields) -> params, Bindings.map fix_tuples fields) ctx.records; - variants = Bindings.map (fun (params, ctors) -> params, Bindings.map fix_tuples ctors) ctx.variants; - valspecs = Bindings.map (fun (extern, ctyps, ctyp) -> extern, List.map fix_tuples ctyps, fix_tuples ctyp) ctx.valspecs; - locals = Bindings.map (fun (mut, ctyp) -> mut, fix_tuples ctyp) ctx.locals + { + ctx with + records = Bindings.map (fun (params, fields) -> (params, Bindings.map fix_tuples fields)) ctx.records; + variants = Bindings.map (fun (params, ctors) -> (params, Bindings.map fix_tuples ctors)) ctx.variants; + valspecs = + Bindings.map (fun (extern, ctyps, ctyp) -> (extern, List.map fix_tuples ctyps, fix_tuples ctyp)) ctx.valspecs; + locals = Bindings.map (fun (mut, ctyp) -> (mut, fix_tuples ctyp)) ctx.locals; } in let to_struct = function | CT_tup ctyps -> - let ctyps = List.map fix_tuples ctyps in - let name = "tuple#" ^ Util.string_of_list "_" string_of_ctyp ctyps in - let fields = List.mapi (fun n ctyp -> mk_id (name ^ string_of_int n), ctyp) ctyps in - [CDEF_type (CTD_struct (mk_id name, fields)); - CDEF_pragma ("tuplestruct", Util.string_of_list " " (fun x -> x) (Util.zencode_string name :: List.map (fun (id, _) -> Util.zencode_string (string_of_id id)) fields))] + let ctyps = List.map fix_tuples ctyps in + let name = "tuple#" ^ Util.string_of_list "_" string_of_ctyp ctyps in + let fields = List.mapi (fun n ctyp -> (mk_id (name ^ string_of_int n), ctyp)) ctyps in + [ + CDEF_type (CTD_struct (mk_id name, fields)); + CDEF_pragma + ( "tuplestruct", + Util.string_of_list " " + (fun x -> x) + (Util.zencode_string name :: List.map (fun (id, _) -> Util.zencode_string (string_of_id id)) fields) + ); + ] | _ -> assert false in let rec go acc = function | cdef :: cdefs -> - let tuples = CTSet.fold (fun ctyp -> CTSet.union (all_tuples ctyp)) (cdef_ctyps cdef) CTSet.empty in - let tuples = CTSet.diff tuples !already_removed in - (* In the case where we have ((x, y), z) and (x, y) we need to - generate (x, y) first, so we sort by the depth of nesting in - the tuples (note we build acc in reverse order) *) - let sorted_tuples = - CTSet.elements tuples - |> List.map (fun ctyp -> tuple_depth ctyp, ctyp) - |> List.sort (fun (d1, _) (d2, _) -> compare d2 d1) - |> List.map snd - in - let structs = List.concat (List.map to_struct sorted_tuples) in - already_removed := CTSet.union tuples !already_removed; - let cdef = - cdef - |> cdef_concatmap_instr fix_conversions - |> cdef_map_instr fix_instr - |> cdef_map_ctyp fix_tuples - in - go (cdef :: structs @ acc) cdefs + let tuples = CTSet.fold (fun ctyp -> CTSet.union (all_tuples ctyp)) (cdef_ctyps cdef) CTSet.empty in + let tuples = CTSet.diff tuples !already_removed in + (* In the case where we have ((x, y), z) and (x, y) we need to + generate (x, y) first, so we sort by the depth of nesting in + the tuples (note we build acc in reverse order) *) + let sorted_tuples = + CTSet.elements tuples + |> List.map (fun ctyp -> (tuple_depth ctyp, ctyp)) + |> List.sort (fun (d1, _) (d2, _) -> compare d2 d1) + |> List.map snd + in + let structs = List.concat (List.map to_struct sorted_tuples) in + already_removed := CTSet.union tuples !already_removed; + let cdef = + cdef |> cdef_concatmap_instr fix_conversions |> cdef_map_instr fix_instr |> cdef_map_ctyp fix_tuples + in + go ((cdef :: structs) @ acc) cdefs | [] -> List.rev acc in - go [] cdefs, - fix_ctx ctx + (go [] cdefs, fix_ctx ctx) diff --git a/src/lib/jib_optimize.mli b/src/lib/jib_optimize.mli index f7cab1265..de8c0d884 100644 --- a/src/lib/jib_optimize.mli +++ b/src/lib/jib_optimize.mli @@ -76,6 +76,7 @@ val optimize_unit : instr list -> instr list (** Remove all instructions that can contain other nested instructions, prodcing a flat list of instructions. *) val flatten_instrs : instr list -> instr list + val flatten_cdef : cdef -> cdef val reset_flat_counter : unit -> unit diff --git a/src/lib/jib_util.ml b/src/lib/jib_util.ml index fc12ec1d7..be6c58f9d 100644 --- a/src/lib/jib_util.ml +++ b/src/lib/jib_util.ml @@ -91,79 +91,58 @@ let instr_number () = incr instr_counter; n -let idecl l ctyp id = - I_aux (I_decl (ctyp, id), (instr_number (), l)) +let idecl l ctyp id = I_aux (I_decl (ctyp, id), (instr_number (), l)) -let ireset l ctyp id = - I_aux (I_reset (ctyp, id), (instr_number (), l)) +let ireset l ctyp id = I_aux (I_reset (ctyp, id), (instr_number (), l)) -let iinit l ctyp id cval = - I_aux (I_init (ctyp, id, cval), (instr_number (), l)) +let iinit l ctyp id cval = I_aux (I_init (ctyp, id, cval), (instr_number (), l)) -let iif l cval then_instrs else_instrs ctyp = - I_aux (I_if (cval, then_instrs, else_instrs, ctyp), (instr_number (), l)) +let iif l cval then_instrs else_instrs ctyp = I_aux (I_if (cval, then_instrs, else_instrs, ctyp), (instr_number (), l)) -let ifuncall l clexp id cvals = - I_aux (I_funcall (clexp, false, id, cvals), (instr_number (), l)) +let ifuncall l clexp id cvals = I_aux (I_funcall (clexp, false, id, cvals), (instr_number (), l)) -let iextern l clexp id cvals = - I_aux (I_funcall (clexp, true, id, cvals), (instr_number (), l)) +let iextern l clexp id cvals = I_aux (I_funcall (clexp, true, id, cvals), (instr_number (), l)) -let icopy l clexp cval = - I_aux (I_copy (clexp, cval), (instr_number (), l)) +let icopy l clexp cval = I_aux (I_copy (clexp, cval), (instr_number (), l)) -let iclear ?loc:(l=Parse_ast.Unknown) ctyp id = - I_aux (I_clear (ctyp, id), (instr_number (), l)) +let iclear ?loc:(l = Parse_ast.Unknown) ctyp id = I_aux (I_clear (ctyp, id), (instr_number (), l)) -let ireturn ?loc:(l=Parse_ast.Unknown) cval = - I_aux (I_return cval, (instr_number (), l)) +let ireturn ?loc:(l = Parse_ast.Unknown) cval = I_aux (I_return cval, (instr_number (), l)) -let iend l = - I_aux (I_end (Return (-1)), (instr_number (), l)) +let iend l = I_aux (I_end (Return (-1)), (instr_number (), l)) -let iblock ?loc:(l=Parse_ast.Unknown) instrs = - I_aux (I_block instrs, (instr_number (), l)) +let iblock ?loc:(l = Parse_ast.Unknown) instrs = I_aux (I_block instrs, (instr_number (), l)) -let itry_block l instrs = - I_aux (I_try_block instrs, (instr_number (), l)) +let itry_block l instrs = I_aux (I_try_block instrs, (instr_number (), l)) -let ithrow l cval = - I_aux (I_throw cval, (instr_number (), l)) +let ithrow l cval = I_aux (I_throw cval, (instr_number (), l)) -let icomment ?loc:(l=Parse_ast.Unknown) str = - I_aux (I_comment str, (instr_number (), l)) +let icomment ?loc:(l = Parse_ast.Unknown) str = I_aux (I_comment str, (instr_number (), l)) -let ilabel ?loc:(l=Parse_ast.Unknown) label = - I_aux (I_label label, (instr_number (), l)) +let ilabel ?loc:(l = Parse_ast.Unknown) label = I_aux (I_label label, (instr_number (), l)) -let igoto ?loc:(l=Parse_ast.Unknown) label = - I_aux (I_goto label, (instr_number (), l)) +let igoto ?loc:(l = Parse_ast.Unknown) label = I_aux (I_goto label, (instr_number (), l)) -let iundefined ?loc:(l=Parse_ast.Unknown) ctyp = - I_aux (I_undefined ctyp, (instr_number (), l)) +let iundefined ?loc:(l = Parse_ast.Unknown) ctyp = I_aux (I_undefined ctyp, (instr_number (), l)) -let imatch_failure l = - I_aux (I_exit "match", (instr_number (), l)) +let imatch_failure l = I_aux (I_exit "match", (instr_number (), l)) -let iexit l = - I_aux (I_exit "explicit", (instr_number (), l)) - -let iraw ?loc:(l=Parse_ast.Unknown) str = - I_aux (I_raw str, (instr_number (), l)) +let iexit l = I_aux (I_exit "explicit", (instr_number (), l)) -let ijump l cval label = - I_aux (I_jump (cval, label), (instr_number (), l)) +let iraw ?loc:(l = Parse_ast.Unknown) str = I_aux (I_raw str, (instr_number (), l)) + +let ijump l cval label = I_aux (I_jump (cval, label), (instr_number (), l)) module Name = struct type t = name let compare id1 id2 = - match id1, id2 with + match (id1, id2) with | Name (x, n), Name (y, m) -> - let c1 = Id.compare x y in - if c1 = 0 then compare n m else c1 + let c1 = Id.compare x y in + if c1 = 0 then compare n m else c1 | Global (x, n), Global (y, m) -> - let c1 = Id.compare x y in - if c1 = 0 then compare n m else c1 + let c1 = Id.compare x y in + if c1 = 0 then compare n m else c1 | Have_exception n, Have_exception m -> compare n m | Current_exception n, Current_exception m -> compare n m | Return n, Return m -> compare n m @@ -179,8 +158,8 @@ module Name = struct | _, Throw_location _ -> -1 end -module NameSet = Set.Make(Name) -module NameMap = Map.Make(Name) +module NameSet = Set.Make (Name) +module NameMap = Map.Make (Name) let current_exception = Current_exception (-1) let have_exception = Have_exception (-1) @@ -200,9 +179,8 @@ let rec cval_rename from_id to_id = function | V_ctor_kind (f, ctor, ctyp) -> V_ctor_kind (cval_rename from_id to_id f, ctor, ctyp) | V_ctor_unwrap (f, ctor, ctyp) -> V_ctor_unwrap (cval_rename from_id to_id f, ctor, ctyp) | V_struct (fields, ctyp) -> - V_struct (List.map (fun (field, cval) -> field, cval_rename from_id to_id cval) fields, ctyp) - | V_tuple (members, ctyp) -> - V_tuple (List.map (cval_rename from_id to_id) members, ctyp) + V_struct (List.map (fun (field, cval) -> (field, cval_rename from_id to_id cval)) fields, ctyp) + | V_tuple (members, ctyp) -> V_tuple (List.map (cval_rename from_id to_id) members, ctyp) let rec map_cval g = function | V_id (id, ctyp) -> g (V_id (id, ctyp)) @@ -212,82 +190,60 @@ let rec map_cval g = function | V_tuple_member (f, len, n) -> g (V_tuple_member (map_cval g f, len, n)) | V_ctor_kind (f, ctor, ctyp) -> g (V_ctor_kind (map_cval g f, ctor, ctyp)) | V_ctor_unwrap (f, ctor, ctyp) -> g (V_ctor_unwrap (map_cval g f, ctor, ctyp)) - | V_struct (fields, ctyp) -> - g (V_struct (List.map (fun (field, cval) -> field, map_cval g cval) fields, ctyp)) - | V_tuple (members, ctyp) -> - g (V_tuple (List.map (map_cval g) members, ctyp)) - + | V_struct (fields, ctyp) -> g (V_struct (List.map (fun (field, cval) -> (field, map_cval g cval)) fields, ctyp)) + | V_tuple (members, ctyp) -> g (V_tuple (List.map (map_cval g) members, ctyp)) + let rec clexp_rename from_id to_id = function | CL_id (id, ctyp) when Name.compare id from_id = 0 -> CL_id (to_id, ctyp) | CL_id (id, ctyp) -> CL_id (id, ctyp) | CL_rmw (read, write, ctyp) -> - CL_rmw ((if Name.compare read from_id = 0 then to_id else read), - (if Name.compare write from_id = 0 then to_id else write), - ctyp) - | CL_field (clexp, field) -> - CL_field (clexp_rename from_id to_id clexp, field) - | CL_addr clexp -> - CL_addr (clexp_rename from_id to_id clexp) - | CL_tuple (clexp, n) -> - CL_tuple (clexp_rename from_id to_id clexp, n) + CL_rmw + ( (if Name.compare read from_id = 0 then to_id else read), + (if Name.compare write from_id = 0 then to_id else write), + ctyp + ) + | CL_field (clexp, field) -> CL_field (clexp_rename from_id to_id clexp, field) + | CL_addr clexp -> CL_addr (clexp_rename from_id to_id clexp) + | CL_tuple (clexp, n) -> CL_tuple (clexp_rename from_id to_id clexp, n) | CL_void -> CL_void let rec instr_rename from_id to_id (I_aux (instr, aux)) = - let instr = match instr with + let instr = + match instr with | I_decl (ctyp, id) when Name.compare id from_id = 0 -> I_decl (ctyp, to_id) | I_decl (ctyp, id) -> I_decl (ctyp, id) - - | I_init (ctyp, id, cval) when Name.compare id from_id = 0 -> - I_init (ctyp, to_id, cval_rename from_id to_id cval) - | I_init (ctyp, id, cval) -> - I_init (ctyp, id, cval_rename from_id to_id cval) - + | I_init (ctyp, id, cval) when Name.compare id from_id = 0 -> I_init (ctyp, to_id, cval_rename from_id to_id cval) + | I_init (ctyp, id, cval) -> I_init (ctyp, id, cval_rename from_id to_id cval) | I_if (cval, then_instrs, else_instrs, ctyp2) -> - I_if (cval_rename from_id to_id cval, - List.map (instr_rename from_id to_id) then_instrs, - List.map (instr_rename from_id to_id) else_instrs, - ctyp2) - + I_if + ( cval_rename from_id to_id cval, + List.map (instr_rename from_id to_id) then_instrs, + List.map (instr_rename from_id to_id) else_instrs, + ctyp2 + ) | I_jump (cval, label) -> I_jump (cval_rename from_id to_id cval, label) - | I_funcall (clexp, extern, id, args) -> - I_funcall (clexp_rename from_id to_id clexp, extern, id, List.map (cval_rename from_id to_id) args) - + I_funcall (clexp_rename from_id to_id clexp, extern, id, List.map (cval_rename from_id to_id) args) | I_copy (clexp, cval) -> I_copy (clexp_rename from_id to_id clexp, cval_rename from_id to_id cval) - | I_clear (ctyp, id) when Name.compare id from_id = 0 -> I_clear (ctyp, to_id) | I_clear (ctyp, id) -> I_clear (ctyp, id) - | I_return cval -> I_return (cval_rename from_id to_id cval) - | I_block instrs -> I_block (List.map (instr_rename from_id to_id) instrs) - | I_try_block instrs -> I_try_block (List.map (instr_rename from_id to_id) instrs) - | I_throw cval -> I_throw (cval_rename from_id to_id cval) - | I_comment str -> I_comment str - | I_raw str -> I_raw str - | I_label label -> I_label label - | I_goto label -> I_goto label - | I_undefined ctyp -> I_undefined ctyp - | I_exit cause -> I_exit cause - | I_end id when Name.compare id from_id = 0 -> I_end to_id | I_end id -> I_end id - | I_reset (ctyp, id) when Name.compare id from_id = 0 -> I_reset (ctyp, to_id) | I_reset (ctyp, id) -> I_reset (ctyp, id) - | I_reinit (ctyp, id, cval) when Name.compare id from_id = 0 -> - I_reinit (ctyp, to_id, cval_rename from_id to_id cval) - | I_reinit (ctyp, id, cval) -> - I_reinit (ctyp, id, cval_rename from_id to_id cval) + I_reinit (ctyp, to_id, cval_rename from_id to_id cval) + | I_reinit (ctyp, id, cval) -> I_reinit (ctyp, id, cval_rename from_id to_id cval) in I_aux (instr, aux) @@ -295,21 +251,16 @@ let rec instr_rename from_id to_id (I_aux (instr, aux)) = (* 1. Instruction pretty printer *) (**************************************************************************) -let string_of_name ?deref_current_exception:(dce=false) ?zencode:(zencode=true) = - let ssa_num n = if n = -1 then "" else ("/" ^ string_of_int n) in +let string_of_name ?deref_current_exception:(dce = false) ?(zencode = true) = + let ssa_num n = if n = -1 then "" else "/" ^ string_of_int n in function | Name (id, n) | Global (id, n) -> - (if zencode then Util.zencode_string (string_of_id id) else string_of_id id) ^ ssa_num n - | Have_exception n -> - "have_exception" ^ ssa_num n - | Return n -> - "return" ^ ssa_num n - | Current_exception n when dce -> - "(*current_exception)" ^ ssa_num n - | Current_exception n -> - "current_exception" ^ ssa_num n - | Throw_location n -> - "throw_location" ^ ssa_num n + (if zencode then Util.zencode_string (string_of_id id) else string_of_id id) ^ ssa_num n + | Have_exception n -> "have_exception" ^ ssa_num n + | Return n -> "return" ^ ssa_num n + | Current_exception n when dce -> "(*current_exception)" ^ ssa_num n + | Current_exception n -> "current_exception" ^ ssa_num n + | Throw_location n -> "throw_location" ^ ssa_num n let string_of_op = function | Bnot -> "@not" @@ -360,16 +311,16 @@ let rec string_of_ctyp = function | CT_string -> "%string" | CT_tup ctyps -> "(" ^ Util.string_of_list ", " string_of_ctyp ctyps ^ ")" | CT_struct (id, _fields) -> - "%struct " ^ Util.zencode_string (string_of_id id) - (* + "%struct " ^ Util.zencode_string (string_of_id id) + (* ^ "{" ^ Util.string_of_list ", " (fun ((id, _), ctyp) -> Util.zencode_string (string_of_id id) ^ " : " ^ string_of_ctyp ctyp) fields ^ "}" *) | CT_enum (id, _) -> "%enum " ^ Util.zencode_string (string_of_id id) | CT_variant (id, _ctors) -> - "%union " ^ Util.zencode_string (string_of_id id) - (* + "%union " ^ Util.zencode_string (string_of_id id) + (* ^ "{" ^ Util.string_of_list ", " (fun ((id, _), ctyp) -> Util.zencode_string (string_of_id id) ^ " : " ^ string_of_ctyp ctyp) ctors ^ "}" @@ -379,7 +330,7 @@ let rec string_of_ctyp = function | CT_list ctyp -> "%list(" ^ string_of_ctyp ctyp ^ ")" | CT_ref ctyp -> "&(" ^ string_of_ctyp ctyp ^ ")" | CT_poly kid -> string_of_kid kid - + and string_of_uid (id, ctyps) = match ctyps with | [] -> Util.zencode_string (string_of_id id) @@ -390,15 +341,13 @@ and string_of_uid (id, ctyps) = and full_string_of_ctyp = function | CT_tup ctyps -> "(" ^ Util.string_of_list ", " full_string_of_ctyp ctyps ^ ")" | CT_struct (id, ctors) -> - "struct " ^ string_of_id id - ^ "{" - ^ Util.string_of_list ", " (fun (id, ctyp) -> string_of_id id ^ " : " ^ full_string_of_ctyp ctyp) ctors - ^ "}" + "struct " ^ string_of_id id ^ "{" + ^ Util.string_of_list ", " (fun (id, ctyp) -> string_of_id id ^ " : " ^ full_string_of_ctyp ctyp) ctors + ^ "}" | CT_variant (id, ctors) -> - "union " ^ string_of_id id - ^ "{" - ^ Util.string_of_list ", " (fun (id, ctyp) -> string_of_id id ^ " : " ^ full_string_of_ctyp ctyp) ctors - ^ "}" + "union " ^ string_of_id id ^ "{" + ^ Util.string_of_list ", " (fun (id, ctyp) -> string_of_id id ^ " : " ^ full_string_of_ctyp ctyp) ctors + ^ "}" | CT_vector (true, ctyp) -> "vector(dec, " ^ full_string_of_ctyp ctyp ^ ")" | CT_vector (false, ctyp) -> "vector(inc, " ^ full_string_of_ctyp ctyp ^ ")" | CT_list ctyp -> "list(" ^ full_string_of_ctyp ctyp ^ ")" @@ -427,43 +376,38 @@ let rec string_of_cval = function | V_id (id, _) -> string_of_name id | V_lit (VL_undefined, ctyp) -> string_of_value VL_undefined ^ " : " ^ string_of_ctyp ctyp | V_lit (vl, ctyp) -> string_of_value vl - | V_call (op, cvals) -> - Printf.sprintf "%s(%s)" (string_of_op op) (Util.string_of_list ", " string_of_cval cvals) - | V_field (f, field) -> - Printf.sprintf "%s.%s" (string_of_cval f) (Util.zencode_string (string_of_id field)) - | V_tuple_member (f, _, n) -> - Printf.sprintf "%s.ztup%d" (string_of_cval f) n - | V_ctor_kind (f, ctor, _) -> - string_of_cval f ^ " is " ^ string_of_uid ctor - | V_ctor_unwrap (f, ctor, _) -> - string_of_cval f ^ " as " ^ string_of_uid ctor - | V_struct (fields, ctyp) -> - begin match ctyp with - | CT_struct (id, _) -> - Printf.sprintf "struct %s {%s}" - (Util.zencode_string (string_of_id id)) - (Util.string_of_list ", " (fun (field, cval) -> Util.zencode_string (string_of_id field) ^ " = " ^ string_of_cval cval) fields) - | _ -> - Reporting.unreachable Parse_ast.Unknown __POS__ "Struct without struct type found" - end - | V_tuple (members, _) -> - "(" ^ Util.string_of_list ", " string_of_cval members ^ ")" + | V_call (op, cvals) -> Printf.sprintf "%s(%s)" (string_of_op op) (Util.string_of_list ", " string_of_cval cvals) + | V_field (f, field) -> Printf.sprintf "%s.%s" (string_of_cval f) (Util.zencode_string (string_of_id field)) + | V_tuple_member (f, _, n) -> Printf.sprintf "%s.ztup%d" (string_of_cval f) n + | V_ctor_kind (f, ctor, _) -> string_of_cval f ^ " is " ^ string_of_uid ctor + | V_ctor_unwrap (f, ctor, _) -> string_of_cval f ^ " as " ^ string_of_uid ctor + | V_struct (fields, ctyp) -> begin + match ctyp with + | CT_struct (id, _) -> + Printf.sprintf "struct %s {%s}" + (Util.zencode_string (string_of_id id)) + (Util.string_of_list ", " + (fun (field, cval) -> Util.zencode_string (string_of_id field) ^ " = " ^ string_of_cval cval) + fields + ) + | _ -> Reporting.unreachable Parse_ast.Unknown __POS__ "Struct without struct type found" + end + | V_tuple (members, _) -> "(" ^ Util.string_of_list ", " string_of_cval members ^ ")" let rec map_ctyp f = function - | (CT_lint | CT_fint _ | CT_constant _ | CT_lbits _ | CT_fbits _ | CT_sbits _ | CT_float _ | CT_rounding_mode - | CT_bit | CT_unit | CT_bool | CT_real | CT_string | CT_poly _ | CT_enum _) as ctyp -> f ctyp + | ( CT_lint | CT_fint _ | CT_constant _ | CT_lbits _ | CT_fbits _ | CT_sbits _ | CT_float _ | CT_rounding_mode + | CT_bit | CT_unit | CT_bool | CT_real | CT_string | CT_poly _ | CT_enum _ ) as ctyp -> + f ctyp | CT_tup ctyps -> f (CT_tup (List.map (map_ctyp f) ctyps)) | CT_ref ctyp -> f (CT_ref (map_ctyp f ctyp)) | CT_vector (direction, ctyp) -> f (CT_vector (direction, map_ctyp f ctyp)) | CT_fvector (n, direction, ctyp) -> f (CT_fvector (n, direction, map_ctyp f ctyp)) | CT_list ctyp -> f (CT_list (map_ctyp f ctyp)) - | CT_struct (id, fields) -> - f (CT_struct (id, List.map (fun (id, ctyp) -> id, map_ctyp f ctyp) fields)) - | CT_variant (id, ctors) -> - f (CT_variant (id, List.map (fun (id, ctyp) -> id, map_ctyp f ctyp) ctors)) + | CT_struct (id, fields) -> f (CT_struct (id, List.map (fun (id, ctyp) -> (id, map_ctyp f ctyp)) fields)) + | CT_variant (id, ctors) -> f (CT_variant (id, List.map (fun (id, ctyp) -> (id, map_ctyp f ctyp)) ctors)) let rec ctyp_equal ctyp1 ctyp2 = - match ctyp1, ctyp2 with + match (ctyp1, ctyp2) with | CT_lint, CT_lint -> true | CT_lbits d1, CT_lbits d2 -> d1 = d2 | CT_sbits (m1, d1), CT_sbits (m2, d2) -> m1 = m2 && d1 = d2 @@ -478,8 +422,7 @@ let rec ctyp_equal ctyp1 ctyp2 = | CT_struct (id1, _), CT_struct (id2, _) -> Id.compare id1 id2 = 0 | CT_enum (id1, _), CT_enum (id2, _) -> Id.compare id1 id2 = 0 | CT_variant (id1, _), CT_variant (id2, _) -> Id.compare id1 id2 = 0 - | CT_tup ctyps1, CT_tup ctyps2 when List.length ctyps1 = List.length ctyps2 -> - List.for_all2 ctyp_equal ctyps1 ctyps2 + | CT_tup ctyps1, CT_tup ctyps2 when List.length ctyps1 = List.length ctyps2 -> List.for_all2 ctyp_equal ctyps1 ctyps2 | CT_string, CT_string -> true | CT_real, CT_real -> true | CT_vector (d1, ctyp1), CT_vector (d2, ctyp2) -> d1 = d2 && ctyp_equal ctyp1 ctyp2 @@ -491,77 +434,59 @@ let rec ctyp_equal ctyp1 ctyp2 = let rec ctyp_compare ctyp1 ctyp2 = let lex_ord c1 c2 = if c1 = 0 then c2 else c1 in - match ctyp1, ctyp2 with + match (ctyp1, ctyp2) with | CT_lint, CT_lint -> 0 | CT_lint, _ -> 1 | _, CT_lint -> -1 - | CT_fint n, CT_fint m -> compare n m | CT_fint _, _ -> 1 | _, CT_fint _ -> -1 - | CT_constant n, CT_constant m -> Big_int.compare n m | CT_constant _, _ -> 1 | _, CT_constant _ -> -1 - | CT_fbits (n, ord1), CT_fbits (m, ord2) -> lex_ord (compare n m) (compare ord1 ord2) | CT_fbits _, _ -> 1 | _, CT_fbits _ -> -1 - | CT_sbits (n, ord1), CT_sbits (m, ord2) -> lex_ord (compare n m) (compare ord1 ord2) | CT_sbits _, _ -> 1 | _, CT_sbits _ -> -1 - - | CT_lbits ord1 , CT_lbits ord2 -> compare ord1 ord2 + | CT_lbits ord1, CT_lbits ord2 -> compare ord1 ord2 | CT_lbits _, _ -> 1 | _, CT_lbits _ -> -1 - | CT_bit, CT_bit -> 0 | CT_bit, _ -> 1 | _, CT_bit -> -1 - | CT_unit, CT_unit -> 0 | CT_unit, _ -> 1 | _, CT_unit -> -1 - | CT_real, CT_real -> 0 | CT_real, _ -> 1 | _, CT_real -> -1 - | CT_float n, CT_float m -> compare n m | CT_float _, _ -> 1 | _, CT_float _ -> -1 - | CT_poly kid1, CT_poly kid2 -> Kid.compare kid1 kid2 | CT_poly _, _ -> 1 | _, CT_poly _ -> -1 - | CT_bool, CT_bool -> 0 | CT_bool, _ -> 1 | _, CT_bool -> -1 - | CT_string, CT_string -> 0 | CT_string, _ -> 1 | _, CT_string -> -1 - | CT_ref ctyp1, CT_ref ctyp2 -> ctyp_compare ctyp1 ctyp2 | CT_ref _, _ -> 1 | _, CT_ref _ -> -1 - | CT_list ctyp1, CT_list ctyp2 -> ctyp_compare ctyp1 ctyp2 | CT_list _, _ -> 1 | _, CT_list _ -> -1 - - | CT_vector (d1, ctyp1), CT_vector (d2, ctyp2) -> - lex_ord (ctyp_compare ctyp1 ctyp2) (compare d1 d2) + | CT_vector (d1, ctyp1), CT_vector (d2, ctyp2) -> lex_ord (ctyp_compare ctyp1 ctyp2) (compare d1 d2) | CT_vector _, _ -> 1 | _, CT_vector _ -> -1 - | CT_fvector (n1, d1, ctyp1), CT_fvector (n2, d2, ctyp2) -> - lex_ord (compare n1 n2) (lex_ord (ctyp_compare ctyp1 ctyp2) (compare d1 d2)) + lex_ord (compare n1 n2) (lex_ord (ctyp_compare ctyp1 ctyp2) (compare d1 d2)) | CT_fvector _, _ -> 1 | _, CT_fvector _ -> -1 - | ctyp1, ctyp2 -> String.compare (full_string_of_ctyp ctyp1) (full_string_of_ctyp ctyp2) module CT = struct @@ -573,10 +498,10 @@ module CTList = struct type t = ctyp list let compare ctyps1 ctyps2 = Util.compare_list ctyp_compare ctyps1 ctyps2 end - -module CTSet = Set.Make(CT) -module CTMap = Map.Make(CT) -module CTListSet = Set.Make(CTList) + +module CTSet = Set.Make (CT) +module CTMap = Map.Make (CT) +module CTListSet = Set.Make (CTList) let rec ctyp_vars = function | CT_poly kid -> KidSet.singleton kid @@ -613,47 +538,31 @@ let rec ctyp_suprema = function | CT_list ctyp -> CT_list (ctyp_suprema ctyp) | CT_ref ctyp -> CT_ref (ctyp_suprema ctyp) | CT_poly kid -> CT_poly kid - + let merge_unifiers kid ctyp1 ctyp2 = - if ctyp_equal ctyp1 ctyp2 then - Some ctyp2 - else if ctyp_equal (ctyp_suprema ctyp1) (ctyp_suprema ctyp2) then - Some (ctyp_suprema ctyp2) + if ctyp_equal ctyp1 ctyp2 then Some ctyp2 + else if ctyp_equal (ctyp_suprema ctyp1) (ctyp_suprema ctyp2) then Some (ctyp_suprema ctyp2) else Reporting.unreachable (kid_loc kid) __POS__ ("Invalid unifiers in IR " ^ string_of_ctyp ctyp1 ^ " and " ^ string_of_ctyp ctyp2 ^ " for " ^ string_of_kid kid) let rec ctyp_unify l ctyp1 ctyp2 = - match ctyp1, ctyp2 with + match (ctyp1, ctyp2) with | CT_tup ctyps1, CT_tup ctyps2 when List.length ctyps1 = List.length ctyps2 -> - List.fold_left (KBindings.union merge_unifiers) KBindings.empty (List.map2 (ctyp_unify l) ctyps1 ctyps2) - - | CT_vector (b1, ctyp1), CT_vector (b2, ctyp2) when b1 = b2 -> - ctyp_unify l ctyp1 ctyp2 - - | CT_vector (b1, ctyp1), CT_fvector (_, b2, ctyp2) when b1 = b2 -> - ctyp_unify l ctyp1 ctyp2 - - | CT_fvector (_, b1, ctyp1), CT_vector (b2, ctyp2) when b1 = b2 -> - ctyp_unify l ctyp1 ctyp2 - - | CT_fvector (n1, b1, ctyp1), CT_fvector (n2, b2, ctyp2) when b1 = b2 -> - ctyp_unify l ctyp1 ctyp2 - + List.fold_left (KBindings.union merge_unifiers) KBindings.empty (List.map2 (ctyp_unify l) ctyps1 ctyps2) + | CT_vector (b1, ctyp1), CT_vector (b2, ctyp2) when b1 = b2 -> ctyp_unify l ctyp1 ctyp2 + | CT_vector (b1, ctyp1), CT_fvector (_, b2, ctyp2) when b1 = b2 -> ctyp_unify l ctyp1 ctyp2 + | CT_fvector (_, b1, ctyp1), CT_vector (b2, ctyp2) when b1 = b2 -> ctyp_unify l ctyp1 ctyp2 + | CT_fvector (n1, b1, ctyp1), CT_fvector (n2, b2, ctyp2) when b1 = b2 -> ctyp_unify l ctyp1 ctyp2 | CT_list ctyp1, CT_list ctyp2 -> ctyp_unify l ctyp1 ctyp2 - - | CT_struct (id1, fields1), CT_struct (id2, fields2) - when List.length fields1 == List.length fields2 -> - List.fold_left (KBindings.union merge_unifiers) KBindings.empty (List.map2 (ctyp_unify l) (List.map snd fields1) (List.map snd fields2)) - - | CT_variant (id1, ctors1), CT_variant (id2, ctors2) - when List.length ctors1 == List.length ctors2 -> - List.fold_left (KBindings.union merge_unifiers) KBindings.empty (List.map2 (ctyp_unify l) (List.map snd ctors1) (List.map snd ctors2)) - + | CT_struct (id1, fields1), CT_struct (id2, fields2) when List.length fields1 == List.length fields2 -> + List.fold_left (KBindings.union merge_unifiers) KBindings.empty + (List.map2 (ctyp_unify l) (List.map snd fields1) (List.map snd fields2)) + | CT_variant (id1, ctors1), CT_variant (id2, ctors2) when List.length ctors1 == List.length ctors2 -> + List.fold_left (KBindings.union merge_unifiers) KBindings.empty + (List.map2 (ctyp_unify l) (List.map snd ctors1) (List.map snd ctors2)) | CT_ref ctyp1, CT_ref ctyp2 -> ctyp_unify l ctyp1 ctyp2 - | CT_poly kid, _ -> KBindings.singleton kid ctyp2 - | _, _ when ctyp_equal ctyp1 ctyp2 -> KBindings.empty | CT_lbits _, CT_fbits _ -> KBindings.empty | CT_lbits _, CT_sbits _ -> KBindings.empty @@ -668,38 +577,36 @@ let rec ctyp_unify l ctyp1 ctyp2 = | CT_fint _, CT_constant _ -> KBindings.empty | CT_constant _, CT_fint _ -> KBindings.empty | _, _ -> - Reporting.unreachable l __POS__ ("Invalid ctyp unifiers " ^ full_string_of_ctyp ctyp1 ^ " and " ^ full_string_of_ctyp ctyp2) - + Reporting.unreachable l __POS__ + ("Invalid ctyp unifiers " ^ full_string_of_ctyp ctyp1 ^ " and " ^ full_string_of_ctyp ctyp2) + let rec ctyp_ids = function | CT_enum (id, _) -> IdSet.singleton id | CT_struct (id, ctors) | CT_variant (id, ctors) -> - IdSet.add id (List.fold_left (fun ids (_, ctyp) -> IdSet.union (ctyp_ids ctyp) ids) IdSet.empty ctors) + IdSet.add id (List.fold_left (fun ids (_, ctyp) -> IdSet.union (ctyp_ids ctyp) ids) IdSet.empty ctors) | CT_tup ctyps -> List.fold_left (fun ids ctyp -> IdSet.union (ctyp_ids ctyp) ids) IdSet.empty ctyps | CT_vector (_, ctyp) | CT_fvector (_, _, ctyp) | CT_list ctyp | CT_ref ctyp -> ctyp_ids ctyp - | CT_lint | CT_fint _ | CT_constant _ | CT_lbits _ | CT_fbits _ | CT_sbits _ | CT_unit - | CT_bool | CT_real | CT_bit | CT_string | CT_poly _ | CT_float _ | CT_rounding_mode -> IdSet.empty + | CT_lint | CT_fint _ | CT_constant _ | CT_lbits _ | CT_fbits _ | CT_sbits _ | CT_unit | CT_bool | CT_real | CT_bit + | CT_string | CT_poly _ | CT_float _ | CT_rounding_mode -> + IdSet.empty let rec subst_poly substs = function - | CT_poly kid -> - begin match KBindings.find_opt kid substs with - | Some ctyp -> ctyp - | None -> CT_poly kid - end + | CT_poly kid -> begin match KBindings.find_opt kid substs with Some ctyp -> ctyp | None -> CT_poly kid end | CT_tup ctyps -> CT_tup (List.map (subst_poly substs) ctyps) | CT_list ctyp -> CT_list (subst_poly substs ctyp) | CT_vector (direction, ctyp) -> CT_vector (direction, subst_poly substs ctyp) | CT_fvector (n, direction, ctyp) -> CT_fvector (n, direction, subst_poly substs ctyp) | CT_ref ctyp -> CT_ref (subst_poly substs ctyp) - | CT_variant (id, ctors) -> - CT_variant (id, List.map (fun (ctor_id, ctyp) -> ctor_id, subst_poly substs ctyp) ctors) - | CT_struct (id, fields) -> - CT_struct (id, List.map (fun (ctor_id, ctyp) -> ctor_id, subst_poly substs ctyp) fields) - | (CT_lint | CT_fint _ | CT_constant _ | CT_unit | CT_bool | CT_bit | CT_string | CT_real - | CT_lbits _ | CT_fbits _ | CT_sbits _ | CT_enum _ | CT_float _ | CT_rounding_mode as ctyp) -> ctyp - + | CT_variant (id, ctors) -> CT_variant (id, List.map (fun (ctor_id, ctyp) -> (ctor_id, subst_poly substs ctyp)) ctors) + | CT_struct (id, fields) -> CT_struct (id, List.map (fun (ctor_id, ctyp) -> (ctor_id, subst_poly substs ctyp)) fields) + | ( CT_lint | CT_fint _ | CT_constant _ | CT_unit | CT_bool | CT_bit | CT_string | CT_real | CT_lbits _ | CT_fbits _ + | CT_sbits _ | CT_enum _ | CT_float _ | CT_rounding_mode ) as ctyp -> + ctyp + let rec is_polymorphic = function - | CT_lint | CT_fint _ | CT_constant _ | CT_lbits _ | CT_fbits _ | CT_sbits _ - | CT_bit | CT_unit | CT_bool | CT_real | CT_string | CT_float _ | CT_rounding_mode -> false + | CT_lint | CT_fint _ | CT_constant _ | CT_lbits _ | CT_fbits _ | CT_sbits _ | CT_bit | CT_unit | CT_bool | CT_real + | CT_string | CT_float _ | CT_rounding_mode -> + false | CT_tup ctyps -> List.exists is_polymorphic ctyps | CT_enum _ -> false | CT_struct (_, ctors) | CT_variant (_, ctors) -> List.exists (fun (_, ctyp) -> is_polymorphic ctyp) ctors @@ -716,35 +623,35 @@ let rec cval_deps = function | V_struct (fields, _) -> List.fold_left (fun ns (_, cval) -> NameSet.union ns (cval_deps cval)) NameSet.empty fields let rec clexp_deps = function - | CL_id (id, _) -> NameSet.empty, NameSet.singleton id - | CL_rmw (read, write, _) -> NameSet.singleton read, NameSet.singleton write + | CL_id (id, _) -> (NameSet.empty, NameSet.singleton id) + | CL_rmw (read, write, _) -> (NameSet.singleton read, NameSet.singleton write) | CL_field (clexp, _) -> clexp_deps clexp | CL_tuple (clexp, _) -> clexp_deps clexp | CL_addr clexp -> clexp_deps clexp - | CL_void -> NameSet.empty, NameSet.empty + | CL_void -> (NameSet.empty, NameSet.empty) (* Return the direct, read/write dependencies of a single instruction *) let instr_deps = function - | I_decl (_, id) -> NameSet.empty, NameSet.singleton id - | I_reset (_, id) -> NameSet.empty, NameSet.singleton id - | I_init (_, id, cval) | I_reinit (_, id, cval) -> cval_deps cval, NameSet.singleton id - | I_if (cval, _, _, _) -> cval_deps cval, NameSet.empty - | I_jump (cval, _) -> cval_deps cval, NameSet.empty + | I_decl (_, id) -> (NameSet.empty, NameSet.singleton id) + | I_reset (_, id) -> (NameSet.empty, NameSet.singleton id) + | I_init (_, id, cval) | I_reinit (_, id, cval) -> (cval_deps cval, NameSet.singleton id) + | I_if (cval, _, _, _) -> (cval_deps cval, NameSet.empty) + | I_jump (cval, _) -> (cval_deps cval, NameSet.empty) | I_funcall (clexp, _, _, cvals) -> - let reads, writes = clexp_deps clexp in - List.fold_left NameSet.union reads (List.map cval_deps cvals), writes + let reads, writes = clexp_deps clexp in + (List.fold_left NameSet.union reads (List.map cval_deps cvals), writes) | I_copy (clexp, cval) -> - let reads, writes = clexp_deps clexp in - NameSet.union reads (cval_deps cval), writes - | I_clear (_, id) -> NameSet.singleton id, NameSet.empty - | I_throw cval | I_return cval -> cval_deps cval, NameSet.empty - | I_block _ | I_try_block _ -> NameSet.empty, NameSet.empty - | I_comment _ | I_raw _ -> NameSet.empty, NameSet.empty - | I_label label -> NameSet.empty, NameSet.empty - | I_goto label -> NameSet.empty, NameSet.empty - | I_undefined _ -> NameSet.empty, NameSet.empty - | I_exit _ -> NameSet.empty, NameSet.empty - | I_end id -> NameSet.singleton id, NameSet.empty + let reads, writes = clexp_deps clexp in + (NameSet.union reads (cval_deps cval), writes) + | I_clear (_, id) -> (NameSet.singleton id, NameSet.empty) + | I_throw cval | I_return cval -> (cval_deps cval, NameSet.empty) + | I_block _ | I_try_block _ -> (NameSet.empty, NameSet.empty) + | I_comment _ | I_raw _ -> (NameSet.empty, NameSet.empty) + | I_label label -> (NameSet.empty, NameSet.empty) + | I_goto label -> (NameSet.empty, NameSet.empty) + | I_undefined _ -> (NameSet.empty, NameSet.empty) + | I_exit _ -> (NameSet.empty, NameSet.empty) + | I_end id -> (NameSet.singleton id, NameSet.empty) module NameCT = struct type t = name * ctyp @@ -753,8 +660,8 @@ module NameCT = struct if c = 0 then CT.compare ctyp1 ctyp2 else c end -module NameCTSet = Set.Make(NameCT) -module NameCTMap = Map.Make(NameCT) +module NameCTSet = Set.Make (NameCT) +module NameCTMap = Map.Make (NameCT) let rec clexp_typed_writes = function | CL_id (id, ctyp) -> NameCTSet.singleton (id, ctyp) @@ -782,30 +689,29 @@ let rec map_clexp_ctyp f = function let rec map_cval_ctyp f = function | V_id (id, ctyp) -> V_id (id, f ctyp) | V_lit (vl, ctyp) -> V_lit (vl, f ctyp) - | V_ctor_kind (cval, (id, unifiers), ctyp) -> - V_ctor_kind (map_cval_ctyp f cval, (id, List.map f unifiers), f ctyp) - | V_ctor_unwrap (cval, (id, unifiers), ctyp) -> - V_ctor_unwrap (map_cval_ctyp f cval, (id, List.map f unifiers), f ctyp) - | V_tuple_member (cval, i, j) -> - V_tuple_member (map_cval_ctyp f cval, i, j) - | V_call (op, cvals) -> - V_call (op, List.map (map_cval_ctyp f) cvals) - | V_field (cval, id) -> - V_field (map_cval_ctyp f cval, id) - | V_struct (fields, ctyp) -> - V_struct (List.map (fun (id, cval) -> id, map_cval_ctyp f cval) fields, f ctyp) - | V_tuple (members, ctyp) -> - V_tuple (List.map (map_cval_ctyp f) members, f ctyp) + | V_ctor_kind (cval, (id, unifiers), ctyp) -> V_ctor_kind (map_cval_ctyp f cval, (id, List.map f unifiers), f ctyp) + | V_ctor_unwrap (cval, (id, unifiers), ctyp) -> V_ctor_unwrap (map_cval_ctyp f cval, (id, List.map f unifiers), f ctyp) + | V_tuple_member (cval, i, j) -> V_tuple_member (map_cval_ctyp f cval, i, j) + | V_call (op, cvals) -> V_call (op, List.map (map_cval_ctyp f) cvals) + | V_field (cval, id) -> V_field (map_cval_ctyp f cval, id) + | V_struct (fields, ctyp) -> V_struct (List.map (fun (id, cval) -> (id, map_cval_ctyp f cval)) fields, f ctyp) + | V_tuple (members, ctyp) -> V_tuple (List.map (map_cval_ctyp f) members, f ctyp) let rec map_instr_ctyp f (I_aux (instr, aux)) = - let instr = match instr with + let instr = + match instr with | I_decl (ctyp, id) -> I_decl (f ctyp, id) | I_init (ctyp, id, cval) -> I_init (f ctyp, id, map_cval_ctyp f cval) | I_if (cval, then_instrs, else_instrs, ctyp) -> - I_if (map_cval_ctyp f cval, List.map (map_instr_ctyp f) then_instrs, List.map (map_instr_ctyp f) else_instrs, f ctyp) + I_if + ( map_cval_ctyp f cval, + List.map (map_instr_ctyp f) then_instrs, + List.map (map_instr_ctyp f) else_instrs, + f ctyp + ) | I_jump (cval, label) -> I_jump (map_cval_ctyp f cval, label) | I_funcall (clexp, extern, (id, ctyps), cvals) -> - I_funcall (map_clexp_ctyp f clexp, extern, (id, List.map f ctyps), List.map (map_cval_ctyp f) cvals) + I_funcall (map_clexp_ctyp f clexp, extern, (id, List.map f ctyps), List.map (map_cval_ctyp f) cvals) | I_copy (clexp, cval) -> I_copy (map_clexp_ctyp f clexp, map_cval_ctyp f cval) | I_clear (ctyp, id) -> I_clear (f ctyp, id) | I_return cval -> I_return (map_cval_ctyp f cval) @@ -821,13 +727,13 @@ let rec map_instr_ctyp f (I_aux (instr, aux)) = I_aux (instr, aux) let rec map_instr_cval f (I_aux (instr, aux)) = - let instr = match instr with + let instr = + match instr with | I_init (ctyp, id, cval) -> I_init (ctyp, id, f cval) | I_if (cval, then_instrs, else_instrs, ctyp) -> - I_if (f cval, List.map (map_instr_cval f) then_instrs, List.map (map_instr_cval f) else_instrs, ctyp) + I_if (f cval, List.map (map_instr_cval f) then_instrs, List.map (map_instr_cval f) else_instrs, ctyp) | I_jump (cval, label) -> I_jump (f cval, label) - | I_funcall (clexp, extern, uid, cvals) -> - I_funcall (clexp, extern, uid, List.map f cvals) + | I_funcall (clexp, extern, uid, cvals) -> I_funcall (clexp, extern, uid, List.map f cvals) | I_copy (clexp, cval) -> I_copy (clexp, f cval) | I_return cval -> I_return (f cval) | I_block instrs -> I_block (List.map (map_instr_cval f) instrs) @@ -835,48 +741,52 @@ let rec map_instr_cval f (I_aux (instr, aux)) = | I_throw cval -> I_throw (f cval) | I_reinit (ctyp, id, cval) -> I_reinit (ctyp, id, f cval) | I_end id -> I_end id - | (I_undefined _ | I_reset _ | I_decl _ | I_clear _ | I_comment _ | I_raw _ | I_label _ | I_goto _ | I_exit _) as instr -> instr + | (I_undefined _ | I_reset _ | I_decl _ | I_clear _ | I_comment _ | I_raw _ | I_label _ | I_goto _ | I_exit _) as + instr -> + instr in I_aux (instr, aux) let rec map_instr f (I_aux (instr, aux)) = - let instr = match instr with - | I_decl _ | I_init _ | I_reset _ | I_reinit _ - | I_funcall _ | I_copy _ | I_clear _ | I_jump _ | I_throw _ | I_return _ - | I_comment _ | I_label _ | I_goto _ | I_raw _ | I_exit _ | I_undefined _ | I_end _ -> instr + let instr = + match instr with + | I_decl _ | I_init _ | I_reset _ | I_reinit _ | I_funcall _ | I_copy _ | I_clear _ | I_jump _ | I_throw _ + | I_return _ | I_comment _ | I_label _ | I_goto _ | I_raw _ | I_exit _ | I_undefined _ | I_end _ -> + instr | I_if (cval, instrs1, instrs2, ctyp) -> - I_if (cval, List.map (map_instr f) instrs1, List.map (map_instr f) instrs2, ctyp) - | I_block instrs -> - I_block (List.map (map_instr f) instrs) - | I_try_block instrs -> - I_try_block (List.map (map_instr f) instrs) + I_if (cval, List.map (map_instr f) instrs1, List.map (map_instr f) instrs2, ctyp) + | I_block instrs -> I_block (List.map (map_instr f) instrs) + | I_try_block instrs -> I_try_block (List.map (map_instr f) instrs) in f (I_aux (instr, aux)) let rec concatmap_instr f (I_aux (instr, aux)) = - let instr = match instr with - | I_decl _ | I_init _ | I_reset _ | I_reinit _ - | I_funcall _ | I_copy _ | I_clear _ | I_jump _ | I_throw _ | I_return _ - | I_comment _ | I_label _ | I_goto _ | I_raw _ | I_exit _ | I_undefined _ | I_end _ -> instr + let instr = + match instr with + | I_decl _ | I_init _ | I_reset _ | I_reinit _ | I_funcall _ | I_copy _ | I_clear _ | I_jump _ | I_throw _ + | I_return _ | I_comment _ | I_label _ | I_goto _ | I_raw _ | I_exit _ | I_undefined _ | I_end _ -> + instr | I_if (cval, instrs1, instrs2, ctyp) -> - I_if (cval, List.concat (List.map (concatmap_instr f) instrs1), List.concat (List.map (concatmap_instr f) instrs2), ctyp) - | I_block instrs -> - I_block (List.concat (List.map (concatmap_instr f) instrs)) - | I_try_block instrs -> - I_try_block (List.concat (List.map (concatmap_instr f) instrs)) + I_if + ( cval, + List.concat (List.map (concatmap_instr f) instrs1), + List.concat (List.map (concatmap_instr f) instrs2), + ctyp + ) + | I_block instrs -> I_block (List.concat (List.map (concatmap_instr f) instrs)) + | I_try_block instrs -> I_try_block (List.concat (List.map (concatmap_instr f) instrs)) in f (I_aux (instr, aux)) let rec iter_instr f (I_aux (instr, aux)) = match instr with - | I_decl _ | I_init _ | I_reset _ | I_reinit _ - | I_funcall _ | I_copy _ | I_clear _ | I_jump _ | I_throw _ | I_return _ - | I_comment _ | I_label _ | I_goto _ | I_raw _ | I_exit _ | I_undefined _ | I_end _ -> f (I_aux (instr, aux)) + | I_decl _ | I_init _ | I_reset _ | I_reinit _ | I_funcall _ | I_copy _ | I_clear _ | I_jump _ | I_throw _ + | I_return _ | I_comment _ | I_label _ | I_goto _ | I_raw _ | I_exit _ | I_undefined _ | I_end _ -> + f (I_aux (instr, aux)) | I_if (_, instrs1, instrs2, _) -> - List.iter (iter_instr f) instrs1; - List.iter (iter_instr f) instrs2 - | I_block instrs | I_try_block instrs -> - List.iter (iter_instr f) instrs + List.iter (iter_instr f) instrs1; + List.iter (iter_instr f) instrs2 + | I_block instrs | I_try_block instrs -> List.iter (iter_instr f) instrs let cdef_map_instr f = function | CDEF_register (id, ctyp, instrs) -> CDEF_register (id, ctyp, List.map (map_instr f) instrs) @@ -891,26 +801,24 @@ let cdef_map_instr f = function let rec map_funcall f instrs = match instrs with | [] -> [] - | (I_aux (I_funcall _, _) as funcall_instr)::tail -> begin + | (I_aux (I_funcall _, _) as funcall_instr) :: tail -> begin match tail with - | (I_aux (I_if (V_id (id, CT_bool), _, [], CT_unit), _) as exception_instr)::tail' - when Name.compare id have_exception == 0 -> - f funcall_instr [exception_instr] @ map_funcall f tail' - | _ -> - f funcall_instr [] @ map_funcall f tail + | (I_aux (I_if (V_id (id, CT_bool), _, [], CT_unit), _) as exception_instr) :: tail' + when Name.compare id have_exception == 0 -> + f funcall_instr [exception_instr] @ map_funcall f tail' + | _ -> f funcall_instr [] @ map_funcall f tail end - | (I_aux (instr, aux))::tail -> - let instr = match instr with - | I_decl _ | I_init _ | I_reset _ | I_reinit _ - | I_funcall _ | I_copy _ | I_clear _ | I_jump _ | I_throw _ | I_return _ - | I_comment _ | I_label _ | I_goto _ | I_raw _ | I_exit _ | I_undefined _ | I_end _ -> instr - | I_if (cval, instrs1, instrs2, ctyp) -> - I_if (cval, map_funcall f instrs1, map_funcall f instrs2, ctyp) - | I_block instrs -> - I_block (map_funcall f instrs) - | I_try_block instrs -> - I_try_block (map_funcall f instrs) - in (I_aux (instr, aux)) :: map_funcall f tail + | I_aux (instr, aux) :: tail -> + let instr = + match instr with + | I_decl _ | I_init _ | I_reset _ | I_reinit _ | I_funcall _ | I_copy _ | I_clear _ | I_jump _ | I_throw _ + | I_return _ | I_comment _ | I_label _ | I_goto _ | I_raw _ | I_exit _ | I_undefined _ | I_end _ -> + instr + | I_if (cval, instrs1, instrs2, ctyp) -> I_if (cval, map_funcall f instrs1, map_funcall f instrs2, ctyp) + | I_block instrs -> I_block (map_funcall f instrs) + | I_try_block instrs -> I_try_block (map_funcall f instrs) + in + I_aux (instr, aux) :: map_funcall f tail let cdef_map_funcall f = function | CDEF_register (id, ctyp, instrs) -> CDEF_register (id, ctyp, map_funcall f instrs) @@ -923,16 +831,12 @@ let cdef_map_funcall f = function | CDEF_pragma (name, str) -> CDEF_pragma (name, str) let cdef_concatmap_instr f = function - | CDEF_register (id, ctyp, instrs) -> - CDEF_register (id, ctyp, List.concat (List.map (concatmap_instr f) instrs)) - | CDEF_let (n, bindings, instrs) -> - CDEF_let (n, bindings, List.concat (List.map (concatmap_instr f) instrs)) + | CDEF_register (id, ctyp, instrs) -> CDEF_register (id, ctyp, List.concat (List.map (concatmap_instr f) instrs)) + | CDEF_let (n, bindings, instrs) -> CDEF_let (n, bindings, List.concat (List.map (concatmap_instr f) instrs)) | CDEF_fundef (id, heap_return, args, instrs) -> - CDEF_fundef (id, heap_return, args, List.concat (List.map (concatmap_instr f) instrs)) - | CDEF_startup (id, instrs) -> - CDEF_startup (id, List.concat (List.map (concatmap_instr f) instrs)) - | CDEF_finish (id, instrs) -> - CDEF_finish (id, List.concat (List.map (concatmap_instr f) instrs)) + CDEF_fundef (id, heap_return, args, List.concat (List.map (concatmap_instr f) instrs)) + | CDEF_startup (id, instrs) -> CDEF_startup (id, List.concat (List.map (concatmap_instr f) instrs)) + | CDEF_finish (id, instrs) -> CDEF_finish (id, List.concat (List.map (concatmap_instr f) instrs)) | CDEF_val (id, extern, ctyps, ctyp) -> CDEF_val (id, extern, ctyps, ctyp) | CDEF_type tdef -> CDEF_type tdef | CDEF_pragma (name, str) -> CDEF_pragma (name, str) @@ -946,7 +850,8 @@ let ctype_def_map_ctyp f = function let cdef_map_ctyp f = function | CDEF_register (id, ctyp, instrs) -> CDEF_register (id, f ctyp, List.map (map_instr_ctyp f) instrs) | CDEF_let (n, bindings, instrs) -> CDEF_let (n, bindings, List.map (map_instr_ctyp f) instrs) - | CDEF_fundef (id, heap_return, args, instrs) -> CDEF_fundef (id, heap_return, args, List.map (map_instr_ctyp f) instrs) + | CDEF_fundef (id, heap_return, args, instrs) -> + CDEF_fundef (id, heap_return, args, List.map (map_instr_ctyp f) instrs) | CDEF_startup (id, instrs) -> CDEF_startup (id, List.map (map_instr_ctyp f) instrs) | CDEF_finish (id, instrs) -> CDEF_finish (id, List.map (map_instr_ctyp f) instrs) | CDEF_val (id, extern, ctyps, ctyp) -> CDEF_val (id, extern, List.map f ctyps, f ctyp) @@ -954,39 +859,37 @@ let cdef_map_ctyp f = function | CDEF_pragma (name, str) -> CDEF_pragma (name, str) let cdef_map_cval f = cdef_map_instr (map_instr_cval f) - + (* Map over all sequences of instructions contained within an instruction *) let rec map_instrs f (I_aux (instr, aux)) = - let instr = match instr with + let instr = + match instr with | I_decl _ | I_init _ | I_reset _ | I_reinit _ -> instr | I_if (cval, instrs1, instrs2, ctyp) -> - I_if (cval, f (List.map (map_instrs f) instrs1), f (List.map (map_instrs f) instrs2), ctyp) + I_if (cval, f (List.map (map_instrs f) instrs1), f (List.map (map_instrs f) instrs2), ctyp) | I_funcall _ | I_copy _ | I_clear _ | I_jump _ | I_throw _ | I_return _ -> instr | I_block instrs -> I_block (f (List.map (map_instrs f) instrs)) | I_try_block instrs -> I_try_block (f (List.map (map_instrs f) instrs)) - | I_comment _ | I_label _ | I_goto _ | I_raw _ | I_exit _ | I_undefined _ | I_end _ -> instr + | I_comment _ | I_label _ | I_goto _ | I_raw _ | I_exit _ | I_undefined _ | I_end _ -> instr in I_aux (instr, aux) -let map_instr_list f instrs = - List.map (map_instr f) instrs +let map_instr_list f instrs = List.map (map_instr f) instrs let instr_ids (I_aux (instr, _)) = let reads, writes = instr_deps instr in NameSet.union reads writes -let instr_reads (I_aux (instr, _)) = - fst (instr_deps instr) +let instr_reads (I_aux (instr, _)) = fst (instr_deps instr) -let instr_writes (I_aux (instr, _)) = - snd (instr_deps instr) +let instr_writes (I_aux (instr, _)) = snd (instr_deps instr) let rec filter_instrs f instrs = let filter_instrs' = function | I_aux (I_block instrs, aux) -> I_aux (I_block (filter_instrs f instrs), aux) | I_aux (I_try_block instrs, aux) -> I_aux (I_try_block (filter_instrs f instrs), aux) | I_aux (I_if (cval, instrs1, instrs2, ctyp), aux) -> - I_aux (I_if (cval, filter_instrs f instrs1, filter_instrs f instrs2, ctyp), aux) + I_aux (I_if (cval, filter_instrs f instrs1, filter_instrs f instrs2, ctyp), aux) | instr -> instr in List.filter f (List.map filter_instrs' instrs) @@ -1002,20 +905,20 @@ let label str = str let rec infer_call op vs = - match op, vs with + match (op, vs) with | Bnot, _ -> CT_bool | Band, _ -> CT_bool | Bor, _ -> CT_bool - | List_hd, [v] -> - begin match cval_ctyp v with - | CT_list ctyp -> ctyp - | _ -> Reporting.unreachable Parse_ast.Unknown __POS__ "Invalid call to hd" - end - | List_tl, [v] -> - begin match cval_ctyp v with - | CT_list ctyp -> CT_list ctyp - | _ -> Reporting.unreachable Parse_ast.Unknown __POS__ "Invalid call to tl" - end + | List_hd, [v] -> begin + match cval_ctyp v with + | CT_list ctyp -> ctyp + | _ -> Reporting.unreachable Parse_ast.Unknown __POS__ "Invalid call to hd" + end + | List_tl, [v] -> begin + match cval_ctyp v with + | CT_list ctyp -> CT_list ctyp + | _ -> Reporting.unreachable Parse_ast.Unknown __POS__ "Invalid call to tl" + end | (Eq | Neq), _ -> CT_bool | Bvnot, [v] -> cval_ctyp v | Bvaccess, _ -> CT_bit @@ -1023,65 +926,56 @@ let rec infer_call op vs = | (Ilt | Igt | Ilteq | Igteq), _ -> CT_bool | (Iadd | Isub), _ -> CT_fint 64 | (Unsigned n | Signed n), _ -> CT_fint n - | (Zero_extend n | Sign_extend n), [v] -> - begin match cval_ctyp v with - | CT_fbits (_, ord) | CT_sbits (_, ord) -> - CT_fbits (n, ord) - | _ -> Reporting.unreachable Parse_ast.Unknown __POS__ "Invalid type for zero/sign_extend argument" - end - | Slice n, [vec; _] -> - begin match cval_ctyp vec with - | CT_fbits (_, ord) | CT_sbits (_, ord) -> - CT_fbits (n, ord) - | _ -> Reporting.unreachable Parse_ast.Unknown __POS__ "Invalid type for extract argument" - end - | Sslice n, [vec; _; _] -> - begin match cval_ctyp vec with - | CT_fbits (_, ord) | CT_sbits (_, ord) -> - CT_sbits (n, ord) - | _ -> Reporting.unreachable Parse_ast.Unknown __POS__ "Invalid type for extract argument" - end + | (Zero_extend n | Sign_extend n), [v] -> begin + match cval_ctyp v with + | CT_fbits (_, ord) | CT_sbits (_, ord) -> CT_fbits (n, ord) + | _ -> Reporting.unreachable Parse_ast.Unknown __POS__ "Invalid type for zero/sign_extend argument" + end + | Slice n, [vec; _] -> begin + match cval_ctyp vec with + | CT_fbits (_, ord) | CT_sbits (_, ord) -> CT_fbits (n, ord) + | _ -> Reporting.unreachable Parse_ast.Unknown __POS__ "Invalid type for extract argument" + end + | Sslice n, [vec; _; _] -> begin + match cval_ctyp vec with + | CT_fbits (_, ord) | CT_sbits (_, ord) -> CT_sbits (n, ord) + | _ -> Reporting.unreachable Parse_ast.Unknown __POS__ "Invalid type for extract argument" + end | Set_slice, [vec; _; _] -> cval_ctyp vec - | Replicate n, [vec] -> - begin match cval_ctyp vec with - | CT_fbits (m, ord) -> CT_fbits (n * m, ord) - | _ -> Reporting.unreachable Parse_ast.Unknown __POS__ "Invalid type for replicate argument" - end - | Concat, [v1; v2] -> - begin match cval_ctyp v1, cval_ctyp v2 with - | CT_fbits (n, ord), CT_fbits (m, _) -> - CT_fbits (n + m, ord) - | CT_fbits (n, ord), CT_sbits (m, _) -> - CT_sbits (m, ord) - | CT_sbits (n, ord), CT_fbits (m, _) -> - CT_sbits (n, ord) - | CT_sbits (n, ord), CT_sbits (m, _) -> - CT_sbits (max n m, ord) - | _ -> Reporting.unreachable Parse_ast.Unknown __POS__ "Invalid type for concat argument" - end - | _, _ -> - Reporting.unreachable Parse_ast.Unknown __POS__ ("Invalid call to function " ^ string_of_op op) + | Replicate n, [vec] -> begin + match cval_ctyp vec with + | CT_fbits (m, ord) -> CT_fbits (n * m, ord) + | _ -> Reporting.unreachable Parse_ast.Unknown __POS__ "Invalid type for replicate argument" + end + | Concat, [v1; v2] -> begin + match (cval_ctyp v1, cval_ctyp v2) with + | CT_fbits (n, ord), CT_fbits (m, _) -> CT_fbits (n + m, ord) + | CT_fbits (n, ord), CT_sbits (m, _) -> CT_sbits (m, ord) + | CT_sbits (n, ord), CT_fbits (m, _) -> CT_sbits (n, ord) + | CT_sbits (n, ord), CT_sbits (m, _) -> CT_sbits (max n m, ord) + | _ -> Reporting.unreachable Parse_ast.Unknown __POS__ "Invalid type for concat argument" + end + | _, _ -> Reporting.unreachable Parse_ast.Unknown __POS__ ("Invalid call to function " ^ string_of_op op) and cval_ctyp = function | V_id (_, ctyp) -> ctyp | V_lit (_, ctyp) -> ctyp | V_ctor_kind _ -> CT_bool | V_ctor_unwrap (_, _, ctyp) -> ctyp - | V_tuple_member (cval, _, n) -> - begin match cval_ctyp cval with - | CT_tup ctyps -> - List.nth ctyps n - | ctyp -> Reporting.unreachable Parse_ast.Unknown __POS__ ("Invalid tuple type " ^ full_string_of_ctyp ctyp) - end - | V_field (cval, field) -> - begin match cval_ctyp cval with - | CT_struct (id, ctors) -> - begin - try snd (List.find (fun (id, ctyp) -> Id.compare id field = 0) ctors) with - | Not_found -> failwith ("Struct type " ^ string_of_id id ^ " does not have a constructor " ^ string_of_id field) + | V_tuple_member (cval, _, n) -> begin + match cval_ctyp cval with + | CT_tup ctyps -> List.nth ctyps n + | ctyp -> Reporting.unreachable Parse_ast.Unknown __POS__ ("Invalid tuple type " ^ full_string_of_ctyp ctyp) + end + | V_field (cval, field) -> begin + match cval_ctyp cval with + | CT_struct (id, ctors) -> begin + try snd (List.find (fun (id, ctyp) -> Id.compare id field = 0) ctors) + with Not_found -> + failwith ("Struct type " ^ string_of_id id ^ " does not have a constructor " ^ string_of_id field) end - | ctyp -> Reporting.unreachable Parse_ast.Unknown __POS__ ("Inavlid type for V_field " ^ full_string_of_ctyp ctyp) - end + | ctyp -> Reporting.unreachable Parse_ast.Unknown __POS__ ("Inavlid type for V_field " ^ full_string_of_ctyp ctyp) + end | V_struct (_, ctyp) -> ctyp | V_tuple (_, ctyp) -> ctyp | V_call (op, vs) -> infer_call op vs @@ -1089,53 +983,41 @@ and cval_ctyp = function let rec clexp_ctyp = function | CL_id (_, ctyp) -> ctyp | CL_rmw (_, _, ctyp) -> ctyp - | CL_field (clexp, field) -> - begin match clexp_ctyp clexp with - | CT_struct (id, ctors) -> - begin - try snd (List.find (fun (id, _) -> Id.compare id field = 0) ctors) with - | Not_found -> failwith ("Struct type " ^ string_of_id id ^ " does not have a field " ^ string_of_id field) + | CL_field (clexp, field) -> begin + match clexp_ctyp clexp with + | CT_struct (id, ctors) -> begin + try snd (List.find (fun (id, _) -> Id.compare id field = 0) ctors) + with Not_found -> + failwith ("Struct type " ^ string_of_id id ^ " does not have a field " ^ string_of_id field) end - | ctyp -> failwith ("Bad ctyp for CL_field " ^ string_of_ctyp ctyp) - end - | CL_addr clexp -> - begin match clexp_ctyp clexp with - | CT_ref ctyp -> ctyp - | ctyp -> failwith ("Bad ctyp for CL_addr " ^ string_of_ctyp ctyp) - end - | CL_tuple (clexp, n) -> - begin match clexp_ctyp clexp with - | CT_tup typs -> - begin - try List.nth typs n with - | _ -> failwith "Tuple assignment index out of bounds" - end - | ctyp -> failwith ("Bad ctyp for CL_addr " ^ string_of_ctyp ctyp) - end + | ctyp -> failwith ("Bad ctyp for CL_field " ^ string_of_ctyp ctyp) + end + | CL_addr clexp -> begin + match clexp_ctyp clexp with + | CT_ref ctyp -> ctyp + | ctyp -> failwith ("Bad ctyp for CL_addr " ^ string_of_ctyp ctyp) + end + | CL_tuple (clexp, n) -> begin + match clexp_ctyp clexp with + | CT_tup typs -> begin try List.nth typs n with _ -> failwith "Tuple assignment index out of bounds" end + | ctyp -> failwith ("Bad ctyp for CL_addr " ^ string_of_ctyp ctyp) + end | CL_void -> CT_unit let rec instr_ctyps (I_aux (instr, aux)) = match instr with - | I_decl (ctyp, _) | I_reset (ctyp, _) | I_clear (ctyp, _) | I_undefined ctyp -> - CTSet.singleton ctyp - | I_init (ctyp, _, cval) | I_reinit (ctyp, _, cval) -> - CTSet.add ctyp (CTSet.singleton (cval_ctyp cval)) + | I_decl (ctyp, _) | I_reset (ctyp, _) | I_clear (ctyp, _) | I_undefined ctyp -> CTSet.singleton ctyp + | I_init (ctyp, _, cval) | I_reinit (ctyp, _, cval) -> CTSet.add ctyp (CTSet.singleton (cval_ctyp cval)) | I_if (cval, instrs1, instrs2, ctyp) -> - CTSet.union (instrs_ctyps instrs1) (instrs_ctyps instrs2) - |> CTSet.add (cval_ctyp cval) - |> CTSet.add ctyp + CTSet.union (instrs_ctyps instrs1) (instrs_ctyps instrs2) |> CTSet.add (cval_ctyp cval) |> CTSet.add ctyp | I_funcall (clexp, _, (_, ctyps), cvals) -> - List.fold_left (fun m ctyp -> CTSet.add ctyp m) CTSet.empty (List.map cval_ctyp cvals) - |> CTSet.union (CTSet.of_list ctyps) - |> CTSet.add (clexp_ctyp clexp) - | I_copy (clexp, cval) -> - CTSet.add (clexp_ctyp clexp) (CTSet.singleton (cval_ctyp cval)) - | I_block instrs | I_try_block instrs -> - instrs_ctyps instrs - | I_throw cval | I_jump (cval, _) | I_return cval -> - CTSet.singleton (cval_ctyp cval) - | I_comment _ | I_label _ | I_goto _ | I_raw _ | I_exit _ | I_end _ -> - CTSet.empty + List.fold_left (fun m ctyp -> CTSet.add ctyp m) CTSet.empty (List.map cval_ctyp cvals) + |> CTSet.union (CTSet.of_list ctyps) + |> CTSet.add (clexp_ctyp clexp) + | I_copy (clexp, cval) -> CTSet.add (clexp_ctyp clexp) (CTSet.singleton (cval_ctyp cval)) + | I_block instrs | I_try_block instrs -> instrs_ctyps instrs + | I_throw cval | I_jump (cval, _) | I_return cval -> CTSet.singleton (cval_ctyp cval) + | I_comment _ | I_label _ | I_goto _ | I_raw _ | I_exit _ | I_end _ -> CTSet.empty and instrs_ctyps instrs = List.fold_left CTSet.union CTSet.empty (List.map instr_ctyps instrs) @@ -1145,17 +1027,13 @@ let ctype_def_ctyps = function | CTD_variant (_, ctors) -> List.map snd ctors let cdef_ctyps = function - | CDEF_register (_, ctyp, instrs) -> - CTSet.add ctyp (instrs_ctyps instrs) - | CDEF_val (_, _, ctyps, ctyp) -> - CTSet.add ctyp (List.fold_left (fun m ctyp -> CTSet.add ctyp m) CTSet.empty ctyps) - | CDEF_fundef (_, _, _, instrs) | CDEF_startup (_, instrs) | CDEF_finish (_, instrs) -> - instrs_ctyps instrs - | CDEF_type tdef -> - List.fold_right CTSet.add (ctype_def_ctyps tdef) CTSet.empty + | CDEF_register (_, ctyp, instrs) -> CTSet.add ctyp (instrs_ctyps instrs) + | CDEF_val (_, _, ctyps, ctyp) -> CTSet.add ctyp (List.fold_left (fun m ctyp -> CTSet.add ctyp m) CTSet.empty ctyps) + | CDEF_fundef (_, _, _, instrs) | CDEF_startup (_, instrs) | CDEF_finish (_, instrs) -> instrs_ctyps instrs + | CDEF_type tdef -> List.fold_right CTSet.add (ctype_def_ctyps tdef) CTSet.empty | CDEF_let (_, bindings, instrs) -> - List.fold_left (fun m ctyp -> CTSet.add ctyp m) CTSet.empty (List.map snd bindings) - |> CTSet.union (instrs_ctyps instrs) + List.fold_left (fun m ctyp -> CTSet.add ctyp m) CTSet.empty (List.map snd bindings) + |> CTSet.union (instrs_ctyps instrs) | CDEF_pragma (_, _) -> CTSet.empty let rec c_ast_registers = function @@ -1177,16 +1055,18 @@ let rec instrs_rename from_id to_id = let irename instrs = instrs_rename from_id to_id instrs in let lrename = clexp_rename from_id to_id in function - | (I_aux (I_decl (ctyp, new_id), _) :: _) as instrs when Name.compare from_id new_id = 0 -> instrs + | I_aux (I_decl (ctyp, new_id), _) :: _ as instrs when Name.compare from_id new_id = 0 -> instrs | I_aux (I_decl (ctyp, new_id), aux) :: instrs -> I_aux (I_decl (ctyp, new_id), aux) :: irename instrs | I_aux (I_reset (ctyp, id), aux) :: instrs -> I_aux (I_reset (ctyp, rename id), aux) :: irename instrs - | I_aux (I_init (ctyp, id, cval), aux) :: instrs -> I_aux (I_init (ctyp, rename id, crename cval), aux) :: irename instrs - | I_aux (I_reinit (ctyp, id, cval), aux) :: instrs -> I_aux (I_reinit (ctyp, rename id, crename cval), aux) :: irename instrs + | I_aux (I_init (ctyp, id, cval), aux) :: instrs -> + I_aux (I_init (ctyp, rename id, crename cval), aux) :: irename instrs + | I_aux (I_reinit (ctyp, id, cval), aux) :: instrs -> + I_aux (I_reinit (ctyp, rename id, crename cval), aux) :: irename instrs | I_aux (I_if (cval, then_instrs, else_instrs, ctyp), aux) :: instrs -> - I_aux (I_if (crename cval, irename then_instrs, irename else_instrs, ctyp), aux) :: irename instrs + I_aux (I_if (crename cval, irename then_instrs, irename else_instrs, ctyp), aux) :: irename instrs | I_aux (I_jump (cval, label), aux) :: instrs -> I_aux (I_jump (crename cval, label), aux) :: irename instrs | I_aux (I_funcall (clexp, extern, function_id, cvals), aux) :: instrs -> - I_aux (I_funcall (lrename clexp, extern, function_id, List.map crename cvals), aux) :: irename instrs + I_aux (I_funcall (lrename clexp, extern, function_id, List.map crename cvals), aux) :: irename instrs | I_aux (I_copy (clexp, cval), aux) :: instrs -> I_aux (I_copy (lrename clexp, crename cval), aux) :: irename instrs | I_aux (I_clear (ctyp, id), aux) :: instrs -> I_aux (I_clear (ctyp, rename id), aux) :: irename instrs | I_aux (I_return cval, aux) :: instrs -> I_aux (I_return (crename cval), aux) :: irename instrs @@ -1194,5 +1074,6 @@ let rec instrs_rename from_id to_id = | I_aux (I_try_block block, aux) :: instrs -> I_aux (I_try_block (irename block), aux) :: irename instrs | I_aux (I_throw cval, aux) :: instrs -> I_aux (I_throw (crename cval), aux) :: irename instrs | I_aux (I_end id, aux) :: instrs -> I_aux (I_end (rename id), aux) :: irename instrs - | (I_aux ((I_comment _ | I_raw _ | I_label _ | I_goto _ | I_exit _ | I_undefined _), _) as instr) :: instrs -> instr :: irename instrs + | (I_aux ((I_comment _ | I_raw _ | I_label _ | I_goto _ | I_exit _ | I_undefined _), _) as instr) :: instrs -> + instr :: irename instrs | [] -> [] diff --git a/src/lib/jib_util.mli b/src/lib/jib_util.mli index 183064145..dd4703960 100644 --- a/src/lib/jib_util.mli +++ b/src/lib/jib_util.mli @@ -72,7 +72,7 @@ open Ast_util open Jib (** {1 Instruction construction functions, and Jib names } *) - + (** Create a generator that produces fresh names, paired with a function that resets the generator (allowing it to regenerate the same name). *) @@ -102,7 +102,7 @@ val ijump : l -> cval -> string -> instr (** Create a new unique label by concatenating a string with a unique identifier *) val label : string -> string - + module Name : sig type t = name val compare : name -> name -> int @@ -123,7 +123,7 @@ val return : name val name : id -> name val global : id -> name - + val cval_rename : name -> name -> cval -> cval val clexp_rename : name -> name -> clexp -> clexp val instr_rename : name -> name -> instr -> instr @@ -137,7 +137,7 @@ val string_of_value : Value2.vl -> string val string_of_cval : cval -> string (** {1. Functions and modules for working with ctyps} *) - + val map_ctyp : (ctyp -> ctyp) -> ctyp -> ctyp val ctyp_equal : ctyp -> ctyp -> bool val ctyp_compare : ctyp -> ctyp -> int @@ -157,13 +157,13 @@ end module NameCTSet : sig include Set.S with type elt = name * ctyp end - + module NameCTMap : sig include Map.S with type key = name * ctyp end (** {2 Operations for polymorphic Jib ctyps} *) - + val ctyp_unify : l -> ctyp -> ctyp -> ctyp KBindings.t val merge_unifiers : kid -> ctyp -> ctyp -> ctyp option @@ -179,7 +179,7 @@ val is_polymorphic : ctyp -> bool val subst_poly : ctyp KBindings.t -> ctyp -> ctyp (** {2 Infer types} *) - + val cval_ctyp : cval -> ctyp val clexp_ctyp : clexp -> ctyp val cdef_ctyps : cdef -> CTSet.t @@ -189,11 +189,11 @@ val cdef_ctyps : cdef -> CTSet.t val instr_ids : instr -> NameSet.t val instr_reads : instr -> NameSet.t val instr_writes : instr -> NameSet.t - + val instr_typed_writes : instr -> NameCTSet.t - + val map_cval : (cval -> cval) -> cval -> cval - + (** Map over each instruction within an instruction, bottom-up *) val map_instr : (instr -> instr) -> instr -> instr @@ -201,7 +201,7 @@ val map_instrs : (instr list -> instr list) -> instr -> instr (** Concat-map over each instruction within an instruction, bottom-up *) val concatmap_instr : (instr -> instr list) -> instr -> instr list - + (** Iterate over each instruction within an instruction, bottom-up *) val iter_instr : (instr -> unit) -> instr -> unit @@ -216,19 +216,19 @@ val cdef_map_ctyp : (ctyp -> ctyp) -> cdef -> cdef val map_instr_cval : (cval -> cval) -> instr -> instr val map_instr_list : (instr -> instr) -> instr list -> instr list - + val filter_instrs : (instr -> bool) -> instr list -> instr list -val instr_split_at : (instr -> bool) -> instr list -> (instr list * instr list) +val instr_split_at : (instr -> bool) -> instr list -> instr list * instr list (** Map over function calls in an instruction sequence, including exception handler where present *) val map_funcall : (instr -> instr list -> instr list) -> instr list -> instr list - + (** Map over each function call in a cdef using map_funcall *) val cdef_map_funcall : (instr -> instr list -> instr list) -> cdef -> cdef val cdef_map_cval : (cval -> cval) -> cdef -> cdef - + (** Map over each instruction in a cdef using concatmap_instr *) val cdef_concatmap_instr : (instr -> instr list) -> cdef -> cdef diff --git a/src/lib/monomorphise.ml b/src/lib/monomorphise.ml index 5e8c321fa..8f27f639e 100644 --- a/src/lib/monomorphise.ml +++ b/src/lib/monomorphise.ml @@ -82,46 +82,32 @@ let opt_mwords = ref false (* From the command line we take vague file/line locations, but from the analysis we can use exact locations. *) -type split_loc = -| Line of string * int -| Exact of Parse_ast.l +type split_loc = Line of string * int | Exact of Parse_ast.l (* Returns the set of type variables that will appear in the Lem output, which may be smaller than those in the Sail type. May need to be updated with doc_typ_lem *) -let rec lem_nexps_of_typ (Typ_aux (t,l)) = +let rec lem_nexps_of_typ (Typ_aux (t, l)) = let trec = lem_nexps_of_typ in match t with | Typ_id _ -> NexpSet.empty | Typ_var kid -> NexpSet.singleton (orig_nexp (nvar kid)) - | Typ_fn (t1,t2) -> List.fold_left NexpSet.union (trec t2) (List.map trec t1) - | Typ_tuple ts -> - List.fold_left (fun s t -> NexpSet.union s (trec t)) - NexpSet.empty ts - | Typ_app(Id_aux (Id "bitvector", _), [ - A_aux (A_nexp m, _); - A_aux (A_order ord, _)]) -> - let m = nexp_simp m in - if !opt_mwords && not (is_nexp_constant m) then - NexpSet.singleton (orig_nexp m) - else trec bit_typ - | Typ_app(Id_aux (Id "vector", _), [ - A_aux (A_nexp m, _); - A_aux (A_order ord, _); - A_aux (A_typ elem_typ, _)]) -> - trec elem_typ - | Typ_app(Id_aux (Id "register", _), [A_aux (A_typ etyp, _)]) -> - trec etyp - | Typ_app(Id_aux (Id "range", _),_) - | Typ_app(Id_aux (Id "implicit", _),_) - | Typ_app(Id_aux (Id "atom", _), _) -> NexpSet.empty - | Typ_app (_,tas) -> - List.fold_left (fun s ta -> NexpSet.union s (lem_nexps_of_typ_arg ta)) - NexpSet.empty tas - | Typ_exist (kids,_,t) -> trec t + | Typ_fn (t1, t2) -> List.fold_left NexpSet.union (trec t2) (List.map trec t1) + | Typ_tuple ts -> List.fold_left (fun s t -> NexpSet.union s (trec t)) NexpSet.empty ts + | Typ_app (Id_aux (Id "bitvector", _), [A_aux (A_nexp m, _); A_aux (A_order ord, _)]) -> + let m = nexp_simp m in + if !opt_mwords && not (is_nexp_constant m) then NexpSet.singleton (orig_nexp m) else trec bit_typ + | Typ_app (Id_aux (Id "vector", _), [A_aux (A_nexp m, _); A_aux (A_order ord, _); A_aux (A_typ elem_typ, _)]) -> + trec elem_typ + | Typ_app (Id_aux (Id "register", _), [A_aux (A_typ etyp, _)]) -> trec etyp + | Typ_app (Id_aux (Id "range", _), _) | Typ_app (Id_aux (Id "implicit", _), _) | Typ_app (Id_aux (Id "atom", _), _) -> + NexpSet.empty + | Typ_app (_, tas) -> List.fold_left (fun s ta -> NexpSet.union s (lem_nexps_of_typ_arg ta)) NexpSet.empty tas + | Typ_exist (kids, _, t) -> trec t | Typ_bidir _ -> Reporting.unreachable l __POS__ "Lem doesn't support bidir types" | Typ_internal_unknown -> Reporting.unreachable l __POS__ "escaped Typ_internal_unknown" -and lem_nexps_of_typ_arg (A_aux (ta,_)) = + +and lem_nexps_of_typ_arg (A_aux (ta, _)) = match ta with | A_nexp nexp -> let nexp = nexp_simp (orig_nexp nexp) in @@ -130,204 +116,182 @@ and lem_nexps_of_typ_arg (A_aux (ta,_)) = | A_order _ -> NexpSet.empty | A_bool _ -> NexpSet.empty -let rec typeclass_nexps (Typ_aux(t,l)) = - if !opt_mwords then +let rec typeclass_nexps (Typ_aux (t, l)) = + if !opt_mwords then ( match t with - | Typ_id _ - | Typ_var _ - -> NexpSet.empty - | Typ_fn (ts,t) -> List.fold_left NexpSet.union (typeclass_nexps t) (List.map typeclass_nexps ts) + | Typ_id _ | Typ_var _ -> NexpSet.empty + | Typ_fn (ts, t) -> List.fold_left NexpSet.union (typeclass_nexps t) (List.map typeclass_nexps ts) | Typ_tuple ts -> List.fold_left NexpSet.union NexpSet.empty (List.map typeclass_nexps ts) - | Typ_app (Id_aux (Id "bitvector",_), - [A_aux (A_nexp size_nexp,_); _]) - | Typ_app (Id_aux (Id "itself",_), - [A_aux (A_nexp size_nexp,_)]) -> - let size_nexp = nexp_simp size_nexp in - if is_nexp_constant size_nexp then NexpSet.empty else - NexpSet.singleton (orig_nexp size_nexp) + | Typ_app (Id_aux (Id "bitvector", _), [A_aux (A_nexp size_nexp, _); _]) + | Typ_app (Id_aux (Id "itself", _), [A_aux (A_nexp size_nexp, _)]) -> + let size_nexp = nexp_simp size_nexp in + if is_nexp_constant size_nexp then NexpSet.empty else NexpSet.singleton (orig_nexp size_nexp) | Typ_app (id, args) -> - let add_arg_nexps nexps = function - | A_aux (A_typ typ, _) -> - NexpSet.union nexps (typeclass_nexps typ) - | _ -> nexps - in - List.fold_left add_arg_nexps NexpSet.empty args - | Typ_exist (kids,_,t) -> NexpSet.empty (* todo *) + let add_arg_nexps nexps = function + | A_aux (A_typ typ, _) -> NexpSet.union nexps (typeclass_nexps typ) + | _ -> nexps + in + List.fold_left add_arg_nexps NexpSet.empty args + | Typ_exist (kids, _, t) -> NexpSet.empty (* todo *) | Typ_bidir _ -> Reporting.unreachable l __POS__ "Lem doesn't support bidir types" | Typ_internal_unknown -> Reporting.unreachable l __POS__ "escaped Typ_internal_unknown" + ) else NexpSet.empty - + let size_set_limit = 64 -let optmap v f = - match v with - | None -> None - | Some v -> Some (f v) +let optmap v f = match v with None -> None | Some v -> Some (f v) -let kbindings_from_list = List.fold_left (fun s (v,i) -> KBindings.add v i s) KBindings.empty -let bindings_from_list = List.fold_left (fun s (v,i) -> Bindings.add v i s) Bindings.empty +let kbindings_from_list = List.fold_left (fun s (v, i) -> KBindings.add v i s) KBindings.empty +let bindings_from_list = List.fold_left (fun s (v, i) -> Bindings.add v i s) Bindings.empty let ids_in_exp exp = let open Rewriter in - fold_exp { + fold_exp + { (pure_exp_alg IdSet.empty IdSet.union) with e_id = IdSet.singleton; le_id = IdSet.singleton; - le_app = (fun (id,s) -> List.fold_left IdSet.union (IdSet.singleton id) s); - le_typ = (fun (_,id) -> IdSet.singleton id) - } exp + le_app = (fun (id, s) -> List.fold_left IdSet.union (IdSet.singleton id) s); + le_typ = (fun (_, id) -> IdSet.singleton id); + } + exp let make_vector_lit sz i = - let f j = if Big_int.equal (Big_int.modulus (Big_int.shift_right i (sz-j-1)) (Big_int.of_int 2)) Big_int.zero then '0' else '1' in + let f j = + if Big_int.equal (Big_int.modulus (Big_int.shift_right i (sz - j - 1)) (Big_int.of_int 2)) Big_int.zero then '0' + else '1' + in let s = String.init sz f in - L_aux (L_bin s,Generated Unknown) + L_aux (L_bin s, Generated Unknown) let tabulate f n = let rec aux acc n = - let acc' = f n::acc in + let acc' = f n :: acc in if Big_int.equal n Big_int.zero then acc' else aux acc' (Big_int.sub n (Big_int.of_int 1)) - in if Big_int.equal n Big_int.zero then [] else aux [] (Big_int.sub n (Big_int.of_int 1)) + in + if Big_int.equal n Big_int.zero then [] else aux [] (Big_int.sub n (Big_int.of_int 1)) -let make_vectors sz = - tabulate (make_vector_lit sz) (Big_int.shift_left (Big_int.of_int 1) sz) +let make_vectors sz = tabulate (make_vector_lit sz) (Big_int.shift_left (Big_int.of_int 1) sz) let is_inc_vec typ = try - let (_, ord, _) = vector_typ_args_of typ in + let _, ord, _ = vector_typ_args_of typ in is_order_inc ord with _ -> false let rec cross' = function | [] -> [[]] - | (h::t) -> - let t' = cross' t in - List.concat (List.map (fun x -> List.map (fun l -> x::l) t') h) + | h :: t -> + let t' = cross' t in + List.concat (List.map (fun x -> List.map (fun l -> x :: l) t') h) let rec cross'' = function | [] -> [[]] - | (k,None)::t -> List.map (fun l -> (k,None)::l) (cross'' t) - | (k,Some h)::t -> - let t' = cross'' t in - List.concat (List.map (fun x -> List.map (fun l -> (k,Some x)::l) t') h) + | (k, None) :: t -> List.map (fun l -> (k, None) :: l) (cross'' t) + | (k, Some h) :: t -> + let t' = cross'' t in + List.concat (List.map (fun x -> List.map (fun l -> (k, Some x) :: l) t') h) -let kidset_bigunion = function - | [] -> KidSet.empty - | h::t -> List.fold_left KidSet.union h t +let kidset_bigunion = function [] -> KidSet.empty | h :: t -> List.fold_left KidSet.union h t (* TODO: deal with non-set constraints, intersections, etc somehow *) let extract_set_nc env l var nc = let vars = Spec_analysis.equal_kids_ncs var [nc] in - let rec aux_or (NC_aux (nc,l)) = + let rec aux_or (NC_aux (nc, l)) = match nc with - | NC_equal (Nexp_aux (Nexp_var id,_), Nexp_aux (Nexp_constant n,_)) - when KidSet.mem id vars -> - Some [n] - | NC_or (nc1,nc2) -> - (match aux_or nc1, aux_or nc2 with - | Some l1, Some l2 -> Some (l1 @ l2) - | _, _ -> None) + | NC_equal (Nexp_aux (Nexp_var id, _), Nexp_aux (Nexp_constant n, _)) when KidSet.mem id vars -> Some [n] + | NC_or (nc1, nc2) -> ( + match (aux_or nc1, aux_or nc2) with Some l1, Some l2 -> Some (l1 @ l2) | _, _ -> None + ) | _ -> None in (* Lazily expand constraints to keep close to the original form *) - let rec aux expanded (NC_aux (nc,l) as nc_full) = - let re nc = NC_aux (nc,l) in + let rec aux expanded (NC_aux (nc, l) as nc_full) = + let re nc = NC_aux (nc, l) in match nc with - | NC_set (id,is) when KidSet.mem id vars -> Some (is,re NC_true) - | NC_equal (Nexp_aux (Nexp_var id,_), Nexp_aux (Nexp_constant n,_)) - when KidSet.mem id vars -> - Some ([n], re NC_true) + | NC_set (id, is) when KidSet.mem id vars -> Some (is, re NC_true) + | NC_equal (Nexp_aux (Nexp_var id, _), Nexp_aux (Nexp_constant n, _)) when KidSet.mem id vars -> + Some ([n], re NC_true) (* Turn (i <= 'v & 'v <= j & ...) into set constraint ('v in {i..j}) *) - | NC_and (NC_aux (NC_bounded_le (Nexp_aux (Nexp_constant n, _), Nexp_aux (Nexp_var kid, _)), _) as nc1, nc2) + | NC_and ((NC_aux (NC_bounded_le (Nexp_aux (Nexp_constant n, _), Nexp_aux (Nexp_var kid, _)), _) as nc1), nc2) when KidSet.mem kid vars -> - let aux2 () = match aux expanded nc2 with - | Some (is, nc2') -> Some (is, re (NC_and (nc1, nc2'))) - | None -> None - in - begin match constraint_conj nc2 with - | NC_aux (NC_bounded_le (Nexp_aux (Nexp_var kid', _), Nexp_aux (Nexp_constant n', _)), _) :: ncs - when KidSet.mem kid' vars -> - let len = Big_int.succ (Big_int.sub n' n) in - if Big_int.less_equal Big_int.zero len && Big_int.less_equal len (Big_int.of_int size_set_limit) then - let elem i = Big_int.add n (Big_int.of_int i) in - let is = List.init (Big_int.to_int len) elem in - if aux expanded (List.fold_left nc_and nc_true ncs) <> None then - raise (Reporting.err_general l ("Multiple set constraints for " ^ string_of_kid var)) - else Some (is, nc_full) - else aux2 () - | _ -> aux2 () - end - | NC_and (nc1,nc2) -> - (match aux expanded nc1, aux expanded nc2 with - | None, None -> None - | None, Some (is,nc2') -> Some (is, re (NC_and (nc1,nc2'))) - | Some (is,nc1'), None -> Some (is, re (NC_and (nc1',nc2))) - | Some _, Some _ -> - raise (Reporting.err_general l ("Multiple set constraints for " ^ string_of_kid var))) - | NC_or _ -> - (match aux_or nc_full with - | Some is -> Some (is, re NC_true) - | None -> None) + let aux2 () = + match aux expanded nc2 with Some (is, nc2') -> Some (is, re (NC_and (nc1, nc2'))) | None -> None + in + begin + match constraint_conj nc2 with + | NC_aux (NC_bounded_le (Nexp_aux (Nexp_var kid', _), Nexp_aux (Nexp_constant n', _)), _) :: ncs + when KidSet.mem kid' vars -> + let len = Big_int.succ (Big_int.sub n' n) in + if Big_int.less_equal Big_int.zero len && Big_int.less_equal len (Big_int.of_int size_set_limit) then ( + let elem i = Big_int.add n (Big_int.of_int i) in + let is = List.init (Big_int.to_int len) elem in + if aux expanded (List.fold_left nc_and nc_true ncs) <> None then + raise (Reporting.err_general l ("Multiple set constraints for " ^ string_of_kid var)) + else Some (is, nc_full) + ) + else aux2 () + | _ -> aux2 () + end + | NC_and (nc1, nc2) -> ( + match (aux expanded nc1, aux expanded nc2) with + | None, None -> None + | None, Some (is, nc2') -> Some (is, re (NC_and (nc1, nc2'))) + | Some (is, nc1'), None -> Some (is, re (NC_and (nc1', nc2))) + | Some _, Some _ -> raise (Reporting.err_general l ("Multiple set constraints for " ^ string_of_kid var)) + ) + | NC_or _ -> ( + match aux_or nc_full with Some is -> Some (is, re NC_true) | None -> None + ) | _ -> if expanded then None else aux true (Env.expand_constraint_synonyms env nc_full) - in match aux false nc with + in + match aux false nc with | Some is -> is | None -> - raise (Reporting.err_general l ("No set constraint for " ^ string_of_kid var ^ - " in " ^ string_of_n_constraint nc)) + raise (Reporting.err_general l ("No set constraint for " ^ string_of_kid var ^ " in " ^ string_of_n_constraint nc)) let rec split_insts = function - | [] -> [],[] - | (k,None)::t -> let l1,l2 = split_insts t in l1,k::l2 - | (k,Some v)::t -> let l1,l2 = split_insts t in (k,v)::l1,l2 + | [] -> ([], []) + | (k, None) :: t -> + let l1, l2 = split_insts t in + (l1, k :: l2) + | (k, Some v) :: t -> + let l1, l2 = split_insts t in + ((k, v) :: l1, l2) let apply_kid_insts kid_insts nc t = let kid_insts, kids' = split_insts kid_insts in - let kid_insts = List.map - (fun (v,i) -> (kopt_kid v,Nexp_aux (Nexp_constant i,Generated Unknown))) - kid_insts in + let kid_insts = List.map (fun (v, i) -> (kopt_kid v, Nexp_aux (Nexp_constant i, Generated Unknown))) kid_insts in let subst = kbindings_from_list kid_insts in - kids', subst_kids_nc subst nc, subst_kids_typ subst t + (kids', subst_kids_nc subst nc, subst_kids_typ subst t) -let rec contains_exist (Typ_aux (ty,l)) = +let rec contains_exist (Typ_aux (ty, l)) = match ty with - | Typ_id _ - | Typ_var _ - -> false - | Typ_fn (t1,t2) -> List.exists contains_exist t1 || contains_exist t2 + | Typ_id _ | Typ_var _ -> false + | Typ_fn (t1, t2) -> List.exists contains_exist t1 || contains_exist t2 | Typ_bidir (t1, t2) -> contains_exist t1 || contains_exist t2 | Typ_tuple ts -> List.exists contains_exist ts - | Typ_app (_,args) -> List.exists contains_exist_arg args + | Typ_app (_, args) -> List.exists contains_exist_arg args | Typ_exist _ -> true | Typ_internal_unknown -> Reporting.unreachable l __POS__ "escaped Typ_internal_unknown" -and contains_exist_arg (A_aux (arg,_)) = - match arg with - | A_nexp _ - | A_order _ - | A_bool _ - -> false - | A_typ typ -> contains_exist typ - -let is_number typ = match destruct_numeric typ with - | Some _ -> true - | None -> false - -let rec size_nvars_nexp (Nexp_aux (ne,_)) = + +and contains_exist_arg (A_aux (arg, _)) = + match arg with A_nexp _ | A_order _ | A_bool _ -> false | A_typ typ -> contains_exist typ + +let is_number typ = match destruct_numeric typ with Some _ -> true | None -> false + +let rec size_nvars_nexp (Nexp_aux (ne, _)) = match ne with | Nexp_var v -> [v] - | Nexp_id _ - | Nexp_constant _ - -> [] - | Nexp_times (n1,n2) - | Nexp_sum (n1,n2) - | Nexp_minus (n1,n2) - -> size_nvars_nexp n1 @ size_nvars_nexp n2 - | Nexp_exp n - | Nexp_neg n - -> size_nvars_nexp n - | Nexp_app (_,args) -> List.concat (List.map size_nvars_nexp args) + | Nexp_id _ | Nexp_constant _ -> [] + | Nexp_times (n1, n2) | Nexp_sum (n1, n2) | Nexp_minus (n1, n2) -> size_nvars_nexp n1 @ size_nvars_nexp n2 + | Nexp_exp n | Nexp_neg n -> size_nvars_nexp n + | Nexp_app (_, args) -> List.concat (List.map size_nvars_nexp args) (* Given a type for a constructor, work out which refinements we ought to produce *) (* TODO collision avoidance *) -let split_src_type all_errors env id ty (TypQ_aux (q,ql)) = +let split_src_type all_errors env id ty (TypQ_aux (q, ql)) = let cannot l msg default = let open Reporting in match all_errors with @@ -341,155 +305,158 @@ let split_src_type all_errors env id ty (TypQ_aux (q,ql)) = let i = string_of_id id in (* This was originally written for the general case, but I cut it down to the more manageable prenex-form below *) - let rec size_nvars_ty (Typ_aux (_,l) as typ) = - let Typ_aux (ty,_l) = Env.expand_synonyms env typ in + let rec size_nvars_ty (Typ_aux (_, l) as typ) = + let (Typ_aux (ty, _l)) = Env.expand_synonyms env typ in match ty with - | Typ_id _ - | Typ_var _ - -> (KidSet.empty,[[],typ]) - | Typ_fn _ -> - cannot l ("Function type in constructor " ^ i) (KidSet.empty,[[],typ]) - | Typ_bidir _ -> - cannot l ("Mapping type in constructor " ^ i) (KidSet.empty,[[],typ]) + | Typ_id _ | Typ_var _ -> (KidSet.empty, [([], typ)]) + | Typ_fn _ -> cannot l ("Function type in constructor " ^ i) (KidSet.empty, [([], typ)]) + | Typ_bidir _ -> cannot l ("Mapping type in constructor " ^ i) (KidSet.empty, [([], typ)]) | Typ_tuple ts -> - let (vars,tys) = List.split (List.map size_nvars_ty ts) in - let insttys = List.map (fun x -> let (insts,tys) = List.split x in - List.concat insts, Typ_aux (Typ_tuple tys,l)) (cross' tys) in - (kidset_bigunion vars, insttys) - | Typ_app (Id_aux (Id "bitvector",_), - [A_aux (A_nexp sz,_);_]) -> - (KidSet.of_list (size_nvars_nexp sz), [[],typ]) + let vars, tys = List.split (List.map size_nvars_ty ts) in + let insttys = + List.map + (fun x -> + let insts, tys = List.split x in + (List.concat insts, Typ_aux (Typ_tuple tys, l)) + ) + (cross' tys) + in + (kidset_bigunion vars, insttys) + | Typ_app (Id_aux (Id "bitvector", _), [A_aux (A_nexp sz, _); _]) -> + (KidSet.of_list (size_nvars_nexp sz), [([], typ)]) | Typ_app (_, tas) -> - (KidSet.empty,[[],typ]) (* We only support sizes for bitvectors mentioned explicitly, not any buried - inside another type *) + (KidSet.empty, [([], typ)]) + (* We only support sizes for bitvectors mentioned explicitly, not any buried + inside another type *) | Typ_exist (kopts, nc, t) -> - let (vars,tys) = size_nvars_ty t in - let find_insts k (insts,nc) = - let inst,nc' = - if KidSet.mem (kopt_kid k) vars then - let is,nc' = extract_set_nc env l (kopt_kid k) nc in - Some is,nc' - else None,nc - in (k,inst)::insts,nc' - in - let (insts,nc') = List.fold_right find_insts kopts ([],nc) in - let insts = cross'' insts in - let ty_and_inst (inst0,ty) inst = - let kopts, nc', ty = apply_kid_insts inst nc' ty in - let ty = - (* Typ_exist is not allowed an empty list of kids *) - match kopts with - | [] -> ty - | _ -> Typ_aux (Typ_exist (kopts, nc', ty),l) - in inst@inst0, ty - in - let tys = List.concat (List.map (fun instty -> List.map (ty_and_inst instty) insts) tys) in - let free = List.fold_left (fun vars k -> KidSet.remove (kopt_kid k) vars) vars kopts in - (free,tys) + let vars, tys = size_nvars_ty t in + let find_insts k (insts, nc) = + let inst, nc' = + if KidSet.mem (kopt_kid k) vars then ( + let is, nc' = extract_set_nc env l (kopt_kid k) nc in + (Some is, nc') + ) + else (None, nc) + in + ((k, inst) :: insts, nc') + in + let insts, nc' = List.fold_right find_insts kopts ([], nc) in + let insts = cross'' insts in + let ty_and_inst (inst0, ty) inst = + let kopts, nc', ty = apply_kid_insts inst nc' ty in + let ty = + (* Typ_exist is not allowed an empty list of kids *) + match kopts with [] -> ty | _ -> Typ_aux (Typ_exist (kopts, nc', ty), l) + in + (inst @ inst0, ty) + in + let tys = List.concat (List.map (fun instty -> List.map (ty_and_inst instty) insts) tys) in + let free = List.fold_left (fun vars k -> KidSet.remove (kopt_kid k) vars) vars kopts in + (free, tys) | Typ_internal_unknown -> Reporting.unreachable l __POS__ "escaped Typ_internal_unknown" in - let size_nvars_ty (Typ_aux (ty,l) as typ) = + let size_nvars_ty (Typ_aux (ty, l) as typ) = match ty with - | Typ_exist (kids,_,t) -> - begin - match snd (size_nvars_ty typ) with - | [] -> [] - | [[],_] -> [] - | tys -> - if contains_exist t then - cannot l "Only prenex types in unions are supported by monomorphisation" [] + | Typ_exist (kids, _, t) -> begin + match snd (size_nvars_ty typ) with + | [] -> [] + | [([], _)] -> [] + | tys -> + if contains_exist t then cannot l "Only prenex types in unions are supported by monomorphisation" [] else tys - end + end | _ -> [] in (* TODO: reject universally quantification or monomorphise it *) let variants = size_nvars_ty ty in match variants with | [] -> None - | [l,_] when List.for_all (function (_,None) -> true | _ -> false) l -> None - | sample::_ -> - if List.length variants > size_set_limit then - cannot ql - (string_of_int (List.length variants) ^ "variants for constructor " ^ i ^ - "bigger than limit " ^ string_of_int size_set_limit) None - else - let wrap = match id with - | Id_aux (Id i,l) -> (fun f -> Id_aux (Id (f i),Generated l)) - | Id_aux (Operator i,l) -> (fun f -> Id_aux (Operator (f i),l)) - in - let name_seg = function - | (_,None) -> "" - | (k,Some i) -> "#" ^ string_of_kid (kopt_kid k) ^ Big_int.to_string i - in - let name l i = String.concat "" (i::(List.map name_seg l)) in - Some (List.map (fun (l,ty) -> (l, wrap (name l),ty)) variants) + | [(l, _)] when List.for_all (function _, None -> true | _ -> false) l -> None + | sample :: _ -> + if List.length variants > size_set_limit then + cannot ql + (string_of_int (List.length variants) + ^ "variants for constructor " ^ i ^ "bigger than limit " ^ string_of_int size_set_limit + ) + None + else ( + let wrap = + match id with + | Id_aux (Id i, l) -> fun f -> Id_aux (Id (f i), Generated l) + | Id_aux (Operator i, l) -> fun f -> Id_aux (Operator (f i), l) + in + let name_seg = function _, None -> "" | k, Some i -> "#" ^ string_of_kid (kopt_kid k) ^ Big_int.to_string i in + let name l i = String.concat "" (i :: List.map name_seg l) in + Some (List.map (fun (l, ty) -> (l, wrap (name l), ty)) variants) + ) let typ_of_args args = match args with - | [(E_aux (E_tuple args, (_, tannot)) as exp)] -> - begin match destruct_tannot tannot with - | Some (_,Typ_aux (Typ_exist _,_)) -> - let tys = List.map Type_check.typ_of args in - Typ_aux (Typ_tuple tys,Unknown) - | _ -> Type_check.typ_of exp - end - | [exp] -> - Type_check.typ_of exp + | [(E_aux (E_tuple args, (_, tannot)) as exp)] -> begin + match destruct_tannot tannot with + | Some (_, Typ_aux (Typ_exist _, _)) -> + let tys = List.map Type_check.typ_of args in + Typ_aux (Typ_tuple tys, Unknown) + | _ -> Type_check.typ_of exp + end + | [exp] -> Type_check.typ_of exp | _ -> - let tys = List.map Type_check.typ_of args in - Typ_aux (Typ_tuple tys,Unknown) + let tys = List.map Type_check.typ_of args in + Typ_aux (Typ_tuple tys, Unknown) (* Check to see if we need to monomorphise a use of a constructor. Currently assumes that bitvector sizes are always given as a variable; don't yet handle more general cases (e.g., 8 * var) *) let refine_constructor refinements l env id args = - match List.find (fun (id',_) -> Id.compare id id' = 0) refinements with - | (_,irefinements) -> begin - let (_,constr_ty) = Env.get_union_id id env in - match constr_ty with - (* A constructor should always have a single argument. *) - | Typ_aux (Typ_fn ([constr_ty],_),_) -> begin - let arg_ty = typ_of_args args in - match Type_check.destruct_exist (Type_check.Env.expand_synonyms env constr_ty) with - | None -> None - | Some (kopts,nc,constr_ty) -> - (* Remove existentials in argument types to prevent unification failures *) - let unwrap (Typ_aux (t,_) as typ) = match t with - | Typ_exist (_,_,typ) -> typ - | _ -> typ - in - let arg_ty = match arg_ty with - | Typ_aux (Typ_tuple ts,annot) -> Typ_aux (Typ_tuple (List.map unwrap ts),annot) - | _ -> arg_ty - in - let bindings = Type_check.unify l env (tyvars_of_typ constr_ty) constr_ty arg_ty in - let find_kopt kopt = try Some (KBindings.find (kopt_kid kopt) bindings) with Not_found -> None in - let bindings = List.map find_kopt kopts in - let matches_refinement (mapping,_,_) = - List.for_all2 - (fun v (_,w) -> - match v,w with - | _,None -> true - | Some (A_aux (A_nexp (Nexp_aux (Nexp_constant n, _)), _)),Some m -> Big_int.equal n m - | _,_ -> false) bindings mapping - in - match List.find matches_refinement irefinements with - | (_,new_id,_) -> Some (E_app (new_id,args)) - | exception Not_found -> - let print_map kopt = function - | None -> string_of_kid (kopt_kid kopt) ^ " -> _" - | Some ta -> string_of_kid (kopt_kid kopt) ^ " -> " ^ string_of_typ_arg ta - in - (Reporting.print_err l "Monomorphisation" - ("Unable to refine constructor " ^ string_of_id id ^ " using mapping " ^ String.concat "," (List.map2 print_map kopts bindings)); - None) + match List.find (fun (id', _) -> Id.compare id id' = 0) refinements with + | _, irefinements -> begin + let _, constr_ty = Env.get_union_id id env in + match constr_ty with + (* A constructor should always have a single argument. *) + | Typ_aux (Typ_fn ([constr_ty], _), _) -> begin + let arg_ty = typ_of_args args in + match Type_check.destruct_exist (Type_check.Env.expand_synonyms env constr_ty) with + | None -> None + | Some (kopts, nc, constr_ty) -> ( + (* Remove existentials in argument types to prevent unification failures *) + let unwrap (Typ_aux (t, _) as typ) = match t with Typ_exist (_, _, typ) -> typ | _ -> typ in + let arg_ty = + match arg_ty with + | Typ_aux (Typ_tuple ts, annot) -> Typ_aux (Typ_tuple (List.map unwrap ts), annot) + | _ -> arg_ty + in + let bindings = Type_check.unify l env (tyvars_of_typ constr_ty) constr_ty arg_ty in + let find_kopt kopt = try Some (KBindings.find (kopt_kid kopt) bindings) with Not_found -> None in + let bindings = List.map find_kopt kopts in + let matches_refinement (mapping, _, _) = + List.for_all2 + (fun v (_, w) -> + match (v, w) with + | _, None -> true + | Some (A_aux (A_nexp (Nexp_aux (Nexp_constant n, _)), _)), Some m -> Big_int.equal n m + | _, _ -> false + ) + bindings mapping + in + match List.find matches_refinement irefinements with + | _, new_id, _ -> Some (E_app (new_id, args)) + | exception Not_found -> + let print_map kopt = function + | None -> string_of_kid (kopt_kid kopt) ^ " -> _" + | Some ta -> string_of_kid (kopt_kid kopt) ^ " -> " ^ string_of_typ_arg ta + in + Reporting.print_err l "Monomorphisation" + ("Unable to refine constructor " ^ string_of_id id ^ " using mapping " + ^ String.concat "," (List.map2 print_map kopts bindings) + ); + None + ) + end + | _ -> None end - | _ -> None - end | exception Not_found -> None - type pat_choice = Parse_ast.l * (int * int * (id * tannot exp) list) (* We may need to split up a pattern match if (1) we've been told to case split @@ -497,10 +464,16 @@ type pat_choice = Parse_ast.l * (int * int * (id * tannot exp) list) in the pattern. *) type split = | NoSplit - | VarSplit of (tannot pat * (* pattern for this case *) - (id * tannot Ast.exp) list * (* substitutions for arguments *) - pat_choice list * (* optional locations of constraints/case expressions to reduce *) - (nexp * bool) KBindings.t) (* substitutions for type variables; bool says whether to generate an assertion because we generated a wildcard to make the completeness checker happy *) + | VarSplit of + ( tannot pat + * (* pattern for this case *) + (id * tannot Ast.exp) list + * (* substitutions for arguments *) + pat_choice list + * (* optional locations of constraints/case expressions to reduce *) + (nexp * bool) KBindings.t + ) + (* substitutions for type variables; bool says whether to generate an assertion because we generated a wildcard to make the completeness checker happy *) list | ConstrSplit of (tannot pat * nexp KBindings.t) list @@ -510,52 +483,56 @@ let freshen_id = let n = !counter in let () = counter := n + 1 in match id with - | Id_aux (Id x, l) -> Id_aux (Id (x ^ "#m" ^ string_of_int n),Generated l) - | Id_aux (Operator x, l) -> Id_aux (Operator (x ^ "#m" ^ string_of_int n),Generated l) + | Id_aux (Id x, l) -> Id_aux (Id (x ^ "#m" ^ string_of_int n), Generated l) + | Id_aux (Operator x, l) -> Id_aux (Operator (x ^ "#m" ^ string_of_int n), Generated l) (* TODO: only freshen bindings that might be shadowed *) let freshen_pat_bindings p = - let rec aux (P_aux (p,(l,annot)) as pat) = - let mkp p = P_aux (p,(Generated l, annot)) in + let rec aux (P_aux (p, (l, annot)) as pat) = + let mkp p = P_aux (p, (Generated l, annot)) in match p with - | P_lit _ - | P_wild -> pat, [] + | P_lit _ | P_wild -> (pat, []) | P_or (p1, p2) -> - let (r1, vs1) = aux p1 in - let (r2, vs2) = aux p2 in - (mkp (P_or (r1, r2)), vs1 @ vs2) + let r1, vs1 = aux p1 in + let r2, vs2 = aux p2 in + (mkp (P_or (r1, r2)), vs1 @ vs2) | P_not p -> - let (r, vs) = aux p in - (mkp (P_not r), vs) - | P_as (p,_) -> aux p - | P_typ (typ,p) -> let p',vs = aux p in mkp (P_typ (typ,p')),vs - | P_id id -> let id' = freshen_id id in mkp (P_id id'),[id,E_aux (E_id id',(Generated Unknown,empty_tannot))] - | P_var (p,_) -> aux p - | P_app (id,args) -> - let args',vs = List.split (List.map aux args) in - mkp (P_app (id,args')),List.concat vs + let r, vs = aux p in + (mkp (P_not r), vs) + | P_as (p, _) -> aux p + | P_typ (typ, p) -> + let p', vs = aux p in + (mkp (P_typ (typ, p')), vs) + | P_id id -> + let id' = freshen_id id in + (mkp (P_id id'), [(id, E_aux (E_id id', (Generated Unknown, empty_tannot)))]) + | P_var (p, _) -> aux p + | P_app (id, args) -> + let args', vs = List.split (List.map aux args) in + (mkp (P_app (id, args')), List.concat vs) | P_vector ps -> - let ps,vs = List.split (List.map aux ps) in - mkp (P_vector ps),List.concat vs + let ps, vs = List.split (List.map aux ps) in + (mkp (P_vector ps), List.concat vs) | P_vector_concat ps -> - let ps,vs = List.split (List.map aux ps) in - mkp (P_vector_concat ps),List.concat vs + let ps, vs = List.split (List.map aux ps) in + (mkp (P_vector_concat ps), List.concat vs) | P_string_append ps -> - let ps,vs = List.split (List.map aux ps) in - mkp (P_string_append ps),List.concat vs + let ps, vs = List.split (List.map aux ps) in + (mkp (P_string_append ps), List.concat vs) | P_tuple ps -> - let ps,vs = List.split (List.map aux ps) in - mkp (P_tuple ps),List.concat vs + let ps, vs = List.split (List.map aux ps) in + (mkp (P_tuple ps), List.concat vs) | P_list ps -> - let ps,vs = List.split (List.map aux ps) in - mkp (P_list ps),List.concat vs - | P_cons (p1,p2) -> - let p1,vs1 = aux p1 in - let p2,vs2 = aux p2 in - mkp (P_cons (p1, p2)), vs1@vs2 + let ps, vs = List.split (List.map aux ps) in + (mkp (P_list ps), List.concat vs) + | P_cons (p1, p2) -> + let p1, vs1 = aux p1 in + let p2, vs2 = aux p2 in + (mkp (P_cons (p1, p2)), vs1 @ vs2) | P_vector_subrange _ -> - Reporting.unreachable l __POS__ "vector subrange pattern should be removed before monomorphisation" - in aux p + Reporting.unreachable l __POS__ "vector subrange pattern should be removed before monomorphisation" + in + aux p (* This cuts off function bodies at false assertions that we may have produced in a wildcard pattern match. It should handle the same assertions that @@ -563,118 +540,112 @@ let freshen_pat_bindings p = let stop_at_false_assertions e = let dummy_value_of_typ typ = let l = Generated Unknown in - E_aux (E_exit (E_aux (E_lit (L_aux (L_unit,l)),(l,empty_tannot))),(l,empty_tannot)) + E_aux (E_exit (E_aux (E_lit (L_aux (L_unit, l)), (l, empty_tannot))), (l, empty_tannot)) in - let rec nc_false (NC_aux (nc,_)) = - match nc with - | NC_false -> true - | NC_and (nc1,nc2) -> nc_false nc1 || nc_false nc2 - | _ -> false + let rec nc_false (NC_aux (nc, _)) = + match nc with NC_false -> true | NC_and (nc1, nc2) -> nc_false nc1 || nc_false nc2 | _ -> false in - let rec exp_false (E_aux (e,_)) = + let rec exp_false (E_aux (e, _)) = match e with | E_constraint nc -> nc_false nc - | E_lit (L_aux (L_false,_)) -> true - | E_app (Id_aux (Id "and_bool",_),[e1;e2]) -> - exp_false e1 || exp_false e2 + | E_lit (L_aux (L_false, _)) -> true + | E_app (Id_aux (Id "and_bool", _), [e1; e2]) -> exp_false e1 || exp_false e2 | _ -> false in - let rec exp (E_aux (e,ann) as ea) = + let rec exp (E_aux (e, ann) as ea) = match e with | E_block es -> - let rec aux = function - | [] -> [], None - | e::es -> let e,stop = exp e in - match stop with - | Some _ -> [e],stop - | None -> - let es',stop = aux es in - e::es',stop - in let es,stop = aux es in begin + let rec aux = function + | [] -> ([], None) + | e :: es -> ( + let e, stop = exp e in + match stop with + | Some _ -> ([e], stop) + | None -> + let es', stop = aux es in + (e :: es', stop) + ) + in + let es, stop = aux es in + begin match stop with - | None -> E_aux (E_block es,ann), stop + | None -> (E_aux (E_block es, ann), stop) | Some typ -> - let typ' = typ_of_annot ann in - if Type_check.alpha_equivalent (env_of_annot ann) typ typ' - then E_aux (E_block es,ann), stop - else E_aux (E_block (es@[dummy_value_of_typ typ']),ann), Some typ' - end - | E_typ (typ,e) -> let e,stop = exp e in - let stop = match stop with Some _ -> Some typ | None -> None in - E_aux (E_typ (typ,e),ann),stop - | E_let (LB_aux (LB_val (p,e1),lbann),e2) -> - let e1,stop = exp e1 in begin - match stop with - | Some _ -> e1,stop - | None -> - let e2,stop = exp e2 in - E_aux (E_let (LB_aux (LB_val (p,e1),lbann),e2),ann), stop - end - | E_assert (e1,_) when exp_false e1 -> - ea, Some (typ_of_annot ann) - | E_throw e -> - ea, Some (typ_of_annot ann) - | _ -> ea, None - in fst (exp e) + let typ' = typ_of_annot ann in + if Type_check.alpha_equivalent (env_of_annot ann) typ typ' then (E_aux (E_block es, ann), stop) + else (E_aux (E_block (es @ [dummy_value_of_typ typ']), ann), Some typ') + end + | E_typ (typ, e) -> + let e, stop = exp e in + let stop = match stop with Some _ -> Some typ | None -> None in + (E_aux (E_typ (typ, e), ann), stop) + | E_let (LB_aux (LB_val (p, e1), lbann), e2) -> + let e1, stop = exp e1 in + begin + match stop with + | Some _ -> (e1, stop) + | None -> + let e2, stop = exp e2 in + (E_aux (E_let (LB_aux (LB_val (p, e1), lbann), e2), ann), stop) + end + | E_assert (e1, _) when exp_false e1 -> (ea, Some (typ_of_annot ann)) + | E_throw e -> (ea, Some (typ_of_annot ann)) + | _ -> (ea, None) + in + fst (exp e) (* Use the location pairs in choices to reduce case expressions at the first location to the given case at the second. *) let apply_pat_choices choices = - let rec rewrite_ncs (NC_aux (nc,l) as nconstr) = + let rec rewrite_ncs (NC_aux (nc, l) as nconstr) = match nc with - | NC_set _ - | NC_or _ -> begin - match List.assoc l choices with - | choice,max,_ -> - NC_aux ((if choice < max then NC_true else NC_false), Generated l) - | exception Not_found -> nconstr + | NC_set _ | NC_or _ -> begin + match List.assoc l choices with + | choice, max, _ -> NC_aux ((if choice < max then NC_true else NC_false), Generated l) + | exception Not_found -> nconstr end - | NC_and (nc1,nc2) -> begin - match rewrite_ncs nc1, rewrite_ncs nc2 with - | NC_aux (NC_false,l), _ - | _, NC_aux (NC_false,l) -> NC_aux (NC_false,l) - | nc1,nc2 -> NC_aux (NC_and (nc1,nc2),l) + | NC_and (nc1, nc2) -> begin + match (rewrite_ncs nc1, rewrite_ncs nc2) with + | NC_aux (NC_false, l), _ | _, NC_aux (NC_false, l) -> NC_aux (NC_false, l) + | nc1, nc2 -> NC_aux (NC_and (nc1, nc2), l) end | _ -> nconstr in - let rec rewrite_assert_cond (E_aux (e,(l,ann)) as exp) = + let rec rewrite_assert_cond (E_aux (e, (l, ann)) as exp) = match List.assoc l choices with - | choice,max,_ -> - E_aux (E_lit (L_aux ((if choice < max then L_true else L_false (* wildcard *)), - Generated l)),(Generated l,ann)) - | exception Not_found -> - match e with - | E_constraint nc -> E_aux (E_constraint (rewrite_ncs nc),(l,ann)) - | E_app (Id_aux (Id "and_bool",andl), [e1;e2]) -> - E_aux (E_app (Id_aux (Id "and_bool",andl), - [rewrite_assert_cond e1; - rewrite_assert_cond e2]),(l,ann)) - | _ -> exp - in - let rewrite_assert (e1,e2) = - E_assert (rewrite_assert_cond e1, e2) + | choice, max, _ -> + E_aux + (E_lit (L_aux ((if choice < max then L_true else L_false (* wildcard *)), Generated l)), (Generated l, ann)) + | exception Not_found -> ( + match e with + | E_constraint nc -> E_aux (E_constraint (rewrite_ncs nc), (l, ann)) + | E_app (Id_aux (Id "and_bool", andl), [e1; e2]) -> + E_aux (E_app (Id_aux (Id "and_bool", andl), [rewrite_assert_cond e1; rewrite_assert_cond e2]), (l, ann)) + | _ -> exp + ) in - let rewrite_case (e,cases) = + let rewrite_assert (e1, e2) = E_assert (rewrite_assert_cond e1, e2) in + let rewrite_case (e, cases) = match List.assoc (exp_loc e) choices with - | choice,max,subst -> - (match List.nth cases choice with - | Pat_aux (Pat_exp (p,E_aux (e,_)),_) -> - let dummyannot = (Generated Unknown,empty_tannot) in - (* TODO: use a proper substitution *) - List.fold_left (fun e (id,e') -> - E_let (LB_aux (LB_val (P_aux (P_id id, dummyannot),e'),dummyannot),E_aux (e,dummyannot))) e subst - | Pat_aux (Pat_when _,(l,_)) -> - raise (Reporting.err_unreachable l __POS__ - "Pattern acquired a guard after analysis!") - | exception Not_found -> - raise (Reporting.err_unreachable (exp_loc e) __POS__ - "Unable to find case I found earlier!")) - | exception Not_found -> E_match (e,cases) + | choice, max, subst -> ( + match List.nth cases choice with + | Pat_aux (Pat_exp (p, E_aux (e, _)), _) -> + let dummyannot = (Generated Unknown, empty_tannot) in + (* TODO: use a proper substitution *) + List.fold_left + (fun e (id, e') -> + E_let (LB_aux (LB_val (P_aux (P_id id, dummyannot), e'), dummyannot), E_aux (e, dummyannot)) + ) + e subst + | Pat_aux (Pat_when _, (l, _)) -> + raise (Reporting.err_unreachable l __POS__ "Pattern acquired a guard after analysis!") + | exception Not_found -> + raise (Reporting.err_unreachable (exp_loc e) __POS__ "Unable to find case I found earlier!") + ) + | exception Not_found -> E_match (e, cases) in let open Rewriter in - fold_exp { id_exp_alg with - e_assert = rewrite_assert; - e_case = rewrite_case } + fold_exp { id_exp_alg with e_assert = rewrite_assert; e_case = rewrite_case } type split_req = split_loc * string * (tannot pat list * Parse_ast.l) option @@ -685,523 +656,547 @@ let split_defs target all_errors (splits : split_req list) env ast = let sc_type_union q (Tu_aux (Tu_ty_id (ty, id), l)) = let env = Env.add_typquant l q env in match split_src_type error_opt env id ty q with - | None -> ([],[Tu_aux (Tu_ty_id (ty,id),l)]) + | None -> ([], [Tu_aux (Tu_ty_id (ty, id), l)]) | Some variants -> - ([(id,variants)], - List.map (fun (insts, id', ty) -> Tu_aux (Tu_ty_id (ty,id'),Generated l)) variants) + ([(id, variants)], List.map (fun (insts, id', ty) -> Tu_aux (Tu_ty_id (ty, id'), Generated l)) variants) in - let sc_type_def ((TD_aux (tda,annot)) as td) = + let sc_type_def (TD_aux (tda, annot) as td) = match tda with - | TD_variant (id,quant,tus,flag) -> - let (refinements, tus') = List.split (List.map (sc_type_union quant) tus) in - (List.concat refinements, TD_aux (TD_variant (id,quant,List.concat tus',flag),annot)) - | _ -> ([],td) + | TD_variant (id, quant, tus, flag) -> + let refinements, tus' = List.split (List.map (sc_type_union quant) tus) in + (List.concat refinements, TD_aux (TD_variant (id, quant, List.concat tus', flag), annot)) + | _ -> ([], td) in let sc_def d = match d with | DEF_aux (DEF_type td, def_annot) -> - let (refinements,td') = sc_type_def td in - (refinements, DEF_aux (DEF_type td', def_annot)) - | _ -> ([], d) + let refinements, td' = sc_type_def td in + (refinements, DEF_aux (DEF_type td', def_annot)) + | _ -> ([], d) in - let (refinements, defs') = List.split (List.map sc_def defs) - in (List.concat refinements, defs') + let refinements, defs' = List.split (List.map sc_def defs) in + (List.concat refinements, defs') in - let (refinements, defs') = split_constructors ast.defs in + let refinements, defs' = split_constructors ast.defs in (* This will perform the initialisation just once, and share it across all defs *) let const_prop = Constant_propagation.const_prop target ast in let subst_exp ref_vars substs ksubsts exp = - let substs = bindings_from_list substs, KBindings.map fst ksubsts in + let substs = (bindings_from_list substs, KBindings.map fst ksubsts) in let exp = fst (const_prop ref_vars substs Bindings.empty exp) in - KBindings.fold (fun kid (nexp, should_assert) exp -> - if should_assert && not (is_kid_generated kid) then + KBindings.fold + (fun kid (nexp, should_assert) exp -> + if should_assert && not (is_kid_generated kid) then ( let assert_nc = nc_eq (nvar kid) nexp in Type_check.tc_assume assert_nc exp - else - exp) ksubsts exp + ) + else exp + ) + ksubsts exp in (* Split a variable pattern into every possible value *) - let split var pat_l annot = let v = string_of_id var in let env = Type_check.env_of_annot (pat_l, annot) in let typ = Type_check.typ_of_annot (pat_l, annot) in let typ = Env.expand_synonyms env typ in - let Typ_aux (ty,l) = typ in + let (Typ_aux (ty, l)) = typ in let new_l = Generated l in - let renew_id (Id_aux (id,l)) = Id_aux (id,new_l) in + let renew_id (Id_aux (id, l)) = Id_aux (id, new_l) in let cannot msg = let open Reporting in let error_msg = "Cannot split type " ^ string_of_typ typ ^ " for variable " ^ v ^ ": " ^ msg in - if all_errors - then (no_errors_happened := false; - print_err pat_l "" error_msg; - [P_aux (P_id var,(pat_l,annot)),[],[],KBindings.empty]) - else raise (err_general pat_l error_msg) + if all_errors then ( + no_errors_happened := false; + print_err pat_l "" error_msg; + [(P_aux (P_id var, (pat_l, annot)), [], [], KBindings.empty)] + ) + else raise (err_general pat_l error_msg) in match ty with - | Typ_id (Id_aux (Id "bool",_)) | Typ_app (Id_aux (Id "atom_bool", _), [_]) -> - [P_aux (P_lit (L_aux (L_true,new_l)),(l,annot)),[var, E_aux (E_lit (L_aux (L_true,new_l)),(new_l,annot))],[],KBindings.empty; - P_aux (P_lit (L_aux (L_false,new_l)),(l,annot)),[var, E_aux (E_lit (L_aux (L_false,new_l)),(new_l,annot))],[],KBindings.empty] - - | Typ_id id -> - (try - (* enumerations *) - let ns = Env.get_enum id env in - List.map (fun n -> (P_aux (P_id (renew_id n),(l,annot)), - [var,E_aux (E_id (renew_id n),(new_l,annot))],[],KBindings.empty)) ns - with Type_error _ -> - match id with - | Id_aux (Id "bit",_) -> - List.map (fun b -> - P_aux (P_lit (L_aux (b,new_l)),(l,annot)), - [var,E_aux (E_lit (L_aux (b,new_l)),(new_l, annot))],[],KBindings.empty) - [L_zero; L_one] - | _ -> cannot ("don't know about type " ^ string_of_id id)) - - | Typ_app (Id_aux (Id "bitvector",_), [A_aux (A_nexp len,_);_]) -> - (match len with - | Nexp_aux (Nexp_constant sz,_) when Big_int.greater_equal sz Big_int.zero -> - let sz = Big_int.to_int sz in - let num_lits = Big_int.pow_int (Big_int.of_int 2) sz in - (* Check that split size is within limits before generating the list of literals *) - if (Big_int.less_equal num_lits (Big_int.of_int size_set_limit)) then - let lits = make_vectors sz in - (* Some parts of Sail don't recognise complete bitvector - matches, so make the last one a wildcard. *) - let rec map_lits = function - | [] -> [] - | [lit] -> - [P_aux (P_wild, (l,annot)), - [var, E_aux (E_lit lit,(new_l,annot))],[],KBindings.empty] - | lit::tl -> - (P_aux (P_lit lit,(l,annot)), - [var,E_aux (E_lit lit,(new_l,annot))],[],KBindings.empty)::(map_lits tl) - in map_lits lits - else - cannot ("bitvector length outside limit, " ^ string_of_nexp len) - | _ -> - cannot ("length not constant and positive, " ^ string_of_nexp len) - ) + | Typ_id (Id_aux (Id "bool", _)) | Typ_app (Id_aux (Id "atom_bool", _), [_]) -> + [ + ( P_aux (P_lit (L_aux (L_true, new_l)), (l, annot)), + [(var, E_aux (E_lit (L_aux (L_true, new_l)), (new_l, annot)))], + [], + KBindings.empty + ); + ( P_aux (P_lit (L_aux (L_false, new_l)), (l, annot)), + [(var, E_aux (E_lit (L_aux (L_false, new_l)), (new_l, annot)))], + [], + KBindings.empty + ); + ] + | Typ_id id -> ( + try + (* enumerations *) + let ns = Env.get_enum id env in + List.map + (fun n -> + ( P_aux (P_id (renew_id n), (l, annot)), + [(var, E_aux (E_id (renew_id n), (new_l, annot)))], + [], + KBindings.empty + ) + ) + ns + with Type_error _ -> ( + match id with + | Id_aux (Id "bit", _) -> + List.map + (fun b -> + ( P_aux (P_lit (L_aux (b, new_l)), (l, annot)), + [(var, E_aux (E_lit (L_aux (b, new_l)), (new_l, annot)))], + [], + KBindings.empty + ) + ) + [L_zero; L_one] + | _ -> cannot ("don't know about type " ^ string_of_id id) + ) + ) + | Typ_app (Id_aux (Id "bitvector", _), [A_aux (A_nexp len, _); _]) -> ( + match len with + | Nexp_aux (Nexp_constant sz, _) when Big_int.greater_equal sz Big_int.zero -> + let sz = Big_int.to_int sz in + let num_lits = Big_int.pow_int (Big_int.of_int 2) sz in + (* Check that split size is within limits before generating the list of literals *) + if Big_int.less_equal num_lits (Big_int.of_int size_set_limit) then ( + let lits = make_vectors sz in + (* Some parts of Sail don't recognise complete bitvector + matches, so make the last one a wildcard. *) + let rec map_lits = function + | [] -> [] + | [lit] -> + [(P_aux (P_wild, (l, annot)), [(var, E_aux (E_lit lit, (new_l, annot)))], [], KBindings.empty)] + | lit :: tl -> + (P_aux (P_lit lit, (l, annot)), [(var, E_aux (E_lit lit, (new_l, annot)))], [], KBindings.empty) + :: map_lits tl + in + map_lits lits + ) + else cannot ("bitvector length outside limit, " ^ string_of_nexp len) + | _ -> cannot ("length not constant and positive, " ^ string_of_nexp len) + ) (* set constrained numbers *) - | Typ_app (Id_aux (Id "atom",_), [A_aux (A_nexp (Nexp_aux (value,_) as nexp),_)]) -> - begin - (* Introduce a wilcard for the last pattern to ensure completeness is clear *) - let mk_lit kid wildcard i = - let lit = L_aux (L_num i,new_l) in - P_aux ((if wildcard then P_wild else P_lit lit), (l,annot)), - [var,E_aux (E_lit lit,(new_l,annot))],[], - match kid with None -> KBindings.empty - | Some k -> KBindings.singleton k (nconstant i, wildcard) - in - match value with - | Nexp_constant i -> [mk_lit None false i] - | Nexp_var kvar -> - let ncs = Env.get_constraints env in - let nc = List.fold_left nc_and nc_true ncs in - (match extract_set_nc env l kvar nc with - | (is,_) -> Util.map_last (mk_lit (Some kvar)) is - | exception Reporting.Fatal_error (Reporting.Err_general (_,msg)) -> cannot msg) - | _ -> cannot ("unsupport atom nexp " ^ string_of_nexp nexp) - end + | Typ_app (Id_aux (Id "atom", _), [A_aux (A_nexp (Nexp_aux (value, _) as nexp), _)]) -> begin + (* Introduce a wilcard for the last pattern to ensure completeness is clear *) + let mk_lit kid wildcard i = + let lit = L_aux (L_num i, new_l) in + ( P_aux ((if wildcard then P_wild else P_lit lit), (l, annot)), + [(var, E_aux (E_lit lit, (new_l, annot)))], + [], + match kid with None -> KBindings.empty | Some k -> KBindings.singleton k (nconstant i, wildcard) + ) + in + match value with + | Nexp_constant i -> [mk_lit None false i] + | Nexp_var kvar -> ( + let ncs = Env.get_constraints env in + let nc = List.fold_left nc_and nc_true ncs in + match extract_set_nc env l kvar nc with + | is, _ -> Util.map_last (mk_lit (Some kvar)) is + | exception Reporting.Fatal_error (Reporting.Err_general (_, msg)) -> cannot msg + ) + | _ -> cannot ("unsupport atom nexp " ^ string_of_nexp nexp) + end | _ -> cannot ("unsupported type " ^ string_of_typ typ) in - (* Split variable patterns at the given locations *) - let map_locs ls defs = let match_file_line filename line = let rec aux = function | Unknown -> false - | Unique (_,l) -> aux l + | Unique (_, l) -> aux l | Generated l -> false (* Could do match_l l, but only want to split user-written patterns *) - | Hint (_,_,l) -> aux l - | Range (p,q) -> - p.Lexing.pos_fname = filename && - p.Lexing.pos_lnum <= line && line <= q.Lexing.pos_lnum - in aux + | Hint (_, _, l) -> aux l + | Range (p, q) -> p.Lexing.pos_fname = filename && p.Lexing.pos_lnum <= line && line <= q.Lexing.pos_lnum + in + aux in let match_l l = let matches = - List.filter (function - | (Exact l',_,_) -> l = l' - | (Line (filename,line),_,_) -> match_file_line filename line l) + List.filter + (function Exact l', _, _ -> l = l' | Line (filename, line), _, _ -> match_file_line filename line l) ls - in List.map (fun (_,var,optpats) -> (var,optpats)) matches - in + in + List.map (fun (_, var, optpats) -> (var, optpats)) matches + in let split_pat vars p = let id_match = function - | Id_aux (Id x,_) -> (try Some (List.assoc x vars) with Not_found -> None) - | Id_aux (Operator x,_) -> (try Some (List.assoc x vars) with Not_found -> None) + | Id_aux (Id x, _) -> ( + try Some (List.assoc x vars) with Not_found -> None + ) + | Id_aux (Operator x, _) -> ( + try Some (List.assoc x vars) with Not_found -> None + ) in let rec list f = function | [] -> None - | h::t -> - let t' = - match list f t with - | None -> [t,[],[],KBindings.empty] - | Some t' -> t' - in - let h' = - match f h with - | None -> [h,[],[],KBindings.empty] - | Some ps -> ps - in - let merge (h,hsubs,hpchoices,hksubs) (t,tsubs,tpchoices,tksubs) = - if KBindings.for_all (fun kid (nexp, _) -> + | h :: t -> + let t' = match list f t with None -> [(t, [], [], KBindings.empty)] | Some t' -> t' in + let h' = match f h with None -> [(h, [], [], KBindings.empty)] | Some ps -> ps in + let merge (h, hsubs, hpchoices, hksubs) (t, tsubs, tpchoices, tksubs) = + if + KBindings.for_all + (fun kid (nexp, _) -> match KBindings.find_opt kid tksubs with | None -> true - | Some (nexp',_) -> Nexp.compare nexp nexp' == 0) hksubs - then Some (h::t, hsubs@tsubs, hpchoices@tpchoices, - KBindings.union (fun k a _ -> Some a) hksubs tksubs) - else None - in - Some (List.concat - (List.map (fun h -> List.filter_map (merge h) t') h')) + | Some (nexp', _) -> Nexp.compare nexp nexp' == 0 + ) + hksubs + then + Some (h :: t, hsubs @ tsubs, hpchoices @ tpchoices, KBindings.union (fun k a _ -> Some a) hksubs tksubs) + else None + in + Some (List.concat (List.map (fun h -> List.filter_map (merge h) t') h')) in - let rec spl (P_aux (p,(l,annot))) = + let rec spl (P_aux (p, (l, annot))) = let relist f ctx ps = - optmap (list f ps) - (fun ps -> - List.map (fun (ps,sub,pchoices,ksub) -> P_aux (ctx ps,(l,annot)),sub,pchoices,ksub) ps) + optmap (list f ps) (fun ps -> + List.map (fun (ps, sub, pchoices, ksub) -> (P_aux (ctx ps, (l, annot)), sub, pchoices, ksub)) ps + ) in let re f p = - optmap (spl p) - (fun ps -> List.map (fun (p,sub,pchoices,ksub) -> (P_aux (f p,(l,annot)), sub, pchoices, ksub)) ps) + optmap (spl p) (fun ps -> + List.map (fun (p, sub, pchoices, ksub) -> (P_aux (f p, (l, annot)), sub, pchoices, ksub)) ps + ) in let re2 f ctx p1 p2 = - (* Todo: I am not proud of this abuse of relist - but creating a special - * version of re just for two entries did not seem worth it - *) + (* Todo: I am not proud of this abuse of relist - but creating a special + * version of re just for two entries did not seem worth it + *) relist f (function [p1'; p2'] -> ctx p1' p2' | _ -> assert false) [p1; p2] in match p with - | P_lit _ - | P_wild - -> None - | P_or (p1, p2) -> - re2 spl (fun p1' p2' -> P_or (p1', p2')) p1 p2 + | P_lit _ | P_wild -> None + | P_or (p1, p2) -> re2 spl (fun p1' p2' -> P_or (p1', p2')) p1 p2 | P_not p -> - (* todo: not sure that I can't split - but can't figure out how at - * the moment *) - raise (Reporting.err_general l - ("Cannot split on 'not' pattern")) - | P_as (p',id) when id_match id <> None -> - raise (Reporting.err_general l - ("Cannot split " ^ string_of_id id ^ " on 'as' pattern")) - | P_as (p',id) -> - re (fun p -> P_as (p,id)) p' - | P_typ (t,p') -> re (fun p -> P_typ (t,p)) p' - | P_var (p', (TP_aux (TP_var kid,_) as tp)) -> - (match spl p' with - | None -> None - | Some ps -> - let kids = Spec_analysis.equal_kids (env_of_pat p') kid in - Some (List.map (fun (p,sub,pchoices,ksub) -> - P_aux (P_var (p,tp),(l,annot)), sub, pchoices, - match List.find_opt (fun k -> KBindings.mem k ksub) (KidSet.elements kids) with - | None -> ksub - | Some k -> KBindings.add kid (KBindings.find k ksub) ksub - ) ps)) - | P_var (p',tp) -> re (fun p -> P_var (p,tp)) p' - | P_id id -> - (match id_match id with - | None -> None - (* Total case split *) - | Some None -> Some (split id l annot) - (* Where the analysis proposed a specific case split, propagate a - literal as normal, but perform a more careful transformation - otherwise *) - | Some (Some (pats,l)) -> - let max = List.length pats - 1 in - let lit_like = function - | P_lit _ -> true - | P_vector ps -> List.for_all (function P_aux (P_lit _,_) -> true | _ -> false) ps - | _ -> false - in - let rec to_exp = function - | P_aux (P_lit lit,(l,ann)) -> E_aux (E_lit lit,(Generated l,ann)) - | P_aux (P_vector ps,(l,ann)) -> E_aux (E_vector (List.map to_exp ps),(Generated l,ann)) - | _ -> assert false - in - Some (List.mapi (fun i p -> - match p with - | P_aux (P_lit (L_aux (L_num j,_) as lit),(pl,pannot)) -> - let orig_typ = Env.base_typ_of (env_of_annot (l,annot)) (typ_of_annot (l,annot)) in - let kid_subst = match orig_typ with - | Typ_aux - (Typ_app (Id_aux (Id "atom",_), - [A_aux (A_nexp - (Nexp_aux (Nexp_var var,_)),_)]),_) -> - KBindings.singleton var (nconstant j, false) - | _ -> KBindings.empty - in - p,[id,E_aux (E_lit lit,(Generated pl,pannot))],[l,(i,max,[])],kid_subst - | P_aux (p',(pl,pannot)) when lit_like p' -> - p,[id,to_exp p],[l,(i,max,[])],KBindings.empty - | _ -> - let p',subst = freshen_pat_bindings p in - match p' with - | P_aux (P_wild,_) -> - P_aux (P_id id,(l,annot)),[],[l,(i,max,subst)],KBindings.empty - | _ -> - P_aux (P_as (p',id),(l,annot)),[],[l,(i,max,subst)],KBindings.empty) - pats) - ) - | P_app (id,ps) -> - relist spl (fun ps -> P_app (id,ps)) ps - | P_vector ps -> - relist spl (fun ps -> P_vector ps) ps - | P_vector_concat ps -> - relist spl (fun ps -> P_vector_concat ps) ps - | P_string_append ps -> - relist spl (fun ps -> P_string_append ps) ps - | P_tuple ps -> - relist spl (fun ps -> P_tuple ps) ps - | P_list ps -> - relist spl (fun ps -> P_list ps) ps - | P_cons (p1,p2) -> - re2 spl (fun p1' p2' -> P_cons (p1', p2')) p1 p2 + (* todo: not sure that I can't split - but can't figure out how at + * the moment *) + raise (Reporting.err_general l "Cannot split on 'not' pattern") + | P_as (p', id) when id_match id <> None -> + raise (Reporting.err_general l ("Cannot split " ^ string_of_id id ^ " on 'as' pattern")) + | P_as (p', id) -> re (fun p -> P_as (p, id)) p' + | P_typ (t, p') -> re (fun p -> P_typ (t, p)) p' + | P_var (p', (TP_aux (TP_var kid, _) as tp)) -> ( + match spl p' with + | None -> None + | Some ps -> + let kids = Spec_analysis.equal_kids (env_of_pat p') kid in + Some + (List.map + (fun (p, sub, pchoices, ksub) -> + ( P_aux (P_var (p, tp), (l, annot)), + sub, + pchoices, + match List.find_opt (fun k -> KBindings.mem k ksub) (KidSet.elements kids) with + | None -> ksub + | Some k -> KBindings.add kid (KBindings.find k ksub) ksub + ) + ) + ps + ) + ) + | P_var (p', tp) -> re (fun p -> P_var (p, tp)) p' + | P_id id -> ( + match id_match id with + | None -> None + (* Total case split *) + | Some None -> Some (split id l annot) + (* Where the analysis proposed a specific case split, propagate a + literal as normal, but perform a more careful transformation + otherwise *) + | Some (Some (pats, l)) -> + let max = List.length pats - 1 in + let lit_like = function + | P_lit _ -> true + | P_vector ps -> List.for_all (function P_aux (P_lit _, _) -> true | _ -> false) ps + | _ -> false + in + let rec to_exp = function + | P_aux (P_lit lit, (l, ann)) -> E_aux (E_lit lit, (Generated l, ann)) + | P_aux (P_vector ps, (l, ann)) -> E_aux (E_vector (List.map to_exp ps), (Generated l, ann)) + | _ -> assert false + in + Some + (List.mapi + (fun i p -> + match p with + | P_aux (P_lit (L_aux (L_num j, _) as lit), (pl, pannot)) -> + let orig_typ = Env.base_typ_of (env_of_annot (l, annot)) (typ_of_annot (l, annot)) in + let kid_subst = + match orig_typ with + | Typ_aux + (Typ_app (Id_aux (Id "atom", _), [A_aux (A_nexp (Nexp_aux (Nexp_var var, _)), _)]), _) + -> + KBindings.singleton var (nconstant j, false) + | _ -> KBindings.empty + in + (p, [(id, E_aux (E_lit lit, (Generated pl, pannot)))], [(l, (i, max, []))], kid_subst) + | P_aux (p', (pl, pannot)) when lit_like p' -> + (p, [(id, to_exp p)], [(l, (i, max, []))], KBindings.empty) + | _ -> ( + let p', subst = freshen_pat_bindings p in + match p' with + | P_aux (P_wild, _) -> + (P_aux (P_id id, (l, annot)), [], [(l, (i, max, subst))], KBindings.empty) + | _ -> (P_aux (P_as (p', id), (l, annot)), [], [(l, (i, max, subst))], KBindings.empty) + ) + ) + pats + ) + ) + | P_app (id, ps) -> relist spl (fun ps -> P_app (id, ps)) ps + | P_vector ps -> relist spl (fun ps -> P_vector ps) ps + | P_vector_concat ps -> relist spl (fun ps -> P_vector_concat ps) ps + | P_string_append ps -> relist spl (fun ps -> P_string_append ps) ps + | P_tuple ps -> relist spl (fun ps -> P_tuple ps) ps + | P_list ps -> relist spl (fun ps -> P_list ps) ps + | P_cons (p1, p2) -> re2 spl (fun p1' p2' -> P_cons (p1', p2')) p1 p2 | P_vector_subrange _ -> - Reporting.unreachable l __POS__ "vector subrange pattern should be removed before monomorphisation" - in spl p + Reporting.unreachable l __POS__ "vector subrange pattern should be removed before monomorphisation" + in + spl p in - let map_pat_by_loc (P_aux (p,(l,_)) as pat) = - match match_l l with - | [] -> None - | vars -> split_pat vars pat - in - let map_pat (P_aux (p,(l,tannot)) as pat) = - let try_by_location () = - match map_pat_by_loc pat with - | Some l -> VarSplit l - | None -> NoSplit - in + let map_pat_by_loc (P_aux (p, (l, _)) as pat) = match match_l l with [] -> None | vars -> split_pat vars pat in + let map_pat (P_aux (p, (l, tannot)) as pat) = + let try_by_location () = match map_pat_by_loc pat with Some l -> VarSplit l | None -> NoSplit in match p with - | P_app (id,args) -> - begin - match List.find (fun (id',_) -> Id.compare id id' = 0) refinements with - | (_,variants) -> -(* TODO: at changes to the pattern and what substitutions do we need in general? - let kid,kid_annot = - match args with - | [P_aux (P_var (_, TP_aux (TP_var kid, _)),ann)] -> kid,ann - | _ -> - raise (Reporting.err_general l - ("Pattern match not currently supported by monomorphisation: " - ^ string_of_pat pat)) - in - let map_inst (insts,id',_) = - let insts = - match insts with [(v,Some i)] -> [(kid,Nexp_aux (Nexp_constant i, Generated l))] - | _ -> assert false - in -(* - let insts,_ = split_insts insts in - let insts = List.map (fun (v,i) -> - (??, - Nexp_aux (Nexp_constant i,Generated l))) - insts in - P_aux (app (id',args),(Generated l,tannot)), -*) - P_aux (P_app (id',[P_aux (P_id (id_of_kid kid),kid_annot)]),(Generated l,tannot)), - kbindings_from_list insts - in -*) - let map_inst (insts,id',_) = - P_aux (P_app (id',args),(Generated l,tannot)), - KBindings.empty - in - ConstrSplit (List.map map_inst variants) - | exception Not_found -> try_by_location () - end - | _ -> try_by_location () + | P_app (id, args) -> begin + match List.find (fun (id', _) -> Id.compare id id' = 0) refinements with + | _, variants -> + (* TODO: at changes to the pattern and what substitutions do we need in general? + let kid,kid_annot = + match args with + | [P_aux (P_var (_, TP_aux (TP_var kid, _)),ann)] -> kid,ann + | _ -> + raise (Reporting.err_general l + ("Pattern match not currently supported by monomorphisation: " + ^ string_of_pat pat)) + in + let map_inst (insts,id',_) = + let insts = + match insts with [(v,Some i)] -> [(kid,Nexp_aux (Nexp_constant i, Generated l))] + | _ -> assert false + in + (* + let insts,_ = split_insts insts in + let insts = List.map (fun (v,i) -> + (??, + Nexp_aux (Nexp_constant i,Generated l))) + insts in + P_aux (app (id',args),(Generated l,tannot)), + *) + P_aux (P_app (id',[P_aux (P_id (id_of_kid kid),kid_annot)]),(Generated l,tannot)), + kbindings_from_list insts + in + *) + let map_inst (insts, id', _) = (P_aux (P_app (id', args), (Generated l, tannot)), KBindings.empty) in + ConstrSplit (List.map map_inst variants) + | exception Not_found -> try_by_location () + end + | _ -> try_by_location () in - let check_single_pat (P_aux (_,(l,annot)) as p) = + let check_single_pat (P_aux (_, (l, annot)) as p) = match match_l l with | [] -> p | lvs -> - let pvs = Spec_analysis.bindings_from_pat p in - let pvs = List.map string_of_id pvs in - let overlap = List.exists (fun (v,_) -> List.mem v pvs) lvs in - let () = - if overlap then - Reporting.print_err l "Monomorphisation" - "Splitting a singleton pattern is not possible" - in p + let pvs = Spec_analysis.bindings_from_pat p in + let pvs = List.map string_of_id pvs in + let overlap = List.exists (fun (v, _) -> List.mem v pvs) lvs in + let () = + if overlap then Reporting.print_err l "Monomorphisation" "Splitting a singleton pattern is not possible" + in + p in let check_split_size lst l = let size = List.length lst in if size > size_set_limit then let open Reporting in - let error_msg = "Case split is too large (" ^ string_of_int size ^ " > limit " ^ string_of_int size_set_limit ^ ")" in - if all_errors - then (no_errors_happened := false; - print_err l "" error_msg; false) + let error_msg = + "Case split is too large (" ^ string_of_int size ^ " > limit " ^ string_of_int size_set_limit ^ ")" + in + if all_errors then ( + no_errors_happened := false; + print_err l "" error_msg; + false + ) else raise (err_general l error_msg) else true in let map_fns ref_vars = - let rec map_exp ((E_aux (e,annot)) as ea) = - let re e = E_aux (e,annot) in + let rec map_exp (E_aux (e, annot) as ea) = + let re e = E_aux (e, annot) in match e with | E_block es -> re (E_block (List.map map_exp es)) - | E_id _ - | E_lit _ - | E_sizeof _ - | E_constraint _ - | E_ref _ - | E_internal_value _ - -> ea - | E_typ (t,e') -> re (E_typ (t, map_exp e')) - | E_app (id,es) -> - let es' = List.map map_exp es in - let env = env_of_annot annot in - begin - match Env.is_union_constructor id env, refine_constructor refinements (fst annot) env id es' with - | true, Some exp -> re exp - | _,_ -> re (E_app (id,es')) - end - | E_app_infix (e1,id,e2) -> re (E_app_infix (map_exp e1,id,map_exp e2)) + | E_id _ | E_lit _ | E_sizeof _ | E_constraint _ | E_ref _ | E_internal_value _ -> ea + | E_typ (t, e') -> re (E_typ (t, map_exp e')) + | E_app (id, es) -> + let es' = List.map map_exp es in + let env = env_of_annot annot in + begin + match (Env.is_union_constructor id env, refine_constructor refinements (fst annot) env id es') with + | true, Some exp -> re exp + | _, _ -> re (E_app (id, es')) + end + | E_app_infix (e1, id, e2) -> re (E_app_infix (map_exp e1, id, map_exp e2)) | E_tuple es -> re (E_tuple (List.map map_exp es)) - | E_if (e1,e2,e3) -> re (E_if (map_exp e1, map_exp e2, map_exp e3)) - | E_for (id,e1,e2,e3,ord,e4) -> re (E_for (id,map_exp e1,map_exp e2,map_exp e3,ord,map_exp e4)) - | E_loop (loop,m,e1,e2) -> re (E_loop (loop,m,map_exp e1,map_exp e2)) + | E_if (e1, e2, e3) -> re (E_if (map_exp e1, map_exp e2, map_exp e3)) + | E_for (id, e1, e2, e3, ord, e4) -> re (E_for (id, map_exp e1, map_exp e2, map_exp e3, ord, map_exp e4)) + | E_loop (loop, m, e1, e2) -> re (E_loop (loop, m, map_exp e1, map_exp e2)) | E_vector es -> re (E_vector (List.map map_exp es)) - | E_vector_access (e1,e2) -> re (E_vector_access (map_exp e1,map_exp e2)) - | E_vector_subrange (e1,e2,e3) -> re (E_vector_subrange (map_exp e1,map_exp e2,map_exp e3)) - | E_vector_update (e1,e2,e3) -> re (E_vector_update (map_exp e1,map_exp e2,map_exp e3)) - | E_vector_update_subrange (e1,e2,e3,e4) -> re (E_vector_update_subrange (map_exp e1,map_exp e2,map_exp e3,map_exp e4)) - | E_vector_append (e1,e2) -> re (E_vector_append (map_exp e1,map_exp e2)) + | E_vector_access (e1, e2) -> re (E_vector_access (map_exp e1, map_exp e2)) + | E_vector_subrange (e1, e2, e3) -> re (E_vector_subrange (map_exp e1, map_exp e2, map_exp e3)) + | E_vector_update (e1, e2, e3) -> re (E_vector_update (map_exp e1, map_exp e2, map_exp e3)) + | E_vector_update_subrange (e1, e2, e3, e4) -> + re (E_vector_update_subrange (map_exp e1, map_exp e2, map_exp e3, map_exp e4)) + | E_vector_append (e1, e2) -> re (E_vector_append (map_exp e1, map_exp e2)) | E_list es -> re (E_list (List.map map_exp es)) - | E_cons (e1,e2) -> re (E_cons (map_exp e1,map_exp e2)) + | E_cons (e1, e2) -> re (E_cons (map_exp e1, map_exp e2)) | E_struct fes -> re (E_struct (List.map map_fexp fes)) - | E_struct_update (e,fes) -> re (E_struct_update (map_exp e, List.map map_fexp fes)) - | E_field (e,id) -> re (E_field (map_exp e,id)) - | E_match (e,cases) -> re (E_match (map_exp e, List.concat (List.map map_pexp cases))) - | E_let (lb,e) -> - let lb_l, binding_exp_annot = - match lb with LB_aux (LB_val (_, (E_aux (_, (_, a)))),(l,_)) -> l,a - in - let lb' = map_letbind lb in - let e' = map_exp e in - (* Add a case split in the right hand side, e.g. for let 'n = get_vector_size() in ... *) - let e' = match match_l lb_l with - | [] -> e' - | [(id_string,splits)] -> - let l' = Generated lb_l in - let id = mk_id id_string in - let match_exp = E_aux (E_id id, (l',binding_exp_annot)) in - let pat_to_split = P_aux (P_id id, (l',binding_exp_annot)) in - let patsubsts = split_pat [(id_string, splits)] pat_to_split in - let patsubsts = match patsubsts with Some x -> x | None -> assert false (* TODO *) in - let pexps = - List.map (fun (pat',substs,pchoices,ksubsts) -> - let plain_ksubsts = KBindings.map fst ksubsts in - let exp' = Spec_analysis.nexp_subst_exp plain_ksubsts e' in - let exp' = apply_pat_choices pchoices exp' in - let exp' = subst_exp ref_vars substs ksubsts exp' in - let exp' = stop_at_false_assertions exp' in - let annot = match e' with E_aux (_,(_,a)) -> a in - Pat_aux (Pat_exp (pat', map_exp exp'),(l',annot))) - patsubsts - in - E_aux (E_match (match_exp, pexps), annot) - | _ -> assert false (* TODO: should just have an error here...? *) - in - re (E_let (lb', e')) - | E_assign (le,e) -> re (E_assign (map_lexp le, map_exp e)) + | E_struct_update (e, fes) -> re (E_struct_update (map_exp e, List.map map_fexp fes)) + | E_field (e, id) -> re (E_field (map_exp e, id)) + | E_match (e, cases) -> re (E_match (map_exp e, List.concat (List.map map_pexp cases))) + | E_let (lb, e) -> + let lb_l, binding_exp_annot = match lb with LB_aux (LB_val (_, E_aux (_, (_, a))), (l, _)) -> (l, a) in + let lb' = map_letbind lb in + let e' = map_exp e in + (* Add a case split in the right hand side, e.g. for let 'n = get_vector_size() in ... *) + let e' = + match match_l lb_l with + | [] -> e' + | [(id_string, splits)] -> + let l' = Generated lb_l in + let id = mk_id id_string in + let match_exp = E_aux (E_id id, (l', binding_exp_annot)) in + let pat_to_split = P_aux (P_id id, (l', binding_exp_annot)) in + let patsubsts = split_pat [(id_string, splits)] pat_to_split in + let patsubsts = match patsubsts with Some x -> x | None -> assert false (* TODO *) in + let pexps = + List.map + (fun (pat', substs, pchoices, ksubsts) -> + let plain_ksubsts = KBindings.map fst ksubsts in + let exp' = Spec_analysis.nexp_subst_exp plain_ksubsts e' in + let exp' = apply_pat_choices pchoices exp' in + let exp' = subst_exp ref_vars substs ksubsts exp' in + let exp' = stop_at_false_assertions exp' in + let annot = match e' with E_aux (_, (_, a)) -> a in + Pat_aux (Pat_exp (pat', map_exp exp'), (l', annot)) + ) + patsubsts + in + E_aux (E_match (match_exp, pexps), annot) + | _ -> assert false (* TODO: should just have an error here...? *) + in + re (E_let (lb', e')) + | E_assign (le, e) -> re (E_assign (map_lexp le, map_exp e)) | E_exit e -> re (E_exit (map_exp e)) | E_throw e -> re (E_throw e) - | E_try (e,cases) -> re (E_try (map_exp e, List.concat (List.map map_pexp cases))) + | E_try (e, cases) -> re (E_try (map_exp e, List.concat (List.map map_pexp cases))) | E_return e -> re (E_return (map_exp e)) - | E_assert (e1,e2) -> re (E_assert (map_exp e1,map_exp e2)) - | E_var (le,e1,e2) -> re (E_var (map_lexp le, map_exp e1, map_exp e2)) - | E_internal_plet (p,e1,e2) -> re (E_internal_plet (check_single_pat p, map_exp e1, map_exp e2)) + | E_assert (e1, e2) -> re (E_assert (map_exp e1, map_exp e2)) + | E_var (le, e1, e2) -> re (E_var (map_lexp le, map_exp e1, map_exp e2)) + | E_internal_plet (p, e1, e2) -> re (E_internal_plet (check_single_pat p, map_exp e1, map_exp e2)) | E_internal_return e -> re (E_internal_return (map_exp e)) - | E_internal_assume (nc,e) -> re (E_internal_assume (nc, map_exp e)) - and map_fexp (FE_aux (FE_fexp (id,e), annot)) = - FE_aux (FE_fexp (id,map_exp e),annot) + | E_internal_assume (nc, e) -> re (E_internal_assume (nc, map_exp e)) + and map_fexp (FE_aux (FE_fexp (id, e), annot)) = FE_aux (FE_fexp (id, map_exp e), annot) and map_pexp = function - | Pat_aux (Pat_exp (p,e),l) -> - let nosplit = lazy [Pat_aux (Pat_exp (p,map_exp e),l)] in - (match map_pat p with - | NoSplit -> Lazy.force nosplit - | VarSplit patsubsts -> - if check_split_size patsubsts (pat_loc p) then - List.map (fun (pat',substs,pchoices,ksubsts) -> - let plain_ksubsts = KBindings.map fst ksubsts in - let exp' = Spec_analysis.nexp_subst_exp plain_ksubsts e in - let exp' = apply_pat_choices pchoices exp' in - let exp' = subst_exp ref_vars substs ksubsts exp' in - let exp' = stop_at_false_assertions exp' in - Pat_aux (Pat_exp (pat', map_exp exp'),l)) - patsubsts - else Lazy.force nosplit - | ConstrSplit patnsubsts -> - List.map (fun (pat',nsubst) -> - let pat' = Spec_analysis.nexp_subst_pat nsubst pat' in - let exp' = Spec_analysis.nexp_subst_exp nsubst e in - Pat_aux (Pat_exp (pat', map_exp exp'),l) - ) patnsubsts) - | Pat_aux (Pat_when (p,e1,e2),l) -> - let nosplit = lazy [Pat_aux (Pat_when (p,map_exp e1,map_exp e2),l)] in - (match map_pat p with - | NoSplit -> Lazy.force nosplit - | VarSplit patsubsts -> - if check_split_size patsubsts (pat_loc p) then - List.map (fun (pat',substs,pchoices,ksubsts) -> - let plain_ksubsts = KBindings.map fst ksubsts in - let exp1' = Spec_analysis.nexp_subst_exp plain_ksubsts e1 in - let exp1' = apply_pat_choices pchoices exp1' in - let exp1' = subst_exp ref_vars substs ksubsts exp1' in - let plain_ksubsts = KBindings.map fst ksubsts in - let exp2' = Spec_analysis.nexp_subst_exp plain_ksubsts e2 in - let exp2' = apply_pat_choices pchoices exp2' in - let exp2' = subst_exp ref_vars substs ksubsts exp2' in - let exp2' = stop_at_false_assertions exp2' in - Pat_aux (Pat_when (pat', map_exp exp1', map_exp exp2'),l)) - patsubsts - else Lazy.force nosplit - | ConstrSplit patnsubsts -> - List.map (fun (pat',nsubst) -> - let pat' = Spec_analysis.nexp_subst_pat nsubst pat' in - let exp1' = Spec_analysis.nexp_subst_exp nsubst e1 in - let exp2' = Spec_analysis.nexp_subst_exp nsubst e2 in - Pat_aux (Pat_when (pat', map_exp exp1', map_exp exp2'),l) - ) patnsubsts) - and map_letbind (LB_aux (lb,annot)) = - match lb with - | LB_val (p,e) -> LB_aux (LB_val (check_single_pat p,map_exp e), annot) - and map_lexp ((LE_aux (e,annot)) as le) = - let re e = LE_aux (e,annot) in + | Pat_aux (Pat_exp (p, e), l) -> ( + let nosplit = lazy [Pat_aux (Pat_exp (p, map_exp e), l)] in + match map_pat p with + | NoSplit -> Lazy.force nosplit + | VarSplit patsubsts -> + if check_split_size patsubsts (pat_loc p) then + List.map + (fun (pat', substs, pchoices, ksubsts) -> + let plain_ksubsts = KBindings.map fst ksubsts in + let exp' = Spec_analysis.nexp_subst_exp plain_ksubsts e in + let exp' = apply_pat_choices pchoices exp' in + let exp' = subst_exp ref_vars substs ksubsts exp' in + let exp' = stop_at_false_assertions exp' in + Pat_aux (Pat_exp (pat', map_exp exp'), l) + ) + patsubsts + else Lazy.force nosplit + | ConstrSplit patnsubsts -> + List.map + (fun (pat', nsubst) -> + let pat' = Spec_analysis.nexp_subst_pat nsubst pat' in + let exp' = Spec_analysis.nexp_subst_exp nsubst e in + Pat_aux (Pat_exp (pat', map_exp exp'), l) + ) + patnsubsts + ) + | Pat_aux (Pat_when (p, e1, e2), l) -> ( + let nosplit = lazy [Pat_aux (Pat_when (p, map_exp e1, map_exp e2), l)] in + match map_pat p with + | NoSplit -> Lazy.force nosplit + | VarSplit patsubsts -> + if check_split_size patsubsts (pat_loc p) then + List.map + (fun (pat', substs, pchoices, ksubsts) -> + let plain_ksubsts = KBindings.map fst ksubsts in + let exp1' = Spec_analysis.nexp_subst_exp plain_ksubsts e1 in + let exp1' = apply_pat_choices pchoices exp1' in + let exp1' = subst_exp ref_vars substs ksubsts exp1' in + let plain_ksubsts = KBindings.map fst ksubsts in + let exp2' = Spec_analysis.nexp_subst_exp plain_ksubsts e2 in + let exp2' = apply_pat_choices pchoices exp2' in + let exp2' = subst_exp ref_vars substs ksubsts exp2' in + let exp2' = stop_at_false_assertions exp2' in + Pat_aux (Pat_when (pat', map_exp exp1', map_exp exp2'), l) + ) + patsubsts + else Lazy.force nosplit + | ConstrSplit patnsubsts -> + List.map + (fun (pat', nsubst) -> + let pat' = Spec_analysis.nexp_subst_pat nsubst pat' in + let exp1' = Spec_analysis.nexp_subst_exp nsubst e1 in + let exp2' = Spec_analysis.nexp_subst_exp nsubst e2 in + Pat_aux (Pat_when (pat', map_exp exp1', map_exp exp2'), l) + ) + patnsubsts + ) + and map_letbind (LB_aux (lb, annot)) = + match lb with LB_val (p, e) -> LB_aux (LB_val (check_single_pat p, map_exp e), annot) + and map_lexp (LE_aux (e, annot) as le) = + let re e = LE_aux (e, annot) in match e with - | LE_id _ - | LE_typ _ - -> le - | LE_app (id,es) -> re (LE_app (id,List.map map_exp es)) + | LE_id _ | LE_typ _ -> le + | LE_app (id, es) -> re (LE_app (id, List.map map_exp es)) | LE_tuple les -> re (LE_tuple (List.map map_lexp les)) - | LE_vector (le,e) -> re (LE_vector (map_lexp le, map_exp e)) - | LE_vector_range (le,e1,e2) -> re (LE_vector_range (map_lexp le, map_exp e1, map_exp e2)) + | LE_vector (le, e) -> re (LE_vector (map_lexp le, map_exp e)) + | LE_vector_range (le, e1, e2) -> re (LE_vector_range (map_lexp le, map_exp e1, map_exp e2)) | LE_vector_concat les -> re (LE_vector_concat (List.map map_lexp les)) - | LE_field (le,id) -> re (LE_field (map_lexp le, id)) + | LE_field (le, id) -> re (LE_field (map_lexp le, id)) | LE_deref e -> re (LE_deref (map_exp e)) - in map_exp, map_pexp, map_letbind + in + (map_exp, map_pexp, map_letbind) + in + let map_exp r = + let f, _, _ = map_fns r in + f + in + let map_pexp r = + let _, f, _ = map_fns r in + f + in + let map_letbind r = + let _, _, f = map_fns r in + f in - let map_exp r = let (f,_,_) = map_fns r in f in - let map_pexp r = let (_,f,_) = map_fns r in f in - let map_letbind r = let (_,_,f) = map_fns r in f in let map_exp exp = let ref_vars = Constant_propagation.referenced_vars exp in map_exp ref_vars exp @@ -1210,3270 +1205,3392 @@ let split_defs target all_errors (splits : split_req list) env ast = (* Construct the set of referenced variables so that we don't accidentally make false assumptions about them during constant propagation. Note that we assume there aren't any in the guard. *) - let (_,_,body,_) = destruct_pexp top_pexp in + let _, _, body, _ = destruct_pexp top_pexp in let ref_vars = Constant_propagation.referenced_vars body in map_pexp ref_vars top_pexp in - let map_letbind (LB_aux (LB_val (_,e),_) as lb) = + let map_letbind (LB_aux (LB_val (_, e), _) as lb) = let ref_vars = Constant_propagation.referenced_vars e in map_letbind ref_vars lb in - let map_funcl (FCL_aux (FCL_funcl (id,pexp),annot)) = - List.map (fun pexp -> FCL_aux (FCL_funcl (id,pexp),annot)) (map_pexp pexp) + let map_funcl (FCL_aux (FCL_funcl (id, pexp), annot)) = + List.map (fun pexp -> FCL_aux (FCL_funcl (id, pexp), annot)) (map_pexp pexp) in - let map_fundef (FD_aux (FD_function (r,t,fcls),annot)) = - FD_aux (FD_function (r,t,List.concat (List.map map_funcl fcls)),annot) + let map_fundef (FD_aux (FD_function (r, t, fcls), annot)) = + FD_aux (FD_function (r, t, List.concat (List.map map_funcl fcls)), annot) in let map_scattered_def sd = match sd with - | SD_aux (SD_funcl fcl, annot) -> - List.map (fun fcl' -> SD_aux (SD_funcl fcl', annot)) (map_funcl fcl) + | SD_aux (SD_funcl fcl, annot) -> List.map (fun fcl' -> SD_aux (SD_funcl fcl', annot)) (map_funcl fcl) | _ -> [sd] in let num_defs = List.length defs in let map_def idx (DEF_aux (aux, def_annot) as def) = Util.progress "Monomorphising " (string_of_int idx ^ "/" ^ string_of_int num_defs) idx num_defs; match aux with - | DEF_type _ - | DEF_val _ - | DEF_default _ - | DEF_register _ - | DEF_overload _ - | DEF_fixity _ - | DEF_pragma _ - | DEF_internal_mutrec _ - -> [def] + | DEF_type _ | DEF_val _ | DEF_default _ | DEF_register _ | DEF_overload _ | DEF_fixity _ | DEF_pragma _ + | DEF_internal_mutrec _ -> + [def] | DEF_fundef fd -> [DEF_aux (DEF_fundef (map_fundef fd), def_annot)] | DEF_let lb -> [DEF_aux (DEF_let (map_letbind lb), def_annot)] | DEF_scattered sd -> List.map (fun x -> DEF_aux (DEF_scattered x, def_annot)) (map_scattered_def sd) - | DEF_measure (id,pat,exp) -> [DEF_aux (DEF_measure (id,pat,map_exp exp), def_annot)] + | DEF_measure (id, pat, exp) -> [DEF_aux (DEF_measure (id, pat, map_exp exp), def_annot)] | DEF_impl _ | DEF_instantiation _ | DEF_outcome _ | DEF_mapdef _ | DEF_loop_measures _ -> - Reporting.unreachable (def_loc def) __POS__ - "Found definition that should have been rewritten previously during monomorphisation" + Reporting.unreachable (def_loc def) __POS__ + "Found definition that should have been rewritten previously during monomorphisation" in List.concat (List.mapi map_def defs) in let defs'' = map_locs splits defs' in Util.progress "Monomorphising " "done" (List.length defs'') (List.length defs''); - !no_errors_happened, { ast with defs = defs'' } - - + (!no_errors_happened, { ast with defs = defs'' }) (* The next section of code turns atom('n) types into itself('n) types, which survive into the Lem output, so can be used to parametrise functions over internal bitvector lengths (such as datasize and regsize in ARM specs *) -module AtomToItself = -struct - -let mapat f is xs = - let rec aux n = function - | [] -> [] - | h::t when Util.IntSet.mem n is -> - let h' = f h in - let t' = aux (n+1) t in - h'::t' - | h::t -> - let t' = aux (n+1) t in - h::t' - in aux 0 xs - -let mapat_extra f is xs = - let rec aux n = function - | [] -> [], [] - | h::t when Util.IntSet.mem n is -> - let h',x = f n h in - let t',xs = aux (n+1) t in - h'::t',x::xs - | h::t -> - let t',xs = aux (n+1) t in - h::t',xs - in aux 0 xs - -let change_parameter_pat i = function - | P_aux (P_id var, (l,_)) - | P_aux (P_typ (_,P_aux (P_id var, (l,_))),_) -> - P_aux (P_id var, (l,empty_tannot)), ([var],[]) - | P_aux (P_lit lit,(l,_)) -> - let var = mk_id ("p#" ^ string_of_int i) in - let annot = (Generated l, empty_tannot) in - let test : tannot exp = - E_aux (E_app_infix (E_aux (E_app (mk_id "size_itself_int",[E_aux (E_id var,annot)]),annot), - mk_id "==", - E_aux (E_lit lit,annot)), annot) in - P_aux (P_id var, (l,empty_tannot)), ([],[test]) - | P_aux (_,(l,_)) -> raise (Reporting.err_unreachable l __POS__ - "Expected variable pattern") - -(* TODO: make more precise, preferably with a proper free variables function - which deals with shadowing *) -let var_maybe_used_in_exp exp var = - let open Rewriter in - fst (fold_exp { - (compute_exp_alg false (||)) with - e_id = fun id -> (Id.compare id var == 0, E_id id) } exp) - -(* We add code to change the itself('n) parameter into the corresponding - integer. We always do this for the function body (otherwise we'd have to do - something clever with E_sizeof to avoid making things more complex), but - only for guards when they actually use the variable. *) -let add_var_rebind unconditional exp var = - if unconditional || var_maybe_used_in_exp exp var then - let l = Generated Unknown in - let annot = (l,empty_tannot) in - E_aux (E_let (LB_aux (LB_val (P_aux (P_id var,annot), - E_aux (E_app (mk_id "size_itself_int",[E_aux (E_id var,annot)]),annot)),annot),exp),annot) - else exp - -(* atom('n) arguments to function calls need to be rewritten *) -let replace_with_the_value bound_nexps (E_aux (_,(l,_)) as exp) = - let env = env_of exp in - let typ, wrap = match typ_of exp with - | Typ_aux (Typ_exist (kids,nc,typ),l) -> typ, fun t -> Typ_aux (Typ_exist (kids,nc,t),l) - | typ -> typ, fun x -> x - in - let typ = Env.expand_synonyms env typ in - let replace_size size = - (* TODO: pick simpler nexp when there's a choice (also in pretty printer) *) - let is_equal nexp = - prove __POS__ env (NC_aux (NC_equal (size,nexp), Parse_ast.Unknown)) +module AtomToItself = struct + let mapat f is xs = + let rec aux n = function + | [] -> [] + | h :: t when Util.IntSet.mem n is -> + let h' = f h in + let t' = aux (n + 1) t in + h' :: t' + | h :: t -> + let t' = aux (n + 1) t in + h :: t' in - if is_nexp_constant size then size else - match solve_unique env size with - | Some n -> nconstant n - | None -> - match List.find is_equal bound_nexps with - | nexp -> nexp - | exception Not_found -> size - in - let mk_exp nexp l l' = - let nexp = replace_size nexp in - E_aux (E_typ (wrap (Typ_aux (Typ_app (Id_aux (Id "itself",Generated Unknown), - [A_aux (A_nexp nexp,l')]),Generated Unknown)), - E_aux (E_app (Id_aux (Id "make_the_value",Generated Unknown),[exp]),(Generated l,empty_tannot))), - (Generated l,empty_tannot)) - in - match destruct_numeric typ with - | Some ([], nc, nexp) when prove __POS__ env nc -> mk_exp nexp l l - | _ -> raise (Reporting.err_unreachable l __POS__ - ("replace_with_the_value: Unsupported type " ^ string_of_typ typ)) - -let replace_type env typ = - let Typ_aux (t,l) = Env.expand_synonyms env typ in - match destruct_numeric typ with - | Some ([], nc, nexp) when prove __POS__ env nc -> - Typ_aux (Typ_app (mk_id "itself", [A_aux (A_nexp nexp, Generated l)]), Generated l) - | _ -> raise (Reporting.err_unreachable l __POS__ - ("replace_type: Unsupported type " ^ string_of_typ typ)) - - -let rewrite_size_parameters target type_env ast = - let open Rewriter in - let open Util in - - let const_prop_exp exp = - let ref_vars = Constant_propagation.referenced_vars exp in - let substs = (Bindings.empty, KBindings.empty) in - let assigns = Bindings.empty in - fst (Constant_propagation.const_prop target ast ref_vars substs assigns exp) - in - let const_prop_pexp pexp = - let (pat, guard, exp, a) = destruct_pexp pexp in - construct_pexp (pat, guard, const_prop_exp exp, a) - in - let const_prop_funcl (FCL_aux (FCL_funcl (id, pexp), a)) = - FCL_aux (FCL_funcl (id, const_prop_pexp pexp), a) - in + aux 0 xs + + let mapat_extra f is xs = + let rec aux n = function + | [] -> ([], []) + | h :: t when Util.IntSet.mem n is -> + let h', x = f n h in + let t', xs = aux (n + 1) t in + (h' :: t', x :: xs) + | h :: t -> + let t', xs = aux (n + 1) t in + (h :: t', xs) + in + aux 0 xs + + let change_parameter_pat i = function + | P_aux (P_id var, (l, _)) | P_aux (P_typ (_, P_aux (P_id var, (l, _))), _) -> + (P_aux (P_id var, (l, empty_tannot)), ([var], [])) + | P_aux (P_lit lit, (l, _)) -> + let var = mk_id ("p#" ^ string_of_int i) in + let annot = (Generated l, empty_tannot) in + let test : tannot exp = + E_aux + ( E_app_infix + ( E_aux (E_app (mk_id "size_itself_int", [E_aux (E_id var, annot)]), annot), + mk_id "==", + E_aux (E_lit lit, annot) + ), + annot + ) + in + (P_aux (P_id var, (l, empty_tannot)), ([], [test])) + | P_aux (_, (l, _)) -> raise (Reporting.err_unreachable l __POS__ "Expected variable pattern") - let sizes_funcl fsizes (FCL_aux (FCL_funcl (id,pexp),(def_annot,ann))) = - let l = def_annot.loc in - let env = env_of_tannot ann in - let _, typ = Env.get_val_spec_orig id env in - let already_visible_nexps = - NexpSet.union - (lem_nexps_of_typ typ) - (typeclass_nexps typ) + (* TODO: make more precise, preferably with a proper free variables function + which deals with shadowing *) + let var_maybe_used_in_exp exp var = + let open Rewriter in + fst (fold_exp { (compute_exp_alg false ( || )) with e_id = (fun id -> (Id.compare id var == 0, E_id id)) } exp) + + (* We add code to change the itself('n) parameter into the corresponding + integer. We always do this for the function body (otherwise we'd have to do + something clever with E_sizeof to avoid making things more complex), but + only for guards when they actually use the variable. *) + let add_var_rebind unconditional exp var = + if unconditional || var_maybe_used_in_exp exp var then ( + let l = Generated Unknown in + let annot = (l, empty_tannot) in + E_aux + ( E_let + ( LB_aux + ( LB_val + (P_aux (P_id var, annot), E_aux (E_app (mk_id "size_itself_int", [E_aux (E_id var, annot)]), annot)), + annot + ), + exp + ), + annot + ) + ) + else exp + + (* atom('n) arguments to function calls need to be rewritten *) + let replace_with_the_value bound_nexps (E_aux (_, (l, _)) as exp) = + let env = env_of exp in + let typ, wrap = + match typ_of exp with + | Typ_aux (Typ_exist (kids, nc, typ), l) -> (typ, fun t -> Typ_aux (Typ_exist (kids, nc, t), l)) + | typ -> (typ, fun x -> x) in - let types = match typ with - | Typ_aux (Typ_fn (arg_typs,_),_) -> List.map (Env.expand_synonyms env) arg_typs - | _ -> raise (Reporting.err_unreachable l __POS__ "Function clause does not have a function type") + let typ = Env.expand_synonyms env typ in + let replace_size size = + (* TODO: pick simpler nexp when there's a choice (also in pretty printer) *) + let is_equal nexp = prove __POS__ env (NC_aux (NC_equal (size, nexp), Parse_ast.Unknown)) in + if is_nexp_constant size then size + else ( + match solve_unique env size with + | Some n -> nconstant n + | None -> ( + match List.find is_equal bound_nexps with nexp -> nexp | exception Not_found -> size + ) + ) in - let add_parameter (i,nmap) typ = - let nmap = - match Env.base_typ_of env typ with - Typ_aux (Typ_app(Id_aux (Id "range",_), - [A_aux (A_nexp nexp,_); - A_aux (A_nexp nexp',_)]),_) - when Nexp.compare nexp nexp' = 0 && not (NexpMap.mem nexp nmap) && - not (NexpSet.mem nexp already_visible_nexps) -> - (* Split integer variables if the nexp is not already available via a bitvector length *) - NexpMap.add nexp i nmap - | Typ_aux (Typ_app(Id_aux (Id "atom", _), - [A_aux (A_nexp nexp,_)]), _) - when not (NexpMap.mem nexp nmap) && - not (NexpSet.mem nexp already_visible_nexps) -> - NexpMap.add nexp i nmap - | _ -> nmap - in (i+1,nmap) + let mk_exp nexp l l' = + let nexp = replace_size nexp in + E_aux + ( E_typ + ( wrap + (Typ_aux + (Typ_app (Id_aux (Id "itself", Generated Unknown), [A_aux (A_nexp nexp, l')]), Generated Unknown) + ), + E_aux (E_app (Id_aux (Id "make_the_value", Generated Unknown), [exp]), (Generated l, empty_tannot)) + ), + (Generated l, empty_tannot) + ) in - let (_,nexp_map) = List.fold_left add_parameter (0,NexpMap.empty) types in - let nexp_list = NexpMap.bindings nexp_map in -(* let () = - print_endline ("Type of pattern for " ^ string_of_id id ^": " ^string_of_typ (typ_of_pat pat)); - print_endline ("Types : " ^ String.concat ", " (List.map string_of_typ types)); - print_endline ("Nexp map for " ^ string_of_id id); - List.iter (fun (nexp, i) -> print_endline (" " ^ string_of_nexp nexp ^ " -> " ^ string_of_int i)) nexp_list -in *) - let parameters_for e tannot = - let parameters_for_nexp env size = - match solve_unique env size with - | Some _ -> IntSet.empty - | None -> - match NexpMap.find size nexp_map with - | i -> IntSet.singleton i - | exception Not_found -> - (* Look for equivalent nexps, but only in consistent type env *) - if prove __POS__ env (NC_aux (NC_false,Unknown)) then IntSet.empty else - match List.find (fun (nexp,i) -> - prove __POS__ env (NC_aux (NC_equal (nexp,size),Unknown))) nexp_list with - | _, i -> IntSet.singleton i - | exception Not_found -> IntSet.empty + match destruct_numeric typ with + | Some ([], nc, nexp) when prove __POS__ env nc -> mk_exp nexp l l + | _ -> raise (Reporting.err_unreachable l __POS__ ("replace_with_the_value: Unsupported type " ^ string_of_typ typ)) + + let replace_type env typ = + let (Typ_aux (t, l)) = Env.expand_synonyms env typ in + match destruct_numeric typ with + | Some ([], nc, nexp) when prove __POS__ env nc -> + Typ_aux (Typ_app (mk_id "itself", [A_aux (A_nexp nexp, Generated l)]), Generated l) + | _ -> raise (Reporting.err_unreachable l __POS__ ("replace_type: Unsupported type " ^ string_of_typ typ)) + + let rewrite_size_parameters target type_env ast = + let open Rewriter in + let open Util in + let const_prop_exp exp = + let ref_vars = Constant_propagation.referenced_vars exp in + let substs = (Bindings.empty, KBindings.empty) in + let assigns = Bindings.empty in + fst (Constant_propagation.const_prop target ast ref_vars substs assigns exp) + in + let const_prop_pexp pexp = + let pat, guard, exp, a = destruct_pexp pexp in + construct_pexp (pat, guard, const_prop_exp exp, a) + in + let const_prop_funcl (FCL_aux (FCL_funcl (id, pexp), a)) = FCL_aux (FCL_funcl (id, const_prop_pexp pexp), a) in + + let sizes_funcl fsizes (FCL_aux (FCL_funcl (id, pexp), (def_annot, ann))) = + let l = def_annot.loc in + let env = env_of_tannot ann in + let _, typ = Env.get_val_spec_orig id env in + let already_visible_nexps = NexpSet.union (lem_nexps_of_typ typ) (typeclass_nexps typ) in + let types = + match typ with + | Typ_aux (Typ_fn (arg_typs, _), _) -> List.map (Env.expand_synonyms env) arg_typs + | _ -> raise (Reporting.err_unreachable l __POS__ "Function clause does not have a function type") in - let parameters_for_typ = - match destruct_tannot tannot with - | Some (env,typ) -> - begin match Env.base_typ_of env typ with - | Typ_aux (Typ_app (Id_aux (Id "bitvector",_), [A_aux (A_nexp size,_);_]),_) - when not (is_nexp_constant size) -> - parameters_for_nexp env size + let add_parameter (i, nmap) typ = + let nmap = + match Env.base_typ_of env typ with + | Typ_aux (Typ_app (Id_aux (Id "range", _), [A_aux (A_nexp nexp, _); A_aux (A_nexp nexp', _)]), _) + when Nexp.compare nexp nexp' = 0 + && (not (NexpMap.mem nexp nmap)) + && not (NexpSet.mem nexp already_visible_nexps) -> + (* Split integer variables if the nexp is not already available via a bitvector length *) + NexpMap.add nexp i nmap + | Typ_aux (Typ_app (Id_aux (Id "atom", _), [A_aux (A_nexp nexp, _)]), _) + when (not (NexpMap.mem nexp nmap)) && not (NexpSet.mem nexp already_visible_nexps) -> + NexpMap.add nexp i nmap + | _ -> nmap + in + (i + 1, nmap) + in + let _, nexp_map = List.fold_left add_parameter (0, NexpMap.empty) types in + let nexp_list = NexpMap.bindings nexp_map in + (* let () = + print_endline ("Type of pattern for " ^ string_of_id id ^": " ^string_of_typ (typ_of_pat pat)); + print_endline ("Types : " ^ String.concat ", " (List.map string_of_typ types)); + print_endline ("Nexp map for " ^ string_of_id id); + List.iter (fun (nexp, i) -> print_endline (" " ^ string_of_nexp nexp ^ " -> " ^ string_of_int i)) nexp_list + in *) + let parameters_for e tannot = + let parameters_for_nexp env size = + match solve_unique env size with + | Some _ -> IntSet.empty + | None -> ( + match NexpMap.find size nexp_map with + | i -> IntSet.singleton i + | exception Not_found -> + (* Look for equivalent nexps, but only in consistent type env *) + if prove __POS__ env (NC_aux (NC_false, Unknown)) then IntSet.empty + else ( + match + List.find (fun (nexp, i) -> prove __POS__ env (NC_aux (NC_equal (nexp, size), Unknown))) nexp_list + with + | _, i -> IntSet.singleton i + | exception Not_found -> IntSet.empty + ) + ) + in + let parameters_for_typ = + match destruct_tannot tannot with + | Some (env, typ) -> begin + match Env.base_typ_of env typ with + | Typ_aux (Typ_app (Id_aux (Id "bitvector", _), [A_aux (A_nexp size, _); _]), _) + when not (is_nexp_constant size) -> + parameters_for_nexp env size + | _ -> IntSet.empty + end + | None -> IntSet.empty + in + let parameters_for_exp = + match e with + | E_app (id, args) when Bindings.mem id fsizes -> + let add_arg (i, s) arg = + if IntSet.mem i (fst (Bindings.find id fsizes)) then ( + try + match destruct_numeric (typ_of arg) with + | Some ([], _, nexp) -> (i + 1, IntSet.union s (parameters_for_nexp env nexp)) + | _ -> (i + 1, s) + with _ -> (i + 1, s) + ) + else (i + 1, s) + in + snd (List.fold_left add_arg (0, IntSet.empty) args) | _ -> IntSet.empty - end - | None -> IntSet.empty + in + IntSet.union parameters_for_typ parameters_for_exp in - let parameters_for_exp = match e with - | E_app (id, args) when Bindings.mem id fsizes -> - let add_arg (i, s) arg = - if IntSet.mem i (fst (Bindings.find id fsizes)) then - try match destruct_numeric (typ_of arg) with - | Some ([], _, nexp) -> - (i + 1, IntSet.union s (parameters_for_nexp env nexp)) - | _ -> (i + 1, s) - with _ -> (i + 1, s) - else (i + 1, s) - in - snd (List.fold_left add_arg (0, IntSet.empty) args) - | _ -> IntSet.empty + let parameters_to_rewrite = + fst + (fold_pexp + { + (compute_exp_alg IntSet.empty IntSet.union) with + e_aux = (fun ((s, e), (l, annot)) -> (IntSet.union s (parameters_for e annot), E_aux (e, (l, annot)))); + } + pexp + ) in - IntSet.union parameters_for_typ parameters_for_exp + let new_nexps = + NexpSet.of_list (List.map fst (List.filter (fun (nexp, i) -> IntSet.mem i parameters_to_rewrite) nexp_list)) + in + match Bindings.find id fsizes with + | old, old_nexps -> + Bindings.add id (IntSet.union old parameters_to_rewrite, NexpSet.union old_nexps new_nexps) fsizes + | exception Not_found -> Bindings.add id (parameters_to_rewrite, new_nexps) fsizes in - let parameters_to_rewrite = - fst (fold_pexp - { (compute_exp_alg IntSet.empty IntSet.union) with - e_aux = (fun ((s,e),(l,annot)) -> IntSet.union s (parameters_for e annot),E_aux (e,(l,annot))) - } pexp) + let sizes_def fsizes = function + | DEF_aux (DEF_fundef (FD_aux (FD_function (_, _, funcls), _)), _) -> List.fold_left sizes_funcl fsizes funcls + | _ -> fsizes in - let new_nexps = NexpSet.of_list (List.map fst - (List.filter (fun (nexp,i) -> IntSet.mem i parameters_to_rewrite) nexp_list)) in - match Bindings.find id fsizes with - | old,old_nexps -> Bindings.add id (IntSet.union old parameters_to_rewrite, - NexpSet.union old_nexps new_nexps) fsizes - | exception Not_found -> Bindings.add id (parameters_to_rewrite, new_nexps) fsizes - in - let sizes_def fsizes = function - | DEF_aux (DEF_fundef (FD_aux (FD_function (_,_,funcls),_)),_) -> - List.fold_left sizes_funcl fsizes funcls - | _ -> fsizes - in - let fn_sizes = List.fold_left sizes_def Bindings.empty ast.defs in - - let rewrite_funcl (FCL_aux (FCL_funcl (id,pexp),(def_annot,annot))) = - let pat,guard,body,(pl,_) = destruct_pexp pexp in - let pat,guard,body, nexps = - (* Update pattern and add itself -> nat wrapper to body *) - match Bindings.find id fn_sizes with - | to_change,nexps -> - let pat, vars, new_guards = - match pat with - P_aux (P_tuple pats,(l,_)) -> - let pats, vars_guards = mapat_extra change_parameter_pat to_change pats in - let vars, new_guards = List.split vars_guards in - P_aux (P_tuple pats,(l,empty_tannot)), vars, new_guards - | P_aux (_,(l,_)) -> - begin - if IntSet.is_empty to_change then pat, [], [] - else - let pat, (var, newguard) = change_parameter_pat 0 pat in - pat, [var], [newguard] - end - in - let vars, new_guards = List.concat vars, List.concat new_guards in - let body = List.fold_left (add_var_rebind true) body vars in - let merge_guards g1 g2 : tannot exp = - E_aux (E_app_infix (g1, mk_id "&", g2),(Generated Unknown,empty_tannot)) in - let guard = match guard, new_guards with - | None, [] -> None - | None, (h::t) -> Some (List.fold_left merge_guards h t) - | Some exp, gs -> - let exp' = List.fold_left (add_var_rebind false) exp vars in - Some (List.fold_left merge_guards exp' gs) - in - pat,guard,body,nexps - | exception Not_found -> pat,guard,body,NexpSet.empty + let fn_sizes = List.fold_left sizes_def Bindings.empty ast.defs in + + let rewrite_funcl (FCL_aux (FCL_funcl (id, pexp), (def_annot, annot))) = + let pat, guard, body, (pl, _) = destruct_pexp pexp in + let pat, guard, body, nexps = + (* Update pattern and add itself -> nat wrapper to body *) + match Bindings.find id fn_sizes with + | to_change, nexps -> + let pat, vars, new_guards = + match pat with + | P_aux (P_tuple pats, (l, _)) -> + let pats, vars_guards = mapat_extra change_parameter_pat to_change pats in + let vars, new_guards = List.split vars_guards in + (P_aux (P_tuple pats, (l, empty_tannot)), vars, new_guards) + | P_aux (_, (l, _)) -> begin + if IntSet.is_empty to_change then (pat, [], []) + else ( + let pat, (var, newguard) = change_parameter_pat 0 pat in + (pat, [var], [newguard]) + ) + end + in + let vars, new_guards = (List.concat vars, List.concat new_guards) in + let body = List.fold_left (add_var_rebind true) body vars in + let merge_guards g1 g2 : tannot exp = + E_aux (E_app_infix (g1, mk_id "&", g2), (Generated Unknown, empty_tannot)) + in + let guard = + match (guard, new_guards) with + | None, [] -> None + | None, h :: t -> Some (List.fold_left merge_guards h t) + | Some exp, gs -> + let exp' = List.fold_left (add_var_rebind false) exp vars in + Some (List.fold_left merge_guards exp' gs) + in + (pat, guard, body, nexps) + | exception Not_found -> (pat, guard, body, NexpSet.empty) + in + (* Update function applications *) + let funcl_typ = typ_of_tannot annot in + let already_visible_nexps = NexpSet.union (lem_nexps_of_typ funcl_typ) (typeclass_nexps funcl_typ) in + let bound_nexps = NexpSet.elements (NexpSet.union nexps already_visible_nexps) in + let rewrite_e_app (id, args) = + match Bindings.find id fn_sizes with + | to_change, _ -> + let args' = mapat (replace_with_the_value bound_nexps) to_change args in + E_app (id, args') + | exception Not_found -> E_app (id, args) + in + let body = fold_exp { id_exp_alg with e_app = rewrite_e_app } body in + let guard = + match guard with None -> None | Some exp -> Some (fold_exp { id_exp_alg with e_app = rewrite_e_app } exp) + in + FCL_aux (FCL_funcl (id, construct_pexp (pat, guard, body, (pl, empty_tannot))), (def_annot, empty_tannot)) in - (* Update function applications *) - let funcl_typ = typ_of_tannot annot in - let already_visible_nexps = - NexpSet.union - (lem_nexps_of_typ funcl_typ) - (typeclass_nexps funcl_typ) + let rewrite_e_app (id, args) = + match Bindings.find id fn_sizes with + | to_change, _ -> + let args' = mapat (replace_with_the_value []) to_change args in + E_app (id, args') + | exception Not_found -> E_app (id, args) in - let bound_nexps = NexpSet.elements (NexpSet.union nexps already_visible_nexps) in - let rewrite_e_app (id,args) = + let rewrite_letbind = fold_letbind { id_exp_alg with e_app = rewrite_e_app } in + let rewrite_exp = fold_exp { id_exp_alg with e_app = rewrite_e_app } in + let replace_funtype id typ = match Bindings.find id fn_sizes with - | to_change,_ -> - let args' = mapat (replace_with_the_value bound_nexps) to_change args in - E_app (id,args') - | exception Not_found -> E_app (id,args) + | to_change, _ when not (IntSet.is_empty to_change) -> begin + match typ with + | Typ_aux (Typ_fn (ts, t2), l2) -> Typ_aux (Typ_fn (mapat (replace_type type_env) to_change ts, t2), l2) + | _ -> replace_type type_env typ + end + | _ -> typ + | exception Not_found -> typ in - let body = fold_exp { id_exp_alg with e_app = rewrite_e_app } body in - let guard = match guard with - | None -> None - | Some exp -> Some (fold_exp { id_exp_alg with e_app = rewrite_e_app } exp) in - FCL_aux (FCL_funcl (id,construct_pexp (pat,guard,body,(pl,empty_tannot))),(def_annot,empty_tannot)) - in - let rewrite_e_app (id,args) = - match Bindings.find id fn_sizes with - | to_change,_ -> - let args' = mapat (replace_with_the_value []) to_change args in - E_app (id,args') - | exception Not_found -> E_app (id,args) - in - let rewrite_letbind = fold_letbind { id_exp_alg with e_app = rewrite_e_app } in - let rewrite_exp = fold_exp { id_exp_alg with e_app = rewrite_e_app } in - let replace_funtype id typ = - match Bindings.find id fn_sizes with - | to_change,_ when not (IntSet.is_empty to_change) -> - begin match typ with - | Typ_aux (Typ_fn (ts,t2),l2) -> - Typ_aux (Typ_fn (mapat (replace_type type_env) to_change ts,t2),l2) - | _ -> replace_type type_env typ - end - | _ -> typ - | exception Not_found -> typ - in - let type_env' = - let update_val_spec id _ env = - let (tq, typ) = Env.get_val_spec_orig id env in - Env.update_val_spec id (tq, replace_funtype id typ) env + let type_env' = + let update_val_spec id _ env = + let tq, typ = Env.get_val_spec_orig id env in + Env.update_val_spec id (tq, replace_funtype id typ) env + in + Bindings.fold update_val_spec fn_sizes type_env in - Bindings.fold update_val_spec fn_sizes type_env - in - let rewrite_def (DEF_aux (aux, def_annot)) = - let aux = match aux with - | DEF_fundef (FD_aux (FD_function (recopt,tannopt,funcls),(l,_))) -> - let funcls = List.map rewrite_funcl funcls in - (* Check whether we have ended up with itself('n) expressions where 'n - is not constant. If so, try and see if constant propagation can - resolve those variable expressions. In many cases the monomorphisation - pass will already have performed constant propagation, but it does not - for functions where it does not perform splits.*) - let check_funcl (FCL_aux (FCL_funcl (id, pexp), (def_annot, _)) as funcl) = - let has_nonconst_sizes = - let check_cast (typ, _) = - match unaux_typ typ with - | Typ_app (itself, [A_aux (A_nexp nexp, _)]) - | Typ_exist (_, _, Typ_aux (Typ_app (itself, [A_aux (A_nexp nexp, _)]), _)) + let rewrite_def (DEF_aux (aux, def_annot)) = + let aux = + match aux with + | DEF_fundef (FD_aux (FD_function (recopt, tannopt, funcls), (l, _))) -> + let funcls = List.map rewrite_funcl funcls in + (* Check whether we have ended up with itself('n) expressions where 'n + is not constant. If so, try and see if constant propagation can + resolve those variable expressions. In many cases the monomorphisation + pass will already have performed constant propagation, but it does not + for functions where it does not perform splits.*) + let check_funcl (FCL_aux (FCL_funcl (id, pexp), (def_annot, _)) as funcl) = + let has_nonconst_sizes = + let check_cast (typ, _) = + match unaux_typ typ with + | Typ_app (itself, [A_aux (A_nexp nexp, _)]) + | Typ_exist (_, _, Typ_aux (Typ_app (itself, [A_aux (A_nexp nexp, _)]), _)) when string_of_id itself = "itself" -> - not (is_nexp_constant nexp) - | _ -> false - in - fold_pexp { (pure_exp_alg false (||)) with e_typ = check_cast } pexp - in - if has_nonconst_sizes then - (* Constant propagation requires a fully type-annotated AST, - so re-check the function clause *) - let (tq, typ) = Env.get_val_spec id type_env' in - let env = Env.add_typquant def_annot.loc tq type_env' in - const_prop_funcl (Type_check.check_funcl env (strip_funcl funcl) typ) - else funcl - in - let funcls = List.map check_funcl funcls in - (* TODO rewrite tannopt? *) - DEF_fundef (FD_aux (FD_function (recopt,tannopt,funcls),(l,empty_tannot))) - | DEF_let lb -> DEF_let (rewrite_letbind lb) - | DEF_val (VS_aux (VS_val_spec (typschm,id,extern,cast),(l,annot))) -> - let typschm = match typschm with - | TypSchm_aux (TypSchm_ts (tq, typ),l) -> - TypSchm_aux (TypSchm_ts (tq, replace_funtype id typ), l) - in - DEF_val (VS_aux (VS_val_spec (typschm,id,extern,cast),(l,annot))) - | DEF_register (DEC_aux (DEC_reg (typ, id, Some exp), a)) -> - DEF_register (DEC_aux (DEC_reg (typ, id, Some (rewrite_exp exp)), a)) - | _ -> aux + not (is_nexp_constant nexp) + | _ -> false + in + fold_pexp { (pure_exp_alg false ( || )) with e_typ = check_cast } pexp + in + if has_nonconst_sizes then ( + (* Constant propagation requires a fully type-annotated AST, + so re-check the function clause *) + let tq, typ = Env.get_val_spec id type_env' in + let env = Env.add_typquant def_annot.loc tq type_env' in + const_prop_funcl (Type_check.check_funcl env (strip_funcl funcl) typ) + ) + else funcl + in + let funcls = List.map check_funcl funcls in + (* TODO rewrite tannopt? *) + DEF_fundef (FD_aux (FD_function (recopt, tannopt, funcls), (l, empty_tannot))) + | DEF_let lb -> DEF_let (rewrite_letbind lb) + | DEF_val (VS_aux (VS_val_spec (typschm, id, extern, cast), (l, annot))) -> + let typschm = + match typschm with + | TypSchm_aux (TypSchm_ts (tq, typ), l) -> TypSchm_aux (TypSchm_ts (tq, replace_funtype id typ), l) + in + DEF_val (VS_aux (VS_val_spec (typschm, id, extern, cast), (l, annot))) + | DEF_register (DEC_aux (DEC_reg (typ, id, Some exp), a)) -> + DEF_register (DEC_aux (DEC_reg (typ, id, Some (rewrite_exp exp)), a)) + | _ -> aux + in + DEF_aux (aux, def_annot) in - DEF_aux (aux, def_annot) - in -(* + (* Bindings.iter (fun id args -> print_endline (string_of_id id ^ " needs " ^ String.concat ", " (List.map string_of_int args))) fn_sizes *) - { ast with defs = List.map rewrite_def ast.defs } - + { ast with defs = List.map rewrite_def ast.defs } end - let is_id env id = - let ids = Env.get_overloads (Id_aux (id,Parse_ast.Unknown)) env in - let ids = id :: List.map (fun (Id_aux (id,_)) -> id) ids in - fun (Id_aux (x,_)) -> List.mem x ids + let ids = Env.get_overloads (Id_aux (id, Parse_ast.Unknown)) env in + let ids = id :: List.map (fun (Id_aux (id, _)) -> id) ids in + fun (Id_aux (x, _)) -> List.mem x ids (* Type-agnostic pattern comparison for merging below *) -let lit_eq' (L_aux (l1,_)) (L_aux (l2,_)) = - match l1, l2 with - | L_num n1, L_num n2 -> Big_int.equal n1 n2 - | _,_ -> l1 = l2 +let lit_eq' (L_aux (l1, _)) (L_aux (l2, _)) = + match (l1, l2) with L_num n1, L_num n2 -> Big_int.equal n1 n2 | _, _ -> l1 = l2 -let forall2 p x y = - try List.for_all2 p x y with Invalid_argument _ -> false +let forall2 p x y = try List.for_all2 p x y with Invalid_argument _ -> false let rec typ_pat_eq (TP_aux (tp1, _)) (TP_aux (tp2, _)) = - match tp1, tp2 with + match (tp1, tp2) with | TP_wild, TP_wild -> true | TP_var kid1, TP_var kid2 -> Kid.compare kid1 kid2 = 0 | TP_app (f1, args1), TP_app (f2, args2) when List.length args1 = List.length args2 -> - Id.compare f1 f2 = 0 && List.for_all2 typ_pat_eq args1 args2 + Id.compare f1 f2 = 0 && List.for_all2 typ_pat_eq args1 args2 | _, _ -> false -let rec pat_eq (P_aux (p1,_)) (P_aux (p2,_)) = - match p1, p2 with +let rec pat_eq (P_aux (p1, _)) (P_aux (p2, _)) = + match (p1, p2) with | P_lit lit1, P_lit lit2 -> lit_eq' lit1 lit2 | P_wild, P_wild -> true | P_or (p1, q1), P_or (p2, q2) -> - (* ToDo: A case could be made for flattening trees of P_or nodes and - * comparing the lists so that we treat P_or as associative - *) - pat_eq p1 p2 && pat_eq q1 q2 - | P_not(p1), P_not(p2) -> pat_eq p1 p2 - | P_as (p1',id1), P_as (p2',id2) -> Id.compare id1 id2 == 0 && pat_eq p1' p2' - | P_typ (_,p1'), P_typ (_,p2') -> pat_eq p1' p2' + (* ToDo: A case could be made for flattening trees of P_or nodes and + * comparing the lists so that we treat P_or as associative + *) + pat_eq p1 p2 && pat_eq q1 q2 + | P_not p1, P_not p2 -> pat_eq p1 p2 + | P_as (p1', id1), P_as (p2', id2) -> Id.compare id1 id2 == 0 && pat_eq p1' p2' + | P_typ (_, p1'), P_typ (_, p2') -> pat_eq p1' p2' | P_id id1, P_id id2 -> Id.compare id1 id2 == 0 | P_var (p1', tpat1), P_var (p2', tpat2) -> typ_pat_eq tpat1 tpat2 && pat_eq p1' p2' - | P_app (id1,args1), P_app (id2,args2) -> - Id.compare id1 id2 == 0 && forall2 pat_eq args1 args2 + | P_app (id1, args1), P_app (id2, args2) -> Id.compare id1 id2 == 0 && forall2 pat_eq args1 args2 | P_vector ps1, P_vector ps2 | P_vector_concat ps1, P_vector_concat ps2 | P_tuple ps1, P_tuple ps2 - | P_list ps1, P_list ps2 -> List.for_all2 pat_eq ps1 ps2 - | P_cons (p1',p1''), P_cons (p2',p2'') -> pat_eq p1' p2' && pat_eq p1'' p2'' - | _,_ -> false - - -module Analysis = -struct - -(* Does a location contain enough information to identify the syntax again? *) -let rec useful_loc = function - | Unknown -> false - | Unique (_,l) -> useful_loc l - | Generated l -> useful_loc l - | Hint (_,_,l) -> useful_loc l - | Range (_,_) -> true - -(* Usually we do a full case split on an argument, but sometimes we find a - case expression in the function body that suggests a more compact case - splitting. *) -type match_detail = - | Total - | Partial of tannot pat list * Parse_ast.l - -module IdLocMap = Map.Make (struct - type t = id * Parse_ast.l - let compare (id,l) (id',l') = - let x = Id.compare id id' in - if x <> 0 then x else - compare l l' -end) - -(* Arguments that we might split on *) -module ArgSplits = IdLocMap -type arg_splits = match_detail ArgSplits.t - -(* Function id, funcl loc for adding splits on sizes in the body when - there's no corresponding argument *) -module ExtraSplits = IdLocMap -type extra_splits = (match_detail KBindings.t) ExtraSplits.t - -(* For a case split after a type variable is let-bound; in particular when - a function is called to provide a size via a side effect (e.g., reading - a vector size register). *) -module KidLocMap = Map.Make (struct - type t= kid * Parse_ast.l - let compare (kid,l) (kid',l') = - let x = Kid.compare kid kid' in - if x <> 0 then x else - compare l l' -end) -module LetSplits = IdLocMap -type let_binding_splits = match_detail LetSplits.t - -(* Arguments that we should look at in callers *) -module CallerArgSet = Set.Make (struct - type t = id * int - let compare (id,i) (id',i') = - let x= Id.compare id id' in - if x <> 0 then x else compare i i' -end) - -(* Type variables that we should look at in callers *) -module CallerKidSet = Set.Make (struct - type t = id * kid - let compare (id,kid) (id',kid') = - match Id.compare id id' with - | 0 -> Kid.compare kid kid' + | P_list ps1, P_list ps2 -> + List.for_all2 pat_eq ps1 ps2 + | P_cons (p1', p1''), P_cons (p2', p2'') -> pat_eq p1' p2' && pat_eq p1'' p2'' + | _, _ -> false + +module Analysis = struct + (* Does a location contain enough information to identify the syntax again? *) + let rec useful_loc = function + | Unknown -> false + | Unique (_, l) -> useful_loc l + | Generated l -> useful_loc l + | Hint (_, _, l) -> useful_loc l + | Range (_, _) -> true + + (* Usually we do a full case split on an argument, but sometimes we find a + case expression in the function body that suggests a more compact case + splitting. *) + type match_detail = Total | Partial of tannot pat list * Parse_ast.l + + module IdLocMap = Map.Make (struct + type t = id * Parse_ast.l + let compare (id, l) (id', l') = + let x = Id.compare id id' in + if x <> 0 then x else compare l l' + end) + + (* Arguments that we might split on *) + module ArgSplits = IdLocMap + type arg_splits = match_detail ArgSplits.t + + (* Function id, funcl loc for adding splits on sizes in the body when + there's no corresponding argument *) + module ExtraSplits = IdLocMap + type extra_splits = match_detail KBindings.t ExtraSplits.t + + (* For a case split after a type variable is let-bound; in particular when + a function is called to provide a size via a side effect (e.g., reading + a vector size register). *) + module KidLocMap = Map.Make (struct + type t = kid * Parse_ast.l + let compare (kid, l) (kid', l') = + let x = Kid.compare kid kid' in + if x <> 0 then x else compare l l' + end) + module LetSplits = IdLocMap + type let_binding_splits = match_detail LetSplits.t + + (* Arguments that we should look at in callers *) + module CallerArgSet = Set.Make (struct + type t = id * int + let compare (id, i) (id', i') = + let x = Id.compare id id' in + if x <> 0 then x else compare i i' + end) + + (* Type variables that we should look at in callers *) + module CallerKidSet = Set.Make (struct + type t = id * kid + let compare (id, kid) (id', kid') = match Id.compare id id' with 0 -> Kid.compare kid kid' | x -> x + end) + + (* Map from locations to string sets *) + module Failures = Map.Make (struct + type t = Parse_ast.l + let compare = compare + end) + module StringSet = Set.Make (struct + type t = string + let compare = compare + end) + + type dependencies = Have of arg_splits * extra_splits * let_binding_splits | Unknown of Parse_ast.l * string + + let string_of_match_detail = function + | Total -> "[total]" + | Partial (pats, _) -> "[" ^ String.concat " | " (List.map string_of_pat pats) ^ "]" + + let string_of_argsplits s = + String.concat ", " + (List.map + (fun ((id, l), detail) -> string_of_id id ^ "." ^ simple_string_of_loc l ^ string_of_match_detail detail) + (ArgSplits.bindings s) + ) + + let string_of_extra_splits s = + String.concat ", " + (List.map + (fun ((id, l), ks) -> + string_of_id id ^ "." ^ simple_string_of_loc l ^ ":" + ^ String.concat "," + (List.map + (fun (kid, detail) -> string_of_kid kid ^ "." ^ string_of_match_detail detail) + (KBindings.bindings ks) + ) + ) + (ExtraSplits.bindings s) + ) + + let string_of_let_binding_splits s = + String.concat ", " + (List.map + (fun ((id, l), detail) -> string_of_id id ^ "." ^ simple_string_of_loc l ^ "." ^ string_of_match_detail detail) + (LetSplits.bindings s) + ) + + let _string_of_callerset s = + String.concat ", " (List.map (fun (id, arg) -> string_of_id id ^ "." ^ string_of_int arg) (CallerArgSet.elements s)) + + let string_of_callerkidset s = + String.concat ", " (List.map (fun (id, kid) -> string_of_id id ^ "." ^ string_of_kid kid) (CallerKidSet.elements s)) + + let string_of_dep = function + | Have (args, extras, letbinds) -> + "Have (" ^ string_of_argsplits args ^ ";" ^ string_of_extra_splits extras ^ ";" + ^ string_of_let_binding_splits letbinds ^ ")" + | Unknown (l, msg) -> "Unknown " ^ msg ^ " at " ^ Reporting.loc_to_string l + + (* If a callee uses a type variable as a size, does it need to be split in the + current function, or is it also a parameter? (Note that there may be multiple + calls, so more than one parameter can be involved) *) + type call_dep = { in_fun : dependencies option; parents : CallerKidSet.t } + + let in_fun_call_dep deps = { in_fun = Some deps; parents = CallerKidSet.empty } + + let parents_call_dep cks = { in_fun = None; parents = cks } + + (* Result of analysing the body of a function. The split field gives + the arguments to split based on the body alone, the extra_splits + field where we want to case split on a size type variable but + there's no corresponding argument so we introduce a case + expression, and the failures field where we couldn't do anything. + The other fields are used at the end for the interprocedural + phase. *) + + type result = { + split : arg_splits; + extra_splits : extra_splits; + let_binding_splits : let_binding_splits; + failures : StringSet.t Failures.t; + (* Dependencies for type variables of each fn called, so that + if the fn uses one for a bitvector size we can track it back *) + split_on_call : call_dep KBindings.t Bindings.t; (* kids per fn *) + kid_in_caller : CallerKidSet.t; + } + + let empty = + { + split = ArgSplits.empty; + extra_splits = ExtraSplits.empty; + let_binding_splits = LetSplits.empty; + failures = Failures.empty; + split_on_call = Bindings.empty; + kid_in_caller = CallerKidSet.empty; + } + + let merge_detail _ x y = + match (x, y) with + | None, x -> x + | x, None -> x + | Some (Partial (ps1, l1)), Some (Partial (ps2, l2)) when l1 = l2 && forall2 pat_eq ps1 ps2 -> x + | _ -> Some Total + + let opt_merge f _ x y = match (x, y) with None, _ -> y | _, None -> x | Some x, Some y -> Some (f x y) + + let merge_extras = ExtraSplits.merge (opt_merge (KBindings.merge merge_detail)) + + let dmerge x y = + match (x, y) with + | Unknown (l, s), _ -> Unknown (l, s) + | _, Unknown (l, s) -> Unknown (l, s) + | Have (args, extras, lets), Have (args', extras', lets') -> + Have + (ArgSplits.merge merge_detail args args', merge_extras extras extras', LetSplits.merge merge_detail lets lets') + + let dempty = Have (ArgSplits.empty, ExtraSplits.empty, LetSplits.empty) + + let dep_bindings_merge a1 a2 = Bindings.merge (opt_merge dmerge) a1 a2 + + let dep_kbindings_merge a1 a2 = KBindings.merge (opt_merge dmerge) a1 a2 + + let call_dep_merge k d1 d2 = + { in_fun = opt_merge dmerge k d1.in_fun d2.in_fun; parents = CallerKidSet.union d1.parents d2.parents } + + let call_kid_merge k x y = + match (x, y) with None, x -> x | x, None -> x | Some d1, Some d2 -> Some (call_dep_merge k d1 d2) + + let call_arg_merge k args args' = + match (args, args') with + | None, x -> x + | x, None -> x + | Some kdep, Some kdep' -> Some (KBindings.merge call_kid_merge kdep kdep') + + let failure_merge _ x y = + match (x, y) with None, x -> x | x, None -> x | Some x, Some y -> Some (StringSet.union x y) + + let merge rs rs' = + { + split = ArgSplits.merge merge_detail rs.split rs'.split; + extra_splits = merge_extras rs.extra_splits rs'.extra_splits; + let_binding_splits = LetSplits.merge merge_detail rs.let_binding_splits rs'.let_binding_splits; + failures = Failures.merge failure_merge rs.failures rs'.failures; + split_on_call = Bindings.merge call_arg_merge rs.split_on_call rs'.split_on_call; + kid_in_caller = CallerKidSet.union rs.kid_in_caller rs'.kid_in_caller; + } + + type env = { + top_kids : kid list; (* Int kids bound by the function type *) + var_deps : dependencies Bindings.t; + kid_deps : dependencies KBindings.t; + referenced_vars : IdSet.t; + globals : bool Bindings.t (* is_value or not *); + } + + let rec split3 = function + | [] -> ([], [], []) + | (h1, h2, h3) :: t -> + let t1, t2, t3 = split3 t in + (h1 :: t1, h2 :: t2, h3 :: t3) + + let is_kid_in_env env kid = match Env.get_typ_var kid env with _ -> true | exception _ -> false + + let rec kids_bound_by_typ_pat (TP_aux (tp, _)) = + match tp with + | TP_wild -> KidSet.empty + | TP_var kid -> KidSet.singleton kid + | TP_app (_, pats) -> kidset_bigunion (List.map kids_bound_by_typ_pat pats) + + (* We need both the explicitly bound kids from the AST, and any freshly + generated kids from the typechecker. *) + let kids_bound_by_pat pat = + let open Rewriter in + fst + (fold_pat + { + (compute_pat_alg KidSet.empty KidSet.union) with + p_aux = + (function + | (s, (P_var (P_aux (_, annot'), tpat) as p)), annot when not (is_empty_tannot (snd annot')) -> + let kids = tyvars_of_typ (typ_of_annot annot') in + let new_kids = KidSet.filter (fun kid -> not (is_kid_in_env (env_of_annot annot) kid)) kids in + let tpat_kids = kids_bound_by_typ_pat tpat in + (KidSet.union s (KidSet.union new_kids tpat_kids), P_aux (p, annot)) + | (s, p), ann -> (s, P_aux (p, ann)) + ); + } + pat + ) + + (* Diff the type environment to find new type variables and record that they + depend on deps *) + + let update_env_new_kids env deps typ_env_pre typ_env_post = + let kbound = + KBindings.merge + (fun k x y -> match (x, y) with Some k, None -> Some k | _ -> None) + (Env.get_typ_vars typ_env_post) (Env.get_typ_vars typ_env_pre) + in + let kid_deps = KBindings.fold (fun v _ ds -> KBindings.add v deps ds) kbound env.kid_deps in + { env with kid_deps } + + (* Add bound variables from a pattern to the environment with the given dependency, + plus any new type variables. *) + + let update_env env deps pat typ_env_pre typ_env_post = + let bound = Spec_analysis.bindings_from_pat pat in + let var_deps = List.fold_left (fun ds v -> Bindings.add v deps ds) env.var_deps bound in + update_env_new_kids { env with var_deps } deps typ_env_pre typ_env_post + + (* A function argument may end up with fresh type variables due to coercing + unification (which will eventually be existentially bound in the type of + the function). Here we record the dependencies for these variables. *) + + let add_arg_only_kids env typ_env typ deps = + let all_vars = tyvars_of_typ typ in + let check_kid kid kid_deps = if KBindings.mem kid kid_deps then kid_deps else KBindings.add kid deps kid_deps in + let kid_deps = KidSet.fold check_kid all_vars env.kid_deps in + { env with kid_deps } + + let assigned_vars_exps es = + List.fold_left (fun vs exp -> IdSet.union vs (Spec_analysis.assigned_vars exp)) IdSet.empty es + + (* For adding control dependencies to mutable variables *) + + let add_dep_to_assigned dep assigns es = + let assigned = assigned_vars_exps es in + Bindings.mapi (fun id d -> if IdSet.mem id assigned then dmerge dep d else d) assigns + + (* Functions to give dependencies for type variables in nexps, constraints, types and + unification variables. For function calls we also supply a list of dependencies for + arguments so that we can find dependencies for existentially bound sizes. *) + + let deps_of_tyvars l kid_deps arg_deps kids = + let check kid deps = + match KBindings.find kid kid_deps with + | deps' -> dmerge deps deps' + | exception Not_found -> ( + match kid with + | Kid_aux (Var kidstr, _) -> + let unknown = Unknown (l, "Unknown type variable " ^ string_of_kid kid) in + (* Tyvars from existentials in arguments have a special format *) + if String.length kidstr > 5 && String.sub kidstr 0 4 = "'arg" then ( + try + let i = String.index kidstr '#' in + let n = String.sub kidstr 4 (i - 4) in + let arg = int_of_string n in + List.nth arg_deps arg + with Not_found | Failure _ -> unknown + ) + else unknown + ) + in + KidSet.fold check kids dempty + + let deps_of_nexp l kid_deps arg_deps nexp = + let kids = nexp_frees nexp in + deps_of_tyvars l kid_deps arg_deps kids + + let rec deps_of_nc kid_deps (NC_aux (nc, l)) = + match nc with + | NC_equal (nexp1, nexp2) + | NC_bounded_ge (nexp1, nexp2) + | NC_bounded_gt (nexp1, nexp2) + | NC_bounded_le (nexp1, nexp2) + | NC_bounded_lt (nexp1, nexp2) + | NC_not_equal (nexp1, nexp2) -> + dmerge (deps_of_nexp l kid_deps [] nexp1) (deps_of_nexp l kid_deps [] nexp2) + | NC_set (kid, _) -> ( + match KBindings.find kid kid_deps with + | deps -> deps + | exception Not_found -> Unknown (l, "Unknown type variable in constraint " ^ string_of_kid kid) + ) + | NC_or (nc1, nc2) | NC_and (nc1, nc2) -> dmerge (deps_of_nc kid_deps nc1) (deps_of_nc kid_deps nc2) + | NC_true | NC_false -> dempty + | NC_app (Id_aux (Id "mod", _), [A_aux (A_nexp nexp1, _); A_aux (A_nexp nexp2, _)]) -> + dmerge (deps_of_nexp l kid_deps [] nexp1) (deps_of_nexp l kid_deps [] nexp2) + | NC_var _ | NC_app _ -> dempty + + and deps_of_typ l kid_deps arg_deps typ = deps_of_tyvars l kid_deps arg_deps (tyvars_of_typ typ) + + and deps_of_typ_arg l fn_id env arg_deps (A_aux (aux, _)) = + match aux with + | A_nexp (Nexp_aux (Nexp_var kid, _)) when List.exists (fun k -> Kid.compare kid k == 0) env.top_kids -> + parents_call_dep (CallerKidSet.singleton (fn_id, kid)) + | A_nexp nexp -> in_fun_call_dep (deps_of_nexp l env.kid_deps arg_deps nexp) + | A_order _ -> in_fun_call_dep dempty + | A_typ typ -> in_fun_call_dep (deps_of_typ l env.kid_deps arg_deps typ) + | A_bool nc -> in_fun_call_dep (deps_of_nc env.kid_deps nc) + + let mk_subrange_pattern vannot vstart vend = + let len, ord, typ = vector_typ_args_of (Env.base_typ_of (env_of_annot vannot) (typ_of_annot vannot)) in + match ord with + | Ord_aux (Ord_var _, _) -> None + | Ord_aux (ord', _) -> ( + let vstart, vend = if ord' = Ord_inc then (vstart, vend) else (vend, vstart) in + let dummyl = Generated Unknown in + match len with + | Nexp_aux (Nexp_constant len, _) -> + Some + (fun pat -> + let end_len = Big_int.pred (Big_int.sub len vend) in + (* Wrap pat in its type; in particular the type checker won't + manage P_wild in the middle of a P_vector_concat *) + let pat = P_aux (P_typ (typ_of_pat pat, pat), (Generated (pat_loc pat), empty_tannot)) in + let pats = + if Big_int.greater end_len Big_int.zero then + [ + pat; + P_aux + ( P_typ (bitvector_typ (nconstant end_len) ord, P_aux (P_wild, (dummyl, empty_tannot))), + (dummyl, empty_tannot) + ); + ] + else [pat] + in + let pats = + if Big_int.greater vstart Big_int.zero then + P_aux + ( P_typ (bitvector_typ (nconstant vstart) ord, P_aux (P_wild, (dummyl, empty_tannot))), + (dummyl, empty_tannot) + ) + :: pats + else pats + in + let pats = if ord' = Ord_inc then pats else List.rev pats in + P_aux (P_vector_concat pats, (Generated (fst vannot), empty_tannot)) + ) + | _ -> None + ) + + (* If the expression matched on in a case expression is a function argument, + and has no other dependencies, we can try to use the pattern match directly + rather than doing a full case split. *) + let refine_dependency env (E_aux (e, (l, annot)) as exp) pexps = + let check_dep id ctx = + match Bindings.find id env.var_deps with + | Have (args, extras, lets) -> begin + match (ArgSplits.bindings args, ExtraSplits.is_empty extras, LetSplits.is_empty lets) with + | [((id', loc), Total)], true, true when Id.compare id id' == 0 -> ( + match + Util.map_all + (function Pat_aux (Pat_exp (pat, _), _) -> Some (ctx pat) | Pat_aux (Pat_when (_, _, _), _) -> None) + pexps + with + | Some pats -> + if l = Parse_ast.Unknown then ( + Reporting.print_err l "" ("No location for pattern match: " ^ string_of_exp exp); + None + ) + else + Some (Have (ArgSplits.singleton (id, loc) (Partial (pats, l)), ExtraSplits.empty, LetSplits.empty)) + | None -> None + ) + | _ -> None + end + | Unknown _ -> None + | exception Not_found -> None + in + match e with + | E_id id -> check_dep id (fun x -> x) + | E_app + ( fn_id, + [ + E_aux (E_id id, vannot); E_aux (E_lit (L_aux (L_num vstart, _)), _); E_aux (E_lit (L_aux (L_num vend, _)), _); + ] + ) + when is_id (env_of exp) (Id "vector_subrange") fn_id -> ( + match mk_subrange_pattern vannot vstart vend with Some mk_pat -> check_dep id mk_pat | None -> None + ) + (* TODO: Aborted attempt at considering bitvector concatenations when + refining dependencies. Needs corresponding support in constant + propagation to work. *) + (* | E_app (append, [vec1; vec2]) + when is_id (env_of exp) (Id "append") append -> + (* If the expression is a concatenation resulting in a small enough bitvector, + perform a (total) case split on the sub-vectors *) + let vec_len v = try Option.map Big_int.to_int (get_constant_vec_len (env_of exp) v) with _ -> None in + let pow2 n = Big_int.pow_int (Big_int.of_int 2) n in + let size_set len1 len2 = Big_int.mul (pow2 len1) (pow2 len2) in + begin match (vec_len (typ_of exp), vec_len (typ_of vec1), vec_len (typ_of vec2)) with + | (Some len, Some len1, Some len2) + when Big_int.less_equal (size_set len1 len2) (Big_int.of_int size_set_limit) -> + let recur = refine_dependency env in + (* Create pexps with dummy bodies (ignored by the recursive call) *) + let mk_pexps len = + let mk_pexp lit = + let (_, ord, _) = vector_typ_args_of (typ_of exp) in + let tannot = mk_tannot (env_of exp) (bitvector_typ (nint len) ord) no_effect in + let pat = P_aux (P_lit lit, (Generated l, tannot)) in + let exp = E_aux (E_lit (mk_lit L_unit), (Generated l, empty_tannot)) in + Pat_aux (Pat_exp (pat, exp), (Generated l, empty_tannot)) + in + List.map mk_pexp (make_vectors len) + in + begin match (recur vec1 (mk_pexps len1), recur vec2 (mk_pexps len2)) with + | (Some deps1, Some deps2) -> Some (dmerge deps1 deps2) + | _ -> None + end + | _ -> None + end *) + | _ -> None + + let simplify_size_nexp env typ_env (Nexp_aux (ne, l) as nexp) = + match solve_unique typ_env nexp with + | Some n -> nconstant n + | None -> ( + let is_equal kid = + try + if Env.get_typ_var kid typ_env = K_int then + prove __POS__ typ_env (NC_aux (NC_equal (Nexp_aux (Nexp_var kid, Unknown), nexp), Unknown)) + else false + with _ -> false + in + match ne with + | Nexp_var _ | Nexp_constant _ -> nexp + | _ -> ( + match List.find is_equal env.top_kids with + | kid -> Nexp_aux (Nexp_var kid, Generated l) + | exception Not_found -> ( + match KBindings.find_first_opt is_equal (Env.get_typ_vars typ_env) with + | Some (kid, _) -> Nexp_aux (Nexp_var kid, Generated l) + | None -> nexp + ) + ) + ) + + let simplify_size_typ_arg env typ_env = function + | A_aux (A_nexp nexp, l) -> A_aux (A_nexp (simplify_size_nexp env typ_env nexp), l) | x -> x -end) - -(* Map from locations to string sets *) -module Failures = Map.Make (struct - type t = Parse_ast.l - let compare = compare -end) -module StringSet = Set.Make (struct - type t = string - let compare = compare -end) - -type dependencies = - | Have of arg_splits * extra_splits * let_binding_splits - | Unknown of Parse_ast.l * string - -let string_of_match_detail = function - | Total -> "[total]" - | Partial (pats,_) -> "[" ^ String.concat " | " (List.map string_of_pat pats) ^ "]" - -let string_of_argsplits s = - String.concat ", " - (List.map (fun ((id,l),detail) -> - string_of_id id ^ "." ^ simple_string_of_loc l ^ string_of_match_detail detail) - (ArgSplits.bindings s)) - -let string_of_extra_splits s = - String.concat ", " - (List.map (fun ((id,l),ks) -> - string_of_id id ^ "." ^ simple_string_of_loc l ^ ":" ^ - (String.concat "," (List.map (fun (kid,detail) -> - string_of_kid kid ^ "." ^ string_of_match_detail detail) - (KBindings.bindings ks)))) - (ExtraSplits.bindings s)) - -let string_of_let_binding_splits s = - String.concat ", " - (List.map (fun ((id,l),detail) -> - string_of_id id ^ "." ^ simple_string_of_loc l ^ "." ^ string_of_match_detail detail) - (LetSplits.bindings s)) - -let _string_of_callerset s = - String.concat ", " (List.map (fun (id,arg) -> string_of_id id ^ "." ^ string_of_int arg) - (CallerArgSet.elements s)) - -let string_of_callerkidset s = - String.concat ", " (List.map (fun (id,kid) -> string_of_id id ^ "." ^ string_of_kid kid) - (CallerKidSet.elements s)) - -let string_of_dep = function - | Have (args,extras,letbinds) -> - "Have (" ^ string_of_argsplits args ^ ";" ^ string_of_extra_splits extras ^ ";" ^ string_of_let_binding_splits letbinds ^ ")" - | Unknown (l,msg) -> "Unknown " ^ msg ^ " at " ^ Reporting.loc_to_string l - -(* If a callee uses a type variable as a size, does it need to be split in the - current function, or is it also a parameter? (Note that there may be multiple - calls, so more than one parameter can be involved) *) -type call_dep = { - in_fun : dependencies option; - parents : CallerKidSet.t; -} - -let in_fun_call_dep deps = { in_fun = Some deps; parents = CallerKidSet.empty } - -let parents_call_dep cks = { in_fun = None; parents = cks } - -(* Result of analysing the body of a function. The split field gives - the arguments to split based on the body alone, the extra_splits - field where we want to case split on a size type variable but - there's no corresponding argument so we introduce a case - expression, and the failures field where we couldn't do anything. - The other fields are used at the end for the interprocedural - phase. *) - -type result = { - split : arg_splits; - extra_splits : extra_splits; - let_binding_splits : let_binding_splits; - failures : StringSet.t Failures.t; - (* Dependencies for type variables of each fn called, so that - if the fn uses one for a bitvector size we can track it back *) - split_on_call : (call_dep KBindings.t) Bindings.t; (* kids per fn *) - kid_in_caller : CallerKidSet.t -} - -let empty = { - split = ArgSplits.empty; - extra_splits = ExtraSplits.empty; - let_binding_splits = LetSplits.empty; - failures = Failures.empty; - split_on_call = Bindings.empty; - kid_in_caller = CallerKidSet.empty -} - -let merge_detail _ x y = - match x,y with - | None, x -> x - | x, None -> x - | Some (Partial (ps1,l1)), Some (Partial (ps2,l2)) - when l1 = l2 && forall2 pat_eq ps1 ps2 -> x - | _ -> Some Total - -let opt_merge f _ x y = - match x,y with - | None, _ -> y - | _, None -> x - | Some x, Some y -> Some (f x y) - -let merge_extras = ExtraSplits.merge (opt_merge (KBindings.merge merge_detail)) - -let dmerge x y = - match x,y with - | Unknown (l,s), _ -> Unknown (l,s) - | _, Unknown (l,s) -> Unknown (l,s) - | Have (args,extras,lets), Have (args',extras',lets') -> - Have (ArgSplits.merge merge_detail args args', - merge_extras extras extras', - LetSplits.merge merge_detail lets lets') - -let dempty = Have (ArgSplits.empty, ExtraSplits.empty, LetSplits.empty) - -let dep_bindings_merge a1 a2 = - Bindings.merge (opt_merge dmerge) a1 a2 - -let dep_kbindings_merge a1 a2 = - KBindings.merge (opt_merge dmerge) a1 a2 - -let call_dep_merge k d1 d2 = { - in_fun = opt_merge dmerge k d1.in_fun d2.in_fun; - parents = CallerKidSet.union d1.parents d2.parents -} - -let call_kid_merge k x y = - match x, y with - | None, x -> x - | x, None -> x - | Some d1, Some d2 -> Some (call_dep_merge k d1 d2) - -let call_arg_merge k args args' = - match args, args' with - | None, x -> x - | x, None -> x - | Some kdep, Some kdep' - -> Some (KBindings.merge call_kid_merge kdep kdep') - -let failure_merge _ x y = - match x, y with - | None, x -> x - | x, None -> x - | Some x, Some y -> Some (StringSet.union x y) - -let merge rs rs' = { - split = ArgSplits.merge merge_detail rs.split rs'.split; - extra_splits = merge_extras rs.extra_splits rs'.extra_splits; - let_binding_splits = LetSplits.merge merge_detail rs.let_binding_splits rs'.let_binding_splits; - failures = Failures.merge failure_merge rs.failures rs'.failures; - split_on_call = Bindings.merge call_arg_merge rs.split_on_call rs'.split_on_call; - kid_in_caller = CallerKidSet.union rs.kid_in_caller rs'.kid_in_caller -} - -type env = { - top_kids : kid list; (* Int kids bound by the function type *) - var_deps : dependencies Bindings.t; - kid_deps : dependencies KBindings.t; - referenced_vars : IdSet.t; - globals : bool Bindings.t (* is_value or not *) -} - -let rec split3 = function - | [] -> [],[],[] - | ((h1,h2,h3)::t) -> - let t1,t2,t3 = split3 t in - (h1::t1,h2::t2,h3::t3) - -let is_kid_in_env env kid = - match Env.get_typ_var kid env with - | _ -> true - | exception _ -> false - -let rec kids_bound_by_typ_pat (TP_aux (tp,_)) = - match tp with - | TP_wild -> KidSet.empty - | TP_var kid -> KidSet.singleton kid - | TP_app (_,pats) -> - kidset_bigunion (List.map kids_bound_by_typ_pat pats) - -(* We need both the explicitly bound kids from the AST, and any freshly - generated kids from the typechecker. *) -let kids_bound_by_pat pat = - let open Rewriter in - fst (fold_pat ({ (compute_pat_alg KidSet.empty KidSet.union) - with p_aux = - (function ((s,(P_var (P_aux (_, annot'),tpat) as p)), annot) when not (is_empty_tannot (snd annot')) -> - let kids = tyvars_of_typ (typ_of_annot annot') in - let new_kids = KidSet.filter (fun kid -> not (is_kid_in_env (env_of_annot annot) kid)) kids in - let tpat_kids = kids_bound_by_typ_pat tpat in - KidSet.union s (KidSet.union new_kids tpat_kids), P_aux (p, annot) - | ((s,p),ann) -> s, P_aux (p,ann)) - }) pat) - -(* Diff the type environment to find new type variables and record that they - depend on deps *) - -let update_env_new_kids env deps typ_env_pre typ_env_post = - let kbound = - KBindings.merge (fun k x y -> - match x,y with - | Some k, None -> Some k - | _ -> None) - (Env.get_typ_vars typ_env_post) - (Env.get_typ_vars typ_env_pre) - in - let kid_deps = KBindings.fold (fun v _ ds -> KBindings.add v deps ds) kbound env.kid_deps in - { env with kid_deps = kid_deps } - -(* Add bound variables from a pattern to the environment with the given dependency, - plus any new type variables. *) - -let update_env env deps pat typ_env_pre typ_env_post = - let bound = Spec_analysis.bindings_from_pat pat in - let var_deps = List.fold_left (fun ds v -> Bindings.add v deps ds) env.var_deps bound in - update_env_new_kids { env with var_deps = var_deps } deps typ_env_pre typ_env_post - -(* A function argument may end up with fresh type variables due to coercing - unification (which will eventually be existentially bound in the type of - the function). Here we record the dependencies for these variables. *) - -let add_arg_only_kids env typ_env typ deps = - let all_vars = tyvars_of_typ typ in - let check_kid kid kid_deps = - if KBindings.mem kid kid_deps then kid_deps - else KBindings.add kid deps kid_deps - in - let kid_deps = KidSet.fold check_kid all_vars env.kid_deps in - { env with kid_deps } - -let assigned_vars_exps es = - List.fold_left (fun vs exp -> IdSet.union vs (Spec_analysis.assigned_vars exp)) - IdSet.empty es - -(* For adding control dependencies to mutable variables *) - -let add_dep_to_assigned dep assigns es = - let assigned = assigned_vars_exps es in - Bindings.mapi (fun id d -> if IdSet.mem id assigned then dmerge dep d else d) assigns - -(* Functions to give dependencies for type variables in nexps, constraints, types and - unification variables. For function calls we also supply a list of dependencies for - arguments so that we can find dependencies for existentially bound sizes. *) - -let deps_of_tyvars l kid_deps arg_deps kids = - let check kid deps = - match KBindings.find kid kid_deps with - | deps' -> dmerge deps deps' - | exception Not_found -> - match kid with - | Kid_aux (Var kidstr, _) -> - let unknown = Unknown (l, "Unknown type variable " ^ string_of_kid kid) in - (* Tyvars from existentials in arguments have a special format *) - if String.length kidstr > 5 && String.sub kidstr 0 4 = "'arg" then - try - let i = String.index kidstr '#' in - let n = String.sub kidstr 4 (i-4) in - let arg = int_of_string n in - List.nth arg_deps arg - with Not_found | Failure _ -> unknown - else unknown - in - KidSet.fold check kids dempty - -let deps_of_nexp l kid_deps arg_deps nexp = - let kids = nexp_frees nexp in - deps_of_tyvars l kid_deps arg_deps kids - -let rec deps_of_nc kid_deps (NC_aux (nc,l)) = - match nc with - | NC_equal (nexp1,nexp2) - | NC_bounded_ge (nexp1,nexp2) - | NC_bounded_gt (nexp1,nexp2) - | NC_bounded_le (nexp1,nexp2) - | NC_bounded_lt (nexp1,nexp2) - | NC_not_equal (nexp1,nexp2) - -> dmerge (deps_of_nexp l kid_deps [] nexp1) (deps_of_nexp l kid_deps [] nexp2) - | NC_set (kid,_) -> - (match KBindings.find kid kid_deps with - | deps -> deps - | exception Not_found -> Unknown (l, "Unknown type variable in constraint " ^ string_of_kid kid)) - | NC_or (nc1,nc2) - | NC_and (nc1,nc2) - -> dmerge (deps_of_nc kid_deps nc1) (deps_of_nc kid_deps nc2) - | NC_true - | NC_false - -> dempty - | NC_app (Id_aux (Id "mod", _), [A_aux (A_nexp nexp1, _); A_aux (A_nexp nexp2, _)]) - -> dmerge (deps_of_nexp l kid_deps [] nexp1) (deps_of_nexp l kid_deps [] nexp2) - | NC_var _ | NC_app _ - -> dempty - -and deps_of_typ l kid_deps arg_deps typ = - deps_of_tyvars l kid_deps arg_deps (tyvars_of_typ typ) - -and deps_of_typ_arg l fn_id env arg_deps (A_aux (aux, _)) = - match aux with - | A_nexp (Nexp_aux (Nexp_var kid,_)) - when List.exists (fun k -> Kid.compare kid k == 0) env.top_kids -> - parents_call_dep (CallerKidSet.singleton (fn_id,kid)) - | A_nexp nexp -> in_fun_call_dep (deps_of_nexp l env.kid_deps arg_deps nexp) - | A_order _ -> in_fun_call_dep dempty - | A_typ typ -> in_fun_call_dep (deps_of_typ l env.kid_deps arg_deps typ) - | A_bool nc -> in_fun_call_dep (deps_of_nc env.kid_deps nc) - -let mk_subrange_pattern vannot vstart vend = - let (len,ord,typ) = vector_typ_args_of (Env.base_typ_of (env_of_annot vannot) (typ_of_annot vannot)) in - match ord with - | Ord_aux (Ord_var _,_) -> None - | Ord_aux (ord',_) -> - let vstart,vend = if ord' = Ord_inc then vstart,vend else vend,vstart - in - let dummyl = Generated Unknown in - match len with - | Nexp_aux (Nexp_constant len,_) -> - Some (fun pat -> - let end_len = Big_int.pred (Big_int.sub len vend) in - (* Wrap pat in its type; in particular the type checker won't - manage P_wild in the middle of a P_vector_concat *) - let pat = P_aux (P_typ (typ_of_pat pat, pat),(Generated (pat_loc pat),empty_tannot)) in - let pats = if Big_int.greater end_len Big_int.zero then - [pat;P_aux (P_typ (bitvector_typ (nconstant end_len) ord, - P_aux (P_wild,(dummyl,empty_tannot))),(dummyl,empty_tannot))] - else [pat] + + (* Takes an environment of dependencies on vars, type vars, and flow control, + and dependencies on mutable variables. The latter are quite conservative, + we currently drop variables assigned inside loops, for example. *) + + let rec analyse_exp fn_id effect_info env assigns (E_aux (e, (l, annot)) as exp) = + let analyse_sub = analyse_exp fn_id effect_info in + let analyse_lexp = analyse_lexp fn_id effect_info in + let remove_assigns es message = + let assigned = assigned_vars_exps es in + IdSet.fold (fun id asn -> Bindings.add id (Unknown (l, string_of_id id ^ message)) asn) assigned assigns + in + let non_det es = + let assigns = remove_assigns es " assigned in non-deterministic expressions" in + let deps, _, rs = split3 (List.map (analyse_sub env assigns) es) in + (deps, assigns, List.fold_left merge empty rs) + in + (* We allow for arguments to functions being executed non-deterministically, but + follow the type checker in processing them in-order to detect the automatic + unpacking of existentials. When we spot a new type variable (using + update_env_new_kids) we set them to depend on the previous argument. *) + let non_det_args es typs = + let assigns = remove_assigns es " assigned in non-deterministic expressions" in + let rec aux env = function + | [], _ -> ([], empty, env) + | (E_aux (_, ann) as h) :: t, typ :: typs -> + let typ_env = env_of h in + let new_deps, _, new_r = analyse_sub env assigns h in + let env = add_arg_only_kids env typ_env typ new_deps in + let t_deps, t_r, t_env = aux env (t, typs) in + (new_deps :: t_deps, merge new_r t_r, t_env) + | _ :: _, [] -> Reporting.unreachable l __POS__ "Argument and type list in non_det_args had different lengths" + in + let deps, r, env = aux env (es, typs) in + (deps, assigns, r, env) + in + let is_toplevel_int tannot = + match destruct_atom_nexp (env_of_annot tannot) (typ_of_annot tannot) with + | Some (Nexp_aux (Nexp_var kid, _)) -> List.exists (fun k -> Kid.compare k kid == 0) env.top_kids + | _ -> false + in + let merge_deps deps = List.fold_left dmerge dempty deps in + let deps, assigns, r = + match e with + | E_block es -> + let rec aux env assigns = function + | [] -> (dempty, assigns, empty) + | [e] -> analyse_sub env assigns e + (* There's also a lone assignment case below where no env update is needed *) + | E_aux (E_assign (lexp, e1), ann) :: e2 :: es -> + let d1, assigns, r1 = analyse_sub env assigns e1 in + let assigns, r2 = analyse_lexp env assigns d1 lexp in + let env = update_env_new_kids env d1 (env_of_annot ann) (env_of e2) in + let d3, assigns, r3 = aux env assigns (e2 :: es) in + (d3, assigns, merge (merge r1 r2) r3) + | e :: es -> + let _, assigns, r' = analyse_sub env assigns e in + let d, assigns, r = aux env assigns es in + (d, assigns, merge r r') in - let pats = if Big_int.greater vstart Big_int.zero then - (P_aux (P_typ (bitvector_typ (nconstant vstart) ord, - P_aux (P_wild,(dummyl,empty_tannot))),(dummyl,empty_tannot)))::pats - else pats + aux env assigns es + | E_id id -> begin + match Bindings.find id env.var_deps with + | args -> (args, assigns, empty) + | exception Not_found -> ( + match Bindings.find id assigns with + | args -> (args, assigns, empty) + | exception Not_found -> ( + match Env.lookup_id id (Type_check.env_of_annot (l, annot)) with + | Enum _ -> (dempty, assigns, empty) + | Register _ -> (Unknown (l, string_of_id id ^ " is a register"), assigns, empty) + | _ -> + if IdSet.mem id env.referenced_vars then + (Unknown (l, string_of_id id ^ " may be modified via a reference"), assigns, empty) + else ( + match Bindings.find id env.globals with + | true -> (dempty, assigns, empty (* value *)) + | false -> (Unknown (l, string_of_id id ^ " is a global but not a value"), assigns, empty) + | exception Not_found -> + (Unknown (l, string_of_id id ^ " is not in the environment"), assigns, empty) + ) + ) + ) + end + | E_lit _ -> (dempty, assigns, empty) + | E_typ (_, e) -> analyse_sub env assigns e + | E_app (id, args) -> + let typ_env = env_of_annot (l, annot) in + let _, fn_typ = Env.get_val_spec_orig id typ_env in + let kid_inst = instantiation_of exp in + let kid_inst = KBindings.fold (fun kid -> KBindings.add (orig_kid kid)) kid_inst KBindings.empty in + let fn_typ = subst_unifiers kid_inst fn_typ in + let arg_typs = match fn_typ with Typ_aux (Typ_fn (args, _), _) -> args | _ -> [] in + (* We have to use the types from the val_spec here so that we can track + any type variables that are generated by the coercing unification that + the type checker applies after inferring the type of an argument, and + that only appear in the unifiers. *) + let deps, assigns, r, env = non_det_args args arg_typs in + let eff_dep = + (* For a pure function we can monomorphise the result by monomorphising + the arguments - but that's not guaranteed for an effectful function, + which may (e.g.) depend upn a register. *) + if Effects.function_is_pure id effect_info then dempty else Unknown (l, "Effects from function application") in - let pats = if ord' = Ord_inc then pats else List.rev pats + let kid_inst = KBindings.map (simplify_size_typ_arg env typ_env) kid_inst in + (* Change kids in instantiation to the canonical ones from the type signature *) + let kid_deps = KBindings.map (deps_of_typ_arg l fn_id env deps) kid_inst in + let rdep, r' = + if Id.compare fn_id id == 0 then ( + let bad = Unknown (l, "Recursive call of " ^ string_of_id id) in + let kid_deps = KBindings.map (fun _ -> in_fun_call_dep bad) kid_deps in + (bad, { empty with split_on_call = Bindings.singleton id kid_deps }) + ) + else (dempty, { empty with split_on_call = Bindings.singleton id kid_deps }) in - P_aux (P_vector_concat pats,(Generated (fst vannot),empty_tannot))) - | _ -> None - -(* If the expression matched on in a case expression is a function argument, - and has no other dependencies, we can try to use the pattern match directly - rather than doing a full case split. *) -let refine_dependency env (E_aux (e,(l,annot)) as exp) pexps = - let check_dep id ctx = - match Bindings.find id env.var_deps with - | Have (args,extras,lets) -> begin - match ArgSplits.bindings args, ExtraSplits.is_empty extras, LetSplits.is_empty lets with - | [(id',loc),Total], true, true when Id.compare id id' == 0 -> - (match Util.map_all (function - | Pat_aux (Pat_exp (pat,_),_) -> Some (ctx pat) - | Pat_aux (Pat_when (_,_,_),_) -> None) pexps - with - | Some pats -> - if l = Parse_ast.Unknown then - (Reporting.print_err l "" ("No location for pattern match: " ^ string_of_exp exp); - None) - else - Some (Have (ArgSplits.singleton (id,loc) (Partial (pats,l)), - ExtraSplits.empty, - LetSplits.empty)) - | None -> None) - | _ -> None - end - | Unknown _ -> None - | exception Not_found -> None - in - match e with - | E_id id -> check_dep id (fun x -> x) - | E_app (fn_id, [E_aux (E_id id,vannot); - E_aux (E_lit (L_aux (L_num vstart,_)),_); - E_aux (E_lit (L_aux (L_num vend,_)),_)]) - when is_id (env_of exp) (Id "vector_subrange") fn_id -> - (match mk_subrange_pattern vannot vstart vend with - | Some mk_pat -> check_dep id mk_pat - | None -> None) - (* TODO: Aborted attempt at considering bitvector concatenations when - refining dependencies. Needs corresponding support in constant - propagation to work. *) - (* | E_app (append, [vec1; vec2]) - when is_id (env_of exp) (Id "append") append -> - (* If the expression is a concatenation resulting in a small enough bitvector, - perform a (total) case split on the sub-vectors *) - let vec_len v = try Option.map Big_int.to_int (get_constant_vec_len (env_of exp) v) with _ -> None in - let pow2 n = Big_int.pow_int (Big_int.of_int 2) n in - let size_set len1 len2 = Big_int.mul (pow2 len1) (pow2 len2) in - begin match (vec_len (typ_of exp), vec_len (typ_of vec1), vec_len (typ_of vec2)) with - | (Some len, Some len1, Some len2) - when Big_int.less_equal (size_set len1 len2) (Big_int.of_int size_set_limit) -> - let recur = refine_dependency env in - (* Create pexps with dummy bodies (ignored by the recursive call) *) - let mk_pexps len = - let mk_pexp lit = - let (_, ord, _) = vector_typ_args_of (typ_of exp) in - let tannot = mk_tannot (env_of exp) (bitvector_typ (nint len) ord) no_effect in - let pat = P_aux (P_lit lit, (Generated l, tannot)) in - let exp = E_aux (E_lit (mk_lit L_unit), (Generated l, empty_tannot)) in - Pat_aux (Pat_exp (pat, exp), (Generated l, empty_tannot)) - in - List.map mk_pexp (make_vectors len) + (merge_deps (rdep :: eff_dep :: deps), assigns, merge r r') + | E_tuple es | E_list es -> + let deps, assigns, r = non_det es in + (merge_deps deps, assigns, r) + | E_if (e1, e2, e3) -> + let d1, assigns, r1 = analyse_sub env assigns e1 in + let d2, a2, r2 = analyse_sub env assigns e2 in + let d3, a3, r3 = analyse_sub env assigns e3 in + let assigns = add_dep_to_assigned d1 (dep_bindings_merge a2 a3) [e2; e3] in + (dmerge d1 (dmerge d2 d3), assigns, merge r1 (merge r2 r3)) + | E_loop (_, _, e1, e2) -> + (* We remove all of the variables assigned in the loop, so we don't + need to add control dependencies *) + let assigns = remove_assigns [e1; e2] " assigned in a loop" in + let d1, a1, r1 = analyse_sub env assigns e1 in + let d2, a2, r2 = analyse_sub env assigns e2 in + (dempty, assigns, merge r1 r2) + | E_for (var, efrom, eto, eby, ord, body) -> + let d1, assigns, r1 = non_det [efrom; eto; eby] in + let assigns = remove_assigns [body] " assigned in a loop" in + let d = merge_deps d1 in + let loop_kid = mk_kid ("loop_" ^ string_of_id var) in + let env' = { env with kid_deps = KBindings.add loop_kid d env.kid_deps } in + let d2, a2, r2 = analyse_sub env' assigns body in + (dempty, assigns, merge r1 r2) + | E_vector es -> + let ds, assigns, r = non_det es in + (merge_deps ds, assigns, r) + | E_vector_access (e1, e2) | E_vector_append (e1, e2) | E_cons (e1, e2) -> + let ds, assigns, r = non_det [e1; e2] in + (merge_deps ds, assigns, r) + | E_vector_subrange (e1, e2, e3) | E_vector_update (e1, e2, e3) -> + let ds, assigns, r = non_det [e1; e2; e3] in + (merge_deps ds, assigns, r) + | E_vector_update_subrange (e1, e2, e3, e4) -> + let ds, assigns, r = non_det [e1; e2; e3; e4] in + (merge_deps ds, assigns, r) + | E_struct fexps -> + let es = List.map (function FE_aux (FE_fexp (_, e), _) -> e) fexps in + let ds, assigns, r = non_det es in + (merge_deps ds, assigns, r) + | E_struct_update (e, fexps) -> + let es = List.map (function FE_aux (FE_fexp (_, e), _) -> e) fexps in + let ds, assigns, r = non_det (e :: es) in + (merge_deps ds, assigns, r) + | E_field (e, _) -> analyse_sub env assigns e + | E_match (e, cases) -> + let deps, assigns, r = analyse_sub env assigns e in + let deps = match refine_dependency env e cases with Some deps -> deps | None -> deps in + let analyse_case (Pat_aux (pexp, _)) = + match pexp with + | Pat_exp (pat, e1) -> + let env = update_env env deps pat (env_of_annot (l, annot)) (env_of e1) in + let d, assigns, r = analyse_sub env assigns e1 in + let assigns = add_dep_to_assigned deps assigns [e1] in + (d, assigns, r) + | Pat_when (pat, e1, e2) -> + let env = update_env env deps pat (env_of_annot (l, annot)) (env_of e2) in + let d1, assigns, r1 = analyse_sub env assigns e1 in + let d2, assigns, r2 = analyse_sub env assigns e2 in + let assigns = add_dep_to_assigned deps assigns [e1; e2] in + (dmerge d1 d2, assigns, merge r1 r2) in - begin match (recur vec1 (mk_pexps len1), recur vec2 (mk_pexps len2)) with - | (Some deps1, Some deps2) -> Some (dmerge deps1 deps2) - | _ -> None - end - | _ -> None - end *) - | _ -> None + let ds, assigns, rs = split3 (List.map analyse_case cases) in + (merge_deps (deps :: ds), List.fold_left dep_bindings_merge Bindings.empty assigns, List.fold_left merge r rs) + | E_let (LB_aux (LB_val (pat, e1), (lb_l, _)), e2) -> + let d1, assigns, r1 = analyse_sub env assigns e1 in + let unknown_deps = match d1 with Unknown _ -> true | Have _ -> false in + let d = + (* As a special case, detect + let 'size = if ... then 'typaram1 else 'typaram2; + where we can reduce the dependencies of 'size to the guard. *) + match (pat, e1) with + | ( P_aux (P_var (P_aux (P_id id, _), TP_aux (TP_var kid, _)), _), + E_aux (E_if (guard_exp, E_aux (E_id id1, annot1), E_aux (E_id id2, annot2)), _) ) + when is_toplevel_int annot1 && is_toplevel_int annot2 -> + let guard_deps, _, _ = analyse_sub env assigns guard_exp in + guard_deps + (* Add a new case split after the let if necessary *) + (* Potential improvements: match on more patterns (e.g. tuples); + allow disjunctions of equalities as well as set constraints; + allow set constraint to be part of a larger constraint. *) + | P_aux ((P_id id | P_var (P_aux (P_id id, _), _)), _), _ when unknown_deps && useful_loc lb_l -> + let l' = Generated l in + let split = + match typ_of e1 with + | Typ_aux (Typ_exist ([kdid], NC_aux (NC_set (kid, sizes), _), typ), _) + when Kid.compare (kopt_kid kdid) kid == 0 -> begin + match Type_check.destruct_atom_nexp (env_of e1) typ with + | Some nexp when Nexp.compare (nvar kid) nexp == 0 -> + let pats = List.map (fun n -> P_aux (P_lit (L_aux (L_num n, l')), (l', annot))) sizes in + Partial (pats, l) + | _ -> Total + end + | Typ_aux (Typ_app (Id_aux (Id "atom", _), [A_aux (A_nexp (Nexp_aux (Nexp_var kid', _)), _)]), _) -> + let typ_env = env_of_annot (l, annot) in + let constraints = Type_check.Env.get_constraints typ_env in + let vars = Spec_analysis.equal_kids_ncs kid' constraints in + begin + match + Util.find_map + (function NC_aux (NC_set (kid'', is), _) when KidSet.mem kid'' vars -> Some is | _ -> None) + constraints + with + | Some sizes -> + let pats = List.map (fun n -> P_aux (P_lit (L_aux (L_num n, l')), (l', annot))) sizes in + Partial (pats, l) + | None -> Total + end + | _ -> Total + in + Have (ArgSplits.empty, ExtraSplits.empty, LetSplits.singleton (id, lb_l) split) + | _, _ -> d1 + in + let env = update_env env d pat (env_of_annot (l, annot)) (env_of e2) in + let d2, assigns, r2 = analyse_sub env assigns e2 in + (d2, assigns, merge r1 r2) + (* There's a more general assignment case above to update env inside a block. *) + | E_assign (lexp, e1) -> + let d1, assigns, r1 = analyse_sub env assigns e1 in + let assigns, r2 = analyse_lexp env assigns d1 lexp in + (dempty, assigns, merge r1 r2) + | E_sizeof nexp -> (deps_of_nexp l env.kid_deps [] nexp, assigns, empty) + | E_return e | E_exit e | E_throw e -> + let _, _, r = analyse_sub env assigns e in + (dempty, Bindings.empty, r) + | E_ref id -> (Unknown (l, "May be mutated via reference to " ^ string_of_id id), assigns, empty) + | E_try (e, cases) -> + let deps, _, r = analyse_sub env assigns e in + let assigns = remove_assigns [e] " assigned in try expression" in + let analyse_handler (Pat_aux (pexp, _)) = + match pexp with + | Pat_exp (pat, e1) -> + let env = update_env env (Unknown (l, "Exception")) pat (env_of_annot (l, annot)) (env_of e1) in + let d, assigns, r = analyse_sub env assigns e1 in + let assigns = add_dep_to_assigned deps assigns [e1] in + (d, assigns, r) + | Pat_when (pat, e1, e2) -> + let env = update_env env (Unknown (l, "Exception")) pat (env_of_annot (l, annot)) (env_of e2) in + let d1, assigns, r1 = analyse_sub env assigns e1 in + let d2, assigns, r2 = analyse_sub env assigns e2 in + let assigns = add_dep_to_assigned deps assigns [e1; e2] in + (dmerge d1 d2, assigns, merge r1 r2) + in + let ds, assigns, rs = split3 (List.map analyse_handler cases) in + (merge_deps (deps :: ds), List.fold_left dep_bindings_merge Bindings.empty assigns, List.fold_left merge r rs) + | E_assert (e1, _) -> analyse_sub env assigns e1 + | E_internal_assume (nc, e1) -> analyse_sub env assigns e1 + | E_app_infix _ | E_internal_plet _ | E_internal_return _ | E_internal_value _ -> + raise + (Reporting.err_unreachable l __POS__ + ("Unexpected expression encountered in monomorphisation: " ^ string_of_exp exp) + ) + | E_var (lexp, e1, e2) -> + (* Really we ought to remove the assignment after e2 *) + let d1, assigns, r1 = analyse_sub env assigns e1 in + let assigns, r' = analyse_lexp env assigns d1 lexp in + let d2, assigns, r2 = analyse_sub env assigns e2 in + (dempty, assigns, merge r1 (merge r' r2)) + | E_constraint nc -> (deps_of_nc env.kid_deps nc, assigns, empty) + in + let deps = + match destruct_atom_bool (env_of exp) (typ_of exp) with + | Some nc -> dmerge deps (deps_of_nc env.kid_deps nc) + | None -> deps + in + let r = + (* Check for bitvector types with parametrised sizes *) + match destruct_tannot annot with + | None -> r + | Some (tenv, typ) -> + let typ = Env.base_typ_of tenv typ in + let env, tenv, typ = + match destruct_exist (Env.expand_synonyms tenv typ) with + | None -> (env, tenv, typ) + | Some (kopts, nc, typ) -> + ( { + env with + kid_deps = + List.fold_left (fun kds kopt -> KBindings.add (kopt_kid kopt) deps kds) env.kid_deps kopts; + }, + Env.add_constraint nc (List.fold_left (fun tenv kopt -> Env.add_typ_var l kopt tenv) tenv kopts), + typ + ) + in + let rec check_typ typ = + if is_bitvector_typ typ then ( + let size, _, _ = vector_typ_args_of typ in + let (Nexp_aux (size, _) as size_nexp) = simplify_size_nexp env tenv size in + let is_tyvar_parameter v = List.exists (fun k -> Kid.compare k v == 0) env.top_kids in + match size with + | Nexp_constant _ -> r + | Nexp_var v when is_tyvar_parameter v -> + { r with kid_in_caller = CallerKidSet.add (fn_id, v) r.kid_in_caller } + | _ -> ( + match deps_of_nexp l env.kid_deps [] size_nexp with + | Have (args, extras, lets) -> + { + r with + split = ArgSplits.merge merge_detail r.split args; + extra_splits = merge_extras r.extra_splits extras; + let_binding_splits = LetSplits.merge merge_detail r.let_binding_splits lets; + } + | Unknown (l, msg) -> + { + r with + failures = + Failures.add l + (StringSet.singleton ("Unable to monomorphise " ^ string_of_nexp size_nexp ^ ": " ^ msg)) + r.failures; + } + ) + ) + else ( + match typ with + | Typ_aux (Typ_tuple typs, _) -> List.fold_left (fun r ty -> merge r (check_typ ty)) r typs + | _ -> r + ) + in + check_typ typ + in + (deps, assigns, r) -let simplify_size_nexp env typ_env (Nexp_aux (ne,l) as nexp) = - match solve_unique typ_env nexp with - | Some n -> nconstant n - | None -> - let is_equal kid = - try - if Env.get_typ_var kid typ_env = K_int then - prove __POS__ typ_env (NC_aux (NC_equal (Nexp_aux (Nexp_var kid,Unknown), nexp),Unknown)) - else false - with _ -> false - in - match ne with - | Nexp_var _ - | Nexp_constant _ -> nexp - | _ -> - match List.find is_equal env.top_kids with - | kid -> Nexp_aux (Nexp_var kid, Generated l) - | exception Not_found -> - match KBindings.find_first_opt is_equal (Env.get_typ_vars typ_env) with - | Some (kid,_) -> Nexp_aux (Nexp_var kid, Generated l) - | None -> nexp - -let simplify_size_typ_arg env typ_env = function - | A_aux (A_nexp nexp, l) -> A_aux (A_nexp (simplify_size_nexp env typ_env nexp), l) - | x -> x - -(* Takes an environment of dependencies on vars, type vars, and flow control, - and dependencies on mutable variables. The latter are quite conservative, - we currently drop variables assigned inside loops, for example. *) - -let rec analyse_exp fn_id effect_info env assigns (E_aux (e,(l,annot)) as exp) = - let analyse_sub = analyse_exp fn_id effect_info in - let analyse_lexp = analyse_lexp fn_id effect_info in - let remove_assigns es message = - let assigned = assigned_vars_exps es in - IdSet.fold - (fun id asn -> - Bindings.add id (Unknown (l, string_of_id id ^ message)) asn) - assigned assigns - in - let non_det es = - let assigns = remove_assigns es " assigned in non-deterministic expressions" in - let deps, _, rs = split3 (List.map (analyse_sub env assigns) es) in - (deps, assigns, List.fold_left merge empty rs) - in - (* We allow for arguments to functions being executed non-deterministically, but - follow the type checker in processing them in-order to detect the automatic - unpacking of existentials. When we spot a new type variable (using - update_env_new_kids) we set them to depend on the previous argument. *) - let non_det_args es typs = - let assigns = remove_assigns es " assigned in non-deterministic expressions" in - let rec aux env = function - | [], _ -> [], empty, env - | (E_aux (_,ann) as h)::t, typ::typs -> - let typ_env = env_of h in - let new_deps, _, new_r = analyse_sub env assigns h in - let env = add_arg_only_kids env typ_env typ new_deps in - let t_deps, t_r, t_env = aux env (t,typs) in - new_deps::t_deps, merge new_r t_r, t_env - | _ :: _, [] -> - Reporting.unreachable l __POS__ "Argument and type list in non_det_args had different lengths" + and analyse_lexp fn_id effect_info env assigns deps (LE_aux (lexp, (l, _))) = + let analyse_sub = analyse_exp fn_id effect_info in + let analyse_lexp = analyse_lexp fn_id effect_info in + (* TODO: maybe subexps and sublexps should be non-det (and in const_prop_lexp, too?) *) + match lexp with + | LE_id id | LE_typ (_, id) -> + if IdSet.mem id env.referenced_vars then (assigns, empty) else (Bindings.add id deps assigns, empty) + | LE_app (id, es) -> + let _, assigns, r = analyse_sub env assigns (E_aux (E_tuple es, (Unknown, empty_tannot))) in + (assigns, r) + | LE_tuple lexps | LE_vector_concat lexps -> + List.fold_left + (fun (assigns, r) lexp -> + let assigns, r' = analyse_lexp env assigns deps lexp in + (assigns, merge r r') + ) + (assigns, empty) lexps + | LE_vector (lexp, e) -> + let _, assigns, r1 = analyse_sub env assigns e in + let assigns, r2 = analyse_lexp env assigns deps lexp in + (assigns, merge r1 r2) + | LE_vector_range (lexp, e1, e2) -> + let _, assigns, r1 = analyse_sub env assigns e1 in + let _, assigns, r2 = analyse_sub env assigns e2 in + let assigns, r3 = analyse_lexp env assigns deps lexp in + (assigns, merge r3 (merge r1 r2)) + | LE_field (lexp, _) -> analyse_lexp env assigns deps lexp + | LE_deref e -> + let _, assigns, r = analyse_sub env assigns e in + (assigns, r) + + let initial_env fn_id fn_l (TypQ_aux (tq, _)) pat body set_assertions globals = + (* The splitter always uses the outermost location *) + let top_pat_loc = pat_loc pat in + + let pats = match pat with P_aux (P_tuple pats, _) -> pats | _ -> [pat] in + (* For the type in an annotation, produce the corresponding tyvar (if any), + and a default case split (a set if there's one, a full case split if not). *) + let kids_of_annot annot = + let env = env_of_annot annot in + let (Typ_aux (typ, _)) = Env.base_typ_of env (typ_of_annot annot) in + match typ with + | Typ_app (Id_aux (Id "atom", _), [A_aux (A_nexp (Nexp_aux (Nexp_var kid, _)), _)]) -> + Spec_analysis.equal_kids env kid + | _ -> KidSet.empty in - let deps, r, env = aux env (es,typs) in - (deps, assigns, r, env) - in - let is_toplevel_int tannot = - match destruct_atom_nexp (env_of_annot tannot) (typ_of_annot tannot) with - | Some (Nexp_aux (Nexp_var kid, _)) -> List.exists (fun k -> Kid.compare k kid == 0) env.top_kids - | _ -> false - in - let merge_deps deps = List.fold_left dmerge dempty deps in - let deps, assigns, r = - match e with - | E_block es -> - let rec aux env assigns = function - | [] -> (dempty, assigns, empty) - | [e] -> analyse_sub env assigns e - (* There's also a lone assignment case below where no env update is needed *) - | E_aux (E_assign (lexp,e1),ann)::e2::es -> - let d1,assigns,r1 = analyse_sub env assigns e1 in - let assigns,r2 = analyse_lexp env assigns d1 lexp in - let env = update_env_new_kids env d1 (env_of_annot ann) (env_of e2) in - let d3, assigns, r3 = aux env assigns (e2::es) in - (d3, assigns, merge (merge r1 r2) r3) - | e::es -> - let _, assigns, r' = analyse_sub env assigns e in - let d, assigns, r = aux env assigns es in - d, assigns, merge r r' - in - aux env assigns es - | E_id id -> - begin - match Bindings.find id env.var_deps with - | args -> (args,assigns,empty) - | exception Not_found -> - match Bindings.find id assigns with - | args -> (args,assigns,empty) - | exception Not_found -> - match Env.lookup_id id (Type_check.env_of_annot (l,annot)) with - | Enum _ -> dempty,assigns,empty - | Register _ -> Unknown (l, string_of_id id ^ " is a register"),assigns,empty - | _ -> - if IdSet.mem id env.referenced_vars then - Unknown (l, string_of_id id ^ " may be modified via a reference"),assigns,empty - else match Bindings.find id env.globals with - | true -> dempty,assigns,empty (* value *) - | false -> Unknown (l, string_of_id id ^ " is a global but not a value"),assigns,empty - | exception Not_found -> - Unknown (l, string_of_id id ^ " is not in the environment"),assigns,empty - end - | E_lit _ -> (dempty,assigns,empty) - | E_typ (_,e) -> analyse_sub env assigns e - | E_app (id,args) -> - let typ_env = env_of_annot (l,annot) in - let (_,fn_typ) = Env.get_val_spec_orig id typ_env in - let kid_inst = instantiation_of exp in - let kid_inst = KBindings.fold (fun kid -> KBindings.add (orig_kid kid)) kid_inst KBindings.empty in - let fn_typ = subst_unifiers kid_inst fn_typ in - let arg_typs = match fn_typ with - | Typ_aux (Typ_fn (args,_),_) -> args - | _ -> [] - in - (* We have to use the types from the val_spec here so that we can track - any type variables that are generated by the coercing unification that - the type checker applies after inferring the type of an argument, and - that only appear in the unifiers. *) - let deps, assigns, r, env = non_det_args args arg_typs in - let eff_dep = - (* For a pure function we can monomorphise the result by monomorphising - the arguments - but that's not guaranteed for an effectful function, - which may (e.g.) depend upn a register. *) - if Effects.function_is_pure id effect_info - then dempty - else Unknown (l, "Effects from function application") - in - let kid_inst = KBindings.map (simplify_size_typ_arg env typ_env) kid_inst in - (* Change kids in instantiation to the canonical ones from the type signature *) - let kid_deps = KBindings.map (deps_of_typ_arg l fn_id env deps) kid_inst in - let rdep,r' = - if Id.compare fn_id id == 0 then - let bad = Unknown (l,"Recursive call of " ^ string_of_id id) in - let kid_deps = KBindings.map (fun _ -> in_fun_call_dep bad) kid_deps in - bad, { empty with split_on_call = Bindings.singleton id kid_deps } - else - dempty, { empty with split_on_call = Bindings.singleton id kid_deps } in - (merge_deps (rdep::eff_dep::deps), assigns, merge r r') - | E_tuple es - | E_list es -> - let deps, assigns, r = non_det es in - (merge_deps deps, assigns, r) - | E_if (e1,e2,e3) -> - let d1,assigns,r1 = analyse_sub env assigns e1 in - let d2,a2,r2 = analyse_sub env assigns e2 in - let d3,a3,r3 = analyse_sub env assigns e3 in - let assigns = add_dep_to_assigned d1 (dep_bindings_merge a2 a3) [e2;e3] in - (dmerge d1 (dmerge d2 d3), assigns, merge r1 (merge r2 r3)) - | E_loop (_,_,e1,e2) -> - (* We remove all of the variables assigned in the loop, so we don't - need to add control dependencies *) - let assigns = remove_assigns [e1;e2] " assigned in a loop" in - let d1,a1,r1 = analyse_sub env assigns e1 in - let d2,a2,r2 = analyse_sub env assigns e2 in - (dempty, assigns, merge r1 r2) - | E_for (var,efrom,eto,eby,ord,body) -> - let d1,assigns,r1 = non_det [efrom;eto;eby] in - let assigns = remove_assigns [body] " assigned in a loop" in - let d = merge_deps d1 in - let loop_kid = mk_kid ("loop_" ^ string_of_id var) in - let env' = { env with - kid_deps = KBindings.add loop_kid d env.kid_deps} in - let d2,a2,r2 = analyse_sub env' assigns body in - (dempty, assigns, merge r1 r2) - | E_vector es -> - let ds, assigns, r = non_det es in - (merge_deps ds, assigns, r) - | E_vector_access (e1,e2) - | E_vector_append (e1,e2) - | E_cons (e1,e2) -> - let ds, assigns, r = non_det [e1;e2] in - (merge_deps ds, assigns, r) - | E_vector_subrange (e1,e2,e3) - | E_vector_update (e1,e2,e3) -> - let ds, assigns, r = non_det [e1;e2;e3] in - (merge_deps ds, assigns, r) - | E_vector_update_subrange (e1,e2,e3,e4) -> - let ds, assigns, r = non_det [e1;e2;e3;e4] in - (merge_deps ds, assigns, r) - | E_struct fexps -> - let es = List.map (function (FE_aux (FE_fexp (_,e),_)) -> e) fexps in - let ds, assigns, r = non_det es in - (merge_deps ds, assigns, r) - | E_struct_update (e,fexps) -> - let es = List.map (function (FE_aux (FE_fexp (_,e),_)) -> e) fexps in - let ds, assigns, r = non_det (e::es) in - (merge_deps ds, assigns, r) - | E_field (e,_) -> analyse_sub env assigns e - | E_match (e,cases) -> - let deps,assigns,r = analyse_sub env assigns e in - let deps = match refine_dependency env e cases with - | Some deps -> deps - | None -> deps - in - let analyse_case (Pat_aux (pexp,_)) = - match pexp with - | Pat_exp (pat,e1) -> - let env = update_env env deps pat (env_of_annot (l,annot)) (env_of e1) in - let d,assigns,r = analyse_sub env assigns e1 in - let assigns = add_dep_to_assigned deps assigns [e1] in - (d,assigns,r) - | Pat_when (pat,e1,e2) -> - let env = update_env env deps pat (env_of_annot (l,annot)) (env_of e2) in - let d1,assigns,r1 = analyse_sub env assigns e1 in - let d2,assigns,r2 = analyse_sub env assigns e2 in - let assigns = add_dep_to_assigned deps assigns [e1;e2] in - (dmerge d1 d2, assigns, merge r1 r2) - in - let ds,assigns,rs = split3 (List.map analyse_case cases) in - (merge_deps (deps::ds), - List.fold_left dep_bindings_merge Bindings.empty assigns, - List.fold_left merge r rs) - | E_let (LB_aux (LB_val (pat,e1),(lb_l,_)),e2) -> - let d1,assigns,r1 = analyse_sub env assigns e1 in - let unknown_deps = match d1 with Unknown _ -> true | Have _ -> false in - let d = - (* As a special case, detect - let 'size = if ... then 'typaram1 else 'typaram2; - where we can reduce the dependencies of 'size to the guard. *) - match pat, e1 with - | P_aux (P_var (P_aux (P_id id, _), TP_aux (TP_var kid, _)),_), - E_aux (E_if (guard_exp, - E_aux (E_id id1, annot1), - E_aux (E_id id2, annot2)), _) - when is_toplevel_int annot1 && is_toplevel_int annot2 -> - let guard_deps, _, _ = analyse_sub env assigns guard_exp in - guard_deps - - (* Add a new case split after the let if necessary *) - (* Potential improvements: match on more patterns (e.g. tuples); - allow disjunctions of equalities as well as set constraints; - allow set constraint to be part of a larger constraint. *) - | P_aux ((P_id id | P_var (P_aux (P_id id, _), _)), _), _ - when unknown_deps && useful_loc lb_l -> - let l' = Generated l in - let split = match typ_of e1 with - | Typ_aux (Typ_exist ([kdid], NC_aux (NC_set (kid, sizes), _), typ), _) - when Kid.compare (kopt_kid kdid) kid == 0 -> - begin match Type_check.destruct_atom_nexp (env_of e1) typ with - | Some nexp when Nexp.compare (nvar kid) nexp == 0 -> - let pats = List.map (fun n -> P_aux (P_lit (L_aux (L_num n,l')),(l',annot))) sizes in - Partial (pats,l) - | _ -> Total - end - - | Typ_aux (Typ_app (Id_aux (Id "atom", _), - [A_aux (A_nexp (Nexp_aux (Nexp_var kid',_)),_)]),_) -> - let typ_env = env_of_annot (l,annot) in - let constraints = Type_check.Env.get_constraints typ_env in - let vars = Spec_analysis.equal_kids_ncs kid' constraints in - begin match Util.find_map (function - | NC_aux (NC_set (kid'', is),_) when KidSet.mem kid'' vars -> Some is - | _ -> None) constraints with - | Some sizes -> - let pats = List.map (fun n -> P_aux (P_lit (L_aux (L_num n,l')),(l',annot))) sizes in - Partial (pats,l) - | None -> Total - end - - | _ -> Total + let default_split annot kids = + let kids = KidSet.elements kids in + let try_kid kid = try Some (KBindings.find kid set_assertions) with Not_found -> None in + match Util.option_first try_kid kids with + | Some (l, is) -> + let l' = Generated l in + let pats = List.map (fun n -> P_aux (P_lit (L_aux (L_num n, l')), (l', annot))) is in + let pats = pats @ [P_aux (P_wild, (l', annot))] in + Partial (pats, l) + | None -> Total + in + let qs = match tq with TypQ_no_forall -> [] | TypQ_tq qs -> qs in + let eqn_instantiations = Type_check.instantiate_simple_equations qs in + let eqn_kid_deps = + KBindings.map (function A_aux (A_nexp nexp, _) -> Some (nexp_frees nexp) | _ -> None) eqn_instantiations + in + let arg i pat = + let rec aux (P_aux (p, (l, annot))) = + let of_list pats = + let ss, vs, ks = split3 (List.map aux pats) in + let s = List.fold_left (ArgSplits.merge merge_detail) ArgSplits.empty ss in + let v = List.fold_left dep_bindings_merge Bindings.empty vs in + let k = List.fold_left dep_kbindings_merge KBindings.empty ks in + (s, v, k) + in + match p with + | P_lit _ | P_wild -> (ArgSplits.empty, Bindings.empty, KBindings.empty) + | P_or (p1, p2) -> + let s1, v1, k1 = aux p1 in + let s2, v2, k2 = aux p2 in + (ArgSplits.merge merge_detail s1 s2, dep_bindings_merge v1 v2, dep_kbindings_merge k1 k2) + | P_not p -> aux p + | P_as (pat, id) -> + let s, v, k = aux pat in + if useful_loc top_pat_loc then + ( ArgSplits.add (id, top_pat_loc) Total s, + Bindings.add id + (Have (ArgSplits.singleton (id, top_pat_loc) Total, ExtraSplits.empty, LetSplits.empty)) + v, + k + ) + else (s, Bindings.add id (Unknown (l, "Unable to give location for " ^ string_of_id id)) v, k) + | P_typ (_, pat) -> aux pat + | P_id id -> + if useful_loc top_pat_loc then ( + let kids = kids_of_annot (l, annot) in + let split = default_split annot kids in + let s = ArgSplits.singleton (id, top_pat_loc) split in + ( s, + Bindings.singleton id (Have (s, ExtraSplits.empty, LetSplits.empty)), + KidSet.fold + (fun kid k -> KBindings.add kid (Have (s, ExtraSplits.empty, LetSplits.empty)) k) + kids KBindings.empty + ) + ) + else + ( ArgSplits.empty, + Bindings.singleton id (Unknown (l, "Unable to give location for " ^ string_of_id id)), + KBindings.empty + ) + | P_var (pat, tpat) -> + let s, v, k = aux pat in + let kids = kids_bound_by_typ_pat tpat in + let kids = + KidSet.fold + (fun kid s -> KidSet.union s (Spec_analysis.equal_kids (env_of_annot (l, annot)) kid)) + kids kids in - Have (ArgSplits.empty, ExtraSplits.empty, LetSplits.singleton (id,lb_l) split) - - | _, _ -> - d1 - in - let env = update_env env d pat (env_of_annot (l,annot)) (env_of e2) in - let d2,assigns,r2 = analyse_sub env assigns e2 in - (d2,assigns,merge r1 r2) - (* There's a more general assignment case above to update env inside a block. *) - | E_assign (lexp,e1) -> - let d1,assigns,r1 = analyse_sub env assigns e1 in - let assigns,r2 = analyse_lexp env assigns d1 lexp in - (dempty, assigns, merge r1 r2) - | E_sizeof nexp -> - (deps_of_nexp l env.kid_deps [] nexp, assigns, empty) - | E_return e - | E_exit e - | E_throw e -> - let _, _, r = analyse_sub env assigns e in - (dempty, Bindings.empty, r) - | E_ref id -> - (Unknown (l, "May be mutated via reference to " ^ string_of_id id), assigns, empty) - | E_try (e,cases) -> - let deps,_,r = analyse_sub env assigns e in - let assigns = remove_assigns [e] " assigned in try expression" in - let analyse_handler (Pat_aux (pexp,_)) = - match pexp with - | Pat_exp (pat,e1) -> - let env = update_env env (Unknown (l,"Exception")) pat (env_of_annot (l,annot)) (env_of e1) in - let d,assigns,r = analyse_sub env assigns e1 in - let assigns = add_dep_to_assigned deps assigns [e1] in - (d,assigns,r) - | Pat_when (pat,e1,e2) -> - let env = update_env env (Unknown (l,"Exception")) pat (env_of_annot (l,annot)) (env_of e2) in - let d1,assigns,r1 = analyse_sub env assigns e1 in - let d2,assigns,r2 = analyse_sub env assigns e2 in - let assigns = add_dep_to_assigned deps assigns [e1;e2] in - (dmerge d1 d2, assigns, merge r1 r2) - in - let ds,assigns,rs = split3 (List.map analyse_handler cases) in - (merge_deps (deps::ds), - List.fold_left dep_bindings_merge Bindings.empty assigns, - List.fold_left merge r rs) - | E_assert (e1,_) -> analyse_sub env assigns e1 - | E_internal_assume (nc,e1) -> analyse_sub env assigns e1 - - | E_app_infix _ - | E_internal_plet _ - | E_internal_return _ - | E_internal_value _ - -> raise (Reporting.err_unreachable l __POS__ - ("Unexpected expression encountered in monomorphisation: " ^ string_of_exp exp)) - - | E_var (lexp,e1,e2) -> - (* Really we ought to remove the assignment after e2 *) - let d1,assigns,r1 = analyse_sub env assigns e1 in - let assigns,r' = analyse_lexp env assigns d1 lexp in - let d2,assigns,r2 = analyse_sub env assigns e2 in - (dempty, assigns, merge r1 (merge r' r2)) - | E_constraint nc -> - (deps_of_nc env.kid_deps nc, assigns, empty) - in - let deps = - match destruct_atom_bool (env_of exp) (typ_of exp) with - | Some nc -> dmerge deps (deps_of_nc env.kid_deps nc) - | None -> deps - in - let r = - (* Check for bitvector types with parametrised sizes *) - match destruct_tannot annot with - | None -> r - | Some (tenv,typ) -> - let typ = Env.base_typ_of tenv typ in - let env, tenv, typ = - match destruct_exist (Env.expand_synonyms tenv typ) with - | None -> env, tenv, typ - | Some (kopts, nc, typ) -> - { env with kid_deps = - List.fold_left (fun kds kopt -> KBindings.add (kopt_kid kopt) deps kds) env.kid_deps kopts }, - Env.add_constraint nc - (List.fold_left (fun tenv kopt -> Env.add_typ_var l kopt tenv) tenv kopts), - typ - in - let rec check_typ typ = - if is_bitvector_typ typ then - let size,_,_ = vector_typ_args_of typ in - let Nexp_aux (size,_) as size_nexp = simplify_size_nexp env tenv size in - let is_tyvar_parameter v = - List.exists (fun k -> Kid.compare k v == 0) env.top_kids - in - match size with - | Nexp_constant _ -> r - | Nexp_var v when is_tyvar_parameter v -> - { r with kid_in_caller = CallerKidSet.add (fn_id,v) r.kid_in_caller } - | _ -> - match deps_of_nexp l env.kid_deps [] size_nexp with - | Have (args,extras,lets) -> - { r with - split = ArgSplits.merge merge_detail r.split args; - extra_splits = merge_extras r.extra_splits extras; - let_binding_splits = LetSplits.merge merge_detail r.let_binding_splits lets - } - | Unknown (l,msg) -> - { r with - failures = - Failures.add l (StringSet.singleton ("Unable to monomorphise " ^ string_of_nexp size_nexp ^ ": " ^ msg)) - r.failures } - else match typ with - | Typ_aux (Typ_tuple typs,_) -> - List.fold_left (fun r ty -> merge r (check_typ ty)) r typs - | _ -> r - in check_typ typ - in (deps, assigns, r) - - -and analyse_lexp fn_id effect_info env assigns deps (LE_aux (lexp,(l,_))) = - let analyse_sub = analyse_exp fn_id effect_info in - let analyse_lexp = analyse_lexp fn_id effect_info in - (* TODO: maybe subexps and sublexps should be non-det (and in const_prop_lexp, too?) *) - match lexp with - | LE_id id - | LE_typ (_,id) -> - if IdSet.mem id env.referenced_vars - then assigns, empty - else Bindings.add id deps assigns, empty - | LE_app (id,es) -> - let _, assigns, r = analyse_sub env assigns (E_aux (E_tuple es,(Unknown,empty_tannot))) in - assigns, r - | LE_tuple lexps - | LE_vector_concat lexps -> - List.fold_left (fun (assigns,r) lexp -> - let assigns,r' = analyse_lexp env assigns deps lexp - in assigns,merge r r') (assigns,empty) lexps - | LE_vector (lexp,e) -> - let _, assigns, r1 = analyse_sub env assigns e in - let assigns, r2 = analyse_lexp env assigns deps lexp in - assigns, merge r1 r2 - | LE_vector_range (lexp,e1,e2) -> - let _, assigns, r1 = analyse_sub env assigns e1 in - let _, assigns, r2 = analyse_sub env assigns e2 in - let assigns, r3 = analyse_lexp env assigns deps lexp in - assigns, merge r3 (merge r1 r2) - | LE_field (lexp,_) -> analyse_lexp env assigns deps lexp - | LE_deref e -> - let _, assigns, r = analyse_sub env assigns e in - assigns, r - - -let initial_env fn_id fn_l (TypQ_aux (tq,_)) pat body set_assertions globals = - (* The splitter always uses the outermost location *) - let top_pat_loc = pat_loc pat in - - let pats = - match pat with - | P_aux (P_tuple pats,_) -> pats - | _ -> [pat] - in - (* For the type in an annotation, produce the corresponding tyvar (if any), - and a default case split (a set if there's one, a full case split if not). *) - let kids_of_annot annot = - let env = env_of_annot annot in - let Typ_aux (typ,_) = Env.base_typ_of env (typ_of_annot annot) in - match typ with - | Typ_app (Id_aux (Id "atom",_),[A_aux (A_nexp (Nexp_aux (Nexp_var kid,_)),_)]) -> - Spec_analysis.equal_kids env kid - | _ -> KidSet.empty - in - let default_split annot kids = - let kids = KidSet.elements kids in - let try_kid kid = try Some (KBindings.find kid set_assertions) with Not_found -> None in - match Util.option_first try_kid kids with - | Some (l,is) -> - let l' = Generated l in - let pats = List.map (fun n -> P_aux (P_lit (L_aux (L_num n,l')),(l',annot))) is in - let pats = pats @ [P_aux (P_wild,(l',annot))] in - Partial (pats,l) - | None -> Total - in - let qs = - match tq with - | TypQ_no_forall -> [] - | TypQ_tq qs -> qs - in - let eqn_instantiations = Type_check.instantiate_simple_equations qs in - let eqn_kid_deps = KBindings.map (function - | A_aux (A_nexp nexp, _) -> Some (nexp_frees nexp) - | _ -> None) eqn_instantiations - in - let arg i pat = - let rec aux (P_aux (p,(l,annot))) = - let of_list pats = - let ss,vs,ks = split3 (List.map aux pats) in - let s = List.fold_left (ArgSplits.merge merge_detail) ArgSplits.empty ss in - let v = List.fold_left dep_bindings_merge Bindings.empty vs in - let k = List.fold_left dep_kbindings_merge KBindings.empty ks in - s,v,k + (s, v, KidSet.fold (fun kid k -> KBindings.add kid (Have (s, ExtraSplits.empty, LetSplits.empty)) k) kids k) + | P_app (_, pats) -> of_list pats + | P_vector pats | P_vector_concat pats | P_string_append pats | P_tuple pats | P_list pats -> of_list pats + | P_cons (p1, p2) -> of_list [p1; p2] + | P_vector_subrange _ -> + Reporting.unreachable l __POS__ "vector subrange pattern should be removed before monomorphisation" in - match p with - | P_lit _ - | P_wild - -> ArgSplits.empty,Bindings.empty,KBindings.empty - | P_or (p1, p2) -> - let (s1, v1, k1) = aux p1 in - let (s2, v2, k2) = aux p2 in - (ArgSplits.merge merge_detail s1 s2, dep_bindings_merge v1 v2, dep_kbindings_merge k1 k2) - | P_not p -> aux p - | P_as (pat,id) -> - let s,v,k = aux pat in - if useful_loc top_pat_loc then - ArgSplits.add (id,top_pat_loc) Total s, - Bindings.add id (Have (ArgSplits.singleton (id,top_pat_loc) Total, ExtraSplits.empty, LetSplits.empty)) v, - k - else - s, - Bindings.add id (Unknown (l, ("Unable to give location for " ^ string_of_id id))) v, - k - | P_typ (_,pat) -> aux pat - | P_id id -> - if useful_loc top_pat_loc then - let kids = kids_of_annot (l,annot) in - let split = default_split annot kids in - let s = ArgSplits.singleton (id,top_pat_loc) split in - s, - Bindings.singleton id (Have (s, ExtraSplits.empty, LetSplits.empty)), - KidSet.fold (fun kid k -> KBindings.add kid (Have (s, ExtraSplits.empty, LetSplits.empty)) k) kids KBindings.empty - else - ArgSplits.empty, - Bindings.singleton id (Unknown (l, ("Unable to give location for " ^ string_of_id id))), - KBindings.empty - | P_var (pat, tpat) -> - let s,v,k = aux pat in - let kids = kids_bound_by_typ_pat tpat in - let kids = KidSet.fold (fun kid s -> - KidSet.union s (Spec_analysis.equal_kids (env_of_annot (l,annot)) kid)) - kids kids in - s,v,KidSet.fold (fun kid k -> KBindings.add kid (Have (s, ExtraSplits.empty, LetSplits.empty)) k) kids k - | P_app (_,pats) -> of_list pats - | P_vector pats - | P_vector_concat pats - | P_string_append pats - | P_tuple pats - | P_list pats - -> of_list pats - | P_cons (p1,p2) -> of_list [p1;p2] - | P_vector_subrange _ -> - Reporting.unreachable l __POS__ "vector subrange pattern should be removed before monomorphisation" - in aux pat - in - let int_quant = function - | QI_aux (QI_id (KOpt_aux (KOpt_kind (K_aux (K_int,_),kid),_)),_) -> Some kid - | _ -> None - in - let top_kids = List.filter_map int_quant qs in - let _,var_deps,kid_deps = split3 (List.mapi arg pats) in - let var_deps = List.fold_left dep_bindings_merge Bindings.empty var_deps in - let kid_deps = List.fold_left dep_kbindings_merge KBindings.empty kid_deps in - let note_no_arg kid_deps kid = - if KBindings.mem kid kid_deps then kid_deps - else - (* When there's no argument to case split on for a kid, we'll add a - case expression instead *) - let env = env_of_pat pat in - let split = default_split (mk_tannot env int_typ) (KidSet.singleton kid) in - let extra_splits = ExtraSplits.singleton (fn_id, fn_l) - (KBindings.singleton kid split) in - KBindings.add kid (Have (ArgSplits.empty, extra_splits, LetSplits.empty)) kid_deps - in - let kid_deps = List.fold_left note_no_arg kid_deps top_kids in - let merge_kid_deps_eqns k kdeps eqn_kids = - match kdeps, eqn_kids with - | _, Some (Some kids) -> Some (KidSet.fold (fun kid deps -> dmerge deps (KBindings.find kid kid_deps)) kids dempty) - | Some deps, _ -> Some deps - | _, _ -> None - in - let kid_deps = KBindings.merge merge_kid_deps_eqns kid_deps eqn_kid_deps in - let referenced_vars = Constant_propagation.referenced_vars body in - { top_kids; var_deps; kid_deps; referenced_vars; globals } - -(* When there's more than one pick the first *) -let merge_set_asserts _ x y = - match x, y with - | None, _ -> y - | _, _ -> x -let merge_set_asserts_by_kid sets1 sets2 = - KBindings.merge merge_set_asserts sets1 sets2 - -(* Set constraints in assertions don't always use the set syntax, so we also - handle assert('N == 1 | ...) style set constraints *) -let rec sets_from_assert e = - let set_from_or_exps (E_aux (_,(l,_)) as e) = - let mykid = ref None in - let check_kid kid = - match !mykid with - | None -> mykid := Some kid - | Some kid' -> if Kid.compare kid kid' == 0 then () - else raise Not_found + aux pat in - let rec aux (E_aux (e,_)) = - match e with - | E_app (Id_aux (Id "or_bool",_),[e1;e2]) -> - aux e1 @ aux e2 - | E_app (Id_aux (Id "eq_int",_), - [E_aux (E_sizeof (Nexp_aux (Nexp_var kid,_)),_); - E_aux (E_lit (L_aux (L_num i,_)),_)]) -> - (check_kid kid; [i]) - (* TODO: Now that E_constraint is re-written by the typechecker, - we'll end up with the following for the above - some of this - function is probably redundant now *) - | E_app (Id_aux (Id "eq_int",_), - [E_aux (E_app (Id_aux (Id "__id", _), [E_aux (E_id id, annot)]), _); - E_aux (E_lit (L_aux (L_num i,_)),_)]) -> - begin match typ_of_annot annot with - | Typ_aux (Typ_app (Id_aux (Id "atom", _), [A_aux (A_nexp (Nexp_aux (Nexp_var kid, _)), _)]), _) -> - check_kid kid; [i] - | _ -> raise Not_found - end - | _ -> raise Not_found - in try - let is = aux e in - match !mykid with - | None -> KBindings.empty - | Some kid -> KBindings.singleton kid (l,is) + let int_quant = function + | QI_aux (QI_id (KOpt_aux (KOpt_kind (K_aux (K_int, _), kid), _)), _) -> Some kid + | _ -> None + in + let top_kids = List.filter_map int_quant qs in + let _, var_deps, kid_deps = split3 (List.mapi arg pats) in + let var_deps = List.fold_left dep_bindings_merge Bindings.empty var_deps in + let kid_deps = List.fold_left dep_kbindings_merge KBindings.empty kid_deps in + let note_no_arg kid_deps kid = + if KBindings.mem kid kid_deps then kid_deps + else ( + (* When there's no argument to case split on for a kid, we'll add a + case expression instead *) + let env = env_of_pat pat in + let split = default_split (mk_tannot env int_typ) (KidSet.singleton kid) in + let extra_splits = ExtraSplits.singleton (fn_id, fn_l) (KBindings.singleton kid split) in + KBindings.add kid (Have (ArgSplits.empty, extra_splits, LetSplits.empty)) kid_deps + ) + in + let kid_deps = List.fold_left note_no_arg kid_deps top_kids in + let merge_kid_deps_eqns k kdeps eqn_kids = + match (kdeps, eqn_kids) with + | _, Some (Some kids) -> Some (KidSet.fold (fun kid deps -> dmerge deps (KBindings.find kid kid_deps)) kids dempty) + | Some deps, _ -> Some deps + | _, _ -> None + in + let kid_deps = KBindings.merge merge_kid_deps_eqns kid_deps eqn_kid_deps in + let referenced_vars = Constant_propagation.referenced_vars body in + { top_kids; var_deps; kid_deps; referenced_vars; globals } + + (* When there's more than one pick the first *) + let merge_set_asserts _ x y = match (x, y) with None, _ -> y | _, _ -> x + let merge_set_asserts_by_kid sets1 sets2 = KBindings.merge merge_set_asserts sets1 sets2 + + (* Set constraints in assertions don't always use the set syntax, so we also + handle assert('N == 1 | ...) style set constraints *) + let rec sets_from_assert e = + let set_from_or_exps (E_aux (_, (l, _)) as e) = + let mykid = ref None in + let check_kid kid = + match !mykid with + | None -> mykid := Some kid + | Some kid' -> if Kid.compare kid kid' == 0 then () else raise Not_found + in + let rec aux (E_aux (e, _)) = + match e with + | E_app (Id_aux (Id "or_bool", _), [e1; e2]) -> aux e1 @ aux e2 + | E_app + ( Id_aux (Id "eq_int", _), + [E_aux (E_sizeof (Nexp_aux (Nexp_var kid, _)), _); E_aux (E_lit (L_aux (L_num i, _)), _)] + ) -> + check_kid kid; + [i] + (* TODO: Now that E_constraint is re-written by the typechecker, + we'll end up with the following for the above - some of this + function is probably redundant now *) + | E_app + ( Id_aux (Id "eq_int", _), + [ + E_aux (E_app (Id_aux (Id "__id", _), [E_aux (E_id id, annot)]), _); E_aux (E_lit (L_aux (L_num i, _)), _); + ] + ) -> begin + match typ_of_annot annot with + | Typ_aux (Typ_app (Id_aux (Id "atom", _), [A_aux (A_nexp (Nexp_aux (Nexp_var kid, _)), _)]), _) -> + check_kid kid; + [i] + | _ -> raise Not_found + end + | _ -> raise Not_found + in + try + let is = aux e in + match !mykid with None -> KBindings.empty | Some kid -> KBindings.singleton kid (l, is) with Not_found -> KBindings.empty - in - let rec set_from_nc_or (NC_aux (nc,_)) = - match nc with - | NC_equal (Nexp_aux (Nexp_var kid,_), Nexp_aux (Nexp_constant n,_)) -> - Some (kid,[n]) - | NC_or (nc1, nc2) -> - (match set_from_nc_or nc1, set_from_nc_or nc2 with - | Some (kid1,l1), Some (kid2,l2) when Kid.compare kid1 kid2 == 0 -> Some (kid1,l1 @ l2) - | _ -> None) - | _ -> None - in - let rec sets_from_nc (NC_aux (nc,l) as nc_full) = - match nc with - | NC_and (nc1,nc2) -> merge_set_asserts_by_kid (sets_from_nc nc1) (sets_from_nc nc2) - | NC_set (kid,is) -> KBindings.singleton kid (l,is) - | NC_equal (Nexp_aux (Nexp_var kid,_), Nexp_aux (Nexp_constant n,_)) -> - KBindings.singleton kid (l, [n]) - | NC_or _ -> - (match set_from_nc_or nc_full with - | Some (kid, is) -> KBindings.singleton kid (l,is) - | None -> KBindings.empty) - | _ -> KBindings.empty - in - match e with - | E_aux (E_app (Id_aux (Id "and_bool",_),[e1;e2]),_) -> - merge_set_asserts_by_kid (sets_from_assert e1) (sets_from_assert e2) - | E_aux (E_constraint nc,_) -> sets_from_nc nc - | _ -> set_from_or_exps e - -(* Find all the easily reached set assertions in a function body, to use as - case splits. Note that this should be mirrored in stop_at_false_assertions, - above. *) -let rec find_set_assertions (E_aux (e,_)) = - match e with - | E_block es -> - List.fold_left merge_set_asserts_by_kid KBindings.empty (List.map find_set_assertions es) - | E_typ (_,e) -> find_set_assertions e - | E_let (LB_aux (LB_val (p,e1),_),e2) -> - let sets1 = find_set_assertions e1 in - let sets2 = find_set_assertions e2 in - let kbound = kids_bound_by_pat p in - let sets2 = KBindings.filter (fun kid _ -> not (KidSet.mem kid kbound)) sets2 in - merge_set_asserts_by_kid sets1 sets2 - | E_assert (exp1,_) -> sets_from_assert exp1 - | _ -> KBindings.empty - -let print_set_assertions set_assertions = - if KBindings.is_empty set_assertions then - print_endline "No top-level set assertions found." - else begin - print_endline "Top-level set assertions found:"; - KBindings.iter (fun k (l,is) -> - print_endline (string_of_kid k ^ " @ " ^ simple_string_of_loc l ^ " " ^ - String.concat "," (List.map Big_int.to_string is))) - set_assertions - end - -let print_result r = - let _ = print_endline (" splits: " ^ string_of_argsplits r.split) in - let print_kbinding kid dep = - let s1 = match dep.in_fun with - | Some dep -> "InFun " ^ string_of_dep dep - | None -> "" in - let s2 = string_of_callerkidset dep.parents in - let _ = print_endline (" " ^ string_of_kid kid ^ ": " ^ s1 ^ "; " ^ s2) in - () - in - let print_binding id kdep = - let _ = print_endline (" " ^ string_of_id id ^ ":") in - let _ = KBindings.iter print_kbinding kdep in - () - in - let _ = print_endline " split_on_call: " in - let _ = Bindings.iter print_binding r.split_on_call in - let _ = print_endline (" kid_in_caller: " ^ string_of_callerkidset r.kid_in_caller) in - let _ = print_endline (" failures: \n " ^ - (String.concat "\n " - (List.map (fun (l,s) -> Reporting.loc_to_string l ^ ":\n " ^ - String.concat "\n " (StringSet.elements s)) - (Failures.bindings r.failures)))) in - () - -let analyse_funcl debug effect_info tenv constants (FCL_aux (FCL_funcl (id,pexp),(def_annot,_))) = - let l = def_annot.loc in - let _ = if debug > 2 then print_endline (string_of_id id) else () in - let pat,guard,body,_ = destruct_pexp pexp in - let (tq,_) = Env.get_val_spec_orig id tenv in - let set_assertions = find_set_assertions body in - let _ = if debug > 2 then print_set_assertions set_assertions in - let aenv = initial_env id l tq pat body set_assertions constants in - let _,_,r = analyse_exp id effect_info aenv Bindings.empty body in - let r = match guard with - | None -> r - | Some exp -> let _,_,r' = analyse_exp id effect_info aenv Bindings.empty exp in - let r' = - if ExtraSplits.is_empty r'.extra_splits - then r' - else merge r' { empty with failures = - Failures.singleton l (StringSet.singleton - "Case splitting size tyvars in guards not supported") } - in - merge r r' - in - let _ = if debug > 2 then print_result r else () - in r - -let analyse_def debug effect_info env globals (DEF_aux (aux, _)) = - match aux with - | DEF_fundef (FD_aux (FD_function (_,_,funcls),_)) -> - globals, List.fold_left (fun r f -> merge r (analyse_funcl debug effect_info env globals f)) empty funcls - - | DEF_let (LB_aux (LB_val (P_aux ((P_id id | P_typ (_,P_aux (P_id id,_))),_), exp),_)) -> - Bindings.add id (Constant_fold.is_constant exp) globals, empty + let rec set_from_nc_or (NC_aux (nc, _)) = + match nc with + | NC_equal (Nexp_aux (Nexp_var kid, _), Nexp_aux (Nexp_constant n, _)) -> Some (kid, [n]) + | NC_or (nc1, nc2) -> ( + match (set_from_nc_or nc1, set_from_nc_or nc2) with + | Some (kid1, l1), Some (kid2, l2) when Kid.compare kid1 kid2 == 0 -> Some (kid1, l1 @ l2) + | _ -> None + ) + | _ -> None + in + let rec sets_from_nc (NC_aux (nc, l) as nc_full) = + match nc with + | NC_and (nc1, nc2) -> merge_set_asserts_by_kid (sets_from_nc nc1) (sets_from_nc nc2) + | NC_set (kid, is) -> KBindings.singleton kid (l, is) + | NC_equal (Nexp_aux (Nexp_var kid, _), Nexp_aux (Nexp_constant n, _)) -> KBindings.singleton kid (l, [n]) + | NC_or _ -> ( + match set_from_nc_or nc_full with + | Some (kid, is) -> KBindings.singleton kid (l, is) + | None -> KBindings.empty + ) + | _ -> KBindings.empty + in + match e with + | E_aux (E_app (Id_aux (Id "and_bool", _), [e1; e2]), _) -> + merge_set_asserts_by_kid (sets_from_assert e1) (sets_from_assert e2) + | E_aux (E_constraint nc, _) -> sets_from_nc nc + | _ -> set_from_or_exps e + + (* Find all the easily reached set assertions in a function body, to use as + case splits. Note that this should be mirrored in stop_at_false_assertions, + above. *) + let rec find_set_assertions (E_aux (e, _)) = + match e with + | E_block es -> List.fold_left merge_set_asserts_by_kid KBindings.empty (List.map find_set_assertions es) + | E_typ (_, e) -> find_set_assertions e + | E_let (LB_aux (LB_val (p, e1), _), e2) -> + let sets1 = find_set_assertions e1 in + let sets2 = find_set_assertions e2 in + let kbound = kids_bound_by_pat p in + let sets2 = KBindings.filter (fun kid _ -> not (KidSet.mem kid kbound)) sets2 in + merge_set_asserts_by_kid sets1 sets2 + | E_assert (exp1, _) -> sets_from_assert exp1 + | _ -> KBindings.empty - | _ -> globals, empty + let print_set_assertions set_assertions = + if KBindings.is_empty set_assertions then print_endline "No top-level set assertions found." + else begin + print_endline "Top-level set assertions found:"; + KBindings.iter + (fun k (l, is) -> + print_endline + (string_of_kid k ^ " @ " ^ simple_string_of_loc l ^ " " ^ String.concat "," (List.map Big_int.to_string is)) + ) + set_assertions + end -let detail_to_split = function - | Total -> None - | Partial (pats,l) -> Some (pats,l) + let print_result r = + let _ = print_endline (" splits: " ^ string_of_argsplits r.split) in + let print_kbinding kid dep = + let s1 = match dep.in_fun with Some dep -> "InFun " ^ string_of_dep dep | None -> "" in + let s2 = string_of_callerkidset dep.parents in + let _ = print_endline (" " ^ string_of_kid kid ^ ": " ^ s1 ^ "; " ^ s2) in + () + in + let print_binding id kdep = + let _ = print_endline (" " ^ string_of_id id ^ ":") in + let _ = KBindings.iter print_kbinding kdep in + () + in + let _ = print_endline " split_on_call: " in + let _ = Bindings.iter print_binding r.split_on_call in + let _ = print_endline (" kid_in_caller: " ^ string_of_callerkidset r.kid_in_caller) in + let _ = + print_endline + (" failures: \n " + ^ String.concat "\n " + (List.map + (fun (l, s) -> Reporting.loc_to_string l ^ ":\n " ^ String.concat "\n " (StringSet.elements s)) + (Failures.bindings r.failures) + ) + ) + in + () -let argset_to_list splits = - let l = ArgSplits.bindings splits in - let argelt = function - | ((id,loc),detail) -> (Exact loc,string_of_id id,detail_to_split detail) - in - List.map argelt l - -let let_binding_splits_to_list lets = - List.map (fun ((id,loc), detail) -> - (Exact loc, string_of_id id, detail_to_split detail)) - (LetSplits.bindings lets) - -let analyse_defs debug effect_info env ast = - let total_defs = List.length ast.defs in - let def (idx,globals,r) d = - begin match d with - | DEF_aux (DEF_fundef fd, _) -> - Util.progress "Analysing " (string_of_id (id_of_fundef fd)) idx total_defs - | _ -> () - end; - let globals,r' = analyse_def debug effect_info env globals d in - idx + 1, globals, merge r r' - in - let _,_,r = List.fold_left def (0,Bindings.empty,empty) ast.defs in - let _ = Util.progress "Analysing " "done" total_defs total_defs in - - (* Resolve the interprocedural dependencies *) - - let rec separate_deps = function - | Have (splits, extras, lets) -> - splits, extras, lets, Failures.empty - | Unknown (l,msg) -> - ArgSplits.empty, ExtraSplits.empty, LetSplits.empty, - Failures.singleton l (StringSet.singleton ("Unable to monomorphise dependency: " ^ msg)) - and chase_kid_caller (id,kid) = - match Bindings.find id r.split_on_call with - | kid_deps -> begin - match KBindings.find kid kid_deps with - | call_dep -> - let (splits, extras, lets, fails) = match call_dep.in_fun with - | Some deps -> separate_deps deps - | None -> (ArgSplits.empty, ExtraSplits.empty, LetSplits.empty, Failures.empty) - in - CallerKidSet.fold add_kid call_dep.parents (splits, extras, lets, fails) - | exception Not_found -> ArgSplits.empty,ExtraSplits.empty,LetSplits.empty,Failures.empty - end - | exception Not_found -> ArgSplits.empty,ExtraSplits.empty,LetSplits.empty,Failures.empty - and add_kid k (splits,extras,lets,fails) = - let splits',extras',lets',fails' = chase_kid_caller k in - ArgSplits.merge merge_detail splits splits', - merge_extras extras extras', - LetSplits.merge merge_detail lets lets', - Failures.merge failure_merge fails fails' - in - let _ = if debug > 1 then print_result r else () in - let splits,extras,lets,fails = - CallerKidSet.fold add_kid r.kid_in_caller (r.split,r.extra_splits,r.let_binding_splits,r.failures) in - let _ = - if debug > 0 then - (print_endline "Final splits:"; - print_endline (string_of_argsplits splits); - print_endline (string_of_extra_splits extras); - print_endline (string_of_let_binding_splits lets)) - else () - in - let splits = argset_to_list splits @ let_binding_splits_to_list lets in - if Failures.is_empty fails - then (true,splits,extras) else - begin - Failures.iter (fun l msgs -> - Reporting.print_err l "Monomorphisation" (String.concat "\n" (StringSet.elements msgs))) + let analyse_funcl debug effect_info tenv constants (FCL_aux (FCL_funcl (id, pexp), (def_annot, _))) = + let l = def_annot.loc in + let _ = if debug > 2 then print_endline (string_of_id id) else () in + let pat, guard, body, _ = destruct_pexp pexp in + let tq, _ = Env.get_val_spec_orig id tenv in + let set_assertions = find_set_assertions body in + let _ = if debug > 2 then print_set_assertions set_assertions in + let aenv = initial_env id l tq pat body set_assertions constants in + let _, _, r = analyse_exp id effect_info aenv Bindings.empty body in + let r = + match guard with + | None -> r + | Some exp -> + let _, _, r' = analyse_exp id effect_info aenv Bindings.empty exp in + let r' = + if ExtraSplits.is_empty r'.extra_splits then r' + else + merge r' + { + empty with + failures = + Failures.singleton l (StringSet.singleton "Case splitting size tyvars in guards not supported"); + } + in + merge r r' + in + let _ = if debug > 2 then print_result r else () in + r + + let analyse_def debug effect_info env globals (DEF_aux (aux, _)) = + match aux with + | DEF_fundef (FD_aux (FD_function (_, _, funcls), _)) -> + (globals, List.fold_left (fun r f -> merge r (analyse_funcl debug effect_info env globals f)) empty funcls) + | DEF_let (LB_aux (LB_val (P_aux ((P_id id | P_typ (_, P_aux (P_id id, _))), _), exp), _)) -> + (Bindings.add id (Constant_fold.is_constant exp) globals, empty) + | _ -> (globals, empty) + + let detail_to_split = function Total -> None | Partial (pats, l) -> Some (pats, l) + + let argset_to_list splits = + let l = ArgSplits.bindings splits in + let argelt = function (id, loc), detail -> (Exact loc, string_of_id id, detail_to_split detail) in + List.map argelt l + + let let_binding_splits_to_list lets = + List.map (fun ((id, loc), detail) -> (Exact loc, string_of_id id, detail_to_split detail)) (LetSplits.bindings lets) + + let analyse_defs debug effect_info env ast = + let total_defs = List.length ast.defs in + let def (idx, globals, r) d = + begin + match d with + | DEF_aux (DEF_fundef fd, _) -> Util.progress "Analysing " (string_of_id (id_of_fundef fd)) idx total_defs + | _ -> () + end; + let globals, r' = analyse_def debug effect_info env globals d in + (idx + 1, globals, merge r r') + in + let _, _, r = List.fold_left def (0, Bindings.empty, empty) ast.defs in + let _ = Util.progress "Analysing " "done" total_defs total_defs in + + (* Resolve the interprocedural dependencies *) + let rec separate_deps = function + | Have (splits, extras, lets) -> (splits, extras, lets, Failures.empty) + | Unknown (l, msg) -> + ( ArgSplits.empty, + ExtraSplits.empty, + LetSplits.empty, + Failures.singleton l (StringSet.singleton ("Unable to monomorphise dependency: " ^ msg)) + ) + and chase_kid_caller (id, kid) = + match Bindings.find id r.split_on_call with + | kid_deps -> begin + match KBindings.find kid kid_deps with + | call_dep -> + let splits, extras, lets, fails = + match call_dep.in_fun with + | Some deps -> separate_deps deps + | None -> (ArgSplits.empty, ExtraSplits.empty, LetSplits.empty, Failures.empty) + in + CallerKidSet.fold add_kid call_dep.parents (splits, extras, lets, fails) + | exception Not_found -> (ArgSplits.empty, ExtraSplits.empty, LetSplits.empty, Failures.empty) + end + | exception Not_found -> (ArgSplits.empty, ExtraSplits.empty, LetSplits.empty, Failures.empty) + and add_kid k (splits, extras, lets, fails) = + let splits', extras', lets', fails' = chase_kid_caller k in + ( ArgSplits.merge merge_detail splits splits', + merge_extras extras extras', + LetSplits.merge merge_detail lets lets', + Failures.merge failure_merge fails fails' + ) + in + let _ = if debug > 1 then print_result r else () in + let splits, extras, lets, fails = + CallerKidSet.fold add_kid r.kid_in_caller (r.split, r.extra_splits, r.let_binding_splits, r.failures) + in + let _ = + if debug > 0 then ( + print_endline "Final splits:"; + print_endline (string_of_argsplits splits); + print_endline (string_of_extra_splits extras); + print_endline (string_of_let_binding_splits lets) + ) + else () + in + let splits = argset_to_list splits @ let_binding_splits_to_list lets in + if Failures.is_empty fails then (true, splits, extras) + else begin + Failures.iter + (fun l msgs -> Reporting.print_err l "Monomorphisation" (String.concat "\n" (StringSet.elements msgs))) fails; - (false, splits,extras) + (false, splits, extras) end - end let fresh_sz_var = let counter = ref 0 in fun () -> let n = !counter in - let () = counter := n+1 in + let () = counter := n + 1 in mk_id ("sz#" ^ string_of_int n) let add_extra_splits extras defs = let success = ref true in - let add_to_body extras (E_aux (_,(l,annot)) as e) = + let add_to_body extras (E_aux (_, (l, annot)) as e) = let l' = unique (Generated l) in - KBindings.fold (fun kid detail (exp,split_list) -> - let nexp = Nexp_aux (Nexp_var kid,l) in - let var = fresh_sz_var () in - let size_annot = mk_tannot (env_of e) (atom_typ nexp) in - let pexps = [Pat_aux (Pat_exp (P_aux (P_id var,(l',size_annot)),exp),(l',annot))] in - E_aux (E_match (E_aux (E_sizeof nexp, (l',size_annot)), pexps),(l',annot)), - ((Exact l', string_of_id var, Analysis.detail_to_split detail)::split_list) - ) extras (e,[]) + KBindings.fold + (fun kid detail (exp, split_list) -> + let nexp = Nexp_aux (Nexp_var kid, l) in + let var = fresh_sz_var () in + let size_annot = mk_tannot (env_of e) (atom_typ nexp) in + let pexps = [Pat_aux (Pat_exp (P_aux (P_id var, (l', size_annot)), exp), (l', annot))] in + ( E_aux (E_match (E_aux (E_sizeof nexp, (l', size_annot)), pexps), (l', annot)), + (Exact l', string_of_id var, Analysis.detail_to_split detail) :: split_list + ) + ) + extras (e, []) in - let add_to_funcl (FCL_aux (FCL_funcl (id,Pat_aux (pexp,peannot)),(def_annot,annot))) = + let add_to_funcl (FCL_aux (FCL_funcl (id, Pat_aux (pexp, peannot)), (def_annot, annot))) = let l = def_annot.loc in - let pexp, splits = - match Analysis.ExtraSplits.find (id,l) extras with - | extras -> - (match pexp with - | Pat_exp (p,e) -> let e',sp = add_to_body extras e in Pat_exp (p,e'), sp - | Pat_when (p,g,e) -> let e',sp = add_to_body extras e in Pat_when (p,g,e'), sp) - | exception Not_found -> pexp, [] - in FCL_aux (FCL_funcl (id,Pat_aux (pexp,peannot)),(def_annot,annot)), splits + let pexp, splits = + match Analysis.ExtraSplits.find (id, l) extras with + | extras -> ( + match pexp with + | Pat_exp (p, e) -> + let e', sp = add_to_body extras e in + (Pat_exp (p, e'), sp) + | Pat_when (p, g, e) -> + let e', sp = add_to_body extras e in + (Pat_when (p, g, e'), sp) + ) + | exception Not_found -> (pexp, []) + in + (FCL_aux (FCL_funcl (id, Pat_aux (pexp, peannot)), (def_annot, annot)), splits) in let add_to_def = function - | DEF_aux (DEF_fundef (FD_aux (FD_function (re,ta,funcls),annot)),def_annot) -> - let funcls,splits = List.split (List.map add_to_funcl funcls) in - DEF_aux (DEF_fundef (FD_aux (FD_function (re,ta,funcls),annot)),def_annot), List.concat splits - | d -> d, [] - in - let defs, splits = List.split (List.map add_to_def defs) in - !success, defs, List.concat splits - -module MonoRewrites = -struct - -let is_constant_range = function - | E_aux (E_lit _,_), E_aux (E_lit _,_) -> true - | _ -> false - -let is_constant = function - | E_aux (E_lit _,_) -> true - | _ -> false - -let get_constant_vec_len ?solve:(solve=false) env typ = - let typ = Env.base_typ_of env typ in - match destruct_bitvector env typ with - | Some (size,_) -> - begin match nexp_simp size with - | Nexp_aux (Nexp_constant i,_) -> Some i - | nexp when solve -> solve_unique env nexp - | _ -> None - end - | _ -> None - -let is_constant_vec_typ env typ = (get_constant_vec_len env typ <> None) - -let is_zeros env id = - is_id env (Id "Zeros") id || is_id env (Id "zeros") id || - is_id env (Id "sail_zeros") id - -let is_zero_extend env id = - is_id env (Id "ZeroExtend") id || - is_id env (Id "zero_extend") id || is_id env (Id "sail_zero_extend") id || - is_id env (Id "mips_zero_extend") id || is_id env (Id "EXTZ") id - -let eq_exp_conservative (E_aux (exp1, _)) (E_aux (exp2, _)) = - match exp1, exp2 with - | E_id id1, E_id id2 -> true - | E_lit lit1, E_lit lit2 -> lit_eq' lit1 lit2 - | _ -> false - -(* We have to add casts in here with appropriate length information so that the - type checker knows the expected return types. *) - -let rec rewrite_app env typ (id,args) = - let is_append = is_id env (Id "append") in - let is_subrange = is_id env (Id "vector_subrange") in - let is_integer_subrange = is_id env (Id "integer_subrange") in - let is_slice = is_id env (Id "slice") in - let is_zeros id = is_zeros env id in - let is_ones id = is_id env (Id "Ones") id || is_id env (Id "ones") id || - is_id env (Id "sail_ones") id in - let is_zero_extend = is_zero_extend env id in - let is_sign_extend = - is_id env (Id "SignExtend") id || - is_id env (Id "sign_extend") id || is_id env (Id "sail_sign_extend") id || - is_id env (Id "mips_sign_extend") id || is_id env (Id "EXTS") id - in - let is_truncate = is_id env (Id "truncate") id in - let mk_exp e = E_aux (e, (Unknown, empty_tannot)) in - let rec is_zeros_exp e = match unaux_exp e with - | E_app (zeros, [_]) when is_zeros zeros -> true - | E_lit (L_aux ((L_bin s | L_hex s), _)) -> - List.for_all (fun c -> c = '0') (Util.string_to_list s) - | E_typ (_, e) -> is_zeros_exp e - | _ -> false + | DEF_aux (DEF_fundef (FD_aux (FD_function (re, ta, funcls), annot)), def_annot) -> + let funcls, splits = List.split (List.map add_to_funcl funcls) in + (DEF_aux (DEF_fundef (FD_aux (FD_function (re, ta, funcls), annot)), def_annot), List.concat splits) + | d -> (d, []) in - let rec get_zeros_exp_len e = match unaux_exp e with - | E_app (zeros, [len]) when is_zeros zeros -> Some len - | E_typ (_, e) -> get_zeros_exp_len e - | _ -> - match get_constant_vec_len (env_of e) (typ_of e) with - | Some i -> Some (mk_exp (E_lit (L_aux (L_num i, Unknown)))) - | None -> None - in - let try_cast_to_typ (E_aux (e,(l, _)) as exp) = - let (size,order,bittyp) = vector_typ_args_of (Env.base_typ_of env typ) in - (* vector_typ_args_of might simplify size, so rebuild the type even if it's constant *) - match size with - | Nexp_aux (Nexp_constant c,_) -> E_typ (bitvector_typ (nconstant c) order, exp) - | _ -> match solve_unique env size with - | Some c -> E_typ (bitvector_typ (nconstant c) order, exp) - | None -> e - in - let rewrap e = E_aux (e, (Unknown, empty_tannot)) in - if is_append id then - match args with - (* (known-size-vector @ variable-vector) @ variable-vector *) - | [E_aux (E_app (append, - [e1; - E_aux (E_app (subrange1, - [vector1; start1; end1]),_) as sub1]),_); - E_aux (E_app (subrange2, - [vector2; start2; end2]),_) as sub2] - when is_append append && is_subrange subrange1 && is_subrange subrange2 && - is_constant_vec_typ env (typ_of e1) && - is_bitvector_typ (typ_of vector1) && is_bitvector_typ (typ_of vector2) && - not (is_constant_vec_typ env (typ_of sub1) || is_constant_vec_typ env (typ_of sub2)) -> - let (size,order,bittyp) = vector_typ_args_of (Env.base_typ_of env typ) in - let (size1,_,_) = vector_typ_args_of (Env.base_typ_of env (typ_of e1)) in - let midsize = nminus size size1 in begin - match solve_unique env midsize with - | Some c -> - let midtyp = bitvector_typ (nconstant c) order in - E_app (append, - [e1; - E_aux (E_typ (midtyp, - E_aux (E_app (mk_id "subrange_subrange_concat", - [vector1; start1; end1; vector2; start2; end2]), - (Unknown,empty_tannot))),(Unknown,empty_tannot))]) - | _ -> - E_app (append, - [e1; - E_aux (E_app (mk_id "subrange_subrange_concat", - [vector1; start1; end1; vector2; start2; end2]), - (Unknown,empty_tannot))]) - end - | [E_aux (E_app (append, - [e1; - E_aux (E_app (slice1, - [vector1; start1; length1]),_)]),_); - E_aux (E_app (slice2, - [vector2; start2; length2]),_)] - when is_append append && is_slice slice1 && is_slice slice2 && - is_constant_vec_typ env (typ_of e1) && - is_bitvector_typ (typ_of vector1) && is_bitvector_typ (typ_of vector2) && - not (is_constant length1 || is_constant length2) -> - let (size,order,bittyp) = vector_typ_args_of (Env.base_typ_of env typ) in - let (size1,_,_) = vector_typ_args_of (Env.base_typ_of env (typ_of e1)) in - let midsize = nminus size size1 in begin - match solve_unique env midsize with - | Some c -> - let midtyp = bitvector_typ (nconstant c) order in - E_app (append, - [e1; - E_aux (E_typ (midtyp, - E_aux (E_app (mk_id "slice_slice_concat", - [vector1; start1; length1; vector2; start2; length2]), - (Unknown,empty_tannot))),(Unknown,empty_tannot))]) - | _ -> - E_app (append, - [e1; - E_aux (E_app (mk_id "slice_slice_concat", - [vector1; start1; length1; vector2; start2; length2]), - (Unknown,empty_tannot))]) - end - - (* variable-slice @ zeros *) - | [E_aux (E_app (op, [vector1; start1; len1]),_) as exp1; zeros_exp] - when (is_slice op || is_subrange op) && is_zeros_exp zeros_exp - && is_bitvector_typ (typ_of vector1) - && not (is_constant_vec_typ env (typ_of exp1) && is_constant_vec_typ env (typ_of zeros_exp)) -> - let op' = if is_subrange op then "place_subrange" else "place_slice" in - begin match get_zeros_exp_len zeros_exp with - | Some zlen -> try_cast_to_typ (mk_exp (E_app (mk_id op', [vector1; start1; len1; zlen]))) - | None -> E_app (id, args) - end - - (* ones @ zeros *) - | [E_aux (E_app (ones1, [len1]), _); zeros_exp] - when is_ones ones1 && is_zeros_exp zeros_exp && - not (is_constant len1 && is_constant_vec_typ env (typ_of zeros_exp)) -> - begin match get_zeros_exp_len zeros_exp with - | Some zlen -> try_cast_to_typ (mk_exp (E_app (mk_id "slice_mask", [zlen; len1]))) - | None -> E_app (id, args) + let defs, splits = List.split (List.map add_to_def defs) in + (!success, defs, List.concat splits) + +module MonoRewrites = struct + let is_constant_range = function E_aux (E_lit _, _), E_aux (E_lit _, _) -> true | _ -> false + + let is_constant = function E_aux (E_lit _, _) -> true | _ -> false + + let get_constant_vec_len ?(solve = false) env typ = + let typ = Env.base_typ_of env typ in + match destruct_bitvector env typ with + | Some (size, _) -> begin + match nexp_simp size with + | Nexp_aux (Nexp_constant i, _) -> Some i + | nexp when solve -> solve_unique env nexp + | _ -> None end + | _ -> None + + let is_constant_vec_typ env typ = get_constant_vec_len env typ <> None + + let is_zeros env id = is_id env (Id "Zeros") id || is_id env (Id "zeros") id || is_id env (Id "sail_zeros") id + + let is_zero_extend env id = + is_id env (Id "ZeroExtend") id || is_id env (Id "zero_extend") id || is_id env (Id "sail_zero_extend") id + || is_id env (Id "mips_zero_extend") id || is_id env (Id "EXTZ") id - (* ones @ variable *) - | [E_aux (E_app (ones1, [len1]), _); (E_aux (E_id _,_) as vector2)] - when is_ones ones1 - && not (is_constant len1) -> - let one = mk_exp (E_lit (mk_lit (L_num (Big_int.of_int 1)))) in - let len2 = mk_exp (E_app (mk_id "length", [vector2])) in - let total = mk_exp (E_app_infix (len1, mk_id "+", len2)) in - try_cast_to_typ - (E_aux (E_app (mk_id "update_subrange_bits", - [E_aux (E_app (ones1, [total]), (Unknown,empty_tannot)); - mk_exp (E_app_infix (len2, mk_id "-", one)); - mk_exp (E_lit (mk_lit (L_num Big_int.zero))); - vector2]), - (Unknown, empty_tannot))) - - (* variable-range @ variable-range *) - | [E_aux (E_app (subrange1, - [vector1; start1; end1]),_) as exp1; - E_aux (E_app (subrange2, - [vector2; start2; end2]),_) as exp2] - when is_subrange subrange1 && is_subrange subrange2 && - is_bitvector_typ (typ_of vector1) && is_bitvector_typ (typ_of vector2) && - not (is_constant_vec_typ env (typ_of exp1) || is_constant_vec_typ env (typ_of exp2)) -> - try_cast_to_typ - (E_aux (E_app (mk_id "subrange_subrange_concat", - [vector1; start1; end1; vector2; start2; end2]), - (Unknown,empty_tannot))) - - (* variable-slice @ variable-slice *) - | [E_aux (E_app (slice1, - [vector1; start1; length1]),_); - E_aux (E_app (slice2, - [vector2; start2; length2]),_)] - when is_slice slice1 && is_slice slice2 && - is_bitvector_typ (typ_of vector1) && is_bitvector_typ (typ_of vector2) && - not (is_constant length1 || is_constant length2) -> - try_cast_to_typ - (E_aux (E_app (mk_id "slice_slice_concat", - [vector1; start1; length1; vector2; start2; length2]),(Unknown,empty_tannot))) - - (* variable-slice @ local-var *) - | [(E_aux (E_app (op, - [vector1; start1; length1]),_) as exp1); - (E_aux (E_id _,_) as vector2)] - when (is_slice op || is_subrange op) && is_bitvector_typ (typ_of vector1) && - not (is_constant_vec_typ env (typ_of exp1)) -> - let op' = if is_subrange op then "subrange_subrange_concat" else "slice_slice_concat" in - let zero = mk_exp (E_lit (mk_lit (L_num Big_int.zero))) in - let one = mk_exp (E_lit (mk_lit (L_num (Big_int.of_int 1)))) in - let length2 = mk_exp (E_app (mk_id "length", [vector2])) in - let indices2 = - if is_subrange op then - [mk_exp (E_app_infix (length2, mk_id "-", one)); zero] - else - [zero; length2] - in - try_cast_to_typ - (E_aux (E_app (mk_id op', - [vector1; start1; length1; vector2] @ indices2),(Unknown,empty_tannot))) - - | [E_aux (E_app (append1, - [e1; - (E_aux (E_app (op, [vector1; start1; length1]),_) as slice1)]),_); - E_aux (E_app (zeros1, [length2]),_)] - when is_append append1 && (is_slice op || is_subrange op) && is_zeros zeros1 && - is_constant_vec_typ env (typ_of e1) && is_bitvector_typ (typ_of vector1) && - not (is_constant_vec_typ env (typ_of slice1) || is_constant length2) -> - let op' = mk_id (if is_subrange op then "subrange_zeros_concat" else "slice_zeros_concat") in - let (size,order,bittyp) = vector_typ_args_of (Env.base_typ_of env typ) in - let (size1,_,_) = vector_typ_args_of (Env.base_typ_of env (typ_of e1)) in - let midsize = nminus size size1 in begin - match solve_unique env midsize with - | Some c -> - let midtyp = bitvector_typ (nconstant c) order in - try_cast_to_typ - (E_aux (E_app (mk_id "append", - [e1; - E_aux (E_typ (midtyp, - E_aux (E_app (op', - [vector1; start1; length1; length2]),(Unknown,empty_tannot))),(Unknown,empty_tannot))]), - (Unknown,empty_tannot))) - | _ -> - try_cast_to_typ - (E_aux (E_app (mk_id "append", - [e1; - E_aux (E_app (op', - [vector1; start1; length1; length2]),(Unknown,empty_tannot))]), - (Unknown,empty_tannot))) - end - - (* known-length @ (known-length @ var-length) *) - | [e1; E_aux (E_app (append1, [e2; e3]), _)] - when is_append append1 && is_constant_vec_typ env (typ_of e1) && - is_constant_vec_typ env (typ_of e2) && - not (is_constant_vec_typ env (typ_of e3)) -> - let (size1,order,bittyp) = vector_typ_args_of (Env.base_typ_of env (typ_of e1)) in - let (size2,_,_) = vector_typ_args_of (Env.base_typ_of env (typ_of e2)) in - let size12 = nexp_simp (nsum size1 size2) in - let tannot12 = mk_tannot env (bitvector_typ size12 order) in - E_app (id, [E_aux (E_app (append1, [e1; e2]), (Unknown, tannot12)); e3]) - - | _ -> E_app (id,args) - - else if is_id env (Id "vector_update_subrange") id then - match args with - | [vec1; start1; end1; (E_aux (E_app (subrange2, [vec2; start2; end2]), _) as e2)] - when is_subrange subrange2 && not (is_constant_vec_typ env (typ_of e2)) -> - let op = - if is_number (typ_of vec2) then "vector_update_subrange_from_integer_subrange" else - "vector_update_subrange_from_subrange" - in - try_cast_to_typ (E_aux (E_app (mk_id op, [vec1; start1; end1; vec2; start2; end2]), (Unknown, empty_tannot))) - - | [vec1; start1; end1; (E_aux (E_app (zeros, _), _) as e2)] - when is_zeros zeros && not (is_constant_vec_typ env (typ_of e2)) -> - try_cast_to_typ (E_aux (E_app (mk_id "set_subrange_zeros", [vec1; start1; end1]), (Unknown, empty_tannot))) - - | _ -> E_app (id, args) - - else if is_id env (Id "eq_bits") id || is_id env (Id "neq_bits") id then - (* variable-range == variable_range *) - let wrap e = - if is_id env (Id "neq_bits") id - then E_app (mk_id "not_bool", [mk_exp e]) - else e + let eq_exp_conservative (E_aux (exp1, _)) (E_aux (exp2, _)) = + match (exp1, exp2) with E_id id1, E_id id2 -> true | E_lit lit1, E_lit lit2 -> lit_eq' lit1 lit2 | _ -> false + + (* We have to add casts in here with appropriate length information so that the + type checker knows the expected return types. *) + + let rec rewrite_app env typ (id, args) = + let is_append = is_id env (Id "append") in + let is_subrange = is_id env (Id "vector_subrange") in + let is_integer_subrange = is_id env (Id "integer_subrange") in + let is_slice = is_id env (Id "slice") in + let is_zeros id = is_zeros env id in + let is_ones id = is_id env (Id "Ones") id || is_id env (Id "ones") id || is_id env (Id "sail_ones") id in + let is_zero_extend = is_zero_extend env id in + let is_sign_extend = + is_id env (Id "SignExtend") id || is_id env (Id "sign_extend") id || is_id env (Id "sail_sign_extend") id + || is_id env (Id "mips_sign_extend") id || is_id env (Id "EXTS") id + in + let is_truncate = is_id env (Id "truncate") id in + let mk_exp e = E_aux (e, (Unknown, empty_tannot)) in + let rec is_zeros_exp e = + match unaux_exp e with + | E_app (zeros, [_]) when is_zeros zeros -> true + | E_lit (L_aux ((L_bin s | L_hex s), _)) -> List.for_all (fun c -> c = '0') (Util.string_to_list s) + | E_typ (_, e) -> is_zeros_exp e + | _ -> false + in + let rec get_zeros_exp_len e = + match unaux_exp e with + | E_app (zeros, [len]) when is_zeros zeros -> Some len + | E_typ (_, e) -> get_zeros_exp_len e + | _ -> ( + match get_constant_vec_len (env_of e) (typ_of e) with + | Some i -> Some (mk_exp (E_lit (L_aux (L_num i, Unknown)))) + | None -> None + ) in - match args with - | [E_aux (E_app (subrange1, - [vector1; start1; end1]),_); - E_aux (E_app (subrange2, - [vector2; start2; end2]),_)] - when is_subrange subrange1 && is_subrange subrange2 && - is_bitvector_typ (typ_of vector1) && is_bitvector_typ (typ_of vector2) && - not (is_constant_range (start1, end1) || is_constant_range (start2, end2)) -> - wrap (E_app (mk_id "subrange_subrange_eq", - [vector1; start1; end1; vector2; start2; end2])) - | [E_aux (E_app (slice1, - [vector1; len1; start1]),_); - E_aux (E_app (slice2, - [vector2; len2; start2]),_)] - when is_slice slice1 && is_slice slice2 && - is_bitvector_typ (typ_of vector1) && is_bitvector_typ (typ_of vector2) && - not (is_constant len1 && is_constant start1 && is_constant len2 && is_constant start2) -> - let upper start len = - mk_exp (E_app_infix (start, mk_id "+", - mk_exp (E_app_infix (len, mk_id "-", - mk_exp (E_lit (mk_lit (L_num (Big_int.of_int 1)))))))) - in - wrap (E_app (mk_id "subrange_subrange_eq", - [vector1; upper start1 len1; start1; vector2; upper start2 len2; start2])) - | [(E_aux (E_app (op, [vector1; start1; len1]), _) as e1); - E_aux (E_app (zeros2, _), _)] - when (is_slice op || is_subrange op) && is_zeros zeros2 - && not (is_constant_vec_typ env (typ_of e1)) && is_bitvector_typ (typ_of vector1) -> - let op' = if is_subrange op then "is_zero_subrange" else "is_zeros_slice" in - wrap (E_app (mk_id op', [vector1; start1; len1])) - - (* subrange == ones *) - | [E_aux (E_app (subrange1, [vector1; start1; end1]),_); - E_aux (E_app (ones2, [_]),_)] - when is_id env (Id "vector_subrange") subrange1 && is_bitvector_typ (typ_of vector1) && - not (is_constant_range (start1,end1)) -> - E_app (mk_id "is_ones_subrange", - [vector1; start1; end1]) - (* slice == ones *) - | [E_aux (E_app (slice1, [vector1; start1; len1]),_); - E_aux (E_app (ones2, [_]),_)] - when is_slice slice1 && not (is_constant len1) && is_bitvector_typ (typ_of vector1) -> - E_app (mk_id "is_ones_slice", [vector1; start1; len1]) - - (* Arm specs sometimes check for overflows on values that can be either 32 or 64 bits - by converting to unbounded integers and asking for the top slice. *) - | [E_aux (E_app (op1, [vector1; start1; end1]),_); - E_aux (E_app (op2, [vector2; start2; end2]),_)] - when is_integer_subrange op1 && is_integer_subrange op2 && - is_constant start1 && is_constant start2 && - not (is_constant end1) && not (is_constant end2) -> - let zero = mk_exp (E_lit (mk_lit (L_num Big_int.zero))) in - wrap (E_app (mk_id "subrange_subrange_eq", - [mk_exp (E_app (mk_id "integer_subrange", [vector1; start1; zero])); - start1; - end1; - mk_exp (E_app (mk_id "integer_subrange", [vector2; start2; zero])); - start2; - end2])) - - | _ -> E_app (id,args) - - else if is_id env (Id "IsZero") id then - match args with - | [E_aux (E_app (subrange1, [vector1; start1; end1]),_)] - when (is_id env (Id "vector_subrange") subrange1) && is_bitvector_typ (typ_of vector1) && - not (is_constant_range (start1,end1)) -> - E_app (mk_id "is_zero_subrange", [vector1; start1; end1]) - | [E_aux (E_app (slice1, [vector1; start1; len1]),_)] - when (is_slice slice1) && is_bitvector_typ (typ_of vector1) && - not (is_constant len1) -> - E_app (mk_id "is_zeros_slice", [vector1; start1; len1]) - | _ -> E_app (id,args) - - else if is_id env (Id "IsOnes") id then - match args with - | [E_aux (E_app (subrange1, [vector1; start1; end1]),_)] - when is_id env (Id "vector_subrange") subrange1 && is_bitvector_typ (typ_of vector1) && - not (is_constant_range (start1,end1)) -> - E_app (mk_id "is_ones_subrange", - [vector1; start1; end1]) - | [E_aux (E_app (slice1, [vector1; start1; len1]),_)] - when is_slice slice1 && not (is_constant len1) && is_bitvector_typ (typ_of vector1) -> - E_app (mk_id "is_ones_slice", [vector1; start1; len1]) - | _ -> E_app (id,args) - - else if is_zero_extend || is_truncate then - let length_arg = List.filter (fun arg -> is_number (typ_of arg)) args in - match List.filter (fun arg -> not (is_number (typ_of arg))) args with - | [E_aux (E_app (append1, [E_aux (E_app (subrange1, [vector1; start1; end1]), _); zeros_exp]),_)] - when is_subrange subrange1 && is_zeros_exp zeros_exp && is_append append1 && is_bitvector_typ (typ_of vector1) -> - begin match get_zeros_exp_len zeros_exp with - | Some zlen -> + let try_cast_to_typ (E_aux (e, (l, _)) as exp) = + let size, order, bittyp = vector_typ_args_of (Env.base_typ_of env typ) in + (* vector_typ_args_of might simplify size, so rebuild the type even if it's constant *) + match size with + | Nexp_aux (Nexp_constant c, _) -> E_typ (bitvector_typ (nconstant c) order, exp) + | _ -> ( + match solve_unique env size with Some c -> E_typ (bitvector_typ (nconstant c) order, exp) | None -> e + ) + in + let rewrap e = E_aux (e, (Unknown, empty_tannot)) in + if is_append id then ( + match args with + (* (known-size-vector @ variable-vector) @ variable-vector *) + | [ + E_aux (E_app (append, [e1; (E_aux (E_app (subrange1, [vector1; start1; end1]), _) as sub1)]), _); + (E_aux (E_app (subrange2, [vector2; start2; end2]), _) as sub2); + ] + when is_append append && is_subrange subrange1 && is_subrange subrange2 + && is_constant_vec_typ env (typ_of e1) + && is_bitvector_typ (typ_of vector1) + && is_bitvector_typ (typ_of vector2) + && not (is_constant_vec_typ env (typ_of sub1) || is_constant_vec_typ env (typ_of sub2)) -> + let size, order, bittyp = vector_typ_args_of (Env.base_typ_of env typ) in + let size1, _, _ = vector_typ_args_of (Env.base_typ_of env (typ_of e1)) in + let midsize = nminus size size1 in + begin + match solve_unique env midsize with + | Some c -> + let midtyp = bitvector_typ (nconstant c) order in + E_app + ( append, + [ + e1; + E_aux + ( E_typ + ( midtyp, + E_aux + ( E_app + (mk_id "subrange_subrange_concat", [vector1; start1; end1; vector2; start2; end2]), + (Unknown, empty_tannot) + ) + ), + (Unknown, empty_tannot) + ); + ] + ) + | _ -> + E_app + ( append, + [ + e1; + E_aux + ( E_app (mk_id "subrange_subrange_concat", [vector1; start1; end1; vector2; start2; end2]), + (Unknown, empty_tannot) + ); + ] + ) + end + | [ + E_aux (E_app (append, [e1; E_aux (E_app (slice1, [vector1; start1; length1]), _)]), _); + E_aux (E_app (slice2, [vector2; start2; length2]), _); + ] + when is_append append && is_slice slice1 && is_slice slice2 + && is_constant_vec_typ env (typ_of e1) + && is_bitvector_typ (typ_of vector1) + && is_bitvector_typ (typ_of vector2) + && not (is_constant length1 || is_constant length2) -> + let size, order, bittyp = vector_typ_args_of (Env.base_typ_of env typ) in + let size1, _, _ = vector_typ_args_of (Env.base_typ_of env (typ_of e1)) in + let midsize = nminus size size1 in + begin + match solve_unique env midsize with + | Some c -> + let midtyp = bitvector_typ (nconstant c) order in + E_app + ( append, + [ + e1; + E_aux + ( E_typ + ( midtyp, + E_aux + ( E_app + (mk_id "slice_slice_concat", [vector1; start1; length1; vector2; start2; length2]), + (Unknown, empty_tannot) + ) + ), + (Unknown, empty_tannot) + ); + ] + ) + | _ -> + E_app + ( append, + [ + e1; + E_aux + ( E_app (mk_id "slice_slice_concat", [vector1; start1; length1; vector2; start2; length2]), + (Unknown, empty_tannot) + ); + ] + ) + end + (* variable-slice @ zeros *) + | [(E_aux (E_app (op, [vector1; start1; len1]), _) as exp1); zeros_exp] + when (is_slice op || is_subrange op) + && is_zeros_exp zeros_exp + && is_bitvector_typ (typ_of vector1) + && not (is_constant_vec_typ env (typ_of exp1) && is_constant_vec_typ env (typ_of zeros_exp)) -> + let op' = if is_subrange op then "place_subrange" else "place_slice" in + begin + match get_zeros_exp_len zeros_exp with + | Some zlen -> try_cast_to_typ (mk_exp (E_app (mk_id op', [vector1; start1; len1; zlen]))) + | None -> E_app (id, args) + end + (* ones @ zeros *) + | [E_aux (E_app (ones1, [len1]), _); zeros_exp] + when is_ones ones1 && is_zeros_exp zeros_exp + && not (is_constant len1 && is_constant_vec_typ env (typ_of zeros_exp)) -> begin + match get_zeros_exp_len zeros_exp with + | Some zlen -> try_cast_to_typ (mk_exp (E_app (mk_id "slice_mask", [zlen; len1]))) + | None -> E_app (id, args) + end + (* ones @ variable *) + | [E_aux (E_app (ones1, [len1]), _); (E_aux (E_id _, _) as vector2)] when is_ones ones1 && not (is_constant len1) + -> + let one = mk_exp (E_lit (mk_lit (L_num (Big_int.of_int 1)))) in + let len2 = mk_exp (E_app (mk_id "length", [vector2])) in + let total = mk_exp (E_app_infix (len1, mk_id "+", len2)) in + try_cast_to_typ + (E_aux + ( E_app + ( mk_id "update_subrange_bits", + [ + E_aux (E_app (ones1, [total]), (Unknown, empty_tannot)); + mk_exp (E_app_infix (len2, mk_id "-", one)); + mk_exp (E_lit (mk_lit (L_num Big_int.zero))); + vector2; + ] + ), + (Unknown, empty_tannot) + ) + ) + (* variable-range @ variable-range *) + | [ + (E_aux (E_app (subrange1, [vector1; start1; end1]), _) as exp1); + (E_aux (E_app (subrange2, [vector2; start2; end2]), _) as exp2); + ] + when is_subrange subrange1 && is_subrange subrange2 + && is_bitvector_typ (typ_of vector1) + && is_bitvector_typ (typ_of vector2) + && not (is_constant_vec_typ env (typ_of exp1) || is_constant_vec_typ env (typ_of exp2)) -> + try_cast_to_typ + (E_aux + ( E_app (mk_id "subrange_subrange_concat", [vector1; start1; end1; vector2; start2; end2]), + (Unknown, empty_tannot) + ) + ) + (* variable-slice @ variable-slice *) + | [E_aux (E_app (slice1, [vector1; start1; length1]), _); E_aux (E_app (slice2, [vector2; start2; length2]), _)] + when is_slice slice1 && is_slice slice2 + && is_bitvector_typ (typ_of vector1) + && is_bitvector_typ (typ_of vector2) + && not (is_constant length1 || is_constant length2) -> + try_cast_to_typ + (E_aux + ( E_app (mk_id "slice_slice_concat", [vector1; start1; length1; vector2; start2; length2]), + (Unknown, empty_tannot) + ) + ) + (* variable-slice @ local-var *) + | [(E_aux (E_app (op, [vector1; start1; length1]), _) as exp1); (E_aux (E_id _, _) as vector2)] + when (is_slice op || is_subrange op) + && is_bitvector_typ (typ_of vector1) + && not (is_constant_vec_typ env (typ_of exp1)) -> + let op' = if is_subrange op then "subrange_subrange_concat" else "slice_slice_concat" in + let zero = mk_exp (E_lit (mk_lit (L_num Big_int.zero))) in + let one = mk_exp (E_lit (mk_lit (L_num (Big_int.of_int 1)))) in + let length2 = mk_exp (E_app (mk_id "length", [vector2])) in + let indices2 = + if is_subrange op then [mk_exp (E_app_infix (length2, mk_id "-", one)); zero] else [zero; length2] + in + try_cast_to_typ + (E_aux (E_app (mk_id op', [vector1; start1; length1; vector2] @ indices2), (Unknown, empty_tannot))) + | [ + E_aux (E_app (append1, [e1; (E_aux (E_app (op, [vector1; start1; length1]), _) as slice1)]), _); + E_aux (E_app (zeros1, [length2]), _); + ] + when is_append append1 + && (is_slice op || is_subrange op) + && is_zeros zeros1 + && is_constant_vec_typ env (typ_of e1) + && is_bitvector_typ (typ_of vector1) + && not (is_constant_vec_typ env (typ_of slice1) || is_constant length2) -> + let op' = mk_id (if is_subrange op then "subrange_zeros_concat" else "slice_zeros_concat") in + let size, order, bittyp = vector_typ_args_of (Env.base_typ_of env typ) in + let size1, _, _ = vector_typ_args_of (Env.base_typ_of env (typ_of e1)) in + let midsize = nminus size size1 in + begin + match solve_unique env midsize with + | Some c -> + let midtyp = bitvector_typ (nconstant c) order in + try_cast_to_typ + (E_aux + ( E_app + ( mk_id "append", + [ + e1; + E_aux + ( E_typ + ( midtyp, + E_aux (E_app (op', [vector1; start1; length1; length2]), (Unknown, empty_tannot)) + ), + (Unknown, empty_tannot) + ); + ] + ), + (Unknown, empty_tannot) + ) + ) + | _ -> + try_cast_to_typ + (E_aux + ( E_app + ( mk_id "append", + [e1; E_aux (E_app (op', [vector1; start1; length1; length2]), (Unknown, empty_tannot))] + ), + (Unknown, empty_tannot) + ) + ) + end + (* known-length @ (known-length @ var-length) *) + | [e1; E_aux (E_app (append1, [e2; e3]), _)] + when is_append append1 + && is_constant_vec_typ env (typ_of e1) + && is_constant_vec_typ env (typ_of e2) + && not (is_constant_vec_typ env (typ_of e3)) -> + let size1, order, bittyp = vector_typ_args_of (Env.base_typ_of env (typ_of e1)) in + let size2, _, _ = vector_typ_args_of (Env.base_typ_of env (typ_of e2)) in + let size12 = nexp_simp (nsum size1 size2) in + let tannot12 = mk_tannot env (bitvector_typ size12 order) in + E_app (id, [E_aux (E_app (append1, [e1; e2]), (Unknown, tannot12)); e3]) + | _ -> E_app (id, args) + ) + else if is_id env (Id "vector_update_subrange") id then ( + match args with + | [vec1; start1; end1; (E_aux (E_app (subrange2, [vec2; start2; end2]), _) as e2)] + when is_subrange subrange2 && not (is_constant_vec_typ env (typ_of e2)) -> + let op = + if is_number (typ_of vec2) then "vector_update_subrange_from_integer_subrange" + else "vector_update_subrange_from_subrange" + in + try_cast_to_typ (E_aux (E_app (mk_id op, [vec1; start1; end1; vec2; start2; end2]), (Unknown, empty_tannot))) + | [vec1; start1; end1; (E_aux (E_app (zeros, _), _) as e2)] + when is_zeros zeros && not (is_constant_vec_typ env (typ_of e2)) -> + try_cast_to_typ (E_aux (E_app (mk_id "set_subrange_zeros", [vec1; start1; end1]), (Unknown, empty_tannot))) + | _ -> E_app (id, args) + ) + else if is_id env (Id "eq_bits") id || is_id env (Id "neq_bits") id then ( + (* variable-range == variable_range *) + let wrap e = if is_id env (Id "neq_bits") id then E_app (mk_id "not_bool", [mk_exp e]) else e in + match args with + | [E_aux (E_app (subrange1, [vector1; start1; end1]), _); E_aux (E_app (subrange2, [vector2; start2; end2]), _)] + when is_subrange subrange1 && is_subrange subrange2 + && is_bitvector_typ (typ_of vector1) + && is_bitvector_typ (typ_of vector2) + && not (is_constant_range (start1, end1) || is_constant_range (start2, end2)) -> + wrap (E_app (mk_id "subrange_subrange_eq", [vector1; start1; end1; vector2; start2; end2])) + | [E_aux (E_app (slice1, [vector1; len1; start1]), _); E_aux (E_app (slice2, [vector2; len2; start2]), _)] + when is_slice slice1 && is_slice slice2 + && is_bitvector_typ (typ_of vector1) + && is_bitvector_typ (typ_of vector2) + && not (is_constant len1 && is_constant start1 && is_constant len2 && is_constant start2) -> + let upper start len = + mk_exp + (E_app_infix + ( start, + mk_id "+", + mk_exp (E_app_infix (len, mk_id "-", mk_exp (E_lit (mk_lit (L_num (Big_int.of_int 1)))))) + ) + ) + in + wrap + (E_app + (mk_id "subrange_subrange_eq", [vector1; upper start1 len1; start1; vector2; upper start2 len2; start2]) + ) + | [(E_aux (E_app (op, [vector1; start1; len1]), _) as e1); E_aux (E_app (zeros2, _), _)] + when (is_slice op || is_subrange op) + && is_zeros zeros2 + && (not (is_constant_vec_typ env (typ_of e1))) + && is_bitvector_typ (typ_of vector1) -> + let op' = if is_subrange op then "is_zero_subrange" else "is_zeros_slice" in + wrap (E_app (mk_id op', [vector1; start1; len1])) + (* subrange == ones *) + | [E_aux (E_app (subrange1, [vector1; start1; end1]), _); E_aux (E_app (ones2, [_]), _)] + when is_id env (Id "vector_subrange") subrange1 + && is_bitvector_typ (typ_of vector1) + && not (is_constant_range (start1, end1)) -> + E_app (mk_id "is_ones_subrange", [vector1; start1; end1]) + (* slice == ones *) + | [E_aux (E_app (slice1, [vector1; start1; len1]), _); E_aux (E_app (ones2, [_]), _)] + when is_slice slice1 && (not (is_constant len1)) && is_bitvector_typ (typ_of vector1) -> + E_app (mk_id "is_ones_slice", [vector1; start1; len1]) + (* Arm specs sometimes check for overflows on values that can be either 32 or 64 bits + by converting to unbounded integers and asking for the top slice. *) + | [E_aux (E_app (op1, [vector1; start1; end1]), _); E_aux (E_app (op2, [vector2; start2; end2]), _)] + when is_integer_subrange op1 && is_integer_subrange op2 && is_constant start1 && is_constant start2 + && (not (is_constant end1)) + && not (is_constant end2) -> + let zero = mk_exp (E_lit (mk_lit (L_num Big_int.zero))) in + wrap + (E_app + ( mk_id "subrange_subrange_eq", + [ + mk_exp (E_app (mk_id "integer_subrange", [vector1; start1; zero])); + start1; + end1; + mk_exp (E_app (mk_id "integer_subrange", [vector2; start2; zero])); + start2; + end2; + ] + ) + ) + | _ -> E_app (id, args) + ) + else if is_id env (Id "IsZero") id then ( + match args with + | [E_aux (E_app (subrange1, [vector1; start1; end1]), _)] + when is_id env (Id "vector_subrange") subrange1 + && is_bitvector_typ (typ_of vector1) + && not (is_constant_range (start1, end1)) -> + E_app (mk_id "is_zero_subrange", [vector1; start1; end1]) + | [E_aux (E_app (slice1, [vector1; start1; len1]), _)] + when is_slice slice1 && is_bitvector_typ (typ_of vector1) && not (is_constant len1) -> + E_app (mk_id "is_zeros_slice", [vector1; start1; len1]) + | _ -> E_app (id, args) + ) + else if is_id env (Id "IsOnes") id then ( + match args with + | [E_aux (E_app (subrange1, [vector1; start1; end1]), _)] + when is_id env (Id "vector_subrange") subrange1 + && is_bitvector_typ (typ_of vector1) + && not (is_constant_range (start1, end1)) -> + E_app (mk_id "is_ones_subrange", [vector1; start1; end1]) + | [E_aux (E_app (slice1, [vector1; start1; len1]), _)] + when is_slice slice1 && (not (is_constant len1)) && is_bitvector_typ (typ_of vector1) -> + E_app (mk_id "is_ones_slice", [vector1; start1; len1]) + | _ -> E_app (id, args) + ) + else if is_zero_extend || is_truncate then ( + let length_arg = List.filter (fun arg -> is_number (typ_of arg)) args in + match List.filter (fun arg -> not (is_number (typ_of arg))) args with + | [E_aux (E_app (append1, [E_aux (E_app (subrange1, [vector1; start1; end1]), _); zeros_exp]), _)] + when is_subrange subrange1 && is_zeros_exp zeros_exp && is_append append1 && is_bitvector_typ (typ_of vector1) + -> begin + match get_zeros_exp_len zeros_exp with + | Some zlen -> try_cast_to_typ (rewrap (E_app (mk_id "place_subrange", length_arg @ [vector1; start1; end1; zlen]))) - | None -> E_app (id, args) - end - - | [E_aux (E_app (append1, [vector1; zeros_exp]),_)] - when is_constant_vec_typ env (typ_of vector1) && is_zeros_exp zeros_exp && is_append append1 -> - begin match get_zeros_exp_len zeros_exp with - | Some zlen -> - let (vector1, start1, length1) = + | None -> E_app (id, args) + end + | [E_aux (E_app (append1, [vector1; zeros_exp]), _)] + when is_constant_vec_typ env (typ_of vector1) && is_zeros_exp zeros_exp && is_append append1 -> begin + match get_zeros_exp_len zeros_exp with + | Some zlen -> + let vector1, start1, length1 = match vector1 with - | E_aux (E_app (slice1, [vector1; start1; length1]), _) -> - (vector1, start1, length1) + | E_aux (E_app (slice1, [vector1; start1; length1]), _) -> (vector1, start1, length1) | _ -> - let (length1,_,_) = vector_typ_args_of (Env.base_typ_of env (typ_of vector1)) in - (vector1, mk_exp (E_lit (mk_lit (L_num (Big_int.zero)))), mk_exp (E_sizeof length1)) + let length1, _, _ = vector_typ_args_of (Env.base_typ_of env (typ_of vector1)) in + (vector1, mk_exp (E_lit (mk_lit (L_num Big_int.zero))), mk_exp (E_sizeof length1)) in try_cast_to_typ (rewrap (E_app (mk_id "place_slice", length_arg @ [vector1; start1; length1; zlen]))) - | None -> E_app (id, args) + | None -> E_app (id, args) end - - (* If we've already rewritten to slice_slice_concat or subrange_subrange_concat, - we can just drop the zero extension because those functions can do it - themselves *) - | [E_aux (E_typ (_, (E_aux (E_app (Id_aux ((Id "slice_slice_concat" | Id "subrange_subrange_concat" | Id "place_slice" | Id "place_subrange"),_) as op, args),_))),_)] - -> try_cast_to_typ (rewrap (E_app (op, length_arg @ args))) - - | [E_aux (E_app (Id_aux ((Id "slice_slice_concat" | Id "subrange_subrange_concat" | Id "place_slice" | Id "place_subrange"),_) as op, args),_)] - -> try_cast_to_typ (rewrap (E_app (op, length_arg @ args))) - - | [E_aux (E_app (slice1, [vector1; start1; length1]),_)] - when is_slice slice1 && not (is_constant length1) && is_bitvector_typ (typ_of vector1) -> - try_cast_to_typ (rewrap (E_app (mk_id "zext_slice", length_arg @ [vector1; start1; length1]))) - - | [E_aux (E_app (subrange1, [vector1; hi1; lo1]),_)] - when is_subrange subrange1 && not (is_constant hi1 && is_constant lo1) && is_bitvector_typ (typ_of vector1) -> - try_cast_to_typ (rewrap (E_app (mk_id "zext_subrange", length_arg @ [vector1; hi1; lo1]))) - - | [E_aux (E_app (ones, [len1]),_)] when is_ones ones -> - try_cast_to_typ (rewrap (E_app (mk_id "zext_ones", length_arg @ [len1]))) - - | [E_aux (E_app (replicate_bits, [E_aux (E_lit (L_aux (L_bin "1", _)), _); len1]), _)] + (* If we've already rewritten to slice_slice_concat or subrange_subrange_concat, + we can just drop the zero extension because those functions can do it + themselves *) + | [ + E_aux + ( E_typ + ( _, + E_aux + ( E_app + ( ( Id_aux + ( ( Id "slice_slice_concat" + | Id "subrange_subrange_concat" + | Id "place_slice" + | Id "place_subrange" ), + _ + ) as op + ), + args + ), + _ + ) + ), + _ + ); + ] -> + try_cast_to_typ (rewrap (E_app (op, length_arg @ args))) + | [ + E_aux + ( E_app + ( ( Id_aux + ( (Id "slice_slice_concat" | Id "subrange_subrange_concat" | Id "place_slice" | Id "place_subrange"), + _ + ) as op + ), + args + ), + _ + ); + ] -> + try_cast_to_typ (rewrap (E_app (op, length_arg @ args))) + | [E_aux (E_app (slice1, [vector1; start1; length1]), _)] + when is_slice slice1 && (not (is_constant length1)) && is_bitvector_typ (typ_of vector1) -> + try_cast_to_typ (rewrap (E_app (mk_id "zext_slice", length_arg @ [vector1; start1; length1]))) + | [E_aux (E_app (subrange1, [vector1; hi1; lo1]), _)] + when is_subrange subrange1 && (not (is_constant hi1 && is_constant lo1)) && is_bitvector_typ (typ_of vector1) -> + try_cast_to_typ (rewrap (E_app (mk_id "zext_subrange", length_arg @ [vector1; hi1; lo1]))) + | [E_aux (E_app (ones, [len1]), _)] when is_ones ones -> + try_cast_to_typ (rewrap (E_app (mk_id "zext_ones", length_arg @ [len1]))) + | [E_aux (E_app (replicate_bits, [E_aux (E_lit (L_aux (L_bin "1", _)), _); len1]), _)] when is_id env (Id "replicate_bits") replicate_bits -> - let start1 = mk_exp (E_lit (mk_lit (L_num Big_int.zero))) in - try_cast_to_typ (rewrap (E_app (mk_id "slice_mask", length_arg @ [start1; len1]))) - - | [E_aux (E_app (zeros, [len1]),_)] - | [E_aux (E_typ (_, E_aux (E_app (zeros, [len1]),_)), _)] - when is_zeros zeros -> - try_cast_to_typ (rewrap (E_app (zeros, length_arg))) - - | _ -> E_app (id,args) - - else if is_sign_extend then - let length_arg = List.filter (fun arg -> is_number (typ_of arg)) args in - match List.filter (fun arg -> not (is_number (typ_of arg))) args with - | [E_aux (E_app (slice1, [vector1; start1; length1]),_)] - when is_slice slice1 && not (is_constant length1) && is_bitvector_typ (typ_of vector1) -> - try_cast_to_typ (rewrap (E_app (mk_id "sext_slice", length_arg @ [vector1; start1; length1]))) - - | [E_aux (E_app (subrange1, [vector1; hi1; lo1]),_) as exp1] - when is_subrange subrange1 && not (is_constant_vec_typ env (typ_of exp1)) - && is_bitvector_typ (typ_of vector1) -> - try_cast_to_typ (rewrap (E_app (mk_id "sext_subrange", length_arg @ [vector1; hi1; lo1]))) - - | [E_aux (E_app (append, [E_aux (E_app (op, [vector1; start1; len1]), _); zeros_exp]), _)] - when is_append append && (is_slice op || is_subrange op) && is_zeros_exp zeros_exp - && is_bitvector_typ (typ_of vector1) - && not (is_constant len1 && is_constant_vec_typ env (typ_of zeros_exp)) -> - let op' = if is_subrange op then "place_subrange_signed" else "place_slice_signed" in - begin match get_zeros_exp_len zeros_exp with - | Some zlen -> E_app (mk_id op', length_arg @ [vector1; start1; len1; zlen]) - | None -> E_app (id, args) - end - - | [E_aux (E_typ (_, (E_aux (E_app (Id_aux ((Id "place_slice"),_), args),_))),_)] - | [E_aux (E_app (Id_aux ((Id "place_slice"),_), args),_)] - -> try_cast_to_typ (rewrap (E_app (mk_id "place_slice_signed", length_arg @ args))) - - | [E_aux (E_typ (_, (E_aux (E_app (Id_aux ((Id "place_subrange"),_), args),_))),_)] - | [E_aux (E_app (Id_aux ((Id "place_subrange"),_), args),_)] - -> try_cast_to_typ (rewrap (E_app (mk_id "place_subrange_signed", length_arg @ args))) - + let start1 = mk_exp (E_lit (mk_lit (L_num Big_int.zero))) in + try_cast_to_typ (rewrap (E_app (mk_id "slice_mask", length_arg @ [start1; len1]))) + | ([E_aux (E_app (zeros, [len1]), _)] | [E_aux (E_typ (_, E_aux (E_app (zeros, [len1]), _)), _)]) + when is_zeros zeros -> + try_cast_to_typ (rewrap (E_app (zeros, length_arg))) + | _ -> E_app (id, args) + ) + else if is_sign_extend then ( + let length_arg = List.filter (fun arg -> is_number (typ_of arg)) args in + match List.filter (fun arg -> not (is_number (typ_of arg))) args with + | [E_aux (E_app (slice1, [vector1; start1; length1]), _)] + when is_slice slice1 && (not (is_constant length1)) && is_bitvector_typ (typ_of vector1) -> + try_cast_to_typ (rewrap (E_app (mk_id "sext_slice", length_arg @ [vector1; start1; length1]))) + | [(E_aux (E_app (subrange1, [vector1; hi1; lo1]), _) as exp1)] + when is_subrange subrange1 && (not (is_constant_vec_typ env (typ_of exp1))) && is_bitvector_typ (typ_of vector1) + -> + try_cast_to_typ (rewrap (E_app (mk_id "sext_subrange", length_arg @ [vector1; hi1; lo1]))) + | [E_aux (E_app (append, [E_aux (E_app (op, [vector1; start1; len1]), _); zeros_exp]), _)] + when is_append append + && (is_slice op || is_subrange op) + && is_zeros_exp zeros_exp + && is_bitvector_typ (typ_of vector1) + && not (is_constant len1 && is_constant_vec_typ env (typ_of zeros_exp)) -> + let op' = if is_subrange op then "place_subrange_signed" else "place_slice_signed" in + begin + match get_zeros_exp_len zeros_exp with + | Some zlen -> E_app (mk_id op', length_arg @ [vector1; start1; len1; zlen]) + | None -> E_app (id, args) + end + | [E_aux (E_typ (_, E_aux (E_app (Id_aux (Id "place_slice", _), args), _)), _)] + | [E_aux (E_app (Id_aux (Id "place_slice", _), args), _)] -> + try_cast_to_typ (rewrap (E_app (mk_id "place_slice_signed", length_arg @ args))) + | [E_aux (E_typ (_, E_aux (E_app (Id_aux (Id "place_subrange", _), args), _)), _)] + | [E_aux (E_app (Id_aux (Id "place_subrange", _), args), _)] -> + try_cast_to_typ (rewrap (E_app (mk_id "place_subrange_signed", length_arg @ args))) (* If the original had a length, keep it *) - (* | [E_aux (E_app (slice1, [vector1; start1; length1]),_);length2] - when is_slice slice1 && not (is_constant length1) -> - begin - match Type_check.destruct_atom_nexp (env_of length2) (typ_of length2) with - | None -> E_app (mk_id "sext_slice", [vector1; start1; length1]) - | Some nlen -> - let (_,order,bittyp) = vector_typ_args_of (Env.base_typ_of env typ) in - E_typ (vector_typ nlen order bittyp, - E_aux (E_app (mk_id "sext_slice", [vector1; start1; length1]), - (Unknown,empty_tannot))) - end *) - - | _ -> E_app (id,args) - - else if is_id env (Id "Extend") id then - match args with - | [vector; len; unsigned] -> - let extz = mk_exp (rewrite_app env typ (mk_id "ZeroExtend", [vector; len])) in - let exts = mk_exp (rewrite_app env typ (mk_id "SignExtend", [vector; len])) in - E_if (unsigned, extz, exts) - | _ -> E_app (id, args) - - else if is_id env (Id "UInt") id || is_id env (Id "unsigned") id then - match args with - | [E_aux (E_app (slice1, [vector1; start1; length1]),_)] - when is_slice slice1 && not (is_constant length1) && is_bitvector_typ (typ_of vector1) -> - E_app (mk_id "unsigned_slice", [vector1; start1; length1]) - | [E_aux (E_app (subrange1, [vector1; start1; end1]),_)] - when is_subrange subrange1 && not (is_constant_range (start1,end1)) && is_bitvector_typ (typ_of vector1) -> - E_app (mk_id "unsigned_subrange", [vector1; start1; end1]) - - | [E_aux (E_app (append, [vector1; zeros2]), _)] - when is_append append && is_zeros_exp zeros2 && not (is_constant_vec_typ env (typ_of zeros2)) -> - begin match get_zeros_exp_len zeros2 with - | Some len -> - E_app (mk_id "shl_int", [E_aux (E_app (id, [vector1]), (Unknown, empty_tannot)); len]) - | None -> E_app (id, args) - end - - | _ -> E_app (id,args) - - else if is_id env (Id "__SetSlice_bits") id || is_id env (Id "SetSlice") id then - match args with - | [len; slice_len; vector; start; E_aux (E_app (zeros, _), _)] - when is_zeros zeros && is_bitvector_typ (typ_of vector) -> - E_app (mk_id "set_slice_zeros", [len; vector; start; slice_len]) - | _ -> E_app (id, args) - - else if is_id env (Id "Replicate") id then - let length_arg = List.filter (fun arg -> is_number (typ_of arg)) args in - match List.filter (fun arg -> not (is_number (typ_of arg))) args with - | [E_aux (E_lit (L_aux (L_bin "0", _)), _)] -> - E_app (mk_id "sail_zeros", length_arg) - | [E_aux (E_lit (L_aux (L_bin "1", _)), _)] -> - E_app (mk_id "sail_ones", length_arg) - | _ -> E_app (id, args) - - (* Turn constant-length subranges into slices, making the constant length more explicit, - e.g. turning x[i+1 .. i] into slice(x, i, 2) *) - else if is_subrange id then - match get_constant_vec_len ~solve:true env typ, args with - | Some i, [vector1; start1; end1] - when is_bitvector_typ (typ_of vector1) && not (is_constant start1 && is_constant end1) -> - let inc = is_inc_vec (typ_of vector1) in - let low = if inc then start1 else end1 in - let exp' = rewrap (E_app (mk_id "slice", [vector1; low; mk_exp (E_lit (mk_lit (L_num i)))])) in - E_typ (bitvector_typ (nconstant i) (if inc then inc_ord else dec_ord), exp') - | _, _ -> E_app (id, args) - - (* Rewrite (v[x .. y] + i) to (v + (i << y))[x .. y], which is more amenable to further rewriting *) - else if is_id env (Id "add_bits_int") id then - match args with - | [E_aux (E_app (subrange1, [vec1; start1; end1]), a) as exp1; exp2] - when is_subrange subrange1 && is_bitvector_typ (typ_of vec1) - && not (is_constant_vec_typ env (typ_of exp1)) -> - let low = if is_inc_vec (typ_of vec1) then start1 else end1 in - let exp2' = mk_exp (E_app (mk_id "shl_int", [exp2; low])) in - let vec1' = E_aux (E_app (id, [vec1; exp2']), a) in - E_app (subrange1, [vec1'; start1; end1]) - | _ -> E_app (id, args) - - (* Similarly for bitwise operations *) - else if is_id env (Id "and_vec") id || - is_id env (Id "or_vec") id || - is_id env (Id "xor_vec") id then - match args with - | [E_aux (E_app (subrange1, [vec1; start1; end1]), a1) as exp1; - E_aux (E_app (subrange2, [vec2; start2; end2]), a2)] - when is_subrange subrange1 && is_bitvector_typ (typ_of vec1) && - is_subrange subrange2 && is_bitvector_typ (typ_of vec2) && - not (is_constant_vec_typ env (typ_of exp1)) && - eq_exp_conservative start1 start2 && - eq_exp_conservative end1 end2 - -> - E_app (subrange1, [check_exp env (strip_exp (mk_exp (E_app (id, [vec1; vec2])))) (typ_of vec1); start1; end1]) - - | _ -> E_app (id, args) - - else if is_id env (Id "string_of_bits") id then - match args with - | [E_aux (E_app (subrange1, [vec1; start1; end1]), a1) as exp1] - when is_subrange subrange1 && is_bitvector_typ (typ_of vec1) && - not (is_constant_vec_typ env (typ_of exp1)) - -> - E_app (mk_id "string_of_bits_subrange", [vec1; start1; end1]) - | _ -> E_app (id, args) - - else E_app (id,args) - -(* A deeper rewrite may have removed the type information, so try reinferring it *) -let base_typ_of_with_infer env (E_aux (_, (l, tannot)) as exp) = - let typ = - match destruct_tannot tannot with - | Some (_, typ) -> typ - | None -> - typ_of (infer_exp env (strip_exp exp)) - in Env.base_typ_of env typ - -let rec rewrite_aux = function - | E_app (id,args), (l, tannot) -> - begin match destruct_tannot tannot with - | Some (env, ty) -> - E_aux (rewrite_app env ty (id,args), (l, tannot)) - | None -> E_aux (E_app (id, args), (l, tannot)) - end - | E_assign ( - LE_aux (LE_vector_range (LE_aux (LE_id id1,(l_id1,_)), start1, end1),_), - E_aux (E_app (subrange2, [vector2; start2; end2]),(l_assign,_))), - annot - when is_id (env_of_annot annot) (Id "vector_subrange") subrange2 && - not (is_constant_range (start1, end1)) -> - let typ2 = base_typ_of_with_infer (env_of_annot annot) vector2 in - let op = - if is_number typ2 then "vector_update_subrange_from_integer_subrange" else - "vector_update_subrange_from_subrange" - in - E_aux (E_assign (LE_aux (LE_id id1,(l_id1,empty_tannot)), - E_aux (E_app (mk_id op, [ - E_aux (E_id id1,(Generated l_id1,empty_tannot)); - start1; end1; - vector2; start2; end2]),(Unknown,empty_tannot))), - (l_assign, empty_tannot)) - | E_assign (LE_aux (LE_vector_range (LE_aux (LE_id id1, annot1), start1, end1), _), - E_aux (E_app (zeros, _), _)), annot - when is_zeros (env_of_annot annot) zeros -> - let lhs = LE_aux (LE_id id1, annot1) in - let rhs = E_aux (E_app (mk_id "set_subrange_zeros", [E_aux (E_id id1, annot1); start1; end1]), annot1) in - E_aux (E_assign (lhs, rhs), annot) - - | E_assign (LE_aux (LE_vector_range (lexp1, start1, end1), _), - E_aux (E_app (zero_extend, zero_extend_args), _)), (l, tannot) - when is_zero_extend (env_of_tannot tannot) zero_extend && not (is_constant_range (start1, end1)) -> - let new_annot = (Generated l, empty_tannot) in - let vector = List.find (fun exp -> is_bitvector_typ (typ_of exp)) zero_extend_args in - let len = E_aux (E_app (mk_id "length", [vector]), new_annot) in - let mid_point_high = E_aux (E_app_infix (end1, mk_id "+", len), new_annot) in - let mid_point_low = E_aux (E_app_infix ( - mid_point_high, - mk_id "-", - E_aux (E_lit (mk_lit (L_num (Big_int.of_int 1))),new_annot) - ), new_annot) - in - let with_zeros = E_aux (E_app (mk_id "set_subrange_zeros", [lexp_to_exp lexp1; start1; mid_point_high]), new_annot) in - E_aux (E_block [ + (* | [E_aux (E_app (slice1, [vector1; start1; length1]),_);length2] + when is_slice slice1 && not (is_constant length1) -> + begin + match Type_check.destruct_atom_nexp (env_of length2) (typ_of length2) with + | None -> E_app (mk_id "sext_slice", [vector1; start1; length1]) + | Some nlen -> + let (_,order,bittyp) = vector_typ_args_of (Env.base_typ_of env typ) in + E_typ (vector_typ nlen order bittyp, + E_aux (E_app (mk_id "sext_slice", [vector1; start1; length1]), + (Unknown,empty_tannot))) + end *) + | _ -> E_app (id, args) + ) + else if is_id env (Id "Extend") id then ( + match args with + | [vector; len; unsigned] -> + let extz = mk_exp (rewrite_app env typ (mk_id "ZeroExtend", [vector; len])) in + let exts = mk_exp (rewrite_app env typ (mk_id "SignExtend", [vector; len])) in + E_if (unsigned, extz, exts) + | _ -> E_app (id, args) + ) + else if is_id env (Id "UInt") id || is_id env (Id "unsigned") id then ( + match args with + | [E_aux (E_app (slice1, [vector1; start1; length1]), _)] + when is_slice slice1 && (not (is_constant length1)) && is_bitvector_typ (typ_of vector1) -> + E_app (mk_id "unsigned_slice", [vector1; start1; length1]) + | [E_aux (E_app (subrange1, [vector1; start1; end1]), _)] + when is_subrange subrange1 && (not (is_constant_range (start1, end1))) && is_bitvector_typ (typ_of vector1) -> + E_app (mk_id "unsigned_subrange", [vector1; start1; end1]) + | [E_aux (E_app (append, [vector1; zeros2]), _)] + when is_append append && is_zeros_exp zeros2 && not (is_constant_vec_typ env (typ_of zeros2)) -> begin + match get_zeros_exp_len zeros2 with + | Some len -> E_app (mk_id "shl_int", [E_aux (E_app (id, [vector1]), (Unknown, empty_tannot)); len]) + | None -> E_app (id, args) + end + | _ -> E_app (id, args) + ) + else if is_id env (Id "__SetSlice_bits") id || is_id env (Id "SetSlice") id then ( + match args with + | [len; slice_len; vector; start; E_aux (E_app (zeros, _), _)] + when is_zeros zeros && is_bitvector_typ (typ_of vector) -> + E_app (mk_id "set_slice_zeros", [len; vector; start; slice_len]) + | _ -> E_app (id, args) + ) + else if is_id env (Id "Replicate") id then ( + let length_arg = List.filter (fun arg -> is_number (typ_of arg)) args in + match List.filter (fun arg -> not (is_number (typ_of arg))) args with + | [E_aux (E_lit (L_aux (L_bin "0", _)), _)] -> E_app (mk_id "sail_zeros", length_arg) + | [E_aux (E_lit (L_aux (L_bin "1", _)), _)] -> E_app (mk_id "sail_ones", length_arg) + | _ -> E_app (id, args) + (* Turn constant-length subranges into slices, making the constant length more explicit, + e.g. turning x[i+1 .. i] into slice(x, i, 2) *) + ) + else if is_subrange id then ( + match (get_constant_vec_len ~solve:true env typ, args) with + | Some i, [vector1; start1; end1] + when is_bitvector_typ (typ_of vector1) && not (is_constant start1 && is_constant end1) -> + let inc = is_inc_vec (typ_of vector1) in + let low = if inc then start1 else end1 in + let exp' = rewrap (E_app (mk_id "slice", [vector1; low; mk_exp (E_lit (mk_lit (L_num i)))])) in + E_typ (bitvector_typ (nconstant i) (if inc then inc_ord else dec_ord), exp') + | _, _ -> + E_app (id, args) + (* Rewrite (v[x .. y] + i) to (v + (i << y))[x .. y], which is more amenable to further rewriting *) + ) + else if is_id env (Id "add_bits_int") id then ( + match args with + | [(E_aux (E_app (subrange1, [vec1; start1; end1]), a) as exp1); exp2] + when is_subrange subrange1 && is_bitvector_typ (typ_of vec1) && not (is_constant_vec_typ env (typ_of exp1)) -> + let low = if is_inc_vec (typ_of vec1) then start1 else end1 in + let exp2' = mk_exp (E_app (mk_id "shl_int", [exp2; low])) in + let vec1' = E_aux (E_app (id, [vec1; exp2']), a) in + E_app (subrange1, [vec1'; start1; end1]) + | _ -> E_app (id, args) (* Similarly for bitwise operations *) + ) + else if is_id env (Id "and_vec") id || is_id env (Id "or_vec") id || is_id env (Id "xor_vec") id then ( + match args with + | [ + (E_aux (E_app (subrange1, [vec1; start1; end1]), a1) as exp1); E_aux (E_app (subrange2, [vec2; start2; end2]), a2); + ] + when is_subrange subrange1 + && is_bitvector_typ (typ_of vec1) + && is_subrange subrange2 + && is_bitvector_typ (typ_of vec2) + && (not (is_constant_vec_typ env (typ_of exp1))) + && eq_exp_conservative start1 start2 && eq_exp_conservative end1 end2 -> + E_app (subrange1, [check_exp env (strip_exp (mk_exp (E_app (id, [vec1; vec2])))) (typ_of vec1); start1; end1]) + | _ -> E_app (id, args) + ) + else if is_id env (Id "string_of_bits") id then ( + match args with + | [(E_aux (E_app (subrange1, [vec1; start1; end1]), a1) as exp1)] + when is_subrange subrange1 && is_bitvector_typ (typ_of vec1) && not (is_constant_vec_typ env (typ_of exp1)) -> + E_app (mk_id "string_of_bits_subrange", [vec1; start1; end1]) + | _ -> E_app (id, args) + ) + else E_app (id, args) + + (* A deeper rewrite may have removed the type information, so try reinferring it *) + let base_typ_of_with_infer env (E_aux (_, (l, tannot)) as exp) = + let typ = + match destruct_tannot tannot with Some (_, typ) -> typ | None -> typ_of (infer_exp env (strip_exp exp)) + in + Env.base_typ_of env typ + + let rec rewrite_aux = function + | E_app (id, args), (l, tannot) -> begin + match destruct_tannot tannot with + | Some (env, ty) -> E_aux (rewrite_app env ty (id, args), (l, tannot)) + | None -> E_aux (E_app (id, args), (l, tannot)) + end + | ( E_assign + ( LE_aux (LE_vector_range (LE_aux (LE_id id1, (l_id1, _)), start1, end1), _), + E_aux (E_app (subrange2, [vector2; start2; end2]), (l_assign, _)) + ), + annot ) + when is_id (env_of_annot annot) (Id "vector_subrange") subrange2 && not (is_constant_range (start1, end1)) -> + let typ2 = base_typ_of_with_infer (env_of_annot annot) vector2 in + let op = + if is_number typ2 then "vector_update_subrange_from_integer_subrange" + else "vector_update_subrange_from_subrange" + in + E_aux + ( E_assign + ( LE_aux (LE_id id1, (l_id1, empty_tannot)), + E_aux + ( E_app + ( mk_id op, + [E_aux (E_id id1, (Generated l_id1, empty_tannot)); start1; end1; vector2; start2; end2] + ), + (Unknown, empty_tannot) + ) + ), + (l_assign, empty_tannot) + ) + | ( E_assign (LE_aux (LE_vector_range (LE_aux (LE_id id1, annot1), start1, end1), _), E_aux (E_app (zeros, _), _)), + annot ) + when is_zeros (env_of_annot annot) zeros -> + let lhs = LE_aux (LE_id id1, annot1) in + let rhs = E_aux (E_app (mk_id "set_subrange_zeros", [E_aux (E_id id1, annot1); start1; end1]), annot1) in + E_aux (E_assign (lhs, rhs), annot) + | ( E_assign (LE_aux (LE_vector_range (lexp1, start1, end1), _), E_aux (E_app (zero_extend, zero_extend_args), _)), + (l, tannot) ) + when is_zero_extend (env_of_tannot tannot) zero_extend && not (is_constant_range (start1, end1)) -> + let new_annot = (Generated l, empty_tannot) in + let vector = List.find (fun exp -> is_bitvector_typ (typ_of exp)) zero_extend_args in + let len = E_aux (E_app (mk_id "length", [vector]), new_annot) in + let mid_point_high = E_aux (E_app_infix (end1, mk_id "+", len), new_annot) in + let mid_point_low = + E_aux + ( E_app_infix (mid_point_high, mk_id "-", E_aux (E_lit (mk_lit (L_num (Big_int.of_int 1))), new_annot)), + new_annot + ) + in + let with_zeros = + E_aux (E_app (mk_id "set_subrange_zeros", [lexp_to_exp lexp1; start1; mid_point_high]), new_annot) + in + E_aux + ( E_block + [ E_aux (E_assign (lexp1, with_zeros), new_annot); - E_aux (E_assign (LE_aux (LE_vector_range (lexp1, mid_point_low, end1), new_annot), - vector), new_annot) - ], new_annot) - - | (E_let (LB_aux (LB_val (P_aux ((P_id id | P_typ (_, P_aux (P_id id, _))), _), - (E_aux (E_app (subrange1, [vec1; start1; end1]), _) as exp1)), _), - exp2) as e_aux), annot - when is_id (env_of_annot annot) (Id "vector_subrange") subrange1 - && not (is_constant_vec_typ (env_of_annot annot) (typ_of exp1))-> - let open Spec_analysis in - let depends1 = ids_in_exp exp1 in - let assigned2 = IdSet.union (assigned_vars exp2) (bound_vars exp2) in - if IdSet.is_empty (IdSet.inter depends1 assigned2) then rewrite_exp (subst id exp1 exp2) else - E_aux (e_aux, annot) - | e_aux, annot -> E_aux (e_aux, annot) - -and rewrite_exp exp = Rewriter.fold_exp { Rewriter.id_exp_alg with e_aux = rewrite_aux } exp - -let mono_rewrite defs = - let open Rewriter in - rewrite_ast_base - { rewriters_base with - rewrite_exp = fun _ -> fold_exp { id_exp_alg with e_aux = rewrite_aux } } - defs + E_aux (E_assign (LE_aux (LE_vector_range (lexp1, mid_point_low, end1), new_annot), vector), new_annot); + ], + new_annot + ) + | ( ( E_let + ( LB_aux + ( LB_val + ( P_aux ((P_id id | P_typ (_, P_aux (P_id id, _))), _), + (E_aux (E_app (subrange1, [vec1; start1; end1]), _) as exp1) + ), + _ + ), + exp2 + ) as e_aux + ), + annot ) + when is_id (env_of_annot annot) (Id "vector_subrange") subrange1 + && not (is_constant_vec_typ (env_of_annot annot) (typ_of exp1)) -> + let open Spec_analysis in + let depends1 = ids_in_exp exp1 in + let assigned2 = IdSet.union (assigned_vars exp2) (bound_vars exp2) in + if IdSet.is_empty (IdSet.inter depends1 assigned2) then rewrite_exp (subst id exp1 exp2) + else E_aux (e_aux, annot) + | e_aux, annot -> E_aux (e_aux, annot) + + and rewrite_exp exp = Rewriter.fold_exp { Rewriter.id_exp_alg with e_aux = rewrite_aux } exp + + let mono_rewrite defs = + let open Rewriter in + rewrite_ast_base + { rewriters_base with rewrite_exp = (fun _ -> fold_exp { id_exp_alg with e_aux = rewrite_aux }) } + defs end -module BitvectorSizeCasts = -struct - -let simplify_size_nexp env quant_kids nexp = - let rec aux (Nexp_aux (ne,l) as nexp) = - match solve_unique env nexp with - | Some n -> Some (nconstant n) - | None -> - let is_equal kid = - prove __POS__ env (NC_aux (NC_equal (Nexp_aux (Nexp_var kid,Unknown), nexp),Unknown)) - in - match List.find is_equal quant_kids with - | kid -> Some (Nexp_aux (Nexp_var kid,Generated l)) - | exception Not_found -> - (* Normally rewriting of complex nexps in function signatures will - produce a simple constant or variable above, but occasionally it's - useful to work when that rewriting hasn't been applied. In - particular, that rewriting isn't fully working with RISC-V at the - moment. *) - let re f = function - | Some n1, Some n2 -> Some (Nexp_aux (f n1 n2,l)) - | _ -> None - in - match ne with - | Nexp_times(n1,n2) -> - re (fun n1 n2 -> Nexp_times(n1,n2)) (aux n1, aux n2) - | Nexp_sum(n1,n2) -> - re (fun n1 n2 -> Nexp_sum(n1,n2)) (aux n1, aux n2) - | Nexp_minus(n1,n2) -> - re (fun n1 n2 -> Nexp_minus(n1,n2)) (aux n1, aux n2) - | Nexp_exp n -> - Option.map (fun n -> Nexp_aux (Nexp_exp n,l)) (aux n) - | Nexp_neg n -> - Option.map (fun n -> Nexp_aux (Nexp_neg n,l)) (aux n) - | _ -> None - in aux nexp - -let specs_required = ref IdSet.empty -let check_for_spec env name = - let id = mk_id name in - match Env.get_val_spec id env with - | _ -> () - | exception _ -> specs_required := IdSet.add id !specs_required - -(* These functions add cast functions across case splits, so that when a - bitvector size becomes known in sail, the generated Lem code contains a - function call to change mword 'n to (say) mword ty16, and vice versa. *) -let make_bitvector_cast_fns cast_name top_env env quant_kids src_typ target_typ = - let genunk = Generated Unknown in - let fresh = - let counter = ref 0 in - fun () -> - let n = !counter in - let () = counter := n+1 in - mk_id ("cast#" ^ string_of_int n) - in - let at_least_one = ref None in - let rec aux (Typ_aux (src_t,src_l) as src_typ) (Typ_aux (tar_t,tar_l) as tar_typ) = - let src_ann = mk_tannot env src_typ in - let tar_ann = mk_tannot env tar_typ in - match src_t, tar_t with - | Typ_tuple typs, Typ_tuple typs' -> - let ps,es = List.split (List.map2 aux typs typs') in - P_aux (P_typ (src_typ, P_aux (P_tuple ps,(Generated src_l, src_ann))),(Generated src_l, src_ann)), - E_aux (E_tuple es,(Generated tar_l, tar_ann)) - | Typ_app (Id_aux (Id "bitvector",_), - [A_aux (A_nexp size,_); _]), - Typ_app (Id_aux (Id "bitvector",_) as t_id, - [A_aux (A_nexp size',l_size'); t_ord]) -> begin - match simplify_size_nexp env quant_kids size, simplify_size_nexp top_env quant_kids size' with - | Some size, Some size' when Nexp.compare size size' <> 0 -> - let var = fresh () in - let tar_typ' = Typ_aux (Typ_app (t_id, [A_aux (A_nexp size',l_size');t_ord]), - tar_l) in - let () = at_least_one := Some tar_typ' in - P_aux (P_id var,(Generated src_l,src_ann)), - E_aux - (E_typ (tar_typ', - E_aux (E_app (Id_aux (Id cast_name, genunk), - [E_aux (E_id var, (genunk, src_ann))]), (genunk, tar_ann))), - (genunk, tar_ann)) - | _ -> +module BitvectorSizeCasts = struct + let simplify_size_nexp env quant_kids nexp = + let rec aux (Nexp_aux (ne, l) as nexp) = + match solve_unique env nexp with + | Some n -> Some (nconstant n) + | None -> ( + let is_equal kid = prove __POS__ env (NC_aux (NC_equal (Nexp_aux (Nexp_var kid, Unknown), nexp), Unknown)) in + match List.find is_equal quant_kids with + | kid -> Some (Nexp_aux (Nexp_var kid, Generated l)) + | exception Not_found -> ( + (* Normally rewriting of complex nexps in function signatures will + produce a simple constant or variable above, but occasionally it's + useful to work when that rewriting hasn't been applied. In + particular, that rewriting isn't fully working with RISC-V at the + moment. *) + let re f = function Some n1, Some n2 -> Some (Nexp_aux (f n1 n2, l)) | _ -> None in + match ne with + | Nexp_times (n1, n2) -> re (fun n1 n2 -> Nexp_times (n1, n2)) (aux n1, aux n2) + | Nexp_sum (n1, n2) -> re (fun n1 n2 -> Nexp_sum (n1, n2)) (aux n1, aux n2) + | Nexp_minus (n1, n2) -> re (fun n1 n2 -> Nexp_minus (n1, n2)) (aux n1, aux n2) + | Nexp_exp n -> Option.map (fun n -> Nexp_aux (Nexp_exp n, l)) (aux n) + | Nexp_neg n -> Option.map (fun n -> Nexp_aux (Nexp_neg n, l)) (aux n) + | _ -> None + ) + ) + in + aux nexp + + let specs_required = ref IdSet.empty + let check_for_spec env name = + let id = mk_id name in + match Env.get_val_spec id env with _ -> () | exception _ -> specs_required := IdSet.add id !specs_required + + (* These functions add cast functions across case splits, so that when a + bitvector size becomes known in sail, the generated Lem code contains a + function call to change mword 'n to (say) mword ty16, and vice versa. *) + let make_bitvector_cast_fns cast_name top_env env quant_kids src_typ target_typ = + let genunk = Generated Unknown in + let fresh = + let counter = ref 0 in + fun () -> + let n = !counter in + let () = counter := n + 1 in + mk_id ("cast#" ^ string_of_int n) + in + let at_least_one = ref None in + let rec aux (Typ_aux (src_t, src_l) as src_typ) (Typ_aux (tar_t, tar_l) as tar_typ) = + let src_ann = mk_tannot env src_typ in + let tar_ann = mk_tannot env tar_typ in + match (src_t, tar_t) with + | Typ_tuple typs, Typ_tuple typs' -> + let ps, es = List.split (List.map2 aux typs typs') in + ( P_aux (P_typ (src_typ, P_aux (P_tuple ps, (Generated src_l, src_ann))), (Generated src_l, src_ann)), + E_aux (E_tuple es, (Generated tar_l, tar_ann)) + ) + | ( Typ_app (Id_aux (Id "bitvector", _), [A_aux (A_nexp size, _); _]), + Typ_app ((Id_aux (Id "bitvector", _) as t_id), [A_aux (A_nexp size', l_size'); t_ord]) ) -> begin + match (simplify_size_nexp env quant_kids size, simplify_size_nexp top_env quant_kids size') with + | Some size, Some size' when Nexp.compare size size' <> 0 -> + let var = fresh () in + let tar_typ' = Typ_aux (Typ_app (t_id, [A_aux (A_nexp size', l_size'); t_ord]), tar_l) in + let () = at_least_one := Some tar_typ' in + ( P_aux (P_id var, (Generated src_l, src_ann)), + E_aux + ( E_typ + ( tar_typ', + E_aux + ( E_app (Id_aux (Id cast_name, genunk), [E_aux (E_id var, (genunk, src_ann))]), + (genunk, tar_ann) + ) + ), + (genunk, tar_ann) + ) + ) + | _ -> + let var = fresh () in + (P_aux (P_id var, (Generated src_l, src_ann)), E_aux (E_id var, (Generated src_l, tar_ann))) + end + | _ -> let var = fresh () in - P_aux (P_id var,(Generated src_l,src_ann)), - E_aux (E_id var,(Generated src_l,tar_ann)) + (P_aux (P_id var, (Generated src_l, src_ann)), E_aux (E_id var, (Generated src_l, tar_ann))) + in + let src_typ' = Env.base_typ_of env src_typ in + let target_typ' = Env.base_typ_of env target_typ in + let pat, e' = aux src_typ' target_typ' in + match !at_least_one with + | Some one_target_typ -> begin + check_for_spec env cast_name; + let src_ann = mk_tannot env src_typ in + let tar_ann = mk_tannot env target_typ in + let asg_ann = mk_tannot env unit_typ in + match src_typ' with + (* Simple case with just the bitvector; don't need to pull apart value *) + | Typ_aux (Typ_app _, _) -> + ( (fun var exp -> + let exp_ann = mk_tannot env (typ_of exp) in + E_aux + ( E_let + ( LB_aux + ( LB_val + ( P_aux (P_typ (one_target_typ, P_aux (P_id var, (genunk, tar_ann))), (genunk, tar_ann)), + E_aux + ( E_app (Id_aux (Id cast_name, genunk), [E_aux (E_id var, (genunk, src_ann))]), + (genunk, tar_ann) + ) + ), + (genunk, tar_ann) + ), + exp + ), + (genunk, exp_ann) + ) + ), + (fun var -> + [ + E_aux + ( E_assign + ( LE_aux (LE_typ (one_target_typ, var), (genunk, tar_ann)), + E_aux + ( E_app (Id_aux (Id cast_name, genunk), [E_aux (E_id var, (genunk, src_ann))]), + (genunk, tar_ann) + ) + ), + (genunk, asg_ann) + ); + ] + ), + fun (E_aux (_, (exp_l, exp_ann)) as exp) -> + E_aux + ( E_typ + (one_target_typ, E_aux (E_app (Id_aux (Id cast_name, genunk), [exp]), (Generated exp_l, tar_ann))), + (Generated exp_l, tar_ann) + ) + ) + | _ -> + ( (fun var exp -> + let exp_ann = mk_tannot env (typ_of exp) in + E_aux + ( E_let + ( LB_aux (LB_val (pat, E_aux (E_id var, (genunk, src_ann))), (genunk, src_ann)), + E_aux + ( E_let (LB_aux (LB_val (P_aux (P_id var, (genunk, tar_ann)), e'), (genunk, tar_ann)), exp), + (genunk, exp_ann) + ) + ), + (genunk, exp_ann) + ) + ), + (fun var -> + [ + E_aux + ( E_let + ( LB_aux (LB_val (pat, E_aux (E_id var, (genunk, src_ann))), (genunk, src_ann)), + E_aux + (E_assign (LE_aux (LE_typ (one_target_typ, var), (genunk, tar_ann)), e'), (genunk, asg_ann)) + ), + (genunk, asg_ann) + ); + ] + ), + fun (E_aux (_, (exp_l, exp_ann)) as exp) -> + E_aux (E_let (LB_aux (LB_val (pat, exp), (Generated exp_l, exp_ann)), e'), (Generated exp_l, tar_ann)) + ) end - | _ -> - let var = fresh () in - P_aux (P_id var,(Generated src_l,src_ann)), - E_aux (E_id var,(Generated src_l,tar_ann)) - in - let src_typ' = Env.base_typ_of env src_typ in - let target_typ' = Env.base_typ_of env target_typ in - let pat, e' = aux src_typ' target_typ' in - match !at_least_one with - | Some one_target_typ -> begin - check_for_spec env cast_name; - let src_ann = mk_tannot env src_typ in - let tar_ann = mk_tannot env target_typ in - let asg_ann = mk_tannot env unit_typ in - match src_typ' with - (* Simple case with just the bitvector; don't need to pull apart value *) - | Typ_aux (Typ_app _,_) -> - (fun var exp -> - let exp_ann = mk_tannot env (typ_of exp) in - E_aux (E_let (LB_aux (LB_val (P_aux (P_typ (one_target_typ, P_aux (P_id var,(genunk,tar_ann))),(genunk,tar_ann)), - E_aux (E_app (Id_aux (Id cast_name,genunk), - [E_aux (E_id var,(genunk,src_ann))]),(genunk,tar_ann))),(genunk,tar_ann)), - exp),(genunk,exp_ann))), - (fun var -> - [E_aux (E_assign (LE_aux (LE_typ (one_target_typ, var),(genunk,tar_ann)), - E_aux (E_app (Id_aux (Id cast_name,genunk), - [E_aux (E_id var,(genunk,src_ann))]),(genunk,tar_ann) - )),(genunk,asg_ann))]), - (fun (E_aux (_,(exp_l,exp_ann)) as exp) -> - E_aux (E_typ (one_target_typ, - E_aux (E_app (Id_aux (Id cast_name, genunk), [exp]), (Generated exp_l,tar_ann))), - - (Generated exp_l,tar_ann))) - | _ -> - (fun var exp -> - let exp_ann = mk_tannot env (typ_of exp) in - E_aux (E_let (LB_aux (LB_val (pat, E_aux (E_id var,(genunk,src_ann))),(genunk,src_ann)), - E_aux (E_let (LB_aux (LB_val (P_aux (P_id var,(genunk,tar_ann)),e'),(genunk,tar_ann)), - exp),(genunk,exp_ann))),(genunk,exp_ann))), - (fun var -> - [E_aux (E_let (LB_aux (LB_val (pat, E_aux (E_id var,(genunk,src_ann))),(genunk,src_ann)), - E_aux (E_assign (LE_aux (LE_typ (one_target_typ, var),(genunk,tar_ann)), - e'),(genunk,asg_ann))),(genunk,asg_ann))]), - (fun (E_aux (_,(exp_l,exp_ann)) as exp) -> - E_aux (E_let (LB_aux (LB_val (pat, exp),(Generated exp_l,exp_ann)), e'),(Generated exp_l,tar_ann))) - end - | None -> (fun _ e -> e),(fun _ -> []),(fun e -> e) -let make_bitvector_cast_let cast_name top_env env quant_kids src_typ target_typ = - let f,_,_ = make_bitvector_cast_fns cast_name top_env env quant_kids src_typ target_typ - in f -let make_bitvector_cast_assign cast_name top_env env quant_kids src_typ target_typ = - let _,f,_ = make_bitvector_cast_fns cast_name top_env env quant_kids src_typ target_typ - in f -let make_bitvector_cast_cast cast_name top_env env quant_kids src_typ target_typ = - let _,_,f = make_bitvector_cast_fns cast_name top_env env quant_kids src_typ target_typ - in f - -let make_bitvector_env_casts top_env env quant_kids insts exp = - let mk_cast var typ exp = (make_bitvector_cast_let "bitvector_cast_in" env top_env quant_kids typ (subst_kids_typ insts typ)) var exp in - let mk_assign_in var typ = - make_bitvector_cast_assign "bitvector_cast_in" env top_env quant_kids typ - (subst_kids_typ insts typ) var - in - let mk_assign_out var typ = - make_bitvector_cast_assign "bitvector_cast_out" top_env env quant_kids - (subst_kids_typ insts typ) typ var - in - let locals = Env.get_locals env in - let used_ids = ids_in_exp exp in - let locals = Bindings.filter (fun id _ -> IdSet.mem id used_ids) locals in - let immutables,mutables = Bindings.partition (fun _ (mut,_) -> mut = Immutable) locals in - let assigns_in = Bindings.fold (fun var (_,typ) acc -> mk_assign_in var typ @ acc) mutables [] in - let assigns_out = Bindings.fold (fun var (_,typ) acc -> mk_assign_out var typ @ acc) mutables [] in - let exp = match assigns_in, exp with - | [], _ -> exp - | _::_, E_aux (E_block es,ann) -> E_aux (E_block (assigns_in @ es @ assigns_out),ann) - | _::_, E_aux (_,(l,ann)) -> - E_aux (E_block (assigns_in @ [exp] @ assigns_out), (Generated l,ann)) - in - let add_immutables exp = - Bindings.fold (fun var (mut,typ) exp -> - if mut = Immutable then mk_cast var typ exp else exp) immutables exp - in add_immutables exp - -let make_bitvector_cast_exp cast_name cast_env quant_kids typ target_typ exp = - if alpha_equivalent (env_of exp) typ target_typ then exp else - let infer_arg_typ env f l typ = - let (typq, ctor_typ) = Env.get_union_id f env in - match Env.expand_synonyms env ctor_typ with - | Typ_aux (Typ_fn ([arg_typ], ret_typ), _) -> - begin - let goals = quant_kopts typq |> List.map kopt_kid |> KidSet.of_list in - let unifiers = unify l env goals ret_typ typ in - let arg_typ' = subst_unifiers unifiers arg_typ in - arg_typ' - end - | _ -> typ_error env l ("Malformed constructor " ^ string_of_id f ^ " with type " ^ string_of_typ ctor_typ) + | None -> ((fun _ e -> e), (fun _ -> []), fun e -> e) + let make_bitvector_cast_let cast_name top_env env quant_kids src_typ target_typ = + let f, _, _ = make_bitvector_cast_fns cast_name top_env env quant_kids src_typ target_typ in + f + let make_bitvector_cast_assign cast_name top_env env quant_kids src_typ target_typ = + let _, f, _ = make_bitvector_cast_fns cast_name top_env env quant_kids src_typ target_typ in + f + let make_bitvector_cast_cast cast_name top_env env quant_kids src_typ target_typ = + let _, _, f = make_bitvector_cast_fns cast_name top_env env quant_kids src_typ target_typ in + f + + let make_bitvector_env_casts top_env env quant_kids insts exp = + let mk_cast var typ exp = + (make_bitvector_cast_let "bitvector_cast_in" env top_env quant_kids typ (subst_kids_typ insts typ)) var exp + in + let mk_assign_in var typ = + make_bitvector_cast_assign "bitvector_cast_in" env top_env quant_kids typ (subst_kids_typ insts typ) var + in + let mk_assign_out var typ = + make_bitvector_cast_assign "bitvector_cast_out" top_env env quant_kids (subst_kids_typ insts typ) typ var + in + let locals = Env.get_locals env in + let used_ids = ids_in_exp exp in + let locals = Bindings.filter (fun id _ -> IdSet.mem id used_ids) locals in + let immutables, mutables = Bindings.partition (fun _ (mut, _) -> mut = Immutable) locals in + let assigns_in = Bindings.fold (fun var (_, typ) acc -> mk_assign_in var typ @ acc) mutables [] in + let assigns_out = Bindings.fold (fun var (_, typ) acc -> mk_assign_out var typ @ acc) mutables [] in + let exp = + match (assigns_in, exp) with + | [], _ -> exp + | _ :: _, E_aux (E_block es, ann) -> E_aux (E_block (assigns_in @ es @ assigns_out), ann) + | _ :: _, E_aux (_, (l, ann)) -> E_aux (E_block (assigns_in @ [exp] @ assigns_out), (Generated l, ann)) + in + let add_immutables exp = + Bindings.fold (fun var (mut, typ) exp -> if mut = Immutable then mk_cast var typ exp else exp) immutables exp + in + add_immutables exp + + let make_bitvector_cast_exp cast_name cast_env quant_kids typ target_typ exp = + if alpha_equivalent (env_of exp) typ target_typ then exp + else ( + let infer_arg_typ env f l typ = + let typq, ctor_typ = Env.get_union_id f env in + match Env.expand_synonyms env ctor_typ with + | Typ_aux (Typ_fn ([arg_typ], ret_typ), _) -> begin + let goals = quant_kopts typq |> List.map kopt_kid |> KidSet.of_list in + let unifiers = unify l env goals ret_typ typ in + let arg_typ' = subst_unifiers unifiers arg_typ in + arg_typ' + end + | _ -> typ_error env l ("Malformed constructor " ^ string_of_id f ^ " with type " ^ string_of_typ ctor_typ) + in - in - (* Push the cast down, including through constructors *) - let rec aux exp (typ, target_typ) = - if alpha_equivalent (env_of exp) typ target_typ then exp else - let exp_env = env_of exp in - match exp with - | E_aux (E_let (lb,exp'),ann) -> - E_aux (E_let (lb,aux exp' (typ, target_typ)),ann) - | E_aux (E_var (lexp, bind, exp'),ann) -> - E_aux (E_var (lexp, bind, aux exp' (typ, target_typ)), ann) - | E_aux (E_block exps,ann) -> - let exps' = match List.rev exps with - | [] -> [] - | final::l -> aux final (typ, target_typ)::l - in E_aux (E_block (List.rev exps'),ann) - | E_aux (E_tuple exps,(l,ann)) -> begin - match Env.expand_synonyms exp_env typ, Env.expand_synonyms exp_env target_typ with - | Typ_aux (Typ_tuple src_typs,_), Typ_aux (Typ_tuple tgt_typs,_) -> - E_aux (E_tuple (List.map2 aux exps (List.combine src_typs tgt_typs)),(l,ann)) - | _ -> raise (Reporting.err_unreachable l __POS__ - ("Attempted to insert cast on tuple on non-tuple type: " ^ - string_of_typ typ ^ " to " ^ string_of_typ target_typ)) - end - | E_aux (E_app (f,args),(l,ann)) when Env.is_union_constructor f (env_of exp) -> - let arg = match args with [arg] -> arg | _ -> E_aux (E_tuple args, (l,empty_tannot)) in - let src_arg_typ = infer_arg_typ (env_of exp) f l typ in - let tgt_arg_typ = infer_arg_typ (env_of exp) f l target_typ in - E_aux (E_app (f,[aux arg (src_arg_typ, tgt_arg_typ)]),(l,ann)) - | E_aux (E_internal_assume (nc, exp'), ann) -> - E_aux (E_internal_assume (nc, aux exp' (typ, target_typ)), ann) - | _ -> - (make_bitvector_cast_cast cast_name cast_env (env_of exp) quant_kids typ target_typ) exp - in - aux exp (typ, target_typ) + (* Push the cast down, including through constructors *) + let rec aux exp (typ, target_typ) = + if alpha_equivalent (env_of exp) typ target_typ then exp + else ( + let exp_env = env_of exp in + match exp with + | E_aux (E_let (lb, exp'), ann) -> E_aux (E_let (lb, aux exp' (typ, target_typ)), ann) + | E_aux (E_var (lexp, bind, exp'), ann) -> E_aux (E_var (lexp, bind, aux exp' (typ, target_typ)), ann) + | E_aux (E_block exps, ann) -> + let exps' = match List.rev exps with [] -> [] | final :: l -> aux final (typ, target_typ) :: l in + E_aux (E_block (List.rev exps'), ann) + | E_aux (E_tuple exps, (l, ann)) -> begin + match (Env.expand_synonyms exp_env typ, Env.expand_synonyms exp_env target_typ) with + | Typ_aux (Typ_tuple src_typs, _), Typ_aux (Typ_tuple tgt_typs, _) -> + E_aux (E_tuple (List.map2 aux exps (List.combine src_typs tgt_typs)), (l, ann)) + | _ -> + raise + (Reporting.err_unreachable l __POS__ + ("Attempted to insert cast on tuple on non-tuple type: " ^ string_of_typ typ ^ " to " + ^ string_of_typ target_typ + ) + ) + end + | E_aux (E_app (f, args), (l, ann)) when Env.is_union_constructor f (env_of exp) -> + let arg = match args with [arg] -> arg | _ -> E_aux (E_tuple args, (l, empty_tannot)) in + let src_arg_typ = infer_arg_typ (env_of exp) f l typ in + let tgt_arg_typ = infer_arg_typ (env_of exp) f l target_typ in + E_aux (E_app (f, [aux arg (src_arg_typ, tgt_arg_typ)]), (l, ann)) + | E_aux (E_internal_assume (nc, exp'), ann) -> E_aux (E_internal_assume (nc, aux exp' (typ, target_typ)), ann) + | _ -> (make_bitvector_cast_cast cast_name cast_env (env_of exp) quant_kids typ target_typ) exp + ) + in + aux exp (typ, target_typ) + ) -let rec extract_value_from_guard var (E_aux (e,_)) = - match e with - | E_app (op, ([E_aux (E_id var',_); E_aux (E_lit (L_aux (L_num i,_)),_)] | - [E_aux (E_lit (L_aux (L_num i,_)),_); E_aux (E_id var',_)])) + let rec extract_value_from_guard var (E_aux (e, _)) = + match e with + | E_app + ( op, + ( [E_aux (E_id var', _); E_aux (E_lit (L_aux (L_num i, _)), _)] + | [E_aux (E_lit (L_aux (L_num i, _)), _); E_aux (E_id var', _)] ) + ) when string_of_id op = "eq_int" && Id.compare var var' == 0 -> - Some i - | E_app (op, [e1;e2]) when string_of_id op = "and_bool" -> - (match extract_value_from_guard var e1 with - | Some i -> Some i - | None -> extract_value_from_guard var e2) - | _ -> None - -let fill_in_type env typ = - let tyvars = tyvars_of_typ typ in - let subst = KidSet.fold (fun kid subst -> - match Env.get_typ_var kid env with - | K_type - | K_order - | K_bool -> subst - | K_int -> - (match solve_unique env (nvar kid) with - | None -> subst - | Some n -> KBindings.add kid (nconstant n) subst)) tyvars KBindings.empty in - subst_kids_typ subst typ - -(* Extract the instantiations of kids resulting from an if or assert guard *) -let rec extract (E_aux (e,_)) = - match e with - | E_app (op, - ([E_aux (E_sizeof (Nexp_aux (Nexp_var kid,_)),_); y] | - [y; E_aux (E_sizeof (Nexp_aux (Nexp_var kid,_)),_)])) - when string_of_id op = "eq_int" -> - (match destruct_atom_nexp (env_of y) (typ_of y) with - | Some (Nexp_aux (Nexp_constant i,_)) -> [(kid,i)] - | _ -> []) - | E_app (op,[x;y]) - when string_of_id op = "eq_int" -> - (match destruct_atom_nexp (env_of x) (typ_of x), destruct_atom_nexp (env_of y) (typ_of y) with - | Some (Nexp_aux (Nexp_var kid,_)), Some (Nexp_aux (Nexp_constant i,_)) - | Some (Nexp_aux (Nexp_constant i,_)), Some (Nexp_aux (Nexp_var kid,_)) - -> [(kid,i)] - | _ -> []) - | E_app (op, [x;y]) when string_of_id op = "and_bool" -> - extract x @ extract y - | _ -> [] - -(* TODO: top-level patterns *) -(* TODO: proper environment tracking for variables. Currently we pretend that - we can print the type of a variable in the top-level environment, but in - practice they might be below a case split. Note that we'd also need to - provide some way for the Lem pretty printer to know what to use; currently - we just use two names for the cast, bitvector_cast_in and bitvector_cast_out, - to let the pretty printer know whether to use the top-level environment. *) -let add_bitvector_casts global_env ({ defs; _ } as ast) = - let rewrite_body id quant_kids top_env defining_eqns ret_typ exp = - - (* Extract instantiations from a guard, then see if that fills in some equations *) - let extract env exp = - let direct_insts = extract exp in - let direct_insts = List.fold_left (fun insts (kid,i) -> - KBindings.add kid (nconstant i) insts) KBindings.empty direct_insts in - KBindings.fold (fun k nexp new_insts -> - let nexp_subst = subst_kids_nexp direct_insts nexp in - if Nexp.compare nexp nexp_subst <> 0 then - let nexp_simp = simplify_size_nexp env quant_kids nexp_subst in - match nexp_simp with - | Some (Nexp_aux (Nexp_constant i, _) as nexp') -> KBindings.add k nexp' new_insts - | _ -> new_insts - else new_insts) defining_eqns direct_insts - in + Some i + | E_app (op, [e1; e2]) when string_of_id op = "and_bool" -> ( + match extract_value_from_guard var e1 with Some i -> Some i | None -> extract_value_from_guard var e2 + ) + | _ -> None - let rewrite_aux (e,ann) = - match e with - | E_match (E_aux (e',ann') as exp',cases) -> begin - let env = env_of_annot ann in - let result_typ = Env.base_typ_of env (typ_of_annot ann) in - let matched_typ = Env.base_typ_of env (typ_of_annot ann') in - match e',matched_typ with - | E_sizeof (Nexp_aux (Nexp_var kid,_)), _ - | _, Typ_aux (Typ_app (Id_aux (Id "atom",_), [A_aux (A_nexp (Nexp_aux (Nexp_var kid,_)),_)]),_) -> - let map_case pexp = - let pat,guard,body,ann = destruct_pexp pexp in - let body = match pat, guard with - | P_aux (P_lit (L_aux (L_num i,_)),_), _ -> - (* We used to just substitute kid, but fill_in_type also catches other kids defined by it *) - let src_typ = fill_in_type (Env.add_constraint (nc_eq (nvar kid) (nconstant i)) env) result_typ in - make_bitvector_cast_exp "bitvector_cast_out" env quant_kids src_typ result_typ - (make_bitvector_env_casts env (env_of body) quant_kids (KBindings.singleton kid (nconstant i)) body) - | P_aux (P_id var,_), Some guard -> - (match extract_value_from_guard var guard with - | Some i -> - let src_typ = fill_in_type (Env.add_constraint (nc_eq (nvar kid) (nconstant i)) env) result_typ in - make_bitvector_cast_exp "bitvector_cast_out" env quant_kids src_typ result_typ - (make_bitvector_env_casts env (env_of body) quant_kids (KBindings.singleton kid (nconstant i)) body) - | None -> body) - | P_aux (P_wild, (_, annot)), None -> - (* Similar to the literal case *) - begin match body, untyped_annot annot |> get_attribute "int_wildcard" with - | _, Some (_, s) -> - let i = Big_int.of_string s in - let src_typ = fill_in_type (Env.add_constraint (nc_eq (nvar kid) (nconstant i)) env) result_typ in - make_bitvector_cast_exp "bitvector_cast_out" env quant_kids src_typ result_typ - (make_bitvector_env_casts env (env_of body) quant_kids (KBindings.singleton kid (nconstant i)) body) - | E_aux (E_internal_assume (NC_aux (NC_equal (Nexp_aux (Nexp_var kid', _), nexp), _) as nc, body'), assume_ann), _ when Kid.compare kid kid' == 0 -> - let src_typ = fill_in_type (Env.add_constraint (nc_eq (nvar kid) nexp) env) result_typ in - let body'' = make_bitvector_cast_exp "bitvector_cast_out" env quant_kids src_typ result_typ - (make_bitvector_env_casts env (env_of body') quant_kids (KBindings.singleton kid nexp) body') - in E_aux (E_internal_assume (nc, body''), assume_ann) - | _ -> body - end - | _ -> - body - in - construct_pexp (pat, guard, body, ann) - in - E_aux (E_match (exp', List.map map_case cases),ann) - | _ -> E_aux (e,ann) - end - | E_if (e1,e2,e3) -> - let env = env_of_annot ann in - let result_typ = Env.base_typ_of env (typ_of_annot ann) in - let insts = extract env e1 in - let e2' = make_bitvector_env_casts env (env_of e2) quant_kids insts e2 in - let src_typ = subst_kids_typ insts result_typ in - let e2' = make_bitvector_cast_exp "bitvector_cast_out" env quant_kids src_typ result_typ e2' in - (* Ask the type checker if only one value remains for any of kids in - the else branch. *) - let env3 = env_of e3 in - let insts3 = KBindings.fold (fun kid _ i3 -> - match Type_check.solve_unique env3 (nvar kid) with - | None -> i3 - | Some c -> KBindings.add kid (nconstant c) i3) - insts KBindings.empty - in - let e3' = make_bitvector_env_casts env (env_of e3) quant_kids insts3 e3 in - let src_typ3 = subst_kids_typ insts3 result_typ in - let e3' = make_bitvector_cast_exp "bitvector_cast_out" env quant_kids src_typ3 result_typ e3' in - E_aux (E_if (e1,e2',e3'), ann) - | E_return e' -> - E_aux (E_return (make_bitvector_cast_exp "bitvector_cast_out" top_env quant_kids (fill_in_type (env_of e') (typ_of e')) ret_typ e'),ann) - | E_block es -> - let env = env_of_annot ann in - let result_typ = Env.base_typ_of env (typ_of_annot ann) in - let rec aux = function - | [] -> [] - | (E_aux (E_assert (assert_exp,msg),ann) as h)::t -> - (* Check the assertion for constraints that instantiate kids *) - let is_known_kid kid = KBindings.mem kid (Env.get_typ_vars env) in - begin match Type_check.assert_constraint env true assert_exp with - | Some nc when KidSet.for_all is_known_kid (tyvars_of_constraint nc) -> - (* If the type checker can extract constraints from the assertion - for pre-existing kids (not for those that are bound by the - assertion itself), then look at the environment after the - assertion to extract kid instantiations. *) - let env_post = Env.add_constraint nc env in - let check_inst kid insts = - (* First check if the given kid already had a fixed value previously. *) - let rec nc_fixes_kid nc = match unaux_constraint nc with - | NC_equal (Nexp_aux (Nexp_var kid', _), Nexp_aux (Nexp_constant _, _)) -> - Kid.compare kid kid' = 0 - | NC_and (_, _) -> List.exists nc_fixes_kid (constraint_conj nc) - | _ -> false - in - if List.exists nc_fixes_kid (Env.get_constraints env) then - insts - else - (* Otherwise ask the solver for a new, unique value *) - match solve_unique env_post (nvar kid) with - | Some n -> KBindings.add kid (nconstant n) insts - | None -> insts - | exception _ -> insts - in - let insts = KidSet.fold check_inst (tyvars_of_constraint nc) KBindings.empty in - if KBindings.is_empty insts then h :: (aux t) else begin - (* Propagate new instantiations and insert casts *) - let t' = aux t in - let et = E_aux (E_block t',ann) in - let et = make_bitvector_env_casts env env_post quant_kids insts et in - let src_typ = subst_kids_typ insts result_typ in - let et = make_bitvector_cast_exp "bitvector_cast_out" env quant_kids src_typ result_typ et in - [h; et] - end - | _ -> h :: (aux t) - end - | h::t -> h::(aux t) - in E_aux (E_block (aux es),ann) - | _ -> E_aux (e,ann) + let fill_in_type env typ = + let tyvars = tyvars_of_typ typ in + let subst = + KidSet.fold + (fun kid subst -> + match Env.get_typ_var kid env with + | K_type | K_order | K_bool -> subst + | K_int -> ( + match solve_unique env (nvar kid) with None -> subst | Some n -> KBindings.add kid (nconstant n) subst + ) + ) + tyvars KBindings.empty in - let open Rewriter in - fold_exp - { id_exp_alg with - e_aux = rewrite_aux } exp - in - let rewrite_funcl (FCL_aux (FCL_funcl (id,pexp),((def_annot,_) as fcl_ann))) = - let l = def_annot.loc in - let (tq,typ) = Env.get_val_spec_orig id global_env in - let fun_env = List.fold_right (Env.add_typ_var l) (quant_kopts tq) global_env in - let quant_kids = List.map kopt_kid (List.filter is_int_kopt (quant_kopts tq)) in + subst_kids_typ subst typ - let ret_typ = - match typ with - | Typ_aux (Typ_fn (_,ret),_) -> ret - | Typ_aux (_,l) as typ -> - raise (Reporting.err_unreachable l __POS__ - ("Function clause must have function type: " ^ string_of_typ typ ^ - " is not a function type")) - in - let pat,guard,body,annot = destruct_pexp pexp in - let rec strip_assumes = function - | E_aux (E_internal_assume (nc, e), ann) -> - let e', k = strip_assumes e in - e', fun e -> E_aux (E_internal_assume (nc, k e), ann) - | e -> e, fun e -> e - in - let body, restore_assumes = strip_assumes body in + (* Extract the instantiations of kids resulting from an if or assert guard *) + let rec extract (E_aux (e, _)) = + match e with + | E_app + ( op, + ([E_aux (E_sizeof (Nexp_aux (Nexp_var kid, _)), _); y] | [y; E_aux (E_sizeof (Nexp_aux (Nexp_var kid, _)), _)]) + ) + when string_of_id op = "eq_int" -> ( + match destruct_atom_nexp (env_of y) (typ_of y) with + | Some (Nexp_aux (Nexp_constant i, _)) -> [(kid, i)] + | _ -> [] + ) + | E_app (op, [x; y]) when string_of_id op = "eq_int" -> ( + match (destruct_atom_nexp (env_of x) (typ_of x), destruct_atom_nexp (env_of y) (typ_of y)) with + | Some (Nexp_aux (Nexp_var kid, _)), Some (Nexp_aux (Nexp_constant i, _)) + | Some (Nexp_aux (Nexp_constant i, _)), Some (Nexp_aux (Nexp_var kid, _)) -> + [(kid, i)] + | _ -> [] + ) + | E_app (op, [x; y]) when string_of_id op = "and_bool" -> extract x @ extract y + | _ -> [] + + (* TODO: top-level patterns *) + (* TODO: proper environment tracking for variables. Currently we pretend that + we can print the type of a variable in the top-level environment, but in + practice they might be below a case split. Note that we'd also need to + provide some way for the Lem pretty printer to know what to use; currently + we just use two names for the cast, bitvector_cast_in and bitvector_cast_out, + to let the pretty printer know whether to use the top-level environment. *) + let add_bitvector_casts global_env ({ defs; _ } as ast) = + let rewrite_body id quant_kids top_env defining_eqns ret_typ exp = + (* Extract instantiations from a guard, then see if that fills in some equations *) + let extract env exp = + let direct_insts = extract exp in + let direct_insts = + List.fold_left (fun insts (kid, i) -> KBindings.add kid (nconstant i) insts) KBindings.empty direct_insts + in + KBindings.fold + (fun k nexp new_insts -> + let nexp_subst = subst_kids_nexp direct_insts nexp in + if Nexp.compare nexp nexp_subst <> 0 then ( + let nexp_simp = simplify_size_nexp env quant_kids nexp_subst in + match nexp_simp with + | Some (Nexp_aux (Nexp_constant i, _) as nexp') -> KBindings.add k nexp' new_insts + | _ -> new_insts + ) + else new_insts + ) + defining_eqns direct_insts + in - let add_constraint insts = function - | NC_aux (NC_equal (Nexp_aux (Nexp_var kid,_), nexp), _) -> KBindings.add kid nexp insts - | _ -> insts + let rewrite_aux (e, ann) = + match e with + | E_match ((E_aux (e', ann') as exp'), cases) -> begin + let env = env_of_annot ann in + let result_typ = Env.base_typ_of env (typ_of_annot ann) in + let matched_typ = Env.base_typ_of env (typ_of_annot ann') in + match (e', matched_typ) with + | E_sizeof (Nexp_aux (Nexp_var kid, _)), _ + | _, Typ_aux (Typ_app (Id_aux (Id "atom", _), [A_aux (A_nexp (Nexp_aux (Nexp_var kid, _)), _)]), _) -> + let map_case pexp = + let pat, guard, body, ann = destruct_pexp pexp in + let body = + match (pat, guard) with + | P_aux (P_lit (L_aux (L_num i, _)), _), _ -> + (* We used to just substitute kid, but fill_in_type also catches other kids defined by it *) + let src_typ = + fill_in_type (Env.add_constraint (nc_eq (nvar kid) (nconstant i)) env) result_typ + in + make_bitvector_cast_exp "bitvector_cast_out" env quant_kids src_typ result_typ + (make_bitvector_env_casts env (env_of body) quant_kids + (KBindings.singleton kid (nconstant i)) + body + ) + | P_aux (P_id var, _), Some guard -> ( + match extract_value_from_guard var guard with + | Some i -> + let src_typ = + fill_in_type (Env.add_constraint (nc_eq (nvar kid) (nconstant i)) env) result_typ + in + make_bitvector_cast_exp "bitvector_cast_out" env quant_kids src_typ result_typ + (make_bitvector_env_casts env (env_of body) quant_kids + (KBindings.singleton kid (nconstant i)) + body + ) + | None -> body + ) + | P_aux (P_wild, (_, annot)), None -> begin + (* Similar to the literal case *) + match (body, untyped_annot annot |> get_attribute "int_wildcard") with + | _, Some (_, s) -> + let i = Big_int.of_string s in + let src_typ = + fill_in_type (Env.add_constraint (nc_eq (nvar kid) (nconstant i)) env) result_typ + in + make_bitvector_cast_exp "bitvector_cast_out" env quant_kids src_typ result_typ + (make_bitvector_env_casts env (env_of body) quant_kids + (KBindings.singleton kid (nconstant i)) + body + ) + | ( E_aux + ( E_internal_assume + ((NC_aux (NC_equal (Nexp_aux (Nexp_var kid', _), nexp), _) as nc), body'), + assume_ann + ), + _ ) + when Kid.compare kid kid' == 0 -> + let src_typ = fill_in_type (Env.add_constraint (nc_eq (nvar kid) nexp) env) result_typ in + let body'' = + make_bitvector_cast_exp "bitvector_cast_out" env quant_kids src_typ result_typ + (make_bitvector_env_casts env (env_of body') quant_kids (KBindings.singleton kid nexp) + body' + ) + in + E_aux (E_internal_assume (nc, body''), assume_ann) + | _ -> body + end + | _ -> body + in + construct_pexp (pat, guard, body, ann) + in + E_aux (E_match (exp', List.map map_case cases), ann) + | _ -> E_aux (e, ann) + end + | E_if (e1, e2, e3) -> + let env = env_of_annot ann in + let result_typ = Env.base_typ_of env (typ_of_annot ann) in + let insts = extract env e1 in + let e2' = make_bitvector_env_casts env (env_of e2) quant_kids insts e2 in + let src_typ = subst_kids_typ insts result_typ in + let e2' = make_bitvector_cast_exp "bitvector_cast_out" env quant_kids src_typ result_typ e2' in + (* Ask the type checker if only one value remains for any of kids in + the else branch. *) + let env3 = env_of e3 in + let insts3 = + KBindings.fold + (fun kid _ i3 -> + match Type_check.solve_unique env3 (nvar kid) with + | None -> i3 + | Some c -> KBindings.add kid (nconstant c) i3 + ) + insts KBindings.empty + in + let e3' = make_bitvector_env_casts env (env_of e3) quant_kids insts3 e3 in + let src_typ3 = subst_kids_typ insts3 result_typ in + let e3' = make_bitvector_cast_exp "bitvector_cast_out" env quant_kids src_typ3 result_typ e3' in + E_aux (E_if (e1, e2', e3'), ann) + | E_return e' -> + E_aux + ( E_return + (make_bitvector_cast_exp "bitvector_cast_out" top_env quant_kids + (fill_in_type (env_of e') (typ_of e')) + ret_typ e' + ), + ann + ) + | E_block es -> + let env = env_of_annot ann in + let result_typ = Env.base_typ_of env (typ_of_annot ann) in + let rec aux = function + | [] -> [] + | (E_aux (E_assert (assert_exp, msg), ann) as h) :: t -> + (* Check the assertion for constraints that instantiate kids *) + let is_known_kid kid = KBindings.mem kid (Env.get_typ_vars env) in + begin + match Type_check.assert_constraint env true assert_exp with + | Some nc when KidSet.for_all is_known_kid (tyvars_of_constraint nc) -> + (* If the type checker can extract constraints from the assertion + for pre-existing kids (not for those that are bound by the + assertion itself), then look at the environment after the + assertion to extract kid instantiations. *) + let env_post = Env.add_constraint nc env in + let check_inst kid insts = + (* First check if the given kid already had a fixed value previously. *) + let rec nc_fixes_kid nc = + match unaux_constraint nc with + | NC_equal (Nexp_aux (Nexp_var kid', _), Nexp_aux (Nexp_constant _, _)) -> + Kid.compare kid kid' = 0 + | NC_and (_, _) -> List.exists nc_fixes_kid (constraint_conj nc) + | _ -> false + in + if List.exists nc_fixes_kid (Env.get_constraints env) then insts + else ( + (* Otherwise ask the solver for a new, unique value *) + match solve_unique env_post (nvar kid) with + | Some n -> KBindings.add kid (nconstant n) insts + | None -> insts + | exception _ -> insts + ) + in + let insts = KidSet.fold check_inst (tyvars_of_constraint nc) KBindings.empty in + if KBindings.is_empty insts then h :: aux t + else begin + (* Propagate new instantiations and insert casts *) + let t' = aux t in + let et = E_aux (E_block t', ann) in + let et = make_bitvector_env_casts env env_post quant_kids insts et in + let src_typ = subst_kids_typ insts result_typ in + let et = make_bitvector_cast_exp "bitvector_cast_out" env quant_kids src_typ result_typ et in + [h; et] + end + | _ -> h :: aux t + end + | h :: t -> h :: aux t + in + E_aux (E_block (aux es), ann) + | _ -> E_aux (e, ann) + in + let open Rewriter in + fold_exp { id_exp_alg with e_aux = rewrite_aux } exp in - let defining_eqns = List.fold_left add_constraint KBindings.empty (Env.get_constraints (env_of body)) in + let rewrite_funcl (FCL_aux (FCL_funcl (id, pexp), ((def_annot, _) as fcl_ann))) = + let l = def_annot.loc in + let tq, typ = Env.get_val_spec_orig id global_env in + let fun_env = List.fold_right (Env.add_typ_var l) (quant_kopts tq) global_env in + let quant_kids = List.map kopt_kid (List.filter is_int_kopt (quant_kopts tq)) in + + let ret_typ = + match typ with + | Typ_aux (Typ_fn (_, ret), _) -> ret + | Typ_aux (_, l) as typ -> + raise + (Reporting.err_unreachable l __POS__ + ("Function clause must have function type: " ^ string_of_typ typ ^ " is not a function type") + ) + in + let pat, guard, body, annot = destruct_pexp pexp in + let rec strip_assumes = function + | E_aux (E_internal_assume (nc, e), ann) -> + let e', k = strip_assumes e in + (e', fun e -> E_aux (E_internal_assume (nc, k e), ann)) + | e -> (e, fun e -> e) + in + let body, restore_assumes = strip_assumes body in - let body = rewrite_body id quant_kids fun_env defining_eqns ret_typ body in + let add_constraint insts = function + | NC_aux (NC_equal (Nexp_aux (Nexp_var kid, _), nexp), _) -> KBindings.add kid nexp insts + | _ -> insts + in + let defining_eqns = List.fold_left add_constraint KBindings.empty (Env.get_constraints (env_of body)) in + + let body = rewrite_body id quant_kids fun_env defining_eqns ret_typ body in - (* Cast function arguments, if necessary *) - let src_typ = fill_in_type (env_of body) (typ_of body) in - let body = make_bitvector_env_casts fun_env (env_of body) quant_kids defining_eqns body in + (* Cast function arguments, if necessary *) + let src_typ = fill_in_type (env_of body) (typ_of body) in + let body = make_bitvector_env_casts fun_env (env_of body) quant_kids defining_eqns body in - (* Also add a cast around the entire function clause body, if necessary *) - let body = - make_bitvector_cast_exp "bitvector_cast_out" fun_env quant_kids src_typ ret_typ body + (* Also add a cast around the entire function clause body, if necessary *) + let body = make_bitvector_cast_exp "bitvector_cast_out" fun_env quant_kids src_typ ret_typ body in + let body = restore_assumes body in + let pexp = construct_pexp (pat, guard, body, annot) in + FCL_aux (FCL_funcl (id, pexp), fcl_ann) in - let body = restore_assumes body in - let pexp = construct_pexp (pat,guard,body,annot) in - FCL_aux (FCL_funcl (id,pexp),fcl_ann) - in - let rewrite_def idx = function - | DEF_aux (DEF_fundef (FD_aux (FD_function (r,t,fcls),fd_ann) as fd), def_annot) -> - Util.progress "Adding casts " (string_of_id (id_of_fundef fd)) idx (List.length defs); - DEF_aux (DEF_fundef (FD_aux (FD_function (r,t,List.map rewrite_funcl fcls),fd_ann)), def_annot) - | d -> d - in - specs_required := IdSet.empty; - let defs = List.mapi rewrite_def defs in - let _ = Util.progress "Adding casts " "done" (List.length defs) (List.length defs) in - let cast_specs, _ = - (* TODO: use default/relevant order *) - let kid = mk_kid "n" in - let bitsn = bitvector_typ (nvar kid) dec_ord in - let ts = mk_typschm (mk_typquant [mk_qi_id K_int kid]) - (function_typ [bitsn] bitsn) in - let mkfn name = - mk_val_spec (VS_val_spec (ts,name,Some { pure = true; bindings = [("_", "zeroExtend")] },false)) + let rewrite_def idx = function + | DEF_aux (DEF_fundef (FD_aux (FD_function (r, t, fcls), fd_ann) as fd), def_annot) -> + Util.progress "Adding casts " (string_of_id (id_of_fundef fd)) idx (List.length defs); + DEF_aux (DEF_fundef (FD_aux (FD_function (r, t, List.map rewrite_funcl fcls), fd_ann)), def_annot) + | d -> d in - let defs = List.map mkfn (IdSet.elements !specs_required) in - check_defs initial_env defs - in { ast with defs = cast_specs @ defs } + specs_required := IdSet.empty; + let defs = List.mapi rewrite_def defs in + let _ = Util.progress "Adding casts " "done" (List.length defs) (List.length defs) in + let cast_specs, _ = + (* TODO: use default/relevant order *) + let kid = mk_kid "n" in + let bitsn = bitvector_typ (nvar kid) dec_ord in + let ts = mk_typschm (mk_typquant [mk_qi_id K_int kid]) (function_typ [bitsn] bitsn) in + let mkfn name = + mk_val_spec (VS_val_spec (ts, name, Some { pure = true; bindings = [("_", "zeroExtend")] }, false)) + in + let defs = List.map mkfn (IdSet.elements !specs_required) in + check_defs initial_env defs + in + { ast with defs = cast_specs @ defs } end module ToplevelNexpRewrites = struct + let replace_nexp_in_typ env typ orig new_nexp = + let rec aux (Typ_aux (t, l) as typ) = + match t with + | Typ_id _ | Typ_var _ -> (false, typ) + | Typ_fn (arg, res) -> + let arg' = List.map aux arg in + let f1 = List.exists fst arg' in + let f2, res = aux res in + (f1 || f2, Typ_aux (Typ_fn (List.map snd arg', res), l)) + | Typ_bidir (t1, t2) -> + let f1, t1 = aux t1 in + let f2, t2 = aux t2 in + (f1 || f2, Typ_aux (Typ_bidir (t1, t2), l)) + | Typ_tuple typs -> + let fs, typs = List.split (List.map aux typs) in + (List.exists (fun x -> x) fs, Typ_aux (Typ_tuple typs, l)) + | Typ_exist (kids, nc, typ') -> + (* TODO avoid capture *) + let f, typ' = aux typ' in + (f, Typ_aux (Typ_exist (kids, nc, typ'), l)) + | Typ_app (id, targs) -> + let fs, targs = List.split (List.map aux_targ targs) in + (List.exists (fun x -> x) fs, Typ_aux (Typ_app (id, targs), l)) + | Typ_internal_unknown -> Reporting.unreachable l __POS__ "escaped Typ_internal_unknown" + and aux_targ (A_aux (ta, l) as typ_arg) = + match ta with + | A_nexp nexp -> + if prove __POS__ env (nc_eq nexp orig) then (true, A_aux (A_nexp new_nexp, l)) else (false, typ_arg) + | A_typ typ -> + let f, typ = aux typ in + (f, A_aux (A_typ typ, l)) + | A_order _ | A_bool _ -> (false, typ_arg) + in + aux typ + + let fresh_nexp_kid nexp = + let rec mangle_nexp (Nexp_aux (nexp, _)) = + match nexp with + | Nexp_id id -> string_of_id id + | Nexp_var kid -> string_of_id (id_of_kid kid) + | Nexp_constant i -> + (if Big_int.greater_equal i Big_int.zero then "p" else "m") ^ Big_int.to_string (Big_int.abs i) + | Nexp_times (n1, n2) -> mangle_nexp n1 ^ "_times_" ^ mangle_nexp n2 + | Nexp_sum (n1, n2) -> mangle_nexp n1 ^ "_plus_" ^ mangle_nexp n2 + | Nexp_minus (n1, n2) -> mangle_nexp n1 ^ "_minus_" ^ mangle_nexp n2 + | Nexp_exp n -> "exp_" ^ mangle_nexp n + | Nexp_neg n -> "neg_" ^ mangle_nexp n + | Nexp_app (id, args) -> string_of_id id ^ "_" ^ String.concat "_" (List.map mangle_nexp args) + in + (* TODO: I'd like to add a # to distinguish it from user-provided names, but + the rewriter currently uses them as a hint that they're not printable in + types, which these are explicitly supposed to be. *) + mk_kid (mangle_nexp nexp (*^ "#"*)) -let replace_nexp_in_typ env typ orig new_nexp = - let rec aux (Typ_aux (t,l) as typ) = - match t with - | Typ_id _ - | Typ_var _ - -> false, typ - | Typ_fn (arg,res) -> - let arg' = List.map aux arg in - let f1 = List.exists fst arg' in - let f2, res = aux res in - f1 || f2, Typ_aux (Typ_fn (List.map snd arg', res),l) - | Typ_bidir (t1, t2) -> - let f1, t1 = aux t1 in - let f2, t2 = aux t2 in - f1 || f2, Typ_aux (Typ_bidir (t1, t2), l) - | Typ_tuple typs -> - let fs, typs = List.split (List.map aux typs) in - List.exists (fun x -> x) fs, Typ_aux (Typ_tuple typs,l) - | Typ_exist (kids,nc,typ') -> (* TODO avoid capture *) - let f, typ' = aux typ' in - f, Typ_aux (Typ_exist (kids,nc,typ'),l) - | Typ_app (id, targs) -> - let fs, targs = List.split (List.map aux_targ targs) in - List.exists (fun x -> x) fs, Typ_aux (Typ_app (id, targs),l) - | Typ_internal_unknown -> Reporting.unreachable l __POS__ "escaped Typ_internal_unknown" - and aux_targ (A_aux (ta,l) as typ_arg) = - match ta with - | A_nexp nexp -> - if prove __POS__ env (nc_eq nexp orig) - then true, A_aux (A_nexp new_nexp,l) - else false, typ_arg - | A_typ typ -> - let f, typ = aux typ in - f, A_aux (A_typ typ,l) - | A_order _ - | A_bool _ - -> false, typ_arg - in aux typ - -let fresh_nexp_kid nexp = - let rec mangle_nexp (Nexp_aux (nexp, _)) = - match nexp with - | Nexp_id id -> string_of_id id - | Nexp_var kid -> string_of_id (id_of_kid kid) - | Nexp_constant i -> - (if Big_int.greater_equal i Big_int.zero then "p" else "m") - ^ Big_int.to_string (Big_int.abs i) - | Nexp_times (n1, n2) -> mangle_nexp n1 ^ "_times_" ^ mangle_nexp n2 - | Nexp_sum (n1, n2) -> mangle_nexp n1 ^ "_plus_" ^ mangle_nexp n2 - | Nexp_minus (n1, n2) -> mangle_nexp n1 ^ "_minus_" ^ mangle_nexp n2 - | Nexp_exp n -> "exp_" ^ mangle_nexp n - | Nexp_neg n -> "neg_" ^ mangle_nexp n - | Nexp_app (id,args) -> string_of_id id ^ "_" ^ - String.concat "_" (List.map mangle_nexp args) - in - (* TODO: I'd like to add a # to distinguish it from user-provided names, but - the rewriter currently uses them as a hint that they're not printable in - types, which these are explicitly supposed to be. *) - mk_kid (mangle_nexp nexp (*^ "#"*)) - -let find_nexp env nexp_map nexp = - let is_equal (kid,nexp') = prove __POS__ env (nc_eq nexp nexp') in - List.find is_equal nexp_map + let find_nexp env nexp_map nexp = + let is_equal (kid, nexp') = prove __POS__ env (nc_eq nexp nexp') in + List.find is_equal nexp_map -let rec rewrite_typ_in_spec env nexp_map (Typ_aux (t,ann) as typ_full) = - match t with - | Typ_fn (args,res) -> - let args' = List.map (rewrite_typ_in_spec env nexp_map) args in - let nexp_map = List.concat (List.map fst args') in - let nexp_map, res = rewrite_typ_in_spec env nexp_map res in - nexp_map, Typ_aux (Typ_fn (List.map snd args',res),ann) - | Typ_tuple typs -> - let nexp_map, typs = - List.fold_right (fun typ (nexp_map,t) -> - let nexp_map, typ = rewrite_typ_in_spec env nexp_map typ in - (nexp_map, typ::t)) typs (nexp_map,[]) - in nexp_map, Typ_aux (Typ_tuple typs,ann) - | _ -> - let typ' = Env.base_typ_of env typ_full in - if Typ.compare typ_full typ' == 0 then - match t with - | Typ_app (f,args) -> - let in_arg nexp_map (A_aux (arg,l) as arg_full) = - match arg with - | A_typ typ -> - let nexp_map, typ' = rewrite_typ_in_spec env nexp_map typ in - nexp_map, A_aux (A_typ typ',l) - | A_nexp (Nexp_aux (Nexp_constant _,_)) - | A_nexp (Nexp_aux (Nexp_var _,_)) -> nexp_map, arg_full - | A_nexp nexp -> - let nexp_map, kid = - match find_nexp env nexp_map nexp with - | (kid,_) -> nexp_map, kid - | exception Not_found -> - let kid = fresh_nexp_kid nexp in - (kid, nexp)::nexp_map, kid - in - let new_nexp = nvar kid in - nexp_map, A_aux (A_nexp new_nexp, l) - | A_bool _ | A_order _ -> nexp_map, arg_full - in - let nexp_map, args = - List.fold_right (fun arg (nexp_map,args) -> - let nexp_map, arg = in_arg nexp_map arg in - (nexp_map, arg::args)) args (nexp_map,[]) - in nexp_map, Typ_aux (Typ_app (f,args),ann) - | _ -> nexp_map, typ_full - else rewrite_typ_in_spec env nexp_map typ' - -let rewrite_toplevel_nexps ({ defs; _ } as ast) = - let rewrite_valspec (VS_aux (VS_val_spec (TypSchm_aux (TypSchm_ts (tqs,typ),ts_l),id,ext_opt,is_cast),ann)) = - match tqs with - | TypQ_aux (TypQ_no_forall,_) -> None - | TypQ_aux (TypQ_tq qs, tq_l) -> - let env = env_of_annot ann in - let env = Env.add_typquant tq_l tqs env in - let nexp_map, typ = rewrite_typ_in_spec env [] typ in - match nexp_map with - | [] -> None - | _ -> - let new_vars = List.map (fun (kid,nexp) -> QI_aux (QI_id (mk_kopt K_int kid), Generated tq_l)) nexp_map in - let new_constraints = List.map (fun (kid,nexp) -> QI_aux (QI_constraint (nc_eq (nvar kid) nexp), Generated tq_l)) nexp_map in - let tqs = TypQ_aux (TypQ_tq (qs @ new_vars @ new_constraints),tq_l) in - let vs = - VS_aux (VS_val_spec (TypSchm_aux (TypSchm_ts (tqs,typ),ts_l),id,ext_opt,is_cast),ann) in - Some (id, nexp_map, vs) - in - (* Changing types in the body confuses simple sizeof rewriting, so turn it - off for now *) - let rewrite_typ_in_body env nexp_map typ = - let rec aux (Typ_aux (t,l) as typ_full) = + let rec rewrite_typ_in_spec env nexp_map (Typ_aux (t, ann) as typ_full) = match t with - | Typ_tuple typs -> Typ_aux (Typ_tuple (List.map aux typs),l) - | Typ_exist (kids,nc,typ') -> (* TODO: avoid shadowing *) - Typ_aux (Typ_exist (kids,(* TODO? *) nc, aux typ'),l) - | Typ_app (id,targs) -> Typ_aux (Typ_app (id,List.map aux_targ targs),l) - | _ -> typ_full - and aux_targ (A_aux (ta,l) as ta_full) = - match ta with - | A_typ typ -> A_aux (A_typ (aux typ),l) - | A_order _ -> ta_full - | A_nexp nexp -> A_aux (A_nexp (aux_nexp nexp), l) - | A_bool nc -> A_aux (A_bool (aux_nconstraint nc), l) - and aux_nexp nexp = - match find_nexp env nexp_map nexp with - | (kid,_) -> nvar kid - | exception Not_found -> nexp - and aux_nconstraint (NC_aux (nc, l)) = - let rewrap nc = NC_aux (nc, l) in - match nc with - | NC_equal (n1, n2) -> rewrap (NC_equal (aux_nexp n1, aux_nexp n2)) - | NC_bounded_ge (n1, n2) -> rewrap (NC_bounded_ge (aux_nexp n1, aux_nexp n2)) - | NC_bounded_gt (n1, n2) -> rewrap (NC_bounded_gt (aux_nexp n1, aux_nexp n2)) - | NC_bounded_le (n1, n2) -> rewrap (NC_bounded_le (aux_nexp n1, aux_nexp n2)) - | NC_bounded_lt (n1, n2) -> rewrap (NC_bounded_lt (aux_nexp n1, aux_nexp n2)) - | NC_not_equal (n1, n2) -> rewrap (NC_not_equal (aux_nexp n1, aux_nexp n2)) - | NC_or (nc1, nc2) -> rewrap (NC_or (aux_nconstraint nc1, aux_nconstraint nc2)) - | NC_and (nc1, nc2) -> rewrap (NC_and (aux_nconstraint nc1, aux_nconstraint nc2)) - | NC_app (id, args) -> rewrap (NC_app (id, List.map aux_targ args)) - | _ -> rewrap nc - in aux typ - in - let rewrite_one_exp nexp_map (e,ann) = - match e with - | E_typ (typ,e') -> E_aux (E_typ (rewrite_typ_in_body (env_of_annot ann) nexp_map typ,e'),ann) - | E_sizeof nexp -> - (match find_nexp (env_of_annot ann) nexp_map nexp with - | (kid,_) -> E_aux (E_sizeof (nvar kid),ann) - | exception Not_found -> E_aux (e,ann)) - | _ -> E_aux (e,ann) - in - let rewrite_one_pat nexp_map (p,ann) = - match p with - | P_typ (typ,p') -> P_aux (P_typ (rewrite_typ_in_body (env_of_annot ann) nexp_map typ,p'),ann) - | _ -> P_aux (p,ann) - in - let rewrite_one_lexp nexp_map (lexp, ann) = - match lexp with - | LE_typ (typ, id) -> - LE_aux (LE_typ (rewrite_typ_in_body (env_of_annot ann) nexp_map typ, id), ann) - | _ -> LE_aux (lexp, ann) - in - let rewrite_body nexp_map pexp = - let open Rewriter in - fold_pexp { id_exp_alg with - e_aux = rewrite_one_exp nexp_map; - le_aux = rewrite_one_lexp nexp_map; - pat_alg = { id_pat_alg with p_aux = rewrite_one_pat nexp_map } - } pexp - in - let rewrite_funcl spec_map (FCL_aux (FCL_funcl (id,pexp),ann) as funcl) = - match Bindings.find id spec_map with - | nexp_map -> FCL_aux (FCL_funcl (id,rewrite_body nexp_map pexp),ann) - | exception Not_found -> funcl - in - let rewrite_def spec_map def = - match def with - | DEF_aux (DEF_val vs, def_annot) -> (match rewrite_valspec vs with - | None -> spec_map, def - | Some (id, nexp_map, vs) -> Bindings.add id nexp_map spec_map, DEF_aux (DEF_val vs, def_annot)) - | DEF_aux (DEF_fundef (FD_aux (FD_function (recopt,_,funcls),ann)), def_annot) -> - (* Type annotations on function definitions will have been turned into - valspecs by type checking, so it should be safe to drop them rather - than updating them. *) - let tann = Typ_annot_opt_aux (Typ_annot_opt_none,Generated Unknown) in - spec_map, - DEF_aux (DEF_fundef (FD_aux (FD_function (recopt,tann,List.map (rewrite_funcl spec_map) funcls),ann)), def_annot) - | _ -> spec_map, def - in - let _, defs = List.fold_left (fun (spec_map,t) def -> - let spec_map, def = rewrite_def spec_map def in - (spec_map, def::t)) (Bindings.empty, []) defs - in - { ast with defs = List.rev defs } - -(* Move complex sizes in record field types into the parameters. *) -let rewrite_complete_record_params env ast = - let lift_params (additions_map,tl) def = - match def with - | DEF_aux (DEF_type (TD_aux (TD_record (id, (TypQ_aux (TypQ_tq qs, tq_l ) as tyqs), fields, semi), annot)), def_annot) as def -> - (* TODO: replace with a local environment *) - let env = Env.add_typquant tq_l tyqs env in - let nexp_map, fields' = - List.fold_right (fun (typ,id) (nexp_map,t) -> - let nexp_map, typ = rewrite_typ_in_spec env nexp_map typ in - (nexp_map, (typ,id)::t)) fields ([],[]) - in begin - match nexp_map with - | [] -> additions_map, def::tl - | _ -> - let new_vars = List.map (fun (kid,nexp) -> QI_aux (QI_id (mk_kopt K_int kid), Generated tq_l)) nexp_map in - let new_constraints = List.map (fun (kid,nexp) -> QI_aux (QI_constraint (nc_eq (nvar kid) nexp), Generated tq_l)) nexp_map in - let tyqs' = TypQ_aux (TypQ_tq (qs @ new_vars @ new_constraints),tq_l) in - let additions_map' = Bindings.add id (tyqs, nexp_map) additions_map in - additions_map', DEF_aux (DEF_type (TD_aux (TD_record (id, tyqs', fields', semi), annot)), def_annot) :: tl - end - | def -> additions_map, def::tl - in - - let additions_map, rdefs = List.fold_left lift_params (Bindings.empty, []) ast.defs in - let ast = { ast with defs = List.rev rdefs } in - - let rec expand_type (Typ_aux (typ, l) as full_typ) = - match typ with - | Typ_fn (args, ret) -> Typ_aux (Typ_fn (List.map expand_type args, expand_type ret), l) - | Typ_tuple typs -> Typ_aux (Typ_tuple (List.map expand_type typs), l) - (* TODO: another potential shadowing hazard *) - | Typ_exist (kids,nc,typ') -> - Typ_aux (Typ_exist (kids, nc, expand_type typ'), l) - | Typ_app (id, ty_args) -> - begin match Bindings.find_opt id additions_map with - | None -> full_typ - | Some (original_tyqs, nexp_map) -> - let instantiation = - List.fold_left2 - (fun m kopt ty_arg -> - match kopt, ty_arg with - | KOpt_aux (KOpt_kind (K_aux (K_int, _), kid), _), A_aux (A_nexp nexp, _) -> - KBindings.add kid nexp m - | _ -> m) - KBindings.empty - (quant_kopts original_tyqs) ty_args + | Typ_fn (args, res) -> + let args' = List.map (rewrite_typ_in_spec env nexp_map) args in + let nexp_map = List.concat (List.map fst args') in + let nexp_map, res = rewrite_typ_in_spec env nexp_map res in + (nexp_map, Typ_aux (Typ_fn (List.map snd args', res), ann)) + | Typ_tuple typs -> + let nexp_map, typs = + List.fold_right + (fun typ (nexp_map, t) -> + let nexp_map, typ = rewrite_typ_in_spec env nexp_map typ in + (nexp_map, typ :: t) + ) + typs (nexp_map, []) + in + (nexp_map, Typ_aux (Typ_tuple typs, ann)) + | _ -> + let typ' = Env.base_typ_of env typ_full in + if Typ.compare typ_full typ' == 0 then ( + match t with + | Typ_app (f, args) -> + let in_arg nexp_map (A_aux (arg, l) as arg_full) = + match arg with + | A_typ typ -> + let nexp_map, typ' = rewrite_typ_in_spec env nexp_map typ in + (nexp_map, A_aux (A_typ typ', l)) + | A_nexp (Nexp_aux (Nexp_constant _, _)) | A_nexp (Nexp_aux (Nexp_var _, _)) -> (nexp_map, arg_full) + | A_nexp nexp -> + let nexp_map, kid = + match find_nexp env nexp_map nexp with + | kid, _ -> (nexp_map, kid) + | exception Not_found -> + let kid = fresh_nexp_kid nexp in + ((kid, nexp) :: nexp_map, kid) + in + let new_nexp = nvar kid in + (nexp_map, A_aux (A_nexp new_nexp, l)) + | A_bool _ | A_order _ -> (nexp_map, arg_full) + in + let nexp_map, args = + List.fold_right + (fun arg (nexp_map, args) -> + let nexp_map, arg = in_arg nexp_map arg in + (nexp_map, arg :: args) + ) + args (nexp_map, []) + in + (nexp_map, Typ_aux (Typ_app (f, args), ann)) + | _ -> (nexp_map, typ_full) + ) + else rewrite_typ_in_spec env nexp_map typ' + + let rewrite_toplevel_nexps ({ defs; _ } as ast) = + let rewrite_valspec (VS_aux (VS_val_spec (TypSchm_aux (TypSchm_ts (tqs, typ), ts_l), id, ext_opt, is_cast), ann)) = + match tqs with + | TypQ_aux (TypQ_no_forall, _) -> None + | TypQ_aux (TypQ_tq qs, tq_l) -> ( + let env = env_of_annot ann in + let env = Env.add_typquant tq_l tqs env in + let nexp_map, typ = rewrite_typ_in_spec env [] typ in + match nexp_map with + | [] -> None + | _ -> + let new_vars = + List.map (fun (kid, nexp) -> QI_aux (QI_id (mk_kopt K_int kid), Generated tq_l)) nexp_map + in + let new_constraints = + List.map (fun (kid, nexp) -> QI_aux (QI_constraint (nc_eq (nvar kid) nexp), Generated tq_l)) nexp_map + in + let tqs = TypQ_aux (TypQ_tq (qs @ new_vars @ new_constraints), tq_l) in + let vs = VS_aux (VS_val_spec (TypSchm_aux (TypSchm_ts (tqs, typ), ts_l), id, ext_opt, is_cast), ann) in + Some (id, nexp_map, vs) + ) + in + (* Changing types in the body confuses simple sizeof rewriting, so turn it + off for now *) + let rewrite_typ_in_body env nexp_map typ = + let rec aux (Typ_aux (t, l) as typ_full) = + match t with + | Typ_tuple typs -> Typ_aux (Typ_tuple (List.map aux typs), l) + | Typ_exist (kids, nc, typ') -> + (* TODO: avoid shadowing *) + Typ_aux (Typ_exist (kids, (* TODO? *) nc, aux typ'), l) + | Typ_app (id, targs) -> Typ_aux (Typ_app (id, List.map aux_targ targs), l) + | _ -> typ_full + and aux_targ (A_aux (ta, l) as ta_full) = + match ta with + | A_typ typ -> A_aux (A_typ (aux typ), l) + | A_order _ -> ta_full + | A_nexp nexp -> A_aux (A_nexp (aux_nexp nexp), l) + | A_bool nc -> A_aux (A_bool (aux_nconstraint nc), l) + and aux_nexp nexp = match find_nexp env nexp_map nexp with kid, _ -> nvar kid | exception Not_found -> nexp + and aux_nconstraint (NC_aux (nc, l)) = + let rewrap nc = NC_aux (nc, l) in + match nc with + | NC_equal (n1, n2) -> rewrap (NC_equal (aux_nexp n1, aux_nexp n2)) + | NC_bounded_ge (n1, n2) -> rewrap (NC_bounded_ge (aux_nexp n1, aux_nexp n2)) + | NC_bounded_gt (n1, n2) -> rewrap (NC_bounded_gt (aux_nexp n1, aux_nexp n2)) + | NC_bounded_le (n1, n2) -> rewrap (NC_bounded_le (aux_nexp n1, aux_nexp n2)) + | NC_bounded_lt (n1, n2) -> rewrap (NC_bounded_lt (aux_nexp n1, aux_nexp n2)) + | NC_not_equal (n1, n2) -> rewrap (NC_not_equal (aux_nexp n1, aux_nexp n2)) + | NC_or (nc1, nc2) -> rewrap (NC_or (aux_nconstraint nc1, aux_nconstraint nc2)) + | NC_and (nc1, nc2) -> rewrap (NC_and (aux_nconstraint nc1, aux_nconstraint nc2)) + | NC_app (id, args) -> rewrap (NC_app (id, List.map aux_targ args)) + | _ -> rewrap nc + in + aux typ + in + let rewrite_one_exp nexp_map (e, ann) = + match e with + | E_typ (typ, e') -> E_aux (E_typ (rewrite_typ_in_body (env_of_annot ann) nexp_map typ, e'), ann) + | E_sizeof nexp -> ( + match find_nexp (env_of_annot ann) nexp_map nexp with + | kid, _ -> E_aux (E_sizeof (nvar kid), ann) + | exception Not_found -> E_aux (e, ann) + ) + | _ -> E_aux (e, ann) + in + let rewrite_one_pat nexp_map (p, ann) = + match p with + | P_typ (typ, p') -> P_aux (P_typ (rewrite_typ_in_body (env_of_annot ann) nexp_map typ, p'), ann) + | _ -> P_aux (p, ann) + in + let rewrite_one_lexp nexp_map (lexp, ann) = + match lexp with + | LE_typ (typ, id) -> LE_aux (LE_typ (rewrite_typ_in_body (env_of_annot ann) nexp_map typ, id), ann) + | _ -> LE_aux (lexp, ann) + in + let rewrite_body nexp_map pexp = + let open Rewriter in + fold_pexp + { + id_exp_alg with + e_aux = rewrite_one_exp nexp_map; + le_aux = rewrite_one_lexp nexp_map; + pat_alg = { id_pat_alg with p_aux = rewrite_one_pat nexp_map }; + } + pexp + in + let rewrite_funcl spec_map (FCL_aux (FCL_funcl (id, pexp), ann) as funcl) = + match Bindings.find id spec_map with + | nexp_map -> FCL_aux (FCL_funcl (id, rewrite_body nexp_map pexp), ann) + | exception Not_found -> funcl + in + let rewrite_def spec_map def = + match def with + | DEF_aux (DEF_val vs, def_annot) -> ( + match rewrite_valspec vs with + | None -> (spec_map, def) + | Some (id, nexp_map, vs) -> (Bindings.add id nexp_map spec_map, DEF_aux (DEF_val vs, def_annot)) + ) + | DEF_aux (DEF_fundef (FD_aux (FD_function (recopt, _, funcls), ann)), def_annot) -> + (* Type annotations on function definitions will have been turned into + valspecs by type checking, so it should be safe to drop them rather + than updating them. *) + let tann = Typ_annot_opt_aux (Typ_annot_opt_none, Generated Unknown) in + ( spec_map, + DEF_aux + ( DEF_fundef (FD_aux (FD_function (recopt, tann, List.map (rewrite_funcl spec_map) funcls), ann)), + def_annot + ) + ) + | _ -> (spec_map, def) + in + let _, defs = + List.fold_left + (fun (spec_map, t) def -> + let spec_map, def = rewrite_def spec_map def in + (spec_map, def :: t) + ) + (Bindings.empty, []) defs + in + { ast with defs = List.rev defs } + + (* Move complex sizes in record field types into the parameters. *) + let rewrite_complete_record_params env ast = + let lift_params (additions_map, tl) def = + match def with + | DEF_aux + (DEF_type (TD_aux (TD_record (id, (TypQ_aux (TypQ_tq qs, tq_l) as tyqs), fields, semi), annot)), def_annot) as + def -> + (* TODO: replace with a local environment *) + let env = Env.add_typquant tq_l tyqs env in + let nexp_map, fields' = + List.fold_right + (fun (typ, id) (nexp_map, t) -> + let nexp_map, typ = rewrite_typ_in_spec env nexp_map typ in + (nexp_map, (typ, id) :: t) + ) + fields ([], []) in - let new_args = - List.map (fun (_, nexp) -> A_aux (A_nexp (subst_kids_nexp instantiation nexp), Generated l)) nexp_map - in Typ_aux (Typ_app (id, ty_args @ new_args), l) + begin + match nexp_map with + | [] -> (additions_map, def :: tl) + | _ -> + let new_vars = + List.map (fun (kid, nexp) -> QI_aux (QI_id (mk_kopt K_int kid), Generated tq_l)) nexp_map + in + let new_constraints = + List.map (fun (kid, nexp) -> QI_aux (QI_constraint (nc_eq (nvar kid) nexp), Generated tq_l)) nexp_map + in + let tyqs' = TypQ_aux (TypQ_tq (qs @ new_vars @ new_constraints), tq_l) in + let additions_map' = Bindings.add id (tyqs, nexp_map) additions_map in + ( additions_map', + DEF_aux (DEF_type (TD_aux (TD_record (id, tyqs', fields', semi), annot)), def_annot) :: tl + ) + end + | def -> (additions_map, def :: tl) + in - end - | _ -> full_typ - in + let additions_map, rdefs = List.fold_left lift_params (Bindings.empty, []) ast.defs in + let ast = { ast with defs = List.rev rdefs } in - let open Rewriter in - let rw_pat = { - id_pat_alg with - p_typ = (fun (typ, pat) -> P_typ (expand_type typ, pat)); - } - in - let rw_exp = { - id_exp_alg with - e_typ = (fun (typ, exp) -> E_typ (expand_type typ, exp)); - le_typ = (fun (typ, lexp) -> LE_typ (expand_type typ, lexp)); - pat_alg = rw_pat; - } - in - let rw_spec (VS_aux (VS_val_spec (typschm, id, ext, is_cast), annot)) = - match typschm with - | TypSchm_aux (TypSchm_ts (typq, typ), annot') -> - (* TODO: capture hazard *) - let typschm' = TypSchm_aux (TypSchm_ts (typq, expand_type typ), annot') in - VS_aux (VS_val_spec (typschm', id, ext, is_cast), annot) - in - let rw_typedef (TD_aux (td, annot)) = - let rw_union (Tu_aux (Tu_ty_id (typ, id), annot)) = - Tu_aux (Tu_ty_id (expand_type typ, id), annot) + let rec expand_type (Typ_aux (typ, l) as full_typ) = + match typ with + | Typ_fn (args, ret) -> Typ_aux (Typ_fn (List.map expand_type args, expand_type ret), l) + | Typ_tuple typs -> Typ_aux (Typ_tuple (List.map expand_type typs), l) + (* TODO: another potential shadowing hazard *) + | Typ_exist (kids, nc, typ') -> Typ_aux (Typ_exist (kids, nc, expand_type typ'), l) + | Typ_app (id, ty_args) -> begin + match Bindings.find_opt id additions_map with + | None -> full_typ + | Some (original_tyqs, nexp_map) -> + let instantiation = + List.fold_left2 + (fun m kopt ty_arg -> + match (kopt, ty_arg) with + | KOpt_aux (KOpt_kind (K_aux (K_int, _), kid), _), A_aux (A_nexp nexp, _) -> + KBindings.add kid nexp m + | _ -> m + ) + KBindings.empty (quant_kopts original_tyqs) ty_args + in + let new_args = + List.map (fun (_, nexp) -> A_aux (A_nexp (subst_kids_nexp instantiation nexp), Generated l)) nexp_map + in + Typ_aux (Typ_app (id, ty_args @ new_args), l) + end + | _ -> full_typ in - match td with - | TD_abbrev (id, typq, A_aux (A_typ typ, l)) -> - TD_aux (TD_abbrev (id, typq, A_aux (A_typ (expand_type typ), l)), annot) - | TD_abbrev (id, typq, typ_arg) -> - TD_aux (TD_abbrev (id, typq, typ_arg), annot) - | TD_record (id, typq, typ_ids, flag) -> - TD_aux (TD_record (id, typq, List.map (fun (typ, id) -> (expand_type typ, id)) typ_ids, flag), annot) - | TD_variant (id, typq, tus, flag) -> - TD_aux (TD_variant (id, typq, List.map rw_union tus, flag), annot) - | TD_enum (id, ids, flag) -> TD_aux (TD_enum (id, ids, flag), annot) - | TD_bitfield _ -> assert false (* Processed before re-writing *) - in - let rw_register (DEC_aux (DEC_reg (typ, id, init), annot)) = - DEC_aux (DEC_reg (expand_type typ, id, init), annot) - in - let rw_def rws (DEF_aux (aux, def_annot)) = - let aux' = match aux with - | DEF_val vs -> DEF_val (rw_spec vs) - | DEF_type td -> DEF_type (rw_typedef td) - | DEF_register reg -> DEF_register (rw_register reg) - | def -> def - in rewrite_def rws (DEF_aux (aux', def_annot)) - in - rewrite_ast_base - { rewriters_base with rewrite_exp = (fun _ -> fold_exp rw_exp); - rewrite_pat = (fun _ -> fold_pat rw_pat); - rewrite_def = rw_def - } ast -end (* ToplevelNexpRewrites *) + let open Rewriter in + let rw_pat = { id_pat_alg with p_typ = (fun (typ, pat) -> P_typ (expand_type typ, pat)) } in + let rw_exp = + { + id_exp_alg with + e_typ = (fun (typ, exp) -> E_typ (expand_type typ, exp)); + le_typ = (fun (typ, lexp) -> LE_typ (expand_type typ, lexp)); + pat_alg = rw_pat; + } + in + let rw_spec (VS_aux (VS_val_spec (typschm, id, ext, is_cast), annot)) = + match typschm with + | TypSchm_aux (TypSchm_ts (typq, typ), annot') -> + (* TODO: capture hazard *) + let typschm' = TypSchm_aux (TypSchm_ts (typq, expand_type typ), annot') in + VS_aux (VS_val_spec (typschm', id, ext, is_cast), annot) + in + let rw_typedef (TD_aux (td, annot)) = + let rw_union (Tu_aux (Tu_ty_id (typ, id), annot)) = Tu_aux (Tu_ty_id (expand_type typ, id), annot) in + match td with + | TD_abbrev (id, typq, A_aux (A_typ typ, l)) -> + TD_aux (TD_abbrev (id, typq, A_aux (A_typ (expand_type typ), l)), annot) + | TD_abbrev (id, typq, typ_arg) -> TD_aux (TD_abbrev (id, typq, typ_arg), annot) + | TD_record (id, typq, typ_ids, flag) -> + TD_aux (TD_record (id, typq, List.map (fun (typ, id) -> (expand_type typ, id)) typ_ids, flag), annot) + | TD_variant (id, typq, tus, flag) -> TD_aux (TD_variant (id, typq, List.map rw_union tus, flag), annot) + | TD_enum (id, ids, flag) -> TD_aux (TD_enum (id, ids, flag), annot) + | TD_bitfield _ -> assert false (* Processed before re-writing *) + in + let rw_register (DEC_aux (DEC_reg (typ, id, init), annot)) = DEC_aux (DEC_reg (expand_type typ, id, init), annot) in + let rw_def rws (DEF_aux (aux, def_annot)) = + let aux' = + match aux with + | DEF_val vs -> DEF_val (rw_spec vs) + | DEF_type td -> DEF_type (rw_typedef td) + | DEF_register reg -> DEF_register (rw_register reg) + | def -> def + in + rewrite_def rws (DEF_aux (aux', def_annot)) + in + rewrite_ast_base + { + rewriters_base with + rewrite_exp = (fun _ -> fold_exp rw_exp); + rewrite_pat = (fun _ -> fold_pat rw_pat); + rewrite_def = rw_def; + } + ast +end +(* ToplevelNexpRewrites *) let rewrite_toplevel_nexps = ToplevelNexpRewrites.rewrite_toplevel_nexps let rewrite_complete_record_params = ToplevelNexpRewrites.rewrite_complete_record_params -type options = { - auto : bool; - debug_analysis : int; - all_split_errors : bool; - continue_anyway : bool -} +type options = { auto : bool; debug_analysis : int; all_split_errors : bool; continue_anyway : bool } let mono_rewrites = MonoRewrites.mono_rewrite let monomorphise target effect_info opts splits ast = let ast, env = Type_check.check Type_check.initial_env (strip_ast ast) in let ok_analysis, new_splits, extra_splits = - if opts.auto - then - let f,r,ex = Analysis.analyse_defs opts.debug_analysis effect_info env ast in - if f || opts.all_split_errors || opts.continue_anyway - then f, r, ex + if opts.auto then ( + let f, r, ex = Analysis.analyse_defs opts.debug_analysis effect_info env ast in + if f || opts.all_split_errors || opts.continue_anyway then (f, r, ex) else raise (Reporting.err_general Unknown "Unable to monomorphise program") - else true, [], Analysis.ExtraSplits.empty in - let splits = new_splits @ (List.map (fun ((file,line),id) -> (Line (file,line),id,None)) splits) in + ) + else (true, [], Analysis.ExtraSplits.empty) + in + let splits = new_splits @ List.map (fun ((file, line), id) -> (Line (file, line), id, None)) splits in let ok_extras, defs, extra_splits = add_extra_splits extra_splits ast.defs in - let ast = { ast with defs = defs } in + let ast = { ast with defs } in let splits = splits @ extra_splits in - let () = if ok_extras || opts.all_split_errors || opts.continue_anyway - then () - else raise (Reporting.err_general Unknown "Unable to monomorphise program") + let () = + if ok_extras || opts.all_split_errors || opts.continue_anyway then () + else raise (Reporting.err_general Unknown "Unable to monomorphise program") in let ok_split, ast = split_defs target opts.all_split_errors splits env ast in - let () = if (ok_analysis && ok_extras && ok_split) || opts.continue_anyway - then () - else raise (Reporting.err_general Unknown "Unable to monomorphise program") - in ast + let () = + if (ok_analysis && ok_extras && ok_split) || opts.continue_anyway then () + else raise (Reporting.err_general Unknown "Unable to monomorphise program") + in + ast let add_bitvector_casts = BitvectorSizeCasts.add_bitvector_casts let rewrite_atoms_to_singletons target ast = diff --git a/src/lib/monomorphise.mli b/src/lib/monomorphise.mli index 99daa8b67..2ce45a873 100644 --- a/src/lib/monomorphise.mli +++ b/src/lib/monomorphise.mli @@ -68,19 +68,21 @@ open Ast_defs val opt_mwords : bool ref - + type options = { - auto : bool; (* Analyse ast to find splits for monomorphisation *) - debug_analysis : int; (* Debug output level for the automatic analysis *) + auto : bool; (* Analyse ast to find splits for monomorphisation *) + debug_analysis : int; (* Debug output level for the automatic analysis *) all_split_errors : bool; - continue_anyway : bool + continue_anyway : bool; } val monomorphise : - string -> (* Target backend *) + string -> + (* Target backend *) Effects.side_effect_info -> options -> - ((string * int) * string) list -> (* List of splits from the command line *) + ((string * int) * string) list -> + (* List of splits from the command line *) Type_check.tannot ast -> Type_check.tannot ast diff --git a/src/lib/nl_flow.ml b/src/lib/nl_flow.ml index bff81bb19..30fc1ceb5 100644 --- a/src/lib/nl_flow.ml +++ b/src/lib/nl_flow.ml @@ -77,10 +77,7 @@ let rec escapes (E_aux (aux, _)) = | E_block exps -> escapes (List.hd (List.rev exps)) | _ -> false -let is_bitvector_literal (L_aux (aux, _)) = - match aux with - | L_bin _ | L_hex _ -> true - | _ -> false +let is_bitvector_literal (L_aux (aux, _)) = match aux with L_bin _ | L_hex _ -> true | _ -> false let bitvector_unsigned (L_aux (aux, _)) = let open Sail_lib in @@ -90,11 +87,7 @@ let bitvector_unsigned (L_aux (aux, _)) = | _ -> assert false let rec pat_id (P_aux (aux, _)) = - match aux with - | P_id id -> Some id - | P_as (_, id) -> Some id - | P_var (pat, _) -> pat_id pat - | _ -> None + match aux with P_id id -> Some id | P_as (_, id) -> Some id | P_var (pat, _) -> pat_id pat | _ -> None let add_assert cond (E_aux (aux, (l, uannot)) as exp) = let msg = mk_lit_exp (L_string "") in @@ -107,29 +100,30 @@ let add_assert cond (E_aux (aux, (l, uannot)) as exp) = will also know that y != unsigned(bitv) *) let modify_unsigned id value (E_aux (aux, annot) as exp) = match aux with - | E_let (LB_aux (LB_val (pat, E_aux (E_app (f, [E_aux (E_id id', _)]), _)), _) as lb, exp') - when (string_of_id f = "unsigned" || string_of_id f = "UInt") && Id.compare id id' = 0 -> - begin match pat_id pat with - | None -> exp - | Some uid -> - E_aux (E_let (lb, - add_assert (mk_exp (E_app_infix (mk_exp (E_id uid), mk_id "!=", mk_lit_exp (L_num value)))) exp'), - annot) - end + | E_let ((LB_aux (LB_val (pat, E_aux (E_app (f, [E_aux (E_id id', _)]), _)), _) as lb), exp') + when (string_of_id f = "unsigned" || string_of_id f = "UInt") && Id.compare id id' = 0 -> begin + match pat_id pat with + | None -> exp + | Some uid -> + E_aux + ( E_let + (lb, add_assert (mk_exp (E_app_infix (mk_exp (E_id uid), mk_id "!=", mk_lit_exp (L_num value)))) exp'), + annot + ) + end | _ -> exp let analyze' exps = match exps with - | E_aux (E_if (cond, then_exp, _), _) :: _ when escapes then_exp -> - begin match cond with - | E_aux (E_app_infix (E_aux (E_id id, _), op, E_aux (E_lit lit, _)), _) - | E_aux (E_app_infix (E_aux (E_lit lit, _), op, E_aux (E_id id, _)), _) - when string_of_id op = "==" && is_bitvector_literal lit -> - let value = bitvector_unsigned lit in - List.map (modify_unsigned id value) exps - | _ -> exps - end + | E_aux (E_if (cond, then_exp, _), _) :: _ when escapes then_exp -> begin + match cond with + | E_aux (E_app_infix (E_aux (E_id id, _), op, E_aux (E_lit lit, _)), _) + | E_aux (E_app_infix (E_aux (E_lit lit, _), op, E_aux (E_id id, _)), _) + when string_of_id op = "==" && is_bitvector_literal lit -> + let value = bitvector_unsigned lit in + List.map (modify_unsigned id value) exps + | _ -> exps + end | _ -> exps -let analyze exps = - if !opt_nl_flow then analyze' exps else exps +let analyze exps = if !opt_nl_flow then analyze' exps else exps diff --git a/src/lib/optimize.ml b/src/lib/optimize.ml index 1841fac68..041649c05 100644 --- a/src/lib/optimize.ml +++ b/src/lib/optimize.ml @@ -79,58 +79,59 @@ let rec split_at_function' id defs acc = let split_at_function id defs = match split_at_function' id defs [] with | None -> None - | Some (pre_defs, def, post_defs) -> - Some (List.rev pre_defs, def, post_defs) + | Some (pre_defs, def, post_defs) -> Some (List.rev pre_defs, def, post_defs) -let rec last_env = function - | [] -> Type_check.initial_env - | [(_, env)] -> env - | _ :: xs -> last_env xs - -let recheck ({ defs; _} as ast) = +let rec last_env = function [] -> Type_check.initial_env | [(_, env)] -> env | _ :: xs -> last_env xs + +let recheck ({ defs; _ } as ast) = let defs = Type_check.check_with_envs Type_check.initial_env (List.map Type_check.strip_def defs) in let rec find_optimizations = function - | ([DEF_aux (DEF_pragma ("optimize", pragma, p_l), _)], env) :: ([DEF_aux (DEF_val vs, vs_annot) as def1], _) :: defs -> - let id = id_of_val_spec vs in - let args = Str.split (Str.regexp " +") (String.trim pragma) in - begin match args with - | ["unroll"; n]-> - let n = int_of_string n in - begin match split_at_function id defs with - | Some (intervening_defs, ((DEF_aux (DEF_fundef fdef, fdef_annot) as def2, _)), defs) -> - let rw_app subst (fn, args) = - if Id.compare id fn = 0 then E_app (subst, args) else E_app (fn, args) - in - let rw_exp subst = { id_exp_alg with e_app = rw_app subst } in - let rw_defs subst = { rewriters_base with rewrite_exp = (fun _ -> fold_exp (rw_exp subst)) } in - - let specs = ref [def1] in - let bodies = ref [rewrite_def (rw_defs (append_id id "_unroll_1")) def2] in + | ([DEF_aux (DEF_pragma ("optimize", pragma, p_l), _)], env) + :: ([(DEF_aux (DEF_val vs, vs_annot) as def1)], _) + :: defs -> + let id = id_of_val_spec vs in + let args = Str.split (Str.regexp " +") (String.trim pragma) in + begin + match args with + | ["unroll"; n] -> + let n = int_of_string n in + begin + match split_at_function id defs with + | Some (intervening_defs, ((DEF_aux (DEF_fundef fdef, fdef_annot) as def2), _), defs) -> + let rw_app subst (fn, args) = + if Id.compare id fn = 0 then E_app (subst, args) else E_app (fn, args) + in + let rw_exp subst = { id_exp_alg with e_app = rw_app subst } in + let rw_defs subst = { rewriters_base with rewrite_exp = (fun _ -> fold_exp (rw_exp subst)) } in - for i = 1 to n do - let current_id = append_id id ("_unroll_" ^ string_of_int i) in - let next_id = if i = n then current_id else append_id id ("_unroll_" ^ string_of_int (i + 1)) in - (* Create a valspec for the new unrolled function *) - specs := !specs @ [DEF_aux (DEF_val (rename_valspec current_id vs), vs_annot)]; - (* Then duplicate its function body and make it call the next unrolled function *) - bodies := !bodies @ [rewrite_def (rw_defs next_id) (DEF_aux (DEF_fundef (rename_fundef current_id fdef), fdef_annot))] - done; + let specs = ref [def1] in + let bodies = ref [rewrite_def (rw_defs (append_id id "_unroll_1")) def2] in - !specs @ List.concat (List.map fst intervening_defs) @ !bodies @ find_optimizations defs + for i = 1 to n do + let current_id = append_id id ("_unroll_" ^ string_of_int i) in + let next_id = if i = n then current_id else append_id id ("_unroll_" ^ string_of_int (i + 1)) in + (* Create a valspec for the new unrolled function *) + specs := !specs @ [DEF_aux (DEF_val (rename_valspec current_id vs), vs_annot)]; + (* Then duplicate its function body and make it call the next unrolled function *) + bodies := + !bodies + @ [ + rewrite_def (rw_defs next_id) + (DEF_aux (DEF_fundef (rename_fundef current_id fdef), fdef_annot)); + ] + done; + !specs @ List.concat (List.map fst intervening_defs) @ !bodies @ find_optimizations defs + | _ -> + Reporting.warn "Could not find function body for unroll pragma at " p_l ""; + def1 :: find_optimizations defs + end | _ -> - Reporting.warn "Could not find function body for unroll pragma at " p_l ""; - def1 :: find_optimizations defs - end - | _ -> - Reporting.warn "Unrecognised optimize pragma at" p_l ""; - def1 :: find_optimizations defs - end - - | (defs, _) :: defs' -> - defs @ find_optimizations defs' - + Reporting.warn "Unrecognised optimize pragma at" p_l ""; + def1 :: find_optimizations defs + end + | (defs, _) :: defs' -> defs @ find_optimizations defs' | [] -> [] in diff --git a/src/lib/outcome_rewrites.ml b/src/lib/outcome_rewrites.ml index dfc9a067e..a3762dd16 100644 --- a/src/lib/outcome_rewrites.ml +++ b/src/lib/outcome_rewrites.ml @@ -83,83 +83,102 @@ let rec instantiate_id id = function | IS_aux (IS_id (id_from, id_to), _) :: _ when Id.compare id id_from = 0 -> id_to | _ :: substs -> instantiate_id id substs | [] -> id - + let instantiate_typ substs typ = - List.fold_left (fun typ -> function - | (kid, (_, subst_typ)) -> typ_subst kid (mk_typ_arg (A_typ subst_typ)) typ - ) typ (KBindings.bindings substs) - + List.fold_left + (fun typ -> function kid, (_, subst_typ) -> typ_subst kid (mk_typ_arg (A_typ subst_typ)) typ) + typ (KBindings.bindings substs) + let instantiate_def target id substs = function - | DEF_aux (DEF_impl (FCL_aux (FCL_funcl (target_id, pexp), (fcl_def_annot, tannot))), def_annot) when string_of_id target_id = target -> - let l = gen_loc (id_loc id) in - Some (DEF_aux ( - DEF_fundef (FD_aux (FD_function (Rec_aux (Rec_nonrec, l), - Typ_annot_opt_aux (Typ_annot_opt_none, l), - [FCL_aux (FCL_funcl (id, pexp), (fcl_def_annot, tannot))]), - (l, tannot))), - def_annot)) + | DEF_aux (DEF_impl (FCL_aux (FCL_funcl (target_id, pexp), (fcl_def_annot, tannot))), def_annot) + when string_of_id target_id = target -> + let l = gen_loc (id_loc id) in + Some + (DEF_aux + ( DEF_fundef + (FD_aux + ( FD_function + ( Rec_aux (Rec_nonrec, l), + Typ_annot_opt_aux (Typ_annot_opt_none, l), + [FCL_aux (FCL_funcl (id, pexp), (fcl_def_annot, tannot))] + ), + (l, tannot) + ) + ), + def_annot + ) + ) | def -> None let rec instantiated_or_abstract l = function | [] -> None | None :: xs -> instantiated_or_abstract l xs | Some def :: xs -> - if List.for_all Option.is_none xs then - Some def - else - raise (Reporting.err_general l "Multiple instantiations found for target") - + if List.for_all Option.is_none xs then Some def + else raise (Reporting.err_general l "Multiple instantiations found for target") + let instantiate target ast = let process_def outcomes = function - | DEF_aux (DEF_outcome (OV_aux (OV_outcome (id, TypSchm_aux (TypSchm_ts (typq, typ), _), args), l), outcome_defs), _) -> - Bindings.add id (typq, typ, args, l, outcome_defs) outcomes, [] - + | DEF_aux (DEF_outcome (OV_aux (OV_outcome (id, TypSchm_aux (TypSchm_ts (typq, typ), _), args), l), outcome_defs), _) + -> + (Bindings.add id (typq, typ, args, l, outcome_defs) outcomes, []) | DEF_aux (DEF_instantiation (IN_aux (IN_id id, annot), id_substs), def_annot) -> - let l = gen_loc (id_loc id) in - let env = env_of_annot annot in - let substs = Env.get_outcome_instantiation env in - let (typq, typ, args, outcome_l, outcome_defs) = match Bindings.find_opt id outcomes with - | Some outcome -> outcome - | None -> Reporting.unreachable (id_loc id) __POS__ ("Outcome for instantiation " ^ string_of_id id ^ " does not exist") - in + let l = gen_loc (id_loc id) in + let env = env_of_annot annot in + let substs = Env.get_outcome_instantiation env in + let typq, typ, args, outcome_l, outcome_defs = + match Bindings.find_opt id outcomes with + | Some outcome -> outcome + | None -> + Reporting.unreachable (id_loc id) __POS__ + ("Outcome for instantiation " ^ string_of_id id ^ " does not exist") + in - let rewrite_p_aux (pat, annot) = - match pat with - | P_typ (typ, pat) -> P_aux (P_typ (instantiate_typ substs typ, pat), annot) - | pat -> P_aux (pat, annot) - in - let rewrite_e_aux (exp, annot) = - match exp with - | E_app (f, args) -> E_aux (E_app (instantiate_id f id_substs, args), annot) - | E_typ (typ, exp) -> E_aux (E_typ (instantiate_typ substs typ, exp), annot) - | _ -> E_aux (exp, annot) - in - let pat_alg = { id_pat_alg with p_aux = rewrite_p_aux } in - let rewrite_pat rw pat = - fold_pat pat_alg pat - in - let rewrite_exp _ exp = - fold_exp { id_exp_alg with e_aux = rewrite_e_aux; pat_alg = pat_alg } exp - in + let rewrite_p_aux (pat, annot) = + match pat with + | P_typ (typ, pat) -> P_aux (P_typ (instantiate_typ substs typ, pat), annot) + | pat -> P_aux (pat, annot) + in + let rewrite_e_aux (exp, annot) = + match exp with + | E_app (f, args) -> E_aux (E_app (instantiate_id f id_substs, args), annot) + | E_typ (typ, exp) -> E_aux (E_typ (instantiate_typ substs typ, exp), annot) + | _ -> E_aux (exp, annot) + in + let pat_alg = { id_pat_alg with p_aux = rewrite_p_aux } in + let rewrite_pat rw pat = fold_pat pat_alg pat in + let rewrite_exp _ exp = fold_exp { id_exp_alg with e_aux = rewrite_e_aux; pat_alg } exp in - let valspec is_extern = - let extern = if is_extern then Some { pure = false; bindings = [("_", string_of_id id)] } else None in - DEF_aux (DEF_val (VS_aux (VS_val_spec (TypSchm_aux (TypSchm_ts (typq, instantiate_typ substs typ), l), id, extern, false), (l, empty_uannot))), def_annot) - in - let instantiated_def = - rewrite_ast_defs { rewriters_base with rewrite_pat = rewrite_pat; rewrite_exp = rewrite_exp } outcome_defs - |> List.map (instantiate_def target id id_substs) - |> instantiated_or_abstract (id_loc id) - in - let outcome_defs, _ = - (match instantiated_def with - | None -> [DEF_aux (DEF_pragma ("abstract", string_of_id id, gen_loc (id_loc id)), mk_def_annot (gen_loc (id_loc id))); valspec true] + let valspec is_extern = + let extern = if is_extern then Some { pure = false; bindings = [("_", string_of_id id)] } else None in + DEF_aux + ( DEF_val + (VS_aux + ( VS_val_spec (TypSchm_aux (TypSchm_ts (typq, instantiate_typ substs typ), l), id, extern, false), + (l, empty_uannot) + ) + ), + def_annot + ) + in + let instantiated_def = + rewrite_ast_defs { rewriters_base with rewrite_pat; rewrite_exp } outcome_defs + |> List.map (instantiate_def target id id_substs) + |> instantiated_or_abstract (id_loc id) + in + let outcome_defs, _ = + ( match instantiated_def with + | None -> + [ + DEF_aux + (DEF_pragma ("abstract", string_of_id id, gen_loc (id_loc id)), mk_def_annot (gen_loc (id_loc id))); + valspec true; + ] | Some def -> [valspec false; strip_def def] - ) |> Type_error.check_defs env - in - outcomes, outcome_defs - - | def -> - outcomes, [def] + ) + |> Type_error.check_defs env + in + (outcomes, outcome_defs) + | def -> (outcomes, [def]) in { ast with defs = snd (Util.fold_left_concat_map process_def Bindings.empty ast.defs) } diff --git a/src/lib/parse_ast.ml b/src/lib/parse_ast.ml index 7aa60dd50..605ecd2b8 100644 --- a/src/lib/parse_ast.ml +++ b/src/lib/parse_ast.ml @@ -87,470 +87,372 @@ exception Parse_error_locn of l * string type x = text (* identifier *) type ix = text (* infix identifier *) - -type -kind_aux = (* base kind *) - K_type (* kind of types *) - | K_int (* kind of natural number size expressions *) - | K_order (* kind of vector order specifications *) - | K_bool (* kind of constraints *) - - -type -kind = - K_aux of kind_aux * l - - -type -base_effect_aux = (* effect *) - BE_rreg (* read register *) - | BE_wreg (* write register *) - | BE_rmem (* read memory *) - | BE_wmem (* write memory *) - | BE_wmv (* write memory value *) - | BE_eamem (* address for write signaled *) - | BE_exmem (* determine if a store-exclusive (ARM) is going to succeed *) - | BE_barr (* memory barrier *) - | BE_depend (* dynamically dependent footprint *) - | BE_undef (* undefined-instruction exception *) - | BE_unspec (* unspecified values *) - | BE_nondet (* nondeterminism from intra-instruction parallelism *) - | BE_escape - | BE_config - -type -kid_aux = (* identifiers with kind, ticked to differentiate from program variables *) - Var of x - - -type -id_aux = (* Identifier *) - Id of x - | Operator of x (* remove infix status *) - -type -base_effect = - BE_aux of base_effect_aux * l - - -type -kid = - Kid_aux of kid_aux * l - - -type -id = - Id_aux of id_aux * l - -type -lit_aux = (* Literal constant *) - L_unit (* $() : _$ *) - | L_zero (* $_ : _$ *) - | L_one (* $_ : _$ *) - | L_true (* $_ : _$ *) - | L_false (* $_ : _$ *) - | L_num of Big_int.num (* natural number constant *) - | L_hex of string (* bit vector constant, C-style *) - | L_bin of string (* bit vector constant, C-style *) - | L_undef (* undefined value *) - | L_string of string (* string constant *) - | L_real of string - -type -lit = - L_aux of lit_aux * l - -type -atyp_aux = (* expressions of all kinds, to be translated to types, nats, orders, and effects after parsing *) - ATyp_id of id (* identifier *) - | ATyp_var of kid (* ticked variable *) - | ATyp_lit of lit (* literal *) - | ATyp_nset of kid * (Big_int.num) list (* set type *) - | ATyp_times of atyp * atyp (* product *) - | ATyp_sum of atyp * atyp (* sum *) - | ATyp_minus of atyp * atyp (* subtraction *) - | ATyp_exp of atyp (* exponential *) - | ATyp_neg of atyp (* Internal (but not M as I want a datatype constructor) negative nexp *) - | ATyp_inc (* increasing *) - | ATyp_dec (* decreasing *) - | ATyp_set of (base_effect) list (* effect set *) - | ATyp_fn of atyp * atyp * atyp (* Function type, last atyp is an effect *) - | ATyp_bidir of atyp * atyp * atyp (* Mapping type, last atyp is an effect *) - | ATyp_wild - | ATyp_tuple of (atyp) list (* Tuple type *) - | ATyp_app of id * (atyp) list (* type constructor application *) - | ATyp_exist of kinded_id list * atyp * atyp - -and atyp = - ATyp_aux of atyp_aux * l - - -and -kinded_id_aux = (* optionally kind-annotated identifier *) - KOpt_kind of string option * kid list * kind option (* kind-annotated variable *) - -and -kinded_id = - KOpt_aux of kinded_id_aux * l - -type -quant_item_aux = (* Either a kinded identifier or a nexp constraint for a typquant *) - QI_id of kinded_id (* An optionally kinded identifier *) - | QI_constraint of atyp (* A constraint for this type *) - - -type -quant_item = - QI_aux of quant_item_aux * l - - -type -typquant_aux = (* type quantifiers and constraints *) - TypQ_tq of (quant_item) list - | TypQ_no_forall (* sugar, omitting quantifier and constraints *) - - -type -typquant = - TypQ_aux of typquant_aux * l - -type -typschm_aux = (* type scheme *) - TypSchm_ts of typquant * atyp - - -type -typschm = - TypSchm_aux of typschm_aux * l - - -type -pat_aux = (* Pattern *) - P_lit of lit (* literal constant pattern *) - | P_wild (* wildcard - always matches *) - | P_typ of atyp * pat (* typed pattern *) - | P_id of id (* identifier *) - | P_var of pat * atyp (* bind pat to type variable *) - | P_app of id * pat list (* union constructor pattern *) - | P_vector of pat list (* vector pattern *) - | P_vector_concat of pat list (* concatenated vector pattern *) - | P_vector_subrange of id * Big_int.num * Big_int.num - | P_tuple of pat list (* tuple pattern *) - | P_list of pat list (* list pattern *) - | P_cons of pat * pat (* cons pattern *) - | P_string_append of pat list (* string append pattern, x ^^ y *) - | P_attribute of string * string * pat - -and pat = - P_aux of pat_aux * l - -and fpat_aux = (* Field pattern *) - FP_Fpat of id * pat - -and fpat = - FP_aux of fpat_aux * l +type kind_aux = + | (* base kind *) + K_type (* kind of types *) + | K_int (* kind of natural number size expressions *) + | K_order (* kind of vector order specifications *) + | K_bool (* kind of constraints *) + +type kind = K_aux of kind_aux * l + +type base_effect_aux = + | (* effect *) + BE_rreg (* read register *) + | BE_wreg (* write register *) + | BE_rmem (* read memory *) + | BE_wmem (* write memory *) + | BE_wmv (* write memory value *) + | BE_eamem (* address for write signaled *) + | BE_exmem (* determine if a store-exclusive (ARM) is going to succeed *) + | BE_barr (* memory barrier *) + | BE_depend (* dynamically dependent footprint *) + | BE_undef (* undefined-instruction exception *) + | BE_unspec (* unspecified values *) + | BE_nondet (* nondeterminism from intra-instruction parallelism *) + | BE_escape + | BE_config + +type kid_aux = (* identifiers with kind, ticked to differentiate from program variables *) + | Var of x + +type id_aux = (* Identifier *) + | Id of x | Operator of x (* remove infix status *) + +type base_effect = BE_aux of base_effect_aux * l + +type kid = Kid_aux of kid_aux * l + +type id = Id_aux of id_aux * l + +type lit_aux = + | (* Literal constant *) + L_unit (* $() : _$ *) + | L_zero (* $_ : _$ *) + | L_one (* $_ : _$ *) + | L_true (* $_ : _$ *) + | L_false (* $_ : _$ *) + | L_num of Big_int.num (* natural number constant *) + | L_hex of string (* bit vector constant, C-style *) + | L_bin of string (* bit vector constant, C-style *) + | L_undef (* undefined value *) + | L_string of string (* string constant *) + | L_real of string + +type lit = L_aux of lit_aux * l + +type atyp_aux = + (* expressions of all kinds, to be translated to types, nats, orders, and effects after parsing *) + | ATyp_id of id (* identifier *) + | ATyp_var of kid (* ticked variable *) + | ATyp_lit of lit (* literal *) + | ATyp_nset of kid * Big_int.num list (* set type *) + | ATyp_times of atyp * atyp (* product *) + | ATyp_sum of atyp * atyp (* sum *) + | ATyp_minus of atyp * atyp (* subtraction *) + | ATyp_exp of atyp (* exponential *) + | ATyp_neg of atyp (* Internal (but not M as I want a datatype constructor) negative nexp *) + | ATyp_inc (* increasing *) + | ATyp_dec (* decreasing *) + | ATyp_set of base_effect list (* effect set *) + | ATyp_fn of atyp * atyp * atyp (* Function type, last atyp is an effect *) + | ATyp_bidir of atyp * atyp * atyp (* Mapping type, last atyp is an effect *) + | ATyp_wild + | ATyp_tuple of atyp list (* Tuple type *) + | ATyp_app of id * atyp list (* type constructor application *) + | ATyp_exist of kinded_id list * atyp * atyp + +and atyp = ATyp_aux of atyp_aux * l + +and kinded_id_aux = + (* optionally kind-annotated identifier *) + | KOpt_kind of string option * kid list * kind option (* kind-annotated variable *) + +and kinded_id = KOpt_aux of kinded_id_aux * l + +type quant_item_aux = + (* Either a kinded identifier or a nexp constraint for a typquant *) + | QI_id of kinded_id (* An optionally kinded identifier *) + | QI_constraint of atyp (* A constraint for this type *) + +type quant_item = QI_aux of quant_item_aux * l + +type typquant_aux = + (* type quantifiers and constraints *) + | TypQ_tq of quant_item list + | TypQ_no_forall (* sugar, omitting quantifier and constraints *) + +type typquant = TypQ_aux of typquant_aux * l + +type typschm_aux = (* type scheme *) + | TypSchm_ts of typquant * atyp + +type typschm = TypSchm_aux of typschm_aux * l + +type pat_aux = + (* Pattern *) + | P_lit of lit (* literal constant pattern *) + | P_wild (* wildcard - always matches *) + | P_typ of atyp * pat (* typed pattern *) + | P_id of id (* identifier *) + | P_var of pat * atyp (* bind pat to type variable *) + | P_app of id * pat list (* union constructor pattern *) + | P_vector of pat list (* vector pattern *) + | P_vector_concat of pat list (* concatenated vector pattern *) + | P_vector_subrange of id * Big_int.num * Big_int.num + | P_tuple of pat list (* tuple pattern *) + | P_list of pat list (* list pattern *) + | P_cons of pat * pat (* cons pattern *) + | P_string_append of pat list (* string append pattern, x ^^ y *) + | P_attribute of string * string * pat + +and pat = P_aux of pat_aux * l + +and fpat_aux = (* Field pattern *) + | FP_Fpat of id * pat + +and fpat = FP_aux of fpat_aux * l type loop = While | Until -type measure_aux = (* optional termination measure for a loop *) - | Measure_none - | Measure_some of exp - -and measure = - | Measure_aux of measure_aux * l - -and -exp_aux = (* Expression *) - E_block of (exp) list (* block (parsing conflict with structs?) *) - | E_id of id (* identifier *) - | E_ref of id - | E_deref of exp - | E_lit of lit (* literal constant *) - | E_typ of atyp * exp (* cast *) - | E_app of id * exp list (* function application *) - | E_app_infix of exp * id * exp (* infix function application *) - | E_tuple of exp list (* tuple *) - | E_if of exp * exp * exp (* conditional *) - | E_loop of loop * measure * exp * exp - | E_for of id * exp * exp * exp * atyp * exp (* loop *) - | E_vector of (exp) list (* vector (indexed from 0) *) - | E_vector_access of exp * exp (* vector access *) - | E_vector_subrange of exp * exp * exp (* subvector extraction *) - | E_vector_update of exp * exp * exp (* vector functional update *) - | E_vector_update_subrange of exp * exp * exp * exp (* vector subrange update (with vector) *) - | E_vector_append of exp * exp (* vector concatenation *) - | E_list of (exp) list (* list *) - | E_cons of exp * exp (* cons *) - | E_struct of exp list (* struct *) - | E_struct_update of exp * (exp) list (* functional update of struct *) - | E_field of exp * id (* field projection from struct *) - | E_match of exp * (pexp) list (* pattern matching *) - | E_let of letbind * exp (* let expression *) - | E_assign of exp * exp (* imperative assignment *) - | E_sizeof of atyp - | E_constraint of atyp - | E_exit of exp - | E_throw of exp - | E_try of exp * pexp list - | E_return of exp - | E_assert of exp * exp - | E_var of exp * exp * exp - | E_attribute of string * string * exp - | E_internal_plet of pat * exp * exp - | E_internal_return of exp - -and exp = - E_aux of exp_aux * l - -and opt_default_aux = (* Optional default value for indexed vectors, to define a default value for any unspecified positions in a sparse map *) - Def_val_empty - | Def_val_dec of exp - -and opt_default = - Def_val_aux of opt_default_aux * l - -and pexp_aux = (* Pattern match *) - Pat_exp of pat * exp - | Pat_when of pat * exp * exp - -and pexp = - Pat_aux of pexp_aux * l - -and letbind_aux = (* Let binding *) - LB_val of pat * exp (* value binding, implicit type (pat must be total) *) - -and letbind = - LB_aux of letbind_aux * l - - -type -tannot_opt_aux = (* Optional type annotation for functions *) - Typ_annot_opt_none - | Typ_annot_opt_some of typquant * atyp - -type -typschm_opt_aux = - TypSchm_opt_none -| TypSchm_opt_some of typschm - -type -typschm_opt = - TypSchm_opt_aux of typschm_opt_aux * l - -type -effect_opt_aux = (* Optional effect annotation for functions *) - Effect_opt_none (* sugar for empty effect set *) - | Effect_opt_effect of atyp - - -type -rec_opt_aux = (* Optional recursive annotation for functions *) - Rec_none (* no termination measure *) - | Rec_measure of pat * exp (* recursive with termination measure *) - - -type -funcl_aux = (* Function clause *) - FCL_funcl of id * pexp - - -type -type_union_aux = (* Type union constructors *) - Tu_ty_id of atyp * id - | Tu_ty_anon_rec of (atyp * id) list * id - -type -tannot_opt = - Typ_annot_opt_aux of tannot_opt_aux * l - - -type -effect_opt = - Effect_opt_aux of effect_opt_aux * l - - -type -rec_opt = - Rec_aux of rec_opt_aux * l - - -type -funcl = - FCL_aux of funcl_aux * l - -type -type_union = - Tu_aux of type_union_aux * l - -type subst_aux = (* instantiation substitution *) - | IS_typ of kid * atyp (* instantiate a type variable with a type *) - | IS_id of id * id (* instantiate an identifier with another identifier *) - -type subst = - | IS_aux of subst_aux * l - -type -index_range_aux = (* index specification, for bitfields in register types *) - BF_single of atyp (* single index *) - | BF_range of atyp * atyp (* index range *) - | BF_concat of index_range * index_range (* concatenation of index ranges *) +type measure_aux = (* optional termination measure for a loop *) + | Measure_none | Measure_some of exp + +and measure = Measure_aux of measure_aux * l + +and exp_aux = + (* Expression *) + | E_block of exp list (* block (parsing conflict with structs?) *) + | E_id of id (* identifier *) + | E_ref of id + | E_deref of exp + | E_lit of lit (* literal constant *) + | E_typ of atyp * exp (* cast *) + | E_app of id * exp list (* function application *) + | E_app_infix of exp * id * exp (* infix function application *) + | E_tuple of exp list (* tuple *) + | E_if of exp * exp * exp (* conditional *) + | E_loop of loop * measure * exp * exp + | E_for of id * exp * exp * exp * atyp * exp (* loop *) + | E_vector of exp list (* vector (indexed from 0) *) + | E_vector_access of exp * exp (* vector access *) + | E_vector_subrange of exp * exp * exp (* subvector extraction *) + | E_vector_update of exp * exp * exp (* vector functional update *) + | E_vector_update_subrange of exp * exp * exp * exp (* vector subrange update (with vector) *) + | E_vector_append of exp * exp (* vector concatenation *) + | E_list of exp list (* list *) + | E_cons of exp * exp (* cons *) + | E_struct of exp list (* struct *) + | E_struct_update of exp * exp list (* functional update of struct *) + | E_field of exp * id (* field projection from struct *) + | E_match of exp * pexp list (* pattern matching *) + | E_let of letbind * exp (* let expression *) + | E_assign of exp * exp (* imperative assignment *) + | E_sizeof of atyp + | E_constraint of atyp + | E_exit of exp + | E_throw of exp + | E_try of exp * pexp list + | E_return of exp + | E_assert of exp * exp + | E_var of exp * exp * exp + | E_attribute of string * string * exp + | E_internal_plet of pat * exp * exp + | E_internal_return of exp + +and exp = E_aux of exp_aux * l + +and opt_default_aux = + | (* Optional default value for indexed vectors, to define a default value for any unspecified positions in a sparse map *) + Def_val_empty + | Def_val_dec of exp + +and opt_default = Def_val_aux of opt_default_aux * l + +and pexp_aux = (* Pattern match *) + | Pat_exp of pat * exp | Pat_when of pat * exp * exp + +and pexp = Pat_aux of pexp_aux * l + +and letbind_aux = (* Let binding *) + | LB_val of pat * exp (* value binding, implicit type (pat must be total) *) + +and letbind = LB_aux of letbind_aux * l + +type tannot_opt_aux = + | (* Optional type annotation for functions *) + Typ_annot_opt_none + | Typ_annot_opt_some of typquant * atyp + +type typschm_opt_aux = TypSchm_opt_none | TypSchm_opt_some of typschm + +type typschm_opt = TypSchm_opt_aux of typschm_opt_aux * l + +type effect_opt_aux = + | (* Optional effect annotation for functions *) + Effect_opt_none (* sugar for empty effect set *) + | Effect_opt_effect of atyp + +type rec_opt_aux = + | (* Optional recursive annotation for functions *) + Rec_none (* no termination measure *) + | Rec_measure of pat * exp (* recursive with termination measure *) + +type funcl_aux = (* Function clause *) + | FCL_funcl of id * pexp + +type type_union_aux = (* Type union constructors *) + | Tu_ty_id of atyp * id | Tu_ty_anon_rec of (atyp * id) list * id + +type tannot_opt = Typ_annot_opt_aux of tannot_opt_aux * l + +type effect_opt = Effect_opt_aux of effect_opt_aux * l + +type rec_opt = Rec_aux of rec_opt_aux * l + +type funcl = FCL_aux of funcl_aux * l + +type type_union = Tu_aux of type_union_aux * l + +type subst_aux = + (* instantiation substitution *) + | IS_typ of kid * atyp (* instantiate a type variable with a type *) + | IS_id of id * id (* instantiate an identifier with another identifier *) + +type subst = IS_aux of subst_aux * l + +type index_range_aux = + (* index specification, for bitfields in register types *) + | BF_single of atyp (* single index *) + | BF_range of atyp * atyp (* index range *) + | BF_concat of index_range * index_range (* concatenation of index ranges *) + +and index_range = BF_aux of index_range_aux * l + +type default_typing_spec_aux = + (* Default kinding or typing assumption, and default order for literal vectors and vector shorthands *) + | DT_order of kind * atyp + +type mpat_aux = + (* Mapping pattern. Mostly the same as normal patterns but only constructible parts *) + | MP_lit of lit + | MP_id of id + | MP_app of id * mpat list + | MP_vector of mpat list + | MP_vector_concat of mpat list + | MP_vector_subrange of id * Big_int.num * Big_int.num + | MP_tuple of mpat list + | MP_list of mpat list + | MP_cons of mpat * mpat + | MP_string_append of mpat list + | MP_typ of mpat * atyp + | MP_as of mpat * id -and index_range = - BF_aux of index_range_aux * l +and mpat = MP_aux of mpat_aux * l -type -default_typing_spec_aux = (* Default kinding or typing assumption, and default order for literal vectors and vector shorthands *) - DT_order of kind * atyp +type mpexp_aux = MPat_pat of mpat | MPat_when of mpat * exp +type mpexp = MPat_aux of mpexp_aux * l -type mpat_aux = (* Mapping pattern. Mostly the same as normal patterns but only constructible parts *) - | MP_lit of lit - | MP_id of id - | MP_app of id * ( mpat) list - | MP_vector of ( mpat) list - | MP_vector_concat of ( mpat) list - | MP_vector_subrange of id * Big_int.num * Big_int.num - | MP_tuple of ( mpat) list - | MP_list of ( mpat) list - | MP_cons of ( mpat) * ( mpat) - | MP_string_append of mpat list - | MP_typ of mpat * atyp - | MP_as of mpat * id - -and mpat = - | MP_aux of ( mpat_aux) * l - -type mpexp_aux = - | MPat_pat of ( mpat) - | MPat_when of ( mpat) * ( exp) - -type mpexp = - | MPat_aux of ( mpexp_aux) * l - -type mapcl_aux = (* mapping clause (bidirectional pattern-match) *) - | MCL_bidir of ( mpexp) * ( mpexp) +type mapcl_aux = + (* mapping clause (bidirectional pattern-match) *) + | MCL_bidir of mpexp * mpexp | MCL_forwards of mpexp * exp | MCL_backwards of mpexp * exp -type mapcl = - | MCL_aux of ( mapcl_aux) * l - -type mapdef_aux = (* mapping definition (bidirectional pattern-match function) *) - | MD_mapping of id * typschm_opt * ( mapcl) list - -type mapdef = - | MD_aux of ( mapdef_aux) * l +type mapcl = MCL_aux of mapcl_aux * l -type outcome_spec_aux = (* outcome declaration *) - | OV_outcome of id * typschm * kinded_id list +type mapdef_aux = + (* mapping definition (bidirectional pattern-match function) *) + | MD_mapping of id * typschm_opt * mapcl list -type outcome_spec = - | OV_aux of outcome_spec_aux * l +type mapdef = MD_aux of mapdef_aux * l -type -fundef_aux = (* Function definition *) - FD_function of rec_opt * tannot_opt * effect_opt * (funcl) list +type outcome_spec_aux = (* outcome declaration *) + | OV_outcome of id * typschm * kinded_id list -type -type_def_aux = (* Type definition body *) - TD_abbrev of id * typquant * kind * atyp (* type abbreviation *) - | TD_record of id * typquant * ((atyp * id)) list * bool (* struct type definition *) - | TD_variant of id * typquant * (type_union) list * bool (* union type definition *) - | TD_enum of id * (id * atyp) list * (id * exp option) list * bool (* enumeration type definition *) - | TD_bitfield of id * atyp * (id * index_range) list (* register mutable bitfield type definition *) +type outcome_spec = OV_aux of outcome_spec_aux * l -type -val_spec_aux = (* Value type specification *) - VS_val_spec of typschm * id * extern option * bool +type fundef_aux = (* Function definition *) + | FD_function of rec_opt * tannot_opt * effect_opt * funcl list -type -dec_spec_aux = (* Register declarations *) - DEC_reg of atyp * id * exp option +type type_def_aux = + (* Type definition body *) + | TD_abbrev of id * typquant * kind * atyp (* type abbreviation *) + | TD_record of id * typquant * (atyp * id) list * bool (* struct type definition *) + | TD_variant of id * typquant * type_union list * bool (* union type definition *) + | TD_enum of id * (id * atyp) list * (id * exp option) list * bool (* enumeration type definition *) + | TD_bitfield of id * atyp * (id * index_range) list (* register mutable bitfield type definition *) -type -scattered_def_aux = (* Function and type union definitions that can be spread across - a file. Each one must end in $_$ *) - SD_function of rec_opt * tannot_opt * effect_opt * id (* scattered function definition header *) - | SD_funcl of funcl (* scattered function definition clause *) - | SD_variant of id * typquant (* scattered union definition header *) - | SD_unioncl of id * type_union (* scattered union definition member *) - | SD_mapping of id * tannot_opt - | SD_mapcl of id * mapcl - | SD_end of id (* scattered definition end *) +type val_spec_aux = (* Value type specification *) + | VS_val_spec of typschm * id * extern option * bool +type dec_spec_aux = (* Register declarations *) + | DEC_reg of atyp * id * exp option -type -default_typing_spec = - DT_aux of default_typing_spec_aux * l +type scattered_def_aux = + (* Function and type union definitions that can be spread across + a file. Each one must end in $_$ *) + | SD_function of rec_opt * tannot_opt * effect_opt * id (* scattered function definition header *) + | SD_funcl of funcl (* scattered function definition clause *) + | SD_variant of id * typquant (* scattered union definition header *) + | SD_unioncl of id * type_union (* scattered union definition member *) + | SD_mapping of id * tannot_opt + | SD_mapcl of id * mapcl + | SD_end of id (* scattered definition end *) +type default_typing_spec = DT_aux of default_typing_spec_aux * l -type -fundef = - FD_aux of fundef_aux * l +type fundef = FD_aux of fundef_aux * l +type type_def = TD_aux of type_def_aux * l -type -type_def = - TD_aux of type_def_aux * l +type val_spec = VS_aux of val_spec_aux * l +type dec_spec = DEC_aux of dec_spec_aux * l -type -val_spec = - VS_aux of val_spec_aux * l +type loop_measure = Loop of loop * exp -type -dec_spec = - DEC_aux of dec_spec_aux * l - - -type loop_measure = - | Loop of loop * exp - - -type -scattered_def = - SD_aux of scattered_def_aux * l +type scattered_def = SD_aux of scattered_def_aux * l type prec = Infix | InfixL | InfixR -type fixity_token = (prec * Big_int.num * string) - -type def_aux = (* Top-level definition *) - DEF_type of type_def (* type definition *) - | DEF_fundef of fundef (* function definition *) - | DEF_mapdef of mapdef (* mapping definition *) - | DEF_impl of funcl (* impl definition *) - | DEF_let of letbind (* value definition *) - | DEF_overload of id * id list (* operator overload specifications *) - | DEF_fixity of prec * Big_int.num * id (* fixity declaration *) - | DEF_val of val_spec (* top-level type constraint *) - | DEF_outcome of outcome_spec * def list (* top-level outcome definition *) - | DEF_instantiation of id * subst list (* instantiation *) - | DEF_default of default_typing_spec (* default kind and type assumptions *) - | DEF_scattered of scattered_def (* scattered definition *) - | DEF_measure of id * pat * exp (* separate termination measure declaration *) - | DEF_loop_measures of id * loop_measure list (* separate termination measure declaration *) - | DEF_register of dec_spec (* register declaration *) - | DEF_pragma of string * string - | DEF_attribute of string * string * def - | DEF_doc of string * def - | DEF_internal_mutrec of fundef list +type fixity_token = prec * Big_int.num * string + +type def_aux = + (* Top-level definition *) + | DEF_type of type_def (* type definition *) + | DEF_fundef of fundef (* function definition *) + | DEF_mapdef of mapdef (* mapping definition *) + | DEF_impl of funcl (* impl definition *) + | DEF_let of letbind (* value definition *) + | DEF_overload of id * id list (* operator overload specifications *) + | DEF_fixity of prec * Big_int.num * id (* fixity declaration *) + | DEF_val of val_spec (* top-level type constraint *) + | DEF_outcome of outcome_spec * def list (* top-level outcome definition *) + | DEF_instantiation of id * subst list (* instantiation *) + | DEF_default of default_typing_spec (* default kind and type assumptions *) + | DEF_scattered of scattered_def (* scattered definition *) + | DEF_measure of id * pat * exp (* separate termination measure declaration *) + | DEF_loop_measures of id * loop_measure list (* separate termination measure declaration *) + | DEF_register of dec_spec (* register declaration *) + | DEF_pragma of string * string + | DEF_attribute of string * string * def + | DEF_doc of string * def + | DEF_internal_mutrec of fundef list and def = DEF_aux of def_aux * l -type -lexp_aux = (* lvalue expression, can't occur out of the parser *) - LE_id of id (* identifier *) - | LE_mem of id * (exp) list - | LE_vector of lexp * exp (* vector element *) - | LE_vector_range of lexp * exp * exp (* subvector *) - | LE_vector_concat of lexp list - | LE_field of lexp * id (* struct field *) - -and lexp = - LE_aux of lexp_aux * l +type lexp_aux = + (* lvalue expression, can't occur out of the parser *) + | LE_id of id (* identifier *) + | LE_mem of id * exp list + | LE_vector of lexp * exp (* vector element *) + | LE_vector_range of lexp * exp * exp (* subvector *) + | LE_vector_concat of lexp list + | LE_field of lexp * id (* struct field *) +and lexp = LE_aux of lexp_aux * l -type -defs = (* Definition sequence *) - Defs of (string * def list) list +type defs = (* Definition sequence *) + | Defs of (string * def list) list diff --git a/src/lib/parser_combinators.ml b/src/lib/parser_combinators.ml index 0facf9ac6..b321109b0 100644 --- a/src/lib/parser_combinators.ml +++ b/src/lib/parser_combinators.ml @@ -65,55 +65,29 @@ (* SUCH DAMAGE. *) (****************************************************************************) -type 'a parse_result = - | Ok of 'a * Str.split_result list - | Fail +type 'a parse_result = Ok of 'a * Str.split_result list | Fail type 'a parser = Str.split_result list -> 'a parse_result -let (>>=) (m : 'a parser) (f : 'a -> 'b parser) (toks : Str.split_result list) = - match m toks with - | Ok (r, toks) -> f r toks - | Fail -> Fail +let ( >>= ) (m : 'a parser) (f : 'a -> 'b parser) (toks : Str.split_result list) = + match m toks with Ok (r, toks) -> f r toks | Fail -> Fail -let pmap f m toks = - match m toks with - | Ok (r, toks) -> Ok (f r, toks) - | Fail -> Fail +let pmap f m toks = match m toks with Ok (r, toks) -> Ok (f r, toks) | Fail -> Fail -let token f = function - | tok :: toks -> - begin match f tok with - | Some x -> Ok (x, toks) - | None -> Fail - end - | [] -> Fail +let token f = function tok :: toks -> begin match f tok with Some x -> Ok (x, toks) | None -> Fail end | [] -> Fail let preturn x toks = Ok (x, toks) let rec plist m toks = match m toks with - | Ok (x, toks) -> - begin match plist m toks with - | Ok (xs, toks) -> Ok (x :: xs, toks) - | Fail -> Fail - end + | Ok (x, toks) -> begin match plist m toks with Ok (xs, toks) -> Ok (x :: xs, toks) | Fail -> Fail end | Fail -> Ok ([], toks) -let pchoose m n toks = - match m toks with - | Fail -> n toks - | Ok (x, toks) -> Ok (x, toks) +let pchoose m n toks = match m toks with Fail -> n toks | Ok (x, toks) -> Ok (x, toks) let parse p delim_regexp input = let delim = Str.regexp delim_regexp in let tokens = Str.full_split delim input in - let non_whitespace = function - | Str.Delim d when String.trim d = "" -> false - | _ -> true - in + let non_whitespace = function Str.Delim d when String.trim d = "" -> false | _ -> true in let tokens = List.filter non_whitespace tokens in - match p tokens with - | Ok (result, []) -> Some result - | Ok (_, _) -> None - | Fail -> None + match p tokens with Ok (result, []) -> Some result | Ok (_, _) -> None | Fail -> None diff --git a/src/lib/pattern_completeness.ml b/src/lib/pattern_completeness.ml index 9ec227089..a548ed38a 100644 --- a/src/lib/pattern_completeness.ml +++ b/src/lib/pattern_completeness.ml @@ -75,45 +75,52 @@ module IntIntSet = Util.IntIntSet let opt_debug_no_literals = ref false type ctx = { - variants : (typquant * type_union list) Bindings.t; - enums : IdSet.t Bindings.t; - constraints : n_constraint list; - } - -module type Config = - sig - type t - val typ_of_t : t -> typ - val add_attribute : l -> string -> string -> t -> t - end + variants : (typquant * type_union list) Bindings.t; + enums : IdSet.t Bindings.t; + constraints : n_constraint list; +} + +module type Config = sig + type t + val typ_of_t : t -> typ + val add_attribute : l -> string -> string -> t -> t +end -type row_index = { - loc: l; - num: int - } +type row_index = { loc : l; num : int } -type 'a rows = Rows of ((row_index * 'a) list) -type 'a columns = Columns of ('a list) +type 'a rows = Rows of (row_index * 'a) list +type 'a columns = Columns of 'a list -type column_type = Tuple_column of int | App_column of id | Bool_column | Enum_column of id | Lit_column | List_column | Unknown_column +type column_type = + | Tuple_column of int + | App_column of id + | Bool_column + | Enum_column of id + | Lit_column + | List_column + | Unknown_column type complete_info = { - (* As we check completeness, we check submatrices which correspond to a subset of rows in the overall case statement *) - rows: IntSet.t; - (* These literal patterns can be turned into wildcards, as row index number * pattern number pairs *) - wildcards: IntIntSet.t; - (* Wildcards we must keep because they cannot be removed in all submatrices *) - preserved_literals: IntIntSet.t; - (* These rows are redundant *) - redundant: IntSet.t; - } + (* As we check completeness, we check submatrices which correspond to a subset of rows in the overall case statement *) + rows : IntSet.t; + (* These literal patterns can be turned into wildcards, as row index number * pattern number pairs *) + wildcards : IntIntSet.t; + (* Wildcards we must keep because they cannot be removed in all submatrices *) + preserved_literals : IntIntSet.t; + (* These rows are redundant *) + redundant : IntSet.t; +} let union_complete lhs rhs = let all_wildcards = IntIntSet.union lhs.wildcards rhs.wildcards in let shared_wildcards = IntIntSet.inter lhs.wildcards rhs.wildcards in let wildcards_lhs = IntIntSet.filter (fun (r, _) -> IntSet.mem r (IntSet.diff lhs.rows rhs.rows)) all_wildcards in let wildcards_rhs = IntIntSet.filter (fun (r, _) -> IntSet.mem r (IntSet.diff rhs.rows lhs.rows)) all_wildcards in - let new_preserved = IntIntSet.diff (IntIntSet.filter (fun (r, _) -> IntSet.mem r (IntSet.inter rhs.rows lhs.rows)) all_wildcards) shared_wildcards in + let new_preserved = + IntIntSet.diff + (IntIntSet.filter (fun (r, _) -> IntSet.mem r (IntSet.inter rhs.rows lhs.rows)) all_wildcards) + shared_wildcards + in let only_in_lhs = IntSet.diff lhs.rows rhs.rows in let only_in_rhs = IntSet.diff rhs.rows lhs.rows in { @@ -124,19 +131,18 @@ let union_complete lhs rhs = } let get_wildcard_patterns (cinfo : complete_info) = IntIntSet.elements cinfo.wildcards |> List.map snd -let get_preserved_patterns (cinfo : complete_info) = IntIntSet.elements cinfo.preserved_literals |> List.map snd |> IntSet.of_list +let get_preserved_patterns (cinfo : complete_info) = + IntIntSet.elements cinfo.preserved_literals |> List.map snd |> IntSet.of_list -type 'a completeness = - | Incomplete of 'a - | Complete of complete_info - | Completeness_unknown +type 'a completeness = Incomplete of 'a | Complete of complete_info | Completeness_unknown -let mk_complete ?redundant:(redundant = []) rows wildcards = - Complete { +let mk_complete ?(redundant = []) rows wildcards = + Complete + { rows = IntSet.of_list rows; wildcards = IntIntSet.of_list wildcards; preserved_literals = IntIntSet.empty; - redundant = IntSet.of_list redundant + redundant = IntSet.of_list redundant; } let completeness_map f g = function @@ -147,8 +153,13 @@ let completeness_map f g = function (* turn a [t pat] into a [(t, int) pat] where each subpattern is uniquely identified *) let number_pat (from : int) (pat : 'a pat) : ('a * int) pat * int = let rec go counter (P_aux (aux, (l, t))) = - let count () = let c = !counter in (counter := c + 1; c) in - let aux = match aux with + let count () = + let c = !counter in + counter := c + 1; + c + in + let aux = + match aux with | P_or (p1, p2) -> P_or (go counter p1, go counter p2) | P_not p -> P_not (go counter p) | P_as (p, id) -> P_as (go counter p, id) @@ -170,12 +181,14 @@ let number_pat (from : int) (pat : 'a pat) : ('a * int) pat * int = in let counter = ref from in let pat = go counter pat in - pat, !counter + (pat, !counter) let preserved_explanation = "Sail cannot simplify the above pattern match:\n" - ^ "This bitvector pattern literal must be kept, as it is required for Sail to show that the surrounding pattern match is complete.\n" - ^ "When translated into prover targets (e.g. Lem, Coq) without native bitvector patterns, they may be unable to verify that the match covers all possible cases." + ^ "This bitvector pattern literal must be kept, as it is required for Sail to show that the surrounding pattern \ + match is complete.\n" + ^ "When translated into prover targets (e.g. Lem, Coq) without native bitvector patterns, they may be unable to \ + verify that the match covers all possible cases." let rows_to_list (Rows rs) = rs let columns_to_list (Columns cs) = cs @@ -185,20 +198,19 @@ type 'a cr_matrix = 'a rows columns let pop_column (matrix : 'a rc_matrix) : ((row_index * 'a) list * 'a rc_matrix) option = match rows_to_list matrix with - | ((l, Columns (_ :: _)) :: _) as matrix -> - Some (List.map (fun (l, row) -> (l, List.hd (columns_to_list row))) matrix, Rows (List.map (fun (l, row) -> (l, Columns (List.tl (columns_to_list row)))) matrix)) - | _ -> - None + | (l, Columns (_ :: _)) :: _ as matrix -> + Some + ( List.map (fun (l, row) -> (l, List.hd (columns_to_list row))) matrix, + Rows (List.map (fun (l, row) -> (l, Columns (List.tl (columns_to_list row)))) matrix) + ) + | _ -> None let rec transpose (matrix : 'a rc_matrix) : 'a cr_matrix = match pop_column matrix with | Some (col, matrix) -> Columns (Rows col :: columns_to_list (transpose matrix)) | None -> Columns [] -let row_matrix_empty (Rows rows) = - match rows with - | [] -> true - | _ -> false +let row_matrix_empty (Rows rows) = match rows with [] -> true | _ -> false let row_matrix_width l (Rows rows) = match rows with @@ -207,8 +219,7 @@ let row_matrix_width l (Rows rows) = let row_matrix_height (Rows rows) = List.length rows -module Make(C: Config) = struct - +module Make (C : Config) = struct type bv_constraint = | BVC_eq of bv_constraint * bv_constraint | BVC_and of bv_constraint * bv_constraint @@ -221,14 +232,12 @@ module Make(C: Config) = struct | BVC_eq (bvc1, bvc2) -> "(= " ^ string_of_bv_constraint bvc1 ^ " " ^ string_of_bv_constraint bvc2 ^ ")" | BVC_and (bvc1, bvc2) -> "(and " ^ string_of_bv_constraint bvc1 ^ " " ^ string_of_bv_constraint bvc2 ^ ")" | BVC_bvand (bvc1, bvc2) -> "(bvand " ^ string_of_bv_constraint bvc1 ^ " " ^ string_of_bv_constraint bvc2 ^ ")" - | BVC_extract (n, m, bvc) -> "((_ extract " ^ string_of_int n ^ " " ^ string_of_int m ^ ") " ^ string_of_bv_constraint bvc ^ ")" + | BVC_extract (n, m, bvc) -> + "((_ extract " ^ string_of_int n ^ " " ^ string_of_int m ^ ") " ^ string_of_bv_constraint bvc ^ ")" | BVC_true -> "true" | BVC_lit lit -> lit - let bvc_and x y = match (x, y) with - | BVC_true, _ -> y - | _, BVC_true -> x - | _, _ -> BVC_and (x, y) + let bvc_and x y = match (x, y) with BVC_true, _ -> y | _, BVC_true -> x | _, _ -> BVC_and (x, y) let typ_of_pat (P_aux (_, (_, (t, _)))) = C.typ_of_t t @@ -236,12 +245,11 @@ module Make(C: Config) = struct let preserved = get_preserved_patterns cinfo in let wildcards = get_wildcard_patterns cinfo in let rec go wild (P_aux (aux, (l, (t, n))) as full_pat) = - if IntSet.mem n preserved then ( - Reporting.warn "Required literal" l preserved_explanation - ); + if IntSet.mem n preserved then Reporting.warn "Required literal" l preserved_explanation; let wild = wild || List.exists (fun wildcard -> wildcard = n) wildcards in let t = ref t in - let aux = match aux with + let aux = + match aux with | P_or (p1, p2) -> P_or (go wild p1, go wild p2) | P_not p -> P_not (go wild p) | P_as (p, id) -> P_as (go wild p, id) @@ -257,11 +265,11 @@ module Make(C: Config) = struct | P_string_append ps -> P_string_append (List.map (go wild) ps) | P_id id -> P_id id | P_lit (L_aux (L_num n, _)) when wild -> - t := C.add_attribute (gen_loc l) "int_wildcard" (Big_int.to_string n) !t; - P_wild + t := C.add_attribute (gen_loc l) "int_wildcard" (Big_int.to_string n) !t; + P_wild | P_lit _ when wild -> - let typ = typ_of_pat full_pat in - P_typ (typ, P_aux (P_wild, (l, !t))) + let typ = typ_of_pat full_pat in + P_typ (typ, P_aux (P_wild, (l, !t))) | P_lit lit -> P_lit lit | P_wild -> P_wild in @@ -289,8 +297,7 @@ module Make(C: Config) = struct | GP_unknown -> "?" | GP_lit lit -> string_of_lit lit | GP_tuple gpats -> "(" ^ Util.string_of_list ", " _string_of_gpat gpats ^ ")" - | GP_app (_, ctor, gpats) -> - string_of_id ctor ^ "(" ^ Util.string_of_list ", " _string_of_gpat gpats ^ ")" + | GP_app (_, ctor, gpats) -> string_of_id ctor ^ "(" ^ Util.string_of_list ", " _string_of_gpat gpats ^ ")" | GP_bitvector (_, _, bvc) -> string_of_bv_constraint (bvc (BVC_lit "x")) | GP_num (_, n, _) -> Big_int.to_string n | GP_enum (_, id) -> string_of_id id @@ -301,117 +308,105 @@ module Make(C: Config) = struct let _debug_rc_matrix (Rows rs) = prerr_endline "=== MATRIX ==="; - List.iter (fun (_, Columns c) -> - prerr_endline (Util.string_of_list ", " _string_of_gpat c) - ) rs + List.iter (fun (_, Columns c) -> prerr_endline (Util.string_of_list ", " _string_of_gpat c)) rs [@@@coverage on] let rec generalize ctx head_exp_typ (P_aux (p_aux, (l, (_, pnum))) as pat) = let typ = typ_of_pat pat in match p_aux with | P_lit (L_aux (L_unit, _)) -> - (* Unit pattern always matches on unit, so generalize to wildcard *) - GP_wild - - | P_lit (L_aux (L_hex hex, _)) -> GP_bitvector (pnum, String.length hex * 4, fun x -> BVC_eq (x, BVC_lit ("#x" ^ hex))) + (* Unit pattern always matches on unit, so generalize to wildcard *) + GP_wild + | P_lit (L_aux (L_hex hex, _)) -> + GP_bitvector (pnum, String.length hex * 4, fun x -> BVC_eq (x, BVC_lit ("#x" ^ hex))) | P_lit (L_aux (L_bin bin, _)) -> GP_bitvector (pnum, String.length bin, fun x -> BVC_eq (x, BVC_lit ("#b" ^ bin))) - | P_vector pats when is_bitvector_typ typ -> - let mask, bits = - List.fold_left (fun (mask, bits) (P_aux (pat, _)) -> - let rec go pat = match pat with - | P_lit (L_aux (L_one, _)) -> (mask ^ "1", bits ^ "1") - | P_lit (L_aux (L_zero, _)) -> (mask ^ "1", bits ^ "0") - | P_wild | P_id _ -> (mask ^ "0", bits ^ "0") - | P_typ (_, P_aux (pat, _)) -> go pat - | _ -> - Reporting.warn "Unexpected pattern" l ""; - (mask ^ "0", bits ^ "0") - in - go pat - ) ("#b", "#b") pats - in - GP_bitvector (pnum, List.length pats, fun x -> BVC_eq (BVC_bvand (BVC_lit mask, x), BVC_lit bits)) - - | P_vector pats -> - GP_vector (List.map (generalize ctx None) pats) - - | P_vector_concat pats when is_bitvector_typ typ -> - let lengths = - List.fold_left (fun acc typ -> - match acc with - | None -> None - | Some lengths -> - let (nexp, _, _) = vector_typ_args_of typ in - match int_of_nexp_opt nexp with - | Some n -> Some (Big_int.to_int n :: lengths) - | None -> None - ) (Some []) (List.map typ_of_pat pats) in - let gpats = List.map (generalize ctx None) pats in - begin match lengths with - | Some lengths -> - let (total, slices) = List.fold_left (fun (total, acc) len -> (total + len, (total + len - 1, total) :: acc)) (0, []) lengths in - let bvc = fun x -> - List.fold_left2 (fun bvc (n, m) gpat -> - match gpat with - | GP_bitvector (_, _, bvc_subpat) -> - bvc_and bvc (bvc_subpat (BVC_extract (n, m, x))) - | GP_wild -> - bvc + let mask, bits = + List.fold_left + (fun (mask, bits) (P_aux (pat, _)) -> + let rec go pat = + match pat with + | P_lit (L_aux (L_one, _)) -> (mask ^ "1", bits ^ "1") + | P_lit (L_aux (L_zero, _)) -> (mask ^ "1", bits ^ "0") + | P_wild | P_id _ -> (mask ^ "0", bits ^ "0") + | P_typ (_, P_aux (pat, _)) -> go pat | _ -> - Reporting.unreachable l __POS__ "Invalid bitvector pattern" [@coverage off] - ) BVC_true slices gpats - in - GP_bitvector (pnum, total, bvc) - | None -> - GP_wild - end - - | P_tuple pats -> - begin match head_exp_typ with - | Some (Typ_aux (Typ_tuple typs, _)) when List.length pats = List.length typs -> - GP_tuple (List.map2 (fun pat typ -> generalize ctx (Some typ) pat) pats typs) - | _ -> - GP_tuple (List.map (generalize ctx None) pats) - end - + Reporting.warn "Unexpected pattern" l ""; + (mask ^ "0", bits ^ "0") + in + go pat + ) + ("#b", "#b") pats + in + GP_bitvector (pnum, List.length pats, fun x -> BVC_eq (BVC_bvand (BVC_lit mask, x), BVC_lit bits)) + | P_vector pats -> GP_vector (List.map (generalize ctx None) pats) + | P_vector_concat pats when is_bitvector_typ typ -> + let lengths = + List.fold_left + (fun acc typ -> + match acc with + | None -> None + | Some lengths -> ( + let nexp, _, _ = vector_typ_args_of typ in + match int_of_nexp_opt nexp with Some n -> Some (Big_int.to_int n :: lengths) | None -> None + ) + ) + (Some []) (List.map typ_of_pat pats) + in + let gpats = List.map (generalize ctx None) pats in + begin + match lengths with + | Some lengths -> + let total, slices = + List.fold_left (fun (total, acc) len -> (total + len, (total + len - 1, total) :: acc)) (0, []) lengths + in + let bvc x = + List.fold_left2 + (fun bvc (n, m) gpat -> + match gpat with + | GP_bitvector (_, _, bvc_subpat) -> bvc_and bvc (bvc_subpat (BVC_extract (n, m, x))) + | GP_wild -> bvc + | _ -> Reporting.unreachable l __POS__ "Invalid bitvector pattern" [@coverage off] + ) + BVC_true slices gpats + in + GP_bitvector (pnum, total, bvc) + | None -> GP_wild + end + | P_tuple pats -> begin + match head_exp_typ with + | Some (Typ_aux (Typ_tuple typs, _)) when List.length pats = List.length typs -> + GP_tuple (List.map2 (fun pat typ -> generalize ctx (Some typ) pat) pats typs) + | _ -> GP_tuple (List.map (generalize ctx None) pats) + end | P_app (id, pats) -> - let typ_id = match typ with - | Typ_aux (Typ_app (id, _), _) -> id - | Typ_aux (Typ_id id, _) -> id - | _ -> failwith "Bad type" - in - GP_app (typ_id, id, List.map (generalize ctx None) pats) - + let typ_id = + match typ with Typ_aux (Typ_app (id, _), _) -> id | Typ_aux (Typ_id id, _) -> id | _ -> failwith "Bad type" + in + GP_app (typ_id, id, List.map (generalize ctx None) pats) | P_lit (L_aux (L_true, _)) -> GP_bool true | P_lit (L_aux (L_false, _)) -> GP_bool false - | P_lit (L_aux (L_num n, _)) -> - begin match head_exp_typ with - | Some (Typ_aux (Typ_app (f, [A_aux (A_nexp (Nexp_aux (Nexp_var v, _)), _)]), _)) - when string_of_id f = "atom" || string_of_id f = "implicit" -> - GP_num (pnum, n, Some v) - | _ -> - GP_num (pnum, n, None) - end + | P_lit (L_aux (L_num n, _)) -> begin + match head_exp_typ with + | Some (Typ_aux (Typ_app (f, [A_aux (A_nexp (Nexp_aux (Nexp_var v, _)), _)]), _)) + when string_of_id f = "atom" || string_of_id f = "implicit" -> + GP_num (pnum, n, Some v) + | _ -> GP_num (pnum, n, None) + end | P_lit lit -> GP_lit lit | P_wild -> GP_wild | P_var (pat, _) -> generalize ctx head_exp_typ pat | P_as (pat, _) -> generalize ctx head_exp_typ pat | P_typ (_, pat) -> generalize ctx head_exp_typ pat - | P_vector_subrange _ -> GP_wild - - | P_id id -> - begin match List.find_opt (fun (enum, ctors) -> IdSet.mem id ctors) (Bindings.bindings ctx.enums) with - | Some (enum, _) -> GP_enum (enum, id) - | None -> GP_wild - end - - | P_cons (hd_pat, tl_pat) -> - GP_cons (generalize ctx head_exp_typ hd_pat, generalize ctx head_exp_typ tl_pat) + | P_id id -> begin + match List.find_opt (fun (enum, ctors) -> IdSet.mem id ctors) (Bindings.bindings ctx.enums) with + | Some (enum, _) -> GP_enum (enum, id) + | None -> GP_wild + end + | P_cons (hd_pat, tl_pat) -> GP_cons (generalize ctx head_exp_typ hd_pat, generalize ctx head_exp_typ tl_pat) | P_list xs -> - List.fold_right (fun pat tl_gpat -> GP_cons (generalize ctx head_exp_typ pat, tl_gpat)) xs GP_empty_list - + List.fold_right (fun pat tl_gpat -> GP_cons (generalize ctx head_exp_typ pat, tl_gpat)) xs GP_empty_list | _ -> GP_unknown let rec find_smtlib_type = function @@ -420,9 +415,7 @@ module Make(C: Config) = struct | _ :: rest -> find_smtlib_type rest | [] -> None - let is_simple_gpat = function - | GP_bitvector _ | GP_num _ | GP_wild -> true - | _ -> false + let is_simple_gpat = function GP_bitvector _ | GP_num _ | GP_wild -> true | _ -> false let rec column_type = function | (_, GP_tuple gpats) :: _ -> Tuple_column (List.length gpats) @@ -435,7 +428,8 @@ module Make(C: Config) = struct | [] -> Unknown_column let rec unmatched_string_literal max_length = function - | (_, GP_lit (L_aux (L_string str, _))) :: rest -> unmatched_string_literal (max (String.length str) max_length) rest + | (_, GP_lit (L_aux (L_string str, _))) :: rest -> + unmatched_string_literal (max (String.length str) max_length) rest | _ :: rest -> unmatched_string_literal max_length rest | [] -> L_string (String.make (max_length + 1) '?') @@ -452,82 +446,87 @@ module Make(C: Config) = struct let simple_matrix_is_complete ctx matrix = let vars = - List.mapi (fun i (Rows column) -> - match find_smtlib_type column with - | None -> None - | Some ty -> Some (i, ty) - ) (columns_to_list (transpose matrix)) + List.mapi + (fun i (Rows column) -> match find_smtlib_type column with None -> None | Some ty -> Some (i, ty)) + (columns_to_list (transpose matrix)) in let just_vars = vars |> Util.option_these in let all_rows = List.map (fun (idx, _) -> idx.num) (rows_to_list matrix) in match just_vars with - | [] when row_matrix_height matrix = 1 -> mk_complete all_rows [] (* The matrix is a single row of wildcard patterns *) - | _ -> - let head_exp_constraint, var_map, _ = - Constraint.constraint_to_smt Parse_ast.Unknown (List.fold_left nc_and nc_true ctx.constraints) in - let created_vars = ref KidSet.empty in - (* We set this true if we need to include the head expression constraint in the generated SMT problem *) - let require_head_exp_constraint = ref false in - let constrs = - List.map (fun (l, Columns row) -> - let row_constrs = - List.map2 (fun var gpat -> - match var, gpat with - | (Some (i, _), GP_bitvector (_, _, bvc)) -> Some (string_of_bv_constraint (bvc (BVC_lit ("p" ^ string_of_int i)))) - | (Some (i, _), GP_num (_, n, Some v)) -> - let smt_var, created = var_map v in - (* If the variable was not already in the map (and has therefore just been created), then it is unconstrained *) - if created then ( - created_vars := KidSet.add v !created_vars - ); - if not (KidSet.mem v !created_vars) then ( - require_head_exp_constraint := true; - Some (Printf.sprintf "(or (= p%d %s) (not (= p%d %s)))" i (Big_int.to_string n) i smt_var) - ) else ( - Some (Printf.sprintf "(= p%d %s)" i (Big_int.to_string n)) - ) - | (Some (i, _), GP_num (_, n, None)) -> - Some (Printf.sprintf "(= p%d %s)" i (Big_int.to_string n)) - | _ -> None - ) vars row - |> Util.option_these - in - match row_constrs with - | [] -> (l, None) - | [c] -> (l, Some ("(assert (not " ^ Util.string_of_list " " (fun x -> x) row_constrs ^ "))")) - | _ -> (l, Some ("(assert (not (and " ^ Util.string_of_list " " (fun x -> x) row_constrs ^ ")))")) - ) (rows_to_list matrix) - in - (* Check if we have any row containing only wildcards, hence matrix is trivially unsatisfiable *) - match Util.find_rest_opt (fun (_, constr) -> Option.is_none constr) constrs with - | Some (_, []) -> mk_complete all_rows [] - (* If there are any rows after the wildcard row, they are redundant *) - | Some (_, redundant) -> - mk_complete ~redundant:(List.map (fun (idx, _) -> idx.num) redundant) all_rows [] - | None -> - let smtlib = - (if !require_head_exp_constraint then head_exp_constraint ^ "\n" else "") - ^ Util.string_of_list "\n" (fun (v, ty) -> Printf.sprintf "(declare-const p%d %s)" v ty) just_vars ^ "\n" - ^ Util.string_of_list "\n" (fun x -> x) (Util.option_these (List.map snd constrs)) ^ "\n" - ^ "(check-sat)\n" - ^ "(get-model)\n" - in - match Constraint.call_smt_solve_bitvector Parse_ast.Unknown smtlib just_vars with - | Some lits -> - if !opt_debug_no_literals then ( - Incomplete (List.init (List.length vars) (fun _ -> mk_lit_exp L_undef)) - ) else ( - Incomplete (List.init (List.length vars) (fun i -> match List.assoc_opt i lits with - | Some lit -> mk_exp (E_lit lit) - | None -> mk_lit_exp L_undef)) - ) - | None -> - let to_wildcards = match Util.last_opt (rows_to_list matrix) with - | Some (idx, Columns row) -> - List.filter_map (function (GP_bitvector (pnum, _, _) | GP_num (pnum, _, _)) -> Some (idx.num, pnum) | _ -> None) row - | None -> [] - in - mk_complete all_rows to_wildcards + | [] when row_matrix_height matrix = 1 -> + mk_complete all_rows [] (* The matrix is a single row of wildcard patterns *) + | _ -> ( + let head_exp_constraint, var_map, _ = + Constraint.constraint_to_smt Parse_ast.Unknown (List.fold_left nc_and nc_true ctx.constraints) + in + let created_vars = ref KidSet.empty in + (* We set this true if we need to include the head expression constraint in the generated SMT problem *) + let require_head_exp_constraint = ref false in + let constrs = + List.map + (fun (l, Columns row) -> + let row_constrs = + List.map2 + (fun var gpat -> + match (var, gpat) with + | Some (i, _), GP_bitvector (_, _, bvc) -> + Some (string_of_bv_constraint (bvc (BVC_lit ("p" ^ string_of_int i)))) + | Some (i, _), GP_num (_, n, Some v) -> + let smt_var, created = var_map v in + (* If the variable was not already in the map (and has therefore just been created), then it is unconstrained *) + if created then created_vars := KidSet.add v !created_vars; + if not (KidSet.mem v !created_vars) then ( + require_head_exp_constraint := true; + Some (Printf.sprintf "(or (= p%d %s) (not (= p%d %s)))" i (Big_int.to_string n) i smt_var) + ) + else Some (Printf.sprintf "(= p%d %s)" i (Big_int.to_string n)) + | Some (i, _), GP_num (_, n, None) -> Some (Printf.sprintf "(= p%d %s)" i (Big_int.to_string n)) + | _ -> None + ) + vars row + |> Util.option_these + in + match row_constrs with + | [] -> (l, None) + | [c] -> (l, Some ("(assert (not " ^ Util.string_of_list " " (fun x -> x) row_constrs ^ "))")) + | _ -> (l, Some ("(assert (not (and " ^ Util.string_of_list " " (fun x -> x) row_constrs ^ ")))")) + ) + (rows_to_list matrix) + in + (* Check if we have any row containing only wildcards, hence matrix is trivially unsatisfiable *) + match Util.find_rest_opt (fun (_, constr) -> Option.is_none constr) constrs with + | Some (_, []) -> mk_complete all_rows [] + (* If there are any rows after the wildcard row, they are redundant *) + | Some (_, redundant) -> mk_complete ~redundant:(List.map (fun (idx, _) -> idx.num) redundant) all_rows [] + | None -> ( + let smtlib = + (if !require_head_exp_constraint then head_exp_constraint ^ "\n" else "") + ^ Util.string_of_list "\n" (fun (v, ty) -> Printf.sprintf "(declare-const p%d %s)" v ty) just_vars + ^ "\n" + ^ Util.string_of_list "\n" (fun x -> x) (Util.option_these (List.map snd constrs)) + ^ "\n" ^ "(check-sat)\n" ^ "(get-model)\n" + in + match Constraint.call_smt_solve_bitvector Parse_ast.Unknown smtlib just_vars with + | Some lits -> + if !opt_debug_no_literals then Incomplete (List.init (List.length vars) (fun _ -> mk_lit_exp L_undef)) + else + Incomplete + (List.init (List.length vars) (fun i -> + match List.assoc_opt i lits with Some lit -> mk_exp (E_lit lit) | None -> mk_lit_exp L_undef + ) + ) + | None -> + let to_wildcards = + match Util.last_opt (rows_to_list matrix) with + | Some (idx, Columns row) -> + List.filter_map + (function GP_bitvector (pnum, _, _) | GP_num (pnum, _, _) -> Some (idx.num, pnum) | _ -> None) + row + | None -> [] + in + mk_complete all_rows to_wildcards + ) + ) let find_complex_column matrix = let is_complex_column col = List.exists (fun (_, gpat) -> not (is_simple_gpat gpat)) col in @@ -541,19 +540,24 @@ module Make(C: Config) = struct let split_app_column l ctx col = let typ_id = column_typ_id l col in - let all_ctors = Bindings.find typ_id ctx.variants |> snd |> List.map (function Tu_aux (Tu_ty_id (_, id), _) -> id) in + let all_ctors = + Bindings.find typ_id ctx.variants |> snd |> List.map (function Tu_aux (Tu_ty_id (_, id), _) -> id) + in let all_ctors = List.fold_left (fun m ctor -> Bindings.add ctor [] m) Bindings.empty all_ctors in - List.fold_left (fun (i, acc) (_, gpat) -> - let acc = match gpat with + List.fold_left + (fun (i, acc) (_, gpat) -> + let acc = + match gpat with | GP_app (_, ctor, ctor_gpats) -> - Bindings.update ctor (function None -> Some [(i, Some ctor_gpats)] | Some xs -> Some ((i, Some ctor_gpats) :: xs)) acc - | GP_wild -> - Bindings.map (fun xs -> (i, None) :: xs) acc - | _ -> - Reporting.unreachable Parse_ast.Unknown __POS__ "App column contains invalid pattern" [@coverage off] + Bindings.update ctor + (function None -> Some [(i, Some ctor_gpats)] | Some xs -> Some ((i, Some ctor_gpats) :: xs)) + acc + | GP_wild -> Bindings.map (fun xs -> (i, None) :: xs) acc + | _ -> Reporting.unreachable Parse_ast.Unknown __POS__ "App column contains invalid pattern" [@coverage off] in (i + 1, acc) - ) (0, all_ctors) col + ) + (0, all_ctors) col |> snd let flatten_tuple_column width i matrix = @@ -562,9 +566,16 @@ module Make(C: Config) = struct | GP_wild -> List.init width (fun _ -> GP_wild) | _ -> Reporting.unreachable Parse_ast.Unknown __POS__ "Tuple column contains invalid pattern" [@coverage off] in - Rows (List.map (fun (l, row) -> - (l, Columns (List.mapi (fun j gpat -> if i = j then flatten gpat else [gpat]) (columns_to_list row) |> List.concat)) - ) (rows_to_list matrix)) + Rows + (List.map + (fun (l, row) -> + ( l, + Columns + (List.mapi (fun j gpat -> if i = j then flatten gpat else [gpat]) (columns_to_list row) |> List.concat) + ) + ) + (rows_to_list matrix) + ) let split_matrix_ctor ctx c ctor ctor_rows matrix = let row_indices = List.fold_left (fun set (r, _) -> IntSet.add r set) IntSet.empty ctor_rows in @@ -573,281 +584,253 @@ module Make(C: Config) = struct | GP_wild -> GP_wild | _ -> Reporting.unreachable Parse_ast.Unknown __POS__ "App column contains invalid pattern" [@coverage off] in - let remove_ctor row = Columns (List.mapi (fun i gpat -> if i = c then flatten gpat else gpat) (columns_to_list row)) in - Rows ( - rows_to_list matrix - |> List.mapi (fun r row -> (r, row)) - |> List.filter_map (fun (r, (l, row)) -> if IntSet.mem r row_indices then Some (l, remove_ctor row) else None) + let remove_ctor row = + Columns (List.mapi (fun i gpat -> if i = c then flatten gpat else gpat) (columns_to_list row)) + in + Rows + (rows_to_list matrix + |> List.mapi (fun r row -> (r, row)) + |> List.filter_map (fun (r, (l, row)) -> if IntSet.mem r row_indices then Some (l, remove_ctor row) else None) ) - let rec remove_index n = function - | x :: xs when n = 0 -> xs - | x :: xs -> x :: remove_index (n - 1) xs - | [] -> [] + let rec remove_index n = function x :: xs when n = 0 -> xs | x :: xs -> x :: remove_index (n - 1) xs | [] -> [] let split_matrix_bool b c matrix = - let is_bool_row = function - | GP_bool b' -> b = b' - | GP_wild -> true - | _ -> false - in - Rows ( - rows_to_list matrix - |> List.filter (fun (_, row) -> columns_to_list row |> (fun xs -> List.nth xs c) |> is_bool_row) - |> List.map (fun (l, row) -> (l, Columns (remove_index c (columns_to_list row)))) + let is_bool_row = function GP_bool b' -> b = b' | GP_wild -> true | _ -> false in + Rows + (rows_to_list matrix + |> List.filter (fun (_, row) -> columns_to_list row |> (fun xs -> List.nth xs c) |> is_bool_row) + |> List.map (fun (l, row) -> (l, Columns (remove_index c (columns_to_list row)))) ) let split_matrix_wild c matrix = - let is_wild_row = function - | GP_wild -> true - | _ -> false - in - Rows ( - rows_to_list matrix - |> List.filter (fun (_, row) -> columns_to_list row |> (fun xs -> List.nth xs c) |> is_wild_row) - |> List.map (fun (l, row) -> (l, Columns (remove_index c (columns_to_list row)))) + let is_wild_row = function GP_wild -> true | _ -> false in + Rows + (rows_to_list matrix + |> List.filter (fun (_, row) -> columns_to_list row |> (fun xs -> List.nth xs c) |> is_wild_row) + |> List.map (fun (l, row) -> (l, Columns (remove_index c (columns_to_list row)))) ) let split_matrix_cons c matrix = - let is_cons_row = function - | GP_wild | GP_cons _ -> true - | _ -> false - in - let is_empty_list_row = function - | GP_wild | GP_empty_list -> true - | _ -> false - in + let is_cons_row = function GP_wild | GP_cons _ -> true | _ -> false in + let is_empty_list_row = function GP_wild | GP_empty_list -> true | _ -> false in let uncons = function | GP_wild -> GP_tuple [GP_wild; GP_wild] | GP_cons (hd_gpat, tl_gpat) -> GP_tuple [hd_gpat; tl_gpat] | _ -> Reporting.unreachable Parse_ast.Unknown __POS__ "Cons row contains invalid pattern" [@coverage off] in - let remove_cons row = Columns (List.mapi (fun i gpat -> if i = c then uncons gpat else gpat) (columns_to_list row)) in - ( - Rows ( - rows_to_list matrix + let remove_cons row = + Columns (List.mapi (fun i gpat -> if i = c then uncons gpat else gpat) (columns_to_list row)) + in + ( Rows + (rows_to_list matrix |> List.filter (fun (_, row) -> columns_to_list row |> (fun xs -> List.nth xs c) |> is_cons_row) |> List.map (fun (l, row) -> (l, remove_cons row)) - ), - Rows ( - rows_to_list matrix + ), + Rows + (rows_to_list matrix |> List.filter (fun (_, row) -> columns_to_list row |> (fun xs -> List.nth xs c) |> is_empty_list_row) |> List.map (fun (l, row) -> (l, Columns (remove_index c (columns_to_list row)))) - ) + ) ) let split_matrix_enum e c matrix = - let is_enum_row = function - | GP_enum (_, id) -> Id.compare e id = 0 - | GP_wild -> true - | _ -> false - in - Rows ( - rows_to_list matrix - |> List.filter (fun (_, row) -> columns_to_list row |> (fun xs -> List.nth xs c) |> is_enum_row) - |> List.map (fun (l, row) -> (l, Columns (remove_index c (columns_to_list row)))) + let is_enum_row = function GP_enum (_, id) -> Id.compare e id = 0 | GP_wild -> true | _ -> false in + Rows + (rows_to_list matrix + |> List.filter (fun (_, row) -> columns_to_list row |> (fun xs -> List.nth xs c) |> is_enum_row) + |> List.map (fun (l, row) -> (l, Columns (remove_index c (columns_to_list row)))) ) let retuple width i unmatcheds = - let (xs, ys) = Util.split_after i unmatcheds in + let xs, ys = Util.split_after i unmatcheds in let tuple_elems = Util.take width ys in let zs = Util.drop width ys in - xs @ mk_exp (E_tuple tuple_elems) :: zs + xs @ (mk_exp (E_tuple tuple_elems) :: zs) let rector ctor i unmatcheds = - let (xs, ys) = Util.split_after i unmatcheds in + let xs, ys = Util.split_after i unmatcheds in match ys with - | E_aux (E_tuple args, _) :: zs -> - xs @ mk_exp (E_app (ctor, args)) :: zs - | y :: zs -> - xs @ mk_exp (E_app (ctor, [y])) :: zs - | [] -> - xs @ [mk_exp (E_app (ctor, []))] + | E_aux (E_tuple args, _) :: zs -> xs @ (mk_exp (E_app (ctor, args)) :: zs) + | y :: zs -> xs @ (mk_exp (E_app (ctor, [y])) :: zs) + | [] -> xs @ [mk_exp (E_app (ctor, []))] let relit lit i unmatcheds = - let (xs, ys) = Util.split_after i unmatcheds in - xs @ mk_lit_exp lit :: ys + let xs, ys = Util.split_after i unmatcheds in + xs @ (mk_lit_exp lit :: ys) - let rebool b i unmatcheds = - relit (if b then L_true else L_false) i unmatcheds + let rebool b i unmatcheds = relit (if b then L_true else L_false) i unmatcheds let recons l i unmatcheds = - let (xs, ys) = Util.split_after i unmatcheds in + let xs, ys = Util.split_after i unmatcheds in match ys with - | E_aux (E_tuple [hd_arg; tl_arg], _) :: zs -> - xs @ mk_exp (E_cons (hd_arg, tl_arg)) :: zs - | _ -> - Reporting.unreachable l __POS__ "Cannot reconstruct cons pattern" [@coverage off] - + | E_aux (E_tuple [hd_arg; tl_arg], _) :: zs -> xs @ (mk_exp (E_cons (hd_arg, tl_arg)) :: zs) + | _ -> Reporting.unreachable l __POS__ "Cannot reconstruct cons pattern" [@coverage off] + let reempty_list i unmatcheds = - let (xs, ys) = Util.split_after i unmatcheds in - xs @ mk_exp (E_list []) :: ys - + let xs, ys = Util.split_after i unmatcheds in + xs @ (mk_exp (E_list []) :: ys) + let reenum e i unmatcheds = - let (xs, ys) = Util.split_after i unmatcheds in - xs @ mk_exp (E_id e) :: ys + let xs, ys = Util.split_after i unmatcheds in + xs @ (mk_exp (E_id e) :: ys) let rec undefs_except n c v len = - if n = len then - [] - else if n = c then - v :: undefs_except (n + 1) c v len - else - mk_lit_exp L_undef :: undefs_except (n + 1) c v len + if n = len then [] + else if n = c then v :: undefs_except (n + 1) c v len + else mk_lit_exp L_undef :: undefs_except (n + 1) c v len let rec matrix_is_complete l ctx matrix = match find_complex_column matrix with | None -> simple_matrix_is_complete ctx matrix - | Some (i, col) -> - begin match column_type col with - | Tuple_column width -> - matrix_is_complete l ctx (flatten_tuple_column width i matrix) - |> completeness_map (retuple width i) (fun w -> w) - - | Lit_column -> - let wild_matrix = split_matrix_wild i matrix in - begin match unmatched_literal col with - | None -> Completeness_unknown - | Some lit -> - if row_matrix_empty wild_matrix then - Incomplete (undefs_except 0 i (mk_lit_exp lit) (row_matrix_width l matrix)) - else - match matrix_is_complete l ctx wild_matrix with - | Incomplete unmatcheds -> Incomplete (relit lit i unmatcheds) - | Complete cinfo -> Complete cinfo - | Completeness_unknown -> Completeness_unknown - end - - | List_column -> - let cons_matrix, empty_list_matrix = split_matrix_cons i matrix in - let width = row_matrix_width l matrix in - if row_matrix_empty empty_list_matrix then - Incomplete (undefs_except 0 i (mk_exp (E_list [])) width) - else if row_matrix_empty cons_matrix then - Incomplete (undefs_except 0 i (mk_exp (E_cons (mk_lit_exp L_undef, mk_lit_exp L_undef))) width) - else - begin match matrix_is_complete l ctx cons_matrix with - | Incomplete unmatcheds -> - Incomplete (recons l i unmatcheds) - | Complete cinfo -> - matrix_is_complete l ctx empty_list_matrix |> completeness_map (reempty_list i) (union_complete cinfo) - | Completeness_unknown -> - Completeness_unknown + | Some (i, col) -> begin + match column_type col with + | Tuple_column width -> + matrix_is_complete l ctx (flatten_tuple_column width i matrix) + |> completeness_map (retuple width i) (fun w -> w) + | Lit_column -> + let wild_matrix = split_matrix_wild i matrix in + begin + match unmatched_literal col with + | None -> Completeness_unknown + | Some lit -> + if row_matrix_empty wild_matrix then + Incomplete (undefs_except 0 i (mk_lit_exp lit) (row_matrix_width l matrix)) + else ( + match matrix_is_complete l ctx wild_matrix with + | Incomplete unmatcheds -> Incomplete (relit lit i unmatcheds) + | Complete cinfo -> Complete cinfo + | Completeness_unknown -> Completeness_unknown + ) end - - | App_column typ_id -> - let ctors = split_app_column l ctx col in - Bindings.fold (fun ctor ctor_rows unmatcheds -> - match unmatcheds with - | Incomplete unmatcheds -> Incomplete unmatcheds - | Completeness_unknown -> Completeness_unknown + | List_column -> + let cons_matrix, empty_list_matrix = split_matrix_cons i matrix in + let width = row_matrix_width l matrix in + if row_matrix_empty empty_list_matrix then Incomplete (undefs_except 0 i (mk_exp (E_list [])) width) + else if row_matrix_empty cons_matrix then + Incomplete (undefs_except 0 i (mk_exp (E_cons (mk_lit_exp L_undef, mk_lit_exp L_undef))) width) + else begin + match matrix_is_complete l ctx cons_matrix with + | Incomplete unmatcheds -> Incomplete (recons l i unmatcheds) | Complete cinfo -> - let ctor_matrix = split_matrix_ctor ctx i ctor ctor_rows matrix in - if row_matrix_empty ctor_matrix then - let width = row_matrix_width l matrix in - Incomplete (undefs_except 0 i (mk_exp (E_app (ctor, [mk_lit_exp L_undef]))) width) - else - matrix_is_complete l ctx ctor_matrix |> completeness_map (rector ctor i) (union_complete cinfo) - ) ctors (mk_complete [] []) - - | Bool_column -> - let true_matrix = split_matrix_bool true i matrix in - let false_matrix = split_matrix_bool false i matrix in - let width = row_matrix_width l matrix in - if row_matrix_empty true_matrix then - Incomplete (undefs_except 0 i (mk_lit_exp L_true) width) - else if row_matrix_empty false_matrix then - Incomplete (undefs_except 0 i (mk_lit_exp L_false) width) - else - begin match matrix_is_complete l ctx true_matrix with - | Incomplete unmatcheds -> - Incomplete (rebool true i unmatcheds) - | Complete cinfo -> - matrix_is_complete l ctx false_matrix |> completeness_map (rebool false i) (union_complete cinfo) - | Completeness_unknown -> - Completeness_unknown - end - - | Enum_column typ_id -> - let members = Bindings.find typ_id ctx.enums in - IdSet.fold (fun member unmatcheds -> - match unmatcheds with - | Incomplete unmatcheds -> Incomplete unmatcheds + matrix_is_complete l ctx empty_list_matrix |> completeness_map (reempty_list i) (union_complete cinfo) | Completeness_unknown -> Completeness_unknown + end + | App_column typ_id -> + let ctors = split_app_column l ctx col in + Bindings.fold + (fun ctor ctor_rows unmatcheds -> + match unmatcheds with + | Incomplete unmatcheds -> Incomplete unmatcheds + | Completeness_unknown -> Completeness_unknown + | Complete cinfo -> + let ctor_matrix = split_matrix_ctor ctx i ctor ctor_rows matrix in + if row_matrix_empty ctor_matrix then ( + let width = row_matrix_width l matrix in + Incomplete (undefs_except 0 i (mk_exp (E_app (ctor, [mk_lit_exp L_undef]))) width) + ) + else matrix_is_complete l ctx ctor_matrix |> completeness_map (rector ctor i) (union_complete cinfo) + ) + ctors (mk_complete [] []) + | Bool_column -> + let true_matrix = split_matrix_bool true i matrix in + let false_matrix = split_matrix_bool false i matrix in + let width = row_matrix_width l matrix in + if row_matrix_empty true_matrix then Incomplete (undefs_except 0 i (mk_lit_exp L_true) width) + else if row_matrix_empty false_matrix then Incomplete (undefs_except 0 i (mk_lit_exp L_false) width) + else begin + match matrix_is_complete l ctx true_matrix with + | Incomplete unmatcheds -> Incomplete (rebool true i unmatcheds) | Complete cinfo -> - let enum_matrix = split_matrix_enum member i matrix in - if row_matrix_empty enum_matrix then - let width = row_matrix_width l matrix in - Incomplete (undefs_except 0 i (mk_exp (E_id member)) width) - else - matrix_is_complete l ctx enum_matrix |> completeness_map (reenum member i) (union_complete cinfo) - ) members (mk_complete [] []) - - | Unknown_column -> Completeness_unknown - end + matrix_is_complete l ctx false_matrix |> completeness_map (rebool false i) (union_complete cinfo) + | Completeness_unknown -> Completeness_unknown + end + | Enum_column typ_id -> + let members = Bindings.find typ_id ctx.enums in + IdSet.fold + (fun member unmatcheds -> + match unmatcheds with + | Incomplete unmatcheds -> Incomplete unmatcheds + | Completeness_unknown -> Completeness_unknown + | Complete cinfo -> + let enum_matrix = split_matrix_enum member i matrix in + if row_matrix_empty enum_matrix then ( + let width = row_matrix_width l matrix in + Incomplete (undefs_except 0 i (mk_exp (E_id member)) width) + ) + else + matrix_is_complete l ctx enum_matrix |> completeness_map (reenum member i) (union_complete cinfo) + ) + members (mk_complete [] []) + | Unknown_column -> Completeness_unknown + end (* Just highlight the match keyword and not the whole match block. *) let shrink_loc keyword = function - | Parse_ast.Range (n, m) -> - Lexing.(Parse_ast.Range (n, { n with pos_cnum = n.pos_cnum + String.length keyword })) + | Parse_ast.Range (n, m) -> Lexing.(Parse_ast.Range (n, { n with pos_cnum = n.pos_cnum + String.length keyword })) | l -> l let rec cases_to_pats from have_guard = function - | [] -> have_guard, [] + | [] -> (have_guard, []) | Pat_aux (Pat_exp ((P_aux (_, (l, _)) as pat), _), _) :: cases -> - let pat, from = number_pat from pat in - let have_guard, pats = cases_to_pats from have_guard cases in - have_guard, ((l, pat) :: pats) + let pat, from = number_pat from pat in + let have_guard, pats = cases_to_pats from have_guard cases in + (have_guard, (l, pat) :: pats) (* We don't consider guarded cases *) | Pat_aux (Pat_when _, _) :: cases -> cases_to_pats from true cases let rec update_cases l new_pats cases = - match new_pats, cases with + match (new_pats, cases) with | [], [] -> [] - | (new_pat :: new_pats), (Pat_aux (Pat_exp (_, exp), annot) :: cases) -> - Pat_aux (Pat_exp (new_pat, exp), annot) :: update_cases l new_pats cases - | _, ((Pat_aux (Pat_when _, _) as case) :: cases) -> - case :: update_cases l new_pats cases - | _, _ -> - Reporting.unreachable l __POS__ "Impossible case in update_cases" [@coverage off] - - let is_complete_wildcarded ?(keyword="match") l ctx cases head_exp_typ = + | new_pat :: new_pats, Pat_aux (Pat_exp (_, exp), annot) :: cases -> + Pat_aux (Pat_exp (new_pat, exp), annot) :: update_cases l new_pats cases + | _, (Pat_aux (Pat_when _, _) as case) :: cases -> case :: update_cases l new_pats cases + | _, _ -> Reporting.unreachable l __POS__ "Impossible case in update_cases" [@coverage off] + + let is_complete_wildcarded ?(keyword = "match") l ctx cases head_exp_typ = try match cases_to_pats 0 false cases with | _, [] -> None | have_guard, pats -> - let matrix = Rows (List.mapi (fun i (l, pat) -> ({ loc = l; num = i}, Columns [generalize ctx (Some head_exp_typ) pat])) pats) in - begin match matrix_is_complete l ctx matrix with - | Incomplete (unmatched :: _) -> - let guard_info = if have_guard then " by unguarded patterns" else "" in - Reporting.warn "Incomplete pattern match statement at" (shrink_loc keyword l) - ("The following expression is unmatched" ^ guard_info ^ ": " ^ (string_of_exp unmatched |> Util.yellow |> Util.clear)); - None - | Incomplete [] -> - Reporting.unreachable l __POS__ "Got unmatched pattern matrix without witness" [@coverage off] - | Complete cinfo -> - let wildcarded_pats = List.map (fun (_, pat) -> insert_wildcards cinfo pat) pats in - List.iter (fun (idx, _) -> - if IntSet.mem idx.num cinfo.redundant then - Reporting.warn "Redundant case" idx.loc "This match case is never used" - ) (rows_to_list matrix); - Some (update_cases l wildcarded_pats cases) - | Completeness_unknown -> - None - end - with - (* For now, if any error occurs just report the pattern match is incomplete *) + let matrix = + Rows + (List.mapi + (fun i (l, pat) -> ({ loc = l; num = i }, Columns [generalize ctx (Some head_exp_typ) pat])) + pats + ) + in + begin + match matrix_is_complete l ctx matrix with + | Incomplete (unmatched :: _) -> + let guard_info = if have_guard then " by unguarded patterns" else "" in + Reporting.warn "Incomplete pattern match statement at" (shrink_loc keyword l) + ("The following expression is unmatched" ^ guard_info ^ ": " + ^ (string_of_exp unmatched |> Util.yellow |> Util.clear) + ); + None + | Incomplete [] -> + Reporting.unreachable l __POS__ "Got unmatched pattern matrix without witness" [@coverage off] + | Complete cinfo -> + let wildcarded_pats = List.map (fun (_, pat) -> insert_wildcards cinfo pat) pats in + List.iter + (fun (idx, _) -> + if IntSet.mem idx.num cinfo.redundant then + Reporting.warn "Redundant case" idx.loc "This match case is never used" + ) + (rows_to_list matrix); + Some (update_cases l wildcarded_pats cases) + | Completeness_unknown -> None + end + with (* For now, if any error occurs just report the pattern match is incomplete *) | exn -> None - let is_complete_funcls_wildcarded ?(keyword="match") l ctx funcls head_exp_typ = + let is_complete_funcls_wildcarded ?(keyword = "match") l ctx funcls head_exp_typ = let destruct_funcl (FCL_aux (FCL_funcl (id, pexp), annot)) = ((id, annot), pexp) in let cases = List.map destruct_funcl funcls in - match is_complete_wildcarded ~keyword:keyword l ctx (List.map snd cases) head_exp_typ with - | Some pexps -> - Some (List.map2 (fun ((id, annot), _) pexp -> FCL_aux (FCL_funcl (id, pexp), annot)) cases pexps) - | None -> - None - - let is_complete ?(keyword="match") l ctx cases head_exp_typ = Option.is_some (is_complete_wildcarded ~keyword:keyword l ctx cases head_exp_typ) + match is_complete_wildcarded ~keyword l ctx (List.map snd cases) head_exp_typ with + | Some pexps -> Some (List.map2 (fun ((id, annot), _) pexp -> FCL_aux (FCL_funcl (id, pexp), annot)) cases pexps) + | None -> None + let is_complete ?(keyword = "match") l ctx cases head_exp_typ = + Option.is_some (is_complete_wildcarded ~keyword l ctx cases head_exp_typ) end diff --git a/src/lib/pattern_completeness.mli b/src/lib/pattern_completeness.mli index 71e4f726e..824a0a270 100644 --- a/src/lib/pattern_completeness.mli +++ b/src/lib/pattern_completeness.mli @@ -75,20 +75,20 @@ open Ast_util val opt_debug_no_literals : bool ref type ctx = { - variants : (typquant * type_union list) Bindings.t; - enums : IdSet.t Bindings.t; - constraints : n_constraint list; - } + variants : (typquant * type_union list) Bindings.t; + enums : IdSet.t Bindings.t; + constraints : n_constraint list; +} -module type Config = - sig - type t - val typ_of_t : t -> typ - val add_attribute : l -> string -> string -> t -> t - end +module type Config = sig + type t + val typ_of_t : t -> typ + val add_attribute : l -> string -> string -> t -> t +end -module Make(C: Config) : sig +module Make (C : Config) : sig val is_complete_wildcarded : ?keyword:string -> Parse_ast.l -> ctx -> C.t pexp list -> typ -> C.t pexp list option - val is_complete_funcls_wildcarded : ?keyword:string -> Parse_ast.l -> ctx -> C.t funcl list -> typ -> C.t funcl list option + val is_complete_funcls_wildcarded : + ?keyword:string -> Parse_ast.l -> ctx -> C.t funcl list -> typ -> C.t funcl list option val is_complete : ?keyword:string -> Parse_ast.l -> ctx -> C.t pexp list -> typ -> bool end diff --git a/src/lib/preprocess.ml b/src/lib/preprocess.ml index e053e6e45..034c84b16 100644 --- a/src/lib/preprocess.ml +++ b/src/lib/preprocess.ml @@ -68,49 +68,46 @@ open Parse_ast (* Simple preprocessor features for conditional file loading *) -module StringSet = Set.Make(String) +module StringSet = Set.Make (String) let default_symbols = - List.fold_left (fun set str -> StringSet.add str set) StringSet.empty - [ "FEATURE_IMPLICITS"; - "FEATURE_CONSTANT_TYPES"; - "FEATURE_BITVECTOR_TYPE"; - "FEATURE_UNION_BARRIER"; - ] + List.fold_left + (fun set str -> StringSet.add str set) + StringSet.empty + ["FEATURE_IMPLICITS"; "FEATURE_CONSTANT_TYPES"; "FEATURE_BITVECTOR_TYPE"; "FEATURE_UNION_BARRIER"] let symbols = ref default_symbols -let have_symbol symbol = - StringSet.mem symbol !symbols +let have_symbol symbol = StringSet.mem symbol !symbols let clear_symbols () = symbols := default_symbols let add_symbol str = symbols := StringSet.add str !symbols - + let cond_pragma l defs = let depth = ref 0 in let in_then = ref true in let then_defs = ref [] in let else_defs = ref [] in - let push_def def = - if !in_then then - then_defs := (def :: !then_defs) - else - else_defs := (def :: !else_defs) - in + let push_def def = if !in_then then then_defs := def :: !then_defs else else_defs := def :: !else_defs in let rec scan = function - | DEF_aux (DEF_pragma ("endif", _), _) :: defs when !depth = 0 -> - (List.rev !then_defs, List.rev !else_defs, defs) + | DEF_aux (DEF_pragma ("endif", _), _) :: defs when !depth = 0 -> (List.rev !then_defs, List.rev !else_defs, defs) | DEF_aux (DEF_pragma ("else", _), _) :: defs when !depth = 0 -> - in_then := false; scan defs + in_then := false; + scan defs | (DEF_aux (DEF_pragma (p, _), _) as def) :: defs when p = "ifdef" || p = "ifndef" || p = "iftarget" -> - incr depth; push_def def; scan defs - | (DEF_aux (DEF_pragma ("endif", _), _) as def) :: defs-> - decr depth; push_def def; scan defs + incr depth; + push_def def; + scan defs + | (DEF_aux (DEF_pragma ("endif", _), _) as def) :: defs -> + decr depth; + push_def def; + scan defs | def :: defs -> - push_def def; scan defs + push_def def; + scan defs | [] -> raise (Reporting.err_general l "$ifdef, $ifndef, or $iftarget never ended by $endif") in scan defs @@ -119,8 +116,11 @@ let cond_pragma l defs = just silently ignoring them, so we have a list here of all recognised pragmas. *) let all_pragmas = - List.fold_left (fun set str -> StringSet.add str set) StringSet.empty - [ "define"; + List.fold_left + (fun set str -> StringSet.add str set) + StringSet.empty + [ + "define"; "anchor"; "span"; "include"; @@ -139,112 +139,108 @@ let all_pragmas = "include_end"; "sail_internal"; "target_set"; - "non_exec" + "non_exec"; ] let wrap_include l file = function | [] -> [] - | defs -> - [DEF_aux (DEF_pragma ("include_start", file), l)] - @ defs - @ [DEF_aux (DEF_pragma ("include_end", file), l)] + | defs -> [DEF_aux (DEF_pragma ("include_start", file), l)] @ defs @ [DEF_aux (DEF_pragma ("include_end", file), l)] let rec preprocess dir target opts = let module P = Parse_ast in function | [] -> [] | DEF_aux (DEF_pragma ("define", symbol), _) :: defs -> - symbols := StringSet.add symbol !symbols; - preprocess dir target opts defs - + symbols := StringSet.add symbol !symbols; + preprocess dir target opts defs | (DEF_aux (DEF_pragma ("option", command), l) as opt_pragma) :: defs -> - begin - let first_line err_msg = match String.split_on_char '\n' err_msg with - | line :: _ -> "\n" ^ line - | [] -> "" [@coverage off] (* Don't expect this should ever happen, but we are fine if it does *) - in - try - let args = Str.split (Str.regexp " +") command in - let file_arg file = raise (Reporting.err_general l ("Anonymous argument '" ^ file ^ "' cannot be passed via $option directive")) in - Arg.parse_argv ~current:(ref 0) (Array.of_list ("sail" :: args)) opts file_arg ""; - with - | Arg.Help msg -> raise (Reporting.err_general l "-help flag passed to $option directive") - | Arg.Bad msg -> raise (Reporting.err_general l ("Invalid flag passed to $option directive" ^ first_line msg)) - end; - opt_pragma :: preprocess dir target opts defs - + begin + let first_line err_msg = + match String.split_on_char '\n' err_msg with line :: _ -> "\n" ^ line | [] -> ("" [@coverage off]) + (* Don't expect this should ever happen, but we are fine if it does *) + in + try + let args = Str.split (Str.regexp " +") command in + let file_arg file = + raise (Reporting.err_general l ("Anonymous argument '" ^ file ^ "' cannot be passed via $option directive")) + in + Arg.parse_argv ~current:(ref 0) (Array.of_list ("sail" :: args)) opts file_arg "" + with + | Arg.Help msg -> raise (Reporting.err_general l "-help flag passed to $option directive") + | Arg.Bad msg -> raise (Reporting.err_general l ("Invalid flag passed to $option directive" ^ first_line msg)) + end; + opt_pragma :: preprocess dir target opts defs | DEF_aux (DEF_pragma ("ifndef", symbol), l) :: defs -> - let then_defs, else_defs, defs = cond_pragma l defs in - if not (StringSet.mem symbol !symbols) then - preprocess dir target opts (then_defs @ defs) - else - preprocess dir target opts (else_defs @ defs) - + let then_defs, else_defs, defs = cond_pragma l defs in + if not (StringSet.mem symbol !symbols) then preprocess dir target opts (then_defs @ defs) + else preprocess dir target opts (else_defs @ defs) | DEF_aux (DEF_pragma ("ifdef", symbol), l) :: defs -> - let then_defs, else_defs, defs = cond_pragma l defs in - if StringSet.mem symbol !symbols then - preprocess dir target opts (then_defs @ defs) - else - preprocess dir target opts (else_defs @ defs) - + let then_defs, else_defs, defs = cond_pragma l defs in + if StringSet.mem symbol !symbols then preprocess dir target opts (then_defs @ defs) + else preprocess dir target opts (else_defs @ defs) | DEF_aux (DEF_pragma ("iftarget", t), l) :: defs -> - let then_defs, else_defs, defs = cond_pragma l defs in - begin match target with - | Some t' when t = t' -> - preprocess dir target opts (then_defs @ defs) - | _ -> - preprocess dir target opts (else_defs @ defs) - end - + let then_defs, else_defs, defs = cond_pragma l defs in + begin + match target with + | Some t' when t = t' -> preprocess dir target opts (then_defs @ defs) + | _ -> preprocess dir target opts (else_defs @ defs) + end | DEF_aux (DEF_pragma ("include", file), l) :: defs -> - let len = String.length file in - if len = 0 then - (Reporting.warn "" l "Skipping bad $include. No file argument."; preprocess dir target opts defs) - else if file.[0] = '"' && file.[len - 1] = '"' then - let relative = match l with - | Parse_ast.Range (pos, _) -> Filename.dirname (Lexing.(pos.pos_fname)) - | _ -> failwith "Couldn't figure out relative path for $include. This really shouldn't ever happen." - in - let file = String.sub file 1 (len - 2) in - let include_file = Filename.concat relative file in - let include_defs = Initial_check.parse_file ~loc:l (Filename.concat relative file) |> snd |> preprocess dir target opts in - wrap_include l include_file include_defs @ preprocess dir target opts defs - else if file.[0] = '<' && file.[len - 1] = '>' then - let file = String.sub file 1 (len - 2) in - let sail_dir = Reporting.get_sail_dir dir in - let file = Filename.concat sail_dir ("lib/" ^ file) in - let include_defs = Initial_check.parse_file ~loc:l file |> snd |> preprocess dir target opts in - wrap_include l file include_defs @ preprocess dir target opts defs - else - let help = "Make sure the filename is surrounded by quotes or angle brackets" in - (Reporting.warn "" l ("Skipping bad $include " ^ file ^ ". " ^ help); preprocess dir target opts defs) - + let len = String.length file in + if len = 0 then ( + Reporting.warn "" l "Skipping bad $include. No file argument."; + preprocess dir target opts defs + ) + else if file.[0] = '"' && file.[len - 1] = '"' then ( + let relative = + match l with + | Parse_ast.Range (pos, _) -> Filename.dirname Lexing.(pos.pos_fname) + | _ -> failwith "Couldn't figure out relative path for $include. This really shouldn't ever happen." + in + let file = String.sub file 1 (len - 2) in + let include_file = Filename.concat relative file in + let include_defs = + Initial_check.parse_file ~loc:l (Filename.concat relative file) |> snd |> preprocess dir target opts + in + wrap_include l include_file include_defs @ preprocess dir target opts defs + ) + else if file.[0] = '<' && file.[len - 1] = '>' then ( + let file = String.sub file 1 (len - 2) in + let sail_dir = Reporting.get_sail_dir dir in + let file = Filename.concat sail_dir ("lib/" ^ file) in + let include_defs = Initial_check.parse_file ~loc:l file |> snd |> preprocess dir target opts in + wrap_include l file include_defs @ preprocess dir target opts defs + ) + else ( + let help = "Make sure the filename is surrounded by quotes or angle brackets" in + Reporting.warn "" l ("Skipping bad $include " ^ file ^ ". " ^ help); + preprocess dir target opts defs + ) | DEF_aux (DEF_pragma ("suppress_warnings", _), l) :: defs -> - begin match Reporting.simp_loc l with - | None -> () (* This shouldn't happen, but if it does just continue *) - | Some (p, _) -> Reporting.suppress_warnings_for_file p.pos_fname - end; - preprocess dir target opts defs - + begin + match Reporting.simp_loc l with + | None -> () (* This shouldn't happen, but if it does just continue *) + | Some (p, _) -> Reporting.suppress_warnings_for_file p.pos_fname + end; + preprocess dir target opts defs (* Filter file_start and file_end out of the AST so when we round-trip files through the compiler we don't end up with incorrect start/end annotations *) | (DEF_aux (DEF_pragma ("file_start", _), _) | DEF_aux (DEF_pragma ("file_end", _), _)) :: defs -> - preprocess dir target opts defs - + preprocess dir target opts defs | DEF_aux (DEF_pragma (p, arg), l) :: defs -> - if not (StringSet.mem p all_pragmas) then - Reporting.warn "" l ("Unrecognised directive: " ^ p); - DEF_aux (DEF_pragma (p, arg), l) :: preprocess dir target opts defs - + if not (StringSet.mem p all_pragmas) then Reporting.warn "" l ("Unrecognised directive: " ^ p); + DEF_aux (DEF_pragma (p, arg), l) :: preprocess dir target opts defs | DEF_aux (DEF_outcome (outcome_spec, inner_defs), l) :: defs -> - DEF_aux (DEF_outcome (outcome_spec, preprocess dir target opts inner_defs), l) :: preprocess dir target opts defs - - | (DEF_aux (DEF_default (DT_aux (DT_order (_, ATyp_aux (atyp, _)), _)), l) as def) :: defs -> - begin match atyp with - | Parse_ast.ATyp_inc -> symbols := StringSet.add "_DEFAULT_INC" !symbols; def :: preprocess dir target opts defs - | Parse_ast.ATyp_dec -> symbols := StringSet.add "_DEFAULT_DEC" !symbols; def :: preprocess dir target opts defs - | _ -> def :: preprocess dir target opts defs - end - + DEF_aux (DEF_outcome (outcome_spec, preprocess dir target opts inner_defs), l) :: preprocess dir target opts defs + | (DEF_aux (DEF_default (DT_aux (DT_order (_, ATyp_aux (atyp, _)), _)), l) as def) :: defs -> begin + match atyp with + | Parse_ast.ATyp_inc -> + symbols := StringSet.add "_DEFAULT_INC" !symbols; + def :: preprocess dir target opts defs + | Parse_ast.ATyp_dec -> + symbols := StringSet.add "_DEFAULT_DEC" !symbols; + def :: preprocess dir target opts defs + | _ -> def :: preprocess dir target opts defs + end | def :: defs -> def :: preprocess dir target opts defs diff --git a/src/lib/preprocess.mli b/src/lib/preprocess.mli index 73717d310..47a42cca8 100644 --- a/src/lib/preprocess.mli +++ b/src/lib/preprocess.mli @@ -69,4 +69,5 @@ val clear_symbols : unit -> unit val have_symbol : string -> bool val add_symbol : string -> unit -val preprocess : string -> string option -> (Arg.key * Arg.spec * Arg.doc) list -> Parse_ast.def list -> Parse_ast.def list +val preprocess : + string -> string option -> (Arg.key * Arg.spec * Arg.doc) list -> Parse_ast.def list -> Parse_ast.def list diff --git a/src/lib/pretty_print_common.ml b/src/lib/pretty_print_common.ml index 344da64e8..b6be67eed 100644 --- a/src/lib/pretty_print_common.ml +++ b/src/lib/pretty_print_common.ml @@ -97,5 +97,5 @@ let doc_int i = string (Big_int.to_string i) let doc_op symb a b = infix 2 1 symb a b let doc_unop symb a = prefix 2 1 symb a -let print ?(len=100) channel doc = ToChannel.pretty 1. len channel doc -let to_buf ?(len=100) buf doc = ToBuffer.pretty 1. len buf doc +let print ?(len = 100) channel doc = ToChannel.pretty 1. len channel doc +let to_buf ?(len = 100) buf doc = ToBuffer.pretty 1. len buf doc diff --git a/src/lib/pretty_print_sail.ml b/src/lib/pretty_print_sail.ml index a125e896e..a3c8bf45d 100644 --- a/src/lib/pretty_print_sail.ml +++ b/src/lib/pretty_print_sail.ml @@ -77,35 +77,26 @@ module Big_int = Nat_big_num let doc_op symb a b = infix 2 1 symb a b -let doc_id (Id_aux (id_aux, _)) = - string (match id_aux with - | Id v -> v - | Operator op -> "operator " ^ op) +let doc_id (Id_aux (id_aux, _)) = string (match id_aux with Id v -> v | Operator op -> "operator " ^ op) let doc_kid kid = string (Ast_util.string_of_kid kid) let doc_attr attr arg = - if arg = "" then - Printf.ksprintf string "$[%s]" attr ^^ space - else - Printf.ksprintf string "$[%s %s]" attr arg ^^ space + if arg = "" then Printf.ksprintf string "$[%s]" attr ^^ space else Printf.ksprintf string "$[%s %s]" attr arg ^^ space let doc_kopt_no_parens = function | kopt when is_int_kopt kopt -> doc_kid (kopt_kid kopt) | kopt when is_typ_kopt kopt -> separate space [doc_kid (kopt_kid kopt); colon; string "Type"] | kopt when is_order_kopt kopt -> separate space [doc_kid (kopt_kid kopt); colon; string "Order"] | kopt -> separate space [doc_kid (kopt_kid kopt); colon; string "Bool"] - + let doc_kopt = function | kopt when is_int_kopt kopt -> doc_kopt_no_parens kopt | kopt -> parens (doc_kopt_no_parens kopt) let doc_int n = string (Big_int.to_string n) -let doc_ord (Ord_aux(o,_)) = match o with - | Ord_var v -> doc_kid v - | Ord_inc -> string "inc" - | Ord_dec -> string "dec" +let doc_ord (Ord_aux (o, _)) = match o with Ord_var v -> doc_kid v | Ord_inc -> string "inc" | Ord_dec -> string "dec" let rec doc_typ_pat (TP_aux (tpat_aux, _)) = match tpat_aux with @@ -117,27 +108,23 @@ let doc_nexp nexp = let rec atomic_nexp (Nexp_aux (n_aux, _) as nexp) = match n_aux with | Nexp_constant c -> string (Big_int.to_string c) - | Nexp_app (Id_aux (Operator op, _), [n1; n2]) -> - separate space [atomic_nexp n1; string op; atomic_nexp n2] + | Nexp_app (Id_aux (Operator op, _), [n1; n2]) -> separate space [atomic_nexp n1; string op; atomic_nexp n2] | Nexp_app (_id, _nexps) -> string (string_of_nexp nexp) (* This segfaults??!!!! doc_id id ^^ (parens (separate_map (comma ^^ space) doc_nexp nexps)) - *) + *) | Nexp_id id -> doc_id id | Nexp_var kid -> doc_kid kid | _ -> parens (nexp0 nexp) and nexp0 (Nexp_aux (n_aux, _) as nexp) = match n_aux with - | Nexp_sum (n1, Nexp_aux (Nexp_neg n2, _)) | Nexp_minus (n1, n2) -> - separate space [nexp0 n1; string "-"; nexp1 n2] + | Nexp_sum (n1, Nexp_aux (Nexp_neg n2, _)) | Nexp_minus (n1, n2) -> separate space [nexp0 n1; string "-"; nexp1 n2] | Nexp_sum (n1, Nexp_aux (Nexp_constant c, _)) when Big_int.less c Big_int.zero -> - separate space [nexp0 n1; string "-"; doc_int (Big_int.abs c)] + separate space [nexp0 n1; string "-"; doc_int (Big_int.abs c)] | Nexp_sum (n1, n2) -> separate space [nexp0 n1; string "+"; nexp1 n2] | _ -> nexp1 nexp and nexp1 (Nexp_aux (n_aux, _) as nexp) = - match n_aux with - | Nexp_times (n1, n2) -> separate space [nexp1 n1; string "*"; nexp2 n2] - | _ -> nexp2 nexp + match n_aux with Nexp_times (n1, n2) -> separate space [nexp1 n1; string "*"; nexp2 n2] | _ -> nexp2 nexp and nexp2 (Nexp_aux (n_aux, _) as nexp) = match n_aux with | Nexp_neg n -> separate space [string "-"; atomic_nexp n] @@ -159,28 +146,28 @@ let rec doc_nc nc = | NC_bounded_le (n1, n2) -> nc_op "<=" n1 n2 | NC_bounded_lt (n1, n2) -> nc_op "<" n1 n2 | NC_set (kid, ints) -> - separate space [doc_kid kid; string "in"; braces (separate_map (comma ^^ space) doc_int ints)] - | NC_app (id, args) -> - doc_id id ^^ parens (separate_map (comma ^^ space) doc_typ_arg args) + separate space [doc_kid kid; string "in"; braces (separate_map (comma ^^ space) doc_int ints)] + | NC_app (id, args) -> doc_id id ^^ parens (separate_map (comma ^^ space) doc_typ_arg args) | NC_var kid -> doc_kid kid | NC_or _ | NC_and _ -> nc0 ~parenthesize:true nc - and nc0 ?parenthesize:(parenthesize=false) nc = + and nc0 ?(parenthesize = false) nc = (* Rather than parens (nc0 x) we use nc0 ~parenthesize:true x, because if we rewrite a disjunction as a set constraint, then we can always omit the parens. *) - let parens' = if parenthesize then parens else (fun x -> x) in + let parens' = if parenthesize then parens else fun x -> x in let disjs = constraint_disj nc in let collect_constants kid = function - | NC_aux (NC_equal (Nexp_aux (Nexp_var kid', _), Nexp_aux (Nexp_constant c, _)), _) when Kid.compare kid kid' = 0 -> Some c + | NC_aux (NC_equal (Nexp_aux (Nexp_var kid', _), Nexp_aux (Nexp_constant c, _)), _) when Kid.compare kid kid' = 0 + -> + Some c | _ -> None in match disjs with - | NC_aux (NC_equal (Nexp_aux (Nexp_var kid, _), Nexp_aux (Nexp_constant c, _)), _) :: ncs -> - begin match Util.option_all (List.map (collect_constants kid) ncs) with - | None | Some [] -> parens' (separate_map (space ^^ bar ^^ space) nc1 disjs) - | Some cs -> - separate space [doc_kid kid; string "in"; braces (separate_map (comma ^^ space) doc_int (c :: cs))] - end + | NC_aux (NC_equal (Nexp_aux (Nexp_var kid, _), Nexp_aux (Nexp_constant c, _)), _) :: ncs -> begin + match Util.option_all (List.map (collect_constants kid) ncs) with + | None | Some [] -> parens' (separate_map (space ^^ bar ^^ space) nc1 disjs) + | Some cs -> separate space [doc_kid kid; string "in"; braces (separate_map (comma ^^ space) doc_int (c :: cs))] + end | _ -> parens' (separate_map (space ^^ bar ^^ space) nc1 disjs) and nc1 nc = let conjs = constraint_conj nc in @@ -188,53 +175,48 @@ let rec doc_nc nc = in atomic_nc (constraint_simp nc) -and doc_typ ?(simple=false) (Typ_aux (typ_aux, l)) = +and doc_typ ?(simple = false) (Typ_aux (typ_aux, l)) = match typ_aux with | Typ_id id -> doc_id id | Typ_app (id, []) -> doc_id id - | Typ_app (Id_aux (Operator str, _), [x; y]) -> - separate space [doc_typ_arg x; string str; doc_typ_arg y] + | Typ_app (Id_aux (Operator str, _), [x; y]) -> separate space [doc_typ_arg x; string str; doc_typ_arg y] | Typ_app (id, typs) when Id.compare id (mk_id "atom") = 0 -> - string "int" ^^ parens (separate_map (string ", ") doc_typ_arg typs) + string "int" ^^ parens (separate_map (string ", ") doc_typ_arg typs) | Typ_app (id, typs) when Id.compare id (mk_id "atom_bool") = 0 -> - string "bool" ^^ parens (separate_map (string ", ") doc_typ_arg typs) + string "bool" ^^ parens (separate_map (string ", ") doc_typ_arg typs) | Typ_app (id, typs) -> doc_id id ^^ parens (separate_map (string ", ") doc_typ_arg typs) | Typ_tuple typs -> parens (separate_map (string ", ") doc_typ typs) | Typ_var kid -> doc_kid kid (* Resugar set types like {|1, 2, 3|} *) - | Typ_exist ([kopt], - NC_aux (NC_set (kid1, ints), _), - Typ_aux (Typ_app (id, [A_aux (A_nexp (Nexp_aux (Nexp_var kid2, _)), _)]), _)) - when Kid.compare (kopt_kid kopt) kid1 == 0 && Kid.compare kid1 kid2 == 0 && Id.compare (mk_id "atom") id == 0 -> - enclose (string "{|") (string "|}") (separate_map (string ", ") doc_int ints) + | Typ_exist + ( [kopt], + NC_aux (NC_set (kid1, ints), _), + Typ_aux (Typ_app (id, [A_aux (A_nexp (Nexp_aux (Nexp_var kid2, _)), _)]), _) + ) + when Kid.compare (kopt_kid kopt) kid1 == 0 && Kid.compare kid1 kid2 == 0 && Id.compare (mk_id "atom") id == 0 -> + enclose (string "{|") (string "|}") (separate_map (string ", ") doc_int ints) | Typ_exist (kopts, nc, typ) -> - braces (separate_map space doc_kopt kopts ^^ comma ^^ space ^^ doc_nc nc ^^ dot ^^ space ^^ doc_typ typ) - | Typ_fn (typs, typ) -> - separate space [doc_arg_typs typs; string "->"; doc_typ ~simple:simple typ] - | Typ_bidir (typ1, typ2) -> - separate space [doc_typ typ1; string "<->"; doc_typ typ2] + braces (separate_map space doc_kopt kopts ^^ comma ^^ space ^^ doc_nc nc ^^ dot ^^ space ^^ doc_typ typ) + | Typ_fn (typs, typ) -> separate space [doc_arg_typs typs; string "->"; doc_typ ~simple typ] + | Typ_bidir (typ1, typ2) -> separate space [doc_typ typ1; string "<->"; doc_typ typ2] | Typ_internal_unknown -> raise (Reporting.err_unreachable l __POS__ "escaped Typ_internal_unknown") + and doc_typ_arg (A_aux (ta_aux, _)) = match ta_aux with | A_typ typ -> doc_typ typ | A_nexp nexp -> doc_nexp nexp | A_order o -> doc_ord o | A_bool nc -> doc_nc nc -and doc_arg_typs = function - | [typ] -> doc_typ typ - | typs -> parens (separate_map (comma ^^ space) doc_typ typs) + +and doc_arg_typs = function [typ] -> doc_typ typ | typs -> parens (separate_map (comma ^^ space) doc_typ typs) let doc_subst (IS_aux (subst_aux, _)) = match subst_aux with | IS_typ (kid, typ) -> doc_kid kid ^^ space ^^ equals ^^ space ^^ doc_typ typ | IS_id (id1, id2) -> doc_id id1 ^^ space ^^ equals ^^ space ^^ doc_id id2 - + let doc_kind (K_aux (k, _)) = - string (match k with - | K_int -> "Int" - | K_type -> "Type" - | K_bool -> "Bool" - | K_order -> "Order") + string (match k with K_int -> "Int" | K_type -> "Type" | K_bool -> "Bool" | K_order -> "Order") let doc_kopts = separate_map space doc_kopt @@ -244,11 +226,7 @@ let doc_quants quants = | _ :: qis -> get_kopts qis | [] -> [] in - let qi_nc (QI_aux (qi_aux, _)) = - match qi_aux with - | QI_constraint nc -> [nc] - | _ -> [] - in + let qi_nc (QI_aux (qi_aux, _)) = match qi_aux with QI_constraint nc -> [nc] | _ -> [] in let kdoc = doc_kopts (get_kopts quants) in let ncs = List.concat (List.map qi_nc quants) in match ncs with @@ -265,11 +243,7 @@ let doc_param_quants quants = | QI_id kopt -> [doc_kid (kopt_kid kopt) ^^ colon ^^ space ^^ string "Order"] | QI_constraint _ -> [] in - let qi_nc (QI_aux (qi_aux, _)) = - match qi_aux with - | QI_constraint nc -> [nc] - | _ -> [] - in + let qi_nc (QI_aux (qi_aux, _)) = match qi_aux with QI_constraint nc -> [nc] | _ -> [] in let kdoc = separate (comma ^^ space) (List.concat (List.map doc_qi_kopt quants)) in let ncs = List.concat (List.map qi_nc quants) in match ncs with @@ -277,46 +251,43 @@ let doc_param_quants quants = | [nc] -> parens kdoc ^^ comma ^^ space ^^ doc_nc nc | nc :: ncs -> parens kdoc ^^ comma ^^ space ^^ doc_nc (List.fold_left nc_and nc ncs) -let doc_binding ?(simple=false) ((TypQ_aux (tq_aux, _) as typq), typ) = +let doc_binding ?(simple = false) ((TypQ_aux (tq_aux, _) as typq), typ) = match tq_aux with - | TypQ_no_forall -> doc_typ ~simple:simple typ - | TypQ_tq [] -> doc_typ ~simple:simple typ + | TypQ_no_forall -> doc_typ ~simple typ + | TypQ_tq [] -> doc_typ ~simple typ | TypQ_tq qs -> - if !opt_use_heuristics && String.length (string_of_typquant typq) > 60 then - let kopts, ncs = quant_split typq in - if ncs = [] then - string "forall" ^^ space ^^ separate_map space doc_kopt kopts ^^ dot - ^//^ doc_typ ~simple:simple typ - else - string "forall" ^^ space ^^ separate_map space doc_kopt kopts ^^ comma - ^//^ (separate_map (space ^^ string "&" ^^ space) doc_nc ncs ^^ dot - ^^ hardline ^^ doc_typ ~simple:simple typ) - else - string "forall" ^^ space ^^ doc_quants qs ^^ dot ^//^ doc_typ ~simple:simple typ - -let doc_typschm ?(simple=false) (TypSchm_aux (TypSchm_ts (typq, typ), _)) = doc_binding ~simple:simple (typq, typ) + if !opt_use_heuristics && String.length (string_of_typquant typq) > 60 then ( + let kopts, ncs = quant_split typq in + if ncs = [] then string "forall" ^^ space ^^ separate_map space doc_kopt kopts ^^ dot ^//^ doc_typ ~simple typ + else + string "forall" ^^ space ^^ separate_map space doc_kopt kopts ^^ comma + ^//^ separate_map (space ^^ string "&" ^^ space) doc_nc ncs + ^^ dot ^^ hardline ^^ doc_typ ~simple typ + ) + else string "forall" ^^ space ^^ doc_quants qs ^^ dot ^//^ doc_typ ~simple typ + +let doc_typschm ?(simple = false) (TypSchm_aux (TypSchm_ts (typq, typ), _)) = doc_binding ~simple (typq, typ) let doc_typschm_typ (TypSchm_aux (TypSchm_ts (_, typ), _)) = doc_typ typ let doc_typquant (TypQ_aux (tq_aux, _)) = - match tq_aux with - | TypQ_no_forall -> None - | TypQ_tq [] -> None - | TypQ_tq qs -> Some (doc_param_quants qs) - -let doc_lit (L_aux(l,_)) = - utf8string (match l with - | L_unit -> "()" - | L_zero -> "bitzero" - | L_one -> "bitone" - | L_true -> "true" - | L_false -> "false" - | L_num i -> Big_int.to_string i - | L_hex n -> "0x" ^ n - | L_bin n -> "0b" ^ n - | L_real r -> r - | L_undef -> "undefined" - | L_string s -> "\"" ^ String.escaped s ^ "\"") + match tq_aux with TypQ_no_forall -> None | TypQ_tq [] -> None | TypQ_tq qs -> Some (doc_param_quants qs) + +let doc_lit (L_aux (l, _)) = + utf8string + ( match l with + | L_unit -> "()" + | L_zero -> "bitzero" + | L_one -> "bitone" + | L_true -> "true" + | L_false -> "false" + | L_num i -> Big_int.to_string i + | L_hex n -> "0x" ^ n + | L_bin n -> "0b" ^ n + | L_real r -> r + | L_undef -> "undefined" + | L_string s -> "\"" ^ String.escaped s ^ "\"" + ) let rec doc_pat (P_aux (p_aux, (_, uannot))) = concat_map (fun (_, attr, arg) -> doc_attr attr arg) (get_attributes uannot) @@ -329,37 +300,27 @@ let rec doc_pat (P_aux (p_aux, (_, uannot))) = | P_typ (typ, pat) -> separate space [doc_pat pat; colon; doc_typ typ] | P_lit lit -> doc_lit lit (* P_var short form sugar *) - | P_var (P_aux (P_id id, _), TP_aux (TP_var kid, _)) when Id.compare (id_of_kid kid) id == 0 -> - doc_kid kid + | P_var (P_aux (P_id id, _), TP_aux (TP_var kid, _)) when Id.compare (id_of_kid kid) id == 0 -> doc_kid kid | P_var (pat, tpat) -> parens (separate space [doc_pat pat; string "as"; doc_typ_pat tpat]) | P_vector pats -> brackets (separate_map (comma ^^ space) doc_pat pats) | P_vector_concat pats -> parens (separate_map (space ^^ string "@" ^^ space) doc_pat pats) | P_vector_subrange (id, n, m) -> - if Big_int.equal n m then - doc_id id ^^ brackets (string (Big_int.to_string n)) - else - doc_id id ^^ brackets (string (Big_int.to_string n) ^^ string ".." ^^ string (Big_int.to_string m)) + if Big_int.equal n m then doc_id id ^^ brackets (string (Big_int.to_string n)) + else doc_id id ^^ brackets (string (Big_int.to_string n) ^^ string ".." ^^ string (Big_int.to_string m)) | P_wild -> string "_" | P_as (pat, id) -> parens (separate space [doc_pat pat; string "as"; doc_id id]) | P_app (id, pats) -> doc_id id ^^ parens (separate_map (comma ^^ space) doc_pat pats) | P_list pats -> string "[|" ^^ separate_map (comma ^^ space) doc_pat pats ^^ string "|]" | P_cons (hd_pat, tl_pat) -> parens (separate space [doc_pat hd_pat; string "::"; doc_pat tl_pat]) | P_string_append [] -> string "\"\"" - | P_string_append pats -> - parens (separate_map (string " ^ ") doc_pat pats) + | P_string_append pats -> parens (separate_map (string " ^ ") doc_pat pats) (* if_block_x is true if x should be printed like a block, i.e. with newlines. Blocks are automatically printed as blocks, so this returns false for them. *) -let if_block_then (E_aux (e_aux, _)) = - match e_aux with - | E_assign _ | E_if _ -> true - | _ -> false +let if_block_then (E_aux (e_aux, _)) = match e_aux with E_assign _ | E_if _ -> true | _ -> false -let if_block_else (E_aux (e_aux, _)) = - match e_aux with - | E_assign _ -> true - | _ -> false +let if_block_else (E_aux (e_aux, _)) = match e_aux with E_assign _ -> true | _ -> false let fixities = let fixities' = @@ -385,149 +346,143 @@ let fixities = in ref (fixities' : (prec * int) Bindings.t) -type 'a vector_update = - | VU_single of 'a exp * 'a exp - | VU_range of 'a exp * 'a exp * 'a exp +type 'a vector_update = VU_single of 'a exp * 'a exp | VU_range of 'a exp * 'a exp * 'a exp let rec get_vector_updates (E_aux (e_aux, _) as exp) = match e_aux with | E_vector_update (exp1, exp2, exp3) -> - let input, updates = get_vector_updates exp1 in - input, updates @ [VU_single (exp2, exp3)] + let input, updates = get_vector_updates exp1 in + (input, updates @ [VU_single (exp2, exp3)]) | E_vector_update_subrange (exp1, exp2, exp3, exp4) -> - let input, updates = get_vector_updates exp1 in - input, updates @ [VU_range (exp2, exp3, exp4)] - | _ -> exp, [] + let input, updates = get_vector_updates exp1 in + (input, updates @ [VU_range (exp2, exp3, exp4)]) + | _ -> (exp, []) let rec doc_exp (E_aux (e_aux, (_, uannot)) as exp) = concat_map (fun (_, attr, arg) -> doc_attr attr arg) (get_attributes uannot) ^^ match e_aux with | E_block [] -> string "()" - | E_block exps -> - group (lbrace ^^ nest 4 (hardline ^^ doc_block exps) ^^ hardline ^^ rbrace) + | E_block exps -> group (lbrace ^^ nest 4 (hardline ^^ doc_block exps) ^^ hardline ^^ rbrace) (* This is mostly for the -convert option *) | E_app_infix (x, id, y) when Id.compare (mk_id "quot") id == 0 -> - separate space [doc_atomic_exp x; string "/"; doc_atomic_exp y] + separate space [doc_atomic_exp x; string "/"; doc_atomic_exp y] | E_app_infix _ -> doc_infix 0 exp | E_tuple exps -> parens (separate_map (comma ^^ space) doc_exp exps) - | E_if (if_exp, then_exp, (E_aux (E_if (_, _, _), _) as else_exp)) when !opt_insert_braces -> - separate space [string "if"; doc_exp if_exp; string "then"] ^^ space - ^^ doc_exp_as_block then_exp - ^^ space ^^ string "else" ^^ space - ^^ doc_exp else_exp + separate space [string "if"; doc_exp if_exp; string "then"] + ^^ space ^^ doc_exp_as_block then_exp ^^ space ^^ string "else" ^^ space ^^ doc_exp else_exp | E_if (if_exp, then_exp, else_exp) when !opt_insert_braces -> - separate space [string "if"; doc_exp if_exp; string "then"] ^^ space - ^^ doc_exp_as_block then_exp - ^^ space ^^ string "else" ^^ space - ^^ doc_exp_as_block else_exp - + separate space [string "if"; doc_exp if_exp; string "then"] + ^^ space ^^ doc_exp_as_block then_exp ^^ space ^^ string "else" ^^ space ^^ doc_exp_as_block else_exp (* Various rules to try to format if blocks nicely based on content. There's also an if rule in doc_block for { ... if . then .; ... } because it's unambiguous there. *) | E_if (if_exp, then_exp, else_exp) when if_block_then then_exp && if_block_else else_exp -> - (separate space [string "if"; doc_exp if_exp; string "then"] ^//^ doc_exp then_exp) - ^/^ (string "else" ^//^ doc_exp else_exp) + (separate space [string "if"; doc_exp if_exp; string "then"] ^//^ doc_exp then_exp) + ^/^ string "else" ^//^ doc_exp else_exp | E_if (if_exp, then_exp, (E_aux (E_if _, _) as else_exp)) when if_block_then then_exp -> - (separate space [string "if"; doc_exp if_exp; string "then"] ^//^ doc_exp then_exp) - ^/^ (string "else" ^^ space ^^ doc_exp else_exp) + (separate space [string "if"; doc_exp if_exp; string "then"] ^//^ doc_exp then_exp) + ^/^ string "else" ^^ space ^^ doc_exp else_exp | E_if (if_exp, then_exp, else_exp) when if_block_else else_exp -> - (separate space [string "if"; doc_exp if_exp; string "then"; doc_exp then_exp]) - ^^ space ^^ (string "else" ^//^ doc_exp else_exp) + separate space [string "if"; doc_exp if_exp; string "then"; doc_exp then_exp] + ^^ space ^^ string "else" ^//^ doc_exp else_exp | E_if (if_exp, then_exp, else_exp) when if_block_then then_exp -> - (separate space [string "if"; doc_exp if_exp; string "then"] ^//^ doc_exp then_exp) - ^/^ (string "else" ^^ space ^^ doc_exp else_exp) + (separate space [string "if"; doc_exp if_exp; string "then"] ^//^ doc_exp then_exp) + ^/^ string "else" ^^ space ^^ doc_exp else_exp | E_if (if_exp, E_aux (E_block (_ :: _ as then_exps), _), E_aux (E_block (_ :: _ as else_exps), _)) -> - separate space [string "if"; doc_exp if_exp; string "then {"] ^^ group (nest 4 (hardline ^^ doc_block then_exps) ^^ hardline) ^^ string "} else {" ^^ group (nest 4 (hardline ^^ doc_block else_exps) ^^ hardline ^^ rbrace) - | E_if (if_exp, E_aux (E_block (_ :: _ as then_exps), _), (E_aux (E_if _,_) as else_exp)) -> - separate space [string "if"; doc_exp if_exp; string "then {"] ^^ group (nest 4 (hardline ^^ doc_block then_exps) ^^ hardline) ^^ string "} else " ^^ doc_exp else_exp + separate space [string "if"; doc_exp if_exp; string "then {"] + ^^ group (nest 4 (hardline ^^ doc_block then_exps) ^^ hardline) + ^^ string "} else {" + ^^ group (nest 4 (hardline ^^ doc_block else_exps) ^^ hardline ^^ rbrace) + | E_if (if_exp, E_aux (E_block (_ :: _ as then_exps), _), (E_aux (E_if _, _) as else_exp)) -> + separate space [string "if"; doc_exp if_exp; string "then {"] + ^^ group (nest 4 (hardline ^^ doc_block then_exps) ^^ hardline) + ^^ string "} else " ^^ doc_exp else_exp | E_if (if_exp, E_aux (E_block (_ :: _ as then_exps), _), else_exp) -> - separate space [string "if"; doc_exp if_exp; string "then {"] ^^ group (nest 4 (hardline ^^ doc_block then_exps) ^^ hardline) ^^ string "} else" ^//^ doc_exp else_exp + separate space [string "if"; doc_exp if_exp; string "then {"] + ^^ group (nest 4 (hardline ^^ doc_block then_exps) ^^ hardline) + ^^ string "} else" ^//^ doc_exp else_exp | E_if (if_exp, then_exp, E_aux (E_block (_ :: _ as else_exps), _)) -> - group ((separate space [string "if"; doc_exp if_exp; string "then"] ^//^ doc_exp then_exp) ^/^ string "else {") ^^ group (nest 4 (hardline ^^ doc_block else_exps) ^^ hardline ^^ rbrace) + group ((separate space [string "if"; doc_exp if_exp; string "then"] ^//^ doc_exp then_exp) ^/^ string "else {") + ^^ group (nest 4 (hardline ^^ doc_block else_exps) ^^ hardline ^^ rbrace) | E_if (if_exp, then_exp, else_exp) -> - group ((separate space [string "if"; doc_exp if_exp; string "then"] ^//^ doc_exp then_exp) ^/^ string "else") ^//^ doc_exp else_exp - + group ((separate space [string "if"; doc_exp if_exp; string "then"] ^//^ doc_exp then_exp) ^/^ string "else") + ^//^ doc_exp else_exp | E_list exps -> string "[|" ^^ separate_map (comma ^^ space) doc_exp exps ^^ string "|]" | E_cons (exp1, exp2) -> doc_atomic_exp exp1 ^^ space ^^ string "::" ^^ space ^^ doc_exp exp2 | E_struct fexps -> separate space [string "struct"; string "{"; doc_fexps fexps; string "}"] | E_loop (While, measure, cond, exp) -> - separate space ([string "while"] @ doc_measure measure @ [doc_exp cond; string "do"; doc_exp exp]) + separate space ([string "while"] @ doc_measure measure @ [doc_exp cond; string "do"; doc_exp exp]) | E_loop (Until, measure, cond, exp) -> - separate space ([string "repeat"] @ doc_measure measure @ [doc_exp exp; string "until"; doc_exp cond]) - | E_struct_update (exp, fexps) -> - separate space [string "{"; doc_exp exp; string "with"; doc_fexps fexps; string "}"] + separate space ([string "repeat"] @ doc_measure measure @ [doc_exp exp; string "until"; doc_exp cond]) + | E_struct_update (exp, fexps) -> separate space [string "{"; doc_exp exp; string "with"; doc_fexps fexps; string "}"] | E_vector_append (exp1, exp2) -> separate space [doc_atomic_exp exp1; string "@"; doc_atomic_exp exp2] - | E_match (exp, pexps) -> - separate space [string "match"; doc_exp exp; doc_pexps pexps] - | E_let (LB_aux (LB_val (pat, binding), _), exp) -> - doc_let_style "let" (doc_pat pat) (doc_exp binding) exp - | E_internal_plet (pat, exp1, exp2) -> - doc_let_style "internal_plet" (doc_pat pat) (doc_exp exp1) exp2 - | E_var (lexp, binding, exp) -> - doc_let_style "var" (doc_lexp lexp) (doc_exp binding) exp - | E_assign (lexp, exp) -> - separate space [doc_lexp lexp; equals; doc_exp exp] + | E_match (exp, pexps) -> separate space [string "match"; doc_exp exp; doc_pexps pexps] + | E_let (LB_aux (LB_val (pat, binding), _), exp) -> doc_let_style "let" (doc_pat pat) (doc_exp binding) exp + | E_internal_plet (pat, exp1, exp2) -> doc_let_style "internal_plet" (doc_pat pat) (doc_exp exp1) exp2 + | E_var (lexp, binding, exp) -> doc_let_style "var" (doc_lexp lexp) (doc_exp binding) exp + | E_assign (lexp, exp) -> separate space [doc_lexp lexp; equals; doc_exp exp] | E_for (id, exp1, exp2, exp3, order, exp4) -> - let header = - string "foreach" ^^ space ^^ - group (parens (separate (break 1) - [ doc_id id; - string "from " ^^ doc_atomic_exp exp1; - string "to " ^^ doc_atomic_exp exp2; - string "by " ^^ doc_atomic_exp exp3; - string "in " ^^ doc_ord order ])) - in - header ^^ space ^^ doc_exp exp4 + let header = + string "foreach" ^^ space + ^^ group + (parens + (separate (break 1) + [ + doc_id id; + string "from " ^^ doc_atomic_exp exp1; + string "to " ^^ doc_atomic_exp exp2; + string "by " ^^ doc_atomic_exp exp3; + string "in " ^^ doc_ord order; + ] + ) + ) + in + header ^^ space ^^ doc_exp exp4 (* Resugar an assert with an empty message *) | E_throw exp -> string "throw" ^^ parens (doc_exp exp) - | E_try (exp, pexps) -> - separate space [string "try"; doc_exp exp; string "catch"; doc_pexps pexps] + | E_try (exp, pexps) -> separate space [string "try"; doc_exp exp; string "catch"; doc_pexps pexps] | E_return (E_aux (E_lit (L_aux (L_unit, _)), _)) -> string "return()" | E_return exp -> string "return" ^^ parens (doc_exp exp) | E_internal_return exp -> string "internal_return" ^^ parens (doc_exp exp) | E_app (id, [exp]) when Id.compare (mk_id "pow2") id == 0 -> - separate space [string "2"; string "^"; doc_atomic_exp exp] - | E_internal_assume (nc, exp) -> - doc_let_style_general "internal_assume" (doc_nc nc) None exp + separate space [string "2"; string "^"; doc_atomic_exp exp] + | E_internal_assume (nc, exp) -> doc_let_style_general "internal_assume" (doc_nc nc) None exp | _ -> doc_atomic_exp exp + and doc_let_style keyword lhs rhs body = doc_let_style_general keyword lhs (Some rhs) body + and doc_let_style_general keyword lhs rhs body = (* Avoid staircases *) - let (^///^) = - match unaux_exp body with - | E_let _ | E_var _ | E_internal_plet _ | E_internal_assume _ -> (^/^) - | _ -> (^//^) + let ( ^///^ ) = + match unaux_exp body with E_let _ | E_var _ | E_internal_plet _ | E_internal_assume _ -> ( ^/^ ) | _ -> ( ^//^ ) in match rhs with | Some rhs -> group ((separate space [string keyword; lhs; equals] ^//^ rhs) ^/^ string "in") ^///^ doc_exp body - | None -> group ((string keyword ^//^ lhs) ^/^ string "in") ^///^ doc_exp body + | None -> group ((string keyword ^//^ lhs) ^/^ string "in") ^///^ doc_exp body + and doc_measure (Measure_aux (m_aux, _)) = - match m_aux with - | Measure_none -> [] - | Measure_some exp -> [string "termination_measure"; braces (doc_exp exp)] + match m_aux with Measure_none -> [] | Measure_some exp -> [string "termination_measure"; braces (doc_exp exp)] + and doc_infix n (E_aux (e_aux, _) as exp) = match e_aux with - | E_app_infix (l, op, r) when n < 10 -> - begin - try - match Bindings.find op !fixities with - | (Infix, m) when m >= n -> separate space [doc_infix (m + 1) l; doc_id op; doc_infix (m + 1) r] - | (Infix, m) -> parens (separate space [doc_infix (m + 1) l; doc_id op; doc_infix (m + 1) r]) - | (InfixL, m) when m >= n -> separate space [doc_infix m l; doc_id op; doc_infix (m + 1) r] - | (InfixL, m) -> parens (separate space [doc_infix m l; doc_id op; doc_infix (m + 1) r]) - | (InfixR, m) when m >= n -> separate space [doc_infix (m + 1) l; doc_id op; doc_infix m r] - | (InfixR, m) -> parens (separate space [doc_infix (m + 1) l; doc_id op; doc_infix m r]) - with - | Not_found -> - parens (separate space [doc_atomic_exp l; doc_id op; doc_atomic_exp r]) - end + | E_app_infix (l, op, r) when n < 10 -> begin + try + match Bindings.find op !fixities with + | Infix, m when m >= n -> separate space [doc_infix (m + 1) l; doc_id op; doc_infix (m + 1) r] + | Infix, m -> parens (separate space [doc_infix (m + 1) l; doc_id op; doc_infix (m + 1) r]) + | InfixL, m when m >= n -> separate space [doc_infix m l; doc_id op; doc_infix (m + 1) r] + | InfixL, m -> parens (separate space [doc_infix m l; doc_id op; doc_infix (m + 1) r]) + | InfixR, m when m >= n -> separate space [doc_infix (m + 1) l; doc_id op; doc_infix m r] + | InfixR, m -> parens (separate space [doc_infix (m + 1) l; doc_id op; doc_infix m r]) + with Not_found -> parens (separate space [doc_atomic_exp l; doc_id op; doc_atomic_exp r]) + end | _ -> doc_atomic_exp exp + and doc_atomic_exp (E_aux (e_aux, _) as exp) = match e_aux with - | E_typ (typ, exp) -> - separate space [doc_atomic_exp exp; colon; doc_typ typ] + | E_typ (typ, exp) -> separate space [doc_atomic_exp exp; colon; doc_typ typ] | E_lit lit -> doc_lit lit | E_id id -> doc_id id | E_ref id -> string "ref" ^^ space ^^ doc_id id @@ -542,40 +497,39 @@ and doc_atomic_exp (E_aux (e_aux, _) as exp) = | E_assert (exp1, exp2) -> string "assert" ^^ parens (doc_exp exp1 ^^ comma ^^ space ^^ doc_exp exp2) | E_exit exp -> string "exit" ^^ parens (doc_exp exp) | E_vector_access (exp1, exp2) -> doc_atomic_exp exp1 ^^ brackets (doc_exp exp2) - | E_vector_subrange (exp1, exp2, exp3) -> doc_atomic_exp exp1 ^^ brackets (separate space [doc_exp exp2; string ".."; doc_exp exp3]) + | E_vector_subrange (exp1, exp2, exp3) -> + doc_atomic_exp exp1 ^^ brackets (separate space [doc_exp exp2; string ".."; doc_exp exp3]) | E_vector exps -> brackets (separate_map (comma ^^ space) doc_exp exps) - | E_vector_update _ - | E_vector_update_subrange _ -> - let input, updates = get_vector_updates exp in - let updates_doc = separate_map (comma ^^ space) doc_vector_update updates in - brackets (separate space ([doc_exp input; string "with"; updates_doc])) + | E_vector_update _ | E_vector_update_subrange _ -> + let input, updates = get_vector_updates exp in + let updates_doc = separate_map (comma ^^ space) doc_vector_update updates in + brackets (separate space [doc_exp input; string "with"; updates_doc]) | E_internal_value v -> - if !Interactive.opt_interactive then - string (Value.string_of_value v |> Util.green |> Util.clear) - else - string (Value.string_of_value v) + if !Interactive.opt_interactive then string (Value.string_of_value v |> Util.green |> Util.clear) + else string (Value.string_of_value v) | _ -> parens (doc_exp exp) -and doc_fexps fexps = - separate_map (comma ^^ space) doc_fexp fexps -and doc_fexp (FE_aux (FE_fexp (id, exp), _)) = - separate space [doc_id id; equals; doc_exp exp] + +and doc_fexps fexps = separate_map (comma ^^ space) doc_fexp fexps + +and doc_fexp (FE_aux (FE_fexp (id, exp), _)) = separate space [doc_id id; equals; doc_exp exp] + and doc_block = function | [] -> string "()" | [E_aux (E_let (LB_aux (LB_val (pat, binding), _), E_aux (E_block exps, _)), _)] -> - separate space [string "let"; doc_pat pat; equals; doc_exp binding] ^^ semi ^^ hardline ^^ doc_block exps + separate space [string "let"; doc_pat pat; equals; doc_exp binding] ^^ semi ^^ hardline ^^ doc_block exps | [E_aux (E_let (LB_aux (LB_val (pat, binding), _), exp), _)] -> - separate space [string "let"; doc_pat pat; equals; doc_exp binding] ^^ semi ^^ hardline ^^ doc_block [exp] + separate space [string "let"; doc_pat pat; equals; doc_exp binding] ^^ semi ^^ hardline ^^ doc_block [exp] | [E_aux (E_var (lexp, binding, E_aux (E_block exps, _)), _)] -> - separate space [string "var"; doc_lexp lexp; equals; doc_exp binding] ^^ semi ^^ hardline ^^ doc_block exps - | (E_aux (E_if (if_exp, then_exp, E_aux ((E_lit (L_aux (L_unit, _)) | E_block []), _)), _))::exps -> - group (separate space [string "if"; doc_exp if_exp; string "then"; doc_exp then_exp]) ^^ semi ^^ hardline ^^ - doc_block exps + separate space [string "var"; doc_lexp lexp; equals; doc_exp binding] ^^ semi ^^ hardline ^^ doc_block exps + | E_aux (E_if (if_exp, then_exp, E_aux ((E_lit (L_aux (L_unit, _)) | E_block []), _)), _) :: exps -> + group (separate space [string "if"; doc_exp if_exp; string "then"; doc_exp then_exp]) + ^^ semi ^^ hardline ^^ doc_block exps | [exp] -> doc_exp exp | exp :: exps -> doc_exp exp ^^ semi ^^ hardline ^^ doc_block exps + and doc_lexp (LE_aux (l_aux, _) as lexp) = - match l_aux with - | LE_typ (typ, id) -> separate space [doc_id id; colon; doc_typ typ] - | _ -> doc_atomic_lexp lexp + match l_aux with LE_typ (typ, id) -> separate space [doc_id id; colon; doc_typ typ] | _ -> doc_atomic_lexp lexp + and doc_atomic_lexp (LE_aux (l_aux, _) as lexp) = match l_aux with | LE_id id -> doc_id id @@ -583,59 +537,61 @@ and doc_atomic_lexp (LE_aux (l_aux, _) as lexp) = | LE_tuple lexps -> lparen ^^ separate_map (comma ^^ space) doc_lexp lexps ^^ rparen | LE_field (lexp, id) -> doc_atomic_lexp lexp ^^ dot ^^ doc_id id | LE_vector (lexp, exp) -> doc_atomic_lexp lexp ^^ brackets (doc_exp exp) - | LE_vector_range (lexp, exp1, exp2) -> doc_atomic_lexp lexp ^^ brackets (separate space [doc_exp exp1; string ".."; doc_exp exp2]) + | LE_vector_range (lexp, exp1, exp2) -> + doc_atomic_lexp lexp ^^ brackets (separate space [doc_exp exp1; string ".."; doc_exp exp2]) | LE_vector_concat lexps -> parens (separate_map (string " @ ") doc_lexp lexps) | LE_app (id, exps) -> doc_id id ^^ parens (separate_map (comma ^^ space) doc_exp exps) | _ -> parens (doc_lexp lexp) + and doc_pexps pexps = surround 2 0 lbrace (separate_map (comma ^^ hardline) doc_pexp pexps) rbrace + and doc_pexp (Pat_aux (pat_aux, _)) = match pat_aux with | Pat_exp (pat, exp) -> separate space [doc_pat pat; string "=>"; doc_exp exp] - | Pat_when (pat, wh, exp) -> - separate space [doc_pat pat; string "if"; doc_exp wh; string "=>"; doc_exp exp] + | Pat_when (pat, wh, exp) -> separate space [doc_pat pat; string "if"; doc_exp wh; string "=>"; doc_exp exp] + and doc_letbind (LB_aux (lb_aux, _)) = - match lb_aux with - | LB_val (pat, exp) -> - separate space [doc_pat pat; equals; doc_exp exp] + match lb_aux with LB_val (pat, exp) -> separate space [doc_pat pat; equals; doc_exp exp] + and doc_exp_as_block (E_aux (aux, _) as exp) = match aux with | E_block _ | E_lit _ -> doc_exp exp - | _ when !opt_insert_braces -> - group (lbrace ^^ nest 4 (hardline ^^ doc_block [exp]) ^^ hardline ^^ rbrace) + | _ when !opt_insert_braces -> group (lbrace ^^ nest 4 (hardline ^^ doc_block [exp]) ^^ hardline ^^ rbrace) | _ -> doc_exp exp + and doc_vector_update = function - | VU_single (idx, value) -> - begin match unaux_exp idx, unaux_exp value with - | E_id id, E_id id' when Id.compare id id' == 0 -> - doc_atomic_exp idx - | _, _ -> - separate space [doc_atomic_exp idx; equals; doc_exp value] - end - | VU_range (high, low, value) -> separate space [doc_atomic_exp high; string ".."; doc_atomic_exp low; equals; doc_exp value] - -let doc_funcl (FCL_aux (FCL_funcl (id, Pat_aux (pexp,_)), _)) = + | VU_single (idx, value) -> begin + match (unaux_exp idx, unaux_exp value) with + | E_id id, E_id id' when Id.compare id id' == 0 -> doc_atomic_exp idx + | _, _ -> separate space [doc_atomic_exp idx; equals; doc_exp value] + end + | VU_range (high, low, value) -> + separate space [doc_atomic_exp high; string ".."; doc_atomic_exp low; equals; doc_exp value] + +let doc_funcl (FCL_aux (FCL_funcl (id, Pat_aux (pexp, _)), _)) = match pexp with - | Pat_exp (pat,exp) -> - group (separate space [doc_id id; doc_pat pat; equals; doc_exp_as_block exp]) - | Pat_when (pat,wh,exp) -> - group (separate space [doc_id id; parens (separate space [doc_pat pat; string "if"; doc_exp wh]); string "="; doc_exp_as_block exp]) + | Pat_exp (pat, exp) -> group (separate space [doc_id id; doc_pat pat; equals; doc_exp_as_block exp]) + | Pat_when (pat, wh, exp) -> + group + (separate space + [doc_id id; parens (separate space [doc_pat pat; string "if"; doc_exp wh]); string "="; doc_exp_as_block exp] + ) let doc_default (DT_aux (DT_order ord, _)) = separate space [string "default"; string "Order"; doc_ord ord] -let doc_rec (Rec_aux (r,_)) = +let doc_rec (Rec_aux (r, _)) = match r with - | Rec_nonrec - | Rec_rec -> empty - | Rec_measure (pat,exp) -> braces (doc_pat pat ^^ string " => " ^^ doc_exp exp) ^^ space + | Rec_nonrec | Rec_rec -> empty + | Rec_measure (pat, exp) -> braces (doc_pat pat ^^ string " => " ^^ doc_exp exp) ^^ space let doc_fundef (FD_aux (FD_function (r, _, funcls), annot)) = match funcls with | [] -> failwith "Empty function list" | _ -> - let rec_pp = doc_rec r in - let sep = hardline ^^ string "and" ^^ space in - let clauses = separate_map sep doc_funcl funcls in - string "function" ^^ space ^^ rec_pp ^^ clauses + let rec_pp = doc_rec r in + let sep = hardline ^^ string "and" ^^ space in + let clauses = separate_map sep doc_funcl funcls in + string "function" ^^ space ^^ rec_pp ^^ clauses let rec doc_mpat (MP_aux (mp_aux, _) as mpat) = match mp_aux with @@ -648,7 +604,6 @@ let rec doc_mpat (MP_aux (mp_aux, _) as mpat) = | MP_list pats -> string "[|" ^^ separate_map (comma ^^ space) doc_mpat pats ^^ string "|]" | _ -> string (string_of_mpat mpat) - let doc_mpexp (MPat_aux (mpexp, _)) = match mpexp with | MPat_pat mpat -> doc_mpat mpat @@ -657,38 +612,35 @@ let doc_mpexp (MPat_aux (mpexp, _)) = let doc_mapcl (MCL_aux (cl, _)) = match cl with | MCL_bidir (mpexp1, mpexp2) -> - let left = doc_mpexp mpexp1 in - let right = doc_mpexp mpexp2 in - separate space [left; string "<->"; right] + let left = doc_mpexp mpexp1 in + let right = doc_mpexp mpexp2 in + separate space [left; string "<->"; right] | MCL_forwards (mpexp, exp) -> - let left = doc_mpexp mpexp in - let right = doc_exp exp in - separate space [left; string "=>"; right] + let left = doc_mpexp mpexp in + let right = doc_exp exp in + separate space [left; string "=>"; right] | MCL_backwards (mpexp, exp) -> - let left = doc_mpexp mpexp in - let right = doc_exp exp in - separate space [left; string "<-"; right] - + let left = doc_mpexp mpexp in + let right = doc_exp exp in + separate space [left; string "<-"; right] let doc_mapdef (MD_aux (MD_mapping (id, _, mapcls), _)) = match mapcls with | [] -> failwith "Empty mapping" | _ -> - let sep = string "," ^^ hardline in - let clauses = separate_map sep doc_mapcl mapcls in - string "mapping" ^^ space ^^ doc_id id ^^ space ^^ string "=" ^^ space ^^ (surround 2 0 lbrace clauses rbrace) + let sep = string "," ^^ hardline in + let clauses = separate_map sep doc_mapcl mapcls in + string "mapping" ^^ space ^^ doc_id id ^^ space ^^ string "=" ^^ space ^^ surround 2 0 lbrace clauses rbrace -let doc_dec (DEC_aux (reg,_)) = +let doc_dec (DEC_aux (reg, _)) = match reg with - | DEC_reg (typ, id, opt_exp) -> - match opt_exp with - | None -> - separate space [string "register"; doc_id id; colon; doc_typ typ] - | Some exp -> - separate space [string "register"; doc_id id; colon; doc_typ typ; equals; doc_exp exp] + | DEC_reg (typ, id, opt_exp) -> ( + match opt_exp with + | None -> separate space [string "register"; doc_id id; colon; doc_typ typ] + | Some exp -> separate space [string "register"; doc_id id; colon; doc_typ typ; equals; doc_exp exp] + ) -let doc_field (typ, id) = - separate space [doc_id id; colon; doc_typ typ] +let doc_field (typ, id) = separate space [doc_id id; colon; doc_typ typ] let doc_union (Tu_aux (Tu_ty_id (typ, id), _)) = separate space [doc_id id; colon; doc_typ typ] @@ -700,98 +652,115 @@ let rec doc_index_range (BF_aux (ir, _)) = let doc_typ_arg_kind sep (A_aux (aux, _)) = match aux with - | A_nexp _ -> space ^^ string sep ^^ space ^^string "Int" + | A_nexp _ -> space ^^ string sep ^^ space ^^ string "Int" | A_bool _ -> space ^^ string sep ^^ space ^^ string "Bool" - | A_order _ -> space ^^ string sep ^^ space ^^ string "Order" + | A_order _ -> space ^^ string sep ^^ space ^^ string "Order" | A_typ _ -> empty -let doc_typdef (TD_aux(td,_)) = match td with - | TD_abbrev (id, typq, typ_arg) -> - begin - match doc_typquant typq with - | Some qdoc -> - doc_op equals (concat [string "type"; space; doc_id id; qdoc; doc_typ_arg_kind "->" typ_arg]) (doc_typ_arg typ_arg) - | None -> +let doc_typdef (TD_aux (td, _)) = + match td with + | TD_abbrev (id, typq, typ_arg) -> begin + match doc_typquant typq with + | Some qdoc -> + doc_op equals + (concat [string "type"; space; doc_id id; qdoc; doc_typ_arg_kind "->" typ_arg]) + (doc_typ_arg typ_arg) + | None -> doc_op equals (concat [string "type"; space; doc_id id; doc_typ_arg_kind ":" typ_arg]) (doc_typ_arg typ_arg) - end + end | TD_enum (id, ids, _) -> - separate space [string "enum"; doc_id id; equals; surround 2 0 lbrace (separate_map (comma ^^ break 1) doc_id ids) rbrace] + separate space + [string "enum"; doc_id id; equals; surround 2 0 lbrace (separate_map (comma ^^ break 1) doc_id ids) rbrace] | TD_record (id, TypQ_aux (TypQ_no_forall, _), fields, _) | TD_record (id, TypQ_aux (TypQ_tq [], _), fields, _) -> - separate space [string "struct"; doc_id id; equals; surround 2 0 lbrace (separate_map (comma ^^ break 1) doc_field fields) rbrace] + separate space + [ + string "struct"; + doc_id id; + equals; + surround 2 0 lbrace (separate_map (comma ^^ break 1) doc_field fields) rbrace; + ] | TD_record (id, TypQ_aux (TypQ_tq qs, _), fields, _) -> - separate space [string "struct"; doc_id id; doc_param_quants qs; equals; - surround 2 0 lbrace (separate_map (comma ^^ break 1) doc_field fields) rbrace] + separate space + [ + string "struct"; + doc_id id; + doc_param_quants qs; + equals; + surround 2 0 lbrace (separate_map (comma ^^ break 1) doc_field fields) rbrace; + ] | TD_variant (id, TypQ_aux (TypQ_no_forall, _), unions, _) | TD_variant (id, TypQ_aux (TypQ_tq [], _), unions, _) -> - separate space [string "union"; doc_id id; equals; surround 2 0 lbrace (separate_map (comma ^^ break 1) doc_union unions) rbrace] + separate space + [ + string "union"; + doc_id id; + equals; + surround 2 0 lbrace (separate_map (comma ^^ break 1) doc_union unions) rbrace; + ] | TD_variant (id, TypQ_aux (TypQ_tq qs, _), unions, _) -> - separate space [string "union"; doc_id id; doc_param_quants qs; equals; - surround 2 0 lbrace (separate_map (comma ^^ break 1) doc_union unions) rbrace] + separate space + [ + string "union"; + doc_id id; + doc_param_quants qs; + equals; + surround 2 0 lbrace (separate_map (comma ^^ break 1) doc_union unions) rbrace; + ] | TD_bitfield (id, typ, fields) -> - let doc_field (id, range) = separate space [doc_id id; colon; doc_index_range range] in - doc_op equals (separate space [string "bitfield"; doc_id id; colon; doc_typ typ]) - (surround 2 0 lbrace (separate_map (comma ^^ break 1) doc_field fields) rbrace) + let doc_field (id, range) = separate space [doc_id id; colon; doc_index_range range] in + doc_op equals + (separate space [string "bitfield"; doc_id id; colon; doc_typ typ]) + (surround 2 0 lbrace (separate_map (comma ^^ break 1) doc_field fields) rbrace) let doc_spec (VS_aux (v, annot)) = let doc_extern ext = match ext with | Some ext -> - let purity = if ext.pure then string "pure" ^^ space else string "monadic" ^^ space in - let docs = List.map (fun (backend, rep) -> string (backend ^ ":") ^^ space ^^ utf8string ("\"" ^ String.escaped rep ^ "\"")) ext.bindings in - equals ^^ space ^^ purity ^^ braces (separate (comma ^^ space) docs) - | None -> - empty + let purity = if ext.pure then string "pure" ^^ space else string "monadic" ^^ space in + let docs = + List.map + (fun (backend, rep) -> string (backend ^ ":") ^^ space ^^ utf8string ("\"" ^ String.escaped rep ^ "\"")) + ext.bindings + in + equals ^^ space ^^ purity ^^ braces (separate (comma ^^ space) docs) + | None -> empty in match v with - | VS_val_spec(ts,id,ext,is_cast) -> - string "val" ^^ space - ^^ (if is_cast then (string "cast" ^^ space) else empty) - ^^ doc_id id ^^ space - ^^ doc_extern ext - ^^ colon ^^ space - ^^ doc_typschm ts - -let doc_prec = function - | Infix -> string "infix" - | InfixL -> string "infixl" - | InfixR -> string "infixr" + | VS_val_spec (ts, id, ext, is_cast) -> + string "val" ^^ space + ^^ (if is_cast then string "cast" ^^ space else empty) + ^^ doc_id id ^^ space ^^ doc_extern ext ^^ colon ^^ space ^^ doc_typschm ts + +let doc_prec = function Infix -> string "infix" | InfixL -> string "infixl" | InfixR -> string "infixr" let doc_loop_measures l = - separate_map (comma ^^ break 1) - (function (Loop (l,e)) -> - string (match l with While -> "while" | Until -> "until") ^^ - space ^^ doc_exp e) + separate_map + (comma ^^ break 1) + (function Loop (l, e) -> string (match l with While -> "while" | Until -> "until") ^^ space ^^ doc_exp e) l let doc_scattered (SD_aux (sd_aux, _)) = match sd_aux with - | SD_function (_, _, id) -> - string "scattered" ^^ space ^^ string "function" ^^ space ^^ doc_id id - | SD_funcl funcl -> - string "function" ^^ space ^^ string "clause" ^^ space ^^ doc_funcl funcl - | SD_end id -> - string "end" ^^ space ^^ doc_id id - | SD_variant (id, TypQ_aux (TypQ_no_forall, _)) -> - string "scattered" ^^ space ^^ string "union" ^^ space ^^ doc_id id + | SD_function (_, _, id) -> string "scattered" ^^ space ^^ string "function" ^^ space ^^ doc_id id + | SD_funcl funcl -> string "function" ^^ space ^^ string "clause" ^^ space ^^ doc_funcl funcl + | SD_end id -> string "end" ^^ space ^^ doc_id id + | SD_variant (id, TypQ_aux (TypQ_no_forall, _)) -> string "scattered" ^^ space ^^ string "union" ^^ space ^^ doc_id id | SD_variant (id, TypQ_aux (TypQ_tq quants, _)) -> - string "scattered" ^^ space ^^ string "union" ^^ space ^^ doc_id id ^^ doc_param_quants quants - | SD_mapcl (id, mapcl) -> - separate space [string "mapping clause"; doc_id id; equals; doc_mapcl mapcl] - | SD_mapping (id, Typ_annot_opt_aux (Typ_annot_opt_none, _)) -> - separate space [string "scattered mapping"; doc_id id] + string "scattered" ^^ space ^^ string "union" ^^ space ^^ doc_id id ^^ doc_param_quants quants + | SD_mapcl (id, mapcl) -> separate space [string "mapping clause"; doc_id id; equals; doc_mapcl mapcl] + | SD_mapping (id, Typ_annot_opt_aux (Typ_annot_opt_none, _)) -> separate space [string "scattered mapping"; doc_id id] | SD_mapping (id, Typ_annot_opt_aux (Typ_annot_opt_some (typq, typ), _)) -> - separate space [string "scattered mapping"; doc_id id; colon; doc_binding (typq, typ)] - | SD_unioncl (id, tu) -> - separate space [string "union clause"; doc_id id; equals; doc_union tu] + separate space [string "scattered mapping"; doc_id id; colon; doc_binding (typq, typ)] + | SD_unioncl (id, tu) -> separate space [string "union clause"; doc_id id; equals; doc_union tu] let doc_filter = function | DEF_aux ((DEF_pragma ("file_start", _, _) | DEF_pragma ("file_end", _, _)), _) -> false | _ -> true - -let rec doc_def_no_hardline ?comment:(comment=false) (DEF_aux (aux, def_annot)) = - (match def_annot.doc_comment with - | Some str when comment -> - string "/*! " ^^ string str ^^ string " */" ^^ hardline - | _ -> empty) + +let rec doc_def_no_hardline ?(comment = false) (DEF_aux (aux, def_annot)) = + ( match def_annot.doc_comment with + | Some str when comment -> string "/*! " ^^ string str ^^ string " */" ^^ hardline + | _ -> empty + ) ^^ match aux with | DEF_default df -> doc_default df @@ -799,42 +768,42 @@ let rec doc_def_no_hardline ?comment:(comment=false) (DEF_aux (aux, def_annot)) | DEF_type t_def -> doc_typdef t_def | DEF_fundef f_def -> doc_fundef f_def | DEF_mapdef m_def -> doc_mapdef m_def - | DEF_outcome (OV_aux (OV_outcome (id, typschm, args), _), defs) -> - string "outcome" ^^ space ^^ doc_id id ^^ space ^^ colon ^^ space ^^ doc_typschm typschm - ^^ break 1 ^^ (string "with" ^//^ separate_map (comma ^^ break 1) doc_kopt_no_parens args) - ^^ (match defs with - | [] -> empty - | _ -> break 1 ^^ ((string "= {" ^//^ separate_map (hardline ^^ hardline) doc_def_no_hardline defs) ^/^ string "}")) - | DEF_instantiation (IN_aux (IN_id id, _), substs) -> - string "instantiation" ^^ space ^^ doc_id id - ^^ (match substs with - | [] -> empty - | _ -> (space ^^ string "with") ^//^ separate_map (comma ^^ break 1) doc_subst substs) - | DEF_impl funcl -> - string "impl" ^^ space ^^ doc_funcl funcl + | DEF_outcome (OV_aux (OV_outcome (id, typschm, args), _), defs) -> ( + string "outcome" ^^ space ^^ doc_id id ^^ space ^^ colon ^^ space ^^ doc_typschm typschm ^^ break 1 + ^^ (string "with" ^//^ separate_map (comma ^^ break 1) doc_kopt_no_parens args) + ^^ + match defs with + | [] -> empty + | _ -> break 1 ^^ (string "= {" ^//^ separate_map (hardline ^^ hardline) doc_def_no_hardline defs) ^/^ string "}" + ) + | DEF_instantiation (IN_aux (IN_id id, _), substs) -> ( + string "instantiation" ^^ space ^^ doc_id id + ^^ + match substs with + | [] -> empty + | _ -> (space ^^ string "with") ^//^ separate_map (comma ^^ break 1) doc_subst substs + ) + | DEF_impl funcl -> string "impl" ^^ space ^^ doc_funcl funcl | DEF_let lbind -> string "let" ^^ space ^^ doc_letbind lbind | DEF_internal_mutrec fundefs -> - (string "mutual {" ^//^ separate_map (hardline ^^ hardline) doc_fundef fundefs) - ^^ hardline ^^ string "}" + (string "mutual {" ^//^ separate_map (hardline ^^ hardline) doc_fundef fundefs) ^^ hardline ^^ string "}" | DEF_register dec -> doc_dec dec | DEF_scattered sdef -> doc_scattered sdef - | DEF_measure (id,pat,exp) -> - string "termination_measure" ^^ space ^^ doc_id id ^/^ doc_pat pat ^^ - space ^^ equals ^/^ doc_exp exp - | DEF_loop_measures (id,measures) -> - string "termination_measure" ^^ space ^^ doc_id id ^/^ doc_loop_measures measures - | DEF_pragma (pragma, arg, _) -> - string ("$" ^ pragma ^ " " ^ arg) + | DEF_measure (id, pat, exp) -> + string "termination_measure" ^^ space ^^ doc_id id ^/^ doc_pat pat ^^ space ^^ equals ^/^ doc_exp exp + | DEF_loop_measures (id, measures) -> + string "termination_measure" ^^ space ^^ doc_id id ^/^ doc_loop_measures measures + | DEF_pragma (pragma, arg, _) -> string ("$" ^ pragma ^ " " ^ arg) | DEF_fixity (prec, n, id) -> - fixities := Bindings.add id (prec, Big_int.to_int n) !fixities; - separate space [doc_prec prec; doc_int n; doc_id id] - | DEF_overload (Id_aux (_, l) as id, ids) -> - separate space [string "overload"; doc_id id; equals; surround 2 0 lbrace (separate_map (comma ^^ break 1) doc_id ids) rbrace] + fixities := Bindings.add id (prec, Big_int.to_int n) !fixities; + separate space [doc_prec prec; doc_int n; doc_id id] + | DEF_overload ((Id_aux (_, l) as id), ids) -> + separate space + [string "overload"; doc_id id; equals; surround 2 0 lbrace (separate_map (comma ^^ break 1) doc_id ids) rbrace] -and doc_def ?comment:(comment=false) def = group (doc_def_no_hardline ~comment:comment def ^^ hardline) +and doc_def ?(comment = false) def = group (doc_def_no_hardline ~comment def ^^ hardline) -let doc_ast ?comment:(comment=false) { defs; _ } = - separate_map hardline (doc_def ~comment:comment) (List.filter doc_filter defs) +let doc_ast ?(comment = false) { defs; _ } = separate_map hardline (doc_def ~comment) (List.filter doc_filter defs) (* This function is intended to reformat machine-generated Sail into something a bit more readable, it is not intended to be used as a @@ -842,47 +811,44 @@ let doc_ast ?comment:(comment=false) { defs; _ } = let reformat dir { defs; _ } = let file_stack = ref [] in - let pop () = match !file_stack with + let pop () = + match !file_stack with | [] -> Reporting.unreachable Parse_ast.Unknown __POS__ "Unbalanced file structure" | Some chan :: chans -> - close_out chan; - file_stack := chans - | None :: chans -> - file_stack := chans + close_out chan; + file_stack := chans + | None :: chans -> file_stack := chans in let push = function - | Some path -> - file_stack := Some (open_out path) :: !file_stack - | None -> - file_stack := None :: !file_stack + | Some path -> file_stack := Some (open_out path) :: !file_stack + | None -> file_stack := None :: !file_stack in let adjust_path path = Filename.concat (Filename.concat (Filename.dirname path) dir) (Filename.basename path) in - let output_def def = match !file_stack with - | Some chan :: _ -> - ToChannel.pretty 1. 120 chan (hardline ^^ doc_def ~comment:true def) + let output_def def = + match !file_stack with + | Some chan :: _ -> ToChannel.pretty 1. 120 chan (hardline ^^ doc_def ~comment:true def) | None :: _ -> () | [] -> Reporting.unreachable Parse_ast.Unknown __POS__ "No file for definition" in - let output_include path = match !file_stack with + let output_include path = + match !file_stack with | Some chan :: _ -> - if Filename.is_relative path then - Printf.fprintf chan "$include \"%s\"\n" path - else - Printf.fprintf chan "$include <%s>\n" (Filename.basename path) + if Filename.is_relative path then Printf.fprintf chan "$include \"%s\"\n" path + else Printf.fprintf chan "$include <%s>\n" (Filename.basename path) | None :: _ -> () | [] -> Reporting.unreachable Parse_ast.Unknown __POS__ "No file for include" in - + let format_def = function | DEF_aux (DEF_pragma ("file_start", path, _), _) -> push (Some (adjust_path path)) | DEF_aux (DEF_pragma ("file_end", _, _), _) -> pop () | DEF_aux (DEF_pragma ("include_start", path, _), _) -> - output_include path; - if Filename.is_relative path then push (Some (adjust_path path)) else push None + output_include path; + if Filename.is_relative path then push (Some (adjust_path path)) else push None | DEF_aux (DEF_pragma ("include_end", _, _), _) -> pop () | def -> output_def def in diff --git a/src/lib/profile.ml b/src/lib/profile.ml index 0c112a1c0..51b9bf7f4 100644 --- a/src/lib/profile.ml +++ b/src/lib/profile.ml @@ -67,42 +67,31 @@ let opt_profile = ref false -type profile = { - smt_calls : int; - smt_time : float - } +type profile = { smt_calls : int; smt_time : float } -let new_profile = { - smt_calls = 0; - smt_time = 0.0 - } +let new_profile = { smt_calls = 0; smt_time = 0.0 } let profile_stack = ref [] -let update_profile f = - match !profile_stack with - | [] -> () - | (p :: ps) -> - profile_stack := f p :: ps +let update_profile f = match !profile_stack with [] -> () | p :: ps -> profile_stack := f p :: ps let start_smt () = update_profile (fun p -> { p with smt_calls = p.smt_calls + 1 }); Sys.time () -let finish_smt t = - update_profile (fun p -> { p with smt_time = p.smt_time +. (Sys.time () -. t) }) +let finish_smt t = update_profile (fun p -> { p with smt_time = p.smt_time +. (Sys.time () -. t) }) let start () = profile_stack := new_profile :: !profile_stack; Sys.time () let finish msg t = - if !opt_profile then - begin match !profile_stack with + if !opt_profile then begin + match !profile_stack with | p :: ps -> - prerr_endline (Printf.sprintf "%s %s: %fs" Util.("Profiled" |> magenta |> clear) msg (Sys.time () -. t)); - prerr_endline (Printf.sprintf " SMT calls: %d, SMT time: %fs" p.smt_calls p.smt_time); - profile_stack := ps + prerr_endline (Printf.sprintf "%s %s: %fs" Util.("Profiled" |> magenta |> clear) msg (Sys.time () -. t)); + prerr_endline (Printf.sprintf " SMT calls: %d, SMT time: %fs" p.smt_calls p.smt_time); + profile_stack := ps | [] -> () - end + end else () diff --git a/src/lib/property.ml b/src/lib/property.ml index cbf5628d6..2905de56f 100644 --- a/src/lib/property.ml +++ b/src/lib/property.ml @@ -72,13 +72,11 @@ open Parser_combinators let find_properties { defs; _ } = let rec find_prop acc = function - | DEF_aux (DEF_pragma (("property" | "counterexample") as prop_type, command, l), _) :: defs -> - begin match Util.find_next (function DEF_aux (DEF_val _, _) -> true | _ -> false) defs with - | _, Some (DEF_aux (DEF_val vs, _), defs) -> - find_prop ((prop_type, command, l, vs) :: acc) defs - | _, _ -> - raise (Reporting.err_general l "Property is not attached to any function signature") - end + | DEF_aux (DEF_pragma ((("property" | "counterexample") as prop_type), command, l), _) :: defs -> begin + match Util.find_next (function DEF_aux (DEF_val _, _) -> true | _ -> false) defs with + | _, Some (DEF_aux (DEF_val vs, _), defs) -> find_prop ((prop_type, command, l, vs) :: acc) defs + | _, _ -> raise (Reporting.err_general l "Property is not attached to any function signature") + end | def :: defs -> find_prop acc defs | [] -> acc in @@ -90,72 +88,80 @@ let add_property_guards props ast = let open Type_check in let rec add_property_guards' acc = function | (DEF_aux (DEF_fundef (FD_aux (FD_function (r_opt, t_opt, funcls), fd_aux) as fdef), def_annot) as def) :: defs -> - begin match Bindings.find_opt (id_of_fundef fdef) props with - | Some (_, _, pragma_l, VS_aux (VS_val_spec (TypSchm_aux (TypSchm_ts (quant, _), _), _, _, _), _)) -> - begin match quant_split quant with - | _, [] -> add_property_guards' (def :: acc) defs - | _, constraints -> - let add_constraints_to_funcl (FCL_aux (FCL_funcl (id, Pat_aux (pexp, pexp_aux)), fcl_aux)) = - let add_guard exp = - (* FIXME: Use an assert *) - let exp' = mk_exp (E_block [mk_exp (E_app (mk_id "sail_assume", [mk_exp (E_constraint (List.fold_left nc_and nc_true constraints))])); - strip_exp exp]) - in - try Type_check.check_exp (env_of exp) exp' (typ_of exp) with - | Type_error (_, l, err) -> - let msg = - "\nType error when generating guard for a property.\n\ - When generating guards we convert type quantifiers from the function signature\n\ - into runtime checks so it must be possible to reconstruct the quantifier from\n\ - the function arguments. For example:\n\n\ - \ - function f : forall 'n, 'n <= 100. (x: int('n)) -> bool\n\n\ - \ - would cause the runtime check x <= 100 to be added to the function body.\n\ - To fix this error, ensure that all quantifiers have corresponding function arguments.\n" + begin + match Bindings.find_opt (id_of_fundef fdef) props with + | Some (_, _, pragma_l, VS_aux (VS_val_spec (TypSchm_aux (TypSchm_ts (quant, _), _), _, _, _), _)) -> begin + match quant_split quant with + | _, [] -> add_property_guards' (def :: acc) defs + | _, constraints -> + let add_constraints_to_funcl (FCL_aux (FCL_funcl (id, Pat_aux (pexp, pexp_aux)), fcl_aux)) = + let add_guard exp = + (* FIXME: Use an assert *) + let exp' = + mk_exp + (E_block + [ + mk_exp + (E_app + ( mk_id "sail_assume", + [mk_exp (E_constraint (List.fold_left nc_and nc_true constraints))] + ) + ); + strip_exp exp; + ] + ) in - raise (Reporting.err_typ pragma_l (Type_error.string_of_type_error err ^ msg)) - in - let mk_funcl p = FCL_aux (FCL_funcl (id, Pat_aux (p, pexp_aux)), fcl_aux) in - match pexp with - | Pat_exp (pat, exp) -> - mk_funcl (Pat_exp (pat, add_guard exp)) - | Pat_when (pat, guard, exp) -> - mk_funcl (Pat_when (pat, guard, add_guard exp)) - in - - let funcls = List.map add_constraints_to_funcl funcls in - let fdef = FD_aux (FD_function (r_opt, t_opt, funcls), fd_aux) in - - add_property_guards' (DEF_aux (DEF_fundef fdef, def_annot) :: acc) defs - end - | None -> add_property_guards' (def :: acc) defs - end + try Type_check.check_exp (env_of exp) exp' (typ_of exp) + with Type_error (_, l, err) -> + let msg = + "\n\ + Type error when generating guard for a property.\n\ + When generating guards we convert type quantifiers from the function signature\n\ + into runtime checks so it must be possible to reconstruct the quantifier from\n\ + the function arguments. For example:\n\n\ + function f : forall 'n, 'n <= 100. (x: int('n)) -> bool\n\n\ + would cause the runtime check x <= 100 to be added to the function body.\n\ + To fix this error, ensure that all quantifiers have corresponding function arguments.\n" + in + raise (Reporting.err_typ pragma_l (Type_error.string_of_type_error err ^ msg)) + in + let mk_funcl p = FCL_aux (FCL_funcl (id, Pat_aux (p, pexp_aux)), fcl_aux) in + match pexp with + | Pat_exp (pat, exp) -> mk_funcl (Pat_exp (pat, add_guard exp)) + | Pat_when (pat, guard, exp) -> mk_funcl (Pat_when (pat, guard, add_guard exp)) + in + + let funcls = List.map add_constraints_to_funcl funcls in + let fdef = FD_aux (FD_function (r_opt, t_opt, funcls), fd_aux) in + add_property_guards' (DEF_aux (DEF_fundef fdef, def_annot) :: acc) defs + end + | None -> add_property_guards' (def :: acc) defs + end | def :: defs -> add_property_guards' (def :: acc) defs | [] -> List.rev acc in { ast with defs = add_property_guards' [] ast.defs } -let rewrite defs = - add_property_guards (find_properties defs) defs +let rewrite defs = add_property_guards (find_properties defs) defs type event = Overflow | Assertion | Assumption | Match | Return type query = - | Q_all of event (* All events of type are true *) - | Q_exist of event (* Some event of type is true *) - | Q_not of query - | Q_and of query list - | Q_or of query list + | Q_all of event (* All events of type are true *) + | Q_exist of event (* Some event of type is true *) + | Q_not of query + | Q_and of query list + | Q_or of query list let default_query = - Q_or [Q_and [Q_not (Q_exist Assertion); Q_all Return; Q_not (Q_exist Match)]; Q_exist Overflow; Q_not (Q_all Assumption)] + Q_or + [Q_and [Q_not (Q_exist Assertion); Q_all Return; Q_not (Q_exist Match)]; Q_exist Overflow; Q_not (Q_all Assumption)] module Event = struct type t = event let compare e1 e2 = - match e1, e2 with + match (e1, e2) with | Overflow, Overflow -> 0 | Assertion, Assertion -> 0 | Assumption, Assumption -> 0 @@ -190,57 +196,74 @@ let parse_query = let bar = token (function Str.Delim "|" -> Some () | _ -> None) in let lparen = token (function Str.Delim "(" -> Some () | _ -> None) in let rparen = token (function Str.Delim ")" -> Some () | _ -> None) in - let quant = token (function Str.Text ("A" | "all") -> Some (fun x -> Q_all x) - | Str.Text ("E" | "exist") -> Some (fun x -> Q_exist x) - | _ -> None) in - let event = token (function Str.Text "overflow" -> Some Overflow - | Str.Text "assertion" -> Some Assertion - | Str.Text "assumption" -> Some Assumption - | Str.Text "match_failure" -> Some Match - | Str.Text "return"-> Some Return - | _ -> None) in + let quant = + token (function + | Str.Text ("A" | "all") -> Some (fun x -> Q_all x) + | Str.Text ("E" | "exist") -> Some (fun x -> Q_exist x) + | _ -> None + ) + in + let event = + token (function + | Str.Text "overflow" -> Some Overflow + | Str.Text "assertion" -> Some Assertion + | Str.Text "assumption" -> Some Assumption + | Str.Text "match_failure" -> Some Match + | Str.Text "return" -> Some Return + | _ -> None + ) + in let tilde = token (function Str.Delim "~" -> Some () | _ -> None) in let rec exp0 () = - pchoose (exp1 () >>= fun x -> bar >>= fun _ -> exp0 () >>= fun y -> preturn (Q_or [x; y])) - (exp1 ()) + pchoose + ( exp1 () >>= fun x -> + bar >>= fun _ -> + exp0 () >>= fun y -> preturn (Q_or [x; y]) + ) + (exp1 ()) and exp1 () = - pchoose (exp2 () >>= fun x -> amp >>= fun _ -> exp1 () >>= fun y -> preturn (Q_and [x; y])) - (exp2 ()) + pchoose + ( exp2 () >>= fun x -> + amp >>= fun _ -> + exp1 () >>= fun y -> preturn (Q_and [x; y]) + ) + (exp2 ()) and exp2 () = - pchoose (tilde >>= fun _ -> exp3 () >>= fun x -> preturn (Q_not x)) - (exp3 ()) + pchoose + ( tilde >>= fun _ -> + exp3 () >>= fun x -> preturn (Q_not x) + ) + (exp3 ()) and exp3 () = - pchoose (lparen >>= fun _ -> exp0 () >>= fun x -> rparen >>= fun _ -> preturn x) - (quant >>= fun f -> event >>= fun ev -> preturn (f ev)) + pchoose + ( lparen >>= fun _ -> + exp0 () >>= fun x -> + rparen >>= fun _ -> preturn x + ) + ( quant >>= fun f -> + event >>= fun ev -> preturn (f ev) + ) in parse (exp0 ()) "[ \n\t]+\\|(\\|)\\|&\\||\\|~" -type pragma = { - query : query; - litmus : string list; - } +type pragma = { query : query; litmus : string list } -let default_pragma = { - query = default_query; - litmus = []; - } +let default_pragma = { query = default_query; litmus = [] } let parse_pragma l input = let key = Str.regexp ":[a-z]+" in let tokens = Str.full_split key input in let rec process_toks pragma = function - | Str.Delim ":query" :: Str.Text query :: rest -> - begin match parse_query query with - | Some q -> process_toks { pragma with query = q } rest - | None -> - raise (Reporting.err_general l ("Could not parse query " ^ String.trim query)) - end + | Str.Delim ":query" :: Str.Text query :: rest -> begin + match parse_query query with + | Some q -> process_toks { pragma with query = q } rest + | None -> raise (Reporting.err_general l ("Could not parse query " ^ String.trim query)) + end | Str.Delim ":litmus" :: rest -> - let args, rest = Util.take_drop (function Str.Text _ -> true | _ -> false) rest in - process_toks { pragma with litmus = List.map (function Str.Text t -> t | _ -> assert false) args } rest + let args, rest = Util.take_drop (function Str.Text _ -> true | _ -> false) rest in + process_toks { pragma with litmus = List.map (function Str.Text t -> t | _ -> assert false) args } rest | [] -> pragma - | _ -> - raise (Reporting.err_general l "Could not parse pragma") + | _ -> raise (Reporting.err_general l "Could not parse pragma") in process_toks default_pragma tokens diff --git a/src/lib/property.mli b/src/lib/property.mli index b6430ea72..9b40fc2ce 100644 --- a/src/lib/property.mli +++ b/src/lib/property.mli @@ -115,18 +115,10 @@ module Event : sig val compare : event -> event -> int end -type query = - | Q_all of event - | Q_exist of event - | Q_not of query - | Q_and of query list - | Q_or of query list +type query = Q_all of event | Q_exist of event | Q_not of query | Q_and of query list | Q_or of query list val default_query : query -type pragma = { - query : query; - litmus : string list; - } +type pragma = { query : query; litmus : string list } val parse_pragma : Parse_ast.l -> string -> pragma diff --git a/src/lib/reporting.ml b/src/lib/reporting.ml index c95abb454..eb1b752c0 100644 --- a/src/lib/reporting.ml +++ b/src/lib/reporting.ml @@ -65,7 +65,6 @@ (* SUCH DAMAGE. *) (****************************************************************************) - (**************************************************************************) (* Lem *) (* *) @@ -117,18 +116,15 @@ let opt_backtrace_length = ref 10 type pos_or_loc = Loc of Parse_ast.l | Pos of Lexing.position -let fix_endline str = - if str.[String.length str - 1] = '\n' then - String.sub str 0 (String.length str - 1) - else - str - +let fix_endline str = if str.[String.length str - 1] = '\n' then String.sub str 0 (String.length str - 1) else str + let print_err_internal p_l m1 m2 = let open Error_format in prerr_endline (m1 ^ ":"); - begin match p_l with - | Loc l -> format_message (Location ("", None, l, Line (fix_endline m2))) err_formatter - | Pos p -> format_message (Location ("", None, Parse_ast.Range (p, p), Line (fix_endline m2))) err_formatter + begin + match p_l with + | Loc l -> format_message (Location ("", None, l, Line (fix_endline m2))) err_formatter + | Pos p -> format_message (Location ("", None, Parse_ast.Range (p, p), Line (fix_endline m2))) err_formatter end let loc_to_string l = @@ -136,16 +132,12 @@ let loc_to_string l = let b = Buffer.create 160 in format_message (Location ("", None, l, Line "")) (buffer_formatter b); Buffer.contents b - + let rec simp_loc = function | Parse_ast.Unknown -> None | Parse_ast.Unique (_, l) -> simp_loc l | Parse_ast.Generated l -> simp_loc l - | Parse_ast.Hint (_, l1, l2) -> - begin match simp_loc l1 with - | None -> simp_loc l2 - | pos -> pos - end + | Parse_ast.Hint (_, l1, l2) -> begin match simp_loc l1 with None -> simp_loc l2 | pos -> pos end | Parse_ast.Range (p1, p2) -> Some (p1, p2) let rec loc_file = function @@ -161,16 +153,15 @@ let rec start_loc = function | Parse_ast.Generated l -> Parse_ast.Generated (start_loc l) | Parse_ast.Hint (hint, l1, l2) -> Parse_ast.Hint (hint, start_loc l1, start_loc l2) | Parse_ast.Range (p1, _) -> Parse_ast.Range (p1, p1) - + let short_loc_to_string l = match simp_loc l with | None -> "unknown location" | Some (p1, p2) -> - Printf.sprintf "%s:%d.%d-%d.%d" - p1.pos_fname p1.pos_lnum (p1.pos_cnum - p1.pos_bol) p2.pos_lnum (p2.pos_cnum - p2.pos_bol) - -let print_err l m1 m2 = - print_err_internal (Loc l) m1 m2 + Printf.sprintf "%s:%d.%d-%d.%d" p1.pos_fname p1.pos_lnum (p1.pos_cnum - p1.pos_bol) p2.pos_lnum + (p2.pos_cnum - p2.pos_bol) + +let print_err l m1 m2 = print_err_internal (Loc l) m1 m2 type error = | Err_general of Parse_ast.l * string @@ -186,10 +177,12 @@ let issues = "\nPlease report this as an issue on GitHub at https://github.com/r let dest_err ?(interactive = false) = function | Err_general (l, m) -> (Util.("Error" |> yellow |> clear), Loc l, m) | Err_unreachable (l, (file, line, _, _), backtrace, m) -> - if interactive then - ("Error", Loc l, m) - else - (Printf.sprintf "Internal error: Unreachable code (at \"%s\" line %d)" file line, Loc l, m ^ "\n\n" ^ Printexc.raw_backtrace_to_string backtrace ^ issues) + if interactive then ("Error", Loc l, m) + else + ( Printf.sprintf "Internal error: Unreachable code (at \"%s\" line %d)" file line, + Loc l, + m ^ "\n\n" ^ Printexc.raw_backtrace_to_string backtrace ^ issues + ) | Err_todo (l, m) -> ("Todo", Loc l, m) | Err_syntax (p, m) -> (Util.("Syntax error" |> yellow |> clear), Pos p, m) | Err_syntax_loc (l, m) -> (Util.("Syntax error" |> yellow |> clear), Loc l, m) @@ -209,8 +202,7 @@ let err_syntax p m = Fatal_error (Err_syntax (p, m)) let err_syntax_loc l m = Fatal_error (Err_syntax_loc (l, m)) let err_lex p m = Fatal_error (Err_lex (p, m)) -let unreachable l pos msg = - raise (err_unreachable l pos msg) +let unreachable l pos msg = raise (err_unreachable l pos msg) let forbid_errors ocaml_pos f x = try f x with @@ -220,29 +212,25 @@ let forbid_errors ocaml_pos f x = | Fatal_error (Err_syntax_loc (l, m)) -> raise (err_unreachable l ocaml_pos m) | Fatal_error (Err_lex (p, m)) -> raise (err_unreachable (Range (p, p)) ocaml_pos m) | Fatal_error (Err_type (l, _, m)) -> raise (err_unreachable l ocaml_pos m) - + let print_error ?(interactive = false) e = - let (m1, pos_l, m2) = dest_err ~interactive:interactive e in + let m1, pos_l, m2 = dest_err ~interactive e in print_err_internal pos_l m1 m2 (* Warnings *) -module StringSet = Set.Make(String) +module StringSet = Set.Make (String) let pos_compare p1 p2 = let open Lexing in match String.compare p1.pos_fname p2.pos_fname with - | 0 -> - begin match compare p1.pos_lnum p2.pos_lnum with - | 0 -> - begin match compare p1.pos_bol p2.pos_bol with - | 0 -> compare p1.pos_cnum p2.pos_cnum - | n -> n - end - | n -> n - end + | 0 -> begin + match compare p1.pos_lnum p2.pos_lnum with + | 0 -> begin match compare p1.pos_bol p2.pos_bol with 0 -> compare p1.pos_cnum p2.pos_cnum | n -> n end + | n -> n + end | n -> n - + module Range = struct type t = Lexing.position * Lexing.position let compare (p1, p2) (p3, p4) = @@ -250,62 +238,65 @@ module Range = struct if c = 0 then pos_compare p2 p4 else c end -module RangeMap = Map.Make(Range) - +module RangeMap = Map.Make (Range) + let ignored_files = ref StringSet.empty -let suppress_warnings_for_file f = - ignored_files := StringSet.add f !ignored_files +let suppress_warnings_for_file f = ignored_files := StringSet.add f !ignored_files let seen_warnings = ref RangeMap.empty let once_from_warnings = ref StringSet.empty let warn ?once_from short_str l explanation = - let already_shown = match once_from with + let already_shown = + match once_from with | Some (file, lnum, cnum, enum) -> - let key = Printf.sprintf "%d:%d:%d:%s" lnum cnum enum file in - if StringSet.mem key !once_from_warnings then ( - true - ) else ( - once_from_warnings := StringSet.add key !once_from_warnings; - false - ) + let key = Printf.sprintf "%d:%d:%d:%s" lnum cnum enum file in + if StringSet.mem key !once_from_warnings then true + else ( + once_from_warnings := StringSet.add key !once_from_warnings; + false + ) | None -> false in if !opt_warnings && not already_shown then ( match simp_loc l with | Some (p1, p2) when not (StringSet.mem p1.pos_fname !ignored_files) -> - let shorts = RangeMap.find_opt (p1, p2) !seen_warnings |> Option.value ~default:[] in - if not (List.exists (fun s -> s = short_str) shorts) then ( - prerr_endline (Util.("Warning" |> yellow |> clear) ^ ": " - ^ short_str ^ (if short_str <> "" then " " else "") ^ loc_to_string l ^ explanation ^ "\n"); - seen_warnings := RangeMap.add (p1, p2) (short_str :: shorts) !seen_warnings - ) + let shorts = RangeMap.find_opt (p1, p2) !seen_warnings |> Option.value ~default:[] in + if not (List.exists (fun s -> s = short_str) shorts) then ( + prerr_endline + (Util.("Warning" |> yellow |> clear) + ^ ": " ^ short_str + ^ (if short_str <> "" then " " else "") + ^ loc_to_string l ^ explanation ^ "\n" + ); + seen_warnings := RangeMap.add (p1, p2) (short_str :: shorts) !seen_warnings + ) | _ -> prerr_endline (Util.("Warning" |> yellow |> clear) ^ ": " ^ short_str ^ "\n" ^ explanation ^ "\n") ) let format_warn ?once_from short_str l explanation = - let already_shown = match once_from with + let already_shown = + match once_from with | Some (file, lnum, cnum, enum) -> - let key = Printf.sprintf "%d:%d:%d:%s" lnum cnum enum file in - if StringSet.mem key !once_from_warnings then ( - true - ) else ( - once_from_warnings := StringSet.add key !once_from_warnings; - false - ) + let key = Printf.sprintf "%d:%d:%d:%s" lnum cnum enum file in + if StringSet.mem key !once_from_warnings then true + else ( + once_from_warnings := StringSet.add key !once_from_warnings; + false + ) | None -> false in if !opt_warnings && not already_shown then ( match simp_loc l with | Some (p1, p2) when not (StringSet.mem p1.pos_fname !ignored_files) -> - let shorts = RangeMap.find_opt (p1, p2) !seen_warnings |> Option.value ~default:[] in - if not (List.exists (fun s -> s = short_str) shorts) then ( - let open Error_format in - prerr_endline (Util.("Warning" |> yellow |> clear) ^ ": " ^ short_str); - format_message (Location ("", None, l, explanation)) err_formatter; - seen_warnings := RangeMap.add (p1, p2) (short_str :: shorts) !seen_warnings - ) + let shorts = RangeMap.find_opt (p1, p2) !seen_warnings |> Option.value ~default:[] in + if not (List.exists (fun s -> s = short_str) shorts) then ( + let open Error_format in + prerr_endline (Util.("Warning" |> yellow |> clear) ^ ": " ^ short_str); + format_message (Location ("", None, l, explanation)) err_formatter; + seen_warnings := RangeMap.add (p1, p2) (short_str :: shorts) !seen_warnings + ) | _ -> prerr_endline (Util.("Warning" |> yellow |> clear) ^ ": " ^ short_str ^ "\n") ) @@ -315,10 +306,11 @@ let get_sail_dir default_sail_dir = match Sys.getenv_opt "SAIL_DIR" with | Some path -> path | None -> - if Sys.file_exists default_sail_dir then - default_sail_dir - else - raise (err_general Parse_ast.Unknown - ("Sail share directory " - ^ default_sail_dir - ^ " does not exist. Make sure Sail is installed correctly, or try setting the SAIL_DIR environment variable")) + if Sys.file_exists default_sail_dir then default_sail_dir + else + raise + (err_general Parse_ast.Unknown + ("Sail share directory " ^ default_sail_dir + ^ " does not exist. Make sure Sail is installed correctly, or try setting the SAIL_DIR environment variable" + ) + ) diff --git a/src/lib/reporting.mli b/src/lib/reporting.mli index 57de9b00d..2b46a36e9 100644 --- a/src/lib/reporting.mli +++ b/src/lib/reporting.mli @@ -90,7 +90,7 @@ val loc_to_string : Parse_ast.l -> string (** [loc_file] returns the file for a location *) val loc_file : Parse_ast.l -> string option - + (** Reduce a location to a pair of positions if possible *) val simp_loc : Ast.l -> (Lexing.position * Lexing.position) option @@ -102,7 +102,7 @@ val print_err : Parse_ast.l -> string -> string -> unit (** Reduce all spans in a location to just their starting characters *) val start_loc : Parse_ast.l -> Parse_ast.l - + (** {2 The error type } *) (** Errors stop execution and print a message; they typically have a location and message. @@ -110,30 +110,21 @@ val start_loc : Parse_ast.l -> Parse_ast.l Note that all these errors are intended to be fatal, so should not be caught other than by the top-level function. *) type error = private - | Err_general of Parse_ast.l * string - (** General errors, used for multi purpose. If you are unsure, use this one. *) - + | Err_general of Parse_ast.l * string (** General errors, used for multi purpose. If you are unsure, use this one. *) | Err_unreachable of Parse_ast.l * (string * int * int * int) * Printexc.raw_backtrace * string - (** Unreachable errors should never be thrown. They represent an internal Sail error. *) - - | Err_todo of Parse_ast.l * string - (** [Err_todo] indicates that some feature is unimplemented. *) - + (** Unreachable errors should never be thrown. They represent an internal Sail error. *) + | Err_todo of Parse_ast.l * string (** [Err_todo] indicates that some feature is unimplemented. *) | Err_syntax of Lexing.position * string | Err_syntax_loc of Parse_ast.l * string - (** [Err_syntax] and [Err_syntax_loc] are used for syntax errors by the parser. *) - - | Err_lex of Lexing.position * string - (** [Err_lex] is a lexical error generated by the lexer. *) - - | Err_type of Parse_ast.l * string option * string - (** [Err_type] is a type error. See the Type_error module. *) + (** [Err_syntax] and [Err_syntax_loc] are used for syntax errors by the parser. *) + | Err_lex of Lexing.position * string (** [Err_lex] is a lexical error generated by the lexer. *) + | Err_type of Parse_ast.l * string option * string (** [Err_type] is a type error. See the Type_error module. *) exception Fatal_error of error val err_todo : Parse_ast.l -> string -> exn val err_general : Parse_ast.l -> string -> exn -val err_unreachable : Parse_ast.l -> (string * int * int * int) -> string -> exn +val err_unreachable : Parse_ast.l -> string * int * int * int -> string -> exn val err_typ : ?hint:string -> Parse_ast.l -> string -> exn val err_syntax : Lexing.position -> string -> exn val err_syntax_loc : Parse_ast.l -> string -> exn @@ -142,7 +133,7 @@ val err_lex : Lexing.position -> string -> exn (** Raise an unreachable exception. This should always be used over an assert false or a generic OCaml failwith exception when appropriate *) -val unreachable : Parse_ast.l -> (string * int * int * int) -> string -> 'a +val unreachable : Parse_ast.l -> string * int * int * int -> string -> 'a (** Print an error to stdout. @@ -152,20 +143,19 @@ it's possible to excute code paths from the interactive mode that would otherwis val print_error : ?interactive:bool -> error -> unit (** This function transforms all errors raised by the provided function into internal [Err_unreachable] errors *) -val forbid_errors : (string * int * int * int) -> ('a -> 'b) -> 'a -> 'b - +val forbid_errors : string * int * int * int -> ('a -> 'b) -> 'a -> 'b + (** Print a warning message. The first string is printed before the location, the second after. *) -val warn : ?once_from:(string * int * int * int) -> string -> Parse_ast.l -> string -> unit +val warn : ?once_from:string * int * int * int -> string -> Parse_ast.l -> string -> unit -val format_warn : ?once_from:(string * int * int * int) -> string -> Parse_ast.l -> Error_format.message -> unit +val format_warn : ?once_from:string * int * int * int -> string -> Parse_ast.l -> Error_format.message -> unit (** Print a simple one-line warning without a location. *) -val simple_warn: string -> unit - +val simple_warn : string -> unit + (** Will suppress all warnings for a given (Sail) file name. Used by $suppress_warnings directive in process_file.ml *) val suppress_warnings_for_file : string -> unit val get_sail_dir : string -> string - diff --git a/src/lib/rewriter.ml b/src/lib/rewriter.ml index e9a465347..605415208 100644 --- a/src/lib/rewriter.ml +++ b/src/lib/rewriter.ml @@ -72,19 +72,20 @@ open Ast_util open Type_check type 'a rewriters = { - rewrite_exp : 'a rewriters -> 'a exp -> 'a exp; - rewrite_lexp : 'a rewriters -> 'a lexp -> 'a lexp; - rewrite_pat : 'a rewriters -> 'a pat -> 'a pat; - rewrite_let : 'a rewriters -> 'a letbind -> 'a letbind; - rewrite_fun : 'a rewriters -> 'a fundef -> 'a fundef; - rewrite_def : 'a rewriters -> 'a def -> 'a def; - rewrite_ast : 'a rewriters -> 'a ast -> 'a ast; - } + rewrite_exp : 'a rewriters -> 'a exp -> 'a exp; + rewrite_lexp : 'a rewriters -> 'a lexp -> 'a lexp; + rewrite_pat : 'a rewriters -> 'a pat -> 'a pat; + rewrite_let : 'a rewriters -> 'a letbind -> 'a letbind; + rewrite_fun : 'a rewriters -> 'a fundef -> 'a fundef; + rewrite_def : 'a rewriters -> 'a def -> 'a def; + rewrite_ast : 'a rewriters -> 'a ast -> 'a ast; +} let lookup_generated_kid env kid = let match_kid_nc kid = function | NC_aux (NC_equal (Nexp_aux (Nexp_var kid1, _), Nexp_aux (Nexp_var kid2, _)), _) - when Kid.compare kid kid2 = 0 && not (is_kid_generated kid1) -> kid1 + when Kid.compare kid kid2 = 0 && not (is_kid_generated kid1) -> + kid1 | _ -> kid in List.fold_left match_kid_nc kid (Env.get_constraints env) @@ -94,29 +95,26 @@ let generated_kids typ = KidSet.filter is_kid_generated (tyvars_of_typ typ) let rec is_src_typ typ = match typ with | Typ_aux (Typ_tuple typs, l) -> List.for_all is_src_typ typs - | _ -> - match destruct_exist typ with - | Some (kopts, nc, typ') -> - let declared_kids = KidSet.of_list (List.map kopt_kid kopts) in - let unused_kids = KidSet.diff declared_kids (tyvars_of_typ typ') in - KidSet.is_empty unused_kids && KidSet.is_empty (generated_kids typ) - | None -> KidSet.is_empty (generated_kids typ) + | _ -> ( + match destruct_exist typ with + | Some (kopts, nc, typ') -> + let declared_kids = KidSet.of_list (List.map kopt_kid kopts) in + let unused_kids = KidSet.diff declared_kids (tyvars_of_typ typ') in + KidSet.is_empty unused_kids && KidSet.is_empty (generated_kids typ) + | None -> KidSet.is_empty (generated_kids typ) + ) let resolve_generated_kids env typ = let subst_kid kid typ = subst_kid typ_subst kid (lookup_generated_kid env kid) typ in KidSet.fold subst_kid (generated_kids typ) typ -let rec remove_p_typ = function - | P_aux (P_typ (typ, pat), _) -> remove_p_typ pat - | pat -> pat +let rec remove_p_typ = function P_aux (P_typ (typ, pat), _) -> remove_p_typ pat | pat -> pat let add_p_typ env typ (P_aux (paux, annot) as pat) = let typ' = resolve_generated_kids env typ in if is_src_typ typ' then P_aux (P_typ (typ', remove_p_typ pat), annot) else pat -let rec remove_e_typ = function - | E_aux (E_typ (_, exp), _) -> remove_e_typ exp - | exp -> exp +let rec remove_e_typ = function E_aux (E_typ (_, exp), _) -> remove_e_typ exp | exp -> exp let add_e_typ env typ (E_aux (eaux, annot) as exp) = let typ' = resolve_generated_kids env typ in @@ -129,138 +127,148 @@ let add_typs_let env ltyp rtyp exp = in match exp with | E_aux (E_let (LB_aux (LB_val (pat, lhs), lba), rhs), a) -> - let (pat', lhs', rhs') = aux pat lhs rhs in - E_aux (E_let (LB_aux (LB_val (pat', lhs'), lba), rhs'), a) + let pat', lhs', rhs' = aux pat lhs rhs in + E_aux (E_let (LB_aux (LB_val (pat', lhs'), lba), rhs'), a) | E_aux (E_internal_plet (pat, lhs, rhs), a) -> - let (pat', lhs', rhs') = aux pat lhs rhs in - E_aux (E_internal_plet (pat', lhs', rhs'), a) + let pat', lhs', rhs' = aux pat lhs rhs in + E_aux (E_internal_plet (pat', lhs', rhs'), a) | _ -> exp let rewrite_pexp rewriters = let rewrite = rewriters.rewrite_exp rewriters in function - | (Pat_aux (Pat_exp(p, e), pannot)) -> - Pat_aux (Pat_exp(rewriters.rewrite_pat rewriters p, rewrite e), pannot) - | (Pat_aux (Pat_when(p, e, e'), pannot)) -> - Pat_aux (Pat_when(rewriters.rewrite_pat rewriters p, rewrite e, rewrite e'), pannot) + | Pat_aux (Pat_exp (p, e), pannot) -> Pat_aux (Pat_exp (rewriters.rewrite_pat rewriters p, rewrite e), pannot) + | Pat_aux (Pat_when (p, e, e'), pannot) -> + Pat_aux (Pat_when (rewriters.rewrite_pat rewriters p, rewrite e, rewrite e'), pannot) -let rewrite_pat rewriters (P_aux (pat,(l,annot))) = - let rewrap p = P_aux (p,(l,annot)) in +let rewrite_pat rewriters (P_aux (pat, (l, annot))) = + let rewrap p = P_aux (p, (l, annot)) in let rewrite = rewriters.rewrite_pat rewriters in match pat with | P_lit _ | P_wild | P_id _ | P_var _ | P_vector_subrange _ -> rewrap pat - | P_or(pat1, pat2) -> rewrap (P_or(rewrite pat1, rewrite pat2)) - | P_not(pat) -> rewrap (P_not(rewrite pat)) - | P_as(pat,id) -> rewrap (P_as(rewrite pat, id)) - | P_typ(typ,pat) -> rewrap (P_typ(typ, rewrite pat)) - | P_app(id ,pats) -> rewrap (P_app(id, List.map rewrite pats)) - | P_vector pats -> rewrap (P_vector(List.map rewrite pats)) + | P_or (pat1, pat2) -> rewrap (P_or (rewrite pat1, rewrite pat2)) + | P_not pat -> rewrap (P_not (rewrite pat)) + | P_as (pat, id) -> rewrap (P_as (rewrite pat, id)) + | P_typ (typ, pat) -> rewrap (P_typ (typ, rewrite pat)) + | P_app (id, pats) -> rewrap (P_app (id, List.map rewrite pats)) + | P_vector pats -> rewrap (P_vector (List.map rewrite pats)) | P_vector_concat pats -> rewrap (P_vector_concat (List.map rewrite pats)) | P_tuple pats -> rewrap (P_tuple (List.map rewrite pats)) | P_list pats -> rewrap (P_list (List.map rewrite pats)) | P_cons (pat1, pat2) -> rewrap (P_cons (rewrite pat1, rewrite pat2)) | P_string_append pats -> rewrap (P_string_append (List.map rewrite pats)) -let rewrite_exp rewriters (E_aux (exp,(l,annot))) = - let rewrap e = E_aux (e,(l,annot)) in +let rewrite_exp rewriters (E_aux (exp, (l, annot))) = + let rewrap e = E_aux (e, (l, annot)) in let rewrite = rewriters.rewrite_exp rewriters in match exp with | E_block exps -> rewrap (E_block (List.map rewrite exps)) - | E_id _ | E_lit _ -> rewrap exp + | E_id _ | E_lit _ -> rewrap exp | E_typ (typ, exp) -> rewrap (E_typ (typ, rewrite exp)) - | E_app (id,exps) -> rewrap (E_app (id,List.map rewrite exps)) - | E_app_infix(el,id,er) -> rewrap (E_app_infix(rewrite el,id,rewrite er)) + | E_app (id, exps) -> rewrap (E_app (id, List.map rewrite exps)) + | E_app_infix (el, id, er) -> rewrap (E_app_infix (rewrite el, id, rewrite er)) | E_tuple exps -> rewrap (E_tuple (List.map rewrite exps)) - | E_if (c,t,e) -> rewrap (E_if (rewrite c,rewrite t, rewrite e)) - | E_for (id, e1, e2, e3, o, body) -> - rewrap (E_for (id, rewrite e1, rewrite e2, rewrite e3, o, rewrite body)) + | E_if (c, t, e) -> rewrap (E_if (rewrite c, rewrite t, rewrite e)) + | E_for (id, e1, e2, e3, o, body) -> rewrap (E_for (id, rewrite e1, rewrite e2, rewrite e3, o, rewrite body)) | E_loop (loop, m, e1, e2) -> - let m = match m with - | Measure_aux (Measure_none,_) -> m - | Measure_aux (Measure_some exp,l) -> Measure_aux (Measure_some (rewrite exp),l) - in - rewrap (E_loop (loop, m, rewrite e1, rewrite e2)) + let m = + match m with + | Measure_aux (Measure_none, _) -> m + | Measure_aux (Measure_some exp, l) -> Measure_aux (Measure_some (rewrite exp), l) + in + rewrap (E_loop (loop, m, rewrite e1, rewrite e2)) | E_vector exps -> rewrap (E_vector (List.map rewrite exps)) - | E_vector_access (vec,index) -> rewrap (E_vector_access (rewrite vec,rewrite index)) - | E_vector_subrange (vec,i1,i2) -> - rewrap (E_vector_subrange (rewrite vec,rewrite i1,rewrite i2)) - | E_vector_update (vec,index,new_v) -> - rewrap (E_vector_update (rewrite vec,rewrite index,rewrite new_v)) - | E_vector_update_subrange (vec,i1,i2,new_v) -> - rewrap (E_vector_update_subrange (rewrite vec,rewrite i1,rewrite i2,rewrite new_v)) - | E_vector_append (v1,v2) -> rewrap (E_vector_append (rewrite v1,rewrite v2)) + | E_vector_access (vec, index) -> rewrap (E_vector_access (rewrite vec, rewrite index)) + | E_vector_subrange (vec, i1, i2) -> rewrap (E_vector_subrange (rewrite vec, rewrite i1, rewrite i2)) + | E_vector_update (vec, index, new_v) -> rewrap (E_vector_update (rewrite vec, rewrite index, rewrite new_v)) + | E_vector_update_subrange (vec, i1, i2, new_v) -> + rewrap (E_vector_update_subrange (rewrite vec, rewrite i1, rewrite i2, rewrite new_v)) + | E_vector_append (v1, v2) -> rewrap (E_vector_append (rewrite v1, rewrite v2)) | E_list exps -> rewrap (E_list (List.map rewrite exps)) - | E_cons(h,t) -> rewrap (E_cons (rewrite h,rewrite t)) + | E_cons (h, t) -> rewrap (E_cons (rewrite h, rewrite t)) | E_struct fexps -> - rewrap (E_struct - (List.map (fun (FE_aux(FE_fexp(id,e),fannot)) -> - FE_aux(FE_fexp(id,rewrite e),fannot)) fexps)) + rewrap + (E_struct (List.map (fun (FE_aux (FE_fexp (id, e), fannot)) -> FE_aux (FE_fexp (id, rewrite e), fannot)) fexps)) | E_struct_update (re, fexps) -> - rewrap (E_struct_update ((rewrite re), - (List.map (fun (FE_aux(FE_fexp(id,e),fannot)) -> - FE_aux(FE_fexp(id,rewrite e),fannot)) fexps))) - | E_field(exp,id) -> rewrap (E_field(rewrite exp,id)) - | E_match (exp,pexps) -> - rewrap (E_match (rewrite exp, List.map (rewrite_pexp rewriters) pexps)) - | E_try (exp,pexps) -> - rewrap (E_try (rewrite exp, List.map (rewrite_pexp rewriters) pexps)) - | E_let (letbind,body) -> rewrap (E_let(rewriters.rewrite_let rewriters letbind,rewrite body)) - | E_assign (lexp,exp) -> rewrap (E_assign(rewriters.rewrite_lexp rewriters lexp,rewrite exp)) + rewrap + (E_struct_update + ( rewrite re, + List.map (fun (FE_aux (FE_fexp (id, e), fannot)) -> FE_aux (FE_fexp (id, rewrite e), fannot)) fexps + ) + ) + | E_field (exp, id) -> rewrap (E_field (rewrite exp, id)) + | E_match (exp, pexps) -> rewrap (E_match (rewrite exp, List.map (rewrite_pexp rewriters) pexps)) + | E_try (exp, pexps) -> rewrap (E_try (rewrite exp, List.map (rewrite_pexp rewriters) pexps)) + | E_let (letbind, body) -> rewrap (E_let (rewriters.rewrite_let rewriters letbind, rewrite body)) + | E_assign (lexp, exp) -> rewrap (E_assign (rewriters.rewrite_lexp rewriters lexp, rewrite exp)) | E_sizeof n -> rewrap (E_sizeof n) | E_exit e -> rewrap (E_exit (rewrite e)) | E_throw e -> rewrap (E_throw (rewrite e)) | E_return e -> rewrap (E_return (rewrite e)) - | E_assert(e1,e2) -> rewrap (E_assert(rewrite e1,rewrite e2)) + | E_assert (e1, e2) -> rewrap (E_assert (rewrite e1, rewrite e2)) | E_var (lexp, e1, e2) -> - rewrap (E_var (rewriters.rewrite_lexp rewriters lexp, rewriters.rewrite_exp rewriters e1, rewriters.rewrite_exp rewriters e2)) - | E_internal_return _ -> raise (Reporting.err_unreachable l __POS__ "Internal return found before it should have been introduced") - | E_internal_plet _ -> raise (Reporting.err_unreachable l __POS__ " Internal plet found before it should have been introduced") + rewrap + (E_var + ( rewriters.rewrite_lexp rewriters lexp, + rewriters.rewrite_exp rewriters e1, + rewriters.rewrite_exp rewriters e2 + ) + ) + | E_internal_return _ -> + raise (Reporting.err_unreachable l __POS__ "Internal return found before it should have been introduced") + | E_internal_plet _ -> + raise (Reporting.err_unreachable l __POS__ " Internal plet found before it should have been introduced") | _ -> rewrap exp -let rewrite_let rewriters (LB_aux(letbind,(l,annot))) = +let rewrite_let rewriters (LB_aux (letbind, (l, annot))) = match letbind with - | LB_val ( pat, exp) -> - LB_aux(LB_val (rewriters.rewrite_pat rewriters pat, - rewriters.rewrite_exp rewriters exp), - (l, annot)) + | LB_val (pat, exp) -> + LB_aux (LB_val (rewriters.rewrite_pat rewriters pat, rewriters.rewrite_exp rewriters exp), (l, annot)) -let rewrite_lexp rewriters (LE_aux(lexp,(l,annot))) = - let rewrap le = LE_aux(le,(l,annot)) in +let rewrite_lexp rewriters (LE_aux (lexp, (l, annot))) = + let rewrap le = LE_aux (le, (l, annot)) in match lexp with | LE_id _ | LE_typ _ -> rewrap lexp | LE_deref exp -> rewrap (LE_deref (rewriters.rewrite_exp rewriters exp)) | LE_tuple tupls -> rewrap (LE_tuple (List.map (rewriters.rewrite_lexp rewriters) tupls)) - | LE_app (id,exps) -> rewrap (LE_app(id,List.map (rewriters.rewrite_exp rewriters) exps)) - | LE_vector (lexp,exp) -> - rewrap (LE_vector (rewriters.rewrite_lexp rewriters lexp,rewriters.rewrite_exp rewriters exp)) - | LE_vector_range (lexp,exp1,exp2) -> - rewrap (LE_vector_range (rewriters.rewrite_lexp rewriters lexp, - rewriters.rewrite_exp rewriters exp1, - rewriters.rewrite_exp rewriters exp2)) + | LE_app (id, exps) -> rewrap (LE_app (id, List.map (rewriters.rewrite_exp rewriters) exps)) + | LE_vector (lexp, exp) -> + rewrap (LE_vector (rewriters.rewrite_lexp rewriters lexp, rewriters.rewrite_exp rewriters exp)) + | LE_vector_range (lexp, exp1, exp2) -> + rewrap + (LE_vector_range + ( rewriters.rewrite_lexp rewriters lexp, + rewriters.rewrite_exp rewriters exp1, + rewriters.rewrite_exp rewriters exp2 + ) + ) | LE_vector_concat lexps -> rewrap (LE_vector_concat (List.map (rewriters.rewrite_lexp rewriters) lexps)) - | LE_field (lexp,id) -> rewrap (LE_field (rewriters.rewrite_lexp rewriters lexp,id)) + | LE_field (lexp, id) -> rewrap (LE_field (rewriters.rewrite_lexp rewriters lexp, id)) -let rewrite_funcl rewriters (FCL_aux (FCL_funcl(id,pexp),(l,annot))) = - FCL_aux (FCL_funcl (id, rewrite_pexp rewriters pexp),(l,annot)) +let rewrite_funcl rewriters (FCL_aux (FCL_funcl (id, pexp), (l, annot))) = + FCL_aux (FCL_funcl (id, rewrite_pexp rewriters pexp), (l, annot)) -let rewrite_fun rewriters (FD_aux (FD_function(recopt,tannotopt,funcls),(l,fdannot))) = - let recopt = match recopt with +let rewrite_fun rewriters (FD_aux (FD_function (recopt, tannotopt, funcls), (l, fdannot))) = + let recopt = + match recopt with | Rec_aux (Rec_nonrec, l) -> Rec_aux (Rec_nonrec, l) | Rec_aux (Rec_rec, l) -> Rec_aux (Rec_rec, l) - | Rec_aux (Rec_measure (pat,exp),l) -> - Rec_aux (Rec_measure (rewriters.rewrite_pat rewriters pat, rewriters.rewrite_exp rewriters exp),l) + | Rec_aux (Rec_measure (pat, exp), l) -> + Rec_aux (Rec_measure (rewriters.rewrite_pat rewriters pat, rewriters.rewrite_exp rewriters exp), l) in - FD_aux (FD_function(recopt,tannotopt,List.map (rewrite_funcl rewriters) funcls), (l,fdannot)) + FD_aux (FD_function (recopt, tannotopt, List.map (rewrite_funcl rewriters) funcls), (l, fdannot)) let rewrite_mpexp rewriters (MPat_aux (aux, (l, annot))) = - let aux = match aux with + let aux = + match aux with | MPat_pat mpat -> MPat_pat mpat | MPat_when (mpat, exp) -> MPat_when (mpat, rewriters.rewrite_exp rewriters exp) in MPat_aux (aux, (l, annot)) - + let rewrite_mapcl rewriters (MCL_aux (aux, def_annot)) = - let aux = match aux with + let aux = + match aux with | MCL_bidir (mpexp1, mpexp2) -> MCL_bidir (rewrite_mpexp rewriters mpexp1, mpexp2) | MCL_forwards (mpexp, exp) -> MCL_forwards (rewrite_mpexp rewriters mpexp, rewriters.rewrite_exp rewriters exp) | MCL_backwards (mpexp, exp) -> MCL_backwards (rewrite_mpexp rewriters mpexp, rewriters.rewrite_exp rewriters exp) @@ -269,9 +277,10 @@ let rewrite_mapcl rewriters (MCL_aux (aux, def_annot)) = let rewrite_mapdef rewriters (MD_aux (MD_mapping (id, tannot_opt, mapcls), annot)) = MD_aux (MD_mapping (id, tannot_opt, List.map (rewrite_mapcl rewriters) mapcls), annot) - + let rewrite_scattered rewriters (SD_aux (sd, (l, annot))) = - let sd = match sd with + let sd = + match sd with | SD_funcl funcl -> SD_funcl (rewrite_funcl rewriters funcl) | SD_mapcl (id, mapcl) -> SD_mapcl (id, rewrite_mapcl rewriters mapcl) | SD_variant _ | SD_unioncl _ | SD_mapping _ | SD_function _ | SD_end _ -> sd @@ -279,10 +288,13 @@ let rewrite_scattered rewriters (SD_aux (sd, (l, annot))) = SD_aux (sd, (l, annot)) let rec rewrite_def rewriters (DEF_aux (aux, def_annot)) = - let aux = match aux with + let aux = + match aux with | DEF_register (DEC_aux (DEC_reg (typ, id, Some exp), annot)) -> - DEF_register (DEC_aux (DEC_reg (typ, id, Some (rewriters.rewrite_exp rewriters exp)), annot)) - | DEF_type _ | DEF_mapdef _ | DEF_val _ | DEF_default _ | DEF_register _ | DEF_overload _ | DEF_fixity _ | DEF_instantiation _ -> aux + DEF_register (DEC_aux (DEC_reg (typ, id, Some (rewriters.rewrite_exp rewriters exp)), annot)) + | DEF_type _ | DEF_mapdef _ | DEF_val _ | DEF_default _ | DEF_register _ | DEF_overload _ | DEF_fixity _ + | DEF_instantiation _ -> + aux | DEF_fundef fdef -> DEF_fundef (rewriters.rewrite_fun rewriters fdef) | DEF_impl funcl -> DEF_impl (rewrite_funcl rewriters funcl) | DEF_outcome (outcome_spec, defs) -> DEF_outcome (outcome_spec, List.map (rewrite_def rewriters) defs) @@ -290,22 +302,19 @@ let rec rewrite_def rewriters (DEF_aux (aux, def_annot)) = | DEF_let letbind -> DEF_let (rewriters.rewrite_let rewriters letbind) | DEF_pragma (pragma, arg, l) -> DEF_pragma (pragma, arg, l) | DEF_scattered sd -> DEF_scattered (rewrite_scattered rewriters sd) - | DEF_measure (id,pat,exp) -> DEF_measure (id,rewriters.rewrite_pat rewriters pat, rewriters.rewrite_exp rewriters exp) - | DEF_loop_measures (id,_) -> raise (Reporting.err_unreachable (id_loc id) __POS__ "DEF_loop_measures survived to rewriter") + | DEF_measure (id, pat, exp) -> + DEF_measure (id, rewriters.rewrite_pat rewriters pat, rewriters.rewrite_exp rewriters exp) + | DEF_loop_measures (id, _) -> + raise (Reporting.err_unreachable (id_loc id) __POS__ "DEF_loop_measures survived to rewriter") in DEF_aux (aux, def_annot) let rewrite_ast_defs rewriters defs = - let rec rewrite ds = match ds with - | [] -> [] - | d::ds -> (rewriters.rewrite_def rewriters d)::(rewrite ds) - in + let rec rewrite ds = match ds with [] -> [] | d :: ds -> rewriters.rewrite_def rewriters d :: rewrite ds in rewrite defs let rewrite_ast_base rewriters ast = - let rec rewrite ds = match ds with - | [] -> [] - | d::ds -> (rewriters.rewrite_def rewriters d)::(rewrite ds) in + let rec rewrite ds = match ds with [] -> [] | d :: ds -> rewriters.rewrite_def rewriters d :: rewrite ds in { ast with defs = rewrite ast.defs } let rewrite_ast_base_progress prefix rewriters ast = @@ -313,331 +322,342 @@ let rewrite_ast_base_progress prefix rewriters ast = let rec rewrite n = function | [] -> [] | d :: ds -> - Util.progress (prefix ^ " ") (string_of_int n ^ "/" ^ string_of_int total) n total; - let d = rewriters.rewrite_def rewriters d in - d :: rewrite (n + 1) ds + Util.progress (prefix ^ " ") (string_of_int n ^ "/" ^ string_of_int total) n total; + let d = rewriters.rewrite_def rewriters d in + d :: rewrite (n + 1) ds in { ast with defs = rewrite 1 ast.defs } let rewriters_base = - {rewrite_exp = rewrite_exp; - rewrite_pat = rewrite_pat; - rewrite_let = rewrite_let; - rewrite_lexp = rewrite_lexp; - rewrite_fun = rewrite_fun; - rewrite_def = rewrite_def; - rewrite_ast = rewrite_ast_base} + { rewrite_exp; rewrite_pat; rewrite_let; rewrite_lexp; rewrite_fun; rewrite_def; rewrite_ast = rewrite_ast_base } let rewrite_ast ast = rewrite_ast_base rewriters_base ast -type ('a,'pat,'pat_aux) pat_alg = - { p_lit : lit -> 'pat_aux - ; p_wild : 'pat_aux - ; p_or : 'pat * 'pat -> 'pat_aux - ; p_not : 'pat -> 'pat_aux - ; p_as : 'pat * id -> 'pat_aux - ; p_typ : Ast.typ * 'pat -> 'pat_aux - ; p_id : id -> 'pat_aux - ; p_var : 'pat * typ_pat -> 'pat_aux - ; p_app : id * 'pat list -> 'pat_aux - ; p_vector : 'pat list -> 'pat_aux - ; p_vector_concat : 'pat list -> 'pat_aux - ; p_vector_subrange : id * Big_int.num * Big_int.num -> 'pat_aux - ; p_tuple : 'pat list -> 'pat_aux - ; p_list : 'pat list -> 'pat_aux - ; p_cons : 'pat * 'pat -> 'pat_aux - ; p_string_append : 'pat list -> 'pat_aux - ; p_aux : 'pat_aux * 'a annot -> 'pat - } +type ('a, 'pat, 'pat_aux) pat_alg = { + p_lit : lit -> 'pat_aux; + p_wild : 'pat_aux; + p_or : 'pat * 'pat -> 'pat_aux; + p_not : 'pat -> 'pat_aux; + p_as : 'pat * id -> 'pat_aux; + p_typ : Ast.typ * 'pat -> 'pat_aux; + p_id : id -> 'pat_aux; + p_var : 'pat * typ_pat -> 'pat_aux; + p_app : id * 'pat list -> 'pat_aux; + p_vector : 'pat list -> 'pat_aux; + p_vector_concat : 'pat list -> 'pat_aux; + p_vector_subrange : id * Big_int.num * Big_int.num -> 'pat_aux; + p_tuple : 'pat list -> 'pat_aux; + p_list : 'pat list -> 'pat_aux; + p_cons : 'pat * 'pat -> 'pat_aux; + p_string_append : 'pat list -> 'pat_aux; + p_aux : 'pat_aux * 'a annot -> 'pat; +} -let rec fold_pat_aux (alg : ('a,'pat,'pat_aux) pat_alg) : 'a pat_aux -> 'pat_aux = - function - | P_lit lit -> alg.p_lit lit - | P_wild -> alg.p_wild - | P_or(p1, p2) -> alg.p_or (fold_pat alg p1, fold_pat alg p2) - | P_not(p) -> alg.p_not (fold_pat alg p) - | P_id id -> alg.p_id id - | P_var (p,tpat) -> alg.p_var (fold_pat alg p, tpat) - | P_as (p,id) -> alg.p_as (fold_pat alg p, id) - | P_typ (typ,p) -> alg.p_typ (typ,fold_pat alg p) - | P_app (id,ps) -> alg.p_app (id,List.map (fold_pat alg) ps) - | P_vector ps -> alg.p_vector (List.map (fold_pat alg) ps) - | P_vector_concat ps -> alg.p_vector_concat (List.map (fold_pat alg) ps) +let rec fold_pat_aux (alg : ('a, 'pat, 'pat_aux) pat_alg) : 'a pat_aux -> 'pat_aux = function + | P_lit lit -> alg.p_lit lit + | P_wild -> alg.p_wild + | P_or (p1, p2) -> alg.p_or (fold_pat alg p1, fold_pat alg p2) + | P_not p -> alg.p_not (fold_pat alg p) + | P_id id -> alg.p_id id + | P_var (p, tpat) -> alg.p_var (fold_pat alg p, tpat) + | P_as (p, id) -> alg.p_as (fold_pat alg p, id) + | P_typ (typ, p) -> alg.p_typ (typ, fold_pat alg p) + | P_app (id, ps) -> alg.p_app (id, List.map (fold_pat alg) ps) + | P_vector ps -> alg.p_vector (List.map (fold_pat alg) ps) + | P_vector_concat ps -> alg.p_vector_concat (List.map (fold_pat alg) ps) | P_vector_subrange (id, n, m) -> alg.p_vector_subrange (id, n, m) - | P_tuple ps -> alg.p_tuple (List.map (fold_pat alg) ps) - | P_list ps -> alg.p_list (List.map (fold_pat alg) ps) - | P_cons (ph,pt) -> alg.p_cons (fold_pat alg ph, fold_pat alg pt) - | P_string_append ps -> alg.p_string_append (List.map (fold_pat alg) ps) + | P_tuple ps -> alg.p_tuple (List.map (fold_pat alg) ps) + | P_list ps -> alg.p_list (List.map (fold_pat alg) ps) + | P_cons (ph, pt) -> alg.p_cons (fold_pat alg ph, fold_pat alg pt) + | P_string_append ps -> alg.p_string_append (List.map (fold_pat alg) ps) -and fold_pat (alg : ('a,'pat,'pat_aux) pat_alg) : 'a pat -> 'pat = - function +and fold_pat (alg : ('a, 'pat, 'pat_aux) pat_alg) : 'a pat -> 'pat = function | P_aux (pat, annot) -> alg.p_aux (fold_pat_aux alg pat, annot) -let rec fold_mpat_aux (alg : ('a,'mpat,'mpat_aux) pat_alg) : 'a mpat_aux -> 'mpat_aux = - function - | MP_lit lit -> alg.p_lit lit - | MP_id id -> alg.p_id id - | MP_as (p, id) -> alg.p_as (fold_mpat alg p, id) - | MP_typ (p, typ) -> alg.p_typ (typ,fold_mpat alg p) - | MP_app (id, ps) -> alg.p_app (id,List.map (fold_mpat alg) ps) - | MP_vector ps -> alg.p_vector (List.map (fold_mpat alg) ps) - | MP_vector_concat ps -> alg.p_vector_concat (List.map (fold_mpat alg) ps) +let rec fold_mpat_aux (alg : ('a, 'mpat, 'mpat_aux) pat_alg) : 'a mpat_aux -> 'mpat_aux = function + | MP_lit lit -> alg.p_lit lit + | MP_id id -> alg.p_id id + | MP_as (p, id) -> alg.p_as (fold_mpat alg p, id) + | MP_typ (p, typ) -> alg.p_typ (typ, fold_mpat alg p) + | MP_app (id, ps) -> alg.p_app (id, List.map (fold_mpat alg) ps) + | MP_vector ps -> alg.p_vector (List.map (fold_mpat alg) ps) + | MP_vector_concat ps -> alg.p_vector_concat (List.map (fold_mpat alg) ps) | MP_vector_subrange (id, n, m) -> alg.p_vector_subrange (id, n, m) - | MP_tuple ps -> alg.p_tuple (List.map (fold_mpat alg) ps) - | MP_list ps -> alg.p_list (List.map (fold_mpat alg) ps) - | MP_cons (ph, pt) -> alg.p_cons (fold_mpat alg ph, fold_mpat alg pt) - | MP_string_append ps -> alg.p_string_append (List.map (fold_mpat alg) ps) + | MP_tuple ps -> alg.p_tuple (List.map (fold_mpat alg) ps) + | MP_list ps -> alg.p_list (List.map (fold_mpat alg) ps) + | MP_cons (ph, pt) -> alg.p_cons (fold_mpat alg ph, fold_mpat alg pt) + | MP_string_append ps -> alg.p_string_append (List.map (fold_mpat alg) ps) -and fold_mpat (alg : ('a,'mpat,'mpat_aux) pat_alg) : 'a mpat -> 'mpat = - function +and fold_mpat (alg : ('a, 'mpat, 'mpat_aux) pat_alg) : 'a mpat -> 'mpat = function | MP_aux (mpat, annot) -> alg.p_aux (fold_mpat_aux alg mpat, annot) (* identity fold from term alg to term alg *) -let id_pat_alg : ('a,'a pat, 'a pat_aux) pat_alg = - { p_lit = (fun lit -> P_lit lit) - ; p_wild = P_wild - ; p_or = (fun (pat1, pat2) -> P_or(pat1, pat2)) - ; p_not = (fun pat -> P_not(pat)) - ; p_as = (fun (pat,id) -> P_as (pat,id)) - ; p_typ = (fun (typ,pat) -> P_typ (typ,pat)) - ; p_id = (fun id -> P_id id) - ; p_var = (fun (pat,tpat) -> P_var (pat,tpat)) - ; p_app = (fun (id,ps) -> P_app (id,ps)) - ; p_vector = (fun ps -> P_vector ps) - ; p_vector_concat = (fun ps -> P_vector_concat ps) - ; p_vector_subrange = (fun (id, n, m) -> P_vector_subrange (id, n, m)) - ; p_tuple = (fun ps -> P_tuple ps) - ; p_list = (fun ps -> P_list ps) - ; p_cons = (fun (ph,pt) -> P_cons (ph,pt)) - ; p_string_append = (fun ps -> P_string_append ps) - ; p_aux = (fun (pat,annot) -> P_aux (pat,annot)) +let id_pat_alg : ('a, 'a pat, 'a pat_aux) pat_alg = + { + p_lit = (fun lit -> P_lit lit); + p_wild = P_wild; + p_or = (fun (pat1, pat2) -> P_or (pat1, pat2)); + p_not = (fun pat -> P_not pat); + p_as = (fun (pat, id) -> P_as (pat, id)); + p_typ = (fun (typ, pat) -> P_typ (typ, pat)); + p_id = (fun id -> P_id id); + p_var = (fun (pat, tpat) -> P_var (pat, tpat)); + p_app = (fun (id, ps) -> P_app (id, ps)); + p_vector = (fun ps -> P_vector ps); + p_vector_concat = (fun ps -> P_vector_concat ps); + p_vector_subrange = (fun (id, n, m) -> P_vector_subrange (id, n, m)); + p_tuple = (fun ps -> P_tuple ps); + p_list = (fun ps -> P_list ps); + p_cons = (fun (ph, pt) -> P_cons (ph, pt)); + p_string_append = (fun ps -> P_string_append ps); + p_aux = (fun (pat, annot) -> P_aux (pat, annot)); } let id_mpat_alg : ('a, 'a mpat option, 'a mpat_aux option) pat_alg = - { p_lit = (fun lit -> Some (MP_lit lit)) - ; p_wild = None - ; p_or = (fun _ -> None) - ; p_not = (fun _ -> None) - ; p_as = (fun (pat, id) -> Option.map (fun pat -> MP_as (pat, id)) pat) - ; p_typ = (fun (typ, pat) -> Option.map (fun pat -> MP_typ (pat, typ)) pat) - ; p_id = (fun id -> Some (MP_id id)) - ; p_var = (fun _ -> None) - ; p_app = (fun (id, ps) -> Option.map (fun ps -> MP_app (id, ps)) (Util.option_all ps)) - ; p_vector = (fun ps -> Option.map (fun ps -> MP_vector ps) (Util.option_all ps)) - ; p_vector_concat = (fun ps -> Option.map (fun ps -> MP_vector_concat ps) (Util.option_all ps)) - ; p_vector_subrange = (fun (id, n, m) -> Some (MP_vector_subrange (id, n, m))) - ; p_tuple = (fun ps -> Option.map (fun ps -> MP_tuple ps) (Util.option_all ps)) - ; p_list = (fun ps -> Option.map (fun ps -> MP_list ps) (Util.option_all ps)) - ; p_cons = (fun (ph, pt) -> Option.bind ph (fun ph -> Option.map (fun pt -> MP_cons (ph, pt)) pt)) - ; p_string_append = (fun ps -> Option.map (fun ps -> MP_string_append ps) (Util.option_all ps)) - ; p_aux = (fun (pat, annot) -> Option.map (fun pat -> MP_aux (pat,annot)) pat) + { + p_lit = (fun lit -> Some (MP_lit lit)); + p_wild = None; + p_or = (fun _ -> None); + p_not = (fun _ -> None); + p_as = (fun (pat, id) -> Option.map (fun pat -> MP_as (pat, id)) pat); + p_typ = (fun (typ, pat) -> Option.map (fun pat -> MP_typ (pat, typ)) pat); + p_id = (fun id -> Some (MP_id id)); + p_var = (fun _ -> None); + p_app = (fun (id, ps) -> Option.map (fun ps -> MP_app (id, ps)) (Util.option_all ps)); + p_vector = (fun ps -> Option.map (fun ps -> MP_vector ps) (Util.option_all ps)); + p_vector_concat = (fun ps -> Option.map (fun ps -> MP_vector_concat ps) (Util.option_all ps)); + p_vector_subrange = (fun (id, n, m) -> Some (MP_vector_subrange (id, n, m))); + p_tuple = (fun ps -> Option.map (fun ps -> MP_tuple ps) (Util.option_all ps)); + p_list = (fun ps -> Option.map (fun ps -> MP_list ps) (Util.option_all ps)); + p_cons = (fun (ph, pt) -> Option.bind ph (fun ph -> Option.map (fun pt -> MP_cons (ph, pt)) pt)); + p_string_append = (fun ps -> Option.map (fun ps -> MP_string_append ps) (Util.option_all ps)); + p_aux = (fun (pat, annot) -> Option.map (fun pat -> MP_aux (pat, annot)) pat); } -type ('a,'exp,'exp_aux,'lexp,'lexp_aux,'fexp,'fexp_aux, - 'opt_default_aux,'opt_default,'pexp,'pexp_aux,'letbind_aux,'letbind, - 'pat,'pat_aux) exp_alg = - { e_block : 'exp list -> 'exp_aux - ; e_id : id -> 'exp_aux - ; e_ref : id -> 'exp_aux - ; e_lit : lit -> 'exp_aux - ; e_typ : Ast.typ * 'exp -> 'exp_aux - ; e_app : id * 'exp list -> 'exp_aux - ; e_app_infix : 'exp * id * 'exp -> 'exp_aux - ; e_tuple : 'exp list -> 'exp_aux - ; e_if : 'exp * 'exp * 'exp -> 'exp_aux - ; e_for : id * 'exp * 'exp * 'exp * Ast.order * 'exp -> 'exp_aux - ; e_loop : loop * ('exp option * Parse_ast.l) * 'exp * 'exp -> 'exp_aux - ; e_vector : 'exp list -> 'exp_aux - ; e_vector_access : 'exp * 'exp -> 'exp_aux - ; e_vector_subrange : 'exp * 'exp * 'exp -> 'exp_aux - ; e_vector_update : 'exp * 'exp * 'exp -> 'exp_aux - ; e_vector_update_subrange : 'exp * 'exp * 'exp * 'exp -> 'exp_aux - ; e_vector_append : 'exp * 'exp -> 'exp_aux - ; e_list : 'exp list -> 'exp_aux - ; e_cons : 'exp * 'exp -> 'exp_aux - ; e_struct : 'fexp list -> 'exp_aux - ; e_struct_update : 'exp * 'fexp list -> 'exp_aux - ; e_field : 'exp * id -> 'exp_aux - ; e_case : 'exp * 'pexp list -> 'exp_aux - ; e_try : 'exp * 'pexp list -> 'exp_aux - ; e_let : 'letbind * 'exp -> 'exp_aux - ; e_assign : 'lexp * 'exp -> 'exp_aux - ; e_sizeof : nexp -> 'exp_aux - ; e_constraint : n_constraint -> 'exp_aux - ; e_exit : 'exp -> 'exp_aux - ; e_throw : 'exp -> 'exp_aux - ; e_return : 'exp -> 'exp_aux - ; e_assert : 'exp * 'exp -> 'exp_aux - ; e_var : 'lexp * 'exp * 'exp -> 'exp_aux - ; e_internal_plet : 'pat * 'exp * 'exp -> 'exp_aux - ; e_internal_return : 'exp -> 'exp_aux - ; e_internal_value : Value.value -> 'exp_aux - ; e_internal_assume : n_constraint * 'exp -> 'exp_aux - ; e_aux : 'exp_aux * 'a annot -> 'exp - ; le_id : id -> 'lexp_aux - ; le_deref : 'exp -> 'lexp_aux - ; le_app : id * 'exp list -> 'lexp_aux - ; le_typ : Ast.typ * id -> 'lexp_aux - ; le_tuple : 'lexp list -> 'lexp_aux - ; le_vector : 'lexp * 'exp -> 'lexp_aux - ; le_vector_range : 'lexp * 'exp * 'exp -> 'lexp_aux - ; le_vector_concat : 'lexp list -> 'lexp_aux - ; le_field : 'lexp * id -> 'lexp_aux - ; le_aux : 'lexp_aux * 'a annot -> 'lexp - ; fe_fexp : id * 'exp -> 'fexp_aux - ; fe_aux : 'fexp_aux * 'a annot -> 'fexp - ; def_val_empty : 'opt_default_aux - ; def_val_dec : 'exp -> 'opt_default_aux - ; def_val_aux : 'opt_default_aux * 'a annot -> 'opt_default - ; pat_exp : 'pat * 'exp -> 'pexp_aux - ; pat_when : 'pat * 'exp * 'exp -> 'pexp_aux - ; pat_aux : 'pexp_aux * 'a annot -> 'pexp - ; lb_val : 'pat * 'exp -> 'letbind_aux - ; lb_aux : 'letbind_aux * 'a annot -> 'letbind - ; pat_alg : ('a,'pat,'pat_aux) pat_alg - } +type ( 'a, + 'exp, + 'exp_aux, + 'lexp, + 'lexp_aux, + 'fexp, + 'fexp_aux, + 'opt_default_aux, + 'opt_default, + 'pexp, + 'pexp_aux, + 'letbind_aux, + 'letbind, + 'pat, + 'pat_aux + ) + exp_alg = { + e_block : 'exp list -> 'exp_aux; + e_id : id -> 'exp_aux; + e_ref : id -> 'exp_aux; + e_lit : lit -> 'exp_aux; + e_typ : Ast.typ * 'exp -> 'exp_aux; + e_app : id * 'exp list -> 'exp_aux; + e_app_infix : 'exp * id * 'exp -> 'exp_aux; + e_tuple : 'exp list -> 'exp_aux; + e_if : 'exp * 'exp * 'exp -> 'exp_aux; + e_for : id * 'exp * 'exp * 'exp * Ast.order * 'exp -> 'exp_aux; + e_loop : loop * ('exp option * Parse_ast.l) * 'exp * 'exp -> 'exp_aux; + e_vector : 'exp list -> 'exp_aux; + e_vector_access : 'exp * 'exp -> 'exp_aux; + e_vector_subrange : 'exp * 'exp * 'exp -> 'exp_aux; + e_vector_update : 'exp * 'exp * 'exp -> 'exp_aux; + e_vector_update_subrange : 'exp * 'exp * 'exp * 'exp -> 'exp_aux; + e_vector_append : 'exp * 'exp -> 'exp_aux; + e_list : 'exp list -> 'exp_aux; + e_cons : 'exp * 'exp -> 'exp_aux; + e_struct : 'fexp list -> 'exp_aux; + e_struct_update : 'exp * 'fexp list -> 'exp_aux; + e_field : 'exp * id -> 'exp_aux; + e_case : 'exp * 'pexp list -> 'exp_aux; + e_try : 'exp * 'pexp list -> 'exp_aux; + e_let : 'letbind * 'exp -> 'exp_aux; + e_assign : 'lexp * 'exp -> 'exp_aux; + e_sizeof : nexp -> 'exp_aux; + e_constraint : n_constraint -> 'exp_aux; + e_exit : 'exp -> 'exp_aux; + e_throw : 'exp -> 'exp_aux; + e_return : 'exp -> 'exp_aux; + e_assert : 'exp * 'exp -> 'exp_aux; + e_var : 'lexp * 'exp * 'exp -> 'exp_aux; + e_internal_plet : 'pat * 'exp * 'exp -> 'exp_aux; + e_internal_return : 'exp -> 'exp_aux; + e_internal_value : Value.value -> 'exp_aux; + e_internal_assume : n_constraint * 'exp -> 'exp_aux; + e_aux : 'exp_aux * 'a annot -> 'exp; + le_id : id -> 'lexp_aux; + le_deref : 'exp -> 'lexp_aux; + le_app : id * 'exp list -> 'lexp_aux; + le_typ : Ast.typ * id -> 'lexp_aux; + le_tuple : 'lexp list -> 'lexp_aux; + le_vector : 'lexp * 'exp -> 'lexp_aux; + le_vector_range : 'lexp * 'exp * 'exp -> 'lexp_aux; + le_vector_concat : 'lexp list -> 'lexp_aux; + le_field : 'lexp * id -> 'lexp_aux; + le_aux : 'lexp_aux * 'a annot -> 'lexp; + fe_fexp : id * 'exp -> 'fexp_aux; + fe_aux : 'fexp_aux * 'a annot -> 'fexp; + def_val_empty : 'opt_default_aux; + def_val_dec : 'exp -> 'opt_default_aux; + def_val_aux : 'opt_default_aux * 'a annot -> 'opt_default; + pat_exp : 'pat * 'exp -> 'pexp_aux; + pat_when : 'pat * 'exp * 'exp -> 'pexp_aux; + pat_aux : 'pexp_aux * 'a annot -> 'pexp; + lb_val : 'pat * 'exp -> 'letbind_aux; + lb_aux : 'letbind_aux * 'a annot -> 'letbind; + pat_alg : ('a, 'pat, 'pat_aux) pat_alg; +} let rec fold_exp_aux alg = function | E_block es -> alg.e_block (List.map (fold_exp alg) es) | E_id id -> alg.e_id id | E_ref id -> alg.e_ref id | E_lit lit -> alg.e_lit lit - | E_typ (typ,e) -> alg.e_typ (typ, fold_exp alg e) - | E_app (id,es) -> alg.e_app (id, List.map (fold_exp alg) es) - | E_app_infix (e1,id,e2) -> alg.e_app_infix (fold_exp alg e1, id, fold_exp alg e2) + | E_typ (typ, e) -> alg.e_typ (typ, fold_exp alg e) + | E_app (id, es) -> alg.e_app (id, List.map (fold_exp alg) es) + | E_app_infix (e1, id, e2) -> alg.e_app_infix (fold_exp alg e1, id, fold_exp alg e2) | E_tuple es -> alg.e_tuple (List.map (fold_exp alg) es) - | E_if (e1,e2,e3) -> alg.e_if (fold_exp alg e1, fold_exp alg e2, fold_exp alg e3) - | E_for (id,e1,e2,e3,order,e4) -> - alg.e_for (id,fold_exp alg e1, fold_exp alg e2, fold_exp alg e3, order, fold_exp alg e4) + | E_if (e1, e2, e3) -> alg.e_if (fold_exp alg e1, fold_exp alg e2, fold_exp alg e3) + | E_for (id, e1, e2, e3, order, e4) -> + alg.e_for (id, fold_exp alg e1, fold_exp alg e2, fold_exp alg e3, order, fold_exp alg e4) | E_loop (loop_type, m, e1, e2) -> - let m = match m with - | Measure_aux (Measure_none,l) -> None,l - | Measure_aux (Measure_some exp,l) -> Some (fold_exp alg exp),l - in - alg.e_loop (loop_type, m, fold_exp alg e1, fold_exp alg e2) + let m = + match m with + | Measure_aux (Measure_none, l) -> (None, l) + | Measure_aux (Measure_some exp, l) -> (Some (fold_exp alg exp), l) + in + alg.e_loop (loop_type, m, fold_exp alg e1, fold_exp alg e2) | E_vector es -> alg.e_vector (List.map (fold_exp alg) es) - | E_vector_access (e1,e2) -> alg.e_vector_access (fold_exp alg e1, fold_exp alg e2) - | E_vector_subrange (e1,e2,e3) -> - alg.e_vector_subrange (fold_exp alg e1, fold_exp alg e2, fold_exp alg e3) - | E_vector_update (e1,e2,e3) -> - alg.e_vector_update (fold_exp alg e1, fold_exp alg e2, fold_exp alg e3) - | E_vector_update_subrange (e1,e2,e3,e4) -> - alg.e_vector_update_subrange (fold_exp alg e1,fold_exp alg e2, fold_exp alg e3, fold_exp alg e4) - | E_vector_append (e1,e2) -> alg.e_vector_append (fold_exp alg e1, fold_exp alg e2) + | E_vector_access (e1, e2) -> alg.e_vector_access (fold_exp alg e1, fold_exp alg e2) + | E_vector_subrange (e1, e2, e3) -> alg.e_vector_subrange (fold_exp alg e1, fold_exp alg e2, fold_exp alg e3) + | E_vector_update (e1, e2, e3) -> alg.e_vector_update (fold_exp alg e1, fold_exp alg e2, fold_exp alg e3) + | E_vector_update_subrange (e1, e2, e3, e4) -> + alg.e_vector_update_subrange (fold_exp alg e1, fold_exp alg e2, fold_exp alg e3, fold_exp alg e4) + | E_vector_append (e1, e2) -> alg.e_vector_append (fold_exp alg e1, fold_exp alg e2) | E_list es -> alg.e_list (List.map (fold_exp alg) es) - | E_cons (e1,e2) -> alg.e_cons (fold_exp alg e1, fold_exp alg e2) + | E_cons (e1, e2) -> alg.e_cons (fold_exp alg e1, fold_exp alg e2) | E_struct fexps -> alg.e_struct (List.map (fold_fexp alg) fexps) - | E_struct_update (e,fexps) -> alg.e_struct_update (fold_exp alg e, List.map (fold_fexp alg) fexps) - | E_field (e,id) -> alg.e_field (fold_exp alg e, id) - | E_match (e,pexps) -> alg.e_case (fold_exp alg e, List.map (fold_pexp alg) pexps) - | E_try (e,pexps) -> alg.e_try (fold_exp alg e, List.map (fold_pexp alg) pexps) - | E_let (letbind,e) -> alg.e_let (fold_letbind alg letbind, fold_exp alg e) - | E_assign (lexp,e) -> alg.e_assign (fold_lexp alg lexp, fold_exp alg e) + | E_struct_update (e, fexps) -> alg.e_struct_update (fold_exp alg e, List.map (fold_fexp alg) fexps) + | E_field (e, id) -> alg.e_field (fold_exp alg e, id) + | E_match (e, pexps) -> alg.e_case (fold_exp alg e, List.map (fold_pexp alg) pexps) + | E_try (e, pexps) -> alg.e_try (fold_exp alg e, List.map (fold_pexp alg) pexps) + | E_let (letbind, e) -> alg.e_let (fold_letbind alg letbind, fold_exp alg e) + | E_assign (lexp, e) -> alg.e_assign (fold_lexp alg lexp, fold_exp alg e) | E_sizeof nexp -> alg.e_sizeof nexp | E_constraint nc -> alg.e_constraint nc | E_exit e -> alg.e_exit (fold_exp alg e) | E_throw e -> alg.e_throw (fold_exp alg e) | E_return e -> alg.e_return (fold_exp alg e) - | E_assert(e1,e2) -> alg.e_assert (fold_exp alg e1, fold_exp alg e2) - | E_var (lexp,e1,e2) -> - alg.e_var (fold_lexp alg lexp, fold_exp alg e1, fold_exp alg e2) - | E_internal_plet (pat,e1,e2) -> - alg.e_internal_plet (fold_pat alg.pat_alg pat, fold_exp alg e1, fold_exp alg e2) + | E_assert (e1, e2) -> alg.e_assert (fold_exp alg e1, fold_exp alg e2) + | E_var (lexp, e1, e2) -> alg.e_var (fold_lexp alg lexp, fold_exp alg e1, fold_exp alg e2) + | E_internal_plet (pat, e1, e2) -> alg.e_internal_plet (fold_pat alg.pat_alg pat, fold_exp alg e1, fold_exp alg e2) | E_internal_return e -> alg.e_internal_return (fold_exp alg e) | E_internal_value v -> alg.e_internal_value v | E_internal_assume (nc, e) -> alg.e_internal_assume (nc, fold_exp alg e) -and fold_exp alg (E_aux (exp_aux,annot)) = alg.e_aux (fold_exp_aux alg exp_aux, annot) + +and fold_exp alg (E_aux (exp_aux, annot)) = alg.e_aux (fold_exp_aux alg exp_aux, annot) + and fold_lexp_aux alg = function | LE_id id -> alg.le_id id | LE_deref exp -> alg.le_deref (fold_exp alg exp) - | LE_app (id,es) -> alg.le_app (id, List.map (fold_exp alg) es) + | LE_app (id, es) -> alg.le_app (id, List.map (fold_exp alg) es) | LE_tuple les -> alg.le_tuple (List.map (fold_lexp alg) les) - | LE_typ (typ,id) -> alg.le_typ (typ,id) - | LE_vector (lexp,e) -> alg.le_vector (fold_lexp alg lexp, fold_exp alg e) - | LE_vector_range (lexp,e1,e2) -> - alg.le_vector_range (fold_lexp alg lexp, fold_exp alg e1, fold_exp alg e2) + | LE_typ (typ, id) -> alg.le_typ (typ, id) + | LE_vector (lexp, e) -> alg.le_vector (fold_lexp alg lexp, fold_exp alg e) + | LE_vector_range (lexp, e1, e2) -> alg.le_vector_range (fold_lexp alg lexp, fold_exp alg e1, fold_exp alg e2) | LE_vector_concat les -> alg.le_vector_concat (List.map (fold_lexp alg) les) - | LE_field (lexp,id) -> alg.le_field (fold_lexp alg lexp, id) -and fold_lexp alg (LE_aux (lexp_aux,annot)) = - alg.le_aux (fold_lexp_aux alg lexp_aux, annot) -and fold_fexp_aux alg (FE_fexp (id,e)) = alg.fe_fexp (id, fold_exp alg e) -and fold_fexp alg (FE_aux (fexp_aux,annot)) = alg.fe_aux (fold_fexp_aux alg fexp_aux,annot) + | LE_field (lexp, id) -> alg.le_field (fold_lexp alg lexp, id) + +and fold_lexp alg (LE_aux (lexp_aux, annot)) = alg.le_aux (fold_lexp_aux alg lexp_aux, annot) + +and fold_fexp_aux alg (FE_fexp (id, e)) = alg.fe_fexp (id, fold_exp alg e) + +and fold_fexp alg (FE_aux (fexp_aux, annot)) = alg.fe_aux (fold_fexp_aux alg fexp_aux, annot) + and fold_pexp_aux alg = function - | Pat_exp (pat,e) -> alg.pat_exp (fold_pat alg.pat_alg pat, fold_exp alg e) - | Pat_when (pat,e,e') -> alg.pat_when (fold_pat alg.pat_alg pat, fold_exp alg e, fold_exp alg e') -and fold_pexp alg (Pat_aux (pexp_aux,annot)) = alg.pat_aux (fold_pexp_aux alg pexp_aux, annot) -and fold_letbind_aux alg = function - | LB_val (pat,e) -> alg.lb_val (fold_pat alg.pat_alg pat, fold_exp alg e) -and fold_letbind alg (LB_aux (letbind_aux,annot)) = alg.lb_aux (fold_letbind_aux alg letbind_aux, annot) + | Pat_exp (pat, e) -> alg.pat_exp (fold_pat alg.pat_alg pat, fold_exp alg e) + | Pat_when (pat, e, e') -> alg.pat_when (fold_pat alg.pat_alg pat, fold_exp alg e, fold_exp alg e') + +and fold_pexp alg (Pat_aux (pexp_aux, annot)) = alg.pat_aux (fold_pexp_aux alg pexp_aux, annot) -let fold_funcl alg (FCL_aux (FCL_funcl (id, pexp), annot)) = - FCL_aux (FCL_funcl (id, fold_pexp alg pexp), annot) +and fold_letbind_aux alg = function LB_val (pat, e) -> alg.lb_val (fold_pat alg.pat_alg pat, fold_exp alg e) + +and fold_letbind alg (LB_aux (letbind_aux, annot)) = alg.lb_aux (fold_letbind_aux alg letbind_aux, annot) + +let fold_funcl alg (FCL_aux (FCL_funcl (id, pexp), annot)) = FCL_aux (FCL_funcl (id, fold_pexp alg pexp), annot) let fold_function alg (FD_aux (FD_function (rec_opt, tannot_opt, funcls), annot)) = FD_aux (FD_function (rec_opt, tannot_opt, List.map (fold_funcl alg) funcls), annot) let id_exp_alg = - { e_block = (fun es -> E_block es) - ; e_id = (fun id -> E_id id) - ; e_ref = (fun id -> E_ref id) - ; e_lit = (fun lit -> (E_lit lit)) - ; e_typ = (fun (typ,e) -> E_typ (typ,e)) - ; e_app = (fun (id,es) -> E_app (id,es)) - ; e_app_infix = (fun (e1,id,e2) -> E_app_infix (e1,id,e2)) - ; e_tuple = (fun es -> E_tuple es) - ; e_if = (fun (e1,e2,e3) -> E_if (e1,e2,e3)) - ; e_for = (fun (id,e1,e2,e3,order,e4) -> E_for (id,e1,e2,e3,order,e4)) - ; e_loop = (fun (lt, (m,l), e1, e2) -> - let m = match m with None -> Measure_none | Some e -> Measure_some e in - E_loop (lt, Measure_aux (m,l), e1, e2)) - ; e_vector = (fun es -> E_vector es) - ; e_vector_access = (fun (e1,e2) -> E_vector_access (e1,e2)) - ; e_vector_subrange = (fun (e1,e2,e3) -> E_vector_subrange (e1,e2,e3)) - ; e_vector_update = (fun (e1,e2,e3) -> E_vector_update (e1,e2,e3)) - ; e_vector_update_subrange = (fun (e1,e2,e3,e4) -> E_vector_update_subrange (e1,e2,e3,e4)) - ; e_vector_append = (fun (e1,e2) -> E_vector_append (e1,e2)) - ; e_list = (fun es -> E_list es) - ; e_cons = (fun (e1,e2) -> E_cons (e1,e2)) - ; e_struct = (fun fexps -> E_struct fexps) - ; e_struct_update = (fun (e1,fexp) -> E_struct_update (e1,fexp)) - ; e_field = (fun (e1,id) -> (E_field (e1,id))) - ; e_case = (fun (e1,pexps) -> E_match (e1,pexps)) - ; e_try = (fun (e1,pexps) -> E_try (e1,pexps)) - ; e_let = (fun (lb,e2) -> E_let (lb,e2)) - ; e_assign = (fun (lexp,e2) -> E_assign (lexp,e2)) - ; e_sizeof = (fun nexp -> E_sizeof nexp) - ; e_constraint = (fun nc -> E_constraint nc) - ; e_exit = (fun e1 -> E_exit (e1)) - ; e_throw = (fun e1 -> E_throw (e1)) - ; e_return = (fun e1 -> E_return e1) - ; e_assert = (fun (e1,e2) -> E_assert(e1,e2)) - ; e_var = (fun (lexp, e2, e3) -> E_var (lexp,e2,e3)) - ; e_internal_plet = (fun (pat, e1, e2) -> E_internal_plet (pat,e1,e2)) - ; e_internal_return = (fun e -> E_internal_return e) - ; e_internal_value = (fun v -> E_internal_value v) - ; e_internal_assume = (fun (nc,e) -> E_internal_assume (nc, e)) - ; e_aux = (fun (e,annot) -> E_aux (e,annot)) - ; le_id = (fun id -> LE_id id) - ; le_deref = (fun e -> LE_deref e) - ; le_app = (fun (id,es) -> LE_app (id,es)) - ; le_typ = (fun (typ,id) -> LE_typ (typ,id)) - ; le_tuple = (fun tups -> LE_tuple tups) - ; le_vector = (fun (lexp,e2) -> LE_vector (lexp,e2)) - ; le_vector_range = (fun (lexp,e2,e3) -> LE_vector_range (lexp,e2,e3)) - ; le_vector_concat = (fun lexps -> LE_vector_concat lexps) - ; le_field = (fun (lexp,id) -> LE_field (lexp,id)) - ; le_aux = (fun (lexp,annot) -> LE_aux (lexp,annot)) - ; fe_fexp = (fun (id,e) -> FE_fexp (id,e)) - ; fe_aux = (fun (fexp,annot) -> FE_aux (fexp,annot)) - ; def_val_empty = Def_val_empty - ; def_val_dec = (fun e -> Def_val_dec e) - ; def_val_aux = (fun (defval,aux) -> Def_val_aux (defval,aux)) - ; pat_exp = (fun (pat,e) -> (Pat_exp (pat,e))) - ; pat_when = (fun (pat,e,e') -> (Pat_when (pat,e,e'))) - ; pat_aux = (fun (pexp,a) -> (Pat_aux (pexp,a))) - ; lb_val = (fun (pat,e) -> LB_val (pat,e)) - ; lb_aux = (fun (lb,annot) -> LB_aux (lb,annot)) - ; pat_alg = id_pat_alg + { + e_block = (fun es -> E_block es); + e_id = (fun id -> E_id id); + e_ref = (fun id -> E_ref id); + e_lit = (fun lit -> E_lit lit); + e_typ = (fun (typ, e) -> E_typ (typ, e)); + e_app = (fun (id, es) -> E_app (id, es)); + e_app_infix = (fun (e1, id, e2) -> E_app_infix (e1, id, e2)); + e_tuple = (fun es -> E_tuple es); + e_if = (fun (e1, e2, e3) -> E_if (e1, e2, e3)); + e_for = (fun (id, e1, e2, e3, order, e4) -> E_for (id, e1, e2, e3, order, e4)); + e_loop = + (fun (lt, (m, l), e1, e2) -> + let m = match m with None -> Measure_none | Some e -> Measure_some e in + E_loop (lt, Measure_aux (m, l), e1, e2) + ); + e_vector = (fun es -> E_vector es); + e_vector_access = (fun (e1, e2) -> E_vector_access (e1, e2)); + e_vector_subrange = (fun (e1, e2, e3) -> E_vector_subrange (e1, e2, e3)); + e_vector_update = (fun (e1, e2, e3) -> E_vector_update (e1, e2, e3)); + e_vector_update_subrange = (fun (e1, e2, e3, e4) -> E_vector_update_subrange (e1, e2, e3, e4)); + e_vector_append = (fun (e1, e2) -> E_vector_append (e1, e2)); + e_list = (fun es -> E_list es); + e_cons = (fun (e1, e2) -> E_cons (e1, e2)); + e_struct = (fun fexps -> E_struct fexps); + e_struct_update = (fun (e1, fexp) -> E_struct_update (e1, fexp)); + e_field = (fun (e1, id) -> E_field (e1, id)); + e_case = (fun (e1, pexps) -> E_match (e1, pexps)); + e_try = (fun (e1, pexps) -> E_try (e1, pexps)); + e_let = (fun (lb, e2) -> E_let (lb, e2)); + e_assign = (fun (lexp, e2) -> E_assign (lexp, e2)); + e_sizeof = (fun nexp -> E_sizeof nexp); + e_constraint = (fun nc -> E_constraint nc); + e_exit = (fun e1 -> E_exit e1); + e_throw = (fun e1 -> E_throw e1); + e_return = (fun e1 -> E_return e1); + e_assert = (fun (e1, e2) -> E_assert (e1, e2)); + e_var = (fun (lexp, e2, e3) -> E_var (lexp, e2, e3)); + e_internal_plet = (fun (pat, e1, e2) -> E_internal_plet (pat, e1, e2)); + e_internal_return = (fun e -> E_internal_return e); + e_internal_value = (fun v -> E_internal_value v); + e_internal_assume = (fun (nc, e) -> E_internal_assume (nc, e)); + e_aux = (fun (e, annot) -> E_aux (e, annot)); + le_id = (fun id -> LE_id id); + le_deref = (fun e -> LE_deref e); + le_app = (fun (id, es) -> LE_app (id, es)); + le_typ = (fun (typ, id) -> LE_typ (typ, id)); + le_tuple = (fun tups -> LE_tuple tups); + le_vector = (fun (lexp, e2) -> LE_vector (lexp, e2)); + le_vector_range = (fun (lexp, e2, e3) -> LE_vector_range (lexp, e2, e3)); + le_vector_concat = (fun lexps -> LE_vector_concat lexps); + le_field = (fun (lexp, id) -> LE_field (lexp, id)); + le_aux = (fun (lexp, annot) -> LE_aux (lexp, annot)); + fe_fexp = (fun (id, e) -> FE_fexp (id, e)); + fe_aux = (fun (fexp, annot) -> FE_aux (fexp, annot)); + def_val_empty = Def_val_empty; + def_val_dec = (fun e -> Def_val_dec e); + def_val_aux = (fun (defval, aux) -> Def_val_aux (defval, aux)); + pat_exp = (fun (pat, e) -> Pat_exp (pat, e)); + pat_when = (fun (pat, e, e') -> Pat_when (pat, e, e')); + pat_aux = (fun (pexp, a) -> Pat_aux (pexp, a)); + lb_val = (fun (pat, e) -> LB_val (pat, e)); + lb_aux = (fun (lb, annot) -> LB_aux (lb, annot)); + pat_alg = id_pat_alg; } (* Folding algorithms for not only rewriting patterns/expressions, but also @@ -647,389 +667,488 @@ let id_exp_alg = See rewrite_sizeof for examples. *) let compute_pat_alg bot join = let join_list vs = List.fold_left join bot vs in - let split_join f ps = let (vs,ps) = List.split ps in (join_list vs, f ps) in - { p_lit = (fun lit -> (bot, P_lit lit)) - ; p_wild = (bot, P_wild) - (* todo: I have no idea how to combine v1 and v2 in the following *) - ; p_or = (fun ((v1, pat1), (v2, pat2)) -> (v1, P_or(pat1, pat2))) - ; p_not = (fun (v, pat) -> (v, P_not(pat))) - ; p_as = (fun ((v,pat),id) -> (v, P_as (pat,id))) - ; p_typ = (fun (typ,(v,pat)) -> (v, P_typ (typ,pat))) - ; p_id = (fun id -> (bot, P_id id)) - ; p_var = (fun ((v,pat),kid) -> (v, P_var (pat,kid))) - ; p_app = (fun (id,ps) -> split_join (fun ps -> P_app (id,ps)) ps) - ; p_vector = split_join (fun ps -> P_vector ps) - ; p_vector_concat = split_join (fun ps -> P_vector_concat ps) - ; p_vector_subrange = (fun (id, n, m) -> (bot, P_vector_subrange (id, n, m))) - ; p_tuple = split_join (fun ps -> P_tuple ps) - ; p_list = split_join (fun ps -> P_list ps) - ; p_cons = (fun ((vh,ph),(vt,pt)) -> (join vh vt, P_cons (ph,pt))) - ; p_string_append = split_join (fun ps -> P_string_append ps) - ; p_aux = (fun ((v,pat),annot) -> (v, P_aux (pat,annot))) + let split_join f ps = + let vs, ps = List.split ps in + (join_list vs, f ps) + in + { + p_lit = (fun lit -> (bot, P_lit lit)); + p_wild = (bot, P_wild) (* todo: I have no idea how to combine v1 and v2 in the following *); + p_or = (fun ((v1, pat1), (v2, pat2)) -> (v1, P_or (pat1, pat2))); + p_not = (fun (v, pat) -> (v, P_not pat)); + p_as = (fun ((v, pat), id) -> (v, P_as (pat, id))); + p_typ = (fun (typ, (v, pat)) -> (v, P_typ (typ, pat))); + p_id = (fun id -> (bot, P_id id)); + p_var = (fun ((v, pat), kid) -> (v, P_var (pat, kid))); + p_app = (fun (id, ps) -> split_join (fun ps -> P_app (id, ps)) ps); + p_vector = split_join (fun ps -> P_vector ps); + p_vector_concat = split_join (fun ps -> P_vector_concat ps); + p_vector_subrange = (fun (id, n, m) -> (bot, P_vector_subrange (id, n, m))); + p_tuple = split_join (fun ps -> P_tuple ps); + p_list = split_join (fun ps -> P_list ps); + p_cons = (fun ((vh, ph), (vt, pt)) -> (join vh vt, P_cons (ph, pt))); + p_string_append = split_join (fun ps -> P_string_append ps); + p_aux = (fun ((v, pat), annot) -> (v, P_aux (pat, annot))); } let compute_exp_alg bot join = let join_list vs = List.fold_left join bot vs in - let split_join f es = let (vs,es) = List.split es in (join_list vs, f es) in - { e_block = split_join (fun es -> E_block es) - ; e_id = (fun id -> (bot, E_id id)) - ; e_ref = (fun id -> (bot, E_ref id)) - ; e_lit = (fun lit -> (bot, E_lit lit)) - ; e_typ = (fun (typ,(v,e)) -> (v, E_typ (typ,e))) - ; e_app = (fun (id,es) -> split_join (fun es -> E_app (id,es)) es) - ; e_app_infix = (fun ((v1,e1),id,(v2,e2)) -> (join v1 v2, E_app_infix (e1,id,e2))) - ; e_tuple = split_join (fun es -> E_tuple es) - ; e_if = (fun ((v1,e1),(v2,e2),(v3,e3)) -> (join_list [v1;v2;v3], E_if (e1,e2,e3))) - ; e_for = (fun (id,(v1,e1),(v2,e2),(v3,e3),order,(v4,e4)) -> - (join_list [v1;v2;v3;v4], E_for (id,e1,e2,e3,order,e4))) - ; e_loop = (fun (lt, (m,l), (v1, e1), (v2, e2)) -> - let vs,m = match m with - | None -> [], Measure_none - | Some (v,e) -> [v], Measure_some e - in - (join_list (vs@[v1;v2]), E_loop (lt, Measure_aux (m,l), e1, e2))) - ; e_vector = split_join (fun es -> E_vector es) - ; e_vector_access = (fun ((v1,e1),(v2,e2)) -> (join v1 v2, E_vector_access (e1,e2))) - ; e_vector_subrange = (fun ((v1,e1),(v2,e2),(v3,e3)) -> (join_list [v1;v2;v3], E_vector_subrange (e1,e2,e3))) - ; e_vector_update = (fun ((v1,e1),(v2,e2),(v3,e3)) -> (join_list [v1;v2;v3], E_vector_update (e1,e2,e3))) - ; e_vector_update_subrange = (fun ((v1,e1),(v2,e2),(v3,e3),(v4,e4)) -> (join_list [v1;v2;v3;v4], E_vector_update_subrange (e1,e2,e3,e4))) - ; e_vector_append = (fun ((v1,e1),(v2,e2)) -> (join v1 v2, E_vector_append (e1,e2))) - ; e_list = split_join (fun es -> E_list es) - ; e_cons = (fun ((v1,e1),(v2,e2)) -> (join v1 v2, E_cons (e1,e2))) - ; e_struct = (fun fexps -> - let vs, fexps = List.split fexps in - (join_list vs, E_struct fexps)) - ; e_struct_update = (fun ((v1,e1),fexps) -> - let (vps,fexps) = List.split fexps in - (join_list (v1::vps), E_struct_update (e1,fexps))) - ; e_field = (fun ((v1,e1),id) -> (v1, E_field (e1,id))) - ; e_case = (fun ((v1,e1),pexps) -> - let (vps,pexps) = List.split pexps in - (join_list (v1::vps), E_match (e1,pexps))) - ; e_try = (fun ((v1,e1),pexps) -> - let (vps,pexps) = List.split pexps in - (join_list (v1::vps), E_try (e1,pexps))) - ; e_let = (fun ((vl,lb),(v2,e2)) -> (join vl v2, E_let (lb,e2))) - ; e_assign = (fun ((vl,lexp),(v2,e2)) -> (join vl v2, E_assign (lexp,e2))) - ; e_sizeof = (fun nexp -> (bot, E_sizeof nexp)) - ; e_constraint = (fun nc -> (bot, E_constraint nc)) - ; e_exit = (fun (v1,e1) -> (v1, E_exit (e1))) - ; e_throw = (fun (v1,e1) -> (v1, E_throw (e1))) - ; e_return = (fun (v1,e1) -> (v1, E_return e1)) - ; e_assert = (fun ((v1,e1),(v2,e2)) -> (join v1 v2, E_assert(e1,e2)) ) - ; e_var = (fun ((vl, lexp), (v2,e2), (v3,e3)) -> - (join_list [vl;v2;v3], E_var (lexp,e2,e3))) - ; e_internal_plet = (fun ((vp,pat), (v1,e1), (v2,e2)) -> - (join_list [vp;v1;v2], E_internal_plet (pat,e1,e2))) - ; e_internal_return = (fun (v,e) -> (v, E_internal_return e)) - ; e_internal_value = (fun v -> (bot, E_internal_value v)) - ; e_internal_assume = (fun (nc,(v,e)) -> (v, E_internal_assume (nc,e))) - ; e_aux = (fun ((v,e),annot) -> (v, E_aux (e,annot))) - ; le_id = (fun id -> (bot, LE_id id)) - ; le_deref = (fun (v, e) -> (v, LE_deref e)) - ; le_app = (fun (id,es) -> split_join (fun es -> LE_app (id,es)) es) - ; le_typ = (fun (typ,id) -> (bot, LE_typ (typ,id))) - ; le_tuple = (fun ls -> - let (vs,ls) = List.split ls in - (join_list vs, LE_tuple ls)) - ; le_vector = (fun ((vl,lexp),(v2,e2)) -> (join vl v2, LE_vector (lexp,e2))) - ; le_vector_range = (fun ((vl,lexp),(v2,e2),(v3,e3)) -> - (join_list [vl;v2;v3], LE_vector_range (lexp,e2,e3))) - ; le_vector_concat = (fun ls -> - let (vs,ls) = List.split ls in - (join_list vs, LE_vector_concat ls)) - ; le_field = (fun ((vl,lexp),id) -> (vl, LE_field (lexp,id))) - ; le_aux = (fun ((vl,lexp),annot) -> (vl, LE_aux (lexp,annot))) - ; fe_fexp = (fun (id,(v,e)) -> (v, FE_fexp (id,e))) - ; fe_aux = (fun ((vf,fexp),annot) -> (vf, FE_aux (fexp,annot))) - ; def_val_empty = (bot, Def_val_empty) - ; def_val_dec = (fun (v,e) -> (v, Def_val_dec e)) - ; def_val_aux = (fun ((v,defval),aux) -> (v, Def_val_aux (defval,aux))) - ; pat_exp = (fun ((vp,pat),(v,e)) -> (join vp v, Pat_exp (pat,e))) - ; pat_when = (fun ((vp,pat),(v,e),(v',e')) -> (join_list [vp;v;v'], Pat_when (pat,e,e'))) - ; pat_aux = (fun ((v,pexp),a) -> (v, Pat_aux (pexp,a))) - ; lb_val = (fun ((vp,pat),(v,e)) -> (join vp v, LB_val (pat,e))) - ; lb_aux = (fun ((vl,lb),annot) -> (vl,LB_aux (lb,annot))) - ; pat_alg = compute_pat_alg bot join + let split_join f es = + let vs, es = List.split es in + (join_list vs, f es) + in + { + e_block = split_join (fun es -> E_block es); + e_id = (fun id -> (bot, E_id id)); + e_ref = (fun id -> (bot, E_ref id)); + e_lit = (fun lit -> (bot, E_lit lit)); + e_typ = (fun (typ, (v, e)) -> (v, E_typ (typ, e))); + e_app = (fun (id, es) -> split_join (fun es -> E_app (id, es)) es); + e_app_infix = (fun ((v1, e1), id, (v2, e2)) -> (join v1 v2, E_app_infix (e1, id, e2))); + e_tuple = split_join (fun es -> E_tuple es); + e_if = (fun ((v1, e1), (v2, e2), (v3, e3)) -> (join_list [v1; v2; v3], E_if (e1, e2, e3))); + e_for = + (fun (id, (v1, e1), (v2, e2), (v3, e3), order, (v4, e4)) -> + (join_list [v1; v2; v3; v4], E_for (id, e1, e2, e3, order, e4)) + ); + e_loop = + (fun (lt, (m, l), (v1, e1), (v2, e2)) -> + let vs, m = match m with None -> ([], Measure_none) | Some (v, e) -> ([v], Measure_some e) in + (join_list (vs @ [v1; v2]), E_loop (lt, Measure_aux (m, l), e1, e2)) + ); + e_vector = split_join (fun es -> E_vector es); + e_vector_access = (fun ((v1, e1), (v2, e2)) -> (join v1 v2, E_vector_access (e1, e2))); + e_vector_subrange = (fun ((v1, e1), (v2, e2), (v3, e3)) -> (join_list [v1; v2; v3], E_vector_subrange (e1, e2, e3))); + e_vector_update = (fun ((v1, e1), (v2, e2), (v3, e3)) -> (join_list [v1; v2; v3], E_vector_update (e1, e2, e3))); + e_vector_update_subrange = + (fun ((v1, e1), (v2, e2), (v3, e3), (v4, e4)) -> + (join_list [v1; v2; v3; v4], E_vector_update_subrange (e1, e2, e3, e4)) + ); + e_vector_append = (fun ((v1, e1), (v2, e2)) -> (join v1 v2, E_vector_append (e1, e2))); + e_list = split_join (fun es -> E_list es); + e_cons = (fun ((v1, e1), (v2, e2)) -> (join v1 v2, E_cons (e1, e2))); + e_struct = + (fun fexps -> + let vs, fexps = List.split fexps in + (join_list vs, E_struct fexps) + ); + e_struct_update = + (fun ((v1, e1), fexps) -> + let vps, fexps = List.split fexps in + (join_list (v1 :: vps), E_struct_update (e1, fexps)) + ); + e_field = (fun ((v1, e1), id) -> (v1, E_field (e1, id))); + e_case = + (fun ((v1, e1), pexps) -> + let vps, pexps = List.split pexps in + (join_list (v1 :: vps), E_match (e1, pexps)) + ); + e_try = + (fun ((v1, e1), pexps) -> + let vps, pexps = List.split pexps in + (join_list (v1 :: vps), E_try (e1, pexps)) + ); + e_let = (fun ((vl, lb), (v2, e2)) -> (join vl v2, E_let (lb, e2))); + e_assign = (fun ((vl, lexp), (v2, e2)) -> (join vl v2, E_assign (lexp, e2))); + e_sizeof = (fun nexp -> (bot, E_sizeof nexp)); + e_constraint = (fun nc -> (bot, E_constraint nc)); + e_exit = (fun (v1, e1) -> (v1, E_exit e1)); + e_throw = (fun (v1, e1) -> (v1, E_throw e1)); + e_return = (fun (v1, e1) -> (v1, E_return e1)); + e_assert = (fun ((v1, e1), (v2, e2)) -> (join v1 v2, E_assert (e1, e2))); + e_var = (fun ((vl, lexp), (v2, e2), (v3, e3)) -> (join_list [vl; v2; v3], E_var (lexp, e2, e3))); + e_internal_plet = (fun ((vp, pat), (v1, e1), (v2, e2)) -> (join_list [vp; v1; v2], E_internal_plet (pat, e1, e2))); + e_internal_return = (fun (v, e) -> (v, E_internal_return e)); + e_internal_value = (fun v -> (bot, E_internal_value v)); + e_internal_assume = (fun (nc, (v, e)) -> (v, E_internal_assume (nc, e))); + e_aux = (fun ((v, e), annot) -> (v, E_aux (e, annot))); + le_id = (fun id -> (bot, LE_id id)); + le_deref = (fun (v, e) -> (v, LE_deref e)); + le_app = (fun (id, es) -> split_join (fun es -> LE_app (id, es)) es); + le_typ = (fun (typ, id) -> (bot, LE_typ (typ, id))); + le_tuple = + (fun ls -> + let vs, ls = List.split ls in + (join_list vs, LE_tuple ls) + ); + le_vector = (fun ((vl, lexp), (v2, e2)) -> (join vl v2, LE_vector (lexp, e2))); + le_vector_range = (fun ((vl, lexp), (v2, e2), (v3, e3)) -> (join_list [vl; v2; v3], LE_vector_range (lexp, e2, e3))); + le_vector_concat = + (fun ls -> + let vs, ls = List.split ls in + (join_list vs, LE_vector_concat ls) + ); + le_field = (fun ((vl, lexp), id) -> (vl, LE_field (lexp, id))); + le_aux = (fun ((vl, lexp), annot) -> (vl, LE_aux (lexp, annot))); + fe_fexp = (fun (id, (v, e)) -> (v, FE_fexp (id, e))); + fe_aux = (fun ((vf, fexp), annot) -> (vf, FE_aux (fexp, annot))); + def_val_empty = (bot, Def_val_empty); + def_val_dec = (fun (v, e) -> (v, Def_val_dec e)); + def_val_aux = (fun ((v, defval), aux) -> (v, Def_val_aux (defval, aux))); + pat_exp = (fun ((vp, pat), (v, e)) -> (join vp v, Pat_exp (pat, e))); + pat_when = (fun ((vp, pat), (v, e), (v', e')) -> (join_list [vp; v; v'], Pat_when (pat, e, e'))); + pat_aux = (fun ((v, pexp), a) -> (v, Pat_aux (pexp, a))); + lb_val = (fun ((vp, pat), (v, e)) -> (join vp v, LB_val (pat, e))); + lb_aux = (fun ((vl, lb), annot) -> (vl, LB_aux (lb, annot))); + pat_alg = compute_pat_alg bot join; } let pure_pat_alg bot join = let join_list vs = List.fold_left join bot vs in - { p_lit = (fun _ -> bot) - ; p_wild = bot - ; p_or = (fun (pat1, pat2) -> bot) (* todo: this is wrong *) - ; p_not = (fun _ -> bot) (* todo: this is wrong *) - ; p_as = (fun (v, _) -> v) - ; p_typ = (fun (_, v) -> v) - ; p_id = (fun id -> bot) - ; p_var = (fun (v,kid) -> v) - ; p_app = (fun (id,ps) -> join_list ps) - ; p_vector = join_list - ; p_vector_concat = join_list - ; p_vector_subrange = (fun _ -> bot) - ; p_tuple = join_list - ; p_list = join_list - ; p_string_append = join_list - ; p_cons = (fun (vh,vt) -> join vh vt) - ; p_aux = (fun (v,annot) -> v) + { + p_lit = (fun _ -> bot); + p_wild = bot; + p_or = (fun (pat1, pat2) -> bot) (* todo: this is wrong *); + p_not = (fun _ -> bot) (* todo: this is wrong *); + p_as = (fun (v, _) -> v); + p_typ = (fun (_, v) -> v); + p_id = (fun id -> bot); + p_var = (fun (v, kid) -> v); + p_app = (fun (id, ps) -> join_list ps); + p_vector = join_list; + p_vector_concat = join_list; + p_vector_subrange = (fun _ -> bot); + p_tuple = join_list; + p_list = join_list; + p_string_append = join_list; + p_cons = (fun (vh, vt) -> join vh vt); + p_aux = (fun (v, annot) -> v); } let pure_exp_alg bot join = let join_list vs = List.fold_left join bot vs in - { e_block = join_list - ; e_id = (fun id -> bot) - ; e_ref = (fun id -> bot) - ; e_lit = (fun lit -> bot) - ; e_typ = (fun (typ,v) -> v) - ; e_app = (fun (id,es) -> join_list es) - ; e_app_infix = (fun (v1,id,v2) -> join v1 v2) - ; e_tuple = join_list - ; e_if = (fun (v1,v2,v3) -> join_list [v1;v2;v3]) - ; e_for = (fun (id,v1,v2,v3,order,v4) -> join_list [v1;v2;v3;v4]) - ; e_loop = (fun (lt, (m,_), v1, v2) -> - let v = join v1 v2 in match m with None -> v | Some v' -> join v v') - ; e_vector = join_list - ; e_vector_access = (fun (v1,v2) -> join v1 v2) - ; e_vector_subrange = (fun (v1,v2,v3) -> join_list [v1;v2;v3]) - ; e_vector_update = (fun (v1,v2,v3) -> join_list [v1;v2;v3]) - ; e_vector_update_subrange = (fun (v1,v2,v3,v4) -> join_list [v1;v2;v3;v4]) - ; e_vector_append = (fun (v1,v2) -> join v1 v2) - ; e_list = join_list - ; e_cons = (fun (v1,v2) -> join v1 v2) - ; e_struct = (fun vs -> join_list vs) - ; e_struct_update = (fun (v1,vf) -> join_list (v1::vf)) - ; e_field = (fun (v1,id) -> v1) - ; e_case = (fun (v1,vps) -> join_list (v1::vps)) - ; e_try = (fun (v1,vps) -> join_list (v1::vps)) - ; e_let = (fun (vl,v2) -> join vl v2) - ; e_assign = (fun (vl,v2) -> join vl v2) - ; e_sizeof = (fun nexp -> bot) - ; e_constraint = (fun nc -> bot) - ; e_exit = (fun v1 -> v1) - ; e_throw = (fun v1 -> v1) - ; e_return = (fun v1 -> v1) - ; e_assert = (fun (v1,v2) -> join v1 v2) - ; e_var = (fun (vl, v2, v3) -> join_list [vl;v2;v3]) - ; e_internal_plet = (fun (vp, v1, v2) -> join_list [vp;v1;v2]) - ; e_internal_return = (fun v -> v) - ; e_internal_value = (fun v -> bot) - ; e_internal_assume = (fun (_nc,v) -> v) - ; e_aux = (fun (v,annot) -> v) - ; le_id = (fun id -> bot) - ; le_deref = (fun v -> v) - ; le_app = (fun (id,es) -> join_list es) - ; le_typ = (fun (typ,id) -> bot) - ; le_tuple = join_list - ; le_vector = (fun (vl,v2) -> join vl v2) - ; le_vector_range = (fun (vl,v2,v3) -> join_list [vl;v2;v3]) - ; le_vector_concat = join_list - ; le_field = (fun (vl,id) -> vl) - ; le_aux = (fun (vl,annot) -> vl) - ; fe_fexp = (fun (id,v) -> v) - ; fe_aux = (fun (vf,annot) -> vf) - ; def_val_empty = bot - ; def_val_dec = (fun v -> v) - ; def_val_aux = (fun (v,aux) -> v) - ; pat_exp = (fun (vp,v) -> join vp v) - ; pat_when = (fun (vp,v,v') -> join_list [vp;v;v']) - ; pat_aux = (fun (v,a) -> v) - ; lb_val = (fun (vp,v) -> join vp v) - ; lb_aux = (fun (vl,annot) -> vl) - ; pat_alg = pure_pat_alg bot join + { + e_block = join_list; + e_id = (fun id -> bot); + e_ref = (fun id -> bot); + e_lit = (fun lit -> bot); + e_typ = (fun (typ, v) -> v); + e_app = (fun (id, es) -> join_list es); + e_app_infix = (fun (v1, id, v2) -> join v1 v2); + e_tuple = join_list; + e_if = (fun (v1, v2, v3) -> join_list [v1; v2; v3]); + e_for = (fun (id, v1, v2, v3, order, v4) -> join_list [v1; v2; v3; v4]); + e_loop = + (fun (lt, (m, _), v1, v2) -> + let v = join v1 v2 in + match m with None -> v | Some v' -> join v v' + ); + e_vector = join_list; + e_vector_access = (fun (v1, v2) -> join v1 v2); + e_vector_subrange = (fun (v1, v2, v3) -> join_list [v1; v2; v3]); + e_vector_update = (fun (v1, v2, v3) -> join_list [v1; v2; v3]); + e_vector_update_subrange = (fun (v1, v2, v3, v4) -> join_list [v1; v2; v3; v4]); + e_vector_append = (fun (v1, v2) -> join v1 v2); + e_list = join_list; + e_cons = (fun (v1, v2) -> join v1 v2); + e_struct = (fun vs -> join_list vs); + e_struct_update = (fun (v1, vf) -> join_list (v1 :: vf)); + e_field = (fun (v1, id) -> v1); + e_case = (fun (v1, vps) -> join_list (v1 :: vps)); + e_try = (fun (v1, vps) -> join_list (v1 :: vps)); + e_let = (fun (vl, v2) -> join vl v2); + e_assign = (fun (vl, v2) -> join vl v2); + e_sizeof = (fun nexp -> bot); + e_constraint = (fun nc -> bot); + e_exit = (fun v1 -> v1); + e_throw = (fun v1 -> v1); + e_return = (fun v1 -> v1); + e_assert = (fun (v1, v2) -> join v1 v2); + e_var = (fun (vl, v2, v3) -> join_list [vl; v2; v3]); + e_internal_plet = (fun (vp, v1, v2) -> join_list [vp; v1; v2]); + e_internal_return = (fun v -> v); + e_internal_value = (fun v -> bot); + e_internal_assume = (fun (_nc, v) -> v); + e_aux = (fun (v, annot) -> v); + le_id = (fun id -> bot); + le_deref = (fun v -> v); + le_app = (fun (id, es) -> join_list es); + le_typ = (fun (typ, id) -> bot); + le_tuple = join_list; + le_vector = (fun (vl, v2) -> join vl v2); + le_vector_range = (fun (vl, v2, v3) -> join_list [vl; v2; v3]); + le_vector_concat = join_list; + le_field = (fun (vl, id) -> vl); + le_aux = (fun (vl, annot) -> vl); + fe_fexp = (fun (id, v) -> v); + fe_aux = (fun (vf, annot) -> vf); + def_val_empty = bot; + def_val_dec = (fun v -> v); + def_val_aux = (fun (v, aux) -> v); + pat_exp = (fun (vp, v) -> join vp v); + pat_when = (fun (vp, v, v') -> join_list [vp; v; v']); + pat_aux = (fun (v, a) -> v); + lb_val = (fun (vp, v) -> join vp v); + lb_aux = (fun (vl, annot) -> vl); + pat_alg = pure_pat_alg bot join; } -let default_fold_fexp f x (FE_aux (FE_fexp (id,e),annot)) = - let x,e = f x e in - x, FE_aux (FE_fexp (id,e),annot) +let default_fold_fexp f x (FE_aux (FE_fexp (id, e), annot)) = + let x, e = f x e in + (x, FE_aux (FE_fexp (id, e), annot)) -let default_fold_pexp f x (Pat_aux (pe,ann)) = - let x,pe = match pe with - | Pat_exp (p,e) -> - let x,e = f x e in - x,Pat_exp (p,e) - | Pat_when (p,e1,e2) -> - let x,e1 = f x e1 in - let x,e2 = f x e2 in - x,Pat_when (p,e1,e2) - in x, Pat_aux (pe,ann) +let default_fold_pexp f x (Pat_aux (pe, ann)) = + let x, pe = + match pe with + | Pat_exp (p, e) -> + let x, e = f x e in + (x, Pat_exp (p, e)) + | Pat_when (p, e1, e2) -> + let x, e1 = f x e1 in + let x, e2 = f x e2 in + (x, Pat_when (p, e1, e2)) + in + (x, Pat_aux (pe, ann)) -let default_fold_letbind f x (LB_aux (LB_val (p,e),ann)) = - let x,e = f x e in - x, LB_aux (LB_val (p,e),ann) +let default_fold_letbind f x (LB_aux (LB_val (p, e), ann)) = + let x, e = f x e in + (x, LB_aux (LB_val (p, e), ann)) -let rec default_fold_lexp f x (LE_aux (le,ann) as lexp) = - let re le = LE_aux (le,ann) in +let rec default_fold_lexp f x (LE_aux (le, ann) as lexp) = + let re le = LE_aux (le, ann) in match le with - | LE_id _ - | LE_typ _ - -> x, lexp + | LE_id _ | LE_typ _ -> (x, lexp) | LE_deref e -> - let x, e = f x e in - x, re (LE_deref e) - | LE_app (id,es) -> - let x,es = List.fold_left (fun (x,es) e -> - let x,e' = f x e in x,e'::es) (x,[]) es in - x, re (LE_app (id, List.rev es)) + let x, e = f x e in + (x, re (LE_deref e)) + | LE_app (id, es) -> + let x, es = + List.fold_left + (fun (x, es) e -> + let x, e' = f x e in + (x, e' :: es) + ) + (x, []) es + in + (x, re (LE_app (id, List.rev es))) | LE_tuple les -> - let x,les = List.fold_left (fun (x,les) le -> - let x,le' = default_fold_lexp f x le in x,le'::les) (x,[]) les in - x, re (LE_tuple (List.rev les)) + let x, les = + List.fold_left + (fun (x, les) le -> + let x, le' = default_fold_lexp f x le in + (x, le' :: les) + ) + (x, []) les + in + (x, re (LE_tuple (List.rev les))) | LE_vector_concat les -> - let x,les = List.fold_left (fun (x,les) le -> - let x,le' = default_fold_lexp f x le in x,le'::les) (x,[]) les in - x, re (LE_vector_concat (List.rev les)) - | LE_vector (le,e) -> - let x, le = default_fold_lexp f x le in - let x, e = f x e in - x, re (LE_vector (le,e)) - | LE_vector_range (le,e1,e2) -> - let x, le = default_fold_lexp f x le in - let x, e1 = f x e1 in - let x, e2 = f x e2 in - x, re (LE_vector_range (le,e1,e2)) - | LE_field (le,id) -> - let x, le = default_fold_lexp f x le in - x, re (LE_field (le,id)) + let x, les = + List.fold_left + (fun (x, les) le -> + let x, le' = default_fold_lexp f x le in + (x, le' :: les) + ) + (x, []) les + in + (x, re (LE_vector_concat (List.rev les))) + | LE_vector (le, e) -> + let x, le = default_fold_lexp f x le in + let x, e = f x e in + (x, re (LE_vector (le, e))) + | LE_vector_range (le, e1, e2) -> + let x, le = default_fold_lexp f x le in + let x, e1 = f x e1 in + let x, e2 = f x e2 in + (x, re (LE_vector_range (le, e1, e2))) + | LE_field (le, id) -> + let x, le = default_fold_lexp f x le in + (x, re (LE_field (le, id))) -let default_fold_exp f x (E_aux (e,ann) as exp) = - let re e = E_aux (e,ann) in +let default_fold_exp f x (E_aux (e, ann) as exp) = + let re e = E_aux (e, ann) in match e with | E_block es -> - let x,es = List.fold_left (fun (x,es) e -> - let x,e' = f x e in x,e'::es) (x,[]) es in - x, re (E_block (List.rev es)) - | E_id _ - | E_ref _ - | E_lit _ -> x, exp - | E_typ (typ,e) -> - let x,e = f x e in - x, re (E_typ (typ,e)) - | E_app (id,es) -> - let x,es = List.fold_left (fun (x,es) e -> - let x,e' = f x e in x,e'::es) (x,[]) es in - x, re (E_app (id, List.rev es)) - | E_app_infix (e1,id,e2) -> - let x,e1 = f x e1 in - let x,e2 = f x e2 in - x, re (E_app_infix (e1,id,e2)) + let x, es = + List.fold_left + (fun (x, es) e -> + let x, e' = f x e in + (x, e' :: es) + ) + (x, []) es + in + (x, re (E_block (List.rev es))) + | E_id _ | E_ref _ | E_lit _ -> (x, exp) + | E_typ (typ, e) -> + let x, e = f x e in + (x, re (E_typ (typ, e))) + | E_app (id, es) -> + let x, es = + List.fold_left + (fun (x, es) e -> + let x, e' = f x e in + (x, e' :: es) + ) + (x, []) es + in + (x, re (E_app (id, List.rev es))) + | E_app_infix (e1, id, e2) -> + let x, e1 = f x e1 in + let x, e2 = f x e2 in + (x, re (E_app_infix (e1, id, e2))) | E_tuple es -> - let x,es = List.fold_left (fun (x,es) e -> - let x,e' = f x e in x,e'::es) (x,[]) es in - x, re (E_tuple (List.rev es)) - | E_if (e1,e2,e3) -> - let x,e1 = f x e1 in - let x,e2 = f x e2 in - let x,e3 = f x e3 in - x, re (E_if (e1,e2,e3)) - | E_for (id,e1,e2,e3,order,e4) -> - let x,e1 = f x e1 in - let x,e2 = f x e2 in - let x,e3 = f x e3 in - let x,e4 = f x e4 in - x, re (E_for (id,e1,e2,e3,order,e4)) + let x, es = + List.fold_left + (fun (x, es) e -> + let x, e' = f x e in + (x, e' :: es) + ) + (x, []) es + in + (x, re (E_tuple (List.rev es))) + | E_if (e1, e2, e3) -> + let x, e1 = f x e1 in + let x, e2 = f x e2 in + let x, e3 = f x e3 in + (x, re (E_if (e1, e2, e3))) + | E_for (id, e1, e2, e3, order, e4) -> + let x, e1 = f x e1 in + let x, e2 = f x e2 in + let x, e3 = f x e3 in + let x, e4 = f x e4 in + (x, re (E_for (id, e1, e2, e3, order, e4))) | E_loop (loop_type, m, e1, e2) -> - let x,m = match m with - | Measure_aux (Measure_none,_) -> x,m - | Measure_aux (Measure_some exp,l) -> - let x, exp = f x exp in - x, Measure_aux (Measure_some exp,l) - in - let x,e1 = f x e1 in - let x,e2 = f x e2 in - x, re (E_loop (loop_type, m, e1, e2)) + let x, m = + match m with + | Measure_aux (Measure_none, _) -> (x, m) + | Measure_aux (Measure_some exp, l) -> + let x, exp = f x exp in + (x, Measure_aux (Measure_some exp, l)) + in + let x, e1 = f x e1 in + let x, e2 = f x e2 in + (x, re (E_loop (loop_type, m, e1, e2))) | E_vector es -> - let x,es = List.fold_left (fun (x,es) e -> - let x,e' = f x e in x,e'::es) (x,[]) es in - x, re (E_vector (List.rev es)) - | E_vector_access (e1,e2) -> - let x,e1 = f x e1 in - let x,e2 = f x e2 in - x, re (E_vector_access (e1,e2)) - | E_vector_subrange (e1,e2,e3) -> - let x,e1 = f x e1 in - let x,e2 = f x e2 in - let x,e3 = f x e3 in - x, re (E_vector_subrange (e1,e2,e3)) - | E_vector_update (e1,e2,e3) -> - let x,e1 = f x e1 in - let x,e2 = f x e2 in - let x,e3 = f x e3 in - x, re (E_vector_update (e1,e2,e3)) - | E_vector_update_subrange (e1,e2,e3,e4) -> - let x,e1 = f x e1 in - let x,e2 = f x e2 in - let x,e3 = f x e3 in - let x,e4 = f x e4 in - x, re (E_vector_update_subrange (e1,e2,e3,e4)) - | E_vector_append (e1,e2) -> - let x,e1 = f x e1 in - let x,e2 = f x e2 in - x, re (E_vector_append (e1,e2)) + let x, es = + List.fold_left + (fun (x, es) e -> + let x, e' = f x e in + (x, e' :: es) + ) + (x, []) es + in + (x, re (E_vector (List.rev es))) + | E_vector_access (e1, e2) -> + let x, e1 = f x e1 in + let x, e2 = f x e2 in + (x, re (E_vector_access (e1, e2))) + | E_vector_subrange (e1, e2, e3) -> + let x, e1 = f x e1 in + let x, e2 = f x e2 in + let x, e3 = f x e3 in + (x, re (E_vector_subrange (e1, e2, e3))) + | E_vector_update (e1, e2, e3) -> + let x, e1 = f x e1 in + let x, e2 = f x e2 in + let x, e3 = f x e3 in + (x, re (E_vector_update (e1, e2, e3))) + | E_vector_update_subrange (e1, e2, e3, e4) -> + let x, e1 = f x e1 in + let x, e2 = f x e2 in + let x, e3 = f x e3 in + let x, e4 = f x e4 in + (x, re (E_vector_update_subrange (e1, e2, e3, e4))) + | E_vector_append (e1, e2) -> + let x, e1 = f x e1 in + let x, e2 = f x e2 in + (x, re (E_vector_append (e1, e2))) | E_list es -> - let x,es = List.fold_left (fun (x,es) e -> - let x,e' = f x e in x,e'::es) (x,[]) es in - x, re (E_list (List.rev es)) - | E_cons (e1,e2) -> - let x,e1 = f x e1 in - let x,e2 = f x e2 in - x, re (E_cons (e1,e2)) + let x, es = + List.fold_left + (fun (x, es) e -> + let x, e' = f x e in + (x, e' :: es) + ) + (x, []) es + in + (x, re (E_list (List.rev es))) + | E_cons (e1, e2) -> + let x, e1 = f x e1 in + let x, e2 = f x e2 in + (x, re (E_cons (e1, e2))) | E_struct fexps -> - let x,fexps = List.fold_left (fun (x,fes) fe -> - let x,fe' = default_fold_fexp f x fe in x,fe'::fes) (x,[]) fexps in - x, re (E_struct (List.rev fexps)) - | E_struct_update (e,fexps) -> - let x,e = f x e in - let x,fexps = List.fold_left (fun (x,fes) fe -> - let x,fe' = default_fold_fexp f x fe in x,fe'::fes) (x,[]) fexps in - x, re (E_struct_update (e, List.rev fexps)) - | E_field (e,id) -> - let x,e = f x e in x, re (E_field (e,id)) - | E_match (e,pexps) -> - let x,e = f x e in - let x,pexps = List.fold_left (fun (x,pes) pe -> - let x,pe' = default_fold_pexp f x pe in x,pe'::pes) (x,[]) pexps in - x, re (E_match (e, List.rev pexps)) - | E_try (e,pexps) -> - let x,e = f x e in - let x,pexps = List.fold_left (fun (x,pes) pe -> - let x,pe' = default_fold_pexp f x pe in x,pe'::pes) (x,[]) pexps in - x, re (E_try (e, List.rev pexps)) - | E_let (letbind,e) -> - let x,letbind = default_fold_letbind f x letbind in - let x,e = f x e in - x, re (E_let (letbind,e)) - | E_assign (lexp,e) -> - let x,lexp = default_fold_lexp f x lexp in - let x,e = f x e in - x, re (E_assign (lexp,e)) - | E_sizeof _ - | E_constraint _ - -> x,exp + let x, fexps = + List.fold_left + (fun (x, fes) fe -> + let x, fe' = default_fold_fexp f x fe in + (x, fe' :: fes) + ) + (x, []) fexps + in + (x, re (E_struct (List.rev fexps))) + | E_struct_update (e, fexps) -> + let x, e = f x e in + let x, fexps = + List.fold_left + (fun (x, fes) fe -> + let x, fe' = default_fold_fexp f x fe in + (x, fe' :: fes) + ) + (x, []) fexps + in + (x, re (E_struct_update (e, List.rev fexps))) + | E_field (e, id) -> + let x, e = f x e in + (x, re (E_field (e, id))) + | E_match (e, pexps) -> + let x, e = f x e in + let x, pexps = + List.fold_left + (fun (x, pes) pe -> + let x, pe' = default_fold_pexp f x pe in + (x, pe' :: pes) + ) + (x, []) pexps + in + (x, re (E_match (e, List.rev pexps))) + | E_try (e, pexps) -> + let x, e = f x e in + let x, pexps = + List.fold_left + (fun (x, pes) pe -> + let x, pe' = default_fold_pexp f x pe in + (x, pe' :: pes) + ) + (x, []) pexps + in + (x, re (E_try (e, List.rev pexps))) + | E_let (letbind, e) -> + let x, letbind = default_fold_letbind f x letbind in + let x, e = f x e in + (x, re (E_let (letbind, e))) + | E_assign (lexp, e) -> + let x, lexp = default_fold_lexp f x lexp in + let x, e = f x e in + (x, re (E_assign (lexp, e))) + | E_sizeof _ | E_constraint _ -> (x, exp) | E_exit e -> - let x,e = f x e in x, re (E_exit e) + let x, e = f x e in + (x, re (E_exit e)) | E_throw e -> - let x,e = f x e in x, re (E_throw e) + let x, e = f x e in + (x, re (E_throw e)) | E_return e -> - let x,e = f x e in x, re (E_return e) - | E_assert(e1,e2) -> - let x,e1 = f x e1 in - let x,e2 = f x e2 in - x, re (E_assert (e1,e2)) - | E_var (lexp,e1,e2) -> - let x,lexp = default_fold_lexp f x lexp in - let x,e1 = f x e1 in - let x,e2 = f x e2 in - x, re (E_var (lexp,e1,e2)) - | E_internal_plet (pat,e1,e2) -> - let x,e1 = f x e1 in - let x,e2 = f x e2 in - x, re (E_internal_plet (pat,e1,e2)) + let x, e = f x e in + (x, re (E_return e)) + | E_assert (e1, e2) -> + let x, e1 = f x e1 in + let x, e2 = f x e2 in + (x, re (E_assert (e1, e2))) + | E_var (lexp, e1, e2) -> + let x, lexp = default_fold_lexp f x lexp in + let x, e1 = f x e1 in + let x, e2 = f x e2 in + (x, re (E_var (lexp, e1, e2))) + | E_internal_plet (pat, e1, e2) -> + let x, e1 = f x e1 in + let x, e2 = f x e2 in + (x, re (E_internal_plet (pat, e1, e2))) | E_internal_return e -> - let x,e = f x e in x, re (E_internal_return e) - | E_internal_value _ -> x,exp + let x, e = f x e in + (x, re (E_internal_return e)) + | E_internal_value _ -> (x, exp) | E_internal_assume (nc, e) -> - let x,e = f x e in - x, re (E_internal_assume (nc, e)) + let x, e = f x e in + (x, re (E_internal_assume (nc, e))) let rec foldin_exp f x e = f (default_fold_exp (foldin_exp f)) x e let foldin_pexp f x e = default_fold_pexp (foldin_exp f) x e diff --git a/src/lib/rewriter.mli b/src/lib/rewriter.mli index 38683f4c2..7dc3c6a5b 100644 --- a/src/lib/rewriter.mli +++ b/src/lib/rewriter.mli @@ -72,14 +72,15 @@ open Ast open Ast_defs open Type_check -type 'a rewriters = { rewrite_exp : 'a rewriters -> 'a exp -> 'a exp; - rewrite_lexp : 'a rewriters -> 'a lexp -> 'a lexp; - rewrite_pat : 'a rewriters -> 'a pat -> 'a pat; - rewrite_let : 'a rewriters -> 'a letbind -> 'a letbind; - rewrite_fun : 'a rewriters -> 'a fundef -> 'a fundef; - rewrite_def : 'a rewriters -> 'a def -> 'a def; - rewrite_ast : 'a rewriters -> 'a ast -> 'a ast; - } +type 'a rewriters = { + rewrite_exp : 'a rewriters -> 'a exp -> 'a exp; + rewrite_lexp : 'a rewriters -> 'a lexp -> 'a lexp; + rewrite_pat : 'a rewriters -> 'a pat -> 'a pat; + rewrite_let : 'a rewriters -> 'a letbind -> 'a letbind; + rewrite_fun : 'a rewriters -> 'a fundef -> 'a fundef; + rewrite_def : 'a rewriters -> 'a def -> 'a def; + rewrite_ast : 'a rewriters -> 'a ast -> 'a ast; +} val rewrite_exp : tannot rewriters -> tannot exp -> tannot exp @@ -89,7 +90,7 @@ val rewriters_base : tannot rewriters val rewrite_ast : tannot ast -> tannot ast val rewrite_ast_defs : tannot rewriters -> tannot def list -> tannot def list - + val rewrite_ast_base : tannot rewriters -> tannot ast -> tannot ast (** Same as rewrite_defs_base but display a progress bar when verbosity >= 1 *) @@ -108,148 +109,267 @@ val rewrite_def : tannot rewriters -> tannot def -> tannot def val rewrite_fun : tannot rewriters -> tannot fundef -> tannot fundef val rewrite_mapdef : tannot rewriters -> tannot mapdef -> tannot mapdef - + (** the type of interpretations of patterns *) -type ('a,'pat,'pat_aux) pat_alg = - { p_lit : lit -> 'pat_aux - ; p_wild : 'pat_aux - ; p_or : 'pat * 'pat -> 'pat_aux - ; p_not : 'pat -> 'pat_aux - ; p_as : 'pat * id -> 'pat_aux - ; p_typ : Ast.typ * 'pat -> 'pat_aux - ; p_id : id -> 'pat_aux - ; p_var : 'pat * typ_pat -> 'pat_aux - ; p_app : id * 'pat list -> 'pat_aux - ; p_vector : 'pat list -> 'pat_aux - ; p_vector_concat : 'pat list -> 'pat_aux - ; p_vector_subrange : id * Big_int.num * Big_int.num -> 'pat_aux - ; p_tuple : 'pat list -> 'pat_aux - ; p_list : 'pat list -> 'pat_aux - ; p_cons : 'pat * 'pat -> 'pat_aux - ; p_string_append : 'pat list -> 'pat_aux - ; p_aux : 'pat_aux * 'a annot -> 'pat - } +type ('a, 'pat, 'pat_aux) pat_alg = { + p_lit : lit -> 'pat_aux; + p_wild : 'pat_aux; + p_or : 'pat * 'pat -> 'pat_aux; + p_not : 'pat -> 'pat_aux; + p_as : 'pat * id -> 'pat_aux; + p_typ : Ast.typ * 'pat -> 'pat_aux; + p_id : id -> 'pat_aux; + p_var : 'pat * typ_pat -> 'pat_aux; + p_app : id * 'pat list -> 'pat_aux; + p_vector : 'pat list -> 'pat_aux; + p_vector_concat : 'pat list -> 'pat_aux; + p_vector_subrange : id * Big_int.num * Big_int.num -> 'pat_aux; + p_tuple : 'pat list -> 'pat_aux; + p_list : 'pat list -> 'pat_aux; + p_cons : 'pat * 'pat -> 'pat_aux; + p_string_append : 'pat list -> 'pat_aux; + p_aux : 'pat_aux * 'a annot -> 'pat; +} (** the type of interpretations of expressions *) -type ('a,'exp,'exp_aux,'lexp,'lexp_aux,'fexp,'fexp_aux, - 'opt_default_aux,'opt_default,'pexp,'pexp_aux,'letbind_aux,'letbind, - 'pat,'pat_aux) exp_alg = - { e_block : 'exp list -> 'exp_aux - ; e_id : id -> 'exp_aux - ; e_ref : id -> 'exp_aux - ; e_lit : lit -> 'exp_aux - ; e_typ : Ast.typ * 'exp -> 'exp_aux - ; e_app : id * 'exp list -> 'exp_aux - ; e_app_infix : 'exp * id * 'exp -> 'exp_aux - ; e_tuple : 'exp list -> 'exp_aux - ; e_if : 'exp * 'exp * 'exp -> 'exp_aux - ; e_for : id * 'exp * 'exp * 'exp * Ast.order * 'exp -> 'exp_aux - ; e_loop : loop * ('exp option * Parse_ast.l) * 'exp * 'exp -> 'exp_aux - ; e_vector : 'exp list -> 'exp_aux - ; e_vector_access : 'exp * 'exp -> 'exp_aux - ; e_vector_subrange : 'exp * 'exp * 'exp -> 'exp_aux - ; e_vector_update : 'exp * 'exp * 'exp -> 'exp_aux - ; e_vector_update_subrange : 'exp * 'exp * 'exp * 'exp -> 'exp_aux - ; e_vector_append : 'exp * 'exp -> 'exp_aux - ; e_list : 'exp list -> 'exp_aux - ; e_cons : 'exp * 'exp -> 'exp_aux - ; e_struct : 'fexp list -> 'exp_aux - ; e_struct_update : 'exp * 'fexp list -> 'exp_aux - ; e_field : 'exp * id -> 'exp_aux - ; e_case : 'exp * 'pexp list -> 'exp_aux - ; e_try : 'exp * 'pexp list -> 'exp_aux - ; e_let : 'letbind * 'exp -> 'exp_aux - ; e_assign : 'lexp * 'exp -> 'exp_aux - ; e_sizeof : nexp -> 'exp_aux - ; e_constraint : n_constraint -> 'exp_aux - ; e_exit : 'exp -> 'exp_aux - ; e_throw : 'exp -> 'exp_aux - ; e_return : 'exp -> 'exp_aux - ; e_assert : 'exp * 'exp -> 'exp_aux - ; e_var : 'lexp * 'exp * 'exp -> 'exp_aux - ; e_internal_plet : 'pat * 'exp * 'exp -> 'exp_aux - ; e_internal_return : 'exp -> 'exp_aux - ; e_internal_value : Value.value -> 'exp_aux - ; e_internal_assume : n_constraint * 'exp -> 'exp_aux - ; e_aux : 'exp_aux * 'a annot -> 'exp - ; le_id : id -> 'lexp_aux - ; le_deref : 'exp -> 'lexp_aux - ; le_app : id * 'exp list -> 'lexp_aux - ; le_typ : Ast.typ * id -> 'lexp_aux - ; le_tuple : 'lexp list -> 'lexp_aux - ; le_vector : 'lexp * 'exp -> 'lexp_aux - ; le_vector_range : 'lexp * 'exp * 'exp -> 'lexp_aux - ; le_vector_concat : 'lexp list -> 'lexp_aux - ; le_field : 'lexp * id -> 'lexp_aux - ; le_aux : 'lexp_aux * 'a annot -> 'lexp - ; fe_fexp : id * 'exp -> 'fexp_aux - ; fe_aux : 'fexp_aux * 'a annot -> 'fexp - ; def_val_empty : 'opt_default_aux - ; def_val_dec : 'exp -> 'opt_default_aux - ; def_val_aux : 'opt_default_aux * 'a annot -> 'opt_default - ; pat_exp : 'pat * 'exp -> 'pexp_aux - ; pat_when : 'pat * 'exp * 'exp -> 'pexp_aux - ; pat_aux : 'pexp_aux * 'a annot -> 'pexp - ; lb_val : 'pat * 'exp -> 'letbind_aux - ; lb_aux : 'letbind_aux * 'a annot -> 'letbind - ; pat_alg : ('a,'pat,'pat_aux) pat_alg - } +type ( 'a, + 'exp, + 'exp_aux, + 'lexp, + 'lexp_aux, + 'fexp, + 'fexp_aux, + 'opt_default_aux, + 'opt_default, + 'pexp, + 'pexp_aux, + 'letbind_aux, + 'letbind, + 'pat, + 'pat_aux + ) + exp_alg = { + e_block : 'exp list -> 'exp_aux; + e_id : id -> 'exp_aux; + e_ref : id -> 'exp_aux; + e_lit : lit -> 'exp_aux; + e_typ : Ast.typ * 'exp -> 'exp_aux; + e_app : id * 'exp list -> 'exp_aux; + e_app_infix : 'exp * id * 'exp -> 'exp_aux; + e_tuple : 'exp list -> 'exp_aux; + e_if : 'exp * 'exp * 'exp -> 'exp_aux; + e_for : id * 'exp * 'exp * 'exp * Ast.order * 'exp -> 'exp_aux; + e_loop : loop * ('exp option * Parse_ast.l) * 'exp * 'exp -> 'exp_aux; + e_vector : 'exp list -> 'exp_aux; + e_vector_access : 'exp * 'exp -> 'exp_aux; + e_vector_subrange : 'exp * 'exp * 'exp -> 'exp_aux; + e_vector_update : 'exp * 'exp * 'exp -> 'exp_aux; + e_vector_update_subrange : 'exp * 'exp * 'exp * 'exp -> 'exp_aux; + e_vector_append : 'exp * 'exp -> 'exp_aux; + e_list : 'exp list -> 'exp_aux; + e_cons : 'exp * 'exp -> 'exp_aux; + e_struct : 'fexp list -> 'exp_aux; + e_struct_update : 'exp * 'fexp list -> 'exp_aux; + e_field : 'exp * id -> 'exp_aux; + e_case : 'exp * 'pexp list -> 'exp_aux; + e_try : 'exp * 'pexp list -> 'exp_aux; + e_let : 'letbind * 'exp -> 'exp_aux; + e_assign : 'lexp * 'exp -> 'exp_aux; + e_sizeof : nexp -> 'exp_aux; + e_constraint : n_constraint -> 'exp_aux; + e_exit : 'exp -> 'exp_aux; + e_throw : 'exp -> 'exp_aux; + e_return : 'exp -> 'exp_aux; + e_assert : 'exp * 'exp -> 'exp_aux; + e_var : 'lexp * 'exp * 'exp -> 'exp_aux; + e_internal_plet : 'pat * 'exp * 'exp -> 'exp_aux; + e_internal_return : 'exp -> 'exp_aux; + e_internal_value : Value.value -> 'exp_aux; + e_internal_assume : n_constraint * 'exp -> 'exp_aux; + e_aux : 'exp_aux * 'a annot -> 'exp; + le_id : id -> 'lexp_aux; + le_deref : 'exp -> 'lexp_aux; + le_app : id * 'exp list -> 'lexp_aux; + le_typ : Ast.typ * id -> 'lexp_aux; + le_tuple : 'lexp list -> 'lexp_aux; + le_vector : 'lexp * 'exp -> 'lexp_aux; + le_vector_range : 'lexp * 'exp * 'exp -> 'lexp_aux; + le_vector_concat : 'lexp list -> 'lexp_aux; + le_field : 'lexp * id -> 'lexp_aux; + le_aux : 'lexp_aux * 'a annot -> 'lexp; + fe_fexp : id * 'exp -> 'fexp_aux; + fe_aux : 'fexp_aux * 'a annot -> 'fexp; + def_val_empty : 'opt_default_aux; + def_val_dec : 'exp -> 'opt_default_aux; + def_val_aux : 'opt_default_aux * 'a annot -> 'opt_default; + pat_exp : 'pat * 'exp -> 'pexp_aux; + pat_when : 'pat * 'exp * 'exp -> 'pexp_aux; + pat_aux : 'pexp_aux * 'a annot -> 'pexp; + lb_val : 'pat * 'exp -> 'letbind_aux; + lb_aux : 'letbind_aux * 'a annot -> 'letbind; + pat_alg : ('a, 'pat, 'pat_aux) pat_alg; +} (* fold over patterns *) -val fold_pat : ('a,'pat,'pat_aux) pat_alg -> 'a pat -> 'pat - -val fold_mpat : ('a,'mpat,'mpat_aux) pat_alg -> 'a mpat -> 'mpat - -(* fold over expressions *) -val fold_exp : ('a,'exp,'exp_aux,'lexp,'lexp_aux,'fexp,'fexp_aux, - 'opt_default_aux,'opt_default,'pexp,'pexp_aux,'letbind_aux,'letbind, - 'pat,'pat_aux) exp_alg -> 'a exp -> 'exp +val fold_pat : ('a, 'pat, 'pat_aux) pat_alg -> 'a pat -> 'pat -val fold_letbind : ('a,'exp,'exp_aux,'lexp,'lexp_aux,'fexp,'fexp_aux, - 'opt_default_aux,'opt_default,'pexp,'pexp_aux,'letbind_aux,'letbind, - 'pat,'pat_aux) exp_alg -> 'a letbind -> 'letbind +val fold_mpat : ('a, 'mpat, 'mpat_aux) pat_alg -> 'a mpat -> 'mpat -val fold_pexp : ('a,'exp,'exp_aux,'lexp,'lexp_aux,'fexp,'fexp_aux, - 'opt_default_aux,'opt_default,'pexp,'pexp_aux,'letbind_aux,'letbind, - 'pat,'pat_aux) exp_alg -> 'a pexp -> 'pexp - -val fold_funcl : ('a,'exp,'exp_aux,'lexp,'lexp_aux,'fexp,'fexp_aux, - 'opt_default_aux,'opt_default,'a pexp,'pexp_aux,'letbind_aux,'letbind, - 'pat,'pat_aux) exp_alg -> 'a funcl -> 'a funcl - -val fold_function : ('a,'exp,'exp_aux,'lexp,'lexp_aux,'fexp,'fexp_aux, - 'opt_default_aux,'opt_default, 'a pexp,'pexp_aux,'letbind_aux,'letbind, - 'pat,'pat_aux) exp_alg -> 'a fundef -> 'a fundef +(* fold over expressions *) +val fold_exp : + ( 'a, + 'exp, + 'exp_aux, + 'lexp, + 'lexp_aux, + 'fexp, + 'fexp_aux, + 'opt_default_aux, + 'opt_default, + 'pexp, + 'pexp_aux, + 'letbind_aux, + 'letbind, + 'pat, + 'pat_aux + ) + exp_alg -> + 'a exp -> + 'exp + +val fold_letbind : + ( 'a, + 'exp, + 'exp_aux, + 'lexp, + 'lexp_aux, + 'fexp, + 'fexp_aux, + 'opt_default_aux, + 'opt_default, + 'pexp, + 'pexp_aux, + 'letbind_aux, + 'letbind, + 'pat, + 'pat_aux + ) + exp_alg -> + 'a letbind -> + 'letbind + +val fold_pexp : + ( 'a, + 'exp, + 'exp_aux, + 'lexp, + 'lexp_aux, + 'fexp, + 'fexp_aux, + 'opt_default_aux, + 'opt_default, + 'pexp, + 'pexp_aux, + 'letbind_aux, + 'letbind, + 'pat, + 'pat_aux + ) + exp_alg -> + 'a pexp -> + 'pexp + +val fold_funcl : + ( 'a, + 'exp, + 'exp_aux, + 'lexp, + 'lexp_aux, + 'fexp, + 'fexp_aux, + 'opt_default_aux, + 'opt_default, + 'a pexp, + 'pexp_aux, + 'letbind_aux, + 'letbind, + 'pat, + 'pat_aux + ) + exp_alg -> + 'a funcl -> + 'a funcl + +val fold_function : + ( 'a, + 'exp, + 'exp_aux, + 'lexp, + 'lexp_aux, + 'fexp, + 'fexp_aux, + 'opt_default_aux, + 'opt_default, + 'a pexp, + 'pexp_aux, + 'letbind_aux, + 'letbind, + 'pat, + 'pat_aux + ) + exp_alg -> + 'a fundef -> + 'a fundef val id_pat_alg : ('a, 'a pat, 'a pat_aux) pat_alg val id_mpat_alg : ('a, 'a mpat option, 'a mpat_aux option) pat_alg val id_exp_alg : - ('a,'a exp,'a exp_aux,'a lexp,'a lexp_aux,'a fexp, - 'a fexp_aux, - 'a opt_default_aux,'a opt_default,'a pexp,'a pexp_aux, - 'a letbind_aux,'a letbind, - 'a pat,'a pat_aux) exp_alg - -val compute_pat_alg : 'b -> ('b -> 'b -> 'b) -> - ('a,('b * 'a pat),('b * 'a pat_aux)) pat_alg - -val compute_exp_alg : 'b -> ('b -> 'b -> 'b) -> - ('a,('b * 'a exp),('b * 'a exp_aux),('b * 'a lexp),('b * 'a lexp_aux),('b * 'a fexp), - ('b * 'a fexp_aux), - ('b * 'a opt_default_aux),('b * 'a opt_default),('b * 'a pexp),('b * 'a pexp_aux), - ('b * 'a letbind_aux),('b * 'a letbind), - ('b * 'a pat),('b * 'a pat_aux)) exp_alg - -val pure_pat_alg : 'b -> ('b -> 'b -> 'b) -> ('a,'b,'b) pat_alg - -val pure_exp_alg : 'b -> ('b -> 'b -> 'b) -> - ('a,'b,'b,'b,'b,'b, - 'b,'b,'b, - 'b,'b, - 'b,'b, - 'b,'b) exp_alg + ( 'a, + 'a exp, + 'a exp_aux, + 'a lexp, + 'a lexp_aux, + 'a fexp, + 'a fexp_aux, + 'a opt_default_aux, + 'a opt_default, + 'a pexp, + 'a pexp_aux, + 'a letbind_aux, + 'a letbind, + 'a pat, + 'a pat_aux + ) + exp_alg + +val compute_pat_alg : 'b -> ('b -> 'b -> 'b) -> ('a, 'b * 'a pat, 'b * 'a pat_aux) pat_alg + +val compute_exp_alg : + 'b -> + ('b -> 'b -> 'b) -> + ( 'a, + 'b * 'a exp, + 'b * 'a exp_aux, + 'b * 'a lexp, + 'b * 'a lexp_aux, + 'b * 'a fexp, + 'b * 'a fexp_aux, + 'b * 'a opt_default_aux, + 'b * 'a opt_default, + 'b * 'a pexp, + 'b * 'a pexp_aux, + 'b * 'a letbind_aux, + 'b * 'a letbind, + 'b * 'a pat, + 'b * 'a pat_aux + ) + exp_alg + +val pure_pat_alg : 'b -> ('b -> 'b -> 'b) -> ('a, 'b, 'b) pat_alg + +val pure_exp_alg : 'b -> ('b -> 'b -> 'b) -> ('a, 'b, 'b, 'b, 'b, 'b, 'b, 'b, 'b, 'b, 'b, 'b, 'b, 'b, 'b) exp_alg val add_p_typ : Env.t -> typ -> 'a pat -> 'a pat diff --git a/src/lib/rewrites.ml b/src/lib/rewrites.ml index 35e0e8bf2..e4682eb77 100644 --- a/src/lib/rewrites.ml +++ b/src/lib/rewrites.ml @@ -77,148 +77,125 @@ let fresh_name_counter = ref 0 let fresh_name () = let current = !fresh_name_counter in - let () = fresh_name_counter := (current + 1) in + let () = fresh_name_counter := current + 1 in current -let reset_fresh_name_counter () = - fresh_name_counter := 0 +let reset_fresh_name_counter () = fresh_name_counter := 0 let fresh_id pre l = let current = fresh_name () in Id_aux (Id (pre ^ string_of_int current), gen_loc l) -let fresh_id_pat pre ((l,annot)) = +let fresh_id_pat pre (l, annot) = let id = fresh_id pre l in P_aux (P_id id, (gen_loc l, annot)) -let get_loc_exp (E_aux (_,(l,_))) = l +let get_loc_exp (E_aux (_, (l, _))) = l -let gen_vs ~pure (id, spec) = Initial_check.extern_of_string ~pure:pure (mk_id id) spec +let gen_vs ~pure (id, spec) = Initial_check.extern_of_string ~pure (mk_id id) spec let simple_annot l typ = (gen_loc l, mk_tannot initial_env typ) - + let annot_exp e_aux l env typ = E_aux (e_aux, (l, mk_tannot env typ)) let annot_pat p_aux l env typ = P_aux (p_aux, (l, mk_tannot env typ)) -let annot_letbind (p_aux, exp) l env typ = - LB_aux (LB_val (annot_pat p_aux l env typ, exp), (l, mk_tannot env typ)) +let annot_letbind (p_aux, exp) l env typ = LB_aux (LB_val (annot_pat p_aux l env typ, exp), (l, mk_tannot env typ)) -let simple_num l n = E_aux ( - E_lit (L_aux (L_num n, gen_loc l)), - simple_annot (gen_loc l) - (atom_typ (Nexp_aux (Nexp_constant n, gen_loc l)))) +let simple_num l n = + E_aux (E_lit (L_aux (L_num n, gen_loc l)), simple_annot (gen_loc l) (atom_typ (Nexp_aux (Nexp_constant n, gen_loc l)))) let effectful eaux = Ast_util.effectful (effect_of eaux) let effectful_pexp pexp = - let (pat, guard, exp, _) = destruct_pexp pexp in - let guard_eff = match guard with - | Some g -> effect_of g - | None -> no_effect - in + let pat, guard, exp, _ = destruct_pexp pexp in + let guard_eff = match guard with Some g -> effect_of g | None -> no_effect in Ast_util.effectful (union_effects guard_eff (effect_of exp)) -let rec small (E_aux (exp,_)) = match exp with - | E_id _ - | E_lit _ -> true - | E_typ (_,e) -> small e +let rec small (E_aux (exp, _)) = + match exp with + | E_id _ | E_lit _ -> true + | E_typ (_, e) -> small e | E_list es -> List.for_all small es - | E_cons (e1,e2) -> small e1 && small e2 + | E_cons (e1, e2) -> small e1 && small e2 | E_sizeof _ -> true | _ -> false -let id_is_local_var id env = match Env.lookup_id id env with - | Local _ -> true - | _ -> false +let id_is_local_var id env = match Env.lookup_id id env with Local _ -> true | _ -> false -let id_is_unbound id env = match Env.lookup_id id env with - | Unbound _ -> true - | _ -> false +let id_is_unbound id env = match Env.lookup_id id env with Unbound _ -> true | _ -> false -let rec lexp_is_local (LE_aux (lexp, _)) env = match lexp with +let rec lexp_is_local (LE_aux (lexp, _)) env = + match lexp with | LE_app _ | LE_deref _ -> false - | LE_id id - | LE_typ (_, id) -> id_is_local_var id env + | LE_id id | LE_typ (_, id) -> id_is_local_var id env | LE_tuple lexps | LE_vector_concat lexps -> List.for_all (fun lexp -> lexp_is_local lexp env) lexps - | LE_vector (lexp,_) - | LE_vector_range (lexp,_,_) - | LE_field (lexp,_) -> lexp_is_local lexp env + | LE_vector (lexp, _) | LE_vector_range (lexp, _, _) | LE_field (lexp, _) -> lexp_is_local lexp env -let rec lexp_is_local_intro (LE_aux (lexp, _)) env = match lexp with +let rec lexp_is_local_intro (LE_aux (lexp, _)) env = + match lexp with | LE_app _ | LE_deref _ -> false - | LE_id id - | LE_typ (_, id) -> id_is_unbound id env + | LE_id id | LE_typ (_, id) -> id_is_unbound id env | LE_tuple lexps | LE_vector_concat lexps -> List.for_all (fun lexp -> lexp_is_local_intro lexp env) lexps - | LE_vector (lexp,_) - | LE_vector_range (lexp,_,_) - | LE_field (lexp,_) -> lexp_is_local_intro lexp env + | LE_vector (lexp, _) | LE_vector_range (lexp, _, _) | LE_field (lexp, _) -> lexp_is_local_intro lexp env let lexp_is_effectful (LE_aux (_, (_, tannot))) = Ast_util.effectful (effect_of_annot tannot) - + let find_used_vars exp = (* Overapproximates the set of used identifiers, but for the use cases below this is acceptable. *) - let e_id id = IdSet.singleton id, E_id id in - fst (fold_exp - { (compute_exp_alg IdSet.empty IdSet.union) with e_id = e_id } exp) + let e_id id = (IdSet.singleton id, E_id id) in + fst (fold_exp { (compute_exp_alg IdSet.empty IdSet.union) with e_id } exp) let find_introduced_vars exp = let le_aux ((ids, lexp), annot) = - let ids = match lexp with - | LE_id id | LE_typ (_, id) - when id_is_unbound id (env_of_annot annot) -> IdSet.add id ids - | _ -> ids in - (ids, LE_aux (lexp, annot)) in - fst (fold_exp - { (compute_exp_alg IdSet.empty IdSet.union) with le_aux = le_aux } exp) + let ids = + match lexp with + | (LE_id id | LE_typ (_, id)) when id_is_unbound id (env_of_annot annot) -> IdSet.add id ids + | _ -> ids + in + (ids, LE_aux (lexp, annot)) + in + fst (fold_exp { (compute_exp_alg IdSet.empty IdSet.union) with le_aux } exp) let find_updated_vars exp = let intros = find_introduced_vars exp in let le_aux ((ids, lexp), annot) = - let ids = match lexp with - | LE_id id | LE_typ (_, id) - when id_is_local_var id (env_of_annot annot) && not (IdSet.mem id intros) -> - IdSet.add id ids - | _ -> ids in - (ids, LE_aux (lexp, annot)) in - fst (fold_exp - { (compute_exp_alg IdSet.empty IdSet.union) with le_aux = le_aux } exp) + let ids = + match lexp with + | (LE_id id | LE_typ (_, id)) when id_is_local_var id (env_of_annot annot) && not (IdSet.mem id intros) -> + IdSet.add id ids + | _ -> ids + in + (ids, LE_aux (lexp, annot)) + in + fst (fold_exp { (compute_exp_alg IdSet.empty IdSet.union) with le_aux } exp) let lookup_equal_kids env = let get_eq_kids kid eqs = try KBindings.find kid eqs with Not_found -> KidSet.singleton kid in let add_eq_kids kid1 kid2 eqs = let kids = KidSet.union (get_eq_kids kid2 eqs) (get_eq_kids kid1 eqs) in - eqs - |> KBindings.add kid1 kids - |> KBindings.add kid2 kids + eqs |> KBindings.add kid1 kids |> KBindings.add kid2 kids in let add_nc eqs = function - | NC_aux (NC_equal (Nexp_aux (Nexp_var kid1, _), Nexp_aux (Nexp_var kid2, _)), _) -> - add_eq_kids kid1 kid2 eqs + | NC_aux (NC_equal (Nexp_aux (Nexp_var kid1, _), Nexp_aux (Nexp_var kid2, _)), _) -> add_eq_kids kid1 kid2 eqs | _ -> eqs in List.fold_left add_nc KBindings.empty (Env.get_constraints env) let lookup_constant_kid env kid = let kids = - match KBindings.find kid (lookup_equal_kids env) with - | kids -> kids - | exception Not_found -> KidSet.singleton kid - in - let check_nc const nc = match const, nc with - | None, NC_aux (NC_equal (Nexp_aux (Nexp_var kid, _), Nexp_aux (Nexp_constant i, _)), _) - when KidSet.mem kid kids -> - Some i + match KBindings.find kid (lookup_equal_kids env) with kids -> kids | exception Not_found -> KidSet.singleton kid + in + let check_nc const nc = + match (const, nc) with + | None, NC_aux (NC_equal (Nexp_aux (Nexp_var kid, _), Nexp_aux (Nexp_constant i, _)), _) when KidSet.mem kid kids -> + Some i | _, _ -> const in List.fold_left check_nc None (Env.get_constraints env) -let rec rewrite_nexp_ids env (Nexp_aux (nexp, l) as nexp_aux) = match nexp with +let rec rewrite_nexp_ids env (Nexp_aux (nexp, l) as nexp_aux) = + match nexp with | Nexp_id id -> Env.expand_nexp_synonyms env nexp_aux - | Nexp_var kid -> - begin - match lookup_constant_kid env kid with - | Some i -> nconstant i - | None -> nexp_aux - end + | Nexp_var kid -> begin match lookup_constant_kid env kid with Some i -> nconstant i | None -> nexp_aux end | Nexp_times (nexp1, nexp2) -> Nexp_aux (Nexp_times (rewrite_nexp_ids env nexp1, rewrite_nexp_ids env nexp2), l) | Nexp_sum (nexp1, nexp2) -> Nexp_aux (Nexp_sum (rewrite_nexp_ids env nexp1, rewrite_nexp_ids env nexp2), l) | Nexp_minus (nexp1, nexp2) -> Nexp_aux (Nexp_minus (rewrite_nexp_ids env nexp1, rewrite_nexp_ids env nexp2), l) @@ -227,31 +204,25 @@ let rec rewrite_nexp_ids env (Nexp_aux (nexp, l) as nexp_aux) = match nexp with | _ -> nexp_aux let rewrite_ast_nexp_ids, _rewrite_typ_nexp_ids = - let rec rewrite_typ env (Typ_aux (typ, l) as typ_aux) = match typ with - | Typ_fn (arg_ts, ret_t) -> - Typ_aux (Typ_fn (List.map (rewrite_typ env) arg_ts, rewrite_typ env ret_t), l) - | Typ_tuple ts -> - Typ_aux (Typ_tuple (List.map (rewrite_typ env) ts), l) - | Typ_exist (kids, c, typ) -> - Typ_aux (Typ_exist (kids, c, rewrite_typ env typ), l) - | Typ_app (id, targs) -> - Typ_aux (Typ_app (id, List.map (rewrite_typ_arg env) targs), l) + let rec rewrite_typ env (Typ_aux (typ, l) as typ_aux) = + match typ with + | Typ_fn (arg_ts, ret_t) -> Typ_aux (Typ_fn (List.map (rewrite_typ env) arg_ts, rewrite_typ env ret_t), l) + | Typ_tuple ts -> Typ_aux (Typ_tuple (List.map (rewrite_typ env) ts), l) + | Typ_exist (kids, c, typ) -> Typ_aux (Typ_exist (kids, c, rewrite_typ env typ), l) + | Typ_app (id, targs) -> Typ_aux (Typ_app (id, List.map (rewrite_typ_arg env) targs), l) | _ -> typ_aux - and rewrite_typ_arg env (A_aux (targ, l)) = match targ with - | A_nexp nexp -> - A_aux (A_nexp (rewrite_nexp_ids env nexp), l) - | A_typ typ -> - A_aux (A_typ (rewrite_typ env typ), l) - | A_order ord -> - A_aux (A_order ord, l) - | A_bool nc -> - A_aux (A_bool nc, l) + and rewrite_typ_arg env (A_aux (targ, l)) = + match targ with + | A_nexp nexp -> A_aux (A_nexp (rewrite_nexp_ids env nexp), l) + | A_typ typ -> A_aux (A_typ (rewrite_typ env typ), l) + | A_order ord -> A_aux (A_order ord, l) + | A_bool nc -> A_aux (A_bool nc, l) in let rewrite_annot (l, tannot) = match destruct_tannot tannot with - | Some (env, typ) -> l, replace_typ (rewrite_typ env typ) tannot - | None -> l, empty_tannot + | Some (env, typ) -> (l, replace_typ (rewrite_typ env typ) tannot) + | None -> (l, empty_tannot) in let rewrite_typschm env (TypSchm_aux (TypSchm_ts (tq, typ), l)) = @@ -260,90 +231,111 @@ let rewrite_ast_nexp_ids, _rewrite_typ_nexp_ids = let rewrite_def env rewriters = function | DEF_aux (DEF_val (VS_aux (VS_val_spec (typschm, id, exts, b), a)), def_annot) -> - let typschm = rewrite_typschm env typschm in - let a = rewrite_annot a in - DEF_aux (DEF_val (VS_aux (VS_val_spec (typschm, id, exts, b), a)), def_annot) + let typschm = rewrite_typschm env typschm in + let a = rewrite_annot a in + DEF_aux (DEF_val (VS_aux (VS_val_spec (typschm, id, exts, b), a)), def_annot) | DEF_aux (DEF_type (TD_aux (TD_abbrev (id, typq, typ_arg), a)), def_annot) -> - DEF_aux (DEF_type (TD_aux (TD_abbrev (id, typq, rewrite_typ_arg env typ_arg), a)), def_annot) + DEF_aux (DEF_type (TD_aux (TD_abbrev (id, typq, rewrite_typ_arg env typ_arg), a)), def_annot) | DEF_aux (DEF_type (TD_aux (TD_record (id, typq, fields, b), a)), def_annot) -> - let fields' = List.map (fun (t, id) -> (rewrite_typ env t, id)) fields in - DEF_aux (DEF_type (TD_aux (TD_record (id, typq, fields', b), a)), def_annot) + let fields' = List.map (fun (t, id) -> (rewrite_typ env t, id)) fields in + DEF_aux (DEF_type (TD_aux (TD_record (id, typq, fields', b), a)), def_annot) | DEF_aux (DEF_type (TD_aux (TD_variant (id, typq, constrs, b), a)), def_annot) -> - let constrs' = - List.map (fun (Tu_aux (Tu_ty_id (t, id), l)) -> - Tu_aux (Tu_ty_id (rewrite_typ env t, id), l)) - constrs - in - DEF_aux (DEF_type (TD_aux (TD_variant (id, typq, constrs', b), a)), def_annot) + let constrs' = + List.map (fun (Tu_aux (Tu_ty_id (t, id), l)) -> Tu_aux (Tu_ty_id (rewrite_typ env t, id), l)) constrs + in + DEF_aux (DEF_type (TD_aux (TD_variant (id, typq, constrs', b), a)), def_annot) | d -> Rewriter.rewrite_def rewriters d in - (fun env defs -> rewrite_ast_base { rewriters_base with - rewrite_exp = (fun _ -> map_exp_annot rewrite_annot); - rewrite_def = rewrite_def env - } defs), - rewrite_typ + ( (fun env defs -> + rewrite_ast_base + { rewriters_base with rewrite_exp = (fun _ -> map_exp_annot rewrite_annot); rewrite_def = rewrite_def env } + defs + ), + rewrite_typ + ) let rewrite_ast_remove_vector_subrange_pats env ast = let rewrite_pattern pat = let appends = ref Bindings.empty in let rec insert_into_append (n1, m1, id1, typ1) = function | (n2, m2, id2, typ2) :: xs -> - if Big_int.greater m1 n2 then - (n1, m1, id1, typ1) :: (n2, m2, id2, typ2) :: xs - else - (n2, m2, id2, typ2) :: insert_into_append (n1, m1, id1, typ1) xs - | [] -> [(n1, m1, id1, typ1)] in - let pat_alg = { + if Big_int.greater m1 n2 then (n1, m1, id1, typ1) :: (n2, m2, id2, typ2) :: xs + else (n2, m2, id2, typ2) :: insert_into_append (n1, m1, id1, typ1) xs + | [] -> [(n1, m1, id1, typ1)] + in + let pat_alg = + { id_pat_alg with p_aux = (fun (aux, annot) -> let typ = typ_of_annot annot in match aux with | P_vector_subrange (id, n, m) -> - let range_id = Printf.ksprintf mk_id "%s_%s_%s#" - (string_of_id id) (Big_int.to_string n) (Big_int.to_string m) in - appends := Bindings.update id (fun a -> Some (insert_into_append (n, m, range_id, typ) (Option.value a ~default:[]))) !appends; - P_aux (P_typ (typ, P_aux (P_id range_id, annot)), annot) + let range_id = + Printf.ksprintf mk_id "%s_%s_%s#" (string_of_id id) (Big_int.to_string n) (Big_int.to_string m) + in + appends := + Bindings.update id + (fun a -> Some (insert_into_append (n, m, range_id, typ) (Option.value a ~default:[]))) + !appends; + P_aux (P_typ (typ, P_aux (P_id range_id, annot)), annot) | _ -> P_aux (aux, annot) - ) - } in + ); + } + in let pat = fold_pat pat_alg pat in - pat, !appends + (pat, !appends) in let rewrite_pexp pat body = let pat, appends = rewrite_pattern pat in let body = - Bindings.fold (fun id append body -> + Bindings.fold + (fun id append body -> match append with | (_, _, id1, _) :: tl_append -> - let env = List.fold_left (fun env (_, _, id, typ) -> Env.add_local id (Immutable, typ) env) (env_of body) append in - let append_exp = List.fold_left (fun e1 (_, _, id2, _) -> mk_exp (E_vector_append (e1, mk_exp (E_id id2)))) (mk_exp (E_id id1)) tl_append in - let bind = mk_exp (E_let (mk_letbind (mk_pat (P_id id)) append_exp, mk_lit_exp L_unit)) in - let bind = check_exp env bind unit_typ in - begin match bind with - | E_aux (E_let (letbind, _), annot) -> - E_aux (E_let (letbind, body), annot) - | _ -> assert false - end - | [] -> - body - ) appends body in - pat, body - in - let exp_alg = { + let env = + List.fold_left (fun env (_, _, id, typ) -> Env.add_local id (Immutable, typ) env) (env_of body) append + in + let append_exp = + List.fold_left + (fun e1 (_, _, id2, _) -> mk_exp (E_vector_append (e1, mk_exp (E_id id2)))) + (mk_exp (E_id id1)) tl_append + in + let bind = mk_exp (E_let (mk_letbind (mk_pat (P_id id)) append_exp, mk_lit_exp L_unit)) in + let bind = check_exp env bind unit_typ in + begin + match bind with + | E_aux (E_let (letbind, _), annot) -> E_aux (E_let (letbind, body), annot) + | _ -> assert false + end + | [] -> body + ) + appends body + in + (pat, body) + in + let exp_alg = + { id_exp_alg with - pat_exp = (fun (pat, body) -> let pat, body = rewrite_pexp pat body in Pat_exp (pat, body)); - pat_when = (fun (pat, guard, body) -> let pat, body = rewrite_pexp pat body in Pat_when (pat, guard, body)) - } in + pat_exp = + (fun (pat, body) -> + let pat, body = rewrite_pexp pat body in + Pat_exp (pat, body) + ); + pat_when = + (fun (pat, guard, body) -> + let pat, body = rewrite_pexp pat body in + Pat_when (pat, guard, body) + ); + } + in let rewrite_exp _ = fold_exp exp_alg in - let rewrite_funcl (FCL_aux (FCL_funcl (id, pexp), annot)) = - FCL_aux (FCL_funcl (id, fold_pexp exp_alg pexp), annot) in + let rewrite_funcl (FCL_aux (FCL_funcl (id, pexp), annot)) = FCL_aux (FCL_funcl (id, fold_pexp exp_alg pexp), annot) in let rewrite_fun _ (FD_aux (FD_function (r_o, t_o, funcls), a)) = - FD_aux (FD_function (r_o, t_o, List.map rewrite_funcl funcls), a) in - rewrite_ast_base - { rewriters_base with rewrite_exp = rewrite_exp; rewrite_fun = rewrite_fun } - ast + FD_aux (FD_function (r_o, t_o, List.map rewrite_funcl funcls), a) + in + rewrite_ast_base { rewriters_base with rewrite_exp; rewrite_fun } ast let remove_vector_concat_pat pat = let fresh_id_v = fresh_id "v__" in @@ -354,70 +346,76 @@ let remove_vector_concat_pat pat = (* introduce names for all patterns of form P_vector_concat *) let name_vector_concat_roots = - { p_lit = (fun lit -> P_lit lit) - ; p_typ = (fun (typ,p) -> P_typ (typ,p false)) (* cannot happen *) - ; p_wild = P_wild - (* ToDo: I have no idea what the boolean parameter means so guessed that - * "true" was a good value to use. - * (Adding a comment explaining the boolean might be useful?) - *) - ; p_or = (fun (pat1, pat2) -> P_or (pat1 true, pat2 true)) - ; p_not = (fun pat -> P_not (pat true)) - ; p_as = (fun (pat,id) -> P_as (pat true,id)) - ; p_id = (fun id -> P_id id) - ; p_var = (fun (pat,kid) -> P_var (pat true,kid)) - ; p_app = (fun (id,ps) -> P_app (id, List.map (fun p -> p false) ps)) - ; p_vector = (fun ps -> P_vector (List.map (fun p -> p false) ps)) - ; p_vector_concat = (fun ps -> P_vector_concat (List.map (fun p -> p false) ps)) - ; p_vector_subrange = (fun (id, n, m) -> P_vector_subrange (id, n, m)) - ; p_tuple = (fun ps -> P_tuple (List.map (fun p -> p false) ps)) - ; p_list = (fun ps -> P_list (List.map (fun p -> p false) ps)) - ; p_cons = (fun (p,ps) -> P_cons (p false, ps false)) - ; p_string_append = (fun (ps) -> P_string_append (List.map (fun p -> p false) ps)) - ; p_aux = - (fun (pat,((l,_) as annot)) contained_in_p_as -> + { + p_lit = (fun lit -> P_lit lit); + p_typ = (fun (typ, p) -> P_typ (typ, p false)) (* cannot happen *); + p_wild = + P_wild + (* ToDo: I have no idea what the boolean parameter means so guessed that + * "true" was a good value to use. + * (Adding a comment explaining the boolean might be useful?) + *); + p_or = (fun (pat1, pat2) -> P_or (pat1 true, pat2 true)); + p_not = (fun pat -> P_not (pat true)); + p_as = (fun (pat, id) -> P_as (pat true, id)); + p_id = (fun id -> P_id id); + p_var = (fun (pat, kid) -> P_var (pat true, kid)); + p_app = (fun (id, ps) -> P_app (id, List.map (fun p -> p false) ps)); + p_vector = (fun ps -> P_vector (List.map (fun p -> p false) ps)); + p_vector_concat = (fun ps -> P_vector_concat (List.map (fun p -> p false) ps)); + p_vector_subrange = (fun (id, n, m) -> P_vector_subrange (id, n, m)); + p_tuple = (fun ps -> P_tuple (List.map (fun p -> p false) ps)); + p_list = (fun ps -> P_list (List.map (fun p -> p false) ps)); + p_cons = (fun (p, ps) -> P_cons (p false, ps false)); + p_string_append = (fun ps -> P_string_append (List.map (fun p -> p false) ps)); + p_aux = + (fun (pat, ((l, _) as annot)) contained_in_p_as -> match pat with | P_vector_concat pats -> - (if contained_in_p_as - then P_aux (pat,annot) - else P_aux (P_as (P_aux (pat,annot),fresh_id_v l),annot)) - | _ -> P_aux (pat,annot) - ) - } in + if contained_in_p_as then P_aux (pat, annot) else P_aux (P_as (P_aux (pat, annot), fresh_id_v l), annot) + | _ -> P_aux (pat, annot) + ); + } + in let pat = (fold_pat name_vector_concat_roots pat) false in (* introduce names for all unnamed child nodes of P_vector_concat *) let name_vector_concat_elements = let p_vector_concat pats = - let rec aux ((P_aux (p,((l,_) as a))) as pat) = match p with - | P_vector _ -> P_aux (P_as (pat, fresh_id_v l),a) + let rec aux (P_aux (p, ((l, _) as a)) as pat) = + match p with + | P_vector _ -> P_aux (P_as (pat, fresh_id_v l), a) | P_lit _ -> P_aux (P_as (pat, fresh_id_v l), a) - | P_id id -> P_aux (P_id id,a) - | P_as (p,id) -> P_aux (P_as (p,id),a) - | P_typ (typ, pat) -> P_aux (P_typ (typ, aux pat),a) - | P_wild -> P_aux (P_wild,a) - | P_app (id, pats) when Env.is_mapping id (env_of_annot a) -> - P_aux (P_app (id, List.map aux pats), a) + | P_id id -> P_aux (P_id id, a) + | P_as (p, id) -> P_aux (P_as (p, id), a) + | P_typ (typ, pat) -> P_aux (P_typ (typ, aux pat), a) + | P_wild -> P_aux (P_wild, a) + | P_app (id, pats) when Env.is_mapping id (env_of_annot a) -> P_aux (P_app (id, List.map aux pats), a) | _ -> - raise - (Reporting.err_unreachable - l __POS__ "name_vector_concat_elements: Non-vector in vector-concat pattern") in - P_vector_concat (List.map aux pats) in - {id_pat_alg with p_vector_concat = p_vector_concat} in + raise + (Reporting.err_unreachable l __POS__ "name_vector_concat_elements: Non-vector in vector-concat pattern") + in + P_vector_concat (List.map aux pats) + in + { id_pat_alg with p_vector_concat } + in let pat = fold_pat name_vector_concat_elements pat in let rec tag_last = function - | x :: xs -> let is_last = xs = [] in (x,is_last) :: tag_last xs - | _ -> [] in + | x :: xs -> + let is_last = xs = [] in + (x, is_last) :: tag_last xs + | _ -> [] + in (* remove names from vectors in vector_concat patterns and collect them as declarations for the function body or expression *) let unname_vector_concat_elements = (* build a let-expression of the form "let child = root[i..j] in body" *) - let letbind_vec typ_opt (rootid,rannot) (child,cannot) (i,j) = - let (l,_) = cannot in + let letbind_vec typ_opt (rootid, rannot) (child, cannot) (i, j) = + let l, _ = cannot in let env = env_of_annot rannot in let rootname = string_of_id rootid in let childname = string_of_id child in @@ -430,127 +428,158 @@ let remove_vector_concat_pat pat = let id_pat = match typ_opt with - | Some typ -> add_p_typ env typ (P_aux (P_id child,cannot)) - | None -> P_aux (P_id child,cannot) in - let letbind = LB_aux (LB_val (id_pat,subv),cannot) in - (letbind, - (fun body -> - if IdSet.mem child (find_used_vars body) - then annot_exp (E_let (letbind,body)) l env (typ_of body) - else body), - (rootname,childname)) in + | Some typ -> add_p_typ env typ (P_aux (P_id child, cannot)) + | None -> P_aux (P_id child, cannot) + in + let letbind = LB_aux (LB_val (id_pat, subv), cannot) in + ( letbind, + (fun body -> + if IdSet.mem child (find_used_vars body) then annot_exp (E_let (letbind, body)) l env (typ_of body) else body + ), + (rootname, childname) + ) + in let p_aux = function - | ((P_as (P_aux (P_vector_concat pats,rannot'),rootid),decls),rannot) -> - let rtyp = Env.base_typ_of (env_of_annot rannot') (typ_of_annot rannot') in - let (start,last_idx) = (match vector_start_index rtyp, vector_typ_args_of rtyp with - | Nexp_aux (Nexp_constant start,_), (Nexp_aux (Nexp_constant length,_), ord, _) -> - (start, if is_order_inc ord - then Big_int.sub (Big_int.add start length) (Big_int.of_int 1) - else Big_int.add (Big_int.sub start length) (Big_int.of_int 1)) - | _ -> - Reporting.unreachable (fst rannot') __POS__ - "unname_vector_concat_elements: vector of unspecified length in vector-concat pattern") in - let rec aux typ_opt (pos,pat_acc,decl_acc) (P_aux (p,cannot),is_last) = - let ctyp = Env.base_typ_of (env_of_annot cannot) (typ_of_annot cannot) in - let (length,ord,_) = vector_typ_args_of ctyp in - let (pos',index_j) = match Type_check.solve_unique (env_of_annot cannot) length with - | Some i -> - if is_order_inc ord - then (Big_int.add pos i, Big_int.sub (Big_int.add pos i) (Big_int.of_int 1)) - else (Big_int.sub pos i, Big_int.add (Big_int.sub pos i) (Big_int.of_int 1)) - | None -> - if is_last then (pos,last_idx) - else - Reporting.unreachable - (fst cannot) __POS__ ("unname_vector_concat_elements: vector of unspecified length in vector-concat pattern") in - (match p with + | (P_as (P_aux (P_vector_concat pats, rannot'), rootid), decls), rannot -> + let rtyp = Env.base_typ_of (env_of_annot rannot') (typ_of_annot rannot') in + let start, last_idx = + match (vector_start_index rtyp, vector_typ_args_of rtyp) with + | Nexp_aux (Nexp_constant start, _), (Nexp_aux (Nexp_constant length, _), ord, _) -> + ( start, + if is_order_inc ord then Big_int.sub (Big_int.add start length) (Big_int.of_int 1) + else Big_int.add (Big_int.sub start length) (Big_int.of_int 1) + ) + | _ -> + Reporting.unreachable (fst rannot') __POS__ + "unname_vector_concat_elements: vector of unspecified length in vector-concat pattern" + in + let rec aux typ_opt (pos, pat_acc, decl_acc) (P_aux (p, cannot), is_last) = + let ctyp = Env.base_typ_of (env_of_annot cannot) (typ_of_annot cannot) in + let length, ord, _ = vector_typ_args_of ctyp in + let pos', index_j = + match Type_check.solve_unique (env_of_annot cannot) length with + | Some i -> + if is_order_inc ord then (Big_int.add pos i, Big_int.sub (Big_int.add pos i) (Big_int.of_int 1)) + else (Big_int.sub pos i, Big_int.add (Big_int.sub pos i) (Big_int.of_int 1)) + | None -> + if is_last then (pos, last_idx) + else + Reporting.unreachable (fst cannot) __POS__ + "unname_vector_concat_elements: vector of unspecified length in vector-concat pattern" + in + match p with (* if we see a named vector pattern, remove the name and remember to - declare it later *) - | P_as (P_aux (p,cannot),cname) -> - let (lb,decl,info) = letbind_vec typ_opt (rootid,rannot) (cname,cannot) (pos,index_j) in - (pos', pat_acc @ [P_aux (p,cannot)], decl_acc @ [((lb,decl),info)]) + declare it later *) + | P_as (P_aux (p, cannot), cname) -> + let lb, decl, info = letbind_vec typ_opt (rootid, rannot) (cname, cannot) (pos, index_j) in + (pos', pat_acc @ [P_aux (p, cannot)], decl_acc @ [((lb, decl), info)]) (* if we see a P_id variable, remember to declare it later *) | P_id cname -> - let (lb,decl,info) = letbind_vec typ_opt (rootid,rannot) (cname,cannot) (pos,index_j) in - (pos', pat_acc @ [P_aux (P_id cname,cannot)], decl_acc @ [((lb,decl),info)]) - | P_typ (typ, pat) -> aux (Some typ) (pos,pat_acc,decl_acc) (pat, is_last) + let lb, decl, info = letbind_vec typ_opt (rootid, rannot) (cname, cannot) (pos, index_j) in + (pos', pat_acc @ [P_aux (P_id cname, cannot)], decl_acc @ [((lb, decl), info)]) + | P_typ (typ, pat) -> aux (Some typ) (pos, pat_acc, decl_acc) (pat, is_last) (* | P_app (cname, pats) if Env.is_mapping cname (en) -> * let (lb,decl,info) = letbind_vec typ_opt (rootid,rannot) (cname,cannot) (pos,index_j) in * (pos', pat_acc @ [P_aux (P_app (cname,pats),cannot)], decl_acc @ [((lb,decl),info)]) *) (* normal vector patterns are fine *) - | _ -> (pos', pat_acc @ [P_aux (p,cannot)],decl_acc)) in + | _ -> (pos', pat_acc @ [P_aux (p, cannot)], decl_acc) + in let pats_tagged = tag_last pats in - let (_,pats',decls') = List.fold_left (aux None) (start,[],[]) pats_tagged in + let _, pats', decls' = List.fold_left (aux None) (start, [], []) pats_tagged in (* abuse P_vector_concat as a P_vector_const pattern: it has the of - patterns as an argument but they're meant to be consed together *) - (P_aux (P_as (P_aux (P_vector_concat pats',rannot'),rootid),rannot), decls @ decls') - | ((p,decls),annot) -> (P_aux (p,annot),decls) in - - { p_lit = (fun lit -> (P_lit lit,[])) - ; p_wild = (P_wild,[]) - ; p_or = (fun ((pat1, ds1), (pat2, ds2)) -> (P_or(pat1, pat2), ds1 @ ds2)) - ; p_not = (fun (pat, ds) -> (P_not(pat), ds)) - ; p_as = (fun ((pat,decls),id) -> (P_as (pat,id),decls)) - ; p_typ = (fun (typ,(pat,decls)) -> (P_typ (typ,pat),decls)) - ; p_id = (fun id -> (P_id id,[])) - ; p_var = (fun ((pat,decls),kid) -> (P_var (pat,kid),decls)) - ; p_app = (fun (id,ps) -> let (ps,decls) = List.split ps in - (P_app (id,ps),List.flatten decls)) - ; p_vector = (fun ps -> let (ps,decls) = List.split ps in - (P_vector ps,List.flatten decls)) - ; p_vector_concat = (fun ps -> let (ps,decls) = List.split ps in - (P_vector_concat ps,List.flatten decls)) - ; p_vector_subrange = (fun (id, n, m) -> (P_vector_subrange (id, n, m), [])) - ; p_tuple = (fun ps -> let (ps,decls) = List.split ps in - (P_tuple ps,List.flatten decls)) - ; p_list = (fun ps -> let (ps,decls) = List.split ps in - (P_list ps,List.flatten decls)) - ; p_string_append = (fun ps -> let (ps,decls) = List.split ps in - (P_string_append ps,List.flatten decls)) - ; p_cons = (fun ((p,decls),(p',decls')) -> (P_cons (p,p'), decls @ decls')) - ; p_aux = (fun ((pat,decls),annot) -> p_aux ((pat,decls),annot)) - } in - - let (pat,decls) = fold_pat unname_vector_concat_elements pat in + patterns as an argument but they're meant to be consed together *) + (P_aux (P_as (P_aux (P_vector_concat pats', rannot'), rootid), rannot), decls @ decls') + | (p, decls), annot -> (P_aux (p, annot), decls) + in + + { + p_lit = (fun lit -> (P_lit lit, [])); + p_wild = (P_wild, []); + p_or = (fun ((pat1, ds1), (pat2, ds2)) -> (P_or (pat1, pat2), ds1 @ ds2)); + p_not = (fun (pat, ds) -> (P_not pat, ds)); + p_as = (fun ((pat, decls), id) -> (P_as (pat, id), decls)); + p_typ = (fun (typ, (pat, decls)) -> (P_typ (typ, pat), decls)); + p_id = (fun id -> (P_id id, [])); + p_var = (fun ((pat, decls), kid) -> (P_var (pat, kid), decls)); + p_app = + (fun (id, ps) -> + let ps, decls = List.split ps in + (P_app (id, ps), List.flatten decls) + ); + p_vector = + (fun ps -> + let ps, decls = List.split ps in + (P_vector ps, List.flatten decls) + ); + p_vector_concat = + (fun ps -> + let ps, decls = List.split ps in + (P_vector_concat ps, List.flatten decls) + ); + p_vector_subrange = (fun (id, n, m) -> (P_vector_subrange (id, n, m), [])); + p_tuple = + (fun ps -> + let ps, decls = List.split ps in + (P_tuple ps, List.flatten decls) + ); + p_list = + (fun ps -> + let ps, decls = List.split ps in + (P_list ps, List.flatten decls) + ); + p_string_append = + (fun ps -> + let ps, decls = List.split ps in + (P_string_append ps, List.flatten decls) + ); + p_cons = (fun ((p, decls), (p', decls')) -> (P_cons (p, p'), decls @ decls')); + p_aux = (fun ((pat, decls), annot) -> p_aux ((pat, decls), annot)); + } + in + + let pat, decls = fold_pat unname_vector_concat_elements pat in (* We need to put the decls in the right order so letbinds are generated correctly for nested patterns *) - let module G = Graph.Make(String) in + let module G = Graph.Make (String) in let root_graph = List.fold_left (fun g (_, (root_id, child_id)) -> G.add_edge root_id child_id g) G.empty decls in let root_order = G.topsort root_graph in let find_root root_id = - try List.find (fun (_, (root_id', _)) -> root_id = root_id') decls with - | Not_found -> - (* If it's not a root then it's a leaf node in the graph, so search for child_id *) - try List.find (fun (_, (_, child_id)) -> root_id = child_id) decls with - | Not_found -> assert false (* Should never happen *) + try List.find (fun (_, (root_id', _)) -> root_id = root_id') decls + with Not_found -> ( + (* If it's not a root then it's a leaf node in the graph, so search for child_id *) + try List.find (fun (_, (_, child_id)) -> root_id = child_id) decls + with Not_found -> assert false (* Should never happen *) + ) in let decls = List.map find_root root_order in - let (letbinds,decls) = + let letbinds, decls = let decls = List.map fst decls in - List.split decls in + List.split decls + in let decls = List.fold_left (fun f g x -> f (g x)) (fun b -> b) decls in (* Finally we patch up the top location for the expressions wrapped by decls, otherwise this can cause the coverage instrumentation - to get super confused by the generated locations *) + to get super confused by the generated locations *) let decls (E_aux (_, (l, _)) as exp) = - let E_aux (aux, (_, annot)) = decls exp in - E_aux (aux, (gen_loc l, annot)) in - + let (E_aux (aux, (_, annot))) = decls exp in + E_aux (aux, (gen_loc l, annot)) + in + (* at this point shouldn't have P_as patterns in P_vector_concat patterns any more, all P_as and P_id vectors should have their declarations in decls. Now flatten all vector_concat patterns *) let flatten = let p_vector_concat ps = - let aux p acc = match p with - | (P_aux (P_vector_concat pats,_)) -> pats @ acc - | pat -> pat :: acc in - P_vector_concat (List.fold_right aux ps []) in - {id_pat_alg with p_vector_concat = p_vector_concat} in + let aux p acc = match p with P_aux (P_vector_concat pats, _) -> pats @ acc | pat -> pat :: acc in + P_vector_concat (List.fold_right aux ps []) + in + { id_pat_alg with p_vector_concat } + in let pat = fold_pat flatten pat in @@ -558,29 +587,31 @@ let remove_vector_concat_pat pat = with vector_concats patterns as direct child-nodes anymore *) let range a b = let rec aux a b = if Big_int.greater a b then [] else a :: aux (Big_int.add a (Big_int.of_int 1)) b in - if Big_int.greater a b then List.rev (aux b a) else aux a b in + if Big_int.greater a b then List.rev (aux b a) else aux a b + in let remove_vector_concats = let p_vector_concat ps = - let aux acc (P_aux (p,annot),is_last) = + let aux acc (P_aux (p, annot), is_last) = let env = env_of_annot annot in let typ = Env.base_typ_of env (typ_of_annot annot) in - let (l,_) = annot in - let wild _ = P_aux (P_wild,(gen_loc l, mk_tannot env bit_typ)) in - if is_vector_typ typ || is_bitvector_typ typ then - match p, vector_typ_args_of typ with + let l, _ = annot in + let wild _ = P_aux (P_wild, (gen_loc l, mk_tannot env bit_typ)) in + if is_vector_typ typ || is_bitvector_typ typ then ( + match (p, vector_typ_args_of typ) with | P_vector ps, _ -> acc @ ps - | _, (nexp, _, _) -> - begin match Type_check.solve_unique env nexp with - | Some length -> - acc @ (List.map wild (range Big_int.zero (Big_int.sub length (Big_int.of_int 1)))) - | None -> - acc @ [wild Big_int.zero] - end - else raise - (Reporting.err_unreachable l __POS__ - ("remove_vector_concats: Non-vector in vector-concat pattern " ^ - string_of_typ (typ_of_annot annot))) in + | _, (nexp, _, _) -> begin + match Type_check.solve_unique env nexp with + | Some length -> acc @ List.map wild (range Big_int.zero (Big_int.sub length (Big_int.of_int 1))) + | None -> acc @ [wild Big_int.zero] + end + ) + else + raise + (Reporting.err_unreachable l __POS__ + ("remove_vector_concats: Non-vector in vector-concat pattern " ^ string_of_typ (typ_of_annot annot)) + ) + in let ps_tagged = tag_last ps in let ps' = List.fold_left aux [] ps_tagged in @@ -588,196 +619,185 @@ let remove_vector_concat_pat pat = P_vector ps' in - {id_pat_alg with p_vector_concat = p_vector_concat} in + { id_pat_alg with p_vector_concat } + in let pat = fold_pat remove_vector_concats pat in - (pat,letbinds,decls) + (pat, letbinds, decls) (* assumes there are no more E_internal expressions *) -let rewrite_exp_remove_vector_concat_pat rewriters (E_aux (exp,(l,annot)) as full_exp) = - let rewrap e = E_aux (e,(l,annot)) in +let rewrite_exp_remove_vector_concat_pat rewriters (E_aux (exp, (l, annot)) as full_exp) = + let rewrap e = E_aux (e, (l, annot)) in let rewrite_rec = rewriters.rewrite_exp rewriters in let rewrite_base = rewrite_exp rewriters in match exp with - | E_match (e,ps) -> - let aux = function - | (Pat_aux (Pat_exp (pat,body),annot')) -> - let (pat,_,decls) = remove_vector_concat_pat pat in - Pat_aux (Pat_exp (pat, decls (rewrite_rec body)),annot') - | (Pat_aux (Pat_when (pat,guard,body),annot')) -> - let (pat,_,decls) = remove_vector_concat_pat pat in - Pat_aux (Pat_when (pat, decls (rewrite_rec guard), decls (rewrite_rec body)),annot') in - rewrap (E_match (rewrite_rec e, List.map aux ps)) - | E_let (LB_aux (LB_val (pat,v),annot'),body) -> - let (pat,_,decls) = remove_vector_concat_pat pat in - rewrap (E_let (LB_aux (LB_val (pat,rewrite_rec v),annot'), - decls (rewrite_rec body))) + | E_match (e, ps) -> + let aux = function + | Pat_aux (Pat_exp (pat, body), annot') -> + let pat, _, decls = remove_vector_concat_pat pat in + Pat_aux (Pat_exp (pat, decls (rewrite_rec body)), annot') + | Pat_aux (Pat_when (pat, guard, body), annot') -> + let pat, _, decls = remove_vector_concat_pat pat in + Pat_aux (Pat_when (pat, decls (rewrite_rec guard), decls (rewrite_rec body)), annot') + in + rewrap (E_match (rewrite_rec e, List.map aux ps)) + | E_let (LB_aux (LB_val (pat, v), annot'), body) -> + let pat, _, decls = remove_vector_concat_pat pat in + rewrap (E_let (LB_aux (LB_val (pat, rewrite_rec v), annot'), decls (rewrite_rec body))) | exp -> rewrite_base full_exp -let rewrite_fun_remove_vector_concat_pat - rewriters (FD_aux (FD_function(recopt,tannotopt,funcls),(l,fdannot))) = - let rewrite_funcl (FCL_aux (FCL_funcl(id,pexp),(l,annot))) = - let pat,guard,exp,pannot = destruct_pexp pexp in - let (pat',_,decls) = remove_vector_concat_pat pat in - let guard' = match guard with - | Some exp -> Some (decls (rewriters.rewrite_exp rewriters exp)) - | None -> None in +let rewrite_fun_remove_vector_concat_pat rewriters (FD_aux (FD_function (recopt, tannotopt, funcls), (l, fdannot))) = + let rewrite_funcl (FCL_aux (FCL_funcl (id, pexp), (l, annot))) = + let pat, guard, exp, pannot = destruct_pexp pexp in + let pat', _, decls = remove_vector_concat_pat pat in + let guard' = match guard with Some exp -> Some (decls (rewriters.rewrite_exp rewriters exp)) | None -> None in let exp' = decls (rewriters.rewrite_exp rewriters exp) in - let pexp' = construct_pexp (pat',guard',exp',pannot) in - (FCL_aux (FCL_funcl (id,pexp'),(l,annot))) - in FD_aux (FD_function(recopt,tannotopt,List.map rewrite_funcl funcls),(l,fdannot)) + let pexp' = construct_pexp (pat', guard', exp', pannot) in + FCL_aux (FCL_funcl (id, pexp'), (l, annot)) + in + FD_aux (FD_function (recopt, tannotopt, List.map rewrite_funcl funcls), (l, fdannot)) let rewrite_ast_remove_vector_concat env ast = let rewriters = - {rewrite_exp = rewrite_exp_remove_vector_concat_pat; - rewrite_pat = rewrite_pat; - rewrite_let = rewrite_let; - rewrite_lexp = rewrite_lexp; - rewrite_fun = rewrite_fun_remove_vector_concat_pat; - rewrite_def = rewrite_def; - rewrite_ast = rewrite_ast_base} in + { + rewrite_exp = rewrite_exp_remove_vector_concat_pat; + rewrite_pat; + rewrite_let; + rewrite_lexp; + rewrite_fun = rewrite_fun_remove_vector_concat_pat; + rewrite_def; + rewrite_ast = rewrite_ast_base; + } + in let rewrite_def d = let d = rewriters.rewrite_def rewriters d in match d with - | DEF_aux (DEF_let (LB_aux (LB_val (pat,exp),a)), def_annot) -> - let (pat,letbinds,_) = remove_vector_concat_pat pat in - let defvals = List.map (fun lb -> DEF_aux (DEF_let lb, mk_def_annot (gen_loc def_annot.loc))) letbinds in - [DEF_aux (DEF_let (LB_aux (LB_val (pat,exp),a)), def_annot)] @ defvals - | d -> [d] in + | DEF_aux (DEF_let (LB_aux (LB_val (pat, exp), a)), def_annot) -> + let pat, letbinds, _ = remove_vector_concat_pat pat in + let defvals = List.map (fun lb -> DEF_aux (DEF_let lb, mk_def_annot (gen_loc def_annot.loc))) letbinds in + [DEF_aux (DEF_let (LB_aux (LB_val (pat, exp), a)), def_annot)] @ defvals + | d -> [d] + in { ast with defs = List.flatten (List.map rewrite_def ast.defs) } (* A few helper functions for rewriting guarded pattern clauses. Used both by the rewriting of P_when and separately by the rewriting of bitvectors in parameter patterns of function clauses *) -let remove_wildcards pre (P_aux (_,(l,_)) as pat) = +let remove_wildcards pre (P_aux (_, (l, _)) as pat) = fold_pat - {id_pat_alg with - p_aux = function - | (P_wild,(l,annot)) -> P_aux (P_id (fresh_id pre l),(l,annot)) - | (p,annot) -> P_aux (p,annot) } + { + id_pat_alg with + p_aux = (function P_wild, (l, annot) -> P_aux (P_id (fresh_id pre l), (l, annot)) | p, annot -> P_aux (p, annot)); + } pat -let rec is_irrefutable_pattern (P_aux (p,ann)) = +let rec is_irrefutable_pattern (P_aux (p, ann)) = match p with - | P_lit (L_aux (L_unit,_)) - | P_wild - -> true - | P_or(pat1, pat2) -> is_irrefutable_pattern pat1 && is_irrefutable_pattern pat2 - | P_not(pat) -> is_irrefutable_pattern pat + | P_lit (L_aux (L_unit, _)) | P_wild -> true + | P_or (pat1, pat2) -> is_irrefutable_pattern pat1 && is_irrefutable_pattern pat2 + | P_not pat -> is_irrefutable_pattern pat | P_lit _ -> false - | P_as (p1,_) - | P_typ (_,p1) - -> is_irrefutable_pattern p1 + | P_as (p1, _) | P_typ (_, p1) -> is_irrefutable_pattern p1 | P_vector_subrange _ -> true | P_id id -> begin - match Env.lookup_id id (env_of_annot ann) with - | Local _ | Unbound _ -> true - | Register _ -> false (* should be impossible, anyway *) - | Enum enum -> - match enum with - | Typ_aux (Typ_id enum_id,_) -> - List.length (Env.get_enum enum_id (env_of_annot ann)) <= 1 - | _ -> false (* should be impossible, anyway *) - end - | P_var (p1,_) -> is_irrefutable_pattern p1 - | P_app (f,args) -> - Env.is_singleton_union_constructor f (env_of_annot ann) && - List.for_all is_irrefutable_pattern args - | P_vector ps - | P_vector_concat ps - | P_tuple ps - | P_list ps - -> List.for_all is_irrefutable_pattern ps - | P_cons (p1,p2) -> is_irrefutable_pattern p1 && is_irrefutable_pattern p2 - | P_string_append ps - -> List.for_all is_irrefutable_pattern ps + match Env.lookup_id id (env_of_annot ann) with + | Local _ | Unbound _ -> true + | Register _ -> false (* should be impossible, anyway *) + | Enum enum -> ( + match enum with + | Typ_aux (Typ_id enum_id, _) -> List.length (Env.get_enum enum_id (env_of_annot ann)) <= 1 + | _ -> false (* should be impossible, anyway *) + ) + end + | P_var (p1, _) -> is_irrefutable_pattern p1 + | P_app (f, args) -> + Env.is_singleton_union_constructor f (env_of_annot ann) && List.for_all is_irrefutable_pattern args + | P_vector ps | P_vector_concat ps | P_tuple ps | P_list ps -> List.for_all is_irrefutable_pattern ps + | P_cons (p1, p2) -> is_irrefutable_pattern p1 && is_irrefutable_pattern p2 + | P_string_append ps -> List.for_all is_irrefutable_pattern ps (* Check if one pattern subsumes the other, and if so, calculate a substitution of variables that are used in the same position. TODO: Check somewhere that there are no variable clashes (the same variable name used in different positions of the patterns) - *) -let rec subsumes_pat (P_aux (p1,annot1) as pat1) (P_aux (p2,annot2) as pat2) = - let rewrap p = P_aux (p,annot1) in +*) +let rec subsumes_pat (P_aux (p1, annot1) as pat1) (P_aux (p2, annot2) as pat2) = + let rewrap p = P_aux (p, annot1) in let subsumes_list pats1 pats2 = - if List.length pats1 = List.length pats2 - then + if List.length pats1 = List.length pats2 then ( let subs = List.map2 subsumes_pat pats1 pats2 in List.fold_right - (fun p acc -> match p, acc with - | Some subst, Some substs -> Some (subst @ substs) - | _ -> None) + (fun p acc -> match (p, acc) with Some subst, Some substs -> Some (subst @ substs) | _ -> None) subs (Some []) - else None in - match p1, p2 with - | P_lit (L_aux (lit1,_)), P_lit (L_aux (lit2,_)) -> - if lit1 = lit2 then Some [] else None - | P_or(pat1, pat2), _ -> (* todo: possibly not the right answer *) None - | _, P_or(pat1, pat2) -> (* todo: possibly not the right answer *) None - | P_not(pat), _ -> (* todo: possibly not the right answer *) None - | _, P_not(pat) -> (* todo: possibly not the right answer *) None + ) + else None + in + match (p1, p2) with + | P_lit (L_aux (lit1, _)), P_lit (L_aux (lit2, _)) -> if lit1 = lit2 then Some [] else None + | P_or (pat1, pat2), _ -> (* todo: possibly not the right answer *) None + | _, P_or (pat1, pat2) -> (* todo: possibly not the right answer *) None + | P_not pat, _ -> (* todo: possibly not the right answer *) None + | _, P_not pat -> (* todo: possibly not the right answer *) None | P_as (pat1, id1), _ -> - (* Abuse subsumes_list to check that both the nested pattern and the - * variable binding can subsume the other pattern *) - subsumes_list [P_aux (P_id id1, annot1); pat1] [pat2; pat2] + (* Abuse subsumes_list to check that both the nested pattern and the + * variable binding can subsume the other pattern *) + subsumes_list [P_aux (P_id id1, annot1); pat1] [pat2; pat2] | _, P_as (pat2, id2) -> - (* Ditto for the other direction *) - subsumes_list [pat1; pat1] [P_aux (P_id id2, annot2); pat2] - | P_typ (_,pat1), _ -> subsumes_pat pat1 pat2 - | _, P_typ (_,pat2) -> subsumes_pat pat1 pat2 - | P_id (Id_aux (id1,_) as aid1), P_id (Id_aux (id2,_) as aid2) -> - if id1 = id2 then Some [] else - if is_unbound (Env.lookup_id aid1 (env_of_annot annot1)) then - if is_unbound (Env.lookup_id aid2 (env_of_annot annot2)) then Some [(id2,id1)] else Some [] - else None - | P_id id1, _ -> - if is_unbound (Env.lookup_id id1 (env_of_annot annot1)) then Some [] else None - | P_var (pat1,_), P_var (pat2,_) -> subsumes_pat pat1 pat2 + (* Ditto for the other direction *) + subsumes_list [pat1; pat1] [P_aux (P_id id2, annot2); pat2] + | P_typ (_, pat1), _ -> subsumes_pat pat1 pat2 + | _, P_typ (_, pat2) -> subsumes_pat pat1 pat2 + | P_id (Id_aux (id1, _) as aid1), P_id (Id_aux (id2, _) as aid2) -> + if id1 = id2 then Some [] + else if is_unbound (Env.lookup_id aid1 (env_of_annot annot1)) then + if is_unbound (Env.lookup_id aid2 (env_of_annot annot2)) then Some [(id2, id1)] else Some [] + else None + | P_id id1, _ -> if is_unbound (Env.lookup_id id1 (env_of_annot annot1)) then Some [] else None + | P_var (pat1, _), P_var (pat2, _) -> subsumes_pat pat1 pat2 | P_wild, _ -> Some [] - | P_app (Id_aux (id1,_),args1), P_app (Id_aux (id2,_),args2) -> - if id1 = id2 then subsumes_list args1 args2 else None + | P_app (Id_aux (id1, _), args1), P_app (Id_aux (id2, _), args2) -> + if id1 = id2 then subsumes_list args1 args2 else None | P_vector pats1, P_vector pats2 | P_vector_concat pats1, P_vector_concat pats2 | P_tuple pats1, P_tuple pats2 | P_list pats1, P_list pats2 -> - subsumes_list pats1 pats2 - | P_list (pat1 :: pats1), P_cons _ -> - subsumes_pat (rewrap (P_cons (pat1, rewrap (P_list pats1)))) pat2 - | P_cons _, P_list (pat2 :: pats2)-> - subsumes_pat pat1 (rewrap (P_cons (pat2, rewrap (P_list pats2)))) - | P_cons (pat1, pats1), P_cons (pat2, pats2) -> - (match subsumes_pat pat1 pat2, subsumes_pat pats1 pats2 with - | Some substs1, Some substs2 -> Some (substs1 @ substs2) - | _ -> None) + subsumes_list pats1 pats2 + | P_list (pat1 :: pats1), P_cons _ -> subsumes_pat (rewrap (P_cons (pat1, rewrap (P_list pats1)))) pat2 + | P_cons _, P_list (pat2 :: pats2) -> subsumes_pat pat1 (rewrap (P_cons (pat2, rewrap (P_list pats2)))) + | P_cons (pat1, pats1), P_cons (pat2, pats2) -> ( + match (subsumes_pat pat1 pat2, subsumes_pat pats1 pats2) with + | Some substs1, Some substs2 -> Some (substs1 @ substs2) + | _ -> None + ) | _, P_wild -> if is_irrefutable_pattern pat1 then Some [] else None | _ -> None let vector_string_to_bits_pat (L_aux (lit, _) as l_aux) (l, tannot) = - let bit_annot = match destruct_tannot tannot with - | Some (env, _) -> mk_tannot env bit_typ - | None -> empty_tannot - in - begin match lit with - | L_hex _ | L_bin _ -> P_aux (P_vector (List.map (fun p -> P_aux (P_lit p, (l, bit_annot))) (vector_string_to_bit_list l_aux)), (l, tannot)) - | lit -> P_aux (P_lit l_aux, (l, tannot)) + let bit_annot = match destruct_tannot tannot with Some (env, _) -> mk_tannot env bit_typ | None -> empty_tannot in + begin + match lit with + | L_hex _ | L_bin _ -> + P_aux + (P_vector (List.map (fun p -> P_aux (P_lit p, (l, bit_annot))) (vector_string_to_bit_list l_aux)), (l, tannot)) + | lit -> P_aux (P_lit l_aux, (l, tannot)) end let vector_string_to_bits_exp (L_aux (lit, _) as l_aux) (l, tannot) = - let bit_annot = match destruct_tannot tannot with - | Some (env, _) -> mk_tannot env bit_typ - | None -> empty_tannot - in - begin match lit with - | L_hex _ | L_bin _ -> E_aux (E_vector (List.map (fun p -> E_aux (E_lit p, (l, bit_annot))) (vector_string_to_bit_list l_aux)), (l, tannot)) - | lit -> E_aux (E_lit l_aux, (l, tannot)) + let bit_annot = match destruct_tannot tannot with Some (env, _) -> mk_tannot env bit_typ | None -> empty_tannot in + begin + match lit with + | L_hex _ | L_bin _ -> + E_aux + (E_vector (List.map (fun p -> E_aux (E_lit p, (l, bit_annot))) (vector_string_to_bit_list l_aux)), (l, tannot)) + | lit -> E_aux (E_lit l_aux, (l, tannot)) end (* A simple check for pattern disjointness; used for optimisation in the guarded pattern rewrite step *) -let rec disjoint_pat env (P_aux (p1,annot1) as pat1) (P_aux (p2,annot2) as pat2) = - match p1, p2 with +let rec disjoint_pat env (P_aux (p1, annot1) as pat1) (P_aux (p2, annot2) as pat2) = + match (p1, p2) with | P_as (pat1, _), _ -> disjoint_pat env pat1 pat2 | _, P_as (pat2, _) -> disjoint_pat env pat1 pat2 | P_typ (_, pat1), _ -> disjoint_pat env pat1 pat2 @@ -788,68 +808,59 @@ let rec disjoint_pat env (P_aux (p1,annot1) as pat1) (P_aux (p2,annot2) as pat2) | _, P_id id when id_is_unbound id env -> false | P_id id1, P_id id2 -> Id.compare id1 id2 <> 0 | P_lit (L_aux ((L_bin _ | L_hex _), _) as lit), _ -> - disjoint_pat env (vector_string_to_bits_pat lit (Unknown, empty_tannot)) pat2 + disjoint_pat env (vector_string_to_bits_pat lit (Unknown, empty_tannot)) pat2 | _, P_lit (L_aux ((L_bin _ | L_hex _), _) as lit) -> - disjoint_pat env pat1 (vector_string_to_bits_pat lit (Unknown, empty_tannot)) - | P_lit (L_aux (L_num n1, _)), P_lit (L_aux (L_num n2, _)) -> - not (Big_int.equal n1 n2) - | P_lit (L_aux (l1, _)), P_lit (L_aux (l2, _)) -> - l1 <> l2 - | P_app (id1, args1), P_app (id2, args2) -> - Id.compare id1 id2 <> 0 || List.exists2 (disjoint_pat env) args1 args2 - | P_vector pats1, P_vector pats2 - | P_tuple pats1, P_tuple pats2 - | P_list pats1, P_list pats2 -> - List.length pats1 <> List.length pats2 || List.exists2 (disjoint_pat env) pats1 pats2 + disjoint_pat env pat1 (vector_string_to_bits_pat lit (Unknown, empty_tannot)) + | P_lit (L_aux (L_num n1, _)), P_lit (L_aux (L_num n2, _)) -> not (Big_int.equal n1 n2) + | P_lit (L_aux (l1, _)), P_lit (L_aux (l2, _)) -> l1 <> l2 + | P_app (id1, args1), P_app (id2, args2) -> Id.compare id1 id2 <> 0 || List.exists2 (disjoint_pat env) args1 args2 + | P_vector pats1, P_vector pats2 | P_tuple pats1, P_tuple pats2 | P_list pats1, P_list pats2 -> + List.length pats1 <> List.length pats2 || List.exists2 (disjoint_pat env) pats1 pats2 | _ -> false let equiv_pats pat1 pat2 = - match subsumes_pat pat1 pat2, subsumes_pat pat2 pat1 with - | Some _, Some _ -> true - | _, _ -> false + match (subsumes_pat pat1 pat2, subsumes_pat pat2 pat1) with Some _, Some _ -> true | _, _ -> false -let subst_id_pat pat (id1,id2) = - let p_id (Id_aux (id,l)) = (if id = id1 then P_id (Id_aux (id2,l)) else P_id (Id_aux (id,l))) in - fold_pat {id_pat_alg with p_id = p_id} pat +let subst_id_pat pat (id1, id2) = + let p_id (Id_aux (id, l)) = if id = id1 then P_id (Id_aux (id2, l)) else P_id (Id_aux (id, l)) in + fold_pat { id_pat_alg with p_id } pat -let subst_id_exp exp (id1,id2) = - Ast_util.subst (Id_aux (id1,Parse_ast.Unknown)) (E_aux (E_id (Id_aux (id2,Parse_ast.Unknown)),(Parse_ast.Unknown,empty_tannot))) exp +let subst_id_exp exp (id1, id2) = + Ast_util.subst + (Id_aux (id1, Parse_ast.Unknown)) + (E_aux (E_id (Id_aux (id2, Parse_ast.Unknown)), (Parse_ast.Unknown, empty_tannot))) + exp -let rec pat_to_exp ((P_aux (pat,(l,annot))) as p_aux) = - let rewrap e = E_aux (e,(l,annot)) in +let rec pat_to_exp (P_aux (pat, (l, annot)) as p_aux) = + let rewrap e = E_aux (e, (l, annot)) in let env = env_of_pat p_aux in let typ = typ_of_pat p_aux in match pat with | P_lit lit -> rewrap (E_lit lit) - | P_wild -> raise (Reporting.err_unreachable l __POS__ - "pat_to_exp given wildcard pattern") - | P_or(pat1, pat2) -> (* todo: insert boolean or *) pat_to_exp pat1 - | P_not(pat) -> (* todo: insert boolean not *) pat_to_exp pat - | P_as (pat,id) -> rewrap (E_id id) + | P_wild -> raise (Reporting.err_unreachable l __POS__ "pat_to_exp given wildcard pattern") + | P_or (pat1, pat2) -> (* todo: insert boolean or *) pat_to_exp pat1 + | P_not pat -> (* todo: insert boolean not *) pat_to_exp pat + | P_as (pat, id) -> rewrap (E_id id) | P_var (pat, _) -> pat_to_exp pat - | P_typ (_,pat) -> pat_to_exp pat + | P_typ (_, pat) -> pat_to_exp pat | P_id id -> rewrap (E_id id) | P_vector_subrange (id, n, m) -> - let subrange = mk_exp (E_vector_subrange (mk_exp (E_id id), mk_lit_exp (L_num n), mk_lit_exp (L_num m))) in - check_exp env subrange typ - | P_app (id,pats) -> rewrap (E_app (id, List.map pat_to_exp pats)) + let subrange = mk_exp (E_vector_subrange (mk_exp (E_id id), mk_lit_exp (L_num n), mk_lit_exp (L_num m))) in + check_exp env subrange typ + | P_app (id, pats) -> rewrap (E_app (id, List.map pat_to_exp pats)) | P_vector pats -> rewrap (E_vector (List.map pat_to_exp pats)) | P_vector_concat pats -> begin let empty_vec = E_aux (E_vector [], (l, empty_uannot)) in - let concat_vectors vec1 vec2 = - E_aux (E_vector_append (vec1, vec2), (l, empty_uannot)) - in + let concat_vectors vec1 vec2 = E_aux (E_vector_append (vec1, vec2), (l, empty_uannot)) in check_exp env (List.fold_right concat_vectors (List.map (fun p -> strip_exp (pat_to_exp p)) pats) empty_vec) typ end | P_tuple pats -> rewrap (E_tuple (List.map pat_to_exp pats)) | P_list pats -> rewrap (E_list (List.map pat_to_exp pats)) - | P_cons (p,ps) -> rewrap (E_cons (pat_to_exp p, pat_to_exp ps)) - | P_string_append (pats) -> begin + | P_cons (p, ps) -> rewrap (E_cons (pat_to_exp p, pat_to_exp ps)) + | P_string_append pats -> begin let empty_string = annot_exp (E_lit (L_aux (L_string "", l))) l env string_typ in - let string_append str1 str2 = - annot_exp (E_app (mk_id "string_append", [str1; str2])) l env string_typ - in - (List.fold_right string_append (List.map pat_to_exp pats) empty_string) + let string_append str1 str2 = annot_exp (E_app (mk_id "string_append", [str1; str2])) l env string_typ in + List.fold_right string_append (List.map pat_to_exp pats) empty_string end let case_exp e t cs = @@ -857,12 +868,11 @@ let case_exp e t cs = let env = env_of e in match cs with | [(P_aux (P_wild, _), body, _)] -> body - | [(P_aux (P_id id, pannot) as pat, body, _)] -> - annot_exp (E_let (LB_aux (LB_val (pat, e), pannot), body)) l env t + | [((P_aux (P_id id, pannot) as pat), body, _)] -> annot_exp (E_let (LB_aux (LB_val (pat, e), pannot), body)) l env t | _ -> - let pexp (pat,body,annot) = Pat_aux (Pat_exp (pat,body),annot) in - let ps = List.map pexp cs in - annot_exp (E_match (e,ps)) l env t + let pexp (pat, body, annot) = Pat_aux (Pat_exp (pat, body), annot) in + let ps = List.map pexp cs in + annot_exp (E_match (e, ps)) l env t module PC_config = struct type t = tannot @@ -870,106 +880,115 @@ module PC_config = struct let add_attribute l attr arg = map_uannot (add_attribute l attr arg) end -module PC = Pattern_completeness.Make(PC_config);; +module PC = Pattern_completeness.Make (PC_config) let pats_complete l env ps typ = - let ctx = { + let ctx = + { Pattern_completeness.variants = Env.get_variants env; Pattern_completeness.enums = Env.get_enums env; Pattern_completeness.constraints = Env.get_constraints env; - } in + } + in PC.is_complete l ctx ps typ - + (* Rewrite guarded patterns into a combination of if-expressions and - unguarded pattern matches - - Strategy: - - Split clauses into groups where the first pattern subsumes all the - following ones - - Translate the groups in reverse order, using the next group as a - fall-through target, if there is one - - Within a group, - - translate the sequence of clauses to an if-then-else cascade using the - guards as long as the patterns are equivalent modulo substitution, or - - recursively translate the remaining clauses to a pattern match if - there is a difference in the patterns. - - TODO: Compare this more closely with the algorithm in the CPP'18 paper of - Spector-Zabusky et al, who seem to use the opposite grouping and merging - strategy to ours: group *mutually exclusive* clauses, and try to merge them - into a pattern match first instead of an if-then-else cascade. + unguarded pattern matches + + Strategy: + - Split clauses into groups where the first pattern subsumes all the + following ones + - Translate the groups in reverse order, using the next group as a + fall-through target, if there is one + - Within a group, + - translate the sequence of clauses to an if-then-else cascade using the + guards as long as the patterns are equivalent modulo substitution, or + - recursively translate the remaining clauses to a pattern match if + there is a difference in the patterns. + + TODO: Compare this more closely with the algorithm in the CPP'18 paper of + Spector-Zabusky et al, who seem to use the opposite grouping and merging + strategy to ours: group *mutually exclusive* clauses, and try to merge them + into a pattern match first instead of an if-then-else cascade. *) -let rewrite_toplevel_guarded_clauses mk_fallthrough l env pat_typ typ (cs : (tannot pat * tannot exp option * tannot exp * tannot clause_annot) list) = +let rewrite_toplevel_guarded_clauses mk_fallthrough l env pat_typ typ + (cs : (tannot pat * tannot exp option * tannot exp * tannot clause_annot) list) = let annot_from_clause (def_annot, tannot) = (def_annot.loc, tannot) in let fix_fallthrough (pat, guard, exp, (l, tannot)) = (pat, guard, exp, (mk_def_annot l, tannot)) in - + let rec group fallthrough clauses = let add_clause (pat, cls, annot) c = (pat, cls @ [c], annot) in let rec group_aux current acc = function - | ((pat, guard, body, annot) as c) :: cs -> - let (current_pat,_,_) = current in - (match subsumes_pat current_pat pat with - | Some substs -> - let pat' = List.fold_left subst_id_pat pat substs in - let guard' = (match guard with - | Some exp -> Some (List.fold_left subst_id_exp exp substs) - | None -> None) in - let body' = List.fold_left subst_id_exp body substs in - let c' = (pat', guard', body', annot) in - group_aux (add_clause current c') acc cs - | None -> - let pat = match cs with _::_ -> remove_wildcards "g__" pat | _ -> pat in - group_aux (pat,[c], annot_from_clause annot) (acc @ [current]) cs) - | [] -> acc @ [current] in - let groups = match clauses with - | [(pat, guard, body, annot) as c] -> - [(pat, [c], annot_from_clause annot)] + | ((pat, guard, body, annot) as c) :: cs -> ( + let current_pat, _, _ = current in + match subsumes_pat current_pat pat with + | Some substs -> + let pat' = List.fold_left subst_id_pat pat substs in + let guard' = + match guard with Some exp -> Some (List.fold_left subst_id_exp exp substs) | None -> None + in + let body' = List.fold_left subst_id_exp body substs in + let c' = (pat', guard', body', annot) in + group_aux (add_clause current c') acc cs + | None -> + let pat = match cs with _ :: _ -> remove_wildcards "g__" pat | _ -> pat in + group_aux (pat, [c], annot_from_clause annot) (acc @ [current]) cs + ) + | [] -> acc @ [current] + in + let groups = + match clauses with + | [((pat, guard, body, annot) as c)] -> [(pat, [c], annot_from_clause annot)] | ((pat, guard, body, annot) as c) :: cs -> group_aux (remove_wildcards "g__" pat, [c], annot_from_clause annot) [] cs - | _ -> - raise (Reporting.err_unreachable l __POS__ - "group given empty list in rewrite_guarded_clauses") in - let add_group cs groups = (if_pexp (groups @ fallthrough) cs) :: groups in + | _ -> raise (Reporting.err_unreachable l __POS__ "group given empty list in rewrite_guarded_clauses") + in + let add_group cs groups = if_pexp (groups @ fallthrough) cs :: groups in List.fold_right add_group groups [] - - and if_pexp fallthrough (pat, cs, annot) = (match cs with + and if_pexp fallthrough (pat, cs, annot) = + match cs with | c :: _ -> let body = if_exp fallthrough pat cs in (pat, body, annot) - | [] -> - raise (Reporting.err_unreachable l __POS__ - "if_pexp given empty list in rewrite_guarded_clauses")) - + | [] -> raise (Reporting.err_unreachable l __POS__ "if_pexp given empty list in rewrite_guarded_clauses") and if_exp fallthrough current_pat = function - | (pat, guard, body, annot) :: ((pat', guard', body', annot') as c') :: cs -> - (match guard with - | Some exp -> - let else_exp = - if equiv_pats current_pat pat' - then if_exp fallthrough current_pat (c' :: cs) - else case_exp (pat_to_exp current_pat) (typ_of body') (group fallthrough (c' :: cs)) in - annot_exp (E_if (exp, body, else_exp)) (fst annot).loc (env_of exp) (typ_of body) - | None -> body) - | [(pat, guard, body, annot)] -> + | (pat, guard, body, annot) :: ((pat', guard', body', annot') as c') :: cs -> ( + match guard with + | Some exp -> + let else_exp = + if equiv_pats current_pat pat' then if_exp fallthrough current_pat (c' :: cs) + else case_exp (pat_to_exp current_pat) (typ_of body') (group fallthrough (c' :: cs)) + in + annot_exp (E_if (exp, body, else_exp)) (fst annot).loc (env_of exp) (typ_of body) + | None -> body + ) + | [(pat, guard, body, annot)] -> ( (* For singleton clauses with a guard, use fallthrough clauses if the guard is not satisfied, but only those fallthrough clauses that are not disjoint with the current pattern *) let overlapping_clause (pat, _, _) = not (disjoint_pat env current_pat pat) in let fallthrough = List.filter overlapping_clause fallthrough in - (match guard, fallthrough with - | Some exp, _ :: _ -> - let else_exp = case_exp (pat_to_exp current_pat) (typ_of body) fallthrough in - annot_exp (E_if (exp, body, else_exp)) (fst annot).loc (env_of exp) (typ_of body) - | _, _ -> body) - | [] -> - raise (Reporting.err_unreachable l __POS__ - "if_exp given empty list in rewrite_guarded_clauses") in - - let is_complete = pats_complete l env (List.map (fun (pat, guard, body, cl_annot) -> construct_pexp (pat, guard, body, annot_from_clause cl_annot)) cs) pat_typ in - let fallthrough = if not is_complete then [fix_fallthrough (destruct_pexp (mk_fallthrough l env pat_typ typ))] else [] in + match (guard, fallthrough) with + | Some exp, _ :: _ -> + let else_exp = case_exp (pat_to_exp current_pat) (typ_of body) fallthrough in + annot_exp (E_if (exp, body, else_exp)) (fst annot).loc (env_of exp) (typ_of body) + | _, _ -> body + ) + | [] -> raise (Reporting.err_unreachable l __POS__ "if_exp given empty list in rewrite_guarded_clauses") + in + + let is_complete = + pats_complete l env + (List.map (fun (pat, guard, body, cl_annot) -> construct_pexp (pat, guard, body, annot_from_clause cl_annot)) cs) + pat_typ + in + let fallthrough = + if not is_complete then [fix_fallthrough (destruct_pexp (mk_fallthrough l env pat_typ typ))] else [] + in group [] (cs @ fallthrough) -let rewrite_guarded_clauses mk_fallthrough l env pat_typ typ (cs : (tannot pat * tannot exp option * tannot exp * tannot annot) list) = +let rewrite_guarded_clauses mk_fallthrough l env pat_typ typ + (cs : (tannot pat * tannot exp option * tannot exp * tannot annot) list) = let map_clause_annot f cs = List.map (fun (pat, guard, body, annot) -> (pat, guard, body, f annot)) cs in let cs = map_clause_annot (fun (l, tannot) -> (mk_def_annot l, tannot)) cs in rewrite_toplevel_guarded_clauses mk_fallthrough l env pat_typ typ cs @@ -983,80 +1002,81 @@ let mk_pattern_match_failure_pexp l env pat_typ typ = construct_pexp (p, None, e, (gen_loc l, ann)) let mk_rethrow_pexp l env pat_typ typ = - let (p, env') = bind_pat_no_guard env (mk_pat (P_id (mk_id "e"))) pat_typ in + let p, env' = bind_pat_no_guard env (mk_pat (P_id (mk_id "e"))) pat_typ in let (E_aux (_, a) as e) = check_exp env' (mk_exp ~loc:(gen_loc l) (E_throw (mk_exp (E_id (mk_id "e"))))) typ in construct_pexp (p, None, e, a) let bitwise_and_exp exp1 exp2 = - let (E_aux (_,(l,_))) = exp1 in + let (E_aux (_, (l, _))) = exp1 in let andid = Id_aux (Id "and_bool", gen_loc l) in - annot_exp (E_app(andid,[exp1;exp2])) l (env_of exp1) bool_typ + annot_exp (E_app (andid, [exp1; exp2])) l (env_of exp1) bool_typ -let compose_guard_opt g1 g2 = match g1, g2 with +let compose_guard_opt g1 g2 = + match (g1, g2) with | Some g1, Some g2 -> Some (bitwise_and_exp g1 g2) | Some g1, None -> Some g1 | None, Some g2 -> Some g2 | None, None -> None -let rec contains_bitvector_pat (P_aux (pat,annot)) = match pat with -| P_lit _ | P_wild | P_id _ -> false -| P_vector_subrange _ -> true -| P_as (pat,_) | P_typ (_,pat) | P_var (pat,_) -> contains_bitvector_pat pat -| P_or(pat1, pat2) -> contains_bitvector_pat pat1 || contains_bitvector_pat pat2 -| P_not(pat) -> contains_bitvector_pat pat -| P_vector _ | P_vector_concat _ -> - let typ = Env.base_typ_of (env_of_annot annot) (typ_of_annot annot) in - is_bitvector_typ typ -| P_app (_,pats) | P_tuple pats | P_list pats -> - List.exists contains_bitvector_pat pats -| P_cons (p,ps) -> contains_bitvector_pat p || contains_bitvector_pat ps -| P_string_append (ps) -> List.exists contains_bitvector_pat ps +let rec contains_bitvector_pat (P_aux (pat, annot)) = + match pat with + | P_lit _ | P_wild | P_id _ -> false + | P_vector_subrange _ -> true + | P_as (pat, _) | P_typ (_, pat) | P_var (pat, _) -> contains_bitvector_pat pat + | P_or (pat1, pat2) -> contains_bitvector_pat pat1 || contains_bitvector_pat pat2 + | P_not pat -> contains_bitvector_pat pat + | P_vector _ | P_vector_concat _ -> + let typ = Env.base_typ_of (env_of_annot annot) (typ_of_annot annot) in + is_bitvector_typ typ + | P_app (_, pats) | P_tuple pats | P_list pats -> List.exists contains_bitvector_pat pats + | P_cons (p, ps) -> contains_bitvector_pat p || contains_bitvector_pat ps + | P_string_append ps -> List.exists contains_bitvector_pat ps let contains_bitvector_pexp = function -| Pat_aux (Pat_exp (pat,_),_) | Pat_aux (Pat_when (pat,_,_),_) -> - contains_bitvector_pat pat + | Pat_aux (Pat_exp (pat, _), _) | Pat_aux (Pat_when (pat, _, _), _) -> contains_bitvector_pat pat (* Rewrite bitvector patterns to guarded patterns *) let remove_bitvector_pat (P_aux (_, (l, _)) as pat) = - - let env = try env_of_pat pat with _ -> raise (Reporting.err_unreachable l __POS__ "Pattern without annotation found") in + let env = + try env_of_pat pat with _ -> raise (Reporting.err_unreachable l __POS__ "Pattern without annotation found") + in (* first introduce names for bitvector patterns *) let name_bitvector_roots = - { p_lit = (fun lit -> P_lit lit) - ; p_typ = (fun (typ,p) -> P_typ (typ,p false)) - ; p_wild = P_wild - (* todo: I have no idea what the boolean parameter means - so I randomly - * passed "true". A comment to explain the bool might be a good idea? - *) - ; p_or = (fun (pat1, pat2) -> P_or (pat1 true, pat2 true)) - ; p_not = (fun pat -> P_not (pat true)) - ; p_as = (fun (pat,id) -> P_as (pat true,id)) - ; p_id = (fun id -> P_id id) - ; p_var = (fun (pat,kid) -> P_var (pat true,kid)) - ; p_app = (fun (id,ps) -> P_app (id, List.map (fun p -> p false) ps)) - ; p_vector = (fun ps -> P_vector (List.map (fun p -> p false) ps)) - ; p_vector_concat = (fun ps -> P_vector_concat (List.map (fun p -> p false) ps)) - ; p_vector_subrange = (fun (id, n, m) -> P_vector_subrange (id, n, m)) - ; p_string_append = (fun ps -> P_string_append (List.map (fun p -> p false) ps)) - ; p_tuple = (fun ps -> P_tuple (List.map (fun p -> p false) ps)) - ; p_list = (fun ps -> P_list (List.map (fun p -> p false) ps)) - ; p_cons = (fun (p,ps) -> P_cons (p false, ps false)) - ; p_aux = - (fun (pat,annot) contained_in_p_as -> + { + p_lit = (fun lit -> P_lit lit); + p_typ = (fun (typ, p) -> P_typ (typ, p false)); + p_wild = + P_wild + (* todo: I have no idea what the boolean parameter means - so I randomly + * passed "true". A comment to explain the bool might be a good idea? + *); + p_or = (fun (pat1, pat2) -> P_or (pat1 true, pat2 true)); + p_not = (fun pat -> P_not (pat true)); + p_as = (fun (pat, id) -> P_as (pat true, id)); + p_id = (fun id -> P_id id); + p_var = (fun (pat, kid) -> P_var (pat true, kid)); + p_app = (fun (id, ps) -> P_app (id, List.map (fun p -> p false) ps)); + p_vector = (fun ps -> P_vector (List.map (fun p -> p false) ps)); + p_vector_concat = (fun ps -> P_vector_concat (List.map (fun p -> p false) ps)); + p_vector_subrange = (fun (id, n, m) -> P_vector_subrange (id, n, m)); + p_string_append = (fun ps -> P_string_append (List.map (fun p -> p false) ps)); + p_tuple = (fun ps -> P_tuple (List.map (fun p -> p false) ps)); + p_list = (fun ps -> P_list (List.map (fun p -> p false) ps)); + p_cons = (fun (p, ps) -> P_cons (p false, ps false)); + p_aux = + (fun (pat, annot) contained_in_p_as -> let env = env_of_annot annot in let t = Env.base_typ_of env (typ_of_annot annot) in - let (l,_) = annot in - match pat, is_bitvector_typ t, contained_in_p_as with - | P_vector _, true, false -> - P_aux (P_as (P_aux (pat,annot),fresh_id "b__" l), annot) - | _ -> P_aux (pat,annot) - ) - } in - let pat, env = bind_pat_no_guard env - (strip_pat ((fold_pat name_bitvector_roots pat) false)) - (typ_of_pat pat) in + let l, _ = annot in + match (pat, is_bitvector_typ t, contained_in_p_as) with + | P_vector _, true, false -> P_aux (P_as (P_aux (pat, annot), fresh_id "b__" l), annot) + | _ -> P_aux (pat, annot) + ); + } + in + let pat, env = bind_pat_no_guard env (strip_pat ((fold_pat name_bitvector_roots pat) false)) (typ_of_pat pat) in (* Then collect guard expressions testing whether the literal bits of a bitvector pattern match those of a given bitvector, and collect let @@ -1067,265 +1087,291 @@ let remove_bitvector_pat (P_aux (_, (l, _)) as pat) = let mk_num_exp i = mk_lit_exp (L_num i) in let check_eq_exp l r = let exp = mk_exp (E_app_infix (l, Id_aux (Operator "==", Parse_ast.Unknown), r)) in - check_exp (Env.no_casts env) exp bool_typ in + check_exp (Env.no_casts env) exp bool_typ + in let access_bit_exp rootid l typ idx = let access_aux = E_vector_access (mk_exp (E_id rootid), mk_num_exp idx) in - check_exp env (mk_exp access_aux) bit_typ in + check_exp env (mk_exp access_aux) bit_typ + in let test_subvec_exp rootid l typ i j lits = let start = vector_start_index typ in - let (length, ord, _) = vector_typ_args_of typ in + let length, ord, _ = vector_typ_args_of typ in let subvec_exp = - match start, length with + match (start, length) with | Nexp_aux (Nexp_constant s, _), Nexp_aux (Nexp_constant l, _) when Big_int.equal s i && Big_int.equal l (Big_int.of_int (List.length lits)) -> - mk_exp (E_id rootid) - | _ -> - mk_exp (E_vector_subrange (mk_exp (E_id rootid), mk_num_exp i, mk_num_exp j)) in - check_eq_exp subvec_exp (mk_exp (E_vector (List.map strip_exp lits))) in + mk_exp (E_id rootid) + | _ -> mk_exp (E_vector_subrange (mk_exp (E_id rootid), mk_num_exp i, mk_num_exp j)) + in + check_eq_exp subvec_exp (mk_exp (E_vector (List.map strip_exp lits))) + in let letbind_bit_exp rootid l typ idx id = let elem = access_bit_exp rootid l typ idx in let e = annot_pat (P_id id) l env bit_typ in - let letbind = LB_aux (LB_val (e,elem), (l, mk_tannot env bit_typ)) in - let letexp = (fun body -> - let (E_aux (_,(_,bannot))) = body in - if IdSet.mem id (find_used_vars body) - then annot_exp (E_let (letbind,body)) l env (typ_of body) - else body) in - (letexp, letbind) in + let letbind = LB_aux (LB_val (e, elem), (l, mk_tannot env bit_typ)) in + let letexp body = + let (E_aux (_, (_, bannot))) = body in + if IdSet.mem id (find_used_vars body) then annot_exp (E_let (letbind, body)) l env (typ_of body) else body + in + (letexp, letbind) + in let compose_guards guards = List.fold_right compose_guard_opt guards None in let flatten_guards_decls gd = - let (guards,decls,letbinds) = Util.split3 gd in - (compose_guards guards, (List.fold_right (@@) decls), List.flatten letbinds) in + let guards, decls, letbinds = Util.split3 gd in + (compose_guards guards, List.fold_right ( @@ ) decls, List.flatten letbinds) + in (* Collect guards and let bindings *) let guard_bitvector_pat = let collect_guards_decls ps rootid t = let start = vector_start_index t in - let (_,ord,_) = vector_typ_args_of t in - let start_idx = match start with - | Nexp_aux (Nexp_constant s, _) -> s - | _ -> - raise (Reporting.err_unreachable l __POS__ - "guard_bitvector_pat called on pattern with non-constant start index") in + let _, ord, _ = vector_typ_args_of t in + let start_idx = + match start with + | Nexp_aux (Nexp_constant s, _) -> s + | _ -> + raise + (Reporting.err_unreachable l __POS__ "guard_bitvector_pat called on pattern with non-constant start index") + in let add_bit_pat (idx, current, guards, dls) pat = let idx' = - if is_order_inc ord - then Big_int.add idx (Big_int.of_int 1) - else Big_int.sub idx (Big_int.of_int 1) in - let ids = fst (fold_pat - { (compute_pat_alg IdSet.empty IdSet.union) with - p_id = (fun id -> IdSet.singleton id, P_id id); - p_as = (fun ((ids, pat), id) -> IdSet.add id ids, P_as (pat, id)) } - pat) in - let lits = fst (fold_pat - { (compute_pat_alg [] (@)) with - p_aux = (fun ((lits, paux), (l, annot)) -> - let lits = match paux with - | P_lit lit -> E_aux (E_lit lit, (l, annot)) :: lits - | _ -> lits in - lits, P_aux (paux, (l, annot))) } - pat) in + if is_order_inc ord then Big_int.add idx (Big_int.of_int 1) else Big_int.sub idx (Big_int.of_int 1) + in + let ids = + fst + (fold_pat + { + (compute_pat_alg IdSet.empty IdSet.union) with + p_id = (fun id -> (IdSet.singleton id, P_id id)); + p_as = (fun ((ids, pat), id) -> (IdSet.add id ids, P_as (pat, id))); + } + pat + ) + in + let lits = + fst + (fold_pat + { + (compute_pat_alg [] ( @ )) with + p_aux = + (fun ((lits, paux), (l, annot)) -> + let lits = match paux with P_lit lit -> E_aux (E_lit lit, (l, annot)) :: lits | _ -> lits in + (lits, P_aux (paux, (l, annot))) + ); + } + pat + ) + in let add_letbind id dls = dls @ [letbind_bit_exp rootid l t idx id] in let dls' = IdSet.fold add_letbind ids dls in let current', guards' = match current with | Some (l, i, j, lits') -> - if lits = [] - then None, guards @ [Some (test_subvec_exp rootid l t i j lits')] - else Some (l, i, idx, lits' @ lits), guards - | None -> - begin - match lits with - | E_aux (_, (l, _)) :: _ -> Some (l, idx, idx, lits), guards - | [] -> None, guards - end - in - (idx', current', guards', dls') in - let (_, final, guards, dls) = List.fold_left add_bit_pat (start_idx, None, [], []) ps in - let guards = match final with - | Some (l,i,j,lits) -> - guards @ [Some (test_subvec_exp rootid l t i j lits)] - | None -> guards in - let (decls,letbinds) = List.split dls in - (compose_guards guards, List.fold_right (@@) decls, letbinds) in - - { p_lit = (fun lit -> (P_lit lit, (None, (fun b -> b), []))) - ; p_wild = (P_wild, (None, (fun b -> b), [])) - ; p_or = (fun ((pat1, gdl1), (pat2, gdl2)) -> - (P_or(pat1, pat2), flatten_guards_decls [gdl1; gdl2])) - ; p_not = (fun (pat, gdl) -> (P_not(pat), gdl)) - ; p_as = (fun ((pat,gdls),id) -> (P_as (pat,id), gdls)) - ; p_typ = (fun (typ,(pat,gdls)) -> (P_typ (typ,pat), gdls)) - ; p_id = (fun id -> (P_id id, (None, (fun b -> b), []))) - ; p_var = (fun ((pat,gdls),kid) -> (P_var (pat,kid), gdls)) - ; p_app = (fun (id,ps) -> let (ps,gdls) = List.split ps in - (P_app (id,ps), flatten_guards_decls gdls)) - ; p_vector = (fun ps -> let (ps,gdls) = List.split ps in - (P_vector ps, flatten_guards_decls gdls)) - ; p_vector_concat = (fun ps -> let (ps,gdls) = List.split ps in - (P_vector_concat ps, flatten_guards_decls gdls)) - ; p_vector_subrange = (fun (id, n, m) -> (P_vector_subrange (id, n, m), (None, (fun b -> b), []))) - ; p_string_append = (fun ps -> let (ps,gdls) = List.split ps in - (P_string_append ps, flatten_guards_decls gdls)) - ; p_tuple = (fun ps -> let (ps,gdls) = List.split ps in - (P_tuple ps, flatten_guards_decls gdls)) - ; p_list = (fun ps -> let (ps,gdls) = List.split ps in - (P_list ps, flatten_guards_decls gdls)) - ; p_cons = (fun ((p,gdls),(p',gdls')) -> - (P_cons (p,p'), flatten_guards_decls [gdls;gdls'])) - ; p_aux = (fun ((pat,gdls),annot) -> - let env = env_of_annot annot in - let t = Env.base_typ_of env (typ_of_annot annot) in - (match pat, is_bitvector_typ t with - | P_as (P_aux (P_vector ps, _), id), true -> - (P_aux (P_id id, annot), collect_guards_decls ps id t) - | _, _ -> (P_aux (pat,annot), gdls))) - } in + if lits = [] then (None, guards @ [Some (test_subvec_exp rootid l t i j lits')]) + else (Some (l, i, idx, lits' @ lits), guards) + | None -> begin + match lits with E_aux (_, (l, _)) :: _ -> (Some (l, idx, idx, lits), guards) | [] -> (None, guards) + end + in + (idx', current', guards', dls') + in + let _, final, guards, dls = List.fold_left add_bit_pat (start_idx, None, [], []) ps in + let guards = + match final with + | Some (l, i, j, lits) -> guards @ [Some (test_subvec_exp rootid l t i j lits)] + | None -> guards + in + let decls, letbinds = List.split dls in + (compose_guards guards, List.fold_right ( @@ ) decls, letbinds) + in + + { + p_lit = (fun lit -> (P_lit lit, (None, (fun b -> b), []))); + p_wild = (P_wild, (None, (fun b -> b), [])); + p_or = (fun ((pat1, gdl1), (pat2, gdl2)) -> (P_or (pat1, pat2), flatten_guards_decls [gdl1; gdl2])); + p_not = (fun (pat, gdl) -> (P_not pat, gdl)); + p_as = (fun ((pat, gdls), id) -> (P_as (pat, id), gdls)); + p_typ = (fun (typ, (pat, gdls)) -> (P_typ (typ, pat), gdls)); + p_id = (fun id -> (P_id id, (None, (fun b -> b), []))); + p_var = (fun ((pat, gdls), kid) -> (P_var (pat, kid), gdls)); + p_app = + (fun (id, ps) -> + let ps, gdls = List.split ps in + (P_app (id, ps), flatten_guards_decls gdls) + ); + p_vector = + (fun ps -> + let ps, gdls = List.split ps in + (P_vector ps, flatten_guards_decls gdls) + ); + p_vector_concat = + (fun ps -> + let ps, gdls = List.split ps in + (P_vector_concat ps, flatten_guards_decls gdls) + ); + p_vector_subrange = (fun (id, n, m) -> (P_vector_subrange (id, n, m), (None, (fun b -> b), []))); + p_string_append = + (fun ps -> + let ps, gdls = List.split ps in + (P_string_append ps, flatten_guards_decls gdls) + ); + p_tuple = + (fun ps -> + let ps, gdls = List.split ps in + (P_tuple ps, flatten_guards_decls gdls) + ); + p_list = + (fun ps -> + let ps, gdls = List.split ps in + (P_list ps, flatten_guards_decls gdls) + ); + p_cons = (fun ((p, gdls), (p', gdls')) -> (P_cons (p, p'), flatten_guards_decls [gdls; gdls'])); + p_aux = + (fun ((pat, gdls), annot) -> + let env = env_of_annot annot in + let t = Env.base_typ_of env (typ_of_annot annot) in + match (pat, is_bitvector_typ t) with + | P_as (P_aux (P_vector ps, _), id), true -> (P_aux (P_id id, annot), collect_guards_decls ps id t) + | _, _ -> (P_aux (pat, annot), gdls) + ); + } + in fold_pat guard_bitvector_pat pat -let rewrite_exp_remove_bitvector_pat rewriters (E_aux (exp,(l,annot)) as full_exp) = - let rewrap e = E_aux (e,(l,annot)) in +let rewrite_exp_remove_bitvector_pat rewriters (E_aux (exp, (l, annot)) as full_exp) = + let rewrap e = E_aux (e, (l, annot)) in let rewrite_rec = rewriters.rewrite_exp rewriters in let rewrite_base = rewrite_exp rewriters in match exp with - | E_match (e,ps) - when List.exists contains_bitvector_pexp ps -> - let rewrite_pexp = function - | Pat_aux (Pat_exp (pat,body),annot') -> - let (pat',(guard',decls,_)) = remove_bitvector_pat pat in - let body' = decls (rewrite_rec body) in - (match guard' with - | Some guard' -> Pat_aux (Pat_when (pat', guard', body'), annot') - | None -> Pat_aux (Pat_exp (pat', body'), annot')) - | Pat_aux (Pat_when (pat,guard,body),annot') -> - let (pat',(guard',decls,_)) = remove_bitvector_pat pat in - let guard'' = rewrite_rec guard in - let body' = decls (rewrite_rec body) in - (match guard' with - | Some guard' -> Pat_aux (Pat_when (pat', bitwise_and_exp (decls guard'') guard', body'), annot') - | None -> Pat_aux (Pat_when (pat', (decls guard''), body'), annot')) in - rewrap (E_match (e, List.map rewrite_pexp ps)) - | E_let (LB_aux (LB_val (pat,v),annot'),body) -> - let (pat,(_,decls,_)) = remove_bitvector_pat pat in - rewrap (E_let (LB_aux (LB_val (pat,rewrite_rec v),annot'), - decls (rewrite_rec body))) + | E_match (e, ps) when List.exists contains_bitvector_pexp ps -> + let rewrite_pexp = function + | Pat_aux (Pat_exp (pat, body), annot') -> ( + let pat', (guard', decls, _) = remove_bitvector_pat pat in + let body' = decls (rewrite_rec body) in + match guard' with + | Some guard' -> Pat_aux (Pat_when (pat', guard', body'), annot') + | None -> Pat_aux (Pat_exp (pat', body'), annot') + ) + | Pat_aux (Pat_when (pat, guard, body), annot') -> ( + let pat', (guard', decls, _) = remove_bitvector_pat pat in + let guard'' = rewrite_rec guard in + let body' = decls (rewrite_rec body) in + match guard' with + | Some guard' -> Pat_aux (Pat_when (pat', bitwise_and_exp (decls guard'') guard', body'), annot') + | None -> Pat_aux (Pat_when (pat', decls guard'', body'), annot') + ) + in + rewrap (E_match (e, List.map rewrite_pexp ps)) + | E_let (LB_aux (LB_val (pat, v), annot'), body) -> + let pat, (_, decls, _) = remove_bitvector_pat pat in + rewrap (E_let (LB_aux (LB_val (pat, rewrite_rec v), annot'), decls (rewrite_rec body))) | _ -> rewrite_base full_exp -let rewrite_fun_remove_bitvector_pat - rewriters (FD_aux (FD_function(recopt,tannotopt,funcls),(l,fdannot))) = +let rewrite_fun_remove_bitvector_pat rewriters (FD_aux (FD_function (recopt, tannotopt, funcls), (l, fdannot))) = let _ = reset_fresh_name_counter () in - let funcls = match funcls with - | (FCL_aux (FCL_funcl(id,_),_) :: _) -> - let clause (FCL_aux (FCL_funcl(_,pexp), fcl_annot)) = - let pat,fguard,exp,pannot = destruct_pexp pexp in - let (pat,(guard,decls,_)) = remove_bitvector_pat pat in - let guard = match guard,fguard with - | None,e | e,None -> e - | Some g, Some wh -> - Some (bitwise_and_exp g (decls (rewriters.rewrite_exp rewriters wh))) + let funcls = + match funcls with + | FCL_aux (FCL_funcl (id, _), _) :: _ -> + let clause (FCL_aux (FCL_funcl (_, pexp), fcl_annot)) = + let pat, fguard, exp, pannot = destruct_pexp pexp in + let pat, (guard, decls, _) = remove_bitvector_pat pat in + let guard = + match (guard, fguard) with + | None, e | e, None -> e + | Some g, Some wh -> Some (bitwise_and_exp g (decls (rewriters.rewrite_exp rewriters wh))) in let exp = decls (rewriters.rewrite_exp rewriters exp) in (* AA: Why can't this use pannot ? *) - FCL_aux (FCL_funcl (id,construct_pexp (pat,guard,exp, ((fst fcl_annot).loc, snd fcl_annot))), fcl_annot) in + FCL_aux (FCL_funcl (id, construct_pexp (pat, guard, exp, ((fst fcl_annot).loc, snd fcl_annot))), fcl_annot) + in List.map clause funcls - | _ -> funcls in - FD_aux (FD_function(recopt,tannotopt,funcls),(l,fdannot)) + | _ -> funcls + in + FD_aux (FD_function (recopt, tannotopt, funcls), (l, fdannot)) let rewrite_ast_remove_bitvector_pats env ast = let rewriters = - {rewrite_exp = rewrite_exp_remove_bitvector_pat; - rewrite_pat = rewrite_pat; - rewrite_let = rewrite_let; - rewrite_lexp = rewrite_lexp; - rewrite_fun = rewrite_fun_remove_bitvector_pat; - rewrite_def = rewrite_def; - rewrite_ast = rewrite_ast_base } in + { + rewrite_exp = rewrite_exp_remove_bitvector_pat; + rewrite_pat; + rewrite_let; + rewrite_lexp; + rewrite_fun = rewrite_fun_remove_bitvector_pat; + rewrite_def; + rewrite_ast = rewrite_ast_base; + } + in let rewrite_def d = let d = rewriters.rewrite_def rewriters d in match d with - | DEF_aux (DEF_let (LB_aux (LB_val (pat,exp),a)), def_annot) -> - let (pat',(_,_,letbinds)) = remove_bitvector_pat pat in - let defvals = List.map (fun lb -> DEF_aux (DEF_let lb, mk_def_annot (gen_loc def_annot.loc))) letbinds in - [DEF_aux (DEF_let (LB_aux (LB_val (pat',exp),a)), def_annot)] @ defvals - | d -> [d] in + | DEF_aux (DEF_let (LB_aux (LB_val (pat, exp), a)), def_annot) -> + let pat', (_, _, letbinds) = remove_bitvector_pat pat in + let defvals = List.map (fun lb -> DEF_aux (DEF_let lb, mk_def_annot (gen_loc def_annot.loc))) letbinds in + [DEF_aux (DEF_let (LB_aux (LB_val (pat', exp), a)), def_annot)] @ defvals + | d -> [d] + in (* FIXME See above in rewrite_sizeof *) (* fst (check initial_env ( *) { ast with defs = List.flatten (List.map rewrite_def ast.defs) } - (* )) *) +(* )) *) (* Rewrite literal number patterns to guarded patterns Those numeral patterns are not handled very well by Lem (or Isabelle) - *) +*) let rewrite_ast_remove_numeral_pats env = let p_lit outer_env = function | L_aux (L_num n, l) -> - let id = fresh_id "l__" Parse_ast.Unknown in - let typ = atom_typ (nconstant n) in - let guard = - mk_exp (E_app_infix ( - mk_exp (E_id id), - mk_id "==", - mk_lit_exp (L_num n) - )) in - (* Check expression in reasonable approx of environment to resolve overriding *) - let env = Env.add_local id (Immutable, typ) outer_env in - let checked_guard = check_exp env guard bool_typ in - (Some checked_guard, P_id id) - | lit -> (None, P_lit lit) in - let guard_pat outer_env = - fold_pat { (compute_pat_alg None compose_guard_opt) with p_lit = p_lit outer_env } in + let id = fresh_id "l__" Parse_ast.Unknown in + let typ = atom_typ (nconstant n) in + let guard = mk_exp (E_app_infix (mk_exp (E_id id), mk_id "==", mk_lit_exp (L_num n))) in + (* Check expression in reasonable approx of environment to resolve overriding *) + let env = Env.add_local id (Immutable, typ) outer_env in + let checked_guard = check_exp env guard bool_typ in + (Some checked_guard, P_id id) + | lit -> (None, P_lit lit) + in + let guard_pat outer_env = fold_pat { (compute_pat_alg None compose_guard_opt) with p_lit = p_lit outer_env } in let pat_aux (pexp_aux, a) = - let pat,guard,exp,a = destruct_pexp (Pat_aux (pexp_aux, a)) in - let guard',pat = guard_pat (env_of_pat pat) pat in + let pat, guard, exp, a = destruct_pexp (Pat_aux (pexp_aux, a)) in + let guard', pat = guard_pat (env_of_pat pat) pat in match compose_guard_opt guard guard' with | Some g -> Pat_aux (Pat_when (pat, g, exp), a) - | None -> Pat_aux (Pat_exp (pat, exp), a) in - let exp_alg = { id_exp_alg with pat_aux = pat_aux } in + | None -> Pat_aux (Pat_exp (pat, exp), a) + in + let exp_alg = { id_exp_alg with pat_aux } in let rewrite_exp _ = fold_exp exp_alg in - let rewrite_funcl (FCL_aux (FCL_funcl (id, pexp), annot)) = - FCL_aux (FCL_funcl (id, fold_pexp exp_alg pexp), annot) in + let rewrite_funcl (FCL_aux (FCL_funcl (id, pexp), annot)) = FCL_aux (FCL_funcl (id, fold_pexp exp_alg pexp), annot) in let rewrite_fun _ (FD_aux (FD_function (r_o, t_o, funcls), a)) = - FD_aux (FD_function (r_o, t_o, List.map rewrite_funcl funcls), a) in - rewrite_ast_base - { rewriters_base with rewrite_exp = rewrite_exp; rewrite_fun = rewrite_fun } + FD_aux (FD_function (r_o, t_o, List.map rewrite_funcl funcls), a) + in + rewrite_ast_base { rewriters_base with rewrite_exp; rewrite_fun } let rewrite_ast_vector_string_pats_to_bit_list env = let rewrite_p_aux (pat, (annot : tannot annot)) = - match pat with - | P_lit lit -> vector_string_to_bits_pat lit annot - | pat -> (P_aux (pat, annot)) + match pat with P_lit lit -> vector_string_to_bits_pat lit annot | pat -> P_aux (pat, annot) in let rewrite_e_aux (exp, (annot : tannot annot)) = - match exp with - | E_lit lit -> vector_string_to_bits_exp lit annot - | exp -> (E_aux (exp, annot)) + match exp with E_lit lit -> vector_string_to_bits_exp lit annot | exp -> E_aux (exp, annot) in let pat_alg = { id_pat_alg with p_aux = rewrite_p_aux } in - let rewrite_pat rw pat = - fold_pat pat_alg pat - in - let rewrite_exp rw exp = - fold_exp { id_exp_alg with e_aux = rewrite_e_aux; pat_alg = pat_alg } exp - in - rewrite_ast_base { rewriters_base with rewrite_pat = rewrite_pat; rewrite_exp = rewrite_exp } + let rewrite_pat rw pat = fold_pat pat_alg pat in + let rewrite_exp rw exp = fold_exp { id_exp_alg with e_aux = rewrite_e_aux; pat_alg } exp in + rewrite_ast_base { rewriters_base with rewrite_pat; rewrite_exp } let rewrite_bit_lists_to_lits env = (* TODO Make all rewriting passes support bitvector literals instead of converting back and forth *) let open Sail2_values in - let bit_of_lit = function - | L_aux (L_zero, _) -> Some B0 - | L_aux (L_one, _) -> Some B1 - | _ -> None - in + let bit_of_lit = function L_aux (L_zero, _) -> Some B0 | L_aux (L_one, _) -> Some B1 | _ -> None in let bit_of_exp = function E_aux (E_lit lit, _) -> bit_of_lit lit | _ -> None in let string_of_chars cs = String.concat "" (List.map (String.make 1) cs) in - let lit_of_bits bits = match hexstring_of_bits bits with + let lit_of_bits bits = + match hexstring_of_bits bits with | Some h -> L_hex (string_of_chars h) | None -> L_bin (string_of_chars (List.map bitU_char bits)) in @@ -1335,156 +1381,169 @@ let rewrite_bit_lists_to_lits env = let env = env_of_annot (l, annot) in let typ = typ_of_annot (l, annot) in match e with - | E_vector es when is_bitvector_typ typ -> - (match just_list (List.map bit_of_exp es) with - | Some bits -> - check_exp env (mk_exp (E_typ (typ, mk_lit_exp (lit_of_bits bits)))) typ - | None -> rewrap e) + | E_vector es when is_bitvector_typ typ -> ( + match just_list (List.map bit_of_exp es) with + | Some bits -> check_exp env (mk_exp (E_typ (typ, mk_lit_exp (lit_of_bits bits)))) typ + | None -> rewrap e + ) | E_typ (typ', E_aux (E_typ (_, e'), _)) -> rewrap (E_typ (typ', e')) | _ -> rewrap e with _ -> rewrap e in - let rewrite_exp rw = fold_exp { id_exp_alg with e_aux = e_aux; } in - rewrite_ast_base { rewriters_base with rewrite_exp = rewrite_exp } + let rewrite_exp rw = fold_exp { id_exp_alg with e_aux } in + rewrite_ast_base { rewriters_base with rewrite_exp } (* Remove pattern guards by rewriting them to if-expressions within the pattern expression. *) -let rewrite_exp_guarded_pats rewriters (E_aux (exp,(l,annot)) as full_exp) = - let rewrap e = E_aux (e,(l,annot)) in +let rewrite_exp_guarded_pats rewriters (E_aux (exp, (l, annot)) as full_exp) = + let rewrap e = E_aux (e, (l, annot)) in let rewrite_rec = rewriters.rewrite_exp rewriters in let rewrite_base = rewrite_exp rewriters in - let is_guarded_pexp = function - | Pat_aux (Pat_when (_,_,_),_) -> true - | _ -> false - in + let is_guarded_pexp = function Pat_aux (Pat_when (_, _, _), _) -> true | _ -> false in (* Also rewrite potentially incomplete pattern matches, adding a fallthrough clause *) match exp with - | E_match (e,ps) - when List.exists is_guarded_pexp ps || not (pats_complete l (env_of full_exp) ps (typ_of full_exp)) -> - let clause = function - | Pat_aux (Pat_exp (pat, body), annot) -> - (pat, None, rewrite_rec body, annot) - | Pat_aux (Pat_when (pat, guard, body), annot) -> - (pat, Some (rewrite_rec guard), rewrite_rec body, annot) in - let clauses = rewrite_guarded_clauses mk_pattern_match_failure_pexp l (env_of full_exp) (typ_of e) (typ_of full_exp) (List.map clause ps) in - let e = rewrite_rec e in - if (effectful e) then - let (E_aux (_,(el,eannot))) = e in - let pat_e' = fresh_id_pat "p__" (el, mk_tannot (env_of e) (typ_of e)) in - let exp_e' = pat_to_exp pat_e' in - let letbind_e = LB_aux (LB_val (pat_e',e), (el,eannot)) in - let exp' = case_exp exp_e' (typ_of full_exp) clauses in - rewrap (E_let (letbind_e, exp')) - else case_exp e (typ_of full_exp) clauses - | E_try (e,ps) - when List.exists is_guarded_pexp ps || not (pats_complete l (env_of full_exp) ps (typ_of full_exp)) -> - let e = rewrite_rec e in - let clause = function - | Pat_aux (Pat_exp (pat, body), annot) -> - (pat, None, rewrite_rec body, annot) - | Pat_aux (Pat_when (pat, guard, body), annot) -> - (pat, Some (rewrite_rec guard), rewrite_rec body, annot) in - let clauses = rewrite_guarded_clauses mk_rethrow_pexp l (env_of full_exp) exc_typ (typ_of full_exp) (List.map clause ps) in - let pexp (pat,body,annot) = Pat_aux (Pat_exp (pat,body),annot) in - let ps = List.map pexp clauses in - annot_exp (E_try (e,ps)) l (env_of full_exp) (typ_of full_exp) + | E_match (e, ps) when List.exists is_guarded_pexp ps || not (pats_complete l (env_of full_exp) ps (typ_of full_exp)) + -> + let clause = function + | Pat_aux (Pat_exp (pat, body), annot) -> (pat, None, rewrite_rec body, annot) + | Pat_aux (Pat_when (pat, guard, body), annot) -> (pat, Some (rewrite_rec guard), rewrite_rec body, annot) + in + let clauses = + rewrite_guarded_clauses mk_pattern_match_failure_pexp l (env_of full_exp) (typ_of e) (typ_of full_exp) + (List.map clause ps) + in + let e = rewrite_rec e in + if effectful e then ( + let (E_aux (_, (el, eannot))) = e in + let pat_e' = fresh_id_pat "p__" (el, mk_tannot (env_of e) (typ_of e)) in + let exp_e' = pat_to_exp pat_e' in + let letbind_e = LB_aux (LB_val (pat_e', e), (el, eannot)) in + let exp' = case_exp exp_e' (typ_of full_exp) clauses in + rewrap (E_let (letbind_e, exp')) + ) + else case_exp e (typ_of full_exp) clauses + | E_try (e, ps) when List.exists is_guarded_pexp ps || not (pats_complete l (env_of full_exp) ps (typ_of full_exp)) -> + let e = rewrite_rec e in + let clause = function + | Pat_aux (Pat_exp (pat, body), annot) -> (pat, None, rewrite_rec body, annot) + | Pat_aux (Pat_when (pat, guard, body), annot) -> (pat, Some (rewrite_rec guard), rewrite_rec body, annot) + in + let clauses = + rewrite_guarded_clauses mk_rethrow_pexp l (env_of full_exp) exc_typ (typ_of full_exp) (List.map clause ps) + in + let pexp (pat, body, annot) = Pat_aux (Pat_exp (pat, body), annot) in + let ps = List.map pexp clauses in + annot_exp (E_try (e, ps)) l (env_of full_exp) (typ_of full_exp) | _ -> rewrite_base full_exp -let rewrite_fun_guarded_pats rewriters (FD_aux (FD_function (r,t,funcls),(l,fdannot))) = - let funcls = match funcls with - | (FCL_aux (FCL_funcl(id,pexp), fcl_annot) :: _) -> - let clause (FCL_aux (FCL_funcl(_,pexp),annot)) = - let pat,guard,exp,_ = destruct_pexp pexp in - let exp = rewriters.rewrite_exp rewriters exp in - (pat,guard,exp,annot) in - let pexp_pat_typ, pexp_ret_typ = - let pat, _, exp, _ = destruct_pexp pexp in - (typ_of_pat pat, typ_of exp) - in - let pat_typ, ret_typ = match Env.get_val_spec_orig id (env_of_tannot (snd fcl_annot)) with - | (tq, Typ_aux (Typ_fn ([arg_typ], ret_typ), _)) -> (arg_typ, ret_typ) - | (tq, Typ_aux (Typ_fn (arg_typs, ret_typ), _)) -> (tuple_typ arg_typs, ret_typ) - | _ -> (pexp_pat_typ, pexp_ret_typ) | exception _ -> (pexp_pat_typ, pexp_ret_typ) - in - let cs = rewrite_toplevel_guarded_clauses mk_pattern_match_failure_pexp l (env_of_tannot (snd fcl_annot)) pat_typ ret_typ (List.map clause funcls) in - List.map (fun (pat,exp,annot) -> - FCL_aux (FCL_funcl(id,construct_pexp (pat,None,exp,(Parse_ast.Unknown,empty_tannot))), (mk_def_annot (fst annot), snd annot))) cs - | _ -> funcls (* TODO is the empty list possible here? *) in - FD_aux (FD_function(r,t,funcls),(l,fdannot)) +let rewrite_fun_guarded_pats rewriters (FD_aux (FD_function (r, t, funcls), (l, fdannot))) = + let funcls = + match funcls with + | FCL_aux (FCL_funcl (id, pexp), fcl_annot) :: _ -> + let clause (FCL_aux (FCL_funcl (_, pexp), annot)) = + let pat, guard, exp, _ = destruct_pexp pexp in + let exp = rewriters.rewrite_exp rewriters exp in + (pat, guard, exp, annot) + in + let pexp_pat_typ, pexp_ret_typ = + let pat, _, exp, _ = destruct_pexp pexp in + (typ_of_pat pat, typ_of exp) + in + let pat_typ, ret_typ = + match Env.get_val_spec_orig id (env_of_tannot (snd fcl_annot)) with + | tq, Typ_aux (Typ_fn ([arg_typ], ret_typ), _) -> (arg_typ, ret_typ) + | tq, Typ_aux (Typ_fn (arg_typs, ret_typ), _) -> (tuple_typ arg_typs, ret_typ) + | _ -> (pexp_pat_typ, pexp_ret_typ) + | exception _ -> (pexp_pat_typ, pexp_ret_typ) + in + let cs = + rewrite_toplevel_guarded_clauses mk_pattern_match_failure_pexp l + (env_of_tannot (snd fcl_annot)) + pat_typ ret_typ (List.map clause funcls) + in + List.map + (fun (pat, exp, annot) -> + FCL_aux + ( FCL_funcl (id, construct_pexp (pat, None, exp, (Parse_ast.Unknown, empty_tannot))), + (mk_def_annot (fst annot), snd annot) + ) + ) + cs + | _ -> funcls (* TODO is the empty list possible here? *) + in + FD_aux (FD_function (r, t, funcls), (l, fdannot)) let rewrite_ast_guarded_pats env = - rewrite_ast_base { rewriters_base with rewrite_exp = rewrite_exp_guarded_pats; - rewrite_fun = rewrite_fun_guarded_pats } - + rewrite_ast_base + { rewriters_base with rewrite_exp = rewrite_exp_guarded_pats; rewrite_fun = rewrite_fun_guarded_pats } -let rec rewrite_lexp_to_rhs ((LE_aux(lexp,((l,_) as annot))) as le) = +let rec rewrite_lexp_to_rhs (LE_aux (lexp, ((l, _) as annot)) as le) = match lexp with - | LE_id _ | LE_typ (_, _) | LE_tuple _ | LE_deref _ -> (le, (fun exp -> exp)) + | LE_id _ | LE_typ (_, _) | LE_tuple _ | LE_deref _ -> (le, fun exp -> exp) | LE_vector (lexp, e) -> - let (lhs, rhs) = rewrite_lexp_to_rhs lexp in - (lhs, (fun exp -> rhs (E_aux (E_vector_update (lexp_to_exp lexp, e, exp), annot)))) + let lhs, rhs = rewrite_lexp_to_rhs lexp in + (lhs, fun exp -> rhs (E_aux (E_vector_update (lexp_to_exp lexp, e, exp), annot))) | LE_vector_range (lexp, e1, e2) -> - let (lhs, rhs) = rewrite_lexp_to_rhs lexp in - (lhs, (fun exp -> rhs (E_aux (E_vector_update_subrange (lexp_to_exp lexp, e1, e2, exp), annot)))) - | LE_field (lexp, id) -> - begin - let (lhs, rhs) = rewrite_lexp_to_rhs lexp in - let (LE_aux (_, lannot)) = lexp in - let env = env_of_annot lannot in - match Env.expand_synonyms env (typ_of_annot lannot) with - | Typ_aux (Typ_id rectyp_id, _) | Typ_aux (Typ_app (rectyp_id, _), _) when Env.is_record rectyp_id env -> + let lhs, rhs = rewrite_lexp_to_rhs lexp in + (lhs, fun exp -> rhs (E_aux (E_vector_update_subrange (lexp_to_exp lexp, e1, e2, exp), annot))) + | LE_field (lexp, id) -> begin + let lhs, rhs = rewrite_lexp_to_rhs lexp in + let (LE_aux (_, lannot)) = lexp in + let env = env_of_annot lannot in + match Env.expand_synonyms env (typ_of_annot lannot) with + | (Typ_aux (Typ_id rectyp_id, _) | Typ_aux (Typ_app (rectyp_id, _), _)) when Env.is_record rectyp_id env -> let field_update exp = FE_aux (FE_fexp (id, exp), annot) in - (lhs, (fun exp -> rhs (E_aux (E_struct_update (lexp_to_exp lexp, [field_update exp]), lannot)))) - | _ -> raise (Reporting.err_unreachable l __POS__ ("Unsupported lexp: " ^ string_of_lexp le)) - end + (lhs, fun exp -> rhs (E_aux (E_struct_update (lexp_to_exp lexp, [field_update exp]), lannot))) + | _ -> raise (Reporting.err_unreachable l __POS__ ("Unsupported lexp: " ^ string_of_lexp le)) + end | _ -> raise (Reporting.err_unreachable l __POS__ ("Unsupported lexp: " ^ string_of_lexp le)) let updates_vars exp = - let e_assign ((_, lexp), (u, exp)) = - (u || lexp_is_local lexp (env_of exp), E_assign (lexp, exp)) in - fst (fold_exp { (compute_exp_alg false (||)) with e_assign = e_assign } exp) - + let e_assign ((_, lexp), (u, exp)) = (u || lexp_is_local lexp (env_of exp), E_assign (lexp, exp)) in + fst (fold_exp { (compute_exp_alg false ( || )) with e_assign } exp) (*Expects to be called after rewrite_ast; thus the following should not appear: internal_exp of any form lit vectors in patterns or expressions - *) -let rewrite_exp_lift_assign_intro rewriters ((E_aux (exp,((l,_) as annot))) as full_exp) = +*) +let rewrite_exp_lift_assign_intro rewriters (E_aux (exp, ((l, _) as annot)) as full_exp) = let rewrite_rec = rewriters.rewrite_exp rewriters in let rewrite_base = rewrite_exp rewriters in match exp with | E_block exps -> - let rec walker exps = match exps with - | [] -> [] - | (E_aux(E_assign(le,e), (l, tannot)))::exps - when not (is_empty_tannot tannot) && lexp_is_local_intro le (env_of_annot (l, tannot)) -> - let env = env_of_annot (l, tannot) in - let (le', re') = rewrite_lexp_to_rhs le in - let e' = re' (rewrite_base e) in - let exps' = walker exps in - let block = E_aux (E_block exps', (gen_loc l, mk_tannot env unit_typ)) in - [E_aux (E_var(le', e', block), annot)] - | e::exps -> (rewrite_rec e)::(walker exps) - in - E_aux (E_block (walker exps), annot) - - | E_assign(le,e) - when lexp_is_local_intro le (env_of full_exp) && not (lexp_is_effectful le) -> - let (le', re') = rewrite_lexp_to_rhs le in - let e' = re' (rewrite_base e) in - let block = annot_exp (E_block []) (gen_loc l) (env_of full_exp) unit_typ in - E_aux (E_var (le', e', block), annot) - + let rec walker exps = + match exps with + | [] -> [] + | E_aux (E_assign (le, e), (l, tannot)) :: exps + when (not (is_empty_tannot tannot)) && lexp_is_local_intro le (env_of_annot (l, tannot)) -> + let env = env_of_annot (l, tannot) in + let le', re' = rewrite_lexp_to_rhs le in + let e' = re' (rewrite_base e) in + let exps' = walker exps in + let block = E_aux (E_block exps', (gen_loc l, mk_tannot env unit_typ)) in + [E_aux (E_var (le', e', block), annot)] + | e :: exps -> rewrite_rec e :: walker exps + in + E_aux (E_block (walker exps), annot) + | E_assign (le, e) when lexp_is_local_intro le (env_of full_exp) && not (lexp_is_effectful le) -> + let le', re' = rewrite_lexp_to_rhs le in + let e' = re' (rewrite_base e) in + let block = annot_exp (E_block []) (gen_loc l) (env_of full_exp) unit_typ in + E_aux (E_var (le', e', block), annot) | _ -> rewrite_base full_exp -let rewrite_ast_exp_lift_assign env defs = rewrite_ast_base - {rewrite_exp = rewrite_exp_lift_assign_intro; - rewrite_pat = rewrite_pat; - rewrite_let = rewrite_let; - rewrite_lexp = rewrite_lexp (*_lift_assign_intro*); - rewrite_fun = rewrite_fun; - rewrite_def = rewrite_def; - rewrite_ast = rewrite_ast_base} defs +let rewrite_ast_exp_lift_assign env defs = + rewrite_ast_base + { + rewrite_exp = rewrite_exp_lift_assign_intro; + rewrite_pat; + rewrite_let; + rewrite_lexp; + (*_lift_assign_intro*) rewrite_fun; + rewrite_def; + rewrite_ast = rewrite_ast_base; + } + defs (* Remove redundant return statements, and translate remaining ones into an (effectful) call to builtin function "early_return" (in the Lem shallow @@ -1492,106 +1551,107 @@ let rewrite_ast_exp_lift_assign env defs = rewrite_ast_base TODO: Maybe separate generic removal of redundant returns, and Lem-specific rewriting of early returns - *) +*) let rewrite_ast_early_return effect_info env ast = - let is_unit (E_aux (exp, _)) = match exp with - | E_lit (L_aux (L_unit, _)) -> true - | _ -> false in + let is_unit (E_aux (exp, _)) = match exp with E_lit (L_aux (L_unit, _)) -> true | _ -> false in - let rec is_return (E_aux (exp, _)) = match exp with - | E_return _ -> true - | E_typ (_, e) -> is_return e - | _ -> false in + let rec is_return (E_aux (exp, _)) = match exp with E_return _ -> true | E_typ (_, e) -> is_return e | _ -> false in - let rec get_return (E_aux (e, annot) as exp) = match e with - | E_return e -> e - | E_typ (typ, e) -> E_aux (E_typ (typ, get_return e), annot) - | _ -> exp in + let rec get_return (E_aux (e, annot) as exp) = + match e with E_return e -> e | E_typ (typ, e) -> E_aux (E_typ (typ, get_return e), annot) | _ -> exp + in let contains_return exp = - fst (fold_exp - { (compute_exp_alg false (||)) - with e_return = (fun (_, r) -> (true, E_return r)) } exp) in + fst (fold_exp { (compute_exp_alg false ( || )) with e_return = (fun (_, r) -> (true, E_return r)) } exp) + in let e_if (e1, e2, e3) = - if is_return e2 && is_return e3 then + if is_return e2 && is_return e3 then ( let (E_aux (_, annot)) = get_return e2 in E_return (E_aux (E_if (e1, get_return e2, get_return e3), annot)) - else E_if (e1, e2, e3) in + ) + else E_if (e1, e2, e3) + in let rec e_block es = (* If one of the branches of an if-expression in a block is an early return, fold the rest of the block after the if-expression into the other branch *) - let fold_if_return exp block = match exp with + let fold_if_return exp block = + match exp with | E_aux (E_if (c, t, (E_aux (_, annot) as e)), _) when is_return t -> - let annot = match block with - | [] -> annot - | _ -> let (E_aux (_, annot)) = Util.last block in annot - in - let block = if is_unit e then block else e :: block in - let e' = E_aux (e_block block, annot) in - [E_aux (e_if (c, t, e'), annot)] + let annot = + match block with + | [] -> annot + | _ -> + let (E_aux (_, annot)) = Util.last block in + annot + in + let block = if is_unit e then block else e :: block in + let e' = E_aux (e_block block, annot) in + [E_aux (e_if (c, t, e'), annot)] | E_aux (E_if (c, (E_aux (_, annot) as t), e), _) when is_return e -> - let annot = match block with - | [] -> annot - | _ -> let (E_aux (_, annot)) = Util.last block in annot - in - let block = if is_unit t then block else t :: block in - let t' = E_aux (e_block block, annot) in - [E_aux (e_if (c, t', e), annot)] - | _ -> exp :: block in + let annot = + match block with + | [] -> annot + | _ -> + let (E_aux (_, annot)) = Util.last block in + annot + in + let block = if is_unit t then block else t :: block in + let t' = E_aux (e_block block, annot) in + [E_aux (e_if (c, t', e), annot)] + | _ -> exp :: block + in let es = List.fold_right fold_if_return es [] in match es with | [E_aux (e, _)] -> e | _ :: _ when is_return (Util.last es) -> - let (E_aux (_, annot) as e) = get_return (Util.last es) in - E_return (E_aux (E_block (Util.butlast es @ [get_return e]), annot)) - | _ -> E_block es in + let (E_aux (_, annot) as e) = get_return (Util.last es) in + E_return (E_aux (E_block (Util.butlast es @ [get_return e]), annot)) + | _ -> E_block es + in let e_case (e, pes) = - let is_return_pexp (Pat_aux (pexp, _)) = match pexp with - | Pat_exp (_, e) | Pat_when (_, _, e) -> is_return e in - let get_return_pexp (Pat_aux (pexp, a)) = match pexp with - | Pat_exp (p, e) -> Pat_aux (Pat_exp (p, get_return e), a) - | Pat_when (p, g, e) -> Pat_aux (Pat_when (p, g, get_return e), a) in - let annot = match List.map get_return_pexp pes with - | Pat_aux (Pat_exp (_, E_aux (_, annot)), _) :: _ -> annot - | Pat_aux (Pat_when (_, _, E_aux (_, annot)), _) :: _ -> annot - | [] -> (Parse_ast.Unknown, empty_tannot) in - if List.for_all is_return_pexp pes - then E_return (E_aux (E_match (e, List.map get_return_pexp pes), annot)) - else E_match (e, pes) in + let is_return_pexp (Pat_aux (pexp, _)) = match pexp with Pat_exp (_, e) | Pat_when (_, _, e) -> is_return e in + let get_return_pexp (Pat_aux (pexp, a)) = + match pexp with + | Pat_exp (p, e) -> Pat_aux (Pat_exp (p, get_return e), a) + | Pat_when (p, g, e) -> Pat_aux (Pat_when (p, g, get_return e), a) + in + let annot = + match List.map get_return_pexp pes with + | Pat_aux (Pat_exp (_, E_aux (_, annot)), _) :: _ -> annot + | Pat_aux (Pat_when (_, _, E_aux (_, annot)), _) :: _ -> annot + | [] -> (Parse_ast.Unknown, empty_tannot) + in + if List.for_all is_return_pexp pes then E_return (E_aux (E_match (e, List.map get_return_pexp pes), annot)) + else E_match (e, pes) + in let e_let (lb, exp) = let (E_aux (_, annot) as ret_exp) = get_return exp in - if is_return exp then E_return (E_aux (E_let (lb, ret_exp), annot)) - else E_let (lb, exp) in + if is_return exp then E_return (E_aux (E_let (lb, ret_exp), annot)) else E_let (lb, exp) + in let e_var (lexp, exp1, exp2) = let (E_aux (_, annot) as ret_exp2) = get_return exp2 in - if is_return exp2 then - E_return (E_aux (E_var (lexp, exp1, ret_exp2), annot)) - else E_var (lexp, exp1, exp2) in - - let e_app (id, es) = - try E_return (get_return (List.find is_return es)) - with - | Not_found -> E_app (id, es) + if is_return exp2 then E_return (E_aux (E_var (lexp, exp1, ret_exp2), annot)) else E_var (lexp, exp1, exp2) in + let e_app (id, es) = try E_return (get_return (List.find is_return es)) with Not_found -> E_app (id, es) in + let e_aux (exp, (l, annot)) = let full_exp = E_aux (exp, (l, annot)) in match full_exp with | E_aux (E_return exp, (l, tannot)) when not (is_empty_tannot tannot) -> - let typ = typ_of_annot (l, tannot) in - let env = env_of_annot (l, tannot) in - let tannot' = mk_tannot env typ in - let exp' = match Env.get_ret_typ env with - | Some typ -> add_e_typ env typ exp - | None -> exp in - E_aux (E_app (mk_id "early_return", [exp']), (l, tannot')) - | _ -> full_exp in + let typ = typ_of_annot (l, tannot) in + let env = env_of_annot (l, tannot) in + let tannot' = mk_tannot env typ in + let exp' = match Env.get_ret_typ env with Some typ -> add_e_typ env typ exp | None -> exp in + E_aux (E_app (mk_id "early_return", [exp']), (l, tannot')) + | _ -> full_exp + in (* Make sure that all final leaves of an expression (e.g. all branches of the last if-expression) are wrapped in a return statement. This allows @@ -1601,108 +1661,97 @@ let rewrite_ast_early_return effect_info env ast = let rec add_final_return nested (E_aux (e, annot) as exp) = let rewrap e = E_aux (e, annot) in match e with - | E_return _ -> exp - | E_typ (typ, e') -> - begin - let (E_aux (e_aux', annot') as e') = add_final_return nested e' in - match e_aux' with - | E_return e' -> rewrap (E_return (rewrap (E_typ (typ, e')))) - | _ -> rewrap (E_typ (typ, e')) - end - | E_block ((_ :: _) as es) -> - rewrap (E_block (Util.butlast es @ [add_final_return true (Util.last es)])) - | E_if (c, t, e) -> - rewrap (E_if (c, add_final_return true t, add_final_return true e)) - | E_match (e, pes) -> - let add_final_return_pexp = function - | Pat_aux (Pat_exp (p, e), a) -> - Pat_aux (Pat_exp (p, add_final_return true e), a) - | Pat_aux (Pat_when (p, g, e), a) -> - Pat_aux (Pat_when (p, g, add_final_return true e), a) - in - rewrap (E_match (e, List.map add_final_return_pexp pes)) - | E_let (lb, exp) -> - rewrap (E_let (lb, add_final_return true exp)) - | E_var (lexp, e1, e2) -> - rewrap (E_var (lexp, e1, add_final_return true e2)) - | _ -> - if nested && not (contains_return exp) then rewrap (E_return exp) else exp + | E_return _ -> exp + | E_typ (typ, e') -> begin + let (E_aux (e_aux', annot') as e') = add_final_return nested e' in + match e_aux' with E_return e' -> rewrap (E_return (rewrap (E_typ (typ, e')))) | _ -> rewrap (E_typ (typ, e')) + end + | E_block (_ :: _ as es) -> rewrap (E_block (Util.butlast es @ [add_final_return true (Util.last es)])) + | E_if (c, t, e) -> rewrap (E_if (c, add_final_return true t, add_final_return true e)) + | E_match (e, pes) -> + let add_final_return_pexp = function + | Pat_aux (Pat_exp (p, e), a) -> Pat_aux (Pat_exp (p, add_final_return true e), a) + | Pat_aux (Pat_when (p, g, e), a) -> Pat_aux (Pat_when (p, g, add_final_return true e), a) + in + rewrap (E_match (e, List.map add_final_return_pexp pes)) + | E_let (lb, exp) -> rewrap (E_let (lb, add_final_return true exp)) + | E_var (lexp, e1, e2) -> rewrap (E_var (lexp, e1, add_final_return true e2)) + | _ -> if nested && not (contains_return exp) then rewrap (E_return exp) else exp in let rewrite_funcl_early_return _ (FCL_aux (FCL_funcl (id, pexp), a)) = - let pat,guard,exp,pannot = destruct_pexp pexp in + let pat, guard, exp, pannot = destruct_pexp pexp in let exp = - if contains_return exp then + if contains_return exp then ( (* Try to pull out early returns as far as possible *) let exp' = - fold_exp - { id_exp_alg with e_block = e_block; e_if = e_if; e_case = e_case; - e_let = e_let; e_var = e_var; e_app = e_app } - (add_final_return false exp) in + fold_exp { id_exp_alg with e_block; e_if; e_case; e_let; e_var; e_app } (add_final_return false exp) + in (* Remove early return if we can pull it out completely, and rewrite remaining early returns to "early_return" calls *) - fold_exp - { id_exp_alg with e_aux = e_aux } - (if is_return exp' then get_return exp' else exp) + fold_exp { id_exp_alg with e_aux } (if is_return exp' then get_return exp' else exp) + ) else exp in - let a = match destruct_tannot (snd a) with - | Some (env, typ) -> - (fst a, mk_tannot env typ) - | _ -> a in - FCL_aux (FCL_funcl (id, construct_pexp (pat, guard, exp, pannot)), a) in - - let rewrite_fun_early_return rewriters - (FD_aux (FD_function (rec_opt, tannot_opt, funcls), a)) = - FD_aux (FD_function (rec_opt, tannot_opt, - List.map (rewrite_funcl_early_return rewriters) funcls), a) in - - let early_ret_spec = fst (Type_error.check_defs initial_env [gen_vs ~pure:false - ("early_return", "forall ('a : Type) ('b : Type). 'a -> 'b")]) in + let a = match destruct_tannot (snd a) with Some (env, typ) -> (fst a, mk_tannot env typ) | _ -> a in + FCL_aux (FCL_funcl (id, construct_pexp (pat, guard, exp, pannot)), a) + in + + let rewrite_fun_early_return rewriters (FD_aux (FD_function (rec_opt, tannot_opt, funcls), a)) = + FD_aux (FD_function (rec_opt, tannot_opt, List.map (rewrite_funcl_early_return rewriters) funcls), a) + in + + let early_ret_spec = + fst + (Type_error.check_defs initial_env + [gen_vs ~pure:false ("early_return", "forall ('a : Type) ('b : Type). 'a -> 'b")] + ) + in let effect_info = Effects.add_monadic_built_in (mk_id "early_return") effect_info in let new_ast = rewrite_ast_base { rewriters_base with rewrite_fun = rewrite_fun_early_return } { ast with defs = early_ret_spec @ ast.defs } - in new_ast, effect_info, env + in + (new_ast, effect_info, env) -let swaptyp typ (l,tannot) = match destruct_tannot tannot with +let swaptyp typ (l, tannot) = + match destruct_tannot tannot with | Some (env, typ') -> (l, mk_tannot env typ) | _ -> raise (Reporting.err_unreachable l __POS__ "swaptyp called with empty type annotation") let is_funcl_rec (FCL_aux (FCL_funcl (id, pexp), _)) = fold_pexp - { (pure_exp_alg false (||)) with - e_app = (fun (id',args) -> - Id.compare id id' == 0 || List.exists (fun x -> x) args); - e_app_infix = (fun (arg1,id',arg2) -> - arg1 || arg2 || Id.compare id id' == 0) - } pexp + { + (pure_exp_alg false ( || )) with + e_app = (fun (id', args) -> Id.compare id id' == 0 || List.exists (fun x -> x) args); + e_app_infix = (fun (arg1, id', arg2) -> arg1 || arg2 || Id.compare id id' == 0); + } + pexp (* Sail code isn't required to declare recursive functions as recursive, so if a backend needs them then this rewrite updates them. (Also see minimise_recursive_functions.) *) let rewrite_add_unspecified_rec env ast = - let rewrite_function (FD_aux (FD_function (recopt,topt,funcls),ann) as fd) = + let rewrite_function (FD_aux (FD_function (recopt, topt, funcls), ann) as fd) = match recopt with | Rec_aux (Rec_nonrec, l) when List.exists is_funcl_rec funcls -> - FD_aux (FD_function (Rec_aux (Rec_rec, Generated l),topt,funcls),ann) + FD_aux (FD_function (Rec_aux (Rec_rec, Generated l), topt, funcls), ann) | _ -> fd in let rewrite_def = function | DEF_aux (DEF_fundef fd, def_annot) -> DEF_aux (DEF_fundef (rewrite_function fd), def_annot) | d -> d - in { ast with defs = List.map rewrite_def ast.defs } + in + { ast with defs = List.map rewrite_def ast.defs } let pat_var (P_aux (paux, a)) = let env = env_of_annot a in let is_var id = - not (Env.is_union_constructor id env) && - match Env.lookup_id id env with Enum _ -> false | _ -> true - in match paux with - | (P_as (_, id) | P_id id) when is_var id -> Some id - | _ -> None + (not (Env.is_union_constructor id env)) && match Env.lookup_id id env with Enum _ -> false | _ -> true + in + match paux with (P_as (_, id) | P_id id) when is_var id -> Some id | _ -> None (** Split out function clauses for individual union constructor patterns (e.g. AST nodes) into auxiliary functions. Used for the execute function. @@ -1734,140 +1783,140 @@ let rewrite_split_fun_ctor_pats fun_name effect_info env ast = let pat, guard, exp, annot = destruct_pexp pexp in match pat with | P_aux (P_app (ctor_id, args), pannot) -> - let ctor_typq, ctor_typ = Env.get_union_id ctor_id env in - let args = match args with [P_aux (P_tuple args, _)] -> args | _ -> args in - let argstup_typ = tuple_typ (List.map typ_of_pat args) in - let pannot' = swaptyp argstup_typ pannot in - let pat' = - match args with - | [arg] -> arg - | _ -> P_aux (P_tuple args, pannot') - in - let pexp' = construct_pexp (pat', guard, exp, annot) in - let aux_fun_id = prepend_id (fun_name ^ "_") ctor_id in - let aux_funcl = FCL_aux (FCL_funcl (aux_fun_id, pexp'), (mk_def_annot (fst pannot'), snd pannot')) in - begin - try - let aux_clauses = Bindings.find aux_fun_id aux_funs in - clauses, - Bindings.add aux_fun_id (aux_clauses @ [(aux_funcl, ctor_typq, ctor_typ)]) aux_funs - with Not_found -> - let argpats, argexps = List.split (List.mapi - (fun idx (P_aux (_,a) as pat) -> - let id = match pat_var pat with - | Some id -> id - | None -> mk_id ("arg" ^ string_of_int idx) - in - P_aux (P_id id, a), E_aux (E_id id, a)) - args) - in - let pexp = construct_pexp - (P_aux (P_app (ctor_id, argpats), pannot), - None, - E_aux (E_app (aux_fun_id, argexps), annot), - annot) - in - clauses @ [FCL_aux (FCL_funcl (id, pexp), fannot)], - Bindings.add aux_fun_id [(aux_funcl, ctor_typq, ctor_typ)] aux_funs - end - | _ -> clauses @ [clause], aux_funs) + let ctor_typq, ctor_typ = Env.get_union_id ctor_id env in + let args = match args with [P_aux (P_tuple args, _)] -> args | _ -> args in + let argstup_typ = tuple_typ (List.map typ_of_pat args) in + let pannot' = swaptyp argstup_typ pannot in + let pat' = match args with [arg] -> arg | _ -> P_aux (P_tuple args, pannot') in + let pexp' = construct_pexp (pat', guard, exp, annot) in + let aux_fun_id = prepend_id (fun_name ^ "_") ctor_id in + let aux_funcl = FCL_aux (FCL_funcl (aux_fun_id, pexp'), (mk_def_annot (fst pannot'), snd pannot')) in + begin + try + let aux_clauses = Bindings.find aux_fun_id aux_funs in + (clauses, Bindings.add aux_fun_id (aux_clauses @ [(aux_funcl, ctor_typq, ctor_typ)]) aux_funs) + with Not_found -> + let argpats, argexps = + List.split + (List.mapi + (fun idx (P_aux (_, a) as pat) -> + let id = + match pat_var pat with Some id -> id | None -> mk_id ("arg" ^ string_of_int idx) + in + (P_aux (P_id id, a), E_aux (E_id id, a)) + ) + args + ) + in + let pexp = + construct_pexp + (P_aux (P_app (ctor_id, argpats), pannot), None, E_aux (E_app (aux_fun_id, argexps), annot), annot) + in + ( clauses @ [FCL_aux (FCL_funcl (id, pexp), fannot)], + Bindings.add aux_fun_id [(aux_funcl, ctor_typq, ctor_typ)] aux_funs + ) + end + | _ -> (clauses @ [clause], aux_funs) + ) ([], Bindings.empty) clauses in let add_aux_def id aux_funs defs = let funcls = List.map (fun (fcl, _, _) -> fcl) aux_funs in - let env, quants, args_typ, ret_typ = match aux_funs with + let env, quants, args_typ, ret_typ = + match aux_funs with | (FCL_aux (FCL_funcl (_, pexp), _), ctor_typq, ctor_typ) :: _ -> - let pat, _, exp, _ = destruct_pexp pexp in - let ctor_quants args_typ = - List.filter (fun qi -> KOptSet.subset (kopts_of_quant_item qi) (kopts_of_typ args_typ)) - (quant_items ctor_typq) - in - begin match ctor_typ with - | Typ_aux (Typ_fn ([Typ_aux (Typ_exist (kopts, nc, args_typ), _)], _), _) -> - env_of exp, ctor_quants args_typ @ List.map mk_qi_kopt kopts @ [mk_qi_nc nc], args_typ, typ_of exp - | Typ_aux (Typ_fn ([args_typ], _), _) -> env_of exp, ctor_quants args_typ, args_typ, typ_of exp - | _ -> - raise (Reporting.err_unreachable l __POS__ - ("Union constructor has non-function type: " ^ string_of_typ ctor_typ)) - end - | _ -> - raise (Reporting.err_unreachable l __POS__ - "rewrite_split_fun_constr_pats: empty auxiliary function") + let pat, _, exp, _ = destruct_pexp pexp in + let ctor_quants args_typ = + List.filter + (fun qi -> KOptSet.subset (kopts_of_quant_item qi) (kopts_of_typ args_typ)) + (quant_items ctor_typq) + in + begin + match ctor_typ with + | Typ_aux (Typ_fn ([Typ_aux (Typ_exist (kopts, nc, args_typ), _)], _), _) -> + (env_of exp, ctor_quants args_typ @ List.map mk_qi_kopt kopts @ [mk_qi_nc nc], args_typ, typ_of exp) + | Typ_aux (Typ_fn ([args_typ], _), _) -> (env_of exp, ctor_quants args_typ, args_typ, typ_of exp) + | _ -> + raise + (Reporting.err_unreachable l __POS__ + ("Union constructor has non-function type: " ^ string_of_typ ctor_typ) + ) + end + | _ -> raise (Reporting.err_unreachable l __POS__ "rewrite_split_fun_constr_pats: empty auxiliary function") in let fun_typ = (* Because we got the argument type from a pattern we need to do this. *) match args_typ with - | Typ_aux (Typ_tuple (args_typs), _) -> - function_typ args_typs ret_typ - | _ -> - function_typ [args_typ] ret_typ + | Typ_aux (Typ_tuple args_typs, _) -> function_typ args_typs ret_typ + | _ -> function_typ [args_typ] ret_typ in let val_spec = - VS_aux (VS_val_spec - (mk_typschm (mk_typquant quants) fun_typ, id, None, false), - (Parse_ast.Unknown, empty_tannot)) + VS_aux + (VS_val_spec (mk_typschm (mk_typquant quants) fun_typ, id, None, false), (Parse_ast.Unknown, empty_tannot)) in let fundef = FD_aux (FD_function (r_o, t_o, funcls), fdannot) in let def_annot = mk_def_annot (gen_loc def_annot.loc) in [DEF_aux (DEF_val val_spec, def_annot); DEF_aux (DEF_fundef fundef, def_annot)] @ defs in - Bindings.fold add_aux_def aux_funs - [DEF_aux (DEF_fundef (FD_aux (FD_function (r_o, t_o, rec_clauses @ clauses), fdannot)), def_annot)], - List.map fst (Bindings.bindings aux_funs) + ( Bindings.fold add_aux_def aux_funs + [DEF_aux (DEF_fundef (FD_aux (FD_function (r_o, t_o, rec_clauses @ clauses), fdannot)), def_annot)], + List.map fst (Bindings.bindings aux_funs) + ) in - let typquant = List.fold_left (fun tq def -> match def with - | DEF_aux (DEF_val (VS_aux (VS_val_spec (TypSchm_aux (TypSchm_ts (tq, _), _), id, _, _), _)), _) - when string_of_id id = fun_name -> tq - | _ -> tq) (mk_typquant []) ast.defs + let typquant = + List.fold_left + (fun tq def -> + match def with + | DEF_aux (DEF_val (VS_aux (VS_val_spec (TypSchm_aux (TypSchm_ts (tq, _), _), id, _, _), _)), _) + when string_of_id id = fun_name -> + tq + | _ -> tq + ) + (mk_typquant []) ast.defs in let defs, new_effect_info = - List.fold_right (fun def (defs, effect_info) -> + List.fold_right + (fun def (defs, effect_info) -> match def with | DEF_aux (DEF_fundef fundef, def_annot) when string_of_id (id_of_fundef fundef) = fun_name -> - let new_defs, new_ids = rewrite_fundef typquant fundef def_annot in - (new_defs @ defs, List.fold_left (Effects.copy_function_effect (id_of_fundef fundef)) effect_info new_ids) + let new_defs, new_ids = rewrite_fundef typquant fundef def_annot in + (new_defs @ defs, List.fold_left (Effects.copy_function_effect (id_of_fundef fundef)) effect_info new_ids) | _ -> (def :: defs, effect_info) - ) ast.defs ([], effect_info) + ) + ast.defs ([], effect_info) in - { ast with defs = defs }, new_effect_info, env + ({ ast with defs }, new_effect_info, env) -let rewrite_type_union_typs rw_typ (Tu_aux (Tu_ty_id (typ, id), annot)) = - Tu_aux (Tu_ty_id (rw_typ typ, id), annot) +let rewrite_type_union_typs rw_typ (Tu_aux (Tu_ty_id (typ, id), annot)) = Tu_aux (Tu_ty_id (rw_typ typ, id), annot) let rewrite_type_def_typs rw_typ rw_typquant (TD_aux (td, annot)) = match td with | TD_abbrev (id, typq, A_aux (A_typ typ, l)) -> - TD_aux (TD_abbrev (id, rw_typquant typq, A_aux (A_typ (rw_typ typ), l)), annot) - | TD_abbrev (id, typq, typ_arg) -> - TD_aux (TD_abbrev (id, rw_typquant typq, typ_arg), annot) + TD_aux (TD_abbrev (id, rw_typquant typq, A_aux (A_typ (rw_typ typ), l)), annot) + | TD_abbrev (id, typq, typ_arg) -> TD_aux (TD_abbrev (id, rw_typquant typq, typ_arg), annot) | TD_record (id, typq, typ_ids, flag) -> - TD_aux (TD_record (id, rw_typquant typq, List.map (fun (typ, id) -> (rw_typ typ, id)) typ_ids, flag), annot) + TD_aux (TD_record (id, rw_typquant typq, List.map (fun (typ, id) -> (rw_typ typ, id)) typ_ids, flag), annot) | TD_variant (id, typq, tus, flag) -> - TD_aux (TD_variant (id, rw_typquant typq, List.map (rewrite_type_union_typs rw_typ) tus, flag), annot) + TD_aux (TD_variant (id, rw_typquant typq, List.map (rewrite_type_union_typs rw_typ) tus, flag), annot) | TD_enum (id, ids, flag) -> TD_aux (TD_enum (id, ids, flag), annot) | TD_bitfield _ -> assert false (* Processed before re-writing *) (* FIXME: rewrite in opt_exp? *) let rewrite_dec_spec_typs rw_typ (DEC_aux (ds, annot)) = - match ds with - | DEC_reg (typ, id, opt_exp) -> DEC_aux (DEC_reg (rw_typ typ, id, opt_exp), annot) + match ds with DEC_reg (typ, id, opt_exp) -> DEC_aux (DEC_reg (rw_typ typ, id, opt_exp), annot) (* Remove overload definitions and cast val specs from the specification because the interpreter doesn't know about them.*) let rewrite_overload_cast env ast = let remove_cast_vs (VS_aux (vs_aux, annot)) = - match vs_aux with - | VS_val_spec (typschm, id, ext, _) -> VS_aux (VS_val_spec (typschm, id, ext, false), annot) + match vs_aux with VS_val_spec (typschm, id, ext, _) -> VS_aux (VS_val_spec (typschm, id, ext, false), annot) in let simple_def = function | DEF_aux (DEF_val vs, def_annot) -> DEF_aux (DEF_val (remove_cast_vs vs), def_annot) | def -> def in - let is_overload = function - | DEF_aux (DEF_overload _, _) -> true - | _ -> false - in + let is_overload = function DEF_aux (DEF_overload _, _) -> true | _ -> false in let defs = List.map simple_def ast.defs in { ast with defs = List.filter (fun def -> not (is_overload def)) defs } @@ -1875,39 +1924,37 @@ let rewrite_undefined mwords env = let rewrite_e_aux (E_aux (e_aux, _) as exp) = match e_aux with | E_lit (L_aux (L_undef, l)) -> - check_exp (env_of exp) (undefined_of_typ mwords l (fun _ -> empty_uannot) (Env.expand_synonyms (env_of exp) (typ_of exp))) (typ_of exp) + check_exp (env_of exp) + (undefined_of_typ mwords l (fun _ -> empty_uannot) (Env.expand_synonyms (env_of exp) (typ_of exp))) + (typ_of exp) | _ -> exp in let rewrite_exp_undefined = { id_exp_alg with e_aux = (fun (exp, annot) -> rewrite_e_aux (E_aux (exp, annot))) } in rewrite_ast_base { rewriters_base with rewrite_exp = (fun _ -> fold_exp rewrite_exp_undefined) } let rewrite_undefined_if_gen always_bitvector env defs = - if !Initial_check.opt_undefined_gen - then rewrite_undefined (always_bitvector || !Monomorphise.opt_mwords) env defs + if !Initial_check.opt_undefined_gen then rewrite_undefined (always_bitvector || !Monomorphise.opt_mwords) env defs else defs let rec simple_typ (Typ_aux (typ_aux, l)) = Typ_aux (simple_typ_aux l typ_aux, l) + and simple_typ_aux l = function | Typ_id id -> Typ_id id | Typ_app (id, [_; _; A_aux (A_typ typ, l)]) when Id.compare id (mk_id "vector") = 0 -> - Typ_app (mk_id "list", [A_aux (A_typ (simple_typ typ), l)]) + Typ_app (mk_id "list", [A_aux (A_typ (simple_typ typ), l)]) | Typ_app (id, [_; _]) when Id.compare id (mk_id "bitvector") = 0 -> - Typ_app (mk_id "list", [A_aux (A_typ bit_typ, gen_loc l)]) - | Typ_app (id, [_]) when Id.compare id (mk_id "atom") = 0 -> - Typ_id (mk_id "int") - | Typ_app (id, [_; _]) when Id.compare id (mk_id "range") = 0 -> - Typ_id (mk_id "int") - | Typ_app (id, [_]) when Id.compare id (mk_id "atom_bool") = 0 -> - Typ_id (mk_id "bool") + Typ_app (mk_id "list", [A_aux (A_typ bit_typ, gen_loc l)]) + | Typ_app (id, [_]) when Id.compare id (mk_id "atom") = 0 -> Typ_id (mk_id "int") + | Typ_app (id, [_; _]) when Id.compare id (mk_id "range") = 0 -> Typ_id (mk_id "int") + | Typ_app (id, [_]) when Id.compare id (mk_id "atom_bool") = 0 -> Typ_id (mk_id "bool") | Typ_app (id, args) -> Typ_app (id, List.concat (List.map simple_typ_arg args)) | Typ_fn (arg_typs, ret_typ) -> Typ_fn (List.map simple_typ arg_typs, simple_typ ret_typ) | Typ_tuple typs -> Typ_tuple (List.map simple_typ typs) | Typ_exist (_, _, Typ_aux (typ, l)) -> simple_typ_aux l typ | typ_aux -> typ_aux + and simple_typ_arg (A_aux (typ_arg_aux, l)) = - match typ_arg_aux with - | A_typ typ -> [A_aux (A_typ (simple_typ typ), l)] - | _ -> [] + match typ_arg_aux with A_typ typ -> [A_aux (A_typ (simple_typ typ), l)] | _ -> [] (* This pass aims to remove all the Num quantifiers from the specification. *) let rewrite_simple_types env ast = @@ -1930,48 +1977,54 @@ let rewrite_simple_types env ast = let simple_lit (L_aux (lit_aux, l) as lit) = match lit_aux with | L_bin _ | L_hex _ -> - E_list (List.map (fun b -> E_aux (E_lit b, simple_annot l bit_typ)) (vector_string_to_bit_list lit)) + E_list (List.map (fun b -> E_aux (E_lit b, simple_annot l bit_typ)) (vector_string_to_bit_list lit)) | _ -> E_lit lit in let simple_def (DEF_aux (aux, def_annot)) = - let aux = match aux with + let aux = + match aux with | DEF_val vs -> DEF_val (simple_vs vs) | DEF_type td -> DEF_type (rewrite_type_def_typs simple_typ simple_typquant td) | DEF_register ds -> DEF_register (rewrite_dec_spec_typs simple_typ ds) - | _ -> aux in + | _ -> aux + in DEF_aux (aux, def_annot) in - let simple_pat = { + let simple_pat = + { id_pat_alg with p_typ = (fun (typ, pat) -> P_typ (simple_typ typ, pat)); p_var = (fun (pat, kid) -> unaux_pat pat); - p_vector = (fun pats -> P_list pats) - } in - let simple_exp = { + p_vector = (fun pats -> P_list pats); + } + in + let simple_exp = + { id_exp_alg with e_lit = simple_lit; e_vector = (fun exps -> E_list exps); e_typ = (fun (typ, exp) -> E_typ (simple_typ typ, exp)); (* e_assert = (fun (E_aux (_, annot), str) -> E_assert (E_aux (E_lit (mk_lit L_true), annot), str)); *) le_typ = (fun (typ, lexp) -> LE_typ (simple_typ typ, lexp)); - pat_alg = simple_pat - } in - let simple_defs = { rewriters_base with rewrite_exp = (fun _ -> fold_exp simple_exp); - rewrite_pat = (fun _ -> fold_pat simple_pat) } + pat_alg = simple_pat; + } + in + let simple_defs = + { rewriters_base with rewrite_exp = (fun _ -> fold_exp simple_exp); rewrite_pat = (fun _ -> fold_pat simple_pat) } in let ast = { ast with defs = List.map simple_def ast.defs } in rewrite_ast_base simple_defs ast let rewrite_vector_concat_assignments env defs = let lit_int i = mk_exp (E_lit (mk_lit (L_num i))) in - let sub m n = match m, n with - | E_aux (E_lit (L_aux (L_num m, _)),_), E_aux (E_lit (L_aux (L_num n, _)),_) -> - lit_int (Big_int.sub m n) + let sub m n = + match (m, n) with + | E_aux (E_lit (L_aux (L_num m, _)), _), E_aux (E_lit (L_aux (L_num n, _)), _) -> lit_int (Big_int.sub m n) | _, _ -> mk_exp (E_app_infix (m, mk_id "-", n)) in - let add m n = match m, n with - | E_aux (E_lit (L_aux (L_num m, _)),_), E_aux (E_lit (L_aux (L_num n, _)),_) -> - lit_int (Big_int.add m n) + let add m n = + match (m, n) with + | E_aux (E_lit (L_aux (L_num m, _)), _), E_aux (E_lit (L_aux (L_num n, _)), _) -> lit_int (Big_int.add m n) | _, _ -> mk_exp (E_app_infix (m, mk_id "+", n)) in @@ -1979,55 +2032,52 @@ let rewrite_vector_concat_assignments env defs = let env = env_of_annot annot in match e_aux with | E_assign (LE_aux (LE_vector_concat lexps, lannot), exp) -> - let typ = Env.base_typ_of env (typ_of exp) in - if is_vector_typ typ || is_bitvector_typ typ then - (* let _ = Pretty_print_common.print stderr (Pretty_print_sail.doc_exp (E_aux (e_aux, annot))) in *) - let start = vector_start_index typ in - let (_, ord, etyp) = vector_typ_args_of typ in - let len (LE_aux (le, lannot)) = - let ltyp = Env.base_typ_of env (typ_of_annot lannot) in - if is_vector_typ ltyp || is_bitvector_typ ltyp then - let (len, _, _) = vector_typ_args_of ltyp in - match Type_check.solve_unique (env_of_annot lannot) len with - | Some len -> mk_exp (E_lit (mk_lit (L_num len))) - | None -> mk_exp (E_sizeof (nexp_simp len)) - else Reporting.unreachable (fst lannot) __POS__ "Lexp in vector concatenation assignment is not a vector" - in - let next i step = - if is_order_inc ord - then (sub (add i step) (lit_int (Big_int.of_int 1)), add i step) - else (add (sub i step) (lit_int (Big_int.of_int 1)), sub i step) - in - let i = match Type_check.solve_unique (env_of exp) start with - | Some i -> lit_int i - | None -> mk_exp (E_sizeof (nexp_simp start)) - in - let vec_id = mk_id "split_vec" in - let exp' = if small exp then strip_exp exp else mk_exp (E_id vec_id) in - let lexp_to_exp (i, exps) lexp = - let (j, i') = next i (len lexp) in - let sub = mk_exp (E_vector_subrange (exp', i, j)) in - (i', exps @ [sub]) - in - let (_, exps) = List.fold_left lexp_to_exp (i, []) lexps in - let assign lexp exp = mk_exp (E_assign (strip_lexp lexp, exp)) in - let block = mk_exp (E_block (List.map2 assign lexps exps)) in - let full_exp = - if small exp then block else - mk_exp (E_let (mk_letbind (mk_pat (P_id vec_id)) (strip_exp exp), block)) - in - begin - try check_exp env full_exp unit_typ with - | Type_error (_, l, err) -> - raise (Reporting.err_typ l (Type_error.string_of_type_error err)) - end - else E_aux (e_aux, annot) + let typ = Env.base_typ_of env (typ_of exp) in + if is_vector_typ typ || is_bitvector_typ typ then ( + (* let _ = Pretty_print_common.print stderr (Pretty_print_sail.doc_exp (E_aux (e_aux, annot))) in *) + let start = vector_start_index typ in + let _, ord, etyp = vector_typ_args_of typ in + let len (LE_aux (le, lannot)) = + let ltyp = Env.base_typ_of env (typ_of_annot lannot) in + if is_vector_typ ltyp || is_bitvector_typ ltyp then ( + let len, _, _ = vector_typ_args_of ltyp in + match Type_check.solve_unique (env_of_annot lannot) len with + | Some len -> mk_exp (E_lit (mk_lit (L_num len))) + | None -> mk_exp (E_sizeof (nexp_simp len)) + ) + else Reporting.unreachable (fst lannot) __POS__ "Lexp in vector concatenation assignment is not a vector" + in + let next i step = + if is_order_inc ord then (sub (add i step) (lit_int (Big_int.of_int 1)), add i step) + else (add (sub i step) (lit_int (Big_int.of_int 1)), sub i step) + in + let i = + match Type_check.solve_unique (env_of exp) start with + | Some i -> lit_int i + | None -> mk_exp (E_sizeof (nexp_simp start)) + in + let vec_id = mk_id "split_vec" in + let exp' = if small exp then strip_exp exp else mk_exp (E_id vec_id) in + let lexp_to_exp (i, exps) lexp = + let j, i' = next i (len lexp) in + let sub = mk_exp (E_vector_subrange (exp', i, j)) in + (i', exps @ [sub]) + in + let _, exps = List.fold_left lexp_to_exp (i, []) lexps in + let assign lexp exp = mk_exp (E_assign (strip_lexp lexp, exp)) in + let block = mk_exp (E_block (List.map2 assign lexps exps)) in + let full_exp = + if small exp then block else mk_exp (E_let (mk_letbind (mk_pat (P_id vec_id)) (strip_exp exp), block)) + in + begin + try check_exp env full_exp unit_typ + with Type_error (_, l, err) -> raise (Reporting.err_typ l (Type_error.string_of_type_error err)) + end + ) + else E_aux (e_aux, annot) | _ -> E_aux (e_aux, annot) in - let assign_exp = { - id_exp_alg with - e_aux = (fun (e_aux, annot) -> assign_tuple e_aux annot) - } in + let assign_exp = { id_exp_alg with e_aux = (fun (e_aux, annot) -> assign_tuple e_aux annot) } in let assign_defs = { rewriters_base with rewrite_exp = (fun _ -> fold_exp assign_exp) } in rewrite_ast_base assign_defs defs @@ -2036,23 +2086,23 @@ let rewrite_tuple_assignments env defs = let env = env_of_annot annot in match e_aux with | E_assign (LE_aux (LE_tuple lexps, _), exp) -> - let (_, ids) = List.fold_left (fun (n, ids) _ -> (n + 1, ids @ [mk_id ("tup__" ^ string_of_int n)])) (0, []) lexps in - let block_assign i lexp = mk_exp (E_assign (strip_lexp lexp, mk_exp (E_id (mk_id ("tup__" ^ string_of_int i))))) in - let block = mk_exp (E_block (List.mapi block_assign lexps)) in - let pat = mk_pat (P_tuple (List.map (fun id -> mk_pat (P_id id)) ids)) in - let exp' = add_e_typ env (typ_of exp) exp in - let let_exp = mk_exp (E_let (mk_letbind pat (strip_exp exp'), block)) in - begin - try check_exp env let_exp unit_typ with - | Type_error (_, l, err) -> - raise (Reporting.err_typ l (Type_error.string_of_type_error err)) - end + let _, ids = + List.fold_left (fun (n, ids) _ -> (n + 1, ids @ [mk_id ("tup__" ^ string_of_int n)])) (0, []) lexps + in + let block_assign i lexp = + mk_exp (E_assign (strip_lexp lexp, mk_exp (E_id (mk_id ("tup__" ^ string_of_int i))))) + in + let block = mk_exp (E_block (List.mapi block_assign lexps)) in + let pat = mk_pat (P_tuple (List.map (fun id -> mk_pat (P_id id)) ids)) in + let exp' = add_e_typ env (typ_of exp) exp in + let let_exp = mk_exp (E_let (mk_letbind pat (strip_exp exp'), block)) in + begin + try check_exp env let_exp unit_typ + with Type_error (_, l, err) -> raise (Reporting.err_typ l (Type_error.string_of_type_error err)) + end | _ -> E_aux (e_aux, annot) in - let assign_exp = { - id_exp_alg with - e_aux = (fun (e_aux, annot) -> assign_tuple e_aux annot) - } in + let assign_exp = { id_exp_alg with e_aux = (fun (e_aux, annot) -> assign_tuple e_aux annot) } in let assign_defs = { rewriters_base with rewrite_exp = (fun _ -> fold_exp assign_exp) } in rewrite_ast_base assign_defs defs @@ -2061,24 +2111,20 @@ let rewrite_simple_assignments allow_fields env defs = match aux with | LE_id _ -> true | LE_typ _ -> true - | LE_field (lexp, _) when allow_fields -> is_simple lexp + | LE_field (lexp, _) when allow_fields -> is_simple lexp | _ -> false in let assign_e_aux e_aux annot = let env = env_of_annot annot in match e_aux with - | E_assign (lexp, _) when is_simple lexp -> - E_aux (e_aux, annot) + | E_assign (lexp, _) when is_simple lexp -> E_aux (e_aux, annot) | E_assign (lexp, exp) -> - let (lexp, rhs) = rewrite_lexp_to_rhs lexp in - let assign = mk_exp (E_assign (strip_lexp lexp, strip_exp (rhs exp))) in - check_exp env assign unit_typ + let lexp, rhs = rewrite_lexp_to_rhs lexp in + let assign = mk_exp (E_assign (strip_lexp lexp, strip_exp (rhs exp))) in + check_exp env assign unit_typ | _ -> E_aux (e_aux, annot) in - let assign_exp = { - id_exp_alg with - e_aux = (fun (e_aux, annot) -> assign_e_aux e_aux annot) - } in + let assign_exp = { id_exp_alg with e_aux = (fun (e_aux, annot) -> assign_e_aux e_aux annot) } in let assign_defs = { rewriters_base with rewrite_exp = (fun _ -> fold_exp assign_exp) } in rewrite_ast_base assign_defs defs @@ -2089,392 +2135,324 @@ let rewrite_ast_remove_blocks env = let typ = typ_of v in let wild = annot_pat P_wild l env typ in let e_aux = E_let (annot_letbind (unaux_pat wild, v) l env typ, body) in - annot_exp e_aux l env (typ_of body) - |> add_typs_let env typ (typ_of body) + annot_exp e_aux l env (typ_of body) |> add_typs_let env typ (typ_of body) in let rec f l = function - | [] -> E_aux (E_lit (L_aux (L_unit,gen_loc l)), (simple_annot l unit_typ)) - | [e] -> e (* check with Kathy if that annotation is fine *) - | e :: es -> letbind_wild e (f l es) in + | [] -> E_aux (E_lit (L_aux (L_unit, gen_loc l)), simple_annot l unit_typ) + | [e] -> e (* check with Kathy if that annotation is fine *) + | e :: es -> letbind_wild e (f l es) + in - let e_aux = function - | (E_block es,(l,_)) -> f l es - | (e,annot) -> E_aux (e,annot) in + let e_aux = function E_block es, (l, _) -> f l es | e, annot -> E_aux (e, annot) in - let alg = { id_exp_alg with e_aux = e_aux } in + let alg = { id_exp_alg with e_aux } in rewrite_ast_base - {rewrite_exp = (fun _ -> fold_exp alg) - ; rewrite_pat = rewrite_pat - ; rewrite_let = rewrite_let - ; rewrite_lexp = rewrite_lexp - ; rewrite_fun = rewrite_fun - ; rewrite_def = rewrite_def - ; rewrite_ast = rewrite_ast_base + { + rewrite_exp = (fun _ -> fold_exp alg); + rewrite_pat; + rewrite_let; + rewrite_lexp; + rewrite_fun; + rewrite_def; + rewrite_ast = rewrite_ast_base; } - let letbind (v : 'a exp) (body : 'a exp -> 'a exp) : 'a exp = (* body is a function : E_id variable -> actual body *) - let (E_aux (_,(l,annot))) = v in + let (E_aux (_, (l, annot))) = v in match destruct_tannot annot with | Some (env, typ) when is_unit_typ typ -> - let body = body (annot_exp (E_lit (mk_lit L_unit)) l env unit_typ) in - let body_typ = try typ_of body with _ -> unit_typ in - let wild = annot_pat P_wild l env typ in - let lb = annot_letbind (unaux_pat wild, v) l env unit_typ in - annot_exp (E_let (lb, body)) l env body_typ - |> add_typs_let env typ body_typ + let body = body (annot_exp (E_lit (mk_lit L_unit)) l env unit_typ) in + let body_typ = try typ_of body with _ -> unit_typ in + let wild = annot_pat P_wild l env typ in + let lb = annot_letbind (unaux_pat wild, v) l env unit_typ in + annot_exp (E_let (lb, body)) l env body_typ |> add_typs_let env typ body_typ | Some (env, typ) -> - let id = fresh_id "w__" l in - let pat = annot_pat (P_id id) l env typ in - let lb = annot_letbind (unaux_pat pat, v) l env typ in - let body = body (annot_exp (E_id id) l env typ) in - annot_exp (E_let (lb, body)) l env (typ_of body) - |> add_typs_let env typ (typ_of body) - | None -> - Reporting.unreachable l __POS__ "no type information" - -let rec mapCont (f : 'b -> ('b -> 'a exp) -> 'a exp) (l : 'b list) (k : 'b list -> 'a exp) : 'a exp = - match l with - | [] -> k [] - | exp :: exps -> f exp (fun exp -> mapCont f exps (fun exps -> k (exp :: exps))) + let id = fresh_id "w__" l in + let pat = annot_pat (P_id id) l env typ in + let lb = annot_letbind (unaux_pat pat, v) l env typ in + let body = body (annot_exp (E_id id) l env typ) in + annot_exp (E_let (lb, body)) l env (typ_of body) |> add_typs_let env typ (typ_of body) + | None -> Reporting.unreachable l __POS__ "no type information" + +let rec mapCont (f : 'b -> ('b -> 'a exp) -> 'a exp) (l : 'b list) (k : 'b list -> 'a exp) : 'a exp = + match l with [] -> k [] | exp :: exps -> f exp (fun exp -> mapCont f exps (fun exps -> k (exp :: exps))) let rewrite_ast_letbind_effects effect_info env = - let monadic (E_aux (aux, (l, tannot))) = - E_aux (aux, (l, add_effect_annot tannot monadic_effect)) in + let monadic (E_aux (aux, (l, tannot))) = E_aux (aux, (l, add_effect_annot tannot monadic_effect)) in - let purify (E_aux (aux, (l, tannot))) = - E_aux (aux, (l, add_effect_annot tannot no_effect)) in - - let value ((E_aux (exp_aux,_)) as exp) = - not (effectful exp || updates_vars exp) in + let purify (E_aux (aux, (l, tannot))) = E_aux (aux, (l, add_effect_annot tannot no_effect)) in + + let value (E_aux (exp_aux, _) as exp) = not (effectful exp || updates_vars exp) in let rec n_exp_name (exp : 'a exp) (k : 'a exp -> 'a exp) : 'a exp = n_exp exp (fun exp -> if value exp then k exp else monadic (letbind exp k)) - and n_exp_pure (exp : 'a exp) (k : 'a exp -> 'a exp) : 'a exp = n_exp exp (fun exp -> if value exp then k exp else monadic (letbind exp k)) - - and n_exp_nameL (exps : 'a exp list) (k : 'a exp list -> 'a exp) : 'a exp = - mapCont n_exp_name exps k - + and n_exp_nameL (exps : 'a exp list) (k : 'a exp list -> 'a exp) : 'a exp = mapCont n_exp_name exps k and n_fexp (fexp : 'a fexp) (k : 'a fexp -> 'a exp) : 'a exp = - let (FE_aux (FE_fexp (id,exp),annot)) = fexp in - n_exp_name exp (fun exp -> - k (FE_aux (FE_fexp (id,exp),annot))) - - and n_fexpL (fexps : 'a fexp list) (k : 'a fexp list -> 'a exp) : 'a exp = - mapCont n_fexp fexps k - - and n_pexp : 'b. bool -> 'a pexp -> ('a pexp -> 'b) -> 'b = fun newreturn pexp k -> + let (FE_aux (FE_fexp (id, exp), annot)) = fexp in + n_exp_name exp (fun exp -> k (FE_aux (FE_fexp (id, exp), annot))) + and n_fexpL (fexps : 'a fexp list) (k : 'a fexp list -> 'a exp) : 'a exp = mapCont n_fexp fexps k + and n_pexp : 'b. bool -> 'a pexp -> ('a pexp -> 'b) -> 'b = + fun newreturn pexp k -> match pexp with - | Pat_aux (Pat_exp (pat,exp),annot) -> - k (Pat_aux (Pat_exp (pat, n_exp_term newreturn exp), annot)) - | Pat_aux (Pat_when (pat,guard,exp),annot) -> - k (Pat_aux (Pat_when (pat, n_exp_term newreturn guard, n_exp_term newreturn exp), annot)) - + | Pat_aux (Pat_exp (pat, exp), annot) -> k (Pat_aux (Pat_exp (pat, n_exp_term newreturn exp), annot)) + | Pat_aux (Pat_when (pat, guard, exp), annot) -> + k (Pat_aux (Pat_when (pat, n_exp_term newreturn guard, n_exp_term newreturn exp), annot)) and n_pexpL (newreturn : bool) (pexps : 'a pexp list) (k : 'a pexp list -> 'a exp) : 'a exp = mapCont (n_pexp newreturn) pexps k - and n_lb (lb : 'a letbind) (k : 'a letbind -> 'a exp) : 'a exp = - let (LB_aux (lb,annot)) = lb in - match lb with - | LB_val (pat,exp1) -> - n_exp exp1 (fun exp1 -> - k (LB_aux (LB_val (pat,exp1),annot))) - + let (LB_aux (lb, annot)) = lb in + match lb with LB_val (pat, exp1) -> n_exp exp1 (fun exp1 -> k (LB_aux (LB_val (pat, exp1), annot))) and n_lexp (lexp : 'a lexp) (k : 'a lexp -> 'a exp) : 'a exp = - let (LE_aux (lexp_aux,annot)) = lexp in + let (LE_aux (lexp_aux, annot)) = lexp in match lexp_aux with | LE_id _ -> k lexp - | LE_deref exp -> - n_exp_name exp (fun exp -> - k (LE_aux (LE_deref exp, annot))) - | LE_app (id,es) -> - n_exp_nameL es (fun es -> - k (LE_aux (LE_app (id,es),annot))) - | LE_tuple es -> - n_lexpL es (fun es -> - k (LE_aux (LE_tuple es,annot))) - | LE_typ (typ,id) -> - k (LE_aux (LE_typ (typ,id),annot)) - | LE_vector (lexp,e) -> - n_lexp lexp (fun lexp -> - n_exp_name e (fun e -> - k (LE_aux (LE_vector (lexp,e),annot)))) - | LE_vector_range (lexp,e1,e2) -> - n_lexp lexp (fun lexp -> - n_exp_name e1 (fun e1 -> - n_exp_name e2 (fun e2 -> - k (LE_aux (LE_vector_range (lexp,e1,e2),annot))))) - | LE_vector_concat es -> - n_lexpL es (fun es -> - k (LE_aux (LE_vector_concat es,annot))) - | LE_field (lexp,id) -> - n_lexp lexp (fun lexp -> - k (LE_aux (LE_field (lexp,id),annot))) - - and n_lexpL (lexps : 'a lexp list) (k : 'a lexp list -> 'a exp) : 'a exp = - mapCont n_lexp lexps k - - and n_exp_term ?cast:(cast=false) (newreturn : bool) (exp : 'a exp) : 'a exp = - let (E_aux (_,(l,tannot))) = exp in + | LE_deref exp -> n_exp_name exp (fun exp -> k (LE_aux (LE_deref exp, annot))) + | LE_app (id, es) -> n_exp_nameL es (fun es -> k (LE_aux (LE_app (id, es), annot))) + | LE_tuple es -> n_lexpL es (fun es -> k (LE_aux (LE_tuple es, annot))) + | LE_typ (typ, id) -> k (LE_aux (LE_typ (typ, id), annot)) + | LE_vector (lexp, e) -> n_lexp lexp (fun lexp -> n_exp_name e (fun e -> k (LE_aux (LE_vector (lexp, e), annot)))) + | LE_vector_range (lexp, e1, e2) -> + n_lexp lexp (fun lexp -> + n_exp_name e1 (fun e1 -> n_exp_name e2 (fun e2 -> k (LE_aux (LE_vector_range (lexp, e1, e2), annot)))) + ) + | LE_vector_concat es -> n_lexpL es (fun es -> k (LE_aux (LE_vector_concat es, annot))) + | LE_field (lexp, id) -> n_lexp lexp (fun lexp -> k (LE_aux (LE_field (lexp, id), annot))) + and n_lexpL (lexps : 'a lexp list) (k : 'a lexp list -> 'a exp) : 'a exp = mapCont n_lexp lexps k + and n_exp_term ?(cast = false) (newreturn : bool) (exp : 'a exp) : 'a exp = + let (E_aux (_, (l, tannot))) = exp in let exp = if newreturn then ( (* let typ = try typ_of exp with _ -> unit_typ in *) let exp = if cast then add_e_typ (env_of exp) (typ_of exp) exp else exp in annot_exp (E_internal_return exp) l (env_of exp) (typ_of exp) - ) else exp + ) + else exp in - (* n_exp_term forces an expression to be translated into a form + (* n_exp_term forces an expression to be translated into a form "let .. let .. let .. in EXP" where EXP has no effect and does not update variables *) n_exp_pure exp (fun exp -> exp) - - and n_exp (E_aux (exp_aux,annot) as exp : 'a exp) (k : 'a exp -> 'a exp) : 'a exp = - let rewrap e = E_aux (e,annot) in + and n_exp (E_aux (exp_aux, annot) as exp : 'a exp) (k : 'a exp -> 'a exp) : 'a exp = + let rewrap e = E_aux (e, annot) in let pure_rewrap e = purify (rewrap e) in match exp_aux with | E_block es -> failwith "E_block should have been removed till now" | E_id id -> k exp | E_ref id -> k exp | E_lit _ -> k exp - | E_typ (typ,exp') -> - n_exp_name exp' (fun exp' -> - k (pure_rewrap (E_typ (typ, exp')))) - | E_app (op_bool, [l; r]) - when string_of_id op_bool = "and_bool" || string_of_id op_bool = "or_bool" -> - (* Leave effectful operands of Boolean "and"/"or" in place to allow - short-circuiting. *) - let newreturn = effectful l || effectful r in - let l = n_exp_term ~cast:true newreturn l in - let r = n_exp_term ~cast:true newreturn r in - k (rewrap (E_app (op_bool, [l; r]))) + | E_typ (typ, exp') -> n_exp_name exp' (fun exp' -> k (pure_rewrap (E_typ (typ, exp')))) + | E_app (op_bool, [l; r]) when string_of_id op_bool = "and_bool" || string_of_id op_bool = "or_bool" -> + (* Leave effectful operands of Boolean "and"/"or" in place to allow + short-circuiting. *) + let newreturn = effectful l || effectful r in + let l = n_exp_term ~cast:true newreturn l in + let r = n_exp_term ~cast:true newreturn r in + k (rewrap (E_app (op_bool, [l; r]))) | E_app (id, exps) -> - let fix_eff = if Effects.function_is_pure id effect_info then purify else (fun exp -> exp) in - n_exp_nameL exps (fun exps -> - k (fix_eff (rewrap (E_app (id, exps))))) + let fix_eff = if Effects.function_is_pure id effect_info then purify else fun exp -> exp in + n_exp_nameL exps (fun exps -> k (fix_eff (rewrap (E_app (id, exps))))) | E_app_infix (exp1, id, exp2) -> - let fix_eff = if Effects.function_is_pure id effect_info then purify else (fun exp -> exp) in - n_exp_name exp1 (fun exp1 -> - n_exp_name exp2 (fun exp2 -> - k (fix_eff (rewrap (E_app_infix (exp1, id, exp2)))))) - | E_tuple exps -> - n_exp_nameL exps (fun exps -> - k (pure_rewrap (E_tuple exps))) - | E_if (exp1,exp2,exp3) -> - let e_if exp1 = - let newreturn = effectful exp2 || effectful exp3 in - let exp2 = n_exp_term newreturn exp2 in - let exp3 = n_exp_term newreturn exp3 in - k (rewrap (E_if (exp1,exp2,exp3))) - in - if value exp1 then e_if (n_exp_term false exp1) else n_exp_name exp1 e_if - | E_for (id,start,stop,by,dir,body) -> - n_exp_name start (fun start -> - n_exp_name stop (fun stop -> - n_exp_name by (fun by -> - let body = n_exp_term (effectful body) body in - k (rewrap (E_for (id,start,stop,by,dir,body)))))) + let fix_eff = if Effects.function_is_pure id effect_info then purify else fun exp -> exp in + n_exp_name exp1 (fun exp1 -> n_exp_name exp2 (fun exp2 -> k (fix_eff (rewrap (E_app_infix (exp1, id, exp2)))))) + | E_tuple exps -> n_exp_nameL exps (fun exps -> k (pure_rewrap (E_tuple exps))) + | E_if (exp1, exp2, exp3) -> + let e_if exp1 = + let newreturn = effectful exp2 || effectful exp3 in + let exp2 = n_exp_term newreturn exp2 in + let exp3 = n_exp_term newreturn exp3 in + k (rewrap (E_if (exp1, exp2, exp3))) + in + if value exp1 then e_if (n_exp_term false exp1) else n_exp_name exp1 e_if + | E_for (id, start, stop, by, dir, body) -> + n_exp_name start (fun start -> + n_exp_name stop (fun stop -> + n_exp_name by (fun by -> + let body = n_exp_term (effectful body) body in + k (rewrap (E_for (id, start, stop, by, dir, body))) + ) + ) + ) | E_loop (loop, measure, cond, body) -> - let measure = match measure with - | Measure_aux (Measure_none,_) -> measure - | Measure_aux (Measure_some exp,l) -> - Measure_aux (Measure_some (n_exp_term false exp),l) - in - let cond = n_exp_term ~cast:true (effectful cond) cond in - let body = n_exp_term (effectful body) body in - k (rewrap (E_loop (loop,measure,cond,body))) - | E_vector exps -> - n_exp_nameL exps (fun exps -> - k (pure_rewrap (E_vector exps))) - | E_vector_access (exp1,exp2) -> - n_exp_name exp1 (fun exp1 -> - n_exp_name exp2 (fun exp2 -> - k (pure_rewrap (E_vector_access (exp1,exp2))))) - | E_vector_subrange (exp1,exp2,exp3) -> - n_exp_name exp1 (fun exp1 -> - n_exp_name exp2 (fun exp2 -> - n_exp_name exp3 (fun exp3 -> - k (pure_rewrap (E_vector_subrange (exp1,exp2,exp3)))))) - | E_vector_update (exp1,exp2,exp3) -> - n_exp_name exp1 (fun exp1 -> - n_exp_name exp2 (fun exp2 -> - n_exp_name exp3 (fun exp3 -> - k (pure_rewrap (E_vector_update (exp1,exp2,exp3)))))) - | E_vector_update_subrange (exp1,exp2,exp3,exp4) -> - n_exp_name exp1 (fun exp1 -> - n_exp_name exp2 (fun exp2 -> - n_exp_name exp3 (fun exp3 -> - n_exp_name exp4 (fun exp4 -> - k (pure_rewrap (E_vector_update_subrange (exp1,exp2,exp3,exp4))))))) - | E_vector_append (exp1,exp2) -> - n_exp_name exp1 (fun exp1 -> - n_exp_name exp2 (fun exp2 -> - k (pure_rewrap (E_vector_append (exp1,exp2))))) - | E_list exps -> - n_exp_nameL exps (fun exps -> - k (pure_rewrap (E_list exps))) - | E_cons (exp1,exp2) -> - n_exp_name exp1 (fun exp1 -> - n_exp_name exp2 (fun exp2 -> - k (pure_rewrap (E_cons (exp1,exp2))))) - | E_struct fexps -> - n_fexpL fexps (fun fexps -> - k (pure_rewrap (E_struct fexps))) - | E_struct_update (exp1,fexps) -> - n_exp_name exp1 (fun exp1 -> - n_fexpL fexps (fun fexps -> - k (pure_rewrap (E_struct_update (exp1,fexps))))) - | E_field (exp1,id) -> - n_exp_name exp1 (fun exp1 -> - k (pure_rewrap (E_field (exp1,id)))) - | E_match (exp1,pexps) -> - let newreturn = List.exists effectful_pexp pexps in - n_exp_name exp1 (fun exp1 -> - n_pexpL newreturn pexps (fun pexps -> - k (rewrap (E_match (exp1,pexps))))) - | E_try (exp1,pexps) -> - let newreturn = effectful exp1 || List.exists effectful_pexp pexps in - let exp1 = n_exp_term newreturn exp1 in - n_pexpL newreturn pexps (fun pexps -> - k (rewrap (E_try (exp1,pexps)))) - | E_let (lb,body) -> - n_lb lb (fun lb -> - rewrap (E_let (lb, n_exp body k))) - | E_sizeof nexp -> - k (rewrap (E_sizeof nexp)) - | E_constraint nc -> - k (rewrap (E_constraint nc)) - | E_assign (lexp,exp1) -> - n_lexp lexp (fun lexp -> - n_exp_name exp1 (fun exp1 -> - k (rewrap (E_assign (lexp,exp1))))) - | E_exit exp' -> k (E_aux (E_exit (n_exp_term (effectful exp') exp'),annot)) - | E_assert (exp1,exp2) -> - n_exp_name exp1 (fun exp1 -> - n_exp_name exp2 (fun exp2 -> - k (rewrap (E_assert (exp1, exp2))))) - | E_var (lexp,exp1,exp2) -> - n_lexp lexp (fun lexp -> - n_exp exp1 (fun exp1 -> - rewrap (E_var (lexp,exp1,n_exp exp2 k)))) + let measure = + match measure with + | Measure_aux (Measure_none, _) -> measure + | Measure_aux (Measure_some exp, l) -> Measure_aux (Measure_some (n_exp_term false exp), l) + in + let cond = n_exp_term ~cast:true (effectful cond) cond in + let body = n_exp_term (effectful body) body in + k (rewrap (E_loop (loop, measure, cond, body))) + | E_vector exps -> n_exp_nameL exps (fun exps -> k (pure_rewrap (E_vector exps))) + | E_vector_access (exp1, exp2) -> + n_exp_name exp1 (fun exp1 -> n_exp_name exp2 (fun exp2 -> k (pure_rewrap (E_vector_access (exp1, exp2))))) + | E_vector_subrange (exp1, exp2, exp3) -> + n_exp_name exp1 (fun exp1 -> + n_exp_name exp2 (fun exp2 -> + n_exp_name exp3 (fun exp3 -> k (pure_rewrap (E_vector_subrange (exp1, exp2, exp3)))) + ) + ) + | E_vector_update (exp1, exp2, exp3) -> + n_exp_name exp1 (fun exp1 -> + n_exp_name exp2 (fun exp2 -> + n_exp_name exp3 (fun exp3 -> k (pure_rewrap (E_vector_update (exp1, exp2, exp3)))) + ) + ) + | E_vector_update_subrange (exp1, exp2, exp3, exp4) -> + n_exp_name exp1 (fun exp1 -> + n_exp_name exp2 (fun exp2 -> + n_exp_name exp3 (fun exp3 -> + n_exp_name exp4 (fun exp4 -> k (pure_rewrap (E_vector_update_subrange (exp1, exp2, exp3, exp4)))) + ) + ) + ) + | E_vector_append (exp1, exp2) -> + n_exp_name exp1 (fun exp1 -> n_exp_name exp2 (fun exp2 -> k (pure_rewrap (E_vector_append (exp1, exp2))))) + | E_list exps -> n_exp_nameL exps (fun exps -> k (pure_rewrap (E_list exps))) + | E_cons (exp1, exp2) -> + n_exp_name exp1 (fun exp1 -> n_exp_name exp2 (fun exp2 -> k (pure_rewrap (E_cons (exp1, exp2))))) + | E_struct fexps -> n_fexpL fexps (fun fexps -> k (pure_rewrap (E_struct fexps))) + | E_struct_update (exp1, fexps) -> + n_exp_name exp1 (fun exp1 -> n_fexpL fexps (fun fexps -> k (pure_rewrap (E_struct_update (exp1, fexps))))) + | E_field (exp1, id) -> n_exp_name exp1 (fun exp1 -> k (pure_rewrap (E_field (exp1, id)))) + | E_match (exp1, pexps) -> + let newreturn = List.exists effectful_pexp pexps in + n_exp_name exp1 (fun exp1 -> n_pexpL newreturn pexps (fun pexps -> k (rewrap (E_match (exp1, pexps))))) + | E_try (exp1, pexps) -> + let newreturn = effectful exp1 || List.exists effectful_pexp pexps in + let exp1 = n_exp_term newreturn exp1 in + n_pexpL newreturn pexps (fun pexps -> k (rewrap (E_try (exp1, pexps)))) + | E_let (lb, body) -> n_lb lb (fun lb -> rewrap (E_let (lb, n_exp body k))) + | E_sizeof nexp -> k (rewrap (E_sizeof nexp)) + | E_constraint nc -> k (rewrap (E_constraint nc)) + | E_assign (lexp, exp1) -> n_lexp lexp (fun lexp -> n_exp_name exp1 (fun exp1 -> k (rewrap (E_assign (lexp, exp1))))) + | E_exit exp' -> k (E_aux (E_exit (n_exp_term (effectful exp') exp'), annot)) + | E_assert (exp1, exp2) -> + n_exp_name exp1 (fun exp1 -> n_exp_name exp2 (fun exp2 -> k (rewrap (E_assert (exp1, exp2))))) + | E_var (lexp, exp1, exp2) -> + n_lexp lexp (fun lexp -> n_exp exp1 (fun exp1 -> rewrap (E_var (lexp, exp1, n_exp exp2 k)))) | E_internal_return exp1 -> - let is_early_return = function - | E_aux (E_app (id, _), _) -> string_of_id id = "early_return" - | _ -> false in - n_exp_name exp1 (fun exp1 -> - k (if effectful exp1 || is_early_return exp1 then exp1 else rewrap (E_internal_return exp1))) - | E_internal_value v -> - k (rewrap (E_internal_value v)) - | E_return exp' -> - n_exp_name exp' (fun exp' -> - k (pure_rewrap (E_return exp'))) - | E_throw exp' -> - n_exp_name exp' (fun exp' -> - k (rewrap (E_throw exp'))) - | E_internal_assume (nc, exp') -> - rewrap (E_internal_assume (nc, n_exp exp' k)) - | E_internal_plet _ -> failwith "E_internal_plet should not be here yet" in - - let rewrite_fun _ (FD_aux (FD_function(recopt,tannotopt,funcls),fdannot)) = + let is_early_return = function E_aux (E_app (id, _), _) -> string_of_id id = "early_return" | _ -> false in + n_exp_name exp1 (fun exp1 -> + k (if effectful exp1 || is_early_return exp1 then exp1 else rewrap (E_internal_return exp1)) + ) + | E_internal_value v -> k (rewrap (E_internal_value v)) + | E_return exp' -> n_exp_name exp' (fun exp' -> k (pure_rewrap (E_return exp'))) + | E_throw exp' -> n_exp_name exp' (fun exp' -> k (rewrap (E_throw exp'))) + | E_internal_assume (nc, exp') -> rewrap (E_internal_assume (nc, n_exp exp' k)) + | E_internal_plet _ -> failwith "E_internal_plet should not be here yet" + in + + let rewrite_fun _ (FD_aux (FD_function (recopt, tannotopt, funcls), fdannot)) = (* TODO EFFECT *) let effectful_vs = false in - let effectful_funcl (FCL_aux (FCL_funcl(_, pexp), _)) = effectful_pexp pexp in + let effectful_funcl (FCL_aux (FCL_funcl (_, pexp), _)) = effectful_pexp pexp in let newreturn = effectful_vs || List.exists effectful_funcl funcls in - let rewrite_funcl (FCL_aux (FCL_funcl(id,pexp),annot)) = + let rewrite_funcl (FCL_aux (FCL_funcl (id, pexp), annot)) = let _ = reset_fresh_name_counter () in - FCL_aux (FCL_funcl (id,n_pexp newreturn pexp (fun x -> x)),annot) + FCL_aux (FCL_funcl (id, n_pexp newreturn pexp (fun x -> x)), annot) in - FD_aux (FD_function(recopt,tannotopt,List.map rewrite_funcl funcls),fdannot) + FD_aux (FD_function (recopt, tannotopt, List.map rewrite_funcl funcls), fdannot) in let rewrite_def rewriters (DEF_aux (aux, def_annot)) = - let aux = match aux with + let aux = + match aux with | DEF_let (LB_aux (lb, annot)) -> - let rewrap lb = DEF_let (LB_aux (lb, annot)) in - begin - match lb with - | LB_val (pat, exp) -> - rewrap (LB_val (pat, n_exp_term (effectful exp) exp)) - end + let rewrap lb = DEF_let (LB_aux (lb, annot)) in + begin + match lb with LB_val (pat, exp) -> rewrap (LB_val (pat, n_exp_term (effectful exp) exp)) + end | DEF_fundef fdef -> DEF_fundef (rewrite_fun rewriters fdef) - | DEF_internal_mutrec fdefs -> - DEF_internal_mutrec (List.map (rewrite_fun rewriters) fdefs) - | _ -> aux in + | DEF_internal_mutrec fdefs -> DEF_internal_mutrec (List.map (rewrite_fun rewriters) fdefs) + | _ -> aux + in DEF_aux (aux, def_annot) in - (fun ast -> - rewrite_ast_base - { rewrite_exp = rewrite_exp - ; rewrite_pat = rewrite_pat - ; rewrite_let = rewrite_let - ; rewrite_lexp = rewrite_lexp - ; rewrite_fun = rewrite_fun - ; rewrite_def = rewrite_def - ; rewrite_ast = rewrite_ast_base - } ast, - effect_info, - env) + fun ast -> + ( rewrite_ast_base + { + rewrite_exp; + rewrite_pat; + rewrite_let; + rewrite_lexp; + rewrite_fun; + rewrite_def; + rewrite_ast = rewrite_ast_base; + } + ast, + effect_info, + env + ) let rewrite_ast_internal_lets env = - - let rec pat_of_local_lexp (LE_aux (lexp, ((l, _) as annot))) = match lexp with + let rec pat_of_local_lexp (LE_aux (lexp, ((l, _) as annot))) = + match lexp with | LE_id id -> P_aux (P_id id, annot) | LE_typ (typ, id) -> add_p_typ (env_of_annot annot) typ (P_aux (P_id id, annot)) | LE_tuple lexps -> P_aux (P_tuple (List.map pat_of_local_lexp lexps), annot) - | _ -> raise (Reporting.err_unreachable l __POS__ "unexpected local lexp") in + | _ -> raise (Reporting.err_unreachable l __POS__ "unexpected local lexp") + in - let e_let (lb,body) = + let e_let (lb, body) = match lb with - | LB_aux (LB_val (P_aux ((P_wild | P_typ (_, P_aux (P_wild, _))), _), - E_aux (E_assign ((LE_aux (_, annot) as le), exp), (l, _))), _) + | LB_aux + ( LB_val + ( P_aux ((P_wild | P_typ (_, P_aux (P_wild, _))), _), + E_aux (E_assign ((LE_aux (_, annot) as le), exp), (l, _)) + ), + _ + ) when lexp_is_local le (env_of_annot annot) && not (lexp_is_effectful le) -> - (* Rewrite assignments to local variables into let bindings *) - let (lhs, rhs) = rewrite_lexp_to_rhs le in - let (LE_aux (_, lannot)) = lhs in - let ltyp = typ_of_annot - (* The type in the lannot might come from exp rather than being the - type of the storage, so ask the type checker what it really is. *) - (match infer_lexp (env_of_annot lannot) (strip_lexp lhs) with - | LE_aux (_,lexp_annot') -> lexp_annot' - | exception _ -> lannot) - in - let rhs = add_e_typ (env_of exp) ltyp (rhs exp) in - E_let (LB_aux (LB_val (pat_of_local_lexp lhs, rhs), annot), body) - | LB_aux (LB_val (pat,exp'),annot') -> - if effectful exp' - then E_internal_plet (pat,exp',body) - else E_let (lb,body) in - - let e_var = fun (lexp,exp1,exp2) -> - let paux, annot = match lexp with - | LE_aux (LE_id id, annot) -> - (P_id id, annot) - | LE_aux (LE_typ (typ, id), annot) -> - (unaux_pat (add_p_typ (env_of_annot annot) typ (P_aux (P_id id, annot))), annot) - | _ -> failwith "E_var with unexpected lexp" in - if effectful exp1 then - E_internal_plet (P_aux (paux, annot), exp1, exp2) - else - E_let (LB_aux (LB_val (P_aux (paux, annot), exp1), annot), exp2) in - - let alg = { id_exp_alg with e_let = e_let; e_var = e_var } in + (* Rewrite assignments to local variables into let bindings *) + let lhs, rhs = rewrite_lexp_to_rhs le in + let (LE_aux (_, lannot)) = lhs in + let ltyp = + typ_of_annot + (* The type in the lannot might come from exp rather than being the + type of the storage, so ask the type checker what it really is. *) + ( match infer_lexp (env_of_annot lannot) (strip_lexp lhs) with + | LE_aux (_, lexp_annot') -> lexp_annot' + | exception _ -> lannot + ) + in + let rhs = add_e_typ (env_of exp) ltyp (rhs exp) in + E_let (LB_aux (LB_val (pat_of_local_lexp lhs, rhs), annot), body) + | LB_aux (LB_val (pat, exp'), annot') -> + if effectful exp' then E_internal_plet (pat, exp', body) else E_let (lb, body) + in + + let e_var (lexp, exp1, exp2) = + let paux, annot = + match lexp with + | LE_aux (LE_id id, annot) -> (P_id id, annot) + | LE_aux (LE_typ (typ, id), annot) -> + (unaux_pat (add_p_typ (env_of_annot annot) typ (P_aux (P_id id, annot))), annot) + | _ -> failwith "E_var with unexpected lexp" + in + if effectful exp1 then E_internal_plet (P_aux (paux, annot), exp1, exp2) + else E_let (LB_aux (LB_val (P_aux (paux, annot), exp1), annot), exp2) + in + + let alg = { id_exp_alg with e_let; e_var } in rewrite_ast_base - { rewrite_exp = (fun _ exp -> fold_exp alg exp) - ; rewrite_pat = rewrite_pat - ; rewrite_let = rewrite_let - ; rewrite_lexp = rewrite_lexp - ; rewrite_fun = rewrite_fun - ; rewrite_def = rewrite_def - ; rewrite_ast = rewrite_ast_base + { + rewrite_exp = (fun _ exp -> fold_exp alg exp); + rewrite_pat; + rewrite_let; + rewrite_lexp; + rewrite_fun; + rewrite_def; + rewrite_ast = rewrite_ast_base; } let fold_typed_guards env guards = match guards with | [] -> annot_exp (E_lit (mk_lit L_true)) Parse_ast.Unknown env bool_typ - | g :: gs -> List.fold_left (fun g g' -> annot_exp (E_app (mk_id "and_bool", [g; g'])) Parse_ast.Unknown env bool_typ) g gs + | g :: gs -> + List.fold_left (fun g g' -> annot_exp (E_app (mk_id "and_bool", [g; g'])) Parse_ast.Unknown env bool_typ) g gs let pexp_rewriters rewrite_pexp = let alg = { id_exp_alg with pat_aux = (fun (pexp_aux, annot) -> rewrite_pexp (Pat_aux (pexp_aux, annot))) } in @@ -2483,7 +2461,7 @@ let pexp_rewriters rewrite_pexp = let stringappend_counter = ref 0 let fresh_stringappend_id () = - let id = mk_id ("_s" ^ (string_of_int !stringappend_counter) ^ "#") in + let id = mk_id ("_s" ^ string_of_int !stringappend_counter ^ "#") in stringappend_counter := !stringappend_counter + 1; id @@ -2493,15 +2471,13 @@ let unkt = (Parse_ast.Unknown, empty_tannot) let construct_bool_match env (match_on : tannot exp) (pexp : tannot pexp) : tannot exp = let true_exp = annot_exp (E_lit (mk_lit L_true)) unk env bool_typ in let false_exp = annot_exp (E_lit (mk_lit L_false)) unk env bool_typ in - let true_pexp = - match pexp with - | Pat_aux (Pat_exp (pat, exp), annot) -> - Pat_aux (Pat_exp (pat, true_exp), unkt) - | Pat_aux (Pat_when (pat, guards, exp), annot) -> - Pat_aux (Pat_when (pat, guards, true_exp), unkt) - in - let false_pexp = Pat_aux (Pat_exp (annot_pat P_wild unk env (typ_of match_on), false_exp), unkt) in - annot_exp (E_match (match_on, [true_pexp; false_pexp])) unk env bool_typ + let true_pexp = + match pexp with + | Pat_aux (Pat_exp (pat, exp), annot) -> Pat_aux (Pat_exp (pat, true_exp), unkt) + | Pat_aux (Pat_when (pat, guards, exp), annot) -> Pat_aux (Pat_when (pat, guards, true_exp), unkt) + in + let false_pexp = Pat_aux (Pat_exp (annot_pat P_wild unk env (typ_of match_on), false_exp), unkt) in + annot_exp (E_match (match_on, [true_pexp; false_pexp])) unk env bool_typ let rec bindings_of_pat (P_aux (p_aux, p_annot) as pat) = match p_aux with @@ -2514,15 +2490,9 @@ let rec bindings_of_pat (P_aux (p_aux, p_annot) as pat) = | P_as (p, id) -> [annot_pat (P_id id) unk (env_of_pat p) (typ_of_pat p)] | P_cons (left, right) -> bindings_of_pat left @ bindings_of_pat right (* todo: is this right for negated patterns? *) - | P_not p - | P_typ (_, p) - | P_var (p, _) -> bindings_of_pat p - | P_app (_, ps) - | P_vector ps - | P_vector_concat ps - | P_tuple ps - | P_list ps - | P_string_append ps -> List.map bindings_of_pat ps |> List.flatten + | P_not p | P_typ (_, p) | P_var (p, _) -> bindings_of_pat p + | P_app (_, ps) | P_vector ps | P_vector_concat ps | P_tuple ps | P_list ps | P_string_append ps -> + List.map bindings_of_pat ps |> List.flatten let rec binding_typs_of_pat (P_aux (p_aux, p_annot) as pat) = match p_aux with @@ -2535,15 +2505,9 @@ let rec binding_typs_of_pat (P_aux (p_aux, p_annot) as pat) = | P_as (p, id) -> [typ_of_pat p] | P_cons (left, right) -> binding_typs_of_pat left @ binding_typs_of_pat right (* todo: is this right for negated patterns? *) - | P_not p - | P_typ (_, p) - | P_var (p, _) -> binding_typs_of_pat p - | P_app (_, ps) - | P_vector ps - | P_vector_concat ps - | P_tuple ps - | P_list ps - | P_string_append ps -> List.map binding_typs_of_pat ps |> List.flatten + | P_not p | P_typ (_, p) | P_var (p, _) -> binding_typs_of_pat p + | P_app (_, ps) | P_vector ps | P_vector_concat ps | P_tuple ps | P_list ps | P_string_append ps -> + List.map binding_typs_of_pat ps |> List.flatten let construct_toplevel_string_append_call env f_id bindings binding_typs guard expr = (* s# if match f#(s#) { Some (bindings) => guard, _ => false) } => let Some(bindings) = f#(s#) in expr *) @@ -2554,38 +2518,49 @@ let construct_toplevel_string_append_call env f_id bindings binding_typs guard e | Typ_app (Id_aux (Id "atom", _), [_]) -> int_typ | _ -> typ in - let option_typ = app_typ (mk_id "option") [A_aux (A_typ (match binding_typs with - | [] -> unit_typ - | [typ] -> hack_typ typ - | typs -> tuple_typ (List.map hack_typ typs) - ), unk)] - in - let bindings = if bindings = [] then - [annot_pat (P_lit (mk_lit L_unit)) unk env unit_typ] - else - bindings + let option_typ = + app_typ (mk_id "option") + [ + A_aux + ( A_typ + ( match binding_typs with + | [] -> unit_typ + | [typ] -> hack_typ typ + | typs -> tuple_typ (List.map hack_typ typs) + ), + unk + ); + ] in + let bindings = if bindings = [] then [annot_pat (P_lit (mk_lit L_unit)) unk env unit_typ] else bindings in let new_pat = annot_pat (P_id s_id) unk env string_typ in - let new_guard = annot_exp ( - E_match (annot_exp (E_app (f_id, [annot_exp (E_id s_id) unk env string_typ])) unk env option_typ, - [ - Pat_aux (Pat_exp (annot_pat (P_app (mk_id "Some", bindings)) unk env option_typ, guard), unkt); - Pat_aux (Pat_exp (annot_pat P_wild unk env option_typ, annot_exp (E_lit (mk_lit L_false)) unk env bool_typ), unkt) - ] - ) - ) unk env bool_typ in - let new_letbind = annot_letbind (P_app (mk_id "Some", bindings), annot_exp (E_app (f_id, [annot_exp (E_id s_id) unk env string_typ])) unk env option_typ) unk env option_typ in + let new_guard = + annot_exp + (E_match + ( annot_exp (E_app (f_id, [annot_exp (E_id s_id) unk env string_typ])) unk env option_typ, + [ + Pat_aux (Pat_exp (annot_pat (P_app (mk_id "Some", bindings)) unk env option_typ, guard), unkt); + Pat_aux + (Pat_exp (annot_pat P_wild unk env option_typ, annot_exp (E_lit (mk_lit L_false)) unk env bool_typ), unkt); + ] + ) + ) + unk env bool_typ + in + let new_letbind = + annot_letbind + ( P_app (mk_id "Some", bindings), + annot_exp (E_app (f_id, [annot_exp (E_id s_id) unk env string_typ])) unk env option_typ + ) + unk env option_typ + in let new_expr = annot_exp (E_let (new_letbind, expr)) unk env (typ_of expr) in (new_pat, [new_guard], new_expr) let construct_toplevel_string_append_func effect_info env f_id pat = let binding_typs = binding_typs_of_pat pat in let bindings = bindings_of_pat pat in - let bindings = if bindings = [] then - [annot_pat (P_lit (mk_lit L_unit)) unk env unit_typ] - else - bindings - in + let bindings = if bindings = [] then [annot_pat (P_lit (mk_lit L_unit)) unk env unit_typ] else bindings in (* AA: Pulling the types out of a pattern with binding_typs_of_pat is broken here because they might contain type variables that were bound locally to the pattern, so we can't lift them out to @@ -2597,115 +2572,171 @@ let construct_toplevel_string_append_func effect_info env f_id pat = | Typ_app (Id_aux (Id "atom", _), [_]) -> int_typ | _ -> typ in - let option_typ = app_typ (mk_id "option") [A_aux (A_typ (match binding_typs with - | [] -> unit_typ - | [typ] -> hack_typ typ - | typs -> tuple_typ (List.map hack_typ typs) - ), unk)] + let option_typ = + app_typ (mk_id "option") + [ + A_aux + ( A_typ + ( match binding_typs with + | [] -> unit_typ + | [typ] -> hack_typ typ + | typs -> tuple_typ (List.map hack_typ typs) + ), + unk + ); + ] + in + let fun_typ = mk_typ (Typ_fn ([string_typ], option_typ)) in + let new_val_spec = + VS_aux (VS_val_spec (mk_typschm (TypQ_aux (TypQ_no_forall, unk)) fun_typ, f_id, None, false), no_annot) in - let fun_typ = (mk_typ (Typ_fn ([string_typ], option_typ))) in - let new_val_spec = VS_aux (VS_val_spec (mk_typschm (TypQ_aux (TypQ_no_forall, unk)) fun_typ, f_id, None, false), no_annot) in let new_val_spec, env = Type_check.check_val_spec env (mk_def_annot Parse_ast.Unknown) new_val_spec in - let non_rec = (Rec_aux (Rec_nonrec, Parse_ast.Unknown)) in - let no_tannot = (Typ_annot_opt_aux (Typ_annot_opt_none, Parse_ast.Unknown)) in + let non_rec = Rec_aux (Rec_nonrec, Parse_ast.Unknown) in + let no_tannot = Typ_annot_opt_aux (Typ_annot_opt_none, Parse_ast.Unknown) in let s_id = fresh_stringappend_id () in let arg_pat = mk_pat (P_id s_id) in (* We can ignore guards here because we've already removed them *) let rec rewrite_pat env (pat, guards, expr) = match pat with - (* "lit" ^ pat2 ^ ... ^ patn => Some(...) ---> s# if startswith(s#, "lit") => match string_drop(s#, strlen("lit")) { - pat2 => Some(...) - _ => None() - } - *) - | P_aux (P_string_append ( - P_aux (P_lit (L_aux (L_string s, _) as lit), _) - :: pats - ), psa_annot) -> - let s_id = fresh_stringappend_id () in - let drop_exp = annot_exp (E_app (mk_id "string_drop", [annot_exp (E_id s_id) unk env string_typ; annot_exp (E_app (mk_id "string_length", [annot_exp (E_lit lit) unk env string_typ])) unk env nat_typ])) unk env string_typ in - (* recurse into pat2 .. patn *) - let new_pat2_pexp = - match rewrite_pat env (P_aux (P_string_append pats, psa_annot), guards, expr) with - | pat, [], expr -> Pat_aux (Pat_exp (pat, expr), unkt) - | pat, gs, expr -> Pat_aux (Pat_when (pat, fold_typed_guards env gs, expr), unkt) - in - let new_guard = annot_exp (E_app (mk_id "string_startswith", [annot_exp (E_id s_id) unk env string_typ; - annot_exp (E_lit lit) unk env string_typ] - )) unk env bool_typ in - let new_wildcard = Pat_aux (Pat_exp (annot_pat P_wild unk env string_typ, annot_exp (E_app (mk_id "None", [annot_exp (E_lit (mk_lit L_unit)) unk env unit_typ])) unk env option_typ), unkt) in - let new_expr = annot_exp (E_match (drop_exp, [new_pat2_pexp; new_wildcard])) unk env (typ_of expr) in - (annot_pat (P_id s_id) unk env string_typ, [new_guard], new_expr) + (* "lit" ^ pat2 ^ ... ^ patn => Some(...) ---> s# if startswith(s#, "lit") => match string_drop(s#, strlen("lit")) { + pat2 => Some(...) + _ => None() + } + *) + | P_aux (P_string_append (P_aux (P_lit (L_aux (L_string s, _) as lit), _) :: pats), psa_annot) -> + let s_id = fresh_stringappend_id () in + let drop_exp = + annot_exp + (E_app + ( mk_id "string_drop", + [ + annot_exp (E_id s_id) unk env string_typ; + annot_exp (E_app (mk_id "string_length", [annot_exp (E_lit lit) unk env string_typ])) unk env nat_typ; + ] + ) + ) + unk env string_typ + in + (* recurse into pat2 .. patn *) + let new_pat2_pexp = + match rewrite_pat env (P_aux (P_string_append pats, psa_annot), guards, expr) with + | pat, [], expr -> Pat_aux (Pat_exp (pat, expr), unkt) + | pat, gs, expr -> Pat_aux (Pat_when (pat, fold_typed_guards env gs, expr), unkt) + in + let new_guard = + annot_exp + (E_app + ( mk_id "string_startswith", + [annot_exp (E_id s_id) unk env string_typ; annot_exp (E_lit lit) unk env string_typ] + ) + ) + unk env bool_typ + in + let new_wildcard = + Pat_aux + ( Pat_exp + ( annot_pat P_wild unk env string_typ, + annot_exp + (E_app (mk_id "None", [annot_exp (E_lit (mk_lit L_unit)) unk env unit_typ])) + unk env option_typ + ), + unkt + ) + in + let new_expr = annot_exp (E_match (drop_exp, [new_pat2_pexp; new_wildcard])) unk env (typ_of expr) in + (annot_pat (P_id s_id) unk env string_typ, [new_guard], new_expr) (* mapping(x) ^ pat2 ^ .. ^ patn => Some(...) ---> s# => match map_matches_prefix(s#) { Some(x, n#) => match string_drop(s#, n#) { pat2 ^ .. ^ patn => Some(...) _ => None() } } - *) - | P_aux (P_string_append ( - P_aux (P_app (mapping_id, arg_pats) , _) - :: pats - ), psa_annot) - when Env.is_mapping mapping_id env -> - (* common things *) - let mapping_prefix_func = - match mapping_id with - | Id_aux (Id id, _) - | Id_aux (Operator id, _) -> id ^ "_matches_prefix" - in - let mapping_inner_typ = - match Env.get_val_spec (mk_id mapping_prefix_func) env with - | (_, Typ_aux (Typ_fn (_, Typ_aux (Typ_app (_, [A_aux (A_typ typ, _)]), _)), _)) -> typ - | _ -> raise (Reporting.err_unreachable Parse_ast.Unknown __POS__ "mapping prefix func without correct function type?") - in - - let s_id = fresh_stringappend_id () in - let len_id = fresh_stringappend_id () in - - (* construct drop expression -- string_drop(s#, len#) *) - let drop_exp = annot_exp (E_app (mk_id "string_drop", - [annot_exp (E_id s_id) unk env string_typ; - annot_exp (E_id len_id) unk env nat_typ])) - unk env string_typ in - (* construct func expression -- maybe_atoi s# *) - let func_exp = annot_exp (E_app (mk_id mapping_prefix_func, - [annot_exp (E_id s_id) unk env string_typ])) - unk env mapping_inner_typ in - effect_info := Effects.add_function_effect (mk_id mapping_prefix_func) !effect_info f_id; - - (* construct some pattern -- Some (n#, len#) *) - let opt_typ = app_typ (mk_id "option") [A_aux (A_typ mapping_inner_typ, unk)] in - let tup_arg_pat = match arg_pats with - | [] -> assert false - | [arg_pat] -> arg_pat - | arg_pats -> annot_pat (P_tuple arg_pats) unk env (tuple_typ (List.map typ_of_pat arg_pats)) - in - - let some_pat = annot_pat (P_app (mk_id "Some", - [tup_arg_pat; - annot_pat (P_id len_id) unk env nat_typ])) - unk env opt_typ in - let some_pat, some_pat_env, _ = bind_pat env (strip_pat some_pat) opt_typ in - - let new_wildcard = Pat_aux (Pat_exp (annot_pat P_wild unk env string_typ, annot_exp (E_app (mk_id "None", [annot_exp (E_lit (mk_lit L_unit)) unk env unit_typ])) unk env option_typ), unkt) in - - (* recurse into pat2 .. patn *) - let new_pat2_pexp = - match rewrite_pat env (P_aux (P_string_append (pats), psa_annot), guards, expr) with - | pat, [], expr -> Pat_aux (Pat_exp (pat, expr), unkt) - | pat, gs, expr -> Pat_aux (Pat_when (pat, fold_typed_guards env gs, expr), unkt) - in - - let inner_match = annot_exp (E_match (drop_exp, [new_pat2_pexp; new_wildcard])) unk env option_typ in - - let outer_match = annot_exp (E_match (func_exp, [Pat_aux (Pat_exp (some_pat, inner_match), unkt); new_wildcard])) unk env option_typ in - - (annot_pat (P_id s_id) unk env string_typ, [], outer_match) + *) + | P_aux (P_string_append (P_aux (P_app (mapping_id, arg_pats), _) :: pats), psa_annot) + when Env.is_mapping mapping_id env -> + (* common things *) + let mapping_prefix_func = + match mapping_id with Id_aux (Id id, _) | Id_aux (Operator id, _) -> id ^ "_matches_prefix" + in + let mapping_inner_typ = + match Env.get_val_spec (mk_id mapping_prefix_func) env with + | _, Typ_aux (Typ_fn (_, Typ_aux (Typ_app (_, [A_aux (A_typ typ, _)]), _)), _) -> typ + | _ -> + raise + (Reporting.err_unreachable Parse_ast.Unknown __POS__ + "mapping prefix func without correct function type?" + ) + in + + let s_id = fresh_stringappend_id () in + let len_id = fresh_stringappend_id () in + + (* construct drop expression -- string_drop(s#, len#) *) + let drop_exp = + annot_exp + (E_app + (mk_id "string_drop", [annot_exp (E_id s_id) unk env string_typ; annot_exp (E_id len_id) unk env nat_typ]) + ) + unk env string_typ + in + (* construct func expression -- maybe_atoi s# *) + let func_exp = + annot_exp + (E_app (mk_id mapping_prefix_func, [annot_exp (E_id s_id) unk env string_typ])) + unk env mapping_inner_typ + in + effect_info := Effects.add_function_effect (mk_id mapping_prefix_func) !effect_info f_id; + + (* construct some pattern -- Some (n#, len#) *) + let opt_typ = app_typ (mk_id "option") [A_aux (A_typ mapping_inner_typ, unk)] in + let tup_arg_pat = + match arg_pats with + | [] -> assert false + | [arg_pat] -> arg_pat + | arg_pats -> annot_pat (P_tuple arg_pats) unk env (tuple_typ (List.map typ_of_pat arg_pats)) + in + + let some_pat = + annot_pat (P_app (mk_id "Some", [tup_arg_pat; annot_pat (P_id len_id) unk env nat_typ])) unk env opt_typ + in + let some_pat, some_pat_env, _ = bind_pat env (strip_pat some_pat) opt_typ in + + let new_wildcard = + Pat_aux + ( Pat_exp + ( annot_pat P_wild unk env string_typ, + annot_exp + (E_app (mk_id "None", [annot_exp (E_lit (mk_lit L_unit)) unk env unit_typ])) + unk env option_typ + ), + unkt + ) + in + + (* recurse into pat2 .. patn *) + let new_pat2_pexp = + match rewrite_pat env (P_aux (P_string_append pats, psa_annot), guards, expr) with + | pat, [], expr -> Pat_aux (Pat_exp (pat, expr), unkt) + | pat, gs, expr -> Pat_aux (Pat_when (pat, fold_typed_guards env gs, expr), unkt) + in + + let inner_match = annot_exp (E_match (drop_exp, [new_pat2_pexp; new_wildcard])) unk env option_typ in + + let outer_match = + annot_exp + (E_match (func_exp, [Pat_aux (Pat_exp (some_pat, inner_match), unkt); new_wildcard])) + unk env option_typ + in + + (annot_pat (P_id s_id) unk env string_typ, [], outer_match) | _ -> (pat, guards, expr) in - let new_pat, new_guards, new_expr = rewrite_pat env (pat, [], annot_exp (E_app (mk_id "Some", List.map (fun p -> pat_to_exp p) bindings)) unk env option_typ) in - let new_pexp = match new_guards with + let new_pat, new_guards, new_expr = + rewrite_pat env + (pat, [], annot_exp (E_app (mk_id "Some", List.map (fun p -> pat_to_exp p) bindings)) unk env option_typ) + in + let new_pexp = + match new_guards with | [] -> Pat_aux (Pat_exp (new_pat, new_expr), unkt) | gs -> Pat_aux (Pat_when (new_pat, fold_typed_guards env gs, new_expr), unkt) in @@ -2720,25 +2751,25 @@ let rewrite_ast_toplevel_string_append effect_info env ast = let effect_info = ref effect_info in let rewrite_pexp (Pat_aux (pexp_aux, pexp_annot)) = (* merge cases of Pat_exp and Pat_when *) - let (P_aux (p_aux, p_annot) as pat, guards, expr) = - match pexp_aux with - | Pat_exp (pat, expr) -> (pat, [], expr) - | Pat_when (pat, guard, expr) -> (pat, [guard], expr) + let (P_aux (p_aux, p_annot) as pat), guards, expr = + match pexp_aux with Pat_exp (pat, expr) -> (pat, [], expr) | Pat_when (pat, guard, expr) -> (pat, [guard], expr) in let env = env_of_annot p_annot in - let (new_pat, new_guards, new_expr) = + let new_pat, new_guards, new_expr = match pat with | P_aux (P_string_append appends, psa_annot) -> - let f_id = fresh_stringappend_id () in - new_defs := (construct_toplevel_string_append_func effect_info env f_id pat) @ !new_defs; - construct_toplevel_string_append_call env f_id (bindings_of_pat pat) (binding_typs_of_pat pat) (fold_typed_guards env guards) expr + let f_id = fresh_stringappend_id () in + new_defs := construct_toplevel_string_append_func effect_info env f_id pat @ !new_defs; + construct_toplevel_string_append_call env f_id (bindings_of_pat pat) (binding_typs_of_pat pat) + (fold_typed_guards env guards) expr | _ -> (pat, guards, expr) in (* un-merge Pat_exp and Pat_when cases *) - let new_pexp = match new_guards with + let new_pexp = + match new_guards with | [] -> Pat_aux (Pat_exp (new_pat, new_expr), pexp_annot) | gs -> Pat_aux (Pat_when (new_pat, fold_typed_guards env gs, new_expr), pexp_annot) in @@ -2752,8 +2783,7 @@ let rewrite_ast_toplevel_string_append effect_info env ast = in let new_defs = List.map rewrite ast.defs |> List.flatten in - { ast with defs = new_defs }, !effect_info, env - + ({ ast with defs = new_defs }, !effect_info, env) let rewrite_ast_pat_string_append env = let rec rewrite_pat env ((pat : tannot pat), (guards : tannot exp list), (expr : tannot exp)) = @@ -2775,36 +2805,47 @@ let rewrite_ast_pat_string_append env = pat2 => expr } *) - | P_aux (P_string_append ( - P_aux (P_lit (L_aux (L_string s, _) as lit), _) - :: pats - ), psa_annot) -> - - let id = fresh_stringappend_id () in - - (* construct drop expression -- string_drop(s#, strlen("lit")) *) - let drop_exp = annot_exp (E_app (mk_id "string_drop", [annot_exp (E_id id) unk env string_typ; annot_exp (E_app (mk_id "string_length", [annot_exp (E_lit lit) unk env string_typ])) unk env nat_typ])) unk env string_typ in - - (* recurse into pat2 *) - let new_pat2_pexp = - match rewrite_pat env (P_aux (P_string_append (pats), psa_annot), guards, expr) with - | pat, [], expr -> Pat_aux (Pat_exp (pat, expr), unkt) - | pat, gs, expr -> Pat_aux (Pat_when (pat, fold_typed_guards env gs, expr), unkt) - in + | P_aux (P_string_append (P_aux (P_lit (L_aux (L_string s, _) as lit), _) :: pats), psa_annot) -> + let id = fresh_stringappend_id () in + + (* construct drop expression -- string_drop(s#, strlen("lit")) *) + let drop_exp = + annot_exp + (E_app + ( mk_id "string_drop", + [ + annot_exp (E_id id) unk env string_typ; + annot_exp (E_app (mk_id "string_length", [annot_exp (E_lit lit) unk env string_typ])) unk env nat_typ; + ] + ) + ) + unk env string_typ + in - (* construct the two new guards *) - let guard1 = annot_exp (E_app (mk_id "string_startswith", - [annot_exp (E_id id) unk env string_typ; - annot_exp (E_lit lit) unk env string_typ] - )) unk env bool_typ in - let guard2 = construct_bool_match env drop_exp new_pat2_pexp in + (* recurse into pat2 *) + let new_pat2_pexp = + match rewrite_pat env (P_aux (P_string_append pats, psa_annot), guards, expr) with + | pat, [], expr -> Pat_aux (Pat_exp (pat, expr), unkt) + | pat, gs, expr -> Pat_aux (Pat_when (pat, fold_typed_guards env gs, expr), unkt) + in - (* construct new match expr *) - let new_expr = annot_exp (E_match (drop_exp, [new_pat2_pexp])) unk env (typ_of expr) in + (* construct the two new guards *) + let guard1 = + annot_exp + (E_app + ( mk_id "string_startswith", + [annot_exp (E_id id) unk env string_typ; annot_exp (E_lit lit) unk env string_typ] + ) + ) + unk env bool_typ + in + let guard2 = construct_bool_match env drop_exp new_pat2_pexp in - (* construct final result *) - annot_pat (P_id id) unk env string_typ, [guard1; guard2], new_expr + (* construct new match expr *) + let new_expr = annot_exp (E_match (drop_exp, [new_pat2_pexp])) unk env (typ_of expr) in + (* construct final result *) + (annot_pat (P_id id) unk env string_typ, [guard1; guard2], new_expr) (* (builtin x) ^^ pat2 => expr ---> s# if match maybe_atoi s# { Some (n#, len#) => @@ -2820,187 +2861,198 @@ let rewrite_ast_pat_string_append env = pat2 => expr } *) + | P_aux (P_string_append (P_aux (P_app (mapping_id, arg_pats), _) :: pats), psa_annot) + when Env.is_mapping mapping_id env -> + (* common things *) + let mapping_prefix_func = + match mapping_id with Id_aux (Id id, _) | Id_aux (Operator id, _) -> id ^ "_matches_prefix" + in + let mapping_inner_typ = + match Env.get_val_spec (mk_id mapping_prefix_func) env with + | _, Typ_aux (Typ_fn (_, Typ_aux (Typ_app (_, [A_aux (A_typ typ, _)]), _)), _) -> typ + | _ -> typ_error env Parse_ast.Unknown "mapping prefix func without correct function type?" + in - | P_aux (P_string_append ( - P_aux (P_app (mapping_id, arg_pats) , _) - :: pats - ), psa_annot) - when Env.is_mapping mapping_id env -> - (* common things *) - let mapping_prefix_func = - match mapping_id with - | Id_aux (Id id, _) - | Id_aux (Operator id, _) -> id ^ "_matches_prefix" - in - let mapping_inner_typ = - match Env.get_val_spec (mk_id mapping_prefix_func) env with - | (_, Typ_aux (Typ_fn (_, Typ_aux (Typ_app (_, [A_aux (A_typ typ, _)]), _)), _)) -> typ - | _ -> typ_error env Parse_ast.Unknown "mapping prefix func without correct function type?" - in - - let s_id = fresh_stringappend_id () in - let len_id = fresh_stringappend_id () in - - (* construct drop expression -- string_drop(s#, len#) *) - let drop_exp = annot_exp (E_app (mk_id "string_drop", - [annot_exp (E_id s_id) unk env string_typ; - annot_exp (E_id len_id) unk env nat_typ])) - unk env string_typ in - (* construct func expression -- maybe_atoi s# *) - let func_exp = annot_exp (E_app (mk_id mapping_prefix_func, - [annot_exp (E_id s_id) unk env string_typ])) - unk env mapping_inner_typ in - (* construct some pattern -- Some (n#, len#) *) - let opt_typ = app_typ (mk_id "option") [A_aux (A_typ mapping_inner_typ, unk)] in - let tup_arg_pat = match arg_pats with - | [] -> assert false - | [arg_pat] -> arg_pat - | arg_pats -> annot_pat (P_tuple arg_pats) unk env (tuple_typ (List.map typ_of_pat arg_pats)) - in - - let some_pat = annot_pat (P_app (mk_id "Some", - [tup_arg_pat; - annot_pat (P_id len_id) unk env nat_typ])) - unk env opt_typ in - let some_pat, some_pat_env, _ = bind_pat env (strip_pat some_pat) opt_typ in - - (* need to add the Some(...) env to tup_arg_pats for pat_to_exp below as it calls the typechecker *) - let tup_arg_pat = map_pat_annot (fun (l, tannot) -> (l, replace_env some_pat_env tannot)) tup_arg_pat in - - (* construct None pattern *) - let none_pat = annot_pat P_wild unk env opt_typ in - - (* recurse into pat2 *) - let new_pat2_pexp = - match rewrite_pat env (P_aux (P_string_append (pats), psa_annot), guards, expr) with - | pat, [], expr -> Pat_aux (Pat_exp (pat, expr), unkt) - | pat, gs, expr -> Pat_aux (Pat_when (pat, fold_typed_guards env gs, expr), unkt) - in - - (* construct the new guard *) - let guard_inner_match = construct_bool_match env drop_exp new_pat2_pexp in - let new_guard = annot_exp (E_match (func_exp, [ - Pat_aux (Pat_exp (some_pat, guard_inner_match), unkt); - Pat_aux (Pat_exp (none_pat, annot_exp (E_lit (mk_lit (L_false))) unk env bool_typ), unkt) - ])) unk env bool_typ in - - (* construct the new match *) - let new_match = annot_exp (E_match (drop_exp, [new_pat2_pexp])) unk env (typ_of expr) in - - (* construct the new let *) - let new_binding = annot_exp (E_typ (mapping_inner_typ, - annot_exp (E_match (func_exp, [ - Pat_aux (Pat_exp (some_pat, - annot_exp (E_tuple [ - pat_to_exp tup_arg_pat; - annot_exp (E_id len_id) unk env nat_typ - ]) unk env mapping_inner_typ - ), unkt) - ])) unk env mapping_inner_typ - )) unk env mapping_inner_typ in - let new_letbind = - match arg_pats with - | [] -> assert false - | [arg_pat] -> annot_letbind - (P_tuple [arg_pat; annot_pat (P_id len_id) unk env nat_typ], new_binding) - unk env (tuple_typ [typ_of_pat arg_pat; nat_typ]) - | arg_pats -> annot_letbind - (P_tuple - [annot_pat (P_tuple arg_pats) unk env (tuple_typ (List.map typ_of_pat arg_pats)); - annot_pat (P_id len_id) unk env nat_typ], - new_binding) - unk env (tuple_typ [tuple_typ (List.map typ_of_pat arg_pats); nat_typ]) - in - let new_let = annot_exp (E_let (new_letbind, new_match)) unk env (typ_of expr) in - - (* construct final result *) - annot_pat (P_id s_id) unk env string_typ, - [new_guard], - new_let - - | P_aux (P_string_append [pat], _) -> - pat, guards, expr - - | P_aux (P_string_append [], (l, _)) -> - annot_pat (P_lit (L_aux (L_string "", l))) l env string_typ, guards, expr + let s_id = fresh_stringappend_id () in + let len_id = fresh_stringappend_id () in - | P_aux (P_string_append _, _) -> - failwith ("encountered a variety of string append pattern that is not yet implemented: " ^ string_of_pat pat) + (* construct drop expression -- string_drop(s#, len#) *) + let drop_exp = + annot_exp + (E_app + (mk_id "string_drop", [annot_exp (E_id s_id) unk env string_typ; annot_exp (E_id len_id) unk env nat_typ]) + ) + unk env string_typ + in + (* construct func expression -- maybe_atoi s# *) + let func_exp = + annot_exp + (E_app (mk_id mapping_prefix_func, [annot_exp (E_id s_id) unk env string_typ])) + unk env mapping_inner_typ + in + (* construct some pattern -- Some (n#, len#) *) + let opt_typ = app_typ (mk_id "option") [A_aux (A_typ mapping_inner_typ, unk)] in + let tup_arg_pat = + match arg_pats with + | [] -> assert false + | [arg_pat] -> arg_pat + | arg_pats -> annot_pat (P_tuple arg_pats) unk env (tuple_typ (List.map typ_of_pat arg_pats)) + in - | P_aux (P_or(pat1, pat2), p_annot) -> - (* todo: this is wrong - no idea what is happening here *) - let (pat1', guards1, expr1) = rewrite_pat env (pat1, guards, expr) in - let (pat2', guards2, expr2) = rewrite_pat env (pat2, guards, expr) in - (P_aux (P_or(pat1', pat2'), p_annot), guards1 @ guards2, expr2) + let some_pat = + annot_pat (P_app (mk_id "Some", [tup_arg_pat; annot_pat (P_id len_id) unk env nat_typ])) unk env opt_typ + in + let some_pat, some_pat_env, _ = bind_pat env (strip_pat some_pat) opt_typ in + + (* need to add the Some(...) env to tup_arg_pats for pat_to_exp below as it calls the typechecker *) + let tup_arg_pat = map_pat_annot (fun (l, tannot) -> (l, replace_env some_pat_env tannot)) tup_arg_pat in + + (* construct None pattern *) + let none_pat = annot_pat P_wild unk env opt_typ in + + (* recurse into pat2 *) + let new_pat2_pexp = + match rewrite_pat env (P_aux (P_string_append pats, psa_annot), guards, expr) with + | pat, [], expr -> Pat_aux (Pat_exp (pat, expr), unkt) + | pat, gs, expr -> Pat_aux (Pat_when (pat, fold_typed_guards env gs, expr), unkt) + in - | P_aux (P_not(pat), p_annot) -> - let (pat', guards, expr) = rewrite_pat env (pat, guards, expr) in - (P_aux (P_not(pat'), p_annot), guards, expr) + (* construct the new guard *) + let guard_inner_match = construct_bool_match env drop_exp new_pat2_pexp in + let new_guard = + annot_exp + (E_match + ( func_exp, + [ + Pat_aux (Pat_exp (some_pat, guard_inner_match), unkt); + Pat_aux (Pat_exp (none_pat, annot_exp (E_lit (mk_lit L_false)) unk env bool_typ), unkt); + ] + ) + ) + unk env bool_typ + in + + (* construct the new match *) + let new_match = annot_exp (E_match (drop_exp, [new_pat2_pexp])) unk env (typ_of expr) in + + (* construct the new let *) + let new_binding = + annot_exp + (E_typ + ( mapping_inner_typ, + annot_exp + (E_match + ( func_exp, + [ + Pat_aux + ( Pat_exp + ( some_pat, + annot_exp + (E_tuple [pat_to_exp tup_arg_pat; annot_exp (E_id len_id) unk env nat_typ]) + unk env mapping_inner_typ + ), + unkt + ); + ] + ) + ) + unk env mapping_inner_typ + ) + ) + unk env mapping_inner_typ + in + let new_letbind = + match arg_pats with + | [] -> assert false + | [arg_pat] -> + annot_letbind + (P_tuple [arg_pat; annot_pat (P_id len_id) unk env nat_typ], new_binding) + unk env + (tuple_typ [typ_of_pat arg_pat; nat_typ]) + | arg_pats -> + annot_letbind + ( P_tuple + [ + annot_pat (P_tuple arg_pats) unk env (tuple_typ (List.map typ_of_pat arg_pats)); + annot_pat (P_id len_id) unk env nat_typ; + ], + new_binding + ) + unk env + (tuple_typ [tuple_typ (List.map typ_of_pat arg_pats); nat_typ]) + in + let new_let = annot_exp (E_let (new_letbind, new_match)) unk env (typ_of expr) in + (* construct final result *) + (annot_pat (P_id s_id) unk env string_typ, [new_guard], new_let) + | P_aux (P_string_append [pat], _) -> (pat, guards, expr) + | P_aux (P_string_append [], (l, _)) -> (annot_pat (P_lit (L_aux (L_string "", l))) l env string_typ, guards, expr) + | P_aux (P_string_append _, _) -> + failwith ("encountered a variety of string append pattern that is not yet implemented: " ^ string_of_pat pat) + | P_aux (P_or (pat1, pat2), p_annot) -> + (* todo: this is wrong - no idea what is happening here *) + let pat1', guards1, expr1 = rewrite_pat env (pat1, guards, expr) in + let pat2', guards2, expr2 = rewrite_pat env (pat2, guards, expr) in + (P_aux (P_or (pat1', pat2'), p_annot), guards1 @ guards2, expr2) + | P_aux (P_not pat, p_annot) -> + let pat', guards, expr = rewrite_pat env (pat, guards, expr) in + (P_aux (P_not pat', p_annot), guards, expr) | P_aux (P_as (inner_pat, inner_id), p_annot) -> - let inner_pat, guards, expr = rewrite_pat env (inner_pat, guards, expr) in - P_aux (P_as (inner_pat, inner_id), p_annot), guards, expr + let inner_pat, guards, expr = rewrite_pat env (inner_pat, guards, expr) in + (P_aux (P_as (inner_pat, inner_id), p_annot), guards, expr) | P_aux (P_typ (inner_typ, inner_pat), p_annot) -> - let inner_pat, guards, expr = rewrite_pat env (inner_pat, guards, expr) in - P_aux (P_typ (inner_typ, inner_pat), p_annot), guards, expr + let inner_pat, guards, expr = rewrite_pat env (inner_pat, guards, expr) in + (P_aux (P_typ (inner_typ, inner_pat), p_annot), guards, expr) | P_aux (P_var (inner_pat, typ_pat), p_annot) -> - let inner_pat, guards, expr = rewrite_pat env (inner_pat, guards, expr) in - P_aux (P_var (inner_pat, typ_pat), p_annot), guards, expr + let inner_pat, guards, expr = rewrite_pat env (inner_pat, guards, expr) in + (P_aux (P_var (inner_pat, typ_pat), p_annot), guards, expr) | P_aux (P_vector pats, p_annot) -> - let pats = List.map folder pats in - P_aux (P_vector pats, p_annot), !guards_ref, !expr_ref + let pats = List.map folder pats in + (P_aux (P_vector pats, p_annot), !guards_ref, !expr_ref) | P_aux (P_vector_concat pats, p_annot) -> - let pats = List.map folder pats in - P_aux (P_vector_concat pats, p_annot), !guards_ref, !expr_ref + let pats = List.map folder pats in + (P_aux (P_vector_concat pats, p_annot), !guards_ref, !expr_ref) | P_aux (P_tuple pats, p_annot) -> - let pats = List.map folder pats in - P_aux (P_tuple pats, p_annot), !guards_ref, !expr_ref + let pats = List.map folder pats in + (P_aux (P_tuple pats, p_annot), !guards_ref, !expr_ref) | P_aux (P_list pats, p_annot) -> - let pats = List.map folder pats in - P_aux (P_list pats, p_annot), !guards_ref, !expr_ref + let pats = List.map folder pats in + (P_aux (P_list pats, p_annot), !guards_ref, !expr_ref) | P_aux (P_app (f, pats), p_annot) -> - let pats = List.map folder pats in - P_aux (P_app (f, pats), p_annot), !guards_ref, !expr_ref + let pats = List.map folder pats in + (P_aux (P_app (f, pats), p_annot), !guards_ref, !expr_ref) | P_aux (P_cons (pat1, pat2), p_annot) -> - let pat1, guards, expr = rewrite_pat env (pat1, guards, expr) in - let pat2, guards, expr = rewrite_pat env (pat2, guards, expr) in - P_aux (P_cons (pat1, pat2), p_annot), guards, expr - | P_aux (P_id _, _) - | P_aux (P_vector_subrange _, _) - | P_aux (P_lit _, _) - | P_aux (P_wild, _) -> pat, guards, expr + let pat1, guards, expr = rewrite_pat env (pat1, guards, expr) in + let pat2, guards, expr = rewrite_pat env (pat2, guards, expr) in + (P_aux (P_cons (pat1, pat2), p_annot), guards, expr) + | P_aux (P_id _, _) | P_aux (P_vector_subrange _, _) | P_aux (P_lit _, _) | P_aux (P_wild, _) -> (pat, guards, expr) in let rewrite_pexp (Pat_aux (pexp_aux, pexp_annot)) = - (* merge cases of Pat_exp and Pat_when *) - let (P_aux (p_aux, p_annot) as pat, guards, expr) = - match pexp_aux with - | Pat_exp (pat, expr) -> (pat, [], expr) - | Pat_when (pat, guard, expr) -> (pat, [guard], expr) + let (P_aux (p_aux, p_annot) as pat), guards, expr = + match pexp_aux with Pat_exp (pat, expr) -> (pat, [], expr) | Pat_when (pat, guard, expr) -> (pat, [guard], expr) in let env = env_of_annot p_annot in - let (new_pat, new_guards, new_expr) = - rewrite_pat env (pat, guards, expr) - in + let new_pat, new_guards, new_expr = rewrite_pat env (pat, guards, expr) in (* un-merge Pat_exp and Pat_when cases *) - let new_pexp = match new_guards with - | [] -> Pat_aux (Pat_exp (new_pat, new_expr), pexp_annot) - | gs -> Pat_aux (Pat_when (new_pat, fold_typed_guards env gs, new_expr), pexp_annot) + let new_pexp = + match new_guards with + | [] -> Pat_aux (Pat_exp (new_pat, new_expr), pexp_annot) + | gs -> Pat_aux (Pat_when (new_pat, fold_typed_guards env gs, new_expr), pexp_annot) in new_pexp - in - pexp_rewriters rewrite_pexp + pexp_rewriters rewrite_pexp let mappingpatterns_counter = ref 0 let fresh_mappingpatterns_id () = - let id = mk_id ("_mappingpatterns_" ^ (string_of_int !mappingpatterns_counter) ^ "#") in + let id = mk_id ("_mappingpatterns_" ^ string_of_int !mappingpatterns_counter ^ "#") in mappingpatterns_counter := !mappingpatterns_counter + 1; id @@ -3016,7 +3068,7 @@ let rewrite_ast_mapping_patterns env = in let env = env_of_pat pat in match pat with - (* + (* mapping(args) if g => expr ----> s# if mapping_matches(s#) & (if mapping_matches(s#) then let args = mapping(s#) in g) => let args = mapping(s#) in expr @@ -3024,143 +3076,134 @@ let rewrite_ast_mapping_patterns env = (plus 'infer the mapping type' shenanigans) *) | P_aux (P_app (mapping_id, arg_pats), p_annot) when Env.is_mapping mapping_id env -> + let mapping_in_typ = typ_of_annot p_annot in + + let x = Env.get_val_spec mapping_id env in + + let typ1, typ2 = + match x with + | _, Typ_aux (Typ_bidir (typ1, typ2), _) -> (typ1, typ2) + | _, typ -> + raise + (Reporting.err_unreachable (fst p_annot) __POS__ + ("Must be bi-directional mapping: " ^ string_of_typ typ) + ) + in + + let mapping_direction = if mapping_in_typ = typ1 then "forwards" else "backwards" in + + let mapping_out_typ = if mapping_in_typ = typ2 then typ2 else typ1 in + + let mapping_name = match mapping_id with Id_aux (Id id, _) | Id_aux (Operator id, _) -> id in - let mapping_in_typ = typ_of_annot p_annot in - - let x = Env.get_val_spec mapping_id env in - - let typ1, typ2 = match x with - | (_, Typ_aux(Typ_bidir(typ1, typ2), _)) -> typ1, typ2 - | (_, typ) -> raise (Reporting.err_unreachable (fst p_annot) __POS__ ("Must be bi-directional mapping: " ^ string_of_typ typ)) - in - - let mapping_direction = - if mapping_in_typ = typ1 then - "forwards" - else - "backwards" - in - - let mapping_out_typ = - if mapping_in_typ = typ2 then - typ2 - else - typ1 - in - - let mapping_name = - match mapping_id with - | Id_aux (Id id, _) - | Id_aux (Operator id, _) -> id - in - - let mapping_matches_id = mk_id (mapping_name ^ "_" ^ mapping_direction ^ "_matches") in - let mapping_perform_id = mk_id (mapping_name ^ "_" ^ mapping_direction) in - let s_id = fresh_mappingpatterns_id () in - - let s_exp = annot_exp (E_id s_id) unk env mapping_in_typ in - let new_guard = annot_exp (E_app (mapping_matches_id, [s_exp])) unk env bool_typ in - let new_binding = annot_exp (E_app (mapping_perform_id, [s_exp])) unk env typ2 in - let new_letbind, expr = match arg_pats with - | [] -> assert false - | [arg_pat] -> - let arg_pat, _, expr = rewrite_pat env (arg_pat, [], expr) in - LB_aux (LB_val (arg_pat, new_binding), unkt), expr - | arg_pats -> - let checked_tup = annot_pat (P_tuple arg_pats) unk env mapping_out_typ in - LB_aux (LB_val (checked_tup, new_binding), unkt), expr - in - - let new_let = annot_exp (E_let (new_letbind, expr)) unk env (typ_of expr) in - - let false_exp = annot_exp (E_lit (L_aux (L_false, unk))) unk env bool_typ in - let new_complete_guard = - match guards with - | [] -> new_guard - | _ -> - annot_exp (E_if (new_guard, - (annot_exp (E_let (new_letbind, annot_exp (E_typ (bool_typ, fold_typed_guards env guards)) unk env bool_typ)) unk env bool_typ), - false_exp)) unk env bool_typ - in - - annot_pat (P_typ (mapping_in_typ, annot_pat (P_id s_id) unk env mapping_in_typ)) unk env mapping_in_typ, [new_complete_guard], new_let + let mapping_matches_id = mk_id (mapping_name ^ "_" ^ mapping_direction ^ "_matches") in + let mapping_perform_id = mk_id (mapping_name ^ "_" ^ mapping_direction) in + let s_id = fresh_mappingpatterns_id () in + let s_exp = annot_exp (E_id s_id) unk env mapping_in_typ in + let new_guard = annot_exp (E_app (mapping_matches_id, [s_exp])) unk env bool_typ in + let new_binding = annot_exp (E_app (mapping_perform_id, [s_exp])) unk env typ2 in + let new_letbind, expr = + match arg_pats with + | [] -> assert false + | [arg_pat] -> + let arg_pat, _, expr = rewrite_pat env (arg_pat, [], expr) in + (LB_aux (LB_val (arg_pat, new_binding), unkt), expr) + | arg_pats -> + let checked_tup = annot_pat (P_tuple arg_pats) unk env mapping_out_typ in + (LB_aux (LB_val (checked_tup, new_binding), unkt), expr) + in + + let new_let = annot_exp (E_let (new_letbind, expr)) unk env (typ_of expr) in + + let false_exp = annot_exp (E_lit (L_aux (L_false, unk))) unk env bool_typ in + let new_complete_guard = + match guards with + | [] -> new_guard + | _ -> + annot_exp + (E_if + ( new_guard, + annot_exp + (E_let (new_letbind, annot_exp (E_typ (bool_typ, fold_typed_guards env guards)) unk env bool_typ)) + unk env bool_typ, + false_exp + ) + ) + unk env bool_typ + in + + ( annot_pat (P_typ (mapping_in_typ, annot_pat (P_id s_id) unk env mapping_in_typ)) unk env mapping_in_typ, + [new_complete_guard], + new_let + ) | P_aux (P_as (inner_pat, inner_id), p_annot) -> - let inner_pat, guards, expr = rewrite_pat env (inner_pat, guards, expr) in - P_aux (P_as (inner_pat, inner_id), p_annot), guards, expr + let inner_pat, guards, expr = rewrite_pat env (inner_pat, guards, expr) in + (P_aux (P_as (inner_pat, inner_id), p_annot), guards, expr) | P_aux (P_typ (inner_typ, inner_pat), p_annot) -> - let inner_pat, guards, expr = rewrite_pat env (inner_pat, guards, expr) in - P_aux (P_typ (inner_typ, inner_pat), p_annot), guards, expr + let inner_pat, guards, expr = rewrite_pat env (inner_pat, guards, expr) in + (P_aux (P_typ (inner_typ, inner_pat), p_annot), guards, expr) | P_aux (P_var (inner_pat, typ_pat), p_annot) -> - let inner_pat, guards, expr = rewrite_pat env (inner_pat, guards, expr) in - P_aux (P_var (inner_pat, typ_pat), p_annot), guards, expr + let inner_pat, guards, expr = rewrite_pat env (inner_pat, guards, expr) in + (P_aux (P_var (inner_pat, typ_pat), p_annot), guards, expr) | P_aux (P_vector pats, p_annot) -> - let pats = List.map folder pats in - P_aux (P_vector pats, p_annot), !guards_ref, !expr_ref + let pats = List.map folder pats in + (P_aux (P_vector pats, p_annot), !guards_ref, !expr_ref) | P_aux (P_vector_concat pats, p_annot) -> - let pats = List.map folder pats in - P_aux (P_vector_concat pats, p_annot), !guards_ref, !expr_ref + let pats = List.map folder pats in + (P_aux (P_vector_concat pats, p_annot), !guards_ref, !expr_ref) | P_aux (P_tuple pats, p_annot) -> - let pats = List.map folder pats in - P_aux (P_tuple pats, p_annot), !guards_ref, !expr_ref + let pats = List.map folder pats in + (P_aux (P_tuple pats, p_annot), !guards_ref, !expr_ref) | P_aux (P_list pats, p_annot) -> - let pats = List.map folder pats in - P_aux (P_list pats, p_annot), !guards_ref, !expr_ref + let pats = List.map folder pats in + (P_aux (P_list pats, p_annot), !guards_ref, !expr_ref) | P_aux (P_app (f, pats), p_annot) -> - let pats = List.map folder pats in - P_aux (P_app (f, pats), p_annot), !guards_ref, !expr_ref + let pats = List.map folder pats in + (P_aux (P_app (f, pats), p_annot), !guards_ref, !expr_ref) | P_aux (P_string_append pats, p_annot) -> - let pats = List.map folder pats in - P_aux (P_string_append pats, p_annot), !guards_ref, !expr_ref + let pats = List.map folder pats in + (P_aux (P_string_append pats, p_annot), !guards_ref, !expr_ref) | P_aux (P_cons (pat1, pat2), p_annot) -> - let pat1, guards, expr = rewrite_pat env (pat1, guards, expr) in - let pat2, guards, expr = rewrite_pat env (pat2, guards, expr) in - P_aux (P_cons (pat1, pat2), p_annot), guards, expr + let pat1, guards, expr = rewrite_pat env (pat1, guards, expr) in + let pat2, guards, expr = rewrite_pat env (pat2, guards, expr) in + (P_aux (P_cons (pat1, pat2), p_annot), guards, expr) | P_aux (P_or (pat1, pat2), p_annot) -> - let pat1, guards, expr = rewrite_pat env (pat1, guards, expr) in - let pat2, guards, expr = rewrite_pat env (pat2, guards, expr) in - P_aux (P_or (pat1, pat2), p_annot), guards, expr + let pat1, guards, expr = rewrite_pat env (pat1, guards, expr) in + let pat2, guards, expr = rewrite_pat env (pat2, guards, expr) in + (P_aux (P_or (pat1, pat2), p_annot), guards, expr) | P_aux (P_not p, p_annot) -> - let p', guards, expr = rewrite_pat env (p, guards, expr) in - P_aux (P_not p', p_annot), guards, expr - | P_aux (P_id _, _) - | P_aux (P_vector_subrange _, _) - | P_aux (P_lit _, _) - | P_aux (P_wild, _) -> pat, guards, expr + let p', guards, expr = rewrite_pat env (p, guards, expr) in + (P_aux (P_not p', p_annot), guards, expr) + | P_aux (P_id _, _) | P_aux (P_vector_subrange _, _) | P_aux (P_lit _, _) | P_aux (P_wild, _) -> (pat, guards, expr) in let rewrite_pexp (Pat_aux (pexp_aux, pexp_annot)) = - (* merge cases of Pat_exp and Pat_when *) - let (P_aux (p_aux, p_annot) as pat, guards, expr) = - match pexp_aux with - | Pat_exp (pat, expr) -> (pat, [], expr) - | Pat_when (pat, guard, expr) -> (pat, [guard], expr) + let (P_aux (p_aux, p_annot) as pat), guards, expr = + match pexp_aux with Pat_exp (pat, expr) -> (pat, [], expr) | Pat_when (pat, guard, expr) -> (pat, [guard], expr) in let env = env_of_annot p_annot in - let (new_pat, new_guards, new_expr) = - rewrite_pat env (pat, guards, expr) - in + let new_pat, new_guards, new_expr = rewrite_pat env (pat, guards, expr) in (* un-merge Pat_exp and Pat_when cases *) - let new_pexp = match new_guards with - | [] -> Pat_aux (Pat_exp (new_pat, new_expr), pexp_annot) - | gs -> Pat_aux (Pat_when (new_pat, fold_typed_guards env gs, new_expr), pexp_annot) + let new_pexp = + match new_guards with + | [] -> Pat_aux (Pat_exp (new_pat, new_expr), pexp_annot) + | gs -> Pat_aux (Pat_when (new_pat, fold_typed_guards env gs, new_expr), pexp_annot) in new_pexp - in + pexp_rewriters rewrite_pexp -let rewrite_lit_lem (L_aux (lit, _)) = match lit with - | L_num _ | L_string _ | L_hex _ | L_bin _ | L_real _ -> true - | _ -> false +let rewrite_lit_lem (L_aux (lit, _)) = + match lit with L_num _ | L_string _ | L_hex _ | L_bin _ | L_real _ -> true | _ -> false -let rewrite_lit_ocaml (L_aux (lit, _)) = match lit with - | L_num _ | L_string _ | L_hex _ | L_bin _ | L_real _ | L_unit -> false - | _ -> true +let rewrite_lit_ocaml (L_aux (lit, _)) = + match lit with L_num _ | L_string _ | L_hex _ | L_bin _ | L_real _ | L_unit -> false | _ -> true let rewrite_ast_pat_lits rewrite_lit env ast = let rewrite_pexp (Pat_aux (pexp_aux, annot)) = @@ -3169,43 +3212,52 @@ let rewrite_ast_pat_lits rewrite_lit env ast = let rewrite_pat = function (* Matching on unit is always the same as matching on wildcard *) - | P_lit (L_aux (L_unit, _) as lit), p_annot when rewrite_lit lit -> - P_aux (P_wild, p_annot) + | P_lit (L_aux (L_unit, _) as lit), p_annot when rewrite_lit lit -> P_aux (P_wild, p_annot) | P_lit lit, p_annot when rewrite_lit lit -> - let env = env_of_annot p_annot in - let typ = typ_of_annot p_annot in - let id = mk_id ("p" ^ string_of_int !counter ^ "#") in - let guard = mk_exp (E_app_infix (mk_exp (E_id id), mk_id "==", mk_exp (E_lit lit))) in - let guard = check_exp (Env.add_local id (Immutable, typ) env) guard bool_typ in - guards := guard :: !guards; - incr counter; - P_aux (P_id id, p_annot) - | p_aux, p_annot -> - P_aux (p_aux, p_annot) + let env = env_of_annot p_annot in + let typ = typ_of_annot p_annot in + let id = mk_id ("p" ^ string_of_int !counter ^ "#") in + let guard = mk_exp (E_app_infix (mk_exp (E_id id), mk_id "==", mk_exp (E_lit lit))) in + let guard = check_exp (Env.add_local id (Immutable, typ) env) guard bool_typ in + guards := guard :: !guards; + incr counter; + P_aux (P_id id, p_annot) + | p_aux, p_annot -> P_aux (p_aux, p_annot) in match pexp_aux with - | Pat_exp (pat, exp) -> - begin - let pat = fold_pat { id_pat_alg with p_aux = rewrite_pat } pat in - match !guards with - | [] -> Pat_aux (Pat_exp (pat, exp), annot) - | (g :: gs) -> + | Pat_exp (pat, exp) -> begin + let pat = fold_pat { id_pat_alg with p_aux = rewrite_pat } pat in + match !guards with + | [] -> Pat_aux (Pat_exp (pat, exp), annot) + | g :: gs -> let guard_annot = (fst annot, mk_tannot (env_of exp) bool_typ) in - Pat_aux (Pat_when (pat, List.fold_left (fun g g' -> E_aux (E_app (mk_id "and_bool", [g; g']), guard_annot)) g gs, exp), annot) - end - | Pat_when (pat, guard, exp) -> - begin - let pat = fold_pat { id_pat_alg with p_aux = rewrite_pat } pat in - let guard_annot = (fst annot, mk_tannot (env_of exp) bool_typ) in - Pat_aux (Pat_when (pat, List.fold_left (fun g g' -> E_aux (E_app (mk_id "and_bool", [g; g']), guard_annot)) guard !guards, exp), annot) - end + Pat_aux + ( Pat_when + (pat, List.fold_left (fun g g' -> E_aux (E_app (mk_id "and_bool", [g; g']), guard_annot)) g gs, exp), + annot + ) + end + | Pat_when (pat, guard, exp) -> begin + let pat = fold_pat { id_pat_alg with p_aux = rewrite_pat } pat in + let guard_annot = (fst annot, mk_tannot (env_of exp) bool_typ) in + Pat_aux + ( Pat_when + ( pat, + List.fold_left (fun g g' -> E_aux (E_app (mk_id "and_bool", [g; g']), guard_annot)) guard !guards, + exp + ), + annot + ) + end in let rewrite_funcl (FCL_aux (FCL_funcl (id, pexp), (l, annot))) = - FCL_aux (FCL_funcl (id, rewrite_pexp pexp), (l, annot)) in + FCL_aux (FCL_funcl (id, rewrite_pexp pexp), (l, annot)) + in let rewrite_fun (FD_aux (FD_function (recopt, tannotopt, funcls), (l, annot))) = - FD_aux (FD_function (recopt, tannotopt, List.map rewrite_funcl funcls), (l, annot)) in + FD_aux (FD_function (recopt, tannotopt, List.map rewrite_funcl funcls), (l, annot)) + in let rewrite_def = function | DEF_aux (DEF_fundef fdef, def_annot) -> DEF_aux (DEF_fundef (rewrite_fun fdef), def_annot) | def -> def @@ -3215,17 +3267,13 @@ let rewrite_ast_pat_lits rewrite_lit env ast = let ast = rewrite_ast_base { rewriters_base with rewrite_exp = (fun _ -> fold_exp alg) } ast in { ast with defs = List.map rewrite_def ast.defs } - (* Now all expressions have no blocks anymore, any term is a sequence of let-expressions, * internal let-expressions, or internal plet-expressions ended by a term that does not * access memory or registers and does not update variables *) -type 'a updated_term = - | Added_vars of 'a exp * 'a pat - | Same_vars of 'a exp - -let rec rewrite_var_updates ((E_aux (expaux,((l,_) as annot))) as exp) = +type 'a updated_term = Added_vars of 'a exp * 'a pat | Same_vars of 'a exp +let rec rewrite_var_updates (E_aux (expaux, ((l, _) as annot)) as exp) = let env = env_of exp in let tuple_exp = function @@ -3237,142 +3285,160 @@ let rec rewrite_var_updates ((E_aux (expaux,((l,_) as annot))) as exp) = let tuple_pat = function | [] -> annot_pat P_wild l env unit_typ | [pat] -> - let typ = typ_of_pat pat in - add_p_typ env typ pat + let typ = typ_of_pat pat in + add_p_typ env typ pat | pats -> - let typ = tuple_typ (List.map typ_of_pat pats) in - add_p_typ env typ (annot_pat (P_tuple pats) l env typ) + let typ = tuple_typ (List.map typ_of_pat pats) in + add_p_typ env typ (annot_pat (P_tuple pats) l env typ) in - let rec add_vars overwrite ((E_aux (expaux,annot)) as exp) vars = + let rec add_vars overwrite (E_aux (expaux, annot) as exp) vars = match expaux with - | E_let (lb,exp) -> - let exp = add_vars overwrite exp vars in - E_aux (E_let (lb,exp), swaptyp (typ_of exp) annot) - | E_var (lexp,exp1,exp2) -> - let exp2 = add_vars overwrite exp2 vars in - E_aux (E_var (lexp,exp1,exp2), swaptyp (typ_of exp2) annot) - | E_internal_plet (pat,exp1,exp2) -> - let exp2 = add_vars overwrite exp2 vars in - E_aux (E_internal_plet (pat,exp1,exp2), swaptyp (typ_of exp2) annot) + | E_let (lb, exp) -> + let exp = add_vars overwrite exp vars in + E_aux (E_let (lb, exp), swaptyp (typ_of exp) annot) + | E_var (lexp, exp1, exp2) -> + let exp2 = add_vars overwrite exp2 vars in + E_aux (E_var (lexp, exp1, exp2), swaptyp (typ_of exp2) annot) + | E_internal_plet (pat, exp1, exp2) -> + let exp2 = add_vars overwrite exp2 vars in + E_aux (E_internal_plet (pat, exp1, exp2), swaptyp (typ_of exp2) annot) | E_internal_return exp2 -> - let exp2 = add_vars overwrite exp2 vars in - E_aux (E_internal_return exp2, swaptyp (typ_of exp2) annot) + let exp2 = add_vars overwrite exp2 vars in + E_aux (E_internal_return exp2, swaptyp (typ_of exp2) annot) | E_typ (typ, exp) -> - let (E_aux (expaux, annot) as exp) = add_vars overwrite exp vars in - let typ' = typ_of exp in - add_e_typ (env_of exp) typ' (E_aux (expaux, swaptyp typ' annot)) + let (E_aux (expaux, annot) as exp) = add_vars overwrite exp vars in + let typ' = typ_of exp in + add_e_typ (env_of exp) typ' (E_aux (expaux, swaptyp typ' annot)) | E_app (early_return, args) when string_of_id early_return = "early_return" -> - (* Special case early return: It has to be monadic for the prover - * backends, so the addition of vars below wouldn't work without an - * extra E_internal_return. But threading through local vars to the - * outer block isn't necessary anyway, because we will exit the - * function, so just keep the early_return expression as is. *) - exp - | E_internal_assume (nc,exp) -> - let exp = add_vars overwrite exp vars in - E_aux (E_internal_assume (nc,exp), swaptyp (typ_of exp) annot) + (* Special case early return: It has to be monadic for the prover + * backends, so the addition of vars below wouldn't work without an + * extra E_internal_return. But threading through local vars to the + * outer block isn't necessary anyway, because we will exit the + * function, so just keep the early_return expression as is. *) + exp + | E_internal_assume (nc, exp) -> + let exp = add_vars overwrite exp vars in + E_aux (E_internal_assume (nc, exp), swaptyp (typ_of exp) annot) | _ -> - (* after rewrite_ast_letbind_effects there cannot be terms that have - effects/update local variables in "tail-position": check n_exp_term - and where it is used. *) - if overwrite then - let lb = LB_aux (LB_val (P_aux (P_wild, annot), exp), annot) in - let exp' = tuple_exp vars in - E_aux (E_let (lb, exp'), swaptyp (typ_of exp') annot) - |> add_typs_let env (typ_of exp) (typ_of exp') - else tuple_exp (exp :: vars) in + (* after rewrite_ast_letbind_effects there cannot be terms that have + effects/update local variables in "tail-position": check n_exp_term + and where it is used. *) + if overwrite then ( + let lb = LB_aux (LB_val (P_aux (P_wild, annot), exp), annot) in + let exp' = tuple_exp vars in + E_aux (E_let (lb, exp'), swaptyp (typ_of exp') annot) |> add_typs_let env (typ_of exp) (typ_of exp') + ) + else tuple_exp (exp :: vars) + in let mk_var_exps_pats l env ids = - ids - |> IdSet.elements - |> List.map - (fun id -> - let (E_aux (_, a) as exp) = infer_exp env (E_aux (E_id id, (l, empty_uannot))) in - exp, P_aux (P_id id, a)) - |> List.split in + ids |> IdSet.elements + |> List.map (fun id -> + let (E_aux (_, a) as exp) = infer_exp env (E_aux (E_id id, (l, empty_uannot))) in + (exp, P_aux (P_id id, a)) + ) + |> List.split + in - let rec rewrite used_vars (E_aux (expaux,((el,_) as annot)) as full_exp) (P_aux (paux,(pl,pannot)) as pat) = + let rec rewrite used_vars (E_aux (expaux, ((el, _) as annot)) as full_exp) (P_aux (paux, (pl, pannot)) as pat) = let env = env_of_annot annot in - let overwrite = match paux with - | P_wild | P_typ (_, P_aux (P_wild, _)) -> true - | _ -> false in + let overwrite = match paux with P_wild | P_typ (_, P_aux (P_wild, _)) -> true | _ -> false in match expaux with - | E_for(id,exp1,exp2,exp3,order,exp4) -> - (* Translate for loops into calls to one of the foreach combinators. - The loop body becomes a function of the loop variable and any - mutable local variables that are updated inside the loop and - are used after or within the loop. - Since the foreach* combinators are higher-order functions, - they cannot be represented faithfully in the AST. The following - code abuses the parameters of an E_app node, embedding the loop body - function as an expression followed by the list of variables it - expects. In (Lem) pretty-printing, this turned into an anonymous - function and passed to foreach*. *) - let vars, varpats = - find_updated_vars exp4 - |> IdSet.inter (IdSet.union used_vars (find_used_vars full_exp)) - |> mk_var_exps_pats pl env - in - let exp4 = rewrite_var_updates (add_vars overwrite exp4 vars) in - (* Bind the loop variable in the body, annotated with constraints *) - let lvar_kid = mk_kid ("loop_" ^ string_of_id id) in - let lower_id = mk_id ("loop_" ^ string_of_id id ^ "_lower") in - let upper_id = mk_id ("loop_" ^ string_of_id id ^ "_upper") in - let lower_kid = mk_kid ("loop_" ^ string_of_id id ^ "_lower") in - let upper_kid = mk_kid ("loop_" ^ string_of_id id ^ "_upper") in - let env' = - env - |> Env.add_typ_var el (mk_kopt K_int lvar_kid) - |> Env.add_typ_var el (mk_kopt K_int lower_kid) - |> Env.add_typ_var el (mk_kopt K_int upper_kid) - in - let lower_id_exp = annot_exp (E_id lower_id) el env' (atom_typ (nvar lower_kid)) in - let upper_id_exp = annot_exp (E_id upper_id) el env' (atom_typ (nvar upper_kid)) in - let annot_bool_lit lit = annot_exp (E_lit lit) el env' bool_typ in - let ord_exp, lower_exp, upper_exp, exp1, exp2 = - if is_order_inc order - then annot_bool_lit (mk_lit L_true), exp1, exp2, lower_id_exp, upper_id_exp - else annot_bool_lit (mk_lit L_false), exp2, exp1, upper_id_exp, lower_id_exp - in - let lvar_nc = nc_and (nc_lteq (nvar lower_kid) (nvar lvar_kid)) (nc_lteq (nvar lvar_kid) (nvar upper_kid)) in - let lvar_typ = mk_typ (Typ_exist (List.map (mk_kopt K_int) [lvar_kid], lvar_nc, atom_typ (nvar lvar_kid))) in - let lvar_pat = unaux_pat (annot_pat (P_var ( - annot_pat (P_id id) el env' (atom_typ (nvar lvar_kid)), - TP_aux (TP_var lvar_kid, gen_loc el))) el env' lvar_typ) - in - let lb = annot_letbind (lvar_pat, exp1) el env' lvar_typ in - let body = - annot_exp (E_let (lb, exp4)) el env' (typ_of exp4) - |> add_typs_let env' lvar_typ (typ_of exp4) - in - (* If lower > upper, the loop body never gets executed, and the type - checker might not be able to prove that the initial value exp1 - satisfies the constraints on the loop variable. - - Make this explicit by guarding the loop body with lower <= upper. - (for type-checking; the guard is later removed again by the Lem - pretty-printer). This could be implemented with an assertion, but - that would force the loop to be effectful, so we use an if-expression - instead. This code assumes that the loop bounds have (possibly - existential) atom types, and the loop body has type unit. *) - let lower_pat = P_var (annot_pat (P_id lower_id) el env (typ_of lower_exp), mk_typ_pat (TP_app (mk_id "atom", [mk_typ_pat (TP_var lower_kid)]))) in - let lb_lower = annot_letbind (lower_pat, lower_exp) el env (typ_of lower_exp) in - let upper_pat = P_var (annot_pat (P_id upper_id) el env (typ_of upper_exp), mk_typ_pat (TP_app (mk_id "atom", [mk_typ_pat (TP_var upper_kid)]))) in - let lb_upper = annot_letbind (upper_pat, upper_exp) el env (typ_of upper_exp) in - let guard = annot_exp (E_constraint (nc_lteq (nvar lower_kid) (nvar upper_kid))) el env' bool_typ in - let unit_exp = annot_exp (E_lit (mk_lit L_unit)) el env' unit_typ in - let skip_val = tuple_exp (if overwrite then vars else unit_exp :: vars) in - let guarded_body = annot_exp (E_if (guard, body, skip_val)) el env' (typ_of exp4) in - let v = - annot_exp (E_let (lb_lower, - annot_exp (E_let (lb_upper, - annot_exp (E_app (mk_id "foreach#", [exp1; exp2; exp3; ord_exp; tuple_exp vars; guarded_body])) - el env (typ_of exp4))) - el env (typ_of exp4))) - el env (typ_of exp4) in - Added_vars (v, tuple_pat (if overwrite then varpats else pat :: varpats)) - | E_loop(loop,Measure_aux (measure,_),cond,body) -> + | E_for (id, exp1, exp2, exp3, order, exp4) -> + (* Translate for loops into calls to one of the foreach combinators. + The loop body becomes a function of the loop variable and any + mutable local variables that are updated inside the loop and + are used after or within the loop. + Since the foreach* combinators are higher-order functions, + they cannot be represented faithfully in the AST. The following + code abuses the parameters of an E_app node, embedding the loop body + function as an expression followed by the list of variables it + expects. In (Lem) pretty-printing, this turned into an anonymous + function and passed to foreach*. *) + let vars, varpats = + find_updated_vars exp4 + |> IdSet.inter (IdSet.union used_vars (find_used_vars full_exp)) + |> mk_var_exps_pats pl env + in + let exp4 = rewrite_var_updates (add_vars overwrite exp4 vars) in + (* Bind the loop variable in the body, annotated with constraints *) + let lvar_kid = mk_kid ("loop_" ^ string_of_id id) in + let lower_id = mk_id ("loop_" ^ string_of_id id ^ "_lower") in + let upper_id = mk_id ("loop_" ^ string_of_id id ^ "_upper") in + let lower_kid = mk_kid ("loop_" ^ string_of_id id ^ "_lower") in + let upper_kid = mk_kid ("loop_" ^ string_of_id id ^ "_upper") in + let env' = + env + |> Env.add_typ_var el (mk_kopt K_int lvar_kid) + |> Env.add_typ_var el (mk_kopt K_int lower_kid) + |> Env.add_typ_var el (mk_kopt K_int upper_kid) + in + let lower_id_exp = annot_exp (E_id lower_id) el env' (atom_typ (nvar lower_kid)) in + let upper_id_exp = annot_exp (E_id upper_id) el env' (atom_typ (nvar upper_kid)) in + let annot_bool_lit lit = annot_exp (E_lit lit) el env' bool_typ in + let ord_exp, lower_exp, upper_exp, exp1, exp2 = + if is_order_inc order then (annot_bool_lit (mk_lit L_true), exp1, exp2, lower_id_exp, upper_id_exp) + else (annot_bool_lit (mk_lit L_false), exp2, exp1, upper_id_exp, lower_id_exp) + in + let lvar_nc = nc_and (nc_lteq (nvar lower_kid) (nvar lvar_kid)) (nc_lteq (nvar lvar_kid) (nvar upper_kid)) in + let lvar_typ = mk_typ (Typ_exist (List.map (mk_kopt K_int) [lvar_kid], lvar_nc, atom_typ (nvar lvar_kid))) in + let lvar_pat = + unaux_pat + (annot_pat + (P_var (annot_pat (P_id id) el env' (atom_typ (nvar lvar_kid)), TP_aux (TP_var lvar_kid, gen_loc el))) + el env' lvar_typ + ) + in + let lb = annot_letbind (lvar_pat, exp1) el env' lvar_typ in + let body = annot_exp (E_let (lb, exp4)) el env' (typ_of exp4) |> add_typs_let env' lvar_typ (typ_of exp4) in + (* If lower > upper, the loop body never gets executed, and the type + checker might not be able to prove that the initial value exp1 + satisfies the constraints on the loop variable. + + Make this explicit by guarding the loop body with lower <= upper. + (for type-checking; the guard is later removed again by the Lem + pretty-printer). This could be implemented with an assertion, but + that would force the loop to be effectful, so we use an if-expression + instead. This code assumes that the loop bounds have (possibly + existential) atom types, and the loop body has type unit. *) + let lower_pat = + P_var + ( annot_pat (P_id lower_id) el env (typ_of lower_exp), + mk_typ_pat (TP_app (mk_id "atom", [mk_typ_pat (TP_var lower_kid)])) + ) + in + let lb_lower = annot_letbind (lower_pat, lower_exp) el env (typ_of lower_exp) in + let upper_pat = + P_var + ( annot_pat (P_id upper_id) el env (typ_of upper_exp), + mk_typ_pat (TP_app (mk_id "atom", [mk_typ_pat (TP_var upper_kid)])) + ) + in + let lb_upper = annot_letbind (upper_pat, upper_exp) el env (typ_of upper_exp) in + let guard = annot_exp (E_constraint (nc_lteq (nvar lower_kid) (nvar upper_kid))) el env' bool_typ in + let unit_exp = annot_exp (E_lit (mk_lit L_unit)) el env' unit_typ in + let skip_val = tuple_exp (if overwrite then vars else unit_exp :: vars) in + let guarded_body = annot_exp (E_if (guard, body, skip_val)) el env' (typ_of exp4) in + let v = + annot_exp + (E_let + ( lb_lower, + annot_exp + (E_let + ( lb_upper, + annot_exp + (E_app (mk_id "foreach#", [exp1; exp2; exp3; ord_exp; tuple_exp vars; guarded_body])) + el env (typ_of exp4) + ) + ) + el env (typ_of exp4) + ) + ) + el env (typ_of exp4) + in + Added_vars (v, tuple_pat (if overwrite then varpats else pat :: varpats)) + | E_loop (loop, Measure_aux (measure, _), cond, body) -> (* Find variables that might be updated in the loop body and are used either after or within the loop. *) let vars, varpats = @@ -3382,224 +3448,218 @@ let rec rewrite_var_updates ((E_aux (expaux,((l,_) as annot))) as exp) = in let body = rewrite_var_updates (add_vars overwrite body vars) in let body = add_e_typ env (typ_of body) body in - let (E_aux (_,(_,bannot))) = body in - let fname, measure = match loop, measure with - | While, Measure_none -> "while#", [] - | Until, Measure_none -> "until#", [] - | While, Measure_some exp -> "while#t", [exp] - | Until, Measure_some exp -> "until#t", [exp] + let (E_aux (_, (_, bannot))) = body in + let fname, measure = + match (loop, measure) with + | While, Measure_none -> ("while#", []) + | Until, Measure_none -> ("until#", []) + | While, Measure_some exp -> ("while#t", [exp]) + | Until, Measure_some exp -> ("until#t", [exp]) in - let funcl = Id_aux (Id fname,gen_loc el) in - let v = E_aux (E_app (funcl,[cond;tuple_exp vars;body]@measure), (gen_loc el, bannot)) in + let funcl = Id_aux (Id fname, gen_loc el) in + let v = E_aux (E_app (funcl, [cond; tuple_exp vars; body] @ measure), (gen_loc el, bannot)) in Added_vars (v, tuple_pat (if overwrite then varpats else pat :: varpats)) - | E_if (c,e1,e2) -> - let vars, varpats = - IdSet.union (find_updated_vars e1) (find_updated_vars e2) - |> IdSet.inter used_vars - |> mk_var_exps_pats pl env in - if vars = [] then - (Same_vars (E_aux (E_if (c,rewrite_var_updates e1,rewrite_var_updates e2),annot))) - else - let e1 = rewrite_var_updates (add_vars overwrite e1 vars) in - let e2 = rewrite_var_updates (add_vars overwrite e2 vars) in - (* after rewrite_ast_letbind_effects c has no variable updates *) - let env = env_of_annot annot in - let typ = typ_of e1 in - let v = E_aux (E_if (c,e1,e2), (gen_loc el, mk_tannot env typ)) in - Added_vars (v, tuple_pat (if overwrite then varpats else pat :: varpats)) - | E_match (e1,ps) | E_try (e1, ps) -> - let is_case = match expaux with E_match _ -> true | _ -> false in - let vars, varpats = - (* for E_match, e1 needs no rewriting after rewrite_ast_letbind_effects *) - ((if is_case then [] else [e1]) @ - List.map (fun (Pat_aux ((Pat_exp (_,e)|Pat_when (_,_,e)),_)) -> e) ps) - |> List.map find_updated_vars - |> List.fold_left IdSet.union IdSet.empty - |> IdSet.inter used_vars - |> mk_var_exps_pats pl env in - let e1 = if is_case then e1 else rewrite_var_updates (add_vars overwrite e1 vars) in - if vars = [] then - let ps = List.map (function - | Pat_aux (Pat_exp (p,e),a) -> - Pat_aux (Pat_exp (p,rewrite_var_updates e),a) - | Pat_aux (Pat_when (p,g,e),a) -> - Pat_aux (Pat_when (p,g,rewrite_var_updates e),a)) ps in - let expaux = if is_case then E_match (e1, ps) else E_try (e1, ps) in - Same_vars (E_aux (expaux, annot)) - else - let rewrite_pexp (Pat_aux (pexp, (l, _))) = match pexp with - | Pat_exp (pat, exp) -> - let exp = rewrite_var_updates (add_vars overwrite exp vars) in - let pannot = (l, mk_tannot (env_of exp) (typ_of exp)) in - Pat_aux (Pat_exp (pat, exp), pannot) - | Pat_when _ -> - raise (Reporting.err_unreachable l __POS__ - "Guarded patterns should have been rewritten already") in - let ps = List.map rewrite_pexp ps in - let expaux = if is_case then E_match (e1, ps) else E_try (e1, ps) in - let typ = match ps with - | Pat_aux ((Pat_exp (_,first)|Pat_when (_,_,first)),_) :: _ -> typ_of first - | _ -> unit_typ in - let v = annot_exp expaux pl env typ in - Added_vars (v, tuple_pat (if overwrite then varpats else pat :: varpats)) - | E_assign (lexp,vexp) -> - let mk_id_pat id = - let typ = lvar_typ (Env.lookup_id id env) in - add_p_typ env typ (annot_pat (P_id id) pl env typ) - in - if effectful full_exp then ( - Same_vars (E_aux (E_assign (lexp,vexp),annot)) - ) else - (match lexp with - | LE_aux (LE_id id,annot) -> - Added_vars (vexp, mk_id_pat id) - | LE_aux (LE_typ (typ,id),annot) -> - let pat = add_p_typ env typ (annot_pat (P_id id) pl env (typ_of vexp)) in - Added_vars (vexp,pat) - | LE_aux (LE_vector (LE_aux (LE_id id,((l2,_) as annot2)),i),((l1,_) as annot)) -> - let eid = annot_exp (E_id id) l2 env (typ_of_annot annot2) in - let vexp = annot_exp (E_vector_update (eid,i,vexp)) l1 env (typ_of_annot annot) in - let pat = annot_pat (P_id id) pl env (typ_of vexp) in - Added_vars (vexp,pat) - | LE_aux (LE_vector_range (LE_aux (LE_id id,((l2,_) as annot2)),i,j), - ((l,_) as annot)) -> - let eid = annot_exp (E_id id) l2 env (typ_of_annot annot2) in - let vexp = annot_exp (E_vector_update_subrange (eid,i,j,vexp)) l env (typ_of_annot annot) in - let pat = annot_pat (P_id id) pl env (typ_of vexp) in - Added_vars (vexp,pat) - | _ -> Same_vars (E_aux (E_assign (lexp,vexp),annot))) - | E_typ (typ, exp) -> - begin match rewrite used_vars exp pat with - | Added_vars (exp', pat') -> - Added_vars (add_e_typ (env_of exp') (typ_of exp') exp', pat') - | Same_vars (exp') -> - Same_vars (E_aux (E_typ (typ, exp'), annot)) - end + | E_if (c, e1, e2) -> + let vars, varpats = + IdSet.union (find_updated_vars e1) (find_updated_vars e2) |> IdSet.inter used_vars |> mk_var_exps_pats pl env + in + if vars = [] then Same_vars (E_aux (E_if (c, rewrite_var_updates e1, rewrite_var_updates e2), annot)) + else ( + let e1 = rewrite_var_updates (add_vars overwrite e1 vars) in + let e2 = rewrite_var_updates (add_vars overwrite e2 vars) in + (* after rewrite_ast_letbind_effects c has no variable updates *) + let env = env_of_annot annot in + let typ = typ_of e1 in + let v = E_aux (E_if (c, e1, e2), (gen_loc el, mk_tannot env typ)) in + Added_vars (v, tuple_pat (if overwrite then varpats else pat :: varpats)) + ) + | E_match (e1, ps) | E_try (e1, ps) -> + let is_case = match expaux with E_match _ -> true | _ -> false in + let vars, varpats = + (* for E_match, e1 needs no rewriting after rewrite_ast_letbind_effects *) + (if is_case then [] else [e1]) @ List.map (fun (Pat_aux ((Pat_exp (_, e) | Pat_when (_, _, e)), _)) -> e) ps + |> List.map find_updated_vars |> List.fold_left IdSet.union IdSet.empty |> IdSet.inter used_vars + |> mk_var_exps_pats pl env + in + let e1 = if is_case then e1 else rewrite_var_updates (add_vars overwrite e1 vars) in + if vars = [] then ( + let ps = + List.map + (function + | Pat_aux (Pat_exp (p, e), a) -> Pat_aux (Pat_exp (p, rewrite_var_updates e), a) + | Pat_aux (Pat_when (p, g, e), a) -> Pat_aux (Pat_when (p, g, rewrite_var_updates e), a) + ) + ps + in + let expaux = if is_case then E_match (e1, ps) else E_try (e1, ps) in + Same_vars (E_aux (expaux, annot)) + ) + else ( + let rewrite_pexp (Pat_aux (pexp, (l, _))) = + match pexp with + | Pat_exp (pat, exp) -> + let exp = rewrite_var_updates (add_vars overwrite exp vars) in + let pannot = (l, mk_tannot (env_of exp) (typ_of exp)) in + Pat_aux (Pat_exp (pat, exp), pannot) + | Pat_when _ -> + raise (Reporting.err_unreachable l __POS__ "Guarded patterns should have been rewritten already") + in + let ps = List.map rewrite_pexp ps in + let expaux = if is_case then E_match (e1, ps) else E_try (e1, ps) in + let typ = + match ps with + | Pat_aux ((Pat_exp (_, first) | Pat_when (_, _, first)), _) :: _ -> typ_of first + | _ -> unit_typ + in + let v = annot_exp expaux pl env typ in + Added_vars (v, tuple_pat (if overwrite then varpats else pat :: varpats)) + ) + | E_assign (lexp, vexp) -> + let mk_id_pat id = + let typ = lvar_typ (Env.lookup_id id env) in + add_p_typ env typ (annot_pat (P_id id) pl env typ) + in + if effectful full_exp then Same_vars (E_aux (E_assign (lexp, vexp), annot)) + else ( + match lexp with + | LE_aux (LE_id id, annot) -> Added_vars (vexp, mk_id_pat id) + | LE_aux (LE_typ (typ, id), annot) -> + let pat = add_p_typ env typ (annot_pat (P_id id) pl env (typ_of vexp)) in + Added_vars (vexp, pat) + | LE_aux (LE_vector (LE_aux (LE_id id, ((l2, _) as annot2)), i), ((l1, _) as annot)) -> + let eid = annot_exp (E_id id) l2 env (typ_of_annot annot2) in + let vexp = annot_exp (E_vector_update (eid, i, vexp)) l1 env (typ_of_annot annot) in + let pat = annot_pat (P_id id) pl env (typ_of vexp) in + Added_vars (vexp, pat) + | LE_aux (LE_vector_range (LE_aux (LE_id id, ((l2, _) as annot2)), i, j), ((l, _) as annot)) -> + let eid = annot_exp (E_id id) l2 env (typ_of_annot annot2) in + let vexp = annot_exp (E_vector_update_subrange (eid, i, j, vexp)) l env (typ_of_annot annot) in + let pat = annot_pat (P_id id) pl env (typ_of vexp) in + Added_vars (vexp, pat) + | _ -> Same_vars (E_aux (E_assign (lexp, vexp), annot)) + ) + | E_typ (typ, exp) -> begin + match rewrite used_vars exp pat with + | Added_vars (exp', pat') -> Added_vars (add_e_typ (env_of exp') (typ_of exp') exp', pat') + | Same_vars exp' -> Same_vars (E_aux (E_typ (typ, exp'), annot)) + end | _ -> - (* after rewrite_ast_letbind_effects this expression is pure and updates - no variables: check n_exp_term and where it's used. *) - Same_vars (E_aux (expaux,annot)) in + (* after rewrite_ast_letbind_effects this expression is pure and updates + no variables: check n_exp_term and where it's used. *) + Same_vars (E_aux (expaux, annot)) + in match expaux with - | E_let (lb,body) -> - let body = rewrite_var_updates body in - let (LB_aux (LB_val (pat, v), lbannot)) = lb in - let lb = match rewrite (find_used_vars body) v pat with - | Added_vars (v, P_aux (pat, _)) -> - annot_letbind (pat, v) (get_loc_exp v) env (typ_of v) - | Same_vars v -> LB_aux (LB_val (pat, v),lbannot) in - annot_exp (E_let (lb, body)) l env (typ_of body) - | E_var (lexp,v,body) -> - (* Rewrite E_var into E_let and call recursively *) - let rec aux lexp = match lexp with - | LE_aux (LE_id id, _) -> - P_id id, typ_of v - | LE_aux (LE_typ (typ, id), _) -> - unaux_pat (add_p_typ env typ (annot_pat (P_id id) l env (typ_of v))), typ - | LE_aux (LE_tuple lexps, _) -> - let pauxs_typs = List.map aux lexps in - let pats, typs = List.split (List.map (fun (paux, typ) -> - annot_pat paux l env typ, typ) pauxs_typs) in - P_tuple pats, mk_typ (Typ_tuple typs) - | _ -> - raise (Reporting.err_unreachable l __POS__ - ("E_var with a lexp that is not a variable: " ^ string_of_lexp lexp)) in - let paux, typ = aux lexp in - let lb = annot_letbind (paux, v) l env typ in - let exp = annot_exp (E_let (lb, body)) l env (typ_of body) in - rewrite_var_updates exp + | E_let (lb, body) -> + let body = rewrite_var_updates body in + let (LB_aux (LB_val (pat, v), lbannot)) = lb in + let lb = + match rewrite (find_used_vars body) v pat with + | Added_vars (v, P_aux (pat, _)) -> annot_letbind (pat, v) (get_loc_exp v) env (typ_of v) + | Same_vars v -> LB_aux (LB_val (pat, v), lbannot) + in + annot_exp (E_let (lb, body)) l env (typ_of body) + | E_var (lexp, v, body) -> + (* Rewrite E_var into E_let and call recursively *) + let rec aux lexp = + match lexp with + | LE_aux (LE_id id, _) -> (P_id id, typ_of v) + | LE_aux (LE_typ (typ, id), _) -> (unaux_pat (add_p_typ env typ (annot_pat (P_id id) l env (typ_of v))), typ) + | LE_aux (LE_tuple lexps, _) -> + let pauxs_typs = List.map aux lexps in + let pats, typs = List.split (List.map (fun (paux, typ) -> (annot_pat paux l env typ, typ)) pauxs_typs) in + (P_tuple pats, mk_typ (Typ_tuple typs)) + | _ -> + raise + (Reporting.err_unreachable l __POS__ ("E_var with a lexp that is not a variable: " ^ string_of_lexp lexp)) + in + let paux, typ = aux lexp in + let lb = annot_letbind (paux, v) l env typ in + let exp = annot_exp (E_let (lb, body)) l env (typ_of body) in + rewrite_var_updates exp | E_for _ | E_loop _ | E_if _ | E_match _ | E_assign _ -> - let var_id = fresh_id "u__" l in - let lb = LB_aux (LB_val (P_aux (P_id var_id, annot), exp), annot) in - let exp' = E_aux (E_let (lb, E_aux (E_id var_id, annot)), annot) in - rewrite_var_updates exp' - | E_internal_plet (pat,v,body) -> - failwith "rewrite_var_updates: E_internal_plet shouldn't be introduced yet" + let var_id = fresh_id "u__" l in + let lb = LB_aux (LB_val (P_aux (P_id var_id, annot), exp), annot) in + let exp' = E_aux (E_let (lb, E_aux (E_id var_id, annot)), annot) in + rewrite_var_updates exp' + | E_internal_plet (pat, v, body) -> failwith "rewrite_var_updates: E_internal_plet shouldn't be introduced yet" | E_typ (typ, exp) -> - let exp' = rewrite_var_updates exp in - E_aux (E_typ (typ, exp'), annot) + let exp' = rewrite_var_updates exp in + E_aux (E_typ (typ, exp'), annot) | E_internal_assume (nc, exp) -> - let exp' = rewrite_var_updates exp in - E_aux (E_internal_assume (nc, exp'), annot) + let exp' = rewrite_var_updates exp in + E_aux (E_internal_assume (nc, exp'), annot) (* There are no other expressions that have effects or variable updates in "tail-position": check the definition nexp_term and where it is used. *) | _ -> exp let replace_memwrite_e_assign exp = - let e_aux = fun (expaux,annot) -> + let e_aux (expaux, annot) = match expaux with - | E_assign (LE_aux (LE_app (id,args),_),v) -> E_aux (E_app (id,args @ [v]),annot) - | _ -> E_aux (expaux,annot) in - fold_exp { id_exp_alg with e_aux = e_aux } exp - - + | E_assign (LE_aux (LE_app (id, args), _), v) -> E_aux (E_app (id, args @ [v]), annot) + | _ -> E_aux (expaux, annot) + in + fold_exp { id_exp_alg with e_aux } exp let remove_reference_types exp = - - let rec rewrite_t (Typ_aux (t_aux,a)) = (Typ_aux (rewrite_t_aux t_aux,a)) - and rewrite_t_aux t_aux = match t_aux with - | Typ_app (Id_aux (Id "reg",_), [A_aux (A_typ (Typ_aux (t_aux2, _)), _)]) -> - rewrite_t_aux t_aux2 - | Typ_app (name,t_args) -> Typ_app (name,List.map rewrite_t_arg t_args) + let rec rewrite_t (Typ_aux (t_aux, a)) = Typ_aux (rewrite_t_aux t_aux, a) + and rewrite_t_aux t_aux = + match t_aux with + | Typ_app (Id_aux (Id "reg", _), [A_aux (A_typ (Typ_aux (t_aux2, _)), _)]) -> rewrite_t_aux t_aux2 + | Typ_app (name, t_args) -> Typ_app (name, List.map rewrite_t_arg t_args) | Typ_fn (arg_typs, ret_typ) -> Typ_fn (List.map rewrite_t arg_typs, rewrite_t ret_typ) | Typ_tuple ts -> Typ_tuple (List.map rewrite_t ts) | _ -> t_aux - and rewrite_t_arg t_arg = match t_arg with - | A_aux (A_typ t, a) -> A_aux (A_typ (rewrite_t t), a) - | _ -> t_arg in + and rewrite_t_arg t_arg = match t_arg with A_aux (A_typ t, a) -> A_aux (A_typ (rewrite_t t), a) | _ -> t_arg in let rewrite_annot (l, tannot) = match destruct_tannot tannot with - | None -> l, empty_tannot - | Some (_, typ) -> l, replace_typ (rewrite_t typ) tannot in + | None -> (l, empty_tannot) + | Some (_, typ) -> (l, replace_typ (rewrite_t typ) tannot) + in map_exp_annot rewrite_annot exp - - let rewrite_ast_remove_superfluous_letbinds env = - - let e_aux (exp,annot) = match exp with - | E_let (LB_aux (LB_val (pat, exp1), _), exp2) - | E_internal_plet (pat, exp1, exp2) -> - begin match untyp_pat pat, uncast_exp exp1, uncast_exp exp2 with - (* 'let x = EXP1 in x' can be replaced with 'EXP1' *) - | (P_aux (P_id id, _), _), _, (E_aux (E_id id', _), _) - when Id.compare id id' = 0 -> - exp1 - (* "let _ = () in exp" can be replaced with exp *) - | (P_aux (P_wild, _), _), (E_aux (E_lit (L_aux (L_unit, _)), _), _), _ -> - exp2 - (* "let x = EXP1 in return x" can be replaced with 'return (EXP1)', at - least when EXP1 is 'small' enough *) - | (P_aux (P_id id, _), _), _, (E_aux (E_internal_return (E_aux (E_id id', _)), _), _) - when Id.compare id id' = 0 && small exp1 && not (effectful exp1) -> - let (E_aux (_,e1annot)) = exp1 in - E_aux (E_internal_return (exp1),e1annot) - | _, (E_aux (E_throw e, a), _), _ -> E_aux (E_throw e, a) - | (pat, _), (E_aux (E_assert (c, msg), a) as assert_exp, _), _ -> - begin match typ_of c with - | Typ_aux (Typ_app (Id_aux (Id "atom_bool", _), [A_aux (A_bool nc, _)]), _) - when prove __POS__ (env_of c) (nc_not nc) -> - (* Drop rest of block after an 'assert(false)' *) - let exit_exp = E_aux (E_exit (infer_exp (env_of c) (mk_lit_exp L_unit)), a) in - E_aux (E_internal_plet (pat, assert_exp, exit_exp), annot) - | _ -> - E_aux (exp, annot) + let e_aux (exp, annot) = + match exp with + | E_let (LB_aux (LB_val (pat, exp1), _), exp2) | E_internal_plet (pat, exp1, exp2) -> begin + match (untyp_pat pat, uncast_exp exp1, uncast_exp exp2) with + (* 'let x = EXP1 in x' can be replaced with 'EXP1' *) + | (P_aux (P_id id, _), _), _, (E_aux (E_id id', _), _) when Id.compare id id' = 0 -> exp1 + (* "let _ = () in exp" can be replaced with exp *) + | (P_aux (P_wild, _), _), (E_aux (E_lit (L_aux (L_unit, _)), _), _), _ -> exp2 + (* "let x = EXP1 in return x" can be replaced with 'return (EXP1)', at + least when EXP1 is 'small' enough *) + | (P_aux (P_id id, _), _), _, (E_aux (E_internal_return (E_aux (E_id id', _)), _), _) + when Id.compare id id' = 0 && small exp1 && not (effectful exp1) -> + let (E_aux (_, e1annot)) = exp1 in + E_aux (E_internal_return exp1, e1annot) + | _, (E_aux (E_throw e, a), _), _ -> E_aux (E_throw e, a) + | (pat, _), ((E_aux (E_assert (c, msg), a) as assert_exp), _), _ -> begin + match typ_of c with + | Typ_aux (Typ_app (Id_aux (Id "atom_bool", _), [A_aux (A_bool nc, _)]), _) + when prove __POS__ (env_of c) (nc_not nc) -> + (* Drop rest of block after an 'assert(false)' *) + let exit_exp = E_aux (E_exit (infer_exp (env_of c) (mk_lit_exp L_unit)), a) in + E_aux (E_internal_plet (pat, assert_exp, exit_exp), annot) + | _ -> E_aux (exp, annot) end - | _ -> E_aux (exp,annot) - end - | _ -> E_aux (exp,annot) in + | _ -> E_aux (exp, annot) + end + | _ -> E_aux (exp, annot) + in - let alg = { id_exp_alg with e_aux = e_aux } in + let alg = { id_exp_alg with e_aux } in rewrite_ast_base - { rewrite_exp = (fun _ -> fold_exp alg) - ; rewrite_pat = rewrite_pat - ; rewrite_let = rewrite_let - ; rewrite_lexp = rewrite_lexp - ; rewrite_fun = rewrite_fun - ; rewrite_def = rewrite_def - ; rewrite_ast = rewrite_ast_base + { + rewrite_exp = (fun _ -> fold_exp alg); + rewrite_pat; + rewrite_let; + rewrite_lexp; + rewrite_fun; + rewrite_def; + rewrite_ast = rewrite_ast_base; } (* FIXME: We shouldn't allow nested not-patterns *) @@ -3611,190 +3671,209 @@ let rewrite_ast_not_pats env = let rewrite_not_pat (pat_aux, annot) = match pat_aux with | P_not pat -> - incr not_counter; - let np_id = mk_id ("np#" ^ string_of_int !not_counter) in - let guard = - mk_exp (E_match (mk_exp (E_id np_id), - [mk_pexp (Pat_exp (strip_pat pat, mk_lit_exp L_false)); - mk_pexp (Pat_exp (mk_pat P_wild, mk_lit_exp L_true))])) - in - guards := (np_id, typ_of_annot annot, guard) :: !guards; - P_aux (P_id np_id, annot) - + incr not_counter; + let np_id = mk_id ("np#" ^ string_of_int !not_counter) in + let guard = + mk_exp + (E_match + ( mk_exp (E_id np_id), + [ + mk_pexp (Pat_exp (strip_pat pat, mk_lit_exp L_false)); + mk_pexp (Pat_exp (mk_pat P_wild, mk_lit_exp L_true)); + ] + ) + ) + in + guards := (np_id, typ_of_annot annot, guard) :: !guards; + P_aux (P_id np_id, annot) | _ -> P_aux (pat_aux, annot) - in - let pat = fold_pat { id_pat_alg with p_aux = rewrite_not_pat } pat in - begin match !guards with - | [] -> - Pat_aux (pexp_aux, annot) - | guards -> - let guard_exp = - match orig_guard, guards with - | Some guard, _ -> - List.fold_left (fun exp1 (_, _, exp2) -> mk_exp (E_app_infix (exp1, mk_id "&", exp2))) guard guards - | None, (_, _, guard) :: guards -> - List.fold_left (fun exp1 (_, _, exp2) -> mk_exp (E_app_infix (exp1, mk_id "&", exp2))) guard guards - | _ -> raise (Reporting.err_unreachable (fst annot) __POS__ "Case in not-pattern re-writing should be unreachable") - in - (* We need to construct an environment to check the match guard in *) - let env = env_of_pat pat in - let env = List.fold_left (fun env (np_id, np_typ, _) -> Env.add_local np_id (Immutable, np_typ) env) env guards in - let guard_exp = Type_check.check_exp env guard_exp bool_typ in - Pat_aux (Pat_when (pat, guard_exp, exp), annot) - end + in + let pat = fold_pat { id_pat_alg with p_aux = rewrite_not_pat } pat in + begin + match !guards with + | [] -> Pat_aux (pexp_aux, annot) + | guards -> + let guard_exp = + match (orig_guard, guards) with + | Some guard, _ -> + List.fold_left (fun exp1 (_, _, exp2) -> mk_exp (E_app_infix (exp1, mk_id "&", exp2))) guard guards + | None, (_, _, guard) :: guards -> + List.fold_left (fun exp1 (_, _, exp2) -> mk_exp (E_app_infix (exp1, mk_id "&", exp2))) guard guards + | _ -> + raise + (Reporting.err_unreachable (fst annot) __POS__ + "Case in not-pattern re-writing should be unreachable" + ) + in + (* We need to construct an environment to check the match guard in *) + let env = env_of_pat pat in + let env = + List.fold_left (fun env (np_id, np_typ, _) -> Env.add_local np_id (Immutable, np_typ) env) env guards + in + let guard_exp = Type_check.check_exp env guard_exp bool_typ in + Pat_aux (Pat_when (pat, guard_exp, exp), annot) + end in match pexp_aux with - | Pat_exp (pat, exp) -> - rewrite_pexp' pat exp None - | Pat_when (pat, guard, exp) -> - rewrite_pexp' pat exp (Some (strip_exp guard)) + | Pat_exp (pat, exp) -> rewrite_pexp' pat exp None + | Pat_when (pat, guard, exp) -> rewrite_pexp' pat exp (Some (strip_exp guard)) in let rw_exp = { id_exp_alg with pat_aux = rewrite_pexp } in rewrite_ast_base { rewriters_base with rewrite_exp = (fun _ -> fold_exp rw_exp) } let rewrite_ast_remove_superfluous_returns env = - let add_opt_cast typopt1 typopt2 annot exp = - match typopt1, typopt2 with - | Some typ, _ | _, Some typ -> add_e_typ (env_of exp) typ exp - | None, None -> exp - in - - let e_aux (exp,annot) = match exp with - | E_let (LB_aux (LB_val (pat, exp1), _), exp2) - | E_internal_plet (pat, exp1, exp2) - when effectful exp1 -> - begin match untyp_pat pat, uncast_exp exp2 with - | (P_aux (P_lit (L_aux (lit,_)),_), ptyp), - (E_aux (E_internal_return (E_aux (E_lit (L_aux (lit',_)),_)), a), etyp) - when lit = lit' -> - add_opt_cast ptyp etyp a exp1 - | (P_aux (P_wild,pannot), ptyp), - (E_aux ((E_internal_return (E_aux (E_lit (L_aux (L_unit,_)),_)) | E_lit (L_aux (L_unit,_))), a), etyp) - when is_unit_typ (typ_of exp1) -> - add_opt_cast ptyp etyp a exp1 - | (P_aux (P_id id,_), ptyp), - (E_aux (E_internal_return (E_aux (E_id id',_)), a), etyp) - when Id.compare id id' == 0 -> - add_opt_cast ptyp etyp a exp1 - | (P_aux (P_tuple ps, _), ptyp), - (E_aux (E_internal_return (E_aux (E_tuple es, _)), a), etyp) - when List.length ps = List.length es -> - let same_id (P_aux (p, _)) (E_aux (e, _)) = match p, e with - | P_id id, E_id id' -> Id.compare id id' == 0 - | _, _ -> false - in - let ps = List.map fst (List.map untyp_pat ps) in - let es = List.map fst (List.map uncast_exp es) in - if List.for_all2 same_id ps es - then add_opt_cast ptyp etyp a exp1 - else E_aux (exp,annot) - | _ -> E_aux (exp,annot) - end - | _ -> E_aux (exp,annot) in - - let alg = { id_exp_alg with e_aux = e_aux } in + match (typopt1, typopt2) with Some typ, _ | _, Some typ -> add_e_typ (env_of exp) typ exp | None, None -> exp + in + + let e_aux (exp, annot) = + match exp with + | (E_let (LB_aux (LB_val (pat, exp1), _), exp2) | E_internal_plet (pat, exp1, exp2)) when effectful exp1 -> begin + match (untyp_pat pat, uncast_exp exp2) with + | ( (P_aux (P_lit (L_aux (lit, _)), _), ptyp), + (E_aux (E_internal_return (E_aux (E_lit (L_aux (lit', _)), _)), a), etyp) ) + when lit = lit' -> + add_opt_cast ptyp etyp a exp1 + | ( (P_aux (P_wild, pannot), ptyp), + (E_aux ((E_internal_return (E_aux (E_lit (L_aux (L_unit, _)), _)) | E_lit (L_aux (L_unit, _))), a), etyp) ) + when is_unit_typ (typ_of exp1) -> + add_opt_cast ptyp etyp a exp1 + | (P_aux (P_id id, _), ptyp), (E_aux (E_internal_return (E_aux (E_id id', _)), a), etyp) + when Id.compare id id' == 0 -> + add_opt_cast ptyp etyp a exp1 + | (P_aux (P_tuple ps, _), ptyp), (E_aux (E_internal_return (E_aux (E_tuple es, _)), a), etyp) + when List.length ps = List.length es -> + let same_id (P_aux (p, _)) (E_aux (e, _)) = + match (p, e) with P_id id, E_id id' -> Id.compare id id' == 0 | _, _ -> false + in + let ps = List.map fst (List.map untyp_pat ps) in + let es = List.map fst (List.map uncast_exp es) in + if List.for_all2 same_id ps es then add_opt_cast ptyp etyp a exp1 else E_aux (exp, annot) + | _ -> E_aux (exp, annot) + end + | _ -> E_aux (exp, annot) + in + + let alg = { id_exp_alg with e_aux } in rewrite_ast_base - { rewrite_exp = (fun _ -> fold_exp alg) - ; rewrite_pat = rewrite_pat - ; rewrite_let = rewrite_let - ; rewrite_lexp = rewrite_lexp - ; rewrite_fun = rewrite_fun - ; rewrite_def = rewrite_def - ; rewrite_ast = rewrite_ast_base + { + rewrite_exp = (fun _ -> fold_exp alg); + rewrite_pat; + rewrite_let; + rewrite_lexp; + rewrite_fun; + rewrite_def; + rewrite_ast = rewrite_ast_base; } - let rewrite_ast_remove_e_assign env ast = - let loop_specs = fst (Type_error.check_defs initial_env (List.map (gen_vs ~pure:true) - [("foreach#", "forall ('vars_in 'vars_out : Type). (int, int, int, bool, 'vars_in, 'vars_out) -> 'vars_out"); - ("while#", "forall ('vars_in 'vars_out : Type). (bool, 'vars_in, 'vars_out) -> 'vars_out"); - ("until#", "forall ('vars_in 'vars_out : Type). (bool, 'vars_in, 'vars_out) -> 'vars_out"); - ("while#t", "forall ('vars_in 'vars_out : Type). (bool, 'vars_in, 'vars_out, int) -> 'vars_out"); - ("until#t", "forall ('vars_in 'vars_out : Type). (bool, 'vars_in, 'vars_out, int) -> 'vars_out")])) in - let rewrite_exp _ e = - replace_memwrite_e_assign (remove_reference_types (rewrite_var_updates e)) in + let loop_specs = + fst + (Type_error.check_defs initial_env + (List.map (gen_vs ~pure:true) + [ + ("foreach#", "forall ('vars_in 'vars_out : Type). (int, int, int, bool, 'vars_in, 'vars_out) -> 'vars_out"); + ("while#", "forall ('vars_in 'vars_out : Type). (bool, 'vars_in, 'vars_out) -> 'vars_out"); + ("until#", "forall ('vars_in 'vars_out : Type). (bool, 'vars_in, 'vars_out) -> 'vars_out"); + ("while#t", "forall ('vars_in 'vars_out : Type). (bool, 'vars_in, 'vars_out, int) -> 'vars_out"); + ("until#t", "forall ('vars_in 'vars_out : Type). (bool, 'vars_in, 'vars_out, int) -> 'vars_out"); + ] + ) + ) + in + let rewrite_exp _ e = replace_memwrite_e_assign (remove_reference_types (rewrite_var_updates e)) in rewrite_ast_base - { rewrite_exp = rewrite_exp - ; rewrite_pat = rewrite_pat - ; rewrite_let = rewrite_let - ; rewrite_lexp = rewrite_lexp - ; rewrite_fun = rewrite_fun - ; rewrite_def = rewrite_def - ; rewrite_ast = rewrite_ast_base - } { ast with defs = loop_specs @ ast.defs } + { rewrite_exp; rewrite_pat; rewrite_let; rewrite_lexp; rewrite_fun; rewrite_def; rewrite_ast = rewrite_ast_base } + { ast with defs = loop_specs @ ast.defs } let merge_funcls env ast = - let merge_function (FD_aux (FD_function (r,t,fcls),ann) as f) = + let merge_function (FD_aux (FD_function (r, t, fcls), ann) as f) = match fcls with | [] | [_] -> f - | (FCL_aux (FCL_funcl (id,_),(def_annot,_)))::_ -> - let l = def_annot.loc in - let var = mk_id "merge#var" in - let l_g = Parse_ast.Generated l in - let ann_g : _ * tannot = (l_g,empty_tannot) in - let clauses = List.map (fun (FCL_aux (FCL_funcl (_,pexp),_)) -> pexp) fcls in - FD_aux (FD_function (r,t,[ - FCL_aux (FCL_funcl (id,Pat_aux (Pat_exp (P_aux (P_id var,ann_g), - E_aux (E_match (E_aux (E_id var,ann_g),clauses),ann_g)),ann_g)), - (mk_def_annot l,empty_tannot))]),ann) + | FCL_aux (FCL_funcl (id, _), (def_annot, _)) :: _ -> + let l = def_annot.loc in + let var = mk_id "merge#var" in + let l_g = Parse_ast.Generated l in + let ann_g : _ * tannot = (l_g, empty_tannot) in + let clauses = List.map (fun (FCL_aux (FCL_funcl (_, pexp), _)) -> pexp) fcls in + FD_aux + ( FD_function + ( r, + t, + [ + FCL_aux + ( FCL_funcl + ( id, + Pat_aux + ( Pat_exp + (P_aux (P_id var, ann_g), E_aux (E_match (E_aux (E_id var, ann_g), clauses), ann_g)), + ann_g + ) + ), + (mk_def_annot l, empty_tannot) + ); + ] + ), + ann + ) in let merge_in_def = function | DEF_aux (DEF_fundef f, def_annot) -> DEF_aux (DEF_fundef (merge_function f), def_annot) - | DEF_aux (DEF_internal_mutrec fs, def_annot) -> DEF_aux (DEF_internal_mutrec (List.map merge_function fs), def_annot) + | DEF_aux (DEF_internal_mutrec fs, def_annot) -> + DEF_aux (DEF_internal_mutrec (List.map merge_function fs), def_annot) | d -> d - in { ast with defs = List.map merge_in_def ast.defs } + in + { ast with defs = List.map merge_in_def ast.defs } let rec pat_of_mpat (MP_aux (mpat, annot)) = match mpat with - | MP_lit lit -> P_aux (P_lit lit, annot) - | MP_id id -> P_aux (P_id id, annot) - | MP_app (id, args) -> P_aux (P_app (id, (List.map pat_of_mpat args)), annot) - | MP_vector mpats -> P_aux (P_vector (List.map pat_of_mpat mpats), annot) - | MP_vector_concat mpats -> P_aux (P_vector_concat (List.map pat_of_mpat mpats), annot) - | MP_vector_subrange (id, n, m) -> P_aux (P_vector_subrange (id, n, m), annot) - | MP_tuple mpats -> P_aux (P_tuple (List.map pat_of_mpat mpats), annot) - | MP_list mpats -> P_aux (P_list (List.map pat_of_mpat mpats), annot) - | MP_cons (mpat1, mpat2) -> P_aux ((P_cons (pat_of_mpat mpat1, pat_of_mpat mpat2), annot)) - | MP_string_append (mpats) -> P_aux ((P_string_append (List.map pat_of_mpat mpats), annot)) - | MP_typ (mpat, typ) -> P_aux (P_typ (typ, pat_of_mpat mpat), annot) - | MP_as (mpat, id) -> P_aux (P_as (pat_of_mpat mpat, id), annot) + | MP_lit lit -> P_aux (P_lit lit, annot) + | MP_id id -> P_aux (P_id id, annot) + | MP_app (id, args) -> P_aux (P_app (id, List.map pat_of_mpat args), annot) + | MP_vector mpats -> P_aux (P_vector (List.map pat_of_mpat mpats), annot) + | MP_vector_concat mpats -> P_aux (P_vector_concat (List.map pat_of_mpat mpats), annot) + | MP_vector_subrange (id, n, m) -> P_aux (P_vector_subrange (id, n, m), annot) + | MP_tuple mpats -> P_aux (P_tuple (List.map pat_of_mpat mpats), annot) + | MP_list mpats -> P_aux (P_list (List.map pat_of_mpat mpats), annot) + | MP_cons (mpat1, mpat2) -> P_aux (P_cons (pat_of_mpat mpat1, pat_of_mpat mpat2), annot) + | MP_string_append mpats -> P_aux (P_string_append (List.map pat_of_mpat mpats), annot) + | MP_typ (mpat, typ) -> P_aux (P_typ (typ, pat_of_mpat mpat), annot) + | MP_as (mpat, id) -> P_aux (P_as (pat_of_mpat mpat, id), annot) let rec exp_of_mpat (MP_aux (mpat, (l, annot))) = let empty_vec = E_aux (E_vector [], (l, empty_uannot)) in - let concat_vectors vec1 vec2 = - E_aux (E_vector_append (vec1, vec2), (l, empty_uannot)) - in + let concat_vectors vec1 vec2 = E_aux (E_vector_append (vec1, vec2), (l, empty_uannot)) in let empty_string = E_aux (E_lit (L_aux (L_string "", Parse_ast.Unknown)), (l, empty_uannot)) in - let string_append str1 str2 = - E_aux (E_app (mk_id "string_append", [str1; str2]), (l, empty_uannot)) - in + let string_append str1 str2 = E_aux (E_app (mk_id "string_append", [str1; str2]), (l, empty_uannot)) in match mpat with - | MP_lit lit -> E_aux (E_lit lit, (l,annot)) - | MP_id id -> E_aux (E_id id, (l,annot)) - | MP_app (id, args) -> E_aux (E_app (id, (List.map exp_of_mpat args)), (l,annot)) - | MP_vector mpats -> E_aux (E_vector (List.map exp_of_mpat mpats), (l,annot)) - | MP_vector_concat mpats -> List.fold_right concat_vectors (List.map (fun m -> exp_of_mpat m) mpats) empty_vec - | MP_vector_subrange (id, n, m) -> E_aux (E_vector_subrange (mk_exp ~loc:(id_loc id) (E_id id), mk_lit_exp (L_num n), mk_lit_exp (L_num m)), (l, annot)) - | MP_tuple mpats -> E_aux (E_tuple (List.map exp_of_mpat mpats), (l,annot)) - | MP_list mpats -> E_aux (E_list (List.map exp_of_mpat mpats), (l,annot)) - | MP_cons (mpat1, mpat2) -> E_aux (E_cons (exp_of_mpat mpat1, exp_of_mpat mpat2), (l,annot)) - | MP_string_append mpats -> List.fold_right string_append (List.map (fun m -> exp_of_mpat m) mpats) empty_string - | MP_typ (mpat, typ) -> E_aux (E_typ (typ, exp_of_mpat mpat), (l,annot)) - | MP_as (mpat, id) -> E_aux (E_match (E_aux (E_id id, (l,annot)), [ - Pat_aux (Pat_exp (pat_of_mpat mpat, exp_of_mpat mpat), (l,annot)) - ]), (l, annot)) (* TODO FIXME location information? *) + | MP_lit lit -> E_aux (E_lit lit, (l, annot)) + | MP_id id -> E_aux (E_id id, (l, annot)) + | MP_app (id, args) -> E_aux (E_app (id, List.map exp_of_mpat args), (l, annot)) + | MP_vector mpats -> E_aux (E_vector (List.map exp_of_mpat mpats), (l, annot)) + | MP_vector_concat mpats -> List.fold_right concat_vectors (List.map (fun m -> exp_of_mpat m) mpats) empty_vec + | MP_vector_subrange (id, n, m) -> + E_aux + (E_vector_subrange (mk_exp ~loc:(id_loc id) (E_id id), mk_lit_exp (L_num n), mk_lit_exp (L_num m)), (l, annot)) + | MP_tuple mpats -> E_aux (E_tuple (List.map exp_of_mpat mpats), (l, annot)) + | MP_list mpats -> E_aux (E_list (List.map exp_of_mpat mpats), (l, annot)) + | MP_cons (mpat1, mpat2) -> E_aux (E_cons (exp_of_mpat mpat1, exp_of_mpat mpat2), (l, annot)) + | MP_string_append mpats -> List.fold_right string_append (List.map (fun m -> exp_of_mpat m) mpats) empty_string + | MP_typ (mpat, typ) -> E_aux (E_typ (typ, exp_of_mpat mpat), (l, annot)) + | MP_as (mpat, id) -> + E_aux + ( E_match (E_aux (E_id id, (l, annot)), [Pat_aux (Pat_exp (pat_of_mpat mpat, exp_of_mpat mpat), (l, annot))]), + (l, annot) + ) +(* TODO FIXME location information? *) let rewrite_ast_realize_mappings effect_info env ast = let effect_info = ref effect_info in let realize_mpexps forwards mpexp1 mpexp2 = - let mpexp_pat, mpexp_exp = - if forwards then mpexp1, mpexp2 else mpexp2, mpexp1 - in + let mpexp_pat, mpexp_exp = if forwards then (mpexp1, mpexp2) else (mpexp2, mpexp1) in let exp = match mpexp_exp with - | MPat_aux ((MPat_pat mpat), _) -> exp_of_mpat mpat - | MPat_aux ((MPat_when (mpat, _), _)) -> exp_of_mpat mpat + | MPat_aux (MPat_pat mpat, _) -> exp_of_mpat mpat + | MPat_aux (MPat_when (mpat, _), _) -> exp_of_mpat mpat in match mpexp_pat with | MPat_aux (MPat_pat mpat, annot) -> Pat_aux (Pat_exp (pat_of_mpat mpat, exp), annot) @@ -3802,201 +3881,327 @@ let rewrite_ast_realize_mappings effect_info env ast = in let realize_single_mpexp mpexp exp = match mpexp with - | MPat_aux (MPat_pat mpat, annot) -> - Pat_aux (Pat_exp (pat_of_mpat mpat, exp), annot) - | MPat_aux (MPat_when (mpat, guard), annot) -> - Pat_aux (Pat_when (pat_of_mpat mpat, guard, exp), annot) + | MPat_aux (MPat_pat mpat, annot) -> Pat_aux (Pat_exp (pat_of_mpat mpat, exp), annot) + | MPat_aux (MPat_when (mpat, guard), annot) -> Pat_aux (Pat_when (pat_of_mpat mpat, guard, exp), annot) in let realize_mapcl forwards id mapcl = match mapcl with - | (MCL_aux (MCL_bidir (mpexp1, mpexp2), _)) -> - [realize_mpexps forwards mpexp1 mpexp2] - | (MCL_aux (MCL_forwards (mpexp, exp), _)) -> - if forwards then - [realize_single_mpexp mpexp exp] - else - [] - | (MCL_aux (MCL_backwards (mpexp, exp), _)) -> - if forwards then - [] - else - [realize_single_mpexp mpexp exp] + | MCL_aux (MCL_bidir (mpexp1, mpexp2), _) -> [realize_mpexps forwards mpexp1 mpexp2] + | MCL_aux (MCL_forwards (mpexp, exp), _) -> if forwards then [realize_single_mpexp mpexp exp] else [] + | MCL_aux (MCL_backwards (mpexp, exp), _) -> if forwards then [] else [realize_single_mpexp mpexp exp] in let realize_bool_mapcl forwards id mapcl = match mapcl with - | (MCL_aux (MCL_bidir (mpexp1, mpexp2), _)) -> - let mpexp = if forwards then mpexp1 else mpexp2 in - [realize_mpexps true mpexp (mk_mpexp (MPat_pat (mk_mpat (MP_lit (mk_lit L_true)))))] - | (MCL_aux (MCL_forwards (mpexp, exp), _)) -> - if forwards then - [realize_single_mpexp mpexp (mk_lit_exp L_true)] - else - [] - | (MCL_aux (MCL_backwards (mpexp, exp), _)) -> - if forwards then - [] - else - [realize_single_mpexp mpexp (mk_lit_exp L_true)] + | MCL_aux (MCL_bidir (mpexp1, mpexp2), _) -> + let mpexp = if forwards then mpexp1 else mpexp2 in + [realize_mpexps true mpexp (mk_mpexp (MPat_pat (mk_mpat (MP_lit (mk_lit L_true)))))] + | MCL_aux (MCL_forwards (mpexp, exp), _) -> + if forwards then [realize_single_mpexp mpexp (mk_lit_exp L_true)] else [] + | MCL_aux (MCL_backwards (mpexp, exp), _) -> + if forwards then [] else [realize_single_mpexp mpexp (mk_lit_exp L_true)] in let arg_id = mk_id "arg#" in - let arg_exp = (mk_exp (E_id arg_id)) in + let arg_exp = mk_exp (E_id arg_id) in let arg_pat = mk_pat (P_id arg_id) in let placeholder_id = mk_id "s#" in let append_placeholder = function | MPat_aux (MPat_pat (MP_aux (MP_string_append mpats, p_annot)), aux_annot) -> - MPat_aux (MPat_pat (MP_aux (MP_string_append (mpats @ [mk_mpat (MP_id placeholder_id)]), p_annot)), aux_annot) + MPat_aux (MPat_pat (MP_aux (MP_string_append (mpats @ [mk_mpat (MP_id placeholder_id)]), p_annot)), aux_annot) | MPat_aux (MPat_when (MP_aux (MP_string_append mpats, p_annot), guard), aux_annot) -> - MPat_aux (MPat_when (MP_aux (MP_string_append (mpats @ [mk_mpat (MP_id placeholder_id)]), p_annot), guard), aux_annot) + MPat_aux + (MPat_when (MP_aux (MP_string_append (mpats @ [mk_mpat (MP_id placeholder_id)]), p_annot), guard), aux_annot) | MPat_aux (MPat_pat mpat, aux_annot) -> - MPat_aux (MPat_pat (mk_mpat (MP_string_append [mpat; mk_mpat (MP_id placeholder_id)])), aux_annot) + MPat_aux (MPat_pat (mk_mpat (MP_string_append [mpat; mk_mpat (MP_id placeholder_id)])), aux_annot) | MPat_aux (MPat_when (mpat, guard), aux_annot) -> - MPat_aux (MPat_when (mk_mpat (MP_string_append [mpat; mk_mpat (MP_id placeholder_id)]), guard), aux_annot) + MPat_aux (MPat_when (mk_mpat (MP_string_append [mpat; mk_mpat (MP_id placeholder_id)]), guard), aux_annot) in let realize_prefix_mapcl forwards id mapcl = - let strlen = ( - mk_mpat (MP_app ( mk_id "sub_nat", - [ - mk_mpat (MP_app ( mk_id "string_length" , [mk_mpat (MP_id arg_id )])); - mk_mpat (MP_app ( mk_id "string_length" , [mk_mpat (MP_id placeholder_id)])); - ] - )) - ) in + let strlen = + mk_mpat + (MP_app + ( mk_id "sub_nat", + [ + mk_mpat (MP_app (mk_id "string_length", [mk_mpat (MP_id arg_id)])); + mk_mpat (MP_app (mk_id "string_length", [mk_mpat (MP_id placeholder_id)])); + ] + ) + ) + in match mapcl with - | (MCL_aux (MCL_bidir (mpexp1, mpexp2), _)) -> begin - let mpexp = if forwards then mpexp1 else mpexp2 in - let other = if forwards then mpexp2 else mpexp1 in - match other with - | MPat_aux (MPat_pat mpat2, _) - | MPat_aux (MPat_when (mpat2, _), _)-> - [realize_mpexps true (append_placeholder mpexp) (mk_mpexp (MPat_pat (mk_mpat (MP_app ((mk_id "Some"), [ mk_mpat (MP_tuple [mpat2; strlen]) ])))))] + | MCL_aux (MCL_bidir (mpexp1, mpexp2), _) -> begin + let mpexp = if forwards then mpexp1 else mpexp2 in + let other = if forwards then mpexp2 else mpexp1 in + match other with + | MPat_aux (MPat_pat mpat2, _) | MPat_aux (MPat_when (mpat2, _), _) -> + [ + realize_mpexps true (append_placeholder mpexp) + (mk_mpexp (MPat_pat (mk_mpat (MP_app (mk_id "Some", [mk_mpat (MP_tuple [mpat2; strlen])]))))); + ] end - | (MCL_aux (MCL_forwards (mpexp, exp), _)) -> begin + | MCL_aux (MCL_forwards (mpexp, exp), _) -> begin if forwards then - [realize_single_mpexp (append_placeholder mpexp) (mk_exp (E_app ((mk_id "Some"), [mk_exp (E_tuple [exp; exp_of_mpat strlen])])))] - else - [] + [ + realize_single_mpexp (append_placeholder mpexp) + (mk_exp (E_app (mk_id "Some", [mk_exp (E_tuple [exp; exp_of_mpat strlen])]))); + ] + else [] end - | (MCL_aux (MCL_backwards (mpexp, exp), _)) -> begin - if forwards then - [] + | MCL_aux (MCL_backwards (mpexp, exp), _) -> begin + if forwards then [] else - [realize_single_mpexp (append_placeholder mpexp) (mk_exp (E_app ((mk_id "Some"), [mk_exp (E_tuple [exp; exp_of_mpat strlen])])))] + [ + realize_single_mpexp (append_placeholder mpexp) + (mk_exp (E_app (mk_id "Some", [mk_exp (E_tuple [exp; exp_of_mpat strlen])]))); + ] end in let realize_val_spec def_annot = function - | (VS_aux (VS_val_spec (TypSchm_aux (TypSchm_ts (typq, Typ_aux (Typ_bidir (typ1, typ2), l)), _), id, _, _), ((_, (tannot:tannot)) as annot))) -> - let forwards_id = mk_id (string_of_id id ^ "_forwards") in - let forwards_matches_id = mk_id (string_of_id id ^ "_forwards_matches") in - let backwards_id = mk_id (string_of_id id ^ "_backwards") in - let backwards_matches_id = mk_id (string_of_id id ^ "_backwards_matches") in - - effect_info := List.fold_left (Effects.copy_mapping_to_function id) !effect_info [forwards_id; forwards_matches_id; backwards_id; backwards_matches_id]; - - let env = env_of_annot annot in - let forwards_typ = Typ_aux (Typ_fn ([typ1], typ2), l) in - let forwards_matches_typ = Typ_aux (Typ_fn ([typ1], bool_typ), l) in - let backwards_typ = Typ_aux (Typ_fn ([typ2], typ1), l) in - let backwards_matches_typ = Typ_aux (Typ_fn ([typ2], bool_typ), l) in - - let forwards_spec = VS_aux (VS_val_spec (mk_typschm typq forwards_typ, forwards_id, None, false), no_annot) in - let backwards_spec = VS_aux (VS_val_spec (mk_typschm typq backwards_typ, backwards_id, None, false), no_annot) in - let forwards_matches_spec = VS_aux (VS_val_spec (mk_typschm typq forwards_matches_typ, forwards_matches_id, None, false), no_annot) in - let backwards_matches_spec = VS_aux (VS_val_spec (mk_typschm typq backwards_matches_typ, backwards_matches_id, None, false), no_annot) in - - let forwards_spec, env = Type_check.check_val_spec env def_annot forwards_spec in - let backwards_spec, env = Type_check.check_val_spec env def_annot backwards_spec in - let forwards_matches_spec, env = Type_check.check_val_spec env def_annot forwards_matches_spec in - let backwards_matches_spec, env = Type_check.check_val_spec env def_annot backwards_matches_spec in - - let prefix_id = mk_id (string_of_id id ^ "_matches_prefix") in - let string_defs = - begin if subtype_check env typ1 string_typ && subtype_check env string_typ typ1 then begin - effect_info := Effects.copy_mapping_to_function id !effect_info prefix_id; - let forwards_prefix_typ = Typ_aux (Typ_fn ([typ1], app_typ (mk_id "option") [A_aux (A_typ (tuple_typ [typ2; nat_typ]), Parse_ast.Unknown)]), Parse_ast.Unknown) in - let forwards_prefix_spec = VS_aux (VS_val_spec (mk_typschm typq forwards_prefix_typ, prefix_id, None, false), no_annot) in - let forwards_prefix_spec, env = Type_check.check_val_spec env def_annot forwards_prefix_spec in - forwards_prefix_spec - end else - if subtype_check env typ2 string_typ && subtype_check env string_typ typ2 then begin - effect_info := Effects.copy_mapping_to_function id !effect_info prefix_id; - let backwards_prefix_typ = Typ_aux (Typ_fn ([typ2], app_typ (mk_id "option") [A_aux (A_typ (tuple_typ [typ1; nat_typ]), Parse_ast.Unknown)]), Parse_ast.Unknown) in - let backwards_prefix_spec = VS_aux (VS_val_spec (mk_typschm typq backwards_prefix_typ, prefix_id, None, false), no_annot) in - let backwards_prefix_spec, env = Type_check.check_val_spec env def_annot backwards_prefix_spec in - backwards_prefix_spec - end else - [] - end - in + | VS_aux + ( VS_val_spec (TypSchm_aux (TypSchm_ts (typq, Typ_aux (Typ_bidir (typ1, typ2), l)), _), id, _, _), + ((_, (tannot : tannot)) as annot) + ) -> + let forwards_id = mk_id (string_of_id id ^ "_forwards") in + let forwards_matches_id = mk_id (string_of_id id ^ "_forwards_matches") in + let backwards_id = mk_id (string_of_id id ^ "_backwards") in + let backwards_matches_id = mk_id (string_of_id id ^ "_backwards_matches") in + + effect_info := + List.fold_left (Effects.copy_mapping_to_function id) !effect_info + [forwards_id; forwards_matches_id; backwards_id; backwards_matches_id]; - forwards_spec - @ backwards_spec - @ forwards_matches_spec - @ backwards_matches_spec - @ string_defs + let env = env_of_annot annot in + let forwards_typ = Typ_aux (Typ_fn ([typ1], typ2), l) in + let forwards_matches_typ = Typ_aux (Typ_fn ([typ1], bool_typ), l) in + let backwards_typ = Typ_aux (Typ_fn ([typ2], typ1), l) in + let backwards_matches_typ = Typ_aux (Typ_fn ([typ2], bool_typ), l) in + + let forwards_spec = VS_aux (VS_val_spec (mk_typschm typq forwards_typ, forwards_id, None, false), no_annot) in + let backwards_spec = + VS_aux (VS_val_spec (mk_typschm typq backwards_typ, backwards_id, None, false), no_annot) + in + let forwards_matches_spec = + VS_aux (VS_val_spec (mk_typschm typq forwards_matches_typ, forwards_matches_id, None, false), no_annot) + in + let backwards_matches_spec = + VS_aux (VS_val_spec (mk_typschm typq backwards_matches_typ, backwards_matches_id, None, false), no_annot) + in + + let forwards_spec, env = Type_check.check_val_spec env def_annot forwards_spec in + let backwards_spec, env = Type_check.check_val_spec env def_annot backwards_spec in + let forwards_matches_spec, env = Type_check.check_val_spec env def_annot forwards_matches_spec in + let backwards_matches_spec, env = Type_check.check_val_spec env def_annot backwards_matches_spec in + + let prefix_id = mk_id (string_of_id id ^ "_matches_prefix") in + let string_defs = + begin + if subtype_check env typ1 string_typ && subtype_check env string_typ typ1 then begin + effect_info := Effects.copy_mapping_to_function id !effect_info prefix_id; + let forwards_prefix_typ = + Typ_aux + ( Typ_fn + ([typ1], app_typ (mk_id "option") [A_aux (A_typ (tuple_typ [typ2; nat_typ]), Parse_ast.Unknown)]), + Parse_ast.Unknown + ) + in + let forwards_prefix_spec = + VS_aux (VS_val_spec (mk_typschm typq forwards_prefix_typ, prefix_id, None, false), no_annot) + in + let forwards_prefix_spec, env = Type_check.check_val_spec env def_annot forwards_prefix_spec in + forwards_prefix_spec + end + else if subtype_check env typ2 string_typ && subtype_check env string_typ typ2 then begin + effect_info := Effects.copy_mapping_to_function id !effect_info prefix_id; + let backwards_prefix_typ = + Typ_aux + ( Typ_fn + ([typ2], app_typ (mk_id "option") [A_aux (A_typ (tuple_typ [typ1; nat_typ]), Parse_ast.Unknown)]), + Parse_ast.Unknown + ) + in + let backwards_prefix_spec = + VS_aux (VS_val_spec (mk_typschm typq backwards_prefix_typ, prefix_id, None, false), no_annot) + in + let backwards_prefix_spec, env = Type_check.check_val_spec env def_annot backwards_prefix_spec in + backwards_prefix_spec + end + else [] + end + in + + forwards_spec @ backwards_spec @ forwards_matches_spec @ backwards_matches_spec @ string_defs | vs -> [DEF_aux (DEF_val vs, def_annot)] in - let realize_mapdef def_annot (MD_aux (MD_mapping (id, _, mapcls), (l, (tannot:tannot)))) = + let realize_mapdef def_annot (MD_aux (MD_mapping (id, _, mapcls), (l, (tannot : tannot)))) = let forwards_id = mk_id (string_of_id id ^ "_forwards") in let forwards_matches_id = mk_id (string_of_id id ^ "_forwards_matches") in let backwards_id = mk_id (string_of_id id ^ "_backwards") in let backwards_matches_id = mk_id (string_of_id id ^ "_backwards_matches") in - effect_info := List.fold_left (Effects.copy_mapping_to_function id) !effect_info [forwards_id; forwards_matches_id; backwards_id; backwards_matches_id]; - - let non_rec = (Rec_aux (Rec_nonrec, Parse_ast.Unknown)) in + effect_info := + List.fold_left (Effects.copy_mapping_to_function id) !effect_info + [forwards_id; forwards_matches_id; backwards_id; backwards_matches_id]; + + let non_rec = Rec_aux (Rec_nonrec, Parse_ast.Unknown) in (* We need to make sure we get the environment for the last mapping clause *) - let env = match List.rev mapcls with + let env = + match List.rev mapcls with | MCL_aux (_, (_, mapcl_tannot)) :: _ -> env_of_tannot mapcl_tannot | _ -> raise (Reporting.err_unreachable l __POS__ "mapping with no clauses?") in - let (typq, bidir_typ) = Env.get_val_spec id env in - let (typ1, typ2, l) = match bidir_typ with - | Typ_aux (Typ_bidir (typ1, typ2), l) -> typ1, typ2, l + let typq, bidir_typ = Env.get_val_spec id env in + let typ1, typ2, l = + match bidir_typ with + | Typ_aux (Typ_bidir (typ1, typ2), l) -> (typ1, typ2, l) | _ -> raise (Reporting.err_unreachable l __POS__ "non-bidir type of mapping?") in - let no_tannot = (Typ_annot_opt_aux (Typ_annot_opt_none, Parse_ast.Unknown)) in - let forwards_match = mk_exp (E_match (arg_exp, ((List.map (fun mapcl -> strip_mapcl mapcl |> realize_mapcl true forwards_id) mapcls) |> List.flatten))) in - let backwards_match = mk_exp (E_match (arg_exp, ((List.map (fun mapcl -> strip_mapcl mapcl |> realize_mapcl false backwards_id) mapcls) |> List.flatten))) in + let no_tannot = Typ_annot_opt_aux (Typ_annot_opt_none, Parse_ast.Unknown) in + let forwards_match = + mk_exp + (E_match + (arg_exp, List.map (fun mapcl -> strip_mapcl mapcl |> realize_mapcl true forwards_id) mapcls |> List.flatten) + ) + in + let backwards_match = + mk_exp + (E_match + ( arg_exp, + List.map (fun mapcl -> strip_mapcl mapcl |> realize_mapcl false backwards_id) mapcls |> List.flatten + ) + ) + in let wildcard = mk_pexp (Pat_exp (mk_pat P_wild, mk_exp (E_lit (mk_lit L_false)))) in - let forwards_matches_match = mk_exp (E_match (arg_exp, ((List.map (fun mapcl -> strip_mapcl mapcl |> realize_bool_mapcl true forwards_matches_id) mapcls) |> List.flatten) @ [wildcard])) in - let backwards_matches_match = mk_exp (E_match (arg_exp, ((List.map (fun mapcl -> strip_mapcl mapcl |> realize_bool_mapcl false backwards_matches_id) mapcls) |> List.flatten) @ [wildcard])) in - - let forwards_fun = (FD_aux (FD_function (non_rec, no_tannot, [mk_funcl forwards_id arg_pat forwards_match]), (l, empty_uannot))) in - let backwards_fun = (FD_aux (FD_function (non_rec, no_tannot, [mk_funcl backwards_id arg_pat backwards_match]), (l, empty_uannot))) in - let forwards_matches_fun = (FD_aux (FD_function (non_rec, no_tannot, [mk_funcl forwards_matches_id arg_pat forwards_matches_match]), (l, empty_uannot))) in - let backwards_matches_fun = (FD_aux (FD_function (non_rec, no_tannot, [mk_funcl backwards_matches_id arg_pat backwards_matches_match]), (l, empty_uannot))) in - - typ_debug (lazy (Printf.sprintf "forwards for mapping %s: %s\n%!" (string_of_id id) (Pretty_print_sail.doc_fundef forwards_fun |> Pretty_print_sail.to_string))); - typ_debug (lazy (Printf.sprintf "backwards for mapping %s: %s\n%!" (string_of_id id) (Pretty_print_sail.doc_fundef backwards_fun |> Pretty_print_sail.to_string))); - typ_debug (lazy (Printf.sprintf "forwards matches for mapping %s: %s\n%!" (string_of_id id) (Pretty_print_sail.doc_fundef forwards_matches_fun |> Pretty_print_sail.to_string))); - typ_debug (lazy (Printf.sprintf "backwards matches for mapping %s: %s\n%!" (string_of_id id) (Pretty_print_sail.doc_fundef backwards_matches_fun |> Pretty_print_sail.to_string))); + let forwards_matches_match = + mk_exp + (E_match + ( arg_exp, + (List.map (fun mapcl -> strip_mapcl mapcl |> realize_bool_mapcl true forwards_matches_id) mapcls + |> List.flatten + ) + @ [wildcard] + ) + ) + in + let backwards_matches_match = + mk_exp + (E_match + ( arg_exp, + (List.map (fun mapcl -> strip_mapcl mapcl |> realize_bool_mapcl false backwards_matches_id) mapcls + |> List.flatten + ) + @ [wildcard] + ) + ) + in + + let forwards_fun = + FD_aux (FD_function (non_rec, no_tannot, [mk_funcl forwards_id arg_pat forwards_match]), (l, empty_uannot)) + in + let backwards_fun = + FD_aux (FD_function (non_rec, no_tannot, [mk_funcl backwards_id arg_pat backwards_match]), (l, empty_uannot)) + in + let forwards_matches_fun = + FD_aux + ( FD_function (non_rec, no_tannot, [mk_funcl forwards_matches_id arg_pat forwards_matches_match]), + (l, empty_uannot) + ) + in + let backwards_matches_fun = + FD_aux + ( FD_function (non_rec, no_tannot, [mk_funcl backwards_matches_id arg_pat backwards_matches_match]), + (l, empty_uannot) + ) + in + + typ_debug + ( lazy + (Printf.sprintf "forwards for mapping %s: %s\n%!" (string_of_id id) + (Pretty_print_sail.doc_fundef forwards_fun |> Pretty_print_sail.to_string) + ) + ); + typ_debug + ( lazy + (Printf.sprintf "backwards for mapping %s: %s\n%!" (string_of_id id) + (Pretty_print_sail.doc_fundef backwards_fun |> Pretty_print_sail.to_string) + ) + ); + typ_debug + ( lazy + (Printf.sprintf "forwards matches for mapping %s: %s\n%!" (string_of_id id) + (Pretty_print_sail.doc_fundef forwards_matches_fun |> Pretty_print_sail.to_string) + ) + ); + typ_debug + ( lazy + (Printf.sprintf "backwards matches for mapping %s: %s\n%!" (string_of_id id) + (Pretty_print_sail.doc_fundef backwards_matches_fun |> Pretty_print_sail.to_string) + ) + ); let forwards_fun, _ = Type_check.check_fundef env def_annot forwards_fun in let backwards_fun, _ = Type_check.check_fundef env def_annot backwards_fun in let forwards_matches_fun, _ = Type_check.check_fundef env def_annot forwards_matches_fun in let backwards_matches_fun, _ = Type_check.check_fundef env def_annot backwards_matches_fun in let prefix_id = mk_id (string_of_id id ^ "_matches_prefix") in - let prefix_wildcard = mk_pexp (Pat_exp (mk_pat P_wild, mk_exp (E_app (mk_id "None", [mk_exp (E_lit (mk_lit L_unit))])))) in + let prefix_wildcard = + mk_pexp (Pat_exp (mk_pat P_wild, mk_exp (E_app (mk_id "None", [mk_exp (E_lit (mk_lit L_unit))])))) + in let string_defs = - begin if subtype_check env typ1 string_typ && subtype_check env string_typ typ1 then begin - effect_info := Effects.copy_mapping_to_function id !effect_info prefix_id; - let forwards_prefix_match = mk_exp (E_match (arg_exp, ((List.map (fun mapcl -> strip_mapcl mapcl |> realize_prefix_mapcl true prefix_id) mapcls) |> List.flatten) @ [prefix_wildcard])) in - let forwards_prefix_fun = (FD_aux (FD_function (non_rec, no_tannot, [mk_funcl prefix_id arg_pat forwards_prefix_match]), (l, empty_uannot))) in - typ_debug (lazy (Printf.sprintf "forwards prefix matches for mapping %s: %s\n%!" (string_of_id id) (Pretty_print_sail.doc_fundef forwards_prefix_fun |> Pretty_print_sail.to_string))); - let forwards_prefix_fun, _ = Type_check.check_fundef env def_annot forwards_prefix_fun in - forwards_prefix_fun - end else - if subtype_check env typ2 string_typ && subtype_check env string_typ typ2 then begin - effect_info := Effects.copy_mapping_to_function id !effect_info prefix_id; - let backwards_prefix_match = mk_exp (E_match (arg_exp, ((List.map (fun mapcl -> strip_mapcl mapcl |> realize_prefix_mapcl false prefix_id) mapcls) |> List.flatten) @ [prefix_wildcard])) in - let backwards_prefix_fun = (FD_aux (FD_function (non_rec, no_tannot, [mk_funcl prefix_id arg_pat backwards_prefix_match]), (l, empty_uannot))) in - typ_debug (lazy (Printf.sprintf "backwards prefix matches for mapping %s: %s\n%!" (string_of_id id) (Pretty_print_sail.doc_fundef backwards_prefix_fun |> Pretty_print_sail.to_string))); - let backwards_prefix_fun, _ = Type_check.check_fundef env def_annot backwards_prefix_fun in - backwards_prefix_fun - end else - [] + begin + if subtype_check env typ1 string_typ && subtype_check env string_typ typ1 then begin + effect_info := Effects.copy_mapping_to_function id !effect_info prefix_id; + let forwards_prefix_match = + mk_exp + (E_match + ( arg_exp, + (List.map (fun mapcl -> strip_mapcl mapcl |> realize_prefix_mapcl true prefix_id) mapcls + |> List.flatten + ) + @ [prefix_wildcard] + ) + ) + in + let forwards_prefix_fun = + FD_aux + (FD_function (non_rec, no_tannot, [mk_funcl prefix_id arg_pat forwards_prefix_match]), (l, empty_uannot)) + in + typ_debug + ( lazy + (Printf.sprintf "forwards prefix matches for mapping %s: %s\n%!" (string_of_id id) + (Pretty_print_sail.doc_fundef forwards_prefix_fun |> Pretty_print_sail.to_string) + ) + ); + let forwards_prefix_fun, _ = Type_check.check_fundef env def_annot forwards_prefix_fun in + forwards_prefix_fun + end + else if subtype_check env typ2 string_typ && subtype_check env string_typ typ2 then begin + effect_info := Effects.copy_mapping_to_function id !effect_info prefix_id; + let backwards_prefix_match = + mk_exp + (E_match + ( arg_exp, + (List.map (fun mapcl -> strip_mapcl mapcl |> realize_prefix_mapcl false prefix_id) mapcls + |> List.flatten + ) + @ [prefix_wildcard] + ) + ) + in + let backwards_prefix_fun = + FD_aux + (FD_function (non_rec, no_tannot, [mk_funcl prefix_id arg_pat backwards_prefix_match]), (l, empty_uannot)) + in + typ_debug + ( lazy + (Printf.sprintf "backwards prefix matches for mapping %s: %s\n%!" (string_of_id id) + (Pretty_print_sail.doc_fundef backwards_prefix_fun |> Pretty_print_sail.to_string) + ) + ); + let backwards_prefix_fun, _ = Type_check.check_fundef env def_annot backwards_prefix_fun in + backwards_prefix_fun + end + else [] end in let has_def id = IdSet.mem id (ids_of_ast ast) in @@ -4005,7 +4210,7 @@ let rewrite_ast_realize_mappings effect_info env ast = @ (if has_def backwards_id then [] else backwards_fun) @ (if has_def forwards_matches_id then [] else forwards_matches_fun) @ (if has_def backwards_matches_id then [] else backwards_matches_fun) - @ (if has_def prefix_id then [] else string_defs) + @ if has_def prefix_id then [] else string_defs in let rewrite_def def = match def with @@ -4014,7 +4219,7 @@ let rewrite_ast_realize_mappings effect_info env ast = | d -> [d] in let ast = { ast with defs = List.map rewrite_def ast.defs |> List.flatten } in - ast, !effect_info, env + (ast, !effect_info, env) (* Rewrite to make all pattern matches in Coq output exhaustive. Assumes that guards, vector patterns, etc have been rewritten already, @@ -4028,352 +4233,331 @@ let rewrite_ast_realize_mappings effect_info env ast = Note: if this naive implementation turns out to be too slow or buggy, we could look at implementing Maranget JFP 17(3), 2007. - *) +*) let opt_coq_warn_nonexhaustive = ref false - -module MakeExhaustive = -struct - -type rlit = - | RL_unit - | RL_true - | RL_false - | RL_inf - -let string_of_rlit = function - | RL_unit -> "()" - | RL_true -> "true" - | RL_false -> "false" - | RL_inf -> "..." - -let rlit_of_lit (L_aux (l,_)) = - match l with - | L_unit -> RL_unit - | L_zero -> RL_inf - | L_one -> RL_inf - | L_true -> RL_true - | L_false -> RL_false - | L_num _ | L_hex _ | L_bin _ | L_string _ | L_real _ -> RL_inf - | L_undef -> assert false - -let inv_rlit_of_lit (L_aux (l,_)) = - match l with - | L_unit -> [] - | L_zero -> [RL_inf] - | L_one -> [RL_inf] - | L_true -> [RL_false] - | L_false -> [RL_true] - | L_num _ | L_hex _ | L_bin _ | L_string _ | L_real _ -> [RL_inf] - | L_undef -> assert false - -type residual_pattern = - | RP_any - | RP_lit of rlit - | RP_enum of id - | RP_app of id * residual_pattern list - | RP_tuple of residual_pattern list - | RP_nil - | RP_cons of residual_pattern * residual_pattern - -let rec string_of_rp = function - | RP_any -> "_" - | RP_lit rlit -> string_of_rlit rlit - | RP_enum id -> string_of_id id - | RP_app (f,args) -> string_of_id f ^ "(" ^ String.concat "," (List.map string_of_rp args) ^ ")" - | RP_tuple rps -> "(" ^ String.concat "," (List.map string_of_rp rps) ^ ")" - | RP_nil -> "[| |]" - | RP_cons (rp1,rp2) -> string_of_rp rp1 ^ "::" ^ string_of_rp rp2 - -type ctx = { - env : Env.t; - enum_to_rest: (residual_pattern list) Bindings.t; - constructor_to_rest: (residual_pattern list) Bindings.t -} - -let make_enum_mappings ids m = - IdSet.fold (fun id m -> - Bindings.add id - (List.map (fun e -> RP_enum e) (IdSet.elements (IdSet.remove id ids))) m) - ids - m - -let make_cstr_mappings env ids m = - let ids = IdSet.elements ids in - let constructors = List.map - (fun id -> - let _,ty = Env.get_val_spec id env in - let args = match ty with - | Typ_aux (Typ_fn (tys,_),_) -> List.map (fun _ -> RP_any) tys - | _ -> [RP_any] - in RP_app (id,args)) ids in - let rec aux ids acc l = - match ids, l with - | [], [] -> m - | id::ids, rp::t -> - let m = aux ids (acc@[rp]) t in - Bindings.add id (acc@t) m - | _ -> assert false - in aux ids [] constructors - -let ctx_from_env env = - { env = env; - enum_to_rest = Bindings.fold (fun _ ids m -> make_enum_mappings ids m) - (Env.get_enums env) Bindings.empty; - constructor_to_rest = Bindings.fold (fun _ ids m -> make_cstr_mappings env ids m) - (Bindings.map (fun (_, tus) -> IdSet.of_list (List.map type_union_id tus)) (Env.get_variants env)) Bindings.empty + +module MakeExhaustive = struct + type rlit = RL_unit | RL_true | RL_false | RL_inf + + let string_of_rlit = function RL_unit -> "()" | RL_true -> "true" | RL_false -> "false" | RL_inf -> "..." + + let rlit_of_lit (L_aux (l, _)) = + match l with + | L_unit -> RL_unit + | L_zero -> RL_inf + | L_one -> RL_inf + | L_true -> RL_true + | L_false -> RL_false + | L_num _ | L_hex _ | L_bin _ | L_string _ | L_real _ -> RL_inf + | L_undef -> assert false + + let inv_rlit_of_lit (L_aux (l, _)) = + match l with + | L_unit -> [] + | L_zero -> [RL_inf] + | L_one -> [RL_inf] + | L_true -> [RL_false] + | L_false -> [RL_true] + | L_num _ | L_hex _ | L_bin _ | L_string _ | L_real _ -> [RL_inf] + | L_undef -> assert false + + type residual_pattern = + | RP_any + | RP_lit of rlit + | RP_enum of id + | RP_app of id * residual_pattern list + | RP_tuple of residual_pattern list + | RP_nil + | RP_cons of residual_pattern * residual_pattern + + let rec string_of_rp = function + | RP_any -> "_" + | RP_lit rlit -> string_of_rlit rlit + | RP_enum id -> string_of_id id + | RP_app (f, args) -> string_of_id f ^ "(" ^ String.concat "," (List.map string_of_rp args) ^ ")" + | RP_tuple rps -> "(" ^ String.concat "," (List.map string_of_rp rps) ^ ")" + | RP_nil -> "[| |]" + | RP_cons (rp1, rp2) -> string_of_rp rp1 ^ "::" ^ string_of_rp rp2 + + type ctx = { + env : Env.t; + enum_to_rest : residual_pattern list Bindings.t; + constructor_to_rest : residual_pattern list Bindings.t; } -let rec remove_clause_from_pattern ctx (P_aux (rm_pat,ann)) res_pat = - let subpats rm_pats res_pats = - let res_pats' = List.map2 (remove_clause_from_pattern ctx) rm_pats res_pats in - let rec aux acc fixed residual = - match fixed, residual with - | [], [] -> [] - | (fh::ft), (rh::rt) -> - let rt' = aux (acc@[fh]) ft rt in - let newr = List.map (fun x -> acc @ (x::ft)) rh in - newr @ rt' - | _,_ -> assert false (* impossible because we managed map2 above *) - in aux [] res_pats res_pats' - in - let inconsistent () = - raise (Reporting.err_unreachable (fst ann) __POS__ - ("Inconsistency during exhaustiveness analysis with " ^ - string_of_rp res_pat)) - in - (*let _ = print_endline (!printprefix ^ "pat: " ^string_of_pat (P_aux (rm_pat,ann))) in - let _ = print_endline (!printprefix ^ "res_pat: " ^string_of_rp res_pat) in - let _ = printprefix := " " ^ !printprefix in*) - let rp' = - match rm_pat with - | P_wild -> [] - | P_id id when (match Env.lookup_id id ctx.env with Unbound _ | Local _ -> true | _ -> false) -> [] - | P_lit lit -> - (match res_pat with - | RP_any -> List.map (fun l -> RP_lit l) (inv_rlit_of_lit lit) - | RP_lit RL_inf -> [res_pat] - | RP_lit lit' -> if lit' = rlit_of_lit lit then [] else [res_pat] - | _ -> inconsistent ()) - | P_as (p,_) - | P_typ (_,p) - | P_var (p,_) - -> remove_clause_from_pattern ctx p res_pat - | P_id id -> - (match Env.lookup_id id ctx.env with - | Enum enum -> - (match res_pat with - | RP_any -> Bindings.find id ctx.enum_to_rest - | RP_enum id' -> if Id.compare id id' == 0 then [] else [res_pat] - | _ -> inconsistent ()) - | _ -> assert false) - | P_tuple rm_pats -> - let previous_res_pats = - match res_pat with - | RP_tuple res_pats -> res_pats - | RP_any -> List.map (fun _ -> RP_any) rm_pats - | _ -> inconsistent () - in - let res_pats' = subpats rm_pats previous_res_pats in - List.map (fun rps -> RP_tuple rps) res_pats' - | P_app (id,args) -> - (match res_pat with - | RP_app (id',residual_args) -> - if Id.compare id id' == 0 then - let res_pats' = - (* Constructors that were specified without a return type might get - an extra tuple in their type; expand that here if necessary. - TODO: this should go away if we enforce proper arities. *) - match args, residual_args with - | [], [RP_any] - | _::_::_, [RP_any] - -> subpats args (List.map (fun _ -> RP_any) args) - | _,_ -> - subpats args residual_args in - List.map (fun rps -> RP_app (id,rps)) res_pats' - else [res_pat] - | RP_any -> - let res_args = subpats args (List.map (fun _ -> RP_any) args) in - (List.map (fun l -> (RP_app (id,l))) res_args) @ - (Bindings.find id ctx.constructor_to_rest) - | _ -> inconsistent () - ) - | P_list ps -> - (match ps with - | p1::ptl -> remove_clause_from_pattern ctx (P_aux (P_cons (p1,P_aux (P_list ptl,ann)),ann)) res_pat - | [] -> - match res_pat with - | RP_any -> [RP_cons (RP_any,RP_any)] - | RP_cons _ -> [res_pat] - | RP_nil -> [] - | _ -> inconsistent ()) - | P_cons (p1,p2) -> begin - let rp',rps = - match res_pat with - | RP_cons (rp1,rp2) -> [], Some [rp1;rp2] - | RP_any -> [RP_nil], Some [RP_any;RP_any] - | RP_nil -> [RP_nil], None - | _ -> inconsistent () - in - match rps with - | None -> rp' - | Some rps -> - let res_pats = subpats [p1;p2] rps in - rp' @ List.map (function [rp1;rp2] -> RP_cons (rp1,rp2) | _ -> assert false) res_pats - end - | P_or _ -> - raise (Reporting.err_unreachable (fst ann) __POS__ "Or pattern not supported") - | P_not _ -> - raise (Reporting.err_unreachable (fst ann) __POS__ "Negated pattern not supported") - | P_vector _ - | P_vector_concat _ - | P_vector_subrange _ - | P_string_append _ -> - raise (Reporting.err_unreachable (fst ann) __POS__ - "Found pattern that should have been rewritten away in earlier stage") - - (*in let _ = printprefix := String.sub (!printprefix) 0 (String.length !printprefix - 2) - in let _ = print_endline (!printprefix ^ "res_pats': " ^ String.concat "; " (List.map string_of_rp rp'))*) - in rp' - -let process_pexp env = - let ctx = ctx_from_env env in - fun rps patexp -> - (*let _ = print_endline ("res_pats: " ^ String.concat "; " (List.map string_of_rp rps)) in - let _ = print_endline ("pat: " ^ string_of_pexp patexp) in*) - match patexp with - | Pat_aux (Pat_exp (p,_),_) -> - List.concat (List.map (remove_clause_from_pattern ctx p) rps) - | Pat_aux (Pat_when _,(l,_)) -> - raise (Reporting.err_unreachable l __POS__ - "Guarded pattern should have been rewritten away") - -(* We do some minimal redundancy checking to remove bogus wildcard patterns here *) -let check_cases process is_wild loc_of cases = - let rec aux rps acc = function - | [] -> acc, rps - | [p] when is_wild p && match rps with [] -> true | _ -> false -> - let () = Reporting.print_err - (loc_of p) "Match checking" "Redundant wildcard clause" in - acc, [] - | h::t -> aux (process rps h) (h::acc) t - in - let cases, rps = aux [RP_any] [] cases in - List.rev cases, rps - -let not_enum env id = - match Env.lookup_id id env with - | Enum _ -> false - | _ -> true - -let pexp_is_wild = function - | (Pat_aux (Pat_exp (P_aux (P_wild,_),_),_)) -> true - | (Pat_aux (Pat_exp (P_aux (P_id id,ann),_),_)) - when not_enum (env_of_annot ann) id -> true - | _ -> false + let make_enum_mappings ids m = + IdSet.fold + (fun id m -> Bindings.add id (List.map (fun e -> RP_enum e) (IdSet.elements (IdSet.remove id ids))) m) + ids m -let pexp_loc = function - | (Pat_aux (Pat_exp (P_aux (_,(l,_)),_),_)) -> l - | (Pat_aux (Pat_when (P_aux (_,(l,_)),_,_),_)) -> l - -let funcl_is_wild = function - | (FCL_aux (FCL_funcl (_,pexp),_)) -> pexp_is_wild pexp - -let funcl_loc (FCL_aux (_, (def_annot, _))) = def_annot.loc - -let rewrite_case (e,ann) = - match e with - | E_match (e1,cases) - | E_try (e1,cases) -> - begin - let env = env_of_annot ann in - let cases, rps = check_cases (process_pexp env) pexp_is_wild pexp_loc cases in - let rebuild cases = match e with - | E_match _ -> E_match (e1,cases) - | E_try _ -> E_try (e1,cases) - | _ -> assert false - in - match rps with - | [] -> E_aux (rebuild cases,ann) - | (example::_) -> + let make_cstr_mappings env ids m = + let ids = IdSet.elements ids in + let constructors = + List.map + (fun id -> + let _, ty = Env.get_val_spec id env in + let args = match ty with Typ_aux (Typ_fn (tys, _), _) -> List.map (fun _ -> RP_any) tys | _ -> [RP_any] in + RP_app (id, args) + ) + ids + in + let rec aux ids acc l = + match (ids, l) with + | [], [] -> m + | id :: ids, rp :: t -> + let m = aux ids (acc @ [rp]) t in + Bindings.add id (acc @ t) m + | _ -> assert false + in + aux ids [] constructors + + let ctx_from_env env = + { + env; + enum_to_rest = Bindings.fold (fun _ ids m -> make_enum_mappings ids m) (Env.get_enums env) Bindings.empty; + constructor_to_rest = + Bindings.fold + (fun _ ids m -> make_cstr_mappings env ids m) + (Bindings.map (fun (_, tus) -> IdSet.of_list (List.map type_union_id tus)) (Env.get_variants env)) + Bindings.empty; + } - let _ = - if !opt_coq_warn_nonexhaustive - then Reporting.print_err - (fst ann) "Non-exhaustive matching" ("Example: " ^ string_of_rp example) in + let rec remove_clause_from_pattern ctx (P_aux (rm_pat, ann)) res_pat = + let subpats rm_pats res_pats = + let res_pats' = List.map2 (remove_clause_from_pattern ctx) rm_pats res_pats in + let rec aux acc fixed residual = + match (fixed, residual) with + | [], [] -> [] + | fh :: ft, rh :: rt -> + let rt' = aux (acc @ [fh]) ft rt in + let newr = List.map (fun x -> acc @ (x :: ft)) rh in + newr @ rt' + | _, _ -> assert false (* impossible because we managed map2 above *) + in + aux [] res_pats res_pats' + in + let inconsistent () = + raise + (Reporting.err_unreachable (fst ann) __POS__ + ("Inconsistency during exhaustiveness analysis with " ^ string_of_rp res_pat) + ) + in + (*let _ = print_endline (!printprefix ^ "pat: " ^string_of_pat (P_aux (rm_pat,ann))) in + let _ = print_endline (!printprefix ^ "res_pat: " ^string_of_rp res_pat) in + let _ = printprefix := " " ^ !printprefix in*) + let rp' = + match rm_pat with + | P_wild -> [] + | P_id id when match Env.lookup_id id ctx.env with Unbound _ | Local _ -> true | _ -> false -> [] + | P_lit lit -> ( + match res_pat with + | RP_any -> List.map (fun l -> RP_lit l) (inv_rlit_of_lit lit) + | RP_lit RL_inf -> [res_pat] + | RP_lit lit' -> if lit' = rlit_of_lit lit then [] else [res_pat] + | _ -> inconsistent () + ) + | P_as (p, _) | P_typ (_, p) | P_var (p, _) -> remove_clause_from_pattern ctx p res_pat + | P_id id -> ( + match Env.lookup_id id ctx.env with + | Enum enum -> ( + match res_pat with + | RP_any -> Bindings.find id ctx.enum_to_rest + | RP_enum id' -> if Id.compare id id' == 0 then [] else [res_pat] + | _ -> inconsistent () + ) + | _ -> assert false + ) + | P_tuple rm_pats -> + let previous_res_pats = + match res_pat with + | RP_tuple res_pats -> res_pats + | RP_any -> List.map (fun _ -> RP_any) rm_pats + | _ -> inconsistent () + in + let res_pats' = subpats rm_pats previous_res_pats in + List.map (fun rps -> RP_tuple rps) res_pats' + | P_app (id, args) -> ( + match res_pat with + | RP_app (id', residual_args) -> + if Id.compare id id' == 0 then ( + let res_pats' = + (* Constructors that were specified without a return type might get + an extra tuple in their type; expand that here if necessary. + TODO: this should go away if we enforce proper arities. *) + match (args, residual_args) with + | [], [RP_any] | _ :: _ :: _, [RP_any] -> subpats args (List.map (fun _ -> RP_any) args) + | _, _ -> subpats args residual_args + in + List.map (fun rps -> RP_app (id, rps)) res_pats' + ) + else [res_pat] + | RP_any -> + let res_args = subpats args (List.map (fun _ -> RP_any) args) in + List.map (fun l -> RP_app (id, l)) res_args @ Bindings.find id ctx.constructor_to_rest + | _ -> inconsistent () + ) + | P_list ps -> ( + match ps with + | p1 :: ptl -> remove_clause_from_pattern ctx (P_aux (P_cons (p1, P_aux (P_list ptl, ann)), ann)) res_pat + | [] -> ( + match res_pat with + | RP_any -> [RP_cons (RP_any, RP_any)] + | RP_cons _ -> [res_pat] + | RP_nil -> [] + | _ -> inconsistent () + ) + ) + | P_cons (p1, p2) -> begin + let rp', rps = + match res_pat with + | RP_cons (rp1, rp2) -> ([], Some [rp1; rp2]) + | RP_any -> ([RP_nil], Some [RP_any; RP_any]) + | RP_nil -> ([RP_nil], None) + | _ -> inconsistent () + in + match rps with + | None -> rp' + | Some rps -> + let res_pats = subpats [p1; p2] rps in + rp' @ List.map (function [rp1; rp2] -> RP_cons (rp1, rp2) | _ -> assert false) res_pats + end + | P_or _ -> raise (Reporting.err_unreachable (fst ann) __POS__ "Or pattern not supported") + | P_not _ -> raise (Reporting.err_unreachable (fst ann) __POS__ "Negated pattern not supported") + | P_vector _ | P_vector_concat _ | P_vector_subrange _ | P_string_append _ -> + raise + (Reporting.err_unreachable (fst ann) __POS__ + "Found pattern that should have been rewritten away in earlier stage" + ) + (*in let _ = printprefix := String.sub (!printprefix) 0 (String.length !printprefix - 2) + in let _ = print_endline (!printprefix ^ "res_pats': " ^ String.concat "; " (List.map string_of_rp rp'))*) + in - let l = Parse_ast.Generated Parse_ast.Unknown in - let p = P_aux (P_wild, (l, empty_tannot)) in - let ann' = mk_tannot env (typ_of_annot ann) in - (* TODO: use an expression that specifically indicates a failed pattern match *) - let b = E_aux (E_exit (E_aux (E_lit (L_aux (L_unit, l)),(l,empty_tannot))),(l,ann')) in - E_aux (rebuild (cases@[Pat_aux (Pat_exp (p,b),(l,empty_tannot))]),ann) - end - | E_let (LB_aux (LB_val (pat,e1),lb_ann),e2) -> - begin - let env = env_of_annot ann in - let ctx = ctx_from_env env in - let rps = remove_clause_from_pattern ctx pat RP_any in - match rps with - | [] -> E_aux (e,ann) - | (example::_) -> + rp' + + let process_pexp env = + let ctx = ctx_from_env env in + fun rps patexp -> + (*let _ = print_endline ("res_pats: " ^ String.concat "; " (List.map string_of_rp rps)) in + let _ = print_endline ("pat: " ^ string_of_pexp patexp) in*) + match patexp with + | Pat_aux (Pat_exp (p, _), _) -> List.concat (List.map (remove_clause_from_pattern ctx p) rps) + | Pat_aux (Pat_when _, (l, _)) -> + raise (Reporting.err_unreachable l __POS__ "Guarded pattern should have been rewritten away") + + (* We do some minimal redundancy checking to remove bogus wildcard patterns here *) + let check_cases process is_wild loc_of cases = + let rec aux rps acc = function + | [] -> (acc, rps) + | [p] when is_wild p && match rps with [] -> true | _ -> false -> + let () = Reporting.print_err (loc_of p) "Match checking" "Redundant wildcard clause" in + (acc, []) + | h :: t -> aux (process rps h) (h :: acc) t + in + let cases, rps = aux [RP_any] [] cases in + (List.rev cases, rps) + + let not_enum env id = match Env.lookup_id id env with Enum _ -> false | _ -> true + + let pexp_is_wild = function + | Pat_aux (Pat_exp (P_aux (P_wild, _), _), _) -> true + | Pat_aux (Pat_exp (P_aux (P_id id, ann), _), _) when not_enum (env_of_annot ann) id -> true + | _ -> false + + let pexp_loc = function + | Pat_aux (Pat_exp (P_aux (_, (l, _)), _), _) -> l + | Pat_aux (Pat_when (P_aux (_, (l, _)), _, _), _) -> l + + let funcl_is_wild = function FCL_aux (FCL_funcl (_, pexp), _) -> pexp_is_wild pexp + + let funcl_loc (FCL_aux (_, (def_annot, _))) = def_annot.loc + + let rewrite_case (e, ann) = + match e with + | E_match (e1, cases) | E_try (e1, cases) -> begin + let env = env_of_annot ann in + let cases, rps = check_cases (process_pexp env) pexp_is_wild pexp_loc cases in + let rebuild cases = + match e with E_match _ -> E_match (e1, cases) | E_try _ -> E_try (e1, cases) | _ -> assert false + in + match rps with + | [] -> E_aux (rebuild cases, ann) + | example :: _ -> + let _ = + if !opt_coq_warn_nonexhaustive then + Reporting.print_err (fst ann) "Non-exhaustive matching" ("Example: " ^ string_of_rp example) + in + + let l = Parse_ast.Generated Parse_ast.Unknown in + let p = P_aux (P_wild, (l, empty_tannot)) in + let ann' = mk_tannot env (typ_of_annot ann) in + (* TODO: use an expression that specifically indicates a failed pattern match *) + let b = E_aux (E_exit (E_aux (E_lit (L_aux (L_unit, l)), (l, empty_tannot))), (l, ann')) in + E_aux (rebuild (cases @ [Pat_aux (Pat_exp (p, b), (l, empty_tannot))]), ann) + end + | E_let (LB_aux (LB_val (pat, e1), lb_ann), e2) -> begin + let env = env_of_annot ann in + let ctx = ctx_from_env env in + let rps = remove_clause_from_pattern ctx pat RP_any in + match rps with + | [] -> E_aux (e, ann) + | example :: _ -> + let _ = + if !opt_coq_warn_nonexhaustive then + Reporting.print_err (fst ann) "Non-exhaustive let" ("Example: " ^ string_of_rp example) + in + let l = Parse_ast.Generated Parse_ast.Unknown in + let p = P_aux (P_wild, (l, empty_tannot)) in + let ann' = mk_tannot env (typ_of_annot ann) in + (* TODO: use an expression that specifically indicates a failed pattern match *) + let b = E_aux (E_exit (E_aux (E_lit (L_aux (L_unit, l)), (l, empty_tannot))), (l, ann')) in + E_aux (E_match (e1, [Pat_aux (Pat_exp (pat, e2), ann); Pat_aux (Pat_exp (p, b), (l, empty_tannot))]), ann) + end + | _ -> E_aux (e, ann) + + let rewrite_fun rewriters (FD_aux (FD_function (r, t, fcls), f_ann)) = + let id, fcl_ann = + match fcls with + | FCL_aux (FCL_funcl (id, _), ann) :: _ -> (id, ann) + | [] -> raise (Reporting.err_unreachable (fst f_ann) __POS__ "Empty function") + in + let env = env_of_tannot (snd fcl_ann) in + let process_funcl rps (FCL_aux (FCL_funcl (_, pexp), _)) = process_pexp env rps pexp in + let fcls, rps = check_cases process_funcl funcl_is_wild funcl_loc fcls in + let fcls' = + List.map + (function FCL_aux (FCL_funcl (id, pexp), ann) -> FCL_aux (FCL_funcl (id, rewrite_pexp rewriters pexp), ann)) + fcls + in + match rps with + | [] -> FD_aux (FD_function (r, t, fcls'), f_ann) + | example :: _ -> let _ = - if !opt_coq_warn_nonexhaustive - then Reporting.print_err - (fst ann) "Non-exhaustive let" ("Example: " ^ string_of_rp example) in + if !opt_coq_warn_nonexhaustive then + Reporting.print_err (fst f_ann) "Non-exhaustive matching" ("Example: " ^ string_of_rp example) + in + let l = Parse_ast.Generated Parse_ast.Unknown in let p = P_aux (P_wild, (l, empty_tannot)) in - let ann' = mk_tannot env (typ_of_annot ann) in + let ann' = mk_tannot env (typ_of_tannot (snd fcl_ann)) in (* TODO: use an expression that specifically indicates a failed pattern match *) - let b = E_aux (E_exit (E_aux (E_lit (L_aux (L_unit, l)),(l,empty_tannot))),(l,ann')) in - E_aux (E_match (e1,[Pat_aux (Pat_exp(pat,e2),ann); - Pat_aux (Pat_exp (p,b),(l,empty_tannot))]),ann) - end - | _ -> E_aux (e,ann) - -let rewrite_fun rewriters (FD_aux (FD_function (r,t,fcls),f_ann)) = - let id,fcl_ann = - match fcls with - | FCL_aux (FCL_funcl (id,_),ann) :: _ -> id, ann - | [] -> raise (Reporting.err_unreachable (fst f_ann) __POS__ - "Empty function") - in - let env = env_of_tannot (snd fcl_ann) in - let process_funcl rps (FCL_aux (FCL_funcl (_,pexp),_)) = process_pexp env rps pexp in - let fcls, rps = check_cases process_funcl funcl_is_wild funcl_loc fcls in - let fcls' = List.map (function FCL_aux (FCL_funcl (id,pexp),ann) -> - FCL_aux (FCL_funcl (id, rewrite_pexp rewriters pexp),ann)) - fcls in - match rps with - | [] -> FD_aux (FD_function (r,t,fcls'),f_ann) - | (example::_) -> - let _ = - if !opt_coq_warn_nonexhaustive - then Reporting.print_err - (fst f_ann) "Non-exhaustive matching" ("Example: " ^ string_of_rp example) in - - let l = Parse_ast.Generated Parse_ast.Unknown in - let p = P_aux (P_wild, (l, empty_tannot)) in - let ann' = mk_tannot env (typ_of_tannot (snd fcl_ann)) in - (* TODO: use an expression that specifically indicates a failed pattern match *) - let b = E_aux (E_exit (E_aux (E_lit (L_aux (L_unit, l)),(l,empty_tannot))),(l,ann')) in - let default = FCL_aux (FCL_funcl (id,Pat_aux (Pat_exp (p,b),(l,empty_tannot))),fcl_ann) in - - FD_aux (FD_function (r,t,fcls'@[default]),f_ann) - -let rewrite env = - let alg = { id_exp_alg with e_aux = rewrite_case } in - rewrite_ast_base - { rewrite_exp = (fun _ -> fold_exp alg) - ; rewrite_pat = rewrite_pat - ; rewrite_let = rewrite_let - ; rewrite_lexp = rewrite_lexp - ; rewrite_fun = rewrite_fun - ; rewrite_def = rewrite_def - ; rewrite_ast = rewrite_ast_base - } + let b = E_aux (E_exit (E_aux (E_lit (L_aux (L_unit, l)), (l, empty_tannot))), (l, ann')) in + let default = FCL_aux (FCL_funcl (id, Pat_aux (Pat_exp (p, b), (l, empty_tannot))), fcl_ann) in + FD_aux (FD_function (r, t, fcls' @ [default]), f_ann) + let rewrite env = + let alg = { id_exp_alg with e_aux = rewrite_case } in + rewrite_ast_base + { + rewrite_exp = (fun _ -> fold_exp alg); + rewrite_pat; + rewrite_let; + rewrite_lexp; + rewrite_fun; + rewrite_def; + rewrite_ast = rewrite_ast_base; + } end (* Splitting a function (e.g., an execute function on an AST) can produce @@ -4381,57 +4565,55 @@ end see if the flag can be turned off. Doesn't handle mutual recursion for now. *) let minimise_recursive_functions env ast = - let rewrite_function (FD_aux (FD_function (recopt,topt,funcls),ann) as fd) = + let rewrite_function (FD_aux (FD_function (recopt, topt, funcls), ann) as fd) = match recopt with | Rec_aux (Rec_nonrec, _) -> fd | Rec_aux ((Rec_rec | Rec_measure _), l) -> - if List.exists is_funcl_rec funcls - then fd - else FD_aux (FD_function (Rec_aux (Rec_nonrec, Generated l),topt,funcls),ann) + if List.exists is_funcl_rec funcls then fd + else FD_aux (FD_function (Rec_aux (Rec_nonrec, Generated l), topt, funcls), ann) in let rewrite_def = function | DEF_aux (DEF_fundef fd, def_annot) -> DEF_aux (DEF_fundef (rewrite_function fd), def_annot) | d -> d - in { ast with defs = List.map rewrite_def ast.defs } + in + { ast with defs = List.map rewrite_def ast.defs } (* Move recursive function termination measures into the function definitions. *) let move_termination_measures env ast = let scan_for id defs = let rec aux = function | [] -> None - | (DEF_aux (DEF_measure (id',pat,exp),_))::t -> - if Id.compare id id' == 0 then Some (pat,exp) else aux t - | (DEF_aux (DEF_fundef (FD_aux (FD_function (_,_,FCL_aux (FCL_funcl (id',_),_)::_),_)),_))::_ - | (DEF_aux (DEF_val (VS_aux (VS_val_spec (_,id',_,_),_)),_)::_) - when Id.compare id id' == 0 -> None - | _::t -> aux t - in aux defs + | DEF_aux (DEF_measure (id', pat, exp), _) :: t -> if Id.compare id id' == 0 then Some (pat, exp) else aux t + | DEF_aux (DEF_fundef (FD_aux (FD_function (_, _, FCL_aux (FCL_funcl (id', _), _) :: _), _)), _) :: _ + | DEF_aux (DEF_val (VS_aux (VS_val_spec (_, id', _, _), _)), _) :: _ + when Id.compare id id' == 0 -> + None + | _ :: t -> aux t + in + aux defs in let rec aux acc = function | [] -> List.rev acc - | (DEF_aux (DEF_fundef (FD_aux (FD_function (r,ty,fs),(l,f_ann))),def_annot) as d)::t -> begin - let id = match fs with - | [] -> assert false (* TODO *) - | (FCL_aux (FCL_funcl (id,_),_))::_ -> id - in - match scan_for id t with - | None -> aux (d::acc) t - | Some (pat,exp) -> - let r = Rec_aux (Rec_measure (pat,exp), Generated l) in - aux (DEF_aux (DEF_fundef (FD_aux (FD_function (r,ty,fs),(l,f_ann))),def_annot)::acc) t + | (DEF_aux (DEF_fundef (FD_aux (FD_function (r, ty, fs), (l, f_ann))), def_annot) as d) :: t -> begin + let id = match fs with [] -> assert false (* TODO *) | FCL_aux (FCL_funcl (id, _), _) :: _ -> id in + match scan_for id t with + | None -> aux (d :: acc) t + | Some (pat, exp) -> + let r = Rec_aux (Rec_measure (pat, exp), Generated l) in + aux (DEF_aux (DEF_fundef (FD_aux (FD_function (r, ty, fs), (l, f_ann))), def_annot) :: acc) t end - | (DEF_aux (DEF_measure _,_))::t -> aux acc t - | h::t -> aux (h::acc) t - in { ast with defs = aux [] ast.defs } + | DEF_aux (DEF_measure _, _) :: t -> aux acc t + | h :: t -> aux (h :: acc) t + in + { ast with defs = aux [] ast.defs } (* Make recursive functions with a measure use the measure as an explicit recursion limit, enforced by an assertion. *) let rewrite_explicit_measure effect_info env ast = let effect_info = ref effect_info in let scan_function measures = function - | FD_aux (FD_function (Rec_aux (Rec_measure (mpat,mexp),rl),topt, - FCL_aux (FCL_funcl (id,_),_)::_),ann) -> - Bindings.add id (mpat,mexp) measures + | FD_aux (FD_function (Rec_aux (Rec_measure (mpat, mexp), rl), topt, FCL_aux (FCL_funcl (id, _), _) :: _), ann) -> + Bindings.add id (mpat, mexp) measures | _ -> measures in let scan_def measures = function @@ -4441,152 +4623,188 @@ let rewrite_explicit_measure effect_info env ast = in let measures = List.fold_left scan_def Bindings.empty ast.defs in (* NB: the Coq backend relies on recognising the #rec# prefix *) - let rec_id = function - | Id_aux (Id id,l) - | Id_aux (Operator id,l) -> Id_aux (Id ("#rec#" ^ id),Generated l) - in + let rec_id = function Id_aux (Id id, l) | Id_aux (Operator id, l) -> Id_aux (Id ("#rec#" ^ id), Generated l) in let limit = mk_id "#reclimit" in (* Add helper function with extra argument to spec *) - let rewrite_spec (VS_aux (VS_val_spec (typsch,id,extern,flag),ann) as vs) = + let rewrite_spec (VS_aux (VS_val_spec (typsch, id, extern, flag), ann) as vs) = match Bindings.find id measures with | _ -> begin - match typsch with - | TypSchm_aux (TypSchm_ts (tq, - Typ_aux (Typ_fn (args,res),typl)),tsl) -> - [VS_aux (VS_val_spec ( - TypSchm_aux (TypSchm_ts (tq, - Typ_aux (Typ_fn (args@[int_typ],res),typl)),tsl) - ,rec_id id,extern,flag),ann); - VS_aux (VS_val_spec ( - TypSchm_aux (TypSchm_ts (tq, - Typ_aux (Typ_fn (args,res),typl)),tsl) - ,id,extern,flag),ann)] - | _ -> [vs] (* TODO warn *) + match typsch with + | TypSchm_aux (TypSchm_ts (tq, Typ_aux (Typ_fn (args, res), typl)), tsl) -> + [ + VS_aux + ( VS_val_spec + ( TypSchm_aux (TypSchm_ts (tq, Typ_aux (Typ_fn (args @ [int_typ], res), typl)), tsl), + rec_id id, + extern, + flag + ), + ann + ); + VS_aux + ( VS_val_spec (TypSchm_aux (TypSchm_ts (tq, Typ_aux (Typ_fn (args, res), typl)), tsl), id, extern, flag), + ann + ); + ] + | _ -> [vs] + (* TODO warn *) end | exception Not_found -> [vs] in (* Add extra argument and assertion to each funcl, and rewrite recursive calls *) - let rewrite_funcl recset (FCL_aux (FCL_funcl (id,pexp),fcl_ann)) = + let rewrite_funcl recset (FCL_aux (FCL_funcl (id, pexp), fcl_ann)) = let loc = Parse_ast.Generated (fst fcl_ann).loc in - let P_aux (pat,pann),guard,body,ann = destruct_pexp pexp in - let extra_pat = P_aux (P_id limit,(loc,empty_tannot)) in - let pat = match pat with - | P_tuple pats -> P_tuple (pats@[extra_pat]) - | p -> P_tuple [P_aux (p,pann);extra_pat] + let P_aux (pat, pann), guard, body, ann = destruct_pexp pexp in + let extra_pat = P_aux (P_id limit, (loc, empty_tannot)) in + let pat = + match pat with P_tuple pats -> P_tuple (pats @ [extra_pat]) | p -> P_tuple [P_aux (p, pann); extra_pat] in let assert_exp = - E_aux (E_assert - (E_aux (E_app (mk_id "gteq_int", - [E_aux (E_id limit,(loc,empty_tannot)); - E_aux (E_lit (L_aux (L_num Big_int.zero,loc)),(loc,empty_tannot))]), - (loc,empty_tannot)), - (E_aux (E_lit (L_aux (L_string "recursion limit reached",loc)),(loc,empty_tannot)))), - (loc,empty_tannot)) + E_aux + ( E_assert + ( E_aux + ( E_app + ( mk_id "gteq_int", + [ + E_aux (E_id limit, (loc, empty_tannot)); + E_aux (E_lit (L_aux (L_num Big_int.zero, loc)), (loc, empty_tannot)); + ] + ), + (loc, empty_tannot) + ), + E_aux (E_lit (L_aux (L_string "recursion limit reached", loc)), (loc, empty_tannot)) + ), + (loc, empty_tannot) + ) in let tick = - E_aux (E_app (mk_id "sub_int", - [E_aux (E_id limit,(loc,empty_tannot)); - E_aux (E_lit (L_aux (L_num (Big_int.of_int 1),loc)),(loc,empty_tannot))]), - (loc,empty_tannot)) + E_aux + ( E_app + ( mk_id "sub_int", + [ + E_aux (E_id limit, (loc, empty_tannot)); + E_aux (E_lit (L_aux (L_num (Big_int.of_int 1), loc)), (loc, empty_tannot)); + ] + ), + (loc, empty_tannot) + ) in let open Rewriter in let body = - fold_exp { id_exp_alg with - e_app = (fun (f,args) -> - if IdSet.mem f recset - then E_app (rec_id f, args@[tick]) - else E_app (f, args)) - } body + fold_exp + { + id_exp_alg with + e_app = (fun (f, args) -> if IdSet.mem f recset then E_app (rec_id f, args @ [tick]) else E_app (f, args)); + } + body in - let body = E_aux (E_block [assert_exp; body],(loc,empty_tannot)) in + let body = E_aux (E_block [assert_exp; body], (loc, empty_tannot)) in let new_id = rec_id id in effect_info := Effects.copy_function_effect id !effect_info new_id; - FCL_aux (FCL_funcl (new_id, construct_pexp (P_aux (pat,pann),guard,body,ann)),fcl_ann) + FCL_aux (FCL_funcl (new_id, construct_pexp (P_aux (pat, pann), guard, body, ann)), fcl_ann) in - let rewrite_function recset (FD_aux (FD_function (r,t,fcls),ann) as fd) = + let rewrite_function recset (FD_aux (FD_function (r, t, fcls), ann) as fd) = let loc = Parse_ast.Generated (fst ann) in match fcls with - | FCL_aux (FCL_funcl (id,_),fcl_ann)::_ -> begin + | FCL_aux (FCL_funcl (id, _), fcl_ann) :: _ -> begin match Bindings.find id measures with - | (measure_pat, measure_exp) -> - let arg_typs = match Env.get_val_spec id (env_of_tannot (snd fcl_ann)) with - | _, Typ_aux (Typ_fn (args,_),_) -> args - | _, _ -> raise (Reporting.err_unreachable (fst ann) __POS__ - "Function doesn't have function type") - in - let measure_pats = match arg_typs, measure_pat with - | [_], _ -> [measure_pat] - | _, P_aux (P_tuple ps,_) -> ps - | _, _ -> [measure_pat] - in - let mk_wrap i (P_aux (p,(l,_)) as p_full) = - let id = - match p with - | P_id id - | P_typ (_,(P_aux (P_id id,_))) -> id - | P_lit _ - | P_wild - | P_typ (_,(P_aux (P_wild,_))) -> - mk_id ("_arg" ^ string_of_int i) - | _ -> raise (Reporting.err_todo l ("Measure patterns can only be identifiers or wildcards, not " ^ string_of_pat p_full)) - in - P_aux (P_id id,(loc,empty_tannot)), - E_aux (E_id id,(loc,empty_tannot)) - in - let wpats,wexps = List.split (List.mapi mk_wrap measure_pats) in - let wpat = match wpats with - | [wpat] -> wpat - | _ -> P_aux (P_tuple wpats,(loc,empty_tannot)) - in - let measure_exp = E_aux (E_typ (int_typ, measure_exp),(loc,empty_tannot)) in - let wbody = E_aux (E_app (rec_id id,wexps@[measure_exp]),(loc,empty_tannot)) in - let wrapper = - FCL_aux (FCL_funcl (id, Pat_aux (Pat_exp (wpat,wbody),(loc,empty_tannot))),(mk_def_annot loc,empty_tannot)) - in - let new_rec = - Rec_aux (Rec_measure (P_aux (P_tuple (List.map (fun _ -> P_aux (P_wild,(loc,empty_tannot))) measure_pats @ [P_aux (P_id limit,(loc,empty_tannot))]),(loc,empty_tannot)), E_aux (E_id limit, (loc,empty_tannot))), loc) - in - FD_aux (FD_function (new_rec,t,List.map (rewrite_funcl recset) fcls),ann), - [FD_aux (FD_function (Rec_aux (Rec_nonrec,loc),t,[wrapper]),ann)] - | exception Not_found -> fd,[] + | measure_pat, measure_exp -> + let arg_typs = + match Env.get_val_spec id (env_of_tannot (snd fcl_ann)) with + | _, Typ_aux (Typ_fn (args, _), _) -> args + | _, _ -> raise (Reporting.err_unreachable (fst ann) __POS__ "Function doesn't have function type") + in + let measure_pats = + match (arg_typs, measure_pat) with + | [_], _ -> [measure_pat] + | _, P_aux (P_tuple ps, _) -> ps + | _, _ -> [measure_pat] + in + let mk_wrap i (P_aux (p, (l, _)) as p_full) = + let id = + match p with + | P_id id | P_typ (_, P_aux (P_id id, _)) -> id + | P_lit _ | P_wild | P_typ (_, P_aux (P_wild, _)) -> mk_id ("_arg" ^ string_of_int i) + | _ -> + raise + (Reporting.err_todo l + ("Measure patterns can only be identifiers or wildcards, not " ^ string_of_pat p_full) + ) + in + (P_aux (P_id id, (loc, empty_tannot)), E_aux (E_id id, (loc, empty_tannot))) + in + let wpats, wexps = List.split (List.mapi mk_wrap measure_pats) in + let wpat = match wpats with [wpat] -> wpat | _ -> P_aux (P_tuple wpats, (loc, empty_tannot)) in + let measure_exp = E_aux (E_typ (int_typ, measure_exp), (loc, empty_tannot)) in + let wbody = E_aux (E_app (rec_id id, wexps @ [measure_exp]), (loc, empty_tannot)) in + let wrapper = + FCL_aux + (FCL_funcl (id, Pat_aux (Pat_exp (wpat, wbody), (loc, empty_tannot))), (mk_def_annot loc, empty_tannot)) + in + let new_rec = + Rec_aux + ( Rec_measure + ( P_aux + ( P_tuple + (List.map (fun _ -> P_aux (P_wild, (loc, empty_tannot))) measure_pats + @ [P_aux (P_id limit, (loc, empty_tannot))] + ), + (loc, empty_tannot) + ), + E_aux (E_id limit, (loc, empty_tannot)) + ), + loc + ) + in + ( FD_aux (FD_function (new_rec, t, List.map (rewrite_funcl recset) fcls), ann), + [FD_aux (FD_function (Rec_aux (Rec_nonrec, loc), t, [wrapper]), ann)] + ) + | exception Not_found -> (fd, []) end - | _ -> fd,[] + | _ -> (fd, []) in let rewrite_def = function | DEF_aux (DEF_val vs, def_annot) -> List.map (fun vs -> DEF_aux (DEF_val vs, def_annot)) (rewrite_spec vs) | DEF_aux (DEF_fundef fd, def_annot) -> - let fd,extra = rewrite_function (IdSet.singleton (id_of_fundef fd)) fd in - List.map (fun f -> DEF_aux (DEF_fundef f, def_annot)) (fd::extra) - | (DEF_aux (DEF_internal_mutrec fds, def_annot)) as d -> - let recset = ids_of_def d in - let fds,extras = List.split (List.map (rewrite_function recset) fds) in - let extras = List.concat extras in - (DEF_aux (DEF_internal_mutrec fds, def_annot))::(List.map (fun f -> DEF_aux (DEF_fundef f, def_annot)) extras) + let fd, extra = rewrite_function (IdSet.singleton (id_of_fundef fd)) fd in + List.map (fun f -> DEF_aux (DEF_fundef f, def_annot)) (fd :: extra) + | DEF_aux (DEF_internal_mutrec fds, def_annot) as d -> + let recset = ids_of_def d in + let fds, extras = List.split (List.map (rewrite_function recset) fds) in + let extras = List.concat extras in + DEF_aux (DEF_internal_mutrec fds, def_annot) :: List.map (fun f -> DEF_aux (DEF_fundef f, def_annot)) extras | d -> [d] in let defs = List.flatten (List.map rewrite_def ast.defs) in - { ast with defs }, !effect_info, env + ({ ast with defs }, !effect_info, env) (* Add a dummy assert to loops for backends that require loops to be able to fail. Note that the Coq backend will spot the assert and omit it. *) let rewrite_loops_with_escape_effect env defs = - let dummy_ann = Parse_ast.Unknown,empty_tannot in + let dummy_ann = (Parse_ast.Unknown, empty_tannot) in let assert_exp = - E_aux (E_assert (E_aux (E_lit (L_aux (L_true,Unknown)),dummy_ann), - E_aux (E_lit (L_aux (L_string "loop dummy assert",Unknown)),dummy_ann)), - dummy_ann) + E_aux + ( E_assert + ( E_aux (E_lit (L_aux (L_true, Unknown)), dummy_ann), + E_aux (E_lit (L_aux (L_string "loop dummy assert", Unknown)), dummy_ann) + ), + dummy_ann + ) in let rewrite_exp rws exp = - let (E_aux (e,ann) as exp) = Rewriter.rewrite_exp rws exp in + let (E_aux (e, ann) as exp) = Rewriter.rewrite_exp rws exp in match e with - | E_loop (l,(Measure_aux (Measure_some _,_) as m),guard,body) -> - (* TODO EFFECT *) - if (* has_effect (effect_of exp) BE_escape *) false then exp else - let body = match body with - | E_aux (E_block es,ann) -> - E_aux (E_block (assert_exp::es),ann) - | _ -> E_aux (E_block [assert_exp;body],dummy_ann) - in E_aux (E_loop (l,m,guard,body),ann) + | E_loop (l, (Measure_aux (Measure_some _, _) as m), guard, body) -> + (* TODO EFFECT *) + if (* has_effect (effect_of exp) BE_escape *) false then exp + else ( + let body = + match body with + | E_aux (E_block es, ann) -> E_aux (E_block (assert_exp :: es), ann) + | _ -> E_aux (E_block [assert_exp; body], dummy_ann) + in + E_aux (E_loop (l, m, guard, body), ann) + ) | _ -> exp in rewrite_ast_base { rewriters_base with rewrite_exp } defs @@ -4601,22 +4819,27 @@ let remove_duplicate_valspecs env ast = List.fold_left (fun last_externs def -> match def with - | DEF_aux (DEF_val (VS_aux (VS_val_spec (_, id, externs, _), _)), _) -> - Bindings.add id externs last_externs - | _ -> last_externs) Bindings.empty ast.defs + | DEF_aux (DEF_val (VS_aux (VS_val_spec (_, id, externs, _), _)), _) -> Bindings.add id externs last_externs + | _ -> last_externs + ) + Bindings.empty ast.defs in - let (_, rev_defs) = + let _, rev_defs = List.fold_left (fun (set, defs) def -> match def with | DEF_aux (DEF_val (VS_aux (VS_val_spec (typschm, id, _, cast), l)), def_annot) -> if IdSet.mem id set then (set, defs) - else + else ( let externs = Bindings.find id last_externs in let vs = VS_aux (VS_val_spec (typschm, id, externs, cast), l) in - (IdSet.add id set, (DEF_aux (DEF_val vs, def_annot))::defs) - | _ -> (set, def::defs)) (IdSet.empty, []) ast.defs - in { ast with defs = List.rev rev_defs } + (IdSet.add id set, DEF_aux (DEF_val vs, def_annot) :: defs) + ) + | _ -> (set, def :: defs) + ) + (IdSet.empty, []) ast.defs + in + { ast with defs = List.rev rev_defs } (* Move loop termination measures into loop AST nodes. This is used before type checking so that we avoid the complexity of type checking separate @@ -4627,47 +4850,48 @@ let move_loop_measures ast = (fun m d -> match d with | DEF_aux (DEF_loop_measures (id, measures), _) -> - (* Allow multiple measure definitions, concatenating them *) - Bindings.add id - (match Bindings.find_opt id m with - | None -> measures - | Some m -> m @ measures) - m - | _ -> m) Bindings.empty ast.defs - in - let do_exp exp_rec measures (E_aux (e,ann) as exp) = - match e, measures with - | E_loop (loop, _, e1, e2), (Loop (loop',exp))::t when loop = loop' -> - let t,e1 = exp_rec t e1 in - let t,e2 = exp_rec t e2 in - t,E_aux (E_loop (loop, Measure_aux (Measure_some exp, exp_loc exp), e1, e2),ann) + (* Allow multiple measure definitions, concatenating them *) + Bindings.add id (match Bindings.find_opt id m with None -> measures | Some m -> m @ measures) m + | _ -> m + ) + Bindings.empty ast.defs + in + let do_exp exp_rec measures (E_aux (e, ann) as exp) = + match (e, measures) with + | E_loop (loop, _, e1, e2), Loop (loop', exp) :: t when loop = loop' -> + let t, e1 = exp_rec t e1 in + let t, e2 = exp_rec t e2 in + (t, E_aux (E_loop (loop, Measure_aux (Measure_some exp, exp_loc exp), e1, e2), ann)) | _ -> exp_rec measures exp in - let do_funcl (m,acc) (FCL_aux (FCL_funcl (id, pexp),ann) as fcl) = + let do_funcl (m, acc) (FCL_aux (FCL_funcl (id, pexp), ann) as fcl) = match Bindings.find_opt id m with | Some measures -> - let measures,pexp = foldin_pexp do_exp measures pexp in - Bindings.add id measures m, (FCL_aux (FCL_funcl (id, pexp),ann))::acc - | None -> m, fcl::acc + let measures, pexp = foldin_pexp do_exp measures pexp in + (Bindings.add id measures m, FCL_aux (FCL_funcl (id, pexp), ann) :: acc) + | None -> (m, fcl :: acc) in - let unused,rev_defs = + let unused, rev_defs = List.fold_left - (fun (m,acc) d -> + (fun (m, acc) d -> match d with - | DEF_aux (DEF_loop_measures _, _) -> m, acc - | DEF_aux (DEF_fundef (FD_aux (FD_function (r,t,fcls),ann)),def_annot) -> - let m,rfcls = List.fold_left do_funcl (m,[]) fcls in - m, (DEF_aux (DEF_fundef (FD_aux (FD_function (r,t,List.rev rfcls),ann)), def_annot))::acc - | _ -> m, d::acc) - (loop_measures,[]) ast.defs - in let () = Bindings.iter - (fun id -> function - | [] -> () - | _::_ -> - Reporting.print_err (id_loc id) "Warning" - ("unused loop measure for function " ^ string_of_id id)) - unused - in { ast with defs = List.rev rev_defs } + | DEF_aux (DEF_loop_measures _, _) -> (m, acc) + | DEF_aux (DEF_fundef (FD_aux (FD_function (r, t, fcls), ann)), def_annot) -> + let m, rfcls = List.fold_left do_funcl (m, []) fcls in + (m, DEF_aux (DEF_fundef (FD_aux (FD_function (r, t, List.rev rfcls), ann)), def_annot) :: acc) + | _ -> (m, d :: acc) + ) + (loop_measures, []) ast.defs + in + let () = + Bindings.iter + (fun id -> function + | [] -> () + | _ :: _ -> Reporting.print_err (id_loc id) "Warning" ("unused loop measure for function " ^ string_of_id id) + ) + unused + in + { ast with defs = List.rev rev_defs } let rewrite_toplevel_consts target type_env ast = let istate = Constant_fold.initial_state ast type_env in @@ -4678,25 +4902,25 @@ let rewrite_toplevel_consts target type_env ast = IdSet.fold (fun id -> subst id (Bindings.find id consts)) subst_ids exp in let rewrite_def (revdefs, consts) = function - | DEF_aux (DEF_let (LB_aux (LB_val (pat, exp), a) as lb), def_annot) -> - begin match unaux_pat pat with - | P_id id | P_typ (_, P_aux (P_id id, _)) -> + | DEF_aux (DEF_let (LB_aux (LB_val (pat, exp), a) as lb), def_annot) -> begin + match unaux_pat pat with + | P_id id | P_typ (_, P_aux (P_id id, _)) -> let exp' = Constant_fold.rewrite_exp_once target istate (subst consts exp) in - if Constant_fold.is_constant exp' then + if Constant_fold.is_constant exp' then ( try let exp' = infer_exp (env_of exp') (strip_exp exp') in let pannot = (pat_loc pat, mk_tannot (env_of_pat pat) (typ_of exp')) in let pat' = P_aux (P_typ (typ_of exp', P_aux (P_id id, pannot)), pannot) in let consts' = Bindings.add id exp' consts in (DEF_aux (DEF_let (LB_aux (LB_val (pat', exp'), a)), def_annot) :: revdefs, consts') - with - | _ -> (DEF_aux (DEF_let lb, def_annot) :: revdefs, consts) + with _ -> (DEF_aux (DEF_let lb, def_annot) :: revdefs, consts) + ) else (DEF_aux (DEF_let lb, def_annot) :: revdefs, consts) - | _ -> (DEF_aux (DEF_let lb, def_annot) :: revdefs, consts) - end + | _ -> (DEF_aux (DEF_let lb, def_annot) :: revdefs, consts) + end | def -> (def :: revdefs, consts) in - let (revdefs, _) = List.fold_left rewrite_def ([], Bindings.empty) ast.defs in + let revdefs, _ = List.fold_left rewrite_def ([], Bindings.empty) ast.defs in { ast with defs = List.rev revdefs } (* Hex literals are always a multiple of 4 bits long. If one of a different size is needed, users may truncate @@ -4705,37 +4929,31 @@ let rewrite_toplevel_consts target type_env ast = let rewrite_truncate_hex_literals _type_env defs = let rewrite_aux (e, annot) = match e with - | E_app (Id_aux (Id "truncate", _), [E_aux (E_lit (L_aux (L_hex hex, l_ann)),_); E_aux (E_lit (L_aux (L_num len, _)),_)]) -> - let bin = hex_to_bin hex in - let len = Nat_big_num.to_int len in - let truncation = String.sub bin (String.length bin - len) len in - E_aux (E_lit (L_aux (L_bin truncation, l_ann)), annot) + | E_app + ( Id_aux (Id "truncate", _), + [E_aux (E_lit (L_aux (L_hex hex, l_ann)), _); E_aux (E_lit (L_aux (L_num len, _)), _)] + ) -> + let bin = hex_to_bin hex in + let len = Nat_big_num.to_int len in + let truncation = String.sub bin (String.length bin - len) len in + E_aux (E_lit (L_aux (L_bin truncation, l_ann)), annot) | _ -> E_aux (e, annot) in rewrite_ast_base - { rewriters_base with - rewrite_exp = (fun _ -> fold_exp { id_exp_alg with e_aux = rewrite_aux }) } + { rewriters_base with rewrite_exp = (fun _ -> fold_exp { id_exp_alg with e_aux = rewrite_aux }) } defs let opt_mono_rewrites = ref false let opt_mono_complex_nexps = ref true -let mono_rewrites env defs = - if !opt_mono_rewrites then - Monomorphise.mono_rewrites defs - else defs +let mono_rewrites env defs = if !opt_mono_rewrites then Monomorphise.mono_rewrites defs else defs -let rewrite_toplevel_nexps env defs = - if !opt_mono_complex_nexps then - Monomorphise.rewrite_toplevel_nexps defs - else defs +let rewrite_toplevel_nexps env defs = if !opt_mono_complex_nexps then Monomorphise.rewrite_toplevel_nexps defs else defs let rewrite_complete_record_params env defs = - if !opt_mono_complex_nexps then - Monomorphise.rewrite_complete_record_params env defs - else defs + if !opt_mono_complex_nexps then Monomorphise.rewrite_complete_record_params env defs else defs -let opt_mono_split = ref ([]:((string * int) * string) list) +let opt_mono_split = ref ([] : ((string * int) * string) list) let opt_dmono_analysis = ref 0 let opt_auto_mono = ref false let opt_dall_split_errors = ref false @@ -4743,27 +4961,26 @@ let opt_dmono_continue = ref false let monomorphise target effect_info env defs = let open Monomorphise in - monomorphise - target - effect_info - { auto = !opt_auto_mono; - debug_analysis = !opt_dmono_analysis; - all_split_errors = !opt_dall_split_errors; - continue_anyway = !opt_dmono_continue } - !opt_mono_split - defs, effect_info, env + ( monomorphise target effect_info + { + auto = !opt_auto_mono; + debug_analysis = !opt_dmono_analysis; + all_split_errors = !opt_dall_split_errors; + continue_anyway = !opt_dmono_continue; + } + !opt_mono_split defs, + effect_info, + env + ) let if_mono f effect_info env ast = - match !opt_mono_split, !opt_auto_mono with - | [], false -> ast, effect_info, env - | _, _ -> f effect_info env ast + match (!opt_mono_split, !opt_auto_mono) with [], false -> (ast, effect_info, env) | _, _ -> f effect_info env ast (* Also turn mwords stages on when we're just trying out mono *) let if_mwords f effect_info env ast = if !Monomorphise.opt_mwords then f effect_info env ast else if_mono f effect_info env ast -let if_flag flag f effect_info env ast = - if !flag then f effect_info env ast else (ast, effect_info, env) +let if_flag flag f effect_info env ast = if !flag then f effect_info env ast else (ast, effect_info, env) type rewriter = | Base_rewriter of (Effects.side_effect_info -> Env.t -> tannot ast -> tannot ast * Effects.side_effect_info * Env.t) @@ -4771,9 +4988,14 @@ type rewriter = | String_rewriter of (string -> rewriter) | Literal_rewriter of ((lit -> bool) -> rewriter) -let basic_rewriter f = Base_rewriter (fun effect_info env ast -> f env ast, effect_info, env) -let checking_rewriter f = Base_rewriter (fun effect_info env ast -> let ast, env = f env ast in ast, effect_info, env) - +let basic_rewriter f = Base_rewriter (fun effect_info env ast -> (f env ast, effect_info, env)) +let checking_rewriter f = + Base_rewriter + (fun effect_info env ast -> + let ast, env = f env ast in + (ast, effect_info, env) + ) + type rewriter_arg = | If_mono_arg | If_mwords_arg @@ -4788,17 +5010,20 @@ let rec describe_rewriter = function | Bool_rewriter rw -> "" :: describe_rewriter (rw false) | Literal_rewriter rw -> "(ocaml|lem|all)" :: describe_rewriter (rw (fun _ -> true)) | Base_rewriter _ -> [] - + let instantiate_rewriter rewriter args = let selector_function = function | "ocaml" -> rewrite_lit_ocaml | "lem" -> rewrite_lit_lem - | "all" -> (fun _ -> true) + | "all" -> fun _ -> true | arg -> - raise (Reporting.err_general Parse_ast.Unknown ("No rewrite for literal target \"" ^ arg ^ "\", valid targets are ocaml/lem/all")) + raise + (Reporting.err_general Parse_ast.Unknown + ("No rewrite for literal target \"" ^ arg ^ "\", valid targets are ocaml/lem/all") + ) in let instantiate rewriter arg = - match rewriter, arg with + match (rewriter, arg) with | Base_rewriter rw, If_mono_arg -> Base_rewriter (if_mono rw) | Base_rewriter rw, If_mwords_arg -> Base_rewriter (if_mwords rw) | Base_rewriter rw, If_flag flag -> Base_rewriter (if_flag flag rw) @@ -4806,15 +5031,14 @@ let instantiate_rewriter rewriter args = | Bool_rewriter rw, Bool_arg b -> rw b | String_rewriter rw, String_arg str -> rw str | Literal_rewriter rw, Literal_arg selector -> rw (selector_function selector) - | _, _ -> - Reporting.unreachable Parse_ast.Unknown __POS__ "Invalid rewrite argument" + | _, _ -> Reporting.unreachable Parse_ast.Unknown __POS__ "Invalid rewrite argument" in match List.fold_left instantiate rewriter args with | Base_rewriter rw -> rw - | _ -> - Reporting.unreachable Parse_ast.Unknown __POS__ "Rewrite not fully instantiated" + | _ -> Reporting.unreachable Parse_ast.Unknown __POS__ "Rewrite not fully instantiated" -let all_rewriters = [ +let all_rewriters = + [ ("recheck_defs", checking_rewriter (fun _ ast -> Type_error.check initial_env (strip_ast ast))); ("optimize_recheck_defs", checking_rewriter (fun _ -> Optimize.recheck)); ("realize_mappings", Base_rewriter rewrite_ast_realize_mappings); @@ -4828,7 +5052,9 @@ let all_rewriters = [ ("toplevel_nexps", basic_rewriter rewrite_toplevel_nexps); ("toplevel_consts", String_rewriter (fun target -> basic_rewriter (rewrite_toplevel_consts target))); ("monomorphise", String_rewriter (fun target -> Base_rewriter (monomorphise target))); - ("atoms_to_singletons", String_rewriter (fun target -> (basic_rewriter (fun _ -> Monomorphise.rewrite_atoms_to_singletons target)))); + ( "atoms_to_singletons", + String_rewriter (fun target -> basic_rewriter (fun _ -> Monomorphise.rewrite_atoms_to_singletons target)) + ); ("add_bitvector_casts", basic_rewriter Monomorphise.add_bitvector_casts); ("remove_impossible_int_cases", basic_rewriter Constant_propagation.remove_impossible_int_cases); ("const_prop_mutrec", String_rewriter (fun target -> Base_rewriter (Constant_propagation_mutrec.rewrite_ast target))); @@ -4863,17 +5089,34 @@ let all_rewriters = [ ("rewrite_loops_with_escape_effect", basic_rewriter rewrite_loops_with_escape_effect); ("simple_types", basic_rewriter rewrite_simple_types); ("overload_cast", basic_rewriter rewrite_overload_cast); - ("instantiate_outcomes", String_rewriter (fun target -> basic_rewriter (fun _ -> Outcome_rewrites.instantiate target))); + ( "instantiate_outcomes", + String_rewriter (fun target -> basic_rewriter (fun _ -> Outcome_rewrites.instantiate target)) + ); ("top_sort_defs", basic_rewriter (fun _ -> top_sort_defs)); - ("constant_fold", String_rewriter (fun target -> basic_rewriter (fun _ -> Constant_fold.(rewrite_constant_function_calls no_fixed target)))); + ( "constant_fold", + String_rewriter + (fun target -> basic_rewriter (fun _ -> Constant_fold.(rewrite_constant_function_calls no_fixed target))) + ); ("split", String_rewriter (fun str -> Base_rewriter (rewrite_split_fun_ctor_pats str))); ("properties", basic_rewriter (fun _ -> Property.rewrite)); - ("attach_effects", Base_rewriter (fun effect_info env ast -> Effects.rewrite_attach_effects effect_info ast, effect_info, env)); - ("prover_regstate", Bool_rewriter (fun mwords -> Base_rewriter (fun effect_info env ast -> let env, ast = State.add_regstate_defs mwords env ast in ast, effect_info, env))); + ( "attach_effects", + Base_rewriter (fun effect_info env ast -> (Effects.rewrite_attach_effects effect_info ast, effect_info, env)) + ); + ( "prover_regstate", + Bool_rewriter + (fun mwords -> + Base_rewriter + (fun effect_info env ast -> + let env, ast = State.add_regstate_defs mwords env ast in + (ast, effect_info, env) + ) + ) + ); ("add_unspecified_rec", basic_rewriter rewrite_add_unspecified_rec); ] -let rewrites_interpreter = [ +let rewrites_interpreter = + [ ("instantiate_outcomes", [String_arg "interpreter"]); ("realize_mappings", []); ("toplevel_string_append", []); @@ -4882,53 +5125,59 @@ let rewrites_interpreter = [ ("undefined", [Bool_arg false]); ("tuple_assignments", []); ("vector_concat_assignments", []); - ("simple_assignments", []) + ("simple_assignments", []); ] -type rewrite_sequence = (string * (Effects.side_effect_info -> Env.t -> tannot ast -> tannot ast * Effects.side_effect_info * Env.t)) list - +type rewrite_sequence = + (string * (Effects.side_effect_info -> Env.t -> tannot ast -> tannot ast * Effects.side_effect_info * Env.t)) list + let instantiate_rewrites rws = let get_rewriter name = match List.assoc_opt name all_rewriters with | Some rewrite -> rewrite - | None -> - Reporting.unreachable Parse_ast.Unknown __POS__ ("Attempted to execute unknown rewrite " ^ name) + | None -> Reporting.unreachable Parse_ast.Unknown __POS__ ("Attempted to execute unknown rewrite " ^ name) in List.map (fun (name, args) -> (name, instantiate_rewriter (get_rewriter name) args)) rws let opt_ddump_rewrite_ast = ref None - + let rewrite_step n total (ast, effect_info, env) (name, rewriter) = let t = Profile.start () in let ast, effect_info, env = rewriter effect_info env ast in Profile.finish ("rewrite " ^ name) t; - begin match !opt_ddump_rewrite_ast with - | Some (f, i) -> - let filename = f ^ "_rewrite_" ^ string_of_int i ^ "_" ^ name ^ ".sail" in - let ((ot,_,_,_) as ext_ot) = Util.open_output_with_check_unformatted None filename in - Pretty_print_sail.pp_ast ot (strip_ast ast); - Util.close_output_with_check ext_ot; - opt_ddump_rewrite_ast := Some (f, i + 1) - | _ -> () + begin + match !opt_ddump_rewrite_ast with + | Some (f, i) -> + let filename = f ^ "_rewrite_" ^ string_of_int i ^ "_" ^ name ^ ".sail" in + let ((ot, _, _, _) as ext_ot) = Util.open_output_with_check_unformatted None filename in + Pretty_print_sail.pp_ast ot (strip_ast ast); + Util.close_output_with_check ext_ot; + opt_ddump_rewrite_ast := Some (f, i + 1) + | _ -> () end; Util.progress "Rewrite " name n total; - ast, effect_info, env + (ast, effect_info, env) let rewrite effect_info env rewriters ast = let total = List.length rewriters in - try snd (List.fold_left (fun (n, astenv) rw -> n + 1, rewrite_step n total astenv rw) (1, (ast, effect_info, env)) rewriters) with - | Type_check.Type_error (_, l, err) -> - raise (Reporting.err_typ l (Type_error.string_of_type_error err)) + try + snd + (List.fold_left + (fun (n, astenv) rw -> (n + 1, rewrite_step n total astenv rw)) + (1, (ast, effect_info, env)) + rewriters + ) + with Type_check.Type_error (_, l, err) -> raise (Reporting.err_typ l (Type_error.string_of_type_error err)) let () = let open Interactive in - - ActionUnit (fun _ -> - let print_rewriter (name, rw) = - print_endline (name ^ " " ^ Util.(String.concat " " (describe_rewriter rw) |> yellow |> clear)) - in - List.sort (fun a b -> String.compare (fst a) (fst b)) all_rewriters - |> List.iter print_rewriter - ) |> register_command ~name:"list_rewrites" ~help:"List all rewrites for use with the :rewrite command"; + ActionUnit + (fun _ -> + let print_rewriter (name, rw) = + print_endline (name ^ " " ^ Util.(String.concat " " (describe_rewriter rw) |> yellow |> clear)) + in + List.sort (fun a b -> String.compare (fst a) (fst b)) all_rewriters |> List.iter print_rewriter + ) + |> register_command ~name:"list_rewrites" ~help:"List all rewrites for use with the :rewrite command" diff --git a/src/lib/rewrites.mli b/src/lib/rewrites.mli index e640857b3..54f36028b 100644 --- a/src/lib/rewrites.mli +++ b/src/lib/rewrites.mli @@ -84,8 +84,8 @@ val opt_coq_warn_nonexhaustive : bool ref (** Output each rewrite step (as produced by the rewrite function) to a file for debugging *) -val opt_ddump_rewrite_ast : ((string * int) option) ref - +val opt_ddump_rewrite_ast : (string * int) option ref + (** Generate a fresh id with the given prefix *) val fresh_id : string -> l -> id @@ -93,7 +93,7 @@ val fresh_id : string -> l -> id val move_loop_measures : 'a ast -> 'a ast val pat_of_mpat : 'a mpat -> 'a pat - + (** Re-write undefined to functions created by -undefined_gen flag *) val rewrite_undefined : bool -> Env.t -> tannot ast -> tannot ast @@ -104,10 +104,11 @@ type rewriter = | Literal_rewriter of ((lit -> bool) -> rewriter) val describe_rewriter : rewriter -> string list - + val all_rewriters : (string * rewriter) list -type rewrite_sequence = (string * (Effects.side_effect_info -> Env.t -> tannot ast -> tannot ast * Effects.side_effect_info * Env.t)) list +type rewrite_sequence = + (string * (Effects.side_effect_info -> Env.t -> tannot ast -> tannot ast * Effects.side_effect_info * Env.t)) list val rewrite_lit_ocaml : lit -> bool val rewrite_lit_lem : lit -> bool @@ -125,8 +126,9 @@ type rewriter_arg = val instantiate_rewrites : (string * rewriter_arg list) list -> rewrite_sequence (** Apply a sequence of rewrites to an AST *) -val rewrite : Effects.side_effect_info -> Env.t -> rewrite_sequence -> tannot ast -> tannot ast * Effects.side_effect_info * Env.t +val rewrite : + Effects.side_effect_info -> Env.t -> rewrite_sequence -> tannot ast -> tannot ast * Effects.side_effect_info * Env.t val rewrites_interpreter : (string * rewriter_arg list) list - + val simple_typ : typ -> typ diff --git a/src/lib/sail_lib.ml b/src/lib/sail_lib.ml index bd28d83f7..0c6daf55e 100644 --- a/src/lib/sail_lib.ml +++ b/src/lib/sail_lib.ml @@ -74,8 +74,8 @@ module type BitType = sig val b1 : t end -type 'a return = { return : 'b . 'a -> 'b } -type 'za zoption = | ZNone of unit | ZSome of 'za;; +type 'a return = { return : 'b. 'a -> 'b } +type 'za zoption = ZNone of unit | ZSome of 'za let zint_forwards i = string_of_int (Big_int.to_int i) @@ -84,42 +84,35 @@ let opt_trace = ref false let trace_depth = ref 0 let random = ref false - let opt_cycle_limit = ref 0 let cycle_count = ref 0 -let cycle_limit_reached () = +let cycle_limit_reached () = cycle_count := !cycle_count + 1; !opt_cycle_limit != 0 && !cycle_count >= !opt_cycle_limit let sail_call (type t) (f : _ -> t) = - let module M = - struct exception Return of t end - in + let module M = struct + exception Return of t + end in let return = { return = (fun x -> raise (M.Return x)) } in - try - f return - with M.Return x -> x + try f return with M.Return x -> x let trace str = - if !opt_trace - then - begin - if !trace_depth < 0 then trace_depth := 0 else (); - prerr_endline (String.make (!trace_depth * 2) ' ' ^ str) - end + if !opt_trace then begin + if !trace_depth < 0 then trace_depth := 0 else (); + prerr_endline (String.make (!trace_depth * 2) ' ' ^ str) + end else () -let trace_write name str = - trace ("Write: " ^ name ^ " " ^ str) +let trace_write name str = trace ("Write: " ^ name ^ " " ^ str) -let trace_read name str = - trace ("Read: " ^ name ^ " " ^ str) +let trace_read name str = trace ("Read: " ^ name ^ " " ^ str) let sail_trace_call (type t) (name : string) (in_string : string) (string_of_out : t -> string) (f : _ -> t) = - let module M = - struct exception Return of t end - in + let module M = struct + exception Return of t + end in let return = { return = (fun x -> raise (M.Return x)) } in trace ("Call: " ^ name ^ " " ^ in_string); incr trace_depth; @@ -129,7 +122,8 @@ let sail_trace_call (type t) (name : string) (in_string : string) (string_of_out result let trace_call str = - trace str; incr trace_depth + trace str; + incr trace_depth type bit = B0 | B1 @@ -137,18 +131,11 @@ let eq_anything (a, b) = a = b let eq_bit (a, b) = a = b -let and_bit = function - | B1, B1 -> B1 - | _, _ -> B0 +let and_bit = function B1, B1 -> B1 | _, _ -> B0 -let or_bit = function - | B0, B0 -> B0 - | _, _ -> B1 +let or_bit = function B0, B0 -> B0 | _, _ -> B1 -let xor_bit = function - | B1, B0 -> B1 - | B0, B1 -> B1 - | _, _ -> B0 +let xor_bit = function B1, B0 -> B1 | B0, B1 -> B1 | _, _ -> B0 let and_vec (xs, ys) = assert (List.length xs = List.length ys); @@ -166,64 +153,42 @@ let xor_vec (xs, ys) = assert (List.length xs = List.length ys); List.map2 (fun x y -> xor_bit (x, y)) xs ys -let xor_bool (b1, b2) = (b1 || b2) && (b1 != b2) +let xor_bool (b1, b2) = (b1 || b2) && b1 != b2 -let undefined_bit () = - if !random - then (if Random.bool () then B0 else B1) - else B0 +let undefined_bit () = if !random then if Random.bool () then B0 else B1 else B0 -let undefined_bool () = - if !random then Random.bool () else false +let undefined_bool () = if !random then Random.bool () else false let rec undefined_vector (len, item) = - if Big_int.equal len Big_int.zero - then [] - else item :: undefined_vector (Big_int.sub len (Big_int.of_int 1), item) + if Big_int.equal len Big_int.zero then [] else item :: undefined_vector (Big_int.sub len (Big_int.of_int 1), item) let undefined_list _ = [] let undefined_bitvector len = - if Big_int.equal len Big_int.zero - then [] - else B0 :: undefined_vector (Big_int.sub len (Big_int.of_int 1), B0) + if Big_int.equal len Big_int.zero then [] else B0 :: undefined_vector (Big_int.sub len (Big_int.of_int 1), B0) let undefined_string () = "" let undefined_unit () = () -let undefined_int () = - if !random then Big_int.of_int (Random.int 0xFFFF) else Big_int.zero +let undefined_int () = if !random then Big_int.of_int (Random.int 0xFFFF) else Big_int.zero let undefined_nat () = Big_int.zero let undefined_range (lo, _) = lo -let internal_pick list = - if !random - then List.nth list (Random.int (List.length list)) - else List.nth list 0 +let internal_pick list = if !random then List.nth list (Random.int (List.length list)) else List.nth list 0 let eq_int (n, m) = Big_int.equal n m let eq_bool ((x : bool), (y : bool)) : bool = x = y -let rec drop n xs = - match n, xs with - | 0, xs -> xs - | _, [] -> [] - | n, (_ :: xs) -> drop (n -1) xs +let rec drop n xs = match (n, xs) with 0, xs -> xs | _, [] -> [] | n, _ :: xs -> drop (n - 1) xs -let rec take n xs = - match n, xs with - | 0, _ -> [] - | n, (x :: xs) -> x :: take (n - 1) xs - | _, [] -> [] +let rec take n xs = match (n, xs) with 0, _ -> [] | n, x :: xs -> x :: take (n - 1) xs | _, [] -> [] let count_leading_zeros xs = - let rec aux bs acc = match bs with - | (B0 :: bs') -> aux bs' (acc + 1) - | _ -> acc in + let rec aux bs acc = match bs with B0 :: bs' -> aux bs' (acc + 1) | _ -> acc in Big_int.of_int (aux xs 0) let subrange (list, n, m) = @@ -243,14 +208,11 @@ let access (xs, n) = List.nth (List.rev xs) (Big_int.to_int n) let append (xs, ys) = xs @ ys let update (xs, n, x) = - let n = (List.length xs - Big_int.to_int n) - 1 in + let n = List.length xs - Big_int.to_int n - 1 in take n xs @ [x] @ drop (n + 1) xs let update_subrange (xs, n, _, ys) = - let rec aux xs o = function - | [] -> xs - | (y :: ys) -> aux (update (xs, o, y)) (Big_int.sub o (Big_int.of_int 1)) ys - in + let rec aux xs o = function [] -> xs | y :: ys -> aux (update (xs, o, y)) (Big_int.sub o (Big_int.of_int 1)) ys in aux xs n ys let vector_truncate (xs, n) = List.rev (take (Big_int.to_int n) (List.rev xs)) @@ -259,13 +221,11 @@ let vector_truncateLSB (xs, n) = take (Big_int.to_int n) xs let length xs = Big_int.of_int (List.length xs) -let big_int_of_bit = function - | B0 -> Big_int.zero - | B1 -> (Big_int.of_int 1) +let big_int_of_bit = function B0 -> Big_int.zero | B1 -> Big_int.of_int 1 let uint xs = let uint_bit x (n, pos) = - Big_int.add n (Big_int.mul (Big_int.pow_int_positive 2 pos) (big_int_of_bit x)), pos + 1 + (Big_int.add n (Big_int.mul (Big_int.pow_int_positive 2 pos) (big_int_of_bit x)), pos + 1) in fst (List.fold_right uint_bit xs (Big_int.zero, 0)) @@ -273,11 +233,9 @@ let sint = function | [] -> Big_int.zero | [msb] -> Big_int.negate (big_int_of_bit msb) | msb :: xs -> - let msb_pos = List.length xs in - let complement = - Big_int.negate (Big_int.mul (Big_int.pow_int_positive 2 msb_pos) (big_int_of_bit msb)) - in - Big_int.add complement (uint xs) + let msb_pos = List.length xs in + let complement = Big_int.negate (Big_int.mul (Big_int.pow_int_positive 2 msb_pos) (big_int_of_bit msb)) in + Big_int.add complement (uint xs) let add_int (x, y) = Big_int.add x y let sub_int (x, y) = Big_int.sub x y @@ -291,12 +249,10 @@ let mult (x, y) = Big_int.mul x y let quotient (x, y) = Big_int.div x y (* This is the same as tdiv_int, kept for compatibility with old preludes *) -let quot_round_zero (x, y) = - Big_int.integerDiv_t x y +let quot_round_zero (x, y) = Big_int.integerDiv_t x y (* The corresponding remainder function for above just respects the sign of x *) -let rem_round_zero (x, y) = - Big_int.integerRem_t x y +let rem_round_zero (x, y) = Big_int.integerRem_t x y (* Lem provides euclidian modulo by default *) let modulus (x, y) = Big_int.modulus x y @@ -308,46 +264,47 @@ let tdiv_int (x, y) = Big_int.integerDiv_t x y let tmod_int (x, y) = Big_int.integerRem_t x y let add_bit_with_carry (x, y, carry) = - match x, y, carry with - | B0, B0, B0 -> B0, B0 - | B0, B1, B0 -> B1, B0 - | B1, B0, B0 -> B1, B0 - | B1, B1, B0 -> B0, B1 - | B0, B0, B1 -> B1, B0 - | B0, B1, B1 -> B0, B1 - | B1, B0, B1 -> B0, B1 - | B1, B1, B1 -> B1, B1 + match (x, y, carry) with + | B0, B0, B0 -> (B0, B0) + | B0, B1, B0 -> (B1, B0) + | B1, B0, B0 -> (B1, B0) + | B1, B1, B0 -> (B0, B1) + | B0, B0, B1 -> (B1, B0) + | B0, B1, B1 -> (B0, B1) + | B1, B0, B1 -> (B0, B1) + | B1, B1, B1 -> (B1, B1) let sub_bit_with_carry (x, y, carry) = - match x, y, carry with - | B0, B0, B0 -> B0, B0 - | B0, B1, B0 -> B0, B1 - | B1, B0, B0 -> B1, B0 - | B1, B1, B0 -> B0, B0 - | B0, B0, B1 -> B1, B0 - | B0, B1, B1 -> B0, B0 - | B1, B0, B1 -> B1, B1 - | B1, B1, B1 -> B1, B0 - -let not_bit = function - | B0 -> B1 - | B1 -> B0 + match (x, y, carry) with + | B0, B0, B0 -> (B0, B0) + | B0, B1, B0 -> (B0, B1) + | B1, B0, B0 -> (B1, B0) + | B1, B1, B0 -> (B0, B0) + | B0, B0, B1 -> (B1, B0) + | B0, B1, B1 -> (B0, B0) + | B1, B0, B1 -> (B1, B1) + | B1, B1, B1 -> (B1, B0) + +let not_bit = function B0 -> B1 | B1 -> B0 let not_vec xs = List.map not_bit xs let add_vec_carry (xs, ys) = assert (List.length xs = List.length ys); - let (carry, result) = - List.fold_right2 (fun x y (c, result) -> let (z, c) = add_bit_with_carry (x, y, c) in (c, z :: result)) xs ys (B0, []) + let carry, result = + List.fold_right2 + (fun x y (c, result) -> + let z, c = add_bit_with_carry (x, y, c) in + (c, z :: result) + ) + xs ys (B0, []) in - carry, result + (carry, result) let add_vec (xs, ys) = snd (add_vec_carry (xs, ys)) let rec replicate_bits (bits, n) = - if Big_int.less_equal n Big_int.zero - then [] - else bits @ replicate_bits (bits, Big_int.sub n (Big_int.of_int 1)) + if Big_int.less_equal n Big_int.zero then [] else bits @ replicate_bits (bits, Big_int.sub n (Big_int.of_int 1)) let identity x = x @@ -357,11 +314,11 @@ Uses twos-complement representation for m<0 and pads most significant bits in si Most significant bit is head of returned list. *) let rec get_slice_int' (n, m, o) = - if n <= 0 then - [] - else - let bit = if (Big_int.extract_num m (n + o - 1) 1) == Big_int.zero then B0 else B1 in - bit :: get_slice_int' (n-1, m, o) + if n <= 0 then [] + else ( + let bit = if Big_int.extract_num m (n + o - 1) 1 == Big_int.zero then B0 else B1 in + bit :: get_slice_int' (n - 1, m, o) + ) (* as above but taking Big_int for all arguments *) let get_slice_int (n, m, o) = get_slice_int' (Big_int.to_int n, m, Big_int.to_int o) @@ -374,34 +331,31 @@ let to_bits (len, n) = get_slice_int' (Big_int.to_int len, n, 0) (* unsigned multiplication of two n bit lists producing a list of 2n bits *) let mult_vec (x, y) = - let xi = uint(x) in - let yi = uint(y) in + let xi = uint x in + let yi = uint y in let len = List.length x in let prod = Big_int.mul xi yi in - to_bits' (2*len, prod) + to_bits' (2 * len, prod) (* signed multiplication of two n bit lists producing a list of 2n bits. *) let mults_vec (x, y) = - let xi = sint(x) in - let yi = sint(y) in + let xi = sint x in + let yi = sint y in let len = List.length x in let prod = Big_int.mul xi yi in - to_bits' (2*len, prod) + to_bits' (2 * len, prod) let add_vec_int (v, n) = - let n_bits = to_bits'(List.length v, n) in - add_vec(v, n_bits) + let n_bits = to_bits' (List.length v, n) in + add_vec (v, n_bits) -let sub_vec (xs, ys) = add_vec (xs, add_vec_int (not_vec ys, (Big_int.of_int 1))) +let sub_vec (xs, ys) = add_vec (xs, add_vec_int (not_vec ys, Big_int.of_int 1)) let sub_vec_int (v, n) = - let n_bits = to_bits'(List.length v, n) in - sub_vec(v, n_bits) + let n_bits = to_bits' (List.length v, n) in + sub_vec (v, n_bits) -let bin_char = function - | '0' -> B0 - | '1' -> B1 - | _ -> failwith "Invalid binary character" +let bin_char = function '0' -> B0 | '1' -> B1 | _ -> failwith "Invalid binary character" let hex_char = function | '0' -> [B0; B0; B0; B0] @@ -423,39 +377,24 @@ let hex_char = function | _ -> failwith "Invalid hex character" let list_of_string s = - let rec aux i acc = - if i < 0 then acc - else aux (i-1) (s.[i] :: acc) - in aux (String.length s - 1) [] + let rec aux i acc = if i < 0 then acc else aux (i - 1) (s.[i] :: acc) in + aux (String.length s - 1) [] -let bits_of_string str = - List.concat (List.map hex_char (list_of_string str)) +let bits_of_string str = List.concat (List.map hex_char (list_of_string str)) let concat_str (str1, str2) = str1 ^ str2 -let rec break n = function - | [] -> [] - | (_ :: _ as xs) -> [take n xs] @ break n (drop n xs) +let rec break n = function [] -> [] | _ :: _ as xs -> [take n xs] @ break n (drop n xs) -let string_of_bit = function - | B0 -> "0" - | B1 -> "1" +let string_of_bit = function B0 -> "0" | B1 -> "1" -let char_of_bit = function - | B0 -> '0' - | B1 -> '1' +let char_of_bit = function B0 -> '0' | B1 -> '1' -let int_of_bit = function - | B0 -> 0 - | B1 -> 1 +let int_of_bit = function B0 -> 0 | B1 -> 1 -let bool_of_bit = function - | B0 -> false - | B1 -> true +let bool_of_bit = function B0 -> false | B1 -> true -let bit_of_bool = function - | false -> B0 - | true -> B1 +let bit_of_bool = function false -> B0 | true -> B1 let bigint_of_bit b = Big_int.of_int (int_of_bit b) @@ -478,24 +417,20 @@ let string_of_hex = function | [B1; B1; B1; B1] -> "F" | _ -> failwith "Cannot convert binary sequence to hex" - let string_of_bits bits = - if List.length bits mod 4 == 0 - then "0x" ^ String.concat "" (List.map string_of_hex (break 4 bits)) + if List.length bits mod 4 == 0 then "0x" ^ String.concat "" (List.map string_of_hex (break 4 bits)) else "0b" ^ String.concat "" (List.map string_of_bit bits) let decimal_string_of_bits bits = let place_values = - List.mapi - (fun i b -> (Big_int.mul (bigint_of_bit b) (Big_int.pow_int_positive 2 i))) - (List.rev bits) + List.mapi (fun i b -> Big_int.mul (bigint_of_bit b) (Big_int.pow_int_positive 2 i)) (List.rev bits) in let sum = List.fold_left Big_int.add Big_int.zero place_values in Big_int.to_string sum let hex_slice (str, n, m) = let bits = List.concat (List.map hex_char (list_of_string (String.sub str 2 (String.length str - 2)))) in - let padding = replicate_bits([B0], n) in + let padding = replicate_bits ([B0], n) in let bits = padding @ bits in let slice = List.rev (take (Big_int.to_int n) (drop (Big_int.to_int m) (List.rev bits))) in slice @@ -505,46 +440,39 @@ let putchar n = flush stdout let rec bits_of_int bit n = - if bit <> 0 - then - begin - if n / bit > 0 - then B1 :: bits_of_int (bit / 2) (n - bit) - else B0 :: bits_of_int (bit / 2) n - end + if bit <> 0 then begin + if n / bit > 0 then B1 :: bits_of_int (bit / 2) (n - bit) else B0 :: bits_of_int (bit / 2) n + end else [] let rec bits_of_big_int pow n = if pow < 1 then [] else begin - let bit = (Big_int.pow_int_positive 2 (pow - 1)) in - if Big_int.greater (Big_int.div n bit) Big_int.zero then - B1 :: bits_of_big_int (pow - 1) (Big_int.sub n bit) - else - B0 :: bits_of_big_int (pow - 1) n - end + let bit = Big_int.pow_int_positive 2 (pow - 1) in + if Big_int.greater (Big_int.div n bit) Big_int.zero then B1 :: bits_of_big_int (pow - 1) (Big_int.sub n bit) + else B0 :: bits_of_big_int (pow - 1) n + end let byte_of_int n = bits_of_int 128 n module Mem = struct - include Map.Make(struct - type t = Big_int.num - let compare = Big_int.compare - end) + include Map.Make (struct + type t = Big_int.num + let compare = Big_int.compare + end) end -let mem_pages = (ref Mem.empty : (Bytes.t Mem.t) ref);; +let mem_pages = (ref Mem.empty : Bytes.t Mem.t ref) let page_shift_bits = 20 (* 1M page *) -let page_size_bytes = 1 lsl page_shift_bits;; +let page_size_bytes = 1 lsl page_shift_bits let page_no_of_addr a = Big_int.shift_right a page_shift_bits let bottom_addr_of_page p = Big_int.shift_left p page_shift_bits let top_addr_of_page p = Big_int.shift_left (Big_int.succ p) page_shift_bits let get_mem_page p = - try - Mem.find p !mem_pages - with Not_found -> + try Mem.find p !mem_pages + with Not_found -> let new_page = Bytes.make page_size_bytes '\000' in mem_pages := Mem.add p new_page !mem_pages; new_page @@ -558,10 +486,9 @@ let rec add_mem_bytes addr buf off len = let bytes_left_in_page = Big_int.sub page_top addr in let to_copy = min (Big_int.to_int bytes_left_in_page) len in Bytes.blit buf off page page_off to_copy; - if (to_copy < len) then - add_mem_bytes page_top buf (off + to_copy) (len - to_copy) + if to_copy < len then add_mem_bytes page_top buf (off + to_copy) (len - to_copy) -let rec read_mem_bytes addr len = +let rec read_mem_bytes addr len = let page_no = page_no_of_addr addr in let page_bot = bottom_addr_of_page page_no in let page_top = top_addr_of_page page_no in @@ -570,20 +497,19 @@ let rec read_mem_bytes addr len = let bytes_left_in_page = Big_int.sub page_top addr in let to_get = min (Big_int.to_int bytes_left_in_page) len in let bytes = Bytes.sub page page_off to_get in - if to_get >= len then - bytes - else - Bytes.cat bytes (read_mem_bytes page_top (len - to_get)) + if to_get >= len then bytes else Bytes.cat bytes (read_mem_bytes page_top (len - to_get)) let write_ram' (data_size, addr, data) = - let len = Big_int.to_int data_size in - let bytes = Bytes.create len in begin + let len = Big_int.to_int data_size in + let bytes = Bytes.create len in + begin List.iteri (fun i byte -> Bytes.set bytes (len - i - 1) (char_of_int (Big_int.to_int (uint byte)))) (break 8 data); add_mem_bytes addr bytes 0 len end let write_ram (_addr_size, data_size, _hex_ram, addr, data) = - write_ram' (data_size, uint addr, data); true + write_ram' (data_size, uint addr, data); + true let wram addr byte = let bytes = Bytes.make 1 (char_of_int byte) in @@ -593,17 +519,17 @@ let read_ram (_addr_size, data_size, _hex_ram, addr) = let addr = uint addr in let bytes = read_mem_bytes addr (Big_int.to_int data_size) in let vector = ref [] in - Bytes.iter (fun byte -> vector := (byte_of_int (int_of_char byte)) @ !vector) bytes; + Bytes.iter (fun byte -> vector := byte_of_int (int_of_char byte) @ !vector) bytes; !vector let fast_read_ram (data_size, addr) = let addr = uint addr in let bytes = read_mem_bytes addr (Big_int.to_int data_size) in let vector = ref [] in - Bytes.iter (fun byte -> vector := (byte_of_int (int_of_char byte)) @ !vector) bytes; + Bytes.iter (fun byte -> vector := byte_of_int (int_of_char byte) @ !vector) bytes; !vector -let tag_ram = (ref Mem.empty : (bool Mem.t) ref);; +let tag_ram = (ref Mem.empty : bool Mem.t ref) let write_tag_bool (addr, tag) = let addri = uint addr in @@ -613,9 +539,7 @@ let read_tag_bool addr = let addri = uint addr in try Mem.find addri !tag_ram with Not_found -> false -let rec reverse_endianness bits = - if List.length bits <= 8 then bits else - reverse_endianness (drop 8 bits) @ (take 8 bits) +let rec reverse_endianness bits = if List.length bits <= 8 then bits else reverse_endianness (drop 8 bits) @ take 8 bits (* FIXME: Casts can't be externed *) let zcast_unit_vec x = [x] @@ -630,45 +554,42 @@ let debug (str1, n, str2, v) = prerr_endline (str1 ^ Big_int.to_string n ^ str2 let eq_string (str1, str2) = String.compare str1 str2 == 0 -let string_startswith (str1, str2) = String.length str1 >= String.length str2 && String.compare (String.sub str1 0 (String.length str2)) str2 == 0 +let string_startswith (str1, str2) = + String.length str1 >= String.length str2 && String.compare (String.sub str1 0 (String.length str2)) str2 == 0 -let string_drop (str, n) = if (Big_int.less_equal (Big_int.of_int (String.length str)) n) then "" else let n = Big_int.to_int n in String.sub str n (String.length str - n) +let string_drop (str, n) = + if Big_int.less_equal (Big_int.of_int (String.length str)) n then "" + else ( + let n = Big_int.to_int n in + String.sub str n (String.length str - n) + ) let string_take (str, n) = let n = Big_int.to_int n in - if String.length str <= n then - str - else - String.sub str 0 n + if String.length str <= n then str else String.sub str 0 n let string_length str = Big_int.of_int (String.length str) let string_append (s1, s2) = s1 ^ s2 -let int_of_string_opt s = - try - Some (Big_int.of_string s) - with - | Invalid_argument _ -> None +let int_of_string_opt s = try Some (Big_int.of_string s) with Invalid_argument _ -> None (* highly inefficient recursive implementation *) let rec maybe_int_of_prefix = function | "" -> ZNone () - | str -> - let len = String.length str in - match int_of_string_opt str with - | Some n -> ZSome (n, Big_int.of_int len) - | None -> maybe_int_of_prefix (String.sub str 0 (len - 1)) + | str -> ( + let len = String.length str in + match int_of_string_opt str with + | Some n -> ZSome (n, Big_int.of_int len) + | None -> maybe_int_of_prefix (String.sub str 0 (len - 1)) + ) -let maybe_int_of_string str = - match int_of_string_opt str with - | None -> ZNone () - | Some n -> ZSome n +let maybe_int_of_string str = match int_of_string_opt str with None -> ZNone () | Some n -> ZSome n let lt_int (x, y) = Big_int.less x y let set_slice (out_len, _slice_len, out, n, slice) = - let out = update_subrange(out, Big_int.add n (Big_int.of_int (List.length slice - 1)), n, slice) in + let out = update_subrange (out, Big_int.add n (Big_int.of_int (List.length slice - 1)), n, slice) in assert (List.length out = Big_int.to_int out_len); out @@ -689,10 +610,8 @@ let negate_real x = Rational.neg x let neg_real x = Rational.neg x let string_of_real x = - if Big_int.equal (Rational.den x) (Big_int.of_int 1) then - Big_int.to_string (Rational.num x) - else - Big_int.to_string (Rational.num x) ^ "/" ^ Big_int.to_string (Rational.den x) + if Big_int.equal (Rational.den x) (Big_int.of_int 1) then Big_int.to_string (Rational.num x) + else Big_int.to_string (Rational.num x) ^ "/" ^ Big_int.to_string (Rational.den x) let print_real (str, r) = print_endline (str ^ string_of_real r) let prerr_real (str, r) = prerr_endline (str ^ string_of_real r) @@ -714,22 +633,22 @@ let sqrt_real x = let s = Big_int.sqrt (Rational.num x) in if Big_int.equal (Rational.den x) (Big_int.of_int 1) && Big_int.equal (Big_int.mul s s) (Rational.num x) then to_real s - else + else ( let p = ref (to_real (Big_int.sqrt (Big_int.div (Rational.num x) (Rational.den x)))) in let n = ref (Rational.of_int 0) in - let convergence = ref (Rational.div (Rational.of_int 1) (Rational.of_big_int (Big_int.pow_int_positive 10 precision))) in + let convergence = + ref (Rational.div (Rational.of_int 1) (Rational.of_big_int (Big_int.pow_int_positive 10 precision))) + in let quit_loop = ref false in while not !quit_loop do n := Rational.div (Rational.add !p (Rational.div x !p)) (Rational.of_int 2); - if Rational.lt (Rational.abs (Rational.sub !p !n)) !convergence then - quit_loop := true - else - p := !n + if Rational.lt (Rational.abs (Rational.sub !p !n)) !convergence then quit_loop := true else p := !n done; !n + ) -let random_real () = Rational.div (Rational.of_int (Random.bits ())) (Rational.of_int (Random.bits())) +let random_real () = Rational.div (Rational.of_int (Random.bits ())) (Rational.of_int (Random.bits ())) let lt (x, y) = Big_int.less x y let gt (x, y) = Big_int.greater x y @@ -746,16 +665,14 @@ let string_of_int x = Big_int.to_string x let undefined_real () = Rational.of_int 0 -let rec pow x = function - | 0 -> 1 - | n -> x * pow x (n - 1) +let rec pow x = function 0 -> 1 | n -> x * pow x (n - 1) let real_of_string str = match Util.split_on_char '.' str with | [whole; frac] -> - let whole = Rational.of_int (int_of_string whole) in - let frac = Rational.div (Rational.of_int (int_of_string frac)) (Rational.of_int (pow 10 (String.length frac))) in - Rational.add whole frac + let whole = Rational.of_int (int_of_string whole) in + let frac = Rational.div (Rational.of_int (int_of_string frac)) (Rational.of_int (pow 10 (String.length frac))) in + Rational.add whole frac | [_] -> Rational.of_int (int_of_string str) | _ -> failwith "invalid real literal" @@ -763,43 +680,33 @@ let print str = Stdlib.print_string str let prerr str = Stdlib.prerr_string str -let print_int (str, x) = - print_endline (str ^ Big_int.to_string x) +let print_int (str, x) = print_endline (str ^ Big_int.to_string x) -let prerr_int (str, x) = - prerr_endline (str ^ Big_int.to_string x) +let prerr_int (str, x) = prerr_endline (str ^ Big_int.to_string x) -let print_bits (str, xs) = - print_endline (str ^ string_of_bits xs) +let print_bits (str, xs) = print_endline (str ^ string_of_bits xs) -let prerr_bits (str, xs) = - prerr_endline (str ^ string_of_bits xs) +let prerr_bits (str, xs) = prerr_endline (str ^ string_of_bits xs) -let print_string(str, msg) = - print_endline (str ^ msg) +let print_string (str, msg) = print_endline (str ^ msg) -let prerr_string(str, msg) = - prerr_endline (str ^ msg) +let prerr_string (str, msg) = prerr_endline (str ^ msg) let reg_deref r = !r -let string_of_zbit = function - | B0 -> "0" - | B1 -> "1" +let string_of_zbit = function B0 -> "0" | B1 -> "1" let string_of_znat n = Big_int.to_string n let string_of_zint n = Big_int.to_string n let string_of_zimplicit n = Big_int.to_string n let string_of_zunit () = "()" -let string_of_zbool = function - | true -> "true" - | false -> "false" +let string_of_zbool = function true -> "true" | false -> "false" let string_of_zreal _ = "REAL" let string_of_zstring str = "\"" ^ String.escaped str ^ "\"" let rec string_of_list sep string_of = function | [] -> "" | [x] -> string_of x - | x::ls -> (string_of x) ^ sep ^ (string_of_list sep string_of ls) + | x :: ls -> string_of x ^ sep ^ string_of_list sep string_of ls let skip () = () @@ -807,9 +714,7 @@ let memea (_, _) = () let zero_extend (vec, n) = let m = Big_int.to_int n in - if m <= List.length vec - then take m vec - else replicate_bits ([B0], Big_int.of_int (m - List.length vec)) @ vec + if m <= List.length vec then take m vec else replicate_bits ([B0], Big_int.of_int (m - List.length vec)) @ vec let sign_extend (vec, n) = let m = Big_int.to_int n in @@ -819,10 +724,10 @@ let sign_extend (vec, n) = | B1 :: _ as vec -> replicate_bits ([B1], Big_int.of_int (m - List.length vec)) @ vec let zeros n = replicate_bits ([B0], n) -let ones n = replicate_bits ([B1], n) +let ones n = replicate_bits ([B1], n) let shift_bits_right_arith (x, y) = - let ybi = uint(y) in + let ybi = uint y in let msbs = replicate_bits (take 1 x, ybi) in let rbits = msbs @ x in take (List.length x) rbits @@ -837,17 +742,15 @@ let arith_shiftr (x, y) = let rbits = msbs @ x in take (List.length x) rbits -let shift_bits_right (x, y) = - shiftr (x, uint(y)) +let shift_bits_right (x, y) = shiftr (x, uint y) let shiftl (x, y) = - let yi = Big_int.to_int y in + let yi = Big_int.to_int y in let zeros = zeros y in let rbits = x @ zeros in drop yi rbits -let shift_bits_left (x, y) = - shiftl (x, uint(y)) +let shift_bits_left (x, y) = shiftl (x, uint y) let speculate_conditional_success () = true @@ -855,375 +758,301 @@ let speculate_conditional_success () = true let get_time_ns () = Big_int.of_int (int_of_float (1e9 *. Unix.gettimeofday ())) (* Python: -f = """let hex_bits_{0}_matches_prefix s = - match maybe_int_of_prefix s with - | ZNone () -> ZNone () - | ZSome (n, len) -> - if Big_int.less_equal Big_int.zero n - && Big_int.less n (Big_int.pow_int_positive 2 {0}) then - ZSome ((bits_of_big_int {0} n, len)) - else - ZNone () -""" - -for i in list(range(1, 34)) + [48, 64]: - print(f.format(i)) - + f = """let hex_bits_{0}_matches_prefix s = + match maybe_int_of_prefix s with + | ZNone () -> ZNone () + | ZSome (n, len) -> + if Big_int.less_equal Big_int.zero n + && Big_int.less n (Big_int.pow_int_positive 2 {0}) then + ZSome ((bits_of_big_int {0} n, len)) + else + ZNone () + """ + + for i in list(range(1, 34)) + [48, 64]: + print(f.format(i)) *) let hex_bits_1_matches_prefix s = match maybe_int_of_prefix s with | ZNone () -> ZNone () | ZSome (n, len) -> - if Big_int.less_equal Big_int.zero n - && Big_int.less n (Big_int.pow_int_positive 2 1) then - ZSome ((bits_of_big_int 1 n, len)) - else - ZNone () + if Big_int.less_equal Big_int.zero n && Big_int.less n (Big_int.pow_int_positive 2 1) then + ZSome (bits_of_big_int 1 n, len) + else ZNone () let hex_bits_2_matches_prefix s = match maybe_int_of_prefix s with | ZNone () -> ZNone () | ZSome (n, len) -> - if Big_int.less_equal Big_int.zero n - && Big_int.less n (Big_int.pow_int_positive 2 2) then - ZSome ((bits_of_big_int 2 n, len)) - else - ZNone () + if Big_int.less_equal Big_int.zero n && Big_int.less n (Big_int.pow_int_positive 2 2) then + ZSome (bits_of_big_int 2 n, len) + else ZNone () let hex_bits_3_matches_prefix s = match maybe_int_of_prefix s with | ZNone () -> ZNone () | ZSome (n, len) -> - if Big_int.less_equal Big_int.zero n - && Big_int.less n (Big_int.pow_int_positive 2 3) then - ZSome ((bits_of_big_int 3 n, len)) - else - ZNone () + if Big_int.less_equal Big_int.zero n && Big_int.less n (Big_int.pow_int_positive 2 3) then + ZSome (bits_of_big_int 3 n, len) + else ZNone () let hex_bits_4_matches_prefix s = match maybe_int_of_prefix s with | ZNone () -> ZNone () | ZSome (n, len) -> - if Big_int.less_equal Big_int.zero n - && Big_int.less n (Big_int.pow_int_positive 2 4) then - ZSome ((bits_of_big_int 4 n, len)) - else - ZNone () + if Big_int.less_equal Big_int.zero n && Big_int.less n (Big_int.pow_int_positive 2 4) then + ZSome (bits_of_big_int 4 n, len) + else ZNone () let hex_bits_5_matches_prefix s = match maybe_int_of_prefix s with | ZNone () -> ZNone () | ZSome (n, len) -> - if Big_int.less_equal Big_int.zero n - && Big_int.less n (Big_int.pow_int_positive 2 5) then - ZSome ((bits_of_big_int 5 n, len)) - else - ZNone () + if Big_int.less_equal Big_int.zero n && Big_int.less n (Big_int.pow_int_positive 2 5) then + ZSome (bits_of_big_int 5 n, len) + else ZNone () let hex_bits_6_matches_prefix s = match maybe_int_of_prefix s with | ZNone () -> ZNone () | ZSome (n, len) -> - if Big_int.less_equal Big_int.zero n - && Big_int.less n (Big_int.pow_int_positive 2 6) then - ZSome ((bits_of_big_int 6 n, len)) - else - ZNone () + if Big_int.less_equal Big_int.zero n && Big_int.less n (Big_int.pow_int_positive 2 6) then + ZSome (bits_of_big_int 6 n, len) + else ZNone () let hex_bits_7_matches_prefix s = match maybe_int_of_prefix s with | ZNone () -> ZNone () | ZSome (n, len) -> - if Big_int.less_equal Big_int.zero n - && Big_int.less n (Big_int.pow_int_positive 2 7) then - ZSome ((bits_of_big_int 7 n, len)) - else - ZNone () + if Big_int.less_equal Big_int.zero n && Big_int.less n (Big_int.pow_int_positive 2 7) then + ZSome (bits_of_big_int 7 n, len) + else ZNone () let hex_bits_8_matches_prefix s = match maybe_int_of_prefix s with | ZNone () -> ZNone () | ZSome (n, len) -> - if Big_int.less_equal Big_int.zero n - && Big_int.less n (Big_int.pow_int_positive 2 8) then - ZSome ((bits_of_big_int 8 n, len)) - else - ZNone () + if Big_int.less_equal Big_int.zero n && Big_int.less n (Big_int.pow_int_positive 2 8) then + ZSome (bits_of_big_int 8 n, len) + else ZNone () let hex_bits_9_matches_prefix s = match maybe_int_of_prefix s with | ZNone () -> ZNone () | ZSome (n, len) -> - if Big_int.less_equal Big_int.zero n - && Big_int.less n (Big_int.pow_int_positive 2 9) then - ZSome ((bits_of_big_int 9 n, len)) - else - ZNone () + if Big_int.less_equal Big_int.zero n && Big_int.less n (Big_int.pow_int_positive 2 9) then + ZSome (bits_of_big_int 9 n, len) + else ZNone () let hex_bits_10_matches_prefix s = match maybe_int_of_prefix s with | ZNone () -> ZNone () | ZSome (n, len) -> - if Big_int.less_equal Big_int.zero n - && Big_int.less n (Big_int.pow_int_positive 2 10) then - ZSome ((bits_of_big_int 10 n, len)) - else - ZNone () + if Big_int.less_equal Big_int.zero n && Big_int.less n (Big_int.pow_int_positive 2 10) then + ZSome (bits_of_big_int 10 n, len) + else ZNone () let hex_bits_11_matches_prefix s = match maybe_int_of_prefix s with | ZNone () -> ZNone () | ZSome (n, len) -> - if Big_int.less_equal Big_int.zero n - && Big_int.less n (Big_int.pow_int_positive 2 11) then - ZSome ((bits_of_big_int 11 n, len)) - else - ZNone () + if Big_int.less_equal Big_int.zero n && Big_int.less n (Big_int.pow_int_positive 2 11) then + ZSome (bits_of_big_int 11 n, len) + else ZNone () let hex_bits_12_matches_prefix s = match maybe_int_of_prefix s with | ZNone () -> ZNone () | ZSome (n, len) -> - if Big_int.less_equal Big_int.zero n - && Big_int.less n (Big_int.pow_int_positive 2 12) then - ZSome ((bits_of_big_int 12 n, len)) - else - ZNone () + if Big_int.less_equal Big_int.zero n && Big_int.less n (Big_int.pow_int_positive 2 12) then + ZSome (bits_of_big_int 12 n, len) + else ZNone () let hex_bits_13_matches_prefix s = match maybe_int_of_prefix s with | ZNone () -> ZNone () | ZSome (n, len) -> - if Big_int.less_equal Big_int.zero n - && Big_int.less n (Big_int.pow_int_positive 2 13) then - ZSome ((bits_of_big_int 13 n, len)) - else - ZNone () + if Big_int.less_equal Big_int.zero n && Big_int.less n (Big_int.pow_int_positive 2 13) then + ZSome (bits_of_big_int 13 n, len) + else ZNone () let hex_bits_14_matches_prefix s = match maybe_int_of_prefix s with | ZNone () -> ZNone () | ZSome (n, len) -> - if Big_int.less_equal Big_int.zero n - && Big_int.less n (Big_int.pow_int_positive 2 14) then - ZSome ((bits_of_big_int 14 n, len)) - else - ZNone () + if Big_int.less_equal Big_int.zero n && Big_int.less n (Big_int.pow_int_positive 2 14) then + ZSome (bits_of_big_int 14 n, len) + else ZNone () let hex_bits_15_matches_prefix s = match maybe_int_of_prefix s with | ZNone () -> ZNone () | ZSome (n, len) -> - if Big_int.less_equal Big_int.zero n - && Big_int.less n (Big_int.pow_int_positive 2 15) then - ZSome ((bits_of_big_int 15 n, len)) - else - ZNone () + if Big_int.less_equal Big_int.zero n && Big_int.less n (Big_int.pow_int_positive 2 15) then + ZSome (bits_of_big_int 15 n, len) + else ZNone () let hex_bits_16_matches_prefix s = match maybe_int_of_prefix s with | ZNone () -> ZNone () | ZSome (n, len) -> - if Big_int.less_equal Big_int.zero n - && Big_int.less n (Big_int.pow_int_positive 2 16) then - ZSome ((bits_of_big_int 16 n, len)) - else - ZNone () + if Big_int.less_equal Big_int.zero n && Big_int.less n (Big_int.pow_int_positive 2 16) then + ZSome (bits_of_big_int 16 n, len) + else ZNone () let hex_bits_17_matches_prefix s = match maybe_int_of_prefix s with | ZNone () -> ZNone () | ZSome (n, len) -> - if Big_int.less_equal Big_int.zero n - && Big_int.less n (Big_int.pow_int_positive 2 17) then - ZSome ((bits_of_big_int 17 n, len)) - else - ZNone () + if Big_int.less_equal Big_int.zero n && Big_int.less n (Big_int.pow_int_positive 2 17) then + ZSome (bits_of_big_int 17 n, len) + else ZNone () let hex_bits_18_matches_prefix s = match maybe_int_of_prefix s with | ZNone () -> ZNone () | ZSome (n, len) -> - if Big_int.less_equal Big_int.zero n - && Big_int.less n (Big_int.pow_int_positive 2 18) then - ZSome ((bits_of_big_int 18 n, len)) - else - ZNone () + if Big_int.less_equal Big_int.zero n && Big_int.less n (Big_int.pow_int_positive 2 18) then + ZSome (bits_of_big_int 18 n, len) + else ZNone () let hex_bits_19_matches_prefix s = match maybe_int_of_prefix s with | ZNone () -> ZNone () | ZSome (n, len) -> - if Big_int.less_equal Big_int.zero n - && Big_int.less n (Big_int.pow_int_positive 2 19) then - ZSome ((bits_of_big_int 19 n, len)) - else - ZNone () + if Big_int.less_equal Big_int.zero n && Big_int.less n (Big_int.pow_int_positive 2 19) then + ZSome (bits_of_big_int 19 n, len) + else ZNone () let hex_bits_20_matches_prefix s = match maybe_int_of_prefix s with | ZNone () -> ZNone () | ZSome (n, len) -> - if Big_int.less_equal Big_int.zero n - && Big_int.less n (Big_int.pow_int_positive 2 20) then - ZSome ((bits_of_big_int 20 n, len)) - else - ZNone () + if Big_int.less_equal Big_int.zero n && Big_int.less n (Big_int.pow_int_positive 2 20) then + ZSome (bits_of_big_int 20 n, len) + else ZNone () let hex_bits_21_matches_prefix s = match maybe_int_of_prefix s with | ZNone () -> ZNone () | ZSome (n, len) -> - if Big_int.less_equal Big_int.zero n - && Big_int.less n (Big_int.pow_int_positive 2 21) then - ZSome ((bits_of_big_int 21 n, len)) - else - ZNone () + if Big_int.less_equal Big_int.zero n && Big_int.less n (Big_int.pow_int_positive 2 21) then + ZSome (bits_of_big_int 21 n, len) + else ZNone () let hex_bits_22_matches_prefix s = match maybe_int_of_prefix s with | ZNone () -> ZNone () | ZSome (n, len) -> - if Big_int.less_equal Big_int.zero n - && Big_int.less n (Big_int.pow_int_positive 2 22) then - ZSome ((bits_of_big_int 22 n, len)) - else - ZNone () + if Big_int.less_equal Big_int.zero n && Big_int.less n (Big_int.pow_int_positive 2 22) then + ZSome (bits_of_big_int 22 n, len) + else ZNone () let hex_bits_23_matches_prefix s = match maybe_int_of_prefix s with | ZNone () -> ZNone () | ZSome (n, len) -> - if Big_int.less_equal Big_int.zero n - && Big_int.less n (Big_int.pow_int_positive 2 23) then - ZSome ((bits_of_big_int 23 n, len)) - else - ZNone () + if Big_int.less_equal Big_int.zero n && Big_int.less n (Big_int.pow_int_positive 2 23) then + ZSome (bits_of_big_int 23 n, len) + else ZNone () let hex_bits_24_matches_prefix s = match maybe_int_of_prefix s with | ZNone () -> ZNone () | ZSome (n, len) -> - if Big_int.less_equal Big_int.zero n - && Big_int.less n (Big_int.pow_int_positive 2 24) then - ZSome ((bits_of_big_int 24 n, len)) - else - ZNone () + if Big_int.less_equal Big_int.zero n && Big_int.less n (Big_int.pow_int_positive 2 24) then + ZSome (bits_of_big_int 24 n, len) + else ZNone () let hex_bits_25_matches_prefix s = match maybe_int_of_prefix s with | ZNone () -> ZNone () | ZSome (n, len) -> - if Big_int.less_equal Big_int.zero n - && Big_int.less n (Big_int.pow_int_positive 2 25) then - ZSome ((bits_of_big_int 25 n, len)) - else - ZNone () + if Big_int.less_equal Big_int.zero n && Big_int.less n (Big_int.pow_int_positive 2 25) then + ZSome (bits_of_big_int 25 n, len) + else ZNone () let hex_bits_26_matches_prefix s = match maybe_int_of_prefix s with | ZNone () -> ZNone () | ZSome (n, len) -> - if Big_int.less_equal Big_int.zero n - && Big_int.less n (Big_int.pow_int_positive 2 26) then - ZSome ((bits_of_big_int 26 n, len)) - else - ZNone () + if Big_int.less_equal Big_int.zero n && Big_int.less n (Big_int.pow_int_positive 2 26) then + ZSome (bits_of_big_int 26 n, len) + else ZNone () let hex_bits_27_matches_prefix s = match maybe_int_of_prefix s with | ZNone () -> ZNone () | ZSome (n, len) -> - if Big_int.less_equal Big_int.zero n - && Big_int.less n (Big_int.pow_int_positive 2 27) then - ZSome ((bits_of_big_int 27 n, len)) - else - ZNone () + if Big_int.less_equal Big_int.zero n && Big_int.less n (Big_int.pow_int_positive 2 27) then + ZSome (bits_of_big_int 27 n, len) + else ZNone () let hex_bits_28_matches_prefix s = match maybe_int_of_prefix s with | ZNone () -> ZNone () | ZSome (n, len) -> - if Big_int.less_equal Big_int.zero n - && Big_int.less n (Big_int.pow_int_positive 2 28) then - ZSome ((bits_of_big_int 28 n, len)) - else - ZNone () + if Big_int.less_equal Big_int.zero n && Big_int.less n (Big_int.pow_int_positive 2 28) then + ZSome (bits_of_big_int 28 n, len) + else ZNone () let hex_bits_29_matches_prefix s = match maybe_int_of_prefix s with | ZNone () -> ZNone () | ZSome (n, len) -> - if Big_int.less_equal Big_int.zero n - && Big_int.less n (Big_int.pow_int_positive 2 29) then - ZSome ((bits_of_big_int 29 n, len)) - else - ZNone () + if Big_int.less_equal Big_int.zero n && Big_int.less n (Big_int.pow_int_positive 2 29) then + ZSome (bits_of_big_int 29 n, len) + else ZNone () let hex_bits_30_matches_prefix s = match maybe_int_of_prefix s with | ZNone () -> ZNone () | ZSome (n, len) -> - if Big_int.less_equal Big_int.zero n - && Big_int.less n (Big_int.pow_int_positive 2 30) then - ZSome ((bits_of_big_int 30 n, len)) - else - ZNone () + if Big_int.less_equal Big_int.zero n && Big_int.less n (Big_int.pow_int_positive 2 30) then + ZSome (bits_of_big_int 30 n, len) + else ZNone () let hex_bits_31_matches_prefix s = match maybe_int_of_prefix s with | ZNone () -> ZNone () | ZSome (n, len) -> - if Big_int.less_equal Big_int.zero n - && Big_int.less n (Big_int.pow_int_positive 2 31) then - ZSome ((bits_of_big_int 31 n, len)) - else - ZNone () + if Big_int.less_equal Big_int.zero n && Big_int.less n (Big_int.pow_int_positive 2 31) then + ZSome (bits_of_big_int 31 n, len) + else ZNone () let hex_bits_32_matches_prefix s = match maybe_int_of_prefix s with | ZNone () -> ZNone () | ZSome (n, len) -> - if Big_int.less_equal Big_int.zero n - && Big_int.less n (Big_int.pow_int_positive 2 32) then - ZSome ((bits_of_big_int 32 n, len)) - else - ZNone () + if Big_int.less_equal Big_int.zero n && Big_int.less n (Big_int.pow_int_positive 2 32) then + ZSome (bits_of_big_int 32 n, len) + else ZNone () let hex_bits_33_matches_prefix s = match maybe_int_of_prefix s with | ZNone () -> ZNone () | ZSome (n, len) -> - if Big_int.less_equal Big_int.zero n - && Big_int.less n (Big_int.pow_int_positive 2 33) then - ZSome ((bits_of_big_int 33 n, len)) - else - ZNone () + if Big_int.less_equal Big_int.zero n && Big_int.less n (Big_int.pow_int_positive 2 33) then + ZSome (bits_of_big_int 33 n, len) + else ZNone () let hex_bits_48_matches_prefix s = match maybe_int_of_prefix s with | ZNone () -> ZNone () | ZSome (n, len) -> - if Big_int.less_equal Big_int.zero n - && Big_int.less n (Big_int.pow_int_positive 2 48) then - ZSome ((bits_of_big_int 48 n, len)) - else - ZNone () + if Big_int.less_equal Big_int.zero n && Big_int.less n (Big_int.pow_int_positive 2 48) then + ZSome (bits_of_big_int 48 n, len) + else ZNone () let hex_bits_64_matches_prefix s = match maybe_int_of_prefix s with | ZNone () -> ZNone () | ZSome (n, len) -> - if Big_int.less_equal Big_int.zero n - && Big_int.less n (Big_int.pow_int_positive 2 64) then - ZSome ((bits_of_big_int 64 n, len)) - else - ZNone () - + if Big_int.less_equal Big_int.zero n && Big_int.less n (Big_int.pow_int_positive 2 64) then + ZSome (bits_of_big_int 64 n, len) + else ZNone () -let string_of_bool = function - | true -> "true" - | false -> "false" +let string_of_bool = function true -> "true" | false -> "false" let dec_str x = Big_int.to_string x @@ -1246,8 +1075,7 @@ let load_raw (paddr, file) = wram (Big_int.add paddr (Big_int.of_int !i)) byte; incr i done - with - | End_of_file -> () + with End_of_file -> () (* XXX this could count cycles and exit after given limit *) let cycle_count () = () @@ -1256,14 +1084,12 @@ let cycle_count () = () let rand_zvector (g : 'generators) (size : int) (_order : bool) (elem_gen : 'generators -> 'a) : 'a list = Util.list_init size (fun _ -> elem_gen g) -let rand_zbit (_ : 'generators) : bit = - bit_of_bool (Random.bool()) +let rand_zbit (_ : 'generators) : bit = bit_of_bool (Random.bool ()) let rand_zbitvector (g : 'generators) (size : int) (_order : bool) : bit list = Util.list_init size (fun _ -> rand_zbit g) -let rand_zbool (_ : 'generators) : bool = - Random.bool() +let rand_zbool (_ : 'generators) : bool = Random.bool () let rand_zunit (_ : 'generators) : unit = () diff --git a/src/lib/scattered.ml b/src/lib/scattered.ml index ff4e16c92..ecce1d4bb 100644 --- a/src/lib/scattered.ml +++ b/src/lib/scattered.ml @@ -72,14 +72,12 @@ open Ast_util let funcl_id (FCL_aux (FCL_funcl (id, _), _)) = id let rec last_scattered_funcl id = function - | DEF_aux (DEF_scattered (SD_aux (SD_funcl funcl, _)), _) :: _ - when Id.compare (funcl_id funcl) id = 0 -> false + | DEF_aux (DEF_scattered (SD_aux (SD_funcl funcl, _)), _) :: _ when Id.compare (funcl_id funcl) id = 0 -> false | _ :: defs -> last_scattered_funcl id defs | [] -> true let rec last_scattered_mapcl id = function - | DEF_aux (DEF_scattered (SD_aux (SD_mapcl (mid, _), _)), _) :: _ - when Id.compare mid id = 0 -> false + | DEF_aux (DEF_scattered (SD_aux (SD_mapcl (mid, _), _)), _) :: _ when Id.compare mid id = 0 -> false | _ :: defs -> last_scattered_mapcl id defs | [] -> true @@ -90,16 +88,13 @@ let no_tannot_opt l = Typ_annot_opt_aux (Typ_annot_opt_none, gen_loc l) let rec filter_union_clauses id = function | DEF_aux (DEF_scattered (SD_aux (SD_unioncl (uid, tu), _)), _) :: defs when Id.compare id uid = 0 -> - filter_union_clauses id defs - | def :: defs -> - def :: filter_union_clauses id defs + filter_union_clauses id defs + | def :: defs -> def :: filter_union_clauses id defs | [] -> [] -let patch_funcl_loc def_annot (FCL_aux (aux, (_, tannot))) = - FCL_aux (aux, (def_annot, tannot)) +let patch_funcl_loc def_annot (FCL_aux (aux, (_, tannot))) = FCL_aux (aux, (def_annot, tannot)) -let patch_mapcl_annot def_annot (MCL_aux (aux, (_, tannot))) = - MCL_aux (aux, (def_annot, tannot)) +let patch_mapcl_annot def_annot (MCL_aux (aux, (_, tannot))) = MCL_aux (aux, (def_annot, tannot)) module PC_config = struct type t = Type_check.tannot @@ -107,69 +102,69 @@ module PC_config = struct let add_attribute l attr arg = Type_check.map_uannot (add_attribute l attr arg) end -module PC = Pattern_completeness.Make(PC_config) +module PC = Pattern_completeness.Make (PC_config) let rec descatter' funcls mapcls = function (* For scattered functions we collect all the seperate function clauses until we find the last one, then we turn that function clause into a DEF_fundef containing all the clauses. *) | DEF_aux (DEF_scattered (SD_aux (SD_funcl funcl, (l, tannot))), def_annot) :: defs - when last_scattered_funcl (funcl_id funcl) defs -> - let funcl = patch_funcl_loc def_annot funcl in - let clauses = match Bindings.find_opt (funcl_id funcl) funcls with - | Some clauses -> List.rev (funcl :: clauses) - | None -> [funcl] - in - let clauses, update_attr = Type_check.(check_funcls_complete l (env_of_tannot tannot) clauses (typ_of_tannot tannot)) in - DEF_aux (DEF_fundef (FD_aux (FD_function (fake_rec_opt l, no_tannot_opt l, clauses), - (gen_loc l, tannot))), - update_attr (mk_def_annot (gen_loc l))) - :: descatter' funcls mapcls defs - + when last_scattered_funcl (funcl_id funcl) defs -> + let funcl = patch_funcl_loc def_annot funcl in + let clauses = + match Bindings.find_opt (funcl_id funcl) funcls with + | Some clauses -> List.rev (funcl :: clauses) + | None -> [funcl] + in + let clauses, update_attr = + Type_check.(check_funcls_complete l (env_of_tannot tannot) clauses (typ_of_tannot tannot)) + in + DEF_aux + ( DEF_fundef (FD_aux (FD_function (fake_rec_opt l, no_tannot_opt l, clauses), (gen_loc l, tannot))), + update_attr (mk_def_annot (gen_loc l)) + ) + :: descatter' funcls mapcls defs | DEF_aux (DEF_scattered (SD_aux (SD_funcl funcl, (l, _))), def_annot) :: defs -> - let id = funcl_id funcl in - let funcl = patch_funcl_loc def_annot funcl in - begin match Bindings.find_opt id funcls with - | Some clauses -> descatter' (Bindings.add id (funcl :: clauses) funcls) mapcls defs - | None -> descatter' (Bindings.add id [funcl] funcls) mapcls defs - end - + let id = funcl_id funcl in + let funcl = patch_funcl_loc def_annot funcl in + begin + match Bindings.find_opt id funcls with + | Some clauses -> descatter' (Bindings.add id (funcl :: clauses) funcls) mapcls defs + | None -> descatter' (Bindings.add id [funcl] funcls) mapcls defs + end (* Scattered mappings are handled the same way as scattered functions *) | DEF_aux (DEF_scattered (SD_aux (SD_mapcl (id, mapcl), (l, tannot))), def_annot) :: defs - when last_scattered_mapcl id defs -> - let mapcl = patch_mapcl_annot def_annot mapcl in - let clauses = match Bindings.find_opt id mapcls with - | Some clauses -> List.rev (mapcl :: clauses) - | None -> [mapcl] - in - DEF_aux (DEF_mapdef (MD_aux (MD_mapping (id, no_tannot_opt l, clauses), - (gen_loc l, tannot))), - mk_def_annot (gen_loc l)) - :: descatter' funcls mapcls defs - + when last_scattered_mapcl id defs -> + let mapcl = patch_mapcl_annot def_annot mapcl in + let clauses = + match Bindings.find_opt id mapcls with Some clauses -> List.rev (mapcl :: clauses) | None -> [mapcl] + in + DEF_aux + (DEF_mapdef (MD_aux (MD_mapping (id, no_tannot_opt l, clauses), (gen_loc l, tannot))), mk_def_annot (gen_loc l)) + :: descatter' funcls mapcls defs | DEF_aux (DEF_scattered (SD_aux (SD_mapcl (id, mapcl), _)), def_annot) :: defs -> - let mapcl = patch_mapcl_annot def_annot mapcl in - begin match Bindings.find_opt id mapcls with - | Some clauses -> descatter' funcls (Bindings.add id (mapcl :: clauses) mapcls) defs - | None -> descatter' funcls (Bindings.add id [mapcl] mapcls) defs - end - + let mapcl = patch_mapcl_annot def_annot mapcl in + begin + match Bindings.find_opt id mapcls with + | Some clauses -> descatter' funcls (Bindings.add id (mapcl :: clauses) mapcls) defs + | None -> descatter' funcls (Bindings.add id [mapcl] mapcls) defs + end (* For scattered unions, when we find a union declaration we immediately grab all the future clauses and turn it into a regular union declaration. *) | DEF_aux (DEF_scattered (SD_aux (SD_variant (id, typq), (l, _))), def_annot) :: defs -> - let tus = get_scattered_union_clauses id defs in - begin match tus with - | [] -> raise (Reporting.err_general l "No clauses found for scattered union type") - | _ -> - DEF_aux (DEF_type (TD_aux (TD_variant (id, typq, tus, false), (gen_loc l, Type_check.empty_tannot))), def_annot) - :: descatter' funcls mapcls (filter_union_clauses id defs) - end - + let tus = get_scattered_union_clauses id defs in + begin + match tus with + | [] -> raise (Reporting.err_general l "No clauses found for scattered union type") + | _ -> + DEF_aux + (DEF_type (TD_aux (TD_variant (id, typq, tus, false), (gen_loc l, Type_check.empty_tannot))), def_annot) + :: descatter' funcls mapcls (filter_union_clauses id defs) + end (* Therefore we should never see SD_unioncl... *) | DEF_aux (DEF_scattered (SD_aux (SD_unioncl _, (l, _))), _) :: _ -> - raise (Reporting.err_unreachable l __POS__ "Found union clause during de-scattering") - + raise (Reporting.err_unreachable l __POS__ "Found union clause during de-scattering") | def :: defs -> def :: descatter' funcls mapcls defs | [] -> [] diff --git a/src/lib/spec_analysis.ml b/src/lib/spec_analysis.ml index 9fea2260b..10534b084 100644 --- a/src/lib/spec_analysis.ml +++ b/src/lib/spec_analysis.ml @@ -70,7 +70,7 @@ open Ast_defs open Ast_util open Util -module Nameset = Set.Make(String) +module Nameset = Set.Make (String) let mt = Nameset.empty @@ -79,422 +79,427 @@ let mt = Nameset.empty let conditional_add typ_or_exp bound used id = let known_list = - if typ_or_exp (*true for typ*) - then ["bit";"vector";"unit";"string";"int";"bool"] - else ["=="; "!="; "|";"~";"&";"add_int"] in - let i = (string_of_id (if typ_or_exp then prepend_id "typ:" id else id)) in - if Nameset.mem i bound || List.mem i known_list - then used - else Nameset.add i used + if typ_or_exp (*true for typ*) then ["bit"; "vector"; "unit"; "string"; "int"; "bool"] + else ["=="; "!="; "|"; "~"; "&"; "add_int"] + in + let i = string_of_id (if typ_or_exp then prepend_id "typ:" id else id) in + if Nameset.mem i bound || List.mem i known_list then used else Nameset.add i used let conditional_add_typ = conditional_add true let conditional_add_exp = conditional_add false - let nameset_bigunion = List.fold_left Nameset.union mt - -let rec free_type_names_t consider_var (Typ_aux (t, l)) = match t with +let rec free_type_names_t consider_var (Typ_aux (t, l)) = + match t with | Typ_var name -> if consider_var then Nameset.add (string_of_kid name) mt else mt | Typ_id name -> Nameset.add (string_of_id name) mt - | Typ_fn (arg_typs,ret_typ) -> - List.fold_left Nameset.union (free_type_names_t consider_var ret_typ) (List.map (free_type_names_t consider_var) arg_typs) - | Typ_bidir (t1,t2) -> Nameset.union (free_type_names_t consider_var t1) - (free_type_names_t consider_var t2) + | Typ_fn (arg_typs, ret_typ) -> + List.fold_left Nameset.union + (free_type_names_t consider_var ret_typ) + (List.map (free_type_names_t consider_var) arg_typs) + | Typ_bidir (t1, t2) -> Nameset.union (free_type_names_t consider_var t1) (free_type_names_t consider_var t2) | Typ_tuple ts -> free_type_names_ts consider_var ts - | Typ_app (name,targs) -> Nameset.add (string_of_id name) (free_type_names_t_args consider_var targs) - | Typ_exist (kopts,_,t') -> List.fold_left (fun s kopt -> Nameset.remove (string_of_kid (kopt_kid kopt)) s) (free_type_names_t consider_var t') kopts + | Typ_app (name, targs) -> Nameset.add (string_of_id name) (free_type_names_t_args consider_var targs) + | Typ_exist (kopts, _, t') -> + List.fold_left + (fun s kopt -> Nameset.remove (string_of_kid (kopt_kid kopt)) s) + (free_type_names_t consider_var t') kopts | Typ_internal_unknown -> Reporting.unreachable l __POS__ "escaped Typ_internal_unknown" + and free_type_names_ts consider_var ts = nameset_bigunion (List.map (free_type_names_t consider_var) ts) -and free_type_names_t_arg consider_var = function - | A_aux (A_typ t, _) -> free_type_names_t consider_var t - | _ -> mt -and free_type_names_t_args consider_var targs = - nameset_bigunion (List.map (free_type_names_t_arg consider_var) targs) +and free_type_names_t_arg consider_var = function A_aux (A_typ t, _) -> free_type_names_t consider_var t | _ -> mt + +and free_type_names_t_args consider_var targs = nameset_bigunion (List.map (free_type_names_t_arg consider_var) targs) -let rec fv_of_typ consider_var bound used (Typ_aux (t,l)) : Nameset.t = +let rec fv_of_typ consider_var bound used (Typ_aux (t, l)) : Nameset.t = match t with - | Typ_var (Kid_aux (Var v,l)) -> - if consider_var - then conditional_add_typ bound used (Ast.Id_aux (Ast.Id v,l)) - else used + | Typ_var (Kid_aux (Var v, l)) -> + if consider_var then conditional_add_typ bound used (Ast.Id_aux (Ast.Id v, l)) else used | Typ_id id -> conditional_add_typ bound used id - | Typ_fn(arg,ret) -> - fv_of_typ consider_var bound (List.fold_left Nameset.union Nameset.empty (List.map (fv_of_typ consider_var bound used) arg)) ret - | Typ_bidir(t1,t2) -> fv_of_typ consider_var bound (fv_of_typ consider_var bound used t1) t2 (* TODO FIXME? *) + | Typ_fn (arg, ret) -> + fv_of_typ consider_var bound + (List.fold_left Nameset.union Nameset.empty (List.map (fv_of_typ consider_var bound used) arg)) + ret + | Typ_bidir (t1, t2) -> fv_of_typ consider_var bound (fv_of_typ consider_var bound used t1) t2 (* TODO FIXME? *) | Typ_tuple ts -> List.fold_right (fun t n -> fv_of_typ consider_var bound n t) ts used - | Typ_app(id,targs) -> - List.fold_right (fun ta n -> fv_of_targ consider_var bound n ta) targs (conditional_add_typ bound used id) - | Typ_exist (kopts,_,t') -> - fv_of_typ consider_var - (List.fold_left (fun b (KOpt_aux (KOpt_kind (_, (Kid_aux (Var v,_))), _)) -> Nameset.add v b) bound kopts) - used t' + | Typ_app (id, targs) -> + List.fold_right (fun ta n -> fv_of_targ consider_var bound n ta) targs (conditional_add_typ bound used id) + | Typ_exist (kopts, _, t') -> + fv_of_typ consider_var + (List.fold_left (fun b (KOpt_aux (KOpt_kind (_, Kid_aux (Var v, _)), _)) -> Nameset.add v b) bound kopts) + used t' | Typ_internal_unknown -> Reporting.unreachable l __POS__ "escaped Typ_internal_unknown" and fv_of_tannot consider_var bound used tannot = - match Type_check.destruct_tannot tannot with - | None -> mt - | Some (_, t) -> fv_of_typ consider_var bound used t + match Type_check.destruct_tannot tannot with None -> mt | Some (_, t) -> fv_of_typ consider_var bound used t -and fv_of_targ consider_var bound used (Ast.A_aux(targ,_)) : Nameset.t = match targ with +and fv_of_targ consider_var bound used (Ast.A_aux (targ, _)) : Nameset.t = + match targ with | A_typ t -> fv_of_typ consider_var bound used t | A_nexp n -> fv_of_nexp consider_var bound used n | _ -> used -and fv_of_nexp consider_var bound used (Ast.Nexp_aux(n,_)) = match n with +and fv_of_nexp consider_var bound used (Ast.Nexp_aux (n, _)) = + match n with | Nexp_id id -> conditional_add_typ bound used id - | Nexp_var (Ast.Kid_aux (Ast.Var i,_)) -> - if consider_var - then conditional_add_typ bound used (Ast.Id_aux (Ast.Id i, Parse_ast.Unknown)) - else used - | Nexp_times (n1,n2) | Ast.Nexp_sum (n1,n2) | Ast.Nexp_minus(n1,n2) -> - fv_of_nexp consider_var bound (fv_of_nexp consider_var bound used n1) n2 + | Nexp_var (Ast.Kid_aux (Ast.Var i, _)) -> + if consider_var then conditional_add_typ bound used (Ast.Id_aux (Ast.Id i, Parse_ast.Unknown)) else used + | Nexp_times (n1, n2) | Ast.Nexp_sum (n1, n2) | Ast.Nexp_minus (n1, n2) -> + fv_of_nexp consider_var bound (fv_of_nexp consider_var bound used n1) n2 | Nexp_exp n | Ast.Nexp_neg n -> fv_of_nexp consider_var bound used n | _ -> used -and fv_of_nconstraint consider_var bound used (Ast.NC_aux(nc,_)) = match nc with - | NC_equal (n1,n2) | NC_bounded_ge (n1,n2) | NC_bounded_gt (n1, n2) | NC_bounded_le (n1,n2) - | NC_bounded_lt (n1,n2) | NC_not_equal (n1, n2) -> - fv_of_nexp consider_var bound (fv_of_nexp consider_var bound used n1) n2 - | NC_set (Ast.Kid_aux (Ast.Var i,_), _) - | NC_var (Ast.Kid_aux (Ast.Var i,_)) -> - if consider_var - then conditional_add_typ bound used (Ast.Id_aux (Ast.Id i, Parse_ast.Unknown)) - else used - | NC_or (nc1,nc2) | NC_and (nc1,nc2) -> - fv_of_nconstraint consider_var bound (fv_of_nconstraint consider_var bound used nc1) nc2 +and fv_of_nconstraint consider_var bound used (Ast.NC_aux (nc, _)) = + match nc with + | NC_equal (n1, n2) + | NC_bounded_ge (n1, n2) + | NC_bounded_gt (n1, n2) + | NC_bounded_le (n1, n2) + | NC_bounded_lt (n1, n2) + | NC_not_equal (n1, n2) -> + fv_of_nexp consider_var bound (fv_of_nexp consider_var bound used n1) n2 + | NC_set (Ast.Kid_aux (Ast.Var i, _), _) | NC_var (Ast.Kid_aux (Ast.Var i, _)) -> + if consider_var then conditional_add_typ bound used (Ast.Id_aux (Ast.Id i, Parse_ast.Unknown)) else used + | NC_or (nc1, nc2) | NC_and (nc1, nc2) -> + fv_of_nconstraint consider_var bound (fv_of_nconstraint consider_var bound used nc1) nc2 | NC_app (id, targs) -> - List.fold_right (fun ta n -> fv_of_targ consider_var bound n ta) targs (conditional_add_typ bound used id) + List.fold_right (fun ta n -> fv_of_targ consider_var bound n ta) targs (conditional_add_typ bound used id) | NC_true | NC_false -> used -let typq_bindings (TypQ_aux(tq,_)) = match tq with +let typq_bindings (TypQ_aux (tq, _)) = + match tq with | TypQ_tq quants -> - List.fold_right (fun (QI_aux (qi,_)) bounds -> - match qi with - | QI_id (KOpt_aux(k,_)) -> - (match k with - | KOpt_kind (_, Kid_aux (Var s,_)) -> Nameset.add s bounds) - | _ -> bounds) quants mt - | TypQ_no_forall -> mt - -let fv_of_typschm consider_var bound used (Ast.TypSchm_aux ((Ast.TypSchm_ts(typq,typ)),_)) = + List.fold_right + (fun (QI_aux (qi, _)) bounds -> + match qi with + | QI_id (KOpt_aux (k, _)) -> ( + match k with KOpt_kind (_, Kid_aux (Var s, _)) -> Nameset.add s bounds + ) + | _ -> bounds + ) + quants mt + | TypQ_no_forall -> mt + +let fv_of_typschm consider_var bound used (Ast.TypSchm_aux (Ast.TypSchm_ts (typq, typ), _)) = let ts_bound = if consider_var then typq_bindings typq else mt in - ts_bound, fv_of_typ consider_var (Nameset.union bound ts_bound) used typ + (ts_bound, fv_of_typ consider_var (Nameset.union bound ts_bound) used typ) let rec fv_of_typ_pat consider_var bound used (TP_aux (tp, _)) = match tp with - | TP_wild -> bound, used - | TP_var (Kid_aux (Var v, l)) -> - Nameset.add (string_of_id (Ast.Id_aux (Ast.Id v,l))) bound, used + | TP_wild -> (bound, used) + | TP_var (Kid_aux (Var v, l)) -> (Nameset.add (string_of_id (Ast.Id_aux (Ast.Id v, l))) bound, used) | TP_app (id, tps) -> - let u = conditional_add_typ bound used id in - List.fold_right (fun ta (b, u) -> fv_of_typ_pat consider_var b u ta) tps (bound, u) + let u = conditional_add_typ bound used id in + List.fold_right (fun ta (b, u) -> fv_of_typ_pat consider_var b u ta) tps (bound, u) -let rec pat_bindings consider_var bound used (P_aux(p,(_,tannot))) = - let list_fv bound used ps = List.fold_right (fun p (b,n) -> pat_bindings consider_var b n p) ps (bound, used) in +let rec pat_bindings consider_var bound used (P_aux (p, (_, tannot))) = + let list_fv bound used ps = List.fold_right (fun p (b, n) -> pat_bindings consider_var b n p) ps (bound, used) in match p with - | P_lit _ | P_wild -> bound,used - | P_or(p1,p2) -> - (* The typechecker currently drops bindings in disjunctions entirely *) - let _b1, u1 = pat_bindings consider_var bound used p1 in - let _b2, u2 = pat_bindings consider_var bound used p2 in - bound, Nameset.union u1 u2 + | P_lit _ | P_wild -> (bound, used) + | P_or (p1, p2) -> + (* The typechecker currently drops bindings in disjunctions entirely *) + let _b1, u1 = pat_bindings consider_var bound used p1 in + let _b2, u2 = pat_bindings consider_var bound used p2 in + (bound, Nameset.union u1 u2) | P_not p -> - let _b, u = pat_bindings consider_var bound used p in - bound, u - | P_as(p,id) -> let b,ns = pat_bindings consider_var bound used p in - Nameset.add (string_of_id id) b,ns - | P_typ(t,p) -> - let used = fv_of_tannot consider_var bound used tannot in - let ns = fv_of_typ consider_var bound used t in pat_bindings consider_var bound ns p + let _b, u = pat_bindings consider_var bound used p in + (bound, u) + | P_as (p, id) -> + let b, ns = pat_bindings consider_var bound used p in + (Nameset.add (string_of_id id) b, ns) + | P_typ (t, p) -> + let used = fv_of_tannot consider_var bound used tannot in + let ns = fv_of_typ consider_var bound used t in + pat_bindings consider_var bound ns p | P_id id | P_vector_subrange (id, _, _) -> - let used = fv_of_tannot consider_var bound used tannot in - Nameset.add (string_of_id id) bound,used + let used = fv_of_tannot consider_var bound used tannot in + (Nameset.add (string_of_id id) bound, used) | P_var (p, typ_p) -> - let b, u = pat_bindings consider_var bound used p in - fv_of_typ_pat consider_var b u typ_p - | P_app(id,pats) -> - let used = fv_of_tannot consider_var bound used tannot in - list_fv bound (Nameset.add (string_of_id id) used) pats - | P_vector pats | Ast.P_vector_concat pats | Ast.P_tuple pats | Ast.P_list pats | P_string_append pats -> list_fv bound used pats - | P_cons (p1,p2) -> - let b1, u1 = pat_bindings consider_var bound used p1 in - pat_bindings consider_var b1 u1 p2 - -let rec fv_of_exp consider_var bound used set (E_aux (e,(_,tannot))) : (Nameset.t * Nameset.t * Nameset.t) = - let list_fv b n s es = List.fold_right (fun e (b,n,s) -> fv_of_exp consider_var b n s e) es (b,n,s) in + let b, u = pat_bindings consider_var bound used p in + fv_of_typ_pat consider_var b u typ_p + | P_app (id, pats) -> + let used = fv_of_tannot consider_var bound used tannot in + list_fv bound (Nameset.add (string_of_id id) used) pats + | P_vector pats | Ast.P_vector_concat pats | Ast.P_tuple pats | Ast.P_list pats | P_string_append pats -> + list_fv bound used pats + | P_cons (p1, p2) -> + let b1, u1 = pat_bindings consider_var bound used p1 in + pat_bindings consider_var b1 u1 p2 + +let rec fv_of_exp consider_var bound used set (E_aux (e, (_, tannot))) : Nameset.t * Nameset.t * Nameset.t = + let list_fv b n s es = List.fold_right (fun e (b, n, s) -> fv_of_exp consider_var b n s e) es (b, n, s) in match e with - | E_lit _ - | E_internal_value _ -> bound,used,set - | E_block es | Ast.E_tuple es | Ast.E_vector es | Ast.E_list es -> - list_fv bound used set es + | E_lit _ | E_internal_value _ -> (bound, used, set) + | E_block es | Ast.E_tuple es | Ast.E_vector es | Ast.E_list es -> list_fv bound used set es | E_id id | E_ref id -> - let used = conditional_add_exp bound used id in - let used = fv_of_tannot consider_var bound used tannot in - bound,used,set - | E_typ (t,e) -> - let u = fv_of_typ consider_var (if consider_var then bound else mt) used t in - fv_of_exp consider_var bound u set e - | E_app(id,es) -> - let us = conditional_add_exp bound used id in - let us = conditional_add_exp bound us (prepend_id "val:" id) in - list_fv bound us set es - | E_app_infix(l,id,r) -> - let us = conditional_add_exp bound used id in - let us = conditional_add_exp bound us (prepend_id "val:" id) in - list_fv bound us set [l;r] - | E_if(c,t,e) -> list_fv bound used set [c;t;e] - | E_for(id,from,to_,by,_,body) -> - let _,used,set = list_fv bound used set [from;to_;by] in - fv_of_exp consider_var (Nameset.add (string_of_id id) bound) used set body - | E_loop(_, measure, cond, body) -> - let m = match measure with Measure_aux (Measure_some exp,_) -> [exp] | _ -> [] in - list_fv bound used set (m @ [cond; body]) - | E_vector_access(v,i) -> list_fv bound used set [v;i] - | E_vector_subrange(v,i1,i2) -> list_fv bound used set [v;i1;i2] - | E_vector_update(v,i,e) -> list_fv bound used set [v;i;e] - | E_vector_update_subrange(v,i1,i2,e) -> list_fv bound used set [v;i1;i2;e] - | E_vector_append(e1,e2) | E_cons(e1,e2) -> list_fv bound used set [e1;e2] + let used = conditional_add_exp bound used id in + let used = fv_of_tannot consider_var bound used tannot in + (bound, used, set) + | E_typ (t, e) -> + let u = fv_of_typ consider_var (if consider_var then bound else mt) used t in + fv_of_exp consider_var bound u set e + | E_app (id, es) -> + let us = conditional_add_exp bound used id in + let us = conditional_add_exp bound us (prepend_id "val:" id) in + list_fv bound us set es + | E_app_infix (l, id, r) -> + let us = conditional_add_exp bound used id in + let us = conditional_add_exp bound us (prepend_id "val:" id) in + list_fv bound us set [l; r] + | E_if (c, t, e) -> list_fv bound used set [c; t; e] + | E_for (id, from, to_, by, _, body) -> + let _, used, set = list_fv bound used set [from; to_; by] in + fv_of_exp consider_var (Nameset.add (string_of_id id) bound) used set body + | E_loop (_, measure, cond, body) -> + let m = match measure with Measure_aux (Measure_some exp, _) -> [exp] | _ -> [] in + list_fv bound used set (m @ [cond; body]) + | E_vector_access (v, i) -> list_fv bound used set [v; i] + | E_vector_subrange (v, i1, i2) -> list_fv bound used set [v; i1; i2] + | E_vector_update (v, i, e) -> list_fv bound used set [v; i; e] + | E_vector_update_subrange (v, i1, i2, e) -> list_fv bound used set [v; i1; i2; e] + | E_vector_append (e1, e2) | E_cons (e1, e2) -> list_fv bound used set [e1; e2] | E_struct fexps -> - let used = fv_of_tannot consider_var bound used tannot in - List.fold_right - (fun (FE_aux(FE_fexp(_,e),_)) (b,u,s) -> fv_of_exp consider_var b u s e) fexps (bound,used,set) - | E_struct_update(e, fexps) -> - let b,u,s = fv_of_exp consider_var bound used set e in - List.fold_right - (fun (FE_aux(FE_fexp(_,e),_)) (b,u,s) -> fv_of_exp consider_var b u s e) fexps (b,u,s) - | E_field(e,_) -> fv_of_exp consider_var bound used set e - | E_match(e,pes) - | E_try(e,pes) -> - let b,u,s = fv_of_exp consider_var bound used set e in - fv_of_pes consider_var b u s pes - | E_let(lebind,e) -> - let b,u,s = fv_of_let consider_var bound used set lebind in - fv_of_exp consider_var b u s e + let used = fv_of_tannot consider_var bound used tannot in + List.fold_right + (fun (FE_aux (FE_fexp (_, e), _)) (b, u, s) -> fv_of_exp consider_var b u s e) + fexps (bound, used, set) + | E_struct_update (e, fexps) -> + let b, u, s = fv_of_exp consider_var bound used set e in + List.fold_right (fun (FE_aux (FE_fexp (_, e), _)) (b, u, s) -> fv_of_exp consider_var b u s e) fexps (b, u, s) + | E_field (e, _) -> fv_of_exp consider_var bound used set e + | E_match (e, pes) | E_try (e, pes) -> + let b, u, s = fv_of_exp consider_var bound used set e in + fv_of_pes consider_var b u s pes + | E_let (lebind, e) -> + let b, u, s = fv_of_let consider_var bound used set lebind in + fv_of_exp consider_var b u s e | E_var (lexp, exp1, exp2) -> - let b,u,s = fv_of_lexp consider_var bound used set lexp in - let _,used,set = fv_of_exp consider_var bound used set exp1 in - fv_of_exp consider_var b used set exp2 - | E_assign(lexp,e) -> - let b,u,s = fv_of_lexp consider_var bound used set lexp in - let _,used,set = fv_of_exp consider_var bound u s e in - b,used,set + let b, u, s = fv_of_lexp consider_var bound used set lexp in + let _, used, set = fv_of_exp consider_var bound used set exp1 in + fv_of_exp consider_var b used set exp2 + | E_assign (lexp, e) -> + let b, u, s = fv_of_lexp consider_var bound used set lexp in + let _, used, set = fv_of_exp consider_var bound u s e in + (b, used, set) | E_exit e -> fv_of_exp consider_var bound used set e - | E_assert(c,m) -> list_fv bound used set [c;m] - | E_sizeof ne -> bound, fv_of_nexp consider_var bound used ne, set - | E_return e - | E_throw e - | E_internal_return e -> - fv_of_exp consider_var bound used set e + | E_assert (c, m) -> list_fv bound used set [c; m] + | E_sizeof ne -> (bound, fv_of_nexp consider_var bound used ne, set) + | E_return e | E_throw e | E_internal_return e -> fv_of_exp consider_var bound used set e | E_internal_plet (pat, exp1, exp2) -> - let bp,up = pat_bindings consider_var bound used pat in - let _,u1,s1 = fv_of_exp consider_var bound used set exp1 in - fv_of_exp consider_var bp (Nameset.union up u1) s1 exp2 - | E_constraint nc -> bound, fv_of_nconstraint consider_var bound used nc, set - | E_internal_assume (nc, e) -> - fv_of_exp consider_var bound (fv_of_nconstraint consider_var bound used nc) set e + let bp, up = pat_bindings consider_var bound used pat in + let _, u1, s1 = fv_of_exp consider_var bound used set exp1 in + fv_of_exp consider_var bp (Nameset.union up u1) s1 exp2 + | E_constraint nc -> (bound, fv_of_nconstraint consider_var bound used nc, set) + | E_internal_assume (nc, e) -> fv_of_exp consider_var bound (fv_of_nconstraint consider_var bound used nc) set e and fv_of_pes consider_var bound used set pes = match pes with - | [] -> bound,used,set - | Pat_aux(Pat_exp (p,e),_)::pes -> - let bound_p,us_p = pat_bindings consider_var bound used p in - let bound_e,us_e,set_e = fv_of_exp consider_var bound_p us_p set e in - fv_of_pes consider_var bound us_e set_e pes - | Pat_aux(Pat_when (p,g,e),_)::pes -> - let bound_p,us_p = pat_bindings consider_var bound used p in - let bound_g,us_g,set_g = fv_of_exp consider_var bound_p us_p set g in - let bound_e,us_e,set_e = fv_of_exp consider_var bound_g us_g set_g e in - fv_of_pes consider_var bound us_e set_e pes - -and fv_of_let consider_var bound used set (LB_aux(lebind,_)) = match lebind with - | LB_val(pat,exp) -> - let bound_p, us_p = pat_bindings consider_var bound used pat in - let _,us_e,set_e = fv_of_exp consider_var bound used set exp in - bound_p,Nameset.union us_p us_e,set_e - -and fv_of_lexp consider_var bound used set (LE_aux(lexp,(_,tannot))) = + | [] -> (bound, used, set) + | Pat_aux (Pat_exp (p, e), _) :: pes -> + let bound_p, us_p = pat_bindings consider_var bound used p in + let bound_e, us_e, set_e = fv_of_exp consider_var bound_p us_p set e in + fv_of_pes consider_var bound us_e set_e pes + | Pat_aux (Pat_when (p, g, e), _) :: pes -> + let bound_p, us_p = pat_bindings consider_var bound used p in + let bound_g, us_g, set_g = fv_of_exp consider_var bound_p us_p set g in + let bound_e, us_e, set_e = fv_of_exp consider_var bound_g us_g set_g e in + fv_of_pes consider_var bound us_e set_e pes + +and fv_of_let consider_var bound used set (LB_aux (lebind, _)) = + match lebind with + | LB_val (pat, exp) -> + let bound_p, us_p = pat_bindings consider_var bound used pat in + let _, us_e, set_e = fv_of_exp consider_var bound used set exp in + (bound_p, Nameset.union us_p us_e, set_e) + +and fv_of_lexp consider_var bound used set (LE_aux (lexp, (_, tannot))) = match lexp with | LE_id id -> - let used = fv_of_tannot consider_var bound used tannot in - let i = string_of_id id in - if Nameset.mem i bound - then bound, used, Nameset.add i set - else Nameset.add i bound, Nameset.add i used, set - | LE_deref exp -> - fv_of_exp consider_var bound used set exp - | LE_typ(typ,id) -> - let used = fv_of_tannot consider_var bound used tannot in - let i = string_of_id id in - let used_t = fv_of_typ consider_var bound used typ in - if Nameset.mem i bound - then bound, used_t, Nameset.add i set - else Nameset.add i bound, Nameset.add i used_t, set - | LE_tuple(tups) -> - List.fold_right (fun l (b,u,s) -> fv_of_lexp consider_var b u s l) tups (bound,used,set) - | LE_app(id,args) -> - let (bound,used,set) = - List.fold_right - (fun e (b,u,s) -> - fv_of_exp consider_var b u s e) args (bound,used,set) in - bound,Nameset.add (string_of_id id) used,set - | LE_vector_concat(args) -> - List.fold_right - (fun e (b,u,s) -> - fv_of_lexp consider_var b u s e) args (bound,used,set) - | LE_field(lexp,_) -> fv_of_lexp consider_var bound used set lexp - | LE_vector(lexp,exp) -> - let bound_l,used,set = fv_of_lexp consider_var bound used set lexp in - let _,used,set = fv_of_exp consider_var bound used set exp in - bound_l,used,set - | LE_vector_range(lexp,e1,e2) -> - let bound_l,used,set = fv_of_lexp consider_var bound used set lexp in - let _,used,set = fv_of_exp consider_var bound used set e1 in - let _,used,set = fv_of_exp consider_var bound used set e2 in - bound_l,used,set + let used = fv_of_tannot consider_var bound used tannot in + let i = string_of_id id in + if Nameset.mem i bound then (bound, used, Nameset.add i set) else (Nameset.add i bound, Nameset.add i used, set) + | LE_deref exp -> fv_of_exp consider_var bound used set exp + | LE_typ (typ, id) -> + let used = fv_of_tannot consider_var bound used tannot in + let i = string_of_id id in + let used_t = fv_of_typ consider_var bound used typ in + if Nameset.mem i bound then (bound, used_t, Nameset.add i set) + else (Nameset.add i bound, Nameset.add i used_t, set) + | LE_tuple tups -> List.fold_right (fun l (b, u, s) -> fv_of_lexp consider_var b u s l) tups (bound, used, set) + | LE_app (id, args) -> + let bound, used, set = + List.fold_right (fun e (b, u, s) -> fv_of_exp consider_var b u s e) args (bound, used, set) + in + (bound, Nameset.add (string_of_id id) used, set) + | LE_vector_concat args -> List.fold_right (fun e (b, u, s) -> fv_of_lexp consider_var b u s e) args (bound, used, set) + | LE_field (lexp, _) -> fv_of_lexp consider_var bound used set lexp + | LE_vector (lexp, exp) -> + let bound_l, used, set = fv_of_lexp consider_var bound used set lexp in + let _, used, set = fv_of_exp consider_var bound used set exp in + (bound_l, used, set) + | LE_vector_range (lexp, e1, e2) -> + let bound_l, used, set = fv_of_lexp consider_var bound used set lexp in + let _, used, set = fv_of_exp consider_var bound used set e1 in + let _, used, set = fv_of_exp consider_var bound used set e2 in + (bound_l, used, set) let init_env s = Nameset.singleton s let typ_variants consider_var bound tunions = List.fold_right - (fun (Tu_aux(Tu_ty_id(t,id),_)) (b,n) -> Nameset.add (string_of_id id) b, fv_of_typ consider_var b n t) - tunions - (bound,mt) + (fun (Tu_aux (Tu_ty_id (t, id), _)) (b, n) -> (Nameset.add (string_of_id id) b, fv_of_typ consider_var b n t)) + tunions (bound, mt) let fv_of_abbrev consider_var bound used typq typ_arg = let ts_bound = if consider_var then typq_bindings typq else mt in - ts_bound, fv_of_targ consider_var (Nameset.union bound ts_bound) used typ_arg - -let fv_of_type_def consider_var (TD_aux(t,_)) = match t with - | TD_abbrev(id,typq,typ_arg) -> - init_env ("typ:" ^ string_of_id id), snd (fv_of_abbrev consider_var mt mt typq typ_arg) - | TD_record(id,typq,tids,_) -> - let binds = init_env ("typ:" ^ string_of_id id) in - let bounds = if consider_var then typq_bindings typq else mt in - binds, List.fold_right (fun (t,_) n -> fv_of_typ consider_var bounds n t) tids mt - | TD_variant(id,typq,tunions,_) -> - let bindings = Nameset.add ("typ:" ^ string_of_id id) (if consider_var then typq_bindings typq else mt) in - typ_variants consider_var bindings tunions - | TD_enum(id,ids,_) -> - Nameset.of_list (("typ:" ^ string_of_id id) :: List.map string_of_id ids),mt - | TD_bitfield(id,typ,_) -> - init_env ("typ:" ^ string_of_id id), Nameset.empty (* fv_of_typ consider_var mt typ *) - -let fv_of_fun consider_var (FD_aux (FD_function(rec_opt,tannot_opt,funcls),_) as fd) = - let fun_name = match funcls with + (ts_bound, fv_of_targ consider_var (Nameset.union bound ts_bound) used typ_arg) + +let fv_of_type_def consider_var (TD_aux (t, _)) = + match t with + | TD_abbrev (id, typq, typ_arg) -> + (init_env ("typ:" ^ string_of_id id), snd (fv_of_abbrev consider_var mt mt typq typ_arg)) + | TD_record (id, typq, tids, _) -> + let binds = init_env ("typ:" ^ string_of_id id) in + let bounds = if consider_var then typq_bindings typq else mt in + (binds, List.fold_right (fun (t, _) n -> fv_of_typ consider_var bounds n t) tids mt) + | TD_variant (id, typq, tunions, _) -> + let bindings = Nameset.add ("typ:" ^ string_of_id id) (if consider_var then typq_bindings typq else mt) in + typ_variants consider_var bindings tunions + | TD_enum (id, ids, _) -> (Nameset.of_list (("typ:" ^ string_of_id id) :: List.map string_of_id ids), mt) + | TD_bitfield (id, typ, _) -> (init_env ("typ:" ^ string_of_id id), Nameset.empty (* fv_of_typ consider_var mt typ *)) + +let fv_of_fun consider_var (FD_aux (FD_function (rec_opt, tannot_opt, funcls), _) as fd) = + let fun_name = + match funcls with | [] -> failwith "fv_of_fun fell off the end looking for the function name" - | FCL_aux(FCL_funcl(id,_),_)::_ -> string_of_id id in - let base_bounds = match rec_opt with + | FCL_aux (FCL_funcl (id, _), _) :: _ -> string_of_id id + in + let base_bounds = + match rec_opt with (* Current Sail does not require syntax for declaring functions as recursive, - and type checker does not check whether functions are recursive, so - just always add a self-dependency of functions on themselves, as well as - adding dependencies from any specified termination measure further below - | Rec_aux(Ast.Rec_rec,_) -> init_env fun_name - | _ -> mt*) - | _ -> init_env fun_name in - let base_bounds,ns_r = match tannot_opt with - | Typ_annot_opt_aux(Typ_annot_opt_some (typq, typ),_) -> - let bindings = if consider_var then typq_bindings typq else mt in - let bound = Nameset.union bindings base_bounds in - bound, fv_of_typ consider_var bound mt typ - | Typ_annot_opt_aux(Typ_annot_opt_none, _) -> - base_bounds, mt in - let ns_measure = match rec_opt with - | Rec_aux(Rec_measure (pat,exp),_) -> - let pat_bs,pat_ns = pat_bindings consider_var base_bounds mt pat in - let _, exp_ns,_ = fv_of_exp consider_var pat_bs pat_ns Nameset.empty exp in - exp_ns + and type checker does not check whether functions are recursive, so + just always add a self-dependency of functions on themselves, as well as + adding dependencies from any specified termination measure further below + | Rec_aux(Ast.Rec_rec,_) -> init_env fun_name + | _ -> mt*) + | _ -> init_env fun_name + in + let base_bounds, ns_r = + match tannot_opt with + | Typ_annot_opt_aux (Typ_annot_opt_some (typq, typ), _) -> + let bindings = if consider_var then typq_bindings typq else mt in + let bound = Nameset.union bindings base_bounds in + (bound, fv_of_typ consider_var bound mt typ) + | Typ_annot_opt_aux (Typ_annot_opt_none, _) -> (base_bounds, mt) + in + let ns_measure = + match rec_opt with + | Rec_aux (Rec_measure (pat, exp), _) -> + let pat_bs, pat_ns = pat_bindings consider_var base_bounds mt pat in + let _, exp_ns, _ = fv_of_exp consider_var pat_bs pat_ns Nameset.empty exp in + exp_ns | _ -> mt in - let ns = List.fold_right (fun (FCL_aux(FCL_funcl(_,pexp),_)) ns -> - match pexp with - | Pat_aux(Pat_exp (pat,exp),_) -> - let pat_bs,pat_ns = pat_bindings consider_var base_bounds ns pat in - let _, exp_ns,_ = fv_of_exp consider_var pat_bs pat_ns Nameset.empty exp in - exp_ns - | Pat_aux(Pat_when (pat,guard,exp),_) -> - let pat_bs,pat_ns = pat_bindings consider_var base_bounds ns pat in - let guard_bs, guard_ns,_ = fv_of_exp consider_var pat_bs pat_ns Nameset.empty guard in - let _, exp_ns,_ = fv_of_exp consider_var guard_bs guard_ns Nameset.empty exp in - exp_ns - ) funcls mt in - let ns_vs = init_env ("val:" ^ (string_of_id (id_of_fundef fd))) in + let ns = + List.fold_right + (fun (FCL_aux (FCL_funcl (_, pexp), _)) ns -> + match pexp with + | Pat_aux (Pat_exp (pat, exp), _) -> + let pat_bs, pat_ns = pat_bindings consider_var base_bounds ns pat in + let _, exp_ns, _ = fv_of_exp consider_var pat_bs pat_ns Nameset.empty exp in + exp_ns + | Pat_aux (Pat_when (pat, guard, exp), _) -> + let pat_bs, pat_ns = pat_bindings consider_var base_bounds ns pat in + let guard_bs, guard_ns, _ = fv_of_exp consider_var pat_bs pat_ns Nameset.empty guard in + let _, exp_ns, _ = fv_of_exp consider_var guard_bs guard_ns Nameset.empty exp in + exp_ns + ) + funcls mt + in + let ns_vs = init_env ("val:" ^ string_of_id (id_of_fundef fd)) in (* let _ = Printf.eprintf "Function %s uses %s\n" fun_name (set_to_string (Nameset.union ns ns_r)) in *) - init_env fun_name, Nameset.union ns_vs (Nameset.union ns (Nameset.union ns_r ns_measure)) + (init_env fun_name, Nameset.union ns_vs (Nameset.union ns (Nameset.union ns_r ns_measure))) -let fv_of_vspec consider_var (VS_aux(vspec,_)) = match vspec with - | VS_val_spec(ts,id,_,_) -> - init_env ("val:" ^ (string_of_id id)), snd (fv_of_typschm consider_var mt mt ts) +let fv_of_vspec consider_var (VS_aux (vspec, _)) = + match vspec with + | VS_val_spec (ts, id, _, _) -> (init_env ("val:" ^ string_of_id id), snd (fv_of_typschm consider_var mt mt ts)) let rec find_scattered_of name = function | [] -> [] - | DEF_scattered (SD_aux(sda,_) as sd):: defs -> - (match sda with - | SD_function(_,_,id) - | SD_funcl(FCL_aux(FCL_funcl(id,_),_)) - | SD_unioncl(id,_) -> - if name = string_of_id id - then [sd] else [] - | _ -> [])@ - (find_scattered_of name defs) - | _::defs -> find_scattered_of name defs - -let rec fv_of_scattered consider_var consider_scatter_as_one all_defs (SD_aux(sd,(l, _))) = match sd with - | SD_function(_,tannot_opt,id) -> - let b,ns = (match tannot_opt with - | Typ_annot_opt_aux(Typ_annot_opt_some (typq, typ),_) -> - let bindings = if consider_var then typq_bindings typq else mt in - bindings, fv_of_typ consider_var bindings mt typ - | Typ_annot_opt_aux(Typ_annot_opt_none, _) -> - mt, mt) in - init_env (string_of_id id),ns - | SD_funcl (FCL_aux(FCL_funcl(id,pexp),_)) -> - begin - match pexp with - | Pat_aux(Pat_exp (pat,exp),_) -> - let pat_bs,pat_ns = pat_bindings consider_var mt mt pat in - let _,exp_ns,_ = fv_of_exp consider_var pat_bs pat_ns Nameset.empty exp in - let scattered_binds = match pat with - | P_aux(P_app(pid,_),_) -> init_env ((string_of_id id) ^ "/" ^ (string_of_id pid)) - | _ -> mt in - scattered_binds, exp_ns - | Pat_aux(Pat_when (pat,guard,exp),_) -> - let pat_bs,pat_ns = pat_bindings consider_var mt mt pat in - let guard_bs, guard_ns,_ = fv_of_exp consider_var pat_bs pat_ns Nameset.empty guard in - let _, exp_ns,_ = fv_of_exp consider_var guard_bs guard_ns Nameset.empty exp in - let scattered_binds = match pat with - | P_aux(P_app(pid,_),_) -> init_env ((string_of_id id) ^ "/" ^ (string_of_id pid)) - | _ -> mt in - scattered_binds, exp_ns - end - | SD_variant (id,_) -> - let name = string_of_id id in - let uses = - if consider_scatter_as_one - then - let variant_defs = find_scattered_of name all_defs in - let pieces_uses = - List.fold_right (fun (binds,uses) all_uses -> Nameset.union uses all_uses) - (List.map (fv_of_scattered consider_var false []) variant_defs) mt in - Nameset.remove name pieces_uses - else mt in - init_env name, uses - | SD_unioncl(id, type_union) -> - let typ_name = string_of_id id in - let b = init_env typ_name in - let (b,r) = typ_variants consider_var b [type_union] in - (Nameset.remove typ_name b, Nameset.add typ_name r) + | DEF_scattered (SD_aux (sda, _) as sd) :: defs -> + ( match sda with + | SD_function (_, _, id) | SD_funcl (FCL_aux (FCL_funcl (id, _), _)) | SD_unioncl (id, _) -> + if name = string_of_id id then [sd] else [] + | _ -> [] + ) + @ find_scattered_of name defs + | _ :: defs -> find_scattered_of name defs + +let rec fv_of_scattered consider_var consider_scatter_as_one all_defs (SD_aux (sd, (l, _))) = + match sd with + | SD_function (_, tannot_opt, id) -> + let b, ns = + match tannot_opt with + | Typ_annot_opt_aux (Typ_annot_opt_some (typq, typ), _) -> + let bindings = if consider_var then typq_bindings typq else mt in + (bindings, fv_of_typ consider_var bindings mt typ) + | Typ_annot_opt_aux (Typ_annot_opt_none, _) -> (mt, mt) + in + (init_env (string_of_id id), ns) + | SD_funcl (FCL_aux (FCL_funcl (id, pexp), _)) -> begin + match pexp with + | Pat_aux (Pat_exp (pat, exp), _) -> + let pat_bs, pat_ns = pat_bindings consider_var mt mt pat in + let _, exp_ns, _ = fv_of_exp consider_var pat_bs pat_ns Nameset.empty exp in + let scattered_binds = + match pat with P_aux (P_app (pid, _), _) -> init_env (string_of_id id ^ "/" ^ string_of_id pid) | _ -> mt + in + (scattered_binds, exp_ns) + | Pat_aux (Pat_when (pat, guard, exp), _) -> + let pat_bs, pat_ns = pat_bindings consider_var mt mt pat in + let guard_bs, guard_ns, _ = fv_of_exp consider_var pat_bs pat_ns Nameset.empty guard in + let _, exp_ns, _ = fv_of_exp consider_var guard_bs guard_ns Nameset.empty exp in + let scattered_binds = + match pat with P_aux (P_app (pid, _), _) -> init_env (string_of_id id ^ "/" ^ string_of_id pid) | _ -> mt + in + (scattered_binds, exp_ns) + end + | SD_variant (id, _) -> + let name = string_of_id id in + let uses = + if consider_scatter_as_one then ( + let variant_defs = find_scattered_of name all_defs in + let pieces_uses = + List.fold_right + (fun (binds, uses) all_uses -> Nameset.union uses all_uses) + (List.map (fv_of_scattered consider_var false []) variant_defs) + mt + in + Nameset.remove name pieces_uses + ) + else mt + in + (init_env name, uses) + | SD_unioncl (id, type_union) -> + let typ_name = string_of_id id in + let b = init_env typ_name in + let b, r = typ_variants consider_var b [type_union] in + (Nameset.remove typ_name b, Nameset.add typ_name r) | SD_end id -> - let name = string_of_id id in - let uses = if consider_scatter_as_one - (*Note: if this is a function ending, the dec is included *) - then - let scattered_defs = find_scattered_of name all_defs in - List.fold_right (fun (binds,uses) all_uses -> Nameset.union (Nameset.union binds uses) all_uses) - (List.map (fv_of_scattered consider_var false []) scattered_defs) (init_env name) - else init_env name in - init_env (name ^ "/end"), uses + let name = string_of_id id in + let uses = + if consider_scatter_as_one (*Note: if this is a function ending, the dec is included *) then ( + let scattered_defs = find_scattered_of name all_defs in + List.fold_right + (fun (binds, uses) all_uses -> Nameset.union (Nameset.union binds uses) all_uses) + (List.map (fv_of_scattered consider_var false []) scattered_defs) + (init_env name) + ) + else init_env name + in + (init_env (name ^ "/end"), uses) | _ -> raise (Reporting.err_unreachable l __POS__ "Tried to find free variables for scattered mapping clause") let fv_of_rd consider_var (DEC_aux (d, annot)) = @@ -505,66 +510,60 @@ let fv_of_rd consider_var (DEC_aux (d, annot)) = let open Type_check in let env = env_of_annot annot in match d with - | DEC_reg(t, id, _) -> - let t' = Env.expand_synonyms env t in - init_env (string_of_id id), - Nameset.union (fv_of_typ consider_var mt mt t) (fv_of_typ consider_var mt mt t') + | DEC_reg (t, id, _) -> + let t' = Env.expand_synonyms env t in + (init_env (string_of_id id), Nameset.union (fv_of_typ consider_var mt mt t) (fv_of_typ consider_var mt mt t')) let fv_of_def consider_var consider_scatter_as_one all_defs (DEF_aux (aux, _) as def) = match aux with | DEF_type tdef -> fv_of_type_def consider_var tdef | DEF_fundef fdef -> fv_of_fun consider_var fdef - | DEF_mapdef mdef -> mt,mt (* fv_of_map consider_var mdef *) - | DEF_let lebind -> ((fun (b,u,_) -> (b,u)) (fv_of_let consider_var mt mt mt lebind)) + | DEF_mapdef mdef -> (mt, mt (* fv_of_map consider_var mdef *)) + | DEF_let lebind -> (fun (b, u, _) -> (b, u)) (fv_of_let consider_var mt mt mt lebind) | DEF_val vspec -> fv_of_vspec consider_var vspec - | DEF_fixity _ -> mt,mt - | DEF_overload (id,ids) -> - init_env (string_of_id id), - List.fold_left (fun ns id -> Nameset.add ("val:" ^ string_of_id id) ns) mt ids - | DEF_default def -> mt,mt + | DEF_fixity _ -> (mt, mt) + | DEF_overload (id, ids) -> + (init_env (string_of_id id), List.fold_left (fun ns id -> Nameset.add ("val:" ^ string_of_id id) ns) mt ids) + | DEF_default def -> (mt, mt) | DEF_internal_mutrec fdefs -> - let fvs = List.map (fv_of_fun consider_var) fdefs in - List.fold_left Nameset.union Nameset.empty (List.map fst fvs), - List.fold_left Nameset.union Nameset.empty (List.map snd fvs) + let fvs = List.map (fv_of_fun consider_var) fdefs in + ( List.fold_left Nameset.union Nameset.empty (List.map fst fvs), + List.fold_left Nameset.union Nameset.empty (List.map snd fvs) + ) | DEF_scattered sdef -> fv_of_scattered consider_var consider_scatter_as_one all_defs sdef | DEF_register rdec -> fv_of_rd consider_var rdec - | DEF_pragma _ -> mt,mt + | DEF_pragma _ -> (mt, mt) (* removed beforehand for Coq, but may still be present otherwise *) - | DEF_measure(id,pat,exp) -> - let i = string_of_id id in - let used = Nameset.of_list [i; "val:"^i] in - ((fun (_,u,_) -> Nameset.singleton ("measure:"^i),u) - (fv_of_pes consider_var mt used mt - [Pat_aux(Pat_exp (pat,exp),(Unknown,Type_check.empty_tannot))])) + | DEF_measure (id, pat, exp) -> + let i = string_of_id id in + let used = Nameset.of_list [i; "val:" ^ i] in + (fun (_, u, _) -> (Nameset.singleton ("measure:" ^ i), u)) + (fv_of_pes consider_var mt used mt [Pat_aux (Pat_exp (pat, exp), (Unknown, Type_check.empty_tannot))]) | DEF_loop_measures _ | DEF_impl _ | DEF_outcome _ | DEF_instantiation _ -> - Reporting.unreachable (def_loc def) __POS__ - "Found definition that should have been rewritten previously" + Reporting.unreachable (def_loc def) __POS__ "Found definition that should have been rewritten previously" (* * Sorting definitions, take 3 *) -module Namemap = Map.Make(String) -module NameGraph = Graph.Make(String) +module Namemap = Map.Make (String) +module NameGraph = Graph.Make (String) let add_def_to_map id d defset = - Namemap.add id - (match Namemap.find id defset with - | t -> t@[d] - | exception Not_found -> [d]) - defset + Namemap.add id (match Namemap.find id defset with t -> t @ [d] | exception Not_found -> [d]) defset let add_def_to_graph (prelude, original_order, defset, graph) d = let bound, used = fv_of_def false true [] d in - let used = match d with + let used = + match d with | DEF_aux (DEF_register (DEC_aux (DEC_reg (typ, _, _), annot)), _) -> - (* For a register, we need to ensure that any undefined_type - functions for types used by the register are placed before - the register declaration. *) - let env = Type_check.env_of_annot annot in - let typ' = Type_check.Env.expand_synonyms env typ in - let undefineds = Nameset.map (fun name -> "undefined_" ^ name) (free_type_names_t false typ') in - Nameset.union undefineds used + (* For a register, we need to ensure that any undefined_type + functions for types used by the register are placed before + the register declaration. *) + let env = Type_check.env_of_annot annot in + let typ' = Type_check.Env.expand_synonyms env typ in + let undefineds = Nameset.map (fun name -> "undefined_" ^ name) (free_type_names_t false typ') in + Nameset.union undefineds used | _ -> used in try @@ -578,90 +577,85 @@ let add_def_to_graph (prelude, original_order, defset, graph) d = NameGraph.add_edges id (other_ids @ Nameset.elements used) graph |> List.fold_right (fun id' g -> NameGraph.add_edge id' id g) other_ids in - prelude, - original_order @ [id], - add_def_to_map id d defset, - graph' - with - | Not_found -> - (* Some definitions do not bind any identifiers at all. This *should* - only happen for default bitvector order declarations, operator fixity - declarations, and comments. The sorting does not (currently) attempt - to preserve the positions of these AST nodes; they are collected - separately and placed at the beginning of the output. Comments are - currently ignored by the Lem and OCaml backends, anyway. For - default order and fixity declarations, this means that specifications - currently have to assume those declarations are moved to the - beginning when using a backend that requires topological sorting. *) - prelude @ [d], original_order, defset, graph + (prelude, original_order @ [id], add_def_to_map id d defset, graph') + with Not_found -> + (* Some definitions do not bind any identifiers at all. This *should* + only happen for default bitvector order declarations, operator fixity + declarations, and comments. The sorting does not (currently) attempt + to preserve the positions of these AST nodes; they are collected + separately and placed at the beginning of the output. Comments are + currently ignored by the Lem and OCaml backends, anyway. For + default order and fixity declarations, this means that specifications + currently have to assume those declarations are moved to the + beginning when using a backend that requires topological sorting. *) + (prelude @ [d], original_order, defset, graph) let def_of_component graph defset comp = let get_def id = if Namemap.mem id defset then Namemap.find id defset else [] in match List.concat (List.map get_def comp) with | [] -> [] | [def] -> [def] - | ((DEF_aux ((DEF_fundef _ | DEF_internal_mutrec _), _) as def) :: _) as defs -> - let get_fundefs = function - | DEF_aux (DEF_fundef fundef, _) -> [fundef] - | DEF_aux (DEF_internal_mutrec fundefs, _) -> fundefs - | _ -> - raise (Reporting.err_unreachable (def_loc def) __POS__ - "Trying to merge non-function definition with mutually recursive functions") in - let fundefs = List.concat (List.map get_fundefs defs) in - (* print_dot graph (List.map (fun fd -> string_of_id (id_of_fundef fd)) fundefs); *) - [mk_def (DEF_internal_mutrec fundefs)] + | (DEF_aux ((DEF_fundef _ | DEF_internal_mutrec _), _) as def) :: _ as defs -> + let get_fundefs = function + | DEF_aux (DEF_fundef fundef, _) -> [fundef] + | DEF_aux (DEF_internal_mutrec fundefs, _) -> fundefs + | _ -> + raise + (Reporting.err_unreachable (def_loc def) __POS__ + "Trying to merge non-function definition with mutually recursive functions" + ) + in + let fundefs = List.concat (List.map get_fundefs defs) in + (* print_dot graph (List.map (fun fd -> string_of_id (id_of_fundef fd)) fundefs); *) + [mk_def (DEF_internal_mutrec fundefs)] (* We could merge other stuff, in particular overloads, but don't need to just now *) | defs -> defs let top_sort_defs ast = let prelude, original_order, defset, graph = - List.fold_left add_def_to_graph ([], [], Namemap.empty, Namemap.empty) ast.defs in - let components = NameGraph.scc ~original_order:original_order graph in + List.fold_left add_def_to_graph ([], [], Namemap.empty, Namemap.empty) ast.defs + in + let components = NameGraph.scc ~original_order graph in { ast with defs = prelude @ List.concat (List.map (def_of_component graph defset) components) } (* Functions for finding the set of variables assigned to. Used in constant propagation and monomorphisation. *) let assigned_vars exp = - (Rewriter.fold_exp - { (Rewriter.pure_exp_alg IdSet.empty IdSet.union) with - Rewriter.le_id = (fun id -> IdSet.singleton id); - Rewriter.le_typ = (fun (ty,id) -> IdSet.singleton id) } - exp) + Rewriter.fold_exp + { + (Rewriter.pure_exp_alg IdSet.empty IdSet.union) with + Rewriter.le_id = (fun id -> IdSet.singleton id); + Rewriter.le_typ = (fun (ty, id) -> IdSet.singleton id); + } + exp let assigned_vars_in_fexps fes = - List.fold_left - (fun vs (FE_aux (FE_fexp (_,e),_)) -> IdSet.union vs (assigned_vars e)) - IdSet.empty - fes + List.fold_left (fun vs (FE_aux (FE_fexp (_, e), _)) -> IdSet.union vs (assigned_vars e)) IdSet.empty fes -let assigned_vars_in_pexp (Pat_aux (p,_)) = +let assigned_vars_in_pexp (Pat_aux (p, _)) = match p with - | Pat_exp (_,e) -> assigned_vars e - | Pat_when (p,e1,e2) -> IdSet.union (assigned_vars e1) (assigned_vars e2) + | Pat_exp (_, e) -> assigned_vars e + | Pat_when (p, e1, e2) -> IdSet.union (assigned_vars e1) (assigned_vars e2) -let rec assigned_vars_in_lexp (LE_aux (le,_)) = +let rec assigned_vars_in_lexp (LE_aux (le, _)) = match le with - | LE_id id - | LE_typ (_,id) -> IdSet.singleton id - | LE_tuple lexps - | LE_vector_concat lexps -> - List.fold_left (fun vs le -> IdSet.union vs (assigned_vars_in_lexp le)) IdSet.empty lexps - | LE_app (_,es) -> List.fold_left (fun vs e -> IdSet.union vs (assigned_vars e)) IdSet.empty es - | LE_vector (le,e) -> IdSet.union (assigned_vars_in_lexp le) (assigned_vars e) - | LE_vector_range (le,e1,e2) -> - IdSet.union (assigned_vars_in_lexp le) (IdSet.union (assigned_vars e1) (assigned_vars e2)) - | LE_field (le,_) -> assigned_vars_in_lexp le + | LE_id id | LE_typ (_, id) -> IdSet.singleton id + | LE_tuple lexps | LE_vector_concat lexps -> + List.fold_left (fun vs le -> IdSet.union vs (assigned_vars_in_lexp le)) IdSet.empty lexps + | LE_app (_, es) -> List.fold_left (fun vs e -> IdSet.union vs (assigned_vars e)) IdSet.empty es + | LE_vector (le, e) -> IdSet.union (assigned_vars_in_lexp le) (assigned_vars e) + | LE_vector_range (le, e1, e2) -> + IdSet.union (assigned_vars_in_lexp le) (IdSet.union (assigned_vars e1) (assigned_vars e2)) + | LE_field (le, _) -> assigned_vars_in_lexp le | LE_deref e -> assigned_vars e let bound_vars exp = let open Rewriter in - let pat_alg = { - (pure_pat_alg IdSet.empty IdSet.union) with - p_id = IdSet.singleton; - p_as = (fun (ids, id) -> IdSet.add id ids) - } in - fold_exp { (pure_exp_alg IdSet.empty IdSet.union) with pat_alg = pat_alg } exp + let pat_alg = + { (pure_pat_alg IdSet.empty IdSet.union) with p_id = IdSet.singleton; p_as = (fun (ids, id) -> IdSet.add id ids) } + in + fold_exp { (pure_exp_alg IdSet.empty IdSet.union) with pat_alg } exp let pat_id_is_variable env id = match Type_check.Env.lookup_id id env with @@ -670,43 +664,35 @@ let pat_id_is_variable env id = | Unbound _ (* Shadowing of immutable locals is allowed; mutable locals and registers are rejected by the type checker, so don't matter *) - | Local _ - | Register _ - -> true + | Local _ | Register _ -> + true | Enum _ -> false let bindings_from_pat p = - let rec aux_pat (P_aux (p,(l,annot))) = + let rec aux_pat (P_aux (p, (l, annot))) = let env = Type_check.env_of_annot (l, annot) in match p with - | P_lit _ - | P_wild - -> [] + | P_lit _ | P_wild -> [] | P_or (p1, p2) -> aux_pat p1 @ aux_pat p2 - | P_not (p) -> aux_pat p - | P_as (p,id) -> id::(aux_pat p) - | P_typ (_,p) -> aux_pat p + | P_not p -> aux_pat p + | P_as (p, id) -> id :: aux_pat p + | P_typ (_, p) -> aux_pat p | P_vector_subrange (id, _, _) -> [id] - | P_id id -> - if pat_id_is_variable env id then [id] else [] - | P_var (p,kid) -> aux_pat p - | P_vector ps - | P_vector_concat ps - | P_string_append ps - | P_app (_,ps) - | P_tuple ps - | P_list ps - -> List.concat (List.map aux_pat ps) - | P_cons (p1,p2) -> aux_pat p1 @ aux_pat p2 - in aux_pat p + | P_id id -> if pat_id_is_variable env id then [id] else [] + | P_var (p, kid) -> aux_pat p + | P_vector ps | P_vector_concat ps | P_string_append ps | P_app (_, ps) | P_tuple ps | P_list ps -> + List.concat (List.map aux_pat ps) + | P_cons (p1, p2) -> aux_pat p1 @ aux_pat p2 + in + aux_pat p (* TODO: replace the below with solutions that don't depend so much on the structure of the environment. *) let rec flatten_constraints = function | [] -> [] - | (NC_aux (NC_and (nc1,nc2),_))::t -> flatten_constraints (nc1::nc2::t) - | h::t -> h::(flatten_constraints t) + | NC_aux (NC_and (nc1, nc2), _) :: t -> flatten_constraints (nc1 :: nc2 :: t) + | h :: t -> h :: flatten_constraints t (* NB: this only looks for direct equalities with the given kid. It would be better in principle to find the entire set of equal kids, but it isn't @@ -714,10 +700,10 @@ let rec flatten_constraints = function checking P_var patterns, so we don't do it for now. *) let equal_kids_ncs kid ncs = let rec add_equal_kids_nc s = function - | NC_aux (NC_equal (Nexp_aux (Nexp_var var1,_), Nexp_aux (Nexp_var var2,_)),_) -> - if Kid.compare kid var1 == 0 then KidSet.add var2 s else - if Kid.compare kid var2 == 0 then KidSet.add var1 s else - s + | NC_aux (NC_equal (Nexp_aux (Nexp_var var1, _), Nexp_aux (Nexp_var var2, _)), _) -> + if Kid.compare kid var1 == 0 then KidSet.add var2 s + else if Kid.compare kid var2 == 0 then KidSet.add var1 s + else s | NC_aux (NC_and (nc1, nc2), _) -> add_equal_kids_nc (add_equal_kids_nc s nc1) nc2 | _ -> s in @@ -727,110 +713,98 @@ let equal_kids env kid = let ncs = flatten_constraints (Type_check.Env.get_constraints env) in equal_kids_ncs kid ncs - - (* TODO: kid shadowing *) let nexp_subst_fns substs = let s_t t = subst_kids_typ substs t in -(* let s_typschm (TypSchm_aux (TypSchm_ts (q,t),l)) = TypSchm_aux (TypSchm_ts (q,s_t t),l) in - hopefully don't need this anyway *)(* + (* let s_typschm (TypSchm_aux (TypSchm_ts (q,t),l)) = TypSchm_aux (TypSchm_ts (q,s_t t),l) in + hopefully don't need this anyway *) + (* let s_typschm tsh = tsh in*) let s_tannot tannot = match Type_check.destruct_tannot tannot with | None -> Type_check.empty_tannot - | Some (env,t) -> Type_check.mk_tannot env (s_t t) (* TODO: what about env? *) + | Some (env, t) -> Type_check.mk_tannot env (s_t t) + (* TODO: what about env? *) in - let rec s_pat (P_aux (p,(l,annot))) = - let re p = P_aux (p,(l,s_tannot annot)) in + let rec s_pat (P_aux (p, (l, annot))) = + let re p = P_aux (p, (l, s_tannot annot)) in match p with | P_lit _ | P_wild | P_id _ | P_vector_subrange _ -> re p | P_or (p1, p2) -> re (P_or (s_pat p1, s_pat p2)) - | P_not (p) -> re (P_not (s_pat p)) - | P_var (p',tpat) -> re (P_var (s_pat p',tpat)) - | P_as (p',id) -> re (P_as (s_pat p', id)) - | P_typ (ty,p') -> re (P_typ (s_t ty,s_pat p')) - | P_app (id,ps) -> re (P_app (id, List.map s_pat ps)) + | P_not p -> re (P_not (s_pat p)) + | P_var (p', tpat) -> re (P_var (s_pat p', tpat)) + | P_as (p', id) -> re (P_as (s_pat p', id)) + | P_typ (ty, p') -> re (P_typ (s_t ty, s_pat p')) + | P_app (id, ps) -> re (P_app (id, List.map s_pat ps)) | P_vector ps -> re (P_vector (List.map s_pat ps)) | P_vector_concat ps -> re (P_vector_concat (List.map s_pat ps)) | P_string_append ps -> re (P_string_append (List.map s_pat ps)) | P_tuple ps -> re (P_tuple (List.map s_pat ps)) | P_list ps -> re (P_list (List.map s_pat ps)) - | P_cons (p1,p2) -> re (P_cons (s_pat p1, s_pat p2)) + | P_cons (p1, p2) -> re (P_cons (s_pat p1, s_pat p2)) in - let rec s_exp (E_aux (e,(l,annot))) = - let re e = E_aux (e,(l,s_tannot annot)) in - match e with - | E_block es -> re (E_block (List.map s_exp es)) - | E_id _ - | E_ref _ - | E_lit _ - | E_internal_value _ - -> re e - | E_sizeof ne -> begin - let ne' = subst_kids_nexp substs ne in - match ne' with - | Nexp_aux (Nexp_constant i,l) -> re (E_lit (L_aux (L_num i,l))) - | _ -> re (E_sizeof ne') + let rec s_exp (E_aux (e, (l, annot))) = + let re e = E_aux (e, (l, s_tannot annot)) in + match e with + | E_block es -> re (E_block (List.map s_exp es)) + | E_id _ | E_ref _ | E_lit _ | E_internal_value _ -> re e + | E_sizeof ne -> begin + let ne' = subst_kids_nexp substs ne in + match ne' with Nexp_aux (Nexp_constant i, l) -> re (E_lit (L_aux (L_num i, l))) | _ -> re (E_sizeof ne') end - | E_constraint nc -> re (E_constraint (subst_kids_nc substs nc)) - | E_typ (t,e') -> re (E_typ (s_t t, s_exp e')) - | E_app (id,es) -> re (E_app (id, List.map s_exp es)) - | E_app_infix (e1,id,e2) -> re (E_app_infix (s_exp e1,id,s_exp e2)) - | E_tuple es -> re (E_tuple (List.map s_exp es)) - | E_if (e1,e2,e3) -> re (E_if (s_exp e1, s_exp e2, s_exp e3)) - | E_for (id,e1,e2,e3,ord,e4) -> re (E_for (id,s_exp e1,s_exp e2,s_exp e3,ord,s_exp e4)) - | E_loop (loop,m,e1,e2) -> re (E_loop (loop,s_measure m,s_exp e1,s_exp e2)) - | E_vector es -> re (E_vector (List.map s_exp es)) - | E_vector_access (e1,e2) -> re (E_vector_access (s_exp e1,s_exp e2)) - | E_vector_subrange (e1,e2,e3) -> re (E_vector_subrange (s_exp e1,s_exp e2,s_exp e3)) - | E_vector_update (e1,e2,e3) -> re (E_vector_update (s_exp e1,s_exp e2,s_exp e3)) - | E_vector_update_subrange (e1,e2,e3,e4) -> re (E_vector_update_subrange (s_exp e1,s_exp e2,s_exp e3,s_exp e4)) - | E_vector_append (e1,e2) -> re (E_vector_append (s_exp e1,s_exp e2)) - | E_list es -> re (E_list (List.map s_exp es)) - | E_cons (e1,e2) -> re (E_cons (s_exp e1,s_exp e2)) - | E_struct fes -> re (E_struct (List.map s_fexp fes)) - | E_struct_update (e,fes) -> re (E_struct_update (s_exp e, List.map s_fexp fes)) - | E_field (e,id) -> re (E_field (s_exp e,id)) - | E_match (e,cases) -> re (E_match (s_exp e, List.map s_pexp cases)) - | E_let (lb,e) -> re (E_let (s_letbind lb, s_exp e)) - | E_assign (le,e) -> re (E_assign (s_lexp le, s_exp e)) - | E_exit e -> re (E_exit (s_exp e)) - | E_return e -> re (E_return (s_exp e)) - | E_assert (e1,e2) -> re (E_assert (s_exp e1,s_exp e2)) - | E_var (le,e1,e2) -> re (E_var (s_lexp le, s_exp e1, s_exp e2)) - | E_internal_plet (p,e1,e2) -> re (E_internal_plet (s_pat p, s_exp e1, s_exp e2)) - | E_internal_return e -> re (E_internal_return (s_exp e)) - | E_throw e -> re (E_throw (s_exp e)) - | E_try (e,cases) -> re (E_try (s_exp e, List.map s_pexp cases)) - | E_internal_assume (nc, e) -> re (E_internal_assume (subst_kids_nc substs nc, s_exp e)) - and s_measure (Measure_aux (m,l)) = - let m = match m with - | Measure_none -> m - | Measure_some exp -> Measure_some (s_exp exp) - in - Measure_aux (m,l) - and s_fexp (FE_aux (FE_fexp (id,e), (l,annot))) = - FE_aux (FE_fexp (id,s_exp e),(l,s_tannot annot)) - and s_pexp = function - | (Pat_aux (Pat_exp (p,e),(l,annot))) -> - Pat_aux (Pat_exp (s_pat p, s_exp e),(l,s_tannot annot)) - | (Pat_aux (Pat_when (p,e1,e2),(l,annot))) -> - Pat_aux (Pat_when (s_pat p, s_exp e1, s_exp e2),(l,s_tannot annot)) - and s_letbind (LB_aux (lb,(l,annot))) = - match lb with - | LB_val (p,e) -> LB_aux (LB_val (s_pat p,s_exp e), (l,s_tannot annot)) - and s_lexp (LE_aux (e,(l,annot))) = - let re e = LE_aux (e,(l,s_tannot annot)) in - match e with - | LE_id _ -> re e - | LE_typ (typ,id) -> re (LE_typ (s_t typ, id)) - | LE_app (id,es) -> re (LE_app (id,List.map s_exp es)) - | LE_tuple les -> re (LE_tuple (List.map s_lexp les)) - | LE_vector (le,e) -> re (LE_vector (s_lexp le, s_exp e)) - | LE_vector_range (le,e1,e2) -> re (LE_vector_range (s_lexp le, s_exp e1, s_exp e2)) - | LE_vector_concat les -> re (LE_vector_concat (List.map s_lexp les)) - | LE_field (le,id) -> re (LE_field (s_lexp le, id)) - | LE_deref e -> re (LE_deref (s_exp e)) - in (s_pat,s_exp) + | E_constraint nc -> re (E_constraint (subst_kids_nc substs nc)) + | E_typ (t, e') -> re (E_typ (s_t t, s_exp e')) + | E_app (id, es) -> re (E_app (id, List.map s_exp es)) + | E_app_infix (e1, id, e2) -> re (E_app_infix (s_exp e1, id, s_exp e2)) + | E_tuple es -> re (E_tuple (List.map s_exp es)) + | E_if (e1, e2, e3) -> re (E_if (s_exp e1, s_exp e2, s_exp e3)) + | E_for (id, e1, e2, e3, ord, e4) -> re (E_for (id, s_exp e1, s_exp e2, s_exp e3, ord, s_exp e4)) + | E_loop (loop, m, e1, e2) -> re (E_loop (loop, s_measure m, s_exp e1, s_exp e2)) + | E_vector es -> re (E_vector (List.map s_exp es)) + | E_vector_access (e1, e2) -> re (E_vector_access (s_exp e1, s_exp e2)) + | E_vector_subrange (e1, e2, e3) -> re (E_vector_subrange (s_exp e1, s_exp e2, s_exp e3)) + | E_vector_update (e1, e2, e3) -> re (E_vector_update (s_exp e1, s_exp e2, s_exp e3)) + | E_vector_update_subrange (e1, e2, e3, e4) -> re (E_vector_update_subrange (s_exp e1, s_exp e2, s_exp e3, s_exp e4)) + | E_vector_append (e1, e2) -> re (E_vector_append (s_exp e1, s_exp e2)) + | E_list es -> re (E_list (List.map s_exp es)) + | E_cons (e1, e2) -> re (E_cons (s_exp e1, s_exp e2)) + | E_struct fes -> re (E_struct (List.map s_fexp fes)) + | E_struct_update (e, fes) -> re (E_struct_update (s_exp e, List.map s_fexp fes)) + | E_field (e, id) -> re (E_field (s_exp e, id)) + | E_match (e, cases) -> re (E_match (s_exp e, List.map s_pexp cases)) + | E_let (lb, e) -> re (E_let (s_letbind lb, s_exp e)) + | E_assign (le, e) -> re (E_assign (s_lexp le, s_exp e)) + | E_exit e -> re (E_exit (s_exp e)) + | E_return e -> re (E_return (s_exp e)) + | E_assert (e1, e2) -> re (E_assert (s_exp e1, s_exp e2)) + | E_var (le, e1, e2) -> re (E_var (s_lexp le, s_exp e1, s_exp e2)) + | E_internal_plet (p, e1, e2) -> re (E_internal_plet (s_pat p, s_exp e1, s_exp e2)) + | E_internal_return e -> re (E_internal_return (s_exp e)) + | E_throw e -> re (E_throw (s_exp e)) + | E_try (e, cases) -> re (E_try (s_exp e, List.map s_pexp cases)) + | E_internal_assume (nc, e) -> re (E_internal_assume (subst_kids_nc substs nc, s_exp e)) + and s_measure (Measure_aux (m, l)) = + let m = match m with Measure_none -> m | Measure_some exp -> Measure_some (s_exp exp) in + Measure_aux (m, l) + and s_fexp (FE_aux (FE_fexp (id, e), (l, annot))) = FE_aux (FE_fexp (id, s_exp e), (l, s_tannot annot)) + and s_pexp = function + | Pat_aux (Pat_exp (p, e), (l, annot)) -> Pat_aux (Pat_exp (s_pat p, s_exp e), (l, s_tannot annot)) + | Pat_aux (Pat_when (p, e1, e2), (l, annot)) -> Pat_aux (Pat_when (s_pat p, s_exp e1, s_exp e2), (l, s_tannot annot)) + and s_letbind (LB_aux (lb, (l, annot))) = + match lb with LB_val (p, e) -> LB_aux (LB_val (s_pat p, s_exp e), (l, s_tannot annot)) + and s_lexp (LE_aux (e, (l, annot))) = + let re e = LE_aux (e, (l, s_tannot annot)) in + match e with + | LE_id _ -> re e + | LE_typ (typ, id) -> re (LE_typ (s_t typ, id)) + | LE_app (id, es) -> re (LE_app (id, List.map s_exp es)) + | LE_tuple les -> re (LE_tuple (List.map s_lexp les)) + | LE_vector (le, e) -> re (LE_vector (s_lexp le, s_exp e)) + | LE_vector_range (le, e1, e2) -> re (LE_vector_range (s_lexp le, s_exp e1, s_exp e2)) + | LE_vector_concat les -> re (LE_vector_concat (List.map s_lexp les)) + | LE_field (le, id) -> re (LE_field (s_lexp le, id)) + | LE_deref e -> re (LE_deref (s_exp e)) + in + (s_pat, s_exp) let nexp_subst_pat substs = fst (nexp_subst_fns substs) let nexp_subst_exp substs = snd (nexp_subst_fns substs) diff --git a/src/lib/spec_analysis.mli b/src/lib/spec_analysis.mli index 3f9b8d60a..d00e65cc2 100644 --- a/src/lib/spec_analysis.mli +++ b/src/lib/spec_analysis.mli @@ -71,16 +71,16 @@ open Ast_util open Util open Type_check -(*Determines if the first typ is within the range of the the second typ, - using the constraints provided when the first typ contains variables. +(*Determines if the first typ is within the range of the the second typ, + using the constraints provided when the first typ contains variables. It is an error for second typ to be anything other than a range type - If the first typ is a vector, then determines if the max representable + If the first typ is a vector, then determines if the max representable number is in the range of the second; it is an error for the first typ to be anything other than a vector, a range, an atom, or a bit (after - suitable unwrapping of abbreviations, reg, and registers). + suitable unwrapping of abbreviations, reg, and registers). *) (* val is_within_range: typ -> typ -> nexp_range list -> triple -val is_within_machine64 : typ -> nexp_range list -> triple *) + val is_within_machine64 : typ -> nexp_range list -> triple *) (* free variables and dependencies *) @@ -98,12 +98,14 @@ val top_sort_defs : tannot ast -> tannot ast (** Return the set of mutable variables assigned to in the given AST. *) val assigned_vars : 'a exp -> IdSet.t + val assigned_vars_in_fexps : 'a fexp list -> IdSet.t val assigned_vars_in_pexp : 'a pexp -> IdSet.t val assigned_vars_in_lexp : 'a lexp -> IdSet.t (** Variable bindings in patterns and expressions *) val pat_id_is_variable : env -> id -> bool + val bindings_from_pat : tannot pat -> id list val bound_vars : 'a exp -> IdSet.t @@ -113,4 +115,5 @@ val equal_kids : env -> kid -> KidSet.t (** Type-level substitutions into patterns and expressions. Also attempts to update type annotations, but not the associated environments. *) val nexp_subst_pat : nexp KBindings.t -> tannot pat -> tannot pat + val nexp_subst_exp : nexp KBindings.t -> tannot exp -> tannot exp diff --git a/src/lib/specialize.ml b/src/lib/specialize.ml index 34fc8232c..1535233f6 100644 --- a/src/lib/specialize.ml +++ b/src/lib/specialize.ml @@ -72,50 +72,54 @@ open Rewriter let opt_ddump_spec_ast = ref None -let is_typ_ord_arg = function - | A_aux (A_typ _, _) -> true - | A_aux (A_order _, _) -> true - | _ -> false +let is_typ_ord_arg = function A_aux (A_typ _, _) -> true | A_aux (A_order _, _) -> true | _ -> false type specialization = { - is_polymorphic : kinded_id -> bool; - instantiation_filter : kid -> typ_arg -> bool; - extern_filter : extern option -> bool - } + is_polymorphic : kinded_id -> bool; + instantiation_filter : kid -> typ_arg -> bool; + extern_filter : extern option -> bool; +} -let typ_ord_specialization = { +let typ_ord_specialization = + { is_polymorphic = (fun kopt -> is_typ_kopt kopt || is_order_kopt kopt); instantiation_filter = (fun _ -> is_typ_ord_arg); - extern_filter = (fun _ -> false) + extern_filter = (fun _ -> false); } -let int_specialization = { +let int_specialization = + { is_polymorphic = is_int_kopt; - instantiation_filter = (fun _ arg -> match arg with A_aux (A_nexp (Nexp_aux (Nexp_constant _, _)), _) -> true | _ -> false); - extern_filter = (fun externs -> match Ast_util.extern_assoc "c" externs with Some _ -> true | None -> false) + instantiation_filter = + (fun _ arg -> match arg with A_aux (A_nexp (Nexp_aux (Nexp_constant _, _)), _) -> true | _ -> false); + extern_filter = (fun externs -> match Ast_util.extern_assoc "c" externs with Some _ -> true | None -> false); } -let int_specialization_with_externs = { +let int_specialization_with_externs = + { is_polymorphic = is_int_kopt; - instantiation_filter = (fun _ arg -> match arg with A_aux (A_nexp (Nexp_aux (Nexp_constant _, _)), _) -> true | _ -> false); - extern_filter = (fun _ -> false) + instantiation_filter = + (fun _ arg -> match arg with A_aux (A_nexp (Nexp_aux (Nexp_constant _, _)), _) -> true | _ -> false); + extern_filter = (fun _ -> false); } let rec nexp_simp_typ (Typ_aux (typ_aux, l)) = - let typ_aux = match typ_aux with + let typ_aux = + match typ_aux with | Typ_id v -> Typ_id v | Typ_var kid -> Typ_var kid | Typ_tuple typs -> Typ_tuple (List.map nexp_simp_typ typs) | Typ_app (f, args) -> Typ_app (f, List.map nexp_simp_typ_arg args) | Typ_exist (kids, nc, typ) -> Typ_exist (kids, nc, nexp_simp_typ typ) - | Typ_fn (arg_typs, ret_typ) -> - Typ_fn (List.map nexp_simp_typ arg_typs, nexp_simp_typ ret_typ) + | Typ_fn (arg_typs, ret_typ) -> Typ_fn (List.map nexp_simp_typ arg_typs, nexp_simp_typ ret_typ) | Typ_bidir (t1, t2) -> Typ_bidir (nexp_simp_typ t1, nexp_simp_typ t2) | Typ_internal_unknown -> Reporting.unreachable l __POS__ "escaped Typ_internal_unknown" in Typ_aux (typ_aux, l) + and nexp_simp_typ_arg (A_aux (typ_arg_aux, l)) = - let typ_arg_aux = match typ_arg_aux with + let typ_arg_aux = + match typ_arg_aux with | A_nexp n -> A_nexp (nexp_simp n) | A_typ typ -> A_typ (nexp_simp_typ typ) | A_order ord -> A_order ord @@ -127,7 +131,7 @@ and nexp_simp_typ_arg (A_aux (typ_arg_aux, l)) = This part of the typechecker API is a bit ugly. *) let fix_instantiation spec instantiation = let instantiation = KBindings.bindings (KBindings.filter spec.instantiation_filter instantiation) in - let instantiation = List.map (fun (kid, arg) -> Type_check.orig_kid kid, nexp_simp_typ_arg arg) instantiation in + let instantiation = List.map (fun (kid, arg) -> (Type_check.orig_kid kid, nexp_simp_typ_arg arg)) instantiation in List.fold_left (fun m (k, v) -> KBindings.add k v m) KBindings.empty instantiation (* polymorphic_functions returns all functions that are polymorphic @@ -136,12 +140,10 @@ let fix_instantiation spec instantiation = return all Int-polymorphic functions. *) let rec polymorphic_functions ctx defs = match defs with - | DEF_aux (DEF_val (VS_aux (VS_val_spec (TypSchm_aux (TypSchm_ts (typq, typ) , _), id, externs, _), _)), _) :: defs -> - let is_polymorphic = List.exists ctx.is_polymorphic (quant_kopts typq) in - if is_polymorphic && not (ctx.extern_filter externs) then - IdSet.add id (polymorphic_functions ctx defs) - else - polymorphic_functions ctx defs + | DEF_aux (DEF_val (VS_aux (VS_val_spec (TypSchm_aux (TypSchm_ts (typq, typ), _), id, externs, _), _)), _) :: defs -> + let is_polymorphic = List.exists ctx.is_polymorphic (quant_kopts typq) in + if is_polymorphic && not (ctx.extern_filter externs) then IdSet.add id (polymorphic_functions ctx defs) + else polymorphic_functions ctx defs | _ :: defs -> polymorphic_functions ctx defs | [] -> IdSet.empty @@ -156,17 +158,16 @@ let string_of_instantiation instantiation = let kid_names = ref KOptMap.empty in let kid_counter = ref 0 in let kid_name kid = - try KOptMap.find kid !kid_names with - | Not_found -> - let n = string_of_int !kid_counter in - kid_names := KOptMap.add kid n !kid_names; - incr kid_counter; - n + try KOptMap.find kid !kid_names + with Not_found -> + let n = string_of_int !kid_counter in + kid_names := KOptMap.add kid n !kid_names; + incr kid_counter; + n in (* We need custom string_of functions to ensure that alpha-equivalent definitions get the same name *) - let rec string_of_nexp = function - | Nexp_aux (nexp, _) -> string_of_nexp_aux nexp + let rec string_of_nexp = function Nexp_aux (nexp, _) -> string_of_nexp_aux nexp and string_of_nexp_aux = function | Nexp_id id -> string_of_id id | Nexp_var kid -> kid_name (mk_kopt K_int kid) @@ -179,22 +180,19 @@ let string_of_instantiation instantiation = | Nexp_neg n -> "- " ^ string_of_nexp n in - let rec string_of_typ = function - | Typ_aux (typ, l) -> string_of_typ_aux typ + let rec string_of_typ = function Typ_aux (typ, l) -> string_of_typ_aux typ and string_of_typ_aux = function | Typ_id id -> string_of_id id | Typ_var kid -> kid_name (mk_kopt K_type kid) | Typ_tuple typs -> "(" ^ Util.string_of_list ", " string_of_typ typs ^ ")" | Typ_app (id, args) -> string_of_id id ^ "(" ^ Util.string_of_list "," string_of_typ_arg args ^ ")" | Typ_fn (arg_typs, ret_typ) -> - "(" ^ Util.string_of_list ", " string_of_typ arg_typs ^ ") -> " ^ string_of_typ ret_typ - | Typ_bidir (t1, t2) -> - string_of_typ t1 ^ " <-> " ^ string_of_typ t2 + "(" ^ Util.string_of_list ", " string_of_typ arg_typs ^ ") -> " ^ string_of_typ ret_typ + | Typ_bidir (t1, t2) -> string_of_typ t1 ^ " <-> " ^ string_of_typ t2 | Typ_exist (kids, nc, typ) -> - "exist " ^ Util.string_of_list " " kid_name kids ^ ", " ^ string_of_n_constraint nc ^ ". " ^ string_of_typ typ + "exist " ^ Util.string_of_list " " kid_name kids ^ ", " ^ string_of_n_constraint nc ^ ". " ^ string_of_typ typ | Typ_internal_unknown -> "UNKNOWN" - and string_of_typ_arg = function - | A_aux (typ_arg, l) -> string_of_typ_arg_aux typ_arg + and string_of_typ_arg = function A_aux (typ_arg, l) -> string_of_typ_arg_aux typ_arg and string_of_typ_arg_aux = function | A_nexp n -> string_of_nexp n | A_typ typ -> string_of_typ typ @@ -207,12 +205,10 @@ let string_of_instantiation instantiation = | NC_aux (NC_bounded_gt (n1, n2), _) -> string_of_nexp n1 ^ " > " ^ string_of_nexp n2 | NC_aux (NC_bounded_le (n1, n2), _) -> string_of_nexp n1 ^ " <= " ^ string_of_nexp n2 | NC_aux (NC_bounded_lt (n1, n2), _) -> string_of_nexp n1 ^ " < " ^ string_of_nexp n2 - | NC_aux (NC_or (nc1, nc2), _) -> - "(" ^ string_of_n_constraint nc1 ^ " | " ^ string_of_n_constraint nc2 ^ ")" - | NC_aux (NC_and (nc1, nc2), _) -> - "(" ^ string_of_n_constraint nc1 ^ " & " ^ string_of_n_constraint nc2 ^ ")" + | NC_aux (NC_or (nc1, nc2), _) -> "(" ^ string_of_n_constraint nc1 ^ " | " ^ string_of_n_constraint nc2 ^ ")" + | NC_aux (NC_and (nc1, nc2), _) -> "(" ^ string_of_n_constraint nc1 ^ " & " ^ string_of_n_constraint nc2 ^ ")" | NC_aux (NC_set (kid, ns), _) -> - kid_name (mk_kopt K_int kid) ^ " in {" ^ Util.string_of_list ", " Big_int.to_string ns ^ "}" + kid_name (mk_kopt K_int kid) ^ " in {" ^ Util.string_of_list ", " Big_int.to_string ns ^ "}" | NC_aux (NC_true, _) -> "true" | NC_aux (NC_false, _) -> "false" | NC_aux (NC_var kid, _) -> kid_name (mk_kopt K_bool kid) @@ -229,7 +225,8 @@ let id_of_instantiation id instantiation = let rec variant_generic_typ id defs = match defs with | DEF_aux (DEF_type (TD_aux (TD_variant (id', typq, _, _), _)), _) :: _ when Id.compare id id' = 0 -> - mk_typ (Typ_app (id', List.map (fun kopt -> mk_typ_arg (A_typ (mk_typ (Typ_var (kopt_kid kopt))))) (quant_kopts typq))) + mk_typ + (Typ_app (id', List.map (fun kopt -> mk_typ_arg (A_typ (mk_typ (Typ_var (kopt_kid kopt))))) (quant_kopts typq))) | _ :: defs -> variant_generic_typ id defs | [] -> failwith ("No variant with id " ^ string_of_id id) @@ -241,37 +238,46 @@ let instantiations_of spec id ast = let inspect_exp = function | E_aux (E_app (id', _), _) as exp when Id.compare id id' = 0 -> - let instantiation = fix_instantiation spec (Type_check.instantiation_of exp) in - instantiations := instantiation :: !instantiations; - exp + let instantiation = fix_instantiation spec (Type_check.instantiation_of exp) in + instantiations := instantiation :: !instantiations; + exp | exp -> exp in (* We need to to check patterns in case id is a union constructor that is never called like a function. *) let inspect_pat = function - | P_aux (P_app (id', _), annot) as pat when Id.compare id id' = 0 -> - begin match Type_check.typ_of_annot annot with - | Typ_aux (Typ_app (variant_id, _), _) as typ -> - let open Type_check in - let instantiation = unify (fst annot) (env_of_annot annot) - (tyvars_of_typ (variant_generic_typ variant_id ast.defs)) - (variant_generic_typ variant_id ast.defs) - typ - in - instantiations := fix_instantiation spec instantiation :: !instantiations; - pat - | Typ_aux (Typ_id variant_id, _) -> pat - | _ -> failwith ("Union constructor " ^ string_of_pat pat ^ " has non-union type") - end + | P_aux (P_app (id', _), annot) as pat when Id.compare id id' = 0 -> begin + match Type_check.typ_of_annot annot with + | Typ_aux (Typ_app (variant_id, _), _) as typ -> + let open Type_check in + let instantiation = + unify (fst annot) (env_of_annot annot) + (tyvars_of_typ (variant_generic_typ variant_id ast.defs)) + (variant_generic_typ variant_id ast.defs) + typ + in + instantiations := fix_instantiation spec instantiation :: !instantiations; + pat + | Typ_aux (Typ_id variant_id, _) -> pat + | _ -> failwith ("Union constructor " ^ string_of_pat pat ^ " has non-union type") + end | pat -> pat in let rewrite_pat = { id_pat_alg with p_aux = (fun (pat, annot) -> inspect_pat (P_aux (pat, annot))) } in - let rewrite_exp = { id_exp_alg with pat_alg = rewrite_pat; - e_aux = (fun (exp, annot) -> inspect_exp (E_aux (exp, annot))) } in - let _ = rewrite_ast_base { rewriters_base with rewrite_exp = (fun _ -> fold_exp rewrite_exp); - rewrite_pat = (fun _ -> fold_pat rewrite_pat)} ast in + let rewrite_exp = + { id_exp_alg with pat_alg = rewrite_pat; e_aux = (fun (exp, annot) -> inspect_exp (E_aux (exp, annot))) } + in + let _ = + rewrite_ast_base + { + rewriters_base with + rewrite_exp = (fun _ -> fold_exp rewrite_exp); + rewrite_pat = (fun _ -> fold_pat rewrite_pat); + } + ast + in !instantiations @@ -280,58 +286,57 @@ let rewrite_polymorphic_calls spec id ast = let rewrite_e_aux = function | E_aux (E_app (id', args), annot) as exp when Id.compare id id' = 0 -> - let instantiation = fix_instantiation spec (Type_check.instantiation_of exp) in - let spec_id = id_of_instantiation id instantiation in - (* Make sure we only generate specialized calls when we've - specialized the valspec. The valspec may not be generated if - a polymorphic function calls another polymorphic function. - In this case a specialization of the first may require that - the second needs to be specialized again, but this may not - have happened yet. *) - if IdSet.mem spec_id vs_ids then - E_aux (E_app (spec_id, args), annot) - else - exp + let instantiation = fix_instantiation spec (Type_check.instantiation_of exp) in + let spec_id = id_of_instantiation id instantiation in + (* Make sure we only generate specialized calls when we've + specialized the valspec. The valspec may not be generated if + a polymorphic function calls another polymorphic function. + In this case a specialization of the first may require that + the second needs to be specialized again, but this may not + have happened yet. *) + if IdSet.mem spec_id vs_ids then E_aux (E_app (spec_id, args), annot) else exp | exp -> exp in let rewrite_exp = { id_exp_alg with e_aux = (fun (exp, annot) -> rewrite_e_aux (E_aux (exp, annot))) } in rewrite_ast_base { rewriters_base with rewrite_exp = (fun _ -> fold_exp rewrite_exp) } ast -let rec typ_frees ?exs:(exs=KidSet.empty) (Typ_aux (typ_aux, l)) = +let rec typ_frees ?(exs = KidSet.empty) (Typ_aux (typ_aux, l)) = match typ_aux with | Typ_id v -> KidSet.empty | Typ_var kid when KidSet.mem kid exs -> KidSet.empty | Typ_var kid -> KidSet.singleton kid - | Typ_tuple typs -> List.fold_left KidSet.union KidSet.empty (List.map (typ_frees ~exs:exs) typs) - | Typ_app (f, args) -> List.fold_left KidSet.union KidSet.empty (List.map (typ_arg_frees ~exs:exs) args) + | Typ_tuple typs -> List.fold_left KidSet.union KidSet.empty (List.map (typ_frees ~exs) typs) + | Typ_app (f, args) -> List.fold_left KidSet.union KidSet.empty (List.map (typ_arg_frees ~exs) args) | Typ_exist (kopts, nc, typ) -> typ_frees ~exs:(KidSet.of_list (List.map kopt_kid kopts)) typ | Typ_fn (arg_typs, ret_typ) -> - List.fold_left KidSet.union (typ_frees ~exs:exs ret_typ) (List.map (typ_frees ~exs:exs) arg_typs) - | Typ_bidir (t1, t2) -> KidSet.union (typ_frees ~exs:exs t1) (typ_frees ~exs:exs t2) + List.fold_left KidSet.union (typ_frees ~exs ret_typ) (List.map (typ_frees ~exs) arg_typs) + | Typ_bidir (t1, t2) -> KidSet.union (typ_frees ~exs t1) (typ_frees ~exs t2) | Typ_internal_unknown -> Reporting.unreachable l __POS__ "escaped Typ_internal_unknown" -and typ_arg_frees ?exs:(exs=KidSet.empty) (A_aux (typ_arg_aux, l)) = + +and typ_arg_frees ?(exs = KidSet.empty) (A_aux (typ_arg_aux, l)) = match typ_arg_aux with | A_nexp n -> KidSet.empty - | A_typ typ -> typ_frees ~exs:exs typ + | A_typ typ -> typ_frees ~exs typ | A_order ord -> KidSet.empty | A_bool _ -> KidSet.empty -let rec typ_int_frees ?exs:(exs=KidSet.empty) (Typ_aux (typ_aux, l)) = +let rec typ_int_frees ?(exs = KidSet.empty) (Typ_aux (typ_aux, l)) = match typ_aux with | Typ_id v -> KidSet.empty | Typ_var kid -> KidSet.empty - | Typ_tuple typs -> List.fold_left KidSet.union KidSet.empty (List.map (typ_int_frees ~exs:exs) typs) - | Typ_app (f, args) -> List.fold_left KidSet.union KidSet.empty (List.map (typ_arg_int_frees ~exs:exs) args) + | Typ_tuple typs -> List.fold_left KidSet.union KidSet.empty (List.map (typ_int_frees ~exs) typs) + | Typ_app (f, args) -> List.fold_left KidSet.union KidSet.empty (List.map (typ_arg_int_frees ~exs) args) | Typ_exist (kopts, nc, typ) -> typ_int_frees ~exs:(KidSet.of_list (List.map kopt_kid kopts)) typ | Typ_fn (arg_typs, ret_typ) -> - List.fold_left KidSet.union (typ_int_frees ~exs:exs ret_typ) (List.map (typ_int_frees ~exs:exs) arg_typs) - | Typ_bidir (t1, t2) -> KidSet.union (typ_int_frees ~exs:exs t1) (typ_int_frees ~exs:exs t2) + List.fold_left KidSet.union (typ_int_frees ~exs ret_typ) (List.map (typ_int_frees ~exs) arg_typs) + | Typ_bidir (t1, t2) -> KidSet.union (typ_int_frees ~exs t1) (typ_int_frees ~exs t2) | Typ_internal_unknown -> Reporting.unreachable l __POS__ "escaped Typ_internal_unknown" -and typ_arg_int_frees ?exs:(exs=KidSet.empty) (A_aux (typ_arg_aux, l)) = + +and typ_arg_int_frees ?(exs = KidSet.empty) (A_aux (typ_arg_aux, l)) = match typ_arg_aux with | A_nexp n -> KidSet.diff (tyvars_of_nexp n) exs - | A_typ typ -> typ_int_frees ~exs:exs typ + | A_typ typ -> typ_int_frees ~exs typ | A_order ord -> KidSet.empty | A_bool _ -> KidSet.empty @@ -349,14 +354,13 @@ let rec remove_implicit (Typ_aux (aux, l)) = | Typ_id id -> Typ_aux (Typ_id id, l) | Typ_exist (kopts, nc, typ) -> Typ_aux (Typ_exist (kopts, nc, remove_implicit typ), l) | Typ_var v -> Typ_aux (Typ_var v, l) + and remove_implicit_arg (A_aux (aux, l)) = - match aux with - | A_typ typ -> A_aux (A_typ (remove_implicit typ), l) - | arg -> A_aux (arg, l) + match aux with A_typ typ -> A_aux (A_typ (remove_implicit typ), l) | arg -> A_aux (arg, l) let kopt_arg = function | KOpt_aux (KOpt_kind (K_aux (K_int, _), kid), _) -> arg_nexp (nvar kid) - | KOpt_aux (KOpt_kind (K_aux (K_type,_), kid), _) -> arg_typ (mk_typ (Typ_var kid)) + | KOpt_aux (KOpt_kind (K_aux (K_type, _), kid), _) -> arg_typ (mk_typ (Typ_var kid)) | KOpt_aux (KOpt_kind (K_aux (K_bool, _), kid), _) -> arg_bool (nc_var kid) | KOpt_aux (KOpt_kind (K_aux (K_order, _), kid), _) -> arg_order (mk_ord (Ord_var kid)) @@ -377,9 +381,13 @@ let safe_instantiation instantiation = |> List.fold_left KOptSet.union KOptSet.empty |> KOptSet.elements in - List.fold_left (fun (i, r) v -> KBindings.map (fun arg -> subst_kid typ_arg_subst (kopt_kid v) (prepend_kid "i#" (kopt_kid v)) arg) i, - KBindings.add (prepend_kid "i#" (kopt_kid v)) (kopt_arg v) r) - (instantiation, KBindings.empty) args + List.fold_left + (fun (i, r) v -> + ( KBindings.map (fun arg -> subst_kid typ_arg_subst (kopt_kid v) (prepend_kid "i#" (kopt_kid v)) arg) i, + KBindings.add (prepend_kid "i#" (kopt_kid v)) (kopt_arg v) r + ) + ) + (instantiation, KBindings.empty) args let instantiate_constraints instantiation ncs = List.map (fun c -> List.fold_left (fun c (v, a) -> constraint_subst v a c) c (KBindings.bindings instantiation)) ncs @@ -388,66 +396,85 @@ let specialize_id_valspec spec instantiations id ast effect_info = match split_defs (is_valspec id) ast.defs with | None -> Reporting.unreachable (id_loc id) __POS__ ("Valspec " ^ string_of_id id ^ " does not exist!") | Some (pre_defs, vs, post_defs) -> - let typschm, externs, is_cast, annot, def_annot = match vs with - | DEF_aux (DEF_val (VS_aux (VS_val_spec (typschm, _, externs, is_cast), annot)), def_annot) -> typschm, externs, is_cast, annot, def_annot - | _ -> Reporting.unreachable (id_loc id) __POS__ "val-spec is not actually a val-spec" - in - let TypSchm_aux (TypSchm_ts (typq, typ), _) = typschm in - - (* Keep track of the specialized ids to avoid generating things twice. *) - let spec_ids = ref IdSet.empty in - - let specialize_instance instantiation = - let uninstantiated = quant_kopts typq |> List.map kopt_kid |> List.filter (fun v -> not (KBindings.mem v instantiation)) |> KidSet.of_list in - - (* Collect any new type variables introduced by the instantiation *) - let collect_kids kidsets = KidSet.elements (List.fold_left KidSet.union KidSet.empty kidsets) in - let typ_frees = KBindings.bindings instantiation |> List.map snd |> List.map typ_arg_frees |> collect_kids in - let int_frees = KBindings.bindings instantiation |> List.map snd |> List.map typ_arg_int_frees |> collect_kids in - - let typq, typ = - List.fold_left (fun (typq, typ) free -> - if KidSet.mem free uninstantiated then - let fresh_v = prepend_kid "o#" free in - typquant_subst_kid free fresh_v typq, subst_kid typ_subst free fresh_v typ - else - typq, typ - ) (typq, typ) (typ_frees @ int_frees) - in - - let safe_instantiation, reverse = safe_instantiation instantiation in - (* Replace the polymorphic type variables in the type with their concrete instantiation. *) - let typ = remove_implicit (Type_check.subst_unifiers reverse (Type_check.subst_unifiers safe_instantiation typ)) in - - (* Remove type variables from the type quantifier. *) - let kopts, constraints = quant_split typq in - let constraints = instantiate_constraints safe_instantiation constraints in - let constraints = instantiate_constraints reverse constraints in - let kopts = List.filter (fun kopt -> not (spec.is_polymorphic kopt && KBindings.mem (kopt_kid kopt) safe_instantiation)) kopts in - let typq = - if List.length (typ_frees @ int_frees) = 0 && List.length kopts = 0 then - mk_typquant [] - else - mk_typquant (List.map (mk_qi_id K_type) typ_frees - @ List.map (mk_qi_id K_int) int_frees - @ List.map mk_qi_kopt kopts - @ List.map mk_qi_nc constraints) in - let typschm = mk_typschm typq typ in - - let spec_id = id_of_instantiation id instantiation in - - if IdSet.mem spec_id !spec_ids then [] else - begin - spec_ids := IdSet.add spec_id !spec_ids; - [DEF_aux (DEF_val (VS_aux (VS_val_spec (typschm, spec_id, externs, is_cast), annot)), def_annot)] - end - in - - let specializations = List.map specialize_instance instantiations |> List.concat in - - let effect_info = IdSet.fold (fun id' effect_info -> Effects.copy_function_effect id effect_info id') !spec_ids effect_info in - - { ast with defs = pre_defs @ (vs :: specializations) @ post_defs }, effect_info + let typschm, externs, is_cast, annot, def_annot = + match vs with + | DEF_aux (DEF_val (VS_aux (VS_val_spec (typschm, _, externs, is_cast), annot)), def_annot) -> + (typschm, externs, is_cast, annot, def_annot) + | _ -> Reporting.unreachable (id_loc id) __POS__ "val-spec is not actually a val-spec" + in + let (TypSchm_aux (TypSchm_ts (typq, typ), _)) = typschm in + + (* Keep track of the specialized ids to avoid generating things twice. *) + let spec_ids = ref IdSet.empty in + + let specialize_instance instantiation = + let uninstantiated = + quant_kopts typq |> List.map kopt_kid + |> List.filter (fun v -> not (KBindings.mem v instantiation)) + |> KidSet.of_list + in + + (* Collect any new type variables introduced by the instantiation *) + let collect_kids kidsets = KidSet.elements (List.fold_left KidSet.union KidSet.empty kidsets) in + let typ_frees = KBindings.bindings instantiation |> List.map snd |> List.map typ_arg_frees |> collect_kids in + let int_frees = + KBindings.bindings instantiation |> List.map snd |> List.map typ_arg_int_frees |> collect_kids + in + + let typq, typ = + List.fold_left + (fun (typq, typ) free -> + if KidSet.mem free uninstantiated then ( + let fresh_v = prepend_kid "o#" free in + (typquant_subst_kid free fresh_v typq, subst_kid typ_subst free fresh_v typ) + ) + else (typq, typ) + ) + (typq, typ) (typ_frees @ int_frees) + in + + let safe_instantiation, reverse = safe_instantiation instantiation in + (* Replace the polymorphic type variables in the type with their concrete instantiation. *) + let typ = + remove_implicit (Type_check.subst_unifiers reverse (Type_check.subst_unifiers safe_instantiation typ)) + in + + (* Remove type variables from the type quantifier. *) + let kopts, constraints = quant_split typq in + let constraints = instantiate_constraints safe_instantiation constraints in + let constraints = instantiate_constraints reverse constraints in + let kopts = + List.filter + (fun kopt -> not (spec.is_polymorphic kopt && KBindings.mem (kopt_kid kopt) safe_instantiation)) + kopts + in + let typq = + if List.length (typ_frees @ int_frees) = 0 && List.length kopts = 0 then mk_typquant [] + else + mk_typquant + (List.map (mk_qi_id K_type) typ_frees + @ List.map (mk_qi_id K_int) int_frees + @ List.map mk_qi_kopt kopts @ List.map mk_qi_nc constraints + ) + in + let typschm = mk_typschm typq typ in + + let spec_id = id_of_instantiation id instantiation in + + if IdSet.mem spec_id !spec_ids then [] + else begin + spec_ids := IdSet.add spec_id !spec_ids; + [DEF_aux (DEF_val (VS_aux (VS_val_spec (typschm, spec_id, externs, is_cast), annot)), def_annot)] + end + in + + let specializations = List.map specialize_instance instantiations |> List.concat in + + let effect_info = + IdSet.fold (fun id' effect_info -> Effects.copy_function_effect id effect_info id') !spec_ids effect_info + in + + ({ ast with defs = pre_defs @ (vs :: specializations) @ post_defs }, effect_info) (* When we specialize a function definition we also need to specialize all the types that appear as annotations within the function @@ -455,45 +482,39 @@ let specialize_id_valspec spec instantiations id ast effect_info = because at this point we have that as a separate valspec.*) let specialize_annotations instantiation fdef = let open Type_check in - let rw_pat = { - id_pat_alg with - p_typ = (fun (typ, pat) -> P_typ (subst_unifiers instantiation typ, pat)) - } in - let rw_exp = { + let rw_pat = { id_pat_alg with p_typ = (fun (typ, pat) -> P_typ (subst_unifiers instantiation typ, pat)) } in + let rw_exp = + { id_exp_alg with e_typ = (fun (typ, exp) -> E_typ (subst_unifiers instantiation typ, exp)); le_typ = (fun (typ, lexp) -> LE_typ (subst_unifiers instantiation typ, lexp)); - pat_alg = rw_pat - } in + pat_alg = rw_pat; + } + in let fdef = - rewrite_fun { - rewriters_base with - rewrite_exp = (fun _ -> fold_exp rw_exp); - rewrite_pat = (fun _ -> fold_pat rw_pat) - } fdef + rewrite_fun + { rewriters_base with rewrite_exp = (fun _ -> fold_exp rw_exp); rewrite_pat = (fun _ -> fold_pat rw_pat) } + fdef in match fdef with | FD_aux (FD_function (rec_opt, _, funcls), annot) -> - FD_aux (FD_function (rec_opt, - Typ_annot_opt_aux (Typ_annot_opt_none, Parse_ast.Unknown), - funcls), - annot) + FD_aux (FD_function (rec_opt, Typ_annot_opt_aux (Typ_annot_opt_none, Parse_ast.Unknown), funcls), annot) let specialize_id_fundef instantiations id ast = match split_defs (is_fundef id) ast.defs with | None -> ast | Some (pre_defs, DEF_aux (DEF_fundef fundef, def_annot), post_defs) -> - let spec_ids = ref IdSet.empty in - let specialize_fundef instantiation = - let spec_id = id_of_instantiation id instantiation in - if IdSet.mem spec_id !spec_ids then [] else - begin - spec_ids := IdSet.add spec_id !spec_ids; - [DEF_aux (DEF_fundef (specialize_annotations instantiation (rename_fundef spec_id fundef)), def_annot)] - end - in - let fundefs = List.map specialize_fundef instantiations |> List.concat in - { ast with defs = pre_defs @ (DEF_aux (DEF_fundef fundef, def_annot) :: fundefs) @ post_defs } + let spec_ids = ref IdSet.empty in + let specialize_fundef instantiation = + let spec_id = id_of_instantiation id instantiation in + if IdSet.mem spec_id !spec_ids then [] + else begin + spec_ids := IdSet.add spec_id !spec_ids; + [DEF_aux (DEF_fundef (specialize_annotations instantiation (rename_fundef spec_id fundef)), def_annot)] + end + in + let fundefs = List.map specialize_fundef instantiations |> List.concat in + { ast with defs = pre_defs @ (DEF_aux (DEF_fundef fundef, def_annot) :: fundefs) @ post_defs } | Some _ -> assert false (* unreachable *) let specialize_id_overloads instantiations id ast = @@ -502,8 +523,10 @@ let specialize_id_overloads instantiations id ast = let rec rewrite_overloads defs = match defs with | DEF_aux (DEF_overload (overload_id, overloads), def_annot) :: defs -> - let overloads = List.concat (List.map (fun id' -> if Id.compare id' id = 0 then IdSet.elements ids else [id']) overloads) in - DEF_aux (DEF_overload (overload_id, overloads), def_annot) :: rewrite_overloads defs + let overloads = + List.concat (List.map (fun id' -> if Id.compare id' id = 0 then IdSet.elements ids else [id']) overloads) + in + DEF_aux (DEF_overload (overload_id, overloads), def_annot) :: rewrite_overloads defs | def :: defs -> def :: rewrite_overloads defs | [] -> [] in @@ -516,30 +539,34 @@ let specialize_id_overloads instantiations id ast = valspecs are then re-specialized. This process is iterated until the whole spec is specialized. *) -let initial_calls = ref (IdSet.of_list - [ mk_id "main"; - mk_id "__InitConfig"; - mk_id "__SetConfig"; - mk_id "__ListConfig"; - mk_id "execute"; - mk_id "decode"; - mk_id "initialize_registers"; - mk_id "prop"; - mk_id "append_64" (* used to construct bitvector literals in C backend *) - ]) +let initial_calls = + ref + (IdSet.of_list + [ + mk_id "main"; + mk_id "__InitConfig"; + mk_id "__SetConfig"; + mk_id "__ListConfig"; + mk_id "execute"; + mk_id "decode"; + mk_id "initialize_registers"; + mk_id "prop"; + mk_id "append_64" (* used to construct bitvector literals in C backend *); + ] + ) let add_initial_calls ids = initial_calls := IdSet.union ids !initial_calls let get_initial_calls () = IdSet.elements !initial_calls - + let remove_unused_valspecs env ast = let calls = ref !initial_calls in let vs_ids = val_spec_ids ast.defs in let inspect_exp = function | E_aux (E_app (call, _), _) as exp -> - calls := IdSet.add call !calls; - exp + calls := IdSet.add call !calls; + exp | exp -> exp in @@ -550,16 +577,13 @@ let remove_unused_valspecs env ast = let rec remove_unused defs id = match defs with - | def :: defs when is_fundef id def -> - remove_unused defs id - | def :: defs when is_valspec id def -> - remove_unused defs id - | DEF_aux (DEF_overload (overload_id, overloads), def_annot) :: defs -> - begin - match List.filter (fun id' -> Id.compare id id' <> 0) overloads with - | [] -> remove_unused defs id - | overloads -> DEF_aux (DEF_overload (overload_id, overloads), def_annot) :: remove_unused defs id - end + | def :: defs when is_fundef id def -> remove_unused defs id + | def :: defs when is_valspec id def -> remove_unused defs id + | DEF_aux (DEF_overload (overload_id, overloads), def_annot) :: defs -> begin + match List.filter (fun id' -> Id.compare id id' <> 0) overloads with + | [] -> remove_unused defs id + | overloads -> DEF_aux (DEF_overload (overload_id, overloads), def_annot) :: remove_unused defs id + end | def :: defs -> def :: remove_unused defs id | [] -> [] in @@ -570,7 +594,7 @@ let specialize_id spec id ast effect_info = let instantiations = instantiations_of spec id ast in let ast, effect_info = specialize_id_valspec spec instantiations id ast effect_info in let ast = specialize_id_fundef instantiations id ast in - specialize_id_overloads instantiations id ast, effect_info + (specialize_id_overloads instantiations id ast, effect_info) (* When we generate specialized versions of functions, we need to ensure that the types they are specialized to appear before the @@ -580,9 +604,9 @@ let reorder_typedefs ast = let tdefs = ref [] in let rec filter_typedefs = function - | DEF_aux ((DEF_default _ | DEF_type _), _) as tdef :: defs -> - tdefs := tdef :: !tdefs; - filter_typedefs defs + | (DEF_aux ((DEF_default _ | DEF_type _), _) as tdef) :: defs -> + tdefs := tdef :: !tdefs; + filter_typedefs defs | def :: defs -> def :: filter_typedefs defs | [] -> [] in @@ -596,48 +620,55 @@ let specialize_ids spec ids ast effect_info = let _, (ast, effect_info) = List.fold_left (fun (n, (ast, effect_info)) id -> - Util.progress "Specializing " (string_of_id id) n total; (n + 1, specialize_id spec id ast effect_info)) - (1, (ast, effect_info)) (IdSet.elements ids) + Util.progress "Specializing " (string_of_id id) n total; + (n + 1, specialize_id spec id ast effect_info) + ) + (1, (ast, effect_info)) + (IdSet.elements ids) in let ast = reorder_typedefs ast in - begin match !opt_ddump_spec_ast with - | Some (f, i) -> - let filename = f ^ "_spec_" ^ string_of_int i ^ ".sail" in - let out_chan = open_out filename in - Pretty_print_sail.pp_ast out_chan (Type_check.strip_ast ast); - close_out out_chan; - opt_ddump_spec_ast := Some (f, i + 1) - | None -> () + begin + match !opt_ddump_spec_ast with + | Some (f, i) -> + let filename = f ^ "_spec_" ^ string_of_int i ^ ".sail" in + let out_chan = open_out filename in + Pretty_print_sail.pp_ast out_chan (Type_check.strip_ast ast); + close_out out_chan; + opt_ddump_spec_ast := Some (f, i + 1) + | None -> () end; let ast, _ = Type_error.check Type_check.initial_env (Type_check.strip_ast ast) in let _, ast = List.fold_left (fun (n, ast) id -> Util.progress "Rewriting " (string_of_id id) n total; - (n + 1, rewrite_polymorphic_calls spec id ast)) + (n + 1, rewrite_polymorphic_calls spec id ast) + ) (1, ast) (IdSet.elements ids) in let ast, env = Type_error.check Type_check.initial_env (Type_check.strip_ast ast) in let ast = remove_unused_valspecs env ast in Profile.finish "specialization pass" t; - ast, env, effect_info + (ast, env, effect_info) let rec specialize_passes n spec env ast effect_info = - if n = 0 then - ast, env, effect_info - else + if n = 0 then (ast, env, effect_info) + else ( let ids = polymorphic_functions spec ast.defs in - if IdSet.is_empty ids then - ast, env, effect_info - else + if IdSet.is_empty ids then (ast, env, effect_info) + else ( let ast, env, effect_info = specialize_ids spec ids ast effect_info in specialize_passes (n - 1) spec env ast effect_info + ) + ) let specialize = specialize_passes (-1) let () = let open Interactive in - Action (fun istate -> - let ast', env', effect_info' = specialize typ_ord_specialization istate.env istate.ast istate.effect_info in - { istate with ast = ast'; env = env'; effect_info = effect_info' } - ) |> register_command ~name:"specialize" ~help:"Specialize Type and Order type variables in the AST" + Action + (fun istate -> + let ast', env', effect_info' = specialize typ_ord_specialization istate.env istate.ast istate.effect_info in + { istate with ast = ast'; env = env'; effect_info = effect_info' } + ) + |> register_command ~name:"specialize" ~help:"Specialize Type and Order type variables in the AST" diff --git a/src/lib/specialize.mli b/src/lib/specialize.mli index 69eb619c5..b17816fe6 100644 --- a/src/lib/specialize.mli +++ b/src/lib/specialize.mli @@ -101,11 +101,18 @@ val get_initial_calls : unit -> id list AST with [Type_check.initial_env]. The env parameter is the environment to return if there is no polymorphism to remove, in which case specialize returns the AST unmodified. *) -val specialize : specialization -> Env.t -> tannot ast -> Effects.side_effect_info -> tannot ast * Env.t * Effects.side_effect_info +val specialize : + specialization -> Env.t -> tannot ast -> Effects.side_effect_info -> tannot ast * Env.t * Effects.side_effect_info (** specialize' n performs at most n specialization passes. Useful for int_specialization which is not guaranteed to terminate. *) -val specialize_passes : int -> specialization -> Env.t -> tannot ast -> Effects.side_effect_info -> tannot ast * Env.t * Effects.side_effect_info +val specialize_passes : + int -> + specialization -> + Env.t -> + tannot ast -> + Effects.side_effect_info -> + tannot ast * Env.t * Effects.side_effect_info (** return all instantiations of a function id, with the instantiations filtered according to the specialization. *) diff --git a/src/lib/splice.ml b/src/lib/splice.ml index 0c37c47ab..577484319 100644 --- a/src/lib/splice.ml +++ b/src/lib/splice.ml @@ -77,46 +77,42 @@ open Ast_util let scan_ast { defs; _ } = let scan (ids, specs) (DEF_aux (aux, _) as def) = match aux with - | DEF_fundef fd -> - IdSet.add (id_of_fundef fd) ids, specs - | DEF_val (VS_aux (VS_val_spec (_,id,_,_),_) as vs) -> - ids, Bindings.add id vs specs - | DEF_pragma (("file_start" | "file_end"), _ ,_) -> - ids, specs - | _ -> raise (Reporting.err_general (def_loc def) - "Definition in splice file isn't a spec or function") - in List.fold_left scan (IdSet.empty, Bindings.empty) defs + | DEF_fundef fd -> (IdSet.add (id_of_fundef fd) ids, specs) + | DEF_val (VS_aux (VS_val_spec (_, id, _, _), _) as vs) -> (ids, Bindings.add id vs specs) + | DEF_pragma (("file_start" | "file_end"), _, _) -> (ids, specs) + | _ -> raise (Reporting.err_general (def_loc def) "Definition in splice file isn't a spec or function") + in + List.fold_left scan (IdSet.empty, Bindings.empty) defs let filter_old_ast repl_ids repl_specs { defs; _ } = - let check (rdefs,spec_found) (DEF_aux (aux, def_annot) as def) = + let check (rdefs, spec_found) (DEF_aux (aux, def_annot) as def) = match aux with | DEF_fundef fd -> - let id = id_of_fundef fd in - if IdSet.mem id repl_ids - then rdefs, spec_found - else def::rdefs, spec_found - | DEF_val (VS_aux (VS_val_spec (_,id,_,_),_)) -> - (match Bindings.find_opt id repl_specs with - | Some vs -> DEF_aux (DEF_val vs, def_annot) :: rdefs, IdSet.add id spec_found - | None -> def::rdefs, spec_found) - | _ -> def::rdefs, spec_found + let id = id_of_fundef fd in + if IdSet.mem id repl_ids then (rdefs, spec_found) else (def :: rdefs, spec_found) + | DEF_val (VS_aux (VS_val_spec (_, id, _, _), _)) -> ( + match Bindings.find_opt id repl_specs with + | Some vs -> (DEF_aux (DEF_val vs, def_annot) :: rdefs, IdSet.add id spec_found) + | None -> (def :: rdefs, spec_found) + ) + | _ -> (def :: rdefs, spec_found) in - let rdefs, spec_found = List.fold_left check ([],IdSet.empty) defs in + let rdefs, spec_found = List.fold_left check ([], IdSet.empty) defs in (List.rev rdefs, spec_found) let filter_replacements spec_found { defs; _ } = let not_found = function - | DEF_aux (DEF_val (VS_aux (VS_val_spec (_,id,_,_),_)),_) -> not (IdSet.mem id spec_found) + | DEF_aux (DEF_val (VS_aux (VS_val_spec (_, id, _, _), _)), _) -> not (IdSet.mem id spec_found) | _ -> true - in List.filter not_found defs + in + List.filter not_found defs let splice ast file = let parsed_ast = Initial_check.parse_file file |> snd in let repl_ast = Initial_check.process_ast ~generate:false (Parse_ast.Defs [(file, parsed_ast)]) in let repl_ast = Rewrites.move_loop_measures repl_ast in - let repl_ast = map_ast_annot (fun (l,_) -> l,Type_check.empty_tannot) repl_ast in + let repl_ast = map_ast_annot (fun (l, _) -> (l, Type_check.empty_tannot)) repl_ast in let repl_ids, repl_specs = scan_ast repl_ast in let defs1, specs_found = filter_old_ast repl_ids repl_specs ast in let defs2 = filter_replacements specs_found repl_ast in Type_error.check Type_check.initial_env (Type_check.strip_ast { ast with defs = defs1 @ defs2 }) - diff --git a/src/lib/state.ml b/src/lib/state.ml index 158bce565..8ac349af2 100644 --- a/src/lib/state.ml +++ b/src/lib/state.ml @@ -87,217 +87,201 @@ let find_registers defs = List.fold_left (fun acc def -> match def with - | DEF_aux (DEF_register (DEC_aux(DEC_reg (typ, id, _), (_, tannot))), _) -> - let env = match destruct_tannot tannot with - | Some (env, _) -> env - | _ -> Env.empty - in - (Env.expand_synonyms env typ, id) :: acc + | DEF_aux (DEF_register (DEC_aux (DEC_reg (typ, id, _), (_, tannot))), _) -> + let env = match destruct_tannot tannot with Some (env, _) -> env | _ -> Env.empty in + (Env.expand_synonyms env typ, id) :: acc | _ -> acc - ) [] defs + ) + [] defs let generate_register_id_enum = function | [] -> ["type register_id = unit"] | registers -> - let reg (typ, id) = string_of_id id in - ["type register_id = " ^ String.concat " | " (List.map reg registers)] + let reg (typ, id) = string_of_id id in + ["type register_id = " ^ String.concat " | " (List.map reg registers)] -let rec id_of_regtyp builtins mwords (Typ_aux (t, l) as typ) = match t with +let rec id_of_regtyp builtins mwords (Typ_aux (t, l) as typ) = + match t with | Typ_id id -> id | Typ_app (id, args) -> - let name_arg (A_aux (targ, _)) = match targ with - | A_typ targ -> string_of_id (id_of_regtyp builtins mwords targ) - | A_nexp nexp when is_nexp_constant (nexp_simp nexp) -> - string_of_nexp (nexp_simp nexp) - | A_order (Ord_aux (Ord_inc, _)) -> "inc" - | A_order (Ord_aux (Ord_dec, _)) -> "dec" - | _ -> - raise (Reporting.err_typ l "Unsupported register type") - in - if IdSet.mem id builtins && not (mwords && is_bitvector_typ typ) then id else - append_id id (String.concat "_" ("" :: List.map name_arg args)) + let name_arg (A_aux (targ, _)) = + match targ with + | A_typ targ -> string_of_id (id_of_regtyp builtins mwords targ) + | A_nexp nexp when is_nexp_constant (nexp_simp nexp) -> string_of_nexp (nexp_simp nexp) + | A_order (Ord_aux (Ord_inc, _)) -> "inc" + | A_order (Ord_aux (Ord_dec, _)) -> "dec" + | _ -> raise (Reporting.err_typ l "Unsupported register type") + in + if IdSet.mem id builtins && not (mwords && is_bitvector_typ typ) then id + else append_id id (String.concat "_" ("" :: List.map name_arg args)) | _ -> raise (Reporting.err_typ l "Unsupported register type") let regstate_field typ = append_id (id_of_regtyp IdSet.empty false typ) "_reg" let generate_regstate registers = let regstate_def = - if registers = [] then - TD_abbrev (mk_id "regstate", mk_typquant [], mk_typ_arg (A_typ unit_typ)) - else + if registers = [] then TD_abbrev (mk_id "regstate", mk_typquant [], mk_typ_arg (A_typ unit_typ)) + else ( let fields = if !opt_type_grouped_regstate then - List.map - (fun (typ, id) -> - (function_typ [string_typ] typ, - regstate_field typ)) - registers + List.map (fun (typ, id) -> (function_typ [string_typ] typ, regstate_field typ)) registers |> List.sort_uniq (fun (typ1, id1) (typ2, id2) -> Id.compare id1 id2) else registers in TD_record (mk_id "regstate", mk_typquant [], fields, false) + ) in [DEF_aux (DEF_type (TD_aux (regstate_def, (Unknown, empty_uannot))), mk_def_annot Unknown)] let generate_initial_regstate defs = let registers = find_registers defs in - if registers = [] then [] else - try - (* Recursively choose a default value for every type in the spec. - vals, constructed below, maps user-defined types to default values. *) - let rec lookup_init_val vals (Typ_aux (typ_aux, _)) = - match typ_aux with - | Typ_id id -> - if string_of_id id = "bool" then "false" else - if string_of_id id = "bit" then "bitzero" else - if string_of_id id = "int" then "0" else - if string_of_id id = "nat" then "0" else - if string_of_id id = "real" then "0" else - if string_of_id id = "string" then "\"\"" else - if string_of_id id = "unit" then "()" else - Bindings.find id vals [] - | Typ_app (id, _) when string_of_id id = "list" -> "[||]" - | Typ_app (id, [A_aux (A_nexp nexp, _)]) when string_of_id id = "atom" -> - string_of_nexp nexp - | Typ_app (id, [A_aux (A_nexp nexp, _); _]) when string_of_id id = "range" -> - string_of_nexp nexp - | Typ_app (id, [A_aux (A_nexp (Nexp_aux (Nexp_constant len, _)), _); _]) - when string_of_id id = "bitvector" -> - (* Output a literal binary zero value if this is a bitvector - and the environment has a default indexing order (required - by the typechecker for binary and hex literals) *) - let literal_bitvec = has_default_order defs in - let init_elem = if literal_bitvec then "0" else lookup_init_val vals bit_typ in - let rec elems len = - if (Nat_big_num.less_equal len Nat_big_num.zero) then [] else - init_elem :: elems (Nat_big_num.pred len) - in - if literal_bitvec then - "0b" ^ (String.concat "" (elems len)) - else - "[" ^ (String.concat ", " (elems len)) ^ "]" - | Typ_app (id, [A_aux (A_nexp (Nexp_aux (Nexp_constant len, _)), _); _ ; - A_aux (A_typ etyp, _)]) - when string_of_id id = "vector" -> - (* Output a list of initial values of the vector elements. *) - let init_elem = lookup_init_val vals etyp in - let rec elems len = - if (Nat_big_num.less_equal len Nat_big_num.zero) then [] else - init_elem :: elems (Nat_big_num.pred len) - in - "[" ^ (String.concat ", " (elems len)) ^ "]" - | Typ_app (id, args) -> Bindings.find id vals args - | Typ_tuple typs -> - "(" ^ (String.concat ", " (List.map (lookup_init_val vals) typs)) ^ ")" - | Typ_exist (_, _, typ) -> lookup_init_val vals typ - | _ -> raise Not_found - in - let typ_subst_quant_item typ (QI_aux (qi, _)) arg = match qi with - | QI_id (KOpt_aux (KOpt_kind (_, kid), _)) -> - typ_subst kid arg typ - | _ -> typ - in - let typ_subst_typquant tq args typ = - List.fold_left2 typ_subst_quant_item typ (quant_items tq) args - in - let add_typ_init_val (defs', vals) = function - | TD_enum (id, id1 :: _, _) -> - (* Choose the first value of an enumeration type as default *) - (defs', Bindings.add id (fun _ -> string_of_id id1) vals) - | TD_variant (id, tq, (Tu_aux (Tu_ty_id (typ1, id1), _)) :: _, _) -> - (* Choose the first variant of a union type as default *) - let init_val args = - let typ1 = typ_subst_typquant tq args typ1 in - string_of_id id1 ^ " (" ^ lookup_init_val vals typ1 ^ ")" - in - (defs', Bindings.add id init_val vals) - | TD_abbrev (id, tq, A_aux (A_typ typ, _)) -> - let init_val args = lookup_init_val vals (typ_subst_typquant tq args typ) in - (defs', Bindings.add id init_val vals) - | TD_record (id, tq, fields, _) -> - let init_val args = - let init_field (typ, id) = - let typ = typ_subst_typquant tq args typ in - string_of_id id ^ " = " ^ lookup_init_val vals typ - in - "struct { " ^ (String.concat ", " (List.map init_field fields)) ^ " }" - in - let def_name = "initial_" ^ string_of_id id in - if quant_items tq = [] && not (is_defined defs def_name) then - (defs' @ ["let " ^ def_name ^ " : " ^ string_of_id id ^ " = " ^ init_val []], - Bindings.add id (fun _ -> def_name) vals) - else (defs', Bindings.add id init_val vals) - | TD_bitfield (id, typ, _) -> - (defs', Bindings.add id (fun _ -> lookup_init_val vals typ) vals) - | _ -> (defs', vals) - in - let (init_defs, init_vals) = List.fold_left (fun inits def -> match def with - | DEF_aux (DEF_type (TD_aux (td, _)), _) -> add_typ_init_val inits td - | _ -> inits) ([], Bindings.empty) defs - in - let init_reg (typ, id) = string_of_id id ^ " = " ^ lookup_init_val init_vals typ in - List.map (defs_of_string __POS__) - (init_defs @ - ["let initial_regstate : regstate = struct { " ^ - (String.concat ", " (List.map init_reg registers)) ^ - " }"]) - with - | _ -> [] (* Do not generate an initial register state if anything goes wrong *) + if registers = [] then [] + else ( + try + (* Recursively choose a default value for every type in the spec. + vals, constructed below, maps user-defined types to default values. *) + let rec lookup_init_val vals (Typ_aux (typ_aux, _)) = + match typ_aux with + | Typ_id id -> + if string_of_id id = "bool" then "false" + else if string_of_id id = "bit" then "bitzero" + else if string_of_id id = "int" then "0" + else if string_of_id id = "nat" then "0" + else if string_of_id id = "real" then "0" + else if string_of_id id = "string" then "\"\"" + else if string_of_id id = "unit" then "()" + else Bindings.find id vals [] + | Typ_app (id, _) when string_of_id id = "list" -> "[||]" + | Typ_app (id, [A_aux (A_nexp nexp, _)]) when string_of_id id = "atom" -> string_of_nexp nexp + | Typ_app (id, [A_aux (A_nexp nexp, _); _]) when string_of_id id = "range" -> string_of_nexp nexp + | Typ_app (id, [A_aux (A_nexp (Nexp_aux (Nexp_constant len, _)), _); _]) when string_of_id id = "bitvector" -> + (* Output a literal binary zero value if this is a bitvector + and the environment has a default indexing order (required + by the typechecker for binary and hex literals) *) + let literal_bitvec = has_default_order defs in + let init_elem = if literal_bitvec then "0" else lookup_init_val vals bit_typ in + let rec elems len = + if Nat_big_num.less_equal len Nat_big_num.zero then [] else init_elem :: elems (Nat_big_num.pred len) + in + if literal_bitvec then "0b" ^ String.concat "" (elems len) else "[" ^ String.concat ", " (elems len) ^ "]" + | Typ_app (id, [A_aux (A_nexp (Nexp_aux (Nexp_constant len, _)), _); _; A_aux (A_typ etyp, _)]) + when string_of_id id = "vector" -> + (* Output a list of initial values of the vector elements. *) + let init_elem = lookup_init_val vals etyp in + let rec elems len = + if Nat_big_num.less_equal len Nat_big_num.zero then [] else init_elem :: elems (Nat_big_num.pred len) + in + "[" ^ String.concat ", " (elems len) ^ "]" + | Typ_app (id, args) -> Bindings.find id vals args + | Typ_tuple typs -> "(" ^ String.concat ", " (List.map (lookup_init_val vals) typs) ^ ")" + | Typ_exist (_, _, typ) -> lookup_init_val vals typ + | _ -> raise Not_found + in + let typ_subst_quant_item typ (QI_aux (qi, _)) arg = + match qi with QI_id (KOpt_aux (KOpt_kind (_, kid), _)) -> typ_subst kid arg typ | _ -> typ + in + let typ_subst_typquant tq args typ = List.fold_left2 typ_subst_quant_item typ (quant_items tq) args in + let add_typ_init_val (defs', vals) = function + | TD_enum (id, id1 :: _, _) -> + (* Choose the first value of an enumeration type as default *) + (defs', Bindings.add id (fun _ -> string_of_id id1) vals) + | TD_variant (id, tq, Tu_aux (Tu_ty_id (typ1, id1), _) :: _, _) -> + (* Choose the first variant of a union type as default *) + let init_val args = + let typ1 = typ_subst_typquant tq args typ1 in + string_of_id id1 ^ " (" ^ lookup_init_val vals typ1 ^ ")" + in + (defs', Bindings.add id init_val vals) + | TD_abbrev (id, tq, A_aux (A_typ typ, _)) -> + let init_val args = lookup_init_val vals (typ_subst_typquant tq args typ) in + (defs', Bindings.add id init_val vals) + | TD_record (id, tq, fields, _) -> + let init_val args = + let init_field (typ, id) = + let typ = typ_subst_typquant tq args typ in + string_of_id id ^ " = " ^ lookup_init_val vals typ + in + "struct { " ^ String.concat ", " (List.map init_field fields) ^ " }" + in + let def_name = "initial_" ^ string_of_id id in + if quant_items tq = [] && not (is_defined defs def_name) then + ( defs' @ ["let " ^ def_name ^ " : " ^ string_of_id id ^ " = " ^ init_val []], + Bindings.add id (fun _ -> def_name) vals + ) + else (defs', Bindings.add id init_val vals) + | TD_bitfield (id, typ, _) -> (defs', Bindings.add id (fun _ -> lookup_init_val vals typ) vals) + | _ -> (defs', vals) + in + let init_defs, init_vals = + List.fold_left + (fun inits def -> + match def with DEF_aux (DEF_type (TD_aux (td, _)), _) -> add_typ_init_val inits td | _ -> inits + ) + ([], Bindings.empty) defs + in + let init_reg (typ, id) = string_of_id id ^ " = " ^ lookup_init_val init_vals typ in + List.map (defs_of_string __POS__) + (init_defs + @ ["let initial_regstate : regstate = struct { " ^ String.concat ", " (List.map init_reg registers) ^ " }"] + ) + with _ -> [] (* Do not generate an initial register state if anything goes wrong *) + ) -let regval_constr_id = id_of_regtyp (IdSet.of_list (List.map mk_id ["bool"; "int"; "real"; "string"; "vector"; "bitvector"; "list"; "option"])) +let regval_constr_id = + id_of_regtyp + (IdSet.of_list (List.map mk_id ["bool"; "int"; "real"; "string"; "vector"; "bitvector"; "list"; "option"])) let register_base_types mwords typs = let rec add_base_typs typs (Typ_aux (t, _) as typ) = - let builtins = IdSet.of_list (List.map mk_id ["bool"; "atom_bool"; "atom"; "int"; "real"; "string"; "vector"; "list"; "option"]) in + let builtins = + IdSet.of_list (List.map mk_id ["bool"; "atom_bool"; "atom"; "int"; "real"; "string"; "vector"; "list"; "option"]) + in match t with - | Typ_app (id, args) - when IdSet.mem id builtins && not (mwords && is_bitvector_typ typ) -> - let add_typ_arg base_typs (A_aux (targ, _)) = - match targ with - | A_typ typ -> add_base_typs base_typs typ - | _ -> base_typs - in - List.fold_left add_typ_arg typs args - | Typ_id id when IdSet.mem id builtins -> typs - | _ -> Bindings.add (regval_constr_id mwords typ) typ typs + | Typ_app (id, args) when IdSet.mem id builtins && not (mwords && is_bitvector_typ typ) -> + let add_typ_arg base_typs (A_aux (targ, _)) = + match targ with A_typ typ -> add_base_typs base_typs typ | _ -> base_typs + in + List.fold_left add_typ_arg typs args + | Typ_id id when IdSet.mem id builtins -> typs + | _ -> Bindings.add (regval_constr_id mwords typ) typ typs in List.fold_left add_base_typs Bindings.empty (bit_typ :: typs) let generate_regval_typ typs = - let constr (constr_id, typ) = - Printf.sprintf "Regval_%s : %s" (string_of_id constr_id) (to_string (doc_typ typ)) in + let constr (constr_id, typ) = Printf.sprintf "Regval_%s : %s" (string_of_id constr_id) (to_string (doc_typ typ)) in let builtins = - "Regval_vector : list(register_value), " ^ - "Regval_list : list(register_value), " ^ - "Regval_option : option(register_value), " ^ - "Regval_bool : bool, " ^ - "Regval_int : int, " ^ - "Regval_real : real, " ^ - "Regval_string : string" + "Regval_vector : list(register_value), " ^ "Regval_list : list(register_value), " + ^ "Regval_option : option(register_value), " ^ "Regval_bool : bool, " ^ "Regval_int : int, " + ^ "Regval_real : real, " ^ "Regval_string : string" in - [defs_of_string __POS__ - ("union register_value = { " ^ - (String.concat ", " (builtins :: List.map constr (Bindings.bindings typs))) ^ - " }")] + [ + defs_of_string __POS__ + ("union register_value = { " ^ String.concat ", " (builtins :: List.map constr (Bindings.bindings typs)) ^ " }"); + ] let regval_class_typs_lem = [("bool", "bool"); ("int", "integer"); ("real", "real"); ("string", "string")] let regval_instance_lem = let conv_def (name, typ) = - [ "val " ^ name ^ "_of_register_value : register_value -> maybe " ^ typ; + [ + "val " ^ name ^ "_of_register_value : register_value -> maybe " ^ typ; "let " ^ name ^ "_of_register_value rv = match rv with Regval_" ^ name ^ " v -> Just v | _ -> Nothing end"; "val register_value_of_" ^ name ^ " : " ^ typ ^ " -> register_value"; - "let register_value_of_" ^ name ^ " v = Regval_" ^ name ^ " v" ] + "let register_value_of_" ^ name ^ " v = Regval_" ^ name ^ " v"; + ] in let conv_inst (name, typ) = - [ "let " ^ name ^ "_of_regval = " ^ name ^ "_of_register_value"; - "let regval_of_" ^ name ^ " = register_value_of_" ^ name ] + [ + "let " ^ name ^ "_of_regval = " ^ name ^ "_of_register_value"; + "let regval_of_" ^ name ^ " = register_value_of_" ^ name; + ] in separate_map hardline string (List.concat (List.map conv_def regval_class_typs_lem) @ [""; "instance (Register_Value register_value)"] @ List.concat (List.map conv_inst regval_class_typs_lem) - @ ["end"]) + @ ["end"] + ) let add_regval_conv id typ defs = let id = string_of_id id in @@ -305,10 +289,10 @@ let add_regval_conv id typ defs = (* Create a function that converts from regval to the target type. *) let from_name = id ^ "_of_regval" in let from_val = Printf.sprintf "val %s : register_value -> option(%s)" from_name typ_str in - let from_function = String.concat "\n" [ - Printf.sprintf "function %s Regval_%s(v) = Some(v)" from_name id; - Printf.sprintf "and %s _ = None()" from_name - ] in + let from_function = + String.concat "\n" + [Printf.sprintf "function %s Regval_%s(v) = Some(v)" from_name id; Printf.sprintf "and %s _ = None()" from_name] + in let from_defs = if is_defined defs from_name then [] else [from_val; from_function] in (* Create a function that converts from target type to regval. *) let to_name = "regval_of_" ^ id in @@ -318,83 +302,104 @@ let add_regval_conv id typ defs = let cdefs = List.concat (List.map (defs_of_string __POS__) (from_defs @ to_defs)) in defs @ cdefs -let rec regval_convs mwords wrap_fun (Typ_aux (t, _) as typ) = match t with +let rec regval_convs mwords wrap_fun (Typ_aux (t, _) as typ) = + match t with | Typ_app _ when (is_vector_typ typ || is_bitvector_typ typ) && not (mwords && is_bitvector_typ typ) -> - let size, ord, etyp = vector_typ_args_of typ in - let etyp_of, of_etyp = regval_convs mwords wrap_fun etyp in - "vector_of_regval " ^ wrap_fun etyp_of, - "regval_of_vector " ^ wrap_fun of_etyp - | Typ_app (id, [A_aux (A_typ etyp, _)]) - when string_of_id id = "list" -> - let etyp_of, of_etyp = regval_convs mwords wrap_fun etyp in - "list_of_regval " ^ wrap_fun etyp_of, - "regval_of_list " ^ wrap_fun of_etyp - | Typ_app (id, [A_aux (A_typ etyp, _)]) - when string_of_id id = "option" -> - let etyp_of, of_etyp = regval_convs mwords wrap_fun etyp in - "option_of_regval " ^ wrap_fun etyp_of, - "regval_of_option " ^ wrap_fun of_etyp + let size, ord, etyp = vector_typ_args_of typ in + let etyp_of, of_etyp = regval_convs mwords wrap_fun etyp in + ("vector_of_regval " ^ wrap_fun etyp_of, "regval_of_vector " ^ wrap_fun of_etyp) + | Typ_app (id, [A_aux (A_typ etyp, _)]) when string_of_id id = "list" -> + let etyp_of, of_etyp = regval_convs mwords wrap_fun etyp in + ("list_of_regval " ^ wrap_fun etyp_of, "regval_of_list " ^ wrap_fun of_etyp) + | Typ_app (id, [A_aux (A_typ etyp, _)]) when string_of_id id = "option" -> + let etyp_of, of_etyp = regval_convs mwords wrap_fun etyp in + ("option_of_regval " ^ wrap_fun etyp_of, "regval_of_option " ^ wrap_fun of_etyp) | _ -> - let id = string_of_id (regval_constr_id mwords typ) in - if List.mem id (List.map fst regval_class_typs_lem) - then id ^ "_of_register_value", "register_value_of_" ^ id - else id ^ "_of_regval", "regval_of_" ^ id + let id = string_of_id (regval_constr_id mwords typ) in + if List.mem id (List.map fst regval_class_typs_lem) then (id ^ "_of_register_value", "register_value_of_" ^ id) + else (id ^ "_of_regval", "regval_of_" ^ id) let regval_convs_lem mwords = regval_convs mwords (fun conv -> "(fun v -> " ^ conv ^ " v)") let regval_convs_isa mwords = regval_convs mwords (fun conv -> "(\\v. " ^ conv ^ " v)") let register_refs_lem mwords pp_tannot registers = let generic_convs = - separate_map hardline string [ - "val vector_of_regval : forall 'a. (register_value -> maybe 'a) -> register_value -> maybe (list 'a)"; - "let vector_of_regval of_regval rv = match rv with"; - " | Regval_vector v -> just_list (List.map of_regval v)"; - " | _ -> Nothing"; - "end"; - ""; - "val regval_of_vector : forall 'a. ('a -> register_value) -> list 'a -> register_value"; - "let regval_of_vector regval_of xs = Regval_vector (List.map regval_of xs)"; - ""; - "val list_of_regval : forall 'a. (register_value -> maybe 'a) -> register_value -> maybe (list 'a)"; - "let list_of_regval of_regval rv = match rv with"; - " | Regval_list v -> just_list (List.map of_regval v)"; - " | _ -> Nothing"; - "end"; - ""; - "val regval_of_list : forall 'a. ('a -> register_value) -> list 'a -> register_value"; - "let regval_of_list regval_of xs = Regval_list (List.map regval_of xs)"; - ""; - "val option_of_regval : forall 'a. (register_value -> maybe 'a) -> register_value -> maybe (maybe 'a)"; - "let option_of_regval of_regval rv = match rv with"; - " | Regval_option v -> Just (Maybe.bind v of_regval)"; - " | _ -> Nothing"; - "end"; - ""; - "val regval_of_option : forall 'a. ('a -> register_value) -> maybe 'a -> register_value"; - "let regval_of_option regval_of v = Regval_option (Maybe.map regval_of v)"; - ""; - "" - ] + separate_map hardline string + [ + "val vector_of_regval : forall 'a. (register_value -> maybe 'a) -> register_value -> maybe (list 'a)"; + "let vector_of_regval of_regval rv = match rv with"; + " | Regval_vector v -> just_list (List.map of_regval v)"; + " | _ -> Nothing"; + "end"; + ""; + "val regval_of_vector : forall 'a. ('a -> register_value) -> list 'a -> register_value"; + "let regval_of_vector regval_of xs = Regval_vector (List.map regval_of xs)"; + ""; + "val list_of_regval : forall 'a. (register_value -> maybe 'a) -> register_value -> maybe (list 'a)"; + "let list_of_regval of_regval rv = match rv with"; + " | Regval_list v -> just_list (List.map of_regval v)"; + " | _ -> Nothing"; + "end"; + ""; + "val regval_of_list : forall 'a. ('a -> register_value) -> list 'a -> register_value"; + "let regval_of_list regval_of xs = Regval_list (List.map regval_of xs)"; + ""; + "val option_of_regval : forall 'a. (register_value -> maybe 'a) -> register_value -> maybe (maybe 'a)"; + "let option_of_regval of_regval rv = match rv with"; + " | Regval_option v -> Just (Maybe.bind v of_regval)"; + " | _ -> Nothing"; + "end"; + ""; + "val regval_of_option : forall 'a. ('a -> register_value) -> maybe 'a -> register_value"; + "let regval_of_option regval_of v = Regval_option (Maybe.map regval_of v)"; + ""; + ""; + ] in let register_ref (typ, id) = let idd = string (string_of_id id) in - let (read_from, write_to) = - if !opt_type_grouped_regstate then + let read_from, write_to = + if !opt_type_grouped_regstate then ( let field_idd = string (string_of_id (regstate_field typ)) in - (field_idd ^^ space ^^ dquotes idd, - doc_op equals field_idd (string "(fun reg -> if reg = \"" ^^ idd ^^ string "\" then v else s." ^^ field_idd ^^ string " reg)")) - else - (idd, doc_op equals idd (string "v")) + ( field_idd ^^ space ^^ dquotes idd, + doc_op equals field_idd + (string "(fun reg -> if reg = \"" ^^ idd ^^ string "\" then v else s." ^^ field_idd ^^ string " reg)") + ) + ) + else (idd, doc_op equals idd (string "v")) in (* let field = if prefix_recordtype then string "regstate_" ^^ idd else idd in *) let of_regval, regval_of = regval_convs_lem mwords typ in let tannot = pp_tannot typ in - concat [string "let "; idd; string "_ref "; tannot; string " = <|"; hardline; - string " name = \""; idd; string "\";"; hardline; - string " read_from = (fun s -> s."; read_from; string ");"; hardline; - string " write_to = (fun v s -> (<| s with "; write_to; string " |>));"; hardline; - string " of_regval = (fun v -> "; string of_regval; string " v);"; hardline; - string " regval_of = (fun v -> "; string regval_of; string " v) |>"; hardline] + concat + [ + string "let "; + idd; + string "_ref "; + tannot; + string " = <|"; + hardline; + string " name = \""; + idd; + string "\";"; + hardline; + string " read_from = (fun s -> s."; + read_from; + string ");"; + hardline; + string " write_to = (fun v s -> (<| s with "; + write_to; + string " |>));"; + hardline; + string " of_regval = (fun v -> "; + string of_regval; + string " v);"; + hardline; + string " regval_of = (fun v -> "; + string regval_of; + string " v) |>"; + hardline; + ] in let refs = separate_map hardline register_ref registers in let mk_reg_assoc (_, id) = @@ -402,19 +407,27 @@ let register_refs_lem mwords pp_tannot registers = let qidd = "\"" ^ idd ^ "\"" in string (" (" ^ qidd ^ ", register_ops_of " ^ idd ^ "_ref)") in - let reg_assocs = separate hardline [ - string "val registers : list (string * register_ops regstate register_value)"; - string "let registers = ["; - separate (string ";" ^^ hardline) (List.map mk_reg_assoc registers); - string " ]"] ^^ hardline + let reg_assocs = + separate hardline + [ + string "val registers : list (string * register_ops regstate register_value)"; + string "let registers = ["; + separate (string ";" ^^ hardline) (List.map mk_reg_assoc registers); + string " ]"; + ] + ^^ hardline in let getters_setters = - string "let register_accessors = mk_accessors (fun nm -> List.lookup nm registers)" ^^ - hardline ^^ hardline ^^ - string "val get_regval : string -> regstate -> maybe register_value" ^^ hardline ^^ - string "let get_regval = fst register_accessors" ^^ hardline ^^ hardline ^^ - string "val set_regval : string -> register_value -> regstate -> maybe regstate" ^^ hardline ^^ - string "let set_regval = snd register_accessors" ^^ hardline ^^ hardline + string "let register_accessors = mk_accessors (fun nm -> List.lookup nm registers)" + ^^ hardline ^^ hardline + ^^ string "val get_regval : string -> regstate -> maybe register_value" + ^^ hardline + ^^ string "let get_regval = fst register_accessors" + ^^ hardline ^^ hardline + ^^ string "val set_regval : string -> register_value -> regstate -> maybe regstate" + ^^ hardline + ^^ string "let set_regval = snd register_accessors" + ^^ hardline ^^ hardline (* string "let liftS s = liftState register_accessors s" ^^ hardline *) in separate hardline [generic_convs; refs; reg_assocs; getters_setters] @@ -423,114 +436,119 @@ let register_refs_lem mwords pp_tannot registers = asserting that all lists representing non-bit-vectors have the right length. *) let generate_isa_lemmas mwords defs = - let rec drop_while f = function - | x :: xs when f x -> drop_while f xs - | xs -> xs - in - let remove_leading_underscores str = - String.concat "_" (drop_while (fun s -> s = "") (Util.split_on_char '_' str)) - in + let rec drop_while f = function x :: xs when f x -> drop_while f xs | xs -> xs in + let remove_leading_underscores str = String.concat "_" (drop_while (fun s -> s = "") (Util.split_on_char '_' str)) in let remove_trailing_underscores str = - Util.split_on_char '_' str |> List.rev |> - drop_while (fun s -> s = "") |> List.rev |> - String.concat "_" + Util.split_on_char '_' str |> List.rev |> drop_while (fun s -> s = "") |> List.rev |> String.concat "_" in let remove_underscores str = remove_leading_underscores (remove_trailing_underscores str) in let registers = find_registers defs in - let regtyp_ids = - register_base_types mwords (List.map fst registers) - |> Bindings.bindings |> List.map fst - in + let regtyp_ids = register_base_types mwords (List.map fst registers) |> Bindings.bindings |> List.map fst in let regval_class_typ_ids = List.map (fun (t, _) -> mk_id t) regval_class_typs_lem in let register_defs = let reg_id id = remove_leading_underscores (string_of_id id) in - hang 2 (flow_map (break 1) string - (["lemmas register_defs"; "="; "get_regval_unfold"; "set_regval_unfold"] @ - (List.map (fun (typ, id) -> reg_id id ^ "_ref_def") registers))) + hang 2 + (flow_map (break 1) string + (["lemmas register_defs"; "="; "get_regval_unfold"; "set_regval_unfold"] + @ List.map (fun (typ, id) -> reg_id id ^ "_ref_def") registers + ) + ) in let conv_lemma typ_id = let typ_id = remove_trailing_underscores (string_of_id typ_id) in let typ_id' = remove_leading_underscores typ_id in - let (of_rv, rv_of) = - if List.mem typ_id (List.map fst regval_class_typs_lem) - then (typ_id' ^ "_of_register_value", "register_value_of_" ^ typ_id) + let of_rv, rv_of = + if List.mem typ_id (List.map fst regval_class_typs_lem) then + (typ_id' ^ "_of_register_value", "register_value_of_" ^ typ_id) else (typ_id' ^ "_of_regval", "regval_of_" ^ typ_id) in - string ("lemma " ^ of_rv ^ "_eq_Some_iff[simp]:") ^^ hardline ^^ - string (" \"" ^ of_rv ^ " rv = Some v \\ rv = Regval_" ^ typ_id ^ " v\"") ^^ hardline ^^ - string (" by (cases rv; auto)") ^^ hardline ^^ - hardline ^^ - string ("declare " ^ rv_of ^ "_def[simp]") ^^ hardline ^^ - hardline ^^ - string ("lemma regval_" ^ typ_id ^ "[simp]:") ^^ hardline ^^ - string (" \"" ^ of_rv ^ " (" ^ rv_of ^ " v) = Some v\"") ^^ hardline ^^ - string (" by auto") + string ("lemma " ^ of_rv ^ "_eq_Some_iff[simp]:") + ^^ hardline + ^^ string (" \"" ^ of_rv ^ " rv = Some v \\ rv = Regval_" ^ typ_id ^ " v\"") + ^^ hardline ^^ string " by (cases rv; auto)" ^^ hardline ^^ hardline + ^^ string ("declare " ^ rv_of ^ "_def[simp]") + ^^ hardline ^^ hardline + ^^ string ("lemma regval_" ^ typ_id ^ "[simp]:") + ^^ hardline + ^^ string (" \"" ^ of_rv ^ " (" ^ rv_of ^ " v) = Some v\"") + ^^ hardline ^^ string " by auto" in let register_lemmas (typ, id) = let id = remove_leading_underscores (string_of_id id) in - separate_map hardline string [ - "lemma liftS_read_reg_" ^ id ^ "[liftState_simp]:"; - " \"\\read_reg " ^ id ^ "_ref\\\\<^sub>S = read_regS " ^ id ^ "_ref\""; - " by (intro liftState_read_reg) (auto simp: register_defs)"; - ""; - "lemma liftS_write_reg_" ^ id ^ "[liftState_simp]:"; - " \"\\write_reg " ^ id ^ "_ref v\\\\<^sub>S = write_regS " ^ id ^ "_ref v\""; - " by (intro liftState_write_reg) (auto simp: register_defs)" - ] + separate_map hardline string + [ + "lemma liftS_read_reg_" ^ id ^ "[liftState_simp]:"; + " \"\\read_reg " ^ id ^ "_ref\\\\<^sub>S = read_regS " ^ id ^ "_ref\""; + " by (intro liftState_read_reg) (auto simp: register_defs)"; + ""; + "lemma liftS_write_reg_" ^ id ^ "[liftState_simp]:"; + " \"\\write_reg " ^ id ^ "_ref v\\\\<^sub>S = write_regS " ^ id ^ "_ref v\""; + " by (intro liftState_write_reg) (auto simp: register_defs)"; + ] in - let registers_eqs = separate hardline (List.map string [ - "lemma registers_distinct:"; - " \"distinct (map fst registers)\""; - " unfolding registers_def list.simps fst_conv"; - " by (distinct_string; simp)"; - ""; - "lemma registers_eqs_setup:"; - " \"!x : set registers. map_of registers (fst x) = Some (snd x)\""; - " using registers_distinct"; - " by simp"; - ""; - "lemmas map_of_registers_eqs[simp] ="; - " registers_eqs_setup[simplified arg_cong[where f=set, OF registers_def]"; - " list.simps ball_simps fst_conv snd_conv]"; - ""; - "lemmas get_regval_unfold = get_regval_def[THEN fun_cong,"; - " unfolded register_accessors_def mk_accessors_def fst_conv snd_conv]"; - "lemmas set_regval_unfold = set_regval_def[THEN fun_cong,"; - " unfolded register_accessors_def mk_accessors_def fst_conv snd_conv]"; - ]) + let registers_eqs = + separate hardline + (List.map string + [ + "lemma registers_distinct:"; + " \"distinct (map fst registers)\""; + " unfolding registers_def list.simps fst_conv"; + " by (distinct_string; simp)"; + ""; + "lemma registers_eqs_setup:"; + " \"!x : set registers. map_of registers (fst x) = Some (snd x)\""; + " using registers_distinct"; + " by simp"; + ""; + "lemmas map_of_registers_eqs[simp] ="; + " registers_eqs_setup[simplified arg_cong[where f=set, OF registers_def]"; + " list.simps ball_simps fst_conv snd_conv]"; + ""; + "lemmas get_regval_unfold = get_regval_def[THEN fun_cong,"; + " unfolded register_accessors_def mk_accessors_def fst_conv snd_conv]"; + "lemmas set_regval_unfold = set_regval_def[THEN fun_cong,"; + " unfolded register_accessors_def mk_accessors_def fst_conv snd_conv]"; + ] + ) in - let module StringMap = Map.Make(String) in + let module StringMap = Map.Make (String) in let field_id typ = remove_leading_underscores (string_of_id (id_of_regtyp IdSet.empty false typ)) in let field_id_stripped typ = remove_trailing_underscores (field_id typ) in let set_regval_type_cases = let add_reg_case cases (typ, id) = let of_regval = remove_underscores (fst (regval_convs_isa mwords typ)) in let case = - "(" ^ field_id_stripped typ ^ ") v where " ^ - "\"" ^ of_regval ^ " rv = Some v\" and " ^ - "\"s' = s\\" ^ field_id typ ^ "_reg := (" ^ field_id typ ^ "_reg s)(r := v)\\\"" + "(" ^ field_id_stripped typ ^ ") v where " ^ "\"" ^ of_regval ^ " rv = Some v\" and " ^ "\"s' = s\\" + ^ field_id typ ^ "_reg := (" ^ field_id typ ^ "_reg s)(r := v)\\\"" in StringMap.add (field_id typ) case cases in let cases = List.fold_left add_reg_case StringMap.empty registers |> StringMap.bindings |> List.map snd in - let prove_case (typ, id) = " subgoal using " ^ field_id_stripped typ ^ " by (auto simp: register_defs fun_upd_def)" in + let prove_case (typ, id) = + " subgoal using " ^ field_id_stripped typ ^ " by (auto simp: register_defs fun_upd_def)" + in if List.length cases > 0 && !opt_type_grouped_regstate then - string "lemma set_regval_Some_type_cases:" ^^ hardline ^^ - string " assumes \"set_regval r rv s = Some s'\"" ^^ hardline ^^ - string " obtains " ^^ separate_map (hardline ^^ string " | ") string cases ^^ hardline ^^ - string "proof -" ^^ hardline ^^ - string " from assms show ?thesis" ^^ hardline ^^ - string " unfolding set_regval_unfold registers_def" ^^ hardline ^^ - string " apply (elim option_bind_SomeE map_of_Cons_SomeE)" ^^ hardline ^^ - separate_map hardline string (List.map prove_case registers) ^^ hardline ^^ - string " by auto" ^^ hardline ^^ - string "qed" + string "lemma set_regval_Some_type_cases:" + ^^ hardline + ^^ string " assumes \"set_regval r rv s = Some s'\"" + ^^ hardline ^^ string " obtains " + ^^ separate_map (hardline ^^ string " | ") string cases + ^^ hardline ^^ string "proof -" ^^ hardline ^^ string " from assms show ?thesis" ^^ hardline + ^^ string " unfolding set_regval_unfold registers_def" + ^^ hardline + ^^ string " apply (elim option_bind_SomeE map_of_Cons_SomeE)" + ^^ hardline + ^^ separate_map hardline string (List.map prove_case registers) + ^^ hardline ^^ string " by auto" ^^ hardline ^^ string "qed" else string "" in let get_regval_type_cases = let add_reg_case cases (typ, id) = let regval_of = remove_underscores (snd (regval_convs_isa mwords typ)) in - let case = "(" ^ field_id_stripped typ ^ ") \"get_regval r = (\\s. Some (" ^ regval_of ^ " (" ^ field_id typ ^ "_reg s r)))\"" in + let case = + "(" ^ field_id_stripped typ ^ ") \"get_regval r = (\\s. Some (" ^ regval_of ^ " (" ^ field_id typ + ^ "_reg s r)))\"" + in StringMap.add (field_id typ) case cases in let cases = List.fold_left add_reg_case StringMap.empty registers in @@ -538,149 +556,192 @@ let generate_isa_lemmas mwords defs = let cases = (StringMap.bindings cases |> List.map snd) @ [fail_case] in let prove_case (typ, id) = " subgoal using " ^ field_id_stripped typ ^ " by (auto simp: register_defs)" in if !opt_type_grouped_regstate then - string "lemma get_regval_type_cases:" ^^ hardline ^^ - string " fixes r :: string" ^^ hardline ^^ - string " obtains " ^^ separate_map (hardline ^^ string " | ") string cases ^^ hardline ^^ - string "proof (cases \"map_of registers r\")" ^^ hardline ^^ - string " case (Some ops)" ^^ hardline ^^ - string " then show ?thesis" ^^ hardline ^^ - string " unfolding registers_def" ^^ hardline ^^ - string " apply (elim map_of_Cons_SomeE)" ^^ hardline ^^ - separate_map hardline string (List.map prove_case registers) ^^ hardline ^^ - string " by auto" ^^ hardline ^^ - string "qed (auto simp: get_regval_unfold)" + string "lemma get_regval_type_cases:" ^^ hardline ^^ string " fixes r :: string" ^^ hardline + ^^ string " obtains " + ^^ separate_map (hardline ^^ string " | ") string cases + ^^ hardline + ^^ string "proof (cases \"map_of registers r\")" + ^^ hardline ^^ string " case (Some ops)" ^^ hardline ^^ string " then show ?thesis" ^^ hardline + ^^ string " unfolding registers_def" ^^ hardline + ^^ string " apply (elim map_of_Cons_SomeE)" + ^^ hardline + ^^ separate_map hardline string (List.map prove_case registers) + ^^ hardline ^^ string " by auto" ^^ hardline + ^^ string "qed (auto simp: get_regval_unfold)" else string "" in - registers_eqs ^^ hardline ^^ hardline ^^ - string "abbreviation liftS (\"\\_\\\\<^sub>S\") where \"liftS \\ liftState (get_regval, set_regval)\"" ^^ - hardline ^^ hardline ^^ - register_defs ^^ - hardline ^^ hardline ^^ - separate_map (hardline ^^ hardline) conv_lemma (regval_class_typ_ids @ regtyp_ids) ^^ - hardline ^^ hardline ^^ - separate_map hardline string [ - "lemma vector_of_rv_rv_of_vector[simp]:"; - " assumes \"\\v. of_rv (rv_of v) = Some v\""; - " shows \"vector_of_regval of_rv (regval_of_vector rv_of v) = Some v\""; - "proof -"; - " from assms have \"of_rv \\ rv_of = Some\" by auto"; - " then show ?thesis by (auto simp: regval_of_vector_def)"; - "qed"; - ""; - "lemma option_of_rv_rv_of_option[simp]:"; - " assumes \"\\v. of_rv (rv_of v) = Some v\""; - " shows \"option_of_regval of_rv (regval_of_option rv_of v) = Some v\""; - " using assms by (cases v) (auto simp: regval_of_option_def)"; - ""; - "lemma list_of_rv_rv_of_list[simp]:"; - " assumes \"\\v. of_rv (rv_of v) = Some v\""; - " shows \"list_of_regval of_rv (regval_of_list rv_of v) = Some v\""; - "proof -"; - " from assms have \"of_rv \\ rv_of = Some\" by auto"; - " with assms show ?thesis by (induction v) (auto simp: regval_of_list_def)"; - "qed"] ^^ - hardline ^^ hardline ^^ - separate_map (hardline ^^ hardline) register_lemmas registers ^^ - hardline ^^ hardline ^^ - set_regval_type_cases ^^ - hardline ^^ hardline ^^ - get_regval_type_cases + registers_eqs ^^ hardline ^^ hardline + ^^ string + "abbreviation liftS (\"\\_\\\\<^sub>S\") where \"liftS \\ liftState (get_regval, \ + set_regval)\"" + ^^ hardline ^^ hardline ^^ register_defs ^^ hardline ^^ hardline + ^^ separate_map (hardline ^^ hardline) conv_lemma (regval_class_typ_ids @ regtyp_ids) + ^^ hardline ^^ hardline + ^^ separate_map hardline string + [ + "lemma vector_of_rv_rv_of_vector[simp]:"; + " assumes \"\\v. of_rv (rv_of v) = Some v\""; + " shows \"vector_of_regval of_rv (regval_of_vector rv_of v) = Some v\""; + "proof -"; + " from assms have \"of_rv \\ rv_of = Some\" by auto"; + " then show ?thesis by (auto simp: regval_of_vector_def)"; + "qed"; + ""; + "lemma option_of_rv_rv_of_option[simp]:"; + " assumes \"\\v. of_rv (rv_of v) = Some v\""; + " shows \"option_of_regval of_rv (regval_of_option rv_of v) = Some v\""; + " using assms by (cases v) (auto simp: regval_of_option_def)"; + ""; + "lemma list_of_rv_rv_of_list[simp]:"; + " assumes \"\\v. of_rv (rv_of v) = Some v\""; + " shows \"list_of_regval of_rv (regval_of_list rv_of v) = Some v\""; + "proof -"; + " from assms have \"of_rv \\ rv_of = Some\" by auto"; + " with assms show ?thesis by (induction v) (auto simp: regval_of_list_def)"; + "qed"; + ] + ^^ hardline ^^ hardline + ^^ separate_map (hardline ^^ hardline) register_lemmas registers + ^^ hardline ^^ hardline ^^ set_regval_type_cases ^^ hardline ^^ hardline ^^ get_regval_type_cases -let rec regval_convs_coq (Typ_aux (t, _) as typ) = match t with +let rec regval_convs_coq (Typ_aux (t, _) as typ) = + match t with | Typ_app _ when is_vector_typ typ && not (is_bitvector_typ typ) -> - let size, ord, etyp = vector_typ_args_of typ in - let size = string_of_nexp (nexp_simp size) in - let etyp_of, of_etyp = regval_convs_coq etyp in - "(fun v => vector_of_regval " ^ size ^ " " ^ etyp_of ^ " v)", - "(fun v => regval_of_vector " ^ of_etyp ^ " v)" - | Typ_app (id, [A_aux (A_typ etyp, _)]) - when string_of_id id = "list" -> - let etyp_of, of_etyp = regval_convs_coq etyp in - "(fun v => list_of_regval " ^ etyp_of ^ " v)", - "(fun v => regval_of_list " ^ of_etyp ^ " v)" - | Typ_app (id, [A_aux (A_typ etyp, _)]) - when string_of_id id = "option" -> - let etyp_of, of_etyp = regval_convs_coq etyp in - "(fun v => option_of_regval " ^ etyp_of ^ " v)", - "(fun v => regval_of_option " ^ of_etyp ^ " v)" + let size, ord, etyp = vector_typ_args_of typ in + let size = string_of_nexp (nexp_simp size) in + let etyp_of, of_etyp = regval_convs_coq etyp in + ("(fun v => vector_of_regval " ^ size ^ " " ^ etyp_of ^ " v)", "(fun v => regval_of_vector " ^ of_etyp ^ " v)") + | Typ_app (id, [A_aux (A_typ etyp, _)]) when string_of_id id = "list" -> + let etyp_of, of_etyp = regval_convs_coq etyp in + ("(fun v => list_of_regval " ^ etyp_of ^ " v)", "(fun v => regval_of_list " ^ of_etyp ^ " v)") + | Typ_app (id, [A_aux (A_typ etyp, _)]) when string_of_id id = "option" -> + let etyp_of, of_etyp = regval_convs_coq etyp in + ("(fun v => option_of_regval " ^ etyp_of ^ " v)", "(fun v => regval_of_option " ^ of_etyp ^ " v)") | _ -> - let id = string_of_id (regval_constr_id true typ) in - "(fun v => " ^ id ^ "_of_regval v)", "(fun v => regval_of_" ^ id ^ " v)" + let id = string_of_id (regval_constr_id true typ) in + ("(fun v => " ^ id ^ "_of_regval v)", "(fun v => regval_of_" ^ id ^ " v)") let register_refs_coq doc_id registers = let generic_convs = - separate_map hardline string [ - "Definition bool_of_regval (merge_var : register_value) : option bool :="; - " match merge_var with | Regval_bool v => Some v | _ => None end."; - ""; - "Definition regval_of_bool (v : bool) : register_value := Regval_bool v."; - ""; - "Definition int_of_regval (merge_var : register_value) : option Z :="; - " match merge_var with | Regval_int v => Some v | _ => None end."; - ""; - "Definition regval_of_int (v : Z) : register_value := Regval_int v."; - ""; - "Definition real_of_regval (merge_var : register_value) : option R :="; - " match merge_var with | Regval_real v => Some v | _ => None end."; - ""; - "Definition regval_of_real (v : R) : register_value := Regval_real v."; - ""; - "Definition string_of_regval (merge_var : register_value) : option string :="; - " match merge_var with | Regval_string v => Some v | _ => None end."; - ""; - "Definition regval_of_string (v : string) : register_value := Regval_string v."; - ""; - "Definition vector_of_regval {a} n (of_regval : register_value -> option a) (rv : register_value) : option (vec a n) := match rv with"; - " | Regval_vector v => if n =? length_list v then map_bind (vec_of_list n) (just_list (List.map of_regval v)) else None"; - " | _ => None"; - "end."; - ""; - "Definition regval_of_vector {a size} (regval_of : a -> register_value) (xs : vec a size) : register_value := Regval_vector (List.map regval_of (list_of_vec xs))."; - ""; - "Definition list_of_regval {a} (of_regval : register_value -> option a) (rv : register_value) : option (list a) := match rv with"; - " | Regval_list v => just_list (List.map of_regval v)"; - " | _ => None"; - "end."; - ""; - "Definition regval_of_list {a} (regval_of : a -> register_value) (xs : list a) : register_value := Regval_list (List.map regval_of xs)."; - ""; - "Definition option_of_regval {a} (of_regval : register_value -> option a) (rv : register_value) : option (option a) := match rv with"; - " | Regval_option v => option_map of_regval v"; - " | _ => None"; - "end."; - ""; - "Definition regval_of_option {a} (regval_of : a -> register_value) (v : option a) := Regval_option (option_map regval_of v)."; - ""; - "" - ] + separate_map hardline string + [ + "Definition bool_of_regval (merge_var : register_value) : option bool :="; + " match merge_var with | Regval_bool v => Some v | _ => None end."; + ""; + "Definition regval_of_bool (v : bool) : register_value := Regval_bool v."; + ""; + "Definition int_of_regval (merge_var : register_value) : option Z :="; + " match merge_var with | Regval_int v => Some v | _ => None end."; + ""; + "Definition regval_of_int (v : Z) : register_value := Regval_int v."; + ""; + "Definition real_of_regval (merge_var : register_value) : option R :="; + " match merge_var with | Regval_real v => Some v | _ => None end."; + ""; + "Definition regval_of_real (v : R) : register_value := Regval_real v."; + ""; + "Definition string_of_regval (merge_var : register_value) : option string :="; + " match merge_var with | Regval_string v => Some v | _ => None end."; + ""; + "Definition regval_of_string (v : string) : register_value := Regval_string v."; + ""; + "Definition vector_of_regval {a} n (of_regval : register_value -> option a) (rv : register_value) : option \ + (vec a n) := match rv with"; + " | Regval_vector v => if n =? length_list v then map_bind (vec_of_list n) (just_list (List.map of_regval v)) \ + else None"; + " | _ => None"; + "end."; + ""; + "Definition regval_of_vector {a size} (regval_of : a -> register_value) (xs : vec a size) : register_value := \ + Regval_vector (List.map regval_of (list_of_vec xs))."; + ""; + "Definition list_of_regval {a} (of_regval : register_value -> option a) (rv : register_value) : option (list \ + a) := match rv with"; + " | Regval_list v => just_list (List.map of_regval v)"; + " | _ => None"; + "end."; + ""; + "Definition regval_of_list {a} (regval_of : a -> register_value) (xs : list a) : register_value := Regval_list \ + (List.map regval_of xs)."; + ""; + "Definition option_of_regval {a} (of_regval : register_value -> option a) (rv : register_value) : option \ + (option a) := match rv with"; + " | Regval_option v => option_map of_regval v"; + " | _ => None"; + "end."; + ""; + "Definition regval_of_option {a} (regval_of : a -> register_value) (v : option a) := Regval_option (option_map \ + regval_of v)."; + ""; + ""; + ] in let register_ref (typ, id) = let idd = doc_id id in (* let field = if prefix_recordtype then string "regstate_" ^^ idd else idd in *) let of_regval, regval_of = regval_convs_coq typ in - concat [string "Definition "; idd; string "_ref := {|"; hardline; - string " name := \""; idd; string "\";"; hardline; - string " read_from := (fun s => s.("; idd; string "));"; hardline; - string " write_to := (fun v s => ({[ s with "; idd; string " := v ]}));"; hardline; - string " of_regval := "; string of_regval; string ";"; hardline; - string " regval_of := "; string regval_of; string " |}."; hardline] + concat + [ + string "Definition "; + idd; + string "_ref := {|"; + hardline; + string " name := \""; + idd; + string "\";"; + hardline; + string " read_from := (fun s => s.("; + idd; + string "));"; + hardline; + string " write_to := (fun v s => ({[ s with "; + idd; + string " := v ]}));"; + hardline; + string " of_regval := "; + string of_regval; + string ";"; + hardline; + string " regval_of := "; + string regval_of; + string " |}."; + hardline; + ] in let refs = separate_map hardline register_ref registers in let get_set_reg (_, id) = let idd = doc_id id in - concat [string " if string_dec reg_name \""; idd; string "\" then Some ("; idd; string "_ref.(regval_of) ("; idd; string "_ref.(read_from) s)) else"], - concat [string " if string_dec reg_name \""; idd; string "\" then option_map (fun v => "; idd; string "_ref.(write_to) v s) ("; idd; string "_ref.(of_regval) v) else"] + ( concat + [ + string " if string_dec reg_name \""; + idd; + string "\" then Some ("; + idd; + string "_ref.(regval_of) ("; + idd; + string "_ref.(read_from) s)) else"; + ], + concat + [ + string " if string_dec reg_name \""; + idd; + string "\" then option_map (fun v => "; + idd; + string "_ref.(write_to) v s) ("; + idd; + string "_ref.(of_regval) v) else"; + ] + ) in let getters_setters = let getters, setters = List.split (List.map get_set_reg registers) in - string "Local Open Scope string." ^^ hardline ^^ - string "Definition get_regval (reg_name : string) (s : regstate) : option register_value :=" ^^ hardline ^^ - separate hardline getters ^^ hardline ^^ - string " None." ^^ hardline ^^ hardline ^^ - string "Definition set_regval (reg_name : string) (v : register_value) (s : regstate) : option regstate :=" ^^ hardline ^^ - separate hardline setters ^^ hardline ^^ - string " None." ^^ hardline ^^ hardline ^^ - string "Definition register_accessors := (get_regval, set_regval)." ^^ hardline ^^ hardline + string "Local Open Scope string." ^^ hardline + ^^ string "Definition get_regval (reg_name : string) (s : regstate) : option register_value :=" + ^^ hardline ^^ separate hardline getters ^^ hardline ^^ string " None." ^^ hardline ^^ hardline + ^^ string "Definition set_regval (reg_name : string) (v : register_value) (s : regstate) : option regstate :=" + ^^ hardline ^^ separate hardline setters ^^ hardline ^^ string " None." ^^ hardline ^^ hardline + ^^ string "Definition register_accessors := (get_regval, set_regval)." + ^^ hardline ^^ hardline in separate hardline [generic_convs; refs; getters_setters] @@ -694,8 +755,8 @@ let generate_regstate_defs mwords defs = let registers = find_registers defs in let regtyps = register_base_types mwords (List.map fst registers) in let option_typ = - if is_defined defs "option" then [] else - [defs_of_string __POS__ "union option ('a : Type) = {None : unit, Some : 'a}"] + if is_defined defs "option" then [] + else [defs_of_string __POS__ "union option ('a : Type) = {None : unit, Some : 'a}"] in let regval_typ = if is_defined defs "register_value" then [] else generate_regval_typ regtyps in let regstate_typ = if is_defined defs "regstate" then [] else [generate_regstate registers] in @@ -704,17 +765,14 @@ let generate_regstate_defs mwords defs = a regstate record with registers grouped per type; the latter would require record fields storing functions, which is not supported in Sail. *) - if is_defined defs "initial_regstate" || !opt_type_grouped_regstate then [] else - generate_initial_regstate defs + if is_defined defs "initial_regstate" || !opt_type_grouped_regstate then [] else generate_initial_regstate defs in let defs = - option_typ @ regval_typ @ regstate_typ @ initregstate - |> List.concat - |> Bindings.fold add_regval_conv regtyps + option_typ @ regval_typ @ regstate_typ @ initregstate |> List.concat |> Bindings.fold add_regval_conv regtyps in Initial_check.opt_undefined_gen := gen_undef; defs let add_regstate_defs mwords env ast = let reg_defs, env = Type_error.check_defs env (generate_regstate_defs mwords ast.defs) in - env, append_ast_defs ast reg_defs + (env, append_ast_defs ast reg_defs) diff --git a/src/lib/target.ml b/src/lib/target.ml index 7d0d493a0..6cae998d7 100644 --- a/src/lib/target.ml +++ b/src/lib/target.ml @@ -67,25 +67,25 @@ open Ast_defs open Type_check - -module StringMap = Map.Make(String) + +module StringMap = Map.Make (String) type target = { - name : string; - options : (Arg.key * Arg.spec * Arg.doc) list; - pre_parse_hook : (unit -> unit); - pre_rewrites_hook : (tannot ast -> Effects.side_effect_info -> Env.t -> unit); - rewrites : (string * Rewrites.rewriter_arg list) list; - action : string -> string option -> tannot ast -> Effects.side_effect_info -> Env.t -> unit; - asserts_termination : bool; - } + name : string; + options : (Arg.key * Arg.spec * Arg.doc) list; + pre_parse_hook : unit -> unit; + pre_rewrites_hook : tannot ast -> Effects.side_effect_info -> Env.t -> unit; + rewrites : (string * Rewrites.rewriter_arg list) list; + action : string -> string option -> tannot ast -> Effects.side_effect_info -> Env.t -> unit; + asserts_termination : bool; +} let name tgt = tgt.name let run_pre_parse_hook tgt = tgt.pre_parse_hook let run_pre_rewrites_hook tgt = tgt.pre_rewrites_hook - + let action tgt = tgt.action let rewrites tgt = Rewrites.instantiate_rewrites tgt.rewrites @@ -96,78 +96,76 @@ let targets = ref StringMap.empty let the_target = ref None -let register - ~name:name - ?flag:flag - ?description:desc - ?options:(options = []) - ?pre_parse_hook:(pre_parse_hook = (fun () -> ())) - ?pre_rewrites_hook:(pre_rewrites_hook = (fun _ _ _ -> ())) - ?rewrites:(rewrites = []) - ?asserts_termination:(asserts_termination = false) - action = - let set_target () = match !the_target with +let register ~name ?flag ?description:desc ?(options = []) ?(pre_parse_hook = fun () -> ()) + ?(pre_rewrites_hook = fun _ _ _ -> ()) ?(rewrites = []) ?(asserts_termination = false) action = + let set_target () = + match !the_target with | None -> the_target := Some name | Some tgt -> - prerr_endline ("Cannot use multiple Sail targets simultaneously: " ^ tgt ^ " and " ^ name); - exit 1 - in - let desc = match desc with - | Some desc -> desc - | None -> " invoke the Sail " ^ name ^ " target" - in - let flag = match flag with - | Some flag -> flag - | None -> name + prerr_endline ("Cannot use multiple Sail targets simultaneously: " ^ tgt ^ " and " ^ name); + exit 1 in - let tgt = { - name = name; + let desc = match desc with Some desc -> desc | None -> " invoke the Sail " ^ name ^ " target" in + let flag = match flag with Some flag -> flag | None -> name in + let tgt = + { + name; options = ("-" ^ flag, Arg.Unit set_target, desc) :: options; - pre_parse_hook = pre_parse_hook; - pre_rewrites_hook = pre_rewrites_hook; - rewrites = rewrites; - action = action; - asserts_termination = asserts_termination; - } in + pre_parse_hook; + pre_rewrites_hook; + rewrites; + action; + asserts_termination; + } + in targets := StringMap.add name tgt !targets; tgt -let get_the_target () = - match !the_target with - | Some name -> StringMap.find_opt name !targets - | None -> None +let get_the_target () = match !the_target with Some name -> StringMap.find_opt name !targets | None -> None -let get ~name:name = - StringMap.find_opt name !targets +let get ~name = StringMap.find_opt name !targets let extract_options () = - let opts = - StringMap.bindings !targets - |> List.map (fun (_, tgt) -> tgt.options) - |> List.concat in + let opts = StringMap.bindings !targets |> List.map (fun (_, tgt) -> tgt.options) |> List.concat in targets := StringMap.map (fun tgt -> { tgt with options = [] }) !targets; opts let () = let open Interactive in - ActionUnit (fun _ -> - List.iter (fun (name, _) -> - print_endline name - ) (StringMap.bindings !targets) - ) |> register_command ~name:"list_targets" ~help:"list available Sail targets for use with :target"; - - ArgString ("target", fun name -> Action (fun istate -> - match get ~name:name with - | Some tgt -> - let ast, effect_info, env = Rewrites.rewrite istate.effect_info istate.env (rewrites tgt) istate.ast in - { istate with ast = ast; env = env; effect_info = effect_info } - | None -> - print_endline ("No target " ^ name); - istate - )) |> register_command ~name:"rewrites" ~help:"perform rewrites for a target. See :list_targets for a list of targets"; - - ArgString ("target", fun name -> ArgString ("out", fun out -> ActionUnit (fun istate -> - match get ~name:name with - | Some tgt -> action tgt istate.default_sail_dir (Some out) istate.ast istate.effect_info istate.env; - | None -> print_endline ("No target " ^ name) - ))) |> register_command ~name:"target" ~help:"invoke Sail target. See :list_targets for a list of targets. out parameter is equivalent to command line -o option" + ActionUnit (fun _ -> List.iter (fun (name, _) -> print_endline name) (StringMap.bindings !targets)) + |> register_command ~name:"list_targets" ~help:"list available Sail targets for use with :target"; + + ArgString + ( "target", + fun name -> + Action + (fun istate -> + match get ~name with + | Some tgt -> + let ast, effect_info, env = Rewrites.rewrite istate.effect_info istate.env (rewrites tgt) istate.ast in + { istate with ast; env; effect_info } + | None -> + print_endline ("No target " ^ name); + istate + ) + ) + |> register_command ~name:"rewrites" ~help:"perform rewrites for a target. See :list_targets for a list of targets"; + + ArgString + ( "target", + fun name -> + ArgString + ( "out", + fun out -> + ActionUnit + (fun istate -> + match get ~name with + | Some tgt -> action tgt istate.default_sail_dir (Some out) istate.ast istate.effect_info istate.env + | None -> print_endline ("No target " ^ name) + ) + ) + ) + |> register_command ~name:"target" + ~help: + "invoke Sail target. See :list_targets for a list of targets. out parameter is equivalent to command line -o \ + option" diff --git a/src/lib/target.mli b/src/lib/target.mli index 58c814a67..7d542704b 100644 --- a/src/lib/target.mli +++ b/src/lib/target.mli @@ -77,7 +77,7 @@ open Ast_defs open Type_check (** {2 Target type and accessor functions} *) - + type target val name : target -> string @@ -87,13 +87,13 @@ val run_pre_parse_hook : target -> unit -> unit val run_pre_rewrites_hook : target -> tannot ast -> Effects.side_effect_info -> Env.t -> unit val rewrites : target -> Rewrites.rewrite_sequence - + val action : target -> string -> string option -> tannot ast -> Effects.side_effect_info -> Env.t -> unit val asserts_termination : target -> bool (** {2 Target registration} *) - + (** Used for plugins to register custom Sail targets/backends. [register_target ~name:"foo" action] will create an option -foo, diff --git a/src/lib/type_check.ml b/src/lib/type_check.ml index 92aaf8681..d31d299cc 100644 --- a/src/lib/type_check.ml +++ b/src/lib/type_check.ml @@ -92,22 +92,20 @@ let opt_smt_linearize = ref false (* Don't expand bitfields (when using old syntax), used for LaTeX output *) let opt_no_bitfield_expansion = ref false - + let depth = ref 0 -let rec indent n = match n with - | 0 -> "" - | n -> "| " ^ indent (n - 1) +let rec indent n = match n with 0 -> "" | n -> "| " ^ indent (n - 1) (* Lazily evaluate debugging message. This can make a big performance difference; for example, repeated calls to string_of_exp can be costly for deeply nested expressions, e.g. with long sequences of monadic binds. *) -let typ_debug ?level:(level=1) m = if !opt_tc_debug > level then prerr_endline (indent !depth ^ Lazy.force m) else () +let typ_debug ?(level = 1) m = if !opt_tc_debug > level then prerr_endline (indent !depth ^ Lazy.force m) else () let typ_print m = if !opt_tc_debug > 0 then prerr_endline (indent !depth ^ Lazy.force m) else () type constraint_reason = (Ast.l * string) option - + type type_error = (* First parameter is the error that caused us to start doing type coercions, the second is the errors encountered by all possible @@ -123,66 +121,60 @@ type type_error = let err_because (error1, l, error2) = Err_inner (error1, l, "Caused by", None, error2) -type env = - { top_val_specs : (typquant * typ) Bindings.t; - defined_val_specs : IdSet.t; - locals : (mut * typ) Bindings.t; - top_letbinds : IdSet.t; - union_ids : (typquant * typ) Bindings.t; - registers : typ Bindings.t; - variants : (typquant * type_union list) Bindings.t; - scattered_variant_envs : env Bindings.t; - mappings : (typquant * typ * typ) Bindings.t; - typ_vars : (Ast.l * kind_aux) KBindings.t; - shadow_vars : int KBindings.t; - typ_synonyms : (typquant * typ_arg) Bindings.t; - typ_params : typquant Bindings.t; - overloads : (id list) Bindings.t; - enums : IdSet.t Bindings.t; - records : (typquant * (typ * id) list) Bindings.t; - accessors : (typquant * typ) Bindings.t; - externs : extern Bindings.t; - casts : id list; - allow_casts : bool; - allow_bindings : bool; - constraints : (constraint_reason * n_constraint) list; - default_order : order option; - ret_typ : typ option; - poly_undefineds : bool; - prove : (env -> n_constraint -> bool) option; - allow_unknowns : bool; - bitfields : index_range Bindings.t Bindings.t; - toplevel : l option; - outcomes : (typquant * typ * kinded_id list * id list * env) Bindings.t; - outcome_typschm : (typquant * typ) option; - outcome_instantiation : (Ast.l * typ) KBindings.t; - } +type env = { + top_val_specs : (typquant * typ) Bindings.t; + defined_val_specs : IdSet.t; + locals : (mut * typ) Bindings.t; + top_letbinds : IdSet.t; + union_ids : (typquant * typ) Bindings.t; + registers : typ Bindings.t; + variants : (typquant * type_union list) Bindings.t; + scattered_variant_envs : env Bindings.t; + mappings : (typquant * typ * typ) Bindings.t; + typ_vars : (Ast.l * kind_aux) KBindings.t; + shadow_vars : int KBindings.t; + typ_synonyms : (typquant * typ_arg) Bindings.t; + typ_params : typquant Bindings.t; + overloads : id list Bindings.t; + enums : IdSet.t Bindings.t; + records : (typquant * (typ * id) list) Bindings.t; + accessors : (typquant * typ) Bindings.t; + externs : extern Bindings.t; + casts : id list; + allow_casts : bool; + allow_bindings : bool; + constraints : (constraint_reason * n_constraint) list; + default_order : order option; + ret_typ : typ option; + poly_undefineds : bool; + prove : (env -> n_constraint -> bool) option; + allow_unknowns : bool; + bitfields : index_range Bindings.t Bindings.t; + toplevel : l option; + outcomes : (typquant * typ * kinded_id list * id list * env) Bindings.t; + outcome_typschm : (typquant * typ) option; + outcome_instantiation : (Ast.l * typ) KBindings.t; +} -exception Type_error of env * l * type_error;; +exception Type_error of env * l * type_error let typ_error env l m = raise (Type_error (env, l, Err_other m)) let typ_raise env l err = raise (Type_error (env, l, err)) -let deinfix = function - | Id_aux (Id v, l) -> Id_aux (Operator v, l) - | Id_aux (Operator v, l) -> Id_aux (Operator v, l) +let deinfix = function Id_aux (Id v, l) -> Id_aux (Operator v, l) | Id_aux (Operator v, l) -> Id_aux (Operator v, l) let field_name rec_id id = - match rec_id, id with - | Id_aux (Id r, _), Id_aux (Id v, l) -> Id_aux (Id (r ^ "." ^ v), l) - | _, _ -> assert false + match (rec_id, id) with Id_aux (Id r, _), Id_aux (Id v, l) -> Id_aux (Id (r ^ "." ^ v), l) | _, _ -> assert false let string_of_bind (typquant, typ) = string_of_typquant typquant ^ ". " ^ string_of_typ typ let orig_kid (Kid_aux (Var v, l) as kid) = try let i = String.rindex v '#' in - if i >= 3 && String.sub v 0 3 = "'fv" then - Kid_aux (Var ("'" ^ String.sub v (i + 1) (String.length v - i - 1)), l) + if i >= 3 && String.sub v 0 3 = "'fv" then Kid_aux (Var ("'" ^ String.sub v (i + 1) (String.length v - i - 1)), l) else kid - with - | Not_found -> kid + with Not_found -> kid (* Rewrite mangled names of type variables to the original names *) let rec orig_nexp (Nexp_aux (nexp, l)) = @@ -197,30 +189,22 @@ let rec orig_nexp (Nexp_aux (nexp, l)) = | _ -> rewrap nexp let is_list (Typ_aux (typ_aux, _)) = - match typ_aux with - | Typ_app (f, [A_aux (A_typ typ, _)]) - when string_of_id f = "list" -> Some typ - | _ -> None + match typ_aux with Typ_app (f, [A_aux (A_typ typ, _)]) when string_of_id f = "list" -> Some typ | _ -> None -let is_unknown_type = function - | (Typ_aux (Typ_internal_unknown, _)) -> true - | _ -> false +let is_unknown_type = function Typ_aux (Typ_internal_unknown, _) -> true | _ -> false let is_atom (Typ_aux (typ_aux, _)) = - match typ_aux with - | Typ_app (f, [_]) when string_of_id f = "atom" -> true - | _ -> false + match typ_aux with Typ_app (f, [_]) when string_of_id f = "atom" -> true | _ -> false let is_atom_bool (Typ_aux (typ_aux, _)) = - match typ_aux with - | Typ_app (f, [_]) when string_of_id f = "atom_bool" -> true - | _ -> false + match typ_aux with Typ_app (f, [_]) when string_of_id f = "atom_bool" -> true | _ -> false let rec strip_id = function | Id_aux (Id x, _) -> Id_aux (Id x, Parse_ast.Unknown) | Id_aux (Operator x, _) -> Id_aux (Operator x, Parse_ast.Unknown) -and strip_kid = function - | Kid_aux (Var x, _) -> Kid_aux (Var x, Parse_ast.Unknown) + +and strip_kid = function Kid_aux (Var x, _) -> Kid_aux (Var x, Parse_ast.Unknown) + and strip_nexp_aux = function | Nexp_id id -> Nexp_id (strip_id id) | Nexp_var kid -> Nexp_var (strip_kid kid) @@ -231,8 +215,9 @@ and strip_nexp_aux = function | Nexp_minus (nexp1, nexp2) -> Nexp_minus (strip_nexp nexp1, strip_nexp nexp2) | Nexp_exp nexp -> Nexp_exp (strip_nexp nexp) | Nexp_neg nexp -> Nexp_neg (strip_nexp nexp) -and strip_nexp = function - | Nexp_aux (nexp_aux, _) -> Nexp_aux (strip_nexp_aux nexp_aux, Parse_ast.Unknown) + +and strip_nexp = function Nexp_aux (nexp_aux, _) -> Nexp_aux (strip_nexp_aux nexp_aux, Parse_ast.Unknown) + and strip_n_constraint_aux = function | NC_equal (nexp1, nexp2) -> NC_equal (strip_nexp nexp1, strip_nexp nexp2) | NC_bounded_ge (nexp1, nexp2) -> NC_bounded_ge (strip_nexp nexp1, strip_nexp nexp2) @@ -247,21 +232,21 @@ and strip_n_constraint_aux = function | NC_app (id, args) -> NC_app (strip_id id, List.map strip_typ_arg args) | NC_true -> NC_true | NC_false -> NC_false -and strip_n_constraint = function - | NC_aux (nc_aux, _) -> NC_aux (strip_n_constraint_aux nc_aux, Parse_ast.Unknown) -and strip_typ_arg = function - | A_aux (typ_arg_aux, _) -> A_aux (strip_typ_arg_aux typ_arg_aux, Parse_ast.Unknown) + +and strip_n_constraint = function NC_aux (nc_aux, _) -> NC_aux (strip_n_constraint_aux nc_aux, Parse_ast.Unknown) + +and strip_typ_arg = function A_aux (typ_arg_aux, _) -> A_aux (strip_typ_arg_aux typ_arg_aux, Parse_ast.Unknown) + and strip_typ_arg_aux = function | A_nexp nexp -> A_nexp (strip_nexp nexp) | A_typ typ -> A_typ (strip_typ typ) | A_order ord -> A_order (strip_order ord) | A_bool nc -> A_bool (strip_n_constraint nc) -and strip_order = function - | Ord_aux (ord_aux, _) -> Ord_aux (strip_order_aux ord_aux, Parse_ast.Unknown) -and strip_order_aux = function - | Ord_var kid -> Ord_var (strip_kid kid) - | Ord_inc -> Ord_inc - | Ord_dec -> Ord_dec + +and strip_order = function Ord_aux (ord_aux, _) -> Ord_aux (strip_order_aux ord_aux, Parse_ast.Unknown) + +and strip_order_aux = function Ord_var kid -> Ord_var (strip_kid kid) | Ord_inc -> Ord_inc | Ord_dec -> Ord_dec + and strip_typ_aux : typ_aux -> typ_aux = function | Typ_internal_unknown -> Typ_internal_unknown | Typ_id id -> Typ_id (strip_id id) @@ -270,25 +255,29 @@ and strip_typ_aux : typ_aux -> typ_aux = function | Typ_bidir (typ1, typ2) -> Typ_bidir (strip_typ typ1, strip_typ typ2) | Typ_tuple typs -> Typ_tuple (List.map strip_typ typs) | Typ_exist (kopts, constr, typ) -> - Typ_exist ((List.map strip_kinded_id kopts), strip_n_constraint constr, strip_typ typ) + Typ_exist (List.map strip_kinded_id kopts, strip_n_constraint constr, strip_typ typ) | Typ_app (id, args) -> Typ_app (strip_id id, List.map strip_typ_arg args) -and strip_typ : typ -> typ = function - | Typ_aux (typ_aux, _) -> Typ_aux (strip_typ_aux typ_aux, Parse_ast.Unknown) + +and strip_typ : typ -> typ = function Typ_aux (typ_aux, _) -> Typ_aux (strip_typ_aux typ_aux, Parse_ast.Unknown) + and strip_typq = function TypQ_aux (typq_aux, l) -> TypQ_aux (strip_typq_aux typq_aux, Parse_ast.Unknown) + and strip_typq_aux = function | TypQ_no_forall -> TypQ_no_forall | TypQ_tq quants -> TypQ_tq (List.map strip_quant_item quants) -and strip_quant_item = function - | QI_aux (qi_aux, _) -> QI_aux (strip_qi_aux qi_aux, Parse_ast.Unknown) + +and strip_quant_item = function QI_aux (qi_aux, _) -> QI_aux (strip_qi_aux qi_aux, Parse_ast.Unknown) + and strip_qi_aux = function | QI_id kinded_id -> QI_id (strip_kinded_id kinded_id) | QI_constraint constr -> QI_constraint (strip_n_constraint constr) + and strip_kinded_id = function | KOpt_aux (kinded_id_aux, _) -> KOpt_aux (strip_kinded_id_aux kinded_id_aux, Parse_ast.Unknown) -and strip_kinded_id_aux = function - | KOpt_kind (kind, kid) -> KOpt_kind (strip_kind kind, strip_kid kid) -and strip_kind = function - | K_aux (k_aux, _) -> K_aux (k_aux, Parse_ast.Unknown) + +and strip_kinded_id_aux = function KOpt_kind (kind, kid) -> KOpt_kind (strip_kind kind, strip_kid kid) + +and strip_kind = function K_aux (k_aux, _) -> K_aux (k_aux, Parse_ast.Unknown) let rec typ_constraints (Typ_aux (typ_aux, _)) = match typ_aux with @@ -298,16 +287,11 @@ let rec typ_constraints (Typ_aux (typ_aux, _)) = | Typ_tuple typs -> List.concat (List.map typ_constraints typs) | Typ_app (_, args) -> List.concat (List.map typ_arg_nexps args) | Typ_exist (_, _, typ) -> typ_constraints typ - | Typ_fn (arg_typs, ret_typ) -> - List.concat (List.map typ_constraints arg_typs) @ typ_constraints ret_typ - | Typ_bidir (typ1, typ2) -> - typ_constraints typ1 @ typ_constraints typ2 + | Typ_fn (arg_typs, ret_typ) -> List.concat (List.map typ_constraints arg_typs) @ typ_constraints ret_typ + | Typ_bidir (typ1, typ2) -> typ_constraints typ1 @ typ_constraints typ2 + and typ_arg_nexps (A_aux (typ_arg_aux, _)) = - match typ_arg_aux with - | A_nexp _ -> [] - | A_typ typ -> typ_constraints typ - | A_bool nc -> [nc] - | A_order _ -> [] + match typ_arg_aux with A_nexp _ -> [] | A_typ typ -> typ_constraints typ | A_bool nc -> [nc] | A_order _ -> [] let rec typ_nexps (Typ_aux (typ_aux, _)) = match typ_aux with @@ -317,20 +301,25 @@ let rec typ_nexps (Typ_aux (typ_aux, _)) = | Typ_tuple typs -> List.concat (List.map typ_nexps typs) | Typ_app (f, args) -> List.concat (List.map typ_arg_nexps args) | Typ_exist (kids, nc, typ) -> typ_nexps typ - | Typ_fn (arg_typs, ret_typ) -> - List.concat (List.map typ_nexps arg_typs) @ typ_nexps ret_typ - | Typ_bidir (typ1, typ2) -> - typ_nexps typ1 @ typ_nexps typ2 + | Typ_fn (arg_typs, ret_typ) -> List.concat (List.map typ_nexps arg_typs) @ typ_nexps ret_typ + | Typ_bidir (typ1, typ2) -> typ_nexps typ1 @ typ_nexps typ2 + and typ_arg_nexps (A_aux (typ_arg_aux, l)) = match typ_arg_aux with | A_nexp n -> [n] | A_typ typ -> typ_nexps typ | A_bool nc -> constraint_nexps nc | A_order ord -> [] + and constraint_nexps (NC_aux (nc_aux, l)) = match nc_aux with - | NC_equal (n1, n2) | NC_bounded_ge (n1, n2) | NC_bounded_le (n1, n2) | NC_bounded_gt (n1, n2) | NC_bounded_lt (n1, n2) | NC_not_equal (n1, n2) -> - [n1; n2] + | NC_equal (n1, n2) + | NC_bounded_ge (n1, n2) + | NC_bounded_le (n1, n2) + | NC_bounded_gt (n1, n2) + | NC_bounded_lt (n1, n2) + | NC_not_equal (n1, n2) -> + [n1; n2] | NC_set _ | NC_true | NC_false | NC_var _ -> [] | NC_or (nc1, nc2) | NC_and (nc1, nc2) -> constraint_nexps nc1 @ constraint_nexps nc2 | NC_app (_, args) -> List.concat (List.map typ_arg_nexps args) @@ -339,21 +328,20 @@ and constraint_nexps (NC_aux (nc_aux, l)) = let rec replace_nexp_typ nexp nexp' (Typ_aux (typ_aux, l) as typ) = let rep_typ = replace_nexp_typ nexp nexp' in match typ_aux with - | Typ_internal_unknown - | Typ_id _ - | Typ_var _ - -> typ + | Typ_internal_unknown | Typ_id _ | Typ_var _ -> typ | Typ_tuple typs -> Typ_aux (Typ_tuple (List.map rep_typ typs), l) | Typ_app (f, args) -> Typ_aux (Typ_app (f, List.map (replace_nexp_typ_arg nexp nexp') args), l) | Typ_exist (kids, nc, typ) -> Typ_aux (Typ_exist (kids, nc, rep_typ typ), l) | Typ_fn (arg_typs, ret_typ) -> Typ_aux (Typ_fn (List.map rep_typ arg_typs, rep_typ ret_typ), l) | Typ_bidir (typ1, typ2) -> Typ_aux (Typ_bidir (rep_typ typ1, rep_typ typ2), l) + and replace_nexp_typ_arg nexp nexp' (A_aux (typ_arg_aux, l) as arg) = match typ_arg_aux with - | A_nexp n -> if Nexp.compare n nexp == 0 then (A_aux (A_nexp nexp', l)) else arg + | A_nexp n -> if Nexp.compare n nexp == 0 then A_aux (A_nexp nexp', l) else arg | A_typ typ -> A_aux (A_typ (replace_nexp_typ nexp nexp' typ), l) | A_bool nc -> A_aux (A_bool (replace_nexp_nc nexp nexp' nc), l) | A_order _ -> arg + and replace_nexp_nc nexp nexp' (NC_aux (nc_aux, l) as nc) = let rep_nc = replace_nexp_nc nexp nexp' in let rep n = if Nexp.compare n nexp == 0 then nexp' else n in @@ -373,15 +361,13 @@ and replace_nexp_nc nexp nexp' (NC_aux (nc_aux, l) as nc) = let rec replace_nc_typ nc nc' (Typ_aux (typ_aux, l) as typ) = let rep_typ = replace_nc_typ nc nc' in match typ_aux with - | Typ_internal_unknown - | Typ_id _ - | Typ_var _ - -> typ + | Typ_internal_unknown | Typ_id _ | Typ_var _ -> typ | Typ_tuple typs -> Typ_aux (Typ_tuple (List.map rep_typ typs), l) | Typ_app (f, args) -> Typ_aux (Typ_app (f, List.map (replace_nc_typ_arg nc nc') args), l) | Typ_exist (kids, nc, typ) -> Typ_aux (Typ_exist (kids, nc, rep_typ typ), l) | Typ_fn (arg_typs, ret_typ) -> Typ_aux (Typ_fn (List.map rep_typ arg_typs, rep_typ ret_typ), l) | Typ_bidir (typ1, typ2) -> Typ_aux (Typ_bidir (rep_typ typ1, rep_typ typ2), l) + and replace_nc_typ_arg nc nc' (A_aux (typ_arg_aux, l) as arg) = match typ_arg_aux with | A_nexp _ -> arg @@ -394,15 +380,11 @@ and replace_nc_typ_arg nc nc' (A_aux (typ_arg_aux, l) as arg) = let rec nexp_power_variables (Nexp_aux (aux, _)) = match aux with | Nexp_times (n1, n2) | Nexp_sum (n1, n2) | Nexp_minus (n1, n2) -> - KidSet.union (nexp_power_variables n1) (nexp_power_variables n2) - | Nexp_neg n -> - nexp_power_variables n - | Nexp_id _ | Nexp_var _ | Nexp_constant _ -> - KidSet.empty - | Nexp_app (_, ns) -> - List.fold_left KidSet.union KidSet.empty (List.map nexp_power_variables ns) - | Nexp_exp n -> - tyvars_of_nexp n + KidSet.union (nexp_power_variables n1) (nexp_power_variables n2) + | Nexp_neg n -> nexp_power_variables n + | Nexp_id _ | Nexp_var _ | Nexp_constant _ -> KidSet.empty + | Nexp_app (_, ns) -> List.fold_left KidSet.union KidSet.empty (List.map nexp_power_variables ns) + | Nexp_exp n -> tyvars_of_nexp n let constraint_power_variables nc = List.fold_left KidSet.union KidSet.empty (List.map nexp_power_variables (constraint_nexps nc)) @@ -417,27 +399,28 @@ let ex_counter = ref 0 let fresh_existential l k = let fresh = Kid_aux (Var ("'ex" ^ string_of_int !ex_counter ^ "#"), l) in - incr ex_counter; mk_kopt ~loc:l k fresh + incr ex_counter; + mk_kopt ~loc:l k fresh -let named_existential l k = function - | Some n -> mk_kopt ~loc:l k (mk_kid n) - | None -> fresh_existential l k +let named_existential l k = function Some n -> mk_kopt ~loc:l k (mk_kid n) | None -> fresh_existential l k -let destruct_exist_plain ?name:(name=None) typ = +let destruct_exist_plain ?(name = None) typ = match typ with | Typ_aux (Typ_exist ([kopt], nc, typ), l) -> - let kid, fresh = kopt_kid kopt, named_existential l (unaux_kind (kopt_kind kopt)) name in - let nc = constraint_subst kid (arg_kopt fresh) nc in - let typ = typ_subst kid (arg_kopt fresh) typ in - Some ([fresh], nc, typ) + let kid, fresh = (kopt_kid kopt, named_existential l (unaux_kind (kopt_kind kopt)) name) in + let nc = constraint_subst kid (arg_kopt fresh) nc in + let typ = typ_subst kid (arg_kopt fresh) typ in + Some ([fresh], nc, typ) | Typ_aux (Typ_exist (kopts, nc, typ), l) -> - let add_num i = match name with Some n -> Some (n ^ string_of_int i) | None -> None in - let fresh_kopts = - List.mapi (fun i kopt -> (kopt_kid kopt, named_existential (kopt_loc kopt) (unaux_kind (kopt_kind kopt)) (add_num i))) kopts - in - let nc = List.fold_left (fun nc (kid, fresh) -> constraint_subst kid (arg_kopt fresh) nc) nc fresh_kopts in - let typ = List.fold_left (fun typ (kid, fresh) -> typ_subst kid (arg_kopt fresh) typ) typ fresh_kopts in - Some (List.map snd fresh_kopts, nc, typ) + let add_num i = match name with Some n -> Some (n ^ string_of_int i) | None -> None in + let fresh_kopts = + List.mapi + (fun i kopt -> (kopt_kid kopt, named_existential (kopt_loc kopt) (unaux_kind (kopt_kind kopt)) (add_num i))) + kopts + in + let nc = List.fold_left (fun nc (kid, fresh) -> constraint_subst kid (arg_kopt fresh) nc) nc fresh_kopts in + let typ = List.fold_left (fun typ (kid, fresh) -> typ_subst kid (arg_kopt fresh) typ) typ fresh_kopts in + Some (List.map snd fresh_kopts, nc, typ) | _ -> None (** Destructure and canonicalise a numeric type into a list of type @@ -448,52 +431,52 @@ let destruct_exist_plain ?name:(name=None) typ = - int => ['n], true, 'n (where x is fresh) - atom('n) => [], true, 'n **) -let destruct_numeric ?name:(name=None) typ = - match destruct_exist_plain ~name:name typ, typ with +let destruct_numeric ?(name = None) typ = + match (destruct_exist_plain ~name typ, typ) with | Some (kids, nc, Typ_aux (Typ_app (id, [A_aux (A_nexp nexp, _)]), _)), _ when string_of_id id = "atom" -> - Some (List.map kopt_kid kids, nc, nexp) - | None, Typ_aux (Typ_app (id, [A_aux (A_nexp nexp, _)]), _) when string_of_id id = "atom" -> - Some ([], nc_true, nexp) + Some (List.map kopt_kid kids, nc, nexp) + | None, Typ_aux (Typ_app (id, [A_aux (A_nexp nexp, _)]), _) when string_of_id id = "atom" -> Some ([], nc_true, nexp) | None, Typ_aux (Typ_app (id, [A_aux (A_nexp lo, _); A_aux (A_nexp hi, _)]), l) when string_of_id id = "range" -> - let kid = kopt_kid (named_existential l K_int name) in - Some ([kid], nc_and (nc_lteq lo (nvar kid)) (nc_lteq (nvar kid) hi), nvar kid) + let kid = kopt_kid (named_existential l K_int name) in + Some ([kid], nc_and (nc_lteq lo (nvar kid)) (nc_lteq (nvar kid) hi), nvar kid) | None, Typ_aux (Typ_id id, l) when string_of_id id = "nat" -> - let kid = kopt_kid (named_existential l K_int name) in - Some ([kid], nc_lteq (nint 0) (nvar kid), nvar kid) + let kid = kopt_kid (named_existential l K_int name) in + Some ([kid], nc_lteq (nint 0) (nvar kid), nvar kid) | None, Typ_aux (Typ_id id, l) when string_of_id id = "int" -> - let kid = kopt_kid (named_existential l K_int name) in - Some ([kid], nc_true, nvar kid) + let kid = kopt_kid (named_existential l K_int name) in + Some ([kid], nc_true, nvar kid) | _, _ -> None -let destruct_boolean ?name:(name=None) = function +let destruct_boolean ?(name = None) = function | Typ_aux (Typ_id (Id_aux (Id "bool", _)), l) -> - let kid = kopt_kid (fresh_existential l K_bool) in - Some (kid, nc_var kid) + let kid = kopt_kid (fresh_existential l K_bool) in + Some (kid, nc_var kid) | _ -> None -let destruct_exist ?name:(name=None) typ = - match destruct_numeric ~name:name typ with +let destruct_exist ?(name = None) typ = + match destruct_numeric ~name typ with | Some (kids, nc, nexp) -> Some (List.map (mk_kopt K_int) kids, nc, atom_typ nexp) - | None -> - match destruct_boolean ~name:name typ with - | Some (kid, nc) -> Some ([mk_kopt K_bool kid], nc_true, atom_bool_typ nc) - | None -> destruct_exist_plain ~name:name typ + | None -> ( + match destruct_boolean ~name typ with + | Some (kid, nc) -> Some ([mk_kopt K_bool kid], nc_true, atom_bool_typ nc) + | None -> destruct_exist_plain ~name typ + ) let adding = Util.("Adding " |> darkgray |> clear) let counter = ref 0 - -let fresh_kid ?kid:(kid=mk_kid "") env = + +let fresh_kid ?(kid = mk_kid "") env = let suffix = if Kid.compare kid (mk_kid "") = 0 then "#" else "#" ^ string_of_id (id_of_kid kid) in let fresh = Kid_aux (Var ("'fv" ^ string_of_int !counter ^ suffix), Parse_ast.Unknown) in - incr counter; fresh + incr counter; + fresh let freshen_kid env kid (typq, typ) = - let fresh = fresh_kid ~kid:kid env in + let fresh = fresh_kid ~kid env in if KidSet.mem kid (KidSet.of_list (List.map kopt_kid (quant_kopts typq))) then (typquant_subst_kid kid fresh typq, subst_kid typ_subst kid fresh typ) - else - (typq, typ) + else (typq, typ) let freshen_bind env bind = List.fold_left (fun bind (kid, _) -> freshen_kid env kid bind) bind (KBindings.bindings env.typ_vars) @@ -510,7 +493,7 @@ module Env : sig val define_val_spec : id -> t -> t val get_defined_val_specs : t -> IdSet.t val get_val_spec : id -> t -> typquant * typ - val get_val_specs : t -> (typquant * typ ) Bindings.t + val get_val_specs : t -> (typquant * typ) Bindings.t val get_val_spec_orig : id -> t -> typquant * typ val get_outcome : l -> id -> t -> typquant * typ * kinded_id list * id list * t val get_outcome_instantiation : t -> (Ast.l * typ) KBindings.t @@ -537,7 +520,7 @@ module Env : sig val get_variants : t -> (typquant * type_union list) Bindings.t val get_scattered_variant_env : id -> t -> t val add_union_id : id -> typquant * typ -> t -> t - val get_union_id : id -> t -> typquant * typ + val get_union_id : id -> t -> typquant * typ val is_register : id -> t -> bool val get_register : id -> t -> typ val get_registers : t -> typ Bindings.t @@ -545,7 +528,7 @@ module Env : sig val is_mutable : id -> t -> bool val get_constraints : t -> n_constraint list val get_constraint_reasons : t -> (constraint_reason * n_constraint) list - val add_constraint : ?reason:(Ast.l * string) -> n_constraint -> t -> t + val add_constraint : ?reason:Ast.l * string -> n_constraint -> t -> t val add_typquant : l -> typquant -> t -> t val get_typ_var : kid -> t -> kind_aux val get_typ_var_loc_opt : kid -> t -> Ast.l option @@ -569,7 +552,7 @@ module Env : sig val set_default_order : order -> t -> t val add_enum : id -> id list -> t -> t val get_enum : id -> t -> id list - val get_enums : t -> IdSet.t Bindings.t + val get_enums : t -> IdSet.t Bindings.t val is_enum : id -> t -> bool val get_casts : t -> id list val allow_casts : t -> bool @@ -609,7 +592,8 @@ end = struct type t = env let empty = - { top_val_specs = Bindings.empty; + { + top_val_specs = Bindings.empty; defined_val_specs = IdSet.empty; locals = Bindings.empty; top_letbinds = IdSet.empty; @@ -654,47 +638,60 @@ end = struct variable is renamed. We can't just remove it because it may be referenced by constraints. *) let shadows v env = match KBindings.find_opt v env.shadow_vars with Some n -> n | None -> 0 - + let add_typ_var_shadow l (KOpt_aux (KOpt_kind (K_aux (k, _), v), _)) env = if KBindings.mem v env.typ_vars then begin - let n = match KBindings.find_opt v env.shadow_vars with Some n -> n | None -> 0 in - let s_l, s_k = KBindings.find v env.typ_vars in - let s_v = Kid_aux (Var (string_of_kid v ^ "#" ^ string_of_int n), l) in - typ_print (lazy (Printf.sprintf "%stype variable (shadowing %s) %s : %s" adding (string_of_kid s_v) (string_of_kid v) (string_of_kind_aux k))); - { env with - constraints = List.map (fun (l, nc) -> (l, constraint_subst v (arg_kopt (mk_kopt s_k s_v)) nc)) env.constraints; + let n = match KBindings.find_opt v env.shadow_vars with Some n -> n | None -> 0 in + let s_l, s_k = KBindings.find v env.typ_vars in + let s_v = Kid_aux (Var (string_of_kid v ^ "#" ^ string_of_int n), l) in + typ_print + ( lazy + (Printf.sprintf "%stype variable (shadowing %s) %s : %s" adding (string_of_kid s_v) (string_of_kid v) + (string_of_kind_aux k) + ) + ); + ( { + env with + constraints = + List.map (fun (l, nc) -> (l, constraint_subst v (arg_kopt (mk_kopt s_k s_v)) nc)) env.constraints; typ_vars = KBindings.add v (l, k) (KBindings.add s_v (s_l, s_k) env.typ_vars); - locals = Bindings.map (fun (mut, typ) -> mut, typ_subst v (arg_kopt (mk_kopt s_k s_v)) typ) env.locals; - shadow_vars = KBindings.add v (n + 1) env.shadow_vars - }, Some s_v - end + locals = Bindings.map (fun (mut, typ) -> (mut, typ_subst v (arg_kopt (mk_kopt s_k s_v)) typ)) env.locals; + shadow_vars = KBindings.add v (n + 1) env.shadow_vars; + }, + Some s_v + ) + end else begin - typ_print (lazy (adding ^ "type variable " ^ string_of_kid v ^ " : " ^ string_of_kind_aux k)); - { env with typ_vars = KBindings.add v (l, k) env.typ_vars }, None - end + typ_print (lazy (adding ^ "type variable " ^ string_of_kid v ^ " : " ^ string_of_kind_aux k)); + ({ env with typ_vars = KBindings.add v (l, k) env.typ_vars }, None) + end let add_typ_var l kopt env = fst (add_typ_var_shadow l kopt env) - + let get_typ_var_loc_opt kid env = - match KBindings.find_opt kid env.typ_vars with - | Some (l, _) -> Some l - | None -> None - + match KBindings.find_opt kid env.typ_vars with Some (l, _) -> Some l | None -> None + let get_typ_var kid env = - try snd (KBindings.find kid env.typ_vars) with - | Not_found -> typ_error env (kid_loc kid) ("No type variable " ^ string_of_kid kid) + try snd (KBindings.find kid env.typ_vars) + with Not_found -> typ_error env (kid_loc kid) ("No type variable " ^ string_of_kid kid) let get_typ_vars env = KBindings.map snd env.typ_vars let get_typ_var_locs env = KBindings.map fst env.typ_vars let k_counter = ref 0 - let k_name () = let kid = mk_kid ("k#" ^ string_of_int !k_counter) in incr k_counter; kid + let k_name () = + let kid = mk_kid ("k#" ^ string_of_int !k_counter) in + incr k_counter; + kid let kinds_typq kinds = mk_typquant (List.map (fun k -> mk_qi_id k (k_name ())) kinds) let builtin_typs = - List.fold_left (fun m (name, kinds) -> Bindings.add (mk_id name) (kinds_typq kinds) m) Bindings.empty - [ ("range", [K_int; K_int]); + List.fold_left + (fun m (name, kinds) -> Bindings.add (mk_id name) (kinds_typq kinds) m) + Bindings.empty + [ + ("range", [K_int; K_int]); ("atom", [K_int]); ("implicit", [K_int]); ("vector", [K_int; K_order; K_type]); @@ -718,43 +715,38 @@ end = struct ] let bound_typ_id env id = - Bindings.mem id env.typ_synonyms - || Bindings.mem id env.variants - || Bindings.mem id env.records - || Bindings.mem id env.enums - || Bindings.mem id builtin_typs + Bindings.mem id env.typ_synonyms || Bindings.mem id env.variants || Bindings.mem id env.records + || Bindings.mem id env.enums || Bindings.mem id builtin_typs let get_binding_loc env id = - let has_key id' = Id.compare id id' = 0 in - if Bindings.mem id builtin_typs then - None - else if Bindings.mem id env.variants then - Some (id_loc (fst (Bindings.find_first has_key env.variants))) - else if Bindings.mem id env.records then - Some (id_loc (fst (Bindings.find_first has_key env.records))) - else if Bindings.mem id env.enums then - Some (id_loc (fst (Bindings.find_first has_key env.enums))) - else if Bindings.mem id env.typ_synonyms then - Some (id_loc (fst (Bindings.find_first has_key env.typ_synonyms))) - else - None + let has_key id' = Id.compare id id' = 0 in + if Bindings.mem id builtin_typs then None + else if Bindings.mem id env.variants then Some (id_loc (fst (Bindings.find_first has_key env.variants))) + else if Bindings.mem id env.records then Some (id_loc (fst (Bindings.find_first has_key env.records))) + else if Bindings.mem id env.enums then Some (id_loc (fst (Bindings.find_first has_key env.enums))) + else if Bindings.mem id env.typ_synonyms then Some (id_loc (fst (Bindings.find_first has_key env.typ_synonyms))) + else None let already_bound str id env = match get_binding_loc env id with | Some l -> - typ_raise env (id_loc id) (Err_inner (Err_other ("Cannot create " ^ str ^ " type " ^ string_of_id id ^ ", name is already bound"), - l, "", Some "previous binding", Err_other "")) + typ_raise env (id_loc id) + (Err_inner + ( Err_other ("Cannot create " ^ str ^ " type " ^ string_of_id id ^ ", name is already bound"), + l, + "", + Some "previous binding", + Err_other "" + ) + ) | None -> - let suffix = if Bindings.mem id builtin_typs then " as a built-in type" else "" in - typ_error env (id_loc id) ("Cannot create " ^ str ^ " type " ^ string_of_id id ^ ", name is already bound" ^ suffix) - - let bound_ctor_fn env id = - Bindings.mem id env.top_val_specs - || Bindings.mem id env.union_ids - - let get_overloads id env = - try Bindings.find id env.overloads with - | Not_found -> [] + let suffix = if Bindings.mem id builtin_typs then " as a built-in type" else "" in + typ_error env (id_loc id) + ("Cannot create " ^ str ^ " type " ^ string_of_id id ^ ", name is already bound" ^ suffix) + + let bound_ctor_fn env id = Bindings.mem id env.top_val_specs || Bindings.mem id env.union_ids + + let get_overloads id env = try Bindings.find id env.overloads with Not_found -> [] let add_overloads id ids env = typ_print (lazy (adding ^ "overloads for " ^ string_of_id id ^ " [" ^ string_of_list ", " string_of_id ids ^ "]")); @@ -762,186 +754,232 @@ end = struct { env with overloads = Bindings.add id (existing @ ids) env.overloads } let infer_kind env id = - if Bindings.mem id builtin_typs then - Bindings.find id builtin_typs - else if Bindings.mem id env.variants then - fst (Bindings.find id env.variants) - else if Bindings.mem id env.records then - fst (Bindings.find id env.records) - else if Bindings.mem id env.enums then - mk_typquant [] + if Bindings.mem id builtin_typs then Bindings.find id builtin_typs + else if Bindings.mem id env.variants then fst (Bindings.find id env.variants) + else if Bindings.mem id env.records then fst (Bindings.find id env.records) + else if Bindings.mem id env.enums then mk_typquant [] else if Bindings.mem id env.typ_synonyms then typ_error env (id_loc id) ("Cannot infer kind of type synonym " ^ string_of_id id) - else - typ_error env (id_loc id) ("Cannot infer kind of " ^ string_of_id id) + else typ_error env (id_loc id) ("Cannot infer kind of " ^ string_of_id id) let check_args_typquant id env args typq = let kopts, ncs = quant_split typq in let rec subst_args kopts args = - match kopts, args with + match (kopts, args) with | kopt :: kopts, (A_aux (A_nexp _, _) as arg) :: args when is_int_kopt kopt -> - List.map (constraint_subst (kopt_kid kopt) arg) (subst_args kopts args) - | kopt :: kopts, A_aux (A_typ arg, _) :: args when is_typ_kopt kopt -> - subst_args kopts args - | kopt :: kopts, A_aux (A_order arg, _) :: args when is_order_kopt kopt -> - subst_args kopts args - | kopt :: kopts, A_aux (A_bool arg, _) :: args when is_bool_kopt kopt -> - subst_args kopts args + List.map (constraint_subst (kopt_kid kopt) arg) (subst_args kopts args) + | kopt :: kopts, A_aux (A_typ arg, _) :: args when is_typ_kopt kopt -> subst_args kopts args + | kopt :: kopts, A_aux (A_order arg, _) :: args when is_order_kopt kopt -> subst_args kopts args + | kopt :: kopts, A_aux (A_bool arg, _) :: args when is_bool_kopt kopt -> subst_args kopts args | [], [] -> ncs - | _, A_aux (_, l) :: _ -> typ_error env l ("Error when processing type quantifer arguments " ^ string_of_typquant typq) - | _, _ -> typ_error env Parse_ast.Unknown ("Error when processing type quantifer arguments " ^ string_of_typquant typq) + | _, A_aux (_, l) :: _ -> + typ_error env l ("Error when processing type quantifer arguments " ^ string_of_typquant typq) + | _, _ -> + typ_error env Parse_ast.Unknown ("Error when processing type quantifer arguments " ^ string_of_typquant typq) in let ncs = subst_args kopts args in - if (match env.prove with Some prover -> List.for_all (prover env) ncs | None -> false) - then () - else typ_error env (id_loc id) ("Could not prove " ^ string_of_list ", " string_of_n_constraint ncs ^ " for type constructor " ^ string_of_id id) + if match env.prove with Some prover -> List.for_all (prover env) ncs | None -> false then () + else + typ_error env (id_loc id) + ("Could not prove " + ^ string_of_list ", " string_of_n_constraint ncs + ^ " for type constructor " ^ string_of_id id + ) let mk_synonym typq typ_arg = let kopts, ncs = quant_split typq in - let kopts = List.map (fun kopt -> kopt, fresh_existential (kopt_loc kopt) (unaux_kind (kopt_kind kopt))) kopts in - let ncs = List.map (fun nc -> List.fold_left (fun nc (kopt, fresh) -> constraint_subst (kopt_kid kopt) (arg_kopt fresh) nc) nc kopts) ncs in - let typ_arg = List.fold_left (fun typ_arg (kopt, fresh) -> typ_arg_subst (kopt_kid kopt) (arg_kopt fresh) typ_arg) typ_arg kopts in + let kopts = List.map (fun kopt -> (kopt, fresh_existential (kopt_loc kopt) (unaux_kind (kopt_kind kopt)))) kopts in + let ncs = + List.map + (fun nc -> + List.fold_left (fun nc (kopt, fresh) -> constraint_subst (kopt_kid kopt) (arg_kopt fresh) nc) nc kopts + ) + ncs + in + let typ_arg = + List.fold_left (fun typ_arg (kopt, fresh) -> typ_arg_subst (kopt_kid kopt) (arg_kopt fresh) typ_arg) typ_arg kopts + in let kopts = List.map snd kopts in let rec subst_args env l kopts args = - match kopts, args with + match (kopts, args) with | kopt :: kopts, A_aux (A_nexp arg, _) :: args when is_int_kopt kopt -> - let typ_arg, ncs = subst_args env l kopts args in - typ_arg_subst (kopt_kid kopt) (arg_nexp arg) typ_arg, - List.map (constraint_subst (kopt_kid kopt) (arg_nexp arg)) ncs + let typ_arg, ncs = subst_args env l kopts args in + ( typ_arg_subst (kopt_kid kopt) (arg_nexp arg) typ_arg, + List.map (constraint_subst (kopt_kid kopt) (arg_nexp arg)) ncs + ) | kopt :: kopts, A_aux (A_typ arg, _) :: args when is_typ_kopt kopt -> - let typ_arg, ncs = subst_args env l kopts args in - typ_arg_subst (kopt_kid kopt) (arg_typ arg) typ_arg, ncs + let typ_arg, ncs = subst_args env l kopts args in + (typ_arg_subst (kopt_kid kopt) (arg_typ arg) typ_arg, ncs) | kopt :: kopts, A_aux (A_order arg, _) :: args when is_order_kopt kopt -> - let typ_arg, ncs = subst_args env l kopts args in - typ_arg_subst (kopt_kid kopt) (arg_order arg) typ_arg, ncs + let typ_arg, ncs = subst_args env l kopts args in + (typ_arg_subst (kopt_kid kopt) (arg_order arg) typ_arg, ncs) | kopt :: kopts, A_aux (A_bool arg, _) :: args when is_bool_kopt kopt -> - let typ_arg, ncs = subst_args env l kopts args in - typ_arg_subst (kopt_kid kopt) (arg_bool arg) typ_arg, ncs - | [], [] -> typ_arg, ncs + let typ_arg, ncs = subst_args env l kopts args in + (typ_arg_subst (kopt_kid kopt) (arg_bool arg) typ_arg, ncs) + | [], [] -> (typ_arg, ncs) | _, _ -> typ_error env l "Synonym applied to bad arguments" in fun l env args -> - let typ_arg, ncs = subst_args env l kopts args in - if (match env.prove with Some prover -> List.for_all (prover env) ncs | None -> false) - then typ_arg - else typ_error env l ("Could not prove constraints " ^ string_of_list ", " string_of_n_constraint ncs - ^ " in type synonym " ^ string_of_typ_arg typ_arg - ^ " with " ^ Util.string_of_list ", " string_of_n_constraint (List.map snd env.constraints)) + let typ_arg, ncs = subst_args env l kopts args in + if match env.prove with Some prover -> List.for_all (prover env) ncs | None -> false then typ_arg + else + typ_error env l + ("Could not prove constraints " + ^ string_of_list ", " string_of_n_constraint ncs + ^ " in type synonym " ^ string_of_typ_arg typ_arg ^ " with " + ^ Util.string_of_list ", " string_of_n_constraint (List.map snd env.constraints) + ) let get_typ_synonym id env = - match Bindings.find_opt id env.typ_synonyms with - | Some (typq, arg) -> mk_synonym typq arg - | None -> raise Not_found + match Bindings.find_opt id env.typ_synonyms with Some (typq, arg) -> mk_synonym typq arg | None -> raise Not_found let get_typ_synonyms env = env.typ_synonyms - let get_constraints env = List.map snd (env.constraints) + let get_constraints env = List.map snd env.constraints let get_constraint_reasons env = env.constraints let wf_debug str f x exs = - typ_debug ~level:2 (lazy ("wf_" ^ str ^ ": " ^ f x ^ " exs: " ^ Util.string_of_list ", " string_of_kid (KidSet.elements exs))) + typ_debug ~level:2 + (lazy ("wf_" ^ str ^ ": " ^ f x ^ " exs: " ^ Util.string_of_list ", " string_of_kid (KidSet.elements exs))) (* Check if a type, order, n-expression or constraint is well-formed. Throws a type error if the type is badly formed. *) - let rec wf_typ' ?exs:(exs=KidSet.empty) env (Typ_aux (typ_aux, l) as typ) = + let rec wf_typ' ?(exs = KidSet.empty) env (Typ_aux (typ_aux, l) as typ) = match typ_aux with | Typ_id id when bound_typ_id env id -> - let typq = infer_kind env id in - if quant_kopts typq != [] - then typ_error env l ("Type constructor " ^ string_of_id id ^ " expected " ^ string_of_typquant typq) - else () + let typq = infer_kind env id in + if quant_kopts typq != [] then + typ_error env l ("Type constructor " ^ string_of_id id ^ " expected " ^ string_of_typquant typq) + else () | Typ_id id -> typ_error env l ("Undefined type " ^ string_of_id id) | Typ_var kid -> begin - match KBindings.find kid env.typ_vars with - | (_, K_type) -> () - | (_, k) -> typ_error env l ("Type variable " ^ string_of_kid kid ^ " in type " ^ string_of_typ typ - ^ " is " ^ string_of_kind_aux k ^ " rather than Type") - | exception Not_found -> - typ_error env l ("Unbound type variable " ^ string_of_kid kid ^ " in type " ^ string_of_typ typ) - end - | Typ_fn (arg_typs, ret_typ) -> List.iter (wf_typ' ~exs:exs env) arg_typs; wf_typ' ~exs:exs env ret_typ + match KBindings.find kid env.typ_vars with + | _, K_type -> () + | _, k -> + typ_error env l + ("Type variable " ^ string_of_kid kid ^ " in type " ^ string_of_typ typ ^ " is " ^ string_of_kind_aux k + ^ " rather than Type" + ) + | exception Not_found -> + typ_error env l ("Unbound type variable " ^ string_of_kid kid ^ " in type " ^ string_of_typ typ) + end + | Typ_fn (arg_typs, ret_typ) -> + List.iter (wf_typ' ~exs env) arg_typs; + wf_typ' ~exs env ret_typ | Typ_bidir (typ1, typ2) when strip_typ typ1 = strip_typ typ2 -> - typ_error env l "Bidirectional types cannot be the same on both sides" - | Typ_bidir (typ1, typ2) -> wf_typ' ~exs:exs env typ1; wf_typ' ~exs:exs env typ2 - | Typ_tuple typs -> List.iter (wf_typ' ~exs:exs env) typs - | Typ_app (id, [A_aux (A_nexp _, _) as arg]) when string_of_id id = "implicit" -> - wf_typ_arg ~exs:exs env arg + typ_error env l "Bidirectional types cannot be the same on both sides" + | Typ_bidir (typ1, typ2) -> + wf_typ' ~exs env typ1; + wf_typ' ~exs env typ2 + | Typ_tuple typs -> List.iter (wf_typ' ~exs env) typs + | Typ_app (id, [(A_aux (A_nexp _, _) as arg)]) when string_of_id id = "implicit" -> wf_typ_arg ~exs env arg | Typ_app (id, args) when bound_typ_id env id -> - List.iter (wf_typ_arg ~exs:exs env) args; - check_args_typquant id env args (infer_kind env id) + List.iter (wf_typ_arg ~exs env) args; + check_args_typquant id env args (infer_kind env id) | Typ_app (id, _) -> typ_error env l ("Undefined type " ^ string_of_id id) - | Typ_exist ([], _, _) -> typ_error env l ("Existential must have some type variables") + | Typ_exist ([], _, _) -> typ_error env l "Existential must have some type variables" | Typ_exist (kopts, nc, typ) when KidSet.is_empty exs -> - wf_constraint ~exs:(KidSet.of_list (List.map kopt_kid kopts)) env nc; - wf_typ' ~exs:(KidSet.of_list (List.map kopt_kid kopts)) env typ - | Typ_exist (_, _, _) -> typ_error env l ("Nested existentials are not allowed") + wf_constraint ~exs:(KidSet.of_list (List.map kopt_kid kopts)) env nc; + wf_typ' ~exs:(KidSet.of_list (List.map kopt_kid kopts)) env typ + | Typ_exist (_, _, _) -> typ_error env l "Nested existentials are not allowed" | Typ_internal_unknown -> Reporting.unreachable l __POS__ "escaped Typ_internal_unknown" - and wf_typ_arg ?exs:(exs=KidSet.empty) env (A_aux (typ_arg_aux, _)) = + + and wf_typ_arg ?(exs = KidSet.empty) env (A_aux (typ_arg_aux, _)) = match typ_arg_aux with - | A_nexp nexp -> wf_nexp ~exs:exs env nexp - | A_typ typ -> wf_typ' ~exs:exs env typ + | A_nexp nexp -> wf_nexp ~exs env nexp + | A_typ typ -> wf_typ' ~exs env typ | A_order ord -> wf_order env ord - | A_bool nc -> wf_constraint ~exs:exs env nc - and wf_nexp ?exs:(exs=KidSet.empty) env (Nexp_aux (nexp_aux, l) as nexp) = + | A_bool nc -> wf_constraint ~exs env nc + + and wf_nexp ?(exs = KidSet.empty) env (Nexp_aux (nexp_aux, l) as nexp) = wf_debug "nexp" string_of_nexp nexp exs; match nexp_aux with | Nexp_id id -> typ_error env l ("Undefined type synonym " ^ string_of_id id) | Nexp_var kid when KidSet.mem kid exs -> () - | Nexp_var kid -> - begin match get_typ_var kid env with - | K_int -> () - | kind -> typ_error env l ("Constraint is badly formed, " - ^ string_of_kid kid ^ " has kind " - ^ string_of_kind_aux kind ^ " but should have kind Int") - end + | Nexp_var kid -> begin + match get_typ_var kid env with + | K_int -> () + | kind -> + typ_error env l + ("Constraint is badly formed, " ^ string_of_kid kid ^ " has kind " ^ string_of_kind_aux kind + ^ " but should have kind Int" + ) + end | Nexp_constant _ -> () - | Nexp_app (id, nexps) -> - List.iter (fun n -> wf_nexp ~exs:exs env n) nexps - | Nexp_times (nexp1, nexp2) -> wf_nexp ~exs:exs env nexp1; wf_nexp ~exs:exs env nexp2 - | Nexp_sum (nexp1, nexp2) -> wf_nexp ~exs:exs env nexp1; wf_nexp ~exs:exs env nexp2 - | Nexp_minus (nexp1, nexp2) -> wf_nexp ~exs:exs env nexp1; wf_nexp ~exs:exs env nexp2 - | Nexp_exp nexp -> wf_nexp ~exs:exs env nexp (* MAYBE: Could put restrictions on what is allowed here *) - | Nexp_neg nexp -> wf_nexp ~exs:exs env nexp + | Nexp_app (id, nexps) -> List.iter (fun n -> wf_nexp ~exs env n) nexps + | Nexp_times (nexp1, nexp2) -> + wf_nexp ~exs env nexp1; + wf_nexp ~exs env nexp2 + | Nexp_sum (nexp1, nexp2) -> + wf_nexp ~exs env nexp1; + wf_nexp ~exs env nexp2 + | Nexp_minus (nexp1, nexp2) -> + wf_nexp ~exs env nexp1; + wf_nexp ~exs env nexp2 + | Nexp_exp nexp -> wf_nexp ~exs env nexp (* MAYBE: Could put restrictions on what is allowed here *) + | Nexp_neg nexp -> wf_nexp ~exs env nexp + and wf_order env (Ord_aux (ord_aux, l)) = match ord_aux with - | Ord_var kid -> - begin match get_typ_var kid env with - | K_order -> () - | kind -> typ_error env l ("Order is badly formed, " - ^ string_of_kid kid ^ " has kind " - ^ string_of_kind_aux kind ^ " but should have kind Order") - end + | Ord_var kid -> begin + match get_typ_var kid env with + | K_order -> () + | kind -> + typ_error env l + ("Order is badly formed, " ^ string_of_kid kid ^ " has kind " ^ string_of_kind_aux kind + ^ " but should have kind Order" + ) + end | Ord_inc | Ord_dec -> () - and wf_constraint ?exs:(exs=KidSet.empty) env (NC_aux (nc_aux, l) as nc) = + + and wf_constraint ?(exs = KidSet.empty) env (NC_aux (nc_aux, l) as nc) = wf_debug "constraint" string_of_n_constraint nc exs; match nc_aux with - | NC_equal (n1, n2) -> wf_nexp ~exs:exs env n1; wf_nexp ~exs:exs env n2 - | NC_not_equal (n1, n2) -> wf_nexp ~exs:exs env n1; wf_nexp ~exs:exs env n2 - | NC_bounded_ge (n1, n2) -> wf_nexp ~exs:exs env n1; wf_nexp ~exs:exs env n2 - | NC_bounded_gt (n1, n2) -> wf_nexp ~exs:exs env n1; wf_nexp ~exs:exs env n2 - | NC_bounded_le (n1, n2) -> wf_nexp ~exs:exs env n1; wf_nexp ~exs:exs env n2 - | NC_bounded_lt (n1, n2) -> wf_nexp ~exs:exs env n1; wf_nexp ~exs:exs env n2 + | NC_equal (n1, n2) -> + wf_nexp ~exs env n1; + wf_nexp ~exs env n2 + | NC_not_equal (n1, n2) -> + wf_nexp ~exs env n1; + wf_nexp ~exs env n2 + | NC_bounded_ge (n1, n2) -> + wf_nexp ~exs env n1; + wf_nexp ~exs env n2 + | NC_bounded_gt (n1, n2) -> + wf_nexp ~exs env n1; + wf_nexp ~exs env n2 + | NC_bounded_le (n1, n2) -> + wf_nexp ~exs env n1; + wf_nexp ~exs env n2 + | NC_bounded_lt (n1, n2) -> + wf_nexp ~exs env n1; + wf_nexp ~exs env n2 | NC_set (kid, _) when KidSet.mem kid exs -> () - | NC_set (kid, _) -> - begin match get_typ_var kid env with - | K_int -> () - | kind -> typ_error env l ("Set constraint is badly formed, " - ^ string_of_kid kid ^ " has kind " - ^ string_of_kind_aux kind ^ " but should have kind Int") - end - | NC_or (nc1, nc2) -> wf_constraint ~exs:exs env nc1; wf_constraint ~exs:exs env nc2 - | NC_and (nc1, nc2) -> wf_constraint ~exs:exs env nc1; wf_constraint ~exs:exs env nc2 - | NC_app (id, args) -> List.iter (wf_typ_arg ~exs:exs env) args + | NC_set (kid, _) -> begin + match get_typ_var kid env with + | K_int -> () + | kind -> + typ_error env l + ("Set constraint is badly formed, " ^ string_of_kid kid ^ " has kind " ^ string_of_kind_aux kind + ^ " but should have kind Int" + ) + end + | NC_or (nc1, nc2) -> + wf_constraint ~exs env nc1; + wf_constraint ~exs env nc2 + | NC_and (nc1, nc2) -> + wf_constraint ~exs env nc1; + wf_constraint ~exs env nc2 + | NC_app (id, args) -> List.iter (wf_typ_arg ~exs env) args | NC_var kid when KidSet.mem kid exs -> () - | NC_var kid -> - begin match get_typ_var kid env with - | K_bool -> () - | kind -> typ_error env l (string_of_kid kid ^ " has kind " - ^ string_of_kind_aux kind ^ " but should have kind Bool") - end + | NC_var kid -> begin + match get_typ_var kid env with + | K_bool -> () + | kind -> + typ_error env l (string_of_kid kid ^ " has kind " ^ string_of_kind_aux kind ^ " but should have kind Bool") + end | NC_true | NC_false -> () - + let rec expand_constraint_synonyms env (NC_aux (aux, l) as nc) = match aux with | NC_or (nc1, nc2) -> NC_aux (NC_or (expand_constraint_synonyms env nc1, expand_constraint_synonyms env nc2), l) @@ -952,35 +990,44 @@ end = struct | NC_bounded_lt (n1, n2) -> NC_aux (NC_bounded_lt (expand_nexp_synonyms env n1, expand_nexp_synonyms env n2), l) | NC_bounded_ge (n1, n2) -> NC_aux (NC_bounded_ge (expand_nexp_synonyms env n1, expand_nexp_synonyms env n2), l) | NC_bounded_gt (n1, n2) -> NC_aux (NC_bounded_gt (expand_nexp_synonyms env n1, expand_nexp_synonyms env n2), l) - | NC_app (id, args) -> - (try - begin match get_typ_synonym id env l env args with - | A_aux (A_bool nc, _) -> expand_constraint_synonyms env nc - | arg -> typ_error env l ("Expected Bool when expanding synonym " ^ string_of_id id ^ " got " ^ string_of_typ_arg arg) + | NC_app (id, args) -> ( + try + begin + match get_typ_synonym id env l env args with + | A_aux (A_bool nc, _) -> expand_constraint_synonyms env nc + | arg -> + typ_error env l + ("Expected Bool when expanding synonym " ^ string_of_id id ^ " got " ^ string_of_typ_arg arg) end - with Not_found -> NC_aux (NC_app (id, List.map (expand_arg_synonyms env) args), l)) + with Not_found -> NC_aux (NC_app (id, List.map (expand_arg_synonyms env) args), l) + ) | NC_true | NC_false | NC_var _ | NC_set _ -> nc and expand_nexp_synonyms env (Nexp_aux (aux, l) as nexp) = match aux with - | Nexp_app (id, args) -> - (try - begin match get_typ_synonym id env l env [] with - | A_aux (A_nexp nexp, _) -> expand_nexp_synonyms env nexp - | _ -> typ_error env l ("Expected Int when expanding synonym " ^ string_of_id id) + | Nexp_app (id, args) -> ( + try + begin + match get_typ_synonym id env l env [] with + | A_aux (A_nexp nexp, _) -> expand_nexp_synonyms env nexp + | _ -> typ_error env l ("Expected Int when expanding synonym " ^ string_of_id id) end - with - | Not_found -> Nexp_aux (Nexp_app (id, List.map (expand_nexp_synonyms env) args), l)) - | Nexp_id id -> - (try - begin match get_typ_synonym id env l env [] with - | A_aux (A_nexp nexp, _) -> expand_nexp_synonyms env nexp - | _ -> typ_error env l ("Expected Int when expanding synonym " ^ string_of_id id) + with Not_found -> Nexp_aux (Nexp_app (id, List.map (expand_nexp_synonyms env) args), l) + ) + | Nexp_id id -> ( + try + begin + match get_typ_synonym id env l env [] with + | A_aux (A_nexp nexp, _) -> expand_nexp_synonyms env nexp + | _ -> typ_error env l ("Expected Int when expanding synonym " ^ string_of_id id) end - with Not_found -> nexp) - | Nexp_times (nexp1, nexp2) -> Nexp_aux (Nexp_times (expand_nexp_synonyms env nexp1, expand_nexp_synonyms env nexp2), l) + with Not_found -> nexp + ) + | Nexp_times (nexp1, nexp2) -> + Nexp_aux (Nexp_times (expand_nexp_synonyms env nexp1, expand_nexp_synonyms env nexp2), l) | Nexp_sum (nexp1, nexp2) -> Nexp_aux (Nexp_sum (expand_nexp_synonyms env nexp1, expand_nexp_synonyms env nexp2), l) - | Nexp_minus (nexp1, nexp2) -> Nexp_aux (Nexp_minus (expand_nexp_synonyms env nexp1, expand_nexp_synonyms env nexp2), l) + | Nexp_minus (nexp1, nexp2) -> + Nexp_aux (Nexp_minus (expand_nexp_synonyms env nexp1, expand_nexp_synonyms env nexp2), l) | Nexp_exp nexp -> Nexp_aux (Nexp_exp (expand_nexp_synonyms env nexp), l) | Nexp_neg nexp -> Nexp_aux (Nexp_neg (expand_nexp_synonyms env nexp), l) | Nexp_var kid -> Nexp_aux (Nexp_var kid, l) @@ -990,56 +1037,61 @@ end = struct match typ with | Typ_internal_unknown -> Typ_aux (Typ_internal_unknown, l) | Typ_tuple typs -> Typ_aux (Typ_tuple (List.map (expand_synonyms env) typs), l) - | Typ_fn (arg_typs, ret_typ) -> Typ_aux (Typ_fn (List.map (expand_synonyms env) arg_typs, expand_synonyms env ret_typ), l) + | Typ_fn (arg_typs, ret_typ) -> + Typ_aux (Typ_fn (List.map (expand_synonyms env) arg_typs, expand_synonyms env ret_typ), l) | Typ_bidir (typ1, typ2) -> Typ_aux (Typ_bidir (expand_synonyms env typ1, expand_synonyms env typ2), l) - | Typ_app (id, args) -> - (try - begin match get_typ_synonym id env l env args with - | A_aux (A_typ typ, _) -> expand_synonyms env typ - | _ -> typ_error env l ("Expected Type when expanding synonym " ^ string_of_id id) + | Typ_app (id, args) -> ( + try + begin + match get_typ_synonym id env l env args with + | A_aux (A_typ typ, _) -> expand_synonyms env typ + | _ -> typ_error env l ("Expected Type when expanding synonym " ^ string_of_id id) end - with - | Not_found -> Typ_aux (Typ_app (id, List.map (expand_arg_synonyms env) args), l)) - | Typ_id id -> - (try - begin match get_typ_synonym id env l env [] with - | A_aux (A_typ typ, _) -> expand_synonyms env typ - | _ -> typ_error env l ("Expected Type when expanding synonym " ^ string_of_id id) + with Not_found -> Typ_aux (Typ_app (id, List.map (expand_arg_synonyms env) args), l) + ) + | Typ_id id -> ( + try + begin + match get_typ_synonym id env l env [] with + | A_aux (A_typ typ, _) -> expand_synonyms env typ + | _ -> typ_error env l ("Expected Type when expanding synonym " ^ string_of_id id) end - with - | Not_found -> Typ_aux (Typ_id id, l)) + with Not_found -> Typ_aux (Typ_id id, l) + ) | Typ_exist (kopts, nc, typ) -> - let nc = expand_constraint_synonyms env nc in - - (* When expanding an existential synonym we need to take care - to add the type variables and constraints to the - environment, so we can check constraints attached to type - synonyms within the existential. Furthermore, we must take - care to avoid clobbering any existing type variables in - scope while doing this. *) - let rebindings = ref [] in - - let rename_kopt (KOpt_aux (KOpt_kind (k, kid), l) as kopt) = - if KBindings.mem kid env.typ_vars then - KOpt_aux (KOpt_kind (k, prepend_kid "syn#" kid), l) - else kopt - in - let add_typ_var env (KOpt_aux (KOpt_kind (k, kid), l)) = - try - let (l, _) = KBindings.find kid env.typ_vars in - rebindings := kid :: !rebindings; - { env with typ_vars = KBindings.add (prepend_kid "syn#" kid) (l, unaux_kind k) env.typ_vars } - with - | Not_found -> - { env with typ_vars = KBindings.add kid (l, unaux_kind k) env.typ_vars } - in - - let env = List.fold_left add_typ_var env kopts in - let kopts = List.map rename_kopt kopts in - let nc = List.fold_left (fun nc kid -> constraint_subst kid (arg_nexp (nvar (prepend_kid "syn#" kid))) nc) nc !rebindings in - let typ = List.fold_left (fun typ kid -> typ_subst kid (arg_nexp (nvar (prepend_kid "syn#" kid))) typ) typ !rebindings in - let env = add_constraint nc env in - Typ_aux (Typ_exist (kopts, nc, expand_synonyms env typ), l) + let nc = expand_constraint_synonyms env nc in + + (* When expanding an existential synonym we need to take care + to add the type variables and constraints to the + environment, so we can check constraints attached to type + synonyms within the existential. Furthermore, we must take + care to avoid clobbering any existing type variables in + scope while doing this. *) + let rebindings = ref [] in + + let rename_kopt (KOpt_aux (KOpt_kind (k, kid), l) as kopt) = + if KBindings.mem kid env.typ_vars then KOpt_aux (KOpt_kind (k, prepend_kid "syn#" kid), l) else kopt + in + let add_typ_var env (KOpt_aux (KOpt_kind (k, kid), l)) = + try + let l, _ = KBindings.find kid env.typ_vars in + rebindings := kid :: !rebindings; + { env with typ_vars = KBindings.add (prepend_kid "syn#" kid) (l, unaux_kind k) env.typ_vars } + with Not_found -> { env with typ_vars = KBindings.add kid (l, unaux_kind k) env.typ_vars } + in + + let env = List.fold_left add_typ_var env kopts in + let kopts = List.map rename_kopt kopts in + let nc = + List.fold_left + (fun nc kid -> constraint_subst kid (arg_nexp (nvar (prepend_kid "syn#" kid))) nc) + nc !rebindings + in + let typ = + List.fold_left (fun typ kid -> typ_subst kid (arg_nexp (nvar (prepend_kid "syn#" kid))) typ) typ !rebindings + in + let env = add_constraint nc env in + Typ_aux (Typ_exist (kopts, nc, expand_synonyms env typ), l) | Typ_var v -> Typ_aux (Typ_var v, l) and expand_arg_synonyms env (A_aux (typ_arg, l)) = @@ -1054,36 +1106,49 @@ end = struct wf_constraint env constr; let power_vars = constraint_power_variables constr in if KidSet.cardinal power_vars > 1 && !opt_smt_linearize then - typ_error env l ("Cannot add constraint " ^ string_of_n_constraint constr - ^ " where more than two variables appear within an exponential") - else if KidSet.cardinal power_vars = 1 && !opt_smt_linearize then + typ_error env l + ("Cannot add constraint " ^ string_of_n_constraint constr + ^ " where more than two variables appear within an exponential" + ) + else if KidSet.cardinal power_vars = 1 && !opt_smt_linearize then ( let v = KidSet.choose power_vars in let constrs = List.fold_left nc_and nc_true (get_constraints env) in - begin match Constraint.solve_all_smt l constrs v with - | Some solutions -> - typ_print (lazy (Util.("Linearizing " |> red |> clear) ^ string_of_n_constraint constr - ^ " for " ^ string_of_kid v ^ " in " ^ Util.string_of_list ", " Big_int.to_string solutions)); - let linearized = - List.fold_left - (fun c s -> nc_or c (nc_and (nc_eq (nvar v) (nconstant s)) (constraint_subst v (arg_nexp (nconstant s)) constr))) - nc_false solutions - in - typ_print (lazy (adding ^ "constraint " ^ string_of_n_constraint linearized)); - { env with constraints = (reason, linearized) :: env.constraints } - | None -> - typ_error env l ("Type variable " ^ string_of_kid v - ^ " must have a finite number of solutions to add " ^ string_of_n_constraint constr) + begin + match Constraint.solve_all_smt l constrs v with + | Some solutions -> + typ_print + ( lazy + (Util.("Linearizing " |> red |> clear) + ^ string_of_n_constraint constr ^ " for " ^ string_of_kid v ^ " in " + ^ Util.string_of_list ", " Big_int.to_string solutions + ) + ); + let linearized = + List.fold_left + (fun c s -> + nc_or c (nc_and (nc_eq (nvar v) (nconstant s)) (constraint_subst v (arg_nexp (nconstant s)) constr)) + ) + nc_false solutions + in + typ_print (lazy (adding ^ "constraint " ^ string_of_n_constraint linearized)); + { env with constraints = (reason, linearized) :: env.constraints } + | None -> + typ_error env l + ("Type variable " ^ string_of_kid v ^ " must have a finite number of solutions to add " + ^ string_of_n_constraint constr + ) end - else + ) + else ( match nc_aux with | NC_true -> env | _ -> - typ_print (lazy (adding ^ "constraint " ^ string_of_n_constraint constr)); - { env with constraints = (reason, constr) :: env.constraints } + typ_print (lazy (adding ^ "constraint " ^ string_of_n_constraint constr)); + { env with constraints = (reason, constr) :: env.constraints } + ) let add_typquant l quant env = - let rec add_quant_item env = function - | QI_aux (qi, _) -> add_quant_item_aux env qi + let rec add_quant_item env = function QI_aux (qi, _) -> add_quant_item_aux env qi and add_quant_item_aux env = function | QI_constraint constr -> add_constraint constr env | QI_id kopt -> add_typ_var l kopt env @@ -1099,117 +1164,131 @@ end = struct try wf_typ' env typ; decr depth - with - | Type_error (env, err_l, err) -> - decr depth; - typ_raise env l (err_because (Err_other "Well-formedness check failed for type", - err_l, - err)) - + with Type_error (env, err_l, err) -> + decr depth; + typ_raise env l (err_because (Err_other "Well-formedness check failed for type", err_l, err)) + let add_typ_synonym id typq arg env = - if bound_typ_id env id then ( - typ_error env (id_loc id) ("Cannot define type synonym " ^ string_of_id id ^ ", as a type or synonym with that name already exists") - ) else ( + if bound_typ_id env id then + typ_error env (id_loc id) + ("Cannot define type synonym " ^ string_of_id id ^ ", as a type or synonym with that name already exists") + else ( let typq = - quant_map_items (function + quant_map_items + (function | QI_aux (QI_constraint nexp, aux) -> QI_aux (QI_constraint (expand_constraint_synonyms env nexp), aux) | quant_item -> quant_item - ) typq in - typ_print (lazy (adding ^ "type synonym " ^ string_of_id id ^ ", " ^ string_of_typquant typq ^ " = " ^ string_of_typ_arg arg)); - { env with typ_synonyms = Bindings.add id (typq, expand_arg_synonyms (add_typquant (id_loc id) typq env) arg) env.typ_synonyms } + ) + typq + in + typ_print + ( lazy + (adding ^ "type synonym " ^ string_of_id id ^ ", " ^ string_of_typquant typq ^ " = " ^ string_of_typ_arg arg) + ); + { + env with + typ_synonyms = + Bindings.add id (typq, expand_arg_synonyms (add_typquant (id_loc id) typq env) arg) env.typ_synonyms; + } ) let get_val_spec_orig id env = - try - Bindings.find id env.top_val_specs - with - | Not_found -> typ_error env (id_loc id) ("No type signature found for " ^ string_of_id id) + try Bindings.find id env.top_val_specs + with Not_found -> typ_error env (id_loc id) ("No type signature found for " ^ string_of_id id) let get_val_spec id env = try let bind = Bindings.find id env.top_val_specs in - typ_debug (lazy ("get_val_spec: Env has " ^ string_of_list ", " (fun (kid, (_, k)) -> string_of_kid kid ^ " => " ^ string_of_kind_aux k) (KBindings.bindings env.typ_vars))); - let bind' = List.fold_left (fun bind (kid, _) -> freshen_kid env kid bind) bind (KBindings.bindings env.typ_vars) in + typ_debug + ( lazy + ("get_val_spec: Env has " + ^ string_of_list ", " + (fun (kid, (_, k)) -> string_of_kid kid ^ " => " ^ string_of_kind_aux k) + (KBindings.bindings env.typ_vars) + ) + ); + let bind' = + List.fold_left (fun bind (kid, _) -> freshen_kid env kid bind) bind (KBindings.bindings env.typ_vars) + in typ_debug (lazy ("get_val_spec: freshened to " ^ string_of_bind bind')); bind' - with - | Not_found -> typ_error env (id_loc id) ("No type declaration found for " ^ string_of_id id) + with Not_found -> typ_error env (id_loc id) ("No type declaration found for " ^ string_of_id id) let get_val_specs env = env.top_val_specs - + let add_union_id id bind env = - if bound_ctor_fn env id - then typ_error env (id_loc id) ("A union constructor or function already exists with name " ^ string_of_id id ) - else - begin - typ_print (lazy (adding ^ "union identifier " ^ string_of_id id ^ " : " ^ string_of_bind bind)); - { env with union_ids = Bindings.add id bind env.union_ids } - end - + if bound_ctor_fn env id then + typ_error env (id_loc id) ("A union constructor or function already exists with name " ^ string_of_id id) + else begin + typ_print (lazy (adding ^ "union identifier " ^ string_of_id id ^ " : " ^ string_of_bind bind)); + { env with union_ids = Bindings.add id bind env.union_ids } + end + let get_union_id id env = try let bind = Bindings.find id env.union_ids in List.fold_left (fun bind (kid, _) -> freshen_kid env kid bind) bind (KBindings.bindings env.typ_vars) - with - | Not_found -> typ_error env (id_loc id) ("No union constructor found for " ^ string_of_id id) + with Not_found -> typ_error env (id_loc id) ("No union constructor found for " ^ string_of_id id) let rec valid_implicits env start = function | Typ_aux (Typ_app (Id_aux (Id "implicit", _), [A_aux (A_nexp (Nexp_aux (Nexp_var v, _)), _)]), l) :: rest -> - if start then - valid_implicits env true rest - else - typ_error env l "Arguments are invalid, implicit arguments must come before all other arguments" + if start then valid_implicits env true rest + else typ_error env l "Arguments are invalid, implicit arguments must come before all other arguments" | Typ_aux (Typ_app (Id_aux (Id "implicit", _), [A_aux (A_nexp _, l)]), _) :: rest -> - typ_error env l "Implicit argument must contain a single type variable" + typ_error env l "Implicit argument must contain a single type variable" | _ :: rest -> valid_implicits env false rest | [] -> () let rec update_val_spec id (typq, typ) env = let typq_env = add_typquant (id_loc id) typq env in - begin match expand_synonyms typq_env typ with - | Typ_aux (Typ_fn (arg_typs, ret_typ), l) -> - valid_implicits env true arg_typs; - - (* We perform some canonicalisation for function types where existentials appear on the left, so - ({'n, 'n >= 2, int('n)}, foo) -> bar - would become - forall 'n, 'n >= 2. (int('n), foo) -> bar - this enforces the invariant that all things on the left of functions are 'base types' (i.e. without existentials) - *) - let base_args = List.map (fun typ -> destruct_exist (expand_synonyms typq_env typ)) arg_typs in - let existential_arg typq = function - | None -> typq - | Some (exs, nc, _) -> - List.fold_left (fun typq kopt -> quant_add (mk_qi_kopt kopt) typq) (quant_add (mk_qi_nc nc) typq) exs - in - let typq = List.fold_left existential_arg typq base_args in - let arg_typs = List.map2 (fun typ -> function Some (_, _, typ) -> typ | None -> typ) arg_typs base_args in - let typ = Typ_aux (Typ_fn (arg_typs, ret_typ), l) in - typ_print (lazy (adding ^ "val " ^ string_of_id id ^ " : " ^ string_of_bind (typq, typ))); - { env with top_val_specs = Bindings.add id (typq, typ) env.top_val_specs } - - | Typ_aux (Typ_bidir (typ1, typ2), _) -> - let env = add_mapping id (typq, typ1, typ2) env in - typ_print (lazy (adding ^ "mapping " ^ string_of_id id ^ " : " ^ string_of_bind (typq, typ))); - { env with top_val_specs = Bindings.add id (typq, typ) env.top_val_specs } - - | _ -> typ_error env (id_loc id) "val definition must have a mapping or function type" + begin + match expand_synonyms typq_env typ with + | Typ_aux (Typ_fn (arg_typs, ret_typ), l) -> + valid_implicits env true arg_typs; + + (* We perform some canonicalisation for function types where existentials appear on the left, so + ({'n, 'n >= 2, int('n)}, foo) -> bar + would become + forall 'n, 'n >= 2. (int('n), foo) -> bar + this enforces the invariant that all things on the left of functions are 'base types' (i.e. without existentials) + *) + let base_args = List.map (fun typ -> destruct_exist (expand_synonyms typq_env typ)) arg_typs in + let existential_arg typq = function + | None -> typq + | Some (exs, nc, _) -> + List.fold_left (fun typq kopt -> quant_add (mk_qi_kopt kopt) typq) (quant_add (mk_qi_nc nc) typq) exs + in + let typq = List.fold_left existential_arg typq base_args in + let arg_typs = List.map2 (fun typ -> function Some (_, _, typ) -> typ | None -> typ) arg_typs base_args in + let typ = Typ_aux (Typ_fn (arg_typs, ret_typ), l) in + typ_print (lazy (adding ^ "val " ^ string_of_id id ^ " : " ^ string_of_bind (typq, typ))); + { env with top_val_specs = Bindings.add id (typq, typ) env.top_val_specs } + | Typ_aux (Typ_bidir (typ1, typ2), _) -> + let env = add_mapping id (typq, typ1, typ2) env in + typ_print (lazy (adding ^ "mapping " ^ string_of_id id ^ " : " ^ string_of_bind (typq, typ))); + { env with top_val_specs = Bindings.add id (typq, typ) env.top_val_specs } + | _ -> typ_error env (id_loc id) "val definition must have a mapping or function type" end - and add_val_spec ?(ignore_duplicate=false) id (bind_typq, bind_typ) env = - if not (Bindings.mem id env.top_val_specs) || ignore_duplicate then ( - update_val_spec id (bind_typq, bind_typ) env - ) else if ignore_duplicate then ( - env - ) else ( + and add_val_spec ?(ignore_duplicate = false) id (bind_typq, bind_typ) env = + if (not (Bindings.mem id env.top_val_specs)) || ignore_duplicate then update_val_spec id (bind_typq, bind_typ) env + else if ignore_duplicate then env + else ( let previous_loc = match Bindings.choose_opt (Bindings.filter (fun key _ -> Id.compare id key = 0) env.top_val_specs) with | Some (prev_id, _) -> id_loc prev_id - | None -> Parse_ast.Unknown in + | None -> Parse_ast.Unknown + in let open Error_format in - Reporting.format_warn ~once_from:__POS__ ("Duplicate function type definition for " ^ string_of_id id) (id_loc id) - (Seq [Line "This duplicate definition is being ignored!"; - Location ("", Some "previous definition here", previous_loc, Seq [])]); + Reporting.format_warn ~once_from:__POS__ + ("Duplicate function type definition for " ^ string_of_id id) + (id_loc id) + (Seq + [ + Line "This duplicate definition is being ignored!"; + Location ("", Some "previous definition here", previous_loc, Seq []); + ] + ); env ) @@ -1219,9 +1298,8 @@ end = struct and get_outcome l id env = match Bindings.find_opt id env.outcomes with | Some outcome -> outcome - | None -> - typ_error env l ("Outcome " ^ string_of_id id ^ " does not exist") - + | None -> typ_error env l ("Outcome " ^ string_of_id id ^ " does not exist") + and add_mapping id (typq, typ1, typ2) env = typ_print (lazy (adding ^ "mapping " ^ string_of_id id)); let forwards_id = mk_id (string_of_id id ^ "_forwards") in @@ -1240,60 +1318,70 @@ end = struct |> add_val_spec ~ignore_duplicate:true backwards_matches_id (typq, backwards_matches_typ) in let prefix_id = mk_id (string_of_id id ^ "_matches_prefix") in - if strip_typ typ1 = string_typ then - let forwards_prefix_typ = Typ_aux (Typ_fn ([typ1], app_typ (mk_id "option") [A_aux (A_typ (tuple_typ [typ2; nat_typ]), Parse_ast.Unknown)]), Parse_ast.Unknown) in + if strip_typ typ1 = string_typ then ( + let forwards_prefix_typ = + Typ_aux + ( Typ_fn ([typ1], app_typ (mk_id "option") [A_aux (A_typ (tuple_typ [typ2; nat_typ]), Parse_ast.Unknown)]), + Parse_ast.Unknown + ) + in add_val_spec ~ignore_duplicate:true prefix_id (typq, forwards_prefix_typ) env - else if strip_typ typ2 = string_typ then - let backwards_prefix_typ = Typ_aux (Typ_fn ([typ2], app_typ (mk_id "option") [A_aux (A_typ (tuple_typ [typ1; nat_typ]), Parse_ast.Unknown)]), Parse_ast.Unknown) in + ) + else if strip_typ typ2 = string_typ then ( + let backwards_prefix_typ = + Typ_aux + ( Typ_fn ([typ2], app_typ (mk_id "option") [A_aux (A_typ (tuple_typ [typ1; nat_typ]), Parse_ast.Unknown)]), + Parse_ast.Unknown + ) + in add_val_spec ~ignore_duplicate:true prefix_id (typq, backwards_prefix_typ) env - else - env + ) + else env let get_outcome_instantiation env = env.outcome_instantiation let add_outcome_variable l kid typ env = { env with outcome_instantiation = KBindings.add kid (l, typ) env.outcome_instantiation } - + let define_val_spec id env = - if IdSet.mem id env.defined_val_specs - then typ_error env (id_loc id) ("Function " ^ string_of_id id ^ " has already been declared") + if IdSet.mem id env.defined_val_specs then + typ_error env (id_loc id) ("Function " ^ string_of_id id ^ " has already been declared") else { env with defined_val_specs = IdSet.add id env.defined_val_specs } let get_defined_val_specs env = env.defined_val_specs - let is_ctor id (Tu_aux (tu, _)) = match tu with - | Tu_ty_id (_, ctor_id) when Id.compare id ctor_id = 0 -> true - | _ -> false - + let is_ctor id (Tu_aux (tu, _)) = + match tu with Tu_ty_id (_, ctor_id) when Id.compare id ctor_id = 0 -> true | _ -> false + let union_constructor_info id env = let type_unions = List.map (fun (id, (_, tus)) -> (id, tus)) (Bindings.bindings env.variants) in - Util.find_map (fun (union_id, tus) -> Option.map (fun (n, tu) -> (n, List.length tus, union_id, tu)) (Util.find_index_opt (is_ctor id) tus)) type_unions - + Util.find_map + (fun (union_id, tus) -> + Option.map (fun (n, tu) -> (n, List.length tus, union_id, tu)) (Util.find_index_opt (is_ctor id) tus) + ) + type_unions + let is_union_constructor id env = let type_unions = List.concat (List.map (fun (_, (_, tus)) -> tus) (Bindings.bindings env.variants)) in List.exists (is_ctor id) type_unions let is_singleton_union_constructor id env = let type_unions = List.map (fun (_, (_, tus)) -> tus) (Bindings.bindings env.variants) in - match List.find (List.exists (is_ctor id)) type_unions with - | l -> List.length l = 1 - | exception Not_found -> false + match List.find (List.exists (is_ctor id)) type_unions with l -> List.length l = 1 | exception Not_found -> false let is_mapping id env = Bindings.mem id env.mappings let add_enum id ids env = - if bound_typ_id env id then ( - already_bound "enum" id env - ) else ( + if bound_typ_id env id then already_bound "enum" id env + else ( typ_print (lazy (adding ^ "enum " ^ string_of_id id)); { env with enums = Bindings.add id (IdSet.of_list ids) env.enums } ) - + let get_enum id env = match Bindings.find_opt id env.enums with | Some enum -> IdSet.elements enum - | None -> - typ_error env (id_loc id) ("Enumeration " ^ string_of_id id ^ " does not exist") + | None -> typ_error env (id_loc id) ("Enumeration " ^ string_of_id id ^ " does not exist") let get_enums env = env.enums @@ -1302,63 +1390,65 @@ end = struct let is_record id env = Bindings.mem id env.records let get_record id env = Bindings.find id env.records - + let get_records env = env.records - + let add_record id typq fields env = let fields = List.map (fun (typ, id) -> (expand_synonyms env typ, id)) fields in - if bound_typ_id env id then ( - already_bound "struct" id env - ) else ( + if bound_typ_id env id then already_bound "struct" id env + else ( typ_print (lazy (adding ^ "record " ^ string_of_id id)); let rec record_typ_args = function | [] -> [] - | ((QI_aux (QI_id kopt, _)) :: qis) when is_int_kopt kopt -> - mk_typ_arg (A_nexp (nvar (kopt_kid kopt))) :: record_typ_args qis - | ((QI_aux (QI_id kopt, _)) :: qis) when is_typ_kopt kopt -> - mk_typ_arg (A_typ (mk_typ (Typ_var (kopt_kid kopt)))) :: record_typ_args qis - | ((QI_aux (QI_id kopt, _)) :: qis) when is_order_kopt kopt -> - mk_typ_arg (A_order (mk_ord (Ord_var (kopt_kid kopt)))) :: record_typ_args qis - | (_ :: qis) -> record_typ_args qis + | QI_aux (QI_id kopt, _) :: qis when is_int_kopt kopt -> + mk_typ_arg (A_nexp (nvar (kopt_kid kopt))) :: record_typ_args qis + | QI_aux (QI_id kopt, _) :: qis when is_typ_kopt kopt -> + mk_typ_arg (A_typ (mk_typ (Typ_var (kopt_kid kopt)))) :: record_typ_args qis + | QI_aux (QI_id kopt, _) :: qis when is_order_kopt kopt -> + mk_typ_arg (A_order (mk_ord (Ord_var (kopt_kid kopt)))) :: record_typ_args qis + | _ :: qis -> record_typ_args qis in - let rectyp = match record_typ_args (quant_items typq) with - | [] -> mk_id_typ id - | args -> mk_typ (Typ_app (id, args)) + let rectyp = + match record_typ_args (quant_items typq) with [] -> mk_id_typ id | args -> mk_typ (Typ_app (id, args)) in let fold_accessors accs (typ, fid) = let acc_typ = mk_typ (Typ_fn ([rectyp], typ)) in - typ_print (lazy (indent 1 ^ adding ^ "accessor " ^ string_of_id id ^ "." ^ string_of_id fid ^ " :: " ^ string_of_bind (typq, acc_typ))); + typ_print + ( lazy + (indent 1 ^ adding ^ "accessor " ^ string_of_id id ^ "." ^ string_of_id fid ^ " :: " + ^ string_of_bind (typq, acc_typ) + ) + ); Bindings.add (field_name id fid) (typq, acc_typ) accs in - { env with records = Bindings.add id (typq, fields) env.records; - accessors = List.fold_left fold_accessors env.accessors fields } + { + env with + records = Bindings.add id (typq, fields) env.records; + accessors = List.fold_left fold_accessors env.accessors fields; + } ) - + let get_accessor_fn rec_id id env = - let freshen_bind bind = List.fold_left (fun bind (kid, _) -> freshen_kid env kid bind) bind (KBindings.bindings env.typ_vars) in + let freshen_bind bind = + List.fold_left (fun bind (kid, _) -> freshen_kid env kid bind) bind (KBindings.bindings env.typ_vars) + in try freshen_bind (Bindings.find (field_name rec_id id) env.accessors) - with - | Not_found -> typ_error env (id_loc id) ("No accessor found for " ^ string_of_id (field_name rec_id id)) + with Not_found -> typ_error env (id_loc id) ("No accessor found for " ^ string_of_id (field_name rec_id id)) let get_accessor rec_id id env = match get_accessor_fn rec_id id env with (* All accessors should have a single argument (the record itself) *) - | (typq, Typ_aux (Typ_fn ([rec_typ], field_typ), _)) -> - (typq, rec_typ, field_typ) + | typq, Typ_aux (Typ_fn ([rec_typ], field_typ), _) -> (typq, rec_typ, field_typ) | _ -> typ_error env (id_loc id) ("Accessor with non-function type found for " ^ string_of_id (field_name rec_id id)) let is_mutable id env = try - let (mut, _) = Bindings.find id env.locals in - match mut with - | Mutable -> true - | Immutable -> false - with - | Not_found -> false + let mut, _ = Bindings.find id env.locals in + match mut with Mutable -> true | Immutable -> false + with Not_found -> false - let string_of_mtyp (mut, typ) = match mut with - | Immutable -> string_of_typ typ - | Mutable -> "ref<" ^ string_of_typ typ ^ ">" + let string_of_mtyp (mut, typ) = + match mut with Immutable -> string_of_typ typ | Mutable -> "ref<" ^ string_of_typ typ ^ ">" let add_local id mtyp env = if not env.allow_bindings then typ_error env (id_loc id) "Bindings are not allowed in this context" else (); @@ -1367,49 +1457,49 @@ end = struct typ_error env (id_loc id) ("Local variable " ^ string_of_id id ^ " is already bound as a function name") else (); typ_print (lazy (adding ^ "local binding " ^ string_of_id id ^ " : " ^ string_of_mtyp mtyp)); - { env with locals = Bindings.add id mtyp env.locals; - top_letbinds = IdSet.remove id env.top_letbinds } + { env with locals = Bindings.add id mtyp env.locals; top_letbinds = IdSet.remove id env.top_letbinds } - let add_toplevel_lets ids env = - { env with top_letbinds = IdSet.union ids env.top_letbinds } + let add_toplevel_lets ids env = { env with top_letbinds = IdSet.union ids env.top_letbinds } let get_toplevel_lets env = env.top_letbinds let is_variant id env = Bindings.mem id env.variants - + let add_variant id (typq, constructors) env = let constructors = - List.map (fun (Tu_aux (Tu_ty_id (typ, id), l)) -> + List.map + (fun (Tu_aux (Tu_ty_id (typ, id), l)) -> Tu_aux (Tu_ty_id (expand_synonyms (add_typquant l typq env) typ, id), l) - ) constructors in - if bound_typ_id env id then ( - already_bound "union" id env - ) else ( + ) + constructors + in + if bound_typ_id env id then already_bound "union" id env + else ( typ_print (lazy (adding ^ "variant " ^ string_of_id id)); { env with variants = Bindings.add id (typq, constructors) env.variants } ) - + let add_scattered_variant id typq env = - if bound_typ_id env id then ( - already_bound "scattered union" id env - ) else ( + if bound_typ_id env id then already_bound "scattered union" id env + else ( typ_print (lazy (adding ^ "scattered variant " ^ string_of_id id)); - { env with + { + env with variants = Bindings.add id (typq, []) env.variants; - scattered_variant_envs = Bindings.add id env env.scattered_variant_envs + scattered_variant_envs = Bindings.add id env env.scattered_variant_envs; } ) - + let add_variant_clause id tu env = match Bindings.find_opt id env.variants with | Some (typq, tus) -> { env with variants = Bindings.add id (typq, tus @ [tu]) env.variants } | None -> typ_error env (id_loc id) ("scattered union " ^ string_of_id id ^ " not found") let get_variants env = env.variants - + let get_variant id env = match Bindings.find_opt id env.variants with - | Some (typq, tus) -> typq, tus + | Some (typq, tus) -> (typq, tus) | None -> typ_error env (id_loc id) ("union " ^ string_of_id id ^ " not found") let get_scattered_variant_env id env = @@ -1417,60 +1507,53 @@ end = struct | Some env' -> env' | None -> typ_error env (id_loc id) ("scattered union " ^ string_of_id id ^ " has not been declared") - let is_register id env = - Bindings.mem id env.registers + let is_register id env = Bindings.mem id env.registers let get_register id env = - try Bindings.find id env.registers with - | Not_found -> typ_error env (id_loc id) ("No register binding found for " ^ string_of_id id) + try Bindings.find id env.registers + with Not_found -> typ_error env (id_loc id) ("No register binding found for " ^ string_of_id id) let get_registers env = env.registers let is_extern id env backend = - try not (Ast_util.extern_assoc backend (Bindings.find_opt id env.externs) = None) with - | Not_found -> false + try not (Ast_util.extern_assoc backend (Bindings.find_opt id env.externs) = None) with Not_found -> false - let add_extern id ext env = - { env with externs = Bindings.add id ext env.externs } + let add_extern id ext env = { env with externs = Bindings.add id ext env.externs } let get_extern id env backend = try match Ast_util.extern_assoc backend (Bindings.find_opt id env.externs) with | Some ext -> ext | None -> typ_error env (id_loc id) ("No extern binding found for " ^ string_of_id id) - with - | Not_found -> typ_error env (id_loc id) ("No extern binding found for " ^ string_of_id id) + with Not_found -> typ_error env (id_loc id) ("No extern binding found for " ^ string_of_id id) let get_casts env = env.casts let add_register id typ env = wf_typ env typ; - if Bindings.mem id env.registers - then typ_error env (id_loc id) ("Register " ^ string_of_id id ^ " is already bound") - else - begin - typ_print (lazy (adding ^ "register binding " ^ string_of_id id ^ " :: " ^ string_of_typ typ)); - { env with registers = Bindings.add id typ env.registers } - end + if Bindings.mem id env.registers then typ_error env (id_loc id) ("Register " ^ string_of_id id ^ " is already bound") + else begin + typ_print (lazy (adding ^ "register binding " ^ string_of_id id ^ " :: " ^ string_of_typ typ)); + { env with registers = Bindings.add id typ env.registers } + end let get_locals env = env.locals let lookup_id id env = try - let (mut, typ) = Bindings.find id env.locals in + let mut, typ = Bindings.find id env.locals in Local (mut, typ) - with - | Not_found -> - try - let typ = Bindings.find id env.registers in - Register typ - with - | Not_found -> - try - let (enum, _) = List.find (fun (_, ctors) -> IdSet.mem id ctors) (Bindings.bindings env.enums) in - Enum (mk_typ (Typ_id enum)) - with - | Not_found -> Unbound id + with Not_found -> ( + try + let typ = Bindings.find id env.registers in + Register typ + with Not_found -> ( + try + let enum, _ = List.find (fun (_, ctors) -> IdSet.mem id ctors) (Bindings.bindings env.enums) in + Enum (mk_typ (Typ_id enum)) + with Not_found -> Unbound id + ) + ) let get_ret_typ env = env.ret_typ @@ -1488,59 +1571,50 @@ end = struct let get_default_order env = match env.default_order with - | None -> typ_error env Parse_ast.Unknown ("No default order has been set") + | None -> typ_error env Parse_ast.Unknown "No default order has been set" | Some ord -> ord let get_default_order_option env = env.default_order - + let set_default_order o env = match o with | Ord_aux (Ord_var _, l) -> typ_error env l "Cannot have variable default order" - | Ord_aux (_, l) -> - match env.default_order with - | None -> { env with default_order = Some o } - | Some _ -> typ_error env l ("Cannot change default order once already set") + | Ord_aux (_, l) -> ( + match env.default_order with + | None -> { env with default_order = Some o } + | Some _ -> typ_error env l "Cannot change default order once already set" + ) let base_typ_of env typ = - let rec aux (Typ_aux (t,a)) = - let rewrap t = Typ_aux (t,a) in + let rec aux (Typ_aux (t, a)) = + let rewrap t = Typ_aux (t, a) in match t with - | Typ_fn (arg_typs, ret_typ) -> - rewrap (Typ_fn (List.map aux arg_typs, aux ret_typ)) - | Typ_tuple ts -> - rewrap (Typ_tuple (List.map aux ts)) - | Typ_app (r, [A_aux (A_typ rtyp,_)]) when string_of_id r = "register" -> - aux rtyp - | Typ_app (id, targs) -> - rewrap (Typ_app (id, List.map aux_arg targs)) + | Typ_fn (arg_typs, ret_typ) -> rewrap (Typ_fn (List.map aux arg_typs, aux ret_typ)) + | Typ_tuple ts -> rewrap (Typ_tuple (List.map aux ts)) + | Typ_app (r, [A_aux (A_typ rtyp, _)]) when string_of_id r = "register" -> aux rtyp + | Typ_app (id, targs) -> rewrap (Typ_app (id, List.map aux_arg targs)) | t -> rewrap t - and aux_arg (A_aux (targ,a)) = - let rewrap targ = A_aux (targ,a) in - match targ with - | A_typ typ -> rewrap (A_typ (aux typ)) - | targ -> rewrap targ in + and aux_arg (A_aux (targ, a)) = + let rewrap targ = A_aux (targ, a) in + match targ with A_typ typ -> rewrap (A_typ (aux typ)) | targ -> rewrap targ + in aux (expand_synonyms env typ) let is_bitfield id env = Bindings.mem id env.bitfields let get_bitfield_ranges id env = Bindings.find id env.bitfields - let add_bitfield id ranges env = - { env with bitfields = Bindings.add id ranges env.bitfields } + let add_bitfield id ranges env = { env with bitfields = Bindings.add id ranges env.bitfields } - let allow_polymorphic_undefineds env = - { env with poly_undefineds = true } + let allow_polymorphic_undefineds env = { env with poly_undefineds = true } let polymorphic_undefineds env = env.poly_undefineds - end let get_bitfield_range id field env = - try Bindings.find_opt field (Env.get_bitfield_ranges id env) - with Not_found -> None + try Bindings.find_opt field (Env.get_bitfield_ranges id env) with Not_found -> None -let expand_bind_synonyms l env (typq, typ) = - typq, Env.expand_synonyms (Env.add_typquant l typq env) typ +let expand_bind_synonyms l env (typq, typ) = (typq, Env.expand_synonyms (Env.add_typquant l typq env) typ) let wf_binding l env (typq, typ) = let env = Env.add_typquant l typq env in @@ -1560,11 +1634,12 @@ let add_existential l kopts nc env = let env = List.fold_left (fun env kopt -> Env.add_typ_var l kopt env) env kopts in Env.add_constraint nc env -let add_typ_vars l kopts env = List.fold_left (fun env (KOpt_aux (_, kl) as kopt) -> Env.add_typ_var (Parse_ast.Hint ("derived from here", kl, l)) kopt env) env kopts +let add_typ_vars l kopts env = + List.fold_left + (fun env (KOpt_aux (_, kl) as kopt) -> Env.add_typ_var (Parse_ast.Hint ("derived from here", kl, l)) kopt env) + env kopts -let is_exist = function - | Typ_aux (Typ_exist (_, _, _), _) -> true - | _ -> false +let is_exist = function Typ_aux (Typ_exist (_, _, _), _) -> true | _ -> false let exist_typ l constr typ = let fresh = fresh_existential l K_int in @@ -1572,76 +1647,84 @@ let exist_typ l constr typ = let bind_numeric l typ env = match destruct_numeric (Env.expand_synonyms env typ) with - | Some (kids, nc, nexp) -> - nexp, add_existential l (List.map (mk_kopt K_int) kids) nc env + | Some (kids, nc, nexp) -> (nexp, add_existential l (List.map (mk_kopt K_int) kids) nc env) | None -> typ_error env l ("Expected " ^ string_of_typ typ ^ " to be numeric") let check_shadow_leaks l inner_env outer_env typ = typ_debug (lazy ("Shadow leaks: " ^ string_of_typ typ)); let vars = tyvars_of_typ typ in - List.iter (fun var -> + List.iter + (fun var -> if Env.shadows var inner_env > Env.shadows var outer_env then - typ_error outer_env l - ("Type variable " ^ string_of_kid var ^ " would leak into a scope where it is shadowed") - else + typ_error outer_env l ("Type variable " ^ string_of_kid var ^ " would leak into a scope where it is shadowed") + else ( match Env.get_typ_var_loc_opt var outer_env with | Some _ -> () - | None -> - match Env.get_typ_var_loc_opt var inner_env with - | Some leak_l -> - typ_raise outer_env l - (err_because - (Err_other ("The type variable " ^ string_of_kid var - ^ " would leak into an outer scope.\n\nTry adding a type annotation to this expression."), - leak_l, - Err_other ("Type variable " ^ string_of_kid var ^ " was introduced here"))) - | None -> Reporting.unreachable l __POS__ "Found a type with an unknown type variable" + | None -> ( + match Env.get_typ_var_loc_opt var inner_env with + | Some leak_l -> + typ_raise outer_env l + (err_because + ( Err_other + ("The type variable " ^ string_of_kid var + ^ " would leak into an outer scope.\n\nTry adding a type annotation to this expression." + ), + leak_l, + Err_other ("Type variable " ^ string_of_kid var ^ " was introduced here") + ) + ) + | None -> Reporting.unreachable l __POS__ "Found a type with an unknown type variable" + ) + ) ) (KidSet.elements vars); typ - + (** Pull an (potentially)-existentially qualified type into the global typing environment **) let bind_existential l name typ env = - match destruct_exist ~name:name (Env.expand_synonyms env typ) with - | Some (kids, nc, typ) -> typ, add_existential l kids nc env - | None -> typ, env + match destruct_exist ~name (Env.expand_synonyms env typ) with + | Some (kids, nc, typ) -> (typ, add_existential l kids nc env) + | None -> (typ, env) let bind_tuple_existentials l name (Typ_aux (aux, annot) as typ) env = match aux with | Typ_tuple typs -> - let typs, env = - List.fold_right (fun typ (typs, env) -> let typ, env = bind_existential l name typ env in typ :: typs, env) typs ([], env) - in - Typ_aux (Typ_tuple typs, annot), env - | _ -> typ, env + let typs, env = + List.fold_right + (fun typ (typs, env) -> + let typ, env = bind_existential l name typ env in + (typ :: typs, env) + ) + typs ([], env) + in + (Typ_aux (Typ_tuple typs, annot), env) + | _ -> (typ, env) let destruct_range env typ = - let kopts, constr, (Typ_aux (typ_aux, _)) = + let kopts, constr, Typ_aux (typ_aux, _) = Option.value ~default:([], nc_true, typ) (destruct_exist (Env.expand_synonyms env typ)) in match typ_aux with - | Typ_app (f, [A_aux (A_nexp n, _)]) - when string_of_id f = "atom" || string_of_id f = "implicit" -> Some (List.map kopt_kid kopts, constr, n, n) - | Typ_app (f, [A_aux (A_nexp n1, _); A_aux (A_nexp n2, _)]) - when string_of_id f = "range" -> Some (List.map kopt_kid kopts, constr, n1, n2) - | _ -> None + | Typ_app (f, [A_aux (A_nexp n, _)]) when string_of_id f = "atom" || string_of_id f = "implicit" -> + Some (List.map kopt_kid kopts, constr, n, n) + | Typ_app (f, [A_aux (A_nexp n1, _); A_aux (A_nexp n2, _)]) when string_of_id f = "range" -> + Some (List.map kopt_kid kopts, constr, n1, n2) + | _ -> None let destruct_vector env typ = let destruct_vector' = function - | Typ_aux (Typ_app (id, [A_aux (A_nexp n1, _); - A_aux (A_order o, _); - A_aux (A_typ vtyp, _)] - ), _) when string_of_id id = "vector" -> Some (nexp_simp n1, o, vtyp) + | Typ_aux (Typ_app (id, [A_aux (A_nexp n1, _); A_aux (A_order o, _); A_aux (A_typ vtyp, _)]), _) + when string_of_id id = "vector" -> + Some (nexp_simp n1, o, vtyp) | _ -> None in destruct_vector' (Env.expand_synonyms env typ) let destruct_bitvector env typ = let destruct_bitvector' = function - | Typ_aux (Typ_app (id, [A_aux (A_nexp n1, _); - A_aux (A_order o, _)] - ), _) when string_of_id id = "bitvector" -> Some (nexp_simp n1, o) + | Typ_aux (Typ_app (id, [A_aux (A_nexp n1, _); A_aux (A_order o, _)]), _) when string_of_id id = "bitvector" -> + Some (nexp_simp n1, o) | _ -> None in destruct_bitvector' (Env.expand_synonyms env typ) @@ -1655,6 +1738,7 @@ let rec is_typ_monomorphic (Typ_aux (typ, l)) = | Typ_bidir (typ1, typ2) -> is_typ_monomorphic typ1 && is_typ_monomorphic typ2 | Typ_exist _ | Typ_var _ -> false | Typ_internal_unknown -> Reporting.unreachable l __POS__ "escaped Typ_internal_unknown" + and is_typ_arg_monomorphic (A_aux (arg, _)) = match arg with | A_nexp _ -> true @@ -1667,29 +1751,21 @@ and is_typ_arg_monomorphic (A_aux (arg, _)) = (* 2. Subtyping and constraint solving *) (**************************************************************************) -type ('a, 'b) filter = - | Keep of 'a - | Remove of 'b +type ('a, 'b) filter = Keep of 'a | Remove of 'b -let rec filter_keep = function - | Keep x :: xs -> x :: filter_keep xs - | Remove _ :: xs -> filter_keep xs - | [] -> [] +let rec filter_keep = function Keep x :: xs -> x :: filter_keep xs | Remove _ :: xs -> filter_keep xs | [] -> [] -let rec filter_remove = function - | Keep _ :: xs -> filter_remove xs - | Remove x :: xs -> x :: filter_remove xs - | [] -> [] +let rec filter_remove = function Keep _ :: xs -> filter_remove xs | Remove x :: xs -> x :: filter_remove xs | [] -> [] let filter_split f g xs = let xs = List.map f xs in - filter_keep xs, g (filter_remove xs) + (filter_keep xs, g (filter_remove xs)) let rec simp_typ (Typ_aux (typ_aux, l)) = Typ_aux (simp_typ_aux typ_aux, l) + and simp_typ_aux = function | Typ_exist (kids1, nc1, Typ_aux (Typ_exist (kids2, nc2, typ), _)) -> - simp_typ_aux (Typ_exist (kids1 @ kids2, nc_and nc1 nc2, typ)) - + simp_typ_aux (Typ_exist (kids1 @ kids2, nc_and nc1 nc2, typ)) (* This removes redundant boolean variables in existentials, such that {('p: Bool) ('q:Bool) ('r: Bool), nc('r). bool('p & 'q & 'r)} would become {('s:Bool) ('r: Bool), nc('r). bool('s & 'r)}, @@ -1698,104 +1774,126 @@ and simp_typ_aux = function having to pass large numbers of pointless variables to SMT if we ever bind this existential. *) | Typ_exist (vars, nc, Typ_aux (Typ_app (Id_aux (Id "atom_bool", _), [A_aux (A_bool b, _)]), l)) -> - let kids = KidSet.of_list (List.map kopt_kid vars) in - let constrained = tyvars_of_constraint nc in - let conjs = constraint_conj b in - let is_redundant = function - | NC_aux (NC_var v, _) when KidSet.mem v kids && not (KidSet.mem v constrained) -> Remove v - | nc -> Keep nc - in - let conjs, redundant = filter_split is_redundant KidSet.of_list conjs in - begin match conjs with - | [] -> Typ_id (mk_id "bool") - | conj :: conjs when KidSet.is_empty redundant -> - Typ_exist (vars, nc, atom_bool_typ (List.fold_left nc_and conj conjs)) - | conjs -> - let vars = List.filter (fun v -> not (KidSet.mem (kopt_kid v) redundant)) vars in - let var = fresh_existential l K_bool in - Typ_exist (var :: vars, nc, atom_bool_typ (List.fold_left nc_and (nc_var (kopt_kid var)) conjs)) - end - + let kids = KidSet.of_list (List.map kopt_kid vars) in + let constrained = tyvars_of_constraint nc in + let conjs = constraint_conj b in + let is_redundant = function + | NC_aux (NC_var v, _) when KidSet.mem v kids && not (KidSet.mem v constrained) -> Remove v + | nc -> Keep nc + in + let conjs, redundant = filter_split is_redundant KidSet.of_list conjs in + begin + match conjs with + | [] -> Typ_id (mk_id "bool") + | conj :: conjs when KidSet.is_empty redundant -> + Typ_exist (vars, nc, atom_bool_typ (List.fold_left nc_and conj conjs)) + | conjs -> + let vars = List.filter (fun v -> not (KidSet.mem (kopt_kid v) redundant)) vars in + let var = fresh_existential l K_bool in + Typ_exist (var :: vars, nc, atom_bool_typ (List.fold_left nc_and (nc_var (kopt_kid var)) conjs)) + end | typ_aux -> typ_aux (* Here's how the constraint generation works for subtyping -X(b,c...) --> {a. Y(a,b,c...)} \subseteq {a. Z(a,b,c...)} + X(b,c...) --> {a. Y(a,b,c...)} \subseteq {a. Z(a,b,c...)} -this is equivalent to + this is equivalent to -\forall b c. X(b,c) --> \forall a. Y(a,b,c) --> Z(a,b,c) + \forall b c. X(b,c) --> \forall a. Y(a,b,c) --> Z(a,b,c) -\forall b c. X(b,c) --> \forall a. !Y(a,b,c) \/ !Z^-1(a,b,c) + \forall b c. X(b,c) --> \forall a. !Y(a,b,c) \/ !Z^-1(a,b,c) -\forall b c. X(b,c) --> !\exists a. Y(a,b,c) /\ Z^-1(a,b,c) + \forall b c. X(b,c) --> !\exists a. Y(a,b,c) /\ Z^-1(a,b,c) -\forall b c. !X(b,c) \/ !\exists a. Y(a,b,c) /\ Z^-1(a,b,c) + \forall b c. !X(b,c) \/ !\exists a. Y(a,b,c) /\ Z^-1(a,b,c) -!\exists b c. X(b,c) /\ \exists a. Y(a,b,c) /\ Z^-1(a,b,c) + !\exists b c. X(b,c) /\ \exists a. Y(a,b,c) /\ Z^-1(a,b,c) -!\exists a b c. X(b,c) /\ Y(a,b,c) /\ Z^-1(a,b,c) + !\exists a b c. X(b,c) /\ Y(a,b,c) /\ Z^-1(a,b,c) -which is then a problem we can feed to the constraint solver expecting unsat. - *) + which is then a problem we can feed to the constraint solver expecting unsat. +*) let prove_smt env (NC_aux (_, l) as nc) = let ncs = Env.get_constraints env in match Constraint.call_smt l (List.fold_left nc_and (nc_not nc) ncs) with - | Constraint.Unsat -> typ_debug (lazy "unsat"); true - | Constraint.Sat -> typ_debug (lazy "sat"); false - | Constraint.Unknown -> - (* Work around versions of z3 that are confused by 2^n in - constraints, even when such constraints are irrelevant *) - let ncs' = List.concat (List.map constraint_conj ncs) in - let ncs' = List.filter (fun nc -> KidSet.is_empty (constraint_power_variables nc)) ncs' in - match Constraint.call_smt l (List.fold_left nc_and (nc_not nc) ncs') with - | Constraint.Unsat -> typ_debug (lazy "unsat"); true - | Constraint.Sat | Constraint.Unknown -> typ_debug (lazy "sat/unknown"); false + | Constraint.Unsat -> + typ_debug (lazy "unsat"); + true + | Constraint.Sat -> + typ_debug (lazy "sat"); + false + | Constraint.Unknown -> ( + (* Work around versions of z3 that are confused by 2^n in + constraints, even when such constraints are irrelevant *) + let ncs' = List.concat (List.map constraint_conj ncs) in + let ncs' = List.filter (fun nc -> KidSet.is_empty (constraint_power_variables nc)) ncs' in + match Constraint.call_smt l (List.fold_left nc_and (nc_not nc) ncs') with + | Constraint.Unsat -> + typ_debug (lazy "unsat"); + true + | Constraint.Sat | Constraint.Unknown -> + typ_debug (lazy "sat/unknown"); + false + ) let solve_unique env (Nexp_aux (_, l) as nexp) = - typ_print (lazy (Util.("Solve " |> red |> clear) ^ string_of_list ", " string_of_n_constraint (Env.get_constraints env) - ^ " |- " ^ string_of_nexp nexp ^ " = ?")); + typ_print + ( lazy + (Util.("Solve " |> red |> clear) + ^ string_of_list ", " string_of_n_constraint (Env.get_constraints env) + ^ " |- " ^ string_of_nexp nexp ^ " = ?" + ) + ); match nexp with - | Nexp_aux (Nexp_constant n,_) -> Some n + | Nexp_aux (Nexp_constant n, _) -> Some n | _ -> - let env = Env.add_typ_var l (mk_kopt K_int (mk_kid "solve#")) env in - let vars = Env.get_typ_vars env in - let _vars = KBindings.filter (fun _ k -> match k with K_int | K_bool -> true | _ -> false) vars in - let constr = List.fold_left nc_and (nc_eq (nvar (mk_kid "solve#")) nexp) (Env.get_constraints env) in - Constraint.solve_unique_smt l constr (mk_kid "solve#") + let env = Env.add_typ_var l (mk_kopt K_int (mk_kid "solve#")) env in + let vars = Env.get_typ_vars env in + let _vars = KBindings.filter (fun _ k -> match k with K_int | K_bool -> true | _ -> false) vars in + let constr = List.fold_left nc_and (nc_eq (nvar (mk_kid "solve#")) nexp) (Env.get_constraints env) in + Constraint.solve_unique_smt l constr (mk_kid "solve#") -let debug_pos (file, line, _, _) = - "(" ^ file ^ "/" ^ string_of_int line ^ ") " +let debug_pos (file, line, _, _) = "(" ^ file ^ "/" ^ string_of_int line ^ ") " let prove pos env nc = - typ_print (lazy (Util.("Prove " |> red |> clear) ^ string_of_list ", " string_of_n_constraint (Env.get_constraints env) ^ " |- " ^ string_of_n_constraint nc)); + typ_print + ( lazy + (Util.("Prove " |> red |> clear) + ^ string_of_list ", " string_of_n_constraint (Env.get_constraints env) + ^ " |- " ^ string_of_n_constraint nc + ) + ); let (NC_aux (nc_aux, _) as nc) = constraint_simp (Env.expand_constraint_synonyms env nc) in if !Constraint.opt_smt_verbose then - prerr_endline (Util.("Prove " |> red |> clear) ^ debug_pos pos ^ string_of_list ", " string_of_n_constraint (Env.get_constraints env) ^ " |- " ^ string_of_n_constraint nc) + prerr_endline + (Util.("Prove " |> red |> clear) + ^ debug_pos pos + ^ string_of_list ", " string_of_n_constraint (Env.get_constraints env) + ^ " |- " ^ string_of_n_constraint nc + ) else (); - match nc_aux with - | NC_true -> true - | _ -> prove_smt env nc + match nc_aux with NC_true -> true | _ -> prove_smt env nc (**************************************************************************) (* 3. Unification *) (**************************************************************************) -let rec nexp_frees ?exs:(exs=KidSet.empty) (Nexp_aux (nexp, l)) = +let rec nexp_frees ?(exs = KidSet.empty) (Nexp_aux (nexp, l)) = match nexp with | Nexp_id _ -> KidSet.empty | Nexp_var kid -> KidSet.singleton kid | Nexp_constant _ -> KidSet.empty - | Nexp_times (n1, n2) -> KidSet.union (nexp_frees ~exs:exs n1) (nexp_frees ~exs:exs n2) - | Nexp_sum (n1, n2) -> KidSet.union (nexp_frees ~exs:exs n1) (nexp_frees ~exs:exs n2) - | Nexp_minus (n1, n2) -> KidSet.union (nexp_frees ~exs:exs n1) (nexp_frees ~exs:exs n2) - | Nexp_app (id, ns) -> List.fold_left KidSet.union KidSet.empty (List.map (fun n -> nexp_frees ~exs:exs n) ns) - | Nexp_exp n -> nexp_frees ~exs:exs n - | Nexp_neg n -> nexp_frees ~exs:exs n + | Nexp_times (n1, n2) -> KidSet.union (nexp_frees ~exs n1) (nexp_frees ~exs n2) + | Nexp_sum (n1, n2) -> KidSet.union (nexp_frees ~exs n1) (nexp_frees ~exs n2) + | Nexp_minus (n1, n2) -> KidSet.union (nexp_frees ~exs n1) (nexp_frees ~exs n2) + | Nexp_app (id, ns) -> List.fold_left KidSet.union KidSet.empty (List.map (fun n -> nexp_frees ~exs n) ns) + | Nexp_exp n -> nexp_frees ~exs n + | Nexp_neg n -> nexp_frees ~exs n let rec nexp_identical (Nexp_aux (nexp1, _)) (Nexp_aux (nexp2, _)) = - match nexp1, nexp2 with + match (nexp1, nexp2) with | Nexp_id v1, Nexp_id v2 -> Id.compare v1 v2 = 0 | Nexp_var kid1, Nexp_var kid2 -> Kid.compare kid1 kid2 = 0 | Nexp_constant c1, Nexp_constant c2 -> Big_int.equal c1 c2 @@ -1805,18 +1903,18 @@ let rec nexp_identical (Nexp_aux (nexp1, _)) (Nexp_aux (nexp2, _)) = | Nexp_exp n1, Nexp_exp n2 -> nexp_identical n1 n2 | Nexp_neg n1, Nexp_neg n2 -> nexp_identical n1 n2 | Nexp_app (f1, args1), Nexp_app (f2, args2) when List.length args1 = List.length args2 -> - Id.compare f1 f2 = 0 && List.for_all2 nexp_identical args1 args2 + Id.compare f1 f2 = 0 && List.for_all2 nexp_identical args1 args2 | _, _ -> false let ord_identical (Ord_aux (ord1, _)) (Ord_aux (ord2, _)) = - match ord1, ord2 with + match (ord1, ord2) with | Ord_var kid1, Ord_var kid2 -> Kid.compare kid1 kid2 = 0 | Ord_inc, Ord_inc -> true | Ord_dec, Ord_dec -> true | _, _ -> false let rec nc_identical (NC_aux (nc1, _)) (NC_aux (nc2, _)) = - match nc1, nc2 with + match (nc1, nc2) with | NC_equal (n1a, n1b), NC_equal (n2a, n2b) -> nexp_identical n1a n2a && nexp_identical n1b n2b | NC_not_equal (n1a, n1b), NC_not_equal (n2a, n2b) -> nexp_identical n1a n2a && nexp_identical n1b n2b | NC_bounded_ge (n1a, n1b), NC_bounded_ge (n2a, n2b) -> nexp_identical n1a n2a && nexp_identical n1b n2b @@ -1828,14 +1926,14 @@ let rec nc_identical (NC_aux (nc1, _)) (NC_aux (nc2, _)) = | NC_true, NC_true -> true | NC_false, NC_false -> true | NC_set (kid1, ints1), NC_set (kid2, ints2) when List.length ints1 = List.length ints2 -> - Kid.compare kid1 kid2 = 0 && List.for_all2 (fun i1 i2 -> i1 = i2) ints1 ints2 + Kid.compare kid1 kid2 = 0 && List.for_all2 (fun i1 i2 -> i1 = i2) ints1 ints2 | NC_var kid1, NC_var kid2 -> Kid.compare kid1 kid2 = 0 | NC_app (id1, args1), NC_app (id2, args2) when List.length args1 = List.length args2 -> - Id.compare id1 id2 = 0 && List.for_all2 typ_arg_identical args1 args2 + Id.compare id1 id2 = 0 && List.for_all2 typ_arg_identical args1 args2 | _, _ -> false and typ_arg_identical (A_aux (arg1, _)) (A_aux (arg2, _)) = - match arg1, arg2 with + match (arg1, arg2) with | A_nexp n1, A_nexp n2 -> nexp_identical n1 n2 | A_typ typ1, A_typ typ2 -> typ_identical typ1 typ2 | A_order ord1, A_order ord2 -> ord_identical ord1 ord2 @@ -1843,260 +1941,255 @@ and typ_arg_identical (A_aux (arg1, _)) (A_aux (arg2, _)) = | _, _ -> false and typ_identical (Typ_aux (typ1, _)) (Typ_aux (typ2, _)) = - match typ1, typ2 with + match (typ1, typ2) with | Typ_id v1, Typ_id v2 -> Id.compare v1 v2 = 0 | Typ_var kid1, Typ_var kid2 -> Kid.compare kid1 kid2 = 0 - | Typ_fn (arg_typs1, ret_typ1), Typ_fn (arg_typs2, ret_typ2) - when List.length arg_typs1 = List.length arg_typs2 -> - List.for_all2 typ_identical arg_typs1 arg_typs2 - && typ_identical ret_typ1 ret_typ2 - | Typ_bidir (typ1, typ2), Typ_bidir (typ3, typ4) -> - typ_identical typ1 typ3 - && typ_identical typ2 typ4 - | Typ_tuple typs1, Typ_tuple typs2 -> - begin - try List.for_all2 typ_identical typs1 typs2 with - | Invalid_argument _ -> false - end - | Typ_app (f1, args1), Typ_app (f2, args2) -> - begin - try Id.compare f1 f2 = 0 && List.for_all2 typ_arg_identical args1 args2 with - | Invalid_argument _ -> false - end + | Typ_fn (arg_typs1, ret_typ1), Typ_fn (arg_typs2, ret_typ2) when List.length arg_typs1 = List.length arg_typs2 -> + List.for_all2 typ_identical arg_typs1 arg_typs2 && typ_identical ret_typ1 ret_typ2 + | Typ_bidir (typ1, typ2), Typ_bidir (typ3, typ4) -> typ_identical typ1 typ3 && typ_identical typ2 typ4 + | Typ_tuple typs1, Typ_tuple typs2 -> begin + try List.for_all2 typ_identical typs1 typs2 with Invalid_argument _ -> false + end + | Typ_app (f1, args1), Typ_app (f2, args2) -> begin + try Id.compare f1 f2 = 0 && List.for_all2 typ_arg_identical args1 args2 with Invalid_argument _ -> false + end | Typ_exist (kopts1, nc1, typ1), Typ_exist (kopts2, nc2, typ2) when List.length kopts1 = List.length kopts2 -> - List.for_all2 (fun k1 k2 -> KOpt.compare k1 k2 = 0) kopts1 kopts2 && nc_identical nc1 nc2 && typ_identical typ1 typ2 + List.for_all2 (fun k1 k2 -> KOpt.compare k1 k2 = 0) kopts1 kopts2 + && nc_identical nc1 nc2 && typ_identical typ1 typ2 | _, _ -> false -let expanded_typ_identical env typ1 typ2 = - typ_identical (Env.expand_synonyms env typ1) (Env.expand_synonyms env typ2) +let expanded_typ_identical env typ1 typ2 = typ_identical (Env.expand_synonyms env typ1) (Env.expand_synonyms env typ2) -exception Unification_error of l * string;; +exception Unification_error of l * string let unify_error l str = raise (Unification_error (l, str)) let merge_unifiers env l kid uvar1 uvar2 = - match uvar1, uvar2 with + match (uvar1, uvar2) with | Some arg1, Some arg2 when typ_arg_identical arg1 arg2 -> Some arg1 (* If the unifiers are equivalent nexps, use one, preferably a variable *) - | Some (A_aux (A_nexp nexp1, _) as arg1), - Some (A_aux (A_nexp nexp2, _) as arg2) - when prove __POS__ env (nc_eq nexp1 nexp2) -> - begin match nexp1, nexp2 with - | Nexp_aux (Nexp_var _, _), _ -> Some arg1 - | _, Nexp_aux (Nexp_var _, _) -> Some arg2 - | _, _ -> Some arg1 - end + | Some (A_aux (A_nexp nexp1, _) as arg1), Some (A_aux (A_nexp nexp2, _) as arg2) + when prove __POS__ env (nc_eq nexp1 nexp2) -> begin + match (nexp1, nexp2) with + | Nexp_aux (Nexp_var _, _), _ -> Some arg1 + | _, Nexp_aux (Nexp_var _, _) -> Some arg2 + | _, _ -> Some arg1 + end | Some arg1, Some arg2 -> - unify_error l ("Multiple non-identical unifiers for " ^ string_of_kid kid - ^ ": " ^ string_of_typ_arg arg1 ^ " and " ^ string_of_typ_arg arg2) + unify_error l + ("Multiple non-identical unifiers for " ^ string_of_kid kid ^ ": " ^ string_of_typ_arg arg1 ^ " and " + ^ string_of_typ_arg arg2 + ) | None, Some u2 -> Some u2 | Some u1, None -> Some u1 | None, None -> None -let merge_uvars env l unifiers1 unifiers2 = - KBindings.merge (merge_unifiers env l) unifiers1 unifiers2 +let merge_uvars env l unifiers1 unifiers2 = KBindings.merge (merge_unifiers env l) unifiers1 unifiers2 let rec unify_typ l env goals (Typ_aux (aux1, _) as typ1) (Typ_aux (aux2, _) as typ2) = - typ_debug (lazy (Util.("Unify type " |> magenta |> clear) ^ string_of_typ typ1 ^ " and " ^ string_of_typ typ2 - ^ " goals " ^ string_of_list ", " string_of_kid (KidSet.elements goals))); - match aux1, aux2 with - | Typ_internal_unknown, _ | _, Typ_internal_unknown - when Env.allow_unknowns env -> - KBindings.empty - + typ_debug + ( lazy + (Util.("Unify type " |> magenta |> clear) + ^ string_of_typ typ1 ^ " and " ^ string_of_typ typ2 ^ " goals " + ^ string_of_list ", " string_of_kid (KidSet.elements goals) + ) + ); + match (aux1, aux2) with + | (Typ_internal_unknown, _ | _, Typ_internal_unknown) when Env.allow_unknowns env -> KBindings.empty | Typ_var v, _ when KidSet.mem v goals -> KBindings.singleton v (arg_typ typ2) - | Typ_var v1, Typ_var v2 when Kid.compare v1 v2 = 0 -> KBindings.empty - (* We need special cases for unifying range(n, m), nat, and int vs atom('n) *) | Typ_id int, Typ_app (atom, [A_aux (A_nexp n, _)]) when string_of_id int = "int" -> KBindings.empty - | Typ_id nat, Typ_app (atom, [A_aux (A_nexp n, _)]) when string_of_id nat = "nat" -> - if prove __POS__ env (nc_gteq n (nint 0)) then KBindings.empty - else unify_error l (string_of_typ typ2 ^ " must be a natural number") - - | Typ_app (range, [A_aux (A_nexp n1, _); A_aux (A_nexp n2, _)]), - Typ_app (atom, [A_aux (A_nexp m, _)]) - when string_of_id range = "range" && string_of_id atom = "atom" -> - let n1, n2 = nexp_simp n1, nexp_simp n2 in - begin match n1, n2 with - | Nexp_aux (Nexp_constant c1, _), Nexp_aux (Nexp_constant c2, _) -> - if prove __POS__ env (nc_and (nc_lteq n1 m) (nc_lteq m n2)) then KBindings.empty - else unify_error l (string_of_typ typ1 ^ " is not contained within " ^ string_of_typ typ1) - | _, _ -> - merge_uvars env l (unify_nexp l env goals n1 m) (unify_nexp l env goals n2 m) - end - + if prove __POS__ env (nc_gteq n (nint 0)) then KBindings.empty + else unify_error l (string_of_typ typ2 ^ " must be a natural number") + | Typ_app (range, [A_aux (A_nexp n1, _); A_aux (A_nexp n2, _)]), Typ_app (atom, [A_aux (A_nexp m, _)]) + when string_of_id range = "range" && string_of_id atom = "atom" -> + let n1, n2 = (nexp_simp n1, nexp_simp n2) in + begin + match (n1, n2) with + | Nexp_aux (Nexp_constant c1, _), Nexp_aux (Nexp_constant c2, _) -> + if prove __POS__ env (nc_and (nc_lteq n1 m) (nc_lteq m n2)) then KBindings.empty + else unify_error l (string_of_typ typ1 ^ " is not contained within " ^ string_of_typ typ1) + | _, _ -> merge_uvars env l (unify_nexp l env goals n1 m) (unify_nexp l env goals n2 m) + end | Typ_app (id1, args1), Typ_app (id2, args2) when List.length args1 = List.length args2 && Id.compare id1 id2 = 0 -> - List.fold_left (merge_uvars env l) KBindings.empty (List.map2 (unify_typ_arg l env goals) args1 args2) - + List.fold_left (merge_uvars env l) KBindings.empty (List.map2 (unify_typ_arg l env goals) args1 args2) | Typ_app (id1, []), Typ_id id2 when Id.compare id1 id2 = 0 -> KBindings.empty | Typ_id id1, Typ_app (id2, []) when Id.compare id1 id2 = 0 -> KBindings.empty | Typ_id id1, Typ_id id2 when Id.compare id1 id2 = 0 -> KBindings.empty - | Typ_tuple typs1, Typ_tuple typs2 when List.length typs1 = List.length typs2 -> - List.fold_left (merge_uvars env l) KBindings.empty (List.map2 (unify_typ l env goals) typs1 typs2) - + List.fold_left (merge_uvars env l) KBindings.empty (List.map2 (unify_typ l env goals) typs1 typs2) | Typ_fn (arg_typs1, ret_typ1), Typ_fn (arg_typs2, ret_typ2) when List.length arg_typs1 = List.length arg_typs2 -> - merge_uvars env l - (List.fold_left (merge_uvars env l) KBindings.empty (List.map2 (unify_typ l env goals) arg_typs1 arg_typs2)) - (unify_typ l env goals ret_typ1 ret_typ2) - + merge_uvars env l + (List.fold_left (merge_uvars env l) KBindings.empty (List.map2 (unify_typ l env goals) arg_typs1 arg_typs2)) + (unify_typ l env goals ret_typ1 ret_typ2) | _, _ -> unify_error l ("Could not unify " ^ string_of_typ typ1 ^ " and " ^ string_of_typ typ2) and unify_typ_arg l env goals (A_aux (aux1, _) as typ_arg1) (A_aux (aux2, _) as typ_arg2) = - match aux1, aux2 with + match (aux1, aux2) with | A_typ typ1, A_typ typ2 -> unify_typ l env goals typ1 typ2 | A_nexp nexp1, A_nexp nexp2 -> unify_nexp l env goals nexp1 nexp2 | A_order ord1, A_order ord2 -> unify_order l goals ord1 ord2 | A_bool nc1, A_bool nc2 -> unify_constraint l env goals nc1 nc2 - | _, _ -> unify_error l ("Could not unify type arguments " ^ string_of_typ_arg typ_arg1 ^ " and " ^ string_of_typ_arg typ_arg2) + | _, _ -> + unify_error l + ("Could not unify type arguments " ^ string_of_typ_arg typ_arg1 ^ " and " ^ string_of_typ_arg typ_arg2) and unify_constraint l env goals (NC_aux (aux1, _) as nc1) (NC_aux (aux2, _) as nc2) = - typ_debug (lazy (Util.("Unify constraint " |> magenta |> clear) ^ string_of_n_constraint nc1 ^ " and " ^ string_of_n_constraint nc2)); - match aux1, aux2 with + typ_debug + ( lazy + (Util.("Unify constraint " |> magenta |> clear) + ^ string_of_n_constraint nc1 ^ " and " ^ string_of_n_constraint nc2 + ) + ); + match (aux1, aux2) with | NC_var v, _ when KidSet.mem v goals -> KBindings.singleton v (arg_bool nc2) | NC_var v, NC_var v' when Kid.compare v v' = 0 -> KBindings.empty - | NC_and (nc1a, nc2a), NC_and (nc1b, nc2b) -> - begin - try - let conjs1 = List.sort NC.compare (constraint_conj nc1) in - let conjs2 = List.sort NC.compare (constraint_conj nc2) in - let unify_merge uv nc1 nc2 = merge_uvars env l uv (unify_constraint l env goals nc1 nc2) in - List.fold_left2 unify_merge KBindings.empty conjs1 conjs2 - with - | _ -> merge_uvars env l (unify_constraint l env goals nc1a nc1b) (unify_constraint l env goals nc2a nc2b) - end + | NC_and (nc1a, nc2a), NC_and (nc1b, nc2b) -> begin + try + let conjs1 = List.sort NC.compare (constraint_conj nc1) in + let conjs2 = List.sort NC.compare (constraint_conj nc2) in + let unify_merge uv nc1 nc2 = merge_uvars env l uv (unify_constraint l env goals nc1 nc2) in + List.fold_left2 unify_merge KBindings.empty conjs1 conjs2 + with _ -> merge_uvars env l (unify_constraint l env goals nc1a nc1b) (unify_constraint l env goals nc2a nc2b) + end | NC_or (nc1a, nc2a), NC_or (nc1b, nc2b) -> - merge_uvars env l (unify_constraint l env goals nc1a nc1b) (unify_constraint l env goals nc2a nc2b) + merge_uvars env l (unify_constraint l env goals nc1a nc1b) (unify_constraint l env goals nc2a nc2b) | NC_app (f1, args1), NC_app (f2, args2) when Id.compare f1 f2 = 0 && List.length args1 = List.length args2 -> - List.fold_left (merge_uvars env l) KBindings.empty (List.map2 (unify_typ_arg l env goals) args1 args2) + List.fold_left (merge_uvars env l) KBindings.empty (List.map2 (unify_typ_arg l env goals) args1 args2) | NC_equal (n1a, n2a), NC_equal (n1b, n2b) -> - merge_uvars env l (unify_nexp l env goals n1a n1b) (unify_nexp l env goals n2a n2b) + merge_uvars env l (unify_nexp l env goals n1a n1b) (unify_nexp l env goals n2a n2b) | NC_not_equal (n1a, n2a), NC_not_equal (n1b, n2b) -> - merge_uvars env l (unify_nexp l env goals n1a n1b) (unify_nexp l env goals n2a n2b) + merge_uvars env l (unify_nexp l env goals n1a n1b) (unify_nexp l env goals n2a n2b) | NC_bounded_ge (n1a, n2a), NC_bounded_ge (n1b, n2b) -> - merge_uvars env l (unify_nexp l env goals n1a n1b) (unify_nexp l env goals n2a n2b) + merge_uvars env l (unify_nexp l env goals n1a n1b) (unify_nexp l env goals n2a n2b) | NC_bounded_gt (n1a, n2a), NC_bounded_gt (n1b, n2b) -> - merge_uvars env l (unify_nexp l env goals n1a n1b) (unify_nexp l env goals n2a n2b) + merge_uvars env l (unify_nexp l env goals n1a n1b) (unify_nexp l env goals n2a n2b) | NC_bounded_le (n1a, n2a), NC_bounded_le (n1b, n2b) -> - merge_uvars env l (unify_nexp l env goals n1a n1b) (unify_nexp l env goals n2a n2b) + merge_uvars env l (unify_nexp l env goals n1a n1b) (unify_nexp l env goals n2a n2b) | NC_bounded_lt (n1a, n2a), NC_bounded_lt (n1b, n2b) -> - merge_uvars env l (unify_nexp l env goals n1a n1b) (unify_nexp l env goals n2a n2b) + merge_uvars env l (unify_nexp l env goals n1a n1b) (unify_nexp l env goals n2a n2b) | NC_true, NC_true -> KBindings.empty | NC_false, NC_false -> KBindings.empty - | _, _ -> unify_error l ("Could not unify constraints " ^ string_of_n_constraint nc1 ^ " and " ^ string_of_n_constraint nc2) + | _, _ -> + unify_error l ("Could not unify constraints " ^ string_of_n_constraint nc1 ^ " and " ^ string_of_n_constraint nc2) and unify_order l goals (Ord_aux (aux1, _) as ord1) (Ord_aux (aux2, _) as ord2) = typ_print (lazy (Util.("Unify order " |> magenta |> clear) ^ string_of_order ord1 ^ " and " ^ string_of_order ord2)); - match aux1, aux2 with + match (aux1, aux2) with | Ord_var v, _ when KidSet.mem v goals -> KBindings.singleton v (arg_order ord2) | Ord_inc, Ord_inc -> KBindings.empty | Ord_dec, Ord_dec -> KBindings.empty | _, _ -> unify_error l ("Could not unify " ^ string_of_order ord1 ^ " and " ^ string_of_order ord2) and unify_nexp l env goals (Nexp_aux (nexp_aux1, _) as nexp1) (Nexp_aux (nexp_aux2, _) as nexp2) = - typ_debug (lazy (Util.("Unify nexp " |> magenta |> clear) ^ string_of_nexp nexp1 ^ " and " ^ string_of_nexp nexp2 - ^ " goals " ^ string_of_list ", " string_of_kid (KidSet.elements goals))); - if KidSet.is_empty (KidSet.inter (nexp_frees nexp1) goals) - then - begin - if prove __POS__ env (NC_aux (NC_equal (nexp1, nexp2), Parse_ast.Unknown)) - then KBindings.empty - else unify_error l ("Integer expressions " ^ string_of_nexp nexp1 ^ " and " ^ string_of_nexp nexp2 ^ " are not equal") - end - else + typ_debug + ( lazy + (Util.("Unify nexp " |> magenta |> clear) + ^ string_of_nexp nexp1 ^ " and " ^ string_of_nexp nexp2 ^ " goals " + ^ string_of_list ", " string_of_kid (KidSet.elements goals) + ) + ); + if KidSet.is_empty (KidSet.inter (nexp_frees nexp1) goals) then begin + if prove __POS__ env (NC_aux (NC_equal (nexp1, nexp2), Parse_ast.Unknown)) then KBindings.empty + else + unify_error l ("Integer expressions " ^ string_of_nexp nexp1 ^ " and " ^ string_of_nexp nexp2 ^ " are not equal") + end + else ( match nexp_aux1 with | Nexp_id v -> unify_error l "Unimplemented Nexp_id in unify nexp" | Nexp_var kid when KidSet.mem kid goals -> KBindings.singleton kid (arg_nexp nexp2) - | Nexp_constant c1 -> - begin - match nexp_aux2 with - | Nexp_constant c2 -> if c1 = c2 then KBindings.empty else unify_error l "Constants are not the same" - | _ -> unify_error l "Unification error" - end + | Nexp_constant c1 -> begin + match nexp_aux2 with + | Nexp_constant c2 -> if c1 = c2 then KBindings.empty else unify_error l "Constants are not the same" + | _ -> unify_error l "Unification error" + end | Nexp_sum (n1a, n1b) -> - if KidSet.is_empty (nexp_frees n1b) - then unify_nexp l env goals n1a (nminus nexp2 n1b) - else - if KidSet.is_empty (nexp_frees n1a) - then unify_nexp l env goals n1b (nminus nexp2 n1a) - else begin - match nexp_aux2 with - | Nexp_sum (n2a, n2b) -> - if KidSet.is_empty (nexp_frees n2a) - then unify_nexp l env goals n2b (nminus nexp1 n2a) - else - if KidSet.is_empty (nexp_frees n2a) - then unify_nexp l env goals n2a (nminus nexp1 n2b) - else merge_uvars env l (unify_nexp l env goals n1a n2a) (unify_nexp l env goals n1b n2b) - | _ -> unify_error l ("Both sides of Int expression " ^ string_of_nexp nexp1 - ^ " contain free type variables so it cannot be unified with " ^ string_of_nexp nexp2) - end + if KidSet.is_empty (nexp_frees n1b) then unify_nexp l env goals n1a (nminus nexp2 n1b) + else if KidSet.is_empty (nexp_frees n1a) then unify_nexp l env goals n1b (nminus nexp2 n1a) + else begin + match nexp_aux2 with + | Nexp_sum (n2a, n2b) -> + if KidSet.is_empty (nexp_frees n2a) then unify_nexp l env goals n2b (nminus nexp1 n2a) + else if KidSet.is_empty (nexp_frees n2a) then unify_nexp l env goals n2a (nminus nexp1 n2b) + else merge_uvars env l (unify_nexp l env goals n1a n2a) (unify_nexp l env goals n1b n2b) + | _ -> + unify_error l + ("Both sides of Int expression " ^ string_of_nexp nexp1 + ^ " contain free type variables so it cannot be unified with " ^ string_of_nexp nexp2 + ) + end | Nexp_minus (n1a, n1b) -> - if KidSet.is_empty (nexp_frees n1b) - then unify_nexp l env goals n1a (nsum nexp2 n1b) - else unify_error l ("Cannot unify minus Int expression " ^ string_of_nexp nexp1 ^ " with " ^ string_of_nexp nexp2) + if KidSet.is_empty (nexp_frees n1b) then unify_nexp l env goals n1a (nsum nexp2 n1b) + else + unify_error l ("Cannot unify minus Int expression " ^ string_of_nexp nexp1 ^ " with " ^ string_of_nexp nexp2) | Nexp_times (n1a, n1b) -> - (* If we have SMT operations div and mod, then we can use the - property that + (* If we have SMT operations div and mod, then we can use the + property that - mod(m, C) = 0 && C != 0 --> (C * n = m <--> n = m / C) + mod(m, C) = 0 && C != 0 --> (C * n = m <--> n = m / C) - to help us unify multiplications and divisions. + to help us unify multiplications and divisions. - In particular, the nexp rewriting used in monomorphisation adds - constraints of the form 8 * 'n == 'p8_times_n, and we sometimes need - to solve for 'n. + In particular, the nexp rewriting used in monomorphisation adds + constraints of the form 8 * 'n == 'p8_times_n, and we sometimes need + to solve for 'n. *) - let valid n c = prove __POS__ env (nc_eq (napp (mk_id "mod") [n; c]) (nint 0)) && prove __POS__ env (nc_neq c (nint 0)) in - (*if KidSet.is_empty (nexp_frees n1b) && valid nexp2 n1b then - unify_nexp l env goals n1a (napp (mk_id "div") [nexp2; n1b]) - else if KidSet.is_empty (nexp_frees n1a) && valid nexp2 n1a then - unify_nexp l env goals n1b (napp (mk_id "div") [nexp2; n1a]) *) - if KidSet.is_empty (nexp_frees n1a) then - begin - match nexp_aux2 with - | Nexp_times (n2a, n2b) when prove __POS__ env (NC_aux (NC_equal (n1a, n2a), Parse_ast.Unknown)) -> + let valid n c = + prove __POS__ env (nc_eq (napp (mk_id "mod") [n; c]) (nint 0)) && prove __POS__ env (nc_neq c (nint 0)) + in + (*if KidSet.is_empty (nexp_frees n1b) && valid nexp2 n1b then + unify_nexp l env goals n1a (napp (mk_id "div") [nexp2; n1b]) + else if KidSet.is_empty (nexp_frees n1a) && valid nexp2 n1a then + unify_nexp l env goals n1b (napp (mk_id "div") [nexp2; n1a]) *) + if KidSet.is_empty (nexp_frees n1a) then begin + match nexp_aux2 with + | Nexp_times (n2a, n2b) when prove __POS__ env (NC_aux (NC_equal (n1a, n2a), Parse_ast.Unknown)) -> unify_nexp l env goals n1b n2b - | Nexp_constant c2 -> - begin - match n1a with - | Nexp_aux (Nexp_constant c1,_) when Big_int.equal (Big_int.modulus c2 c1) Big_int.zero -> - unify_nexp l env goals n1b (nconstant (Big_int.div c2 c1)) - | _ -> unify_error l ("Cannot unify Int expression " ^ string_of_nexp nexp1 ^ " with " ^ string_of_nexp nexp2) - end - | Nexp_var kid when (not (KidSet.mem kid goals)) && valid nexp2 n1a -> + | Nexp_constant c2 -> begin + match n1a with + | Nexp_aux (Nexp_constant c1, _) when Big_int.equal (Big_int.modulus c2 c1) Big_int.zero -> + unify_nexp l env goals n1b (nconstant (Big_int.div c2 c1)) + | _ -> + unify_error l ("Cannot unify Int expression " ^ string_of_nexp nexp1 ^ " with " ^ string_of_nexp nexp2) + end + | Nexp_var kid when (not (KidSet.mem kid goals)) && valid nexp2 n1a -> unify_nexp l env goals n1b (napp (mk_id "div") [nexp2; n1a]) - | _ -> unify_error l ("Cannot unify Int expression " ^ string_of_nexp nexp1 ^ " with " ^ string_of_nexp nexp2) - end - else if KidSet.is_empty (nexp_frees n1b) then - begin - match nexp_aux2 with - | Nexp_times (n2a, n2b) when prove __POS__ env (NC_aux (NC_equal (n1b, n2b), Parse_ast.Unknown)) -> + | _ -> unify_error l ("Cannot unify Int expression " ^ string_of_nexp nexp1 ^ " with " ^ string_of_nexp nexp2) + end + else if KidSet.is_empty (nexp_frees n1b) then begin + match nexp_aux2 with + | Nexp_times (n2a, n2b) when prove __POS__ env (NC_aux (NC_equal (n1b, n2b), Parse_ast.Unknown)) -> unify_nexp l env goals n1a n2a - | Nexp_var kid when (not (KidSet.mem kid goals)) && valid nexp2 n1b -> + | Nexp_var kid when (not (KidSet.mem kid goals)) && valid nexp2 n1b -> unify_nexp l env goals n1a (napp (mk_id "div") [nexp2; n1b]) - | _ -> unify_error l ("Cannot unify Int expression " ^ string_of_nexp nexp1 ^ " with " ^ string_of_nexp nexp2) - end - else unify_error l ("Cannot unify Int expression " ^ string_of_nexp nexp1 ^ " with " ^ string_of_nexp nexp2) - | Nexp_exp n1 -> - begin - match nexp_aux2 with - | Nexp_exp n2 -> unify_nexp l env goals n1 n2 - | _ -> unify_error l ("Cannot unify Int expression " ^ string_of_nexp nexp1 ^ " with " ^ string_of_nexp nexp2) - end + | _ -> unify_error l ("Cannot unify Int expression " ^ string_of_nexp nexp1 ^ " with " ^ string_of_nexp nexp2) + end + else unify_error l ("Cannot unify Int expression " ^ string_of_nexp nexp1 ^ " with " ^ string_of_nexp nexp2) + | Nexp_exp n1 -> begin + match nexp_aux2 with + | Nexp_exp n2 -> unify_nexp l env goals n1 n2 + | _ -> unify_error l ("Cannot unify Int expression " ^ string_of_nexp nexp1 ^ " with " ^ string_of_nexp nexp2) + end | _ -> unify_error l ("Cannot unify Int expression " ^ string_of_nexp nexp1 ^ " with " ^ string_of_nexp nexp2) + ) let unify l env goals typ1 typ2 = - typ_print (lazy (Util.("Unify " |> magenta |> clear) ^ string_of_typ typ1 ^ " and " ^ string_of_typ typ2 - ^ " for " ^ Util.string_of_list ", " string_of_kid (KidSet.elements goals))); - let typ1, typ2 = Env.expand_synonyms env typ1, Env.expand_synonyms env typ2 in + typ_print + ( lazy + (Util.("Unify " |> magenta |> clear) + ^ string_of_typ typ1 ^ " and " ^ string_of_typ typ2 ^ " for " + ^ Util.string_of_list ", " string_of_kid (KidSet.elements goals) + ) + ); + let typ1, typ2 = (Env.expand_synonyms env typ1, Env.expand_synonyms env typ2) in if not (KidSet.is_empty (KidSet.inter goals (tyvars_of_typ typ2))) then - typ_error env l ("Occurs check failed: " ^ string_of_typ typ2 ^ " contains " - ^ Util.string_of_list ", " string_of_kid (KidSet.elements goals)) - else - unify_typ l env goals typ1 typ2 + typ_error env l + ("Occurs check failed: " ^ string_of_typ typ2 ^ " contains " + ^ Util.string_of_list ", " string_of_kid (KidSet.elements goals) + ) + else unify_typ l env goals typ1 typ2 let subst_unifiers unifiers typ = List.fold_left (fun typ (v, arg) -> typ_subst v arg typ) typ (KBindings.bindings unifiers) @@ -2107,13 +2200,12 @@ let subst_unifiers_typ_arg unifiers typ_arg = let instantiate_quant env (v, arg) (QI_aux (aux, l) as qi) = match aux with | QI_id kopt when Kid.compare (kopt_kid kopt) v = 0 -> - typ_debug (lazy ("Instantiated " ^ string_of_quant_item qi)); - None + typ_debug (lazy ("Instantiated " ^ string_of_quant_item qi)); + None | QI_id _ -> Some qi | QI_constraint nc -> Some (QI_aux (QI_constraint (constraint_subst v arg nc), l)) -let instantiate_quants env quants unifier = - List.map (instantiate_quant env unifier) quants |> Util.option_these +let instantiate_quants env quants unifier = List.map (instantiate_quant env unifier) quants |> Util.option_these (* During typechecking, we can run into the following issue, where we have a function like @@ -2139,10 +2231,7 @@ let rec ambiguous_vars' (Typ_aux (aux, _)) = | _ -> KidSet.empty and ambiguous_arg_vars (A_aux (aux, _)) = - match aux with - | A_bool nc -> ambiguous_nc_vars nc - | A_nexp nexp -> ambiguous_nexp_vars nexp - | _ -> KidSet.empty + match aux with A_bool nc -> ambiguous_nc_vars nc | A_nexp nexp -> ambiguous_nexp_vars nexp | _ -> KidSet.empty and ambiguous_nc_vars (NC_aux (aux, _)) = match aux with @@ -2151,8 +2240,7 @@ and ambiguous_nc_vars (NC_aux (aux, _)) = | NC_bounded_lt (n1, n2) -> KidSet.union (tyvars_of_nexp n1) (tyvars_of_nexp n2) | NC_bounded_ge (n1, n2) -> KidSet.union (tyvars_of_nexp n1) (tyvars_of_nexp n2) | NC_bounded_gt (n1, n2) -> KidSet.union (tyvars_of_nexp n1) (tyvars_of_nexp n2) - | NC_equal (n1, n2) | NC_not_equal (n1, n2) -> - KidSet.union (ambiguous_nexp_vars n1) (ambiguous_nexp_vars n2) + | NC_equal (n1, n2) | NC_not_equal (n1, n2) -> KidSet.union (ambiguous_nexp_vars n1) (ambiguous_nexp_vars n2) | _ -> KidSet.empty and ambiguous_nexp_vars (Nexp_aux (aux, _)) = @@ -2166,68 +2254,58 @@ let ambiguous_vars typ = let rec is_typ_inhabited env (Typ_aux (aux, l) as typ) = match aux with - | Typ_tuple typs -> - List.for_all (is_typ_inhabited env) typs + | Typ_tuple typs -> List.for_all (is_typ_inhabited env) typs | Typ_app (id, [A_aux (A_nexp len, _); _]) when Id.compare id (mk_id "bitvector") = 0 -> - prove __POS__ env (nc_gteq len (nint 0)) + prove __POS__ env (nc_gteq len (nint 0)) | Typ_app (id, [A_aux (A_nexp len, _); _; A_aux (A_typ elem_typ, _)]) when Id.compare id (mk_id "vector") = 0 -> - prove __POS__ env (nc_gteq len (nint 0)) - | Typ_app (id, _) when Id.compare id (mk_id "list") = 0 -> - true + prove __POS__ env (nc_gteq len (nint 0)) + | Typ_app (id, _) when Id.compare id (mk_id "list") = 0 -> true | Typ_app (id, args) when Env.is_variant id env -> - let typq, constructors = Env.get_variant id env in - let kopts, _ = quant_split typq in - let unifiers = List.fold_left2 (fun kb kopt arg -> KBindings.add (kopt_kid kopt) arg kb) KBindings.empty kopts args in - List.exists (fun (Tu_aux (Tu_ty_id (typ, id), _)) -> - is_typ_inhabited env (subst_unifiers unifiers typ) - ) constructors + let typq, constructors = Env.get_variant id env in + let kopts, _ = quant_split typq in + let unifiers = + List.fold_left2 (fun kb kopt arg -> KBindings.add (kopt_kid kopt) arg kb) KBindings.empty kopts args + in + List.exists + (fun (Tu_aux (Tu_ty_id (typ, id), _)) -> is_typ_inhabited env (subst_unifiers unifiers typ)) + constructors | Typ_id id when Env.is_record id env -> - let _, fields = Env.get_record id env in - List.for_all (fun (typ, field) -> - is_typ_inhabited env typ - ) fields + let _, fields = Env.get_record id env in + List.for_all (fun (typ, field) -> is_typ_inhabited env typ) fields | Typ_app (id, args) when Env.is_record id env -> - let typq, fields = Env.get_record id env in - let kopts, _ = quant_split typq in - let unifiers = List.fold_left2 (fun kb kopt arg -> KBindings.add (kopt_kid kopt) arg kb) KBindings.empty kopts args in - List.for_all (fun (typ, field) -> - is_typ_inhabited env (subst_unifiers unifiers typ) - ) fields - | Typ_app (_, args) -> - List.for_all (is_typ_arg_inhabited env) args - | (Typ_exist _) -> - let typ, env = bind_existential l None typ env in - is_typ_inhabited env typ - | Typ_id _ -> - true - | Typ_var _ -> - true - | Typ_fn _ | Typ_bidir _ -> - Reporting.unreachable l __POS__ "Inhabitedness check applied to function or mapping type" - | Typ_internal_unknown -> - Reporting.unreachable l __POS__ "Inhabitedness check applied to unknown type" - -and is_typ_arg_inhabited env (A_aux (aux, l)) = - match aux with - | A_typ typ -> is_typ_inhabited env typ - | _ -> true - + let typq, fields = Env.get_record id env in + let kopts, _ = quant_split typq in + let unifiers = + List.fold_left2 (fun kb kopt arg -> KBindings.add (kopt_kid kopt) arg kb) KBindings.empty kopts args + in + List.for_all (fun (typ, field) -> is_typ_inhabited env (subst_unifiers unifiers typ)) fields + | Typ_app (_, args) -> List.for_all (is_typ_arg_inhabited env) args + | Typ_exist _ -> + let typ, env = bind_existential l None typ env in + is_typ_inhabited env typ + | Typ_id _ -> true + | Typ_var _ -> true + | Typ_fn _ | Typ_bidir _ -> Reporting.unreachable l __POS__ "Inhabitedness check applied to function or mapping type" + | Typ_internal_unknown -> Reporting.unreachable l __POS__ "Inhabitedness check applied to unknown type" + +and is_typ_arg_inhabited env (A_aux (aux, l)) = match aux with A_typ typ -> is_typ_inhabited env typ | _ -> true + (**************************************************************************) (* 3.5. Subtyping with existentials *) (**************************************************************************) let destruct_atom_nexp env typ = match Env.expand_synonyms env typ with - | Typ_aux (Typ_app (f, [A_aux (A_nexp n, _)]), _) - when string_of_id f = "atom" || string_of_id f = "implicit" -> Some n + | Typ_aux (Typ_app (f, [A_aux (A_nexp n, _)]), _) when string_of_id f = "atom" || string_of_id f = "implicit" -> + Some n | Typ_aux (Typ_app (f, [A_aux (A_nexp n, _); A_aux (A_nexp m, _)]), _) - when string_of_id f = "range" && nexp_identical n m -> Some n + when string_of_id f = "range" && nexp_identical n m -> + Some n | _ -> None let destruct_atom_bool env typ = match Env.expand_synonyms env typ with - | Typ_aux (Typ_app (f, [A_aux (A_bool nc, _)]), _) when string_of_id f = "atom_bool" -> - Some nc + | Typ_aux (Typ_app (f, [A_aux (A_bool nc, _)]), _) when string_of_id f = "atom_bool" -> Some nc | _ -> None (* The kid_order function takes a set of Int-kinded kids, and returns @@ -2239,57 +2317,85 @@ let destruct_atom_bool env typ = let rec kid_order_nexp kind_map (Nexp_aux (aux, l)) = match aux with | Nexp_var kid when KBindings.mem kid kind_map -> - ([mk_kopt (unaux_kind (KBindings.find kid kind_map)) kid], KBindings.remove kid kind_map) + ([mk_kopt (unaux_kind (KBindings.find kid kind_map)) kid], KBindings.remove kid kind_map) | Nexp_var _ | Nexp_id _ | Nexp_constant _ -> ([], kind_map) | Nexp_exp nexp | Nexp_neg nexp -> kid_order_nexp kind_map nexp | Nexp_times (nexp1, nexp2) | Nexp_sum (nexp1, nexp2) | Nexp_minus (nexp1, nexp2) -> - let (ord, kids) = kid_order_nexp kind_map nexp1 in - let (ord', kids) = kid_order_nexp kids nexp2 in - (ord @ ord', kids) + let ord, kids = kid_order_nexp kind_map nexp1 in + let ord', kids = kid_order_nexp kids nexp2 in + (ord @ ord', kids) | Nexp_app (id, nexps) -> - List.fold_left (fun (ord, kids) nexp -> let (ord', kids) = kid_order_nexp kids nexp in (ord @ ord', kids)) ([], kind_map) nexps - + List.fold_left + (fun (ord, kids) nexp -> + let ord', kids = kid_order_nexp kids nexp in + (ord @ ord', kids) + ) + ([], kind_map) nexps let rec kid_order kind_map (Typ_aux (aux, l) as typ) = match aux with | Typ_var kid when KBindings.mem kid kind_map -> - ([mk_kopt (unaux_kind (KBindings.find kid kind_map)) kid], KBindings.remove kid kind_map) + ([mk_kopt (unaux_kind (KBindings.find kid kind_map)) kid], KBindings.remove kid kind_map) | Typ_id _ | Typ_var _ -> ([], kind_map) | Typ_tuple typs -> - List.fold_left (fun (ord, kids) typ -> let (ord', kids) = kid_order kids typ in (ord @ ord', kids)) ([], kind_map) typs + List.fold_left + (fun (ord, kids) typ -> + let ord', kids = kid_order kids typ in + (ord @ ord', kids) + ) + ([], kind_map) typs | Typ_app (_, args) -> - List.fold_left (fun (ord, kids) arg -> let (ord', kids) = kid_order_arg kids arg in (ord @ ord', kids)) ([], kind_map) args - | Typ_fn _ | Typ_bidir _ | Typ_exist _ -> typ_error Env.empty l ("Existential or function type cannot appear within existential type: " ^ string_of_typ typ) + List.fold_left + (fun (ord, kids) arg -> + let ord', kids = kid_order_arg kids arg in + (ord @ ord', kids) + ) + ([], kind_map) args + | Typ_fn _ | Typ_bidir _ | Typ_exist _ -> + typ_error Env.empty l ("Existential or function type cannot appear within existential type: " ^ string_of_typ typ) | Typ_internal_unknown -> Reporting.unreachable l __POS__ "escaped Typ_internal_unknown" + and kid_order_arg kind_map (A_aux (aux, l)) = match aux with | A_typ typ -> kid_order kind_map typ | A_nexp nexp -> kid_order_nexp kind_map nexp | A_bool nc -> kid_order_constraint kind_map nc | A_order _ -> ([], kind_map) + and kid_order_constraint kind_map (NC_aux (aux, l)) = match aux with - | NC_var kid | NC_set (kid, _) when KBindings.mem kid kind_map -> - ([mk_kopt (unaux_kind (KBindings.find kid kind_map)) kid], KBindings.remove kid kind_map) + | (NC_var kid | NC_set (kid, _)) when KBindings.mem kid kind_map -> + ([mk_kopt (unaux_kind (KBindings.find kid kind_map)) kid], KBindings.remove kid kind_map) | NC_var _ | NC_set _ -> ([], kind_map) | NC_true | NC_false -> ([], kind_map) - | NC_equal (n1, n2) | NC_not_equal (n1, n2) - | NC_bounded_le (n1, n2) | NC_bounded_ge (n1, n2) - | NC_bounded_lt (n1, n2) | NC_bounded_gt (n1, n2) -> - let ord1, kind_map = kid_order_nexp kind_map n1 in - let ord2, kind_map = kid_order_nexp kind_map n2 in - (ord1 @ ord2, kind_map) + | NC_equal (n1, n2) + | NC_not_equal (n1, n2) + | NC_bounded_le (n1, n2) + | NC_bounded_ge (n1, n2) + | NC_bounded_lt (n1, n2) + | NC_bounded_gt (n1, n2) -> + let ord1, kind_map = kid_order_nexp kind_map n1 in + let ord2, kind_map = kid_order_nexp kind_map n2 in + (ord1 @ ord2, kind_map) | NC_app (_, args) -> - List.fold_left (fun (ord, kind_map) arg -> let ord', kind_map = kid_order_arg kind_map arg in (ord @ ord', kind_map)) - ([], kind_map) args + List.fold_left + (fun (ord, kind_map) arg -> + let ord', kind_map = kid_order_arg kind_map arg in + (ord @ ord', kind_map) + ) + ([], kind_map) args | NC_and (nc1, nc2) | NC_or (nc1, nc2) -> - let ord1, kind_map = kid_order_constraint kind_map nc1 in - let ord2, kind_map = kid_order_constraint kind_map nc2 in - (ord1 @ ord2, kind_map) + let ord1, kind_map = kid_order_constraint kind_map nc1 in + let ord2, kind_map = kid_order_constraint kind_map nc2 in + (ord1 @ ord2, kind_map) let alpha_equivalent env typ1 typ2 = let counter = ref 0 in - let new_kid () = let kid = mk_kid ("alpha#" ^ string_of_int !counter) in (incr counter; kid) in + let new_kid () = + let kid = mk_kid ("alpha#" ^ string_of_int !counter) in + incr counter; + kid + in let rec relabel (Typ_aux (aux, l)) = let relabelled_aux = @@ -2300,38 +2406,42 @@ let alpha_equivalent env typ1 typ2 = | Typ_bidir (typ1, typ2) -> Typ_bidir (relabel typ1, relabel typ2) | Typ_tuple typs -> Typ_tuple (List.map relabel typs) | Typ_exist (kopts, nc, typ) -> - let kind_map = List.fold_left (fun m kopt -> KBindings.add (kopt_kid kopt) (kopt_kind kopt) m) KBindings.empty kopts in - let (kopts1, kind_map) = kid_order_constraint kind_map nc in - let (kopts2, _) = kid_order kind_map typ in - let kopts = kopts1 @ kopts2 in - let kopts = List.map (fun kopt -> (kopt_kid kopt, mk_kopt (unaux_kind (kopt_kind kopt)) (new_kid ()))) kopts in - let nc = List.fold_left (fun nc (kid, nk) -> constraint_subst kid (arg_kopt nk) nc) nc kopts in - let typ = List.fold_left (fun nc (kid, nk) -> typ_subst kid (arg_kopt nk) nc) typ kopts in - let kopts = List.map snd kopts in - Typ_exist (kopts, nc, typ) - | Typ_app (id, args) -> - Typ_app (id, List.map relabel_arg args) + let kind_map = + List.fold_left (fun m kopt -> KBindings.add (kopt_kid kopt) (kopt_kind kopt) m) KBindings.empty kopts + in + let kopts1, kind_map = kid_order_constraint kind_map nc in + let kopts2, _ = kid_order kind_map typ in + let kopts = kopts1 @ kopts2 in + let kopts = + List.map (fun kopt -> (kopt_kid kopt, mk_kopt (unaux_kind (kopt_kind kopt)) (new_kid ()))) kopts + in + let nc = List.fold_left (fun nc (kid, nk) -> constraint_subst kid (arg_kopt nk) nc) nc kopts in + let typ = List.fold_left (fun nc (kid, nk) -> typ_subst kid (arg_kopt nk) nc) typ kopts in + let kopts = List.map snd kopts in + Typ_exist (kopts, nc, typ) + | Typ_app (id, args) -> Typ_app (id, List.map relabel_arg args) in Typ_aux (relabelled_aux, l) and relabel_arg (A_aux (aux, l) as arg) = (* FIXME relabel constraint *) - match aux with - | A_nexp _ | A_order _ | A_bool _ -> arg - | A_typ typ -> A_aux (A_typ (relabel typ), l) + match aux with A_nexp _ | A_order _ | A_bool _ -> arg | A_typ typ -> A_aux (A_typ (relabel typ), l) in let typ1 = relabel (Env.expand_synonyms env typ1) in counter := 0; let typ2 = relabel (Env.expand_synonyms env typ2) in typ_debug (lazy ("Alpha equivalence for " ^ string_of_typ typ1 ^ " and " ^ string_of_typ typ2)); - if typ_identical typ1 typ2 - then (typ_debug (lazy "alpha-equivalent"); true) - else (typ_debug (lazy "Not alpha-equivalent"); false) + if typ_identical typ1 typ2 then ( + typ_debug (lazy "alpha-equivalent"); + true + ) + else ( + typ_debug (lazy "Not alpha-equivalent"); + false + ) let unifier_constraint env (v, arg) = - match arg with - | A_aux (A_nexp nexp, _) -> Env.add_constraint (nc_eq (nvar v) nexp) env - | _ -> env + match arg with A_aux (A_nexp nexp, _) -> Env.add_constraint (nc_eq (nvar v) nexp) env | _ -> env let canonicalize env typ = let typ = Env.expand_synonyms env typ in @@ -2339,140 +2449,150 @@ let canonicalize env typ = match aux with | Typ_var v -> Typ_aux (Typ_var v, l) | Typ_internal_unknown -> Typ_aux (Typ_internal_unknown, l) - | Typ_id id when string_of_id id = "int" -> - exist_typ l (fun _ -> nc_true) (fun v -> atom_typ (nvar v)) + | Typ_id id when string_of_id id = "int" -> exist_typ l (fun _ -> nc_true) (fun v -> atom_typ (nvar v)) | Typ_id id -> Typ_aux (Typ_id id, l) | Typ_app (id, [A_aux (A_nexp lo, _); A_aux (A_nexp hi, _)]) when string_of_id id = "range" -> - exist_typ l (fun v -> nc_and (nc_lteq lo (nvar v)) (nc_lteq (nvar v) hi)) (fun v -> atom_typ (nvar v)) - | Typ_app (id, args) -> - Typ_aux (Typ_app (id, List.map canon_arg args), l) + exist_typ l (fun v -> nc_and (nc_lteq lo (nvar v)) (nc_lteq (nvar v) hi)) (fun v -> atom_typ (nvar v)) + | Typ_app (id, args) -> Typ_aux (Typ_app (id, List.map canon_arg args), l) | Typ_tuple typs -> - let typs = List.map canon typs in - let fold_exist (kids, nc, typs) typ = - match destruct_exist typ with - | Some (kids', nc', typ') -> (kids @ kids', nc_and nc nc', typs @ [typ']) - | None -> (kids, nc, typs @ [typ]) - in - let kids, nc, typs = List.fold_left fold_exist ([], nc_true, []) typs in - if kids = [] then - Typ_aux (Typ_tuple typs, l) - else - Typ_aux (Typ_exist (kids, nc, Typ_aux (Typ_tuple typs, l)), l) - | Typ_exist (kids, nc, typ) -> - begin match destruct_exist (canon typ) with - | Some (kids', nc', typ') -> - Typ_aux (Typ_exist (kids @ kids', nc_and nc nc', typ'), l) - | None -> Typ_aux (Typ_exist (kids, nc, typ), l) - end - | Typ_fn _ | Typ_bidir _ -> raise (Reporting.err_unreachable l __POS__ "Function type passed to Type_check.canonicalize") - and canon_arg (A_aux (aux, l)) = - A_aux ((match aux with - | A_typ typ -> A_typ (canon typ) - | arg -> arg), - l) - in + let typs = List.map canon typs in + let fold_exist (kids, nc, typs) typ = + match destruct_exist typ with + | Some (kids', nc', typ') -> (kids @ kids', nc_and nc nc', typs @ [typ']) + | None -> (kids, nc, typs @ [typ]) + in + let kids, nc, typs = List.fold_left fold_exist ([], nc_true, []) typs in + if kids = [] then Typ_aux (Typ_tuple typs, l) else Typ_aux (Typ_exist (kids, nc, Typ_aux (Typ_tuple typs, l)), l) + | Typ_exist (kids, nc, typ) -> begin + match destruct_exist (canon typ) with + | Some (kids', nc', typ') -> Typ_aux (Typ_exist (kids @ kids', nc_and nc nc', typ'), l) + | None -> Typ_aux (Typ_exist (kids, nc, typ), l) + end + | Typ_fn _ | Typ_bidir _ -> + raise (Reporting.err_unreachable l __POS__ "Function type passed to Type_check.canonicalize") + and canon_arg (A_aux (aux, l)) = A_aux ((match aux with A_typ typ -> A_typ (canon typ) | arg -> arg), l) in canon typ let rec subtyp l env typ1 typ2 = let (Typ_aux (typ_aux1, _) as typ1) = Env.expand_synonyms env typ1 in let (Typ_aux (typ_aux2, _) as typ2) = Env.expand_synonyms env typ2 in typ_print (lazy (("Subtype " |> Util.green |> Util.clear) ^ string_of_typ typ1 ^ " and " ^ string_of_typ typ2)); - match destruct_numeric typ1, destruct_numeric typ2 with + match (destruct_numeric typ1, destruct_numeric typ2) with (* Ensure alpha equivalent types are always subtypes of one another - this ensures that we can always re-check inferred types. *) | _, _ when alpha_equivalent env typ1 typ2 -> () (* Special cases for two numeric (atom) types *) | Some (kids1, nc1, nexp1), Some ([], _, nexp2) -> - let env = add_existential l (List.map (mk_kopt K_int) kids1) nc1 env in - let prop = nc_eq nexp1 nexp2 in - if prove __POS__ env prop then () else typ_raise env l (Err_subtype (typ1, typ2, Some prop, Env.get_constraint_reasons env, Env.get_typ_var_locs env)) + let env = add_existential l (List.map (mk_kopt K_int) kids1) nc1 env in + let prop = nc_eq nexp1 nexp2 in + if prove __POS__ env prop then () + else + typ_raise env l (Err_subtype (typ1, typ2, Some prop, Env.get_constraint_reasons env, Env.get_typ_var_locs env)) | Some (kids1, nc1, nexp1), Some (kids2, nc2, nexp2) -> - let env = add_existential l (List.map (mk_kopt K_int) kids1) nc1 env in - let env = add_typ_vars l (List.map (mk_kopt K_int) (KidSet.elements (KidSet.inter (nexp_frees nexp2) (KidSet.of_list kids2)))) env in - let kids2 = KidSet.elements (KidSet.diff (KidSet.of_list kids2) (nexp_frees nexp2)) in - if not (kids2 = []) then typ_error env l ("Universally quantified constraint generated: " ^ Util.string_of_list ", " string_of_kid kids2) else (); - (* TODO: Check this *) - let _vars = KBindings.filter (fun _ k -> match k with K_int | K_bool -> true | _ -> false) (Env.get_typ_vars env) in - begin match Constraint.call_smt l (nc_eq nexp1 nexp2) with - | Constraint.Sat -> - let env = Env.add_constraint (nc_eq nexp1 nexp2) env in - if prove __POS__ env nc2 then - () - else - typ_raise env l (Err_subtype (typ1, typ2, Some nc2, Env.get_constraint_reasons env, Env.get_typ_var_locs env)) - | _ -> - typ_error env l ("Constraint " ^ string_of_n_constraint (nc_eq nexp1 nexp2) ^ " is not satisfiable") - end - | _, _ -> - match typ_aux1, typ_aux2 with - | _, Typ_internal_unknown when Env.allow_unknowns env -> () - - | Typ_app (id1, _), Typ_id id2 when string_of_id id1 = "atom_bool" && string_of_id id2 = "bool" -> () - - | Typ_tuple typs1, Typ_tuple typs2 when List.length typs1 = List.length typs2 -> - List.iter2 (subtyp l env) typs1 typs2 - - | Typ_app (id1, args1), Typ_app (id2, args2) when Id.compare id1 id2 = 0 && List.length args1 = List.length args2 -> - List.iter2 (subtyp_arg l env) args1 args2 - - | Typ_id id1, Typ_id id2 when Id.compare id1 id2 = 0 -> () - | Typ_id id1, Typ_app (id2, []) when Id.compare id1 id2 = 0 -> () - | Typ_app (id1, []), Typ_id id2 when Id.compare id1 id2 = 0 -> () - - | Typ_fn (typ_args1, ret_typ1), Typ_fn (typ_args2, ret_typ2) -> - if List.compare_lengths typ_args1 typ_args2 <> 0 then ( - typ_error env l "Function types do not have the same number of arguments in subtype check" - ); - List.iter2 (subtyp l env) typ_args2 typ_args1; - subtyp l env ret_typ1 ret_typ2 - - | _, _ -> - match destruct_exist_plain typ1, destruct_exist (canonicalize env typ2) with - | Some (kopts, nc, typ1), _ -> - let env = add_existential l kopts nc env in subtyp l env typ1 typ2 - | None, Some (kopts, nc, typ2) -> - typ_debug (lazy "Subtype check with unification"); - let orig_env = env in - let typ1, env = bind_existential l None (canonicalize env typ1) env in - let env = add_typ_vars l kopts env in - let kids' = KidSet.elements (KidSet.diff (KidSet.of_list (List.map kopt_kid kopts)) (tyvars_of_typ typ2)) in - if not (kids' = []) then typ_error env l "Universally quantified constraint generated" else (); - let unifiers = - try unify l env (KidSet.diff (tyvars_of_typ typ2) (tyvars_of_typ typ1)) typ2 typ1 with - | Unification_error (_, m) -> typ_error env l m - in - let nc = List.fold_left (fun nc (kid, uvar) -> constraint_subst kid uvar nc) nc (KBindings.bindings unifiers) in - let env = List.fold_left unifier_constraint env (KBindings.bindings unifiers) in - if prove __POS__ env nc then () - else typ_raise env l (Err_subtype (typ1, typ2, Some nc, Env.get_constraint_reasons orig_env, Env.get_typ_var_locs env)) - | None, None -> typ_raise env l (Err_subtype (typ1, typ2, None, Env.get_constraint_reasons env, Env.get_typ_var_locs env)) + let env = add_existential l (List.map (mk_kopt K_int) kids1) nc1 env in + let env = + add_typ_vars l + (List.map (mk_kopt K_int) (KidSet.elements (KidSet.inter (nexp_frees nexp2) (KidSet.of_list kids2)))) + env + in + let kids2 = KidSet.elements (KidSet.diff (KidSet.of_list kids2) (nexp_frees nexp2)) in + if not (kids2 = []) then + typ_error env l ("Universally quantified constraint generated: " ^ Util.string_of_list ", " string_of_kid kids2) + else (); + (* TODO: Check this *) + let _vars = + KBindings.filter (fun _ k -> match k with K_int | K_bool -> true | _ -> false) (Env.get_typ_vars env) + in + begin + match Constraint.call_smt l (nc_eq nexp1 nexp2) with + | Constraint.Sat -> + let env = Env.add_constraint (nc_eq nexp1 nexp2) env in + if prove __POS__ env nc2 then () + else + typ_raise env l + (Err_subtype (typ1, typ2, Some nc2, Env.get_constraint_reasons env, Env.get_typ_var_locs env)) + | _ -> typ_error env l ("Constraint " ^ string_of_n_constraint (nc_eq nexp1 nexp2) ^ " is not satisfiable") + end + | _, _ -> ( + match (typ_aux1, typ_aux2) with + | _, Typ_internal_unknown when Env.allow_unknowns env -> () + | Typ_app (id1, _), Typ_id id2 when string_of_id id1 = "atom_bool" && string_of_id id2 = "bool" -> () + | Typ_tuple typs1, Typ_tuple typs2 when List.length typs1 = List.length typs2 -> + List.iter2 (subtyp l env) typs1 typs2 + | Typ_app (id1, args1), Typ_app (id2, args2) when Id.compare id1 id2 = 0 && List.length args1 = List.length args2 + -> + List.iter2 (subtyp_arg l env) args1 args2 + | Typ_id id1, Typ_id id2 when Id.compare id1 id2 = 0 -> () + | Typ_id id1, Typ_app (id2, []) when Id.compare id1 id2 = 0 -> () + | Typ_app (id1, []), Typ_id id2 when Id.compare id1 id2 = 0 -> () + | Typ_fn (typ_args1, ret_typ1), Typ_fn (typ_args2, ret_typ2) -> + if List.compare_lengths typ_args1 typ_args2 <> 0 then + typ_error env l "Function types do not have the same number of arguments in subtype check"; + List.iter2 (subtyp l env) typ_args2 typ_args1; + subtyp l env ret_typ1 ret_typ2 + | _, _ -> ( + match (destruct_exist_plain typ1, destruct_exist (canonicalize env typ2)) with + | Some (kopts, nc, typ1), _ -> + let env = add_existential l kopts nc env in + subtyp l env typ1 typ2 + | None, Some (kopts, nc, typ2) -> + typ_debug (lazy "Subtype check with unification"); + let orig_env = env in + let typ1, env = bind_existential l None (canonicalize env typ1) env in + let env = add_typ_vars l kopts env in + let kids' = + KidSet.elements (KidSet.diff (KidSet.of_list (List.map kopt_kid kopts)) (tyvars_of_typ typ2)) + in + if not (kids' = []) then typ_error env l "Universally quantified constraint generated" else (); + let unifiers = + try unify l env (KidSet.diff (tyvars_of_typ typ2) (tyvars_of_typ typ1)) typ2 typ1 + with Unification_error (_, m) -> typ_error env l m + in + let nc = + List.fold_left (fun nc (kid, uvar) -> constraint_subst kid uvar nc) nc (KBindings.bindings unifiers) + in + let env = List.fold_left unifier_constraint env (KBindings.bindings unifiers) in + if prove __POS__ env nc then () + else + typ_raise env l + (Err_subtype (typ1, typ2, Some nc, Env.get_constraint_reasons orig_env, Env.get_typ_var_locs env)) + | None, None -> + typ_raise env l (Err_subtype (typ1, typ2, None, Env.get_constraint_reasons env, Env.get_typ_var_locs env)) + ) + ) and subtyp_arg l env (A_aux (aux1, _) as arg1) (A_aux (aux2, _) as arg2) = - typ_print (lazy (("Subtype arg " |> Util.green |> Util.clear) ^ string_of_typ_arg arg1 ^ " and " ^ string_of_typ_arg arg2)); - let raise_failed_constraint nc = typ_raise env l (Err_failed_constraint (nc, Env.get_locals env, Env.get_constraints env)) in - match aux1, aux2 with + typ_print + (lazy (("Subtype arg " |> Util.green |> Util.clear) ^ string_of_typ_arg arg1 ^ " and " ^ string_of_typ_arg arg2)); + let raise_failed_constraint nc = + typ_raise env l (Err_failed_constraint (nc, Env.get_locals env, Env.get_constraints env)) + in + match (aux1, aux2) with | A_nexp n1, A_nexp n2 -> - let check = nc_eq n1 n2 in - if not (prove __POS__ env check) then raise_failed_constraint check + let check = nc_eq n1 n2 in + if not (prove __POS__ env check) then raise_failed_constraint check | A_typ typ1, A_typ typ2 -> subtyp l env typ1 typ2 | A_order ord1, A_order ord2 when ord_identical ord1 ord2 -> () | A_bool nc1, A_bool nc2 -> - let check = (nc_and (nc_or (nc_not nc1) nc2) (nc_or (nc_not nc2) nc1)) in - if not (prove __POS__ env check) then raise_failed_constraint check + let check = nc_and (nc_or (nc_not nc1) nc2) (nc_or (nc_not nc2) nc1) in + if not (prove __POS__ env check) then raise_failed_constraint check | _, _ -> typ_error env l "Mismatched argument types in sub-typing check" let typ_equality l env typ1 typ2 = - subtyp l env typ1 typ2; subtyp l env typ2 typ1 + subtyp l env typ1 typ2; + subtyp l env typ2 typ1 let subtype_check env typ1 typ2 = - try subtyp Parse_ast.Unknown env typ1 typ2; true with - | Type_error _ -> false + try + subtyp Parse_ast.Unknown env typ1 typ2; + true + with Type_error _ -> false (**************************************************************************) (* 4. Removing sizeof expressions *) (**************************************************************************) -exception No_simple_rewrite;; +exception No_simple_rewrite let rec move_to_front p ys = function | x :: xs when p x -> x :: (ys @ xs) @@ -2483,94 +2603,84 @@ let rec rewrite_sizeof' env (Nexp_aux (aux, l) as nexp) = let mk_exp exp = mk_exp ~loc:l exp in match aux with | Nexp_var v -> - (* Use a simple heuristic to find the most likely local we can - use, and move it to the front of the list. *) - let str = string_of_kid v in - let likely = - try let n = if str.[1] = '_' then 2 else 1 in String.sub str n (String.length str - n) with - | Invalid_argument _ -> str - in - let locals = Env.get_locals env |> Bindings.bindings in - let locals = move_to_front (fun local -> likely = string_of_id (fst local)) [] locals in - let same_size (local, (_, Typ_aux (aux, _))) = - match aux with - | Typ_app (id, [A_aux (A_nexp (Nexp_aux (Nexp_var v', _)), _)]) - when string_of_id id = "atom" && Kid.compare v v' = 0 -> true - - | Typ_app (id, [A_aux (A_nexp n, _)]) when string_of_id id = "atom" -> - prove __POS__ env (nc_eq (nvar v) n) - - | Typ_app (id, [A_aux (A_nexp (Nexp_aux (Nexp_var v', _)), _); _]) when string_of_id id = "bitvector" -> - Kid.compare v v' = 0 - - | _ -> - false - in - begin match List.find_opt same_size locals with - | Some (id, (_, typ)) -> mk_exp (E_app (mk_id "__size", [mk_exp (E_id id)])) - | None -> raise No_simple_rewrite - end - - | Nexp_constant c -> - mk_lit_exp (L_num c) - + (* Use a simple heuristic to find the most likely local we can + use, and move it to the front of the list. *) + let str = string_of_kid v in + let likely = + try + let n = if str.[1] = '_' then 2 else 1 in + String.sub str n (String.length str - n) + with Invalid_argument _ -> str + in + let locals = Env.get_locals env |> Bindings.bindings in + let locals = move_to_front (fun local -> likely = string_of_id (fst local)) [] locals in + let same_size (local, (_, Typ_aux (aux, _))) = + match aux with + | Typ_app (id, [A_aux (A_nexp (Nexp_aux (Nexp_var v', _)), _)]) + when string_of_id id = "atom" && Kid.compare v v' = 0 -> + true + | Typ_app (id, [A_aux (A_nexp n, _)]) when string_of_id id = "atom" -> prove __POS__ env (nc_eq (nvar v) n) + | Typ_app (id, [A_aux (A_nexp (Nexp_aux (Nexp_var v', _)), _); _]) when string_of_id id = "bitvector" -> + Kid.compare v v' = 0 + | _ -> false + in + begin + match List.find_opt same_size locals with + | Some (id, (_, typ)) -> mk_exp (E_app (mk_id "__size", [mk_exp (E_id id)])) + | None -> raise No_simple_rewrite + end + | Nexp_constant c -> mk_lit_exp (L_num c) | Nexp_neg nexp -> - let exp = rewrite_sizeof' env nexp in - mk_exp (E_app (mk_id "negate_atom", [exp])) - + let exp = rewrite_sizeof' env nexp in + mk_exp (E_app (mk_id "negate_atom", [exp])) | Nexp_sum (nexp1, nexp2) -> - let exp1 = rewrite_sizeof' env nexp1 in - let exp2 = rewrite_sizeof' env nexp2 in - mk_exp (E_app (mk_id "add_atom", [exp1; exp2])) - + let exp1 = rewrite_sizeof' env nexp1 in + let exp2 = rewrite_sizeof' env nexp2 in + mk_exp (E_app (mk_id "add_atom", [exp1; exp2])) | Nexp_minus (nexp1, nexp2) -> - let exp1 = rewrite_sizeof' env nexp1 in - let exp2 = rewrite_sizeof' env nexp2 in - mk_exp (E_app (mk_id "sub_atom", [exp1; exp2])) - + let exp1 = rewrite_sizeof' env nexp1 in + let exp2 = rewrite_sizeof' env nexp2 in + mk_exp (E_app (mk_id "sub_atom", [exp1; exp2])) | Nexp_times (nexp1, nexp2) -> - let exp1 = rewrite_sizeof' env nexp1 in - let exp2 = rewrite_sizeof' env nexp2 in - mk_exp (E_app (mk_id "mult_atom", [exp1; exp2])) - + let exp1 = rewrite_sizeof' env nexp1 in + let exp2 = rewrite_sizeof' env nexp2 in + mk_exp (E_app (mk_id "mult_atom", [exp1; exp2])) | Nexp_exp nexp -> - let exp = rewrite_sizeof' env nexp in - mk_exp (E_app (mk_id "pow2", [exp])) - + let exp = rewrite_sizeof' env nexp in + mk_exp (E_app (mk_id "pow2", [exp])) (* SMT solver div/mod is euclidian, so we must use those versions of div and mod in lib/smt.sail *) | Nexp_app (id, [nexp1; nexp2]) when string_of_id id = "div" -> - let exp1 = rewrite_sizeof' env nexp1 in - let exp2 = rewrite_sizeof' env nexp2 in - mk_exp (E_app (mk_id "ediv_int", [exp1; exp2])) - + let exp1 = rewrite_sizeof' env nexp1 in + let exp2 = rewrite_sizeof' env nexp2 in + mk_exp (E_app (mk_id "ediv_int", [exp1; exp2])) | Nexp_app (id, [nexp1; nexp2]) when string_of_id id = "mod" -> - let exp1 = rewrite_sizeof' env nexp1 in - let exp2 = rewrite_sizeof' env nexp2 in - mk_exp (E_app (mk_id "emod_int", [exp1; exp2])) - - | Nexp_app _ | Nexp_id _ -> - typ_error env l ("Cannot re-write sizeof(" ^ string_of_nexp nexp ^ ")") + let exp1 = rewrite_sizeof' env nexp1 in + let exp2 = rewrite_sizeof' env nexp2 in + mk_exp (E_app (mk_id "emod_int", [exp1; exp2])) + | Nexp_app _ | Nexp_id _ -> typ_error env l ("Cannot re-write sizeof(" ^ string_of_nexp nexp ^ ")") let rewrite_sizeof l env nexp = - try rewrite_sizeof' env nexp with - | No_simple_rewrite -> - let locals = Env.get_locals env |> Bindings.bindings in - let same_size (local, (_, Typ_aux (aux, _))) = - match aux with - | Typ_app (id, [A_aux (A_nexp n, _)]) when string_of_id id = "atom" -> - prove __POS__ env (nc_eq nexp n) - | _ -> false - in - begin match List.find_opt same_size locals with - | Some (id, (_, typ)) -> mk_exp (E_app (mk_id "__size", [mk_exp (E_id id)])) - | None -> - match solve_unique env nexp with - | Some n -> mk_lit_exp (L_num n) - | None -> typ_error env l ("Cannot re-write sizeof(" ^ string_of_nexp nexp ^ ")") - end + try rewrite_sizeof' env nexp + with No_simple_rewrite -> + let locals = Env.get_locals env |> Bindings.bindings in + let same_size (local, (_, Typ_aux (aux, _))) = + match aux with + | Typ_app (id, [A_aux (A_nexp n, _)]) when string_of_id id = "atom" -> prove __POS__ env (nc_eq nexp n) + | _ -> false + in + begin + match List.find_opt same_size locals with + | Some (id, (_, typ)) -> mk_exp (E_app (mk_id "__size", [mk_exp (E_id id)])) + | None -> ( + match solve_unique env nexp with + | Some n -> mk_lit_exp (L_num n) + | None -> typ_error env l ("Cannot re-write sizeof(" ^ string_of_nexp nexp ^ ")") + ) + end let rec rewrite_nc env (NC_aux (nc_aux, l)) = mk_exp ~loc:l (rewrite_nc_aux l env nc_aux) + and rewrite_nc_aux l env = let mk_exp exp = mk_exp ~loc:l exp in function @@ -2584,18 +2694,15 @@ and rewrite_nc_aux l env = | NC_or (nc1, nc2) -> E_app_infix (rewrite_nc env nc1, mk_id "|", rewrite_nc env nc2) | NC_false -> E_lit (mk_lit L_false) | NC_true -> E_lit (mk_lit L_true) - | NC_set (kid, []) -> E_lit (mk_lit (L_false)) + | NC_set (kid, []) -> E_lit (mk_lit L_false) | NC_set (kid, int :: ints) -> - let kid_eq kid int = nc_eq (nvar kid) (nconstant int) in - unaux_exp (rewrite_nc env (List.fold_left (fun nc int -> nc_or nc (kid_eq kid int)) (kid_eq kid int) ints)) - | NC_app (f, [A_aux (A_bool nc, _)]) when string_of_id f = "not" -> - E_app (mk_id "not_bool", [rewrite_nc env nc]) - | NC_app (f, args) -> - unaux_exp (rewrite_nc env (Env.expand_constraint_synonyms env (mk_nc (NC_app (f, args))))) + let kid_eq kid int = nc_eq (nvar kid) (nconstant int) in + unaux_exp (rewrite_nc env (List.fold_left (fun nc int -> nc_or nc (kid_eq kid int)) (kid_eq kid int) ints)) + | NC_app (f, [A_aux (A_bool nc, _)]) when string_of_id f = "not" -> E_app (mk_id "not_bool", [rewrite_nc env nc]) + | NC_app (f, args) -> unaux_exp (rewrite_nc env (Env.expand_constraint_synonyms env (mk_nc (NC_app (f, args))))) | NC_var v -> - (* Would be better to translate change E_sizeof to take a kid, then rewrite to E_sizeof *) - E_id (id_of_kid v) - + (* Would be better to translate change E_sizeof to take a kid, then rewrite to E_sizeof *) + E_id (id_of_kid v) (**************************************************************************) (* 5. Type checking expressions *) @@ -2605,64 +2712,39 @@ and rewrite_nc_aux l env = of these type annotations. The extra typ option is the expected type, that is, the type that the AST node was checked against, if there was one. *) type tannot' = { - env : Env.t; - typ : typ; - mutable monadic : effect; - expected : typ option; - instantiation : typ_arg KBindings.t option - } + env : Env.t; + typ : typ; + mutable monadic : effect; + expected : typ option; + instantiation : typ_arg KBindings.t option; +} type tannot = tannot' option * uannot let untyped_annot tannot = snd tannot -let mk_tannot ?uannot:(uannot=empty_uannot) env typ : tannot = - (Some { - env = env; - typ = Env.expand_synonyms env typ; - monadic = no_effect; - expected = None; - instantiation = None - }, - uannot) - -let mk_expected_tannot ?uannot:(uannot=empty_uannot) env typ expected : tannot = - (Some { - env = env; - typ = Env.expand_synonyms env typ; - monadic = no_effect; - expected = expected; - instantiation = None - }, - uannot) - -let get_instantiations = function - | (None, _) -> None - | (Some t, _) -> t.instantiation - +let mk_tannot ?(uannot = empty_uannot) env typ : tannot = + (Some { env; typ = Env.expand_synonyms env typ; monadic = no_effect; expected = None; instantiation = None }, uannot) + +let mk_expected_tannot ?(uannot = empty_uannot) env typ expected : tannot = + (Some { env; typ = Env.expand_synonyms env typ; monadic = no_effect; expected; instantiation = None }, uannot) + +let get_instantiations = function None, _ -> None | Some t, _ -> t.instantiation + let empty_tannot = (None, empty_uannot) -let is_empty_tannot tannot = match fst tannot with - | None -> true - | Some _ -> false +let is_empty_tannot tannot = match fst tannot with None -> true | Some _ -> false let map_uannot f (t, uannot) = (t, f uannot) let destruct_tannot tannot = Option.map (fun t -> (t.env, t.typ)) (fst tannot) let string_of_tannot tannot = - match destruct_tannot tannot with - | Some (_, typ) -> - "Some (_, " ^ string_of_typ typ ^ ")" - | None -> "None" + match destruct_tannot tannot with Some (_, typ) -> "Some (_, " ^ string_of_typ typ ^ ")" | None -> "None" -let replace_typ typ = function - | (Some t, u) -> (Some { t with typ = typ }, u) - | (None, u) -> (None, u) +let replace_typ typ = function Some t, u -> (Some { t with typ }, u) | None, u -> (None, u) -let replace_env env = function - | (Some t, u) -> (Some { t with env = env }, u) - | (None, u) -> (None, u) +let replace_env env = function Some t, u -> (Some { t with env }, u) | None, u -> (None, u) (* Helpers for implicit arguments in infer_funapp' *) let is_not_implicit (Typ_aux (aux, _)) = @@ -2677,8 +2759,9 @@ let implicit_to_int (Typ_aux (aux, l)) = let rec get_implicits typs = match typs with - | Typ_aux (Typ_app (id, [A_aux (A_nexp (Nexp_aux (Nexp_var impl, _)), _)]), _) :: typs when string_of_id id = "implicit" -> - impl :: get_implicits typs + | Typ_aux (Typ_app (id, [A_aux (A_nexp (Nexp_aux (Nexp_var impl, _)), _)]), _) :: typs + when string_of_id id = "implicit" -> + impl :: get_implicits typs | _ :: typs -> get_implicits typs | [] -> [] @@ -2692,103 +2775,87 @@ let infer_lit env (L_aux (lit_aux, l)) = | L_false -> atom_bool_typ nc_false | L_string _ -> string_typ | L_real _ -> real_typ - | L_bin str -> - begin - match Env.get_default_order env with - | Ord_aux (Ord_inc, _) | Ord_aux (Ord_dec, _) -> - bits_typ env (nint (String.length str)) - | Ord_aux (Ord_var _, _) -> typ_error env l default_order_error_string - end - | L_hex str -> - begin - match Env.get_default_order env with - | Ord_aux (Ord_inc, _) | Ord_aux (Ord_dec, _) -> - bits_typ env (nint (String.length str * 4)) - | Ord_aux (Ord_var _, _) -> typ_error env l default_order_error_string - end + | L_bin str -> begin + match Env.get_default_order env with + | Ord_aux (Ord_inc, _) | Ord_aux (Ord_dec, _) -> bits_typ env (nint (String.length str)) + | Ord_aux (Ord_var _, _) -> typ_error env l default_order_error_string + end + | L_hex str -> begin + match Env.get_default_order env with + | Ord_aux (Ord_inc, _) | Ord_aux (Ord_dec, _) -> bits_typ env (nint (String.length str * 4)) + | Ord_aux (Ord_var _, _) -> typ_error env l default_order_error_string + end | L_undef -> typ_error env l "Cannot infer the type of undefined" let instantiate_simple_equations = - let rec find_eqs kid (NC_aux (nc,_)) = + let rec find_eqs kid (NC_aux (nc, _)) = match nc with - | NC_equal (Nexp_aux (Nexp_var kid',_), nexp) - when Kid.compare kid kid' == 0 && - not (KidSet.mem kid (nexp_frees nexp)) -> - [arg_nexp nexp] - | NC_and (nexp1, nexp2) -> - find_eqs kid nexp1 @ find_eqs kid nexp2 + | NC_equal (Nexp_aux (Nexp_var kid', _), nexp) + when Kid.compare kid kid' == 0 && not (KidSet.mem kid (nexp_frees nexp)) -> + [arg_nexp nexp] + | NC_and (nexp1, nexp2) -> find_eqs kid nexp1 @ find_eqs kid nexp2 | _ -> [] in - let find_eqs_quant kid (QI_aux (qi,_)) = - match qi with - | QI_id _ -> [] - | QI_constraint nc -> find_eqs kid nc - in + let find_eqs_quant kid (QI_aux (qi, _)) = match qi with QI_id _ -> [] | QI_constraint nc -> find_eqs kid nc in let rec inst_from_eq = function | [] -> KBindings.empty - | (QI_aux (QI_id kinded_kid, _)) :: quants -> - let kid = kopt_kid kinded_kid in - let insts_tl = inst_from_eq quants in - begin - match List.concat (List.map (find_eqs_quant kid) quants) with - | [] -> insts_tl - | h::_ -> KBindings.add kid h (KBindings.map (typ_arg_subst kid h) insts_tl) - end - | quant :: quants -> - inst_from_eq quants - in inst_from_eq - -type destructed_vector = - | Destruct_vector of nexp * order * typ - | Destruct_bitvector of nexp * order + | QI_aux (QI_id kinded_kid, _) :: quants -> + let kid = kopt_kid kinded_kid in + let insts_tl = inst_from_eq quants in + begin + match List.concat (List.map (find_eqs_quant kid) quants) with + | [] -> insts_tl + | h :: _ -> KBindings.add kid h (KBindings.map (typ_arg_subst kid h) insts_tl) + end + | quant :: quants -> inst_from_eq quants + in + inst_from_eq + +type destructed_vector = Destruct_vector of nexp * order * typ | Destruct_bitvector of nexp * order let destruct_any_vector_typ l env typ = let destruct_any_vector_typ' l = function - | Typ_aux (Typ_app (id, [A_aux (A_nexp n1, _); - A_aux (A_order o, _)] - ), _) when string_of_id id = "bitvector" -> Destruct_bitvector (n1, o) - | Typ_aux (Typ_app (id, [A_aux (A_nexp n1, _); - A_aux (A_order o, _); - A_aux (A_typ vtyp, _)] - ), _) when string_of_id id = "vector" -> Destruct_vector (n1, o, vtyp) + | Typ_aux (Typ_app (id, [A_aux (A_nexp n1, _); A_aux (A_order o, _)]), _) when string_of_id id = "bitvector" -> + Destruct_bitvector (n1, o) + | Typ_aux (Typ_app (id, [A_aux (A_nexp n1, _); A_aux (A_order o, _); A_aux (A_typ vtyp, _)]), _) + when string_of_id id = "vector" -> + Destruct_vector (n1, o, vtyp) | typ -> typ_error env l ("Expected vector or bitvector type, got " ^ string_of_typ typ) in destruct_any_vector_typ' l (Env.expand_synonyms env typ) let destruct_vector_typ l env typ = let destruct_vector_typ' l = function - | Typ_aux (Typ_app (id, [A_aux (A_nexp n1, _); - A_aux (A_order o, _); - A_aux (A_typ vtyp, _)] - ), _) when string_of_id id = "vector" -> n1, o, vtyp + | Typ_aux (Typ_app (id, [A_aux (A_nexp n1, _); A_aux (A_order o, _); A_aux (A_typ vtyp, _)]), _) + when string_of_id id = "vector" -> + (n1, o, vtyp) | typ -> typ_error env l ("Expected vector type, got " ^ string_of_typ typ) in destruct_vector_typ' l (Env.expand_synonyms env typ) let destruct_bitvector_typ l env typ = let destruct_bitvector_typ' l = function - | Typ_aux (Typ_app (id, [A_aux (A_nexp n1, _); - A_aux (A_order o, _)] - ), _) when string_of_id id = "bitvector" -> n1, o + | Typ_aux (Typ_app (id, [A_aux (A_nexp n1, _); A_aux (A_order o, _)]), _) when string_of_id id = "bitvector" -> + (n1, o) | typ -> typ_error env l ("Expected bitvector type, got " ^ string_of_typ typ) in destruct_bitvector_typ' l (Env.expand_synonyms env typ) -let env_of_annot (l, tannot) = match fst tannot with - | Some t -> t.env - | None -> raise (Reporting.err_unreachable l __POS__ "no type annotation") +let env_of_annot (l, tannot) = + match fst tannot with Some t -> t.env | None -> raise (Reporting.err_unreachable l __POS__ "no type annotation") -let env_of_tannot tannot = match fst tannot with +let env_of_tannot tannot = + match fst tannot with | Some t -> t.env | None -> raise (Reporting.err_unreachable Parse_ast.Unknown __POS__ "no type annotation") -let typ_of_tannot tannot = match fst tannot with +let typ_of_tannot tannot = + match fst tannot with | Some t -> t.typ | None -> raise (Reporting.err_unreachable Parse_ast.Unknown __POS__ "no type annotation") -let typ_of_annot (l, tannot) = match fst tannot with - | Some t -> t.typ - | None -> raise (Reporting.err_unreachable l __POS__ "no type annotation") +let typ_of_annot (l, tannot) = + match fst tannot with Some t -> t.typ | None -> raise (Reporting.err_unreachable l __POS__ "no type annotation") let typ_of (E_aux (_, (l, tannot))) = typ_of_annot (l, tannot) @@ -2812,37 +2879,24 @@ let env_of_mpexp (MPat_aux (_, (l, tannot))) = env_of_annot (l, tannot) let lexp_typ_of (LE_aux (_, (l, tannot))) = typ_of_annot (l, tannot) -let expected_typ_of (l, tannot) = match fst tannot with - | Some t -> t.expected - | None -> raise (Reporting.err_unreachable l __POS__ "no type annotation") +let expected_typ_of (l, tannot) = + match fst tannot with Some t -> t.expected | None -> raise (Reporting.err_unreachable l __POS__ "no type annotation") (* Flow typing *) -type simple_numeric = - | Equal of nexp - | Constraint of (kid -> n_constraint) - | Anything +type simple_numeric = Equal of nexp | Constraint of (kid -> n_constraint) | Anything let to_simple_numeric l kids nc (Nexp_aux (aux, _) as n) = - match aux, kids with - | Nexp_var v, [v'] when Kid.compare v v' = 0 -> - Constraint (fun subst -> constraint_subst v (arg_nexp (nvar subst)) nc) - | _, [] -> - Equal n - | _ -> - typ_error Env.empty l "Numeric type is non-simple" + match (aux, kids) with + | Nexp_var v, [v'] when Kid.compare v v' = 0 -> Constraint (fun subst -> constraint_subst v (arg_nexp (nvar subst)) nc) + | _, [] -> Equal n + | _ -> typ_error Env.empty l "Numeric type is non-simple" let union_simple_numeric ex1 ex2 = - match ex1, ex2 with - | Equal nexp1, Equal nexp2 -> - Constraint (fun kid -> nc_or (nc_eq (nvar kid) nexp1) (nc_eq (nvar kid) nexp2)) - - | Equal nexp, Constraint c -> - Constraint (fun kid -> nc_or (nc_eq (nvar kid) nexp) (c kid)) - - | Constraint c, Equal nexp -> - Constraint (fun kid -> nc_or (c kid) (nc_eq (nvar kid) nexp)) - + match (ex1, ex2) with + | Equal nexp1, Equal nexp2 -> Constraint (fun kid -> nc_or (nc_eq (nvar kid) nexp1) (nc_eq (nvar kid) nexp2)) + | Equal nexp, Constraint c -> Constraint (fun kid -> nc_or (nc_eq (nvar kid) nexp) (c kid)) + | Constraint c, Equal nexp -> Constraint (fun kid -> nc_or (c kid) (nc_eq (nvar kid) nexp)) | _, _ -> Anything let typ_of_simple_numeric = function @@ -2850,67 +2904,55 @@ let typ_of_simple_numeric = function | Equal nexp -> atom_typ nexp | Constraint c -> exist_typ Parse_ast.Unknown c (fun kid -> atom_typ (nvar kid)) -let rec big_int_of_nexp (Nexp_aux (nexp, _)) = match nexp with +let rec big_int_of_nexp (Nexp_aux (nexp, _)) = + match nexp with | Nexp_constant c -> Some c - | Nexp_times (n1, n2) -> - Util.option_binop Big_int.add (big_int_of_nexp n1) (big_int_of_nexp n2) - | Nexp_sum (n1, n2) -> - Util.option_binop Big_int.add (big_int_of_nexp n1) (big_int_of_nexp n2) - | Nexp_minus (n1, n2) -> - Util.option_binop Big_int.add (big_int_of_nexp n1) (big_int_of_nexp n2) - | Nexp_exp n -> - Option.map (fun n -> Big_int.pow_int_positive 2 (Big_int.to_int n)) (big_int_of_nexp n) + | Nexp_times (n1, n2) -> Util.option_binop Big_int.add (big_int_of_nexp n1) (big_int_of_nexp n2) + | Nexp_sum (n1, n2) -> Util.option_binop Big_int.add (big_int_of_nexp n1) (big_int_of_nexp n2) + | Nexp_minus (n1, n2) -> Util.option_binop Big_int.add (big_int_of_nexp n1) (big_int_of_nexp n2) + | Nexp_exp n -> Option.map (fun n -> Big_int.pow_int_positive 2 (Big_int.to_int n)) (big_int_of_nexp n) | _ -> None let assert_nexp env exp = destruct_atom_nexp env (typ_of exp) -let combine_constraint b f x y = match b, x, y with - | true, Some x, Some y -> Some (f x y) - | true, Some x, None -> Some x - | true, None, Some y -> Some y +let combine_constraint b f x y = + match (b, x, y) with + | true, Some x, Some y -> Some (f x y) + | true, Some x, None -> Some x + | true, None, Some y -> Some y | false, Some x, Some y -> Some (f x y) | _, _, _ -> None let rec assert_constraint env b (E_aux (exp_aux, _) as exp) = typ_debug ~level:2 (lazy ("Asserting constraint for " ^ string_of_exp exp ^ " : " ^ string_of_typ (typ_of exp))); match typ_of exp with - | Typ_aux (Typ_app (Id_aux (Id "atom_bool", _), [A_aux (A_bool nc, _)]), _) -> - Some nc - | _ -> - match exp_aux with - | E_constraint nc -> - Some nc - | E_lit (L_aux (L_true, _)) -> Some nc_true - | E_lit (L_aux (L_false, _)) -> Some nc_false - | E_let (_,e) -> - assert_constraint env b e (* TODO: beware of fresh type vars *) - | E_app (op, [x; y]) when string_of_id op = "or_bool" -> - combine_constraint (not b) nc_or (assert_constraint env b x) (assert_constraint env b y) - | E_app (op, [x; y]) when string_of_id op = "and_bool" -> - combine_constraint b nc_and (assert_constraint env b x) (assert_constraint env b y) - | E_app (op, [x; y]) when string_of_id op = "gteq_int" -> - option_binop nc_gteq (assert_nexp env x) (assert_nexp env y) - | E_app (op, [x; y]) when string_of_id op = "lteq_int" -> - option_binop nc_lteq (assert_nexp env x) (assert_nexp env y) - | E_app (op, [x; y]) when string_of_id op = "gt_int" -> - option_binop nc_gt (assert_nexp env x) (assert_nexp env y) - | E_app (op, [x; y]) when string_of_id op = "lt_int" -> - option_binop nc_lt (assert_nexp env x) (assert_nexp env y) - | E_app (op, [x; y]) when string_of_id op = "eq_int" -> - option_binop nc_eq (assert_nexp env x) (assert_nexp env y) - | E_app (op, [x; y]) when string_of_id op = "neq_int" -> - option_binop nc_neq (assert_nexp env x) (assert_nexp env y) - | _ -> - None + | Typ_aux (Typ_app (Id_aux (Id "atom_bool", _), [A_aux (A_bool nc, _)]), _) -> Some nc + | _ -> ( + match exp_aux with + | E_constraint nc -> Some nc + | E_lit (L_aux (L_true, _)) -> Some nc_true + | E_lit (L_aux (L_false, _)) -> Some nc_false + | E_let (_, e) -> assert_constraint env b e (* TODO: beware of fresh type vars *) + | E_app (op, [x; y]) when string_of_id op = "or_bool" -> + combine_constraint (not b) nc_or (assert_constraint env b x) (assert_constraint env b y) + | E_app (op, [x; y]) when string_of_id op = "and_bool" -> + combine_constraint b nc_and (assert_constraint env b x) (assert_constraint env b y) + | E_app (op, [x; y]) when string_of_id op = "gteq_int" -> + option_binop nc_gteq (assert_nexp env x) (assert_nexp env y) + | E_app (op, [x; y]) when string_of_id op = "lteq_int" -> + option_binop nc_lteq (assert_nexp env x) (assert_nexp env y) + | E_app (op, [x; y]) when string_of_id op = "gt_int" -> option_binop nc_gt (assert_nexp env x) (assert_nexp env y) + | E_app (op, [x; y]) when string_of_id op = "lt_int" -> option_binop nc_lt (assert_nexp env x) (assert_nexp env y) + | E_app (op, [x; y]) when string_of_id op = "eq_int" -> option_binop nc_eq (assert_nexp env x) (assert_nexp env y) + | E_app (op, [x; y]) when string_of_id op = "neq_int" -> + option_binop nc_neq (assert_nexp env x) (assert_nexp env y) + | _ -> None + ) let add_opt_constraint l reason constr env = - match constr with - | None -> env - | Some constr -> Env.add_constraint ~reason:(l, reason) constr env + match constr with None -> env | Some constr -> Env.add_constraint ~reason:(l, reason) constr env -let solve_quant env = function - | QI_aux (QI_id _, _) -> false - | QI_aux (QI_constraint nc, _) -> prove __POS__ env nc +let solve_quant env = function QI_aux (QI_id _, _) -> false | QI_aux (QI_constraint nc, _) -> prove __POS__ env nc let check_function_instantiation l id env bind1 bind2 = let direction check (typq1, typ1) (typq2, typ2) = @@ -2921,24 +2963,22 @@ let check_function_instantiation l id env bind1 bind2 = try unify l check_env (quant_kopts typq2 |> List.map kopt_kid |> KidSet.of_list) typ2 typ1 with Unification_error (l, m) -> typ_error env l ("Unification error: " ^ m) in - + let quants = List.fold_left (instantiate_quants check_env) (quant_items typq2) (KBindings.bindings unifiers) in - if not (List.for_all (solve_quant check_env) quants) then ( - typ_raise env l (Err_unresolved_quants (id, quants, Env.get_locals env, Env.get_constraints env)) - ); + if not (List.for_all (solve_quant check_env) quants) then + typ_raise env l (Err_unresolved_quants (id, quants, Env.get_locals env, Env.get_constraints env)); let typ2 = subst_unifiers unifiers typ2 in - + check check_env typ1 typ2 - ) else ( - check env typ1 typ2 ) + else check env typ1 typ2 in - try direction (fun check_env typ1 typ2 -> subtyp l check_env typ1 typ2) bind1 bind2 with - | Type_error (_, l1, err1) -> - try direction (fun check_env typ1 typ2 -> subtyp l check_env typ2 typ1) bind2 bind1 with - | Type_error (err_env, l2, err2) -> - typ_raise err_env l2 (Err_inner (err2, l1, "Also tried", None, err1)) - + try direction (fun check_env typ1 typ2 -> subtyp l check_env typ1 typ2) bind1 bind2 + with Type_error (_, l1, err1) -> ( + try direction (fun check_env typ1 typ2 -> subtyp l check_env typ2 typ1) bind2 bind1 + with Type_error (err_env, l2, err2) -> typ_raise err_env l2 (Err_inner (err2, l1, "Also tried", None, err1)) + ) + (* When doing implicit type coercion, for performance reasons we want to filter out the possible casts to only those that could reasonably apply. We don't mind if we try some coercions that are @@ -2946,9 +2986,9 @@ let check_function_instantiation l id env bind1 bind2 = cast - match_typ and filter_casts implement this logic. It must be the case that if two types unify, then they match. *) let rec match_typ env typ1 typ2 = - let Typ_aux (typ1_aux, _) = Env.expand_synonyms env typ1 in - let Typ_aux (typ2_aux, _) = Env.expand_synonyms env typ2 in - match typ1_aux, typ2_aux with + let (Typ_aux (typ1_aux, _)) = Env.expand_synonyms env typ1 in + let (Typ_aux (typ2_aux, _)) = Env.expand_synonyms env typ2 in + match (typ1_aux, typ2_aux) with | Typ_exist (_, _, typ1), _ -> match_typ env typ1 typ2 | _, Typ_exist (_, _, typ2) -> match_typ env typ1 typ2 | _, Typ_var kid2 -> true @@ -2956,11 +2996,11 @@ let rec match_typ env typ1 typ2 = | Typ_id v1, Typ_id v2 when string_of_id v1 = "int" && string_of_id v2 = "nat" -> true | Typ_tuple typs1, Typ_tuple typs2 -> List.for_all2 (match_typ env) typs1 typs2 | Typ_id v, Typ_app (f, _) when string_of_id v = "nat" && string_of_id f = "atom" -> true - | Typ_id v, Typ_app (f, _) when string_of_id v = "int" && string_of_id f = "atom" -> true - | Typ_id v, Typ_app (f, _) when string_of_id v = "nat" && string_of_id f = "range" -> true - | Typ_id v, Typ_app (f, _) when string_of_id v = "int" && string_of_id f = "range" -> true - | Typ_id v, Typ_app (f, _) when string_of_id v = "bool" && string_of_id f = "atom_bool" -> true - | Typ_app (f, _), Typ_id v when string_of_id v = "bool" && string_of_id f = "atom_bool" -> true + | Typ_id v, Typ_app (f, _) when string_of_id v = "int" && string_of_id f = "atom" -> true + | Typ_id v, Typ_app (f, _) when string_of_id v = "nat" && string_of_id f = "range" -> true + | Typ_id v, Typ_app (f, _) when string_of_id v = "int" && string_of_id f = "range" -> true + | Typ_id v, Typ_app (f, _) when string_of_id v = "bool" && string_of_id f = "atom_bool" -> true + | Typ_app (f, _), Typ_id v when string_of_id v = "bool" && string_of_id f = "atom_bool" -> true | Typ_app (f1, _), Typ_app (f2, _) when string_of_id f1 = "range" && string_of_id f2 = "atom" -> true | Typ_app (f1, _), Typ_app (f2, _) when string_of_id f1 = "atom" && string_of_id f2 = "range" -> true | Typ_app (f1, _), Typ_app (f2, _) when Id.compare f1 f2 = 0 -> true @@ -2970,42 +3010,35 @@ let rec match_typ env typ1 typ2 = let rec filter_casts env from_typ to_typ casts = match casts with - | (cast :: casts) -> - begin - let (quant, cast_typ) = Env.get_val_spec cast env in - match cast_typ with - (* A cast should be a function A -> B and have only a single argument type. *) - | Typ_aux (Typ_fn (arg_typs, cast_to_typ), _) -> - begin match List.filter is_not_implicit arg_typs with + | cast :: casts -> begin + let quant, cast_typ = Env.get_val_spec cast env in + match cast_typ with + (* A cast should be a function A -> B and have only a single argument type. *) + | Typ_aux (Typ_fn (arg_typs, cast_to_typ), _) -> begin + match List.filter is_not_implicit arg_typs with | [cast_from_typ] when match_typ env from_typ cast_from_typ && match_typ env to_typ cast_to_typ -> - typ_print (lazy ("Considering cast " ^ string_of_typ cast_typ - ^ " for " ^ string_of_typ from_typ ^ " to " ^ string_of_typ to_typ)); - cast :: filter_casts env from_typ to_typ casts + typ_print + ( lazy + ("Considering cast " ^ string_of_typ cast_typ ^ " for " ^ string_of_typ from_typ ^ " to " + ^ string_of_typ to_typ + ) + ); + cast :: filter_casts env from_typ to_typ casts | _ -> filter_casts env from_typ to_typ casts - end - | _ -> filter_casts env from_typ to_typ casts - end + end + | _ -> filter_casts env from_typ to_typ casts + end | [] -> [] -type pattern_duplicate = - | Pattern_singleton of l - | Pattern_duplicate of l * l +type pattern_duplicate = Pattern_singleton of l | Pattern_duplicate of l * l -let is_enum_member id env = match Env.lookup_id id env with - | Enum _ -> true - | _ -> false +let is_enum_member id env = match Env.lookup_id id env with Enum _ -> true | _ -> false (* Check if a pattern contains duplicate bindings, and raise a type error if this is the case *) let check_pattern_duplicates env pat = - let is_duplicate _ = function - | Pattern_duplicate _ -> true - | _ -> false - in - let one_loc = function - | Pattern_singleton l -> l - | Pattern_duplicate (l, _) -> l - in + let is_duplicate _ = function Pattern_duplicate _ -> true | _ -> false in + let one_loc = function Pattern_singleton l -> l | Pattern_duplicate (l, _) -> l in let ids = ref Bindings.empty in let subrange_ids = ref Bindings.empty in let rec collect_duplicates (P_aux (aux, (l, _))) = @@ -3015,70 +3048,82 @@ let check_pattern_duplicates env pat = | duplicate -> duplicate in match aux with - | P_id id when not (is_enum_member id env) -> - ids := Bindings.update id update_id !ids - | P_vector_subrange (id, _, _) -> - subrange_ids := Bindings.add id l !subrange_ids + | P_id id when not (is_enum_member id env) -> ids := Bindings.update id update_id !ids + | P_vector_subrange (id, _, _) -> subrange_ids := Bindings.add id l !subrange_ids | P_as (p, id) -> - ids := Bindings.update id update_id !ids; - collect_duplicates p + ids := Bindings.update id update_id !ids; + collect_duplicates p | P_id _ | P_lit _ | P_wild -> () - | P_not p | P_typ (_, p) | P_var (p, _) -> - collect_duplicates p + | P_not p | P_typ (_, p) | P_var (p, _) -> collect_duplicates p | P_or (p1, p2) | P_cons (p1, p2) -> - collect_duplicates p1; collect_duplicates p2 + collect_duplicates p1; + collect_duplicates p2 | P_app (_, ps) | P_vector ps | P_vector_concat ps | P_tuple ps | P_list ps | P_string_append ps -> - List.iter collect_duplicates ps + List.iter collect_duplicates ps in collect_duplicates pat; match Bindings.choose_opt (Bindings.filter is_duplicate !ids) with | Some (id, Pattern_duplicate (l1, l2)) -> - typ_raise env l2 - (err_because (Err_other ("Duplicate binding for " ^ string_of_id id ^ " in pattern"), - l1, - Err_other ("Previous binding of " ^ string_of_id id ^ " here"))) + typ_raise env l2 + (err_because + ( Err_other ("Duplicate binding for " ^ string_of_id id ^ " in pattern"), + l1, + Err_other ("Previous binding of " ^ string_of_id id ^ " here") + ) + ) | _ -> - Bindings.iter (fun subrange_id l -> - match Bindings.find_opt subrange_id !ids with - | Some pattern_info -> - typ_raise env l - (err_because (Err_other ("Vector subrange binding " ^ string_of_id subrange_id ^ " is also bound as a regular identifier"), - one_loc pattern_info, - Err_other "Regular binding is here")) - | None -> () - ) !subrange_ids + Bindings.iter + (fun subrange_id l -> + match Bindings.find_opt subrange_id !ids with + | Some pattern_info -> + typ_raise env l + (err_because + ( Err_other + ("Vector subrange binding " ^ string_of_id subrange_id ^ " is also bound as a regular identifier"), + one_loc pattern_info, + Err_other "Regular binding is here" + ) + ) + | None -> () + ) + !subrange_ids let bitvector_typ_from_range l env n m = - let len, order = match Env.get_default_order env with + let len, order = + match Env.get_default_order env with | Ord_aux (Ord_dec, _) -> - if Big_int.greater_equal n m then - Big_int.sub (Big_int.succ n) m, dec_ord - else - typ_error env l (Printf.sprintf "First index %s must be greater than or equal to second index %s (when default Order dec)" - (Big_int.to_string n) (Big_int.to_string m)) + if Big_int.greater_equal n m then (Big_int.sub (Big_int.succ n) m, dec_ord) + else + typ_error env l + (Printf.sprintf "First index %s must be greater than or equal to second index %s (when default Order dec)" + (Big_int.to_string n) (Big_int.to_string m) + ) | Ord_aux (Ord_inc, _) -> - if Big_int.less_equal n m then - Big_int.sub (Big_int.succ m) n, inc_ord - else - typ_error env l (Printf.sprintf "First index %s must be less than or equal to second index %s (when default Order inc)" - (Big_int.to_string n) (Big_int.to_string m)) - | _ -> - typ_error env l default_order_error_string + if Big_int.less_equal n m then (Big_int.sub (Big_int.succ m) n, inc_ord) + else + typ_error env l + (Printf.sprintf "First index %s must be less than or equal to second index %s (when default Order inc)" + (Big_int.to_string n) (Big_int.to_string m) + ) + | _ -> typ_error env l default_order_error_string in bitvector_typ (nconstant len) order let bind_pattern_vector_subranges (P_aux (_, (l, _)) as pat) env = let id_ranges = pattern_vector_subranges pat in - Bindings.fold (fun id ranges env -> + Bindings.fold + (fun id ranges env -> match ranges with - | [(n, m)] -> - Env.add_local id (Immutable, bitvector_typ_from_range l env n m) env + | [(n, m)] -> Env.add_local id (Immutable, bitvector_typ_from_range l env n m) env | (_, n) :: (m, _) :: _ -> - typ_error env l (Printf.sprintf "Cannot bind %s as pattern subranges are non-contiguous. %s[%s] is not defined." - (string_of_id id) (string_of_id id) (Big_int.to_string (Big_int.succ m))) - | _ -> - Reporting.unreachable l __POS__ "Found range pattern with no range" - ) id_ranges env + typ_error env l + (Printf.sprintf "Cannot bind %s as pattern subranges are non-contiguous. %s[%s] is not defined." + (string_of_id id) (string_of_id id) + (Big_int.to_string (Big_int.succ m)) + ) + | _ -> Reporting.unreachable l __POS__ "Found range pattern with no range" + ) + id_ranges env let crule r env exp typ = incr depth; @@ -3086,27 +3131,30 @@ let crule r env exp typ = try let checked_exp = r env exp typ in Env.wf_typ env (typ_of checked_exp); - decr depth; checked_exp - with - | Type_error (env, l, err) -> decr depth; typ_raise env l err + decr depth; + checked_exp + with Type_error (env, l, err) -> + decr depth; + typ_raise env l err let irule r env exp = incr depth; try let inferred_exp = r env exp in - typ_print (lazy (Util.("Infer " |> blue |> clear) ^ string_of_exp exp ^ " => " ^ string_of_typ (typ_of inferred_exp))); + typ_print + (lazy (Util.("Infer " |> blue |> clear) ^ string_of_exp exp ^ " => " ^ string_of_typ (typ_of inferred_exp))); Env.wf_typ env (typ_of inferred_exp); decr depth; inferred_exp - with - | Type_error (env, l, err) -> decr depth; typ_raise env l err - + with Type_error (env, l, err) -> + decr depth; + typ_raise env l err (* This function adds useful assertion messages to asserts missing them *) let assert_msg = function | E_aux (E_lit (L_aux (L_string "", _)), (l, _)) -> - let open Reporting in - locate (fun _ -> l) (mk_lit_exp (L_string (short_loc_to_string l))) + let open Reporting in + locate (fun _ -> l) (mk_lit_exp (L_string (short_loc_to_string l))) | msg -> msg let strip_exp exp = map_exp_annot (fun (l, tannot) -> (l, untyped_annot tannot)) exp @@ -3128,65 +3176,67 @@ let strip_ast ast = map_ast_annot (fun (l, tannot) -> (l, untyped_annot tannot)) (* A L-expression can either be declaring new variables, or updating existing variables, but never a mix of the two *) type lexp_assignment_type = Declaration | Update -let is_update = function - | Update -> true - | Declaration -> false +let is_update = function Update -> true | Declaration -> false + +let is_declaration = function Update -> false | Declaration -> true -let is_declaration = function - | Update -> false - | Declaration -> true - let rec lexp_assignment_type env (LE_aux (aux, (l, _))) = match aux with - | LE_id v -> - begin match Env.lookup_id v env with - | Register _ | Local (Mutable, _) -> Update - | Unbound _ -> Declaration - | Local (Immutable, _) | Enum _ -> - typ_error env l ("Cannot modify immutable let-bound constant or enumeration constructor " ^ string_of_id v) - end - | LE_typ (_, v) -> - begin match Env.lookup_id v env with - | Register _ | Local (Mutable, _) -> - Reporting.warn ("Redundant type annotation on assignment to " ^ string_of_id v) l "Type is already known"; - Update - | Unbound _ -> Declaration - | Local (Immutable, _) | Enum _ -> - typ_error env l ("Cannot modify immutable let-bound constant or enumeration constructor " ^ string_of_id v) - end + | LE_id v -> begin + match Env.lookup_id v env with + | Register _ | Local (Mutable, _) -> Update + | Unbound _ -> Declaration + | Local (Immutable, _) | Enum _ -> + typ_error env l ("Cannot modify immutable let-bound constant or enumeration constructor " ^ string_of_id v) + end + | LE_typ (_, v) -> begin + match Env.lookup_id v env with + | Register _ | Local (Mutable, _) -> + Reporting.warn ("Redundant type annotation on assignment to " ^ string_of_id v) l "Type is already known"; + Update + | Unbound _ -> Declaration + | Local (Immutable, _) | Enum _ -> + typ_error env l ("Cannot modify immutable let-bound constant or enumeration constructor " ^ string_of_id v) + end | LE_deref _ | LE_app _ -> Update - | LE_field (lexp, _) -> - begin match lexp_assignment_type env lexp with - | Update -> Update - | Declaration -> - typ_error env l "Field assignment can only be done to a variable that has already been declared" - end - | LE_vector (lexp, _) | LE_vector_range (lexp, _, _) -> - begin match lexp_assignment_type env lexp with - | Update -> Update - | Declaration -> - typ_error env l "Vector assignment can only be done to a variable that has already been declared" - end + | LE_field (lexp, _) -> begin + match lexp_assignment_type env lexp with + | Update -> Update + | Declaration -> typ_error env l "Field assignment can only be done to a variable that has already been declared" + end + | LE_vector (lexp, _) | LE_vector_range (lexp, _, _) -> begin + match lexp_assignment_type env lexp with + | Update -> Update + | Declaration -> typ_error env l "Vector assignment can only be done to a variable that has already been declared" + end | LE_tuple lexps | LE_vector_concat lexps -> - let lexp_is_update lexp = lexp_assignment_type env lexp |> is_update in - let lexp_is_declaration lexp = lexp_assignment_type env lexp |> is_declaration in - begin match List.find_opt lexp_is_update lexps, List.find_opt lexp_is_declaration lexps with - | Some (LE_aux (_, (l_u, _))), Some (LE_aux (_, (l_d, _)) as lexp_d) -> - typ_raise env l_d (Err_inner (Err_other ("Assignment declaring new variable " ^ string_of_lexp lexp_d ^ " is also assigning to an existing variable"), - l_u, - "", - Some "existing variable", - Err_other "")) - | None, _ -> Declaration - | _, None -> Update - end - - + let lexp_is_update lexp = lexp_assignment_type env lexp |> is_update in + let lexp_is_declaration lexp = lexp_assignment_type env lexp |> is_declaration in + begin + match (List.find_opt lexp_is_update lexps, List.find_opt lexp_is_declaration lexps) with + | Some (LE_aux (_, (l_u, _))), Some (LE_aux (_, (l_d, _)) as lexp_d) -> + typ_raise env l_d + (Err_inner + ( Err_other + ("Assignment declaring new variable " ^ string_of_lexp lexp_d + ^ " is also assigning to an existing variable" + ), + l_u, + "", + Some "existing variable", + Err_other "" + ) + ) + | None, _ -> Declaration + | _, None -> Update + end + let fresh_var = let counter = ref 0 in - fun () -> let n = !counter in - let () = counter := n+1 in - mk_id ("v#" ^ string_of_int n) + fun () -> + let n = !counter in + let () = counter := n + 1 in + mk_id ("v#" ^ string_of_int n) let rec exp_unconditionally_returns (E_aux (aux, _)) = match aux with @@ -3195,352 +3245,390 @@ let rec exp_unconditionally_returns (E_aux (aux, _)) = | E_block exps -> exp_unconditionally_returns (List.hd (List.rev exps)) | _ -> false -let tc_assume nc (E_aux (aux, annot)) = - E_aux (E_internal_assume (nc, E_aux (aux, annot)), annot) - +let tc_assume nc (E_aux (aux, annot)) = E_aux (E_internal_assume (nc, E_aux (aux, annot)), annot) + module PC_config = struct type t = tannot let typ_of_t = typ_of_tannot let add_attribute l attr arg = map_uannot (add_attribute l attr arg) end -module PC = Pattern_completeness.Make(PC_config) +module PC = Pattern_completeness.Make (PC_config) + +let pattern_completeness_ctx env = + { + Pattern_completeness.variants = Env.get_variants env; + Pattern_completeness.enums = Env.get_enums env; + Pattern_completeness.constraints = Env.get_constraints env; + } -let pattern_completeness_ctx env = { - Pattern_completeness.variants = Env.get_variants env; - Pattern_completeness.enums = Env.get_enums env; - Pattern_completeness.constraints = Env.get_constraints env; -} - let rec check_exp env (E_aux (exp_aux, (l, uannot)) as exp : uannot exp) (Typ_aux (typ_aux, _) as typ) : tannot exp = - let annot_exp exp typ' = E_aux (exp, (l, mk_expected_tannot ~uannot:uannot env typ' (Some typ))) in + let annot_exp exp typ' = E_aux (exp, (l, mk_expected_tannot ~uannot env typ' (Some typ))) in let update_uannot f (E_aux (aux, (l, (tannot, uannot)))) = E_aux (aux, (l, (tannot, f uannot))) in match (exp_aux, typ_aux) with - | E_block exps, _ -> - annot_exp (E_block (check_block l env exps (Some typ))) typ + | E_block exps, _ -> annot_exp (E_block (check_block l env exps (Some typ))) typ | E_match (exp, cases), _ -> - let inferred_exp = irule infer_exp env exp in - let inferred_typ = typ_of inferred_exp in - let checked_cases = List.map (fun case -> check_case env inferred_typ case typ) cases in - let checked_cases, attr_update = - if Option.is_some (get_attribute "complete" uannot) || Option.is_some (get_attribute "incomplete" uannot) then ( - checked_cases, (fun attrs -> attrs) - ) else ( - let ctx = pattern_completeness_ctx env in - match PC.is_complete_wildcarded l ctx checked_cases inferred_typ with - | Some wildcarded -> wildcarded, add_attribute (gen_loc l) "complete" "" - | None -> checked_cases, add_attribute (gen_loc l) "incomplete" "" - ) in - annot_exp (E_match (inferred_exp, checked_cases)) typ |> update_uannot attr_update + let inferred_exp = irule infer_exp env exp in + let inferred_typ = typ_of inferred_exp in + let checked_cases = List.map (fun case -> check_case env inferred_typ case typ) cases in + let checked_cases, attr_update = + if Option.is_some (get_attribute "complete" uannot) || Option.is_some (get_attribute "incomplete" uannot) then + (checked_cases, fun attrs -> attrs) + else ( + let ctx = pattern_completeness_ctx env in + match PC.is_complete_wildcarded l ctx checked_cases inferred_typ with + | Some wildcarded -> (wildcarded, add_attribute (gen_loc l) "complete" "") + | None -> (checked_cases, add_attribute (gen_loc l) "incomplete" "") + ) + in + annot_exp (E_match (inferred_exp, checked_cases)) typ |> update_uannot attr_update | E_try (exp, cases), _ -> - let checked_exp = crule check_exp env exp typ in - annot_exp (E_try (checked_exp, List.map (fun case -> check_case env exc_typ case typ) cases)) typ - | E_cons (x, xs), _ -> - begin - match is_list (Env.expand_synonyms env typ) with - | Some elem_typ -> + let checked_exp = crule check_exp env exp typ in + annot_exp (E_try (checked_exp, List.map (fun case -> check_case env exc_typ case typ) cases)) typ + | E_cons (x, xs), _ -> begin + match is_list (Env.expand_synonyms env typ) with + | Some elem_typ -> let checked_xs = crule check_exp env xs typ in let checked_x = crule check_exp env x elem_typ in annot_exp (E_cons (checked_x, checked_xs)) typ - | None -> typ_error env l ("Cons " ^ string_of_exp exp ^ " must have list type, got " ^ string_of_typ typ) - end - | E_list xs, _ -> - begin - match is_list (Env.expand_synonyms env typ) with - | Some elem_typ -> + | None -> typ_error env l ("Cons " ^ string_of_exp exp ^ " must have list type, got " ^ string_of_typ typ) + end + | E_list xs, _ -> begin + match is_list (Env.expand_synonyms env typ) with + | Some elem_typ -> let checked_xs = List.map (fun x -> crule check_exp env x elem_typ) xs in annot_exp (E_list checked_xs) typ - | None -> typ_error env l ("List " ^ string_of_exp exp ^ " must have list type, got " ^ string_of_typ typ) - end + | None -> typ_error env l ("List " ^ string_of_exp exp ^ " must have list type, got " ^ string_of_typ typ) + end | E_struct_update (exp, fexps), _ -> - let checked_exp = crule check_exp env exp typ in - let rectyp_id = match Env.expand_synonyms env typ with - | Typ_aux (Typ_id rectyp_id, _) | Typ_aux (Typ_app (rectyp_id, _), _) when Env.is_record rectyp_id env -> - rectyp_id - | _ -> typ_error env l ("The type " ^ string_of_typ typ ^ " is not a record") - in - let check_fexp (FE_aux (FE_fexp (field, exp), (l, _))) = - let (typq, rectyp_q, field_typ) = Env.get_accessor rectyp_id field env in - let unifiers = try unify l env (tyvars_of_typ rectyp_q) rectyp_q typ with Unification_error (l, m) -> typ_error env l ("Unification error: " ^ m) in - let field_typ' = subst_unifiers unifiers field_typ in - let checked_exp = crule check_exp env exp field_typ' in - FE_aux (FE_fexp (field, checked_exp), (l, empty_tannot)) - in - annot_exp (E_struct_update (checked_exp, List.map check_fexp fexps)) typ + let checked_exp = crule check_exp env exp typ in + let rectyp_id = + match Env.expand_synonyms env typ with + | (Typ_aux (Typ_id rectyp_id, _) | Typ_aux (Typ_app (rectyp_id, _), _)) when Env.is_record rectyp_id env -> + rectyp_id + | _ -> typ_error env l ("The type " ^ string_of_typ typ ^ " is not a record") + in + let check_fexp (FE_aux (FE_fexp (field, exp), (l, _))) = + let typq, rectyp_q, field_typ = Env.get_accessor rectyp_id field env in + let unifiers = + try unify l env (tyvars_of_typ rectyp_q) rectyp_q typ + with Unification_error (l, m) -> typ_error env l ("Unification error: " ^ m) + in + let field_typ' = subst_unifiers unifiers field_typ in + let checked_exp = crule check_exp env exp field_typ' in + FE_aux (FE_fexp (field, checked_exp), (l, empty_tannot)) + in + annot_exp (E_struct_update (checked_exp, List.map check_fexp fexps)) typ | E_struct fexps, _ -> - let rectyp_id = match Env.expand_synonyms env typ with - | Typ_aux (Typ_id rectyp_id, _) | Typ_aux (Typ_app (rectyp_id, _), _) when Env.is_record rectyp_id env -> - rectyp_id - | _ -> typ_error env l ("The type " ^ string_of_typ typ ^ " is not a record") - in - let record_fields = ref (Env.get_record rectyp_id env |> snd |> List.map snd |> IdSet.of_list) in - let check_fexp (FE_aux (FE_fexp (field, exp), (l, _))) = - record_fields := IdSet.remove field !record_fields; - let (typq, rectyp_q, field_typ) = Env.get_accessor rectyp_id field env in - let unifiers = try unify l env (tyvars_of_typ rectyp_q) rectyp_q typ with Unification_error (l, m) -> typ_error env l ("Unification error: " ^ m) in - let field_typ' = subst_unifiers unifiers field_typ in - let checked_exp = crule check_exp env exp field_typ' in - FE_aux (FE_fexp (field, checked_exp), (l, empty_tannot)) - in - let fexps = List.map check_fexp fexps in - if IdSet.is_empty !record_fields then - annot_exp (E_struct fexps) typ - else - typ_error env l ("struct literal missing fields: " ^ string_of_list ", " string_of_id (IdSet.elements !record_fields)) - | E_let (LB_aux (letbind, (let_loc, _)), exp), _ -> - begin - match letbind with - | LB_val (P_aux (P_typ (ptyp, _), _) as pat, bind) -> + let rectyp_id = + match Env.expand_synonyms env typ with + | (Typ_aux (Typ_id rectyp_id, _) | Typ_aux (Typ_app (rectyp_id, _), _)) when Env.is_record rectyp_id env -> + rectyp_id + | _ -> typ_error env l ("The type " ^ string_of_typ typ ^ " is not a record") + in + let record_fields = ref (Env.get_record rectyp_id env |> snd |> List.map snd |> IdSet.of_list) in + let check_fexp (FE_aux (FE_fexp (field, exp), (l, _))) = + record_fields := IdSet.remove field !record_fields; + let typq, rectyp_q, field_typ = Env.get_accessor rectyp_id field env in + let unifiers = + try unify l env (tyvars_of_typ rectyp_q) rectyp_q typ + with Unification_error (l, m) -> typ_error env l ("Unification error: " ^ m) + in + let field_typ' = subst_unifiers unifiers field_typ in + let checked_exp = crule check_exp env exp field_typ' in + FE_aux (FE_fexp (field, checked_exp), (l, empty_tannot)) + in + let fexps = List.map check_fexp fexps in + if IdSet.is_empty !record_fields then annot_exp (E_struct fexps) typ + else + typ_error env l + ("struct literal missing fields: " ^ string_of_list ", " string_of_id (IdSet.elements !record_fields)) + | E_let (LB_aux (letbind, (let_loc, _)), exp), _ -> begin + match letbind with + | LB_val ((P_aux (P_typ (ptyp, _), _) as pat), bind) -> Env.wf_typ env ptyp; let checked_bind = crule check_exp env bind ptyp in check_pattern_duplicates env pat; let env = bind_pattern_vector_subranges pat env in let tpat, inner_env = bind_pat_no_guard env pat ptyp in - annot_exp (E_let (LB_aux (LB_val (tpat, checked_bind), (let_loc, empty_tannot)), crule check_exp inner_env exp typ)) + annot_exp + (E_let (LB_aux (LB_val (tpat, checked_bind), (let_loc, empty_tannot)), crule check_exp inner_env exp typ)) (check_shadow_leaks l inner_env env typ) - | LB_val (pat, bind) -> + | LB_val (pat, bind) -> let inferred_bind = irule infer_exp env bind in check_pattern_duplicates env pat; let tpat, inner_env = bind_pat_no_guard env pat (typ_of inferred_bind) in - annot_exp (E_let (LB_aux (LB_val (tpat, inferred_bind), (let_loc, empty_tannot)), crule check_exp inner_env exp typ)) + annot_exp + (E_let (LB_aux (LB_val (tpat, inferred_bind), (let_loc, empty_tannot)), crule check_exp inner_env exp typ)) (check_shadow_leaks l inner_env env typ) - end - | E_app_infix (x, op, y), _ -> - check_exp env (E_aux (E_app (deinfix op, [x; y]), (l, uannot))) typ + end + | E_app_infix (x, op, y), _ -> check_exp env (E_aux (E_app (deinfix op, [x; y]), (l, uannot))) typ | E_app (f, [E_aux (E_constraint nc, _)]), _ when string_of_id f = "_prove" -> - Env.wf_constraint env nc; - if prove __POS__ env nc - then annot_exp (E_lit (L_aux (L_unit, Parse_ast.Unknown))) unit_typ - else typ_error env l ("Cannot prove " ^ string_of_n_constraint nc) + Env.wf_constraint env nc; + if prove __POS__ env nc then annot_exp (E_lit (L_aux (L_unit, Parse_ast.Unknown))) unit_typ + else typ_error env l ("Cannot prove " ^ string_of_n_constraint nc) | E_app (f, [E_aux (E_constraint nc, _)]), _ when string_of_id f = "_not_prove" -> - Env.wf_constraint env nc; - if prove __POS__ env nc - then typ_error env l ("Can prove " ^ string_of_n_constraint nc) - else annot_exp (E_lit (L_aux (L_unit, Parse_ast.Unknown))) unit_typ + Env.wf_constraint env nc; + if prove __POS__ env nc then typ_error env l ("Can prove " ^ string_of_n_constraint nc) + else annot_exp (E_lit (L_aux (L_unit, Parse_ast.Unknown))) unit_typ | E_app (f, [E_aux (E_typ (typ, exp), _)]), _ when string_of_id f = "_check" -> - Env.wf_typ env typ; - let _ = crule check_exp env exp typ in - annot_exp (E_lit (L_aux (L_unit, Parse_ast.Unknown))) unit_typ + Env.wf_typ env typ; + let _ = crule check_exp env exp typ in + annot_exp (E_lit (L_aux (L_unit, Parse_ast.Unknown))) unit_typ | E_app (f, [E_aux (E_typ (typ, exp), _)]), _ when string_of_id f = "_not_check" -> - Env.wf_typ env typ; - if (try (ignore (crule check_exp env exp typ); false) with Type_error _ -> true) - then annot_exp (E_lit (L_aux (L_unit, Parse_ast.Unknown))) unit_typ - else typ_error env l (Printf.sprintf "Expected _not_check(%s : %s) to fail" (string_of_exp exp) (string_of_typ typ)) + Env.wf_typ env typ; + if + try + ignore (crule check_exp env exp typ); + false + with Type_error _ -> true + then annot_exp (E_lit (L_aux (L_unit, Parse_ast.Unknown))) unit_typ + else + typ_error env l (Printf.sprintf "Expected _not_check(%s : %s) to fail" (string_of_exp exp) (string_of_typ typ)) (* All constructors and mappings are treated as having one argument so Ctor(x, y) is checked as Ctor((x, y)) *) | E_app (f, x :: y :: zs), _ when Env.is_union_constructor f env || Env.is_mapping f env -> - typ_print (lazy ("Checking multiple argument constructor or mapping: " ^ string_of_id f)); - crule check_exp env (mk_exp ~loc:l (E_app (f, [mk_exp ~loc:l (E_tuple (x :: y :: zs))]))) typ + typ_print (lazy ("Checking multiple argument constructor or mapping: " ^ string_of_id f)); + crule check_exp env (mk_exp ~loc:l (E_app (f, [mk_exp ~loc:l (E_tuple (x :: y :: zs))]))) typ | E_app (mapping, xs), _ when Env.is_mapping mapping env -> - let forwards_id = mk_id (string_of_id mapping ^ "_forwards") in - let backwards_id = mk_id (string_of_id mapping ^ "_backwards") in - typ_print (lazy("Trying forwards direction for mapping " ^ string_of_id mapping ^ "(" ^ string_of_list ", " string_of_exp xs ^ ")")); - begin try crule check_exp env (E_aux (E_app (forwards_id, xs), (l, uannot))) typ with - | Type_error (_, _, err1) -> - typ_print (lazy ("Trying backwards direction for mapping " ^ string_of_id mapping ^ "(" ^ string_of_list ", " string_of_exp xs ^ ")")); - begin try crule check_exp env (E_aux (E_app (backwards_id, xs), (l, uannot))) typ with - | Type_error (_, _, err2) -> - typ_raise env l (Err_no_overloading (mapping, [(forwards_id, err1); (backwards_id, err2)])) - end - end + let forwards_id = mk_id (string_of_id mapping ^ "_forwards") in + let backwards_id = mk_id (string_of_id mapping ^ "_backwards") in + typ_print + ( lazy + ("Trying forwards direction for mapping " ^ string_of_id mapping ^ "(" ^ string_of_list ", " string_of_exp xs + ^ ")" + ) + ); + begin + try crule check_exp env (E_aux (E_app (forwards_id, xs), (l, uannot))) typ + with Type_error (_, _, err1) -> + typ_print + ( lazy + ("Trying backwards direction for mapping " ^ string_of_id mapping ^ "(" + ^ string_of_list ", " string_of_exp xs ^ ")" + ) + ); + begin + try crule check_exp env (E_aux (E_app (backwards_id, xs), (l, uannot))) typ + with Type_error (_, _, err2) -> + typ_raise env l (Err_no_overloading (mapping, [(forwards_id, err1); (backwards_id, err2)])) + end + end | E_app (f, xs), _ when List.length (Env.get_overloads f env) > 0 -> - let rec try_overload = function - | (errs, []) -> typ_raise env l (Err_no_overloading (f, errs)) - | (errs, (f :: fs)) -> begin - typ_print (lazy ("Overload: " ^ string_of_id f ^ "(" ^ string_of_list ", " string_of_exp xs ^ ")")); - try crule check_exp env (E_aux (E_app (f, xs), (l, uannot))) typ with - | Type_error (_, _, err) -> + let rec try_overload = function + | errs, [] -> typ_raise env l (Err_no_overloading (f, errs)) + | errs, f :: fs -> begin + typ_print (lazy ("Overload: " ^ string_of_id f ^ "(" ^ string_of_list ", " string_of_exp xs ^ ")")); + try crule check_exp env (E_aux (E_app (f, xs), (l, uannot))) typ + with Type_error (_, _, err) -> typ_debug (lazy "Error"); try_overload (errs @ [(f, err)], fs) - end - in - try_overload ([], Env.get_overloads f env) - | E_app (f, [x; y]), _ when string_of_id f = "and_bool" || string_of_id f = "or_bool" -> - (* We have to ensure that the type of y in (x || y) and (x && y) - is non-empty, otherwise it could force the entire type of the - expression to become empty even when unevaluted due to - short-circuiting. *) - begin match destruct_exist (typ_of (irule infer_exp env y)) with - | None | Some (_, NC_aux (NC_true, _), _) -> - let inferred_exp = infer_funapp l env f [x; y] (Some typ) in - type_coercion env inferred_exp typ - | Some _ -> - let inferred_exp = infer_funapp l env f [x; mk_exp (E_typ (bool_typ, y))] (Some typ) in - type_coercion env inferred_exp typ - | exception Type_error _ -> - let inferred_exp = infer_funapp l env f [x; mk_exp (E_typ (bool_typ, y))] (Some typ) in - type_coercion env inferred_exp typ - end + end + in + try_overload ([], Env.get_overloads f env) + | E_app (f, [x; y]), _ when string_of_id f = "and_bool" || string_of_id f = "or_bool" -> begin + (* We have to ensure that the type of y in (x || y) and (x && y) + is non-empty, otherwise it could force the entire type of the + expression to become empty even when unevaluted due to + short-circuiting. *) + match destruct_exist (typ_of (irule infer_exp env y)) with + | None | Some (_, NC_aux (NC_true, _), _) -> + let inferred_exp = infer_funapp l env f [x; y] (Some typ) in + type_coercion env inferred_exp typ + | Some _ -> + let inferred_exp = infer_funapp l env f [x; mk_exp (E_typ (bool_typ, y))] (Some typ) in + type_coercion env inferred_exp typ + | exception Type_error _ -> + let inferred_exp = infer_funapp l env f [x; mk_exp (E_typ (bool_typ, y))] (Some typ) in + type_coercion env inferred_exp typ + end | E_app (f, xs), _ -> - let inferred_exp = infer_funapp l env f xs (Some typ) in - type_coercion env inferred_exp typ + let inferred_exp = infer_funapp l env f xs (Some typ) in + type_coercion env inferred_exp typ | E_return exp, _ -> - let checked_exp = match Env.get_ret_typ env with - | Some ret_typ -> crule check_exp env exp ret_typ - | None -> typ_error env l "Cannot use return outside a function" - in - annot_exp (E_return checked_exp) typ + let checked_exp = + match Env.get_ret_typ env with + | Some ret_typ -> crule check_exp env exp ret_typ + | None -> typ_error env l "Cannot use return outside a function" + in + annot_exp (E_return checked_exp) typ | E_tuple exps, Typ_tuple typs when List.length exps = List.length typs -> - let checked_exps = List.map2 (fun exp typ -> crule check_exp env exp typ) exps typs in - annot_exp (E_tuple checked_exps) typ + let checked_exps = List.map2 (fun exp typ -> crule check_exp env exp typ) exps typs in + annot_exp (E_tuple checked_exps) typ | E_if (cond, then_branch, else_branch), _ -> - let cond' = try irule infer_exp env cond with Type_error _ -> crule check_exp env cond bool_typ in - begin match destruct_exist (typ_of cond') with - | Some (kopts, nc, Typ_aux (Typ_app (ab, [A_aux (A_bool flow, _)]), _)) when string_of_id ab = "atom_bool" -> - let env = add_existential l kopts nc env in - let then_branch' = crule check_exp (Env.add_constraint ~reason:(l, "then branch") flow env) then_branch typ in - let else_branch' = crule check_exp (Env.add_constraint ~reason:(l, "else branch") (nc_not flow) env) else_branch typ in - annot_exp (E_if (cond', then_branch', else_branch')) typ - | _ -> - let cond' = type_coercion env cond' bool_typ in - let then_branch' = crule check_exp (add_opt_constraint l "then branch" (assert_constraint env true cond') env) then_branch typ in - let else_branch' = crule check_exp (add_opt_constraint l "else branch" (Option.map nc_not (assert_constraint env false cond')) env) else_branch typ in - annot_exp (E_if (cond', then_branch', else_branch')) typ - end + let cond' = try irule infer_exp env cond with Type_error _ -> crule check_exp env cond bool_typ in + begin + match destruct_exist (typ_of cond') with + | Some (kopts, nc, Typ_aux (Typ_app (ab, [A_aux (A_bool flow, _)]), _)) when string_of_id ab = "atom_bool" -> + let env = add_existential l kopts nc env in + let then_branch' = + crule check_exp (Env.add_constraint ~reason:(l, "then branch") flow env) then_branch typ + in + let else_branch' = + crule check_exp (Env.add_constraint ~reason:(l, "else branch") (nc_not flow) env) else_branch typ + in + annot_exp (E_if (cond', then_branch', else_branch')) typ + | _ -> + let cond' = type_coercion env cond' bool_typ in + let then_branch' = + crule check_exp + (add_opt_constraint l "then branch" (assert_constraint env true cond') env) + then_branch typ + in + let else_branch' = + crule check_exp + (add_opt_constraint l "else branch" (Option.map nc_not (assert_constraint env false cond')) env) + else_branch typ + in + annot_exp (E_if (cond', then_branch', else_branch')) typ + end | E_exit exp, _ -> - let checked_exp = crule check_exp env exp unit_typ in - annot_exp (E_exit checked_exp) typ + let checked_exp = crule check_exp env exp unit_typ in + annot_exp (E_exit checked_exp) typ | E_throw exp, _ -> - let checked_exp = crule check_exp env exp exc_typ in - annot_exp (E_throw checked_exp) typ - | E_var (lexp, bind, exp), _ -> - begin match lexp_assignment_type env lexp with - | Declaration -> - let lexp, bind, env = match bind_assignment l env lexp bind with - | E_aux (E_assign (lexp, bind), _), env -> lexp, bind, env - | _, _ -> assert false - in - let checked_exp = crule check_exp env exp typ in - annot_exp (E_var (lexp, bind, checked_exp)) typ - | Update -> - typ_error env l "var expression can only be used to declare new variables, not update them" - end + let checked_exp = crule check_exp env exp exc_typ in + annot_exp (E_throw checked_exp) typ + | E_var (lexp, bind, exp), _ -> begin + match lexp_assignment_type env lexp with + | Declaration -> + let lexp, bind, env = + match bind_assignment l env lexp bind with + | E_aux (E_assign (lexp, bind), _), env -> (lexp, bind, env) + | _, _ -> assert false + in + let checked_exp = crule check_exp env exp typ in + annot_exp (E_var (lexp, bind, checked_exp)) typ + | Update -> typ_error env l "var expression can only be used to declare new variables, not update them" + end | E_internal_return exp, _ -> - let checked_exp = crule check_exp env exp typ in - annot_exp (E_internal_return checked_exp) typ + let checked_exp = crule check_exp env exp typ in + annot_exp (E_internal_return checked_exp) typ | E_internal_plet (pat, bind, body), _ -> - let bind_exp, ptyp = match pat with - | P_aux (P_typ (ptyp, _), _) -> - Env.wf_typ env ptyp; - let checked_bind = crule check_exp env bind ptyp in - checked_bind, ptyp - | _ -> - let inferred_bind = irule infer_exp env bind in - inferred_bind, typ_of inferred_bind in - let tpat, env = bind_pat_no_guard env pat ptyp in - (* Propagate constraint assertions on the lhs of monadic binds to the rhs *) - let env = match bind_exp with - | E_aux (E_assert (constr_exp, _), _) -> - begin + let bind_exp, ptyp = + match pat with + | P_aux (P_typ (ptyp, _), _) -> + Env.wf_typ env ptyp; + let checked_bind = crule check_exp env bind ptyp in + (checked_bind, ptyp) + | _ -> + let inferred_bind = irule infer_exp env bind in + (inferred_bind, typ_of inferred_bind) + in + let tpat, env = bind_pat_no_guard env pat ptyp in + (* Propagate constraint assertions on the lhs of monadic binds to the rhs *) + let env = + match bind_exp with + | E_aux (E_assert (constr_exp, _), _) -> begin match assert_constraint env true constr_exp with | Some nc -> - typ_print (lazy ("Adding constraint " ^ string_of_n_constraint nc ^ " for assert")); - Env.add_constraint nc env + typ_print (lazy ("Adding constraint " ^ string_of_n_constraint nc ^ " for assert")); + Env.add_constraint nc env | None -> env end - | E_aux (E_if (cond, e_t, e_e), _) -> - begin + | E_aux (E_if (cond, e_t, e_e), _) -> begin match unaux_exp (fst (uncast_exp e_t)) with | E_throw _ | E_block [E_aux (E_throw _, _)] -> - add_opt_constraint l "if-throw" (Option.map nc_not (assert_constraint env false cond)) env + add_opt_constraint l "if-throw" (Option.map nc_not (assert_constraint env false cond)) env | _ -> env end - | _ -> env in - let checked_body = crule check_exp env body typ in - annot_exp (E_internal_plet (tpat, bind_exp, checked_body)) typ + | _ -> env + in + let checked_body = crule check_exp env body typ in + annot_exp (E_internal_plet (tpat, bind_exp, checked_body)) typ | E_vector vec, _ -> - let len, ord, vtyp = match destruct_any_vector_typ l env typ with - | Destruct_vector (len, ord, vtyp) -> len, ord, vtyp - | Destruct_bitvector (len, ord) -> len, ord, bit_typ - in - let checked_items = List.map (fun i -> crule check_exp env i vtyp) vec in - if prove __POS__ env (nc_eq (nint (List.length vec)) (nexp_simp len)) then annot_exp (E_vector checked_items) typ - else typ_error env l "List length didn't match" (* FIXME: improve error message *) + let len, ord, vtyp = + match destruct_any_vector_typ l env typ with + | Destruct_vector (len, ord, vtyp) -> (len, ord, vtyp) + | Destruct_bitvector (len, ord) -> (len, ord, bit_typ) + in + let checked_items = List.map (fun i -> crule check_exp env i vtyp) vec in + if prove __POS__ env (nc_eq (nint (List.length vec)) (nexp_simp len)) then annot_exp (E_vector checked_items) typ + else typ_error env l "List length didn't match" (* FIXME: improve error message *) | E_lit (L_aux (L_undef, _) as lit), _ -> - if (is_typ_monomorphic typ || Env.polymorphic_undefineds env) then ( - if is_typ_inhabited env (Env.expand_synonyms env typ) || Env.polymorphic_undefineds env then ( - annot_exp (E_lit lit) typ - ) else ( - typ_error env l ("Type " ^ string_of_typ typ ^ " is empty") - ) - ) else ( - typ_error env l ("Type " ^ string_of_typ typ ^ " must be monomorphic") - ) + if is_typ_monomorphic typ || Env.polymorphic_undefineds env then + if is_typ_inhabited env (Env.expand_synonyms env typ) || Env.polymorphic_undefineds env then + annot_exp (E_lit lit) typ + else typ_error env l ("Type " ^ string_of_typ typ ^ " is empty") + else typ_error env l ("Type " ^ string_of_typ typ ^ " must be monomorphic") | E_internal_assume (nc, exp), _ -> - Env.wf_constraint env nc; - let env = Env.add_constraint nc env in - let exp' = crule check_exp env exp typ in - annot_exp (E_internal_assume (nc, exp')) typ + Env.wf_constraint env nc; + let env = Env.add_constraint nc env in + let exp' = crule check_exp env exp typ in + annot_exp (E_internal_assume (nc, exp')) typ | _, _ -> - let inferred_exp = irule infer_exp env exp in - type_coercion env inferred_exp typ + let inferred_exp = irule infer_exp env exp in + type_coercion env inferred_exp typ and check_block l env exps ret_typ = - let final env exp = match ret_typ with - | Some typ -> crule check_exp env exp typ - | None -> irule infer_exp env exp - in + let final env exp = match ret_typ with Some typ -> crule check_exp env exp typ | None -> irule infer_exp env exp in let annot_exp exp typ exp_typ = E_aux (exp, (l, mk_expected_tannot env typ exp_typ)) in match Nl_flow.analyze exps with - | [] -> (match ret_typ with Some typ -> typ_equality l env typ unit_typ; [] | None -> []) + | [] -> ( + match ret_typ with + | Some typ -> + typ_equality l env typ unit_typ; + [] + | None -> [] + ) (* We need the special case for assign even if it's the last expression in the block because the block provides the scope when it's a declaration. *) - | (E_aux (E_assign (lexp, bind), (assign_l, _)) :: exps) -> - begin match lexp_assignment_type env lexp with - | Update -> - let texp, env = bind_assignment assign_l env lexp bind in - texp :: check_block l env exps ret_typ - | Declaration -> - let lexp, bind, env = match bind_assignment l env lexp bind with - | E_aux (E_assign (lexp, bind), _), env -> lexp, bind, env - | _, _ -> assert false - in - let rec last_typ = function [exp] -> typ_of exp | _ :: exps -> last_typ exps | [] -> unit_typ in - let rest = check_block l env exps ret_typ in - let typ = last_typ rest in - [annot_exp (E_var (lexp, bind, annot_exp (E_block rest) typ ret_typ)) typ ret_typ] - end + | E_aux (E_assign (lexp, bind), (assign_l, _)) :: exps -> begin + match lexp_assignment_type env lexp with + | Update -> + let texp, env = bind_assignment assign_l env lexp bind in + texp :: check_block l env exps ret_typ + | Declaration -> + let lexp, bind, env = + match bind_assignment l env lexp bind with + | E_aux (E_assign (lexp, bind), _), env -> (lexp, bind, env) + | _, _ -> assert false + in + let rec last_typ = function [exp] -> typ_of exp | _ :: exps -> last_typ exps | [] -> unit_typ in + let rest = check_block l env exps ret_typ in + let typ = last_typ rest in + [annot_exp (E_var (lexp, bind, annot_exp (E_block rest) typ ret_typ)) typ ret_typ] + end | [exp] -> [final env exp] - | (E_aux (E_app (f, [E_aux (E_constraint nc, _)]), _) :: exps) when string_of_id f = "_assume" -> - Env.wf_constraint env nc; - let env = Env.add_constraint nc env in - let annotated_exp = annot_exp (E_app (f, [annot_exp (E_constraint nc) bool_typ None])) unit_typ None in - annotated_exp :: check_block l env exps ret_typ - | ((E_aux (E_assert (constr_exp, msg), (assert_l, _))) :: exps) -> - let msg = assert_msg msg in - let constr_exp = crule check_exp env constr_exp bool_typ in - let checked_msg = crule check_exp env msg string_typ in - let env, added_constraint = match assert_constraint env true constr_exp with - | Some nc -> - typ_print (lazy (adding ^ "constraint " ^ string_of_n_constraint nc ^ " for assert")); - Env.add_constraint ~reason:(assert_l, "assertion") nc env, true - | None -> env, false - in - let texp = annot_exp (E_assert (constr_exp, checked_msg)) unit_typ (Some unit_typ) in - let checked_exps = check_block l env exps ret_typ in - (* If we can prove false, then any code after the assertion is - dead. In this inconsistent typing environment we can do some - broken things, so we eliminate this dead code here *) - if added_constraint && List.compare_length_with exps 1 >= 0 && prove __POS__ env nc_false then ( - let ret_typ = List.rev checked_exps |> List.hd |> typ_of in - texp :: [crule check_exp env (mk_exp ~loc:assert_l (E_exit (mk_lit_exp L_unit))) ret_typ] - ) else ( - texp :: checked_exps - ) - | ((E_aux (E_if (cond, (E_aux (E_throw _, _) | E_aux (E_block [E_aux (E_throw _, _)], _)), _), _) as exp) :: exps) -> - let texp = crule check_exp env exp (mk_typ (Typ_id (mk_id "unit"))) in - let cond' = crule check_exp env cond (mk_typ (Typ_id (mk_id "bool"))) in - let env = add_opt_constraint l "if-throw" (Option.map nc_not (assert_constraint env false cond')) env in - texp :: check_block l env exps ret_typ - | ((E_aux (E_if (cond, then_exp, _), _) as exp) :: exps) when exp_unconditionally_returns then_exp -> - let texp = crule check_exp env exp (mk_typ (Typ_id (mk_id "unit"))) in - let cond' = crule check_exp env cond (mk_typ (Typ_id (mk_id "bool"))) in - let env = add_opt_constraint l "unconditional if" (Option.map nc_not (assert_constraint env false cond')) env in - texp :: check_block l env exps ret_typ - | (exp :: exps) -> - let texp = crule check_exp env exp (mk_typ (Typ_id (mk_id "unit"))) in - texp :: check_block l env exps ret_typ + | E_aux (E_app (f, [E_aux (E_constraint nc, _)]), _) :: exps when string_of_id f = "_assume" -> + Env.wf_constraint env nc; + let env = Env.add_constraint nc env in + let annotated_exp = annot_exp (E_app (f, [annot_exp (E_constraint nc) bool_typ None])) unit_typ None in + annotated_exp :: check_block l env exps ret_typ + | E_aux (E_assert (constr_exp, msg), (assert_l, _)) :: exps -> + let msg = assert_msg msg in + let constr_exp = crule check_exp env constr_exp bool_typ in + let checked_msg = crule check_exp env msg string_typ in + let env, added_constraint = + match assert_constraint env true constr_exp with + | Some nc -> + typ_print (lazy (adding ^ "constraint " ^ string_of_n_constraint nc ^ " for assert")); + (Env.add_constraint ~reason:(assert_l, "assertion") nc env, true) + | None -> (env, false) + in + let texp = annot_exp (E_assert (constr_exp, checked_msg)) unit_typ (Some unit_typ) in + let checked_exps = check_block l env exps ret_typ in + (* If we can prove false, then any code after the assertion is + dead. In this inconsistent typing environment we can do some + broken things, so we eliminate this dead code here *) + if added_constraint && List.compare_length_with exps 1 >= 0 && prove __POS__ env nc_false then ( + let ret_typ = List.rev checked_exps |> List.hd |> typ_of in + texp :: [crule check_exp env (mk_exp ~loc:assert_l (E_exit (mk_lit_exp L_unit))) ret_typ] + ) + else texp :: checked_exps + | (E_aux (E_if (cond, (E_aux (E_throw _, _) | E_aux (E_block [E_aux (E_throw _, _)], _)), _), _) as exp) :: exps -> + let texp = crule check_exp env exp (mk_typ (Typ_id (mk_id "unit"))) in + let cond' = crule check_exp env cond (mk_typ (Typ_id (mk_id "bool"))) in + let env = add_opt_constraint l "if-throw" (Option.map nc_not (assert_constraint env false cond')) env in + texp :: check_block l env exps ret_typ + | (E_aux (E_if (cond, then_exp, _), _) as exp) :: exps when exp_unconditionally_returns then_exp -> + let texp = crule check_exp env exp (mk_typ (Typ_id (mk_id "unit"))) in + let cond' = crule check_exp env cond (mk_typ (Typ_id (mk_id "bool"))) in + let env = add_opt_constraint l "unconditional if" (Option.map nc_not (assert_constraint env false cond')) env in + texp :: check_block l env exps ret_typ + | exp :: exps -> + let texp = crule check_exp env exp (mk_typ (Typ_id (mk_id "unit"))) in + texp :: check_block l env exps ret_typ and check_case env pat_typ pexp typ = let pat, guard, case, ((l, _) as annot) = destruct_pexp pexp in @@ -3548,57 +3636,55 @@ and check_case env pat_typ pexp typ = let env = bind_pattern_vector_subranges pat env in match bind_pat env pat pat_typ with | tpat, env, guards -> - let guard = match guard, guards with - | None, h::t -> Some (h,t) - | Some x, l -> Some (x,l) - | None, [] -> None - in - let guard = match guard with - | Some (h,t) -> - Some (List.fold_left (fun acc guard -> mk_exp (E_app_infix (acc, mk_id "&", guard))) h t) - | None -> None - in - let checked_guard, env' = match guard with - | None -> None, env - | Some guard -> - let checked_guard = check_exp env guard bool_typ in - Some checked_guard, add_opt_constraint l "guard pattern" (assert_constraint env true checked_guard) env - in - let checked_case = crule check_exp env' case typ in - construct_pexp (tpat, checked_guard, checked_case, (l, empty_tannot)) + let guard = + match (guard, guards) with None, h :: t -> Some (h, t) | Some x, l -> Some (x, l) | None, [] -> None + in + let guard = + match guard with + | Some (h, t) -> Some (List.fold_left (fun acc guard -> mk_exp (E_app_infix (acc, mk_id "&", guard))) h t) + | None -> None + in + let checked_guard, env' = + match guard with + | None -> (None, env) + | Some guard -> + let checked_guard = check_exp env guard bool_typ in + (Some checked_guard, add_opt_constraint l "guard pattern" (assert_constraint env true checked_guard) env) + in + let checked_case = crule check_exp env' case typ in + construct_pexp (tpat, checked_guard, checked_case, (l, empty_tannot)) (* AA: Not sure if we still need this *) - | exception (Type_error _ as typ_exn) -> - match pat with - | P_aux (P_lit lit, _) -> - let guard' = mk_exp (E_app_infix (mk_exp (E_id (mk_id "p#")), mk_id "==", mk_exp (E_lit lit))) in - let guard = match guard with - | None -> guard' - | Some guard -> mk_exp (E_app_infix (guard, mk_id "&", guard')) - in - check_case env pat_typ (Pat_aux (Pat_when (mk_pat (P_id (mk_id "p#")), guard, case), annot)) typ - | _ -> raise typ_exn + | exception (Type_error _ as typ_exn) -> ( + match pat with + | P_aux (P_lit lit, _) -> + let guard' = mk_exp (E_app_infix (mk_exp (E_id (mk_id "p#")), mk_id "==", mk_exp (E_lit lit))) in + let guard = + match guard with None -> guard' | Some guard -> mk_exp (E_app_infix (guard, mk_id "&", guard')) + in + check_case env pat_typ (Pat_aux (Pat_when (mk_pat (P_id (mk_id "p#")), guard, case), annot)) typ + | _ -> raise typ_exn + ) and check_mpexp other_env env mpexp typ = - let mpat,guard,((l,_) as annot) = destruct_mpexp mpexp in + let mpat, guard, ((l, _) as annot) = destruct_mpexp mpexp in match bind_mpat false other_env env mpat typ with | checked_mpat, env, guards -> - let guard = match guard, guards with - | None, h::t -> Some (h,t) - | Some x, l -> Some (x,l) - | None, [] -> None - in - let guard = match guard with - | Some (h,t) -> - Some (List.fold_left (fun acc guard -> mk_exp (E_app_infix (acc, mk_id "&", guard))) h t) - | None -> None - in - let checked_guard, _ = match guard with - | None -> None, env - | Some guard -> - let checked_guard = check_exp env guard bool_typ in - Some checked_guard, env - in - construct_mpexp (checked_mpat, checked_guard, (l, empty_tannot)) + let guard = + match (guard, guards) with None, h :: t -> Some (h, t) | Some x, l -> Some (x, l) | None, [] -> None + in + let guard = + match guard with + | Some (h, t) -> Some (List.fold_left (fun acc guard -> mk_exp (E_app_infix (acc, mk_id "&", guard))) h t) + | None -> None + in + let checked_guard, _ = + match guard with + | None -> (None, env) + | Some guard -> + let checked_guard = check_exp env guard bool_typ in + (Some checked_guard, env) + in + construct_mpexp (checked_mpat, checked_guard, (l, empty_tannot)) (* type_coercion env exp typ takes a fully annoted (i.e. already type checked) expression exp, and attempts to cast (coerce) it to the @@ -3609,29 +3695,36 @@ and check_mpexp other_env env mpexp typ = and type_coercion env (E_aux (_, (l, _)) as annotated_exp) typ = let strip exp_aux = strip_exp (E_aux (exp_aux, (Parse_ast.Unknown, empty_tannot))) in let annot_exp exp typ' = E_aux (exp, (l, mk_expected_tannot env typ' (Some typ))) in - let switch_exp_typ exp = match exp with + let switch_exp_typ exp = + match exp with | E_aux (exp, (l, (Some tannot, uannot))) -> E_aux (exp, (l, (Some { tannot with expected = Some typ }, uannot))) | _ -> failwith "Cannot switch type for unannotated function" in let rec try_casts trigger errs = function | [] -> typ_raise env l (Err_no_casts (strip_exp annotated_exp, typ_of annotated_exp, typ, trigger, errs)) - | (cast :: casts) -> begin - typ_print (lazy ("Casting with " ^ string_of_id cast ^ " expression " ^ string_of_exp annotated_exp ^ " to " ^ string_of_typ typ)); + | cast :: casts -> begin + typ_print + ( lazy + ("Casting with " ^ string_of_id cast ^ " expression " ^ string_of_exp annotated_exp ^ " to " + ^ string_of_typ typ + ) + ); try let checked_cast = crule check_exp (Env.no_casts env) (strip (E_app (cast, [annotated_exp]))) typ in annot_exp (E_typ (typ, checked_cast)) typ - with - | Type_error (_, _, err) -> try_casts trigger (err :: errs) casts + with Type_error (_, _, err) -> try_casts trigger (err :: errs) casts end in begin try - typ_debug (lazy ("Performing type coercion: from " ^ string_of_typ (typ_of annotated_exp) ^ " to " ^ string_of_typ typ)); - subtyp l env (typ_of annotated_exp) typ; switch_exp_typ annotated_exp + typ_debug + (lazy ("Performing type coercion: from " ^ string_of_typ (typ_of annotated_exp) ^ " to " ^ string_of_typ typ)); + subtyp l env (typ_of annotated_exp) typ; + switch_exp_typ annotated_exp with | Type_error (_, _, trigger) when Env.allow_casts env -> - let casts = filter_casts env (typ_of annotated_exp) typ (Env.get_casts env) in - try_casts trigger [] casts + let casts = filter_casts env (typ_of annotated_exp) typ (Env.get_casts env) in + try_casts trigger [] casts | Type_error (env, l, err) -> typ_raise env l err end @@ -3644,12 +3737,15 @@ and type_coercion_unify env goals (E_aux (_, (l, _)) as annotated_exp) typ = let strip exp_aux = strip_exp (E_aux (exp_aux, (Parse_ast.Unknown, empty_tannot))) in let rec try_casts = function | [] -> unify_error l "No valid casts resulted in unification" - | (cast :: casts) -> begin - typ_print (lazy ("Casting with " ^ string_of_id cast ^ " expression " ^ string_of_exp annotated_exp ^ " for unification")); + | cast :: casts -> begin + typ_print + ( lazy + ("Casting with " ^ string_of_id cast ^ " expression " ^ string_of_exp annotated_exp ^ " for unification") + ); try let inferred_cast = irule infer_exp (Env.no_casts env) (strip (E_app (cast, [annotated_exp]))) in let ityp, env = bind_existential l None (typ_of inferred_cast) env in - inferred_cast, unify l env (KidSet.diff goals (ambiguous_vars typ)) typ ityp, env + (inferred_cast, unify l env (KidSet.diff goals (ambiguous_vars typ)) typ ityp, env) with | Type_error _ -> try_casts casts | Unification_error _ -> try_casts casts @@ -3657,980 +3753,1064 @@ and type_coercion_unify env goals (E_aux (_, (l, _)) as annotated_exp) typ = in begin try - typ_debug (lazy ("Coercing unification: from " ^ string_of_typ (typ_of annotated_exp) ^ " to " ^ string_of_typ typ)); + typ_debug + (lazy ("Coercing unification: from " ^ string_of_typ (typ_of annotated_exp) ^ " to " ^ string_of_typ typ)); let atyp, env = bind_existential l None (typ_of annotated_exp) env in let atyp, env = bind_tuple_existentials l None atyp env in - annotated_exp, unify l env (KidSet.diff goals (ambiguous_vars typ)) typ atyp, env - with - | Unification_error (_, _) when Env.allow_casts env -> - let casts = filter_casts env (typ_of annotated_exp) typ (Env.get_casts env) in - try_casts casts + (annotated_exp, unify l env (KidSet.diff goals (ambiguous_vars typ)) typ atyp, env) + with Unification_error (_, _) when Env.allow_casts env -> + let casts = filter_casts env (typ_of annotated_exp) typ (Env.get_casts env) in + try_casts casts end and bind_pat_no_guard env (P_aux (_, (l, _)) as pat) typ = match bind_pat env pat typ with - | _, _, _::_ -> typ_error env l "Literal patterns not supported here" - | tpat, env, [] -> tpat, env + | _, _, _ :: _ -> typ_error env l "Literal patterns not supported here" + | tpat, env, [] -> (tpat, env) and bind_pat env (P_aux (pat_aux, (l, uannot)) as pat) typ = let typ, env = bind_existential l (name_pat pat) typ env in - typ_print (lazy (Util.("Binding " |> yellow |> clear) ^ string_of_pat pat ^ " to " ^ string_of_typ typ)); - let annot_pat pat typ' = P_aux (pat, (l, mk_expected_tannot ~uannot:uannot env typ' (Some typ))) in - let switch_typ pat typ = match pat with - | P_aux (pat_aux, (l, (Some tannot, uannot))) -> P_aux (pat_aux, (l, (Some { tannot with typ = typ }, uannot))) + typ_print (lazy (Util.("Binding " |> yellow |> clear) ^ string_of_pat pat ^ " to " ^ string_of_typ typ)); + let annot_pat pat typ' = P_aux (pat, (l, mk_expected_tannot ~uannot env typ' (Some typ))) in + let switch_typ pat typ = + match pat with + | P_aux (pat_aux, (l, (Some tannot, uannot))) -> P_aux (pat_aux, (l, (Some { tannot with typ }, uannot))) | _ -> typ_error env l "Cannot switch type for unannotated pattern" in let bind_tuple_pat (tpats, env, guards) pat typ = - let tpat, env, guards' = bind_pat env pat typ in tpat :: tpats, env, guards' @ guards + let tpat, env, guards' = bind_pat env pat typ in + (tpat :: tpats, env, guards' @ guards) in match pat_aux with - | P_id v -> - begin - (* If the identifier we're matching on is also a constructor of - a union, that's probably a mistake, so warn about it. *) - if Env.is_union_constructor v env then ( - Reporting.warn - (Printf.sprintf "Identifier %s found in pattern is also a union constructor at" (string_of_id v)) - l - (Printf.sprintf "Suggestion: Maybe you meant to match against %s() instead?" (string_of_id v)) - ); - match Env.lookup_id v env with - | Local _ | Unbound _ -> annot_pat (P_id v) typ, Env.add_local v (Immutable, typ) env, [] - | Register _ -> - typ_error env l ("Cannot shadow register in pattern " ^ string_of_pat pat) - | Enum enum -> subtyp l env enum typ; annot_pat (P_id v) typ, env, [] - end + | P_id v -> begin + (* If the identifier we're matching on is also a constructor of + a union, that's probably a mistake, so warn about it. *) + if Env.is_union_constructor v env then + Reporting.warn + (Printf.sprintf "Identifier %s found in pattern is also a union constructor at" (string_of_id v)) + l + (Printf.sprintf "Suggestion: Maybe you meant to match against %s() instead?" (string_of_id v)); + match Env.lookup_id v env with + | Local _ | Unbound _ -> (annot_pat (P_id v) typ, Env.add_local v (Immutable, typ) env, []) + | Register _ -> typ_error env l ("Cannot shadow register in pattern " ^ string_of_pat pat) + | Enum enum -> + subtyp l env enum typ; + (annot_pat (P_id v) typ, env, []) + end | P_var (pat, typ_pat) -> - let env, typ = bind_typ_pat env typ_pat typ in - let typed_pat, env, guards = bind_pat env pat typ in - annot_pat (P_var (typed_pat, typ_pat)) typ, env, guards + let env, typ = bind_typ_pat env typ_pat typ in + let typed_pat, env, guards = bind_pat env pat typ in + (annot_pat (P_var (typed_pat, typ_pat)) typ, env, guards) | P_wild -> - let env = match get_attribute "int_wildcard" uannot with - | Some (_, arg) -> - (* If the patterh completeness checker replaced an numeric pattern, modify the environment as if it hadn't *) - let _, env, _ = bind_pat env (P_aux (P_lit (L_aux (L_num (Big_int.of_string arg), gen_loc l)), (l, uannot))) typ in - env - | None -> env in - annot_pat P_wild typ, env, [] + let env = + match get_attribute "int_wildcard" uannot with + | Some (_, arg) -> + (* If the patterh completeness checker replaced an numeric pattern, modify the environment as if it hadn't *) + let _, env, _ = + bind_pat env (P_aux (P_lit (L_aux (L_num (Big_int.of_string arg), gen_loc l)), (l, uannot))) typ + in + env + | None -> env + in + (annot_pat P_wild typ, env, []) | P_or (pat1, pat2) -> - let tpat1, _, guards1 = bind_pat (Env.no_bindings env) pat1 typ in - let tpat2, _, guards2 = bind_pat (Env.no_bindings env) pat2 typ in - annot_pat (P_or (tpat1, tpat2)) typ, env, guards1 @ guards2 + let tpat1, _, guards1 = bind_pat (Env.no_bindings env) pat1 typ in + let tpat2, _, guards2 = bind_pat (Env.no_bindings env) pat2 typ in + (annot_pat (P_or (tpat1, tpat2)) typ, env, guards1 @ guards2) | P_not pat -> - let tpat, _, guards = bind_pat (Env.no_bindings env) pat typ in - annot_pat (P_not(tpat)) typ, env, guards - | P_cons (hd_pat, tl_pat) -> - begin - match Env.expand_synonyms env typ with - | Typ_aux (Typ_app (f, [A_aux (A_typ ltyp, _)]), _) when Id.compare f (mk_id "list") = 0 -> + let tpat, _, guards = bind_pat (Env.no_bindings env) pat typ in + (annot_pat (P_not tpat) typ, env, guards) + | P_cons (hd_pat, tl_pat) -> begin + match Env.expand_synonyms env typ with + | Typ_aux (Typ_app (f, [A_aux (A_typ ltyp, _)]), _) when Id.compare f (mk_id "list") = 0 -> let hd_pat, env, hd_guards = bind_pat env hd_pat ltyp in let tl_pat, env, tl_guards = bind_pat env tl_pat typ in - annot_pat (P_cons (hd_pat, tl_pat)) typ, env, hd_guards @ tl_guards - | _ -> typ_error env l "Cannot match cons pattern against non-list type" - end - | P_string_append pats -> - begin - match Env.expand_synonyms env typ with - | Typ_aux (Typ_id id, _) when Id.compare id (mk_id "string") = 0 -> + (annot_pat (P_cons (hd_pat, tl_pat)) typ, env, hd_guards @ tl_guards) + | _ -> typ_error env l "Cannot match cons pattern against non-list type" + end + | P_string_append pats -> begin + match Env.expand_synonyms env typ with + | Typ_aux (Typ_id id, _) when Id.compare id (mk_id "string") = 0 -> let rec process_pats env = function - | [] -> [], env, [] + | [] -> ([], env, []) | pat :: pats -> - let pat', env, guards = bind_pat env pat typ in - let pats', env, guards' = process_pats env pats in - pat' :: pats', env, guards @ guards' + let pat', env, guards = bind_pat env pat typ in + let pats', env, guards' = process_pats env pats in + (pat' :: pats', env, guards @ guards') in let pats, env, guards = process_pats env pats in - annot_pat (P_string_append pats) typ, env, guards - | _ -> typ_error env l "Cannot match string-append pattern against non-string type" - end - | P_list pats -> - begin - match Env.expand_synonyms env typ with - | Typ_aux (Typ_app (f, [A_aux (A_typ ltyp, _)]), _) when Id.compare f (mk_id "list") = 0 -> + (annot_pat (P_string_append pats) typ, env, guards) + | _ -> typ_error env l "Cannot match string-append pattern against non-string type" + end + | P_list pats -> begin + match Env.expand_synonyms env typ with + | Typ_aux (Typ_app (f, [A_aux (A_typ ltyp, _)]), _) when Id.compare f (mk_id "list") = 0 -> let rec process_pats env = function - | [] -> [], env, [] - | (pat :: pats) -> - let pat', env, guards = bind_pat env pat ltyp in - let pats', env, guards' = process_pats env pats in - pat' :: pats', env, guards @ guards' + | [] -> ([], env, []) + | pat :: pats -> + let pat', env, guards = bind_pat env pat ltyp in + let pats', env, guards' = process_pats env pats in + (pat' :: pats', env, guards @ guards') in let pats, env, guards = process_pats env pats in - annot_pat (P_list pats) typ, env, guards - | _ -> typ_error env l ("Cannot match list pattern " ^ string_of_pat pat ^ " against non-list type " ^ string_of_typ typ) - end - | P_tuple [] -> - begin - match Env.expand_synonyms env typ with - | Typ_aux (Typ_id typ_id, _) when string_of_id typ_id = "unit" -> - annot_pat (P_tuple []) typ, env, [] - | _ -> typ_error env l "Cannot match unit pattern against non-unit type" - end - | P_tuple pats -> - begin - match Env.expand_synonyms env typ with - | Typ_aux (Typ_tuple typs, _) -> + (annot_pat (P_list pats) typ, env, guards) + | _ -> + typ_error env l + ("Cannot match list pattern " ^ string_of_pat pat ^ " against non-list type " ^ string_of_typ typ) + end + | P_tuple [] -> begin + match Env.expand_synonyms env typ with + | Typ_aux (Typ_id typ_id, _) when string_of_id typ_id = "unit" -> (annot_pat (P_tuple []) typ, env, []) + | _ -> typ_error env l "Cannot match unit pattern against non-unit type" + end + | P_tuple pats -> begin + match Env.expand_synonyms env typ with + | Typ_aux (Typ_tuple typs, _) -> let tpats, env, guards = - try List.fold_left2 bind_tuple_pat ([], env, []) pats typs with - | Invalid_argument _ -> typ_error env l "Tuple pattern and tuple type have different length" + try List.fold_left2 bind_tuple_pat ([], env, []) pats typs + with Invalid_argument _ -> typ_error env l "Tuple pattern and tuple type have different length" in - annot_pat (P_tuple (List.rev tpats)) typ, env, guards - | _ -> - typ_error env l (Printf.sprintf "Cannot bind tuple pattern %s against non tuple type %s" - (string_of_pat pat) (string_of_typ typ)) - end + (annot_pat (P_tuple (List.rev tpats)) typ, env, guards) + | _ -> + typ_error env l + (Printf.sprintf "Cannot bind tuple pattern %s against non tuple type %s" (string_of_pat pat) + (string_of_typ typ) + ) + end | P_app (f, [pat]) when Env.is_union_constructor f env -> - let (typq, ctor_typ) = Env.get_union_id f env in - let quants = quant_items typq in - begin match Env.expand_synonyms (Env.add_typquant l typq env) ctor_typ with - | Typ_aux (Typ_fn ([arg_typ], ret_typ), _) -> - begin + let typq, ctor_typ = Env.get_union_id f env in + let quants = quant_items typq in + begin + match Env.expand_synonyms (Env.add_typquant l typq env) ctor_typ with + | Typ_aux (Typ_fn ([arg_typ], ret_typ), _) -> begin + try + let goals = quant_kopts typq |> List.map kopt_kid |> KidSet.of_list in + typ_debug (lazy ("Unifying " ^ string_of_bind (typq, ctor_typ) ^ " for pattern " ^ string_of_typ typ)); + let unifiers = unify l env goals ret_typ typ in + let arg_typ' = subst_unifiers unifiers arg_typ in + let quants' = List.fold_left (instantiate_quants env) quants (KBindings.bindings unifiers) in + if not (List.for_all (solve_quant env) quants') then + typ_raise env l (Err_unresolved_quants (f, quants', Env.get_locals env, Env.get_constraints env)) + else (); + let _ret_typ' = subst_unifiers unifiers ret_typ in + let arg_typ', env = bind_existential l None arg_typ' env in + let tpat, env, guards = bind_pat env pat arg_typ' in + (annot_pat (P_app (f, [tpat])) typ, env, guards) + with Unification_error (l, m) -> + typ_error env l ("Unification error when pattern matching against union constructor: " ^ m) + end + | _ -> typ_error env l ("Mal-formed constructor " ^ string_of_id f ^ " with type " ^ string_of_typ ctor_typ) + end + | P_app (f, pats) when Env.is_union_constructor f env -> + (* Treat Ctor(x, y) as Ctor((x, y)) *) + bind_pat env (P_aux (P_app (f, [mk_pat (P_tuple pats)]), (l, uannot))) typ + | P_app (f, pats) when Env.is_mapping f env -> begin + let typq, mapping_typ = Env.get_val_spec f env in + let quants = quant_items typq in + let untuple (Typ_aux (typ_aux, _) as typ) = match typ_aux with Typ_tuple typs -> typs | _ -> [typ] in + match Env.expand_synonyms env mapping_typ with + | Typ_aux (Typ_bidir (typ1, typ2), _) -> begin try - let goals = quant_kopts typq |> List.map kopt_kid |> KidSet.of_list in - typ_debug (lazy ("Unifying " ^ string_of_bind (typq, ctor_typ) ^ " for pattern " ^ string_of_typ typ)); - let unifiers = unify l env goals ret_typ typ in - let arg_typ' = subst_unifiers unifiers arg_typ in + typ_debug (lazy ("Unifying " ^ string_of_bind (typq, mapping_typ) ^ " for pattern " ^ string_of_typ typ)); + + (* FIXME: There's no obvious goals here *) + let unifiers = unify l env (tyvars_of_typ typ2) typ2 typ in + let arg_typ' = subst_unifiers unifiers typ1 in let quants' = List.fold_left (instantiate_quants env) quants (KBindings.bindings unifiers) in - if not (List.for_all (solve_quant env) quants') then - typ_raise env l (Err_unresolved_quants (f, quants', Env.get_locals env, Env.get_constraints env)) + if match quants' with [] -> false | _ -> true then + typ_error env l + ("Quantifiers " + ^ string_of_list ", " string_of_quant_item quants' + ^ " not resolved in pattern " ^ string_of_pat pat + ) else (); - let _ret_typ' = subst_unifiers unifiers ret_typ in - let arg_typ', env = bind_existential l None arg_typ' env in - let tpat, env, guards = bind_pat env pat arg_typ' in - annot_pat (P_app (f, [tpat])) typ, env, guards - with - | Unification_error (l, m) -> typ_error env l ("Unification error when pattern matching against union constructor: " ^ m) - end - | _ -> typ_error env l ("Mal-formed constructor " ^ string_of_id f ^ " with type " ^ string_of_typ ctor_typ) - end - | P_app (f, pats) when Env.is_union_constructor f env -> - (* Treat Ctor(x, y) as Ctor((x, y)) *) - bind_pat env (P_aux (P_app (f, [mk_pat (P_tuple pats)]), (l, uannot))) typ - - | P_app (f, pats) when Env.is_mapping f env -> - begin - let (typq, mapping_typ) = Env.get_val_spec f env in - let quants = quant_items typq in - let untuple (Typ_aux (typ_aux, _) as typ) = match typ_aux with - | Typ_tuple typs -> typs - | _ -> [typ] - in - match Env.expand_synonyms env mapping_typ with - | Typ_aux (Typ_bidir (typ1, typ2), _) -> - begin + let _ret_typ' = subst_unifiers unifiers typ2 in + let tpats, env, guards = + try List.fold_left2 bind_tuple_pat ([], env, []) pats (untuple arg_typ') + with Invalid_argument _ -> typ_error env l "Mapping pattern arguments have incorrect length" + in + (annot_pat (P_app (f, List.rev tpats)) typ, env, guards) + with Unification_error (l, _) -> ( try + typ_debug (lazy "Unifying mapping forwards failed, trying backwards."); typ_debug (lazy ("Unifying " ^ string_of_bind (typq, mapping_typ) ^ " for pattern " ^ string_of_typ typ)); - - (* FIXME: There's no obvious goals here *) - let unifiers = unify l env (tyvars_of_typ typ2) typ2 typ in - let arg_typ' = subst_unifiers unifiers typ1 in + let unifiers = unify l env (tyvars_of_typ typ1) typ1 typ in + let arg_typ' = subst_unifiers unifiers typ2 in let quants' = List.fold_left (instantiate_quants env) quants (KBindings.bindings unifiers) in - if (match quants' with [] -> false | _ -> true) - then typ_error env l ("Quantifiers " ^ string_of_list ", " string_of_quant_item quants' ^ " not resolved in pattern " ^ string_of_pat pat) + if match quants' with [] -> false | _ -> true then + typ_error env l + ("Quantifiers " + ^ string_of_list ", " string_of_quant_item quants' + ^ " not resolved in pattern " ^ string_of_pat pat + ) else (); - - let _ret_typ' = subst_unifiers unifiers typ2 in + let _ret_typ' = subst_unifiers unifiers typ1 in let tpats, env, guards = - try List.fold_left2 bind_tuple_pat ([], env, []) pats (untuple arg_typ') with - | Invalid_argument _ -> typ_error env l "Mapping pattern arguments have incorrect length" + try List.fold_left2 bind_tuple_pat ([], env, []) pats (untuple arg_typ') + with Invalid_argument _ -> typ_error env l "Mapping pattern arguments have incorrect length" in - annot_pat (P_app (f, List.rev tpats)) typ, env, guards - with - | Unification_error (l, _) -> - try - typ_debug (lazy "Unifying mapping forwards failed, trying backwards."); - typ_debug (lazy ("Unifying " ^ string_of_bind (typq, mapping_typ) ^ " for pattern " ^ string_of_typ typ)); - let unifiers = unify l env (tyvars_of_typ typ1) typ1 typ in - let arg_typ' = subst_unifiers unifiers typ2 in - let quants' = List.fold_left (instantiate_quants env) quants (KBindings.bindings unifiers) in - if (match quants' with [] -> false | _ -> true) - then typ_error env l ("Quantifiers " ^ string_of_list ", " string_of_quant_item quants' ^ " not resolved in pattern " ^ string_of_pat pat) - else (); - let _ret_typ' = subst_unifiers unifiers typ1 in - let tpats, env, guards = - try List.fold_left2 bind_tuple_pat ([], env, []) pats (untuple arg_typ') with - | Invalid_argument _ -> typ_error env l "Mapping pattern arguments have incorrect length" - in - annot_pat (P_app (f, List.rev tpats)) typ, env, guards - with - | Unification_error (l, m) -> typ_error env l ("Unification error when pattern matching against mapping constructor: " ^ m) - end - | _ -> typ_error env l ("Mal-formed mapping " ^ string_of_id f) - end - - | P_app (f, _) when (not (Env.is_union_constructor f env) && not (Env.is_mapping f env)) -> - typ_error env l (string_of_id f ^ " is not a union constructor or mapping in pattern " ^ string_of_pat pat) + (annot_pat (P_app (f, List.rev tpats)) typ, env, guards) + with Unification_error (l, m) -> + typ_error env l ("Unification error when pattern matching against mapping constructor: " ^ m) + ) + end + | _ -> typ_error env l ("Mal-formed mapping " ^ string_of_id f) + end + | P_app (f, _) when (not (Env.is_union_constructor f env)) && not (Env.is_mapping f env) -> + typ_error env l (string_of_id f ^ " is not a union constructor or mapping in pattern " ^ string_of_pat pat) | P_as (pat, id) -> - let (typed_pat, env, guards) = bind_pat env pat typ in - annot_pat (P_as (typed_pat, id)) (typ_of_pat typed_pat), Env.add_local id (Immutable, typ_of_pat typed_pat) env, guards + let typed_pat, env, guards = bind_pat env pat typ in + ( annot_pat (P_as (typed_pat, id)) (typ_of_pat typed_pat), + Env.add_local id (Immutable, typ_of_pat typed_pat) env, + guards + ) (* This is a special case for flow typing when we match a constant numeric literal. *) | P_lit (L_aux (L_num n, _) as lit) when is_atom typ -> - let nexp = match destruct_atom_nexp env typ with Some n -> n | None -> assert false in - annot_pat (P_lit lit) (atom_typ (nconstant n)), Env.add_constraint (nc_eq nexp (nconstant n)) env, [] + let nexp = match destruct_atom_nexp env typ with Some n -> n | None -> assert false in + (annot_pat (P_lit lit) (atom_typ (nconstant n)), Env.add_constraint (nc_eq nexp (nconstant n)) env, []) | P_lit (L_aux (L_true, _) as lit) when is_atom_bool typ -> - let nc = match destruct_atom_bool env typ with Some nc -> nc | None -> assert false in - annot_pat (P_lit lit) (atom_bool_typ nc_true), Env.add_constraint nc env, [] + let nc = match destruct_atom_bool env typ with Some nc -> nc | None -> assert false in + (annot_pat (P_lit lit) (atom_bool_typ nc_true), Env.add_constraint nc env, []) | P_lit (L_aux (L_false, _) as lit) when is_atom_bool typ -> - let nc = match destruct_atom_bool env typ with Some nc -> nc | None -> assert false in - annot_pat (P_lit lit) (atom_bool_typ nc_false), Env.add_constraint (nc_not nc) env, [] - | P_vector_concat (pat :: pats) -> - bind_vector_concat_pat l env uannot pat pats (Some typ) - | _ -> - let (inferred_pat, env, guards) = infer_pat env pat in - match subtyp l env typ (typ_of_pat inferred_pat) with - | () -> switch_typ inferred_pat (typ_of_pat inferred_pat), env, guards - | exception (Type_error _ as typ_exn) -> - match pat_aux with - | P_lit lit -> - let var = fresh_var () in - let guard = locate (fun _ -> l) (mk_exp (E_app_infix (mk_exp (E_id var), mk_id "==", mk_exp (E_lit lit)))) in - let (typed_pat, env, guards) = bind_pat env (mk_pat (P_id var)) typ in - typed_pat, env, guard::guards - | _ -> raise typ_exn + let nc = match destruct_atom_bool env typ with Some nc -> nc | None -> assert false in + (annot_pat (P_lit lit) (atom_bool_typ nc_false), Env.add_constraint (nc_not nc) env, []) + | P_vector_concat (pat :: pats) -> bind_vector_concat_pat l env uannot pat pats (Some typ) + | _ -> ( + let inferred_pat, env, guards = infer_pat env pat in + match subtyp l env typ (typ_of_pat inferred_pat) with + | () -> (switch_typ inferred_pat (typ_of_pat inferred_pat), env, guards) + | exception (Type_error _ as typ_exn) -> ( + match pat_aux with + | P_lit lit -> + let var = fresh_var () in + let guard = + locate (fun _ -> l) (mk_exp (E_app_infix (mk_exp (E_id var), mk_id "==", mk_exp (E_lit lit)))) + in + let typed_pat, env, guards = bind_pat env (mk_pat (P_id var)) typ in + (typed_pat, env, guard :: guards) + | _ -> raise typ_exn + ) + ) and infer_pat env (P_aux (pat_aux, (l, uannot)) as pat) = - let annot_pat pat typ = P_aux (pat, (l, mk_tannot ~uannot:uannot env typ)) in + let annot_pat pat typ = P_aux (pat, (l, mk_tannot ~uannot env typ)) in match pat_aux with - | P_id v -> - begin - match Env.lookup_id v env with - | Local (Immutable, _) | Unbound _ -> + | P_id v -> begin + match Env.lookup_id v env with + | Local (Immutable, _) | Unbound _ -> typ_error env l ("Cannot infer identifier in pattern " ^ string_of_pat pat ^ " - try adding a type annotation") - | Local (Mutable, _) | Register _ -> + | Local (Mutable, _) | Register _ -> typ_error env l ("Cannot shadow mutable local or register in switch statement pattern " ^ string_of_pat pat) - | Enum enum -> annot_pat (P_id v) enum, env, [] - end - | P_app (f, _) when Env.is_union_constructor f env -> - begin - let (_, ctor_typ) = Env.get_val_spec f env in - match Env.expand_synonyms env ctor_typ with - | Typ_aux (Typ_fn (_, ret_typ), _) -> - bind_pat env pat ret_typ - | _ -> typ_error env l ("Mal-formed constructor " ^ string_of_id f) - end - | P_app (f, _) when Env.is_mapping f env -> - begin - let (_, mapping_typ) = Env.get_val_spec f env in - match Env.expand_synonyms env mapping_typ with - | Typ_aux (Typ_bidir (typ1, typ2), _) -> - begin - try - bind_pat env pat typ2 - with - | Type_error _ -> - bind_pat env pat typ1 - end - | _ -> typ_error env l ("Malformed mapping type " ^ string_of_id f) - end + | Enum enum -> (annot_pat (P_id v) enum, env, []) + end + | P_app (f, _) when Env.is_union_constructor f env -> begin + let _, ctor_typ = Env.get_val_spec f env in + match Env.expand_synonyms env ctor_typ with + | Typ_aux (Typ_fn (_, ret_typ), _) -> bind_pat env pat ret_typ + | _ -> typ_error env l ("Mal-formed constructor " ^ string_of_id f) + end + | P_app (f, _) when Env.is_mapping f env -> begin + let _, mapping_typ = Env.get_val_spec f env in + match Env.expand_synonyms env mapping_typ with + | Typ_aux (Typ_bidir (typ1, typ2), _) -> begin + try bind_pat env pat typ2 with Type_error _ -> bind_pat env pat typ1 + end + | _ -> typ_error env l ("Malformed mapping type " ^ string_of_id f) + end | P_typ (typ_annot, pat) -> - Env.wf_typ env typ_annot; - let (typed_pat, env, guards) = bind_pat env pat typ_annot in - annot_pat (P_typ (typ_annot, typed_pat)) typ_annot, env, guards - | P_lit lit -> - annot_pat (P_lit lit) (infer_lit env lit), env, [] + Env.wf_typ env typ_annot; + let typed_pat, env, guards = bind_pat env pat typ_annot in + (annot_pat (P_typ (typ_annot, typed_pat)) typ_annot, env, guards) + | P_lit lit -> (annot_pat (P_lit lit) (infer_lit env lit), env, []) | P_vector (pat :: pats) -> - let fold_pats (pats, env, guards) pat = - let typed_pat, env, guards' = bind_pat env pat bit_typ in - pats @ [typed_pat], env, guards' @ guards - in - let pats, env, guards = List.fold_left fold_pats ([], env, []) (pat :: pats) in - let len = nexp_simp (nint (List.length pats)) in - let etyp = typ_of_pat (List.hd pats) in - (* BVS TODO: Non-bitvector P_vector *) - List.iter (fun pat -> typ_equality l env etyp (typ_of_pat pat)) pats; - annot_pat (P_vector pats) (bits_typ env len), env, guards - | P_vector_concat (pat :: pats) -> - bind_vector_concat_pat l env uannot pat pats None + let fold_pats (pats, env, guards) pat = + let typed_pat, env, guards' = bind_pat env pat bit_typ in + (pats @ [typed_pat], env, guards' @ guards) + in + let pats, env, guards = List.fold_left fold_pats ([], env, []) (pat :: pats) in + let len = nexp_simp (nint (List.length pats)) in + let etyp = typ_of_pat (List.hd pats) in + (* BVS TODO: Non-bitvector P_vector *) + List.iter (fun pat -> typ_equality l env etyp (typ_of_pat pat)) pats; + (annot_pat (P_vector pats) (bits_typ env len), env, guards) + | P_vector_concat (pat :: pats) -> bind_vector_concat_pat l env uannot pat pats None | P_vector_subrange (id, n, m) -> - let typ = bitvector_typ_from_range l env n m in - annot_pat (P_vector_subrange (id, n, m)) typ, env, [] + let typ = bitvector_typ_from_range l env n m in + (annot_pat (P_vector_subrange (id, n, m)) typ, env, []) | P_string_append pats -> - let fold_pats (pats, env, guards) pat = - let inferred_pat, env, guards' = infer_pat env pat in - typ_equality l env (typ_of_pat inferred_pat) string_typ; - pats @ [inferred_pat], env, guards' @ guards - in - let typed_pats, env, guards = - List.fold_left fold_pats ([], env, []) pats - in - annot_pat (P_string_append typed_pats) string_typ, env, guards + let fold_pats (pats, env, guards) pat = + let inferred_pat, env, guards' = infer_pat env pat in + typ_equality l env (typ_of_pat inferred_pat) string_typ; + (pats @ [inferred_pat], env, guards' @ guards) + in + let typed_pats, env, guards = List.fold_left fold_pats ([], env, []) pats in + (annot_pat (P_string_append typed_pats) string_typ, env, guards) | P_as (pat, id) -> - let (typed_pat, env, guards) = infer_pat env pat in - annot_pat (P_as (typed_pat, id)) (typ_of_pat typed_pat), - Env.add_local id (Immutable, typ_of_pat typed_pat) env, - guards + let typed_pat, env, guards = infer_pat env pat in + ( annot_pat (P_as (typed_pat, id)) (typ_of_pat typed_pat), + Env.add_local id (Immutable, typ_of_pat typed_pat) env, + guards + ) | _ -> typ_error env l ("Couldn't infer type of pattern " ^ string_of_pat pat) and bind_vector_concat_pat l env uannot pat pats typ_opt = - let annot_vcp pats typ = P_aux (P_vector_concat pats, (l, mk_tannot ~uannot:uannot env typ)) in + let annot_vcp pats typ = P_aux (P_vector_concat pats, (l, mk_tannot ~uannot env typ)) in (* Try to infer a constant length, and the element type if non-bitvector *) let typ_opt = Option.bind typ_opt (fun typ -> match destruct_any_vector_typ l env typ with | Destruct_vector (len, order, elem_typ) -> - Option.map (fun l -> (l, order, Some elem_typ)) (solve_unique env len) - | Destruct_bitvector(len, order) -> - Option.map (fun l -> (l, order, None)) (solve_unique env len) - ) in + Option.map (fun l -> (l, order, Some elem_typ)) (solve_unique env len) + | Destruct_bitvector (len, order) -> Option.map (fun l -> (l, order, None)) (solve_unique env len) + ) + in (* Try to infer any subpatterns, skipping those we cannot infer *) let fold_pats (pats, env, guards) pat = let wrap_some (x, y, z) = (Ok x, y, z) in let inferred_pat, env, guards' = - if Option.is_none typ_opt then ( - wrap_some (infer_pat env pat) - ) else ( - try wrap_some (infer_pat env pat) with - | (Type_error _ as exn) -> (Error (pat, exn), env, []) - ) in - inferred_pat :: pats, env, guards' @ guards + if Option.is_none typ_opt then wrap_some (infer_pat env pat) + else (try wrap_some (infer_pat env pat) with Type_error _ as exn -> (Error (pat, exn), env, [])) + in + (inferred_pat :: pats, env, guards' @ guards) in let inferred_pats, env, guards = List.fold_left fold_pats ([], env, []) (pat :: pats) in let inferred_pats = List.rev inferred_pats in (* Will be none if the subpatterns are bitvectors *) - let elem_typ = match typ_opt with + let elem_typ = + match typ_opt with | Some (_, _, elem_typ) -> elem_typ - | None -> - match List.find_opt Result.is_ok inferred_pats with - | Some (Ok pat) -> - begin match destruct_any_vector_typ l env (typ_of_pat pat) with - | Destruct_vector (_, _, t) -> Some t - | Destruct_bitvector _ -> None + | None -> ( + match List.find_opt Result.is_ok inferred_pats with + | Some (Ok pat) -> begin + match destruct_any_vector_typ l env (typ_of_pat pat) with + | Destruct_vector (_, _, t) -> Some t + | Destruct_bitvector _ -> None end - | _ -> - typ_error env l "Could not infer type of subpatterns in vector concatenation pattern" in + | _ -> typ_error env l "Could not infer type of subpatterns in vector concatenation pattern" + ) + in (* We can handle a single None in inferred_pats from something like 0b00 @ _ @ 0b00, because we know the wildcard will be bits('n - 4) where 'n is the total length of the pattern. *) let before_uninferred, rest = Util.take_drop Result.is_ok inferred_pats in let before_uninferred = List.map Result.get_ok before_uninferred in - let uninferred, after_uninferred = match rest with + let uninferred, after_uninferred = + match rest with | Error (first_uninferred, exn) :: rest -> - begin match List.find_opt Result.is_error rest with - | Some (Error (second_uninferred, _)) -> - let msg = - "Cannot infer width here, as there are multiple subpatterns with unclear width in vector concatenation pattern" - in - typ_raise env (pat_loc second_uninferred) - (err_because (Err_other msg, pat_loc first_uninferred, Err_other "A previous subpattern is here")) - | _ -> () - end; - begin match typ_opt with - | Some (total_len, order, _) -> Some (total_len, order, first_uninferred), List.map Result.get_ok rest - | None -> raise exn - end - | _ -> None, [] in - - let check_constant_len l n = match solve_unique env n with + begin + match List.find_opt Result.is_error rest with + | Some (Error (second_uninferred, _)) -> + let msg = + "Cannot infer width here, as there are multiple subpatterns with unclear width in vector concatenation \ + pattern" + in + typ_raise env (pat_loc second_uninferred) + (err_because (Err_other msg, pat_loc first_uninferred, Err_other "A previous subpattern is here")) + | _ -> () + end; + begin + match typ_opt with + | Some (total_len, order, _) -> (Some (total_len, order, first_uninferred), List.map Result.get_ok rest) + | None -> raise exn + end + | _ -> (None, []) + in + + let check_constant_len l n = + match solve_unique env n with | Some c -> nconstant c - | None -> typ_error env l "Could not infer constant length for vector concatenation subpattern" in + | None -> typ_error env l "Could not infer constant length for vector concatenation subpattern" + in (* Now we have two similar cases for ordinary vectors and bitvectors *) match elem_typ with | Some elem_typ -> - let fold_len len pat = - let (len', _, elem_typ') = destruct_vector_typ l env (typ_of_pat pat) in - let len' = check_constant_len (pat_loc pat) len' in - typ_equality l env elem_typ elem_typ'; - nsum len len' - in - let before_len = List.fold_left fold_len (nint 0) before_uninferred in - let after_len = List.fold_left fold_len (nint 0) after_uninferred in - let inferred_len = nexp_simp (nsum before_len after_len) in - begin match uninferred with - | Some (total_len, order, uninferred_pat) -> - let total_len = nconstant total_len in - let uninferred_len = nexp_simp (nminus total_len inferred_len) in - let checked_pat, env, guards' = bind_pat env uninferred_pat (vector_typ uninferred_len order elem_typ) in - annot_vcp (before_uninferred @ [checked_pat] @ after_uninferred) (vector_typ total_len order elem_typ), - env, - guards' @ guards - | None -> - annot_vcp before_uninferred (dvector_typ env inferred_len elem_typ), env, guards - end - + let fold_len len pat = + let len', _, elem_typ' = destruct_vector_typ l env (typ_of_pat pat) in + let len' = check_constant_len (pat_loc pat) len' in + typ_equality l env elem_typ elem_typ'; + nsum len len' + in + let before_len = List.fold_left fold_len (nint 0) before_uninferred in + let after_len = List.fold_left fold_len (nint 0) after_uninferred in + let inferred_len = nexp_simp (nsum before_len after_len) in + begin + match uninferred with + | Some (total_len, order, uninferred_pat) -> + let total_len = nconstant total_len in + let uninferred_len = nexp_simp (nminus total_len inferred_len) in + let checked_pat, env, guards' = bind_pat env uninferred_pat (vector_typ uninferred_len order elem_typ) in + ( annot_vcp (before_uninferred @ [checked_pat] @ after_uninferred) (vector_typ total_len order elem_typ), + env, + guards' @ guards + ) + | None -> (annot_vcp before_uninferred (dvector_typ env inferred_len elem_typ), env, guards) + end | None -> - let fold_len len pat = - let (len', _) = destruct_bitvector_typ l env (typ_of_pat pat) in - let len' = check_constant_len (pat_loc pat) len' in - nsum len len' - in - let before_len = List.fold_left fold_len (nint 0) before_uninferred in - let after_len = List.fold_left fold_len (nint 0) after_uninferred in - let inferred_len = nexp_simp (nsum before_len after_len) in - begin match uninferred with - | Some (total_len, order, uninferred_pat) -> - let total_len = nconstant total_len in - let uninferred_len = nexp_simp (nminus total_len inferred_len) in - let checked_pat, env, guards' = bind_pat env uninferred_pat (bitvector_typ uninferred_len order) in - annot_vcp (before_uninferred @ [checked_pat] @ after_uninferred) (bitvector_typ total_len order), - env, - guards' @ guards - | None -> - annot_vcp before_uninferred (bits_typ env inferred_len), env, guards - end + let fold_len len pat = + let len', _ = destruct_bitvector_typ l env (typ_of_pat pat) in + let len' = check_constant_len (pat_loc pat) len' in + nsum len len' + in + let before_len = List.fold_left fold_len (nint 0) before_uninferred in + let after_len = List.fold_left fold_len (nint 0) after_uninferred in + let inferred_len = nexp_simp (nsum before_len after_len) in + begin + match uninferred with + | Some (total_len, order, uninferred_pat) -> + let total_len = nconstant total_len in + let uninferred_len = nexp_simp (nminus total_len inferred_len) in + let checked_pat, env, guards' = bind_pat env uninferred_pat (bitvector_typ uninferred_len order) in + ( annot_vcp (before_uninferred @ [checked_pat] @ after_uninferred) (bitvector_typ total_len order), + env, + guards' @ guards + ) + | None -> (annot_vcp before_uninferred (bits_typ env inferred_len), env, guards) + end and bind_typ_pat env (TP_aux (typ_pat_aux, l) as typ_pat) (Typ_aux (typ_aux, _) as typ) = - typ_print (lazy (Util.("Binding type pattern " |> yellow |> clear) ^ string_of_typ_pat typ_pat ^ " to " ^ string_of_typ typ)); - match typ_pat_aux, typ_aux with - | TP_wild, _ -> env, typ - | TP_var kid, _ -> - begin - match typ_nexps typ, typ_constraints typ with - | [nexp], [] -> + typ_print + (lazy (Util.("Binding type pattern " |> yellow |> clear) ^ string_of_typ_pat typ_pat ^ " to " ^ string_of_typ typ)); + match (typ_pat_aux, typ_aux) with + | TP_wild, _ -> (env, typ) + | TP_var kid, _ -> begin + match (typ_nexps typ, typ_constraints typ) with + | [nexp], [] -> let env, shadow = Env.add_typ_var_shadow l (mk_kopt K_int kid) env in let nexp = match shadow with Some s_v -> nexp_subst kid (arg_nexp (nvar s_v)) nexp | None -> nexp in - Env.add_constraint ~reason:(l, "type pattern") (nc_eq (nvar kid) nexp) env, replace_nexp_typ nexp (Nexp_aux (Nexp_var kid, l)) typ - | [], [nc] -> + ( Env.add_constraint ~reason:(l, "type pattern") (nc_eq (nvar kid) nexp) env, + replace_nexp_typ nexp (Nexp_aux (Nexp_var kid, l)) typ + ) + | [], [nc] -> let env, shadow = Env.add_typ_var_shadow l (mk_kopt K_bool kid) env in let nc = match shadow with Some s_v -> constraint_subst kid (arg_bool (nc_var s_v)) nc | None -> nc in - Env.add_constraint ~reason:(l, "type pattern") (nc_and (nc_or (nc_not nc) (nc_var kid)) (nc_or nc (nc_not (nc_var kid)))) env, - replace_nc_typ nc (NC_aux (NC_var kid, l)) typ - | [], [] -> + ( Env.add_constraint ~reason:(l, "type pattern") + (nc_and (nc_or (nc_not nc) (nc_var kid)) (nc_or nc (nc_not (nc_var kid)))) + env, + replace_nc_typ nc (NC_aux (NC_var kid, l)) typ + ) + | [], [] -> typ_error env l ("No numeric expressions in " ^ string_of_typ typ ^ " to bind " ^ string_of_kid kid ^ " to") - | _, _ -> - typ_error env l ("Type " ^ string_of_typ typ ^ " has multiple numeric or boolean expressions. Cannot bind " ^ string_of_kid kid) - end + | _, _ -> + typ_error env l + ("Type " ^ string_of_typ typ ^ " has multiple numeric or boolean expressions. Cannot bind " + ^ string_of_kid kid + ) + end | TP_app (f1, tpats), Typ_app (f2, typs) when Id.compare f1 f2 = 0 -> - let env, args = List.fold_right2 (fun tp arg (env, args) -> let env, arg = bind_typ_pat_arg env tp arg in env, arg::args) tpats typs (env, []) in - env, Typ_aux (Typ_app (f2, args), l) + let env, args = + List.fold_right2 + (fun tp arg (env, args) -> + let env, arg = bind_typ_pat_arg env tp arg in + (env, arg :: args) + ) + tpats typs (env, []) + in + (env, Typ_aux (Typ_app (f2, args), l)) | _, _ -> typ_error env l ("Couldn't bind type " ^ string_of_typ typ ^ " with " ^ string_of_typ_pat typ_pat) + and bind_typ_pat_arg env (TP_aux (typ_pat_aux, l) as typ_pat) (A_aux (typ_arg_aux, l_arg) as typ_arg) = - match typ_pat_aux, typ_arg_aux with - | TP_wild, _ -> env, typ_arg + match (typ_pat_aux, typ_arg_aux) with + | TP_wild, _ -> (env, typ_arg) | TP_var kid, A_nexp nexp -> - let env, shadow = Env.add_typ_var_shadow l (mk_kopt K_int kid) env in - let nexp = match shadow with Some s_v -> nexp_subst kid (arg_nexp (nvar s_v)) nexp | None -> nexp in - Env.add_constraint ~reason:(l, "type pattern") (nc_eq (nvar kid) nexp) env, arg_nexp ~loc:l (nvar kid) - | _, A_typ typ -> let env, typ' = bind_typ_pat env typ_pat typ in env, A_aux (A_typ typ', l_arg) + let env, shadow = Env.add_typ_var_shadow l (mk_kopt K_int kid) env in + let nexp = match shadow with Some s_v -> nexp_subst kid (arg_nexp (nvar s_v)) nexp | None -> nexp in + (Env.add_constraint ~reason:(l, "type pattern") (nc_eq (nvar kid) nexp) env, arg_nexp ~loc:l (nvar kid)) + | _, A_typ typ -> + let env, typ' = bind_typ_pat env typ_pat typ in + (env, A_aux (A_typ typ', l_arg)) | _, A_order _ -> typ_error env l "Cannot bind type pattern against order" - | _, _ -> typ_error env l ("Couldn't bind type argument " ^ string_of_typ_arg typ_arg ^ " with " ^ string_of_typ_pat typ_pat) + | _, _ -> + typ_error env l ("Couldn't bind type argument " ^ string_of_typ_arg typ_arg ^ " with " ^ string_of_typ_pat typ_pat) and bind_assignment assign_l env (LE_aux (lexp_aux, (lexp_l, uannot)) as lexp) exp = - let annot_assign lexp exp = E_aux (E_assign (lexp, exp), (assign_l, mk_tannot env (mk_typ (Typ_id (mk_id "unit"))))) in - let has_typ v env = - match Env.lookup_id v env with - | Local (Mutable, _) | Register _ -> true - | _ -> false + let annot_assign lexp exp = + E_aux (E_assign (lexp, exp), (assign_l, mk_tannot env (mk_typ (Typ_id (mk_id "unit"))))) in + let has_typ v env = match Env.lookup_id v env with Local (Mutable, _) | Register _ -> true | _ -> false in match lexp_aux with - | LE_app (f, xs) -> - check_exp env (E_aux (E_app (f, xs @ [exp]), (lexp_l, uannot))) unit_typ, env + | LE_app (f, xs) -> (check_exp env (E_aux (E_app (f, xs @ [exp]), (lexp_l, uannot))) unit_typ, env) | LE_typ (typ_annot, _) -> - let checked_exp = crule check_exp env exp typ_annot in - let tlexp, env' = bind_lexp env lexp (typ_of checked_exp) in - annot_assign tlexp checked_exp, env' - | LE_id v when has_typ v env -> - begin match Env.lookup_id v env with - | Local (Mutable, vtyp) | Register vtyp -> - let checked_exp = crule check_exp env exp vtyp in - let tlexp, env' = bind_lexp env lexp (typ_of checked_exp) in - annot_assign tlexp checked_exp, env' - | _ -> assert false - end - | _ -> - (* Here we have two options, we can infer the type from the - expression, or we can infer the type from the - l-expression. Both are useful in different cases, so try - both. *) - try - let inferred_exp = irule infer_exp env exp in - let tlexp, env' = bind_lexp env lexp (typ_of inferred_exp) in - annot_assign tlexp inferred_exp, env' - with - | Type_error (_, l, err) -> + let checked_exp = crule check_exp env exp typ_annot in + let tlexp, env' = bind_lexp env lexp (typ_of checked_exp) in + (annot_assign tlexp checked_exp, env') + | LE_id v when has_typ v env -> begin + match Env.lookup_id v env with + | Local (Mutable, vtyp) | Register vtyp -> + let checked_exp = crule check_exp env exp vtyp in + let tlexp, env' = bind_lexp env lexp (typ_of checked_exp) in + (annot_assign tlexp checked_exp, env') + | _ -> assert false + end + | _ -> ( + (* Here we have two options, we can infer the type from the + expression, or we can infer the type from the + l-expression. Both are useful in different cases, so try + both. *) + try + let inferred_exp = irule infer_exp env exp in + let tlexp, env' = bind_lexp env lexp (typ_of inferred_exp) in + (annot_assign tlexp inferred_exp, env') + with Type_error (_, l, err) -> ( try let inferred_lexp = infer_lexp env lexp in let checked_exp = crule check_exp env exp (lexp_typ_of inferred_lexp) in - annot_assign inferred_lexp checked_exp, env + (annot_assign inferred_lexp checked_exp, env) with Type_error (env, l', err') -> typ_raise env l' (err_because (err', l, err)) + ) + ) and bind_lexp env (LE_aux (lexp_aux, (l, _)) as lexp) typ = - typ_print (lazy ("Binding mutable " ^ string_of_lexp lexp ^ " to " ^ string_of_typ typ)); + typ_print (lazy ("Binding mutable " ^ string_of_lexp lexp ^ " to " ^ string_of_typ typ)); let annot_lexp lexp typ = LE_aux (lexp, (l, mk_tannot env typ)) in match lexp_aux with - | LE_typ (typ_annot, v) -> - begin match Env.lookup_id v env with - | Local (Immutable, _) | Enum _ -> + | LE_typ (typ_annot, v) -> begin + match Env.lookup_id v env with + | Local (Immutable, _) | Enum _ -> typ_error env l ("Cannot modify immutable let-bound constant or enumeration constructor " ^ string_of_id v) - | Local (Mutable, vtyp) -> + | Local (Mutable, vtyp) -> subtyp l env typ typ_annot; subtyp l env typ_annot vtyp; - annot_lexp (LE_typ (typ_annot, v)) typ, Env.add_local v (Mutable, typ_annot) env - | Register vtyp -> + (annot_lexp (LE_typ (typ_annot, v)) typ, Env.add_local v (Mutable, typ_annot) env) + | Register vtyp -> subtyp l env typ typ_annot; subtyp l env typ_annot vtyp; - annot_lexp (LE_typ (typ_annot, v)) typ, env - | Unbound _ -> + (annot_lexp (LE_typ (typ_annot, v)) typ, env) + | Unbound _ -> subtyp l env typ typ_annot; - annot_lexp (LE_typ (typ_annot, v)) typ, Env.add_local v (Mutable, typ_annot) env - end - | LE_id v -> - begin match Env.lookup_id v env with - | Local (Immutable, _) | Enum _ -> - typ_error env l ("Cannot modify immutable let-bound constant or enumeration constructor " ^ string_of_id v) - | Local (Mutable, vtyp) -> subtyp l env typ vtyp; annot_lexp (LE_id v) typ, env - | Register vtyp -> subtyp l env typ vtyp; annot_lexp (LE_id v) typ, env - | Unbound _ -> annot_lexp (LE_id v) typ, Env.add_local v (Mutable, typ) env - end - | LE_tuple lexps -> - begin - let typ = Env.expand_synonyms env typ in - let (Typ_aux (typ_aux, _)) = typ in - match typ_aux with - | Typ_tuple typs -> + (annot_lexp (LE_typ (typ_annot, v)) typ, Env.add_local v (Mutable, typ_annot) env) + end + | LE_id v -> begin + match Env.lookup_id v env with + | Local (Immutable, _) | Enum _ -> + typ_error env l ("Cannot modify immutable let-bound constant or enumeration constructor " ^ string_of_id v) + | Local (Mutable, vtyp) -> + subtyp l env typ vtyp; + (annot_lexp (LE_id v) typ, env) + | Register vtyp -> + subtyp l env typ vtyp; + (annot_lexp (LE_id v) typ, env) + | Unbound _ -> (annot_lexp (LE_id v) typ, Env.add_local v (Mutable, typ) env) + end + | LE_tuple lexps -> begin + let typ = Env.expand_synonyms env typ in + let (Typ_aux (typ_aux, _)) = typ in + match typ_aux with + | Typ_tuple typs -> let bind_tuple_lexp lexp typ (tlexps, env) = - let tlexp, env = bind_lexp env lexp typ in tlexp :: tlexps, env + let tlexp, env = bind_lexp env lexp typ in + (tlexp :: tlexps, env) in let tlexps, env = - try List.fold_right2 bind_tuple_lexp lexps typs ([], env) with - | Invalid_argument _ -> typ_error env l "Tuple l-expression and tuple type have different length" + try List.fold_right2 bind_tuple_lexp lexps typs ([], env) + with Invalid_argument _ -> typ_error env l "Tuple l-expression and tuple type have different length" in - annot_lexp (LE_tuple tlexps) typ, env - | _ -> typ_error env l ("Cannot bind tuple l-expression against non tuple type " ^ string_of_typ typ) - end + (annot_lexp (LE_tuple tlexps) typ, env) + | _ -> typ_error env l ("Cannot bind tuple l-expression against non tuple type " ^ string_of_typ typ) + end | _ -> - let inferred_lexp = infer_lexp env lexp in - subtyp l env typ (lexp_typ_of inferred_lexp); - inferred_lexp, env + let inferred_lexp = infer_lexp env lexp in + subtyp l env typ (lexp_typ_of inferred_lexp); + (inferred_lexp, env) and infer_lexp env (LE_aux (lexp_aux, (l, uannot)) as lexp) = - let annot_lexp lexp typ = LE_aux (lexp, (l, mk_tannot ~uannot:uannot env typ)) in + let annot_lexp lexp typ = LE_aux (lexp, (l, mk_tannot ~uannot env typ)) in match lexp_aux with - | LE_id v -> - begin match Env.lookup_id v env with - | Local (Mutable, typ) -> annot_lexp (LE_id v) typ - | Register typ -> annot_lexp (LE_id v) typ - | Local (Immutable, _) | Enum _ -> - typ_error env l ("Cannot modify let-bound constant or enumeration constructor " ^ string_of_id v) - | Unbound _ -> - typ_error env l ("Cannot create a new identifier in this l-expression " ^ string_of_lexp lexp) - end - | LE_vector_range (v_lexp, exp1, exp2) -> - begin - let inferred_v_lexp = infer_lexp env v_lexp in - let (Typ_aux (v_typ_aux, _)) = Env.expand_synonyms env (lexp_typ_of inferred_v_lexp) in - match v_typ_aux with - | Typ_app (id, [A_aux (A_nexp len, _); A_aux (A_order ord, _)]) when Id.compare id (mk_id "bitvector") = 0 -> + | LE_id v -> begin + match Env.lookup_id v env with + | Local (Mutable, typ) -> annot_lexp (LE_id v) typ + | Register typ -> annot_lexp (LE_id v) typ + | Local (Immutable, _) | Enum _ -> + typ_error env l ("Cannot modify let-bound constant or enumeration constructor " ^ string_of_id v) + | Unbound _ -> typ_error env l ("Cannot create a new identifier in this l-expression " ^ string_of_lexp lexp) + end + | LE_vector_range (v_lexp, exp1, exp2) -> begin + let inferred_v_lexp = infer_lexp env v_lexp in + let (Typ_aux (v_typ_aux, _)) = Env.expand_synonyms env (lexp_typ_of inferred_v_lexp) in + match v_typ_aux with + | Typ_app (id, [A_aux (A_nexp len, _); A_aux (A_order ord, _)]) when Id.compare id (mk_id "bitvector") = 0 -> let inferred_exp1 = infer_exp env exp1 in let inferred_exp2 = infer_exp env exp2 in let nexp1, env = bind_numeric l (typ_of inferred_exp1) env in let nexp2, env = bind_numeric l (typ_of inferred_exp2) env in - let (slice_len, check) = match ord with + let slice_len, check = + match ord with | Ord_aux (Ord_inc, _) -> - (nexp_simp (nsum (nminus nexp2 nexp1) (nint 1)), - nc_and (nc_and (nc_lteq (nint 0) nexp1) (nc_lteq nexp1 nexp2)) (nc_lt nexp2 len)) + ( nexp_simp (nsum (nminus nexp2 nexp1) (nint 1)), + nc_and (nc_and (nc_lteq (nint 0) nexp1) (nc_lteq nexp1 nexp2)) (nc_lt nexp2 len) + ) | Ord_aux (Ord_dec, _) -> - (nexp_simp (nsum (nminus nexp1 nexp2) (nint 1)), - nc_and (nc_and (nc_lteq (nint 0) nexp2) (nc_lteq nexp2 nexp1)) (nc_lt nexp1 len)) + ( nexp_simp (nsum (nminus nexp1 nexp2) (nint 1)), + nc_and (nc_and (nc_lteq (nint 0) nexp2) (nc_lteq nexp2 nexp1)) (nc_lt nexp1 len) + ) | Ord_aux (Ord_var _, _) -> - typ_error env l "Slice assignment to bitvector with variable indexing order unsupported" + typ_error env l "Slice assignment to bitvector with variable indexing order unsupported" in if !opt_no_lexp_bounds_check || prove __POS__ env check then annot_lexp (LE_vector_range (inferred_v_lexp, inferred_exp1, inferred_exp2)) (bitvector_typ slice_len ord) - else - typ_raise env l (Err_failed_constraint (check, Env.get_locals env, Env.get_constraints env)) - | _ -> typ_error env l "Cannot assign slice of non vector type" - end - | LE_vector (v_lexp, exp) -> - begin - let inferred_v_lexp = infer_lexp env v_lexp in - let (Typ_aux (v_typ_aux, _)) = Env.expand_synonyms env (lexp_typ_of inferred_v_lexp) in - match v_typ_aux with - | Typ_app (id, [A_aux (A_nexp len, _); A_aux (A_order _, _); A_aux (A_typ elem_typ, _)]) - when Id.compare id (mk_id "vector") = 0 -> + else typ_raise env l (Err_failed_constraint (check, Env.get_locals env, Env.get_constraints env)) + | _ -> typ_error env l "Cannot assign slice of non vector type" + end + | LE_vector (v_lexp, exp) -> begin + let inferred_v_lexp = infer_lexp env v_lexp in + let (Typ_aux (v_typ_aux, _)) = Env.expand_synonyms env (lexp_typ_of inferred_v_lexp) in + match v_typ_aux with + | Typ_app (id, [A_aux (A_nexp len, _); A_aux (A_order _, _); A_aux (A_typ elem_typ, _)]) + when Id.compare id (mk_id "vector") = 0 -> let inferred_exp = infer_exp env exp in let nexp, env = bind_numeric l (typ_of inferred_exp) env in let bounds_check = nc_and (nc_lteq (nint 0) nexp) (nc_lt nexp len) in if !opt_no_lexp_bounds_check || prove __POS__ env bounds_check then annot_lexp (LE_vector (inferred_v_lexp, inferred_exp)) elem_typ - else - typ_raise env l (Err_failed_constraint (bounds_check, Env.get_locals env, Env.get_constraints env)) - | Typ_app (id, [A_aux (A_nexp len, _); A_aux (A_order _, _)]) - when Id.compare id (mk_id "bitvector") = 0 -> + else typ_raise env l (Err_failed_constraint (bounds_check, Env.get_locals env, Env.get_constraints env)) + | Typ_app (id, [A_aux (A_nexp len, _); A_aux (A_order _, _)]) when Id.compare id (mk_id "bitvector") = 0 -> let inferred_exp = infer_exp env exp in let nexp, env = bind_numeric l (typ_of inferred_exp) env in let bounds_check = nc_and (nc_lteq (nint 0) nexp) (nc_lt nexp len) in if !opt_no_lexp_bounds_check || prove __POS__ env bounds_check then annot_lexp (LE_vector (inferred_v_lexp, inferred_exp)) bit_typ - else - typ_raise env l (Err_failed_constraint (bounds_check, Env.get_locals env, Env.get_constraints env)) - | Typ_id id -> - begin match exp with + else typ_raise env l (Err_failed_constraint (bounds_check, Env.get_locals env, Env.get_constraints env)) + | Typ_id id -> begin + match exp with | E_aux (E_id field, _) -> - let field_lexp = Bitfield.set_bits_field_lexp v_lexp in - let index_range = match get_bitfield_range id field env with - | Some range -> range - | None -> typ_error env l (Printf.sprintf "Unknown field %s in bitfield %s" (string_of_id field) (string_of_id id)) - in - infer_lexp env (Bitfield.set_field_lexp index_range field_lexp) - | _ -> - typ_error env l (string_of_exp exp ^ " is not a bitfield accessor") - end - | _ -> typ_error env l "Cannot assign vector element of non vector or bitfield type" - end + let field_lexp = Bitfield.set_bits_field_lexp v_lexp in + let index_range = + match get_bitfield_range id field env with + | Some range -> range + | None -> + typ_error env l + (Printf.sprintf "Unknown field %s in bitfield %s" (string_of_id field) (string_of_id id)) + in + infer_lexp env (Bitfield.set_field_lexp index_range field_lexp) + | _ -> typ_error env l (string_of_exp exp ^ " is not a bitfield accessor") + end + | _ -> typ_error env l "Cannot assign vector element of non vector or bitfield type" + end | LE_vector_concat [] -> typ_error env l "Cannot have empty vector concatenation l-expression" - | LE_vector_concat (v_lexp :: v_lexps) -> - begin - let sum_vector_lengths first_ord first_elem_typ acc (Typ_aux (v_typ_aux, _)) = - match v_typ_aux with - | Typ_app (id, [A_aux (A_nexp len, _); A_aux (A_order ord, _); A_aux (A_typ elem_typ, _)]) - when Id.compare id (mk_id "vector") = 0 && ord_identical ord first_ord -> + | LE_vector_concat (v_lexp :: v_lexps) -> begin + let sum_vector_lengths first_ord first_elem_typ acc (Typ_aux (v_typ_aux, _)) = + match v_typ_aux with + | Typ_app (id, [A_aux (A_nexp len, _); A_aux (A_order ord, _); A_aux (A_typ elem_typ, _)]) + when Id.compare id (mk_id "vector") = 0 && ord_identical ord first_ord -> typ_equality l env elem_typ first_elem_typ; nsum acc len - | _ -> typ_error env l "Vector concatentation l-expression must only contain vector types of the same order" - in - let sum_bitvector_lengths first_ord acc (Typ_aux (v_typ_aux, _)) = - match v_typ_aux with - | Typ_app (id, [A_aux (A_nexp len, _); A_aux (A_order ord, _)]) - when Id.compare id (mk_id "bitvector") = 0 && ord_identical ord first_ord -> + | _ -> typ_error env l "Vector concatentation l-expression must only contain vector types of the same order" + in + let sum_bitvector_lengths first_ord acc (Typ_aux (v_typ_aux, _)) = + match v_typ_aux with + | Typ_app (id, [A_aux (A_nexp len, _); A_aux (A_order ord, _)]) + when Id.compare id (mk_id "bitvector") = 0 && ord_identical ord first_ord -> nsum acc len - | _ -> typ_error env l "Bitvector concatentation l-expression must only contain bitvector types of the same order" - in - let inferred_v_lexp = infer_lexp env v_lexp in - let inferred_v_lexps = List.map (infer_lexp env) v_lexps in - let (Typ_aux (v_typ_aux, _) as v_typ) = Env.expand_synonyms env (lexp_typ_of inferred_v_lexp) in - let v_typs = List.map (fun lexp -> Env.expand_synonyms env (lexp_typ_of lexp)) inferred_v_lexps in - match v_typ_aux with - | Typ_app (id, [A_aux (A_nexp len, _); A_aux (A_order ord, _); A_aux (A_typ elem_typ, _)]) - when Id.compare id (mk_id "vector") = 0 -> + | _ -> + typ_error env l "Bitvector concatentation l-expression must only contain bitvector types of the same order" + in + let inferred_v_lexp = infer_lexp env v_lexp in + let inferred_v_lexps = List.map (infer_lexp env) v_lexps in + let (Typ_aux (v_typ_aux, _) as v_typ) = Env.expand_synonyms env (lexp_typ_of inferred_v_lexp) in + let v_typs = List.map (fun lexp -> Env.expand_synonyms env (lexp_typ_of lexp)) inferred_v_lexps in + match v_typ_aux with + | Typ_app (id, [A_aux (A_nexp len, _); A_aux (A_order ord, _); A_aux (A_typ elem_typ, _)]) + when Id.compare id (mk_id "vector") = 0 -> let len = List.fold_left (sum_vector_lengths ord elem_typ) len v_typs in annot_lexp (LE_vector_concat (inferred_v_lexp :: inferred_v_lexps)) (vector_typ (nexp_simp len) ord elem_typ) - | Typ_app (id, [A_aux (A_nexp len, _); A_aux (A_order ord, _)]) - when Id.compare id (mk_id "bitvector") = 0 -> + | Typ_app (id, [A_aux (A_nexp len, _); A_aux (A_order ord, _)]) when Id.compare id (mk_id "bitvector") = 0 -> let len = List.fold_left (sum_bitvector_lengths ord) len v_typs in annot_lexp (LE_vector_concat (inferred_v_lexp :: inferred_v_lexps)) (bitvector_typ (nexp_simp len) ord) - | _ -> typ_error env l ("Vector concatentation l-expression must only contain bitvector or vector types, found " ^ string_of_typ v_typ) - end + | _ -> + typ_error env l + ("Vector concatentation l-expression must only contain bitvector or vector types, found " + ^ string_of_typ v_typ + ) + end | LE_field ((LE_aux (_, (l, _)) as lexp), field_id) -> - let inferred_lexp = infer_lexp env lexp in - let rectyp = lexp_typ_of inferred_lexp in - begin match lexp_typ_of inferred_lexp with - | Typ_aux (Typ_id rectyp_id, _) | Typ_aux (Typ_app (rectyp_id, _), _) when Env.is_record rectyp_id env -> - let (_, rectyp_q, field_typ) = Env.get_accessor rectyp_id field_id env in - let unifiers = try unify l env (tyvars_of_typ rectyp_q) rectyp_q rectyp with Unification_error (l, m) -> typ_error env l ("Unification error: " ^ m) in - let field_typ' = subst_unifiers unifiers field_typ in - annot_lexp (LE_field (inferred_lexp, field_id)) field_typ' - | _ -> typ_error env l "Field l-expression has invalid type" - end + let inferred_lexp = infer_lexp env lexp in + let rectyp = lexp_typ_of inferred_lexp in + begin + match lexp_typ_of inferred_lexp with + | (Typ_aux (Typ_id rectyp_id, _) | Typ_aux (Typ_app (rectyp_id, _), _)) when Env.is_record rectyp_id env -> + let _, rectyp_q, field_typ = Env.get_accessor rectyp_id field_id env in + let unifiers = + try unify l env (tyvars_of_typ rectyp_q) rectyp_q rectyp + with Unification_error (l, m) -> typ_error env l ("Unification error: " ^ m) + in + let field_typ' = subst_unifiers unifiers field_typ in + annot_lexp (LE_field (inferred_lexp, field_id)) field_typ' + | _ -> typ_error env l "Field l-expression has invalid type" + end | LE_deref exp -> - let inferred_exp = infer_exp env exp in - begin match typ_of inferred_exp with - | Typ_aux (Typ_app (r, [A_aux (A_typ vtyp, _)]), _) when string_of_id r = "register" -> - annot_lexp (LE_deref inferred_exp) vtyp - | _ -> - typ_error env l (string_of_typ (typ_of inferred_exp) ^ " must be a register type in " ^ string_of_exp exp ^ ")") - end + let inferred_exp = infer_exp env exp in + begin + match typ_of inferred_exp with + | Typ_aux (Typ_app (r, [A_aux (A_typ vtyp, _)]), _) when string_of_id r = "register" -> + annot_lexp (LE_deref inferred_exp) vtyp + | _ -> + typ_error env l + (string_of_typ (typ_of inferred_exp) ^ " must be a register type in " ^ string_of_exp exp ^ ")") + end | LE_tuple lexps -> - let inferred_lexps = List.map (infer_lexp env) lexps in - annot_lexp (LE_tuple inferred_lexps) (tuple_typ (List.map lexp_typ_of inferred_lexps)) + let inferred_lexps = List.map (infer_lexp env) lexps in + annot_lexp (LE_tuple inferred_lexps) (tuple_typ (List.map lexp_typ_of inferred_lexps)) | _ -> typ_error env l ("Could not infer the type of " ^ string_of_lexp lexp) and infer_exp env (E_aux (exp_aux, (l, uannot)) as exp) = - let annot_exp exp typ = E_aux (exp, (l, mk_tannot ~uannot:uannot env typ)) in + let annot_exp exp typ = E_aux (exp, (l, mk_tannot ~uannot env typ)) in match exp_aux with | E_block exps -> - let rec last_typ = function [exp] -> typ_of exp | _ :: exps -> last_typ exps | [] -> unit_typ in - let inferred_block = check_block l env exps None in - annot_exp (E_block inferred_block) (last_typ inferred_block) - | E_id v -> - begin - match Env.lookup_id v env with - | Local (_, typ) | Enum typ -> annot_exp (E_id v) typ - | Register typ -> annot_exp (E_id v) typ - | Unbound _ -> + let rec last_typ = function [exp] -> typ_of exp | _ :: exps -> last_typ exps | [] -> unit_typ in + let inferred_block = check_block l env exps None in + annot_exp (E_block inferred_block) (last_typ inferred_block) + | E_id v -> begin + match Env.lookup_id v env with + | Local (_, typ) | Enum typ -> annot_exp (E_id v) typ + | Register typ -> annot_exp (E_id v) typ + | Unbound _ -> ( match Bindings.find_opt v (Env.get_val_specs env) with - | Some _ -> typ_error env l ("Identifier " ^ string_of_id v ^ " is unbound (Did you mean to call the " ^ string_of_id v ^ " function?)") + | Some _ -> + typ_error env l + ("Identifier " ^ string_of_id v ^ " is unbound (Did you mean to call the " ^ string_of_id v + ^ " function?)" + ) | None -> typ_error env l ("Identifier " ^ string_of_id v ^ " is unbound") - end + ) + end | E_lit lit -> annot_exp (E_lit lit) (infer_lit env lit) - | E_sizeof nexp -> - irule infer_exp env (rewrite_sizeof l env (Env.expand_nexp_synonyms env nexp)) + | E_sizeof nexp -> irule infer_exp env (rewrite_sizeof l env (Env.expand_nexp_synonyms env nexp)) | E_constraint nc -> - Env.wf_constraint env nc; - crule check_exp env (rewrite_nc env (Env.expand_constraint_synonyms env nc)) (atom_bool_typ nc) - | E_field (exp, field) -> - begin - let inferred_exp = irule infer_exp env exp in - match Env.expand_synonyms env (typ_of inferred_exp) with - (* Accessing a field of a record *) - | Typ_aux (Typ_id rectyp, _) when Env.is_record rectyp env -> - begin - let inferred_acc = infer_funapp' l (Env.no_casts env) field (Env.get_accessor_fn rectyp field env) [strip_exp inferred_exp] None in - match inferred_acc with - | E_aux (E_app (field, [inferred_exp]) ,_) -> annot_exp (E_field (inferred_exp, field)) (typ_of inferred_acc) - | _ -> assert false (* Unreachable *) - end - (* Not sure if we need to do anything different with args here. *) - | Typ_aux (Typ_app (rectyp, _), _) when Env.is_record rectyp env -> - begin - let inferred_acc = infer_funapp' l (Env.no_casts env) field (Env.get_accessor_fn rectyp field env) [strip_exp inferred_exp] None in - match inferred_acc with - | E_aux (E_app (field, [inferred_exp]) ,_) -> annot_exp (E_field (inferred_exp, field)) (typ_of inferred_acc) - | _ -> assert false (* Unreachable *) - end - | _ -> typ_error env l ("Field expression " ^ string_of_exp exp ^ " :: " ^ string_of_typ (typ_of inferred_exp) ^ " is not valid") - end + Env.wf_constraint env nc; + crule check_exp env (rewrite_nc env (Env.expand_constraint_synonyms env nc)) (atom_bool_typ nc) + | E_field (exp, field) -> begin + let inferred_exp = irule infer_exp env exp in + match Env.expand_synonyms env (typ_of inferred_exp) with + (* Accessing a field of a record *) + | Typ_aux (Typ_id rectyp, _) when Env.is_record rectyp env -> begin + let inferred_acc = + infer_funapp' l (Env.no_casts env) field (Env.get_accessor_fn rectyp field env) + [strip_exp inferred_exp] + None + in + match inferred_acc with + | E_aux (E_app (field, [inferred_exp]), _) -> annot_exp (E_field (inferred_exp, field)) (typ_of inferred_acc) + | _ -> assert false (* Unreachable *) + end + (* Not sure if we need to do anything different with args here. *) + | Typ_aux (Typ_app (rectyp, _), _) when Env.is_record rectyp env -> begin + let inferred_acc = + infer_funapp' l (Env.no_casts env) field (Env.get_accessor_fn rectyp field env) + [strip_exp inferred_exp] + None + in + match inferred_acc with + | E_aux (E_app (field, [inferred_exp]), _) -> annot_exp (E_field (inferred_exp, field)) (typ_of inferred_acc) + | _ -> assert false (* Unreachable *) + end + | _ -> + typ_error env l + ("Field expression " ^ string_of_exp exp ^ " :: " ^ string_of_typ (typ_of inferred_exp) ^ " is not valid") + end | E_tuple exps -> - let inferred_exps = List.map (irule infer_exp env) exps in - annot_exp (E_tuple inferred_exps) (mk_typ (Typ_tuple (List.map typ_of inferred_exps))) - | E_assign (lexp, bind) -> - begin match lexp_assignment_type env lexp with - | Update -> - fst (bind_assignment l env lexp bind) - | Declaration -> - typ_error env l "Variable declaration with unclear (or no) scope. Use an explicit var statement instead, or place in a block" - end + let inferred_exps = List.map (irule infer_exp env) exps in + annot_exp (E_tuple inferred_exps) (mk_typ (Typ_tuple (List.map typ_of inferred_exps))) + | E_assign (lexp, bind) -> begin + match lexp_assignment_type env lexp with + | Update -> fst (bind_assignment l env lexp bind) + | Declaration -> + typ_error env l + "Variable declaration with unclear (or no) scope. Use an explicit var statement instead, or place in a \ + block" + end | E_struct_update (exp, fexps) -> - let inferred_exp = irule infer_exp env exp in - let typ = typ_of inferred_exp in - let rectyp_id = match Env.expand_synonyms env typ with - | Typ_aux (Typ_id rectyp_id, _) | Typ_aux (Typ_app (rectyp_id, _), _) when Env.is_record rectyp_id env -> - rectyp_id - | _ -> typ_error env l ("The type " ^ string_of_typ typ ^ " is not a record") - in - let check_fexp (FE_aux (FE_fexp (field, exp), (l, _))) = - let (_, rectyp_q, field_typ) = Env.get_accessor rectyp_id field env in - let unifiers = try unify l env (tyvars_of_typ rectyp_q) rectyp_q typ with Unification_error (l, m) -> typ_error env l ("Unification error: " ^ m) in - let field_typ' = subst_unifiers unifiers field_typ in - let inferred_exp = crule check_exp env exp field_typ' in - FE_aux (FE_fexp (field, inferred_exp), (l, empty_tannot)) - in - annot_exp (E_struct_update (inferred_exp, List.map check_fexp fexps)) typ + let inferred_exp = irule infer_exp env exp in + let typ = typ_of inferred_exp in + let rectyp_id = + match Env.expand_synonyms env typ with + | (Typ_aux (Typ_id rectyp_id, _) | Typ_aux (Typ_app (rectyp_id, _), _)) when Env.is_record rectyp_id env -> + rectyp_id + | _ -> typ_error env l ("The type " ^ string_of_typ typ ^ " is not a record") + in + let check_fexp (FE_aux (FE_fexp (field, exp), (l, _))) = + let _, rectyp_q, field_typ = Env.get_accessor rectyp_id field env in + let unifiers = + try unify l env (tyvars_of_typ rectyp_q) rectyp_q typ + with Unification_error (l, m) -> typ_error env l ("Unification error: " ^ m) + in + let field_typ' = subst_unifiers unifiers field_typ in + let inferred_exp = crule check_exp env exp field_typ' in + FE_aux (FE_fexp (field, inferred_exp), (l, empty_tannot)) + in + annot_exp (E_struct_update (inferred_exp, List.map check_fexp fexps)) typ | E_typ (typ, exp) -> - let checked_exp = crule check_exp env exp typ in - annot_exp (E_typ (typ, checked_exp)) typ + let checked_exp = crule check_exp env exp typ in + annot_exp (E_typ (typ, checked_exp)) typ | E_app_infix (x, op, y) -> infer_exp env (E_aux (E_app (deinfix op, [x; y]), (l, uannot))) (* Treat a multiple argument constructor as a single argument constructor taking a tuple, Ctor(x, y) -> Ctor((x, y)). *) | E_app (ctor, x :: y :: zs) when Env.is_union_constructor ctor env -> - typ_print (lazy ("Inferring multiple argument constructor: " ^ string_of_id ctor)); - irule infer_exp env (mk_exp ~loc:l (E_app (ctor, [mk_exp ~loc:l (E_tuple (x :: y :: zs))]))) + typ_print (lazy ("Inferring multiple argument constructor: " ^ string_of_id ctor)); + irule infer_exp env (mk_exp ~loc:l (E_app (ctor, [mk_exp ~loc:l (E_tuple (x :: y :: zs))]))) | E_app (mapping, xs) when Env.is_mapping mapping env -> - let forwards_id = mk_id (string_of_id mapping ^ "_forwards") in - let backwards_id = mk_id (string_of_id mapping ^ "_backwards") in - typ_print (lazy ("Trying forwards direction for mapping " ^ string_of_id mapping ^ "(" ^ string_of_list ", " string_of_exp xs ^ ")")); - begin try irule infer_exp env (E_aux (E_app (forwards_id, xs), (l, uannot))) with - | Type_error (_, _, err1) -> - (* typ_print (lazy ("Error in forwards direction: " ^ string_of_type_error err1)); *) - typ_print (lazy ("Trying backwards direction for mapping " ^ string_of_id mapping ^ "(" ^ string_of_list ", " string_of_exp xs ^ ")")); - begin try irule infer_exp env (E_aux (E_app (backwards_id, xs), (l, uannot))) with - | Type_error (env, _, err2) -> - (* typ_print (lazy ("Error in backwards direction: " ^ string_of_type_error err2)); *) - typ_raise env l (Err_no_overloading (mapping, [(forwards_id, err1); (backwards_id, err2)])) - end - end + let forwards_id = mk_id (string_of_id mapping ^ "_forwards") in + let backwards_id = mk_id (string_of_id mapping ^ "_backwards") in + typ_print + ( lazy + ("Trying forwards direction for mapping " ^ string_of_id mapping ^ "(" ^ string_of_list ", " string_of_exp xs + ^ ")" + ) + ); + begin + try irule infer_exp env (E_aux (E_app (forwards_id, xs), (l, uannot))) + with Type_error (_, _, err1) -> + (* typ_print (lazy ("Error in forwards direction: " ^ string_of_type_error err1)); *) + typ_print + ( lazy + ("Trying backwards direction for mapping " ^ string_of_id mapping ^ "(" + ^ string_of_list ", " string_of_exp xs ^ ")" + ) + ); + begin + try irule infer_exp env (E_aux (E_app (backwards_id, xs), (l, uannot))) + with Type_error (env, _, err2) -> + (* typ_print (lazy ("Error in backwards direction: " ^ string_of_type_error err2)); *) + typ_raise env l (Err_no_overloading (mapping, [(forwards_id, err1); (backwards_id, err2)])) + end + end | E_app (f, xs) when List.length (Env.get_overloads f env) > 0 -> - let rec try_overload = function - | (errs, []) -> typ_raise env l (Err_no_overloading (f, errs)) - | (errs, (f :: fs)) -> begin - typ_print (lazy ("Overload: " ^ string_of_id f ^ "(" ^ string_of_list ", " string_of_exp xs ^ ")")); - try irule infer_exp env (E_aux (E_app (f, xs), (l, uannot))) with - | Type_error (_, _, err) -> + let rec try_overload = function + | errs, [] -> typ_raise env l (Err_no_overloading (f, errs)) + | errs, f :: fs -> begin + typ_print (lazy ("Overload: " ^ string_of_id f ^ "(" ^ string_of_list ", " string_of_exp xs ^ ")")); + try irule infer_exp env (E_aux (E_app (f, xs), (l, uannot))) + with Type_error (_, _, err) -> typ_debug (lazy "Error"); try_overload (errs @ [(f, err)], fs) - end - in - try_overload ([], Env.get_overloads f env) - | E_app (f, [x; y]) when string_of_id f = "and_bool" || string_of_id f = "or_bool" -> - begin match destruct_exist (typ_of (irule infer_exp env y)) with - | None | Some (_, NC_aux (NC_true, _), _) -> infer_funapp l env f [x; y] None - | Some _ -> infer_funapp l env f [x; mk_exp (E_typ (bool_typ, y))] None - | exception Type_error _ -> infer_funapp l env f [x; mk_exp (E_typ (bool_typ, y))] None - end + end + in + try_overload ([], Env.get_overloads f env) + | E_app (f, [x; y]) when string_of_id f = "and_bool" || string_of_id f = "or_bool" -> begin + match destruct_exist (typ_of (irule infer_exp env y)) with + | None | Some (_, NC_aux (NC_true, _), _) -> infer_funapp l env f [x; y] None + | Some _ -> infer_funapp l env f [x; mk_exp (E_typ (bool_typ, y))] None + | exception Type_error _ -> infer_funapp l env f [x; mk_exp (E_typ (bool_typ, y))] None + end | E_app (f, xs) -> infer_funapp l env f xs None | E_loop (loop_type, measure, cond, body) -> - let checked_cond = crule check_exp env cond bool_typ in - let checked_measure = match measure with - | Measure_aux (Measure_none,l) -> Measure_aux (Measure_none,l) - | Measure_aux (Measure_some exp,l) -> - Measure_aux (Measure_some (crule check_exp env exp int_typ),l) - in - let nc = match loop_type with - | While -> assert_constraint env true checked_cond - | Until -> None - in - let checked_body = crule check_exp (add_opt_constraint l "loop condition" nc env) body unit_typ in - annot_exp (E_loop (loop_type, checked_measure, checked_cond, checked_body)) unit_typ - | E_for (v, f, t, step, ord, body) -> - begin - let f, t, is_dec = match ord with - | Ord_aux (Ord_inc, _) -> f, t, false - | Ord_aux (Ord_dec, _) -> t, f, true (* reverse direction to typechecking downto as upto loop *) - | Ord_aux (Ord_var _, _) -> typ_error env l "Cannot check a loop with variable direction!" (* This should never happen *) - in - let inferred_f = irule infer_exp env f in - let inferred_t = irule infer_exp env t in - let checked_step = crule check_exp env step int_typ in - match destruct_numeric (typ_of inferred_f), destruct_numeric (typ_of inferred_t) with - | Some (kids1, nc1, nexp1), Some (kids2, nc2, nexp2) -> + let checked_cond = crule check_exp env cond bool_typ in + let checked_measure = + match measure with + | Measure_aux (Measure_none, l) -> Measure_aux (Measure_none, l) + | Measure_aux (Measure_some exp, l) -> Measure_aux (Measure_some (crule check_exp env exp int_typ), l) + in + let nc = match loop_type with While -> assert_constraint env true checked_cond | Until -> None in + let checked_body = crule check_exp (add_opt_constraint l "loop condition" nc env) body unit_typ in + annot_exp (E_loop (loop_type, checked_measure, checked_cond, checked_body)) unit_typ + | E_for (v, f, t, step, ord, body) -> begin + let f, t, is_dec = + match ord with + | Ord_aux (Ord_inc, _) -> (f, t, false) + | Ord_aux (Ord_dec, _) -> (t, f, true (* reverse direction to typechecking downto as upto loop *)) + | Ord_aux (Ord_var _, _) -> + typ_error env l "Cannot check a loop with variable direction!" (* This should never happen *) + in + let inferred_f = irule infer_exp env f in + let inferred_t = irule infer_exp env t in + let checked_step = crule check_exp env step int_typ in + match (destruct_numeric (typ_of inferred_f), destruct_numeric (typ_of inferred_t)) with + | Some (kids1, nc1, nexp1), Some (kids2, nc2, nexp2) -> let loop_kid = mk_kid ("loop_" ^ string_of_id v) in - let env = List.fold_left (fun env kid -> Env.add_typ_var l (mk_kopt K_int kid) env) env (loop_kid :: kids1 @ kids2) in + let env = + List.fold_left (fun env kid -> Env.add_typ_var l (mk_kopt K_int kid) env) env ((loop_kid :: kids1) @ kids2) + in let env = Env.add_constraint (nc_and nc1 nc2) env in let env = Env.add_constraint (nc_and (nc_lteq nexp1 (nvar loop_kid)) (nc_lteq (nvar loop_kid) nexp2)) env in let loop_vtyp = atom_typ (nvar loop_kid) in let checked_body = crule check_exp (Env.add_local v (Immutable, loop_vtyp) env) body unit_typ in - if not is_dec (* undo reverse direction in annotated ast for downto loop *) - then annot_exp (E_for (v, inferred_f, inferred_t, checked_step, ord, checked_body)) unit_typ + if not is_dec (* undo reverse direction in annotated ast for downto loop *) then + annot_exp (E_for (v, inferred_f, inferred_t, checked_step, ord, checked_body)) unit_typ else annot_exp (E_for (v, inferred_t, inferred_f, checked_step, ord, checked_body)) unit_typ - | _, _ -> typ_error env l "Ranges in foreach overlap" - end + | _, _ -> typ_error env l "Ranges in foreach overlap" + end | E_if (cond, then_branch, else_branch) -> - let cond' = crule check_exp env cond (mk_typ (Typ_id (mk_id "bool"))) in - let then_branch' = irule infer_exp (add_opt_constraint l "then branch" (assert_constraint env true cond') env) then_branch in - (* We don't have generic type union in Sail, but we can union simple numeric types. *) - begin match destruct_numeric (Env.expand_synonyms env (typ_of then_branch')) with - | Some (kids, nc, then_nexp) -> - let then_sn = to_simple_numeric l kids nc then_nexp in - let else_branch' = irule infer_exp (add_opt_constraint l "else branch" (Option.map nc_not (assert_constraint env false cond')) env) else_branch in - begin match destruct_numeric (Env.expand_synonyms env (typ_of else_branch')) with - | Some (kids, nc, else_nexp) -> - let else_sn = to_simple_numeric l kids nc else_nexp in - let typ = typ_of_simple_numeric (union_simple_numeric then_sn else_sn) in - annot_exp (E_if (cond', then_branch', else_branch')) typ - | None -> typ_error env l ("Could not infer type of " ^ string_of_exp else_branch) - end - | None -> - begin match typ_of then_branch' with - | Typ_aux (Typ_app (f, [_]), _) when string_of_id f = "atom_bool" -> - let else_branch' = crule check_exp (add_opt_constraint l "else branch" (Option.map nc_not (assert_constraint env false cond')) env) else_branch bool_typ in - annot_exp (E_if (cond', then_branch', else_branch')) bool_typ - | _ -> - let else_branch' = crule check_exp (add_opt_constraint l "else branch" (Option.map nc_not (assert_constraint env false cond')) env) else_branch (typ_of then_branch') in - annot_exp (E_if (cond', then_branch', else_branch')) (typ_of then_branch') - end - end - | E_vector_access (v, n) -> - begin - try infer_exp env (E_aux (E_app (mk_id "vector_access", [v; n]), (l, uannot))) with - | Type_error (err_env, err_l, err) -> - (try ( - let inferred_v = infer_exp env v in - begin match typ_of inferred_v, n with - | Typ_aux (Typ_id id, _), E_aux (E_id field, (f_l, _)) -> - let access_id = (Bitfield.field_accessor_ids id field).get in - infer_exp env (mk_exp ~loc:l (E_app (access_id, [v]))) - | _, _ -> - typ_error env l "Vector access could not be interpreted as a bitfield access" - end - ) with - | Type_error (_, err_l', err') -> - typ_raise err_env err_l (err_because (err, err_l', err'))) - | exn -> raise exn - end - | E_vector_update (v, n, exp) -> - begin - try infer_exp env (E_aux (E_app (mk_id "vector_update", [v; n; exp]), (l, uannot))) with - | Type_error (err_env, err_l, err) -> - (try ( - let inferred_v = infer_exp env v in - begin match typ_of inferred_v, n with - | Typ_aux (Typ_id id, _), E_aux (E_id field, (f_l, _)) -> - let update_id = (Bitfield.field_accessor_ids id field).update in - infer_exp env (mk_exp ~loc:l (E_app (update_id, [v; exp]))) - | _, _ -> - typ_error env l "Vector update could not be interpreted as a bitfield update" - end - ) with - | Type_error (_, err_l', err') -> - typ_raise err_env err_l (err_because (err, err_l', err'))) - | exn -> raise exn - end - | E_vector_update_subrange (v, n, m, exp) -> infer_exp env (E_aux (E_app (mk_id "vector_update_subrange", [v; n; m; exp]), (l, uannot))) + let cond' = crule check_exp env cond (mk_typ (Typ_id (mk_id "bool"))) in + let then_branch' = + irule infer_exp (add_opt_constraint l "then branch" (assert_constraint env true cond') env) then_branch + in + (* We don't have generic type union in Sail, but we can union simple numeric types. *) + begin + match destruct_numeric (Env.expand_synonyms env (typ_of then_branch')) with + | Some (kids, nc, then_nexp) -> + let then_sn = to_simple_numeric l kids nc then_nexp in + let else_branch' = + irule infer_exp + (add_opt_constraint l "else branch" (Option.map nc_not (assert_constraint env false cond')) env) + else_branch + in + begin + match destruct_numeric (Env.expand_synonyms env (typ_of else_branch')) with + | Some (kids, nc, else_nexp) -> + let else_sn = to_simple_numeric l kids nc else_nexp in + let typ = typ_of_simple_numeric (union_simple_numeric then_sn else_sn) in + annot_exp (E_if (cond', then_branch', else_branch')) typ + | None -> typ_error env l ("Could not infer type of " ^ string_of_exp else_branch) + end + | None -> begin + match typ_of then_branch' with + | Typ_aux (Typ_app (f, [_]), _) when string_of_id f = "atom_bool" -> + let else_branch' = + crule check_exp + (add_opt_constraint l "else branch" (Option.map nc_not (assert_constraint env false cond')) env) + else_branch bool_typ + in + annot_exp (E_if (cond', then_branch', else_branch')) bool_typ + | _ -> + let else_branch' = + crule check_exp + (add_opt_constraint l "else branch" (Option.map nc_not (assert_constraint env false cond')) env) + else_branch (typ_of then_branch') + in + annot_exp (E_if (cond', then_branch', else_branch')) (typ_of then_branch') + end + end + | E_vector_access (v, n) -> begin + try infer_exp env (E_aux (E_app (mk_id "vector_access", [v; n]), (l, uannot))) with + | Type_error (err_env, err_l, err) -> ( + try + let inferred_v = infer_exp env v in + begin + match (typ_of inferred_v, n) with + | Typ_aux (Typ_id id, _), E_aux (E_id field, (f_l, _)) -> + let access_id = (Bitfield.field_accessor_ids id field).get in + infer_exp env (mk_exp ~loc:l (E_app (access_id, [v]))) + | _, _ -> typ_error env l "Vector access could not be interpreted as a bitfield access" + end + with Type_error (_, err_l', err') -> typ_raise err_env err_l (err_because (err, err_l', err')) + ) + | exn -> raise exn + end + | E_vector_update (v, n, exp) -> begin + try infer_exp env (E_aux (E_app (mk_id "vector_update", [v; n; exp]), (l, uannot))) with + | Type_error (err_env, err_l, err) -> ( + try + let inferred_v = infer_exp env v in + begin + match (typ_of inferred_v, n) with + | Typ_aux (Typ_id id, _), E_aux (E_id field, (f_l, _)) -> + let update_id = (Bitfield.field_accessor_ids id field).update in + infer_exp env (mk_exp ~loc:l (E_app (update_id, [v; exp]))) + | _, _ -> typ_error env l "Vector update could not be interpreted as a bitfield update" + end + with Type_error (_, err_l', err') -> typ_raise err_env err_l (err_because (err, err_l', err')) + ) + | exn -> raise exn + end + | E_vector_update_subrange (v, n, m, exp) -> + infer_exp env (E_aux (E_app (mk_id "vector_update_subrange", [v; n; m; exp]), (l, uannot))) | E_vector_append (v1, E_aux (E_vector [], _)) -> infer_exp env v1 | E_vector_append (v1, v2) -> infer_exp env (E_aux (E_app (mk_id "append", [v1; v2]), (l, uannot))) | E_vector_subrange (v, n, m) -> infer_exp env (E_aux (E_app (mk_id "vector_subrange", [v; n; m]), (l, uannot))) | E_vector [] -> typ_error env l "Cannot infer type of empty vector" - | E_vector ((item :: items) as vec) -> - let inferred_item = irule infer_exp env item in - let checked_items = List.map (fun i -> crule check_exp env i (typ_of inferred_item)) items in - begin match typ_of inferred_item with - | Typ_aux (Typ_id id, _) when string_of_id id = "bit" -> - let bitvec_typ = bits_typ env (nint (List.length vec)) in - annot_exp (E_vector (inferred_item :: checked_items)) bitvec_typ - | _ -> - let vec_typ = dvector_typ env (nint (List.length vec)) (typ_of inferred_item) in - annot_exp (E_vector (inferred_item :: checked_items)) vec_typ - end + | E_vector (item :: items as vec) -> + let inferred_item = irule infer_exp env item in + let checked_items = List.map (fun i -> crule check_exp env i (typ_of inferred_item)) items in + begin + match typ_of inferred_item with + | Typ_aux (Typ_id id, _) when string_of_id id = "bit" -> + let bitvec_typ = bits_typ env (nint (List.length vec)) in + annot_exp (E_vector (inferred_item :: checked_items)) bitvec_typ + | _ -> + let vec_typ = dvector_typ env (nint (List.length vec)) (typ_of inferred_item) in + annot_exp (E_vector (inferred_item :: checked_items)) vec_typ + end | E_assert (test, msg) -> - let msg = assert_msg msg in - let checked_test = crule check_exp env test bool_typ in - let checked_msg = crule check_exp env msg string_typ in - annot_exp (E_assert (checked_test, checked_msg)) unit_typ + let msg = assert_msg msg in + let checked_test = crule check_exp env test bool_typ in + let checked_msg = crule check_exp env msg string_typ in + annot_exp (E_assert (checked_test, checked_msg)) unit_typ | E_internal_return exp -> - let inferred_exp = irule infer_exp env exp in - annot_exp (E_internal_return inferred_exp) (typ_of inferred_exp) + let inferred_exp = irule infer_exp env exp in + annot_exp (E_internal_return inferred_exp) (typ_of inferred_exp) | E_internal_plet (pat, bind, body) -> - let bind_exp, ptyp = match pat with - | P_aux (P_typ (ptyp, _), _) -> - Env.wf_typ env ptyp; - let checked_bind = crule check_exp env bind ptyp in - checked_bind, ptyp - | _ -> - let inferred_bind = irule infer_exp env bind in - inferred_bind, typ_of inferred_bind in - let tpat, env = bind_pat_no_guard env pat ptyp in - (* Propagate constraint assertions on the lhs of monadic binds to the rhs *) - let env = match bind_exp with - | E_aux (E_assert (constr_exp, _), _) -> - begin + let bind_exp, ptyp = + match pat with + | P_aux (P_typ (ptyp, _), _) -> + Env.wf_typ env ptyp; + let checked_bind = crule check_exp env bind ptyp in + (checked_bind, ptyp) + | _ -> + let inferred_bind = irule infer_exp env bind in + (inferred_bind, typ_of inferred_bind) + in + let tpat, env = bind_pat_no_guard env pat ptyp in + (* Propagate constraint assertions on the lhs of monadic binds to the rhs *) + let env = + match bind_exp with + | E_aux (E_assert (constr_exp, _), _) -> begin match assert_constraint env true constr_exp with | Some nc -> - typ_print (lazy ("Adding constraint " ^ string_of_n_constraint nc ^ " for assert")); - Env.add_constraint nc env + typ_print (lazy ("Adding constraint " ^ string_of_n_constraint nc ^ " for assert")); + Env.add_constraint nc env | None -> env end - | _ -> env in - let inferred_body = irule infer_exp env body in - annot_exp (E_internal_plet (tpat, bind_exp, inferred_body)) (typ_of inferred_body) + | _ -> env + in + let inferred_body = irule infer_exp env body in + annot_exp (E_internal_plet (tpat, bind_exp, inferred_body)) (typ_of inferred_body) | E_let (LB_aux (letbind, (let_loc, _)), exp) -> - let bind_exp, pat, ptyp = match letbind with - | LB_val (P_aux (P_typ (ptyp, _), _) as pat, bind) -> - Env.wf_typ env ptyp; - let checked_bind = crule check_exp env bind ptyp in - checked_bind, pat, ptyp - | LB_val (pat, bind) -> - let inferred_bind = irule infer_exp env bind in - inferred_bind, pat, typ_of inferred_bind in - check_pattern_duplicates env pat; - let tpat, inner_env = bind_pat_no_guard env pat ptyp in - let inferred_exp = irule infer_exp inner_env exp in - annot_exp (E_let (LB_aux (LB_val (tpat, bind_exp), (let_loc, empty_tannot)), inferred_exp)) - (check_shadow_leaks l inner_env env (typ_of inferred_exp)) + let bind_exp, pat, ptyp = + match letbind with + | LB_val ((P_aux (P_typ (ptyp, _), _) as pat), bind) -> + Env.wf_typ env ptyp; + let checked_bind = crule check_exp env bind ptyp in + (checked_bind, pat, ptyp) + | LB_val (pat, bind) -> + let inferred_bind = irule infer_exp env bind in + (inferred_bind, pat, typ_of inferred_bind) + in + check_pattern_duplicates env pat; + let tpat, inner_env = bind_pat_no_guard env pat ptyp in + let inferred_exp = irule infer_exp inner_env exp in + annot_exp + (E_let (LB_aux (LB_val (tpat, bind_exp), (let_loc, empty_tannot)), inferred_exp)) + (check_shadow_leaks l inner_env env (typ_of inferred_exp)) | E_ref id when Env.is_register id env -> - let typ = Env.get_register id env in - annot_exp (E_ref id) (register_typ typ) + let typ = Env.get_register id env in + annot_exp (E_ref id) (register_typ typ) | E_internal_assume (nc, exp) -> - Env.wf_constraint env nc; - let env = Env.add_constraint nc env in - let exp' = irule infer_exp env exp in - annot_exp (E_internal_assume (nc, exp')) (typ_of exp') + Env.wf_constraint env nc; + let env = Env.add_constraint nc env in + let exp' = irule infer_exp env exp in + annot_exp (E_internal_assume (nc, exp')) (typ_of exp') | _ -> typ_error env l ("Cannot infer type of: " ^ string_of_exp exp) and infer_funapp l env f xs ret_ctx_typ = infer_funapp' l env f (Env.get_val_spec f env) xs ret_ctx_typ and instantiation_of (E_aux (_, (l, tannot)) as exp) = match fst tannot with - | Some t -> - begin match t.instantiation with - | Some inst -> inst - | None -> - raise (Reporting.err_unreachable l __POS__ "Passed non type-checked function to instantiation_of") - end + | Some t -> begin + match t.instantiation with + | Some inst -> inst + | None -> raise (Reporting.err_unreachable l __POS__ "Passed non type-checked function to instantiation_of") + end | _ -> invalid_arg ("instantiation_of expected application, got " ^ string_of_exp exp) and instantiation_of_without_type (E_aux (exp_aux, (l, _)) as exp) = let env = env_of exp in match exp_aux with - | E_app (f, xs) -> instantiation_of (infer_funapp' l (Env.no_casts env) f (Env.get_val_spec f env) (List.map strip_exp xs) None) + | E_app (f, xs) -> + instantiation_of (infer_funapp' l (Env.no_casts env) f (Env.get_val_spec f env) (List.map strip_exp xs) None) | _ -> invalid_arg ("instantiation_of expected application, got " ^ string_of_exp exp) and infer_funapp' l env f (typq, f_typ) xs expected_ret_typ = typ_print (lazy (Util.("Function " |> cyan |> clear) ^ string_of_id f)); let annot_exp exp typ inst = - E_aux (exp, (l, (Some { env = env; typ = typ; monadic = no_effect; expected = expected_ret_typ; instantiation = Some inst }, empty_uannot))) + E_aux + ( exp, + ( l, + (Some { env; typ; monadic = no_effect; expected = expected_ret_typ; instantiation = Some inst }, empty_uannot) + ) + ) in let is_bound env kid = KBindings.mem kid (Env.get_typ_vars env) in @@ -4644,18 +4824,24 @@ and infer_funapp' l env f (typq, f_typ) xs expected_ret_typ = let record_unifiers unifiers = let previous_unifiers = !all_unifiers in let updated_unifiers = KBindings.map (subst_unifiers_typ_arg unifiers) previous_unifiers in - all_unifiers := merge_uvars env l updated_unifiers unifiers; + all_unifiers := merge_uvars env l updated_unifiers unifiers in let quants, typ_args, typ_ret = match Env.expand_synonyms (Env.add_typquant l typq env) f_typ with - | Typ_aux (Typ_fn (typ_args, typ_ret), _) -> ref (quant_items typq), typ_args, ref typ_ret + | Typ_aux (Typ_fn (typ_args, typ_ret), _) -> (ref (quant_items typq), typ_args, ref typ_ret) | _ -> typ_error env l (string_of_typ f_typ ^ " is not a function type") in let unifiers = instantiate_simple_equations !quants in typ_debug (lazy "Instantiating from equations"); - typ_debug (lazy (string_of_list ", " (fun (kid, arg) -> string_of_kid kid ^ " => " ^ string_of_typ_arg arg) (KBindings.bindings unifiers))); + typ_debug + ( lazy + (string_of_list ", " + (fun (kid, arg) -> string_of_kid kid ^ " => " ^ string_of_typ_arg arg) + (KBindings.bindings unifiers) + ) + ); all_unifiers := unifiers; let typ_args = List.map (subst_unifiers unifiers) typ_args in List.iter (fun unifier -> quants := instantiate_quants env !quants unifier) (KBindings.bindings unifiers); @@ -4665,76 +4851,91 @@ and infer_funapp' l env f (typq, f_typ) xs expected_ret_typ = let implicits, typ_args, xs = let typ_args' = List.filter is_not_implicit typ_args in - match xs, typ_args' with - (* Support the case where a function has only implicit arguments; - allow it to be called either as f() or f(i...) *) - | [E_aux (E_lit (L_aux (L_unit, _)), _)], [] -> - get_implicits typ_args, [], [] + match (xs, typ_args') with + (* Support the case where a function has only implicit arguments; + allow it to be called either as f() or f(i...) *) + | [E_aux (E_lit (L_aux (L_unit, _)), _)], [] -> (get_implicits typ_args, [], []) | _ -> - if not (List.length typ_args = List.length xs) then - if not (List.length typ_args' = List.length xs) then - typ_error env l (Printf.sprintf "Function %s applied to %d args, expected %d (%d explicit): %s" (string_of_id f) (List.length xs) (List.length typ_args) (List.length typ_args') (String.concat ", " (List.map string_of_typ typ_args))) - else - get_implicits typ_args, typ_args', xs - else - [], List.map implicit_to_int typ_args, xs + if not (List.length typ_args = List.length xs) then + if not (List.length typ_args' = List.length xs) then + typ_error env l + (Printf.sprintf "Function %s applied to %d args, expected %d (%d explicit): %s" (string_of_id f) + (List.length xs) (List.length typ_args) (List.length typ_args') + (String.concat ", " (List.map string_of_typ typ_args)) + ) + else (get_implicits typ_args, typ_args', xs) + else ([], List.map implicit_to_int typ_args, xs) in - let typ_args = match expected_ret_typ with + let typ_args = + match expected_ret_typ with | None -> typ_args | Some expect when is_exist (Env.expand_synonyms env expect) || is_exist !typ_ret -> typ_args - | Some expect -> - let goals = quant_kopts (mk_typquant !quants) |> List.map kopt_kid |> KidSet.of_list in - try - let unifiers = unify l env (KidSet.diff goals (ambiguous_vars !typ_ret)) !typ_ret expect in - record_unifiers unifiers; - let unifiers = KBindings.bindings unifiers in - typ_debug (lazy (Util.("Unifiers " |> magenta |> clear) - ^ Util.string_of_list ", " (fun (v, arg) -> string_of_kid v ^ " => " ^ string_of_typ_arg arg) unifiers)); - List.iter (fun unifier -> quants := instantiate_quants env !quants unifier) unifiers; - List.iter (fun (v, arg) -> typ_ret := typ_subst v arg !typ_ret) unifiers; - List.map (fun typ -> List.fold_left (fun typ (v, arg) -> typ_subst v arg typ) typ unifiers) typ_args - with Unification_error _ -> typ_args + | Some expect -> ( + let goals = quant_kopts (mk_typquant !quants) |> List.map kopt_kid |> KidSet.of_list in + try + let unifiers = unify l env (KidSet.diff goals (ambiguous_vars !typ_ret)) !typ_ret expect in + record_unifiers unifiers; + let unifiers = KBindings.bindings unifiers in + typ_debug + ( lazy + (Util.("Unifiers " |> magenta |> clear) + ^ Util.string_of_list ", " (fun (v, arg) -> string_of_kid v ^ " => " ^ string_of_typ_arg arg) unifiers + ) + ); + List.iter (fun unifier -> quants := instantiate_quants env !quants unifier) unifiers; + List.iter (fun (v, arg) -> typ_ret := typ_subst v arg !typ_ret) unifiers; + List.map (fun typ -> List.fold_left (fun typ (v, arg) -> typ_subst v arg typ) typ unifiers) typ_args + with Unification_error _ -> typ_args + ) in (* We now iterate throught the function arguments, checking them and instantiating quantifiers. *) let instantiate env arg typ remaining_typs = - if KidSet.for_all (is_bound env) (tyvars_of_typ typ) then - crule check_exp env arg typ, remaining_typs, env - else + if KidSet.for_all (is_bound env) (tyvars_of_typ typ) then (crule check_exp env arg typ, remaining_typs, env) + else ( let goals = quant_kopts (mk_typquant !quants) |> List.map kopt_kid |> KidSet.of_list in typ_debug (lazy ("Quantifiers " ^ Util.string_of_list ", " string_of_quant_item !quants)); let inferred_arg = irule infer_exp env arg in let inferred_arg, unifiers, env = - try type_coercion_unify env goals inferred_arg typ with - | Unification_error (l, m) -> typ_error env l m + try type_coercion_unify env goals inferred_arg typ with Unification_error (l, m) -> typ_error env l m in record_unifiers unifiers; let unifiers = KBindings.bindings unifiers in - typ_debug (lazy (Util.("Unifiers " |> magenta |> clear) - ^ Util.string_of_list ", " (fun (v, arg) -> string_of_kid v ^ " => " ^ string_of_typ_arg arg) unifiers)); + typ_debug + ( lazy + (Util.("Unifiers " |> magenta |> clear) + ^ Util.string_of_list ", " (fun (v, arg) -> string_of_kid v ^ " => " ^ string_of_typ_arg arg) unifiers + ) + ); List.iter (fun unifier -> quants := instantiate_quants env !quants unifier) unifiers; List.iter (fun (v, arg) -> typ_ret := typ_subst v arg !typ_ret) unifiers; let remaining_typs = List.map (fun typ -> List.fold_left (fun typ (v, arg) -> typ_subst v arg typ) typ unifiers) remaining_typs in - inferred_arg, remaining_typs, env + (inferred_arg, remaining_typs, env) + ) in let fold_instantiate (xs, args, env) x = match args with | arg :: remaining_args -> - let x, remaining_args, env = instantiate env x arg remaining_args in - (x :: xs, remaining_args, env) + let x, remaining_args, env = instantiate env x arg remaining_args in + (x :: xs, remaining_args, env) | [] -> raise (Reporting.err_unreachable l __POS__ "Empty arguments during instantiation") in let xs, _, env = List.fold_left fold_instantiate ([], typ_args, env) xs in let xs = List.rev xs in - let solve_implicit impl = match KBindings.find_opt impl !all_unifiers with + let solve_implicit impl = + match KBindings.find_opt impl !all_unifiers with | Some (A_aux (A_nexp (Nexp_aux (Nexp_constant c, _)), _)) -> irule infer_exp env (mk_lit_exp (L_num c)) | Some (A_aux (A_nexp n, _)) -> irule infer_exp env (mk_exp (E_sizeof n)) - | _ -> typ_error env l ("Cannot solve implicit " ^ string_of_kid impl ^ " in " ^ string_of_exp (mk_exp (E_app (f, List.map strip_exp xs)))) + | _ -> + typ_error env l + ("Cannot solve implicit " ^ string_of_kid impl ^ " in " + ^ string_of_exp (mk_exp (E_app (f, List.map strip_exp xs))) + ) in let xs = List.map solve_implicit implicits @ xs in @@ -4752,7 +4953,9 @@ and infer_funapp' l env f (typq, f_typ) xs expected_ret_typ = let universals = KBindings.bindings universals |> List.map fst |> KidSet.of_list in let typ_ret = - if KidSet.is_empty (KidSet.of_list (List.map kopt_kid existentials)) || KidSet.is_empty (KidSet.diff (tyvars_of_typ !typ_ret) universals) + if + KidSet.is_empty (KidSet.of_list (List.map kopt_kid existentials)) + || KidSet.is_empty (KidSet.diff (tyvars_of_typ !typ_ret) universals) then !typ_ret else mk_typ (Typ_exist (existentials, List.fold_left nc_and nc_true ex_constraints, !typ_ret)) in @@ -4763,357 +4966,356 @@ and infer_funapp' l env f (typq, f_typ) xs expected_ret_typ = and bind_mpat allow_unknown other_env env (MP_aux (mpat_aux, (l, _)) as mpat) typ = let typ, env = bind_existential l None typ env in - typ_print (lazy (Util.("Binding " |> yellow |> clear) ^ string_of_mpat mpat ^ " to " ^ string_of_typ typ)); + typ_print (lazy (Util.("Binding " |> yellow |> clear) ^ string_of_mpat mpat ^ " to " ^ string_of_typ typ)); let annot_mpat mpat typ' = MP_aux (mpat, (l, mk_expected_tannot env typ' (Some typ))) in - let switch_typ mpat typ = match mpat with - | MP_aux (pat_aux, (l, (Some tannot, uannot))) -> MP_aux (pat_aux, (l, (Some { tannot with typ = typ }, uannot))) + let switch_typ mpat typ = + match mpat with + | MP_aux (pat_aux, (l, (Some tannot, uannot))) -> MP_aux (pat_aux, (l, (Some { tannot with typ }, uannot))) | _ -> typ_error env l "Cannot switch type for unannotated mapping-pattern" in let bind_tuple_mpat (tpats, env, guards) mpat typ = - let tpat, env, guards' = bind_mpat allow_unknown other_env env mpat typ in tpat :: tpats, env, guards' @ guards + let tpat, env, guards' = bind_mpat allow_unknown other_env env mpat typ in + (tpat :: tpats, env, guards' @ guards) in match mpat_aux with - | MP_id v -> - begin - (* If the identifier we're matching on is also a constructor of - a union, that's probably a mistake, so warn about it. *) - if Env.is_union_constructor v env then - Reporting.warn (Printf.sprintf "Identifier %s found in mapping-pattern is also a union constructor at" - (string_of_id v)) - l "" - else (); - match Env.lookup_id v env with - | Local (Immutable, _) | Unbound _ -> annot_mpat (MP_id v) typ, Env.add_local v (Immutable, typ) env, [] - | Local (Mutable, _) | Register _ -> - typ_error env l ("Cannot shadow mutable local or register in switch statement mapping-pattern " ^ string_of_mpat mpat) - | Enum enum -> subtyp l env enum typ; annot_mpat (MP_id v) typ, env, [] - end - | MP_cons (hd_mpat, tl_mpat) -> - begin - match Env.expand_synonyms env typ with - | Typ_aux (Typ_app (f, [A_aux (A_typ ltyp, _)]), _) when Id.compare f (mk_id "list") = 0 -> + | MP_id v -> begin + (* If the identifier we're matching on is also a constructor of + a union, that's probably a mistake, so warn about it. *) + if Env.is_union_constructor v env then + Reporting.warn + (Printf.sprintf "Identifier %s found in mapping-pattern is also a union constructor at" (string_of_id v)) + l "" + else (); + match Env.lookup_id v env with + | Local (Immutable, _) | Unbound _ -> (annot_mpat (MP_id v) typ, Env.add_local v (Immutable, typ) env, []) + | Local (Mutable, _) | Register _ -> + typ_error env l + ("Cannot shadow mutable local or register in switch statement mapping-pattern " ^ string_of_mpat mpat) + | Enum enum -> + subtyp l env enum typ; + (annot_mpat (MP_id v) typ, env, []) + end + | MP_cons (hd_mpat, tl_mpat) -> begin + match Env.expand_synonyms env typ with + | Typ_aux (Typ_app (f, [A_aux (A_typ ltyp, _)]), _) when Id.compare f (mk_id "list") = 0 -> let hd_mpat, env, hd_guards = bind_mpat allow_unknown other_env env hd_mpat ltyp in let tl_mpat, env, tl_guards = bind_mpat allow_unknown other_env env tl_mpat typ in - annot_mpat (MP_cons (hd_mpat, tl_mpat)) typ, env, hd_guards @ tl_guards - | _ -> typ_error env l "Cannot match cons mapping-pattern against non-list type" - end - | MP_string_append mpats -> - begin - match Env.expand_synonyms env typ with - | Typ_aux (Typ_id id, _) when Id.compare id (mk_id "string") = 0 -> + (annot_mpat (MP_cons (hd_mpat, tl_mpat)) typ, env, hd_guards @ tl_guards) + | _ -> typ_error env l "Cannot match cons mapping-pattern against non-list type" + end + | MP_string_append mpats -> begin + match Env.expand_synonyms env typ with + | Typ_aux (Typ_id id, _) when Id.compare id (mk_id "string") = 0 -> let rec process_mpats env = function - | [] -> [], env, [] + | [] -> ([], env, []) | pat :: pats -> - let pat', env, guards = bind_mpat allow_unknown other_env env pat typ in - let pats', env, guards' = process_mpats env pats in - pat' :: pats', env, guards @ guards' + let pat', env, guards = bind_mpat allow_unknown other_env env pat typ in + let pats', env, guards' = process_mpats env pats in + (pat' :: pats', env, guards @ guards') in let pats, env, guards = process_mpats env mpats in - annot_mpat (MP_string_append pats) typ, env, guards - | _ -> typ_error env l "Cannot match string-append pattern against non-string type" - end - | MP_list mpats -> - begin - match Env.expand_synonyms env typ with - | Typ_aux (Typ_app (f, [A_aux (A_typ ltyp, _)]), _) when Id.compare f (mk_id "list") = 0 -> + (annot_mpat (MP_string_append pats) typ, env, guards) + | _ -> typ_error env l "Cannot match string-append pattern against non-string type" + end + | MP_list mpats -> begin + match Env.expand_synonyms env typ with + | Typ_aux (Typ_app (f, [A_aux (A_typ ltyp, _)]), _) when Id.compare f (mk_id "list") = 0 -> let rec process_mpats env = function - | [] -> [], env, [] - | (_ :: mpats) -> - let mpat', env, guards = bind_mpat allow_unknown other_env env mpat ltyp in - let mpats', env, guards' = process_mpats env mpats in - mpat' :: mpats', env, guards @ guards' + | [] -> ([], env, []) + | _ :: mpats -> + let mpat', env, guards = bind_mpat allow_unknown other_env env mpat ltyp in + let mpats', env, guards' = process_mpats env mpats in + (mpat' :: mpats', env, guards @ guards') in let mpats, env, guards = process_mpats env mpats in - annot_mpat (MP_list mpats) typ, env, guards - | _ -> typ_error env l ("Cannot match list mapping-pattern " ^ string_of_mpat mpat ^ " against non-list type " ^ string_of_typ typ) - end - | MP_tuple [] -> - begin - match Env.expand_synonyms env typ with - | Typ_aux (Typ_id typ_id, _) when string_of_id typ_id = "unit" -> - annot_mpat (MP_tuple []) typ, env, [] - | _ -> typ_error env l "Cannot match unit mapping-pattern against non-unit type" - end - | MP_tuple mpats -> - begin - match Env.expand_synonyms env typ with - | Typ_aux (Typ_tuple typs, _) -> + (annot_mpat (MP_list mpats) typ, env, guards) + | _ -> + typ_error env l + ("Cannot match list mapping-pattern " ^ string_of_mpat mpat ^ " against non-list type " ^ string_of_typ typ) + end + | MP_tuple [] -> begin + match Env.expand_synonyms env typ with + | Typ_aux (Typ_id typ_id, _) when string_of_id typ_id = "unit" -> (annot_mpat (MP_tuple []) typ, env, []) + | _ -> typ_error env l "Cannot match unit mapping-pattern against non-unit type" + end + | MP_tuple mpats -> begin + match Env.expand_synonyms env typ with + | Typ_aux (Typ_tuple typs, _) -> let tpats, env, guards = - try List.fold_left2 bind_tuple_mpat ([], env, []) mpats typs with - | Invalid_argument _ -> typ_error env l "Tuple mapping-pattern and tuple type have different length" + try List.fold_left2 bind_tuple_mpat ([], env, []) mpats typs + with Invalid_argument _ -> typ_error env l "Tuple mapping-pattern and tuple type have different length" in - annot_mpat (MP_tuple (List.rev tpats)) typ, env, guards - | _ -> typ_error env l "Cannot bind tuple mapping-pattern against non tuple type" - end - | MP_app (f, mpats) when Env.is_union_constructor f env -> - begin - let (typq, ctor_typ) = Env.get_val_spec f env in - let quants = quant_items typq in - let untuple (Typ_aux (typ_aux, _) as typ) = match typ_aux with - | Typ_tuple typs -> typs - | _ -> [typ] - in - match Env.expand_synonyms env ctor_typ with - | Typ_aux (Typ_fn ([arg_typ], ret_typ), _) -> - begin - try - typ_debug (lazy ("Unifying " ^ string_of_bind (typq, ctor_typ) ^ " for mapping-pattern " ^ string_of_typ typ)); - let unifiers = unify l env (tyvars_of_typ ret_typ) ret_typ typ in - let arg_typ' = subst_unifiers unifiers arg_typ in - let quants' = List.fold_left (instantiate_quants env) quants (KBindings.bindings unifiers) in - if (match quants' with [] -> false | _ -> true) - then typ_error env l ("Quantifiers " ^ string_of_list ", " string_of_quant_item quants' ^ " not resolved in mapping-pattern " ^ string_of_mpat mpat) - else (); - let _ret_typ' = subst_unifiers unifiers ret_typ in - let tpats, env, guards = - try List.fold_left2 bind_tuple_mpat ([], env, []) mpats (untuple arg_typ') with - | Invalid_argument _ -> typ_error env l "Union constructor mapping-pattern arguments have incorrect length" - in - annot_mpat (MP_app (f, List.rev tpats)) typ, env, guards - with - | Unification_error (l, m) -> typ_error env l ("Unification error when mapping-pattern matching against union constructor: " ^ m) - end - | _ -> typ_error env l ("Mal-formed constructor " ^ string_of_id f ^ " with type " ^ string_of_typ ctor_typ) - end - | MP_app (other, mpats) when Env.is_mapping other env -> - begin - let (typq, mapping_typ) = Env.get_val_spec other env in - let quants = quant_items typq in - let untuple (Typ_aux (typ_aux, _) as typ) = match typ_aux with - | Typ_tuple typs -> typs - | _ -> [typ] - in - match Env.expand_synonyms env mapping_typ with - | Typ_aux (Typ_bidir (typ1, typ2), _) -> - begin + (annot_mpat (MP_tuple (List.rev tpats)) typ, env, guards) + | _ -> typ_error env l "Cannot bind tuple mapping-pattern against non tuple type" + end + | MP_app (f, mpats) when Env.is_union_constructor f env -> begin + let typq, ctor_typ = Env.get_val_spec f env in + let quants = quant_items typq in + let untuple (Typ_aux (typ_aux, _) as typ) = match typ_aux with Typ_tuple typs -> typs | _ -> [typ] in + match Env.expand_synonyms env ctor_typ with + | Typ_aux (Typ_fn ([arg_typ], ret_typ), _) -> begin + try + typ_debug + (lazy ("Unifying " ^ string_of_bind (typq, ctor_typ) ^ " for mapping-pattern " ^ string_of_typ typ)); + let unifiers = unify l env (tyvars_of_typ ret_typ) ret_typ typ in + let arg_typ' = subst_unifiers unifiers arg_typ in + let quants' = List.fold_left (instantiate_quants env) quants (KBindings.bindings unifiers) in + if match quants' with [] -> false | _ -> true then + typ_error env l + ("Quantifiers " + ^ string_of_list ", " string_of_quant_item quants' + ^ " not resolved in mapping-pattern " ^ string_of_mpat mpat + ) + else (); + let _ret_typ' = subst_unifiers unifiers ret_typ in + let tpats, env, guards = + try List.fold_left2 bind_tuple_mpat ([], env, []) mpats (untuple arg_typ') + with Invalid_argument _ -> + typ_error env l "Union constructor mapping-pattern arguments have incorrect length" + in + (annot_mpat (MP_app (f, List.rev tpats)) typ, env, guards) + with Unification_error (l, m) -> + typ_error env l ("Unification error when mapping-pattern matching against union constructor: " ^ m) + end + | _ -> typ_error env l ("Mal-formed constructor " ^ string_of_id f ^ " with type " ^ string_of_typ ctor_typ) + end + | MP_app (other, mpats) when Env.is_mapping other env -> begin + let typq, mapping_typ = Env.get_val_spec other env in + let quants = quant_items typq in + let untuple (Typ_aux (typ_aux, _) as typ) = match typ_aux with Typ_tuple typs -> typs | _ -> [typ] in + match Env.expand_synonyms env mapping_typ with + | Typ_aux (Typ_bidir (typ1, typ2), _) -> begin + try + typ_debug + (lazy ("Unifying " ^ string_of_bind (typq, mapping_typ) ^ " for mapping-pattern " ^ string_of_typ typ)); + let unifiers = unify l env (tyvars_of_typ typ2) typ2 typ in + let arg_typ' = subst_unifiers unifiers typ1 in + let quants' = List.fold_left (instantiate_quants env) quants (KBindings.bindings unifiers) in + if match quants' with [] -> false | _ -> true then + typ_error env l + ("Quantifiers " + ^ string_of_list ", " string_of_quant_item quants' + ^ " not resolved in mapping-pattern " ^ string_of_mpat mpat + ) + else (); + let _ret_typ' = subst_unifiers unifiers typ2 in + let tpats, env, guards = + try List.fold_left2 bind_tuple_mpat ([], env, []) mpats (untuple arg_typ') + with Invalid_argument _ -> typ_error env l "Mapping pattern arguments have incorrect length" + in + (annot_mpat (MP_app (other, List.rev tpats)) typ, env, guards) + with Unification_error (l, _) -> ( try - typ_debug (lazy ("Unifying " ^ string_of_bind (typq, mapping_typ) ^ " for mapping-pattern " ^ string_of_typ typ)); - let unifiers = unify l env (tyvars_of_typ typ2) typ2 typ in - let arg_typ' = subst_unifiers unifiers typ1 in + typ_debug (lazy "Unifying mapping forwards failed, trying backwards."); + typ_debug + (lazy ("Unifying " ^ string_of_bind (typq, mapping_typ) ^ " for mapping-pattern " ^ string_of_typ typ)); + let unifiers = unify l env (tyvars_of_typ typ1) typ1 typ in + let arg_typ' = subst_unifiers unifiers typ2 in let quants' = List.fold_left (instantiate_quants env) quants (KBindings.bindings unifiers) in - if (match quants' with [] -> false | _ -> true) - then typ_error env l ("Quantifiers " ^ string_of_list ", " string_of_quant_item quants' ^ " not resolved in mapping-pattern " ^ string_of_mpat mpat) + if match quants' with [] -> false | _ -> true then + typ_error env l + ("Quantifiers " + ^ string_of_list ", " string_of_quant_item quants' + ^ " not resolved in mapping-pattern " ^ string_of_mpat mpat + ) else (); - let _ret_typ' = subst_unifiers unifiers typ2 in + let _ret_typ' = subst_unifiers unifiers typ1 in let tpats, env, guards = - try List.fold_left2 bind_tuple_mpat ([], env, []) mpats (untuple arg_typ') with - | Invalid_argument _ -> typ_error env l "Mapping pattern arguments have incorrect length" + try List.fold_left2 bind_tuple_mpat ([], env, []) mpats (untuple arg_typ') + with Invalid_argument _ -> typ_error env l "Mapping pattern arguments have incorrect length" in - annot_mpat (MP_app (other, List.rev tpats)) typ, env, guards - with - | Unification_error (l, _) -> - try - typ_debug (lazy "Unifying mapping forwards failed, trying backwards."); - typ_debug (lazy ("Unifying " ^ string_of_bind (typq, mapping_typ) ^ " for mapping-pattern " ^ string_of_typ typ)); - let unifiers = unify l env (tyvars_of_typ typ1) typ1 typ in - let arg_typ' = subst_unifiers unifiers typ2 in - let quants' = List.fold_left (instantiate_quants env) quants (KBindings.bindings unifiers) in - if (match quants' with [] -> false | _ -> true) - then typ_error env l ("Quantifiers " ^ string_of_list ", " string_of_quant_item quants' ^ " not resolved in mapping-pattern " ^ string_of_mpat mpat) - else (); - let _ret_typ' = subst_unifiers unifiers typ1 in - let tpats, env, guards = - try List.fold_left2 bind_tuple_mpat ([], env, []) mpats (untuple arg_typ') with - | Invalid_argument _ -> typ_error env l "Mapping pattern arguments have incorrect length" - in - annot_mpat (MP_app (other, List.rev tpats)) typ, env, guards - with - | Unification_error (l, m) -> typ_error env l ("Unification error when pattern matching against mapping constructor: " ^ m) - end - | _ -> - typ_error env l ("unifying mapping type, expanded synonyms to non-mapping type??") - end + (annot_mpat (MP_app (other, List.rev tpats)) typ, env, guards) + with Unification_error (l, m) -> + typ_error env l ("Unification error when pattern matching against mapping constructor: " ^ m) + ) + end + | _ -> typ_error env l "unifying mapping type, expanded synonyms to non-mapping type??" + end | MP_app (f, _) when not (Env.is_union_constructor f env || Env.is_mapping f env) -> - typ_error env l (string_of_id f ^ " is not a union constructor or mapping in mapping-pattern " ^ string_of_mpat mpat) + typ_error env l + (string_of_id f ^ " is not a union constructor or mapping in mapping-pattern " ^ string_of_mpat mpat) | MP_as (mpat, id) -> - let (typed_mpat, env, guards) = bind_mpat allow_unknown other_env env mpat typ in - (annot_mpat (MP_as (typed_mpat, id)) (typ_of_mpat typed_mpat), - Env.add_local id (Immutable, typ_of_mpat typed_mpat) env, - guards) + let typed_mpat, env, guards = bind_mpat allow_unknown other_env env mpat typ in + ( annot_mpat (MP_as (typed_mpat, id)) (typ_of_mpat typed_mpat), + Env.add_local id (Immutable, typ_of_mpat typed_mpat) env, + guards + ) (* This is a special case for flow typing when we match a constant numeric literal. *) | MP_lit (L_aux (L_num n, _) as lit) when is_atom typ -> - let nexp = match destruct_atom_nexp env typ with Some n -> n | None -> assert false in - annot_mpat (MP_lit lit) (atom_typ (nconstant n)), Env.add_constraint (nc_eq nexp (nconstant n)) env, [] + let nexp = match destruct_atom_nexp env typ with Some n -> n | None -> assert false in + (annot_mpat (MP_lit lit) (atom_typ (nconstant n)), Env.add_constraint (nc_eq nexp (nconstant n)) env, []) (* Similarly, for boolean literals *) | MP_lit (L_aux (L_true, _) as lit) when is_atom_bool typ -> - let nc = match destruct_atom_bool env typ with Some n -> n | None -> assert false in - annot_mpat (MP_lit lit) (atom_bool_typ nc_true), Env.add_constraint nc env, [] + let nc = match destruct_atom_bool env typ with Some n -> n | None -> assert false in + (annot_mpat (MP_lit lit) (atom_bool_typ nc_true), Env.add_constraint nc env, []) | MP_lit (L_aux (L_false, _) as lit) when is_atom_bool typ -> - let nc = match destruct_atom_bool env typ with Some n -> n | None -> assert false in - annot_mpat (MP_lit lit) (atom_bool_typ nc_false), Env.add_constraint (nc_not nc) env, [] - | _ -> - let (inferred_mpat, env, guards) = infer_mpat allow_unknown other_env env mpat in - match subtyp l env typ (typ_of_mpat inferred_mpat) with - | () -> switch_typ inferred_mpat (typ_of_mpat inferred_mpat), env, guards - | exception (Type_error _ as typ_exn) -> - match mpat_aux with - | MP_lit lit -> - let var = fresh_var () in - let guard = mk_exp (E_app_infix (mk_exp (E_id var), mk_id "==", mk_exp (E_lit lit))) in - let (typed_mpat, env, guards) = bind_mpat allow_unknown other_env env (mk_mpat (MP_id var)) typ in - typed_mpat, env, guard::guards - | _ -> raise typ_exn + let nc = match destruct_atom_bool env typ with Some n -> n | None -> assert false in + (annot_mpat (MP_lit lit) (atom_bool_typ nc_false), Env.add_constraint (nc_not nc) env, []) + | _ -> ( + let inferred_mpat, env, guards = infer_mpat allow_unknown other_env env mpat in + match subtyp l env typ (typ_of_mpat inferred_mpat) with + | () -> (switch_typ inferred_mpat (typ_of_mpat inferred_mpat), env, guards) + | exception (Type_error _ as typ_exn) -> ( + match mpat_aux with + | MP_lit lit -> + let var = fresh_var () in + let guard = mk_exp (E_app_infix (mk_exp (E_id var), mk_id "==", mk_exp (E_lit lit))) in + let typed_mpat, env, guards = bind_mpat allow_unknown other_env env (mk_mpat (MP_id var)) typ in + (typed_mpat, env, guard :: guards) + | _ -> raise typ_exn + ) + ) and infer_mpat allow_unknown other_env env (MP_aux (mpat_aux, (l, _)) as mpat) = let annot_mpat mpat typ = MP_aux (mpat, (l, mk_tannot env typ)) in match mpat_aux with - | MP_id v -> - begin - match Env.lookup_id v env with - | Local (Immutable, _) | Unbound _ -> - begin match Env.lookup_id v other_env with - | Local (Immutable, typ) -> bind_mpat allow_unknown other_env env (mk_mpat (MP_typ (mk_mpat (MP_id v), typ))) typ + | MP_id v -> begin + match Env.lookup_id v env with + | Local (Immutable, _) | Unbound _ -> begin + match Env.lookup_id v other_env with + | Local (Immutable, typ) -> + bind_mpat allow_unknown other_env env (mk_mpat (MP_typ (mk_mpat (MP_id v), typ))) typ | Unbound _ -> - if allow_unknown then annot_mpat (MP_id v) unknown_typ, env, [] else - typ_error env l ("Cannot infer identifier in mapping-pattern " ^ string_of_mpat mpat ^ " - try adding a type annotation") + if allow_unknown then (annot_mpat (MP_id v) unknown_typ, env, []) + else + typ_error env l + ("Cannot infer identifier in mapping-pattern " ^ string_of_mpat mpat + ^ " - try adding a type annotation" + ) | _ -> assert false - end - | Local (Mutable, _) | Register _ -> + end + | Local (Mutable, _) | Register _ -> typ_error env l ("Cannot shadow mutable local or register in mapping-pattern " ^ string_of_mpat mpat) - | Enum enum -> annot_mpat (MP_id v) enum, env, [] - end + | Enum enum -> (annot_mpat (MP_id v) enum, env, []) + end | MP_vector_subrange (id, n, m) -> - let len, order = match Env.get_default_order env with - | Ord_aux (Ord_dec, _) -> - if Big_int.greater_equal n m then - Big_int.sub (Big_int.succ n) m, dec_ord - else - typ_error env l (Printf.sprintf "%s must be greater than or equal to %s" (Big_int.to_string n) (Big_int.to_string m)) - | Ord_aux (Ord_inc, _) -> - if Big_int.less_equal n m then - Big_int.sub (Big_int.succ m) n, inc_ord - else - typ_error env l (Printf.sprintf "%s must be less than or equal to %s" (Big_int.to_string n) (Big_int.to_string m)) - | _ -> - typ_error env l default_order_error_string - in - begin - match Env.lookup_id id env with - | Local (Immutable, _) | Unbound _ -> - begin match Env.lookup_id id other_env with - | Unbound _ -> - if allow_unknown then - annot_mpat (MP_vector_subrange (id, n, m)) (bitvector_typ (nconstant len) order), env, [] - else - typ_error env l "Cannot infer identifier type in vector subrange pattern" - | Local (Immutable, other_typ) -> - let (id_len, id_order) = destruct_bitvector_typ l env other_typ in - if is_order_inc id_order <> is_order_inc order then ( - typ_error env l "Mismatching bitvector ordering in vector subrange pattern %b %b" - ); - begin match id_len with - | Nexp_aux (Nexp_constant id_len, _) when Big_int.greater_equal id_len len -> - annot_mpat (MP_vector_subrange (id, n, m)) (bitvector_typ (nconstant len) order), env, [] - | _ -> - typ_error env l (Printf.sprintf "%s must have a constant length greater than or equal to %s" - (string_of_id id) (Big_int.to_string len)) - end - | _ -> - typ_error env l "Invalid identifier in vector subrange pattern" - end - | Local _ | Register _ -> - typ_error env l "Invalid identifier in vector subrange pattern" - | Enum e -> - typ_error env l (Printf.sprintf "Identifier %s is a member of enumeration %s in vector subrange pattern" - (string_of_id id) (string_of_typ e)) - end - | MP_app (f, _) when Env.is_union_constructor f env -> - begin - let (_, ctor_typ) = Env.get_val_spec f env in - match Env.expand_synonyms env ctor_typ with - | Typ_aux (Typ_fn (_, ret_typ), _) -> - bind_mpat allow_unknown other_env env mpat ret_typ - | _ -> typ_error env l ("Mal-formed constructor " ^ string_of_id f) - end - | MP_app (f, _) when Env.is_mapping f env -> - begin - let (_, mapping_typ) = Env.get_val_spec f env in - match Env.expand_synonyms env mapping_typ with - | Typ_aux (Typ_bidir (typ1, typ2), _) -> - begin - try - bind_mpat allow_unknown other_env env mpat typ2 - with - | Type_error _ -> - bind_mpat allow_unknown other_env env mpat typ1 + let len, order = + match Env.get_default_order env with + | Ord_aux (Ord_dec, _) -> + if Big_int.greater_equal n m then (Big_int.sub (Big_int.succ n) m, dec_ord) + else + typ_error env l + (Printf.sprintf "%s must be greater than or equal to %s" (Big_int.to_string n) (Big_int.to_string m)) + | Ord_aux (Ord_inc, _) -> + if Big_int.less_equal n m then (Big_int.sub (Big_int.succ m) n, inc_ord) + else + typ_error env l + (Printf.sprintf "%s must be less than or equal to %s" (Big_int.to_string n) (Big_int.to_string m)) + | _ -> typ_error env l default_order_error_string + in + begin + match Env.lookup_id id env with + | Local (Immutable, _) | Unbound _ -> begin + match Env.lookup_id id other_env with + | Unbound _ -> + if allow_unknown then + (annot_mpat (MP_vector_subrange (id, n, m)) (bitvector_typ (nconstant len) order), env, []) + else typ_error env l "Cannot infer identifier type in vector subrange pattern" + | Local (Immutable, other_typ) -> + let id_len, id_order = destruct_bitvector_typ l env other_typ in + if is_order_inc id_order <> is_order_inc order then + typ_error env l "Mismatching bitvector ordering in vector subrange pattern %b %b"; + begin + match id_len with + | Nexp_aux (Nexp_constant id_len, _) when Big_int.greater_equal id_len len -> + (annot_mpat (MP_vector_subrange (id, n, m)) (bitvector_typ (nconstant len) order), env, []) + | _ -> + typ_error env l + (Printf.sprintf "%s must have a constant length greater than or equal to %s" (string_of_id id) + (Big_int.to_string len) + ) + end + | _ -> typ_error env l "Invalid identifier in vector subrange pattern" end - | _ -> typ_error env l ("Malformed mapping type " ^ string_of_id f) - end - | MP_lit lit -> - annot_mpat (MP_lit lit) (infer_lit env lit), env, [] + | Local _ | Register _ -> typ_error env l "Invalid identifier in vector subrange pattern" + | Enum e -> + typ_error env l + (Printf.sprintf "Identifier %s is a member of enumeration %s in vector subrange pattern" (string_of_id id) + (string_of_typ e) + ) + end + | MP_app (f, _) when Env.is_union_constructor f env -> begin + let _, ctor_typ = Env.get_val_spec f env in + match Env.expand_synonyms env ctor_typ with + | Typ_aux (Typ_fn (_, ret_typ), _) -> bind_mpat allow_unknown other_env env mpat ret_typ + | _ -> typ_error env l ("Mal-formed constructor " ^ string_of_id f) + end + | MP_app (f, _) when Env.is_mapping f env -> begin + let _, mapping_typ = Env.get_val_spec f env in + match Env.expand_synonyms env mapping_typ with + | Typ_aux (Typ_bidir (typ1, typ2), _) -> begin + try bind_mpat allow_unknown other_env env mpat typ2 + with Type_error _ -> bind_mpat allow_unknown other_env env mpat typ1 + end + | _ -> typ_error env l ("Malformed mapping type " ^ string_of_id f) + end + | MP_lit lit -> (annot_mpat (MP_lit lit) (infer_lit env lit), env, []) | MP_typ (mpat, typ_annot) -> - Env.wf_typ env typ_annot; - let (typed_mpat, env, guards) = bind_mpat allow_unknown other_env env mpat typ_annot in - annot_mpat (MP_typ (typed_mpat, typ_annot)) typ_annot, env, guards + Env.wf_typ env typ_annot; + let typed_mpat, env, guards = bind_mpat allow_unknown other_env env mpat typ_annot in + (annot_mpat (MP_typ (typed_mpat, typ_annot)) typ_annot, env, guards) | MP_vector (mpat :: mpats) -> - let fold_mpats (mpats, env, guards) mpat = - let typed_mpat, env, guards' = bind_mpat allow_unknown other_env env mpat bit_typ in - mpats @ [typed_mpat], env, guards' @ guards - in - let mpats, env, guards = List.fold_left fold_mpats ([], env, []) (mpat :: mpats) in - let len = nexp_simp (nint (List.length mpats)) in - let etyp = typ_of_mpat (List.hd mpats) in - List.iter (fun mpat -> typ_equality l env etyp (typ_of_mpat mpat)) mpats; - annot_mpat (MP_vector mpats) (dvector_typ env len etyp), env, guards + let fold_mpats (mpats, env, guards) mpat = + let typed_mpat, env, guards' = bind_mpat allow_unknown other_env env mpat bit_typ in + (mpats @ [typed_mpat], env, guards' @ guards) + in + let mpats, env, guards = List.fold_left fold_mpats ([], env, []) (mpat :: mpats) in + let len = nexp_simp (nint (List.length mpats)) in + let etyp = typ_of_mpat (List.hd mpats) in + List.iter (fun mpat -> typ_equality l env etyp (typ_of_mpat mpat)) mpats; + (annot_mpat (MP_vector mpats) (dvector_typ env len etyp), env, guards) | MP_vector_concat (mpat :: mpats) -> - let fold_mpats (mpats, env, guards) mpat = - let inferred_mpat, env, guards' = infer_mpat allow_unknown other_env env mpat in - mpats @ [inferred_mpat], env, guards' @ guards - in - let inferred_mpats, env, guards = - List.fold_left fold_mpats ([], env, []) (mpat :: mpats) in - if allow_unknown && List.exists (fun mpat -> is_unknown_type (typ_of_mpat mpat)) inferred_mpats then - annot_mpat (MP_vector_concat inferred_mpats) unknown_typ, env, guards (* hack *) - else - begin match destruct_any_vector_typ l env (typ_of_mpat (List.hd inferred_mpats)) with - | Destruct_vector (len, _, vtyp) -> - let fold_len len mpat = - let (len', _, vtyp') = destruct_vector_typ l env (typ_of_mpat mpat) in - typ_equality l env vtyp vtyp'; - nsum len len' - in - let len = nexp_simp (List.fold_left fold_len len (List.tl inferred_mpats)) in - annot_mpat (MP_vector_concat inferred_mpats) (dvector_typ env len vtyp), env, guards - | Destruct_bitvector (len, _) -> - let fold_len len mpat = - let (len', _) = destruct_bitvector_typ l env (typ_of_mpat mpat) in - nsum len len' - in - let len = nexp_simp (List.fold_left fold_len len (List.tl inferred_mpats)) in - annot_mpat (MP_vector_concat inferred_mpats) (bits_typ env len), env, guards - end + let fold_mpats (mpats, env, guards) mpat = + let inferred_mpat, env, guards' = infer_mpat allow_unknown other_env env mpat in + (mpats @ [inferred_mpat], env, guards' @ guards) + in + let inferred_mpats, env, guards = List.fold_left fold_mpats ([], env, []) (mpat :: mpats) in + if allow_unknown && List.exists (fun mpat -> is_unknown_type (typ_of_mpat mpat)) inferred_mpats then + (annot_mpat (MP_vector_concat inferred_mpats) unknown_typ, env, guards (* hack *)) + else begin + match destruct_any_vector_typ l env (typ_of_mpat (List.hd inferred_mpats)) with + | Destruct_vector (len, _, vtyp) -> + let fold_len len mpat = + let len', _, vtyp' = destruct_vector_typ l env (typ_of_mpat mpat) in + typ_equality l env vtyp vtyp'; + nsum len len' + in + let len = nexp_simp (List.fold_left fold_len len (List.tl inferred_mpats)) in + (annot_mpat (MP_vector_concat inferred_mpats) (dvector_typ env len vtyp), env, guards) + | Destruct_bitvector (len, _) -> + let fold_len len mpat = + let len', _ = destruct_bitvector_typ l env (typ_of_mpat mpat) in + nsum len len' + in + let len = nexp_simp (List.fold_left fold_len len (List.tl inferred_mpats)) in + (annot_mpat (MP_vector_concat inferred_mpats) (bits_typ env len), env, guards) + end | MP_string_append mpats -> - let fold_pats (pats, env, guards) pat = - let inferred_pat, env, guards' = infer_mpat allow_unknown other_env env pat in - typ_equality l env (typ_of_mpat inferred_pat) string_typ; - pats @ [inferred_pat], env, guards' @ guards - in - let typed_mpats, env, guards = - List.fold_left fold_pats ([], env, []) mpats - in - annot_mpat (MP_string_append typed_mpats) string_typ, env, guards + let fold_pats (pats, env, guards) pat = + let inferred_pat, env, guards' = infer_mpat allow_unknown other_env env pat in + typ_equality l env (typ_of_mpat inferred_pat) string_typ; + (pats @ [inferred_pat], env, guards' @ guards) + in + let typed_mpats, env, guards = List.fold_left fold_pats ([], env, []) mpats in + (annot_mpat (MP_string_append typed_mpats) string_typ, env, guards) | MP_as (mpat, id) -> - let (typed_mpat, env, guards) = infer_mpat allow_unknown other_env env mpat in - (annot_mpat (MP_as (typed_mpat, id)) (typ_of_mpat typed_mpat), - Env.add_local id (Immutable, typ_of_mpat typed_mpat) env, - guards) - - | _ -> - typ_error env l ("Couldn't infer type of mapping-pattern " ^ string_of_mpat mpat) + let typed_mpat, env, guards = infer_mpat allow_unknown other_env env mpat in + ( annot_mpat (MP_as (typed_mpat, id)) (typ_of_mpat typed_mpat), + Env.add_local id (Immutable, typ_of_mpat typed_mpat) env, + guards + ) + | _ -> typ_error env l ("Couldn't infer type of mapping-pattern " ^ string_of_mpat mpat) (**************************************************************************) (* 6. Effect system *) (**************************************************************************) -let effect_of_annot = function -| (Some t, _) -> t.monadic -| (None, _) -> no_effect +let effect_of_annot = function Some t, _ -> t.monadic | None, _ -> no_effect let effect_of (E_aux (_, (_, annot))) = effect_of_annot annot -let add_effect_annot annot eff = match annot with - | (Some tannot, uannot) -> (Some { tannot with monadic = eff }, uannot) - | (None, uannot) -> (None, uannot) +let add_effect_annot annot eff = + match annot with Some tannot, uannot -> (Some { tannot with monadic = eff }, uannot) | None, uannot -> (None, uannot) let effect_of_pat (P_aux (_, (_, annot))) = effect_of_annot annot @@ -5123,41 +5325,39 @@ let effect_of_pat (P_aux (_, (_, annot))) = effect_of_annot annot let check_duplicate_letbinding l pat env = match IdSet.choose_opt (IdSet.inter (pat_ids pat) (Env.get_toplevel_lets env)) with - | Some id -> - typ_error env l ("Duplicate toplevel let binding " ^ string_of_id id) + | Some id -> typ_error env l ("Duplicate toplevel let binding " ^ string_of_id id) | None -> () let check_letdef orig_env def_annot (LB_aux (letbind, (l, _))) = typ_print (lazy ("\nChecking top-level let" |> cyan |> clear)); match letbind with - | LB_val (P_aux (P_typ (typ_annot, _), _) as pat, bind) -> - check_duplicate_letbinding l pat orig_env; - Env.wf_typ orig_env typ_annot; - let checked_bind = crule check_exp orig_env bind typ_annot in - let tpat, env = bind_pat_no_guard orig_env pat typ_annot in - [DEF_aux (DEF_let (LB_aux (LB_val (tpat, checked_bind), (l, empty_tannot))), def_annot)], - Env.add_toplevel_lets (pat_ids tpat) env - + | LB_val ((P_aux (P_typ (typ_annot, _), _) as pat), bind) -> + check_duplicate_letbinding l pat orig_env; + Env.wf_typ orig_env typ_annot; + let checked_bind = crule check_exp orig_env bind typ_annot in + let tpat, env = bind_pat_no_guard orig_env pat typ_annot in + ( [DEF_aux (DEF_let (LB_aux (LB_val (tpat, checked_bind), (l, empty_tannot))), def_annot)], + Env.add_toplevel_lets (pat_ids tpat) env + ) | LB_val (pat, bind) -> - check_duplicate_letbinding l pat orig_env; - let inferred_bind = irule infer_exp orig_env bind in - let tpat, env = bind_pat_no_guard orig_env pat (typ_of inferred_bind) in - [DEF_aux (DEF_let (LB_aux (LB_val (tpat, inferred_bind), (l, empty_tannot))), def_annot)], - Env.add_toplevel_lets (pat_ids tpat) env + check_duplicate_letbinding l pat orig_env; + let inferred_bind = irule infer_exp orig_env bind in + let tpat, env = bind_pat_no_guard orig_env pat (typ_of inferred_bind) in + ( [DEF_aux (DEF_let (LB_aux (LB_val (tpat, inferred_bind), (l, empty_tannot))), def_annot)], + Env.add_toplevel_lets (pat_ids tpat) env + ) let bind_funcl_arg_typ l env typ = match typ with - | Typ_aux (Typ_fn (typ_args, typ_ret), _) -> - begin - let env = Env.add_ret_typ typ_ret env in - match List.map implicit_to_int typ_args with - | [typ_arg] -> - typ_arg, typ_ret, env - | typ_args -> - (* This is one of the cases where we are allowed to treat - function arguments as like a tuple, normally we can't. *) - Typ_aux (Typ_tuple typ_args, l), typ_ret, env - end + | Typ_aux (Typ_fn (typ_args, typ_ret), _) -> begin + let env = Env.add_ret_typ typ_ret env in + match List.map implicit_to_int typ_args with + | [typ_arg] -> (typ_arg, typ_ret, env) + | typ_args -> + (* This is one of the cases where we are allowed to treat + function arguments as like a tuple, normally we can't. *) + (Typ_aux (Typ_tuple typ_args, l), typ_ret, env) + end | _ -> typ_error env l ("Function clause must have function type: " ^ string_of_typ typ ^ " is not a function type") let check_funcl env (FCL_aux (FCL_funcl (id, pexp), (def_annot, _))) typ = @@ -5170,15 +5370,14 @@ let check_funcl env (FCL_aux (FCL_funcl (id, pexp), (def_annot, _))) typ = re-write the polymorphic undefineds (due to the specific form the functions have *) let env = - if Str.string_match (Str.regexp_string "undefined_") (string_of_id id) 0 - then Env.allow_polymorphic_undefineds env + if Str.string_match (Str.regexp_string "undefined_") (string_of_id id) 0 then Env.allow_polymorphic_undefineds env else env in let typed_pexp = check_case env typ_arg pexp typ_ret in FCL_aux (FCL_funcl (id, typed_pexp), (def_annot, mk_expected_tannot env typ (Some typ))) let check_mapcl : Env.t -> uannot mapcl -> typ -> tannot mapcl = - fun env (MCL_aux (cl, (def_annot, _))) typ -> + fun env (MCL_aux (cl, (def_annot, _))) typ -> match typ with | Typ_aux (Typ_bidir (typ1, typ2), _) -> begin match cl with @@ -5208,57 +5407,69 @@ let check_mapcl : Env.t -> uannot mapcl -> typ -> tannot mapcl = MCL_aux (MCL_backwards (typed_mpexp, typed_exp), (def_annot, mk_expected_tannot env typ (Some typ))) end end - | _ -> typ_error env def_annot.loc ("Mapping clause must have mapping type: " ^ string_of_typ typ ^ " is not a mapping type") + | _ -> + typ_error env def_annot.loc + ("Mapping clause must have mapping type: " ^ string_of_typ typ ^ " is not a mapping type") let infer_funtyp l env tannotopt funcls = match tannotopt with - | Typ_annot_opt_aux (Typ_annot_opt_some (quant, ret_typ), _) -> - begin - let rec typ_from_pat (P_aux (pat_aux, (l, _)) as pat) = - match pat_aux with - | P_lit lit -> infer_lit env lit - | P_typ (typ, _) -> typ - | P_tuple pats -> mk_typ (Typ_tuple (List.map typ_from_pat pats)) - | _ -> typ_error env l ("Cannot infer type from pattern " ^ string_of_pat pat) - in - match funcls with - | [FCL_aux (FCL_funcl (_, Pat_aux (pexp,_)), _)] -> - let pat = match pexp with Pat_exp (pat,_) | Pat_when (pat,_,_) -> pat in + | Typ_annot_opt_aux (Typ_annot_opt_some (quant, ret_typ), _) -> begin + let rec typ_from_pat (P_aux (pat_aux, (l, _)) as pat) = + match pat_aux with + | P_lit lit -> infer_lit env lit + | P_typ (typ, _) -> typ + | P_tuple pats -> mk_typ (Typ_tuple (List.map typ_from_pat pats)) + | _ -> typ_error env l ("Cannot infer type from pattern " ^ string_of_pat pat) + in + match funcls with + | [FCL_aux (FCL_funcl (_, Pat_aux (pexp, _)), _)] -> + let pat = match pexp with Pat_exp (pat, _) | Pat_when (pat, _, _) -> pat in (* The function syntax lets us bind multiple function arguments with a single pattern, hence why we need to do this. But perhaps we don't want to allow this? *) let arg_typs = - match typ_from_pat pat with - | Typ_aux (Typ_tuple arg_typs, _) -> arg_typs - | arg_typ -> [arg_typ] + match typ_from_pat pat with Typ_aux (Typ_tuple arg_typs, _) -> arg_typs | arg_typ -> [arg_typ] in let fn_typ = mk_typ (Typ_fn (arg_typs, ret_typ)) in (quant, fn_typ) - | _ -> typ_error env l "Cannot infer function type for function with multiple clauses" - end + | _ -> typ_error env l "Cannot infer function type for function with multiple clauses" + end | Typ_annot_opt_aux (Typ_annot_opt_none, _) -> typ_error env l "Cannot infer function type for unannotated function" (* This is used for functions and mappings that do not have an explicit type signature using val *) let synthesize_val_spec env typq typ id = - mk_def (DEF_val (VS_aux (VS_val_spec (TypSchm_aux (TypSchm_ts (typq, typ), Parse_ast.Unknown), id, None, false), (Parse_ast.Unknown, mk_tannot env typ)))) + mk_def + (DEF_val + (VS_aux + ( VS_val_spec (TypSchm_aux (TypSchm_ts (typq, typ), Parse_ast.Unknown), id, None, false), + (Parse_ast.Unknown, mk_tannot env typ) + ) + ) + ) let check_tannotopt env typq ret_typ = function | Typ_annot_opt_aux (Typ_annot_opt_none, _) -> () | Typ_annot_opt_aux (Typ_annot_opt_some (annot_typq, annot_ret_typ), l) -> - if expanded_typ_identical env ret_typ annot_ret_typ - then () - else typ_error env l (string_of_bind (typq, ret_typ) ^ " and " ^ string_of_bind (annot_typq, annot_ret_typ) ^ " do not match between function and val spec") + if expanded_typ_identical env ret_typ annot_ret_typ then () + else + typ_error env l + (string_of_bind (typq, ret_typ) + ^ " and " + ^ string_of_bind (annot_typq, annot_ret_typ) + ^ " do not match between function and val spec" + ) let check_termination_measure env arg_typs pat exp = - let typ = match arg_typs with [x] -> x | _ -> Typ_aux (Typ_tuple arg_typs,Unknown) in + let typ = match arg_typs with [x] -> x | _ -> Typ_aux (Typ_tuple arg_typs, Unknown) in let tpat, env = bind_pat_no_guard env pat typ in let texp = check_exp env exp int_typ in - tpat, texp + (tpat, texp) let check_termination_measure_decl env def_annot (id, pat, exp) = let quant, typ = Env.get_val_spec id env in - let arg_typs, l = match typ with - | Typ_aux (Typ_fn (arg_typs, _), l) -> arg_typs, l + let arg_typs, l = + match typ with + | Typ_aux (Typ_fn (arg_typs, _), l) -> (arg_typs, l) | _ -> typ_error env (id_loc id) "Function val spec is not a function type" in let env = Env.add_typquant l quant env in @@ -5269,32 +5480,39 @@ let check_funcls_complete l env funcls typ = let typ_arg, _, env = bind_funcl_arg_typ l env typ in let ctx = pattern_completeness_ctx env in match PC.is_complete_funcls_wildcarded ~keyword:"function" l ctx funcls typ_arg with - | Some funcls -> funcls, add_def_attribute (gen_loc l) "complete" "" - | None -> funcls, add_def_attribute (gen_loc l) "incomplete" "" + | Some funcls -> (funcls, add_def_attribute (gen_loc l) "complete" "") + | None -> (funcls, add_def_attribute (gen_loc l) "incomplete" "") let check_fundef env def_annot (FD_aux (FD_function (recopt, tannotopt, funcls), (l, _))) = let id = - match (List.fold_right - (fun (FCL_aux (FCL_funcl (id, _), _)) id' -> - match id' with - | Some id' -> if string_of_id id' = string_of_id id then Some id' - else typ_error env l ("Function declaration expects all definitions to have the same name, " - ^ string_of_id id ^ " differs from other definitions of " ^ string_of_id id') - | None -> Some id) funcls None) + match + List.fold_right + (fun (FCL_aux (FCL_funcl (id, _), _)) id' -> + match id' with + | Some id' -> + if string_of_id id' = string_of_id id then Some id' + else + typ_error env l + ("Function declaration expects all definitions to have the same name, " ^ string_of_id id + ^ " differs from other definitions of " ^ string_of_id id' + ) + | None -> Some id + ) + funcls None with | Some id -> id | None -> typ_error env l "funcl list is empty" in typ_print (lazy ("\n" ^ Util.("Check function " |> cyan |> clear) ^ string_of_id id)); let have_val_spec, (quant, typ), env = - try true, Env.get_val_spec id env, env with - | Type_error (_, l, _) -> - let (quant, typ) = infer_funtyp l env tannotopt funcls in - false, (quant, typ), env + try (true, Env.get_val_spec id env, env) + with Type_error (_, l, _) -> + let quant, typ = infer_funtyp l env tannotopt funcls in + (false, (quant, typ), env) in - let vtyp_args, vtyp_ret, vl = match typ with - | Typ_aux (Typ_fn (vtyp_args, vtyp_ret), vl) -> - vtyp_args, vtyp_ret, vl + let vtyp_args, vtyp_ret, vl = + match typ with + | Typ_aux (Typ_fn (vtyp_args, vtyp_ret), vl) -> (vtyp_args, vtyp_ret, vl) | _ -> typ_error env l "Function val spec is not a function type" in check_tannotopt env quant vtyp_ret tannotopt; @@ -5305,441 +5523,467 @@ let check_fundef env def_annot (FD_aux (FD_function (recopt, tannotopt, funcls), | Rec_aux (Rec_nonrec, l) -> Rec_aux (Rec_nonrec, l) | Rec_aux (Rec_rec, l) -> Rec_aux (Rec_rec, l) | Rec_aux (Rec_measure (measure_p, measure_e), l) -> - let tpat, texp = check_termination_measure funcl_env vtyp_args measure_p measure_e in - Rec_aux (Rec_measure (tpat, texp), l) + let tpat, texp = check_termination_measure funcl_env vtyp_args measure_p measure_e in + Rec_aux (Rec_measure (tpat, texp), l) in let vs_def, env = - if not have_val_spec then + if not have_val_spec then ( let typ = Typ_aux (Typ_fn (vtyp_args, vtyp_ret), vl) in - [synthesize_val_spec env quant typ id], Env.add_val_spec id (quant, typ) env - else - [], env + ([synthesize_val_spec env quant typ id], Env.add_val_spec id (quant, typ) env) + ) + else ([], env) in let funcls = List.map (fun funcl -> check_funcl funcl_env funcl typ) funcls in let funcls, update_attr = - if Option.is_some (get_def_attribute "complete" def_annot) || Option.is_some (get_def_attribute "incomplete" def_annot) then ( - funcls, (fun attrs -> attrs) - ) else ( - check_funcls_complete l funcl_env funcls typ - ) in + if + Option.is_some (get_def_attribute "complete" def_annot) + || Option.is_some (get_def_attribute "incomplete" def_annot) + then (funcls, fun attrs -> attrs) + else check_funcls_complete l funcl_env funcls typ + in let env = Env.define_val_spec id env in - vs_def @ [DEF_aux (DEF_fundef (FD_aux (FD_function (recopt, tannotopt, funcls), (l, empty_tannot))), update_attr def_annot)], - env + ( vs_def + @ [ + DEF_aux (DEF_fundef (FD_aux (FD_function (recopt, tannotopt, funcls), (l, empty_tannot))), update_attr def_annot); + ], + env + ) let check_mapdef env def_annot (MD_aux (MD_mapping (id, tannot_opt, mapcls), (l, _))) = typ_print (lazy ("\nChecking mapping " ^ string_of_id id)); let have_val_spec, (quant, typ), env = - try true, Env.get_val_spec id env, env with - | Type_error (_, _, _) as err -> - match tannot_opt with - | Typ_annot_opt_aux (Typ_annot_opt_some (quant, typ), _) -> - false, (quant, typ), env - | Typ_annot_opt_aux (Typ_annot_opt_none, _) -> - raise err + try (true, Env.get_val_spec id env, env) + with Type_error (_, _, _) as err -> ( + match tannot_opt with + | Typ_annot_opt_aux (Typ_annot_opt_some (quant, typ), _) -> (false, (quant, typ), env) + | Typ_annot_opt_aux (Typ_annot_opt_none, _) -> raise err + ) in - begin match typ with + begin + match typ with | Typ_aux (Typ_bidir (_, _), _) -> () | _ -> typ_error env l "Mapping val spec was not a mapping type" end; - begin match tannot_opt with - | Typ_annot_opt_aux (Typ_annot_opt_none, _) -> () - | Typ_annot_opt_aux (Typ_annot_opt_some (annot_typq, annot_typ), l) -> - if expanded_typ_identical env typ annot_typ then () - else typ_error env l (string_of_bind (quant, typ) ^ " and " ^ string_of_bind (annot_typq, annot_typ) ^ " do not match between mapping and val spec") + begin + match tannot_opt with + | Typ_annot_opt_aux (Typ_annot_opt_none, _) -> () + | Typ_annot_opt_aux (Typ_annot_opt_some (annot_typq, annot_typ), l) -> + if expanded_typ_identical env typ annot_typ then () + else + typ_error env l + (string_of_bind (quant, typ) + ^ " and " + ^ string_of_bind (annot_typq, annot_typ) + ^ " do not match between mapping and val spec" + ) end; typ_debug (lazy ("Checking mapdef " ^ string_of_id id ^ " has type " ^ string_of_bind (quant, typ))); let vs_def, env = if not have_val_spec then - [synthesize_val_spec env quant (Env.expand_synonyms env typ) id], Env.add_val_spec id (quant, typ) env - else - [], env + ([synthesize_val_spec env quant (Env.expand_synonyms env typ) id], Env.add_val_spec id (quant, typ) env) + else ([], env) in let mapcl_env = Env.add_typquant l quant env in let mapcls = List.map (fun mapcl -> check_mapcl mapcl_env mapcl typ) mapcls in let env = Env.define_val_spec id env in - vs_def @ [DEF_aux (DEF_mapdef (MD_aux (MD_mapping (id, tannot_opt, mapcls), (l, empty_tannot))), def_annot)], env + (vs_def @ [DEF_aux (DEF_mapdef (MD_aux (MD_mapping (id, tannot_opt, mapcls), (l, empty_tannot))), def_annot)], env) let rec warn_if_unsafe_cast l env = function | Typ_aux (Typ_fn (arg_typs, ret_typ), _) -> - List.iter (warn_if_unsafe_cast l env) arg_typs; - warn_if_unsafe_cast l env ret_typ + List.iter (warn_if_unsafe_cast l env) arg_typs; + warn_if_unsafe_cast l env ret_typ | Typ_aux (Typ_id id, _) when string_of_id id = "bool" -> () | Typ_aux (Typ_id id, _) when Env.is_enum id env -> () | Typ_aux (Typ_id id, _) when string_of_id id = "string" -> - Reporting.warn "Unsafe string cast" l - "A cast X -> string is unsafe, as it can cause 'x : X == y : X' to be checked as 'eq_string(cast(x), cast(y))'" + Reporting.warn "Unsafe string cast" l + "A cast X -> string is unsafe, as it can cause 'x : X == y : X' to be checked as 'eq_string(cast(x), cast(y))'" (* If we have a cast to an existential, it's probably done on purpose and we want to avoid false positives for warnings. *) | Typ_aux (Typ_exist _, _) -> () | typ when is_bitvector_typ typ -> () | typ when is_bit_typ typ -> () - | typ -> - Reporting.warn ("Potentially unsafe cast involving " ^ string_of_typ typ) l "" + | typ -> Reporting.warn ("Potentially unsafe cast involving " ^ string_of_typ typ) l "" (* Checking a val spec simply adds the type as a binding in the context. *) let check_val_spec env def_annot (VS_aux (vs, (l, _))) = - let annotate vs typq typ = DEF_aux (DEF_val (VS_aux (vs, (l, mk_tannot (Env.add_typquant l typq env) typ))), def_annot) in - let vs, id, typq, typ, env = match vs with - | VS_val_spec (TypSchm_aux (TypSchm_ts (typq, typ), ts_l) as typschm, id, exts, is_cast) -> - typ_print (lazy (Util.("Check val spec " |> cyan |> clear) ^ string_of_id id ^ " : " ^ string_of_typschm typschm)); - wf_typschm env typschm; - let env = match exts with Some exts -> Env.add_extern id exts env | None -> env in - let env = if is_cast then (warn_if_unsafe_cast l env (Env.expand_synonyms env typ); Env.add_cast id env) else env in - let typq', typ' = expand_bind_synonyms ts_l env (typq, typ) in - (* !opt_expand_valspec controls whether the actual valspec in - the AST is expanded, the val_spec type stored in the - environment is always expanded and uses typq' and typ' *) - let typq, typ = - if !opt_expand_valspec then - (typq', typ') - else - (typq, typ) - in - let vs = VS_val_spec (TypSchm_aux (TypSchm_ts (typq, typ), ts_l), id, exts, is_cast) in - (vs, id, typq', typ', env) + let annotate vs typq typ = + DEF_aux (DEF_val (VS_aux (vs, (l, mk_tannot (Env.add_typquant l typq env) typ))), def_annot) + in + let vs, id, typq, typ, env = + match vs with + | VS_val_spec ((TypSchm_aux (TypSchm_ts (typq, typ), ts_l) as typschm), id, exts, is_cast) -> + typ_print + (lazy (Util.("Check val spec " |> cyan |> clear) ^ string_of_id id ^ " : " ^ string_of_typschm typschm)); + wf_typschm env typschm; + let env = match exts with Some exts -> Env.add_extern id exts env | None -> env in + let env = + if is_cast then ( + warn_if_unsafe_cast l env (Env.expand_synonyms env typ); + Env.add_cast id env + ) + else env + in + let typq', typ' = expand_bind_synonyms ts_l env (typq, typ) in + (* !opt_expand_valspec controls whether the actual valspec in + the AST is expanded, the val_spec type stored in the + environment is always expanded and uses typq' and typ' *) + let typq, typ = if !opt_expand_valspec then (typq', typ') else (typq, typ) in + let vs = VS_val_spec (TypSchm_aux (TypSchm_ts (typq, typ), ts_l), id, exts, is_cast) in + (vs, id, typq', typ', env) in - [annotate vs typq typ], Env.add_val_spec id (typq, typ) env + ([annotate vs typq typ], Env.add_val_spec id (typq, typ) env) let check_default env def_annot (DT_aux (DT_order order, l)) = - [DEF_aux (DEF_default (DT_aux (DT_order order, l)), def_annot)], - Env.set_default_order order env + ([DEF_aux (DEF_default (DT_aux (DT_order order, l)), def_annot)], Env.set_default_order order env) let kinded_id_arg kind_id = let typ_arg l arg = A_aux (arg, l) in match kind_id with - | KOpt_aux (KOpt_kind (K_aux (K_int, _), kid), _) -> - typ_arg (kid_loc kid) (A_nexp (nvar kid)) + | KOpt_aux (KOpt_kind (K_aux (K_int, _), kid), _) -> typ_arg (kid_loc kid) (A_nexp (nvar kid)) | KOpt_aux (KOpt_kind (K_aux (K_order, _), kid), _) -> - let l = kid_loc kid in - typ_arg l (A_order (Ord_aux (Ord_var kid, l))) - | KOpt_aux (KOpt_kind (K_aux (K_type, _), kid), _) -> - typ_arg (kid_loc kid) (A_typ (mk_typ (Typ_var kid))) - | KOpt_aux (KOpt_kind (K_aux (K_bool, _), kid), _) -> - typ_arg (kid_loc kid) (A_bool (nc_var kid)) + let l = kid_loc kid in + typ_arg l (A_order (Ord_aux (Ord_var kid, l))) + | KOpt_aux (KOpt_kind (K_aux (K_type, _), kid), _) -> typ_arg (kid_loc kid) (A_typ (mk_typ (Typ_var kid))) + | KOpt_aux (KOpt_kind (K_aux (K_bool, _), kid), _) -> typ_arg (kid_loc kid) (A_bool (nc_var kid)) let fold_union_quant quants (QI_aux (qi, _)) = - match qi with - | QI_id kind_id -> quants @ [kinded_id_arg kind_id] - | _ -> quants + match qi with QI_id kind_id -> quants @ [kinded_id_arg kind_id] | _ -> quants (* We wrap this around wf_binding checks that aim to forbid recursive types to explain any error messages raised if the well-formedness check fails. *) let forbid_recursive_types type_l f = - try f () with - | Type_error (env, l, err) -> - let msg = "Types are not well-formed within this type definition. Note that recursive types are forbidden." in - raise (Type_error (env, type_l, err_because (Err_other msg, l, err))) + try f () + with Type_error (env, l, err) -> + let msg = "Types are not well-formed within this type definition. Note that recursive types are forbidden." in + raise (Type_error (env, type_l, err_because (Err_other msg, l, err))) let check_type_union u_l non_rec_env env variant typq (Tu_aux (Tu_ty_id (arg_typ, v), l)) = let ret_typ = app_typ variant (List.fold_left fold_union_quant [] (quant_items typq)) in let typ = mk_typ (Typ_fn ([arg_typ], ret_typ)) in forbid_recursive_types u_l (fun () -> wf_binding l non_rec_env (typq, arg_typ)); wf_binding l env (typq, typ); - env - |> Env.add_union_id v (typq, typ) - |> Env.add_val_spec v (typq, typ) + env |> Env.add_union_id v (typq, typ) |> Env.add_val_spec v (typq, typ) -let rec check_typedef : Env.t -> def_annot -> uannot type_def -> (tannot def) list * Env.t = - fun env def_annot (TD_aux (tdef, (l, _))) -> +let rec check_typedef : Env.t -> def_annot -> uannot type_def -> tannot def list * Env.t = + fun env def_annot (TD_aux (tdef, (l, _))) -> match tdef with | TD_abbrev (id, typq, typ_arg) -> - begin match typ_arg with - | A_aux (A_typ typ, a_l) -> - forbid_recursive_types l (fun () -> wf_binding a_l env (typq, typ)); - | _ -> () - end; - [DEF_aux (DEF_type (TD_aux (tdef, (l, empty_tannot))), def_annot)], - Env.add_typ_synonym id typq typ_arg env - + begin + match typ_arg with + | A_aux (A_typ typ, a_l) -> forbid_recursive_types l (fun () -> wf_binding a_l env (typq, typ)) + | _ -> () + end; + ([DEF_aux (DEF_type (TD_aux (tdef, (l, empty_tannot))), def_annot)], Env.add_typ_synonym id typq typ_arg env) | TD_record (id, typq, fields, _) -> - forbid_recursive_types l (fun () -> List.iter (fun (Typ_aux (_, l) as field, _) -> wf_binding l env (typq, field)) fields); - [DEF_aux (DEF_type (TD_aux (tdef, (l, empty_tannot))), def_annot)], - Env.add_record id typq fields env - + forbid_recursive_types l (fun () -> + List.iter (fun ((Typ_aux (_, l) as field), _) -> wf_binding l env (typq, field)) fields + ); + ([DEF_aux (DEF_type (TD_aux (tdef, (l, empty_tannot))), def_annot)], Env.add_record id typq fields env) | TD_variant (id, typq, arms, _) -> - let rec_env = Env.add_variant id (typq, arms) env in - (* register_value is a special type used by theorem prover - backends that we allow to be recursive. *) - let non_rec_env = if string_of_id id = "register_value" then rec_env else env in - let env = - rec_env - |> (fun env -> List.fold_left (fun env tu -> check_type_union l non_rec_env env id typq tu) env arms) - in - [DEF_aux (DEF_type (TD_aux (tdef, (l, empty_tannot))), def_annot)], - env - - | TD_enum (id, ids, _) -> - [DEF_aux (DEF_type (TD_aux (tdef, (l, empty_tannot))), def_annot)], - Env.add_enum id ids env - + let rec_env = Env.add_variant id (typq, arms) env in + (* register_value is a special type used by theorem prover + backends that we allow to be recursive. *) + let non_rec_env = if string_of_id id = "register_value" then rec_env else env in + let env = + rec_env |> fun env -> List.fold_left (fun env tu -> check_type_union l non_rec_env env id typq tu) env arms + in + ([DEF_aux (DEF_type (TD_aux (tdef, (l, empty_tannot))), def_annot)], env) + | TD_enum (id, ids, _) -> ([DEF_aux (DEF_type (TD_aux (tdef, (l, empty_tannot))), def_annot)], Env.add_enum id ids env) | TD_bitfield (id, typ, ranges) as unexpanded -> - let typ = Env.expand_synonyms env typ in - begin match typ with - (* The type of a bitfield must be a constant-width bitvector *) - | Typ_aux (Typ_app (v, [A_aux (A_nexp (Nexp_aux (Nexp_constant size, _)), _); - A_aux (A_order order, _)]), _) + let typ = Env.expand_synonyms env typ in + begin + match typ with + (* The type of a bitfield must be a constant-width bitvector *) + | Typ_aux (Typ_app (v, [A_aux (A_nexp (Nexp_aux (Nexp_constant size, _)), _); A_aux (A_order order, _)]), _) when string_of_id v = "bitvector" -> - let rec expand_range_synonyms = function - | BF_aux (BF_single nexp, l) -> - BF_aux (BF_single (Env.expand_nexp_synonyms env nexp), l) - | BF_aux (BF_range (nexp1, nexp2), l) -> - let nexp1 = Env.expand_nexp_synonyms env nexp1 in - let nexp2 = Env.expand_nexp_synonyms env nexp2 in - BF_aux (BF_range (nexp1, nexp2), l) - | BF_aux (BF_concat (r1, r2), l) -> - BF_aux (BF_concat (expand_range_synonyms r1, expand_range_synonyms r2), l) - in - let record_tdef = TD_record (id, mk_typquant [], [(typ, mk_id "bits")], false) in - let ranges = - List.map (fun (f, r) -> (f, expand_range_synonyms r)) ranges - |> List.to_seq |> Bindings.of_seq - in - let defs = - DEF_aux (DEF_type (TD_aux (record_tdef, (l, empty_uannot))), def_annot) - :: Bitfield.macro id size order ranges - in - let defs = - if !Initial_check.opt_undefined_gen - then Initial_check.generate_undefineds IdSet.empty defs - else defs - in - let defs, env = check_defs env defs in - let env = Env.add_bitfield id ranges env in - if !opt_no_bitfield_expansion - then [DEF_aux (DEF_type (TD_aux (unexpanded, (l, empty_tannot))), def_annot)], env - else defs, env - - | _ -> - typ_error env l "Underlying bitfield type must be a constant-width bitvector" - end - -and check_scattered : Env.t -> def_annot -> uannot scattered_def -> (tannot def) list * Env.t = - fun env def_annot (SD_aux (sdef, (l, uannot))) -> - match sdef with - | SD_function _ | SD_end _ | SD_mapping _ -> - [], env + let rec expand_range_synonyms = function + | BF_aux (BF_single nexp, l) -> BF_aux (BF_single (Env.expand_nexp_synonyms env nexp), l) + | BF_aux (BF_range (nexp1, nexp2), l) -> + let nexp1 = Env.expand_nexp_synonyms env nexp1 in + let nexp2 = Env.expand_nexp_synonyms env nexp2 in + BF_aux (BF_range (nexp1, nexp2), l) + | BF_aux (BF_concat (r1, r2), l) -> + BF_aux (BF_concat (expand_range_synonyms r1, expand_range_synonyms r2), l) + in + let record_tdef = TD_record (id, mk_typquant [], [(typ, mk_id "bits")], false) in + let ranges = + List.map (fun (f, r) -> (f, expand_range_synonyms r)) ranges |> List.to_seq |> Bindings.of_seq + in + let defs = + DEF_aux (DEF_type (TD_aux (record_tdef, (l, empty_uannot))), def_annot) + :: Bitfield.macro id size order ranges + in + let defs = + if !Initial_check.opt_undefined_gen then Initial_check.generate_undefineds IdSet.empty defs else defs + in + let defs, env = check_defs env defs in + let env = Env.add_bitfield id ranges env in + if !opt_no_bitfield_expansion then + ([DEF_aux (DEF_type (TD_aux (unexpanded, (l, empty_tannot))), def_annot)], env) + else (defs, env) + | _ -> typ_error env l "Underlying bitfield type must be a constant-width bitvector" + end +and check_scattered : Env.t -> def_annot -> uannot scattered_def -> tannot def list * Env.t = + fun env def_annot (SD_aux (sdef, (l, uannot))) -> + match sdef with + | SD_function _ | SD_end _ | SD_mapping _ -> ([], env) | SD_variant (id, typq) -> - [DEF_aux (DEF_scattered (SD_aux (SD_variant (id, typq), (l, empty_tannot))), def_annot)], - Env.add_scattered_variant id typq env - - | SD_unioncl (id, tu) -> - [DEF_aux (DEF_scattered (SD_aux (SD_unioncl (id, tu), (l, empty_tannot))), def_annot)], - let env = Env.add_variant_clause id tu env in - let typq, _ = Env.get_variant id env in - let definition_env = Env.get_scattered_variant_env id env in - (try check_type_union l definition_env env id typq tu with - | Type_error (env, l', err) -> - let msg = "As this is a scattered union clause, this could \ - also be caused by using a type defined after the \ - 'scattered union' declaration" in - raise (Type_error (env, l', err_because (err, id_loc id, Err_other msg)))) - + ( [DEF_aux (DEF_scattered (SD_aux (SD_variant (id, typq), (l, empty_tannot))), def_annot)], + Env.add_scattered_variant id typq env + ) + | SD_unioncl (id, tu) -> ( + ( [DEF_aux (DEF_scattered (SD_aux (SD_unioncl (id, tu), (l, empty_tannot))), def_annot)], + let env = Env.add_variant_clause id tu env in + let typq, _ = Env.get_variant id env in + let definition_env = Env.get_scattered_variant_env id env in + try check_type_union l definition_env env id typq tu + with Type_error (env, l', err) -> + let msg = + "As this is a scattered union clause, this could also be caused by using a type defined after the \ + 'scattered union' declaration" + in + raise (Type_error (env, l', err_because (err, id_loc id, Err_other msg))) + ) + ) | SD_funcl (FCL_aux (FCL_funcl (id, _), (fcl_def_annot, _)) as funcl) -> - let typq, typ = Env.get_val_spec id env in - let funcl_env = Env.add_typquant fcl_def_annot.loc typq env in - let funcl = check_funcl funcl_env funcl typ in - [DEF_aux (DEF_scattered (SD_aux (SD_funcl funcl, (l, mk_tannot ~uannot:uannot funcl_env typ))), def_annot)], - env - + let typq, typ = Env.get_val_spec id env in + let funcl_env = Env.add_typquant fcl_def_annot.loc typq env in + let funcl = check_funcl funcl_env funcl typ in + ([DEF_aux (DEF_scattered (SD_aux (SD_funcl funcl, (l, mk_tannot ~uannot funcl_env typ))), def_annot)], env) | SD_mapcl (id, mapcl) -> - let typq, typ = Env.get_val_spec id env in - let mapcl_env = Env.add_typquant l typq env in - let mapcl = check_mapcl mapcl_env mapcl typ in - [DEF_aux (DEF_scattered (SD_aux (SD_mapcl (id, mapcl), (l, empty_tannot))), def_annot)], - env + let typq, typ = Env.get_val_spec id env in + let mapcl_env = Env.add_typquant l typq env in + let mapcl = check_mapcl mapcl_env mapcl typ in + ([DEF_aux (DEF_scattered (SD_aux (SD_mapcl (id, mapcl), (l, empty_tannot))), def_annot)], env) and check_outcome : Env.t -> outcome_spec -> uannot def list -> outcome_spec * tannot def list * Env.t = - fun env (OV_aux (OV_outcome (id, typschm, params), l)) defs -> + fun env (OV_aux (OV_outcome (id, typschm, params), l)) defs -> let valid_outcome_def = function | DEF_aux ((DEF_impl _ | DEF_val _), _) -> () - | def -> - typ_error env (def_loc def) "Forbidden definition in outcome block" + | def -> typ_error env (def_loc def) "Forbidden definition in outcome block" in typ_print (lazy (Util.("Check outcome " |> cyan |> clear) ^ string_of_id id ^ " : " ^ string_of_typschm typschm)); match env.toplevel with - | None -> - begin - incr depth; - try - let local_env = { (add_typ_vars l params env) with toplevel = Some l } in - wf_typschm local_env typschm; - let quant, typ = match typschm with - | TypSchm_aux (TypSchm_ts (typq, typ), _) -> typq, typ - in - let local_env = { local_env with outcome_typschm = Some (quant, typ) } in - List.iter valid_outcome_def defs; - let defs, local_env = check_defs local_env defs in - let vals = List.filter_map (function DEF_aux (DEF_val (VS_aux (VS_val_spec (_, id, _, _), _)), _) -> Some id | _ -> None) defs in - decr depth; - OV_aux (OV_outcome (id, typschm, params), l), defs, Env.add_outcome id (quant, typ, params, vals, local_env) env - with - | Type_error (env, err_l, err) -> - decr depth; - typ_raise env err_l err - end + | None -> begin + incr depth; + try + let local_env = { (add_typ_vars l params env) with toplevel = Some l } in + wf_typschm local_env typschm; + let quant, typ = match typschm with TypSchm_aux (TypSchm_ts (typq, typ), _) -> (typq, typ) in + let local_env = { local_env with outcome_typschm = Some (quant, typ) } in + List.iter valid_outcome_def defs; + let defs, local_env = check_defs local_env defs in + let vals = + List.filter_map + (function DEF_aux (DEF_val (VS_aux (VS_val_spec (_, id, _, _), _)), _) -> Some id | _ -> None) + defs + in + decr depth; + ( OV_aux (OV_outcome (id, typschm, params), l), + defs, + Env.add_outcome id (quant, typ, params, vals, local_env) env + ) + with Type_error (env, err_l, err) -> + decr depth; + typ_raise env err_l err + end | Some outer_l -> - let msg = "Outcome must be declared within top-level scope" in - typ_raise env l (err_because (Err_other msg, outer_l, Err_other "Containing scope declared here")) + let msg = "Outcome must be declared within top-level scope" in + typ_raise env l (err_because (Err_other msg, outer_l, Err_other "Containing scope declared here")) and check_impldef : Env.t -> def_annot -> uannot funcl -> tannot def list * Env.t = - fun env def_annot (FCL_aux (FCL_funcl (id, _), (fcl_def_annot, _)) as funcl) -> + fun env def_annot (FCL_aux (FCL_funcl (id, _), (fcl_def_annot, _)) as funcl) -> typ_print (lazy (Util.("Check impl " |> cyan |> clear) ^ string_of_id id)); match env.outcome_typschm with | Some (quant, typ) -> - let funcl_env = Env.add_typquant fcl_def_annot.loc quant env in - [DEF_aux (DEF_impl (check_funcl funcl_env funcl typ), def_annot)], env - | None -> - typ_error env fcl_def_annot.loc "Cannot declare an implementation outside of an outcome" + let funcl_env = Env.add_typquant fcl_def_annot.loc quant env in + ([DEF_aux (DEF_impl (check_funcl funcl_env funcl typ), def_annot)], env) + | None -> typ_error env fcl_def_annot.loc "Cannot declare an implementation outside of an outcome" -and check_outcome_instantiation : 'a. Env.t -> def_annot -> 'a instantiation_spec -> subst list -> tannot def list * Env.t = - fun env def_annot (IN_aux (IN_id id, (l, _))) substs -> +and check_outcome_instantiation : + 'a. Env.t -> def_annot -> 'a instantiation_spec -> subst list -> tannot def list * Env.t = + fun env def_annot (IN_aux (IN_id id, (l, _))) substs -> typ_print (lazy (Util.("Check instantiation " |> cyan |> clear) ^ string_of_id id)); let typq, typ, params, vals, outcome_env = Env.get_outcome l id env in (* Find the outcome parameters that were already instantiated by previous instantiation commands *) let instantiated, uninstantiated = - Util.map_split (fun kopt -> + Util.map_split + (fun kopt -> match KBindings.find_opt (kopt_kid kopt) (Env.get_outcome_instantiation env) with | Some (prev_l, existing_typ) -> Ok (kopt_kid kopt, (prev_l, kopt_kind kopt, existing_typ)) | None -> Error kopt - ) params in + ) + params + in let instantiated = List.fold_left (fun m (kid, inst) -> KBindings.add kid inst m) KBindings.empty instantiated in (* Instantiate the outcome type with these existing parameters *) - let typ = List.fold_left (fun typ (kid, (_, _, existing_typ)) -> typ_subst kid (mk_typ_arg (A_typ existing_typ)) typ) - typ (KBindings.bindings instantiated) in + let typ = + List.fold_left + (fun typ (kid, (_, _, existing_typ)) -> typ_subst kid (mk_typ_arg (A_typ existing_typ)) typ) + typ (KBindings.bindings instantiated) + in let instantiate_typ substs typ = - List.fold_left (fun (typ, new_instantiated, fns, env) -> function - | IS_aux (IS_typ (kid, subst_typ), decl_l) -> - begin match KBindings.find_opt kid instantiated with - | Some (_, _, existing_typ) when alpha_equivalent env subst_typ existing_typ -> - typ, new_instantiated, fns, env - | Some (prev_l, _, existing_typ) -> - let msg = Printf.sprintf "Cannot instantiate %s with %s, already instantiated as %s" - (string_of_kid kid) (string_of_typ subst_typ) (string_of_typ existing_typ) in - typ_raise env decl_l (err_because (Err_other msg, prev_l, Err_other "Previously instantiated here")) - | None -> - Env.wf_typ env subst_typ; - typ_subst kid (mk_typ_arg (A_typ subst_typ)) typ, - (kid, subst_typ) :: new_instantiated, - fns, - Env.add_outcome_variable decl_l kid subst_typ env - end - | IS_aux (IS_id (id_from, id_to), decl_l) -> - typ, new_instantiated, (id_from, id_to, decl_l) :: fns, env - ) (typ, [], [], env) substs + List.fold_left + (fun (typ, new_instantiated, fns, env) -> function + | IS_aux (IS_typ (kid, subst_typ), decl_l) -> begin + match KBindings.find_opt kid instantiated with + | Some (_, _, existing_typ) when alpha_equivalent env subst_typ existing_typ -> + (typ, new_instantiated, fns, env) + | Some (prev_l, _, existing_typ) -> + let msg = + Printf.sprintf "Cannot instantiate %s with %s, already instantiated as %s" (string_of_kid kid) + (string_of_typ subst_typ) (string_of_typ existing_typ) + in + typ_raise env decl_l (err_because (Err_other msg, prev_l, Err_other "Previously instantiated here")) + | None -> + Env.wf_typ env subst_typ; + ( typ_subst kid (mk_typ_arg (A_typ subst_typ)) typ, + (kid, subst_typ) :: new_instantiated, + fns, + Env.add_outcome_variable decl_l kid subst_typ env + ) + end + | IS_aux (IS_id (id_from, id_to), decl_l) -> (typ, new_instantiated, (id_from, id_to, decl_l) :: fns, env) + ) + (typ, [], [], env) substs in let typ, new_instantiated, fns, env = instantiate_typ substs typ in (* Make sure every required outcome parameter has been instantiated *) - List.iter (fun kopt -> + List.iter + (fun kopt -> if not (List.exists (fun (v, _) -> Kid.compare (kopt_kid kopt) v = 0) new_instantiated) then typ_error env l ("Type variable " ^ string_of_kinded_id kopt ^ " must be instantiated") - ) uninstantiated; + ) + uninstantiated; - begin match List.find_opt (fun id -> not (List.exists (fun (id_from, _, _) -> Id.compare id id_from = 0) fns)) vals with - | Some val_id -> - typ_error env l ("Function " ^ string_of_id val_id ^ " must be instantiated for " ^ string_of_id id) - | None -> () + begin + match List.find_opt (fun id -> not (List.exists (fun (id_from, _, _) -> Id.compare id id_from = 0) fns)) vals with + | Some val_id -> typ_error env l ("Function " ^ string_of_id val_id ^ " must be instantiated for " ^ string_of_id id) + | None -> () end; - - List.iter (fun (id_from, id_to, decl_l) -> - let (to_typq, to_typ) = Env.get_val_spec id_to env in - let (from_typq, from_typ) = Env.get_val_spec id_from outcome_env in + + List.iter + (fun (id_from, id_to, decl_l) -> + let to_typq, to_typ = Env.get_val_spec id_to env in + let from_typq, from_typ = Env.get_val_spec id_from outcome_env in typ_debug (lazy (string_of_bind (to_typq, to_typ))); - - let from_typ = List.fold_left (fun typ (v, subst_typ) -> typ_subst v (mk_typ_arg (A_typ subst_typ)) typ) from_typ new_instantiated in - let from_typ = List.fold_left (fun typ (v, (_, _, subst_typ)) -> typ_subst v (mk_typ_arg (A_typ subst_typ)) typ) from_typ (KBindings.bindings instantiated) in - - check_function_instantiation decl_l id_from env (to_typq, to_typ) (from_typq, from_typ); - ) fns; - [DEF_aux (DEF_instantiation (IN_aux (IN_id id, (l, mk_tannot env unit_typ)), substs), def_annot)], - Env.add_val_spec id (typq, typ) env + let from_typ = + List.fold_left + (fun typ (v, subst_typ) -> typ_subst v (mk_typ_arg (A_typ subst_typ)) typ) + from_typ new_instantiated + in + let from_typ = + List.fold_left + (fun typ (v, (_, _, subst_typ)) -> typ_subst v (mk_typ_arg (A_typ subst_typ)) typ) + from_typ (KBindings.bindings instantiated) + in + + check_function_instantiation decl_l id_from env (to_typq, to_typ) (from_typq, from_typ) + ) + fns; + + ( [DEF_aux (DEF_instantiation (IN_aux (IN_id id, (l, mk_tannot env unit_typ)), substs), def_annot)], + Env.add_val_spec id (typq, typ) env + ) and check_def : Env.t -> uannot def -> tannot def list * Env.t = - fun env (DEF_aux (aux, def_annot)) -> + fun env (DEF_aux (aux, def_annot)) -> match aux with - | DEF_fixity (prec, n, op) -> [DEF_aux (DEF_fixity (prec, n, op), def_annot)], env + | DEF_fixity (prec, n, op) -> ([DEF_aux (DEF_fixity (prec, n, op), def_annot)], env) | DEF_type tdef -> check_typedef env def_annot tdef | DEF_fundef fdef -> check_fundef env def_annot fdef | DEF_mapdef mdef -> check_mapdef env def_annot mdef | DEF_impl funcl -> check_impldef env def_annot funcl | DEF_internal_mutrec fdefs -> - let defs = List.concat (List.map (fun fdef -> fst (check_fundef env def_annot fdef)) fdefs) in - let split_fundef (defs, fdefs) def = match def with - | DEF_aux (DEF_fundef fdef, _) -> (defs, fdefs @ [fdef]) - | _ -> (defs @ [def], fdefs) in - let (defs, fdefs) = List.fold_left split_fundef ([], []) defs in - (defs @ [DEF_aux (DEF_internal_mutrec fdefs, def_annot)]), env + let defs = List.concat (List.map (fun fdef -> fst (check_fundef env def_annot fdef)) fdefs) in + let split_fundef (defs, fdefs) def = + match def with DEF_aux (DEF_fundef fdef, _) -> (defs, fdefs @ [fdef]) | _ -> (defs @ [def], fdefs) + in + let defs, fdefs = List.fold_left split_fundef ([], []) defs in + (defs @ [DEF_aux (DEF_internal_mutrec fdefs, def_annot)], env) | DEF_let letdef -> check_letdef env def_annot letdef | DEF_val vs -> check_val_spec env def_annot vs | DEF_outcome (outcome, defs) -> - let outcome, defs, env = check_outcome env outcome defs in - [DEF_aux (DEF_outcome (outcome, defs), def_annot)], env + let outcome, defs, env = check_outcome env outcome defs in + ([DEF_aux (DEF_outcome (outcome, defs), def_annot)], env) | DEF_instantiation (ispec, substs) -> check_outcome_instantiation env def_annot ispec substs | DEF_default default -> check_default env def_annot default - | DEF_overload (id, ids) -> [DEF_aux (DEF_overload (id, ids), def_annot)], Env.add_overloads id ids env + | DEF_overload (id, ids) -> ([DEF_aux (DEF_overload (id, ids), def_annot)], Env.add_overloads id ids env) | DEF_register (DEC_aux (DEC_reg (typ, id, None), (l, _))) -> - let env = Env.add_register id typ env in - [DEF_aux (DEF_register (DEC_aux (DEC_reg (typ, id, None), (l, mk_expected_tannot env typ (Some typ)))), def_annot)], env + let env = Env.add_register id typ env in + ( [ + DEF_aux + (DEF_register (DEC_aux (DEC_reg (typ, id, None), (l, mk_expected_tannot env typ (Some typ)))), def_annot); + ], + env + ) | DEF_register (DEC_aux (DEC_reg (typ, id, Some exp), (l, _))) -> - let checked_exp = crule check_exp env exp typ in - let env = Env.add_register id typ env in - [DEF_aux (DEF_register (DEC_aux (DEC_reg (typ, id, Some checked_exp), (l, mk_expected_tannot env typ (Some typ)))), def_annot)], env - | DEF_pragma (pragma, arg, l) -> [DEF_aux (DEF_pragma (pragma, arg, l), def_annot)], env + let checked_exp = crule check_exp env exp typ in + let env = Env.add_register id typ env in + ( [ + DEF_aux + ( DEF_register (DEC_aux (DEC_reg (typ, id, Some checked_exp), (l, mk_expected_tannot env typ (Some typ)))), + def_annot + ); + ], + env + ) + | DEF_pragma (pragma, arg, l) -> ([DEF_aux (DEF_pragma (pragma, arg, l), def_annot)], env) | DEF_scattered sdef -> check_scattered env def_annot sdef - | DEF_measure (id, pat, exp) -> [check_termination_measure_decl env def_annot (id, pat, exp)], env + | DEF_measure (id, pat, exp) -> ([check_termination_measure_decl env def_annot (id, pat, exp)], env) | DEF_loop_measures (id, _) -> - Reporting.unreachable (id_loc id) __POS__ - "Loop termination measures should have been rewritten before type checking" + Reporting.unreachable (id_loc id) __POS__ + "Loop termination measures should have been rewritten before type checking" and check_defs_progress : int -> int -> Env.t -> uannot def list -> tannot def list * Env.t = - fun n total env defs -> + fun n total env defs -> match defs with - | [] -> [], env + | [] -> ([], env) | def :: defs -> - Util.progress "Type check " (string_of_int n ^ "/" ^ string_of_int total) n total; - let (def, env) = check_def env def in - let defs, env = check_defs_progress (n + 1) total env defs in - def @ defs, env + Util.progress "Type check " (string_of_int n ^ "/" ^ string_of_int total) n total; + let def, env = check_def env def in + let defs, env = check_defs_progress (n + 1) total env defs in + (def @ defs, env) and check_defs : Env.t -> uannot def list -> tannot def list * Env.t = - fun env defs -> let total = List.length defs in check_defs_progress 1 total env defs + fun env defs -> + let total = List.length defs in + check_defs_progress 1 total env defs let check : Env.t -> uannot ast -> tannot ast * Env.t = - fun env ast -> + fun env ast -> let total = List.length ast.defs in let defs, env = check_defs_progress 1 total env ast.defs in - { ast with defs = defs }, env + ({ ast with defs }, env) let rec check_with_envs : Env.t -> uannot def list -> (tannot def list * Env.t) list = - fun env defs -> + fun env defs -> match defs with | [] -> [] | def :: defs -> - let def, env = check_def env def in - (def, env) :: check_with_envs env defs + let def, env = check_def env def in + (def, env) :: check_with_envs env defs let initial_env = Env.empty |> Env.set_prover (Some (prove __POS__)) |> Env.add_extern (mk_id "size_itself_int") { pure = true; bindings = [("_", "size_itself_int")] } |> Env.add_val_spec (mk_id "size_itself_int") - (TypQ_aux (TypQ_tq [QI_aux (QI_id (mk_kopt K_int (mk_kid "n")), - Parse_ast.Unknown)],Parse_ast.Unknown), - function_typ [app_typ (mk_id "itself") [mk_typ_arg (A_nexp (nvar (mk_kid "n")))]] - (atom_typ (nvar (mk_kid "n")))) + ( TypQ_aux (TypQ_tq [QI_aux (QI_id (mk_kopt K_int (mk_kid "n")), Parse_ast.Unknown)], Parse_ast.Unknown), + function_typ [app_typ (mk_id "itself") [mk_typ_arg (A_nexp (nvar (mk_kid "n")))]] (atom_typ (nvar (mk_kid "n"))) + ) |> Env.add_extern (mk_id "make_the_value") { pure = true; bindings = [("_", "make_the_value")] } |> Env.add_val_spec (mk_id "make_the_value") - (TypQ_aux (TypQ_tq [QI_aux (QI_id (mk_kopt K_int (mk_kid "n")), - Parse_ast.Unknown)],Parse_ast.Unknown), - function_typ [atom_typ (nvar (mk_kid "n"))] - (app_typ (mk_id "itself") [mk_typ_arg (A_nexp (nvar (mk_kid "n")))])) + ( TypQ_aux (TypQ_tq [QI_aux (QI_id (mk_kopt K_int (mk_kid "n")), Parse_ast.Unknown)], Parse_ast.Unknown), + function_typ [atom_typ (nvar (mk_kid "n"))] (app_typ (mk_id "itself") [mk_typ_arg (A_nexp (nvar (mk_kid "n")))]) + ) (* __assume is used by property.ml to add guards for SMT generation, - but which don't affect flow-typing. *) + but which don't affect flow-typing. *) |> Env.add_val_spec (mk_id "sail_assume") - (TypQ_aux (TypQ_no_forall, Parse_ast.Unknown), - function_typ [bool_typ] unit_typ) + (TypQ_aux (TypQ_no_forall, Parse_ast.Unknown), function_typ [bool_typ] unit_typ) diff --git a/src/lib/type_check.mli b/src/lib/type_check.mli index 396cbfccb..bc965313e 100644 --- a/src/lib/type_check.mli +++ b/src/lib/type_check.mli @@ -110,7 +110,7 @@ type type_error = type env -exception Type_error of env * l * type_error;; +exception Type_error of env * l * type_error val typ_debug : ?level:int -> string Lazy.t -> unit val typ_print : string Lazy.t -> unit @@ -130,7 +130,7 @@ module Env : sig type variables. *) val get_val_spec : id -> t -> typquant * typ - val get_val_specs : t -> (typquant * typ ) Bindings.t + val get_val_specs : t -> (typquant * typ) Bindings.t val get_defined_val_specs : t -> IdSet.t @@ -165,7 +165,7 @@ module Env : sig (** Get the current set of constraints. *) val get_constraints : t -> n_constraint list - val add_constraint : ?reason:(Ast.l * string) -> n_constraint -> t -> t + val add_constraint : ?reason:Ast.l * string -> n_constraint -> t -> t (** Push all the type variables and constraints from a typquant into an environment *) @@ -306,6 +306,7 @@ type tannot calling destruct_tannot followed by mk_tannot returns an identical type annotation. *) val destruct_tannot : tannot -> (Env.t * typ) option + val mk_tannot : ?uannot:uannot -> Env.t -> typ -> tannot val untyped_annot : tannot -> uannot @@ -328,6 +329,7 @@ val strip_exp : tannot exp -> uannot exp (** Strip the type annotations from a pattern *) val strip_pat : tannot pat -> uannot pat + val strip_mpat : tannot mpat -> uannot mpat (** Strip the type annotations from a pattern-expression *) @@ -348,6 +350,7 @@ val strip_ast : tannot ast -> uannot ast (** Strip location information from types for comparison purposes *) val strip_typ : typ -> typ + val strip_typq : typquant -> typquant val strip_id : id -> id val strip_kid : kid -> kid @@ -391,14 +394,15 @@ val assert_constraint : Env.t -> bool -> tannot exp -> n_constraint option this is only exposed so that it can be used during descattering to check completeness of scattered functions, and should not be called otherwise. *) -val check_funcls_complete : Parse_ast.l -> Env.t -> tannot funcl list -> typ -> tannot funcl list * (def_annot -> def_annot) +val check_funcls_complete : + Parse_ast.l -> Env.t -> tannot funcl list -> typ -> tannot funcl list * (def_annot -> def_annot) (** Attempt to prove a constraint using z3. Returns true if z3 can prove that the constraint is true, returns false if z3 cannot prove the constraint true. Note that this does not guarantee that the constraint is actually false, as the constraint solver is somewhat untrustworthy. *) -val prove : (string * int * int * int) -> Env.t -> n_constraint -> bool +val prove : string * int * int * int -> Env.t -> n_constraint -> bool (** Returns Some c if there is a unique c such that nexp = c *) val solve_unique : Env.t -> nexp -> Big_int.num option @@ -464,6 +468,7 @@ val expected_typ_of : Ast.l * tannot -> typ option a collision. The "plain" version does not treat numeric types (i.e. range, int, nat) as existentials. *) val destruct_exist_plain : ?name:string option -> typ -> (kinded_id list * n_constraint * typ) option + val destruct_exist : ?name:string option -> typ -> (kinded_id list * n_constraint * typ) option val destruct_atom_nexp : Env.t -> typ -> nexp option diff --git a/src/lib/type_error.ml b/src/lib/type_error.ml index bfa62de67..e76fbd587 100644 --- a/src/lib/type_error.ml +++ b/src/lib/type_error.ml @@ -73,260 +73,281 @@ open Type_check let opt_explain_all_variables = ref false let opt_explain_constraints = ref false - -type suggestion = - | Suggest_add_constraint of n_constraint - | Suggest_none + +type suggestion = Suggest_add_constraint of n_constraint | Suggest_none let analyze_unresolved_quant locals ncs = function | QI_aux (QI_constraint nc, _) -> - let gen_kids = List.filter is_kid_generated (KidSet.elements (tyvars_of_constraint nc)) in - if gen_kids = [] then - Suggest_add_constraint nc - else - (* If there are generated kind-identifiers in the constraint, - we don't want to make a suggestion based on them, so try to - look for generated kid free nexps in the set of constraints - that are equal to the generated identifier. This often - occurs due to how the type-checker introduces new type - variables. *) - let is_subst v = function - | NC_aux (NC_equal (Nexp_aux (Nexp_var v', _), nexp), _) - when Kid.compare v v' = 0 && not (KidSet.exists is_kid_generated (tyvars_of_nexp nexp)) -> - [(v, nexp)] - | NC_aux (NC_equal (nexp, Nexp_aux (Nexp_var v', _)), _) - when Kid.compare v v' = 0 && not (KidSet.exists is_kid_generated (tyvars_of_nexp nexp)) -> - [(v, nexp)] - | _ -> [] - in - let substs = List.concat (List.map (fun v -> List.concat (List.map (fun nc -> is_subst v nc) ncs)) gen_kids) in - let nc = List.fold_left (fun nc (v, nexp) -> constraint_subst v (arg_nexp nexp) nc) nc substs in - if not (KidSet.exists is_kid_generated (tyvars_of_constraint nc)) then - Suggest_add_constraint nc - else - (* If we have a really anonymous type-variable, try to find a - regular variable that corresponds to it. *) - let is_linked v = function - | (id, (Immutable, (Typ_aux (Typ_app (ty_id, [A_aux (A_nexp (Nexp_aux (Nexp_var v', _)), _)]), _) as typ))) - when Id.compare ty_id (mk_id "atom") = 0 && Kid.compare v v' = 0 -> - [(v, nid id, typ)] - | (id, (mut, typ)) -> - [] - in - let substs = List.concat (List.map (fun v -> List.concat (List.map (fun nc -> is_linked v nc) (Bindings.bindings locals))) gen_kids) in - let nc = List.fold_left (fun nc (v, nexp, _) -> constraint_subst v (arg_nexp nexp) nc) nc substs in - if not (KidSet.exists is_kid_generated (tyvars_of_constraint nc)) then - Suggest_none - else - Suggest_none - - | QI_aux (QI_id _, _) -> - Suggest_none + let gen_kids = List.filter is_kid_generated (KidSet.elements (tyvars_of_constraint nc)) in + if gen_kids = [] then Suggest_add_constraint nc + else ( + (* If there are generated kind-identifiers in the constraint, + we don't want to make a suggestion based on them, so try to + look for generated kid free nexps in the set of constraints + that are equal to the generated identifier. This often + occurs due to how the type-checker introduces new type + variables. *) + let is_subst v = function + | NC_aux (NC_equal (Nexp_aux (Nexp_var v', _), nexp), _) + when Kid.compare v v' = 0 && not (KidSet.exists is_kid_generated (tyvars_of_nexp nexp)) -> + [(v, nexp)] + | NC_aux (NC_equal (nexp, Nexp_aux (Nexp_var v', _)), _) + when Kid.compare v v' = 0 && not (KidSet.exists is_kid_generated (tyvars_of_nexp nexp)) -> + [(v, nexp)] + | _ -> [] + in + let substs = List.concat (List.map (fun v -> List.concat (List.map (fun nc -> is_subst v nc) ncs)) gen_kids) in + let nc = List.fold_left (fun nc (v, nexp) -> constraint_subst v (arg_nexp nexp) nc) nc substs in + if not (KidSet.exists is_kid_generated (tyvars_of_constraint nc)) then Suggest_add_constraint nc + else ( + (* If we have a really anonymous type-variable, try to find a + regular variable that corresponds to it. *) + let is_linked v = function + | id, (Immutable, (Typ_aux (Typ_app (ty_id, [A_aux (A_nexp (Nexp_aux (Nexp_var v', _)), _)]), _) as typ)) + when Id.compare ty_id (mk_id "atom") = 0 && Kid.compare v v' = 0 -> + [(v, nid id, typ)] + | id, (mut, typ) -> [] + in + let substs = + List.concat + (List.map (fun v -> List.concat (List.map (fun nc -> is_linked v nc) (Bindings.bindings locals))) gen_kids) + in + let nc = List.fold_left (fun nc (v, nexp, _) -> constraint_subst v (arg_nexp nexp) nc) nc substs in + if not (KidSet.exists is_kid_generated (tyvars_of_constraint nc)) then Suggest_none else Suggest_none + ) + ) + | QI_aux (QI_id _, _) -> Suggest_none let readable_name (Kid_aux (Var str, l)) = let str = String.concat "" (String.split_on_char '#' str) in - let str = if String.length str > 1 && str.[1] = '_' then String.sub str 0 1 ^ String.sub str 2 (String.length str - 2) else str in + let str = + if String.length str > 1 && str.[1] = '_' then String.sub str 0 1 ^ String.sub str 2 (String.length str - 2) + else str + in Kid_aux (Var (String.concat "" (String.split_on_char '#' str)), l) let has_underscore (Kid_aux (Var str, l)) = String.length str > 1 && str.[1] = '_' - + let error_string_of_kid substs kid = - match KBindings.find_opt kid substs with - | Some nexp -> string_of_nexp nexp - | None -> string_of_kid kid + match KBindings.find_opt kid substs with Some nexp -> string_of_nexp nexp | None -> string_of_kid kid + +let _error_string_of_nexp substs nexp = string_of_nexp (subst_kids_nexp substs nexp) + +let error_string_of_nc substs nexp = string_of_n_constraint (subst_kids_nc substs nexp) + +let error_string_of_typ substs typ = string_of_typ (subst_kids_typ substs typ) -let _error_string_of_nexp substs nexp = - string_of_nexp (subst_kids_nexp substs nexp) +let error_string_of_typ_arg substs arg = string_of_typ_arg (subst_kids_typ_arg substs arg) -let error_string_of_nc substs nexp = - string_of_n_constraint (subst_kids_nc substs nexp) - -let error_string_of_typ substs typ = - string_of_typ (subst_kids_typ substs typ) +let has_variable set nexp = not (KidSet.is_empty (KidSet.inter set (tyvars_of_nexp nexp))) -let error_string_of_typ_arg substs arg = - string_of_typ_arg (subst_kids_typ_arg substs arg) - -let has_variable set nexp = - not (KidSet.is_empty (KidSet.inter set (tyvars_of_nexp nexp))) - let rewrite_equality preferred_on_right (NC_aux (aux, l) as nc) = - let equality = match aux with + let equality = + match aux with | NC_equal (lhs, rhs) -> - if has_variable preferred_on_right lhs && not (has_variable preferred_on_right rhs) then - Some (rhs, lhs) - else - Some (lhs, rhs) - | _ -> None in - match equality with - | Some (lhs, rhs) -> - NC_aux (NC_equal (lhs, rhs), l) - | None -> - nc + if has_variable preferred_on_right lhs && not (has_variable preferred_on_right rhs) then Some (rhs, lhs) + else Some (lhs, rhs) + | _ -> None + in + match equality with Some (lhs, rhs) -> NC_aux (NC_equal (lhs, rhs), l) | None -> nc let subst_preferred_variables prefs constraints = let simplified_by v arg = function - | Some (l, str) -> Some (l, fun substs -> "simplified by " ^ str ^ " with " ^ error_string_of_kid substs v ^ " = " ^ error_string_of_typ_arg substs arg) - | None -> None in - let original_location = function - | Some (l, str) -> Some (l, fun _ -> "introduced here by " ^ str) - | None -> None in - let all_substs, constraints = - Util.map_split (function - | (r, NC_aux (NC_equal (Nexp_aux (Nexp_var v, _), rhs), _)) - when has_variable prefs rhs && not (KidSet.mem v (tyvars_of_nexp rhs)) -> - Ok (r, v, arg_nexp rhs) - | (r, NC_aux (NC_app (id, [A_aux (A_bool (NC_aux (NC_var v, _)), _)]), _)) when string_of_id id = "not" -> - Ok (r, v, arg_bool nc_false) - | (r, NC_aux (NC_var v, _)) -> - Ok (r, v, arg_bool nc_true) - | (r, nc) -> - Error (r, nc) - ) constraints in + | Some (l, str) -> + Some + ( l, + fun substs -> + "simplified by " ^ str ^ " with " ^ error_string_of_kid substs v ^ " = " + ^ error_string_of_typ_arg substs arg + ) + | None -> None + in + let original_location = function Some (l, str) -> Some (l, fun _ -> "introduced here by " ^ str) | None -> None in + let all_substs, constraints = + Util.map_split + (function + | r, NC_aux (NC_equal (Nexp_aux (Nexp_var v, _), rhs), _) + when has_variable prefs rhs && not (KidSet.mem v (tyvars_of_nexp rhs)) -> + Ok (r, v, arg_nexp rhs) + | r, NC_aux (NC_app (id, [A_aux (A_bool (NC_aux (NC_var v, _)), _)]), _) when string_of_id id = "not" -> + Ok (r, v, arg_bool nc_false) + | r, NC_aux (NC_var v, _) -> Ok (r, v, arg_bool nc_true) + | r, nc -> Error (r, nc) + ) + constraints + in (* Filter out any substitutions that just rename variables *) let rename_substs, other_substs = - List.partition (function - | (_, _, A_aux (A_nexp (Nexp_aux (Nexp_var _, _)), _)) -> true - | _ -> false - ) all_substs in + List.partition (function _, _, A_aux (A_nexp (Nexp_aux (Nexp_var _, _)), _) -> true | _ -> false) all_substs + in (* and apply those renaming substitutions first *) let constraints = - List.map (fun (r, nc) -> - (r, List.fold_left (fun nc (_, v, arg) -> constraint_subst v arg nc) nc rename_substs) - ) constraints in + List.map + (fun (r, nc) -> (r, List.fold_left (fun nc (_, v, arg) -> constraint_subst v arg nc) nc rename_substs)) + constraints + in (* now apply the more interesting substitutions, keeping track of the reasons for using them *) - List.map (fun (r, orig_nc) -> - List.fold_left (fun (rs, b, nc) (r', v, arg) -> - if KidSet.mem v (tyvars_of_constraint nc) then - (simplified_by v arg r' :: rs, true, constraint_subst v arg nc) - else - (rs, b, nc) - ) ([original_location r], false, orig_nc) other_substs - |> (fun (r, changed, nc) -> (r, (if changed then Some orig_nc else None), constraint_simp nc)) - ) constraints + List.map + (fun (r, orig_nc) -> + List.fold_left + (fun (rs, b, nc) (r', v, arg) -> + if KidSet.mem v (tyvars_of_constraint nc) then (simplified_by v arg r' :: rs, true, constraint_subst v arg nc) + else (rs, b, nc) + ) + ([original_location r], false, orig_nc) + other_substs + |> fun (r, changed, nc) -> (r, (if changed then Some orig_nc else None), constraint_simp nc) + ) + constraints -let rec map_typ_arg ?under:(under = []) f (Typ_aux (aux, l)) = - let aux = match aux with +let rec map_typ_arg ?(under = []) f (Typ_aux (aux, l)) = + let aux = + match aux with | Typ_internal_unknown -> Typ_internal_unknown | Typ_id id -> Typ_id id | Typ_var v -> Typ_var v - | Typ_fn (typs, typ) -> Typ_fn (List.map (map_typ_arg ~under:under f) typs, map_typ_arg ~under:under f typ) - | Typ_bidir (typ1, typ2) -> Typ_bidir (map_typ_arg ~under:under f typ1, map_typ_arg ~under:under f typ2) - | Typ_tuple typs -> Typ_tuple (List.map (map_typ_arg ~under:under f) typs) + | Typ_fn (typs, typ) -> Typ_fn (List.map (map_typ_arg ~under f) typs, map_typ_arg ~under f typ) + | Typ_bidir (typ1, typ2) -> Typ_bidir (map_typ_arg ~under f typ1, map_typ_arg ~under f typ2) + | Typ_tuple typs -> Typ_tuple (List.map (map_typ_arg ~under f) typs) | Typ_app (id, args) -> - List.map (function - | A_aux (A_typ typ, l) -> - let typ = map_typ_arg ~under:under f typ in - f under id (A_aux (A_typ typ, l)) - | arg -> f under id arg - ) args - |> (fun args -> Typ_app (id, args)) - | Typ_exist (vars, nc, typ) -> - Typ_exist (vars, nc, map_typ_arg ~under:((vars, nc) :: under) f typ) + List.map + (function + | A_aux (A_typ typ, l) -> + let typ = map_typ_arg ~under f typ in + f under id (A_aux (A_typ typ, l)) + | arg -> f under id arg + ) + args + |> fun args -> Typ_app (id, args) + | Typ_exist (vars, nc, typ) -> Typ_exist (vars, nc, map_typ_arg ~under:((vars, nc) :: under) f typ) in Typ_aux (aux, l) -let simp_typ = map_typ_arg (fun _ _ -> function - | A_aux (A_nexp nexp, l) -> A_aux (A_nexp (nexp_simp nexp), l) - | A_aux (A_bool nc, l) -> A_aux (A_bool (constraint_simp nc), l) - | arg -> arg) - +let simp_typ = + map_typ_arg (fun _ _ -> function + | A_aux (A_nexp nexp, l) -> A_aux (A_nexp (nexp_simp nexp), l) + | A_aux (A_bool nc, l) -> A_aux (A_bool (constraint_simp nc), l) + | arg -> arg + ) + let message_of_type_error = let open Error_format in let rec msg = function | Err_inner (err, l', prefix, hint, err') -> - let prefix = if prefix = "" then "" else Util.((prefix ^ " ") |> yellow |> clear) in - Seq [msg err; - Line ""; - Location (prefix, hint, l', msg err')] - + let prefix = if prefix = "" then "" else Util.(prefix ^ " " |> yellow |> clear) in + Seq [msg err; Line ""; Location (prefix, hint, l', msg err')] | Err_other str -> if str = "" then Seq [] else Line str - | Err_no_overloading (id, errs) -> - Seq [Line ("No overloading for " ^ string_of_id id ^ ", tried:"); - List (List.map (fun (id, err) -> string_of_id id, msg err) errs)] - + Seq + [ + Line ("No overloading for " ^ string_of_id id ^ ", tried:"); + List (List.map (fun (id, err) -> (string_of_id id, msg err)) errs); + ] | Err_unresolved_quants (id, quants, locals, ncs) -> - Seq [Line ("Could not resolve quantifiers for " ^ string_of_id id); - Line (bullet ^ " " ^ Util.string_of_list ("\n" ^ bullet ^ " ") string_of_quant_item quants)] - - | Err_failed_constraint (check, locals, ncs) -> - Line ("Failed to prove constraint: " ^ string_of_n_constraint check) - + Seq + [ + Line ("Could not resolve quantifiers for " ^ string_of_id id); + Line (bullet ^ " " ^ Util.string_of_list ("\n" ^ bullet ^ " ") string_of_quant_item quants); + ] + | Err_failed_constraint (check, locals, ncs) -> Line ("Failed to prove constraint: " ^ string_of_n_constraint check) | Err_subtype (typ1, typ2, nc, all_constraints, all_vars) -> - let nc = Option.map constraint_simp nc in - let typ1, typ2 = simp_typ typ1, simp_typ typ2 in - let nc_vars = match nc with Some nc -> tyvars_of_constraint nc | None -> KidSet.empty in - (* Variables appearing in the types and constraint *) - let appear_vars = - KBindings.bindings all_vars - |> List.filter (fun (v, _) -> KidSet.mem v (KidSet.union nc_vars (KidSet.union (tyvars_of_typ typ1) (tyvars_of_typ typ2)))) in - let vars = List.filter (fun (v, _) -> is_kid_generated v || has_underscore v) appear_vars in - - let preferred = KidSet.of_list (List.map fst appear_vars) in - let rewritten_constraints = - List.map (fun (reason, nc) -> - (reason, rewrite_equality preferred nc) - ) all_constraints - |> subst_preferred_variables preferred in - - let var_constraints = - List.map (fun (v, l) -> - (v, l, List.filter (fun (_, _, nc) -> KidSet.mem v (tyvars_of_constraint nc)) rewritten_constraints) - ) (if !opt_explain_all_variables then appear_vars else vars) in + let nc = Option.map constraint_simp nc in + let typ1, typ2 = (simp_typ typ1, simp_typ typ2) in + let nc_vars = match nc with Some nc -> tyvars_of_constraint nc | None -> KidSet.empty in + (* Variables appearing in the types and constraint *) + let appear_vars = + KBindings.bindings all_vars + |> List.filter (fun (v, _) -> + KidSet.mem v (KidSet.union nc_vars (KidSet.union (tyvars_of_typ typ1) (tyvars_of_typ typ2))) + ) + in + let vars = List.filter (fun (v, _) -> is_kid_generated v || has_underscore v) appear_vars in - let substs = - List.fold_left (fun (substs, new_vars) (v, _) -> - if is_kid_generated v || has_underscore v then - let v' = readable_name v in - if not (KBindings.mem v' all_vars) && not (KidSet.mem v' new_vars) then - (KBindings.add v (nvar v') substs, KidSet.add v' new_vars) - else - (substs, new_vars) - else - (substs, new_vars) - ) (KBindings.empty, KidSet.empty) vars - |> fst in + let preferred = KidSet.of_list (List.map fst appear_vars) in + let rewritten_constraints = + List.map (fun (reason, nc) -> (reason, rewrite_equality preferred nc)) all_constraints + |> subst_preferred_variables preferred + in - let format_var_constraint (reasons, original_nc, nc) = - if List.for_all (function None -> true | Some _ -> false) reasons || not !opt_explain_constraints then - Line ("has constraint: " ^ error_string_of_nc substs nc) - else - Seq (Line ("has constraint " ^ error_string_of_nc substs nc) - :: (match original_nc with Some nc -> [Line ("original constraint was " ^ error_string_of_nc substs nc)] | None -> []) - @ List.filter_map (function - | None -> None - | Some (l, hint) -> Some (Location ("", Some (hint substs), Reporting.start_loc l, Seq [])) - ) reasons) - in - let format_var_constraints = - function - | [info] -> format_var_constraint info - | infos -> Seq (List.map format_var_constraint infos) - in - With ((fun ppf -> { ppf with loc_color = Util.yellow }), - Seq (Line (error_string_of_typ substs typ1 ^ " is not a subtype of " ^ error_string_of_typ substs typ2) - :: (match nc with Some nc -> [Line ("as " ^ error_string_of_nc substs nc ^ " could not be proven")] | None -> []) - @ List.map (fun (v, l, ncs) -> - Seq [Line ""; - Line ("type variable " ^ error_string_of_kid substs v ^ ":"); - Location ("", Some "bound here", l, format_var_constraints ncs)]) - var_constraints)) + let var_constraints = + List.map + (fun (v, l) -> + (v, l, List.filter (fun (_, _, nc) -> KidSet.mem v (tyvars_of_constraint nc)) rewritten_constraints) + ) + (if !opt_explain_all_variables then appear_vars else vars) + in - | Err_no_num_ident id -> - Line ("No num identifier " ^ string_of_id id) + let substs = + List.fold_left + (fun (substs, new_vars) (v, _) -> + if is_kid_generated v || has_underscore v then ( + let v' = readable_name v in + if (not (KBindings.mem v' all_vars)) && not (KidSet.mem v' new_vars) then + (KBindings.add v (nvar v') substs, KidSet.add v' new_vars) + else (substs, new_vars) + ) + else (substs, new_vars) + ) + (KBindings.empty, KidSet.empty) vars + |> fst + in - | Err_no_casts (exp, typ_from, typ_to, trigger, reasons) -> - let coercion = - Line ("Tried performing type coercion from " ^ string_of_typ typ_from - ^ " to " ^ string_of_typ typ_to - ^ " on " ^ string_of_exp exp) - in - Seq ([coercion; Line "Coercion failed because:"; msg trigger] - @ if not (reasons = []) then - Line "Possible reasons:" :: List.map msg reasons - else - []) + let format_var_constraint (reasons, original_nc, nc) = + if List.for_all (function None -> true | Some _ -> false) reasons || not !opt_explain_constraints then + Line ("has constraint: " ^ error_string_of_nc substs nc) + else + Seq + (Line ("has constraint " ^ error_string_of_nc substs nc) + :: + ( match original_nc with + | Some nc -> [Line ("original constraint was " ^ error_string_of_nc substs nc)] + | None -> [] + ) + @ List.filter_map + (function + | None -> None + | Some (l, hint) -> Some (Location ("", Some (hint substs), Reporting.start_loc l, Seq [])) + ) + reasons + ) + in + let format_var_constraints = function + | [info] -> format_var_constraint info + | infos -> Seq (List.map format_var_constraint infos) + in + With + ( (fun ppf -> { ppf with loc_color = Util.yellow }), + Seq + (Line (error_string_of_typ substs typ1 ^ " is not a subtype of " ^ error_string_of_typ substs typ2) + :: + ( match nc with + | Some nc -> [Line ("as " ^ error_string_of_nc substs nc ^ " could not be proven")] + | None -> [] + ) + @ List.map + (fun (v, l, ncs) -> + Seq + [ + Line ""; + Line ("type variable " ^ error_string_of_kid substs v ^ ":"); + Location ("", Some "bound here", l, format_var_constraints ncs); + ] + ) + var_constraints + ) + ) + | Err_no_num_ident id -> Line ("No num identifier " ^ string_of_id id) + | Err_no_casts (exp, typ_from, typ_to, trigger, reasons) -> + let coercion = + Line + ("Tried performing type coercion from " ^ string_of_typ typ_from ^ " to " ^ string_of_typ typ_to ^ " on " + ^ string_of_exp exp + ) + in + Seq + ([coercion; Line "Coercion failed because:"; msg trigger] + @ if not (reasons = []) then Line "Possible reasons:" :: List.map msg reasons else [] + ) in msg @@ -337,44 +358,37 @@ let string_of_type_error err = Buffer.contents b let rec collapse_errors = function - | (Err_no_overloading (_, errs) as no_collapse) -> - let errs = List.map (fun (_, err) -> collapse_errors err) errs in - let interesting = function - | Err_other _ -> false - | Err_no_casts _ -> false - | _ -> true - in - begin match List.filter interesting errs with - | err :: errs -> - let fold_equal msg err = - match msg, err with - | Some msg, Err_no_overloading _ -> Some msg - | Some msg, Err_no_casts _ -> Some msg - | Some msg, err when msg = string_of_type_error err -> Some msg - | _, _ -> None - in - begin match List.fold_left fold_equal (Some (string_of_type_error err)) errs with - | Some _ -> err - | None -> no_collapse - end - | [] -> no_collapse - end + | Err_no_overloading (_, errs) as no_collapse -> + let errs = List.map (fun (_, err) -> collapse_errors err) errs in + let interesting = function Err_other _ -> false | Err_no_casts _ -> false | _ -> true in + begin + match List.filter interesting errs with + | err :: errs -> + let fold_equal msg err = + match (msg, err) with + | Some msg, Err_no_overloading _ -> Some msg + | Some msg, Err_no_casts _ -> Some msg + | Some msg, err when msg = string_of_type_error err -> Some msg + | _, _ -> None + in + begin + match List.fold_left fold_equal (Some (string_of_type_error err)) errs with + | Some _ -> err + | None -> no_collapse + end + | [] -> no_collapse + end | Err_inner (err1, l, prefix, hint, err2) -> - let err1 = collapse_errors err1 in - let err2 = collapse_errors err2 in - if string_of_type_error err1 = string_of_type_error err2 then - err1 - else - Err_inner (err1, l, prefix, hint, err2) + let err1 = collapse_errors err1 in + let err2 = collapse_errors err2 in + if string_of_type_error err1 = string_of_type_error err2 then err1 else Err_inner (err1, l, prefix, hint, err2) | err -> err let check_defs : Env.t -> uannot def list -> tannot def list * Env.t = - fun env defs -> - try Type_check.check_defs env defs with - | Type_error (env, l, err) -> - raise (Reporting.err_typ l (string_of_type_error err)) + fun env defs -> + try Type_check.check_defs env defs + with Type_error (env, l, err) -> raise (Reporting.err_typ l (string_of_type_error err)) let check : Env.t -> uannot ast -> tannot ast * Env.t = - fun env defs -> - try Type_check.check env defs with - | Type_error (env, l, err) -> raise (Reporting.err_typ l (string_of_type_error err)) + fun env defs -> + try Type_check.check env defs with Type_error (env, l, err) -> raise (Reporting.err_typ l (string_of_type_error err)) diff --git a/src/lib/type_error.mli b/src/lib/type_error.mli index ead41881d..5baeacb7d 100644 --- a/src/lib/type_error.mli +++ b/src/lib/type_error.mli @@ -80,28 +80,17 @@ val opt_explain_all_variables : bool ref (** If false (default), we'll list relevant constraints, but not go into detail about how they were derived *) val opt_explain_constraints : bool ref - -type suggestion = - | Suggest_add_constraint of Ast.n_constraint - | Suggest_none + +type suggestion = Suggest_add_constraint of Ast.n_constraint | Suggest_none (** Analyze an unresolved quantifier type error *) -val analyze_unresolved_quant : - (Ast_util.mut * Ast.typ) Ast_util.Bindings.t -> - Ast.n_constraint list -> - Ast.quant_item -> - suggestion +val analyze_unresolved_quant : + (Ast_util.mut * Ast.typ) Ast_util.Bindings.t -> Ast.n_constraint list -> Ast.quant_item -> suggestion val collapse_errors : Type_check.type_error -> Type_check.type_error - + val string_of_type_error : Type_check.type_error -> string -val check_defs : - Type_check.Env.t -> - uannot Ast.def list -> - Type_check.tannot Ast.def list * Type_check.Env.t +val check_defs : Type_check.Env.t -> uannot Ast.def list -> Type_check.tannot Ast.def list * Type_check.Env.t -val check : - Type_check.Env.t -> - uannot Ast_defs.ast -> - Type_check.tannot Ast_defs.ast * Type_check.Env.t +val check : Type_check.Env.t -> uannot Ast_defs.ast -> Type_check.tannot Ast_defs.ast * Type_check.Env.t diff --git a/src/lib/util.ml b/src/lib/util.ml index 37106e971..f58538f27 100644 --- a/src/lib/util.ml +++ b/src/lib/util.ml @@ -114,321 +114,273 @@ let opt_colors = ref true let opt_verbosity = ref 0 -let rec last = function - | [x] -> x - | _ :: xs -> last xs - | [] -> raise (Failure "last") - -let rec last_opt = function - | [x] -> Some x - | _ :: xs -> last_opt xs - | [] -> None - -let rec butlast = function - | [_] -> [] - | x :: xs -> x :: butlast xs - | [] -> [] - -module Duplicate(S : Set.S) = struct - -type dups = - | No_dups of S.t - | Has_dups of S.elt - -let duplicates (x : S.elt list) : dups = - let rec f x acc = match x with - | [] -> No_dups(acc) - | s::rest -> - if S.mem s acc then - Has_dups(s) - else - f rest (S.add s acc) - in +let rec last = function [x] -> x | _ :: xs -> last xs | [] -> raise (Failure "last") + +let rec last_opt = function [x] -> Some x | _ :: xs -> last_opt xs | [] -> None + +let rec butlast = function [_] -> [] | x :: xs -> x :: butlast xs | [] -> [] + +module Duplicate (S : Set.S) = struct + type dups = No_dups of S.t | Has_dups of S.elt + + let duplicates (x : S.elt list) : dups = + let rec f x acc = + match x with [] -> No_dups acc | s :: rest -> if S.mem s acc then Has_dups s else f rest (S.add s acc) + in f x S.empty end let remove_duplicates l = let l' = List.sort Stdlib.compare l in - let rec aux acc l = match (acc, l) with - (_, []) -> List.rev acc - | ([], x :: xs) -> aux [x] xs - | (y::ys, x :: xs) -> if (x = y) then aux (y::ys) xs else aux (x::y::ys) xs + let rec aux acc l = + match (acc, l) with + | _, [] -> List.rev acc + | [], x :: xs -> aux [x] xs + | y :: ys, x :: xs -> if x = y then aux (y :: ys) xs else aux (x :: y :: ys) xs in aux [] l' let remove_dups compare eq l = let l' = List.sort compare l in - let rec aux acc l = match (acc, l) with - (_, []) -> List.rev acc - | ([], x :: xs) -> aux [x] xs - | (y::ys, x :: xs) -> if (eq x y) then aux (y::ys) xs else aux (x::y::ys) xs + let rec aux acc l = + match (acc, l) with + | _, [] -> List.rev acc + | [], x :: xs -> aux [x] xs + | y :: ys, x :: xs -> if eq x y then aux (y :: ys) xs else aux (x :: y :: ys) xs in aux [] l' let lex_ord_list comparison xs ys = let rec lex_lists xs ys = - match xs, ys with + match (xs, ys) with | x :: xs, y :: ys -> - let c = comparison x y in - if c = 0 then lex_lists xs ys else c + let c = comparison x y in + if c = 0 then lex_lists xs ys else c | [], [] -> 0 | _, _ -> assert false in - if List.length xs = List.length ys then - lex_lists xs ys - else if List.length xs < List.length ys then - -1 - else - 1 - -let rec power i tothe = - if tothe <= 0 - then 1 - else i * power i (tothe - 1) + if List.length xs = List.length ys then lex_lists xs ys else if List.length xs < List.length ys then -1 else 1 + +let rec power i tothe = if tothe <= 0 then 1 else i * power i (tothe - 1) let rec assoc_equal_opt eq k l = - match l with - | [] -> None - | (k',v)::l -> if (eq k k') then Some v else assoc_equal_opt eq k l + match l with [] -> None | (k', v) :: l -> if eq k k' then Some v else assoc_equal_opt eq k l let rec assoc_compare_opt cmp k l = - match l with - | [] -> None - | (k',v)::l -> if cmp k k' = 0 then Some v else assoc_compare_opt cmp k l + match l with [] -> None | (k', v) :: l -> if cmp k k' = 0 then Some v else assoc_compare_opt cmp k l let rec compare_list f l1 l2 = - match (l1,l2) with - | ([],[]) -> 0 - | (_,[]) -> 1 - | ([],_) -> -1 - | (x::l1,y::l2) -> - let c = f x y in - if c = 0 then - compare_list f l1 l2 - else - c - -let rec map_last f = function - | [] -> [] - | [x] -> [f true x] - | (x :: xs) -> f false x :: map_last f xs + match (l1, l2) with + | [], [] -> 0 + | _, [] -> 1 + | [], _ -> -1 + | x :: l1, y :: l2 -> + let c = f x y in + if c = 0 then compare_list f l1 l2 else c + +let rec map_last f = function [] -> [] | [x] -> [f true x] | x :: xs -> f false x :: map_last f xs let rec iter_last f = function | [] -> () | [x] -> f true x | x :: xs -> - f false x; - iter_last f xs - + f false x; + iter_last f xs + let rec split_on_char sep str = try let sep_pos = String.index str sep in String.sub str 0 sep_pos :: split_on_char sep (String.sub str (sep_pos + 1) (String.length str - (sep_pos + 1))) - with - | Not_found -> [str] + with Not_found -> [str] let map_changed_default d f l = let rec g = function - | [] -> ([],false) - | x::y -> - let (r,c) = g y in - match f x with - | None -> ((d x)::r,c) - | Some(x') -> (x'::r,true) + | [] -> ([], false) + | x :: y -> ( + let r, c = g y in + match f x with None -> (d x :: r, c) | Some x' -> (x' :: r, true) + ) in - let (r,c) = g l in - if c then - Some(r) - else - None + let r, c = g l in + if c then Some r else None let map_changed f l = map_changed_default (fun x -> x) f l let rec map_split f = function | [] -> ([], []) - | x :: xs -> - match f x with - | Ok x' -> - let (xs', ys') = map_split f xs in - (x' :: xs', ys') - | Error y' -> - let (xs', ys') = map_split f xs in - (xs', y' :: ys') - -let list_empty = function - | [] -> true - | _ -> false - + | x :: xs -> ( + match f x with + | Ok x' -> + let xs', ys' = map_split f xs in + (x' :: xs', ys') + | Error y' -> + let xs', ys' = map_split f xs in + (xs', y' :: ys') + ) + +let list_empty = function [] -> true | _ -> false + let list_index p l = - let rec aux i l = - match l with [] -> None - | (x :: xs) -> if p x then Some i else aux (i+1) xs - in + let rec aux i l = match l with [] -> None | x :: xs -> if p x then Some i else aux (i + 1) xs in aux 0 l -let option_get_exn e = function - | Some(o) -> o - | None -> raise e +let option_get_exn e = function Some o -> o | None -> raise e -let option_cases op f1 f2 = match op with - | Some(o) -> f1 o - | None -> f2 () +let option_cases op f1 f2 = match op with Some o -> f1 o | None -> f2 () -let option_binop f x y = match x, y with - | Some x, Some y -> Some (f x y) - | _ -> None +let option_binop f x y = match (x, y) with Some x, Some y -> Some (f x y) | _ -> None -let rec option_these = function - | Some x :: xs -> x :: option_these xs - | None :: xs -> option_these xs - | [] -> [] +let rec option_these = function Some x :: xs -> x :: option_these xs | None :: xs -> option_these xs | [] -> [] let rec option_all = function | [] -> Some [] | None :: _ -> None - | Some x :: xs -> - begin match option_all xs with - | None -> None - | Some xs -> Some (x :: xs) - end + | Some x :: xs -> begin match option_all xs with None -> None | Some xs -> Some (x :: xs) end let rec map_all (f : 'a -> 'b option) (l : 'a list) : 'b list option = - match l with [] -> Some [] - | x :: xs -> - match (f x) with None -> None - | Some x' -> Option.map (fun xs' -> x' :: xs') (map_all f xs) + match l with + | [] -> Some [] + | x :: xs -> ( + match f x with None -> None | Some x' -> Option.map (fun xs' -> x' :: xs') (map_all f xs) + ) let rec option_first f xL = match xL with - [] -> None - | (x :: xs) -> match f x with None -> option_first f xs | Some s -> Some s + | [] -> None + | x :: xs -> ( + match f x with None -> option_first f xs | Some s -> Some s + ) let list_to_front n l = - if n <= 0 then l else - let rec aux acc n l = - match (n, l) with - (0, x::xs) -> (x :: (List.rev_append acc xs)) - | (n, x::xs) -> aux (x :: acc) (n-1) xs - | (_, []) -> (* should not happen *) raise (Failure "list_to_front") - in aux [] n l + if n <= 0 then l + else ( + let rec aux acc n l = + match (n, l) with + | 0, x :: xs -> x :: List.rev_append acc xs + | n, x :: xs -> aux (x :: acc) (n - 1) xs + | _, [] -> (* should not happen *) raise (Failure "list_to_front") + in + aux [] n l + ) let undo_list_to_front n l = - if n <= 0 then l else - let rec aux acc n y l = - match (n, l) with - (0, xs) -> List.rev_append acc (y::xs) - | (n, x::xs) -> aux (x :: acc) (n-1) y xs - | (_, []) -> List.rev_append acc [y] - in match l with [] -> l | y::xs -> aux [] n y xs + if n <= 0 then l + else ( + let rec aux acc n y l = + match (n, l) with + | 0, xs -> List.rev_append acc (y :: xs) + | n, x :: xs -> aux (x :: acc) (n - 1) y xs + | _, [] -> List.rev_append acc [y] + in + match l with [] -> l | y :: xs -> aux [] n y xs + ) let split_after n l = - if n < 0 then raise (Failure "negative argument to split_after") else - let rec aux acc n ll = match (n, ll) with - (0, _) -> (List.rev acc, ll) - | (n, x :: xs) -> aux (x :: acc) (n-1) xs - | _ -> raise (Failure "index too large") - in aux [] n l + if n < 0 then raise (Failure "negative argument to split_after") + else ( + let rec aux acc n ll = + match (n, ll) with + | 0, _ -> (List.rev acc, ll) + | n, x :: xs -> aux (x :: acc) (n - 1) xs + | _ -> raise (Failure "index too large") + in + aux [] n l + ) let rec split3 = function | (x, y, z) :: xs -> - let (xs, ys, zs) = split3 xs in - (x :: xs, y :: ys, z :: zs) + let xs, ys, zs = split3 xs in + (x :: xs, y :: ys, z :: zs) | [] -> ([], [], []) let rec list_iter_sep (sf : unit -> unit) (f : 'a -> unit) l : unit = match l with - | [] -> () - | [x0] -> f x0 - | (x0 :: x1 :: xs) -> (f x0; sf(); list_iter_sep sf f (x1 :: xs)) + | [] -> () + | [x0] -> f x0 + | x0 :: x1 :: xs -> + f x0; + sf (); + list_iter_sep sf f (x1 :: xs) let string_to_list s = - let rec aux i acc = - if i < 0 then acc - else aux (i-1) (s.[i] :: acc) - in aux (String.length s - 1) [] - -module IntSet = Set.Make( - struct - let compare = Stdlib.compare - type t = int - end ) - -module IntIntSet = Set.Make( - struct - let compare = Stdlib.compare - type t = int * int - end ) + let rec aux i acc = if i < 0 then acc else aux (i - 1) (s.[i] :: acc) in + aux (String.length s - 1) [] + +module IntSet = Set.Make (struct + let compare = Stdlib.compare + type t = int +end) + +module IntIntSet = Set.Make (struct + let compare = Stdlib.compare + type t = int * int +end) let copy_file src dst = let len = 5096 in let b = Bytes.make len ' ' in let read_len = ref 0 in let i = open_in_bin src in - let o = open_out_bin dst in - while (read_len := input i b 0 len; !read_len <> 0) do + let o = open_out_bin dst in + while + read_len := input i b 0 len; + !read_len <> 0 + do output o b 0 !read_len done; close_in i; close_out o let move_file src dst = - if (Sys.file_exists dst) then Sys.remove dst; - try - (* try efficient version *) - Sys.rename src dst - with Sys_error _ -> - begin - (* OK, do it the the hard way *) - copy_file src dst; - Sys.remove src - end + if Sys.file_exists dst then Sys.remove dst; + try (* try efficient version *) + Sys.rename src dst + with Sys_error _ -> + (* OK, do it the the hard way *) + copy_file src dst; + Sys.remove src let input_byte_opt chan = try Some (input_byte chan) with End_of_file -> None let same_content_files file1 file2 : bool = - (Sys.file_exists file1) && (Sys.file_exists file2) && - begin - let s1 = open_in_bin file1 in - let s2 = open_in_bin file2 in - let rec comp s1 s2 = - match (input_byte_opt s1, input_byte_opt s2) with - | None, None -> true - | Some b1, Some b2 -> if b1 = b2 then comp s1 s2 else false - | _, _ -> false - in - let result = comp s1 s2 in - close_in s1; - close_in s2; - result - end + Sys.file_exists file1 && Sys.file_exists file2 + && begin + let s1 = open_in_bin file1 in + let s2 = open_in_bin file2 in + let rec comp s1 s2 = + match (input_byte_opt s1, input_byte_opt s2) with + | None, None -> true + | Some b1, Some b2 -> if b1 = b2 then comp s1 s2 else false + | _, _ -> false + in + let result = comp s1 s2 in + close_in s1; + close_in s2; + result + end (*String formatting *) let rec string_of_list sep string_of = function | [] -> "" | [x] -> string_of x - | x::ls -> (string_of x) ^ sep ^ (string_of_list sep string_of ls) + | x :: ls -> string_of x ^ sep ^ string_of_list sep string_of ls -let string_of_option string_of = function - | None -> "" - | Some x -> string_of x +let string_of_option string_of = function None -> "" | Some x -> string_of x let rec take_drop f = function | [] -> ([], []) - | (x :: xs) when not (f x) -> ([], x :: xs) - | (x :: xs) -> - let (ys, zs) = take_drop f xs in - (x :: ys, zs) + | x :: xs when not (f x) -> ([], x :: xs) + | x :: xs -> + let ys, zs = take_drop f xs in + (x :: ys, zs) + +let rec find_rest_opt f = function [] -> None | x :: xs when f x -> Some (x, xs) | _ :: xs -> find_rest_opt f xs -let rec find_rest_opt f = function - | [] -> None - | x :: xs when f x -> Some (x, xs) - | _ :: xs -> find_rest_opt f xs - let find_next f xs = let rec find_next' f acc = function - | x :: xs when f x -> List.rev acc, Some (x, xs) + | x :: xs when f x -> (List.rev acc, Some (x, xs)) | x :: xs -> find_next' f (x :: acc) xs - | [] -> List.rev acc, None + | [] -> (List.rev acc, None) in find_next' f [] xs @@ -441,16 +393,19 @@ let find_index_opt f xs = find_index_opt' f 0 xs let rec find_map f = function - | x :: xs -> - begin match f x with - | Some y -> Some y - | None -> find_map f xs - end + | x :: xs -> begin match f x with Some y -> Some y | None -> find_map f xs end | [] -> None let fold_left_concat_map f acc xs = - let ys, acc = List.fold_left (fun (ys, acc) x -> let acc, zs = f acc x in (List.rev zs @ ys, acc)) ([], acc) xs in - acc, List.rev ys + let ys, acc = + List.fold_left + (fun (ys, acc) x -> + let acc, zs = f acc x in + (List.rev zs @ ys, acc) + ) + ([], acc) xs + in + (acc, List.rev ys) let rec fold_left_last f acc = function | [] -> acc @@ -458,41 +413,22 @@ let rec fold_left_last f acc = function | x :: xs -> fold_left_last f (f false acc x) xs let fold_left_index f init xs = - let rec go n acc = function - | [] -> acc - | x :: xs -> go (n + 1) (f n acc x) xs - in + let rec go n acc = function [] -> acc | x :: xs -> go (n + 1) (f n acc x) xs in go 0 init xs let fold_left_index_last f init xs = - let rec go n acc = function - | [] -> acc - | [x] -> f n true acc x - | x :: xs -> go (n + 1) (f n false acc x) xs - in + let rec go n acc = function [] -> acc | [x] -> f n true acc x | x :: xs -> go (n + 1) (f n false acc x) xs in go 0 init xs -let rec take n xs = match n, xs with - | 0, _ -> [] - | _, [] -> [] - | n, (x :: xs) -> x :: take (n - 1) xs +let rec take n xs = match (n, xs) with 0, _ -> [] | _, [] -> [] | n, x :: xs -> x :: take (n - 1) xs -let rec drop n xs = match n, xs with - | 0, xs -> xs - | _, [] -> [] - | n, (_ :: xs) -> drop (n - 1) xs +let rec drop n xs = match (n, xs) with 0, xs -> xs | _, [] -> [] | n, _ :: xs -> drop (n - 1) xs let list_init len f = - let rec list_init' len f acc = - if acc >= len then [] - else f acc :: list_init' len f (acc + 1) - in + let rec list_init' len f acc = if acc >= len then [] else f acc :: list_init' len f (acc + 1) in list_init' len f 0 -let termcode n = - if !opt_colors then - "\x1B[" ^ string_of_int n ^ "m" - else "" +let termcode n = if !opt_colors then "\x1B[" ^ string_of_int n ^ "m" else "" let bold str = termcode 1 ^ str let dim str = termcode 2 ^ str @@ -533,53 +469,47 @@ let file_encode_string str = let md5 = Digest.to_hex (Digest.string zstr) in String.lowercase_ascii zstr ^ String.lowercase_ascii md5 -let log_line str line msg = - "\n[" ^ (str ^ ":" ^ string_of_int line |> blue |> clear) ^ "] " ^ msg +let log_line str line msg = "\n[" ^ (str ^ ":" ^ string_of_int line |> blue |> clear) ^ "] " ^ msg -let header str n = "\n" ^ str ^ "\n" ^ String.make (String.length str - 9 * n) '=' +let header str n = "\n" ^ str ^ "\n" ^ String.make (String.length str - (9 * n)) '=' let progress prefix msg n total = - if !opt_verbosity > 0 then - let len = truncate ((float n /. float total) *. 50.0) in - let percent = truncate ((float n /. float total) *. 100.0) in + if !opt_verbosity > 0 then ( + let len = truncate (float n /. float total *. 50.0) in + let percent = truncate (float n /. float total *. 100.0) in let msg = - if String.length msg <= 20 then - msg ^ ")" ^ String.make (20 - String.length msg) ' ' - else - String.sub msg 0 17 ^ "...)" + if String.length msg <= 20 then msg ^ ")" ^ String.make (20 - String.length msg) ' ' + else String.sub msg 0 17 ^ "...)" in - let str = prefix ^ "[" ^ String.make len '=' ^ String.make (50 - len) ' ' ^ "] " - ^ string_of_int percent ^ "%" - ^ " (" ^ msg + let str = + prefix ^ "[" ^ String.make len '=' ^ String.make (50 - len) ' ' ^ "] " ^ string_of_int percent ^ "%" ^ " (" ^ msg in prerr_string str; - if n = total then - prerr_char '\n' - else - prerr_string ("\x1B[" ^ string_of_int (String.length str) ^ "D"); + if n = total then prerr_char '\n' else prerr_string ("\x1B[" ^ string_of_int (String.length str) ^ "D"); flush stderr - else - () + ) + else () let open_output_with_check opt_dir file_name = - let (temp_file_name, o) = Filename.open_temp_file "ll_temp" "" in + let temp_file_name, o = Filename.open_temp_file "ll_temp" "" in let o' = Format.formatter_of_out_channel o in (o', (o, temp_file_name, opt_dir, file_name)) let open_output_with_check_unformatted opt_dir file_name = - let (temp_file_name, o) = Filename.open_temp_file "ll_temp" "" in + let temp_file_name, o = Filename.open_temp_file "ll_temp" "" in (o, temp_file_name, opt_dir, file_name) let always_replace_files = ref true let close_output_with_check (o, temp_file_name, opt_dir, file_name) = let _ = close_out o in - let file_name = match opt_dir with - | None -> file_name - | Some dir -> if Sys.file_exists dir then () - else Unix.mkdir dir 0o775; - Filename.concat dir file_name in - let do_replace = !always_replace_files || (not (same_content_files temp_file_name file_name)) in - let _ = if (not do_replace) then Sys.remove temp_file_name - else move_file temp_file_name file_name in + let file_name = + match opt_dir with + | None -> file_name + | Some dir -> + if Sys.file_exists dir then () else Unix.mkdir dir 0o775; + Filename.concat dir file_name + in + let do_replace = !always_replace_files || not (same_content_files temp_file_name file_name) in + let _ = if not do_replace then Sys.remove temp_file_name else move_file temp_file_name file_name in () diff --git a/src/lib/util.mli b/src/lib/util.mli index 966f1747b..7014cc6a9 100644 --- a/src/lib/util.mli +++ b/src/lib/util.mli @@ -74,14 +74,12 @@ val opt_verbosity : int ref val last : 'a list -> 'a val last_opt : 'a list -> 'a option - + val butlast : 'a list -> 'a list (** Mixed useful things *) -module Duplicate(S : Set.S) : sig - type dups = - | No_dups of S.t - | Has_dups of S.elt +module Duplicate (S : Set.S) : sig + type dups = No_dups of S.t | Has_dups of S.elt val duplicates : S.elt list -> dups end @@ -108,7 +106,7 @@ val power : int -> int -> int val map_last : (bool -> 'a -> 'b) -> 'a list -> 'b list val iter_last : (bool -> 'a -> unit) -> 'a list -> unit - + (** {2 Option Functions} *) (** [option_cases None f_s f_n] returns [f_n], whereas @@ -135,16 +133,16 @@ val option_all : 'a option list -> 'a list option (** {2 List Functions} *) val list_empty : 'a list -> bool - + (** [list_index p l] returns the first index [i] such that the predicate [p (l!i)] holds. If no such [i] exists, [None] is returned. *) -val list_index: ('a -> bool) -> 'a list -> int option +val list_index : ('a -> bool) -> 'a list -> int option (** [option_first f l] searches for the first element [x] of [l] such that the [f x] is not [None]. If such an element exists, [f x] is returned, otherwise [None]. *) -val option_first: ('a -> 'b option) -> 'a list -> 'b option +val option_first : ('a -> 'b option) -> 'a list -> 'b option (** [map_changed f l] maps [f] over [l]. If for all elements of [l] the @@ -165,7 +163,7 @@ val map_changed_default : ('a -> 'b) -> ('a -> 'b option) -> 'a list -> 'b list val list_iter_sep : (unit -> unit) -> ('a -> unit) -> 'a list -> unit val map_split : ('a -> ('b, 'c) result) -> 'a list -> 'b list * 'c list - + (** [map_all f l] maps [f] over [l]. If at least one entry is [None], [None] is returned. Otherwise, the [Some] function is removed from the list. *) val map_all : ('a -> 'b option) -> 'a list -> 'b list option @@ -191,11 +189,11 @@ val compare_list : ('a -> 'b -> int) -> 'a list -> 'b list -> int val take : int -> 'a list -> 'a list val drop : int -> 'a list -> 'a list -val take_drop : ('a -> bool) -> 'a list -> ('a list * 'a list) +val take_drop : ('a -> bool) -> 'a list -> 'a list * 'a list val find_rest_opt : ('a -> bool) -> 'a list -> ('a * 'a list) option -val find_next : ('a -> bool) -> 'a list -> ('a list * ('a * 'a list) option) +val find_next : ('a -> bool) -> 'a list -> 'a list * ('a * 'a list) option (** find an item in a list and return that item as well as its index *) val find_index_opt : ('a -> bool) -> 'a list -> (int * 'a) option @@ -242,6 +240,7 @@ val string_to_list : string -> char list (** Sets of Integers *) module IntSet : Set.S with type elt = int + module IntIntSet : Set.S with type elt = int * int (** {2 Formatting functions} *) @@ -295,9 +294,10 @@ val progress : string -> string -> int -> int -> unit files existed before. If it is set to [false] and an output file already exists, the output file is only updated, if its content really changes. *) val always_replace_files : bool ref - -val open_output_with_check : string option -> string -> (Format.formatter * (out_channel * string * string option * string)) -val open_output_with_check_unformatted : string option -> string -> (out_channel * string * string option * string) +val open_output_with_check : + string option -> string -> Format.formatter * (out_channel * string * string option * string) + +val open_output_with_check_unformatted : string option -> string -> out_channel * string * string option * string -val close_output_with_check : (out_channel * string * string option * string) -> unit +val close_output_with_check : out_channel * string * string option * string -> unit diff --git a/src/lib/value.ml b/src/lib/value.ml index 947428b94..c2f9aac4d 100644 --- a/src/lib/value.ml +++ b/src/lib/value.ml @@ -67,7 +67,7 @@ module Big_int = Nat_big_num -module StringMap = Map.Make(String) +module StringMap = Map.Make (String) let print_chan = ref stdout let print_redirected = ref false @@ -76,11 +76,7 @@ let output_redirect chan = print_chan := chan; print_redirected := true -let output_close () = - if !print_redirected then - close_out !print_chan - else - () +let output_close () = if !print_redirected then close_out !print_chan else () let output str = output_string !print_chan str; @@ -110,13 +106,9 @@ type value = with a direct register read. *) | V_attempted_read of string -let coerce_bit = function - | V_bit b -> b - | _ -> assert false +let coerce_bit = function V_bit b -> b | _ -> assert false -let is_bit = function - | V_bit _ -> true - | _ -> false +let is_bit = function V_bit _ -> true | _ -> false let rec string_of_value = function | V_vector vs when List.for_all is_bit vs -> Sail_lib.string_of_bits (List.map coerce_bit vs) @@ -134,11 +126,13 @@ let rec string_of_value = function | V_real r -> Sail_lib.string_of_real r | V_ctor (str, vals) -> str ^ "(" ^ Util.string_of_list ", " string_of_value vals ^ ")" | V_record record -> - "{" ^ Util.string_of_list ", " (fun (field, v) -> field ^ "=" ^ string_of_value v) (StringMap.bindings record) ^ "}" + "{" + ^ Util.string_of_list ", " (fun (field, v) -> field ^ "=" ^ string_of_value v) (StringMap.bindings record) + ^ "}" | V_attempted_read _ -> assert false let rec eq_value v1 v2 = - match v1, v2 with + match (v1, v2) with | V_vector v1s, V_vector v2s when List.length v1s = List.length v2s -> List.for_all2 eq_value v1s v2s | V_list v1s, V_list v2s when List.length v1s = List.length v2s -> List.for_all2 eq_value v1s v2s | V_int n, V_int m -> Big_int.equal n m @@ -150,81 +144,45 @@ let rec eq_value v1 v2 = | V_string str1, V_string str2 -> str1 = str2 | V_ref str1, V_ref str2 -> str1 = str2 | V_ctor (name1, fields1), V_ctor (name2, fields2) when List.length fields1 = List.length fields2 -> - name1 = name2 && List.for_all2 eq_value fields1 fields2 - | V_record fields1, V_record fields2 -> - StringMap.equal eq_value fields1 fields2 + name1 = name2 && List.for_all2 eq_value fields1 fields2 + | V_record fields1, V_record fields2 -> StringMap.equal eq_value fields1 fields2 | _, _ -> false -let coerce_ctor = function - | V_ctor (str, vals) -> (str, vals) - | _ -> assert false +let coerce_ctor = function V_ctor (str, vals) -> (str, vals) | _ -> assert false -let coerce_bool = function - | V_bool b -> b - | _ -> assert false +let coerce_bool = function V_bool b -> b | _ -> assert false -let coerce_record = function - | V_record record -> record - | _ -> assert false +let coerce_record = function V_record record -> record | _ -> assert false -let and_bool = function - | [v1; v2] -> V_bool (coerce_bool v1 && coerce_bool v2) - | _ -> assert false +let and_bool = function [v1; v2] -> V_bool (coerce_bool v1 && coerce_bool v2) | _ -> assert false -let or_bool = function - | [v1; v2] -> V_bool (coerce_bool v1 || coerce_bool v2) - | _ -> assert false +let or_bool = function [v1; v2] -> V_bool (coerce_bool v1 || coerce_bool v2) | _ -> assert false let tuple_value (vs : value list) : value = V_tuple vs let mk_vector (bits : Sail_lib.bit list) : value = V_vector (List.map (fun bit -> V_bit bit) bits) -let coerce_bit = function - | V_bit b -> b - | _ -> assert false +let coerce_bit = function V_bit b -> b | _ -> assert false -let coerce_tuple = function - | V_tuple vs -> vs - | _ -> assert false +let coerce_tuple = function V_tuple vs -> vs | _ -> assert false -let coerce_list = function - | V_list vs -> vs - | _ -> assert false +let coerce_list = function V_list vs -> vs | _ -> assert false -let coerce_listlike = function - | V_tuple vs -> vs - | V_list vs -> vs - | V_unit -> [] - | _ -> assert false +let coerce_listlike = function V_tuple vs -> vs | V_list vs -> vs | V_unit -> [] | _ -> assert false -let coerce_int = function - | V_int i -> i - | _ -> assert false +let coerce_int = function V_int i -> i | _ -> assert false -let coerce_real = function - | V_real r -> r - | _ -> assert false +let coerce_real = function V_real r -> r | _ -> assert false -let coerce_cons = function - | V_list (v :: vs) -> Some (v, vs) - | V_list [] -> None - | _ -> assert false +let coerce_cons = function V_list (v :: vs) -> Some (v, vs) | V_list [] -> None | _ -> assert false -let coerce_gv = function - | V_vector vs -> vs - | _ -> assert false +let coerce_gv = function V_vector vs -> vs | _ -> assert false -let coerce_bv = function - | V_vector vs -> List.map coerce_bit vs - | _ -> assert false +let coerce_bv = function V_vector vs -> List.map coerce_bit vs | _ -> assert false -let coerce_string = function - | V_string str -> str - | _ -> assert false +let coerce_string = function V_string str -> str | _ -> assert false -let coerce_ref = function - | V_ref str -> str - | _ -> assert false +let coerce_ref = function V_ref str -> str | _ -> assert false let unit_value = V_unit @@ -244,13 +202,9 @@ let value_gteq = function | [v1; v2] -> V_bool (Sail_lib.gteq (coerce_int v1, coerce_int v2)) | _ -> failwith "value gteq" -let value_lt = function - | [v1; v2] -> V_bool (Sail_lib.lt (coerce_int v1, coerce_int v2)) - | _ -> failwith "value lt" +let value_lt = function [v1; v2] -> V_bool (Sail_lib.lt (coerce_int v1, coerce_int v2)) | _ -> failwith "value lt" -let value_gt = function - | [v1; v2] -> V_bool (Sail_lib.gt (coerce_int v1, coerce_int v2)) - | _ -> failwith "value gt" +let value_gt = function [v1; v2] -> V_bool (Sail_lib.gt (coerce_int v1, coerce_int v2)) | _ -> failwith "value gt" let value_eq_list = function | [v1; v2] -> V_bool (Sail_lib.eq_list (coerce_bv v1, coerce_bv v2)) @@ -279,17 +233,13 @@ let value_eq_bit = function | [v1; v2] -> V_bool (Sail_lib.eq_bit (coerce_bit v1, coerce_bit v2)) | _ -> failwith "value eq_bit" -let value_length = function - | [v] -> V_int (coerce_gv v |> List.length |> Big_int.of_int) - | _ -> failwith "value length" +let value_length = function [v] -> V_int (coerce_gv v |> List.length |> Big_int.of_int) | _ -> failwith "value length" let value_subrange = function | [v1; v2; v3] -> mk_vector (Sail_lib.subrange (coerce_bv v1, coerce_int v2, coerce_int v3)) | _ -> failwith "value subrange" -let value_access = function - | [v1; v2] -> Sail_lib.access (coerce_gv v1, coerce_int v2) - | _ -> failwith "value access" +let value_access = function [v1; v2] -> Sail_lib.access (coerce_gv v1, coerce_int v2) | _ -> failwith "value access" let value_update = function | [v1; v2; v3] -> V_vector (Sail_lib.update (coerce_gv v1, coerce_int v2, v3)) @@ -299,9 +249,7 @@ let value_update_subrange = function | [v1; v2; v3; v4] -> mk_vector (Sail_lib.update_subrange (coerce_bv v1, coerce_int v2, coerce_int v3, coerce_bv v4)) | _ -> failwith "value update_subrange" -let value_append = function - | [v1; v2] -> V_vector (coerce_gv v1 @ coerce_gv v2) - | _ -> failwith "value append" +let value_append = function [v1; v2] -> V_vector (coerce_gv v1 @ coerce_gv v2) | _ -> failwith "value append" let value_append_list = function | [v1; v2] -> V_list (coerce_list v1 @ coerce_list v2) @@ -311,13 +259,9 @@ let value_slice = function | [v1; v2; v3] -> V_vector (Sail_lib.slice (coerce_gv v1, coerce_int v2, coerce_int v3)) | _ -> failwith "value slice" -let value_not = function - | [v] -> V_bool (not (coerce_bool v)) - | _ -> failwith "value not" +let value_not = function [v] -> V_bool (not (coerce_bool v)) | _ -> failwith "value not" -let value_not_vec = function - | [v] -> mk_vector (Sail_lib.not_vec (coerce_bv v)) - | _ -> failwith "value not_vec" +let value_not_vec = function [v] -> mk_vector (Sail_lib.not_vec (coerce_bv v)) | _ -> failwith "value not_vec" let value_and_vec = function | [v1; v2] -> mk_vector (Sail_lib.and_vec (coerce_bv v1, coerce_bv v2)) @@ -331,31 +275,25 @@ let value_xor_vec = function | [v1; v2] -> mk_vector (Sail_lib.xor_vec (coerce_bv v1, coerce_bv v2)) | _ -> failwith "value xor_vec" -let value_uint = function - | [v] -> V_int (Sail_lib.uint (coerce_bv v)) - | _ -> failwith "value uint" +let value_uint = function [v] -> V_int (Sail_lib.uint (coerce_bv v)) | _ -> failwith "value uint" -let value_sint = function - | [v] -> V_int (Sail_lib.sint (coerce_bv v)) - | _ -> failwith "value sint" +let value_sint = function [v] -> V_int (Sail_lib.sint (coerce_bv v)) | _ -> failwith "value sint" let value_get_slice_int = function | [v1; v2; v3] -> mk_vector (Sail_lib.get_slice_int (coerce_int v1, coerce_int v2, coerce_int v3)) | _ -> failwith "value get_slice_int" let value_set_slice_int = function - | [v1; v2; v3; v4] -> - V_int (Sail_lib.set_slice_int (coerce_int v1, coerce_int v2, coerce_int v3, coerce_bv v4)) + | [v1; v2; v3; v4] -> V_int (Sail_lib.set_slice_int (coerce_int v1, coerce_int v2, coerce_int v3, coerce_bv v4)) | _ -> failwith "value set_slice_int" let value_set_slice = function | [v1; v2; v3; v4; v5] -> - mk_vector (Sail_lib.set_slice (coerce_int v1, coerce_int v2, coerce_bv v3, coerce_int v4, coerce_bv v5)) + mk_vector (Sail_lib.set_slice (coerce_int v1, coerce_int v2, coerce_bv v3, coerce_int v4, coerce_bv v5)) | _ -> failwith "value set_slice" let value_hex_slice = function - | [v1; v2; v3] -> - mk_vector (Sail_lib.hex_slice (coerce_string v1, coerce_int v2, coerce_int v3)) + | [v1; v2; v3] -> mk_vector (Sail_lib.hex_slice (coerce_string v1, coerce_int v2, coerce_int v3)) | _ -> failwith "value hex_slice" let value_add_int = function @@ -370,13 +308,9 @@ let value_sub_nat = function | [v1; v2] -> V_int (Sail_lib.sub_nat (coerce_int v1, coerce_int v2)) | _ -> failwith "value sub_nat" -let value_negate = function - | [v1] -> V_int (Sail_lib.negate (coerce_int v1)) - | _ -> failwith "value negate" +let value_negate = function [v1] -> V_int (Sail_lib.negate (coerce_int v1)) | _ -> failwith "value negate" -let value_pow2 = function - | [v1] -> V_int (Sail_lib.pow2 (coerce_int v1)) - | _ -> failwith "value pow2" +let value_pow2 = function [v1] -> V_int (Sail_lib.pow2 (coerce_int v1)) | _ -> failwith "value pow2" let value_int_power = function | [v1; v2] -> V_int (Sail_lib.int_power (coerce_int v1, coerce_int v2)) @@ -402,9 +336,7 @@ let value_modulus = function | [v1; v2] -> V_int (Sail_lib.modulus (coerce_int v1, coerce_int v2)) | _ -> failwith "value modulus" -let value_abs_int = function - | [v] -> V_int (Big_int.abs (coerce_int v)) - | _ -> failwith "value abs_int" +let value_abs_int = function [v] -> V_int (Big_int.abs (coerce_int v)) | _ -> failwith "value abs_int" let value_add_vec_int = function | [v1; v2] -> mk_vector (Sail_lib.add_vec_int (coerce_bv v1, coerce_int v2)) @@ -446,9 +378,7 @@ let value_count_leading_zeros = function | [v1] -> V_int (Sail_lib.count_leading_zeros (coerce_bv v1)) | _ -> failwith "value count_leading_zeros" -let is_ctor = function - | V_ctor _ -> true - | _ -> false +let is_ctor = function V_ctor _ -> true | _ -> false let value_sign_extend = function | [v1; v2] -> mk_vector (Sail_lib.sign_extend (coerce_bv v1, coerce_int v2)) @@ -458,13 +388,9 @@ let value_zero_extend = function | [v1; v2] -> mk_vector (Sail_lib.zero_extend (coerce_bv v1, coerce_int v2)) | _ -> failwith "value zero_extend" -let value_zeros = function - | [v] -> mk_vector (Sail_lib.zeros (coerce_int v)) - | _ -> failwith "value zeros" +let value_zeros = function [v] -> mk_vector (Sail_lib.zeros (coerce_int v)) | _ -> failwith "value zeros" -let value_ones = function - | [v] -> mk_vector (Sail_lib.ones (coerce_int v)) - | _ -> failwith "value ones" +let value_ones = function [v] -> mk_vector (Sail_lib.ones (coerce_int v)) | _ -> failwith "value ones" let value_shiftl = function | [v1; v2] -> mk_vector (Sail_lib.shiftl (coerce_bv v1, coerce_int v2)) @@ -494,34 +420,36 @@ let value_vector_truncateLSB = function | [v1; v2] -> mk_vector (Sail_lib.vector_truncateLSB (coerce_bv v1, coerce_int v2)) | _ -> failwith "value vector_truncateLSB" -let value_eq_anything = function - | [v1; v2] -> V_bool (eq_value v1 v2) - | _ -> failwith "value eq_anything" +let value_eq_anything = function [v1; v2] -> V_bool (eq_value v1 v2) | _ -> failwith "value eq_anything" let value_print = function - | [V_string str] -> output str; V_unit - | [v] -> output (string_of_value v |> Util.red |> Util.clear); V_unit + | [V_string str] -> + output str; + V_unit + | [v] -> + output (string_of_value v |> Util.red |> Util.clear); + V_unit | _ -> assert false let value_print_endline = function - | [V_string str] -> output_endline str; V_unit - | [v] -> output_endline (string_of_value v |> Util.red |> Util.clear); V_unit + | [V_string str] -> + output_endline str; + V_unit + | [v] -> + output_endline (string_of_value v |> Util.red |> Util.clear); + V_unit | _ -> assert false -let value_internal_pick = function - | [v1] -> List.hd (coerce_listlike v1); - | _ -> failwith "value internal_pick" +let value_internal_pick = function [v1] -> List.hd (coerce_listlike v1) | _ -> failwith "value internal_pick" let value_undefined_vector = function | [v1; v2] -> V_vector (Sail_lib.undefined_vector (coerce_int v1, v2)) | _ -> failwith "value undefined_vector" -let value_undefined_list = function - | [_] -> V_list [] - | _ -> failwith "value undefined_list" +let value_undefined_list = function [_] -> V_list [] | _ -> failwith "value undefined_list" let value_undefined_bitvector = function - | [v] -> V_vector (Sail_lib.undefined_vector (coerce_int v, V_bit (Sail_lib.B0))) + | [v] -> V_vector (Sail_lib.undefined_vector (coerce_int v, V_bit Sail_lib.B0)) | _ -> failwith "value undefined_bitvector" let value_read_ram = function @@ -530,80 +458,84 @@ let value_read_ram = function let value_write_ram = function | [v1; v2; v3; v4; v5] -> - let b = Sail_lib.write_ram (coerce_int v1, coerce_int v2, coerce_bv v3, coerce_bv v4, coerce_bv v5) in - V_bool(b) + let b = Sail_lib.write_ram (coerce_int v1, coerce_int v2, coerce_bv v3, coerce_bv v4, coerce_bv v5) in + V_bool b | _ -> failwith "value write_ram" let value_load_raw = function - | [v1; v2] -> Sail_lib.load_raw (coerce_bv v1, coerce_string v2) ; V_unit + | [v1; v2] -> + Sail_lib.load_raw (coerce_bv v1, coerce_string v2); + V_unit | _ -> failwith "value load_raw" let value_putchar = function | [v] -> - output_char !print_chan (char_of_int (Big_int.to_int (coerce_int v))); - flush !print_chan; - V_unit + output_char !print_chan (char_of_int (Big_int.to_int (coerce_int v))); + flush !print_chan; + V_unit | _ -> failwith "value putchar" -let value_dec_str = function - | [n] -> V_string (string_of_value n) - | _ -> failwith "value print_int" +let value_dec_str = function [n] -> V_string (string_of_value n) | _ -> failwith "value print_int" let value_print_bits = function - | [msg; bits] -> output_endline (coerce_string msg ^ string_of_value bits); V_unit + | [msg; bits] -> + output_endline (coerce_string msg ^ string_of_value bits); + V_unit | _ -> failwith "value print_bits" let value_print_int = function - | [msg; n] -> output_endline (coerce_string msg ^ string_of_value n); V_unit + | [msg; n] -> + output_endline (coerce_string msg ^ string_of_value n); + V_unit | _ -> failwith "value print_int" let value_print_string = function - | [msg; str] -> output_endline (coerce_string msg ^ coerce_string str); V_unit + | [msg; str] -> + output_endline (coerce_string msg ^ coerce_string str); + V_unit | _ -> failwith "value print_string" let value_prerr_bits = function - | [msg; bits] -> prerr_endline (coerce_string msg ^ string_of_value bits); V_unit + | [msg; bits] -> + prerr_endline (coerce_string msg ^ string_of_value bits); + V_unit | _ -> failwith "value prerr_bits" let value_prerr_int = function - | [msg; n] -> prerr_endline (coerce_string msg ^ string_of_value n); V_unit + | [msg; n] -> + prerr_endline (coerce_string msg ^ string_of_value n); + V_unit | _ -> failwith "value prerr_int" let value_prerr_string = function - | [msg; str] -> output_endline (coerce_string msg ^ coerce_string str); V_unit + | [msg; str] -> + output_endline (coerce_string msg ^ coerce_string str); + V_unit | _ -> failwith "value print_string" let value_concat_str = function | [v1; v2] -> V_string (Sail_lib.concat_str (coerce_string v1, coerce_string v2)) | _ -> failwith "value concat_str" -let value_to_real = function - | [v] -> V_real (Sail_lib.to_real (coerce_int v)) - | _ -> failwith "value to_real" +let value_to_real = function [v] -> V_real (Sail_lib.to_real (coerce_int v)) | _ -> failwith "value to_real" let value_print_real = function - | [v1; v2] -> output_endline (coerce_string v1 ^ string_of_value v2); V_unit + | [v1; v2] -> + output_endline (coerce_string v1 ^ string_of_value v2); + V_unit | _ -> failwith "value print_real" -let value_random_real = function - | [_] -> V_real (Sail_lib.random_real ()) - | _ -> failwith "value random_real" +let value_random_real = function [_] -> V_real (Sail_lib.random_real ()) | _ -> failwith "value random_real" -let value_sqrt_real = function - | [v] -> V_real (Sail_lib.sqrt_real (coerce_real v)) - | _ -> failwith "value sqrt_real" +let value_sqrt_real = function [v] -> V_real (Sail_lib.sqrt_real (coerce_real v)) | _ -> failwith "value sqrt_real" let value_quotient_real = function | [v1; v2] -> V_real (Sail_lib.quotient_real (coerce_real v1, coerce_real v2)) | _ -> failwith "value quotient_real" -let value_round_up = function - | [v] -> V_int (Sail_lib.round_up (coerce_real v)) - | _ -> failwith "value round_up" +let value_round_up = function [v] -> V_int (Sail_lib.round_up (coerce_real v)) | _ -> failwith "value round_up" -let value_round_down = function - | [v] -> V_int (Sail_lib.round_down (coerce_real v)) - | _ -> failwith "value round_down" +let value_round_down = function [v] -> V_int (Sail_lib.round_down (coerce_real v)) | _ -> failwith "value round_down" let value_quot_round_zero = function | [v1; v2] -> V_int (Sail_lib.quot_round_zero (coerce_int v1, coerce_int v2)) @@ -629,9 +561,7 @@ let value_div_real = function | [v1; v2] -> V_real (Sail_lib.div_real (coerce_real v1, coerce_real v2)) | _ -> failwith "value div_real" -let value_abs_real = function - | [v] -> V_real (Sail_lib.abs_real (coerce_real v)) - | _ -> failwith "value abs_real" +let value_abs_real = function [v] -> V_real (Sail_lib.abs_real (coerce_real v)) | _ -> failwith "value abs_real" let value_eq_real = function | [v1; v2] -> V_bool (Sail_lib.eq_real (coerce_real v1, coerce_real v2)) @@ -661,139 +591,149 @@ let value_decimal_string_of_bits = function | [v] -> V_string (Sail_lib.decimal_string_of_bits (coerce_bv v)) | _ -> failwith "value decimal_string_of_bits" -let primops = ref - (List.fold_left - (fun r (x, y) -> StringMap.add x y r) - StringMap.empty - [ ("and_bool", and_bool); - ("or_bool", or_bool); - ("print", value_print); - ("prerr", fun vs -> (prerr_string (string_of_value (List.hd vs)); V_unit)); - ("dec_str", value_dec_str); - ("print_endline", value_print_endline); - ("prerr_endline", fun vs -> (prerr_endline (string_of_value (List.hd vs)); V_unit)); - ("putchar", value_putchar); - ("string_of_int", fun vs -> V_string (string_of_value (List.hd vs))); - ("string_of_bits", fun vs -> V_string (string_of_value (List.hd vs))); - ("decimal_string_of_bits", value_decimal_string_of_bits); - ("print_bits", value_print_bits); - ("print_int", value_print_int); - ("print_string", value_print_string); - ("prerr_bits", value_print_bits); - ("prerr_int", value_print_int); - ("prerr_string", value_prerr_string); - ("concat_str", value_concat_str); - ("eq_int", value_eq_int); - ("lteq", value_lteq); - ("gteq", value_gteq); - ("lt", value_lt); - ("gt", value_gt); - ("eq_list", value_eq_list); - ("eq_bool", value_eq_bool); - ("eq_string", value_eq_string); - ("string_startswith", value_string_startswith); - ("string_drop", value_string_drop); - ("string_take", value_string_take); - ("string_length", value_string_length); - ("eq_bit", value_eq_bit); - ("eq_anything", value_eq_anything); - ("length", value_length); - ("subrange", value_subrange); - ("access", value_access); - ("update", value_update); - ("update_subrange", value_update_subrange); - ("slice", value_slice); - ("append", value_append); - ("append_list", value_append_list); - ("not", value_not); - ("not_vec", value_not_vec); - ("and_vec", value_and_vec); - ("or_vec", value_or_vec); - ("xor_vec", value_xor_vec); - ("uint", value_uint); - ("sint", value_sint); - ("get_slice_int", value_get_slice_int); - ("set_slice_int", value_set_slice_int); - ("set_slice", value_set_slice); - ("hex_slice", value_hex_slice); - ("zero_extend", value_zero_extend); - ("sign_extend", value_sign_extend); - ("zeros", value_zeros); - ("ones", value_ones); - ("shiftr", value_shiftr); - ("shiftl", value_shiftl); - ("arith_shiftr", value_arith_shiftr); - ("shift_bits_left", value_shift_bits_left); - ("shift_bits_right", value_shift_bits_right); - ("add_int", value_add_int); - ("sub_int", value_sub_int); - ("sub_nat", value_sub_nat); - ("div_int", value_quotient); - ("tdiv_int", value_tdiv_int); - ("tmod_int", value_tmod_int); - ("mult_int", value_mult); - ("mult", value_mult); - ("quotient", value_quotient); - ("modulus", value_modulus); - ("negate", value_negate); - ("pow2", value_pow2); - ("int_power", value_int_power); - ("shr_int", value_shr_int); - ("shl_int", value_shl_int); - ("max_int", value_max_int); - ("min_int", value_min_int); - ("abs_int", value_abs_int); - ("add_vec_int", value_add_vec_int); - ("sub_vec_int", value_sub_vec_int); - ("add_vec", value_add_vec); - ("sub_vec", value_sub_vec); - ("vector_truncate", value_vector_truncate); - ("vector_truncateLSB", value_vector_truncateLSB); - ("read_ram", value_read_ram); - ("write_ram", value_write_ram); - ("trace_memory_read", fun _ -> V_unit); - ("trace_memory_write", fun _ -> V_unit); - ("get_time_ns", fun _ -> V_int (Sail_lib.get_time_ns())); - ("load_raw", value_load_raw); - ("to_real", value_to_real); - ("eq_real", value_eq_real); - ("lt_real", value_lt_real); - ("gt_real", value_gt_real); - ("lteq_real", value_lteq_real); - ("gteq_real", value_gteq_real); - ("add_real", value_add_real); - ("sub_real", value_sub_real); - ("mult_real", value_mult_real); - ("round_up", value_round_up); - ("round_down", value_round_down); - ("quot_round_zero", value_quot_round_zero); - ("rem_round_zero", value_rem_round_zero); - ("quotient_real", value_quotient_real); - ("abs_real", value_abs_real); - ("div_real", value_div_real); - ("sqrt_real", value_sqrt_real); - ("print_real", value_print_real); - ("random_real", value_random_real); - ("undefined_unit", fun _ -> V_unit); - ("undefined_bit", fun _ -> V_bit Sail_lib.B0); - ("undefined_int", fun _ -> V_int Big_int.zero); - ("undefined_nat", fun _ -> V_int Big_int.zero); - ("undefined_bool", fun _ -> V_bool false); - ("undefined_bitvector", value_undefined_bitvector); - ("undefined_vector", value_undefined_vector); - ("undefined_list", value_undefined_list); - ("undefined_string", fun _ -> V_string ""); - ("internal_pick", value_internal_pick); - ("replicate_bits", value_replicate_bits); - ("count_leading_zeros", value_count_leading_zeros); - ("Elf_loader.elf_entry", fun _ -> V_int (!Elf_loader.opt_elf_entry)); - ("Elf_loader.elf_tohost", fun _ -> V_int (!Elf_loader.opt_elf_tohost)); - ("string_append", value_string_append); - ("string_length", value_string_length); - ("string_startswith", value_string_startswith); - ("string_drop", value_string_drop); - ("skip", fun _ -> V_unit); - ]) - -let add_primop name impl = - primops := StringMap.add name impl !primops +let primops = + ref + (List.fold_left + (fun r (x, y) -> StringMap.add x y r) + StringMap.empty + [ + ("and_bool", and_bool); + ("or_bool", or_bool); + ("print", value_print); + ( "prerr", + fun vs -> + prerr_string (string_of_value (List.hd vs)); + V_unit + ); + ("dec_str", value_dec_str); + ("print_endline", value_print_endline); + ( "prerr_endline", + fun vs -> + prerr_endline (string_of_value (List.hd vs)); + V_unit + ); + ("putchar", value_putchar); + ("string_of_int", fun vs -> V_string (string_of_value (List.hd vs))); + ("string_of_bits", fun vs -> V_string (string_of_value (List.hd vs))); + ("decimal_string_of_bits", value_decimal_string_of_bits); + ("print_bits", value_print_bits); + ("print_int", value_print_int); + ("print_string", value_print_string); + ("prerr_bits", value_print_bits); + ("prerr_int", value_print_int); + ("prerr_string", value_prerr_string); + ("concat_str", value_concat_str); + ("eq_int", value_eq_int); + ("lteq", value_lteq); + ("gteq", value_gteq); + ("lt", value_lt); + ("gt", value_gt); + ("eq_list", value_eq_list); + ("eq_bool", value_eq_bool); + ("eq_string", value_eq_string); + ("string_startswith", value_string_startswith); + ("string_drop", value_string_drop); + ("string_take", value_string_take); + ("string_length", value_string_length); + ("eq_bit", value_eq_bit); + ("eq_anything", value_eq_anything); + ("length", value_length); + ("subrange", value_subrange); + ("access", value_access); + ("update", value_update); + ("update_subrange", value_update_subrange); + ("slice", value_slice); + ("append", value_append); + ("append_list", value_append_list); + ("not", value_not); + ("not_vec", value_not_vec); + ("and_vec", value_and_vec); + ("or_vec", value_or_vec); + ("xor_vec", value_xor_vec); + ("uint", value_uint); + ("sint", value_sint); + ("get_slice_int", value_get_slice_int); + ("set_slice_int", value_set_slice_int); + ("set_slice", value_set_slice); + ("hex_slice", value_hex_slice); + ("zero_extend", value_zero_extend); + ("sign_extend", value_sign_extend); + ("zeros", value_zeros); + ("ones", value_ones); + ("shiftr", value_shiftr); + ("shiftl", value_shiftl); + ("arith_shiftr", value_arith_shiftr); + ("shift_bits_left", value_shift_bits_left); + ("shift_bits_right", value_shift_bits_right); + ("add_int", value_add_int); + ("sub_int", value_sub_int); + ("sub_nat", value_sub_nat); + ("div_int", value_quotient); + ("tdiv_int", value_tdiv_int); + ("tmod_int", value_tmod_int); + ("mult_int", value_mult); + ("mult", value_mult); + ("quotient", value_quotient); + ("modulus", value_modulus); + ("negate", value_negate); + ("pow2", value_pow2); + ("int_power", value_int_power); + ("shr_int", value_shr_int); + ("shl_int", value_shl_int); + ("max_int", value_max_int); + ("min_int", value_min_int); + ("abs_int", value_abs_int); + ("add_vec_int", value_add_vec_int); + ("sub_vec_int", value_sub_vec_int); + ("add_vec", value_add_vec); + ("sub_vec", value_sub_vec); + ("vector_truncate", value_vector_truncate); + ("vector_truncateLSB", value_vector_truncateLSB); + ("read_ram", value_read_ram); + ("write_ram", value_write_ram); + ("trace_memory_read", fun _ -> V_unit); + ("trace_memory_write", fun _ -> V_unit); + ("get_time_ns", fun _ -> V_int (Sail_lib.get_time_ns ())); + ("load_raw", value_load_raw); + ("to_real", value_to_real); + ("eq_real", value_eq_real); + ("lt_real", value_lt_real); + ("gt_real", value_gt_real); + ("lteq_real", value_lteq_real); + ("gteq_real", value_gteq_real); + ("add_real", value_add_real); + ("sub_real", value_sub_real); + ("mult_real", value_mult_real); + ("round_up", value_round_up); + ("round_down", value_round_down); + ("quot_round_zero", value_quot_round_zero); + ("rem_round_zero", value_rem_round_zero); + ("quotient_real", value_quotient_real); + ("abs_real", value_abs_real); + ("div_real", value_div_real); + ("sqrt_real", value_sqrt_real); + ("print_real", value_print_real); + ("random_real", value_random_real); + ("undefined_unit", fun _ -> V_unit); + ("undefined_bit", fun _ -> V_bit Sail_lib.B0); + ("undefined_int", fun _ -> V_int Big_int.zero); + ("undefined_nat", fun _ -> V_int Big_int.zero); + ("undefined_bool", fun _ -> V_bool false); + ("undefined_bitvector", value_undefined_bitvector); + ("undefined_vector", value_undefined_vector); + ("undefined_list", value_undefined_list); + ("undefined_string", fun _ -> V_string ""); + ("internal_pick", value_internal_pick); + ("replicate_bits", value_replicate_bits); + ("count_leading_zeros", value_count_leading_zeros); + ("Elf_loader.elf_entry", fun _ -> V_int !Elf_loader.opt_elf_entry); + ("Elf_loader.elf_tohost", fun _ -> V_int !Elf_loader.opt_elf_tohost); + ("string_append", value_string_append); + ("string_length", value_string_length); + ("string_startswith", value_string_startswith); + ("string_drop", value_string_drop); + ("skip", fun _ -> V_unit); + ] + ) + +let add_primop name impl = primops := StringMap.add name impl !primops diff --git a/src/sail_c_backend/c_backend.ml b/src/sail_c_backend/c_backend.ml index a33c71a84..1b1d1a968 100644 --- a/src/sail_c_backend/c_backend.ml +++ b/src/sail_c_backend/c_backend.ml @@ -90,15 +90,9 @@ let opt_extra_params = ref None let opt_extra_arguments = ref None let opt_branch_coverage = ref None -let extra_params () = - match !opt_extra_params with - | Some str -> str ^ ", " - | _ -> "" +let extra_params () = match !opt_extra_params with Some str -> str ^ ", " | _ -> "" -let extra_arguments is_extern = - match !opt_extra_arguments with - | Some str when not is_extern -> str ^ ", " - | _ -> "" +let extra_arguments is_extern = match !opt_extra_arguments with Some str when not is_extern -> str ^ ", " | _ -> "" (* Optimization flags *) let optimize_primops = ref false @@ -108,11 +102,10 @@ let optimize_alias = ref false let optimize_fixed_int = ref false let optimize_fixed_bits = ref false -let (gensym, _) = symbol_generator "cb" +let gensym, _ = symbol_generator "cb" let ngensym () = name (gensym ()) -let c_error ?loc:(l=Parse_ast.Unknown) message = - raise (Reporting.err_general l ("\nC backend: " ^ message)) +let c_error ?loc:(l = Parse_ast.Unknown) message = raise (Reporting.err_general l ("\nC backend: " ^ message)) let zencode_id id = Util.zencode_string (string_of_id id) @@ -133,7 +126,8 @@ let min_int n = Big_int.negate (Big_int.pow_int_positive 2 (n - 1)) (** This function is used to split types into those we allocate on the stack, versus those which need to live on the heap, or otherwise require some additional memory management *) -let rec is_stack_ctyp ctyp = match ctyp with +let rec is_stack_ctyp ctyp = + match ctyp with | CT_fbits _ | CT_sbits _ | CT_bit | CT_unit | CT_bool | CT_enum _ -> true | CT_fint n -> n <= 64 | CT_lint when !optimize_fixed_int -> true @@ -142,7 +136,10 @@ let rec is_stack_ctyp ctyp = match ctyp with | CT_lbits _ -> false | CT_real | CT_string | CT_list _ | CT_vector _ | CT_fvector _ -> false | CT_struct (_, fields) -> List.for_all (fun (_, ctyp) -> is_stack_ctyp ctyp) fields - | CT_variant (_, _) -> false (* List.for_all (fun (_, ctyp) -> is_stack_ctyp ctyp) ctors *) (* FIXME *) + | CT_variant (_, _) -> + false + (* List.for_all (fun (_, ctyp) -> is_stack_ctyp ctyp) ctors *) + (* FIXME *) | CT_tup ctyps -> List.for_all is_stack_ctyp ctyps | CT_ref _ -> true | CT_poly _ -> true @@ -172,133 +169,127 @@ let hex_char = | 'E' | 'e' -> [B1; B1; B1; B0] | 'F' | 'f' -> [B1; B1; B1; B1] | _ -> failwith "Invalid hex character" - + let literal_to_fragment (L_aux (l_aux, _)) = match l_aux with | L_num n when Big_int.less_equal (min_int 64) n && Big_int.less_equal n (max_int 64) -> - Some (V_lit (VL_int n, CT_fint 64)) + Some (V_lit (VL_int n, CT_fint 64)) | L_hex str when String.length str <= 16 -> - let padding = 16 - String.length str in - let padding = Util.list_init padding (fun _ -> Sail2_values.B0) in - let content = Util.string_to_list str |> List.map hex_char |> List.concat in - Some (V_lit (VL_bits (padding @ content, true), CT_fbits (String.length str * 4, true))) + let padding = 16 - String.length str in + let padding = Util.list_init padding (fun _ -> Sail2_values.B0) in + let content = Util.string_to_list str |> List.map hex_char |> List.concat in + Some (V_lit (VL_bits (padding @ content, true), CT_fbits (String.length str * 4, true))) | L_unit -> Some (V_lit (VL_unit, CT_unit)) | L_true -> Some (V_lit (VL_bool true, CT_bool)) | L_false -> Some (V_lit (VL_bool false, CT_bool)) | _ -> None - -module C_config(Opts : sig val branch_coverage : out_channel option end) : Config = struct - + +module C_config (Opts : sig + val branch_coverage : out_channel option +end) : Config = struct (** Convert a sail type into a C-type. This function can be quite slow, because it uses ctx.local_env and SMT to analyse the Sail types and attempts to fit them into the smallest possible C types, provided ctx.optimize_smt is true (default) **) let rec convert_typ ctx typ = - let Typ_aux (typ_aux, l) as typ = Env.expand_synonyms ctx.local_env typ in + let (Typ_aux (typ_aux, l) as typ) = Env.expand_synonyms ctx.local_env typ in match typ_aux with - | Typ_id id when string_of_id id = "bit" -> CT_bit - | Typ_id id when string_of_id id = "bool" -> CT_bool - | Typ_id id when string_of_id id = "int" -> CT_lint - | Typ_id id when string_of_id id = "nat" -> CT_lint - | Typ_id id when string_of_id id = "unit" -> CT_unit + | Typ_id id when string_of_id id = "bit" -> CT_bit + | Typ_id id when string_of_id id = "bool" -> CT_bool + | Typ_id id when string_of_id id = "int" -> CT_lint + | Typ_id id when string_of_id id = "nat" -> CT_lint + | Typ_id id when string_of_id id = "unit" -> CT_unit | Typ_id id when string_of_id id = "string" -> CT_string - | Typ_id id when string_of_id id = "real" -> CT_real - + | Typ_id id when string_of_id id = "real" -> CT_real | Typ_app (id, _) when string_of_id id = "atom_bool" -> CT_bool - - | Typ_app (id, args) when string_of_id id = "itself" -> - convert_typ ctx (Typ_aux (Typ_app (mk_id "atom", args), l)) + | Typ_app (id, args) when string_of_id id = "itself" -> convert_typ ctx (Typ_aux (Typ_app (mk_id "atom", args), l)) | Typ_app (id, _) when string_of_id id = "range" || string_of_id id = "atom" || string_of_id id = "implicit" -> - begin match destruct_range Env.empty typ with - | None -> assert false (* Checked if range type in guard *) - | Some (kids, constr, n, m) -> - let ctx = { ctx with local_env = add_existential Parse_ast.Unknown (List.map (mk_kopt K_int) kids) constr ctx.local_env }in - match nexp_simp n, nexp_simp m with - | Nexp_aux (Nexp_constant n, _), Nexp_aux (Nexp_constant m, _) - when Big_int.less_equal (min_int 64) n && Big_int.less_equal m (max_int 64) -> - CT_fint 64 - | n, m -> - if prove __POS__ ctx.local_env (nc_lteq (nconstant (min_int 64)) n) && prove __POS__ ctx.local_env (nc_lteq m (nconstant (max_int 64))) then - CT_fint 64 - else - CT_lint - end - - | Typ_app (id, [A_aux (A_typ typ, _)]) when string_of_id id = "list" -> - CT_list (ctyp_suprema (convert_typ ctx typ)) - + begin + match destruct_range Env.empty typ with + | None -> assert false (* Checked if range type in guard *) + | Some (kids, constr, n, m) -> ( + let ctx = + { + ctx with + local_env = add_existential Parse_ast.Unknown (List.map (mk_kopt K_int) kids) constr ctx.local_env; + } + in + match (nexp_simp n, nexp_simp m) with + | Nexp_aux (Nexp_constant n, _), Nexp_aux (Nexp_constant m, _) + when Big_int.less_equal (min_int 64) n && Big_int.less_equal m (max_int 64) -> + CT_fint 64 + | n, m -> + if + prove __POS__ ctx.local_env (nc_lteq (nconstant (min_int 64)) n) + && prove __POS__ ctx.local_env (nc_lteq m (nconstant (max_int 64))) + then CT_fint 64 + else CT_lint + ) + end + | Typ_app (id, [A_aux (A_typ typ, _)]) when string_of_id id = "list" -> CT_list (ctyp_suprema (convert_typ ctx typ)) (* When converting a sail bitvector type into C, we have three options in order of efficiency: - If the length is obviously static and smaller than 64, use the fixed bits type (aka uint64_t), fbits. - If the length is less than 64, then use a small bits type, sbits. - If the length may be larger than 64, use a large bits type lbits. *) - | Typ_app (id, [A_aux (A_nexp n, _); - A_aux (A_order _, _)]) - when string_of_id id = "bitvector" -> - let direction = true in (* match ord with Ord_aux (Ord_dec, _) -> true | Ord_aux (Ord_inc, _) -> false | _ -> assert false in *) - begin match nexp_simp n with - | Nexp_aux (Nexp_constant n, _) when Big_int.less_equal n (Big_int.of_int 64) -> CT_fbits (Big_int.to_int n, direction) - | n when prove __POS__ ctx.local_env (nc_lteq n (nint 64)) -> CT_sbits (64, direction) - | _ -> CT_lbits direction - end - - | Typ_app (id, [A_aux (A_nexp _, _); - A_aux (A_order _, _); - A_aux (A_typ typ, _)]) - when string_of_id id = "vector" -> - let direction = true in (* let direction = match ord with Ord_aux (Ord_dec, _) -> true | Ord_aux (Ord_inc, _) -> false | _ -> assert false in *) - CT_vector (direction, convert_typ ctx typ) - - | Typ_app (id, [A_aux (A_typ typ, _)]) when string_of_id id = "register" -> - CT_ref (convert_typ ctx typ) - - | Typ_id id when Bindings.mem id ctx.records -> CT_struct (id, Bindings.find id ctx.records |> snd |> Bindings.bindings) + | Typ_app (id, [A_aux (A_nexp n, _); A_aux (A_order _, _)]) when string_of_id id = "bitvector" -> + let direction = true in + (* match ord with Ord_aux (Ord_dec, _) -> true | Ord_aux (Ord_inc, _) -> false | _ -> assert false in *) + begin + match nexp_simp n with + | Nexp_aux (Nexp_constant n, _) when Big_int.less_equal n (Big_int.of_int 64) -> + CT_fbits (Big_int.to_int n, direction) + | n when prove __POS__ ctx.local_env (nc_lteq n (nint 64)) -> CT_sbits (64, direction) + | _ -> CT_lbits direction + end + | Typ_app (id, [A_aux (A_nexp _, _); A_aux (A_order _, _); A_aux (A_typ typ, _)]) when string_of_id id = "vector" -> + let direction = true in + (* let direction = match ord with Ord_aux (Ord_dec, _) -> true | Ord_aux (Ord_inc, _) -> false | _ -> assert false in *) + CT_vector (direction, convert_typ ctx typ) + | Typ_app (id, [A_aux (A_typ typ, _)]) when string_of_id id = "register" -> CT_ref (convert_typ ctx typ) + | Typ_id id when Bindings.mem id ctx.records -> + CT_struct (id, Bindings.find id ctx.records |> snd |> Bindings.bindings) | Typ_app (id, typ_args) when Bindings.mem id ctx.records -> - let (typ_params, fields) = Bindings.find id ctx.records in - let quants = - List.fold_left2 (fun quants typ_param typ_arg -> - match typ_arg with - | A_aux (A_typ typ, _) -> - KBindings.add typ_param (convert_typ ctx typ) quants - | _ -> - Reporting.unreachable l __POS__ "Non-type argument for record here should be impossible" - ) ctx.quants typ_params (List.filter is_typ_arg_typ typ_args) - in - let fix_ctyp ctyp = if is_polymorphic ctyp then ctyp_suprema (subst_poly quants ctyp) else ctyp in - CT_struct (id, Bindings.map fix_ctyp fields |> Bindings.bindings) - - | Typ_id id when Bindings.mem id ctx.variants -> CT_variant (id, Bindings.find id ctx.variants |> snd |> Bindings.bindings) + let typ_params, fields = Bindings.find id ctx.records in + let quants = + List.fold_left2 + (fun quants typ_param typ_arg -> + match typ_arg with + | A_aux (A_typ typ, _) -> KBindings.add typ_param (convert_typ ctx typ) quants + | _ -> Reporting.unreachable l __POS__ "Non-type argument for record here should be impossible" + ) + ctx.quants typ_params (List.filter is_typ_arg_typ typ_args) + in + let fix_ctyp ctyp = if is_polymorphic ctyp then ctyp_suprema (subst_poly quants ctyp) else ctyp in + CT_struct (id, Bindings.map fix_ctyp fields |> Bindings.bindings) + | Typ_id id when Bindings.mem id ctx.variants -> + CT_variant (id, Bindings.find id ctx.variants |> snd |> Bindings.bindings) | Typ_app (id, typ_args) when Bindings.mem id ctx.variants -> - let (typ_params, ctors) = Bindings.find id ctx.variants in - let quants = - List.fold_left2 (fun quants typ_param typ_arg -> - match typ_arg with - | A_aux (A_typ typ, _) -> - KBindings.add typ_param (convert_typ ctx typ) quants - | _ -> - Reporting.unreachable l __POS__ "Non-type argument for variant here should be impossible" - ) ctx.quants typ_params (List.filter is_typ_arg_typ typ_args) - in - let fix_ctyp ctyp = if is_polymorphic ctyp then ctyp_suprema (subst_poly quants ctyp) else ctyp in - CT_variant (id, Bindings.map fix_ctyp ctors |> Bindings.bindings) - + let typ_params, ctors = Bindings.find id ctx.variants in + let quants = + List.fold_left2 + (fun quants typ_param typ_arg -> + match typ_arg with + | A_aux (A_typ typ, _) -> KBindings.add typ_param (convert_typ ctx typ) quants + | _ -> Reporting.unreachable l __POS__ "Non-type argument for variant here should be impossible" + ) + ctx.quants typ_params (List.filter is_typ_arg_typ typ_args) + in + let fix_ctyp ctyp = if is_polymorphic ctyp then ctyp_suprema (subst_poly quants ctyp) else ctyp in + CT_variant (id, Bindings.map fix_ctyp ctors |> Bindings.bindings) | Typ_id id when Bindings.mem id ctx.enums -> CT_enum (id, Bindings.find id ctx.enums |> IdSet.elements) - | Typ_tuple typs -> CT_tup (List.map (convert_typ ctx) typs) - - | Typ_exist _ -> - (* Use Type_check.destruct_exist when optimising with SMT, to - ensure that we don't cause any type variable clashes in - local_env, and that we can optimize the existential based - upon its constraints. *) - begin match destruct_exist typ with - | Some (kids, nc, typ) -> - let env = add_existential l kids nc ctx.local_env in - convert_typ { ctx with local_env = env } typ - | None -> raise (Reporting.err_unreachable l __POS__ "Existential cannot be destructured!") - end - + | Typ_exist _ -> begin + (* Use Type_check.destruct_exist when optimising with SMT, to + ensure that we don't cause any type variable clashes in + local_env, and that we can optimize the existential based + upon its constraints. *) + match destruct_exist typ with + | Some (kids, nc, typ) -> + let env = add_existential l kids nc ctx.local_env in + convert_typ { ctx with local_env = env } typ + | None -> raise (Reporting.err_unreachable l __POS__ "Existential cannot be destructured!") + end | Typ_var kid -> CT_poly kid - | _ -> c_error ~loc:l ("No C type for type " ^ string_of_typ typ) (**************************************************************************) @@ -307,12 +298,9 @@ module C_config(Opts : sig val branch_coverage : out_channel option end) : Confi let c_literals ctx = let rec c_literal env l = function - | AV_lit (lit, typ) as v when is_stack_ctyp (convert_typ { ctx with local_env = env } typ) -> - begin - match literal_to_fragment lit with - | Some cval -> AV_cval (cval, typ) - | None -> v - end + | AV_lit (lit, typ) as v when is_stack_ctyp (convert_typ { ctx with local_env = env } typ) -> begin + match literal_to_fragment lit with Some cval -> AV_cval (cval, typ) | None -> v + end | AV_tuple avals -> AV_tuple (List.map (c_literal env l) avals) | v -> v in @@ -332,106 +320,94 @@ module C_config(Opts : sig val branch_coverage : out_channel option end) : Confi (** Used to make sure the -Ofixed_int and -Ofixed_bits don't interfere with assumptions made about optimizations in the common case. *) - let never_optimize = function - | CT_lbits _ | CT_lint -> true - | _ -> false + let never_optimize = function CT_lbits _ | CT_lint -> true | _ -> false let rec c_aval ctx = function - | AV_lit (lit, typ) as v -> - begin - match literal_to_fragment lit with - | Some cval -> AV_cval (cval, typ) - | None -> v - end + | AV_lit (lit, typ) as v -> begin + match literal_to_fragment lit with Some cval -> AV_cval (cval, typ) | None -> v + end | AV_cval (cval, typ) -> AV_cval (cval, typ) (* An id can be converted to a C fragment if its type can be - stack-allocated. *) - | AV_id (id, lvar) as v -> - begin - match lvar with - | Local (_, typ) -> + stack-allocated. *) + | AV_id (id, lvar) as v -> begin + match lvar with + | Local (_, typ) -> let ctyp = convert_typ ctx typ in - if is_stack_ctyp ctyp && not (never_optimize ctyp) then - begin - try - (* We need to check that id's type hasn't changed due to flow typing *) - let _, ctyp' = Bindings.find id ctx.locals in - if ctyp_equal ctyp ctyp' then - AV_cval (V_id (name_or_global ctx id, ctyp), typ) - else - (* id's type changed due to flow typing, so it's - really still heap allocated! *) - v - with - (* Hack: Assuming global letbindings don't change from flow typing... *) - Not_found -> AV_cval (V_id (name_or_global ctx id, ctyp), typ) - end - else - v - | Register typ -> + if is_stack_ctyp ctyp && not (never_optimize ctyp) then begin + try + (* We need to check that id's type hasn't changed due to flow typing *) + let _, ctyp' = Bindings.find id ctx.locals in + if ctyp_equal ctyp ctyp' then AV_cval (V_id (name_or_global ctx id, ctyp), typ) + else + (* id's type changed due to flow typing, so it's + really still heap allocated! *) + v + with (* Hack: Assuming global letbindings don't change from flow typing... *) + | Not_found -> + AV_cval (V_id (name_or_global ctx id, ctyp), typ) + end + else v + | Register typ -> let ctyp = convert_typ ctx typ in - if is_stack_ctyp ctyp && not (never_optimize ctyp) then - AV_cval (V_id (global id, ctyp), typ) - else - v - | _ -> v - end + if is_stack_ctyp ctyp && not (never_optimize ctyp) then AV_cval (V_id (global id, ctyp), typ) else v + | _ -> v + end | AV_vector (v, typ) when is_bitvector v && List.length v <= 64 -> - let bitstring = VL_bits (List.map value_of_aval_bit v, true) in - AV_cval (V_lit (bitstring, CT_fbits (List.length v, true)), typ) + let bitstring = VL_bits (List.map value_of_aval_bit v, true) in + AV_cval (V_lit (bitstring, CT_fbits (List.length v, true)), typ) | AV_tuple avals -> AV_tuple (List.map (c_aval ctx) avals) | aval -> aval (* Map over all the functions in an aexp. *) let rec analyze_functions ctx f (AE_aux (aexp, env, l)) = let ctx = { ctx with local_env = env } in - let aexp = match aexp with + let aexp = + match aexp with | AE_app (id, vs, typ) -> f ctx id vs typ - | AE_typ (aexp, typ) -> AE_typ (analyze_functions ctx f aexp, typ) - | AE_assign (alexp, aexp) -> AE_assign (alexp, analyze_functions ctx f aexp) - | AE_short_circuit (op, aval, aexp) -> AE_short_circuit (op, aval, analyze_functions ctx f aexp) - | AE_let (mut, id, typ1, aexp1, (AE_aux (_, env2, _) as aexp2), typ2) -> - let aexp1 = analyze_functions ctx f aexp1 in - (* Use aexp2's environment because it will contain constraints for id *) - let ctyp1 = convert_typ { ctx with local_env = env2 } typ1 in - let ctx = { ctx with locals = Bindings.add id (mut, ctyp1) ctx.locals } in - AE_let (mut, id, typ1, aexp1, analyze_functions ctx f aexp2, typ2) - - | AE_block (aexps, aexp, typ) -> AE_block (List.map (analyze_functions ctx f) aexps, analyze_functions ctx f aexp, typ) - + let aexp1 = analyze_functions ctx f aexp1 in + (* Use aexp2's environment because it will contain constraints for id *) + let ctyp1 = convert_typ { ctx with local_env = env2 } typ1 in + let ctx = { ctx with locals = Bindings.add id (mut, ctyp1) ctx.locals } in + AE_let (mut, id, typ1, aexp1, analyze_functions ctx f aexp2, typ2) + | AE_block (aexps, aexp, typ) -> + AE_block (List.map (analyze_functions ctx f) aexps, analyze_functions ctx f aexp, typ) | AE_if (aval, aexp1, aexp2, typ) -> - AE_if (aval, analyze_functions ctx f aexp1, analyze_functions ctx f aexp2, typ) - - | AE_loop (loop_typ, aexp1, aexp2) -> AE_loop (loop_typ, analyze_functions ctx f aexp1, analyze_functions ctx f aexp2) - + AE_if (aval, analyze_functions ctx f aexp1, analyze_functions ctx f aexp2, typ) + | AE_loop (loop_typ, aexp1, aexp2) -> + AE_loop (loop_typ, analyze_functions ctx f aexp1, analyze_functions ctx f aexp2) | AE_for (id, aexp1, aexp2, aexp3, order, aexp4) -> - let aexp1 = analyze_functions ctx f aexp1 in - let aexp2 = analyze_functions ctx f aexp2 in - let aexp3 = analyze_functions ctx f aexp3 in - (* Currently we assume that loop indexes are always safe to put into an int64 *) - let ctx = { ctx with locals = Bindings.add id (Immutable, CT_fint 64) ctx.locals } in - let aexp4 = analyze_functions ctx f aexp4 in - AE_for (id, aexp1, aexp2, aexp3, order, aexp4) - + let aexp1 = analyze_functions ctx f aexp1 in + let aexp2 = analyze_functions ctx f aexp2 in + let aexp3 = analyze_functions ctx f aexp3 in + (* Currently we assume that loop indexes are always safe to put into an int64 *) + let ctx = { ctx with locals = Bindings.add id (Immutable, CT_fint 64) ctx.locals } in + let aexp4 = analyze_functions ctx f aexp4 in + AE_for (id, aexp1, aexp2, aexp3, order, aexp4) | AE_match (aval, cases, typ) -> - let analyze_case (AP_aux (_, env, _) as pat, aexp1, aexp2) = - let pat_bindings = Bindings.bindings (apat_types pat) in - let ctx = { ctx with local_env = env } in - let ctx = - List.fold_left (fun ctx (id, typ) -> { ctx with locals = Bindings.add id (Immutable, convert_typ ctx typ) ctx.locals }) ctx pat_bindings - in - pat, analyze_functions ctx f aexp1, analyze_functions ctx f aexp2 - in - AE_match (aval, List.map analyze_case cases, typ) - + let analyze_case ((AP_aux (_, env, _) as pat), aexp1, aexp2) = + let pat_bindings = Bindings.bindings (apat_types pat) in + let ctx = { ctx with local_env = env } in + let ctx = + List.fold_left + (fun ctx (id, typ) -> { ctx with locals = Bindings.add id (Immutable, convert_typ ctx typ) ctx.locals }) + ctx pat_bindings + in + (pat, analyze_functions ctx f aexp1, analyze_functions ctx f aexp2) + in + AE_match (aval, List.map analyze_case cases, typ) | AE_try (aexp, cases, typ) -> - AE_try (analyze_functions ctx f aexp, List.map (fun (pat, aexp1, aexp2) -> pat, analyze_functions ctx f aexp1, analyze_functions ctx f aexp2) cases, typ) - - | AE_field _ | AE_struct_update _ | AE_val _ | AE_return _ | AE_exit _ | AE_throw _ as v -> v + AE_try + ( analyze_functions ctx f aexp, + List.map + (fun (pat, aexp1, aexp2) -> (pat, analyze_functions ctx f aexp1, analyze_functions ctx f aexp2)) + cases, + typ + ) + | (AE_field _ | AE_struct_update _ | AE_val _ | AE_return _ | AE_exit _ | AE_throw _) as v -> v in AE_aux (aexp, env, l) @@ -440,151 +416,108 @@ module C_config(Opts : sig val branch_coverage : out_channel option end) : Confi let args = List.map (c_aval ctx) args in let extern = if ctx_is_extern id ctx then ctx_get_extern id ctx else failwith "Not extern" in - match extern, args with - | "eq_bits", [AV_cval (v1, _); AV_cval (v2, _)] when ctyp_equal (cval_ctyp v1) (cval_ctyp v2) -> - begin match cval_ctyp v1 with - | CT_fbits _ | CT_sbits _ -> - AE_val (AV_cval (V_call (Eq, [v1; v2]), typ)) - | _ -> no_change - end - - | "neq_bits", [AV_cval (v1, _); AV_cval (v2, _)] when ctyp_equal (cval_ctyp v1) (cval_ctyp v2) -> - begin match cval_ctyp v1 with - | CT_fbits _ | CT_sbits _ -> - AE_val (AV_cval (V_call (Neq, [v1; v2]), typ)) - | _ -> no_change - end - - | "eq_int", [AV_cval (v1, _); AV_cval (v2, _)] -> - AE_val (AV_cval (V_call (Eq, [v1; v2]), typ)) - - | "eq_bit", [AV_cval (v1, _); AV_cval (v2, _)] -> - AE_val (AV_cval (V_call (Eq, [v1; v2]), typ)) - - | "zeros", [_] -> - begin match destruct_vector ctx.tc_env typ with - | Some (Nexp_aux (Nexp_constant n, _), _, Typ_aux (Typ_id id, _)) - when string_of_id id = "bit" && Big_int.less_equal n (Big_int.of_int 64) -> - let n = Big_int.to_int n in - AE_val (AV_cval (V_lit (VL_bits (Util.list_init n (fun _ -> Sail2_values.B0), true), CT_fbits (n, true)), typ)) - | _ -> no_change - end - - | "zero_extend", [AV_cval (v, _); _] -> - begin match destruct_vector ctx.tc_env typ with - | Some (Nexp_aux (Nexp_constant n, _), _, Typ_aux (Typ_id id, _)) - when string_of_id id = "bit" && Big_int.less_equal n (Big_int.of_int 64) -> - AE_val (AV_cval (V_call (Zero_extend (Big_int.to_int n), [v]), typ)) - | _ -> no_change - end - - | "sign_extend", [AV_cval (v, _); _] -> - begin match destruct_vector ctx.tc_env typ with - | Some (Nexp_aux (Nexp_constant n, _), _, Typ_aux (Typ_id id, _)) - when string_of_id id = "bit" && Big_int.less_equal n (Big_int.of_int 64) -> - AE_val (AV_cval (V_call (Sign_extend (Big_int.to_int n), [v]), typ)) - | _ -> no_change - end - - | "lteq", [AV_cval (v1, _); AV_cval (v2, _)] -> - AE_val (AV_cval (V_call (Ilteq, [v1; v2]), typ)) - | "gteq", [AV_cval (v1, _); AV_cval (v2, _)] -> - AE_val (AV_cval (V_call (Igteq, [v1; v2]), typ)) - | "lt", [AV_cval (v1, _); AV_cval (v2, _)] -> - AE_val (AV_cval (V_call (Ilt, [v1; v2]), typ)) - | "gt", [AV_cval (v1, _); AV_cval (v2, _)] -> - AE_val (AV_cval (V_call (Igt, [v1; v2]), typ)) - - | "append", [AV_cval (v1, _); AV_cval (v2, _)] -> - begin match convert_typ ctx typ with - | CT_fbits _ | CT_sbits _ -> - AE_val (AV_cval (V_call (Concat, [v1; v2]), typ)) - | _ -> no_change - end - - | "not_bits", [AV_cval (v, _)] -> - AE_val (AV_cval (V_call (Bvnot, [v]), typ)) - + match (extern, args) with + | "eq_bits", [AV_cval (v1, _); AV_cval (v2, _)] when ctyp_equal (cval_ctyp v1) (cval_ctyp v2) -> begin + match cval_ctyp v1 with + | CT_fbits _ | CT_sbits _ -> AE_val (AV_cval (V_call (Eq, [v1; v2]), typ)) + | _ -> no_change + end + | "neq_bits", [AV_cval (v1, _); AV_cval (v2, _)] when ctyp_equal (cval_ctyp v1) (cval_ctyp v2) -> begin + match cval_ctyp v1 with + | CT_fbits _ | CT_sbits _ -> AE_val (AV_cval (V_call (Neq, [v1; v2]), typ)) + | _ -> no_change + end + | "eq_int", [AV_cval (v1, _); AV_cval (v2, _)] -> AE_val (AV_cval (V_call (Eq, [v1; v2]), typ)) + | "eq_bit", [AV_cval (v1, _); AV_cval (v2, _)] -> AE_val (AV_cval (V_call (Eq, [v1; v2]), typ)) + | "zeros", [_] -> begin + match destruct_vector ctx.tc_env typ with + | Some (Nexp_aux (Nexp_constant n, _), _, Typ_aux (Typ_id id, _)) + when string_of_id id = "bit" && Big_int.less_equal n (Big_int.of_int 64) -> + let n = Big_int.to_int n in + AE_val + (AV_cval (V_lit (VL_bits (Util.list_init n (fun _ -> Sail2_values.B0), true), CT_fbits (n, true)), typ)) + | _ -> no_change + end + | "zero_extend", [AV_cval (v, _); _] -> begin + match destruct_vector ctx.tc_env typ with + | Some (Nexp_aux (Nexp_constant n, _), _, Typ_aux (Typ_id id, _)) + when string_of_id id = "bit" && Big_int.less_equal n (Big_int.of_int 64) -> + AE_val (AV_cval (V_call (Zero_extend (Big_int.to_int n), [v]), typ)) + | _ -> no_change + end + | "sign_extend", [AV_cval (v, _); _] -> begin + match destruct_vector ctx.tc_env typ with + | Some (Nexp_aux (Nexp_constant n, _), _, Typ_aux (Typ_id id, _)) + when string_of_id id = "bit" && Big_int.less_equal n (Big_int.of_int 64) -> + AE_val (AV_cval (V_call (Sign_extend (Big_int.to_int n), [v]), typ)) + | _ -> no_change + end + | "lteq", [AV_cval (v1, _); AV_cval (v2, _)] -> AE_val (AV_cval (V_call (Ilteq, [v1; v2]), typ)) + | "gteq", [AV_cval (v1, _); AV_cval (v2, _)] -> AE_val (AV_cval (V_call (Igteq, [v1; v2]), typ)) + | "lt", [AV_cval (v1, _); AV_cval (v2, _)] -> AE_val (AV_cval (V_call (Ilt, [v1; v2]), typ)) + | "gt", [AV_cval (v1, _); AV_cval (v2, _)] -> AE_val (AV_cval (V_call (Igt, [v1; v2]), typ)) + | "append", [AV_cval (v1, _); AV_cval (v2, _)] -> begin + match convert_typ ctx typ with + | CT_fbits _ | CT_sbits _ -> AE_val (AV_cval (V_call (Concat, [v1; v2]), typ)) + | _ -> no_change + end + | "not_bits", [AV_cval (v, _)] -> AE_val (AV_cval (V_call (Bvnot, [v]), typ)) | "add_bits", [AV_cval (v1, _); AV_cval (v2, _)] when ctyp_equal (cval_ctyp v1) (cval_ctyp v2) -> - AE_val (AV_cval (V_call (Bvadd, [v1; v2]), typ)) - + AE_val (AV_cval (V_call (Bvadd, [v1; v2]), typ)) | "sub_bits", [AV_cval (v1, _); AV_cval (v2, _)] when ctyp_equal (cval_ctyp v1) (cval_ctyp v2) -> - AE_val (AV_cval (V_call (Bvsub, [v1; v2]), typ)) - + AE_val (AV_cval (V_call (Bvsub, [v1; v2]), typ)) | "and_bits", [AV_cval (v1, _); AV_cval (v2, _)] when ctyp_equal (cval_ctyp v1) (cval_ctyp v2) -> - AE_val (AV_cval (V_call (Bvand, [v1; v2]), typ)) - + AE_val (AV_cval (V_call (Bvand, [v1; v2]), typ)) | "or_bits", [AV_cval (v1, _); AV_cval (v2, _)] when ctyp_equal (cval_ctyp v1) (cval_ctyp v2) -> - AE_val (AV_cval (V_call (Bvor, [v1; v2]), typ)) - + AE_val (AV_cval (V_call (Bvor, [v1; v2]), typ)) | "xor_bits", [AV_cval (v1, _); AV_cval (v2, _)] when ctyp_equal (cval_ctyp v1) (cval_ctyp v2) -> - AE_val (AV_cval (V_call (Bvxor, [v1; v2]), typ)) - - | "vector_subrange", [AV_cval (vec, _); AV_cval (_, _); AV_cval (t, _)] -> - begin match convert_typ ctx typ with - | CT_fbits (n, true) -> - AE_val (AV_cval (V_call (Slice n, [vec; t]), typ)) - | _ -> no_change - end - - | "slice", [AV_cval (vec, _); AV_cval (start, _); AV_cval (len, _)] -> - begin match convert_typ ctx typ with - | CT_fbits (n, _) -> - AE_val (AV_cval (V_call (Slice n, [vec; start]), typ)) - | CT_sbits (64, _) -> - AE_val (AV_cval (V_call (Sslice 64, [vec; start; len]), typ)) - | _ -> no_change - end - - | "vector_access", [AV_cval (vec, _); AV_cval (n, _)] -> - AE_val (AV_cval (V_call (Bvaccess, [vec; n]), typ)) - - | "add_int", [AV_cval (op1, _); AV_cval (op2, _)] -> - begin match destruct_range ctx.local_env typ with - | None -> no_change - | Some (_, _, n, m) -> - match nexp_simp n, nexp_simp m with - | Nexp_aux (Nexp_constant n, _), Nexp_aux (Nexp_constant m, _) - when Big_int.less_equal (min_int 64) n && Big_int.less_equal m (max_int 64) -> - AE_val (AV_cval (V_call (Iadd, [op1; op2]), typ)) - | n, m when prove __POS__ ctx.local_env (nc_lteq (nconstant (min_int 64)) n) && prove __POS__ ctx.local_env (nc_lteq m (nconstant (max_int 64))) -> - AE_val (AV_cval (V_call (Iadd, [op1; op2]), typ)) - | _ -> no_change - end - - | "replicate_bits", [AV_cval (vec, vtyp); _] -> - begin match destruct_vector ctx.tc_env typ, destruct_vector ctx.tc_env vtyp with - | Some (Nexp_aux (Nexp_constant n, _), _, _), Some (Nexp_aux (Nexp_constant m, _), _, _) - when Big_int.less_equal n (Big_int.of_int 64) -> - let times = Big_int.div n m in - if Big_int.equal (Big_int.mul m times) n then - AE_val (AV_cval (V_call (Replicate (Big_int.to_int times), [vec]), typ)) - else - no_change - | _, _ -> - no_change - end - - | "undefined_bit", _ -> - AE_val (AV_cval (V_lit (VL_bit Sail2_values.B0, CT_bit), typ)) - - | "undefined_bool", _ -> - AE_val (AV_cval (V_lit (VL_bool false, CT_bool), typ)) - - | _, _ -> - no_change + AE_val (AV_cval (V_call (Bvxor, [v1; v2]), typ)) + | "vector_subrange", [AV_cval (vec, _); AV_cval (_, _); AV_cval (t, _)] -> begin + match convert_typ ctx typ with + | CT_fbits (n, true) -> AE_val (AV_cval (V_call (Slice n, [vec; t]), typ)) + | _ -> no_change + end + | "slice", [AV_cval (vec, _); AV_cval (start, _); AV_cval (len, _)] -> begin + match convert_typ ctx typ with + | CT_fbits (n, _) -> AE_val (AV_cval (V_call (Slice n, [vec; start]), typ)) + | CT_sbits (64, _) -> AE_val (AV_cval (V_call (Sslice 64, [vec; start; len]), typ)) + | _ -> no_change + end + | "vector_access", [AV_cval (vec, _); AV_cval (n, _)] -> AE_val (AV_cval (V_call (Bvaccess, [vec; n]), typ)) + | "add_int", [AV_cval (op1, _); AV_cval (op2, _)] -> begin + match destruct_range ctx.local_env typ with + | None -> no_change + | Some (_, _, n, m) -> ( + match (nexp_simp n, nexp_simp m) with + | Nexp_aux (Nexp_constant n, _), Nexp_aux (Nexp_constant m, _) + when Big_int.less_equal (min_int 64) n && Big_int.less_equal m (max_int 64) -> + AE_val (AV_cval (V_call (Iadd, [op1; op2]), typ)) + | n, m + when prove __POS__ ctx.local_env (nc_lteq (nconstant (min_int 64)) n) + && prove __POS__ ctx.local_env (nc_lteq m (nconstant (max_int 64))) -> + AE_val (AV_cval (V_call (Iadd, [op1; op2]), typ)) + | _ -> no_change + ) + end + | "replicate_bits", [AV_cval (vec, vtyp); _] -> begin + match (destruct_vector ctx.tc_env typ, destruct_vector ctx.tc_env vtyp) with + | Some (Nexp_aux (Nexp_constant n, _), _, _), Some (Nexp_aux (Nexp_constant m, _), _, _) + when Big_int.less_equal n (Big_int.of_int 64) -> + let times = Big_int.div n m in + if Big_int.equal (Big_int.mul m times) n then + AE_val (AV_cval (V_call (Replicate (Big_int.to_int times), [vec]), typ)) + else no_change + | _, _ -> no_change + end + | "undefined_bit", _ -> AE_val (AV_cval (V_lit (VL_bit Sail2_values.B0, CT_bit), typ)) + | "undefined_bool", _ -> AE_val (AV_cval (V_lit (VL_bool false, CT_bool), typ)) + | _, _ -> no_change let analyze_primop ctx id args typ = let no_change = AE_app (id, args, typ) in - if !optimize_primops then - try analyze_primop' ctx id args typ with - | Failure _ -> - no_change - else - no_change + if !optimize_primops then (try analyze_primop' ctx id args typ with Failure _ -> no_change) else no_change - let optimize_anf ctx aexp = - analyze_functions ctx analyze_primop (c_literals ctx aexp) + let optimize_anf ctx aexp = analyze_functions ctx analyze_primop (c_literals ctx aexp) let unroll_loops = None let specialize_calls = false @@ -595,7 +528,7 @@ module C_config(Opts : sig val branch_coverage : out_channel option end) : Confi let branch_coverage = Opts.branch_coverage let track_throw = true end - + (** Functions that have heap-allocated return types are implemented by passing a pointer a location where the return value should be stored. The ANF -> Sail IR pass for expressions simply outputs an @@ -615,100 +548,63 @@ let fix_early_heap_return ret instrs = let rec rewrite_return instrs = match instr_split_at is_return_recur instrs with | instrs, [] -> instrs - | before, I_aux (I_block instrs, _) :: after -> - before - @ [iblock (rewrite_return instrs)] - @ rewrite_return after + | before, I_aux (I_block instrs, _) :: after -> before @ [iblock (rewrite_return instrs)] @ rewrite_return after | before, I_aux (I_try_block instrs, (_, l)) :: after -> - before - @ [itry_block l (rewrite_return instrs)] - @ rewrite_return after + before @ [itry_block l (rewrite_return instrs)] @ rewrite_return after | before, I_aux (I_if (cval, then_instrs, else_instrs, ctyp), (_, l)) :: after -> - before - @ [iif l cval (rewrite_return then_instrs) (rewrite_return else_instrs) ctyp] - @ rewrite_return after + before @ [iif l cval (rewrite_return then_instrs) (rewrite_return else_instrs) ctyp] @ rewrite_return after | before, I_aux (I_funcall (CL_id (Return _, ctyp), extern, fid, args), aux) :: after -> - before - @ [I_aux (I_funcall (CL_addr (CL_id (ret, CT_ref ctyp)), extern, fid, args), aux)] - @ rewrite_return after + before @ [I_aux (I_funcall (CL_addr (CL_id (ret, CT_ref ctyp)), extern, fid, args), aux)] @ rewrite_return after | before, I_aux (I_copy (CL_id (Return _, ctyp), cval), aux) :: after -> - before - @ [I_aux (I_copy (CL_addr (CL_id (ret, CT_ref ctyp)), cval), aux)] - @ rewrite_return after + before @ [I_aux (I_copy (CL_addr (CL_id (ret, CT_ref ctyp)), cval), aux)] @ rewrite_return after | before, I_aux ((I_end _ | I_undefined _), _) :: after -> - before - @ [igoto end_function_label] - @ rewrite_return after - | before, (I_aux ((I_copy _ | I_funcall _), _) as instr) :: after -> - before @ instr :: rewrite_return after + before @ [igoto end_function_label] @ rewrite_return after + | before, (I_aux ((I_copy _ | I_funcall _), _) as instr) :: after -> before @ (instr :: rewrite_return after) | _, _ -> assert false in - rewrite_return instrs - @ [ilabel end_function_label] + rewrite_return instrs @ [ilabel end_function_label] (* This is like fix_early_heap_return, but for stack allocated returns. *) let fix_early_stack_return ret ret_ctyp instrs = let is_return_recur (I_aux (instr, _)) = - match instr with - | I_if _ | I_block _ | I_try_block _ | I_end _ | I_funcall _ | I_copy _ -> true - | _ -> false + match instr with I_if _ | I_block _ | I_try_block _ | I_end _ | I_funcall _ | I_copy _ -> true | _ -> false in let rec rewrite_return instrs = match instr_split_at is_return_recur instrs with | instrs, [] -> instrs - | before, I_aux (I_block instrs, _) :: after -> - before - @ [iblock (rewrite_return instrs)] - @ rewrite_return after + | before, I_aux (I_block instrs, _) :: after -> before @ [iblock (rewrite_return instrs)] @ rewrite_return after | before, I_aux (I_try_block instrs, (_, l)) :: after -> - before - @ [itry_block l (rewrite_return instrs)] - @ rewrite_return after + before @ [itry_block l (rewrite_return instrs)] @ rewrite_return after | before, I_aux (I_if (cval, then_instrs, else_instrs, ctyp), (_, l)) :: after -> - before - @ [iif l cval (rewrite_return then_instrs) (rewrite_return else_instrs) ctyp] - @ rewrite_return after + before @ [iif l cval (rewrite_return then_instrs) (rewrite_return else_instrs) ctyp] @ rewrite_return after | before, I_aux (I_funcall (CL_id (Return _, ctyp), extern, fid, args), aux) :: after -> - before - @ [I_aux (I_funcall (CL_id (ret, ctyp), extern, fid, args), aux)] - @ rewrite_return after + before @ [I_aux (I_funcall (CL_id (ret, ctyp), extern, fid, args), aux)] @ rewrite_return after | before, I_aux (I_copy (CL_id (Return _, ctyp), cval), aux) :: after -> - before - @ [I_aux (I_copy (CL_id (ret, ctyp), cval), aux)] - @ rewrite_return after - | before, I_aux (I_end _, _) :: after -> - before - @ [ireturn (V_id (ret, ret_ctyp))] - @ rewrite_return after - | before, (I_aux ((I_copy _ | I_funcall _), _) as instr) :: after -> - before @ instr :: rewrite_return after + before @ [I_aux (I_copy (CL_id (ret, ctyp), cval), aux)] @ rewrite_return after + | before, I_aux (I_end _, _) :: after -> before @ [ireturn (V_id (ret, ret_ctyp))] @ rewrite_return after + | before, (I_aux ((I_copy _ | I_funcall _), _) as instr) :: after -> before @ (instr :: rewrite_return after) | _, _ -> assert false in rewrite_return instrs let rec insert_heap_returns ret_ctyps = function | (CDEF_val (id, _, _, ret_ctyp) as cdef) :: cdefs -> - cdef :: insert_heap_returns (Bindings.add id ret_ctyp ret_ctyps) cdefs - + cdef :: insert_heap_returns (Bindings.add id ret_ctyp ret_ctyps) cdefs | CDEF_fundef (id, None, args, body) :: cdefs -> - let gs = gensym () in - begin match Bindings.find_opt id ret_ctyps with - | None -> - raise (Reporting.err_general (id_loc id) ("Cannot find return type for function " ^ string_of_id id)) - | Some ret_ctyp when not (is_stack_ctyp ret_ctyp) -> - CDEF_fundef (id, Some gs, args, fix_early_heap_return (name gs) body) - :: insert_heap_returns ret_ctyps cdefs - | Some ret_ctyp -> - CDEF_fundef (id, None, args, fix_early_stack_return (name gs) ret_ctyp (idecl (id_loc id) ret_ctyp (name gs) :: body)) - :: insert_heap_returns ret_ctyps cdefs - end - + let gs = gensym () in + begin + match Bindings.find_opt id ret_ctyps with + | None -> raise (Reporting.err_general (id_loc id) ("Cannot find return type for function " ^ string_of_id id)) + | Some ret_ctyp when not (is_stack_ctyp ret_ctyp) -> + CDEF_fundef (id, Some gs, args, fix_early_heap_return (name gs) body) :: insert_heap_returns ret_ctyps cdefs + | Some ret_ctyp -> + CDEF_fundef + (id, None, args, fix_early_stack_return (name gs) ret_ctyp (idecl (id_loc id) ret_ctyp (name gs) :: body)) + :: insert_heap_returns ret_ctyps cdefs + end | CDEF_fundef (id, _, _, _) :: _ -> - Reporting.unreachable (id_loc id) __POS__ "Found function with return already re-written in insert_heap_returns" - - | cdef :: cdefs -> - cdef :: insert_heap_returns ret_ctyps cdefs - + Reporting.unreachable (id_loc id) __POS__ "Found function with return already re-written in insert_heap_returns" + | cdef :: cdefs -> cdef :: insert_heap_returns ret_ctyps cdefs | [] -> [] (** To keep things neat we use GCC's local labels extension to limit @@ -719,30 +615,19 @@ let rec insert_heap_returns ret_ctyps = function See https://gcc.gnu.org/onlinedocs/gcc/Local-Labels.html **) let add_local_labels' instrs = - let is_label (I_aux (instr, _)) = - match instr with - | I_label str -> [str] - | _ -> [] - in + let is_label (I_aux (instr, _)) = match instr with I_label str -> [str] | _ -> [] in let labels = List.concat (List.map is_label instrs) in let local_label_decl = iraw ("__label__ " ^ String.concat ", " labels ^ ";\n") in - if labels = [] then - instrs - else - local_label_decl :: instrs + if labels = [] then instrs else local_label_decl :: instrs let add_local_labels instrs = - match map_instrs add_local_labels' (iblock instrs) with - | I_aux (I_block instrs, _) -> instrs - | _ -> assert false + match map_instrs add_local_labels' (iblock instrs) with I_aux (I_block instrs, _) -> instrs | _ -> assert false (**************************************************************************) (* 5. Optimizations *) (**************************************************************************) -let hoist_ctyp = function - | CT_lint | CT_lbits _ | CT_struct _ -> true - | _ -> false +let hoist_ctyp = function CT_lint | CT_lbits _ | CT_struct _ -> true | _ -> false let hoist_counter = ref 0 let hoist_id () = @@ -751,55 +636,44 @@ let hoist_id () = name id let hoist_allocations recursive_functions = function - | CDEF_fundef (function_id, _, _, _) as cdef when IdSet.mem function_id recursive_functions -> - [cdef] - + | CDEF_fundef (function_id, _, _, _) as cdef when IdSet.mem function_id recursive_functions -> [cdef] | CDEF_fundef (function_id, heap_return, args, body) -> - let decls = ref [] in - let cleanups = ref [] in - let rec hoist = function - | I_aux (I_decl (ctyp, decl_id), annot) :: instrs when hoist_ctyp ctyp -> - let hid = hoist_id () in - decls := idecl (snd annot) ctyp hid :: !decls; - cleanups := iclear ctyp hid :: !cleanups; - let instrs = instrs_rename decl_id hid instrs in - I_aux (I_reset (ctyp, hid), annot) :: hoist instrs - - | I_aux (I_init (ctyp, decl_id, cval), annot) :: instrs when hoist_ctyp ctyp -> - let hid = hoist_id () in - decls := idecl (snd annot) ctyp hid :: !decls; - cleanups := iclear ctyp hid :: !cleanups; - let instrs = instrs_rename decl_id hid instrs in - I_aux (I_reinit (ctyp, hid, cval), annot) :: hoist instrs - - | I_aux (I_clear (ctyp, _), _) :: instrs when hoist_ctyp ctyp -> - hoist instrs - - | I_aux (I_block block, annot) :: instrs -> - I_aux (I_block (hoist block), annot) :: hoist instrs - | I_aux (I_try_block block, annot) :: instrs -> - I_aux (I_try_block (hoist block), annot) :: hoist instrs - | I_aux (I_if (cval, then_instrs, else_instrs, ctyp), annot) :: instrs -> - I_aux (I_if (cval, hoist then_instrs, hoist else_instrs, ctyp), annot) :: hoist instrs - - | instr :: instrs -> instr :: hoist instrs - | [] -> [] - in - let body = hoist body in - if !decls = [] then - [CDEF_fundef (function_id, heap_return, args, body)] - else - [CDEF_startup (function_id, List.rev !decls); - CDEF_fundef (function_id, heap_return, args, body); - CDEF_finish (function_id, !cleanups)] - + let decls = ref [] in + let cleanups = ref [] in + let rec hoist = function + | I_aux (I_decl (ctyp, decl_id), annot) :: instrs when hoist_ctyp ctyp -> + let hid = hoist_id () in + decls := idecl (snd annot) ctyp hid :: !decls; + cleanups := iclear ctyp hid :: !cleanups; + let instrs = instrs_rename decl_id hid instrs in + I_aux (I_reset (ctyp, hid), annot) :: hoist instrs + | I_aux (I_init (ctyp, decl_id, cval), annot) :: instrs when hoist_ctyp ctyp -> + let hid = hoist_id () in + decls := idecl (snd annot) ctyp hid :: !decls; + cleanups := iclear ctyp hid :: !cleanups; + let instrs = instrs_rename decl_id hid instrs in + I_aux (I_reinit (ctyp, hid, cval), annot) :: hoist instrs + | I_aux (I_clear (ctyp, _), _) :: instrs when hoist_ctyp ctyp -> hoist instrs + | I_aux (I_block block, annot) :: instrs -> I_aux (I_block (hoist block), annot) :: hoist instrs + | I_aux (I_try_block block, annot) :: instrs -> I_aux (I_try_block (hoist block), annot) :: hoist instrs + | I_aux (I_if (cval, then_instrs, else_instrs, ctyp), annot) :: instrs -> + I_aux (I_if (cval, hoist then_instrs, hoist else_instrs, ctyp), annot) :: hoist instrs + | instr :: instrs -> instr :: hoist instrs + | [] -> [] + in + let body = hoist body in + if !decls = [] then [CDEF_fundef (function_id, heap_return, args, body)] + else + [ + CDEF_startup (function_id, List.rev !decls); + CDEF_fundef (function_id, heap_return, args, body); + CDEF_finish (function_id, !cleanups); + ] | cdef -> [cdef] let removed = icomment "REMOVED" -let is_not_removed = function - | I_aux (I_comment "REMOVED", _) -> false - | _ -> true +let is_not_removed = function I_aux (I_comment "REMOVED", _) -> false | _ -> true (** This optimization looks for patterns of the form: @@ -815,34 +689,19 @@ let remove_alias = let pattern ctyp id = let alias = ref None in let rec scan ctyp id n instrs = - match n, !alias, instrs with + match (n, !alias, instrs) with | 0, None, I_aux (I_copy (CL_id (id', ctyp'), V_id (a, ctyp'')), _) :: instrs - when Name.compare id id' = 0 && ctyp_equal ctyp ctyp' && ctyp_equal ctyp' ctyp'' -> - alias := Some a; - scan ctyp id 1 instrs - + when Name.compare id id' = 0 && ctyp_equal ctyp ctyp' && ctyp_equal ctyp' ctyp'' -> + alias := Some a; + scan ctyp id 1 instrs | 1, Some a, I_aux (I_copy (CL_id (a', ctyp'), V_id (id', ctyp'')), _) :: instrs - when Name.compare a a' = 0 && Name.compare id id' = 0 && ctyp_equal ctyp ctyp' && ctyp_equal ctyp' ctyp'' -> - scan ctyp id 2 instrs - - | 1, Some a, instr :: instrs -> - if NameSet.mem a (instr_ids instr) then - None - else - scan ctyp id 1 instrs - - | 2, Some _, I_aux (I_clear (ctyp', id'), _) :: instrs - when Name.compare id id' = 0 && ctyp_equal ctyp ctyp' -> - scan ctyp id 2 instrs - - | 2, Some _, instr :: instrs -> - if NameSet.mem id (instr_ids instr) then - None - else - scan ctyp id 2 instrs - + when Name.compare a a' = 0 && Name.compare id id' = 0 && ctyp_equal ctyp ctyp' && ctyp_equal ctyp' ctyp'' -> + scan ctyp id 2 instrs + | 1, Some a, instr :: instrs -> if NameSet.mem a (instr_ids instr) then None else scan ctyp id 1 instrs + | 2, Some _, I_aux (I_clear (ctyp', id'), _) :: instrs when Name.compare id id' = 0 && ctyp_equal ctyp ctyp' -> + scan ctyp id 2 instrs + | 2, Some _, instr :: instrs -> if NameSet.mem id (instr_ids instr) then None else scan ctyp id 2 instrs | 2, Some _, [] -> !alias - | n, _, _ :: instrs when n = 0 || n > 2 -> scan ctyp id n instrs | _, _, I_aux (_, (_, l)) :: _ -> Reporting.unreachable l __POS__ "optimize_alias" | _, _, [] -> None @@ -850,34 +709,32 @@ let remove_alias = scan ctyp id 0 in let remove_alias id alias = function - | I_aux (I_copy (CL_id (id', _), V_id (alias', _)), _) - when Name.compare id id' = 0 && Name.compare alias alias' = 0 -> removed - | I_aux (I_copy (CL_id (alias', _), V_id (id', _)), _) - when Name.compare id id' = 0 && Name.compare alias alias' = 0 -> removed + | I_aux (I_copy (CL_id (id', _), V_id (alias', _)), _) when Name.compare id id' = 0 && Name.compare alias alias' = 0 + -> + removed + | I_aux (I_copy (CL_id (alias', _), V_id (id', _)), _) when Name.compare id id' = 0 && Name.compare alias alias' = 0 + -> + removed | I_aux (I_clear (_, _), _) -> removed | instr -> instr in let rec opt = function - | I_aux (I_decl (ctyp, id), _) as instr :: instrs -> - begin match pattern ctyp id instrs with - | None -> instr :: opt instrs - | Some alias -> - let instrs = List.map (map_instr (remove_alias id alias)) instrs in - filter_instrs is_not_removed (List.map (instr_rename id alias) instrs) - end - + | (I_aux (I_decl (ctyp, id), _) as instr) :: instrs -> begin + match pattern ctyp id instrs with + | None -> instr :: opt instrs + | Some alias -> + let instrs = List.map (map_instr (remove_alias id alias)) instrs in + filter_instrs is_not_removed (List.map (instr_rename id alias) instrs) + end | I_aux (I_block block, aux) :: instrs -> I_aux (I_block (opt block), aux) :: opt instrs | I_aux (I_try_block block, aux) :: instrs -> I_aux (I_try_block (opt block), aux) :: opt instrs | I_aux (I_if (cval, then_instrs, else_instrs, ctyp), aux) :: instrs -> - I_aux (I_if (cval, opt then_instrs, opt else_instrs, ctyp), aux) :: opt instrs - - | instr :: instrs -> - instr :: opt instrs + I_aux (I_if (cval, opt then_instrs, opt else_instrs, ctyp), aux) :: opt instrs + | instr :: instrs -> instr :: opt instrs | [] -> [] in function - | CDEF_fundef (function_id, heap_return, args, body) -> - [CDEF_fundef (function_id, heap_return, args, opt body)] + | CDEF_fundef (function_id, heap_return, args, body) -> [CDEF_fundef (function_id, heap_return, args, opt body)] | cdef -> [cdef] (** This optimization looks for patterns of the form @@ -893,44 +750,22 @@ let combine_variables = let pattern ctyp id = let combine = ref None in let rec scan id n instrs = - match n, !combine, instrs with - | 0, None, I_aux (I_block block, _) :: instrs -> - begin match scan id 0 block with - | Some combine -> Some combine - | None -> scan id 0 instrs - end - + match (n, !combine, instrs) with + | 0, None, I_aux (I_block block, _) :: instrs -> begin + match scan id 0 block with Some combine -> Some combine | None -> scan id 0 instrs + end | 0, None, I_aux (I_decl (ctyp', id'), _) :: instrs when ctyp_equal ctyp ctyp' -> - combine := Some id'; - scan id 1 instrs - + combine := Some id'; + scan id 1 instrs | 1, Some c, I_aux (I_copy (CL_id (id', ctyp'), V_id (c', ctyp'')), _) :: instrs - when Name.compare c c' = 0 && Name.compare id id' = 0 && ctyp_equal ctyp ctyp' && ctyp_equal ctyp' ctyp'' -> - scan id 2 instrs - + when Name.compare c c' = 0 && Name.compare id id' = 0 && ctyp_equal ctyp ctyp' && ctyp_equal ctyp' ctyp'' -> + scan id 2 instrs (* Ignore seemingly early clears of x, as this can happen along exception paths *) - | 1, Some _, I_aux (I_clear (_, id'), _) :: instrs - when Name.compare id id' = 0 -> - scan id 1 instrs - - | 1, Some _, instr :: instrs -> - if NameSet.mem id (instr_ids instr) then - None - else - scan id 1 instrs - - | 2, Some c, I_aux (I_clear (ctyp', c'), _) :: _ - when Name.compare c c' = 0 && ctyp_equal ctyp ctyp' -> - !combine - - | 2, Some c, instr :: instrs -> - if NameSet.mem c (instr_ids instr) then - None - else - scan id 2 instrs - + | 1, Some _, I_aux (I_clear (_, id'), _) :: instrs when Name.compare id id' = 0 -> scan id 1 instrs + | 1, Some _, instr :: instrs -> if NameSet.mem id (instr_ids instr) then None else scan id 1 instrs + | 2, Some c, I_aux (I_clear (ctyp', c'), _) :: _ when Name.compare c c' = 0 && ctyp_equal ctyp ctyp' -> !combine + | 2, Some c, instr :: instrs -> if NameSet.mem c (instr_ids instr) then None else scan id 2 instrs | 2, Some _, [] -> !combine - | n, _, _ :: instrs -> scan id n instrs | _, _, [] -> None in @@ -946,28 +781,27 @@ let combine_variables = | _ -> true in let rec opt = function - | (I_aux (I_decl (ctyp, id), _) as instr) :: instrs -> - begin match pattern ctyp id instrs with - | None -> instr :: opt instrs - | Some combine -> - let instrs = List.map (map_instr (remove_variable combine)) instrs in - let instrs = filter_instrs (fun i -> is_not_removed i && is_not_self_assignment i) - (List.map (instr_rename combine id) instrs) in - opt (instr :: instrs) - end - + | (I_aux (I_decl (ctyp, id), _) as instr) :: instrs -> begin + match pattern ctyp id instrs with + | None -> instr :: opt instrs + | Some combine -> + let instrs = List.map (map_instr (remove_variable combine)) instrs in + let instrs = + filter_instrs + (fun i -> is_not_removed i && is_not_self_assignment i) + (List.map (instr_rename combine id) instrs) + in + opt (instr :: instrs) + end | I_aux (I_block block, aux) :: instrs -> I_aux (I_block (opt block), aux) :: opt instrs | I_aux (I_try_block block, aux) :: instrs -> I_aux (I_try_block (opt block), aux) :: opt instrs | I_aux (I_if (cval, then_instrs, else_instrs, ctyp), aux) :: instrs -> - I_aux (I_if (cval, opt then_instrs, opt else_instrs, ctyp), aux) :: opt instrs - - | instr :: instrs -> - instr :: opt instrs + I_aux (I_if (cval, opt then_instrs, opt else_instrs, ctyp), aux) :: opt instrs + | instr :: instrs -> instr :: opt instrs | [] -> [] in function - | CDEF_fundef (function_id, heap_return, args, body) -> - [CDEF_fundef (function_id, heap_return, args, opt body)] + | CDEF_fundef (function_id, heap_return, args, body) -> [CDEF_fundef (function_id, heap_return, args, opt body)] | cdef -> [cdef] let concatMap f xs = List.concat (List.map f xs) @@ -978,7 +812,8 @@ let optimize recursive_functions cdefs = |> (if !optimize_alias then concatMap remove_alias else nothing) |> (if !optimize_alias then concatMap combine_variables else nothing) (* We need the runtime to initialize hoisted allocations *) - |> (if !optimize_hoist_allocations && not !opt_no_rts then concatMap (hoist_allocations recursive_functions) else nothing) + |> + if !optimize_hoist_allocations && not !opt_no_rts then concatMap (hoist_allocations recursive_functions) else nothing (**************************************************************************) (* 6. Code generation *) @@ -1046,32 +881,21 @@ let rec sgen_ctyp_name = function | CT_float n -> "float" ^ string_of_int n | CT_rounding_mode -> "rounding_mode" | CT_poly _ -> "POLY" (* c_error "Tried to generate code for non-monomorphic type" *) - + let sgen_mask n = - if n = 0 then - "UINT64_C(0)" - else if n <= 64 then + if n = 0 then "UINT64_C(0)" + else if n <= 64 then ( let chars_F = String.make (n / 4) 'F' in - let first = match (n mod 4) with - | 0 -> "" - | 1 -> "1" - | 2 -> "3" - | 3 -> "7" - | _ -> assert false - in + let first = match n mod 4 with 0 -> "" | 1 -> "1" | 2 -> "3" | 3 -> "7" | _ -> assert false in "UINT64_C(0x" ^ first ^ chars_F ^ ")" - else - failwith "Tried to create a mask literal for a vector greater than 64 bits." + ) + else failwith "Tried to create a mask literal for a vector greater than 64 bits." let sgen_value = function | VL_bits ([], _) -> "UINT64_C(0)" | VL_bits (bs, true) -> "UINT64_C(" ^ Sail2_values.show_bitlist bs ^ ")" | VL_bits (bs, false) -> "UINT64_C(" ^ Sail2_values.show_bitlist (List.rev bs) ^ ")" - | VL_int i -> - if Big_int.equal i (min_int 64) then - "INT64_MIN" - else - "INT64_C(" ^ Big_int.to_string i ^ ")" + | VL_int i -> if Big_int.equal i (min_int 64) then "INT64_MIN" else "INT64_C(" ^ Big_int.to_string i ^ ")" | VL_bool true -> "true" | VL_bool false -> "false" | VL_unit -> "UNIT" @@ -1083,200 +907,145 @@ let sgen_value = function | VL_empty_list -> "NULL" | VL_enum element -> Util.zencode_string element | VL_ref r -> "&" ^ Util.zencode_string r - | VL_undefined -> - Reporting.unreachable Parse_ast.Unknown __POS__ "Cannot generate C value for an undefined literal" + | VL_undefined -> Reporting.unreachable Parse_ast.Unknown __POS__ "Cannot generate C value for an undefined literal" let rec sgen_cval = function | V_id (id, _) -> sgen_name id | V_lit (vl, _) -> sgen_value vl | V_call (op, cvals) -> sgen_call op cvals - | V_field (f, field) -> - Printf.sprintf "%s.%s" (sgen_cval f) (sgen_id field) - | V_tuple_member (f, _, n) -> - Printf.sprintf "%s.ztup%d" (sgen_cval f) n - | V_ctor_kind (f, ctor, _) -> - sgen_cval f ^ ".kind" - ^ " != Kind_" ^ zencode_uid ctor + | V_field (f, field) -> Printf.sprintf "%s.%s" (sgen_cval f) (sgen_id field) + | V_tuple_member (f, _, n) -> Printf.sprintf "%s.ztup%d" (sgen_cval f) n + | V_ctor_kind (f, ctor, _) -> sgen_cval f ^ ".kind" ^ " != Kind_" ^ zencode_uid ctor | V_struct (fields, _) -> - Printf.sprintf "{%s}" - (Util.string_of_list ", " (fun (field, cval) -> zencode_id field ^ " = " ^ sgen_cval cval) fields) - | V_ctor_unwrap (f, ctor, _) -> - Printf.sprintf "%s.%s" - (sgen_cval f) - (sgen_uid ctor) - | V_tuple _ -> - Reporting.unreachable Parse_ast.Unknown __POS__ "Cannot generate C value for a tuple literal" + Printf.sprintf "{%s}" + (Util.string_of_list ", " (fun (field, cval) -> zencode_id field ^ " = " ^ sgen_cval cval) fields) + | V_ctor_unwrap (f, ctor, _) -> Printf.sprintf "%s.%s" (sgen_cval f) (sgen_uid ctor) + | V_tuple _ -> Reporting.unreachable Parse_ast.Unknown __POS__ "Cannot generate C value for a tuple literal" and sgen_call op cvals = let open Printf in - match op, cvals with + match (op, cvals) with | Bnot, [v] -> "!(" ^ sgen_cval v ^ ")" - | List_hd, [v] -> - sprintf "(%s).hd" ("*" ^ sgen_cval v) - | List_tl, [v] -> - sprintf "(%s).tl" ("*" ^ sgen_cval v) - | Eq, [v1; v2] -> - begin match cval_ctyp v1 with - | CT_sbits _ -> - sprintf "eq_sbits(%s, %s)" (sgen_cval v1) (sgen_cval v2) - | _ -> - sprintf "(%s == %s)" (sgen_cval v1) (sgen_cval v2) - end - | Neq, [v1; v2] -> - begin match cval_ctyp v1 with - | CT_sbits _ -> - sprintf "neq_sbits(%s, %s)" (sgen_cval v1) (sgen_cval v2) - | _ -> - sprintf "(%s != %s)" (sgen_cval v1) (sgen_cval v2) - end - | Ilt, [v1; v2] -> - sprintf "(%s < %s)" (sgen_cval v1) (sgen_cval v2) - | Igt, [v1; v2] -> - sprintf "(%s > %s)" (sgen_cval v1) (sgen_cval v2) - | Ilteq, [v1; v2] -> - sprintf "(%s <= %s)" (sgen_cval v1) (sgen_cval v2) - | Igteq, [v1; v2] -> - sprintf "(%s >= %s)" (sgen_cval v1) (sgen_cval v2) - | Iadd, [v1; v2] -> - sprintf "(%s + %s)" (sgen_cval v1) (sgen_cval v2) - | Isub, [v1; v2] -> - sprintf "(%s - %s)" (sgen_cval v1) (sgen_cval v2) - | Unsigned 64, [vec] -> - sprintf "((mach_int) %s)" (sgen_cval vec) - | Signed 64, [vec] -> - begin match cval_ctyp vec with - | CT_fbits (n, _) -> - sprintf "fast_signed(%s, %d)" (sgen_cval vec) n - | _ -> assert false - end - | Bvand, [v1; v2] -> - begin match cval_ctyp v1 with - | CT_fbits _ -> - sprintf "(%s & %s)" (sgen_cval v1) (sgen_cval v2) - | CT_sbits _ -> - sprintf "and_sbits(%s, %s)" (sgen_cval v1) (sgen_cval v2) - | _ -> assert false - end - | Bvnot, [v] -> - begin match cval_ctyp v with - | CT_fbits (n, _) -> - sprintf "(~(%s) & %s)" (sgen_cval v) (sgen_cval (v_mask_lower n)) - | CT_sbits _ -> - sprintf "not_sbits(%s)" (sgen_cval v) - | _ -> assert false - end - | Bvor, [v1; v2] -> - begin match cval_ctyp v1 with - | CT_fbits _ -> - sprintf "(%s | %s)" (sgen_cval v1) (sgen_cval v2) - | CT_sbits _ -> - sprintf "or_sbits(%s, %s)" (sgen_cval v1) (sgen_cval v2) - | _ -> assert false - end - | Bvxor, [v1; v2] -> - begin match cval_ctyp v1 with - | CT_fbits _ -> - sprintf "(%s ^ %s)" (sgen_cval v1) (sgen_cval v2) - | CT_sbits _ -> - sprintf "xor_sbits(%s, %s)" (sgen_cval v1) (sgen_cval v2) - | _ -> assert false - end - | Bvadd, [v1; v2] -> - begin match cval_ctyp v1 with - | CT_fbits (n, _) -> - sprintf "((%s + %s) & %s)" (sgen_cval v1) (sgen_cval v2) (sgen_cval (v_mask_lower n)) - | CT_sbits _ -> - sprintf "add_sbits(%s, %s)" (sgen_cval v1) (sgen_cval v2) - | _ -> assert false - end - | Bvsub, [v1; v2] -> - begin match cval_ctyp v1 with - | CT_fbits (n, _) -> - sprintf "((%s - %s) & %s)" (sgen_cval v1) (sgen_cval v2) (sgen_cval (v_mask_lower n)) - | CT_sbits _ -> - sprintf "sub_sbits(%s, %s)" (sgen_cval v1) (sgen_cval v2) - | _ -> assert false - end - | Bvaccess, [vec; n] -> - begin match cval_ctyp vec with - | CT_fbits _ -> - sprintf "(UINT64_C(1) & (%s >> %s))" (sgen_cval vec) (sgen_cval n) - | CT_sbits _ -> - sprintf "(UINT64_C(1) & (%s.bits >> %s))" (sgen_cval vec) (sgen_cval n) - | _ -> assert false - end - | Slice len, [vec; start] -> - begin match cval_ctyp vec with - | CT_fbits _ -> - sprintf "(safe_rshift(UINT64_MAX, 64 - %d) & (%s >> %s))" len (sgen_cval vec) (sgen_cval start) - | CT_sbits _ -> - sprintf "(safe_rshift(UINT64_MAX, 64 - %d) & (%s.bits >> %s))" len (sgen_cval vec) (sgen_cval start) - | _ -> assert false - end - | Sslice 64, [vec; start; len] -> - begin match cval_ctyp vec with - | CT_fbits _ -> - sprintf "sslice(%s, %s, %s)" (sgen_cval vec) (sgen_cval start) (sgen_cval len) - | CT_sbits _ -> - sprintf "sslice(%s.bits, %s, %s)" (sgen_cval vec) (sgen_cval start) (sgen_cval len) - | _ -> assert false - end - | Set_slice, [vec; start; slice] -> - begin match cval_ctyp vec, cval_ctyp slice with - | CT_fbits (_, _), CT_fbits (m, _) -> - sprintf "((%s & ~(%s << %s)) | (%s << %s))" (sgen_cval vec) (sgen_mask m) (sgen_cval start) (sgen_cval slice) (sgen_cval start) - | _ -> assert false - end - | Zero_extend n, [v] -> - begin match cval_ctyp v with - | CT_fbits _ -> sgen_cval v - | CT_sbits _ -> - sprintf "fast_zero_extend(%s, %d)" (sgen_cval v) n - | _ -> assert false - end - | Sign_extend n, [v] -> - begin match cval_ctyp v with - | CT_fbits (m, _) -> - sprintf "fast_sign_extend(%s, %d, %d)" (sgen_cval v) m n - | CT_sbits _ -> - sprintf "fast_sign_extend2(%s, %d)" (sgen_cval v) n - | _ -> assert false - end - | Replicate n, [v] -> - begin match cval_ctyp v with - | CT_fbits (m, _) -> - sprintf "fast_replicate_bits(UINT64_C(%d), %s, %d)" m (sgen_cval v) n - | _ -> assert false - end - | Concat, [v1; v2] -> - (* Optimized routines for all combinations of fixed and small bits - appends, where the result is guaranteed to be smaller than 64. *) - begin match cval_ctyp v1, cval_ctyp v2 with - | CT_fbits (0, _), CT_fbits (_, _) -> - sgen_cval v2 - | CT_fbits (_, _), CT_fbits (n2, _) -> - sprintf "(%s << %d) | %s" (sgen_cval v1) n2 (sgen_cval v2) - | CT_sbits (64, _), CT_fbits (n2, _) -> - sprintf "append_sf(%s, %s, %d)" (sgen_cval v1) (sgen_cval v2) n2 - | CT_fbits (n1, _), CT_sbits (64, _) -> - sprintf "append_fs(%s, %d, %s)" (sgen_cval v1) n1 (sgen_cval v2) - | CT_sbits (64, _), CT_sbits (64, _) -> - sprintf "append_ss(%s, %s)" (sgen_cval v1) (sgen_cval v2) - | _ -> assert false - end - | _, _ -> - failwith "Could not generate cval primop" + | List_hd, [v] -> sprintf "(%s).hd" ("*" ^ sgen_cval v) + | List_tl, [v] -> sprintf "(%s).tl" ("*" ^ sgen_cval v) + | Eq, [v1; v2] -> begin + match cval_ctyp v1 with + | CT_sbits _ -> sprintf "eq_sbits(%s, %s)" (sgen_cval v1) (sgen_cval v2) + | _ -> sprintf "(%s == %s)" (sgen_cval v1) (sgen_cval v2) + end + | Neq, [v1; v2] -> begin + match cval_ctyp v1 with + | CT_sbits _ -> sprintf "neq_sbits(%s, %s)" (sgen_cval v1) (sgen_cval v2) + | _ -> sprintf "(%s != %s)" (sgen_cval v1) (sgen_cval v2) + end + | Ilt, [v1; v2] -> sprintf "(%s < %s)" (sgen_cval v1) (sgen_cval v2) + | Igt, [v1; v2] -> sprintf "(%s > %s)" (sgen_cval v1) (sgen_cval v2) + | Ilteq, [v1; v2] -> sprintf "(%s <= %s)" (sgen_cval v1) (sgen_cval v2) + | Igteq, [v1; v2] -> sprintf "(%s >= %s)" (sgen_cval v1) (sgen_cval v2) + | Iadd, [v1; v2] -> sprintf "(%s + %s)" (sgen_cval v1) (sgen_cval v2) + | Isub, [v1; v2] -> sprintf "(%s - %s)" (sgen_cval v1) (sgen_cval v2) + | Unsigned 64, [vec] -> sprintf "((mach_int) %s)" (sgen_cval vec) + | Signed 64, [vec] -> begin + match cval_ctyp vec with CT_fbits (n, _) -> sprintf "fast_signed(%s, %d)" (sgen_cval vec) n | _ -> assert false + end + | Bvand, [v1; v2] -> begin + match cval_ctyp v1 with + | CT_fbits _ -> sprintf "(%s & %s)" (sgen_cval v1) (sgen_cval v2) + | CT_sbits _ -> sprintf "and_sbits(%s, %s)" (sgen_cval v1) (sgen_cval v2) + | _ -> assert false + end + | Bvnot, [v] -> begin + match cval_ctyp v with + | CT_fbits (n, _) -> sprintf "(~(%s) & %s)" (sgen_cval v) (sgen_cval (v_mask_lower n)) + | CT_sbits _ -> sprintf "not_sbits(%s)" (sgen_cval v) + | _ -> assert false + end + | Bvor, [v1; v2] -> begin + match cval_ctyp v1 with + | CT_fbits _ -> sprintf "(%s | %s)" (sgen_cval v1) (sgen_cval v2) + | CT_sbits _ -> sprintf "or_sbits(%s, %s)" (sgen_cval v1) (sgen_cval v2) + | _ -> assert false + end + | Bvxor, [v1; v2] -> begin + match cval_ctyp v1 with + | CT_fbits _ -> sprintf "(%s ^ %s)" (sgen_cval v1) (sgen_cval v2) + | CT_sbits _ -> sprintf "xor_sbits(%s, %s)" (sgen_cval v1) (sgen_cval v2) + | _ -> assert false + end + | Bvadd, [v1; v2] -> begin + match cval_ctyp v1 with + | CT_fbits (n, _) -> sprintf "((%s + %s) & %s)" (sgen_cval v1) (sgen_cval v2) (sgen_cval (v_mask_lower n)) + | CT_sbits _ -> sprintf "add_sbits(%s, %s)" (sgen_cval v1) (sgen_cval v2) + | _ -> assert false + end + | Bvsub, [v1; v2] -> begin + match cval_ctyp v1 with + | CT_fbits (n, _) -> sprintf "((%s - %s) & %s)" (sgen_cval v1) (sgen_cval v2) (sgen_cval (v_mask_lower n)) + | CT_sbits _ -> sprintf "sub_sbits(%s, %s)" (sgen_cval v1) (sgen_cval v2) + | _ -> assert false + end + | Bvaccess, [vec; n] -> begin + match cval_ctyp vec with + | CT_fbits _ -> sprintf "(UINT64_C(1) & (%s >> %s))" (sgen_cval vec) (sgen_cval n) + | CT_sbits _ -> sprintf "(UINT64_C(1) & (%s.bits >> %s))" (sgen_cval vec) (sgen_cval n) + | _ -> assert false + end + | Slice len, [vec; start] -> begin + match cval_ctyp vec with + | CT_fbits _ -> sprintf "(safe_rshift(UINT64_MAX, 64 - %d) & (%s >> %s))" len (sgen_cval vec) (sgen_cval start) + | CT_sbits _ -> + sprintf "(safe_rshift(UINT64_MAX, 64 - %d) & (%s.bits >> %s))" len (sgen_cval vec) (sgen_cval start) + | _ -> assert false + end + | Sslice 64, [vec; start; len] -> begin + match cval_ctyp vec with + | CT_fbits _ -> sprintf "sslice(%s, %s, %s)" (sgen_cval vec) (sgen_cval start) (sgen_cval len) + | CT_sbits _ -> sprintf "sslice(%s.bits, %s, %s)" (sgen_cval vec) (sgen_cval start) (sgen_cval len) + | _ -> assert false + end + | Set_slice, [vec; start; slice] -> begin + match (cval_ctyp vec, cval_ctyp slice) with + | CT_fbits (_, _), CT_fbits (m, _) -> + sprintf "((%s & ~(%s << %s)) | (%s << %s))" (sgen_cval vec) (sgen_mask m) (sgen_cval start) (sgen_cval slice) + (sgen_cval start) + | _ -> assert false + end + | Zero_extend n, [v] -> begin + match cval_ctyp v with + | CT_fbits _ -> sgen_cval v + | CT_sbits _ -> sprintf "fast_zero_extend(%s, %d)" (sgen_cval v) n + | _ -> assert false + end + | Sign_extend n, [v] -> begin + match cval_ctyp v with + | CT_fbits (m, _) -> sprintf "fast_sign_extend(%s, %d, %d)" (sgen_cval v) m n + | CT_sbits _ -> sprintf "fast_sign_extend2(%s, %d)" (sgen_cval v) n + | _ -> assert false + end + | Replicate n, [v] -> begin + match cval_ctyp v with + | CT_fbits (m, _) -> sprintf "fast_replicate_bits(UINT64_C(%d), %s, %d)" m (sgen_cval v) n + | _ -> assert false + end + | Concat, [v1; v2] -> begin + (* Optimized routines for all combinations of fixed and small bits + appends, where the result is guaranteed to be smaller than 64. *) + match (cval_ctyp v1, cval_ctyp v2) with + | CT_fbits (0, _), CT_fbits (_, _) -> sgen_cval v2 + | CT_fbits (_, _), CT_fbits (n2, _) -> sprintf "(%s << %d) | %s" (sgen_cval v1) n2 (sgen_cval v2) + | CT_sbits (64, _), CT_fbits (n2, _) -> sprintf "append_sf(%s, %s, %d)" (sgen_cval v1) (sgen_cval v2) n2 + | CT_fbits (n1, _), CT_sbits (64, _) -> sprintf "append_fs(%s, %d, %s)" (sgen_cval v1) n1 (sgen_cval v2) + | CT_sbits (64, _), CT_sbits (64, _) -> sprintf "append_ss(%s, %s)" (sgen_cval v1) (sgen_cval v2) + | _ -> assert false + end + | _, _ -> failwith "Could not generate cval primop" let sgen_cval_param cval = match cval_ctyp cval with - | CT_lbits direction -> - sgen_cval cval ^ ", " ^ string_of_bool direction - | CT_sbits (_, direction) -> - sgen_cval cval ^ ", " ^ string_of_bool direction - | CT_fbits (len, direction) -> - sgen_cval cval ^ ", UINT64_C(" ^ string_of_int len ^ ") , " ^ string_of_bool direction - | _ -> - sgen_cval cval + | CT_lbits direction -> sgen_cval cval ^ ", " ^ string_of_bool direction + | CT_sbits (_, direction) -> sgen_cval cval ^ ", " ^ string_of_bool direction + | CT_fbits (len, direction) -> sgen_cval cval ^ ", UINT64_C(" ^ string_of_int len ^ ") , " ^ string_of_bool direction + | _ -> sgen_cval cval let rec sgen_clexp l = function | CL_id (Have_exception _, _) -> "have_exception" @@ -1312,74 +1081,69 @@ let rec codegen_conversion l clexp cval = let open Printf in let ctyp_to = clexp_ctyp clexp in let ctyp_from = cval_ctyp cval in - match ctyp_to, ctyp_from with + match (ctyp_to, ctyp_from) with (* When both types are equal, we don't need any conversion. *) | _, _ when ctyp_equal ctyp_to ctyp_from -> - if is_stack_ctyp ctyp_to then - ksprintf string " %s = %s;" (sgen_clexp_pure l clexp) (sgen_cval cval) - else - ksprintf string " COPY(%s)(%s, %s);" (sgen_ctyp_name ctyp_to) (sgen_clexp l clexp) (sgen_cval cval) - - | CT_ref _, _ -> - codegen_conversion l (CL_addr clexp) cval - + if is_stack_ctyp ctyp_to then ksprintf string " %s = %s;" (sgen_clexp_pure l clexp) (sgen_cval cval) + else ksprintf string " COPY(%s)(%s, %s);" (sgen_ctyp_name ctyp_to) (sgen_clexp l clexp) (sgen_cval cval) + | CT_ref _, _ -> codegen_conversion l (CL_addr clexp) cval | CT_vector (_, ctyp_elem_to), CT_vector (_, ctyp_elem_from) -> - let i = ngensym () in - let from = ngensym () in - let into = ngensym () in - ksprintf string " KILL(%s)(%s);" (sgen_ctyp_name ctyp_to) (sgen_clexp l clexp) ^^ hardline - ^^ ksprintf string " internal_vector_init_%s(%s, %s.len);" (sgen_ctyp_name ctyp_to) (sgen_clexp l clexp) (sgen_cval cval) ^^ hardline - ^^ ksprintf string " for (int %s = 0; %s < %s.len; %s++) {" (sgen_name i) (sgen_name i) (sgen_cval cval) (sgen_name i) ^^ hardline - ^^ (if is_stack_ctyp ctyp_elem_from then - ksprintf string " %s %s = %s.data[%s];" (sgen_ctyp ctyp_elem_from) (sgen_name from) (sgen_cval cval) (sgen_name i) - else - ksprintf string " %s %s;" (sgen_ctyp ctyp_elem_from) (sgen_name from) ^^ hardline - ^^ ksprintf string " CREATE(%s)(&%s);" (sgen_ctyp_name ctyp_elem_from) (sgen_name from) ^^ hardline - ^^ ksprintf string " COPY(%s)(&%s, %s.data[%s]);" (sgen_ctyp_name ctyp_elem_from) (sgen_name from) (sgen_cval cval) (sgen_name i) - ) - ^^ hardline - ^^ ksprintf string " %s %s;" (sgen_ctyp ctyp_elem_to) (sgen_name into) - ^^ (if is_stack_ctyp ctyp_elem_to then - empty - else - hardline ^^ ksprintf string " CREATE(%s)(&%s);" (sgen_ctyp_name ctyp_elem_to) (sgen_name into) - ) - ^^ nest 2 (hardline - ^^ codegen_conversion l (CL_id (into, ctyp_elem_to)) (V_id (from, ctyp_elem_from))) - ^^ hardline - ^^ (if is_stack_ctyp ctyp_elem_to then - ksprintf string " %s.data[%s] = %s;" (sgen_clexp_pure l clexp) (sgen_name i) (sgen_name into) - else - ksprintf string " COPY(%s)(&((%s)->data[%s]), %s);" (sgen_ctyp_name ctyp_elem_to) (sgen_clexp l clexp) (sgen_name i) (sgen_name into) - ^^ hardline ^^ ksprintf string " KILL(%s)(&%s);" (sgen_ctyp_name ctyp_elem_to) (sgen_name into) - ) - ^^ (if is_stack_ctyp ctyp_elem_from then - empty - else - hardline ^^ ksprintf string " KILL(%s)(&%s);" (sgen_ctyp_name ctyp_elem_from) (sgen_name from) - ) - ^^ hardline - ^^ string " }" - + let i = ngensym () in + let from = ngensym () in + let into = ngensym () in + ksprintf string " KILL(%s)(%s);" (sgen_ctyp_name ctyp_to) (sgen_clexp l clexp) + ^^ hardline + ^^ ksprintf string " internal_vector_init_%s(%s, %s.len);" (sgen_ctyp_name ctyp_to) (sgen_clexp l clexp) + (sgen_cval cval) + ^^ hardline + ^^ ksprintf string " for (int %s = 0; %s < %s.len; %s++) {" (sgen_name i) (sgen_name i) (sgen_cval cval) + (sgen_name i) + ^^ hardline + ^^ ( if is_stack_ctyp ctyp_elem_from then + ksprintf string " %s %s = %s.data[%s];" (sgen_ctyp ctyp_elem_from) (sgen_name from) (sgen_cval cval) + (sgen_name i) + else + ksprintf string " %s %s;" (sgen_ctyp ctyp_elem_from) (sgen_name from) + ^^ hardline + ^^ ksprintf string " CREATE(%s)(&%s);" (sgen_ctyp_name ctyp_elem_from) (sgen_name from) + ^^ hardline + ^^ ksprintf string " COPY(%s)(&%s, %s.data[%s]);" (sgen_ctyp_name ctyp_elem_from) (sgen_name from) + (sgen_cval cval) (sgen_name i) + ) + ^^ hardline + ^^ ksprintf string " %s %s;" (sgen_ctyp ctyp_elem_to) (sgen_name into) + ^^ ( if is_stack_ctyp ctyp_elem_to then empty + else hardline ^^ ksprintf string " CREATE(%s)(&%s);" (sgen_ctyp_name ctyp_elem_to) (sgen_name into) + ) + ^^ nest 2 (hardline ^^ codegen_conversion l (CL_id (into, ctyp_elem_to)) (V_id (from, ctyp_elem_from))) + ^^ hardline + ^^ ( if is_stack_ctyp ctyp_elem_to then + ksprintf string " %s.data[%s] = %s;" (sgen_clexp_pure l clexp) (sgen_name i) (sgen_name into) + else + ksprintf string " COPY(%s)(&((%s)->data[%s]), %s);" (sgen_ctyp_name ctyp_elem_to) (sgen_clexp l clexp) + (sgen_name i) (sgen_name into) + ^^ hardline + ^^ ksprintf string " KILL(%s)(&%s);" (sgen_ctyp_name ctyp_elem_to) (sgen_name into) + ) + ^^ ( if is_stack_ctyp ctyp_elem_from then empty + else hardline ^^ ksprintf string " KILL(%s)(&%s);" (sgen_ctyp_name ctyp_elem_from) (sgen_name from) + ) + ^^ hardline ^^ string " }" (* If we have to convert between tuple types, convert the fields individually. *) | CT_tup ctyps_to, CT_tup ctyps_from when List.length ctyps_to = List.length ctyps_from -> - let len = List.length ctyps_to in - let conversions = - List.mapi (fun i _ -> codegen_conversion l (CL_tuple (clexp, i)) (V_tuple_member (cval, len, i))) ctyps_from - in - string " /* conversions */" - ^^ hardline - ^^ separate hardline conversions - ^^ hardline - ^^ string " /* end conversions */" - + let len = List.length ctyps_to in + let conversions = + List.mapi (fun i _ -> codegen_conversion l (CL_tuple (clexp, i)) (V_tuple_member (cval, len, i))) ctyps_from + in + string " /* conversions */" ^^ hardline ^^ separate hardline conversions ^^ hardline + ^^ string " /* end conversions */" (* For anything not special cased, just try to call a appropriate CONVERT_OF function. *) | _, _ when is_stack_ctyp (clexp_ctyp clexp) -> - ksprintf string " %s = CONVERT_OF(%s, %s)(%s);" - (sgen_clexp_pure l clexp) (sgen_ctyp_name ctyp_to) (sgen_ctyp_name ctyp_from) (sgen_cval_param cval) + ksprintf string " %s = CONVERT_OF(%s, %s)(%s);" (sgen_clexp_pure l clexp) (sgen_ctyp_name ctyp_to) + (sgen_ctyp_name ctyp_from) (sgen_cval_param cval) | _, _ -> - ksprintf string " CONVERT_OF(%s, %s)(%s, %s);" - (sgen_ctyp_name ctyp_to) (sgen_ctyp_name ctyp_from) (sgen_clexp l clexp) (sgen_cval_param cval) + ksprintf string " CONVERT_OF(%s, %s)(%s, %s);" (sgen_ctyp_name ctyp_to) (sgen_ctyp_name ctyp_from) + (sgen_clexp l clexp) (sgen_cval_param cval) (* PPrint doesn't provide a nice way to filter out empty documents *) let squash_empty docs = List.filter (fun doc -> requirement doc > 0) docs @@ -1388,409 +1152,375 @@ let sq_separate_map sep f xs = separate sep (squash_empty (List.map f xs)) let rec codegen_instr fid ctx (I_aux (instr, (_, l))) = let open Printf in match instr with - | I_decl (ctyp, id) when is_stack_ctyp ctyp -> - ksprintf string " %s %s;" (sgen_ctyp ctyp) (sgen_name id) + | I_decl (ctyp, id) when is_stack_ctyp ctyp -> ksprintf string " %s %s;" (sgen_ctyp ctyp) (sgen_name id) | I_decl (ctyp, id) -> - ksprintf string " %s %s;" (sgen_ctyp ctyp) (sgen_name id) ^^ hardline - ^^ ksprintf string " CREATE(%s)(&%s);" (sgen_ctyp_name ctyp) (sgen_name id) - + ksprintf string " %s %s;" (sgen_ctyp ctyp) (sgen_name id) + ^^ hardline + ^^ ksprintf string " CREATE(%s)(&%s);" (sgen_ctyp_name ctyp) (sgen_name id) | I_copy (clexp, cval) -> codegen_conversion l clexp cval - - | I_jump (cval, label) -> - ksprintf string " if (%s) goto %s;" (sgen_cval cval) label - + | I_jump (cval, label) -> ksprintf string " if (%s) goto %s;" (sgen_cval cval) label | I_if (cval, [then_instr], [], _) -> - ksprintf string " if (%s)" (sgen_cval cval) ^^ hardline - ^^ twice space ^^ codegen_instr fid ctx then_instr + ksprintf string " if (%s)" (sgen_cval cval) ^^ hardline ^^ twice space ^^ codegen_instr fid ctx then_instr | I_if (cval, then_instrs, [], _) -> - string " if" ^^ space ^^ parens (string (sgen_cval cval)) ^^ space - ^^ surround 2 0 lbrace (separate_map hardline (codegen_instr fid ctx) then_instrs) (twice space ^^ rbrace) + string " if" ^^ space + ^^ parens (string (sgen_cval cval)) + ^^ space + ^^ surround 2 0 lbrace (separate_map hardline (codegen_instr fid ctx) then_instrs) (twice space ^^ rbrace) | I_if (cval, then_instrs, else_instrs, _) -> - string " if" ^^ space ^^ parens (string (sgen_cval cval)) ^^ space - ^^ surround 2 0 lbrace (sq_separate_map hardline (codegen_instr fid ctx) then_instrs) (twice space ^^ rbrace) - ^^ space ^^ string "else" ^^ space - ^^ surround 2 0 lbrace (sq_separate_map hardline (codegen_instr fid ctx) else_instrs) (twice space ^^ rbrace) - + string " if" ^^ space + ^^ parens (string (sgen_cval cval)) + ^^ space + ^^ surround 2 0 lbrace (sq_separate_map hardline (codegen_instr fid ctx) then_instrs) (twice space ^^ rbrace) + ^^ space ^^ string "else" ^^ space + ^^ surround 2 0 lbrace (sq_separate_map hardline (codegen_instr fid ctx) else_instrs) (twice space ^^ rbrace) | I_block instrs -> - string " {" - ^^ jump 2 2 (sq_separate_map hardline (codegen_instr fid ctx) instrs) ^^ hardline - ^^ string " }" - + string " {" ^^ jump 2 2 (sq_separate_map hardline (codegen_instr fid ctx) instrs) ^^ hardline ^^ string " }" | I_try_block instrs -> - string " { /* try */" - ^^ jump 2 2 (sq_separate_map hardline (codegen_instr fid ctx) instrs) ^^ hardline - ^^ string " }" - + string " { /* try */" + ^^ jump 2 2 (sq_separate_map hardline (codegen_instr fid ctx) instrs) + ^^ hardline ^^ string " }" | I_funcall (x, special_extern, f, args) -> - let c_args = Util.string_of_list ", " sgen_cval args in - let ctyp = clexp_ctyp x in - let is_extern = ctx_is_extern (fst f) ctx || special_extern in - let fname = - if special_extern then - string_of_id (fst f) - else if ctx_is_extern (fst f) ctx then - ctx_get_extern (fst f) ctx - else - sgen_function_uid f - in - let fname = - match fname, ctyp with - | "internal_pick", _ -> Printf.sprintf "pick_%s" (sgen_ctyp_name ctyp) - | "sail_cons", _ -> - begin match snd f with - | [ctyp] -> Util.zencode_string ("cons#" ^ string_of_ctyp ctyp) - | _ -> c_error "cons without specified type" + let c_args = Util.string_of_list ", " sgen_cval args in + let ctyp = clexp_ctyp x in + let is_extern = ctx_is_extern (fst f) ctx || special_extern in + let fname = + if special_extern then string_of_id (fst f) + else if ctx_is_extern (fst f) ctx then ctx_get_extern (fst f) ctx + else sgen_function_uid f + in + let fname = + match (fname, ctyp) with + | "internal_pick", _ -> Printf.sprintf "pick_%s" (sgen_ctyp_name ctyp) + | "sail_cons", _ -> begin + match snd f with + | [ctyp] -> Util.zencode_string ("cons#" ^ string_of_ctyp ctyp) + | _ -> c_error "cons without specified type" end - | "eq_anything", _ -> - begin match args with - | cval :: _ -> Printf.sprintf "eq_%s" (sgen_ctyp_name (cval_ctyp cval)) - | _ -> c_error "eq_anything function with bad arity." + | "eq_anything", _ -> begin + match args with + | cval :: _ -> Printf.sprintf "eq_%s" (sgen_ctyp_name (cval_ctyp cval)) + | _ -> c_error "eq_anything function with bad arity." end - | "length", _ -> - begin match args with - | cval :: _ -> Printf.sprintf "length_%s" (sgen_ctyp_name (cval_ctyp cval)) - | _ -> c_error "length function with bad arity." + | "length", _ -> begin + match args with + | cval :: _ -> Printf.sprintf "length_%s" (sgen_ctyp_name (cval_ctyp cval)) + | _ -> c_error "length function with bad arity." end - | "vector_access", CT_bit -> "bitvector_access" - | "vector_access", _ -> - begin match args with - | cval :: _ -> Printf.sprintf "vector_access_%s" (sgen_ctyp_name (cval_ctyp cval)) - | _ -> c_error "vector access function with bad arity." + | "vector_access", CT_bit -> "bitvector_access" + | "vector_access", _ -> begin + match args with + | cval :: _ -> Printf.sprintf "vector_access_%s" (sgen_ctyp_name (cval_ctyp cval)) + | _ -> c_error "vector access function with bad arity." end - | "vector_update_subrange", _ -> Printf.sprintf "vector_update_subrange_%s" (sgen_ctyp_name ctyp) - | "vector_subrange", _ -> Printf.sprintf "vector_subrange_%s" (sgen_ctyp_name ctyp) - | "vector_update", CT_fbits _ -> "update_fbits" - | "vector_update", CT_lbits _ -> "update_lbits" - | "vector_update", _ -> Printf.sprintf "vector_update_%s" (sgen_ctyp_name ctyp) - | "string_of_bits", _ -> - begin match cval_ctyp (List.nth args 0) with - | CT_fbits _ -> "string_of_fbits" - | CT_lbits _ -> "string_of_lbits" - | _ -> assert false + | "vector_update_subrange", _ -> Printf.sprintf "vector_update_subrange_%s" (sgen_ctyp_name ctyp) + | "vector_subrange", _ -> Printf.sprintf "vector_subrange_%s" (sgen_ctyp_name ctyp) + | "vector_update", CT_fbits _ -> "update_fbits" + | "vector_update", CT_lbits _ -> "update_lbits" + | "vector_update", _ -> Printf.sprintf "vector_update_%s" (sgen_ctyp_name ctyp) + | "string_of_bits", _ -> begin + match cval_ctyp (List.nth args 0) with + | CT_fbits _ -> "string_of_fbits" + | CT_lbits _ -> "string_of_lbits" + | _ -> assert false end - | "decimal_string_of_bits", _ -> - begin match cval_ctyp (List.nth args 0) with - | CT_fbits _ -> "decimal_string_of_fbits" - | CT_lbits _ -> "decimal_string_of_lbits" - | _ -> assert false + | "decimal_string_of_bits", _ -> begin + match cval_ctyp (List.nth args 0) with + | CT_fbits _ -> "decimal_string_of_fbits" + | CT_lbits _ -> "decimal_string_of_lbits" + | _ -> assert false end - | "internal_vector_update", _ -> Printf.sprintf "internal_vector_update_%s" (sgen_ctyp_name ctyp) - | "internal_vector_init", _ -> Printf.sprintf "internal_vector_init_%s" (sgen_ctyp_name ctyp) - | "undefined_bitvector", CT_fbits _ -> "UNDEFINED(fbits)" - | "undefined_bitvector", CT_lbits _ -> "UNDEFINED(lbits)" - | "undefined_bit", _ -> "UNDEFINED(fbits)" - | "undefined_vector", _ -> Printf.sprintf "UNDEFINED(vector_%s)" (sgen_ctyp_name ctyp) - | "undefined_list", _ -> Printf.sprintf "UNDEFINED(%s)" (sgen_ctyp_name ctyp) - | fname, _ -> fname - in - if fname = "reg_deref" then - if is_stack_ctyp ctyp then - string (Printf.sprintf " %s = *(%s);" (sgen_clexp_pure l x) c_args) - else - string (Printf.sprintf " COPY(%s)(&%s, *(%s));" (sgen_ctyp_name ctyp) (sgen_clexp_pure l x) c_args) - else - if is_stack_ctyp ctyp then - string (Printf.sprintf " %s = %s(%s%s);" (sgen_clexp_pure l x) fname (extra_arguments is_extern) c_args) - else - string (Printf.sprintf " %s(%s%s, %s);" fname (extra_arguments is_extern) (sgen_clexp l x) c_args) - - | I_clear (ctyp, _) when is_stack_ctyp ctyp -> - empty - | I_clear (ctyp, id) -> - string (Printf.sprintf " KILL(%s)(&%s);" (sgen_ctyp_name ctyp) (sgen_name id)) - + | "internal_vector_update", _ -> Printf.sprintf "internal_vector_update_%s" (sgen_ctyp_name ctyp) + | "internal_vector_init", _ -> Printf.sprintf "internal_vector_init_%s" (sgen_ctyp_name ctyp) + | "undefined_bitvector", CT_fbits _ -> "UNDEFINED(fbits)" + | "undefined_bitvector", CT_lbits _ -> "UNDEFINED(lbits)" + | "undefined_bit", _ -> "UNDEFINED(fbits)" + | "undefined_vector", _ -> Printf.sprintf "UNDEFINED(vector_%s)" (sgen_ctyp_name ctyp) + | "undefined_list", _ -> Printf.sprintf "UNDEFINED(%s)" (sgen_ctyp_name ctyp) + | fname, _ -> fname + in + if fname = "reg_deref" then + if is_stack_ctyp ctyp then string (Printf.sprintf " %s = *(%s);" (sgen_clexp_pure l x) c_args) + else string (Printf.sprintf " COPY(%s)(&%s, *(%s));" (sgen_ctyp_name ctyp) (sgen_clexp_pure l x) c_args) + else if is_stack_ctyp ctyp then + string (Printf.sprintf " %s = %s(%s%s);" (sgen_clexp_pure l x) fname (extra_arguments is_extern) c_args) + else string (Printf.sprintf " %s(%s%s, %s);" fname (extra_arguments is_extern) (sgen_clexp l x) c_args) + | I_clear (ctyp, _) when is_stack_ctyp ctyp -> empty + | I_clear (ctyp, id) -> string (Printf.sprintf " KILL(%s)(&%s);" (sgen_ctyp_name ctyp) (sgen_name id)) | I_init (ctyp, id, cval) -> - codegen_instr fid ctx (idecl l ctyp id) ^^ hardline - ^^ codegen_conversion l (CL_id (id, ctyp)) cval - + codegen_instr fid ctx (idecl l ctyp id) ^^ hardline ^^ codegen_conversion l (CL_id (id, ctyp)) cval | I_reinit (ctyp, id, cval) -> - codegen_instr fid ctx (ireset l ctyp id) ^^ hardline - ^^ codegen_conversion l (CL_id (id, ctyp)) cval - - | I_reset (ctyp, id) when is_stack_ctyp ctyp -> - string (Printf.sprintf " %s %s;" (sgen_ctyp ctyp) (sgen_name id)) - | I_reset (ctyp, id) -> - string (Printf.sprintf " RECREATE(%s)(&%s);" (sgen_ctyp_name ctyp) (sgen_name id)) - - | I_return cval -> - string (Printf.sprintf " return %s;" (sgen_cval cval)) - - | I_throw _ -> - c_error ~loc:l "I_throw reached code generator" - + codegen_instr fid ctx (ireset l ctyp id) ^^ hardline ^^ codegen_conversion l (CL_id (id, ctyp)) cval + | I_reset (ctyp, id) when is_stack_ctyp ctyp -> string (Printf.sprintf " %s %s;" (sgen_ctyp ctyp) (sgen_name id)) + | I_reset (ctyp, id) -> string (Printf.sprintf " RECREATE(%s)(&%s);" (sgen_ctyp_name ctyp) (sgen_name id)) + | I_return cval -> string (Printf.sprintf " return %s;" (sgen_cval cval)) + | I_throw _ -> c_error ~loc:l "I_throw reached code generator" | I_undefined ctyp -> - let rec codegen_exn_return ctyp = - match ctyp with - | CT_unit -> "UNIT", [] - | CT_bit -> "UINT64_C(0)", [] - | CT_fint _ -> "INT64_C(0xdeadc0de)", [] - | CT_lint when !optimize_fixed_int -> "((sail_int) 0xdeadc0de)", [] - | CT_fbits _ -> "UINT64_C(0xdeadc0de)", [] - | CT_sbits _ -> "undefined_sbits()", [] - | CT_lbits _ when !optimize_fixed_bits -> "undefined_lbits(false)", [] - | CT_bool -> "false", [] - | CT_enum (_, ctor :: _) -> sgen_id ctor, [] - | CT_tup ctyps when is_stack_ctyp ctyp -> - let gs = ngensym () in - let fold (inits, prev) (n, ctyp) = - let init, prev' = codegen_exn_return ctyp in - Printf.sprintf ".ztup%d = %s" n init :: inits, prev @ prev' - in - let inits, prev = List.fold_left fold ([], []) (List.mapi (fun i x -> (i, x)) ctyps) in - sgen_name gs, - [Printf.sprintf "struct %s %s = { " (sgen_ctyp_name ctyp) (sgen_name gs) - ^ Util.string_of_list ", " (fun x -> x) inits ^ " };"] @ prev - | CT_struct (_, ctors) when is_stack_ctyp ctyp -> - let gs = ngensym () in - let fold (inits, prev) (id, ctyp) = - let init, prev' = codegen_exn_return ctyp in - Printf.sprintf ".%s = %s" (sgen_id id) init :: inits, prev @ prev' - in - let inits, prev = List.fold_left fold ([], []) ctors in - sgen_name gs, - [Printf.sprintf "struct %s %s = { " (sgen_ctyp_name ctyp) (sgen_name gs) - ^ Util.string_of_list ", " (fun x -> x) inits ^ " };"] @ prev - | CT_ref _ -> "NULL", [] - | ctyp -> c_error ("Cannot create undefined value for type: " ^ string_of_ctyp ctyp) - in - let ret, prev = codegen_exn_return ctyp in - separate_map hardline (fun str -> string (" " ^ str)) (List.rev prev) - ^^ hardline - ^^ string (Printf.sprintf " return %s;" ret) - - | I_comment str -> - string (" /* " ^ str ^ " */") - - | I_label str -> - string (str ^ ": ;") - - | I_goto str -> - string (Printf.sprintf " goto %s;" str) - + let rec codegen_exn_return ctyp = + match ctyp with + | CT_unit -> ("UNIT", []) + | CT_bit -> ("UINT64_C(0)", []) + | CT_fint _ -> ("INT64_C(0xdeadc0de)", []) + | CT_lint when !optimize_fixed_int -> ("((sail_int) 0xdeadc0de)", []) + | CT_fbits _ -> ("UINT64_C(0xdeadc0de)", []) + | CT_sbits _ -> ("undefined_sbits()", []) + | CT_lbits _ when !optimize_fixed_bits -> ("undefined_lbits(false)", []) + | CT_bool -> ("false", []) + | CT_enum (_, ctor :: _) -> (sgen_id ctor, []) + | CT_tup ctyps when is_stack_ctyp ctyp -> + let gs = ngensym () in + let fold (inits, prev) (n, ctyp) = + let init, prev' = codegen_exn_return ctyp in + (Printf.sprintf ".ztup%d = %s" n init :: inits, prev @ prev') + in + let inits, prev = List.fold_left fold ([], []) (List.mapi (fun i x -> (i, x)) ctyps) in + ( sgen_name gs, + [ + Printf.sprintf "struct %s %s = { " (sgen_ctyp_name ctyp) (sgen_name gs) + ^ Util.string_of_list ", " (fun x -> x) inits + ^ " };"; + ] + @ prev + ) + | CT_struct (_, ctors) when is_stack_ctyp ctyp -> + let gs = ngensym () in + let fold (inits, prev) (id, ctyp) = + let init, prev' = codegen_exn_return ctyp in + (Printf.sprintf ".%s = %s" (sgen_id id) init :: inits, prev @ prev') + in + let inits, prev = List.fold_left fold ([], []) ctors in + ( sgen_name gs, + [ + Printf.sprintf "struct %s %s = { " (sgen_ctyp_name ctyp) (sgen_name gs) + ^ Util.string_of_list ", " (fun x -> x) inits + ^ " };"; + ] + @ prev + ) + | CT_ref _ -> ("NULL", []) + | ctyp -> c_error ("Cannot create undefined value for type: " ^ string_of_ctyp ctyp) + in + let ret, prev = codegen_exn_return ctyp in + separate_map hardline (fun str -> string (" " ^ str)) (List.rev prev) + ^^ hardline + ^^ string (Printf.sprintf " return %s;" ret) + | I_comment str -> string (" /* " ^ str ^ " */") + | I_label str -> string (str ^ ": ;") + | I_goto str -> string (Printf.sprintf " goto %s;" str) | I_raw _ when ctx.no_raw -> empty - | I_raw str -> - string (" " ^ str) - + | I_raw str -> string (" " ^ str) | I_end _ -> assert false - - | I_exit _ -> - string (" sail_match_failure(\"" ^ String.escaped (string_of_id fid) ^ "\");") + | I_exit _ -> string (" sail_match_failure(\"" ^ String.escaped (string_of_id fid) ^ "\");") let codegen_type_def = function - | CTD_enum (id, ((first_id :: _) as ids)) -> - let codegen_eq = - let name = sgen_id id in - string (Printf.sprintf "static bool eq_%s(enum %s op1, enum %s op2) { return op1 == op2; }" name name name) - in - let codegen_undefined = - let name = sgen_id id in - string (Printf.sprintf "static enum %s UNDEFINED(%s)(unit u) { return %s; }" name name (sgen_id first_id)) - in - string (Printf.sprintf "// enum %s" (string_of_id id)) ^^ hardline - ^^ separate space [string "enum"; codegen_id id; lbrace; separate_map (comma ^^ space) codegen_id ids; rbrace ^^ semi] - ^^ twice hardline - ^^ codegen_eq - ^^ twice hardline - ^^ codegen_undefined - + | CTD_enum (id, (first_id :: _ as ids)) -> + let codegen_eq = + let name = sgen_id id in + string (Printf.sprintf "static bool eq_%s(enum %s op1, enum %s op2) { return op1 == op2; }" name name name) + in + let codegen_undefined = + let name = sgen_id id in + string (Printf.sprintf "static enum %s UNDEFINED(%s)(unit u) { return %s; }" name name (sgen_id first_id)) + in + string (Printf.sprintf "// enum %s" (string_of_id id)) + ^^ hardline + ^^ separate space + [string "enum"; codegen_id id; lbrace; separate_map (comma ^^ space) codegen_id ids; rbrace ^^ semi] + ^^ twice hardline ^^ codegen_eq ^^ twice hardline ^^ codegen_undefined | CTD_enum (id, []) -> c_error ("Cannot compile empty enum " ^ string_of_id id) - | CTD_struct (id, ctors) -> - let struct_ctyp = CT_struct (id, ctors) in - (* Generate a set_T function for every struct T *) - let codegen_set (id, ctyp) = - if is_stack_ctyp ctyp then - string (Printf.sprintf "rop->%s = op.%s;" (sgen_id id) (sgen_id id)) - else - string (Printf.sprintf "COPY(%s)(&rop->%s, op.%s);" (sgen_ctyp_name ctyp) (sgen_id id) (sgen_id id)) - in - let codegen_setter id ctors = - string (let n = sgen_id id in Printf.sprintf "static void COPY(%s)(struct %s *rop, const struct %s op)" n n n) ^^ space - ^^ surround 2 0 lbrace - (separate_map hardline codegen_set (Bindings.bindings ctors)) - rbrace - in - (* Generate an init/clear_T function for every struct T *) - let codegen_field_init f (id, ctyp) = - if not (is_stack_ctyp ctyp) then - [string (Printf.sprintf "%s(%s)(&op->%s);" f (sgen_ctyp_name ctyp) (sgen_id id))] - else [] - in - let codegen_init f id ctors = - string (let n = sgen_id id in Printf.sprintf "static void %s(%s)(struct %s *op)" f n n) ^^ space - ^^ surround 2 0 lbrace - (separate hardline (Bindings.bindings ctors |> List.map (codegen_field_init f) |> List.concat)) - rbrace - in - let codegen_eq = - let codegen_eq_test (id, ctyp) = - string (Printf.sprintf "EQUAL(%s)(op1.%s, op2.%s)" (sgen_ctyp_name ctyp) (sgen_id id) (sgen_id id)) - in - string (Printf.sprintf "static bool EQUAL(%s)(struct %s op1, struct %s op2)" (sgen_id id) (sgen_id id) (sgen_id id)) - ^^ space - ^^ surround 2 0 lbrace - (string "return" ^^ space - ^^ separate_map (string " && ") codegen_eq_test ctors - ^^ string ";") - rbrace - in - (* Generate the struct and add the generated functions *) - let codegen_ctor (id, ctyp) = - string (sgen_ctyp ctyp) ^^ space ^^ codegen_id id - in - string (Printf.sprintf "// struct %s" (string_of_id id)) ^^ hardline - ^^ string "struct" ^^ space ^^ codegen_id id ^^ space - ^^ surround 2 0 lbrace - (separate_map (semi ^^ hardline) codegen_ctor ctors ^^ semi) - rbrace - ^^ semi ^^ twice hardline - ^^ codegen_setter id (ctor_bindings ctors) - ^^ (if not (is_stack_ctyp struct_ctyp) then - twice hardline - ^^ codegen_init "CREATE" id (ctor_bindings ctors) - ^^ twice hardline - ^^ codegen_init "RECREATE" id (ctor_bindings ctors) - ^^ twice hardline - ^^ codegen_init "KILL" id (ctor_bindings ctors) - else empty) - ^^ twice hardline - ^^ codegen_eq - + let struct_ctyp = CT_struct (id, ctors) in + (* Generate a set_T function for every struct T *) + let codegen_set (id, ctyp) = + if is_stack_ctyp ctyp then string (Printf.sprintf "rop->%s = op.%s;" (sgen_id id) (sgen_id id)) + else string (Printf.sprintf "COPY(%s)(&rop->%s, op.%s);" (sgen_ctyp_name ctyp) (sgen_id id) (sgen_id id)) + in + let codegen_setter id ctors = + string + (let n = sgen_id id in + Printf.sprintf "static void COPY(%s)(struct %s *rop, const struct %s op)" n n n + ) + ^^ space + ^^ surround 2 0 lbrace (separate_map hardline codegen_set (Bindings.bindings ctors)) rbrace + in + (* Generate an init/clear_T function for every struct T *) + let codegen_field_init f (id, ctyp) = + if not (is_stack_ctyp ctyp) then + [string (Printf.sprintf "%s(%s)(&op->%s);" f (sgen_ctyp_name ctyp) (sgen_id id))] + else [] + in + let codegen_init f id ctors = + string + (let n = sgen_id id in + Printf.sprintf "static void %s(%s)(struct %s *op)" f n n + ) + ^^ space + ^^ surround 2 0 lbrace + (separate hardline (Bindings.bindings ctors |> List.map (codegen_field_init f) |> List.concat)) + rbrace + in + let codegen_eq = + let codegen_eq_test (id, ctyp) = + string (Printf.sprintf "EQUAL(%s)(op1.%s, op2.%s)" (sgen_ctyp_name ctyp) (sgen_id id) (sgen_id id)) + in + string + (Printf.sprintf "static bool EQUAL(%s)(struct %s op1, struct %s op2)" (sgen_id id) (sgen_id id) (sgen_id id)) + ^^ space + ^^ surround 2 0 lbrace + (string "return" ^^ space ^^ separate_map (string " && ") codegen_eq_test ctors ^^ string ";") + rbrace + in + (* Generate the struct and add the generated functions *) + let codegen_ctor (id, ctyp) = string (sgen_ctyp ctyp) ^^ space ^^ codegen_id id in + string (Printf.sprintf "// struct %s" (string_of_id id)) + ^^ hardline ^^ string "struct" ^^ space ^^ codegen_id id ^^ space + ^^ surround 2 0 lbrace (separate_map (semi ^^ hardline) codegen_ctor ctors ^^ semi) rbrace + ^^ semi ^^ twice hardline + ^^ codegen_setter id (ctor_bindings ctors) + ^^ ( if not (is_stack_ctyp struct_ctyp) then + twice hardline + ^^ codegen_init "CREATE" id (ctor_bindings ctors) + ^^ twice hardline + ^^ codegen_init "RECREATE" id (ctor_bindings ctors) + ^^ twice hardline + ^^ codegen_init "KILL" id (ctor_bindings ctors) + else empty + ) + ^^ twice hardline ^^ codegen_eq | CTD_variant (id, tus) -> - let codegen_tu (ctor_id, ctyp) = - separate space [string "struct"; lbrace; string (sgen_ctyp ctyp); codegen_id ctor_id ^^ semi; rbrace] - in - (* Create an if, else if, ... block that does something for each constructor *) - let rec each_ctor v f = function - | [] -> string "{}" - | [(ctor_id, ctyp)] -> - string (Printf.sprintf "if (%skind == Kind_%s)" v (sgen_id ctor_id)) ^^ lbrace ^^ hardline - ^^ jump 0 2 (f ctor_id ctyp) - ^^ hardline ^^ rbrace - | (ctor_id, ctyp) :: ctors -> - string (Printf.sprintf "if (%skind == Kind_%s) " v (sgen_id ctor_id)) ^^ lbrace ^^ hardline - ^^ jump 0 2 (f ctor_id ctyp) - ^^ hardline ^^ rbrace ^^ string " else " ^^ each_ctor v f ctors - in - let codegen_init = - let n = sgen_id id in - let ctor_id, ctyp = List.hd tus in - string (Printf.sprintf "static void CREATE(%s)(struct %s *op)" n n) - ^^ hardline - ^^ surround 2 0 lbrace - (string (Printf.sprintf "op->kind = Kind_%s;" (sgen_id ctor_id)) ^^ hardline - ^^ if not (is_stack_ctyp ctyp) then - string (Printf.sprintf "CREATE(%s)(&op->%s);" (sgen_ctyp_name ctyp) (sgen_id ctor_id)) - else empty) - rbrace - in - let codegen_reinit = - let n = sgen_id id in - string (Printf.sprintf "static void RECREATE(%s)(struct %s *op) {}" n n) - in - let clear_field v ctor_id ctyp = - if is_stack_ctyp ctyp then - string (Printf.sprintf "/* do nothing */") - else - string (Printf.sprintf "KILL(%s)(&%s->%s);" (sgen_ctyp_name ctyp) v (sgen_id ctor_id)) - in - let codegen_clear = - let n = sgen_id id in - string (Printf.sprintf "static void KILL(%s)(struct %s *op)" n n) ^^ hardline - ^^ surround 2 0 lbrace - (each_ctor "op->" (clear_field "op") tus ^^ semi) - rbrace - in - let codegen_ctor (ctor_id, ctyp) = - let ctor_args, tuple, tuple_cleanup = - Printf.sprintf "%s op" (sgen_ctyp ctyp), empty, empty - in - string (Printf.sprintf "static void %s(%sstruct %s *rop, %s)" (sgen_function_id ctor_id) (extra_params ()) (sgen_id id) ctor_args) ^^ hardline - ^^ surround 2 0 lbrace - (tuple - ^^ each_ctor "rop->" (clear_field "rop") tus ^^ hardline - ^^ string ("rop->kind = Kind_" ^ sgen_id ctor_id) ^^ semi ^^ hardline - ^^ if is_stack_ctyp ctyp then - string (Printf.sprintf "rop->%s = op;" (sgen_id ctor_id)) - else - string (Printf.sprintf "CREATE(%s)(&rop->%s);" (sgen_ctyp_name ctyp) (sgen_id ctor_id)) ^^ hardline - ^^ string (Printf.sprintf "COPY(%s)(&rop->%s, op);" (sgen_ctyp_name ctyp) (sgen_id ctor_id)) ^^ hardline - ^^ tuple_cleanup) - rbrace - in - let codegen_setter = - let n = sgen_id id in - let set_field ctor_id ctyp = - if is_stack_ctyp ctyp then - string (Printf.sprintf "rop->%s = op.%s;" (sgen_id ctor_id) (sgen_id ctor_id)) - else - string (Printf.sprintf "CREATE(%s)(&rop->%s);" (sgen_ctyp_name ctyp) (sgen_id ctor_id)) - ^^ string (Printf.sprintf " COPY(%s)(&rop->%s, op.%s);" (sgen_ctyp_name ctyp) (sgen_id ctor_id) (sgen_id ctor_id)) - in - string (Printf.sprintf "static void COPY(%s)(struct %s *rop, struct %s op)" n n n) ^^ hardline - ^^ surround 2 0 lbrace - (each_ctor "rop->" (clear_field "rop") tus - ^^ semi ^^ hardline - ^^ string "rop->kind = op.kind" - ^^ semi ^^ hardline - ^^ each_ctor "op." set_field tus) - rbrace - in - let codegen_eq = - let codegen_eq_test ctor_id ctyp = - string (Printf.sprintf "return EQUAL(%s)(op1.%s, op2.%s);" (sgen_ctyp_name ctyp) (sgen_id ctor_id) (sgen_id ctor_id)) - in - let rec codegen_eq_tests = function - | [] -> string "return false;" - | (ctor_id, ctyp) :: ctors -> - string (Printf.sprintf "if (op1.kind == Kind_%s && op2.kind == Kind_%s) " (sgen_id ctor_id) (sgen_id ctor_id)) ^^ lbrace ^^ hardline - ^^ jump 0 2 (codegen_eq_test ctor_id ctyp) - ^^ hardline ^^ rbrace ^^ string " else " ^^ codegen_eq_tests ctors - in - let n = sgen_id id in - string (Printf.sprintf "static bool EQUAL(%s)(struct %s op1, struct %s op2) " n n n) - ^^ surround 2 0 lbrace (codegen_eq_tests tus) rbrace - in - string (Printf.sprintf "// union %s" (string_of_id id)) ^^ hardline - ^^ string "enum" ^^ space - ^^ string ("kind_" ^ sgen_id id) ^^ space - ^^ separate space [ lbrace; - separate_map (comma ^^ space) (fun id -> string ("Kind_" ^ sgen_id id)) (List.map fst tus); - rbrace ^^ semi ] - ^^ twice hardline - ^^ string "struct" ^^ space ^^ codegen_id id ^^ space - ^^ surround 2 0 lbrace - (separate space [string "enum"; string ("kind_" ^ sgen_id id); string "kind" ^^ semi] - ^^ hardline - ^^ string "union" ^^ space - ^^ surround 2 0 lbrace - (separate_map (semi ^^ hardline) codegen_tu tus ^^ semi) - rbrace - ^^ semi) - rbrace - ^^ semi - ^^ twice hardline - ^^ codegen_init - ^^ twice hardline - ^^ codegen_reinit - ^^ twice hardline - ^^ codegen_clear - ^^ twice hardline - ^^ codegen_setter - ^^ twice hardline - ^^ codegen_eq - ^^ twice hardline - ^^ separate_map (twice hardline) codegen_ctor tus - (* If this is the exception type, then we setup up some global variables to deal with exceptions. *) - ^^ if string_of_id id = "exception" then - twice hardline - ^^ string "struct zexception *current_exception = NULL;" - ^^ hardline - ^^ string "bool have_exception = false;" - ^^ hardline - ^^ string "sail_string *throw_location = NULL;" - else - empty + let codegen_tu (ctor_id, ctyp) = + separate space [string "struct"; lbrace; string (sgen_ctyp ctyp); codegen_id ctor_id ^^ semi; rbrace] + in + (* Create an if, else if, ... block that does something for each constructor *) + let rec each_ctor v f = function + | [] -> string "{}" + | [(ctor_id, ctyp)] -> + string (Printf.sprintf "if (%skind == Kind_%s)" v (sgen_id ctor_id)) + ^^ lbrace ^^ hardline + ^^ jump 0 2 (f ctor_id ctyp) + ^^ hardline ^^ rbrace + | (ctor_id, ctyp) :: ctors -> + string (Printf.sprintf "if (%skind == Kind_%s) " v (sgen_id ctor_id)) + ^^ lbrace ^^ hardline + ^^ jump 0 2 (f ctor_id ctyp) + ^^ hardline ^^ rbrace ^^ string " else " ^^ each_ctor v f ctors + in + let codegen_init = + let n = sgen_id id in + let ctor_id, ctyp = List.hd tus in + string (Printf.sprintf "static void CREATE(%s)(struct %s *op)" n n) + ^^ hardline + ^^ surround 2 0 lbrace + (string (Printf.sprintf "op->kind = Kind_%s;" (sgen_id ctor_id)) + ^^ hardline + ^^ + if not (is_stack_ctyp ctyp) then + string (Printf.sprintf "CREATE(%s)(&op->%s);" (sgen_ctyp_name ctyp) (sgen_id ctor_id)) + else empty + ) + rbrace + in + let codegen_reinit = + let n = sgen_id id in + string (Printf.sprintf "static void RECREATE(%s)(struct %s *op) {}" n n) + in + let clear_field v ctor_id ctyp = + if is_stack_ctyp ctyp then string (Printf.sprintf "/* do nothing */") + else string (Printf.sprintf "KILL(%s)(&%s->%s);" (sgen_ctyp_name ctyp) v (sgen_id ctor_id)) + in + let codegen_clear = + let n = sgen_id id in + string (Printf.sprintf "static void KILL(%s)(struct %s *op)" n n) + ^^ hardline + ^^ surround 2 0 lbrace (each_ctor "op->" (clear_field "op") tus ^^ semi) rbrace + in + let codegen_ctor (ctor_id, ctyp) = + let ctor_args, tuple, tuple_cleanup = (Printf.sprintf "%s op" (sgen_ctyp ctyp), empty, empty) in + string + (Printf.sprintf "static void %s(%sstruct %s *rop, %s)" (sgen_function_id ctor_id) (extra_params ()) + (sgen_id id) ctor_args + ) + ^^ hardline + ^^ surround 2 0 lbrace + (tuple + ^^ each_ctor "rop->" (clear_field "rop") tus + ^^ hardline + ^^ string ("rop->kind = Kind_" ^ sgen_id ctor_id) + ^^ semi ^^ hardline + ^^ + if is_stack_ctyp ctyp then string (Printf.sprintf "rop->%s = op;" (sgen_id ctor_id)) + else + string (Printf.sprintf "CREATE(%s)(&rop->%s);" (sgen_ctyp_name ctyp) (sgen_id ctor_id)) + ^^ hardline + ^^ string (Printf.sprintf "COPY(%s)(&rop->%s, op);" (sgen_ctyp_name ctyp) (sgen_id ctor_id)) + ^^ hardline ^^ tuple_cleanup + ) + rbrace + in + let codegen_setter = + let n = sgen_id id in + let set_field ctor_id ctyp = + if is_stack_ctyp ctyp then string (Printf.sprintf "rop->%s = op.%s;" (sgen_id ctor_id) (sgen_id ctor_id)) + else + string (Printf.sprintf "CREATE(%s)(&rop->%s);" (sgen_ctyp_name ctyp) (sgen_id ctor_id)) + ^^ string + (Printf.sprintf " COPY(%s)(&rop->%s, op.%s);" (sgen_ctyp_name ctyp) (sgen_id ctor_id) (sgen_id ctor_id)) + in + string (Printf.sprintf "static void COPY(%s)(struct %s *rop, struct %s op)" n n n) + ^^ hardline + ^^ surround 2 0 lbrace + (each_ctor "rop->" (clear_field "rop") tus + ^^ semi ^^ hardline ^^ string "rop->kind = op.kind" ^^ semi ^^ hardline ^^ each_ctor "op." set_field tus + ) + rbrace + in + let codegen_eq = + let codegen_eq_test ctor_id ctyp = + string + (Printf.sprintf "return EQUAL(%s)(op1.%s, op2.%s);" (sgen_ctyp_name ctyp) (sgen_id ctor_id) (sgen_id ctor_id) + ) + in + let rec codegen_eq_tests = function + | [] -> string "return false;" + | (ctor_id, ctyp) :: ctors -> + string + (Printf.sprintf "if (op1.kind == Kind_%s && op2.kind == Kind_%s) " (sgen_id ctor_id) (sgen_id ctor_id)) + ^^ lbrace ^^ hardline + ^^ jump 0 2 (codegen_eq_test ctor_id ctyp) + ^^ hardline ^^ rbrace ^^ string " else " ^^ codegen_eq_tests ctors + in + let n = sgen_id id in + string (Printf.sprintf "static bool EQUAL(%s)(struct %s op1, struct %s op2) " n n n) + ^^ surround 2 0 lbrace (codegen_eq_tests tus) rbrace + in + string (Printf.sprintf "// union %s" (string_of_id id)) + ^^ hardline ^^ string "enum" ^^ space + ^^ string ("kind_" ^ sgen_id id) + ^^ space + ^^ separate space + [ + lbrace; + separate_map (comma ^^ space) (fun id -> string ("Kind_" ^ sgen_id id)) (List.map fst tus); + rbrace ^^ semi; + ] + ^^ twice hardline ^^ string "struct" ^^ space ^^ codegen_id id ^^ space + ^^ surround 2 0 lbrace + (separate space [string "enum"; string ("kind_" ^ sgen_id id); string "kind" ^^ semi] + ^^ hardline ^^ string "union" ^^ space + ^^ surround 2 0 lbrace (separate_map (semi ^^ hardline) codegen_tu tus ^^ semi) rbrace + ^^ semi + ) + rbrace + ^^ semi ^^ twice hardline ^^ codegen_init ^^ twice hardline ^^ codegen_reinit ^^ twice hardline ^^ codegen_clear + ^^ twice hardline ^^ codegen_setter ^^ twice hardline ^^ codegen_eq ^^ twice hardline + ^^ separate_map (twice hardline) codegen_ctor tus + (* If this is the exception type, then we setup up some global variables to deal with exceptions. *) + ^^ + if string_of_id id = "exception" then + twice hardline + ^^ string "struct zexception *current_exception = NULL;" + ^^ hardline ^^ string "bool have_exception = false;" ^^ hardline + ^^ string "sail_string *throw_location = NULL;" + else empty (** GLOBAL: because C doesn't have real anonymous tuple types (anonymous structs don't quite work the way we need) every tuple @@ -1809,20 +1539,22 @@ let generated = ref IdSet.empty let codegen_tup ctyps = let id = mk_id ("tuple_" ^ string_of_ctyp (CT_tup ctyps)) in - if IdSet.mem id !generated then - empty - else - begin - let _, fields = List.fold_left (fun (n, fields) ctyp -> n + 1, Bindings.add (mk_id ("tup" ^ string_of_int n)) ctyp fields) - (0, Bindings.empty) - ctyps - in - generated := IdSet.add id !generated; - codegen_type_def (CTD_struct (id, Bindings.bindings fields)) ^^ twice hardline - end + if IdSet.mem id !generated then empty + else begin + let _, fields = + List.fold_left + (fun (n, fields) ctyp -> (n + 1, Bindings.add (mk_id ("tup" ^ string_of_int n)) ctyp fields)) + (0, Bindings.empty) ctyps + in + generated := IdSet.add id !generated; + codegen_type_def (CTD_struct (id, Bindings.bindings fields)) ^^ twice hardline + end let codegen_node id ctyp = - string (Printf.sprintf "struct node_%s {\n unsigned int rc;\n %s hd;\n struct node_%s *tl;\n};\n" (sgen_id id) (sgen_ctyp ctyp) (sgen_id id)) + string + (Printf.sprintf "struct node_%s {\n unsigned int rc;\n %s hd;\n struct node_%s *tl;\n};\n" (sgen_id id) + (sgen_ctyp ctyp) (sgen_id id) + ) ^^ string (Printf.sprintf "typedef struct node_%s *%s;" (sgen_id id) (sgen_id id)) let codegen_list_init id = @@ -1830,110 +1562,105 @@ let codegen_list_init id = let codegen_list_clear id ctyp = string (Printf.sprintf "static void KILL(%s)(%s *rop) {\n" (sgen_id id) (sgen_id id)) - ^^ string " if (*rop == NULL) return;\n" - ^^ string " if ((*rop)->rc >= 1) {\n" - ^^ string " (*rop)->rc -= 1;\n" + ^^ string " if (*rop == NULL) return;\n" ^^ string " if ((*rop)->rc >= 1) {\n" ^^ string " (*rop)->rc -= 1;\n" ^^ string " }\n" ^^ string (Printf.sprintf " %s node = *rop;\n" (sgen_id id)) ^^ string " while (node != NULL && node->rc == 0) {\n" - ^^ (if is_stack_ctyp ctyp then empty - else string (Printf.sprintf " KILL(%s)(&node->hd);\n" (sgen_ctyp_name ctyp))) + ^^ (if is_stack_ctyp ctyp then empty else string (Printf.sprintf " KILL(%s)(&node->hd);\n" (sgen_ctyp_name ctyp))) ^^ string (Printf.sprintf " %s next = node->tl;\n" (sgen_id id)) - ^^ string " sail_free(node);\n" - ^^ string " node = next;\n" + ^^ string " sail_free(node);\n" ^^ string " node = next;\n" ^^ string (Printf.sprintf " internal_dec_%s(node);\n" (sgen_id id)) - ^^ string " }\n" - ^^ string "}" + ^^ string " }\n" ^^ string "}" let codegen_list_recreate id = - string (Printf.sprintf "static void RECREATE(%s)(%s *rop) { KILL(%s)(rop); *rop = NULL; }" (sgen_id id) (sgen_id id) (sgen_id id)) - + string + (Printf.sprintf "static void RECREATE(%s)(%s *rop) { KILL(%s)(rop); *rop = NULL; }" (sgen_id id) (sgen_id id) + (sgen_id id) + ) + let codegen_inc_reference_count id = string (Printf.sprintf "static void internal_inc_%s(%s l) {\n" (sgen_id id) (sgen_id id)) - ^^ string " if (l == NULL) return;\n" - ^^ string " l->rc += 1;\n" - ^^ string "}" + ^^ string " if (l == NULL) return;\n" ^^ string " l->rc += 1;\n" ^^ string "}" let codegen_dec_reference_count id = string (Printf.sprintf "static void internal_dec_%s(%s l) {\n" (sgen_id id) (sgen_id id)) - ^^ string " if (l == NULL) return;\n" - ^^ string " l->rc -= 1;\n" - ^^ string "}" + ^^ string " if (l == NULL) return;\n" ^^ string " l->rc -= 1;\n" ^^ string "}" let codegen_list_copy id = string (Printf.sprintf "static void COPY(%s)(%s *rop, %s op) {\n" (sgen_id id) (sgen_id id) (sgen_id id)) ^^ string (Printf.sprintf " internal_inc_%s(op);\n" (sgen_id id)) ^^ string (Printf.sprintf " KILL(%s)(rop);\n" (sgen_id id)) - ^^ string " *rop = op;\n" - ^^ string "}" - + ^^ string " *rop = op;\n" ^^ string "}" + let codegen_cons id ctyp = let cons_id = mk_id ("cons#" ^ string_of_ctyp ctyp) in - string (Printf.sprintf "static void %s(%s *rop, %s x, %s xs) {\n" (sgen_function_id cons_id) (sgen_id id) (sgen_ctyp ctyp) (sgen_id id)) + string + (Printf.sprintf "static void %s(%s *rop, %s x, %s xs) {\n" (sgen_function_id cons_id) (sgen_id id) (sgen_ctyp ctyp) + (sgen_id id) + ) ^^ string " bool same = *rop == xs;\n" ^^ string (Printf.sprintf " *rop = sail_malloc(sizeof(struct node_%s));\n" (sgen_id id)) ^^ string " (*rop)->rc = 1;\n" - ^^ (if is_stack_ctyp ctyp then - string " (*rop)->hd = x;\n" - else - string (Printf.sprintf " CREATE(%s)(&(*rop)->hd);\n" (sgen_ctyp_name ctyp)) - ^^ string (Printf.sprintf " COPY(%s)(&(*rop)->hd, x);\n" (sgen_ctyp_name ctyp))) + ^^ ( if is_stack_ctyp ctyp then string " (*rop)->hd = x;\n" + else + string (Printf.sprintf " CREATE(%s)(&(*rop)->hd);\n" (sgen_ctyp_name ctyp)) + ^^ string (Printf.sprintf " COPY(%s)(&(*rop)->hd, x);\n" (sgen_ctyp_name ctyp)) + ) ^^ string (Printf.sprintf " if (!same) internal_inc_%s(xs);\n" (sgen_id id)) - ^^ string " (*rop)->tl = xs;\n" - ^^ string "}" + ^^ string " (*rop)->tl = xs;\n" ^^ string "}" let codegen_pick id ctyp = if is_stack_ctyp ctyp then - string (Printf.sprintf "static %s pick_%s(const %s xs) { return xs->hd; }" (sgen_ctyp ctyp) (sgen_ctyp_name ctyp) (sgen_id id)) + string + (Printf.sprintf "static %s pick_%s(const %s xs) { return xs->hd; }" (sgen_ctyp ctyp) (sgen_ctyp_name ctyp) + (sgen_id id) + ) else - string (Printf.sprintf "static void pick_%s(%s *x, const %s xs) { COPY(%s)(x, xs->hd); }" (sgen_ctyp_name ctyp) (sgen_ctyp ctyp) (sgen_id id) (sgen_ctyp_name ctyp)) + string + (Printf.sprintf "static void pick_%s(%s *x, const %s xs) { COPY(%s)(x, xs->hd); }" (sgen_ctyp_name ctyp) + (sgen_ctyp ctyp) (sgen_id id) (sgen_ctyp_name ctyp) + ) let codegen_list_equal id ctyp = let open Printf in ksprintf string "static bool EQUAL(%s)(const %s op1, const %s op2) {\n" (sgen_id id) (sgen_id id) (sgen_id id) ^^ ksprintf string " if (op1 == NULL && op2 == NULL) { return true; };\n" ^^ ksprintf string " if (op1 == NULL || op2 == NULL) { return false; };\n" - ^^ ksprintf string " return EQUAL(%s)(op1->hd, op2->hd) && EQUAL(%s)(op1->tl, op2->tl);\n" (sgen_ctyp_name ctyp) (sgen_id id) + ^^ ksprintf string " return EQUAL(%s)(op1->hd, op2->hd) && EQUAL(%s)(op1->tl, op2->tl);\n" (sgen_ctyp_name ctyp) + (sgen_id id) ^^ string "}" let codegen_list_undefined id ctyp = let open Printf in ksprintf string "static void UNDEFINED(%s)(%s *rop, %s u) {\n" (sgen_id id) (sgen_id id) (sgen_ctyp ctyp) - ^^ ksprintf string " *rop = NULL;\n" - ^^ string "}" + ^^ ksprintf string " *rop = NULL;\n" ^^ string "}" let codegen_list ctyp = let id = mk_id (string_of_ctyp (CT_list ctyp)) in - if IdSet.mem id !generated then - empty - else - begin - generated := IdSet.add id !generated; - codegen_node id ctyp ^^ twice hardline - ^^ codegen_list_init id ^^ twice hardline - ^^ codegen_inc_reference_count id ^^ twice hardline - ^^ codegen_dec_reference_count id ^^ twice hardline - ^^ codegen_list_clear id ctyp ^^ twice hardline - ^^ codegen_list_recreate id ^^ twice hardline - ^^ codegen_list_copy id ^^ twice hardline - ^^ codegen_cons id ctyp ^^ twice hardline - ^^ codegen_pick id ctyp ^^ twice hardline - ^^ codegen_list_equal id ctyp ^^ twice hardline - ^^ codegen_list_undefined id ctyp ^^ twice hardline - end + if IdSet.mem id !generated then empty + else begin + generated := IdSet.add id !generated; + codegen_node id ctyp ^^ twice hardline ^^ codegen_list_init id ^^ twice hardline ^^ codegen_inc_reference_count id + ^^ twice hardline ^^ codegen_dec_reference_count id ^^ twice hardline ^^ codegen_list_clear id ctyp + ^^ twice hardline ^^ codegen_list_recreate id ^^ twice hardline ^^ codegen_list_copy id ^^ twice hardline + ^^ codegen_cons id ctyp ^^ twice hardline ^^ codegen_pick id ctyp ^^ twice hardline ^^ codegen_list_equal id ctyp + ^^ twice hardline ^^ codegen_list_undefined id ctyp ^^ twice hardline + end (* Generate functions for working with non-bit vectors of some specific type. *) let codegen_vector (direction, ctyp) = let id = mk_id (string_of_ctyp (CT_vector (direction, ctyp))) in - if IdSet.mem id !generated then - empty - else + if IdSet.mem id !generated then empty + else ( let vector_typedef = string (Printf.sprintf "struct %s {\n size_t len;\n %s *data;\n};\n" (sgen_id id) (sgen_ctyp ctyp)) ^^ string (Printf.sprintf "typedef struct %s %s;" (sgen_id id) (sgen_id id)) in let vector_init = - string (Printf.sprintf "static void CREATE(%s)(%s *rop) {\n rop->len = 0;\n rop->data = NULL;\n}" (sgen_id id) (sgen_id id)) + string + (Printf.sprintf "static void CREATE(%s)(%s *rop) {\n rop->len = 0;\n rop->data = NULL;\n}" (sgen_id id) + (sgen_id id) + ) in let vector_set = string (Printf.sprintf "static void COPY(%s)(%s *rop, %s op) {\n" (sgen_id id) (sgen_id id) (sgen_id id)) @@ -1941,224 +1668,234 @@ let codegen_vector (direction, ctyp) = ^^ string " rop->len = op.len;\n" ^^ string (Printf.sprintf " rop->data = sail_malloc((rop->len) * sizeof(%s));\n" (sgen_ctyp ctyp)) ^^ string " for (int i = 0; i < op.len; i++) {\n" - ^^ string (if is_stack_ctyp ctyp then - " (rop->data)[i] = op.data[i];\n" - else - Printf.sprintf " CREATE(%s)((rop->data) + i);\n COPY(%s)((rop->data) + i, op.data[i]);\n" (sgen_ctyp_name ctyp) (sgen_ctyp_name ctyp)) - ^^ string " }\n" - ^^ string "}" + ^^ string + ( if is_stack_ctyp ctyp then " (rop->data)[i] = op.data[i];\n" + else + Printf.sprintf " CREATE(%s)((rop->data) + i);\n COPY(%s)((rop->data) + i, op.data[i]);\n" + (sgen_ctyp_name ctyp) (sgen_ctyp_name ctyp) + ) + ^^ string " }\n" ^^ string "}" in let vector_clear = string (Printf.sprintf "static void KILL(%s)(%s *rop) {\n" (sgen_id id) (sgen_id id)) - ^^ (if is_stack_ctyp ctyp then empty - else - string " for (int i = 0; i < (rop->len); i++) {\n" - ^^ string (Printf.sprintf " KILL(%s)((rop->data) + i);\n" (sgen_ctyp_name ctyp)) - ^^ string " }\n") + ^^ ( if is_stack_ctyp ctyp then empty + else + string " for (int i = 0; i < (rop->len); i++) {\n" + ^^ string (Printf.sprintf " KILL(%s)((rop->data) + i);\n" (sgen_ctyp_name ctyp)) + ^^ string " }\n" + ) ^^ string " if (rop->data != NULL) sail_free(rop->data);\n" ^^ string "}" in let vector_reinit = - string (Printf.sprintf "static void RECREATE(%s)(%s *rop) { KILL(%s)(rop); CREATE(%s)(rop); }" (sgen_id id) (sgen_id id) (sgen_id id) (sgen_id id)) + string + (Printf.sprintf "static void RECREATE(%s)(%s *rop) { KILL(%s)(rop); CREATE(%s)(rop); }" (sgen_id id) + (sgen_id id) (sgen_id id) (sgen_id id) + ) in let vector_update = - string (Printf.sprintf "static void vector_update_%s(%s *rop, %s op, sail_int n, %s elem) {\n" (sgen_id id) (sgen_id id) (sgen_id id) (sgen_ctyp ctyp)) + string + (Printf.sprintf "static void vector_update_%s(%s *rop, %s op, sail_int n, %s elem) {\n" (sgen_id id) + (sgen_id id) (sgen_id id) (sgen_ctyp ctyp) + ) ^^ string " int m = sail_int_get_ui(n);\n" ^^ string " if (rop->data == op.data) {\n" - ^^ string (if is_stack_ctyp ctyp then - " rop->data[m] = elem;\n" - else - Printf.sprintf " COPY(%s)((rop->data) + m, elem);\n" (sgen_ctyp_name ctyp)) + ^^ string + ( if is_stack_ctyp ctyp then " rop->data[m] = elem;\n" + else Printf.sprintf " COPY(%s)((rop->data) + m, elem);\n" (sgen_ctyp_name ctyp) + ) ^^ string " } else {\n" ^^ string (Printf.sprintf " COPY(%s)(rop, op);\n" (sgen_id id)) - ^^ string (if is_stack_ctyp ctyp then - " rop->data[m] = elem;\n" - else - Printf.sprintf " COPY(%s)((rop->data) + m, elem);\n" (sgen_ctyp_name ctyp)) - ^^ string " }\n" - ^^ string "}" + ^^ string + ( if is_stack_ctyp ctyp then " rop->data[m] = elem;\n" + else Printf.sprintf " COPY(%s)((rop->data) + m, elem);\n" (sgen_ctyp_name ctyp) + ) + ^^ string " }\n" ^^ string "}" in let internal_vector_update = - string (Printf.sprintf "static void internal_vector_update_%s(%s *rop, %s op, const int64_t n, %s elem) {\n" (sgen_id id) (sgen_id id) (sgen_id id) (sgen_ctyp ctyp)) - ^^ string (if is_stack_ctyp ctyp then - " rop->data[n] = elem;\n" - else - Printf.sprintf " COPY(%s)((rop->data) + n, elem);\n" (sgen_ctyp_name ctyp)) + string + (Printf.sprintf "static void internal_vector_update_%s(%s *rop, %s op, const int64_t n, %s elem) {\n" + (sgen_id id) (sgen_id id) (sgen_id id) (sgen_ctyp ctyp) + ) + ^^ string + ( if is_stack_ctyp ctyp then " rop->data[n] = elem;\n" + else Printf.sprintf " COPY(%s)((rop->data) + n, elem);\n" (sgen_ctyp_name ctyp) + ) ^^ string "}" in let vector_access = if is_stack_ctyp ctyp then - string (Printf.sprintf "static %s vector_access_%s(%s op, sail_int n) {\n" (sgen_ctyp ctyp) (sgen_id id) (sgen_id id)) + string + (Printf.sprintf "static %s vector_access_%s(%s op, sail_int n) {\n" (sgen_ctyp ctyp) (sgen_id id) (sgen_id id)) ^^ string " int m = sail_int_get_ui(n);\n" - ^^ string " return op.data[m];\n" - ^^ string "}" + ^^ string " return op.data[m];\n" ^^ string "}" else - string (Printf.sprintf "static void vector_access_%s(%s *rop, %s op, sail_int n) {\n" (sgen_id id) (sgen_ctyp ctyp) (sgen_id id)) + string + (Printf.sprintf "static void vector_access_%s(%s *rop, %s op, sail_int n) {\n" (sgen_id id) (sgen_ctyp ctyp) + (sgen_id id) + ) ^^ string " int m = sail_int_get_ui(n);\n" ^^ string (Printf.sprintf " COPY(%s)(rop, op.data[m]);\n" (sgen_ctyp_name ctyp)) ^^ string "}" in let internal_vector_init = - string (Printf.sprintf "static void internal_vector_init_%s(%s *rop, const int64_t len) {\n" (sgen_id id) (sgen_id id)) + string + (Printf.sprintf "static void internal_vector_init_%s(%s *rop, const int64_t len) {\n" (sgen_id id) (sgen_id id)) ^^ string " rop->len = len;\n" ^^ string (Printf.sprintf " rop->data = sail_malloc(len * sizeof(%s));\n" (sgen_ctyp ctyp)) - ^^ (if not (is_stack_ctyp ctyp) then - string " for (int i = 0; i < len; i++) {\n" - ^^ string (Printf.sprintf " CREATE(%s)((rop->data) + i);\n" (sgen_ctyp_name ctyp)) - ^^ string " }\n" - else empty) + ^^ ( if not (is_stack_ctyp ctyp) then + string " for (int i = 0; i < len; i++) {\n" + ^^ string (Printf.sprintf " CREATE(%s)((rop->data) + i);\n" (sgen_ctyp_name ctyp)) + ^^ string " }\n" + else empty + ) ^^ string "}" in let vector_undefined = - string (Printf.sprintf "static void undefined_vector_%s(%s *rop, sail_int len, %s elem) {\n" (sgen_id id) (sgen_id id) (sgen_ctyp ctyp)) + string + (Printf.sprintf "static void undefined_vector_%s(%s *rop, sail_int len, %s elem) {\n" (sgen_id id) (sgen_id id) + (sgen_ctyp ctyp) + ) ^^ string (Printf.sprintf " rop->len = sail_int_get_ui(len);\n") ^^ string (Printf.sprintf " rop->data = sail_malloc((rop->len) * sizeof(%s));\n" (sgen_ctyp ctyp)) ^^ string " for (int i = 0; i < (rop->len); i++) {\n" - ^^ string (if is_stack_ctyp ctyp then - " (rop->data)[i] = elem;\n" - else - Printf.sprintf " CREATE(%s)((rop->data) + i);\n COPY(%s)((rop->data) + i, elem);\n" (sgen_ctyp_name ctyp) (sgen_ctyp_name ctyp)) - ^^ string " }\n" - ^^ string "}" + ^^ string + ( if is_stack_ctyp ctyp then " (rop->data)[i] = elem;\n" + else + Printf.sprintf " CREATE(%s)((rop->data) + i);\n COPY(%s)((rop->data) + i, elem);\n" + (sgen_ctyp_name ctyp) (sgen_ctyp_name ctyp) + ) + ^^ string " }\n" ^^ string "}" in let vector_equal = let open Printf in - ksprintf string "static bool EQUAL(%s)(const %s op1, const %s op2) {\n" (sgen_id id) (sgen_id id) (sgen_id id) - ^^ string " if (op1.len != op2.len) return false;\n" - ^^ string " bool result = true;" - ^^ string " for (int i = 0; i < op1.len; i++) {\n" + ksprintf string "static bool EQUAL(%s)(const %s op1, const %s op2) {\n" (sgen_id id) (sgen_id id) (sgen_id id) + ^^ string " if (op1.len != op2.len) return false;\n" + ^^ string " bool result = true;" + ^^ string " for (int i = 0; i < op1.len; i++) {\n" ^^ ksprintf string " result &= EQUAL(%s)(op1.data[i], op2.data[i]);" (sgen_ctyp_name ctyp) - ^^ string " }\n" - ^^ ksprintf string " return result;\n" - ^^ string "}" + ^^ string " }\n" ^^ ksprintf string " return result;\n" ^^ string "}" in let vector_length = let open Printf in - ksprintf string "static void length_%s(sail_int *rop, %s op) {\n" (sgen_id id) (sgen_id id) - ^^ ksprintf string " mpz_set_ui(*rop, (unsigned long int)(op.len));\n" - ^^ string "}" + ksprintf string "static void length_%s(sail_int *rop, %s op) {\n" (sgen_id id) (sgen_id id) + ^^ ksprintf string " mpz_set_ui(*rop, (unsigned long int)(op.len));\n" + ^^ string "}" in begin generated := IdSet.add id !generated; - vector_typedef ^^ twice hardline - ^^ vector_init ^^ twice hardline - ^^ vector_clear ^^ twice hardline - ^^ vector_reinit ^^ twice hardline - ^^ vector_undefined ^^ twice hardline - ^^ vector_access ^^ twice hardline - ^^ vector_set ^^ twice hardline - ^^ vector_update ^^ twice hardline - ^^ vector_equal ^^ twice hardline - ^^ vector_length ^^ twice hardline - ^^ internal_vector_update ^^ twice hardline - ^^ internal_vector_init ^^ twice hardline + vector_typedef ^^ twice hardline ^^ vector_init ^^ twice hardline ^^ vector_clear ^^ twice hardline + ^^ vector_reinit ^^ twice hardline ^^ vector_undefined ^^ twice hardline ^^ vector_access ^^ twice hardline + ^^ vector_set ^^ twice hardline ^^ vector_update ^^ twice hardline ^^ vector_equal ^^ twice hardline + ^^ vector_length ^^ twice hardline ^^ internal_vector_update ^^ twice hardline ^^ internal_vector_init + ^^ twice hardline end + ) -let is_decl = function - | I_aux (I_decl _, _) -> true - | _ -> false +let is_decl = function I_aux (I_decl _, _) -> true | _ -> false let codegen_decl = function - | I_aux (I_decl (ctyp, id), _) -> - string (Printf.sprintf "%s %s;" (sgen_ctyp ctyp) (sgen_name id)) + | I_aux (I_decl (ctyp, id), _) -> string (Printf.sprintf "%s %s;" (sgen_ctyp ctyp) (sgen_name id)) | _ -> assert false let codegen_alloc = function | I_aux (I_decl (ctyp, _), _) when is_stack_ctyp ctyp -> empty - | I_aux (I_decl (ctyp, id), _) -> - string (Printf.sprintf " CREATE(%s)(&%s);" (sgen_ctyp_name ctyp) (sgen_name id)) + | I_aux (I_decl (ctyp, id), _) -> string (Printf.sprintf " CREATE(%s)(&%s);" (sgen_ctyp_name ctyp) (sgen_name id)) | _ -> assert false let codegen_def' ctx = function | CDEF_register (id, ctyp, _) -> - string (Printf.sprintf "// register %s" (string_of_id id)) ^^ hardline - ^^ string (Printf.sprintf "%s%s %s;" (static ()) (sgen_ctyp ctyp) (sgen_id id)) - + string (Printf.sprintf "// register %s" (string_of_id id)) + ^^ hardline + ^^ string (Printf.sprintf "%s%s %s;" (static ()) (sgen_ctyp ctyp) (sgen_id id)) | CDEF_val (id, _, arg_ctyps, ret_ctyp) -> - if ctx_is_extern id ctx then - empty - else if is_stack_ctyp ret_ctyp then - string (Printf.sprintf "%s%s %s(%s%s);" (static ()) (sgen_ctyp ret_ctyp) (sgen_function_id id) (extra_params ()) (Util.string_of_list ", " sgen_ctyp arg_ctyps)) - else - string (Printf.sprintf "%svoid %s(%s%s *rop, %s);" (static ()) (sgen_function_id id) (extra_params ()) (sgen_ctyp ret_ctyp) (Util.string_of_list ", " sgen_ctyp arg_ctyps)) - + if ctx_is_extern id ctx then empty + else if is_stack_ctyp ret_ctyp then + string + (Printf.sprintf "%s%s %s(%s%s);" (static ()) (sgen_ctyp ret_ctyp) (sgen_function_id id) (extra_params ()) + (Util.string_of_list ", " sgen_ctyp arg_ctyps) + ) + else + string + (Printf.sprintf "%svoid %s(%s%s *rop, %s);" (static ()) (sgen_function_id id) (extra_params ()) + (sgen_ctyp ret_ctyp) + (Util.string_of_list ", " sgen_ctyp arg_ctyps) + ) | CDEF_fundef (id, ret_arg, args, instrs) -> - let _, arg_ctyps, ret_ctyp = match Bindings.find_opt id ctx.valspecs with - | Some vs -> vs - | None -> - c_error ~loc:(id_loc id) ("No valspec found for " ^ string_of_id id) - in - - (* Check that the function has the correct arity at this point. *) - if List.length arg_ctyps <> List.length args then - c_error ~loc:(id_loc id) ("function arguments " - ^ Util.string_of_list ", " string_of_id args - ^ " matched against type " - ^ Util.string_of_list ", " string_of_ctyp arg_ctyps) - else (); - - let instrs = add_local_labels instrs in - let args = Util.string_of_list ", " (fun x -> x) (List.map2 (fun ctyp arg -> sgen_ctyp ctyp ^ " " ^ sgen_id arg) arg_ctyps args) in - let function_header = - match ret_arg with - | None -> - assert (is_stack_ctyp ret_ctyp); - (if !opt_static then string "static " else empty) - ^^ string (sgen_ctyp ret_ctyp) ^^ space ^^ codegen_function_id id ^^ parens (string (extra_params ()) ^^ string args) ^^ hardline - | Some gs -> - assert (not (is_stack_ctyp ret_ctyp)); - (if !opt_static then string "static " else empty) - ^^ string "void" ^^ space ^^ codegen_function_id id - ^^ parens (string (extra_params ()) ^^ string (sgen_ctyp ret_ctyp ^ " *" ^ sgen_id gs ^ ", ") ^^ string args) - ^^ hardline - in - function_header - ^^ string "{" - ^^ jump 0 2 (separate_map hardline (codegen_instr id ctx) instrs) ^^ hardline - ^^ string "}" - - | CDEF_type ctype_def -> - codegen_type_def ctype_def + let _, arg_ctyps, ret_ctyp = + match Bindings.find_opt id ctx.valspecs with + | Some vs -> vs + | None -> c_error ~loc:(id_loc id) ("No valspec found for " ^ string_of_id id) + in + (* Check that the function has the correct arity at this point. *) + if List.length arg_ctyps <> List.length args then + c_error ~loc:(id_loc id) + ("function arguments " + ^ Util.string_of_list ", " string_of_id args + ^ " matched against type " + ^ Util.string_of_list ", " string_of_ctyp arg_ctyps + ) + else (); + + let instrs = add_local_labels instrs in + let args = + Util.string_of_list ", " + (fun x -> x) + (List.map2 (fun ctyp arg -> sgen_ctyp ctyp ^ " " ^ sgen_id arg) arg_ctyps args) + in + let function_header = + match ret_arg with + | None -> + assert (is_stack_ctyp ret_ctyp); + (if !opt_static then string "static " else empty) + ^^ string (sgen_ctyp ret_ctyp) + ^^ space ^^ codegen_function_id id + ^^ parens (string (extra_params ()) ^^ string args) + ^^ hardline + | Some gs -> + assert (not (is_stack_ctyp ret_ctyp)); + (if !opt_static then string "static " else empty) + ^^ string "void" ^^ space ^^ codegen_function_id id + ^^ parens (string (extra_params ()) ^^ string (sgen_ctyp ret_ctyp ^ " *" ^ sgen_id gs ^ ", ") ^^ string args) + ^^ hardline + in + function_header ^^ string "{" + ^^ jump 0 2 (separate_map hardline (codegen_instr id ctx) instrs) + ^^ hardline ^^ string "}" + | CDEF_type ctype_def -> codegen_type_def ctype_def | CDEF_startup (id, instrs) -> - let startup_header = string (Printf.sprintf "%svoid startup_%s(void)" (static ()) (sgen_function_id id)) in - separate_map hardline codegen_decl instrs - ^^ twice hardline - ^^ startup_header ^^ hardline - ^^ string "{" - ^^ jump 0 2 (separate_map hardline codegen_alloc instrs) ^^ hardline - ^^ string "}" - + let startup_header = string (Printf.sprintf "%svoid startup_%s(void)" (static ()) (sgen_function_id id)) in + separate_map hardline codegen_decl instrs + ^^ twice hardline ^^ startup_header ^^ hardline ^^ string "{" + ^^ jump 0 2 (separate_map hardline codegen_alloc instrs) + ^^ hardline ^^ string "}" | CDEF_finish (id, instrs) -> - let finish_header = string (Printf.sprintf "%svoid finish_%s(void)" (static ()) (sgen_function_id id)) in - separate_map hardline codegen_decl (List.filter is_decl instrs) - ^^ twice hardline - ^^ finish_header ^^ hardline - ^^ string "{" - ^^ jump 0 2 (separate_map hardline (codegen_instr id ctx) instrs) ^^ hardline - ^^ string "}" - + let finish_header = string (Printf.sprintf "%svoid finish_%s(void)" (static ()) (sgen_function_id id)) in + separate_map hardline codegen_decl (List.filter is_decl instrs) + ^^ twice hardline ^^ finish_header ^^ hardline ^^ string "{" + ^^ jump 0 2 (separate_map hardline (codegen_instr id ctx) instrs) + ^^ hardline ^^ string "}" | CDEF_let (number, bindings, instrs) -> - let instrs = add_local_labels instrs in - let setup = - List.concat (List.map (fun (id, ctyp) -> [idecl (id_loc id) ctyp (name id)]) bindings) - in - let cleanup = - List.concat (List.map (fun (id, ctyp) -> [iclear ~loc:(id_loc id) ctyp (name id)]) bindings) - in - separate_map hardline (fun (id, ctyp) -> string (Printf.sprintf "%s%s %s;" (static ()) (sgen_ctyp ctyp) (sgen_id id))) bindings - ^^ hardline ^^ string (Printf.sprintf "static void create_letbind_%d(void) " number) - ^^ string "{" - ^^ jump 0 2 (separate_map hardline codegen_alloc setup) ^^ hardline - ^^ jump 0 2 (separate_map hardline (codegen_instr (mk_id "let") { ctx with no_raw = true }) instrs) ^^ hardline - ^^ string "}" - ^^ hardline ^^ string (Printf.sprintf "static void kill_letbind_%d(void) " number) - ^^ string "{" - ^^ jump 0 2 (separate_map hardline (codegen_instr (mk_id "let") ctx) cleanup) ^^ hardline - ^^ string "}" - + let instrs = add_local_labels instrs in + let setup = List.concat (List.map (fun (id, ctyp) -> [idecl (id_loc id) ctyp (name id)]) bindings) in + let cleanup = List.concat (List.map (fun (id, ctyp) -> [iclear ~loc:(id_loc id) ctyp (name id)]) bindings) in + separate_map hardline + (fun (id, ctyp) -> string (Printf.sprintf "%s%s %s;" (static ()) (sgen_ctyp ctyp) (sgen_id id))) + bindings + ^^ hardline + ^^ string (Printf.sprintf "static void create_letbind_%d(void) " number) + ^^ string "{" + ^^ jump 0 2 (separate_map hardline codegen_alloc setup) + ^^ hardline + ^^ jump 0 2 (separate_map hardline (codegen_instr (mk_id "let") { ctx with no_raw = true }) instrs) + ^^ hardline ^^ string "}" ^^ hardline + ^^ string (Printf.sprintf "static void kill_letbind_%d(void) " number) + ^^ string "{" + ^^ jump 0 2 (separate_map hardline (codegen_instr (mk_id "let") ctx) cleanup) + ^^ hardline ^^ string "}" | CDEF_pragma _ -> empty - + (** As we generate C we need to generate specialized version of tuple, list, and vector type. These must be generated in the correct order. The ctyp_dependencies function generates a list of @@ -2166,20 +1903,19 @@ let codegen_def' ctx = function repeated in ctyp_dependencies so it's up to the code-generator not to repeat definitions pointlessly (using the !generated variable) *) -type c_gen_typ = - | CTG_tup of ctyp list - | CTG_list of ctyp - | CTG_vector of bool * ctyp +type c_gen_typ = CTG_tup of ctyp list | CTG_list of ctyp | CTG_vector of bool * ctyp let rec ctyp_dependencies = function | CT_tup ctyps -> List.concat (List.map ctyp_dependencies ctyps) @ [CTG_tup ctyps] | CT_list ctyp -> ctyp_dependencies ctyp @ [CTG_list ctyp] - | CT_vector (direction, ctyp) | CT_fvector (_, direction, ctyp) -> ctyp_dependencies ctyp @ [CTG_vector (direction, ctyp)] + | CT_vector (direction, ctyp) | CT_fvector (_, direction, ctyp) -> + ctyp_dependencies ctyp @ [CTG_vector (direction, ctyp)] | CT_ref ctyp -> ctyp_dependencies ctyp | CT_struct (_, ctors) -> List.concat (List.map (fun (_, ctyp) -> ctyp_dependencies ctyp) ctors) | CT_variant (_, ctors) -> List.concat (List.map (fun (_, ctyp) -> ctyp_dependencies ctyp) ctors) - | CT_lint | CT_fint _ | CT_lbits _ | CT_fbits _ | CT_sbits _ | CT_unit | CT_bool - | CT_real | CT_bit | CT_string | CT_enum _ | CT_poly _ | CT_constant _ | CT_float _ | CT_rounding_mode -> [] + | CT_lint | CT_fint _ | CT_lbits _ | CT_fbits _ | CT_sbits _ | CT_unit | CT_bool | CT_real | CT_bit | CT_string + | CT_enum _ | CT_poly _ | CT_constant _ | CT_float _ | CT_rounding_mode -> + [] let codegen_ctg = function | CTG_vector (direction, ctyp) -> codegen_vector (direction, ctyp) @@ -2191,51 +1927,47 @@ let codegen_ctg = function let codegen_def ctx def = let ctyps = cdef_ctyps def |> CTSet.elements in (* We should have erased any polymorphism introduced by variants at this point! *) - if List.exists is_polymorphic ctyps then + if List.exists is_polymorphic ctyps then ( let polymorphic_ctyps = List.filter is_polymorphic ctyps in - c_error (Printf.sprintf "Found polymorphic types:\n%s\nwhile generating definition." - (Util.string_of_list "\n" string_of_ctyp polymorphic_ctyps)) - else + c_error + (Printf.sprintf "Found polymorphic types:\n%s\nwhile generating definition." + (Util.string_of_list "\n" string_of_ctyp polymorphic_ctyps) + ) + ) + else ( let deps = List.concat (List.map ctyp_dependencies ctyps) in - separate_map hardline codegen_ctg deps - ^^ codegen_def' ctx def + separate_map hardline codegen_ctg deps ^^ codegen_def' ctx def + ) -let is_cdef_startup = function - | CDEF_startup _ -> true - | _ -> false +let is_cdef_startup = function CDEF_startup _ -> true | _ -> false let sgen_startup = function - | CDEF_startup (id, _) -> - Printf.sprintf " startup_%s();" (sgen_function_id id) + | CDEF_startup (id, _) -> Printf.sprintf " startup_%s();" (sgen_function_id id) | _ -> assert false -let sgen_instr id ctx instr = - Pretty_print_sail.to_string (codegen_instr id ctx instr) +let sgen_instr id ctx instr = Pretty_print_sail.to_string (codegen_instr id ctx instr) -let is_cdef_finish = function - | CDEF_startup _ -> true - | _ -> false +let is_cdef_finish = function CDEF_startup _ -> true | _ -> false let sgen_finish = function - | CDEF_startup (id, _) -> - Printf.sprintf " finish_%s();" (sgen_function_id id) + | CDEF_startup (id, _) -> Printf.sprintf " finish_%s();" (sgen_function_id id) | _ -> assert false let get_recursive_functions cdefs = let graph = Jib_compile.callgraph cdefs in let rf = IdGraph.self_loops graph in (* Use strongly-connected components for mutually recursive functions *) - List.fold_left (fun rf component -> - match component with [_] -> rf | mutual -> mutual @ rf - ) rf (IdGraph.scc graph) + List.fold_left (fun rf component -> match component with [_] -> rf | mutual -> mutual @ rf) rf (IdGraph.scc graph) |> IdSet.of_list let jib_of_ast env effect_info ast = - let module Jibc = Make(C_config(struct let branch_coverage = !opt_branch_coverage end)) in + let module Jibc = Make (C_config (struct + let branch_coverage = !opt_branch_coverage + end)) in let env, effect_info = add_special_functions env effect_info in let ctx = initial_ctx env effect_info in Jibc.compile_ast ctx ast - + let compile_ast env effect_info output_chan c_includes ast = try let cdefs, ctx = jib_of_ast env effect_info ast in @@ -2247,135 +1979,134 @@ let compile_ast env effect_info output_chan c_includes ast = let docs = separate_map (hardline ^^ hardline) (codegen_def ctx) cdefs in - let preamble = separate hardline - ((if !opt_no_lib then [] else [string "#include \"sail.h\""]) - @ (if !opt_no_rts then [] else - [ string "#include \"rts.h\""; - string "#include \"elf.h\"" ]) - @ (if Option.is_some !opt_branch_coverage then [string "#include \"sail_coverage.h\""] else []) - @ (List.map (fun h -> string (Printf.sprintf "#include \"%s\"" h)) c_includes)) + let preamble = + separate hardline + ((if !opt_no_lib then [] else [string "#include \"sail.h\""]) + @ (if !opt_no_rts then [] else [string "#include \"rts.h\""; string "#include \"elf.h\""]) + @ (if Option.is_some !opt_branch_coverage then [string "#include \"sail_coverage.h\""] else []) + @ List.map (fun h -> string (Printf.sprintf "#include \"%s\"" h)) c_includes + ) in let exn_boilerplate = - if not (Bindings.mem (mk_id "exception") ctx.variants) then ([], []) else - ([ " current_exception = sail_malloc(sizeof(struct zexception));"; - " CREATE(zexception)(current_exception);"; - " throw_location = sail_malloc(sizeof(sail_string));"; - " CREATE(sail_string)(throw_location);" ], - [ " if (have_exception) {fprintf(stderr, \"Exiting due to uncaught exception: %s\\n\", *throw_location);}"; - " KILL(zexception)(current_exception);"; - " sail_free(current_exception);"; - " KILL(sail_string)(throw_location);"; - " sail_free(throw_location);"; - " if (have_exception) {exit(EXIT_FAILURE);}" ]) + if not (Bindings.mem (mk_id "exception") ctx.variants) then ([], []) + else + ( [ + " current_exception = sail_malloc(sizeof(struct zexception));"; + " CREATE(zexception)(current_exception);"; + " throw_location = sail_malloc(sizeof(sail_string));"; + " CREATE(sail_string)(throw_location);"; + ], + [ + " if (have_exception) {fprintf(stderr, \"Exiting due to uncaught exception: %s\\n\", *throw_location);}"; + " KILL(zexception)(current_exception);"; + " sail_free(current_exception);"; + " KILL(sail_string)(throw_location);"; + " sail_free(throw_location);"; + " if (have_exception) {exit(EXIT_FAILURE);}"; + ] + ) in - let letbind_initializers = - List.map (fun n -> Printf.sprintf " create_letbind_%d();" n) (List.rev ctx.letbinds) - in - let letbind_finalizers = - List.map (fun n -> Printf.sprintf " kill_letbind_%d();" n) ctx.letbinds - in - let startup cdefs = - List.map sgen_startup (List.filter is_cdef_startup cdefs) - in - let finish cdefs = - List.map sgen_finish (List.filter is_cdef_finish cdefs) - in + let letbind_initializers = List.map (fun n -> Printf.sprintf " create_letbind_%d();" n) (List.rev ctx.letbinds) in + let letbind_finalizers = List.map (fun n -> Printf.sprintf " kill_letbind_%d();" n) ctx.letbinds in + let startup cdefs = List.map sgen_startup (List.filter is_cdef_startup cdefs) in + let finish cdefs = List.map sgen_finish (List.filter is_cdef_finish cdefs) in let regs = c_ast_registers cdefs in let register_init_clear (id, ctyp, instrs) = - if is_stack_ctyp ctyp then - List.map (sgen_instr (mk_id "reg") ctx) instrs, [] + if is_stack_ctyp ctyp then (List.map (sgen_instr (mk_id "reg") ctx) instrs, []) else - [ Printf.sprintf " CREATE(%s)(&%s);" (sgen_ctyp_name ctyp) (sgen_id id) ] - @ List.map (sgen_instr (mk_id "reg") ctx) instrs, - [ Printf.sprintf " KILL(%s)(&%s);" (sgen_ctyp_name ctyp) (sgen_id id) ] + ( [Printf.sprintf " CREATE(%s)(&%s);" (sgen_ctyp_name ctyp) (sgen_id id)] + @ List.map (sgen_instr (mk_id "reg") ctx) instrs, + [Printf.sprintf " KILL(%s)(&%s);" (sgen_ctyp_name ctyp) (sgen_id id)] + ) in let init_config_id = mk_id "__InitConfig" in - let model_init = separate hardline (List.map string - ( [ Printf.sprintf "%svoid model_init(void)" (static ()); - "{"; - " setup_rts();" ] - @ fst exn_boilerplate - @ startup cdefs - @ letbind_initializers - @ List.concat (List.map (fun r -> fst (register_init_clear r)) regs) - @ (if regs = [] then [] else [ Printf.sprintf " %s(UNIT);" (sgen_function_id (mk_id "initialize_registers")) ]) - @ (if ctx_has_val_spec init_config_id ctx then [ Printf.sprintf " %s(UNIT);" (sgen_function_id init_config_id) ] else []) - @ [ "}" ] )) + let model_init = + separate hardline + (List.map string + ([Printf.sprintf "%svoid model_init(void)" (static ()); "{"; " setup_rts();"] + @ fst exn_boilerplate @ startup cdefs @ letbind_initializers + @ List.concat (List.map (fun r -> fst (register_init_clear r)) regs) + @ (if regs = [] then [] else [Printf.sprintf " %s(UNIT);" (sgen_function_id (mk_id "initialize_registers"))]) + @ ( if ctx_has_val_spec init_config_id ctx then + [Printf.sprintf " %s(UNIT);" (sgen_function_id init_config_id)] + else [] + ) + @ ["}"] + ) + ) in - let model_fini = separate hardline (List.map string - ( [ Printf.sprintf "%svoid model_fini(void)" (static ()); - "{" ] - @ letbind_finalizers - @ List.concat (List.map (fun r -> snd (register_init_clear r)) regs) - @ finish cdefs - @ [ " cleanup_rts();" ] - @ snd exn_boilerplate - @ [ "}" ] )) + let model_fini = + separate hardline + (List.map string + ([Printf.sprintf "%svoid model_fini(void)" (static ()); "{"] + @ letbind_finalizers + @ List.concat (List.map (fun r -> snd (register_init_clear r)) regs) + @ finish cdefs @ [" cleanup_rts();"] @ snd exn_boilerplate @ ["}"] + ) + ) in let model_pre_exit = - [ "void model_pre_exit()"; - "{" ] - @ (if Option.is_some !opt_branch_coverage then - [ " if (sail_coverage_exit() != 0) {"; - " fprintf(stderr, \"Could not write coverage information\\n\");"; - " exit(EXIT_FAILURE);"; - " }"; - "}" ] - else - ["}"] - ) - |> List.map string - |> separate hardline + (["void model_pre_exit()"; "{"] + @ + if Option.is_some !opt_branch_coverage then + [ + " if (sail_coverage_exit() != 0) {"; + " fprintf(stderr, \"Could not write coverage information\\n\");"; + " exit(EXIT_FAILURE);"; + " }"; + "}"; + ] + else ["}"] + ) + |> List.map string |> separate hardline in let model_default_main = - ([ Printf.sprintf "%sint model_main(int argc, char *argv[])" (static ()); - "{"; - " model_init();"; - " if (process_arguments(argc, argv)) exit(EXIT_FAILURE);"; - Printf.sprintf " %s(UNIT);" (sgen_function_id (mk_id "main")); - " model_fini();"; - " model_pre_exit();"; - " return EXIT_SUCCESS;"; - "}" ]) - |> List.map string - |> separate hardline + [ + Printf.sprintf "%sint model_main(int argc, char *argv[])" (static ()); + "{"; + " model_init();"; + " if (process_arguments(argc, argv)) exit(EXIT_FAILURE);"; + Printf.sprintf " %s(UNIT);" (sgen_function_id (mk_id "main")); + " model_fini();"; + " model_pre_exit();"; + " return EXIT_SUCCESS;"; + "}"; + ] + |> List.map string |> separate hardline in - let model_main = separate hardline (if (!opt_no_main) then [] else List.map string - [ "int main(int argc, char *argv[])"; - "{"; - " return model_main(argc, argv);"; - "}" ] ) + let model_main = + separate hardline + ( if !opt_no_main then [] + else List.map string ["int main(int argc, char *argv[])"; "{"; " return model_main(argc, argv);"; "}"] + ) in let hlhl = hardline ^^ hardline in - Pretty_print_sail.to_string (preamble ^^ hlhl ^^ docs ^^ hlhl - ^^ (if not !opt_no_rts then - model_init ^^ hlhl - ^^ model_fini ^^ hlhl - ^^ model_pre_exit ^^ hlhl - ^^ model_default_main ^^ hlhl - else - empty) - ^^ model_main ^^ hardline) + Pretty_print_sail.to_string + (preamble ^^ hlhl ^^ docs ^^ hlhl + ^^ ( if not !opt_no_rts then + model_init ^^ hlhl ^^ model_fini ^^ hlhl ^^ model_pre_exit ^^ hlhl ^^ model_default_main ^^ hlhl + else empty + ) + ^^ model_main ^^ hardline + ) |> output_string output_chan - with - | Type_error (_, l, err) -> - c_error ~loc:l ("Unexpected type error when compiling to C:\n" ^ Type_error.string_of_type_error err) + with Type_error (_, l, err) -> + c_error ~loc:l ("Unexpected type error when compiling to C:\n" ^ Type_error.string_of_type_error err) let compile_ast_clib env effect_info ast codegen = let cdefs, ctx = jib_of_ast env effect_info ast in (* let cdefs', _ = Jib_optimize.remove_tuples cdefs ctx in *) let cdefs = insert_heap_returns Bindings.empty cdefs in codegen ctx cdefs - diff --git a/src/sail_c_backend/c_backend.mli b/src/sail_c_backend/c_backend.mli index 071073b22..863a66a88 100644 --- a/src/sail_c_backend/c_backend.mli +++ b/src/sail_c_backend/c_backend.mli @@ -106,10 +106,11 @@ val opt_prefix : string ref processor state, and each function will be passed the env argument when it is called. *) val opt_extra_params : string option ref + val opt_extra_arguments : string option ref val opt_branch_coverage : out_channel option ref - + (** Optimization flags *) val optimize_primops : bool ref diff --git a/src/sail_c_backend/dune b/src/sail_c_backend/dune index 376b8430d..3b6b2657a 100644 --- a/src/sail_c_backend/dune +++ b/src/sail_c_backend/dune @@ -1,10 +1,12 @@ - (executable - (name sail_plugin_c) - (modes (native plugin)) - (libraries libsail)) + (name sail_plugin_c) + (modes + (native plugin)) + (libraries libsail)) (install - (section (site (libsail plugins))) - (package sail_c_backend) - (files sail_plugin_c.cmxs)) + (section + (site + (libsail plugins))) + (package sail_c_backend) + (files sail_plugin_c.cmxs)) diff --git a/src/sail_c_backend/sail_plugin_c.ml b/src/sail_c_backend/sail_plugin_c.ml index 16d82601d..32c416870 100644 --- a/src/sail_c_backend/sail_plugin_c.ml +++ b/src/sail_c_backend/sail_plugin_c.ml @@ -70,57 +70,61 @@ open Libsail let opt_includes_c : string list ref = ref [] let opt_specialize_c = ref false -let c_options = [ - ( "-c_include", - Arg.String (fun i -> opt_includes_c := i::!opt_includes_c), - " provide additional include for C output"); - ( "-c_no_main", - Arg.Set C_backend.opt_no_main, - " do not generate the main() function" ); - ( "-c_no_rts", - Arg.Set C_backend.opt_no_rts, - " do not include the Sail runtime" ); - ( "-c_no_lib", - Arg.Tuple [Arg.Set C_backend.opt_no_lib; Arg.Set C_backend.opt_no_rts], - " do not include the Sail runtime or library" ); - ( "-c_prefix", - Arg.String (fun prefix -> C_backend.opt_prefix := prefix), - " prefix generated C functions" ); - ( "-c_extra_params", - Arg.String (fun params -> C_backend.opt_extra_params := Some params), - " generate C functions with additional parameters" ); - ( "-c_extra_args", - Arg.String (fun args -> C_backend.opt_extra_arguments := Some args), - " supply extra argument to every generated C function call" ); - ( "-c_specialize", - Arg.Set opt_specialize_c, - " specialize integer arguments in C output"); - ( "-c_preserve", - Arg.String (fun str -> Specialize.add_initial_calls (Ast_util.IdSet.singleton (Ast_util.mk_id str))), - " make sure the provided function identifier is preserved in C output"); - ( "-c_fold_unit", - Arg.String (fun str -> Constant_fold.opt_fold_to_unit := Util.split_on_char ',' str), - " remove comma separated list of functions from C output, replacing them with unit"); - ( "-c_coverage", - Arg.String (fun str -> C_backend.opt_branch_coverage := Some (open_out str)), - " Turn on coverage tracking and output information about all branches and functions to a file"); - ( "-O", - Arg.Tuple [Arg.Set C_backend.optimize_primops; - Arg.Set C_backend.optimize_hoist_allocations; - Arg.Set Initial_check.opt_fast_undefined; - Arg.Set C_backend.optimize_struct_updates; - Arg.Set C_backend.optimize_alias], - " turn on optimizations for C compilation"); - ( "-Ofixed_int", - Arg.Set C_backend.optimize_fixed_int, - " assume fixed size integers rather than GMP arbitrary precision integers"); - ( "-Ofixed_bits", - Arg.Set C_backend.optimize_fixed_bits, - " assume fixed size bitvectors rather than arbitrary precision bitvectors"); - ( "-static", - Arg.Set C_backend.opt_static, - " make generated C functions static"); -] +let c_options = + [ + ( "-c_include", + Arg.String (fun i -> opt_includes_c := i :: !opt_includes_c), + " provide additional include for C output" + ); + ("-c_no_main", Arg.Set C_backend.opt_no_main, " do not generate the main() function"); + ("-c_no_rts", Arg.Set C_backend.opt_no_rts, " do not include the Sail runtime"); + ( "-c_no_lib", + Arg.Tuple [Arg.Set C_backend.opt_no_lib; Arg.Set C_backend.opt_no_rts], + " do not include the Sail runtime or library" + ); + ("-c_prefix", Arg.String (fun prefix -> C_backend.opt_prefix := prefix), " prefix generated C functions"); + ( "-c_extra_params", + Arg.String (fun params -> C_backend.opt_extra_params := Some params), + " generate C functions with additional parameters" + ); + ( "-c_extra_args", + Arg.String (fun args -> C_backend.opt_extra_arguments := Some args), + " supply extra argument to every generated C function call" + ); + ("-c_specialize", Arg.Set opt_specialize_c, " specialize integer arguments in C output"); + ( "-c_preserve", + Arg.String (fun str -> Specialize.add_initial_calls (Ast_util.IdSet.singleton (Ast_util.mk_id str))), + " make sure the provided function identifier is preserved in C output" + ); + ( "-c_fold_unit", + Arg.String (fun str -> Constant_fold.opt_fold_to_unit := Util.split_on_char ',' str), + " remove comma separated list of functions from C output, replacing them with unit" + ); + ( "-c_coverage", + Arg.String (fun str -> C_backend.opt_branch_coverage := Some (open_out str)), + " Turn on coverage tracking and output information about all branches and functions to a file" + ); + ( "-O", + Arg.Tuple + [ + Arg.Set C_backend.optimize_primops; + Arg.Set C_backend.optimize_hoist_allocations; + Arg.Set Initial_check.opt_fast_undefined; + Arg.Set C_backend.optimize_struct_updates; + Arg.Set C_backend.optimize_alias; + ], + " turn on optimizations for C compilation" + ); + ( "-Ofixed_int", + Arg.Set C_backend.optimize_fixed_int, + " assume fixed size integers rather than GMP arbitrary precision integers" + ); + ( "-Ofixed_bits", + Arg.Set C_backend.optimize_fixed_bits, + " assume fixed size bitvectors rather than arbitrary precision bitvectors" + ); + ("-static", Arg.Set C_backend.opt_static, " make generated C functions static"); + ] let c_rewrites = let open Rewrites in @@ -150,23 +154,18 @@ let c_rewrites = ("exp_lift_assign", []); ("merge_function_clauses", []); ("optimize_recheck_defs", []); - ("constant_fold", [String_arg "c"]) + ("constant_fold", [String_arg "c"]); ] let c_target _ out_file ast effect_info _ = let ast, env = Type_error.check Type_check.initial_env (Type_check.strip_ast ast) in - let close, output_chan = match out_file with Some f -> true, open_out (f ^ ".c") | None -> false, stdout in + let close, output_chan = match out_file with Some f -> (true, open_out (f ^ ".c")) | None -> (false, stdout) in Reporting.opt_warnings := true; - C_backend.compile_ast env effect_info output_chan (!opt_includes_c) ast; + C_backend.compile_ast env effect_info output_chan !opt_includes_c ast; flush output_chan; - if close then ( - close_out output_chan - ) + if close then close_out output_chan let _ = - Target.register - ~name:"c" - ~options:c_options + Target.register ~name:"c" ~options:c_options ~pre_parse_hook:(fun () -> Initial_check.opt_undefined_gen := true) - ~rewrites:c_rewrites - c_target + ~rewrites:c_rewrites c_target diff --git a/src/sail_coq_backend/dune b/src/sail_coq_backend/dune index 172c59c03..4524ede4a 100644 --- a/src/sail_coq_backend/dune +++ b/src/sail_coq_backend/dune @@ -1,15 +1,20 @@ (env - (dev - (flags (:standard -w -33 -w -27 -w -32 -w -26 -w -37))) - (release - (flags (:standard -w -33 -w -27 -w -32 -w -26 -w -37)))) + (dev + (flags + (:standard -w -33 -w -27 -w -32 -w -26 -w -37))) + (release + (flags + (:standard -w -33 -w -27 -w -32 -w -26 -w -37)))) (executable - (name sail_plugin_coq) - (modes (native plugin)) - (libraries libsail)) + (name sail_plugin_coq) + (modes + (native plugin)) + (libraries libsail)) (install - (section (site (libsail plugins))) - (package sail_coq_backend) - (files sail_plugin_coq.cmxs)) + (section + (site + (libsail plugins))) + (package sail_coq_backend) + (files sail_plugin_coq.cmxs)) diff --git a/src/sail_coq_backend/pretty_print_coq.ml b/src/sail_coq_backend/pretty_print_coq.ml index f4a317000..60ae22f60 100644 --- a/src/sail_coq_backend/pretty_print_coq.ml +++ b/src/sail_coq_backend/pretty_print_coq.ml @@ -76,23 +76,20 @@ open Rewriter open PPrint open Pretty_print_common -module StringSet = Set.Make(String) +module StringSet = Set.Make (String) let rec list_contains cmp l1 = function | [] -> Some l1 - | h::t -> - let rec remove = function - | [] -> None - | h'::t' -> if cmp h h' = 0 then Some t' - else Option.map (List.cons h') (remove t') - in Option.bind (remove l1) (fun l1' -> list_contains cmp l1' t) + | h :: t -> + let rec remove = function + | [] -> None + | h' :: t' -> if cmp h h' = 0 then Some t' else Option.map (List.cons h') (remove t') + in + Option.bind (remove l1) (fun l1' -> list_contains cmp l1' t) (* We currently support OCaml versions that are too old for KBindings.filter_opt *) let kbindings_filter_map f m = - KBindings.fold - (fun kid v m -> match f kid v with None -> m | Some v' -> KBindings.add kid v' m) - m - KBindings.empty + KBindings.fold (fun kid v m -> match f kid v with None -> m | Some v' -> KBindings.add kid v' m) m KBindings.empty let opt_undef_axioms = ref false let opt_debug_on : string list ref = ref [] @@ -114,42 +111,44 @@ let opt_debug_on : string list ref = ref [] * must rely entirely on the type (like the Sail type checker). *) - type context = { types_mod : string; (* Name of the types module for disambiguation *) early_ret : typ option; - kid_renames : kid KBindings.t; (* Plain tyvar -> tyvar renames, - used to avoid variable/type variable name clashes *) + kid_renames : kid KBindings.t; + (* Plain tyvar -> tyvar renames, + used to avoid variable/type variable name clashes *) (* Note that as well as these kid renames, we also attempt to replace entire n_constraints with equivalent variables in doc_nc_exp. *) - kid_id_renames : (id option) KBindings.t; (* tyvar -> argument renames *) + kid_id_renames : id option KBindings.t; (* tyvar -> argument renames *) kid_id_renames_rev : kid Bindings.t; (* reverse of kid_id_renames *) constant_kids : Nat_big_num.num KBindings.t; (* type variables that should be replaced by a constant definition *) bound_nvars : KidSet.t; build_at_return : string option; - recursive_fns : (int * int) Bindings.t; (* Number of implicit arguments and constraints for (mutually) recursive definitions *) + recursive_fns : (int * int) Bindings.t; + (* Number of implicit arguments and constraints for (mutually) recursive definitions *) debug : bool; ret_typ_pp : PPrint.document; (* Return type formatted for use with returnR *) effect_info : Effects.side_effect_info; is_monadic : bool; avoid_target_names : StringSet.t; } -let empty_ctxt = { - types_mod = ""; - early_ret = None; - kid_renames = KBindings.empty; - kid_id_renames = KBindings.empty; - kid_id_renames_rev = Bindings.empty; - constant_kids = KBindings.empty; - bound_nvars = KidSet.empty; - build_at_return = None; - recursive_fns = Bindings.empty; - debug = false; - ret_typ_pp = PPrint.empty; - effect_info = Effects.empty_side_effect_info; - is_monadic = false; - avoid_target_names = StringSet.empty; -} +let empty_ctxt = + { + types_mod = ""; + early_ret = None; + kid_renames = KBindings.empty; + kid_id_renames = KBindings.empty; + kid_id_renames_rev = Bindings.empty; + constant_kids = KBindings.empty; + bound_nvars = KidSet.empty; + build_at_return = None; + recursive_fns = Bindings.empty; + debug = false; + ret_typ_pp = PPrint.empty; + effect_info = Effects.empty_side_effect_info; + is_monadic = false; + avoid_target_names = StringSet.empty; + } let add_single_kid_id_rename ctxt id kid = let kir = @@ -157,21 +156,17 @@ let add_single_kid_id_rename ctxt id kid = | Some kid -> KBindings.add kid None ctxt.kid_id_renames | None -> ctxt.kid_id_renames in - { ctxt with - kid_id_renames = KBindings.add kid (Some id) kir; - kid_id_renames_rev = Bindings.add id kid ctxt.kid_id_renames_rev - } + { + ctxt with + kid_id_renames = KBindings.add kid (Some id) kir; + kid_id_renames_rev = Bindings.add id kid ctxt.kid_id_renames_rev; + } let debug_depth = ref 0 -let rec indent n = match n with - | 0 -> "" - | n -> "| " ^ indent (n - 1) +let rec indent n = match n with 0 -> "" | n -> "| " ^ indent (n - 1) -let debug ctxt m = - if ctxt.debug - then print_endline (indent !debug_depth ^ Lazy.force m) - else () +let debug ctxt m = if ctxt.debug then print_endline (indent !debug_depth ^ Lazy.force m) else () let langlebar = string "<|" let ranglebar = string "|>" @@ -184,79 +179,47 @@ let comment = enclose (string "(*") (string "*)") let separate_opt s f l = separate s (List.filter_map f l) let is_number_char c = - c = '0' || c = '1' || c = '2' || c = '3' || c = '4' || c = '5' || - c = '6' || c = '7' || c = '8' || c = '9' + c = '0' || c = '1' || c = '2' || c = '3' || c = '4' || c = '5' || c = '6' || c = '7' || c = '8' || c = '9' -let is_enum env id = - match Env.lookup_id id env with - | Enum _ -> true - | _ -> false +let is_enum env id = match Env.lookup_id id env with Enum _ -> true | _ -> false -let rec fix_id avoid remove_tick name = match name with - | "assert" - | "lsl" - | "lsr" - | "asr" - | "type" - | "fun" - | "function" - | "raise" - | "try" - | "match" - | "with" - | "check" - | "field" - | "LT" - | "GT" - | "EQ" - | "Z" - | "O" - | "R" - | "S" - | "mod" - | "M" - | "tt" - -> name ^ "'" +let rec fix_id avoid remove_tick name = + match name with + | "assert" | "lsl" | "lsr" | "asr" | "type" | "fun" | "function" | "raise" | "try" | "match" | "with" | "check" + | "field" | "LT" | "GT" | "EQ" | "Z" | "O" | "R" | "S" | "mod" | "M" | "tt" -> + name ^ "'" | _ -> - if StringSet.mem name avoid then - name ^ "'" - else if String.contains name '#' then - fix_id avoid remove_tick (String.concat "_" (Util.split_on_char '#' name)) - else if String.contains name '?' then - fix_id avoid remove_tick (String.concat "_pat_" (Util.split_on_char '?' name)) - else if String.contains name '^' then - fix_id avoid remove_tick (String.concat "__" (Util.split_on_char '^' name)) - else if name.[0] = '\'' then - let var = String.sub name 1 (String.length name - 1) in - if remove_tick then fix_id avoid remove_tick var else (var ^ "'") - else if is_number_char(name.[0]) then - ("v" ^ name ^ "'") - else name - -let string_id avoid (Id_aux(i,_)) = - match i with - | Id i -> fix_id avoid false i - | Operator x -> Util.zencode_string ("op " ^ x) + if StringSet.mem name avoid then name ^ "'" + else if String.contains name '#' then fix_id avoid remove_tick (String.concat "_" (Util.split_on_char '#' name)) + else if String.contains name '?' then + fix_id avoid remove_tick (String.concat "_pat_" (Util.split_on_char '?' name)) + else if String.contains name '^' then fix_id avoid remove_tick (String.concat "__" (Util.split_on_char '^' name)) + else if name.[0] = '\'' then ( + let var = String.sub name 1 (String.length name - 1) in + if remove_tick then fix_id avoid remove_tick var else var ^ "'" + ) + else if is_number_char name.[0] then "v" ^ name ^ "'" + else name + +let string_id avoid (Id_aux (i, _)) = + match i with Id i -> fix_id avoid false i | Operator x -> Util.zencode_string ("op " ^ x) let doc_id ctxt id = string (string_id ctxt.avoid_target_names id) -let doc_id_type types_mod avoid env (Id_aux(i,_) as id) = +let doc_id_type types_mod avoid env (Id_aux (i, _) as id) = let is_shadowed () = match env with | None -> false - | Some env -> - IdSet.mem id (Env.get_defined_val_specs env) || - not (is_unbound (Env.lookup_id id env)) + | Some env -> IdSet.mem id (Env.get_defined_val_specs env) || not (is_unbound (Env.lookup_id id env)) in match i with - | Id("int") -> string "Z" - | Id("real") -> string "R" - | Id i when is_shadowed () -> - string types_mod ^^ dot ^^ string (fix_id avoid false i) + | Id "int" -> string "Z" + | Id "real" -> string "R" + | Id i when is_shadowed () -> string types_mod ^^ dot ^^ string (fix_id avoid false i) | Id i -> string (fix_id avoid false i) | Operator x -> string (Util.zencode_string ("op " ^ x)) -let doc_id_ctor ctxt (Id_aux(i,_)) = +let doc_id_ctor ctxt (Id_aux (i, _)) = match i with | Id i -> string (fix_id ctxt.avoid_target_names false i) | Operator x -> string (Util.zencode_string ("op " ^ x)) @@ -266,60 +229,52 @@ let doc_var ctxt kid = | Some id -> doc_id ctxt id | None -> underscore (* The original id has been shadowed, hope Coq can work it out... TODO: warn? *) | exception Not_found -> - string (fix_id ctxt.avoid_target_names true (string_of_kid (try KBindings.find kid ctxt.kid_renames with Not_found -> kid))) + string + (fix_id ctxt.avoid_target_names true + (string_of_kid (try KBindings.find kid ctxt.kid_renames with Not_found -> kid)) + ) let simple_annot l typ = (Parse_ast.Generated l, Some (Env.empty, typ)) -let simple_num l n = E_aux ( - E_lit (L_aux (L_num n, Parse_ast.Generated l)), - simple_annot (Parse_ast.Generated l) - (atom_typ (Nexp_aux (Nexp_constant n, Parse_ast.Generated l)))) +let simple_num l n = + E_aux + ( E_lit (L_aux (L_num n, Parse_ast.Generated l)), + simple_annot (Parse_ast.Generated l) (atom_typ (Nexp_aux (Nexp_constant n, Parse_ast.Generated l))) + ) -let is_regtyp (Typ_aux (typ, _)) env = match typ with - | Typ_app(id, _) when string_of_id id = "register" -> true - | _ -> false +let is_regtyp (Typ_aux (typ, _)) env = + match typ with Typ_app (id, _) when string_of_id id = "register" -> true | _ -> false -let doc_nexp ctx ?(skip_vars=KidSet.empty) nexp = +let doc_nexp ctx ?(skip_vars = KidSet.empty) nexp = (* Print according to Coq's precedence rules *) - let rec plussub (Nexp_aux (n,l) as nexp) = + let rec plussub (Nexp_aux (n, l) as nexp) = match n with | Nexp_sum (n1, n2) -> separate space [plussub n1; plus; mul n2] | Nexp_minus (n1, n2) -> separate space [plussub n1; minus; mul n2] | _ -> mul nexp - and mul (Nexp_aux (n,l) as nexp) = - match n with - | Nexp_times (n1, n2) -> separate space [mul n1; star; uneg n2] - | _ -> uneg nexp - and uneg (Nexp_aux (n,l) as nexp) = - match n with - | Nexp_neg n -> separate space [minus; uneg n] - | _ -> exp nexp - and exp (Nexp_aux (n,l) as nexp) = + and mul (Nexp_aux (n, l) as nexp) = + match n with Nexp_times (n1, n2) -> separate space [mul n1; star; uneg n2] | _ -> uneg nexp + and uneg (Nexp_aux (n, l) as nexp) = match n with Nexp_neg n -> separate space [minus; uneg n] | _ -> exp nexp + and exp (Nexp_aux (n, l) as nexp) = + match n with Nexp_exp n -> separate space [string "2"; caret; exp n] | _ -> app nexp + and app (Nexp_aux (n, l) as nexp) = match n with - | Nexp_exp n -> separate space [string "2"; caret; exp n] - | _ -> app nexp - and app (Nexp_aux (n,l) as nexp) = - match n with - | Nexp_app (Id_aux (Id "div",_), [n1;n2]) - -> separate space [string "ZEuclid.div"; atomic n1; atomic n2] - | Nexp_app (Id_aux (Id "mod",_), [n1;n2]) - -> separate space [string "ZEuclid.modulo"; atomic n1; atomic n2] - | Nexp_app (Id_aux (Id "abs_atom",_), [n1]) - -> separate space [string "Z.abs"; atomic n1] + | Nexp_app (Id_aux (Id "div", _), [n1; n2]) -> separate space [string "ZEuclid.div"; atomic n1; atomic n2] + | Nexp_app (Id_aux (Id "mod", _), [n1; n2]) -> separate space [string "ZEuclid.modulo"; atomic n1; atomic n2] + | Nexp_app (Id_aux (Id "abs_atom", _), [n1]) -> separate space [string "Z.abs"; atomic n1] | _ -> atomic nexp - and atomic (Nexp_aux (n,l) as nexp) = + and atomic (Nexp_aux (n, l) as nexp) = match n with | Nexp_constant i -> string (Big_int.to_string i) | Nexp_var v when KidSet.mem v skip_vars -> string "_" | Nexp_var v -> doc_var ctx v | Nexp_id id -> doc_id ctx id | Nexp_sum _ | Nexp_minus _ | Nexp_times _ | Nexp_neg _ | Nexp_exp _ - | Nexp_app (Id_aux (Id ("div"|"mod"),_), [_;_]) - | Nexp_app (Id_aux (Id "abs_atom",_), [_]) - -> parens (plussub nexp) - | _ -> - raise (Reporting.err_unreachable l __POS__ - ("cannot pretty-print nexp \"" ^ string_of_nexp nexp ^ "\"")) - in atomic nexp + | Nexp_app (Id_aux (Id ("div" | "mod"), _), [_; _]) + | Nexp_app (Id_aux (Id "abs_atom", _), [_]) -> + parens (plussub nexp) + | _ -> raise (Reporting.err_unreachable l __POS__ ("cannot pretty-print nexp \"" ^ string_of_nexp nexp ^ "\"")) + in + atomic nexp (* Rewrite mangled names of type variables to the original names *) let rec orig_nexp (Nexp_aux (nexp, l)) = @@ -342,479 +297,467 @@ let rec orig_nc (NC_aux (nc, l) as full_nc) = | NC_bounded_le (nexp1, nexp2) -> rewrap (NC_bounded_le (orig_nexp nexp1, orig_nexp nexp2)) | NC_bounded_lt (nexp1, nexp2) -> rewrap (NC_bounded_lt (orig_nexp nexp1, orig_nexp nexp2)) | NC_not_equal (nexp1, nexp2) -> rewrap (NC_not_equal (orig_nexp nexp1, orig_nexp nexp2)) - | NC_set (kid,s) -> rewrap (NC_set (orig_kid kid, s)) + | NC_set (kid, s) -> rewrap (NC_set (orig_kid kid, s)) | NC_or (nc1, nc2) -> rewrap (NC_or (orig_nc nc1, orig_nc nc2)) | NC_and (nc1, nc2) -> rewrap (NC_and (orig_nc nc1, orig_nc nc2)) - | NC_app (f,args) -> rewrap (NC_app (f,List.map orig_typ_arg args)) + | NC_app (f, args) -> rewrap (NC_app (f, List.map orig_typ_arg args)) | NC_var kid -> rewrap (NC_var (orig_kid kid)) | NC_true | NC_false -> full_nc -and orig_typ_arg (A_aux (arg,l)) = - let rewrap a = (A_aux (a,l)) in + +and orig_typ_arg (A_aux (arg, l)) = + let rewrap a = A_aux (a, l) in match arg with | A_nexp nexp -> rewrap (A_nexp (orig_nexp nexp)) | A_bool nc -> rewrap (A_bool (orig_nc nc)) - | A_order _ | A_typ _ -> - raise (Reporting.err_unreachable l __POS__ "Tried to pass Type or Order kind to SMT function") + | A_order _ | A_typ _ -> raise (Reporting.err_unreachable l __POS__ "Tried to pass Type or Order kind to SMT function") (* Returns the set of type variables that will appear in the Coq output, which may be smaller than those in the Sail type. May need to be updated with do *) -let rec coq_nvars_of_typ (Typ_aux (t,l)) = +let rec coq_nvars_of_typ (Typ_aux (t, l)) = let trec = coq_nvars_of_typ in match t with | Typ_id _ -> KidSet.empty | Typ_var kid -> tyvars_of_nexp (orig_nexp (nvar kid)) - | Typ_fn (t1,t2) -> List.fold_left KidSet.union (trec t2) (List.map trec t1) - | Typ_tuple ts -> - List.fold_left (fun s t -> KidSet.union s (trec t)) - KidSet.empty ts - | Typ_app(Id_aux (Id "register", _), [A_aux (A_typ etyp, _)]) -> - trec etyp - | Typ_app(Id_aux (Id "implicit", _),_) + | Typ_fn (t1, t2) -> List.fold_left KidSet.union (trec t2) (List.map trec t1) + | Typ_tuple ts -> List.fold_left (fun s t -> KidSet.union s (trec t)) KidSet.empty ts + | Typ_app (Id_aux (Id "register", _), [A_aux (A_typ etyp, _)]) -> trec etyp + | Typ_app (Id_aux (Id "implicit", _), _) (* TODO: update when complex atom types are sorted out *) - | Typ_app(Id_aux (Id "atom", _), _) -> KidSet.empty - | Typ_app(Id_aux (Id "atom_bool", _), _) -> KidSet.empty - | Typ_app (_,tas) -> - List.fold_left (fun s ta -> KidSet.union s (coq_nvars_of_typ_arg ta)) - KidSet.empty tas - | Typ_exist (kopts,_,t) -> - List.fold_left (fun vs kopt -> KidSet.remove (kopt_kid kopt) vs) (trec t) kopts + | Typ_app (Id_aux (Id "atom", _), _) -> + KidSet.empty + | Typ_app (Id_aux (Id "atom_bool", _), _) -> KidSet.empty + | Typ_app (_, tas) -> List.fold_left (fun s ta -> KidSet.union s (coq_nvars_of_typ_arg ta)) KidSet.empty tas + | Typ_exist (kopts, _, t) -> List.fold_left (fun vs kopt -> KidSet.remove (kopt_kid kopt) vs) (trec t) kopts | Typ_bidir _ -> unreachable l __POS__ "Coq doesn't support bidir types" | Typ_internal_unknown -> unreachable l __POS__ "escaped Typ_internal_unknown" -and coq_nvars_of_typ_arg (A_aux (ta,_)) = + +and coq_nvars_of_typ_arg (A_aux (ta, _)) = match ta with | A_nexp nexp -> tyvars_of_nexp (orig_nexp nexp) | A_typ typ -> coq_nvars_of_typ typ | A_order _ -> KidSet.empty | A_bool nc -> tyvars_of_constraint (orig_nc nc) -let maybe_expand_range_type (Typ_aux (typ,l) as full_typ) = +let maybe_expand_range_type (Typ_aux (typ, l) as full_typ) = match typ with - | Typ_app(Id_aux (Id "range", _), [A_aux(A_nexp low,_); - A_aux(A_nexp high,_)]) -> - (* TODO: avoid name clashes *) - let kid = mk_kid "rangevar" in - let var = nvar kid in - let nc = nc_and (nc_lteq low var) (nc_lteq var high) in - Some (Typ_aux (Typ_exist ([mk_kopt K_int kid], nc, atom_typ var),Parse_ast.Generated l)) - | Typ_id (Id_aux (Id "nat",_)) -> - let kid = mk_kid "n" in - let var = nvar kid in - Some (Typ_aux (Typ_exist ([mk_kopt K_int kid], nc_gteq var (nconstant Nat_big_num.zero), atom_typ var), - Parse_ast.Generated l)) + | Typ_app (Id_aux (Id "range", _), [A_aux (A_nexp low, _); A_aux (A_nexp high, _)]) -> + (* TODO: avoid name clashes *) + let kid = mk_kid "rangevar" in + let var = nvar kid in + let nc = nc_and (nc_lteq low var) (nc_lteq var high) in + Some (Typ_aux (Typ_exist ([mk_kopt K_int kid], nc, atom_typ var), Parse_ast.Generated l)) + | Typ_id (Id_aux (Id "nat", _)) -> + let kid = mk_kid "n" in + let var = nvar kid in + Some + (Typ_aux + ( Typ_exist ([mk_kopt K_int kid], nc_gteq var (nconstant Nat_big_num.zero), atom_typ var), + Parse_ast.Generated l + ) + ) | _ -> None let expand_range_type typ = Option.value ~default:typ (maybe_expand_range_type typ) - let nice_and nc1 nc2 = -match nc1, nc2 with -| NC_aux (NC_true,_), _ -> nc2 -| _, NC_aux (NC_true,_) -> nc1 -| _,_ -> nc_and nc1 nc2 + match (nc1, nc2) with NC_aux (NC_true, _), _ -> nc2 | _, NC_aux (NC_true, _) -> nc1 | _, _ -> nc_and nc1 nc2 let nice_iff nc1 nc2 = -match nc1, nc2 with -| NC_aux (NC_true,_), _ -> nc2 -| _, NC_aux (NC_true,_) -> nc1 -| NC_aux (NC_false,_), _ -> nc_not nc2 -| _, NC_aux (NC_false,_) -> nc_not nc1 - (* TODO: replace this hacky iff with a proper NC_ constructor *) -| _,_ -> mk_nc (NC_app (mk_id "iff",[arg_bool nc1; arg_bool nc2])) + match (nc1, nc2) with + | NC_aux (NC_true, _), _ -> nc2 + | _, NC_aux (NC_true, _) -> nc1 + | NC_aux (NC_false, _), _ -> nc_not nc2 + | _, NC_aux (NC_false, _) -> nc_not nc1 + (* TODO: replace this hacky iff with a proper NC_ constructor *) + | _, _ -> mk_nc (NC_app (mk_id "iff", [arg_bool nc1; arg_bool nc2])) (* n_constraint functions are currently just Z3 functions *) -let doc_nc_fn ctx (Id_aux (id,_) as full_id) = +let doc_nc_fn ctx (Id_aux (id, _) as full_id) = match id with | Id "not" -> string "negb" | Operator "-->" -> string "implb" | Id "iff" -> string "Bool.eqb" | _ -> doc_id ctx full_id -let merge_kid_count = KBindings.union (fun _ m n -> Some (m+n)) +let merge_kid_count = KBindings.union (fun _ m n -> Some (m + n)) -let rec count_nexp_vars (Nexp_aux (nexp,_)) = +let rec count_nexp_vars (Nexp_aux (nexp, _)) = match nexp with - | Nexp_id _ - | Nexp_constant _ - -> KBindings.empty + | Nexp_id _ | Nexp_constant _ -> KBindings.empty | Nexp_var kid -> KBindings.singleton kid 1 - | Nexp_app (_,nes) -> - List.fold_left merge_kid_count KBindings.empty (List.map count_nexp_vars nes) - | Nexp_times (n1,n2) - | Nexp_sum (n1,n2) - | Nexp_minus (n1,n2) - -> merge_kid_count (count_nexp_vars n1) (count_nexp_vars n2) - | Nexp_exp n - | Nexp_neg n - -> count_nexp_vars n - -let rec count_nc_vars (NC_aux (nc,_)) = - let count_arg (A_aux (arg,_)) = + | Nexp_app (_, nes) -> List.fold_left merge_kid_count KBindings.empty (List.map count_nexp_vars nes) + | Nexp_times (n1, n2) | Nexp_sum (n1, n2) | Nexp_minus (n1, n2) -> + merge_kid_count (count_nexp_vars n1) (count_nexp_vars n2) + | Nexp_exp n | Nexp_neg n -> count_nexp_vars n + +let rec count_nc_vars (NC_aux (nc, _)) = + let count_arg (A_aux (arg, _)) = match arg with | A_bool nc -> count_nc_vars nc | A_nexp nexp -> count_nexp_vars nexp | A_typ _ | A_order _ -> KBindings.empty in match nc with - | NC_or (nc1,nc2) - | NC_and (nc1,nc2) - -> merge_kid_count (count_nc_vars nc1) (count_nc_vars nc2) - | NC_var kid - | NC_set (kid,_) - -> KBindings.singleton kid 1 - | NC_equal (n1,n2) - | NC_bounded_ge (n1,n2) - | NC_bounded_gt (n1,n2) - | NC_bounded_le (n1,n2) - | NC_bounded_lt (n1,n2) - | NC_not_equal (n1,n2) - -> merge_kid_count (count_nexp_vars n1) (count_nexp_vars n2) - | NC_true | NC_false - -> KBindings.empty - | NC_app (_,args) -> - List.fold_left merge_kid_count KBindings.empty (List.map count_arg args) + | NC_or (nc1, nc2) | NC_and (nc1, nc2) -> merge_kid_count (count_nc_vars nc1) (count_nc_vars nc2) + | NC_var kid | NC_set (kid, _) -> KBindings.singleton kid 1 + | NC_equal (n1, n2) + | NC_bounded_ge (n1, n2) + | NC_bounded_gt (n1, n2) + | NC_bounded_le (n1, n2) + | NC_bounded_lt (n1, n2) + | NC_not_equal (n1, n2) -> + merge_kid_count (count_nexp_vars n1) (count_nexp_vars n2) + | NC_true | NC_false -> KBindings.empty + | NC_app (_, args) -> List.fold_left merge_kid_count KBindings.empty (List.map count_arg args) (* Simplify some of the complex boolean types created by the Sail type checker, whereever an existentially bound variable is used once in a trivial way, for example exists b, b and exists n, n = 32. *) -type atom_bool_prop = - Bool_boring -| Bool_complex of kinded_id list * n_constraint * n_constraint +type atom_bool_prop = Bool_boring | Bool_complex of kinded_id list * n_constraint * n_constraint let simplify_atom_bool l kopts nc atom_nc = -(*prerr_endline ("simplify " ^ string_of_n_constraint nc ^ " for bool " ^ string_of_n_constraint atom_nc);*) + (*prerr_endline ("simplify " ^ string_of_n_constraint nc ^ " for bool " ^ string_of_n_constraint atom_nc);*) let counter = ref 0 in let is_bound kid = List.exists (fun kopt -> Kid.compare kid (kopt_kid kopt) == 0) kopts in let ty_vars = merge_kid_count (count_nc_vars nc) (count_nc_vars atom_nc) in let lin_ty_vars = KBindings.filter (fun kid n -> is_bound kid && n = 1) ty_vars in - let rec simplify (NC_aux (nc,l) as nc_full) = - let is_ex_var news (NC_aux (nc,_)) = + let rec simplify (NC_aux (nc, l) as nc_full) = + let is_ex_var news (NC_aux (nc, _)) = match nc with | NC_var kid when KBindings.mem kid lin_ty_vars -> Some kid | NC_var kid when KidSet.mem kid news -> Some kid - | NC_equal (Nexp_aux (Nexp_var kid,_), _) when KBindings.mem kid lin_ty_vars -> Some kid - | NC_equal (_, Nexp_aux (Nexp_var kid,_)) when KBindings.mem kid lin_ty_vars -> Some kid - | NC_bounded_ge (Nexp_aux (Nexp_var kid,_), _) when KBindings.mem kid lin_ty_vars -> Some kid - | NC_bounded_ge (_, Nexp_aux (Nexp_var kid,_)) when KBindings.mem kid lin_ty_vars -> Some kid - | NC_bounded_gt (Nexp_aux (Nexp_var kid,_), _) when KBindings.mem kid lin_ty_vars -> Some kid - | NC_bounded_gt (_, Nexp_aux (Nexp_var kid,_)) when KBindings.mem kid lin_ty_vars -> Some kid - | NC_bounded_le (Nexp_aux (Nexp_var kid,_), _) when KBindings.mem kid lin_ty_vars -> Some kid - | NC_bounded_le (_, Nexp_aux (Nexp_var kid,_)) when KBindings.mem kid lin_ty_vars -> Some kid - | NC_bounded_lt (Nexp_aux (Nexp_var kid,_), _) when KBindings.mem kid lin_ty_vars -> Some kid - | NC_bounded_lt (_, Nexp_aux (Nexp_var kid,_)) when KBindings.mem kid lin_ty_vars -> Some kid - | NC_not_equal (Nexp_aux (Nexp_var kid,_), _) when KBindings.mem kid lin_ty_vars -> Some kid - | NC_not_equal (_, Nexp_aux (Nexp_var kid,_)) when KBindings.mem kid lin_ty_vars -> Some kid - | NC_set (kid, _::_) when KBindings.mem kid lin_ty_vars -> Some kid + | NC_equal (Nexp_aux (Nexp_var kid, _), _) when KBindings.mem kid lin_ty_vars -> Some kid + | NC_equal (_, Nexp_aux (Nexp_var kid, _)) when KBindings.mem kid lin_ty_vars -> Some kid + | NC_bounded_ge (Nexp_aux (Nexp_var kid, _), _) when KBindings.mem kid lin_ty_vars -> Some kid + | NC_bounded_ge (_, Nexp_aux (Nexp_var kid, _)) when KBindings.mem kid lin_ty_vars -> Some kid + | NC_bounded_gt (Nexp_aux (Nexp_var kid, _), _) when KBindings.mem kid lin_ty_vars -> Some kid + | NC_bounded_gt (_, Nexp_aux (Nexp_var kid, _)) when KBindings.mem kid lin_ty_vars -> Some kid + | NC_bounded_le (Nexp_aux (Nexp_var kid, _), _) when KBindings.mem kid lin_ty_vars -> Some kid + | NC_bounded_le (_, Nexp_aux (Nexp_var kid, _)) when KBindings.mem kid lin_ty_vars -> Some kid + | NC_bounded_lt (Nexp_aux (Nexp_var kid, _), _) when KBindings.mem kid lin_ty_vars -> Some kid + | NC_bounded_lt (_, Nexp_aux (Nexp_var kid, _)) when KBindings.mem kid lin_ty_vars -> Some kid + | NC_not_equal (Nexp_aux (Nexp_var kid, _), _) when KBindings.mem kid lin_ty_vars -> Some kid + | NC_not_equal (_, Nexp_aux (Nexp_var kid, _)) when KBindings.mem kid lin_ty_vars -> Some kid + | NC_set (kid, _ :: _) when KBindings.mem kid lin_ty_vars -> Some kid | _ -> None in let replace kills vars = let v = mk_kid ("simp#" ^ string_of_int !counter) in let kills = KidSet.union kills (KidSet.of_list vars) in counter := !counter + 1; - KidSet.singleton v, kills, NC_aux (NC_var v,l) + (KidSet.singleton v, kills, NC_aux (NC_var v, l)) in match nc with - | NC_or (nc1,nc2) -> begin - let new1, kill1, nc1 = simplify nc1 in - let new2, kill2, nc2 = simplify nc2 in - let news, kills = KidSet.union new1 new2, KidSet.union kill1 kill2 in - match is_ex_var news nc1, is_ex_var news nc2 with - | Some kid1, Some kid2 -> replace kills [kid1;kid2] - | _ -> news, kills, NC_aux (NC_or (nc1,nc2),l) + | NC_or (nc1, nc2) -> begin + let new1, kill1, nc1 = simplify nc1 in + let new2, kill2, nc2 = simplify nc2 in + let news, kills = (KidSet.union new1 new2, KidSet.union kill1 kill2) in + match (is_ex_var news nc1, is_ex_var news nc2) with + | Some kid1, Some kid2 -> replace kills [kid1; kid2] + | _ -> (news, kills, NC_aux (NC_or (nc1, nc2), l)) end - | NC_and (nc1,nc2) -> begin - let new1, kill1, nc1 = simplify nc1 in - let new2, kill2, nc2 = simplify nc2 in - let news, kills = KidSet.union new1 new2, KidSet.union kill1 kill2 in - match is_ex_var news nc1, is_ex_var news nc2 with - | Some kid1, Some kid2 -> replace kills [kid1;kid2] - | _ -> news, kills, NC_aux (NC_and (nc1,nc2),l) + | NC_and (nc1, nc2) -> begin + let new1, kill1, nc1 = simplify nc1 in + let new2, kill2, nc2 = simplify nc2 in + let news, kills = (KidSet.union new1 new2, KidSet.union kill1 kill2) in + match (is_ex_var news nc1, is_ex_var news nc2) with + | Some kid1, Some kid2 -> replace kills [kid1; kid2] + | _ -> (news, kills, NC_aux (NC_and (nc1, nc2), l)) end - | NC_app (Id_aux (Id "not",_) as id,[A_aux (A_bool nc1,al)]) -> begin - let new1, kill1, nc1 = simplify nc1 in - match is_ex_var new1 nc1 with - | Some kid -> replace kill1 [kid] - | None -> new1, kill1, NC_aux (NC_app (id,[A_aux (A_bool nc1,al)]),l) + | NC_app ((Id_aux (Id "not", _) as id), [A_aux (A_bool nc1, al)]) -> begin + let new1, kill1, nc1 = simplify nc1 in + match is_ex_var new1 nc1 with + | Some kid -> replace kill1 [kid] + | None -> (new1, kill1, NC_aux (NC_app (id, [A_aux (A_bool nc1, al)]), l)) end (* We don't currently recurse into general uses of NC_app, but the "boring" cases we really want to get rid of won't contain those. *) - | _ -> - match is_ex_var KidSet.empty nc_full with - | Some kid -> replace KidSet.empty [kid] - | None -> KidSet.empty, KidSet.empty, nc_full + | _ -> ( + match is_ex_var KidSet.empty nc_full with + | Some kid -> replace KidSet.empty [kid] + | None -> (KidSet.empty, KidSet.empty, nc_full) + ) in let new_nc, kill_nc, nc = simplify nc in let new_atom, kill_atom, atom_nc = simplify atom_nc in let new_kids = KidSet.union new_nc new_atom in let kill_kids = KidSet.union kill_nc kill_atom in let kopts = - List.map (fun kid -> mk_kopt K_bool kid) (KidSet.elements new_kids) @ - List.filter (fun kopt -> not (KidSet.mem (kopt_kid kopt) kill_kids)) kopts + List.map (fun kid -> mk_kopt K_bool kid) (KidSet.elements new_kids) + @ List.filter (fun kopt -> not (KidSet.mem (kopt_kid kopt) kill_kids)) kopts in -(*prerr_endline ("now have " ^ string_of_n_constraint nc ^ " for bool " ^ string_of_n_constraint atom_nc);*) + (*prerr_endline ("now have " ^ string_of_n_constraint nc ^ " for bool " ^ string_of_n_constraint atom_nc);*) match atom_nc with - | NC_aux (NC_var kid,_) when KBindings.mem kid lin_ty_vars -> Bool_boring - | NC_aux (NC_var kid,_) when KidSet.mem kid new_kids -> Bool_boring + | NC_aux (NC_var kid, _) when KBindings.mem kid lin_ty_vars -> Bool_boring + | NC_aux (NC_var kid, _) when KidSet.mem kid new_kids -> Bool_boring | _ -> Bool_complex (kopts, nc, atom_nc) - type ex_kind = ExNone | ExGeneral -let string_of_ex_kind = function - | ExNone -> "none" - | ExGeneral -> "general" +let string_of_ex_kind = function ExNone -> "none" | ExGeneral -> "general" (* Should a Sail type be turned into a dependent pair in Coq? Optionally takes a variable that we're binding (to avoid trivial cases where the type is exactly the boolean we're binding), and whether to turn bools with interesting type-expressions into dependent pairs. *) -let classify_ex_type ctxt env ?binding ?(rawbools=false) (Typ_aux (t,l) as t0) = +let classify_ex_type ctxt env ?binding ?(rawbools = false) (Typ_aux (t, l) as t0) = let is_binding kid = - match binding, KBindings.find_opt kid ctxt.kid_id_renames with + match (binding, KBindings.find_opt kid ctxt.kid_id_renames) with | Some id, Some (Some id') when Id.compare id id' == 0 -> true | _ -> false in let simplify_atom_bool l kopts nc atom_nc = match simplify_atom_bool l kopts nc atom_nc with | Bool_boring -> Bool_boring - | Bool_complex (_,_,NC_aux (NC_var kid,_)) when is_binding kid -> Bool_boring - | Bool_complex (x,y,z) -> Bool_complex (x,y,z) + | Bool_complex (_, _, NC_aux (NC_var kid, _)) when is_binding kid -> Bool_boring + | Bool_complex (x, y, z) -> Bool_complex (x, y, z) in match t with - | Typ_exist (kopts,nc,Typ_aux (Typ_app (Id_aux (Id "atom_bool",_), [A_aux (A_bool atom_nc,_)]),_)) -> begin + | Typ_exist (kopts, nc, Typ_aux (Typ_app (Id_aux (Id "atom_bool", _), [A_aux (A_bool atom_nc, _)]), _)) -> begin match simplify_atom_bool l kopts nc atom_nc with - | Bool_boring -> ExNone, [], bool_typ - | Bool_complex _ -> ExGeneral, [], bool_typ + | Bool_boring -> (ExNone, [], bool_typ) + | Bool_complex _ -> (ExGeneral, [], bool_typ) end - | Typ_app (Id_aux (Id "atom_bool",_), [A_aux (A_bool atom_nc,_)]) -> begin - match rawbools, simplify_atom_bool l [] nc_true atom_nc with - | false, _ -> ExNone, [], bool_typ - | _,Bool_boring -> ExNone, [], bool_typ - | _,Bool_complex _ -> ExGeneral, [], bool_typ + | Typ_app (Id_aux (Id "atom_bool", _), [A_aux (A_bool atom_nc, _)]) -> begin + match (rawbools, simplify_atom_bool l [] nc_true atom_nc) with + | false, _ -> (ExNone, [], bool_typ) + | _, Bool_boring -> (ExNone, [], bool_typ) + | _, Bool_complex _ -> (ExGeneral, [], bool_typ) end - | Typ_exist (kopts,_,t1) -> ExGeneral,kopts,t1 - | _ -> ExNone,[],t0 + | Typ_exist (kopts, _, t1) -> (ExGeneral, kopts, t1) + | _ -> (ExNone, [], t0) -let rec flatten_nc (NC_aux (nc,l) as nc_full) = - match nc with - | NC_and (nc1,nc2) -> flatten_nc nc1 @ flatten_nc nc2 - | _ -> [nc_full] +let rec flatten_nc (NC_aux (nc, l) as nc_full) = + match nc with NC_and (nc1, nc2) -> flatten_nc nc1 @ flatten_nc nc2 | _ -> [nc_full] (* When making changes here, check whether they affect coq_nvars_of_typ *) let rec doc_typ_fns ctx env = (* following the structure of parser for precedence *) let rec typ ty = fn_typ true ty - and typ' ty = fn_typ false ty - and fn_typ atyp_needed ((Typ_aux (t, _)) as ty) = match t with - | Typ_fn(args,ret) -> - let ret_typ = - (* TODO EFFECT: Make this ICE, add a doc_fn_typ with a monadic parameter that only docs function types *) - (*if effectful efct - then separate space [string "M"; fn_typ true ret] - else *) separate space [fn_typ false ret] in - let arg_typs = List.map (app_typ false) args in - let tpp = separate (space ^^ arrow ^^ space) (arg_typs @ [ret_typ]) in - (* once we have proper excetions we need to know what the exceptions type is *) - if atyp_needed then parens tpp else tpp - | _ -> tup_typ atyp_needed ty - and tup_typ atyp_needed ((Typ_aux (t, _)) as ty) = match t with - | Typ_tuple typs -> - parens (separate_map (space ^^ star ^^ space) (app_typ false) typs) - | _ -> app_typ atyp_needed ty - and app_typ atyp_needed ((Typ_aux (t, l)) as ty) = match t with - | Typ_app(Id_aux (Id "bitvector", _), [ - A_aux (A_nexp m, _); - A_aux (A_order ord, _)]) -> - (* TODO: remove duplication with exists, below *) - let tpp = string "mword " ^^ doc_nexp ctx m in - if atyp_needed then parens tpp else tpp - | Typ_app(Id_aux (Id "vector", _), [ - A_aux (A_nexp m, _); - A_aux (A_order ord, _); - A_aux (A_typ elem_typ, _)]) -> - (* TODO: remove duplication with exists, below *) - let tpp = string "vec" ^^ space ^^ typ elem_typ ^^ space ^^ doc_nexp ctx m in - if atyp_needed then parens tpp else tpp - | Typ_app(Id_aux (Id "register", _), [A_aux (A_typ etyp, _)]) -> - let tpp = string "register_ref regstate register_value " ^^ typ etyp in - if atyp_needed then parens tpp else tpp - | Typ_app(Id_aux (Id "range", _), _) - | Typ_id (Id_aux (Id "nat", _)) - | Typ_app(Id_aux (Id "implicit", _),_) - | Typ_app(Id_aux (Id "atom", _), [A_aux(A_nexp _,_)]) -> - string "Z" - | Typ_app(Id_aux (Id "atom_bool", _), [A_aux (A_bool _atom_nc,_)]) -> - string "bool" - | Typ_app(id,args) -> - let tpp = (doc_id_type ctx.types_mod ctx.avoid_target_names (Some env) id) ^^ space ^^ (separate_map space doc_typ_arg args) in - if atyp_needed then parens tpp else tpp - | _ -> atomic_typ atyp_needed ty - and atomic_typ atyp_needed ((Typ_aux (t, l)) as ty) = match t with - | Typ_id (Id_aux (Id "bool",_)) -> string "bool" - | Typ_id (Id_aux (Id "bit",_)) -> string "bitU" - | Typ_id (id) -> - (*if List.exists ((=) (string_of_id id)) regtypes - then string "register" - else*) doc_id_type ctx.types_mod ctx.avoid_target_names (Some env) id - | Typ_var v -> doc_var ctx v - | Typ_app _ | Typ_tuple _ | Typ_fn _ -> - (* exhaustiveness matters here to avoid infinite loops - * if we add a new Typ constructor *) - let tpp = typ ty in - if atyp_needed then parens tpp else tpp - (* TODO: handle non-integer kopts *) - | Typ_exist (kopts,nc,ty') -> - (* TODO: check for kopts used in ty', using coq_nvars_of_typ, but make sure that's correct *) - atomic_typ atyp_needed ty' - (* TODO: decide how to handle situations where an existential witness is required, e.g., - by turning {'n, 'n >= 0. bits('n)} into a pair of 'n and the bitvector. The code below - is the old implementation which used embedded proofs, but might prove useful. - begin - let kopts,nc,ty' = match maybe_expand_range_type ty' with - | Some (Typ_aux (Typ_exist (kopts',nc',ty'),_)) -> - kopts'@kopts,nc_and nc nc',ty' - | _ -> kopts,nc,ty' - in - match ty' with - | Typ_aux (Typ_app (Id_aux (Id "atom",_), - [A_aux (A_nexp nexp,_)]),_) -> - begin match nexp, kopts with - | (Nexp_aux (Nexp_var kid,_)), [kopt] when Kid.compare kid (kopt_kid kopt) == 0 -> - braces (separate space [doc_var ctx kid; colon; string "Z"; - ampersand; doc_arithfact ctx env nc]) - | _ -> - let var = mk_kid "_atom" in (* TODO collision avoid *) - let nc = nice_and (nc_eq (nvar var) nexp) nc in - braces (separate space [doc_var ctx var; colon; string "Z"; - ampersand; doc_arithfact ctx env ~exists:(List.map kopt_kid kopts) nc]) - end - | Typ_aux (Typ_app (Id_aux (Id "bitvector",_), - [A_aux (A_nexp m, _); - A_aux (A_order ord, _)]), _) -> - (* TODO: proper handling of m, complex elem type, dedup with above *) - let var = mk_kid "_vec" in (* TODO collision avoid *) - let kid_set = KidSet.of_list (List.map kopt_kid kopts) in - let m_pp = doc_nexp ctx ~skip_vars:kid_set m in - let tpp, len_pp = string "mword " ^^ m_pp, string "length_mword" in - let length_constraint_pp = - if KidSet.is_empty (KidSet.inter kid_set (nexp_frees m)) - then None - else Some (separate space [len_pp; doc_var ctx var; string "=?"; doc_nexp ctx m]) - in - braces (separate space - [doc_var ctx var; colon; tpp; - ampersand; - doc_arithfact ctx env ~exists:(List.map kopt_kid kopts) ?extra:length_constraint_pp nc]) - | Typ_aux (Typ_app (Id_aux (Id "vector",_), - [A_aux (A_nexp m, _); - A_aux (A_order ord, _); - A_aux (A_typ elem_typ, _)]),_) -> - (* TODO: proper handling of m, complex elem type, dedup with above *) - let var = mk_kid "_vec" in (* TODO collision avoid *) - let kid_set = KidSet.of_list (List.map kopt_kid kopts) in - let m_pp = doc_nexp ctx ~skip_vars:kid_set m in - let tpp, len_pp = string "vec" ^^ space ^^ typ elem_typ ^^ space ^^ m_pp, string "vec_length" in - let length_constraint_pp = - if KidSet.is_empty (KidSet.inter kid_set (nexp_frees m)) - then None - else Some (separate space [len_pp; doc_var ctx var; string "=?"; doc_nexp ctx m]) - in - braces (separate space - [doc_var ctx var; colon; tpp; - ampersand; - doc_arithfact ctx env ~exists:(List.map kopt_kid kopts) ?extra:length_constraint_pp nc]) - | Typ_aux (Typ_app (Id_aux (Id "atom_bool",_), [A_aux (A_bool atom_nc,_)]),_) -> begin - match simplify_atom_bool l kopts nc atom_nc with - | Bool_boring -> string "bool" - | Bool_complex (kopts,nc,atom_nc) -> - let var = mk_kid "_bool" in (* TODO collision avoid *) - let nc = nice_and (nice_iff atom_nc (nc_var var)) nc in - braces (separate space - [doc_var ctx var; colon; string "bool"; - ampersand; - doc_arithfact ctx env ~exists:(List.map kopt_kid kopts) nc]) - end - | Typ_aux (Typ_tuple tys,l) -> begin - (* TODO: boolean existentials *) - let kid_set = KidSet.of_list (List.map kopt_kid kopts) in - let should_keep (Typ_aux (ty,_)) = - match ty with - | Typ_app (Id_aux (Id "atom",_), [A_aux (A_nexp (Nexp_aux (Nexp_var var,_)),_)]) -> - not (KidSet.mem var kid_set) - | _ -> true - in - let out_tys = List.filter should_keep tys in - let binding_of_tyvar (KOpt_aux (KOpt_kind (K_aux (kind,_) as kaux,kid),_)) = - let kind_pp = match kind with - | K_int -> string "Z" - | _ -> - raise (Reporting.err_todo l - ("Non-atom existential type over " ^ string_of_kind kaux ^ " not yet supported in Coq: " ^ - string_of_typ ty)) - in doc_var ctx kid, kind_pp - in - let exvars_pp = List.map binding_of_tyvar kopts in - let pat = match exvars_pp with - | [v,k] -> v ^^ space ^^ colon ^^ space ^^ k - | _ -> - let vars, types = List.split exvars_pp in - squote ^^ parens (separate (string ", ") vars) ^/^ - colon ^/^ parens (separate (string " * ") types) - in - group (braces (group (pat ^^ space ^^ ampersand) ^/^ - group (tup_typ true (Typ_aux (Typ_tuple out_tys,l)) ^^ - string "%type ") ^^ - ampersand ^/^ - doc_arithfact ctx env nc)) - end - | _ -> - raise (Reporting.err_todo l - ("Non-atom existential type not yet supported in Coq: " ^ - string_of_typ ty)) - end + and typ' ty = fn_typ false ty + and fn_typ atyp_needed (Typ_aux (t, _) as ty) = + match t with + | Typ_fn (args, ret) -> + let ret_typ = + (* TODO EFFECT: Make this ICE, add a doc_fn_typ with a monadic parameter that only docs function types *) + (*if effectful efct + then separate space [string "M"; fn_typ true ret] + else *) + separate space [fn_typ false ret] + in + let arg_typs = List.map (app_typ false) args in + let tpp = separate (space ^^ arrow ^^ space) (arg_typs @ [ret_typ]) in + (* once we have proper excetions we need to know what the exceptions type is *) + if atyp_needed then parens tpp else tpp + | _ -> tup_typ atyp_needed ty + and tup_typ atyp_needed (Typ_aux (t, _) as ty) = + match t with + | Typ_tuple typs -> parens (separate_map (space ^^ star ^^ space) (app_typ false) typs) + | _ -> app_typ atyp_needed ty + and app_typ atyp_needed (Typ_aux (t, l) as ty) = + match t with + | Typ_app (Id_aux (Id "bitvector", _), [A_aux (A_nexp m, _); A_aux (A_order ord, _)]) -> + (* TODO: remove duplication with exists, below *) + let tpp = string "mword " ^^ doc_nexp ctx m in + if atyp_needed then parens tpp else tpp + | Typ_app (Id_aux (Id "vector", _), [A_aux (A_nexp m, _); A_aux (A_order ord, _); A_aux (A_typ elem_typ, _)]) -> + (* TODO: remove duplication with exists, below *) + let tpp = string "vec" ^^ space ^^ typ elem_typ ^^ space ^^ doc_nexp ctx m in + if atyp_needed then parens tpp else tpp + | Typ_app (Id_aux (Id "register", _), [A_aux (A_typ etyp, _)]) -> + let tpp = string "register_ref regstate register_value " ^^ typ etyp in + if atyp_needed then parens tpp else tpp + | Typ_app (Id_aux (Id "range", _), _) + | Typ_id (Id_aux (Id "nat", _)) + | Typ_app (Id_aux (Id "implicit", _), _) + | Typ_app (Id_aux (Id "atom", _), [A_aux (A_nexp _, _)]) -> + string "Z" + | Typ_app (Id_aux (Id "atom_bool", _), [A_aux (A_bool _atom_nc, _)]) -> string "bool" + | Typ_app (id, args) -> + let tpp = + doc_id_type ctx.types_mod ctx.avoid_target_names (Some env) id ^^ space ^^ separate_map space doc_typ_arg args + in + if atyp_needed then parens tpp else tpp + | _ -> atomic_typ atyp_needed ty + and atomic_typ atyp_needed (Typ_aux (t, l) as ty) = + match t with + | Typ_id (Id_aux (Id "bool", _)) -> string "bool" + | Typ_id (Id_aux (Id "bit", _)) -> string "bitU" + | Typ_id id -> + (*if List.exists ((=) (string_of_id id)) regtypes + then string "register" + else*) + doc_id_type ctx.types_mod ctx.avoid_target_names (Some env) id + | Typ_var v -> doc_var ctx v + | Typ_app _ | Typ_tuple _ | Typ_fn _ -> + (* exhaustiveness matters here to avoid infinite loops + * if we add a new Typ constructor *) + let tpp = typ ty in + if atyp_needed then parens tpp else tpp + (* TODO: handle non-integer kopts *) + | Typ_exist (kopts, nc, ty') -> + (* TODO: check for kopts used in ty', using coq_nvars_of_typ, but make sure that's correct *) + atomic_typ atyp_needed ty' + (* TODO: decide how to handle situations where an existential witness is required, e.g., + by turning {'n, 'n >= 0. bits('n)} into a pair of 'n and the bitvector. The code below + is the old implementation which used embedded proofs, but might prove useful. + begin + let kopts,nc,ty' = match maybe_expand_range_type ty' with + | Some (Typ_aux (Typ_exist (kopts',nc',ty'),_)) -> + kopts'@kopts,nc_and nc nc',ty' + | _ -> kopts,nc,ty' + in + match ty' with + | Typ_aux (Typ_app (Id_aux (Id "atom",_), + [A_aux (A_nexp nexp,_)]),_) -> + begin match nexp, kopts with + | (Nexp_aux (Nexp_var kid,_)), [kopt] when Kid.compare kid (kopt_kid kopt) == 0 -> + braces (separate space [doc_var ctx kid; colon; string "Z"; + ampersand; doc_arithfact ctx env nc]) + | _ -> + let var = mk_kid "_atom" in (* TODO collision avoid *) + let nc = nice_and (nc_eq (nvar var) nexp) nc in + braces (separate space [doc_var ctx var; colon; string "Z"; + ampersand; doc_arithfact ctx env ~exists:(List.map kopt_kid kopts) nc]) + end + | Typ_aux (Typ_app (Id_aux (Id "bitvector",_), + [A_aux (A_nexp m, _); + A_aux (A_order ord, _)]), _) -> + (* TODO: proper handling of m, complex elem type, dedup with above *) + let var = mk_kid "_vec" in (* TODO collision avoid *) + let kid_set = KidSet.of_list (List.map kopt_kid kopts) in + let m_pp = doc_nexp ctx ~skip_vars:kid_set m in + let tpp, len_pp = string "mword " ^^ m_pp, string "length_mword" in + let length_constraint_pp = + if KidSet.is_empty (KidSet.inter kid_set (nexp_frees m)) + then None + else Some (separate space [len_pp; doc_var ctx var; string "=?"; doc_nexp ctx m]) + in + braces (separate space + [doc_var ctx var; colon; tpp; + ampersand; + doc_arithfact ctx env ~exists:(List.map kopt_kid kopts) ?extra:length_constraint_pp nc]) + | Typ_aux (Typ_app (Id_aux (Id "vector",_), + [A_aux (A_nexp m, _); + A_aux (A_order ord, _); + A_aux (A_typ elem_typ, _)]),_) -> + (* TODO: proper handling of m, complex elem type, dedup with above *) + let var = mk_kid "_vec" in (* TODO collision avoid *) + let kid_set = KidSet.of_list (List.map kopt_kid kopts) in + let m_pp = doc_nexp ctx ~skip_vars:kid_set m in + let tpp, len_pp = string "vec" ^^ space ^^ typ elem_typ ^^ space ^^ m_pp, string "vec_length" in + let length_constraint_pp = + if KidSet.is_empty (KidSet.inter kid_set (nexp_frees m)) + then None + else Some (separate space [len_pp; doc_var ctx var; string "=?"; doc_nexp ctx m]) + in + braces (separate space + [doc_var ctx var; colon; tpp; + ampersand; + doc_arithfact ctx env ~exists:(List.map kopt_kid kopts) ?extra:length_constraint_pp nc]) + | Typ_aux (Typ_app (Id_aux (Id "atom_bool",_), [A_aux (A_bool atom_nc,_)]),_) -> begin + match simplify_atom_bool l kopts nc atom_nc with + | Bool_boring -> string "bool" + | Bool_complex (kopts,nc,atom_nc) -> + let var = mk_kid "_bool" in (* TODO collision avoid *) + let nc = nice_and (nice_iff atom_nc (nc_var var)) nc in + braces (separate space + [doc_var ctx var; colon; string "bool"; + ampersand; + doc_arithfact ctx env ~exists:(List.map kopt_kid kopts) nc]) + end + | Typ_aux (Typ_tuple tys,l) -> begin + (* TODO: boolean existentials *) + let kid_set = KidSet.of_list (List.map kopt_kid kopts) in + let should_keep (Typ_aux (ty,_)) = + match ty with + | Typ_app (Id_aux (Id "atom",_), [A_aux (A_nexp (Nexp_aux (Nexp_var var,_)),_)]) -> + not (KidSet.mem var kid_set) + | _ -> true + in + let out_tys = List.filter should_keep tys in + let binding_of_tyvar (KOpt_aux (KOpt_kind (K_aux (kind,_) as kaux,kid),_)) = + let kind_pp = match kind with + | K_int -> string "Z" + | _ -> + raise (Reporting.err_todo l + ("Non-atom existential type over " ^ string_of_kind kaux ^ " not yet supported in Coq: " ^ + string_of_typ ty)) + in doc_var ctx kid, kind_pp + in + let exvars_pp = List.map binding_of_tyvar kopts in + let pat = match exvars_pp with + | [v,k] -> v ^^ space ^^ colon ^^ space ^^ k + | _ -> + let vars, types = List.split exvars_pp in + squote ^^ parens (separate (string ", ") vars) ^/^ + colon ^/^ parens (separate (string " * ") types) + in + group (braces (group (pat ^^ space ^^ ampersand) ^/^ + group (tup_typ true (Typ_aux (Typ_tuple out_tys,l)) ^^ + string "%type ") ^^ + ampersand ^/^ + doc_arithfact ctx env nc)) + end + | _ -> + raise (Reporting.err_todo l + ("Non-atom existential type not yet supported in Coq: " ^ + string_of_typ ty)) + end -(* + (* - let add_tyvar tpp kid = - braces (separate space [doc_var ctx kid; colon; string "Z"; ampersand; tpp]) - in - match drop_duplicate_atoms kids ty with - | Some ty -> - let tpp = typ ty in - let tpp = match nc with NC_aux (NC_true,_) -> tpp | _ -> - braces (separate space [underscore; colon; parens (doc_arithfact ctx nc); ampersand; tpp]) - in - List.fold_left add_tyvar tpp kids - | None -> - match nc with -(* | NC_aux (NC_true,_) -> List.fold_left add_tyvar (string "Z") (List.tl kids)*) - | _ -> List.fold_left add_tyvar (doc_arithfact ctx nc) kids - end*)*) - | Typ_bidir _ -> unreachable l __POS__ "Coq doesn't support bidir types" - | Typ_internal_unknown -> unreachable l __POS__ "escaped Typ_internal_unknown" - and doc_typ_arg ?(prop_vars = false) (A_aux(t,_)) = match t with - | A_typ t -> app_typ true t - | A_nexp n -> doc_nexp ctx n - | A_order o -> empty - | A_bool nc -> parens (doc_nc_exp ctx env nc) - in typ', atomic_typ, doc_typ_arg -and doc_typ ctx env = let f,_,_ = doc_typ_fns ctx env in f -and doc_atomic_typ ctx env = let _,f,_ = doc_typ_fns ctx env in f -and doc_typ_arg ctx env = let _,_,f = doc_typ_fns ctx env in f + let add_tyvar tpp kid = + braces (separate space [doc_var ctx kid; colon; string "Z"; ampersand; tpp]) + in + match drop_duplicate_atoms kids ty with + | Some ty -> + let tpp = typ ty in + let tpp = match nc with NC_aux (NC_true,_) -> tpp | _ -> + braces (separate space [underscore; colon; parens (doc_arithfact ctx nc); ampersand; tpp]) + in + List.fold_left add_tyvar tpp kids + | None -> + match nc with + (* | NC_aux (NC_true,_) -> List.fold_left add_tyvar (string "Z") (List.tl kids)*) + | _ -> List.fold_left add_tyvar (doc_arithfact ctx nc) kids + end*)*) + | Typ_bidir _ -> unreachable l __POS__ "Coq doesn't support bidir types" + | Typ_internal_unknown -> unreachable l __POS__ "escaped Typ_internal_unknown" + and doc_typ_arg ?(prop_vars = false) (A_aux (t, _)) = + match t with + | A_typ t -> app_typ true t + | A_nexp n -> doc_nexp ctx n + | A_order o -> empty + | A_bool nc -> parens (doc_nc_exp ctx env nc) + in + (typ', atomic_typ, doc_typ_arg) + +and doc_typ ctx env = + let f, _, _ = doc_typ_fns ctx env in + f + +and doc_atomic_typ ctx env = + let _, f, _ = doc_typ_fns ctx env in + f + +and doc_typ_arg ctx env = + let _, _, f = doc_typ_fns ctx env in + f and doc_arithfact ctxt env ?(exists = []) ?extra nc = let prop = doc_nc_exp ctxt env nc in - let prop = match extra with - | None -> prop - | Some pp -> separate space [parens pp; string "&&"; parens prop] - in + let prop = match extra with None -> prop | Some pp -> separate space [parens pp; string "&&"; parens prop] in let prop = prop in match exists with | [] -> string "ArithFact" ^^ space ^^ parens prop - | _ -> string "ArithFactP" ^^ space ^^ - parens (separate space ([string "exists"]@(List.map (doc_var ctxt) exists)@[comma; prop; equals; string "true"])) + | _ -> + string "ArithFactP" ^^ space + ^^ parens + (separate space ([string "exists"] @ List.map (doc_var ctxt) exists @ [comma; prop; equals; string "true"])) (* Follows Coq precedence levels *) and doc_nc_exp ctx env nc = @@ -822,11 +765,12 @@ and doc_nc_exp ctx env nc = let nc = Env.expand_constraint_synonyms env nc in let nc_id_map = List.fold_left - (fun m (v,(_,Typ_aux (typ,_))) -> + (fun m (v, (_, Typ_aux (typ, _))) -> match typ with - | Typ_app (id, [A_aux (A_bool nc,_)]) when string_of_id id = "atom_bool" -> - (flatten_nc nc, v)::m - | _ -> m) [] locals + | Typ_app (id, [A_aux (A_bool nc, _)]) when string_of_id id = "atom_bool" -> (flatten_nc nc, v) :: m + | _ -> m + ) + [] locals in (* Look for variables in the environment which exactly express the nc, and use them instead. As well as often being shorter, this avoids unbound type @@ -834,13 +778,13 @@ and doc_nc_exp ctx env nc = let rec newnc f nc = let ncs = flatten_nc nc in let candidates = - List.filter_map (fun (ncs',id) -> Option.map (fun x -> x,id) (list_contains NC.compare ncs ncs')) nc_id_map + List.filter_map (fun (ncs', id) -> Option.map (fun x -> (x, id)) (list_contains NC.compare ncs ncs')) nc_id_map in - match List.sort (fun (l,_) (l',_) -> compare l l') candidates with - | ([],id)::_ -> doc_id ctx id - | ((h::t),id)::_ -> parens (doc_op (string "&&") (doc_id ctx id) (l10 (List.fold_left nc_and h t))) + match List.sort (fun (l, _) (l', _) -> compare l l') candidates with + | ([], id) :: _ -> doc_id ctx id + | (h :: t, id) :: _ -> parens (doc_op (string "&&") (doc_id ctx id) (l10 (List.fold_left nc_and h t))) | [] -> f nc - and l70 (NC_aux (nc,_) as nc_full) = + and l70 (NC_aux (nc, _) as nc_full) = match nc with | NC_equal (ne1, ne2) -> doc_op (string "=?") (doc_nexp ctx ne1) (doc_nexp ctx ne2) | NC_bounded_ge (ne1, ne2) -> doc_op (string ">=?") (doc_nexp ctx ne1) (doc_nexp ctx ne2) @@ -848,226 +792,210 @@ and doc_nc_exp ctx env nc = | NC_bounded_le (ne1, ne2) -> doc_op (string "<=?") (doc_nexp ctx ne1) (doc_nexp ctx ne2) | NC_bounded_lt (ne1, ne2) -> doc_op (string " l50 nc_full - and l50 (NC_aux (nc,_) as nc_full) = - match nc with - | NC_or (nc1, nc2) -> doc_op (string "||") (newnc l50 nc1) (newnc l40 nc2) - | _ -> l40 nc_full - and l40 (NC_aux (nc,_) as nc_full) = + and l50 (NC_aux (nc, _) as nc_full) = + match nc with NC_or (nc1, nc2) -> doc_op (string "||") (newnc l50 nc1) (newnc l40 nc2) | _ -> l40 nc_full + and l40 (NC_aux (nc, _) as nc_full) = + match nc with NC_and (nc1, nc2) -> doc_op (string "&&") (newnc l40 nc1) (newnc l10 nc2) | _ -> l10 nc_full + and l10 (NC_aux (nc, _) as nc_full) = match nc with - | NC_and (nc1, nc2) -> doc_op (string "&&") (newnc l40 nc1) (newnc l10 nc2) - | _ -> l10 nc_full - and l10 (NC_aux (nc,_) as nc_full) = - match nc with - | NC_not_equal (ne1, ne2) -> string "negb" ^^ space ^^ parens (doc_op (string "=?") (doc_nexp ctx ne1) (doc_nexp ctx ne2)) + | NC_not_equal (ne1, ne2) -> + string "negb" ^^ space ^^ parens (doc_op (string "=?") (doc_nexp ctx ne1) (doc_nexp ctx ne2)) | NC_set (kid, is) -> - separate space [string "member_Z_list"; doc_var ctx kid; - brackets (separate (string "; ") - (List.map (fun i -> string (Nat_big_num.to_string i)) is))] - | NC_app (f,args) -> separate space (doc_nc_fn ctx f::List.map doc_typ_arg_exp args) + separate space + [ + string "member_Z_list"; + doc_var ctx kid; + brackets (separate (string "; ") (List.map (fun i -> string (Nat_big_num.to_string i)) is)); + ] + | NC_app (f, args) -> separate space (doc_nc_fn ctx f :: List.map doc_typ_arg_exp args) | _ -> l0 nc_full - and l0 (NC_aux (nc,_) as nc_full) = + and l0 (NC_aux (nc, _) as nc_full) = match nc with | NC_true -> string "true" | NC_false -> string "false" | NC_var kid -> doc_nexp ctx (nvar kid) - | NC_not_equal _ - | NC_set _ - | NC_app _ - | NC_equal _ - | NC_bounded_ge _ - | NC_bounded_gt _ - | NC_bounded_le _ - | NC_bounded_lt _ - | NC_or _ - | NC_and _ -> parens (l70 nc_full) - and doc_typ_arg_exp (A_aux (arg,l)) = + | NC_not_equal _ | NC_set _ | NC_app _ | NC_equal _ | NC_bounded_ge _ | NC_bounded_gt _ | NC_bounded_le _ + | NC_bounded_lt _ | NC_or _ | NC_and _ -> + parens (l70 nc_full) + and doc_typ_arg_exp (A_aux (arg, l)) = match arg with | A_nexp nexp -> doc_nexp ctx nexp | A_bool nc -> newnc l0 nc | A_order _ | A_typ _ -> - raise (Reporting.err_unreachable l __POS__ "Tried to pass Type or Order kind to SMT function") - in newnc l70 nc + raise (Reporting.err_unreachable l __POS__ "Tried to pass Type or Order kind to SMT function") + in + newnc l70 nc (* Check for variables in types that would be pretty-printed and are not bound in the val spec of the function. *) -let contains_t_pp_var ctxt (Typ_aux (t,a) as typ) = - KidSet.subset (coq_nvars_of_typ typ) ctxt.bound_nvars +let contains_t_pp_var ctxt (Typ_aux (t, a) as typ) = KidSet.subset (coq_nvars_of_typ typ) ctxt.bound_nvars (* TODO: should we resurrect this? -let replace_typ_size ctxt env (Typ_aux (t,a)) = - match t with - | Typ_app (Id_aux (Id "vector",_) as id, [A_aux (A_nexp size,_);ord;typ']) -> - begin - let mk_typ nexp = - Some (Typ_aux (Typ_app (id, [A_aux (A_nexp nexp,Parse_ast.Unknown);ord;typ']),a)) - in - match Type_check.solve env size with - | Some n -> mk_typ (nconstant n) - | None -> - let is_equal nexp = - prove __POS__ env (NC_aux (NC_equal (size,nexp),Parse_ast.Unknown)) - in match List.find is_equal (NexpSet.elements ctxt.bound_nexps) with - | nexp -> mk_typ nexp - | exception Not_found -> None - end - | _ -> None*) + let replace_typ_size ctxt env (Typ_aux (t,a)) = + match t with + | Typ_app (Id_aux (Id "vector",_) as id, [A_aux (A_nexp size,_);ord;typ']) -> + begin + let mk_typ nexp = + Some (Typ_aux (Typ_app (id, [A_aux (A_nexp nexp,Parse_ast.Unknown);ord;typ']),a)) + in + match Type_check.solve env size with + | Some n -> mk_typ (nconstant n) + | None -> + let is_equal nexp = + prove __POS__ env (NC_aux (NC_equal (size,nexp),Parse_ast.Unknown)) + in match List.find is_equal (NexpSet.elements ctxt.bound_nexps) with + | nexp -> mk_typ nexp + | exception Not_found -> None + end + | _ -> None*) let doc_tannot_core ctxt env eff typ = let of_typ typ = let ta = doc_typ ctxt env typ in - if eff then + if eff then ( match ctxt.early_ret with | Some ret_typ -> - if ctxt.is_monadic - then string "MR " ^^ parens ta ^^ string " " ^^ parens (doc_typ ctxt env ret_typ) - else string "sum " ^^ parens (doc_typ ctxt env ret_typ) ^^ string " " ^^ parens ta + if ctxt.is_monadic then string "MR " ^^ parens ta ^^ string " " ^^ parens (doc_typ ctxt env ret_typ) + else string "sum " ^^ parens (doc_typ ctxt env ret_typ) ^^ string " " ^^ parens ta | None -> string "M " ^^ parens ta + ) else ta - in of_typ typ + in + of_typ typ -let doc_tannot ctxt env eff typ = - string " : " ^^ doc_tannot_core ctxt env eff typ +let doc_tannot ctxt env eff typ = string " : " ^^ doc_tannot_core ctxt env eff typ (* Only double-quotes need escaped - by doubling them. *) -let coq_escape_string s = - Str.global_replace (Str.regexp "\"") "\"\"" s - -let doc_lit (L_aux(lit,l)) = +let coq_escape_string s = Str.global_replace (Str.regexp "\"") "\"\"" s + +let doc_lit (L_aux (lit, l)) = match lit with - | L_unit -> utf8string "tt" - | L_zero -> utf8string "B0" - | L_one -> utf8string "B1" + | L_unit -> utf8string "tt" + | L_zero -> utf8string "B0" + | L_one -> utf8string "B1" | L_false -> utf8string "false" - | L_true -> utf8string "true" + | L_true -> utf8string "true" | L_num i -> - let s = Big_int.to_string i in - let ipp = utf8string s in - if Big_int.less i Big_int.zero then parens ipp else ipp + let s = Big_int.to_string i in + let ipp = utf8string s in + if Big_int.less i Big_int.zero then parens ipp else ipp (* Not a typo, the bbv hex notation uses the letter O *) (* These need parens because of the 'sz 'b "..."' variants :( *) | L_hex n -> utf8string ("(Ox\"" ^ n ^ "\")") | L_bin n -> utf8string ("('b\"" ^ n ^ "\")") - | L_undef -> - utf8string "(Fail \"undefined value of unsupported type\")" - | L_string s -> utf8string ("\"" ^ (coq_escape_string s) ^ "\"") + | L_undef -> utf8string "(Fail \"undefined value of unsupported type\")" + | L_string s -> utf8string ("\"" ^ coq_escape_string s ^ "\"") | L_real s -> - (* Lem does not support decimal syntax, so we translate a string - of the form "x.y" into the ratio (x * 10^len(y) + y) / 10^len(y). - The OCaml library has a conversion function from strings to floats, but - not from floats to ratios. ZArith's Q library does have the latter, but - using this would require adding a dependency on ZArith to Sail. *) - let parts = Util.split_on_char '.' s in - let (num, denom) = match parts with - | [i] -> (Big_int.of_string i, Big_int.of_int 1) - | [i;f] -> - let denom = Big_int.pow_int_positive 10 (String.length f) in - (Big_int.add (Big_int.mul (Big_int.of_string i) denom) (Big_int.of_string f), denom) - | _ -> - raise (Reporting.err_syntax_loc l "could not parse real literal") in - parens (separate space (List.map string [ - "realFromFrac"; Big_int.to_string num; Big_int.to_string denom])) + (* Lem does not support decimal syntax, so we translate a string + of the form "x.y" into the ratio (x * 10^len(y) + y) / 10^len(y). + The OCaml library has a conversion function from strings to floats, but + not from floats to ratios. ZArith's Q library does have the latter, but + using this would require adding a dependency on ZArith to Sail. *) + let parts = Util.split_on_char '.' s in + let num, denom = + match parts with + | [i] -> (Big_int.of_string i, Big_int.of_int 1) + | [i; f] -> + let denom = Big_int.pow_int_positive 10 (String.length f) in + (Big_int.add (Big_int.mul (Big_int.of_string i) denom) (Big_int.of_string f), denom) + | _ -> raise (Reporting.err_syntax_loc l "could not parse real literal") + in + parens (separate space (List.map string ["realFromFrac"; Big_int.to_string num; Big_int.to_string denom])) -let doc_quant_item_id ?(prop_vars=false) ctx delimit (QI_aux (qi,_)) = +let doc_quant_item_id ?(prop_vars = false) ctx delimit (QI_aux (qi, _)) = match qi with - | QI_id (KOpt_aux (KOpt_kind (K_aux (kind,_),kid),_)) -> begin - if KBindings.mem kid ctx.kid_id_renames then None else - match kind with - | K_type -> Some (delimit (separate space [doc_var ctx kid; colon; string "Type"])) - | K_int -> begin - match KBindings.find_opt kid ctx.constant_kids with - | Some value -> Some (parens (separate space [doc_var ctx kid; colon; string "Z :="; string (Big_int.to_string value)])) - | None -> Some (delimit (separate space [doc_var ctx kid; colon; string "Z"])) - end - | K_order -> None - | K_bool -> Some (delimit (separate space [doc_var ctx kid; colon; - string (if prop_vars then "Prop" else "bool")])) - end + | QI_id (KOpt_aux (KOpt_kind (K_aux (kind, _), kid), _)) -> begin + if KBindings.mem kid ctx.kid_id_renames then None + else ( + match kind with + | K_type -> Some (delimit (separate space [doc_var ctx kid; colon; string "Type"])) + | K_int -> begin + match KBindings.find_opt kid ctx.constant_kids with + | Some value -> + Some (parens (separate space [doc_var ctx kid; colon; string "Z :="; string (Big_int.to_string value)])) + | None -> Some (delimit (separate space [doc_var ctx kid; colon; string "Z"])) + end + | K_order -> None + | K_bool -> + Some (delimit (separate space [doc_var ctx kid; colon; string (if prop_vars then "Prop" else "bool")])) + ) + end | QI_constraint _nc -> None -let quant_item_id_name ctx (QI_aux (qi,_)) = +let quant_item_id_name ctx (QI_aux (qi, _)) = match qi with - | QI_id (KOpt_aux (KOpt_kind (K_aux (kind,_),kid),_)) -> begin - if KBindings.mem kid ctx.kid_id_renames then None else - match kind with - | K_type -> Some (doc_var ctx kid) - | K_int -> Some (doc_var ctx kid) - | K_order -> None - | K_bool -> Some (doc_var ctx kid) - end + | QI_id (KOpt_aux (KOpt_kind (K_aux (kind, _), kid), _)) -> begin + if KBindings.mem kid ctx.kid_id_renames then None + else ( + match kind with + | K_type -> Some (doc_var ctx kid) + | K_int -> Some (doc_var ctx kid) + | K_order -> None + | K_bool -> Some (doc_var ctx kid) + ) + end | QI_constraint _nc -> None -let doc_quant_item_constr ?(prop_vars=false) ctx env delimit (QI_aux (qi,_)) = - match qi with - | QI_id _ -> None - | QI_constraint nc -> Some (comment (doc_nc_exp ctx env nc)) +let doc_quant_item_constr ?(prop_vars = false) ctx env delimit (QI_aux (qi, _)) = + match qi with QI_id _ -> None | QI_constraint nc -> Some (comment (doc_nc_exp ctx env nc)) (* At the moment these are all anonymous - when used we rely on Coq to fill them in. *) -let quant_item_constr_name ctx (QI_aux (qi,_)) = - match qi with - | QI_id _ -> None - | QI_constraint _nc -> None (*Some underscore*) +let quant_item_constr_name ctx (QI_aux (qi, _)) = + match qi with QI_id _ -> None | QI_constraint _nc -> None (*Some underscore*) -let doc_typquant_items ?(prop_vars=false) ctx env delimit (TypQ_aux (tq,_)) = +let doc_typquant_items ?(prop_vars = false) ctx env delimit (TypQ_aux (tq, _)) = match tq with | TypQ_tq qis -> - separate_opt space (doc_quant_item_id ~prop_vars ctx delimit) qis ^^ - separate_opt space (doc_quant_item_constr ~prop_vars ctx env delimit) qis + separate_opt space (doc_quant_item_id ~prop_vars ctx delimit) qis + ^^ separate_opt space (doc_quant_item_constr ~prop_vars ctx env delimit) qis | TypQ_no_forall -> empty -let doc_typquant_items_separate ctx env delimit (TypQ_aux (tq,_)) = +let doc_typquant_items_separate ctx env delimit (TypQ_aux (tq, _)) = match tq with | TypQ_tq qis -> - List.filter_map (doc_quant_item_id ctx delimit) qis, - List.filter_map (doc_quant_item_constr ctx env delimit) qis - | TypQ_no_forall -> [], [] + (List.filter_map (doc_quant_item_id ctx delimit) qis, List.filter_map (doc_quant_item_constr ctx env delimit) qis) + | TypQ_no_forall -> ([], []) -let typquant_names_separate ctx (TypQ_aux (tq,_)) = +let typquant_names_separate ctx (TypQ_aux (tq, _)) = match tq with - | TypQ_tq qis -> - List.filter_map (quant_item_id_name ctx) qis, - List.filter_map (quant_item_constr_name ctx) qis - | TypQ_no_forall -> [], [] - + | TypQ_tq qis -> (List.filter_map (quant_item_id_name ctx) qis, List.filter_map (quant_item_constr_name ctx) qis) + | TypQ_no_forall -> ([], []) -let doc_typquant ctx env (TypQ_aux(tq,_)) typ = match tq with -| TypQ_tq ((_ :: _) as qs) -> - string "forall " ^^ separate_opt space (doc_quant_item_id ctx braces) qs ^/^ - separate_opt space (doc_quant_item_constr ctx env parens) qs ^^ string ", " ^^ typ -| _ -> typ +let doc_typquant ctx env (TypQ_aux (tq, _)) typ = + match tq with + | TypQ_tq (_ :: _ as qs) -> + string "forall " + ^^ separate_opt space (doc_quant_item_id ctx braces) qs + ^/^ separate_opt space (doc_quant_item_constr ctx env parens) qs + ^^ string ", " ^^ typ + | _ -> typ (* Produce Size type constraints for bitvector sizes when using machine words. Often these will be unnecessary, but this simple approach will do for now. *) -let rec typeclass_nexps (Typ_aux(t,l)) = +let rec typeclass_nexps (Typ_aux (t, l)) = match t with - | Typ_id _ - | Typ_var _ - -> NexpSet.empty - | Typ_fn (t1,t2) -> List.fold_left NexpSet.union (typeclass_nexps t2) (List.map typeclass_nexps t1) + | Typ_id _ | Typ_var _ -> NexpSet.empty + | Typ_fn (t1, t2) -> List.fold_left NexpSet.union (typeclass_nexps t2) (List.map typeclass_nexps t1) | Typ_tuple ts -> List.fold_left NexpSet.union NexpSet.empty (List.map typeclass_nexps ts) - | Typ_app (Id_aux (Id "bitvector",_), - [A_aux (A_nexp size_nexp,_); _]) - | Typ_app (Id_aux (Id "itself",_), - [A_aux (A_nexp size_nexp,_)]) -> - let size_nexp = nexp_simp size_nexp in - if is_nexp_constant size_nexp then NexpSet.empty else - NexpSet.singleton (orig_nexp size_nexp) + | Typ_app (Id_aux (Id "bitvector", _), [A_aux (A_nexp size_nexp, _); _]) + | Typ_app (Id_aux (Id "itself", _), [A_aux (A_nexp size_nexp, _)]) -> + let size_nexp = nexp_simp size_nexp in + if is_nexp_constant size_nexp then NexpSet.empty else NexpSet.singleton (orig_nexp size_nexp) | Typ_app _ -> NexpSet.empty - | Typ_exist (kids,_,t) -> NexpSet.empty (* todo *) + | Typ_exist (kids, _, t) -> NexpSet.empty (* todo *) | Typ_bidir _ -> unreachable l __POS__ "Coq doesn't support bidir types" | Typ_internal_unknown -> unreachable l __POS__ "escaped Typ_internal_unknown" -let doc_typschm ctx env quants (TypSchm_aux(TypSchm_ts(tq,t),_)) = +let doc_typschm ctx env quants (TypSchm_aux (TypSchm_ts (tq, t), _)) = let pt = doc_typ ctx env t in if quants then doc_typquant ctx env tq pt else pt -let is_ctor env id = match Env.lookup_id id env with -| Enum _ -> true -| _ -> false +let is_ctor env id = match Env.lookup_id id env with Enum _ -> true | _ -> false -let is_auto_decomposed_exist ctxt env ?(rawbools=false) typ = +let is_auto_decomposed_exist ctxt env ?(rawbools = false) typ = let typ = expand_range_type typ in match classify_ex_type ctxt env ~rawbools (Env.expand_synonyms env typ) with | ExGeneral, kopts, typ' -> Some (kopts, typ') @@ -1080,327 +1008,323 @@ let is_auto_decomposed_exist ctxt env ?(rawbools=false) typ = remaining 'a-type pairs. *) let filter_dep_tuple kopts vals_typs = let kid_set = KidSet.of_list (List.map kopt_kid kopts) in - let should_keep (_,Typ_aux (ty,_)) = + let should_keep (_, Typ_aux (ty, _)) = match ty with - | Typ_app (Id_aux (Id "atom",_), [A_aux (A_nexp (Nexp_aux (Nexp_var var,_)),_)]) -> - not (KidSet.mem var kid_set) + | Typ_app (Id_aux (Id "atom", _), [A_aux (A_nexp (Nexp_aux (Nexp_var var, _)), _)]) -> not (KidSet.mem var kid_set) | _ -> true in let tup_val_typs, ex_val_typs = List.partition should_keep vals_typs in - let is_kid kid (Typ_aux (t,_)) = + let is_kid kid (Typ_aux (t, _)) = match t with - | Typ_app (Id_aux (Id "atom",_), [A_aux (A_nexp (Nexp_aux (Nexp_var var,_)),_)]) -> Kid.compare kid var == 0 + | Typ_app (Id_aux (Id "atom", _), [A_aux (A_nexp (Nexp_aux (Nexp_var var, _)), _)]) -> Kid.compare kid var == 0 | _ -> false in - let find_val kopt = List.find_opt (fun (_,ty) -> is_kid (kopt_kid kopt) ty) ex_val_typs in - List.map find_val kopts, tup_val_typs - -let filter_dep_pattern_tuple ctxt kopts (P_aux (p,ann) as pat) typ = - match p, typ with - | P_tuple ps, Typ_aux (Typ_tuple ts,l) -> - let ex_pat_typs, tup_pat_typs = filter_dep_tuple kopts (List.combine ps ts) in - let map_ex_pat x = - match x with - | Some (P_aux (P_wild,_),_) -> string "_" - | Some (P_aux (P_id id,_),_) -> doc_id ctxt id - | Some (p,t) -> raise (Reporting.err_unreachable l __POS__ ("inconsistent type " ^ string_of_typ t ^ " and pattern " ^ string_of_pat p)) - | None -> string "_" - in - let coq_typats = List.map map_ex_pat ex_pat_typs in - let coq_typat = - match coq_typats with - | [p] -> p - | _ -> parens (separate (string ", ") coq_typats) - in - let coq_pat = P_tuple (List.map fst tup_pat_typs) in - let coq_typ = Typ_aux (Typ_tuple (List.map snd tup_pat_typs), l) in - Some coq_typat, P_aux (coq_pat,ann), coq_typ - | _ -> None, pat, typ + let find_val kopt = List.find_opt (fun (_, ty) -> is_kid (kopt_kid kopt) ty) ex_val_typs in + (List.map find_val kopts, tup_val_typs) + +let filter_dep_pattern_tuple ctxt kopts (P_aux (p, ann) as pat) typ = + match (p, typ) with + | P_tuple ps, Typ_aux (Typ_tuple ts, l) -> + let ex_pat_typs, tup_pat_typs = filter_dep_tuple kopts (List.combine ps ts) in + let map_ex_pat x = + match x with + | Some (P_aux (P_wild, _), _) -> string "_" + | Some (P_aux (P_id id, _), _) -> doc_id ctxt id + | Some (p, t) -> + raise + (Reporting.err_unreachable l __POS__ + ("inconsistent type " ^ string_of_typ t ^ " and pattern " ^ string_of_pat p) + ) + | None -> string "_" + in + let coq_typats = List.map map_ex_pat ex_pat_typs in + let coq_typat = match coq_typats with [p] -> p | _ -> parens (separate (string ", ") coq_typats) in + let coq_pat = P_tuple (List.map fst tup_pat_typs) in + let coq_typ = Typ_aux (Typ_tuple (List.map snd tup_pat_typs), l) in + (Some coq_typat, P_aux (coq_pat, ann), coq_typ) + | _ -> (None, pat, typ) (*Note: vector concatenation, literal vectors, indexed vectors, and record should be removed prior to pp. The latter two have never yet been seen *) -let rec doc_pat ctxt apat_needed exists_as_pairs (P_aux (p,(l,annot)) as pat, typ) = - let env = env_of_annot (l,annot) in +let rec doc_pat ctxt apat_needed exists_as_pairs ((P_aux (p, (l, annot)) as pat), typ) = + let env = env_of_annot (l, annot) in let typ = Env.expand_synonyms env typ in - match p with - (* Special case translation of the None constructor to remove the unit arg *) - | P_app(id, _) when string_of_id id = "None" -> string "None" - | P_app(id, ((_ :: _) as pats)) -> begin - (* Following the type checker to get the subpattern types, TODO perhaps ought - to persuade the type checker to output these somehow. *) - let (typq, ctor_typ) = Env.get_union_id id env in - let arg_typs = - match Env.expand_synonyms env ctor_typ with - | Typ_aux (Typ_fn (arg_typs, ret_typ), _) -> + match p with + (* Special case translation of the None constructor to remove the unit arg *) + | P_app (id, _) when string_of_id id = "None" -> string "None" + | P_app (id, (_ :: _ as pats)) -> begin + (* Following the type checker to get the subpattern types, TODO perhaps ought + to persuade the type checker to output these somehow. *) + let typq, ctor_typ = Env.get_union_id id env in + let arg_typs = + match Env.expand_synonyms env ctor_typ with + | Typ_aux (Typ_fn (arg_typs, ret_typ), _) -> let unifiers = unify l env (tyvars_of_typ ret_typ) ret_typ typ in List.map (subst_unifiers unifiers) arg_typs - | _ -> assert false - in - debug ctxt (lazy ("constructor " ^ string_of_id id ^ " with type " ^ - string_of_typ ctor_typ ^ - " gives types for subpatterns of " ^ - String.concat ", " (List.map string_of_typ arg_typs))); - (* Constructors that were specified without a return type might get - an extra tuple in their type; expand that here if necessary. - TODO: this should go away if we enforce proper arities. *) - let arg_typs = match pats, arg_typs with - | _::_::_, [Typ_aux (Typ_tuple typs,_)] -> typs - | _,_ -> arg_typs - in - let pats_pp = separate_map comma (doc_pat ctxt true true) (List.combine pats arg_typs) in - let pats_pp = match pats with [_] -> pats_pp | _ -> parens pats_pp in - let ppp = doc_unop (doc_id_ctor ctxt id) pats_pp in - if apat_needed then parens ppp else ppp - end - | P_app(id, []) -> doc_id_ctor ctxt id - | P_lit lit -> doc_lit lit - | P_wild -> underscore - | P_id id -> doc_id ctxt id - | P_var(p,_) -> doc_pat ctxt true exists_as_pairs (p, typ) - | P_as(p,id) -> parens (separate space [doc_pat ctxt true exists_as_pairs (p, typ); string "as"; doc_id ctxt id]) - | P_typ(ptyp,p) -> - let doc_p = doc_pat ctxt true exists_as_pairs (p, typ) in - doc_p - (* Type annotations aren't allowed everywhere in patterns in Coq *) - (*parens (doc_op colon doc_p (doc_typ typ))*) - | P_vector pats -> - let el_typ = - match destruct_vector env typ with - | Some (_,_,t) -> t - | None -> raise (Reporting.err_unreachable l __POS__ "vector pattern doesn't have vector type") - in - let ppp = brackets (separate_map semi (fun p -> doc_pat ctxt true exists_as_pairs (p,el_typ)) pats) in - if apat_needed then parens ppp else ppp - | P_vector_concat pats -> - raise (Reporting.err_unreachable l __POS__ - "vector concatenation patterns should have been removed before pretty-printing") - | P_vector_subrange _ -> unreachable l __POS__ "Must have been rewritten before Coq backend" - | P_tuple pats -> - let typs = match typ with - | Typ_aux (Typ_tuple typs, _) -> typs - | Typ_aux (Typ_exist _,_) -> - raise (Reporting.err_todo l "existential types not yet supported here") - | _ -> raise (Reporting.err_unreachable l __POS__ "tuple pattern doesn't have tuple type") - in - (match pats, typs with - | [p], [typ'] -> doc_pat ctxt apat_needed true (p, typ') - | [_], _ -> raise (Reporting.err_unreachable l __POS__ "tuple pattern length does not match tuple type length") - | _ -> parens (separate_map comma_sp (doc_pat ctxt false true) (List.combine pats typs))) - | P_list pats -> - let el_typ = match typ with - | Typ_aux (Typ_app (f, [A_aux (A_typ el_typ,_)]),_) - when Id.compare f (mk_id "list") = 0 -> el_typ - | _ -> raise (Reporting.err_unreachable l __POS__ "list pattern not a list") - in - brackets (separate_map semi (fun p -> doc_pat ctxt false true (p, el_typ)) pats) - | P_cons (p,p') -> - let el_typ = match typ with - | Typ_aux (Typ_app (f, [A_aux (A_typ el_typ,_)]),_) - when Id.compare f (mk_id "list") = 0 -> el_typ - | _ -> raise (Reporting.err_unreachable l __POS__ "list pattern not a list") - in - doc_op (string "::") (doc_pat ctxt true true (p, el_typ)) (doc_pat ctxt true true (p', typ)) - | P_string_append _ -> unreachable l __POS__ - "string append pattern found in Coq backend, should have been rewritten" - | P_not _ -> unreachable l __POS__ "Coq backend doesn't support not patterns" - | P_or _ -> unreachable l __POS__ "Coq backend doesn't support or patterns yet" + | _ -> assert false + in + debug ctxt + ( lazy + ("constructor " ^ string_of_id id ^ " with type " ^ string_of_typ ctor_typ + ^ " gives types for subpatterns of " + ^ String.concat ", " (List.map string_of_typ arg_typs) + ) + ); + (* Constructors that were specified without a return type might get + an extra tuple in their type; expand that here if necessary. + TODO: this should go away if we enforce proper arities. *) + let arg_typs = + match (pats, arg_typs) with _ :: _ :: _, [Typ_aux (Typ_tuple typs, _)] -> typs | _, _ -> arg_typs + in + let pats_pp = separate_map comma (doc_pat ctxt true true) (List.combine pats arg_typs) in + let pats_pp = match pats with [_] -> pats_pp | _ -> parens pats_pp in + let ppp = doc_unop (doc_id_ctor ctxt id) pats_pp in + if apat_needed then parens ppp else ppp + end + | P_app (id, []) -> doc_id_ctor ctxt id + | P_lit lit -> doc_lit lit + | P_wild -> underscore + | P_id id -> doc_id ctxt id + | P_var (p, _) -> doc_pat ctxt true exists_as_pairs (p, typ) + | P_as (p, id) -> parens (separate space [doc_pat ctxt true exists_as_pairs (p, typ); string "as"; doc_id ctxt id]) + | P_typ (ptyp, p) -> + let doc_p = doc_pat ctxt true exists_as_pairs (p, typ) in + doc_p + (* Type annotations aren't allowed everywhere in patterns in Coq *) + (*parens (doc_op colon doc_p (doc_typ typ))*) + | P_vector pats -> + let el_typ = + match destruct_vector env typ with + | Some (_, _, t) -> t + | None -> raise (Reporting.err_unreachable l __POS__ "vector pattern doesn't have vector type") + in + let ppp = brackets (separate_map semi (fun p -> doc_pat ctxt true exists_as_pairs (p, el_typ)) pats) in + if apat_needed then parens ppp else ppp + | P_vector_concat pats -> + raise + (Reporting.err_unreachable l __POS__ + "vector concatenation patterns should have been removed before pretty-printing" + ) + | P_vector_subrange _ -> unreachable l __POS__ "Must have been rewritten before Coq backend" + | P_tuple pats -> ( + let typs = + match typ with + | Typ_aux (Typ_tuple typs, _) -> typs + | Typ_aux (Typ_exist _, _) -> raise (Reporting.err_todo l "existential types not yet supported here") + | _ -> raise (Reporting.err_unreachable l __POS__ "tuple pattern doesn't have tuple type") + in + match (pats, typs) with + | [p], [typ'] -> doc_pat ctxt apat_needed true (p, typ') + | [_], _ -> raise (Reporting.err_unreachable l __POS__ "tuple pattern length does not match tuple type length") + | _ -> parens (separate_map comma_sp (doc_pat ctxt false true) (List.combine pats typs)) + ) + | P_list pats -> + let el_typ = + match typ with + | Typ_aux (Typ_app (f, [A_aux (A_typ el_typ, _)]), _) when Id.compare f (mk_id "list") = 0 -> el_typ + | _ -> raise (Reporting.err_unreachable l __POS__ "list pattern not a list") + in + brackets (separate_map semi (fun p -> doc_pat ctxt false true (p, el_typ)) pats) + | P_cons (p, p') -> + let el_typ = + match typ with + | Typ_aux (Typ_app (f, [A_aux (A_typ el_typ, _)]), _) when Id.compare f (mk_id "list") = 0 -> el_typ + | _ -> raise (Reporting.err_unreachable l __POS__ "list pattern not a list") + in + doc_op (string "::") (doc_pat ctxt true true (p, el_typ)) (doc_pat ctxt true true (p', typ)) + | P_string_append _ -> unreachable l __POS__ "string append pattern found in Coq backend, should have been rewritten" + | P_not _ -> unreachable l __POS__ "Coq backend doesn't support not patterns" + | P_or _ -> unreachable l __POS__ "Coq backend doesn't support or patterns yet" let contains_early_return exp = let e_app (f, args) = let rets, args = List.split args in - (List.fold_left (||) (string_of_id f = "early_return") rets, - E_app (f, args)) in - fst (fold_exp - { (Rewriter.compute_exp_alg false (||)) - with e_return = (fun (_, r) -> (true, E_return r)); e_app = e_app } exp) + (List.fold_left ( || ) (string_of_id f = "early_return") rets, E_app (f, args)) + in + fst + (fold_exp { (Rewriter.compute_exp_alg false ( || )) with e_return = (fun (_, r) -> (true, E_return r)); e_app } exp) let find_e_ids exp = - let e_id id = IdSet.singleton id, E_id id in - fst (fold_exp - { (compute_exp_alg IdSet.empty IdSet.union) with e_id = e_id } exp) + let e_id id = (IdSet.singleton id, E_id id) in + fst (fold_exp { (compute_exp_alg IdSet.empty IdSet.union) with e_id } exp) -let typ_id_of (Typ_aux (typ, l)) = match typ with +let typ_id_of (Typ_aux (typ, l)) = + match typ with | Typ_id id -> id - | Typ_app (register, [A_aux (A_typ (Typ_aux (Typ_id id, _)), _)]) - when string_of_id register = "register" -> id + | Typ_app (register, [A_aux (A_typ (Typ_aux (Typ_id id, _)), _)]) when string_of_id register = "register" -> id | Typ_app (id, _) -> id | _ -> raise (Reporting.err_unreachable l __POS__ "failed to get type id") (* TODO: maybe Nexp_exp, division? *) (* Evaluation of constant nexp subexpressions, because Coq will be able to do those itself *) -let rec nexp_const_eval (Nexp_aux (n,l) as nexp) = +let rec nexp_const_eval (Nexp_aux (n, l) as nexp) = let binop f re l n1 n2 = - match nexp_const_eval n1, nexp_const_eval n2 with - | Nexp_aux (Nexp_constant c1,_), Nexp_aux (Nexp_constant c2,_) -> - Nexp_aux (Nexp_constant (f c1 c2),l) - | n1', n2' -> Nexp_aux (re n1' n2',l) + match (nexp_const_eval n1, nexp_const_eval n2) with + | Nexp_aux (Nexp_constant c1, _), Nexp_aux (Nexp_constant c2, _) -> Nexp_aux (Nexp_constant (f c1 c2), l) + | n1', n2' -> Nexp_aux (re n1' n2', l) in let unop f re l n1 = - match nexp_const_eval n1 with - | Nexp_aux (Nexp_constant c1,_) -> Nexp_aux (Nexp_constant (f c1),l) - | n1' -> Nexp_aux (re n1',l) + match nexp_const_eval n1 with + | Nexp_aux (Nexp_constant c1, _) -> Nexp_aux (Nexp_constant (f c1), l) + | n1' -> Nexp_aux (re n1', l) in match n with - | Nexp_times (n1,n2) -> binop Big_int.mul (fun n1 n2 -> Nexp_times (n1,n2)) l n1 n2 - | Nexp_sum (n1,n2) -> binop Big_int.add (fun n1 n2 -> Nexp_sum (n1,n2)) l n1 n2 - | Nexp_minus (n1,n2) -> binop Big_int.sub (fun n1 n2 -> Nexp_minus (n1,n2)) l n1 n2 + | Nexp_times (n1, n2) -> binop Big_int.mul (fun n1 n2 -> Nexp_times (n1, n2)) l n1 n2 + | Nexp_sum (n1, n2) -> binop Big_int.add (fun n1 n2 -> Nexp_sum (n1, n2)) l n1 n2 + | Nexp_minus (n1, n2) -> binop Big_int.sub (fun n1 n2 -> Nexp_minus (n1, n2)) l n1 n2 | Nexp_neg n1 -> unop Big_int.negate (fun n -> Nexp_neg n) l n1 | _ -> nexp (* Decide whether two nexps used in a vector size are similar; if not a cast will be inserted *) let similar_nexps ctxt env n1 n2 = - let rec same_nexp_shape (Nexp_aux (n1,_)) (Nexp_aux (n2,_)) = - match n1, n2 with + let rec same_nexp_shape (Nexp_aux (n1, _)) (Nexp_aux (n2, _)) = + match (n1, n2) with | Nexp_id _, Nexp_id _ -> true (* TODO: this is really just an approximation to what we really want: will the Coq types have the same names? We could probably do better by tracking which existential kids are equal to bound kids. *) | Nexp_var k1, Nexp_var k2 -> - Kid.compare k1 k2 == 0 || - (prove __POS__ env (nc_eq (nvar k1) (nvar k2)) && ( - not (KidSet.mem k1 ctxt.bound_nvars) || - not (KidSet.mem k2 ctxt.bound_nvars))) + Kid.compare k1 k2 == 0 + || prove __POS__ env (nc_eq (nvar k1) (nvar k2)) + && ((not (KidSet.mem k1 ctxt.bound_nvars)) || not (KidSet.mem k2 ctxt.bound_nvars)) | Nexp_constant c1, Nexp_constant c2 -> Nat_big_num.equal c1 c2 - | Nexp_app (f1,args1), Nexp_app (f2,args2) -> - Id.compare f1 f2 == 0 && List.for_all2 same_nexp_shape args1 args2 - | Nexp_times (n1,n2), Nexp_times (n3,n4) - | Nexp_sum (n1,n2), Nexp_sum (n3,n4) - | Nexp_minus (n1,n2), Nexp_minus (n3,n4) - -> same_nexp_shape n1 n3 && same_nexp_shape n2 n4 - | Nexp_exp n1, Nexp_exp n2 - | Nexp_neg n1, Nexp_neg n2 - -> same_nexp_shape n1 n2 + | Nexp_app (f1, args1), Nexp_app (f2, args2) -> Id.compare f1 f2 == 0 && List.for_all2 same_nexp_shape args1 args2 + | Nexp_times (n1, n2), Nexp_times (n3, n4) + | Nexp_sum (n1, n2), Nexp_sum (n3, n4) + | Nexp_minus (n1, n2), Nexp_minus (n3, n4) -> + same_nexp_shape n1 n3 && same_nexp_shape n2 n4 + | Nexp_exp n1, Nexp_exp n2 | Nexp_neg n1, Nexp_neg n2 -> same_nexp_shape n1 n2 | _ -> false - in if same_nexp_shape (nexp_const_eval n1) (nexp_const_eval n2) then true else false + in + if same_nexp_shape (nexp_const_eval n1) (nexp_const_eval n2) then true else false let constraint_fns = ["Z.leb"; "Z.geb"; "Z.ltb"; "Z.gtb"; "Z.eqb"; "neq_int"] let condition_produces_constraint ctxt exp = let env = env_of exp in - match classify_ex_type ctxt env ~rawbools:true (typ_of exp) with - | ExNone, _, _ -> false - | ExGeneral, _, _ -> true + match classify_ex_type ctxt env ~rawbools:true (typ_of exp) with ExNone, _, _ -> false | ExGeneral, _, _ -> true (* For most functions whose return types are non-trivial atoms we return a dependent pair with a proof that the result is the expected integer. This is redundant for basic arithmetic functions and functions which we unfold in the constraint solver. *) -let no_proof_fns = ["Z.add"; "Z.sub"; "Z.opp"; "Z.mul"; "Z.rem"; - "length_mword"; "length"; "vec_length"; - "negb"; "andb"; "orb"; - "Z.leb"; "Z.geb"; "Z.ltb"; "Z.gtb"; "Z.eqb"] +let no_proof_fns = + [ + "Z.add"; + "Z.sub"; + "Z.opp"; + "Z.mul"; + "Z.rem"; + "length_mword"; + "length"; + "vec_length"; + "negb"; + "andb"; + "orb"; + "Z.leb"; + "Z.geb"; + "Z.ltb"; + "Z.gtb"; + "Z.eqb"; + ] let is_no_proof_fn env id = - if Env.is_extern id env "coq" - then + if Env.is_extern id env "coq" then ( let s = Env.get_extern id env "coq" in List.exists (fun x -> String.compare x s == 0) no_proof_fns + ) else false let replace_atom_return_type ret_typ = (* TODO: more complex uses of atom *) match ret_typ with - | Typ_aux (Typ_app (Id_aux (Id "atom",_), [A_aux (A_nexp nexp,_)]),l) -> - let kid = mk_kid "_retval" in (* TODO: collision avoidance *) - Some "build_ex", Typ_aux (Typ_exist ([mk_kopt K_int kid], nc_eq (nvar kid) nexp, atom_typ (nvar kid)),Parse_ast.Generated l) - | Typ_aux (Typ_app (Id_aux (Id "atom_bool",il), ([A_aux (A_bool _,_)] as args)),l) -> - Some "build_ex", ret_typ - | _ -> None, ret_typ - -let is_range_from_atom env (Typ_aux (argty,_)) (Typ_aux (fnty,_)) = - match argty, fnty with - | Typ_app(Id_aux (Id "atom", _), [A_aux (A_nexp nexp,_)]), - Typ_app(Id_aux (Id "range", _), [A_aux(A_nexp low,_); - A_aux(A_nexp high,_)]) -> - Type_check.prove __POS__ env (nc_and (nc_eq nexp low) (nc_eq nexp high)) + | Typ_aux (Typ_app (Id_aux (Id "atom", _), [A_aux (A_nexp nexp, _)]), l) -> + let kid = mk_kid "_retval" in + (* TODO: collision avoidance *) + ( Some "build_ex", + Typ_aux (Typ_exist ([mk_kopt K_int kid], nc_eq (nvar kid) nexp, atom_typ (nvar kid)), Parse_ast.Generated l) + ) + | Typ_aux (Typ_app (Id_aux (Id "atom_bool", il), ([A_aux (A_bool _, _)] as args)), l) -> (Some "build_ex", ret_typ) + | _ -> (None, ret_typ) + +let is_range_from_atom env (Typ_aux (argty, _)) (Typ_aux (fnty, _)) = + match (argty, fnty) with + | ( Typ_app (Id_aux (Id "atom", _), [A_aux (A_nexp nexp, _)]), + Typ_app (Id_aux (Id "range", _), [A_aux (A_nexp low, _); A_aux (A_nexp high, _)]) ) -> + Type_check.prove __POS__ env (nc_and (nc_eq nexp low) (nc_eq nexp high)) | _ -> false (* Get a more general type for an annotation/expression - i.e., like typ_of but using the expected type if there was one *) -let general_typ_of_annot annot = - match expected_typ_of annot with - | None -> typ_of_annot annot - | Some typ -> typ +let general_typ_of_annot annot = match expected_typ_of annot with None -> typ_of_annot annot | Some typ -> typ -let general_typ_of (E_aux (_,annot)) = general_typ_of_annot annot +let general_typ_of (E_aux (_, annot)) = general_typ_of_annot annot let is_prefix s s' = let l = String.length s in - String.length s' >= l && - String.sub s' 0 l = s + String.length s' >= l && String.sub s' 0 l = s let merge_new_tyvars ctxt old_env pat new_env = - let remove_binding id (m,r) = + let remove_binding id (m, r) = match Bindings.find_opt id r with | Some kid -> - debug ctxt (lazy ("Removing " ^ string_of_kid kid ^ " to " ^ string_of_id id)); - KBindings.add kid None m, Bindings.remove id r - | None -> m,r + debug ctxt (lazy ("Removing " ^ string_of_kid kid ^ " to " ^ string_of_id id)); + (KBindings.add kid None m, Bindings.remove id r) + | None -> (m, r) in - let check_kid id kid (m,r) = + let check_kid id kid (m, r) = try let _ = Env.get_typ_var kid old_env in debug ctxt (lazy (" tyvar " ^ string_of_kid kid ^ " already in env")); - m,r + (m, r) with _ -> debug ctxt (lazy (" adding tyvar mapping " ^ string_of_kid kid ^ " to " ^ string_of_id id)); - KBindings.add kid (Some id) m, Bindings.add id kid r + (KBindings.add kid (Some id) m, Bindings.add id kid r) in let merge_new_kids id m = let typ = lvar_typ (Env.lookup_id id new_env) in - debug ctxt (lazy (" considering tyvar mapping for " ^ string_of_id id ^ " at type " ^ string_of_typ typ )); - match destruct_numeric typ, destruct_atom_bool new_env typ with - | Some ([],_,Nexp_aux (Nexp_var kid,_)), _ - | _, Some (NC_aux (NC_var kid,_)) - -> check_kid id kid m + debug ctxt (lazy (" considering tyvar mapping for " ^ string_of_id id ^ " at type " ^ string_of_typ typ)); + match (destruct_numeric typ, destruct_atom_bool new_env typ) with + | Some ([], _, Nexp_aux (Nexp_var kid, _)), _ | _, Some (NC_aux (NC_var kid, _)) -> check_kid id kid m | _ -> - debug ctxt (lazy (" not suitable type")); - m + debug ctxt (lazy " not suitable type"); + m in - let rec merge_pat m (P_aux (p,(l,_))) = + let rec merge_pat m (P_aux (p, (l, _))) = match p with - | P_lit _ | P_wild - -> m + | P_lit _ | P_wild -> m | P_not _ -> unreachable l __POS__ "Coq backend doesn't support not patterns" | P_or _ -> unreachable l __POS__ "Coq backend doesn't support or patterns yet" | P_vector_subrange _ -> unreachable l __POS__ "Must have been rewritten before Coq backend" - | P_typ (_,p) -> merge_pat m p - | P_as (p,id) -> merge_new_kids id (merge_pat m p) + | P_typ (_, p) -> merge_pat m p + | P_as (p, id) -> merge_new_kids id (merge_pat m p) | P_id id -> merge_new_kids id m - | P_var (p,ty_p) -> - begin match p, ty_p with - | _, TP_aux (TP_wild,_) -> merge_pat m p - | P_aux (P_id id,_), TP_aux (TP_var kid,_) -> check_kid id kid (merge_pat m p) - | _ -> merge_pat m p - end + | P_var (p, ty_p) -> begin + match (p, ty_p) with + | _, TP_aux (TP_wild, _) -> merge_pat m p + | P_aux (P_id id, _), TP_aux (TP_var kid, _) -> check_kid id kid (merge_pat m p) + | _ -> merge_pat m p + end (* Some of these don't make it through to the backend, but it's obvious what they'd do *) - | P_app (_,ps) - | P_vector ps - | P_vector_concat ps - | P_tuple ps - | P_list ps - | P_string_append ps - -> List.fold_left merge_pat m ps - | P_cons (p1,p2) -> merge_pat (merge_pat m p1) p2 + | P_app (_, ps) | P_vector ps | P_vector_concat ps | P_tuple ps | P_list ps | P_string_append ps -> + List.fold_left merge_pat m ps + | P_cons (p1, p2) -> merge_pat (merge_pat m p1) p2 in - let m,r = IdSet.fold remove_binding (pat_ids pat) (ctxt.kid_id_renames, ctxt.kid_id_renames_rev) in - let m,r = merge_pat (m, r) pat in + let m, r = IdSet.fold remove_binding (pat_ids pat) (ctxt.kid_id_renames, ctxt.kid_id_renames_rev) in + let m, r = merge_pat (m, r) pat in { ctxt with kid_id_renames = m; kid_id_renames_rev = r } let maybe_parens_comma_list f ls = - match ls with - | [x] -> f true x - | xs -> parens (separate (string ", ") (List.map (f false) xs)) + match ls with [x] -> f true x | xs -> parens (separate (string ", ") (List.map (f false) xs)) let prefix_recordtype = true let report = Reporting.err_unreachable let doc_exp, doc_let = - let rec top_exp (ctxt : context) (aexp_needed : bool) - (E_aux (e, (l,annot)) as full_exp) = - let top_exp c a e = + let rec top_exp (ctxt : context) (aexp_needed : bool) (E_aux (e, (l, annot)) as full_exp) = + let top_exp c a e = let () = debug_depth := !debug_depth + 1 in let r = top_exp c a e in let () = debug_depth := !debug_depth - 1 in @@ -1409,7 +1333,7 @@ let doc_exp, doc_let = let expY = top_exp ctxt true in let expN = top_exp ctxt false in let expV = top_exp ctxt in - let wrap_parens doc = if aexp_needed then parens (doc) else doc in + let wrap_parens doc = if aexp_needed then parens doc else doc in let maybe_cast descr typ pp = let env = env_of full_exp in let exp_typ = expand_range_type (Env.expand_synonyms env typ) in @@ -1417,10 +1341,10 @@ let doc_exp, doc_let = let ann_typ = expand_range_type (Env.expand_synonyms env ann_typ) in let autocast = (* Avoid using helper functions which simplify the nexps *) - match exp_typ, ann_typ with - | Typ_aux (Typ_app (Id_aux (Id "bitvector",_),[A_aux (A_nexp n1,_);_]),_), - Typ_aux (Typ_app (Id_aux (Id "bitvector",_),[A_aux (A_nexp n2,_);_]),_) -> - not (similar_nexps ctxt env n1 n2) + match (exp_typ, ann_typ) with + | ( Typ_aux (Typ_app (Id_aux (Id "bitvector", _), [A_aux (A_nexp n1, _); _]), _), + Typ_aux (Typ_app (Id_aux (Id "bitvector", _), [A_aux (A_nexp n2, _); _]), _) ) -> + not (similar_nexps ctxt env n1 n2) | _ -> false in let () = @@ -1428,773 +1352,834 @@ let doc_exp, doc_let = debug ctxt (lazy (" expected type " ^ string_of_typ ann_typ)); debug ctxt (lazy (" autocast " ^ string_of_bool autocast)) in - if autocast then - wrap_parens (string "autocast" ^/^ pp) - else - pp + if autocast then wrap_parens (string "autocast" ^/^ pp) else pp in let liftR doc = - if Option.is_some ctxt.early_ret && effectful (effect_of full_exp) - then separate space [string "liftR"; parens (doc)] - else doc in + if Option.is_some ctxt.early_ret && effectful (effect_of full_exp) then + separate space [string "liftR"; parens doc] + else doc + in match e with - | E_assign((LE_aux(le_act,tannot) as le), e) -> - (* can only be register writes *) - (match le_act (*, t, tag*) with - | LE_vector_range (le,e2,e3) -> - (match le with - | LE_aux (LE_field ((LE_aux (_, lannot) as le),id), fannot) -> - if is_bit_typ (typ_of_annot fannot) then - raise (report l __POS__ "indexing a register's (single bit) bitfield not supported") - else - let field_ref = - doc_id ctxt (typ_id_of (typ_of_annot lannot)) ^^ - underscore ^^ - doc_id ctxt id in - liftR ((prefix 2 1) - (string "write_reg_field_range") - (align (doc_lexp_deref ctxt le ^/^ - field_ref ^/^ expY e2 ^/^ expY e3 ^/^ expY e))) + | E_assign ((LE_aux (le_act, tannot) as le), e) -> ( + (* can only be register writes *) + match le_act (*, t, tag*) with + | LE_vector_range (le, e2, e3) -> ( + match le with + | LE_aux (LE_field ((LE_aux (_, lannot) as le), id), fannot) -> + if is_bit_typ (typ_of_annot fannot) then + raise (report l __POS__ "indexing a register's (single bit) bitfield not supported") + else ( + let field_ref = doc_id ctxt (typ_id_of (typ_of_annot lannot)) ^^ underscore ^^ doc_id ctxt id in + liftR + ((prefix 2 1) (string "write_reg_field_range") + (align (doc_lexp_deref ctxt le ^/^ field_ref ^/^ expY e2 ^/^ expY e3 ^/^ expY e)) + ) + ) | _ -> - let deref = doc_lexp_deref ctxt le in - liftR ((prefix 2 1) - (string "write_reg_range") - (align (deref ^/^ expY e2 ^/^ expY e3) ^/^ expY e))) - | LE_vector (le,e2) -> - (match le with - | LE_aux (LE_field ((LE_aux (_, lannot) as le),id), fannot) -> - if is_bit_typ (typ_of_annot fannot) then - raise (report l __POS__ "indexing a register's (single bit) bitfield not supported") - else - let field_ref = - doc_id ctxt (typ_id_of (typ_of_annot lannot)) ^^ - underscore ^^ - doc_id ctxt id in - let call = if is_bitvector_typ (Env.base_typ_of (env_of full_exp) (typ_of_annot fannot)) then "write_reg_field_bit" else "write_reg_field_pos" in - liftR ((prefix 2 1) - (string call) - (align (doc_lexp_deref ctxt le ^/^ - field_ref ^/^ expY e2 ^/^ expY e))) + let deref = doc_lexp_deref ctxt le in + liftR ((prefix 2 1) (string "write_reg_range") (align (deref ^/^ expY e2 ^/^ expY e3) ^/^ expY e)) + ) + | LE_vector (le, e2) -> ( + match le with + | LE_aux (LE_field ((LE_aux (_, lannot) as le), id), fannot) -> + if is_bit_typ (typ_of_annot fannot) then + raise (report l __POS__ "indexing a register's (single bit) bitfield not supported") + else ( + let field_ref = doc_id ctxt (typ_id_of (typ_of_annot lannot)) ^^ underscore ^^ doc_id ctxt id in + let call = + if is_bitvector_typ (Env.base_typ_of (env_of full_exp) (typ_of_annot fannot)) then + "write_reg_field_bit" + else "write_reg_field_pos" + in + liftR + ((prefix 2 1) (string call) (align (doc_lexp_deref ctxt le ^/^ field_ref ^/^ expY e2 ^/^ expY e))) + ) | LE_aux (_, lannot) -> - let deref = doc_lexp_deref ctxt le in - let call = if is_bitvector_typ (Env.base_typ_of (env_of full_exp) (typ_of_annot lannot)) then "write_reg_bit" else "write_reg_pos" in - liftR ((prefix 2 1) (string call) - (deref ^/^ expY e2 ^/^ expY e)) - ) - | LE_field ((LE_aux (_, lannot) as le),id) -> - let field_ref = - doc_id ctxt (typ_id_of (typ_of_annot lannot)) ^^ - underscore ^^ - doc_id ctxt id (*^^ - dot ^^ - string "set_field"*) in - liftR ((prefix 2 1) - (string "write_reg_field") - (doc_lexp_deref ctxt le ^^ space ^^ - field_ref ^/^ expY e)) - | LE_deref re -> - liftR ((prefix 2 1) (string "write_reg") (expY re ^/^ expY e)) - | _ -> - liftR ((prefix 2 1) (string "write_reg") (doc_lexp_deref ctxt le ^/^ expY e))) - | E_vector_append(le,re) -> - raise (Reporting.err_unreachable l __POS__ - "E_vector_append should have been rewritten before pretty-printing") - | E_cons(le,re) -> doc_op (group (colon^^colon)) (expY le) (expY re) - | E_if(c,t,e) -> - let epp = if_exp ctxt (env_of full_exp) (typ_of full_exp) false c t e in - if aexp_needed then parens (align epp) else epp - | E_for(id,exp1,exp2,exp3,(Ord_aux(order,_)),exp4) -> - raise (report l __POS__ "E_for should have been rewritten before pretty-printing") - | E_loop _ -> - raise (report l __POS__ "E_loop should have been rewritten before pretty-printing") - | E_let(leb,e) -> - let pat = match leb with LB_aux (LB_val (p,_),_) -> p in - let () = debug ctxt (lazy ("Let with pattern " ^ string_of_pat pat)) in - let new_ctxt = merge_new_tyvars ctxt (env_of_annot (l,annot)) pat (env_of e) in - let epp = let_exp ctxt leb ^^ space ^^ string "in" ^^ hardline ^^ top_exp new_ctxt false e in - if aexp_needed then parens epp else epp - | E_app(f,args) -> - let env = env_of full_exp in - let doc_loop_var (E_aux (e,(l,_)) as exp) = - match e with - | E_id id -> - let id_pp = doc_id ctxt id in - let typ = general_typ_of exp in - id_pp, id_pp - | E_lit (L_aux (L_unit,_)) -> string "tt", underscore - | _ -> raise (Reporting.err_unreachable l __POS__ - ("Bad expression for variable in loop: " ^ string_of_exp exp)) - in - let make_loop_vars extra_binders varstuple = - match varstuple with - | E_aux (E_tuple vs, _) -> - let vs = List.map doc_loop_var vs in - let mkpp f vs = separate (string ", ") (List.map f vs) in - let tup_pp = mkpp (fun (pp,_) -> pp) vs in - let match_pp = mkpp (fun (_,pp) -> pp) vs in - parens tup_pp, - separate space (string "fun" :: extra_binders @ - [squote ^^ parens match_pp; bigarrow]) - | _ -> - let exp_pp,match_pp = doc_loop_var varstuple in - exp_pp, - separate space (string "fun" :: extra_binders @ [match_pp; bigarrow]) - in - begin match f with - | Id_aux (Id "and_bool", _) | Id_aux (Id "or_bool", _) - when effectful (effect_of full_exp) -> - let suffix = "M" in - let call = doc_id ctxt (append_id f suffix) in - debug ctxt (lazy ("Effectful boolean op: " ^ string_of_id f)); - let doc_arg exp = - expY exp - in - let epp = hang 2 (flow (break 1) (call :: List.map doc_arg args)) in - wrap_parens epp - (* temporary hack to make the loop body a function of the temporary variables *) - | Id_aux (Id "None", _) as none -> doc_id_ctor ctxt none - | Id_aux (Id "foreach#", _) -> - begin - match args with - | [from_exp; to_exp; step_exp; ord_exp; vartuple; body] -> - let loopvar, body = match body with - | E_aux (E_if (_, - E_aux (E_let (LB_aux (LB_val ( - ((P_aux (P_typ (_, P_aux (P_var (P_aux (P_id id, _), _), _)), _)) - | (P_aux (P_var (P_aux (P_id id, _), _), _)) - | (P_aux (P_id id, _))), _), _), - body), _), _), _) -> id, body - | _ -> raise (Reporting.err_unreachable l __POS__ ("Unable to find loop variable in " ^ string_of_exp body)) in - let dir = match ord_exp with - | E_aux (E_lit (L_aux (L_false, _)), _) -> "_down" - | E_aux (E_lit (L_aux (L_true, _)), _) -> "_up" - | _ -> raise (Reporting.err_unreachable l __POS__ ("Unexpected loop direction " ^ string_of_exp ord_exp)) - in - let effects = effectful (effect_of body) in - let combinator = - if effects - then if ctxt.is_monadic - then "foreach_ZM" - else "foreach_ZE" - else "foreach_Z" in - let combinator = combinator ^ dir in - let body_ctxt = add_single_kid_id_rename ctxt loopvar (mk_kid ("loop_" ^ string_of_id loopvar)) in - let from_exp_pp, to_exp_pp, step_exp_pp = - expY from_exp, expY to_exp, expY step_exp - in - (* The body has the right type for deciding whether a proof is necessary *) - let vartuple_retyped = check_exp env (strip_exp vartuple) (general_typ_of body) in - let vartuple_pp, body_lambda = - make_loop_vars [doc_id ctxt loopvar] vartuple_retyped - in - (* TODO: this should probably be construct_dep_pairs, but we would need - to change it to use the updated context. *) - let body_pp = top_exp body_ctxt false body in - let loop_pp = - parens ( - (prefix 2 1) - ((separate space) [string combinator; - from_exp_pp; to_exp_pp; step_exp_pp; - vartuple_pp]) - (parens - (prefix 2 1 (group body_lambda) body_pp) - ) - ) - in - loop_pp - | _ -> raise (Reporting.err_unreachable l __POS__ - "Unexpected number of arguments for loop combinator") - end - | Id_aux (Id (("while#" | "until#" | "while#t" | "until#t") as combinator), _) -> - let combinator = String.sub combinator 0 (String.index combinator '#') in - begin - let cond, varstuple, body, measure = - match args with - | [cond; varstuple; body] -> cond, varstuple, body, None - | [cond; varstuple; body; measure] -> cond, varstuple, body, Some measure - | _ -> raise (Reporting.err_unreachable l __POS__ - "Unexpected number of arguments for loop combinator") - in - let return (E_aux (e, (l,a))) = - let a' = mk_tannot (env_of_annot (l,a)) bool_typ in - E_aux (E_internal_return (E_aux (e, (l,a))), (l,a')) - in - let simple_bool (E_aux (_, (l,a)) as exp) = - let a' = mk_tannot (env_of_annot (l,a)) bool_typ in - E_aux (E_typ (bool_typ, exp), (l,a')) - in - let monad = if ctxt.is_monadic then "M" else "E" in - let csuffix, cond, body, body_effectful = - match effectful (effect_of cond), effectful (effect_of body) with - | false, false -> "", cond, body, false - | false, true -> monad, return cond, body, true - | true, false -> monad, simple_bool cond, return body, true - | true, true -> monad, simple_bool cond, body, true - in - (* If rewrite_loops_with_escape_effect added a dummy assertion to - ensure that the loop can escape when it reaches the limit, omit - the dummy assert here. *) - let body = match body with - | E_aux (E_internal_plet - (P_aux ((P_wild | P_typ (_,P_aux (P_wild, _))),_), - E_aux (E_assert - (E_aux (E_lit (L_aux (L_true,_)),_), - E_aux (E_lit (L_aux (L_string "loop dummy assert",_)),_)) - ,_),body'),_) -> body' - | _ -> body - in - (* TODO: does this still make sense? *) - (* The variable tuple (and the loop body) may have - overspecific types, so use the loop's type for deciding - whether a proof is necessary *) - let body_pp = - if body_effectful then expV false body - else construct_dep_pairs ctxt (env_of body) false body (general_typ_of full_exp) in - let varstuple_retyped = check_exp env (strip_exp varstuple) (general_typ_of full_exp) in - let varstuple_pp, lambda = - make_loop_vars [] varstuple_retyped - in - let msuffix, measure_pp = - match measure with - | None -> "", [] - | Some exp -> "T", [parens (prefix 2 1 (group lambda) (expN exp))] - in - parens ( - (prefix 2 1) - (string (combinator ^ csuffix ^ msuffix)) - (separate (break 1) - (varstuple_pp::measure_pp@ - [parens (prefix 2 1 (group lambda) (expN cond)); - parens (prefix 2 1 (group lambda) body_pp)])) - ) - end - | Id_aux (Id "early_return", _) -> - begin - match args with - | [exp] -> - let exp_pp = expY exp in - let ret_typ_pp = doc_atomic_typ ctxt (env_of exp) false (typ_of exp) in - let local_typ_pp = doc_atomic_typ ctxt (env_of full_exp) false (typ_of full_exp) in - let inj, monad, args = - if ctxt.is_monadic - then "early_return", "MR", [local_typ_pp; ret_typ_pp] - else "inl", "sum", [ret_typ_pp; local_typ_pp] - in - let epp = separate space [string inj; exp_pp] in - let tannot = separate space (string monad :: args) - in - parens (doc_op colon epp tannot) - | _ -> raise (Reporting.err_unreachable l __POS__ - "Unexpected number of arguments for early_return builtin") - end - | _ -> - let env = env_of_annot (l,annot) in - let () = debug ctxt (lazy ("Function application " ^ string_of_id f)) in - let call, is_extern, is_ctor, is_rec = - if Env.is_union_constructor f env then doc_id_ctor ctxt f, false, true, None else - if Env.is_extern f env "coq" - then string (Env.get_extern f env "coq"), true, false, None - else doc_id ctxt f, false, false, Bindings.find_opt f ctxt.recursive_fns - in - let (tqs,fn_ty) = - if is_ctor then Env.get_union_id f env else Env.get_val_spec f env - in - (* Calculate the renaming *) - let tqs_map = List.fold_left - (fun m k -> - let kid = kopt_kid k in - KBindings.add (orig_kid kid) kid m) - KBindings.empty (quant_kopts tqs) in - let arg_typs, ret_typ = match fn_ty with - | Typ_aux (Typ_fn (arg_typs,ret_typ),_) -> arg_typs, ret_typ - | _ -> raise (Reporting.err_unreachable l __POS__ "Function not a function type") - in - let fn_typ_env = List.fold_left (fun env kopt -> Env.add_typ_var l kopt env) env (quant_kopts tqs) in - let is_monadic = not (Effects.function_is_pure f ctxt.effect_info) in - let inst, inst_env = - (* We attempt to get an instantiation of the function signature's - type variables which agrees with Coq by - 1. using dummy variables with the expected type of each argument - (avoiding the inferred type, which might have (e.g.) stripped - out an existential quantifier) - 2. calculating the instantiation without using the expected - return type, so that we can work out if we need a cast around - the function call. *) - let dummy_args = - List.mapi (fun i exp -> mk_id ("#coq#arg" ^ string_of_int i), - general_typ_of exp) args - in - let () = debug ctxt (lazy (" arg types: " ^ String.concat ", " (List.map (fun (_,ty) -> string_of_typ ty) dummy_args))) in - let dummy_exp = mk_exp (E_app (f, List.map (fun (id,_) -> mk_exp (E_id id)) dummy_args)) in - let dummy_env = List.fold_left (fun env (id,typ) -> Env.add_local id (Immutable,typ) env) env dummy_args in - let (E_aux (_, (_, inst_tannot))) as inst_exp = - try infer_exp dummy_env dummy_exp - with ex -> - debug ctxt (lazy (" cannot infer dummy application " ^ Printexc.to_string ex)); - full_exp - in - (* We may have inherited existentials from the arguments, - so add any to the environment. *) - let inst_env = - match typ_of inst_exp with - | Typ_aux (Typ_exist (kopts, _, _), l) -> - List.fold_left (fun env kopt -> Env.add_typ_var l kopt env) env kopts - | _ -> env - in - match get_instantiations inst_tannot with - | Some x -> x, inst_env - (* Not all function applications can be inferred, so try falling back to the - type inferred when we know the target type. - TODO: there are probably some edge cases where this won't pick up a need - to cast. *) - | None -> - (debug ctxt (lazy (" unable to infer function instantiation without return type " ^ string_of_typ (typ_of full_exp))); - instantiation_of full_exp, env) - in - let () = debug ctxt (lazy (" instantiations pre-rename: " ^ String.concat ", " (List.map (fun (kid,tyarg) -> string_of_kid kid ^ " => " ^ string_of_typ_arg tyarg) (KBindings.bindings inst)))) in - let inst = KBindings.fold (fun k u m -> - match KBindings.find_opt (orig_kid k) tqs_map with - | Some k' -> KBindings.add k' u m - | None -> m (* must have been an existential *) ) inst KBindings.empty in - let () = debug ctxt (lazy (" instantiations: " ^ String.concat ", " (List.map (fun (kid,tyarg) -> string_of_kid kid ^ " => " ^ string_of_typ_arg tyarg) (KBindings.bindings inst)))) in - - (* Decide whether to unpack an existential result, pack one, or cast. - To do this we compare the expected type stored in the checked expression - with the inferred type. *) - let ret_typ_inst = - subst_unifiers inst ret_typ - in - - (* TODO: clean up some remnants of the embedded proofs *) - let autocast = - let ann_typ = Env.expand_synonyms env (general_typ_of_annot (l,annot)) in - let ann_typ = expand_range_type ann_typ in - let ret_typ_inst = expand_range_type (Env.expand_synonyms inst_env ret_typ_inst) in - let ret_typ_inst = - if is_no_proof_fn env f then ret_typ_inst - else snd (replace_atom_return_type ret_typ_inst) in - let () = - debug ctxt (lazy (" type returned " ^ string_of_typ ret_typ_inst)); - debug ctxt (lazy (" type expected " ^ string_of_typ ann_typ)) - in - let in_typ = (* TODO: just existential stripping? *) - if is_no_proof_fn env f then ret_typ_inst else - match classify_ex_type ctxt inst_env ~rawbools:true ret_typ_inst with - | ExGeneral, _, t1 -> t1 - | ExNone, _, t1 -> t1 - in - let out_typ = - match ann_typ with - | Typ_aux (Typ_exist (_,_,t1),_) -> t1 - | t1 -> t1 + let deref = doc_lexp_deref ctxt le in + let call = + if is_bitvector_typ (Env.base_typ_of (env_of full_exp) (typ_of_annot lannot)) then "write_reg_bit" + else "write_reg_pos" + in + liftR ((prefix 2 1) (string call) (deref ^/^ expY e2 ^/^ expY e)) + ) + | LE_field ((LE_aux (_, lannot) as le), id) -> + let field_ref = + doc_id ctxt (typ_id_of (typ_of_annot lannot)) ^^ underscore ^^ doc_id ctxt id + (*^^ + dot ^^ + string "set_field"*) in - let autocast = - (* Avoid using helper functions which simplify the nexps *) - match in_typ, out_typ with - | Typ_aux (Typ_app (Id_aux (Id "bitvector",_),[A_aux (A_nexp n1,_);_]),_), - Typ_aux (Typ_app (Id_aux (Id "bitvector",_),[A_aux (A_nexp n2,_);_]),_) -> - not (similar_nexps ctxt env n1 n2) - | _ -> false - in autocast - in + liftR ((prefix 2 1) (string "write_reg_field") (doc_lexp_deref ctxt le ^^ space ^^ field_ref ^/^ expY e)) + | LE_deref re -> liftR ((prefix 2 1) (string "write_reg") (expY re ^/^ expY e)) + | _ -> liftR ((prefix 2 1) (string "write_reg") (doc_lexp_deref ctxt le ^/^ expY e)) + ) + | E_vector_append (le, re) -> + raise (Reporting.err_unreachable l __POS__ "E_vector_append should have been rewritten before pretty-printing") + | E_cons (le, re) -> doc_op (group (colon ^^ colon)) (expY le) (expY re) + | E_if (c, t, e) -> + let epp = if_exp ctxt (env_of full_exp) (typ_of full_exp) false c t e in + if aexp_needed then parens (align epp) else epp + | E_for (id, exp1, exp2, exp3, Ord_aux (order, _), exp4) -> + raise (report l __POS__ "E_for should have been rewritten before pretty-printing") + | E_loop _ -> raise (report l __POS__ "E_loop should have been rewritten before pretty-printing") + | E_let (leb, e) -> + let pat = match leb with LB_aux (LB_val (p, _), _) -> p in + let () = debug ctxt (lazy ("Let with pattern " ^ string_of_pat pat)) in + let new_ctxt = merge_new_tyvars ctxt (env_of_annot (l, annot)) pat (env_of e) in + let epp = let_exp ctxt leb ^^ space ^^ string "in" ^^ hardline ^^ top_exp new_ctxt false e in + if aexp_needed then parens epp else epp + | E_app (f, args) -> + let env = env_of full_exp in + let doc_loop_var (E_aux (e, (l, _)) as exp) = + match e with + | E_id id -> + let id_pp = doc_id ctxt id in + let typ = general_typ_of exp in + (id_pp, id_pp) + | E_lit (L_aux (L_unit, _)) -> (string "tt", underscore) + | _ -> + raise (Reporting.err_unreachable l __POS__ ("Bad expression for variable in loop: " ^ string_of_exp exp)) + in + let make_loop_vars extra_binders varstuple = + match varstuple with + | E_aux (E_tuple vs, _) -> + let vs = List.map doc_loop_var vs in + let mkpp f vs = separate (string ", ") (List.map f vs) in + let tup_pp = mkpp (fun (pp, _) -> pp) vs in + let match_pp = mkpp (fun (_, pp) -> pp) vs in + (parens tup_pp, separate space ((string "fun" :: extra_binders) @ [squote ^^ parens match_pp; bigarrow])) + | _ -> + let exp_pp, match_pp = doc_loop_var varstuple in + (exp_pp, separate space ((string "fun" :: extra_binders) @ [match_pp; bigarrow])) + in + begin + match f with + | (Id_aux (Id "and_bool", _) | Id_aux (Id "or_bool", _)) when effectful (effect_of full_exp) -> + let suffix = "M" in + let call = doc_id ctxt (append_id f suffix) in + debug ctxt (lazy ("Effectful boolean op: " ^ string_of_id f)); + let doc_arg exp = expY exp in + let epp = hang 2 (flow (break 1) (call :: List.map doc_arg args)) in + wrap_parens epp + (* temporary hack to make the loop body a function of the temporary variables *) + | Id_aux (Id "None", _) as none -> doc_id_ctor ctxt none + | Id_aux (Id "foreach#", _) -> begin + match args with + | [from_exp; to_exp; step_exp; ord_exp; vartuple; body] -> + let loopvar, body = + match body with + | E_aux + ( E_if + ( _, + E_aux + ( E_let + ( LB_aux + ( LB_val + ( ( P_aux (P_typ (_, P_aux (P_var (P_aux (P_id id, _), _), _)), _) + | P_aux (P_var (P_aux (P_id id, _), _), _) + | P_aux (P_id id, _) ), + _ + ), + _ + ), + body + ), + _ + ), + _ + ), + _ + ) -> + (id, body) + | _ -> + raise + (Reporting.err_unreachable l __POS__ ("Unable to find loop variable in " ^ string_of_exp body)) + in + let dir = + match ord_exp with + | E_aux (E_lit (L_aux (L_false, _)), _) -> "_down" + | E_aux (E_lit (L_aux (L_true, _)), _) -> "_up" + | _ -> + raise + (Reporting.err_unreachable l __POS__ ("Unexpected loop direction " ^ string_of_exp ord_exp)) + in + let effects = effectful (effect_of body) in + let combinator = + if effects then if ctxt.is_monadic then "foreach_ZM" else "foreach_ZE" else "foreach_Z" + in + let combinator = combinator ^ dir in + let body_ctxt = add_single_kid_id_rename ctxt loopvar (mk_kid ("loop_" ^ string_of_id loopvar)) in + let from_exp_pp, to_exp_pp, step_exp_pp = (expY from_exp, expY to_exp, expY step_exp) in + (* The body has the right type for deciding whether a proof is necessary *) + let vartuple_retyped = check_exp env (strip_exp vartuple) (general_typ_of body) in + let vartuple_pp, body_lambda = make_loop_vars [doc_id ctxt loopvar] vartuple_retyped in + (* TODO: this should probably be construct_dep_pairs, but we would need + to change it to use the updated context. *) + let body_pp = top_exp body_ctxt false body in + let loop_pp = + parens + ((prefix 2 1) + ((separate space) [string combinator; from_exp_pp; to_exp_pp; step_exp_pp; vartuple_pp]) + (parens (prefix 2 1 (group body_lambda) body_pp)) + ) + in + loop_pp + | _ -> raise (Reporting.err_unreachable l __POS__ "Unexpected number of arguments for loop combinator") + end + | Id_aux (Id (("while#" | "until#" | "while#t" | "until#t") as combinator), _) -> + let combinator = String.sub combinator 0 (String.index combinator '#') in + begin + let cond, varstuple, body, measure = + match args with + | [cond; varstuple; body] -> (cond, varstuple, body, None) + | [cond; varstuple; body; measure] -> (cond, varstuple, body, Some measure) + | _ -> raise (Reporting.err_unreachable l __POS__ "Unexpected number of arguments for loop combinator") + in + let return (E_aux (e, (l, a))) = + let a' = mk_tannot (env_of_annot (l, a)) bool_typ in + E_aux (E_internal_return (E_aux (e, (l, a))), (l, a')) + in + let simple_bool (E_aux (_, (l, a)) as exp) = + let a' = mk_tannot (env_of_annot (l, a)) bool_typ in + E_aux (E_typ (bool_typ, exp), (l, a')) + in + let monad = if ctxt.is_monadic then "M" else "E" in + let csuffix, cond, body, body_effectful = + match (effectful (effect_of cond), effectful (effect_of body)) with + | false, false -> ("", cond, body, false) + | false, true -> (monad, return cond, body, true) + | true, false -> (monad, simple_bool cond, return body, true) + | true, true -> (monad, simple_bool cond, body, true) + in + (* If rewrite_loops_with_escape_effect added a dummy assertion to + ensure that the loop can escape when it reaches the limit, omit + the dummy assert here. *) + let body = + match body with + | E_aux + ( E_internal_plet + ( P_aux ((P_wild | P_typ (_, P_aux (P_wild, _))), _), + E_aux + ( E_assert + ( E_aux (E_lit (L_aux (L_true, _)), _), + E_aux (E_lit (L_aux (L_string "loop dummy assert", _)), _) + ), + _ + ), + body' + ), + _ + ) -> + body' + | _ -> body + in + (* TODO: does this still make sense? *) + (* The variable tuple (and the loop body) may have + overspecific types, so use the loop's type for deciding + whether a proof is necessary *) + let body_pp = + if body_effectful then expV false body + else construct_dep_pairs ctxt (env_of body) false body (general_typ_of full_exp) + in + let varstuple_retyped = check_exp env (strip_exp varstuple) (general_typ_of full_exp) in + let varstuple_pp, lambda = make_loop_vars [] varstuple_retyped in + let msuffix, measure_pp = + match measure with + | None -> ("", []) + | Some exp -> ("T", [parens (prefix 2 1 (group lambda) (expN exp))]) + in + parens + ((prefix 2 1) + (string (combinator ^ csuffix ^ msuffix)) + (separate (break 1) + ((varstuple_pp :: measure_pp) + @ [parens (prefix 2 1 (group lambda) (expN cond)); parens (prefix 2 1 (group lambda) body_pp)] + ) + ) + ) + end + | Id_aux (Id "early_return", _) -> begin + match args with + | [exp] -> + let exp_pp = expY exp in + let ret_typ_pp = doc_atomic_typ ctxt (env_of exp) false (typ_of exp) in + let local_typ_pp = doc_atomic_typ ctxt (env_of full_exp) false (typ_of full_exp) in + let inj, monad, args = + if ctxt.is_monadic then ("early_return", "MR", [local_typ_pp; ret_typ_pp]) + else ("inl", "sum", [ret_typ_pp; local_typ_pp]) + in + let epp = separate space [string inj; exp_pp] in + let tannot = separate space (string monad :: args) in + parens (doc_op colon epp tannot) + | _ -> + raise (Reporting.err_unreachable l __POS__ "Unexpected number of arguments for early_return builtin") + end + | _ -> + let env = env_of_annot (l, annot) in + let () = debug ctxt (lazy ("Function application " ^ string_of_id f)) in + let call, is_extern, is_ctor, is_rec = + if Env.is_union_constructor f env then (doc_id_ctor ctxt f, false, true, None) + else if Env.is_extern f env "coq" then (string (Env.get_extern f env "coq"), true, false, None) + else (doc_id ctxt f, false, false, Bindings.find_opt f ctxt.recursive_fns) + in + let tqs, fn_ty = if is_ctor then Env.get_union_id f env else Env.get_val_spec f env in + (* Calculate the renaming *) + let tqs_map = + List.fold_left + (fun m k -> + let kid = kopt_kid k in + KBindings.add (orig_kid kid) kid m + ) + KBindings.empty (quant_kopts tqs) + in + let arg_typs, ret_typ = + match fn_ty with + | Typ_aux (Typ_fn (arg_typs, ret_typ), _) -> (arg_typs, ret_typ) + | _ -> raise (Reporting.err_unreachable l __POS__ "Function not a function type") + in + let fn_typ_env = List.fold_left (fun env kopt -> Env.add_typ_var l kopt env) env (quant_kopts tqs) in + let is_monadic = not (Effects.function_is_pure f ctxt.effect_info) in + let inst, inst_env = + (* We attempt to get an instantiation of the function signature's + type variables which agrees with Coq by + 1. using dummy variables with the expected type of each argument + (avoiding the inferred type, which might have (e.g.) stripped + out an existential quantifier) + 2. calculating the instantiation without using the expected + return type, so that we can work out if we need a cast around + the function call. *) + let dummy_args = + List.mapi (fun i exp -> (mk_id ("#coq#arg" ^ string_of_int i), general_typ_of exp)) args + in + let () = + debug ctxt + (lazy (" arg types: " ^ String.concat ", " (List.map (fun (_, ty) -> string_of_typ ty) dummy_args))) + in + let dummy_exp = mk_exp (E_app (f, List.map (fun (id, _) -> mk_exp (E_id id)) dummy_args)) in + let dummy_env = + List.fold_left (fun env (id, typ) -> Env.add_local id (Immutable, typ) env) env dummy_args + in + let (E_aux (_, (_, inst_tannot)) as inst_exp) = + try infer_exp dummy_env dummy_exp + with ex -> + debug ctxt (lazy (" cannot infer dummy application " ^ Printexc.to_string ex)); + full_exp + in + (* We may have inherited existentials from the arguments, + so add any to the environment. *) + let inst_env = + match typ_of inst_exp with + | Typ_aux (Typ_exist (kopts, _, _), l) -> + List.fold_left (fun env kopt -> Env.add_typ_var l kopt env) env kopts + | _ -> env + in + match get_instantiations inst_tannot with + | Some x -> (x, inst_env) + (* Not all function applications can be inferred, so try falling back to the + type inferred when we know the target type. + TODO: there are probably some edge cases where this won't pick up a need + to cast. *) + | None -> + debug ctxt + ( lazy + (" unable to infer function instantiation without return type " + ^ string_of_typ (typ_of full_exp) + ) + ); + (instantiation_of full_exp, env) + in + let () = + debug ctxt + ( lazy + (" instantiations pre-rename: " + ^ String.concat ", " + (List.map + (fun (kid, tyarg) -> string_of_kid kid ^ " => " ^ string_of_typ_arg tyarg) + (KBindings.bindings inst) + ) + ) + ) + in + let inst = + KBindings.fold + (fun k u m -> + match KBindings.find_opt (orig_kid k) tqs_map with + | Some k' -> KBindings.add k' u m + | None -> m (* must have been an existential *) + ) + inst KBindings.empty + in + let () = + debug ctxt + ( lazy + (" instantiations: " + ^ String.concat ", " + (List.map + (fun (kid, tyarg) -> string_of_kid kid ^ " => " ^ string_of_typ_arg tyarg) + (KBindings.bindings inst) + ) + ) + ) + in - let simple_type_equations = Type_check.instantiate_simple_equations (quant_items tqs) in + (* Decide whether to unpack an existential result, pack one, or cast. + To do this we compare the expected type stored in the checked expression + with the inferred type. *) + let ret_typ_inst = subst_unifiers inst ret_typ in + + (* TODO: clean up some remnants of the embedded proofs *) + let autocast = + let ann_typ = Env.expand_synonyms env (general_typ_of_annot (l, annot)) in + let ann_typ = expand_range_type ann_typ in + let ret_typ_inst = expand_range_type (Env.expand_synonyms inst_env ret_typ_inst) in + let ret_typ_inst = + if is_no_proof_fn env f then ret_typ_inst else snd (replace_atom_return_type ret_typ_inst) + in + let () = + debug ctxt (lazy (" type returned " ^ string_of_typ ret_typ_inst)); + debug ctxt (lazy (" type expected " ^ string_of_typ ann_typ)) + in + let in_typ = + (* TODO: just existential stripping? *) + if is_no_proof_fn env f then ret_typ_inst + else ( + match classify_ex_type ctxt inst_env ~rawbools:true ret_typ_inst with + | ExGeneral, _, t1 -> t1 + | ExNone, _, t1 -> t1 + ) + in + let out_typ = match ann_typ with Typ_aux (Typ_exist (_, _, t1), _) -> t1 | t1 -> t1 in + let autocast = + (* Avoid using helper functions which simplify the nexps *) + match (in_typ, out_typ) with + | ( Typ_aux (Typ_app (Id_aux (Id "bitvector", _), [A_aux (A_nexp n1, _); _]), _), + Typ_aux (Typ_app (Id_aux (Id "bitvector", _), [A_aux (A_nexp n2, _); _]), _) ) -> + not (similar_nexps ctxt env n1 n2) + | _ -> false + in + autocast + in - let doc_arg want_parens arg typ_from_fn = - let env = env_of arg in - let fixed_ghost_arg = - match destruct_atom_nexp fn_typ_env typ_from_fn with - | Some (Nexp_aux (Nexp_var kid, _)) -> begin - match KBindings.find_opt kid simple_type_equations with - | Some (A_aux (A_nexp (Nexp_aux (Nexp_constant _, _)), _)) -> true + let simple_type_equations = Type_check.instantiate_simple_equations (quant_items tqs) in + + let doc_arg want_parens arg typ_from_fn = + let env = env_of arg in + let fixed_ghost_arg = + match destruct_atom_nexp fn_typ_env typ_from_fn with + | Some (Nexp_aux (Nexp_var kid, _)) -> begin + match KBindings.find_opt kid simple_type_equations with + | Some (A_aux (A_nexp (Nexp_aux (Nexp_constant _, _)), _)) -> true + | _ -> false + end | _ -> false - end - | _ -> false - in - let typ_from_fn = subst_unifiers inst typ_from_fn in - let typ_from_fn = Env.expand_synonyms inst_env typ_from_fn in - (* TODO: more sophisticated check *) - let () = - debug ctxt (lazy (" arg type found " ^ string_of_typ (typ_of arg))); - debug ctxt (lazy (" arg type expected " ^ string_of_typ typ_from_fn)) - in - let typ_of_arg = Env.expand_synonyms env (typ_of arg) in - let typ_of_arg = expand_range_type typ_of_arg in - let typ_of_arg' = match typ_of_arg with Typ_aux (Typ_exist (_,_,t),_) -> t | t -> t in - let typ_from_fn' = match typ_from_fn with Typ_aux (Typ_exist (_,_,t),_) -> t | t -> t in - (* If the argument is an integer that can be inferred from the - context in a different form, let Coq fill it in. E.g., - when "64" is really "8 * width". Avoid cases where the - type checker has introduced a phantom type variable while - calculating the instantiations. *) - let vars_in_env n = - let ekids = Env.get_typ_vars env in - let frees = nexp_frees n in - not (KidSet.is_empty frees) && - KidSet.for_all (fun kid -> KBindings.mem kid ekids) frees - in - match destruct_atom_nexp env typ_of_arg, destruct_atom_nexp env typ_from_fn with - | _, _ when fixed_ghost_arg -> - (* Comment out an argument whose value is fixed by an equation in the function's - type signature, because it's let-bound in the Coq definition rather than being - a real argument. *) - comment (construct_dep_pairs ctxt inst_env want_parens arg typ_from_fn) - | Some n1, Some n2 - when (not autocast) && vars_in_env n2 && not (similar_nexps ctxt env n1 n2) -> - debug ctxt (lazy (" leaving int arg implicit because of non-trivial types " ^ string_of_nexp n1 ^ " and " ^ string_of_nexp n2)); - underscore - | Some (Nexp_aux (Nexp_var _,_)), Some (Nexp_aux (Nexp_constant c,_)) -> - string (Big_int.to_string c) - | _ -> - construct_dep_pairs ctxt inst_env want_parens arg typ_from_fn - in - let epp = - if is_ctor - then - let argspp = match args, arg_typs with - | [arg], [arg_typ] -> doc_arg true arg arg_typ - | _, _ -> parens (flow (comma ^^ break 1) (List.map2 (doc_arg false) args arg_typs)) - in group (hang 2 (call ^^ break 1 ^^ argspp)) - else - let argspp = List.map2 (doc_arg true) args arg_typs in - let all = - match is_rec with - | Some (pre,post) -> call :: List.init pre (fun _ -> underscore) @ argspp @ - List.init post (fun _ -> underscore) @ - [parens (string "_limit_reduces _acc")] - | None -> - match f with - | Id_aux (Id x,_) when is_prefix "#rec#" x -> - call :: argspp @ [parens (string "Zwf_guarded _")] - | _ -> call :: argspp - in hang 2 (flow (break 1) all) in + in + let typ_from_fn = subst_unifiers inst typ_from_fn in + let typ_from_fn = Env.expand_synonyms inst_env typ_from_fn in + (* TODO: more sophisticated check *) + let () = + debug ctxt (lazy (" arg type found " ^ string_of_typ (typ_of arg))); + debug ctxt (lazy (" arg type expected " ^ string_of_typ typ_from_fn)) + in + let typ_of_arg = Env.expand_synonyms env (typ_of arg) in + let typ_of_arg = expand_range_type typ_of_arg in + let typ_of_arg' = match typ_of_arg with Typ_aux (Typ_exist (_, _, t), _) -> t | t -> t in + let typ_from_fn' = match typ_from_fn with Typ_aux (Typ_exist (_, _, t), _) -> t | t -> t in + (* If the argument is an integer that can be inferred from the + context in a different form, let Coq fill it in. E.g., + when "64" is really "8 * width". Avoid cases where the + type checker has introduced a phantom type variable while + calculating the instantiations. *) + let vars_in_env n = + let ekids = Env.get_typ_vars env in + let frees = nexp_frees n in + (not (KidSet.is_empty frees)) && KidSet.for_all (fun kid -> KBindings.mem kid ekids) frees + in + match (destruct_atom_nexp env typ_of_arg, destruct_atom_nexp env typ_from_fn) with + | _, _ when fixed_ghost_arg -> + (* Comment out an argument whose value is fixed by an equation in the function's + type signature, because it's let-bound in the Coq definition rather than being + a real argument. *) + comment (construct_dep_pairs ctxt inst_env want_parens arg typ_from_fn) + | Some n1, Some n2 when (not autocast) && vars_in_env n2 && not (similar_nexps ctxt env n1 n2) -> + debug ctxt + ( lazy + (" leaving int arg implicit because of non-trivial types " ^ string_of_nexp n1 ^ " and " + ^ string_of_nexp n2 + ) + ); + underscore + | Some (Nexp_aux (Nexp_var _, _)), Some (Nexp_aux (Nexp_constant c, _)) -> string (Big_int.to_string c) + | _ -> construct_dep_pairs ctxt inst_env want_parens arg typ_from_fn + in + let epp = + if is_ctor then ( + let argspp = + match (args, arg_typs) with + | [arg], [arg_typ] -> doc_arg true arg arg_typ + | _, _ -> parens (flow (comma ^^ break 1) (List.map2 (doc_arg false) args arg_typs)) + in + group (hang 2 (call ^^ break 1 ^^ argspp)) + ) + else ( + let argspp = List.map2 (doc_arg true) args arg_typs in + let all = + match is_rec with + | Some (pre, post) -> + (call :: List.init pre (fun _ -> underscore)) + @ argspp + @ List.init post (fun _ -> underscore) + @ [parens (string "_limit_reduces _acc")] + | None -> ( + match f with + | Id_aux (Id x, _) when is_prefix "#rec#" x -> + (call :: argspp) @ [parens (string "Zwf_guarded _")] + | _ -> call :: argspp + ) + in + hang 2 (flow (break 1) all) + ) + in - let () = - debug ctxt (lazy (" autocast: " ^ string_of_bool autocast)) - in - let autocast_id = if is_monadic then "autocast_m" else "autocast" in - let epp = if autocast then string autocast_id ^^ space ^^ parens epp else epp in - liftR (if aexp_needed then parens (align epp) else epp) - end - | E_vector_access (v,e) -> - raise (Reporting.err_unreachable l __POS__ - "E_vector_access should have been rewritten before pretty-printing") - | E_vector_subrange (v,e1,e2) -> - raise (Reporting.err_unreachable l __POS__ - "E_vector_subrange should have been rewritten before pretty-printing") - | E_field((E_aux(_,(l,fannot)) as fexp),id) -> - (match destruct_tannot fannot with - | Some(env, (Typ_aux (Typ_id tid, _))) - | Some(env, (Typ_aux (Typ_app (tid, _), _))) - when Env.is_record tid env -> - let fname = - if prefix_recordtype && string_of_id tid <> "regstate" - then (string (string_of_id tid ^ "_")) ^^ doc_id ctxt id - else doc_id ctxt id in - let exp_pp = expY fexp ^^ dot ^^ parens fname in - let field_typ = expand_range_type (Env.expand_synonyms env (typ_of_annot (l,annot))) in - exp_pp - | _ -> - raise (report l __POS__ "E_field expression with no register or record type")) + let () = debug ctxt (lazy (" autocast: " ^ string_of_bool autocast)) in + let autocast_id = if is_monadic then "autocast_m" else "autocast" in + let epp = if autocast then string autocast_id ^^ space ^^ parens epp else epp in + liftR (if aexp_needed then parens (align epp) else epp) + end + | E_vector_access (v, e) -> + raise (Reporting.err_unreachable l __POS__ "E_vector_access should have been rewritten before pretty-printing") + | E_vector_subrange (v, e1, e2) -> + raise (Reporting.err_unreachable l __POS__ "E_vector_subrange should have been rewritten before pretty-printing") + | E_field ((E_aux (_, (l, fannot)) as fexp), id) -> ( + match destruct_tannot fannot with + | (Some (env, Typ_aux (Typ_id tid, _)) | Some (env, Typ_aux (Typ_app (tid, _), _))) when Env.is_record tid env + -> + let fname = + if prefix_recordtype && string_of_id tid <> "regstate" then + string (string_of_id tid ^ "_") ^^ doc_id ctxt id + else doc_id ctxt id + in + let exp_pp = expY fexp ^^ dot ^^ parens fname in + let field_typ = expand_range_type (Env.expand_synonyms env (typ_of_annot (l, annot))) in + exp_pp + | _ -> raise (report l __POS__ "E_field expression with no register or record type") + ) | E_block [] -> string "tt" | E_block exps -> raise (report l __POS__ "Blocks should have been removed till now.") | E_id id | E_ref id -> - let env = env_of full_exp in - let typ = typ_of full_exp in - let eff = effect_of full_exp in - let base_typ = Env.base_typ_of env typ in - if Env.is_register id env && (match e with E_id _ -> true | _ -> false) then - let epp = separate space [string "read_reg"; doc_id ctxt id ^^ string "_ref"] in - if is_bitvector_typ base_typ - then wrap_parens (align (group (prefix 0 1 (parens (liftR epp)) (doc_tannot ctxt env true base_typ)))) - else liftR epp - else if Env.is_register id env && (match e with E_ref _ -> true | _ -> false) then doc_id ctxt id ^^ string "_ref" - else if is_ctor env id then doc_id_ctor ctxt id - else begin - match Env.lookup_id id env with - | Local (_,typ) -> - let id_pp = doc_id ctxt id in - maybe_cast ("Variable " ^ string_of_id id) typ id_pp - | _ -> doc_id ctxt id - end + let env = env_of full_exp in + let typ = typ_of full_exp in + let eff = effect_of full_exp in + let base_typ = Env.base_typ_of env typ in + if Env.is_register id env && match e with E_id _ -> true | _ -> false then ( + let epp = separate space [string "read_reg"; doc_id ctxt id ^^ string "_ref"] in + if is_bitvector_typ base_typ then + wrap_parens (align (group (prefix 0 1 (parens (liftR epp)) (doc_tannot ctxt env true base_typ)))) + else liftR epp + ) + else if Env.is_register id env && match e with E_ref _ -> true | _ -> false then doc_id ctxt id ^^ string "_ref" + else if is_ctor env id then doc_id_ctor ctxt id + else begin + match Env.lookup_id id env with + | Local (_, typ) -> + let id_pp = doc_id ctxt id in + maybe_cast ("Variable " ^ string_of_id id) typ id_pp + | _ -> doc_id ctxt id + end | E_lit lit -> - let lit_pp = doc_lit lit in - maybe_cast "Literal" (typ_of full_exp) lit_pp - | E_tuple _ - | E_typ(_, E_aux (E_tuple _, _)) -> - construct_dep_pairs ctxt (env_of_annot (l,annot)) true full_exp (general_typ_of full_exp) - | E_typ(typ,e) -> - let env = env_of_annot (l,annot) in - let outer_typ = Env.expand_synonyms env (general_typ_of_annot (l,annot)) in - let outer_typ = expand_range_type outer_typ in - let cast_typ = expand_range_type (Env.expand_synonyms env typ) in - let inner_typ = Env.expand_synonyms env (typ_of e) in - let inner_typ = expand_range_type inner_typ in - let () = - debug ctxt (lazy ("Cast of type " ^ string_of_typ cast_typ)); - debug ctxt (lazy (" on expr of type " ^ string_of_typ inner_typ)); - debug ctxt (lazy (" where type expected is " ^ string_of_typ outer_typ)) - in - let epp = expV true e in - let outer_ex,_,outer_typ' = classify_ex_type ctxt env outer_typ in - let cast_ex,_,cast_typ' = classify_ex_type ctxt env ~rawbools:true cast_typ in - let inner_ex,_,inner_typ' = classify_ex_type ctxt env inner_typ in - let autocast_out = - (* Avoid using helper functions which simplify the nexps *) - match outer_typ', cast_typ' with - | Typ_aux (Typ_app (Id_aux (Id "bitvector",_),[A_aux (A_nexp n1,_);_]),_), - Typ_aux (Typ_app (Id_aux (Id "bitvector",_),[A_aux (A_nexp n2,_);_]),_) -> + let lit_pp = doc_lit lit in + maybe_cast "Literal" (typ_of full_exp) lit_pp + | E_tuple _ | E_typ (_, E_aux (E_tuple _, _)) -> + construct_dep_pairs ctxt (env_of_annot (l, annot)) true full_exp (general_typ_of full_exp) + | E_typ (typ, e) -> + let env = env_of_annot (l, annot) in + let outer_typ = Env.expand_synonyms env (general_typ_of_annot (l, annot)) in + let outer_typ = expand_range_type outer_typ in + let cast_typ = expand_range_type (Env.expand_synonyms env typ) in + let inner_typ = Env.expand_synonyms env (typ_of e) in + let inner_typ = expand_range_type inner_typ in + let () = + debug ctxt (lazy ("Cast of type " ^ string_of_typ cast_typ)); + debug ctxt (lazy (" on expr of type " ^ string_of_typ inner_typ)); + debug ctxt (lazy (" where type expected is " ^ string_of_typ outer_typ)) + in + let epp = expV true e in + let outer_ex, _, outer_typ' = classify_ex_type ctxt env outer_typ in + let cast_ex, _, cast_typ' = classify_ex_type ctxt env ~rawbools:true cast_typ in + let inner_ex, _, inner_typ' = classify_ex_type ctxt env inner_typ in + let autocast_out = + (* Avoid using helper functions which simplify the nexps *) + match (outer_typ', cast_typ') with + | ( Typ_aux (Typ_app (Id_aux (Id "bitvector", _), [A_aux (A_nexp n1, _); _]), _), + Typ_aux (Typ_app (Id_aux (Id "bitvector", _), [A_aux (A_nexp n2, _); _]), _) ) -> not (similar_nexps ctxt env n1 n2) - | _ -> false - in - let effects = effectful (effect_of e) in - (* We don't currently have a version of autocast under existentials, - but they're rare and may be unnecessary *) - let autocast_out = - if effects && outer_ex = ExGeneral then false else autocast_out - in - let () = - debug ctxt (lazy (" effectful: " ^ string_of_bool effects ^ - " outer_ex: " ^ string_of_ex_kind outer_ex ^ - " cast_ex: " ^ string_of_ex_kind cast_ex ^ - " inner_ex: " ^ string_of_ex_kind inner_ex ^ - " autocast_out: " ^ string_of_bool autocast_out)) - in - let epp = epp ^/^ doc_tannot ctxt (env_of e) effects typ in - let epp = - if autocast_out then - string (if effects then "autocast_m" else "autocast") ^^ space ^^ parens epp - else epp - in - if aexp_needed then parens epp else epp + | _ -> false + in + let effects = effectful (effect_of e) in + (* We don't currently have a version of autocast under existentials, + but they're rare and may be unnecessary *) + let autocast_out = if effects && outer_ex = ExGeneral then false else autocast_out in + let () = + debug ctxt + ( lazy + (" effectful: " ^ string_of_bool effects ^ " outer_ex: " ^ string_of_ex_kind outer_ex ^ " cast_ex: " + ^ string_of_ex_kind cast_ex ^ " inner_ex: " ^ string_of_ex_kind inner_ex ^ " autocast_out: " + ^ string_of_bool autocast_out + ) + ) + in + let epp = epp ^/^ doc_tannot ctxt (env_of e) effects typ in + let epp = + if autocast_out then string (if effects then "autocast_m" else "autocast") ^^ space ^^ parens epp else epp + in + if aexp_needed then parens epp else epp | E_struct fexps -> - let recordtyp = match destruct_tannot annot with - | Some (env, Typ_aux (Typ_id tid,_)) - | Some (env, Typ_aux (Typ_app (tid, _), _)) -> - (* when Env.is_record tid env -> *) - tid - | _ -> raise (report l __POS__ ("cannot get record type from annot " ^ string_of_tannot annot ^ " of exp " ^ string_of_exp full_exp)) in - let epp = enclose_record (align (separate_map - (semi_sp ^^ break 1) - (doc_fexp ctxt recordtyp) fexps)) in - if aexp_needed then parens epp else epp - | E_struct_update(e, fexps) -> - let recordtyp, env = match destruct_tannot annot with - | Some (env, Typ_aux (Typ_id tid,_)) - | Some (env, Typ_aux (Typ_app (tid, _), _)) - when Env.is_record tid env -> - tid, env - | _ -> raise (report l __POS__ ("cannot get record type from annot " ^ string_of_tannot annot ^ " of exp " ^ string_of_exp full_exp)) in - if List.length fexps > 1 then - let _,fields = Env.get_record recordtyp env in - let var, let_pp = - match e with - | E_aux (E_id id,_) -> id, empty - | _ -> let v = mk_id "_record" in (* TODO: collision avoid *) - v, separate space [string "let "; doc_id ctxt v; coloneq; top_exp ctxt true e; string "in"] ^^ break 1 - in - let doc_field (_,id) = - match List.find (fun (FE_aux (FE_fexp (id',_),_)) -> Id.compare id id' == 0) fexps with - | fexp -> doc_fexp ctxt recordtyp fexp - | exception Not_found -> - let fname = - if prefix_recordtype && string_of_id recordtyp <> "regstate" - then (string (string_of_id recordtyp ^ "_")) ^^ doc_id ctxt id - else doc_id ctxt id in - doc_op coloneq fname (doc_id ctxt var ^^ dot ^^ parens fname) - in let_pp ^^ enclose_record (align (separate_map (semi_sp ^^ break 1) - doc_field fields)) - else - enclose_record_update (doc_op (string "with") (expY e) (separate_map semi_sp (doc_fexp ctxt recordtyp) fexps)) + let recordtyp = + match destruct_tannot annot with + | Some (env, Typ_aux (Typ_id tid, _)) | Some (env, Typ_aux (Typ_app (tid, _), _)) -> + (* when Env.is_record tid env -> *) + tid + | _ -> + raise + (report l __POS__ + ("cannot get record type from annot " ^ string_of_tannot annot ^ " of exp " ^ string_of_exp full_exp) + ) + in + let epp = enclose_record (align (separate_map (semi_sp ^^ break 1) (doc_fexp ctxt recordtyp) fexps)) in + if aexp_needed then parens epp else epp + | E_struct_update (e, fexps) -> + let recordtyp, env = + match destruct_tannot annot with + | (Some (env, Typ_aux (Typ_id tid, _)) | Some (env, Typ_aux (Typ_app (tid, _), _))) when Env.is_record tid env + -> + (tid, env) + | _ -> + raise + (report l __POS__ + ("cannot get record type from annot " ^ string_of_tannot annot ^ " of exp " ^ string_of_exp full_exp) + ) + in + if List.length fexps > 1 then ( + let _, fields = Env.get_record recordtyp env in + let var, let_pp = + match e with + | E_aux (E_id id, _) -> (id, empty) + | _ -> + let v = mk_id "_record" in + (* TODO: collision avoid *) + (v, separate space [string "let "; doc_id ctxt v; coloneq; top_exp ctxt true e; string "in"] ^^ break 1) + in + let doc_field (_, id) = + match List.find (fun (FE_aux (FE_fexp (id', _), _)) -> Id.compare id id' == 0) fexps with + | fexp -> doc_fexp ctxt recordtyp fexp + | exception Not_found -> + let fname = + if prefix_recordtype && string_of_id recordtyp <> "regstate" then + string (string_of_id recordtyp ^ "_") ^^ doc_id ctxt id + else doc_id ctxt id + in + doc_op coloneq fname (doc_id ctxt var ^^ dot ^^ parens fname) + in + let_pp ^^ enclose_record (align (separate_map (semi_sp ^^ break 1) doc_field fields)) + ) + else + enclose_record_update (doc_op (string "with") (expY e) (separate_map semi_sp (doc_fexp ctxt recordtyp) fexps)) | E_vector exps -> - let t = Env.base_typ_of (env_of full_exp) (typ_of full_exp) in - let start, (len, order, etyp) = - if is_vector_typ t || is_bitvector_typ t then vector_start_index t, vector_typ_args_of t - else raise (Reporting.err_unreachable l __POS__ - "E_vector of non-vector type") in - let dir,dir_out = if is_order_inc order then (true,"true") else (false, "false") in - let expspp = align (group (flow_map (semi ^^ break 0) expN exps)) in - let epp = brackets expspp in - let (epp,aexp_needed) = - if is_bitvector_typ t then - let bepp = string "vec_of_bits" ^^ space ^^ align epp in - (align (group (prefix 0 1 bepp (doc_tannot ctxt (env_of full_exp) false t))), true) - else - let vepp = string "vec_of_list_len" ^^ space ^^ align epp in - (vepp,aexp_needed) in - if aexp_needed then parens (align epp) else epp - | E_vector_update(v,e1,e2) -> - raise (Reporting.err_unreachable l __POS__ - "E_vector_update should have been rewritten before pretty-printing") - | E_vector_update_subrange(v,e1,e2,e3) -> - raise (Reporting.err_unreachable l __POS__ - "E_vector_update should have been rewritten before pretty-printing") - | E_list exps -> - brackets (separate_map (semi ^^ break 1) (expN) exps) - | E_match(e,pexps) -> - let only_integers e = expY e in - let epp = - group ((separate space [string "match"; only_integers e; string "with"]) ^/^ - (separate_map (break 1) (doc_case ctxt (env_of_annot (l,annot)) (typ_of e)) pexps) ^/^ - (string "end")) in - if aexp_needed then parens (align epp) else align epp + let t = Env.base_typ_of (env_of full_exp) (typ_of full_exp) in + let start, (len, order, etyp) = + if is_vector_typ t || is_bitvector_typ t then (vector_start_index t, vector_typ_args_of t) + else raise (Reporting.err_unreachable l __POS__ "E_vector of non-vector type") + in + let dir, dir_out = if is_order_inc order then (true, "true") else (false, "false") in + let expspp = align (group (flow_map (semi ^^ break 0) expN exps)) in + let epp = brackets expspp in + let epp, aexp_needed = + if is_bitvector_typ t then ( + let bepp = string "vec_of_bits" ^^ space ^^ align epp in + (align (group (prefix 0 1 bepp (doc_tannot ctxt (env_of full_exp) false t))), true) + ) + else ( + let vepp = string "vec_of_list_len" ^^ space ^^ align epp in + (vepp, aexp_needed) + ) + in + if aexp_needed then parens (align epp) else epp + | E_vector_update (v, e1, e2) -> + raise (Reporting.err_unreachable l __POS__ "E_vector_update should have been rewritten before pretty-printing") + | E_vector_update_subrange (v, e1, e2, e3) -> + raise (Reporting.err_unreachable l __POS__ "E_vector_update should have been rewritten before pretty-printing") + | E_list exps -> brackets (separate_map (semi ^^ break 1) expN exps) + | E_match (e, pexps) -> + let only_integers e = expY e in + let epp = + group + (separate space [string "match"; only_integers e; string "with"] + ^/^ separate_map (break 1) (doc_case ctxt (env_of_annot (l, annot)) (typ_of e)) pexps + ^/^ string "end" + ) + in + if aexp_needed then parens (align epp) else align epp | E_try (e, pexps) -> - if effectful (effect_of e) then - let try_catch = if Option.is_some ctxt.early_ret then "try_catchR" else "try_catch" in - let epp = - (* TODO capture avoidance for __catch_val *) - group ((separate space [string try_catch; expY e; string "(fun __catch_val => match __catch_val with "]) ^/^ - (separate_map (break 1) (doc_case ctxt (env_of_annot (l,annot)) exc_typ) pexps) ^/^ - (string "end)")) in - if aexp_needed then parens (align epp) else align epp - else - raise (Reporting.err_todo l "Warning: try-block around pure expression") + if effectful (effect_of e) then ( + let try_catch = if Option.is_some ctxt.early_ret then "try_catchR" else "try_catch" in + let epp = + (* TODO capture avoidance for __catch_val *) + group + (separate space [string try_catch; expY e; string "(fun __catch_val => match __catch_val with "] + ^/^ separate_map (break 1) (doc_case ctxt (env_of_annot (l, annot)) exc_typ) pexps + ^/^ string "end)" + ) + in + if aexp_needed then parens (align epp) else align epp + ) + else raise (Reporting.err_todo l "Warning: try-block around pure expression") | E_throw e -> - let epp = liftR (separate space [string "throw"; expY e]) in - if aexp_needed then parens (align epp) else align epp + let epp = liftR (separate space [string "throw"; expY e]) in + if aexp_needed then parens (align epp) else align epp | E_exit e -> liftR (separate space [string "exit"; expY e]) - | E_assert (e1,e2) -> - let epp = liftR (separate space [string "assert_exp"; expY e1; expY e2]) in - if aexp_needed then parens (align epp) else align epp - | E_app_infix (e1,id,e2) -> - raise (Reporting.err_unreachable l __POS__ - "E_app_infix should have been rewritten before pretty-printing") - | E_var(lexp, eq_exp, in_exp) -> - raise (report l __POS__ "E_vars should have been removed before pretty-printing") - | E_internal_plet (pat,e1,e2) -> - begin - let () = - debug ctxt (lazy ("Internal plet, pattern " ^ string_of_pat pat)); - debug ctxt (lazy (" type of e1 " ^ string_of_typ (typ_of e1))) - in - let outer_env = env_of_annot (l,annot) in - let new_ctxt = merge_new_tyvars ctxt outer_env pat (env_of e2) in - match pat, e1, e2 with - | (P_aux (P_wild,_) | P_aux (P_typ (_, P_aux (P_wild, _)), _)), - (E_aux (E_assert (assert_e1,assert_e2),_)), _ -> - let assert_fn, mid = - match assert_constraint outer_env true assert_e1 with - | Some _ -> "assert_exp'", ">>= fun _ =>" - | None -> "assert_exp", ">>" - in - let epp = liftR (separate space [string assert_fn; expY assert_e1; expY assert_e2]) in - let epp = infix 0 1 (string mid) epp (top_exp new_ctxt false e2) in - if aexp_needed then parens (align epp) else align epp - | _ -> + | E_assert (e1, e2) -> + let epp = liftR (separate space [string "assert_exp"; expY e1; expY e2]) in + if aexp_needed then parens (align epp) else align epp + | E_app_infix (e1, id, e2) -> + raise (Reporting.err_unreachable l __POS__ "E_app_infix should have been rewritten before pretty-printing") + | E_var (lexp, eq_exp, in_exp) -> raise (report l __POS__ "E_vars should have been removed before pretty-printing") + | E_internal_plet (pat, e1, e2) -> begin + let () = + debug ctxt (lazy ("Internal plet, pattern " ^ string_of_pat pat)); + debug ctxt (lazy (" type of e1 " ^ string_of_typ (typ_of e1))) + in + let outer_env = env_of_annot (l, annot) in + let new_ctxt = merge_new_tyvars ctxt outer_env pat (env_of e2) in + match (pat, e1, e2) with + | (P_aux (P_wild, _) | P_aux (P_typ (_, P_aux (P_wild, _)), _)), E_aux (E_assert (assert_e1, assert_e2), _), _ + -> + let assert_fn, mid = + match assert_constraint outer_env true assert_e1 with + | Some _ -> ("assert_exp'", ">>= fun _ =>") + | None -> ("assert_exp", ">>") + in + let epp = liftR (separate space [string assert_fn; expY assert_e1; expY assert_e2]) in + let epp = infix 0 1 (string mid) epp (top_exp new_ctxt false e2) in + if aexp_needed then parens (align epp) else align epp + | _ -> let epp = let middle = - if ctxt.is_monadic then + if ctxt.is_monadic then ( let env1 = env_of e1 in match pat with - | P_aux (P_wild,_) | P_aux (P_typ (_, P_aux (P_wild, _)), _) - when is_unit_typ (typ_of_pat pat) -> - string ">>" - | P_aux (P_id id,_) - when not (is_enum (env_of e1) id) -> - separate space [string ">>= fun"; doc_id ctxt id; bigarrow] - | P_aux (P_typ (typ, P_aux (P_id id,_)),_) - when (is_enum (env_of e1) id) -> - separate space [string ">>= fun"; doc_id ctxt id; colon; doc_typ ctxt outer_env typ; bigarrow] - (* TODO: is this still needed? *) - | P_aux (P_typ (typ, P_aux (P_id id,_)),_) - | P_aux (P_typ (typ, P_aux (P_var (P_aux (P_id id,_),_),_)),_) - | P_aux (P_var (P_aux (P_typ (typ, P_aux (P_id id,_)),_),_),_) - when not (is_enum env1 id) -> - let full_typ = (expand_range_type typ) in - let binder = parens (separate space [doc_id ctxt id; colon; doc_typ ctxt outer_env typ]) - in separate space [string ">>= fun"; binder; bigarrow] - | P_aux (P_id id,_) -> - let typ = typ_of e1 in - (* Ideally we'd drop the parens and the squote when possible, but it's - easier to keep both, and avoids clashes with 'b"..." bitvector literals. *) - let binder = squote ^^ parens (doc_pat ctxt false true (pat, typ_of e1)) in - separate space [string ">>= fun"; binder; bigarrow] - | P_aux (P_typ (typ, pat'),_) -> - separate space [string ">>= fun"; squote ^^ parens (doc_pat ctxt true true (pat, typ_of e1) ^/^ colon ^^ space ^^ doc_typ ctxt outer_env typ); bigarrow] + | (P_aux (P_wild, _) | P_aux (P_typ (_, P_aux (P_wild, _)), _)) when is_unit_typ (typ_of_pat pat) -> + string ">>" + | P_aux (P_id id, _) when not (is_enum (env_of e1) id) -> + separate space [string ">>= fun"; doc_id ctxt id; bigarrow] + | P_aux (P_typ (typ, P_aux (P_id id, _)), _) when is_enum (env_of e1) id -> + separate space [string ">>= fun"; doc_id ctxt id; colon; doc_typ ctxt outer_env typ; bigarrow] + (* TODO: is this still needed? *) + | P_aux (P_typ (typ, P_aux (P_id id, _)), _) + | P_aux (P_typ (typ, P_aux (P_var (P_aux (P_id id, _), _), _)), _) + | P_aux (P_var (P_aux (P_typ (typ, P_aux (P_id id, _)), _), _), _) + when not (is_enum env1 id) -> + let full_typ = expand_range_type typ in + let binder = parens (separate space [doc_id ctxt id; colon; doc_typ ctxt outer_env typ]) in + separate space [string ">>= fun"; binder; bigarrow] + | P_aux (P_id id, _) -> + let typ = typ_of e1 in + (* Ideally we'd drop the parens and the squote when possible, but it's + easier to keep both, and avoids clashes with 'b"..." bitvector literals. *) + let binder = squote ^^ parens (doc_pat ctxt false true (pat, typ_of e1)) in + separate space [string ">>= fun"; binder; bigarrow] + | P_aux (P_typ (typ, pat'), _) -> + separate space + [ + string ">>= fun"; + squote + ^^ parens + (doc_pat ctxt true true (pat, typ_of e1) ^/^ colon ^^ space ^^ doc_typ ctxt outer_env typ); + bigarrow; + ] | _ -> - separate space [string ">>= fun"; squote ^^ parens (doc_pat ctxt false true (pat, typ_of e1)); bigarrow] - else + separate space + [string ">>= fun"; squote ^^ parens (doc_pat ctxt false true (pat, typ_of e1)); bigarrow] + ) + else ( match pat with - | P_aux (P_wild,_) | P_aux (P_typ (_, P_aux (P_wild, _)), _) - when is_unit_typ (typ_of_pat pat) -> - string ">>$" + | (P_aux (P_wild, _) | P_aux (P_typ (_, P_aux (P_wild, _)), _)) when is_unit_typ (typ_of_pat pat) -> + string ">>$" | _ -> - separate space [string ">>$= fun"; squote ^^ parens (doc_pat ctxt false false (pat, typ_of e1)); bigarrow] - in + separate space + [string ">>$= fun"; squote ^^ parens (doc_pat ctxt false false (pat, typ_of e1)); bigarrow] + ) + in let e1_pp = expY e1 in let e2_pp = top_exp new_ctxt false e2 in infix 0 1 middle e1_pp e2_pp in if aexp_needed then parens (align epp) else epp - end - | E_internal_return (e1) -> - let exp_typ = typ_of e1 in - let ret_typ = general_typ_of full_exp in - let () = - debug ctxt (lazy ("Monad return of " ^ string_of_exp e1)); - debug ctxt (lazy (" with type " ^ string_of_typ exp_typ)); - debug ctxt (lazy (" at type " ^ string_of_typ ret_typ)) - in - let valpp = - let env = env_of e1 in - construct_dep_pairs ctxt env true e1 ret_typ ~rawbools:true - in - if Option.is_some ctxt.early_ret then - if ctxt.is_monadic - then wrap_parens (group (align (separate space [string "returnR"; parens ctxt.ret_typ_pp; valpp]))) - else wrap_parens (group (align (separate space [string "inr"; valpp]))) - else - wrap_parens (group (align (separate space [string "returnM"; valpp]))) - | E_sizeof nexp -> - (match nexp_simp nexp with + end + | E_internal_return e1 -> + let exp_typ = typ_of e1 in + let ret_typ = general_typ_of full_exp in + let () = + debug ctxt (lazy ("Monad return of " ^ string_of_exp e1)); + debug ctxt (lazy (" with type " ^ string_of_typ exp_typ)); + debug ctxt (lazy (" at type " ^ string_of_typ ret_typ)) + in + let valpp = + let env = env_of e1 in + construct_dep_pairs ctxt env true e1 ret_typ ~rawbools:true + in + if Option.is_some ctxt.early_ret then + if ctxt.is_monadic then + wrap_parens (group (align (separate space [string "returnR"; parens ctxt.ret_typ_pp; valpp]))) + else wrap_parens (group (align (separate space [string "inr"; valpp]))) + else wrap_parens (group (align (separate space [string "returnM"; valpp]))) + | E_sizeof nexp -> ( + match nexp_simp nexp with | Nexp_aux (Nexp_constant i, _) -> doc_lit (L_aux (L_num i, l)) | _ -> - raise (Reporting.err_unreachable l __POS__ - "pretty-printing non-constant sizeof expressions to Lem not supported")) + raise + (Reporting.err_unreachable l __POS__ + "pretty-printing non-constant sizeof expressions to Lem not supported" + ) + ) | E_return r -> - let ret_monad = " : MR" in - let exp_pp = - match ctxt.build_at_return with - | Some s -> parens (string s ^/^ expY r) - | None -> expY r - in - let ta = - if contains_t_pp_var ctxt (typ_of full_exp) || contains_t_pp_var ctxt (typ_of r) - then empty - else separate space - [string ret_monad; - parens (doc_typ ctxt (env_of full_exp) (typ_of full_exp)); - parens (doc_typ ctxt (env_of full_exp) (typ_of r))] in - align (parens (string "early_return" ^//^ exp_pp ^//^ ta)) + let ret_monad = " : MR" in + let exp_pp = match ctxt.build_at_return with Some s -> parens (string s ^/^ expY r) | None -> expY r in + let ta = + if contains_t_pp_var ctxt (typ_of full_exp) || contains_t_pp_var ctxt (typ_of r) then empty + else + separate space + [ + string ret_monad; + parens (doc_typ ctxt (env_of full_exp) (typ_of full_exp)); + parens (doc_typ ctxt (env_of full_exp) (typ_of r)); + ] + in + align (parens (string "early_return" ^//^ exp_pp ^//^ ta)) | E_constraint nc -> wrap_parens (doc_nc_exp ctxt (env_of full_exp) nc) | E_internal_assume (nc, e1) -> - string "(* " ^^ doc_nc_exp ctxt (env_of full_exp) nc ^^ string " *)" ^/^ wrap_parens (expN e1) + string "(* " ^^ doc_nc_exp ctxt (env_of full_exp) nc ^^ string " *)" ^/^ wrap_parens (expN e1) | E_internal_value _ -> - raise (Reporting.err_unreachable l __POS__ - "unsupported internal expression encountered while pretty-printing") - + raise (Reporting.err_unreachable l __POS__ "unsupported internal expression encountered while pretty-printing") (* TODO: no dep pairs now, what should this be? *) - and construct_dep_pairs ctxt ?(rawbools=false) env = - let rec aux want_parens (E_aux (e,_) as exp) typ = + and construct_dep_pairs ctxt ?(rawbools = false) env = + let rec aux want_parens (E_aux (e, _) as exp) typ = match e with - | E_tuple exps - | E_typ (_, E_aux (E_tuple exps,_)) -> - let typs = List.map general_typ_of exps in - parens (separate (string ", ") (List.map2 (aux false) exps typs)) + | E_tuple exps | E_typ (_, E_aux (E_tuple exps, _)) -> + let typs = List.map general_typ_of exps in + parens (separate (string ", ") (List.map2 (aux false) exps typs)) | _ -> - let typ' = expand_range_type (Env.expand_synonyms (env_of exp) typ) in - debug ctxt (lazy ("Constructing " ^ string_of_exp exp ^ " at type " ^ string_of_typ typ)); - let out_typ = - match classify_ex_type ctxt (env_of exp) ~rawbools typ' with - | ExNone, _, _ -> typ' - | ExGeneral, _, typ' -> typ' - in - let in_typ = expand_range_type (Env.expand_synonyms (env_of exp) (typ_of exp)) in - let in_typ = match destruct_exist_plain in_typ with Some (_,_,t) -> t | None -> in_typ in - let exp_pp = top_exp ctxt want_parens exp in - exp_pp - in aux - + let typ' = expand_range_type (Env.expand_synonyms (env_of exp) typ) in + debug ctxt (lazy ("Constructing " ^ string_of_exp exp ^ " at type " ^ string_of_typ typ)); + let out_typ = + match classify_ex_type ctxt (env_of exp) ~rawbools typ' with + | ExNone, _, _ -> typ' + | ExGeneral, _, typ' -> typ' + in + let in_typ = expand_range_type (Env.expand_synonyms (env_of exp) (typ_of exp)) in + let in_typ = match destruct_exist_plain in_typ with Some (_, _, t) -> t | None -> in_typ in + let exp_pp = top_exp ctxt want_parens exp in + exp_pp + in + aux and if_exp ctxt full_env full_typ (elseif : bool) c t e = let if_pp = string (if elseif then "else if" else "if") in let c_pp = top_exp ctxt false c in @@ -2202,145 +2187,127 @@ let doc_exp, doc_let = across if expressions in complex situations, so provide an annotation for monadic expressions. *) let add_type_pp pp = - if effectful (effect_of t) then - pp ^/^ string "return" ^/^ doc_tannot_core ctxt full_env true full_typ - else pp + if effectful (effect_of t) then pp ^/^ string "return" ^/^ doc_tannot_core ctxt full_env true full_typ else pp in let t_pp = top_exp ctxt false t in - let else_pp = match e with - | E_aux (E_if (c', t', e'), _) - | E_aux (E_typ (_, E_aux (E_if (c', t', e'), _)), _) -> - if_exp ctxt full_env full_typ true c' t' e' + let else_pp = + match e with + | E_aux (E_if (c', t', e'), _) | E_aux (E_typ (_, E_aux (E_if (c', t', e'), _)), _) -> + if_exp ctxt full_env full_typ true c' t' e' (* Special case to prevent current arm decoder becoming a staircase *) (* TODO: replace with smarter pretty printing *) - | E_aux (E_internal_plet (pat,exp1,E_aux (E_typ (typ, (E_aux (E_if (_, _, _), _) as exp2)),_)),ann) when Typ.compare typ unit_typ == 0 -> - string "else" ^/^ top_exp ctxt false (E_aux (E_internal_plet (pat,exp1,exp2),ann)) + | E_aux (E_internal_plet (pat, exp1, E_aux (E_typ (typ, (E_aux (E_if (_, _, _), _) as exp2)), _)), ann) + when Typ.compare typ unit_typ == 0 -> + string "else" ^/^ top_exp ctxt false (E_aux (E_internal_plet (pat, exp1, exp2), ann)) | _ -> prefix 2 1 (string "else") (top_exp ctxt false e) in - (prefix 2 1 - (soft_surround 2 1 if_pp - (add_type_pp c_pp) - (string "then")) - t_pp) ^^ - break 1 ^^ - else_pp - and let_exp ctxt (LB_aux(lb,_)) = match lb with + prefix 2 1 (soft_surround 2 1 if_pp (add_type_pp c_pp) (string "then")) t_pp ^^ break 1 ^^ else_pp + and let_exp ctxt (LB_aux (lb, _)) = + match lb with (* Prefer simple lets over patterns, because I've found Coq can struggle to work out return types otherwise *) - | LB_val(P_aux (P_id id,_),e) - when not (is_enum (env_of e) id) -> - prefix 2 1 - (separate space [string "let"; doc_id ctxt id; coloneq]) - (top_exp ctxt false e) - | LB_val(P_aux (P_typ (typ,P_aux (P_id id,_)),_),e) - when not (is_enum (env_of e) id) -> - prefix 2 1 - (separate space [string "let"; doc_id ctxt id; colon; doc_typ ctxt (env_of e) typ; coloneq]) - (top_exp ctxt false e) - | LB_val(P_aux (P_typ (typ,pat),_),(E_aux (_,e_ann) as e)) -> - prefix 2 1 - (separate space [string "let"; squote ^^ parens (doc_pat ctxt true false (pat, typ)); coloneq]) - (top_exp ctxt false (E_aux (E_typ (typ,e),e_ann))) - | LB_val(pat,e) -> - prefix 2 1 - (separate space [string "let"; squote ^^ parens (doc_pat ctxt true false (pat, typ_of e)); coloneq]) - (top_exp ctxt false e) - - and doc_fexp ctxt recordtyp (FE_aux(FE_fexp(id,e),_)) = + | LB_val (P_aux (P_id id, _), e) when not (is_enum (env_of e) id) -> + prefix 2 1 (separate space [string "let"; doc_id ctxt id; coloneq]) (top_exp ctxt false e) + | LB_val (P_aux (P_typ (typ, P_aux (P_id id, _)), _), e) when not (is_enum (env_of e) id) -> + prefix 2 1 + (separate space [string "let"; doc_id ctxt id; colon; doc_typ ctxt (env_of e) typ; coloneq]) + (top_exp ctxt false e) + | LB_val (P_aux (P_typ (typ, pat), _), (E_aux (_, e_ann) as e)) -> + prefix 2 1 + (separate space [string "let"; squote ^^ parens (doc_pat ctxt true false (pat, typ)); coloneq]) + (top_exp ctxt false (E_aux (E_typ (typ, e), e_ann))) + | LB_val (pat, e) -> + prefix 2 1 + (separate space [string "let"; squote ^^ parens (doc_pat ctxt true false (pat, typ_of e)); coloneq]) + (top_exp ctxt false e) + and doc_fexp ctxt recordtyp (FE_aux (FE_fexp (id, e), _)) = let fname = - if prefix_recordtype && string_of_id recordtyp <> "regstate" - then (string (string_of_id recordtyp ^ "_")) ^^ doc_id ctxt id - else doc_id ctxt id in + if prefix_recordtype && string_of_id recordtyp <> "regstate" then + string (string_of_id recordtyp ^ "_") ^^ doc_id ctxt id + else doc_id ctxt id + in let e_pp = construct_dep_pairs ctxt (env_of e) false e (general_typ_of e) in group (doc_op coloneq fname e_pp) - and doc_case ctxt old_env typ = function - | Pat_aux(Pat_exp(pat,e),_) -> - let new_ctxt = merge_new_tyvars ctxt old_env pat (env_of e) in - group (prefix 3 1 (separate space [pipe; doc_pat ctxt false false (pat,typ);bigarrow]) - (group (top_exp new_ctxt false e))) - | Pat_aux(Pat_when(_,_,_),(l,_)) -> - raise (Reporting.err_unreachable l __POS__ - "guarded pattern expression should have been rewritten before pretty-printing") - - and doc_lexp_deref ctxt ((LE_aux(lexp,(l,annot)))) = match lexp with - | LE_field (le,id) -> - parens (separate empty [doc_lexp_deref ctxt le;dot;doc_id ctxt id]) + | Pat_aux (Pat_exp (pat, e), _) -> + let new_ctxt = merge_new_tyvars ctxt old_env pat (env_of e) in + group + (prefix 3 1 + (separate space [pipe; doc_pat ctxt false false (pat, typ); bigarrow]) + (group (top_exp new_ctxt false e)) + ) + | Pat_aux (Pat_when (_, _, _), (l, _)) -> + raise + (Reporting.err_unreachable l __POS__ + "guarded pattern expression should have been rewritten before pretty-printing" + ) + and doc_lexp_deref ctxt (LE_aux (lexp, (l, annot))) = + match lexp with + | LE_field (le, id) -> parens (separate empty [doc_lexp_deref ctxt le; dot; doc_id ctxt id]) | LE_id id -> doc_id ctxt id ^^ string "_ref" - | LE_typ (typ,id) -> doc_id ctxt id ^^ string "_ref" + | LE_typ (typ, id) -> doc_id ctxt id ^^ string "_ref" | LE_tuple lexps -> parens (separate_map comma_sp (doc_lexp_deref ctxt) lexps) - | _ -> - raise (Reporting.err_unreachable l __POS__ ("doc_lexp_deref: Unsupported lexp")) - (* expose doc_exp and doc_let *) - in top_exp, let_exp + | _ -> raise (Reporting.err_unreachable l __POS__ "doc_lexp_deref: Unsupported lexp") + (* expose doc_exp and doc_let *) + in + (top_exp, let_exp) (* FIXME: A temporary definition of List.init until 4.06 is more standard *) let list_init n f = Array.to_list (Array.init n f) let types_used_with_generic_eq defs = - let rec add_typ idset (Typ_aux (typ,_)) = + let rec add_typ idset (Typ_aux (typ, _)) = match typ with | Typ_id id -> IdSet.add id idset - | Typ_app (id,args) -> - List.fold_left add_typ_arg (IdSet.add id idset) args + | Typ_app (id, args) -> List.fold_left add_typ_arg (IdSet.add id idset) args | Typ_tuple ts -> List.fold_left add_typ idset ts | _ -> idset - and add_typ_arg idset (A_aux (ta,_)) = - match ta with - | A_typ typ -> add_typ idset typ - | _ -> idset - in + and add_typ_arg idset (A_aux (ta, _)) = match ta with A_typ typ -> add_typ idset typ | _ -> idset in let alg = - { (Rewriter.compute_exp_alg IdSet.empty IdSet.union) with - Rewriter.e_aux = fun ((typs,exp),annot) -> - let typs' = - match exp with - | E_app (f,[arg1;_]) -> - if Env.is_extern f (env_of_annot annot) "coq" then - let f' = Env.get_extern f (env_of_annot annot) "coq" in - if f' = "generic_eq" || f' = "generic_neq" then - add_typ typs (Env.expand_synonyms (env_of arg1) (typ_of arg1)) - else typs - else typs - | _ -> typs - in typs', E_aux (exp,annot) } - in - let typs_req_funcl (FCL_aux (FCL_funcl (_,pexp), _)) = - fst (Rewriter.fold_pexp alg pexp) + { + (Rewriter.compute_exp_alg IdSet.empty IdSet.union) with + Rewriter.e_aux = + (fun ((typs, exp), annot) -> + let typs' = + match exp with + | E_app (f, [arg1; _]) -> + if Env.is_extern f (env_of_annot annot) "coq" then ( + let f' = Env.get_extern f (env_of_annot annot) "coq" in + if f' = "generic_eq" || f' = "generic_neq" then + add_typ typs (Env.expand_synonyms (env_of arg1) (typ_of arg1)) + else typs + ) + else typs + | _ -> typs + in + (typs', E_aux (exp, annot)) + ); + } in - let typs_req_fundef (FD_aux (FD_function (_,_,fcls),_)) = + let typs_req_funcl (FCL_aux (FCL_funcl (_, pexp), _)) = fst (Rewriter.fold_pexp alg pexp) in + let typs_req_fundef (FD_aux (FD_function (_, _, fcls), _)) = List.fold_left IdSet.union IdSet.empty (List.map typs_req_funcl fcls) in let typs_req_def (DEF_aux (aux, _) as def) = match aux with - | DEF_type _ - | DEF_val _ - | DEF_fixity _ - | DEF_overload _ - | DEF_default _ - | DEF_pragma _ - | DEF_register _ - -> IdSet.empty + | DEF_type _ | DEF_val _ | DEF_fixity _ | DEF_overload _ | DEF_default _ | DEF_pragma _ | DEF_register _ -> + IdSet.empty | DEF_fundef fd -> typs_req_fundef fd - | DEF_internal_mutrec fds -> - List.fold_left IdSet.union IdSet.empty (List.map typs_req_fundef fds) - | DEF_let lb -> - fst (Rewriter.fold_letbind alg lb) - | DEF_mapdef _ | DEF_scattered _ | DEF_measure _ | DEF_loop_measures _ | DEF_impl _ | DEF_instantiation _ | DEF_outcome _ -> - unreachable (def_loc def) __POS__ - "Definition found in the Coq back-end that should have been rewritten away" + | DEF_internal_mutrec fds -> List.fold_left IdSet.union IdSet.empty (List.map typs_req_fundef fds) + | DEF_let lb -> fst (Rewriter.fold_letbind alg lb) + | DEF_mapdef _ | DEF_scattered _ | DEF_measure _ | DEF_loop_measures _ | DEF_impl _ | DEF_instantiation _ + | DEF_outcome _ -> + unreachable (def_loc def) __POS__ "Definition found in the Coq back-end that should have been rewritten away" in List.fold_left IdSet.union IdSet.empty (List.map typs_req_def defs) -let doc_type_union ctxt typ_name (Tu_aux(Tu_ty_id(typ,id),_)) = - separate space [doc_id_ctor ctxt id; colon; - doc_typ ctxt Env.empty typ; arrow; typ_name] +let doc_type_union ctxt typ_name (Tu_aux (Tu_ty_id (typ, id), _)) = + separate space [doc_id_ctor ctxt id; colon; doc_typ ctxt Env.empty typ; arrow; typ_name] (* For records and variants we declare the type parameters as implicit so that they're implicit in the constructors. Currently Coq also makes them implicit in the type, so undo that here. *) -let doc_reset_implicits id_pp typq = - separate space ([string "Arguments"; id_pp; colon; string "clear implicits"]) ^^ dot +let doc_reset_implicits id_pp typq = separate space [string "Arguments"; id_pp; colon; string "clear implicits"] ^^ dot (* let rec doc_range ctxt (BF_aux(r,_)) = match r with @@ -2350,231 +2317,260 @@ let rec doc_range ctxt (BF_aux(r,_)) = match r with *) (* TODO: check use of empty_ctxt below doesn't cause problems due to missing info *) -let doc_typdef types_mod avoid_target_names generic_eq_types (TD_aux(td, (l, annot))) = +let doc_typdef types_mod avoid_target_names generic_eq_types (TD_aux (td, (l, annot))) = let bare_ctxt = { empty_ctxt with avoid_target_names } in match td with - | TD_abbrev(id,typq,A_aux (A_typ typ, _)) -> - let typschm = TypSchm_aux (TypSchm_ts (typq, typ), l) in - doc_op coloneq - (separate space [string "Definition"; doc_id_type types_mod avoid_target_names None id; - doc_typquant_items bare_ctxt Env.empty parens typq; - colon; string "Type"]) - (doc_typschm bare_ctxt Env.empty false typschm) ^^ dot ^^ twice hardline - | TD_abbrev(id,typq,A_aux (A_nexp nexp,_)) -> - let idpp = doc_id_type types_mod avoid_target_names None id in - doc_op coloneq - (separate space [string "Definition"; idpp; - doc_typquant_items bare_ctxt Env.empty parens typq; - colon; string "Z"]) - (doc_nexp bare_ctxt nexp) ^^ dot ^^ hardline ^^ - separate space [string "#[export] Hint Unfold"; idpp; colon; string "sail."] ^^ - twice hardline - | TD_abbrev(id,typq,A_aux (A_bool nc,_)) -> - let idpp = doc_id_type types_mod avoid_target_names None id in - doc_op coloneq - (separate space [string "Definition"; idpp; - doc_typquant_items bare_ctxt Env.empty parens typq; - colon; string "bool"]) - (doc_nc_exp bare_ctxt Env.empty nc) ^^ dot ^^ hardline ^^ - separate space [string "#[export] Hint Unfold"; idpp; colon; string "sail."] ^^ - twice hardline + | TD_abbrev (id, typq, A_aux (A_typ typ, _)) -> + let typschm = TypSchm_aux (TypSchm_ts (typq, typ), l) in + doc_op coloneq + (separate space + [ + string "Definition"; + doc_id_type types_mod avoid_target_names None id; + doc_typquant_items bare_ctxt Env.empty parens typq; + colon; + string "Type"; + ] + ) + (doc_typschm bare_ctxt Env.empty false typschm) + ^^ dot ^^ twice hardline + | TD_abbrev (id, typq, A_aux (A_nexp nexp, _)) -> + let idpp = doc_id_type types_mod avoid_target_names None id in + doc_op coloneq + (separate space + [string "Definition"; idpp; doc_typquant_items bare_ctxt Env.empty parens typq; colon; string "Z"] + ) + (doc_nexp bare_ctxt nexp) + ^^ dot ^^ hardline + ^^ separate space [string "#[export] Hint Unfold"; idpp; colon; string "sail."] + ^^ twice hardline + | TD_abbrev (id, typq, A_aux (A_bool nc, _)) -> + let idpp = doc_id_type types_mod avoid_target_names None id in + doc_op coloneq + (separate space + [string "Definition"; idpp; doc_typquant_items bare_ctxt Env.empty parens typq; colon; string "bool"] + ) + (doc_nc_exp bare_ctxt Env.empty nc) + ^^ dot ^^ hardline + ^^ separate space [string "#[export] Hint Unfold"; idpp; colon; string "sail."] + ^^ twice hardline | TD_abbrev _ -> empty (* TODO? *) | TD_bitfield _ -> empty (* TODO? *) - | TD_record(id,typq,fs,_) -> - let fname fid = if prefix_recordtype && string_of_id id <> "regstate" - then concat [doc_id bare_ctxt id;string "_";doc_id_type types_mod avoid_target_names None fid;] - else doc_id_type types_mod avoid_target_names None fid in - let f_pp (typ,fid) = - concat [fname fid;space;colon;space;doc_typ bare_ctxt Env.empty typ; semi] in - let rectyp = match typq with - | TypQ_aux (TypQ_tq qs, _) -> - let quant_item = function - | QI_aux (QI_id (KOpt_aux (KOpt_kind (_, kid), _)), l) -> - [A_aux (A_nexp (Nexp_aux (Nexp_var kid, l)), l)] - | _ -> [] in - let targs = List.concat (List.map quant_item qs) in - mk_typ (Typ_app (id, targs)) - | TypQ_aux (TypQ_no_forall, _) -> mk_id_typ id in - let fs_doc = group (separate_map (break 1) f_pp fs) in - let type_id_pp = doc_id_type types_mod avoid_target_names None id in - let match_parameters = - match quant_kopts typq with - | [] -> empty - | l -> space ^^ separate_map space (fun _ -> underscore) l - in - let build_parameters = - let (kopts,_) = quant_split typq in - match kopts with - | [] -> empty - | _ -> space ^^ separate_map space (fun _ -> underscore) kopts - in - let doc_inhabited_req = function - | QI_aux (QI_id (KOpt_aux (KOpt_kind (K_aux (K_type,_),kid),_)),_) -> - Some (string "`{Inhabited " ^^ doc_var bare_ctxt kid ^^ string "}") - | _ -> None - in - let doc_update_field (_,fid) = - let idpp = fname fid in - let pp_field alt i (_,fid') = - if Id.compare fid fid' == 0 then string alt else - let id = "f" ^ string_of_int i in - string id + | TD_record (id, typq, fs, _) -> + let fname fid = + if prefix_recordtype && string_of_id id <> "regstate" then + concat [doc_id bare_ctxt id; string "_"; doc_id_type types_mod avoid_target_names None fid] + else doc_id_type types_mod avoid_target_names None fid in - match fs with - | [_] -> - string "Notation \"{[ r 'with' '" ^^ idpp ^^ string "' := e ]}\" :=" ^//^ - string "{| " ^^ idpp ^^ string " := e |} (only parsing)." - | _ -> - string "Notation \"{[ r 'with' '" ^^ idpp ^^ string "' := e ]}\" :=" ^//^ - string "match r with Build_" ^^ type_id_pp ^^ match_parameters ^^ space ^^ separate space (List.mapi (pp_field "_") fs) ^^ string " =>" ^//^ - string "Build_" ^^ type_id_pp ^^ build_parameters ^^ space ^^ separate space (List.mapi (pp_field "e") fs) ^//^ - string "end" ^^ dot - in - let updates_pp = separate hardline (List.map doc_update_field fs) in - let numfields = List.length fs in - let intros_pp s = - string " intros [" ^^ - separate space (list_init numfields (fun n -> string (s ^ string_of_int n))) ^^ - string "]." ^^ hardline - in - let eq_pp = - if IdSet.mem id generic_eq_types then - string "#[export] Instance Decidable_eq_" ^^ type_id_pp ^^ space ^^ colon ^/^ - string "forall (x y : " ^^ type_id_pp ^^ string "), Decidable (x = y)." ^^ - hardline ^^ intros_pp "x" ^^ intros_pp "y" ^^ - separate hardline (list_init numfields - (fun n -> - let ns = string_of_int n in - string ("cmp_record_field x" ^ ns ^ " y" ^ ns ^ "."))) ^^ - hardline ^^ - string "refine (Build_Decidable _ true _). subst. split; reflexivity." ^^ hardline ^^ - string "Defined." ^^ twice hardline - else empty - in - let typqs_pp = doc_typquant_items bare_ctxt Env.empty braces typq in - let inhabited_pp = - let reqs_pp = separate (break 1) (List.filter_map doc_inhabited_req (quant_items typq)) in - let params_pp = separate space (List.filter_map (quant_item_id_name bare_ctxt) (quant_items typq)) in - let field_pp (_,fid) = fname fid ^^ string " := inhabitant" in - group (prefix 2 1 (group (string "#[export] Instance dummy_" ^^ type_id_pp ^/^ typqs_pp ^/^ reqs_pp ^^ colon ^/^ - string "Inhabited (" ^^ type_id_pp ^^ space ^^ params_pp ^^ string ") := {")) - (prefix 2 1 (string "inhabitant := {|") - (separate_map (string ";" ^^ break 1) field_pp fs)) ^/^ - string "|} }.") ^^ hardline - in - let reset_implicits_pp = doc_reset_implicits type_id_pp typq in - doc_op coloneq - (separate space [string "Record"; type_id_pp; typqs_pp]) - ((*doc_typquant typq*) (braces (space ^^ align fs_doc ^^ space))) ^^ - dot ^^ hardline ^^ reset_implicits_pp ^^ hardline ^^ - eq_pp ^^ updates_pp ^^ hardline ^^ - inhabited_pp ^^ twice hardline - | TD_variant(id,typq,ar,_) -> - (match id with - | Id_aux ((Id "read_kind"),_) -> empty - | Id_aux ((Id "write_kind"),_) -> empty - | Id_aux ((Id "a64_barrier_domain"),_) -> empty - | Id_aux ((Id "a64_barrier_type"),_) -> empty - | Id_aux ((Id "barrier_kind"),_) -> empty - | Id_aux ((Id "trans_kind"),_) -> empty - | Id_aux ((Id "instruction_kind"),_) -> empty + let f_pp (typ, fid) = concat [fname fid; space; colon; space; doc_typ bare_ctxt Env.empty typ; semi] in + let rectyp = + match typq with + | TypQ_aux (TypQ_tq qs, _) -> + let quant_item = function + | QI_aux (QI_id (KOpt_aux (KOpt_kind (_, kid), _)), l) -> [A_aux (A_nexp (Nexp_aux (Nexp_var kid, l)), l)] + | _ -> [] + in + let targs = List.concat (List.map quant_item qs) in + mk_typ (Typ_app (id, targs)) + | TypQ_aux (TypQ_no_forall, _) -> mk_id_typ id + in + let fs_doc = group (separate_map (break 1) f_pp fs) in + let type_id_pp = doc_id_type types_mod avoid_target_names None id in + let match_parameters = + match quant_kopts typq with [] -> empty | l -> space ^^ separate_map space (fun _ -> underscore) l + in + let build_parameters = + let kopts, _ = quant_split typq in + match kopts with [] -> empty | _ -> space ^^ separate_map space (fun _ -> underscore) kopts + in + let doc_inhabited_req = function + | QI_aux (QI_id (KOpt_aux (KOpt_kind (K_aux (K_type, _), kid), _)), _) -> + Some (string "`{Inhabited " ^^ doc_var bare_ctxt kid ^^ string "}") + | _ -> None + in + let doc_update_field (_, fid) = + let idpp = fname fid in + let pp_field alt i (_, fid') = + if Id.compare fid fid' == 0 then string alt + else ( + let id = "f" ^ string_of_int i in + string id + ) + in + match fs with + | [_] -> + string "Notation \"{[ r 'with' '" ^^ idpp ^^ string "' := e ]}\" :=" ^//^ string "{| " ^^ idpp + ^^ string " := e |} (only parsing)." + | _ -> + string "Notation \"{[ r 'with' '" ^^ idpp ^^ string "' := e ]}\" :=" ^//^ string "match r with Build_" + ^^ type_id_pp ^^ match_parameters ^^ space + ^^ separate space (List.mapi (pp_field "_") fs) + ^^ string " =>" ^//^ string "Build_" ^^ type_id_pp ^^ build_parameters ^^ space + ^^ separate space (List.mapi (pp_field "e") fs) + ^//^ string "end" ^^ dot + in + let updates_pp = separate hardline (List.map doc_update_field fs) in + let numfields = List.length fs in + let intros_pp s = + string " intros [" + ^^ separate space (list_init numfields (fun n -> string (s ^ string_of_int n))) + ^^ string "]." ^^ hardline + in + let eq_pp = + if IdSet.mem id generic_eq_types then + string "#[export] Instance Decidable_eq_" + ^^ type_id_pp ^^ space ^^ colon ^/^ string "forall (x y : " ^^ type_id_pp ^^ string "), Decidable (x = y)." + ^^ hardline ^^ intros_pp "x" ^^ intros_pp "y" + ^^ separate hardline + (list_init numfields (fun n -> + let ns = string_of_int n in + string ("cmp_record_field x" ^ ns ^ " y" ^ ns ^ ".") + ) + ) + ^^ hardline + ^^ string "refine (Build_Decidable _ true _). subst. split; reflexivity." + ^^ hardline ^^ string "Defined." ^^ twice hardline + else empty + in + let typqs_pp = doc_typquant_items bare_ctxt Env.empty braces typq in + let inhabited_pp = + let reqs_pp = separate (break 1) (List.filter_map doc_inhabited_req (quant_items typq)) in + let params_pp = separate space (List.filter_map (quant_item_id_name bare_ctxt) (quant_items typq)) in + let field_pp (_, fid) = fname fid ^^ string " := inhabitant" in + group + (prefix 2 1 + (group + (string "#[export] Instance dummy_" ^^ type_id_pp ^/^ typqs_pp ^/^ reqs_pp ^^ colon + ^/^ string "Inhabited (" ^^ type_id_pp ^^ space ^^ params_pp ^^ string ") := {" + ) + ) + (prefix 2 1 (string "inhabitant := {|") (separate_map (string ";" ^^ break 1) field_pp fs)) + ^/^ string "|} }." + ) + ^^ hardline + in + let reset_implicits_pp = doc_reset_implicits type_id_pp typq in + doc_op coloneq + (separate space [string "Record"; type_id_pp; typqs_pp]) + ((*doc_typquant typq*) braces (space ^^ align fs_doc ^^ space)) + ^^ dot ^^ hardline ^^ reset_implicits_pp ^^ hardline ^^ eq_pp ^^ updates_pp ^^ hardline ^^ inhabited_pp + ^^ twice hardline + | TD_variant (id, typq, ar, _) -> ( + match id with + | Id_aux (Id "read_kind", _) -> empty + | Id_aux (Id "write_kind", _) -> empty + | Id_aux (Id "a64_barrier_domain", _) -> empty + | Id_aux (Id "a64_barrier_type", _) -> empty + | Id_aux (Id "barrier_kind", _) -> empty + | Id_aux (Id "trans_kind", _) -> empty + | Id_aux (Id "instruction_kind", _) -> empty (* | Id_aux ((Id "regfp"),_) -> empty - | Id_aux ((Id "niafp"),_) -> empty - | Id_aux ((Id "diafp"),_) -> empty *) - | Id_aux ((Id "option"),_) -> empty + | Id_aux ((Id "niafp"),_) -> empty + | Id_aux ((Id "diafp"),_) -> empty *) + | Id_aux (Id "option", _) -> empty | _ -> - let id_pp = doc_id_type types_mod avoid_target_names None id in - let typ_nm = separate space [id_pp; doc_typquant_items bare_ctxt Env.empty braces typq] in - let ar_doc = group (separate_map (break 1) (fun x -> pipe ^^ space ^^ doc_type_union bare_ctxt id_pp x) ar) in - let typ_pp = - (doc_op coloneq) - (concat [string "Inductive"; space; typ_nm]) - ((*doc_typquant typq*) ar_doc) in - let reset_implicits_pp = doc_reset_implicits id_pp typq in - let doc_dec_eq_req = function - | QI_aux (QI_id (KOpt_aux (KOpt_kind (K_aux (K_type,_),kid),_)),_) -> - (* TODO: collision avoidance for x y *) - Some (string "`{forall x y : " ^^ doc_var bare_ctxt kid ^^ string ", Decidable (x = y)}") - | _ -> None - in - let doc_inhabited_req = function - | QI_aux (QI_id (KOpt_aux (KOpt_kind (K_aux (K_type,_),kid),_)),_) -> - Some (string "`{Inhabited " ^^ doc_var bare_ctxt kid ^^ string "}") - | _ -> None - in - let typ_use_pp = - separate space (id_pp::List.filter_map (quant_item_id_name bare_ctxt) (quant_items typq)) - in - let eq_pp = - if IdSet.mem id generic_eq_types then - let eq_reqs_pp = - separate (break 1) (List.filter_map doc_dec_eq_req (quant_items typq)) - in - string "#[export] Instance Decidable_eq_" ^^ typ_nm ^^ space ^^ eq_reqs_pp ^^ colon ^/^ - string "forall (x y : " ^^ typ_use_pp ^^ string "), Decidable (x = y)." ^^ hardline ^^ - string "refine (Decidable_eq_from_dec (fun x y => _))." ^^ hardline ^^ - string "decide equality; refine (generic_dec _ _)." ^^ hardline ^^ - string "Defined." ^^ hardline - else empty - in - let inhabited_pp = - match ar with - | Tu_aux (Tu_ty_id(typ,example_id),_)::_ -> - let reqs_pp = - separate (break 1) (List.filter_map doc_inhabited_req (quant_items typq)) - in - group (prefix 2 1 (group (string "#[export] Instance dummy_" ^^ typ_nm ^^ space ^^ reqs_pp ^^ colon ^/^ - string "Inhabited (" ^^ typ_use_pp ^^ string ") := {")) - (prefix 2 1 (string "inhabitant :=") - (doc_id_ctor bare_ctxt example_id ^^ string " inhabitant")) ^/^ - string "}.") ^^ hardline - | [] -> Reporting.print_err l "Warning" ("Empty type: " ^ string_of_id id); - empty - in - typ_pp ^^ dot ^^ hardline ^^ - reset_implicits_pp ^^ hardline ^^ - eq_pp ^^ hardline ^^ - inhabited_pp ^^ hardline) - | TD_enum(id,enums,_) -> - (match id with - | Id_aux ((Id "read_kind"),_) -> empty - | Id_aux ((Id "write_kind"),_) -> empty - | Id_aux ((Id "a64_barrier_domain"),_) -> empty - | Id_aux ((Id "a64_barrier_type"),_) -> empty - | Id_aux ((Id "barrier_kind"),_) -> empty - | Id_aux ((Id "trans_kind"),_) -> empty - | Id_aux ((Id "instruction_kind"),_) -> empty - | Id_aux ((Id "cache_op_kind"),_) -> empty - | Id_aux ((Id "regfp"),_) -> empty - | Id_aux ((Id "niafp"),_) -> empty - | Id_aux ((Id "diafp"),_) -> empty + let id_pp = doc_id_type types_mod avoid_target_names None id in + let typ_nm = separate space [id_pp; doc_typquant_items bare_ctxt Env.empty braces typq] in + let ar_doc = group (separate_map (break 1) (fun x -> pipe ^^ space ^^ doc_type_union bare_ctxt id_pp x) ar) in + let typ_pp = (doc_op coloneq) (concat [string "Inductive"; space; typ_nm]) (*doc_typquant typq*) ar_doc in + let reset_implicits_pp = doc_reset_implicits id_pp typq in + let doc_dec_eq_req = function + | QI_aux (QI_id (KOpt_aux (KOpt_kind (K_aux (K_type, _), kid), _)), _) -> + (* TODO: collision avoidance for x y *) + Some (string "`{forall x y : " ^^ doc_var bare_ctxt kid ^^ string ", Decidable (x = y)}") + | _ -> None + in + let doc_inhabited_req = function + | QI_aux (QI_id (KOpt_aux (KOpt_kind (K_aux (K_type, _), kid), _)), _) -> + Some (string "`{Inhabited " ^^ doc_var bare_ctxt kid ^^ string "}") + | _ -> None + in + let typ_use_pp = + separate space (id_pp :: List.filter_map (quant_item_id_name bare_ctxt) (quant_items typq)) + in + let eq_pp = + if IdSet.mem id generic_eq_types then ( + let eq_reqs_pp = separate (break 1) (List.filter_map doc_dec_eq_req (quant_items typq)) in + string "#[export] Instance Decidable_eq_" + ^^ typ_nm ^^ space ^^ eq_reqs_pp ^^ colon ^/^ string "forall (x y : " ^^ typ_use_pp + ^^ string "), Decidable (x = y)." ^^ hardline + ^^ string "refine (Decidable_eq_from_dec (fun x y => _))." + ^^ hardline + ^^ string "decide equality; refine (generic_dec _ _)." + ^^ hardline ^^ string "Defined." ^^ hardline + ) + else empty + in + let inhabited_pp = + match ar with + | Tu_aux (Tu_ty_id (typ, example_id), _) :: _ -> + let reqs_pp = separate (break 1) (List.filter_map doc_inhabited_req (quant_items typq)) in + group + (prefix 2 1 + (group + (string "#[export] Instance dummy_" ^^ typ_nm ^^ space ^^ reqs_pp ^^ colon + ^/^ string "Inhabited (" ^^ typ_use_pp ^^ string ") := {" + ) + ) + (prefix 2 1 (string "inhabitant :=") (doc_id_ctor bare_ctxt example_id ^^ string " inhabitant")) + ^/^ string "}." + ) + ^^ hardline + | [] -> + Reporting.print_err l "Warning" ("Empty type: " ^ string_of_id id); + empty + in + typ_pp ^^ dot ^^ hardline ^^ reset_implicits_pp ^^ hardline ^^ eq_pp ^^ hardline ^^ inhabited_pp ^^ hardline + ) + | TD_enum (id, enums, _) -> ( + match id with + | Id_aux (Id "read_kind", _) -> empty + | Id_aux (Id "write_kind", _) -> empty + | Id_aux (Id "a64_barrier_domain", _) -> empty + | Id_aux (Id "a64_barrier_type", _) -> empty + | Id_aux (Id "barrier_kind", _) -> empty + | Id_aux (Id "trans_kind", _) -> empty + | Id_aux (Id "instruction_kind", _) -> empty + | Id_aux (Id "cache_op_kind", _) -> empty + | Id_aux (Id "regfp", _) -> empty + | Id_aux (Id "niafp", _) -> empty + | Id_aux (Id "diafp", _) -> empty | _ -> - let enums_doc = group (separate_map (break 1 ^^ pipe ^^ space) (doc_id_ctor bare_ctxt) enums) in - let id_pp = doc_id_type types_mod avoid_target_names None id in - let typ_pp = (doc_op coloneq) - (concat [string "Inductive"; space; id_pp]) - (enums_doc) in - let eq1_pp = string "Scheme Equality for" ^^ space ^^ id_pp ^^ dot in - let eq2_pp = string "#[export] Instance Decidable_eq_" ^^ id_pp ^^ space ^^ colon ^/^ - string "forall (x y : " ^^ id_pp ^^ string "), Decidable (x = y) :=" ^/^ - string "Decidable_eq_from_dec " ^^ id_pp ^^ string "_eq_dec." in - let inhabited_pp = - match enums with - | example_id::_ -> - group (prefix 2 1 (group (string "#[export] Instance dummy_" ^^ id_pp ^^ space ^^ colon ^/^ - string "Inhabited " ^^ id_pp ^^ string " := {")) - (string "inhabitant :=" ^/^ doc_id_ctor bare_ctxt example_id) ^/^ - string "}.") ^^ hardline - | [] -> Reporting.print_err l "Warning" ("Empty type: " ^ string_of_id id); - empty - in - typ_pp ^^ dot ^^ hardline ^^ eq1_pp ^^ hardline ^^ eq2_pp ^^ hardline ^^ inhabited_pp ^^ twice hardline) + let enums_doc = group (separate_map (break 1 ^^ pipe ^^ space) (doc_id_ctor bare_ctxt) enums) in + let id_pp = doc_id_type types_mod avoid_target_names None id in + let typ_pp = (doc_op coloneq) (concat [string "Inductive"; space; id_pp]) enums_doc in + let eq1_pp = string "Scheme Equality for" ^^ space ^^ id_pp ^^ dot in + let eq2_pp = + string "#[export] Instance Decidable_eq_" + ^^ id_pp ^^ space ^^ colon ^/^ string "forall (x y : " ^^ id_pp ^^ string "), Decidable (x = y) :=" + ^/^ string "Decidable_eq_from_dec " ^^ id_pp ^^ string "_eq_dec." + in + let inhabited_pp = + match enums with + | example_id :: _ -> + group + (prefix 2 1 + (group + (string "#[export] Instance dummy_" ^^ id_pp ^^ space ^^ colon ^/^ string "Inhabited " ^^ id_pp + ^^ string " := {" + ) + ) + (string "inhabitant :=" ^/^ doc_id_ctor bare_ctxt example_id) + ^/^ string "}." + ) + ^^ hardline + | [] -> + Reporting.print_err l "Warning" ("Empty type: " ^ string_of_id id); + empty + in + typ_pp ^^ dot ^^ hardline ^^ eq1_pp ^^ hardline ^^ eq2_pp ^^ hardline ^^ inhabited_pp ^^ twice hardline + ) let args_of_typ l env typs = let arg i typ = let id = mk_id ("arg" ^ string_of_int i) in - (P_aux (P_id id, (l, mk_tannot env typ)), typ), - E_aux (E_id id, (l, mk_tannot env typ)) in + ((P_aux (P_id id, (l, mk_tannot env typ)), typ), E_aux (E_id id, (l, mk_tannot env typ))) + in List.split (List.mapi arg typs) (* Sail currently has a single pattern to match against a list of @@ -2585,122 +2581,117 @@ let args_of_typ l env typs = into multiple binders and reconstruct it in the function body. *) let rec untuple_args_pat typs (P_aux (paux, ((l, _) as annot)) as pat) = let env = env_of_annot annot in - let identity = (fun body -> body) in - match paux, typs with + let identity body = body in + match (paux, typs) with | P_tuple [], _ -> - let annot = (l, mk_tannot Env.empty unit_typ) in - [P_aux (P_lit (mk_lit L_unit), annot), unit_typ], identity - | P_tuple pats, _ -> List.combine pats typs, identity + let annot = (l, mk_tannot Env.empty unit_typ) in + ([(P_aux (P_lit (mk_lit L_unit), annot), unit_typ)], identity) + | P_tuple pats, _ -> (List.combine pats typs, identity) | P_wild, _ -> - let wild typ = P_aux (P_wild, (l, mk_tannot env typ)), typ in - List.map wild typs, identity + let wild typ = (P_aux (P_wild, (l, mk_tannot env typ)), typ) in + (List.map wild typs, identity) | P_typ (_, pat), _ -> untuple_args_pat typs pat - | P_as _, _::_::_ | P_id _, _::_::_ -> - let argpats, argexps = args_of_typ l env typs in - let argexp = E_aux (E_tuple argexps, annot) in - let bindargs (E_aux (_, bannot) as body) = - E_aux (E_let (LB_aux (LB_val (pat, argexp), annot), body), bannot) in - argpats, bindargs - | _, [typ] -> - [pat,typ], identity - | _, _ -> - unreachable l __POS__ "Unexpected pattern/type combination" + | P_as _, _ :: _ :: _ | P_id _, _ :: _ :: _ -> + let argpats, argexps = args_of_typ l env typs in + let argexp = E_aux (E_tuple argexps, annot) in + let bindargs (E_aux (_, bannot) as body) = E_aux (E_let (LB_aux (LB_val (pat, argexp), annot), body), bannot) in + (argpats, bindargs) + | _, [typ] -> ([(pat, typ)], identity) + | _, _ -> unreachable l __POS__ "Unexpected pattern/type combination" let doc_fun_body ctxt is_monadic exp = let doc_exp = doc_exp ctxt false exp in - if Option.is_some ctxt.early_ret - then - if is_monadic - then align (string "catch_early_return" ^//^ parens (doc_exp)) - else align (string "pure_early_return" ^//^ parens (doc_exp)) + if Option.is_some ctxt.early_ret then + if is_monadic then align (string "catch_early_return" ^//^ parens doc_exp) + else align (string "pure_early_return" ^//^ parens doc_exp) else doc_exp (* Coq doesn't support "as" patterns well in Definition binders, so we push them over to the r.h.s. of the := *) -let demote_as_pattern i (P_aux (_,p_annot) as pat,typ) = +let demote_as_pattern i ((P_aux (_, p_annot) as pat), typ) = let open Rewriter in - if fst (fold_pat ({ (compute_pat_alg false (||)) with p_as = (fun ((_,p),id) -> true, P_as (p,id)) }) pat) - then - let id = mk_id ("arg" ^ string_of_int i) in (* TODO: name conflicts *) - (P_aux (P_id id, p_annot),typ), - fun (E_aux (_,e_ann) as e) -> - E_aux (E_let (LB_aux (LB_val (pat, E_aux (E_id id, p_annot)),p_annot),e),e_ann) - else (pat,typ), fun e -> e - -let pat_is_plain_binder env (P_aux (p,_)) = + if fst (fold_pat { (compute_pat_alg false ( || )) with p_as = (fun ((_, p), id) -> (true, P_as (p, id))) } pat) then ( + let id = mk_id ("arg" ^ string_of_int i) in + (* TODO: name conflicts *) + ( (P_aux (P_id id, p_annot), typ), + fun (E_aux (_, e_ann) as e) -> E_aux (E_let (LB_aux (LB_val (pat, E_aux (E_id id, p_annot)), p_annot), e), e_ann) + ) + ) + else ((pat, typ), fun e -> e) + +let pat_is_plain_binder env (P_aux (p, _)) = match p with - | P_id id - | P_typ (_,P_aux (P_id id,_)) - when not (is_enum env id) -> Some (Some id) + | (P_id id | P_typ (_, P_aux (P_id id, _))) when not (is_enum env id) -> Some (Some id) | P_wild -> Some None | _ -> None -let demote_all_patterns env i (P_aux (p,p_annot) as pat,typ) = +let demote_all_patterns env i ((P_aux (p, p_annot) as pat), typ) = match pat_is_plain_binder env pat with | Some id -> - if Option.is_none (is_auto_decomposed_exist empty_ctxt env typ) (* TODO? *) - then (pat,typ), fun e -> e - else begin - match id with - | Some id -> - (P_aux (P_id id, p_annot),typ), - fun (E_aux (_,e_ann) as e) -> - E_aux (E_let (LB_aux (LB_val (pat, E_aux (E_id id, p_annot)),p_annot),e),e_ann) - | None -> (P_aux (P_wild, p_annot),typ), fun e -> e - end + if Option.is_none (is_auto_decomposed_exist empty_ctxt env typ) (* TODO? *) then ((pat, typ), fun e -> e) + else begin + match id with + | Some id -> + ( (P_aux (P_id id, p_annot), typ), + fun (E_aux (_, e_ann) as e) -> + E_aux (E_let (LB_aux (LB_val (pat, E_aux (E_id id, p_annot)), p_annot), e), e_ann) + ) + | None -> ((P_aux (P_wild, p_annot), typ), fun e -> e) + end | None -> - let id = mk_id ("arg" ^ string_of_int i) in (* TODO: name conflicts *) - (P_aux (P_id id, p_annot),typ), - fun (E_aux (_,e_ann) as e) -> - E_aux (E_let (LB_aux (LB_val (pat, E_aux (E_id id, p_annot)),p_annot),e),e_ann) + let id = mk_id ("arg" ^ string_of_int i) in + (* TODO: name conflicts *) + ( (P_aux (P_id id, p_annot), typ), + fun (E_aux (_, e_ann) as e) -> E_aux (E_let (LB_aux (LB_val (pat, E_aux (E_id id, p_annot)), p_annot), e), e_ann) + ) (* Note equality constraints between arguments and nexps in a comment, except in the case that they've been merged. *) let rec atom_constraint ctxt (pat, typ) = let typ = Env.base_typ_of (env_of_pat pat) typ in - match pat, typ with - | P_aux (P_id id, _), - Typ_aux (Typ_app (Id_aux (Id "atom",_), - [A_aux (A_nexp nexp,_)]),_) -> - (match nexp with - (* When the kid is mapped to the id, we don't need a constraint *) - | Nexp_aux (Nexp_var kid,_) - when (try Id.compare (Util.option_get_exn Not_found (KBindings.find kid ctxt.kid_id_renames)) id == 0 with _ -> false) -> - None - | _ -> - Some (comment (doc_op (string "=?") (doc_id ctxt id) (doc_nexp ctxt nexp)))) - | P_aux (P_typ (_,p),_), _ -> atom_constraint ctxt (p, typ) + match (pat, typ) with + | P_aux (P_id id, _), Typ_aux (Typ_app (Id_aux (Id "atom", _), [A_aux (A_nexp nexp, _)]), _) -> ( + match nexp with + (* When the kid is mapped to the id, we don't need a constraint *) + | Nexp_aux (Nexp_var kid, _) + when try Id.compare (Util.option_get_exn Not_found (KBindings.find kid ctxt.kid_id_renames)) id == 0 + with _ -> false -> + None + | _ -> Some (comment (doc_op (string "=?") (doc_id ctxt id) (doc_nexp ctxt nexp))) + ) + | P_aux (P_typ (_, p), _), _ -> atom_constraint ctxt (p, typ) | _ -> None let all_ids pexp = let open Rewriter in - fold_pexp ( - { (pure_exp_alg IdSet.empty IdSet.union) with + fold_pexp + { + (pure_exp_alg IdSet.empty IdSet.union) with e_id = (fun id -> IdSet.singleton id); e_ref = (fun id -> IdSet.singleton id); - e_app = (fun (id,ids) -> - List.fold_left IdSet.union (IdSet.singleton id) ids); - e_app_infix = (fun (ids1,id,ids2) -> - IdSet.add id (IdSet.union ids1 ids2)); - e_for = (fun (id,ids1,ids2,ids3,_,ids4) -> - IdSet.add id (IdSet.union ids1 (IdSet.union ids2 (IdSet.union ids3 ids4)))); + e_app = (fun (id, ids) -> List.fold_left IdSet.union (IdSet.singleton id) ids); + e_app_infix = (fun (ids1, id, ids2) -> IdSet.add id (IdSet.union ids1 ids2)); + e_for = + (fun (id, ids1, ids2, ids3, _, ids4) -> + IdSet.add id (IdSet.union ids1 (IdSet.union ids2 (IdSet.union ids3 ids4))) + ); le_id = IdSet.singleton; - le_app = (fun (id,ids) -> - List.fold_left IdSet.union (IdSet.singleton id) ids); - le_typ = (fun (_,id) -> IdSet.singleton id); - pat_alg = { (pure_pat_alg IdSet.empty IdSet.union) with - p_as = (fun (ids,id) -> IdSet.add id ids); - p_id = IdSet.singleton; - p_app = (fun (id,ids) -> - List.fold_left IdSet.union (IdSet.singleton id) ids); - } - }) pexp - -let tyvars_of_typquant (TypQ_aux (tq,_)) = + le_app = (fun (id, ids) -> List.fold_left IdSet.union (IdSet.singleton id) ids); + le_typ = (fun (_, id) -> IdSet.singleton id); + pat_alg = + { + (pure_pat_alg IdSet.empty IdSet.union) with + p_as = (fun (ids, id) -> IdSet.add id ids); + p_id = IdSet.singleton; + p_app = (fun (id, ids) -> List.fold_left IdSet.union (IdSet.singleton id) ids); + }; + } + pexp + +let tyvars_of_typquant (TypQ_aux (tq, _)) = match tq with | TypQ_no_forall -> KidSet.empty - | TypQ_tq qs -> List.fold_left KidSet.union KidSet.empty - (List.map tyvars_of_quant_item qs) + | TypQ_tq qs -> List.fold_left KidSet.union KidSet.empty (List.map tyvars_of_quant_item qs) let mk_kid_renames avoid_target_names ids_to_avoid kids = let map_id = function @@ -2708,93 +2699,97 @@ let mk_kid_renames avoid_target_names ids_to_avoid kids = | Id_aux (Operator _, _) -> None in let ids = StringSet.of_list (List.filter_map map_id (IdSet.elements ids_to_avoid)) in - let check_kid kid (newkids,rebindings) = + let check_kid kid (newkids, rebindings) = let rec check kid1 = let kid_string = fix_id avoid_target_names true (string_of_kid kid1) in - if StringSet.mem kid_string ids - then let kid2 = match kid1 with Kid_aux (Var x,l) -> Kid_aux (Var (x ^ "0"),l) in - check kid2 - else - KidSet.add kid1 newkids, KBindings.add kid kid1 rebindings - in check kid - in snd (KidSet.fold check_kid kids (kids, KBindings.empty)) + if StringSet.mem kid_string ids then ( + let kid2 = match kid1 with Kid_aux (Var x, l) -> Kid_aux (Var (x ^ "0"), l) in + check kid2 + ) + else (KidSet.add kid1 newkids, KBindings.add kid kid1 rebindings) + in + check kid + in + snd (KidSet.fold check_kid kids (kids, KBindings.empty)) let merge_kids_atoms pats = - let try_eliminate (acc,gone,map,seen) (pat,typ) = + let try_eliminate (acc, gone, map, seen) (pat, typ) = let tryon maybe_id env typ = let merge kid l = - if KidSet.mem kid seen then + if KidSet.mem kid seen then ( let () = Reporting.print_err l "merge_kids_atoms" - ("want to merge tyvar and argument for " ^ string_of_kid kid ^ - " but rearranging arguments isn't supported yet") in - (pat,typ)::acc,gone,map,seen - else - let pat,id = match maybe_id with - | Some id -> pat,id - (* TODO: name clashes *) - | None -> let id = id_of_kid kid in - P_aux (P_id id,match pat with P_aux (_,ann) -> ann), id + ("want to merge tyvar and argument for " ^ string_of_kid kid + ^ " but rearranging arguments isn't supported yet" + ) in - (pat,typ)::acc, - KidSet.add kid gone, KBindings.add kid (Some id) map, KidSet.add kid seen + ((pat, typ) :: acc, gone, map, seen) + ) + else ( + let pat, id = + match maybe_id with + | Some id -> (pat, id) (* TODO: name clashes *) + | None -> + let id = id_of_kid kid in + (P_aux (P_id id, match pat with P_aux (_, ann) -> ann), id) + in + ((pat, typ) :: acc, KidSet.add kid gone, KBindings.add kid (Some id) map, KidSet.add kid seen) + ) in match Type_check.destruct_atom_nexp env typ with - | Some (Nexp_aux (Nexp_var kid,l)) -> merge kid l - | _ -> - match Type_check.destruct_atom_bool env typ with - | Some (NC_aux (NC_var kid,l)) -> merge kid l - | _ -> (pat,typ)::acc,gone,map,KidSet.union seen (tyvars_of_typ typ) + | Some (Nexp_aux (Nexp_var kid, l)) -> merge kid l + | _ -> ( + match Type_check.destruct_atom_bool env typ with + | Some (NC_aux (NC_var kid, l)) -> merge kid l + | _ -> ((pat, typ) :: acc, gone, map, KidSet.union seen (tyvars_of_typ typ)) + ) in - match pat,typ with - | P_aux (P_id id, ann), typ - | P_aux (P_typ (_,P_aux (P_id id, ann)),_), typ -> - tryon (Some id) (env_of_annot ann) typ - | P_aux (P_wild, ann), typ -> - tryon None (env_of_annot ann) typ - | _ -> (pat,typ)::acc,gone,map,KidSet.union seen (tyvars_of_typ typ) + match (pat, typ) with + | P_aux (P_id id, ann), typ | P_aux (P_typ (_, P_aux (P_id id, ann)), _), typ -> + tryon (Some id) (env_of_annot ann) typ + | P_aux (P_wild, ann), typ -> tryon None (env_of_annot ann) typ + | _ -> ((pat, typ) :: acc, gone, map, KidSet.union seen (tyvars_of_typ typ)) in - let r_pats,gone,map,_ = List.fold_left try_eliminate ([],KidSet.empty, KBindings.empty, KidSet.empty) pats in - List.rev r_pats,gone,map - + let r_pats, gone, map, _ = List.fold_left try_eliminate ([], KidSet.empty, KBindings.empty, KidSet.empty) pats in + (List.rev r_pats, gone, map) let merge_var_patterns map pats = - let map,pats = List.fold_left (fun (map,pats) (pat, typ) -> - match pat with - | P_aux (P_var (P_aux (P_id id,_), TP_aux (TP_var kid,_)),ann) -> - KBindings.add kid (Some id) map, (P_aux (P_id id,ann), typ) :: pats - | _ -> map, (pat,typ)::pats) (map,[]) pats - in map, List.rev pats + let map, pats = + List.fold_left + (fun (map, pats) (pat, typ) -> + match pat with + | P_aux (P_var (P_aux (P_id id, _), TP_aux (TP_var kid, _)), ann) -> + (KBindings.add kid (Some id) map, (P_aux (P_id id, ann), typ) :: pats) + | _ -> (map, (pat, typ) :: pats) + ) + (map, []) pats + in + (map, List.rev pats) type mutrec_pos = NotMutrec | FirstFn | LaterFn -let doc_funcl_init types_mod avoid_target_names effect_info mutrec rec_opt ?rec_set (FCL_aux(FCL_funcl(id, pexp), annot)) = +let doc_funcl_init types_mod avoid_target_names effect_info mutrec rec_opt ?rec_set + (FCL_aux (FCL_funcl (id, pexp), annot)) = let env = env_of_tannot (snd annot) in - let (tq,typ) = Env.get_val_spec_orig id env in - let (arg_typs, ret_typ, _) = match typ with - | Typ_aux (Typ_fn (arg_typs, ret_typ),_) -> arg_typs, ret_typ, no_effect + let tq, typ = Env.get_val_spec_orig id env in + let arg_typs, ret_typ, _ = + match typ with + | Typ_aux (Typ_fn (arg_typs, ret_typ), _) -> (arg_typs, ret_typ, no_effect) | _ -> failwith ("Function " ^ string_of_id id ^ " does not have function type") in let is_monadic = not (Effects.function_is_pure id effect_info) in let ids_to_avoid = all_ids pexp in let bound_kids = tyvars_of_typquant tq in - let pat,guard,exp,(l,_) = destruct_pexp pexp in + let pat, guard, exp, (l, _) = destruct_pexp pexp in let pats, bind = untuple_args_pat arg_typs pat in (* Fixpoint definitions can only use simple binders, but even Definitions can't handle as patterns *) - let pattern_elim = - match rec_opt with - | Rec_aux (Rec_nonrec,_) -> demote_as_pattern - | _ -> demote_all_patterns env - in + let pattern_elim = match rec_opt with Rec_aux (Rec_nonrec, _) -> demote_as_pattern | _ -> demote_all_patterns env in let pats, binds = List.split (List.mapi pattern_elim pats) in let pats, eliminated_kids, kid_to_arg_rename = merge_kids_atoms pats in let kid_to_arg_rename, pats = merge_var_patterns kid_to_arg_rename pats in let kids_used = KidSet.diff bound_kids eliminated_kids in - let is_measured = match rec_opt with - | Rec_aux (Rec_measure _,_) -> true - | _ -> false - in + let is_measured = match rec_opt with Rec_aux (Rec_measure _, _) -> true | _ -> false in let kir_rev = KBindings.fold (fun kid idopt m -> match idopt with Some id -> Bindings.add id kid m | None -> m) @@ -2802,39 +2797,53 @@ let doc_funcl_init types_mod avoid_target_names effect_info mutrec rec_opt ?rec_ in let simple_type_equations = Type_check.instantiate_simple_equations (quant_items tq) in let constant_kids = - kbindings_filter_map (fun kid inst -> - match inst with - | A_aux (A_nexp (Nexp_aux (Nexp_constant value, _)), _) -> Some value - | _ -> None) simple_type_equations + kbindings_filter_map + (fun kid inst -> match inst with A_aux (A_nexp (Nexp_aux (Nexp_constant value, _)), _) -> Some value | _ -> None) + simple_type_equations in let ctxt0 = - { types_mod = types_mod; - early_ret = None; (* filled in below *) + { + types_mod; + early_ret = None; + (* filled in below *) kid_renames = mk_kid_renames avoid_target_names ids_to_avoid kids_used; kid_id_renames = kid_to_arg_rename; kid_id_renames_rev = kir_rev; constant_kids; bound_nvars = bound_kids; - build_at_return = None; (* filled in below *) - recursive_fns = Bindings.empty; (* filled in later *) - debug = List.mem (string_of_id id) (!opt_debug_on); - ret_typ_pp = PPrint.empty; (* filled in below *) + build_at_return = None; + (* filled in below *) + recursive_fns = Bindings.empty; + (* filled in later *) + debug = List.mem (string_of_id id) !opt_debug_on; + ret_typ_pp = PPrint.empty; + (* filled in below *) effect_info; is_monadic; avoid_target_names; - } in - let ctxt = { ctxt0 with - early_ret = if contains_early_return exp then Some ret_typ else None; - ret_typ_pp = doc_typ ctxt0 Env.empty ret_typ - } in + } + in + let ctxt = + { + ctxt0 with + early_ret = (if contains_early_return exp then Some ret_typ else None); + ret_typ_pp = doc_typ ctxt0 Env.empty ret_typ; + } + in let () = debug ctxt (lazy ("Function " ^ string_of_id id)); debug ctxt (lazy (" return type " ^ string_of_typ ret_typ)); debug ctxt (lazy (if is_monadic then " monadic" else " pure")); - debug ctxt (lazy (" kid_id_renames " ^ String.concat ", " (List.map - (fun (kid,id) -> string_of_kid kid ^ " |-> " ^ - match id with Some id -> string_of_id id | None -> "<>") - (KBindings.bindings kid_to_arg_rename)))) + debug ctxt + ( lazy + (" kid_id_renames " + ^ String.concat ", " + (List.map + (fun (kid, id) -> string_of_kid kid ^ " |-> " ^ match id with Some id -> string_of_id id | None -> "<>") + (KBindings.bindings kid_to_arg_rename) + ) + ) + ) in (* Put the constraints after pattern matching so that any type variable that's @@ -2842,12 +2851,12 @@ let doc_funcl_init types_mod avoid_target_names effect_info mutrec rec_opt ?rec_ let quantspp, constrspp = doc_typquant_items_separate ctxt env braces tq in let is_fixed_by_eqn env typ = match destruct_atom_nexp env typ with - | Some (Nexp_aux (Nexp_var kid,_)) -> KBindings.find_opt kid constant_kids + | Some (Nexp_aux (Nexp_var kid, _)) -> KBindings.find_opt kid constant_kids | _ -> None in let exp = List.fold_left (fun body f -> f body) (bind exp) binds in let used_a_pattern = ref false in - let doc_binder (P_aux (p,ann) as pat, typ) = + let doc_binder ((P_aux (p, ann) as pat), typ) = let env = env_of_annot ann in let exp_typ = Env.expand_synonyms env typ in let () = @@ -2860,148 +2869,153 @@ let doc_funcl_init types_mod avoid_target_names effect_info mutrec rec_opt ?rec_ let id_pp = match id with Some id -> doc_id ctxt id | None -> underscore in match is_fixed_by_eqn env exp_typ with | Some constant -> - parens (separate space [id_pp; colon; doc_typ ctxt Env.empty typ; string ":="; string (Big_int.to_string constant)]) - | None -> - match classify_ex_type ctxt env ?binding:id exp_typ with - | _, _, typ' -> - parens (separate space [id_pp; colon; doc_typ ctxt Env.empty typ']) + parens + (separate space + [id_pp; colon; doc_typ ctxt Env.empty typ; string ":="; string (Big_int.to_string constant)] + ) + | None -> ( + match classify_ex_type ctxt env ?binding:id exp_typ with + | _, _, typ' -> parens (separate space [id_pp; colon; doc_typ ctxt Env.empty typ']) + ) end | None -> - let typ = - match classify_ex_type ctxt env ~binding:id exp_typ with - | _, _, typ' -> typ' - in - (used_a_pattern := true; - squote ^^ parens (separate space [doc_pat ctxt true true (pat, exp_typ); colon; doc_typ ctxt Env.empty typ])) + let typ = match classify_ex_type ctxt env ~binding:id exp_typ with _, _, typ' -> typ' in + used_a_pattern := true; + squote ^^ parens (separate space [doc_pat ctxt true true (pat, exp_typ); colon; doc_typ ctxt Env.empty typ]) in let patspp = flow_map (break 1) doc_binder pats in let atom_constrs = List.filter_map (atom_constraint ctxt) pats in let retpp = (* TODO: again, probably should provide proper environment *) - if is_monadic - then string "M" ^^ space ^^ parens ctxt.ret_typ_pp - else doc_typ ctxt Env.empty ret_typ + if is_monadic then string "M" ^^ space ^^ parens ctxt.ret_typ_pp else doc_typ ctxt Env.empty ret_typ in let idpp = doc_id ctxt id in - let intropp, accpp, measurepp, fixupspp = match rec_opt with - | Rec_aux (Rec_measure _,_) -> - let fixupspp = - List.filter_map (fun (pat,typ) -> - match pat_is_plain_binder env pat with - | Some (Some id) -> begin - match destruct_exist_plain (Env.expand_synonyms env (expand_range_type typ)) with - | Some (_, NC_aux (NC_true,_), _) -> None - | Some ([KOpt_aux (KOpt_kind (_, kid), _)], nc, - Typ_aux (Typ_app (Id_aux (Id "atom",_), - [A_aux (A_nexp (Nexp_aux (Nexp_var kid',_)),_)]),_)) - when Kid.compare kid kid' == 0 -> - Some (string "let " ^^ doc_id ctxt id ^^ string " := projT1 " ^^ doc_id ctxt id ^^ string " in") - | _ -> None - end - | _ -> None) pats - in - string "Fixpoint", - [parens (string "_acc : Acc (Zwf 0) _reclimit")], - [string "{struct _acc}"], - fixupspp - | Rec_aux (r,_) -> - let d = match r with Rec_nonrec -> "Definition" | _ -> "Fixpoint" in - string d, [], [], [] - in - let intropp = - match mutrec with - | NotMutrec -> intropp - | FirstFn -> string "Fixpoint" - | LaterFn -> string "with" + let intropp, accpp, measurepp, fixupspp = + match rec_opt with + | Rec_aux (Rec_measure _, _) -> + let fixupspp = + List.filter_map + (fun (pat, typ) -> + match pat_is_plain_binder env pat with + | Some (Some id) -> begin + match destruct_exist_plain (Env.expand_synonyms env (expand_range_type typ)) with + | Some (_, NC_aux (NC_true, _), _) -> None + | Some + ( [KOpt_aux (KOpt_kind (_, kid), _)], + nc, + Typ_aux (Typ_app (Id_aux (Id "atom", _), [A_aux (A_nexp (Nexp_aux (Nexp_var kid', _)), _)]), _) + ) + when Kid.compare kid kid' == 0 -> + Some (string "let " ^^ doc_id ctxt id ^^ string " := projT1 " ^^ doc_id ctxt id ^^ string " in") + | _ -> None + end + | _ -> None + ) + pats + in + (string "Fixpoint", [parens (string "_acc : Acc (Zwf 0) _reclimit")], [string "{struct _acc}"], fixupspp) + | Rec_aux (r, _) -> + let d = match r with Rec_nonrec -> "Definition" | _ -> "Fixpoint" in + (string d, [], [], []) in + let intropp = match mutrec with NotMutrec -> intropp | FirstFn -> string "Fixpoint" | LaterFn -> string "with" in let ctxt = - if is_measured then - { ctxt with recursive_fns = Bindings.singleton id (List.length quantspp, 0) } - else ctxt in - let _ = match guard with + if is_measured then { ctxt with recursive_fns = Bindings.singleton id (List.length quantspp, 0) } else ctxt + in + let _ = + match guard with | None -> () | _ -> - raise (Reporting.err_unreachable l __POS__ - "guarded pattern expression should have been rewritten before pretty-printing") in - (group (flow (break 1) ([intropp; idpp] @ quantspp @ [patspp] @ constrspp @ atom_constrs @ accpp) ^/^ - flow (break 1) (measurepp @ [colon; retpp])), - ctxt, - (exp, is_monadic, fixupspp)) - + raise + (Reporting.err_unreachable l __POS__ + "guarded pattern expression should have been rewritten before pretty-printing" + ) + in + ( group + (flow (break 1) ([intropp; idpp] @ quantspp @ [patspp] @ constrspp @ atom_constrs @ accpp) + ^/^ flow (break 1) (measurepp @ [colon; retpp]) + ), + ctxt, + (exp, is_monadic, fixupspp) + ) let doc_funcl_body ctxt (exp, is_monadic, fixupspp) = let bodypp = doc_fun_body ctxt is_monadic exp in let bodypp = - if is_monadic - then + if is_monadic then (* Sometimes a function is marked effectful by effect inference when it's not (especially mappings)... TODO: this seems bad!? *) - if not (effectful (effect_of exp)) - then string "returnM" ^/^ parens bodypp - else bodypp + if not (effectful (effect_of exp)) then string "returnM" ^/^ parens bodypp else bodypp else if Option.is_some ctxt.early_ret then bodypp - else bodypp in + else bodypp + in let bodypp = separate (break 1) (fixupspp @ [bodypp]) in group bodypp -let get_id = function - | [] -> failwith "FD_function with empty list" - | (FCL_aux (FCL_funcl (id,_),_))::_ -> id +let get_id = function [] -> failwith "FD_function with empty list" | FCL_aux (FCL_funcl (id, _), _) :: _ -> id (* Coq doesn't support multiple clauses for a single function joined by "and". However, all the funcls should have been merged by the merge_funcls rewrite now. *) -let doc_fundef_rhs types_mod avoid_target_names effect_info ?(mutrec=NotMutrec) rec_set (FD_aux(FD_function(r, typa, funcls),(l,_))) = +let doc_fundef_rhs types_mod avoid_target_names effect_info ?(mutrec = NotMutrec) rec_set + (FD_aux (FD_function (r, typa, funcls), (l, _))) = match funcls with | [] -> unreachable l __POS__ "function with no clauses" | [funcl] -> doc_funcl_init types_mod avoid_target_names effect_info mutrec r ~rec_set funcl - | (FCL_aux (FCL_funcl (id,_),_))::_ -> unreachable l __POS__ ("function " ^ string_of_id id ^ " has multiple clauses in backend") + | FCL_aux (FCL_funcl (id, _), _) :: _ -> + unreachable l __POS__ ("function " ^ string_of_id id ^ " has multiple clauses in backend") let doc_mutrec types_mod avoid_target_names effect_info rec_set = function | [] -> failwith "DEF_internal_mutrec with empty function list" - | fundef::fundefs -> - let pre1,ctxt1,details1 = doc_fundef_rhs types_mod avoid_target_names effect_info ~mutrec:FirstFn rec_set fundef in - let pren,ctxtn,detailsn = Util.split3 (List.map (doc_fundef_rhs types_mod avoid_target_names effect_info ~mutrec:LaterFn rec_set) fundefs) in - let recursive_fns = List.fold_left (fun m c -> Bindings.union (fun _ x _ -> Some x) m c.recursive_fns) ctxt1.recursive_fns ctxtn in - let ctxts = List.map (fun c -> { c with recursive_fns }) (ctxt1::ctxtn) in - let bodies = List.map2 doc_funcl_body ctxts (details1::detailsn) in - let idpps = List.map (fun fd -> string (string_of_id (id_of_fundef fd))) (fundef::fundefs) in - let bodies = List.map2 (fun idpp b -> surround 3 0 (string "(*" ^^ idpp ^^ string "*) exact (") b (string ").")) idpps bodies in - let pres = pre1::pren in - separate hardline pres ^^ dot ^^ hardline ^^ - separate hardline bodies ^^ - break 1 ^^ string "Defined." ^^ hardline + | fundef :: fundefs -> + let pre1, ctxt1, details1 = + doc_fundef_rhs types_mod avoid_target_names effect_info ~mutrec:FirstFn rec_set fundef + in + let pren, ctxtn, detailsn = + Util.split3 (List.map (doc_fundef_rhs types_mod avoid_target_names effect_info ~mutrec:LaterFn rec_set) fundefs) + in + let recursive_fns = + List.fold_left (fun m c -> Bindings.union (fun _ x _ -> Some x) m c.recursive_fns) ctxt1.recursive_fns ctxtn + in + let ctxts = List.map (fun c -> { c with recursive_fns }) (ctxt1 :: ctxtn) in + let bodies = List.map2 doc_funcl_body ctxts (details1 :: detailsn) in + let idpps = List.map (fun fd -> string (string_of_id (id_of_fundef fd))) (fundef :: fundefs) in + let bodies = + List.map2 (fun idpp b -> surround 3 0 (string "(*" ^^ idpp ^^ string "*) exact (") b (string ").")) idpps bodies + in + let pres = pre1 :: pren in + separate hardline pres ^^ dot ^^ hardline ^^ separate hardline bodies ^^ break 1 ^^ string "Defined." ^^ hardline let doc_funcl types_mod avoid_target_names effect_info mutrec r funcl = - let pre,ctxt,details = doc_funcl_init types_mod avoid_target_names effect_info mutrec r funcl in + let pre, ctxt, details = doc_funcl_init types_mod avoid_target_names effect_info mutrec r funcl in let body = doc_funcl_body ctxt details in - pre,body + (pre, body) -let doc_fundef types_mod avoid_target_names effect_info (FD_aux(FD_function(r, typa, fcls),fannot)) = +let doc_fundef types_mod avoid_target_names effect_info (FD_aux (FD_function (r, typa, fcls), fannot)) = match fcls with | [] -> failwith "FD_function with empty function list" - | [FCL_aux (FCL_funcl(id,_),annot) as funcl] - when not (Env.is_extern id (env_of_tannot (snd annot)) "coq") -> - begin - let pre,body = doc_funcl types_mod avoid_target_names effect_info NotMutrec r funcl in - match r with - | Rec_aux (Rec_measure _,_) -> - group (pre ^^ dot ^^ hardline ^^ - surround 3 0 (string "exact (") body (string ").") ^^ - hardline ^^ string "Defined.") ^^ hardline - | _ -> group (prefix 3 1 (pre ^^ space ^^ coloneq) (body ^^ dot)) - end - | [_] -> empty (* extern *) + | [(FCL_aux (FCL_funcl (id, _), annot) as funcl)] when not (Env.is_extern id (env_of_tannot (snd annot)) "coq") -> + begin + let pre, body = doc_funcl types_mod avoid_target_names effect_info NotMutrec r funcl in + match r with + | Rec_aux (Rec_measure _, _) -> + group + (pre ^^ dot ^^ hardline + ^^ surround 3 0 (string "exact (") body (string ").") + ^^ hardline ^^ string "Defined." + ) + ^^ hardline + | _ -> group (prefix 3 1 (pre ^^ space ^^ coloneq) (body ^^ dot)) + end + | [_] -> empty (* extern *) | _ -> failwith "FD_function with more than one clause" - - let doc_dec avoid_target_names (DEC_aux (reg, (l, _))) = let bare_ctxt = { empty_ctxt with avoid_target_names } in match reg with | DEC_reg (typ, id, None) -> empty - (* + (* let env = env_of_annot annot in let rt = Env.base_typ_of env typ in if is_vector_typ rt then @@ -3020,183 +3034,229 @@ let doc_dec avoid_target_names (DEC_aux (reg, (l, _))) = else raise (Reporting.err_unreachable l __POS__ ("can't deal with register type " ^ string_of_typ typ)) else raise (Reporting.err_unreachable l __POS__ ("can't deal with register type " ^ string_of_typ typ)) *) (* For now treat configuration registers as regular registers *) - | DEC_reg (typ, id, Some exp) -> - empty - (*separate space [string "Definition"; doc_id bare_ctxt id; coloneq; doc_exp empty_ctxt false exp] ^^ dot ^^ hardline*) + | DEC_reg (typ, id, Some exp) -> empty +(*separate space [string "Definition"; doc_id bare_ctxt id; coloneq; doc_exp empty_ctxt false exp] ^^ dot ^^ hardline*) let is_field_accessor regtypes fdef = let is_field_of regtyp field = - List.exists (fun (tname, (_, _, fields)) -> tname = regtyp && - List.exists (fun (_, fid) -> string_of_id fid = field) fields) regtypes in + List.exists + (fun (tname, (_, _, fields)) -> tname = regtyp && List.exists (fun (_, fid) -> string_of_id fid = field) fields) + regtypes + in match Util.split_on_char '_' (string_of_id (id_of_fundef fdef)) with - | [access; regtyp; field] -> - (access = "get" || access = "set") && is_field_of regtyp field + | [access; regtyp; field] -> (access = "get" || access = "set") && is_field_of regtyp field | _ -> false - let int_of_field_index tname fid nexp = match int_of_nexp_opt nexp with | Some i -> i - | None -> raise (Reporting.err_typ Parse_ast.Unknown - ("Non-constant bitfield index in field " ^ string_of_id fid ^ " of " ^ tname)) + | None -> + raise + (Reporting.err_typ Parse_ast.Unknown + ("Non-constant bitfield index in field " ^ string_of_id fid ^ " of " ^ tname) + ) let doc_regtype_fields avoid_target_names (tname, (n1, n2, fields)) = let bare_ctxt = { empty_ctxt with avoid_target_names } in let const_int fid idx = int_of_field_index tname fid idx in - let i1, i2 = match n1, n2 with - | Nexp_aux(Nexp_constant i1,_),Nexp_aux(Nexp_constant i2,_) -> i1, i2 - | _ -> raise (Reporting.err_typ Parse_ast.Unknown - ("Non-constant indices in register type " ^ tname)) in + let i1, i2 = + match (n1, n2) with + | Nexp_aux (Nexp_constant i1, _), Nexp_aux (Nexp_constant i2, _) -> (i1, i2) + | _ -> raise (Reporting.err_typ Parse_ast.Unknown ("Non-constant indices in register type " ^ tname)) + in let dir_b = i1 < i2 in - let dir = (if dir_b then "true" else "false") in + let dir = if dir_b then "true" else "false" in let doc_field (fr, fid) = - let i, j = match fr with - | BF_aux (BF_single i, _) -> let i = const_int fid i in (i, i) - | BF_aux (BF_range (i, j), _) -> (const_int fid i, const_int fid j) - | _ -> raise (Reporting.err_unreachable Parse_ast.Unknown __POS__ - ("Unsupported type in field " ^ string_of_id fid ^ " of " ^ tname)) in + let i, j = + match fr with + | BF_aux (BF_single i, _) -> + let i = const_int fid i in + (i, i) + | BF_aux (BF_range (i, j), _) -> (const_int fid i, const_int fid j) + | _ -> + raise + (Reporting.err_unreachable Parse_ast.Unknown __POS__ + ("Unsupported type in field " ^ string_of_id fid ^ " of " ^ tname) + ) + in let fsize = Big_int.succ (Big_int.abs (Big_int.sub i j)) in (* TODO Assumes normalised, decreasing bitvector slices; however, since start indices or indexing order do not appear in Lem type annotations, this does not matter. *) let ftyp = vector_typ (nconstant fsize) dec_ord bit_typ in let reftyp = - mk_typ (Typ_app (Id_aux (Id "field_ref", Parse_ast.Unknown), - [mk_typ_arg (A_typ (mk_id_typ (mk_id tname))); - mk_typ_arg (A_typ ftyp)])) in + mk_typ + (Typ_app + ( Id_aux (Id "field_ref", Parse_ast.Unknown), + [mk_typ_arg (A_typ (mk_id_typ (mk_id tname))); mk_typ_arg (A_typ ftyp)] + ) + ) + in let rfannot = doc_tannot empty_ctxt Env.empty false reftyp in doc_op equals - (concat [string "let "; parens (concat [string tname; underscore; doc_id bare_ctxt fid; rfannot])]) - (concat [ - space; langlebar; string " field_name = \"" ^^ doc_id bare_ctxt fid ^^ string "\";"; hardline; - space; space; space; string (" field_start = " ^ Big_int.to_string i ^ ";"); hardline; - space; space; space; string (" field_is_inc = " ^ dir ^ ";"); hardline; - space; space; space; string (" get_field = get_" ^ tname ^ "_" ^ string_of_id fid ^ ";"); hardline; - space; space; space; string (" set_field = set_" ^ tname ^ "_" ^ string_of_id fid ^ " "); ranglebar]) + (concat [string "let "; parens (concat [string tname; underscore; doc_id bare_ctxt fid; rfannot])]) + (concat + [ + space; + langlebar; + string " field_name = \"" ^^ doc_id bare_ctxt fid ^^ string "\";"; + hardline; + space; + space; + space; + string (" field_start = " ^ Big_int.to_string i ^ ";"); + hardline; + space; + space; + space; + string (" field_is_inc = " ^ dir ^ ";"); + hardline; + space; + space; + space; + string (" get_field = get_" ^ tname ^ "_" ^ string_of_id fid ^ ";"); + hardline; + space; + space; + space; + string (" set_field = set_" ^ tname ^ "_" ^ string_of_id fid ^ " "); + ranglebar; + ] + ) in separate_map hardline doc_field fields (* Remove some type variables in a similar fashion to merge_kids_atoms *) -let doc_axiom_typschm typ_env is_monadic l (tqs,typ) = +let doc_axiom_typschm typ_env is_monadic l (tqs, typ) = let typ_env = Env.add_typquant l tqs typ_env in match typ with - | Typ_aux (Typ_fn (typs, ret_ty),l') -> - let check_typ (args,used) typ = - match Type_check.destruct_atom_nexp typ_env typ with - | Some (Nexp_aux (Nexp_var kid,_)) -> - if KidSet.mem kid used then args,used else - KidSet.add kid args, used - | Some _ -> args, used - | _ -> - match Type_check.destruct_atom_bool typ_env typ with - | Some (NC_aux (NC_var kid,_)) -> - if KidSet.mem kid used then args,used else - KidSet.add kid args, used - | _ -> - args, KidSet.union used (tyvars_of_typ typ) - in - let args, used = List.fold_left check_typ (KidSet.empty, KidSet.empty) typs in - let used = if is_number ret_ty then used else KidSet.union used (tyvars_of_typ ret_ty) in - let kopts,constraints = quant_split tqs in - let used = List.fold_left (fun used nc -> KidSet.union used (tyvars_of_constraint nc)) used constraints in - let tqs = match tqs with - | TypQ_aux (TypQ_tq qs,l) -> TypQ_aux (TypQ_tq (List.filter (function - | QI_aux (QI_id kopt,_) -> - let kid = kopt_kid kopt in - KidSet.mem kid used && not (KidSet.mem kid args) - | _ -> true) qs),l) - | _ -> tqs - in - let typ_count = ref 0 in - let fresh_var () = - let n = !typ_count in - let () = typ_count := n+1 in - string ("x" ^ string_of_int n) - in - let doc_typ' typ = - match Type_check.destruct_atom_nexp typ_env typ with - | Some (Nexp_aux (Nexp_var kid,_)) when KidSet.mem kid args -> - parens (doc_var empty_ctxt kid ^^ string " : Z") - (* This case is silly, but useful for tests *) - | Some (Nexp_aux (Nexp_constant n,_)) -> - let v = fresh_var () in - parens (v ^^ string " : Z") ^/^ - bquote ^^ braces (string "ArithFact " ^^ - parens (v ^^ string " =? " ^^ string (Big_int.to_string n))) - | _ -> - match Type_check.destruct_atom_bool typ_env typ with - | Some (NC_aux (NC_var kid,_)) when KidSet.mem kid args -> - parens (doc_var empty_ctxt kid ^^ string " : bool") - | _ -> - parens (underscore ^^ string " : " ^^ doc_typ empty_ctxt Env.empty typ) - in - let arg_typs_pp = separate space (List.map doc_typ' typs) in - let _, ret_ty = replace_atom_return_type ret_ty in - let ret_typ_pp = doc_typ empty_ctxt Env.empty ret_ty in - let ret_typ_pp = - if is_monadic - then string "M" ^^ space ^^ parens ret_typ_pp - else ret_typ_pp - in - let tyvars_pp, constrs_pp = doc_typquant_items_separate empty_ctxt typ_env braces tqs in - string "forall" ^/^ separate space tyvars_pp ^/^ - arg_typs_pp ^/^ separate space constrs_pp ^^ comma ^/^ ret_typ_pp - | _ -> doc_typschm empty_ctxt typ_env true (TypSchm_aux (TypSchm_ts (tqs,typ),l)) - -let doc_val_spec def_annot unimplemented avoid_target_names effect_info (VS_aux (VS_val_spec(_,id,_,_),(l,ann)) as vs) = + | Typ_aux (Typ_fn (typs, ret_ty), l') -> + let check_typ (args, used) typ = + match Type_check.destruct_atom_nexp typ_env typ with + | Some (Nexp_aux (Nexp_var kid, _)) -> if KidSet.mem kid used then (args, used) else (KidSet.add kid args, used) + | Some _ -> (args, used) + | _ -> ( + match Type_check.destruct_atom_bool typ_env typ with + | Some (NC_aux (NC_var kid, _)) -> if KidSet.mem kid used then (args, used) else (KidSet.add kid args, used) + | _ -> (args, KidSet.union used (tyvars_of_typ typ)) + ) + in + let args, used = List.fold_left check_typ (KidSet.empty, KidSet.empty) typs in + let used = if is_number ret_ty then used else KidSet.union used (tyvars_of_typ ret_ty) in + let kopts, constraints = quant_split tqs in + let used = List.fold_left (fun used nc -> KidSet.union used (tyvars_of_constraint nc)) used constraints in + let tqs = + match tqs with + | TypQ_aux (TypQ_tq qs, l) -> + TypQ_aux + ( TypQ_tq + (List.filter + (function + | QI_aux (QI_id kopt, _) -> + let kid = kopt_kid kopt in + KidSet.mem kid used && not (KidSet.mem kid args) + | _ -> true + ) + qs + ), + l + ) + | _ -> tqs + in + let typ_count = ref 0 in + let fresh_var () = + let n = !typ_count in + let () = typ_count := n + 1 in + string ("x" ^ string_of_int n) + in + let doc_typ' typ = + match Type_check.destruct_atom_nexp typ_env typ with + | Some (Nexp_aux (Nexp_var kid, _)) when KidSet.mem kid args -> parens (doc_var empty_ctxt kid ^^ string " : Z") + (* This case is silly, but useful for tests *) + | Some (Nexp_aux (Nexp_constant n, _)) -> + let v = fresh_var () in + parens (v ^^ string " : Z") + ^/^ bquote + ^^ braces (string "ArithFact " ^^ parens (v ^^ string " =? " ^^ string (Big_int.to_string n))) + | _ -> ( + match Type_check.destruct_atom_bool typ_env typ with + | Some (NC_aux (NC_var kid, _)) when KidSet.mem kid args -> + parens (doc_var empty_ctxt kid ^^ string " : bool") + | _ -> parens (underscore ^^ string " : " ^^ doc_typ empty_ctxt Env.empty typ) + ) + in + let arg_typs_pp = separate space (List.map doc_typ' typs) in + let _, ret_ty = replace_atom_return_type ret_ty in + let ret_typ_pp = doc_typ empty_ctxt Env.empty ret_ty in + let ret_typ_pp = if is_monadic then string "M" ^^ space ^^ parens ret_typ_pp else ret_typ_pp in + let tyvars_pp, constrs_pp = doc_typquant_items_separate empty_ctxt typ_env braces tqs in + string "forall" ^/^ separate space tyvars_pp ^/^ arg_typs_pp ^/^ separate space constrs_pp ^^ comma ^/^ ret_typ_pp + | _ -> doc_typschm empty_ctxt typ_env true (TypSchm_aux (TypSchm_ts (tqs, typ), l)) + +let doc_val_spec def_annot unimplemented avoid_target_names effect_info + (VS_aux (VS_val_spec (_, id, _, _), (l, ann)) as vs) = let bare_ctxt = { empty_ctxt with avoid_target_names } in - if !opt_undef_axioms && IdSet.mem id unimplemented then - let typ_env = env_of_annot (l,ann) in + if !opt_undef_axioms && IdSet.mem id unimplemented then ( + let typ_env = env_of_annot (l, ann) in (* The type checker will expand the type scheme, and we need to look at the environment afterwards to find it. *) let _, next_env = check_val_spec typ_env def_annot (strip_val_spec vs) in let tys = Env.get_val_spec id next_env in let is_monadic = not (Effects.function_is_pure id effect_info) in - group (separate space - [string "Axiom"; doc_id bare_ctxt id; colon; doc_axiom_typschm typ_env is_monadic l tys] ^^ dot) ^/^ hardline + group + (separate space [string "Axiom"; doc_id bare_ctxt id; colon; doc_axiom_typschm typ_env is_monadic l tys] ^^ dot) + ^/^ hardline + ) else empty (* Type signatures appear in definitions *) (* If a top-level value is declared with an existential type, we turn it into a type annotation expression instead (unless it duplicates an existing one). *) let doc_val avoid_target_names pat exp = let bare_ctxt = { empty_ctxt with avoid_target_names } in - let (id,pat_typ) = match pat with - | P_aux (P_typ (typ, P_aux (P_id id,_)),_) -> id, Some typ - | P_aux (P_id id, _) -> id, None - | P_aux (P_var (P_aux (P_id id, _), TP_aux (TP_var kid, _)),_) when Id.compare id (id_of_kid kid) == 0 -> - id, None - | P_aux (P_typ (typ, P_aux (P_var (P_aux (P_id id, _), TP_aux (TP_var kid, _)),_)),_) when Id.compare id (id_of_kid kid) == 0 -> - id, Some typ - | _ -> raise (Reporting.err_todo (pat_loc pat) - "Top-level value definition with complex pattern not supported for Coq yet") + let id, pat_typ = + match pat with + | P_aux (P_typ (typ, P_aux (P_id id, _)), _) -> (id, Some typ) + | P_aux (P_id id, _) -> (id, None) + | P_aux (P_var (P_aux (P_id id, _), TP_aux (TP_var kid, _)), _) when Id.compare id (id_of_kid kid) == 0 -> (id, None) + | P_aux (P_typ (typ, P_aux (P_var (P_aux (P_id id, _), TP_aux (TP_var kid, _)), _)), _) + when Id.compare id (id_of_kid kid) == 0 -> + (id, Some typ) + | _ -> + raise + (Reporting.err_todo (pat_loc pat) "Top-level value definition with complex pattern not supported for Coq yet") in - let typpp = match pat_typ with - | None -> empty - | Some typ -> space ^^ colon ^^ space ^^ doc_typ empty_ctxt Env.empty typ + let typpp = + match pat_typ with None -> empty | Some typ -> space ^^ colon ^^ space ^^ doc_typ empty_ctxt Env.empty typ in let env = env_of exp in - let ctxt = { empty_ctxt with debug = List.mem (string_of_id id) (!opt_debug_on) } in + let ctxt = { empty_ctxt with debug = List.mem (string_of_id id) !opt_debug_on } in let () = debug ctxt (lazy ("Checking definition " ^ string_of_id id)); debug_depth := 1 in let typpp, exp = match pat_typ with - | None -> typpp, exp - | Some typ -> - let typ = expand_range_type (Env.expand_synonyms env typ) in - match destruct_exist_plain typ with - | None -> typpp, exp - | Some _ -> - empty, match exp with - | E_aux (E_typ (typ',_),_) when alpha_equivalent env typ typ' -> exp - | _ -> E_aux (E_typ (typ,exp), (Parse_ast.Unknown, mk_tannot env typ)) + | None -> (typpp, exp) + | Some typ -> ( + let typ = expand_range_type (Env.expand_synonyms env typ) in + match destruct_exist_plain typ with + | None -> (typpp, exp) + | Some _ -> ( + ( empty, + match exp with + | E_aux (E_typ (typ', _), _) when alpha_equivalent env typ typ' -> exp + | _ -> E_aux (E_typ (typ, exp), (Parse_ast.Unknown, mk_tannot env typ)) + ) + ) + ) in let idpp = doc_id bare_ctxt id in let base_pp = doc_exp ctxt false exp ^^ dot in let () = debug_depth := 0 in - group (string "Definition" ^^ space ^^ idpp ^^ typpp ^^ space ^^ coloneq ^/^ base_pp) ^^ hardline ^^ - group (separate space [string "#[export] Hint Unfold"; idpp; colon; string "sail."]) ^^ hardline + group (string "Definition" ^^ space ^^ idpp ^^ typpp ^^ space ^^ coloneq ^/^ base_pp) + ^^ hardline + ^^ group (separate space [string "#[export] Hint Unfold"; idpp; colon; string "sail."]) + ^^ hardline let doc_def types_mod unimplemented avoid_target_names generic_eq_types effect_info (DEF_aux (aux, def_annot) as def) = match aux with @@ -3205,46 +3265,39 @@ let doc_def types_mod unimplemented avoid_target_names generic_eq_types effect_i | DEF_overload _ -> empty | DEF_type t_def -> doc_typdef types_mod avoid_target_names generic_eq_types t_def | DEF_register dec -> group (doc_dec avoid_target_names dec) - | DEF_default df -> empty | DEF_fundef fdef -> group (doc_fundef types_mod avoid_target_names effect_info fdef) ^/^ hardline - | DEF_internal_mutrec fundefs -> doc_mutrec types_mod avoid_target_names effect_info (ids_of_def def) fundefs ^/^ hardline + | DEF_internal_mutrec fundefs -> + doc_mutrec types_mod avoid_target_names effect_info (ids_of_def def) fundefs ^/^ hardline | DEF_let (LB_aux (LB_val (pat, exp), _)) -> doc_val avoid_target_names pat exp | DEF_scattered sdef -> failwith "doc_def: shoulnd't have DEF_scattered at this point" - | DEF_mapdef (MD_aux (_, (l,_))) -> unreachable l __POS__ "Coq doesn't support mappings" + | DEF_mapdef (MD_aux (_, (l, _))) -> unreachable l __POS__ "Coq doesn't support mappings" | DEF_pragma _ -> empty - | DEF_measure (id,_,_) -> unreachable (id_loc id) __POS__ - ("Termination measure for " ^ string_of_id id ^ - " should have been rewritten before backend") - | DEF_loop_measures (id,_) -> - unreachable (id_loc id) __POS__ - ("Loop termination measures for " ^ string_of_id id ^ - " should have been rewritten before backend") - | (DEF_impl _ | DEF_outcome _ | DEF_instantiation _) -> - unreachable (def_loc def) __POS__ "Event definition should have been rewritten before backend" - - + | DEF_measure (id, _, _) -> + unreachable (id_loc id) __POS__ + ("Termination measure for " ^ string_of_id id ^ " should have been rewritten before backend") + | DEF_loop_measures (id, _) -> + unreachable (id_loc id) __POS__ + ("Loop termination measures for " ^ string_of_id id ^ " should have been rewritten before backend") + | DEF_impl _ | DEF_outcome _ | DEF_instantiation _ -> + unreachable (def_loc def) __POS__ "Event definition should have been rewritten before backend" + let find_exc_typ defs = let is_exc_typ_def = function | DEF_aux (DEF_type td, _) -> string_of_id (id_of_type_def td) = "exception" - | _ -> false in + | _ -> false + in if List.exists is_exc_typ_def defs then "exception" else "unit" let find_unimplemented defs = - let adjust_fundef unimplemented (FD_aux (FD_function (_,_,funcls),_)) = - match funcls with - | [] -> unimplemented - | (FCL_aux (FCL_funcl (id,_),_))::_ -> - IdSet.remove id unimplemented + let adjust_fundef unimplemented (FD_aux (FD_function (_, _, funcls), _)) = + match funcls with [] -> unimplemented | FCL_aux (FCL_funcl (id, _), _) :: _ -> IdSet.remove id unimplemented in let adjust_def unimplemented = function - | DEF_aux (DEF_val (VS_aux (VS_val_spec (_,id,exts,_),_)),_) -> begin - match Ast_util.extern_assoc "coq" exts with - | Some _ -> unimplemented - | None -> IdSet.add id unimplemented - end - | DEF_aux (DEF_internal_mutrec fds, _) -> - List.fold_left adjust_fundef unimplemented fds + | DEF_aux (DEF_val (VS_aux (VS_val_spec (_, id, exts, _), _)), _) -> begin + match Ast_util.extern_assoc "coq" exts with Some _ -> unimplemented | None -> IdSet.add id unimplemented + end + | DEF_aux (DEF_internal_mutrec fds, _) -> List.fold_left adjust_fundef unimplemented fds | DEF_aux (DEF_fundef fd, _) -> adjust_fundef unimplemented fd | _ -> unimplemented in @@ -3252,87 +3305,108 @@ let find_unimplemented defs = let builtin_target_names defs = let check_def names = function - | DEF_aux (DEF_val (VS_aux (VS_val_spec (_,_,exts,_),_)),_) -> begin - match Ast_util.extern_assoc "coq" exts with - | Some name -> StringSet.add name names - | None -> names + | DEF_aux (DEF_val (VS_aux (VS_val_spec (_, _, exts, _), _)), _) -> begin + match Ast_util.extern_assoc "coq" exts with Some name -> StringSet.add name names | None -> names end | _ -> names - in List.fold_left check_def StringSet.empty defs - -let pp_ast_coq (types_file,types_modules) (defs_file,defs_modules) type_defs_module effect_info { defs; _ } top_line suppress_MR_M = -try - (* let regtypes = find_regtypes d in *) - let state_ids = - State.generate_regstate_defs true defs - |> val_spec_ids - in - let is_state_def = function - | DEF_aux (DEF_val vs, _) -> IdSet.mem (id_of_val_spec vs) state_ids - | DEF_aux (DEF_fundef fd, _) -> IdSet.mem (id_of_fundef fd) state_ids - | _ -> false - in - let is_typ_def = function - | DEF_aux (DEF_type _, _) -> true - | _ -> false - in - let exc_typ = find_exc_typ defs in - let typdefs, defs = List.partition is_typ_def defs in - let statedefs, defs = List.partition is_state_def defs in - let unimplemented = find_unimplemented defs in - let avoid_target_names = builtin_target_names defs in - let bare_doc_id = doc_id { empty_ctxt with avoid_target_names } in - let register_refs = State.register_refs_coq bare_doc_id (State.find_registers defs) in - let generic_eq_types = types_used_with_generic_eq defs in - let doc_def = doc_def type_defs_module unimplemented avoid_target_names generic_eq_types effect_info in - let () = if !opt_undef_axioms || IdSet.is_empty unimplemented then () else - Reporting.print_err Parse_ast.Unknown "Warning" - ("The following functions were declared but are undefined:\n" ^ - String.concat "\n" (List.map string_of_id (IdSet.elements unimplemented))) - in - (print types_file) - (concat - [string "(*" ^^ (string top_line) ^^ string "*)";hardline; - (separate_map hardline) - (fun lib -> separate space [string "Require Import";string lib] ^^ dot) types_modules;hardline; - string "Import ListNotations."; - hardline; - string "Open Scope string."; hardline; - string "Open Scope bool."; hardline; - string "Open Scope Z."; hardline; - hardline; - separate empty (List.map doc_def typdefs); hardline; - hardline; - separate empty (List.map doc_def statedefs); hardline; - hardline; - register_refs; hardline; - (if suppress_MR_M then empty else concat [ - string ("Definition MR a r := monadR register_value a r " ^ exc_typ ^ "."); hardline; - string ("Definition M a := monad register_value a " ^ exc_typ ^ "."); hardline; - string ("Definition returnM {A:Type} := @returnm register_value A " ^ exc_typ ^ "."); hardline; - string ("Definition returnR {A:Type} (R:Type) := @returnm register_value A (R + " ^ exc_typ ^ ")."); hardline - ]) - ]); - (print defs_file) - (concat - [string "(*" ^^ (string top_line) ^^ string "*)";hardline; - (separate_map hardline) - (fun lib -> separate space [string "Require Import";string lib] ^^ dot) defs_modules;hardline; - string "Import ListNotations."; - hardline; - string "Open Scope string."; hardline; - string "Open Scope bool."; hardline; - string "Open Scope Z."; hardline; - hardline; - hardline; - separate empty (List.map doc_def defs); - hardline; - hardline]) -with Type_check.Type_error (env,l,err) -> - let extra = - "\nError during Coq printing\n" ^ - if Printexc.backtrace_status () - then "\n" ^ Printexc.get_backtrace () - else "(backtracing unavailable)" in - raise (Reporting.err_typ l (Type_error.string_of_type_error err ^ extra)) + List.fold_left check_def StringSet.empty defs + +let pp_ast_coq (types_file, types_modules) (defs_file, defs_modules) type_defs_module effect_info { defs; _ } top_line + suppress_MR_M = + try + (* let regtypes = find_regtypes d in *) + let state_ids = State.generate_regstate_defs true defs |> val_spec_ids in + let is_state_def = function + | DEF_aux (DEF_val vs, _) -> IdSet.mem (id_of_val_spec vs) state_ids + | DEF_aux (DEF_fundef fd, _) -> IdSet.mem (id_of_fundef fd) state_ids + | _ -> false + in + let is_typ_def = function DEF_aux (DEF_type _, _) -> true | _ -> false in + let exc_typ = find_exc_typ defs in + let typdefs, defs = List.partition is_typ_def defs in + let statedefs, defs = List.partition is_state_def defs in + let unimplemented = find_unimplemented defs in + let avoid_target_names = builtin_target_names defs in + let bare_doc_id = doc_id { empty_ctxt with avoid_target_names } in + let register_refs = State.register_refs_coq bare_doc_id (State.find_registers defs) in + let generic_eq_types = types_used_with_generic_eq defs in + let doc_def = doc_def type_defs_module unimplemented avoid_target_names generic_eq_types effect_info in + let () = + if !opt_undef_axioms || IdSet.is_empty unimplemented then () + else + Reporting.print_err Parse_ast.Unknown "Warning" + ("The following functions were declared but are undefined:\n" + ^ String.concat "\n" (List.map string_of_id (IdSet.elements unimplemented)) + ) + in + (print types_file) + (concat + [ + string "(*" ^^ string top_line ^^ string "*)"; + hardline; + (separate_map hardline) + (fun lib -> separate space [string "Require Import"; string lib] ^^ dot) + types_modules; + hardline; + string "Import ListNotations."; + hardline; + string "Open Scope string."; + hardline; + string "Open Scope bool."; + hardline; + string "Open Scope Z."; + hardline; + hardline; + separate empty (List.map doc_def typdefs); + hardline; + hardline; + separate empty (List.map doc_def statedefs); + hardline; + hardline; + register_refs; + hardline; + ( if suppress_MR_M then empty + else + concat + [ + string ("Definition MR a r := monadR register_value a r " ^ exc_typ ^ "."); + hardline; + string ("Definition M a := monad register_value a " ^ exc_typ ^ "."); + hardline; + string ("Definition returnM {A:Type} := @returnm register_value A " ^ exc_typ ^ "."); + hardline; + string ("Definition returnR {A:Type} (R:Type) := @returnm register_value A (R + " ^ exc_typ ^ ")."); + hardline; + ] + ); + ] + ); + (print defs_file) + (concat + [ + string "(*" ^^ string top_line ^^ string "*)"; + hardline; + (separate_map hardline) (fun lib -> separate space [string "Require Import"; string lib] ^^ dot) defs_modules; + hardline; + string "Import ListNotations."; + hardline; + string "Open Scope string."; + hardline; + string "Open Scope bool."; + hardline; + string "Open Scope Z."; + hardline; + hardline; + hardline; + separate empty (List.map doc_def defs); + hardline; + hardline; + ] + ) + with Type_check.Type_error (env, l, err) -> + let extra = + "\nError during Coq printing\n" + ^ if Printexc.backtrace_status () then "\n" ^ Printexc.get_backtrace () else "(backtracing unavailable)" + in + raise (Reporting.err_typ l (Type_error.string_of_type_error err ^ extra)) diff --git a/src/sail_coq_backend/sail_plugin_coq.ml b/src/sail_coq_backend/sail_plugin_coq.ml index b24b8ad01..9b323f29a 100644 --- a/src/sail_coq_backend/sail_plugin_coq.ml +++ b/src/sail_coq_backend/sail_plugin_coq.ml @@ -72,28 +72,37 @@ let opt_libs_coq : string list ref = ref [] let opt_alt_modules_coq : string list ref = ref [] let opt_alt_modules2_coq : string list ref = ref [] -let coq_options = [ - ( "-coq_output_dir", - Arg.String (fun dir -> opt_coq_output_dir := Some dir), - " set a custom directory to output generated Coq"); - ( "-coq_lib", - Arg.String (fun l -> opt_libs_coq := l::!opt_libs_coq), - " provide additional library to open in Coq output"); - ( "-coq_alt_modules", - Arg.String (fun l -> opt_alt_modules_coq := l::!opt_alt_modules_coq), - " provide alternative modules to open in Coq output"); - ( "-coq_alt_modules2", - Arg.String (fun l -> opt_alt_modules2_coq := l::!opt_alt_modules2_coq), - " provide additional alternative modules to open only in main (non-_types) Coq output, and suppress default definitions of MR and M monads"); - ( "-dcoq_undef_axioms", - Arg.Set Pretty_print_coq.opt_undef_axioms, - " (debug) generate axioms for functions that are declared but not defined"); - ( "-dcoq_warn_nonex", - Arg.Set Rewrites.opt_coq_warn_nonexhaustive, - " (debug) generate warnings for non-exhaustive pattern matches in the Coq backend"); - ( "-dcoq_debug_on", - Arg.String (fun f -> Pretty_print_coq.opt_debug_on := f::!Pretty_print_coq.opt_debug_on), - " (debug) produce debug messages for Coq output on given function"); +let coq_options = + [ + ( "-coq_output_dir", + Arg.String (fun dir -> opt_coq_output_dir := Some dir), + " set a custom directory to output generated Coq" + ); + ( "-coq_lib", + Arg.String (fun l -> opt_libs_coq := l :: !opt_libs_coq), + " provide additional library to open in Coq output" + ); + ( "-coq_alt_modules", + Arg.String (fun l -> opt_alt_modules_coq := l :: !opt_alt_modules_coq), + " provide alternative modules to open in Coq output" + ); + ( "-coq_alt_modules2", + Arg.String (fun l -> opt_alt_modules2_coq := l :: !opt_alt_modules2_coq), + " provide additional alternative modules to open only in main (non-_types) Coq output, and suppress \ + default definitions of MR and M monads" + ); + ( "-dcoq_undef_axioms", + Arg.Set Pretty_print_coq.opt_undef_axioms, + " (debug) generate axioms for functions that are declared but not defined" + ); + ( "-dcoq_warn_nonex", + Arg.Set Rewrites.opt_coq_warn_nonexhaustive, + " (debug) generate warnings for non-exhaustive pattern matches in the Coq backend" + ); + ( "-dcoq_debug_on", + Arg.String (fun f -> Pretty_print_coq.opt_debug_on := f :: !Pretty_print_coq.opt_debug_on), + " (debug) produce debug messages for Coq output on given function" + ); ] let coq_rewrites = @@ -136,7 +145,7 @@ let coq_rewrites = ("recheck_defs", []); ("make_cases_exhaustive", []); (* merge funcls before adding the measure argument so that it doesn't - disappear into an internal pattern match *) + disappear into an internal pattern match *) ("merge_function_clauses", []); ("recheck_defs", []); ("rewrite_explicit_measure", []); @@ -152,32 +161,28 @@ let coq_rewrites = ("remove_superfluous_returns", []); ("bit_lists_to_lits", []); ("recheck_defs", []); - ("attach_effects", []) + ("attach_effects", []); ] - -let generated_line f = - Printf.sprintf "Generated by Sail from %s." f + +let generated_line f = Printf.sprintf "Generated by Sail from %s." f let output_coq opt_dir filename alt_modules alt_modules2 libs effect_info ast = let generated_line = generated_line filename in - let types_module = (filename ^ "_types") in + let types_module = filename ^ "_types" in let base_imports_default = ["Sail.Base"; "Sail.Real"] in let base_imports = - (match alt_modules with + match alt_modules with | [] -> base_imports_default | _ -> Str.split (Str.regexp "[ \t]+") (String.concat " " alt_modules) - ) in - let ((ot,_,_,_) as ext_ot) = - Util.open_output_with_check_unformatted opt_dir (types_module ^ ".v") in - let ((o,_,_,_) as ext_o) = - Util.open_output_with_check_unformatted opt_dir (filename ^ ".v") in - (Pretty_print_coq.pp_ast_coq - (ot, base_imports) + in + let ((ot, _, _, _) as ext_ot) = Util.open_output_with_check_unformatted opt_dir (types_module ^ ".v") in + let ((o, _, _, _) as ext_o) = Util.open_output_with_check_unformatted opt_dir (filename ^ ".v") in + (Pretty_print_coq.pp_ast_coq (ot, base_imports) (o, base_imports @ (types_module :: libs) @ alt_modules2) - types_module - effect_info - ast generated_line) - (alt_modules2 <> []); (* suppress MR and M defns if alt_modules2 present*) + types_module effect_info ast generated_line + ) + (alt_modules2 <> []); + (* suppress MR and M defns if alt_modules2 present*) Util.close_output_with_check ext_ot; Util.close_output_with_check ext_o @@ -185,17 +190,12 @@ let output libs files = List.iter (fun (f, effect_info, _, ast) -> let f' = Filename.basename (Filename.remove_extension f) in - output_coq !opt_coq_output_dir f' !opt_alt_modules_coq !opt_alt_modules2_coq libs effect_info ast) + output_coq !opt_coq_output_dir f' !opt_alt_modules_coq !opt_alt_modules2_coq libs effect_info ast + ) files let coq_target _ out_file ast effect_info env = let out_file = match out_file with Some f -> f | None -> "out" in - output (!opt_libs_coq) [(out_file, effect_info, env, ast)] + output !opt_libs_coq [(out_file, effect_info, env, ast)] -let _ = - Target.register - ~name:"coq" - ~options:coq_options - ~rewrites:coq_rewrites - ~asserts_termination:true - coq_target +let _ = Target.register ~name:"coq" ~options:coq_options ~rewrites:coq_rewrites ~asserts_termination:true coq_target diff --git a/src/sail_doc_backend/docinfo.ml b/src/sail_doc_backend/docinfo.ml index e027d9133..db040bbbe 100644 --- a/src/sail_doc_backend/docinfo.ml +++ b/src/sail_doc_backend/docinfo.ml @@ -83,8 +83,7 @@ open Ast_util info with what the external tooling supports. *) let docinfo_version = 1 -let same_file f1 f2 = - Filename.basename f1 = Filename.basename f2 && Filename.dirname f1 = Filename.dirname f2 +let same_file f1 f2 = Filename.basename f1 = Filename.basename f2 && Filename.dirname f1 = Filename.dirname f2 let process_file f filename = let chan = open_in filename in @@ -104,104 +103,104 @@ let process_file f filename = let read_source (p1 : Lexing.position) (p2 : Lexing.position) = process_file (fun contents -> String.sub contents p1.pos_cnum (p2.pos_cnum - p1.pos_cnum)) p1.pos_fname -let hash_file filename = - process_file Digest.string filename |> Digest.to_hex +let hash_file filename = process_file Digest.string filename |> Digest.to_hex type embedding = Plain | Base64 -let embedding_string = function - | Plain -> "plain" - | Base64 -> "base64" +let embedding_string = function Plain -> "plain" | Base64 -> "base64" let bindings_to_json b f = - Bindings.bindings b - |> List.map (fun (key, elem) -> (string_of_id key, f elem)) - |> (fun elements -> `Assoc elements) + Bindings.bindings b |> List.map (fun (key, elem) -> (string_of_id key, f elem)) |> fun elements -> `Assoc elements -type location_or_raw = - | Raw of string - | Location of string * int * int * int * int * int * int +type location_or_raw = Raw of string | Location of string * int * int * int * int * int * int let location_or_raw_to_json = function | Raw s -> `String s | Location (fname, line1, bol1, char1, line2, bol2, char2) -> - `Assoc [("file", `String fname); ("loc", `List [`Int line1; `Int bol1; `Int char1; `Int line2; `Int bol2; `Int char2])] + `Assoc + [("file", `String fname); ("loc", `List [`Int line1; `Int bol1; `Int char1; `Int line2; `Int bol2; `Int char2])] type hyper_location = string * int * int let included_loc files l = match Reporting.loc_file l with - | Some file -> - Util.list_empty files || List.exists (same_file file) files - | None -> - Util.list_empty files + | Some file -> Util.list_empty files || List.exists (same_file file) files + | None -> Util.list_empty files let hyper_loc l = match Reporting.simp_loc l with | Some (p1, p2) when p1.pos_fname = p2.pos_fname && Filename.is_relative p1.pos_fname -> - Some (p1.pos_fname, p1.pos_cnum, p2.pos_cnum) + Some (p1.pos_fname, p1.pos_cnum, p2.pos_cnum) | _ -> None -type hyperlink = - | Function of id * hyper_location - | Register of id * hyper_location +type hyperlink = Function of id * hyper_location | Register of id * hyper_location let hyperlink_to_json = function | Function (id, (file, c1, c2)) -> - `Assoc [("type", `String "function"); ("id", `String (string_of_id id)); ("file", `String file); ("loc", `List [`Int c1; `Int c2])] + `Assoc + [ + ("type", `String "function"); + ("id", `String (string_of_id id)); + ("file", `String file); + ("loc", `List [`Int c1; `Int c2]); + ] | Register (id, (file, c1, c2)) -> - `Assoc [("type", `String "register"); ("id", `String (string_of_id id)); ("file", `String file); ("loc", `List [`Int c1; `Int c2])] + `Assoc + [ + ("type", `String "register"); + ("id", `String (string_of_id id)); + ("file", `String file); + ("loc", `List [`Int c1; `Int c2]); + ] -let hyperlinks_to_json = function - | [] -> `Null - | links -> `List (List.map hyperlink_to_json links) +let hyperlinks_to_json = function [] -> `Null | links -> `List (List.map hyperlink_to_json links) let hyperlinks_from_def files def = let open Rewriter in let links = ref [] in let link f l = - if included_loc files l then ( - match hyper_loc l with - | Some hloc -> links := f hloc :: !links - | None -> () - ) in + if included_loc files l then (match hyper_loc l with Some hloc -> links := f hloc :: !links | None -> ()) + in let scan_lexp lexp_aux annot = let env = Type_check.env_of_annot annot in - begin match lexp_aux with - | LE_typ (_, id) | LE_id id -> - begin match Type_check.Env.lookup_id id env with - | Register _ -> - link (fun hloc -> Register (id, hloc)) (id_loc id) - | _ -> () - end - | _ -> () + begin + match lexp_aux with + | LE_typ (_, id) | LE_id id -> begin + match Type_check.Env.lookup_id id env with + | Register _ -> link (fun hloc -> Register (id, hloc)) (id_loc id) + | _ -> () + end + | _ -> () end; LE_aux (lexp_aux, annot) in let scan_exp e_aux annot = let env = Type_check.env_of_annot annot in - begin match e_aux with - | E_id id -> - begin match Type_check.Env.lookup_id id env with - | Register _ -> - link (fun hloc -> Register (id, hloc)) (id_loc id) - | _ -> () - end - | E_app (f, _) -> - link (fun hloc -> Function (f, hloc)) (id_loc f) - | _ -> () + begin + match e_aux with + | E_id id -> begin + match Type_check.Env.lookup_id id env with + | Register _ -> link (fun hloc -> Register (id, hloc)) (id_loc id) + | _ -> () + end + | E_app (f, _) -> link (fun hloc -> Function (f, hloc)) (id_loc f) + | _ -> () end; E_aux (e_aux, annot) in let rw_exp _ exp = - fold_exp { - id_exp_alg with e_aux = (fun (e_aux, annot) -> scan_exp e_aux annot); - le_aux = (fun (l_aux, annot) -> scan_lexp l_aux annot) - } exp in + fold_exp + { + id_exp_alg with + e_aux = (fun (e_aux, annot) -> scan_exp e_aux annot); + le_aux = (fun (l_aux, annot) -> scan_lexp l_aux annot); + } + exp + in ignore (rewrite_ast_defs { rewriters_base with rewrite_exp = rw_exp } [def]); !links @@ -216,10 +215,18 @@ let rec pat_to_json (P_aux (aux, _)) = | P_typ (_, pat) -> pat_to_json pat | P_id id -> `Assoc [pat_type "id"; ("id", `String (string_of_id id))] | P_var (pat, _) -> `Assoc [pat_type "var"; ("pattern", pat_to_json pat)] - | P_app (id, pats) -> `Assoc [pat_type "app"; ("id", `String (string_of_id id)); ("patterns", `List (List.map pat_to_json pats))] + | P_app (id, pats) -> + `Assoc [pat_type "app"; ("id", `String (string_of_id id)); ("patterns", `List (List.map pat_to_json pats))] | P_vector pats -> seq_pat_json "vector" pats | P_vector_concat pats -> seq_pat_json "vector_concat" pats - | P_vector_subrange (id, n, m) -> `Assoc [pat_type "vector_subrange"; ("id", `String (string_of_id id)); ("from", `Int (Big_int.to_int n)); ("to", `Int (Big_int.to_int m))] + | P_vector_subrange (id, n, m) -> + `Assoc + [ + pat_type "vector_subrange"; + ("id", `String (string_of_id id)); + ("from", `Int (Big_int.to_int n)); + ("to", `Int (Big_int.to_int m)); + ] | P_tuple pats -> seq_pat_json "tuple" pats | P_list pats -> seq_pat_json "list" pats | P_cons (pat_hd, pat_tl) -> `Assoc [pat_type "cons"; ("hd", pat_to_json pat_hd); ("tl", pat_to_json pat_tl)] @@ -227,109 +234,81 @@ let rec pat_to_json (P_aux (aux, _)) = | P_or _ | P_not _ -> `Null type 'a function_clause_doc = { - number : int; - source : location_or_raw; - pat : 'a pat; - wavedrom : string option; - guard_source : location_or_raw option; - body_source : location_or_raw; - comment : string option; - splits : location_or_raw Bindings.t option; - } + number : int; + source : location_or_raw; + pat : 'a pat; + wavedrom : string option; + guard_source : location_or_raw option; + body_source : location_or_raw; + comment : string option; + splits : location_or_raw Bindings.t option; +} let function_clause_doc_to_json docinfo = - `Assoc ( - [ - ("number", `Int docinfo.number); - ("source", location_or_raw_to_json docinfo.source); - ("pattern", pat_to_json docinfo.pat); - ] - @ (match docinfo.wavedrom with Some w -> [("wavedrom", `String w)] | None -> []) - @ (match docinfo.comment with Some s -> [("comment", `String s)] | None -> []) - @ (match docinfo.guard_source with Some s -> [("guard", location_or_raw_to_json s)] | None -> []) - @ [ - ("body", location_or_raw_to_json docinfo.body_source) - ] - @ (match docinfo.splits with Some s -> [("splits", bindings_to_json s location_or_raw_to_json)] | None -> []) + `Assoc + ([ + ("number", `Int docinfo.number); + ("source", location_or_raw_to_json docinfo.source); + ("pattern", pat_to_json docinfo.pat); + ] + @ (match docinfo.wavedrom with Some w -> [("wavedrom", `String w)] | None -> []) + @ (match docinfo.comment with Some s -> [("comment", `String s)] | None -> []) + @ (match docinfo.guard_source with Some s -> [("guard", location_or_raw_to_json s)] | None -> []) + @ [("body", location_or_raw_to_json docinfo.body_source)] + @ match docinfo.splits with Some s -> [("splits", bindings_to_json s location_or_raw_to_json)] | None -> [] ) -type 'a function_doc = - | Multiple_clauses of 'a function_clause_doc list - | Single_clause of 'a function_clause_doc +type 'a function_doc = Multiple_clauses of 'a function_clause_doc list | Single_clause of 'a function_clause_doc let function_doc_to_json = function | Multiple_clauses docinfos -> `List (List.map function_clause_doc_to_json docinfos) | Single_clause docinfo -> function_clause_doc_to_json docinfo type 'a mapping_clause_doc = { - number : int; - source : location_or_raw; - left : 'a pat option; - left_wavedrom : string option; - right : 'a pat option; - right_wavedrom : string option; - body : location_or_raw option; - } + number : int; + source : location_or_raw; + left : 'a pat option; + left_wavedrom : string option; + right : 'a pat option; + right_wavedrom : string option; + body : location_or_raw option; +} let mapping_clause_doc_to_json docinfo = - `Assoc ( - [ - ("number", `Int docinfo.number); - ("source", location_or_raw_to_json docinfo.source) - ] - @ (match docinfo.left with Some p -> [("left", pat_to_json p)] | None -> []) - @ (match docinfo.left_wavedrom with Some w -> [("left_wavedrom", `String w)] | None -> []) - @ (match docinfo.right with Some p -> [("right", pat_to_json p)] | None -> []) - @ (match docinfo.right_wavedrom with Some w -> [("right_wavedrom", `String w)] | None -> []) - @ (match docinfo.body with Some s -> [("body", location_or_raw_to_json s)] | None -> []) + `Assoc + ([("number", `Int docinfo.number); ("source", location_or_raw_to_json docinfo.source)] + @ (match docinfo.left with Some p -> [("left", pat_to_json p)] | None -> []) + @ (match docinfo.left_wavedrom with Some w -> [("left_wavedrom", `String w)] | None -> []) + @ (match docinfo.right with Some p -> [("right", pat_to_json p)] | None -> []) + @ (match docinfo.right_wavedrom with Some w -> [("right_wavedrom", `String w)] | None -> []) + @ match docinfo.body with Some s -> [("body", location_or_raw_to_json s)] | None -> [] ) type 'a mapping_doc = 'a mapping_clause_doc list let mapping_doc_to_json docinfos = `List (List.map mapping_clause_doc_to_json docinfos) -type valspec_doc = { - source : location_or_raw; - type_source : location_or_raw; - } +type valspec_doc = { source : location_or_raw; type_source : location_or_raw } let valspec_doc_to_json docinfo = - `Assoc [ - ("source", location_or_raw_to_json docinfo.source); - ("type", location_or_raw_to_json docinfo.type_source) - ] + `Assoc [("source", location_or_raw_to_json docinfo.source); ("type", location_or_raw_to_json docinfo.type_source)] type typdef_doc = location_or_raw let typdef_doc_to_json = location_or_raw_to_json -type register_doc = { - source : location_or_raw; - type_source : location_or_raw; - exp_source : location_or_raw option; - } +type register_doc = { source : location_or_raw; type_source : location_or_raw; exp_source : location_or_raw option } let register_doc_to_json docinfo = - `Assoc ( - [ - ("source", location_or_raw_to_json docinfo.source); - ("type", location_or_raw_to_json docinfo.type_source) - ] - @ (match docinfo.exp_source with - | None -> [] - | Some source -> [("exp", location_or_raw_to_json source)]) + `Assoc + ([("source", location_or_raw_to_json docinfo.source); ("type", location_or_raw_to_json docinfo.type_source)] + @ match docinfo.exp_source with None -> [] | Some source -> [("exp", location_or_raw_to_json source)] ) -type let_doc = { - source : location_or_raw; - exp_source : location_or_raw; - } +type let_doc = { source : location_or_raw; exp_source : location_or_raw } let let_doc_to_json docinfo = - `Assoc [ - ("source", location_or_raw_to_json docinfo.source); - ("exp", location_or_raw_to_json docinfo.exp_source) - ] + `Assoc [("source", location_or_raw_to_json docinfo.source); ("exp", location_or_raw_to_json docinfo.exp_source)] let pair_to_json x_label f y_label g (x, y) = match (f x, g y) with @@ -338,145 +317,148 @@ let pair_to_json x_label f y_label g (x, y) = | `Null, y -> `Assoc [(y_label, y)] | x, y -> `Assoc [(x_label, x); (y_label, y)] -type anchor_doc = { - source : location_or_raw; - comment : string option - } +type anchor_doc = { source : location_or_raw; comment : string option } let anchor_doc_to_json docinfo = - `Assoc ( - [ - ("source", location_or_raw_to_json docinfo.source); - ] - @ (match docinfo.comment with Some c -> [("comment", `String c)] | None -> []) + `Assoc + ([("source", location_or_raw_to_json docinfo.source)] + @ match docinfo.comment with Some c -> [("comment", `String c)] | None -> [] ) type 'a docinfo = { - embedding : embedding; - git : (string * bool) option; - hashes : (string * string) list; - functions : ('a function_doc * hyperlink list) Bindings.t; - mappings : ('a mapping_doc * hyperlink list) Bindings.t; - valspecs : (valspec_doc * hyperlink list) Bindings.t; - typdefs : (typdef_doc * hyperlink list) Bindings.t; - registers : (register_doc * hyperlink list) Bindings.t; - lets : (let_doc * hyperlink list) Bindings.t; - anchors : (anchor_doc * hyperlink list) Bindings.t; - spans : location_or_raw Bindings.t; - } + embedding : embedding; + git : (string * bool) option; + hashes : (string * string) list; + functions : ('a function_doc * hyperlink list) Bindings.t; + mappings : ('a mapping_doc * hyperlink list) Bindings.t; + valspecs : (valspec_doc * hyperlink list) Bindings.t; + typdefs : (typdef_doc * hyperlink list) Bindings.t; + registers : (register_doc * hyperlink list) Bindings.t; + lets : (let_doc * hyperlink list) Bindings.t; + anchors : (anchor_doc * hyperlink list) Bindings.t; + spans : location_or_raw Bindings.t; +} let span_to_json loc = `Assoc [("span", location_or_raw_to_json loc)] let docinfo_to_json docinfo = let assoc = - [ - ("version", `Int docinfo_version) - ] - @ (match docinfo.git with Some (commit, dirty) -> [("git", `Assoc [("commit", `String commit); ("dirty", `Bool dirty)])] | None -> []) + [("version", `Int docinfo_version)] + @ ( match docinfo.git with + | Some (commit, dirty) -> [("git", `Assoc [("commit", `String commit); ("dirty", `Bool dirty)])] + | None -> [] + ) @ [ ("embedding", `String (embedding_string docinfo.embedding)); ("hashes", `Assoc (List.map (fun (key, hash) -> (key, `Assoc [("md5", `String hash)])) docinfo.hashes)); - ("functions", bindings_to_json docinfo.functions (pair_to_json "function" function_doc_to_json "links" hyperlinks_to_json)); - ("mappings", bindings_to_json docinfo.mappings (pair_to_json "mapping" mapping_doc_to_json "links" hyperlinks_to_json)); + ( "functions", + bindings_to_json docinfo.functions (pair_to_json "function" function_doc_to_json "links" hyperlinks_to_json) + ); + ( "mappings", + bindings_to_json docinfo.mappings (pair_to_json "mapping" mapping_doc_to_json "links" hyperlinks_to_json) + ); ("vals", bindings_to_json docinfo.valspecs (pair_to_json "val" valspec_doc_to_json "links" hyperlinks_to_json)); ("types", bindings_to_json docinfo.typdefs (pair_to_json "type" typdef_doc_to_json "links" hyperlinks_to_json)); - ("registers", bindings_to_json docinfo.registers (pair_to_json "register" register_doc_to_json "links" hyperlinks_to_json)); + ( "registers", + bindings_to_json docinfo.registers (pair_to_json "register" register_doc_to_json "links" hyperlinks_to_json) + ); ("lets", bindings_to_json docinfo.lets (pair_to_json "let" let_doc_to_json "links" hyperlinks_to_json)); - ("anchors", bindings_to_json docinfo.anchors (pair_to_json "anchor" anchor_doc_to_json "links" hyperlinks_to_json)); + ( "anchors", + bindings_to_json docinfo.anchors (pair_to_json "anchor" anchor_doc_to_json "links" hyperlinks_to_json) + ); ("spans", bindings_to_json docinfo.spans span_to_json); - ] in + ] + in `Assoc assoc let git_command args = try let git_out, git_in, git_err = Unix.open_process_full ("git " ^ args) (Unix.environment ()) in let res = input_line git_out in - match Unix.close_process_full (git_out, git_in, git_err) with - | Unix.WEXITED 0 -> - Some res - | _ -> - None - with - | _ -> None + match Unix.close_process_full (git_out, git_in, git_err) with Unix.WEXITED 0 -> Some res | _ -> None + with _ -> None module type CONFIG = sig val embedding_mode : embedding option end -module Generator(Converter : Markdown.CONVERTER)(Config : CONFIG) = struct - let encode str = - match Config.embedding_mode with - | Some Plain | None -> str - | Some Base64 -> Base64.encode_string str +module Generator (Converter : Markdown.CONVERTER) (Config : CONFIG) = struct + let encode str = match Config.embedding_mode with Some Plain | None -> str | Some Base64 -> Base64.encode_string str - let embedding_format () = - match Config.embedding_mode with - | Some Plain | None -> Plain - | Some Base64 -> Base64 + let embedding_format () = match Config.embedding_mode with Some Plain | None -> Plain | Some Base64 -> Base64 let doc_lexing_pos p1 p2 = match Config.embedding_mode with - | Some _ -> - Raw (read_source p1 p2 |> encode) - | None -> - Location (p1.pos_fname, p1.pos_lnum, p1.pos_bol, p1.pos_cnum, p2.pos_lnum, p2.pos_bol, p2.pos_cnum) + | Some _ -> Raw (read_source p1 p2 |> encode) + | None -> Location (p1.pos_fname, p1.pos_lnum, p1.pos_bol, p1.pos_cnum, p2.pos_lnum, p2.pos_bol, p2.pos_cnum) let doc_loc l g f x = match Reporting.simp_loc l with - | Some (p1, p2) when p1.pos_fname = p2.pos_fname && Filename.is_relative p1.pos_fname -> - doc_lexing_pos p1 p2 - | _ -> - Raw (g x |> f |> Pretty_print_sail.to_string |> encode) + | Some (p1, p2) when p1.pos_fname = p2.pos_fname && Filename.is_relative p1.pos_fname -> doc_lexing_pos p1 p2 + | _ -> Raw (g x |> f |> Pretty_print_sail.to_string |> encode) let get_doc_comment def_annot = - Option.map (fun comment -> + Option.map + (fun comment -> let conf = Converter.default_config ~loc:def_annot.loc in Converter.convert conf comment - ) def_annot.doc_comment + ) + def_annot.doc_comment - let docinfo_for_valspec (VS_aux (VS_val_spec ((TypSchm_aux (_, ts_l) as ts), _, _, _), vs_annot) as vs) = { + let docinfo_for_valspec (VS_aux (VS_val_spec ((TypSchm_aux (_, ts_l) as ts), _, _, _), vs_annot) as vs) = + { source = doc_loc (fst vs_annot) Type_check.strip_val_spec Pretty_print_sail.doc_spec vs; - type_source = doc_loc ts_l (fun ts -> ts) Pretty_print_sail.doc_typschm ts + type_source = doc_loc ts_l (fun ts -> ts) Pretty_print_sail.doc_typschm ts; } - let docinfo_for_typdef (TD_aux (_, annot) as td) = doc_loc (fst annot) Type_check.strip_typedef Pretty_print_sail.doc_typdef td + let docinfo_for_typdef (TD_aux (_, annot) as td) = + doc_loc (fst annot) Type_check.strip_typedef Pretty_print_sail.doc_typdef td - let docinfo_for_register (DEC_aux (DEC_reg ((Typ_aux (_, typ_l) as typ), _, exp), rd_annot) as rd) = { + let docinfo_for_register (DEC_aux (DEC_reg ((Typ_aux (_, typ_l) as typ), _, exp), rd_annot) as rd) = + { source = doc_loc (fst rd_annot) Type_check.strip_register Pretty_print_sail.doc_dec rd; type_source = doc_loc typ_l (fun typ -> typ) Pretty_print_sail.doc_typ typ; - exp_source = Option.map (fun (E_aux (_, (l, _)) as exp) -> doc_loc l Type_check.strip_exp Pretty_print_sail.doc_exp exp) exp + exp_source = + Option.map (fun (E_aux (_, (l, _)) as exp) -> doc_loc l Type_check.strip_exp Pretty_print_sail.doc_exp exp) exp; } - let docinfo_for_let (LB_aux (LB_val (_, exp), annot) as lbind) = { + let docinfo_for_let (LB_aux (LB_val (_, exp), annot) as lbind) = + { source = doc_loc (fst annot) Type_check.strip_letbind Pretty_print_sail.doc_letbind lbind; exp_source = doc_loc (exp_loc exp) Type_check.strip_exp Pretty_print_sail.doc_exp exp; } let funcl_splits ~ast ~error_loc:l attrs exp = (* The constant propagation tends to strip away block formatting, so put it back to make the pretty_printed output a bit nicer. *) - let pretty_printer = match exp with - | E_aux (E_block _, _) -> (fun exp -> Pretty_print_sail.doc_block [exp]) - | _ -> (fun exp -> Pretty_print_sail.doc_exp exp) + let pretty_printer = + match exp with + | E_aux (E_block _, _) -> fun exp -> Pretty_print_sail.doc_block [exp] + | _ -> fun exp -> Pretty_print_sail.doc_exp exp in match find_attribute_opt "split" attrs with | None -> None - | Some split_id -> - let split_id = mk_id split_id in - let env = Type_check.env_of exp in - match Type_check.Env.lookup_id split_id env with - | Local (_, (Typ_aux (Typ_id enum_id, _) as enum_typ)) -> - let members = Type_check.Env.get_enum enum_id env in - let splits = - List.fold_left (fun splits member -> - let checked_member = Type_check.check_exp env (mk_exp (E_id member)) enum_typ in - let substs = (Bindings.singleton split_id checked_member, KBindings.empty) in - let (propagated, _) = Constant_propagation.const_prop "doc" ast IdSet.empty substs Bindings.empty exp in - let propagated_doc = Raw (pretty_printer (Type_check.strip_exp propagated) |> Pretty_print_sail.to_string |> encode) in - Bindings.add member propagated_doc splits - ) Bindings.empty members in - Some splits - | _ -> - raise (Reporting.err_general l ("Could not split on variable " ^ string_of_id split_id)) + | Some split_id -> ( + let split_id = mk_id split_id in + let env = Type_check.env_of exp in + match Type_check.Env.lookup_id split_id env with + | Local (_, (Typ_aux (Typ_id enum_id, _) as enum_typ)) -> + let members = Type_check.Env.get_enum enum_id env in + let splits = + List.fold_left + (fun splits member -> + let checked_member = Type_check.check_exp env (mk_exp (E_id member)) enum_typ in + let substs = (Bindings.singleton split_id checked_member, KBindings.empty) in + let propagated, _ = Constant_propagation.const_prop "doc" ast IdSet.empty substs Bindings.empty exp in + let propagated_doc = + Raw (pretty_printer (Type_check.strip_exp propagated) |> Pretty_print_sail.to_string |> encode) + in + Bindings.add member propagated_doc splits + ) + Bindings.empty members + in + Some splits + | _ -> raise (Reporting.err_general l ("Could not split on variable " ^ string_of_id split_id)) + ) let docinfo_for_funcl ~ast ?outer_annot n (FCL_aux (FCL_funcl (_, pexp), annot) as clause) = (* If we have just a single clause, we use the annotation for the @@ -490,34 +472,42 @@ module Generator(Converter : Markdown.CONVERTER)(Config : CONFIG) = struct let attrs = match outer_annot with None -> (fst annot).attrs | Some outer -> (fst outer).attrs in let source = doc_loc (fst annot).loc Type_check.strip_funcl Pretty_print_sail.doc_funcl clause in - let pat, guard, exp = match pexp with - | Pat_aux (Pat_exp (pat, exp), _) -> pat, None, exp - | Pat_aux (Pat_when (pat, guard, exp), _) -> pat, Some guard, exp in - let guard_source = Option.map (fun exp -> doc_loc (exp_loc exp) Type_check.strip_exp Pretty_print_sail.doc_exp exp) guard in - let body_source = match exp with + let pat, guard, exp = + match pexp with + | Pat_aux (Pat_exp (pat, exp), _) -> (pat, None, exp) + | Pat_aux (Pat_when (pat, guard, exp), _) -> (pat, Some guard, exp) + in + let guard_source = + Option.map (fun exp -> doc_loc (exp_loc exp) Type_check.strip_exp Pretty_print_sail.doc_exp exp) guard + in + let body_source = + match exp with | E_aux (E_block (exp :: exps), _) -> - let first_loc = exp_loc exp in - let last_loc = exp_loc (Util.last (exp :: exps)) in - begin match Reporting.simp_loc first_loc, Reporting.simp_loc last_loc with - | Some (p1, _), Some (_, p2) when p1.pos_fname = p2.pos_fname && Filename.is_relative p1.pos_fname -> - (* Make sure the first line is indented correctly *) - doc_lexing_pos { p1 with pos_cnum = p1.pos_bol } p2 - | _, _ -> - let block = Type_check.strip_exp exp :: List.map Type_check.strip_exp exps in - Raw (Pretty_print_sail.doc_block block |> Pretty_print_sail.to_string |> encode) - end - | _ -> doc_loc (exp_loc exp) Type_check.strip_exp Pretty_print_sail.doc_exp exp in - - let splits = funcl_splits ~ast:ast ~error_loc:(pat_loc pat) attrs exp in - - { number = n; - source = source; - pat = pat; + let first_loc = exp_loc exp in + let last_loc = exp_loc (Util.last (exp :: exps)) in + begin + match (Reporting.simp_loc first_loc, Reporting.simp_loc last_loc) with + | Some (p1, _), Some (_, p2) when p1.pos_fname = p2.pos_fname && Filename.is_relative p1.pos_fname -> + (* Make sure the first line is indented correctly *) + doc_lexing_pos { p1 with pos_cnum = p1.pos_bol } p2 + | _, _ -> + let block = Type_check.strip_exp exp :: List.map Type_check.strip_exp exps in + Raw (Pretty_print_sail.doc_block block |> Pretty_print_sail.to_string |> encode) + end + | _ -> doc_loc (exp_loc exp) Type_check.strip_exp Pretty_print_sail.doc_exp exp + in + + let splits = funcl_splits ~ast ~error_loc:(pat_loc pat) attrs exp in + + { + number = n; + source; + pat; wavedrom = Wavedrom.of_pattern ~labels:None pat |> Option.map encode; - guard_source = guard_source; - body_source = body_source; + guard_source; + body_source; comment = Option.map encode comment; - splits = splits + splits; } let included_clause files (FCL_aux (_, (clause_annot, _))) = included_loc files clause_annot.loc @@ -526,62 +516,60 @@ module Generator(Converter : Markdown.CONVERTER)(Config : CONFIG) = struct let clauses = List.filter (included_clause files) clauses in match clauses with | [] -> None - | [clause] -> - Some (Single_clause (docinfo_for_funcl ~ast:ast ~outer_annot:(def_annot, snd annot) 0 clause)) - | _ -> - Some (Multiple_clauses (List.mapi (docinfo_for_funcl ~ast:ast) clauses)) + | [clause] -> Some (Single_clause (docinfo_for_funcl ~ast ~outer_annot:(def_annot, snd annot) 0 clause)) + | _ -> Some (Multiple_clauses (List.mapi (docinfo_for_funcl ~ast) clauses)) let docinfo_for_mpexp (MPat_aux (aux, _)) = - match aux with - | MPat_pat mpat -> Rewrites.pat_of_mpat mpat - | MPat_when (mpat, _) -> Rewrites.pat_of_mpat mpat + match aux with MPat_pat mpat -> Rewrites.pat_of_mpat mpat | MPat_when (mpat, _) -> Rewrites.pat_of_mpat mpat let docinfo_for_mapcl n (MCL_aux (aux, (def_annot, _)) as clause) = let source = doc_loc def_annot.loc Type_check.strip_mapcl Pretty_print_sail.doc_mapcl clause in let wavedrom_attr = find_attribute_opt "wavedrom" def_annot.attrs in - let left, left_wavedrom, right, right_wavedrom, body = match aux with + let left, left_wavedrom, right, right_wavedrom, body = + match aux with | MCL_bidir (left, right) -> - let left = docinfo_for_mpexp left in - let left_wavedrom = Wavedrom.of_pattern ~labels:wavedrom_attr left in - let right = docinfo_for_mpexp right in - let right_wavedrom = Wavedrom.of_pattern ~labels:wavedrom_attr right in - (Some left, left_wavedrom, Some right, right_wavedrom, None) + let left = docinfo_for_mpexp left in + let left_wavedrom = Wavedrom.of_pattern ~labels:wavedrom_attr left in + let right = docinfo_for_mpexp right in + let right_wavedrom = Wavedrom.of_pattern ~labels:wavedrom_attr right in + (Some left, left_wavedrom, Some right, right_wavedrom, None) | MCL_forwards (left, body) -> - let left = docinfo_for_mpexp left in - let left_wavedrom = Wavedrom.of_pattern ~labels:wavedrom_attr left in - let body = doc_loc (exp_loc body) Type_check.strip_exp Pretty_print_sail.doc_exp body in - (Some left, left_wavedrom, None, None, Some body) + let left = docinfo_for_mpexp left in + let left_wavedrom = Wavedrom.of_pattern ~labels:wavedrom_attr left in + let body = doc_loc (exp_loc body) Type_check.strip_exp Pretty_print_sail.doc_exp body in + (Some left, left_wavedrom, None, None, Some body) | MCL_backwards (right, body) -> - let right = docinfo_for_mpexp right in - let right_wavedrom = Wavedrom.of_pattern ~labels:wavedrom_attr right in - let body = doc_loc (exp_loc body) Type_check.strip_exp Pretty_print_sail.doc_exp body in - (None, None, Some right, right_wavedrom, Some body) in - - { number = n; - source = source; - left = left; + let right = docinfo_for_mpexp right in + let right_wavedrom = Wavedrom.of_pattern ~labels:wavedrom_attr right in + let body = doc_loc (exp_loc body) Type_check.strip_exp Pretty_print_sail.doc_exp body in + (None, None, Some right, right_wavedrom, Some body) + in + + { + number = n; + source; + left; left_wavedrom = Option.map encode left_wavedrom; - right = right; + right; right_wavedrom = Option.map encode right_wavedrom; - body = body + body; } let included_mapping_clause files (MCL_aux (_, (def_annot, _))) = included_loc files def_annot.loc let docinfo_for_mapdef files (MD_aux (MD_mapping (_, _, clauses), _)) = let clauses = List.filter (included_mapping_clause files) clauses in - match clauses with - | [] -> None - | _ -> - Some (List.mapi docinfo_for_mapcl clauses) + match clauses with [] -> None | _ -> Some (List.mapi docinfo_for_mapcl clauses) let docinfo_for_ast ~files ~hyperlinks ast = let gitinfo = git_command "rev-parse HEAD" - |> Option.map (fun checksum -> (checksum, Option.is_none (git_command "diff --quiet"))) in + |> Option.map (fun checksum -> (checksum, Option.is_none (git_command "diff --quiet"))) + in - let empty_docinfo = { + let empty_docinfo = + { embedding = embedding_format (); git = gitinfo; hashes = []; @@ -593,91 +581,79 @@ module Generator(Converter : Markdown.CONVERTER)(Config : CONFIG) = struct lets = Bindings.empty; anchors = Bindings.empty; spans = Bindings.empty; - } in - let initial_skip = match files with - | [] -> false - | _ -> true in - let skip_file file = - if List.exists (same_file file) files then ( - false - ) else ( - initial_skip - ) in - let skipping = function - | true :: _ -> true - | _ -> false in + } + in + let initial_skip = match files with [] -> false | _ -> true in + let skip_file file = if List.exists (same_file file) files then false else initial_skip in + let skipping = function true :: _ -> true | _ -> false in let docinfo_for_def (docinfo, skips) (DEF_aux (aux, def_annot) as def) = let links = hyperlinks files def in match aux with (* Maintain a stack of booleans, for each file if it was not specified via -doc_file, we push true to skip it. If no -doc_file flags are passed, include everything. *) - | DEF_pragma (("file_start" | "include_start"), path, _) -> - docinfo, (skip_file path :: skips) - | DEF_pragma (("file_end" | "include_end"), _, _) -> - docinfo, (match skips with _ :: skips -> skips | [] -> []) - + | DEF_pragma (("file_start" | "include_start"), path, _) -> (docinfo, skip_file path :: skips) + | DEF_pragma (("file_end" | "include_end"), _, _) -> ( + (docinfo, match skips with _ :: skips -> skips | [] -> []) + ) (* Function definiton may be scattered, so we can't skip it *) | DEF_fundef fdef -> - let id = id_of_fundef fdef in - begin match docinfo_for_fundef ~ast:ast def_annot files fdef with - | None -> docinfo - | Some info -> { docinfo with functions = Bindings.add id (info, links) docinfo.functions } - end, - skips - + let id = id_of_fundef fdef in + ( begin + match docinfo_for_fundef ~ast def_annot files fdef with + | None -> docinfo + | Some info -> { docinfo with functions = Bindings.add id (info, links) docinfo.functions } + end, + skips + ) | DEF_mapdef mdef -> - let id = id_of_mapdef mdef in - begin match docinfo_for_mapdef files mdef with - | None -> docinfo - | Some info -> { docinfo with mappings = Bindings.add id (info, links) docinfo.mappings } - end, - skips - - | _ when skipping skips -> - docinfo, skips - + let id = id_of_mapdef mdef in + ( begin + match docinfo_for_mapdef files mdef with + | None -> docinfo + | Some info -> { docinfo with mappings = Bindings.add id (info, links) docinfo.mappings } + end, + skips + ) + | _ when skipping skips -> (docinfo, skips) | DEF_val vs -> - let id = id_of_val_spec vs in - { docinfo with valspecs = Bindings.add id (docinfo_for_valspec vs, links) docinfo.valspecs }, - skips - + let id = id_of_val_spec vs in + ({ docinfo with valspecs = Bindings.add id (docinfo_for_valspec vs, links) docinfo.valspecs }, skips) | DEF_type td -> - let id = id_of_type_def td in - { docinfo with typdefs = Bindings.add id (docinfo_for_typdef td, links) docinfo.typdefs }, - skips - + let id = id_of_type_def td in + ({ docinfo with typdefs = Bindings.add id (docinfo_for_typdef td, links) docinfo.typdefs }, skips) | DEF_register rd -> - let id = id_of_dec_spec rd in - { docinfo with registers = Bindings.add id (docinfo_for_register rd, links) docinfo.registers }, - skips - - | DEF_let (LB_aux (LB_val (pat, _ ), _) as letbind) -> - let ids = pat_ids pat in - IdSet.fold (fun id docinfo -> - { docinfo with lets = Bindings.add id (docinfo_for_let letbind, links) docinfo.lets } - ) ids docinfo, - skips - - | _ -> - docinfo, skips + let id = id_of_dec_spec rd in + ({ docinfo with registers = Bindings.add id (docinfo_for_register rd, links) docinfo.registers }, skips) + | DEF_let (LB_aux (LB_val (pat, _), _) as letbind) -> + let ids = pat_ids pat in + ( IdSet.fold + (fun id docinfo -> { docinfo with lets = Bindings.add id (docinfo_for_let letbind, links) docinfo.lets }) + ids docinfo, + skips + ) + | _ -> (docinfo, skips) in let docinfo = List.fold_left docinfo_for_def (empty_docinfo, [initial_skip]) ast.defs |> fst in let process_anchors docinfo = let anchored = ref Bindings.empty in - List.iter (fun (DEF_aux (aux, def_annot) as def) -> + List.iter + (fun (DEF_aux (aux, def_annot) as def) -> let l = def_loc def in match aux with | DEF_pragma ("anchor", arg, _) -> - let links = hyperlinks files def in - let anchor_info = { - source = doc_loc l Type_check.strip_def Pretty_print_sail.doc_def def; - comment = def_annot.doc_comment - } in - anchored := Bindings.add (mk_id arg) (anchor_info, links) !anchored; + let links = hyperlinks files def in + let anchor_info = + { + source = doc_loc l Type_check.strip_def Pretty_print_sail.doc_def def; + comment = def_annot.doc_comment; + } + in + anchored := Bindings.add (mk_id arg) (anchor_info, links) !anchored | _ -> () - ) ast.defs; + ) + ast.defs; { docinfo with anchors = !anchored } in let docinfo = process_anchors docinfo in @@ -685,56 +661,48 @@ module Generator(Converter : Markdown.CONVERTER)(Config : CONFIG) = struct let process_spans docinfo = let spans = ref Bindings.empty in let current_span = ref None in - List.iter (fun (DEF_aux (aux, def_annot)) -> + List.iter + (fun (DEF_aux (aux, def_annot)) -> match aux with - | DEF_pragma ("span", arg, _) when Option.is_none !current_span -> - begin match String.split_on_char ' ' arg with - | ["start"; name] -> - current_span := Some (name, def_annot.loc) - | _ -> - raise (Reporting.err_general def_annot.loc "Invalid span directive") - end - - | DEF_pragma ("span", arg, _) when arg = "end" -> - begin match !current_span with - | Some (name, start_l) -> - let end_l = def_annot.loc in - begin match Reporting.simp_loc start_l, Reporting.simp_loc end_l with - | Some (_, p1), Some (p2, _) when p1.pos_fname = p2.pos_fname -> - (* Adjust the span for p2 to end at the very start of the directive *) - let p2 = { p2 with pos_cnum = p2.pos_bol } in - spans := Bindings.add (mk_id name) (doc_lexing_pos p1 p2) !spans - | _, _ -> - raise (Reporting.err_general def_annot.loc "Invalid locations found when ending span") - end - | None -> - raise (Reporting.err_general def_annot.loc "No start span for this end span") - end - + | DEF_pragma ("span", arg, _) when Option.is_none !current_span -> begin + match String.split_on_char ' ' arg with + | ["start"; name] -> current_span := Some (name, def_annot.loc) + | _ -> raise (Reporting.err_general def_annot.loc "Invalid span directive") + end + | DEF_pragma ("span", arg, _) when arg = "end" -> begin + match !current_span with + | Some (name, start_l) -> + let end_l = def_annot.loc in + begin + match (Reporting.simp_loc start_l, Reporting.simp_loc end_l) with + | Some (_, p1), Some (p2, _) when p1.pos_fname = p2.pos_fname -> + (* Adjust the span for p2 to end at the very start of the directive *) + let p2 = { p2 with pos_cnum = p2.pos_bol } in + spans := Bindings.add (mk_id name) (doc_lexing_pos p1 p2) !spans + | _, _ -> raise (Reporting.err_general def_annot.loc "Invalid locations found when ending span") + end + | None -> raise (Reporting.err_general def_annot.loc "No start span for this end span") + end | DEF_pragma ("span", _, _) -> - raise (Reporting.err_general def_annot.loc "Previous span must be ended before this one can begin") - + raise (Reporting.err_general def_annot.loc "Previous span must be ended before this one can begin") | _ -> () - ) ast.defs; + ) + ast.defs; { docinfo with spans = !spans } in let docinfo = process_spans docinfo in - let module StringMap = Map.Make(String) in + let module StringMap = Map.Make (String) in let process_file_hashes hashes (DEF_aux (_, doc_annot)) = if included_loc files doc_annot.loc then ( match Reporting.simp_loc doc_annot.loc with | None -> hashes | Some (p1, _) -> - if StringMap.mem p1.pos_fname hashes then ( - hashes - ) else ( - StringMap.add p1.pos_fname (hash_file p1.pos_fname) hashes - ) - ) else ( - hashes - ) in + if StringMap.mem p1.pos_fname hashes then hashes + else StringMap.add p1.pos_fname (hash_file p1.pos_fname) hashes + ) + else hashes + in let hashes = List.fold_left process_file_hashes StringMap.empty ast.defs in { docinfo with hashes = StringMap.bindings hashes } - end diff --git a/src/sail_doc_backend/docinfo.mli b/src/sail_doc_backend/docinfo.mli index ea6c690b8..e137ed931 100644 --- a/src/sail_doc_backend/docinfo.mli +++ b/src/sail_doc_backend/docinfo.mli @@ -85,6 +85,7 @@ module type CONFIG = sig val embedding_mode : embedding option end -module Generator(Converter : Markdown.CONVERTER)(Config: CONFIG) : sig - val docinfo_for_ast : files:string list -> hyperlinks:(string list -> tannot def -> hyperlink list) -> tannot ast -> tannot docinfo +module Generator (Converter : Markdown.CONVERTER) (Config : CONFIG) : sig + val docinfo_for_ast : + files:string list -> hyperlinks:(string list -> tannot def -> hyperlink list) -> tannot ast -> tannot docinfo end diff --git a/src/sail_doc_backend/dune b/src/sail_doc_backend/dune index 41d2e158e..0377e7dbe 100644 --- a/src/sail_doc_backend/dune +++ b/src/sail_doc_backend/dune @@ -1,12 +1,14 @@ - (executable - (name sail_plugin_doc) - (modes (native plugin)) - (link_flags -linkall) - (libraries libsail omd yojson base64) - (embed_in_plugin_libraries omd base64)) + (name sail_plugin_doc) + (modes + (native plugin)) + (link_flags -linkall) + (libraries libsail omd yojson base64) + (embed_in_plugin_libraries omd base64)) (install - (section (site (libsail plugins))) - (package sail_doc_backend) - (files sail_plugin_doc.cmxs)) + (section + (site + (libsail plugins))) + (package sail_doc_backend) + (files sail_plugin_doc.cmxs)) diff --git a/src/sail_doc_backend/markdown.ml b/src/sail_doc_backend/markdown.ml index 131563b00..b44175e45 100644 --- a/src/sail_doc_backend/markdown.ml +++ b/src/sail_doc_backend/markdown.ml @@ -87,44 +87,33 @@ module AsciidocConverter : CONVERTER = struct open Printf open Omd - type config = { - this : Ast.id option; - loc : Parse_ast.l; - list_depth : int - } + type config = { this : Ast.id option; loc : Parse_ast.l; list_depth : int } - let default_config ~loc = { - this = None; - loc = loc; - list_depth = 1 - } + let default_config ~loc = { this = None; loc; list_depth = 1 } - let rec format_elem (conf: config) = function + let rec format_elem (conf : config) = function | Paragraph elems -> format conf elems ^ "\n\n" | Text str -> str | Emph elems -> sprintf "_%s_" (format conf elems) | Bold elems -> sprintf "*%s*" (format conf elems) | Code (_, code) -> sprintf "`%s`" code - | Code_block (lang, code) -> - sprintf "[source,%s]\n----\n%s\n----\n\n" lang code + | Code_block (lang, code) -> sprintf "[source,%s]\n----\n%s\n----\n\n" lang code | Br -> "\n" | NL -> "\n" | H1 header -> "= " ^ format conf header ^ "\n" | H2 header -> "== " ^ format conf header ^ "\n" | H3 header -> "=== " ^ format conf header ^ "\n" | H4 header -> "==== " ^ format conf header ^ "\n" - | (Ul list | Ulp list) -> - Util.string_of_list "" (fun item -> - let new_conf = { conf with list_depth = conf.list_depth + 1 } in - "\n" ^ String.make conf.list_depth '*' ^ " " ^ format new_conf item - ) list - | _ -> - raise (Reporting.err_general conf.loc "Cannot convert markdown element to Asciidoc") + | Ul list | Ulp list -> + Util.string_of_list "" + (fun item -> + let new_conf = { conf with list_depth = conf.list_depth + 1 } in + "\n" ^ String.make conf.list_depth '*' ^ " " ^ format new_conf item + ) + list + | _ -> raise (Reporting.err_general conf.loc "Cannot convert markdown element to Asciidoc") - and format conf elems = - String.concat "" (List.map (format_elem conf) elems) - - let convert conf comment = - format conf (Omd.of_string comment) + and format conf elems = String.concat "" (List.map (format_elem conf) elems) + let convert conf comment = format conf (Omd.of_string comment) end diff --git a/src/sail_doc_backend/sail_plugin_doc.ml b/src/sail_doc_backend/sail_plugin_doc.ml index e59c68796..aa141380e 100644 --- a/src/sail_doc_backend/sail_plugin_doc.ml +++ b/src/sail_doc_backend/sail_plugin_doc.ml @@ -73,40 +73,37 @@ let opt_doc_embed = ref None let opt_doc_compact = ref false let opt_doc_bundle = ref "doc.json" -let embedding_option () = match !opt_doc_embed with +let embedding_option () = + match !opt_doc_embed with | None -> None | Some "plain" -> Some Docinfo.Plain | Some "base64" -> Some Docinfo.Base64 | Some embedding -> - Printf.eprintf "Unknown embedding type %s for -doc_embed, allowed values are 'plain' or 'base64'\n" embedding; - exit 1 + Printf.eprintf "Unknown embedding type %s for -doc_embed, allowed values are 'plain' or 'base64'\n" embedding; + exit 1 -let doc_options = [ - ( "-doc_format", - Arg.String (fun format -> opt_doc_format := format), - " Output documentation in the chosen format, either latex or asciidoc (default asciidoc)"); - ( "-doc_file", - Arg.String (fun file -> opt_doc_files := file :: !opt_doc_files), - " Document only the provided files"); - ( "-doc_embed", - Arg.String (fun format -> opt_doc_embed := Some format), - " Embed all documentation contents into the documentation bundle rather than referencing it"); - ( "-doc_compact", - Arg.Unit (fun _ -> opt_doc_compact := true), - " Use compact documentation format"); - ( "-doc_bundle", - Arg.String (fun file -> opt_doc_bundle := file), - " Name for documentation bundle file"); +let doc_options = + [ + ( "-doc_format", + Arg.String (fun format -> opt_doc_format := format), + " Output documentation in the chosen format, either latex or asciidoc (default asciidoc)" + ); + ( "-doc_file", + Arg.String (fun file -> opt_doc_files := file :: !opt_doc_files), + " Document only the provided files" + ); + ( "-doc_embed", + Arg.String (fun format -> opt_doc_embed := Some format), + " Embed all documentation contents into the documentation bundle rather than referencing it" + ); + ("-doc_compact", Arg.Unit (fun _ -> opt_doc_compact := true), " Use compact documentation format"); + ("-doc_bundle", Arg.String (fun file -> opt_doc_bundle := file), " Name for documentation bundle file"); ] let output_docinfo doc_dir docinfo = let chan = open_out (Filename.concat doc_dir !opt_doc_bundle) in let json = Docinfo.docinfo_to_json docinfo in - if !opt_doc_compact then ( - Yojson.to_channel ~std:true chan json - ) else ( - Yojson.pretty_to_channel ~std:true chan json - ); + if !opt_doc_compact then Yojson.to_channel ~std:true chan json else Yojson.pretty_to_channel ~std:true chan json; output_char chan '\n'; close_out chan @@ -119,30 +116,26 @@ let doc_target _ out_file ast _ _ = prerr_endline ("Failure: documentation output location exists and is not a directory: " ^ doc_dir); exit 1 ) - with Sys_error(_) -> Unix.mkdir doc_dir 0o755 + with Sys_error _ -> Unix.mkdir doc_dir 0o755 end; - if !opt_doc_format = "asciidoc" || !opt_doc_format = "adoc" then ( + if !opt_doc_format = "asciidoc" || !opt_doc_format = "adoc" then let module Config = struct - let embedding_mode = embedding_option() - end in - let module Gen = Docinfo.Generator(Markdown.AsciidocConverter)(Config) in + let embedding_mode = embedding_option () + end in + let module Gen = Docinfo.Generator (Markdown.AsciidocConverter) (Config) in let docinfo = Gen.docinfo_for_ast ~files:!opt_doc_files ~hyperlinks:Docinfo.hyperlinks_from_def ast in output_docinfo doc_dir docinfo - ) else if !opt_doc_format = "identity" then ( + else if !opt_doc_format = "identity" then let module Config = struct - let embedding_mode = embedding_option() - end in - let module Gen = Docinfo.Generator(Markdown.IdentityConverter)(Config) in + let embedding_mode = embedding_option () + end in + let module Gen = Docinfo.Generator (Markdown.IdentityConverter) (Config) in let docinfo = Gen.docinfo_for_ast ~files:!opt_doc_files ~hyperlinks:Docinfo.hyperlinks_from_def ast in output_docinfo doc_dir docinfo - ) else ( - Printf.eprintf "Unknown documentation format: %s\n" !opt_doc_format - ) + else Printf.eprintf "Unknown documentation format: %s\n" !opt_doc_format let _ = - Target.register - ~name:"doc" - ~options:doc_options + Target.register ~name:"doc" ~options:doc_options ~pre_parse_hook:(fun () -> Type_check.opt_expand_valspec := false; Type_check.opt_no_bitfield_expansion := true diff --git a/src/sail_doc_backend/wavedrom.ml b/src/sail_doc_backend/wavedrom.ml index 291220b33..fc5e57164 100644 --- a/src/sail_doc_backend/wavedrom.ml +++ b/src/sail_doc_backend/wavedrom.ml @@ -78,14 +78,11 @@ exception Invalid_wavedrom let process_attr_arg = function | None -> [] | Some arg -> - let labels = String.split_on_char ' ' arg |> List.filter (fun label -> label <> "") in - List.map (function - | "_" -> None - | label -> Some label - ) labels + let labels = String.split_on_char ' ' arg |> List.filter (fun label -> label <> "") in + List.map (function "_" -> None | label -> Some label) labels let rec zip_labels xs ys = - match xs, ys with + match (xs, ys) with | [], ys -> List.map (fun y -> (None, y)) ys | _, [] -> [] | x :: xs, y :: ys -> (x, y) :: zip_labels xs ys @@ -96,56 +93,45 @@ let wavedrom_label size = function let binary_to_hex str = let open Sail2_values in - let padded = match String.length str mod 4 with - | 0 -> str - | 1 -> "000" ^ str - | 2 -> "00" ^ str - | _ -> "0" ^ str in + let padded = match String.length str mod 4 with 0 -> str | 1 -> "000" ^ str | 2 -> "00" ^ str | _ -> "0" ^ str in Util.string_to_list padded |> List.map (function '0' -> B0 | _ -> B1) - |> hexstring_of_bits - |> Option.get + |> hexstring_of_bits |> Option.get |> Util.string_of_list "" (fun c -> String.make 1 c) let rec wavedrom_elem_string size label (P_aux (aux, _)) = match aux with | P_id id -> - Printf.sprintf " { bits: %d, name: '%s'%s, type: 2 }" size (string_of_id id) (wavedrom_label size label) + Printf.sprintf " { bits: %d, name: '%s'%s, type: 2 }" size (string_of_id id) (wavedrom_label size label) | P_lit (L_aux (L_bin bin, _)) -> - Printf.sprintf " { bits: %d, name: 0x%s%s, type: 8 }" size (binary_to_hex bin) (wavedrom_label size label) + Printf.sprintf " { bits: %d, name: 0x%s%s, type: 8 }" size (binary_to_hex bin) (wavedrom_label size label) | P_lit (L_aux (L_hex hex, _)) -> - Printf.sprintf " { bits: %d, name: 0x%s%s, type: 8 }" size hex (wavedrom_label size label) + Printf.sprintf " { bits: %d, name: 0x%s%s, type: 8 }" size hex (wavedrom_label size label) | P_vector_subrange (_, n, m) when Big_int.equal n m -> - Printf.sprintf " { bits: %d, name: '[%s]'%s, type: 3 }" - size (Big_int.to_string n) (wavedrom_label size label) + Printf.sprintf " { bits: %d, name: '[%s]'%s, type: 3 }" size (Big_int.to_string n) (wavedrom_label size label) | P_vector_subrange (id, n, m) -> - Printf.sprintf " { bits: %d, name: '%s[%s..%s]'%s, type: 3 }" - size (string_of_id id) (Big_int.to_string n) (Big_int.to_string m) (wavedrom_label size label) - | P_as (pat, _) | P_typ (_, pat) -> - wavedrom_elem_string size label pat + Printf.sprintf " { bits: %d, name: '%s[%s..%s]'%s, type: 3 }" size (string_of_id id) (Big_int.to_string n) + (Big_int.to_string m) (wavedrom_label size label) + | P_as (pat, _) | P_typ (_, pat) -> wavedrom_elem_string size label pat | _ -> raise Invalid_wavedrom let wavedrom_elem (label, (P_aux (_, (_, tannot)) as pat)) = match Type_check.destruct_tannot tannot with | None -> raise Invalid_wavedrom - | Some (env, typ) -> - match Type_check.destruct_bitvector env typ with - | Some (Nexp_aux (Nexp_constant size, _), _) -> - let size = Big_int.to_int size in - wavedrom_elem_string size label pat - | _ -> raise Invalid_wavedrom + | Some (env, typ) -> ( + match Type_check.destruct_bitvector env typ with + | Some (Nexp_aux (Nexp_constant size, _), _) -> + let size = Big_int.to_int size in + wavedrom_elem_string size label pat + | _ -> raise Invalid_wavedrom + ) let of_pattern' attr_arg = function | P_aux (P_vector_concat xs, _) -> - let labels = process_attr_arg attr_arg in - let elems = List.rev_map wavedrom_elem (zip_labels labels xs) in - let strs = Util.string_of_list ",\n" (fun x -> x) elems in - Printf.sprintf "{reg:[\n%s\n]}" strs - | _ -> - raise Invalid_wavedrom + let labels = process_attr_arg attr_arg in + let elems = List.rev_map wavedrom_elem (zip_labels labels xs) in + let strs = Util.string_of_list ",\n" (fun x -> x) elems in + Printf.sprintf "{reg:[\n%s\n]}" strs + | _ -> raise Invalid_wavedrom -let of_pattern ~labels:attr_arg pat = - try - Some (of_pattern' attr_arg pat) - with - | Invalid_wavedrom -> None +let of_pattern ~labels:attr_arg pat = try Some (of_pattern' attr_arg pat) with Invalid_wavedrom -> None diff --git a/src/sail_latex_backend/dune b/src/sail_latex_backend/dune index 14235d144..9a518cbd3 100644 --- a/src/sail_latex_backend/dune +++ b/src/sail_latex_backend/dune @@ -1,17 +1,22 @@ (env - (dev - (flags (:standard -w -33 -w -27 -w -32 -w -26 -w -37))) - (release - (flags (:standard -w -33 -w -27 -w -32 -w -26 -w -37)))) + (dev + (flags + (:standard -w -33 -w -27 -w -32 -w -26 -w -37))) + (release + (flags + (:standard -w -33 -w -27 -w -32 -w -26 -w -37)))) (executable - (name sail_plugin_latex) - (modes (native plugin)) - (link_flags -linkall) - (libraries libsail omd) - (embed_in_plugin_libraries omd)) + (name sail_plugin_latex) + (modes + (native plugin)) + (link_flags -linkall) + (libraries libsail omd) + (embed_in_plugin_libraries omd)) (install - (section (site (libsail plugins))) - (package sail_latex_backend) - (files sail_plugin_latex.cmxs)) + (section + (site + (libsail plugins))) + (package sail_latex_backend) + (files sail_plugin_latex.cmxs)) diff --git a/src/sail_latex_backend/latex.ml b/src/sail_latex_backend/latex.ml index 809ed62bb..90243e58d 100644 --- a/src/sail_latex_backend/latex.ml +++ b/src/sail_latex_backend/latex.ml @@ -73,7 +73,7 @@ open Ast_util open PPrint open Printf -module StringSet = Set.Make(String);; +module StringSet = Set.Make (String) let opt_prefix = ref "sail" let opt_directory = ref "sail_latex" @@ -81,20 +81,17 @@ let opt_simple_val = ref true let opt_abbrevs = ref ["e.g."; "i.e."] let rec unique_postfix n = - if n < 0 then - "" - else if n >= 26 then - String.make 1 (Char.chr (n mod 26 + 65)) ^ unique_postfix (n - 26) - else - String.make 1 (Char.chr (n mod 26 + 65)) - -type latex_state = - { mutable noindent : bool; - mutable this : id option; - mutable norefs : StringSet.t; - mutable generated_names : string Bindings.t; - mutable commands : StringSet.t - } + if n < 0 then "" + else if n >= 26 then String.make 1 (Char.chr ((n mod 26) + 65)) ^ unique_postfix (n - 26) + else String.make 1 (Char.chr ((n mod 26) + 65)) + +type latex_state = { + mutable noindent : bool; + mutable this : id option; + mutable norefs : StringSet.t; + mutable generated_names : string Bindings.t; + mutable commands : StringSet.t; +} let reset_state state = state.noindent <- false; @@ -104,20 +101,18 @@ let reset_state state = state.commands <- StringSet.empty let state = - { noindent = false; + { + noindent = false; this = None; norefs = StringSet.empty; generated_names = Bindings.empty; - commands = StringSet.empty + commands = StringSet.empty; } let rec unique_postfix n = - if n < 0 then - "" - else if n >= 26 then - String.make 1 (Char.chr (n mod 26 + 65)) ^ unique_postfix (n - 26) - else - String.make 1 (Char.chr (n mod 26 + 65)) + if n < 0 then "" + else if n >= 26 then String.make 1 (Char.chr ((n mod 26) + 65)) ^ unique_postfix (n - 26) + else String.make 1 (Char.chr ((n mod 26) + 65)) type id_category = | Function @@ -132,29 +127,28 @@ type id_category = | Outcome let number_replacements = - [ ("0", "Zero"); - ("1", "One"); - ("2", "Two"); - ("3", "Three"); - ("4", "Four"); - ("5", "Five"); - ("6", "Six"); - ("7", "Seven"); - ("8", "Eight"); - ("9", "Nine") ] + [ + ("0", "Zero"); + ("1", "One"); + ("2", "Two"); + ("3", "Three"); + ("4", "Four"); + ("5", "Five"); + ("6", "Six"); + ("7", "Seven"); + ("8", "Eight"); + ("9", "Nine"); + ] (* add to this as needed *) -let other_replacements = - [ ("_", "Underscore") ] +let other_replacements = [("_", "Underscore")] let char_replace str replacements = List.fold_left (fun str (from, into) -> Str.global_replace (Str.regexp_string from) into str) str replacements -let replace_numbers str = - char_replace str number_replacements +let replace_numbers str = char_replace str number_replacements -let replace_others str = - char_replace str other_replacements +let replace_others str = char_replace str other_replacements let category_name = function | Function -> "fn" @@ -163,16 +157,14 @@ let category_name = function | Overload n -> "overload" ^ unique_postfix n | FunclNum n -> "fcl" ^ unique_postfix n | FunclCtor (id, n) -> - let str = replace_others (replace_numbers (Util.zencode_string (string_of_id id))) in - "fcl" ^ String.sub str 1 (String.length str - 1) ^ unique_postfix n + let str = replace_others (replace_numbers (Util.zencode_string (string_of_id id))) in + "fcl" ^ String.sub str 1 (String.length str - 1) ^ unique_postfix n | FunclApp str -> "fcl" ^ str | Let -> "let" | Register -> "register" | Outcome -> "outcome" -let category_name_val = function - | Val -> "" - | cat -> category_name cat +let category_name_val = function Val -> "" | cat -> category_name cat let category_name_simple = function | Function -> "fn" @@ -190,9 +182,8 @@ let category_name_simple = function a mapping from identifiers to strings in state so we always return the same latex id for a sail id. *) let latex_id_raw id = - if Bindings.mem id state.generated_names then - Bindings.find id state.generated_names - else + if Bindings.mem id state.generated_names then Bindings.find id state.generated_names + else ( let str = string_of_id id in let r = Str.regexp {|_\([a-zA-Z0-9]\)|} in let str = @@ -203,7 +194,8 @@ let latex_id_raw id = ignore (Str.search_forward r !str 0); let replace = (Str.matched_group 0 !str).[1] |> Char.uppercase_ascii |> String.make 1 in str := Str.replace_first r replace !str - done; "" + done; + "" with Not_found -> !str in (* If we have any other weird symbols in the id, remove them using Util.zencode_string (removing the z prefix) *) @@ -215,17 +207,15 @@ let latex_id_raw id = let generated = state.generated_names |> Bindings.bindings |> List.map snd |> StringSet.of_list in (* The above makes maps different names to the same name, so we need - to keep track of what names we've generated an ensure that they - remain unique. *) + to keep track of what names we've generated an ensure that they + remain unique. *) let rec unique n str = - if StringSet.mem (str ^ unique_postfix n) generated then - unique (n + 1) str - else - str ^ unique_postfix n + if StringSet.mem (str ^ unique_postfix n) generated then unique (n + 1) str else str ^ unique_postfix n in let str = unique (-1) str in state.generated_names <- Bindings.add id str state.generated_names; str + ) let latex_cat_id cat id = !opt_prefix ^ category_name cat ^ latex_id_raw id @@ -259,8 +249,7 @@ let guard_abbrevs str = Str.global_replace regex "\\saildocabbrev{\\1}\\2" str let text_code str = - str - |> guard_abbrevs + str |> guard_abbrevs |> Str.global_replace (Str.regexp_string "_") "\\_" |> Str.global_replace (Str.regexp_string ">") "$<$" |> Str.global_replace (Str.regexp_string "<") "$>$" @@ -268,70 +257,65 @@ let text_code str = let replace_this str = match state.this with | Some id -> - str - |> Str.global_replace (Str.regexp_string "NAME") (text_code (string_of_id id)) - |> Str.global_replace (Str.regexp_string "THIS") (inline_code (string_of_id id)) + str + |> Str.global_replace (Str.regexp_string "NAME") (text_code (string_of_id id)) + |> Str.global_replace (Str.regexp_string "THIS") (inline_code (string_of_id id)) | None -> str let latex_of_markdown str = let open Omd in let open Printf in - let rec format_elem = function | Paragraph elems -> - let prepend = if state.noindent then (state.noindent <- false; "\\noindent ") else "" in - prepend ^ format elems ^ "\n\n" + let prepend = + if state.noindent then ( + state.noindent <- false; + "\\noindent " + ) + else "" + in + prepend ^ format elems ^ "\n\n" | Text str -> text_code str | Emph elems -> sprintf "\\emph{%s}" (format elems) | Bold elems -> sprintf "\\textbf{%s}" (format elems) - | Ref (r, "THIS", alt, _) -> - begin match state.this with - | Some id -> sprintf "\\hyperref[%s]{%s}" (refcode_id id) (replace_this alt) - | None -> failwith "Cannot create link to THIS" - end + | Ref (r, "THIS", alt, _) -> begin + match state.this with + | Some id -> sprintf "\\hyperref[%s]{%s}" (refcode_id id) (replace_this alt) + | None -> failwith "Cannot create link to THIS" + end | Ref (r, name, alt, _) -> - (* special case for [id] (format as code) *) - let format_fn = if name = alt then inline_code else replace_this in - (* Do not attempt to escape link destinations wrapped in <> *) - if Str.string_match (Str.regexp "<.+>") name 0 then - sprintf "\\hyperref[%s]{%s}" (String.sub name 1 ((String.length name) - 2)) (format_fn alt) - else - begin match r#get_ref name with - | None -> sprintf "\\hyperref[%s]{%s}" (refcode_string name) (format_fn alt) - | Some (link, _) -> sprintf "\\hyperref[%s]{%s}" (refcode_string link) (format_fn alt) - end - | Url (href, text, "") -> - sprintf "\\href{%s}{%s}" href (format text) - | Url (href, text, reference) -> - sprintf "%s\\footnote{%s~\\url{%s}}" (format text) reference href - | Code (_, code) -> - sprintf "\\lstinline`%s`" code + (* special case for [id] (format as code) *) + let format_fn = if name = alt then inline_code else replace_this in + (* Do not attempt to escape link destinations wrapped in <> *) + if Str.string_match (Str.regexp "<.+>") name 0 then + sprintf "\\hyperref[%s]{%s}" (String.sub name 1 (String.length name - 2)) (format_fn alt) + else begin + match r#get_ref name with + | None -> sprintf "\\hyperref[%s]{%s}" (refcode_string name) (format_fn alt) + | Some (link, _) -> sprintf "\\hyperref[%s]{%s}" (refcode_string link) (format_fn alt) + end + | Url (href, text, "") -> sprintf "\\href{%s}{%s}" href (format text) + | Url (href, text, reference) -> sprintf "%s\\footnote{%s~\\url{%s}}" (format text) reference href + | Code (_, code) -> sprintf "\\lstinline`%s`" code | Code_block (lang, code) -> - let lang = if lang = "" then "sail" else lang in - let uid = Digest.string str |> Digest.to_hex in - let chan = open_out (Filename.concat !opt_directory (sprintf "block%s.%s" uid lang)) in - output_string chan code; - close_out chan; - sprintf "\\lstinputlisting[language=%s]{%s/block%s.%s}" lang !opt_directory uid lang - | (Ul list | Ulp list) -> - "\\begin{itemize}\n\\item " - ^ Util.string_of_list "\n\\item " format list - ^ "\n\\end{itemize}\n" - | (Ol list | Olp list) -> - "\\begin{enumerate}\n\\item " - ^ Util.string_of_list "\n\\item " format list - ^ "\n\\end{enumerate}\n" - | H1 header -> "\\section*{" ^ (format header) ^ "}\n" - | H2 header -> "\\subsection*{" ^ (format header) ^ "}\n" - | H3 header -> "\\subsubsection*{" ^ (format header) ^ "}\n" - | H4 header -> "\\paragraph*{" ^ (format header) ^ "}\n" + let lang = if lang = "" then "sail" else lang in + let uid = Digest.string str |> Digest.to_hex in + let chan = open_out (Filename.concat !opt_directory (sprintf "block%s.%s" uid lang)) in + output_string chan code; + close_out chan; + sprintf "\\lstinputlisting[language=%s]{%s/block%s.%s}" lang !opt_directory uid lang + | Ul list | Ulp list -> + "\\begin{itemize}\n\\item " ^ Util.string_of_list "\n\\item " format list ^ "\n\\end{itemize}\n" + | Ol list | Olp list -> + "\\begin{enumerate}\n\\item " ^ Util.string_of_list "\n\\item " format list ^ "\n\\end{enumerate}\n" + | H1 header -> "\\section*{" ^ format header ^ "}\n" + | H2 header -> "\\subsection*{" ^ format header ^ "}\n" + | H3 header -> "\\subsubsection*{" ^ format header ^ "}\n" + | H4 header -> "\\paragraph*{" ^ format header ^ "}\n" | Br -> "\n" | NL -> "\n" | elem -> failwith ("Can't convert to latex: " ^ Omd_backend.sexpr_of_md [elem]) - - and format elems = - String.concat "" (List.map format_elem elems) - in + and format elems = String.concat "" (List.map format_elem elems) in replace_this (format (of_string str)) @@ -340,57 +324,73 @@ let docstring _ = empty let add_links str = let r = Str.regexp {|\([a-zA-Z0-9_]+\)\([ ]*\)(|} in let subst s = - let keywords = StringSet.of_list - [ "function"; "forall"; "if"; "then"; "else"; "exit"; "return"; "match"; "vector"; - "assert"; "constraint"; "let"; "in"; "atom"; "range"; "throw"; "sizeof"; "foreach" ] + let keywords = + StringSet.of_list + [ + "function"; + "forall"; + "if"; + "then"; + "else"; + "exit"; + "return"; + "match"; + "vector"; + "assert"; + "constraint"; + "let"; + "in"; + "atom"; + "range"; + "throw"; + "sizeof"; + "foreach"; + ] in let fn = Str.matched_group 1 s in let spacing = Str.matched_group 2 s in - if StringSet.mem fn keywords || StringSet.mem fn state.norefs then - fn ^ spacing ^ "(" + if StringSet.mem fn keywords || StringSet.mem fn state.norefs then fn ^ spacing ^ "(" else - Printf.sprintf "#\\hyperref[%s]{%s}#%s(" (refcode_string fn) (Str.global_replace (Str.regexp "_") {|\_|} fn) spacing + Printf.sprintf "#\\hyperref[%s]{%s}#%s(" (refcode_string fn) + (Str.global_replace (Str.regexp "_") {|\_|} fn) + spacing in Str.global_substitute r subst str let rec skip_lines in_chan = function | n when n <= 0 -> () - | n -> ignore (input_line in_chan); skip_lines in_chan (n - 1) + | n -> + ignore (input_line in_chan); + skip_lines in_chan (n - 1) let rec read_lines in_chan = function | n when n <= 0 -> [] | n -> - let l = input_line in_chan in - let ls = read_lines in_chan (n - 1) in - l :: ls + let l = input_line in_chan in + let ls = read_lines in_chan (n - 1) in + l :: ls let latex_loc no_loc l = match Reporting.simp_loc l with - | Some (p1, p2) -> - begin - let open Lexing in - try - let in_chan = open_in p1.pos_fname in - try - skip_lines in_chan (p1.pos_lnum - 3); - let code = read_lines in_chan ((p2.pos_lnum - p1.pos_lnum) + 3) in - close_in in_chan; - let doc = match code with - | _ :: _ :: code -> string (add_links (String.concat "\n" code)) - | _ -> empty - in - doc ^^ hardline - with - | _ -> close_in_noerr in_chan; docstring l ^^ no_loc - with - | _ -> docstring l ^^ no_loc - end + | Some (p1, p2) -> begin + let open Lexing in + try + let in_chan = open_in p1.pos_fname in + try + skip_lines in_chan (p1.pos_lnum - 3); + let code = read_lines in_chan (p2.pos_lnum - p1.pos_lnum + 3) in + close_in in_chan; + let doc = match code with _ :: _ :: code -> string (add_links (String.concat "\n" code)) | _ -> empty in + doc ^^ hardline + with _ -> + close_in_noerr in_chan; + docstring l ^^ no_loc + with _ -> docstring l ^^ no_loc + end | None -> docstring l ^^ no_loc let doc_spec_simple (VS_aux (VS_val_spec (ts, id, ext, is_cast), _)) = - Pretty_print_sail.doc_id id ^^ space - ^^ colon ^^ space - ^^ Pretty_print_sail.doc_typschm ~simple:true ts + Pretty_print_sail.doc_id id ^^ space ^^ colon ^^ space ^^ Pretty_print_sail.doc_typschm ~simple:true ts let latex_command cat id no_loc l = state.this <- Some id; @@ -402,19 +402,21 @@ let latex_command cat id no_loc l = output_string chan (Pretty_print_sail.to_string doc); close_out chan; let command = sprintf "\\%s" (latex_cat_id cat id) in - if StringSet.mem command state.commands then - (Reporting.warn "" l ("Multiple instances of " ^ string_of_id id ^ " only generating latex for the first"); empty) - else - begin - state.commands <- StringSet.add command state.commands; - - ksprintf string "\\newcommand{%s}{\\saildoclabelled{%s}{\\saildoc%s{" command (refcode_cat_id cat id) (category_name_simple cat) - ^^ docstring l ^^ string "}{" - ^^ ksprintf string "\\lstinputlisting[language=sail]{%s}}}}" (Filename.concat !opt_directory code_file) - end + if StringSet.mem command state.commands then ( + Reporting.warn "" l ("Multiple instances of " ^ string_of_id id ^ " only generating latex for the first"); + empty + ) + else begin + state.commands <- StringSet.add command state.commands; + + ksprintf string "\\newcommand{%s}{\\saildoclabelled{%s}{\\saildoc%s{" command (refcode_cat_id cat id) + (category_name_simple cat) + ^^ docstring l ^^ string "}{" + ^^ ksprintf string "\\lstinputlisting[language=sail]{%s}}}}" (Filename.concat !opt_directory code_file) + end let latex_funcls def = - let module StringMap = Map.Make(String) in + let module StringMap = Map.Make (String) in let counter = ref 0 in let app_codes = ref StringMap.empty in let ctors = ref Bindings.empty in @@ -425,21 +427,23 @@ let latex_funcls def = let funcl_command (FCL_funcl (id, pexp)) = match pexp with | Pat_aux (Pat_exp (P_aux (P_app (ctor, _), _), _), _) -> - let n = try Bindings.find ctor !ctors with Not_found -> -1 in - ctors := Bindings.add ctor (n + 1) !ctors; - FunclCtor (ctor, n), id + let n = try Bindings.find ctor !ctors with Not_found -> -1 in + ctors := Bindings.add ctor (n + 1) !ctors; + (FunclCtor (ctor, n), id) | Pat_aux (Pat_exp (_, exp), _) -> - let ac = app_code exp in - let n = try StringMap.find ac !app_codes with Not_found -> -1 in - app_codes := StringMap.add ac (n + 1) !app_codes; - FunclApp (ac ^ unique_postfix n), id - | _ -> incr counter; (FunclNum (!counter + 64), id) + let ac = app_code exp in + let n = try StringMap.find ac !app_codes with Not_found -> -1 in + app_codes := StringMap.add ac (n + 1) !app_codes; + (FunclApp (ac ^ unique_postfix n), id) + | _ -> + incr counter; + (FunclNum (!counter + 64), id) in function | (FCL_aux (funcl_aux, annot) as funcl) :: funcls -> - let cat, id = funcl_command funcl_aux in - let first = latex_command cat id (Pretty_print_sail.doc_funcl funcl) (fst annot).loc in - first ^^ next funcls + let cat, id = funcl_command funcl_aux in + let first = latex_command cat id (Pretty_print_sail.doc_funcl funcl) (fst annot).loc in + first ^^ next funcls | [] -> empty in latex_funcls' def @@ -451,22 +455,19 @@ let process_pragma l command = match cmd with | "noindent" -> - state.noindent <- true; - None - + state.noindent <- true; + None | "noref" -> - state.norefs <- StringSet.add arg state.norefs; - None - + state.norefs <- StringSet.add arg state.norefs; + None | "newcommand" -> - let n = try String.index arg ' ' with Not_found -> failwith "No command given" in - let name = Str.string_before arg n in - let body = String.trim (latex_of_markdown (Str.string_after arg n)) in - Some (ksprintf string "\\newcommand{\\%s}{%s}" name body) - + let n = try String.index arg ' ' with Not_found -> failwith "No command given" in + let name = Str.string_before arg n in + let body = String.trim (latex_of_markdown (Str.string_after arg n)) in + Some (ksprintf string "\\newcommand{\\%s}{%s}" name body) | _ -> - Reporting.warn "Bad latex pragma at" l ""; - None + Reporting.warn "Bad latex pragma at" l ""; + None let tdef_id = function | TD_abbrev (id, _, _) -> id @@ -478,16 +479,16 @@ let tdef_id = function let defs { defs; _ } = reset_state state; - let preamble = string ("\\providecommand\\saildoclabelled[2]{\\phantomsection\\label{#1}#2}\n" ^ - "\\providecommand\\saildocval[2]{#1 #2}\n" ^ - "\\providecommand\\saildocoutcome[2]{#1 #2}\n" ^ - "\\providecommand\\saildocfcl[2]{#1 #2}\n" ^ - "\\providecommand\\saildoctype[2]{#1 #2}\n" ^ - "\\providecommand\\saildocfn[2]{#1 #2}\n" ^ - "\\providecommand\\saildocoverload[2]{#1 #2}\n" ^ - "\\providecommand\\saildocabbrev[1]{#1\\@}\n" ^ - "\\providecommand\\saildoclet[2]{#1 #2}\n" ^ - "\\providecommand\\saildocregister[2]{#1 #2}\n\n") in + let preamble = + string + ("\\providecommand\\saildoclabelled[2]{\\phantomsection\\label{#1}#2}\n" + ^ "\\providecommand\\saildocval[2]{#1 #2}\n" ^ "\\providecommand\\saildocoutcome[2]{#1 #2}\n" + ^ "\\providecommand\\saildocfcl[2]{#1 #2}\n" ^ "\\providecommand\\saildoctype[2]{#1 #2}\n" + ^ "\\providecommand\\saildocfn[2]{#1 #2}\n" ^ "\\providecommand\\saildocoverload[2]{#1 #2}\n" + ^ "\\providecommand\\saildocabbrev[1]{#1\\@}\n" ^ "\\providecommand\\saildoclet[2]{#1 #2}\n" + ^ "\\providecommand\\saildocregister[2]{#1 #2}\n\n" + ) + in let overload_counters = ref Bindings.empty in @@ -503,64 +504,49 @@ let defs { defs; _ } = let latex_def (DEF_aux (aux, _) as def) = match aux with | DEF_overload (id, ids) -> - let doc = - string (Printf.sprintf "overload %s = {%s}" (string_of_id id) (Util.string_of_list ", " string_of_id ids)) - in - overload_counters := Bindings.update id (function None -> Some 0 | Some n -> Some (n + 1)) !overload_counters; - let count = Bindings.find id !overload_counters in - Some (latex_command (Overload count) id doc (id_loc id)) - + let doc = + string (Printf.sprintf "overload %s = {%s}" (string_of_id id) (Util.string_of_list ", " string_of_id ids)) + in + overload_counters := Bindings.update id (function None -> Some 0 | Some n -> Some (n + 1)) !overload_counters; + let count = Bindings.find id !overload_counters in + Some (latex_command (Overload count) id doc (id_loc id)) | DEF_val (VS_aux (VS_val_spec (_, id, _, _), annot) as vs) -> - valspecs := Bindings.add id id !valspecs; - if !opt_simple_val then - Some (latex_command Val id (doc_spec_simple vs) (fst annot)) - else - Some (latex_command Val id (Pretty_print_sail.doc_spec vs) (fst annot)) - + valspecs := Bindings.add id id !valspecs; + if !opt_simple_val then Some (latex_command Val id (doc_spec_simple vs) (fst annot)) + else Some (latex_command Val id (Pretty_print_sail.doc_spec vs) (fst annot)) | DEF_fundef (FD_aux (FD_function (_, _, [FCL_aux (FCL_funcl (id, _), _)]), annot)) -> - fundefs := Bindings.add id id !fundefs; - Some (latex_command Function id (Pretty_print_sail.doc_def def) (fst annot)) - + fundefs := Bindings.add id id !fundefs; + Some (latex_command Function id (Pretty_print_sail.doc_def def) (fst annot)) | DEF_let (LB_aux (LB_val (pat, _), annot)) -> - let ids = pat_ids pat in - begin match IdSet.min_elt_opt ids with - | None -> None - | Some base_id -> - letdefs := IdSet.fold (fun id -> Bindings.add id base_id) ids !letdefs; - Some (latex_command Let base_id (Pretty_print_sail.doc_def def) (fst annot)) - end - + let ids = pat_ids pat in + begin + match IdSet.min_elt_opt ids with + | None -> None + | Some base_id -> + letdefs := IdSet.fold (fun id -> Bindings.add id base_id) ids !letdefs; + Some (latex_command Let base_id (Pretty_print_sail.doc_def def) (fst annot)) + end | DEF_type (TD_aux (tdef, annot)) -> - let id = tdef_id tdef in - typedefs := Bindings.add id id !typedefs; - Some (latex_command Type id (Pretty_print_sail.doc_def def) (fst annot)) - - | DEF_fundef (FD_aux (FD_function (_, _, funcls), annot)) as def -> - Some (latex_funcls def funcls) - - | DEF_pragma ("latex", command, l) -> - process_pragma l command - + let id = tdef_id tdef in + typedefs := Bindings.add id id !typedefs; + Some (latex_command Type id (Pretty_print_sail.doc_def def) (fst annot)) + | DEF_fundef (FD_aux (FD_function (_, _, funcls), annot)) as def -> Some (latex_funcls def funcls) + | DEF_pragma ("latex", command, l) -> process_pragma l command | DEF_register (DEC_aux (_, annot) as dec) -> - let id = id_of_dec_spec dec in - regdefs := Bindings.add id id !regdefs; - Some (latex_command Register id (Pretty_print_sail.doc_def def) (fst annot)) - + let id = id_of_dec_spec dec in + regdefs := Bindings.add id id !regdefs; + Some (latex_command Register id (Pretty_print_sail.doc_def def) (fst annot)) | DEF_outcome (OV_aux (OV_outcome (id, _, _), l), _) -> - outcomedefs := Bindings.add id id !outcomedefs; - Some (latex_command Outcome id (Pretty_print_sail.doc_def def) l) - + outcomedefs := Bindings.add id id !outcomedefs; + Some (latex_command Outcome id (Pretty_print_sail.doc_def def) l) | _ -> None in let rec process_defs = function | [] -> empty | def :: defs -> - let tex = match latex_def def with - | Some tex -> tex ^^ twice hardline - | None -> empty - in - tex ^^ process_defs defs + let tex = match latex_def def with Some tex -> tex ^^ twice hardline | None -> empty in + tex ^^ process_defs defs in let tex = process_defs defs in @@ -573,39 +559,48 @@ let defs { defs; _ } = (* Accept both the plain identifier and an escaped one that can be used in macros that might also typeset the argument. *) let add_encoded_ids ids = - List.concat (List.map (fun (id,base) -> - let s = Str.global_replace (Str.regexp_string "#") "\\#" (string_of_id id) in - let s' = text_code s in - if String.compare s s' == 0 then [(s,base)] else [(s,base); (s',base)]) ids) + List.concat + (List.map + (fun (id, base) -> + let s = Str.global_replace (Str.regexp_string "#") "\\#" (string_of_id id) in + let s' = text_code s in + if String.compare s s' == 0 then [(s, base)] else [(s, base); (s', base)] + ) + ids + ) in let id_command cat ids = sprintf "\\newcommand{\\%s%s}[1]{\n " !opt_prefix (category_name cat) - ^ Util.string_of_list "%\n " (fun (s, id) -> sprintf "\\ifstrequal{#1}{%s}{\\%s}{}" s (latex_cat_id cat id)) - (add_encoded_ids (Bindings.bindings ids)) + ^ Util.string_of_list "%\n " + (fun (s, id) -> sprintf "\\ifstrequal{#1}{%s}{\\%s}{}" s (latex_cat_id cat id)) + (add_encoded_ids (Bindings.bindings ids)) ^ "}" |> string in let ref_command cat ids = sprintf "\\newcommand{\\%sref%s}[2]{\n " !opt_prefix (category_name cat) - ^ Util.string_of_list "%\n " (fun (s, id) -> sprintf "\\ifstrequal{#1}{%s}{\\hyperref[%s]{#2}}{}" s (refcode_cat_id cat id)) - (add_encoded_ids (Bindings.bindings ids)) + ^ Util.string_of_list "%\n " + (fun (s, id) -> sprintf "\\ifstrequal{#1}{%s}{\\hyperref[%s]{#2}}{}" s (refcode_cat_id cat id)) + (add_encoded_ids (Bindings.bindings ids)) ^ "}" |> string in - preamble - ^^ tex - ^^ separate (twice hardline) [id_command Val !valspecs; - ref_command Val !valspecs; - id_command Function !fundefs; - ref_command Function !fundefs; - id_command Type !typedefs; - ref_command Type !typedefs; - id_command Let !letdefs; - ref_command Let !letdefs; - id_command Register !regdefs; - ref_command Register !regdefs; - id_command Outcome !outcomedefs; - ref_command Outcome !outcomedefs;] + preamble ^^ tex + ^^ separate (twice hardline) + [ + id_command Val !valspecs; + ref_command Val !valspecs; + id_command Function !fundefs; + ref_command Function !fundefs; + id_command Type !typedefs; + ref_command Type !typedefs; + id_command Let !letdefs; + ref_command Let !letdefs; + id_command Register !regdefs; + ref_command Register !regdefs; + id_command Outcome !outcomedefs; + ref_command Outcome !outcomedefs; + ] ^^ hardline diff --git a/src/sail_latex_backend/sail_plugin_latex.ml b/src/sail_latex_backend/sail_plugin_latex.ml index 96a840e6f..34e00f7ff 100644 --- a/src/sail_latex_backend/sail_plugin_latex.ml +++ b/src/sail_latex_backend/sail_plugin_latex.ml @@ -67,22 +67,29 @@ open Libsail -let latex_options = [ - ( "-latex_prefix", - Arg.String (fun prefix -> Latex.opt_prefix := prefix), - " set a custom prefix for generated LaTeX labels and macro commands (default sail)"); - ( "-latex_full_valspecs", - Arg.Clear Latex.opt_simple_val, - " print full valspecs in LaTeX output"); - ( "-latex_abbrevs", - Arg.String (fun s -> - let abbrevs = String.split_on_char ';' s in - let filtered = List.filter (fun abbrev -> not (String.equal "" abbrev)) abbrevs in - match List.find_opt (fun abbrev -> not (String.equal "." (String.sub abbrev (String.length abbrev - 1) 1))) filtered with - | None -> Latex.opt_abbrevs := filtered - | Some abbrev -> raise (Arg.Bad (abbrev ^ " does not end in a '.'"))), - " semicolon-separated list of abbreviations to fix spacing for in LaTeX output (default 'e.g.;i.e.')"); -] +let latex_options = + [ + ( "-latex_prefix", + Arg.String (fun prefix -> Latex.opt_prefix := prefix), + " set a custom prefix for generated LaTeX labels and macro commands (default sail)" + ); + ("-latex_full_valspecs", Arg.Clear Latex.opt_simple_val, " print full valspecs in LaTeX output"); + ( "-latex_abbrevs", + Arg.String + (fun s -> + let abbrevs = String.split_on_char ';' s in + let filtered = List.filter (fun abbrev -> not (String.equal "" abbrev)) abbrevs in + match + List.find_opt + (fun abbrev -> not (String.equal "." (String.sub abbrev (String.length abbrev - 1) 1))) + filtered + with + | None -> Latex.opt_abbrevs := filtered + | Some abbrev -> raise (Arg.Bad (abbrev ^ " does not end in a '.'")) + ), + " semicolon-separated list of abbreviations to fix spacing for in LaTeX output (default 'e.g.;i.e.')" + ); + ] let latex_target _ out_file ast effect_info env = Reporting.opt_warnings := true; @@ -90,20 +97,18 @@ let latex_target _ out_file ast effect_info env = begin try if not (Sys.is_directory latex_dir) then begin - prerr_endline ("Failure: latex output location exists and is not a directory: " ^ latex_dir); - exit 1 - end - with Sys_error(_) -> Unix.mkdir latex_dir 0o755 + prerr_endline ("Failure: latex output location exists and is not a directory: " ^ latex_dir); + exit 1 + end + with Sys_error _ -> Unix.mkdir latex_dir 0o755 end; Latex.opt_directory := latex_dir; let chan = open_out (Filename.concat latex_dir "commands.tex") in output_string chan (Pretty_print_sail.to_string (Latex.defs (Type_check.strip_ast ast))); close_out chan - + let _ = - Target.register - ~name:"latex" - ~options:latex_options + Target.register ~name:"latex" ~options:latex_options ~pre_parse_hook:(fun () -> Type_check.opt_expand_valspec := false; Type_check.opt_no_bitfield_expansion := true diff --git a/src/sail_lem_backend/dune b/src/sail_lem_backend/dune index 20ef5d730..d06d46815 100644 --- a/src/sail_lem_backend/dune +++ b/src/sail_lem_backend/dune @@ -1,15 +1,20 @@ (env - (dev - (flags (:standard -w -33 -w -27 -w -32 -w -26 -w -37))) - (release - (flags (:standard -w -33 -w -27 -w -32 -w -26 -w -37)))) + (dev + (flags + (:standard -w -33 -w -27 -w -32 -w -26 -w -37))) + (release + (flags + (:standard -w -33 -w -27 -w -32 -w -26 -w -37)))) (executable - (name sail_plugin_lem) - (modes (native plugin)) - (libraries libsail)) + (name sail_plugin_lem) + (modes + (native plugin)) + (libraries libsail)) (install - (section (site (libsail plugins))) - (package sail_lem_backend) - (files sail_plugin_lem.cmxs)) + (section + (site + (libsail plugins))) + (package sail_lem_backend) + (files sail_plugin_lem.cmxs)) diff --git a/src/sail_lem_backend/pretty_print_lem.ml b/src/sail_lem_backend/pretty_print_lem.ml index e39892421..985d241b6 100644 --- a/src/sail_lem_backend/pretty_print_lem.ml +++ b/src/sail_lem_backend/pretty_print_lem.ml @@ -83,13 +83,14 @@ open Pretty_print_common let opt_sequential = ref false type context = { - early_ret : bool; - monadic : bool; - bound_nexps : NexpSet.t; - top_env : Env.t; - params_to_print : Util.IntSet.t Bindings.t; + early_ret : bool; + monadic : bool; + bound_nexps : NexpSet.t; + top_env : Env.t; + params_to_print : Util.IntSet.t Bindings.t; } -let empty_ctxt = { +let empty_ctxt = + { early_ret = false; monadic = false; bound_nexps = NexpSet.empty; @@ -102,114 +103,93 @@ let langlebar = string "<|" let ranglebar = string "|>" let anglebars = enclose langlebar ranglebar -let doc_var (Kid_aux(Var v,_)) = string v +let doc_var (Kid_aux (Var v, _)) = string v let is_number_char c = - c = '0' || c = '1' || c = '2' || c = '3' || c = '4' || c = '5' || - c = '6' || c = '7' || c = '8' || c = '9' - -let rec fix_id remove_tick name = match name with - | "assert" - | "lsl" - | "lsr" - | "asr" - | "type" - | "fun" - | "function" - | "raise" - | "try" - | "match" - | "with" - | "check" - | "field" - | "LT" | "lt" | "lteq" - | "GT" | "gt" | "gteq" - | "EQ" | "eq" | "neq" - | "integer" - -> name ^ "'" - | _ -> - if String.contains name '#' then - fix_id remove_tick (String.concat "_" (Util.split_on_char '#' name)) - else if String.contains name '?' then - fix_id remove_tick (String.concat "_pat_" (Util.split_on_char '?' name)) - else if name.[0] = '\'' then - let var = String.sub name 1 (String.length name - 1) in - if remove_tick then var else (var ^ "'") - else if is_number_char(name.[0]) then - ("v" ^ name ^ "'") - else name - -let doc_id_lem (Id_aux(i,_)) = - match i with - | Id i -> string (fix_id false i) - | Operator x -> string (Util.zencode_string ("op " ^ x)) + c = '0' || c = '1' || c = '2' || c = '3' || c = '4' || c = '5' || c = '6' || c = '7' || c = '8' || c = '9' -let doc_id_lem_type (Id_aux(i,_)) = +let rec fix_id remove_tick name = + match name with + | "assert" | "lsl" | "lsr" | "asr" | "type" | "fun" | "function" | "raise" | "try" | "match" | "with" | "check" + | "field" | "LT" | "lt" | "lteq" | "GT" | "gt" | "gteq" | "EQ" | "eq" | "neq" | "integer" -> + name ^ "'" + | _ -> + if String.contains name '#' then fix_id remove_tick (String.concat "_" (Util.split_on_char '#' name)) + else if String.contains name '?' then fix_id remove_tick (String.concat "_pat_" (Util.split_on_char '?' name)) + else if name.[0] = '\'' then ( + let var = String.sub name 1 (String.length name - 1) in + if remove_tick then var else var ^ "'" + ) + else if is_number_char name.[0] then "v" ^ name ^ "'" + else name + +let doc_id_lem (Id_aux (i, _)) = + match i with Id i -> string (fix_id false i) | Operator x -> string (Util.zencode_string ("op " ^ x)) + +let doc_id_lem_type (Id_aux (i, _)) = match i with - | Id("int") -> string "ii" - | Id("nat") -> string "ii" - | Id("option") -> string "maybe" + | Id "int" -> string "ii" + | Id "nat" -> string "ii" + | Id "option" -> string "maybe" | Id i -> string (fix_id false i) | Operator x -> string (Util.zencode_string ("op " ^ x)) -let doc_id_lem_ctor (Id_aux(i,_)) = +let doc_id_lem_ctor (Id_aux (i, _)) = match i with - | Id("bit") -> string "bitU" - | Id("int") -> string "integer" - | Id("nat") -> string "integer" - | Id("Some") -> string "Just" - | Id("None") -> string "Nothing" + | Id "bit" -> string "bitU" + | Id "int" -> string "integer" + | Id "nat" -> string "integer" + | Id "Some" -> string "Just" + | Id "None" -> string "Nothing" | Id i -> string (fix_id false (String.capitalize_ascii i)) | Operator x -> string (Util.zencode_string ("op " ^ x)) -let deinfix = function - | Id_aux (Id v, l) -> Id_aux (Operator v, l) - | Id_aux (Operator v, l) -> Id_aux (Operator v, l) +let deinfix = function Id_aux (Id v, l) -> Id_aux (Operator v, l) | Id_aux (Operator v, l) -> Id_aux (Operator v, l) let doc_var_lem kid = string (fix_id true (string_of_kid kid)) let simple_annot l typ = (Parse_ast.Generated l, Some (Env.empty, typ, no_effect)) -let simple_num l n = E_aux ( - E_lit (L_aux (L_num n, Parse_ast.Generated l)), - simple_annot (Parse_ast.Generated l) - (atom_typ (Nexp_aux (Nexp_constant n, Parse_ast.Generated l)))) +let simple_num l n = + E_aux + ( E_lit (L_aux (L_num n, Parse_ast.Generated l)), + simple_annot (Parse_ast.Generated l) (atom_typ (Nexp_aux (Nexp_constant n, Parse_ast.Generated l))) + ) -let is_regtyp (Typ_aux (typ, _)) env = match typ with - | Typ_app(id, _) when string_of_id id = "register" -> true - | _ -> false +let is_regtyp (Typ_aux (typ, _)) env = + match typ with Typ_app (id, _) when string_of_id id = "register" -> true | _ -> false let lemnum default n = - if Big_int.less_equal Big_int.zero n && Big_int.less_equal n (Big_int.of_int 128) then - "int" ^ Big_int.to_string n - else if Big_int.greater_equal n Big_int.zero then - default n - else ("(int0 - " ^ (default (Big_int.abs n)) ^ ")") + if Big_int.less_equal Big_int.zero n && Big_int.less_equal n (Big_int.of_int 128) then "int" ^ Big_int.to_string n + else if Big_int.greater_equal n Big_int.zero then default n + else "(int0 - " ^ default (Big_int.abs n) ^ ")" let doc_nexp_lem nexp = let nice_kid kid = - let (Kid_aux (Var kid,l)) = orig_kid kid in - Kid_aux (Var (String.map (function '#' -> '_' | c -> c) kid),l) + let (Kid_aux (Var kid, l)) = orig_kid kid in + Kid_aux (Var (String.map (function '#' -> '_' | c -> c) kid), l) in let (Nexp_aux (nexp, l) as full_nexp) = nexp_simp nexp in match nexp with | Nexp_constant i -> string ("ty" ^ Big_int.to_string i) | Nexp_var v -> string (string_of_kid (nice_kid v)) | _ -> - let rec mangle_nexp (Nexp_aux (nexp, _)) = begin - match nexp with - | Nexp_id id -> string_of_id id - | Nexp_var kid -> string_of_id (id_of_kid (nice_kid kid)) - | Nexp_constant i -> lemnum Big_int.to_string i - | Nexp_times (n1, n2) -> mangle_nexp n1 ^ "_times_" ^ mangle_nexp n2 - | Nexp_sum (n1, n2) -> mangle_nexp n1 ^ "_plus_" ^ mangle_nexp n2 - | Nexp_minus (n1, n2) -> mangle_nexp n1 ^ "_minus_" ^ mangle_nexp n2 - | Nexp_exp n -> "exp_" ^ mangle_nexp n - | Nexp_neg n -> "neg_" ^ mangle_nexp n - | _ -> - raise (Reporting.err_unreachable l __POS__ - ("cannot pretty-print nexp \"" ^ string_of_nexp full_nexp ^ "\"")) - end in - string ("'" ^ mangle_nexp full_nexp) + let rec mangle_nexp (Nexp_aux (nexp, _)) = + begin + match nexp with + | Nexp_id id -> string_of_id id + | Nexp_var kid -> string_of_id (id_of_kid (nice_kid kid)) + | Nexp_constant i -> lemnum Big_int.to_string i + | Nexp_times (n1, n2) -> mangle_nexp n1 ^ "_times_" ^ mangle_nexp n2 + | Nexp_sum (n1, n2) -> mangle_nexp n1 ^ "_plus_" ^ mangle_nexp n2 + | Nexp_minus (n1, n2) -> mangle_nexp n1 ^ "_minus_" ^ mangle_nexp n2 + | Nexp_exp n -> "exp_" ^ mangle_nexp n + | Nexp_neg n -> "neg_" ^ mangle_nexp n + | _ -> + raise + (Reporting.err_unreachable l __POS__ ("cannot pretty-print nexp \"" ^ string_of_nexp full_nexp ^ "\"")) + end + in + string ("'" ^ mangle_nexp full_nexp) (* Rewrite mangled names of type variables to the original names *) let rec orig_nexp (Nexp_aux (nexp, l)) = @@ -229,19 +209,21 @@ let type_parameters_to_print env defs : Util.IntSet.t Bindings.t = let make_type_size_map env id typq typs type_size_map = let type_params, _ = List.fold_left - (fun (is,i) q -> + (fun (is, i) q -> match q with - | KOpt_aux (KOpt_kind (K_aux (K_type, _),_), _) -> - (Util.IntSet.add i is, i+1) - | _ -> (is,i+1)) (Util.IntSet.empty, 0) (quant_kopts typq) + | KOpt_aux (KOpt_kind (K_aux (K_type, _), _), _) -> (Util.IntSet.add i is, i + 1) + | _ -> (is, i + 1) + ) + (Util.IntSet.empty, 0) (quant_kopts typq) in let local_ints, _ = List.fold_left - (fun (map,i) q -> + (fun (map, i) q -> match q with - | KOpt_aux (KOpt_kind (K_aux (K_int, _),kid), _) -> - (KBindings.add kid i map, i+1) - | _ -> (map,i+1)) (KBindings.empty, 0) (quant_kopts typq) + | KOpt_aux (KOpt_kind (K_aux (K_int, _), kid), _) -> (KBindings.add kid i map, i + 1) + | _ -> (map, i + 1) + ) + (KBindings.empty, 0) (quant_kopts typq) in let rec check_typ is typ = match Env.expand_synonyms env typ with @@ -249,22 +231,23 @@ let type_parameters_to_print env defs : Util.IntSet.t Bindings.t = match Bindings.find_opt id type_size_map with | None -> is | Some js -> - let is' = - Util.IntSet.fold (fun j is -> - match List.nth_opt args j with - | Some (A_aux (A_nexp (Nexp_aux (Nexp_var kid, _)), _)) -> - (match KBindings.find_opt kid local_ints with - | Some i -> Util.IntSet.add i is - | None -> is) - | _ -> is) js is - in - List.fold_left (fun is (A_aux (arg, _)) -> - match arg with - | A_typ typ -> check_typ is typ - | _ -> is) is' args + let is' = + Util.IntSet.fold + (fun j is -> + match List.nth_opt args j with + | Some (A_aux (A_nexp (Nexp_aux (Nexp_var kid, _)), _)) -> ( + match KBindings.find_opt kid local_ints with Some i -> Util.IntSet.add i is | None -> is + ) + | _ -> is + ) + js is + in + List.fold_left + (fun is (A_aux (arg, _)) -> match arg with A_typ typ -> check_typ is typ | _ -> is) + is' args end | Typ_aux (Typ_tuple typs, _) -> List.fold_left check_typ is typs - | Typ_aux (Typ_exist (_,_,typ), _) -> check_typ is typ + | Typ_aux (Typ_exist (_, _, typ), _) -> check_typ is typ | _ -> is in let is = List.fold_left (fun is typ -> check_typ is typ) type_params typs in @@ -273,31 +256,26 @@ let type_parameters_to_print env defs : Util.IntSet.t Bindings.t = let check_def type_size_map (DEF_aux (def, _)) = match def with | DEF_type (TD_aux (TD_record (id, typq, fs, _), _)) -> - let env = Env.add_typquant Unknown typq env in - make_type_size_map env id typq (List.map fst fs) type_size_map + let env = Env.add_typquant Unknown typq env in + make_type_size_map env id typq (List.map fst fs) type_size_map | DEF_type (TD_aux (TD_variant (id, typq, tus, _), _)) -> - let env = Env.add_typquant Unknown typq env in - make_type_size_map env id typq (List.map (fun (Tu_aux (Tu_ty_id (t,_),_)) -> t) tus) type_size_map + let env = Env.add_typquant Unknown typq env in + make_type_size_map env id typq (List.map (fun (Tu_aux (Tu_ty_id (t, _), _)) -> t) tus) type_size_map | DEF_type (TD_aux (TD_abbrev (id, typq, typ_arg), _)) -> begin - let env = Env.add_typquant Unknown typq env in + let env = Env.add_typquant Unknown typq env in match typ_arg with - | A_aux (A_typ typ, _) -> - make_type_size_map env id typq [typ] type_size_map + | A_aux (A_typ typ, _) -> make_type_size_map env id typq [typ] type_size_map | _ -> type_size_map end | _ -> type_size_map in (* Seed parameters to print with builtin types that need a parameter to be printed *) - let bitvector_itself_prints = - if !Monomorphise.opt_mwords - then Util.IntSet.singleton 0 - else Util.IntSet.empty - in + let bitvector_itself_prints = if !Monomorphise.opt_mwords then Util.IntSet.singleton 0 else Util.IntSet.empty in let init_map = - Bindings.empty |> - Bindings.add (mk_id "bitvector") bitvector_itself_prints |> - Bindings.add (mk_id "itself") bitvector_itself_prints + Bindings.empty + |> Bindings.add (mk_id "bitvector") bitvector_itself_prints + |> Bindings.add (mk_id "itself") bitvector_itself_prints in let map = List.fold_left check_def init_map defs in @@ -311,45 +289,34 @@ let type_parameters_to_print env defs : Util.IntSet.t Bindings.t = (* Returns the set of type variables that will appear in the Lem output, which may be smaller than those in the Sail type. May need to be updated with doc_typ_lem *) -let rec lem_nexps_of_typ params_to_print (Typ_aux (t,l)) = +let rec lem_nexps_of_typ params_to_print (Typ_aux (t, l)) = let trec = lem_nexps_of_typ params_to_print in match t with | Typ_id _ -> NexpSet.empty | Typ_var kid -> NexpSet.singleton (orig_nexp (nvar kid)) - | Typ_fn (t1,t2) -> List.fold_left NexpSet.union (trec t2) (List.map trec t1) - | Typ_tuple ts -> - List.fold_left (fun s t -> NexpSet.union s (trec t)) - NexpSet.empty ts - | Typ_app(Id_aux (Id "bitvector", _), [ - A_aux (A_nexp m, _); - A_aux (A_order ord, _)]) -> - let m = nexp_simp m in - if !Monomorphise.opt_mwords && not (is_nexp_constant m) then - NexpSet.singleton (orig_nexp m) - else trec bit_typ - | Typ_app(Id_aux (Id "vector", _), [ - A_aux (A_nexp m, _); - A_aux (A_order ord, _); - A_aux (A_typ elem_typ, _)]) -> - trec elem_typ - | Typ_app(Id_aux (Id "register", _), [A_aux (A_typ etyp, _)]) -> - trec etyp - | Typ_app(Id_aux (Id "range", _),_) - | Typ_app(Id_aux (Id "implicit", _),_) - | Typ_app(Id_aux (Id "atom", _), _) -> NexpSet.empty - | Typ_app (id,tas) -> begin - match Bindings.find_opt id params_to_print with - | Some is -> - Util.IntSet.fold (fun i s -> NexpSet.union s (lem_nexps_of_typ_arg params_to_print (List.nth tas i))) - is NexpSet.empty - | None -> - List.fold_left (fun s ta -> NexpSet.union s (lem_nexps_of_typ_arg params_to_print ta)) - NexpSet.empty tas + | Typ_fn (t1, t2) -> List.fold_left NexpSet.union (trec t2) (List.map trec t1) + | Typ_tuple ts -> List.fold_left (fun s t -> NexpSet.union s (trec t)) NexpSet.empty ts + | Typ_app (Id_aux (Id "bitvector", _), [A_aux (A_nexp m, _); A_aux (A_order ord, _)]) -> + let m = nexp_simp m in + if !Monomorphise.opt_mwords && not (is_nexp_constant m) then NexpSet.singleton (orig_nexp m) else trec bit_typ + | Typ_app (Id_aux (Id "vector", _), [A_aux (A_nexp m, _); A_aux (A_order ord, _); A_aux (A_typ elem_typ, _)]) -> + trec elem_typ + | Typ_app (Id_aux (Id "register", _), [A_aux (A_typ etyp, _)]) -> trec etyp + | Typ_app (Id_aux (Id "range", _), _) | Typ_app (Id_aux (Id "implicit", _), _) | Typ_app (Id_aux (Id "atom", _), _) -> + NexpSet.empty + | Typ_app (id, tas) -> begin + match Bindings.find_opt id params_to_print with + | Some is -> + Util.IntSet.fold + (fun i s -> NexpSet.union s (lem_nexps_of_typ_arg params_to_print (List.nth tas i))) + is NexpSet.empty + | None -> List.fold_left (fun s ta -> NexpSet.union s (lem_nexps_of_typ_arg params_to_print ta)) NexpSet.empty tas end - | Typ_exist (kids,_,t) -> trec t + | Typ_exist (kids, _, t) -> trec t | Typ_bidir _ -> raise (Reporting.err_unreachable l __POS__ "Lem doesn't support bidir types") | Typ_internal_unknown -> raise (Reporting.err_unreachable l __POS__ "escaped Typ_internal_unknown") -and lem_nexps_of_typ_arg params_to_print (A_aux (ta,_)) = + +and lem_nexps_of_typ_arg params_to_print (A_aux (ta, _)) = match ta with | A_nexp nexp -> let nexp = nexp_simp (orig_nexp nexp) in @@ -359,102 +326,103 @@ and lem_nexps_of_typ_arg params_to_print (A_aux (ta,_)) = | A_bool _ -> NexpSet.empty let lem_tyvars_of_typ params_to_print typ = - NexpSet.fold (fun nexp ks -> KidSet.union ks (tyvars_of_nexp nexp)) + NexpSet.fold + (fun nexp ks -> KidSet.union ks (tyvars_of_nexp nexp)) (lem_nexps_of_typ params_to_print typ) KidSet.empty (* When making changes here, check whether they affect lem_tyvars_of_typ *) let doc_typ_lem, doc_typ_lem_brackets, doc_atomic_typ_lem = (* following the structure of parser for precedence *) let rec typ params_to_print atyp_needed ty = tup_typ params_to_print atyp_needed ty - and tup_typ params_to_print atyp_needed (Typ_aux (t, l) as ty) = match t with - | Typ_fn(args,ret) -> - let ret_typ = - (* TODO EFFECT: Monadicity as parameter or separate function. See Coq *) - (* + and tup_typ params_to_print atyp_needed (Typ_aux (t, l) as ty) = + match t with + | Typ_fn (args, ret) -> + let ret_typ = + (* TODO EFFECT: Monadicity as parameter or separate function. See Coq *) + (* if effectful efct then separate space [string "M"; tup_typ true ret] - else *) separate space [tup_typ params_to_print false ret] in - let arg_typs = List.map (tup_typ params_to_print false) args in - let tpp = separate (space ^^ arrow ^^ space) (arg_typs @ [ret_typ]) in - (* once we have proper excetions we need to know what the exceptions type is *) - if atyp_needed then parens tpp else tpp - | Typ_tuple typs -> - parens (separate_map (space ^^ star ^^ space) (app_typ params_to_print false) typs) - | _ -> app_typ params_to_print atyp_needed ty - and app_typ params_to_print atyp_needed ((Typ_aux (t, l)) as ty) = match t with - | Typ_app(Id_aux (Id "vector", _), [ - A_aux (A_nexp m, _); - A_aux (A_order ord, _); - A_aux (A_typ elem_typ, _)]) -> - let tpp = string "list" ^^ space ^^ typ params_to_print true elem_typ in - if atyp_needed then parens tpp else tpp - | Typ_app(Id_aux (Id "bitvector", _), [ - A_aux (A_nexp m, _); - A_aux (A_order ord, _)]) -> - let tpp = - if !Monomorphise.opt_mwords then - string "mword " ^^ doc_nexp_lem (nexp_simp m) - else - string "list" ^^ space ^^ typ params_to_print true bit_typ - in - if atyp_needed then parens tpp else tpp - | Typ_app(Id_aux (Id "register", _), [A_aux (A_typ etyp, _)]) -> - let tpp = string "register_ref regstate register_value " ^^ typ params_to_print true etyp in - if atyp_needed then parens tpp else tpp - | Typ_app(Id_aux (Id "range", _),_) -> - (string "integer") - | Typ_app(Id_aux (Id "implicit", _),_) -> - (string "integer") - | Typ_app(Id_aux (Id "atom", _), [A_aux(A_nexp n,_)]) -> - (string "integer") - | Typ_app(Id_aux (Id "atom_bool", _), [A_aux(A_bool nc,_)]) -> - (string "bool") - | Typ_app(id,args) -> - let args = - match Bindings.find_opt id params_to_print with - | None -> args - | Some is -> - let args,_ = List.fold_left (fun (l,i) a -> if Util.IntSet.mem i is then (a::l,i+1) else (l,i+1)) ([],0) args in + else *) + separate space [tup_typ params_to_print false ret] + in + let arg_typs = List.map (tup_typ params_to_print false) args in + let tpp = separate (space ^^ arrow ^^ space) (arg_typs @ [ret_typ]) in + (* once we have proper excetions we need to know what the exceptions type is *) + if atyp_needed then parens tpp else tpp + | Typ_tuple typs -> parens (separate_map (space ^^ star ^^ space) (app_typ params_to_print false) typs) + | _ -> app_typ params_to_print atyp_needed ty + and app_typ params_to_print atyp_needed (Typ_aux (t, l) as ty) = + match t with + | Typ_app (Id_aux (Id "vector", _), [A_aux (A_nexp m, _); A_aux (A_order ord, _); A_aux (A_typ elem_typ, _)]) -> + let tpp = string "list" ^^ space ^^ typ params_to_print true elem_typ in + if atyp_needed then parens tpp else tpp + | Typ_app (Id_aux (Id "bitvector", _), [A_aux (A_nexp m, _); A_aux (A_order ord, _)]) -> + let tpp = + if !Monomorphise.opt_mwords then string "mword " ^^ doc_nexp_lem (nexp_simp m) + else string "list" ^^ space ^^ typ params_to_print true bit_typ + in + if atyp_needed then parens tpp else tpp + | Typ_app (Id_aux (Id "register", _), [A_aux (A_typ etyp, _)]) -> + let tpp = string "register_ref regstate register_value " ^^ typ params_to_print true etyp in + if atyp_needed then parens tpp else tpp + | Typ_app (Id_aux (Id "range", _), _) -> string "integer" + | Typ_app (Id_aux (Id "implicit", _), _) -> string "integer" + | Typ_app (Id_aux (Id "atom", _), [A_aux (A_nexp n, _)]) -> string "integer" + | Typ_app (Id_aux (Id "atom_bool", _), [A_aux (A_bool nc, _)]) -> string "bool" + | Typ_app (id, args) -> + let args = + match Bindings.find_opt id params_to_print with + | None -> args + | Some is -> + let args, _ = + List.fold_left + (fun (l, i) a -> if Util.IntSet.mem i is then (a :: l, i + 1) else (l, i + 1)) + ([], 0) args + in List.rev args - in - let tpp = (doc_id_lem_type id) ^^ space ^^ (separate_map space (doc_typ_arg_lem params_to_print) args) in - if atyp_needed then parens tpp else tpp - | _ -> atomic_typ params_to_print atyp_needed ty - and atomic_typ params_to_print atyp_needed ((Typ_aux (t, l)) as ty) = match t with - | Typ_id (Id_aux (Id "bool",_)) -> string "bool" - | Typ_id (Id_aux (Id "bit",_)) -> string "bitU" - | Typ_id (id) -> - (*if List.exists ((=) (string_of_id id)) regtypes - then string "register" - else*) doc_id_lem_type id - | Typ_var v -> doc_var v - | Typ_app _ | Typ_tuple _ | Typ_fn _ -> - (* exhaustiveness matters here to avoid infinite loops - * if we add a new Typ constructor *) - let tpp = typ params_to_print true ty in - if atyp_needed then parens tpp else tpp - | Typ_exist (kopts,_,ty) when List.for_all is_int_kopt kopts -> begin - let kids = List.map kopt_kid kopts in - let tpp = typ params_to_print true ty in - let visible_vars = lem_tyvars_of_typ params_to_print ty in - match List.filter (fun kid -> KidSet.mem kid visible_vars) kids with - | [] -> if atyp_needed then parens tpp else tpp - | bad -> raise (Reporting.err_general l - ("Existential type variable(s) " ^ - String.concat ", " (List.map string_of_kid bad) ^ - " escape into Lem")) - end - (* AA: I think the correct thing is likely to filter out - non-integer kinded_id's, then use the above code. *) - | Typ_exist (_,_,Typ_aux(Typ_app(id,[_]),_)) when string_of_id id = "atom_bool" -> string "bool" - | Typ_exist _ -> unreachable l __POS__ "Non-integer existentials currently unsupported in Lem" (* TODO *) - | Typ_bidir _ -> unreachable l __POS__ "Lem doesn't support bidir types" - | Typ_internal_unknown -> unreachable l __POS__ "escaped Typ_internal_unknown" - and doc_typ_arg_lem params_to_print (A_aux(t,_)) = match t with - | A_typ t -> app_typ params_to_print true t - | A_nexp n -> doc_nexp_lem (nexp_simp n) - | A_order o -> empty - | A_bool _ -> empty + in + let tpp = doc_id_lem_type id ^^ space ^^ separate_map space (doc_typ_arg_lem params_to_print) args in + if atyp_needed then parens tpp else tpp + | _ -> atomic_typ params_to_print atyp_needed ty + and atomic_typ params_to_print atyp_needed (Typ_aux (t, l) as ty) = + match t with + | Typ_id (Id_aux (Id "bool", _)) -> string "bool" + | Typ_id (Id_aux (Id "bit", _)) -> string "bitU" + | Typ_id id -> + (*if List.exists ((=) (string_of_id id)) regtypes + then string "register" + else*) + doc_id_lem_type id + | Typ_var v -> doc_var v + | Typ_app _ | Typ_tuple _ | Typ_fn _ -> + (* exhaustiveness matters here to avoid infinite loops + * if we add a new Typ constructor *) + let tpp = typ params_to_print true ty in + if atyp_needed then parens tpp else tpp + | Typ_exist (kopts, _, ty) when List.for_all is_int_kopt kopts -> begin + let kids = List.map kopt_kid kopts in + let tpp = typ params_to_print true ty in + let visible_vars = lem_tyvars_of_typ params_to_print ty in + match List.filter (fun kid -> KidSet.mem kid visible_vars) kids with + | [] -> if atyp_needed then parens tpp else tpp + | bad -> + raise + (Reporting.err_general l + ("Existential type variable(s) " ^ String.concat ", " (List.map string_of_kid bad) ^ " escape into Lem") + ) + end + (* AA: I think the correct thing is likely to filter out + non-integer kinded_id's, then use the above code. *) + | Typ_exist (_, _, Typ_aux (Typ_app (id, [_]), _)) when string_of_id id = "atom_bool" -> string "bool" + | Typ_exist _ -> unreachable l __POS__ "Non-integer existentials currently unsupported in Lem" (* TODO *) + | Typ_bidir _ -> unreachable l __POS__ "Lem doesn't support bidir types" + | Typ_internal_unknown -> unreachable l __POS__ "escaped Typ_internal_unknown" + and doc_typ_arg_lem params_to_print (A_aux (t, _)) = + match t with + | A_typ t -> app_typ params_to_print true t + | A_nexp n -> doc_nexp_lem (nexp_simp n) + | A_order o -> empty + | A_bool _ -> empty in let top atyp_needed params_to_print env ty = (* If we use the bitlist representation of bitvectors, the type argument in @@ -462,769 +430,779 @@ let doc_typ_lem, doc_typ_lem_brackets, doc_atomic_typ_lem = workaround, we expand type synonyms in this case. *) let ty' = if !Monomorphise.opt_mwords then ty else Env.expand_synonyms env ty in typ params_to_print atyp_needed ty' - in top false, top true, atomic_typ + in + (top false, top true, atomic_typ) -let doc_fn_typ_lem ?(monad = empty) params_to_print env (Typ_aux (aux, l) as ty) = match aux with +let doc_fn_typ_lem ?(monad = empty) params_to_print env (Typ_aux (aux, l) as ty) = + match aux with | Typ_fn (args, ret) -> - separate (space ^^ arrow ^^ space) (List.map (doc_typ_lem params_to_print env) args @ [monad ^^ (doc_typ_lem_brackets params_to_print env) ret]) - | _ -> - doc_typ_lem params_to_print env ty - + separate + (space ^^ arrow ^^ space) + (List.map (doc_typ_lem params_to_print env) args @ [monad ^^ (doc_typ_lem_brackets params_to_print env) ret]) + | _ -> doc_typ_lem params_to_print env ty + (* Check for variables in types that would be pretty-printed. *) -let contains_t_pp_var ctxt (Typ_aux (t,a) as typ) = - lem_nexps_of_typ ctxt.params_to_print typ - |> NexpSet.exists (fun nexp -> not (is_nexp_constant nexp)) +let contains_t_pp_var ctxt (Typ_aux (t, a) as typ) = + lem_nexps_of_typ ctxt.params_to_print typ |> NexpSet.exists (fun nexp -> not (is_nexp_constant nexp)) let rec replace_typ_size ctxt env (Typ_aux (t, a) as typ) = let rewrap t = Typ_aux (t, a) in let recur = replace_typ_size ctxt env in match t with - | Typ_tuple typs -> - begin match Util.option_all (List.map recur typs) with - | Some typs' -> Some (rewrap (Typ_tuple typs')) - | None -> None - end - | Typ_app (id, args) when contains_t_pp_var ctxt typ -> - begin match Util.option_all (List.map (replace_typ_arg_size ctxt env) args) with - | Some args' -> Some (rewrap (Typ_app (id, args'))) - | None -> None - end + | Typ_tuple typs -> begin + match Util.option_all (List.map recur typs) with Some typs' -> Some (rewrap (Typ_tuple typs')) | None -> None + end + | Typ_app (id, args) when contains_t_pp_var ctxt typ -> begin + match Util.option_all (List.map (replace_typ_arg_size ctxt env) args) with + | Some args' -> Some (rewrap (Typ_app (id, args'))) + | None -> None + end | Typ_app _ -> Some typ | Typ_id _ -> Some typ - | Typ_fn (argtyps, rtyp) -> - begin match (Util.option_all (List.map recur argtyps), recur rtyp) with - | (Some argtyps', Some rtyp') -> Some (rewrap (Typ_fn (argtyps', rtyp'))) - | _ -> None - end + | Typ_fn (argtyps, rtyp) -> begin + match (Util.option_all (List.map recur argtyps), recur rtyp) with + | Some argtyps', Some rtyp' -> Some (rewrap (Typ_fn (argtyps', rtyp'))) + | _ -> None + end | Typ_var kid -> - let is_kid nexp = Nexp.compare nexp (nvar kid) = 0 in - if NexpSet.exists is_kid ctxt.bound_nexps then Some typ else None - | Typ_exist (kids, nc, typ) -> - begin match recur typ with - | Some typ' -> Some (rewrap (Typ_exist (kids, nc, typ'))) - | None -> None - end - | Typ_internal_unknown - | Typ_bidir (_, _) -> None + let is_kid nexp = Nexp.compare nexp (nvar kid) = 0 in + if NexpSet.exists is_kid ctxt.bound_nexps then Some typ else None + | Typ_exist (kids, nc, typ) -> begin + match recur typ with Some typ' -> Some (rewrap (Typ_exist (kids, nc, typ'))) | None -> None + end + | Typ_internal_unknown | Typ_bidir (_, _) -> None + and replace_typ_arg_size ctxt env (A_aux (ta, a) as targ) = let rewrap ta = A_aux (ta, a) in match ta with - | A_nexp nexp -> - begin match Type_check.solve_unique env nexp with - | Some n -> Some (rewrap (A_nexp (nconstant n))) - | None -> - let is_equal nexp' = - prove __POS__ env (NC_aux (NC_equal (nexp,nexp'),Parse_ast.Unknown)) - in + | A_nexp nexp -> begin + match Type_check.solve_unique env nexp with + | Some n -> Some (rewrap (A_nexp (nconstant n))) + | None -> ( + let is_equal nexp' = prove __POS__ env (NC_aux (NC_equal (nexp, nexp'), Parse_ast.Unknown)) in match List.find is_equal (NexpSet.elements ctxt.bound_nexps) with | nexp' -> Some (rewrap (A_nexp nexp')) | exception Not_found -> None - end - | A_typ typ -> - begin match replace_typ_size ctxt env typ with - | Some typ' -> Some (rewrap (A_typ typ')) - | None -> None - end + ) + end + | A_typ typ -> begin + match replace_typ_size ctxt env typ with Some typ' -> Some (rewrap (A_typ typ')) | None -> None + end | A_order _ | A_bool _ -> Some targ let make_printable_type ctxt env typ = - if contains_t_pp_var ctxt typ - then - try replace_typ_size ctxt env (Env.expand_synonyms env typ) with - | _ -> None + if contains_t_pp_var ctxt typ then (try replace_typ_size ctxt env (Env.expand_synonyms env typ) with _ -> None) else Some typ let doc_tannot_lem ctxt env eff typ = match make_printable_type ctxt env typ with | None -> empty | Some typ -> - let ta = doc_typ_lem ctxt.params_to_print env typ in - if eff then string " : M " ^^ parens ta - else string " : " ^^ ta + let ta = doc_typ_lem ctxt.params_to_print env typ in + if eff then string " : M " ^^ parens ta else string " : " ^^ ta let min_int32 = Big_int.of_int64 (Int64.of_int32 Int32.min_int) let max_int32 = Big_int.of_int64 (Int64.of_int32 Int32.max_int) -let rec doc_lit_lem (L_aux(lit,l)) = +let rec doc_lit_lem (L_aux (lit, l)) = match lit with - | L_unit -> utf8string "()" - | L_zero -> utf8string "B0" - | L_one -> utf8string "B1" + | L_unit -> utf8string "()" + | L_zero -> utf8string "B0" + | L_one -> utf8string "B1" | L_false -> utf8string "false" - | L_true -> utf8string "true" + | L_true -> utf8string "true" | L_num i -> - let ipp = Big_int.to_string i in - utf8string ( - if Big_int.less i Big_int.zero then "((0"^ipp^"):ii)" - else "("^ipp^":ii)") + let ipp = Big_int.to_string i in + utf8string (if Big_int.less i Big_int.zero then "((0" ^ ipp ^ "):ii)" else "(" ^ ipp ^ ":ii)") | L_hex n when !Monomorphise.opt_mwords -> utf8string ("0x" ^ n) | L_bin n when !Monomorphise.opt_mwords -> utf8string ("0b" ^ n) | L_hex _ | L_bin _ -> - vector_string_to_bit_list (L_aux(lit,l)) - |> flow_map (semi ^^ break 0) doc_lit_lem - |> group |> align |> brackets - | L_undef -> - utf8string "(return (failwith \"undefined value of unsupported type\"))" - | L_string s -> utf8string ("\"" ^ (String.escaped s) ^ "\"") + vector_string_to_bit_list (L_aux (lit, l)) |> flow_map (semi ^^ break 0) doc_lit_lem |> group |> align |> brackets + | L_undef -> utf8string "(return (failwith \"undefined value of unsupported type\"))" + | L_string s -> utf8string ("\"" ^ String.escaped s ^ "\"") | L_real s -> - (* Lem does not support decimal syntax, so we translate a string - of the form "x.y" into the ratio (x * 10^len(y) + y) / 10^len(y). - The OCaml library has a conversion function from strings to floats, but - not from floats to ratios. ZArith's Q library does have the latter, but - using this would require adding a dependency on ZArith to Sail. *) - let parts = Util.split_on_char '.' s in - let (num, denom) = match parts with - | [i] -> (Big_int.of_string i, Big_int.of_int 1) - | [i;f] -> - let denom = Big_int.pow_int_positive 10 (String.length f) in - (Big_int.add (Big_int.mul (Big_int.of_string i) denom) (Big_int.of_string f), denom) - | _ -> - raise (Reporting.err_syntax_loc l "could not parse real literal") in - parens (separate space (List.map string [ - "realFromFrac"; Big_int.to_string num; Big_int.to_string denom])) + (* Lem does not support decimal syntax, so we translate a string + of the form "x.y" into the ratio (x * 10^len(y) + y) / 10^len(y). + The OCaml library has a conversion function from strings to floats, but + not from floats to ratios. ZArith's Q library does have the latter, but + using this would require adding a dependency on ZArith to Sail. *) + let parts = Util.split_on_char '.' s in + let num, denom = + match parts with + | [i] -> (Big_int.of_string i, Big_int.of_int 1) + | [i; f] -> + let denom = Big_int.pow_int_positive 10 (String.length f) in + (Big_int.add (Big_int.mul (Big_int.of_string i) denom) (Big_int.of_string f), denom) + | _ -> raise (Reporting.err_syntax_loc l "could not parse real literal") + in + parens (separate space (List.map string ["realFromFrac"; Big_int.to_string num; Big_int.to_string denom])) let kid_nexps_of_typquant tq = quant_kopts tq |> List.filter (fun k -> is_int_kopt k || is_typ_kopt k) |> List.map kopt_kid |> List.map nvar -let doc_typquant_items_lem quant_nexps = - separate_map space doc_nexp_lem quant_nexps +let doc_typquant_items_lem quant_nexps = separate_map space doc_nexp_lem quant_nexps (* Produce Size type constraints for bitvector sizes when using machine words. Often these will be unnecessary, but this simple approach will do for now. *) -let rec typeclass_nexps params_to_print (Typ_aux(t,l)) = +let rec typeclass_nexps params_to_print (Typ_aux (t, l)) = let typeclass_nexps = typeclass_nexps params_to_print in - if !Monomorphise.opt_mwords then + if !Monomorphise.opt_mwords then ( match t with - | Typ_id _ - | Typ_var _ - -> NexpSet.empty - | Typ_fn (ts,t) -> List.fold_left NexpSet.union (typeclass_nexps t) (List.map typeclass_nexps ts) + | Typ_id _ | Typ_var _ -> NexpSet.empty + | Typ_fn (ts, t) -> List.fold_left NexpSet.union (typeclass_nexps t) (List.map typeclass_nexps ts) | Typ_tuple ts -> List.fold_left NexpSet.union NexpSet.empty (List.map typeclass_nexps ts) | Typ_app (id, args) -> - let add_arg_subtyp_nexps nexps = function - | A_aux (A_typ typ, _) -> NexpSet.union nexps (typeclass_nexps typ) - | _ -> nexps - in - let add_arg_nexps nexps = function - | A_aux (A_nexp nexp, _) -> - let nexp = nexp_simp nexp in - if is_nexp_constant nexp - then nexps else - NexpSet.add (orig_nexp nexp) nexps - | _ -> nexps - in begin - let subtyp_nexps = List.fold_left add_arg_subtyp_nexps NexpSet.empty args in - match Bindings.find_opt id params_to_print with - | Some is -> - Util.IntSet.fold (fun i set -> add_arg_nexps set (List.nth args i)) is subtyp_nexps - | None -> subtyp_nexps - end - | Typ_exist (kids,_,t) -> NexpSet.empty (* todo *) + let add_arg_subtyp_nexps nexps = function + | A_aux (A_typ typ, _) -> NexpSet.union nexps (typeclass_nexps typ) + | _ -> nexps + in + let add_arg_nexps nexps = function + | A_aux (A_nexp nexp, _) -> + let nexp = nexp_simp nexp in + if is_nexp_constant nexp then nexps else NexpSet.add (orig_nexp nexp) nexps + | _ -> nexps + in + begin + let subtyp_nexps = List.fold_left add_arg_subtyp_nexps NexpSet.empty args in + match Bindings.find_opt id params_to_print with + | Some is -> Util.IntSet.fold (fun i set -> add_arg_nexps set (List.nth args i)) is subtyp_nexps + | None -> subtyp_nexps + end + | Typ_exist (kids, _, t) -> NexpSet.empty (* todo *) | Typ_bidir _ -> unreachable l __POS__ "Lem doesn't support bidir types" | Typ_internal_unknown -> unreachable l __POS__ "escaped Typ_internal_unknown" + ) else NexpSet.empty let doc_typclasses_lem params_to_print t = let nexps = typeclass_nexps params_to_print t in - if NexpSet.is_empty nexps then (empty, NexpSet.empty) else - (separate_map comma_sp (fun nexp -> string "Size " ^^ doc_nexp_lem nexp) (NexpSet.elements nexps) ^^ string " => ", nexps) + if NexpSet.is_empty nexps then (empty, NexpSet.empty) + else + ( separate_map comma_sp (fun nexp -> string "Size " ^^ doc_nexp_lem nexp) (NexpSet.elements nexps) ^^ string " => ", + nexps + ) -let doc_typschm_lem ?(monad = empty) params_to_print env quants (TypSchm_aux(TypSchm_ts(tq,t),l)) = +let doc_typschm_lem ?(monad = empty) params_to_print env quants (TypSchm_aux (TypSchm_ts (tq, t), l)) = let env = Env.add_typquant l tq env in - let pt = doc_fn_typ_lem ~monad:monad params_to_print env t in - if quants - then + let pt = doc_fn_typ_lem ~monad params_to_print env t in + if quants then ( let nexps_used = lem_nexps_of_typ params_to_print t in let ptyc, nexps_sizes = doc_typclasses_lem params_to_print t in let nexps_to_include = NexpSet.union nexps_used nexps_sizes in - if NexpSet.is_empty nexps_to_include - then pt + if NexpSet.is_empty nexps_to_include then pt else string "forall " ^^ doc_typquant_items_lem (NexpSet.elements nexps_to_include) ^^ string ". " ^^ ptyc ^^ pt + ) else pt -let is_ctor env id = match Env.lookup_id id env with -| Enum _ -> true -| _ -> false +let is_ctor env id = match Env.lookup_id id env with Enum _ -> true | _ -> false (*Note: vector concatenation, literal vectors, indexed vectors, and record should be removed prior to pp. The latter two have never yet been seen *) -let rec doc_pat_lem ctxt apat_needed (P_aux (p,(l,annot)) as pa) = match p with - | P_app(id, _) when string_of_id id = "None" -> string "Nothing" - | P_app(id, ((_ :: _) as pats)) -> - let ppp = doc_unop (doc_id_lem_ctor id) - (parens (separate_map comma (doc_pat_lem ctxt true) pats)) in - if apat_needed then parens ppp else ppp - | P_app(id, []) -> doc_id_lem_ctor id - | P_lit lit -> doc_lit_lem lit +let rec doc_pat_lem ctxt apat_needed (P_aux (p, (l, annot)) as pa) = + match p with + | P_app (id, _) when string_of_id id = "None" -> string "Nothing" + | P_app (id, (_ :: _ as pats)) -> + let ppp = doc_unop (doc_id_lem_ctor id) (parens (separate_map comma (doc_pat_lem ctxt true) pats)) in + if apat_needed then parens ppp else ppp + | P_app (id, []) -> doc_id_lem_ctor id + | P_lit lit -> doc_lit_lem lit | P_wild -> underscore | P_id id -> doc_id_lem id - | P_var(p,_) -> doc_pat_lem ctxt true p - | P_as(p,id) -> parens (separate space [doc_pat_lem ctxt true p; string "as"; doc_id_lem id]) - | P_typ(Typ_aux (Typ_tuple typs, _), P_aux (P_tuple pats, _)) -> - (* Isabelle does not seem to like type-annotated tuple patterns; - it gives a syntax error. Avoid this by annotating the tuple elements instead *) - let env = env_of_pat pa in - let doc_elem typ pat = doc_pat_lem ctxt true (add_p_typ env typ pat) in - parens (separate comma_sp (List.map2 doc_elem typs pats)) - | P_typ(typ,p) -> - let doc_p = doc_pat_lem ctxt true p in - (match make_printable_type ctxt (env_of_annot (l,annot)) typ with - | None -> doc_p - | Some typ -> parens (doc_op colon doc_p (doc_typ_lem ctxt.params_to_print (env_of_annot (l,annot)) typ))) + | P_var (p, _) -> doc_pat_lem ctxt true p + | P_as (p, id) -> parens (separate space [doc_pat_lem ctxt true p; string "as"; doc_id_lem id]) + | P_typ (Typ_aux (Typ_tuple typs, _), P_aux (P_tuple pats, _)) -> + (* Isabelle does not seem to like type-annotated tuple patterns; + it gives a syntax error. Avoid this by annotating the tuple elements instead *) + let env = env_of_pat pa in + let doc_elem typ pat = doc_pat_lem ctxt true (add_p_typ env typ pat) in + parens (separate comma_sp (List.map2 doc_elem typs pats)) + | P_typ (typ, p) -> ( + let doc_p = doc_pat_lem ctxt true p in + match make_printable_type ctxt (env_of_annot (l, annot)) typ with + | None -> doc_p + | Some typ -> parens (doc_op colon doc_p (doc_typ_lem ctxt.params_to_print (env_of_annot (l, annot)) typ)) + ) | P_vector pats -> - let ppp = brackets (separate_map semi (doc_pat_lem ctxt true) pats) in - if apat_needed then parens ppp else ppp + let ppp = brackets (separate_map semi (doc_pat_lem ctxt true) pats) in + if apat_needed then parens ppp else ppp | P_vector_concat pats -> - raise (Reporting.err_unreachable l __POS__ - "vector concatenation patterns should have been removed before pretty-printing") - | P_tuple pats -> - (match pats with + raise + (Reporting.err_unreachable l __POS__ + "vector concatenation patterns should have been removed before pretty-printing" + ) + | P_tuple pats -> ( + match pats with | [p] -> doc_pat_lem ctxt apat_needed p - | _ -> parens (separate_map comma_sp (doc_pat_lem ctxt false) pats)) + | _ -> parens (separate_map comma_sp (doc_pat_lem ctxt false) pats) + ) | P_list pats -> brackets (separate_map semi (doc_pat_lem ctxt false) pats) (*Never seen but easy in lem*) - | P_cons (p,p') -> doc_op (string "::") (doc_pat_lem ctxt true p) (doc_pat_lem ctxt true p') + | P_cons (p, p') -> doc_op (string "::") (doc_pat_lem ctxt true p) (doc_pat_lem ctxt true p') | P_string_append _ -> unreachable l __POS__ "Lem doesn't support string append patterns" | P_vector_subrange _ -> unreachable l __POS__ "Lem doesn't support vector subrange patterns" | P_not _ -> unreachable l __POS__ "Lem doesn't support not patterns" | P_or _ -> unreachable l __POS__ "Lem doesn't support or patterns" -let rec typ_needs_printed params_to_print (Typ_aux (t,_) as typ) = +let rec typ_needs_printed params_to_print (Typ_aux (t, _) as typ) = let typ_needs_printed = typ_needs_printed params_to_print in match t with | Typ_tuple ts -> List.exists typ_needs_printed ts - | Typ_app (id, targs) -> - begin - match Bindings.find_opt id params_to_print with - | Some is when not (Util.IntSet.is_empty is) -> true - | _ -> List.exists (typ_needs_printed_arg params_to_print) targs - end - | Typ_fn (ts,t) -> List.exists typ_needs_printed ts || typ_needs_printed t - | Typ_exist (kopts,_,t) -> - let kids = List.map kopt_kid kopts in (* TODO: Check this *) - let visible_kids = KidSet.inter (KidSet.of_list kids) (lem_tyvars_of_typ params_to_print t) in - typ_needs_printed t && KidSet.is_empty visible_kids - | _ -> false -and typ_needs_printed_arg params_to_print (A_aux (targ, _)) = match targ with - | A_typ t -> typ_needs_printed params_to_print t + | Typ_app (id, targs) -> begin + match Bindings.find_opt id params_to_print with + | Some is when not (Util.IntSet.is_empty is) -> true + | _ -> List.exists (typ_needs_printed_arg params_to_print) targs + end + | Typ_fn (ts, t) -> List.exists typ_needs_printed ts || typ_needs_printed t + | Typ_exist (kopts, _, t) -> + let kids = List.map kopt_kid kopts in + (* TODO: Check this *) + let visible_kids = KidSet.inter (KidSet.of_list kids) (lem_tyvars_of_typ params_to_print t) in + typ_needs_printed t && KidSet.is_empty visible_kids | _ -> false +and typ_needs_printed_arg params_to_print (A_aux (targ, _)) = + match targ with A_typ t -> typ_needs_printed params_to_print t | _ -> false + let contains_early_return exp = let e_app (f, args) = let rets, args = List.split args in - (List.fold_left (||) (string_of_id f = "early_return") rets, - E_app (f, args)) in - fst (fold_exp - { (Rewriter.compute_exp_alg false (||)) - with e_return = (fun (_, r) -> (true, E_return r)); e_app = e_app } exp) + (List.fold_left ( || ) (string_of_id f = "early_return") rets, E_app (f, args)) + in + fst + (fold_exp { (Rewriter.compute_exp_alg false ( || )) with e_return = (fun (_, r) -> (true, E_return r)); e_app } exp) (* Does the expression have the form of a bitvector cast from the monomorphiser? *) type is_bitvector_cast = BVC_yes | BVC_allowed | BVC_not let is_bitvector_cast_out exp = - let merge x y = match x,y with + let merge x y = + match (x, y) with | BVC_allowed, _ -> y | _, BVC_allowed -> x | BVC_not, _ -> BVC_not | _, BVC_not -> BVC_not | _ -> BVC_yes in - let rec aux (E_aux (e,_)) = + let rec aux (E_aux (e, _)) = match e with | E_tuple es -> List.fold_left merge BVC_allowed (List.map aux es) - | E_typ (_,e) -> aux e - | E_app (Id_aux (Id "bitvector_cast_out",_),_) -> BVC_yes + | E_typ (_, e) -> aux e + | E_app (Id_aux (Id "bitvector_cast_out", _), _) -> BVC_yes | E_id _ -> BVC_allowed | _ -> BVC_not - in aux exp = BVC_yes + in + aux exp = BVC_yes -let replace_env_for_cast_out new_env pat = - map_pat_annot (fun (l,a) -> (l,replace_env new_env a)) pat +let replace_env_for_cast_out new_env pat = map_pat_annot (fun (l, a) -> (l, replace_env new_env a)) pat let find_e_ids exp = - let e_id id = IdSet.singleton id, E_id id in - fst (fold_exp - { (compute_exp_alg IdSet.empty IdSet.union) with e_id = e_id } exp) + let e_id id = (IdSet.singleton id, E_id id) in + fst (fold_exp { (compute_exp_alg IdSet.empty IdSet.union) with e_id } exp) -let typ_id_of (Typ_aux (typ, l)) = match typ with +let typ_id_of (Typ_aux (typ, l)) = + match typ with | Typ_id id -> id - | Typ_app (register, [A_aux (A_typ (Typ_aux (Typ_id id, _)), _)]) - when string_of_id register = "register" -> id + | Typ_app (register, [A_aux (A_typ (Typ_aux (Typ_id id, _)), _)]) when string_of_id register = "register" -> id | Typ_app (id, _) -> id | _ -> raise (Reporting.err_unreachable l __POS__ "failed to get type id") let prefix_recordtype = true let report = Reporting.err_unreachable let doc_exp_lem, doc_let_lem = - let rec top_exp (ctxt : context) (aexp_needed : bool) - (E_aux (e, (l,annot)) as full_exp) = + let rec top_exp (ctxt : context) (aexp_needed : bool) (E_aux (e, (l, annot)) as full_exp) = let expY = top_exp ctxt true in let expN = top_exp ctxt false in let expV = top_exp ctxt in - let wrap_parens doc = if aexp_needed then parens (doc) else doc in + let wrap_parens doc = if aexp_needed then parens doc else doc in let liftR doc = - if ctxt.early_ret && effectful (effect_of full_exp) - then wrap_parens (separate space [string "liftR"; parens (doc)]) - else wrap_parens doc in + if ctxt.early_ret && effectful (effect_of full_exp) then wrap_parens (separate space [string "liftR"; parens doc]) + else wrap_parens doc + in match e with - | E_assign((LE_aux(le_act,tannot) as le), e) -> - (* can only be register writes *) - let t = typ_of_annot tannot in - (match le_act (*, t, tag*) with - | LE_vector_range (le,e2,e3) -> - (match le with - | LE_aux (LE_field ((LE_aux (_, lannot) as le),id), fannot) -> - if is_bit_typ (typ_of_annot fannot) then - raise (report l __POS__ "indexing a register's (single bit) bitfield not supported") - else - let field_ref = - doc_id_lem (typ_id_of (typ_of_annot lannot)) ^^ - underscore ^^ - doc_id_lem id in - liftR ((prefix 2 1) - (string "write_reg_field_range") - (align (doc_lexp_deref_lem ctxt le ^/^ - field_ref ^/^ expY e2 ^/^ expY e3 ^/^ expY e))) + | E_assign ((LE_aux (le_act, tannot) as le), e) -> ( + (* can only be register writes *) + let t = typ_of_annot tannot in + match le_act (*, t, tag*) with + | LE_vector_range (le, e2, e3) -> ( + match le with + | LE_aux (LE_field ((LE_aux (_, lannot) as le), id), fannot) -> + if is_bit_typ (typ_of_annot fannot) then + raise (report l __POS__ "indexing a register's (single bit) bitfield not supported") + else ( + let field_ref = doc_id_lem (typ_id_of (typ_of_annot lannot)) ^^ underscore ^^ doc_id_lem id in + liftR + ((prefix 2 1) (string "write_reg_field_range") + (align (doc_lexp_deref_lem ctxt le ^/^ field_ref ^/^ expY e2 ^/^ expY e3 ^/^ expY e)) + ) + ) | _ -> - let deref = doc_lexp_deref_lem ctxt le in - liftR ((prefix 2 1) - (string "write_reg_range") - (align (deref ^/^ expY e2 ^/^ expY e3) ^/^ expY e))) - | LE_vector (le,e2) -> - (match le with - | LE_aux (LE_field ((LE_aux (_, lannot) as le),id), fannot) -> - if is_bit_typ (typ_of_annot fannot) then - raise (report l __POS__ "indexing a register's (single bit) bitfield not supported") - else - let field_ref = - doc_id_lem (typ_id_of (typ_of_annot lannot)) ^^ - underscore ^^ - doc_id_lem id in - let call = if is_bitvector_typ (Env.base_typ_of (env_of full_exp) (typ_of_annot fannot)) then "write_reg_field_bit" else "write_reg_field_pos" in - liftR ((prefix 2 1) - (string call) - (align (doc_lexp_deref_lem ctxt le ^/^ - field_ref ^/^ expY e2 ^/^ expY e))) + let deref = doc_lexp_deref_lem ctxt le in + liftR ((prefix 2 1) (string "write_reg_range") (align (deref ^/^ expY e2 ^/^ expY e3) ^/^ expY e)) + ) + | LE_vector (le, e2) -> ( + match le with + | LE_aux (LE_field ((LE_aux (_, lannot) as le), id), fannot) -> + if is_bit_typ (typ_of_annot fannot) then + raise (report l __POS__ "indexing a register's (single bit) bitfield not supported") + else ( + let field_ref = doc_id_lem (typ_id_of (typ_of_annot lannot)) ^^ underscore ^^ doc_id_lem id in + let call = + if is_bitvector_typ (Env.base_typ_of (env_of full_exp) (typ_of_annot fannot)) then + "write_reg_field_bit" + else "write_reg_field_pos" + in + liftR + ((prefix 2 1) (string call) + (align (doc_lexp_deref_lem ctxt le ^/^ field_ref ^/^ expY e2 ^/^ expY e)) + ) + ) | LE_aux (_, lannot) -> - let deref = doc_lexp_deref_lem ctxt le in - let call = if is_bitvector_typ (Env.base_typ_of (env_of full_exp) (typ_of_annot lannot)) then "write_reg_bit" else "write_reg_pos" in - liftR ((prefix 2 1) (string call) - (deref ^/^ expY e2 ^/^ expY e)) - ) - | LE_field ((LE_aux (_, lannot) as le),id) -> - let field_ref = - doc_id_lem (typ_id_of (typ_of_annot lannot)) ^^ - underscore ^^ - doc_id_lem id (*^^ - dot ^^ - string "set_field"*) in - liftR ((prefix 2 1) - (string "write_reg_field") - (doc_lexp_deref_lem ctxt le ^^ space ^^ - field_ref ^/^ expY e)) - | LE_deref re -> - liftR ((prefix 2 1) (string "write_reg") (expY re ^/^ expY e)) - | _ -> - liftR ((prefix 2 1) (string "write_reg") (doc_lexp_deref_lem ctxt le ^/^ expY e))) - | E_vector_append(le,re) -> - raise (Reporting.err_unreachable l __POS__ - "E_vector_append should have been rewritten before pretty-printing") - | E_cons(le,re) -> doc_op (group (colon^^colon)) (expY le) (expY re) - | E_if(c,t,e) -> wrap_parens (align (if_exp ctxt false c t e)) - | E_for(id,exp1,exp2,exp3,(Ord_aux(order,_)),exp4) -> - raise (report l __POS__ "E_for should have been rewritten before pretty-printing") - | E_loop _ -> - raise (report l __POS__ "E_loop should have been rewritten before pretty-printing") - | E_let(leb,e) -> - wrap_parens (let_exp ctxt leb ^^ space ^^ string "in" ^^ hardline ^^ expN e) - | E_app(f,args) -> - begin match f with - | Id_aux (Id "None", _) as none -> doc_id_lem_ctor none - | Id_aux (Id "and_bool", _) | Id_aux (Id "or_bool", _) - when effectful (effect_of full_exp) -> - let call = doc_id_lem (append_id f "M") in - wrap_parens (hang 2 (flow (break 1) (call :: List.map expY args))) - (* temporary hack to make the loop body a function of the temporary variables *) - | Id_aux (Id "foreach#", _) -> - begin + let deref = doc_lexp_deref_lem ctxt le in + let call = + if is_bitvector_typ (Env.base_typ_of (env_of full_exp) (typ_of_annot lannot)) then "write_reg_bit" + else "write_reg_pos" + in + liftR ((prefix 2 1) (string call) (deref ^/^ expY e2 ^/^ expY e)) + ) + | LE_field ((LE_aux (_, lannot) as le), id) -> + let field_ref = + doc_id_lem (typ_id_of (typ_of_annot lannot)) ^^ underscore ^^ doc_id_lem id + (*^^ + dot ^^ + string "set_field"*) + in + liftR ((prefix 2 1) (string "write_reg_field") (doc_lexp_deref_lem ctxt le ^^ space ^^ field_ref ^/^ expY e)) + | LE_deref re -> liftR ((prefix 2 1) (string "write_reg") (expY re ^/^ expY e)) + | _ -> liftR ((prefix 2 1) (string "write_reg") (doc_lexp_deref_lem ctxt le ^/^ expY e)) + ) + | E_vector_append (le, re) -> + raise (Reporting.err_unreachable l __POS__ "E_vector_append should have been rewritten before pretty-printing") + | E_cons (le, re) -> doc_op (group (colon ^^ colon)) (expY le) (expY re) + | E_if (c, t, e) -> wrap_parens (align (if_exp ctxt false c t e)) + | E_for (id, exp1, exp2, exp3, Ord_aux (order, _), exp4) -> + raise (report l __POS__ "E_for should have been rewritten before pretty-printing") + | E_loop _ -> raise (report l __POS__ "E_loop should have been rewritten before pretty-printing") + | E_let (leb, e) -> wrap_parens (let_exp ctxt leb ^^ space ^^ string "in" ^^ hardline ^^ expN e) + | E_app (f, args) -> begin + match f with + | Id_aux (Id "None", _) as none -> doc_id_lem_ctor none + | (Id_aux (Id "and_bool", _) | Id_aux (Id "or_bool", _)) when effectful (effect_of full_exp) -> + let call = doc_id_lem (append_id f "M") in + wrap_parens (hang 2 (flow (break 1) (call :: List.map expY args))) + (* temporary hack to make the loop body a function of the temporary variables *) + | Id_aux (Id "foreach#", _) -> begin match args with | [exp1; exp2; exp3; ord_exp; vartuple; body] -> - let loopvar, body = match body with - | E_aux (E_if (_, - E_aux (E_let (LB_aux (LB_val ( - ((P_aux (P_typ (_, P_aux (P_var (P_aux (P_id id, _), _), _)), _)) - | (P_aux (P_var (P_aux (P_id id, _), _), _)) - | (P_aux (P_id id, _))), _), _), - body), _), _), _) -> id, body - | _ -> raise (Reporting.err_unreachable l __POS__ ("Unable to find loop variable in " ^ string_of_exp body)) in - let step = match ord_exp with - | E_aux (E_lit (L_aux (L_false, _)), _) -> - parens (separate space [string "integerNegate"; expY exp3]) - | _ -> expY exp3 - in - let combinator = - if ctxt.monadic && effectful (effect_of body) then - "foreachM" - else if effectful (effect_of body) then - "foreachE" - else - "foreach" in - let indices_pp = parens (separate space [string "index_list"; expY exp1; expY exp2; step]) in - let used_vars_body = find_e_ids body in - let body_lambda = - (* Work around indentation issues in Lem when translating - tuple or literal unit patterns to Isabelle *) - match fst (uncast_exp vartuple) with - | E_aux (E_tuple _, _) - when not (IdSet.mem (mk_id "varstup") used_vars_body)-> + let loopvar, body = + match body with + | E_aux + ( E_if + ( _, + E_aux + ( E_let + ( LB_aux + ( LB_val + ( ( P_aux (P_typ (_, P_aux (P_var (P_aux (P_id id, _), _), _)), _) + | P_aux (P_var (P_aux (P_id id, _), _), _) + | P_aux (P_id id, _) ), + _ + ), + _ + ), + body + ), + _ + ), + _ + ), + _ + ) -> + (id, body) + | _ -> + raise + (Reporting.err_unreachable l __POS__ ("Unable to find loop variable in " ^ string_of_exp body)) + in + let step = + match ord_exp with + | E_aux (E_lit (L_aux (L_false, _)), _) -> parens (separate space [string "integerNegate"; expY exp3]) + | _ -> expY exp3 + in + let combinator = + if ctxt.monadic && effectful (effect_of body) then "foreachM" + else if effectful (effect_of body) then "foreachE" + else "foreach" + in + let indices_pp = parens (separate space [string "index_list"; expY exp1; expY exp2; step]) in + let used_vars_body = find_e_ids body in + let body_lambda = + (* Work around indentation issues in Lem when translating + tuple or literal unit patterns to Isabelle *) + match fst (uncast_exp vartuple) with + | E_aux (E_tuple _, _) when not (IdSet.mem (mk_id "varstup") used_vars_body) -> separate space [string "fun"; doc_id_lem loopvar; string "varstup"; arrow] - ^^ break 1 ^^ - separate space [string "let"; expY vartuple; string "= varstup in"] - | E_aux (E_lit (L_aux (L_unit, _)), _) - when not (IdSet.mem (mk_id "unit_var") used_vars_body) -> + ^^ break 1 + ^^ separate space [string "let"; expY vartuple; string "= varstup in"] + | E_aux (E_lit (L_aux (L_unit, _)), _) when not (IdSet.mem (mk_id "unit_var") used_vars_body) -> separate space [string "fun"; doc_id_lem loopvar; string "unit_var"; arrow] - | _ -> - separate space [string "fun"; doc_id_lem loopvar; expY vartuple; arrow] - in - parens ( - (prefix 2 1) + | _ -> separate space [string "fun"; doc_id_lem loopvar; expY vartuple; arrow] + in + parens + ((prefix 2 1) ((separate space) [string combinator; indices_pp; expY vartuple]) - (parens - (prefix 2 1 (group body_lambda) (expN body)) - ) - ) - | _ -> raise (Reporting.err_unreachable l __POS__ - "Unexpected number of arguments for loop combinator") + (parens (prefix 2 1 (group body_lambda) (expN body))) + ) + | _ -> raise (Reporting.err_unreachable l __POS__ "Unexpected number of arguments for loop combinator") end - | Id_aux (Id (("while#" | "until#" | "while#t" | "until#t") as combinator), _) -> - let combinator = String.sub combinator 0 (String.index combinator '#') in - begin + | Id_aux (Id (("while#" | "until#" | "while#t" | "until#t") as combinator), _) -> + let combinator = String.sub combinator 0 (String.index combinator '#') in + begin + match args with + | [cond; varstuple; body] | [cond; varstuple; body; _] -> + (* Ignore termination measures - not used in Lem *) + let return (E_aux (e, a)) = E_aux (E_internal_return (E_aux (e, a)), a) in + let csuffix, cond, body = + match (effectful (effect_of cond), effectful (effect_of body)) with + | false, false -> ("", cond, body) + | false, true -> ("M", return cond, body) + | true, false -> ("M", cond, return body) + | true, true -> ("M", cond, body) + in + let used_vars_body = find_e_ids body in + let lambda = + (* Work around indentation issues in Lem when translating + tuple or literal unit patterns to Isabelle *) + match fst (uncast_exp varstuple) with + | E_aux (E_tuple _, _) when not (IdSet.mem (mk_id "varstup") used_vars_body) -> + separate space [string "fun varstup"; arrow] + ^^ break 1 + ^^ separate space [string "let"; expY varstuple; string "= varstup in"] + | E_aux (E_lit (L_aux (L_unit, _)), _) when not (IdSet.mem (mk_id "unit_var") used_vars_body) -> + separate space [string "fun unit_var"; arrow] + | _ -> separate space [string "fun"; expY varstuple; arrow] + in + parens + ((prefix 2 1) + ((separate space) [string (combinator ^ csuffix); expY varstuple]) + ((prefix 0 1) + (parens (prefix 2 1 (group lambda) (expN cond))) + (parens (prefix 2 1 (group lambda) (expN body))) + ) + ) + | _ -> raise (Reporting.err_unreachable l __POS__ "Unexpected number of arguments for loop combinator") + end + | Id_aux (Id "early_return", _) -> begin match args with - | [cond; varstuple; body] - | [cond; varstuple; body; _] -> (* Ignore termination measures - not used in Lem *) - let return (E_aux (e, a)) = E_aux (E_internal_return (E_aux (e, a)), a) in - let csuffix, cond, body = - match effectful (effect_of cond), effectful (effect_of body) with - | false, false -> "", cond, body - | false, true -> "M", return cond, body - | true, false -> "M", cond, return body - | true, true -> "M", cond, body - in - let used_vars_body = find_e_ids body in - let lambda = - (* Work around indentation issues in Lem when translating - tuple or literal unit patterns to Isabelle *) - match fst (uncast_exp varstuple) with - | E_aux (E_tuple _, _) - when not (IdSet.mem (mk_id "varstup") used_vars_body)-> - separate space [string "fun varstup"; arrow] ^^ break 1 ^^ - separate space [string "let"; expY varstuple; string "= varstup in"] - | E_aux (E_lit (L_aux (L_unit, _)), _) - when not (IdSet.mem (mk_id "unit_var") used_vars_body) -> - separate space [string "fun unit_var"; arrow] - | _ -> - separate space [string "fun"; expY varstuple; arrow] - in - parens ( - (prefix 2 1) - ((separate space) [string (combinator ^ csuffix); expY varstuple]) - ((prefix 0 1) - (parens (prefix 2 1 (group lambda) (expN cond))) - (parens (prefix 2 1 (group lambda) (expN body)))) - ) - | _ -> raise (Reporting.err_unreachable l __POS__ - "Unexpected number of arguments for loop combinator") + | [exp] -> + let returner, monad, arg_order = + if ctxt.monadic then ("early_return", "MR", fun x -> x) else ("Left", "either", List.rev) + in + let epp = separate space [string returner; expY exp] in + let aexp_needed, tepp = + match + ( Option.bind (Env.get_ret_typ (env_of exp)) (make_printable_type ctxt ctxt.top_env), + make_printable_type ctxt (env_of full_exp) (typ_of full_exp) + ) + with + | Some typ, Some full_typ -> + let tannot = + separate space + ([string monad] + @ arg_order + [ + doc_atomic_typ_lem ctxt.params_to_print false full_typ; + doc_atomic_typ_lem ctxt.params_to_print false typ; + ] + ) + in + (true, doc_op colon epp tannot) + | _ -> (aexp_needed, epp) + in + if aexp_needed then parens tepp else tepp + | _ -> raise (Reporting.err_unreachable l __POS__ "Unexpected number of arguments for early_return builtin") end - | Id_aux (Id "early_return", _) -> - begin match args with - | [exp] -> - let (returner, monad, arg_order) = if ctxt.monadic then ("early_return", "MR", fun x -> x) else ("Left", "either", List.rev) in - let epp = separate space [string returner; expY exp] in - let aexp_needed, tepp = - match Option.bind (Env.get_ret_typ (env_of exp)) (make_printable_type ctxt ctxt.top_env), - make_printable_type ctxt (env_of full_exp) (typ_of full_exp) with - | Some typ, Some full_typ -> - let tannot = separate space ([string monad] - @ arg_order [doc_atomic_typ_lem ctxt.params_to_print false full_typ; - doc_atomic_typ_lem ctxt.params_to_print false typ]) in - true, doc_op colon epp tannot - | _ -> aexp_needed, epp - in - if aexp_needed then parens tepp else tepp - | _ -> raise (Reporting.err_unreachable l __POS__ - "Unexpected number of arguments for early_return builtin") - end - | _ -> - begin match destruct_tannot annot with - | Some (env, typ) when Env.is_union_constructor f env -> - let unwrap opt = match opt with - | Some x -> x - | None -> Reporting.unreachable l __POS__ ("Failed to get information about constructor " ^ string_of_id f) in - let (_, _, union_id, _) = Env.union_constructor_info f env |> unwrap in - let (typq, _) = Env.get_variants env |> Bindings.find_opt union_id |> unwrap in - (* If the union has type variables, we may need an annotation for Lem to typecheck it *) - let annotation_needed = false (* List.length (quant_items typq) > 0 *) in - let wrap_union doc = if aexp_needed || annotation_needed then parens doc else doc in - let epp = - match args with - | [] -> doc_id_lem_ctor f - | [arg] -> doc_id_lem_ctor f ^^ space ^^ expV true arg - | _ -> - doc_id_lem_ctor f ^^ space ^^ - parens (separate_map comma (expV false) args) in - wrap_union (if annotation_needed then align epp ^^ doc_tannot_lem ctxt env false typ else align epp) - | _ -> - let call, is_extern = match destruct_tannot annot with - | Some (env, _) when Env.is_extern f env "lem" -> - string (Env.get_extern f env "lem"), true - | _ -> doc_id_lem f, false in - let epp = hang 2 (flow (break 1) (call :: List.map expY args)) in - let (taepp,aexp_needed) = - let env = env_of full_exp in - let t = Env.expand_synonyms env (typ_of full_exp) in - let eff = effect_of full_exp in - if typ_needs_printed ctxt.params_to_print t then - if Id.compare f (mk_id "bitvector_cast_out") <> 0 && - Id.compare f (mk_id "zero_extend_type_hack") <> 0 - then (align (group (prefix 0 1 epp (doc_tannot_lem ctxt env (effectful eff) t))), true) - (* TODO: coordinate with the code in monomorphise.ml to find the correct - typing environment to use *) - else (align (group (prefix 0 1 epp (doc_tannot_lem ctxt ctxt.top_env (effectful eff) t))), true) - else (epp, aexp_needed) in - liftR (if aexp_needed then parens (align taepp) else taepp) + | _ -> begin + match destruct_tannot annot with + | Some (env, typ) when Env.is_union_constructor f env -> + let unwrap opt = + match opt with + | Some x -> x + | None -> + Reporting.unreachable l __POS__ ("Failed to get information about constructor " ^ string_of_id f) + in + let _, _, union_id, _ = Env.union_constructor_info f env |> unwrap in + let typq, _ = Env.get_variants env |> Bindings.find_opt union_id |> unwrap in + (* If the union has type variables, we may need an annotation for Lem to typecheck it *) + let annotation_needed = false (* List.length (quant_items typq) > 0 *) in + let wrap_union doc = if aexp_needed || annotation_needed then parens doc else doc in + let epp = + match args with + | [] -> doc_id_lem_ctor f + | [arg] -> doc_id_lem_ctor f ^^ space ^^ expV true arg + | _ -> doc_id_lem_ctor f ^^ space ^^ parens (separate_map comma (expV false) args) + in + wrap_union (if annotation_needed then align epp ^^ doc_tannot_lem ctxt env false typ else align epp) + | _ -> + let call, is_extern = + match destruct_tannot annot with + | Some (env, _) when Env.is_extern f env "lem" -> (string (Env.get_extern f env "lem"), true) + | _ -> (doc_id_lem f, false) + in + let epp = hang 2 (flow (break 1) (call :: List.map expY args)) in + let taepp, aexp_needed = + let env = env_of full_exp in + let t = Env.expand_synonyms env (typ_of full_exp) in + let eff = effect_of full_exp in + if typ_needs_printed ctxt.params_to_print t then + if + Id.compare f (mk_id "bitvector_cast_out") <> 0 + && Id.compare f (mk_id "zero_extend_type_hack") <> 0 + then (align (group (prefix 0 1 epp (doc_tannot_lem ctxt env (effectful eff) t))), true) + (* TODO: coordinate with the code in monomorphise.ml to find the correct + typing environment to use *) + else (align (group (prefix 0 1 epp (doc_tannot_lem ctxt ctxt.top_env (effectful eff) t))), true) + else (epp, aexp_needed) + in + liftR (if aexp_needed then parens (align taepp) else taepp) end - end - | E_vector_access (v,e) -> - raise (Reporting.err_unreachable l __POS__ - "E_vector_access should have been rewritten before pretty-printing") - | E_vector_subrange (v,e1,e2) -> - raise (Reporting.err_unreachable l __POS__ - "E_vector_subrange should have been rewritten before pretty-printing") - | E_field((E_aux(_,(l,fannot)) as fexp),id) -> - let ft = typ_of_annot (l,fannot) in - (match destruct_tannot fannot with - | Some(env, (Typ_aux (Typ_id tid, _))) - | Some(env, (Typ_aux (Typ_app (tid, _), _))) - when Env.is_record tid env -> - let fname = - if prefix_recordtype && string_of_id tid <> "regstate" - then (string (string_of_id tid ^ "_")) ^^ doc_id_lem id - else doc_id_lem id in - expY fexp ^^ dot ^^ fname - | _ -> - raise (report l __POS__ "E_field expression with no register or record type")) + end + | E_vector_access (v, e) -> + raise (Reporting.err_unreachable l __POS__ "E_vector_access should have been rewritten before pretty-printing") + | E_vector_subrange (v, e1, e2) -> + raise (Reporting.err_unreachable l __POS__ "E_vector_subrange should have been rewritten before pretty-printing") + | E_field ((E_aux (_, (l, fannot)) as fexp), id) -> ( + let ft = typ_of_annot (l, fannot) in + match destruct_tannot fannot with + | (Some (env, Typ_aux (Typ_id tid, _)) | Some (env, Typ_aux (Typ_app (tid, _), _))) when Env.is_record tid env + -> + let fname = + if prefix_recordtype && string_of_id tid <> "regstate" then + string (string_of_id tid ^ "_") ^^ doc_id_lem id + else doc_id_lem id + in + expY fexp ^^ dot ^^ fname + | _ -> raise (report l __POS__ "E_field expression with no register or record type") + ) | E_block [] -> string "()" | E_block exps -> raise (report l __POS__ "Blocks should have been removed till now.") | E_id id | E_ref id -> - let env = env_of full_exp in - let typ = typ_of full_exp in - let eff = effect_of full_exp in - let base_typ = Env.base_typ_of env typ in - if Env.is_register id env && (match e with E_id _ -> true | _ -> false) then - let epp = separate space [string "read_reg";doc_id_lem (append_id id "_ref")] in - if is_bitvector_typ base_typ - then liftR (parens (align (group (prefix 0 1 epp (doc_tannot_lem ctxt env true base_typ))))) - else liftR epp - else if Env.is_register id env && (match e with E_ref _ -> true | _ -> false) then doc_id_lem (append_id id "_ref") - else if is_ctor env id then doc_id_lem_ctor id - else doc_id_lem id + let env = env_of full_exp in + let typ = typ_of full_exp in + let eff = effect_of full_exp in + let base_typ = Env.base_typ_of env typ in + if Env.is_register id env && match e with E_id _ -> true | _ -> false then ( + let epp = separate space [string "read_reg"; doc_id_lem (append_id id "_ref")] in + if is_bitvector_typ base_typ then + liftR (parens (align (group (prefix 0 1 epp (doc_tannot_lem ctxt env true base_typ))))) + else liftR epp + ) + else if Env.is_register id env && match e with E_ref _ -> true | _ -> false then doc_id_lem (append_id id "_ref") + else if is_ctor env id then doc_id_lem_ctor id + else doc_id_lem id | E_lit lit -> - let env = env_of full_exp in - let typ = Env.expand_synonyms env (typ_of full_exp) in - let eff = effect_of full_exp in - if typ_needs_printed ctxt.params_to_print typ - then parens (doc_lit_lem lit ^^ doc_tannot_lem ctxt env (effectful eff) typ) - else doc_lit_lem lit - | E_typ (typ,e) -> expV aexp_needed e (*parens (expN e ^^ doc_tannot_lem ctxt (env_of full_exp) (effectful (effect_of full_exp)) typ)*) - | E_tuple exps -> - parens (align (group (separate_map (comma ^^ break 1) expN exps))) + let env = env_of full_exp in + let typ = Env.expand_synonyms env (typ_of full_exp) in + let eff = effect_of full_exp in + if typ_needs_printed ctxt.params_to_print typ then + parens (doc_lit_lem lit ^^ doc_tannot_lem ctxt env (effectful eff) typ) + else doc_lit_lem lit + | E_typ (typ, e) -> + expV aexp_needed + e (*parens (expN e ^^ doc_tannot_lem ctxt (env_of full_exp) (effectful (effect_of full_exp)) typ)*) + | E_tuple exps -> parens (align (group (separate_map (comma ^^ break 1) expN exps))) | E_struct fexps -> - let recordtyp, annotation_needed, env, typ = match destruct_tannot annot with - | Some (env, (Typ_aux (Typ_id tid,_) as typ)) -> tid, false, env, typ - (* We need an annotation here because some record type parameters may be phantom *) - | Some (env, (Typ_aux (Typ_app (tid, _), _) as typ)) -> tid, true, env, typ - | _ -> Reporting.unreachable l __POS__ ("cannot get record type from annot " ^ string_of_tannot annot ^ " of exp " ^ string_of_exp full_exp) in - let wrap_record doc = if aexp_needed || annotation_needed then parens doc else doc in - wrap_record (anglebars (space ^^ (align (separate_map - (semi_sp ^^ break 1) - (doc_fexp ctxt recordtyp) fexps)) ^^ space) - ^^ if annotation_needed then doc_tannot_lem ctxt env false typ else empty) - | E_struct_update(e, fexps) -> - let recordtyp = match destruct_tannot annot with - | Some (env, Typ_aux (Typ_id tid,_)) - | Some (env, Typ_aux (Typ_app (tid, _), _)) - when Env.is_record tid env -> - tid - | _ -> raise (report l __POS__ ("cannot get record type from annot " ^ string_of_tannot annot ^ " of exp " ^ string_of_exp full_exp)) in - anglebars (space ^^ doc_op (string "with") (expY e) (separate_map semi_sp (doc_fexp ctxt recordtyp) fexps) ^^ space) + let recordtyp, annotation_needed, env, typ = + match destruct_tannot annot with + | Some (env, (Typ_aux (Typ_id tid, _) as typ)) -> (tid, false, env, typ) + (* We need an annotation here because some record type parameters may be phantom *) + | Some (env, (Typ_aux (Typ_app (tid, _), _) as typ)) -> (tid, true, env, typ) + | _ -> + Reporting.unreachable l __POS__ + ("cannot get record type from annot " ^ string_of_tannot annot ^ " of exp " ^ string_of_exp full_exp) + in + let wrap_record doc = if aexp_needed || annotation_needed then parens doc else doc in + wrap_record + (anglebars (space ^^ align (separate_map (semi_sp ^^ break 1) (doc_fexp ctxt recordtyp) fexps) ^^ space) + ^^ if annotation_needed then doc_tannot_lem ctxt env false typ else empty + ) + | E_struct_update (e, fexps) -> + let recordtyp = + match destruct_tannot annot with + | (Some (env, Typ_aux (Typ_id tid, _)) | Some (env, Typ_aux (Typ_app (tid, _), _))) when Env.is_record tid env + -> + tid + | _ -> + raise + (report l __POS__ + ("cannot get record type from annot " ^ string_of_tannot annot ^ " of exp " ^ string_of_exp full_exp) + ) + in + anglebars + (space ^^ doc_op (string "with") (expY e) (separate_map semi_sp (doc_fexp ctxt recordtyp) fexps) ^^ space) | E_vector exps -> - let t = Env.base_typ_of (env_of full_exp) (typ_of full_exp) in - let start, (len, order, etyp) = - if is_vector_typ t || is_bitvector_typ t then vector_start_index t, vector_typ_args_of t - else raise (Reporting.err_unreachable l __POS__ - "E_vector of non-vector type") in - let dir,dir_out = if is_order_inc order then (true,"true") else (false, "false") in - let start = match nexp_simp start with - | Nexp_aux (Nexp_constant i, _) -> Big_int.to_string i - | _ -> if dir then "0" else string_of_int (List.length exps) in - (* let expspp = - match exps with - | [] -> empty - | e :: es -> - let (expspp,_) = - List.fold_left - (fun (pp,count) e -> - (pp ^^ semi ^^ (if count = 20 then break 0 else empty) ^^ - expN e), - if count = 20 then 0 else count + 1) - (expN e,0) es in - align (group expspp) in *) - let expspp = align (group (flow_map (semi ^^ break 0) expN exps)) in - let epp = brackets expspp in - let (epp,aexp_needed) = - if is_bit_typ etyp && !Monomorphise.opt_mwords then - let bepp = string "vec_of_bits" ^^ space ^^ align epp in - (align (group (prefix 0 1 bepp (doc_tannot_lem ctxt (env_of full_exp) false t))), true) - else (epp,aexp_needed) in - if aexp_needed then parens (align epp) else epp - | E_vector_update(v,e1,e2) -> - raise (Reporting.err_unreachable l __POS__ - "E_vector_update should have been rewritten before pretty-printing") - | E_vector_update_subrange(v,e1,e2,e3) -> - raise (Reporting.err_unreachable l __POS__ - "E_vector_update should have been rewritten before pretty-printing") - | E_list exps -> - brackets (separate_map semi (expN) exps) - | E_match(e,pexps) -> - let only_integers e = expY e in - wrap_parens - (group ((separate space [string "match"; only_integers e; string "with"]) ^/^ - (separate_map (break 1) (doc_case ctxt) pexps) ^/^ - (string "end"))) + let t = Env.base_typ_of (env_of full_exp) (typ_of full_exp) in + let start, (len, order, etyp) = + if is_vector_typ t || is_bitvector_typ t then (vector_start_index t, vector_typ_args_of t) + else raise (Reporting.err_unreachable l __POS__ "E_vector of non-vector type") + in + let dir, dir_out = if is_order_inc order then (true, "true") else (false, "false") in + let start = + match nexp_simp start with + | Nexp_aux (Nexp_constant i, _) -> Big_int.to_string i + | _ -> if dir then "0" else string_of_int (List.length exps) + in + (* let expspp = + match exps with + | [] -> empty + | e :: es -> + let (expspp,_) = + List.fold_left + (fun (pp,count) e -> + (pp ^^ semi ^^ (if count = 20 then break 0 else empty) ^^ + expN e), + if count = 20 then 0 else count + 1) + (expN e,0) es in + align (group expspp) in *) + let expspp = align (group (flow_map (semi ^^ break 0) expN exps)) in + let epp = brackets expspp in + let epp, aexp_needed = + if is_bit_typ etyp && !Monomorphise.opt_mwords then ( + let bepp = string "vec_of_bits" ^^ space ^^ align epp in + (align (group (prefix 0 1 bepp (doc_tannot_lem ctxt (env_of full_exp) false t))), true) + ) + else (epp, aexp_needed) + in + if aexp_needed then parens (align epp) else epp + | E_vector_update (v, e1, e2) -> + raise (Reporting.err_unreachable l __POS__ "E_vector_update should have been rewritten before pretty-printing") + | E_vector_update_subrange (v, e1, e2, e3) -> + raise (Reporting.err_unreachable l __POS__ "E_vector_update should have been rewritten before pretty-printing") + | E_list exps -> brackets (separate_map semi expN exps) + | E_match (e, pexps) -> + let only_integers e = expY e in + wrap_parens + (group + (separate space [string "match"; only_integers e; string "with"] + ^/^ separate_map (break 1) (doc_case ctxt) pexps + ^/^ string "end" + ) + ) | E_try (e, pexps) -> - if effectful (effect_of e) then - let try_catch = if ctxt.early_ret then "try_catchR" else "try_catch" in - wrap_parens - (group ((separate space [string try_catch; expY e; string "(function "]) ^/^ - (separate_map (break 1) (doc_case ctxt) pexps) ^/^ - (string "end)"))) - else - raise (Reporting.err_todo l "Warning: try-block around pure expression") - | E_throw e -> - align (liftR (separate space [string "throw"; expY e])) + if effectful (effect_of e) then ( + let try_catch = if ctxt.early_ret then "try_catchR" else "try_catch" in + wrap_parens + (group + (separate space [string try_catch; expY e; string "(function "] + ^/^ separate_map (break 1) (doc_case ctxt) pexps + ^/^ string "end)" + ) + ) + ) + else raise (Reporting.err_todo l "Warning: try-block around pure expression") + | E_throw e -> align (liftR (separate space [string "throw"; expY e])) | E_exit e -> liftR (separate space [string "exit"; expY e]) - | E_assert (e1,e2) -> - align (liftR (separate space [string "assert_exp"; expY e1; expY e2])) - | E_app_infix (e1,id,e2) -> - expV aexp_needed (E_aux (E_app (deinfix id, [e1; e2]), (l, annot))) - | E_var(lexp, eq_exp, in_exp) -> - raise (report l __POS__ "E_vars should have been removed before pretty-printing") - | E_internal_plet (pat,e1,e2) -> - let bind, bind_unit = if ctxt.monadic then (">>=", ">>") else (">>$=", ">>$") in - let epp = - let b = match e1 with E_aux (E_if _,_) -> true | _ -> false in - let middle = - match fst (untyp_pat pat) with - | P_aux (P_wild,_) | P_aux (P_typ (_, P_aux (P_wild, _)), _) - when is_unit_typ (typ_of_pat pat) -> - string bind_unit - | P_aux (P_tuple _, _) - when not (IdSet.mem (mk_id "varstup") (find_e_ids e2)) -> - (* Work around indentation issues in Lem when translating - tuple patterns to Isabelle *) - separate space - [string (bind ^ " fun varstup -> let"); - doc_pat_lem ctxt true pat; - string "= varstup in"] - | _ -> - separate space [string (bind ^ " fun"); - doc_pat_lem ctxt true pat; arrow] - in - infix 0 1 middle (expV b e1) (expN e2) - in - wrap_parens (align epp) - | E_internal_return (e1) -> - let return = if ctxt.monadic then "return" else "Right" in - wrap_parens (align (separate space [string return; expY e1])) - | E_sizeof nexp -> - (match nexp_simp nexp with + | E_assert (e1, e2) -> align (liftR (separate space [string "assert_exp"; expY e1; expY e2])) + | E_app_infix (e1, id, e2) -> expV aexp_needed (E_aux (E_app (deinfix id, [e1; e2]), (l, annot))) + | E_var (lexp, eq_exp, in_exp) -> raise (report l __POS__ "E_vars should have been removed before pretty-printing") + | E_internal_plet (pat, e1, e2) -> + let bind, bind_unit = if ctxt.monadic then (">>=", ">>") else (">>$=", ">>$") in + let epp = + let b = match e1 with E_aux (E_if _, _) -> true | _ -> false in + let middle = + match fst (untyp_pat pat) with + | (P_aux (P_wild, _) | P_aux (P_typ (_, P_aux (P_wild, _)), _)) when is_unit_typ (typ_of_pat pat) -> + string bind_unit + | P_aux (P_tuple _, _) when not (IdSet.mem (mk_id "varstup") (find_e_ids e2)) -> + (* Work around indentation issues in Lem when translating + tuple patterns to Isabelle *) + separate space [string (bind ^ " fun varstup -> let"); doc_pat_lem ctxt true pat; string "= varstup in"] + | _ -> separate space [string (bind ^ " fun"); doc_pat_lem ctxt true pat; arrow] + in + infix 0 1 middle (expV b e1) (expN e2) + in + wrap_parens (align epp) + | E_internal_return e1 -> + let return = if ctxt.monadic then "return" else "Right" in + wrap_parens (align (separate space [string return; expY e1])) + | E_sizeof nexp -> ( + match nexp_simp nexp with | Nexp_aux (Nexp_constant i, _) -> doc_lit_lem (L_aux (L_num i, l)) | _ -> - raise (Reporting.err_unreachable l __POS__ - "pretty-printing non-constant sizeof expressions to Lem not supported")) + raise + (Reporting.err_unreachable l __POS__ + "pretty-printing non-constant sizeof expressions to Lem not supported" + ) + ) | E_return r -> - let ta = - match Option.bind (Env.get_ret_typ (env_of full_exp)) (make_printable_type ctxt ctxt.top_env), - make_printable_type ctxt (env_of r) (typ_of r) with - | Some full_typ, Some r_typ -> - separate space - [string ": MR"; - parens (doc_typ_lem ctxt.params_to_print (env_of full_exp) full_typ); - parens (doc_typ_lem ctxt.params_to_print (env_of r) r_typ)] - | _ -> empty - in - align (parens (string "early_return" ^//^ expV true r ^//^ ta)) + let ta = + match + ( Option.bind (Env.get_ret_typ (env_of full_exp)) (make_printable_type ctxt ctxt.top_env), + make_printable_type ctxt (env_of r) (typ_of r) + ) + with + | Some full_typ, Some r_typ -> + separate space + [ + string ": MR"; + parens (doc_typ_lem ctxt.params_to_print (env_of full_exp) full_typ); + parens (doc_typ_lem ctxt.params_to_print (env_of r) r_typ); + ] + | _ -> empty + in + align (parens (string "early_return" ^//^ expV true r ^//^ ta)) | E_constraint _ -> string "true" | E_internal_assume (nc, e1) -> - string "(* " ^^ string (string_of_n_constraint nc) ^^ string " *)" ^/^ wrap_parens (expN e1) + string "(* " ^^ string (string_of_n_constraint nc) ^^ string " *)" ^/^ wrap_parens (expN e1) | E_internal_value _ -> - raise (Reporting.err_unreachable l __POS__ - "unsupported internal expression encountered while pretty-printing") + raise (Reporting.err_unreachable l __POS__ "unsupported internal expression encountered while pretty-printing") and if_exp ctxt (elseif : bool) c t e = let if_pp = string (if elseif then "else if" else "if") in - let else_pp = match e with - | E_aux (E_if (c', t', e'), _) - | E_aux (E_typ (_, E_aux (E_if (c', t', e'), _)), _) -> - if_exp ctxt true c' t' e' + let else_pp = + match e with + | E_aux (E_if (c', t', e'), _) | E_aux (E_typ (_, E_aux (E_if (c', t', e'), _)), _) -> if_exp ctxt true c' t' e' (* Special case to prevent current arm decoder becoming a staircase *) (* TODO: replace with smarter pretty printing *) - | E_aux (E_internal_plet (pat,exp1,E_aux (E_typ (typ, (E_aux (E_if (_, _, _), _) as exp2)),_)),ann) when Typ.compare typ unit_typ == 0 -> - string "else" ^/^ top_exp ctxt false (E_aux (E_internal_plet (pat,exp1,exp2),ann)) + | E_aux (E_internal_plet (pat, exp1, E_aux (E_typ (typ, (E_aux (E_if (_, _, _), _) as exp2)), _)), ann) + when Typ.compare typ unit_typ == 0 -> + string "else" ^/^ top_exp ctxt false (E_aux (E_internal_plet (pat, exp1, exp2), ann)) | _ -> prefix 2 1 (string "else") (top_exp ctxt false e) in - (prefix 2 1 - (soft_surround 2 1 if_pp (top_exp ctxt true c) (string "then")) - (top_exp ctxt false t)) ^^ - break 1 ^^ - else_pp - and let_exp ctxt (LB_aux(lb,_)) = match lb with - | LB_val(pat,e) -> - let pat = if is_bitvector_cast_out e then replace_env_for_cast_out ctxt.top_env pat else pat in - prefix 2 1 - (separate space [string "let"; doc_pat_lem ctxt true pat; equals]) - (top_exp ctxt false e) - - and doc_fexp ctxt recordtyp (FE_aux(FE_fexp(id,e),_)) = + prefix 2 1 (soft_surround 2 1 if_pp (top_exp ctxt true c) (string "then")) (top_exp ctxt false t) + ^^ break 1 ^^ else_pp + and let_exp ctxt (LB_aux (lb, _)) = + match lb with + | LB_val (pat, e) -> + let pat = if is_bitvector_cast_out e then replace_env_for_cast_out ctxt.top_env pat else pat in + prefix 2 1 (separate space [string "let"; doc_pat_lem ctxt true pat; equals]) (top_exp ctxt false e) + and doc_fexp ctxt recordtyp (FE_aux (FE_fexp (id, e), _)) = let fname = - if prefix_recordtype && string_of_id recordtyp <> "regstate" - then (string (string_of_id recordtyp ^ "_")) ^^ doc_id_lem id - else doc_id_lem id in + if prefix_recordtype && string_of_id recordtyp <> "regstate" then + string (string_of_id recordtyp ^ "_") ^^ doc_id_lem id + else doc_id_lem id + in group (doc_op equals fname (top_exp ctxt true e)) - and doc_case ctxt = function - | Pat_aux(Pat_exp(pat,e),_) -> - group (prefix 3 1 (separate space [pipe; doc_pat_lem ctxt false pat;arrow]) - (group (top_exp ctxt false e))) - | Pat_aux(Pat_when(_,_,_),(l,_)) -> - raise (Reporting.err_unreachable l __POS__ - "guarded pattern expression should have been rewritten before pretty-printing") - - and doc_lexp_deref_lem ctxt ((LE_aux(lexp,(l,annot))) as le) = match lexp with - | LE_field (le,id) -> - parens (separate empty [doc_lexp_deref_lem ctxt le;dot;doc_id_lem id]) + | Pat_aux (Pat_exp (pat, e), _) -> + group (prefix 3 1 (separate space [pipe; doc_pat_lem ctxt false pat; arrow]) (group (top_exp ctxt false e))) + | Pat_aux (Pat_when (_, _, _), (l, _)) -> + raise + (Reporting.err_unreachable l __POS__ + "guarded pattern expression should have been rewritten before pretty-printing" + ) + and doc_lexp_deref_lem ctxt (LE_aux (lexp, (l, annot)) as le) = + match lexp with + | LE_field (le, id) -> parens (separate empty [doc_lexp_deref_lem ctxt le; dot; doc_id_lem id]) | LE_id id -> doc_id_lem (append_id id "_ref") - | LE_typ (typ,id) -> doc_id_lem (append_id id "_ref") + | LE_typ (typ, id) -> doc_id_lem (append_id id "_ref") | LE_tuple lexps -> parens (separate_map comma_sp (doc_lexp_deref_lem ctxt) lexps) - | _ -> - raise (Reporting.err_unreachable l __POS__ ("doc_lexp_deref_lem: Unsupported lexp")) - (* expose doc_exp_lem and doc_let *) - in top_exp, let_exp + | _ -> raise (Reporting.err_unreachable l __POS__ "doc_lexp_deref_lem: Unsupported lexp") + (* expose doc_exp_lem and doc_let *) + in + (top_exp, let_exp) (*TODO Upcase and downcase type and constructors as needed*) -let doc_type_union_lem params_to_print env (Tu_aux(Tu_ty_id(typ,id),_)) = - separate space [pipe; doc_id_lem_ctor id; string "of"; - parens (doc_typ_lem params_to_print env typ)] +let doc_type_union_lem params_to_print env (Tu_aux (Tu_ty_id (typ, id), _)) = + separate space [pipe; doc_id_lem_ctor id; string "of"; parens (doc_typ_lem params_to_print env typ)] (* let rec doc_range_lem (BF_aux(r,_)) = match r with @@ -1233,474 +1211,580 @@ let rec doc_range_lem (BF_aux(r,_)) = match r with | BF_concat(ir1,ir2) -> (doc_range ir1) ^^ comma ^^ (doc_range ir2) *) -let doc_typquant_sorts idpp (TypQ_aux (typq,_)) = +let doc_typquant_sorts idpp (TypQ_aux (typq, _)) = match typq with | TypQ_tq qs -> - let q (QI_aux (qi,_)) = - match qi with - | QI_id (KOpt_aux (KOpt_kind (K_aux (K_int,_),kid),_)) -> Some (string "`len`") - | QI_id (KOpt_aux (KOpt_kind (K_aux (K_type,_),kid),_)) -> Some underscore - | QI_id (KOpt_aux (KOpt_kind (K_aux ((K_order|K_bool),_),kid),_)) -> None - | QI_constraint _ -> None - in - if List.exists (function (QI_aux (QI_id (KOpt_aux (KOpt_kind (K_aux (K_int,_),_),_)),_)) -> true | _ -> false) qs then - let qs_pp = List.filter_map q qs in - string "declare isabelle target_sorts " ^^ idpp ^^ space ^^ separate space (equals::qs_pp) ^^ hardline - else empty + let q (QI_aux (qi, _)) = + match qi with + | QI_id (KOpt_aux (KOpt_kind (K_aux (K_int, _), kid), _)) -> Some (string "`len`") + | QI_id (KOpt_aux (KOpt_kind (K_aux (K_type, _), kid), _)) -> Some underscore + | QI_id (KOpt_aux (KOpt_kind (K_aux ((K_order | K_bool), _), kid), _)) -> None + | QI_constraint _ -> None + in + if + List.exists + (function QI_aux (QI_id (KOpt_aux (KOpt_kind (K_aux (K_int, _), _), _)), _) -> true | _ -> false) + qs + then ( + let qs_pp = List.filter_map q qs in + string "declare isabelle target_sorts " ^^ idpp ^^ space ^^ separate space (equals :: qs_pp) ^^ hardline + ) + else empty | TypQ_no_forall -> empty -let doc_sia_id (Id_aux(i,_)) = - match i with - | Id i -> string i - | Operator x -> string ("operator " ^ x) +let doc_sia_id (Id_aux (i, _)) = match i with Id i -> string i | Operator x -> string ("operator " ^ x) let typq_to_print params_to_print id typq = match Bindings.find_opt id params_to_print with | None -> typq - | Some is -> - match typq with - | TypQ_aux (TypQ_no_forall, _) -> typq - | TypQ_aux (TypQ_tq qs, l) -> - List.fold_left (fun (t,i) h -> - if is_quant_kopt h then - if Util.IntSet.mem i is then (h::t,i+1) else (t,i+1) - else (t,i)) ([],0) qs |> - fst |> List.rev |> fun qs -> TypQ_aux (TypQ_tq qs, l) - -let doc_typdef_lem params_to_print env (TD_aux(td, (l, annot))) = match td with - | TD_abbrev(id,typq,A_aux (A_typ typ, _)) -> - let typq_to_print = typq_to_print params_to_print id typq in - let typschm = TypSchm_aux (TypSchm_ts (typq, typ), l) in - doc_op equals - (separate space [string "type"; doc_id_lem_type id; doc_typquant_items_lem (kid_nexps_of_typquant typq_to_print)]) - (doc_typschm_lem params_to_print env false typschm) + | Some is -> ( + match typq with + | TypQ_aux (TypQ_no_forall, _) -> typq + | TypQ_aux (TypQ_tq qs, l) -> + List.fold_left + (fun (t, i) h -> + if is_quant_kopt h then if Util.IntSet.mem i is then (h :: t, i + 1) else (t, i + 1) else (t, i) + ) + ([], 0) qs + |> fst |> List.rev + |> fun qs -> TypQ_aux (TypQ_tq qs, l) + ) + +let doc_typdef_lem params_to_print env (TD_aux (td, (l, annot))) = + match td with + | TD_abbrev (id, typq, A_aux (A_typ typ, _)) -> + let typq_to_print = typq_to_print params_to_print id typq in + let typschm = TypSchm_aux (TypSchm_ts (typq, typ), l) in + doc_op equals + (separate space + [string "type"; doc_id_lem_type id; doc_typquant_items_lem (kid_nexps_of_typquant typq_to_print)] + ) + (doc_typschm_lem params_to_print env false typschm) | TD_abbrev _ -> empty - | TD_record(id,typq,fs,_) -> - let fname fid = if prefix_recordtype && string_of_id id <> "regstate" - then concat [doc_id_lem id;string "_";doc_id_lem_type fid;] - else doc_id_lem_type fid in - let f_pp (typ,fid) = - concat [fname fid;space;colon;space;doc_typ_lem params_to_print env typ; semi] in - let rectyp = match typq with - | TypQ_aux (TypQ_tq qs, _) -> - let quant_item = function - | QI_aux (QI_id (KOpt_aux (KOpt_kind (_, kid), _)), l) -> - [A_aux (A_nexp (Nexp_aux (Nexp_var kid, l)), l)] - | _ -> [] in - let targs = List.concat (List.map quant_item qs) in - mk_typ (Typ_app (id, targs)) - | TypQ_aux (TypQ_no_forall, _) -> mk_id_typ id in - let fs_doc = group (separate_map (break 1) f_pp fs) in - (* let doc_field (ftyp, fid) = - let reftyp = - mk_typ (Typ_app (Id_aux (Id "field_ref", Parse_ast.Unknown), - [mk_typ_arg (A_typ rectyp); - mk_typ_arg (A_typ ftyp)])) in - let rfannot = doc_tannot_lem empty_ctxt env false reftyp in - let get, set = - string "rec_val" ^^ dot ^^ fname fid, - anglebars (space ^^ string "rec_val with " ^^ - (doc_op equals (fname fid) (string "v")) ^^ space) in - let base_ftyp = match annot with - | Some (env, _, _) -> Env.base_typ_of env ftyp - | _ -> ftyp in - let (start, is_inc) = - try - let start, (_, ord, _) = vector_start_index base_ftyp, vector_typ_args_of base_ftyp in - match nexp_simp start with - | Nexp_aux (Nexp_constant i, _) -> (i, is_order_inc ord) - | _ -> - raise (Reporting.err_unreachable Parse_ast.Unknown __POS__ - ("register " ^ string_of_id id ^ " has non-constant start index " ^ string_of_nexp start)) - with - | _ -> (Big_int.zero, true) in + | TD_record (id, typq, fs, _) -> + let fname fid = + if prefix_recordtype && string_of_id id <> "regstate" then + concat [doc_id_lem id; string "_"; doc_id_lem_type fid] + else doc_id_lem_type fid + in + let f_pp (typ, fid) = concat [fname fid; space; colon; space; doc_typ_lem params_to_print env typ; semi] in + let rectyp = + match typq with + | TypQ_aux (TypQ_tq qs, _) -> + let quant_item = function + | QI_aux (QI_id (KOpt_aux (KOpt_kind (_, kid), _)), l) -> [A_aux (A_nexp (Nexp_aux (Nexp_var kid, l)), l)] + | _ -> [] + in + let targs = List.concat (List.map quant_item qs) in + mk_typ (Typ_app (id, targs)) + | TypQ_aux (TypQ_no_forall, _) -> mk_id_typ id + in + let fs_doc = group (separate_map (break 1) f_pp fs) in + (* let doc_field (ftyp, fid) = + let reftyp = + mk_typ (Typ_app (Id_aux (Id "field_ref", Parse_ast.Unknown), + [mk_typ_arg (A_typ rectyp); + mk_typ_arg (A_typ ftyp)])) in + let rfannot = doc_tannot_lem empty_ctxt env false reftyp in + let get, set = + string "rec_val" ^^ dot ^^ fname fid, + anglebars (space ^^ string "rec_val with " ^^ + (doc_op equals (fname fid) (string "v")) ^^ space) in + let base_ftyp = match annot with + | Some (env, _, _) -> Env.base_typ_of env ftyp + | _ -> ftyp in + let (start, is_inc) = + try + let start, (_, ord, _) = vector_start_index base_ftyp, vector_typ_args_of base_ftyp in + match nexp_simp start with + | Nexp_aux (Nexp_constant i, _) -> (i, is_order_inc ord) + | _ -> + raise (Reporting.err_unreachable Parse_ast.Unknown __POS__ + ("register " ^ string_of_id id ^ " has non-constant start index " ^ string_of_nexp start)) + with + | _ -> (Big_int.zero, true) in + doc_op equals + (concat [string "let "; parens (concat [doc_id_lem id; underscore; doc_id_lem fid; rfannot])]) + (anglebars (concat [space; + doc_op equals (string "field_name") (string_lit (doc_id_lem fid)); semi_sp; + doc_op equals (string "field_start") (string (Big_int.to_string start)); semi_sp; + doc_op equals (string "field_is_inc") (string (if is_inc then "true" else "false")); semi_sp; + doc_op equals (string "get_field") (parens (doc_op arrow (string "fun rec_val") get)); semi_sp; + doc_op equals (string "set_field") (parens (doc_op arrow (string "fun rec_val v") set)); space])) in *) + let typq_to_print = typq_to_print params_to_print id typq in + let sorts_pp = doc_typquant_sorts (doc_id_lem_type id) typq_to_print in doc_op equals - (concat [string "let "; parens (concat [doc_id_lem id; underscore; doc_id_lem fid; rfannot])]) - (anglebars (concat [space; - doc_op equals (string "field_name") (string_lit (doc_id_lem fid)); semi_sp; - doc_op equals (string "field_start") (string (Big_int.to_string start)); semi_sp; - doc_op equals (string "field_is_inc") (string (if is_inc then "true" else "false")); semi_sp; - doc_op equals (string "get_field") (parens (doc_op arrow (string "fun rec_val") get)); semi_sp; - doc_op equals (string "set_field") (parens (doc_op arrow (string "fun rec_val v") set)); space])) in *) - let typq_to_print = typq_to_print params_to_print id typq in - let sorts_pp = doc_typquant_sorts (doc_id_lem_type id) typq_to_print in - doc_op equals - (separate space [string "type"; doc_id_lem_type id; doc_typquant_items_lem (kid_nexps_of_typquant typq_to_print)]) - ((*doc_typquant_lem typq*) (anglebars (space ^^ align fs_doc ^^ space))) ^^ hardline ^^ sorts_pp - (* if !opt_sequential && string_of_id id = "regstate" then empty - else separate_map hardline doc_field fs *) - | TD_variant(id,typq,ar,_) -> - (match id with - | Id_aux ((Id "read_kind"),_) -> empty - | Id_aux ((Id "write_kind"),_) -> empty - | Id_aux ((Id "a64_barrier_domain"),_) -> empty - | Id_aux ((Id "a64_barrier_type"),_) -> empty - | Id_aux ((Id "barrier_kind"),_) -> empty - | Id_aux ((Id "trans_kind"),_) -> empty - | Id_aux ((Id "instruction_kind"),_) -> empty + (separate space + [string "type"; doc_id_lem_type id; doc_typquant_items_lem (kid_nexps_of_typquant typq_to_print)] + ) + ((*doc_typquant_lem typq*) anglebars (space ^^ align fs_doc ^^ space)) + ^^ hardline ^^ sorts_pp + (* if !opt_sequential && string_of_id id = "regstate" then empty + else separate_map hardline doc_field fs *) + | TD_variant (id, typq, ar, _) -> ( + match id with + | Id_aux (Id "read_kind", _) -> empty + | Id_aux (Id "write_kind", _) -> empty + | Id_aux (Id "a64_barrier_domain", _) -> empty + | Id_aux (Id "a64_barrier_type", _) -> empty + | Id_aux (Id "barrier_kind", _) -> empty + | Id_aux (Id "trans_kind", _) -> empty + | Id_aux (Id "instruction_kind", _) -> empty (* | Id_aux ((Id "regfp"),_) -> empty - | Id_aux ((Id "niafp"),_) -> empty - | Id_aux ((Id "diafp"),_) -> empty *) - | Id_aux ((Id "option"),_) -> empty + | Id_aux ((Id "niafp"),_) -> empty + | Id_aux ((Id "diafp"),_) -> empty *) + | Id_aux (Id "option", _) -> empty | _ -> - let env = Env.add_typquant l typq env in - let ar_doc = group (separate_map (break 1) (doc_type_union_lem params_to_print env) ar) in - let typq_to_print = typq_to_print params_to_print id typq in - let typ_pp = - (doc_op equals) - (concat [string "type"; space; doc_id_lem_type id; space; doc_typquant_items_lem (kid_nexps_of_typquant typq_to_print)]) - ((*doc_typquant_lem typq*) ar_doc) in - let make_id pat id = - separate space [string "SIA.Id_aux"; - parens (string "SIA.Id " ^^ string_lit (doc_sia_id id)); - if pat then underscore else string "SIA.Unknown"] in - let fromInterpValueF = concat [doc_id_lem_type id;string "FromInterpValue"] in - let toInterpValueF = concat [doc_id_lem_type id;string "ToInterpValue"] in - let fromInterpValuePP = - (prefix 2 1) - (separate space [string "let rec";fromInterpValueF;string "v";equals;string "match v with"]) - ( - ((separate_map (break 1)) - (fun (Tu_aux (Tu_ty_id (ty,cid),_)) -> - (separate space) - [pipe;string "SI.V_ctor";parens (make_id true cid);underscore;underscore;string "v"; + let env = Env.add_typquant l typq env in + let ar_doc = group (separate_map (break 1) (doc_type_union_lem params_to_print env) ar) in + let typq_to_print = typq_to_print params_to_print id typq in + let typ_pp = + (doc_op equals) + (concat + [ + string "type"; + space; + doc_id_lem_type id; + space; + doc_typquant_items_lem (kid_nexps_of_typquant typq_to_print); + ] + ) + (*doc_typquant_lem typq*) ar_doc + in + let make_id pat id = + separate space + [ + string "SIA.Id_aux"; + parens (string "SIA.Id " ^^ string_lit (doc_sia_id id)); + (if pat then underscore else string "SIA.Unknown"); + ] + in + let fromInterpValueF = concat [doc_id_lem_type id; string "FromInterpValue"] in + let toInterpValueF = concat [doc_id_lem_type id; string "ToInterpValue"] in + let fromInterpValuePP = + (prefix 2 1) + (separate space [string "let rec"; fromInterpValueF; string "v"; equals; string "match v with"]) + ((separate_map (break 1)) + (fun (Tu_aux (Tu_ty_id (ty, cid), _)) -> + (separate space) + [ + pipe; + string "SI.V_ctor"; + parens (make_id true cid); + underscore; + underscore; + string "v"; arrow; doc_id_lem_ctor cid; - parens (string "fromInterpValue v")]) - ar) ^/^ - - ((separate space) [pipe;string "SI.V_tuple [v]";arrow;fromInterpValueF;string "v"]) ^/^ - - let failmessage = - (string_lit - (concat [string "fromInterpValue";space;doc_id_lem_type id;colon;space;string "unexpected value. ";])) - ^^ - (string " ^ Interp.debug_print_value v") in - ((separate space) [pipe;string "v";arrow;string "failwith";parens failmessage]) ^/^ - string "end") in - let toInterpValuePP = - (prefix 2 1) - (separate space [string "let";toInterpValueF;equals;string "function"]) - ( - ((separate_map (break 1)) - (fun (Tu_aux (Tu_ty_id (ty,cid),_)) -> - (separate space) - [pipe;doc_id_lem_ctor cid;string "v";arrow; + parens (string "fromInterpValue v"); + ] + ) + ar + ^/^ (separate space) [pipe; string "SI.V_tuple [v]"; arrow; fromInterpValueF; string "v"] + ^/^ + let failmessage = + string_lit + (concat + [string "fromInterpValue"; space; doc_id_lem_type id; colon; space; string "unexpected value. "] + ) + ^^ string " ^ Interp.debug_print_value v" + in + (separate space) [pipe; string "v"; arrow; string "failwith"; parens failmessage] ^/^ string "end" + ) + in + let toInterpValuePP = + (prefix 2 1) + (separate space [string "let"; toInterpValueF; equals; string "function"]) + ((separate_map (break 1)) + (fun (Tu_aux (Tu_ty_id (ty, cid), _)) -> + (separate space) + [ + pipe; + doc_id_lem_ctor cid; + string "v"; + arrow; string "SI.V_ctor"; parens (make_id false cid); parens (string "SIA.T_id " ^^ string_lit (doc_sia_id id)); string "SI.C_Union"; - parens (string "toInterpValue v")]) - ar) ^/^ - string "end") in - let fromToInterpValuePP = - ((prefix 2 1) - (concat [string "instance ";parens (string "ToFromInterpValue " ^^ doc_id_lem_type id)]) - (concat [string "let toInterpValue = ";toInterpValueF;hardline; - string "let fromInterpValue = ";fromInterpValueF])) - ^/^ string "end" in - typ_pp ^^ hardline ^^ hardline ^^ - if !print_to_from_interp_value then - toInterpValuePP ^^ hardline ^^ hardline ^^ - fromInterpValuePP ^^ hardline ^^ hardline ^^ - fromToInterpValuePP ^^ hardline - else empty) - | TD_enum(id,enums,_) -> - (match id with - | Id_aux ((Id "read_kind"),_) -> empty - | Id_aux ((Id "write_kind"),_) -> empty - | Id_aux ((Id "a64_barrier_domain"),_) -> empty - | Id_aux ((Id "a64_barrier_type"),_) -> empty - | Id_aux ((Id "barrier_kind"),_) -> empty - | Id_aux ((Id "trans_kind"),_) -> empty - | Id_aux ((Id "instruction_kind"),_) -> empty - | Id_aux ((Id "cache_op_kind"),_) -> empty - | Id_aux ((Id "regfp"),_) -> empty - | Id_aux ((Id "niafp"),_) -> empty - | Id_aux ((Id "diafp"),_) -> empty + parens (string "toInterpValue v"); + ] + ) + ar + ^/^ string "end" + ) + in + let fromToInterpValuePP = + (prefix 2 1) + (concat [string "instance "; parens (string "ToFromInterpValue " ^^ doc_id_lem_type id)]) + (concat + [ + string "let toInterpValue = "; + toInterpValueF; + hardline; + string "let fromInterpValue = "; + fromInterpValueF; + ] + ) + ^/^ string "end" + in + typ_pp ^^ hardline ^^ hardline + ^^ + if !print_to_from_interp_value then + toInterpValuePP ^^ hardline ^^ hardline ^^ fromInterpValuePP ^^ hardline ^^ hardline ^^ fromToInterpValuePP + ^^ hardline + else empty + ) + | TD_enum (id, enums, _) -> ( + match id with + | Id_aux (Id "read_kind", _) -> empty + | Id_aux (Id "write_kind", _) -> empty + | Id_aux (Id "a64_barrier_domain", _) -> empty + | Id_aux (Id "a64_barrier_type", _) -> empty + | Id_aux (Id "barrier_kind", _) -> empty + | Id_aux (Id "trans_kind", _) -> empty + | Id_aux (Id "instruction_kind", _) -> empty + | Id_aux (Id "cache_op_kind", _) -> empty + | Id_aux (Id "regfp", _) -> empty + | Id_aux (Id "niafp", _) -> empty + | Id_aux (Id "diafp", _) -> empty | _ -> - let rec range i j = if i > j then [] else i :: (range (i+1) j) in - let nats = range 0 in - let enums_doc = group (pipe ^^ space ^^ separate_map (break 1 ^^ pipe ^^ space) doc_id_lem_ctor enums) in - let typ_pp = (doc_op equals) - (concat [string "type"; space; doc_id_lem_type id;]) - (enums_doc) in - let fromInterpValueF = concat [doc_id_lem_type id;string "FromInterpValue"] in - let toInterpValueF = concat [doc_id_lem_type id;string "ToInterpValue"] in - let make_id pat id = - separate space [string "SIA.Id_aux"; - parens (string "SIA.Id " ^^ string_lit (doc_sia_id id)); - if pat then underscore else string "SIA.Unknown"] in - let fromInterpValuePP = - (prefix 2 1) - (separate space [string "let rec";fromInterpValueF;string "v";equals;string "match v with"]) - ( - ((separate_map (break 1)) - (fun (cid) -> - (separate space) - [pipe;string "SI.V_ctor";parens (make_id true cid);underscore;underscore;string "v"; - arrow;doc_id_lem_ctor cid] + let rec range i j = if i > j then [] else i :: range (i + 1) j in + let nats = range 0 in + let enums_doc = group (pipe ^^ space ^^ separate_map (break 1 ^^ pipe ^^ space) doc_id_lem_ctor enums) in + let typ_pp = (doc_op equals) (concat [string "type"; space; doc_id_lem_type id]) enums_doc in + let fromInterpValueF = concat [doc_id_lem_type id; string "FromInterpValue"] in + let toInterpValueF = concat [doc_id_lem_type id; string "ToInterpValue"] in + let make_id pat id = + separate space + [ + string "SIA.Id_aux"; + parens (string "SIA.Id " ^^ string_lit (doc_sia_id id)); + (if pat then underscore else string "SIA.Unknown"); + ] + in + let fromInterpValuePP = + (prefix 2 1) + (separate space [string "let rec"; fromInterpValueF; string "v"; equals; string "match v with"]) + ((separate_map (break 1)) + (fun cid -> + (separate space) + [ + pipe; + string "SI.V_ctor"; + parens (make_id true cid); + underscore; + underscore; + string "v"; + arrow; + doc_id_lem_ctor cid; + ] + ) + enums + ^/^ align + ((prefix 3 1) + (separate space [pipe; string "SI.V_lit (SIA.L_aux (SIA.L_num n) _)"; arrow]) + (separate space [string "match"; parens (string "natFromInteger n"); string "with"] + ^/^ (separate_map (break 1)) + (fun (cid, number) -> + (separate space) [pipe; string (string_of_int number); arrow; doc_id_lem_ctor cid] + ) + (List.combine enums (nats (List.length enums - 1))) + ^/^ string "end" + ) + ) + ^/^ (separate space) [pipe; string "SI.V_tuple [v]"; arrow; fromInterpValueF; string "v"] + ^/^ + let failmessage = + string_lit + (concat + [string "fromInterpValue"; space; doc_id_lem_type id; colon; space; string "unexpected value. "] ) - enums - ) ^/^ - ( - (align - ((prefix 3 1) - (separate space [pipe;string ("SI.V_lit (SIA.L_aux (SIA.L_num n) _)");arrow]) - (separate space [string "match";parens(string "natFromInteger n");string "with"] ^/^ - ( - ((separate_map (break 1)) - (fun (cid,number) -> - (separate space) - [pipe;string (string_of_int number);arrow;doc_id_lem_ctor cid] - ) - (List.combine enums (nats ((List.length enums) - 1))) - ) ^/^ string "end" - ) - ) - ) - ) - ) ^/^ - - ((separate space) [pipe;string "SI.V_tuple [v]";arrow;fromInterpValueF;string "v"]) ^/^ - - let failmessage = - (string_lit - (concat [string "fromInterpValue";space;doc_id_lem_type id;colon;space;string "unexpected value. ";])) - ^^ - (string " ^ Interp.debug_print_value v") in - ((separate space) [pipe;string "v";arrow;string "failwith";parens failmessage]) ^/^ - - string "end") in - let toInterpValuePP = - (prefix 2 1) - (separate space [string "let";toInterpValueF;equals;string "function"]) - ( - ((separate_map (break 1)) - (fun (cid,number) -> - (separate space) - [pipe;doc_id_lem_ctor cid;arrow; + ^^ string " ^ Interp.debug_print_value v" + in + (separate space) [pipe; string "v"; arrow; string "failwith"; parens failmessage] ^/^ string "end" + ) + in + let toInterpValuePP = + (prefix 2 1) + (separate space [string "let"; toInterpValueF; equals; string "function"]) + ((separate_map (break 1)) + (fun (cid, number) -> + (separate space) + [ + pipe; + doc_id_lem_ctor cid; + arrow; string "SI.V_ctor"; parens (make_id false cid); parens (string "SIA.T_id " ^^ string_lit (doc_sia_id id)); parens (string ("SI.C_Enum " ^ string_of_int number)); - parens (string "toInterpValue ()")]) - (List.combine enums (nats ((List.length enums) - 1)))) ^/^ - string "end") in - let fromToInterpValuePP = - ((prefix 2 1) - (concat [string "instance ";parens (string "ToFromInterpValue " ^^ doc_id_lem_type id)]) - (concat [string "let toInterpValue = ";toInterpValueF;hardline; - string "let fromInterpValue = ";fromInterpValueF])) - ^/^ string "end" in - typ_pp ^^ hardline ^^ hardline ^^ - if !print_to_from_interp_value - then toInterpValuePP ^^ hardline ^^ hardline ^^ - fromInterpValuePP ^^ hardline ^^ hardline ^^ - fromToInterpValuePP ^^ hardline - else empty) - | _ -> raise (Reporting.err_unreachable l __POS__ "register with non-constant indices") + parens (string "toInterpValue ()"); + ] + ) + (List.combine enums (nats (List.length enums - 1))) + ^/^ string "end" + ) + in + let fromToInterpValuePP = + (prefix 2 1) + (concat [string "instance "; parens (string "ToFromInterpValue " ^^ doc_id_lem_type id)]) + (concat + [ + string "let toInterpValue = "; + toInterpValueF; + hardline; + string "let fromInterpValue = "; + fromInterpValueF; + ] + ) + ^/^ string "end" + in + typ_pp ^^ hardline ^^ hardline + ^^ + if !print_to_from_interp_value then + toInterpValuePP ^^ hardline ^^ hardline ^^ fromInterpValuePP ^^ hardline ^^ hardline ^^ fromToInterpValuePP + ^^ hardline + else empty + ) + | _ -> raise (Reporting.err_unreachable l __POS__ "register with non-constant indices") let args_of_typs l env typs = let arg i typ = let id = mk_id ("arg" ^ string_of_int i) in - P_aux (P_id id, (l, mk_tannot env typ)), - E_aux (E_id id, (l, mk_tannot env typ)) in + (P_aux (P_id id, (l, mk_tannot env typ)), E_aux (E_id id, (l, mk_tannot env typ))) + in List.split (List.mapi arg typs) let rec untuple_args_pat (P_aux (paux, ((l, _) as annot)) as pat) arg_typs = let env = env_of_annot annot in - let identity = (fun body -> body) in - match paux, arg_typs with + let identity body = body in + match (paux, arg_typs) with | P_tuple [], _ -> - let annot = (l, mk_tannot Env.empty unit_typ) in - [P_aux (P_lit (mk_lit L_unit), annot)], identity - | P_wild, (_::_::_) -> - let wild typ = P_aux (P_wild, (l, mk_tannot env typ)) in - List.map wild arg_typs, identity + let annot = (l, mk_tannot Env.empty unit_typ) in + ([P_aux (P_lit (mk_lit L_unit), annot)], identity) + | P_wild, _ :: _ :: _ -> + let wild typ = P_aux (P_wild, (l, mk_tannot env typ)) in + (List.map wild arg_typs, identity) | P_typ (_, pat), _ -> untuple_args_pat pat arg_typs - | P_as _, (_::_::_) - | P_id _, (_::_::_) -> - let argpats, argexps = args_of_typs l env arg_typs in - let argexp = E_aux (E_tuple argexps, annot) in - let bindargs (E_aux (_, bannot) as body) = - E_aux (E_let (LB_aux (LB_val (pat, argexp), annot), body), bannot) in - argpats, bindargs + | P_as _, _ :: _ :: _ | P_id _, _ :: _ :: _ -> + let argpats, argexps = args_of_typs l env arg_typs in + let argexp = E_aux (E_tuple argexps, annot) in + let bindargs (E_aux (_, bannot) as body) = E_aux (E_let (LB_aux (LB_val (pat, argexp), annot), body), bannot) in + (argpats, bindargs) (* The type checker currently has a special case for a single arg type; if that is removed, then remove the next case. *) - | P_tuple pats, [_] -> [pat], identity - | P_tuple pats, _ -> pats, identity - | _, _ -> - [pat], identity + | P_tuple pats, [_] -> ([pat], identity) + | P_tuple pats, _ -> (pats, identity) + | _, _ -> ([pat], identity) -let doc_tannot_opt_lem params_to_print env (Typ_annot_opt_aux(t,_)) = match t with - | Typ_annot_opt_some(tq,typ) -> (*doc_typquant_lem tq*) (doc_typ_lem params_to_print env typ) +let doc_tannot_opt_lem params_to_print env (Typ_annot_opt_aux (t, _)) = + match t with + | Typ_annot_opt_some (tq, typ) -> (*doc_typquant_lem tq*) doc_typ_lem params_to_print env typ | Typ_annot_opt_none -> empty let doc_fun_body_lem ctxt exp = let doc_exp = doc_exp_lem ctxt false exp in - if ctxt.early_ret && ctxt.monadic then - align (string "catch_early_return" ^//^ parens (doc_exp)) - else if ctxt.early_ret then - align (string "pure_early_return" ^//^ parens (doc_exp)) - else - doc_exp + if ctxt.early_ret && ctxt.monadic then align (string "catch_early_return" ^//^ parens doc_exp) + else if ctxt.early_ret then align (string "pure_early_return" ^//^ parens doc_exp) + else doc_exp -let doc_funcl_lem monadic params_to_print type_env (FCL_aux(FCL_funcl(id, pexp), ((def_annot, _) as annot))) = +let doc_funcl_lem monadic params_to_print type_env (FCL_aux (FCL_funcl (id, pexp), ((def_annot, _) as annot))) = let l = def_annot.loc in - let (tq, typ) = - try Env.get_val_spec_orig id type_env with - | _ -> raise (unreachable l __POS__ ("Could not get val-spec of " ^ string_of_id id)) + let tq, typ = + try Env.get_val_spec_orig id type_env + with _ -> raise (unreachable l __POS__ ("Could not get val-spec of " ^ string_of_id id)) in - let arg_typs = match typ with + let arg_typs = + match typ with | Typ_aux (Typ_fn (arg_typs, typ_ret), _) -> arg_typs | Typ_aux (_, l) -> raise (unreachable l __POS__ "Non-function type for funcl") in - let pat,guard,exp,(l,_) = destruct_pexp pexp in + let pat, guard, exp, (l, _) = destruct_pexp pexp in let ctxt = - { early_ret = contains_early_return exp; - monadic = monadic; + { + early_ret = contains_early_return exp; + monadic; bound_nexps = NexpSet.union (lem_nexps_of_typ params_to_print typ) (typeclass_nexps params_to_print typ); top_env = type_env; - params_to_print - } in + params_to_print; + } + in let pats, bind = untuple_args_pat pat arg_typs in let patspp = separate_map space (doc_pat_lem ctxt true) pats in let wrap_monadic = - if monadic && not (effectful (effect_of exp)) then - (fun doc -> string "return" ^^ space ^^ parens doc) - else (fun doc -> doc) in - let _ = match guard with + if monadic && not (effectful (effect_of exp)) then fun doc -> string "return" ^^ space ^^ parens doc + else fun doc -> doc + in + let _ = + match guard with | None -> () | _ -> - raise (Reporting.err_unreachable l __POS__ - "guarded pattern expression should have been rewritten before pretty-printing") in - group (prefix 3 1 - (separate space [doc_id_lem id; patspp; equals]) - (wrap_monadic (doc_fun_body_lem ctxt (bind exp)))) + raise + (Reporting.err_unreachable l __POS__ + "guarded pattern expression should have been rewritten before pretty-printing" + ) + in + group (prefix 3 1 (separate space [doc_id_lem id; patspp; equals]) (wrap_monadic (doc_fun_body_lem ctxt (bind exp)))) -let get_id = function - | [] -> failwith "FD_function with empty list" - | (FCL_aux (FCL_funcl (id,_),_))::_ -> id +let get_id = function [] -> failwith "FD_function with empty list" | FCL_aux (FCL_funcl (id, _), _) :: _ -> id -module StringSet = Set.Make(String) +module StringSet = Set.Make (String) (* Strictly speaking, Lem doesn't support multiple clauses for a single function joined by "and", although it has worked for Isabelle before. However, all - the funcls should have been merged by the merge_funcls rewrite now. *) + the funcls should have been merged by the merge_funcls rewrite now. *) let doc_fundef_rhs_lem monadic params_to_print env (FD_aux (FD_function (r, typa, funcls), fannot) as fd) = separate_map (hardline ^^ string "and ") (doc_funcl_lem monadic params_to_print env) funcls let doc_mutrec_lem effect_info params_to_print env = function | [] -> Reporting.unreachable Parse_ast.Unknown __POS__ "Empty internal_mutrec" - | (fundef :: _ as fundefs) -> - let id = id_of_fundef fundef in - let required_monadic = not (Effects.function_is_pure id effect_info) in - string "let rec " ^^ - separate_map (hardline ^^ string "and ") (doc_fundef_rhs_lem required_monadic params_to_print env) fundefs + | fundef :: _ as fundefs -> + let id = id_of_fundef fundef in + let required_monadic = not (Effects.function_is_pure id effect_info) in + string "let rec " + ^^ separate_map (hardline ^^ string "and ") (doc_fundef_rhs_lem required_monadic params_to_print env) fundefs let doc_fundef_lem effect_info params_to_print env (FD_aux (FD_function (r, typa, fcls), fannot) as fd) = match fcls with | [] -> Reporting.unreachable (fst fannot) __POS__ "FD_function with empty function list" - | FCL_aux (FCL_funcl (id, pexp), annot) :: _ - when not (Env.is_extern id env "lem") -> - (* A function is required to be monadic if Sail thinks it is impure *) - let required_monadic = not (Effects.function_is_pure id effect_info) in - (* Output "rec" modifier if function calls itself. Mutually recursive - functions are handled separately by doc_mutrec_lem. *) - let is_funcl_rec = - fold_pexp - { (pure_exp_alg false (||)) with - e_app = (fun (id', args) -> List.fold_left (||) (Id.compare id id' = 0) args); - e_app_infix = (fun (l, id', r) -> l || (Id.compare id id' = 0) || r) } - pexp - in - let doc_rec = if is_funcl_rec then [string "rec"] else [] in - separate space ([string "let"] @ doc_rec @ [doc_fundef_rhs_lem required_monadic params_to_print env fd]) + | FCL_aux (FCL_funcl (id, pexp), annot) :: _ when not (Env.is_extern id env "lem") -> + (* A function is required to be monadic if Sail thinks it is impure *) + let required_monadic = not (Effects.function_is_pure id effect_info) in + (* Output "rec" modifier if function calls itself. Mutually recursive + functions are handled separately by doc_mutrec_lem. *) + let is_funcl_rec = + fold_pexp + { + (pure_exp_alg false ( || )) with + e_app = (fun (id', args) -> List.fold_left ( || ) (Id.compare id id' = 0) args); + e_app_infix = (fun (l, id', r) -> l || Id.compare id id' = 0 || r); + } + pexp + in + let doc_rec = if is_funcl_rec then [string "rec"] else [] in + separate space ([string "let"] @ doc_rec @ [doc_fundef_rhs_lem required_monadic params_to_print env fd]) | _ -> empty -let doc_dec_lem (DEC_aux (reg, ((l, _) as annot))) = - match reg with - | DEC_reg (typ, id, _) -> empty - (* if !opt_sequential then empty - else - let env = env_of_annot annot in - let rt = Env.base_typ_of env typ in - if is_vector_typ rt then - let start, (size, order, etyp) = vector_start_index rt, vector_typ_args_of rt in - if is_bit_typ etyp && is_nexp_constant start && is_nexp_constant size then - let o = if is_order_inc order then "true" else "false" in - (doc_op equals) - (string "let" ^^ space ^^ doc_id_lem id) - (string "Register" ^^ space ^^ - align (separate space [string_lit(doc_id_lem id); - doc_nexp (size); - doc_nexp (start); - string o; - string "[]"])) - ^/^ hardline - else raise (Reporting.err_unreachable l __POS__ ("can't deal with register type " ^ string_of_typ typ)) - else raise (Reporting.err_unreachable l __POS__ ("can't deal with register type " ^ string_of_typ typ)) *) - (*| DEC_reg (typ, id, Some exp) -> - separate space [string "let"; doc_id_lem id; equals; doc_exp_lem empty_ctxt false exp] ^^ hardline*) +let doc_dec_lem (DEC_aux (reg, ((l, _) as annot))) = match reg with DEC_reg (typ, id, _) -> empty +(* if !opt_sequential then empty + else + let env = env_of_annot annot in + let rt = Env.base_typ_of env typ in + if is_vector_typ rt then + let start, (size, order, etyp) = vector_start_index rt, vector_typ_args_of rt in + if is_bit_typ etyp && is_nexp_constant start && is_nexp_constant size then + let o = if is_order_inc order then "true" else "false" in + (doc_op equals) + (string "let" ^^ space ^^ doc_id_lem id) + (string "Register" ^^ space ^^ + align (separate space [string_lit(doc_id_lem id); + doc_nexp (size); + doc_nexp (start); + string o; + string "[]"])) + ^/^ hardline + else raise (Reporting.err_unreachable l __POS__ ("can't deal with register type " ^ string_of_typ typ)) + else raise (Reporting.err_unreachable l __POS__ ("can't deal with register type " ^ string_of_typ typ)) *) +(*| DEC_reg (typ, id, Some exp) -> + separate space [string "let"; doc_id_lem id; equals; doc_exp_lem empty_ctxt false exp] ^^ hardline*) let doc_spec_lem effect_info params_to_print env (VS_aux (valspec, annot)) = match valspec with | VS_val_spec (typschm, id, exts, _) when Ast_util.extern_assoc "lem" exts = None -> - let monad = if Effects.function_is_pure id effect_info then empty else string "M" ^^ space in - (* let (TypSchm_aux (TypSchm_ts (tq, typ), _)) = typschm in - if contains_t_pp_var typ then empty else *) - separate space [string "val"; doc_id_lem id; string ":"; doc_typschm_lem ~monad:monad params_to_print env true typschm] ^/^ hardline + let monad = if Effects.function_is_pure id effect_info then empty else string "M" ^^ space in + (* let (TypSchm_aux (TypSchm_ts (tq, typ), _)) = typschm in + if contains_t_pp_var typ then empty else *) + separate space [string "val"; doc_id_lem id; string ":"; doc_typschm_lem ~monad params_to_print env true typschm] + ^/^ hardline (* | VS_val_spec (_,_,Some _,_) -> empty *) | _ -> empty let is_field_accessor regtypes fdef = let is_field_of regtyp field = - List.exists (fun (tname, (_, _, fields)) -> tname = regtyp && - List.exists (fun (_, fid) -> string_of_id fid = field) fields) regtypes in + List.exists + (fun (tname, (_, _, fields)) -> tname = regtyp && List.exists (fun (_, fid) -> string_of_id fid = field) fields) + regtypes + in match Util.split_on_char '_' (string_of_id (id_of_fundef fdef)) with - | [access; regtyp; field] -> - (access = "get" || access = "set") && is_field_of regtyp field + | [access; regtyp; field] -> (access = "get" || access = "set") && is_field_of regtyp field | _ -> false let int_of_field_index tname fid nexp = match int_of_nexp_opt nexp with | Some i -> i - | None -> raise (Reporting.err_typ Parse_ast.Unknown - ("Non-constant bitfield index in field " ^ string_of_id fid ^ " of " ^ tname)) + | None -> + raise + (Reporting.err_typ Parse_ast.Unknown + ("Non-constant bitfield index in field " ^ string_of_id fid ^ " of " ^ tname) + ) let doc_regtype_fields (tname, (n1, n2, fields)) = let const_int fid idx = int_of_field_index tname fid idx in - let i1, i2 = match n1, n2 with - | Nexp_aux(Nexp_constant i1,_),Nexp_aux(Nexp_constant i2,_) -> i1, i2 - | _ -> raise (Reporting.err_typ Parse_ast.Unknown - ("Non-constant indices in register type " ^ tname)) in + let i1, i2 = + match (n1, n2) with + | Nexp_aux (Nexp_constant i1, _), Nexp_aux (Nexp_constant i2, _) -> (i1, i2) + | _ -> raise (Reporting.err_typ Parse_ast.Unknown ("Non-constant indices in register type " ^ tname)) + in let dir_b = i1 < i2 in - let dir = (if dir_b then "true" else "false") in + let dir = if dir_b then "true" else "false" in let doc_field (fr, fid) = - let i, j = match fr with - | BF_aux (BF_single i, _) -> let i = const_int fid i in (i, i) - | BF_aux (BF_range (i, j), _) -> (const_int fid i, const_int fid j) - | _ -> raise (Reporting.err_unreachable Parse_ast.Unknown __POS__ - ("Unsupported type in field " ^ string_of_id fid ^ " of " ^ tname)) in + let i, j = + match fr with + | BF_aux (BF_single i, _) -> + let i = const_int fid i in + (i, i) + | BF_aux (BF_range (i, j), _) -> (const_int fid i, const_int fid j) + | _ -> + raise + (Reporting.err_unreachable Parse_ast.Unknown __POS__ + ("Unsupported type in field " ^ string_of_id fid ^ " of " ^ tname) + ) + in let fsize = Big_int.succ (Big_int.abs (Big_int.sub i j)) in (* TODO Assumes normalised, decreasing bitvector slices; however, since start indices or indexing order do not appear in Lem type annotations, this does not matter. *) let ftyp = bitvector_typ (nconstant fsize) dec_ord in let reftyp = - mk_typ (Typ_app (Id_aux (Id "field_ref", Parse_ast.Unknown), - [mk_typ_arg (A_typ (mk_id_typ (mk_id tname))); - mk_typ_arg (A_typ ftyp)])) in + mk_typ + (Typ_app + ( Id_aux (Id "field_ref", Parse_ast.Unknown), + [mk_typ_arg (A_typ (mk_id_typ (mk_id tname))); mk_typ_arg (A_typ ftyp)] + ) + ) + in let rfannot = doc_tannot_lem empty_ctxt Env.empty false reftyp in doc_op equals - (concat [string "let "; parens (concat [string tname; underscore; doc_id_lem fid; rfannot])]) - (concat [ - space; langlebar; string " field_name = \"" ^^ doc_id_lem fid ^^ string "\";"; hardline; - space; space; space; string (" field_start = " ^ Big_int.to_string i ^ ";"); hardline; - space; space; space; string (" field_is_inc = " ^ dir ^ ";"); hardline; - space; space; space; string (" get_field = get_" ^ tname ^ "_" ^ string_of_id fid ^ ";"); hardline; - space; space; space; string (" set_field = set_" ^ tname ^ "_" ^ string_of_id fid ^ " "); ranglebar]) + (concat [string "let "; parens (concat [string tname; underscore; doc_id_lem fid; rfannot])]) + (concat + [ + space; + langlebar; + string " field_name = \"" ^^ doc_id_lem fid ^^ string "\";"; + hardline; + space; + space; + space; + string (" field_start = " ^ Big_int.to_string i ^ ";"); + hardline; + space; + space; + space; + string (" field_is_inc = " ^ dir ^ ";"); + hardline; + space; + space; + space; + string (" get_field = get_" ^ tname ^ "_" ^ string_of_id fid ^ ";"); + hardline; + space; + space; + space; + string (" set_field = set_" ^ tname ^ "_" ^ string_of_id fid ^ " "); + ranglebar; + ] + ) in separate_map hardline doc_field fields @@ -1715,10 +1799,11 @@ let doc_def_lem effect_info params_to_print type_env (DEF_aux (aux, _) as def) = | DEF_fundef fdef -> group (doc_fundef_lem effect_info params_to_print type_env fdef) ^/^ hardline | DEF_internal_mutrec fundefs -> doc_mutrec_lem effect_info params_to_print type_env fundefs ^/^ hardline | DEF_let (LB_aux (LB_val (pat, _), _) as lbind) -> - group (doc_let_lem { empty_ctxt with params_to_print } lbind) ^/^ hardline + group (doc_let_lem { empty_ctxt with params_to_print } lbind) ^/^ hardline | DEF_scattered sdef -> unreachable (def_loc def) __POS__ "doc_def_lem: shoulnd't have DEF_scattered at this point" | DEF_mapdef (MD_aux (_, (l, _))) -> unreachable l __POS__ "Lem doesn't support mappings" - | (DEF_outcome _ | DEF_impl _ | DEF_instantiation _) -> unreachable (def_loc def) __POS__ "Event definition found when generating lem" + | DEF_outcome _ | DEF_impl _ | DEF_instantiation _ -> + unreachable (def_loc def) __POS__ "Event definition found when generating lem" | DEF_pragma _ -> empty | DEF_measure _ -> empty (* we might use these in future *) | DEF_loop_measures _ -> empty @@ -1726,64 +1811,81 @@ let doc_def_lem effect_info params_to_print type_env (DEF_aux (aux, _) as def) = let find_exc_typ defs = let is_exc_typ_def = function | DEF_aux (DEF_type td, _) -> string_of_id (id_of_type_def td) = "exception" - | _ -> false in + | _ -> false + in if List.exists is_exc_typ_def defs then "exception" else "unit" -let pp_ast_lem (types_file,types_modules) (defs_file,defs_modules) effect_info type_env { defs; _ } top_line = +let pp_ast_lem (types_file, types_modules) (defs_file, defs_modules) effect_info type_env { defs; _ } top_line = (* let regtypes = find_regtypes d in *) - let state_ids = - State.generate_regstate_defs !Monomorphise.opt_mwords defs - |> val_spec_ids - in + let state_ids = State.generate_regstate_defs !Monomorphise.opt_mwords defs |> val_spec_ids in let is_state_def = function | DEF_aux (DEF_val vs, _) -> IdSet.mem (id_of_val_spec vs) state_ids | DEF_aux (DEF_fundef fd, _) -> IdSet.mem (id_of_fundef fd) state_ids | _ -> false in - let is_typ_def = function - | DEF_aux (DEF_type _, _) -> true - | _ -> false - in + let is_typ_def = function DEF_aux (DEF_type _, _) -> true | _ -> false in let exc_typ = find_exc_typ defs in let params_to_print = type_parameters_to_print type_env defs in let typdefs, defs = List.partition is_typ_def defs in let statedefs, defs = List.partition is_state_def defs in - let register_ref_tannot typ = string " : register_ref regstate register_value " ^^ parens (doc_typ_lem params_to_print type_env typ) in - let register_refs = State.register_refs_lem !Monomorphise.opt_mwords register_ref_tannot (State.find_registers defs) in + let register_ref_tannot typ = + string " : register_ref regstate register_value " ^^ parens (doc_typ_lem params_to_print type_env typ) + in + let register_refs = + State.register_refs_lem !Monomorphise.opt_mwords register_ref_tannot (State.find_registers defs) + in (print types_file) (concat - [string "(*" ^^ (string top_line) ^^ string "*)";hardline; - (separate_map hardline) - (fun lib -> separate space [string "open import";string lib]) types_modules;hardline; - if !print_to_from_interp_value - then - concat - [(separate_map hardline) - (fun lib -> separate space [string " import";string lib]) ["Interp";"Interp_ast"]; - string "open import Deep_shallow_convert"; + [ + string "(*" ^^ string top_line ^^ string "*)"; + hardline; + (separate_map hardline) (fun lib -> separate space [string "open import"; string lib]) types_modules; + hardline; + ( if !print_to_from_interp_value then + concat + [ + (separate_map hardline) + (fun lib -> separate space [string " import"; string lib]) + ["Interp"; "Interp_ast"]; + string "open import Deep_shallow_convert"; + hardline; + hardline; + string "module SI = Interp"; + hardline; + string "module SIA = Interp_ast"; + hardline; + hardline; + ] + else empty + ); + separate empty (List.map (doc_def_lem effect_info params_to_print type_env) typdefs); + hardline; + hardline; + separate empty (List.map (doc_def_lem effect_info params_to_print type_env) statedefs); + hardline; + hardline; + State.regval_instance_lem; + hardline; + register_refs; + hardline; + concat + [ + string ("type MR 'a 'r = base_monadR register_value regstate 'a 'r " ^ exc_typ); hardline; + string ("type M 'a = base_monad register_value regstate 'a " ^ exc_typ); hardline; - string "module SI = Interp"; hardline; - string "module SIA = Interp_ast"; hardline; - hardline] - else empty; - separate empty (List.map (doc_def_lem effect_info params_to_print type_env) typdefs); hardline; - hardline; - separate empty (List.map (doc_def_lem effect_info params_to_print type_env) statedefs); hardline; - hardline; - State.regval_instance_lem; - hardline; - register_refs; hardline; - concat [ - string ("type MR 'a 'r = base_monadR register_value regstate 'a 'r " ^ exc_typ); hardline; - string ("type M 'a = base_monad register_value regstate 'a " ^ exc_typ); hardline - ] - ]); + ]; + ] + ); (print defs_file) (concat - [string "(*" ^^ (string top_line) ^^ string "*)";hardline; - (separate_map hardline) - (fun lib -> separate space [string "open import";string lib]) defs_modules;hardline; - hardline; - separate empty (List.map (doc_def_lem effect_info params_to_print type_env) defs); - hardline]); + [ + string "(*" ^^ string top_line ^^ string "*)"; + hardline; + (separate_map hardline) (fun lib -> separate space [string "open import"; string lib]) defs_modules; + hardline; + hardline; + separate empty (List.map (doc_def_lem effect_info params_to_print type_env) defs); + hardline; + ] + ) diff --git a/src/sail_lem_backend/sail_plugin_lem.ml b/src/sail_lem_backend/sail_plugin_lem.ml index 190d4e809..4fcee0194 100644 --- a/src/sail_lem_backend/sail_plugin_lem.ml +++ b/src/sail_lem_backend/sail_plugin_lem.ml @@ -69,28 +69,28 @@ open Libsail open Ast_defs open PPrint - + let opt_libs_lem : string list ref = ref [] let opt_lem_output_dir : string option ref = ref None let opt_isa_output_dir : string option ref = ref None - -let lem_options = [ - ( "-lem_output_dir", - Arg.String (fun dir -> opt_lem_output_dir := Some dir), - " set a custom directory to output generated Lem"); - ( "-isa_output_dir", - Arg.String (fun dir -> opt_isa_output_dir := Some dir), - " set a custom directory to output generated Isabelle auxiliary theories"); - ( "-lem_lib", - Arg.String (fun l -> opt_libs_lem := l::!opt_libs_lem), - " provide additional library to open in Lem output"); - ( "-lem_sequential", - Arg.Set Pretty_print_lem.opt_sequential, - " use sequential state monad for Lem output"); - ( "-lem_mwords", - Arg.Set Monomorphise.opt_mwords, - " use native machine word library for Lem output"); -] + +let lem_options = + [ + ( "-lem_output_dir", + Arg.String (fun dir -> opt_lem_output_dir := Some dir), + " set a custom directory to output generated Lem" + ); + ( "-isa_output_dir", + Arg.String (fun dir -> opt_isa_output_dir := Some dir), + " set a custom directory to output generated Isabelle auxiliary theories" + ); + ( "-lem_lib", + Arg.String (fun l -> opt_libs_lem := l :: !opt_libs_lem), + " provide additional library to open in Lem output" + ); + ("-lem_sequential", Arg.Set Pretty_print_lem.opt_sequential, " use sequential state monad for Lem output"); + ("-lem_mwords", Arg.Set Monomorphise.opt_mwords, " use native machine word library for Lem output"); + ] let lem_rewrites = let open Rewrites in @@ -147,60 +147,50 @@ let lem_rewrites = ("merge_function_clauses", []); ("bit_lists_to_lits", []); ("recheck_defs", []); - ("attach_effects", []) + ("attach_effects", []); ] -let generated_line f = - Printf.sprintf "Generated by Sail from %s." f - +let generated_line f = Printf.sprintf "Generated by Sail from %s." f + let output_lem filename libs effect_info type_env ast = let generated_line = generated_line filename in (* let seq_suffix = if !Pretty_print_lem.opt_sequential then "_sequential" else "" in *) - let types_module = (filename ^ "_types") in + let types_module = filename ^ "_types" in let monad_modules = ["Sail2_prompt_monad"; "Sail2_prompt"] in let undefined_modules = if !Initial_check.opt_undefined_gen then ["Sail2_undefined"] else [] in - let operators_module = - if !Monomorphise.opt_mwords - then "Sail2_operators_mwords" - else "Sail2_operators_bitlists" in + let operators_module = if !Monomorphise.opt_mwords then "Sail2_operators_mwords" else "Sail2_operators_bitlists" in (* let libs = List.map (fun lib -> lib ^ seq_suffix) libs in *) - let base_imports = [ - "Pervasives_extra"; - "Sail2_instr_kinds"; - "Sail2_values"; - "Sail2_string"; - operators_module - ] @ monad_modules - @ undefined_modules + let base_imports = + ["Pervasives_extra"; "Sail2_instr_kinds"; "Sail2_values"; "Sail2_string"; operators_module] + @ monad_modules @ undefined_modules in let isa_thy_name = String.capitalize_ascii filename ^ "_lemmas" in let isa_lemmas = - separate hardline [ - string ("theory " ^ isa_thy_name); - string " imports"; - string " Sail.Sail2_values_lemmas"; - string " Sail.Sail2_state_lemmas"; - string (" " ^ String.capitalize_ascii filename); - string "begin"; - string ""; - State.generate_isa_lemmas !Monomorphise.opt_mwords ast.defs; - string ""; - string "end" - ] ^^ hardline + separate hardline + [ + string ("theory " ^ isa_thy_name); + string " imports"; + string " Sail.Sail2_values_lemmas"; + string " Sail.Sail2_state_lemmas"; + string (" " ^ String.capitalize_ascii filename); + string "begin"; + string ""; + State.generate_isa_lemmas !Monomorphise.opt_mwords ast.defs; + string ""; + string "end"; + ] + ^^ hardline + in + let ((ot, _, _, _) as ext_ot) = + Util.open_output_with_check_unformatted !opt_lem_output_dir (filename ^ "_types" ^ ".lem") in - let ((ot,_,_,_) as ext_ot) = - Util.open_output_with_check_unformatted !opt_lem_output_dir (filename ^ "_types" ^ ".lem") in - let ((o,_,_,_) as ext_o) = - Util.open_output_with_check_unformatted !opt_lem_output_dir (filename ^ ".lem") in - (Pretty_print_lem.pp_ast_lem - (ot, base_imports) - (o, base_imports @ (String.capitalize_ascii types_module :: libs)) - effect_info - type_env ast generated_line); + let ((o, _, _, _) as ext_o) = Util.open_output_with_check_unformatted !opt_lem_output_dir (filename ^ ".lem") in + Pretty_print_lem.pp_ast_lem (ot, base_imports) + (o, base_imports @ (String.capitalize_ascii types_module :: libs)) + effect_info type_env ast generated_line; Util.close_output_with_check ext_ot; Util.close_output_with_check ext_o; - let ((ol,_,_,_) as ext_ol) = - Util.open_output_with_check_unformatted !opt_isa_output_dir (isa_thy_name ^ ".thy") in + let ((ol, _, _, _) as ext_ol) = Util.open_output_with_check_unformatted !opt_isa_output_dir (isa_thy_name ^ ".thy") in Pretty_print_common.print ol isa_lemmas; Util.close_output_with_check ext_ol @@ -208,16 +198,12 @@ let output libs files = List.iter (fun (f, effect_info, env, ast) -> let f' = Filename.basename (Filename.remove_extension f) in - output_lem f' libs effect_info env ast) + output_lem f' libs effect_info env ast + ) files let lem_target _ out_file ast effect_info env = let out_file = match out_file with Some f -> f | None -> "out" in - output (!opt_libs_lem) [(out_file, effect_info, env, ast)] + output !opt_libs_lem [(out_file, effect_info, env, ast)] -let _ = - Target.register - ~name:"lem" - ~options:lem_options - ~rewrites:lem_rewrites - lem_target +let _ = Target.register ~name:"lem" ~options:lem_options ~rewrites:lem_rewrites lem_target diff --git a/src/sail_manifest/dune b/src/sail_manifest/dune index 470ea8318..3a43baba6 100644 --- a/src/sail_manifest/dune +++ b/src/sail_manifest/dune @@ -1,6 +1,5 @@ (executable - (name sail_manifest) - (public_name sail_manifest) - (package sail_manifest) - (libraries unix)) - + (name sail_manifest) + (public_name sail_manifest) + (package sail_manifest) + (libraries unix)) diff --git a/src/sail_manifest/sail_manifest.ml b/src/sail_manifest/sail_manifest.ml index 56f7c625f..e3a962706 100644 --- a/src/sail_manifest/sail_manifest.ml +++ b/src/sail_manifest/sail_manifest.ml @@ -69,24 +69,15 @@ open Printf let opt_gen_manifest = ref false -let options = Arg.align [ - ( "-gen_manifest", - Arg.Set opt_gen_manifest, - "generate manifest.ml") -] - +let options = Arg.align [("-gen_manifest", Arg.Set opt_gen_manifest, "generate manifest.ml")] + let git_command args = try let git_out, git_in, git_err = Unix.open_process_full ("git " ^ args) (Unix.environment ()) in let res = input_line git_out in - match Unix.close_process_full (git_out, git_in, git_err) with - | Unix.WEXITED 0 -> - res - | _ -> - "unknown" - with - | _ -> "unknown" - + match Unix.close_process_full (git_out, git_in, git_err) with Unix.WEXITED 0 -> res | _ -> "unknown" + with _ -> "unknown" + let gen_manifest () = ksprintf print_endline "let dir = \"%s\"" (Sys.getcwd ()); ksprintf print_endline "let commit = \"%s\"" (git_command "rev-parse HEAD"); @@ -94,11 +85,9 @@ let gen_manifest () = ksprintf print_endline "let version = \"%s\"" (git_command "describe") let usage = "sail_install_tool " - + let main () = Arg.parse options (fun _ -> ()) usage; - if !opt_gen_manifest then ( - gen_manifest () - ) + if !opt_gen_manifest then gen_manifest () let () = main () diff --git a/src/sail_ocaml_backend/dune b/src/sail_ocaml_backend/dune index 489ecab29..251d7d0f2 100644 --- a/src/sail_ocaml_backend/dune +++ b/src/sail_ocaml_backend/dune @@ -1,17 +1,22 @@ (env - (dev - (flags (:standard -w -33 -w -27 -w -32 -w -26 -w -37))) - (release - (flags (:standard -w -33 -w -27 -w -32 -w -26 -w -37)))) + (dev + (flags + (:standard -w -33 -w -27 -w -32 -w -26 -w -37))) + (release + (flags + (:standard -w -33 -w -27 -w -32 -w -26 -w -37)))) (executable - (name sail_plugin_ocaml) - (modes (native plugin)) - (link_flags -linkall) - (libraries libsail base64) - (embed_in_plugin_libraries base64)) + (name sail_plugin_ocaml) + (modes + (native plugin)) + (link_flags -linkall) + (libraries libsail base64) + (embed_in_plugin_libraries base64)) (install - (section (site (libsail plugins))) - (package sail_ocaml_backend) - (files sail_plugin_ocaml.cmxs)) + (section + (site + (libsail plugins))) + (package sail_ocaml_backend) + (files sail_plugin_ocaml.cmxs)) diff --git a/src/sail_ocaml_backend/ocaml_backend.ml b/src/sail_ocaml_backend/ocaml_backend.ml index 034f36ac4..e689df2bc 100644 --- a/src/sail_ocaml_backend/ocaml_backend.ml +++ b/src/sail_ocaml_backend/ocaml_backend.ml @@ -77,6 +77,7 @@ module Big_int = Nat_big_num (* Option to turn tracing features on or off *) let opt_trace_ocaml = ref false + (* Option to not build generated ocaml by default *) let opt_ocaml_nobuild = ref false let opt_ocaml_coverage = ref false @@ -85,20 +86,10 @@ let opt_ocaml_build_dir = ref "_sbuild" (* OCaml variant type can have at most 246 non-constant constructors. *) let ocaml_variant_max_constructors = 246 - -type ctx = - { register_inits : tannot exp list; - externs : id Bindings.t; - val_specs : typ Bindings.t; - records : IdSet.t - } - -let empty_ctx = - { register_inits = []; - externs = Bindings.empty; - val_specs = Bindings.empty; - records = IdSet.empty - } + +type ctx = { register_inits : tannot exp list; externs : id Bindings.t; val_specs : typ Bindings.t; records : IdSet.t } + +let empty_ctx = { register_inits = []; externs = Bindings.empty; val_specs = Bindings.empty; records = IdSet.empty } let gensym_counter = ref 0 @@ -108,12 +99,11 @@ let gensym () = string gs let zencode ctx id = - try string (string_of_id (Bindings.find id ctx.externs)) with - | Not_found -> string (zencode_string (string_of_id id)) + try string (string_of_id (Bindings.find id ctx.externs)) with Not_found -> string (zencode_string (string_of_id id)) let zencode_upper ctx id = - try string (string_of_id (Bindings.find id ctx.externs)) with - | Not_found -> string (zencode_upper_string (string_of_id id)) + try string (string_of_id (Bindings.find id ctx.externs)) + with Not_found -> string (zencode_upper_string (string_of_id id)) let zencode_kid kid = string ("'" ^ zencode_string (string_of_id (id_of_kid kid))) @@ -128,20 +118,25 @@ let rec ocaml_string_typ (Typ_aux (typ_aux, l)) arg = | Typ_id id when string_of_id id = "exception" -> string "Printexc.to_string" ^^ space ^^ arg | Typ_id id -> ocaml_string_of id ^^ space ^^ arg | Typ_app (id, []) -> ocaml_string_of id ^^ space ^^ arg - | Typ_app (id, [A_aux (A_typ (Typ_aux (Typ_id eid, _)), _)]) - when string_of_id id = "list" && string_of_id eid = "bit" -> - string "string_of_bits" ^^ space ^^ arg + | Typ_app (id, [A_aux (A_typ (Typ_aux (Typ_id eid, _)), _)]) when string_of_id id = "list" && string_of_id eid = "bit" + -> + string "string_of_bits" ^^ space ^^ arg | Typ_app (id, [A_aux (A_typ typ, _)]) when string_of_id id = "list" -> - let farg = gensym () in - separate space [string "string_of_list \", \""; parens (separate space [string "fun"; farg; string "->"; ocaml_string_typ typ farg]); arg] + let farg = gensym () in + separate space + [ + string "string_of_list \", \""; + parens (separate space [string "fun"; farg; string "->"; ocaml_string_typ typ farg]); + arg; + ] | Typ_app (_, _) -> string "\"APP\"" | Typ_tuple typs -> - let args = List.map (fun _ -> gensym ()) typs in - let body = - ocaml_string_parens (separate_map ocaml_string_comma (fun (typ, arg) -> ocaml_string_typ typ arg) (List.combine typs args)) - in - parens (separate space [string "fun"; parens (separate (comma ^^ space) args); string "->"; body]) - ^^ space ^^ arg + let args = List.map (fun _ -> gensym ()) typs in + let body = + ocaml_string_parens + (separate_map ocaml_string_comma (fun (typ, arg) -> ocaml_string_typ typ arg) (List.combine typs args)) + in + parens (separate space [string "fun"; parens (separate (comma ^^ space) args); string "->"; body]) ^^ space ^^ arg | Typ_fn (typ1, typ2) -> string "\"FN\"" | Typ_bidir (t1, t2) -> string "\"BIDIR\"" | Typ_var kid -> string "\"VAR\"" @@ -174,6 +169,7 @@ let rec ocaml_typ ctx (Typ_aux (typ_aux, l)) = | Typ_var kid -> zencode_kid kid | Typ_exist _ -> assert false | Typ_internal_unknown -> raise (Reporting.err_unreachable l __POS__ "escaped Typ_internal_unknown") + and ocaml_typ_arg ctx (A_aux (typ_arg_aux, _) as typ_arg) = match typ_arg_aux with | A_typ typ -> ocaml_typ ctx typ @@ -183,7 +179,7 @@ let ocaml_typquant (TypQ_aux (_, l) as typq) = let ocaml_qi = function | QI_aux (QI_id kopt, _) -> zencode_kid (kopt_kid kopt) | QI_aux (QI_constraint _, _) -> - raise (Reporting.err_general l "Ocaml: type quantifiers should no longer contain constraints") + raise (Reporting.err_general l "Ocaml: type quantifiers should no longer contain constraints") in match quant_items typq with | [] -> empty @@ -200,12 +196,10 @@ let ocaml_lit (L_aux (lit_aux, _)) = | L_true -> string "true" | L_false -> string "false" | L_num n -> - if Big_int.equal n Big_int.zero then - string "Big_int.zero" - else if Big_int.less_equal (Big_int.of_int min_int) n && Big_int.less_equal n (Big_int.of_int max_int) then - parens (string "Big_int.of_int" ^^ space ^^ parens (string (Big_int.to_string n))) - else - parens (string "Big_int.of_string" ^^ space ^^ dquotes (string (Big_int.to_string n))) + if Big_int.equal n Big_int.zero then string "Big_int.zero" + else if Big_int.less_equal (Big_int.of_int min_int) n && Big_int.less_equal n (Big_int.of_int max_int) then + parens (string "Big_int.of_int" ^^ space ^^ parens (string (Big_int.to_string n))) + else parens (string "Big_int.of_string" ^^ space ^^ dquotes (string (Big_int.to_string n))) | L_undef -> failwith "undefined should have been re-written prior to ocaml backend" | L_string str -> string_lit str | L_real str -> parens (string "real_of_string" ^^ space ^^ dquotes (string (String.escaped str))) @@ -213,64 +207,66 @@ let ocaml_lit (L_aux (lit_aux, _)) = let rec ocaml_pat ctx (P_aux (pat_aux, _) as pat) = match pat_aux with - | P_id id -> - begin - match Env.lookup_id id (env_of_pat pat) with - | Local (_, _) | Unbound _ -> zencode ctx id - | Enum _ -> zencode_upper ctx id - | _ -> failwith ("Ocaml: Cannot pattern match on register: " ^ string_of_pat pat) - end + | P_id id -> begin + match Env.lookup_id id (env_of_pat pat) with + | Local (_, _) | Unbound _ -> zencode ctx id + | Enum _ -> zencode_upper ctx id + | _ -> failwith ("Ocaml: Cannot pattern match on register: " ^ string_of_pat pat) + end | P_lit lit -> ocaml_lit lit | P_typ (_, pat) -> ocaml_pat ctx pat | P_tuple pats -> parens (separate_map (comma ^^ space) (ocaml_pat ctx) pats) | P_list pats -> brackets (separate_map (semi ^^ space) (ocaml_pat ctx) pats) | P_wild -> string "_" | P_as (pat, id) -> separate space [ocaml_pat ctx pat; string "as"; zencode ctx id] - | P_app (id, pats) -> - begin match Env.union_constructor_info id (env_of_pat pat) with - | Some (_, m, _, _) when m > ocaml_variant_max_constructors -> - (string "`" ^^ zencode_upper ctx id) ^^ space ^^ parens (separate_map (comma ^^ space) (ocaml_pat ctx) pats) - | _ -> - zencode_upper ctx id ^^ space ^^ parens (separate_map (comma ^^ space) (ocaml_pat ctx) pats) - end + | P_app (id, pats) -> begin + match Env.union_constructor_info id (env_of_pat pat) with + | Some (_, m, _, _) when m > ocaml_variant_max_constructors -> + (string "`" ^^ zencode_upper ctx id) ^^ space ^^ parens (separate_map (comma ^^ space) (ocaml_pat ctx) pats) + | _ -> zencode_upper ctx id ^^ space ^^ parens (separate_map (comma ^^ space) (ocaml_pat ctx) pats) + end | P_cons (hd_pat, tl_pat) -> ocaml_pat ctx hd_pat ^^ string " :: " ^^ ocaml_pat ctx tl_pat | _ -> string ("PAT<" ^ string_of_pat pat ^ ">") let begin_end doc = group (string "begin" ^^ nest 2 (break 1 ^^ doc) ^/^ string "end") (* Returns true if a type is a register being passed by name *) -let is_passed_by_name = function - | (Typ_aux (Typ_app (tid, _), _)) -> string_of_id tid = "register" - | _ -> false +let is_passed_by_name = function Typ_aux (Typ_app (tid, _), _) -> string_of_id tid = "register" | _ -> false -let record_id l exp = match typ_of exp with +let record_id l exp = + match typ_of exp with | Typ_aux (Typ_id id, _) when Env.is_record id (env_of exp) -> id | Typ_aux (Typ_app (id, _), _) when Env.is_record id (env_of exp) -> id - | typ -> Reporting.unreachable l __POS__ ("Found a struct without a record type when generating OCaml. Type found: " ^ string_of_typ typ) - + | typ -> + Reporting.unreachable l __POS__ + ("Found a struct without a record type when generating OCaml. Type found: " ^ string_of_typ typ) + let rec ocaml_exp ctx (E_aux (exp_aux, (l, _)) as exp) = match exp_aux with - | E_app (f, xs) -> - begin match Env.union_constructor_info f (env_of exp) with - | Some (_, m, _, _) -> - let name = if m > ocaml_variant_max_constructors then (string "`" ^^ zencode_upper ctx f) else zencode_upper ctx f in - begin match xs with - | [x] -> name ^^ space ^^ ocaml_atomic_exp ctx x - | xs -> name ^^ space ^^ parens (separate_map (comma ^^ space) (ocaml_atomic_exp ctx) xs) - end - | None -> - begin match xs with - | [x] -> zencode ctx f ^^ space ^^ ocaml_atomic_exp ctx x - (* Make sure we get the correct short circuiting semantics for and and or *) - | [x; y] when string_of_id f = "and_bool" -> - separate space [ocaml_atomic_exp ctx x; string "&&"; ocaml_atomic_exp ctx y] - | [x; y] when string_of_id f = "or_bool" -> - separate space [ocaml_atomic_exp ctx x; string "||"; ocaml_atomic_exp ctx y] - | xs -> - zencode ctx f ^^ space ^^ parens (separate_map (comma ^^ space) (ocaml_atomic_exp ctx) xs) + | E_app (f, xs) -> begin + match Env.union_constructor_info f (env_of exp) with + | Some (_, m, _, _) -> + let name = + if m > ocaml_variant_max_constructors then string "`" ^^ zencode_upper ctx f else zencode_upper ctx f + in + begin + match xs with + | [x] -> name ^^ space ^^ ocaml_atomic_exp ctx x + | xs -> name ^^ space ^^ parens (separate_map (comma ^^ space) (ocaml_atomic_exp ctx) xs) + end + | None -> begin + match xs with + | [x] -> zencode ctx f ^^ space ^^ ocaml_atomic_exp ctx x + (* Make sure we get the correct short circuiting semantics for and and or *) + | [x; y] when string_of_id f = "and_bool" -> + separate space [ocaml_atomic_exp ctx x; string "&&"; ocaml_atomic_exp ctx y] + | [x; y] when string_of_id f = "or_bool" -> + separate space [ocaml_atomic_exp ctx x; string "||"; ocaml_atomic_exp ctx y] + | xs -> zencode ctx f ^^ space ^^ parens (separate_map (comma ^^ space) (ocaml_atomic_exp ctx) xs) end - end - | E_vector_subrange (exp1, exp2, exp3) -> string "subrange" ^^ space ^^ parens (separate_map (comma ^^ space) (ocaml_atomic_exp ctx) [exp1; exp2; exp3]) + end + | E_vector_subrange (exp1, exp2, exp3) -> + string "subrange" ^^ space ^^ parens (separate_map (comma ^^ space) (ocaml_atomic_exp ctx) [exp1; exp2; exp3]) | E_return exp -> separate space [string "r.return"; ocaml_atomic_exp ctx exp] | E_assert (exp, _) -> separate space [string "assert"; ocaml_atomic_exp ctx exp] | E_typ (_, exp) -> ocaml_exp ctx exp @@ -281,153 +277,203 @@ let rec ocaml_exp ctx (E_aux (exp_aux, (l, _)) as exp) = | E_exit exp -> string "exit 0" | E_throw exp -> string "raise" ^^ space ^^ ocaml_atomic_exp ctx exp | E_match (exp, pexps) -> - begin_end (separate space [string "match"; ocaml_atomic_exp ctx exp; string "with"] - ^/^ ocaml_pexps ctx pexps) + begin_end (separate space [string "match"; ocaml_atomic_exp ctx exp; string "with"] ^/^ ocaml_pexps ctx pexps) | E_try (exp, pexps) -> - begin_end (separate space [string "try"; ocaml_atomic_exp ctx exp; string "with"] - ^/^ ocaml_pexps ctx pexps) + begin_end (separate space [string "try"; ocaml_atomic_exp ctx exp; string "with"] ^/^ ocaml_pexps ctx pexps) | E_assign (lexp, exp) -> ocaml_assignment ctx lexp exp - | E_if (c, t, e) -> separate space [string "if"; ocaml_atomic_exp ctx c; - string "then"; ocaml_atomic_exp ctx t; - string "else"; ocaml_atomic_exp ctx e] + | E_if (c, t, e) -> + separate space + [ + string "if"; + ocaml_atomic_exp ctx c; + string "then"; + ocaml_atomic_exp ctx t; + string "else"; + ocaml_atomic_exp ctx e; + ] | E_struct fexps -> - enclose lbrace rbrace (group (separate_map (semi ^^ break 1) (ocaml_fexp (record_id l exp) ctx) fexps)) + enclose lbrace rbrace (group (separate_map (semi ^^ break 1) (ocaml_fexp (record_id l exp) ctx) fexps)) | E_struct_update (exp, fexps) -> - enclose lbrace rbrace (separate space [ocaml_atomic_exp ctx exp; - string "with"; - separate_map (semi ^^ space) (ocaml_fexp (record_id l exp) ctx) fexps]) - | E_let (lb, exp) -> - separate space [string "let"; ocaml_letbind ctx lb; string "in"] - ^/^ ocaml_exp ctx exp + enclose lbrace rbrace + (separate space + [ + ocaml_atomic_exp ctx exp; + string "with"; + separate_map (semi ^^ space) (ocaml_fexp (record_id l exp) ctx) fexps; + ] + ) + | E_let (lb, exp) -> separate space [string "let"; ocaml_letbind ctx lb; string "in"] ^/^ ocaml_exp ctx exp | E_var (lexp, exp1, exp2) -> - separate space [string "let"; ocaml_atomic_lexp ctx lexp; - equals; string "ref"; parens (ocaml_atomic_exp ctx exp1 ^^ space ^^ colon ^^ space ^^ ocaml_typ ctx (Rewrites.simple_typ (typ_of exp1))); string "in"] - ^/^ ocaml_exp ctx exp2 + separate space + [ + string "let"; + ocaml_atomic_lexp ctx lexp; + equals; + string "ref"; + parens + (ocaml_atomic_exp ctx exp1 ^^ space ^^ colon ^^ space ^^ ocaml_typ ctx (Rewrites.simple_typ (typ_of exp1))); + string "in"; + ] + ^/^ ocaml_exp ctx exp2 | E_loop (Until, _, cond, body) -> - let loop_body = - (ocaml_atomic_exp ctx body ^^ semi) - ^/^ - separate space [string "if"; ocaml_atomic_exp ctx cond; - string "then ()"; - string "else loop ()"] - in - (string "let rec loop () =" ^//^ loop_body) - ^/^ string "in" - ^/^ string "loop ()" + let loop_body = + (ocaml_atomic_exp ctx body ^^ semi) + ^/^ separate space [string "if"; ocaml_atomic_exp ctx cond; string "then ()"; string "else loop ()"] + in + (string "let rec loop () =" ^//^ loop_body) ^/^ string "in" ^/^ string "loop ()" | E_loop (While, _, cond, body) -> - let loop_body = - separate space [string "if"; ocaml_atomic_exp ctx cond; - string "then"; parens (ocaml_atomic_exp ctx body ^^ semi ^^ space ^^ string "loop ()"); - string "else ()"] - in - (string "let rec loop () =" ^//^ loop_body) - ^/^ string "in" - ^/^ string "loop ()" + let loop_body = + separate space + [ + string "if"; + ocaml_atomic_exp ctx cond; + string "then"; + parens (ocaml_atomic_exp ctx body ^^ semi ^^ space ^^ string "loop ()"); + string "else ()"; + ] + in + (string "let rec loop () =" ^//^ loop_body) ^/^ string "in" ^/^ string "loop ()" | E_lit _ | E_list _ | E_id _ | E_tuple _ | E_ref _ -> ocaml_atomic_exp ctx exp | E_for (id, exp_from, exp_to, exp_step, ord, exp_body) -> - let loop_var = separate space [string "let"; zencode ctx id; equals; string "ref"; ocaml_atomic_exp ctx exp_from; string "in"] in - let loop_mod = - match ord with - | Ord_aux (Ord_inc, _) -> string "Big_int.add" ^^ space ^^ zencode ctx id ^^ space ^^ ocaml_atomic_exp ctx exp_step - | Ord_aux (Ord_dec, _) -> string "Big_int.sub" ^^ space ^^ zencode ctx id ^^ space ^^ ocaml_atomic_exp ctx exp_step - | Ord_aux (Ord_var _, _) -> failwith "Cannot have variable loop order!" - in - let loop_compare = - match ord with - | Ord_aux (Ord_inc, _) -> string "Big_int.less_equal" - | Ord_aux (Ord_dec, _) -> string "Big_int.greater_equal" - | Ord_aux (Ord_var _, _) -> failwith "Cannot have variable loop order!" - in - let loop_body = - separate space [string "if"; loop_compare; zencode ctx id; ocaml_atomic_exp ctx exp_to] - ^/^ separate space [string "then"; - parens (ocaml_atomic_exp ctx exp_body ^^ semi ^^ space ^^ string "loop" ^^ space ^^ parens loop_mod)] - ^/^ string "else ()" - in - (string ("let rec loop " ^ zencode_string (string_of_id id) ^ " =") ^//^ loop_body) - ^/^ string "in" - ^/^ (string "loop" ^^ space ^^ ocaml_atomic_exp ctx exp_from) + let loop_var = + separate space [string "let"; zencode ctx id; equals; string "ref"; ocaml_atomic_exp ctx exp_from; string "in"] + in + let loop_mod = + match ord with + | Ord_aux (Ord_inc, _) -> + string "Big_int.add" ^^ space ^^ zencode ctx id ^^ space ^^ ocaml_atomic_exp ctx exp_step + | Ord_aux (Ord_dec, _) -> + string "Big_int.sub" ^^ space ^^ zencode ctx id ^^ space ^^ ocaml_atomic_exp ctx exp_step + | Ord_aux (Ord_var _, _) -> failwith "Cannot have variable loop order!" + in + let loop_compare = + match ord with + | Ord_aux (Ord_inc, _) -> string "Big_int.less_equal" + | Ord_aux (Ord_dec, _) -> string "Big_int.greater_equal" + | Ord_aux (Ord_var _, _) -> failwith "Cannot have variable loop order!" + in + let loop_body = + separate space [string "if"; loop_compare; zencode ctx id; ocaml_atomic_exp ctx exp_to] + ^/^ separate space + [ + string "then"; + parens (ocaml_atomic_exp ctx exp_body ^^ semi ^^ space ^^ string "loop" ^^ space ^^ parens loop_mod); + ] + ^/^ string "else ()" + in + (string ("let rec loop " ^ zencode_string (string_of_id id) ^ " =") ^//^ loop_body) + ^/^ string "in" ^/^ string "loop" ^^ space ^^ ocaml_atomic_exp ctx exp_from | E_cons (x, xs) -> ocaml_exp ctx x ^^ string " :: " ^^ ocaml_exp ctx xs | _ -> string ("EXP(" ^ string_of_exp exp ^ ")") + and ocaml_letbind ctx (LB_aux (lb_aux, _)) = - match lb_aux with - | LB_val (pat, exp) -> separate space [ocaml_pat ctx pat; equals; ocaml_atomic_exp ctx exp] + match lb_aux with LB_val (pat, exp) -> separate space [ocaml_pat ctx pat; equals; ocaml_atomic_exp ctx exp] + and ocaml_pexps ctx = function | [pexp] -> ocaml_pexp ctx pexp | pexp :: pexps -> ocaml_pexp ctx pexp ^/^ ocaml_pexps ctx pexps | [] -> empty + and ocaml_pexp ctx = function | Pat_aux (Pat_exp (pat, exp), _) -> - separate space [bar; ocaml_pat ctx pat; string "->"] - ^//^ group (ocaml_exp ctx exp) + separate space [bar; ocaml_pat ctx pat; string "->"] ^//^ group (ocaml_exp ctx exp) | Pat_aux (Pat_when (pat, wh, exp), _) -> - separate space [bar; ocaml_pat ctx pat; string "when"; ocaml_atomic_exp ctx wh; string "->"] - ^//^ group (ocaml_exp ctx exp) + separate space [bar; ocaml_pat ctx pat; string "when"; ocaml_atomic_exp ctx wh; string "->"] + ^//^ group (ocaml_exp ctx exp) + and ocaml_block ctx = function | [exp] -> ocaml_exp ctx exp - | E_aux (E_let _, _) as exp :: exps -> ocaml_atomic_exp ctx exp ^^ semi ^/^ ocaml_block ctx exps + | (E_aux (E_let _, _) as exp) :: exps -> ocaml_atomic_exp ctx exp ^^ semi ^/^ ocaml_block ctx exps | exp :: exps -> ocaml_exp ctx exp ^^ semi ^/^ ocaml_block ctx exps | _ -> assert false + and ocaml_fexp record_id ctx (FE_aux (FE_fexp (id, exp), _)) = separate space [zencode_upper ctx record_id ^^ dot ^^ zencode ctx id; equals; ocaml_exp ctx exp] + and ocaml_atomic_exp ctx (E_aux (exp_aux, _) as exp) = match exp_aux with | E_lit lit -> ocaml_lit lit | E_ref id -> zencode ctx id - | E_id id -> - begin - match Env.lookup_id id (env_of exp) with - | Local (Immutable, _) | Unbound _ -> zencode ctx id - | Enum _ -> zencode_upper ctx id - | Register _ when is_passed_by_name (typ_of exp) -> zencode ctx id - | Register typ -> - if !opt_trace_ocaml then + | E_id id -> begin + match Env.lookup_id id (env_of exp) with + | Local (Immutable, _) | Unbound _ -> zencode ctx id + | Enum _ -> zencode_upper ctx id + | Register _ when is_passed_by_name (typ_of exp) -> zencode ctx id + | Register typ -> + if !opt_trace_ocaml then ( let var = gensym () in let str_typ = parens (ocaml_string_typ (Rewrites.simple_typ typ) var) in - parens (separate space [string "let"; var; equals; bang ^^ zencode ctx id; string "in"; - string "trace_read" ^^ space ^^ string_lit (string_of_id id) ^^ space ^^ str_typ ^^ semi; var]) + parens + (separate space + [ + string "let"; + var; + equals; + bang ^^ zencode ctx id; + string "in"; + string "trace_read" ^^ space ^^ string_lit (string_of_id id) ^^ space ^^ str_typ ^^ semi; + var; + ] + ) + ) else bang ^^ zencode ctx id - | Local (Mutable, _) -> bang ^^ zencode ctx id - end + | Local (Mutable, _) -> bang ^^ zencode ctx id + end | E_list exps -> enclose lbracket rbracket (separate_map (semi ^^ space) (ocaml_exp ctx) exps) | E_tuple exps -> parens (separate_map (comma ^^ space) (ocaml_exp ctx) exps) | _ -> parens (ocaml_exp ctx exp) + and ocaml_assignment ctx (LE_aux (lexp_aux, _) as lexp) exp = match lexp_aux with - | LE_typ (_, id) | LE_id id -> - begin - match Env.lookup_id id (env_of exp) with - | Register typ -> + | LE_typ (_, id) | LE_id id -> begin + match Env.lookup_id id (env_of exp) with + | Register typ -> let var = gensym () in let traced_exp = - if !opt_trace_ocaml then + if !opt_trace_ocaml then ( let var = gensym () in let str_typ = parens (ocaml_string_typ (Rewrites.simple_typ typ) var) in - parens (separate space [string "let"; var; equals; ocaml_atomic_exp ctx exp; string "in"; - string "trace_write" ^^ space ^^ string_lit (string_of_id id) ^^ space ^^ str_typ ^^ semi; var]) + parens + (separate space + [ + string "let"; + var; + equals; + ocaml_atomic_exp ctx exp; + string "in"; + string "trace_write" ^^ space ^^ string_lit (string_of_id id) ^^ space ^^ str_typ ^^ semi; + var; + ] + ) + ) else ocaml_atomic_exp ctx exp in separate space [zencode ctx id; string ":="; traced_exp] - | _ -> separate space [zencode ctx id; string ":="; parens (ocaml_exp ctx exp)] - end - | LE_deref ref_exp -> - separate space [ocaml_atomic_exp ctx ref_exp; string ":="; parens (ocaml_exp ctx exp)] + | _ -> separate space [zencode ctx id; string ":="; parens (ocaml_exp ctx exp)] + end + | LE_deref ref_exp -> separate space [ocaml_atomic_exp ctx ref_exp; string ":="; parens (ocaml_exp ctx exp)] | _ -> string ("LEXP<" ^ string_of_lexp lexp ^ ">") + and ocaml_lexp ctx (LE_aux (lexp_aux, _) as lexp) = match lexp_aux with | LE_typ _ | LE_id _ -> ocaml_atomic_lexp ctx lexp | LE_deref exp -> ocaml_exp ctx exp | _ -> string ("LEXP<" ^ string_of_lexp lexp ^ ">") + and ocaml_atomic_lexp ctx (LE_aux (lexp_aux, _) as lexp) = - match lexp_aux with - | LE_typ (_, id) -> zencode ctx id - | LE_id id -> zencode ctx id - | _ -> parens (ocaml_lexp ctx lexp) + match lexp_aux with LE_typ (_, id) -> zencode ctx id | LE_id id -> zencode ctx id | _ -> parens (ocaml_lexp ctx lexp) let rec get_initialize_registers = function - | DEF_aux (DEF_fundef (FD_aux (FD_function (_, _, [FCL_aux (FCL_funcl (id, Pat_aux (Pat_exp (_, E_aux (E_block inits, _)),_)), _)]), _)), _) :: defs - when Id.compare id (mk_id "initialize_registers") = 0 -> - inits + | DEF_aux + ( DEF_fundef + (FD_aux + (FD_function (_, _, [FCL_aux (FCL_funcl (id, Pat_aux (Pat_exp (_, E_aux (E_block inits, _)), _)), _)]), _) + ), + _ + ) + :: defs + when Id.compare id (mk_id "initialize_registers") = 0 -> + inits | _ :: defs -> get_initialize_registers defs | [] -> [] @@ -443,152 +489,208 @@ let initial_value_for id inits = let ocaml_dec_spec ctx (DEC_aux (reg, _)) = match reg with | DEC_reg (typ, id, None) -> - separate space [string "let"; zencode ctx id; colon; - parens (ocaml_typ ctx typ); string "ref"; equals; - string "ref"; parens (ocaml_exp ctx (initial_value_for id ctx.register_inits))] + separate space + [ + string "let"; + zencode ctx id; + colon; + parens (ocaml_typ ctx typ); + string "ref"; + equals; + string "ref"; + parens (ocaml_exp ctx (initial_value_for id ctx.register_inits)); + ] | DEC_reg (typ, id, Some exp) -> - separate space [string "let"; zencode ctx id; colon; - parens (ocaml_typ ctx typ); string "ref"; equals; - string "ref"; parens (ocaml_exp ctx exp)] + separate space + [ + string "let"; + zencode ctx id; + colon; + parens (ocaml_typ ctx typ); + string "ref"; + equals; + string "ref"; + parens (ocaml_exp ctx exp); + ] let first_function = ref true let function_header () = - if !first_function - then (first_function := false; string "let rec") + if !first_function then ( + first_function := false; + string "let rec" + ) else string "and" -let funcls_id = function - | [] -> failwith "Ocaml: empty function" - | FCL_aux (FCL_funcl (id, _),_) :: _ -> id +let funcls_id = function [] -> failwith "Ocaml: empty function" | FCL_aux (FCL_funcl (id, _), _) :: _ -> id -let ocaml_funcl_match ctx (FCL_aux (FCL_funcl (id, pexp), _)) = - ocaml_pexp ctx pexp +let ocaml_funcl_match ctx (FCL_aux (FCL_funcl (id, pexp), _)) = ocaml_pexp ctx pexp let rec ocaml_funcl_matches ctx = function | [] -> failwith "Ocaml: empty function" | [clause] -> ocaml_funcl_match ctx clause - | (clause :: clauses) -> ocaml_funcl_match ctx clause ^/^ ocaml_funcl_matches ctx clauses + | clause :: clauses -> ocaml_funcl_match ctx clause ^/^ ocaml_funcl_matches ctx clauses let ocaml_funcls ctx = (* Create functions string_of_arg and string_of_ret that print the argument and return types of the function respectively *) let trace_info typ1 typ2 = - let arg_sym = gensym () in - let ret_sym = gensym () in - let kids = KidSet.union (tyvars_of_typ typ1) (tyvars_of_typ typ2) in - let foralls = - if KidSet.is_empty kids then empty else - separate space (List.map zencode_kid (KidSet.elements kids)) ^^ dot; - in - let string_of_arg = - separate space [function_header (); arg_sym; colon; foralls; ocaml_typ ctx typ1; string "-> string = fun arg ->"; - ocaml_string_typ typ1 (string "arg")] - in - let string_of_ret = - separate space [function_header (); ret_sym; colon; foralls; ocaml_typ ctx typ2; string "-> string = fun arg ->"; - ocaml_string_typ typ2 (string "arg")] - in - (arg_sym, string_of_arg, ret_sym, string_of_ret) + let arg_sym = gensym () in + let ret_sym = gensym () in + let kids = KidSet.union (tyvars_of_typ typ1) (tyvars_of_typ typ2) in + let foralls = + if KidSet.is_empty kids then empty else separate space (List.map zencode_kid (KidSet.elements kids)) ^^ dot + in + let string_of_arg = + separate space + [ + function_header (); + arg_sym; + colon; + foralls; + ocaml_typ ctx typ1; + string "-> string = fun arg ->"; + ocaml_string_typ typ1 (string "arg"); + ] + in + let string_of_ret = + separate space + [ + function_header (); + ret_sym; + colon; + foralls; + ocaml_typ ctx typ2; + string "-> string = fun arg ->"; + ocaml_string_typ typ2 (string "arg"); + ] + in + (arg_sym, string_of_arg, ret_sym, string_of_ret) in let sail_call id arg_sym pat_sym ret_sym = - if !opt_trace_ocaml - then separate space [string "sail_trace_call"; string_lit (string_of_id id); parens (arg_sym ^^ space ^^ pat_sym); ret_sym] + if !opt_trace_ocaml then + separate space + [string "sail_trace_call"; string_lit (string_of_id id); parens (arg_sym ^^ space ^^ pat_sym); ret_sym] else separate space [string "sail_call"] in let ocaml_funcl call string_of_arg string_of_ret = - if !opt_trace_ocaml - then (call ^^ twice hardline ^^ string_of_arg ^^ twice hardline ^^ string_of_ret) - else call + if !opt_trace_ocaml then call ^^ twice hardline ^^ string_of_arg ^^ twice hardline ^^ string_of_ret else call in function | [] -> failwith "Ocaml: empty function" - | [FCL_aux (FCL_funcl (id, pexp),_)] -> - if Bindings.mem id ctx.externs - then string ("(* Omitting externed function " ^ string_of_id id ^ " *)") ^^ hardline - else - let arg_typs, ret_typ = - match Bindings.find id ctx.val_specs with - | Typ_aux (Typ_fn (typs, typ), _) -> (typs, typ) - | _ -> failwith "Found val spec which was not a function!" - | exception Not_found -> failwith ("No val spec found for " ^ string_of_id id) - in - (* Any remaining type variables after simple_typ rewrite should - indicate Type-polymorphism. If we have it, we need to generate - explicit type signatures with universal quantification. *) - let kids = List.fold_left KidSet.union (tyvars_of_typ ret_typ) (List.map tyvars_of_typ arg_typs) in - let pat_sym = gensym () in - let pat, guard, exp = - match pexp with - | Pat_aux (Pat_exp (pat, exp),_) -> pat, None, exp - | Pat_aux (Pat_when (pat, guard, exp),_) -> pat, Some guard, exp - in - let ocaml_guarded_exp ctx exp = function - | Some guard -> - separate space [string "if"; ocaml_exp ctx guard; - string "then"; parens (ocaml_exp ctx exp); - string "else"; Printf.ksprintf string "failwith \"Pattern match failure in %s\"" (string_of_id id)] - | None -> ocaml_exp ctx exp - in - let annot_pat = - let pat = - if KidSet.is_empty kids then - parens (ocaml_pat ctx pat ^^ space ^^ colon ^^ space ^^ ocaml_typ ctx (mk_typ (Typ_tuple arg_typs))) - else - ocaml_pat ctx pat - in - if !opt_trace_ocaml - then parens (separate space [pat; string "as"; pat_sym]) - else pat - in - let call_header = function_header () in - let arg_sym, string_of_arg, ret_sym, string_of_ret = trace_info (mk_typ (Typ_tuple arg_typs)) ret_typ in - let call = - if KidSet.is_empty kids then - separate space [call_header; zencode ctx id; - annot_pat; colon; ocaml_typ ctx ret_typ; equals; - sail_call id arg_sym pat_sym ret_sym; string "(fun r ->"] - ^//^ ocaml_guarded_exp ctx exp guard - ^^ rparen - else - separate space [call_header; zencode ctx id; colon; - separate space (List.map zencode_kid (KidSet.elements kids)) ^^ dot; - ocaml_typ ctx (mk_typ (Typ_tuple arg_typs)); string "->"; ocaml_typ ctx ret_typ; equals; - string "fun"; annot_pat; string "->"; - sail_call id arg_sym pat_sym ret_sym; string "(fun r ->"] - ^//^ ocaml_guarded_exp ctx exp guard - ^^ rparen - in - ocaml_funcl call string_of_arg string_of_ret + | [FCL_aux (FCL_funcl (id, pexp), _)] -> + if Bindings.mem id ctx.externs then + string ("(* Omitting externed function " ^ string_of_id id ^ " *)") ^^ hardline + else ( + let arg_typs, ret_typ = + match Bindings.find id ctx.val_specs with + | Typ_aux (Typ_fn (typs, typ), _) -> (typs, typ) + | _ -> failwith "Found val spec which was not a function!" + | exception Not_found -> failwith ("No val spec found for " ^ string_of_id id) + in + (* Any remaining type variables after simple_typ rewrite should + indicate Type-polymorphism. If we have it, we need to generate + explicit type signatures with universal quantification. *) + let kids = List.fold_left KidSet.union (tyvars_of_typ ret_typ) (List.map tyvars_of_typ arg_typs) in + let pat_sym = gensym () in + let pat, guard, exp = + match pexp with + | Pat_aux (Pat_exp (pat, exp), _) -> (pat, None, exp) + | Pat_aux (Pat_when (pat, guard, exp), _) -> (pat, Some guard, exp) + in + let ocaml_guarded_exp ctx exp = function + | Some guard -> + separate space + [ + string "if"; + ocaml_exp ctx guard; + string "then"; + parens (ocaml_exp ctx exp); + string "else"; + Printf.ksprintf string "failwith \"Pattern match failure in %s\"" (string_of_id id); + ] + | None -> ocaml_exp ctx exp + in + let annot_pat = + let pat = + if KidSet.is_empty kids then + parens (ocaml_pat ctx pat ^^ space ^^ colon ^^ space ^^ ocaml_typ ctx (mk_typ (Typ_tuple arg_typs))) + else ocaml_pat ctx pat + in + if !opt_trace_ocaml then parens (separate space [pat; string "as"; pat_sym]) else pat + in + let call_header = function_header () in + let arg_sym, string_of_arg, ret_sym, string_of_ret = trace_info (mk_typ (Typ_tuple arg_typs)) ret_typ in + let call = + if KidSet.is_empty kids then + separate space + [ + call_header; + zencode ctx id; + annot_pat; + colon; + ocaml_typ ctx ret_typ; + equals; + sail_call id arg_sym pat_sym ret_sym; + string "(fun r ->"; + ] + ^//^ ocaml_guarded_exp ctx exp guard ^^ rparen + else + separate space + [ + call_header; + zencode ctx id; + colon; + separate space (List.map zencode_kid (KidSet.elements kids)) ^^ dot; + ocaml_typ ctx (mk_typ (Typ_tuple arg_typs)); + string "->"; + ocaml_typ ctx ret_typ; + equals; + string "fun"; + annot_pat; + string "->"; + sail_call id arg_sym pat_sym ret_sym; + string "(fun r ->"; + ] + ^//^ ocaml_guarded_exp ctx exp guard ^^ rparen + in + ocaml_funcl call string_of_arg string_of_ret + ) | funcls -> - let id = funcls_id funcls in - if Bindings.mem id ctx.externs - then string ("(* Omitting externed function " ^ string_of_id id ^ " *)") ^^ hardline - else - let arg_typs, ret_typ = - match Bindings.find id ctx.val_specs with - | Typ_aux (Typ_fn (typs, typ), _) -> (typs, typ) - | _ -> failwith "Found val spec which was not a function!" - in - let kids = List.fold_left KidSet.union (tyvars_of_typ ret_typ) (List.map tyvars_of_typ arg_typs) in - if not (KidSet.is_empty kids) then failwith "Cannot handle polymorphic multi-clause function in OCaml backend" else (); - let pat_sym = gensym () in - let call_header = function_header () in - let arg_sym, string_of_arg, ret_sym, string_of_ret = trace_info (mk_typ (Typ_tuple arg_typs)) ret_typ in - let call = - separate space [call_header; zencode ctx id; parens (pat_sym ^^ space ^^ colon ^^ space ^^ ocaml_typ ctx (mk_typ (Typ_tuple arg_typs))); equals; - sail_call id arg_sym pat_sym ret_sym; string "(fun r ->"] - ^//^ (separate space [string "match"; pat_sym; string "with"] ^^ hardline ^^ ocaml_funcl_matches ctx funcls) - ^^ rparen - in - ocaml_funcl call string_of_arg string_of_ret - -let ocaml_fundef ctx (FD_aux (FD_function (_, _, funcls), _)) = - ocaml_funcls ctx funcls + let id = funcls_id funcls in + if Bindings.mem id ctx.externs then + string ("(* Omitting externed function " ^ string_of_id id ^ " *)") ^^ hardline + else ( + let arg_typs, ret_typ = + match Bindings.find id ctx.val_specs with + | Typ_aux (Typ_fn (typs, typ), _) -> (typs, typ) + | _ -> failwith "Found val spec which was not a function!" + in + let kids = List.fold_left KidSet.union (tyvars_of_typ ret_typ) (List.map tyvars_of_typ arg_typs) in + if not (KidSet.is_empty kids) then failwith "Cannot handle polymorphic multi-clause function in OCaml backend" + else (); + let pat_sym = gensym () in + let call_header = function_header () in + let arg_sym, string_of_arg, ret_sym, string_of_ret = trace_info (mk_typ (Typ_tuple arg_typs)) ret_typ in + let call = + separate space + [ + call_header; + zencode ctx id; + parens (pat_sym ^^ space ^^ colon ^^ space ^^ ocaml_typ ctx (mk_typ (Typ_tuple arg_typs))); + equals; + sail_call id arg_sym pat_sym ret_sym; + string "(fun r ->"; + ] + ^//^ (separate space [string "match"; pat_sym; string "with"] ^^ hardline ^^ ocaml_funcl_matches ctx funcls) + ^^ rparen + in + ocaml_funcl call string_of_arg string_of_ret + ) + +let ocaml_fundef ctx (FD_aux (FD_function (_, _, funcls), _)) = ocaml_funcls ctx funcls let rec ocaml_fields ctx = - let ocaml_field typ id = - separate space [zencode ctx id; colon; ocaml_typ ctx typ] - in + let ocaml_field typ id = separate space [zencode ctx id; colon; ocaml_typ ctx typ] in function | [(typ, id)] -> ocaml_field typ id | (typ, id) :: fields -> ocaml_field typ id ^^ semi ^/^ ocaml_fields ctx fields @@ -596,13 +698,11 @@ let rec ocaml_fields ctx = let rec ocaml_cases polymorphic_variant ctx = let ocaml_case (Tu_aux (Tu_ty_id (typ, id), _)) = - let name = if polymorphic_variant then (string "`" ^^ zencode_upper ctx id) else zencode_upper ctx id in + let name = if polymorphic_variant then string "`" ^^ zencode_upper ctx id else zencode_upper ctx id in separate space [bar; name; string "of"; ocaml_typ ctx typ] in function - | [tu] -> ocaml_case tu - | tu :: tus -> ocaml_case tu ^/^ ocaml_cases polymorphic_variant ctx tus - | [] -> empty + | [tu] -> ocaml_case tu | tu :: tus -> ocaml_case tu ^/^ ocaml_cases polymorphic_variant ctx tus | [] -> empty let rec ocaml_exceptions ctx = let ocaml_exception (Tu_aux (Tu_ty_id (typ, id), _)) = @@ -615,7 +715,7 @@ let rec ocaml_exceptions ctx = let rec ocaml_enum ctx = function | [id] -> zencode_upper ctx id - | id :: ids -> zencode_upper ctx id ^/^ (bar ^^ space ^^ ocaml_enum ctx ids) + | id :: ids -> zencode_upper ctx id ^/^ bar ^^ space ^^ ocaml_enum ctx ids | [] -> empty (* We generate a string_of_X ocaml function for each type X, to be used for debugging purposes *) @@ -623,22 +723,31 @@ let rec ocaml_enum ctx = function let ocaml_def_end = string ";;" ^^ twice hardline let ocaml_string_of_enum ctx id ids = - let ocaml_case id = - separate space [bar; zencode_upper ctx id; string "->"; string ("\"" ^ string_of_id id ^ "\"")] - in - separate space [string "let"; ocaml_string_of id; equals; string "function"] - ^//^ (separate_map hardline ocaml_case ids) + let ocaml_case id = separate space [bar; zencode_upper ctx id; string "->"; string ("\"" ^ string_of_id id ^ "\"")] in + separate space [string "let"; ocaml_string_of id; equals; string "function"] ^//^ separate_map hardline ocaml_case ids + +let ocaml_struct_type ctx id = zencode_upper ctx id ^^ dot ^^ zencode ctx id -let ocaml_struct_type ctx id = - zencode_upper ctx id ^^ dot ^^ zencode ctx id - let ocaml_string_of_struct ctx struct_id typq fields = let arg = gensym () in let ocaml_field (typ, id) = - separate space [string (string_of_id id ^ " = \""); string "^"; ocaml_string_typ typ (arg ^^ dot ^^ zencode_upper ctx struct_id ^^ dot ^^ zencode ctx id)] + separate space + [ + string (string_of_id id ^ " = \""); + string "^"; + ocaml_string_typ typ (arg ^^ dot ^^ zencode_upper ctx struct_id ^^ dot ^^ zencode ctx id); + ] in - separate space [string "let"; ocaml_string_of struct_id; parens (arg ^^ space ^^ colon ^^ space ^^ ocaml_typquant typq ^^ space ^^ ocaml_struct_type ctx struct_id); equals] - ^//^ (string "\"{" ^^ separate_map (hardline ^^ string "^ \", ") ocaml_field fields ^^ string " ^ \"}\"") + separate space + [ + string "let"; + ocaml_string_of struct_id; + parens (arg ^^ space ^^ colon ^^ space ^^ ocaml_typquant typq ^^ space ^^ ocaml_struct_type ctx struct_id); + equals; + ] + ^//^ string "\"{" + ^^ separate_map (hardline ^^ string "^ \", ") ocaml_field fields + ^^ string " ^ \"}\"" let ocaml_string_of_abbrev ctx id typq typ = let arg = gensym () in @@ -651,50 +760,40 @@ let ocaml_string_of_variant ctx id typq cases = let ocaml_typedef ctx (TD_aux (td_aux, (l, _))) = match td_aux with | TD_record (id, typq, fields, _) -> - (separate space [string "module"; zencode_upper ctx id; equals; string "struct"] + (separate space [string "module"; zencode_upper ctx id; equals; string "struct"] ^//^ ((separate space [string "type"; ocaml_typquant typq; zencode ctx id; equals; lbrace] - ^//^ ocaml_fields ctx fields) - ^/^ rbrace) - ^/^ string "end") - ^^ ocaml_def_end - ^^ ocaml_string_of_struct ctx id typq fields - ^^ ocaml_def_end - | TD_variant (id, _, cases, _) when string_of_id id = "exception" -> - ocaml_exceptions ctx cases - ^^ ocaml_def_end + ^//^ ocaml_fields ctx fields + ) + ^/^ rbrace + ) + ^/^ string "end" + ) + ^^ ocaml_def_end + ^^ ocaml_string_of_struct ctx id typq fields + ^^ ocaml_def_end + | TD_variant (id, _, cases, _) when string_of_id id = "exception" -> ocaml_exceptions ctx cases ^^ ocaml_def_end | TD_variant (id, typq, cases, _) -> - (if List.length cases > ocaml_variant_max_constructors then ( - separate space [string "type"; ocaml_typquant typq; zencode ctx id; equals; string "["] - ^//^ ocaml_cases true ctx cases - ^/^ string "]" - ) else ( - separate space [string "type"; ocaml_typquant typq; zencode ctx id; equals] - ^//^ ocaml_cases false ctx cases - )) - ^^ ocaml_def_end - ^^ ocaml_string_of_variant ctx id typq cases - ^^ ocaml_def_end + ( if List.length cases > ocaml_variant_max_constructors then + separate space [string "type"; ocaml_typquant typq; zencode ctx id; equals; string "["] + ^//^ ocaml_cases true ctx cases ^/^ string "]" + else + separate space [string "type"; ocaml_typquant typq; zencode ctx id; equals] ^//^ ocaml_cases false ctx cases + ) + ^^ ocaml_def_end + ^^ ocaml_string_of_variant ctx id typq cases + ^^ ocaml_def_end | TD_enum (id, ids, _) -> - (separate space [string "type"; zencode ctx id; equals] - ^//^ (bar ^^ space ^^ ocaml_enum ctx ids)) - ^^ ocaml_def_end - ^^ ocaml_string_of_enum ctx id ids - ^^ ocaml_def_end + (separate space [string "type"; zencode ctx id; equals] ^//^ bar ^^ space ^^ ocaml_enum ctx ids) + ^^ ocaml_def_end ^^ ocaml_string_of_enum ctx id ids ^^ ocaml_def_end | TD_abbrev (id, typq, A_aux (A_typ typ, _)) -> - separate space [string "type"; ocaml_typquant typq; zencode ctx id; equals; ocaml_typ ctx typ] - ^^ ocaml_def_end - ^^ ocaml_string_of_abbrev ctx id typq typ - ^^ ocaml_def_end - | TD_abbrev _ -> - empty - | TD_bitfield _ -> - Reporting.unreachable l __POS__ "Bitfield should be re-written" + separate space [string "type"; ocaml_typquant typq; zencode ctx id; equals; ocaml_typ ctx typ] + ^^ ocaml_def_end ^^ ocaml_string_of_abbrev ctx id typq typ ^^ ocaml_def_end + | TD_abbrev _ -> empty + | TD_bitfield _ -> Reporting.unreachable l __POS__ "Bitfield should be re-written" let get_externs defs = let extern_id (VS_aux (VS_val_spec (typschm, id, exts, _), _)) = - match Ast_util.extern_assoc "ocaml" exts with - | None -> [] - | Some ext -> [(id, mk_id ext)] + match Ast_util.extern_assoc "ocaml" exts with None -> [] | Some ext -> [(id, mk_id ext)] in let rec extern_ids = function | DEF_aux (DEF_val vs, _) :: defs -> extern_id vs :: extern_ids defs @@ -702,28 +801,31 @@ let get_externs defs = | [] -> [] in List.fold_left (fun exts (id, name) -> Bindings.add id name exts) Bindings.empty (List.concat (extern_ids defs)) - + let nf_group doc = first_function := true; group doc -let ocaml_def ctx (DEF_aux (aux, _)) = match aux with +let ocaml_def ctx (DEF_aux (aux, _)) = + match aux with | DEF_register ds -> nf_group (ocaml_dec_spec ctx ds) ^^ ocaml_def_end | DEF_fundef fd -> group (ocaml_fundef ctx fd) ^^ twice hardline | DEF_internal_mutrec fds -> - separate_map (twice hardline) (fun fd -> group (ocaml_fundef ctx fd)) fds ^^ twice hardline + separate_map (twice hardline) (fun fd -> group (ocaml_fundef ctx fd)) fds ^^ twice hardline | DEF_type td -> nf_group (ocaml_typedef ctx td) | DEF_let lb -> nf_group (string "let" ^^ space ^^ ocaml_letbind ctx lb) ^^ ocaml_def_end | _ -> empty let val_spec_typs defs = - let typs = ref (Bindings.empty) in + let typs = ref Bindings.empty in let val_spec_typ (VS_aux (vs_aux, _)) = match vs_aux with | VS_val_spec (TypSchm_aux (TypSchm_ts (_, typ), _), id, _, _) -> typs := Bindings.add id typ !typs in let rec vs_typs = function - | DEF_aux (DEF_val vs, _) :: defs -> val_spec_typ vs; vs_typs defs + | DEF_aux (DEF_val vs, _) :: defs -> + val_spec_typ vs; + vs_typs defs | _ :: defs -> vs_typs defs | [] -> [] in @@ -736,64 +838,47 @@ let val_spec_typs defs = full type information is available. For example, vectors are simplified to lists, so to produce lists of the right length we need to know what the size of the vector is. - *) +*) let orig_types_for_ocaml_generator defs = - List.filter_map (function - | DEF_aux (DEF_type td, _) -> Some td - | _ -> None) defs + List.filter_map (function DEF_aux (DEF_type td, _) -> Some td | _ -> None) defs let ocaml_pp_generators ctx defs orig_types required = - let add_def typemap td = - Bindings.add (id_of_type_def td) td typemap - in + let add_def typemap td = Bindings.add (id_of_type_def td) td typemap in let typemap = List.fold_left add_def Bindings.empty orig_types in let required = IdSet.of_list required in let rec always_add_req_from_id required id = match Bindings.find id typemap with | td -> add_req_from_td (IdSet.add id required) td | exception Not_found -> - if Bindings.mem id Type_check.Env.builtin_typs - then IdSet.add id required - else - raise (Reporting.err_general (id_loc id) - ("Required generator of unknown type " ^ string_of_id id)) - and add_req_from_id required id = - if IdSet.mem id required then required - else always_add_req_from_id required id - and add_req_from_typ required (Typ_aux (typ,_) as full_typ) = + if Bindings.mem id Type_check.Env.builtin_typs then IdSet.add id required + else raise (Reporting.err_general (id_loc id) ("Required generator of unknown type " ^ string_of_id id)) + and add_req_from_id required id = if IdSet.mem id required then required else always_add_req_from_id required id + and add_req_from_typ required (Typ_aux (typ, _) as full_typ) = match typ with | Typ_id id -> add_req_from_id required id - | Typ_var _ - -> required - | Typ_internal_unknown - | Typ_fn _ - | Typ_bidir _ - -> raise (Reporting.err_unreachable (typ_loc full_typ) __POS__ - ("Required generator for type that should not appear: " ^ - string_of_typ full_typ)) - | Typ_tuple typs -> - List.fold_left add_req_from_typ required typs + | Typ_var _ -> required + | Typ_internal_unknown | Typ_fn _ | Typ_bidir _ -> + raise + (Reporting.err_unreachable (typ_loc full_typ) __POS__ + ("Required generator for type that should not appear: " ^ string_of_typ full_typ) + ) + | Typ_tuple typs -> List.fold_left add_req_from_typ required typs | Typ_exist _ -> - raise (Reporting.err_todo (typ_loc full_typ) - ("Generators for existential types not yet supported: " ^ - string_of_typ full_typ)) - | Typ_app (id,args) -> - List.fold_left add_req_from_typarg (add_req_from_id required id) args - and add_req_from_typarg required (A_aux (arg,_)) = - match arg with - | A_typ typ -> add_req_from_typ required typ - | A_nexp _ | A_order _ | A_bool _ -> required - and add_req_from_td required (TD_aux (td,(l,_))) = + raise + (Reporting.err_todo (typ_loc full_typ) + ("Generators for existential types not yet supported: " ^ string_of_typ full_typ) + ) + | Typ_app (id, args) -> List.fold_left add_req_from_typarg (add_req_from_id required id) args + and add_req_from_typarg required (A_aux (arg, _)) = + match arg with A_typ typ -> add_req_from_typ required typ | A_nexp _ | A_order _ | A_bool _ -> required + and add_req_from_td required (TD_aux (td, (l, _))) = match td with - | TD_abbrev (_, _, A_aux (A_typ typ, _)) -> - add_req_from_typ required typ + | TD_abbrev (_, _, A_aux (A_typ typ, _)) -> add_req_from_typ required typ | TD_abbrev _ -> required - | TD_record (_, _, fields, _) -> - List.fold_left (fun req (typ,_) -> add_req_from_typ req typ) required fields + | TD_record (_, _, fields, _) -> List.fold_left (fun req (typ, _) -> add_req_from_typ req typ) required fields | TD_variant (_, _, variants, _) -> - List.fold_left (fun req (Tu_aux (Tu_ty_id (typ,_),_)) -> - add_req_from_typ req typ) required variants + List.fold_left (fun req (Tu_aux (Tu_ty_id (typ, _), _)) -> add_req_from_typ req typ) required variants | TD_enum _ -> required | TD_bitfield _ -> raise (Reporting.err_todo l "Generators for bitfields not yet supported") in @@ -802,198 +887,184 @@ let ocaml_pp_generators ctx defs orig_types required = let make_gen_field id = let allquants = match Bindings.find id typemap with - | TD_aux (td,_) -> - (match td with - | TD_abbrev (_,tqs,A_aux (A_typ _, _)) -> tqs - | TD_record (_,tqs,_,_) -> tqs - | TD_variant (_,tqs,_,_) -> tqs - | TD_enum _ -> TypQ_aux (TypQ_no_forall,Unknown) + | TD_aux (td, _) -> ( + match td with + | TD_abbrev (_, tqs, A_aux (A_typ _, _)) -> tqs + | TD_record (_, tqs, _, _) -> tqs + | TD_variant (_, tqs, _, _) -> tqs + | TD_enum _ -> TypQ_aux (TypQ_no_forall, Unknown) | TD_abbrev (_, _, _) -> assert false - | TD_bitfield _ -> assert false) - | exception Not_found -> - Bindings.find id Type_check.Env.builtin_typs + | TD_bitfield _ -> assert false + ) + | exception Not_found -> Bindings.find id Type_check.Env.builtin_typs in let tquants = quant_kopts allquants in - let gen_tyvars = List.map (fun k -> kopt_kid k |> zencode_kid) - (List.filter is_typ_kopt tquants) in + let gen_tyvars = List.map (fun k -> kopt_kid k |> zencode_kid) (List.filter is_typ_kopt tquants) in let print_quant kindedid = - if is_int_kopt kindedid then string "int" else - if is_order_kopt kindedid then string "bool" else - parens (separate space [string "generators"; string "->"; zencode_kid (kopt_kid kindedid)]) + if is_int_kopt kindedid then string "int" + else if is_order_kopt kindedid then string "bool" + else parens (separate space [string "generators"; string "->"; zencode_kid (kopt_kid kindedid)]) in let name = "gen_" ^ type_name id in let make_tyarg kindedid = - if is_int_kopt kindedid - then mk_typ_arg (A_nexp (nvar (kopt_kid kindedid))) - else if is_order_kopt kindedid - then mk_typ_arg (A_order (mk_ord (Ord_var (kopt_kid kindedid)))) + if is_int_kopt kindedid then mk_typ_arg (A_nexp (nvar (kopt_kid kindedid))) + else if is_order_kopt kindedid then mk_typ_arg (A_order (mk_ord (Ord_var (kopt_kid kindedid)))) else mk_typ_arg (A_typ (mk_typ (Typ_var (kopt_kid kindedid)))) in let targs = List.map make_tyarg tquants in - let gen_tyvars_pp = match gen_tyvars with - | [] -> empty - | _ -> separate space gen_tyvars ^^ dot ^^ space - in - let out_typ = mk_typ (Typ_app (id,targs)) in + let gen_tyvars_pp = match gen_tyvars with [] -> empty | _ -> separate space gen_tyvars ^^ dot ^^ space in + let out_typ = mk_typ (Typ_app (id, targs)) in let out_typ = Rewrites.simple_typ out_typ in - let types = string "generators" :: List.map print_quant tquants @ [ocaml_typ ctx out_typ] in - string name ^^ colon ^^ space ^^ - gen_tyvars_pp ^^ separate (string " -> ") types + let types = (string "generators" :: List.map print_quant tquants) @ [ocaml_typ ctx out_typ] in + string name ^^ colon ^^ space ^^ gen_tyvars_pp ^^ separate (string " -> ") types in let fields = separate_map (string ";" ^^ break 1) make_gen_field (IdSet.elements required) in let gen_record_type_pp = string "type generators = {" ^^ group (nest 2 (break 0 ^^ fields) ^^ break 0) ^^ string "}" in let make_rand_gen id = - if Bindings.mem id Type_check.Env.builtin_typs - then empty - else + if Bindings.mem id Type_check.Env.builtin_typs then empty + else ( let mk_arg kid = string (zencode_string (string_of_kid kid)) in - let rec gen_type (Typ_aux (typ,l) as full_typ) = - let typ_str, args_pp = match typ with - | Typ_id id -> type_name id, [string "g"] - | Typ_app (id,args) -> type_name id, string "g"::List.map typearg args - | _ -> raise (Reporting.err_todo l - ("Unsupported type for generators: " ^ string_of_typ full_typ)) - in - let args_pp = match args_pp with [] -> empty - | _ -> space ^^ separate space args_pp + let rec gen_type (Typ_aux (typ, l) as full_typ) = + let typ_str, args_pp = + match typ with + | Typ_id id -> (type_name id, [string "g"]) + | Typ_app (id, args) -> (type_name id, string "g" :: List.map typearg args) + | _ -> raise (Reporting.err_todo l ("Unsupported type for generators: " ^ string_of_typ full_typ)) in + let args_pp = match args_pp with [] -> empty | _ -> space ^^ separate space args_pp in string ("g.gen_" ^ typ_str) ^^ args_pp - and typearg (A_aux (arg,l)) = + and typearg (A_aux (arg, l)) = match arg with - | A_nexp (Nexp_aux (nexp,l) as full_nexp) -> - (match nexp with + | A_nexp (Nexp_aux (nexp, l) as full_nexp) -> ( + match nexp with | Nexp_constant c -> string (Big_int.to_string c) (* TODO: overflow *) | Nexp_var v -> mk_arg v - | _ -> raise (Reporting.err_todo l - ("Unsupported nexp for generators: " ^ string_of_nexp full_nexp))) - | A_order (Ord_aux (ord,_)) -> - (match ord with - | Ord_var v -> mk_arg v - | Ord_inc -> string "true" - | Ord_dec -> string "false") + | _ -> raise (Reporting.err_todo l ("Unsupported nexp for generators: " ^ string_of_nexp full_nexp)) + ) + | A_order (Ord_aux (ord, _)) -> ( + match ord with Ord_var v -> mk_arg v | Ord_inc -> string "true" | Ord_dec -> string "false" + ) | A_typ typ -> parens (string "fun g -> " ^^ gen_type typ) - | A_bool nc -> raise (Reporting.err_todo l ("Unsupported constraint for generators: " ^ string_of_n_constraint nc)) + | A_bool nc -> + raise (Reporting.err_todo l ("Unsupported constraint for generators: " ^ string_of_n_constraint nc)) in - let make_subgen (Typ_aux (typ,l) as full_typ) = + let make_subgen (Typ_aux (typ, l) as full_typ) = let typ_str, args_pp = match typ with - | Typ_id id -> type_name id, [] - | Typ_app (id,args) -> type_name id, List.map typearg args - | _ -> raise (Reporting.err_todo l - ("Unsupported type for generators: " ^ string_of_typ full_typ)) + | Typ_id id -> (type_name id, []) + | Typ_app (id, args) -> (type_name id, List.map typearg args) + | _ -> raise (Reporting.err_todo l ("Unsupported type for generators: " ^ string_of_typ full_typ)) in - let args_pp = match args_pp with [] -> empty - | _ -> space ^^ separate space args_pp - in string ("g.gen_" ^ typ_str) ^^ space ^^ string "g" ^^ args_pp + let args_pp = match args_pp with [] -> empty | _ -> space ^^ separate space args_pp in + string ("g.gen_" ^ typ_str) ^^ space ^^ string "g" ^^ args_pp in - let make_variant (Tu_aux (Tu_ty_id (typ,id),_)) = - let arg_typs = match typ with - | Typ_aux (Typ_fn (typs,_),_) -> typs - | Typ_aux (Typ_tuple typs,_) -> typs - | _ -> [typ] + let make_variant (Tu_aux (Tu_ty_id (typ, id), _)) = + let arg_typs = + match typ with Typ_aux (Typ_fn (typs, _), _) -> typs | Typ_aux (Typ_tuple typs, _) -> typs | _ -> [typ] in - zencode_upper ctx id ^^ space ^^ - parens (separate_map (string ", ") make_subgen arg_typs) - in - let rand_variant variant = - parens (string "fun g -> " ^^ make_variant variant) - in - let variant_constructor (Tu_aux (Tu_ty_id (_,id),_)) = - dquotes (string (string_of_id id)) + zencode_upper ctx id ^^ space ^^ parens (separate_map (string ", ") make_subgen arg_typs) in + let rand_variant variant = parens (string "fun g -> " ^^ make_variant variant) in + let variant_constructor (Tu_aux (Tu_ty_id (_, id), _)) = dquotes (string (string_of_id id)) in let build_constructor variant = - separate space [bar; variant_constructor variant; string "->"; - make_variant variant] - in - let enum_constructor id = - dquotes (string (string_of_id id)) + separate space [bar; variant_constructor variant; string "->"; make_variant variant] in + let enum_constructor id = dquotes (string (string_of_id id)) in let build_enum_constructor id = - separate space [bar; dquotes (string (string_of_id id)); string "->"; - zencode_upper ctx id] - in - let rand_field (typ,id) = - zencode ctx id ^^ space ^^ equals ^^ space ^^ make_subgen typ + separate space [bar; dquotes (string (string_of_id id)); string "->"; zencode_upper ctx id] in + let rand_field (typ, id) = zencode ctx id ^^ space ^^ equals ^^ space ^^ make_subgen typ in let make_args tqs = - string "g" ^^ + string "g" + ^^ match quant_kopts tqs with | [] -> empty - | kopts -> - space ^^ - separate_map space (fun kdid -> mk_arg (kopt_kid kdid)) kopts + | kopts -> space ^^ separate_map space (fun kdid -> mk_arg (kopt_kid kdid)) kopts in let tqs, body, constructors, builders = - let TD_aux (td,(l,_)) = Bindings.find id typemap in + let (TD_aux (td, (l, _))) = Bindings.find id typemap in match td with - | TD_abbrev (_,tqs,A_aux (A_typ typ, _)) -> - tqs, gen_type typ, None, None - | TD_variant (_,tqs,variants,_) -> - tqs, - string "let c = rand_choice [" ^^ group (nest 2 (break 0 ^^ - separate_map (string ";" ^^ break 1) rand_variant variants) ^^ - break 0) ^^ - string "] in c g", - Some (separate_map (string ";" ^^ break 1) variant_constructor variants), - Some (separate_map (break 1) build_constructor variants) - | TD_enum (_,variants,_) -> - TypQ_aux (TypQ_no_forall, Parse_ast.Unknown), - string "rand_choice [" ^^ group (nest 2 (break 0 ^^ - separate_map (string ";" ^^ break 1) (zencode_upper ctx) variants) ^^ - break 0) ^^ - string "]", - Some (separate_map (string ";" ^^ break 1) enum_constructor variants), - Some (separate_map (break 1) build_enum_constructor variants) - | TD_record (_,tqs,fields,_) -> - tqs, braces (separate_map (string ";" ^^ break 1) rand_field fields), None, None - | _ -> - raise (Reporting.err_todo l "Generators for bitfields not yet supported") + | TD_abbrev (_, tqs, A_aux (A_typ typ, _)) -> (tqs, gen_type typ, None, None) + | TD_variant (_, tqs, variants, _) -> + ( tqs, + string "let c = rand_choice [" + ^^ group (nest 2 (break 0 ^^ separate_map (string ";" ^^ break 1) rand_variant variants) ^^ break 0) + ^^ string "] in c g", + Some (separate_map (string ";" ^^ break 1) variant_constructor variants), + Some (separate_map (break 1) build_constructor variants) + ) + | TD_enum (_, variants, _) -> + ( TypQ_aux (TypQ_no_forall, Parse_ast.Unknown), + string "rand_choice [" + ^^ group (nest 2 (break 0 ^^ separate_map (string ";" ^^ break 1) (zencode_upper ctx) variants) ^^ break 0) + ^^ string "]", + Some (separate_map (string ";" ^^ break 1) enum_constructor variants), + Some (separate_map (break 1) build_enum_constructor variants) + ) + | TD_record (_, tqs, fields, _) -> + (tqs, braces (separate_map (string ";" ^^ break 1) rand_field fields), None, None) + | _ -> raise (Reporting.err_todo l "Generators for bitfields not yet supported") in let name = type_name id in - let constructors_pp = match constructors with + let constructors_pp = + match constructors with | None -> empty | Some pp -> - nest 2 (separate space - [string "let"; string ("constructors_" ^ name); equals; lbracket] ^^ - break 1 ^^ pp ^^ break 1 ^^ rbracket) ^^ hardline + nest 2 + (separate space [string "let"; string ("constructors_" ^ name); equals; lbracket] + ^^ break 1 ^^ pp ^^ break 1 ^^ rbracket + ) + ^^ hardline in - let build_pp = match builders with + let build_pp = + match builders with | None -> empty | Some pp -> - nest 2 (separate space - [string "let"; string ("build_" ^ name); string "g"; string "c"; equals; - string "match c with"] ^^ - break 1 ^^ pp) ^^ hardline + nest 2 + (separate space + [string "let"; string ("build_" ^ name); string "g"; string "c"; equals; string "match c with"] + ^^ break 1 ^^ pp + ) + ^^ hardline in - nest 2 (separate space [string "let"; string ("rand_" ^ name); make_args tqs; equals] ^^ break 1 ^^ - body) ^^ hardline ^^ constructors_pp ^^ build_pp + nest 2 (separate space [string "let"; string ("rand_" ^ name); make_args tqs; equals] ^^ break 1 ^^ body) + ^^ hardline ^^ constructors_pp ^^ build_pp + ) in let rand_record_pp = - string "let rand_gens : generators = {" ^^ group (nest 2 (break 0 ^^ - separate_map (string ";" ^^ break 1) - (fun id -> - string ("gen_" ^ type_name id) ^^ space ^^ equals ^^ space ^^ - string ("rand_" ^ type_name id)) (IdSet.elements required)) ^^ - break 0) ^^ string "}" ^^ hardline + string "let rand_gens : generators = {" + ^^ group + (nest 2 + (break 0 + ^^ separate_map + (string ";" ^^ break 1) + (fun id -> + string ("gen_" ^ type_name id) ^^ space ^^ equals ^^ space ^^ string ("rand_" ^ type_name id) + ) + (IdSet.elements required) + ) + ^^ break 0 + ) + ^^ string "}" ^^ hardline in - gen_record_type_pp ^^ hardline ^^ hardline ^^ - separate_map hardline make_rand_gen (IdSet.elements required) ^^ - hardline ^^ rand_record_pp + gen_record_type_pp ^^ hardline ^^ hardline + ^^ separate_map hardline make_rand_gen (IdSet.elements required) + ^^ hardline ^^ rand_record_pp let ocaml_ast ast generator_info = - let ctx = { register_inits = get_initialize_registers ast.defs; - externs = get_externs ast.defs; - val_specs = val_spec_typs ast.defs; - records = record_ids ast.defs - } + let ctx = + { + register_inits = get_initialize_registers ast.defs; + externs = get_externs ast.defs; + val_specs = val_spec_typs ast.defs; + records = record_ids ast.defs; + } in let empty_reg_init = - if ctx.register_inits = [] - then - separate space [string "let"; string "zinitializze_registers"; string "()"; equals; string "()"] - ^^ ocaml_def_end + if ctx.register_inits = [] then + separate space [string "let"; string "zinitializze_registers"; string "()"; equals; string "()"] ^^ ocaml_def_end else empty in let gen_pp = @@ -1004,8 +1075,7 @@ let ocaml_ast ast generator_info = (string "open Sail_lib;;" ^^ hardline) ^^ (string "module Big_int = Nat_big_num" ^^ ocaml_def_end) ^^ concat (List.map (ocaml_def ctx) ast.defs) - ^^ empty_reg_init - ^^ gen_pp + ^^ empty_reg_init ^^ gen_pp let ocaml_main spec sail_dir = let lines = ref [] in @@ -1015,33 +1085,35 @@ let ocaml_main spec sail_dir = while true do let line = input_line chan in lines := line :: !lines - done; - with - | End_of_file -> close_in chan; lines := List.rev !lines + done + with End_of_file -> + close_in chan; + lines := List.rev !lines end; - (("open " ^ String.capitalize_ascii spec ^ ";;\n\n") :: !lines - @ [ " zinitializze_registers ();"; - if !opt_trace_ocaml then " Sail_lib.opt_trace := true;" else " ();"; - " Printexc.record_backtrace true;"; - " try zmain () with exn -> (prerr_endline(\"Exiting due to uncaught exception:\\n\" ^ Printexc.to_string exn); exit 1)\n";]) + (("open " ^ String.capitalize_ascii spec ^ ";;\n\n") :: !lines) + @ [ + " zinitializze_registers ();"; + (if !opt_trace_ocaml then " Sail_lib.opt_trace := true;" else " ();"); + " Printexc.record_backtrace true;"; + " try zmain () with exn -> (prerr_endline(\"Exiting due to uncaught exception:\\n\" ^ Printexc.to_string exn); \ + exit 1)\n"; + ] |> String.concat "\n" -let ocaml_pp_ast f ast generator_types = - ToChannel.pretty 1. 80 f (ocaml_ast ast generator_types) - +let ocaml_pp_ast f ast generator_types = ToChannel.pretty 1. 80 f (ocaml_ast ast generator_types) let system_checked str = match Unix.system str with | Unix.WEXITED 0 -> () | Unix.WEXITED n -> - prerr_endline (str ^ " terminated with code " ^ string_of_int n); - exit 1 + prerr_endline (str ^ " terminated with code " ^ string_of_int n); + exit 1 | Unix.WSIGNALED _ -> - prerr_endline (str ^ " was killed by a signal"); - exit 1 + prerr_endline (str ^ " was killed by a signal"); + exit 1 | Unix.WSTOPPED _ -> - prerr_endline (str ^ " was stopped by a signal"); - exit 1 + prerr_endline (str ^ " was stopped by a signal"); + exit 1 let ocaml_compile default_sail_dir spec ast generator_types = let sail_dir = Reporting.get_sail_dir default_sail_dir in @@ -1055,25 +1127,21 @@ let ocaml_compile default_sail_dir spec ast generator_types = let _ = Unix.system ("cp -r " ^ sail_dir ^ "/lib/" ^ tags_file ^ " _tags") in let out_chan = open_out (spec ^ ".ml") in if !opt_ocaml_coverage then - ignore(Unix.system ("cp -r " ^ sail_dir ^ "/lib/myocamlbuild_coverage.ml myocamlbuild.ml")); + ignore (Unix.system ("cp -r " ^ sail_dir ^ "/lib/myocamlbuild_coverage.ml myocamlbuild.ml")); ocaml_pp_ast out_chan ast generator_types; close_out out_chan; - if IdSet.mem (mk_id "main") (val_spec_ids ast.defs) - then - begin - print_endline "Generating main"; - let out_chan = open_out "main.ml" in - output_string out_chan (ocaml_main spec sail_dir); - close_out out_chan; - if not !opt_ocaml_nobuild then ( - if !opt_ocaml_coverage then - system_checked "BISECT_COVERAGE=YES ocamlbuild -use-ocamlfind -plugin-tag 'package(bisect_ppx-ocamlbuild)' main.native" - else - system_checked "ocamlbuild -use-ocamlfind main.native"; - ignore (Unix.system ("cp main.native " ^ cwd ^ "/" ^ spec)) - ) - end - else (if not !opt_ocaml_nobuild then - system_checked ("ocamlbuild -use-ocamlfind " ^ spec ^ ".cmo") - ); + if IdSet.mem (mk_id "main") (val_spec_ids ast.defs) then begin + print_endline "Generating main"; + let out_chan = open_out "main.ml" in + output_string out_chan (ocaml_main spec sail_dir); + close_out out_chan; + if not !opt_ocaml_nobuild then ( + if !opt_ocaml_coverage then + system_checked + "BISECT_COVERAGE=YES ocamlbuild -use-ocamlfind -plugin-tag 'package(bisect_ppx-ocamlbuild)' main.native" + else system_checked "ocamlbuild -use-ocamlfind main.native"; + ignore (Unix.system ("cp main.native " ^ cwd ^ "/" ^ spec)) + ) + end + else if not !opt_ocaml_nobuild then system_checked ("ocamlbuild -use-ocamlfind " ^ spec ^ ".cmo"); Unix.chdir cwd diff --git a/src/sail_ocaml_backend/sail_plugin_ocaml.ml b/src/sail_ocaml_backend/sail_plugin_ocaml.ml index 1aa88fcd9..f329be865 100644 --- a/src/sail_ocaml_backend/sail_plugin_ocaml.ml +++ b/src/sail_ocaml_backend/sail_plugin_ocaml.ml @@ -67,26 +67,29 @@ open Libsail -let opt_ocaml_generators = ref ([]:string list) +let opt_ocaml_generators = ref ([] : string list) + +let ocaml_options = + [ + ("-ocaml_nobuild", Arg.Set Ocaml_backend.opt_ocaml_nobuild, " do not build generated OCaml"); + ( "-ocaml_trace", + Arg.Set Ocaml_backend.opt_trace_ocaml, + " output an OCaml translated version of the input with tracing instrumentation, implies -ocaml" + ); + ( "-ocaml_build_dir", + Arg.String (fun dir -> Ocaml_backend.opt_ocaml_build_dir := dir), + " set a custom directory to build generated OCaml" + ); + ( "-ocaml_coverage", + Arg.Set Ocaml_backend.opt_ocaml_coverage, + " build OCaml with bisect_ppx coverage reporting (requires opam packages bisect_ppx-ocamlbuild and bisect_ppx)." + ); + ( "-ocaml_generators", + Arg.String (fun s -> opt_ocaml_generators := s :: !opt_ocaml_generators), + " produce random generators for the given types" + ); + ] -let ocaml_options = [ - ( "-ocaml_nobuild", - Arg.Set Ocaml_backend.opt_ocaml_nobuild, - " do not build generated OCaml"); - ( "-ocaml_trace", - Arg.Set Ocaml_backend.opt_trace_ocaml, - " output an OCaml translated version of the input with tracing instrumentation, implies -ocaml"); - ( "-ocaml_build_dir", - Arg.String (fun dir -> Ocaml_backend.opt_ocaml_build_dir := dir), - " set a custom directory to build generated OCaml"); - ( "-ocaml_coverage", - Arg.Set Ocaml_backend.opt_ocaml_coverage, - " build OCaml with bisect_ppx coverage reporting (requires opam packages bisect_ppx-ocamlbuild and bisect_ppx)."); - ( "-ocaml_generators", - Arg.String (fun s -> opt_ocaml_generators := s :: !opt_ocaml_generators), - " produce random generators for the given types"); -] - let ocaml_generator_info : (Type_check.tannot Ast.type_def list * string list) option ref = ref None let stash_pre_rewrite_info (ast : _ Ast_defs.ast) _ type_envs = @@ -94,7 +97,7 @@ let stash_pre_rewrite_info (ast : _ Ast_defs.ast) _ type_envs = match !opt_ocaml_generators with | [] -> None | _ -> Some (Ocaml_backend.orig_types_for_ocaml_generator ast.defs, !opt_ocaml_generators) - + let ocaml_rewrites = let open Rewrites in [ @@ -117,35 +120,35 @@ let ocaml_rewrites = ("exp_lift_assign", []); ("top_sort_defs", []); ("simple_types", []); - ("overload_cast", []) + ("overload_cast", []); ] - + let ocaml_target default_sail_dir out_file ast effect_info env = let out = match out_file with None -> "out" | Some s -> s in Ocaml_backend.ocaml_compile default_sail_dir out ast !ocaml_generator_info - + let _ = - Target.register - ~name:"ocaml" - ~options:ocaml_options + Target.register ~name:"ocaml" ~options:ocaml_options ~pre_parse_hook:(fun () -> Initial_check.opt_undefined_gen := true) - ~pre_rewrites_hook:stash_pre_rewrite_info - ~rewrites:ocaml_rewrites - ocaml_target + ~pre_rewrites_hook:stash_pre_rewrite_info ~rewrites:ocaml_rewrites ocaml_target let opt_tofrominterp_output_dir : string option ref = ref None -let tofrominterp_options = [ - ( "-tofrominterp_lem", - Arg.Set ToFromInterp_backend.lem_mode, - " output embedding translation for the Lem backend rather than the OCaml backend, implies -tofrominterp"); - ( "-tofrominterp_mwords", - Arg.Set ToFromInterp_backend.mword_mode, - " output embedding translation in machine-word mode rather than bit-list mode, implies -tofrominterp"); - ( "-tofrominterp_output_dir", - Arg.String (fun dir -> opt_tofrominterp_output_dir := Some dir), - " set a custom directory to output embedding translation OCaml"); -] +let tofrominterp_options = + [ + ( "-tofrominterp_lem", + Arg.Set ToFromInterp_backend.lem_mode, + " output embedding translation for the Lem backend rather than the OCaml backend, implies -tofrominterp" + ); + ( "-tofrominterp_mwords", + Arg.Set ToFromInterp_backend.mword_mode, + " output embedding translation in machine-word mode rather than bit-list mode, implies -tofrominterp" + ); + ( "-tofrominterp_output_dir", + Arg.String (fun dir -> opt_tofrominterp_output_dir := Some dir), + " set a custom directory to output embedding translation OCaml" + ); + ] let tofrominterp_rewrites = let open Rewrites in @@ -158,38 +161,29 @@ let tofrominterp_rewrites = ("undefined", [Bool_arg false]); ("tuple_assignments", []); ("vector_concat_assignments", []); - ("simple_assignments", []) + ("simple_assignments", []); ] - + let tofrominterp_target _ out_file ast _ _ = let out = match out_file with None -> "out" | Some s -> s in ToFromInterp_backend.tofrominterp_output !opt_tofrominterp_output_dir out ast let _ = - Target.register - ~name:"tofrominterp" + Target.register ~name:"tofrominterp" ~description:"output OCaml functions to translate between shallow embedding and interpreter" - ~options:tofrominterp_options - ~rewrites:tofrominterp_rewrites - tofrominterp_target + ~options:tofrominterp_options ~rewrites:tofrominterp_rewrites tofrominterp_target let marshal_target _ out_file ast _ env = let out_filename = match out_file with None -> "out" | Some s -> s in let f = open_out_bin (out_filename ^ ".defs") in let remove_prover (l, tannot) = - if Type_check.is_empty_tannot tannot then - (l, Type_check.empty_tannot) - else - (l, Type_check.replace_env (Type_check.Env.set_prover None (Type_check.env_of_tannot tannot)) tannot) + if Type_check.is_empty_tannot tannot then (l, Type_check.empty_tannot) + else (l, Type_check.replace_env (Type_check.Env.set_prover None (Type_check.env_of_tannot tannot)) tannot) in Marshal.to_string (Ast_util.map_ast_annot remove_prover ast, Type_check.Env.set_prover None env) [Marshal.Compat_32] - |> Base64.encode_string - |> output_string f; + |> Base64.encode_string |> output_string f; close_out f let _ = - Target.register - ~name:"marshal" - ~description:"OCaml-marshal out the rewritten AST to a file" - ~rewrites:tofrominterp_rewrites - marshal_target + Target.register ~name:"marshal" ~description:"OCaml-marshal out the rewritten AST to a file" + ~rewrites:tofrominterp_rewrites marshal_target diff --git a/src/sail_ocaml_backend/toFromInterp_backend.ml b/src/sail_ocaml_backend/toFromInterp_backend.ml index af15acf4b..896e31718 100644 --- a/src/sail_ocaml_backend/toFromInterp_backend.ml +++ b/src/sail_ocaml_backend/toFromInterp_backend.ml @@ -88,136 +88,193 @@ let rec rewriteExistential (kids : kinded_id list) (Typ_aux (typ_aux, annot) as | Typ_tuple typs -> Typ_aux (Typ_tuple (List.map (rewriteExistential kids) typs), annot) | Typ_exist _ -> Reporting.unreachable annot __POS__ "nested Typ_exist in rewriteExistential" | Typ_app (id, [A_aux (A_nexp (Nexp_aux (Nexp_var kid, _)), _)]) - when (string_of_id id = "atom" || string_of_id id = "int") -> - (* List.exists (fun k -> string_of_kid (kopt_kid k) = string_of_kid kid) kids -> *) - print_endline("*** rewriting to int - kid is '" ^ string_of_kid kid ^ "'" ); - Typ_aux (Typ_id (mk_id "int"), annot) - | Typ_internal_unknown - | Typ_id _ - | Typ_var _ - | Typ_fn _ - | Typ_bidir _ - | Typ_app _ -> - typ + when string_of_id id = "atom" || string_of_id id = "int" -> + (* List.exists (fun k -> string_of_kid (kopt_kid k) = string_of_kid kid) kids -> *) + print_endline ("*** rewriting to int - kid is '" ^ string_of_kid kid ^ "'"); + Typ_aux (Typ_id (mk_id "int"), annot) + | Typ_internal_unknown | Typ_id _ | Typ_var _ | Typ_fn _ | Typ_bidir _ | Typ_app _ -> typ let frominterp_typedef (TD_aux (td_aux, (l, _))) = - let fromValueArgs (Typ_aux (typ_aux, _)) = match typ_aux with - | Typ_tuple typs -> brackets (separate space [string "V_tuple"; brackets (separate (semi ^^ space) (List.mapi (fun i _ -> string ("v" ^ (string_of_int i))) typs))]) + let fromValueArgs (Typ_aux (typ_aux, _)) = + match typ_aux with + | Typ_tuple typs -> + brackets + (separate space + [ + string "V_tuple"; + brackets (separate (semi ^^ space) (List.mapi (fun i _ -> string ("v" ^ string_of_int i)) typs)); + ] + ) | _ -> brackets (string "v0") in - let fromValueKid (Kid_aux ((Var name), _)) = - string ("typq_" ^ name) - in - let fromValueNexp ((Nexp_aux (nexp_aux, annot)) as nexp) = match nexp_aux with - | Nexp_constant num -> parens (separate space [string "Big_int.of_string"; dquotes (string (Nat_big_num.to_string num))]) + let fromValueKid (Kid_aux (Var name, _)) = string ("typq_" ^ name) in + let fromValueNexp (Nexp_aux (nexp_aux, annot) as nexp) = + match nexp_aux with + | Nexp_constant num -> + parens (separate space [string "Big_int.of_string"; dquotes (string (Nat_big_num.to_string num))]) | Nexp_var var -> fromValueKid var | Nexp_id id -> string (string_of_id id ^ "FromInterpValue") | _ -> string ("NEXP(" ^ string_of_nexp nexp ^ ")") in - let rec fromValueTypArg (A_aux (a_aux, _)) = match a_aux with - | A_typ typ -> parens ((string "fun v -> ") ^^ parens (fromValueTyp typ "v")) + let rec fromValueTypArg (A_aux (a_aux, _)) = + match a_aux with + | A_typ typ -> parens (string "fun v -> " ^^ parens (fromValueTyp typ "v")) | A_nexp nexp -> fromValueNexp nexp - | A_order order -> string ("Order_" ^ (string_of_order order)) + | A_order order -> string ("Order_" ^ string_of_order order) | A_bool _ -> parens (string "boolFromInterpValue") - and fromValueTyp ((Typ_aux (typ_aux, l)) as typ) arg_name = match typ_aux with - | Typ_id id -> parens (concat [string (maybe_zencode (string_of_id id)); string ("FromInterpValue"); space; string arg_name]) + and fromValueTyp (Typ_aux (typ_aux, l) as typ) arg_name = + match typ_aux with + | Typ_id id -> + parens (concat [string (maybe_zencode (string_of_id id)); string "FromInterpValue"; space; string arg_name]) (* special case bit vectors for lem *) - | Typ_app (Id_aux (Id "vector", _), [A_aux (A_nexp len_nexp, _); - A_aux (A_order (Ord_aux (Ord_dec, _)), _); - A_aux (A_typ (Typ_aux (Typ_id (Id_aux (Id "bit", _)), _)), _)]) when !lem_mode -> - parens (separate space ([string (maybe_zencode "bitsFromInterpValue"); string arg_name])) + | Typ_app + ( Id_aux (Id "vector", _), + [ + A_aux (A_nexp len_nexp, _); + A_aux (A_order (Ord_aux (Ord_dec, _)), _); + A_aux (A_typ (Typ_aux (Typ_id (Id_aux (Id "bit", _)), _)), _); + ] + ) + when !lem_mode -> + parens (separate space [string (maybe_zencode "bitsFromInterpValue"); string arg_name]) | Typ_app (typ_id, typ_args) -> - assert (typ_args <> []); - if string_of_id typ_id = "bits" then - parens (separate space ([string "bitsFromInterpValue"] @ [string arg_name])) - else - parens (separate space ([string (maybe_zencode (string_of_id typ_id) ^ "FromInterpValue")] @ List.map fromValueTypArg typ_args @ [string arg_name])) + assert (typ_args <> []); + if string_of_id typ_id = "bits" then parens (separate space ([string "bitsFromInterpValue"] @ [string arg_name])) + else + parens + (separate space + ([string (maybe_zencode (string_of_id typ_id) ^ "FromInterpValue")] + @ List.map fromValueTypArg typ_args + @ [string arg_name] + ) + ) | Typ_var kid -> parens (separate space [fromValueKid kid; string arg_name]) | Typ_fn _ -> parens (string "failwith \"fromValueTyp: Typ_fn arm unimplemented\"") | Typ_bidir _ -> parens (string "failwith \"fromValueTyp: Typ_bidir arm unimplemented\"") | Typ_exist (kids, _, t) -> parens (fromValueTyp (rewriteExistential kids t) arg_name) - | Typ_tuple typs -> parens (string ("match " ^ arg_name ^ " with V_tuple ") ^^ - brackets (separate (string ";" ^^ space) - (List.mapi (fun i _ -> string (arg_name ^ "_tup" ^ string_of_int i)) typs)) ^^ - (string " -> ") ^^ - parens (separate comma_sp (List.mapi (fun i t -> fromValueTyp t (arg_name ^ "_tup" ^ string_of_int i)) typs))) + | Typ_tuple typs -> + parens + (string ("match " ^ arg_name ^ " with V_tuple ") + ^^ brackets + (separate + (string ";" ^^ space) + (List.mapi (fun i _ -> string (arg_name ^ "_tup" ^ string_of_int i)) typs) + ) + ^^ string " -> " + ^^ parens + (separate comma_sp (List.mapi (fun i t -> fromValueTyp t (arg_name ^ "_tup" ^ string_of_int i)) typs)) + ) | Typ_internal_unknown -> failwith "escaped Typ_internal_unknown" in - let fromValueVals ((Typ_aux (typ_aux, l)) as typ) = match typ_aux with - | Typ_tuple typs -> parens (separate comma_sp (List.mapi (fun i typ -> fromValueTyp typ ("v" ^ (string_of_int i))) typs)) + let fromValueVals (Typ_aux (typ_aux, l) as typ) = + match typ_aux with + | Typ_tuple typs -> + parens (separate comma_sp (List.mapi (fun i typ -> fromValueTyp typ ("v" ^ string_of_int i)) typs)) | _ -> fromValueTyp typ "v0" in - let fromValueTypq (QI_aux (qi_aux, _)) = match qi_aux with + let fromValueTypq (QI_aux (qi_aux, _)) = + match qi_aux with | QI_id (KOpt_aux (KOpt_kind (K_aux (kind_aux, _), kid), _)) -> fromValueKid kid | QI_constraint _ -> empty in - let fromValueTypqs (TypQ_aux (typq_aux, _)) = match typq_aux with - | TypQ_no_forall -> [empty] - | TypQ_tq quants -> List.map fromValueTypq quants + let fromValueTypqs (TypQ_aux (typq_aux, _)) = + match typq_aux with TypQ_no_forall -> [empty] | TypQ_tq quants -> List.map fromValueTypq quants in match td_aux with - | TD_variant (id, typq, arms, _) -> - begin match id with - | Id_aux ((Id "read_kind"),_) -> empty - | Id_aux ((Id "write_kind"),_) -> empty - | Id_aux ((Id "a64_barrier_domain"),_) -> empty - | Id_aux ((Id "a64_barrier_type"),_) -> empty - | Id_aux ((Id "barrier_kind"),_) -> empty - | Id_aux ((Id "trans_kind"),_) -> empty - | Id_aux ((Id "instruction_kind"),_) -> empty - | Id_aux ((Id "cache_op_kind"),_) -> empty - | Id_aux ((Id "regfp"),_) -> empty - | Id_aux ((Id "regfps"),_) -> empty - | Id_aux ((Id "niafp"),_) -> empty - | Id_aux ((Id "niafps"),_) -> empty - | Id_aux ((Id "diafp"),_) -> empty - | Id_aux ((Id "diafps"),_) -> empty - (* | Id_aux ((Id "option"),_) -> empty *) - | Id_aux ((Id id_string), _) - | Id_aux ((Operator id_string), _) -> - if !lem_mode && id_string = "option" then empty else - let fromInterpValueName = concat [string (maybe_zencode (string_of_id id)); string "FromInterpValue"] in - let fromFallback = separate space [pipe; underscore; arrow; string "failwith"; - dquotes (string ("invalid interpreter value for " ^ (string_of_id id)))] in - let fromInterpValue = - prefix 2 1 - (separate space [string "let"; fromInterpValueName; separate space (fromValueTypqs typq @ [string "v"]); equals; string "match v with"]) - ((separate_map hardline - (fun (Tu_aux (Tu_ty_id (typ, ctor_id), _)) -> - separate space - [pipe; string "V_ctor"; parens (concat [dquotes (string (string_of_id ctor_id)); comma_sp; - fromValueArgs typ - ]); - arrow; string (maybe_zencode_upper (string_of_id ctor_id)); fromValueVals typ - ] - ) - arms) - ^^ hardline ^^ fromFallback) - in - fromInterpValue ^^ (twice hardline) - end + | TD_variant (id, typq, arms, _) -> begin + match id with + | Id_aux (Id "read_kind", _) -> empty + | Id_aux (Id "write_kind", _) -> empty + | Id_aux (Id "a64_barrier_domain", _) -> empty + | Id_aux (Id "a64_barrier_type", _) -> empty + | Id_aux (Id "barrier_kind", _) -> empty + | Id_aux (Id "trans_kind", _) -> empty + | Id_aux (Id "instruction_kind", _) -> empty + | Id_aux (Id "cache_op_kind", _) -> empty + | Id_aux (Id "regfp", _) -> empty + | Id_aux (Id "regfps", _) -> empty + | Id_aux (Id "niafp", _) -> empty + | Id_aux (Id "niafps", _) -> empty + | Id_aux (Id "diafp", _) -> empty + | Id_aux (Id "diafps", _) -> empty + (* | Id_aux ((Id "option"),_) -> empty *) + | Id_aux (Id id_string, _) | Id_aux (Operator id_string, _) -> + if !lem_mode && id_string = "option" then empty + else ( + let fromInterpValueName = concat [string (maybe_zencode (string_of_id id)); string "FromInterpValue"] in + let fromFallback = + separate space + [ + pipe; + underscore; + arrow; + string "failwith"; + dquotes (string ("invalid interpreter value for " ^ string_of_id id)); + ] + in + let fromInterpValue = + prefix 2 1 + (separate space + [ + string "let"; + fromInterpValueName; + separate space (fromValueTypqs typq @ [string "v"]); + equals; + string "match v with"; + ] + ) + (separate_map hardline + (fun (Tu_aux (Tu_ty_id (typ, ctor_id), _)) -> + separate space + [ + pipe; + string "V_ctor"; + parens (concat [dquotes (string (string_of_id ctor_id)); comma_sp; fromValueArgs typ]); + arrow; + string (maybe_zencode_upper (string_of_id ctor_id)); + fromValueVals typ; + ] + ) + arms + ^^ hardline ^^ fromFallback + ) + in + fromInterpValue ^^ twice hardline + ) + end | TD_abbrev (Id_aux (Id "regfps", _), _, _) -> empty | TD_abbrev (Id_aux (Id "niafps", _), _, _) -> empty | TD_abbrev (Id_aux (Id "bits", _), _, _) when !lem_mode -> empty - | TD_abbrev (id, typq, typ_arg) -> - begin - let fromInterpValueName = concat [string (maybe_zencode (string_of_id id)); string "FromInterpValue"] in - (* HACK: print a type annotation for abbrevs of unquantified types, to help cases ocaml can't type-infer on its own *) - let fromInterpValspec = - (* HACK because of lem renaming *) - if string_of_id id = "opcode" || string_of_id id = "integer" then empty else - match typ_arg with - | A_aux (A_typ _, _) -> begin match typq with - | TypQ_aux (TypQ_no_forall, _) -> separate space [colon; string "value"; arrow; string (maybe_zencode (string_of_id id))] - | _ -> empty - end - | _ -> empty - in - let fromInterpValue = - (separate space [string "let"; fromInterpValueName; separate space (fromValueTypqs typq); fromInterpValspec; equals; fromValueTypArg typ_arg]) - in - fromInterpValue ^^ (twice hardline) - end + | TD_abbrev (id, typq, typ_arg) -> begin + let fromInterpValueName = concat [string (maybe_zencode (string_of_id id)); string "FromInterpValue"] in + (* HACK: print a type annotation for abbrevs of unquantified types, to help cases ocaml can't type-infer on its own *) + let fromInterpValspec = + (* HACK because of lem renaming *) + if string_of_id id = "opcode" || string_of_id id = "integer" then empty + else ( + match typ_arg with + | A_aux (A_typ _, _) -> begin + match typq with + | TypQ_aux (TypQ_no_forall, _) -> + separate space [colon; string "value"; arrow; string (maybe_zencode (string_of_id id))] + | _ -> empty + end + | _ -> empty + ) + in + let fromInterpValue = + separate space + [ + string "let"; + fromInterpValueName; + separate space (fromValueTypqs typq); + fromInterpValspec; + equals; + fromValueTypArg typ_arg; + ] + in + fromInterpValue ^^ twice hardline + end | TD_enum (Id_aux (Id "read_kind", _), _, _) -> empty | TD_enum (Id_aux (Id "write_kind", _), _, _) -> empty | TD_enum (Id_aux (Id "a64_barrier_domain", _), _, _) -> empty @@ -226,151 +283,243 @@ let frominterp_typedef (TD_aux (td_aux, (l, _))) = | TD_enum (Id_aux (Id "trans_kind", _), _, _) -> empty | TD_enum (Id_aux (Id "cache_op_kind", _), _, _) -> empty | TD_enum (id, ids, _) -> - let fromInterpValueName = concat [string (maybe_zencode (string_of_id id)); string "FromInterpValue"] in - let fromFallback = separate space [pipe; underscore; arrow; string "failwith"; - dquotes (string ("invalid interpreter value for " ^ (string_of_id id)))] in - let fromInterpValue = - prefix 2 1 - (separate space [string "let"; fromInterpValueName; string "v"; equals; string "match v with"]) - ((separate_map hardline + let fromInterpValueName = concat [string (maybe_zencode (string_of_id id)); string "FromInterpValue"] in + let fromFallback = + separate space + [ + pipe; + underscore; + arrow; + string "failwith"; + dquotes (string ("invalid interpreter value for " ^ string_of_id id)); + ] + in + let fromInterpValue = + prefix 2 1 + (separate space [string "let"; fromInterpValueName; string "v"; equals; string "match v with"]) + (separate_map hardline (fun id -> separate space - [pipe; string "V_ctor"; parens (concat [dquotes (string (string_of_id id)); comma_sp; string "[]"]); - arrow; string (maybe_zencode_upper (string_of_id id))] + [ + pipe; + string "V_ctor"; + parens (concat [dquotes (string (string_of_id id)); comma_sp; string "[]"]); + arrow; + string (maybe_zencode_upper (string_of_id id)); + ] ) - ids) - ^^ hardline ^^ fromFallback) - in - fromInterpValue ^^ (twice hardline) + ids + ^^ hardline ^^ fromFallback + ) + in + fromInterpValue ^^ twice hardline | TD_record (record_id, typq, fields, _) -> - let fromInterpField (typ, id) = - separate space [string (maybe_zencode ((if !lem_mode then string_of_id record_id ^ "_" else "") ^ string_of_id id)); equals; fromValueTyp typ ("(StringMap.find \"" ^ (string_of_id id) ^ "\" fs)")] - in - let fromInterpValueName = concat [string (maybe_zencode (string_of_id record_id)); string "FromInterpValue"] in - let fromFallback = separate space [pipe; underscore; arrow; string "failwith"; - dquotes (string ("invalid interpreter value for " ^ (string_of_id record_id)))] in - let fromInterpValue = - prefix 2 1 - (separate space [string "let"; fromInterpValueName; separate space (fromValueTypqs typq @ [string "v"]); equals; string "match v with"]) - ((separate space [pipe; string "V_record fs"; arrow; braces (separate_map (semi ^^ space) fromInterpField fields)]) - ^^ hardline ^^ fromFallback) - in - fromInterpValue ^^ (twice hardline) + let fromInterpField (typ, id) = + separate space + [ + string (maybe_zencode ((if !lem_mode then string_of_id record_id ^ "_" else "") ^ string_of_id id)); + equals; + fromValueTyp typ ("(StringMap.find \"" ^ string_of_id id ^ "\" fs)"); + ] + in + let fromInterpValueName = concat [string (maybe_zencode (string_of_id record_id)); string "FromInterpValue"] in + let fromFallback = + separate space + [ + pipe; + underscore; + arrow; + string "failwith"; + dquotes (string ("invalid interpreter value for " ^ string_of_id record_id)); + ] + in + let fromInterpValue = + prefix 2 1 + (separate space + [ + string "let"; + fromInterpValueName; + separate space (fromValueTypqs typq @ [string "v"]); + equals; + string "match v with"; + ] + ) + (separate space + [pipe; string "V_record fs"; arrow; braces (separate_map (semi ^^ space) fromInterpField fields)] + ^^ hardline ^^ fromFallback + ) + in + fromInterpValue ^^ twice hardline | _ -> empty let tointerp_typedef (TD_aux (td_aux, (l, _))) = - let toValueArgs (Typ_aux (typ_aux, _)) = match typ_aux with - | Typ_tuple typs -> parens (separate comma_sp (List.mapi (fun i _ -> string ("v" ^ (string_of_int i))) typs)) + let toValueArgs (Typ_aux (typ_aux, _)) = + match typ_aux with + | Typ_tuple typs -> parens (separate comma_sp (List.mapi (fun i _ -> string ("v" ^ string_of_int i)) typs)) | _ -> parens (string "v0") in - let toValueKid (Kid_aux ((Var name), _)) = - string ("typq_" ^ name) - in - let toValueNexp ((Nexp_aux (nexp_aux, _)) as nexp) = match nexp_aux with - | Nexp_constant num -> parens (separate space [string "Big_int.of_string"; dquotes (string (Nat_big_num.to_string num))]) + let toValueKid (Kid_aux (Var name, _)) = string ("typq_" ^ name) in + let toValueNexp (Nexp_aux (nexp_aux, _) as nexp) = + match nexp_aux with + | Nexp_constant num -> + parens (separate space [string "Big_int.of_string"; dquotes (string (Nat_big_num.to_string num))]) | Nexp_var var -> toValueKid var | Nexp_id id -> string (string_of_id id ^ "ToInterpValue") | _ -> string ("NEXP(" ^ string_of_nexp nexp ^ ")") in - let rec toValueTypArg (A_aux (a_aux, _)) = match a_aux with - | A_typ typ -> parens ((string "fun v -> ") ^^ parens (toValueTyp typ "v")) + let rec toValueTypArg (A_aux (a_aux, _)) = + match a_aux with + | A_typ typ -> parens (string "fun v -> " ^^ parens (toValueTyp typ "v")) | A_nexp nexp -> toValueNexp nexp - | A_order order -> string ("Order_" ^ (string_of_order order)) - | A_bool _ -> parens (string "boolToInterpValue") - and toValueTyp ((Typ_aux (typ_aux, l)) as typ) arg_name = match typ_aux with - | Typ_id id -> parens (concat [string (maybe_zencode (string_of_id id)); string "ToInterpValue"; space; string arg_name]) + | A_order order -> string ("Order_" ^ string_of_order order) + | A_bool _ -> parens (string "boolToInterpValue") + and toValueTyp (Typ_aux (typ_aux, l) as typ) arg_name = + match typ_aux with + | Typ_id id -> + parens (concat [string (maybe_zencode (string_of_id id)); string "ToInterpValue"; space; string arg_name]) (* special case bit vectors for lem *) - | Typ_app (Id_aux (Id "vector", _), [A_aux (A_nexp len_nexp, _); - A_aux (A_order (Ord_aux (Ord_dec, _)), _); - A_aux (A_typ (Typ_aux (Typ_id (Id_aux (Id "bit", _)), _)), _)]) when !lem_mode -> - parens (separate space ([string (maybe_zencode "bitsToInterpValue"); string arg_name])) + | Typ_app + ( Id_aux (Id "vector", _), + [ + A_aux (A_nexp len_nexp, _); + A_aux (A_order (Ord_aux (Ord_dec, _)), _); + A_aux (A_typ (Typ_aux (Typ_id (Id_aux (Id "bit", _)), _)), _); + ] + ) + when !lem_mode -> + parens (separate space [string (maybe_zencode "bitsToInterpValue"); string arg_name]) | Typ_app (typ_id, typ_args) -> - assert (typ_args <> []); - if string_of_id typ_id = "bits" then - parens (separate space ([string "bitsToInterpValue"] @ [string arg_name])) - else - parens (separate space ([string ((maybe_zencode (string_of_id typ_id)) ^ "ToInterpValue")] @ List.map toValueTypArg typ_args @ [string arg_name])) + assert (typ_args <> []); + if string_of_id typ_id = "bits" then parens (separate space ([string "bitsToInterpValue"] @ [string arg_name])) + else + parens + (separate space + ([string (maybe_zencode (string_of_id typ_id) ^ "ToInterpValue")] + @ List.map toValueTypArg typ_args + @ [string arg_name] + ) + ) | Typ_var kid -> parens (separate space [toValueKid kid; string arg_name]) | Typ_fn _ -> parens (string "failwith \"toValueTyp: Typ_fn arm unimplemented\"") | Typ_bidir _ -> parens (string "failwith \"toValueTyp: Typ_bidir arm unimplemented\"") | Typ_exist (kids, _, t) -> parens (toValueTyp (rewriteExistential kids t) arg_name) - | Typ_tuple typs -> parens (string ("match " ^ arg_name ^ " with ") ^^ - parens (separate comma_sp (List.mapi (fun i _ -> string (arg_name ^ "_tup" ^ string_of_int i)) typs)) ^^ - (string " -> V_tuple ") ^^ - brackets (separate (string ";" ^^ space) - (List.mapi (fun i t -> toValueTyp t (arg_name ^ "_tup" ^ string_of_int i)) typs))) + | Typ_tuple typs -> + parens + (string ("match " ^ arg_name ^ " with ") + ^^ parens (separate comma_sp (List.mapi (fun i _ -> string (arg_name ^ "_tup" ^ string_of_int i)) typs)) + ^^ string " -> V_tuple " + ^^ brackets + (separate + (string ";" ^^ space) + (List.mapi (fun i t -> toValueTyp t (arg_name ^ "_tup" ^ string_of_int i)) typs) + ) + ) | Typ_internal_unknown -> failwith "escaped Typ_internal_unknown" in - let toValueVals ((Typ_aux (typ_aux, _)) as typ) = match typ_aux with - | Typ_tuple typs -> brackets (separate space [string "V_tuple"; brackets (separate (semi ^^ space) (List.mapi (fun i typ -> toValueTyp typ ("v" ^ (string_of_int i))) typs))]) + let toValueVals (Typ_aux (typ_aux, _) as typ) = + match typ_aux with + | Typ_tuple typs -> + brackets + (separate space + [ + string "V_tuple"; + brackets (separate (semi ^^ space) (List.mapi (fun i typ -> toValueTyp typ ("v" ^ string_of_int i)) typs)); + ] + ) | _ -> brackets (toValueTyp typ "v0") in - let toValueTypq (QI_aux (qi_aux, _)) = match qi_aux with + let toValueTypq (QI_aux (qi_aux, _)) = + match qi_aux with | QI_id (KOpt_aux (KOpt_kind (K_aux (kind_aux, _), kid), _)) -> toValueKid kid | QI_constraint _ -> empty in - let toValueTypqs (TypQ_aux (typq_aux, _)) = match typq_aux with - | TypQ_no_forall -> [empty] - | TypQ_tq quants -> List.map toValueTypq quants + let toValueTypqs (TypQ_aux (typq_aux, _)) = + match typq_aux with TypQ_no_forall -> [empty] | TypQ_tq quants -> List.map toValueTypq quants in match td_aux with - | TD_variant (id, typq, arms, _) -> - begin match id with - | Id_aux ((Id "read_kind"),_) -> empty - | Id_aux ((Id "write_kind"),_) -> empty - | Id_aux ((Id "a64_barrier_domain"),_) -> empty - | Id_aux ((Id "a64_barrier_type"),_) -> empty - | Id_aux ((Id "barrier_kind"),_) -> empty - | Id_aux ((Id "trans_kind"),_) -> empty - | Id_aux ((Id "instruction_kind"),_) -> empty - | Id_aux ((Id "cache_op_kind"),_) -> empty - | Id_aux ((Id "regfp"),_) -> empty - | Id_aux ((Id "regfps"),_) -> empty - | Id_aux ((Id "niafp"),_) -> empty - | Id_aux ((Id "niafps"),_) -> empty - | Id_aux ((Id "diafp"),_) -> empty - | Id_aux ((Id "diafps"),_) -> empty - (* | Id_aux ((Id "option"),_) -> empty *) - | Id_aux ((Id id_string), _) - | Id_aux ((Operator id_string), _) -> - if !lem_mode && id_string = "option" then empty else - let toInterpValueName = concat [string (maybe_zencode (string_of_id id)); string "ToInterpValue"] in - let toInterpValue = - prefix 2 1 - (separate space [string "let"; toInterpValueName; separate space (toValueTypqs typq @ [string "v"]); equals; string "match v with"]) - ((separate_map hardline - (fun (Tu_aux (Tu_ty_id (typ, ctor_id), _)) -> - separate space - [pipe; string (maybe_zencode_upper (string_of_id ctor_id)); toValueArgs typ; - arrow; string "V_ctor"; parens (concat [dquotes (string (string_of_id ctor_id)); comma_sp; toValueVals typ]) - ] - ) - arms)) - in - toInterpValue ^^ (twice hardline) - end + | TD_variant (id, typq, arms, _) -> begin + match id with + | Id_aux (Id "read_kind", _) -> empty + | Id_aux (Id "write_kind", _) -> empty + | Id_aux (Id "a64_barrier_domain", _) -> empty + | Id_aux (Id "a64_barrier_type", _) -> empty + | Id_aux (Id "barrier_kind", _) -> empty + | Id_aux (Id "trans_kind", _) -> empty + | Id_aux (Id "instruction_kind", _) -> empty + | Id_aux (Id "cache_op_kind", _) -> empty + | Id_aux (Id "regfp", _) -> empty + | Id_aux (Id "regfps", _) -> empty + | Id_aux (Id "niafp", _) -> empty + | Id_aux (Id "niafps", _) -> empty + | Id_aux (Id "diafp", _) -> empty + | Id_aux (Id "diafps", _) -> empty + (* | Id_aux ((Id "option"),_) -> empty *) + | Id_aux (Id id_string, _) | Id_aux (Operator id_string, _) -> + if !lem_mode && id_string = "option" then empty + else ( + let toInterpValueName = concat [string (maybe_zencode (string_of_id id)); string "ToInterpValue"] in + let toInterpValue = + prefix 2 1 + (separate space + [ + string "let"; + toInterpValueName; + separate space (toValueTypqs typq @ [string "v"]); + equals; + string "match v with"; + ] + ) + (separate_map hardline + (fun (Tu_aux (Tu_ty_id (typ, ctor_id), _)) -> + separate space + [ + pipe; + string (maybe_zencode_upper (string_of_id ctor_id)); + toValueArgs typ; + arrow; + string "V_ctor"; + parens (concat [dquotes (string (string_of_id ctor_id)); comma_sp; toValueVals typ]); + ] + ) + arms + ) + in + toInterpValue ^^ twice hardline + ) + end | TD_abbrev (Id_aux (Id "regfps", _), _, _) -> empty | TD_abbrev (Id_aux (Id "niafps", _), _, _) -> empty | TD_abbrev (Id_aux (Id "bits", _), _, _) when !lem_mode -> empty - | TD_abbrev (id, typq, typ_arg) -> - begin - let toInterpValueName = concat [string (maybe_zencode (string_of_id id)); string "ToInterpValue"] in - (* HACK: print a type annotation for abbrevs of unquantified types, to help cases ocaml can't type-infer on its own *) - let toInterpValspec = - (* HACK because of lem renaming *) - if string_of_id id = "opcode" || string_of_id id = "integer" then empty else - match typ_arg with - | A_aux (A_typ _, _) -> begin match typq with - | TypQ_aux (TypQ_no_forall, _) -> separate space [colon; string (maybe_zencode (string_of_id id)); arrow; string "value"] - | _ -> empty - end - | _ -> empty - in - let toInterpValue = - (separate space [string "let"; toInterpValueName; separate space (toValueTypqs typq); toInterpValspec; equals; toValueTypArg typ_arg]) - in - toInterpValue ^^ (twice hardline) - end + | TD_abbrev (id, typq, typ_arg) -> begin + let toInterpValueName = concat [string (maybe_zencode (string_of_id id)); string "ToInterpValue"] in + (* HACK: print a type annotation for abbrevs of unquantified types, to help cases ocaml can't type-infer on its own *) + let toInterpValspec = + (* HACK because of lem renaming *) + if string_of_id id = "opcode" || string_of_id id = "integer" then empty + else ( + match typ_arg with + | A_aux (A_typ _, _) -> begin + match typq with + | TypQ_aux (TypQ_no_forall, _) -> + separate space [colon; string (maybe_zencode (string_of_id id)); arrow; string "value"] + | _ -> empty + end + | _ -> empty + ) + in + let toInterpValue = + separate space + [ + string "let"; + toInterpValueName; + separate space (toValueTypqs typq); + toInterpValspec; + equals; + toValueTypArg typ_arg; + ] + in + toInterpValue ^^ twice hardline + end | TD_enum (Id_aux (Id "read_kind", _), _, _) -> empty | TD_enum (Id_aux (Id "write_kind", _), _, _) -> empty | TD_enum (Id_aux (Id "a64_barrier_domain", _), _, _) -> empty @@ -379,51 +528,80 @@ let tointerp_typedef (TD_aux (td_aux, (l, _))) = | TD_enum (Id_aux (Id "trans_kind", _), _, _) -> empty | TD_enum (Id_aux (Id "cache_op_kind", _), _, _) -> empty | TD_enum (id, ids, _) -> - let toInterpValueName = concat [string (maybe_zencode (string_of_id id)); string "ToInterpValue"] in - let toInterpValue = - prefix 2 1 - (separate space [string "let"; toInterpValueName; string "v"; equals; string "match v with"]) - ((separate_map hardline + let toInterpValueName = concat [string (maybe_zencode (string_of_id id)); string "ToInterpValue"] in + let toInterpValue = + prefix 2 1 + (separate space [string "let"; toInterpValueName; string "v"; equals; string "match v with"]) + (separate_map hardline (fun id -> separate space - [pipe; string (maybe_zencode_upper (string_of_id id)); - arrow; string "V_ctor"; parens (concat [dquotes (string (string_of_id id)); comma_sp; string "[]"])] + [ + pipe; + string (maybe_zencode_upper (string_of_id id)); + arrow; + string "V_ctor"; + parens (concat [dquotes (string (string_of_id id)); comma_sp; string "[]"]); + ] ) - ids)) - in - toInterpValue ^^ (twice hardline) + ids + ) + in + toInterpValue ^^ twice hardline | TD_record (record_id, typq, fields, _) -> - let toInterpField (typ, id) = - parens (separate comma_sp [dquotes (string (string_of_id id)); toValueTyp typ ("r." ^ (maybe_zencode ((if !lem_mode then string_of_id record_id ^ "_" else "") ^ string_of_id id)))]) - in - let toInterpValueName = concat [string (maybe_zencode (string_of_id record_id)); string "ToInterpValue"] in - let toInterpValue = - prefix 2 1 - (separate space [string "let"; toInterpValueName; separate space (toValueTypqs typq @ [string "r"]); equals]) - (separate space [string "V_record"; parens (separate space [string "List.fold_left (fun m (k, v) -> StringMap.add k v m) StringMap.empty"; (brackets (separate_map (semi ^^ space) toInterpField fields))])]) - in - toInterpValue ^^ (twice hardline) + let toInterpField (typ, id) = + parens + (separate comma_sp + [ + dquotes (string (string_of_id id)); + toValueTyp typ + ("r." ^ maybe_zencode ((if !lem_mode then string_of_id record_id ^ "_" else "") ^ string_of_id id)); + ] + ) + in + let toInterpValueName = concat [string (maybe_zencode (string_of_id record_id)); string "ToInterpValue"] in + let toInterpValue = + prefix 2 1 + (separate space [string "let"; toInterpValueName; separate space (toValueTypqs typq @ [string "r"]); equals]) + (separate space + [ + string "V_record"; + parens + (separate space + [ + string "List.fold_left (fun m (k, v) -> StringMap.add k v m) StringMap.empty"; + brackets (separate_map (semi ^^ space) toInterpField fields); + ] + ); + ] + ) + in + toInterpValue ^^ twice hardline | _ -> empty - -let tofrominterp_def (DEF_aux (aux, _)) = match aux with +let tofrominterp_def (DEF_aux (aux, _)) = + match aux with | DEF_type td -> group (frominterp_typedef td ^^ twice hardline ^^ tointerp_typedef td ^^ twice hardline) | _ -> empty let tofrominterp_ast name { defs; _ } = (string "open Sail_lib;;" ^^ hardline) ^^ (string "open Value;;" ^^ hardline) - ^^ (if !lem_mode then (string "open Sail2_instr_kinds;;" ^^ hardline) else empty) + ^^ (if !lem_mode then string "open Sail2_instr_kinds;;" ^^ hardline else empty) ^^ (string ("open " ^ String.capitalize_ascii name ^ ";;") ^^ hardline) - ^^ (if !lem_mode then (string ("open " ^ String.capitalize_ascii name ^ "_types;;") ^^ hardline) else empty) - ^^ (if !lem_mode then (string ("open " ^ String.capitalize_ascii name ^ "_extras;;") ^^ hardline) else empty) + ^^ (if !lem_mode then string ("open " ^ String.capitalize_ascii name ^ "_types;;") ^^ hardline else empty) + ^^ (if !lem_mode then string ("open " ^ String.capitalize_ascii name ^ "_extras;;") ^^ hardline else empty) ^^ (string "module Big_int = Nat_big_num" ^^ ocaml_def_end) - ^^ (if !mword_mode then (string "include ToFromInterp_lib_mword" ^^ hardline) else empty) - ^^ (if not !mword_mode then (string "include ToFromInterp_lib_bitlist.Make(struct type t = Sail2_values.bitU0 let b0 = Sail2_values.B00 let b1 = Sail2_values.B10 end)" ^^ hardline) else empty) + ^^ (if !mword_mode then string "include ToFromInterp_lib_mword" ^^ hardline else empty) + ^^ ( if not !mword_mode then + string + "include ToFromInterp_lib_bitlist.Make(struct type t = Sail2_values.bitU0 let b0 = Sail2_values.B00 let b1 \ + = Sail2_values.B10 end)" + ^^ hardline + else empty + ) ^^ concat (List.map tofrominterp_def defs) -let tofrominterp_pp_ast name f ast = - ToChannel.pretty 1. 80 f (tofrominterp_ast name ast) +let tofrominterp_pp_ast name f ast = ToChannel.pretty 1. 80 f (tofrominterp_ast name ast) let tofrominterp_output maybe_dir name ast = let dir = match maybe_dir with Some dir -> dir | None -> "." in diff --git a/src/sail_output/dune b/src/sail_output/dune index 9307deb30..4a0fa1a5e 100644 --- a/src/sail_output/dune +++ b/src/sail_output/dune @@ -1,10 +1,12 @@ (executable - (name sail_plugin_output) - (modes (native plugin)) - (libraries libsail)) + (name sail_plugin_output) + (modes + (native plugin)) + (libraries libsail)) (install - (section (site (libsail plugins))) - (package sail_output) - (files sail_plugin_output.cmxs)) - \ No newline at end of file + (section + (site + (libsail plugins))) + (package sail_output) + (files sail_plugin_output.cmxs)) diff --git a/src/sail_output/sail_plugin_output.ml b/src/sail_output/sail_plugin_output.ml index 79818ae9f..6d50976f6 100644 --- a/src/sail_output/sail_plugin_output.ml +++ b/src/sail_output/sail_plugin_output.ml @@ -67,23 +67,19 @@ open Libsail -let output_sail_options = [ - ( "-output_sail_dir", - Arg.String (fun dir -> Frontend.opt_reformat := Some dir), - " set a directory to output pretty-printed Sail"); -] +let output_sail_options = + [ + ( "-output_sail_dir", + Arg.String (fun dir -> Frontend.opt_reformat := Some dir), + " set a directory to output pretty-printed Sail" + ); + ] let sail_target _ out_file ast _effect_info _env = - let close, output_chan = match out_file with - | Some f -> true, open_out (f ^ ".sail") - | None -> false, stdout in - Pretty_print_sail.pp_ast output_chan (Type_check.strip_ast ast); - if close then close_out output_chan + let close, output_chan = match out_file with Some f -> (true, open_out (f ^ ".sail")) | None -> (false, stdout) in + Pretty_print_sail.pp_ast output_chan (Type_check.strip_ast ast); + if close then close_out output_chan let _ = - Target.register - ~name:"sail" - ~flag:"output_sail" - ~options:output_sail_options - ~description:" print Sail code after type checking and initial rewriting" - sail_target + Target.register ~name:"sail" ~flag:"output_sail" ~options:output_sail_options + ~description:" print Sail code after type checking and initial rewriting" sail_target diff --git a/src/sail_smt_backend/dune b/src/sail_smt_backend/dune index 02b04ae9d..7a0128ab8 100644 --- a/src/sail_smt_backend/dune +++ b/src/sail_smt_backend/dune @@ -1,15 +1,20 @@ (env - (dev - (flags (:standard -w -33 -w -27 -w -32 -w -26 -w -37))) - (release - (flags (:standard -w -33 -w -27 -w -32 -w -26 -w -37)))) + (dev + (flags + (:standard -w -33 -w -27 -w -32 -w -26 -w -37))) + (release + (flags + (:standard -w -33 -w -27 -w -32 -w -26 -w -37)))) (executable - (name sail_plugin_smt) - (modes (native plugin)) - (libraries libsail)) + (name sail_plugin_smt) + (modes + (native plugin)) + (libraries libsail)) (install - (section (site (libsail plugins))) - (package sail_smt_backend) - (files sail_plugin_smt.cmxs)) + (section + (site + (libsail plugins))) + (package sail_smt_backend) + (files sail_plugin_smt.cmxs)) diff --git a/src/sail_smt_backend/jib_ir.ml b/src/sail_smt_backend/jib_ir.ml index b93b7f779..51cd7a990 100644 --- a/src/sail_smt_backend/jib_ir.ml +++ b/src/sail_smt_backend/jib_ir.ml @@ -76,20 +76,16 @@ open Printf let zencode_id id = Util.zencode_string (string_of_id id) -module StringMap = Map.Make(String) +module StringMap = Map.Make (String) let string_of_name = - let ssa_num n = if n = -1 then "" else ("/" ^ string_of_int n) in + let ssa_num n = if n = -1 then "" else "/" ^ string_of_int n in function | Name (id, n) | Global (id, n) -> zencode_id id ^ ssa_num n - | Have_exception n -> - "have_exception" ^ ssa_num n - | Return n -> - "return" ^ ssa_num n - | Current_exception n -> - "current_exception" ^ ssa_num n - | Throw_location n -> - "throw_location" ^ ssa_num n + | Have_exception n -> "have_exception" ^ ssa_num n + | Return n -> "return" ^ ssa_num n + | Current_exception n -> "current_exception" ^ ssa_num n + | Throw_location n -> "throw_location" ^ ssa_num n let rec string_of_clexp = function | CL_id (id, ctyp) -> string_of_name id @@ -121,8 +117,10 @@ module Ir_formatter = struct let file_number file_name = let rec scan n = function - | (f :: fs) -> if f = file_name then n else scan (n + 1) fs - | [] -> (file_map := !file_map @ [file_name]; n) + | f :: fs -> if f = file_name then n else scan (n + 1) fs + | [] -> + file_map := !file_map @ [file_name]; + n in scan 0 !file_map @@ -130,114 +128,112 @@ module Ir_formatter = struct match Reporting.simp_loc l with | None -> "`" | Some (p1, p2) -> - Printf.sprintf "%d %d:%d-%d:%d" - (file_number p1.pos_fname) p1.pos_lnum (p1.pos_cnum - p1.pos_bol) p2.pos_lnum (p2.pos_cnum - p2.pos_bol) + Printf.sprintf "%d %d:%d-%d:%d" (file_number p1.pos_fname) p1.pos_lnum (p1.pos_cnum - p1.pos_bol) p2.pos_lnum + (p2.pos_cnum - p2.pos_bol) let output_files buf = Buffer.add_string buf (C.keyword "files"); - List.iter (fun file_name -> - Buffer.add_string buf (" \"" ^ file_name ^ "\"") - ) !file_map - + List.iter (fun file_name -> Buffer.add_string buf (" \"" ^ file_name ^ "\"")) !file_map + let rec output_instr n buf indent label_map (I_aux (instr, (_, l))) = match instr with | I_decl (ctyp, id) | I_reset (ctyp, id) -> - add_instr n buf indent (string_of_name id ^ " : " ^ C.typ ctyp ^ " `" ^ output_loc l) + add_instr n buf indent (string_of_name id ^ " : " ^ C.typ ctyp ^ " `" ^ output_loc l) | I_init (ctyp, id, cval) | I_reinit (ctyp, id, cval) -> - add_instr n buf indent (string_of_name id ^ " : " ^ C.typ ctyp ^ " = " ^ C.value cval ^ " `" ^ output_loc l) - | I_clear (ctyp, id) -> - add_instr n buf indent ("!" ^ string_of_name id) - | I_label label -> - C.output_label_instr buf label_map label + add_instr n buf indent (string_of_name id ^ " : " ^ C.typ ctyp ^ " = " ^ C.value cval ^ " `" ^ output_loc l) + | I_clear (ctyp, id) -> add_instr n buf indent ("!" ^ string_of_name id) + | I_label label -> C.output_label_instr buf label_map label | I_jump (cval, label) -> - add_instr n buf indent (C.keyword "jump" ^ " " ^ C.value cval ^ " " - ^ C.keyword "goto" ^ " " ^ C.string_of_label (StringMap.find label label_map) - ^ " `" ^ output_loc l) + add_instr n buf indent + (C.keyword "jump" ^ " " ^ C.value cval ^ " " ^ C.keyword "goto" ^ " " + ^ C.string_of_label (StringMap.find label label_map) + ^ " `" ^ output_loc l + ) | I_goto label -> - add_instr n buf indent (C.keyword "goto" ^ " " ^ C.string_of_label (StringMap.find label label_map)) - | I_exit cause -> - add_instr n buf indent (C.keyword "failure" ^ " " ^ cause) - | I_undefined _ -> - add_instr n buf indent (C.keyword "arbitrary") - | I_end _ -> - add_instr n buf indent (C.keyword "end") - | I_copy (clexp, cval) -> - add_instr n buf indent (string_of_clexp clexp ^ " = " ^ C.value cval) + add_instr n buf indent (C.keyword "goto" ^ " " ^ C.string_of_label (StringMap.find label label_map)) + | I_exit cause -> add_instr n buf indent (C.keyword "failure" ^ " " ^ cause) + | I_undefined _ -> add_instr n buf indent (C.keyword "arbitrary") + | I_end _ -> add_instr n buf indent (C.keyword "end") + | I_copy (clexp, cval) -> add_instr n buf indent (string_of_clexp clexp ^ " = " ^ C.value cval) | I_funcall (clexp, false, id, args) -> - add_instr n buf indent (string_of_clexp clexp ^ " = " ^ string_of_uid id ^ "(" ^ Util.string_of_list ", " C.value args ^ ")" ^ " `" ^ output_loc l) + add_instr n buf indent + (string_of_clexp clexp ^ " = " ^ string_of_uid id ^ "(" ^ Util.string_of_list ", " C.value args ^ ")" ^ " `" + ^ output_loc l + ) | I_funcall (clexp, true, id, args) -> - add_instr n buf indent (string_of_clexp clexp ^ " = $" ^ string_of_uid id ^ "(" ^ Util.string_of_list ", " C.value args ^ ")" ^ " `" ^ output_loc l) - | I_return cval -> - add_instr n buf indent (C.keyword "return" ^ " " ^ C.value cval) - | I_comment str -> - Buffer.add_string buf (String.make indent ' ' ^ "/*" ^ str ^ "*/\n") - | I_raw str -> - Buffer.add_string buf str - | I_throw cval -> - add_instr n buf indent (C.keyword "throw" ^ " " ^ C.value cval) - | I_if _ | I_block _ | I_try_block _ -> - Reporting.unreachable Parse_ast.Unknown __POS__ "Can only format flat IR" + add_instr n buf indent + (string_of_clexp clexp ^ " = $" ^ string_of_uid id ^ "(" ^ Util.string_of_list ", " C.value args ^ ")" + ^ " `" ^ output_loc l + ) + | I_return cval -> add_instr n buf indent (C.keyword "return" ^ " " ^ C.value cval) + | I_comment str -> Buffer.add_string buf (String.make indent ' ' ^ "/*" ^ str ^ "*/\n") + | I_raw str -> Buffer.add_string buf str + | I_throw cval -> add_instr n buf indent (C.keyword "throw" ^ " " ^ C.value cval) + | I_if _ | I_block _ | I_try_block _ -> Reporting.unreachable Parse_ast.Unknown __POS__ "Can only format flat IR" and output_instrs n buf indent label_map = function | (I_aux (I_label _, _) as instr) :: instrs -> - output_instr n buf indent label_map instr; - output_instrs n buf indent label_map instrs + output_instr n buf indent label_map instr; + output_instrs n buf indent label_map instrs | instr :: instrs -> - output_instr n buf indent label_map instr; - output_instrs (n + 1) buf indent label_map instrs + output_instr n buf indent label_map instr; + output_instrs (n + 1) buf indent label_map instrs | [] -> () - let id_ctyp (id, ctyp) = - sprintf "%s: %s" (zencode_id id) (C.typ ctyp) + let id_ctyp (id, ctyp) = sprintf "%s: %s" (zencode_id id) (C.typ ctyp) - let uid_ctyp (uid, ctyp) = - sprintf "%s: %s" (string_of_uid uid) (C.typ ctyp) + let uid_ctyp (uid, ctyp) = sprintf "%s: %s" (string_of_uid uid) (C.typ ctyp) let output_def buf = function | CDEF_register (id, ctyp, _) -> - Buffer.add_string buf (sprintf "%s %s : %s" (C.keyword "register") (zencode_id id) (C.typ ctyp)) + Buffer.add_string buf (sprintf "%s %s : %s" (C.keyword "register") (zencode_id id) (C.typ ctyp)) | CDEF_val (id, None, ctyps, ctyp) -> - Buffer.add_string buf (sprintf "%s %s : (%s) -> %s" (C.keyword "val") (zencode_id id) (Util.string_of_list ", " C.typ ctyps) (C.typ ctyp)); + Buffer.add_string buf + (sprintf "%s %s : (%s) -> %s" (C.keyword "val") (zencode_id id) (Util.string_of_list ", " C.typ ctyps) + (C.typ ctyp) + ) | CDEF_val (id, Some extern, ctyps, ctyp) -> - Buffer.add_string buf (sprintf "%s %s = \"%s\" : (%s) -> %s" (C.keyword "val") (zencode_id id) extern (Util.string_of_list ", " C.typ ctyps) (C.typ ctyp)); + Buffer.add_string buf + (sprintf "%s %s = \"%s\" : (%s) -> %s" (C.keyword "val") (zencode_id id) extern + (Util.string_of_list ", " C.typ ctyps) (C.typ ctyp) + ) | CDEF_fundef (id, ret, args, instrs) -> - let instrs = C.modify_instrs instrs in - let label_map = C.make_label_map instrs in - let ret = match ret with - | None -> "" - | Some id -> " " ^ zencode_id id - in - Buffer.add_string buf (sprintf "%s %s%s(%s) {\n" (C.keyword "fn") (zencode_id id) ret (Util.string_of_list ", " zencode_id args)); - output_instrs 0 buf 2 label_map instrs; - Buffer.add_string buf "}" + let instrs = C.modify_instrs instrs in + let label_map = C.make_label_map instrs in + let ret = match ret with None -> "" | Some id -> " " ^ zencode_id id in + Buffer.add_string buf + (sprintf "%s %s%s(%s) {\n" (C.keyword "fn") (zencode_id id) ret (Util.string_of_list ", " zencode_id args)); + output_instrs 0 buf 2 label_map instrs; + Buffer.add_string buf "}" | CDEF_type (CTD_enum (id, ids)) -> - Buffer.add_string buf (sprintf "%s %s {\n %s\n}" (C.keyword "enum") (zencode_id id) (Util.string_of_list ",\n " zencode_id ids)) + Buffer.add_string buf + (sprintf "%s %s {\n %s\n}" (C.keyword "enum") (zencode_id id) (Util.string_of_list ",\n " zencode_id ids)) | CDEF_type (CTD_struct (id, ids)) -> - Buffer.add_string buf (sprintf "%s %s {\n %s\n}" (C.keyword "struct") (zencode_id id) (Util.string_of_list ",\n " id_ctyp ids)) + Buffer.add_string buf + (sprintf "%s %s {\n %s\n}" (C.keyword "struct") (zencode_id id) (Util.string_of_list ",\n " id_ctyp ids)) | CDEF_type (CTD_variant (id, ids)) -> - Buffer.add_string buf (sprintf "%s %s {\n %s\n}" (C.keyword "union") (zencode_id id) (Util.string_of_list ",\n " id_ctyp ids)) + Buffer.add_string buf + (sprintf "%s %s {\n %s\n}" (C.keyword "union") (zencode_id id) (Util.string_of_list ",\n " id_ctyp ids)) | CDEF_let (_, bindings, instrs) -> - let instrs = C.modify_instrs instrs in - let label_map = C.make_label_map instrs in - Buffer.add_string buf (sprintf "%s (%s) {\n" (C.keyword "let") (Util.string_of_list ", " id_ctyp bindings)); - output_instrs 0 buf 2 label_map instrs; - Buffer.add_string buf "}" - | CDEF_startup _ | CDEF_finish _ -> - Reporting.unreachable Parse_ast.Unknown __POS__ "Unexpected startup / finish" + let instrs = C.modify_instrs instrs in + let label_map = C.make_label_map instrs in + Buffer.add_string buf (sprintf "%s (%s) {\n" (C.keyword "let") (Util.string_of_list ", " id_ctyp bindings)); + output_instrs 0 buf 2 label_map instrs; + Buffer.add_string buf "}" + | CDEF_startup _ | CDEF_finish _ -> Reporting.unreachable Parse_ast.Unknown __POS__ "Unexpected startup / finish" | CDEF_pragma _ -> () let rec output_defs' buf = function | def :: defs -> - output_def buf def; - Buffer.add_string buf "\n\n"; - output_defs' buf defs + output_def buf def; + Buffer.add_string buf "\n\n"; + output_defs' buf defs | [] -> () let output_defs buf defs = output_defs' buf defs; output_files buf; Buffer.add_string buf "\n\n" - end end @@ -254,10 +250,8 @@ module Flat_ir_config : Ir_formatter.Config = struct let make_label_map instrs = let rec make_label_map' n = function - | I_aux (I_label label, _) :: instrs -> - StringMap.add label n (make_label_map' n instrs) - | _ :: instrs -> - make_label_map' (n + 1) instrs + | I_aux (I_label label, _) :: instrs -> StringMap.add label n (make_label_map' n instrs) + | _ :: instrs -> make_label_map' (n + 1) instrs | [] -> StringMap.empty in make_label_map' 0 instrs @@ -265,30 +259,19 @@ module Flat_ir_config : Ir_formatter.Config = struct let modify_instrs instrs = let open Jib_optimize in reset_flat_counter (); - instrs - |> flatten_instrs - |> remove_clear - |> remove_dead_code + instrs |> flatten_instrs |> remove_clear |> remove_dead_code let string_of_label = string_of_int let output_label_instr buf _ label = () - let color f = - if !colored_ir then - f - else - (fun str -> str) - - let keyword str = - str |> color Util.red |> color Util.clear + let color f = if !colored_ir then f else fun str -> str - let typ str = - string_of_ctyp str |> color Util.yellow |> color Util.clear + let keyword str = str |> color Util.red |> color Util.clear - let value str = - string_of_cval str |> color Util.cyan |> color Util.clear + let typ str = string_of_ctyp str |> color Util.yellow |> color Util.clear + let value str = string_of_cval str |> color Util.cyan |> color Util.clear end -module Flat_ir_formatter = Ir_formatter.Make(Flat_ir_config) +module Flat_ir_formatter = Ir_formatter.Make (Flat_ir_config) diff --git a/src/sail_smt_backend/jib_smt.ml b/src/sail_smt_backend/jib_smt.ml index d7f613ff4..e26e57a49 100644 --- a/src/sail_smt_backend/jib_smt.ml +++ b/src/sail_smt_backend/jib_smt.ml @@ -76,8 +76,14 @@ open Jib_util open Smtlib open Property -module IntSet = Set.Make(struct type t = int let compare = compare end) -module IntMap = Map.Make(struct type t = int let compare = compare end) +module IntSet = Set.Make (struct + type t = int + let compare = compare +end) +module IntMap = Map.Make (struct + type t = int + let compare = compare +end) let zencode_upper_id id = Util.zencode_upper_string (string_of_id id) let zencode_id id = Util.zencode_string (string_of_id id) @@ -97,30 +103,30 @@ let opt_propagate_vars = ref false let opt_unroll_limit = ref 10 -module EventMap = Map.Make(Event) +module EventMap = Map.Make (Event) (* Note that we have to use x : ty ref rather than mutable x : ty, to make sure { ctx with x = ... } doesn't break the mutable state. *) (* See mli file for a description of each field *) type ctx = { - lbits_index : int; - lint_size : int; - vector_index : int; - register_map : id list CTMap.t; - tuple_sizes : IntSet.t ref; - tc_env : Type_check.Env.t; - pragma_l : Ast.l; - arg_stack : (int * string) Stack.t; - ast : Type_check.tannot ast; - shared : ctyp Bindings.t; - preserved : IdSet.t; - events : smt_exp Stack.t EventMap.t ref; - node : int; - pathcond : smt_exp Lazy.t; - use_string : bool ref; - use_real : bool ref - } + lbits_index : int; + lint_size : int; + vector_index : int; + register_map : id list CTMap.t; + tuple_sizes : IntSet.t ref; + tc_env : Type_check.Env.t; + pragma_l : Ast.l; + arg_stack : (int * string) Stack.t; + ast : Type_check.tannot ast; + shared : ctyp Bindings.t; + preserved : IdSet.t; + events : smt_exp Stack.t EventMap.t ref; + node : int; + pathcond : smt_exp Lazy.t; + use_string : bool ref; + use_real : bool ref; +} (* These give the default bounds for various SMT types, stored in the initial_ctx. They shouldn't be read or written by anything else! If @@ -130,7 +136,8 @@ let opt_default_lint_size = ref 128 let opt_default_lbits_index = ref 8 let opt_default_vector_index = ref 5 -let initial_ctx () = { +let initial_ctx () = + { lbits_index = !opt_default_lbits_index; lint_size = !opt_default_lint_size; vector_index = !opt_default_vector_index; @@ -153,16 +160,15 @@ let event_stack ctx ev = match EventMap.find_opt ev !(ctx.events) with | Some stack -> stack | None -> - let stack = Stack.create () in - ctx.events := EventMap.add ev stack !(ctx.events); - stack + let stack = Stack.create () in + ctx.events := EventMap.add ev stack !(ctx.events); + stack let add_event ctx ev smt = let stack = event_stack ctx ev in Stack.push (Fn ("and", [Lazy.force ctx.pathcond; smt])) stack -let add_pathcond_event ctx ev = - Stack.push (Lazy.force ctx.pathcond) (event_stack ctx ev) +let add_pathcond_event ctx ev = Stack.push (Lazy.force ctx.pathcond) (event_stack ctx ev) let overflow_check ctx smt = if not !opt_ignore_overflow then ( @@ -181,10 +187,7 @@ let smt_lbits ctx = mk_record "Bits" [("size", Bitvec ctx.lbits_index); ("bits", represent an integer n *) let required_width n = let rec required_width' n = - if Big_int.equal n Big_int.zero then - 1 - else - 1 + required_width' (Big_int.shift_right n 1) + if Big_int.equal n Big_int.zero then 1 else 1 + required_width' (Big_int.shift_right n 1) in required_width' (Big_int.abs n) @@ -198,34 +201,30 @@ let rec smt_ctyp ctx = function | CT_sbits (n, _) -> smt_lbits ctx | CT_lbits _ -> smt_lbits ctx | CT_bool -> Bool - | CT_enum (id, elems) -> - mk_enum (zencode_upper_id id) (List.map zencode_id elems) + | CT_enum (id, elems) -> mk_enum (zencode_upper_id id) (List.map zencode_id elems) | CT_struct (id, fields) -> - mk_record (zencode_upper_id id) (List.map (fun (id, ctyp) -> (zencode_id id, smt_ctyp ctx ctyp)) fields) + mk_record (zencode_upper_id id) (List.map (fun (id, ctyp) -> (zencode_id id, smt_ctyp ctx ctyp)) fields) | CT_variant (id, ctors) -> - mk_variant (zencode_upper_id id) (List.map (fun (id, ctyp) -> (zencode_id id, smt_ctyp ctx ctyp)) ctors) + mk_variant (zencode_upper_id id) (List.map (fun (id, ctyp) -> (zencode_id id, smt_ctyp ctx ctyp)) ctors) | CT_tup ctyps -> - ctx.tuple_sizes := IntSet.add (List.length ctyps) !(ctx.tuple_sizes); - Tuple (List.map (smt_ctyp ctx) ctyps) + ctx.tuple_sizes := IntSet.add (List.length ctyps) !(ctx.tuple_sizes); + Tuple (List.map (smt_ctyp ctx) ctyps) | CT_vector (_, ctyp) -> Array (Bitvec !vector_index, smt_ctyp ctx ctyp) | CT_string -> - ctx.use_string := true; - String + ctx.use_string := true; + String | CT_real -> - ctx.use_real := true; - Real - | CT_ref ctyp -> - begin match CTMap.find_opt ctyp ctx.register_map with - | Some regs -> Bitvec (required_width (Big_int.of_int (List.length regs))) - | _ -> failwith ("No registers with ctyp: " ^ string_of_ctyp ctyp) - end + ctx.use_real := true; + Real + | CT_ref ctyp -> begin + match CTMap.find_opt ctyp ctx.register_map with + | Some regs -> Bitvec (required_width (Big_int.of_int (List.length regs))) + | _ -> failwith ("No registers with ctyp: " ^ string_of_ctyp ctyp) + end | CT_list _ -> raise (Reporting.err_todo ctx.pragma_l "Lists not yet supported in SMT generation") - | CT_float _ | CT_rounding_mode -> - Reporting.unreachable ctx.pragma_l __POS__ "Floating point in SMT property" - | CT_fvector _ -> - Reporting.unreachable ctx.pragma_l __POS__ "Found CT_fvector in SMT property" - | CT_poly _ -> - Reporting.unreachable ctx.pragma_l __POS__ "Found polymorphic type in SMT property" + | CT_float _ | CT_rounding_mode -> Reporting.unreachable ctx.pragma_l __POS__ "Floating point in SMT property" + | CT_fvector _ -> Reporting.unreachable ctx.pragma_l __POS__ "Found CT_fvector in SMT property" + | CT_poly _ -> Reporting.unreachable ctx.pragma_l __POS__ "Found polymorphic type in SMT property" (* We often need to create a SMT bitvector of a length sz with integer value x. [bvpint sz x] does this for positive integers, and [bvint sz x] @@ -238,99 +237,92 @@ let bvpint sz x = let x = Big_int.to_int x in match Printf.sprintf "%X" x |> Util.string_to_list |> List.map nibble_of_char |> Util.option_all with | Some nibbles -> - let bin = List.map (fun (a, b, c, d) -> [a; b; c; d]) nibbles |> List.concat in - let _, bin = Util.take_drop (function B0 -> true | _ -> false) bin in - let padding = List.init (sz - List.length bin) (fun _ -> B0) in - Bitvec_lit (padding @ bin) + let bin = List.map (fun (a, b, c, d) -> [a; b; c; d]) nibbles |> List.concat in + let _, bin = Util.take_drop (function B0 -> true | _ -> false) bin in + let padding = List.init (sz - List.length bin) (fun _ -> B0) in + Bitvec_lit (padding @ bin) | None -> assert false - ) else if Big_int.greater x (Big_int.of_int max_int) then ( + ) + else if Big_int.greater x (Big_int.of_int max_int) then ( let y = ref x in let bin = ref [] in - while (not (Big_int.equal !y Big_int.zero)) do - let (q, m) = Big_int.quomod !y (Big_int.of_int 2) in + while not (Big_int.equal !y Big_int.zero) do + let q, m = Big_int.quomod !y (Big_int.of_int 2) in bin := (if Big_int.equal m Big_int.zero then B0 else B1) :: !bin; y := q done; let padding_size = sz - List.length !bin in if padding_size < 0 then - raise (Reporting.err_general Parse_ast.Unknown - (Printf.sprintf "Could not create a %d-bit integer with value %s.\nTry increasing the maximum integer size" - sz (Big_int.to_string x))); + raise + (Reporting.err_general Parse_ast.Unknown + (Printf.sprintf "Could not create a %d-bit integer with value %s.\nTry increasing the maximum integer size" + sz (Big_int.to_string x) + ) + ); let padding = List.init padding_size (fun _ -> B0) in Bitvec_lit (padding @ !bin) - ) else failwith "Invalid bvpint" + ) + else failwith "Invalid bvpint" let bvint sz x = if Big_int.less x Big_int.zero then Fn ("bvadd", [Fn ("bvnot", [bvpint sz (Big_int.abs x)]); bvpint sz (Big_int.of_int 1)]) - else - bvpint sz x + else bvpint sz x (** [force_size ctx n m exp] takes a smt expression assumed to be a integer (signed bitvector) of length m and forces it to be length n by either sign extending it or truncating it as required *) -let force_size ?checked:(checked=true) ctx n m smt = - if n = m then - smt - else if n > m then - SignExtend (n - m, smt) - else +let force_size ?(checked = true) ctx n m smt = + if n = m then smt + else if n > m then SignExtend (n - m, smt) + else ( let check = (* If the top bit of the truncated number is one *) - Ite (Fn ("=", [Extract (n - 1, n - 1, smt); Bitvec_lit [Sail2_values.B1]]), - (* Then we have an overflow, unless all bits we truncated were also one *) - Fn ("not", [Fn ("=", [Extract (m - 1, n, smt); bvones (m - n)])]), - (* Otherwise, all the top bits must be zero *) - Fn ("not", [Fn ("=", [Extract (m - 1, n, smt); bvzero (m - n)])])) + Ite + ( Fn ("=", [Extract (n - 1, n - 1, smt); Bitvec_lit [Sail2_values.B1]]), + (* Then we have an overflow, unless all bits we truncated were also one *) + Fn ("not", [Fn ("=", [Extract (m - 1, n, smt); bvones (m - n)])]), + (* Otherwise, all the top bits must be zero *) + Fn ("not", [Fn ("=", [Extract (m - 1, n, smt); bvzero (m - n)])]) + ) in if checked then overflow_check ctx check else (); Extract (n - 1, 0, smt) + ) (** [unsigned_size ctx n m exp] is much like force_size, but it assumes that the bitvector is unsigned *) -let unsigned_size ?checked:(checked=true) ctx n m smt = - if n = m then - smt - else if n > m then - Fn ("concat", [bvzero (n - m); smt]) - else - Extract (n - 1, 0, smt) +let unsigned_size ?(checked = true) ctx n m smt = + if n = m then smt else if n > m then Fn ("concat", [bvzero (n - m); smt]) else Extract (n - 1, 0, smt) let smt_conversion ctx from_ctyp to_ctyp x = - match from_ctyp, to_ctyp with + match (from_ctyp, to_ctyp) with | _, _ when ctyp_equal from_ctyp to_ctyp -> x - | CT_constant c, CT_fint sz -> - bvint sz c - | CT_constant c, CT_lint -> - bvint ctx.lint_size c - | CT_fint sz, CT_lint -> - force_size ctx ctx.lint_size sz x - | CT_lint, CT_fint sz -> - force_size ctx sz ctx.lint_size x - | CT_lint, CT_fbits (n, _) -> - force_size ctx n ctx.lint_size x + | CT_constant c, CT_fint sz -> bvint sz c + | CT_constant c, CT_lint -> bvint ctx.lint_size c + | CT_fint sz, CT_lint -> force_size ctx ctx.lint_size sz x + | CT_lint, CT_fint sz -> force_size ctx sz ctx.lint_size x + | CT_lint, CT_fbits (n, _) -> force_size ctx n ctx.lint_size x | CT_lint, CT_lbits _ -> - Fn ("Bits", [bvint ctx.lbits_index (Big_int.of_int ctx.lint_size); force_size ctx (lbits_size ctx) ctx.lint_size x]) - | CT_fint n, CT_lbits _ -> - Fn ("Bits", [bvint ctx.lbits_index (Big_int.of_int n); force_size ctx (lbits_size ctx) n x]) - | CT_lbits _, CT_fbits (n, _) -> - unsigned_size ctx n (lbits_size ctx) (Fn ("contents", [x])) - | CT_fbits (n, _), CT_fbits (m, _) -> - unsigned_size ctx m n x + Fn + ("Bits", [bvint ctx.lbits_index (Big_int.of_int ctx.lint_size); force_size ctx (lbits_size ctx) ctx.lint_size x]) + | CT_fint n, CT_lbits _ -> Fn ("Bits", [bvint ctx.lbits_index (Big_int.of_int n); force_size ctx (lbits_size ctx) n x]) + | CT_lbits _, CT_fbits (n, _) -> unsigned_size ctx n (lbits_size ctx) (Fn ("contents", [x])) + | CT_fbits (n, _), CT_fbits (m, _) -> unsigned_size ctx m n x | CT_fbits (n, _), CT_lbits _ -> - Fn ("Bits", [bvint ctx.lbits_index (Big_int.of_int n); unsigned_size ctx (lbits_size ctx) n x]) - - | _, _ -> failwith (Printf.sprintf "Cannot perform conversion from %s to %s" (string_of_ctyp from_ctyp) (string_of_ctyp to_ctyp)) + Fn ("Bits", [bvint ctx.lbits_index (Big_int.of_int n); unsigned_size ctx (lbits_size ctx) n x]) + | _, _ -> + failwith + (Printf.sprintf "Cannot perform conversion from %s to %s" (string_of_ctyp from_ctyp) (string_of_ctyp to_ctyp)) (* Translate Jib literals into SMT *) let smt_value ctx vl ctyp = let open Value2 in - match vl, ctyp with - | VL_bits (bv, true), CT_fbits (n, _) -> - unsigned_size ctx n (List.length bv) (Bitvec_lit bv) + match (vl, ctyp) with + | VL_bits (bv, true), CT_fbits (n, _) -> unsigned_size ctx n (List.length bv) (Bitvec_lit bv) | VL_bits (bv, true), CT_lbits _ -> - let sz = List.length bv in - Fn ("Bits", [bvint ctx.lbits_index (Big_int.of_int sz); unsigned_size ctx (lbits_size ctx) sz (Bitvec_lit bv)]) + let sz = List.length bv in + Fn ("Bits", [bvint ctx.lbits_index (Big_int.of_int sz); unsigned_size ctx (lbits_size ctx) sz (Bitvec_lit bv)]) | VL_bool b, _ -> Bool_lit b | VL_int n, CT_constant m -> bvint (required_width n) n | VL_int n, CT_fint sz -> bvint sz n @@ -338,87 +330,74 @@ let smt_value ctx vl ctyp = | VL_bit b, CT_bit -> Bitvec_lit [b] | VL_unit, _ -> Enum "unit" | VL_string str, _ -> - ctx.use_string := true; - String_lit (String.escaped str) + ctx.use_string := true; + String_lit (String.escaped str) | VL_real str, _ -> - ctx.use_real := true; - if str.[0] = '-' then - Fn ("-", [Real_lit (String.sub str 1 (String.length str - 1))]) - else - Real_lit str + ctx.use_real := true; + if str.[0] = '-' then Fn ("-", [Real_lit (String.sub str 1 (String.length str - 1))]) else Real_lit str | VL_enum str, _ -> Enum (Util.zencode_string str) | VL_ref reg_name, _ -> - let id = mk_id reg_name in - let rmap = CTMap.filter (fun ctyp regs -> List.exists (fun reg -> Id.compare reg id = 0) regs) ctx.register_map in - assert (CTMap.cardinal rmap = 1); - begin match CTMap.min_binding_opt rmap with - | Some (ctyp, regs) -> - begin match Util.list_index (fun reg -> Id.compare reg id = 0) regs with - | Some i -> - bvint (required_width (Big_int.of_int (List.length regs))) (Big_int.of_int i) - | None -> assert false - end - | _ -> assert false - end + let id = mk_id reg_name in + let rmap = CTMap.filter (fun ctyp regs -> List.exists (fun reg -> Id.compare reg id = 0) regs) ctx.register_map in + assert (CTMap.cardinal rmap = 1); + begin + match CTMap.min_binding_opt rmap with + | Some (ctyp, regs) -> begin + match Util.list_index (fun reg -> Id.compare reg id = 0) regs with + | Some i -> bvint (required_width (Big_int.of_int (List.length regs))) (Big_int.of_int i) + | None -> assert false + end + | _ -> assert false + end | _ -> failwith ("Cannot translate literal to SMT: " ^ string_of_value vl ^ " : " ^ string_of_ctyp ctyp) let rec smt_cval ctx cval = match cval_ctyp cval with - | CT_constant n -> - bvint (required_width n) n - | _ -> - match cval with - | V_lit (vl, ctyp) -> smt_value ctx vl ctyp - | V_id ((Name (id, _) | Global (id, _)) as ssa_id, _) -> - begin match Type_check.Env.lookup_id id ctx.tc_env with - | Enum _ -> Enum (zencode_id id) - | _ when Bindings.mem id ctx.shared -> Shared (zencode_id id) - | _ -> Var (zencode_name ssa_id) + | CT_constant n -> bvint (required_width n) n + | _ -> ( + match cval with + | V_lit (vl, ctyp) -> smt_value ctx vl ctyp + | V_id (((Name (id, _) | Global (id, _)) as ssa_id), _) -> begin + match Type_check.Env.lookup_id id ctx.tc_env with + | Enum _ -> Enum (zencode_id id) + | _ when Bindings.mem id ctx.shared -> Shared (zencode_id id) + | _ -> Var (zencode_name ssa_id) end - | V_id (ssa_id, _) -> Var (zencode_name ssa_id) - | V_call (Neq, [cval1; cval2]) -> - Fn ("not", [Fn ("=", [smt_cval ctx cval1; smt_cval ctx cval2])]) - | V_call (Bvor, [cval1; cval2]) -> - Fn ("bvor", [smt_cval ctx cval1; smt_cval ctx cval2]) - | V_call (Eq, [cval1; cval2]) -> - Fn ("=", [smt_cval ctx cval1; smt_cval ctx cval2]) - | V_call (Bnot, [cval]) -> - Fn ("not", [smt_cval ctx cval]) - | V_call (Band, cvals) -> - smt_conj (List.map (smt_cval ctx) cvals) - | V_call (Bor, cvals) -> - smt_disj (List.map (smt_cval ctx) cvals) - | V_call (Igt, [cval1; cval2]) -> - Fn ("bvsgt", [smt_cval ctx cval1; smt_cval ctx cval2]) - | V_call (Iadd, [cval1; cval2]) -> - Fn ("bvadd", [smt_cval ctx cval1; smt_cval ctx cval2]) - | V_ctor_kind (union, ctor, _) -> - Fn ("not", [Tester (zencode_uid ctor, smt_cval ctx union)]) - | V_ctor_unwrap (union, ctor, _) -> - Fn ("un" ^ zencode_uid ctor, [smt_cval ctx union]) - | V_field (record, field) -> - begin match cval_ctyp record with - | CT_struct (struct_id, _) -> - Field (zencode_upper_id struct_id ^ "_" ^ zencode_id field, smt_cval ctx record) - | _ -> failwith "Field for non-struct type" + | V_id (ssa_id, _) -> Var (zencode_name ssa_id) + | V_call (Neq, [cval1; cval2]) -> Fn ("not", [Fn ("=", [smt_cval ctx cval1; smt_cval ctx cval2])]) + | V_call (Bvor, [cval1; cval2]) -> Fn ("bvor", [smt_cval ctx cval1; smt_cval ctx cval2]) + | V_call (Eq, [cval1; cval2]) -> Fn ("=", [smt_cval ctx cval1; smt_cval ctx cval2]) + | V_call (Bnot, [cval]) -> Fn ("not", [smt_cval ctx cval]) + | V_call (Band, cvals) -> smt_conj (List.map (smt_cval ctx) cvals) + | V_call (Bor, cvals) -> smt_disj (List.map (smt_cval ctx) cvals) + | V_call (Igt, [cval1; cval2]) -> Fn ("bvsgt", [smt_cval ctx cval1; smt_cval ctx cval2]) + | V_call (Iadd, [cval1; cval2]) -> Fn ("bvadd", [smt_cval ctx cval1; smt_cval ctx cval2]) + | V_ctor_kind (union, ctor, _) -> Fn ("not", [Tester (zencode_uid ctor, smt_cval ctx union)]) + | V_ctor_unwrap (union, ctor, _) -> Fn ("un" ^ zencode_uid ctor, [smt_cval ctx union]) + | V_field (record, field) -> begin + match cval_ctyp record with + | CT_struct (struct_id, _) -> Field (zencode_upper_id struct_id ^ "_" ^ zencode_id field, smt_cval ctx record) + | _ -> failwith "Field for non-struct type" end - | V_struct (fields, ctyp) -> - begin match ctyp with - | CT_struct (struct_id, field_ctyps) -> - let set_field (field, cval) = - match Util.assoc_compare_opt Id.compare field field_ctyps with - | None -> failwith "Field type not found" - | Some ctyp -> - zencode_upper_id struct_id ^ "_" ^ zencode_id field, - smt_conversion ctx (cval_ctyp cval) ctyp (smt_cval ctx cval) - in - Struct (zencode_upper_id struct_id, List.map set_field fields) - | _ -> failwith "Struct does not have struct type" + | V_struct (fields, ctyp) -> begin + match ctyp with + | CT_struct (struct_id, field_ctyps) -> + let set_field (field, cval) = + match Util.assoc_compare_opt Id.compare field field_ctyps with + | None -> failwith "Field type not found" + | Some ctyp -> + ( zencode_upper_id struct_id ^ "_" ^ zencode_id field, + smt_conversion ctx (cval_ctyp cval) ctyp (smt_cval ctx cval) + ) + in + Struct (zencode_upper_id struct_id, List.map set_field fields) + | _ -> failwith "Struct does not have struct type" end - | V_tuple_member (frag, len, n) -> - ctx.tuple_sizes := IntSet.add len !(ctx.tuple_sizes); - Fn (Printf.sprintf "tup_%d_%d" len n, [smt_cval ctx frag]) - | cval -> failwith ("Unrecognised cval " ^ string_of_cval cval) + | V_tuple_member (frag, len, n) -> + ctx.tuple_sizes := IntSet.add len !(ctx.tuple_sizes); + Fn (Printf.sprintf "tup_%d_%d" len n, [smt_cval ctx frag]) + | cval -> failwith ("Unrecognised cval " ^ string_of_cval cval) + ) (**************************************************************************) (* 1. Generating SMT for Sail builtins *) @@ -428,38 +407,28 @@ let builtin_type_error ctx fn cvals = let args = Util.string_of_list ", " (fun cval -> string_of_ctyp (cval_ctyp cval)) cvals in function | Some ret_ctyp -> - let message = Printf.sprintf "%s : (%s) -> %s" fn args (string_of_ctyp ret_ctyp) in - raise (Reporting.err_todo ctx.pragma_l message) - | None -> - raise (Reporting.err_todo ctx.pragma_l (Printf.sprintf "%s : (%s)" fn args)) + let message = Printf.sprintf "%s : (%s) -> %s" fn args (string_of_ctyp ret_ctyp) in + raise (Reporting.err_todo ctx.pragma_l message) + | None -> raise (Reporting.err_todo ctx.pragma_l (Printf.sprintf "%s : (%s)" fn args)) (* ***** Basic comparisons: lib/flow.sail ***** *) let builtin_int_comparison fn big_int_fn ctx v1 v2 = - match cval_ctyp v1, cval_ctyp v2 with - | CT_lint, CT_lint -> - Fn (fn, [smt_cval ctx v1; smt_cval ctx v2]) + match (cval_ctyp v1, cval_ctyp v2) with + | CT_lint, CT_lint -> Fn (fn, [smt_cval ctx v1; smt_cval ctx v2]) | CT_fint sz1, CT_fint sz2 -> - if sz1 == sz2 then - Fn (fn, [smt_cval ctx v1; smt_cval ctx v2]) - else if sz1 > sz2 then - Fn (fn, [smt_cval ctx v1; SignExtend (sz1 - sz2, smt_cval ctx v2)]) - else - Fn (fn, [SignExtend (sz2 - sz1, smt_cval ctx v1); smt_cval ctx v2]) - | CT_constant c, CT_fint sz -> - Fn (fn, [bvint sz c; smt_cval ctx v2]) - | CT_constant c, CT_lint -> - Fn (fn, [bvint ctx.lint_size c; smt_cval ctx v2]) - | CT_fint sz, CT_constant c -> - Fn (fn, [smt_cval ctx v1; bvint sz c]) + if sz1 == sz2 then Fn (fn, [smt_cval ctx v1; smt_cval ctx v2]) + else if sz1 > sz2 then Fn (fn, [smt_cval ctx v1; SignExtend (sz1 - sz2, smt_cval ctx v2)]) + else Fn (fn, [SignExtend (sz2 - sz1, smt_cval ctx v1); smt_cval ctx v2]) + | CT_constant c, CT_fint sz -> Fn (fn, [bvint sz c; smt_cval ctx v2]) + | CT_constant c, CT_lint -> Fn (fn, [bvint ctx.lint_size c; smt_cval ctx v2]) + | CT_fint sz, CT_constant c -> Fn (fn, [smt_cval ctx v1; bvint sz c]) | CT_fint sz, CT_lint when sz < ctx.lint_size -> - Fn (fn, [SignExtend (ctx.lint_size - sz, smt_cval ctx v1); smt_cval ctx v2]) + Fn (fn, [SignExtend (ctx.lint_size - sz, smt_cval ctx v1); smt_cval ctx v2]) | CT_lint, CT_fint sz when sz < ctx.lint_size -> - Fn (fn, [smt_cval ctx v1; SignExtend (ctx.lint_size - sz, smt_cval ctx v2)]) - | CT_lint, CT_constant c -> - Fn (fn, [smt_cval ctx v1; bvint ctx.lint_size c]) - | CT_constant c1, CT_constant c2 -> - Bool_lit (big_int_fn c1 c2) + Fn (fn, [smt_cval ctx v1; SignExtend (ctx.lint_size - sz, smt_cval ctx v2)]) + | CT_lint, CT_constant c -> Fn (fn, [smt_cval ctx v1; bvint ctx.lint_size c]) + | CT_constant c1, CT_constant c2 -> Bool_lit (big_int_fn c1 c2) | _, _ -> builtin_type_error ctx fn [v1; v2] None let builtin_eq_int = builtin_int_comparison "=" Big_int.equal @@ -482,19 +451,23 @@ let builtin_arith fn big_int_fn padding ctx v1 v2 ret_ctyp = to some size determined by a padding function, then check we don't lose precision when going back after performing the operation. *) - let padding = if !opt_ignore_overflow then (fun x -> x) else padding in - match cval_ctyp v1, cval_ctyp v2, ret_ctyp with - | _, _, CT_constant c -> - bvint (required_width c) c - | CT_constant c1, CT_constant c2, _ -> - bvint (int_size ctx ret_ctyp) (big_int_fn c1 c2) - + let padding = if !opt_ignore_overflow then fun x -> x else padding in + match (cval_ctyp v1, cval_ctyp v2, ret_ctyp) with + | _, _, CT_constant c -> bvint (required_width c) c + | CT_constant c1, CT_constant c2, _ -> bvint (int_size ctx ret_ctyp) (big_int_fn c1 c2) | ctyp1, ctyp2, _ -> - let ret_sz = int_size ctx ret_ctyp in - let smt1 = smt_cval ctx v1 in - let smt2 = smt_cval ctx v2 in - force_size ctx ret_sz (padding ret_sz) (Fn (fn, [force_size ctx (padding ret_sz) (int_size ctx ctyp1) smt1; - force_size ctx (padding ret_sz) (int_size ctx ctyp2) smt2])) + let ret_sz = int_size ctx ret_ctyp in + let smt1 = smt_cval ctx v1 in + let smt2 = smt_cval ctx v2 in + force_size ctx ret_sz (padding ret_sz) + (Fn + ( fn, + [ + force_size ctx (padding ret_sz) (int_size ctx ctyp1) smt1; + force_size ctx (padding ret_sz) (int_size ctx ctyp2) smt2; + ] + ) + ) let builtin_add_int = builtin_arith "bvadd" Big_int.add (fun x -> x + 1) let builtin_sub_int = builtin_arith "bvsub" Big_int.sub (fun x -> x + 1) @@ -502,111 +475,91 @@ let builtin_mult_int = builtin_arith "bvmul" Big_int.mul (fun x -> x * 2) let builtin_sub_nat ctx v1 v2 ret_ctyp = let result = builtin_arith "bvsub" Big_int.sub (fun x -> x + 1) ctx v1 v2 ret_ctyp in - Ite (Fn ("bvslt", [result; bvint (int_size ctx ret_ctyp) Big_int.zero]), - bvint (int_size ctx ret_ctyp) Big_int.zero, - result) + Ite + ( Fn ("bvslt", [result; bvint (int_size ctx ret_ctyp) Big_int.zero]), + bvint (int_size ctx ret_ctyp) Big_int.zero, + result + ) let builtin_negate_int ctx v ret_ctyp = - match cval_ctyp v, ret_ctyp with - | _, CT_constant c -> - bvint (required_width c) c - | CT_constant c, _ -> - bvint (int_size ctx ret_ctyp) (Big_int.negate c) + match (cval_ctyp v, ret_ctyp) with + | _, CT_constant c -> bvint (required_width c) c + | CT_constant c, _ -> bvint (int_size ctx ret_ctyp) (Big_int.negate c) | ctyp, _ -> - let open Sail2_values in - let smt = force_size ctx (int_size ctx ret_ctyp) (int_size ctx ctyp) (smt_cval ctx v) in - overflow_check ctx (Fn ("=", [smt; Bitvec_lit (B1 :: List.init (int_size ctx ret_ctyp - 1) (fun _ -> B0))])); - Fn ("bvneg", [smt]) + let open Sail2_values in + let smt = force_size ctx (int_size ctx ret_ctyp) (int_size ctx ctyp) (smt_cval ctx v) in + overflow_check ctx (Fn ("=", [smt; Bitvec_lit (B1 :: List.init (int_size ctx ret_ctyp - 1) (fun _ -> B0))])); + Fn ("bvneg", [smt]) let builtin_shift_int fn big_int_fn ctx v1 v2 ret_ctyp = - match cval_ctyp v1, cval_ctyp v2, ret_ctyp with - | _, _, CT_constant c -> - bvint (required_width c) c - | CT_constant c1, CT_constant c2, _ -> - bvint (int_size ctx ret_ctyp) (big_int_fn c1 (Big_int.to_int c2)) - + match (cval_ctyp v1, cval_ctyp v2, ret_ctyp) with + | _, _, CT_constant c -> bvint (required_width c) c + | CT_constant c1, CT_constant c2, _ -> bvint (int_size ctx ret_ctyp) (big_int_fn c1 (Big_int.to_int c2)) | ctyp, CT_constant c, _ -> - let n = int_size ctx ctyp in - force_size ctx (int_size ctx ret_ctyp) n (Fn (fn, [smt_cval ctx v1; bvint n c])) + let n = int_size ctx ctyp in + force_size ctx (int_size ctx ret_ctyp) n (Fn (fn, [smt_cval ctx v1; bvint n c])) | CT_constant c, ctyp, _ -> - let n = int_size ctx ctyp in - force_size ctx (int_size ctx ret_ctyp) n (Fn (fn, [bvint n c; smt_cval ctx v2])) - + let n = int_size ctx ctyp in + force_size ctx (int_size ctx ret_ctyp) n (Fn (fn, [bvint n c; smt_cval ctx v2])) | ctyp1, ctyp2, _ -> - let ret_sz = int_size ctx ret_ctyp in - let smt1 = smt_cval ctx v1 in - let smt2 = smt_cval ctx v2 in - (Fn (fn, [force_size ctx ret_sz (int_size ctx ctyp1) smt1; - force_size ctx ret_sz (int_size ctx ctyp2) smt2])) + let ret_sz = int_size ctx ret_ctyp in + let smt1 = smt_cval ctx v1 in + let smt2 = smt_cval ctx v2 in + Fn (fn, [force_size ctx ret_sz (int_size ctx ctyp1) smt1; force_size ctx ret_sz (int_size ctx ctyp2) smt2]) let builtin_shl_int = builtin_shift_int "bvshl" Big_int.shift_left let builtin_shr_int = builtin_shift_int "bvashr" Big_int.shift_right let builtin_abs_int ctx v ret_ctyp = - match cval_ctyp v, ret_ctyp with - | _, CT_constant c -> - bvint (required_width c) c - | CT_constant c, _ -> - bvint (int_size ctx ret_ctyp) (Big_int.abs c) + match (cval_ctyp v, ret_ctyp) with + | _, CT_constant c -> bvint (required_width c) c + | CT_constant c, _ -> bvint (int_size ctx ret_ctyp) (Big_int.abs c) | ctyp, _ -> - let sz = int_size ctx ctyp in - let smt = smt_cval ctx v in - Ite (Fn ("=", [Extract (sz - 1, sz -1, smt); Bitvec_lit [Sail2_values.B1]]), + let sz = int_size ctx ctyp in + let smt = smt_cval ctx v in + Ite + ( Fn ("=", [Extract (sz - 1, sz - 1, smt); Bitvec_lit [Sail2_values.B1]]), force_size ctx (int_size ctx ret_ctyp) sz (Fn ("bvneg", [smt])), - force_size ctx (int_size ctx ret_ctyp) sz smt) + force_size ctx (int_size ctx ret_ctyp) sz smt + ) let builtin_pow2 ctx v ret_ctyp = - match cval_ctyp v, ret_ctyp with + match (cval_ctyp v, ret_ctyp) with | CT_constant n, _ when Big_int.greater_equal n Big_int.zero -> - bvint (int_size ctx ret_ctyp) (Big_int.pow_int_positive 2 (Big_int.to_int n)) - + bvint (int_size ctx ret_ctyp) (Big_int.pow_int_positive 2 (Big_int.to_int n)) | _ -> builtin_type_error ctx "pow2" [v] (Some ret_ctyp) let builtin_max_int ctx v1 v2 ret_ctyp = - match cval_ctyp v1, cval_ctyp v2 with - | CT_constant n, CT_constant m -> - bvint (int_size ctx ret_ctyp) (max n m) - + match (cval_ctyp v1, cval_ctyp v2) with + | CT_constant n, CT_constant m -> bvint (int_size ctx ret_ctyp) (max n m) | ctyp1, ctyp2 -> - let ret_sz = int_size ctx ret_ctyp in - let smt1 = force_size ctx ret_sz (int_size ctx ctyp1) (smt_cval ctx v1) in - let smt2 = force_size ctx ret_sz (int_size ctx ctyp2) (smt_cval ctx v2) in - Ite (Fn ("bvslt", [smt1; smt2]), - smt2, - smt1) + let ret_sz = int_size ctx ret_ctyp in + let smt1 = force_size ctx ret_sz (int_size ctx ctyp1) (smt_cval ctx v1) in + let smt2 = force_size ctx ret_sz (int_size ctx ctyp2) (smt_cval ctx v2) in + Ite (Fn ("bvslt", [smt1; smt2]), smt2, smt1) let builtin_min_int ctx v1 v2 ret_ctyp = - match cval_ctyp v1, cval_ctyp v2 with - | CT_constant n, CT_constant m -> - bvint (int_size ctx ret_ctyp) (min n m) - + match (cval_ctyp v1, cval_ctyp v2) with + | CT_constant n, CT_constant m -> bvint (int_size ctx ret_ctyp) (min n m) | ctyp1, ctyp2 -> - let ret_sz = int_size ctx ret_ctyp in - let smt1 = force_size ctx ret_sz (int_size ctx ctyp1) (smt_cval ctx v1) in - let smt2 = force_size ctx ret_sz (int_size ctx ctyp2) (smt_cval ctx v2) in - Ite (Fn ("bvslt", [smt1; smt2]), - smt1, - smt2) + let ret_sz = int_size ctx ret_ctyp in + let smt1 = force_size ctx ret_sz (int_size ctx ctyp1) (smt_cval ctx v1) in + let smt2 = force_size ctx ret_sz (int_size ctx ctyp2) (smt_cval ctx v2) in + Ite (Fn ("bvslt", [smt1; smt2]), smt1, smt2) let builtin_min_int ctx v1 v2 ret_ctyp = - match cval_ctyp v1, cval_ctyp v2 with - | CT_constant n, CT_constant m -> - bvint (int_size ctx ret_ctyp) (min n m) - + match (cval_ctyp v1, cval_ctyp v2) with + | CT_constant n, CT_constant m -> bvint (int_size ctx ret_ctyp) (min n m) | ctyp1, ctyp2 -> - let ret_sz = int_size ctx ret_ctyp in - let smt1 = force_size ctx ret_sz (int_size ctx ctyp1) (smt_cval ctx v1) in - let smt2 = force_size ctx ret_sz (int_size ctx ctyp2) (smt_cval ctx v2) in - Ite (Fn ("bvslt", [smt1; smt2]), - smt1, - smt2) - -let builtin_tdiv_int = - builtin_arith "bvudiv" (Sail2_values.tdiv_int) (fun x -> x) - -let builtin_tmod_int = - builtin_arith "bvurem" (Sail2_values.tmod_int) (fun x -> x) - + let ret_sz = int_size ctx ret_ctyp in + let smt1 = force_size ctx ret_sz (int_size ctx ctyp1) (smt_cval ctx v1) in + let smt2 = force_size ctx ret_sz (int_size ctx ctyp2) (smt_cval ctx v2) in + Ite (Fn ("bvslt", [smt1; smt2]), smt1, smt2) + +let builtin_tdiv_int = builtin_arith "bvudiv" Sail2_values.tdiv_int (fun x -> x) + +let builtin_tmod_int = builtin_arith "bvurem" Sail2_values.tmod_int (fun x -> x) + let bvmask ctx len = let all_ones = bvones (lbits_size ctx) in let shift = Fn ("concat", [bvzero (lbits_size ctx - ctx.lbits_index); len]) in @@ -615,45 +568,45 @@ let bvmask ctx len = let fbits_mask ctx n len = bvnot (bvshl (bvones n) len) let builtin_eq_bits ctx v1 v2 = - match cval_ctyp v1, cval_ctyp v2 with + match (cval_ctyp v1, cval_ctyp v2) with | CT_fbits (n, _), CT_fbits (m, _) -> - let o = max n m in - let smt1 = unsigned_size ctx o n (smt_cval ctx v1) in - let smt2 = unsigned_size ctx o n (smt_cval ctx v2) in - Fn ("=", [smt1; smt2]) - + let o = max n m in + let smt1 = unsigned_size ctx o n (smt_cval ctx v1) in + let smt2 = unsigned_size ctx o n (smt_cval ctx v2) in + Fn ("=", [smt1; smt2]) | CT_lbits _, CT_lbits _ -> - let len1 = Fn ("len", [smt_cval ctx v1]) in - let contents1 = Fn ("contents", [smt_cval ctx v1]) in - let len2 = Fn ("len", [smt_cval ctx v1]) in - let contents2 = Fn ("contents", [smt_cval ctx v1]) in - Fn ("and", [Fn ("=", [len1; len2]); - Fn ("=", [Fn ("bvand", [bvmask ctx len1; contents1]); Fn ("bvand", [bvmask ctx len2; contents2])])]) - + let len1 = Fn ("len", [smt_cval ctx v1]) in + let contents1 = Fn ("contents", [smt_cval ctx v1]) in + let len2 = Fn ("len", [smt_cval ctx v1]) in + let contents2 = Fn ("contents", [smt_cval ctx v1]) in + Fn + ( "and", + [ + Fn ("=", [len1; len2]); + Fn ("=", [Fn ("bvand", [bvmask ctx len1; contents1]); Fn ("bvand", [bvmask ctx len2; contents2])]); + ] + ) | CT_lbits _, CT_fbits (n, _) -> - let smt1 = unsigned_size ctx n (lbits_size ctx) (Fn ("contents", [smt_cval ctx v1])) in - Fn ("=", [smt1; smt_cval ctx v2]) - + let smt1 = unsigned_size ctx n (lbits_size ctx) (Fn ("contents", [smt_cval ctx v1])) in + Fn ("=", [smt1; smt_cval ctx v2]) | CT_fbits (n, _), CT_lbits _ -> - let smt2 = unsigned_size ctx n (lbits_size ctx) (Fn ("contents", [smt_cval ctx v2])) in - Fn ("=", [smt_cval ctx v1; smt2]) - + let smt2 = unsigned_size ctx n (lbits_size ctx) (Fn ("contents", [smt_cval ctx v2])) in + Fn ("=", [smt_cval ctx v1; smt2]) | _ -> builtin_type_error ctx "eq_bits" [v1; v2] None let builtin_zeros ctx v ret_ctyp = - match cval_ctyp v, ret_ctyp with + match (cval_ctyp v, ret_ctyp) with | _, CT_fbits (n, _) -> bvzero n - | CT_constant c, CT_lbits _ -> - Fn ("Bits", [bvint ctx.lbits_index c; bvzero (lbits_size ctx)]) + | CT_constant c, CT_lbits _ -> Fn ("Bits", [bvint ctx.lbits_index c; bvzero (lbits_size ctx)]) | ctyp, CT_lbits _ when int_size ctx ctyp >= ctx.lbits_index -> - Fn ("Bits", [extract (ctx.lbits_index - 1) 0 (smt_cval ctx v); bvzero (lbits_size ctx)]) + Fn ("Bits", [extract (ctx.lbits_index - 1) 0 (smt_cval ctx v); bvzero (lbits_size ctx)]) | _ -> builtin_type_error ctx "zeros" [v] (Some ret_ctyp) let builtin_ones ctx cval = function | CT_fbits (n, _) -> bvones n | CT_lbits _ -> - let len = extract (ctx.lbits_index - 1) 0 (smt_cval ctx cval) in - Fn ("Bits", [len; Fn ("bvand", [bvmask ctx len; bvones (lbits_size ctx)])]); + let len = extract (ctx.lbits_index - 1) 0 (smt_cval ctx cval) in + Fn ("Bits", [len; Fn ("bvand", [bvmask ctx len; bvones (lbits_size ctx)])]) | ret_ctyp -> builtin_type_error ctx "ones" [cval] (Some ret_ctyp) (* [bvzeint ctx esz cval] (BitVector Zero Extend INTeger), takes a cval @@ -663,91 +616,72 @@ let builtin_ones ctx cval = function let bvzeint ctx esz cval = let sz = int_size ctx (cval_ctyp cval) in match cval with - | V_lit (VL_int n, _) -> - bvint esz n + | V_lit (VL_int n, _) -> bvint esz n | _ -> - let smt = smt_cval ctx cval in - if esz = sz then - smt - else if esz > sz then - Fn ("concat", [bvzero (esz - sz); smt]) - else - Extract (esz - 1, 0, smt) + let smt = smt_cval ctx cval in + if esz = sz then smt else if esz > sz then Fn ("concat", [bvzero (esz - sz); smt]) else Extract (esz - 1, 0, smt) let builtin_zero_extend ctx vbits vlen ret_ctyp = - match cval_ctyp vbits, ret_ctyp with - | CT_fbits (n, _), CT_fbits (m, _) when n = m -> - smt_cval ctx vbits + match (cval_ctyp vbits, ret_ctyp) with + | CT_fbits (n, _), CT_fbits (m, _) when n = m -> smt_cval ctx vbits | CT_fbits (n, _), CT_fbits (m, _) -> - let bv = smt_cval ctx vbits in - Fn ("concat", [bvzero (m - n); bv]) + let bv = smt_cval ctx vbits in + Fn ("concat", [bvzero (m - n); bv]) | CT_lbits _, CT_fbits (m, _) -> - assert (lbits_size ctx >= m); - Extract (m - 1, 0, Fn ("contents", [smt_cval ctx vbits])) + assert (lbits_size ctx >= m); + Extract (m - 1, 0, Fn ("contents", [smt_cval ctx vbits])) | CT_fbits (n, _), CT_lbits _ -> - assert (lbits_size ctx >= n); - let vbits = - if lbits_size ctx = n then smt_cval ctx vbits else - if lbits_size ctx > n then Fn ("concat", [bvzero (lbits_size ctx - n); smt_cval ctx vbits]) else - assert false - in - Fn ("Bits", [bvzeint ctx ctx.lbits_index vlen; vbits]) - + assert (lbits_size ctx >= n); + let vbits = + if lbits_size ctx = n then smt_cval ctx vbits + else if lbits_size ctx > n then Fn ("concat", [bvzero (lbits_size ctx - n); smt_cval ctx vbits]) + else assert false + in + Fn ("Bits", [bvzeint ctx ctx.lbits_index vlen; vbits]) | _ -> builtin_type_error ctx "zero_extend" [vbits; vlen] (Some ret_ctyp) let builtin_sign_extend ctx vbits vlen ret_ctyp = - match cval_ctyp vbits, ret_ctyp with - | CT_fbits (n, _), CT_fbits (m, _) when n = m -> - smt_cval ctx vbits + match (cval_ctyp vbits, ret_ctyp) with + | CT_fbits (n, _), CT_fbits (m, _) when n = m -> smt_cval ctx vbits | CT_fbits (n, _), CT_fbits (m, _) -> - let bv = smt_cval ctx vbits in - let top_bit_one = Fn ("=", [Extract (n - 1, n - 1, bv); Bitvec_lit [Sail2_values.B1]]) in - Ite (top_bit_one, Fn ("concat", [bvones (m - n); bv]), Fn ("concat", [bvzero (m - n); bv])) - + let bv = smt_cval ctx vbits in + let top_bit_one = Fn ("=", [Extract (n - 1, n - 1, bv); Bitvec_lit [Sail2_values.B1]]) in + Ite (top_bit_one, Fn ("concat", [bvones (m - n); bv]), Fn ("concat", [bvzero (m - n); bv])) | _ -> builtin_type_error ctx "sign_extend" [vbits; vlen] (Some ret_ctyp) let builtin_shift shiftop ctx vbits vshift ret_ctyp = match cval_ctyp vbits with | CT_fbits (n, _) -> - let bv = smt_cval ctx vbits in - let len = bvzeint ctx n vshift in - Fn (shiftop, [bv; len]) - + let bv = smt_cval ctx vbits in + let len = bvzeint ctx n vshift in + Fn (shiftop, [bv; len]) | CT_lbits _ -> - let bv = smt_cval ctx vbits in - let shift = bvzeint ctx (lbits_size ctx) vshift in - Fn ("Bits", [Fn ("len", [bv]); Fn (shiftop, [Fn ("contents", [bv]); shift])]) - + let bv = smt_cval ctx vbits in + let shift = bvzeint ctx (lbits_size ctx) vshift in + Fn ("Bits", [Fn ("len", [bv]); Fn (shiftop, [Fn ("contents", [bv]); shift])]) | _ -> builtin_type_error ctx shiftop [vbits; vshift] (Some ret_ctyp) let builtin_not_bits ctx v ret_ctyp = - match cval_ctyp v, ret_ctyp with - | CT_lbits _, CT_fbits (n, _) -> - bvnot (Extract (n - 1, 0, Fn ("contents", [smt_cval ctx v]))) - + match (cval_ctyp v, ret_ctyp) with + | CT_lbits _, CT_fbits (n, _) -> bvnot (Extract (n - 1, 0, Fn ("contents", [smt_cval ctx v]))) | CT_lbits _, CT_lbits _ -> - let bv = smt_cval ctx v in - let len = Fn ("len", [bv]) in - Fn ("Bits", [len; Fn ("bvand", [bvmask ctx len; bvnot (Fn ("contents", [bv]))])]) - - | CT_fbits (n, _), CT_fbits (m, _) when n = m -> - bvnot (smt_cval ctx v) - + let bv = smt_cval ctx v in + let len = Fn ("len", [bv]) in + Fn ("Bits", [len; Fn ("bvand", [bvmask ctx len; bvnot (Fn ("contents", [bv]))])]) + | CT_fbits (n, _), CT_fbits (m, _) when n = m -> bvnot (smt_cval ctx v) | _, _ -> builtin_type_error ctx "not_bits" [v] (Some ret_ctyp) let builtin_bitwise fn ctx v1 v2 ret_ctyp = - match cval_ctyp v1, cval_ctyp v2, ret_ctyp with + match (cval_ctyp v1, cval_ctyp v2, ret_ctyp) with | CT_fbits (n, _), CT_fbits (m, _), CT_fbits (o, _) -> - assert (n = m && m = o); - let smt1 = smt_cval ctx v1 in - let smt2 = smt_cval ctx v2 in - Fn (fn, [smt1; smt2]) - + assert (n = m && m = o); + let smt1 = smt_cval ctx v1 in + let smt2 = smt_cval ctx v2 in + Fn (fn, [smt1; smt2]) | CT_lbits _, CT_lbits _, CT_lbits _ -> - let smt1 = smt_cval ctx v1 in - let smt2 = smt_cval ctx v2 in - Fn ("Bits", [Fn ("len", [smt1]); Fn (fn, [Fn ("contents", [smt1]); Fn ("contents", [smt2])])]) - + let smt1 = smt_cval ctx v1 in + let smt2 = smt_cval ctx v2 in + Fn ("Bits", [Fn ("len", [smt1]); Fn (fn, [Fn ("contents", [smt1]); Fn ("contents", [smt2])])]) | _ -> builtin_type_error ctx fn [v1; v2] (Some ret_ctyp) let builtin_and_bits = builtin_bitwise "bvand" @@ -755,365 +689,315 @@ let builtin_or_bits = builtin_bitwise "bvor" let builtin_xor_bits = builtin_bitwise "bvxor" let builtin_append ctx v1 v2 ret_ctyp = - match cval_ctyp v1, cval_ctyp v2, ret_ctyp with + match (cval_ctyp v1, cval_ctyp v2, ret_ctyp) with | CT_fbits (n, _), CT_fbits (m, _), CT_fbits (o, _) -> - assert (n + m = o); - let smt1 = smt_cval ctx v1 in - let smt2 = smt_cval ctx v2 in - Fn ("concat", [smt1; smt2]) - + assert (n + m = o); + let smt1 = smt_cval ctx v1 in + let smt2 = smt_cval ctx v2 in + Fn ("concat", [smt1; smt2]) | CT_fbits (n, _), CT_lbits _, CT_lbits _ -> - let smt1 = smt_cval ctx v1 in - let smt2 = smt_cval ctx v2 in - let x = Fn ("concat", [bvzero (lbits_size ctx - n); smt1]) in - let shift = Fn ("concat", [bvzero (lbits_size ctx - ctx.lbits_index); Fn ("len", [smt2])]) in - Fn ("Bits", [bvadd (bvint ctx.lbits_index (Big_int.of_int n)) (Fn ("len", [smt2])); - bvor (bvshl x shift) (Fn ("contents", [smt2]))]) - + let smt1 = smt_cval ctx v1 in + let smt2 = smt_cval ctx v2 in + let x = Fn ("concat", [bvzero (lbits_size ctx - n); smt1]) in + let shift = Fn ("concat", [bvzero (lbits_size ctx - ctx.lbits_index); Fn ("len", [smt2])]) in + Fn + ( "Bits", + [ + bvadd (bvint ctx.lbits_index (Big_int.of_int n)) (Fn ("len", [smt2])); + bvor (bvshl x shift) (Fn ("contents", [smt2])); + ] + ) | CT_lbits _, CT_fbits (n, _), CT_fbits (m, _) -> - let smt1 = smt_cval ctx v1 in - let smt2 = smt_cval ctx v2 in - Extract (m - 1, 0, Fn ("concat", [Fn ("contents", [smt1]); smt2])) - + let smt1 = smt_cval ctx v1 in + let smt2 = smt_cval ctx v2 in + Extract (m - 1, 0, Fn ("concat", [Fn ("contents", [smt1]); smt2])) | CT_lbits _, CT_fbits (n, _), CT_lbits _ -> - let smt1 = smt_cval ctx v1 in - let smt2 = smt_cval ctx v2 in - Fn ("Bits", [bvadd (bvint ctx.lbits_index (Big_int.of_int n)) (Fn ("len", [smt1])); - Extract (lbits_size ctx - 1, 0, Fn ("concat", [Fn ("contents", [smt1]); smt2]))]) - + let smt1 = smt_cval ctx v1 in + let smt2 = smt_cval ctx v2 in + Fn + ( "Bits", + [ + bvadd (bvint ctx.lbits_index (Big_int.of_int n)) (Fn ("len", [smt1])); + Extract (lbits_size ctx - 1, 0, Fn ("concat", [Fn ("contents", [smt1]); smt2])); + ] + ) | CT_fbits (n, _), CT_fbits (m, _), CT_lbits _ -> - let smt1 = smt_cval ctx v1 in - let smt2 = smt_cval ctx v2 in - Fn ("Bits", [bvint ctx.lbits_index (Big_int.of_int (n + m)); - unsigned_size ctx (lbits_size ctx) (n + m) (Fn ("concat", [smt1; smt2]))]) - + let smt1 = smt_cval ctx v1 in + let smt2 = smt_cval ctx v2 in + Fn + ( "Bits", + [ + bvint ctx.lbits_index (Big_int.of_int (n + m)); + unsigned_size ctx (lbits_size ctx) (n + m) (Fn ("concat", [smt1; smt2])); + ] + ) | CT_lbits _, CT_lbits _, CT_lbits _ -> - let smt1 = smt_cval ctx v1 in - let smt2 = smt_cval ctx v2 in - let x = Fn ("contents", [smt1]) in - let shift = Fn ("concat", [bvzero (lbits_size ctx - ctx.lbits_index); Fn ("len", [smt2])]) in - Fn ("Bits", [bvadd (Fn ("len", [smt1])) (Fn ("len", [smt2])); bvor (bvshl x shift) (Fn ("contents", [smt2]))]) - + let smt1 = smt_cval ctx v1 in + let smt2 = smt_cval ctx v2 in + let x = Fn ("contents", [smt1]) in + let shift = Fn ("concat", [bvzero (lbits_size ctx - ctx.lbits_index); Fn ("len", [smt2])]) in + Fn ("Bits", [bvadd (Fn ("len", [smt1])) (Fn ("len", [smt2])); bvor (bvshl x shift) (Fn ("contents", [smt2]))]) | CT_lbits _, CT_lbits _, CT_fbits (n, _) -> - let smt1 = smt_cval ctx v1 in - let smt2 = smt_cval ctx v2 in - let x = Fn ("contents", [smt1]) in - let shift = Fn ("concat", [bvzero (lbits_size ctx - ctx.lbits_index); Fn ("len", [smt2])]) in - unsigned_size ctx n (lbits_size ctx) (bvor (bvshl x shift) (Fn ("contents", [smt2]))) - + let smt1 = smt_cval ctx v1 in + let smt2 = smt_cval ctx v2 in + let x = Fn ("contents", [smt1]) in + let shift = Fn ("concat", [bvzero (lbits_size ctx - ctx.lbits_index); Fn ("len", [smt2])]) in + unsigned_size ctx n (lbits_size ctx) (bvor (bvshl x shift) (Fn ("contents", [smt2]))) | _ -> builtin_type_error ctx "append" [v1; v2] (Some ret_ctyp) let builtin_length ctx v ret_ctyp = - match cval_ctyp v, ret_ctyp with - | CT_fbits (n, _), (CT_constant _ | CT_fint _ | CT_lint) -> - bvint (int_size ctx ret_ctyp) (Big_int.of_int n) - + match (cval_ctyp v, ret_ctyp) with + | CT_fbits (n, _), (CT_constant _ | CT_fint _ | CT_lint) -> bvint (int_size ctx ret_ctyp) (Big_int.of_int n) | CT_lbits _, (CT_constant _ | CT_fint _ | CT_lint) -> - let sz = ctx.lbits_index in - let m = int_size ctx ret_ctyp in - let len = Fn ("len", [smt_cval ctx v]) in - if m = sz then - len - else if m > sz then - Fn ("concat", [bvzero (m - sz); len]) - else - Extract (m - 1, 0, len) - + let sz = ctx.lbits_index in + let m = int_size ctx ret_ctyp in + let len = Fn ("len", [smt_cval ctx v]) in + if m = sz then len else if m > sz then Fn ("concat", [bvzero (m - sz); len]) else Extract (m - 1, 0, len) | _, _ -> builtin_type_error ctx "length" [v] (Some ret_ctyp) let builtin_vector_subrange ctx vec i j ret_ctyp = - match cval_ctyp vec, cval_ctyp i, cval_ctyp j, ret_ctyp with + match (cval_ctyp vec, cval_ctyp i, cval_ctyp j, ret_ctyp) with | CT_fbits (n, _), CT_constant i, CT_constant j, CT_fbits _ -> - Extract (Big_int.to_int i, Big_int.to_int j, smt_cval ctx vec) - + Extract (Big_int.to_int i, Big_int.to_int j, smt_cval ctx vec) | CT_lbits _, CT_constant i, CT_constant j, CT_fbits _ -> - Extract (Big_int.to_int i, Big_int.to_int j, Fn ("contents", [smt_cval ctx vec])) - + Extract (Big_int.to_int i, Big_int.to_int j, Fn ("contents", [smt_cval ctx vec])) | CT_fbits (n, _), i_ctyp, CT_constant j, CT_lbits _ when Big_int.equal j Big_int.zero -> - let i' = force_size ~checked:false ctx ctx.lbits_index (int_size ctx i_ctyp) (smt_cval ctx i) in - let len = bvadd i' (bvint ctx.lbits_index (Big_int.of_int 1)) in - Fn ("Bits", [len; Fn ("bvand", [bvmask ctx len; unsigned_size ctx (lbits_size ctx) n (smt_cval ctx vec)])]) - + let i' = force_size ~checked:false ctx ctx.lbits_index (int_size ctx i_ctyp) (smt_cval ctx i) in + let len = bvadd i' (bvint ctx.lbits_index (Big_int.of_int 1)) in + Fn ("Bits", [len; Fn ("bvand", [bvmask ctx len; unsigned_size ctx (lbits_size ctx) n (smt_cval ctx vec)])]) | CT_fbits (n, b), i_ctyp, j_ctyp, ret_ctyp -> - let i' = force_size ctx n (int_size ctx i_ctyp) (smt_cval ctx i) in - let j' = force_size ctx n (int_size ctx j_ctyp) (smt_cval ctx j) in - let len = bvadd (bvadd i' (bvneg j')) (bvint n (Big_int.of_int 1)) in - let vec' = bvand (bvlshr (smt_cval ctx vec) j') (fbits_mask ctx n len) in - smt_conversion ctx (CT_fbits (n, b)) ret_ctyp vec' - + let i' = force_size ctx n (int_size ctx i_ctyp) (smt_cval ctx i) in + let j' = force_size ctx n (int_size ctx j_ctyp) (smt_cval ctx j) in + let len = bvadd (bvadd i' (bvneg j')) (bvint n (Big_int.of_int 1)) in + let vec' = bvand (bvlshr (smt_cval ctx vec) j') (fbits_mask ctx n len) in + smt_conversion ctx (CT_fbits (n, b)) ret_ctyp vec' | _ -> builtin_type_error ctx "vector_subrange" [vec; i; j] (Some ret_ctyp) let builtin_vector_access ctx vec i ret_ctyp = - match cval_ctyp vec, cval_ctyp i, ret_ctyp with - | CT_fbits (n, _), CT_constant i, CT_bit -> - Extract (Big_int.to_int i, Big_int.to_int i, smt_cval ctx vec) + match (cval_ctyp vec, cval_ctyp i, ret_ctyp) with + | CT_fbits (n, _), CT_constant i, CT_bit -> Extract (Big_int.to_int i, Big_int.to_int i, smt_cval ctx vec) | CT_lbits _, CT_constant i, CT_bit -> - Extract (Big_int.to_int i, Big_int.to_int i, Fn ("contents", [smt_cval ctx vec])) - + Extract (Big_int.to_int i, Big_int.to_int i, Fn ("contents", [smt_cval ctx vec])) | CT_lbits _, i_ctyp, CT_bit -> - let shift = force_size ~checked:false ctx (lbits_size ctx) (int_size ctx i_ctyp) (smt_cval ctx i) in - Extract (0, 0, Fn ("bvlshr", [Fn ("contents", [smt_cval ctx vec]); shift])) - - | CT_vector _, CT_constant i, _ -> - Fn ("select", [smt_cval ctx vec; bvint !vector_index i]) + let shift = force_size ~checked:false ctx (lbits_size ctx) (int_size ctx i_ctyp) (smt_cval ctx i) in + Extract (0, 0, Fn ("bvlshr", [Fn ("contents", [smt_cval ctx vec]); shift])) + | CT_vector _, CT_constant i, _ -> Fn ("select", [smt_cval ctx vec; bvint !vector_index i]) | CT_vector _, index_ctyp, _ -> - Fn ("select", [smt_cval ctx vec; force_size ctx !vector_index (int_size ctx index_ctyp) (smt_cval ctx i)]) - + Fn ("select", [smt_cval ctx vec; force_size ctx !vector_index (int_size ctx index_ctyp) (smt_cval ctx i)]) | _ -> builtin_type_error ctx "vector_access" [vec; i] (Some ret_ctyp) let builtin_vector_update ctx vec i x ret_ctyp = - match cval_ctyp vec, cval_ctyp i, cval_ctyp x, ret_ctyp with + match (cval_ctyp vec, cval_ctyp i, cval_ctyp x, ret_ctyp) with | CT_fbits (n, _), CT_constant i, CT_bit, CT_fbits (m, _) when n - 1 > Big_int.to_int i && Big_int.to_int i > 0 -> - assert (n = m); - let top = Extract (n - 1, Big_int.to_int i + 1, smt_cval ctx vec) in - let bot = Extract (Big_int.to_int i - 1, 0, smt_cval ctx vec) in - Fn ("concat", [top; Fn ("concat", [smt_cval ctx x; bot])]) - + assert (n = m); + let top = Extract (n - 1, Big_int.to_int i + 1, smt_cval ctx vec) in + let bot = Extract (Big_int.to_int i - 1, 0, smt_cval ctx vec) in + Fn ("concat", [top; Fn ("concat", [smt_cval ctx x; bot])]) | CT_fbits (n, _), CT_constant i, CT_bit, CT_fbits (m, _) when n - 1 = Big_int.to_int i && Big_int.to_int i > 0 -> - let bot = Extract (Big_int.to_int i - 1, 0, smt_cval ctx vec) in - Fn ("concat", [smt_cval ctx x; bot]) - + let bot = Extract (Big_int.to_int i - 1, 0, smt_cval ctx vec) in + Fn ("concat", [smt_cval ctx x; bot]) | CT_fbits (n, _), CT_constant i, CT_bit, CT_fbits (m, _) when n - 1 > Big_int.to_int i && Big_int.to_int i = 0 -> - let top = Extract (n - 1, 1, smt_cval ctx vec) in - Fn ("concat", [top; smt_cval ctx x]) - - | CT_fbits (n, _), CT_constant i, CT_bit, CT_fbits (m, _) when n - 1 = 0 && Big_int.to_int i = 0 -> - smt_cval ctx x - + let top = Extract (n - 1, 1, smt_cval ctx vec) in + Fn ("concat", [top; smt_cval ctx x]) + | CT_fbits (n, _), CT_constant i, CT_bit, CT_fbits (m, _) when n - 1 = 0 && Big_int.to_int i = 0 -> smt_cval ctx x | CT_vector _, CT_constant i, ctyp, CT_vector _ -> - Fn ("store", [smt_cval ctx vec; bvint !vector_index i; smt_cval ctx x]) + Fn ("store", [smt_cval ctx vec; bvint !vector_index i; smt_cval ctx x]) | CT_vector _, index_ctyp, _, CT_vector _ -> - Fn ("store", [smt_cval ctx vec; force_size ctx !vector_index (int_size ctx index_ctyp) (smt_cval ctx i); smt_cval ctx x]) - + Fn + ( "store", + [smt_cval ctx vec; force_size ctx !vector_index (int_size ctx index_ctyp) (smt_cval ctx i); smt_cval ctx x] + ) | _ -> builtin_type_error ctx "vector_update" [vec; i; x] (Some ret_ctyp) let builtin_vector_update_subrange ctx vec i j x ret_ctyp = - match cval_ctyp vec, cval_ctyp i, cval_ctyp j, cval_ctyp x, ret_ctyp with - | CT_fbits (n, _), CT_constant i, CT_constant j, CT_fbits (sz, _), CT_fbits (m, _) when n - 1 > Big_int.to_int i && Big_int.to_int j > 0 -> - assert (n = m); - let top = Extract (n - 1, Big_int.to_int i + 1, smt_cval ctx vec) in - let bot = Extract (Big_int.to_int j - 1, 0, smt_cval ctx vec) in - Fn ("concat", [top; Fn ("concat", [smt_cval ctx x; bot])]) - - | CT_fbits (n, _), CT_constant i, CT_constant j, CT_fbits (sz, _), CT_fbits (m, _) when n - 1 = Big_int.to_int i && Big_int.to_int j > 0 -> - assert (n = m); - let bot = Extract (Big_int.to_int j - 1, 0, smt_cval ctx vec) in - Fn ("concat", [smt_cval ctx x; bot]) - - | CT_fbits (n, _), CT_constant i, CT_constant j, CT_fbits (sz, _), CT_fbits (m, _) when n - 1 > Big_int.to_int i && Big_int.to_int j = 0 -> - assert (n = m); - let top = Extract (n - 1, Big_int.to_int i + 1, smt_cval ctx vec) in - Fn ("concat", [top; smt_cval ctx x]) - - | CT_fbits (n, _), CT_constant i, CT_constant j, CT_fbits (sz, _), CT_fbits (m, _) when n - 1 = Big_int.to_int i && Big_int.to_int j = 0 -> - smt_cval ctx x - + match (cval_ctyp vec, cval_ctyp i, cval_ctyp j, cval_ctyp x, ret_ctyp) with + | CT_fbits (n, _), CT_constant i, CT_constant j, CT_fbits (sz, _), CT_fbits (m, _) + when n - 1 > Big_int.to_int i && Big_int.to_int j > 0 -> + assert (n = m); + let top = Extract (n - 1, Big_int.to_int i + 1, smt_cval ctx vec) in + let bot = Extract (Big_int.to_int j - 1, 0, smt_cval ctx vec) in + Fn ("concat", [top; Fn ("concat", [smt_cval ctx x; bot])]) + | CT_fbits (n, _), CT_constant i, CT_constant j, CT_fbits (sz, _), CT_fbits (m, _) + when n - 1 = Big_int.to_int i && Big_int.to_int j > 0 -> + assert (n = m); + let bot = Extract (Big_int.to_int j - 1, 0, smt_cval ctx vec) in + Fn ("concat", [smt_cval ctx x; bot]) + | CT_fbits (n, _), CT_constant i, CT_constant j, CT_fbits (sz, _), CT_fbits (m, _) + when n - 1 > Big_int.to_int i && Big_int.to_int j = 0 -> + assert (n = m); + let top = Extract (n - 1, Big_int.to_int i + 1, smt_cval ctx vec) in + Fn ("concat", [top; smt_cval ctx x]) + | CT_fbits (n, _), CT_constant i, CT_constant j, CT_fbits (sz, _), CT_fbits (m, _) + when n - 1 = Big_int.to_int i && Big_int.to_int j = 0 -> + smt_cval ctx x | CT_fbits (n, b), ctyp_i, ctyp_j, ctyp_x, CT_fbits (m, _) -> - assert (n = m); - let i' = force_size ctx n (int_size ctx ctyp_i) (smt_cval ctx i) in - let j' = force_size ctx n (int_size ctx ctyp_j) (smt_cval ctx j) in - let x' = smt_conversion ctx ctyp_x (CT_fbits (n, b)) (smt_cval ctx x) in - let len = bvadd (bvadd i' (bvneg j')) (bvint n (Big_int.of_int 1)) in - let mask = bvshl (fbits_mask ctx n len) j' in - bvor (bvand (smt_cval ctx vec) (bvnot mask)) (bvand (bvshl x' j') mask) - + assert (n = m); + let i' = force_size ctx n (int_size ctx ctyp_i) (smt_cval ctx i) in + let j' = force_size ctx n (int_size ctx ctyp_j) (smt_cval ctx j) in + let x' = smt_conversion ctx ctyp_x (CT_fbits (n, b)) (smt_cval ctx x) in + let len = bvadd (bvadd i' (bvneg j')) (bvint n (Big_int.of_int 1)) in + let mask = bvshl (fbits_mask ctx n len) j' in + bvor (bvand (smt_cval ctx vec) (bvnot mask)) (bvand (bvshl x' j') mask) | _ -> builtin_type_error ctx "vector_update_subrange" [vec; i; j; x] (Some ret_ctyp) let builtin_unsigned ctx v ret_ctyp = - match cval_ctyp v, ret_ctyp with + match (cval_ctyp v, ret_ctyp) with | CT_fbits (n, _), CT_fint m when m > n -> - let smt = smt_cval ctx v in - Fn ("concat", [bvzero (m - n); smt]) - + let smt = smt_cval ctx v in + Fn ("concat", [bvzero (m - n); smt]) | CT_fbits (n, _), CT_lint -> - if n >= ctx.lint_size then - failwith "Overflow detected" - else - let smt = smt_cval ctx v in - Fn ("concat", [bvzero (ctx.lint_size - n); smt]) - - | CT_lbits _, CT_lint -> - Extract (ctx.lint_size - 1, 0, Fn ("contents", [smt_cval ctx v])) - + if n >= ctx.lint_size then failwith "Overflow detected" + else ( + let smt = smt_cval ctx v in + Fn ("concat", [bvzero (ctx.lint_size - n); smt]) + ) + | CT_lbits _, CT_lint -> Extract (ctx.lint_size - 1, 0, Fn ("contents", [smt_cval ctx v])) | CT_lbits _, CT_fint m -> - let smt = Fn ("contents", [smt_cval ctx v]) in - force_size ctx m (lbits_size ctx) smt - + let smt = Fn ("contents", [smt_cval ctx v]) in + force_size ctx m (lbits_size ctx) smt | ctyp, _ -> builtin_type_error ctx "unsigned" [v] (Some ret_ctyp) let builtin_signed ctx v ret_ctyp = - match cval_ctyp v, ret_ctyp with - | CT_fbits (n, _), CT_fint m when m >= n -> - SignExtend(m - n, smt_cval ctx v) - - | CT_fbits (n, _), CT_lint -> - SignExtend(ctx.lint_size - n, smt_cval ctx v) - - | CT_lbits _, CT_lint -> - Extract (ctx.lint_size - 1, 0, Fn ("contents", [smt_cval ctx v])) - + match (cval_ctyp v, ret_ctyp) with + | CT_fbits (n, _), CT_fint m when m >= n -> SignExtend (m - n, smt_cval ctx v) + | CT_fbits (n, _), CT_lint -> SignExtend (ctx.lint_size - n, smt_cval ctx v) + | CT_lbits _, CT_lint -> Extract (ctx.lint_size - 1, 0, Fn ("contents", [smt_cval ctx v])) | ctyp, _ -> builtin_type_error ctx "signed" [v] (Some ret_ctyp) let builtin_add_bits ctx v1 v2 ret_ctyp = - match cval_ctyp v1, cval_ctyp v2, ret_ctyp with + match (cval_ctyp v1, cval_ctyp v2, ret_ctyp) with | CT_fbits (n, _), CT_fbits (m, _), CT_fbits (o, _) -> - assert (n = m && m = o); - Fn ("bvadd", [smt_cval ctx v1; smt_cval ctx v2]) - + assert (n = m && m = o); + Fn ("bvadd", [smt_cval ctx v1; smt_cval ctx v2]) | CT_lbits _, CT_lbits _, CT_lbits _ -> - let smt1 = smt_cval ctx v1 in - let smt2 = smt_cval ctx v2 in - Fn ("Bits", [Fn ("len", [smt1]); Fn ("bvadd", [Fn ("contents", [smt1]); Fn ("contents", [smt2])])]) - + let smt1 = smt_cval ctx v1 in + let smt2 = smt_cval ctx v2 in + Fn ("Bits", [Fn ("len", [smt1]); Fn ("bvadd", [Fn ("contents", [smt1]); Fn ("contents", [smt2])])]) | _ -> builtin_type_error ctx "add_bits" [v1; v2] (Some ret_ctyp) let builtin_sub_bits ctx v1 v2 ret_ctyp = - match cval_ctyp v1, cval_ctyp v2, ret_ctyp with + match (cval_ctyp v1, cval_ctyp v2, ret_ctyp) with | CT_fbits (n, _), CT_fbits (m, _), CT_fbits (o, _) -> - assert (n = m && m = o); - Fn ("bvadd", [smt_cval ctx v1; Fn ("bvneg", [smt_cval ctx v2])]) - + assert (n = m && m = o); + Fn ("bvadd", [smt_cval ctx v1; Fn ("bvneg", [smt_cval ctx v2])]) | _ -> failwith "Cannot compile sub_bits" let builtin_add_bits_int ctx v1 v2 ret_ctyp = - match cval_ctyp v1, cval_ctyp v2, ret_ctyp with - | CT_fbits (n, _), CT_constant c, CT_fbits (o, _) when n = o -> - Fn ("bvadd", [smt_cval ctx v1; bvint o c]) - + match (cval_ctyp v1, cval_ctyp v2, ret_ctyp) with + | CT_fbits (n, _), CT_constant c, CT_fbits (o, _) when n = o -> Fn ("bvadd", [smt_cval ctx v1; bvint o c]) | CT_fbits (n, _), CT_fint m, CT_fbits (o, _) when n = o -> - Fn ("bvadd", [smt_cval ctx v1; force_size ctx o m (smt_cval ctx v2)]) - + Fn ("bvadd", [smt_cval ctx v1; force_size ctx o m (smt_cval ctx v2)]) | CT_fbits (n, _), CT_lint, CT_fbits (o, _) when n = o -> - Fn ("bvadd", [smt_cval ctx v1; force_size ctx o ctx.lint_size (smt_cval ctx v2)]) - + Fn ("bvadd", [smt_cval ctx v1; force_size ctx o ctx.lint_size (smt_cval ctx v2)]) | CT_lbits _, CT_fint n, CT_lbits _ when n < lbits_size ctx -> - let smt1 = smt_cval ctx v1 in - let smt2 = force_size ctx (lbits_size ctx) n (smt_cval ctx v2) in - Fn ("Bits", [Fn ("len", [smt1]); Fn ("bvadd", [Fn ("contents", [smt1]); smt2])]) - + let smt1 = smt_cval ctx v1 in + let smt2 = force_size ctx (lbits_size ctx) n (smt_cval ctx v2) in + Fn ("Bits", [Fn ("len", [smt1]); Fn ("bvadd", [Fn ("contents", [smt1]); smt2])]) | _ -> builtin_type_error ctx "add_bits_int" [v1; v2] (Some ret_ctyp) let builtin_sub_bits_int ctx v1 v2 ret_ctyp = - match cval_ctyp v1, cval_ctyp v2, ret_ctyp with + match (cval_ctyp v1, cval_ctyp v2, ret_ctyp) with | CT_fbits (n, _), CT_constant c, CT_fbits (o, _) when n = o -> - Fn ("bvadd", [smt_cval ctx v1; bvint o (Big_int.negate c)]) - + Fn ("bvadd", [smt_cval ctx v1; bvint o (Big_int.negate c)]) | CT_fbits (n, _), CT_fint m, CT_fbits (o, _) when n = o -> - Fn ("bvsub", [smt_cval ctx v1; force_size ctx o m (smt_cval ctx v2)]) - + Fn ("bvsub", [smt_cval ctx v1; force_size ctx o m (smt_cval ctx v2)]) | CT_fbits (n, _), CT_lint, CT_fbits (o, _) when n = o -> - Fn ("bvsub", [smt_cval ctx v1; force_size ctx o ctx.lint_size (smt_cval ctx v2)]) - + Fn ("bvsub", [smt_cval ctx v1; force_size ctx o ctx.lint_size (smt_cval ctx v2)]) | _ -> builtin_type_error ctx "sub_bits_int" [v1; v2] (Some ret_ctyp) let builtin_replicate_bits ctx v1 v2 ret_ctyp = - match cval_ctyp v1, cval_ctyp v2, ret_ctyp with + match (cval_ctyp v1, cval_ctyp v2, ret_ctyp) with | CT_fbits (n, _), CT_constant c, CT_fbits (m, _) -> - assert (n * Big_int.to_int c = m); - let smt = smt_cval ctx v1 in - Fn ("concat", List.init (Big_int.to_int c) (fun _ -> smt)) - + assert (n * Big_int.to_int c = m); + let smt = smt_cval ctx v1 in + Fn ("concat", List.init (Big_int.to_int c) (fun _ -> smt)) | CT_fbits (n, _), _, CT_fbits (m, _) -> - let smt = smt_cval ctx v1 in - let c = m / n in - Fn ("concat", List.init c (fun _ -> smt)) - + let smt = smt_cval ctx v1 in + let c = m / n in + Fn ("concat", List.init c (fun _ -> smt)) | CT_fbits (n, _), v2_ctyp, CT_lbits _ -> - let times = (lbits_size ctx / n) + 1 in - let len = force_size ~checked:false ctx ctx.lbits_index (int_size ctx v2_ctyp) (smt_cval ctx v2) in - let smt1 = smt_cval ctx v1 in - let contents = Extract (lbits_size ctx - 1, 0, Fn ("concat", List.init times (fun _ -> smt1))) in - Fn ("Bits", [len; Fn ("bvand", [bvmask ctx len; contents])]) - + let times = (lbits_size ctx / n) + 1 in + let len = force_size ~checked:false ctx ctx.lbits_index (int_size ctx v2_ctyp) (smt_cval ctx v2) in + let smt1 = smt_cval ctx v1 in + let contents = Extract (lbits_size ctx - 1, 0, Fn ("concat", List.init times (fun _ -> smt1))) in + Fn ("Bits", [len; Fn ("bvand", [bvmask ctx len; contents])]) | _ -> builtin_type_error ctx "replicate_bits" [v1; v2] (Some ret_ctyp) let builtin_sail_truncate ctx v1 v2 ret_ctyp = - match cval_ctyp v1, cval_ctyp v2, ret_ctyp with + match (cval_ctyp v1, cval_ctyp v2, ret_ctyp) with | CT_fbits (n, _), CT_constant c, CT_fbits (m, _) -> - assert (Big_int.to_int c = m); - Extract (Big_int.to_int c - 1, 0, smt_cval ctx v1) - + assert (Big_int.to_int c = m); + Extract (Big_int.to_int c - 1, 0, smt_cval ctx v1) | CT_lbits _, CT_constant c, CT_fbits (m, _) -> - assert (Big_int.to_int c = m && m < lbits_size ctx); - Extract (Big_int.to_int c - 1, 0, Fn ("contents", [smt_cval ctx v1])) - + assert (Big_int.to_int c = m && m < lbits_size ctx); + Extract (Big_int.to_int c - 1, 0, Fn ("contents", [smt_cval ctx v1])) | CT_fbits (n, _), _, CT_lbits _ -> - let smt1 = unsigned_size ctx (lbits_size ctx) n (smt_cval ctx v1) in - let smt2 = bvzeint ctx ctx.lbits_index v2 in - Fn ("Bits", [smt2; Fn ("bvand", [bvmask ctx smt2; smt1])]) - + let smt1 = unsigned_size ctx (lbits_size ctx) n (smt_cval ctx v1) in + let smt2 = bvzeint ctx ctx.lbits_index v2 in + Fn ("Bits", [smt2; Fn ("bvand", [bvmask ctx smt2; smt1])]) | _ -> builtin_type_error ctx "sail_truncate" [v1; v2] (Some ret_ctyp) let builtin_sail_truncateLSB ctx v1 v2 ret_ctyp = - match cval_ctyp v1, cval_ctyp v2, ret_ctyp with + match (cval_ctyp v1, cval_ctyp v2, ret_ctyp) with | CT_fbits (n, _), CT_constant c, CT_fbits (m, _) -> - assert (Big_int.to_int c = m); - Extract (n - 1, n - Big_int.to_int c, smt_cval ctx v1) - + assert (Big_int.to_int c = m); + Extract (n - 1, n - Big_int.to_int c, smt_cval ctx v1) | _ -> builtin_type_error ctx "sail_truncateLSB" [v1; v2] (Some ret_ctyp) let builtin_slice ctx v1 v2 v3 ret_ctyp = - match cval_ctyp v1, cval_ctyp v2, cval_ctyp v3, ret_ctyp with + match (cval_ctyp v1, cval_ctyp v2, cval_ctyp v3, ret_ctyp) with | CT_lbits _, CT_constant start, CT_constant len, CT_fbits (_, _) -> - let top = Big_int.pred (Big_int.add start len) in - Extract(Big_int.to_int top, Big_int.to_int start, Fn ("contents", [smt_cval ctx v1])) - + let top = Big_int.pred (Big_int.add start len) in + Extract (Big_int.to_int top, Big_int.to_int start, Fn ("contents", [smt_cval ctx v1])) | CT_fbits (_, _), CT_constant start, CT_constant len, CT_fbits (_, _) -> - let top = Big_int.pred (Big_int.add start len) in - Extract(Big_int.to_int top, Big_int.to_int start, smt_cval ctx v1) - + let top = Big_int.pred (Big_int.add start len) in + Extract (Big_int.to_int top, Big_int.to_int start, smt_cval ctx v1) | CT_fbits (_, ord), CT_fint _, CT_constant len, CT_fbits (_, _) -> - Extract(Big_int.to_int (Big_int.pred len), 0, builtin_shift "bvlshr" ctx v1 v2 (cval_ctyp v1)) - - | CT_fbits(n, ord), ctyp2, _, CT_lbits _ -> - let smt1 = force_size ctx (lbits_size ctx) n (smt_cval ctx v1) in - let smt2 = force_size ctx (lbits_size ctx) (int_size ctx ctyp2) (smt_cval ctx v2) in - let smt3 = bvzeint ctx ctx.lbits_index v3 in - Fn ("Bits", [smt3; Fn ("bvand", [Fn ("bvlshr", [smt1; smt2]); bvmask ctx smt3])]) - + Extract (Big_int.to_int (Big_int.pred len), 0, builtin_shift "bvlshr" ctx v1 v2 (cval_ctyp v1)) + | CT_fbits (n, ord), ctyp2, _, CT_lbits _ -> + let smt1 = force_size ctx (lbits_size ctx) n (smt_cval ctx v1) in + let smt2 = force_size ctx (lbits_size ctx) (int_size ctx ctyp2) (smt_cval ctx v2) in + let smt3 = bvzeint ctx ctx.lbits_index v3 in + Fn ("Bits", [smt3; Fn ("bvand", [Fn ("bvlshr", [smt1; smt2]); bvmask ctx smt3])]) | _ -> builtin_type_error ctx "slice" [v1; v2; v3] (Some ret_ctyp) let builtin_get_slice_int ctx v1 v2 v3 ret_ctyp = - match cval_ctyp v1, cval_ctyp v2, cval_ctyp v3, ret_ctyp with + match (cval_ctyp v1, cval_ctyp v2, cval_ctyp v3, ret_ctyp) with | CT_constant len, ctyp, CT_constant start, CT_fbits (ret_sz, _) -> - let len = Big_int.to_int len in - let start = Big_int.to_int start in - let in_sz = int_size ctx ctyp in - let smt = - if in_sz < len + start then - force_size ctx (len + start) in_sz (smt_cval ctx v2) - else - smt_cval ctx v2 - in - Extract ((start + len) - 1, start, smt) - + let len = Big_int.to_int len in + let start = Big_int.to_int start in + let in_sz = int_size ctx ctyp in + let smt = if in_sz < len + start then force_size ctx (len + start) in_sz (smt_cval ctx v2) else smt_cval ctx v2 in + Extract (start + len - 1, start, smt) | CT_lint, CT_lint, CT_constant start, CT_lbits _ when Big_int.equal start Big_int.zero -> - let len = Extract (ctx.lbits_index - 1, 0, smt_cval ctx v1) in - let contents = unsigned_size ~checked:false ctx (lbits_size ctx) ctx.lint_size (smt_cval ctx v2) in - Fn ("Bits", [len; Fn ("bvand", [bvmask ctx len; contents])]) - + let len = Extract (ctx.lbits_index - 1, 0, smt_cval ctx v1) in + let contents = unsigned_size ~checked:false ctx (lbits_size ctx) ctx.lint_size (smt_cval ctx v2) in + Fn ("Bits", [len; Fn ("bvand", [bvmask ctx len; contents])]) | CT_lint, ctyp2, ctyp3, ret_ctyp -> - let len = Extract (ctx.lbits_index - 1, 0, smt_cval ctx v1) in - let smt2 = force_size ctx (lbits_size ctx) (int_size ctx ctyp2) (smt_cval ctx v2) in - let smt3 = force_size ctx (lbits_size ctx) (int_size ctx ctyp3) (smt_cval ctx v3) in - let result = bvand (bvmask ctx len) (bvlshr smt2 smt3) in - smt_conversion ctx CT_lint ret_ctyp result - + let len = Extract (ctx.lbits_index - 1, 0, smt_cval ctx v1) in + let smt2 = force_size ctx (lbits_size ctx) (int_size ctx ctyp2) (smt_cval ctx v2) in + let smt3 = force_size ctx (lbits_size ctx) (int_size ctx ctyp3) (smt_cval ctx v3) in + let result = bvand (bvmask ctx len) (bvlshr smt2 smt3) in + smt_conversion ctx CT_lint ret_ctyp result | _ -> builtin_type_error ctx "get_slice_int" [v1; v2; v3] (Some ret_ctyp) let builtin_count_leading_zeros ctx v ret_ctyp = let ret_sz = int_size ctx ret_ctyp in let rec lzcnt sz smt = if sz == 1 then - Ite (Fn ("=", [Extract (0, 0, smt); Bitvec_lit [Sail2_values.B0]]), - bvint ret_sz (Big_int.of_int 1), - bvint ret_sz (Big_int.zero)) + Ite + ( Fn ("=", [Extract (0, 0, smt); Bitvec_lit [Sail2_values.B0]]), + bvint ret_sz (Big_int.of_int 1), + bvint ret_sz Big_int.zero + ) else ( assert (sz land (sz - 1) = 0); let hsz = sz / 2 in - Ite (Fn ("=", [Extract (sz - 1, hsz, smt); bvzero hsz]), - Fn ("bvadd", [bvint ret_sz (Big_int.of_int hsz); lzcnt hsz (Extract (hsz - 1, 0, smt))]), - lzcnt hsz (Extract (sz - 1, hsz, smt))) + Ite + ( Fn ("=", [Extract (sz - 1, hsz, smt); bvzero hsz]), + Fn ("bvadd", [bvint ret_sz (Big_int.of_int hsz); lzcnt hsz (Extract (hsz - 1, 0, smt))]), + lzcnt hsz (Extract (sz - 1, hsz, smt)) + ) ) in let smallest_greater_power_of_two n = @@ -1125,46 +1009,55 @@ let builtin_count_leading_zeros ctx v ret_ctyp = !m in match cval_ctyp v with - | CT_fbits (sz, _) when sz land (sz - 1) = 0 -> - lzcnt sz (smt_cval ctx v) - + | CT_fbits (sz, _) when sz land (sz - 1) = 0 -> lzcnt sz (smt_cval ctx v) | CT_fbits (sz, _) -> - let padded_sz = smallest_greater_power_of_two sz in - let padding = bvzero (padded_sz - sz) in - Fn ("bvsub", [lzcnt padded_sz (Fn ("concat", [padding; smt_cval ctx v])); - bvint ret_sz (Big_int.of_int (padded_sz - sz))]) - + let padded_sz = smallest_greater_power_of_two sz in + let padding = bvzero (padded_sz - sz) in + Fn + ( "bvsub", + [lzcnt padded_sz (Fn ("concat", [padding; smt_cval ctx v])); bvint ret_sz (Big_int.of_int (padded_sz - sz))] + ) | CT_lbits _ -> - let smt = smt_cval ctx v in - Fn ("bvsub", [lzcnt (lbits_size ctx) (Fn ("contents", [smt])); - Fn ("bvsub", [bvint ret_sz (Big_int.of_int (lbits_size ctx)); - Fn ("concat", [bvzero (ret_sz - ctx.lbits_index); Fn ("len", [smt])])])]) - + let smt = smt_cval ctx v in + Fn + ( "bvsub", + [ + lzcnt (lbits_size ctx) (Fn ("contents", [smt])); + Fn + ( "bvsub", + [ + bvint ret_sz (Big_int.of_int (lbits_size ctx)); + Fn ("concat", [bvzero (ret_sz - ctx.lbits_index); Fn ("len", [smt])]); + ] + ); + ] + ) | _ -> builtin_type_error ctx "count_leading_zeros" [v] (Some ret_ctyp) let builtin_set_slice_bits ctx v1 v2 v3 v4 v5 ret_ctyp = - match cval_ctyp v1, cval_ctyp v2, cval_ctyp v3, cval_ctyp v4, cval_ctyp v5, ret_ctyp with - | CT_constant n', CT_constant m', CT_fbits (n, _), CT_constant pos, CT_fbits (m, _), CT_fbits(n'', _) + match (cval_ctyp v1, cval_ctyp v2, cval_ctyp v3, cval_ctyp v4, cval_ctyp v5, ret_ctyp) with + | CT_constant n', CT_constant m', CT_fbits (n, _), CT_constant pos, CT_fbits (m, _), CT_fbits (n'', _) when Big_int.to_int m' = m && Big_int.to_int n' = n && n'' = n && Big_int.less_equal (Big_int.add pos m') n' -> - let pos = Big_int.to_int pos in - if pos = 0 then - let mask = Fn ("concat", [bvones (n - m); bvzero m]) in - let smt5 = Fn ("concat", [bvzero (n - m); smt_cval ctx v5]) in - Fn ("bvor", [Fn ("bvand", [smt_cval ctx v3; mask]); smt5]) - else if n - m - pos = 0 then - let mask = Fn ("concat", [bvzero m; bvones pos]) in - let smt5 = Fn ("concat", [smt_cval ctx v5; bvzero pos]) in - Fn ("bvor", [Fn ("bvand", [smt_cval ctx v3; mask]); smt5]) - else - let mask = Fn ("concat", [bvones (n - m - pos); Fn ("concat", [bvzero m; bvones pos])]) in - let smt5 = Fn ("concat", [bvzero (n - m - pos); Fn ("concat", [smt_cval ctx v5; bvzero pos])]) in - Fn ("bvor", [Fn ("bvand", [smt_cval ctx v3; mask]); smt5]) - + let pos = Big_int.to_int pos in + if pos = 0 then ( + let mask = Fn ("concat", [bvones (n - m); bvzero m]) in + let smt5 = Fn ("concat", [bvzero (n - m); smt_cval ctx v5]) in + Fn ("bvor", [Fn ("bvand", [smt_cval ctx v3; mask]); smt5]) + ) + else if n - m - pos = 0 then ( + let mask = Fn ("concat", [bvzero m; bvones pos]) in + let smt5 = Fn ("concat", [smt_cval ctx v5; bvzero pos]) in + Fn ("bvor", [Fn ("bvand", [smt_cval ctx v3; mask]); smt5]) + ) + else ( + let mask = Fn ("concat", [bvones (n - m - pos); Fn ("concat", [bvzero m; bvones pos])]) in + let smt5 = Fn ("concat", [bvzero (n - m - pos); Fn ("concat", [smt_cval ctx v5; bvzero pos])]) in + Fn ("bvor", [Fn ("bvand", [smt_cval ctx v3; mask]); smt5]) + ) (* set_slice_bits(len, slen, x, pos, y) = let mask = slice_mask(len, pos, slen) in (x AND NOT(mask)) OR ((unsigned_size(len, y) << pos) AND mask) *) - | CT_constant n', _, CT_fbits (n, _), _, CT_lbits _, CT_fbits (n'', _) - when Big_int.to_int n' = n && n'' = n -> + | CT_constant n', _, CT_fbits (n, _), _, CT_lbits _, CT_fbits (n'', _) when Big_int.to_int n' = n && n'' = n -> let pos = bvzeint ctx (lbits_size ctx) v4 in let slen = bvzeint ctx ctx.lbits_index v2 in let mask = Fn ("bvshl", [bvmask ctx slen; pos]) in @@ -1173,24 +1066,20 @@ let builtin_set_slice_bits ctx v1 v2 v3 v4 v5 ret_ctyp = let smt5 = Fn ("contents", [smt_cval ctx v5]) in let smt5' = Fn ("bvand", [Fn ("bvshl", [smt5; pos]); mask]) in Extract (n - 1, 0, Fn ("bvor", [smt3'; smt5'])) - | _ -> builtin_type_error ctx "set_slice" [v1; v2; v3; v4; v5] (Some ret_ctyp) let builtin_compare_bits fn ctx v1 v2 ret_ctyp = - match cval_ctyp v1, cval_ctyp v2 with - | CT_fbits (n, _), CT_fbits (m, _) when n = m -> - Fn (fn, [smt_cval ctx v1; smt_cval ctx v2]) - + match (cval_ctyp v1, cval_ctyp v2) with + | CT_fbits (n, _), CT_fbits (m, _) when n = m -> Fn (fn, [smt_cval ctx v1; smt_cval ctx v2]) | _ -> builtin_type_error ctx fn [v1; v2] (Some ret_ctyp) (* ***** String operations: lib/real.sail ***** *) let builtin_decimal_string_of_bits ctx v = - begin match cval_ctyp v with - | CT_fbits (n, _) -> - Fn ("int.to.str", [Fn ("bv2nat", [smt_cval ctx v])]) - - | _ -> builtin_type_error ctx "decimal_string_of_bits" [v] None + begin + match cval_ctyp v with + | CT_fbits (n, _) -> Fn ("int.to.str", [Fn ("bv2nat", [smt_cval ctx v])]) + | _ -> builtin_type_error ctx "decimal_string_of_bits" [v] None end (* ***** Real number operations: lib/real.sail ***** *) @@ -1198,27 +1087,24 @@ let builtin_decimal_string_of_bits ctx v = let builtin_sqrt_real ctx root v = ctx.use_real := true; let smt = smt_cval ctx v in - [Declare_const (root, Real); - Assert (Fn ("and", [Fn ("=", [smt; Fn ("*", [Var root; Var root])]); - Fn (">=", [Var root; Real_lit "0.0"])]))] + [ + Declare_const (root, Real); + Assert (Fn ("and", [Fn ("=", [smt; Fn ("*", [Var root; Var root])]); Fn (">=", [Var root; Real_lit "0.0"])])); + ] let smt_builtin ctx name args ret_ctyp = - match name, args, ret_ctyp with + match (name, args, ret_ctyp) with | "eq_anything", [v1; v2], CT_bool -> Fn ("=", [smt_cval ctx v1; smt_cval ctx v2]) - (* lib/flow.sail *) - | "eq_bit", [v1; v2], CT_bool -> Fn ("=", [smt_cval ctx v1; smt_cval ctx v2]) + | "eq_bit", [v1; v2], CT_bool -> Fn ("=", [smt_cval ctx v1; smt_cval ctx v2]) | "eq_bool", [v1; v2], CT_bool -> Fn ("=", [smt_cval ctx v1; smt_cval ctx v2]) | "eq_unit", [v1; v2], CT_bool -> Fn ("=", [smt_cval ctx v1; smt_cval ctx v2]) - - | "eq_int", [v1; v2], CT_bool -> builtin_eq_int ctx v1 v2 - + | "eq_int", [v1; v2], CT_bool -> builtin_eq_int ctx v1 v2 | "not", [v], _ -> Fn ("not", [smt_cval ctx v]) - | "lt", [v1; v2], _ -> builtin_lt ctx v1 v2 + | "lt", [v1; v2], _ -> builtin_lt ctx v1 v2 | "lteq", [v1; v2], _ -> builtin_lteq ctx v1 v2 - | "gt", [v1; v2], _ -> builtin_gt ctx v1 v2 + | "gt", [v1; v2], _ -> builtin_gt ctx v1 v2 | "gteq", [v1; v2], _ -> builtin_gteq ctx v1 v2 - (* lib/arith.sail *) | "add_int", [v1; v2], _ -> builtin_add_int ctx v1 v2 ret_ctyp | "sub_int", [v1; v2], _ -> builtin_sub_int ctx v1 v2 ret_ctyp @@ -1231,12 +1117,9 @@ let smt_builtin ctx name args ret_ctyp = | "shr_mach_int", [v1; v2], _ -> builtin_shr_int ctx v1 v2 ret_ctyp | "abs_int", [v], _ -> builtin_abs_int ctx v ret_ctyp | "pow2", [v], _ -> builtin_pow2 ctx v ret_ctyp - | "max_int", [v1; v2], _ -> builtin_max_int ctx v1 v2 ret_ctyp | "min_int", [v1; v2], _ -> builtin_min_int ctx v1 v2 ret_ctyp - | "ediv_int", [v1; v2], _ -> builtin_tdiv_int ctx v1 v2 ret_ctyp - (* All signed and unsigned bitvector comparisons *) | "slt_bits", [v1; v2], CT_bool -> builtin_compare_bits "bvslt" ctx v1 v2 ret_ctyp | "ult_bits", [v1; v2], CT_bool -> builtin_compare_bits "bvult" ctx v1 v2 ret_ctyp @@ -1246,7 +1129,6 @@ let smt_builtin ctx name args ret_ctyp = | "ulteq_bits", [v1; v2], CT_bool -> builtin_compare_bits "bvule" ctx v1 v2 ret_ctyp | "sgteq_bits", [v1; v2], CT_bool -> builtin_compare_bits "bvsge" ctx v1 v2 ret_ctyp | "ugteq_bits", [v1; v2], CT_bool -> builtin_compare_bits "bvuge" ctx v1 v2 ret_ctyp - (* lib/vector_dec.sail *) | "eq_bits", [v1; v2], CT_bool -> builtin_eq_bits ctx v1 v2 | "zeros", [v], _ -> builtin_zeros ctx v ret_ctyp @@ -1281,27 +1163,54 @@ let smt_builtin ctx name args ret_ctyp = | "slice", [v1; v2; v3], ret_ctyp -> builtin_slice ctx v1 v2 v3 ret_ctyp | "get_slice_int", [v1; v2; v3], ret_ctyp -> builtin_get_slice_int ctx v1 v2 v3 ret_ctyp | "set_slice", [v1; v2; v3; v4; v5], ret_ctyp -> builtin_set_slice_bits ctx v1 v2 v3 v4 v5 ret_ctyp - (* string builtins *) - | "concat_str", [v1; v2], CT_string -> ctx.use_string := true; Fn ("str.++", [smt_cval ctx v1; smt_cval ctx v2]) - | "eq_string", [v1; v2], CT_bool -> ctx.use_string := true; Fn ("=", [smt_cval ctx v1; smt_cval ctx v2]) - | "decimal_string_of_bits", [v], CT_string -> ctx.use_string := true; builtin_decimal_string_of_bits ctx v - + | "concat_str", [v1; v2], CT_string -> + ctx.use_string := true; + Fn ("str.++", [smt_cval ctx v1; smt_cval ctx v2]) + | "eq_string", [v1; v2], CT_bool -> + ctx.use_string := true; + Fn ("=", [smt_cval ctx v1; smt_cval ctx v2]) + | "decimal_string_of_bits", [v], CT_string -> + ctx.use_string := true; + builtin_decimal_string_of_bits ctx v (* lib/real.sail *) (* Note that sqrt_real is special and is handled by smt_instr. *) - | "eq_real", [v1; v2], CT_bool -> ctx.use_real := true; Fn ("=", [smt_cval ctx v1; smt_cval ctx v2]) - | "neg_real", [v], CT_real -> ctx.use_real := true; Fn ("-", [smt_cval ctx v]) - | "add_real", [v1; v2], CT_real -> ctx.use_real := true; Fn ("+", [smt_cval ctx v1; smt_cval ctx v2]) - | "sub_real", [v1; v2], CT_real -> ctx.use_real := true; Fn ("-", [smt_cval ctx v1; smt_cval ctx v2]) - | "mult_real", [v1; v2], CT_real -> ctx.use_real := true; Fn ("*", [smt_cval ctx v1; smt_cval ctx v2]) - | "div_real", [v1; v2], CT_real -> ctx.use_real := true; Fn ("/", [smt_cval ctx v1; smt_cval ctx v2]) - | "lt_real", [v1; v2], CT_bool -> ctx.use_real := true; Fn ("<", [smt_cval ctx v1; smt_cval ctx v2]) - | "gt_real", [v1; v2], CT_bool -> ctx.use_real := true; Fn (">", [smt_cval ctx v1; smt_cval ctx v2]) - | "lteq_real", [v1; v2], CT_bool -> ctx.use_real := true; Fn ("<=", [smt_cval ctx v1; smt_cval ctx v2]) - | "gteq_real", [v1; v2], CT_bool -> ctx.use_real := true; Fn (">=", [smt_cval ctx v1; smt_cval ctx v2]) - + | "eq_real", [v1; v2], CT_bool -> + ctx.use_real := true; + Fn ("=", [smt_cval ctx v1; smt_cval ctx v2]) + | "neg_real", [v], CT_real -> + ctx.use_real := true; + Fn ("-", [smt_cval ctx v]) + | "add_real", [v1; v2], CT_real -> + ctx.use_real := true; + Fn ("+", [smt_cval ctx v1; smt_cval ctx v2]) + | "sub_real", [v1; v2], CT_real -> + ctx.use_real := true; + Fn ("-", [smt_cval ctx v1; smt_cval ctx v2]) + | "mult_real", [v1; v2], CT_real -> + ctx.use_real := true; + Fn ("*", [smt_cval ctx v1; smt_cval ctx v2]) + | "div_real", [v1; v2], CT_real -> + ctx.use_real := true; + Fn ("/", [smt_cval ctx v1; smt_cval ctx v2]) + | "lt_real", [v1; v2], CT_bool -> + ctx.use_real := true; + Fn ("<", [smt_cval ctx v1; smt_cval ctx v2]) + | "gt_real", [v1; v2], CT_bool -> + ctx.use_real := true; + Fn (">", [smt_cval ctx v1; smt_cval ctx v2]) + | "lteq_real", [v1; v2], CT_bool -> + ctx.use_real := true; + Fn ("<=", [smt_cval ctx v1; smt_cval ctx v2]) + | "gteq_real", [v1; v2], CT_bool -> + ctx.use_real := true; + Fn (">=", [smt_cval ctx v1; smt_cval ctx v2]) | _ -> - Reporting.unreachable ctx.pragma_l __POS__ ("Unknown builtin " ^ name ^ " " ^ Util.string_of_list ", " string_of_ctyp (List.map cval_ctyp args) ^ " -> " ^ string_of_ctyp ret_ctyp) + Reporting.unreachable ctx.pragma_l __POS__ + ("Unknown builtin " ^ name ^ " " + ^ Util.string_of_list ", " string_of_ctyp (List.map cval_ctyp args) + ^ " -> " ^ string_of_ctyp ret_ctyp + ) let loc_doc _ = "UNKNOWN" @@ -1311,115 +1220,137 @@ let writes = ref (-1) let builtin_write_mem l ctx wk addr_size addr data_size data = incr writes; let name = "W" ^ string_of_int !writes in - [Write_mem { - name = name; - node = ctx.node; - active = Lazy.force ctx.pathcond; - kind = smt_cval ctx wk; - addr = smt_cval ctx addr; - addr_type = smt_ctyp ctx (cval_ctyp addr); - data = smt_cval ctx data; - data_type = smt_ctyp ctx (cval_ctyp data); - doc = loc_doc l - }], - Var (name ^ "_ret") + ( [ + Write_mem + { + name; + node = ctx.node; + active = Lazy.force ctx.pathcond; + kind = smt_cval ctx wk; + addr = smt_cval ctx addr; + addr_type = smt_ctyp ctx (cval_ctyp addr); + data = smt_cval ctx data; + data_type = smt_ctyp ctx (cval_ctyp data); + doc = loc_doc l; + }; + ], + Var (name ^ "_ret") + ) let ea_writes = ref (-1) let builtin_write_mem_ea ctx wk addr_size addr data_size = incr ea_writes; let name = "A" ^ string_of_int !ea_writes in - [Write_mem_ea (name, ctx.node, Lazy.force ctx.pathcond, smt_cval ctx wk, - smt_cval ctx addr, smt_ctyp ctx (cval_ctyp addr), smt_cval ctx data_size, smt_ctyp ctx (cval_ctyp data_size))], - Enum "unit" + ( [ + Write_mem_ea + ( name, + ctx.node, + Lazy.force ctx.pathcond, + smt_cval ctx wk, + smt_cval ctx addr, + smt_ctyp ctx (cval_ctyp addr), + smt_cval ctx data_size, + smt_ctyp ctx (cval_ctyp data_size) + ); + ], + Enum "unit" + ) let reads = ref (-1) let builtin_read_mem l ctx rk addr_size addr data_size ret_ctyp = incr reads; let name = "R" ^ string_of_int !reads in - [Read_mem { - name = name; - node = ctx.node; - active = Lazy.force ctx.pathcond; - ret_type = smt_ctyp ctx ret_ctyp; - kind = smt_cval ctx rk; - addr = smt_cval ctx addr; - addr_type = smt_ctyp ctx (cval_ctyp addr); - doc = loc_doc l - }], - Read_res name + ( [ + Read_mem + { + name; + node = ctx.node; + active = Lazy.force ctx.pathcond; + ret_type = smt_ctyp ctx ret_ctyp; + kind = smt_cval ctx rk; + addr = smt_cval ctx addr; + addr_type = smt_ctyp ctx (cval_ctyp addr); + doc = loc_doc l; + }; + ], + Read_res name + ) let excl_results = ref (-1) let builtin_excl_res ctx = incr excl_results; let name = "E" ^ string_of_int !excl_results in - [Excl_res (name, ctx.node, Lazy.force ctx.pathcond)], - Var (name ^ "_ret") + ([Excl_res (name, ctx.node, Lazy.force ctx.pathcond)], Var (name ^ "_ret")) let barriers = ref (-1) let builtin_barrier l ctx bk = incr barriers; let name = "B" ^ string_of_int !barriers in - [Barrier { - name = name; - node = ctx.node; - active = Lazy.force ctx.pathcond; - kind = smt_cval ctx bk; - doc = loc_doc l - }], - Enum "unit" + ( [Barrier { name; node = ctx.node; active = Lazy.force ctx.pathcond; kind = smt_cval ctx bk; doc = loc_doc l }], + Enum "unit" + ) let cache_maintenances = ref (-1) let builtin_cache_maintenance l ctx cmk addr_size addr = incr cache_maintenances; let name = "M" ^ string_of_int !cache_maintenances in - [Cache_maintenance { - name = name; - node = ctx.node; - active = Lazy.force ctx.pathcond; - kind = smt_cval ctx cmk; - addr = smt_cval ctx addr; - addr_type = smt_ctyp ctx (cval_ctyp addr); - doc = loc_doc l - }], - Enum "unit" + ( [ + Cache_maintenance + { + name; + node = ctx.node; + active = Lazy.force ctx.pathcond; + kind = smt_cval ctx cmk; + addr = smt_cval ctx addr; + addr_type = smt_ctyp ctx (cval_ctyp addr); + doc = loc_doc l; + }; + ], + Enum "unit" + ) let branch_announces = ref (-1) let builtin_branch_announce l ctx addr_size addr = incr branch_announces; let name = "C" ^ string_of_int !branch_announces in - [Branch_announce { - name = name; - node = ctx.node; - active = Lazy.force ctx.pathcond; - addr = smt_cval ctx addr; - addr_type = smt_ctyp ctx (cval_ctyp addr); - doc = loc_doc l - }], - Enum "unit" + ( [ + Branch_announce + { + name; + node = ctx.node; + active = Lazy.force ctx.pathcond; + addr = smt_cval ctx addr; + addr_type = smt_ctyp ctx (cval_ctyp addr); + doc = loc_doc l; + }; + ], + Enum "unit" + ) let define_const ctx id ctyp exp = Define_const (zencode_name id, smt_ctyp ctx ctyp, exp) let preserve_const ctx id ctyp exp = Preserve_const (string_of_id id, smt_ctyp ctx ctyp, exp) let declare_const ctx id ctyp = Declare_const (zencode_name id, smt_ctyp ctx ctyp) let smt_ctype_def ctx = function - | CTD_enum (id, elems) -> - [declare_datatypes (mk_enum (zencode_upper_id id) (List.map zencode_id elems))] - + | CTD_enum (id, elems) -> [declare_datatypes (mk_enum (zencode_upper_id id) (List.map zencode_id elems))] | CTD_struct (id, fields) -> - [declare_datatypes - (mk_record (zencode_upper_id id) - (List.map (fun (field, ctyp) -> zencode_upper_id id ^ "_" ^ zencode_id field, smt_ctyp ctx ctyp) fields))] - + [ + declare_datatypes + (mk_record (zencode_upper_id id) + (List.map (fun (field, ctyp) -> (zencode_upper_id id ^ "_" ^ zencode_id field, smt_ctyp ctx ctyp)) fields) + ); + ] | CTD_variant (id, ctors) -> - [declare_datatypes - (mk_variant (zencode_upper_id id) - (List.map (fun (ctor, ctyp) -> zencode_id ctor, smt_ctyp ctx ctyp) ctors))] + [ + declare_datatypes + (mk_variant (zencode_upper_id id) (List.map (fun (ctor, ctyp) -> (zencode_id ctor, smt_ctyp ctx ctyp)) ctors)); + ] let rec generate_ctype_defs ctx = function | CDEF_type ctd :: cdefs -> smt_ctype_def ctx ctd :: generate_ctype_defs ctx cdefs @@ -1427,9 +1358,8 @@ let rec generate_ctype_defs ctx = function | [] -> [] let rec generate_reg_decs ctx inits = function - | CDEF_register (id, ctyp, _) :: cdefs when not (NameMap.mem (Global (id, 0)) inits)-> - Declare_const (zencode_name (Global (id, 0)), smt_ctyp ctx ctyp) - :: generate_reg_decs ctx inits cdefs + | CDEF_register (id, ctyp, _) :: cdefs when not (NameMap.mem (Global (id, 0)) inits) -> + Declare_const (zencode_name (Global (id, 0)), smt_ctyp ctx ctyp) :: generate_reg_decs ctx inits cdefs | _ :: cdefs -> generate_reg_decs ctx inits cdefs | [] -> [] @@ -1440,7 +1370,9 @@ let rec generate_reg_decs ctx inits = function let max_int n = Big_int.pred (Big_int.pow_int_positive 2 (n - 1)) let min_int n = Big_int.negate (Big_int.pow_int_positive 2 (n - 1)) -module SMT_config(Opts : sig val unroll_limit : int end) : Jib_compile.Config = struct +module SMT_config (Opts : sig + val unroll_limit : int +end) : Jib_compile.Config = struct open Jib_compile (** Convert a sail type into a C-type. This function can be quite @@ -1450,110 +1382,104 @@ module SMT_config(Opts : sig val unroll_limit : int end) : Jib_compile.Config = let rec convert_typ ctx typ = let open Ast in let open Type_check in - let Typ_aux (typ_aux, l) as typ = Env.expand_synonyms ctx.local_env typ in + let (Typ_aux (typ_aux, l) as typ) = Env.expand_synonyms ctx.local_env typ in match typ_aux with - | Typ_id id when string_of_id id = "bit" -> CT_bit - | Typ_id id when string_of_id id = "bool" -> CT_bool - | Typ_id id when string_of_id id = "int" -> CT_lint - | Typ_id id when string_of_id id = "nat" -> CT_lint - | Typ_id id when string_of_id id = "unit" -> CT_unit + | Typ_id id when string_of_id id = "bit" -> CT_bit + | Typ_id id when string_of_id id = "bool" -> CT_bool + | Typ_id id when string_of_id id = "int" -> CT_lint + | Typ_id id when string_of_id id = "nat" -> CT_lint + | Typ_id id when string_of_id id = "unit" -> CT_unit | Typ_id id when string_of_id id = "string" -> CT_string - | Typ_id id when string_of_id id = "real" -> CT_real - + | Typ_id id when string_of_id id = "real" -> CT_real | Typ_app (id, _) when string_of_id id = "atom_bool" -> CT_bool - - | Typ_app (id, args) when string_of_id id = "itself" -> - convert_typ ctx (Typ_aux (Typ_app (mk_id "atom", args), l)) + | Typ_app (id, args) when string_of_id id = "itself" -> convert_typ ctx (Typ_aux (Typ_app (mk_id "atom", args), l)) | Typ_app (id, _) when string_of_id id = "range" || string_of_id id = "atom" || string_of_id id = "implicit" -> - begin match destruct_range ctx.local_env typ with - | None -> assert false (* Checked if range type in guard *) - | Some (kids, constr, n, m) -> - let ctx = { ctx with local_env = add_existential Parse_ast.Unknown (List.map (mk_kopt K_int) kids) constr ctx.local_env } in - match nexp_simp n, nexp_simp m with - | Nexp_aux (Nexp_constant n, _), Nexp_aux (Nexp_constant m, _) - when n = m -> - CT_constant n - | Nexp_aux (Nexp_constant n, _), Nexp_aux (Nexp_constant m, _) - when Big_int.less_equal (min_int 64) n && Big_int.less_equal m (max_int 64) -> - CT_fint 64 - | n, m -> - if prove __POS__ ctx.local_env (nc_lteq (nconstant (min_int 64)) n) && prove __POS__ ctx.local_env (nc_lteq m (nconstant (max_int 64))) then - CT_fint 64 - else - CT_lint - end - - | Typ_app (id, [A_aux (A_typ typ, _)]) when string_of_id id = "list" -> - CT_list (convert_typ ctx typ) - + begin + match destruct_range ctx.local_env typ with + | None -> assert false (* Checked if range type in guard *) + | Some (kids, constr, n, m) -> ( + let ctx = + { + ctx with + local_env = add_existential Parse_ast.Unknown (List.map (mk_kopt K_int) kids) constr ctx.local_env; + } + in + match (nexp_simp n, nexp_simp m) with + | Nexp_aux (Nexp_constant n, _), Nexp_aux (Nexp_constant m, _) when n = m -> CT_constant n + | Nexp_aux (Nexp_constant n, _), Nexp_aux (Nexp_constant m, _) + when Big_int.less_equal (min_int 64) n && Big_int.less_equal m (max_int 64) -> + CT_fint 64 + | n, m -> + if + prove __POS__ ctx.local_env (nc_lteq (nconstant (min_int 64)) n) + && prove __POS__ ctx.local_env (nc_lteq m (nconstant (max_int 64))) + then CT_fint 64 + else CT_lint + ) + end + | Typ_app (id, [A_aux (A_typ typ, _)]) when string_of_id id = "list" -> CT_list (convert_typ ctx typ) (* Note that we have to use lbits for zero-length bitvectors because they are not allowed by SMTLIB *) - | Typ_app (id, [A_aux (A_nexp n, _); A_aux (A_order ord, _)]) - when string_of_id id = "bitvector" -> - let direction = match ord with Ord_aux (Ord_dec, _) -> true | Ord_aux (Ord_inc, _) -> false | _ -> assert false in - begin match nexp_simp n with - | Nexp_aux (Nexp_constant n, _) when Big_int.equal n Big_int.zero -> CT_lbits direction - | Nexp_aux (Nexp_constant n, _) -> CT_fbits (Big_int.to_int n, direction) - | _ -> CT_lbits direction - end - - | Typ_app (id, [A_aux (A_nexp n, _); - A_aux (A_order ord, _); - A_aux (A_typ typ, _)]) - when string_of_id id = "vector" -> - let direction = match ord with Ord_aux (Ord_dec, _) -> true | Ord_aux (Ord_inc, _) -> false | _ -> assert false in - CT_vector (direction, convert_typ ctx typ) - - | Typ_app (id, [A_aux (A_typ typ, _)]) when string_of_id id = "register" -> - CT_ref (convert_typ ctx typ) - - | Typ_id id when Bindings.mem id ctx.records -> CT_struct (id, Bindings.find id ctx.records |> snd |> Bindings.bindings) + | Typ_app (id, [A_aux (A_nexp n, _); A_aux (A_order ord, _)]) when string_of_id id = "bitvector" -> + let direction = + match ord with Ord_aux (Ord_dec, _) -> true | Ord_aux (Ord_inc, _) -> false | _ -> assert false + in + begin + match nexp_simp n with + | Nexp_aux (Nexp_constant n, _) when Big_int.equal n Big_int.zero -> CT_lbits direction + | Nexp_aux (Nexp_constant n, _) -> CT_fbits (Big_int.to_int n, direction) + | _ -> CT_lbits direction + end + | Typ_app (id, [A_aux (A_nexp n, _); A_aux (A_order ord, _); A_aux (A_typ typ, _)]) when string_of_id id = "vector" + -> + let direction = + match ord with Ord_aux (Ord_dec, _) -> true | Ord_aux (Ord_inc, _) -> false | _ -> assert false + in + CT_vector (direction, convert_typ ctx typ) + | Typ_app (id, [A_aux (A_typ typ, _)]) when string_of_id id = "register" -> CT_ref (convert_typ ctx typ) + | Typ_id id when Bindings.mem id ctx.records -> + CT_struct (id, Bindings.find id ctx.records |> snd |> Bindings.bindings) | Typ_app (id, typ_args) when Bindings.mem id ctx.records -> - let (typ_params, fields) = Bindings.find id ctx.records in - let quants = - List.fold_left2 (fun quants typ_param typ_arg -> - match typ_arg with - | A_aux (A_typ typ, _) -> - KBindings.add typ_param (convert_typ ctx typ) quants - | _ -> - Reporting.unreachable l __POS__ "Non-type argument for record here should be impossible" - ) ctx.quants typ_params (List.filter is_typ_arg_typ typ_args) - in - let fix_ctyp ctyp = if is_polymorphic ctyp then ctyp_suprema (subst_poly quants ctyp) else ctyp in - CT_struct (id, Bindings.map fix_ctyp fields |> Bindings.bindings) - - | Typ_id id when Bindings.mem id ctx.variants -> CT_variant (id, Bindings.find id ctx.variants |> snd |> Bindings.bindings) + let typ_params, fields = Bindings.find id ctx.records in + let quants = + List.fold_left2 + (fun quants typ_param typ_arg -> + match typ_arg with + | A_aux (A_typ typ, _) -> KBindings.add typ_param (convert_typ ctx typ) quants + | _ -> Reporting.unreachable l __POS__ "Non-type argument for record here should be impossible" + ) + ctx.quants typ_params (List.filter is_typ_arg_typ typ_args) + in + let fix_ctyp ctyp = if is_polymorphic ctyp then ctyp_suprema (subst_poly quants ctyp) else ctyp in + CT_struct (id, Bindings.map fix_ctyp fields |> Bindings.bindings) + | Typ_id id when Bindings.mem id ctx.variants -> + CT_variant (id, Bindings.find id ctx.variants |> snd |> Bindings.bindings) | Typ_app (id, typ_args) when Bindings.mem id ctx.variants -> - let (typ_params, ctors) = Bindings.find id ctx.variants in - let quants = - List.fold_left2 (fun quants typ_param typ_arg -> - match typ_arg with - | A_aux (A_typ typ, _) -> - KBindings.add typ_param (convert_typ ctx typ) quants - | _ -> - Reporting.unreachable l __POS__ "Non-type argument for variant here should be impossible" - ) ctx.quants typ_params (List.filter is_typ_arg_typ typ_args) - in - let fix_ctyp ctyp = if is_polymorphic ctyp then ctyp_suprema (subst_poly quants ctyp) else ctyp in - CT_variant (id, Bindings.map fix_ctyp ctors |> Bindings.bindings) - + let typ_params, ctors = Bindings.find id ctx.variants in + let quants = + List.fold_left2 + (fun quants typ_param typ_arg -> + match typ_arg with + | A_aux (A_typ typ, _) -> KBindings.add typ_param (convert_typ ctx typ) quants + | _ -> Reporting.unreachable l __POS__ "Non-type argument for variant here should be impossible" + ) + ctx.quants typ_params (List.filter is_typ_arg_typ typ_args) + in + let fix_ctyp ctyp = if is_polymorphic ctyp then ctyp_suprema (subst_poly quants ctyp) else ctyp in + CT_variant (id, Bindings.map fix_ctyp ctors |> Bindings.bindings) | Typ_id id when Bindings.mem id ctx.enums -> CT_enum (id, Bindings.find id ctx.enums |> IdSet.elements) - | Typ_tuple typs -> CT_tup (List.map (convert_typ ctx) typs) - - | Typ_exist _ -> - (* Use Type_check.destruct_exist when optimising with SMT, to - ensure that we don't cause any type variable clashes in - local_env, and that we can optimize the existential based - upon its constraints. *) - begin match destruct_exist typ with - | Some (kids, nc, typ) -> - let env = add_existential l kids nc ctx.local_env in - convert_typ { ctx with local_env = env } typ - | None -> raise (Reporting.err_unreachable l __POS__ "Existential cannot be destructured!") - end - + | Typ_exist _ -> begin + (* Use Type_check.destruct_exist when optimising with SMT, to + ensure that we don't cause any type variable clashes in + local_env, and that we can optimize the existential based + upon its constraints. *) + match destruct_exist typ with + | Some (kids, nc, typ) -> + let env = add_existential l kids nc ctx.local_env in + convert_typ { ctx with local_env = env } typ + | None -> raise (Reporting.err_unreachable l __POS__ "Existential cannot be destructured!") + end | Typ_var kid -> CT_poly kid - | _ -> raise (Reporting.err_unreachable l __POS__ ("No SMT type for type " ^ string_of_typ typ)) let hex_char = @@ -1581,8 +1507,8 @@ module SMT_config(Opts : sig val unroll_limit : int end) : Jib_compile.Config = match l_aux with | L_num n -> Some (V_lit (VL_int n, CT_constant n)) | L_hex str when String.length str <= 16 -> - let content = Util.string_to_list str |> List.map hex_char |> List.concat in - Some (V_lit (VL_bits (content, true), CT_fbits (String.length str * 4, true))) + let content = Util.string_to_list str |> List.map hex_char |> List.concat in + Some (V_lit (VL_bits (content, true), CT_fbits (String.length str * 4, true))) | L_unit -> Some (V_lit (VL_unit, CT_unit)) | L_true -> Some (V_lit (VL_bool true, CT_bool)) | L_false -> Some (V_lit (VL_bool false, CT_bool)) @@ -1590,44 +1516,47 @@ module SMT_config(Opts : sig val unroll_limit : int end) : Jib_compile.Config = let c_literals ctx = let rec c_literal env l = function - | AV_lit (lit, typ) as v -> - begin match literal_to_cval lit with - | Some cval -> AV_cval (cval, typ) - | None -> v - end + | AV_lit (lit, typ) as v -> begin match literal_to_cval lit with Some cval -> AV_cval (cval, typ) | None -> v end | AV_tuple avals -> AV_tuple (List.map (c_literal env l) avals) | v -> v in map_aval c_literal -(* If we know the loop variables exactly (especially after - specialization), we can unroll the exact number of times required, - and omit any comparisons. *) -let unroll_static_foreach ctx = function - | AE_aux (AE_for (id, from_aexp, to_aexp, by_aexp, order, body), env, l) as aexp -> - begin match convert_typ ctx (aexp_typ from_aexp), convert_typ ctx (aexp_typ to_aexp), convert_typ ctx (aexp_typ by_aexp), order with - | CT_constant f, CT_constant t, CT_constant b, Ord_aux (Ord_inc, _) -> - let i = ref f in - let unrolled = ref [] in - while Big_int.less_equal !i t do - let current_index = AE_aux (AE_val (AV_lit (L_aux (L_num !i, gen_loc l), atom_typ (nconstant !i))), env, gen_loc l) in - let iteration = AE_aux (AE_let (Immutable, id, atom_typ (nconstant !i), current_index, body, unit_typ), env, gen_loc l) in - unrolled := iteration :: !unrolled; - i := Big_int.add !i b - done; - begin match !unrolled with - | last :: iterations -> - AE_aux (AE_block (List.rev iterations, last, unit_typ), env, gen_loc l) - | [] -> AE_aux (AE_val (AV_lit (L_aux (L_unit, gen_loc l), unit_typ)), env, gen_loc l) - end - | _ -> aexp - end - | aexp -> aexp - - let optimize_anf ctx aexp = - aexp - |> c_literals ctx - |> fold_aexp (unroll_static_foreach ctx) + (* If we know the loop variables exactly (especially after + specialization), we can unroll the exact number of times required, + and omit any comparisons. *) + let unroll_static_foreach ctx = function + | AE_aux (AE_for (id, from_aexp, to_aexp, by_aexp, order, body), env, l) as aexp -> begin + match + ( convert_typ ctx (aexp_typ from_aexp), + convert_typ ctx (aexp_typ to_aexp), + convert_typ ctx (aexp_typ by_aexp), + order + ) + with + | CT_constant f, CT_constant t, CT_constant b, Ord_aux (Ord_inc, _) -> + let i = ref f in + let unrolled = ref [] in + while Big_int.less_equal !i t do + let current_index = + AE_aux (AE_val (AV_lit (L_aux (L_num !i, gen_loc l), atom_typ (nconstant !i))), env, gen_loc l) + in + let iteration = + AE_aux (AE_let (Immutable, id, atom_typ (nconstant !i), current_index, body, unit_typ), env, gen_loc l) + in + unrolled := iteration :: !unrolled; + i := Big_int.add !i b + done; + begin + match !unrolled with + | last :: iterations -> AE_aux (AE_block (List.rev iterations, last, unit_typ), env, gen_loc l) + | [] -> AE_aux (AE_val (AV_lit (L_aux (L_unit, gen_loc l), unit_typ)), env, gen_loc l) + end + | _ -> aexp + end + | aexp -> aexp + + let optimize_anf ctx aexp = aexp |> c_literals ctx |> fold_aexp (unroll_static_foreach ctx) let specialize_calls = true let ignore_64 = true @@ -1639,13 +1568,11 @@ let unroll_static_foreach ctx = function let track_throw = false end - (**************************************************************************) (* 3. Generating SMT *) (**************************************************************************) -let push_smt_defs stack smt_defs = - List.iter (fun def -> Stack.push def stack) smt_defs +let push_smt_defs stack smt_defs = List.iter (fun def -> Stack.push def stack) smt_defs (* When generating SMT when we encounter joins between two or more blocks such as in the example below, we have to generate a muxer @@ -1675,28 +1602,25 @@ let smt_ssanode ctx cfg preds = let open Jib_ssa in function | Pi _ -> [] - | Phi (id, ctyp, ids) -> - let get_pi n = - match get_vertex cfg n with - | Some ((ssa_elems, _), _, _) -> - List.concat (List.map (function Pi guards -> guards | _ -> []) ssa_elems) - | None -> failwith "Predecessor node does not exist" - in - let pis = List.map get_pi (IntSet.elements preds) in - let mux = - List.fold_right2 (fun pi id chain -> - let pathcond = smt_conj (List.map (smt_cval ctx) pi) in - match chain with - | Some smt -> - Some (Ite (pathcond, Var (zencode_name id), smt)) - | None -> - Some (Var (zencode_name id))) - pis ids None - in - match mux with - | None -> assert false - | Some mux -> - [Define_const (zencode_name id, smt_ctyp ctx ctyp, mux)] + | Phi (id, ctyp, ids) -> ( + let get_pi n = + match get_vertex cfg n with + | Some ((ssa_elems, _), _, _) -> List.concat (List.map (function Pi guards -> guards | _ -> []) ssa_elems) + | None -> failwith "Predecessor node does not exist" + in + let pis = List.map get_pi (IntSet.elements preds) in + let mux = + List.fold_right2 + (fun pi id chain -> + let pathcond = smt_conj (List.map (smt_cval ctx) pi) in + match chain with + | Some smt -> Some (Ite (pathcond, Var (zencode_name id), smt)) + | None -> Some (Var (zencode_name id)) + ) + pis ids None + in + match mux with None -> assert false | Some mux -> [Define_const (zencode_name id, smt_ctyp ctx ctyp, mux)] + ) (* The pi condition are computed by traversing the dominator tree, with each node having a pi condition defined as the conjunction of @@ -1744,21 +1668,18 @@ let rec get_pathcond n cfg ctx = let get_pi m = match get_vertex cfg m with | Some ((ssa_elems, _), _, _) -> - V_call (Band, List.concat (List.map (function Pi guards -> guards | _ -> []) ssa_elems)) + V_call (Band, List.concat (List.map (function Pi guards -> guards | _ -> []) ssa_elems)) | None -> failwith "Node does not exist" in match get_vertex cfg n with - | Some ((_, CF_guard cond), _, _) -> - smt_cval ctx (get_pi n) + | Some ((_, CF_guard cond), _, _) -> smt_cval ctx (get_pi n) | Some (_, preds, succs) -> - if IntSet.cardinal preds = 0 then - Bool_lit true - else if IntSet.cardinal preds = 1 then - get_pathcond (IntSet.min_elt preds) cfg ctx - else - let pis = List.map get_pi (IntSet.elements preds) in - smt_cval ctx (V_call (Bor, pis)) - + if IntSet.cardinal preds = 0 then Bool_lit true + else if IntSet.cardinal preds = 1 then get_pathcond (IntSet.min_elt preds) cfg ctx + else ( + let pis = List.map get_pi (IntSet.elements preds) in + smt_cval ctx (V_call (Bor, pis)) + ) | None -> assert false (* Should never be called for a non-existent node *) (* For any complex l-expression we need to turn it into a @@ -1768,59 +1689,48 @@ let rec get_pathcond n cfg ctx = name but different SSA numbers. *) let rec rmw_write = function - | CL_rmw (_, write, ctyp) -> write, ctyp + | CL_rmw (_, write, ctyp) -> (write, ctyp) | CL_id _ -> assert false | CL_tuple (clexp, _) -> rmw_write clexp | CL_field (clexp, _) -> rmw_write clexp | clexp -> failwith "Could not understand l-expression" -let rmw_read = function - | CL_rmw (read, _, _) -> zencode_name read - | _ -> assert false +let rmw_read = function CL_rmw (read, _, _) -> zencode_name read | _ -> assert false let rmw_modify smt = function | CL_tuple (clexp, n) -> - let ctyp = clexp_ctyp clexp in - begin match ctyp with - | CT_tup ctyps -> - let len = List.length ctyps in - let set_tup i = - if i == n then - smt - else - Fn (Printf.sprintf "tup_%d_%d" len i, [Var (rmw_read clexp)]) - in - Fn ("tup" ^ string_of_int len, List.init len set_tup) - | _ -> - failwith "Tuple modify does not have tuple type" - end + let ctyp = clexp_ctyp clexp in + begin + match ctyp with + | CT_tup ctyps -> + let len = List.length ctyps in + let set_tup i = if i == n then smt else Fn (Printf.sprintf "tup_%d_%d" len i, [Var (rmw_read clexp)]) in + Fn ("tup" ^ string_of_int len, List.init len set_tup) + | _ -> failwith "Tuple modify does not have tuple type" + end | CL_field (clexp, field) -> - let ctyp = clexp_ctyp clexp in - begin match ctyp with - | CT_struct (struct_id, fields) -> - let set_field (field', _) = - if Id.compare field field' = 0 then - smt - else - Field (zencode_upper_id struct_id ^ "_" ^ zencode_id field', Var (rmw_read clexp)) - in - Fn (zencode_upper_id struct_id, List.map set_field fields) - | _ -> - failwith "Struct modify does not have struct type" - end + let ctyp = clexp_ctyp clexp in + begin + match ctyp with + | CT_struct (struct_id, fields) -> + let set_field (field', _) = + if Id.compare field field' = 0 then smt + else Field (zencode_upper_id struct_id ^ "_" ^ zencode_id field', Var (rmw_read clexp)) + in + Fn (zencode_upper_id struct_id, List.map set_field fields) + | _ -> failwith "Struct modify does not have struct type" + end | _ -> assert false let smt_terminator ctx = let open Jib_ssa in function | T_end id -> - add_event ctx Return (Var (zencode_name id)); - [] - + add_event ctx Return (Var (zencode_name id)); + [] | T_exit _ -> - add_pathcond_event ctx Match; - [] - + add_pathcond_event ctx Match; + [] | T_undefined _ | T_goto _ | T_jump _ | T_label _ | T_none -> [] (* For a basic block (contained in a control-flow node / cfnode), we @@ -1832,175 +1742,161 @@ let smt_instr ctx = let open Type_check in function | I_aux (I_funcall (CL_id (id, ret_ctyp), extern, function_id, args), (_, l)) -> - if Env.is_extern (fst function_id) ctx.tc_env "c" && not extern then - let name = Env.get_extern (fst function_id) ctx.tc_env "c" in - if name = "sqrt_real" then - begin match args with - | [v] -> builtin_sqrt_real ctx (zencode_name id) v - | _ -> - Reporting.unreachable l __POS__ "Bad arguments for sqrt_real" - end - (* See lib/regfp.sail *) - else if name = "platform_write_mem" then - begin match args with - | [wk; addr_size; addr; data_size; data] -> - let mem_event, var = builtin_write_mem l ctx wk addr_size addr data_size data in - mem_event @ [define_const ctx id ret_ctyp var] - | _ -> - Reporting.unreachable l __POS__ "Bad arguments for __write_mem" - end - else if name = "platform_write_mem_ea" then - begin match args with - | [wk; addr_size; addr; data_size] -> - let mem_event, var = builtin_write_mem_ea ctx wk addr_size addr data_size in - mem_event @ [define_const ctx id ret_ctyp var] - | _ -> - Reporting.unreachable l __POS__ "Bad arguments for __write_mem_ea" - end - else if name = "platform_read_mem" then - begin match args with - | [rk; addr_size; addr; data_size] -> - let mem_event, var = builtin_read_mem l ctx rk addr_size addr data_size ret_ctyp in - mem_event @ [define_const ctx id ret_ctyp var] - | _ -> - Reporting.unreachable l __POS__ "Bad arguments for __read_mem" - end - else if name = "platform_barrier" then - begin match args with - | [bk] -> - let mem_event, var = builtin_barrier l ctx bk in - mem_event @ [define_const ctx id ret_ctyp var] - | _ -> - Reporting.unreachable l __POS__ "Bad arguments for __barrier" - end - else if name = "platform_cache_maintenance" then - begin match args with - | [cmk; addr_size; addr] -> - let mem_event, var = builtin_cache_maintenance l ctx cmk addr_size addr in - mem_event @ [define_const ctx id ret_ctyp var] - | _ -> - Reporting.unreachable l __POS__ "Bad arguments for __barrier" - end - else if name = "platform_branch_announce" then - begin match args with - | [addr_size; addr] -> - let mem_event, var = builtin_branch_announce l ctx addr_size addr in - mem_event @ [define_const ctx id ret_ctyp var] - | _ -> - Reporting.unreachable l __POS__ "Bad arguments for __barrier" - end - else if name = "platform_excl_res" then - begin match args with - | [_] -> - let mem_event, var = builtin_excl_res ctx in - mem_event @ [define_const ctx id ret_ctyp var] - | _ -> - Reporting.unreachable l __POS__ "Bad arguments for __excl_res" - end - else if name = "sail_exit" then - (add_event ctx Assertion (Bool_lit false); []) - else if name = "sail_assert" then - begin match args with - | [assertion; _] -> - let smt = smt_cval ctx assertion in - add_event ctx Assertion (Fn ("not", [smt])); - [] - | _ -> - Reporting.unreachable l __POS__ "Bad arguments for assertion" - end - else - let value = smt_builtin ctx name args ret_ctyp in - [define_const ctx id ret_ctyp (Syntactic (value, List.map (smt_cval ctx) args))] - else if extern && string_of_id (fst function_id) = "internal_vector_init" then - [declare_const ctx id ret_ctyp] - else if extern && string_of_id (fst function_id) = "internal_vector_update" then - begin match args with - | [vec; i; x] -> - let sz = int_size ctx (cval_ctyp i) in - [define_const ctx id ret_ctyp - (Fn ("store", [smt_cval ctx vec; force_size ~checked:false ctx ctx.vector_index sz (smt_cval ctx i); smt_cval ctx x]))] - | _ -> - Reporting.unreachable l __POS__ "Bad arguments for internal_vector_update" - end - else if (string_of_id (fst function_id) = "update_fbits" - || string_of_id (fst function_id) = "update_lbits") && extern then - begin match args with - | [vec; i; x] -> - [define_const ctx id ret_ctyp (builtin_vector_update ctx vec i x ret_ctyp)] - | _ -> - Reporting.unreachable l __POS__ "Bad arguments for update_{f,l}bits" - end - else if string_of_id (fst function_id) = "sail_assume" then - begin match args with - | [assumption] -> - let smt = smt_cval ctx assumption in - add_event ctx Assumption smt; + if Env.is_extern (fst function_id) ctx.tc_env "c" && not extern then ( + let name = Env.get_extern (fst function_id) ctx.tc_env "c" in + if name = "sqrt_real" then begin + match args with + | [v] -> builtin_sqrt_real ctx (zencode_name id) v + | _ -> Reporting.unreachable l __POS__ "Bad arguments for sqrt_real" + (* See lib/regfp.sail *) + end + else if name = "platform_write_mem" then begin + match args with + | [wk; addr_size; addr; data_size; data] -> + let mem_event, var = builtin_write_mem l ctx wk addr_size addr data_size data in + mem_event @ [define_const ctx id ret_ctyp var] + | _ -> Reporting.unreachable l __POS__ "Bad arguments for __write_mem" + end + else if name = "platform_write_mem_ea" then begin + match args with + | [wk; addr_size; addr; data_size] -> + let mem_event, var = builtin_write_mem_ea ctx wk addr_size addr data_size in + mem_event @ [define_const ctx id ret_ctyp var] + | _ -> Reporting.unreachable l __POS__ "Bad arguments for __write_mem_ea" + end + else if name = "platform_read_mem" then begin + match args with + | [rk; addr_size; addr; data_size] -> + let mem_event, var = builtin_read_mem l ctx rk addr_size addr data_size ret_ctyp in + mem_event @ [define_const ctx id ret_ctyp var] + | _ -> Reporting.unreachable l __POS__ "Bad arguments for __read_mem" + end + else if name = "platform_barrier" then begin + match args with + | [bk] -> + let mem_event, var = builtin_barrier l ctx bk in + mem_event @ [define_const ctx id ret_ctyp var] + | _ -> Reporting.unreachable l __POS__ "Bad arguments for __barrier" + end + else if name = "platform_cache_maintenance" then begin + match args with + | [cmk; addr_size; addr] -> + let mem_event, var = builtin_cache_maintenance l ctx cmk addr_size addr in + mem_event @ [define_const ctx id ret_ctyp var] + | _ -> Reporting.unreachable l __POS__ "Bad arguments for __barrier" + end + else if name = "platform_branch_announce" then begin + match args with + | [addr_size; addr] -> + let mem_event, var = builtin_branch_announce l ctx addr_size addr in + mem_event @ [define_const ctx id ret_ctyp var] + | _ -> Reporting.unreachable l __POS__ "Bad arguments for __barrier" + end + else if name = "platform_excl_res" then begin + match args with + | [_] -> + let mem_event, var = builtin_excl_res ctx in + mem_event @ [define_const ctx id ret_ctyp var] + | _ -> Reporting.unreachable l __POS__ "Bad arguments for __excl_res" + end + else if name = "sail_exit" then ( + add_event ctx Assertion (Bool_lit false); [] - | _ -> - Reporting.unreachable l __POS__ "Bad arguments for assumption" - end - else if not extern then - let smt_args = List.map (smt_cval ctx) args in - [define_const ctx id ret_ctyp (Ctor (zencode_uid function_id, smt_args))] - else - failwith ("Unrecognised function " ^ string_of_uid function_id) - + ) + else if name = "sail_assert" then begin + match args with + | [assertion; _] -> + let smt = smt_cval ctx assertion in + add_event ctx Assertion (Fn ("not", [smt])); + [] + | _ -> Reporting.unreachable l __POS__ "Bad arguments for assertion" + end + else ( + let value = smt_builtin ctx name args ret_ctyp in + [define_const ctx id ret_ctyp (Syntactic (value, List.map (smt_cval ctx) args))] + ) + ) + else if extern && string_of_id (fst function_id) = "internal_vector_init" then [declare_const ctx id ret_ctyp] + else if extern && string_of_id (fst function_id) = "internal_vector_update" then begin + match args with + | [vec; i; x] -> + let sz = int_size ctx (cval_ctyp i) in + [ + define_const ctx id ret_ctyp + (Fn + ( "store", + [ + smt_cval ctx vec; + force_size ~checked:false ctx ctx.vector_index sz (smt_cval ctx i); + smt_cval ctx x; + ] + ) + ); + ] + | _ -> Reporting.unreachable l __POS__ "Bad arguments for internal_vector_update" + end + else if + (string_of_id (fst function_id) = "update_fbits" || string_of_id (fst function_id) = "update_lbits") && extern + then begin + match args with + | [vec; i; x] -> [define_const ctx id ret_ctyp (builtin_vector_update ctx vec i x ret_ctyp)] + | _ -> Reporting.unreachable l __POS__ "Bad arguments for update_{f,l}bits" + end + else if string_of_id (fst function_id) = "sail_assume" then begin + match args with + | [assumption] -> + let smt = smt_cval ctx assumption in + add_event ctx Assumption smt; + [] + | _ -> Reporting.unreachable l __POS__ "Bad arguments for assumption" + end + else if not extern then ( + let smt_args = List.map (smt_cval ctx) args in + [define_const ctx id ret_ctyp (Ctor (zencode_uid function_id, smt_args))] + ) + else failwith ("Unrecognised function " ^ string_of_uid function_id) | I_aux (I_copy (CL_addr (CL_id (_, _)), _), (_, l)) -> - Reporting.unreachable l __POS__ "Register reference write should be re-written by now" - - | I_aux (I_init (ctyp, id, cval), _) | I_aux (I_copy (CL_id (id, ctyp), cval), _) -> - begin match id, cval with - | (Name (id, _) | Global (id, _)), _ when IdSet.mem id ctx.preserved -> - [preserve_const ctx id ctyp - (smt_conversion ctx (cval_ctyp cval) ctyp (smt_cval ctx cval))] - | _, V_lit (VL_undefined, _) -> - (* Declare undefined variables as arbitrary but fixed *) - [declare_const ctx id ctyp] - | _, _ -> - [define_const ctx id ctyp - (smt_conversion ctx (cval_ctyp cval) ctyp (smt_cval ctx cval))] - end - + Reporting.unreachable l __POS__ "Register reference write should be re-written by now" + | I_aux (I_init (ctyp, id, cval), _) | I_aux (I_copy (CL_id (id, ctyp), cval), _) -> begin + match (id, cval) with + | (Name (id, _) | Global (id, _)), _ when IdSet.mem id ctx.preserved -> + [preserve_const ctx id ctyp (smt_conversion ctx (cval_ctyp cval) ctyp (smt_cval ctx cval))] + | _, V_lit (VL_undefined, _) -> + (* Declare undefined variables as arbitrary but fixed *) + [declare_const ctx id ctyp] + | _, _ -> [define_const ctx id ctyp (smt_conversion ctx (cval_ctyp cval) ctyp (smt_cval ctx cval))] + end | I_aux (I_copy (clexp, cval), _) -> - let smt = smt_cval ctx cval in - let write, ctyp = rmw_write clexp in - [define_const ctx write ctyp (rmw_modify smt clexp)] - + let smt = smt_cval ctx cval in + let write, ctyp = rmw_write clexp in + [define_const ctx write ctyp (rmw_modify smt clexp)] | I_aux (I_decl (ctyp, id), (_, l)) -> - (* Function arguments have unique locations defined from the - $property pragma. We record how they will appear in the - generated SMT so we can check models. *) - begin match l with - | Unique (n, l') when l' = ctx.pragma_l -> - Stack.push (n, zencode_name id) ctx.arg_stack - | _ -> () - end; - [declare_const ctx id ctyp] - + (* Function arguments have unique locations defined from the + $property pragma. We record how they will appear in the + generated SMT so we can check models. *) + begin + match l with Unique (n, l') when l' = ctx.pragma_l -> Stack.push (n, zencode_name id) ctx.arg_stack | _ -> () + end; + [declare_const ctx id ctyp] | I_aux (I_clear _, _) -> [] - (* Should only appear as terminators for basic blocks. *) | I_aux ((I_jump _ | I_goto _ | I_end _ | I_exit _ | I_undefined _), (_, l)) -> - Reporting.unreachable l __POS__ "SMT: Instruction should only appear as block terminator" - - | I_aux (_, (_, l)) -> - Reporting.unreachable l __POS__ "Cannot translate instruction" + Reporting.unreachable l __POS__ "SMT: Instruction should only appear as block terminator" + | I_aux (_, (_, l)) -> Reporting.unreachable l __POS__ "Cannot translate instruction" let smt_cfnode all_cdefs ctx ssa_elems = let open Jib_ssa in function | CF_start inits -> - let smt_reg_decs = generate_reg_decs ctx inits all_cdefs in - let smt_start (id, ctyp) = - match id with - | Have_exception _ -> define_const ctx id ctyp (Bool_lit false) - | _ -> declare_const ctx id ctyp - in - smt_reg_decs @ List.map smt_start (NameMap.bindings inits) + let smt_reg_decs = generate_reg_decs ctx inits all_cdefs in + let smt_start (id, ctyp) = + match id with Have_exception _ -> define_const ctx id ctyp (Bool_lit false) | _ -> declare_const ctx id ctyp + in + smt_reg_decs @ List.map smt_start (NameMap.bindings inits) | CF_block (instrs, terminator) -> - let smt_instrs = List.map (smt_instr ctx) instrs in - let smt_term = smt_terminator ctx terminator in - List.concat (smt_instrs @ [smt_term]) + let smt_instrs = List.map (smt_instr ctx) instrs in + let smt_term = smt_terminator ctx terminator in + List.concat (smt_instrs @ [smt_term]) (* We can ignore any non basic-block/start control-flow nodes *) | _ -> [] @@ -2009,14 +1905,12 @@ let smt_cfnode all_cdefs ctx ssa_elems = keep track of any global letbindings between the spec and the fundef, so they can appear in the generated SMT. *) let rec find_function lets id = function - | CDEF_fundef (id', heap_return, args, body) :: _ when Id.compare id id' = 0 -> - lets, Some (heap_return, args, body) + | CDEF_fundef (id', heap_return, args, body) :: _ when Id.compare id id' = 0 -> (lets, Some (heap_return, args, body)) | CDEF_let (_, vars, setup) :: cdefs -> - let vars = List.map (fun (id, ctyp) -> idecl (id_loc id) ctyp (global id)) vars in - find_function (lets @ vars @ setup) id cdefs; - | _ :: cdefs -> - find_function lets id cdefs - | [] -> lets, None + let vars = List.map (fun (id, ctyp) -> idecl (id_loc id) ctyp (global id)) vars in + find_function (lets @ vars @ setup) id cdefs + | _ :: cdefs -> find_function lets id cdefs + | [] -> (lets, None) module type Sequence = sig type 'a t @@ -2024,78 +1918,82 @@ module type Sequence = sig val add : 'a -> 'a t -> unit end -module Make_optimizer(S : Sequence) = struct - +module Make_optimizer (S : Sequence) = struct let optimize stack = let stack' = Stack.create () in let uses = Hashtbl.create (Stack.length stack) in let rec uses_in_exp = function - | Var var -> - begin match Hashtbl.find_opt uses var with - | Some n -> Hashtbl.replace uses var (n + 1) - | None -> Hashtbl.add uses var 1 - end + | Var var -> begin + match Hashtbl.find_opt uses var with + | Some n -> Hashtbl.replace uses var (n + 1) + | None -> Hashtbl.add uses var 1 + end | Syntactic (exp, _) -> uses_in_exp exp | Shared _ | Enum _ | Read_res _ | Bitvec_lit _ | Bool_lit _ | String_lit _ | Real_lit _ -> () - | Fn (_, exps) | Ctor (_, exps) -> - List.iter uses_in_exp exps - | Field (_, exp) -> - uses_in_exp exp - | Struct (_, fields) -> - List.iter (fun (_, exp) -> uses_in_exp exp) fields + | Fn (_, exps) | Ctor (_, exps) -> List.iter uses_in_exp exps + | Field (_, exp) -> uses_in_exp exp + | Struct (_, fields) -> List.iter (fun (_, exp) -> uses_in_exp exp) fields | Ite (cond, t, e) -> - uses_in_exp cond; uses_in_exp t; uses_in_exp e - | Extract (_, _, exp) | Tester (_, exp) | SignExtend (_, exp) -> - uses_in_exp exp + uses_in_exp cond; + uses_in_exp t; + uses_in_exp e + | Extract (_, _, exp) | Tester (_, exp) | SignExtend (_, exp) -> uses_in_exp exp | Forall _ -> assert false in let remove_unused () = function - | Declare_const (var, _) as def -> - begin match Hashtbl.find_opt uses var with - | None -> () - | Some _ -> - Stack.push def stack' - end - | Declare_fun _ as def -> - Stack.push def stack' + | Declare_const (var, _) as def -> begin + match Hashtbl.find_opt uses var with None -> () | Some _ -> Stack.push def stack' + end + | Declare_fun _ as def -> Stack.push def stack' | Preserve_const (_, _, exp) as def -> - uses_in_exp exp; - Stack.push def stack' - | Define_const (var, _, exp) as def -> - begin match Hashtbl.find_opt uses var with - | None -> () - | Some _ -> - uses_in_exp exp; - Stack.push def stack' - end - | (Declare_datatypes _ | Declare_tuple _) as def -> - Stack.push def stack' + uses_in_exp exp; + Stack.push def stack' + | Define_const (var, _, exp) as def -> begin + match Hashtbl.find_opt uses var with + | None -> () + | Some _ -> + uses_in_exp exp; + Stack.push def stack' + end + | (Declare_datatypes _ | Declare_tuple _) as def -> Stack.push def stack' | Write_mem w as def -> - uses_in_exp w.active; uses_in_exp w.kind; uses_in_exp w.addr; uses_in_exp w.data; - Stack.push def stack' + uses_in_exp w.active; + uses_in_exp w.kind; + uses_in_exp w.addr; + uses_in_exp w.data; + Stack.push def stack' | Write_mem_ea (_, _, active, wk, addr, _, data_size, _) as def -> - uses_in_exp active; uses_in_exp wk; uses_in_exp addr; uses_in_exp data_size; - Stack.push def stack' + uses_in_exp active; + uses_in_exp wk; + uses_in_exp addr; + uses_in_exp data_size; + Stack.push def stack' | Read_mem r as def -> - uses_in_exp r.active; uses_in_exp r.kind; uses_in_exp r.addr; - Stack.push def stack' + uses_in_exp r.active; + uses_in_exp r.kind; + uses_in_exp r.addr; + Stack.push def stack' | Barrier b as def -> - uses_in_exp b.active; uses_in_exp b.kind; - Stack.push def stack' + uses_in_exp b.active; + uses_in_exp b.kind; + Stack.push def stack' | Cache_maintenance m as def -> - uses_in_exp m.active; uses_in_exp m.kind; uses_in_exp m.addr; - Stack.push def stack' + uses_in_exp m.active; + uses_in_exp m.kind; + uses_in_exp m.addr; + Stack.push def stack' | Branch_announce c as def -> - uses_in_exp c.active; uses_in_exp c.addr; - Stack.push def stack' + uses_in_exp c.active; + uses_in_exp c.addr; + Stack.push def stack' | Excl_res (_, _, active) as def -> - uses_in_exp active; - Stack.push def stack' + uses_in_exp active; + Stack.push def stack' | Assert exp as def -> - uses_in_exp exp; - Stack.push def stack' + uses_in_exp exp; + Stack.push def stack' | Define_fun _ -> assert false in Stack.fold remove_unused () stack; @@ -2105,82 +2003,100 @@ module Make_optimizer(S : Sequence) = struct let seq = S.create () in let constant_propagate = function - | Declare_const _ as def -> - S.add def seq - | Declare_fun _ as def -> - S.add def seq - | Preserve_const (var, typ, exp) -> - S.add (Preserve_const (var, typ, simp_smt_exp vars kinds exp)) seq + | Declare_const _ as def -> S.add def seq + | Declare_fun _ as def -> S.add def seq + | Preserve_const (var, typ, exp) -> S.add (Preserve_const (var, typ, simp_smt_exp vars kinds exp)) seq | Define_const (var, typ, exp) -> - let exp = simp_smt_exp vars kinds exp in - begin match Hashtbl.find_opt uses var, simp_smt_exp vars kinds exp with - | _, (Bitvec_lit _ | Bool_lit _) -> - Hashtbl.add vars var exp - | _, Var _ when !opt_propagate_vars -> - Hashtbl.add vars var exp - | _, (Ctor (str, _)) -> - Hashtbl.add kinds var str; - S.add (Define_const (var, typ, exp)) seq - | Some 1, _ -> - Hashtbl.add vars var exp - | Some _, exp -> - S.add (Define_const (var, typ, exp)) seq - | None, _ -> assert false - end + let exp = simp_smt_exp vars kinds exp in + begin + match (Hashtbl.find_opt uses var, simp_smt_exp vars kinds exp) with + | _, (Bitvec_lit _ | Bool_lit _) -> Hashtbl.add vars var exp + | _, Var _ when !opt_propagate_vars -> Hashtbl.add vars var exp + | _, Ctor (str, _) -> + Hashtbl.add kinds var str; + S.add (Define_const (var, typ, exp)) seq + | Some 1, _ -> Hashtbl.add vars var exp + | Some _, exp -> S.add (Define_const (var, typ, exp)) seq + | None, _ -> assert false + end | Write_mem w -> - S.add (Write_mem { w with active = simp_smt_exp vars kinds w.active; - kind = simp_smt_exp vars kinds w.kind; - addr = simp_smt_exp vars kinds w.addr; - data = simp_smt_exp vars kinds w.data }) - seq + S.add + (Write_mem + { + w with + active = simp_smt_exp vars kinds w.active; + kind = simp_smt_exp vars kinds w.kind; + addr = simp_smt_exp vars kinds w.addr; + data = simp_smt_exp vars kinds w.data; + } + ) + seq | Write_mem_ea (name, node, active, wk, addr, addr_ty, data_size, data_size_ty) -> - S.add (Write_mem_ea (name, node, simp_smt_exp vars kinds active, simp_smt_exp vars kinds wk, - simp_smt_exp vars kinds addr, addr_ty, simp_smt_exp vars kinds data_size, data_size_ty)) - seq + S.add + (Write_mem_ea + ( name, + node, + simp_smt_exp vars kinds active, + simp_smt_exp vars kinds wk, + simp_smt_exp vars kinds addr, + addr_ty, + simp_smt_exp vars kinds data_size, + data_size_ty + ) + ) + seq | Read_mem r -> - S.add (Read_mem { r with active = simp_smt_exp vars kinds r.active; - kind = simp_smt_exp vars kinds r.kind; - addr = simp_smt_exp vars kinds r.addr }) - seq + S.add + (Read_mem + { + r with + active = simp_smt_exp vars kinds r.active; + kind = simp_smt_exp vars kinds r.kind; + addr = simp_smt_exp vars kinds r.addr; + } + ) + seq | Barrier b -> - S.add (Barrier { b with active = simp_smt_exp vars kinds b.active; kind = simp_smt_exp vars kinds b.kind }) seq + S.add + (Barrier { b with active = simp_smt_exp vars kinds b.active; kind = simp_smt_exp vars kinds b.kind }) + seq | Cache_maintenance m -> - S.add (Cache_maintenance { m with active = simp_smt_exp vars kinds m.active; - kind = simp_smt_exp vars kinds m.kind; - addr = simp_smt_exp vars kinds m.addr }) - seq + S.add + (Cache_maintenance + { + m with + active = simp_smt_exp vars kinds m.active; + kind = simp_smt_exp vars kinds m.kind; + addr = simp_smt_exp vars kinds m.addr; + } + ) + seq | Branch_announce c -> - S.add (Branch_announce { c with active = simp_smt_exp vars kinds c.active; addr = simp_smt_exp vars kinds c.addr }) seq - | Excl_res (name, node, active) -> - S.add (Excl_res (name, node, simp_smt_exp vars kinds active)) seq - | Assert exp -> - S.add (Assert (simp_smt_exp vars kinds exp)) seq - | (Declare_datatypes _ | Declare_tuple _) as def -> - S.add def seq + S.add + (Branch_announce { c with active = simp_smt_exp vars kinds c.active; addr = simp_smt_exp vars kinds c.addr }) + seq + | Excl_res (name, node, active) -> S.add (Excl_res (name, node, simp_smt_exp vars kinds active)) seq + | Assert exp -> S.add (Assert (simp_smt_exp vars kinds exp)) seq + | (Declare_datatypes _ | Declare_tuple _) as def -> S.add def seq | Define_fun _ -> assert false in Stack.iter constant_propagate stack'; seq - end -module Queue_optimizer = - Make_optimizer(struct - type 'a t = 'a Queue.t - let create = Queue.create - let add = Queue.add - let iter = Queue.iter - end) +module Queue_optimizer = Make_optimizer (struct + type 'a t = 'a Queue.t + let create = Queue.create + let add = Queue.add + let iter = Queue.iter +end) (** [smt_header ctx cdefs] produces a list of smt definitions for all the datatypes in a specification *) let smt_header ctx cdefs = let smt_ctype_defs = List.concat (generate_ctype_defs ctx cdefs) in [declare_datatypes (mk_enum "Unit" ["unit"])] @ (IntSet.elements !(ctx.tuple_sizes) |> List.map (fun n -> Declare_tuple n)) - @ [declare_datatypes (mk_record "Bits" [("len", Bitvec ctx.lbits_index); - ("contents", Bitvec (lbits_size ctx))]) - - ] + @ [declare_datatypes (mk_record "Bits" [("len", Bitvec ctx.lbits_index); ("contents", Bitvec (lbits_size ctx))])] @ smt_ctype_defs (* For generating SMT when we have a reg_deref(r : register(t)) @@ -2189,87 +2105,90 @@ let smt_header ctx cdefs = register if it is. We also do a similar thing for *r = x *) let expand_reg_deref env register_map = function - | I_aux (I_funcall (CL_addr (CL_id (id, ctyp)), false, function_id, args), (_, l)) -> - begin match ctyp with - | CT_ref reg_ctyp -> - begin match CTMap.find_opt reg_ctyp register_map with - | Some regs -> - let end_label = label "end_reg_write_" in - let try_reg r = - let next_label = label "next_reg_write_" in - [ijump l (V_call (Neq, [V_lit (VL_ref (string_of_id r), reg_ctyp); V_id (id, ctyp)])) next_label; - ifuncall l (CL_id (global r, reg_ctyp)) function_id args; - igoto end_label; - ilabel next_label] - in - iblock (List.concat (List.map try_reg regs) @ [ilabel end_label]) - | None -> - raise (Reporting.err_general l ("Could not find any registers with type " ^ string_of_ctyp reg_ctyp)) + | I_aux (I_funcall (CL_addr (CL_id (id, ctyp)), false, function_id, args), (_, l)) -> begin + match ctyp with + | CT_ref reg_ctyp -> begin + match CTMap.find_opt reg_ctyp register_map with + | Some regs -> + let end_label = label "end_reg_write_" in + let try_reg r = + let next_label = label "next_reg_write_" in + [ + ijump l (V_call (Neq, [V_lit (VL_ref (string_of_id r), reg_ctyp); V_id (id, ctyp)])) next_label; + ifuncall l (CL_id (global r, reg_ctyp)) function_id args; + igoto end_label; + ilabel next_label; + ] + in + iblock (List.concat (List.map try_reg regs) @ [ilabel end_label]) + | None -> raise (Reporting.err_general l ("Could not find any registers with type " ^ string_of_ctyp reg_ctyp)) end - | _ -> - raise (Reporting.err_general l "Register reference assignment must take a register reference as an argument") - end + | _ -> + raise (Reporting.err_general l "Register reference assignment must take a register reference as an argument") + end | I_aux (I_funcall (clexp, false, function_id, [reg_ref]), (_, l)) as instr -> - let open Type_check in - begin match (if Env.is_extern (fst function_id) env "smt" then Some (Env.get_extern (fst function_id) env "smt") else None) with - | Some "reg_deref" -> - begin match cval_ctyp reg_ref with - | CT_ref reg_ctyp -> - (* Not find all the registers with this ctyp *) - begin match CTMap.find_opt reg_ctyp register_map with - | Some regs -> - let end_label = label "end_reg_deref_" in + let open Type_check in + begin + match + if Env.is_extern (fst function_id) env "smt" then Some (Env.get_extern (fst function_id) env "smt") else None + with + | Some "reg_deref" -> begin + match cval_ctyp reg_ref with + | CT_ref reg_ctyp -> begin + (* Not find all the registers with this ctyp *) + match CTMap.find_opt reg_ctyp register_map with + | Some regs -> + let end_label = label "end_reg_deref_" in + let try_reg r = + let next_label = label "next_reg_deref_" in + [ + ijump l (V_call (Neq, [V_lit (VL_ref (string_of_id r), reg_ctyp); reg_ref])) next_label; + icopy l clexp (V_id (global r, reg_ctyp)); + igoto end_label; + ilabel next_label; + ] + in + iblock (List.concat (List.map try_reg regs) @ [ilabel end_label]) + | None -> + raise (Reporting.err_general l ("Could not find any registers with type " ^ string_of_ctyp reg_ctyp)) + end + | _ -> raise (Reporting.err_general l "Register dereference must have a register reference as an argument") + end + | _ -> instr + end + | I_aux (I_copy (CL_addr (CL_id (id, ctyp)), cval), (_, l)) -> begin + match ctyp with + | CT_ref reg_ctyp -> begin + match CTMap.find_opt reg_ctyp register_map with + | Some regs -> + let end_label = label "end_reg_write_" in let try_reg r = - let next_label = label "next_reg_deref_" in - [ijump l (V_call (Neq, [V_lit (VL_ref (string_of_id r), reg_ctyp); reg_ref])) next_label; - icopy l clexp (V_id (global r, reg_ctyp)); - igoto end_label; - ilabel next_label] + let next_label = label "next_reg_write_" in + [ + ijump l (V_call (Neq, [V_lit (VL_ref (string_of_id r), reg_ctyp); V_id (id, ctyp)])) next_label; + icopy l (CL_id (global r, reg_ctyp)) cval; + igoto end_label; + ilabel next_label; + ] in iblock (List.concat (List.map try_reg regs) @ [ilabel end_label]) - | None -> - raise (Reporting.err_general l ("Could not find any registers with type " ^ string_of_ctyp reg_ctyp)) - end - | _ -> - raise (Reporting.err_general l "Register dereference must have a register reference as an argument") - end - | _ -> instr - end - | I_aux (I_copy (CL_addr (CL_id (id, ctyp)), cval), (_, l)) -> - begin match ctyp with - | CT_ref reg_ctyp -> - begin match CTMap.find_opt reg_ctyp register_map with - | Some regs -> - let end_label = label "end_reg_write_" in - let try_reg r = - let next_label = label "next_reg_write_" in - [ijump l (V_call (Neq, [V_lit (VL_ref (string_of_id r), reg_ctyp); V_id (id, ctyp)])) next_label; - icopy l (CL_id (global r, reg_ctyp)) cval; - igoto end_label; - ilabel next_label] - in - iblock (List.concat (List.map try_reg regs) @ [ilabel end_label]) - | None -> - raise (Reporting.err_general l ("Could not find any registers with type " ^ string_of_ctyp reg_ctyp)) + | None -> raise (Reporting.err_general l ("Could not find any registers with type " ^ string_of_ctyp reg_ctyp)) end - | _ -> - raise (Reporting.err_general l "Register reference assignment must take a register reference as an argument") - end + | _ -> + raise (Reporting.err_general l "Register reference assignment must take a register reference as an argument") + end | instr -> instr let rec smt_query ctx = function | Q_all ev -> - let stack = event_stack ctx ev in - smt_conj (Stack.fold (fun xs x -> x :: xs) [] stack) + let stack = event_stack ctx ev in + smt_conj (Stack.fold (fun xs x -> x :: xs) [] stack) | Q_exist ev -> - let stack = event_stack ctx ev in - smt_disj (Stack.fold (fun xs x -> x :: xs) [] stack) - | Q_not q -> - Fn ("not", [smt_query ctx q]) - | Q_and qs -> - Fn ("and", List.map (smt_query ctx) qs) - | Q_or qs -> - Fn ("or", List.map (smt_query ctx) qs) + let stack = event_stack ctx ev in + smt_disj (Stack.fold (fun xs x -> x :: xs) [] stack) + | Q_not q -> Fn ("not", [smt_query ctx q]) + | Q_and qs -> Fn ("and", List.map (smt_query ctx) qs) + | Q_or qs -> Fn ("or", List.map (smt_query ctx) qs) let dump_graph name cfg = let gv_file = name ^ ".gv" in @@ -2283,106 +2202,117 @@ let smt_instr_list name ctx all_cdefs instrs = let open Jib_ssa in let start, cfg = ssa instrs in let visit_order = - try topsort cfg with - | Not_a_DAG n -> - dump_graph name cfg; - raise (Reporting.err_general ctx.pragma_l - (Printf.sprintf "%s: control flow graph is not acyclic (node %d is in cycle)\nWrote graph to %s.gv" name n name)) + try topsort cfg + with Not_a_DAG n -> + dump_graph name cfg; + raise + (Reporting.err_general ctx.pragma_l + (Printf.sprintf "%s: control flow graph is not acyclic (node %d is in cycle)\nWrote graph to %s.gv" name n + name + ) + ) in - if !opt_debug_graphs then - dump_graph name cfg; + if !opt_debug_graphs then dump_graph name cfg; - List.iter (fun n -> + List.iter + (fun n -> match get_vertex cfg n with | None -> () | Some ((ssa_elems, cfnode), preds, succs) -> - let muxers = - ssa_elems |> List.map (smt_ssanode ctx cfg preds) |> List.concat - in - let ctx = { ctx with node = n; pathcond = lazy (get_pathcond n cfg ctx) } in - let basic_block = smt_cfnode all_cdefs ctx ssa_elems cfnode in - push_smt_defs stack muxers; - push_smt_defs stack basic_block - ) visit_order; + let muxers = ssa_elems |> List.map (smt_ssanode ctx cfg preds) |> List.concat in + let ctx = { ctx with node = n; pathcond = lazy (get_pathcond n cfg ctx) } in + let basic_block = smt_cfnode all_cdefs ctx ssa_elems cfnode in + push_smt_defs stack muxers; + push_smt_defs stack basic_block + ) + visit_order; - stack, start, cfg + (stack, start, cfg) let smt_cdef props lets name_file ctx all_cdefs = function - | CDEF_val (function_id, _, arg_ctyps, ret_ctyp) when Bindings.mem function_id props -> - begin match find_function [] function_id all_cdefs with - | intervening_lets, Some (None, args, instrs) -> - let prop_type, prop_args, pragma_l, vs = Bindings.find function_id props in - - let pragma = parse_pragma pragma_l prop_args in - - let ctx = { ctx with events = ref EventMap.empty; pragma_l = pragma_l; arg_stack = Stack.create () } in - - (* When we create each argument declaration, give it a unique - location from the $property pragma, so we can identify it later. *) - let arg_decls = - List.map2 (fun id ctyp -> let l = unique pragma_l in idecl l ctyp (name id)) args arg_ctyps - in - let instrs = - let open Jib_optimize in - (lets @ intervening_lets @ arg_decls @ instrs) - |> inline all_cdefs (fun _ -> true) - |> List.map (map_instr (expand_reg_deref ctx.tc_env ctx.register_map)) - |> flatten_instrs - |> remove_unused_labels - |> remove_pointless_goto - in - - let stack, _, _ = smt_instr_list (string_of_id function_id) ctx all_cdefs instrs in - - let query = smt_query ctx pragma.query in - push_smt_defs stack [Assert (Fn ("not", [query]))]; - - let fname = name_file (string_of_id function_id) in - let out_chan = open_out fname in - if prop_type = "counterexample" then - output_string out_chan "(set-option :produce-models true)\n"; - - let header = smt_header ctx all_cdefs in - - if !(ctx.use_string) || !(ctx.use_real) then - output_string out_chan "(set-logic ALL)\n" - else - output_string out_chan "(set-logic QF_AUFBVDT)\n"; - - List.iter (fun def -> output_string out_chan (string_of_smt_def def); output_string out_chan "\n") header; - - let queue = Queue_optimizer.optimize stack in - Queue.iter (fun def -> output_string out_chan (string_of_smt_def def); output_string out_chan "\n") queue; - - output_string out_chan "(check-sat)\n"; - if prop_type = "counterexample" then - output_string out_chan "(get-model)\n"; - - close_out out_chan; - if prop_type = "counterexample" && !opt_auto then ( - let arg_names = Stack.fold (fun m (k, v) -> (k, v) :: m) [] ctx.arg_stack in - let arg_smt_names = - List.map (function - | (I_aux (I_decl (_, Name (id, _)), (_, Unique (n, _)))) -> - (id, List.assoc_opt n arg_names) - | _ -> assert false - ) arg_decls + | CDEF_val (function_id, _, arg_ctyps, ret_ctyp) when Bindings.mem function_id props -> begin + match find_function [] function_id all_cdefs with + | intervening_lets, Some (None, args, instrs) -> + let prop_type, prop_args, pragma_l, vs = Bindings.find function_id props in + + let pragma = parse_pragma pragma_l prop_args in + + let ctx = { ctx with events = ref EventMap.empty; pragma_l; arg_stack = Stack.create () } in + + (* When we create each argument declaration, give it a unique + location from the $property pragma, so we can identify it later. *) + let arg_decls = + List.map2 + (fun id ctyp -> + let l = unique pragma_l in + idecl l ctyp (name id) + ) + args arg_ctyps + in + let instrs = + let open Jib_optimize in + lets @ intervening_lets @ arg_decls @ instrs + |> inline all_cdefs (fun _ -> true) + |> List.map (map_instr (expand_reg_deref ctx.tc_env ctx.register_map)) + |> flatten_instrs |> remove_unused_labels |> remove_pointless_goto in - check_counterexample ctx.ast ctx.tc_env fname function_id args arg_ctyps arg_smt_names - ); - | _ -> failwith "Bad function body" - end + let stack, _, _ = smt_instr_list (string_of_id function_id) ctx all_cdefs instrs in + + let query = smt_query ctx pragma.query in + push_smt_defs stack [Assert (Fn ("not", [query]))]; + + let fname = name_file (string_of_id function_id) in + let out_chan = open_out fname in + if prop_type = "counterexample" then output_string out_chan "(set-option :produce-models true)\n"; + + let header = smt_header ctx all_cdefs in + + if !(ctx.use_string) || !(ctx.use_real) then output_string out_chan "(set-logic ALL)\n" + else output_string out_chan "(set-logic QF_AUFBVDT)\n"; + + List.iter + (fun def -> + output_string out_chan (string_of_smt_def def); + output_string out_chan "\n" + ) + header; + + let queue = Queue_optimizer.optimize stack in + Queue.iter + (fun def -> + output_string out_chan (string_of_smt_def def); + output_string out_chan "\n" + ) + queue; + + output_string out_chan "(check-sat)\n"; + if prop_type = "counterexample" then output_string out_chan "(get-model)\n"; + + close_out out_chan; + if prop_type = "counterexample" && !opt_auto then ( + let arg_names = Stack.fold (fun m (k, v) -> (k, v) :: m) [] ctx.arg_stack in + let arg_smt_names = + List.map + (function + | I_aux (I_decl (_, Name (id, _)), (_, Unique (n, _))) -> (id, List.assoc_opt n arg_names) + | _ -> assert false + ) + arg_decls + in + check_counterexample ctx.ast ctx.tc_env fname function_id args arg_ctyps arg_smt_names + ) + | _ -> failwith "Bad function body" + end | _ -> () -let rec smt_cdefs props lets name_file ctx ast = - function +let rec smt_cdefs props lets name_file ctx ast = function | CDEF_let (_, vars, setup) :: cdefs -> - let vars = List.map (fun (id, ctyp) -> idecl (id_loc id) ctyp (global id)) vars in - smt_cdefs props (lets @ vars @ setup) name_file ctx ast cdefs; + let vars = List.map (fun (id, ctyp) -> idecl (id_loc id) ctyp (global id)) vars in + smt_cdefs props (lets @ vars @ setup) name_file ctx ast cdefs | cdef :: cdefs -> - smt_cdef props lets name_file ctx ast cdef; - smt_cdefs props lets name_file ctx ast cdefs + smt_cdef props lets name_file ctx ast cdef; + smt_cdefs props lets name_file ctx ast cdefs | [] -> () (* In order to support register references, we need to build a map @@ -2392,29 +2322,30 @@ let rec smt_cdefs props lets name_file ctx ast = *) let rec build_register_map rmap = function | CDEF_register (reg, ctyp, _) :: cdefs -> - let rmap = match CTMap.find_opt ctyp rmap with - | Some regs -> - CTMap.add ctyp (reg :: regs) rmap - | None -> - CTMap.add ctyp [reg] rmap - in - build_register_map rmap cdefs + let rmap = + match CTMap.find_opt ctyp rmap with + | Some regs -> CTMap.add ctyp (reg :: regs) rmap + | None -> CTMap.add ctyp [reg] rmap + in + build_register_map rmap cdefs | _ :: cdefs -> build_register_map rmap cdefs | [] -> rmap let compile env effect_info ast = let cdefs, jib_ctx = - let module Jibc = Jib_compile.Make(SMT_config(struct let unroll_limit = !opt_unroll_limit end)) in + let module Jibc = Jib_compile.Make (SMT_config (struct + let unroll_limit = !opt_unroll_limit + end)) in let env, effect_info = Jib_compile.add_special_functions env effect_info in let ctx = Jib_compile.initial_ctx env effect_info in let t = Profile.start () in let cdefs, ctx = Jibc.compile_ast ctx ast in Profile.finish "Compiling to Jib IR" t; - cdefs, ctx + (cdefs, ctx) in let cdefs = Jib_optimize.unique_per_function_ids cdefs in let rmap = build_register_map CTMap.empty cdefs in - cdefs, jib_ctx, { (initial_ctx ()) with tc_env = jib_ctx.tc_env; register_map = rmap; ast = ast } + (cdefs, jib_ctx, { (initial_ctx ()) with tc_env = jib_ctx.tc_env; register_map = rmap; ast }) let serialize_smt_model file env effect_info ast = let cdefs, _, ctx = compile env effect_info ast in @@ -2436,6 +2367,4 @@ let generate_smt props name_file env effect_info ast = try let cdefs, _, ctx = compile env effect_info ast in smt_cdefs props [] name_file ctx cdefs cdefs - with - | Type_check.Type_error (_, l, err) -> - raise (Reporting.err_typ l (Type_error.string_of_type_error err)); + with Type_check.Type_error (_, l, err) -> raise (Reporting.err_typ l (Type_error.string_of_type_error err)) diff --git a/src/sail_smt_backend/jib_smt.mli b/src/sail_smt_backend/jib_smt.mli index ad5f658d0..35b5c4351 100644 --- a/src/sail_smt_backend/jib_smt.mli +++ b/src/sail_smt_backend/jib_smt.mli @@ -93,54 +93,48 @@ val opt_default_lbits_index : int ref val opt_default_vector_index : int ref type ctx = { - lbits_index : int; - (** Arbitrary-precision bitvectors are represented as a (BitVec lbits_index, BitVec (2 ^ lbits_index)) pair. *) - lint_size : int; - (** The size we use for integers where we don't know how large they are statically. *) - vector_index : int; - (** A generic vector, vector('a) becomes Array (BitVec vector_index) 'a. + lbits_index : int; + (** Arbitrary-precision bitvectors are represented as a (BitVec lbits_index, BitVec (2 ^ lbits_index)) pair. *) + lint_size : int; (** The size we use for integers where we don't know how large they are statically. *) + vector_index : int; + (** A generic vector, vector('a) becomes Array (BitVec vector_index) 'a. We need to take care that vector_index is large enough for all generic vectors. *) - register_map : id list CTMap.t; - (** A map from each ctyp to a list of registers of that ctyp *) - tuple_sizes : IntSet.t ref; - (** A set to keep track of all the tuple sizes we need to generate types for *) - tc_env : Type_check.Env.t; - (** tc_env is the global type-checking environment *) - pragma_l : Ast.l; - (** A location, usually the $counterexample or $property we are + register_map : id list CTMap.t; (** A map from each ctyp to a list of registers of that ctyp *) + tuple_sizes : IntSet.t ref; (** A set to keep track of all the tuple sizes we need to generate types for *) + tc_env : Type_check.Env.t; (** tc_env is the global type-checking environment *) + pragma_l : Ast.l; + (** A location, usually the $counterexample or $property we are generating the SMT for. Used for error messages. *) - arg_stack : (int * string) Stack.t; - (** Used internally to keep track of function argument names *) - ast : Type_check.tannot ast; - (** The fully type-checked ast *) - shared : ctyp Bindings.t; - (** Shared variables. These variables do not get renamed by + arg_stack : (int * string) Stack.t; (** Used internally to keep track of function argument names *) + ast : Type_check.tannot ast; (** The fully type-checked ast *) + shared : ctyp Bindings.t; + (** Shared variables. These variables do not get renamed by Smtlib.suffix_variables_def, and their SSA number is omitted. They should therefore only ever be read and never written. Used by sail-axiomatic for symbolic values in the initial litmus state. *) - preserved : IdSet.t; - (** icopy instructions to an id in preserved will generated a + preserved : IdSet.t; + (** icopy instructions to an id in preserved will generated a define-const (by using Smtlib.Preserved_const) that will not be simplified away or renamed. It will also not get a SSA number. Such variables can therefore only ever be written to once, and never read. They are used by sail-axiomatic to extract information from the generated SMT. *) - events : smt_exp Stack.t EventMap.t ref; - (** For every event type we have a stack of boolean SMT + events : smt_exp Stack.t EventMap.t ref; + (** For every event type we have a stack of boolean SMT expressions for each occurance of that event. See src/property.ml for the event types *) - node : int; - pathcond : smt_exp Lazy.t; - (** When generating SMT for an instruction pathcond will contain + node : int; + pathcond : smt_exp Lazy.t; + (** When generating SMT for an instruction pathcond will contain the global path conditional of the containing block/node in the control flow graph *) - use_string : bool ref; - use_real : bool ref - (** Set if we need to use strings or real numbers in the generated + use_string : bool ref; + use_real : bool ref; + (** Set if we need to use strings or real numbers in the generated SMT, which then requires set-logic ALL or similar depending on the solver *) - } +} (** Compile an AST into Jib suitable for SMT generation, and initialise a context. *) val compile : Type_check.Env.t -> Effects.side_effect_info -> Type_check.tannot ast -> cdef list * Jib_compile.ctx * ctx @@ -155,7 +149,8 @@ val smt_header : ctx -> cdef list -> smt_def list val smt_query : ctx -> Property.query -> smt_exp -val smt_instr_list : string -> ctx -> cdef list -> instr list -> smt_def Stack.t * int * (ssa_elem list * cf_node) Jib_ssa.array_graph +val smt_instr_list : + string -> ctx -> cdef list -> instr list -> smt_def Stack.t * int * (ssa_elem list * cf_node) Jib_ssa.array_graph module type Sequence = sig type 'a t @@ -168,22 +163,21 @@ end final SMTLIB file. Depending on the order in which we want to process the results we can either use a FIFO queue or a LIFO stack, or any other structure. *) -module Make_optimizer(S : Sequence) : sig +module Make_optimizer (S : Sequence) : sig val optimize : smt_def Stack.t -> smt_def S.t end -val serialize_smt_model : - string -> Type_check.Env.t -> Effects.side_effect_info -> Type_check.tannot ast -> unit +val serialize_smt_model : string -> Type_check.Env.t -> Effects.side_effect_info -> Type_check.tannot ast -> unit -val deserialize_smt_model : - string -> cdef list * ctx +val deserialize_smt_model : string -> cdef list * ctx (** Generate SMT for all the $property and $counterexample pragmas in an AST, and write it to appropriately named files. *) val generate_smt : - (string * string * l * 'a val_spec) Bindings.t (* See Property.find_properties *) - -> (string -> string) (* Applied to each function name to generate the file name for the smtlib file *) - -> Type_check.Env.t - -> Effects.side_effect_info - -> Type_check.tannot ast - -> unit + (string * string * l * 'a val_spec) Bindings.t (* See Property.find_properties *) -> + (string -> string) -> + (* Applied to each function name to generate the file name for the smtlib file *) + Type_check.Env.t -> + Effects.side_effect_info -> + Type_check.tannot ast -> + unit diff --git a/src/sail_smt_backend/jib_ssa.ml b/src/sail_smt_backend/jib_ssa.ml index 71ca46394..569efb484 100644 --- a/src/sail_smt_backend/jib_ssa.ml +++ b/src/sail_smt_backend/jib_ssa.ml @@ -71,40 +71,35 @@ open Ast_util open Jib open Jib_util -module IntSet = Set.Make(struct type t = int let compare = compare end) -module IntMap = Map.Make(struct type t = int let compare = compare end) +module IntSet = Set.Make (struct + type t = int + let compare = compare +end) +module IntMap = Map.Make (struct + type t = int + let compare = compare +end) (**************************************************************************) (* 1. Mutable graph type *) (**************************************************************************) type 'a array_graph = { - mutable next : int; - mutable nodes : ('a * IntSet.t * IntSet.t) option array; - mutable next_cond : int; - mutable conds : cval IntMap.t - } - -let make ~initial_size () = { - next = 0; - nodes = Array.make initial_size None; - next_cond = 1; - conds = IntMap.empty - } - -let get_cond graph n = - if n >= 0 then - IntMap.find n graph.conds - else - V_call (Bnot, [IntMap.find (abs n) graph.conds]) + mutable next : int; + mutable nodes : ('a * IntSet.t * IntSet.t) option array; + mutable next_cond : int; + mutable conds : cval IntMap.t; +} + +let make ~initial_size () = { next = 0; nodes = Array.make initial_size None; next_cond = 1; conds = IntMap.empty } + +let get_cond graph n = if n >= 0 then IntMap.find n graph.conds else V_call (Bnot, [IntMap.find (abs n) graph.conds]) let get_vertex graph n = graph.nodes.(n) let iter_graph f graph = for n = 0 to graph.next - 1 do - match graph.nodes.(n) with - | Some (x, y, z) -> f x y z - | None -> () + match graph.nodes.(n) with Some (x, y, z) -> f x y z | None -> () done let add_cond cval graph = @@ -116,12 +111,11 @@ let add_cond cval graph = (** Add a vertex to a graph, returning the node index *) let add_vertex data graph = let n = graph.next in - if n >= Array.length graph.nodes then - begin - let new_nodes = Array.make (Array.length graph.nodes * 2) None in - Array.blit graph.nodes 0 new_nodes 0 (Array.length graph.nodes); - graph.nodes <- new_nodes - end; + if n >= Array.length graph.nodes then begin + let new_nodes = Array.make (Array.length graph.nodes * 2) None in + Array.blit graph.nodes 0 new_nodes 0 (Array.length graph.nodes); + graph.nodes <- new_nodes + end; let n = graph.next in graph.nodes.(n) <- Some (data, IntSet.empty, IntSet.empty); graph.next <- n + 1; @@ -130,17 +124,14 @@ let add_vertex data graph = (** Add an edge between two existing vertices. Raises Invalid_argument if either of the vertices do not exist. *) let add_edge n m graph = - begin match graph.nodes.(n) with - | Some (data, parents, children) -> - graph.nodes.(n) <- Some (data, parents, IntSet.add m children) - | None -> - raise (Invalid_argument "Parent node does not exist in graph") + begin + match graph.nodes.(n) with + | Some (data, parents, children) -> graph.nodes.(n) <- Some (data, parents, IntSet.add m children) + | None -> raise (Invalid_argument "Parent node does not exist in graph") end; match graph.nodes.(m) with - | Some (data, parents, children) -> - graph.nodes.(m) <- Some (data, IntSet.add n parents, children) - | None -> - raise (Invalid_argument "Child node does not exist in graph") + | Some (data, parents, children) -> graph.nodes.(m) <- Some (data, IntSet.add n parents, children) + | None -> raise (Invalid_argument "Child node does not exist in graph") let cardinal graph = graph.next @@ -149,18 +140,15 @@ let reachable roots graph = let rec reachable' n = if IntSet.mem n !visited then () - else - begin - visited := IntSet.add n !visited; - match graph.nodes.(n) with - | Some (_, _, successors) -> - IntSet.iter reachable' successors - | None -> () - end + else begin + visited := IntSet.add n !visited; + match graph.nodes.(n) with Some (_, _, successors) -> IntSet.iter reachable' successors | None -> () + end in - IntSet.iter reachable' roots; !visited + IntSet.iter reachable' roots; + !visited -exception Not_a_DAG of int;; +exception Not_a_DAG of int let topsort graph = let marked = ref IntSet.empty in @@ -168,31 +156,26 @@ let topsort graph = let list = ref [] in let rec visit node = - if IntSet.mem node !temp_marked then - raise (Not_a_DAG node) - else if IntSet.mem node !marked then - () - else - begin match get_vertex graph node with + if IntSet.mem node !temp_marked then raise (Not_a_DAG node) + else if IntSet.mem node !marked then () + else begin + match get_vertex graph node with | None -> failwith "Node does not exist in topsort" | Some (_, _, succs) -> - temp_marked := IntSet.add node !temp_marked; - IntSet.iter visit succs; - marked := IntSet.add node !marked; - temp_marked := IntSet.remove node !temp_marked; - list := node :: !list - end + temp_marked := IntSet.add node !temp_marked; + IntSet.iter visit succs; + marked := IntSet.add node !marked; + temp_marked := IntSet.remove node !temp_marked; + list := node :: !list + end in let find_unmarked () = let unmarked = ref (-1) in let i = ref 0 in while !unmarked = -1 && !i < Array.length graph.nodes do - begin match get_vertex graph !i with - | None -> () - | Some _ -> - if not (IntSet.mem !i !marked) then - unmarked := !i + begin + match get_vertex graph !i with None -> () | Some _ -> if not (IntSet.mem !i !marked) then unmarked := !i end; incr i done; @@ -201,21 +184,21 @@ let topsort graph = let rec topsort' () = let unmarked = find_unmarked () in - if unmarked = -1 then - () - else - (visit unmarked; topsort' ()) + if unmarked = -1 then () + else ( + visit unmarked; + topsort' () + ) in - topsort' (); !list + topsort' (); + !list let prune visited graph = for i = 0 to graph.next - 1 do match graph.nodes.(i) with | Some (n, preds, succs) -> - if IntSet.mem i visited then - graph.nodes.(i) <- Some (n, IntSet.inter visited preds, IntSet.inter visited succs) - else - graph.nodes.(i) <- None + if IntSet.mem i visited then graph.nodes.(i) <- Some (n, IntSet.inter visited preds, IntSet.inter visited succs) + else graph.nodes.(i) <- None | None -> () done @@ -232,18 +215,14 @@ type terminator = | T_label of string | T_none -type cf_node = - | CF_label of string - | CF_block of instr list * terminator - | CF_guard of int - | CF_start of ctyp NameMap.t +type cf_node = CF_label of string | CF_block of instr list * terminator | CF_guard of int | CF_start of ctyp NameMap.t let to_terminator graph = function | I_label label -> T_label label | I_goto label -> T_goto label | I_jump (cval, label) -> - let n = add_cond cval graph in - T_jump (n, label) + let n = add_cond cval graph in + T_jump (n, label) | I_end name -> T_end name | I_exit cause -> T_exit cause | I_undefined ctyp -> T_undefined ctyp @@ -251,55 +230,53 @@ let to_terminator graph = function (* For now we only generate CFGs for flat lists of instructions *) let control_flow_graph instrs = - let module StringMap = Map.Make(String) in + let module StringMap = Map.Make (String) in let labels = ref StringMap.empty in let graph = make ~initial_size:512 () in - iter_instr (fun (I_aux (instr, annot)) -> + iter_instr + (fun (I_aux (instr, annot)) -> match instr with - | I_label label -> - labels := StringMap.add label (add_vertex ([], CF_label label) graph) !labels + | I_label label -> labels := StringMap.add label (add_vertex ([], CF_label label) graph) !labels | _ -> () - ) (iblock instrs); + ) + (iblock instrs); let cf_split (I_aux (aux, _)) = - match aux with - | I_label _ | I_goto _ | I_jump _ | I_end _ | I_exit _ | I_undefined _ -> true - | _ -> false + match aux with I_label _ | I_goto _ | I_jump _ | I_end _ | I_exit _ | I_undefined _ -> true | _ -> false in let rec cfg preds instrs = let before, after = instr_split_at cf_split instrs in - let terminator, after = match after with - | I_aux (instr, _) :: after -> to_terminator graph instr, after - | [] -> T_none, [] + let terminator, after = + match after with I_aux (instr, _) :: after -> (to_terminator graph instr, after) | [] -> (T_none, []) in - let preds = match before, terminator with + let preds = + match (before, terminator) with | [], (T_none | T_label _) -> preds | instrs, _ -> - let n = add_vertex ([], CF_block (instrs, terminator)) graph in - List.iter (fun p -> add_edge p n graph) preds; - [n] + let n = add_vertex ([], CF_block (instrs, terminator)) graph in + List.iter (fun p -> add_edge p n graph) preds; + [n] in match terminator with - | T_end _ | T_exit _ | T_undefined _ -> - cfg [] after - + | T_end _ | T_exit _ | T_undefined _ -> cfg [] after | T_goto label -> - List.iter (fun p -> add_edge p (StringMap.find label !labels) graph) preds; - cfg [] after - + List.iter (fun p -> add_edge p (StringMap.find label !labels) graph) preds; + cfg [] after | T_jump (cond, label) -> - let t = add_vertex ([], CF_guard cond) graph in - let f = add_vertex ([], CF_guard (- cond)) graph in - List.iter (fun p -> add_edge p t graph; add_edge p f graph) preds; - add_edge t (StringMap.find label !labels) graph; - cfg [f] after - - | T_label label -> - cfg (StringMap.find label !labels :: preds) after - + let t = add_vertex ([], CF_guard cond) graph in + let f = add_vertex ([], CF_guard (-cond)) graph in + List.iter + (fun p -> + add_edge p t graph; + add_edge p f graph + ) + preds; + add_edge t (StringMap.find label !labels) graph; + cfg [f] after + | T_label label -> cfg (StringMap.find label !labels :: preds) after | T_none -> preds in @@ -309,7 +286,7 @@ let control_flow_graph instrs = let visited = reachable (IntSet.singleton start) graph in prune visited graph; - start, finish, graph + (start, finish, graph) (**************************************************************************) (* 3. Computing dominators *) @@ -321,24 +298,23 @@ let control_flow_graph instrs = which runs in O(n log(n)) time. *) let immediate_dominators graph root = let none = -1 in - let vertex = Array.make (cardinal graph) 0 in - let parent = Array.make (cardinal graph) none in + let vertex = Array.make (cardinal graph) 0 in + let parent = Array.make (cardinal graph) none in let ancestor = Array.make (cardinal graph) none in - let semi = Array.make (cardinal graph) none in - let idom = Array.make (cardinal graph) none in - let samedom = Array.make (cardinal graph) none in - let best = Array.make (cardinal graph) none in - let dfnum = Array.make (cardinal graph) (-1) in - let bucket = Array.make (cardinal graph) IntSet.empty in + let semi = Array.make (cardinal graph) none in + let idom = Array.make (cardinal graph) none in + let samedom = Array.make (cardinal graph) none in + let best = Array.make (cardinal graph) none in + let dfnum = Array.make (cardinal graph) (-1) in + let bucket = Array.make (cardinal graph) IntSet.empty in let rec ancestor_with_lowest_semi v = let a = ancestor.(v) in - if ancestor.(a) <> none then + if ancestor.(a) <> none then ( let b = ancestor_with_lowest_semi a in ancestor.(v) <- ancestor.(a); - if dfnum.(semi.(b)) < dfnum.(semi.(best.(v))) then - best.(v) <- b - else (); + if dfnum.(semi.(b)) < dfnum.(semi.(best.(v))) then best.(v) <- b else () + ) else (); if best.(v) <> none then best.(v) else v in @@ -351,17 +327,15 @@ let immediate_dominators graph root = let count = ref 0 in let rec dfs p n = - if dfnum.(n) = -1 then - begin - dfnum.(n) <- !count; - vertex.(!count) <- n; - parent.(n) <- p; - incr count; - match graph.nodes.(n) with - | Some (_, _, successors) -> - IntSet.iter (fun w -> dfs n w) successors - | None -> assert false - end + if dfnum.(n) = -1 then begin + dfnum.(n) <- !count; + vertex.(!count) <- n; + parent.(n) <- p; + incr count; + match graph.nodes.(n) with + | Some (_, _, successors) -> IntSet.iter (fun w -> dfs n w) successors + | None -> assert false + end in dfs none root; @@ -370,34 +344,30 @@ let immediate_dominators graph root = let p = parent.(n) in let s = ref p in - begin match graph.nodes.(n) with - | Some (_, predecessors, _) -> - IntSet.iter (fun v -> - let s' = - if dfnum.(v) <= dfnum.(n) then - v - else - semi.(ancestor_with_lowest_semi v) - in - if dfnum.(s') < dfnum.(!s) then s := s' - ) predecessors - | None -> assert false + begin + match graph.nodes.(n) with + | Some (_, predecessors, _) -> + IntSet.iter + (fun v -> + let s' = if dfnum.(v) <= dfnum.(n) then v else semi.(ancestor_with_lowest_semi v) in + if dfnum.(s') < dfnum.(!s) then s := s' + ) + predecessors + | None -> assert false end; semi.(n) <- !s; bucket.(!s) <- IntSet.add n bucket.(!s); link p n; - IntSet.iter (fun v -> + IntSet.iter + (fun v -> let y = ancestor_with_lowest_semi v in - if semi.(y) = semi.(v) then - idom.(v) <- p - else - samedom.(v) <- y - ) bucket.(p); + if semi.(y) = semi.(v) then idom.(v) <- p else samedom.(v) <- y + ) + bucket.(p) done; for i = 1 to !count - 1 do let n = vertex.(i) in - if samedom.(n) <> none then - idom.(n) <- idom.(samedom.(n)) + if samedom.(n) <> none then idom.(n) <- idom.(samedom.(n)) done; idom @@ -409,8 +379,7 @@ let dominator_children idom = for n = 0 to Array.length idom - 1 do let p = idom.(n) in - if p <> none then - children.(p) <- IntSet.add n (children.(p)) + if p <> none then children.(p) <- IntSet.add n children.(p) done; children @@ -419,12 +388,7 @@ let dominator_children idom = let rec dominate idom n w = let none = -1 in let p = idom.(n) in - if p = none then - false - else if p = w then - true - else - dominate idom p w + if p = none then false else if p = w then true else dominate idom p w let dominance_frontiers graph root idom children = let df = Array.make (cardinal graph) IntSet.empty in @@ -432,21 +396,17 @@ let dominance_frontiers graph root idom children = let rec compute_df n = let set = ref IntSet.empty in - begin match graph.nodes.(n) with - | Some (content, _, succs) -> - IntSet.iter (fun y -> - if idom.(y) <> n then - set := IntSet.add y !set - ) succs - | None -> () + begin + match graph.nodes.(n) with + | Some (content, _, succs) -> IntSet.iter (fun y -> if idom.(y) <> n then set := IntSet.add y !set) succs + | None -> () end; - IntSet.iter (fun c -> + IntSet.iter + (fun c -> compute_df c; - IntSet.iter (fun w -> - if not (dominate idom n w) then - set := IntSet.add w !set - ) (df.(c)) - ) (children.(n)); + IntSet.iter (fun w -> if not (dominate idom n w) then set := IntSet.add w !set) df.(c) + ) + children.(n); df.(n) <- !set in compute_df root; @@ -456,9 +416,7 @@ let dominance_frontiers graph root idom children = (* 4. Conversion to SSA form *) (**************************************************************************) -type ssa_elem = - | Phi of Jib.name * Jib.ctyp * Jib.name list - | Pi of Jib.cval list +type ssa_elem = Phi of Jib.name * Jib.ctyp * Jib.name list | Pi of Jib.cval list let place_phi_functions graph df = let defsites = ref NameCTMap.empty in @@ -466,8 +424,7 @@ let place_phi_functions graph df = let all_vars = ref NameCTSet.empty in let rec all_decls = function - | I_aux ((I_init (ctyp, id, _) | I_decl (ctyp, id)), _) :: instrs -> - NameCTSet.add (id, ctyp) (all_decls instrs) + | I_aux ((I_init (ctyp, id, _) | I_decl (ctyp, id)), _) :: instrs -> NameCTSet.add (id, ctyp) (all_decls instrs) | _ :: instrs -> all_decls instrs | [] -> NameCTSet.empty in @@ -475,43 +432,53 @@ let place_phi_functions graph df = let orig_A n = match graph.nodes.(n) with | Some ((_, CF_block (instrs, _)), _, _) -> - let vars = List.fold_left NameCTSet.union NameCTSet.empty (List.map instr_typed_writes instrs) in - let vars = NameCTSet.diff vars (all_decls instrs) in - all_vars := NameCTSet.union vars !all_vars; - vars + let vars = List.fold_left NameCTSet.union NameCTSet.empty (List.map instr_typed_writes instrs) in + let vars = NameCTSet.diff vars (all_decls instrs) in + all_vars := NameCTSet.union vars !all_vars; + vars | Some _ -> NameCTSet.empty | None -> NameCTSet.empty in let phi_A = ref NameCTMap.empty in for n = 0 to graph.next - 1 do - NameCTSet.iter (fun a -> + NameCTSet.iter + (fun a -> let ds = match NameCTMap.find_opt a !defsites with Some ds -> ds | None -> IntSet.empty in defsites := NameCTMap.add a (IntSet.add n ds) !defsites - ) (orig_A n) + ) + (orig_A n) done; - NameCTSet.iter (fun a -> + NameCTSet.iter + (fun a -> let workset = ref (NameCTMap.find a !defsites) in while not (IntSet.is_empty !workset) do let n = IntSet.choose !workset in workset := IntSet.remove n !workset; - IntSet.iter (fun y -> + IntSet.iter + (fun y -> let phi_A_a = match NameCTMap.find_opt a !phi_A with Some set -> set | None -> IntSet.empty in - if not (IntSet.mem y phi_A_a) then + if not (IntSet.mem y phi_A_a) then begin begin - begin match graph.nodes.(y) with + match graph.nodes.(y) with | Some ((phis, cfnode), preds, succs) -> - graph.nodes.(y) <- Some ((Phi (fst a, snd a, Util.list_init (IntSet.cardinal preds) (fun _ -> fst a)) :: phis, cfnode), preds, succs) + graph.nodes.(y) <- + Some + ( (Phi (fst a, snd a, Util.list_init (IntSet.cardinal preds) (fun _ -> fst a)) :: phis, cfnode), + preds, + succs + ) | None -> assert false - end; - phi_A := NameCTMap.add a (IntSet.add y phi_A_a) !phi_A; - if not (NameCTSet.mem a (orig_A y)) then - workset := IntSet.add y !workset - end - ) df.(n) + end; + phi_A := NameCTMap.add a (IntSet.add y phi_A_a) !phi_A; + if not (NameCTSet.mem a (orig_A y)) then workset := IntSet.add y !workset + end + ) + df.(n) done - ) !all_vars + ) + !all_vars let rename_variables graph root children = let counts = ref NameMap.empty in @@ -528,45 +495,46 @@ let rename_variables graph root children = | Return _ -> Return i in - let get_count id = - match NameMap.find_opt id !counts with Some n -> n | None -> 0 - in - let top_stack id = - match NameMap.find_opt id !stacks with Some (x :: _) -> x | Some [] -> 0 | None -> 0 - in + let get_count id = match NameMap.find_opt id !counts with Some n -> n | None -> 0 in + let top_stack id = match NameMap.find_opt id !stacks with Some (x :: _) -> x | Some [] -> 0 | None -> 0 in let top_stack_phi id ctyp = - match NameMap.find_opt id !stacks with Some (x :: _) -> x | Some [] -> 0 | None -> (phi_zeros := NameMap.add (ssa_name 0 id) ctyp !phi_zeros; 0) + match NameMap.find_opt id !stacks with + | Some (x :: _) -> x + | Some [] -> 0 + | None -> + phi_zeros := NameMap.add (ssa_name 0 id) ctyp !phi_zeros; + 0 in let push_stack id n = - stacks := NameMap.add id (n :: match NameMap.find_opt id !stacks with Some s -> s | None -> []) !stacks + stacks := NameMap.add id (n :: (match NameMap.find_opt id !stacks with Some s -> s | None -> [])) !stacks in let rec fold_cval = function | V_id (id, ctyp) -> - let i = top_stack id in - V_id (ssa_name i id, ctyp) + let i = top_stack id in + V_id (ssa_name i id, ctyp) | V_lit (vl, ctyp) -> V_lit (vl, ctyp) | V_call (id, fs) -> V_call (id, List.map fold_cval fs) | V_field (f, field) -> V_field (fold_cval f, field) | V_tuple_member (f, len, n) -> V_tuple_member (fold_cval f, len, n) | V_ctor_kind (f, ctor, ctyp) -> V_ctor_kind (fold_cval f, ctor, ctyp) | V_ctor_unwrap (f, ctor, ctyp) -> V_ctor_unwrap (fold_cval f, ctor, ctyp) - | V_struct (fields, ctyp) -> V_struct (List.map (fun (field, cval) -> field, fold_cval cval) fields, ctyp) + | V_struct (fields, ctyp) -> V_struct (List.map (fun (field, cval) -> (field, fold_cval cval)) fields, ctyp) | V_tuple (members, ctyp) -> V_tuple (List.map fold_cval members, ctyp) in let rec fold_clexp rmw = function | CL_id (id, ctyp) when rmw -> - let i = top_stack id in - let j = get_count id + 1 in - counts := NameMap.add id j !counts; - push_stack id j; - CL_rmw (ssa_name i id, ssa_name j id, ctyp) + let i = top_stack id in + let j = get_count id + 1 in + counts := NameMap.add id j !counts; + push_stack id j; + CL_rmw (ssa_name i id, ssa_name j id, ctyp) | CL_id (id, ctyp) -> - let i = get_count id + 1 in - counts := NameMap.add id i !counts; - push_stack id i; - CL_id (ssa_name i id, ctyp) + let i = get_count id + 1 in + counts := NameMap.add id i !counts; + push_stack id i; + CL_id (ssa_name i id, ctyp) | CL_rmw _ -> assert false | CL_field (clexp, field) -> CL_field (fold_clexp true clexp, field) | CL_addr clexp -> CL_addr (fold_clexp false clexp) @@ -575,129 +543,131 @@ let rename_variables graph root children = in let ssa_instr (I_aux (aux, annot)) = - let aux = match aux with + let aux = + match aux with | I_funcall (clexp, extern, id, args) -> - let args = List.map fold_cval args in - I_funcall (fold_clexp false clexp, extern, id, args) + let args = List.map fold_cval args in + I_funcall (fold_clexp false clexp, extern, id, args) | I_copy (clexp, cval) -> - let cval = fold_cval cval in - I_copy (fold_clexp false clexp, cval) + let cval = fold_cval cval in + I_copy (fold_clexp false clexp, cval) | I_decl (ctyp, id) -> - let i = get_count id + 1 in - counts := NameMap.add id i !counts; - push_stack id i; - I_decl (ctyp, ssa_name i id) + let i = get_count id + 1 in + counts := NameMap.add id i !counts; + push_stack id i; + I_decl (ctyp, ssa_name i id) | I_init (ctyp, id, cval) -> - let cval = fold_cval cval in - let i = get_count id + 1 in - counts := NameMap.add id i !counts; - push_stack id i; - I_init (ctyp, ssa_name i id, cval) + let cval = fold_cval cval in + let i = get_count id + 1 in + counts := NameMap.add id i !counts; + push_stack id i; + I_init (ctyp, ssa_name i id, cval) | instr -> instr in I_aux (aux, annot) in let ssa_terminator = function - | T_jump (cond, label) -> - begin match IntMap.find_opt cond graph.conds with - | Some cval -> - graph.conds <- IntMap.add cond (fold_cval cval) graph.conds; - T_jump (cond, label) - | None -> assert false - end + | T_jump (cond, label) -> begin + match IntMap.find_opt cond graph.conds with + | Some cval -> + graph.conds <- IntMap.add cond (fold_cval cval) graph.conds; + T_jump (cond, label) + | None -> assert false + end | T_end id -> - let i = top_stack id in - T_end (ssa_name i id) + let i = top_stack id in + T_end (ssa_name i id) | terminator -> terminator in let ssa_cfnode = function | CF_start inits -> CF_start inits | CF_block (instrs, terminator) -> - let instrs = List.map ssa_instr instrs in - CF_block (instrs, ssa_terminator terminator) + let instrs = List.map ssa_instr instrs in + CF_block (instrs, ssa_terminator terminator) | CF_label label -> CF_label label | CF_guard cond -> CF_guard cond in let ssa_ssanode = function | Phi (id, ctyp, args) -> - let i = get_count id + 1 in - counts := NameMap.add id i !counts; - push_stack id i; - Phi (ssa_name i id, ctyp, args) + let i = get_count id + 1 in + counts := NameMap.add id i !counts; + push_stack id i; + Phi (ssa_name i id, ctyp, args) | Pi _ -> assert false (* Should not be introduced at this point *) in let fix_phi j = function | Phi (id, ctyp, ids) -> - let fix_arg k a = - if k = j then - let i = top_stack_phi a ctyp in - ssa_name i a - else a - in - Phi (id, ctyp, List.mapi fix_arg ids) + let fix_arg k a = + if k = j then ( + let i = top_stack_phi a ctyp in + ssa_name i a + ) + else a + in + Phi (id, ctyp, List.mapi fix_arg ids) | Pi _ -> assert false (* Should not be introduced at this point *) in let rec rename n = let old_stacks = !stacks in - begin match graph.nodes.(n) with - | Some ((ssa, cfnode), preds, succs) -> - let ssa = List.map ssa_ssanode ssa in - graph.nodes.(n) <- Some ((ssa, ssa_cfnode cfnode), preds, succs); - List.iter (fun succ -> - match graph.nodes.(succ) with - | Some ((ssa, cfnode), preds, succs) -> - (* Suppose n is the j-th predecessor of succ *) - let rec find_j n succ = function - | pred :: preds -> - if pred = succ then n else find_j (n + 1) succ preds - | [] -> assert false - in - let j = find_j 0 n (IntSet.elements preds) in - graph.nodes.(succ) <- Some ((List.map (fix_phi j) ssa, cfnode), preds, succs) - | None -> assert false - ) (IntSet.elements succs) - | None -> assert false + begin + match graph.nodes.(n) with + | Some ((ssa, cfnode), preds, succs) -> + let ssa = List.map ssa_ssanode ssa in + graph.nodes.(n) <- Some ((ssa, ssa_cfnode cfnode), preds, succs); + List.iter + (fun succ -> + match graph.nodes.(succ) with + | Some ((ssa, cfnode), preds, succs) -> + (* Suppose n is the j-th predecessor of succ *) + let rec find_j n succ = function + | pred :: preds -> if pred = succ then n else find_j (n + 1) succ preds + | [] -> assert false + in + let j = find_j 0 n (IntSet.elements preds) in + graph.nodes.(succ) <- Some ((List.map (fix_phi j) ssa, cfnode), preds, succs) + | None -> assert false + ) + (IntSet.elements succs) + | None -> assert false end; - IntSet.iter (fun child -> rename child) (children.(n)); + IntSet.iter (fun child -> rename child) children.(n); stacks := old_stacks in rename root; match graph.nodes.(root) with - | Some ((ssa, CF_start _), preds, succs) -> - graph.nodes.(root) <- Some ((ssa, CF_start !phi_zeros), preds, succs) + | Some ((ssa, CF_start _), preds, succs) -> graph.nodes.(root) <- Some ((ssa, CF_start !phi_zeros), preds, succs) | _ -> failwith "root node is not CF_start" let place_pi_functions graph start idom children = let get_guard = function - | CF_guard cond -> - begin match IntMap.find_opt (abs cond) graph.conds with - | Some guard when cond > 0 -> [guard] - | Some guard -> [V_call (Bnot, [guard])] - | None -> assert false - end + | CF_guard cond -> begin + match IntMap.find_opt (abs cond) graph.conds with + | Some guard when cond > 0 -> [guard] + | Some guard -> [V_call (Bnot, [guard])] + | None -> assert false + end | _ -> [] in - let get_pi_contents ssanodes = - List.concat (List.map (function Pi guards -> guards | _ -> []) ssanodes) - in + let get_pi_contents ssanodes = List.concat (List.map (function Pi guards -> guards | _ -> []) ssanodes) in let rec go n = - begin match graph.nodes.(n) with - | Some ((ssa, cfnode), preds, succs) -> - let p = idom.(n) in - if p <> -1 then - begin match graph.nodes.(p) with - | Some ((dom_ssa, _), _, _) -> - let args = get_guard cfnode @ get_pi_contents dom_ssa in - graph.nodes.(n) <- Some ((Pi args :: ssa, cfnode), preds, succs) - | None -> assert false - end - | None -> assert false + begin + match graph.nodes.(n) with + | Some ((ssa, cfnode), preds, succs) -> + let p = idom.(n) in + if p <> -1 then begin + match graph.nodes.(p) with + | Some ((dom_ssa, _), _, _) -> + let args = get_guard cfnode @ get_pi_contents dom_ssa in + graph.nodes.(n) <- Some ((Pi args :: ssa, cfnode), preds, succs) + | None -> assert false + end + | None -> assert false end; IntSet.iter go children.(n) in @@ -708,19 +678,23 @@ let remove_nodes remove_cf graph = for n = 0 to graph.next - 1 do match graph.nodes.(n) with | Some ((_, cfnode), preds, succs) when remove_cf cfnode -> - IntSet.iter (fun pred -> - match graph.nodes.(pred) with - | Some (content, preds', succs') -> - graph.nodes.(pred) <- Some (content, preds', IntSet.remove n (IntSet.union succs succs')) - | None -> assert false - ) preds; - IntSet.iter (fun succ -> - match graph.nodes.(succ) with - | Some (content, preds', succs') -> - graph.nodes.(succ) <- Some (content, IntSet.remove n (IntSet.union preds preds'), succs') - | None -> assert false - ) succs; - graph.nodes.(n) <- None + IntSet.iter + (fun pred -> + match graph.nodes.(pred) with + | Some (content, preds', succs') -> + graph.nodes.(pred) <- Some (content, preds', IntSet.remove n (IntSet.union succs succs')) + | None -> assert false + ) + preds; + IntSet.iter + (fun succ -> + match graph.nodes.(succ) with + | Some (content, preds', succs') -> + graph.nodes.(succ) <- Some (content, IntSet.remove n (IntSet.union preds preds'), succs') + | None -> assert false + ) + succs; + graph.nodes.(n) <- None | _ -> () done @@ -732,52 +706,50 @@ let ssa instrs = place_phi_functions cfg df; rename_variables cfg start children; place_pi_functions cfg start idom children; - start, cfg + (start, cfg) (* Debugging utilities for outputing Graphviz files. *) let string_of_ssainstr = function | Phi (id, ctyp, args) -> - string_of_name id ^ " : " ^ string_of_ctyp ctyp ^ " = φ(" ^ Util.string_of_list ", " string_of_name args ^ ")" - | Pi cvals -> - "π(" ^ Util.string_of_list ", " (fun v -> String.escaped (string_of_cval v)) cvals ^ ")" + string_of_name id ^ " : " ^ string_of_ctyp ctyp ^ " = φ(" ^ Util.string_of_list ", " string_of_name args ^ ")" + | Pi cvals -> "π(" ^ Util.string_of_list ", " (fun v -> String.escaped (string_of_cval v)) cvals ^ ")" -let string_of_phis = function - | [] -> "" - | phis -> Util.string_of_list "\\l" string_of_ssainstr phis ^ "\\l" +let string_of_phis = function [] -> "" | phis -> Util.string_of_list "\\l" string_of_ssainstr phis ^ "\\l" let string_of_node = function - | (phis, CF_label label) -> string_of_phis phis ^ label - | (phis, CF_block (instrs, terminator)) -> - let string_of_instr instr = - let buf = Buffer.create 128 in - Jib_ir.Flat_ir_formatter.output_instr 0 buf 0 Jib_ir.StringMap.empty instr; - Buffer.contents buf - in - string_of_phis phis ^ Util.string_of_list "\\l" (fun instr -> String.escaped (string_of_instr instr)) instrs - | (phis, CF_start inits) -> string_of_phis phis ^ "START" - | (phis, CF_guard cval) -> string_of_phis phis ^ string_of_int cval + | phis, CF_label label -> string_of_phis phis ^ label + | phis, CF_block (instrs, terminator) -> + let string_of_instr instr = + let buf = Buffer.create 128 in + Jib_ir.Flat_ir_formatter.output_instr 0 buf 0 Jib_ir.StringMap.empty instr; + Buffer.contents buf + in + string_of_phis phis ^ Util.string_of_list "\\l" (fun instr -> String.escaped (string_of_instr instr)) instrs + | phis, CF_start inits -> string_of_phis phis ^ "START" + | phis, CF_guard cval -> string_of_phis phis ^ string_of_int cval let vertex_color = function - | (_, CF_start _) -> "peachpuff" - | (_, CF_block _) -> "white" - | (_, CF_label _) -> "springgreen" - | (_, CF_guard _) -> "yellow" + | _, CF_start _ -> "peachpuff" + | _, CF_block _ -> "white" + | _, CF_label _ -> "springgreen" + | _, CF_guard _ -> "yellow" let make_dot out_chan graph = Util.opt_colors := false; output_string out_chan "digraph DEPS {\n"; let make_node i n = - output_string out_chan (Printf.sprintf " n%i [label=\"%i\\n%s\\l\";shape=box;style=filled;fillcolor=%s];\n" i i (string_of_node n) (vertex_color n)) - in - let make_line i s = - output_string out_chan (Printf.sprintf " n%i -> n%i [color=black];\n" i s) + output_string out_chan + (Printf.sprintf " n%i [label=\"%i\\n%s\\l\";shape=box;style=filled;fillcolor=%s];\n" i i (string_of_node n) + (vertex_color n) + ) in + let make_line i s = output_string out_chan (Printf.sprintf " n%i -> n%i [color=black];\n" i s) in for i = 0 to graph.next - 1 do match graph.nodes.(i) with | Some (n, _, successors) -> - make_node i n; - IntSet.iter (fun s -> make_line i s) successors + make_node i n; + IntSet.iter (fun s -> make_line i s) successors | None -> () done; output_string out_chan "}\n"; @@ -787,18 +759,20 @@ let make_dominators_dot out_chan idom graph = Util.opt_colors := false; output_string out_chan "digraph DOMS {\n"; let make_node i n = - output_string out_chan (Printf.sprintf " n%i [label=\"%i\\n%s\\l\";shape=box;style=filled;fillcolor=%s];\n" i i (string_of_node n) (vertex_color n)) - in - let make_line i s = - output_string out_chan (Printf.sprintf " n%i -> n%i [color=black];\n" i s) + output_string out_chan + (Printf.sprintf " n%i [label=\"%i\\n%s\\l\";shape=box;style=filled;fillcolor=%s];\n" i i (string_of_node n) + (vertex_color n) + ) in + let make_line i s = output_string out_chan (Printf.sprintf " n%i -> n%i [color=black];\n" i s) in for i = 0 to Array.length idom - 1 do match graph.nodes.(i) with | Some (n, _, _) -> - if idom.(i) = -1 then - make_node i n - else - (make_node i n; make_line i idom.(i)) + if idom.(i) = -1 then make_node i n + else ( + make_node i n; + make_line i idom.(i) + ) | None -> () done; output_string out_chan "}\n"; diff --git a/src/sail_smt_backend/jib_ssa.mli b/src/sail_smt_backend/jib_ssa.mli index a7bc735b4..89097d298 100644 --- a/src/sail_smt_backend/jib_ssa.mli +++ b/src/sail_smt_backend/jib_ssa.mli @@ -119,9 +119,7 @@ val control_flow_graph : Jib.instr list -> int * int list * ('a list * cf_node) dominators for a control flow graph with a specified root node. *) val immediate_dominators : 'a array_graph -> int -> int array -type ssa_elem = - | Phi of Jib.name * Jib.ctyp * Jib.name list - | Pi of Jib.cval list +type ssa_elem = Phi of Jib.name * Jib.ctyp * Jib.name list | Pi of Jib.cval list (** Convert a list of instructions into SSA form *) val ssa : Jib.instr list -> int * (ssa_elem list * cf_node) array_graph diff --git a/src/sail_smt_backend/sail_plugin_smt.ml b/src/sail_smt_backend/sail_plugin_smt.ml index ef5a5b48f..656487830 100644 --- a/src/sail_smt_backend/sail_plugin_smt.ml +++ b/src/sail_smt_backend/sail_plugin_smt.ml @@ -67,26 +67,24 @@ open Libsail -let smt_options = [ - ( "-smt_auto", - Arg.Tuple [Arg.Set Jib_smt.opt_auto], - " generate SMT and automatically call the CVC4 solver"); - ( "-smt_ignore_overflow", - Arg.Set Jib_smt.opt_ignore_overflow, - " ignore integer overflow in generated SMT"); - ( "-smt_propagate_vars", - Arg.Set Jib_smt.opt_propagate_vars, - " propgate variables through generated SMT"); - ( "-smt_int_size", - Arg.String (fun n -> Jib_smt.opt_default_lint_size := int_of_string n), - " set a bound of n on the maximum integer bitwidth for generated SMT (default 128)"); - ( "-smt_bits_size", - Arg.String (fun n -> Jib_smt.opt_default_lbits_index := int_of_string n), - " set a bound of 2 ^ n for bitvector bitwidth in generated SMT (default 8)"); - ( "-smt_vector_size", - Arg.String (fun n -> Jib_smt.opt_default_vector_index := int_of_string n), - " set a bound of 2 ^ n for generic vectors in generated SMT (default 5)"); -] +let smt_options = + [ + ("-smt_auto", Arg.Tuple [Arg.Set Jib_smt.opt_auto], " generate SMT and automatically call the CVC4 solver"); + ("-smt_ignore_overflow", Arg.Set Jib_smt.opt_ignore_overflow, " ignore integer overflow in generated SMT"); + ("-smt_propagate_vars", Arg.Set Jib_smt.opt_propagate_vars, " propgate variables through generated SMT"); + ( "-smt_int_size", + Arg.String (fun n -> Jib_smt.opt_default_lint_size := int_of_string n), + " set a bound of n on the maximum integer bitwidth for generated SMT (default 128)" + ); + ( "-smt_bits_size", + Arg.String (fun n -> Jib_smt.opt_default_lbits_index := int_of_string n), + " set a bound of 2 ^ n for bitvector bitwidth in generated SMT (default 8)" + ); + ( "-smt_vector_size", + Arg.String (fun n -> Jib_smt.opt_default_vector_index := int_of_string n), + " set a bound of 2 ^ n for generic vectors in generated SMT (default 5)" + ); + ] let smt_rewrites = let open Rewrites in @@ -117,7 +115,7 @@ let smt_rewrites = ("merge_function_clauses", []); ("optimize_recheck_defs", []); ("constant_fold", [String_arg "c"]); - ("properties", []) + ("properties", []); ] let smt_target _ out_file ast effect_info env = @@ -127,18 +125,13 @@ let smt_target _ out_file ast effect_info env = let ast = Callgraph.filter_ast_ids prop_ids IdSet.empty ast in Specialize.add_initial_calls prop_ids; let ast_smt, env, effect_info = Specialize.(specialize typ_ord_specialization env ast effect_info) in - let ast_smt, env, effect_info = Specialize.(specialize_passes 2 int_specialization_with_externs env ast_smt effect_info) in + let ast_smt, env, effect_info = + Specialize.(specialize_passes 2 int_specialization_with_externs env ast_smt effect_info) + in let name_file = - match out_file with - | Some f -> fun str -> f ^ "_" ^ str ^ ".smt2" - | None -> fun str -> str ^ ".smt2" + match out_file with Some f -> fun str -> f ^ "_" ^ str ^ ".smt2" | None -> fun str -> str ^ ".smt2" in Reporting.opt_warnings := true; Jib_smt.generate_smt props name_file env effect_info ast_smt -let _ = - Target.register - ~name:"smt" - ~options:smt_options - ~rewrites:smt_rewrites - smt_target +let _ = Target.register ~name:"smt" ~options:smt_options ~rewrites:smt_rewrites smt_target diff --git a/src/sail_smt_backend/smtlib.ml b/src/sail_smt_backend/smtlib.ml index ea6bb381e..31a616cce 100644 --- a/src/sail_smt_backend/smtlib.ml +++ b/src/sail_smt_backend/smtlib.ml @@ -80,7 +80,7 @@ type smt_typ = | Array of smt_typ * smt_typ let rec smt_typ_compare t1 t2 = - match t1, t2 with + match (t1, t2) with | Bitvec n, Bitvec m -> compare n m | Bool, Bool -> 0 | String, String -> 0 @@ -88,8 +88,8 @@ let rec smt_typ_compare t1 t2 = | Datatype (name1, _), Datatype (name2, _) -> String.compare name1 name2 | Tuple ts1, Tuple ts2 -> Util.lex_ord_list smt_typ_compare ts1 ts2 | Array (t11, t12), Array (t21, t22) -> - let c = smt_typ_compare t11 t21 in - if c = 0 then smt_typ_compare t12 t22 else c + let c = smt_typ_compare t11 t21 in + if c = 0 then smt_typ_compare t12 t22 else c | Bitvec _, _ -> 1 | _, Bitvec _ -> -1 | Bool, _ -> 1 @@ -104,31 +104,24 @@ let rec smt_typ_compare t1 t2 = | _, Tuple _ -> -1 let rec smt_typ_equal t1 t2 = - match t1, t2 with + match (t1, t2) with | Bitvec n, Bitvec m -> n = m | Bool, Bool -> true | Datatype (name1, ctors1), Datatype (name2, ctors2) -> - let field_equal (field_name1, typ1) (field_name2, typ2) = - field_name1 = field_name2 && smt_typ_equal typ1 typ2 - in - let ctor_equal (ctor_name1, fields1) (ctor_name2, fields2) = - ctor_name1 = ctor_name2 - && List.length fields1 = List.length fields2 - && List.for_all2 field_equal fields1 fields2 - in - name1 = name2 - && List.length ctors1 = List.length ctors2 - && List.for_all2 ctor_equal ctors1 ctors2 + let field_equal (field_name1, typ1) (field_name2, typ2) = field_name1 = field_name2 && smt_typ_equal typ1 typ2 in + let ctor_equal (ctor_name1, fields1) (ctor_name2, fields2) = + ctor_name1 = ctor_name2 + && List.length fields1 = List.length fields2 + && List.for_all2 field_equal fields1 fields2 + in + name1 = name2 && List.length ctors1 = List.length ctors2 && List.for_all2 ctor_equal ctors1 ctors2 | _, _ -> false -let mk_enum name elems = - Datatype (name, List.map (fun elem -> (elem, [])) elems) +let mk_enum name elems = Datatype (name, List.map (fun elem -> (elem, [])) elems) -let mk_record name fields = - Datatype (name, [(name, fields)]) +let mk_record name fields = Datatype (name, [(name, fields)]) -let mk_variant name ctors = - Datatype (name, List.map (fun (ctor, ty) -> (ctor, [("un" ^ ctor, ty)])) ctors) +let mk_variant name ctors = Datatype (name, List.map (fun (ctor, ty) -> (ctor, [("un" ^ ctor, ty)])) ctors) type smt_exp = | Bool_lit of bool @@ -161,68 +154,49 @@ let rec fold_smt_exp f = function | Forall (binders, exp) -> f (Forall (binders, fold_smt_exp f exp)) | Syntactic (exp, exps) -> f (Syntactic (fold_smt_exp f exp, List.map (fold_smt_exp f) exps)) | Field (name, exp) -> f (Field (name, fold_smt_exp f exp)) - | Struct (name, fields) -> f (Struct (name, List.map (fun (field, exp) -> field, fold_smt_exp f exp) fields)) - | (Bool_lit _ | Bitvec_lit _ | Real_lit _ | String_lit _ | Var _ | Shared _ | Read_res _ | Enum _ as exp) -> f exp + | Struct (name, fields) -> f (Struct (name, List.map (fun (field, exp) -> (field, fold_smt_exp f exp)) fields)) + | (Bool_lit _ | Bitvec_lit _ | Real_lit _ | String_lit _ | Var _ | Shared _ | Read_res _ | Enum _) as exp -> f exp -let smt_conj = function - | [] -> Bool_lit true - | [x] -> x - | xs -> Fn ("and", xs) +let smt_conj = function [] -> Bool_lit true | [x] -> x | xs -> Fn ("and", xs) -let smt_disj = function - | [] -> Bool_lit false - | [x] -> x - | xs -> Fn ("or", xs) +let smt_disj = function [] -> Bool_lit false | [x] -> x | xs -> Fn ("or", xs) let extract i j x = Extract (i, j, x) -let bvnot x = Fn ("bvnot", [x]) -let bvand x y = Fn ("bvand", [x; y]) -let bvor x y = Fn ("bvor", [x; y]) -let bvneg x = Fn ("bvneg", [x]) -let bvadd x y = Fn ("bvadd", [x; y]) -let bvmul x y = Fn ("bvmul", [x; y]) +let bvnot x = Fn ("bvnot", [x]) +let bvand x y = Fn ("bvand", [x; y]) +let bvor x y = Fn ("bvor", [x; y]) +let bvneg x = Fn ("bvneg", [x]) +let bvadd x y = Fn ("bvadd", [x; y]) +let bvmul x y = Fn ("bvmul", [x; y]) let bvudiv x y = Fn ("bvudiv", [x; y]) let bvurem x y = Fn ("bvurem", [x; y]) -let bvshl x y = Fn ("bvshl", [x; y]) +let bvshl x y = Fn ("bvshl", [x; y]) let bvlshr x y = Fn ("bvlshr", [x; y]) -let bvult x y = Fn ("bvult", [x; y]) +let bvult x y = Fn ("bvult", [x; y]) let bvzero n = Bitvec_lit (Sail2_operators_bitlists.zeros (Big_int.of_int n)) let bvones n = Bitvec_lit (Sail2_operators_bitlists.ones (Big_int.of_int n)) let simp_equal x y = - match x, y with - | Bitvec_lit bv1, Bitvec_lit bv2 -> Some (Sail2_operators_bitlists.eq_vec bv1 bv2) - | _, _ -> None + match (x, y) with Bitvec_lit bv1, Bitvec_lit bv2 -> Some (Sail2_operators_bitlists.eq_vec bv1 bv2) | _, _ -> None let simp_and xs = let xs = List.filter (function Bool_lit true -> false | _ -> true) xs in match xs with | [] -> Bool_lit true | [x] -> x - | _ -> - if List.exists (function Bool_lit false -> true | _ -> false) xs then - Bool_lit false - else - Fn ("and", xs) + | _ -> if List.exists (function Bool_lit false -> true | _ -> false) xs then Bool_lit false else Fn ("and", xs) let simp_or xs = let xs = List.filter (function Bool_lit false -> false | _ -> true) xs in match xs with | [] -> Bool_lit false | [x] -> x - | _ -> - if List.exists (function Bool_lit true -> true | _ -> false) xs then - Bool_lit true - else - Fn ("or", xs) + | _ -> if List.exists (function Bool_lit true -> true | _ -> false) xs then Bool_lit true else Fn ("or", xs) -let rec all_bitvec_lit = function - | Bitvec_lit _ :: rest -> all_bitvec_lit rest - | [] -> true - | _ :: _ -> false +let rec all_bitvec_lit = function Bitvec_lit _ :: rest -> all_bitvec_lit rest | [] -> true | _ :: _ -> false let rec merge_bitvec_lit = function | Bitvec_lit b :: rest -> b @ merge_bitvec_lit rest @@ -241,11 +215,7 @@ let simp_fn = function | Fn ("bvsub", [Bitvec_lit bv1; Bitvec_lit bv2]) -> Bitvec_lit (Sail2_operators_bitlists.sub_vec bv1 bv2) | Fn ("bvadd", [Bitvec_lit bv1; Bitvec_lit bv2]) -> Bitvec_lit (Sail2_operators_bitlists.add_vec bv1 bv2) | Fn ("concat", xs) when all_bitvec_lit xs -> Bitvec_lit (merge_bitvec_lit xs) - | Fn ("=", [x; y]) as exp -> - begin match simp_equal x y with - | Some b -> Bool_lit b - | None -> exp - end + | Fn ("=", [x; y]) as exp -> begin match simp_equal x y with Some b -> Bool_lit b | None -> exp end | exp -> exp let simp_ite = function @@ -257,109 +227,91 @@ let simp_ite = function | exp -> exp let rec simp_smt_exp vars kinds = function - | Var v -> - begin match Hashtbl.find_opt vars v with - | Some exp -> simp_smt_exp vars kinds exp - | None -> Var v - end - | (Read_res _ | Shared _ | Enum _ | Bitvec_lit _ | Bool_lit _ | String_lit _ | Real_lit _ as exp) -> exp + | Var v -> begin match Hashtbl.find_opt vars v with Some exp -> simp_smt_exp vars kinds exp | None -> Var v end + | (Read_res _ | Shared _ | Enum _ | Bitvec_lit _ | Bool_lit _ | String_lit _ | Real_lit _) as exp -> exp | Field (field, exp) -> - let exp = simp_smt_exp vars kinds exp in - begin match exp with - | Struct (_, fields) -> - List.assoc field fields - | _ -> Field (field, exp) - end - | Struct (name, fields) -> - Struct (name, List.map (fun (field, exp) -> field, simp_smt_exp vars kinds exp) fields) + let exp = simp_smt_exp vars kinds exp in + begin + match exp with Struct (_, fields) -> List.assoc field fields | _ -> Field (field, exp) + end + | Struct (name, fields) -> Struct (name, List.map (fun (field, exp) -> (field, simp_smt_exp vars kinds exp)) fields) | Fn (f, exps) -> - let exps = List.map (simp_smt_exp vars kinds) exps in - simp_fn (Fn (f, exps)) + let exps = List.map (simp_smt_exp vars kinds) exps in + simp_fn (Fn (f, exps)) | Ctor (f, exps) -> - let exps = List.map (simp_smt_exp vars kinds) exps in - simp_fn (Ctor (f, exps)) + let exps = List.map (simp_smt_exp vars kinds) exps in + simp_fn (Ctor (f, exps)) | Ite (cond, t, e) -> - let cond = simp_smt_exp vars kinds cond in - let t = simp_smt_exp vars kinds t in - let e = simp_smt_exp vars kinds e in - simp_ite (Ite (cond, t, e)) + let cond = simp_smt_exp vars kinds cond in + let t = simp_smt_exp vars kinds t in + let e = simp_smt_exp vars kinds e in + simp_ite (Ite (cond, t, e)) | Extract (i, j, exp) -> - let exp = simp_smt_exp vars kinds exp in - begin match exp with - | Bitvec_lit bv -> - Bitvec_lit (Sail2_operators_bitlists.subrange_vec_dec bv (Big_int.of_int i) (Big_int.of_int j)) - | _ -> Extract (i, j, exp) - end + let exp = simp_smt_exp vars kinds exp in + begin + match exp with + | Bitvec_lit bv -> + Bitvec_lit (Sail2_operators_bitlists.subrange_vec_dec bv (Big_int.of_int i) (Big_int.of_int j)) + | _ -> Extract (i, j, exp) + end | Tester (str, exp) -> - let exp = simp_smt_exp vars kinds exp in - begin match exp with - | Var v -> - begin match Hashtbl.find_opt kinds v with - | Some str' when str = str' -> Bool_lit true - | Some str' -> Bool_lit false - | None -> Tester (str, exp) - end - | _ -> Tester (str, exp) - end + let exp = simp_smt_exp vars kinds exp in + begin + match exp with + | Var v -> begin + match Hashtbl.find_opt kinds v with + | Some str' when str = str' -> Bool_lit true + | Some str' -> Bool_lit false + | None -> Tester (str, exp) + end + | _ -> Tester (str, exp) + end | Syntactic (exp, _) -> exp | SignExtend (i, exp) -> - let exp = simp_smt_exp vars kinds exp in - begin match exp with - | Bitvec_lit bv -> - Bitvec_lit (Sail2_operators_bitlists.sign_extend bv (Big_int.of_int (i + List.length bv))) - | _ -> SignExtend (i, exp) - end + let exp = simp_smt_exp vars kinds exp in + begin + match exp with + | Bitvec_lit bv -> Bitvec_lit (Sail2_operators_bitlists.sign_extend bv (Big_int.of_int (i + List.length bv))) + | _ -> SignExtend (i, exp) + end | Forall (binders, exp) -> Forall (binders, exp) type read_info = { - name : string; - node : int; - active : smt_exp; - kind : smt_exp; - addr_type : smt_typ; - addr : smt_exp; - ret_type : smt_typ; - doc : string - } + name : string; + node : int; + active : smt_exp; + kind : smt_exp; + addr_type : smt_typ; + addr : smt_exp; + ret_type : smt_typ; + doc : string; +} type write_info = { - name : string; - node : int; - active : smt_exp; - kind : smt_exp; - addr_type : smt_typ; - addr : smt_exp; - data_type : smt_typ; - data : smt_exp; - doc : string - } + name : string; + node : int; + active : smt_exp; + kind : smt_exp; + addr_type : smt_typ; + addr : smt_exp; + data_type : smt_typ; + data : smt_exp; + doc : string; +} -type barrier_info = { - name : string; - node : int; - active : smt_exp; - kind : smt_exp; - doc : string - } +type barrier_info = { name : string; node : int; active : smt_exp; kind : smt_exp; doc : string } -type branch_info = { - name : string; - node : int; - active : smt_exp; - addr_type : smt_typ; - addr : smt_exp; - doc : string - } +type branch_info = { name : string; node : int; active : smt_exp; addr_type : smt_typ; addr : smt_exp; doc : string } type cache_op_info = { - name : string; - node : int; - active : smt_exp; - kind : smt_exp; - addr_type : smt_typ; - addr : smt_exp; - doc : string - } + name : string; + node : int; + active : smt_exp; + kind : smt_exp; + addr_type : smt_typ; + addr : smt_exp; + doc : string; +} type smt_def = | Define_fun of string * (string * smt_typ) list * smt_typ * smt_exp @@ -387,10 +339,10 @@ let smt_def_map_exp f = function | Preserve_const (name, ty, exp) -> Preserve_const (name, ty, f exp) | Write_mem w -> Write_mem { w with active = f w.active; kind = f w.kind; addr = f w.addr; data = f w.data } | Write_mem_ea (name, node, active, wk, addr, addr_ty, data_size, data_size_ty) -> - Write_mem_ea (name, node, f active, f wk, f addr, addr_ty, f data_size, data_size_ty) + Write_mem_ea (name, node, f active, f wk, f addr, addr_ty, f data_size, data_size_ty) | Read_mem r -> Read_mem { r with active = f r.active; kind = f r.kind; addr = f r.addr } | Barrier b -> Barrier { b with active = f b.active; kind = f b.kind } - | Cache_maintenance m -> Cache_maintenance { m with active = f m.active; kind = f m.kind ; addr = f m.addr } + | Cache_maintenance m -> Cache_maintenance { m with active = f m.active; kind = f m.kind; addr = f m.addr } | Branch_announce c -> Branch_announce { c with active = f c.active; addr = f c.addr } | Excl_res (name, node, active) -> Excl_res (name, node, f active) | Declare_datatypes (name, ctors) -> Declare_datatypes (name, ctors) @@ -401,20 +353,35 @@ let smt_def_iter_exp f = function | Define_fun (name, args, ty, exp) -> f exp | Define_const (name, ty, exp) -> f exp | Preserve_const (name, ty, exp) -> f exp - | Write_mem w -> f w.active; f w.kind; f w.addr; f w.data + | Write_mem w -> + f w.active; + f w.kind; + f w.addr; + f w.data | Write_mem_ea (name, node, active, wk, addr, addr_ty, data_size, data_size_ty) -> - f active; f wk; f addr; f data_size - | Read_mem r -> f r.active; f r.kind; f r.addr - | Barrier b -> f b.active; f b.kind - | Cache_maintenance m -> f m.active; f m.kind; f m.addr - | Branch_announce c -> f c.active; f c.addr + f active; + f wk; + f addr; + f data_size + | Read_mem r -> + f r.active; + f r.kind; + f r.addr + | Barrier b -> + f b.active; + f b.kind + | Cache_maintenance m -> + f m.active; + f m.kind; + f m.addr + | Branch_announce c -> + f c.active; + f c.addr | Excl_res (name, node, active) -> f active | Assert exp -> f exp | Declare_fun _ | Declare_const _ | Declare_tuple _ | Declare_datatypes _ -> () -let declare_datatypes = function - | Datatype (name, ctors) -> Declare_datatypes (name, ctors) - | _ -> assert false +let declare_datatypes = function Datatype (name, ctors) -> Declare_datatypes (name, ctors) | _ -> assert false (** For generating SMT with multiple threads (i.e. for litmus tests), we suffix all the variables in the generated SMT with a thread @@ -429,7 +396,14 @@ let suffix_variables_read_info sfx (r : read_info) = let suffix_variables_write_info sfx (w : write_info) = let suffix exp = suffix_variables_exp sfx exp in - { w with name = w.name ^ sfx; active = suffix w.active; kind = suffix w.kind; addr = suffix w.addr; data = suffix w.data } + { + w with + name = w.name ^ sfx; + active = suffix w.active; + kind = suffix w.kind; + addr = suffix w.addr; + data = suffix w.data; + } let suffix_variables_barrier_info sfx (b : barrier_info) = let suffix exp = suffix_variables_exp sfx exp in @@ -445,31 +419,31 @@ let suffix_variables_cache_op_info sfx (m : cache_op_info) = let suffix_variables_def sfx = function | Define_fun (name, args, ty, exp) -> - Define_fun (name ^ sfx, List.map (fun (arg, ty) -> sfx ^ arg, ty) args, ty, suffix_variables_exp sfx exp) - | Declare_fun (name, tys, ty) -> - Declare_fun (name ^ sfx, tys, ty) - | Declare_const (name, ty) -> - Declare_const (name ^ sfx, ty) - | Define_const (name, ty, exp) -> - Define_const (name ^ sfx, ty, suffix_variables_exp sfx exp) - | Preserve_const (name, ty, exp) -> - Preserve_const (name, ty, suffix_variables_exp sfx exp) + Define_fun (name ^ sfx, List.map (fun (arg, ty) -> (sfx ^ arg, ty)) args, ty, suffix_variables_exp sfx exp) + | Declare_fun (name, tys, ty) -> Declare_fun (name ^ sfx, tys, ty) + | Declare_const (name, ty) -> Declare_const (name ^ sfx, ty) + | Define_const (name, ty, exp) -> Define_const (name ^ sfx, ty, suffix_variables_exp sfx exp) + | Preserve_const (name, ty, exp) -> Preserve_const (name, ty, suffix_variables_exp sfx exp) | Write_mem w -> Write_mem (suffix_variables_write_info sfx w) - | Write_mem_ea (name, node, active , wk, addr, addr_ty, data_size, data_size_ty) -> - Write_mem_ea (name ^ sfx, node, suffix_variables_exp sfx active, suffix_variables_exp sfx wk, - suffix_variables_exp sfx addr, addr_ty, suffix_variables_exp sfx data_size, data_size_ty) + | Write_mem_ea (name, node, active, wk, addr, addr_ty, data_size, data_size_ty) -> + Write_mem_ea + ( name ^ sfx, + node, + suffix_variables_exp sfx active, + suffix_variables_exp sfx wk, + suffix_variables_exp sfx addr, + addr_ty, + suffix_variables_exp sfx data_size, + data_size_ty + ) | Read_mem r -> Read_mem (suffix_variables_read_info sfx r) | Barrier b -> Barrier (suffix_variables_barrier_info sfx b) | Cache_maintenance m -> Cache_maintenance (suffix_variables_cache_op_info sfx m) | Branch_announce c -> Branch_announce (suffix_variables_branch_info sfx c) - | Excl_res (name, node, active) -> - Excl_res (name ^ sfx, node, suffix_variables_exp sfx active) - | Declare_datatypes (name, ctors) -> - Declare_datatypes (name, ctors) - | Declare_tuple n -> - Declare_tuple n - | Assert exp -> - Assert (suffix_variables_exp sfx exp) + | Excl_res (name, node, active) -> Excl_res (name ^ sfx, node, suffix_variables_exp sfx active) + | Declare_datatypes (name, ctors) -> Declare_datatypes (name, ctors) + | Declare_tuple n -> Declare_tuple n + | Assert exp -> Assert (suffix_variables_exp sfx exp) let pp_sfun str docs = let open PPrint in @@ -486,7 +460,9 @@ let rec pp_smt_typ = | Tuple tys -> pp_sfun ("Tup" ^ string_of_int (List.length tys)) (List.map pp_smt_typ tys) | Array (ty1, ty2) -> pp_sfun "Array" [pp_smt_typ ty1; pp_smt_typ ty2] -let pp_str_smt_typ (str, ty) = let open PPrint in parens (string str ^^ space ^^ pp_smt_typ ty) +let pp_str_smt_typ (str, ty) = + let open PPrint in + parens (string str ^^ space ^^ pp_smt_typ ty) let rec pp_smt_exp = let open PPrint in @@ -504,101 +480,101 @@ let rec pp_smt_exp = | Struct (str, fields) -> parens (string str ^^ space ^^ separate_map space (fun (_, exp) -> pp_smt_exp exp) fields) | Ctor (str, exps) -> parens (string str ^^ space ^^ separate_map space pp_smt_exp exps) | Ite (cond, then_exp, else_exp) -> - parens (separate space [string "ite"; pp_smt_exp cond; pp_smt_exp then_exp; pp_smt_exp else_exp]) - | Extract (i, j, exp) -> - parens (string (Printf.sprintf "(_ extract %d %d)" i j) ^^ space ^^ pp_smt_exp exp) - | Tester (kind, exp) -> - parens (string (Printf.sprintf "(_ is %s)" kind) ^^ space ^^ pp_smt_exp exp) - | SignExtend (i, exp) -> - parens (string (Printf.sprintf "(_ sign_extend %d)" i) ^^ space ^^ pp_smt_exp exp) + parens (separate space [string "ite"; pp_smt_exp cond; pp_smt_exp then_exp; pp_smt_exp else_exp]) + | Extract (i, j, exp) -> parens (string (Printf.sprintf "(_ extract %d %d)" i j) ^^ space ^^ pp_smt_exp exp) + | Tester (kind, exp) -> parens (string (Printf.sprintf "(_ is %s)" kind) ^^ space ^^ pp_smt_exp exp) + | SignExtend (i, exp) -> parens (string (Printf.sprintf "(_ sign_extend %d)" i) ^^ space ^^ pp_smt_exp exp) | Syntactic (exp, _) -> pp_smt_exp exp | Forall (binders, exp) -> - parens (string "forall" ^^ space ^^ parens (separate_map space pp_str_smt_typ binders) ^^ space ^^ pp_smt_exp exp) + parens (string "forall" ^^ space ^^ parens (separate_map space pp_str_smt_typ binders) ^^ space ^^ pp_smt_exp exp) let pp_smt_def = let open PPrint in let open Printf in function | Define_fun (name, args, ty, exp) -> - parens (string "define-fun" ^^ space ^^ string name - ^^ space ^^ parens (separate_map space pp_str_smt_typ args) - ^^ space ^^ pp_smt_typ ty - ^//^ pp_smt_exp exp) - + parens + (string "define-fun" ^^ space ^^ string name ^^ space + ^^ parens (separate_map space pp_str_smt_typ args) + ^^ space ^^ pp_smt_typ ty ^//^ pp_smt_exp exp + ) | Declare_fun (name, args, ty) -> - parens (string "declare-fun" ^^ space ^^ string name - ^^ space ^^ parens (separate_map space pp_smt_typ args) - ^^ space ^^ pp_smt_typ ty) - - | Declare_const (name, ty) -> - pp_sfun "declare-const" [string name; pp_smt_typ ty] - + parens + (string "declare-fun" ^^ space ^^ string name ^^ space + ^^ parens (separate_map space pp_smt_typ args) + ^^ space ^^ pp_smt_typ ty + ) + | Declare_const (name, ty) -> pp_sfun "declare-const" [string name; pp_smt_typ ty] | Define_const (name, ty, exp) | Preserve_const (name, ty, exp) -> - pp_sfun "define-const" [string name; pp_smt_typ ty; pp_smt_exp exp] - + pp_sfun "define-const" [string name; pp_smt_typ ty; pp_smt_exp exp] | Write_mem w -> - pp_sfun "define-const" [string (w.name ^ "_kind"); string "Zwrite_kind"; pp_smt_exp w.kind] ^^ hardline - ^^ pp_sfun "define-const" [string (w.name ^ "_active"); pp_smt_typ Bool; pp_smt_exp w.active] ^^ hardline - ^^ pp_sfun "define-const" [string (w.name ^ "_data"); pp_smt_typ w.data_type; pp_smt_exp w.data] ^^ hardline - ^^ pp_sfun "define-const" [string (w.name ^ "_addr"); pp_smt_typ w.addr_type; pp_smt_exp w.addr] ^^ hardline - ^^ pp_sfun "declare-const" [string (w.name ^ "_ret"); pp_smt_typ Bool] - + pp_sfun "define-const" [string (w.name ^ "_kind"); string "Zwrite_kind"; pp_smt_exp w.kind] + ^^ hardline + ^^ pp_sfun "define-const" [string (w.name ^ "_active"); pp_smt_typ Bool; pp_smt_exp w.active] + ^^ hardline + ^^ pp_sfun "define-const" [string (w.name ^ "_data"); pp_smt_typ w.data_type; pp_smt_exp w.data] + ^^ hardline + ^^ pp_sfun "define-const" [string (w.name ^ "_addr"); pp_smt_typ w.addr_type; pp_smt_exp w.addr] + ^^ hardline + ^^ pp_sfun "declare-const" [string (w.name ^ "_ret"); pp_smt_typ Bool] | Write_mem_ea (name, _, active, wk, addr, addr_ty, data_size, data_size_ty) -> - pp_sfun "define-const" [string (name ^ "_kind"); string "Zwrite_kind"; pp_smt_exp wk] ^^ hardline - ^^ pp_sfun "define-const" [string (name ^ "_active"); pp_smt_typ Bool; pp_smt_exp active] ^^ hardline - ^^ pp_sfun "define-const" [string (name ^ "_size"); pp_smt_typ data_size_ty; pp_smt_exp data_size] ^^ hardline - ^^ pp_sfun "define-const" [string (name ^ "_addr"); pp_smt_typ addr_ty; pp_smt_exp addr] - + pp_sfun "define-const" [string (name ^ "_kind"); string "Zwrite_kind"; pp_smt_exp wk] + ^^ hardline + ^^ pp_sfun "define-const" [string (name ^ "_active"); pp_smt_typ Bool; pp_smt_exp active] + ^^ hardline + ^^ pp_sfun "define-const" [string (name ^ "_size"); pp_smt_typ data_size_ty; pp_smt_exp data_size] + ^^ hardline + ^^ pp_sfun "define-const" [string (name ^ "_addr"); pp_smt_typ addr_ty; pp_smt_exp addr] | Read_mem r -> - pp_sfun "define-const" [string (r.name ^ "_kind"); string "Zread_kind"; pp_smt_exp r.kind] ^^ hardline - ^^ pp_sfun "define-const" [string (r.name ^ "_active"); pp_smt_typ Bool; pp_smt_exp r.active] ^^ hardline - ^^ pp_sfun "define-const" [string (r.name ^ "_addr"); pp_smt_typ r.addr_type; pp_smt_exp r.addr] ^^ hardline - ^^ pp_sfun "declare-const" [string (r.name ^ "_ret"); pp_smt_typ r.ret_type] - + pp_sfun "define-const" [string (r.name ^ "_kind"); string "Zread_kind"; pp_smt_exp r.kind] + ^^ hardline + ^^ pp_sfun "define-const" [string (r.name ^ "_active"); pp_smt_typ Bool; pp_smt_exp r.active] + ^^ hardline + ^^ pp_sfun "define-const" [string (r.name ^ "_addr"); pp_smt_typ r.addr_type; pp_smt_exp r.addr] + ^^ hardline + ^^ pp_sfun "declare-const" [string (r.name ^ "_ret"); pp_smt_typ r.ret_type] | Barrier b -> - pp_sfun "define-const" [string (b.name ^ "_kind"); string "Zbarrier_kind"; pp_smt_exp b.kind] ^^ hardline - ^^ pp_sfun "define-const" [string (b.name ^ "_active"); pp_smt_typ Bool; pp_smt_exp b.active] - + pp_sfun "define-const" [string (b.name ^ "_kind"); string "Zbarrier_kind"; pp_smt_exp b.kind] + ^^ hardline + ^^ pp_sfun "define-const" [string (b.name ^ "_active"); pp_smt_typ Bool; pp_smt_exp b.active] | Cache_maintenance m -> - pp_sfun "define-const" [string (m.name ^ "_active"); pp_smt_typ Bool; pp_smt_exp m.active] ^^ hardline - ^^ pp_sfun "define-const" [string (m.name ^ "_kind"); string "Zcache_op_kind"; pp_smt_exp m.kind] ^^ hardline - ^^ pp_sfun "define-const" [string (m.name ^ "_addr"); pp_smt_typ m.addr_type; pp_smt_exp m.addr] - + pp_sfun "define-const" [string (m.name ^ "_active"); pp_smt_typ Bool; pp_smt_exp m.active] + ^^ hardline + ^^ pp_sfun "define-const" [string (m.name ^ "_kind"); string "Zcache_op_kind"; pp_smt_exp m.kind] + ^^ hardline + ^^ pp_sfun "define-const" [string (m.name ^ "_addr"); pp_smt_typ m.addr_type; pp_smt_exp m.addr] | Branch_announce c -> - pp_sfun "define-const" [string (c.name ^ "_active"); pp_smt_typ Bool; pp_smt_exp c.active] ^^ hardline - ^^ pp_sfun "define-const" [string (c.name ^ "_addr"); pp_smt_typ c.addr_type; pp_smt_exp c.addr] - + pp_sfun "define-const" [string (c.name ^ "_active"); pp_smt_typ Bool; pp_smt_exp c.active] + ^^ hardline + ^^ pp_sfun "define-const" [string (c.name ^ "_addr"); pp_smt_typ c.addr_type; pp_smt_exp c.addr] | Excl_res (name, _, active) -> - pp_sfun "declare-const" [string (name ^ "_res"); pp_smt_typ Bool] ^^ hardline - ^^ pp_sfun "define-const" [string (name ^ "_active"); pp_smt_typ Bool; pp_smt_exp active] - + pp_sfun "declare-const" [string (name ^ "_res"); pp_smt_typ Bool] + ^^ hardline + ^^ pp_sfun "define-const" [string (name ^ "_active"); pp_smt_typ Bool; pp_smt_exp active] | Declare_datatypes (name, ctors) -> - let pp_ctor (ctor_name, fields) = - match fields with - | [] -> parens (string ctor_name) - | _ -> pp_sfun ctor_name (List.map pp_str_smt_typ fields) - in - pp_sfun "declare-datatypes" - [Printf.ksprintf string "((%s 0))" name; - parens (parens (separate_map space pp_ctor ctors))] - + let pp_ctor (ctor_name, fields) = + match fields with [] -> parens (string ctor_name) | _ -> pp_sfun ctor_name (List.map pp_str_smt_typ fields) + in + pp_sfun "declare-datatypes" + [Printf.ksprintf string "((%s 0))" name; parens (parens (separate_map space pp_ctor ctors))] | Declare_tuple n -> - let par = separate_map space string (Util.list_init n (fun i -> "T" ^ string_of_int i)) in - let fields = separate space (Util.list_init n (fun i -> Printf.ksprintf string "(tup_%d_%d T%d)" n i i)) in - pp_sfun "declare-datatypes" - [Printf.ksprintf string "((Tup%d %d))" n n; - parens (parens (separate space - [string "par"; - parens par; - parens (parens (ksprintf string "tup%d" n ^^ space ^^ fields))]))] - - | Assert exp -> - pp_sfun "assert" [pp_smt_exp exp] + let par = separate_map space string (Util.list_init n (fun i -> "T" ^ string_of_int i)) in + let fields = separate space (Util.list_init n (fun i -> Printf.ksprintf string "(tup_%d_%d T%d)" n i i)) in + pp_sfun "declare-datatypes" + [ + Printf.ksprintf string "((Tup%d %d))" n n; + parens + (parens + (separate space + [string "par"; parens par; parens (parens (ksprintf string "tup%d" n ^^ space ^^ fields))] + ) + ); + ] + | Assert exp -> pp_sfun "assert" [pp_smt_exp exp] let string_of_smt_def def = Pretty_print_sail.to_string (pp_smt_def def) -let output_smt_defs out_chan smt = - List.iter (fun def -> output_string out_chan (string_of_smt_def def ^ "\n")) smt +let output_smt_defs out_chan smt = List.iter (fun def -> output_string out_chan (string_of_smt_def def ^ "\n")) smt (**************************************************************************) (* 2. Parser for SMT solver output *) @@ -609,7 +585,7 @@ let output_smt_defs out_chan smt = form of s-expression based representation. Therefore we define a simple parser for s-expressions using monadic parser combinators. *) -type sexpr = List of (sexpr list) | Atom of string +type sexpr = List of sexpr list | Atom of string let rec string_of_sexpr = function | List sexprs -> "(" ^ Util.string_of_list " " string_of_sexpr sexprs ^ ")" @@ -619,57 +595,49 @@ open Parser_combinators let lparen = token (function Str.Delim "(" -> Some () | _ -> None) let rparen = token (function Str.Delim ")" -> Some () | _ -> None) -let atom = token (function Str.Text str -> Some str | _ -> None) +let atom = token (function Str.Text str -> Some str | _ -> None) let rec sexp toks = let parse = pchoose (atom >>= fun str -> preturn (Atom str)) - (lparen >>= fun _ -> - plist sexp >>= fun xs -> - rparen >>= fun _ -> - preturn (List xs)) + ( lparen >>= fun _ -> + plist sexp >>= fun xs -> + rparen >>= fun _ -> preturn (List xs) + ) in parse toks let parse_sexps input = let delim = Str.regexp "[ \n\t]+\\|(\\|)" in let tokens = Str.full_split delim input in - let non_whitespace = function - | Str.Delim d when String.trim d = "" -> false - | _ -> true - in + let non_whitespace = function Str.Delim d when String.trim d = "" -> false | _ -> true in let tokens = List.filter non_whitespace tokens in - match plist sexp tokens with - | Ok (result, _) -> result - | Fail -> failwith "Parse failure" + match plist sexp tokens with Ok (result, _) -> result | Fail -> failwith "Parse failure" let value_of_sexpr sexpr = let open Jib in let open Value in function - | CT_fbits (n, _) -> - begin match sexpr with - | List [Atom "_"; Atom v; Atom m] - when int_of_string m = n && String.length v > 2 && String.sub v 0 2 = "bv" -> - let v = String.sub v 2 (String.length v - 2) in - mk_vector (Sail_lib.get_slice_int' (n, Big_int.of_string v, 0)) - | Atom v - when String.length v > 2 && String.sub v 0 2 = "#b" -> - let v = String.sub v 2 (String.length v - 2) in - mk_vector (Sail_lib.get_slice_int' (n, Big_int.of_string ("0b" ^ v), 0)) - | Atom v - when String.length v > 2 && String.sub v 0 2 = "#x" -> - let v = String.sub v 2 (String.length v - 2) in - mk_vector (Sail_lib.get_slice_int' (n, Big_int.of_string ("0x" ^ v), 0)) - | _ -> failwith ("Cannot parse sexpr as ctyp: " ^ string_of_sexpr sexpr) - end + | CT_fbits (n, _) -> begin + match sexpr with + | List [Atom "_"; Atom v; Atom m] when int_of_string m = n && String.length v > 2 && String.sub v 0 2 = "bv" -> + let v = String.sub v 2 (String.length v - 2) in + mk_vector (Sail_lib.get_slice_int' (n, Big_int.of_string v, 0)) + | Atom v when String.length v > 2 && String.sub v 0 2 = "#b" -> + let v = String.sub v 2 (String.length v - 2) in + mk_vector (Sail_lib.get_slice_int' (n, Big_int.of_string ("0b" ^ v), 0)) + | Atom v when String.length v > 2 && String.sub v 0 2 = "#x" -> + let v = String.sub v 2 (String.length v - 2) in + mk_vector (Sail_lib.get_slice_int' (n, Big_int.of_string ("0x" ^ v), 0)) + | _ -> failwith ("Cannot parse sexpr as ctyp: " ^ string_of_sexpr sexpr) + end | cty -> failwith ("Unsupported type in sexpr: " ^ Jib_util.string_of_ctyp cty) let rec find_arg id ctyp arg_smt_names = function | List [Atom "define-fun"; Atom str; List []; _; value] :: _ - when Util.assoc_compare_opt Id.compare id arg_smt_names = Some (Some str) -> - (id, value_of_sexpr value ctyp) + when Util.assoc_compare_opt Id.compare id arg_smt_names = Some (Some str) -> + (id, value_of_sexpr value ctyp) | _ :: sexps -> find_arg id ctyp arg_smt_names sexps | [] -> (id, V_unit) @@ -679,14 +647,10 @@ let build_counterexample args arg_ctyps arg_smt_names model = let rec run frame = match frame with | Interpreter.Done (state, v) -> Some v - | Interpreter.Step (lazy_str, _, _, _) -> - run (Interpreter.eval_frame frame) - | Interpreter.Break frame -> - run (Interpreter.eval_frame frame) - | Interpreter.Fail (_, _, _, _, msg) -> - None - | Interpreter.Effect_request (out, state, stack, eff) -> - run (Interpreter.default_effect_interp state eff) + | Interpreter.Step (lazy_str, _, _, _) -> run (Interpreter.eval_frame frame) + | Interpreter.Break frame -> run (Interpreter.eval_frame frame) + | Interpreter.Fail (_, _, _, _, msg) -> None + | Interpreter.Effect_request (out, state, stack, eff) -> run (Interpreter.default_effect_interp state eff) let check_counterexample ast env fname function_id args arg_ctyps arg_smt_names = let open Printf in @@ -698,28 +662,38 @@ let check_counterexample ast env fname function_id args arg_ctyps arg_smt_names while true do lines := input_line in_chan :: !lines done - with - | End_of_file -> () + with End_of_file -> () end; let solver_output = List.rev !lines |> String.concat "\n" |> parse_sexps in - begin match solver_output with - | Atom "sat" :: List (Atom "model" :: model) :: _ -> - let open Value in - let open Interpreter in - prerr_endline (sprintf "Solver found counterexample: %s" Util.("ok" |> green |> clear)); - let counterexample = build_counterexample args arg_ctyps arg_smt_names model in - List.iter (fun (id, v) -> prerr_endline (" " ^ string_of_id id ^ " -> " ^ string_of_value v)) counterexample; - let istate = initial_state ast env !primops in - let annot = (Parse_ast.Unknown, Type_check.mk_tannot env bool_typ) in - let call = E_aux (E_app (function_id, List.map (fun (_, v) -> E_aux (E_internal_value v, (Parse_ast.Unknown, Type_check.empty_tannot))) counterexample), annot) in - let result = run (Step (lazy "", istate, return call, [])) in - begin match result with - | Some (V_bool false) | None -> - ksprintf prerr_endline "Replaying counterexample: %s" Util.("ok" |> green |> clear) - | _ -> () - end - | _ -> - prerr_endline "Solver could not find counterexample" + begin + match solver_output with + | Atom "sat" :: List (Atom "model" :: model) :: _ -> + let open Value in + let open Interpreter in + prerr_endline (sprintf "Solver found counterexample: %s" Util.("ok" |> green |> clear)); + let counterexample = build_counterexample args arg_ctyps arg_smt_names model in + List.iter (fun (id, v) -> prerr_endline (" " ^ string_of_id id ^ " -> " ^ string_of_value v)) counterexample; + let istate = initial_state ast env !primops in + let annot = (Parse_ast.Unknown, Type_check.mk_tannot env bool_typ) in + let call = + E_aux + ( E_app + ( function_id, + List.map + (fun (_, v) -> E_aux (E_internal_value v, (Parse_ast.Unknown, Type_check.empty_tannot))) + counterexample + ), + annot + ) + in + let result = run (Step (lazy "", istate, return call, [])) in + begin + match result with + | Some (V_bool false) | None -> + ksprintf prerr_endline "Replaying counterexample: %s" Util.("ok" |> green |> clear) + | _ -> () + end + | _ -> prerr_endline "Solver could not find counterexample" end; let status = Unix.close_process_in in_chan in () diff --git a/test/builtins/myocamlbuild.ml b/test/builtins/myocamlbuild.ml index cbd112399..1a42cc52d 100644 --- a/test/builtins/myocamlbuild.ml +++ b/test/builtins/myocamlbuild.ml @@ -54,22 +54,29 @@ open Pathname open Outcome (* All -wl ignores should be removed if you want to see the pattern compilation, exhaustive, and unused var warnings *) -let lem_opts = [ A "-lib"; P "../../../../src/gen_lib"; - A "-lib"; P ".."; - A "-wl_pat_comp"; P "ign"; - A "-wl_pat_exh"; P "ign"; - A "-wl_pat_fail"; P "ign"; - A "-wl_unused_vars"; P "ign" ];; - -dispatch begin function -| After_rules -> - rule "lem -> ml" - ~prod: "%.ml" - ~dep: "%.lem" - (fun env builder -> Seq [ - Cmd (S ([P "lem"] @ lem_opts @ [ A "-ocaml"; P (env "%.lem") ])); - ]); - -| _ -> () -end;; +let lem_opts = + [ + A "-lib"; + P "../../../../src/gen_lib"; + A "-lib"; + P ".."; + A "-wl_pat_comp"; + P "ign"; + A "-wl_pat_exh"; + P "ign"; + A "-wl_pat_fail"; + P "ign"; + A "-wl_unused_vars"; + P "ign"; + ] +;; +dispatch + begin + function + | After_rules -> + rule "lem -> ml" ~prod:"%.ml" ~dep:"%.lem" (fun env builder -> + Seq [Cmd (S ([P "lem"] @ lem_opts @ [A "-ocaml"; P (env "%.lem")]))] + ) + | _ -> () + end diff --git a/test/c/lbuild/myocamlbuild.ml b/test/c/lbuild/myocamlbuild.ml index cc65b03ac..582f51cb1 100644 --- a/test/c/lbuild/myocamlbuild.ml +++ b/test/c/lbuild/myocamlbuild.ml @@ -48,30 +48,36 @@ (* SUCH DAMAGE. *) (**************************************************************************) -open Ocamlbuild_plugin ;; -open Command ;; -open Pathname ;; -open Outcome ;; +open Ocamlbuild_plugin +open Command +open Pathname +open Outcome (* paths relative to _build *) -let lem = "lem" ;; +let lem = "lem" (* All -wl ignores should be removed if you want to see the pattern compilation, exhaustive, and unused var warnings *) -let lem_opts = [A "-lib"; P ".."; - A "-wl_pat_comp"; P "ign"; - A "-wl_pat_exh"; P "ign"; - A "-wl_pat_fail"; P "ign"; - A "-wl_unused_vars"; P "ign"; - ] ;; +let lem_opts = + [ + A "-lib"; + P ".."; + A "-wl_pat_comp"; + P "ign"; + A "-wl_pat_exh"; + P "ign"; + A "-wl_pat_fail"; + P "ign"; + A "-wl_unused_vars"; + P "ign"; + ] +;; -dispatch begin function -| After_rules -> - rule "lem -> ml" - ~prod: "%.ml" - ~dep: "%.lem" - (fun env builder -> Seq [ - Cmd (S ([ P lem] @ lem_opts @ [ A "-ocaml"; P (env "%.lem") ])); - ]); - -| _ -> () -end ;; +dispatch + begin + function + | After_rules -> + rule "lem -> ml" ~prod:"%.ml" ~dep:"%.lem" (fun env builder -> + Seq [Cmd (S ([P lem] @ lem_opts @ [A "-ocaml"; P (env "%.lem")]))] + ) + | _ -> () + end diff --git a/test/isabelle/elf_loader.ml b/test/isabelle/elf_loader.ml index 6ec89ee65..3bfefbdcf 100644 --- a/test/isabelle/elf_loader.ml +++ b/test/isabelle/elf_loader.ml @@ -49,20 +49,15 @@ let opt_elf_tohost = ref Nat_big_num.zero type word8 = int -let escape_char c = - if int_of_char c <= 31 then '.' - else if int_of_char c >= 127 then '.' - else c +let escape_char c = if int_of_char c <= 31 then '.' else if int_of_char c >= 127 then '.' else c let hex_line bs = - let hex_char i c = - (if i mod 2 == 0 && i <> 0 then " " else "") ^ Printf.sprintf "%02x" (int_of_char c) - in - String.concat "" (List.mapi hex_char bs) ^ " " ^ String.concat "" (List.map (fun c -> Printf.sprintf "%c" (escape_char c)) bs) + let hex_char i c = (if i mod 2 == 0 && i <> 0 then " " else "") ^ Printf.sprintf "%02x" (int_of_char c) in + String.concat "" (List.mapi hex_char bs) + ^ " " + ^ String.concat "" (List.map (fun c -> Printf.sprintf "%c" (escape_char c)) bs) -let rec break n = function - | [] -> [] - | (_ :: _ as xs) -> [Lem_list.take n xs] @ break n (Lem_list.drop n xs) +let rec break n = function [] -> [] | _ :: _ as xs -> [Lem_list.take n xs] @ break n (Lem_list.drop n xs) let print_segment seg = let bs = seg.Elf_interpreted_segment.elf64_segment_body in @@ -73,32 +68,33 @@ let read name = let info = Sail_interface.populate_and_obtain_global_symbol_init_info name in prerr_endline "Elf read:"; - let (elf_file, elf_epi, symbol_map) = - begin match info with - | Error.Fail s -> failwith (Printf.sprintf "populate_and_obtain_global_symbol_init_info: %s" s) - | Error.Success ((elf_file: Elf_file.elf_file), - (elf_epi: Sail_interface.executable_process_image), - (symbol_map: Elf_file.global_symbol_init_info)) - -> - (* XXX disabled because it crashes if entry_point overflows an ocaml int :-( - prerr_endline (Sail_interface.string_of_executable_process_image elf_epi);*) - (elf_file, elf_epi, symbol_map) + let elf_file, elf_epi, symbol_map = + begin + match info with + | Error.Fail s -> failwith (Printf.sprintf "populate_and_obtain_global_symbol_init_info: %s" s) + | Error.Success + ( (elf_file : Elf_file.elf_file), + (elf_epi : Sail_interface.executable_process_image), + (symbol_map : Elf_file.global_symbol_init_info) + ) -> + (* XXX disabled because it crashes if entry_point overflows an ocaml int :-( + prerr_endline (Sail_interface.string_of_executable_process_image elf_epi);*) + (elf_file, elf_epi, symbol_map) end in prerr_endline "\nElf segments:"; - let (segments, e_entry, e_machine) = - begin match elf_epi, elf_file with - | (Sail_interface.ELF_Class_32 _, _) -> failwith "cannot handle ELF_Class_32" - | (_, Elf_file.ELF_File_32 _) -> failwith "cannot handle ELF_File_32" - | (Sail_interface.ELF_Class_64 (segments, e_entry, e_machine), Elf_file.ELF_File_64 f1) -> - (* remove all the auto generated segments (they contain only 0s) *) - let segments = - Lem_list.mapMaybe - (fun (seg, prov) -> if prov = Elf_file.FromELF then Some seg else None) - segments - in - (segments, e_entry, e_machine) + let segments, e_entry, e_machine = + begin + match (elf_epi, elf_file) with + | Sail_interface.ELF_Class_32 _, _ -> failwith "cannot handle ELF_Class_32" + | _, Elf_file.ELF_File_32 _ -> failwith "cannot handle ELF_File_32" + | Sail_interface.ELF_Class_64 (segments, e_entry, e_machine), Elf_file.ELF_File_64 f1 -> + (* remove all the auto generated segments (they contain only 0s) *) + let segments = + Lem_list.mapMaybe (fun (seg, prov) -> if prov = Elf_file.FromELF then Some seg else None) segments + in + (segments, e_entry, e_machine) end in (segments, e_entry, symbol_map) @@ -113,14 +109,16 @@ let write_file chan paddr i byte = let load_elf name = let segments, e_entry, symbol_map = read name in opt_elf_entry := e_entry; - (if List.mem_assoc "tohost" symbol_map then - let (_, _, tohost_addr, _, _) = List.assoc "tohost" symbol_map in - opt_elf_tohost := tohost_addr); + if List.mem_assoc "tohost" symbol_map then ( + let _, _, tohost_addr, _, _ = List.assoc "tohost" symbol_map in + opt_elf_tohost := tohost_addr + ); (*List.iter (load_segment ~writer:writer) segments*) segments (* The sail model can access this by externing a unit -> Big_int.t function as Elf_loader.elf_entry. *) let elf_entry () = Big_int.big_int_of_string (Nat_big_num.to_string !opt_elf_entry) + (* Used by RISCV sail model test harness for exiting test *) let elf_tohost () = Big_int.big_int_of_string (Nat_big_num.to_string !opt_elf_tohost) diff --git a/test/isabelle/run_aarch64.ml b/test/isabelle/run_aarch64.ml index c60378669..ca71497a9 100644 --- a/test/isabelle/run_aarch64.ml +++ b/test/isabelle/run_aarch64.ml @@ -1,6 +1,4 @@ -open Aarch64_export;; - - +open Aarch64_export (**************************************************************************) (* Sail *) @@ -52,7 +50,7 @@ open Aarch64_export;; (* SUCH DAMAGE. *) (**************************************************************************) -open Elf_loader;; +open Elf_loader let opt_file_arguments = ref ([] : string list) @@ -60,10 +58,9 @@ let options = Arg.align [] let usage_msg = "Sail OCaml RTS options:" -let () = - Arg.parse options (fun s -> opt_file_arguments := !opt_file_arguments @ [s]) usage_msg +let () = Arg.parse options (fun s -> opt_file_arguments := !opt_file_arguments @ [s]) usage_msg -let (>>) = Aarch64.bindS +let ( >> ) = Aarch64.bindS let liftS = Aarch64.liftState (Aarch64.get_regval, Aarch64.set_regval) let load_elf_segment seg = @@ -83,11 +80,6 @@ let load_elf_segment seg = let _ = Random.self_init (); - let elf_segments = match !opt_file_arguments with - | f :: _ -> load_elf f - | _ -> [] - in + let elf_segments = match !opt_file_arguments with f :: _ -> load_elf f | _ -> [] in Aarch64.prerr_results - (Aarch64.initial_state |> - (Aarch64.iterS load_elf_segment elf_segments >> (fun _ -> - liftS (Aarch64.main ())))); + (Aarch64.initial_state |> (Aarch64.iterS load_elf_segment elf_segments >> fun _ -> liftS (Aarch64.main ()))) diff --git a/test/isabelle/run_cheri.ml b/test/isabelle/run_cheri.ml index c50d525d6..16630b140 100644 --- a/test/isabelle/run_cheri.ml +++ b/test/isabelle/run_cheri.ml @@ -1,6 +1,4 @@ -open Cheri_export;; - - +open Cheri_export (**************************************************************************) (* Sail *) @@ -52,7 +50,7 @@ open Cheri_export;; (* SUCH DAMAGE. *) (**************************************************************************) -open Elf_loader;; +open Elf_loader let opt_file_arguments = ref ([] : string list) @@ -60,10 +58,9 @@ let options = Arg.align [] let usage_msg = "Sail OCaml RTS options:" -let () = - Arg.parse options (fun s -> opt_file_arguments := !opt_file_arguments @ [s]) usage_msg +let () = Arg.parse options (fun s -> opt_file_arguments := !opt_file_arguments @ [s]) usage_msg -let (>>) = Sail2_state_monad.bindS +let ( >> ) = Sail2_state_monad.bindS (*let liftS = Sail2_state_lifting.liftState (Cheri_types.get_regval, Cheri_types.set_regval)*) let load_elf_segment seg = @@ -82,11 +79,6 @@ let load_elf_segment seg = let _ = Random.self_init (); - let elf_segments = match !opt_file_arguments with - | f :: _ -> load_elf f - | _ -> [] - in + let elf_segments = match !opt_file_arguments with f :: _ -> load_elf f | _ -> [] in (* State_monad.prerr_results *) - (Cheri_code.initial_state |> - (Sail2_state.iterS load_elf_segment elf_segments >> (fun _ -> - (Cheri_code.mainS ())))); + Cheri_code.initial_state |> (Sail2_state.iterS load_elf_segment elf_segments >> fun _ -> Cheri_code.mainS ()) diff --git a/test/mono/test.ml b/test/mono/test.ml index 6476f8f48..962a8b063 100644 --- a/test/mono/test.ml +++ b/test/mono/test.ml @@ -1,4 +1,8 @@ -match Out.run() with +match Out.run () with | Done _ -> exit 0 -| Fail s -> prerr_endline ("Fail: " ^ s); exit 1 -| _ -> prerr_endline "Unexpected outcome"; exit 1 +| Fail s -> + prerr_endline ("Fail: " ^ s); + exit 1 +| _ -> + prerr_endline "Unexpected outcome"; + exit 1 diff --git a/test/mono/test_with_state.ml b/test/mono/test_with_state.ml index bcb465cef..e97c9fc7b 100644 --- a/test/mono/test_with_state.ml +++ b/test/mono/test_with_state.ml @@ -1,12 +1,15 @@ -module P = Sail2_prompt_monad;; +module P = Sail2_prompt_monad let rec run regs m = match m with | P.Done _ -> 0 | P.Read_reg (r, k) -> run regs (k (Option.get (Out_types.get_regval r regs))) | P.Write_reg (r, v, k) -> run (Option.get (Out_types.set_regval r v regs)) k - | P.Fail s -> prerr_endline ("Fail: " ^ s); exit 1 - | _ -> prerr_endline "Unexpected outcome"; exit 1 -;; -run Out.initial_regstate (Out.run()) + | P.Fail s -> + prerr_endline ("Fail: " ^ s); + exit 1 + | _ -> + prerr_endline "Unexpected outcome"; + exit 1 ;; +run Out.initial_regstate (Out.run ())