1structure annotatedIR = struct 2 local 3 4(* 5quietdec := true; 6*) 7 8 open HolKernel Parse boolLib preARMTheory pairLib simpLib bossLib 9 numSyntax optionSyntax listSyntax ILTheory IRSyntax 10(* 11 quietdec := false; 12 open annotatedIR 13*) 14 in 15 16 (*---------------------------------------------------------------------------------*) 17 (* Annotated IR tree *) 18 (*---------------------------------------------------------------------------------*) 19 20 (* annotation at each node of the IR tree *) 21 type annt = {fspec : thm, ins : exp, outs : exp, context : exp list}; 22 23 (* conditions of CJ and TR *) 24 type cond = exp * rop * exp; 25 26 datatype anntIR = TR of cond * anntIR * annt 27 | SC of anntIR * anntIR * annt 28 | CJ of cond * anntIR * anntIR * annt 29 | BLK of instr list * annt 30 | STM of instr list 31 | CALL of string * anntIR * anntIR * anntIR * annt 32 33 34 (*---------------------------------------------------------------------------------*) 35 (* Print an annotated IR tree *) 36 (*---------------------------------------------------------------------------------*) 37 38 val show_call_detail = ref true; 39 40 fun print_ir (outstream, s0) = 41 let 42 43 fun say s = TextIO.output(outstream,s) 44 fun sayln s= (say s; say "\n") 45 46 fun indent 0 = () 47 | indent i = (say " "; indent(i-1)) 48 49 fun printtree(SC(s1,s2,info),d) = 50 (indent d; sayln "SC("; printtree(s1,d+2); sayln ","; printtree(s2,d+2); say ")") 51 | printtree(CJ((exp1,rop,exp2),s1,s2,info),d) = 52 (indent d; say "CJ("; say (format_exp exp1 ^ " " ^ print_rop rop ^ " " ^ format_exp exp2); sayln ","; 53 printtree(s1,d+2); sayln ","; printtree(s2,d+2); say ")") 54 | printtree(TR((exp1,rop,exp2),s,info),d) = 55 (indent d; say "TR("; say (format_exp exp1 ^ " " ^ print_rop rop ^ format_exp exp2); sayln ","; 56 printtree(s,d+2); say ")") 57 | printtree(CALL(fname,pre,body,post,info),d) = 58 if not (!show_call_detail) then 59 (indent d; say ("CALL(" ^ fname ^ ")")) 60 else 61 let val ir' = SC (pre, SC (body, post, info), info) 62 (* val ir'' = (merge_stm o rm_dummy_inst) ir' *) 63 in 64 printtree (ir', d) 65 end 66 67 | printtree(STM stmL,d) = 68 (indent d; if null stmL then say ("[]") 69 else say ("[" ^ formatInst (hd stmL) ^ 70 (itlist (curry (fn (stm,str) => "; " ^ formatInst stm ^ str)) (tl stmL) "") ^ "]")) 71 | printtree(BLK(stmL, info),d) = 72 printtree(STM stmL,d) 73 74 in 75 printtree(s0,0); sayln ""; TextIO.flushOut outstream 76 end 77 78 fun printIR ir = print_ir (TextIO.stdOut, ir) 79 80 fun printIR2 (f_name, f_type, (ins,ir,outs), defs) = 81 ( print ("Module: " ^ f_name ^ "\n"); 82 print ("Inputs: " ^ format_exp ins ^ "\n" ^ "Outputs: " ^ format_exp outs ^ "\n"); 83 printIR ir 84 ) 85 86 (*---------------------------------------------------------------------------------*) 87 (* Assistant Graph Functions *) 88 (*---------------------------------------------------------------------------------*) 89 90 structure G = Graph; 91 92 fun mk_edge (n0,n1,lab) gr = 93 if n0 = n1 then gr 94 else if (List.find (fn m => #2 m = n0) (#1(G.context(n1,gr)))) <> NONE then gr (* raise Edge *) 95 else 96 let val ((p1,m1,l1,s1), del_n1) = G.match(n1,gr) 97 in G.embed(((lab,n0)::p1,m1,l1,s1),del_n1) 98 end; 99 100 (* retreive the subgraph starting from the start node and ends when "f node" holds *) 101 fun sub_graph gr start_node f = 102 let 103 fun one_node (nodeNo,gr') = 104 if f nodeNo then 105 gr' 106 else 107 List.foldl (fn ((a,b),gr'') => 108 one_node (b, G.embed (([(a,nodeNo)],b,#3 (G.context(b,gr)),[]), gr'')) 109 handle e => mk_edge (nodeNo,b,0) gr'') 110 gr' (#4 (G.context(nodeNo,gr))) 111 in 112 one_node (start_node, G.embed(([],start_node, #3 (G.context(start_node,gr)),[]),G.empty)) 113 end; 114 115 (* locate a node, whenever the final node is reached, search terminats *) 116 fun locate gr start_node f = 117 let 118 val maxNodeNo = G.noNodes gr - 1; 119 fun one_node nodeNo = 120 if f nodeNo then 121 (SOME nodeNo,true) 122 else if nodeNo = maxNodeNo then 123 (NONE,false) 124 else 125 List.foldl (fn ((a,b),(n,found)) => if found then (n,found) else one_node b) 126 (SOME start_node,false) (#4 (G.context(nodeNo,gr))) 127 in 128 #1 (one_node start_node) 129 end; 130 131 (*---------------------------------------------------------------------------------*) 132 (* Conversion from CFG to IR tree *) 133 (*---------------------------------------------------------------------------------*) 134 135 fun get_label (Assem.LABEL {lab = lab'}) = lab' 136 | get_label _ = raise ERR "IL" ("Expecting a label instruction") 137 138 (* find the node including the function label, if there are more than one incoming edges, then a rec is found *) 139 140 fun is_rec gr n = 141 let 142 val context = G.context(n,gr) 143 fun is_label (Assem.LABEL _) = true 144 | is_label _ = false 145 in 146 if is_label (#instr ((#3 context):CFG.node)) then 147 (length (#1 context) > 1, n) 148 else is_rec gr (#2 (hd (#4 context))) 149 end 150 handle e => (false,0); (* no label node in the graph *) 151 152 153 (* Given a TR cfg and the node including the function name label, break this cfg into three parts: the pre-condition part, 154 the basic case part and the recursive case part. The condition is also derived. *) 155 156 157 fun break_rec gr lab_node = 158 let 159 fun get_sucL n = #4 (G.context(n,gr)); 160 fun get_preL n = #1 (G.context(n,gr)); 161 162 val last_node = valOf (locate gr 0 (fn n => null (get_sucL n))); 163 fun find_join_node n = 164 if length (get_preL n) > 1 then n 165 else find_join_node (#2 (hd (get_preL n))); 166 val join_node = find_join_node last_node; (* the node that the basic and recursive parts join *) 167 168 (* the nodes jumping to the join node *) 169 val BAL_node = #2 (valOf (List.find (fn (flag,n) => flag = 1) (get_preL join_node))); 170 val b_end_node = #2 (valOf (List.find (fn (flag,n) => flag = 0) (get_preL join_node))); 171 val b_start_node = #2 (valOf (List.find (fn (flag,n) => flag = 2) (get_sucL BAL_node))); 172 val b_cfg = sub_graph gr b_start_node (fn n => n = b_end_node); 173 174 val cj_node = #2 (valOf (List.find (fn (flag,n) => flag = 1) (get_preL b_start_node))); 175 val cmp_node = #2 (hd (get_preL cj_node)); 176 val cond = to_cond (#instr ((#3 (G.context(cmp_node,gr))):CFG.node), 177 #instr ((#3 (G.context(cj_node,gr))):CFG.node)); 178 179 val r_start_node = #2 (valOf (List.find (fn (flag,n) => flag = 0) (get_sucL cj_node))); 180 val BL_node = #2 (valOf (List.find (fn (flag,n) => flag = 1) (get_preL lab_node))); 181 val r_end_node = #2 (hd (get_preL BL_node)); 182 val r_cfg = sub_graph gr r_start_node (fn n => n = r_end_node); 183 184 val p_start_node = if #2 (hd (get_sucL lab_node)) = cmp_node then NONE 185 else SOME (#2 (hd (get_sucL lab_node))); 186 val p_cfg = case p_start_node of 187 NONE => NONE (* the pre-condition part may be empty *) 188 | SOME k => SOME (sub_graph gr k (fn n => n = #2 (hd (get_preL cj_node)))) 189 in 190 (cond, ((p_cfg,p_start_node), (b_cfg,b_start_node), (r_cfg,r_start_node)), join_node) 191 end 192 193 (* Given a CJ cfg and the node including the cmp instruction, break this cfg into the true case part and 194 the false case part. *) 195 196 fun break_cj gr cmp_node = 197 let 198 fun get_sucL n = #4 (G.context(n,gr)); 199 fun get_preL n = #1 (G.context(n,gr)); 200 201 val cj_node = #2 (hd (get_sucL cmp_node)); 202 val sucL = #4 (G.context(cj_node,gr)); 203 val (t_start_node, f_start_node) = (#2 (valOf (List.find (fn (flag,n) => flag = 1) sucL)), 204 #2 (valOf (List.find (fn (flag,n) => flag = 0) sucL))); 205 val bal_node = #2 (valOf (List.find (fn (flag,n) => flag = 2) (get_preL t_start_node))); 206 val f_end_node = #2 (hd (get_preL bal_node)); 207 208 val join_node = #2 (valOf (List.find (fn (flag,n) => flag = 1) (get_sucL bal_node))); 209 val t_end_node = #2 (valOf (List.find (fn (flag,n) => flag = 0) (get_preL join_node))); 210 211 val cond = to_cond (#instr ((#3 (G.context(cmp_node,gr))):CFG.node), 212 #instr ((#3 (G.context(cj_node,gr))):CFG.node)); 213 214 val (t_cfg,f_cfg) = (sub_graph gr t_start_node (fn n => n = t_end_node), 215 sub_graph gr f_start_node (fn n => n = f_end_node)) 216 in 217 (cond, ((t_cfg,t_start_node), (f_cfg,f_start_node)), join_node) 218 end 219 220 221 fun convert_cond (exp1, rop, exp2) = 222 (to_exp exp1, to_rop rop, to_exp exp2); 223 224 (*---------------------------------------------------------------------------------*) 225 (* Convert cfg consisting of SC, CJ and CALL structures to ir tree *) 226 (*---------------------------------------------------------------------------------*) 227 228 val thm_t = DECIDE (Term `T`); 229 val init_info = {fspec = thm_t, ins = NA, outs = NA, context = []}; 230(* 231 val (cfg, n) = (r_cfg, r_start_node) 232 val n = (#2 (hd sucL)) 233fun extract (Assem.OPER {oper = (Assem.BL,_,_), dst = dList, src = sList, jump = jp}) = (dList, sList, jp); 234val (dList, sList, jp) = extract inst' 235val x = inst' 236*) 237 238 fun convert cfg n = 239 let 240 val (preL,_,{instr = inst', def = def', use = sue'},sucL) = G.context(n,cfg); 241 in 242 case inst' of 243 Assem.OPER {oper = (Assem.NOP, _, _), ...} => 244 if not (null sucL) then convert cfg (#2 (hd sucL)) else STM [] 245 246 | Assem.OPER {oper = (Assem.BL,_,_), dst = dList, src = sList, jump = jp} => 247 let val stm = CALL(Symbol.name(hd (valOf jp)), STM [], STM [], STM [], 248 {fspec = thm_t, ins = list2pair (List.map to_exp sList), 249 outs = list2pair (List.map to_exp dList), context = []}) 250 in if null sucL then stm 251 else SC (stm, convert cfg (#2 (hd sucL)), init_info) 252 end 253 | Assem.LABEL {...} => 254 if not (null sucL) then convert cfg (#2 (hd sucL)) else STM [] 255 256 | Assem.OPER {oper = (Assem.CMP, _, _), ...} => 257 let val (cond, ((t_cfg,t_start_node),(f_cfg,f_start_node)), join_node) = break_cj cfg n; 258 (* val rest_cfg = sub_graph gr join_node (fn k => k = end_node); *) 259 in 260 SC (CJ (convert_cond cond, convert t_cfg t_start_node, convert f_cfg f_start_node, init_info), 261 convert cfg join_node, init_info) 262 end 263 264 | x => 265 if not (null sucL) then 266 SC (STM [to_inst x], convert cfg (#2 (hd sucL)), init_info) 267 else STM [to_inst x] 268 end 269 270(*---------------------------------------------------------------------------------*) 271(* Convert a cfg corresonding to a function to ir *) 272(*---------------------------------------------------------------------------------*) 273 274 fun convert_module (ins,cfg,outs) = 275 let 276 val (flag, lab_node) = is_rec cfg 0 277 val (inL,outL) = (pair2list ins, pair2list outs) 278 in 279 if not flag then convert cfg lab_node 280 else 281 let val (cond, ((p_cfg,p_start_node), (b_cfg,b_start_node), (r_cfg,r_start_node)), join_node) = break_rec cfg lab_node 282 val (b_ir,r_ir) = (convert b_cfg b_start_node, convert r_cfg r_start_node); 283 284 val bal_node = #2 (valOf (List.find (fn (flag,n) => flag = 1) (#1 (G.context(lab_node,cfg))))); 285 val (Assem.OPER {src = rec_argL,...}) = #instr ((#3(G.context(bal_node,cfg))):CFG.node); 286 val rec_args_pass_ir = STM (List.map (fn (dexp,sexp) => {oper = mmov, src = [sexp], dst = [dexp]}) 287 (zip inL (pair2list (to_exp (hd rec_argL))))) (* dst_exp = BL (src_exp) *) 288 289 val info = {fspec = DECIDE(Term`T`), ins = ins, outs = ins, context = []} 290 val r_ir_1 = SC(r_ir,rec_args_pass_ir, info) 291 292 val cond' = convert_cond cond 293 val tr_ir = case p_cfg of 294 NONE => 295 SC (TR (cond', r_ir_1, {fspec = thm_t, ins = ins, outs = ins, context = []}), 296 b_ir, info) 297 | SOME p_gr => 298 let val p_ir = convert p_gr (valOf p_start_node); 299 val ir0 = SC (r_ir_1, p_ir, init_info) 300 val ir1 = TR (cond', ir0, {fspec = thm_t, ins = ins, outs = outs, context = []}) 301 val ir2 = SC (ir1, b_ir, {fspec = thm_t, ins = ins, outs = ins, context = []}) 302 in 303 SC (p_ir, ir2, {fspec = thm_t, ins = ins, outs = outs, context = []}) 304 end 305 in 306 tr_ir 307 end 308 end; 309 310 (*---------------------------------------------------------------------------------*) 311 (* Simplify the IR tree *) 312 (*---------------------------------------------------------------------------------*) 313 314 fun is_dummy_inst ({oper = mmov, src = sList, dst = dList}) = 315 hd sList = hd dList 316 | is_dummy_inst (stm as {oper = op', src = sList, dst = dList}) = 317 if length sList < 2 then false 318 else if hd dList = hd sList then 319 (hd (tl sList) = WCONST Arbint.zero andalso (op' = madd orelse op' = msub orelse op' = mlsl orelse 320 op' = mlsr orelse op' = masr orelse op' = mror)) 321 else if hd dList = hd (tl (sList)) then 322 hd sList = WCONST Arbint.zero andalso (op' = madd orelse op' = mrsb) 323 else 324 false; 325 326 fun rm_dummy_inst (SC(ir1,ir2,info)) = 327 SC (rm_dummy_inst ir1, rm_dummy_inst ir2, info) 328 | rm_dummy_inst (TR(cond,ir2,info)) = 329 TR(cond, rm_dummy_inst ir2, info) 330 | rm_dummy_inst (CJ(cond,ir1,ir2,info)) = 331 CJ(cond, rm_dummy_inst ir1, rm_dummy_inst ir2, info) 332 | rm_dummy_inst (STM stmL) = 333 STM (List.filter (not o is_dummy_inst) stmL) 334 | rm_dummy_inst (CALL (fname,pre,body,post,info)) = 335 CALL (fname, rm_dummy_inst pre, rm_dummy_inst body, rm_dummy_inst post,info) 336 | rm_dummy_inst (BLK (instL, info)) = 337 BLK (List.filter (not o is_dummy_inst) instL,info) 338 339 340 fun merge_ir ir = 341 let 342 343 (* Determine whether two irs are equal or not *) 344 345 fun is_ir_equal (SC(s1,s2,_)) (SC(s3,s4,_)) = 346 is_ir_equal s1 s3 andalso is_ir_equal s2 s4 347 | is_ir_equal (TR(cond1,body1,_)) (TR(cond2,body2,_)) = 348 cond1 = cond2 andalso is_ir_equal body1 body2 349 | is_ir_equal (CJ(cond1,s1,s2,_)) (CJ(cond2,s3,s4,_)) = 350 cond1 = cond2 andalso is_ir_equal s1 s3 andalso is_ir_equal s2 s4 351 | is_ir_equal (BLK(s1,_)) (BLK(s2,_)) = 352 List.all (op =) (zip s1 s2) 353 | is_ir_equal (STM s1) (STM s2) = 354 List.all (op =) (zip s1 s2) 355 | is_ir_equal (CALL(name1,_,_,_,_)) (CALL(name2,_,_,_,_)) = 356 name1 = name2 357 | is_ir_equal _ _ = false; 358 359 fun merge_stm (SC(STM s1, STM s2, info)) = 360 BLK (s1 @ s2, info) 361 | merge_stm (SC(STM s1, BLK(s2,_), info)) = 362 BLK (s1 @ s2, info) 363 | merge_stm (SC(BLK(s1,_), STM s2, info)) = 364 BLK (s1 @ s2, info) 365 | merge_stm (SC(BLK(s1,_), BLK(s2,_), info)) = 366 BLK (s1 @ s2, info) 367 | merge_stm (SC(STM [], s2, info)) = 368 merge_stm s2 369 | merge_stm (SC(s1, STM [], info)) = 370 merge_stm s1 371 | merge_stm (SC(BLK([],_), s2, info)) = 372 merge_stm s2 373 | merge_stm (SC(s1, BLK([],_), info)) = 374 merge_stm s1 375 | merge_stm (SC(s1,s2,info)) = 376 SC (merge_stm s1, merge_stm s2, info) 377 | merge_stm (TR(cond, body, info)) = 378 TR(cond, merge_stm body, info) 379 | merge_stm (CJ(cond, s1, s2, info)) = 380 CJ(cond, merge_stm s1, merge_stm s2, info) 381 | merge_stm (STM s) = 382 BLK(s, init_info) 383 | merge_stm stm = 384 stm 385 386 fun merge ir = let val ir' = merge_stm ir; 387 val ir'' = merge_stm ir' in 388 if is_ir_equal ir' ir'' then ir'' else merge ir'' 389 end 390 in 391 merge ir 392 end 393 394 395 (*---------------------------------------------------------------------------------*) 396 (* Assistant functions *) 397 (*---------------------------------------------------------------------------------*) 398 399 fun get_annt (BLK (stmL,info)) = info 400 | get_annt (SC(s1,s2,info)) = info 401 | get_annt (CJ(cond,s1,s2,info)) = info 402 | get_annt (TR(cond,s,info)) = info 403 | get_annt (CALL(fname,pre,body,post,info)) = info 404 | get_annt (STM stmL) = {ins = NA, outs = NA, fspec = thm_t, context = []}; 405 406 fun replace_ins {ins = ins', outs = outs', context = context', fspec = fspec'} ins'' = 407 {ins = ins'', outs = outs', context = context', fspec = fspec'}; 408 409 fun replace_outs {ins = ins', outs = outs', context = context', fspec = fspec'} outs'' = 410 {ins = ins', outs = outs'', context = context', fspec = fspec'}; 411 412 fun replace_context {ins = ins', outs = outs', context = context', fspec = fspec'} context'' = 413 {ins = ins', outs = outs', context = context'', fspec = fspec'}; 414 415 fun replace_fspec {ins = ins', outs = outs', context = context', fspec = fspec'} fspec'' = 416 {ins = ins', outs = outs', context = context', fspec = fspec''}; 417 418 419 fun apply_to_info (BLK (stmL,info)) f = BLK (stmL, f info) 420 | apply_to_info (SC(s1,s2,info)) f = SC(s1, s2, f info) 421 | apply_to_info (CJ(cond,s1,s2,info)) f = CJ(cond, s1, s2, f info) 422 | apply_to_info (TR(cond,s,info)) f = TR(cond, s, f info) 423 | apply_to_info (CALL(fname,pre,body,post,info)) f = CALL(fname, pre, body, post, f info) 424 | apply_to_info (STM stmL) f = STM stmL; 425 426 427 (*---------------------------------------------------------------------------------*) 428 (* Calculate Modified Registers *) 429 (*---------------------------------------------------------------------------------*) 430 431 fun one_stm_modified_regs ({oper = op1, dst = dlist, src = slist}) = 432 List.map (fn (REG r) => r | _ => ~1) (List.filter (fn (REG r) => true | _ => false) dlist); 433 434 435 fun get_modified_regs (SC(ir1,ir2,info)) = 436 (get_modified_regs ir1) @ (get_modified_regs ir2) 437 | get_modified_regs (TR(cond,ir,info)) = 438 get_modified_regs ir 439 | get_modified_regs (CJ(cond,ir1,ir2,info)) = 440 (get_modified_regs ir1) @ (get_modified_regs ir2) 441 | get_modified_regs (CALL(fname,pre,body,post,info)) = 442 let 443 val preL = get_modified_regs pre 444 val bodyL = get_modified_regs body 445 val outsL = (List.map (fn (REG r) => r | _ => ~1) (pair2list (#outs info))) 446 val restoredL = [13, ~1] 447 val modL = filter (fn e => not (mem e restoredL)) (preL @ bodyL @ outsL) 448 in 449 modL 450 end 451 | get_modified_regs (STM l) = 452 itlist (curry (fn (a,b) => one_stm_modified_regs a @ b)) l [] 453 | get_modified_regs (BLK (l,info)) = 454 itlist (curry (fn (a,b) => one_stm_modified_regs a @ b)) l []; 455 456 (*---------------------------------------------------------------------------------*) 457 (* Set input, output and context information *) 458 (* Alignment functions *) 459 (*---------------------------------------------------------------------------------*) 460 461 (* Adjust the inputs to be consistent with the outputs of the previous ir *) 462 463 fun adjust_ins ir (outer_info as ({ins = outer_ins, outs = outer_outs, context = outer_context, ...}:annt)) = 464 let val inner_info as {ins = inner_ins, context = inner_context, outs = inner_outs, fspec = inner_spec, ...} = get_annt ir; 465 val (inner_inS,outer_inS) = ((list2set o pair2list) inner_ins, (list2set o pair2list) outer_ins); 466 in if outer_ins = inner_ins then 467 ir 468 else 469 case ir of 470 (BLK (stmL, info)) => 471 BLK (stmL, replace_ins info outer_ins) 472 | (SC (s1,s2,info)) => 473 SC (adjust_ins s1 outer_info, s2, replace_ins info outer_ins) 474 | (CJ (cond,s1,s2,info)) => 475 CJ (cond, adjust_ins s1 outer_info, adjust_ins s2 outer_info, replace_ins info outer_ins) 476 | (TR (cond,s,info)) => 477 TR(cond, adjust_ins s info, replace_ins info outer_ins) 478 | (CALL (fname,pre,body,post,info)) => 479 CALL(fname, pre, body, post, replace_ins info outer_ins) 480 | _ => 481 raise Fail "adjust_ins: invalid IR tree" 482 end 483 484 485 fun args_diff (ins1, outs0) = 486 set2pair (S.difference (pair2set ins1, pair2set outs0)); 487 488 (* Given the outputs and the context of an ir, calculate its inputs *) 489 490 491(* 492val irX = (back_trace ir1 info) 493 494fun extract (SC(s1,s2,inner_info)) (outer_info as {outs = outer_outs, context = contextL, ...}:annt) = 495 (s1, s2, inner_info, outer_outs, contextL, outer_info) 496 497val (s1, s2, inner_info, outer_outs, contextL, outer_info) = extract ir1 info 498 499 val s1' = back_trace 500 501val s2' = back_trace s2 outer_info 502 503fun extract (CJ(cond, s1, s2, inner_info)) (outer_info as {outs = outer_outs, context = contextL, ...}:annt) = 504(cond, s1, s2, inner_info, outer_info, outer_outs, contextL) 505 506val (cond, s1, s2, inner_info, outer_info, outer_outs, contextL) = extract ir1 info 507 508fun extract (CALL (fname, pre, body, post, info)) (outer_info as {outs = outer_outs, context = contextL, fspec = fout_spec, ...}:annt) = 509(fname, pre, body, post, info, outer_info, outer_outs, contextL, fout_spec); 510 511val (fname, pre, body, post, info, outer_info, outer_outs, contextL, fout_spec) = extract s1 {ins = #ins outer_info, outs = #ins s2_info, context = #context s2_info, fspec = #fspec s1_info}; 512 513*) 514 515 fun back_trace (BLK (stmL,inner_info)) (outer_info as {outs = outer_outs, context = contextL, ...}:annt) = 516 let 517 val (inner_inL, inner_tempL, inner_outL) = getIO stmL 518 val gapL = set2list (S.difference (pair2set outer_outs, list2set inner_outL)); 519 val read_inS = list2set (gapL @ inner_inL); 520 (* val contextS = S.difference (list2set contextL, real_outS) *) 521 in 522 BLK (stmL, {outs = outer_outs, ins = set2pair read_inS, context = contextL, fspec = #fspec inner_info}) 523 end 524 525 | back_trace (STM stmL) info = 526 back_trace (BLK (stmL,init_info)) info 527 528 | back_trace (SC(s1,s2,inner_info)) (outer_info as {outs = outer_outs, context = contextL, ...}) = 529 let 530 val s2' = back_trace s2 outer_info 531 val (s1_info, s2_info) = (get_annt s1, get_annt s2'); 532 val s2'' = if S.isSubset (pair2set (#ins s2_info), pair2set (#outs s1_info)) 533 then adjust_ins s2' (replace_ins s2_info (#outs s1_info)) 534 else s2'; 535 val s2_info = get_annt s2''; 536 val s1' = back_trace s1 {ins = #ins outer_info, outs = #ins s2_info, context = #context s2_info, fspec = #fspec s1_info}; 537 val s1_info = get_annt s1' 538 in 539 SC(s1',s2'', {ins = #ins s1_info, outs = #outs s2_info, fspec = thm_t, context = contextL}) 540 end 541 542 | back_trace (CJ(cond, s1, s2, inner_info)) (outer_info as {outs = outer_outs, context = contextL, ...}) = 543 let 544 fun filter_exp (REG e) = true 545 | filter_exp (MEM e) = true 546 | filter_exp _ = false 547 548 val cond_expL = [(#1 cond), (#3 cond)]; 549 val s1' = back_trace s1 outer_info 550 val s2' = back_trace s2 outer_info 551 val ({ins = ins1, outs = outs1, ...}, {ins = ins2, outs = outs2, ...}) = (get_annt s1', get_annt s2'); 552 val inS_0 = set2pair (list2set (filter filter_exp 553 (cond_expL @ (pair2list ins1) @ (pair2list ins2)))); 554 (* union of the variables in the condition and the inputs of the s1' and s2' *) 555 val info_0 = replace_ins outer_info inS_0 556 in 557 CJ(cond, adjust_ins s1' info_0, adjust_ins s2' info_0, info_0) 558 end 559 560 | back_trace (TR (cond, body, info)) (outer_info as {outs = outer_outs, context = contextL, ...}) = 561 let 562 val extra_outs = args_diff (PAIR (outer_outs, list2pair contextL), #outs info) 563 val info' = replace_context info (pair2list extra_outs); 564 val body' = adjust_ins body info' 565 in 566 TR(cond, body', info') 567 end 568 569 (* the adjustment will be performed later in the funCall.sml *) 570 | back_trace (CALL (fname, pre, body, post, info)) (outer_info as {outs = outer_outs, context = contextL, fspec = fout_spec, ...}) = 571 let 572 val outer_outs_set = pair2set outer_outs 573 val inner_outs_set = pair2set (#outs info); 574 val extra_outs_set = S.difference (outer_outs_set, inner_outs_set); 575 576 val inner_ins_set = pair2set (#ins info) 577 val new_ins_set = S.union(extra_outs_set, inner_ins_set); 578 579 val new_ins = set2pair new_ins_set; 580 val info' = replace_ins outer_info new_ins 581 in 582 CALL (fname, pre, BLK ([], info), post, info') 583 end; 584 585 fun set_info (ins,ir,outs) = 586 let 587 fun extract_outputs ir0 = 588 let val info0 = get_annt ir0; 589 val (inner_outS, outer_outS) = (IRSyntax.pair2set (#outs info0), IRSyntax.pair2set outs) 590 val ir1 = if Binaryset.equal(inner_outS, outer_outS) then ir0 591 else SC (ir0, BLK ([], {context = #context info0, ins = #outs info0, outs = outs, fspec = thm_t}), 592 replace_outs info0 outs) 593 val info1 = get_annt ir1 594 in 595 if #ins info0 = ins then ir1 596 else SC (BLK ([], {context = #context info1, ins = ins, outs = #ins info0, fspec = thm_t}), 597 ir1, replace_ins info1 ins) 598 end 599 val info = {ins = ins, outs = outs, context = [], fspec = thm_t} 600 in 601 extract_outputs (back_trace ir info) 602 end; 603(* ---------------------------------------------------------------------------------------------------------------------*) 604(* Adjust the IR tree to make inputs and outputs consistent *) 605(* ---------------------------------------------------------------------------------------------------------------------*) 606 607 fun match_ins_outs (SC(s1,s2,info)) = 608 let 609 val (ir1,ir2) = (match_ins_outs s1, match_ins_outs s2); 610 val (info1,info2) = (get_annt ir1, get_annt ir2); 611 val (outs1, ins2) = (#outs info1, #ins info2); 612 in 613 if outs1 = ins2 then 614 SC(ir1,ir2,info) 615 else 616 SC(ir1, 617 SC (BLK ([], {ins = outs1, outs = ins2, context = #context info2, fspec = thm_t}), ir2, 618 replace_ins info2 outs1), 619 info) 620 end 621 622 | match_ins_outs (ir as (CALL (fname, pre, body, post, info))) = 623 let 624 val ((pre_ins,post_outs),(outer_ins,outer_outs)) = ((#ins (get_annt pre), #outs (get_annt post)), (#ins info, #outs info)) 625 val ir' = if post_outs = outer_outs then ir 626 else 627 SC (ir, 628 BLK ([], {ins = post_outs, outs = outer_outs, context = #context info, fspec = thm_t}), 629 replace_ins info pre_ins) 630 in 631 if pre_ins = outer_ins then ir' 632 else 633 SC (BLK([], {ins = outer_ins, outs = pre_ins, context = #context info, fspec = thm_t}), 634 ir', 635 info) 636 end 637 638 | match_ins_outs ir = 639 ir; 640 641 642 (*---------------------------------------------------------------------------------*) 643 (* Interface *) 644 (*---------------------------------------------------------------------------------*) 645 646 fun build_ir (f_name, f_type, f_args, f_gr, f_outs, f_rs) = 647 let 648 val (ins, outs) = (IRSyntax.to_exp f_args, IRSyntax.to_exp f_outs) 649 val ir0 = convert_module (ins, f_gr, outs) 650 val ir1 = (merge_ir o rm_dummy_inst) ir0 651 val ir2 = set_info (ins,ir1,outs) 652 in 653 (ins,ir2,outs) 654 end 655 656 fun sfl2ir prog = 657 let 658 val env = ANF.toANF [] prog; 659 val defs = List.map (fn (name, (flag,src,anf,cps)) => src) env; 660 val (f_const,(_, f_defs, f_anf, f_ast)) = hd env 661 val setting as (f_name, f_type, f_args, f_gr, f_outs, f_rs) = regAllocation.convert_to_ARM (f_anf); 662 val f_ir = build_ir setting 663 in 664 (f_name, f_type, f_ir, defs) 665 end 666end 667end 668