1structure regAllocation =
2struct
3
4local open HolKernel Parse boolLib pairLib bossLib
5
6(* ---------------------------------------------------------------------------------------------------------------------*)
7(* Definition of types                                                                                                  *)
8(* ---------------------------------------------------------------------------------------------------------------------*)
9
10structure T = IntMapTable(type key = int fun getInt n = n);
11structure S = Binaryset;
12structure G = Graph;
13
14in
15
16type edgeLab = int;
17
18(* ---------------------------------------------------------------------------------------------------------------------*)
19(* Machine model                                                                                                        *)
20(* ---------------------------------------------------------------------------------------------------------------------*)
21
22(* NumRegs is the number of registers   *)
23val NumRegs = ref (11);
24
25(* The following flag controls whether we spill only one node at a time         *)
26(* If set to be false, then all potential nodes are spilled at once.            *)
27val spillOneOnce = ref (true);
28
29(* ---------------------------------------------------------------------------------------------------------------------*)
30(* Inputs of the whole program                                                                                          *)
31(* ---------------------------------------------------------------------------------------------------------------------*)
32
33fun intOrder (s1:int,s2:int) =
34  if s1 > s2 then GREATER
35  else if s1 = s2 then EQUAL
36  else LESS;
37
38val precolored : (int S.set) ref = ref (S.empty intOrder);
39val firstnArgL : (int list) ref = ref [];
40
41val tmpTable : (string T.table) ref = ref (T.empty);
42fun newTmp () =
43   let
44      val tmps = T.listItems (!tmpTable);
45
46      fun getNewVarNo n =
47          case List.find (fn x => x = Temp.makestring n) tmps of
48            SOME y => getNewVarNo (n+1)
49          | NONE => n;
50
51      val newVarNo = getNewVarNo (Temp.newtemp())
52   in
53      ( tmpTable := T.enter(!tmpTable, newVarNo, Temp.makestring newVarNo);
54        newVarNo
55      )
56   end;
57
58val cfg : (({def : int list, instr : Assem.instr, use : int list}, edgeLab) G.graph) ref  = ref (G.empty);
59
60val memIndex = ref (~1);         (* the stack pointer pointing at the next available memory slot for spilling   *)
61
62(* ---------------------------------------------------------------------------------------------------------------------*)
63(* for debugging                                                                                                        *)
64(* ---------------------------------------------------------------------------------------------------------------------*)
65
66fun els ([]:(int Binaryset.set) list) = []
67 |  els (x::xs) = (S.listItems x)::els xs;
68
69(* ---------------------------------------------------------------------------------------------------------------------*)
70(* Data Structures                                                                                                      *)
71(* ---------------------------------------------------------------------------------------------------------------------*)
72(*
73   Node work-lists, sets, and stacks:
74        precolored: machine registers, preassigned color
75        intial: temporary registers, not precolored and not yet processed
76        simplifyWorklist: list of low-degree non-move-related nodes
77        freezeWorklist: low-degree move-related nodes
78        spillWorklist: high-degree nodes
79        spilledNodes: nodes marked for spilling during this round; initially empty
80        coalescedNodes: registers that have been coalesced; when u <- v is coalesced, v
81                        is added to this set and u put back on some work-list (or vice versa)
82        coloredNodes: nodes successfully colored
83        selectStack: stack containing temporaries removed from the graph
84
85   Move sets:
86        coalescedMoves: moves that have been coalesced
87        constainedMoves: moves whose source and target interfere
88        frozenMoves: moves that will no longer be considered for coalescing
89        worklistMoves: moves enabled for possible coalescing
90        activeMoves: moves not yet ready for coalescing
91
92   Other data Structures:
93        adjSet: the set of interference edges (u,v) in the graph. If (u,v) in adjSet then (v,u) in adjSet
94        adjList: adjacency list representation of the graph; for each non-precolored temporary u,
95                adjList[u] is the set of nodes that interfere with u
96        degree: an array containing the current degree of each node
97        moveList: a mapping from a node to the list of moves it is associated with
98        alias: when a move (u,v) has been coalesced, and v put in coalescedNodes, then alias(v) = u.
99        color: the color chosen by the algorithm for a node; for precolored nodes this is
100                initialized to the given color.
101*)
102
103fun intTupleOrder ((u1,v1):int*int,(u2,v2):int*int) =
104  if u1 > u2 then GREATER
105  else if u1 < u2 then LESS
106  else if v1 > v2 then GREATER
107  else if v1 < v2 then LESS
108  else EQUAL;
109
110fun buildNumList m n =
111  if m > n then []
112  else m :: buildNumList (m+1) n;
113
114val initial : (int S.set) ref = ref (S.empty intOrder);
115
116val simplifyWorklist : (int S.set) ref = ref (S.empty intOrder);
117val freezeWorklist : (int S.set) ref = ref (S.empty intOrder);
118val spillWorklist : (int S.set) ref = ref (S.empty intOrder);
119val spilledNodes : (int S.set) ref = ref (S.empty intOrder);
120val toBeSpilled : (int S.set) ref = ref (S.empty intOrder);
121val coalescedNodes : (int S.set) ref = ref (S.empty intOrder);
122val coloredNodes : (int S.set) ref = ref (S.empty intOrder);
123val selectStack  : (int Stack.stack) ref = ref (Stack.empty ());
124
125val coalescedMoves : (int S.set) ref = ref (S.empty intOrder);
126val constrainedMoves: (int S.set) ref = ref (S.empty intOrder);
127val frozenMoves: (int S.set) ref = ref (S.empty intOrder);
128val worklistMoves: (int S.set) ref = ref (S.empty intOrder);
129val activeMoves: (int S.set) ref = ref (S.empty intOrder);
130
131val adjSet : ((int * int) S.set) ref = ref (S.empty intTupleOrder);
132val adjList : ((int S.set) T.table) ref = ref (T.empty);
133val degree : (int T.table) ref = ref (T.empty);
134val moveList : (int S.set T.table) ref = ref (T.empty);
135
136val alias : (int T.table) ref = ref (T.empty);
137val color : (int T.table) ref = ref (T.empty);
138
139val spilledTmps = ref (S.empty intOrder);
140
141val MAX_DEGREE = 65535;
142
143val memT : (int T.table) ref = ref (T.empty);
144
145(* ---------------------------------------------------------------------------------------------------------------------*)
146(* Auxiliary functions                                                                                                  *)
147(* ---------------------------------------------------------------------------------------------------------------------*)
148
149fun forall p []      = true
150  | forall p (x::xs) = p(x) andalso forall p xs;
151
152fun zip ([], [])      = []
153  | zip (x::xs, y::ys) = (x,y) :: zip(xs,ys);
154
155
156fun getNodeLab (n:int) =
157   #3 (G.context(n, !cfg));
158
159fun mk_edge ((n0:G.node,n1:G.node,lab:edgeLab), gr:('a,edgeLab) G.graph) =
160    if n0 = n1 then gr
161    else if (List.find (fn m => #2 m = n0) (#1(G.context(n1,gr)))) <> NONE then gr
162    else
163       let val ((p1,m1,l1,s1), del_n1) = G.match(n1,gr)
164       in G.embed(((lab,n0)::p1,m1,l1,s1),del_n1)
165    end;
166
167fun rm_edge ((n0:G.node,n1:G.node), gr:('a,edgeLab) G.graph) =
168    if n0 = n1 then gr
169    else if (List.find (fn m => #2 m = n0) (#1(G.context(n1,gr)))) = NONE then raise G.Edge
170    else
171       let val ((p1,m1,l1,s1), del_n1) = G.match(n1,gr)
172       in G.embed((List.filter (fn el => #2 el <> n0) (#1(G.context(n1,gr))),
173                m1,l1,s1),del_n1)
174    end;
175
176fun PrintProgram (cfg : ({instr : Assem.instr, use : int list, def : int list}, edgeLab) G.graph, tmpT) =
177    List.map (fn inst => print ((Assem.formatInst inst) ^ "\n"))
178                (CFG.linearizeCFG cfg);
179
180fun isMoveInst  (Assem.MOVE {dst = Assem.TEMP d1, src = Assem.TEMP s1}) = true
181 |  isMoveInst _ = false
182
183
184(* ---------------------------------------------------------------------------------------------------------------------*)
185(* Calculate liveness of a program                                                                                      *)
186(* ---------------------------------------------------------------------------------------------------------------------*)
187
188fun computeUseDef (gr : ({instr:Assem.instr, use:int list, def:int list}, edgeLab) Graph.graph) =
189      G.ufold (fn ((predL,nodeNo,inst,sucL),
190        (use:Temp.temp Binaryset.set T.table,def:Temp.temp Binaryset.set T.table)) =>
191              ( let val (dst,src) = (#def inst, #use inst) in
192                    ( T.enter(use, nodeNo, S.addList(S.empty intOrder, src)),
193                      T.enter(def, nodeNo, S.addList(S.empty intOrder, dst)))
194                end))
195                (T.empty,T.empty) gr;
196
197fun CalLiveness (cfg)
198=
199  let
200     val inS = List.map (fn elm => (S.empty intOrder)) (G.nodes cfg);
201     val outS = inS;
202     val (useT,defT) = computeUseDef cfg;
203
204     fun unionAll [] = S.empty intOrder
205      |  unionAll (h::tl) = S.union(h, unionAll tl);
206
207     (* discard spurious edges (i.e. those whose labels equal to 2 ) *)
208     fun getSuc n =
209        let val sucEdges = #4 (G.context(n,cfg))
210            val realEdges = List.filter (fn edge => not (#1 edge = 2)) sucEdges
211        in
212            List.map (fn edge => #2 edge) realEdges
213        end
214
215     fun compOut 0 inS = []
216      |  compOut n inS =
217           compOut (n-1) inS @ [unionAll (List.map (fn index => (List.nth(inS,index))) (getSuc (n-1)))];
218
219     fun round (inS, outS)=
220       let
221        val old_inS = inS;
222        val old_outS = outS;
223
224        val inS = List.map
225                       (fn n => S.union(T.look(useT,n), S.difference(List.nth(outS,n),T.look(defT,n))))
226                       (G.nodes cfg);
227
228        val outS = compOut (length inS) inS
229        in
230          if forall (fn (x1,x2) => S.equal(x1,x2)) (zip(inS,old_inS)) andalso
231             forall (fn (x1,x2) => S.equal(x1,x2)) (zip(outS,old_outS))
232          then (inS, outS)
233          else round(inS,outS)
234        end
235     in
236        round (inS,outS)
237  end;
238
239
240(* -------------------------------------------------------------------------------------------------------------------*)
241(* Initialize data structures, then compute the values of these structures                                            *)
242(* -------------------------------------------------------------------------------------------------------------------*)
243
244(* When a valid allocation isn't found, the whole process is needed to restart. Procedure init() initializes
245   data structures                                                                                                      *)
246
247fun init (cfg, tmpTable) =
248    (
249
250        spilledTmps := S.union (!spilledNodes, !spilledTmps);
251
252        initial := S.difference (S.difference(S.addList(S.empty intOrder, T.listKeys tmpTable),
253                         !precolored), !spilledTmps);
254
255        simplifyWorklist := S.empty intOrder;
256        freezeWorklist := S.empty intOrder;
257        spillWorklist := S.empty intOrder;
258        spilledNodes := S.empty intOrder;
259        coalescedNodes := S.empty intOrder;
260        coloredNodes := S.empty intOrder;
261        selectStack  := Stack.empty ();
262
263        coalescedMoves := S.empty intOrder;
264        constrainedMoves := S.empty intOrder;
265        frozenMoves := S.empty intOrder;
266        worklistMoves := S.empty intOrder;
267        activeMoves := S.empty intOrder;
268
269        adjSet := S.empty intTupleOrder;
270        adjList := List.foldl (fn (n,ll) => T.enter(ll,n,S.empty intOrder)) (!adjList) (T.listKeys (tmpTable));
271        degree := List.foldl (fn (n,ll) => T.enter(ll,n,0)) (T.empty) (T.listKeys (tmpTable));
272        moveList := !adjList;
273
274        alias := List.foldl (fn (n,ll) => T.enter(ll,n,~1)) (T.empty) (T.listKeys (tmpTable));
275        color := !alias
276
277    );
278
279
280fun AddEdge (u:G.node, v:G.node) =
281    if not (S.member(!adjSet,(u,v))) andalso (u <> v) then
282       (adjSet := S.add(S.add(!adjSet,(u,v)),(v,u));
283        if not (S.member(!precolored,u)) then
284             (adjList := T.enter (!adjList, u, S.add (T.look(!adjList,u),v));
285              degree := T.enter (!degree, u, T.look(!degree,u)+1))
286        else ();
287        if not (S.member(!precolored,v)) then
288             (adjList := T.enter (!adjList, v, S.add (T.look(!adjList,v),u));
289              degree := T.enter (!degree, v, T.look(!degree,v)+1))
290        else ())
291    else ();
292
293(* Construct the interference graph using the results of static liveness analysis,
294   and initializes the worklistMoves to contain all the moves.                          *)
295
296fun Build (cfg) =
297  let
298
299    val (inS,outS) = CalLiveness cfg;
300
301    fun investigate_one_node nodeNo =
302      let
303        val nd = getNodeLab nodeNo;
304        val (def,use) = (#def nd, #use nd);
305        val (def,use) = (S.addList(S.empty intOrder, def), S.addList(S.empty intOrder, use));
306        val inst = #instr(nd);
307        val live = List.nth(outS,nodeNo);
308        val _ = if isMoveInst inst then
309               let val live = S.difference(live, use)
310               in  moveList := List.foldl
311                    (fn (varNo,lst) => T.enter(lst,varNo,S.add(T.look(lst,varNo),nodeNo)))
312                          (!moveList) (S.listItems(S.union(def,use)));
313                worklistMoves := S.add (!worklistMoves,nodeNo)
314               end
315             else ();
316        val live = S.union(live, def);
317        val _ = List.foldl
318                (fn (d,lv) => List.foldl (fn (l,lv) => (AddEdge(l,d);lv)) lv (S.listItems live))
319                live (S.listItems def)
320      in
321        S.union(use, S.difference(live, def))
322      end
323
324  in
325      List.map investigate_one_node (G.nodes cfg)
326  end;
327
328
329fun NodeMoves n =
330   S.intersection(T.look(!moveList,n),
331                  S.union(!activeMoves, !worklistMoves));
332
333fun MakeWorklist() =
334   let
335      fun round n =
336        (initial := S.delete(!initial,n);
337         if T.look(!degree,n) >= (!NumRegs) then
338            spillWorklist := S.add(!spillWorklist,n)
339         else if not (S.isEmpty(NodeMoves n)) then
340            freezeWorklist := S.add(!freezeWorklist,n)
341         else
342            simplifyWorklist := S.add(!simplifyWorklist,n))
343
344      val _ = List.foldl (fn (i,s) => round i) () (S.listItems (!initial))
345   in
346         toBeSpilled := !spillWorklist
347   end;
348
349(* -------------------------------------------------------------------------------------------------------------------*)
350(* Merge moves, push simplified nodes into the stack                                                                  *)
351(* -------------------------------------------------------------------------------------------------------------------*)
352
353fun Adjacent n =
354   S.difference (T.look(!adjList, n),
355        S.union(Stack.stack2set (!selectStack, S.empty intOrder), !coalescedNodes));
356
357(* When the degree of a neighbor transitions from K to K - 1, moves associated with its neighbors may be enabled *)
358
359fun EnableMoves nodes =
360   let fun round n =
361       List.foldl
362           (fn (m,s) => (
363                if S.member(!activeMoves,m) then
364                    (activeMoves := S.delete (!activeMoves, m);
365                     worklistMoves := S.add (!worklistMoves,m))
366                else ()))
367             ()  (S.listItems (NodeMoves n))
368   in
369       List.foldl (fn (n,s) => round n) () (S.listItems nodes)
370   end;
371
372
373fun DecrementDegree m =
374   let val d =  T.look(!degree,m) in
375     (degree := T.enter(!degree, m, d-1);
376      if (d = !NumRegs) then
377        (EnableMoves (S.add(Adjacent m, m));
378         spillWorklist := S.difference(!spillWorklist,S.add(S.empty intOrder,m));
379         if not (S.isEmpty(NodeMoves m)) then
380                freezeWorklist := S.add (!freezeWorklist,m)
381         else
382                simplifyWorklist := S.add (!simplifyWorklist,m))
383      else ()
384      )
385   end;
386
387(* Add a low degree node into the select stack *)
388
389fun Simplify() =
390   let val n = hd(S.listItems (!simplifyWorklist))
391   in (
392        simplifyWorklist := S.delete (!simplifyWorklist,n);
393        selectStack := Stack.push(n, !selectStack);
394        List.foldl (fn (m,x) => DecrementDegree m) ()
395            (S.listItems (S.difference(Adjacent n, !precolored)))
396      )
397   end;
398
399(* Coalesce moves                                                    *)
400
401fun AddWorklist u =
402  if not (S.member (!precolored,u)) andalso S.isEmpty(NodeMoves u) andalso
403        T.look(!degree,u) < (!NumRegs) then
404        (freezeWorklist := S.delete (!freezeWorklist,u);
405         simplifyWorklist := S.add (!simplifyWorklist,u))
406  else ();
407
408fun OK (t,r) =
409   T.look(!degree,t) < (!NumRegs) orelse
410   S.member (!precolored,t) orelse
411   S.member (!adjSet, (t,r));
412
413fun Conservative nodes =
414   let val k = List.foldl (fn (n,k) => if T.look(!degree,n) >= (!NumRegs) then k+1 else k) 0 (S.listItems nodes)
415   in  k < (!NumRegs)
416   end;
417
418(*
419fun Conservative(u,v) =
420   let
421     fun dgr u n = if S.member(Adjacent u, n) then T.look(!degree,n)-1 else T.look(!degree,n);
422   val k =
423        List.foldl (fn (n,k) => if dgr n >= (!NumRegs) then k+1 else k) 0 (S.listItems (Adjacent v))
424   in k < (!NumRegs) end;
425*)
426
427fun GetAlias n =
428   if S.member (!coalescedNodes, n) then
429        GetAlias (T.look(!alias,n))
430   else n;
431
432fun Combine (u,v) =
433   ( if S.member(!freezeWorklist,v) then
434        freezeWorklist := S.delete (!freezeWorklist,v)
435     else
436        spillWorklist := S.difference (!spillWorklist,S.add(S.empty intOrder,v));
437     coalescedNodes := S.add(!coalescedNodes,v);
438     alias := T.enter(!alias, v, GetAlias u);
439     moveList := T.enter (!moveList, u, S.union(T.look(!moveList,u), T.look(!moveList,v)));
440     EnableMoves (S.add(S.empty intOrder,v));
441     List.foldl (fn (t,s) => (AddEdge(t,u);DecrementDegree t)) ()
442                (S.listItems (Adjacent v));
443     if T.look(!degree,u)>=(!NumRegs) andalso S.member(!freezeWorklist,u) then
444        (freezeWorklist := S.delete (!freezeWorklist,u);
445         spillWorklist := S.add (!spillWorklist,u))
446     else ()
447   );
448
449fun Coalesce () =
450   let  val m = hd(S.listItems (!worklistMoves));
451        val inst = getNodeLab m;
452        val (def,use) = (#def inst, #use inst);
453        val (x,y) = (GetAlias (hd def), GetAlias (hd use));
454        val (u,v) = if (S.member(!precolored,y))
455                        then (y,x) else (x,y)
456   in
457        (worklistMoves := S.delete (!worklistMoves,m);
458         if u = v then
459            ( coalescedMoves := S.add(!coalescedMoves,m);
460              AddWorklist u)
461         else if S.member(!precolored,v) orelse S.member(!adjSet,(u,v)) then
462            ( constrainedMoves := S.add(!constrainedMoves,m);
463              AddWorklist u;
464              AddWorklist v)
465         else if S.member(!precolored,u) andalso (forall (fn t => OK(t,u)) (S.listItems (Adjacent v)))
466                orelse not (S.member(!precolored,u)) andalso
467                        Conservative (S.union(Adjacent u,Adjacent v)) then
468            ( coalescedMoves := S.add(!coalescedMoves,m);
469              Combine(u,v);
470              AddWorklist u)
471         else activeMoves := S.add(!activeMoves,m))
472   end;
473
474fun FreezeMoves u =
475  let fun freezeOneMove m =
476     let
477        val inst = getNodeLab m;
478        val (x,y) = (hd (#def inst), hd (#use inst));
479        val v = if GetAlias y = GetAlias u then GetAlias x else GetAlias y
480     in
481      ( activeMoves := S.difference (!activeMoves, S.add(S.empty intOrder,m));
482        frozenMoves := S.add (!frozenMoves, m);
483        if S.isEmpty(NodeMoves v) andalso (T.look(!degree,v) < (!NumRegs)) then
484          ( freezeWorklist := S.difference (!freezeWorklist, S.add(S.empty intOrder, v));
485            simplifyWorklist := S.add (!simplifyWorklist, v))
486        else ()
487      ) end
488   in List.foldl (fn (m,s) => freezeOneMove m) () (S.listItems (NodeMoves u))
489   end;
490
491
492fun Freeze () =
493   let val u = hd (S.listItems (!freezeWorklist)) in
494       (freezeWorklist := S.delete (!freezeWorklist, u);
495        simplifyWorklist := S.add (!simplifyWorklist, u);
496        FreezeMoves u)
497   end;
498
499fun spillPriorities nodes =
500    List.foldl (fn (n,max) => if T.look(!degree,n) > T.look(!degree,max) then n else max)
501        (hd(S.listItems nodes)) (S.listItems nodes);
502
503fun SelectSpill () =
504   let
505      val m = spillPriorities (!spillWorklist)
506   in
507    ( spillWorklist := S.delete (!spillWorklist,m);
508      simplifyWorklist := S.add(!simplifyWorklist,m);
509      FreezeMoves m)
510   end;
511
512(* -------------------------------------------------------------------------------------------------------------------*)
513(* Assign colors to nodes, spill nodes when a valid allocation couldn't be found                                      *)
514(* -------------------------------------------------------------------------------------------------------------------*)
515
516(* In a effort to compile with the APCS standard *)
517(* The arguemtns are in r0-r9 and then the stack *)
518
519fun AssignColors () =
520   let
521
522     fun initColors ~1 = S.empty intOrder
523      |  initColors n = S.add(initColors (n-1),n);
524
525     fun assign_precolored () =
526        List.foldl (fn (n,c) => (color := T.enter(!color,n,c);c+1)) 0 (!firstnArgL);
527
528     fun spillNodes n =
529        if (!spillOneOnce) then
530                S.add(S.empty intOrder, spillPriorities (!toBeSpilled))
531        else
532            (!toBeSpilled)
533
534     fun chaseColor n =
535        let val c = T.look(!color, GetAlias n) in
536        if (c = ~1) then
537            chaseColor (GetAlias n)
538        else c end;
539
540     fun assign () =
541      if Stack.isEmpty (!selectStack) then
542          (List.foldl (fn (n,s) => color := T.enter(!color,n,T.look(!color, GetAlias n)))
543                         ()
544                         (S.listItems (!coalescedNodes));
545                         ())
546      else
547          let val n = Stack.top (!selectStack);
548              val _ = (selectStack := Stack.pop(!selectStack));
549              val okColors = initColors ((!NumRegs)-1);
550              val okColors = List.foldl (fn (w,s) => if S.member(S.union(!coloredNodes,!precolored),GetAlias w) then
551                                        S.difference (s, S.add(S.empty intOrder, T.look(!color, GetAlias w)))
552                                                else s)
553                                okColors
554                                (S.listItems (T.look(!adjList,n)))
555          in
556              if S.isEmpty(okColors) then
557                  (spilledNodes := spillNodes n;
558                   selectStack := Stack.empty ())
559              else
560                  ( coloredNodes := S.add(!coloredNodes, n);
561                    color := T.enter(!color, n, hd (S.listItems okColors));
562                    assign()
563                  )
564          end
565   in
566        ( assign_precolored();
567          assign()
568        )
569   end;
570
571(* -------------------------------------------------------------------------------------------------------------------*)
572(* When a node is spilled, we modify the program by replacing the temporary with a memory slot                        *)
573(* -------------------------------------------------------------------------------------------------------------------*)
574
575fun updateNode(cfg, n:int, inst) =
576   let
577        val ((preN,nodeNo,nodeLab,sucN),cfg1) = G.match (n,cfg);
578   in
579        G.embed((preN,nodeNo,inst,sucN), cfg1)
580   end;
581
582
583fun insertBefore(cfg, n:int, inst) =
584   let
585        val (preN,nodeNo,nodeLab,sucN) = G.context (n,cfg);
586        val newCfg = List.foldl (fn (p,gr) => rm_edge((p,n),gr)) cfg (List.map (fn (lab,n) => n) preN);
587        val ct = (preN, G.noNodes cfg, inst, [(0,n)])
588   in
589        G.embed(ct, newCfg)
590
591   end;
592
593fun insertAfter(cfg, n:int, inst) =
594   let
595        val (preN,nodeNo,nodeLab,sucN) = G.context (n,cfg);
596        val newCfg = List.foldl (fn (s,gr) => rm_edge((n,s),gr)) cfg (List.map (fn (lab,n) => n) sucN);
597        val ct = ([(0,n)], G.noNodes cfg, inst, sucN)
598   in
599        G.embed(ct, newCfg)
600   end;
601
602
603(* t : TmpTable,  gr : control-flow graph       *)
604fun updateProgram (old_cfg, spilled : int S.set) =
605   let
606
607     fun replace l old new = List.map (fn x => (if (x = old) then new else x)) l;
608     fun substituteVars (Assem.OPER {oper = p, dst = d1, src = s1, jump = j1}) old new rhs lhs =
609            Assem.OPER {oper = p, dst = if rhs then replace d1 old new else d1,
610                src = if lhs then replace s1 old new else s1, jump = j1}
611      |  substituteVars (Assem.LABEL x) old new rhs lhs = Assem.LABEL x
612      |  substituteVars (Assem.MOVE {dst = d1, src = s1}) old new rhs lhs =
613            Assem.MOVE {dst = if rhs andalso d1 = old then new else d1,
614                src = if lhs andalso s1 = old then new else s1};
615
616     fun not_bl (Assem.OPER {oper = (Assem.BL, NONE, false), ...}) = false
617      |  not_bl _ = true
618     fun is_nop (Assem.OPER {oper = (Assem.NOP, NONE, false), ...}) = true
619             |  is_nop _ = false
620
621
622     fun update_node gr nodeNo (varNo,newVarNo) (for_lhs,for_rhs) =
623         let
624              val {instr = curInst, def = df, use = us} = #3 (G.context(nodeNo,gr))
625         in
626              updateNode(gr, nodeNo,
627                        (if not_bl curInst andalso not (is_nop curInst) then
628                                { def = if for_lhs andalso List.exists (fn n => n = varNo) df then newVarNo ::
629                                        (List.filter (fn n => not (n = varNo orelse n = newVarNo)) df) else df,
630                                  use = if for_rhs andalso List.exists (fn n => n = varNo) us then newVarNo ::
631                                        (List.filter (fn n => not (n = varNo orelse n = newVarNo)) us) else us,
632                                  instr = substituteVars curInst (Assem.TEMP varNo) (Assem.TEMP newVarNo) for_lhs for_rhs
633                                }
634                         else
635                                { def = if for_lhs then (List.filter (fn n => not (n = varNo)) df) else df,
636                                  use = if for_rhs then (List.filter (fn n => not (n = varNo)) us) else us,
637                                  instr = substituteVars curInst (Assem.TEMP varNo) (Assem.TMEM (!memIndex)) for_lhs for_rhs
638                                }
639                         ))
640          end
641
642     fun insertLoadInst gr nodeNo varNo =
643           let val newVarNo = newTmp ();
644               val {instr = curInst, def = df, use = us} = #3 (G.context(nodeNo,gr));
645               val newInst = {instr = Assem.OPER {oper = (Assem.LDR,NONE,false),
646                                                 dst = [Assem.TEMP newVarNo],
647                                                 src = [Assem.TMEM (!memIndex)],
648                                                 jump = NONE},
649                              def = [newVarNo], use = []};
650               val gr1 = update_node gr nodeNo (varNo,newVarNo) (false,true)
651           in
652               if not_bl curInst andalso not (is_nop curInst) then
653                   insertBefore(gr1, nodeNo, newInst)
654               else gr1
655           end;
656
657     fun insertStoreInst gr nodeNo varNo =
658           let val newVarNo = newTmp ();
659               val {instr = curInst, def = df, use = us} = #3 (G.context(nodeNo,gr));
660               val newInst = {instr = Assem.OPER {oper = (Assem.STR,NONE,false),
661                                                 dst = [Assem.TMEM (!memIndex)],
662                                                 src = [Assem.TEMP newVarNo],
663                                                 jump = NONE},
664                             def = [], use = [newVarNo]};
665               val gr1 = update_node gr nodeNo (varNo,newVarNo) (true,false)
666           in
667               if not_bl curInst andalso not (is_nop curInst) then
668                   insertAfter(gr1, nodeNo, newInst)
669               else gr1
670           end;
671
672    fun process_one_variable gr varNo =
673        let
674            val _ = (memIndex:= !memIndex + 1;
675                     memT := T.enter(!memT, varNo, !memIndex)
676                    )
677        in
678            G.ufold (fn ((predL,nodeNo,inst,sucL),gr) =>
679                  let val (def,use) = (#def inst, #use inst);
680                      val (def,use) = (S.addList(S.empty intOrder, def), S.addList(S.empty intOrder, use));
681                      val gr1 = if S.member(use, varNo) then
682                                  (if isMoveInst (#instr inst) then
683                                        let val (Assem.MOVE {dst = d1, src = s1}) = #instr inst in
684                                            updateNode(gr,nodeNo,
685                                                {instr = Assem.OPER {oper = (Assem.LDR,NONE,false), dst = [d1],
686                                                                    src = [Assem.TMEM (!memIndex)], jump = NONE},
687                                                 def = #def inst, use = []})
688                                        end
689                                   else
690                                        insertLoadInst gr nodeNo varNo
691                                  )
692                              else gr
693                  in
694                      if S.member(def, varNo) then
695                          insertStoreInst gr1 nodeNo varNo
696                      else gr1 end)
697                gr gr
698        end
699
700    in
701        List.foldl (fn (varNo, gr) => process_one_variable gr varNo) old_cfg (S.listItems (spilled))
702    end;
703
704
705fun RewriteProgram () =
706      ( cfg := updateProgram (!cfg,!spilledNodes);
707        init (!cfg,!tmpTable)
708      );
709
710
711(* -------------------------------------------------------------------------------------------------------------------*)
712(* Register Allocation                                                                                                *)
713(* -------------------------------------------------------------------------------------------------------------------*)
714
715fun printAllocation () =
716  let
717  fun color2reg c =
718     if c >= 0 andalso c < (!NumRegs) then "r" ^ Int.toString c
719     else "not found";
720
721  fun allocation tmps =
722     List.foldl
723        (fn (n,_) => if S.member (!spilledTmps, n) then ()
724            else let val c =  T.look(!color,n) in
725                 print (T.look(!tmpTable, n) ^ "\t" ^ (Int.toString c) ^ "\t" ^ (color2reg c) ^ "\n")
726                 end)
727    () (T.listKeys (!tmpTable))
728  in
729   ( print "Node\tColor\tRegister\n";
730     allocation ((T.numItems (!tmpTable)) - 1))
731  end;
732
733fun AllocateReg () =
734  let
735     fun allocate () =
736      ( if not (S.isEmpty (!simplifyWorklist)) then Simplify()
737        else if not (S.isEmpty (!worklistMoves)) then Coalesce()
738        else if not (S.isEmpty (!freezeWorklist)) then Freeze()
739        else if not (S.isEmpty (!spillWorklist)) then SelectSpill ()
740        else ();
741        if S.isEmpty (!simplifyWorklist) andalso S.isEmpty (!worklistMoves) andalso
742           S.isEmpty (!freezeWorklist) andalso S.isEmpty (!spillWorklist) then
743                ()
744        else allocate ())
745
746  in
747  ( Build(!cfg);
748    MakeWorklist();
749    allocate();
750    AssignColors();
751    if not (S.isEmpty (!spilledNodes)) then
752        ( RewriteProgram ();
753             AllocateReg ())
754    else ()
755  )
756  end;
757
758
759fun RewrWithReg () =
760  let
761
762     fun subs (Assem.PAIR(a,b)) = Assem.PAIR(subs a, subs b)
763      |  subs ( Assem.TEMP n) = Assem.REG (T.look(!color, n))
764      |  subs x = x
765
766     fun replace ll = List.map subs ll;
767     fun substituteVars (Assem.OPER {oper = p, dst = d1, src = s1, jump = j1}) =
768            Assem.OPER {oper = p, dst = replace d1, src = replace s1, jump = j1}
769      |  substituteVars (Assem.LABEL x) = Assem.LABEL x
770      |  substituteVars (Assem.MOVE {dst = d1, src = s1}) =
771            Assem.MOVE {dst = hd (replace [d1]), src = hd (replace [s1])};
772
773    in
774        G.ufold (fn ((predL,nodeNo,nd,sucL),gr) =>
775                 let val {use = us, def = df, instr = stm} = nd in
776                        updateNode (gr, nodeNo, {use = us, def = df, instr = substituteVars stm})
777                 end) (!cfg) (!cfg)
778    end;
779
780
781(* ---------------------------------------------------------------------------------------------------------------------*)
782(* Interface                                                                                                            *)
783(* ---------------------------------------------------------------------------------------------------------------------*)
784
785fun RegisterAllocation (gr, tmpT, preC) =
786  ( tmpTable := tmpT;
787
788    firstnArgL := List.take (preC, if length preC > !NumRegs then !NumRegs else length preC);
789    precolored := S.addList(S.empty intOrder, !firstnArgL);
790
791    cfg := gr;
792
793    memIndex := ~1;
794    memT := T.empty;
795
796    spilledTmps := S.empty intOrder;
797
798    init (gr, tmpT);
799
800    AllocateReg ();
801
802    PrintProgram(!cfg, !tmpTable);
803    printAllocation ();
804    (RewrWithReg (),!tmpTable)
805  );
806
807 fun regOrder (r1,r2) =
808      let val (Assem.REG n1, Assem.REG n2) = (r1, r2) in
809                if n1 > n2 then GREATER
810                else if n1 = n2 then EQUAL
811                else LESS
812      end
813
814 fun getModifiedRegs stms =
815      List.foldl (fn (Assem.OPER {dst = d, ...}, regs) =>
816                (List.foldl (fn (Assem.REG r, regs)  => Binaryset.add(regs,Assem.REG r)
817                                     |   _ => regs) regs d)
818              |  (Assem.MOVE {dst = d, ...}, regs) => Binaryset.add(regs,d)
819              |  (_,regs) => regs
820             )
821             (Binaryset.empty regOrder) stms
822
823fun convert_to_ARM prog =
824  let
825    fun replace_with_regs (Assem.PAIR(v1,v2)) =
826                Assem.PAIR (replace_with_regs v1, replace_with_regs v2)
827     |  replace_with_regs (Assem.TEMP v) =
828          ( case T.peek(!memT, v) of
829                 NONE => Assem.REG (T.look(!color, v))
830             |   SOME n => Assem.TMEM n
831          )
832     |  replace_with_regs v = v
833
834    val ((fun_name, fun_type, args, gr, outs), t) = CFG.convert_to_CFG prog;
835    val argL = List.map Assem.eval_exp (Assem.pair2list args);
836    val (gr',t) = RegisterAllocation (gr, t, argL);
837
838    val (args,outs) = (replace_with_regs args, replace_with_regs outs);
839    val stms = CFG.linearizeCFG gr';
840    val rs = getModifiedRegs stms
841
842  in
843    (fun_name, fun_type, args, gr', outs, rs)
844  end
845
846(* ---------------------------------------------------------------------------------------------------------------------*)
847(* Verification of register allocation on origial program                                                               *)
848(* ---------------------------------------------------------------------------------------------------------------------*)
849
850fun RewrBody rules body =
851 let
852   fun stripInst insts =
853     if not (is_let insts) then subst rules insts
854     else
855       let
856           val (lt, rhs) = dest_let insts;
857           val (lhs, rest) = dest_pabs lt;
858           val lhs = subst rules lhs
859       in
860           mk_let (mk_pabs (lhs, stripInst rest), subst rules rhs)
861       end
862 in
863    stripInst body
864 end
865
866
867fun check_allocation prog =
868  let
869     fun mk_reg_rules tp =
870        List.map (fn n => {redex = mk_var (T.look(!tmpTable,n), tp),
871                           residue = mk_var ("r" ^ Int.toString (T.look(!color,n)), tp)})
872                (List.filter (fn n => n >= 0 andalso T.peek(!memT,n) = NONE) (T.listKeys (!tmpTable)));
873
874    fun mk_mem_rules tp =
875        List.map (fn n => {redex = mk_var (T.look(!tmpTable,n), tp),
876                           residue = mk_var ("m" ^ Int.toString(T.look(!memT,n)), tp)})
877                (List.filter (fn n => n >= 0) (T.listKeys (!memT)));
878
879     val rw =   (RewrBody (mk_mem_rules (Type `:num`))) o (RewrBody (mk_mem_rules (Type `:word32`))) o
880                (RewrBody (mk_reg_rules (Type `:num`))) o (RewrBody (mk_reg_rules (Type `:word32`)));
881
882     fun replace exp =
883        if is_let exp then
884            let val (lt, rhs) = dest_let exp;
885                val (lhs, rest) = dest_pabs lt
886            in
887                mk_let (mk_pabs(replace lhs, replace rest), replace rhs)
888            end
889        else if is_cond exp then
890            let val (c,t,f) = dest_cond exp
891            in
892                mk_cond (replace c, replace t, replace f)
893            end
894        else if is_pair exp then
895            let val (t1,t2) = dest_pair exp
896            in  mk_pair (replace t1, replace t2)
897            end
898        else rw exp
899
900     val (decl,body) =
901           dest_eq(concl(SPEC_ALL prog))
902           handle HOL_ERR _
903           => (print "not an program in function format\n";
904             raise ERR "buildCFG" "invalid program format");
905     val (f, args) = dest_comb decl;
906
907     val newDecl = mk_comb (f, rw args)
908     val newProg = mk_eq(newDecl, replace body)
909
910  in
911    GEN_ALL (prove (
912        newProg,
913        METIS_TAC [prog, LET_THM]
914    ))
915  end
916
917end (* local open *)
918end (* structure *)
919
920