1structure annotatedIR = struct
2  local
3
4(*
5quietdec := true;
6*)
7
8  open HolKernel Parse boolLib preARMTheory pairLib simpLib bossLib
9       numSyntax optionSyntax listSyntax ILTheory IRSyntax
10(*
11        quietdec := false;
12        open annotatedIR
13*)
14  in
15
16 (*---------------------------------------------------------------------------------*)
17 (*      Annotated IR tree                                                          *)
18 (*---------------------------------------------------------------------------------*)
19
20   (* annotation at each node of the IR tree *)
21   type annt = {fspec : thm, ins : exp, outs : exp, context : exp list};
22
23   (* conditions of CJ and TR *)
24   type cond = exp * rop * exp;
25
26   datatype anntIR =  TR of cond * anntIR * annt
27                   |  SC of anntIR * anntIR * annt
28                   |  CJ of cond * anntIR * anntIR * annt
29                   |  BLK of instr list * annt
30                   |  STM of instr list
31                   |  CALL of string * anntIR * anntIR * anntIR * annt
32
33
34 (*---------------------------------------------------------------------------------*)
35 (*      Print an annotated IR tree                                                 *)
36 (*---------------------------------------------------------------------------------*)
37
38  val show_call_detail = ref true;
39
40  fun print_ir (outstream, s0) =
41    let
42
43      fun say s = TextIO.output(outstream,s)
44      fun sayln s= (say s; say "\n")
45
46      fun indent 0 = ()
47        | indent i = (say " "; indent(i-1))
48
49      fun printtree(SC(s1,s2,info),d) =
50            (indent d; sayln "SC("; printtree(s1,d+2); sayln ","; printtree(s2,d+2); say ")")
51        | printtree(CJ((exp1,rop,exp2),s1,s2,info),d) =
52            (indent d; say "CJ("; say (format_exp exp1 ^ " " ^ print_rop rop ^ " " ^ format_exp exp2); sayln ",";
53                       printtree(s1,d+2); sayln ","; printtree(s2,d+2); say ")")
54        | printtree(TR((exp1,rop,exp2),s,info),d) =
55            (indent d; say "TR("; say (format_exp exp1 ^ " " ^ print_rop rop ^ format_exp exp2); sayln ",";
56                       printtree(s,d+2); say ")")
57        | printtree(CALL(fname,pre,body,post,info),d) =
58            if not (!show_call_detail) then
59                (indent d; say ("CALL(" ^ fname ^ ")"))
60            else
61                let val ir' = SC (pre, SC (body, post, info), info)
62                    (* val ir'' = (merge_stm o rm_dummy_inst) ir' *)
63                in
64                    printtree (ir', d)
65                end
66
67        | printtree(STM stmL,d) =
68            (indent d; if null stmL then say ("[]")
69                       else say ("[" ^ formatInst (hd stmL) ^
70                                 (itlist (curry (fn (stm,str) => "; " ^ formatInst stm ^ str)) (tl stmL) "") ^ "]"))
71        | printtree(BLK(stmL, info),d) =
72            printtree(STM stmL,d)
73
74    in
75      printtree(s0,0); sayln ""; TextIO.flushOut outstream
76    end
77
78  fun printIR ir = print_ir (TextIO.stdOut, ir)
79
80  fun printIR2 (f_name, f_type, (ins,ir,outs), defs) =
81      ( print ("Module: " ^ f_name ^ "\n");
82        print ("Inputs: " ^ format_exp ins ^ "\n" ^ "Outputs: " ^ format_exp outs ^ "\n");
83        printIR ir
84      )
85
86 (*---------------------------------------------------------------------------------*)
87 (*      Assistant Graph Functions                                                  *)
88 (*---------------------------------------------------------------------------------*)
89
90   structure G = Graph;
91
92   fun mk_edge (n0,n1,lab) gr =
93    if n0 = n1 then gr
94    else if (List.find (fn m => #2 m = n0) (#1(G.context(n1,gr)))) <> NONE then gr              (* raise Edge *)
95    else
96       let val ((p1,m1,l1,s1), del_n1) = G.match(n1,gr)
97       in G.embed(((lab,n0)::p1,m1,l1,s1),del_n1)
98    end;
99
100   (* retreive the subgraph starting from the start node and ends when "f node" holds *)
101   fun sub_graph gr start_node f =
102       let
103           fun one_node (nodeNo,gr') =
104               if f nodeNo then
105                   gr'
106               else
107                   List.foldl (fn ((a,b),gr'') =>
108                                   one_node (b, G.embed (([(a,nodeNo)],b,#3 (G.context(b,gr)),[]), gr''))
109                              handle e => mk_edge (nodeNo,b,0) gr'')
110                            gr' (#4 (G.context(nodeNo,gr)))
111       in
112           one_node (start_node, G.embed(([],start_node, #3 (G.context(start_node,gr)),[]),G.empty))
113       end;
114
115    (* locate a node, whenever the final node is reached, search terminats *)
116    fun locate gr start_node f =
117       let
118           val maxNodeNo = G.noNodes gr - 1;
119           fun one_node nodeNo =
120               if f nodeNo then
121                   (SOME nodeNo,true)
122               else if nodeNo = maxNodeNo then
123                   (NONE,false)
124               else
125                   List.foldl (fn ((a,b),(n,found)) => if found then (n,found) else one_node b)
126                            (SOME start_node,false) (#4 (G.context(nodeNo,gr)))
127       in
128           #1 (one_node start_node)
129       end;
130
131 (*---------------------------------------------------------------------------------*)
132 (*      Conversion from CFG to IR tree                                             *)
133 (*---------------------------------------------------------------------------------*)
134
135   fun get_label (Assem.LABEL {lab = lab'}) = lab'
136    |  get_label _ = raise ERR "IL" ("Expecting a label instruction")
137
138   (* find the node including the function label, if there are more than one incoming edges, then a rec is found *)
139
140   fun is_rec gr n =
141       let
142           val context = G.context(n,gr)
143           fun is_label (Assem.LABEL _) = true
144            |  is_label _ = false
145       in
146           if is_label (#instr ((#3 context):CFG.node)) then
147                (length (#1 context) > 1, n)
148           else is_rec gr (#2 (hd (#4 context)))
149       end
150       handle e => (false,0);        (* no label node in the graph *)
151
152
153   (* Given a TR cfg and the node including the function name label, break this cfg into three parts: the pre-condition part,
154      the basic case part and the recursive case part. The condition is also derived.                                          *)
155
156
157   fun break_rec gr lab_node =
158     let
159        fun get_sucL n = #4 (G.context(n,gr));
160        fun get_preL n = #1 (G.context(n,gr));
161
162        val last_node = valOf (locate gr 0 (fn n => null (get_sucL n)));
163        fun find_join_node n =
164            if length (get_preL n) > 1 then n
165            else find_join_node (#2 (hd (get_preL n)));
166        val join_node = find_join_node last_node;  (* the node that the basic and recursive parts join *)
167
168        (* the nodes jumping to the join node *)
169        val BAL_node = #2 (valOf (List.find (fn (flag,n) => flag = 1) (get_preL join_node)));
170        val b_end_node = #2 (valOf (List.find (fn (flag,n) => flag = 0) (get_preL join_node)));
171        val b_start_node = #2 (valOf (List.find (fn (flag,n) => flag = 2) (get_sucL BAL_node)));
172        val b_cfg = sub_graph gr b_start_node (fn n => n = b_end_node);
173
174        val cj_node = #2 (valOf (List.find (fn (flag,n) => flag = 1) (get_preL b_start_node)));
175        val cmp_node = #2 (hd (get_preL cj_node));
176        val cond = to_cond (#instr ((#3 (G.context(cmp_node,gr))):CFG.node),
177                            #instr ((#3 (G.context(cj_node,gr))):CFG.node));
178
179        val r_start_node = #2 (valOf (List.find (fn (flag,n) => flag = 0) (get_sucL cj_node)));
180        val BL_node = #2 (valOf (List.find (fn (flag,n) => flag = 1) (get_preL lab_node)));
181        val r_end_node = #2 (hd (get_preL BL_node));
182        val r_cfg = sub_graph gr r_start_node (fn n => n = r_end_node);
183
184        val p_start_node = if #2 (hd (get_sucL lab_node)) = cmp_node then NONE
185                           else SOME (#2 (hd (get_sucL lab_node)));
186        val p_cfg = case p_start_node of
187                          NONE => NONE  (* the pre-condition part may be empty *)
188                      |   SOME k => SOME (sub_graph gr k (fn n => n = #2 (hd (get_preL cj_node))))
189    in
190        (cond, ((p_cfg,p_start_node), (b_cfg,b_start_node), (r_cfg,r_start_node)), join_node)
191    end
192
193   (* Given a CJ cfg and the node including the cmp instruction, break this cfg into the true case part and
194      the false case part.                                                        *)
195
196   fun break_cj gr cmp_node =
197     let
198        fun get_sucL n = #4 (G.context(n,gr));
199        fun get_preL n = #1 (G.context(n,gr));
200
201        val cj_node = #2 (hd (get_sucL cmp_node));
202        val sucL = #4 (G.context(cj_node,gr));
203        val (t_start_node, f_start_node) =  (#2 (valOf (List.find (fn (flag,n) => flag = 1) sucL)),
204                                             #2 (valOf (List.find (fn (flag,n) => flag = 0) sucL)));
205        val bal_node = #2 (valOf (List.find (fn (flag,n) => flag = 2) (get_preL t_start_node)));
206        val f_end_node = #2 (hd (get_preL bal_node));
207
208        val join_node = #2 (valOf (List.find (fn (flag,n) => flag = 1) (get_sucL bal_node)));
209        val t_end_node =  #2 (valOf (List.find (fn (flag,n) => flag = 0) (get_preL join_node)));
210
211        val cond = to_cond (#instr ((#3 (G.context(cmp_node,gr))):CFG.node),
212                            #instr ((#3 (G.context(cj_node,gr))):CFG.node));
213
214        val (t_cfg,f_cfg) = (sub_graph gr t_start_node (fn n => n = t_end_node),
215                             sub_graph gr f_start_node (fn n => n = f_end_node))
216    in
217        (cond, ((t_cfg,t_start_node), (f_cfg,f_start_node)), join_node)
218    end
219
220
221   fun convert_cond (exp1, rop, exp2) =
222       (to_exp exp1, to_rop rop, to_exp exp2);
223
224 (*---------------------------------------------------------------------------------*)
225 (*  Convert cfg consisting of SC, CJ and CALL structures to ir tree                *)
226 (*---------------------------------------------------------------------------------*)
227
228   val thm_t = DECIDE (Term `T`);
229   val init_info = {fspec = thm_t, ins = NA, outs = NA, context = []};
230(*
231        val (cfg, n) = (r_cfg, r_start_node)
232        val n = (#2 (hd sucL))
233fun extract (Assem.OPER {oper = (Assem.BL,_,_), dst = dList, src = sList, jump = jp}) = (dList, sList, jp);
234val (dList, sList, jp) = extract inst'
235val x = inst'
236*)
237
238   fun convert cfg n =
239        let
240           val (preL,_,{instr = inst', def = def', use = sue'},sucL) = G.context(n,cfg);
241        in
242           case inst' of
243               Assem.OPER {oper = (Assem.NOP, _, _), ...} =>
244                     if not (null sucL) then convert cfg (#2 (hd sucL)) else STM []
245
246           |   Assem.OPER {oper = (Assem.BL,_,_), dst = dList, src = sList, jump = jp} =>
247                     let val stm = CALL(Symbol.name(hd (valOf jp)), STM [], STM [], STM [],
248                                      {fspec = thm_t, ins = list2pair (List.map to_exp sList),
249                                        outs = list2pair (List.map to_exp dList), context = []})
250                     in  if null sucL then stm
251                         else SC (stm, convert cfg (#2 (hd sucL)), init_info)
252                     end
253           |   Assem.LABEL {...} =>
254                     if not (null sucL) then convert cfg (#2 (hd sucL)) else STM []
255
256           |   Assem.OPER {oper = (Assem.CMP, _, _), ...} =>
257                     let val (cond, ((t_cfg,t_start_node),(f_cfg,f_start_node)), join_node) = break_cj cfg n;
258                              (* val rest_cfg = sub_graph gr join_node (fn k => k = end_node); *)
259                     in
260                         SC (CJ (convert_cond cond, convert t_cfg t_start_node, convert f_cfg f_start_node, init_info),
261                             convert cfg join_node, init_info)
262                     end
263
264           |   x =>
265                     if not (null sucL) then
266                         SC (STM [to_inst x], convert cfg (#2 (hd sucL)), init_info)
267                     else STM [to_inst x]
268        end
269
270(*---------------------------------------------------------------------------------*)
271(*  Convert a cfg corresonding to a function to ir                                 *)
272(*---------------------------------------------------------------------------------*)
273
274   fun convert_module (ins,cfg,outs) =
275      let
276         val (flag, lab_node) = is_rec cfg 0
277         val (inL,outL) = (pair2list ins, pair2list outs)
278      in
279         if not flag then convert cfg lab_node
280         else
281            let val (cond, ((p_cfg,p_start_node), (b_cfg,b_start_node), (r_cfg,r_start_node)), join_node) = break_rec cfg lab_node
282                val (b_ir,r_ir) = (convert b_cfg b_start_node, convert r_cfg r_start_node);
283
284                val bal_node = #2 (valOf (List.find (fn (flag,n) => flag = 1) (#1 (G.context(lab_node,cfg)))));
285                val (Assem.OPER {src = rec_argL,...}) = #instr ((#3(G.context(bal_node,cfg))):CFG.node);
286                val rec_args_pass_ir = STM (List.map (fn (dexp,sexp) => {oper = mmov, src = [sexp], dst = [dexp]})
287                             (zip inL (pair2list (to_exp (hd rec_argL)))))    (* dst_exp  = BL (src_exp) *)
288
289                val info = {fspec = DECIDE(Term`T`), ins = ins, outs = ins, context = []}
290                val r_ir_1 = SC(r_ir,rec_args_pass_ir, info)
291
292                val cond' = convert_cond cond
293                val tr_ir = case p_cfg of
294                                 NONE =>
295                                    SC (TR (cond', r_ir_1, {fspec = thm_t, ins = ins, outs = ins, context = []}),
296                                        b_ir, info)
297                               |  SOME p_gr =>
298                                    let val p_ir = convert p_gr (valOf p_start_node);
299                                        val ir0 = SC (r_ir_1, p_ir, init_info)
300                                        val ir1 = TR (cond', ir0, {fspec = thm_t, ins = ins, outs = outs, context = []})
301                                        val ir2 = SC (ir1, b_ir, {fspec = thm_t, ins = ins, outs = ins, context = []})
302                                    in
303                                         SC (p_ir, ir2, {fspec = thm_t, ins = ins, outs = outs, context = []})
304                                    end
305             in
306                tr_ir
307             end
308       end;
309
310 (*---------------------------------------------------------------------------------*)
311 (*      Simplify the IR tree                                                       *)
312 (*---------------------------------------------------------------------------------*)
313
314   fun is_dummy_inst ({oper = mmov, src = sList, dst = dList}) =
315         hd sList = hd dList
316    |  is_dummy_inst (stm as {oper = op', src = sList, dst = dList}) =
317         if length sList < 2 then false
318         else if hd dList = hd sList then
319            (hd (tl sList) = WCONST Arbint.zero andalso (op' = madd orelse op' = msub orelse op' = mlsl orelse
320                     op' = mlsr orelse op' = masr orelse op' = mror))
321         else if hd dList = hd (tl (sList)) then
322            hd sList = WCONST Arbint.zero andalso (op' = madd orelse op' = mrsb)
323         else
324            false;
325
326   fun rm_dummy_inst (SC(ir1,ir2,info)) =
327         SC (rm_dummy_inst ir1, rm_dummy_inst ir2, info)
328    |  rm_dummy_inst (TR(cond,ir2,info)) =
329         TR(cond, rm_dummy_inst ir2, info)
330    |  rm_dummy_inst (CJ(cond,ir1,ir2,info)) =
331         CJ(cond, rm_dummy_inst ir1, rm_dummy_inst ir2, info)
332    |  rm_dummy_inst (STM stmL) =
333         STM (List.filter (not o is_dummy_inst) stmL)
334    | rm_dummy_inst (CALL (fname,pre,body,post,info)) =
335         CALL (fname, rm_dummy_inst pre, rm_dummy_inst body, rm_dummy_inst post,info)
336    | rm_dummy_inst (BLK (instL, info)) =
337         BLK (List.filter (not o is_dummy_inst) instL,info)
338
339
340   fun merge_ir ir =
341     let
342
343       (* Determine whether two irs are equal or not *)
344
345       fun is_ir_equal (SC(s1,s2,_)) (SC(s3,s4,_)) =
346             is_ir_equal s1 s3 andalso is_ir_equal s2 s4
347        |  is_ir_equal (TR(cond1,body1,_)) (TR(cond2,body2,_)) =
348             cond1 = cond2 andalso is_ir_equal body1 body2
349        |  is_ir_equal (CJ(cond1,s1,s2,_)) (CJ(cond2,s3,s4,_)) =
350             cond1 = cond2 andalso is_ir_equal s1 s3 andalso is_ir_equal s2 s4
351        |  is_ir_equal (BLK(s1,_)) (BLK(s2,_)) =
352             List.all (op =) (zip s1 s2)
353        |  is_ir_equal (STM s1) (STM s2) =
354             List.all (op =) (zip s1 s2)
355        |  is_ir_equal (CALL(name1,_,_,_,_)) (CALL(name2,_,_,_,_)) =
356             name1 = name2
357        |  is_ir_equal _ _ = false;
358
359       fun merge_stm (SC(STM s1, STM s2, info)) =
360             BLK (s1 @ s2, info)
361        |  merge_stm (SC(STM s1, BLK(s2,_), info)) =
362             BLK (s1 @ s2, info)
363        |  merge_stm (SC(BLK(s1,_), STM s2, info)) =
364             BLK (s1 @ s2, info)
365        |  merge_stm (SC(BLK(s1,_), BLK(s2,_), info)) =
366             BLK (s1 @ s2, info)
367        |  merge_stm (SC(STM [], s2, info)) =
368             merge_stm s2
369        |  merge_stm (SC(s1, STM [], info)) =
370             merge_stm s1
371        |  merge_stm (SC(BLK([],_), s2, info)) =
372             merge_stm s2
373        |  merge_stm (SC(s1, BLK([],_), info)) =
374             merge_stm s1
375        |  merge_stm (SC(s1,s2,info)) =
376             SC (merge_stm s1, merge_stm s2, info)
377        |  merge_stm (TR(cond, body, info)) =
378             TR(cond, merge_stm body, info)
379        |  merge_stm (CJ(cond, s1, s2, info)) =
380             CJ(cond, merge_stm s1, merge_stm s2, info)
381        |  merge_stm (STM s) =
382             BLK(s, init_info)
383        |  merge_stm stm =
384             stm
385
386        fun merge ir = let val ir' = merge_stm ir;
387                           val ir'' = merge_stm ir' in
388                         if is_ir_equal ir' ir'' then ir'' else merge ir''
389                       end
390     in
391        merge ir
392     end
393
394
395 (*---------------------------------------------------------------------------------*)
396 (*      Assistant functions                                                        *)
397 (*---------------------------------------------------------------------------------*)
398
399   fun get_annt (BLK (stmL,info)) = info
400    |  get_annt (SC(s1,s2,info)) = info
401    |  get_annt (CJ(cond,s1,s2,info)) = info
402    |  get_annt (TR(cond,s,info)) = info
403    |  get_annt (CALL(fname,pre,body,post,info)) = info
404    |  get_annt (STM stmL) = {ins = NA, outs = NA, fspec = thm_t, context = []};
405
406   fun replace_ins {ins = ins', outs = outs', context = context', fspec = fspec'} ins'' =
407           {ins = ins'', outs = outs', context = context', fspec = fspec'};
408
409   fun replace_outs {ins = ins', outs = outs', context = context', fspec = fspec'} outs'' =
410           {ins = ins', outs = outs'', context = context', fspec = fspec'};
411
412   fun replace_context {ins = ins', outs = outs', context = context', fspec = fspec'} context'' =
413           {ins = ins', outs = outs', context = context'', fspec = fspec'};
414
415   fun replace_fspec {ins = ins', outs = outs', context = context', fspec = fspec'} fspec'' =
416           {ins = ins', outs = outs', context = context', fspec = fspec''};
417
418
419   fun apply_to_info (BLK (stmL,info)) f =  BLK (stmL, f info)
420    |  apply_to_info (SC(s1,s2,info)) f = SC(s1, s2, f info)
421    |  apply_to_info (CJ(cond,s1,s2,info)) f = CJ(cond, s1, s2, f info)
422    |  apply_to_info (TR(cond,s,info)) f = TR(cond, s, f info)
423    |  apply_to_info (CALL(fname,pre,body,post,info)) f = CALL(fname, pre, body, post, f info)
424    |  apply_to_info (STM stmL) f = STM stmL;
425
426
427 (*---------------------------------------------------------------------------------*)
428 (*      Calculate Modified Registers                                               *)
429 (*---------------------------------------------------------------------------------*)
430
431   fun one_stm_modified_regs ({oper = op1, dst = dlist, src = slist}) =
432       List.map (fn (REG r) => r | _ => ~1) (List.filter (fn (REG r) => true | _ => false) dlist);
433
434
435   fun get_modified_regs (SC(ir1,ir2,info)) =
436         (get_modified_regs ir1) @ (get_modified_regs ir2)
437    |  get_modified_regs (TR(cond,ir,info)) =
438         get_modified_regs ir
439    |  get_modified_regs (CJ(cond,ir1,ir2,info)) =
440         (get_modified_regs ir1) @ (get_modified_regs ir2)
441    |  get_modified_regs (CALL(fname,pre,body,post,info)) =
442                        let
443                                val preL = get_modified_regs pre
444                                val bodyL = get_modified_regs body
445                                val outsL = (List.map (fn (REG r) => r | _ => ~1) (pair2list (#outs info)))
446                                val restoredL = [13, ~1]
447                                val modL = filter (fn e => not (mem e restoredL)) (preL @ bodyL @ outsL)
448                        in
449                                modL
450                        end
451    |  get_modified_regs (STM l) =
452         itlist (curry (fn (a,b) => one_stm_modified_regs a @ b)) l []
453    |  get_modified_regs (BLK (l,info)) =
454         itlist (curry (fn (a,b) => one_stm_modified_regs a @ b)) l [];
455
456 (*---------------------------------------------------------------------------------*)
457 (*      Set input, output and context information                                  *)
458 (*      Alignment functions                                                        *)
459 (*---------------------------------------------------------------------------------*)
460
461   (* Adjust the inputs to be consistent with the outputs of the previous ir  *)
462
463   fun adjust_ins ir (outer_info as ({ins = outer_ins, outs = outer_outs, context = outer_context, ...}:annt)) =
464        let val inner_info as {ins = inner_ins, context = inner_context, outs = inner_outs, fspec = inner_spec, ...} = get_annt ir;
465            val (inner_inS,outer_inS) = ((list2set o pair2list) inner_ins, (list2set o pair2list) outer_ins);
466        in  if outer_ins = inner_ins then
467               ir
468            else
469               case ir of
470                   (BLK (stmL, info)) =>
471                       BLK (stmL, replace_ins info outer_ins)
472                |  (SC (s1,s2,info)) =>
473                       SC (adjust_ins s1 outer_info, s2, replace_ins info outer_ins)
474                |  (CJ (cond,s1,s2,info)) =>
475                       CJ (cond, adjust_ins s1 outer_info, adjust_ins s2 outer_info, replace_ins info outer_ins)
476                |  (TR (cond,s,info)) =>
477                       TR(cond, adjust_ins s info, replace_ins info outer_ins)
478                |  (CALL (fname,pre,body,post,info)) =>
479                       CALL(fname, pre, body, post, replace_ins info outer_ins)
480                |  _ =>
481                       raise Fail "adjust_ins: invalid IR tree"
482        end
483
484
485   fun args_diff (ins1, outs0) =
486        set2pair (S.difference (pair2set ins1, pair2set outs0));
487
488   (* Given the outputs and the context of an ir, calculate its inputs *)
489
490
491(*
492val irX = (back_trace ir1 info)
493
494fun extract (SC(s1,s2,inner_info)) (outer_info as {outs = outer_outs, context = contextL, ...}:annt) =
495        (s1, s2, inner_info, outer_outs, contextL, outer_info)
496
497val (s1, s2, inner_info, outer_outs, contextL, outer_info)  = extract ir1 info
498
499              val s1' = back_trace
500
501val s2' = back_trace s2 outer_info
502
503fun extract (CJ(cond, s1, s2, inner_info)) (outer_info as {outs = outer_outs, context = contextL, ...}:annt) =
504(cond, s1, s2, inner_info, outer_info, outer_outs, contextL)
505
506val (cond, s1, s2, inner_info, outer_info, outer_outs, contextL) = extract ir1 info
507
508fun extract (CALL (fname, pre, body, post, info)) (outer_info as {outs = outer_outs, context = contextL, fspec = fout_spec, ...}:annt) =
509(fname, pre, body, post, info, outer_info, outer_outs, contextL, fout_spec);
510
511val (fname, pre, body, post, info, outer_info, outer_outs, contextL, fout_spec) = extract s1 {ins = #ins outer_info, outs = #ins s2_info, context = #context s2_info, fspec = #fspec s1_info};
512
513*)
514
515   fun back_trace (BLK (stmL,inner_info)) (outer_info as {outs = outer_outs, context = contextL, ...}:annt) =
516          let
517              val (inner_inL, inner_tempL, inner_outL) =  getIO stmL
518              val gapL = set2list (S.difference (pair2set outer_outs, list2set inner_outL));
519              val read_inS = list2set (gapL @ inner_inL);
520             (* val contextS = S.difference (list2set contextL, real_outS) *)
521          in
522              BLK (stmL, {outs = outer_outs, ins = set2pair read_inS, context = contextL, fspec = #fspec inner_info})
523          end
524
525    |  back_trace (STM stmL) info =
526          back_trace (BLK (stmL,init_info)) info
527
528    |  back_trace (SC(s1,s2,inner_info)) (outer_info as {outs = outer_outs, context = contextL, ...}) =
529           let
530              val s2' = back_trace s2 outer_info
531              val (s1_info, s2_info) = (get_annt s1, get_annt s2');
532              val s2'' = if S.isSubset (pair2set (#ins s2_info), pair2set (#outs s1_info))
533                         then adjust_ins s2' (replace_ins s2_info (#outs s1_info))
534                         else s2';
535              val s2_info = get_annt s2'';
536              val s1' = back_trace s1 {ins = #ins outer_info, outs = #ins s2_info, context = #context s2_info, fspec = #fspec s1_info};
537              val s1_info = get_annt s1'
538           in
539              SC(s1',s2'', {ins = #ins s1_info, outs = #outs s2_info, fspec = thm_t, context = contextL})
540           end
541
542    |  back_trace (CJ(cond, s1, s2, inner_info)) (outer_info as {outs = outer_outs, context = contextL, ...}) =
543          let
544              fun filter_exp (REG e) = true
545               |  filter_exp (MEM e) = true
546               |  filter_exp _ = false
547
548              val cond_expL = [(#1 cond), (#3 cond)];
549              val s1' = back_trace s1 outer_info
550              val s2' = back_trace s2 outer_info
551              val ({ins = ins1, outs = outs1, ...}, {ins = ins2, outs = outs2, ...}) = (get_annt s1', get_annt s2');
552              val inS_0 = set2pair (list2set (filter filter_exp
553                                        (cond_expL @ (pair2list ins1) @ (pair2list ins2))));
554                      (* union of the variables in the condition and the inputs of the s1' and s2' *)
555              val info_0 = replace_ins outer_info inS_0
556          in
557              CJ(cond, adjust_ins s1' info_0, adjust_ins s2' info_0, info_0)
558          end
559
560    |  back_trace (TR (cond, body, info)) (outer_info as {outs = outer_outs, context = contextL, ...}) =
561           let
562              val extra_outs = args_diff (PAIR (outer_outs, list2pair contextL), #outs info)
563              val info' = replace_context info (pair2list extra_outs);
564              val body' = adjust_ins body info'
565          in
566              TR(cond, body', info')
567          end
568
569          (* the adjustment will be performed later in the funCall.sml *)
570    |  back_trace (CALL (fname, pre, body, post, info)) (outer_info as {outs = outer_outs, context = contextL, fspec = fout_spec, ...}) =
571          let
572                                  val outer_outs_set = pair2set outer_outs
573                                  val inner_outs_set = pair2set (#outs info);
574                                  val extra_outs_set = S.difference (outer_outs_set, inner_outs_set);
575
576                                  val inner_ins_set = pair2set (#ins info)
577                                  val new_ins_set = S.union(extra_outs_set, inner_ins_set);
578
579                                  val new_ins = set2pair new_ins_set;
580                                  val info' = replace_ins outer_info new_ins
581          in
582              CALL (fname, pre, BLK ([], info), post, info')
583          end;
584
585   fun set_info (ins,ir,outs) =
586       let
587          fun extract_outputs ir0 =
588              let val info0 = get_annt ir0;
589                  val (inner_outS, outer_outS) = (IRSyntax.pair2set (#outs info0), IRSyntax.pair2set outs)
590                  val ir1 = if Binaryset.equal(inner_outS, outer_outS) then ir0
591                            else SC (ir0, BLK ([], {context = #context info0, ins = #outs info0, outs = outs, fspec = thm_t}),
592                                     replace_outs info0 outs)
593                  val info1 = get_annt ir1
594              in
595                  if #ins info0 = ins then ir1
596                  else SC (BLK ([], {context = #context info1, ins = ins, outs = #ins info0, fspec = thm_t}),
597                           ir1, replace_ins info1 ins)
598              end
599          val info =  {ins = ins, outs = outs, context = [], fspec = thm_t}
600       in
601          extract_outputs (back_trace ir info)
602       end;
603(* ---------------------------------------------------------------------------------------------------------------------*)
604(* Adjust the IR tree to make inputs and outputs consistent                                                             *)
605(* ---------------------------------------------------------------------------------------------------------------------*)
606
607  fun match_ins_outs (SC(s1,s2,info)) =
608      let
609          val (ir1,ir2) = (match_ins_outs s1, match_ins_outs s2);
610          val (info1,info2) = (get_annt ir1, get_annt ir2);
611          val (outs1, ins2) = (#outs info1, #ins info2);
612      in
613          if outs1 = ins2 then
614              SC(ir1,ir2,info)
615          else
616              SC(ir1,
617                 SC (BLK ([], {ins = outs1, outs = ins2, context = #context info2, fspec = thm_t}), ir2,
618                          replace_ins info2 outs1),
619                 info)
620      end
621
622   |  match_ins_outs (ir as (CALL (fname, pre, body, post, info))) =
623      let
624          val ((pre_ins,post_outs),(outer_ins,outer_outs)) = ((#ins (get_annt pre), #outs (get_annt post)), (#ins info, #outs info))
625          val ir' = if post_outs = outer_outs then ir
626                    else
627                        SC (ir,
628                            BLK ([], {ins = post_outs, outs = outer_outs, context = #context info, fspec = thm_t}),
629                            replace_ins info pre_ins)
630      in
631          if pre_ins = outer_ins then ir'
632          else
633              SC (BLK([], {ins = outer_ins, outs = pre_ins, context = #context info, fspec = thm_t}),
634                  ir',
635                  info)
636      end
637
638   |  match_ins_outs ir =
639      ir;
640
641
642 (*---------------------------------------------------------------------------------*)
643 (*      Interface                                                                  *)
644 (*---------------------------------------------------------------------------------*)
645
646   fun build_ir (f_name, f_type, f_args, f_gr, f_outs, f_rs) =
647      let
648         val (ins, outs) = (IRSyntax.to_exp f_args, IRSyntax.to_exp f_outs)
649         val ir0 = convert_module (ins, f_gr, outs)
650         val ir1 = (merge_ir o rm_dummy_inst) ir0
651         val ir2 = set_info (ins,ir1,outs)
652      in
653         (ins,ir2,outs)
654      end
655
656   fun sfl2ir prog =
657     let
658        val env = ANF.toANF [] prog;
659        val defs = List.map (fn (name, (flag,src,anf,cps)) => src) env;
660        val (f_const,(_, f_defs, f_anf, f_ast)) = hd env
661        val setting as (f_name, f_type, f_args, f_gr, f_outs, f_rs) = regAllocation.convert_to_ARM (f_anf);
662        val f_ir = build_ir setting
663     in
664        (f_name, f_type, f_ir, defs)
665     end
666end
667end
668