1structure annotatedIR = struct
2  local
3  open HolKernel Parse boolLib preARMTheory pairLib simpLib bossLib
4       numSyntax optionSyntax listSyntax CFLTheory IRSyntax
5  in
6
7 (*---------------------------------------------------------------------------------*)
8 (*      Annotated IR tree                                                          *)
9 (*---------------------------------------------------------------------------------*)
10
11   (* annotation at each node of the IR tree *)
12   type annt = {fspec : thm, ins : exp, outs : exp, context : exp list};
13
14   (* conditions of CJ and TR *)
15   type cond = exp * rop * exp;
16
17   datatype anntIR =  TR of cond * anntIR * annt
18                   |  SC of anntIR * anntIR * annt
19                   |  CJ of cond * anntIR * anntIR * annt
20                   |  BLK of instr list * annt
21                   |  STM of instr list
22                   |  CALL of string * anntIR * anntIR * anntIR * annt
23
24 (*---------------------------------------------------------------------------------*)
25 (*      Print an annotated IR tree                                                 *)
26 (*---------------------------------------------------------------------------------*)
27
28  val show_call_detail = ref true;
29
30  fun print_ir (outstream, s0) =
31    let
32
33      fun say s = TextIO.output(outstream,s)
34      fun sayln s= (say s; say "\n")
35
36      fun indent 0 = ()
37        | indent i = (say " "; indent(i-1))
38
39      fun printtree(SC(s1,s2,info),d) =
40            (indent d; sayln "SC("; printtree(s1,d+2); sayln ","; printtree(s2,d+2); say ")")
41        | printtree(CJ((exp1,rop,exp2),s1,s2,info),d) =
42            (indent d; say "CJ("; say (format_exp exp1 ^ " " ^ print_rop rop ^ " " ^ format_exp exp2); sayln ",";
43                       printtree(s1,d+2); sayln ","; printtree(s2,d+2); say ")")
44        | printtree(TR((exp1,rop,exp2),s,info),d) =
45            (indent d; say "TR("; say (format_exp exp1 ^ " " ^ print_rop rop ^ format_exp exp2); sayln ",";
46                       printtree(s,d+2); say ")")
47        | printtree(CALL(fname,pre,body,post,info),d) =
48            if not (!show_call_detail) then
49                (indent d; say ("CALL(" ^ fname ^ ")"))
50            else
51                let val ir' = SC (pre, SC (body, post, info), info)
52                    (* val ir'' = (merge_stm o rm_dummy_inst) ir' *)
53                in
54                    printtree (ir', d)
55                end
56
57        | printtree(STM stmL,d) =
58            (indent d; if null stmL then say ("[]")
59                       else say ("[" ^ formatInst (hd stmL) ^
60                                 (itlist (curry (fn (stm,str) => "; " ^ formatInst stm ^ str)) (tl stmL) "") ^ "]"))
61        | printtree(BLK(stmL, info),d) =
62            printtree(STM stmL,d)
63
64    in
65      printtree(s0,0); sayln ""; TextIO.flushOut outstream
66    end
67
68  fun printIR ir = print_ir (TextIO.stdOut, ir)
69
70  fun printIR2 (f_name, f_type, (ins,ir,outs), defs) =
71      ( print ("Module: " ^ f_name ^ "\n");
72        print ("Inputs: " ^ format_exp ins ^ "\n" ^ "Outputs: " ^ format_exp outs ^ "\n");
73        printIR ir
74      )
75
76 (*---------------------------------------------------------------------------------*)
77 (*      Assistant Graph Functions                                                  *)
78 (*---------------------------------------------------------------------------------*)
79
80   structure G = Graph;
81
82   fun mk_edge (n0,n1,lab) gr =
83    if n0 = n1 then gr
84    else if (List.find (fn m => #2 m = n0) (#1(G.context(n1,gr)))) <> NONE then gr              (* raise Edge *)
85    else
86       let val ((p1,m1,l1,s1), del_n1) = G.match(n1,gr)
87       in G.embed(((lab,n0)::p1,m1,l1,s1),del_n1)
88    end;
89
90   (* retreive the subgraph starting from the start node and ends when "f node" holds *)
91   fun sub_graph gr start_node f =
92       let
93           fun one_node (nodeNo,gr') =
94               if f nodeNo then
95                   gr'
96               else
97                   List.foldl (fn ((a,b),gr'') =>
98                                   one_node (b, G.embed (([(a,nodeNo)],b,#3 (G.context(b,gr)),[]), gr''))
99                              handle e => mk_edge (nodeNo,b,0) gr'')
100                            gr' (#4 (G.context(nodeNo,gr)))
101       in
102           one_node (start_node, G.embed(([],start_node, #3 (G.context(start_node,gr)),[]),G.empty))
103       end;
104
105    (* locate a node, whenever the final node is reached, search terminats *)
106    fun locate gr start_node f =
107       let
108           val maxNodeNo = G.noNodes gr - 1;
109           fun one_node nodeNo =
110               if f nodeNo then
111                   (SOME nodeNo,true)
112               else if nodeNo = maxNodeNo then
113                   (NONE,false)
114               else
115                   List.foldl (fn ((a,b),(n,found)) => if found then (n,found) else one_node b)
116                            (SOME start_node,false) (#4 (G.context(nodeNo,gr)))
117       in
118           #1 (one_node start_node)
119       end;
120
121 (*---------------------------------------------------------------------------------*)
122 (*      Conversion from CFG to IR tree                                             *)
123 (*---------------------------------------------------------------------------------*)
124
125   fun get_label (Assem.LABEL {lab = lab'}) = lab'
126    |  get_label _ = raise ERR "IL" ("Expecting a label instruction")
127
128   (* find the node including the function label, if there are more than one incoming edges, then a rec is found *)
129
130   fun is_rec gr n =
131       let
132           val context = G.context(n,gr)
133           fun is_label (Assem.LABEL _) = true
134            |  is_label _ = false
135       in
136           if is_label (#instr ((#3 context):CFG.node)) then
137                (length (#1 context) > 1, n)
138           else is_rec gr (#2 (hd (#4 context)))
139       end
140       handle e => (false,0);        (* no label node in the graph *)
141
142
143   (* Given a TR cfg and the node including the function name label, break this cfg into three parts: the pre-condition part,
144      the basic case part and the recursive case part. The condition is also derived.                                          *)
145
146   fun break_rec gr lab_node =
147     let
148        fun get_sucL n = #4 (G.context(n,gr));
149        fun get_preL n = #1 (G.context(n,gr));
150
151        val last_node = valOf (locate gr 0 (fn n => null (get_sucL n)));
152
153        fun find_join_node n =
154            if length (get_preL n) > 1 then n
155            else find_join_node (#2 (hd (get_preL last_node)));
156        val join_node = find_join_node last_node;  (* the node that the basic and recursive parts join *)
157
158        (* the nodes jumping to the join node *)
159        val BAL_node = #2 (valOf (List.find (fn (flag,n) => flag = 1) (get_preL join_node)));
160        val b_end_node = #2 (valOf (List.find (fn (flag,n) => flag = 0) (get_preL join_node)));
161        val b_start_node = #2 (valOf (List.find (fn (flag,n) => flag = 2) (get_sucL BAL_node)));
162        val b_cfg = sub_graph gr b_start_node (fn n => n = b_end_node);
163
164        val cj_node = #2 (valOf (List.find (fn (flag,n) => flag = 1) (get_preL b_start_node)));
165        val cmp_node = #2 (hd (get_preL cj_node));
166        val cond = to_cond (#instr ((#3 (G.context(cmp_node,gr))):CFG.node),
167                            #instr ((#3 (G.context(cj_node,gr))):CFG.node));
168
169        val r_start_node = #2 (valOf (List.find (fn (flag,n) => flag = 0) (get_sucL cj_node)));
170        val BL_node = #2 (valOf (List.find (fn (flag,n) => flag = 1) (get_preL lab_node)));
171        val r_end_node = #2 (hd (get_preL BL_node));
172        val r_cfg = sub_graph gr r_start_node (fn n => n = r_end_node);
173
174        val p_start_node = if #2 (hd (get_sucL lab_node)) = cmp_node then NONE
175                           else SOME (#2 (hd (get_sucL lab_node)));
176        val p_cfg = case p_start_node of
177                          NONE => NONE  (* the pre-condition part may be empty *)
178                      |   SOME k => SOME (sub_graph gr k (fn n => n = #2 (hd (get_preL cj_node))))
179    in
180        (cond, ((p_cfg,p_start_node), (b_cfg,b_start_node), (r_cfg,r_start_node)), join_node)
181    end
182
183   (* Given a CJ cfg and the node including the cmp instruction, break this cfg into the true case part and
184      the false case part.                                                        *)
185
186   fun break_cj gr cmp_node =
187     let
188        fun get_sucL n = #4 (G.context(n,gr));
189        fun get_preL n = #1 (G.context(n,gr));
190
191        val cj_node = #2 (hd (get_sucL cmp_node));
192        val sucL = #4 (G.context(cj_node,gr));
193        val (t_start_node, f_start_node) =  (#2 (valOf (List.find (fn (flag,n) => flag = 1) sucL)),
194                                             #2 (valOf (List.find (fn (flag,n) => flag = 0) sucL)));
195        val bal_node = #2 (valOf (List.find (fn (flag,n) => flag = 2) (get_preL t_start_node)));
196        val f_end_node = #2 (hd (get_preL bal_node));
197
198        val join_node = #2 (valOf (List.find (fn (flag,n) => flag = 1) (get_sucL bal_node)));
199        val t_end_node =  #2 (valOf (List.find (fn (flag,n) => flag = 0) (get_preL join_node)));
200
201        val cond = to_cond (#instr ((#3 (G.context(cmp_node,gr))):CFG.node),
202                            #instr ((#3 (G.context(cj_node,gr))):CFG.node));
203
204        val (t_cfg,f_cfg) = (sub_graph gr t_start_node (fn n => n = t_end_node),
205                             sub_graph gr f_start_node (fn n => n = f_end_node))
206    in
207        (cond, ((t_cfg,t_start_node), (f_cfg,f_start_node)), join_node)
208    end
209
210
211   fun convert_cond (exp1, rop, exp2) =
212       (to_exp exp1, to_rop rop, to_exp exp2);
213
214 (*---------------------------------------------------------------------------------*)
215 (*  Convert cfg consisting of SC, CJ and CALL structures to ir tree                *)
216 (*---------------------------------------------------------------------------------*)
217
218   val thm_t = DECIDE (Term `T`);
219   val init_info = {fspec = thm_t, ins = NA, outs = NA, context = []};
220
221   fun convert cfg n =
222        let
223           val (preL,_,{instr = inst', def = def', use = sue'},sucL) = G.context(n,cfg);
224        in
225           case inst' of
226               Assem.OPER {oper = (Assem.NOP, _, _), ...} =>
227                     if not (null sucL) then convert cfg (#2 (hd sucL)) else STM []
228
229           |   Assem.OPER {oper = (Assem.BL,_,_), dst = dList, src = sList, jump = jp} =>
230                     let val stm = CALL(Symbol.name(hd (valOf jp)), STM [], STM [], STM [],
231                                       {fspec = thm_t, ins = list2pair (List.map to_exp sList),
232                                        outs = list2pair (List.map to_exp dList), context = []})
233                     in  if null sucL then stm
234                         else SC (stm, convert cfg (#2 (hd sucL)), init_info)
235                     end
236           |   Assem.LABEL {...} =>
237                     if not (null sucL) then convert cfg (#2 (hd sucL)) else STM []
238
239           |   Assem.OPER {oper = (Assem.CMP, _, _), ...} =>
240                     let val (cond, ((t_cfg,t_start_node),(f_cfg,f_start_node)), join_node) = break_cj cfg n;
241                              (* val rest_cfg = sub_graph gr join_node (fn k => k = end_node); *)
242                     in
243                         SC (CJ (convert_cond cond, convert t_cfg t_start_node, convert f_cfg f_start_node, init_info),
244                             convert cfg join_node, init_info)
245                     end
246
247           |   x =>
248                     if not (null sucL) then
249                         SC (STM [to_inst x], convert cfg (#2 (hd sucL)), init_info)
250                     else STM [to_inst x]
251        end
252
253(*---------------------------------------------------------------------------------*)
254(*  Convert a cfg corresonding to a function to ir                                 *)
255(*---------------------------------------------------------------------------------*)
256
257   fun convert_module (ins,cfg,outs) =
258      let
259         val (flag, lab_node) = is_rec cfg 0
260         val (inL,outL) = (pair2list ins, pair2list outs)
261      in
262         if not flag then convert cfg lab_node
263         else
264            let val (cond, ((p_cfg,p_start_node), (b_cfg,b_start_node), (r_cfg,r_start_node)), join_node) = break_rec cfg lab_node
265                val (b_ir,r_ir) = (convert b_cfg b_start_node, convert r_cfg r_start_node);
266
267                val bal_node = #2 (valOf (List.find (fn (flag,n) => flag = 1) (#1 (G.context(lab_node,cfg)))));
268                val (Assem.OPER {src = rec_argL,...}) = #instr ((#3(G.context(bal_node,cfg))):CFG.node);
269                val rec_args_pass_ir = STM (List.map (fn (dexp,sexp) => {oper = mmov, src = [sexp], dst = [dexp]})
270                             (zip inL (pair2list (to_exp (hd rec_argL)))))    (* dst_exp  = BL (src_exp) *)
271
272                val info = {fspec = DECIDE(Term`T`), ins = ins, outs = ins, context = []}
273                val r_ir_1 = SC(r_ir,rec_args_pass_ir, info)
274
275                val cond' = convert_cond cond
276                val tr_ir = case p_cfg of
277                                 NONE =>
278                                    SC (TR (cond', r_ir_1, {fspec = thm_t, ins = ins, outs = ins, context = []}),
279                                        b_ir, info)
280                               |  SOME p_gr =>
281                                    let val p_ir = convert p_gr (valOf p_start_node);
282                                        val ir0 = SC (r_ir_1, p_ir, init_info)
283                                        val ir1 = TR (cond', ir0, {fspec = thm_t, ins = ins, outs = outs, context = []})
284                                        val ir2 = SC (ir1, b_ir, {fspec = thm_t, ins = ins, outs = ins, context = []})
285                                    in
286                                         SC (p_ir, ir2, {fspec = thm_t, ins = ins, outs = outs, context = []})
287                                    end
288             in
289                tr_ir
290             end
291       end;
292
293 (*---------------------------------------------------------------------------------*)
294 (*      Simplify the IR tree                                                       *)
295 (*---------------------------------------------------------------------------------*)
296
297   fun is_dummy_inst ({oper = mmov, src = sList, dst = dList}) =
298         hd sList = hd dList
299    |  is_dummy_inst (stm as {oper = op', src = sList, dst = dList}) =
300         if length sList < 2 then false
301         else if hd dList = hd sList then
302            (hd (tl sList) = WCONST Arbint.zero andalso (op' = madd orelse op' = msub orelse op' = mlsl orelse
303                     op' = mlsr orelse op' = masr orelse op' = mror))
304         else if hd dList = hd (tl (sList)) then
305            hd sList = WCONST Arbint.zero andalso (op' = madd orelse op' = mrsb)
306         else
307            false;
308
309   fun rm_dummy_inst (SC(ir1,ir2,info)) =
310         SC (rm_dummy_inst ir1, rm_dummy_inst ir2, info)
311    |  rm_dummy_inst (TR(cond,ir2,info)) =
312         TR(cond, rm_dummy_inst ir2, info)
313    |  rm_dummy_inst (CJ(cond,ir1,ir2,info)) =
314         CJ(cond, rm_dummy_inst ir1, rm_dummy_inst ir2, info)
315    |  rm_dummy_inst (STM stmL) =
316         STM (List.filter (not o is_dummy_inst) stmL)
317    | rm_dummy_inst (CALL (fname,pre,body,post,info)) =
318         CALL (fname, rm_dummy_inst pre, rm_dummy_inst body, rm_dummy_inst post, info)
319    | rm_dummy_inst (BLK (instL, info)) =
320         BLK (List.filter (not o is_dummy_inst) instL,info)
321
322
323   fun merge_ir ir =
324     let
325
326       (* Determine whether two irs are equal or not *)
327
328       fun is_ir_equal (SC(s1,s2,_)) (SC(s3,s4,_)) =
329             is_ir_equal s1 s3 andalso is_ir_equal s2 s4
330        |  is_ir_equal (TR(cond1,body1,_)) (TR(cond2,body2,_)) =
331             cond1 = cond2 andalso is_ir_equal body1 body2
332        |  is_ir_equal (CJ(cond1,s1,s2,_)) (CJ(cond2,s3,s4,_)) =
333             cond1 = cond2 andalso is_ir_equal s1 s3 andalso is_ir_equal s2 s4
334        |  is_ir_equal (BLK(s1,_)) (BLK(s2,_)) =
335             List.all (op =) (zip s1 s2)
336        |  is_ir_equal (STM s1) (STM s2) =
337             List.all (op =) (zip s1 s2)
338        |  is_ir_equal (CALL(name1,_,_,_,_)) (CALL(name2,_,_,_,_)) =
339             name1 = name2
340        |  is_ir_equal _ _ = false;
341
342       fun merge_stm (SC(STM s1, STM s2, info)) =
343             BLK (s1 @ s2, info)
344        |  merge_stm (SC(STM s1, BLK(s2,_), info)) =
345             BLK (s1 @ s2, info)
346        |  merge_stm (SC(BLK(s1,_), STM s2, info)) =
347             BLK (s1 @ s2, info)
348        |  merge_stm (SC(BLK(s1,_), BLK(s2,_), info)) =
349             BLK (s1 @ s2, info)
350        |  merge_stm (SC(STM [], s2, info)) =
351             merge_stm s2
352        |  merge_stm (SC(s1, STM [], info)) =
353             merge_stm s1
354        |  merge_stm (SC(BLK([],_), s2, info)) =
355             merge_stm s2
356        |  merge_stm (SC(s1, BLK([],_), info)) =
357             merge_stm s1
358        |  merge_stm (SC(s1,s2,info)) =
359             SC (merge_stm s1, merge_stm s2, info)
360        |  merge_stm (TR(cond, body, info)) =
361             TR(cond, merge_stm body, info)
362        |  merge_stm (CJ(cond, s1, s2, info)) =
363             CJ(cond, merge_stm s1, merge_stm s2, info)
364        |  merge_stm (STM s) =
365             BLK(s, init_info)
366        |  merge_stm stm =
367             stm
368
369        fun merge ir = let val ir' = merge_stm ir;
370                           val ir'' = merge_stm ir' in
371                         if is_ir_equal ir' ir'' then ir'' else merge ir''
372                       end
373     in
374        merge ir
375     end
376
377
378 (*---------------------------------------------------------------------------------*)
379 (*      Assistant functions                                                        *)
380 (*---------------------------------------------------------------------------------*)
381
382   fun get_annt (BLK (stmL,info)) = info
383    |  get_annt (SC(s1,s2,info)) = info
384    |  get_annt (CJ(cond,s1,s2,info)) = info
385    |  get_annt (TR(cond,s,info)) = info
386    |  get_annt (CALL(fname,pre,body,post,info)) = info
387    |  get_annt (STM stmL) = {ins = NA, outs = NA, fspec = thm_t, context = []};
388
389   fun replace_ins {ins = ins', outs = outs', context = context', fspec = fspec'} ins'' =
390           {ins = ins'', outs = outs', context = context', fspec = fspec'};
391
392   fun replace_outs {ins = ins', outs = outs', context = context', fspec = fspec'} outs'' =
393           {ins = ins', outs = outs'', context = context', fspec = fspec'};
394
395   fun replace_context {ins = ins', outs = outs', context = context', fspec = fspec'} context'' =
396           {ins = ins', outs = outs', context = context'', fspec = fspec'};
397
398   fun replace_fspec {ins = ins', outs = outs', context = context', fspec = fspec'} fspec'' =
399           {ins = ins', outs = outs', context = context', fspec = fspec''};
400
401
402   fun apply_to_info (BLK (stmL,info)) f =  BLK (stmL, f info)
403    |  apply_to_info (SC(s1,s2,info)) f = SC(s1, s2, f info)
404    |  apply_to_info (CJ(cond,s1,s2,info)) f = CJ(cond, s1, s2, f info)
405    |  apply_to_info (TR(cond,s,info)) f = TR(cond, s, f info)
406    |  apply_to_info (CALL(fname,pre,body,post,info)) f = CALL(fname, pre, body, post, f info)
407    |  apply_to_info (STM stmL) f = STM stmL;
408
409
410 (*---------------------------------------------------------------------------------*)
411 (*      Calculate Modified Registers                                               *)
412 (*---------------------------------------------------------------------------------*)
413
414   fun one_stm_modified_regs ({oper = op1, dst = dlist, src = slist}) =
415       List.map (fn (REG r) => r | _ => ~1) (List.filter (fn (REG r) => true | _ => false) dlist);
416
417   fun get_modified_regs (SC(ir1,ir2,info)) =
418         (get_modified_regs ir1) @ (get_modified_regs ir2)
419    |  get_modified_regs (TR(cond,ir,info)) =
420         get_modified_regs ir
421    |  get_modified_regs (CJ(cond,ir1,ir2,info)) =
422         (get_modified_regs ir1) @ (get_modified_regs ir2)
423    |  get_modified_regs (CALL(fname,pre,body,post,info)) =
424         List.map (fn (REG r) => r | _ => ~1) (List.filter (fn (REG r) => true | _ => false) (pair2list (#outs info)))
425    |  get_modified_regs (STM l) =
426         itlist (curry (fn (a,b) => one_stm_modified_regs a @ b)) l []
427    |  get_modified_regs (BLK (l,info)) =
428         itlist (curry (fn (a,b) => one_stm_modified_regs a @ b)) l [];
429
430
431 (*---------------------------------------------------------------------------------*)
432 (*      Set input, output and context information                                  *)
433 (*      Alignment functions                                                        *)
434 (*---------------------------------------------------------------------------------*)
435
436   (* Adjust the inputs to be consistent with the outputs of the previous ir  *)
437
438   fun adjust_ins ir (outer_info as ({ins = outer_ins, outs = outer_outs, context = outer_context, ...}:annt)) =
439        let val inner_info as {ins = inner_ins, context = inner_context, outs = inner_outs, fspec = inner_spec, ...} = get_annt ir;
440            val (inner_inS,outer_inS) = ((list2set o pair2list) inner_ins, (list2set o pair2list) outer_ins);
441        in  if outer_ins = inner_ins then
442               ir
443            else
444               case ir of
445                   (BLK (stmL, info)) =>
446                       BLK (stmL, replace_ins inner_info outer_ins)
447                |  (SC (s1,s2,info)) =>
448                       SC (adjust_ins s1 outer_info, s2, outer_info)
449                |  (CJ (cond,s1,s2,info)) =>
450                       CJ (cond, adjust_ins s1 outer_info, adjust_ins s2 outer_info, outer_info)
451                |  (TR (cond,s,info)) =>
452                       TR(cond, adjust_ins s info, info)
453                |  (CALL (fname,pre,body,post,info)) =>
454                       CALL(fname, pre, body, post, info)
455                |  _ =>
456                       raise Fail "adjust_ins: invalid IR tree"
457        end
458
459
460   fun args_diff (ins1, outs0) =
461        set2pair (S.difference (pair2set ins1, pair2set outs0));
462
463   (* Given the outputs and the context of an ir, calculate its inputs *)
464
465   fun back_trace (BLK (stmL,inner_info)) (outer_info as {outs = outer_outs, context = contextL, ...}:annt) =
466          let
467              val (inner_inL, inner_tempL, inner_outL) =  getIO stmL
468              val gapL = set2list (S.difference (pair2set outer_outs, list2set inner_outL));
469              val read_inS = list2set (gapL @ inner_inL);
470             (* val contextS = S.difference (list2set contextL, real_outS) *)
471          in
472              BLK (stmL, {outs = outer_outs, ins = set2pair read_inS, context = contextL, fspec = #fspec inner_info})
473          end
474
475    |  back_trace (STM stmL) info =
476          back_trace (BLK (stmL,init_info)) info
477
478    |  back_trace (SC(s1,s2,inner_info)) (outer_info as {outs = outer_outs, context = contextL, ...}) =
479           let
480              val s2' = back_trace s2 outer_info
481              val (s1_info, s2_info) = (get_annt s1, get_annt s2');
482              val s2'' = if S.isSubset (pair2set (#ins s2_info), pair2set (#outs s1_info))
483                         then adjust_ins s2' (replace_ins s2_info (#outs s1_info))
484                         else s2';
485              val s2_info = get_annt s2'';
486              val s1' = back_trace s1 {ins = #ins s1_info, outs = #ins s2_info, context = #context s2_info, fspec = #fspec s1_info};
487              val s1_info = get_annt s1'
488           in
489              SC(s1',s2'', {ins = #ins s1_info, outs = #outs s2_info, fspec = thm_t, context = contextL})
490           end
491
492    |  back_trace (CJ(cond, s1, s2, inner_info)) (outer_info as {outs = outer_outs, context = contextL, ...}) =
493          let
494              fun filter_exp (REG e) = [REG e]
495               |  filter_exp (MEM e) = [MEM e]
496               |  filter_exp _ = []
497
498              val cond_expL = filter_exp (#1 cond) @ filter_exp (#3 cond);
499              val s1' = back_trace s1 outer_info
500              val s2' = back_trace s2 outer_info
501              val ({ins = ins1, outs = outs1, ...}, {ins = ins2, outs = outs2, ...}) = (get_annt s1', get_annt s2');
502              val inS_0 = list2pair (set2list (list2set (cond_expL @ (pair2list ins1) @ (pair2list ins2))));
503                      (* union of the variables in the condition and the inputs of the s1' and s2' *)
504              val info_0 = replace_ins outer_info inS_0
505          in
506              CJ(cond, adjust_ins s1' info_0, adjust_ins s2' info_0, info_0)
507          end
508
509    |  back_trace (TR (cond, body, info)) (outer_info as {outs = outer_outs, context = contextL, ...}) =
510           let
511              val extra_outs = args_diff (PAIR (outer_outs, list2pair contextL), #outs info)
512              val info' = replace_context info (pair2list extra_outs);
513              val body' = adjust_ins body info'
514          in
515              TR(cond, body', info')
516          end
517
518          (* the adjustment will be performed later in the funCall.sml *)
519    |  back_trace (CALL (fname, pre, body, post, info)) (outer_info as {outs = outer_outs, context = contextL, fspec = fout_spec, ...}) =
520          let
521              val extra_outs = args_diff (PAIR (outer_outs, list2pair contextL), #outs info)
522          in
523              CALL (fname, pre, BLK ([], info), post, replace_ins outer_info (#ins info))
524          end;
525
526
527   fun set_info (ins,ir,outs) =
528       let
529          fun extract_outputs ir0 =
530              let val info0 = get_annt ir0;
531                  val (inner_outS, outer_outS) = (pair2set (#outs info0), pair2set outs)
532                  val ir1 = if S.equal(inner_outS, outer_outS) then ir0
533                            else SC (ir0, BLK ([], {context = #context info0, ins = #outs info0, outs = outs, fspec = thm_t}),
534                                     replace_outs info0 outs)
535                  val info1 = get_annt ir1
536              in
537                  if #ins info0 = ins then ir1
538                  else SC (BLK ([], {context = #context info1, ins = ins, outs = #ins info0, fspec = thm_t}),
539                           ir1, replace_ins info1 ins)
540              end
541          val info =  {ins = ins, outs = outs, context = [], fspec = thm_t}
542       in
543          extract_outputs (back_trace ir info)
544       end;
545
546(* ---------------------------------------------------------------------------------------------------------------------*)
547(* Adjust the IR tree to make inputs and outputs consistent                                                             *)
548(* ---------------------------------------------------------------------------------------------------------------------*)
549
550  fun match_ins_outs (SC(s1,s2,info)) =
551      let
552          val (ir1,ir2) = (match_ins_outs s1, match_ins_outs s2);
553          val (info1,info2) = (get_annt ir1, get_annt ir2);
554          val (outs1, ins2) = (#outs info1, #ins info2);
555      in
556          if outs1 = ins2 then
557              SC(ir1,ir2,info)
558          else
559              SC(ir1,
560                 SC (BLK ([], {ins = outs1, outs = ins2, context = #context info2, fspec = thm_t}), ir2,
561                          replace_ins info2 outs1),
562                 info)
563      end
564
565   |  match_ins_outs (ir as (CALL (fname, pre, body, post, info))) =
566      let
567          val ((pre_ins,post_outs),(outer_ins,outer_outs)) = ((#ins (get_annt pre), #outs (get_annt post)), (#ins info, #outs info))
568          val ir' = if post_outs = outer_outs then ir
569                    else
570                        SC (ir,
571                            BLK ([], {ins = post_outs, outs = outer_outs, context = #context info, fspec = thm_t}),
572                            replace_ins info pre_ins)
573      in
574          if pre_ins = outer_ins then ir'
575          else
576              SC (BLK([], {ins = outer_ins, outs = pre_ins, context = #context info, fspec = thm_t}),
577                  ir',
578                  info)
579      end
580
581   |  match_ins_outs ir =
582      ir;
583
584
585 (*---------------------------------------------------------------------------------*)
586 (*      Interface                                                                  *)
587 (*---------------------------------------------------------------------------------*)
588
589   fun build_ir (f_name, f_type, f_args, f_gr, f_outs, f_rs) =
590      let
591         val (ins, outs) = (to_exp f_args, to_exp f_outs)
592         val ir0 = convert_module (ins, f_gr, outs)
593         val ir1 = (merge_ir o rm_dummy_inst) ir0
594         val ir2 = set_info (ins,ir1,outs)
595      in
596         (ins,ir2,outs)
597      end
598
599   fun sfl2ir prog =
600     let
601        val env = ANF.toANF [] prog;
602        val defs = List.map (fn (name, (flag,src,anf,cps)) => src) env;
603        val (f_const,(_, f_defs, f_anf, f_ast)) = hd env
604        val setting as (f_name, f_type, f_args, f_gr, f_outs, f_rs) = regAllocation.convert_to_ARM (f_anf);
605        val f_ir = build_ir setting
606     in
607        (f_name, f_type, f_ir, defs)
608     end
609
610end
611end
612