From cafb3668cf686f8b11a52a8b46fe4a1420cb2dc5 Mon Sep 17 00:00:00 2001
From: Jakob von Raumer <jakob@von-raumer.de>
Date: Thu, 19 Dec 2024 17:50:13 +0100
Subject: [PATCH] first attempt at managing imports

---
 src/sail_lean_backend/pretty_print_lean.ml | 48 +++++++++++++++++++---
 src/sail_lean_backend/sail_plugin_lean.ml  | 27 +++++++++---
 test/lean/import.expected.lean             |  7 ++++
 test/lean/import.sail                      |  2 +
 4 files changed, 74 insertions(+), 10 deletions(-)
 create mode 100644 test/lean/import.expected.lean
 create mode 100644 test/lean/import.sail

diff --git a/src/sail_lean_backend/pretty_print_lean.ml b/src/sail_lean_backend/pretty_print_lean.ml
index bd49c7009..e058cfce1 100644
--- a/src/sail_lean_backend/pretty_print_lean.ml
+++ b/src/sail_lean_backend/pretty_print_lean.ml
@@ -9,6 +9,8 @@ open Rewriter
 open PPrint
 open Pretty_print_common
 
+module StringMap = Map.Make (String)
+
 let implicit_parens x = enclose (string "{") (string "}") x
 
 let doc_id_ctor (Id_aux (i, _)) =
@@ -283,6 +285,27 @@ let doc_typdef (TD_aux (td, tannot) as full_typdef) =
       nest 2 (flow (break 1) [string "structure"; string id; string "where"] ^^ hardline ^^ enums_doc)
   | _ -> failwith ("Type definition " ^ string_of_type_def_con full_typdef ^ " not translatable yet.")
 
+let string_of_def_con (DEF_aux (d, _)) =
+  match d with
+  | DEF_constraint _ -> "DEF_constraint"
+  | DEF_default _ -> "DEF_default"
+  | DEF_fixity _ -> "DEF_fixity"
+  | DEF_fundef _ -> "DEF_fundef"
+  | DEF_impl _ -> "DEF_impl"
+  | DEF_instantiation _ -> "DEF_instantiation"
+  | DEF_internal_mutrec _ -> "DEF_internal_mutrec"
+  | DEF_let _ -> "DEF_let"
+  | DEF_loop_measures _ -> "DEF_loop_measures"
+  | DEF_mapdef _ -> "DEF_mapdef"
+  | DEF_measure _ -> "DEF_measure"
+  | DEF_outcome _ -> "DEF_outcome"
+  | DEF_overload _ -> "DEF_overload"
+  | DEF_pragma _ -> "DEF_pragma"
+  | DEF_register _ -> "DEF_register"
+  | DEF_scattered _ -> "DEF_scattered"
+  | DEF_type _ -> "DEF_type"
+  | DEF_val _ -> "DEF_val"
+
 let doc_def (DEF_aux (aux, def_annot) as def) =
   match aux with
   | DEF_fundef fdef -> group (doc_fundef fdef) ^/^ hardline
@@ -297,8 +320,23 @@ let rec remove_imports (defs : (Libsail.Type_check.tannot, Libsail.Type_check.en
   | DEF_aux (DEF_pragma ("include_end", _, _), _) :: ds -> remove_imports ds (depth - 1)
   | d :: ds -> if depth > 0 then remove_imports ds depth else d :: remove_imports ds depth
 
-let pp_ast_lean ({ defs; _ } as ast : Libsail.Type_check.typed_ast) o =
-  let defs = remove_imports defs 0 in
-  let output : document = separate_map empty doc_def defs in
-  print o output;
-  ()
+let rec collect_imports defs =
+  match defs with
+  | [] -> []
+  | DEF_aux (DEF_pragma ("include_start", x, _), _) :: ds -> x :: collect_imports ds
+  | d :: ds -> collect_imports ds
+
+let rec pp_ast_lean (defs : (tannot, env) def list) (import_outputs : (string * out_channel) StringMap.t)
+    main_output =
+  match defs with
+  | [] -> []
+  | DEF_aux (DEF_pragma ("include_end", _, _), _) :: ds -> 
+    ds
+  | DEF_aux (DEF_pragma ("include_start", file, _), _) :: ds ->
+    let (new_module, new_main) = StringMap.find file import_outputs in
+    print main_output (string ("import ") ^^ string new_module ^^ hardline ^^ hardline);
+    let defs_after_import = pp_ast_lean ds import_outputs new_main in
+    pp_ast_lean defs_after_import import_outputs main_output
+  | d :: ds ->
+    print main_output (doc_def d);
+    pp_ast_lean ds import_outputs main_output
diff --git a/src/sail_lean_backend/sail_plugin_lean.ml b/src/sail_lean_backend/sail_plugin_lean.ml
index 882063b41..40e70493f 100644
--- a/src/sail_lean_backend/sail_plugin_lean.ml
+++ b/src/sail_lean_backend/sail_plugin_lean.ml
@@ -187,13 +187,30 @@ let create_lake_project (out_name : string) default_sail_dir =
   in
   let project_main = open_out (Filename.concat project_dir (out_name_camel ^ ".lean")) in
   output_string project_main ("import " ^ out_name_camel ^ ".Sail.Sail\n\n");
-  project_main
+  (lean_src_dir, project_main)
 
-let output (out_name : string) ast default_sail_dir =
-  let project_main = create_lake_project out_name default_sail_dir in
+let output (out_name : string) ({ defs; _ } as ast : Libsail.Type_check.typed_ast) default_sail_dir =
+  let lean_src_dir, project_main = create_lake_project out_name default_sail_dir in
   (* Uncomment for debug output of the Sail code after the rewrite passes *)
-  (* Pretty_print_sail.output_ast stdout (Type_check.strip_ast ast); *)
-  Pretty_print_lean.pp_ast_lean ast project_main;
+  Pretty_print_sail.output_ast stdout (Type_check.strip_ast ast);
+  (* let (defs, _) = ast in *)
+  let imports = Pretty_print_lean.collect_imports defs in
+  let ref foo : out_channel Pretty_print_lean.StringMap.t = Pretty_print_lean.StringMap.empty in
+  (* Build a map from strings to output channels for each imported file. *)
+  let import_outputs =
+    List.fold_left
+      (fun map file ->
+        let filename = Libsail.Util.to_upper_camel_case (Filename.chop_suffix (Filename.basename file) ".sail") in
+        let new_out = open_out (Filename.concat lean_src_dir (filename ^ ".lean")) in
+        let out_name_camel = Libsail.Util.to_upper_camel_case out_name in
+        let module_name = out_name_camel ^ "." ^ filename in
+        Pretty_print_lean.StringMap.add file (module_name, new_out) map
+      )
+      Pretty_print_lean.StringMap.empty imports
+  in
+  (* let defs = Pretty_print_lean.remove_imports defs 0 in *)
+  let _ = Pretty_print_lean.pp_ast_lean defs import_outputs project_main in
+  Pretty_print_lean.StringMap.iter (fun _ (_, ch) -> close_out ch) import_outputs;
   close_out project_main
 
 let lean_target out_name { default_sail_dir; ctx; ast; effect_info; env; _ } =
diff --git a/test/lean/import.expected.lean b/test/lean/import.expected.lean
new file mode 100644
index 000000000..dc0114492
--- /dev/null
+++ b/test/lean/import.expected.lean
@@ -0,0 +1,7 @@
+import Out.Sail.Sail
+
+import Out.Trivial
+
+def initialize_registers : Unit :=
+  ()
+
diff --git a/test/lean/import.sail b/test/lean/import.sail
new file mode 100644
index 000000000..9cc84159e
--- /dev/null
+++ b/test/lean/import.sail
@@ -0,0 +1,2 @@
+$include "trivial.sail"
+