1structure annotatedIR = struct 2 local 3 open HolKernel Parse boolLib preARMTheory pairLib simpLib bossLib 4 numSyntax optionSyntax listSyntax ILTheory IRSyntax 5 in 6 7 (*---------------------------------------------------------------------------------*) 8 (* Annotated IR tree *) 9 (*---------------------------------------------------------------------------------*) 10 11 (* annotation at each node of the IR tree *) 12 type annt = {fspec : thm, ins : exp, outs : exp, context : exp list}; 13 14 (* conditions of CJ and TR *) 15 type cond = exp * rop * exp; 16 17 datatype anntIR = TR of cond * anntIR * annt 18 | SC of anntIR * anntIR * annt 19 | CJ of cond * anntIR * anntIR * annt 20 | BLK of instr list * annt 21 | STM of instr list 22 | CALL of string * anntIR * anntIR * anntIR * annt 23 24 (*---------------------------------------------------------------------------------*) 25 (* Print an annotated IR tree *) 26 (*---------------------------------------------------------------------------------*) 27 28 val show_call_detail = ref true; 29 30 fun print_ir (outstream, s0) = 31 let 32 33 fun say s = TextIO.output(outstream,s) 34 fun sayln s= (say s; say "\n") 35 36 fun indent 0 = () 37 | indent i = (say " "; indent(i-1)) 38 39 fun printtree(SC(s1,s2,info),d) = 40 (indent d; sayln "SC("; printtree(s1,d+2); sayln ","; printtree(s2,d+2); say ")") 41 | printtree(CJ((exp1,rop,exp2),s1,s2,info),d) = 42 (indent d; say "CJ("; say (format_exp exp1 ^ " " ^ print_rop rop ^ " " ^ format_exp exp2); sayln ","; 43 printtree(s1,d+2); sayln ","; printtree(s2,d+2); say ")") 44 | printtree(TR((exp1,rop,exp2),s,info),d) = 45 (indent d; say "TR("; say (format_exp exp1 ^ " " ^ print_rop rop ^ format_exp exp2); sayln ","; 46 printtree(s,d+2); say ")") 47 | printtree(CALL(fname,pre,body,post,info),d) = 48 if not (!show_call_detail) then 49 (indent d; say ("CALL(" ^ fname ^ ")")) 50 else 51 let val ir' = SC (pre, SC (body, post, info), info) 52 (* val ir'' = (merge_stm o rm_dummy_inst) ir' *) 53 in 54 printtree (ir', d) 55 end 56 57 | printtree(STM stmL,d) = 58 (indent d; if null stmL then say ("[]") 59 else say ("[" ^ formatInst (hd stmL) ^ 60 (itlist (curry (fn (stm,str) => "; " ^ formatInst stm ^ str)) (tl stmL) "") ^ "]")) 61 | printtree(BLK(stmL, info),d) = 62 printtree(STM stmL,d) 63 64 in 65 printtree(s0,0); sayln ""; TextIO.flushOut outstream 66 end 67 68 fun printIR ir = print_ir (TextIO.stdOut, ir) 69 70 fun printIR2 (f_name, f_type, (ins,ir,outs), defs) = 71 ( print ("Module: " ^ f_name ^ "\n"); 72 print ("Inputs: " ^ format_exp ins ^ "\n" ^ "Outputs: " ^ format_exp outs ^ "\n"); 73 printIR ir 74 ) 75 76 (*---------------------------------------------------------------------------------*) 77 (* Assistant Graph Functions *) 78 (*---------------------------------------------------------------------------------*) 79 80 structure G = Graph; 81 82 fun mk_edge (n0,n1,lab) gr = 83 if n0 = n1 then gr 84 else if (List.find (fn m => #2 m = n0) (#1(G.context(n1,gr)))) <> NONE then gr (* raise Edge *) 85 else 86 let val ((p1,m1,l1,s1), del_n1) = G.match(n1,gr) 87 in G.embed(((lab,n0)::p1,m1,l1,s1),del_n1) 88 end; 89 90 (* retreive the subgraph starting from the start node and ends when "f node" holds *) 91 fun sub_graph gr start_node f = 92 let 93 fun one_node (nodeNo,gr') = 94 if f nodeNo then 95 gr' 96 else 97 List.foldl (fn ((a,b),gr'') => 98 one_node (b, G.embed (([(a,nodeNo)],b,#3 (G.context(b,gr)),[]), gr'')) 99 handle e => mk_edge (nodeNo,b,0) gr'') 100 gr' (#4 (G.context(nodeNo,gr))) 101 in 102 one_node (start_node, G.embed(([],start_node, #3 (G.context(start_node,gr)),[]),G.empty)) 103 end; 104 105 (* locate a node, whenever the final node is reached, search terminats *) 106 fun locate gr start_node f = 107 let 108 val maxNodeNo = G.noNodes gr - 1; 109 fun one_node nodeNo = 110 if f nodeNo then 111 (SOME nodeNo,true) 112 else if nodeNo = maxNodeNo then 113 (NONE,false) 114 else 115 List.foldl (fn ((a,b),(n,found)) => if found then (n,found) else one_node b) 116 (SOME start_node,false) (#4 (G.context(nodeNo,gr))) 117 in 118 #1 (one_node start_node) 119 end; 120 121 (*---------------------------------------------------------------------------------*) 122 (* Conversion from CFG to IR tree *) 123 (*---------------------------------------------------------------------------------*) 124 125 fun get_label (Assem.LABEL {lab = lab'}) = lab' 126 | get_label _ = raise ERR "IL" ("Expecting a label instruction") 127 128 (* find the node including the function label, if there are more than one incoming edges, then a rec is found *) 129 130 fun is_rec gr n = 131 let 132 val context = G.context(n,gr) 133 fun is_label (Assem.LABEL _) = true 134 | is_label _ = false 135 in 136 if is_label (#instr ((#3 context):CFG.node)) then 137 (length (#1 context) > 1, n) 138 else is_rec gr (#2 (hd (#4 context))) 139 end 140 handle e => (false,0); (* no label node in the graph *) 141 142 143 (* Given a TR cfg and the node including the function name label, break this cfg into three parts: the pre-condition part, 144 the basic case part and the recursive case part. The condition is also derived. *) 145 146 fun break_rec gr lab_node = 147 let 148 fun get_sucL n = #4 (G.context(n,gr)); 149 fun get_preL n = #1 (G.context(n,gr)); 150 151 val last_node = valOf (locate gr 0 (fn n => null (get_sucL n))); 152 153 fun find_join_node n = 154 if length (get_preL n) > 1 then n 155 else find_join_node (#2 (hd (get_preL last_node))); 156 val join_node = find_join_node last_node; (* the node that the basic and recursive parts join *) 157 158 (* the nodes jumping to the join node *) 159 val BAL_node = #2 (valOf (List.find (fn (flag,n) => flag = 1) (get_preL join_node))); 160 val b_end_node = #2 (valOf (List.find (fn (flag,n) => flag = 0) (get_preL join_node))); 161 val b_start_node = #2 (valOf (List.find (fn (flag,n) => flag = 2) (get_sucL BAL_node))); 162 val b_cfg = sub_graph gr b_start_node (fn n => n = b_end_node); 163 164 val cj_node = #2 (valOf (List.find (fn (flag,n) => flag = 1) (get_preL b_start_node))); 165 val cmp_node = #2 (hd (get_preL cj_node)); 166 val cond = to_cond (#instr ((#3 (G.context(cmp_node,gr))):CFG.node), 167 #instr ((#3 (G.context(cj_node,gr))):CFG.node)); 168 169 val r_start_node = #2 (valOf (List.find (fn (flag,n) => flag = 0) (get_sucL cj_node))); 170 val BL_node = #2 (valOf (List.find (fn (flag,n) => flag = 1) (get_preL lab_node))); 171 val r_end_node = #2 (hd (get_preL BL_node)); 172 val r_cfg = sub_graph gr r_start_node (fn n => n = r_end_node); 173 174 val p_start_node = if #2 (hd (get_sucL lab_node)) = cmp_node then NONE 175 else SOME (#2 (hd (get_sucL lab_node))); 176 val p_cfg = case p_start_node of 177 NONE => NONE (* the pre-condition part may be empty *) 178 | SOME k => SOME (sub_graph gr k (fn n => n = #2 (hd (get_preL cj_node)))) 179 in 180 (cond, ((p_cfg,p_start_node), (b_cfg,b_start_node), (r_cfg,r_start_node)), join_node) 181 end 182 183 (* Given a CJ cfg and the node including the cmp instruction, break this cfg into the true case part and 184 the false case part. *) 185 186 fun break_cj gr cmp_node = 187 let 188 fun get_sucL n = #4 (G.context(n,gr)); 189 fun get_preL n = #1 (G.context(n,gr)); 190 191 val cj_node = #2 (hd (get_sucL cmp_node)); 192 val sucL = #4 (G.context(cj_node,gr)); 193 val (t_start_node, f_start_node) = (#2 (valOf (List.find (fn (flag,n) => flag = 1) sucL)), 194 #2 (valOf (List.find (fn (flag,n) => flag = 0) sucL))); 195 val bal_node = #2 (valOf (List.find (fn (flag,n) => flag = 2) (get_preL t_start_node))); 196 val f_end_node = #2 (hd (get_preL bal_node)); 197 198 val join_node = #2 (valOf (List.find (fn (flag,n) => flag = 1) (get_sucL bal_node))); 199 val t_end_node = #2 (valOf (List.find (fn (flag,n) => flag = 0) (get_preL join_node))); 200 201 val cond = to_cond (#instr ((#3 (G.context(cmp_node,gr))):CFG.node), 202 #instr ((#3 (G.context(cj_node,gr))):CFG.node)); 203 204 val (t_cfg,f_cfg) = (sub_graph gr t_start_node (fn n => n = t_end_node), 205 sub_graph gr f_start_node (fn n => n = f_end_node)) 206 in 207 (cond, ((t_cfg,t_start_node), (f_cfg,f_start_node)), join_node) 208 end 209 210 211 fun convert_cond (exp1, rop, exp2) = 212 (to_exp exp1, to_rop rop, to_exp exp2); 213 214 (*---------------------------------------------------------------------------------*) 215 (* Convert cfg consisting of SC, CJ and CALL structures to ir tree *) 216 (*---------------------------------------------------------------------------------*) 217 218 val thm_t = DECIDE (Term `T`); 219 val init_info = {fspec = thm_t, ins = NA, outs = NA, context = []}; 220 221 fun convert cfg n = 222 let 223 val (preL,_,{instr = inst', def = def', use = sue'},sucL) = G.context(n,cfg); 224 in 225 case inst' of 226 Assem.OPER {oper = (Assem.NOP, _, _), ...} => 227 if not (null sucL) then convert cfg (#2 (hd sucL)) else STM [] 228 229 | Assem.OPER {oper = (Assem.BL,_,_), dst = dList, src = sList, jump = jp} => 230 let val stm = CALL(Symbol.name(hd (valOf jp)), STM [], STM [], STM [], 231 {fspec = thm_t, ins = list2pair (List.map to_exp sList), 232 outs = list2pair (List.map to_exp dList), context = []}) 233 in if null sucL then stm 234 else SC (stm, convert cfg (#2 (hd sucL)), init_info) 235 end 236 | Assem.LABEL {...} => 237 if not (null sucL) then convert cfg (#2 (hd sucL)) else STM [] 238 239 | Assem.OPER {oper = (Assem.CMP, _, _), ...} => 240 let val (cond, ((t_cfg,t_start_node),(f_cfg,f_start_node)), join_node) = break_cj cfg n; 241 (* val rest_cfg = sub_graph gr join_node (fn k => k = end_node); *) 242 in 243 SC (CJ (convert_cond cond, convert t_cfg t_start_node, convert f_cfg f_start_node, init_info), 244 convert cfg join_node, init_info) 245 end 246 247 | x => 248 if not (null sucL) then 249 SC (STM [to_inst x], convert cfg (#2 (hd sucL)), init_info) 250 else STM [to_inst x] 251 end 252 253(*---------------------------------------------------------------------------------*) 254(* Convert a cfg corresonding to a function to ir *) 255(*---------------------------------------------------------------------------------*) 256 257 fun convert_module (ins,cfg,outs) = 258 let 259 val (flag, lab_node) = is_rec cfg 0 260 val (inL,outL) = (pair2list ins, pair2list outs) 261 in 262 if not flag then convert cfg lab_node 263 else 264 let val (cond, ((p_cfg,p_start_node), (b_cfg,b_start_node), (r_cfg,r_start_node)), join_node) = break_rec cfg lab_node 265 val (b_ir,r_ir) = (convert b_cfg b_start_node, convert r_cfg r_start_node); 266 267 val bal_node = #2 (valOf (List.find (fn (flag,n) => flag = 1) (#1 (G.context(lab_node,cfg))))); 268 val (Assem.OPER {src = rec_argL,...}) = #instr ((#3(G.context(bal_node,cfg))):CFG.node); 269 val rec_args_pass_ir = STM (List.map (fn (dexp,sexp) => {oper = mmov, src = [sexp], dst = [dexp]}) 270 (zip inL (pair2list (to_exp (hd rec_argL))))) (* dst_exp = BL (src_exp) *) 271 272 val info = {fspec = DECIDE(Term`T`), ins = ins, outs = ins, context = []} 273 val r_ir_1 = SC(r_ir,rec_args_pass_ir, info) 274 275 val cond' = convert_cond cond 276 val tr_ir = case p_cfg of 277 NONE => 278 SC (TR (cond', r_ir_1, {fspec = thm_t, ins = ins, outs = ins, context = []}), 279 b_ir, info) 280 | SOME p_gr => 281 let val p_ir = convert p_gr (valOf p_start_node); 282 val ir0 = SC (r_ir_1, p_ir, init_info) 283 val ir1 = TR (cond', ir0, {fspec = thm_t, ins = ins, outs = outs, context = []}) 284 val ir2 = SC (ir1, b_ir, {fspec = thm_t, ins = ins, outs = ins, context = []}) 285 in 286 SC (p_ir, ir2, {fspec = thm_t, ins = ins, outs = outs, context = []}) 287 end 288 in 289 tr_ir 290 end 291 end; 292 293 (*---------------------------------------------------------------------------------*) 294 (* Simplify the IR tree *) 295 (*---------------------------------------------------------------------------------*) 296 297 fun is_dummy_inst ({oper = mmov, src = sList, dst = dList}) = 298 hd sList = hd dList 299 | is_dummy_inst (stm as {oper = op', src = sList, dst = dList}) = 300 if length sList < 2 then false 301 else if hd dList = hd sList then 302 (hd (tl sList) = WCONST Arbint.zero andalso (op' = madd orelse op' = msub orelse op' = mlsl orelse 303 op' = mlsr orelse op' = masr orelse op' = mror)) 304 else if hd dList = hd (tl (sList)) then 305 hd sList = WCONST Arbint.zero andalso (op' = madd orelse op' = mrsb) 306 else 307 false; 308 309 fun rm_dummy_inst (SC(ir1,ir2,info)) = 310 SC (rm_dummy_inst ir1, rm_dummy_inst ir2, info) 311 | rm_dummy_inst (TR(cond,ir2,info)) = 312 TR(cond, rm_dummy_inst ir2, info) 313 | rm_dummy_inst (CJ(cond,ir1,ir2,info)) = 314 CJ(cond, rm_dummy_inst ir1, rm_dummy_inst ir2, info) 315 | rm_dummy_inst (STM stmL) = 316 STM (List.filter (not o is_dummy_inst) stmL) 317 | rm_dummy_inst (CALL (fname,pre,body,post,info)) = 318 CALL (fname, rm_dummy_inst pre, rm_dummy_inst body, rm_dummy_inst post, info) 319 | rm_dummy_inst (BLK (instL, info)) = 320 BLK (List.filter (not o is_dummy_inst) instL,info) 321 322 323 fun merge_ir ir = 324 let 325 326 (* Determine whether two irs are equal or not *) 327 328 fun is_ir_equal (SC(s1,s2,_)) (SC(s3,s4,_)) = 329 is_ir_equal s1 s3 andalso is_ir_equal s2 s4 330 | is_ir_equal (TR(cond1,body1,_)) (TR(cond2,body2,_)) = 331 cond1 = cond2 andalso is_ir_equal body1 body2 332 | is_ir_equal (CJ(cond1,s1,s2,_)) (CJ(cond2,s3,s4,_)) = 333 cond1 = cond2 andalso is_ir_equal s1 s3 andalso is_ir_equal s2 s4 334 | is_ir_equal (BLK(s1,_)) (BLK(s2,_)) = 335 List.all (op =) (zip s1 s2) 336 | is_ir_equal (STM s1) (STM s2) = 337 List.all (op =) (zip s1 s2) 338 | is_ir_equal (CALL(name1,_,_,_,_)) (CALL(name2,_,_,_,_)) = 339 name1 = name2 340 | is_ir_equal _ _ = false; 341 342 fun merge_stm (SC(STM s1, STM s2, info)) = 343 BLK (s1 @ s2, info) 344 | merge_stm (SC(STM s1, BLK(s2,_), info)) = 345 BLK (s1 @ s2, info) 346 | merge_stm (SC(BLK(s1,_), STM s2, info)) = 347 BLK (s1 @ s2, info) 348 | merge_stm (SC(BLK(s1,_), BLK(s2,_), info)) = 349 BLK (s1 @ s2, info) 350 | merge_stm (SC(STM [], s2, info)) = 351 merge_stm s2 352 | merge_stm (SC(s1, STM [], info)) = 353 merge_stm s1 354 | merge_stm (SC(BLK([],_), s2, info)) = 355 merge_stm s2 356 | merge_stm (SC(s1, BLK([],_), info)) = 357 merge_stm s1 358 | merge_stm (SC(s1,s2,info)) = 359 SC (merge_stm s1, merge_stm s2, info) 360 | merge_stm (TR(cond, body, info)) = 361 TR(cond, merge_stm body, info) 362 | merge_stm (CJ(cond, s1, s2, info)) = 363 CJ(cond, merge_stm s1, merge_stm s2, info) 364 | merge_stm (STM s) = 365 BLK(s, init_info) 366 | merge_stm stm = 367 stm 368 369 fun merge ir = let val ir' = merge_stm ir; 370 val ir'' = merge_stm ir' in 371 if is_ir_equal ir' ir'' then ir'' else merge ir'' 372 end 373 in 374 merge ir 375 end 376 377 378 (*---------------------------------------------------------------------------------*) 379 (* Assistant functions *) 380 (*---------------------------------------------------------------------------------*) 381 382 fun get_annt (BLK (stmL,info)) = info 383 | get_annt (SC(s1,s2,info)) = info 384 | get_annt (CJ(cond,s1,s2,info)) = info 385 | get_annt (TR(cond,s,info)) = info 386 | get_annt (CALL(fname,pre,body,post,info)) = info 387 | get_annt (STM stmL) = {ins = NA, outs = NA, fspec = thm_t, context = []}; 388 389 fun replace_ins {ins = ins', outs = outs', context = context', fspec = fspec'} ins'' = 390 {ins = ins'', outs = outs', context = context', fspec = fspec'}; 391 392 fun replace_outs {ins = ins', outs = outs', context = context', fspec = fspec'} outs'' = 393 {ins = ins', outs = outs'', context = context', fspec = fspec'}; 394 395 fun replace_context {ins = ins', outs = outs', context = context', fspec = fspec'} context'' = 396 {ins = ins', outs = outs', context = context'', fspec = fspec'}; 397 398 fun replace_fspec {ins = ins', outs = outs', context = context', fspec = fspec'} fspec'' = 399 {ins = ins', outs = outs', context = context', fspec = fspec''}; 400 401 402 fun apply_to_info (BLK (stmL,info)) f = BLK (stmL, f info) 403 | apply_to_info (SC(s1,s2,info)) f = SC(s1, s2, f info) 404 | apply_to_info (CJ(cond,s1,s2,info)) f = CJ(cond, s1, s2, f info) 405 | apply_to_info (TR(cond,s,info)) f = TR(cond, s, f info) 406 | apply_to_info (CALL(fname,pre,body,post,info)) f = CALL(fname, pre, body, post, f info) 407 | apply_to_info (STM stmL) f = STM stmL; 408 409 410 (*---------------------------------------------------------------------------------*) 411 (* Calculate Modified Registers *) 412 (*---------------------------------------------------------------------------------*) 413 414 fun one_stm_modified_regs ({oper = op1, dst = dlist, src = slist}) = 415 List.map (fn (REG r) => r | _ => ~1) (List.filter (fn (REG r) => true | _ => false) dlist); 416 417 fun get_modified_regs (SC(ir1,ir2,info)) = 418 (get_modified_regs ir1) @ (get_modified_regs ir2) 419 | get_modified_regs (TR(cond,ir,info)) = 420 get_modified_regs ir 421 | get_modified_regs (CJ(cond,ir1,ir2,info)) = 422 (get_modified_regs ir1) @ (get_modified_regs ir2) 423 | get_modified_regs (CALL(fname,pre,body,post,info)) = 424 List.map (fn (REG r) => r | _ => ~1) (List.filter (fn (REG r) => true | _ => false) (pair2list (#outs info))) 425 | get_modified_regs (STM l) = 426 itlist (curry (fn (a,b) => one_stm_modified_regs a @ b)) l [] 427 | get_modified_regs (BLK (l,info)) = 428 itlist (curry (fn (a,b) => one_stm_modified_regs a @ b)) l []; 429 430 431 (*---------------------------------------------------------------------------------*) 432 (* Set input, output and context information *) 433 (* Alignment functions *) 434 (*---------------------------------------------------------------------------------*) 435 436 (* Adjust the inputs to be consistent with the outputs of the previous ir *) 437 438 fun adjust_ins ir (outer_info as ({ins = outer_ins, outs = outer_outs, context = outer_context, ...}:annt)) = 439 let val inner_info as {ins = inner_ins, context = inner_context, outs = inner_outs, fspec = inner_spec, ...} = get_annt ir; 440 val (inner_inS,outer_inS) = ((list2set o pair2list) inner_ins, (list2set o pair2list) outer_ins); 441 in if outer_ins = inner_ins then 442 ir 443 else 444 case ir of 445 (BLK (stmL, info)) => 446 BLK (stmL, replace_ins inner_info outer_ins) 447 | (SC (s1,s2,info)) => 448 SC (adjust_ins s1 outer_info, s2, outer_info) 449 | (CJ (cond,s1,s2,info)) => 450 CJ (cond, adjust_ins s1 outer_info, adjust_ins s2 outer_info, outer_info) 451 | (TR (cond,s,info)) => 452 TR(cond, adjust_ins s info, info) 453 | (CALL (fname,pre,body,post,info)) => 454 CALL(fname, pre, body, post, info) 455 | _ => 456 raise Fail "adjust_ins: invalid IR tree" 457 end 458 459 460 fun args_diff (ins1, outs0) = 461 set2pair (S.difference (pair2set ins1, pair2set outs0)); 462 463 (* Given the outputs and the context of an ir, calculate its inputs *) 464 465 fun back_trace (BLK (stmL,inner_info)) (outer_info as {outs = outer_outs, context = contextL, ...}:annt) = 466 let 467 val (inner_inL, inner_tempL, inner_outL) = getIO stmL 468 val gapL = set2list (S.difference (pair2set outer_outs, list2set inner_outL)); 469 val read_inS = list2set (gapL @ inner_inL); 470 (* val contextS = S.difference (list2set contextL, real_outS) *) 471 in 472 BLK (stmL, {outs = outer_outs, ins = set2pair read_inS, context = contextL, fspec = #fspec inner_info}) 473 end 474 475 | back_trace (STM stmL) info = 476 back_trace (BLK (stmL,init_info)) info 477 478 | back_trace (SC(s1,s2,inner_info)) (outer_info as {outs = outer_outs, context = contextL, ...}) = 479 let 480 val s2' = back_trace s2 outer_info 481 val (s1_info, s2_info) = (get_annt s1, get_annt s2'); 482 val s2'' = if S.isSubset (pair2set (#ins s2_info), pair2set (#outs s1_info)) 483 then adjust_ins s2' (replace_ins s2_info (#outs s1_info)) 484 else s2'; 485 val s2_info = get_annt s2''; 486 val s1' = back_trace s1 {ins = #ins s1_info, outs = #ins s2_info, context = #context s2_info, fspec = #fspec s1_info}; 487 val s1_info = get_annt s1' 488 in 489 SC(s1',s2'', {ins = #ins s1_info, outs = #outs s2_info, fspec = thm_t, context = contextL}) 490 end 491 492 | back_trace (CJ(cond, s1, s2, inner_info)) (outer_info as {outs = outer_outs, context = contextL, ...}) = 493 let 494 fun filter_exp (REG e) = [REG e] 495 | filter_exp (MEM e) = [MEM e] 496 | filter_exp _ = [] 497 498 val cond_expL = filter_exp (#1 cond) @ filter_exp (#3 cond); 499 val s1' = back_trace s1 outer_info 500 val s2' = back_trace s2 outer_info 501 val ({ins = ins1, outs = outs1, ...}, {ins = ins2, outs = outs2, ...}) = (get_annt s1', get_annt s2'); 502 val inS_0 = list2pair (set2list (list2set (cond_expL @ (pair2list ins1) @ (pair2list ins2)))); 503 (* union of the variables in the condition and the inputs of the s1' and s2' *) 504 val info_0 = replace_ins outer_info inS_0 505 in 506 CJ(cond, adjust_ins s1' info_0, adjust_ins s2' info_0, info_0) 507 end 508 509 | back_trace (TR (cond, body, info)) (outer_info as {outs = outer_outs, context = contextL, ...}) = 510 let 511 val extra_outs = args_diff (PAIR (outer_outs, list2pair contextL), #outs info) 512 val info' = replace_context info (pair2list extra_outs); 513 val body' = adjust_ins body info' 514 in 515 TR(cond, body', info') 516 end 517 518 (* the adjustment will be performed later in the funCall.sml *) 519 | back_trace (CALL (fname, pre, body, post, info)) (outer_info as {outs = outer_outs, context = contextL, fspec = fout_spec, ...}) = 520 let 521 val extra_outs = args_diff (PAIR (outer_outs, list2pair contextL), #outs info) 522 in 523 CALL (fname, pre, BLK ([], info), post, replace_ins outer_info (#ins info)) 524 end; 525 526 527 fun set_info (ins,ir,outs) = 528 let 529 fun extract_outputs ir0 = 530 let val info0 = get_annt ir0; 531 val (inner_outS, outer_outS) = (pair2set (#outs info0), pair2set outs) 532 val ir1 = if S.equal(inner_outS, outer_outS) then ir0 533 else SC (ir0, BLK ([], {context = #context info0, ins = #outs info0, outs = outs, fspec = thm_t}), 534 replace_outs info0 outs) 535 val info1 = get_annt ir1 536 in 537 if #ins info0 = ins then ir1 538 else SC (BLK ([], {context = #context info1, ins = ins, outs = #ins info0, fspec = thm_t}), 539 ir1, replace_ins info1 ins) 540 end 541 val info = {ins = ins, outs = outs, context = [], fspec = thm_t} 542 in 543 extract_outputs (back_trace ir info) 544 end; 545 546(* ---------------------------------------------------------------------------------------------------------------------*) 547(* Adjust the IR tree to make inputs and outputs consistent *) 548(* ---------------------------------------------------------------------------------------------------------------------*) 549 550 fun match_ins_outs (SC(s1,s2,info)) = 551 let 552 val (ir1,ir2) = (match_ins_outs s1, match_ins_outs s2); 553 val (info1,info2) = (get_annt ir1, get_annt ir2); 554 val (outs1, ins2) = (#outs info1, #ins info2); 555 in 556 if outs1 = ins2 then 557 SC(ir1,ir2,info) 558 else 559 SC(ir1, 560 SC (BLK ([], {ins = outs1, outs = ins2, context = #context info2, fspec = thm_t}), ir2, 561 replace_ins info2 outs1), 562 info) 563 end 564 565 | match_ins_outs (ir as (CALL (fname, pre, body, post, info))) = 566 let 567 val ((pre_ins,post_outs),(outer_ins,outer_outs)) = ((#ins (get_annt pre), #outs (get_annt post)), (#ins info, #outs info)) 568 val ir' = if post_outs = outer_outs then ir 569 else 570 SC (ir, 571 BLK ([], {ins = post_outs, outs = outer_outs, context = #context info, fspec = thm_t}), 572 replace_ins info pre_ins) 573 in 574 if pre_ins = outer_ins then ir' 575 else 576 SC (BLK([], {ins = outer_ins, outs = pre_ins, context = #context info, fspec = thm_t}), 577 ir', 578 info) 579 end 580 581 | match_ins_outs ir = 582 ir; 583 584 585 (*---------------------------------------------------------------------------------*) 586 (* Interface *) 587 (*---------------------------------------------------------------------------------*) 588 589 fun build_ir (f_name, f_type, f_args, f_gr, f_outs, f_rs) = 590 let 591 val (ins, outs) = (to_exp f_args, to_exp f_outs) 592 val ir0 = convert_module (ins, f_gr, outs) 593 val ir1 = (merge_ir o rm_dummy_inst) ir0 594 val ir2 = set_info (ins,ir1,outs) 595 in 596 (ins,ir2,outs) 597 end 598 599 fun sfl2ir prog = 600 let 601 val env = ANF.toANF [] prog; 602 val defs = List.map (fn (name, (flag,src,anf,cps)) => src) env; 603 val (f_const,(_, f_defs, f_anf, f_ast)) = hd env 604 val setting as (f_name, f_type, f_args, f_gr, f_outs, f_rs) = regAllocation.convert_to_ARM (f_anf); 605 val f_ir = build_ir setting 606 in 607 (f_name, f_type, f_ir, defs) 608 end 609 610end 611end 612