1structure funcCall (* :> funcCall *) = 2struct 3 4(* app load ["NormalTheory", "basic", "regAlloc"] *) 5 6open HolKernel Parse boolLib bossLib; 7open pairLib pairSyntax PairRules NormalTheory basic; 8 9val atom_tm = prim_mk_const{Name="atom",Thy="Normal"} 10fun mk_atom tm = mk_comb (inst [alpha |-> type_of tm] atom_tm,tm) 11 12(*----------------------------------------------------------------------------------------------*) 13(* *) 14(*----------------------------------------------------------------------------------------------*) 15 16structure M = Binarymap 17structure S = Binaryset 18val VarType = ref (Type `:word32`) (* numSyntax.num *) 19 20(*----------------------------------------------------------------------------------------------*) 21(* Pre-defined variables and functions *) 22(*----------------------------------------------------------------------------------------------*) 23 24fun is_reg x = (String.sub (term_to_string x,0) = #"r") 25fun is_mem x = (String.sub (term_to_string x,0) = #"m") 26 27fun tvarOrder (t1:term,t2:term) = 28 let val (s1,s2) = (term_to_string t1, term_to_string t2) in 29 if s1 > s2 then GREATER 30 else if s1 = s2 then EQUAL 31 else LESS 32 end; 33 34(* Is an expression a function application? *) 35 36fun is_fun exp = 37 not (is_comb exp) andalso 38 (#1 (dest_type (type_of exp)) = "fun") 39 handle HOL_ERR _ => false; 40 41(*----------------------------------------------------------------------------------------------*) 42(* Function calls in a caller-save style. *) 43(*----------------------------------------------------------------------------------------------*) 44 45(* Traverse the function body to find all modified registers and the next available memory slot *) 46 47fun process_body body = 48 let 49 50 fun traverse t (rS, wS) = 51 if is_let t then 52 let val (v,M,N) = dest_plet t 53 val (rS1, wS1) = traverse M (rS, S.addList(wS, List.filter is_reg (strip_pair v))) 54 in 55 traverse N (rS1, wS1) 56 end 57 else if is_cond t then 58 let val (J,M,N) = dest_cond t in 59 ((traverse N) o (traverse M) o (traverse J)) (rS, wS) 60 end 61 else if is_pair t then 62 let val (M,N) = dest_pair t in 63 ((traverse N) o (traverse M)) (rS, wS) 64 end 65 else if is_pabs t then 66 let val (M,N) = dest_pabs t in 67 ((traverse N) o (traverse M)) (rS, wS) 68 end 69 else if is_comb t then 70 let val (M,N) = dest_comb t in 71 ((traverse N) o (traverse M)) (rS, wS) 72 end 73 else if is_reg t orelse is_mem t then 74 (S.add(rS, t), wS) 75 else (rS, wS) 76 77 val (rS', wS') = traverse body (S.empty tvarOrder, S.empty tvarOrder) 78 val next_avail_slot = List.foldl (* the first unused memory slot *) 79 (fn (v, i) => 80 if is_mem v then 81 let val s = #1 (dest_var v) 82 val j = valOf (Int.fromString (substring(s, 1, String.size s - 1))) 83 in if j > i then j else i 84 end 85 else i 86 ) 87 0 (S.listItems rS') 88 89 in 90 (wS', next_avail_slot) 91 end; 92 93(* Find all modified registers and the next available memory slot for all functions *) 94 95val fmap = ref (M.mkDict tvarOrder); 96 97fun investigate_def def = 98 let 99 val (fname, fbody) = dest_eq (concl (SPEC_ALL def)) 100 val (args,body) = dest_pabs fbody handle _ => (#2 (dest_comb fname), fbody) 101 val fname = if is_pabs fbody then fname else #1 (dest_comb fname) 102 in 103 fmap := M.insert(!fmap, fname, ((args, identify_output body), process_body body)) 104 end 105 106(* Convert a function body into its call-save format *) 107 108fun save (wS, next_slot) exp = 109 #1 (List.foldl (fn (r, (e, slot)) => 110 (mk_plet (mk_var("m" ^ Int.toString(slot) , !VarType), mk_atom r, e), slot + 1)) 111 (exp, next_slot) 112 (S.listItems wS) 113 ) 114 115fun restore (wS, next_slot) exp = 116 #1 (List.foldr (fn (r, (e, slot)) => 117 (mk_plet (r, mk_atom (mk_var("m" ^ Int.toString(slot), !VarType)), e), slot + 1)) 118 (exp, next_slot) 119 (rev (S.listItems wS)) 120 ) 121 122(* 123 fmap := preprocess defs; 124*) 125 126val tr_f = ref (``T``); (* the name of a recursive function *) 127 128fun format_call (fname, dst, src, cont) = 129 let val ((src0, dst0), (s, next_slot)) = M.find(!fmap, fname) 130 (* handle NotFound => (investigate_def (DB.definition (#1 (dest_const fname) ^ "_def")); 131 M.find(!fmap, fname)) 132 *) 133 val s1 = S.addList (S.addList(s, strip_pair src0), strip_pair dst0) 134 val s2 = S.difference(s1, S.addList(S.empty tvarOrder, strip_pair dst)) 135 136 val t1 = restore (s2, next_slot) cont 137 val t2 = regAlloc.parallel_move dst dst0 t1 138 val t3 = mk_plet(dst0, mk_comb(fname, src0), t2) 139 val t4 = regAlloc.parallel_move src0 src t3 140 val t5 = save (s2, next_slot) t4 141 in 142 t5 143 end 144 145fun caller_save t = 146 if is_let t then 147 let val (v,M,N) = dest_plet t in 148 if is_comb M andalso not (is_atomic M) then 149 let val (x,y) = dest_comb M in 150 if is_fun x andalso not (x = !tr_f) then (* non-recursive function application *) 151 let 152 val (fname, dst, src, cont) = (x, v, y, caller_save N) 153 in 154 format_call (fname, dst, src, cont) 155 end 156 else 157 mk_plet(v, caller_save M, caller_save N) 158 end 159 else 160 mk_plet(v, caller_save M, caller_save N) (* not function application *) 161 end 162 else if is_cond t then 163 let val (J,M,N) = dest_cond t in 164 mk_cond (J, caller_save M, caller_save N) 165 end 166 else if is_pair t then 167 let val (M,N) = dest_pair t in 168 mk_pair (caller_save M, caller_save N) 169 end 170 else if is_pabs t then 171 let val (M,N) = dest_pabs t in 172 mk_pabs (caller_save M, caller_save N) 173 end 174 else if is_comb t then 175 let val (M,N) = dest_comb t in 176 mk_comb(caller_save M, caller_save N) 177 end 178 else t 179 180(* Function "trav" traverses a term and adds pre-call saving and post-call saving instructions 181 for each function call; it also make the outputs of the two branches of a conditional 182 statement match, i.e. both branches have the same outputs. 183*) 184 185fun trav t output = 186 if is_let t then 187 let val (v,M,N) = dest_plet t 188 val (M', _) = trav M NONE 189 in 190 if is_comb M andalso not (is_atomic M) then 191 let val (x,y) = dest_comb M in 192 if is_fun x andalso not (x = !tr_f) then (* non-recursive function application *) 193 let 194 val (N', output') = trav N output 195 val (fname, dst, src, cont) = (x, v, y, N') 196 val t' = format_call (fname, dst, src, cont) 197 in 198 (t', output') 199 end 200 else 201 let val (N', output') = trav N output 202 in (mk_plet(v, M', N'), output') end 203 end 204 else 205 let val (N', output') = trav N output 206 in (mk_plet(v, M', N'), output') end (* not function application *) 207 end 208 else if is_cond t then 209 let val (J,M1,M2) = dest_cond t 210 val (M1', output1) = trav M1 output 211 val (M2', output2) = trav M2 output1 212 in 213 (mk_cond(J, M1', M2'), output1) 214 end 215 else if is_pair t orelse is_atomic t then 216 case output of 217 NONE => (t, SOME t) 218 | SOME x => (regAlloc.parallel_move x t x, output) 219 else if is_pabs t then 220 let val (M,N) = dest_pabs t 221 val (N', _) = trav N NONE 222 in 223 (mk_pabs (M, N'), output) 224 end 225 else (t, output) 226 227(* Process function calls in a caller-save style *) 228 229fun caller_save_call def = 230 let val (fname, fbody) = dest_eq (concl (SPEC_ALL def)) 231 val (args,body) = dest_pabs fbody handle _ => (#2 (dest_comb fname), fbody) 232 val (sane,var_type) = pre_check(args,body) 233 val fname = if is_pabs fbody then fname else #1 (dest_comb fname) 234 val _ = (VarType := var_type; tr_f := fname) 235 val _ = investigate_def def (* store the information of the current function into !fmap *) 236 in if sane then 237 let 238 val (body', _) = trav body NONE 239 val th0 = SYM (QCONV(SIMP_CONV pure_ss [LET_ATOM]) body') 240 val (r,t) = dest_eq(concl th0) 241 val lem1 = ALPHA body r 242 val th1 = TRANS lem1 th0 243 val th2 = REWRITE_RULE [Once th1] def 244 val th3 = REWRITE_RULE [ATOM_ID] th2 245 in SIMP_RULE bool_ss [ELIM_USELESS_LET] th3 246 end 247 else def 248 end 249 250(* 251fun callerSave defs = 252 let 253 fun one_fun def = 254 let val (fname, fbody) = dest_eq (concl (SPEC_ALL def)) 255 val (args,body) = dest_pabs fbody handle _ => (#2 (dest_comb fname), fbody) 256 val (sane,var_type) = pre_check(args,body) 257 val fname = if is_pabs fbody then fname else #1 (dest_comb fname) 258 val _ = (VarType := var_type; tr_f := fname) 259 in if sane then 260 let 261 val body' = caller_save body 262 val th0 = SYM (QCONV(SIMP_CONV pure_ss [LET_ATOM]) body') 263 val th1 = REWRITE_RULE [ELIM_USELESS_LET] th0 264 val th2 = REWRITE_RULE [Once th1] def 265 val th3 = REWRITE_RULE [ATOM_ID] th2 266 in th3 267 end 268 else def 269 end 270 in 271 List.map one_fun defs 272 end 273*) 274 275(*----------------------------------------------------------------------------------------------*) 276(* For debugging *) 277(*----------------------------------------------------------------------------------------------*) 278 279fun mm () = 280 List.map (fn (fname, ((src, dst), (wS, slot))) => (fname, (src, dst, S.listItems wS, slot))) (M.listItems (!fmap)) 281 282 283end 284