forked from CakeML/cakeml
-
Notifications
You must be signed in to change notification settings - Fork 0
/
ml_monad_translatorLib.sml
4101 lines (3697 loc) · 167 KB
/
ml_monad_translatorLib.sml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
(*
The ML code that implements the main part of the monadic translator.
*)
structure ml_monad_translatorLib :> ml_monad_translatorLib = struct
(******************************************************************************)
open preamble
astTheory libTheory semanticPrimitivesTheory evaluateTheory
ml_translatorTheory ml_progTheory ml_progLib
ml_pmatchTheory ml_monadBaseTheory ml_monad_translatorBaseTheory
ml_monad_translatorTheory evaluateTheory cfTacticsLib
Net List packLib stringSimps
open ml_monadBaseLib
open ml_monadStoreLib
(******************************************************************************
Global library variables.
******************************************************************************)
(* TODO add useful debug/info printing *)
val DEBUG = false;
val INFO = true;
fun info_print msg term =
if INFO then
print (msg ^ ": " ^ (term_to_string term) ^ "\n" )
else ();
fun debug_print msg term =
if DEBUG then
print (msg ^ ": " ^ (term_to_string term) ^ "\n" )
else ();
fun debug_print_ty msg typ =
if DEBUG then
print (msg ^ ": " ^ (type_to_string typ) ^ "\n" )
else ();
val _ = (print_asts := true);
(******************************************************************************
Get constant terms and types from ml_monad_translatorTheory.
Prevents parsing in the wrong context.
******************************************************************************)
local
structure Parse = struct
open Parse
val (Type,Term) =
parse_from_grammars
ml_monad_translatorTheory.ml_monad_translator_grammars
end
open Parse
val type_alist =
[("exp",``:ast$exp``),
("string_ty",``:tvarN``),
("unit",``:unit``),
("pair", ``:'a # 'b``),
("num", ``:num``),
("poly_M_type",``:'a -> ('b, 'c) exc # 'a``),
("v_bool_ty",``:v -> bool``),
("hprop_ty",``:hprop``),
("recclosure_exp_ty",``:(tvarN, tvarN # ast$exp) alist``),
("register_pure_type_pat",``:('a, 'b) ml_monadBase$exc``),
("exc_ty",``:('a, 'b) exc``),
("ffi",``:'ffi``),
("v_list", ``:v list``)
]
val term_alist =[("EqSt remove",``!a st. EqSt a st = (a : ('a, 'b) H)``),
("PURE ArrowP ro eq", ``PURE(ArrowP ro H (PURE (Eq a x)) b)``),
("ArrowP ro PURE", ``ArrowP ro H a (PURE b)``),
("ArrowP ro EqSt", ``ArrowP ro H (EqSt a st) b``),
("ArrowM_const",``ArrowM``),
("Eval_const",``Eval``),
("EvalM_const",``EvalM``),
("MONAD_const",``MONAD : (α->v->bool) -> (β->v->bool) -> ((γ,α,β)M,γ) H``),
("PURE_const",``PURE : (α -> v -> bool) -> (α, β) H``),
("FST_const",``FST : 'a # 'b -> 'a``),
("SND_const",``SND : 'a # 'b -> 'b``),
("LENGTH_const", ``LENGTH : 'a list -> num``),
("EL_const", ``EL : num -> 'a list -> 'a``),
("Fun_const",``ast$Fun``),
("Var_const",``ast$Var``),
("Closure_const",``semanticPrimitives$Closure``),
("failure_pat",``\v. (M_failure(C v), state_var)``),
("Eval_pat",``Eval env exp (P (res:'a))``),
("Eval_pat2",``Eval env exp P``),
("derive_case_EvalM_abs",
``\EXN_TYPE res (H:('a -> hprop) # 'ffi ffi_proj).
EvalM ro env st exp (MONAD P EXN_TYPE res) H``),
("Eval_name_RI_abs",``\name RI. Eval env (Var (Short name)) RI``),
("write_const",``write``),
("RARRAY_REL_const",``RARRAY_REL``),
("ARRAY_REL_const",``ARRAY_REL``),
("run_const",``ml_monadBase$run``),
("EXC_TYPE_aux_const",``EXC_TYPE_aux``),
("return_pat",``st_ex_return x``),
("bind_pat",``st_ex_bind x y``),
("pure_seq_pat",``pure_seq x y``),
("otherwise_pat",``x otherwise y``),
("if_statement_pat",``if b then (x:('a,'b,'c) M) else (y:('a,'b,'c) M)``),
("PreImp_EvalM_abs",
``\a name RI f (H: ('a -> hprop) # 'ffi ffi_proj).
PreImp a (!st. EvalM ro env st (Var (Short name)) (RI f) H)``),
("refs_emp",``\refs. emp``),
("UNIT_TYPE",``UNIT_TYPE``),
("nsLookup_val_pat",
``nsLookup (env : env_val) (Short (vname : tvarN)) = SOME (loc : v)``),
("CONTAINER",``ml_translator$CONTAINER (b:bool)``),
("EvalM_pat",``EvalM ro env st e p H``),
("var_assum",``Eval env (Var n) (a (y:'a))``),
("nsLookup_assum",``nsLookup env name = opt``),
("lookup_cons_assum",``lookup_cons name env = opt``),
("eqtype_assum",``EqualityType A``),
("nsLookup_closure_pat",``nsLookup env1.v (Short name1) =
SOME (Closure env2 name2 exp)``),
("nsLookup_recclosure_pat",``nsLookup env1.v (Short name1) =
SOME (Recclosure env2 exps name2)``),
("Eq_pat",``Eq a x``),
("EqSt_pat",``EqSt a x``),
("PreImp simp",``(PreImp a b /\ PRECONDITION a) <=> b``),
("PRECONDITION_pat",``ml_translator$PRECONDITION x``),
("LOOKUP_VAR_pat",``LOOKUP_VAR name env exp``),
("nsLookup_pat",``nsLookup (env : v sem_env).v (Short name) = SOME exp``),
("emp_tm",``set_sep$emp``),
("ffi_ffi_proj", ``p:'ffi ffi_proj``)
]
in
val get_term = fn str => assoc str term_alist
val get_type = fn str => assoc str type_alist
end (* local *)
(* Some constant types *)
val exp_ty = get_type "exp";
val string_ty = get_type "string_ty";
val a_ty = alpha;
val b_ty = beta;
val c_ty = gamma;
val bool_ty = type_of T;
val unit_ty = get_type "unit";
val pair_ty = get_type "pair";
val num_ty = get_type "num";
val venvironment = mk_environment v_ty (* :v sem_env *);
val poly_M_type = get_type "poly_M_type" (* :α -> (β, γ) exc # α *);
val v_bool_ty = get_type "v_bool_ty" (* :v -> bool *);
val hprop_ty = get_type "hprop_ty" (* :hprop = :(heap_part -> bool) -> bool *);
val recclosure_exp_ty = get_type "recclosure_exp_ty";
val exc_ty = get_type "exc_ty";
val v_list_ty = get_type "v_list";
val ffi_ty_var = get_type "ffi";
val register_pure_type_pat = get_type "register_pure_type_pat";
(* Some constant terms *)
val v_env = mk_var("env",venvironment);
val exp_var = mk_var("exp", exp_ty);
val cl_env_tm = mk_var("cl_env",venvironment);
val v_var = mk_var("v",v_ty);
val ArrowM_const = get_term "ArrowM_const";
val Eval_const = get_term "Eval_const";
val EvalM_const = get_term "EvalM_const";
val MONAD_const = get_term "MONAD_const";
val PURE_const = get_term "PURE_const";
val FST_const = get_term "FST_const";
val SND_const = get_term "SND_const";
val LENGTH_const = get_term "LENGTH_const";
val EL_const = get_term "EL_const";
val Fun_const = get_term "Fun_const";
val Var_const = get_term "Var_const";
val Closure_const = get_term "Closure_const";
val failure_pat = get_term "failure_pat";
val Eval_pat = get_term "Eval_pat";
val Eval_pat2 = get_term "Eval_pat2";
val derive_case_EvalM_abs = get_term "derive_case_EvalM_abs";
val Eval_name_RI_abs = get_term "Eval_name_RI_abs";
val write_const = get_term "write_const";
val RARRAY_REL_const = get_term "RARRAY_REL_const";
val ARRAY_REL_const = get_term "ARRAY_REL_const";
val run_const = get_term "run_const";
val EXC_TYPE_aux_const = get_term "EXC_TYPE_aux_const";
val return_pat = get_term "return_pat";
val bind_pat = get_term "bind_pat";
val pure_seq_pat = get_term "pure_seq_pat";
val otherwise_pat = get_term "otherwise_pat";
val if_statement_pat = get_term "if_statement_pat";
val PreImp_EvalM_abs = get_term "PreImp_EvalM_abs";
val ArrowM_ro_tm = mk_comb(ArrowM_const, mk_var("ro", bool_ty));
val ArrowP_PURE_pat = get_term "ArrowP ro PURE";
val ArrowP_EqSt_pat = get_term "ArrowP ro EqSt";
val nsLookup_val_pat = get_term "nsLookup_val_pat"
val EvalM_pat = get_term "EvalM_pat"
val nsLookup_assum = get_term "nsLookup_assum"
val lookup_cons_assum = get_term "lookup_cons_assum"
val eqtype_assum = get_term "eqtype_assum"
val emp_tm = get_term "emp_tm";
val nsLookup_closure_pat = get_term "nsLookup_closure_pat"
val nsLookup_recclosure_pat = get_term "nsLookup_recclosure_pat"
val var_assum = get_term "var_assum"
val nsLookup_assum = get_term "nsLookup_assum"
val lookup_cons_assum = get_term "lookup_cons_assum"
val eqtype_assum = get_term "eqtype_assum"
val Eq_pat = get_term "Eq_pat"
val EqSt_pat = get_term "EqSt_pat"
val PRECONDITION_pat = get_term "PRECONDITION_pat"
val LOOKUP_VAR_pat = get_term "LOOKUP_VAR_pat";
val nsLookup_pat = get_term "nsLookup_pat"
val EVAL_T_F = ml_monad_translatorTheory.EVAL_T_F;
val refs_emp = get_term "refs_emp"
(******************************************************************************
Copy/paste from other files.
******************************************************************************)
(*
From ml_translatorLib .
*)
fun MY_MATCH_MP th1 th2 = let
val tm1 = concl th1 |> dest_imp |> fst
val tm2 = concl th2
val (s, i) = match_term tm1 tm2
in MP (INST s (INST_TYPE i th1)) th2 end;
fun rm_fix res = let
val lemma1 = mk_thm([], ml_translatorLib.get_term "eq remove")
val lemma2 = mk_thm([], get_term "EqSt remove")
in
(QCONV (REWRITE_CONV [lemma1, lemma2]) res |>
concl |> dest_eq |> snd)
end;
(******************************************************************************
Translator config and datatypes.
******************************************************************************)
val translator_state = {
(* the store predicate *)
H_def = ref UNIT_TYPE_def,
default_H = ref refs_emp,
H = ref refs_emp,
(* the type of the references object *)
refs_type = ref unit_ty,
(* the exception refinement invariant and type *)
EXN_TYPE_def = ref UNIT_TYPE_def, (* to replace EXN_TYPE_def_ref *)
EXN_TYPE = ref (get_term "UNIT_TYPE"),
exn_type = ref unit_ty, (* WHAT IS THE DIFFERENCE BETWEEN THESE LAST TWO? *)
VALID_STORE_THM = ref (NONE : thm option),
type_theories = ref ([current_theory(), "ml_translator"] : string list),
exn_handles = ref ([] : (term * thm) list),
exn_raises = ref ([] : (term * thm) list),
exn_functions_defs = ref ([] : (thm * thm) list),
access_patterns = ref([] : (term * thm) list),
refs_functions_defs = ref([] : (thm * thm) list),
rarrays_functions_defs = ref([] : (thm * thm * thm * thm * thm * thm) list),
farrays_functions_defs = ref([] : (thm * thm * thm * thm * thm) list),
induction_helper_thms = ref ([] : (string * thm) list),
(* ^theorems saved in case induction can't be proved, left to aid user *)
local_environment_var_name = ref "%env",
num_local_environment_vars = ref 0,
local_state_init_H = ref false,
(* ^ to replace dynamic_init_H - does store predicate have free vars?*)
store_pinv_def = ref (NONE : thm option),
(* The specifications for dynamically initialized stores *)
dynamic_v_thms =
ref (Net.empty : (string * string * term * thm * thm * string option) net),
(* The environment extension for dynamically initialized stores *)
dynamic_refs_bindings = ref ([] : (term * term) list),
(* Abbreviations of the function values in the case of a
dynamically initialized store*)
local_code_abbrevs = ref([] : thm list),
mem_derive_case_ref = ref ([] : (hol_type * thm) list)
}
(******************************************************************************
Helper functions.
******************************************************************************)
(* This is not used in this file, but is in the signature *)
fun add_access_pattern th =
let val th' = REWRITE_RULE [GSYM AND_IMP_INTRO] th |> SPEC_ALL |> UNDISCH_ALL
val tm = th' |> concl |> rator |> rand |> rand
in
(#access_patterns translator_state :=
(tm, th') :: (!(#access_patterns translator_state));
th)
end;
fun the opt = valOf opt
handle _ => failwith("the of NONE");
(* mk_write name v env = ``write ^name ^v ^env`` *)
fun mk_write name v env = ISPECL [name, v, env] write_def
|> concl |> rator |> rand;
val my_list_mk_comb = ml_monadBaseLib.my_list_mk_comb;
(*
ISPECL for terms rather than theorems
ISPECL_TM [``s1``, ``s2``, ...] ``λ x1 x2 ... . body`` =
``body[s1/x1, s2/x2, ...]``
i.e repeated beta application
*)
fun ISPECL_TM specs tm = let
val tm' = my_list_mk_comb (tm, specs)
fun beta_conv (_, acc) = RATOR_CONV acc THENC BETA_CONV
in tm' |> (foldr beta_conv ALL_CONV specs) |> concl |> rand end;
(* dest_monad_type ``:α -> (β, γ) exc # α`` = (``:α``, ``:β``, ``:γ``) *)
fun dest_monad_type monad_type =
let val subst = (match_type poly_M_type monad_type) in
(type_subst subst a_ty, type_subst subst b_ty, type_subst subst c_ty) end;
(* Should be moved somewhere else *)
(* Repeatedly applies destructor f to term tm until no longer possible,
then returns list of resulting terms *)
fun list_dest f tm =
let val (x,y) = f tm in list_dest f x @ list_dest f y end
handle HOL_ERR _ => [tm];
val dest_args = snd o strip_comb;
(*
Converts any theorem into a standard form:
|- hyp1 ∧ hyp2 ∧ ... ⇒ concl
If no hypotheses, will produce a theorem of form:
|- T ⇒ concl
*)
fun disch_asms th =
let val discharged = th |> DISCH_ALL |> PURE_REWRITE_RULE [AND_IMP_INTRO]
in if is_imp (concl discharged) then discharged
else DISCH T discharged
end;
(*
Creates a monad type with return type ty, and state/exception types as given
by the translator_state.
*)
(* some fresh type variables *)
val gen1 = gen_tyvar();
val gen2 = gen_tyvar();
val gen3 = gen_tyvar();
fun M_type ty =
Type.type_subst [a_ty |-> gen1, b_ty |-> gen2, c_ty |-> gen3] poly_M_type |>
Type.type_subst [gen1 |-> !(#refs_type translator_state),
gen2 |-> ty,
gen3 |-> !(#exn_type translator_state)]
(* Some minor functions *)
local
val {H=H, EXN_TYPE=EXN_TYPE, VALID_STORE_THM=VALID_STORE_THM, ...} =
translator_state
in
fun ISPEC_EvalM th = ISPEC (!H) th;
fun ISPEC_EvalM_EXN_TYPE th = ISPEC (!EXN_TYPE) th;
fun ISPEC_EvalM_MONAD th = ISPECL[!H, !EXN_TYPE] th;
fun ISPEC_EvalM_VALID th =
case (!VALID_STORE_THM) of
SOME store_th => MP (ISPEC (!H) th) store_th
| NONE => UNDISCH (ISPEC (!H) th)
end;
(*
Instantiate ro argument in terms/theorems.
*)
fun INST_ro th =
if (!(#local_state_init_H translator_state)) then th
else INST [mk_var("ro",bool)|->F] th;
fun INST_ro_tm tm =
if (!(#local_state_init_H translator_state)) then tm
else subst [mk_var("ro",bool)|->F] tm;
(* Functions to manipulate of the current closure environment *)
fun get_curr_prog_state () = let
val k = ref init_state
in
ml_prog_update (fn s => (k := s; s));
!k
end;
(*
Local environment
*)
fun get_curr_env () =
if !(#local_state_init_H translator_state) then (
#num_local_environment_vars translator_state :=
(!(#num_local_environment_vars translator_state)) + 1;
mk_var(
(!(#local_environment_var_name translator_state))
^(int_to_string (!(#num_local_environment_vars translator_state))),
venvironment))
else
get_env (get_curr_prog_state ());
fun mk_PURE x = Term.inst [b_ty |-> gen1] PURE_const |>
(fn t => mk_icomb(t, x)) |>
Term.inst [gen1 |-> !(#refs_type translator_state)]
(* Retrieves the parameters given to Eval or EvalM *)
fun get_Eval_arg e = if same_const (strip_comb e |> fst) Eval_const then
e |> rand |> rand
else e |> rator |> rand |> rand;
fun get_Eval_env e = if same_const (strip_comb e |> fst) Eval_const then
e |> rator |> rator |> rand
else e |> rator |> rator |> rator |> rator |> rand;
fun get_Eval_exp e = if same_const (strip_comb e |> fst) Eval_const then
e |> rator |> rand
else e |> rator |> rator |> rand;
val get_EvalM_state = rand o rator o rator o rator;
fun abbrev_nsLookup_code th = let
(* Prevent the introduction of abbreviations for already abbreviated code *)
val th = HYP_CONV_RULE (fn x => true)
(PURE_REWRITE_CONV
(List.map GSYM (!(#local_code_abbrevs translator_state)))) th
val pat1 = nsLookup_closure_pat
val pat2 = nsLookup_recclosure_pat
fun can_match_pat tm = can (match_term pat1) tm orelse
can (match_term pat2) tm
val lookup_assums = List.filter can_match_pat (hyp th)
val get_fun_name =
stringSyntax.fromHOLstring o rand o rand o rand o rator
fun get_code tm =
if can (match_term pat1) tm then (rand o rand o rand) tm
else (rand o rator o rand o rand) tm
val name_code_pairs =
List.map(fn x => (get_fun_name x, get_code x)) lookup_assums |>
List.filter (not o is_const o snd)
fun find_abbrev (name, code) = let
val n = Theory.temp_binding ("____" ^ name ^ "_code____")
val code_def = Definition.new_definition(n,mk_eq(mk_var(n,type_of code),code))
in code_def end
val abbrevs = List.map find_abbrev name_code_pairs
in
(#local_code_abbrevs translator_state) :=
List.concat [abbrevs, !(#local_code_abbrevs translator_state)];
HYP_CONV_RULE (fn x => true) (PURE_REWRITE_CONV (List.map GSYM abbrevs)) th
end;
fun lookup_dynamic_v_thm tm = let
val matches = Net.match tm (!(#dynamic_v_thms translator_state))
val (name, ml_name, hol_fun, th, pre_cond, module) =
first
(fn (_, _, _, x, _, _) => (concl x |> rator |> rand |> same_const tm))
matches
val th = MATCH_MP Eval_Var_Short th
val v_name = stringSyntax.fromMLstring ml_name
val th = SPECL [stringSyntax.fromMLstring ml_name, v_env] th |> UNDISCH_ALL
in th end;
(******************************************************************************
Get refinement invariants from monad and arrow types.
******************************************************************************)
fun get_m_type_inv ty =
let val RI = get_type_inv (dest_monad_type ty |> #2)
val MONAD_tm = Term.inst [c_ty |-> gen1] MONAD_const
val MONAD_comb =
my_list_mk_comb(MONAD_tm, [RI, !(#EXN_TYPE translator_state)])
in Term.inst [gen1 |-> !(#refs_type translator_state)] MONAD_comb end
fun get_arrow_type_inv ty =
if can dest_monad_type ty then get_m_type_inv ty
else let
val (ty1,ty2) = dom_rng ty
val i1 = get_arrow_type_inv ty1
handle HOL_ERR _ => (mk_PURE (get_type_inv ty1))
val i2 = get_arrow_type_inv ty2
in my_list_mk_comb (ArrowM_ro_tm, [!(#H translator_state), i1, i2]) end;
fun smart_get_type_inv ty =
if not (can dest_monad_type ty) andalso can get_arrow_type_inv ty then
ty |> get_arrow_type_inv |> ONCE_REWRITE_CONV [ArrowM_def] |>
concl |> rand |> rand
else get_type_inv ty;
fun get_EqSt_var tm =
if can (match_term ArrowP_PURE_pat) tm then
get_EqSt_var ((rand o rand) tm)
else if can (match_term ArrowP_EqSt_pat) tm then
SOME ((rand o rand o rator) tm)
else NONE
(******************************************************************************
Exceptions - prove specifications for exception raise and handle.
******************************************************************************)
local
(* Prove the specifications for the exception raise *)
fun prove_raise_spec exn_ri_def EXN_RI_tm (raise_fun_def, cons, stamp) = let
val fun_tm = concl raise_fun_def |> strip_forall |>
snd |> lhs |> strip_comb |> fst
val exn_param_types = (fst o ml_monadBaseLib.dest_fun_type o type_of) cons
val refin_invs = List.map smart_get_type_inv exn_param_types
val cons_name = fst (dest_const cons) |> stringSyntax.fromMLstring
val raise_fun = raise_fun_def |> concl |> strip_forall |> snd |> lhs
val E = raise_fun_def |> concl |> strip_forall |> snd |> rhs |> dest_abs
|> snd |> dest_pair |> fst |> rand
val cons_params = strip_comb E |> snd
val ri_type = mk_type("fun", [v_ty, bool_ty])
val EVAL_CONDS = List.map mk_comb (zip refin_invs cons_params)
val EVAL_CONDS = listSyntax.mk_list (EVAL_CONDS, ri_type)
val arity = List.length cons_params
val arity_tm = arity |> numSyntax.term_of_int
val exprs_vars = ml_monadBaseLib.mk_list_vars_same "_expr" exp_ty arity
val exprs = listSyntax.mk_list (exprs_vars, exp_ty)
(* Instantiate the raise specification *)
val cv = mk_var (mk_cons_name cons, astSyntax.str_id_ty)
val raise_spec = ISPECL [cv, stamp, EXN_RI_tm, EVAL_CONDS,
arity_tm, E, exprs, raise_fun] EvalM_raise
val free_vars = strip_forall (concl raise_spec) |> fst
val raise_spec = SPEC_ALL raise_spec
(* Prove the assumptions *)
(* Exception RI assumption *)
val take_assumption = fst o dest_imp o concl
val exn_ri_assum = take_assumption raise_spec
val num_eq_tm = mk_eq(mk_var("n",num_ty), mk_var("m",num_ty))
fun case_on_values (asl,w) = let
val a = List.find (can (match_term num_eq_tm)) asl
fun get_values_var a =
if (type_of a) = num_ty
then get_values_var (rand a)
else a
val values_var = get_values_var (lhs (the a))
in (Cases_on `^values_var` >> FULL_SIMP_TAC list_ss []) (asl,w) end
handle Match => raise (ERR "prove_raise_spec" "case_on_values failed")
val prove_exn_ri_assum =
rpt strip_tac
\\ ntac arity case_on_values
\\ TRY case_on_values
\\ FULL_SIMP_TAC list_ss [LIST_CONJ_def]
\\ SIMP_TAC list_ss [exn_ri_def]
\\ instantiate
val exn_ri_lemma = prove(exn_ri_assum, prove_exn_ri_assum)
val raise_spec = MP raise_spec exn_ri_lemma
(* Trivial assumptions *)
fun prove_assumption th = let
val a = take_assumption th
val lemma = SIMP_CONV list_ss [raise_fun_def] a |> EQT_ELIM
in MP th lemma end
val raise_spec = prove_assumption raise_spec
|> prove_assumption
|> prove_assumption
(* Rewrite, generalize *)
val raise_spec = SIMP_RULE list_ss [LIST_CONJ_def,MAP,ZIP,
GSYM AND_IMP_INTRO] raise_spec
val raise_spec = GENL free_vars raise_spec
fun GEN_pair ((x,y),th) = GENL [x,y] th
val raise_spec =
List.foldr GEN_pair raise_spec (zip cons_params exprs_vars)
val thm_name = "EvalM_" ^(dest_const fun_tm |> fst)
val raise_spec = save_thm(thm_name, raise_spec)
val _ = print ("Saved theorem __ \"" ^thm_name ^"\"\n")
in raise_spec end
(* Prove the specifications for the exception handle *)
fun prove_handle_spec exn_ri_def EXN_RI_tm (handle_fun_def, cons, stamp) = let
(* Rename the variables in handle_fun_def *)
val handle_fun_def = let
val vars = concl handle_fun_def |> strip_forall |> fst
val types = List.map type_of vars
val new_vars = ml_monadBaseLib.mk_list_vars "x" types
in GENL new_vars (SPECL new_vars handle_fun_def) end
val handle_fun = concl handle_fun_def |> strip_forall |> snd |> lhs
val exn_type = type_of EXN_RI_tm |> dest_type |> snd |> List.hd
val cons_name = fst (dest_const cons) |> stringSyntax.fromMLstring
(* Instantiate the EvalM specification *)
val CORRECT_CONS = let
val case_tm = TypeBase.case_const_of exn_type
|> Term.inst [alpha |-> bool_ty]
val all_cons = TypeBase.constructors_of exn_type
fun mk_case_bool_fun c = let
val types = (fst o ml_monadBaseLib.dest_fun_type o type_of) c
val vars = ml_monadBaseLib.mk_list_vars "e" types
val bool_tm = if same_const c cons then T else F
val body = list_mk_abs (vars, bool_tm)
in body end
val case_funs = List.map mk_case_bool_fun all_cons
val e_var = mk_var("e", exn_type)
val case_tm = list_mk_comb (case_tm, e_var::case_funs)
in mk_abs(e_var, case_tm) end
val params_types = (fst o ml_monadBaseLib.dest_fun_type o type_of) cons
val arity = List.length params_types
val arity_tm = arity |> numSyntax.term_of_int
val params_vars = ml_monadBaseLib.mk_list_vars "_x" params_types
val PARAMS_CONDITIONS = let
val case_tm = TypeBase.case_const_of exn_type
|> Term.inst [alpha |-> bool_ty]
val all_cons = TypeBase.constructors_of exn_type
val paramsv_var = mk_var("_paramsv", v_list_ty)
val LENGTH_cond = Term.inst [alpha |-> v_ty] LENGTH_const
val LENGTH_cond = mk_eq(mk_comb(LENGTH_cond, paramsv_var),arity_tm)
fun mk_indices n = let
fun mk_aux i =
if i = n then []
else (numSyntax.term_of_int i)::mk_aux(i+1)
in mk_aux 0 end
val indices = mk_indices arity
val EL_tm = Term.inst [alpha |-> v_ty] EL_const
fun mk_type_condition (var, i) = let
val TYPE = get_type_inv (type_of var)
val getter = list_mk_comb(EL_tm, [i, paramsv_var])
val tm = list_mk_comb(TYPE, [var,getter])
in tm end
val type_conditions =
List.map mk_type_condition (zip params_vars indices)
val all_conditions = LENGTH_cond::type_conditions
val conditions = list_mk mk_conj all_conditions T
val conditions_fun = list_mk_abs (params_vars, conditions)
fun mk_case_bool_fun c = let
val types = (fst o ml_monadBaseLib.dest_fun_type o type_of) c
val vars = ml_monadBaseLib.mk_list_vars "e" types
val body =
if same_const c cons then conditions_fun
else list_mk_abs(ml_monadBaseLib.mk_list_vars "_x" types, F)
in body end
val case_funs = List.map mk_case_bool_fun all_cons
val e_var = mk_var("_E", exn_type)
val case_tm = list_mk_comb (case_tm, e_var::case_funs)
in list_mk_abs([e_var, paramsv_var], case_tm) end
(* We need to rewrite the handle function in a different
manner before instantiating the theorem *)
val (alt_handle_fun, alt_x1, alt_x2) = let
val x1_var = strip_comb handle_fun |> snd |> List.hd
val handle_fun_def_rhs = concl handle_fun_def |> strip_forall
|> snd |> rhs
val (state_var1, case_tm1) = dest_abs handle_fun_def_rhs
val case_tm0 = dest_comb case_tm1 |> fst
val (res_var_state_var2, case_tm2) = dest_comb case_tm1
|> snd |> strip_abs
val (res_var, state_var2) = (el 1 res_var_state_var2,
el 2 res_var_state_var2)
val case_tm3 = rator case_tm2
val alt_x2_tm = rand case_tm2
val (e_var,alt_x2_tm) = dest_abs alt_x2_tm
val alt_x2_tm = list_mk_abs([e_var,state_var2], alt_x2_tm)
val alt_x2_type = type_of alt_x2_tm
val alt_x2_var = mk_var("x2", alt_x2_type)
val alt_handle_fun = list_mk_comb(alt_x2_var, [e_var,state_var2])
val alt_handle_fun = mk_abs(e_var,alt_handle_fun)
val alt_handle_fun = mk_comb(case_tm3, alt_handle_fun)
val alt_handle_fun =
list_mk_abs([res_var, state_var2], alt_handle_fun)
val alt_handle_fun = mk_comb(case_tm0,alt_handle_fun)
val alt_handle_fun =
list_mk_abs([x1_var, alt_x2_var, state_var1],alt_handle_fun)
in (alt_handle_fun, x1_var, alt_x2_tm) end
(* We also need to rewrite a2 in a different manner *)
val a2_alt = let
val state_type = !(#refs_type translator_state)
val a2_type =
ml_monadBaseLib.mk_fun_type(state_type::params_types, bool_ty)
val a2_var = mk_var("a2", a2_type)
val E_var = mk_var("E", exn_type)
val case_tm = mk_comb(TypeBase.case_const_of exn_type, E_var)
|> Term.inst [alpha |-> bool_ty]
val consl = TypeBase.constructors_of exn_type
val state_var = mk_var("st", state_type)
fun mk_condition c = let
val types = fst(ml_monadBaseLib.dest_fun_type (type_of c))
val vars = ml_monadBaseLib.mk_list_vars "e" types
in if same_const c cons
then list_mk_abs(vars, list_mk_comb(a2_var, state_var::vars))
else list_mk_abs(vars, F)
end
val conditions = List.map mk_condition consl
val a2_alt = list_mk_comb(case_tm, conditions)
val a2_alt = list_mk_abs([state_var, E_var], a2_alt)
in a2_alt end
(* Instantiate the specification *)
val cv = mk_var (mk_cons_name cons, astSyntax.str_id_ty)
val handle_spec = ISPECL ([cv, stamp, CORRECT_CONS, PARAMS_CONDITIONS,
EXN_RI_tm, alt_handle_fun, alt_x1, alt_x2,
arity_tm])
EvalM_handle
val ty_subst = match_type
(type_of a2_alt)
(type_of (handle_spec |> concl |> dest_forall |> fst))
val a2_alt' = Term.inst ty_subst a2_alt
val handle_spec = handle_spec |> SPEC a2_alt'
val free_vars = concl handle_spec |> strip_forall |> fst
val handle_spec = SPECL free_vars handle_spec
(* Prove the assumptions *)
val take_assumption = fst o dest_imp o concl
(* branch handle *)
val branch_handle_assum = take_assumption handle_spec
val case_thms = [
TypeBase.case_def_of exn_type,
TypeBase.case_def_of pair_ty,
TypeBase.case_def_of exc_ty
]
(* set_goal ([], branch_handle_assum) *)
val prove_branch_handle_assum =
rpt strip_tac >> FULL_SIMP_TAC bool_ss case_thms
val branch_handle_lemma = prove(branch_handle_assum,
prove_branch_handle_assum)
val handle_spec = MP handle_spec branch_handle_lemma
(* branch let through *)
val branch_let_through_assum = take_assumption handle_spec
(* set_goal ([], branch_let_through_assum) *)
val branch_let_through_lemma = prove(branch_let_through_assum,
rpt strip_tac
\\ FULL_SIMP_TAC bool_ss case_thms
\\ rpt (BasicProvers.PURE_CASE_TAC \\ fs[]))
val handle_spec = MP handle_spec branch_let_through_lemma
(* refinement invariant equality *)
val ref_inv_eq_assum = take_assumption handle_spec
(* set_goal ([], ref_inv_eq_assum) *)
val ref_inv_eq_lemma = prove(ref_inv_eq_assum,
rpt strip_tac
\\ FULL_SIMP_TAC bool_ss case_thms
\\ rpt (BasicProvers.PURE_CASE_TAC \\ fs[exn_ri_def]))
val handle_spec = MP handle_spec ref_inv_eq_lemma
(* refinement invariant inequality *)
val ref_inv_ineq_assum = take_assumption handle_spec
(* set_goal ([], ref_inv_ineq_assum) *)
val ref_inv_ineq_lemma = prove(ref_inv_ineq_assum,
BasicProvers.Cases
\\ rpt strip_tac
\\ fs[exn_ri_def])
val handle_spec = MP handle_spec ref_inv_ineq_lemma
(* Rewrite - replace the variables E and paramsv,
remove the case expressions making a disjunction on the type of the
constructor of E *)
val remove_imp = snd o dest_imp
val EvalM_assum =
concl handle_spec |> remove_imp |> remove_imp
|> remove_imp |> dest_imp |> fst
(* Eliminate E and paramsv in the assumptions *)
val EvalM_assum2 = dest_conj EvalM_assum |> snd
val (vars, EvalM_assum2_body) = strip_forall EvalM_assum2
val (st_var,E_var,paramsv_var) = (el 1 vars,el 2 vars,el 3 vars)
val v_vars = ml_monadBaseLib.mk_list_vars_same "v" v_ty arity
val E_params_vars = ml_monadBaseLib.mk_list_vars "e" params_types
val params_v_eq = mk_eq(paramsv_var,
listSyntax.mk_list (v_vars, v_ty))
val E_eq = mk_eq(E_var, list_mk_comb(cons, E_params_vars))
val EvalM_assum2_alt = list_mk_imp ([E_eq, params_v_eq],
EvalM_assum2_body)
val EvalM_assum2_alt = list_mk_forall (vars@E_params_vars@v_vars,
EvalM_assum2_alt)
val EvalM_assum2_eq = mk_eq(EvalM_assum2, EvalM_assum2_alt)
val aquoted_vars = [`^st_var`, `^E_var`, `^paramsv_var`]
val num_eq_pat = mk_eq(mk_var("n",num_ty),mk_var("m",num_ty))
fun cases_on_paramsv a =
if can (match_term num_eq_pat) (concl a) then let
fun get_last_rand x =
if can dest_comb x then get_last_rand (rand x) else x
val paramsv_var = concl a |> lhs |> get_last_rand
in Cases_on `^paramsv_var` end
else failwith "cases_on_paramsv"
val cases_on_paramsv = last_assum cases_on_paramsv
(* set_goal ([], EvalM_assum2_eq) *)
val EvalM_assum2_eq_lemma = prove(EvalM_assum2_eq,
EQ_TAC \\ rpt strip_tac
>-(last_x_assum(qspecl_then aquoted_vars assume_tac)
\\ fs[] \\ rw[] \\ fs[])
\\ last_x_assum(qspecl_then aquoted_vars assume_tac)
\\ Cases_on `^E_var` \\ fs[]
(* Cases on paramsv *)
\\ ntac (arity+1) (cases_on_paramsv \\ fs[]));
val handle_spec = PURE_REWRITE_RULE[EvalM_assum2_eq_lemma] handle_spec
(* Eliminate E and paramsv in the generated precondition *)
val pre =
concl handle_spec |> remove_imp |> remove_imp
|> remove_imp |> remove_imp |> dest_imp |> fst
val (vars, pre_body) = strip_forall pre
val (st_var, E_var) = (el 1 vars, el 2 vars)
val (pre1,pre2) = dest_conj pre_body
val E_with_params = list_mk_comb(cons, E_params_vars)
val pre2_alt = Term.subst [E_var |-> E_with_params] pre2
val pre_alt = list_mk_forall(st_var::E_params_vars,
mk_conj(pre1,pre2_alt))
val pre_eq = mk_eq(pre,pre_alt)
(* set_goal ([], pre_eq) *)
fun random_gen_tac th = let
val (x,_) = dest_forall(concl th)
val y = genvar (type_of x)
in ASSUME_TAC(SPEC y th) end
val pre_eq_lemma = prove(pre_eq,
EQ_TAC
>-(rpt strip_tac >> fs[CONTAINER_def]
\\ `^pre1` by (rpt (first_x_assum random_gen_tac \\ fs[]) \\ fs[])
\\ fs[])
\\ rpt strip_tac
\\ TRY(Cases_on `^E_var`)
\\ fs[CONTAINER_def])
val handle_spec = PURE_REWRITE_RULE[pre_eq_lemma] handle_spec
(* Rewrite *)
val handle_spec =
SIMP_RULE list_ss [GSYM handle_fun_def,
TypeBase.case_def_of exn_type,
GSYM AND_IMP_INTRO] handle_spec
(* Check *)
val f = UNDISCH_ALL handle_spec |> concl |> rator |> rand |> rand
val _ = if f ~~ handle_fun then () else raise (ERR "prove_handle_spec"
"Error : the generated spec does not have the proper form")
(* Generalize *)
val handle_spec = GENL free_vars handle_spec
(* Save the theorem *)
val fun_tm = strip_comb handle_fun |> fst
val thm_name = "EvalM_" ^(dest_const fun_tm |> fst)
val handle_spec = save_thm(thm_name, handle_spec)
val _ = print ("Saved theorem __ \"" ^thm_name ^"\"\n")
in handle_spec end;
in
fun add_raise_handle_functions exn_functions exn_ri_def = let
val (raise_functions, handle_functions) = unzip exn_functions
(* Extract information from the exception refinement invariant *)
val exn_ri_cases = CONJUNCTS exn_ri_def
val EXN_RI_tm =
List.hd exn_ri_cases |> concl |> strip_forall |>
snd |> lhs |> rator |> rator
val exn_ri_pats =
List.map (rand o rator o lhs o snd o strip_forall o concl) exn_ri_cases
val exn_ri_cons = List.map (fst o strip_comb) exn_ri_pats
val exn_ri_params_types =
List.map (fst o ml_monadBaseLib.dest_fun_type o type_of) exn_ri_cons
fun safe_register_type ty = if can get_type_inv ty then ()
else register_type ty
val _ = mapfilter safe_register_type (List.concat exn_ri_params_types)
val exn_ri_decomposed_rhs =
List.map ((list_dest dest_conj) o snd o strip_exists o
rhs o snd o strip_forall o concl) exn_ri_cases
val exn_ri_deep_cons = List.map (rator o rand o hd) exn_ri_decomposed_rhs
val exn_ri_stamps =
List.map (rand o rand o rand) exn_ri_deep_cons
val exn_info = zip exn_ri_cons exn_ri_stamps
val exn_type = type_of EXN_RI_tm |> dest_type |> snd |> List.hd
(* Link the raise definitions with the appropriate information *)
val raise_funct_pairs =
List.map (fn x => (x, concl x |> strip_forall |> snd |> rhs
|> dest_abs |> snd |> dest_pair |> fst |> rand
|> strip_comb |> fst)) raise_functions
val raise_info =
List.map (fn(d, tm) => tryfind (fn (x1, x2) =>
if same_const x1 tm then (d, x1, x2) else failwith "")
exn_info) raise_funct_pairs
(*
* Prove the raise specifications
*)
val raise_specs =
List.map (prove_raise_spec exn_ri_def EXN_RI_tm) raise_info
(* Link the handle definitions with the appropriate information *)
fun get_handle_cons handle_fun_def =
let
val exn_cases = concl handle_fun_def |> strip_forall |> snd |> rhs |>
dest_abs |> snd |> rand |> strip_abs |> snd |> rand |>
dest_abs |> snd
val cases_list = strip_comb exn_cases |> snd |> List.tl
val cases_list = List.map (snd o strip_abs) cases_list
val cases_cons = TypeBase.constructors_of exn_type
val cases_pairs = zip cases_cons cases_list
val handled_cons = the
(List.find (fn (x, y) => not (can dest_pair y)) cases_pairs)
in (handle_fun_def, fst handled_cons) end
handle Bind => failwith "get_handled_cons"
val handle_funct_pairs = List.map get_handle_cons handle_functions
val handle_info = List.map (fn(d, tm) =>
tryfind (fn (x1, x2) =>
if same_const x1 tm
then (d, x1, x2)
else failwith "") exn_info)
handle_funct_pairs
(*
* Prove the handle specifications
*)
val handle_specs =
List.map (prove_handle_spec exn_ri_def EXN_RI_tm) handle_info
(* Store the proved theorems *)
fun extract_pattern th = let
val pat = concl th |> strip_forall |> snd |> strip_imp
|> snd |> rator |> rand |> rand
in (pat, th) end
val {exn_raises=e_raises, exn_handles=e_handles,...} = translator_state
in
e_raises := ((List.map extract_pattern raise_specs) @ (!e_raises));
e_handles := ((List.map extract_pattern handle_specs) @ (!e_handles));
zip raise_specs handle_specs
end;
end; (* end local *)
(******************************************************************************
Translation initialisation
******************************************************************************)
fun compute_dynamic_refs_bindings all_access_specs = let
val store_vars = FVL [(!(#H translator_state)) |> dest_pair |> fst]
empty_varset;
fun get_dynamic_init_bindings spec = let
val spec = SPEC_ALL spec |> UNDISCH_ALL
val pat = nsLookup_val_pat
val lookup_assums =
List.filter (can (match_term pat)) (hyp (UNDISCH_ALL spec))
fun get_name_loc tm = ((rand o rand o rand o rator) tm, (rand o rand) tm)
val bindings = List.map get_name_loc lookup_assums |>
List.filter (fn (x, y) => HOLset.member(store_vars, y))
in bindings end
val all_bindings =
List.concat(List.map get_dynamic_init_bindings all_access_specs)
val bindings_map =
List.foldl (fn ((n, v), m) => Redblackmap.insert(m, v, n))
(Redblackmap.mkDict Term.compare) all_bindings
val store_varsl =
strip_comb ((!(#H translator_state)) |> dest_pair |> fst) |> snd
val store_varsl = store_varsl |>
filter (fn t => not (can (match_type ``:'a -> v -> bool``) (type_of t)))
val final_bindings =