1fun load_path_add x = loadPath := !loadPath @ [concat [Globals.HOLDIR,x]];
2load_path_add "/examples/dev/sw2";
3load_path_add "/examples/machine-code/compiler";
4load_path_add "/examples/machine-code/decompiler";
5load_path_add "/examples/machine-code/hoare-triple";
6load_path_add "/examples/machine-code/instruction-set-models";
7load_path_add "/examples/ARM/v4";
8load_path_add "/tools/mlyacc/mlyacclib";
9
10quietdec := true;
11app load ["compilerLib", "regAlloc", "closure", "inline"];
12open arm_compilerLib;
13open Normal inline closure regAlloc NormalTheory;
14open wordsSyntax numSyntax pairSyntax;
15val mapPartial = List.mapPartial;
16quietdec := false;
17
18
19numLib.prefer_num();
20Globals.priming := NONE;
21
22val ERR = mk_HOL_ERR "compileDefine" ;
23
24val _ = optimise_code := true;
25val _ = abbrev_code := false;
26
27(*---------------------------------------------------------------------------*)
28(* To eventually go to wordsSyntax.                                          *)
29(*---------------------------------------------------------------------------*)
30
31val is_word_literal = is_n2w;
32
33fun strip_word_or tm = 
34 let fun f tm = 
35      case total dest_word_or tm
36       of SOME (l,r) => f(l) @ f(r)
37        | NONE => [tm]
38 in Lib.front_last(f tm)
39 end;
40
41val list_mk_word_or = end_itlist (curry mk_word_or);
42
43fun pflat [] = []
44  | pflat ((x,y)::t) = x::y::pflat t;
45
46(*---------------------------------------------------------------------------*)
47(* Translating constants into compound expressions.                          *)
48(*---------------------------------------------------------------------------*)
49
50fun type_to_name t = String.extract(Hol_pp.type_to_string t, 1, NONE);
51
52fun numeric_type_to_num t =
53 Lib.with_flag(type_pp.pp_num_types,true)
54  Arbnum.fromString (type_to_name t);
55
56fun pad s i =
57 let fun loop n acc = if n <= 0 then acc else loop(n-1) (#"0"::acc)
58 in String.implode (loop i (String.explode s))
59 end;
60
61fun bit_pattern w = 
62 let open wordsSyntax
63     val (ntm,ty) = dest_n2w w
64     val n = numSyntax.dest_numeral ntm
65     val ty_width = Arbnum.toInt(numeric_type_to_num ty)
66     val str = Arbnum.toBinString n
67 in 
68   pad str (ty_width - String.size str)
69 end;
70
71fun word_of_string(s,width) = 
72    mk_n2w(mk_numeral(Arbnum.fromBinString s),width);
73
74val index32 = fcpLib.index_type (Arbnum.fromInt 32);
75fun word32_of_string s = word_of_string(s,index32);
76
77fun word_of_int (i,width) =
78  mk_n2w(numSyntax.term_of_int i,
79         fcpLib.index_type(Arbnum.fromInt width));
80
81fun chunk s = 
82 let open String numSyntax
83     val s1 = substring(s,0,8)
84     val s2 = substring(s,8,8)
85     val s3 = substring(s,16,8)
86     val s4 = substring(s,24,8)
87 in 
88   (word32_of_string s1,word32_of_string s2,
89    word32_of_string s3,word32_of_string s4)
90 end;
91
92val zero32 = ``0w:word32``;
93val n8 = term_of_int 8;
94val n16 = term_of_int 16;
95val n24 = term_of_int 24;
96val n256 = Arbnum.fromInt 256;
97
98fun bytes_to_let (b1,b2,b3,b4) = 
99 let val plist = mapPartial 
100                   (fn p as (b,s) => if b = zero32 then NONE else SOME p)
101                    [(b1,n24), (b2,n16), (b3,n8)]
102     val plist' = enumerate 0 plist
103     val plist'' = map (fn (i,p as (b,_)) => 
104                          (mk_var("v"^Int.toString i,type_of b),p)) plist'
105     val vlist = list_mk_word_or (map fst plist'' @ [b4])
106     fun foo (v,(c,s)) = ((v,c),(v,mk_word_lsl(v,s)))
107     val vb4 = mk_var("v"^Int.toString(length plist''), type_of b4)
108     val plist3 = pflat(map foo plist'') @ [(vb4,vlist)]
109 in list_mk_anylet (map single plist3,vb4)
110 end;
111
112fun IMMEDIATE_CONST_CONV c =
113 let val n = numSyntax.dest_numeral(fst(dest_n2w c))
114 in if Arbnum.<(n,n256) then failwith "CONST_CONV" else 
115    let val bstr = bit_pattern c
116        val res = bytes_to_let (chunk bstr)
117    in EQT_ELIM (wordsLib.WORDS_CONV (mk_eq(c,res)))
118    end
119 end;
120
121(*---------------------------------------------------------------------------*)
122(* Want to remove returned constants, in favour of returned registers.       *)
123(* So "if P then c else d" where c or d are constants, needs to become       *)
124(*                                                                           *)
125(*   if P then (let x = c in x) else (let x = d in x)                        *)
126(*---------------------------------------------------------------------------*)
127
128fun MK_COND tm th1 th2 = 
129 let val core = rator(rator tm)
130     val thm1 = MK_COMB (REFL core, th1)
131     val thm2 = MK_COMB (thm1,th2)
132 in thm2
133 end;
134
135fun letify c = 
136 let val v = genvar (type_of c)
137     val tm = mk_let (mk_abs(v,v),c)
138 in SYM(BETA_RULE (REWR_CONV LET_THM tm))
139 end;
140
141fun COND_CONST_ELIM_CONV tm = 
142 let val (t,a1,a2) = dest_cond tm
143 in case (is_const a1 orelse is_word_literal a1,
144          is_const a2 orelse is_word_literal a2)
145     of (true,true) => MK_COND tm (letify a1) (letify a2)
146      | (true,false) => MK_COND tm (letify a1) (REFL a2)
147      | (false,true) => MK_COND tm (REFL a1) (letify a2)
148      | (false,false) => failwith ""
149 end
150  handle HOL_ERR _ => raise ERR "COND_CONST_ELIM_CONV" "";
151
152
153(*---------------------------------------------------------------------------*)
154(* Expand constants, and eliminate returned constants (return values must    *)
155(* be in registers).                                                         *)
156(*---------------------------------------------------------------------------*)
157
158val pass0 = CONV_RULE (DEPTH_CONV IMMEDIATE_CONST_CONV);
159val pass0a = CONV_RULE (DEPTH_CONV COND_CONST_ELIM_CONV);
160val pass0b = Ho_Rewrite.PURE_REWRITE_RULE [FLATTEN_LET];
161
162(*---------------------------------------------------------------------------*)
163(* Compiling a list of functions.                                            *)
164(*---------------------------------------------------------------------------*)
165  
166fun defname th = 
167  fst(dest_const(fst(strip_comb(lhs(snd(strip_forall(concl th)))))));
168
169fun compenv comp = 
170 let fun compile (env,[]) = PASS(rev env)
171       | compile (env,h::t) =
172          let val name = defname h
173          in 
174            print ("Compiling "^name^" ... ");
175            case total comp (env,h) 
176             of SOME def1 => (print "succeeded.\n"; compile(def1::env,t))
177              | NONE => (print "failed.\n"; FAIL(env,h::t))
178          end
179 in
180    compile 
181 end;
182
183(*---------------------------------------------------------------------------*)
184(* Compile a list of definitions, accumulating the environment.              *)
185(*---------------------------------------------------------------------------*)
186
187fun complist passes deflist = compenv passes ([],deflist);
188
189(*---------------------------------------------------------------------------*)
190(* Basic flattening via CPS and unique names                                 *)
191(*---------------------------------------------------------------------------*)
192
193fun pass1 def = SSA_RULE (pass0b (pass0a (pass0 (normalize def))));
194
195
196(*---------------------------------------------------------------------------*)
197(* All previous, plus inlining and optimizations                             *)
198(*---------------------------------------------------------------------------*)
199
200fun pass2 (env,def) = 
201  let val def1 = pass1 def
202  in 
203   SSA_RULE (optimize_norm env def1)
204  end;
205
206(*---------------------------------------------------------------------------*)
207(* All previous, and closure conversion.                                     *)
208(*---------------------------------------------------------------------------*)
209
210fun pass3 (env,def) = 
211  let val def1 = pass2 (env,def)
212  in case total closure_convert def1
213      of SOME thm => SSA_RULE (optimize_norm env thm)
214       | NONE => def1
215  end;
216
217(*---------------------------------------------------------------------------*)
218(* All previous, and register allocation.                                    *)
219(*---------------------------------------------------------------------------*)
220
221fun pass4 (env,def) = 
222  let val def1 = pass3 (env,def)
223  in 
224    reg_alloc def1
225  end;
226
227(*---------------------------------------------------------------------------*)
228(* Different pass4, in which some intermediate steps are skipped.            *)
229(*---------------------------------------------------------------------------*)
230
231fun pass4a (env,def) = 
232  let val def1 = pass1 def
233      val def2 = reg_alloc def1
234  in 
235    def2
236  end;
237
238val compile1 = complist (fn (e,d) => pass1 d);
239val compile2 = complist pass2;
240val compile3 = complist pass3;
241val compile4 = complist pass4;
242val compile4a = complist pass4a;
243
244(*---------------------------------------------------------------------------*)
245(* Some simplifications used after compilation                               *)
246(*---------------------------------------------------------------------------*)
247
248val lift_cond_above_let = Q.prove
249(`(let x = (if v1 then v2 else v3) in rst x) = 
250  if v1 then (let x = v2 in rst x) else (let x = v3 in rst x)`,
251 RW_TAC std_ss [LET_DEF]);
252
253val lift_cond_above_let1 = Q.prove
254(`(let x = val in (if v1 then v2 x else v3 x)) = 
255  if v1 then (let x = val in v2 x) else (let x = val in v3 x)`,
256 RW_TAC std_ss [LET_DEF]);
257
258val lift_cond_above_trivlet = Q.prove
259(`(let x = (if v1 then v2 else v3) in x) = 
260  if v1 then v2 else v3`,
261 SIMP_TAC std_ss [LET_DEF]);
262
263val id_let = Q.prove
264(`LET (\x.x) y = y`,
265 SIMP_TAC std_ss [LET_DEF]);
266
267
268(*---------------------------------------------------------------------------*)
269(*  Eliminates "let x = x in M" redundancies                                 *)
270(*---------------------------------------------------------------------------*)
271
272fun ID_BIND_CONV tm =
273  let open pairSyntax
274      val (abs,r) = dest_let tm
275  in
276    if bvar abs = r 
277    then RIGHT_BETA (REWR_CONV LET_THM tm)
278    else failwith "ID_BIND_CONV"
279  end;
280
281
282fun head eqns = 
283  strip_comb
284      (lhs(snd(strip_forall(hd
285        (strip_conj(concl eqns))))));
286
287(*---------------------------------------------------------------------------*)
288(* Also want to show that the second definition satisfies the rec. eqn of    *)
289(* the first definition.                                                     *)
290(*---------------------------------------------------------------------------*)
291
292fun compileDefine (env,q) =
293 let open TotalDefn pairSyntax
294     val _ = HOL_MESG "Initial definition"
295     val def1 = Define q
296     val (const,_) = head def1
297     val cinfo as (cname,_) = dest_const const
298     val cvar = mk_var cinfo
299     val compiled = pass4a(env,def1)
300     val args = fst(dest_pabs(rhs(concl compiled)))
301     val th1 = CONV_RULE (RHS_CONV pairLib.GEN_BETA_CONV) 
302                         (AP_THM compiled args)
303     val th2 = CONV_RULE 
304                (RHS_CONV (SIMP_CONV bool_ss 
305                             [lift_cond_above_let,lift_cond_above_let1,
306                              lift_cond_above_trivlet,FLATTEN_LET,
307                              COND_RATOR, COND_RAND])) th1
308     val th3 = CONV_RULE (DEPTH_CONV ID_BIND_CONV) th2
309     val vtm = subst [const |-> cvar] (concl th3)
310     val _ = HOL_MESG "Second definition"
311     val PASS (defn2,tcs) = std_apiDefine (cname,vtm)
312     val def2 = LIST_CONJ (Defn.eqns_of defn2)
313     val ind = Defn.ind_of defn2
314 in
315  (def2,ind)
316 end;
317
318(*---------------------------------------------------------------------------*)
319(* Join the front end with Magnus' backend                                   *)
320(*---------------------------------------------------------------------------*)
321
322fun test_compile' style q = 
323 let val (def, indopt) = compileDefine ([],q)
324     val indth = (case indopt of NONE => TRUTH | SOME ind => ind)
325     val (th,strs) = arm_compile (SPEC_ALL def) indth style
326 in 
327   (def,indth,th) 
328 end;
329
330fun test_compile q = test_compile' InLineCode q;
331fun test_compile_proc q = test_compile' SimpleProcedure q;
332
333val _ = abbrev_code := true;
334val _ = reset_compiled();
335
336(*
337val (load_751_def,_,load_751_arm) = test_compile_proc `
338  load_751 = 
339    let r10 = 2w:word32 in
340    let r10 = r10 << 8 in
341    let r10 = r10 + 239w in r10`;
342
343val load_751_def = 
344 Define 
345   `load_751 = 
346    let r10 = 2w:word32 in
347    let r10 = r10 << 8 in
348    let r10 = r10 + 239w 
349    in r10`;
350
351val (load_751_arm,_) = arm_compile load_751_def TRUTH InLineCode;
352*)
353
354
355val (field_neg_def,_,field_neg_arm) = 
356 test_compile
357   `field_neg (x:word32) = 
358     if x = 0w then (0w:word32) 
359     else 751w - x`;
360
361val field_neg_triple = 
362  REWRITE_RULE [fetch "-" "field_neg_code1_def"] field_neg_arm;
363
364val (field_add_def,_,field_add_arm) =
365 test_compile
366  `field_add (x:word32,y:word32) =
367     let z = x + y 
368      in 
369       if z < 751w then z else z - 751w`;
370val field_add_triple = 
371  REWRITE_RULE [fetch "-" "field_add_code1_def"] field_add_arm;
372
373val (field_sub_def,_,field_sub_arm) = 
374  test_compile
375   `field_sub (x:word32,y:word32) = field_add(x,field_neg y)`;
376REWRITE_RULE [fetch "-" "field_sub_code1_def"] field_sub_arm;
377
378val (field_mult_aux_def,_,field_mult_aux_arm) = 
379  test_compile
380  `field_mult_aux (x:word32,y:word32,acc:word32) =
381      if y = 0w then acc 
382      else let 
383        x' = field_add (x,x) in let 
384        y' = y >>> 1         in let 
385        acc' = (if y && 1w = 0w then acc else field_add (acc,x))
386        in
387          field_mult_aux (x',y',acc')`;
388REWRITE_RULE [fetch "-" "field_mult_aux1_def"] field_mult_aux_arm;
389
390val (field_mult_def,NONE) = 
391 test_compile
392   `field_mult (x,y) = field_mult_aux (x,y,0w)`;
393
394val (field_exp_aux_def, _, field_exp_aux_arm) = 
395 test_compile
396  `field_exp_aux (x:word32,n:word32,acc:word32) =
397      if n = 0w then acc
398      else
399       let x' = field_mult (x,x) in
400       let n' = n >>> 1 in
401       let acc' = (if n && 1w = 0w then acc else field_mult (acc,x))
402        in
403          field_exp_aux (x',n',acc')`);
404
405val (field_exp_def,NONE) = 
406 compileDefine([],
407  `field_exp (x,n) = field_exp_aux (x,n,1w)`);
408
409val (field_inv_def,NONE) =
410 compileDefine ([],
411  `field_inv x = field_exp (x,749w)`);
412
413val (field_div_def,NONE) = 
414 compileDefine([],
415  `field_div (x,y) = field_mult (x,field_inv y)`);
416
417val (curve_neg_def,NONE) = 
418 compileDefine([],
419  `curve_neg (x1,y1) =
420       if (x1 = 0w) /\ (y1 = 0w) then (0w,0w)
421       else
422        let y = field_sub
423                  (field_sub
424                    (field_neg y1,field_mult (0w,x1)),1w)
425         in
426            (x1,y)`);
427
428val (curve_double_def,NONE) = 
429 compileDefine([],
430  `curve_double (x1,y1) =
431      if (x1 = 0w) /\ (y1 = 0w) then (0w,0w)
432      else
433       let d = field_add
434                 (field_add
435                   (field_mult (2w,y1),
436                    field_mult (0w,x1)),1w)
437       in
438        if d = 0w then (0w,0w)
439        else
440         let l = field_div
441                  (field_sub
442                    (field_add
443                      (field_add
444                        (field_mult(3w,field_exp (x1,2w)),
445                         field_mult(field_mult (2w,0w),x1)),750w),
446                       field_mult (0w,y1)),d) in
447         let m = field_div
448                  (field_sub
449                    (field_add
450                      (field_add
451                           (field_neg (field_exp (x1,3w)),
452                            field_mult (750w,x1)),
453                       field_mult (2w,0w)),
454                     field_mult (1w,y1)),d) in
455         let x = field_sub
456                  (field_sub
457                    (field_add(field_exp (l,2w),
458                                   field_mult (0w,l)),0w),
459                     field_mult (2w,x1)) in
460         let y = field_sub
461                  (field_sub
462                     (field_mult
463                       (field_neg (field_add (l,0w)),x),m),1w)
464         in
465           (x,y)`);
466
467
468val (curve_add_def,NONE) = 
469 compileDefine([],
470  `curve_add (x1,y1,x2,y2) =
471       if (x1 = x2) /\ (y1 = y2) then curve_double (x2,y2) else 
472       if (x1 = 0w) /\ (y1 = 0w) then (x2,y2) else
473       if (x2 = 0w) /\ (y2 = 0w) then (x1,y1) else
474       if x1 = x2 then (0w,0w)
475       else
476         let d = field_sub (x2,x1) in
477         let l = field_div (field_sub (y2,y1),d) in
478         let m = field_div
479                   (field_sub (field_mult (y1,x2),
480                                   field_mult (y2,x1)),d) in
481         let x = field_sub
482                  (field_sub
483                    (field_sub
484                      (field_add
485                        (field_exp (l,2w),
486                         field_mult (0w,l)),0w),x1),x2) in
487         let y = field_sub
488                  (field_sub
489                    (field_mult
490                      (field_neg (field_add (l,0w)),x),m),1w)
491         in
492          (x,y)`);
493
494val (curve_mult_aux_def,NONE) = 
495 compileDefine([],
496  `curve_mult_aux (x,y,n:word32,acc_x,acc_y) =
497      if n = 0w then (acc_x:word32,acc_y:word32)
498      else
499       let (x',y') = curve_double (x,y) in
500       let n' = n >>> 1 in
501       let (acc_x',acc_y') =
502              (if n && 1w = 0w then (acc_x,acc_y)
503               else curve_add (x,y,acc_x,acc_y))
504       in
505        curve_mult_aux (x',y',n',acc_x',acc_y')`);
506
507val curve_mult_def = 
508 Define
509  `curve_mult (x,y,n) = curve_mult_aux (x,y,n,0w,0w)`;
510
511