1structure funcCall (* :> funcCall *) =
2struct
3
4(* app load ["NormalTheory", "basic", "regAlloc"] *)
5
6open HolKernel Parse boolLib bossLib;
7open pairLib pairSyntax PairRules NormalTheory basic;
8
9val atom_tm = prim_mk_const{Name="atom",Thy="Normal"}
10fun mk_atom tm = mk_comb (inst [alpha |-> type_of tm] atom_tm,tm)
11
12(*----------------------------------------------------------------------------------------------*)
13(*                                                                                              *)
14(*----------------------------------------------------------------------------------------------*)
15
16structure M = Binarymap
17structure S = Binaryset
18val VarType = ref (Type `:word32`) (* numSyntax.num *)
19
20(*----------------------------------------------------------------------------------------------*)
21(* Pre-defined variables and functions                                                          *)
22(*----------------------------------------------------------------------------------------------*)
23
24fun is_reg x = (String.sub (term_to_string x,0) = #"r")
25fun is_mem x = (String.sub (term_to_string x,0) = #"m")
26
27fun tvarOrder (t1:term,t2:term) =
28  let val (s1,s2) = (term_to_string t1, term_to_string t2) in
29    if s1 > s2 then GREATER
30    else if s1 = s2 then EQUAL
31    else LESS
32  end;
33
34(* Is an expression a function application? *)
35
36fun is_fun exp =
37  not (is_comb exp) andalso
38  (#1 (dest_type (type_of exp)) = "fun")
39  handle HOL_ERR _ => false;
40
41(*----------------------------------------------------------------------------------------------*)
42(* Function calls in a caller-save style.                                                       *)
43(*----------------------------------------------------------------------------------------------*)
44
45(* Traverse the function body to find all modified registers and the next available memory slot *)
46
47fun process_body body =
48 let
49
50   fun traverse t (rS, wS) =
51     if is_let t then
52       let val (v,M,N) = dest_plet t
53           val (rS1, wS1) = traverse M (rS, S.addList(wS, List.filter is_reg (strip_pair v)))
54       in
55           traverse N (rS1, wS1)
56       end
57     else if is_cond t then
58       let val (J,M,N) = dest_cond t in
59           ((traverse N) o (traverse M) o (traverse J)) (rS, wS)
60       end
61     else if is_pair t then
62       let val (M,N) = dest_pair t in
63           ((traverse N) o (traverse M)) (rS, wS)
64       end
65     else if is_pabs t then
66       let val (M,N) = dest_pabs t in
67           ((traverse N) o (traverse M)) (rS, wS)
68       end
69     else if is_comb t then
70       let val (M,N) = dest_comb t in
71           ((traverse N) o (traverse M)) (rS, wS)
72       end
73     else if is_reg t orelse is_mem t then
74       (S.add(rS, t), wS)
75     else (rS, wS)
76
77   val (rS', wS') = traverse body (S.empty tvarOrder, S.empty tvarOrder)
78   val next_avail_slot = List.foldl  (* the first unused memory slot *)
79                           (fn (v, i) =>
80                              if is_mem v then
81                                let val s = #1 (dest_var v)
82                                    val j = valOf (Int.fromString (substring(s, 1, String.size s - 1)))
83                                in if j > i then j else i
84                                end
85                              else i
86                           )
87                           0 (S.listItems rS')
88
89 in
90   (wS', next_avail_slot)
91 end;
92
93(* Find all modified registers and the next available memory slot for all functions *)
94
95val fmap = ref (M.mkDict tvarOrder);
96
97fun investigate_def def =
98   let
99     val (fname, fbody) = dest_eq (concl (SPEC_ALL def))
100     val (args,body) = dest_pabs fbody handle _ => (#2 (dest_comb fname), fbody)
101     val fname = if is_pabs fbody then fname else #1 (dest_comb fname)
102   in
103     fmap := M.insert(!fmap, fname, ((args, identify_output body), process_body body))
104   end
105
106(* Convert a function body into its call-save format *)
107
108fun save (wS, next_slot) exp =
109   #1 (List.foldl (fn (r, (e, slot)) =>
110                     (mk_plet (mk_var("m" ^ Int.toString(slot) , !VarType), mk_atom r, e), slot + 1))
111                  (exp, next_slot)
112                  (S.listItems wS)
113      )
114
115fun restore (wS, next_slot) exp =
116   #1 (List.foldr (fn (r, (e, slot)) =>
117                     (mk_plet (r, mk_atom (mk_var("m" ^ Int.toString(slot), !VarType)), e), slot + 1))
118                  (exp, next_slot)
119                  (rev (S.listItems wS))
120      )
121
122(*
123  fmap := preprocess defs;
124*)
125
126val tr_f = ref (``T``);    (* the name of a recursive function *)
127
128fun format_call (fname, dst, src, cont) =
129  let val ((src0, dst0), (s, next_slot)) = M.find(!fmap, fname)
130       (* handle NotFound => (investigate_def (DB.definition (#1 (dest_const fname) ^ "_def"));
131                             M.find(!fmap, fname))
132       *)
133      val s1 = S.addList (S.addList(s, strip_pair src0), strip_pair dst0)
134      val s2 = S.difference(s1, S.addList(S.empty tvarOrder, strip_pair dst))
135
136      val t1 = restore (s2, next_slot) cont
137      val t2 = regAlloc.parallel_move dst dst0 t1
138      val t3 = mk_plet(dst0, mk_comb(fname, src0), t2)
139      val t4 = regAlloc.parallel_move src0 src t3
140      val t5 = save (s2, next_slot) t4
141  in
142      t5
143  end
144
145fun caller_save t =
146  if is_let t then
147    let val (v,M,N) = dest_plet t in
148      if is_comb M andalso not (is_atomic M) then
149        let val (x,y) = dest_comb M in
150           if is_fun x andalso not (x = !tr_f) then (* non-recursive function application *)
151             let
152                 val (fname, dst, src, cont) = (x, v, y, caller_save N)
153             in
154                 format_call (fname, dst, src, cont)
155             end
156           else
157             mk_plet(v, caller_save M, caller_save N)
158        end
159      else
160        mk_plet(v, caller_save M, caller_save N) (* not function application *)
161    end
162  else if is_cond t then
163    let val (J,M,N) = dest_cond t in
164        mk_cond (J, caller_save M, caller_save N)
165    end
166  else if is_pair t then
167    let val (M,N) = dest_pair t in
168        mk_pair (caller_save M, caller_save N)
169    end
170  else if is_pabs t then
171    let val (M,N) = dest_pabs t in
172        mk_pabs (caller_save M, caller_save N)
173    end
174  else if is_comb t then
175    let val (M,N) = dest_comb t in
176       mk_comb(caller_save M, caller_save N)
177    end
178  else t
179
180(* Function "trav" traverses a term and adds pre-call saving and post-call saving instructions
181   for each function call; it also make the outputs of the two branches of a conditional
182   statement match, i.e. both branches have the same outputs.
183*)
184
185fun trav t output =
186  if is_let t then
187    let val (v,M,N) = dest_plet t
188        val (M', _) =  trav M NONE
189    in
190      if is_comb M andalso not (is_atomic M) then
191        let val (x,y) = dest_comb M in
192           if is_fun x andalso not (x = !tr_f) then (* non-recursive function application *)
193             let
194                 val (N', output') = trav N output
195                 val (fname, dst, src, cont) = (x, v, y, N')
196                 val t' = format_call (fname, dst, src, cont)
197             in
198                 (t', output')
199             end
200           else
201             let val (N', output') = trav N output
202             in (mk_plet(v, M', N'), output') end
203        end
204      else
205        let val (N', output') = trav N output
206        in (mk_plet(v, M', N'), output') end (* not function application *)
207    end
208  else if is_cond t then
209    let val (J,M1,M2) = dest_cond t
210        val (M1', output1) = trav M1 output
211        val (M2', output2) = trav M2 output1
212    in
213        (mk_cond(J, M1', M2'), output1)
214    end
215  else if is_pair t orelse is_atomic t then
216    case output of
217         NONE => (t, SOME t)
218      |  SOME x => (regAlloc.parallel_move x t x, output)
219  else if is_pabs t then
220    let val (M,N) = dest_pabs t
221        val (N', _) = trav N NONE
222    in
223        (mk_pabs (M, N'), output)
224    end
225  else (t, output)
226
227(* Process function calls in a caller-save style *)
228
229fun caller_save_call def =
230    let val (fname, fbody) = dest_eq (concl (SPEC_ALL def))
231        val (args,body) = dest_pabs fbody handle _ => (#2 (dest_comb fname), fbody)
232        val (sane,var_type) = pre_check(args,body)
233        val fname = if is_pabs fbody then fname else #1 (dest_comb fname)
234        val _ = (VarType := var_type; tr_f := fname)
235        val _ = investigate_def def   (* store the information of the current function into !fmap *)
236    in if sane then
237        let
238          val (body', _) = trav body NONE
239          val th0 = SYM (QCONV(SIMP_CONV pure_ss [LET_ATOM]) body')
240          val (r,t) = dest_eq(concl th0)
241          val lem1 = ALPHA body r
242          val th1 = TRANS lem1 th0
243          val th2 = REWRITE_RULE [Once th1] def
244          val th3 = REWRITE_RULE [ATOM_ID] th2
245        in SIMP_RULE bool_ss [ELIM_USELESS_LET] th3
246        end
247       else def
248    end
249
250(*
251fun callerSave defs =
252 let
253   fun one_fun def =
254    let val (fname, fbody) = dest_eq (concl (SPEC_ALL def))
255        val (args,body) = dest_pabs fbody handle _ => (#2 (dest_comb fname), fbody)
256        val (sane,var_type) = pre_check(args,body)
257        val fname = if is_pabs fbody then fname else #1 (dest_comb fname)
258        val _ = (VarType := var_type; tr_f := fname)
259    in if sane then
260        let
261          val body' = caller_save body
262          val th0 = SYM (QCONV(SIMP_CONV pure_ss [LET_ATOM]) body')
263          val th1 = REWRITE_RULE [ELIM_USELESS_LET] th0
264          val th2 = REWRITE_RULE [Once th1] def
265          val th3 = REWRITE_RULE [ATOM_ID] th2
266        in th3
267        end
268       else def
269    end
270 in
271   List.map one_fun defs
272 end
273*)
274
275(*----------------------------------------------------------------------------------------------*)
276(* For debugging                                                                                *)
277(*----------------------------------------------------------------------------------------------*)
278
279fun mm () =
280  List.map (fn (fname, ((src, dst), (wS, slot))) => (fname, (src, dst, S.listItems wS, slot))) (M.listItems (!fmap))
281
282
283end
284