1structure closure :> closure =
2struct
3
4open HolKernel Parse boolLib pairLib PairRules simpLib boolSimps bossLib
5     NormalTheory pairSyntax basic Normal;
6
7val atom_tm = prim_mk_const{Name="atom",Thy="Normal"};
8val fun_tm = prim_mk_const{Name="fun",Thy="Normal"};
9
10(*---------------------------------------------------------------------------*)
11(* Close the function variable by variable                                   *)
12(* Consider a free variable each time                                        *)
13(*---------------------------------------------------------------------------*)
14
15fun abs_fvars tm =
16 let fun close_up f body =
17      List.foldl
18       (fn (v,t) => mk_plet(v,mk_comb(inst[alpha |-> type_of v] atom_tm,v),t))
19       f (free_vars body)
20     fun trav t =
21       if is_let t then
22           let val (v,M,N) = dest_plet t in
23               if is_pabs M then
24                 close_up (mk_plet
25                     (v, mk_comb(inst[alpha |-> type_of M] fun_tm,trav M),trav N)) M
26               else mk_plet (v, trav M, trav N)
27           end
28       else if is_cond t then
29           let val (J,M,N) = dest_cond t in
30             mk_cond (J, trav M, trav N)
31           end
32       else if is_pabs t then
33            let val (M,N) = dest_pabs t in
34            mk_pabs (trav M, trav N) end
35       else t
36  in
37    trav tm
38  end;
39
40fun close_one_by_one def =
41  let
42    val th1 = abs_fun def
43    val body = rhs (concl th1)
44    val t1 = abs_fvars body
45    val th2 = CONV_RULE (LHS_CONV (SIMP_CONV bool_ss [fun_def]))
46          (GSYM (SIMP_CONV std_ss [LET_ATOM] t1))
47    val th3 = ONCE_REWRITE_RULE [th2] th1          (* abs forms *)
48    val th4 = CONV_RULE (RHS_CONV (SIMP_CONV bool_ss [CLOSE_ONE])) th3
49  in
50    th4
51  end
52
53(*---------------------------------------------------------------------------*)
54(* Close the function variable by variable                                   *)
55(* Consider all free variable each time                                      *)
56(*---------------------------------------------------------------------------*)
57
58fun identify_fun tm =
59  let
60     fun trav t =
61       if is_let t then
62           let val (v,M,N) = dest_plet t in
63               if is_pabs M then
64                 mk_plet (v, mk_comb (inst [alpha |-> type_of M] fun_tm, trav M), trav N)
65               else mk_plet (v, trav M, trav N)
66           end
67       else if is_cond t then
68           let val (J,M,N) = dest_cond t in
69             mk_cond (J, trav M, trav N)
70           end
71       else if is_pabs t then
72            let val (M,N) = dest_pabs t in
73            mk_pabs (trav M, trav N) end
74       else t
75  in
76    trav tm
77  end;
78
79fun abs_all_fvars tm =
80  let
81     fun trav t =
82       if is_let t then
83           let val (v,M,N) = dest_plet t in
84               if is_pabs M then
85                  let val cls = list_mk_pair (free_vars M)
86                      val (args, d) = dest_pabs M
87                      val (M',N') = (trav M, trav N)
88                      val f = mk_pabs (cls, M')
89                      val v' = (mk_var (term_to_string v, type_of f))
90                      val f' = mk_comb (inst [alpha |-> type_of f] fun_tm, f)
91                      val N'' = subst_exp [v |-> mk_comb (v', cls)] N'
92                  in
93                     mk_plet (v', f', N'')
94                  end
95               else mk_plet (v, trav M, trav N)
96           end
97       else if is_cond t then
98           let val (J,M,N) = dest_cond t in
99             mk_cond (J, trav M, trav N)
100           end
101       else if is_pabs t then
102            let val (M,N) = dest_pabs t in
103            mk_pabs (trav M, trav N) end
104       else t
105  in
106    trav tm
107  end;
108
109fun close_all def =
110  let
111    val th1 = abs_fun def
112    val body = rhs (concl th1)
113    val t1 = abs_all_fvars body
114    val th2 = GSYM (PBETA_RULE (SIMP_CONV pure_ss [INLINE_EXPAND] t1))
115    val t2 = identify_fun body
116    val th3 = PBETA_RULE (CONV_RULE (RHS_CONV (SIMP_CONV bool_ss [LET_FUN])) (REFL t2))
117    val th4 = TRANS th3 th2
118    val th5 = TRANS th1 (CONV_RULE (LHS_CONV (SIMP_CONV bool_ss [fun_def])) th4)
119  in
120    th5
121  end
122
123(*---------------------------------------------------------------------------*)
124(*   Closure conversion                                                      *)
125(*   Move all functions definitions to top level                             *)
126(*---------------------------------------------------------------------------*)
127
128val TOP_LEVEL_RULE =  (* may loop forever, to be improved *)
129  SIMP_RULE pure_ss [TOP_LEVEL_LET, TOP_LEVEL_COND_1, TOP_LEVEL_COND_2];
130
131(*---------------------------------------------------------------------------*)
132(*   Closure conversion                                                      *)
133(*   Interface                                                               *)
134(*---------------------------------------------------------------------------*)
135
136fun closure_convert def =
137  let
138    val th1 = close_all def
139    val th2 = TOP_LEVEL_RULE th1
140    val th3 = SIMP_RULE pure_ss [FLATTEN_LET] th2
141    val th4 = SSA_RULE th3
142  in
143    th4
144  end
145
146end (* struct *)
147
148