1structure funCall = 2struct 3 4local 5open HolKernel Parse boolLib IRSyntax annotatedIR 6structure T = IntMapTable(type key = int fun getInt n = n); 7structure S = Binaryset 8structure IR = IRSyntax 9in 10 11exception invalidArgs; 12val numAvaiRegs = ref 10; 13 14fun strOrder (s1:string,s2:string) = 15 if s1 > s2 then GREATER 16 else if s1 = s2 then EQUAL 17 else LESS; 18 19(* ---------------------------------------------------------------------------------------------------------------------*) 20(* The configuaration for passing parameters, returning results and allocating stack space *) 21(* The stack goes upward. *) 22(* ---------------------------------------------------------------------------------------------------------------------*) 23(* 24 Address goes upward (i.e. from low address to high address)!!! 25 26 Content Address 27 caller's ip | saved pc | 0 28 caller's fp | saved lr | 1 29 | save sp | 2 30 | save fp | 3 31 | modified reg k | 32 | modified reg k-1 | 33 | ... | 34 | local variable 0 | 4 35 | local variable 1 | 5 36 | ... | . 37 | local variable n | . 38 | output k from callee | . 39 | ... | 40 | output 1 from callee | 41 | output 0 from callee | 42 | argument m' | 43 | ... | 44 | argument 1 | 45 | argument 0 | 46 caller's sp callee's ip | saved pc | 47 callee's fp | saved lr | 48 | save sp | 49 | save fp | 50 | modified reg k | 51 | modified reg k-1 | 52 | ... | 53 | local variable 0 | 54 | local variable 1 | 55 | ... | 56 callee's sp | local variable n' | 57*) 58 59 60(* Distinguish inputs and local variables *) 61(* Calculate the offset based on fp of the temporaries bebing read in a callee *) 62(* If a temporary is an input, then read it from the stack where the caller sets *) 63(* If it is a local variable, then read it from the current frame *) 64 65fun calculate_relative_address (args,ir,outs,numSavedRegs) = 66 let 67 (* Identify those TMEMs that are actually arguments *) 68 val argT = #1 (List.foldl (fn (IR.TMEM n, (t,i)) => 69 (T.enter(t, n, i), i+1) 70 | (arg, (t,i)) => (t, i+1) 71 ) 72 (T.empty, 0) 73 (IR.pair2list args) 74 ); 75 76 val i = ref 0 77 val localT = ref (T.empty); (* Table for the local variables *) 78 79 (* For those TMEMs that are local variables, assign them in the stack according to the order of their apprearance *) 80 81 fun filter_mems (IR.TMEM n) = 82 ( case T.peek (argT, n) of 83 SOME k => IR.MEM (IR.fromAlias IR.fp, ~2 - k) (* inputs *) 84 | NONE => 85 ( case T.peek(!localT, n) of 86 SOME j => IR.MEM (IR.fromAlias IR.fp, 3 + j + numSavedRegs) (* existing local variable *) 87 | NONE => 88 ( localT := T.enter(!localT, n, !i); 89 i := !i + 1; 90 IR.MEM (IR.fromAlias IR.fp, 3 + (!i - 1) + numSavedRegs) (* local variables *) 91 ) 92 ) 93 ) 94 | filter_mems v = v 95 96 fun one_stm ({oper = op1, dst = dst1, src = src1}) = 97 {oper = op1, dst = List.map filter_mems dst1, src = List.map filter_mems src1} 98 99 fun adjust_exp (IR.PAIR (e1,e2)) = 100 IR.PAIR(adjust_exp e1, adjust_exp e2) 101 | adjust_exp e = 102 filter_mems e 103 104 fun adjust_info {ins = ins', outs = outs', context = context', fspec = fspec'} = 105 {ins = adjust_exp ins', outs = adjust_exp outs', context = List.map adjust_exp context', fspec = fspec'} 106 107 fun visit (SC(ir1,ir2,info)) = 108 SC (visit ir1, visit ir2, adjust_info info) 109 | visit (TR((e1,rop,e2),ir,info)) = 110 TR ((adjust_exp e1,rop,adjust_exp e2), visit ir, adjust_info info) 111 | visit (CJ((e1,rop,e2),ir1,ir2,info)) = 112 CJ ((adjust_exp e1,rop,adjust_exp e2), visit ir1, visit ir2, adjust_info info) 113 | visit (CALL(fname,pre,body,post,info)) = 114 CALL(fname, pre, body, post, adjust_info info) 115 | visit (STM l) = 116 STM (List.map one_stm l) 117 | visit (BLK (l,info)) = 118 BLK (List.map one_stm l, adjust_info info); 119 120 in 121 (adjust_exp args, visit ir, adjust_exp outs, T.numItems (!localT)) 122 end 123 124(* ---------------------------------------------------------------------------------------------------------------------*) 125(* Decrease and increase the value of a register by n *) 126(* ---------------------------------------------------------------------------------------------------------------------*) 127 128fun dec_p pt n = {oper = IR.msub, dst = [IR.REG pt], src = [IR.REG pt, IR.WCONST (Arbint.fromInt n)]} 129 130fun inc_p pt n = {oper = IR.madd, dst = [IR.REG pt], src = [IR.REG pt, IR.WCONST (Arbint.fromInt n)]}; 131 132 133(* ---------------------------------------------------------------------------------------------------------------------*) 134(* Pre-call and post-call processing in compilance with the ARM Procedure Call standard *) 135(* ---------------------------------------------------------------------------------------------------------------------*) 136 137(* MOV ip, sp 138 STMFA sp!, {..., fp,ip,lr,pc} 139 SUB fp, ip, #1 140 SUB sp, sp, #var (* skip local variables *) 141*) 142 143fun entry_blk rs n = 144 [ {oper = IR.mmov, dst = [IR.REG (IR.fromAlias IR.ip)], src = [IR.REG (IR.fromAlias IR.sp)]}, 145 {oper = IR.mpush, dst = [IR.REG (IR.fromAlias IR.sp)], 146 src = rs @ [IR.REG (IR.fromAlias IR.fp), IR.REG (IR.fromAlias IR.ip), 147 IR.REG (IR.fromAlias IR.lr), IR.REG (IR.fromAlias IR.pc)] 148 }, 149 {oper = IR.msub, dst = [IR.REG (IR.fromAlias IR.fp)], src = [IR.REG (IR.fromAlias IR.ip), IR.WCONST Arbint.one]}, 150 dec_p (IR.fromAlias IR.sp) n (* skip local variables *) 151 ] 152 153(* 154 ADD sp, fp, 3 + #modified registers (* Skip saved lr, sp, fp and modified registers *) 155 LDMFD sp, {..., fp,sp,pc} 156*) 157 158fun exit_blk rs = 159 [ 160 {oper = IR.msub, dst = [IR.REG (IR.fromAlias IR.sp)], src = [IR.REG (IR.fromAlias IR.fp), IR.WCONST (Arbint.fromInt (3 + length rs))]}, 161 {oper = IR.mpop, dst = rs @ [IR.REG (IR.fromAlias IR.fp), IR.REG (IR.fromAlias IR.sp), IR.REG (IR.fromAlias IR.pc)], 162 src = [IR.REG (IR.fromAlias IR.sp)]} 163 ]; 164 165 166(* ---------------------------------------------------------------------------------------------------------------------*) 167(* Given a list of registers and memory slots, group consecutive registers together to be used by mpush and mpop *) 168(* For example, [r1,r2,m1,m3,r3,4w,r4,r5] is segmented to *) 169(* ([r1,r2],true, 0),(m1,false,2),(m3,false,3),([r3],true,4),(4w,false,5),([r4,r5],true,6) *) 170(* ---------------------------------------------------------------------------------------------------------------------*) 171 172fun mk_reg_segments argL = 173 let val isBroken = ref false; 174 175 (* proceeds in reverse order of the list *) 176 fun one_arg (IR.REG r, (invRegLs, i)) = 177 let val flag = !isBroken 178 val _ = isBroken := false 179 in 180 if null invRegLs orelse flag then 181 (([IR.REG r], true, i) :: invRegLs, i-1) 182 else 183 let val (cur_segL, a, j) = hd invRegLs 184 in 185 if null cur_segL then (([IR.REG r], true, i) :: (tl invRegLs), i-1) 186 else ((IR.REG r :: cur_segL, true, i) :: (tl invRegLs), i-1) 187 end 188 end 189 | one_arg (exp, (invRegLs, i)) = 190 (isBroken := true; (([exp],false, i) :: invRegLs, i-1)) 191 192 val (invRegLs, i) = List.foldr one_arg ([], length argL - 1) argL 193 194 in 195 invRegLs 196 end 197 198(* ---------------------------------------------------------------------------------------------------------------------*) 199(* Given a list of registers, generate a mpush or mpop statement *) 200(* The sp is assumed to be in the right position *) 201(* If the list contain only one resiter, then use mldr and mstr instead *) 202(* ---------------------------------------------------------------------------------------------------------------------*) 203 204fun mk_ldm_stm isPop r dataL = 205 if isPop then 206 if length dataL = 1 then 207 {oper = IR.mldr, dst = dataL, src = [IR.MEM (r,1)]} 208 else 209 {oper = IR.mpop, dst = dataL, src = [IR.REG r]} 210 else 211 if length dataL = 1 then 212 {oper = IR.mstr, dst = [IR.MEM(r,0)], src = dataL} 213 else 214 {oper = IR.mpush, dst = [IR.REG r], src = dataL} 215 216(* ---------------------------------------------------------------------------------------------------------------------*) 217(* Write one argument to the memory slot referred by regNo and offset *) 218(* Push arguments to the stack. If the argument comes from a register, and store it into the stack directly; *) 219(* if it comes from memory, then first load it into R10, and then store it into the stack; *) 220(* if it is a constant, then assign it to R10 first then store it into the stack. *) 221(* ---------------------------------------------------------------------------------------------------------------------*) 222 223fun write_one_arg (IR.MEM v) (regNo,offset) = 224 [ {oper = IR.mldr, dst = [IR.REG (!numAvaiRegs)], src = [IR.MEM v]}, 225 {oper = IR.mstr, dst = [IR.MEM (regNo,offset)], src = [IR.REG (!numAvaiRegs)]} ] 226 | write_one_arg (IR.REG r) (regNo,offset) = 227 [ {oper = IR.mstr, dst = [IR.MEM (regNo,offset)], src = [IR.REG r]} ] 228 | write_one_arg v (regNo,offset) = (* v = NONCST n or WCONST w *) 229 [ {oper = IR.mmov, dst = [IR.REG 10], src = [v]}, 230 {oper = IR.mstr, dst = [IR.MEM (regNo,offset)], src = [IR.REG 10]} ] 231 232(* ---------------------------------------------------------------------------------------------------------------------*) 233(* Read one argument from the memory slot referred by regNo and offset *) 234(* If the destination is a register, then load it into the register directly; *) 235(* if it is in the memory, then first load the value into R10, and then store it into that memory location; *) 236(* The destination couldn't be a constant *) 237(* ---------------------------------------------------------------------------------------------------------------------*) 238 239fun read_one_arg (IR.REG r) (regNo,offset) = 240 [ {oper = IR.mldr, dst = [IR.REG r], src = [IR.MEM(regNo,offset)]} ] 241 | read_one_arg (IR.MEM v) (regNo,offset) = 242 [ {oper = IR.mldr, dst = [IR.REG (!numAvaiRegs)], src = [IR.MEM(regNo,offset)]}, 243 {oper = IR.mstr, dst = [IR.MEM v], src = [IR.REG (!numAvaiRegs)]} ] 244 | read_one_arg _ _ = 245 raise invalidArgs 246 247(* ---------------------------------------------------------------------------------------------------------------------*) 248(* Push a list of values that may be constants or in registers or in memory into the stack *) 249(* [1,2,3,...] *) 250(* old pointer | 1 | *) 251(* | 2 | *) 252(* | 3 | *) 253(* ... *) 254(* new pointer *) 255(* Note that the elements in the list are stored in the memory from low addresses to high addresses *) 256(* ---------------------------------------------------------------------------------------------------------------------*) 257 258fun pushL regNo argL = 259 let 260 val offset = ref 0; 261 262 fun one_seg (regL, true, i) = 263 if length regL = 1 then 264 write_one_arg (hd regL) (regNo, !offset - i) (* relative offset should be negative *) 265 else 266 let val k = !offset in 267 ( offset := i + length regL; 268 [dec_p regNo (i - k), 269 (* reverse the regL in accordance with the order of STM *) 270 mk_ldm_stm false regNo (List.rev regL)]) 271 end 272 | one_seg ([v], false, i) = 273 write_one_arg v (regNo, !offset - i) 274 | one_seg _ = raise invalidArgs 275 in 276 (List.foldl (fn (x,s) => s @ one_seg x) [] (mk_reg_segments argL)) @ 277 [dec_p regNo (length argL - !offset)] 278 end 279 280(* ---------------------------------------------------------------------------------------------------------------------*) 281(* Pop from the stack a list of values that may be in registers or in memory into the stack *) 282(* ... *) 283(* | 3 | *) 284(* | 2 | *) 285(* | 1 | *) 286(* pointer *) 287(* be read to the list [1,2,3,...] *) 288(* ---------------------------------------------------------------------------------------------------------------------*) 289 290fun popL regNo argL = 291 let 292 val offset = ref 0; 293 294 fun one_seg (regL, true, i) = 295 if length regL = 1 then 296 read_one_arg (hd regL) (regNo, i - !offset + 1) (* relative address should be positive *) 297 else 298 let val k = !offset in 299 ( offset := i; 300 [inc_p regNo (i - k), 301 mk_ldm_stm true regNo regL]) 302 end 303 | one_seg ([v], false, i) = 304 read_one_arg v (regNo, i - !offset + 1) 305 | one_seg _ = raise invalidArgs 306 in 307 (List.foldl (fn (x,s) => s @ one_seg x) [] (mk_reg_segments argL)) @ 308 [inc_p regNo (length argL - !offset)] 309 end 310 311(* ---------------------------------------------------------------------------------------------------------------------*) 312(* Pass by the caller the parameters to the callee *) 313(* All arguments are passed through the stack *) 314(* Stack status after passing *) 315(* ... *) 316(* | arg 3 | *) 317(* | arg 2 | *) 318(* sp | arg 1 | *) 319(* ---------------------------------------------------------------------------------------------------------------------*) 320 321fun pass_args argL = 322 pushL (IR.fromAlias IR.sp) (List.rev argL) 323 324(* ---------------------------------------------------------------------------------------------------------------------*) 325(* The callee obtains the arguments passed by the caller though the stack *) 326(* By default, the first four arguments are loaded into r0-r4 *) 327(* The rest arguments has been in the right positions. That is, we need not to get them explicitly *) 328(* Note that the register allocation assumes above convention *) 329(* ---------------------------------------------------------------------------------------------------------------------*) 330 331fun get_args argL = 332 let 333 val len = length argL; 334 val len1 = if len < (!numAvaiRegs) then len else (!numAvaiRegs) 335 336 fun mk_regL 0 = [IR.REG 0] 337 | mk_regL n = mk_regL (n-1) @ [IR.REG n]; 338 339 in 340 popL (IR.fromAlias (IR.ip)) argL 341 (* Note that callee's IP equals to caller's SP, we use the IP here to load the arguments *) 342 end; 343 344(* ---------------------------------------------------------------------------------------------------------------------*) 345(* Pass by the callee the results to the caller *) 346(* All results are passed through the stack *) 347(* Stack status after passing *) 348(* *) 349(* ... *) 350(* | result 3 | *) 351(* | result 2 | *) 352(* | result 1 | *) 353(* sp *) 354(* ---------------------------------------------------------------------------------------------------------------------*) 355 356fun send_results outL numArgs = 357 let 358 (* skip the arguments and the stored pc, then go to the position right before the first output*) 359 val sOffset = numArgs + length outL + 1; 360 val stms = pushL (IR.fromAlias IR.sp) (List.rev outL) 361 in 362 { oper = IR.madd, 363 dst = [IR.REG (IR.fromAlias IR.sp)], 364 src = [IR.REG (IR.fromAlias IR.fp), IR.WCONST (Arbint.fromInt sOffset)] 365 } 366 :: stms 367 end 368 369(* ---------------------------------------------------------------------------------------------------------------------*) 370(* The caller retreives the results passed by the callee though the stack *) 371(* Stack status *) 372(* *) 373(* ... *) 374(* | result 3 | *) 375(* | result 2 | *) 376(* | result 1 | *) 377(* sp | | *) 378(* ---------------------------------------------------------------------------------------------------------------------*) 379 380fun get_results outL = 381 popL (IR.fromAlias IR.sp) outL 382 383(* ---------------------------------------------------------------------------------------------------------------------*) 384(* Function Call: Pre-processing, Post-processing and Adjusted Body *) 385(* ---------------------------------------------------------------------------------------------------------------------*) 386 387fun compute_fcall_info ((outer_ins,outer_outs),(caller_src,caller_dst),(callee_ins,callee_outs),rs,context) = 388 let 389 fun to_stack expL = 390 let val len = length expL 391 val i = ref 0; 392 in 393 List.map (fn exp => ( i := !i + 1; MEM(11, ~(3 + (len - !i))))) expL 394 end 395 396 val to_be_stored = S.difference (list2set (List.map REG (S.listItems rs)), pair2set caller_dst); 397 val rs' = S.intersection(to_be_stored, S.union (pair2set outer_outs, list2set context)); 398 val pre_ins = trim_pair (PAIR (caller_src, set2pair rs')); 399 val stored = (list2pair o to_stack o set2list) rs'; 400 val pre_outs = trim_pair (PAIR (callee_ins, stored)); 401 val body_ins = pre_outs; 402 val body_outs = trim_pair(PAIR (callee_outs, stored)); 403 val post_ins = body_outs; 404 val post_outs = trim_pair(PAIR (caller_dst, set2pair rs')); 405 val context' = set2list (S.difference (list2set context, rs')) 406 in 407 ((pre_ins,pre_outs),(body_ins,body_outs),(post_ins,post_outs),set2list rs',context') 408 end 409 410 411val involved_defs = ref ([] : thm list); 412 413fun convert_fcall (CALL(fname, pre, body, post, outer_info)) = 414 let 415 (* pass the arguments to the callee, and callee will save and restore anything *) 416 (* create the BL statement *) 417 (* then modify the sp to point to the real position of the returning arguments *) 418 419 val (outer_ins,outer_outs) = (#ins outer_info, #outs outer_info); 420 val (caller_dst,caller_src) = let val x = get_annt body in (#outs x, #ins x) end; 421 val {ir = (callee_ins, callee_ir, callee_outs), regs = rs, localNum = n, def = f_def, ...} = declFuncs.getFunc fname; 422 val _ = involved_defs := (!involved_defs) @ [f_def]; 423 val ((pre_ins,pre_outs),(body_ins,body_outs),(post_ins,post_outs),rs',context) = 424 compute_fcall_info ((outer_ins,outer_outs),(caller_src,caller_dst),(callee_ins,callee_outs),rs,#context outer_info); 425 426 val reserve_space_for_outputs = 427 [dec_p (fromAlias sp) (length (pair2list caller_dst))]; 428 429 val pre' = rm_dummy_inst (BLK ( 430 reserve_space_for_outputs @ 431 pass_args (pair2list caller_src) @ 432 entry_blk rs' n @ 433 get_args (pair2list callee_ins), 434 {ins = pre_ins, outs = pre_outs, context = context, fspec = thm_t})); 435 436 val body' = apply_to_info callee_ir (fn info' => {ins = body_ins, outs = body_outs, context = context, fspec = thm_t}); 437 438 val post' = rm_dummy_inst (BLK ( 439 send_results (IR.pair2list callee_outs) (length (IR.pair2list callee_ins)) @ 440 get_results (IR.pair2list caller_dst) @ 441 exit_blk rs', 442 {ins = post_ins, outs = post_outs, context = context, fspec = thm_t})) 443 444 in 445 CALL(fname, pre', body', post', outer_info) 446 end 447 448 | convert_fcall (SC(s1,s2,info)) = SC (convert_fcall s1, convert_fcall s2, info) 449 | convert_fcall (CJ(cond,s1,s2,info)) = CJ (cond, convert_fcall s1, convert_fcall s2, info) 450 | convert_fcall (TR(cond,s,info)) = TR (cond, convert_fcall s, info) 451 | convert_fcall ir = ir; 452 453 454(* ---------------------------------------------------------------------------------------------------------------------*) 455(* Link caller and callees together *) 456(* ---------------------------------------------------------------------------------------------------------------------*) 457 458fun link_ir prog = 459 let 460 val (fname, ftype, f_ir as (ins,ir0,outs), defs) = sfl2ir prog; 461 val rs = S.addList (S.empty regAllocation.intOrder, get_modified_regs ir0); 462 val (ins1,ir1,outs1, localNum) = calculate_relative_address (ins,ir0,outs,S.numItems rs); 463 val _ = (involved_defs := []; 464 declFuncs.putFunc (fname, ftype, (ins1,ir1,outs1), rs, localNum, (hd defs))); 465 val ir2 = convert_fcall ir1 466 val ir3 = match_ins_outs ir2 467 in 468 (fname,ftype,(ins1,ir3,outs1), defs @ (!involved_defs)) 469 end; 470 471 472end (* local structure ... *) 473 474end (* structure *) 475