1structure compilerLib :> compilerLib =
2struct
3
4open HolKernel boolLib bossLib Parse;
5open decompilerLib;
6open codegenLib;
7open codegen_x86Lib;
8open reg_allocLib;
9
10open prog_armLib prog_ppcLib prog_x86Lib prog_x64Lib;
11open wordsTheory wordsLib addressTheory;
12open helperLib;
13open tailrecLib;
14
15
16fun AUTO_ALPHA_CONV () = let
17  val counter = ref (Arbnum.zero)
18  fun inc () = let val v = !counter in (counter := Arbnum.+(v,Arbnum.one); v) end
19  fun counter_genvar ty = mk_var("auto_"^ Arbnum.toString (inc()),ty)
20  fun doit tm =
21    if is_abs tm then
22      (ALPHA_CONV (counter_genvar (type_of (fst (dest_abs tm)))) THENC
23       ABS_CONV doit) tm
24    else if is_comb tm then
25      (RATOR_CONV doit THENC RAND_CONV doit) tm
26    else ALL_CONV tm
27  in doit end
28
29val COMPILER_TAC_LEMMA = prove(
30  ``!a b:bool. (a /\ a /\ b = a /\ b) /\ (a \/ a \/ b = a \/ b)``,
31  REPEAT STRIP_TAC THEN EQ_TAC THEN REPEAT STRIP_TAC THEN ASM_SIMP_TAC std_ss []);
32
33val COMPILER_TAC =
34    SIMP_TAC bool_ss [LET_DEF,word_div_def,word_mod_def,w2w_CLAUSES]
35    THEN SIMP_TAC std_ss [WORD_OR_CLAUSES,GUARD_def]
36    THEN REWRITE_TAC [WORD_CMP_NORMALISE]
37    THEN REWRITE_TAC [WORD_HIGHER,WORD_GREATER,WORD_HIGHER_EQ,
38           WORD_GREATER_EQ,GSYM WORD_NOT_LOWER,GSYM WORD_NOT_LESS]
39    THEN SIMP_TAC std_ss [WORD_MUL_LSL,word_mul_n2w,word_add_n2w,NOT_IF,
40           WORD_OR_CLAUSES]
41    THEN CONV_TAC (EVAL_ANY_MATCH_CONV word_patterns)
42    THEN SIMP_TAC std_ss [WORD_SUB_RZERO, WORD_ADD_0, IF_IF,
43                          AC WORD_ADD_COMM WORD_ADD_ASSOC,
44                          AC WORD_MULT_COMM WORD_MULT_ASSOC,
45                          AC WORD_AND_COMM WORD_AND_ASSOC,
46                          AC WORD_OR_COMM WORD_OR_ASSOC,
47                          AC WORD_XOR_COMM WORD_XOR_ASSOC,
48                          AC CONJ_COMM CONJ_ASSOC,
49                          AC DISJ_COMM DISJ_ASSOC,
50                          IMP_DISJ_THM, WORD_MULT_CLAUSES]
51    THEN REPEAT STRIP_TAC
52    THEN CONV_TAC (RAND_CONV (AUTO_ALPHA_CONV ()))
53    THEN CONV_TAC ((RATOR_CONV o RAND_CONV) (AUTO_ALPHA_CONV ()))
54    THEN SIMP_TAC std_ss [AC CONJ_ASSOC CONJ_COMM, AC DISJ_COMM DISJ_ASSOC,
55                          COMPILER_TAC_LEMMA]
56    THEN SIMP_TAC std_ss [WORD_SUB_RZERO, WORD_ADD_0, IF_IF,
57                          AC WORD_ADD_COMM WORD_ADD_ASSOC,
58                          AC WORD_MULT_COMM WORD_MULT_ASSOC,
59                          AC WORD_AND_COMM WORD_AND_ASSOC,
60                          AC WORD_OR_COMM WORD_OR_ASSOC,
61                          AC WORD_XOR_COMM WORD_XOR_ASSOC,
62                          AC CONJ_COMM CONJ_ASSOC,
63                          AC DISJ_COMM DISJ_ASSOC,
64                          IMP_DISJ_THM, WORD_MULT_CLAUSES]
65    THEN EQ_TAC
66    THEN ONCE_REWRITE_TAC [GSYM DUMMY_EQ_def]
67    THEN REWRITE_TAC [FLATTEN_IF]
68    THEN REPEAT STRIP_TAC
69    THEN ASM_SIMP_TAC std_ss [];
70
71fun basic_compile target tm = let
72  val (tools,target,model_name,s) =
73    if mem target ["arm","ARM"] then (arm_tools,"arm","ARM_MODEL",[]) else
74    if mem target ["x86","i32","386"] then (x86_tools,"x86","X86_MODEL",to_x86_regs ()) else
75    if mem target ["x64","X64"] then (x64_tools,"x64","X64_MODEL",[]) else
76    if mem target ["ppc","Power","PowerPC"] then (ppc_tools,"ppc","PPC_MODEL",[]) else fail()
77  val x = fst (dest_eq tm)
78  val name = fst (dest_const (repeat car x)) handle e => fst (dest_var (repeat car x))
79  val _ = echo 1 ("\nCompiling " ^ name ^ " into "^ target ^ "...\n")
80  val (code,len) = generate_code target model_name true tm
81  fun append' [] = "" | append' (x::xs) = x ^ "\n" ^ append' xs
82  val qcode = [QUOTE (append' code)] : term quotation
83  val in_tm = cdr x
84  fun ends (FUN_IF (_,c1,c2)) = ends c1 @ ends c2
85    | ends (FUN_LET (_,_,c)) = ends c
86    | ends (FUN_COND (_,c)) = ends c
87    | ends (FUN_VAL t) = if is_var t orelse pairSyntax.is_pair t then [t] else []
88  val out_tm2 = hd (ends (tm2ftree (cdr tm)))
89  val (in_tm,out_tm) = (subst s in_tm, subst s out_tm2)
90  val function_in_out = SOME (in_tm,out_tm)
91  val qcode = ([QUOTE (append' code)]) : term Lib.frag list
92  val (th1,th2) = basic_decompile tools name function_in_out qcode
93  val _ = print "Proving equivalence, "
94  val const = (repeat car o fst o dest_eq o concl o hd o CONJUNCTS) th2
95  val tm = subst [ repeat car x |-> const ] tm
96  val pre = if is_conj (concl th2) then (last o CONJUNCTS) th2 else TRUTH
97  val pre = RW [IF_IF,WORD_TIMES2] pre
98  val th3 = auto_prove "compile" (tm,
99(*
100  set_goal([],tm)
101*)
102    REPEAT STRIP_TAC
103    THEN CONV_TAC (RATOR_CONV (ONCE_REWRITE_CONV [th2]))
104    THEN COMPILER_TAC)
105  val _ = add_compiler_assignment out_tm2 (fst (dest_eq tm)) name len model_name
106  val _ = print "done.\n"
107  in (th1,th3,pre) end;
108
109fun compile target tm = let
110  val _ = set_abbreviate_code true
111  fun compile_each [] = []
112    | compile_each (tm::tms) = let
113    val (th,def,pre) = basic_compile target tm
114    in (th,def,pre) :: compile_each tms end
115  val tms = list_dest dest_conj tm
116  val xs = compile_each tms
117  val defs = map (fn (x,y,z) => y) xs
118  val pres = map (fn (x,y,z) => z) xs
119  val def = RW [] (foldr (uncurry CONJ) TRUTH defs)
120  val pre = RW [] (foldr (uncurry CONJ) TRUTH pres)
121  val (th,_,_) = last xs
122  val _ = set_abbreviate_code false
123  val th = UNABBREV_CODE_RULE th
124  in (th,def,pre) end;
125
126fun compile_all_aux tm = let
127  val x = fst (dest_eq tm)
128  val name = fst (dest_const (repeat car x)) handle e => fst (dest_var (repeat car x))
129  val targets = ["ppc","x86","arm"]
130  fun compile_each tm [] = []
131    | compile_each tm (target::ts) = let
132    val (th,def,_) = basic_compile target tm
133    val pre = fetch "-" (name ^ "_pre_def")
134    val tm = concl def
135    in (target,th,def,pre):: compile_each tm ts end
136  val xs = compile_each tm targets
137  val (_,_,last_def,last_pre) = last xs
138  fun prove_eq (target,th,def,pre) = let
139    val f = (repeat car o fst o dest_eq o concl o SPEC_ALL)
140    val goal = mk_conj(mk_eq(f def,f last_def),mk_eq(f pre,f last_pre))
141    val lemma = auto_prove "compiler prove_eq" (goal,
142      TAILREC_TAC THEN REPEAT STRIP_TAC THEN COMPILER_TAC)
143    val _ = echo 1 (" " ^ target)
144    in (target, ONCE_REWRITE_RULE [lemma] th) end
145  val _ = echo 1 "\nJoining definitions:"
146  val ys = map prove_eq xs
147  val _ = echo 1 ".\n"
148  val thms = map snd ys
149  val _ = add_compiled thms
150  in (ys,last_def,last_pre) end;
151
152fun compile_all tm = let
153  val _ = set_abbreviate_code true
154  fun compile_each [] = []
155    | compile_each (tm::tms) = let
156    val x = car (fst (dest_eq tm))
157    val (ys,def,pre) = compile_all_aux tm
158    val x2 = (car o fst o dest_eq o concl o SPEC_ALL) def
159    val tms = map (subst [x |-> x2]) tms
160    in (ys,def,pre) :: compile_each tms end
161  val tms = list_dest dest_conj tm
162  val xs = compile_each tms
163  val f = REWRITE_RULE [] o foldr (uncurry CONJ) TRUTH
164  val defs = f (map (fn (x,y,z) => y) xs)
165  val pres = f (map (fn (x,y,z) => z) xs)
166  val (ys,_,_) = last xs
167  val _ = set_abbreviate_code false
168  val ys = map (fn (n,th) => (n,UNABBREV_CODE_RULE th)) ys
169  in (ys,defs,pres) end
170
171
172(* this compiler maintains a to-do list: a list of functions to be compiled *)
173
174val to_compile = ref ([]:(string * (thm * thm)) list);
175
176
177(* for each function it generates a precondition, asserting termination etc. *)
178
179fun append_lists [] = []
180  | append_lists (y::ys) = y @ append_lists ys
181
182fun list_find x [] = fail()
183  | list_find x ((y,z)::zs) = if x = y then z else list_find x zs
184
185fun all_distinct [] = []
186  | all_distinct (x::xs) = x :: all_distinct (filter (fn y => not (x = y)) xs)
187
188fun generate_pre fname args tm = let
189  val h = fst o dest_eq o concl o SPEC_ALL
190  val aux_pre = map ((fn (x,y) => (h x,h y)) o snd)
191                (filter (fn (_,(_,y)) => not (cdr (concl y) = T)) (!to_compile))
192  fun mk_ALIGNED b = (fst o dest_eq o concl o SPEC b) addressTheory.ALIGNED_def
193  fun word32_read_write tm = let
194    val xs1 = find_terms (fn x => is_comb x andalso is_var (car x) andalso
195                                  (type_of (car x) = ``:word32 -> word32``)) tm
196    val xs1 = map ((fn (x,y) => (fst (dest_var x),y)) o dest_comb) xs1
197    val xs2 = find_terms (fn x => is_comb x andalso
198                                  combinSyntax.is_update (car x) andalso
199                                  (type_of x = ``:word32 -> word32``)) tm
200    val xs2 = map (fn x => (fst (combinSyntax.dest_update(car x)),repeat cdr x)) xs2
201    val xs2 = map (fn (x,f) => (fst (dest_var f),x)) xs2
202    val ys = map (mk_ALIGNED o snd) (xs1 @ xs2)
203    val zs = map (fn (n,x) => pred_setSyntax.mk_in(x,mk_var("d"^n,``:word32 set``))) (xs1@xs2)
204    in ys @ zs end
205  fun aux_fun_pre tm (def,pre) = let
206    val xs = find_terms (fn x => can (match_term def) x) tm
207    in map (fn x => subst (fst (match_term def x)) pre) xs end
208  fun all_aux_fun_pre tm = all_distinct (append_lists (map (aux_fun_pre tm) aux_pre))
209  fun get_pre tm = word32_read_write tm @ all_aux_fun_pre tm
210  fun cond_pre tm f = let
211    val pre = get_pre tm
212    in if pre = [] then f else FUN_COND (list_mk_conj pre, f) end
213  val cond_var = mk_var("cond",``:bool``)
214  fun get_name tm = fst (dest_var (car tm)) handle HOL_ERR _ =>
215                    fst (dest_const (car tm)) handle _ => "    "
216  val pre_name = fname ^ "_pre"
217  val pre_f = mk_var(pre_name,mk_type ("fun",[type_of args, ``:bool``]))
218  fun add_pre (FUN_VAL tm) =
219       (if not (get_name tm = fname) then cond_pre tm (FUN_VAL cond_var)
220        else cond_pre tm (FUN_VAL (mk_conj(mk_comb(pre_f,cdr tm),cond_var))))
221    | add_pre (FUN_IF (tm,t1,t2)) = cond_pre tm (FUN_IF (tm, add_pre t1, add_pre t2))
222    | add_pre (FUN_LET (lhs,rhs,t)) = cond_pre rhs (FUN_LET (lhs,rhs, add_pre t))
223    | add_pre (FUN_COND (tm,t)) = FUN_COND (tm,add_pre t)
224  val pre_tm = subst [cond_var|->T] (ftree2tm (add_pre (tm2ftree (cdr tm))))
225  val pre_tm = (snd o dest_eq o concl) (QCONV (REWRITE_CONV []) pre_tm)
226  in pre_tm end;
227
228
229(* mc_define defines a function, generates a precondition and adds these to to-do list *)
230
231fun mc_define q = let
232  val absyn = Parse.Absyn q
233  val fname = Absyn.dest_ident (fst (Absyn.dest_app (fst (Absyn.dest_eq absyn))))
234  val _ = Parse.hide fname
235  val tm = Parse.Term q
236  val (name,args) = dest_comb (fst (dest_eq tm))
237  val pre_tm = generate_pre fname args tm
238  val pre_f = mk_var(fname ^ "_pre",mk_type ("fun",[type_of args, ``:bool``]))
239  val pre_tm = mk_eq(mk_comb(pre_f,args),pre_tm)
240  val pre_option = SOME pre_tm
241  val (def,pre) = tailrec_define_with_pre tm pre_tm
242  val _ = (to_compile := (fname,(def,pre)) ::
243                         filter (fn (n,_) => not (fname = n)) (!to_compile))
244  in (def,pre) end;
245
246
247(* mc_compile performs actual compilation with help of reg_allocLib and compilerLib *)
248
249fun collect_aux_fnames fname = let
250  val fconst = car o fst o dest_eq o concl o SPEC_ALL
251  val all_fconsts = map (fconst o fst o snd) (!to_compile)
252  fun uses def =
253    (all_distinct o map (fst o dest_const) o filter (fn x => not (x = fconst def)) o
254     find_terms (fn x => mem x all_fconsts) o concl) def
255  val all_uses = map (fn (_,(def,_)) => ((fst o dest_const o fconst) def, uses def)) (!to_compile)
256  fun rec_uses fname =
257    fname :: append_lists (map rec_uses (list_find fname all_uses))
258  val fnames = all_distinct (rec_uses fname)
259  val fnames = filter (fn x => mem x fnames) (map fst (!to_compile))
260  in fnames end;
261
262fun mc_compile fname target = let
263  val fnames = collect_aux_fnames fname
264  val qs = map (fn f => (f,list_find f (!to_compile))) (rev fnames)
265  val fun_const = car o fst o dest_eq o concl
266  val cs = map (fun_const o fst o snd) qs
267  fun make_temp_name s = "__" ^ s
268  val s = map (fn c => c |-> mk_var(make_temp_name (fst (dest_const c)),type_of c)) cs
269  val tm = subst s (list_mk_conj (map (concl o fst o snd) qs))
270  val input_tm = tm
271  val n = list_find target [("arm",13),("ppc",31),("x86",1000)]
272  val imp = allocate_registers n input_tm
273  fun strip_forall tm = strip_forall (snd (dest_forall tm)) handle HOL_ERR _ => tm
274  val gs = (map strip_forall o list_dest dest_conj o fst o dest_imp o concl) imp
275  val xs = zip (rev fnames) gs
276  fun do_all [] thms = thms
277    | do_all ((f,tm)::xs) thms = let
278        val (x1,x2,x3) = compile target tm
279        val x2c = fun_const x2
280        val x3c = fun_const x3
281        val s = [mk_var(make_temp_name(f),type_of x2c) |-> x2c]
282        val xs = map (fn (x,tm) => (x,subst s tm)) xs
283        in do_all xs ((x1,x2,x3)::thms) end
284  val zs = do_all xs []
285  fun prove_equiv [] rws = rws
286    | prove_equiv (((_,(x2,x3)),(_,y2,y3))::ts) rws = let
287        val goal = mk_conj(mk_eq(fun_const x2,fun_const y2),
288                           mk_eq(fun_const x3,fun_const y3))
289        val th2 = auto_prove "mc_compile" (goal,
290          STRIP_TAC THEN TAILREC_TAC THEN AP_TERM_TAC THEN REWRITE_TAC rws
291          THEN SIMP_TAC std_ss [GUARD_def,FUN_EQ_THM] THEN REPEAT STRIP_TAC
292          THEN SIMP_TAC std_ss [ALIGNED_INTRO] THEN COMPILER_TAC)
293        val rws = th2::rws
294        in prove_equiv ts rws end
295  val rws = []
296  val ts = zip qs (rev zs)
297  val rws = prove_equiv ts rws
298  val rw = GSYM (RW [GSYM CONJ_ASSOC] (LIST_CONJ rws))
299  val (th,_,_) = hd zs
300  val th = RW [rw] th
301  val temp_cs = map (fst o dest_eq) (list_dest dest_conj (concl rw))
302  val _ = map (delete_const o fst o dest_const) temp_cs
303  in th end;
304
305fun mc_compile_all fname =
306  map (fn target => (target, mc_compile fname target)) ["arm","ppc","x86"]
307
308end;
309