1structure funCall =
2struct
3
4local
5open HolKernel Parse boolLib IRSyntax annotatedIR
6structure T = IntMapTable(type key = int  fun getInt n = n);
7structure S = Binaryset
8structure IR = IRSyntax
9in
10
11exception invalidArgs;
12val numAvaiRegs = ref 10;
13
14fun strOrder (s1:string,s2:string) =
15    if s1 > s2 then GREATER
16    else if s1 = s2 then EQUAL
17    else LESS;
18
19(* ---------------------------------------------------------------------------------------------------------------------*)
20(* The configuaration for passing parameters, returning results and allocating stack space                              *)
21(* The stack goes upward.                                                                                               *)
22(* ---------------------------------------------------------------------------------------------------------------------*)
23(*
24                Address goes upward (i.e. from low address to high address)!!!
25
26                                          Content                 Address
27    caller's ip                 |  saved pc                     |   0
28    caller's fp                 |  saved lr                     |   1
29                                |  save sp                      |   2
30                                |  save fp                      |   3
31                                |  modified reg k               |
32                                |  modified reg k-1             |
33                                |  ...                          |
34                                |  local variable 0             |   4
35                                |  local variable 1             |   5
36                                |       ...                     |   .
37                                |  local variable n             |   .
38                                |  output k from callee         |   .
39                                |       ...                     |
40                                |  output 1 from callee         |
41                                |  output 0 from callee         |
42                                |  argument m'                  |
43                                |       ...                     |
44                                |  argument 1                   |
45                                |  argument 0                   |
46    caller's sp callee's ip     |  saved pc                     |
47    callee's fp                 |  saved lr                     |
48                                |  save sp                      |
49                                |  save fp                      |
50                                |  modified reg k               |
51                                |  modified reg k-1             |
52                                |  ...                          |
53                                |  local variable 0             |
54                                |  local variable 1             |
55                                |       ...                     |
56    callee's sp                 |  local variable n'            |
57*)
58
59
60(* Distinguish inputs and local variables                                               *)
61(* Calculate the offset based on fp of the temporaries bebing read in a callee          *)
62(* If a temporary is an input, then read it from the stack where the caller sets        *)
63(* If it is a local variable, then read it from the current frame                       *)
64
65fun calculate_relative_address (args,ir,outs,numSavedRegs) =
66  let
67    (* Identify those TMEMs that are actually arguments *)
68    val argT = #1 (List.foldl (fn (IR.TMEM n, (t,i)) =>
69                                  (T.enter(t, n, i), i+1)
70                               |  (arg, (t,i)) => (t, i+1)
71                              )
72                        (T.empty, 0)
73                        (IR.pair2list args)
74                  );
75
76    val i = ref 0
77    val localT = ref (T.empty);   (* Table for the local variables *)
78
79    (* For those TMEMs that are local variables, assign them in the stack according to the order of their apprearance *)
80
81    fun filter_mems (IR.TMEM n) =
82        ( case T.peek (argT, n) of
83                SOME k => IR.MEM (IR.fromAlias IR.fp, ~2 - k)           (* inputs *)
84           |     NONE =>
85                ( case T.peek(!localT, n) of
86                      SOME j => IR.MEM (IR.fromAlias IR.fp, 3 + j + numSavedRegs) (* existing local variable *)
87                   |  NONE =>
88                          ( localT := T.enter(!localT, n, !i);
89                            i := !i + 1;
90                            IR.MEM (IR.fromAlias IR.fp, 3 + (!i - 1) + numSavedRegs) (* local variables *)
91                          )
92                 )
93        )
94     |  filter_mems v = v
95
96    fun one_stm ({oper = op1, dst = dst1, src = src1}) =
97            {oper = op1, dst = List.map filter_mems dst1, src = List.map filter_mems src1}
98
99    fun adjust_exp (IR.PAIR (e1,e2)) =
100            IR.PAIR(adjust_exp e1, adjust_exp e2)
101     |  adjust_exp e =
102            filter_mems e
103
104    fun adjust_info {ins = ins', outs = outs', context = context', fspec = fspec'} =
105        {ins = adjust_exp ins', outs = adjust_exp outs', context = List.map adjust_exp context', fspec = fspec'}
106
107    fun visit (SC(ir1,ir2,info)) =
108         SC (visit ir1, visit ir2, adjust_info info)
109    |  visit (TR((e1,rop,e2),ir,info)) =
110         TR ((adjust_exp e1,rop,adjust_exp e2), visit ir, adjust_info info)
111    |  visit (CJ((e1,rop,e2),ir1,ir2,info)) =
112         CJ ((adjust_exp e1,rop,adjust_exp e2), visit ir1, visit ir2, adjust_info info)
113    |  visit (CALL(fname,pre,body,post,info)) =
114         CALL(fname, pre, body, post, adjust_info info)
115    |  visit (STM l) =
116         STM (List.map one_stm l)
117    |  visit (BLK (l,info)) =
118         BLK (List.map one_stm l, adjust_info info);
119
120  in
121        (adjust_exp args, visit ir, adjust_exp outs, T.numItems (!localT))
122  end
123
124(* ---------------------------------------------------------------------------------------------------------------------*)
125(*  Decrease and increase the value of a register by 4*n                                                                *)
126(*  These functions are used for modification of base registers for load ans stores. Since                              *)
127(*  loads and stores always consider 32bit words, the address had to be deviable by 4                                   *)
128(*  Therefore n is the number of words the register should change.                                                      *)
129(* ---------------------------------------------------------------------------------------------------------------------*)
130
131fun dec_p pt n = {oper = IR.msub, dst = [IR.REG pt], src = [IR.REG pt, IR.WCONST (Arbint.fromInt (4*n))]}
132
133fun inc_p pt n = {oper = IR.madd, dst = [IR.REG pt], src = [IR.REG pt, IR.WCONST (Arbint.fromInt (4*n))]};
134
135
136(* ---------------------------------------------------------------------------------------------------------------------*)
137(* Pre-call and post-call processing in compilance with the ARM Procedure Call standard                                 *)
138(* ---------------------------------------------------------------------------------------------------------------------*)
139
140(*      MOV     ip, sp
141        STMFA   sp!, {..., fp,ip,lr,pc}
142        SUB     fp, ip, #1
143        SUB     sp, sp, #var (* skip local variables *)
144*)
145
146fun entry_blk rs n =
147    [ {oper = IR.mmov, dst = [IR.REG (IR.fromAlias IR.ip)], src = [IR.REG (IR.fromAlias IR.sp)]},
148      {oper = IR.mpush, dst = [IR.REG (IR.fromAlias IR.sp)],
149       src = rs @ [IR.REG (IR.fromAlias IR.fp), IR.REG (IR.fromAlias IR.ip),
150                   IR.REG (IR.fromAlias IR.lr), IR.REG (IR.fromAlias IR.pc)]
151      },
152      {oper = IR.msub, dst = [IR.REG (IR.fromAlias IR.fp)], src = [IR.REG (IR.fromAlias IR.ip), IR.WCONST (Arbint.fromInt 4)]},
153      dec_p (IR.fromAlias IR.sp) n (* skip local variables *)
154    ]
155
156(*
157    ADD         sp, fp, 3 + #modified registers      (* Skip saved lr, sp, fp and modified registers *)
158    LDMFD       sp, {..., fp,sp,pc}
159*)
160
161fun exit_blk rs =
162    [
163      {oper = IR.msub, dst = [IR.REG (IR.fromAlias IR.sp)], src = [IR.REG (IR.fromAlias IR.fp), IR.WCONST (Arbint.fromInt (4* (3 + length rs)))]},
164      {oper = IR.mpop, dst = rs @ [IR.REG (IR.fromAlias IR.fp), IR.REG (IR.fromAlias IR.sp), IR.REG (IR.fromAlias IR.pc)],
165       src = [IR.REG (IR.fromAlias IR.sp)]}
166    ];
167
168
169(* ---------------------------------------------------------------------------------------------------------------------*)
170(*  Given a list of registers and memory slots, group consecutive registers together to be used by mpush and mpop       *)
171(*  For example, [r1,r2,m1,m3,r3,4w,r4,r5] is segmented to                                                              *)
172(*  ([r1,r2],true, 0),(m1,false,2),(m3,false,3),([r3],true,4),(4w,false,5),([r4,r5],true,6)                             *)
173(* ---------------------------------------------------------------------------------------------------------------------*)
174
175fun mk_reg_segments argL =
176  let val isBroken = ref false;
177
178      (* proceeds in reverse order of the list *)
179      fun one_arg (IR.REG r, (invRegLs, i)) =
180        let val flag = !isBroken
181            val _ = isBroken := false
182        in
183            if null invRegLs orelse flag then
184                (([IR.REG r], true, i) :: invRegLs, i-1)
185            else
186                let val (cur_segL, a, j) = hd invRegLs
187                in
188                    if null cur_segL then (([IR.REG r], true, i) :: (tl invRegLs), i-1)
189                    else ((IR.REG r :: cur_segL, true, i) :: (tl invRegLs), i-1)
190                end
191        end
192      |   one_arg (exp, (invRegLs, i)) =
193                (isBroken := true; (([exp],false, i) :: invRegLs, i-1))
194
195      val (invRegLs, i) = List.foldr one_arg ([], length argL - 1) argL
196
197  in
198      invRegLs
199  end
200
201(* ---------------------------------------------------------------------------------------------------------------------*)
202(*  Given a list of registers, generate a mpush or mpop statement                                                       *)
203(*  The sp is assumed to be in the right position                                                                       *)
204(*  If the list contain only one resiter, then use mldr and mstr instead                                                *)
205(* ---------------------------------------------------------------------------------------------------------------------*)
206
207fun mk_ldm_stm isPop r dataL =
208    if isPop then
209        if length dataL = 1 then
210               {oper = IR.mldr, dst = dataL, src = [IR.MEM (r,1)]}
211        else
212               {oper = IR.mpop, dst = dataL, src = [IR.REG r]}
213    else
214        if length dataL = 1 then
215               {oper = IR.mstr, dst = [IR.MEM(r,0)], src = dataL}
216        else
217               {oper = IR.mpush, dst = [IR.REG r], src = dataL}
218
219(* ---------------------------------------------------------------------------------------------------------------------*)
220(*  Write one argument to the memory slot referred by regNo and offset                                                  *)
221(*  Push arguments to the stack. If the argument comes from a register, and store it into the stack directly;           *)
222(*  if it comes from memory, then first load it into R10, and then store it into the stack;                             *)
223(*  if it is a constant, then assign it to R10 first then store it into the stack.                                      *)
224(* ---------------------------------------------------------------------------------------------------------------------*)
225
226fun write_one_arg (IR.MEM v) (regNo,offset) =
227               [  {oper = IR.mldr, dst = [IR.REG (!numAvaiRegs)], src = [IR.MEM v]},
228                  {oper = IR.mstr, dst = [IR.MEM (regNo,offset)], src = [IR.REG (!numAvaiRegs)]} ]
229 |  write_one_arg (IR.REG r) (regNo,offset) =
230               [  {oper = IR.mstr, dst = [IR.MEM (regNo,offset)], src = [IR.REG r]} ]
231 |  write_one_arg v (regNo,offset) =   (* v = NONCST n or WCONST w *)
232               [  {oper = IR.mmov, dst = [IR.REG 10], src = [v]},
233                  {oper = IR.mstr, dst = [IR.MEM (regNo,offset)], src = [IR.REG 10]} ]
234
235(* ---------------------------------------------------------------------------------------------------------------------*)
236(*  Read one argument from the memory slot referred by regNo and offset                                                 *)
237(*  If the destination is a register, then load it into the register directly;                                          *)
238(*  if it is in the memory, then first load the value into R10, and then store it into that memory location;            *)
239(*  The destination couldn't be a constant                                                                              *)
240(* ---------------------------------------------------------------------------------------------------------------------*)
241
242fun read_one_arg (IR.REG r) (regNo,offset) =
243               [  {oper = IR.mldr, dst = [IR.REG r], src = [IR.MEM(regNo,offset)]} ]
244 |  read_one_arg (IR.MEM v) (regNo,offset) =
245               [  {oper = IR.mldr, dst = [IR.REG (!numAvaiRegs)], src = [IR.MEM(regNo,offset)]},
246                  {oper = IR.mstr, dst = [IR.MEM v], src = [IR.REG (!numAvaiRegs)]} ]
247 |  read_one_arg _ _ =
248                 raise invalidArgs
249
250(* ---------------------------------------------------------------------------------------------------------------------*)
251(* Push a list of values that may be constants or in registers or in memory into the stack                              *)
252(* [1,2,3,...]                                                                                                          *)
253(* old pointer | 1 |                                                                                                    *)
254(*             | 2 |                                                                                                    *)
255(*             | 3 |                                                                                                    *)
256(*              ...                                                                                                     *)
257(* new pointer                                                                                                          *)
258(* Note that the elements in the list are stored in the memory from low addresses to high addresses                     *)
259(* ---------------------------------------------------------------------------------------------------------------------*)
260
261fun pushL regNo argL =
262  let
263      val offset = ref 0;
264
265      fun one_seg (regL, true, i) =
266              if length regL = 1 then
267                    write_one_arg (hd regL) (regNo, !offset - i) (* relative offset should be negative *)
268              else
269                  let val k = !offset in
270                    ( offset := i + length regL;
271                      [dec_p regNo (i - k),
272                       (* reverse the regL in accordance with the order of STM *)
273                       mk_ldm_stm false regNo (List.rev regL)])
274                  end
275       | one_seg ([v], false, i) =
276                  write_one_arg v (regNo, !offset - i)
277       | one_seg _ = raise invalidArgs
278  in
279      (List.foldl (fn (x,s) => s @ one_seg x) [] (mk_reg_segments argL)) @
280       [dec_p regNo (length argL - !offset)]
281  end
282
283(* ---------------------------------------------------------------------------------------------------------------------*)
284(* Pop from the stack a list of values that may be in registers or in memory into the stack                             *)
285(*           ...                                                                                                        *)
286(*           | 3 |                                                                                                      *)
287(*           | 2 |                                                                                                      *)
288(*           | 1 |                                                                                                      *)
289(*  pointer                                                                                                             *)
290(*  be read to the list [1,2,3,...]                                                                                     *)
291(* ---------------------------------------------------------------------------------------------------------------------*)
292
293fun popL regNo argL =
294  let
295      val offset = ref 0;
296
297      fun one_seg (regL, true, i) =
298              if length regL = 1 then
299                    read_one_arg (hd regL) (regNo, i - !offset + 1) (* relative address should be positive *)
300              else
301                  let val k = !offset in
302                    ( offset := i;
303                      [inc_p regNo (i - k),
304                       mk_ldm_stm true regNo regL])
305                  end
306       | one_seg ([v], false, i) =
307                  read_one_arg v (regNo, i - !offset + 1)
308       | one_seg _ = raise invalidArgs
309  in
310      (List.foldl (fn (x,s) => s @ one_seg x) [] (mk_reg_segments argL)) @
311       [inc_p regNo (length argL - !offset)]
312  end
313
314(* ---------------------------------------------------------------------------------------------------------------------*)
315(* Pass by the caller the parameters to the callee                                                                      *)
316(* All arguments are passed through the stack                                                                           *)
317(* Stack status after passing                                                                                           *)
318(*          ...                                                                                                         *)
319(*        | arg 3 |                                                                                                     *)
320(*        | arg 2 |                                                                                                     *)
321(* sp     | arg 1 |                                                                                                     *)
322(* ---------------------------------------------------------------------------------------------------------------------*)
323
324fun pass_args argL =
325  pushL (IR.fromAlias IR.sp) (List.rev argL)
326
327(* ---------------------------------------------------------------------------------------------------------------------*)
328(* The callee obtains the arguments passed by the caller though the stack                                               *)
329(* By default, the first four arguments are loaded into r0-r4                                                           *)
330(* The rest arguments has been in the right positions. That is, we need not to get them explicitly                      *)
331(* Note that the register allocation assumes above convention                                                           *)
332(* ---------------------------------------------------------------------------------------------------------------------*)
333
334fun get_args argL =
335   let
336       val len = length argL;
337       val len1 = if len < (!numAvaiRegs) then len else (!numAvaiRegs)
338
339       fun mk_regL 0 = [IR.REG 0]
340        |  mk_regL n = mk_regL (n-1) @ [IR.REG n];
341
342   in
343       popL (IR.fromAlias (IR.ip)) argL
344       (* Note that callee's IP equals to caller's SP, we use the IP here to load the arguments *)
345   end;
346
347(* ---------------------------------------------------------------------------------------------------------------------*)
348(* Pass by the callee the results to the caller                                                                         *)
349(* All results are passed through the stack                                                                             *)
350(* Stack status after passing                                                                                           *)
351(*                                                                                                                      *)
352(*             ...                                                                                                      *)
353(*         | result 3 |                                                                                                 *)
354(*         | result 2 |                                                                                                 *)
355(*         | result 1 |                                                                                                 *)
356(*    sp                                                                                                                *)
357(* ---------------------------------------------------------------------------------------------------------------------*)
358
359fun send_results outL numArgs =
360   let
361       (* skip the arguments and the stored pc, then go to the position right before the first output*)
362       val sOffset = numArgs + length outL + 1;
363       val stms = pushL (IR.fromAlias IR.sp) (List.rev outL)
364   in
365       { oper = IR.madd,
366         dst = [IR.REG (IR.fromAlias IR.sp)],
367              src = [IR.REG (IR.fromAlias IR.fp), IR.WCONST (Arbint.fromInt (4* sOffset))]
368       }
369       :: stms
370   end
371
372(* ---------------------------------------------------------------------------------------------------------------------*)
373(* The caller retreives the results passed by the callee though the stack                                               *)
374(* Stack status                                                                                                         *)
375(*                                                                                                                      *)
376(*             ...                                                                                                      *)
377(*         | result 3 |                                                                                                 *)
378(*         | result 2 |                                                                                                 *)
379(*         | result 1 |                                                                                                 *)
380(*     sp  |          |                                                                                                 *)
381(* ---------------------------------------------------------------------------------------------------------------------*)
382
383fun get_results outL =
384        popL (IR.fromAlias IR.sp) outL
385
386(* ---------------------------------------------------------------------------------------------------------------------*)
387(* Function Call: Pre-processing, Post-processing and Adjusted Body                                                     *)
388(* ---------------------------------------------------------------------------------------------------------------------*)
389
390fun compute_fcall_info ((outer_ins,outer_outs),(caller_src,caller_dst),(callee_ins,callee_outs),rs,context) =
391    let
392        fun to_stack expL =
393            let val len = length expL
394                val i = ref 0;
395            in
396                List.map (fn exp => ( i := !i + 1; MEM(11, ~(3 + (len - !i))))) expL
397            end
398
399        val to_be_stored = S.difference (list2set (List.map REG (S.listItems rs)), pair2set caller_dst);
400        val rs' = S.intersection(to_be_stored, S.union (pair2set outer_outs, list2set context));
401        val pre_ins = trim_pair (PAIR (caller_src, set2pair rs'));
402        val stored = (list2pair o to_stack o set2list) rs';
403        val pre_outs = trim_pair (PAIR (callee_ins, stored));
404        val body_ins = pre_outs;
405        val body_outs = trim_pair(PAIR (callee_outs, stored));
406        val post_ins = body_outs;
407        val post_outs = trim_pair(PAIR (caller_dst, set2pair rs'));
408        val context' = set2list (S.difference (list2set context, rs'))
409    in
410       ((pre_ins,pre_outs),(body_ins,body_outs),(post_ins,post_outs),set2list rs',context')
411    end
412
413
414val involved_defs = ref ([] : thm list);
415
416fun convert_fcall (CALL(fname, pre, body, post, outer_info)) =
417    let
418        (* pass the arguments to the callee, and callee will save and restore anything  *)
419        (* create the BL statement                                                      *)
420        (* then modify the sp to point to the real position of the returning arguments  *)
421
422        val (outer_ins,outer_outs) = (#ins outer_info, #outs outer_info);
423        val (caller_dst,caller_src) = let val x = get_annt body in (#outs x, #ins x) end;
424        val {ir = (callee_ins, callee_ir, callee_outs), regs = rs, localNum = n, def = f_def, ...} = declFuncs.getFunc fname;
425        val _ = involved_defs := (!involved_defs) @ [f_def];
426        val ((pre_ins,pre_outs),(body_ins,body_outs),(post_ins,post_outs),rs',context) =
427             compute_fcall_info ((outer_ins,outer_outs),(caller_src,caller_dst),(callee_ins,callee_outs),rs,#context outer_info);
428
429        val reserve_space_for_outputs =
430                [dec_p (fromAlias sp) (length (pair2list caller_dst))];
431
432        val pre' = rm_dummy_inst (BLK (
433                        reserve_space_for_outputs @
434                        pass_args (pair2list caller_src) @
435                        entry_blk rs' n @
436                        get_args (pair2list callee_ins),
437                    {ins = pre_ins, outs = pre_outs, context = context, fspec = thm_t}));
438
439        val body' = apply_to_info callee_ir (fn info' => {ins = body_ins, outs = body_outs, context = context, fspec = thm_t});
440
441        val post' = rm_dummy_inst (BLK (
442                        send_results (IR.pair2list callee_outs) (length (IR.pair2list callee_ins)) @
443                         get_results (IR.pair2list caller_dst) @
444                        exit_blk rs',
445                    {ins = post_ins, outs = post_outs, context = context, fspec = thm_t}))
446
447    in
448        CALL(fname, pre', body', post', outer_info)
449    end
450
451 |  convert_fcall (SC(s1,s2,info)) = SC (convert_fcall s1, convert_fcall s2, info)
452 |  convert_fcall (CJ(cond,s1,s2,info)) = CJ (cond, convert_fcall s1, convert_fcall s2, info)
453 |  convert_fcall (TR(cond,s,info)) = TR (cond, convert_fcall s, info)
454 |  convert_fcall ir = ir;
455
456
457(* ---------------------------------------------------------------------------------------------------------------------*)
458(* Link caller and callees together                                                                                     *)
459(* ---------------------------------------------------------------------------------------------------------------------*)
460
461fun link_ir prog =
462  let
463      val (fname, ftype, f_ir as (ins,ir0,outs), defs) = sfl2ir prog;
464      val rs = S.addList (S.empty regAllocation.intOrder, get_modified_regs ir0);
465      val (ins1,ir1,outs1, localNum) = calculate_relative_address (ins,ir0,outs,S.numItems rs);
466      val _ = (involved_defs := [];
467               declFuncs.putFunc (fname, ftype, (ins1,ir1,outs1), rs, localNum, (hd defs)));
468      val ir2 = convert_fcall ir1
469      val ir3 = match_ins_outs ir2
470  in
471      (fname,ftype,(ins1,ir3,outs1), defs @ (!involved_defs))
472  end;
473
474
475end (* local structure ... *)
476
477end (* structure *)
478