1structure reg_allocLib :> reg_allocLib =
2struct
3
4open HolKernel boolLib bossLib Parse;
5open wordsTheory codegen_inputLib helperLib;
6
7local
8  val compiler_abbrevs = ref let
9    fun kk i = if i < 32 then i::kk(i+1) else []
10    val ys = map (numSyntax.mk_numeral o Arbnum.fromInt) (kk 1)
11    val ys = map (fn x => ISPECL [mk_var("w",``:word32``),x] WORD_MUL_LSL) ys
12    val ys = map (GSYM o (CONV_RULE (RAND_CONV EVAL))) ys
13    val ys = map (ONCE_REWRITE_RULE [WORD_MULT_COMM]) ys @ ys
14    in ys end
15  in
16    fun add_abbrevs thms = (compiler_abbrevs := thms @ (!compiler_abbrevs))
17    fun COMPILER_UNABBREV_CONV tm = REWRITE_CONV (!compiler_abbrevs) tm
18  end;
19
20val get_temp_name = let
21  val n = ref 0
22  in (fn () => (n := (!n) + 1; "_" ^ int_to_string (!n))) end
23fun mk_temp_var ty = mk_var(get_temp_name(),ty)
24fun is_temp_var v = String.isPrefix "_" (fst (dest_var v)) handle HOL_ERR _ => false
25
26val get_t_name = let
27  val n = ref 0
28  in (fn () => (n := (!n) + 1; "t" ^ int_to_string (!n))) end
29fun mk_t_var ty = mk_var(get_t_name(),ty)
30fun is_t_var v = String.isPrefix "t" (fst (dest_var v)) handle HOL_ERR _ => false
31
32
33
34(* various helpers *)
35
36fun all_distinct [] = []
37  | all_distinct (x::xs) = x :: all_distinct (filter (fn y => not (x = y)) xs)
38
39fun append_lists [] = []
40  | append_lists (y::ys) = y @ append_lists ys
41
42fun diff xs ys = filter (fn y => not (mem y ys)) xs
43fun intersect xs ys = filter (fn y => mem y xs) ys
44
45fun dest_tuple tm =
46  let val (x,y) = pairSyntax.dest_pair tm in x :: dest_tuple y end handle HOL_ERR e => [tm];
47
48fun list_find x [] = fail()
49  | list_find x ((y,z)::zs) = if x = y then z else list_find x zs
50
51val EXPAND_LET_CONV =
52  (RATOR_CONV o RATOR_CONV) (ONCE_REWRITE_CONV [LET_DEF]) THENC
53   RATOR_CONV BETA_CONV THENC BETA_CONV THENC BETA_CONV
54
55fun mk_tuple [] = ``()``
56  | mk_tuple [x] = x
57  | mk_tuple (x::xs) = pairSyntax.mk_pair(x,mk_tuple xs)
58
59
60(* this conversion flattens large expressions into compilable assignments *)
61
62fun BOTTOM_UP_CONV c tm =
63  case dest_term tm of
64    COMB _ => (RAND_CONV (BOTTOM_UP_CONV c) THENC
65               RATOR_CONV (BOTTOM_UP_CONV c) THENC
66               TRY_CONV c) tm
67  | LAMB _ => (ABS_CONV (BOTTOM_UP_CONV c) THENC
68               TRY_CONV c) tm
69  | _ =>      (TRY_CONV c) tm
70
71fun TOP_DOWN_CONV c tm =
72  (TRY_CONV c THENC (fn tm =>
73    case dest_term tm of
74      COMB _ => (RAND_CONV (TOP_DOWN_CONV c) THENC RATOR_CONV (TOP_DOWN_CONV c)) tm
75    | LAMB _ => (ABS_CONV (TOP_DOWN_CONV c)) tm
76    | _ =>      ALL_CONV tm)) tm
77
78fun FLATTEN_EXPS_CONV tm = let
79  fun is_compilable tm = let
80    val vs = filter (fn x => type_of x = ``:word32``) (free_vars tm)
81    val r0 = mk_var("r0",``:word32``)
82    val tm = subst (map (fn x => x |-> r0) vs) tm
83    val result = case term2assign ``r1:word32`` tm of
84                   ASSIGN_OTHER _ => false | _ => true
85                 handle HOL_ERR _ => false | Empty => false
86    in result end handle HOL_ERR _ => false
87  fun is_c_guard tm = let
88    val vs = filter (fn x => type_of x = ``:word32``) (free_vars tm)
89    val r0 = mk_var("r0",``:word32``)
90    val tm = subst (map (fn x => x |-> r0) vs) tm
91    val result = case term2guard tm of
92                    GUARD_OTHER _ => false
93                  | GUARD_NOT (GUARD_OTHER _) => false
94                  | _ => true
95                 handle HOL_ERR _ => false | Empty => false
96    in result end handle HOL_ERR _ => false
97  fun one [x] = x | one _ = fail()
98  fun divide_aux g (xs,rhs) = if g rhs then (xs,rhs) else let
99    val t = find_term (fn x => is_compilable x andalso not (is_var x)) rhs
100    val ty = type_of t
101    val temp = mk_temp_var ty
102    val temp = if ty = ``:word32`` then temp else
103                 find_term (fn v => is_var v andalso (ty = type_of v)) t
104                 handle HOL_ERR _ => temp
105    in divide_aux g (xs @ [(temp,t)], subst [t |-> temp] rhs) end
106    handle HOL_ERR _ => (xs,rhs)
107  fun partition p xs = filter p xs @ filter (not o p) xs
108  fun divide g (xs,rhs) = let
109    val (xs,rhs) = divide_aux g (xs,rhs)
110    val xs = partition (fn x => type_of (fst x) = ``:word32``) xs
111    in (xs,rhs) end
112  fun CONJUNCTS_CONV c tm =
113    if is_conj tm then BINOP_CONV (CONJUNCTS_CONV c) tm else c tm
114  fun FORALL_CONV c tm =
115    if is_forall tm then QUANT_CONV (FORALL_CONV c) tm else c tm
116  val FUNC_BODY_CONV = CONJUNCTS_CONV o FORALL_CONV o RAND_CONV
117  fun FLAT_CONV tm = let
118    val f = tm2ftree tm
119    fun lets ([],y) = y
120      | lets ((x1,x2)::xs,y) = FUN_LET (x1,x2,lets (xs,y))
121    fun ftree_each (FUN_VAL rhs) = let
122          val (xs,rhs2) = divide is_compilable ([],rhs)
123          in lets (xs,FUN_VAL rhs2) end
124      | ftree_each (FUN_LET (lhs,rhs,t)) = let
125          val (xs,rhs2) = divide is_compilable ([],rhs)
126          in lets (xs,FUN_LET (lhs,rhs2,ftree_each t)) end
127      | ftree_each (FUN_IF (b,t1,t2)) = let
128          val (xs,b2) = divide is_c_guard ([],b)
129          in lets (xs,FUN_IF (b2,ftree_each t1,ftree_each t2)) end
130      | ftree_each (FUN_COND (b,t)) = FUN_COND (b,ftree_each t)
131    val tm2 = ftree2tm (ftree_each f)
132    fun EXPAND_TEMPVARLET_CONV tm = let
133      val (v,x) = dest_abs (fst (dest_let tm))
134      in if is_temp_var v then EXPAND_LET_CONV tm else NO_CONV tm end
135      handle HOL_ERR _ => NO_CONV tm
136    val goal = mk_eq(tm,tm2)
137    val result = auto_prove "FLAT_CONV" (goal,
138      CONV_TAC (BOTTOM_UP_CONV EXPAND_TEMPVARLET_CONV) THEN REWRITE_TAC [])
139    in result end
140  val result = FUNC_BODY_CONV FLAT_CONV tm
141  in result end;
142
143
144(* translation into ssa form, at least for word32 variables other than r0,r1... *)
145
146fun not_fixed_reg v = let
147  val (name,ty) = dest_var v
148  val ii = explode name
149  val reg = mem (hd ii) [#"r",#"s"] andalso
150            (filter (fn x => not (mem x [#"0",#"1",#"2",#"3",#"4",#"5",#"6",#"7",#"8",#"9",#"'"])) (tl ii) = [])
151  in (ty = ``:word32``) andalso not reg end
152  handle HOL_ERR _ => false
153
154val SSA_CONV = let
155  fun rename tm = let
156    val (v,x) = dest_abs tm
157    in if not_fixed_reg v then ALPHA_CONV (mk_t_var(type_of v)) tm
158                          else NO_CONV tm end
159  in BOTTOM_UP_CONV rename end
160
161val COMMON_SUBEXP_CONV = let
162  fun aux tm = let
163    val (x,y) = dest_let tm
164    val (v,x) = dest_abs x
165    val _ = dest_var v
166    val _ = if not_fixed_reg v then () else fail()
167    val _ = find_term (fn x => x = y) x
168    val x2 = subst [y|->v] x
169    val tm2 = mk_let(mk_abs(v,x2),y)
170    val goal = mk_eq(tm,tm2)
171    val EXPAND_LET_CONV =
172      (RATOR_CONV o RATOR_CONV) (ONCE_REWRITE_CONV [LET_DEF]) THENC
173       RATOR_CONV BETA_CONV THENC BETA_CONV THENC BETA_CONV
174    val thi = auto_prove "" (goal,
175      CONV_TAC (BINOP_CONV EXPAND_LET_CONV) THEN REWRITE_TAC [])
176    fun DELETE_EXTRA_MOVE_CONV tm = let
177      val (x,y) = dest_let tm
178      val (v,x) = dest_abs x
179      val _ = dest_var v
180      val _ = dest_var y
181      val _ = if not_fixed_reg v then () else fail()
182      in EXPAND_LET_CONV tm end
183    in ((fn tm => thi) THENC BOTTOM_UP_CONV DELETE_EXTRA_MOVE_CONV) tm end
184  in TOP_DOWN_CONV aux end
185
186
187(* restrict register names *)
188
189fun parallel_assign tm2 tm = let (* both tm and tm2 must be tuples of variables *)
190  (* make basic parallel assignment *)
191  val (m,_) = match_term tm tm2
192  val xs = filter (fn x => not (x = subst m x)) (dest_tuple tm)
193  val vs = map (fn x => mk_temp_var (type_of x)) xs
194  val rs = zip vs xs @ zip (map (subst m) xs) vs
195  (* optimise: copy forward *)
196  fun forward [] aux = []
197    | forward ((x,y)::xs) aux = let
198        val y = list_find y aux handle HOL_ERR _ => y
199        val aux = filter (fn (lhs,rhs) => not (mem x (free_vars rhs))) aux
200        in (x,y) :: forward xs ((x,y)::aux) end
201  val rs = forward rs []
202  (* optimise: remove unused temporary variables *)
203  fun is_used x [] = not (is_temp_var x)
204    | is_used x ((y,z)::xs) = if mem x (free_vars z) then true else is_used x xs
205  fun in_tail [] = []
206    | in_tail ((x,y)::xs) = if is_used x xs then (x,y)::in_tail xs else in_tail xs
207  val rs = in_tail rs
208  in rs end;
209
210fun FIX_CALL_RETURN_VALUES_CONV tm = let
211  (* find one return value for each function *)
212  fun in_out x = let
213    val (lhs,rhs) = dest_eq x
214    fun leaves (FUN_COND (_,t)) = leaves t
215      | leaves (FUN_LET (_,_,t)) = leaves t
216      | leaves (FUN_IF (_,t1,t2)) = leaves t1 @ leaves t2
217      | leaves (FUN_VAL tm) = [tm]
218    val bases = filter (not o can (match_term lhs)) (leaves (tm2ftree rhs))
219    in (car lhs, (cdr lhs, hd bases)) end
220  val xs = map (repeat (snd o dest_forall)) (list_dest dest_conj tm)
221  val io = map in_out xs
222  (* invent new temporaries for each return value *)
223  fun invent_new_temps (x,(y,z)) = let
224    val f = map (fn z => if is_t_var z then mk_t_var(type_of z) else z)
225    in (x,(y,mk_tuple (f (dest_tuple z)))) end
226  val io = map invent_new_temps io
227  (* add restrictions on already compiled components *)
228  (* ... *)
229  (* make sure all function calls/returns respect this io restriction *)
230  fun CONJUNCTS_CONV c tm =
231    if is_conj tm then BINOP_CONV (CONJUNCTS_CONV c) tm else c tm
232  fun FORALL_CONV c tm =
233    if is_forall tm then QUANT_CONV (FORALL_CONV c) tm else c tm
234  val FUNC_BODY_CONV = CONJUNCTS_CONV o FORALL_CONV
235  fun FLAT_CONV tm = let
236    val func_tm = (car o fst o dest_eq) tm
237    val f = tm2ftree (cdr tm)
238    fun lets [] y = y
239      | lets ((x1,x2)::xs) y = FUN_LET (x1,x2,lets xs y)
240    fun ftree_each (FUN_IF (b,t1,t2)) = FUN_IF (b,ftree_each t1,ftree_each t2)
241      | ftree_each (FUN_COND (b,t)) = FUN_COND (b,ftree_each t)
242      | ftree_each (FUN_VAL rhs) = let
243          val call = (car rhs = func_tm) handle HOL_ERR _ => false
244          val x = (if call then fst else snd) (list_find func_tm io)
245          val rhs2 = if call then cdr rhs else rhs
246          val rs1 = parallel_assign x rhs2
247          val ret = if call then mk_comb(func_tm,x) else x
248          in lets rs1 (FUN_VAL ret) end
249      | ftree_each (FUN_LET (lhs,rhs,t)) = let
250          val (x,y) = list_find (car rhs) io
251          val rs1 = parallel_assign x (cdr rhs)
252          val rs2 = parallel_assign lhs y
253          in lets rs1 (FUN_LET (y,mk_comb(car rhs,x),lets rs2 (ftree_each t))) end
254          handle HOL_ERR _ => FUN_LET (lhs,rhs,ftree_each t)
255    val tm2 = ftree2tm (ftree_each f)
256    fun EXPAND_TEMPVARLET_CONV tm = let
257      val (v,x) = dest_abs (fst (dest_let tm))
258      in if is_temp_var v then EXPAND_LET_CONV tm else NO_CONV tm end
259      handle HOL_ERR _ => NO_CONV tm
260    val goal = mk_eq(tm,mk_eq((fst o dest_eq) tm,tm2))
261    val result = auto_prove "FLAT_CONV" (goal,SIMP_TAC std_ss [LET_DEF])
262    in result end
263  val result = FUNC_BODY_CONV FLAT_CONV tm
264  in result end;
265
266
267(* clash graph and reg allocation *)
268
269fun ftree_free_vars t = let
270  fun vars (FUN_VAL tm) = free_vars tm
271    | vars (FUN_COND (tm,t)) = free_vars tm @ vars t
272    | vars (FUN_IF (tm,x1,x2)) = all_distinct (free_vars tm @ vars x1 @ vars x2)
273    | vars (FUN_LET (lhs,rhs,t)) = all_distinct (free_vars lhs @ free_vars rhs @ vars t)
274  in all_distinct (vars t) end;
275
276fun subroutine_internal_vars (tm,t) = let
277  val vs = free_vars (cdr tm)
278  fun leaves (FUN_COND (_,t)) = leaves t
279    | leaves (FUN_LET (_,_,t)) = leaves t
280    | leaves (FUN_IF (_,t1,t2)) = leaves t1 @ leaves t2
281    | leaves (FUN_VAL tm) = [tm]
282  val xs = append_lists (map free_vars (leaves t))
283  in diff (ftree_free_vars t) (vs @ xs) end
284
285fun clash_graph ts = let
286  fun ok_var x = (type_of x = ``:word32``)
287  fun add_live_set2 ls1 ls2 t = FUN_COND
288       (mk_eq(listSyntax.mk_list(all_distinct ls1,``:word32``),
289              listSyntax.mk_list(all_distinct ls2,``:word32``)),t)
290  fun add_live_set ls t = add_live_set2 ls [] t
291  val fs = map (car o fst) ts
292  fun get_internal_vars rhs =
293    if not (mem (car rhs) fs) handle HOL_ERR _ => true then [] else
294      subroutine_internal_vars (hd (filter (fn (x,_) => x = rhs) ts))
295  fun live (FUN_VAL tm) = let
296        val ls = filter ok_var (free_vars tm)
297        val t = add_live_set ls (FUN_VAL tm)
298        in (ls,t) end
299    | live (FUN_COND (tm,t)) = fail()
300    | live (FUN_IF (tm,x1,x2)) = let
301        val (ls1,y1) = live x1
302        val (ls2,y2) = live x2
303        val ls = (filter ok_var (free_vars tm)) @ ls1 @ ls2
304        val t = add_live_set ls (FUN_IF (tm,y1,y2))
305        in (ls,t) end
306    | live (FUN_LET (lhs,rhs,t)) = let
307        val (ls,tt) = live t
308        val vs = (filter ok_var (free_vars lhs))
309        val ls2 = diff ls vs
310        val ls = ls2 @ (filter ok_var (free_vars rhs))
311        val ii = get_internal_vars rhs
312        val t = if ii = [] then add_live_set ls (FUN_LET (lhs,rhs,tt)) else
313                  add_live_set ls (add_live_set2 ls2 ii (FUN_LET (lhs,rhs,tt)))
314        in (ls,t) end
315  fun collect (FUN_VAL tm) = []
316    | collect (FUN_IF (tm,x1,x2)) = collect x1 @ collect x2
317    | collect (FUN_LET (lhs,rhs,t)) = collect t
318    | collect (FUN_COND (tm,t)) = let
319         val f = fst o listSyntax.dest_list
320         val (x1,x2) = dest_eq tm
321         in (f x1, f x2) :: collect t end
322  val live_sets = append_lists (map (fn (f,t) => (collect (snd (live t)))) ts)
323  fun clash [] y z = false
324    | clash ((x1,x2)::xs) y z =
325        (mem y x1 andalso mem z x1) orelse
326        (mem y x1 andalso mem z x2) orelse
327        (mem y x2 andalso mem z x1) orelse clash xs y z
328  val all_vars = all_distinct (append_lists (map fst live_sets))
329  val graph = map (fn v => (v,filter (clash live_sets v) all_vars)) all_vars
330  val graph = map (fn (v,cs) => (v,filter (fn y => not (y = v)) cs)) graph
331  in graph end
332
333fun move_assignments ts graph = let
334  fun pref (FUN_COND (_,t)) = pref t
335    | pref (FUN_IF (_,t1,t2)) = pref t1 @ pref t2
336    | pref (FUN_VAL tm) = []
337    | pref (FUN_LET (x,y,t)) =
338        if is_var x andalso is_var y then (x,y)::pref t else pref t
339  val moves = append_lists (map (pref o snd) ts)
340  in moves end;
341
342(* iterated_register_coalescing implements algorithm by George and Appel '96 *)
343fun iterated_register_coalescing graph moves freq is_colourable n = let
344  val init_graph = graph
345  fun kk n = if n < 0 then [] else n::kk(n-1)
346  val regs = map (fn n => mk_var("r" ^ (int_to_string n),``:word32``)) (rev (kk (n-1)))
347  val gsort = sort (fn (xz,x) => fn (yz:term,y:term list) => length x <= length y)
348  val r = map fst (filter (fn (x,xs) => mem x regs) graph)
349  val q = filter (fn (x,xs) => not (mem x regs)) graph
350  fun move_related [] = []
351    | move_related ((x,y)::moves) = x::y::move_related moves
352  fun print_graph graph =
353    (map (fn (v,ns) => (print "\n  "; print_term v; print ":";
354          map (fn x => (print " "; print_term x)) ns)) graph; print "\n")
355  fun join_all joined x = join_all joined (list_find x joined) handle HOL_ERR _ => x
356  fun merge_vertexes x y (graph,moves,joined) = let
357    val xs = filter (fn v => not (v = x) andalso not (v = y)) (list_find x graph)
358    val ys = filter (fn v => not (v = x) andalso not (v = y)) (list_find y graph)
359    val graph = filter (fn (v,ns) => not (v = x) andalso not (v = y)) graph
360    val graph = map (fn (v,ns) => (v,all_distinct (map (fn n => if n = y then x else n) ns))) graph
361    val graph = (x,all_distinct (xs @ ys)) :: graph
362    val moves = filter (fn z => not (z = (x,y))) moves
363    val moves = map (fn (z1,z2) => (if z1 = y then x else z1,if z2 = y then x else z2)) moves
364    val joined = (y,x)::joined
365    in (graph,moves,joined) end;
366  fun delete_vertex w (graph,moves,joined) = let
367    val graph = filter (fn (v,ns) => not (v = w)) graph
368    val graph = map (fn (v,ns) => (v,filter (fn n => not (n = w)) ns)) graph
369    val moves = filter (fn (x,y) => not (x = w) andalso not (y = w)) moves
370    in (graph,moves,joined) end;
371  fun busy w = list_find w freq handle HOL_ERR _ => 0
372  fun no_print s = print ("    " ^ s ^ "\n")
373  fun build_stack graph moves joined n result =
374    (* simplification: ?w. ~(w IN ms) and degree w < n, then remove from graph *) let
375    (* val _ = no_print_graph graph *)
376    val ms = move_related moves
377    val not_ms_graph = filter (fn (v,neighbours) => not (mem v ms)) graph
378    val ws = map fst (filter (fn (v,ns) => length ns < n) not_ms_graph)
379    val ws = filter is_colourable ws
380    val ws = sort (fn x => fn y => busy x >= busy y) ws
381    val w = first (K true) ws (* select most busy *)
382    val (graph,moves,joined) = delete_vertex w (graph,moves,joined)
383    val _ = no_print ("!" ^ term_to_string w ^ " ")
384    in build_stack graph moves joined n ((w,"r")::result) end handle HOL_ERR _ =>
385    (* coalescing: ?x y. (x,y) IN moves and degree (x UNION y) < n, then combine x,y *) let
386    fun combined_degree (x,y) = length (all_distinct (list_find x graph @ list_find y graph))
387                                handle HOL_ERR _ => n+1000
388    val moves2 = filter (fn (x,y) => not (mem x (list_find y graph))) moves
389    val moves2 = filter (fn (x,y) => combined_degree (x,y) < n) moves2
390    val moves2 = sort (fn (x1,x2) => fn (y1,y2) => busy x1 + busy x2 >= busy y1 + busy y2) moves2
391    val moves2 = filter (fn (x,y) => is_colourable x orelse is_colourable y) moves2
392    val moves2 = filter (fn (x,y) => not (x = y)) moves2
393    val (x,y) = first (fn (x,y) => true) moves2
394    val (x,y) = if is_colourable y then (x,y) else (y,x)
395    val (graph,moves,joined) = merge_vertexes x y (graph,moves,joined)
396    val _ = no_print (term_to_string x ^ "<--" ^ term_to_string y ^ " ")
397    in build_stack graph moves joined n result end handle HOL_ERR _ =>
398    (* freezing: removing the move property from an edge *) let
399    val ((x,y),moves) = if moves = [] then fail () else (hd moves,tl moves)
400    val _ = no_print (term_to_string x ^ "-/-" ^ term_to_string y ^ " ")
401    in build_stack graph moves joined n result end handle HOL_ERR _ =>
402    (* spilling: select a vertex and spill it *) let
403    val ws = map fst graph
404    val ws = filter is_colourable ws
405    val ws = sort (fn x => fn y => busy x <= busy y) ws
406    val w = if ws = [] then fail () else hd ws (* select least busy *)
407    val (graph,moves,joined) = delete_vertex w (graph,moves,joined)
408    val _ = no_print ("^" ^ term_to_string w ^ " ")
409    in build_stack graph moves joined n ((w,"s")::result) end handle HOL_ERR _ =>
410    (rev result, joined)
411  val (stack,joined) = build_stack graph moves [] n []
412  val coalesced = join_all joined
413  fun update x y z i = if x = i then y else z i
414  fun select_colour x options c = let
415    fun score c = foldr (op +) 0 (map (fn (x,y) => if c x = c y then 1 else 0) moves)
416    val xs = map (fn p => (p,score (update x p c))) options
417    val result = fst (hd (sort (fn (_,x) => fn (_,y) => y <= x) xs))
418    in result end handle HOL_ERR _ => hd options
419    handle Empty => failwith "no more registers"
420  fun colour [] (c,r) = c
421    | colour ((x,ty)::stack) (c,r) =
422        if ty = "r" then let
423          val qs = map snd (filter (fn (v,ns) => coalesced v = x) graph)
424          val qs = map coalesced (append_lists qs)
425          val zs = filter (fn z => mem z r) qs
426          val zs = map c zs
427          val new_colour = select_colour x (diff regs zs) c
428          in colour stack (update x new_colour c, x::r) end
429        else let
430          val qs = map snd (filter (fn (v,ns) => coalesced v = x) graph)
431          val qs = map coalesced (append_lists qs)
432          val zs = filter (fn z => mem z r) qs
433          val zs = map c zs
434          fun next_stack i = let
435            val z = mk_var("s" ^ int_to_string i,``:word32``)
436            in if mem z zs then next_stack (i+1) else z end
437          val z = next_stack 0
438          in colour stack (update x z c, x::r) end
439  val colouring = colour stack (I,r) o join_all joined
440  (* check validity of colouring *)
441  val g = map (fn (v,ns) => (colouring v, map colouring ns)) graph
442  val _ = if filter (fn (x,xs) => mem x xs) g = [] then ()
443          else (print "\n\nRegister allocator produced invalid result.\n\n"; fail())
444  in (colouring) end
445
446(* provide a list representing the frequency of use/def of each variable,
447   use/defs inside loops are times constant 16 for each loop nesting. *)
448fun frequency ts = let
449  fun is_rec (FUN_VAL tm) = not (pairSyntax.is_pair tm)
450    | is_rec (FUN_COND (tm,t)) = is_rec t
451    | is_rec (FUN_IF (tm,x1,x2)) = is_rec x1 orelse is_rec x2
452    | is_rec (FUN_LET (lhs,rhs,t)) = is_rec t
453  val fs = map (car o fst) ts
454  val vs = all_distinct (append_lists (map (ftree_free_vars o snd) ts))
455  val vs = diff vs fs
456  fun occ v tm s = s + (if mem v (free_vars tm) then 1 else 0)
457  fun count v (FUN_VAL tm) s = occ v tm s
458    | count v (FUN_COND (tm,t)) s = count v t (occ v tm s)
459    | count v (FUN_IF (tm,x1,x2)) s = count v x1 (count v x2 (occ v tm s))
460    | count v (FUN_LET (lhs,rhs,t2)) s = let
461        val t = (list_find rhs ts handle HOL_ERR _ => FUN_VAL T)
462        val inner_s = count v t 0
463        val inner_s = if is_rec t then inner_s * 16 else inner_s
464        in count v t2 (occ v lhs (occ v rhs (inner_s + s))) end
465  val freq = map (fn v => (v,count v (snd (last ts)) 0)) vs
466  in freq end;
467
468fun REMOVE_REFL_LET_CONV tm = let
469  val (x,y) = dest_let tm
470  val (v,x) = dest_abs x
471  in if v = y then EXPAND_LET_CONV tm else NO_CONV tm end
472  handle HOL_ERR _ => NO_CONV tm;
473
474fun REMOVE_DEAD_LET_CONV tm = let
475  val (x,y) = dest_let tm
476  val (v,x) = dest_abs x
477  in if mem v (free_vars x) then NO_CONV tm else EXPAND_LET_CONV tm end
478  handle HOL_ERR _ => NO_CONV tm;
479
480fun REG_ALLOC_CONV n tm = let
481  val xs = map (repeat (snd o dest_forall)) (list_dest dest_conj tm)
482  val ts = map (fn x => ((cdr o car) x, tm2ftree (cdr x))) xs
483  val graph = clash_graph ts
484  val moves = move_assignments ts graph
485  val freq = frequency ts
486  val is_colourable = is_t_var
487  val colouring = iterated_register_coalescing graph moves freq is_colourable n
488  fun COLOUR_ALPHA_CONV colouring tm =
489    ALPHA_CONV (colouring (fst (dest_abs tm))) tm handle HOL_ERR _ => NO_CONV tm
490  val thi = (BOTTOM_UP_CONV REMOVE_DEAD_LET_CONV THENC
491             BOTTOM_UP_CONV (COLOUR_ALPHA_CONV colouring) THENC
492             BOTTOM_UP_CONV REMOVE_REFL_LET_CONV) tm
493  in thi end;
494
495fun initial_clean tm = let
496  val tms = list_dest dest_conj tm
497  fun function_name tm = repeat car (fst (dest_eq tm))
498  val fs = map function_name tms
499  fun add_foralls t = list_mk_forall (diff (free_vars t) fs, t)
500  val tms2 = map add_foralls tms
501  val tm2 = list_mk_conj tms2
502  val goal = mk_imp(tm2,tm)
503  val imp = auto_prove "initial_clean" (goal,
504              ONCE_REWRITE_TAC [EQ_SYM_EQ] THEN SIMP_TAC bool_ss []
505              THEN ONCE_REWRITE_TAC [EQ_SYM_EQ] THEN SIMP_TAC bool_ss [])
506  in imp end;
507
508fun allocate_registers n input_tm = let
509  val imp = initial_clean input_tm
510  val tm = (fst o dest_imp o concl) imp
511  val cc = COMPILER_UNABBREV_CONV
512           THENC FLATTEN_EXPS_CONV
513           THENC SSA_CONV THENC COMMON_SUBEXP_CONV
514           THENC FIX_CALL_RETURN_VALUES_CONV
515           THENC REG_ALLOC_CONV n
516  (*
517  val tm = (snd o dest_eq o concl) (cc tm)
518  *)
519  in CONV_RULE ((RATOR_CONV o RAND_CONV) cc) imp end
520
521
522
523(*
524  for x86: 1. split binary ops into two parts
525                let x = y ?? z in  --> let x = y in let x = x ?? z in
526              this might lead the reg allocator to coalesce x and y,
527              alternatively augment 'moves' to have artificial (x,y) edge
528              but many of these are commutative, should there be (x,[y,z]) edge?
529           2. assume infinite number of regs, make
530                regs 5,6,7,etc. --> stack locations 0,1,2,3,etc.
531              reserve one register for loading when two stack locations are
532              used in the same instruction
533*)
534
535(*
536
537val n = 3
538val pref_list = [4,3,2,1]
539val k = length pref_list
540fun t_vars n = if n = 0 then [] else mk_t_var(``:word32``)::t_vars (n-1)
541val qs = t_vars k
542fun cross xs ys = append_lists (map (fn x => map (fn y => (x,y)) ys) xs)
543fun rest [] = []
544  | rest ((x,y)::xs) = (x,y)::rest (filter (fn z => not (z = (y,x))) xs)
545val edges = rest (filter (fn (x,y) => not (x = y)) (cross qs qs))
546val max = foldr (op * ) 1 (map (K 2) edges)
547val max2 = max * max
548val freq = zip qs pref_list
549val is_colourable = is_t_var
550
551fun get_graph i = let
552  fun n_filter i [] = []
553    | n_filter i (x::xs) =
554        if i mod 2 = 0 then n_filter (i div 2) xs
555        else x :: n_filter (i div 2) xs
556  fun adj vs edges = map (fn v => (v,map snd (filter (fn x => fst x = v) edges))) vs
557  val ts = n_filter i edges
558  val ts = ts @ map (fn (x,y) => (y,x)) ts
559  val moves = n_filter (i div max) edges
560  in (adj qs ts, moves) end;
561
562fun try_inst i = let
563  val (graph,moves) = get_graph i
564  val _ = print (int_to_string i)
565  val _ = print "/"
566  val _ = print (int_to_string max2)
567  val _ = print " "
568  val ok = (iterated_register_coalescing graph moves freq is_colourable n; true)
569           handle HOL_ERR _ => false
570  val _ = print "\n"
571  in if not ok then print ("\n\nFailed at "^int_to_string i^".\n\n") else
572     if i < max2 then try_inst (i+1) else print "\n\nDone!\n\n" end;
573
574val _ = try_inst 0
575
576*)
577
578end;
579