1structure regAllocation = 2struct 3 4local open HolKernel Parse boolLib pairLib bossLib 5 6(* ---------------------------------------------------------------------------------------------------------------------*) 7(* Definition of types *) 8(* ---------------------------------------------------------------------------------------------------------------------*) 9 10structure T = IntMapTable(type key = int fun getInt n = n); 11structure S = Binaryset; 12structure G = Graph; 13 14in 15 16type edgeLab = int; 17 18(* ---------------------------------------------------------------------------------------------------------------------*) 19(* Machine model *) 20(* ---------------------------------------------------------------------------------------------------------------------*) 21 22(* NumRegs is the number of registers *) 23val NumRegs = ref (11); 24 25(* The following flag controls whether we spill only one node at a time *) 26(* If set to be false, then all potential nodes are spilled at once. *) 27val spillOneOnce = ref (true); 28 29(* ---------------------------------------------------------------------------------------------------------------------*) 30(* Inputs of the whole program *) 31(* ---------------------------------------------------------------------------------------------------------------------*) 32 33fun intOrder (s1:int,s2:int) = 34 if s1 > s2 then GREATER 35 else if s1 = s2 then EQUAL 36 else LESS; 37 38val precolored : (int S.set) ref = ref (S.empty intOrder); 39val firstnArgL : (int list) ref = ref []; 40 41val tmpTable : (string T.table) ref = ref (T.empty); 42fun newTmp () = 43 let 44 val tmps = T.listItems (!tmpTable); 45 46 fun getNewVarNo n = 47 case List.find (fn x => x = Temp.makestring n) tmps of 48 SOME y => getNewVarNo (n+1) 49 | NONE => n; 50 51 val newVarNo = getNewVarNo (Temp.newtemp()) 52 in 53 ( tmpTable := T.enter(!tmpTable, newVarNo, Temp.makestring newVarNo); 54 newVarNo 55 ) 56 end; 57 58val cfg : (({def : int list, instr : Assem.instr, use : int list}, edgeLab) G.graph) ref = ref (G.empty); 59 60val memIndex = ref (~1); (* the stack pointer pointing at the next available memory slot for spilling *) 61 62(* ---------------------------------------------------------------------------------------------------------------------*) 63(* for debugging *) 64(* ---------------------------------------------------------------------------------------------------------------------*) 65 66fun els ([]:(int Binaryset.set) list) = [] 67 | els (x::xs) = (S.listItems x)::els xs; 68 69(* ---------------------------------------------------------------------------------------------------------------------*) 70(* Data Structures *) 71(* ---------------------------------------------------------------------------------------------------------------------*) 72(* 73 Node work-lists, sets, and stacks: 74 precolored: machine registers, preassigned color 75 intial: temporary registers, not precolored and not yet processed 76 simplifyWorklist: list of low-degree non-move-related nodes 77 freezeWorklist: low-degree move-related nodes 78 spillWorklist: high-degree nodes 79 spilledNodes: nodes marked for spilling during this round; initially empty 80 coalescedNodes: registers that have been coalesced; when u <- v is coalesced, v 81 is added to this set and u put back on some work-list (or vice versa) 82 coloredNodes: nodes successfully colored 83 selectStack: stack containing temporaries removed from the graph 84 85 Move sets: 86 coalescedMoves: moves that have been coalesced 87 constainedMoves: moves whose source and target interfere 88 frozenMoves: moves that will no longer be considered for coalescing 89 worklistMoves: moves enabled for possible coalescing 90 activeMoves: moves not yet ready for coalescing 91 92 Other data Structures: 93 adjSet: the set of interference edges (u,v) in the graph. If (u,v) in adjSet then (v,u) in adjSet 94 adjList: adjacency list representation of the graph; for each non-precolored temporary u, 95 adjList[u] is the set of nodes that interfere with u 96 degree: an array containing the current degree of each node 97 moveList: a mapping from a node to the list of moves it is associated with 98 alias: when a move (u,v) has been coalesced, and v put in coalescedNodes, then alias(v) = u. 99 color: the color chosen by the algorithm for a node; for precolored nodes this is 100 initialized to the given color. 101*) 102 103fun intTupleOrder ((u1,v1):int*int,(u2,v2):int*int) = 104 if u1 > u2 then GREATER 105 else if u1 < u2 then LESS 106 else if v1 > v2 then GREATER 107 else if v1 < v2 then LESS 108 else EQUAL; 109 110fun buildNumList m n = 111 if m > n then [] 112 else m :: buildNumList (m+1) n; 113 114val initial : (int S.set) ref = ref (S.empty intOrder); 115 116val simplifyWorklist : (int S.set) ref = ref (S.empty intOrder); 117val freezeWorklist : (int S.set) ref = ref (S.empty intOrder); 118val spillWorklist : (int S.set) ref = ref (S.empty intOrder); 119val spilledNodes : (int S.set) ref = ref (S.empty intOrder); 120val toBeSpilled : (int S.set) ref = ref (S.empty intOrder); 121val coalescedNodes : (int S.set) ref = ref (S.empty intOrder); 122val coloredNodes : (int S.set) ref = ref (S.empty intOrder); 123val selectStack : (int Stack.stack) ref = ref (Stack.empty ()); 124 125val coalescedMoves : (int S.set) ref = ref (S.empty intOrder); 126val constrainedMoves: (int S.set) ref = ref (S.empty intOrder); 127val frozenMoves: (int S.set) ref = ref (S.empty intOrder); 128val worklistMoves: (int S.set) ref = ref (S.empty intOrder); 129val activeMoves: (int S.set) ref = ref (S.empty intOrder); 130 131val adjSet : ((int * int) S.set) ref = ref (S.empty intTupleOrder); 132val adjList : ((int S.set) T.table) ref = ref (T.empty); 133val degree : (int T.table) ref = ref (T.empty); 134val moveList : (int S.set T.table) ref = ref (T.empty); 135 136val alias : (int T.table) ref = ref (T.empty); 137val color : (int T.table) ref = ref (T.empty); 138 139val spilledTmps = ref (S.empty intOrder); 140 141val MAX_DEGREE = 65535; 142 143val memT : (int T.table) ref = ref (T.empty); 144 145(* ---------------------------------------------------------------------------------------------------------------------*) 146(* Auxiliary functions *) 147(* ---------------------------------------------------------------------------------------------------------------------*) 148 149fun forall p [] = true 150 | forall p (x::xs) = p(x) andalso forall p xs; 151 152fun zip ([], []) = [] 153 | zip (x::xs, y::ys) = (x,y) :: zip(xs,ys); 154 155 156fun getNodeLab (n:int) = 157 #3 (G.context(n, !cfg)); 158 159fun mk_edge ((n0:G.node,n1:G.node,lab:edgeLab), gr:('a,edgeLab) G.graph) = 160 if n0 = n1 then gr 161 else if (List.find (fn m => #2 m = n0) (#1(G.context(n1,gr)))) <> NONE then gr 162 else 163 let val ((p1,m1,l1,s1), del_n1) = G.match(n1,gr) 164 in G.embed(((lab,n0)::p1,m1,l1,s1),del_n1) 165 end; 166 167fun rm_edge ((n0:G.node,n1:G.node), gr:('a,edgeLab) G.graph) = 168 if n0 = n1 then gr 169 else if (List.find (fn m => #2 m = n0) (#1(G.context(n1,gr)))) = NONE then raise G.Edge 170 else 171 let val ((p1,m1,l1,s1), del_n1) = G.match(n1,gr) 172 in G.embed((List.filter (fn el => #2 el <> n0) (#1(G.context(n1,gr))), 173 m1,l1,s1),del_n1) 174 end; 175 176fun PrintProgram (cfg : ({instr : Assem.instr, use : int list, def : int list}, edgeLab) G.graph, tmpT) = 177 List.map (fn inst => print ((Assem.formatInst inst) ^ "\n")) 178 (CFG.linearizeCFG cfg); 179 180fun isMoveInst (Assem.MOVE {dst = Assem.TEMP d1, src = Assem.TEMP s1}) = true 181 | isMoveInst _ = false 182 183 184(* ---------------------------------------------------------------------------------------------------------------------*) 185(* Calculate liveness of a program *) 186(* ---------------------------------------------------------------------------------------------------------------------*) 187 188fun computeUseDef (gr : ({instr:Assem.instr, use:int list, def:int list}, edgeLab) Graph.graph) = 189 G.ufold (fn ((predL,nodeNo,inst,sucL), 190 (use:Temp.temp Binaryset.set T.table,def:Temp.temp Binaryset.set T.table)) => 191 ( let val (dst,src) = (#def inst, #use inst) in 192 ( T.enter(use, nodeNo, S.addList(S.empty intOrder, src)), 193 T.enter(def, nodeNo, S.addList(S.empty intOrder, dst))) 194 end)) 195 (T.empty,T.empty) gr; 196 197fun CalLiveness (cfg) 198= 199 let 200 val inS = List.map (fn elm => (S.empty intOrder)) (G.nodes cfg); 201 val outS = inS; 202 val (useT,defT) = computeUseDef cfg; 203 204 fun unionAll [] = S.empty intOrder 205 | unionAll (h::tl) = S.union(h, unionAll tl); 206 207 (* discard spurious edges (i.e. those whose labels equal to 2 ) *) 208 fun getSuc n = 209 let val sucEdges = #4 (G.context(n,cfg)) 210 val realEdges = List.filter (fn edge => not (#1 edge = 2)) sucEdges 211 in 212 List.map (fn edge => #2 edge) realEdges 213 end 214 215 fun compOut 0 inS = [] 216 | compOut n inS = 217 compOut (n-1) inS @ [unionAll (List.map (fn index => (List.nth(inS,index))) (getSuc (n-1)))]; 218 219 fun round (inS, outS)= 220 let 221 val old_inS = inS; 222 val old_outS = outS; 223 224 val inS = List.map 225 (fn n => S.union(T.look(useT,n), S.difference(List.nth(outS,n),T.look(defT,n)))) 226 (G.nodes cfg); 227 228 val outS = compOut (length inS) inS 229 in 230 if forall (fn (x1,x2) => S.equal(x1,x2)) (zip(inS,old_inS)) andalso 231 forall (fn (x1,x2) => S.equal(x1,x2)) (zip(outS,old_outS)) 232 then (inS, outS) 233 else round(inS,outS) 234 end 235 in 236 round (inS,outS) 237 end; 238 239 240(* -------------------------------------------------------------------------------------------------------------------*) 241(* Initialize data structures, then compute the values of these structures *) 242(* -------------------------------------------------------------------------------------------------------------------*) 243 244(* When a valid allocation isn't found, the whole process is needed to restart. Procedure init() initializes 245 data structures *) 246 247fun init (cfg, tmpTable) = 248 ( 249 250 spilledTmps := S.union (!spilledNodes, !spilledTmps); 251 252 initial := S.difference (S.difference(S.addList(S.empty intOrder, T.listKeys tmpTable), 253 !precolored), !spilledTmps); 254 255 simplifyWorklist := S.empty intOrder; 256 freezeWorklist := S.empty intOrder; 257 spillWorklist := S.empty intOrder; 258 spilledNodes := S.empty intOrder; 259 coalescedNodes := S.empty intOrder; 260 coloredNodes := S.empty intOrder; 261 selectStack := Stack.empty (); 262 263 coalescedMoves := S.empty intOrder; 264 constrainedMoves := S.empty intOrder; 265 frozenMoves := S.empty intOrder; 266 worklistMoves := S.empty intOrder; 267 activeMoves := S.empty intOrder; 268 269 adjSet := S.empty intTupleOrder; 270 adjList := List.foldl (fn (n,ll) => T.enter(ll,n,S.empty intOrder)) (!adjList) (T.listKeys (tmpTable)); 271 degree := List.foldl (fn (n,ll) => T.enter(ll,n,0)) (T.empty) (T.listKeys (tmpTable)); 272 moveList := !adjList; 273 274 alias := List.foldl (fn (n,ll) => T.enter(ll,n,~1)) (T.empty) (T.listKeys (tmpTable)); 275 color := !alias 276 277 ); 278 279 280fun AddEdge (u:G.node, v:G.node) = 281 if not (S.member(!adjSet,(u,v))) andalso (u <> v) then 282 (adjSet := S.add(S.add(!adjSet,(u,v)),(v,u)); 283 if not (S.member(!precolored,u)) then 284 (adjList := T.enter (!adjList, u, S.add (T.look(!adjList,u),v)); 285 degree := T.enter (!degree, u, T.look(!degree,u)+1)) 286 else (); 287 if not (S.member(!precolored,v)) then 288 (adjList := T.enter (!adjList, v, S.add (T.look(!adjList,v),u)); 289 degree := T.enter (!degree, v, T.look(!degree,v)+1)) 290 else ()) 291 else (); 292 293(* Construct the interference graph using the results of static liveness analysis, 294 and initializes the worklistMoves to contain all the moves. *) 295 296fun Build (cfg) = 297 let 298 299 val (inS,outS) = CalLiveness cfg; 300 301 fun investigate_one_node nodeNo = 302 let 303 val nd = getNodeLab nodeNo; 304 val (def,use) = (#def nd, #use nd); 305 val (def,use) = (S.addList(S.empty intOrder, def), S.addList(S.empty intOrder, use)); 306 val inst = #instr(nd); 307 val live = List.nth(outS,nodeNo); 308 val _ = if isMoveInst inst then 309 let val live = S.difference(live, use) 310 in moveList := List.foldl 311 (fn (varNo,lst) => T.enter(lst,varNo,S.add(T.look(lst,varNo),nodeNo))) 312 (!moveList) (S.listItems(S.union(def,use))); 313 worklistMoves := S.add (!worklistMoves,nodeNo) 314 end 315 else (); 316 val live = S.union(live, def); 317 val _ = List.foldl 318 (fn (d,lv) => List.foldl (fn (l,lv) => (AddEdge(l,d);lv)) lv (S.listItems live)) 319 live (S.listItems def) 320 in 321 S.union(use, S.difference(live, def)) 322 end 323 324 in 325 List.map investigate_one_node (G.nodes cfg) 326 end; 327 328 329fun NodeMoves n = 330 S.intersection(T.look(!moveList,n), 331 S.union(!activeMoves, !worklistMoves)); 332 333fun MakeWorklist() = 334 let 335 fun round n = 336 (initial := S.delete(!initial,n); 337 if T.look(!degree,n) >= (!NumRegs) then 338 spillWorklist := S.add(!spillWorklist,n) 339 else if not (S.isEmpty(NodeMoves n)) then 340 freezeWorklist := S.add(!freezeWorklist,n) 341 else 342 simplifyWorklist := S.add(!simplifyWorklist,n)) 343 344 val _ = List.foldl (fn (i,s) => round i) () (S.listItems (!initial)) 345 in 346 toBeSpilled := !spillWorklist 347 end; 348 349(* -------------------------------------------------------------------------------------------------------------------*) 350(* Merge moves, push simplified nodes into the stack *) 351(* -------------------------------------------------------------------------------------------------------------------*) 352 353fun Adjacent n = 354 S.difference (T.look(!adjList, n), 355 S.union(Stack.stack2set (!selectStack, S.empty intOrder), !coalescedNodes)); 356 357(* When the degree of a neighbor transitions from K to K - 1, moves associated with its neighbors may be enabled *) 358 359fun EnableMoves nodes = 360 let fun round n = 361 List.foldl 362 (fn (m,s) => ( 363 if S.member(!activeMoves,m) then 364 (activeMoves := S.delete (!activeMoves, m); 365 worklistMoves := S.add (!worklistMoves,m)) 366 else ())) 367 () (S.listItems (NodeMoves n)) 368 in 369 List.foldl (fn (n,s) => round n) () (S.listItems nodes) 370 end; 371 372 373fun DecrementDegree m = 374 let val d = T.look(!degree,m) in 375 (degree := T.enter(!degree, m, d-1); 376 if (d = !NumRegs) then 377 (EnableMoves (S.add(Adjacent m, m)); 378 spillWorklist := S.difference(!spillWorklist,S.add(S.empty intOrder,m)); 379 if not (S.isEmpty(NodeMoves m)) then 380 freezeWorklist := S.add (!freezeWorklist,m) 381 else 382 simplifyWorklist := S.add (!simplifyWorklist,m)) 383 else () 384 ) 385 end; 386 387(* Add a low degree node into the select stack *) 388 389fun Simplify() = 390 let val n = hd(S.listItems (!simplifyWorklist)) 391 in ( 392 simplifyWorklist := S.delete (!simplifyWorklist,n); 393 selectStack := Stack.push(n, !selectStack); 394 List.foldl (fn (m,x) => DecrementDegree m) () 395 (S.listItems (S.difference(Adjacent n, !precolored))) 396 ) 397 end; 398 399(* Coalesce moves *) 400 401fun AddWorklist u = 402 if not (S.member (!precolored,u)) andalso S.isEmpty(NodeMoves u) andalso 403 T.look(!degree,u) < (!NumRegs) then 404 (freezeWorklist := S.delete (!freezeWorklist,u); 405 simplifyWorklist := S.add (!simplifyWorklist,u)) 406 else (); 407 408fun OK (t,r) = 409 T.look(!degree,t) < (!NumRegs) orelse 410 S.member (!precolored,t) orelse 411 S.member (!adjSet, (t,r)); 412 413fun Conservative nodes = 414 let val k = List.foldl (fn (n,k) => if T.look(!degree,n) >= (!NumRegs) then k+1 else k) 0 (S.listItems nodes) 415 in k < (!NumRegs) 416 end; 417 418(* 419fun Conservative(u,v) = 420 let 421 fun dgr u n = if S.member(Adjacent u, n) then T.look(!degree,n)-1 else T.look(!degree,n); 422 val k = 423 List.foldl (fn (n,k) => if dgr n >= (!NumRegs) then k+1 else k) 0 (S.listItems (Adjacent v)) 424 in k < (!NumRegs) end; 425*) 426 427fun GetAlias n = 428 if S.member (!coalescedNodes, n) then 429 GetAlias (T.look(!alias,n)) 430 else n; 431 432fun Combine (u,v) = 433 ( if S.member(!freezeWorklist,v) then 434 freezeWorklist := S.delete (!freezeWorklist,v) 435 else 436 spillWorklist := S.difference (!spillWorklist,S.add(S.empty intOrder,v)); 437 coalescedNodes := S.add(!coalescedNodes,v); 438 alias := T.enter(!alias, v, GetAlias u); 439 moveList := T.enter (!moveList, u, S.union(T.look(!moveList,u), T.look(!moveList,v))); 440 EnableMoves (S.add(S.empty intOrder,v)); 441 List.foldl (fn (t,s) => (AddEdge(t,u);DecrementDegree t)) () 442 (S.listItems (Adjacent v)); 443 if T.look(!degree,u)>=(!NumRegs) andalso S.member(!freezeWorklist,u) then 444 (freezeWorklist := S.delete (!freezeWorklist,u); 445 spillWorklist := S.add (!spillWorklist,u)) 446 else () 447 ); 448 449fun Coalesce () = 450 let val m = hd(S.listItems (!worklistMoves)); 451 val inst = getNodeLab m; 452 val (def,use) = (#def inst, #use inst); 453 val (x,y) = (GetAlias (hd def), GetAlias (hd use)); 454 val (u,v) = if (S.member(!precolored,y)) 455 then (y,x) else (x,y) 456 in 457 (worklistMoves := S.delete (!worklistMoves,m); 458 if u = v then 459 ( coalescedMoves := S.add(!coalescedMoves,m); 460 AddWorklist u) 461 else if S.member(!precolored,v) orelse S.member(!adjSet,(u,v)) then 462 ( constrainedMoves := S.add(!constrainedMoves,m); 463 AddWorklist u; 464 AddWorklist v) 465 else if S.member(!precolored,u) andalso (forall (fn t => OK(t,u)) (S.listItems (Adjacent v))) 466 orelse not (S.member(!precolored,u)) andalso 467 Conservative (S.union(Adjacent u,Adjacent v)) then 468 ( coalescedMoves := S.add(!coalescedMoves,m); 469 Combine(u,v); 470 AddWorklist u) 471 else activeMoves := S.add(!activeMoves,m)) 472 end; 473 474fun FreezeMoves u = 475 let fun freezeOneMove m = 476 let 477 val inst = getNodeLab m; 478 val (x,y) = (hd (#def inst), hd (#use inst)); 479 val v = if GetAlias y = GetAlias u then GetAlias x else GetAlias y 480 in 481 ( activeMoves := S.difference (!activeMoves, S.add(S.empty intOrder,m)); 482 frozenMoves := S.add (!frozenMoves, m); 483 if S.isEmpty(NodeMoves v) andalso (T.look(!degree,v) < (!NumRegs)) then 484 ( freezeWorklist := S.difference (!freezeWorklist, S.add(S.empty intOrder, v)); 485 simplifyWorklist := S.add (!simplifyWorklist, v)) 486 else () 487 ) end 488 in List.foldl (fn (m,s) => freezeOneMove m) () (S.listItems (NodeMoves u)) 489 end; 490 491 492fun Freeze () = 493 let val u = hd (S.listItems (!freezeWorklist)) in 494 (freezeWorklist := S.delete (!freezeWorklist, u); 495 simplifyWorklist := S.add (!simplifyWorklist, u); 496 FreezeMoves u) 497 end; 498 499fun spillPriorities nodes = 500 List.foldl (fn (n,max) => if T.look(!degree,n) > T.look(!degree,max) then n else max) 501 (hd(S.listItems nodes)) (S.listItems nodes); 502 503fun SelectSpill () = 504 let 505 val m = spillPriorities (!spillWorklist) 506 in 507 ( spillWorklist := S.delete (!spillWorklist,m); 508 simplifyWorklist := S.add(!simplifyWorklist,m); 509 FreezeMoves m) 510 end; 511 512(* -------------------------------------------------------------------------------------------------------------------*) 513(* Assign colors to nodes, spill nodes when a valid allocation couldn't be found *) 514(* -------------------------------------------------------------------------------------------------------------------*) 515 516(* In a effort to compile with the APCS standard *) 517(* The arguemtns are in r0-r9 and then the stack *) 518 519fun AssignColors () = 520 let 521 522 fun initColors ~1 = S.empty intOrder 523 | initColors n = S.add(initColors (n-1),n); 524 525 fun assign_precolored () = 526 List.foldl (fn (n,c) => (color := T.enter(!color,n,c);c+1)) 0 (!firstnArgL); 527 528 fun spillNodes n = 529 if (!spillOneOnce) then 530 S.add(S.empty intOrder, spillPriorities (!toBeSpilled)) 531 else 532 (!toBeSpilled) 533 534 fun chaseColor n = 535 let val c = T.look(!color, GetAlias n) in 536 if (c = ~1) then 537 chaseColor (GetAlias n) 538 else c end; 539 540 fun assign () = 541 if Stack.isEmpty (!selectStack) then 542 (List.foldl (fn (n,s) => color := T.enter(!color,n,T.look(!color, GetAlias n))) 543 () 544 (S.listItems (!coalescedNodes)); 545 ()) 546 else 547 let val n = Stack.top (!selectStack); 548 val _ = (selectStack := Stack.pop(!selectStack)); 549 val okColors = initColors ((!NumRegs)-1); 550 val okColors = List.foldl (fn (w,s) => if S.member(S.union(!coloredNodes,!precolored),GetAlias w) then 551 S.difference (s, S.add(S.empty intOrder, T.look(!color, GetAlias w))) 552 else s) 553 okColors 554 (S.listItems (T.look(!adjList,n))) 555 in 556 if S.isEmpty(okColors) then 557 (spilledNodes := spillNodes n; 558 selectStack := Stack.empty ()) 559 else 560 ( coloredNodes := S.add(!coloredNodes, n); 561 color := T.enter(!color, n, hd (S.listItems okColors)); 562 assign() 563 ) 564 end 565 in 566 ( assign_precolored(); 567 assign() 568 ) 569 end; 570 571(* -------------------------------------------------------------------------------------------------------------------*) 572(* When a node is spilled, we modify the program by replacing the temporary with a memory slot *) 573(* -------------------------------------------------------------------------------------------------------------------*) 574 575fun updateNode(cfg, n:int, inst) = 576 let 577 val ((preN,nodeNo,nodeLab,sucN),cfg1) = G.match (n,cfg); 578 in 579 G.embed((preN,nodeNo,inst,sucN), cfg1) 580 end; 581 582 583fun insertBefore(cfg, n:int, inst) = 584 let 585 val (preN,nodeNo,nodeLab,sucN) = G.context (n,cfg); 586 val newCfg = List.foldl (fn (p,gr) => rm_edge((p,n),gr)) cfg (List.map (fn (lab,n) => n) preN); 587 val ct = (preN, G.noNodes cfg, inst, [(0,n)]) 588 in 589 G.embed(ct, newCfg) 590 591 end; 592 593fun insertAfter(cfg, n:int, inst) = 594 let 595 val (preN,nodeNo,nodeLab,sucN) = G.context (n,cfg); 596 val newCfg = List.foldl (fn (s,gr) => rm_edge((n,s),gr)) cfg (List.map (fn (lab,n) => n) sucN); 597 val ct = ([(0,n)], G.noNodes cfg, inst, sucN) 598 in 599 G.embed(ct, newCfg) 600 end; 601 602 603(* t : TmpTable, gr : control-flow graph *) 604fun updateProgram (old_cfg, spilled : int S.set) = 605 let 606 607 fun replace l old new = List.map (fn x => (if (x = old) then new else x)) l; 608 fun substituteVars (Assem.OPER {oper = p, dst = d1, src = s1, jump = j1}) old new rhs lhs = 609 Assem.OPER {oper = p, dst = if rhs then replace d1 old new else d1, 610 src = if lhs then replace s1 old new else s1, jump = j1} 611 | substituteVars (Assem.LABEL x) old new rhs lhs = Assem.LABEL x 612 | substituteVars (Assem.MOVE {dst = d1, src = s1}) old new rhs lhs = 613 Assem.MOVE {dst = if rhs andalso d1 = old then new else d1, 614 src = if lhs andalso s1 = old then new else s1}; 615 616 fun not_bl (Assem.OPER {oper = (Assem.BL, NONE, false), ...}) = false 617 | not_bl _ = true 618 fun is_nop (Assem.OPER {oper = (Assem.NOP, NONE, false), ...}) = true 619 | is_nop _ = false 620 621 622 fun update_node gr nodeNo (varNo,newVarNo) (for_lhs,for_rhs) = 623 let 624 val {instr = curInst, def = df, use = us} = #3 (G.context(nodeNo,gr)) 625 in 626 updateNode(gr, nodeNo, 627 (if not_bl curInst andalso not (is_nop curInst) then 628 { def = if for_lhs andalso List.exists (fn n => n = varNo) df then newVarNo :: 629 (List.filter (fn n => not (n = varNo orelse n = newVarNo)) df) else df, 630 use = if for_rhs andalso List.exists (fn n => n = varNo) us then newVarNo :: 631 (List.filter (fn n => not (n = varNo orelse n = newVarNo)) us) else us, 632 instr = substituteVars curInst (Assem.TEMP varNo) (Assem.TEMP newVarNo) for_lhs for_rhs 633 } 634 else 635 { def = if for_lhs then (List.filter (fn n => not (n = varNo)) df) else df, 636 use = if for_rhs then (List.filter (fn n => not (n = varNo)) us) else us, 637 instr = substituteVars curInst (Assem.TEMP varNo) (Assem.TMEM (!memIndex)) for_lhs for_rhs 638 } 639 )) 640 end 641 642 fun insertLoadInst gr nodeNo varNo = 643 let val newVarNo = newTmp (); 644 val {instr = curInst, def = df, use = us} = #3 (G.context(nodeNo,gr)); 645 val newInst = {instr = Assem.OPER {oper = (Assem.LDR,NONE,false), 646 dst = [Assem.TEMP newVarNo], 647 src = [Assem.TMEM (!memIndex)], 648 jump = NONE}, 649 def = [newVarNo], use = []}; 650 val gr1 = update_node gr nodeNo (varNo,newVarNo) (false,true) 651 in 652 if not_bl curInst andalso not (is_nop curInst) then 653 insertBefore(gr1, nodeNo, newInst) 654 else gr1 655 end; 656 657 fun insertStoreInst gr nodeNo varNo = 658 let val newVarNo = newTmp (); 659 val {instr = curInst, def = df, use = us} = #3 (G.context(nodeNo,gr)); 660 val newInst = {instr = Assem.OPER {oper = (Assem.STR,NONE,false), 661 dst = [Assem.TMEM (!memIndex)], 662 src = [Assem.TEMP newVarNo], 663 jump = NONE}, 664 def = [], use = [newVarNo]}; 665 val gr1 = update_node gr nodeNo (varNo,newVarNo) (true,false) 666 in 667 if not_bl curInst andalso not (is_nop curInst) then 668 insertAfter(gr1, nodeNo, newInst) 669 else gr1 670 end; 671 672 fun process_one_variable gr varNo = 673 let 674 val _ = (memIndex:= !memIndex + 1; 675 memT := T.enter(!memT, varNo, !memIndex) 676 ) 677 in 678 G.ufold (fn ((predL,nodeNo,inst,sucL),gr) => 679 let val (def,use) = (#def inst, #use inst); 680 val (def,use) = (S.addList(S.empty intOrder, def), S.addList(S.empty intOrder, use)); 681 val gr1 = if S.member(use, varNo) then 682 (if isMoveInst (#instr inst) then 683 let val (Assem.MOVE {dst = d1, src = s1}) = #instr inst in 684 updateNode(gr,nodeNo, 685 {instr = Assem.OPER {oper = (Assem.LDR,NONE,false), dst = [d1], 686 src = [Assem.TMEM (!memIndex)], jump = NONE}, 687 def = #def inst, use = []}) 688 end 689 else 690 insertLoadInst gr nodeNo varNo 691 ) 692 else gr 693 in 694 if S.member(def, varNo) then 695 insertStoreInst gr1 nodeNo varNo 696 else gr1 end) 697 gr gr 698 end 699 700 in 701 List.foldl (fn (varNo, gr) => process_one_variable gr varNo) old_cfg (S.listItems (spilled)) 702 end; 703 704 705fun RewriteProgram () = 706 ( cfg := updateProgram (!cfg,!spilledNodes); 707 init (!cfg,!tmpTable) 708 ); 709 710 711(* -------------------------------------------------------------------------------------------------------------------*) 712(* Register Allocation *) 713(* -------------------------------------------------------------------------------------------------------------------*) 714 715fun printAllocation () = 716 let 717 fun color2reg c = 718 if c >= 0 andalso c < (!NumRegs) then "r" ^ Int.toString c 719 else "not found"; 720 721 fun allocation tmps = 722 List.foldl 723 (fn (n,_) => if S.member (!spilledTmps, n) then () 724 else let val c = T.look(!color,n) in 725 print (T.look(!tmpTable, n) ^ "\t" ^ (Int.toString c) ^ "\t" ^ (color2reg c) ^ "\n") 726 end) 727 () (T.listKeys (!tmpTable)) 728 in 729 ( print "Node\tColor\tRegister\n"; 730 allocation ((T.numItems (!tmpTable)) - 1)) 731 end; 732 733fun AllocateReg () = 734 let 735 fun allocate () = 736 ( if not (S.isEmpty (!simplifyWorklist)) then Simplify() 737 else if not (S.isEmpty (!worklistMoves)) then Coalesce() 738 else if not (S.isEmpty (!freezeWorklist)) then Freeze() 739 else if not (S.isEmpty (!spillWorklist)) then SelectSpill () 740 else (); 741 if S.isEmpty (!simplifyWorklist) andalso S.isEmpty (!worklistMoves) andalso 742 S.isEmpty (!freezeWorklist) andalso S.isEmpty (!spillWorklist) then 743 () 744 else allocate ()) 745 746 in 747 ( Build(!cfg); 748 MakeWorklist(); 749 allocate(); 750 AssignColors(); 751 if not (S.isEmpty (!spilledNodes)) then 752 ( RewriteProgram (); 753 AllocateReg ()) 754 else () 755 ) 756 end; 757 758 759fun RewrWithReg () = 760 let 761 762 fun subs (Assem.PAIR(a,b)) = Assem.PAIR(subs a, subs b) 763 | subs ( Assem.TEMP n) = Assem.REG (T.look(!color, n)) 764 | subs x = x 765 766 fun replace ll = List.map subs ll; 767 fun substituteVars (Assem.OPER {oper = p, dst = d1, src = s1, jump = j1}) = 768 Assem.OPER {oper = p, dst = replace d1, src = replace s1, jump = j1} 769 | substituteVars (Assem.LABEL x) = Assem.LABEL x 770 | substituteVars (Assem.MOVE {dst = d1, src = s1}) = 771 Assem.MOVE {dst = hd (replace [d1]), src = hd (replace [s1])}; 772 773 in 774 G.ufold (fn ((predL,nodeNo,nd,sucL),gr) => 775 let val {use = us, def = df, instr = stm} = nd in 776 updateNode (gr, nodeNo, {use = us, def = df, instr = substituteVars stm}) 777 end) (!cfg) (!cfg) 778 end; 779 780 781(* ---------------------------------------------------------------------------------------------------------------------*) 782(* Interface *) 783(* ---------------------------------------------------------------------------------------------------------------------*) 784 785fun RegisterAllocation (gr, tmpT, preC) = 786 ( tmpTable := tmpT; 787 788 firstnArgL := List.take (preC, if length preC > !NumRegs then !NumRegs else length preC); 789 precolored := S.addList(S.empty intOrder, !firstnArgL); 790 791 cfg := gr; 792 793 memIndex := ~1; 794 memT := T.empty; 795 796 spilledTmps := S.empty intOrder; 797 798 init (gr, tmpT); 799 800 AllocateReg (); 801 802 PrintProgram(!cfg, !tmpTable); 803 printAllocation (); 804 (RewrWithReg (),!tmpTable) 805 ); 806 807 fun regOrder (r1,r2) = 808 let val (Assem.REG n1, Assem.REG n2) = (r1, r2) in 809 if n1 > n2 then GREATER 810 else if n1 = n2 then EQUAL 811 else LESS 812 end 813 814 fun getModifiedRegs stms = 815 List.foldl (fn (Assem.OPER {dst = d, ...}, regs) => 816 (List.foldl (fn (Assem.REG r, regs) => Binaryset.add(regs,Assem.REG r) 817 | _ => regs) regs d) 818 | (Assem.MOVE {dst = d, ...}, regs) => Binaryset.add(regs,d) 819 | (_,regs) => regs 820 ) 821 (Binaryset.empty regOrder) stms 822 823fun convert_to_ARM prog = 824 let 825 fun replace_with_regs (Assem.PAIR(v1,v2)) = 826 Assem.PAIR (replace_with_regs v1, replace_with_regs v2) 827 | replace_with_regs (Assem.TEMP v) = 828 ( case T.peek(!memT, v) of 829 NONE => Assem.REG (T.look(!color, v)) 830 | SOME n => Assem.TMEM n 831 ) 832 | replace_with_regs v = v 833 834 val ((fun_name, fun_type, args, gr, outs), t) = CFG.convert_to_CFG prog; 835 val argL = List.map Assem.eval_exp (Assem.pair2list args); 836 val (gr',t) = RegisterAllocation (gr, t, argL); 837 838 val (args,outs) = (replace_with_regs args, replace_with_regs outs); 839 val stms = CFG.linearizeCFG gr'; 840 val rs = getModifiedRegs stms 841 842 in 843 (fun_name, fun_type, args, gr', outs, rs) 844 end 845 846(* ---------------------------------------------------------------------------------------------------------------------*) 847(* Verification of register allocation on origial program *) 848(* ---------------------------------------------------------------------------------------------------------------------*) 849 850fun RewrBody rules body = 851 let 852 fun stripInst insts = 853 if not (is_let insts) then subst rules insts 854 else 855 let 856 val (lt, rhs) = dest_let insts; 857 val (lhs, rest) = dest_pabs lt; 858 val lhs = subst rules lhs 859 in 860 mk_let (mk_pabs (lhs, stripInst rest), subst rules rhs) 861 end 862 in 863 stripInst body 864 end 865 866 867fun check_allocation prog = 868 let 869 fun mk_reg_rules tp = 870 List.map (fn n => {redex = mk_var (T.look(!tmpTable,n), tp), 871 residue = mk_var ("r" ^ Int.toString (T.look(!color,n)), tp)}) 872 (List.filter (fn n => n >= 0 andalso T.peek(!memT,n) = NONE) (T.listKeys (!tmpTable))); 873 874 fun mk_mem_rules tp = 875 List.map (fn n => {redex = mk_var (T.look(!tmpTable,n), tp), 876 residue = mk_var ("m" ^ Int.toString(T.look(!memT,n)), tp)}) 877 (List.filter (fn n => n >= 0) (T.listKeys (!memT))); 878 879 val rw = (RewrBody (mk_mem_rules (Type `:num`))) o (RewrBody (mk_mem_rules (Type `:word32`))) o 880 (RewrBody (mk_reg_rules (Type `:num`))) o (RewrBody (mk_reg_rules (Type `:word32`))); 881 882 fun replace exp = 883 if is_let exp then 884 let val (lt, rhs) = dest_let exp; 885 val (lhs, rest) = dest_pabs lt 886 in 887 mk_let (mk_pabs(replace lhs, replace rest), replace rhs) 888 end 889 else if is_cond exp then 890 let val (c,t,f) = dest_cond exp 891 in 892 mk_cond (replace c, replace t, replace f) 893 end 894 else if is_pair exp then 895 let val (t1,t2) = dest_pair exp 896 in mk_pair (replace t1, replace t2) 897 end 898 else rw exp 899 900 val (decl,body) = 901 dest_eq(concl(SPEC_ALL prog)) 902 handle HOL_ERR _ 903 => (print "not an program in function format\n"; 904 raise ERR "buildCFG" "invalid program format"); 905 val (f, args) = dest_comb decl; 906 907 val newDecl = mk_comb (f, rw args) 908 val newProg = mk_eq(newDecl, replace body) 909 910 in 911 GEN_ALL (prove ( 912 newProg, 913 METIS_TAC [prog, LET_THM] 914 )) 915 end 916 917end (* local open *) 918end (* structure *) 919 920