1structure compiler :> compiler =
2struct
3
4(* for interactive use
5fun load_path_add x = loadPath := !loadPath @ [concat Globals.HOLDIR x];
6val _ = load_path_add "/examples/mc-logic";
7val _ = load_path_add "/examples/ARM/v4";
8val _ = load_path_add "/tools/mlyacc/mlyacclib";
9(* load_path_add "/examples/dev/sw2"; *)
10
11val _ = quietdec := true;
12app load ["arm_compilerLib", "Normal", "inline", "closure",
13          "regAlloc", "funcCall", "refine"];
14val _ = quietdec := false;
15*)
16
17open HolKernel Parse boolLib bossLib boolSyntax;
18open compilerLib;
19
20val _ = numLib.prefer_num();
21
22val _ = Globals.priming := NONE;
23
24(*---------------------------------------------------------------------------*)
25(* Interface Functions.                                                      *)
26(*---------------------------------------------------------------------------*)
27
28(* val arm_compile = arm_compilerLib.arm_compile; *)
29
30val normalize = Normal.normalize;
31val SSA_RULE = Normal.SSA_RULE;
32val expand_anonymous = inline.expand_anonymous;
33val expand_named = inline.expand_named;
34val optimize_norm = inline.optimize_norm;
35val close_one_by_one = closure.close_one_by_one;
36val close_all = closure.close_all;
37val closure_convert = closure.closure_convert;
38val parallel_move = regAlloc.parallel_move;
39val reg_alloc = regAlloc.reg_alloc;
40(*
41val printSAL = SALGen.printSAL;
42val certified_gen = SALGen.certified_gen;
43*)
44
45(*---------------------------------------------------------------------------*)
46(* Auxiliary functions.                                                      *)
47(*---------------------------------------------------------------------------*)
48
49structure M = Binarymap;
50
51fun resultOf (PASS x) = x;
52
53fun head eqns =
54  strip_comb (lhs(snd(strip_forall(hd (strip_conj(concl eqns))))));
55
56(*---------------------------------------------------------------------------*)
57(* Compiling a list of functions.                                            *)
58(*---------------------------------------------------------------------------*)
59
60fun defname th =
61  fst(dest_const(fst(strip_comb(lhs(snd(strip_forall(concl th)))))));
62
63fun compenv comp =
64 let fun compile (env,[]) = PASS(rev env)
65       | compile (env,h::t) =
66          let val name = defname h
67          in
68            print ("Compiling "^name^" ... ");
69            case total comp (env,h)
70             of SOME def1 => (print "succeeded.\n"; compile(def1::env,t))
71              | NONE => (print "failed.\n"; FAIL(env,h::t))
72          end
73 in
74    compile
75 end;
76
77(*---------------------------------------------------------------------------*)
78(* Compile a list of definitions, accumulating the environment.              *)
79(*---------------------------------------------------------------------------*)
80
81fun complist passes deflist = compenv passes ([],deflist);
82
83(*---------------------------------------------------------------------------*)
84(* Conversion 1.                                                             *)
85(* Simplify to standard format.                                              *)
86(* Basic flattening via CPS and unique names                                 *)
87(*---------------------------------------------------------------------------*)
88
89val cond_lift = ref true;
90
91fun convert1 (env,def) =
92  let val def1 = Normal.pre_process def
93      val def2 = if !cond_lift then refine.lift_cond def1 else def1
94  in
95      (Normal.SSA_RULE o refine.refine_const o normalize (* o refine.refine_const *)) def2
96  end;
97
98(* All previous, plus inlining and optimizations                             *)
99
100fun convert1a (env,def) =
101  let val def1 = convert1 def
102  in
103   Normal.SSA_RULE (inline.optimize_norm env def1)
104  end;
105
106(* All previous, and closure conversion.                                     *)
107
108fun convert1b (env,def) =
109  let val def1 = convert1a (env,def)
110  in case total closure.closure_convert def1
111      of SOME thm => Normal.SSA_RULE (inline.optimize_norm env thm)
112       | NONE => def1
113  end;
114
115(*---------------------------------------------------------------------------*)
116(* Conversion 2.                                                             *)
117(* All previous, and register allocation.                                    *)
118(*---------------------------------------------------------------------------*)
119
120fun convert2 (env,def) =
121  let val def1 = convert1 (env,def)
122      val def2 = regAlloc.reg_alloc def1
123  in
124    def2
125  end;
126
127(* Different convert2, in which some intermediate steps are skipped.         *)
128
129fun convert2a (env,def) =
130  let val def1 = convert1a (env,def)
131  in
132     regAlloc.reg_alloc def1
133  end;
134
135(*---------------------------------------------------------------------------*)
136(* Conversion 3.                                                             *)
137(* All previous, and refinement.                                             *)
138(*---------------------------------------------------------------------------*)
139
140fun convert3 (env,def) =
141  let val def1 = convert2 (env,def)
142  in
143    funcCall.caller_save_call def1
144  end
145
146(*---------------------------------------------------------------------------*)
147(* Compiling a list of source functions.                                     *)
148(*---------------------------------------------------------------------------*)
149
150fun defname th =
151  fst(dest_const(fst(strip_comb(lhs(snd(strip_forall(concl th)))))));
152
153fun compenv comp =
154 let fun compile (env,[]) = PASS(rev env)
155       | compile (env,h::t) =
156          let val name = defname h
157          in
158            print ("Compiling "^name^" ... ");
159            case total comp (env,h)
160             of SOME def1 => (print "succeeded.\n"; compile(def1::env,t))
161              | NONE => (print "failed.\n"; FAIL(env,h::t))
162          end
163 in
164    compile
165 end;
166
167(* Compile a list of definitions, accumulating the environment.              *)
168
169fun complist passes deflist = compenv passes ([],deflist);
170
171(*---------------------------------------------------------------------------*)
172(* Compilation phases in front end.                                          *)
173(*---------------------------------------------------------------------------*)
174
175(* Basic flattening via CPS and unique names                                 *)
176
177val pass1 = complist convert1
178
179(* All previous, and register allocation.                                    *)
180
181val pass2 = complist convert2
182
183(* All previous, and refinement.                                             *)
184
185val pass3 = complist convert3
186
187(* All previous, and the following:                                          *)
188(*  1. enforce caller-save-style function calls                              *)
189(*  2. tune the normal forms for the back-end                                *)
190
191
192(*---------------------------------------------------------------------------*)
193(* Replace function f with f' when f is called.                              *)
194(*---------------------------------------------------------------------------*)
195
196val renamed = (* Format: [function's old name |-> function's new name] *)
197ref (M.mkDict regAlloc.tvarOrder)
198
199fun renamed_rules () =
200    List.foldl
201      (fn ((f,f'),ys) => (f |-> f') :: ys)
202      [] (M.listItems (!renamed));
203
204val only_one = ref false;
205
206(*---------------------------------------------------------------------------*)
207(* Front end.                                                                *)
208(* the flag indicates whether we start all over                              *)
209(*---------------------------------------------------------------------------*)
210
211fun f_compile_basic defs flag =
212 let
213  val _ = if flag then (renamed := M.mkDict regAlloc.tvarOrder) else ()
214
215  fun redefine def =
216   let
217     val (fname, fbody) = dest_eq (concl def)
218     val (args,body) = pairSyntax.dest_pabs fbody
219     val def1 = CONV_RULE (RHS_CONV pairLib.GEN_BETA_CONV) (AP_THM def args)
220
221     val (cname, ctype) = dest_const fname
222     val new_name = cname ^ "'"
223     val cvar = mk_var(new_name, ctype)
224
225     val vtm = subst [fname |-> cvar] (concl def1)
226     (* rename the function name in function application *)
227     val vtm1 = subst (renamed_rules ()) vtm
228     (* val _ = HOL_MESG "redefinition" *)
229     val PASS (defn2,tcs) = TotalDefn.std_apiDefine (new_name,vtm1)
230     val def2 = LIST_CONJ (Defn.eqns_of defn2)
231     val ind = Defn.ind_of defn2
232     val (def', ind') =
233        let val lem = prove(mk_eq(mk_const(new_name, ctype), fname),
234                 REWRITE_TAC [FUN_EQ_THM] THEN
235                 SIMP_TAC bool_ss [pairTheory.FORALL_PROD] THEN
236                 REWRITE_TAC [Once def1, Once def2])
237            val ind2 = case ind of
238                   SOME th => ONCE_REWRITE_RULE [lem] th
239                 | NONE => TRUTH
240        in (def1, ind2) end
241        handle _ =>
242          let val newc = mk_const(new_name, ctype)
243              val _ = renamed := M.insert(!renamed, fname, newc)
244          in  (def2, case ind of
245                         SOME th => th
246                       | NONE => TRUTH)
247          end
248   in
249     (def',ind')
250   end;
251
252   val result = pass3 defs
253 in
254   case result of
255        PASS defs' => PASS (List.map redefine defs')
256     |  FAIL x => FAIL x
257 end
258
259fun f_compile_one def = hd (resultOf (f_compile_basic [def] false));
260fun f_compile defs = f_compile_basic defs true;
261
262(*---------------------------------------------------------------------------*)
263(* Join the front end with Magnus' backend                                   *)
264(*---------------------------------------------------------------------------*)
265(*
266val style = ref (InLineCode);
267
268fun b_compile_one (def, ind) =
269  let
270    val (th,strs) = arm_compilerLib.arm_compile (SPEC_ALL def) ind (!style)
271    val code = fetch "-" (#1 (dest_const (#1 (head def))) ^ "_code1_def")
272  in  (th, code)
273  end
274
275fun b_compile norms =
276  let val _ = arm_compilerLib.abbrev_code := true
277      val _ = arm_compilerLib.reset_compiled()
278      val _ = optimise_code := true;
279  in
280    case norms of
281         PASS defs => PASS (List.map b_compile_one defs)
282      |  FAIL x => FAIL x
283  end
284
285(*---------------------------------------------------------------------------*)
286(* End-to-end compiler.                                                      *)
287(*---------------------------------------------------------------------------*)
288
289val pp_compile_one = b_compile_one o f_compile_one;
290val pp_compile = b_compile o f_compile;
291*)
292
293(*---------------------------------------------------------------------------*)
294
295end
296