fun load_path_add x = loadPath := !loadPath @ [concat [Globals.HOLDIR,x]]; load_path_add "/examples/dev/sw2"; load_path_add "/examples/machine-code/compiler"; load_path_add "/examples/machine-code/decompiler"; load_path_add "/examples/machine-code/hoare-triple"; load_path_add "/examples/machine-code/instruction-set-models"; load_path_add "/examples/ARM/v4"; load_path_add "/tools/mlyacc/mlyacclib"; quietdec := true; app load ["compilerLib", "regAlloc", "closure", "inline"]; open arm_compilerLib; open Normal inline closure regAlloc NormalTheory; open wordsSyntax numSyntax pairSyntax; val mapPartial = List.mapPartial; quietdec := false; numLib.prefer_num(); Globals.priming := NONE; val ERR = mk_HOL_ERR "compileDefine" ; val _ = optimise_code := true; val _ = abbrev_code := false; (*---------------------------------------------------------------------------*) (* To eventually go to wordsSyntax. *) (*---------------------------------------------------------------------------*) val is_word_literal = is_n2w; fun strip_word_or tm = let fun f tm = case total dest_word_or tm of SOME (l,r) => f(l) @ f(r) | NONE => [tm] in Lib.front_last(f tm) end; val list_mk_word_or = end_itlist (curry mk_word_or); fun pflat [] = [] | pflat ((x,y)::t) = x::y::pflat t; (*---------------------------------------------------------------------------*) (* Translating constants into compound expressions. *) (*---------------------------------------------------------------------------*) fun type_to_name t = String.extract(Hol_pp.type_to_string t, 1, NONE); fun numeric_type_to_num t = Lib.with_flag(type_pp.pp_num_types,true) Arbnum.fromString (type_to_name t); fun pad s i = let fun loop n acc = if n <= 0 then acc else loop(n-1) (#"0"::acc) in String.implode (loop i (String.explode s)) end; fun bit_pattern w = let open wordsSyntax val (ntm,ty) = dest_n2w w val n = numSyntax.dest_numeral ntm val ty_width = Arbnum.toInt(numeric_type_to_num ty) val str = Arbnum.toBinString n in pad str (ty_width - String.size str) end; fun word_of_string(s,width) = mk_n2w(mk_numeral(Arbnum.fromBinString s),width); val index32 = fcpLib.index_type (Arbnum.fromInt 32); fun word32_of_string s = word_of_string(s,index32); fun word_of_int (i,width) = mk_n2w(numSyntax.term_of_int i, fcpLib.index_type(Arbnum.fromInt width)); fun chunk s = let open String numSyntax val s1 = substring(s,0,8) val s2 = substring(s,8,8) val s3 = substring(s,16,8) val s4 = substring(s,24,8) in (word32_of_string s1,word32_of_string s2, word32_of_string s3,word32_of_string s4) end; val zero32 = ``0w:word32``; val n8 = term_of_int 8; val n16 = term_of_int 16; val n24 = term_of_int 24; val n256 = Arbnum.fromInt 256; fun bytes_to_let (b1,b2,b3,b4) = let val plist = mapPartial (fn p as (b,s) => if b = zero32 then NONE else SOME p) [(b1,n24), (b2,n16), (b3,n8)] val plist' = enumerate 0 plist val plist'' = map (fn (i,p as (b,_)) => (mk_var("v"^Int.toString i,type_of b),p)) plist' val vlist = list_mk_word_or (map fst plist'' @ [b4]) fun foo (v,(c,s)) = ((v,c),(v,mk_word_lsl(v,s))) val vb4 = mk_var("v"^Int.toString(length plist''), type_of b4) val plist3 = pflat(map foo plist'') @ [(vb4,vlist)] in list_mk_anylet (map single plist3,vb4) end; fun IMMEDIATE_CONST_CONV c = let val n = numSyntax.dest_numeral(fst(dest_n2w c)) in if Arbnum.<(n,n256) then failwith "CONST_CONV" else let val bstr = bit_pattern c val res = bytes_to_let (chunk bstr) in EQT_ELIM (wordsLib.WORDS_CONV (mk_eq(c,res))) end end; (*---------------------------------------------------------------------------*) (* Want to remove returned constants, in favour of returned registers. *) (* So "if P then c else d" where c or d are constants, needs to become *) (* *) (* if P then (let x = c in x) else (let x = d in x) *) (*---------------------------------------------------------------------------*) fun MK_COND tm th1 th2 = let val core = rator(rator tm) val thm1 = MK_COMB (REFL core, th1) val thm2 = MK_COMB (thm1,th2) in thm2 end; fun letify c = let val v = genvar (type_of c) val tm = mk_let (mk_abs(v,v),c) in SYM(BETA_RULE (REWR_CONV LET_THM tm)) end; fun COND_CONST_ELIM_CONV tm = let val (t,a1,a2) = dest_cond tm in case (is_const a1 orelse is_word_literal a1, is_const a2 orelse is_word_literal a2) of (true,true) => MK_COND tm (letify a1) (letify a2) | (true,false) => MK_COND tm (letify a1) (REFL a2) | (false,true) => MK_COND tm (REFL a1) (letify a2) | (false,false) => failwith "" end handle HOL_ERR _ => raise ERR "COND_CONST_ELIM_CONV" ""; (*---------------------------------------------------------------------------*) (* Expand constants, and eliminate returned constants (return values must *) (* be in registers). *) (*---------------------------------------------------------------------------*) val pass0 = CONV_RULE (DEPTH_CONV IMMEDIATE_CONST_CONV); val pass0a = CONV_RULE (DEPTH_CONV COND_CONST_ELIM_CONV); val pass0b = Ho_Rewrite.PURE_REWRITE_RULE [FLATTEN_LET]; (*---------------------------------------------------------------------------*) (* Compiling a list of functions. *) (*---------------------------------------------------------------------------*) fun defname th = fst(dest_const(fst(strip_comb(lhs(snd(strip_forall(concl th))))))); fun compenv comp = let fun compile (env,[]) = PASS(rev env) | compile (env,h::t) = let val name = defname h in print ("Compiling "^name^" ... "); case total comp (env,h) of SOME def1 => (print "succeeded.\n"; compile(def1::env,t)) | NONE => (print "failed.\n"; FAIL(env,h::t)) end in compile end; (*---------------------------------------------------------------------------*) (* Compile a list of definitions, accumulating the environment. *) (*---------------------------------------------------------------------------*) fun complist passes deflist = compenv passes ([],deflist); (*---------------------------------------------------------------------------*) (* Basic flattening via CPS and unique names *) (*---------------------------------------------------------------------------*) fun pass1 def = SSA_RULE (pass0b (pass0a (pass0 (normalize def)))); (*---------------------------------------------------------------------------*) (* All previous, plus inlining and optimizations *) (*---------------------------------------------------------------------------*) fun pass2 (env,def) = let val def1 = pass1 def in SSA_RULE (optimize_norm env def1) end; (*---------------------------------------------------------------------------*) (* All previous, and closure conversion. *) (*---------------------------------------------------------------------------*) fun pass3 (env,def) = let val def1 = pass2 (env,def) in case total closure_convert def1 of SOME thm => SSA_RULE (optimize_norm env thm) | NONE => def1 end; (*---------------------------------------------------------------------------*) (* All previous, and register allocation. *) (*---------------------------------------------------------------------------*) fun pass4 (env,def) = let val def1 = pass3 (env,def) in reg_alloc def1 end; (*---------------------------------------------------------------------------*) (* Different pass4, in which some intermediate steps are skipped. *) (*---------------------------------------------------------------------------*) fun pass4a (env,def) = let val def1 = pass1 def val def2 = reg_alloc def1 in def2 end; val compile1 = complist (fn (e,d) => pass1 d); val compile2 = complist pass2; val compile3 = complist pass3; val compile4 = complist pass4; val compile4a = complist pass4a; (*---------------------------------------------------------------------------*) (* Some simplifications used after compilation *) (*---------------------------------------------------------------------------*) val lift_cond_above_let = Q.prove (`(let x = (if v1 then v2 else v3) in rst x) = if v1 then (let x = v2 in rst x) else (let x = v3 in rst x)`, RW_TAC std_ss [LET_DEF]); val lift_cond_above_let1 = Q.prove (`(let x = val in (if v1 then v2 x else v3 x)) = if v1 then (let x = val in v2 x) else (let x = val in v3 x)`, RW_TAC std_ss [LET_DEF]); val lift_cond_above_trivlet = Q.prove (`(let x = (if v1 then v2 else v3) in x) = if v1 then v2 else v3`, SIMP_TAC std_ss [LET_DEF]); val id_let = Q.prove (`LET (\x.x) y = y`, SIMP_TAC std_ss [LET_DEF]); (*---------------------------------------------------------------------------*) (* Eliminates "let x = x in M" redundancies *) (*---------------------------------------------------------------------------*) fun ID_BIND_CONV tm = let open pairSyntax val (abs,r) = dest_let tm in if bvar abs = r then RIGHT_BETA (REWR_CONV LET_THM tm) else failwith "ID_BIND_CONV" end; fun head eqns = strip_comb (lhs(snd(strip_forall(hd (strip_conj(concl eqns)))))); (*---------------------------------------------------------------------------*) (* Also want to show that the second definition satisfies the rec. eqn of *) (* the first definition. *) (*---------------------------------------------------------------------------*) fun compileDefine (env,q) = let open TotalDefn pairSyntax val _ = HOL_MESG "Initial definition" val def1 = Define q val (const,_) = head def1 val cinfo as (cname,_) = dest_const const val cvar = mk_var cinfo val compiled = pass4a(env,def1) val args = fst(dest_pabs(rhs(concl compiled))) val th1 = CONV_RULE (RHS_CONV pairLib.GEN_BETA_CONV) (AP_THM compiled args) val th2 = CONV_RULE (RHS_CONV (SIMP_CONV bool_ss [lift_cond_above_let,lift_cond_above_let1, lift_cond_above_trivlet,FLATTEN_LET, COND_RATOR, COND_RAND])) th1 val th3 = CONV_RULE (DEPTH_CONV ID_BIND_CONV) th2 val vtm = subst [const |-> cvar] (concl th3) val _ = HOL_MESG "Second definition" val PASS (defn2,tcs) = std_apiDefine (cname,vtm) val def2 = LIST_CONJ (Defn.eqns_of defn2) val ind = Defn.ind_of defn2 in (def2,ind) end; (*---------------------------------------------------------------------------*) (* Join the front end with Magnus' backend *) (*---------------------------------------------------------------------------*) fun test_compile' style q = let val (def, indopt) = compileDefine ([],q) val indth = (case indopt of NONE => TRUTH | SOME ind => ind) val (th,strs) = arm_compile (SPEC_ALL def) indth style in (def,indth,th) end; fun test_compile q = test_compile' InLineCode q; fun test_compile_proc q = test_compile' SimpleProcedure q; val _ = abbrev_code := true; val _ = reset_compiled(); (* val (load_751_def,_,load_751_arm) = test_compile_proc ` load_751 = let r10 = 2w:word32 in let r10 = r10 << 8 in let r10 = r10 + 239w in r10`; val load_751_def = Define `load_751 = let r10 = 2w:word32 in let r10 = r10 << 8 in let r10 = r10 + 239w in r10`; val (load_751_arm,_) = arm_compile load_751_def TRUTH InLineCode; *) val (field_neg_def,_,field_neg_arm) = test_compile `field_neg (x:word32) = if x = 0w then (0w:word32) else 751w - x`; val field_neg_triple = REWRITE_RULE [fetch "-" "field_neg_code1_def"] field_neg_arm; val (field_add_def,_,field_add_arm) = test_compile `field_add (x:word32,y:word32) = let z = x + y in if z < 751w then z else z - 751w`; val field_add_triple = REWRITE_RULE [fetch "-" "field_add_code1_def"] field_add_arm; val (field_sub_def,_,field_sub_arm) = test_compile `field_sub (x:word32,y:word32) = field_add(x,field_neg y)`; REWRITE_RULE [fetch "-" "field_sub_code1_def"] field_sub_arm; val (field_mult_aux_def,_,field_mult_aux_arm) = test_compile `field_mult_aux (x:word32,y:word32,acc:word32) = if y = 0w then acc else let x' = field_add (x,x) in let y' = y >>> 1 in let acc' = (if y && 1w = 0w then acc else field_add (acc,x)) in field_mult_aux (x',y',acc')`; REWRITE_RULE [fetch "-" "field_mult_aux1_def"] field_mult_aux_arm; val (field_mult_def,NONE) = test_compile `field_mult (x,y) = field_mult_aux (x,y,0w)`; val (field_exp_aux_def, _, field_exp_aux_arm) = test_compile `field_exp_aux (x:word32,n:word32,acc:word32) = if n = 0w then acc else let x' = field_mult (x,x) in let n' = n >>> 1 in let acc' = (if n && 1w = 0w then acc else field_mult (acc,x)) in field_exp_aux (x',n',acc')`); val (field_exp_def,NONE) = compileDefine([], `field_exp (x,n) = field_exp_aux (x,n,1w)`); val (field_inv_def,NONE) = compileDefine ([], `field_inv x = field_exp (x,749w)`); val (field_div_def,NONE) = compileDefine([], `field_div (x,y) = field_mult (x,field_inv y)`); val (curve_neg_def,NONE) = compileDefine([], `curve_neg (x1,y1) = if (x1 = 0w) /\ (y1 = 0w) then (0w,0w) else let y = field_sub (field_sub (field_neg y1,field_mult (0w,x1)),1w) in (x1,y)`); val (curve_double_def,NONE) = compileDefine([], `curve_double (x1,y1) = if (x1 = 0w) /\ (y1 = 0w) then (0w,0w) else let d = field_add (field_add (field_mult (2w,y1), field_mult (0w,x1)),1w) in if d = 0w then (0w,0w) else let l = field_div (field_sub (field_add (field_add (field_mult(3w,field_exp (x1,2w)), field_mult(field_mult (2w,0w),x1)),750w), field_mult (0w,y1)),d) in let m = field_div (field_sub (field_add (field_add (field_neg (field_exp (x1,3w)), field_mult (750w,x1)), field_mult (2w,0w)), field_mult (1w,y1)),d) in let x = field_sub (field_sub (field_add(field_exp (l,2w), field_mult (0w,l)),0w), field_mult (2w,x1)) in let y = field_sub (field_sub (field_mult (field_neg (field_add (l,0w)),x),m),1w) in (x,y)`); val (curve_add_def,NONE) = compileDefine([], `curve_add (x1,y1,x2,y2) = if (x1 = x2) /\ (y1 = y2) then curve_double (x2,y2) else if (x1 = 0w) /\ (y1 = 0w) then (x2,y2) else if (x2 = 0w) /\ (y2 = 0w) then (x1,y1) else if x1 = x2 then (0w,0w) else let d = field_sub (x2,x1) in let l = field_div (field_sub (y2,y1),d) in let m = field_div (field_sub (field_mult (y1,x2), field_mult (y2,x1)),d) in let x = field_sub (field_sub (field_sub (field_add (field_exp (l,2w), field_mult (0w,l)),0w),x1),x2) in let y = field_sub (field_sub (field_mult (field_neg (field_add (l,0w)),x),m),1w) in (x,y)`); val (curve_mult_aux_def,NONE) = compileDefine([], `curve_mult_aux (x,y,n:word32,acc_x,acc_y) = if n = 0w then (acc_x:word32,acc_y:word32) else let (x',y') = curve_double (x,y) in let n' = n >>> 1 in let (acc_x',acc_y') = (if n && 1w = 0w then (acc_x,acc_y) else curve_add (x,y,acc_x,acc_y)) in curve_mult_aux (x',y',n',acc_x',acc_y')`); val curve_mult_def = Define `curve_mult (x,y,n) = curve_mult_aux (x,y,n,0w,0w)`;