1structure funCall =
2struct
3
4local
5open HolKernel Parse boolLib IRSyntax annotatedIR
6structure T = IntMapTable(type key = int  fun getInt n = n);
7structure S = Binaryset
8structure IR = IRSyntax
9in
10
11exception invalidArgs;
12val numAvaiRegs = ref 10;
13
14fun strOrder (s1:string,s2:string) =
15    if s1 > s2 then GREATER
16    else if s1 = s2 then EQUAL
17    else LESS;
18
19(* ---------------------------------------------------------------------------------------------------------------------*)
20(* The configuaration for passing parameters, returning results and allocating stack space                              *)
21(* The stack goes upward.                                                                                               *)
22(* ---------------------------------------------------------------------------------------------------------------------*)
23(*
24                Address goes upward (i.e. from low address to high address)!!!
25
26                                          Content                 Address
27    caller's ip                 |  saved pc                     |   0
28    caller's fp                 |  saved lr                     |   1
29                                |  save sp                      |   2
30                                |  save fp                      |   3
31                                |  modified reg k               |
32                                |  modified reg k-1             |
33                                |  ...                          |
34                                |  local variable 0             |   4
35                                |  local variable 1             |   5
36                                |       ...                     |   .
37                                |  local variable n             |   .
38                                |  output k from callee         |   .
39                                |       ...                     |
40                                |  output 1 from callee         |
41                                |  output 0 from callee         |
42                                |  argument m'                  |
43                                |       ...                     |
44                                |  argument 1                   |
45                                |  argument 0                   |
46    caller's sp callee's ip     |  saved pc                     |
47    callee's fp                 |  saved lr                     |
48                                |  save sp                      |
49                                |  save fp                      |
50                                |  modified reg k               |
51                                |  modified reg k-1             |
52                                |  ...                          |
53                                |  local variable 0             |
54                                |  local variable 1             |
55                                |       ...                     |
56    callee's sp                 |  local variable n'            |
57*)
58
59
60(* Distinguish inputs and local variables                                               *)
61(* Calculate the offset based on fp of the temporaries bebing read in a callee          *)
62(* If a temporary is an input, then read it from the stack where the caller sets        *)
63(* If it is a local variable, then read it from the current frame                       *)
64
65fun calculate_relative_address (args,ir,outs,numSavedRegs) =
66  let
67    (* Identify those TMEMs that are actually arguments *)
68    val argT = #1 (List.foldl (fn (IR.TMEM n, (t,i)) =>
69                                  (T.enter(t, n, i), i+1)
70                               |  (arg, (t,i)) => (t, i+1)
71                              )
72                        (T.empty, 0)
73                        (IR.pair2list args)
74                  );
75
76    val i = ref 0
77    val localT = ref (T.empty);   (* Table for the local variables *)
78
79    (* For those TMEMs that are local variables, assign them in the stack according to the order of their apprearance *)
80
81    fun filter_mems (IR.TMEM n) =
82        ( case T.peek (argT, n) of
83                SOME k => IR.MEM (IR.fromAlias IR.fp, ~2 - k)           (* inputs *)
84           |     NONE =>
85                ( case T.peek(!localT, n) of
86                      SOME j => IR.MEM (IR.fromAlias IR.fp, 3 + j + numSavedRegs) (* existing local variable *)
87                   |  NONE =>
88                          ( localT := T.enter(!localT, n, !i);
89                            i := !i + 1;
90                            IR.MEM (IR.fromAlias IR.fp, 3 + (!i - 1) + numSavedRegs) (* local variables *)
91                          )
92                 )
93        )
94     |  filter_mems v = v
95
96    fun one_stm ({oper = op1, dst = dst1, src = src1}) =
97            {oper = op1, dst = List.map filter_mems dst1, src = List.map filter_mems src1}
98
99    fun adjust_exp (IR.PAIR (e1,e2)) =
100            IR.PAIR(adjust_exp e1, adjust_exp e2)
101     |  adjust_exp e =
102            filter_mems e
103
104    fun adjust_info {ins = ins', outs = outs', context = context', fspec = fspec'} =
105        {ins = adjust_exp ins', outs = adjust_exp outs', context = List.map adjust_exp context', fspec = fspec'}
106
107    fun visit (SC(ir1,ir2,info)) =
108         SC (visit ir1, visit ir2, adjust_info info)
109    |  visit (TR((e1,rop,e2),ir,info)) =
110         TR ((adjust_exp e1,rop,adjust_exp e2), visit ir, adjust_info info)
111    |  visit (CJ((e1,rop,e2),ir1,ir2,info)) =
112         CJ ((adjust_exp e1,rop,adjust_exp e2), visit ir1, visit ir2, adjust_info info)
113    |  visit (CALL(fname,pre,body,post,info)) =
114         CALL(fname, pre, body, post, adjust_info info)
115    |  visit (STM l) =
116         STM (List.map one_stm l)
117    |  visit (BLK (l,info)) =
118         BLK (List.map one_stm l, adjust_info info);
119
120  in
121        (adjust_exp args, visit ir, adjust_exp outs, T.numItems (!localT))
122  end
123
124(* ---------------------------------------------------------------------------------------------------------------------*)
125(*  Decrease and increase the value of a register by n                                                                  *)
126(* ---------------------------------------------------------------------------------------------------------------------*)
127
128fun dec_p pt n = {oper = IR.msub, dst = [IR.REG pt], src = [IR.REG pt, IR.WCONST (Arbint.fromInt n)]}
129
130fun inc_p pt n = {oper = IR.madd, dst = [IR.REG pt], src = [IR.REG pt, IR.WCONST (Arbint.fromInt n)]};
131
132
133(* ---------------------------------------------------------------------------------------------------------------------*)
134(* Pre-call and post-call processing in compilance with the ARM Procedure Call standard                                 *)
135(* ---------------------------------------------------------------------------------------------------------------------*)
136
137(*      MOV     ip, sp
138        STMFA   sp!, {..., fp,ip,lr,pc}
139        SUB     fp, ip, #1
140        SUB     sp, sp, #var (* skip local variables *)
141*)
142
143fun entry_blk rs n =
144    [ {oper = IR.mmov, dst = [IR.REG (IR.fromAlias IR.ip)], src = [IR.REG (IR.fromAlias IR.sp)]},
145      {oper = IR.mpush, dst = [IR.REG (IR.fromAlias IR.sp)],
146       src = rs @ [IR.REG (IR.fromAlias IR.fp), IR.REG (IR.fromAlias IR.ip),
147                   IR.REG (IR.fromAlias IR.lr), IR.REG (IR.fromAlias IR.pc)]
148      },
149      {oper = IR.msub, dst = [IR.REG (IR.fromAlias IR.fp)], src = [IR.REG (IR.fromAlias IR.ip), IR.WCONST Arbint.one]},
150      dec_p (IR.fromAlias IR.sp) n (* skip local variables *)
151    ]
152
153(*
154    ADD         sp, fp, 3 + #modified registers      (* Skip saved lr, sp, fp and modified registers *)
155    LDMFD       sp, {..., fp,sp,pc}
156*)
157
158fun exit_blk rs =
159    [
160      {oper = IR.msub, dst = [IR.REG (IR.fromAlias IR.sp)], src = [IR.REG (IR.fromAlias IR.fp), IR.WCONST (Arbint.fromInt (3 + length rs))]},
161      {oper = IR.mpop, dst = rs @ [IR.REG (IR.fromAlias IR.fp), IR.REG (IR.fromAlias IR.sp), IR.REG (IR.fromAlias IR.pc)],
162       src = [IR.REG (IR.fromAlias IR.sp)]}
163    ];
164
165
166(* ---------------------------------------------------------------------------------------------------------------------*)
167(*  Given a list of registers and memory slots, group consecutive registers together to be used by mpush and mpop       *)
168(*  For example, [r1,r2,m1,m3,r3,4w,r4,r5] is segmented to                                                              *)
169(*  ([r1,r2],true, 0),(m1,false,2),(m3,false,3),([r3],true,4),(4w,false,5),([r4,r5],true,6)                             *)
170(* ---------------------------------------------------------------------------------------------------------------------*)
171
172fun mk_reg_segments argL =
173  let val isBroken = ref false;
174
175      (* proceeds in reverse order of the list *)
176      fun one_arg (IR.REG r, (invRegLs, i)) =
177        let val flag = !isBroken
178            val _ = isBroken := false
179        in
180            if null invRegLs orelse flag then
181                (([IR.REG r], true, i) :: invRegLs, i-1)
182            else
183                let val (cur_segL, a, j) = hd invRegLs
184                in
185                    if null cur_segL then (([IR.REG r], true, i) :: (tl invRegLs), i-1)
186                    else ((IR.REG r :: cur_segL, true, i) :: (tl invRegLs), i-1)
187                end
188        end
189      |   one_arg (exp, (invRegLs, i)) =
190                (isBroken := true; (([exp],false, i) :: invRegLs, i-1))
191
192      val (invRegLs, i) = List.foldr one_arg ([], length argL - 1) argL
193
194  in
195      invRegLs
196  end
197
198(* ---------------------------------------------------------------------------------------------------------------------*)
199(*  Given a list of registers, generate a mpush or mpop statement                                                       *)
200(*  The sp is assumed to be in the right position                                                                       *)
201(*  If the list contain only one resiter, then use mldr and mstr instead                                                *)
202(* ---------------------------------------------------------------------------------------------------------------------*)
203
204fun mk_ldm_stm isPop r dataL =
205    if isPop then
206        if length dataL = 1 then
207               {oper = IR.mldr, dst = dataL, src = [IR.MEM (r,1)]}
208        else
209               {oper = IR.mpop, dst = dataL, src = [IR.REG r]}
210    else
211        if length dataL = 1 then
212               {oper = IR.mstr, dst = [IR.MEM(r,0)], src = dataL}
213        else
214               {oper = IR.mpush, dst = [IR.REG r], src = dataL}
215
216(* ---------------------------------------------------------------------------------------------------------------------*)
217(*  Write one argument to the memory slot referred by regNo and offset                                                  *)
218(*  Push arguments to the stack. If the argument comes from a register, and store it into the stack directly;           *)
219(*  if it comes from memory, then first load it into R10, and then store it into the stack;                             *)
220(*  if it is a constant, then assign it to R10 first then store it into the stack.                                      *)
221(* ---------------------------------------------------------------------------------------------------------------------*)
222
223fun write_one_arg (IR.MEM v) (regNo,offset) =
224               [  {oper = IR.mldr, dst = [IR.REG (!numAvaiRegs)], src = [IR.MEM v]},
225                  {oper = IR.mstr, dst = [IR.MEM (regNo,offset)], src = [IR.REG (!numAvaiRegs)]} ]
226 |  write_one_arg (IR.REG r) (regNo,offset) =
227               [  {oper = IR.mstr, dst = [IR.MEM (regNo,offset)], src = [IR.REG r]} ]
228 |  write_one_arg v (regNo,offset) =   (* v = NONCST n or WCONST w *)
229               [  {oper = IR.mmov, dst = [IR.REG 10], src = [v]},
230                  {oper = IR.mstr, dst = [IR.MEM (regNo,offset)], src = [IR.REG 10]} ]
231
232(* ---------------------------------------------------------------------------------------------------------------------*)
233(*  Read one argument from the memory slot referred by regNo and offset                                                 *)
234(*  If the destination is a register, then load it into the register directly;                                          *)
235(*  if it is in the memory, then first load the value into R10, and then store it into that memory location;            *)
236(*  The destination couldn't be a constant                                                                              *)
237(* ---------------------------------------------------------------------------------------------------------------------*)
238
239fun read_one_arg (IR.REG r) (regNo,offset) =
240               [  {oper = IR.mldr, dst = [IR.REG r], src = [IR.MEM(regNo,offset)]} ]
241 |  read_one_arg (IR.MEM v) (regNo,offset) =
242               [  {oper = IR.mldr, dst = [IR.REG (!numAvaiRegs)], src = [IR.MEM(regNo,offset)]},
243                  {oper = IR.mstr, dst = [IR.MEM v], src = [IR.REG (!numAvaiRegs)]} ]
244 |  read_one_arg _ _ =
245                 raise invalidArgs
246
247(* ---------------------------------------------------------------------------------------------------------------------*)
248(* Push a list of values that may be constants or in registers or in memory into the stack                              *)
249(* [1,2,3,...]                                                                                                          *)
250(* old pointer | 1 |                                                                                                    *)
251(*             | 2 |                                                                                                    *)
252(*             | 3 |                                                                                                    *)
253(*              ...                                                                                                     *)
254(* new pointer                                                                                                          *)
255(* Note that the elements in the list are stored in the memory from low addresses to high addresses                     *)
256(* ---------------------------------------------------------------------------------------------------------------------*)
257
258fun pushL regNo argL =
259  let
260      val offset = ref 0;
261
262      fun one_seg (regL, true, i) =
263              if length regL = 1 then
264                    write_one_arg (hd regL) (regNo, !offset - i) (* relative offset should be negative *)
265              else
266                  let val k = !offset in
267                    ( offset := i + length regL;
268                      [dec_p regNo (i - k),
269                       (* reverse the regL in accordance with the order of STM *)
270                       mk_ldm_stm false regNo (List.rev regL)])
271                  end
272       | one_seg ([v], false, i) =
273                  write_one_arg v (regNo, !offset - i)
274       | one_seg _ = raise invalidArgs
275  in
276      (List.foldl (fn (x,s) => s @ one_seg x) [] (mk_reg_segments argL)) @
277       [dec_p regNo (length argL - !offset)]
278  end
279
280(* ---------------------------------------------------------------------------------------------------------------------*)
281(* Pop from the stack a list of values that may be in registers or in memory into the stack                             *)
282(*           ...                                                                                                        *)
283(*           | 3 |                                                                                                      *)
284(*           | 2 |                                                                                                      *)
285(*           | 1 |                                                                                                      *)
286(*  pointer                                                                                                             *)
287(*  be read to the list [1,2,3,...]                                                                                     *)
288(* ---------------------------------------------------------------------------------------------------------------------*)
289
290fun popL regNo argL =
291  let
292      val offset = ref 0;
293
294      fun one_seg (regL, true, i) =
295              if length regL = 1 then
296                    read_one_arg (hd regL) (regNo, i - !offset + 1) (* relative address should be positive *)
297              else
298                  let val k = !offset in
299                    ( offset := i;
300                      [inc_p regNo (i - k),
301                       mk_ldm_stm true regNo regL])
302                  end
303       | one_seg ([v], false, i) =
304                  read_one_arg v (regNo, i - !offset + 1)
305       | one_seg _ = raise invalidArgs
306  in
307      (List.foldl (fn (x,s) => s @ one_seg x) [] (mk_reg_segments argL)) @
308       [inc_p regNo (length argL - !offset)]
309  end
310
311(* ---------------------------------------------------------------------------------------------------------------------*)
312(* Pass by the caller the parameters to the callee                                                                      *)
313(* All arguments are passed through the stack                                                                           *)
314(* Stack status after passing                                                                                           *)
315(*          ...                                                                                                         *)
316(*        | arg 3 |                                                                                                     *)
317(*        | arg 2 |                                                                                                     *)
318(* sp     | arg 1 |                                                                                                     *)
319(* ---------------------------------------------------------------------------------------------------------------------*)
320
321fun pass_args argL =
322  pushL (IR.fromAlias IR.sp) (List.rev argL)
323
324(* ---------------------------------------------------------------------------------------------------------------------*)
325(* The callee obtains the arguments passed by the caller though the stack                                               *)
326(* By default, the first four arguments are loaded into r0-r4                                                           *)
327(* The rest arguments has been in the right positions. That is, we need not to get them explicitly                      *)
328(* Note that the register allocation assumes above convention                                                           *)
329(* ---------------------------------------------------------------------------------------------------------------------*)
330
331fun get_args argL =
332   let
333       val len = length argL;
334       val len1 = if len < (!numAvaiRegs) then len else (!numAvaiRegs)
335
336       fun mk_regL 0 = [IR.REG 0]
337        |  mk_regL n = mk_regL (n-1) @ [IR.REG n];
338
339   in
340       popL (IR.fromAlias (IR.ip)) argL
341       (* Note that callee's IP equals to caller's SP, we use the IP here to load the arguments *)
342   end;
343
344(* ---------------------------------------------------------------------------------------------------------------------*)
345(* Pass by the callee the results to the caller                                                                         *)
346(* All results are passed through the stack                                                                             *)
347(* Stack status after passing                                                                                           *)
348(*                                                                                                                      *)
349(*             ...                                                                                                      *)
350(*         | result 3 |                                                                                                 *)
351(*         | result 2 |                                                                                                 *)
352(*         | result 1 |                                                                                                 *)
353(*    sp                                                                                                                *)
354(* ---------------------------------------------------------------------------------------------------------------------*)
355
356fun send_results outL numArgs =
357   let
358       (* skip the arguments and the stored pc, then go to the position right before the first output*)
359       val sOffset = numArgs + length outL + 1;
360       val stms = pushL (IR.fromAlias IR.sp) (List.rev outL)
361   in
362       { oper = IR.madd,
363         dst = [IR.REG (IR.fromAlias IR.sp)],
364         src = [IR.REG (IR.fromAlias IR.fp), IR.WCONST (Arbint.fromInt sOffset)]
365       }
366       :: stms
367   end
368
369(* ---------------------------------------------------------------------------------------------------------------------*)
370(* The caller retreives the results passed by the callee though the stack                                               *)
371(* Stack status                                                                                                         *)
372(*                                                                                                                      *)
373(*             ...                                                                                                      *)
374(*         | result 3 |                                                                                                 *)
375(*         | result 2 |                                                                                                 *)
376(*         | result 1 |                                                                                                 *)
377(*     sp  |          |                                                                                                 *)
378(* ---------------------------------------------------------------------------------------------------------------------*)
379
380fun get_results outL =
381        popL (IR.fromAlias IR.sp) outL
382
383(* ---------------------------------------------------------------------------------------------------------------------*)
384(* Function Call: Pre-processing, Post-processing and Adjusted Body                                                     *)
385(* ---------------------------------------------------------------------------------------------------------------------*)
386
387fun compute_fcall_info ((outer_ins,outer_outs),(caller_src,caller_dst),(callee_ins,callee_outs),rs,context) =
388    let
389        fun to_stack expL =
390            let val len = length expL
391                val i = ref 0;
392            in
393                List.map (fn exp => ( i := !i + 1; MEM(11, ~(3 + (len - !i))))) expL
394            end
395
396        val to_be_stored = S.difference (list2set (List.map REG (S.listItems rs)), pair2set caller_dst);
397        val rs' = S.intersection(to_be_stored, S.union (pair2set outer_outs, list2set context));
398        val pre_ins = trim_pair (PAIR (caller_src, set2pair rs'));
399        val stored = (list2pair o to_stack o set2list) rs';
400        val pre_outs = trim_pair (PAIR (callee_ins, stored));
401        val body_ins = pre_outs;
402        val body_outs = trim_pair(PAIR (callee_outs, stored));
403        val post_ins = body_outs;
404        val post_outs = trim_pair(PAIR (caller_dst, set2pair rs'));
405        val context' = set2list (S.difference (list2set context, rs'))
406    in
407       ((pre_ins,pre_outs),(body_ins,body_outs),(post_ins,post_outs),set2list rs',context')
408    end
409
410
411val involved_defs = ref ([] : thm list);
412
413fun convert_fcall (CALL(fname, pre, body, post, outer_info)) =
414    let
415        (* pass the arguments to the callee, and callee will save and restore anything  *)
416        (* create the BL statement                                                      *)
417        (* then modify the sp to point to the real position of the returning arguments  *)
418
419        val (outer_ins,outer_outs) = (#ins outer_info, #outs outer_info);
420        val (caller_dst,caller_src) = let val x = get_annt body in (#outs x, #ins x) end;
421        val {ir = (callee_ins, callee_ir, callee_outs), regs = rs, localNum = n, def = f_def, ...} = declFuncs.getFunc fname;
422        val _ = involved_defs := (!involved_defs) @ [f_def];
423        val ((pre_ins,pre_outs),(body_ins,body_outs),(post_ins,post_outs),rs',context) =
424             compute_fcall_info ((outer_ins,outer_outs),(caller_src,caller_dst),(callee_ins,callee_outs),rs,#context outer_info);
425
426        val reserve_space_for_outputs =
427                [dec_p (fromAlias sp) (length (pair2list caller_dst))];
428
429        val pre' = rm_dummy_inst (BLK (
430                        reserve_space_for_outputs @
431                        pass_args (pair2list caller_src) @
432                        entry_blk rs' n @
433                        get_args (pair2list callee_ins),
434                    {ins = pre_ins, outs = pre_outs, context = context, fspec = thm_t}));
435
436        val body' = apply_to_info callee_ir (fn info' => {ins = body_ins, outs = body_outs, context = context, fspec = thm_t});
437
438        val post' = rm_dummy_inst (BLK (
439                        send_results (IR.pair2list callee_outs) (length (IR.pair2list callee_ins)) @
440                         get_results (IR.pair2list caller_dst) @
441                        exit_blk rs',
442                    {ins = post_ins, outs = post_outs, context = context, fspec = thm_t}))
443
444    in
445        CALL(fname, pre', body', post', outer_info)
446    end
447
448 |  convert_fcall (SC(s1,s2,info)) = SC (convert_fcall s1, convert_fcall s2, info)
449 |  convert_fcall (CJ(cond,s1,s2,info)) = CJ (cond, convert_fcall s1, convert_fcall s2, info)
450 |  convert_fcall (TR(cond,s,info)) = TR (cond, convert_fcall s, info)
451 |  convert_fcall ir = ir;
452
453
454(* ---------------------------------------------------------------------------------------------------------------------*)
455(* Link caller and callees together                                                                                     *)
456(* ---------------------------------------------------------------------------------------------------------------------*)
457
458fun link_ir prog =
459  let
460      val (fname, ftype, f_ir as (ins,ir0,outs), defs) = sfl2ir prog;
461      val rs = S.addList (S.empty regAllocation.intOrder, get_modified_regs ir0);
462      val (ins1,ir1,outs1, localNum) = calculate_relative_address (ins,ir0,outs,S.numItems rs);
463      val _ = (involved_defs := [];
464               declFuncs.putFunc (fname, ftype, (ins1,ir1,outs1), rs, localNum, (hd defs)));
465      val ir2 = convert_fcall ir1
466      val ir3 = match_ins_outs ir2
467  in
468      (fname,ftype,(ins1,ir3,outs1), defs @ (!involved_defs))
469  end;
470
471
472end (* local structure ... *)
473
474end (* structure *)
475