1structure reg_allocLib :> reg_allocLib = 2struct 3 4open HolKernel boolLib bossLib Parse; 5open wordsTheory codegen_inputLib helperLib; 6 7local 8 val compiler_abbrevs = ref let 9 fun kk i = if i < 32 then i::kk(i+1) else [] 10 val ys = map (numSyntax.mk_numeral o Arbnum.fromInt) (kk 1) 11 val ys = map (fn x => ISPECL [mk_var("w",``:word32``),x] WORD_MUL_LSL) ys 12 val ys = map (GSYM o (CONV_RULE (RAND_CONV EVAL))) ys 13 val ys = map (ONCE_REWRITE_RULE [WORD_MULT_COMM]) ys @ ys 14 in ys end 15 in 16 fun add_abbrevs thms = (compiler_abbrevs := thms @ (!compiler_abbrevs)) 17 fun COMPILER_UNABBREV_CONV tm = REWRITE_CONV (!compiler_abbrevs) tm 18 end; 19 20val get_temp_name = let 21 val n = ref 0 22 in (fn () => (n := (!n) + 1; "_" ^ int_to_string (!n))) end 23fun mk_temp_var ty = mk_var(get_temp_name(),ty) 24fun is_temp_var v = String.isPrefix "_" (fst (dest_var v)) handle HOL_ERR _ => false 25 26val get_t_name = let 27 val n = ref 0 28 in (fn () => (n := (!n) + 1; "t" ^ int_to_string (!n))) end 29fun mk_t_var ty = mk_var(get_t_name(),ty) 30fun is_t_var v = String.isPrefix "t" (fst (dest_var v)) handle HOL_ERR _ => false 31 32 33 34(* various helpers *) 35 36fun all_distinct [] = [] 37 | all_distinct (x::xs) = x :: all_distinct (filter (fn y => not (x = y)) xs) 38 39fun append_lists [] = [] 40 | append_lists (y::ys) = y @ append_lists ys 41 42fun diff xs ys = filter (fn y => not (mem y ys)) xs 43fun intersect xs ys = filter (fn y => mem y xs) ys 44 45fun dest_tuple tm = 46 let val (x,y) = pairSyntax.dest_pair tm in x :: dest_tuple y end handle HOL_ERR e => [tm]; 47 48fun list_find x [] = fail() 49 | list_find x ((y,z)::zs) = if x = y then z else list_find x zs 50 51val EXPAND_LET_CONV = 52 (RATOR_CONV o RATOR_CONV) (ONCE_REWRITE_CONV [LET_DEF]) THENC 53 RATOR_CONV BETA_CONV THENC BETA_CONV THENC BETA_CONV 54 55fun mk_tuple [] = ``()`` 56 | mk_tuple [x] = x 57 | mk_tuple (x::xs) = pairSyntax.mk_pair(x,mk_tuple xs) 58 59 60(* this conversion flattens large expressions into compilable assignments *) 61 62fun BOTTOM_UP_CONV c tm = 63 case dest_term tm of 64 COMB _ => (RAND_CONV (BOTTOM_UP_CONV c) THENC 65 RATOR_CONV (BOTTOM_UP_CONV c) THENC 66 TRY_CONV c) tm 67 | LAMB _ => (ABS_CONV (BOTTOM_UP_CONV c) THENC 68 TRY_CONV c) tm 69 | _ => (TRY_CONV c) tm 70 71fun TOP_DOWN_CONV c tm = 72 (TRY_CONV c THENC (fn tm => 73 case dest_term tm of 74 COMB _ => (RAND_CONV (TOP_DOWN_CONV c) THENC RATOR_CONV (TOP_DOWN_CONV c)) tm 75 | LAMB _ => (ABS_CONV (TOP_DOWN_CONV c)) tm 76 | _ => ALL_CONV tm)) tm 77 78fun FLATTEN_EXPS_CONV tm = let 79 fun is_compilable tm = let 80 val vs = filter (fn x => type_of x = ``:word32``) (free_vars tm) 81 val r0 = mk_var("r0",``:word32``) 82 val tm = subst (map (fn x => x |-> r0) vs) tm 83 val result = case term2assign ``r1:word32`` tm of 84 ASSIGN_OTHER _ => false | _ => true 85 handle HOL_ERR _ => false | Empty => false 86 in result end handle HOL_ERR _ => false 87 fun is_c_guard tm = let 88 val vs = filter (fn x => type_of x = ``:word32``) (free_vars tm) 89 val r0 = mk_var("r0",``:word32``) 90 val tm = subst (map (fn x => x |-> r0) vs) tm 91 val result = case term2guard tm of 92 GUARD_OTHER _ => false 93 | GUARD_NOT (GUARD_OTHER _) => false 94 | _ => true 95 handle HOL_ERR _ => false | Empty => false 96 in result end handle HOL_ERR _ => false 97 fun one [x] = x | one _ = fail() 98 fun divide_aux g (xs,rhs) = if g rhs then (xs,rhs) else let 99 val t = find_term (fn x => is_compilable x andalso not (is_var x)) rhs 100 val ty = type_of t 101 val temp = mk_temp_var ty 102 val temp = if ty = ``:word32`` then temp else 103 find_term (fn v => is_var v andalso (ty = type_of v)) t 104 handle HOL_ERR _ => temp 105 in divide_aux g (xs @ [(temp,t)], subst [t |-> temp] rhs) end 106 handle HOL_ERR _ => (xs,rhs) 107 fun partition p xs = filter p xs @ filter (not o p) xs 108 fun divide g (xs,rhs) = let 109 val (xs,rhs) = divide_aux g (xs,rhs) 110 val xs = partition (fn x => type_of (fst x) = ``:word32``) xs 111 in (xs,rhs) end 112 fun CONJUNCTS_CONV c tm = 113 if is_conj tm then BINOP_CONV (CONJUNCTS_CONV c) tm else c tm 114 fun FORALL_CONV c tm = 115 if is_forall tm then QUANT_CONV (FORALL_CONV c) tm else c tm 116 val FUNC_BODY_CONV = CONJUNCTS_CONV o FORALL_CONV o RAND_CONV 117 fun FLAT_CONV tm = let 118 val f = tm2ftree tm 119 fun lets ([],y) = y 120 | lets ((x1,x2)::xs,y) = FUN_LET (x1,x2,lets (xs,y)) 121 fun ftree_each (FUN_VAL rhs) = let 122 val (xs,rhs2) = divide is_compilable ([],rhs) 123 in lets (xs,FUN_VAL rhs2) end 124 | ftree_each (FUN_LET (lhs,rhs,t)) = let 125 val (xs,rhs2) = divide is_compilable ([],rhs) 126 in lets (xs,FUN_LET (lhs,rhs2,ftree_each t)) end 127 | ftree_each (FUN_IF (b,t1,t2)) = let 128 val (xs,b2) = divide is_c_guard ([],b) 129 in lets (xs,FUN_IF (b2,ftree_each t1,ftree_each t2)) end 130 | ftree_each (FUN_COND (b,t)) = FUN_COND (b,ftree_each t) 131 val tm2 = ftree2tm (ftree_each f) 132 fun EXPAND_TEMPVARLET_CONV tm = let 133 val (v,x) = dest_abs (fst (dest_let tm)) 134 in if is_temp_var v then EXPAND_LET_CONV tm else NO_CONV tm end 135 handle HOL_ERR _ => NO_CONV tm 136 val goal = mk_eq(tm,tm2) 137 val result = auto_prove "FLAT_CONV" (goal, 138 CONV_TAC (BOTTOM_UP_CONV EXPAND_TEMPVARLET_CONV) THEN REWRITE_TAC []) 139 in result end 140 val result = FUNC_BODY_CONV FLAT_CONV tm 141 in result end; 142 143 144(* translation into ssa form, at least for word32 variables other than r0,r1... *) 145 146fun not_fixed_reg v = let 147 val (name,ty) = dest_var v 148 val ii = explode name 149 val reg = mem (hd ii) [#"r",#"s"] andalso 150 (filter (fn x => not (mem x [#"0",#"1",#"2",#"3",#"4",#"5",#"6",#"7",#"8",#"9",#"'"])) (tl ii) = []) 151 in (ty = ``:word32``) andalso not reg end 152 handle HOL_ERR _ => false 153 154val SSA_CONV = let 155 fun rename tm = let 156 val (v,x) = dest_abs tm 157 in if not_fixed_reg v then ALPHA_CONV (mk_t_var(type_of v)) tm 158 else NO_CONV tm end 159 in BOTTOM_UP_CONV rename end 160 161val COMMON_SUBEXP_CONV = let 162 fun aux tm = let 163 val (x,y) = dest_let tm 164 val (v,x) = dest_abs x 165 val _ = dest_var v 166 val _ = if not_fixed_reg v then () else fail() 167 val _ = find_term (fn x => x = y) x 168 val x2 = subst [y|->v] x 169 val tm2 = mk_let(mk_abs(v,x2),y) 170 val goal = mk_eq(tm,tm2) 171 val EXPAND_LET_CONV = 172 (RATOR_CONV o RATOR_CONV) (ONCE_REWRITE_CONV [LET_DEF]) THENC 173 RATOR_CONV BETA_CONV THENC BETA_CONV THENC BETA_CONV 174 val thi = auto_prove "" (goal, 175 CONV_TAC (BINOP_CONV EXPAND_LET_CONV) THEN REWRITE_TAC []) 176 fun DELETE_EXTRA_MOVE_CONV tm = let 177 val (x,y) = dest_let tm 178 val (v,x) = dest_abs x 179 val _ = dest_var v 180 val _ = dest_var y 181 val _ = if not_fixed_reg v then () else fail() 182 in EXPAND_LET_CONV tm end 183 in ((fn tm => thi) THENC BOTTOM_UP_CONV DELETE_EXTRA_MOVE_CONV) tm end 184 in TOP_DOWN_CONV aux end 185 186 187(* restrict register names *) 188 189fun parallel_assign tm2 tm = let (* both tm and tm2 must be tuples of variables *) 190 (* make basic parallel assignment *) 191 val (m,_) = match_term tm tm2 192 val xs = filter (fn x => not (x = subst m x)) (dest_tuple tm) 193 val vs = map (fn x => mk_temp_var (type_of x)) xs 194 val rs = zip vs xs @ zip (map (subst m) xs) vs 195 (* optimise: copy forward *) 196 fun forward [] aux = [] 197 | forward ((x,y)::xs) aux = let 198 val y = list_find y aux handle HOL_ERR _ => y 199 val aux = filter (fn (lhs,rhs) => not (mem x (free_vars rhs))) aux 200 in (x,y) :: forward xs ((x,y)::aux) end 201 val rs = forward rs [] 202 (* optimise: remove unused temporary variables *) 203 fun is_used x [] = not (is_temp_var x) 204 | is_used x ((y,z)::xs) = if mem x (free_vars z) then true else is_used x xs 205 fun in_tail [] = [] 206 | in_tail ((x,y)::xs) = if is_used x xs then (x,y)::in_tail xs else in_tail xs 207 val rs = in_tail rs 208 in rs end; 209 210fun FIX_CALL_RETURN_VALUES_CONV tm = let 211 (* find one return value for each function *) 212 fun in_out x = let 213 val (lhs,rhs) = dest_eq x 214 fun leaves (FUN_COND (_,t)) = leaves t 215 | leaves (FUN_LET (_,_,t)) = leaves t 216 | leaves (FUN_IF (_,t1,t2)) = leaves t1 @ leaves t2 217 | leaves (FUN_VAL tm) = [tm] 218 val bases = filter (not o can (match_term lhs)) (leaves (tm2ftree rhs)) 219 in (car lhs, (cdr lhs, hd bases)) end 220 val xs = map (repeat (snd o dest_forall)) (list_dest dest_conj tm) 221 val io = map in_out xs 222 (* invent new temporaries for each return value *) 223 fun invent_new_temps (x,(y,z)) = let 224 val f = map (fn z => if is_t_var z then mk_t_var(type_of z) else z) 225 in (x,(y,mk_tuple (f (dest_tuple z)))) end 226 val io = map invent_new_temps io 227 (* add restrictions on already compiled components *) 228 (* ... *) 229 (* make sure all function calls/returns respect this io restriction *) 230 fun CONJUNCTS_CONV c tm = 231 if is_conj tm then BINOP_CONV (CONJUNCTS_CONV c) tm else c tm 232 fun FORALL_CONV c tm = 233 if is_forall tm then QUANT_CONV (FORALL_CONV c) tm else c tm 234 val FUNC_BODY_CONV = CONJUNCTS_CONV o FORALL_CONV 235 fun FLAT_CONV tm = let 236 val func_tm = (car o fst o dest_eq) tm 237 val f = tm2ftree (cdr tm) 238 fun lets [] y = y 239 | lets ((x1,x2)::xs) y = FUN_LET (x1,x2,lets xs y) 240 fun ftree_each (FUN_IF (b,t1,t2)) = FUN_IF (b,ftree_each t1,ftree_each t2) 241 | ftree_each (FUN_COND (b,t)) = FUN_COND (b,ftree_each t) 242 | ftree_each (FUN_VAL rhs) = let 243 val call = (car rhs = func_tm) handle HOL_ERR _ => false 244 val x = (if call then fst else snd) (list_find func_tm io) 245 val rhs2 = if call then cdr rhs else rhs 246 val rs1 = parallel_assign x rhs2 247 val ret = if call then mk_comb(func_tm,x) else x 248 in lets rs1 (FUN_VAL ret) end 249 | ftree_each (FUN_LET (lhs,rhs,t)) = let 250 val (x,y) = list_find (car rhs) io 251 val rs1 = parallel_assign x (cdr rhs) 252 val rs2 = parallel_assign lhs y 253 in lets rs1 (FUN_LET (y,mk_comb(car rhs,x),lets rs2 (ftree_each t))) end 254 handle HOL_ERR _ => FUN_LET (lhs,rhs,ftree_each t) 255 val tm2 = ftree2tm (ftree_each f) 256 fun EXPAND_TEMPVARLET_CONV tm = let 257 val (v,x) = dest_abs (fst (dest_let tm)) 258 in if is_temp_var v then EXPAND_LET_CONV tm else NO_CONV tm end 259 handle HOL_ERR _ => NO_CONV tm 260 val goal = mk_eq(tm,mk_eq((fst o dest_eq) tm,tm2)) 261 val result = auto_prove "FLAT_CONV" (goal,SIMP_TAC std_ss [LET_DEF]) 262 in result end 263 val result = FUNC_BODY_CONV FLAT_CONV tm 264 in result end; 265 266 267(* clash graph and reg allocation *) 268 269fun ftree_free_vars t = let 270 fun vars (FUN_VAL tm) = free_vars tm 271 | vars (FUN_COND (tm,t)) = free_vars tm @ vars t 272 | vars (FUN_IF (tm,x1,x2)) = all_distinct (free_vars tm @ vars x1 @ vars x2) 273 | vars (FUN_LET (lhs,rhs,t)) = all_distinct (free_vars lhs @ free_vars rhs @ vars t) 274 in all_distinct (vars t) end; 275 276fun subroutine_internal_vars (tm,t) = let 277 val vs = free_vars (cdr tm) 278 fun leaves (FUN_COND (_,t)) = leaves t 279 | leaves (FUN_LET (_,_,t)) = leaves t 280 | leaves (FUN_IF (_,t1,t2)) = leaves t1 @ leaves t2 281 | leaves (FUN_VAL tm) = [tm] 282 val xs = append_lists (map free_vars (leaves t)) 283 in diff (ftree_free_vars t) (vs @ xs) end 284 285fun clash_graph ts = let 286 fun ok_var x = (type_of x = ``:word32``) 287 fun add_live_set2 ls1 ls2 t = FUN_COND 288 (mk_eq(listSyntax.mk_list(all_distinct ls1,``:word32``), 289 listSyntax.mk_list(all_distinct ls2,``:word32``)),t) 290 fun add_live_set ls t = add_live_set2 ls [] t 291 val fs = map (car o fst) ts 292 fun get_internal_vars rhs = 293 if not (mem (car rhs) fs) handle HOL_ERR _ => true then [] else 294 subroutine_internal_vars (hd (filter (fn (x,_) => x = rhs) ts)) 295 fun live (FUN_VAL tm) = let 296 val ls = filter ok_var (free_vars tm) 297 val t = add_live_set ls (FUN_VAL tm) 298 in (ls,t) end 299 | live (FUN_COND (tm,t)) = fail() 300 | live (FUN_IF (tm,x1,x2)) = let 301 val (ls1,y1) = live x1 302 val (ls2,y2) = live x2 303 val ls = (filter ok_var (free_vars tm)) @ ls1 @ ls2 304 val t = add_live_set ls (FUN_IF (tm,y1,y2)) 305 in (ls,t) end 306 | live (FUN_LET (lhs,rhs,t)) = let 307 val (ls,tt) = live t 308 val vs = (filter ok_var (free_vars lhs)) 309 val ls2 = diff ls vs 310 val ls = ls2 @ (filter ok_var (free_vars rhs)) 311 val ii = get_internal_vars rhs 312 val t = if ii = [] then add_live_set ls (FUN_LET (lhs,rhs,tt)) else 313 add_live_set ls (add_live_set2 ls2 ii (FUN_LET (lhs,rhs,tt))) 314 in (ls,t) end 315 fun collect (FUN_VAL tm) = [] 316 | collect (FUN_IF (tm,x1,x2)) = collect x1 @ collect x2 317 | collect (FUN_LET (lhs,rhs,t)) = collect t 318 | collect (FUN_COND (tm,t)) = let 319 val f = fst o listSyntax.dest_list 320 val (x1,x2) = dest_eq tm 321 in (f x1, f x2) :: collect t end 322 val live_sets = append_lists (map (fn (f,t) => (collect (snd (live t)))) ts) 323 fun clash [] y z = false 324 | clash ((x1,x2)::xs) y z = 325 (mem y x1 andalso mem z x1) orelse 326 (mem y x1 andalso mem z x2) orelse 327 (mem y x2 andalso mem z x1) orelse clash xs y z 328 val all_vars = all_distinct (append_lists (map fst live_sets)) 329 val graph = map (fn v => (v,filter (clash live_sets v) all_vars)) all_vars 330 val graph = map (fn (v,cs) => (v,filter (fn y => not (y = v)) cs)) graph 331 in graph end 332 333fun move_assignments ts graph = let 334 fun pref (FUN_COND (_,t)) = pref t 335 | pref (FUN_IF (_,t1,t2)) = pref t1 @ pref t2 336 | pref (FUN_VAL tm) = [] 337 | pref (FUN_LET (x,y,t)) = 338 if is_var x andalso is_var y then (x,y)::pref t else pref t 339 val moves = append_lists (map (pref o snd) ts) 340 in moves end; 341 342(* iterated_register_coalescing implements algorithm by George and Appel '96 *) 343fun iterated_register_coalescing graph moves freq is_colourable n = let 344 val init_graph = graph 345 fun kk n = if n < 0 then [] else n::kk(n-1) 346 val regs = map (fn n => mk_var("r" ^ (int_to_string n),``:word32``)) (rev (kk (n-1))) 347 val gsort = sort (fn (xz,x) => fn (yz:term,y:term list) => length x <= length y) 348 val r = map fst (filter (fn (x,xs) => mem x regs) graph) 349 val q = filter (fn (x,xs) => not (mem x regs)) graph 350 fun move_related [] = [] 351 | move_related ((x,y)::moves) = x::y::move_related moves 352 fun print_graph graph = 353 (map (fn (v,ns) => (print "\n "; print_term v; print ":"; 354 map (fn x => (print " "; print_term x)) ns)) graph; print "\n") 355 fun join_all joined x = join_all joined (list_find x joined) handle HOL_ERR _ => x 356 fun merge_vertexes x y (graph,moves,joined) = let 357 val xs = filter (fn v => not (v = x) andalso not (v = y)) (list_find x graph) 358 val ys = filter (fn v => not (v = x) andalso not (v = y)) (list_find y graph) 359 val graph = filter (fn (v,ns) => not (v = x) andalso not (v = y)) graph 360 val graph = map (fn (v,ns) => (v,all_distinct (map (fn n => if n = y then x else n) ns))) graph 361 val graph = (x,all_distinct (xs @ ys)) :: graph 362 val moves = filter (fn z => not (z = (x,y))) moves 363 val moves = map (fn (z1,z2) => (if z1 = y then x else z1,if z2 = y then x else z2)) moves 364 val joined = (y,x)::joined 365 in (graph,moves,joined) end; 366 fun delete_vertex w (graph,moves,joined) = let 367 val graph = filter (fn (v,ns) => not (v = w)) graph 368 val graph = map (fn (v,ns) => (v,filter (fn n => not (n = w)) ns)) graph 369 val moves = filter (fn (x,y) => not (x = w) andalso not (y = w)) moves 370 in (graph,moves,joined) end; 371 fun busy w = list_find w freq handle HOL_ERR _ => 0 372 fun no_print s = print (" " ^ s ^ "\n") 373 fun build_stack graph moves joined n result = 374 (* simplification: ?w. ~(w IN ms) and degree w < n, then remove from graph *) let 375 (* val _ = no_print_graph graph *) 376 val ms = move_related moves 377 val not_ms_graph = filter (fn (v,neighbours) => not (mem v ms)) graph 378 val ws = map fst (filter (fn (v,ns) => length ns < n) not_ms_graph) 379 val ws = filter is_colourable ws 380 val ws = sort (fn x => fn y => busy x >= busy y) ws 381 val w = first (K true) ws (* select most busy *) 382 val (graph,moves,joined) = delete_vertex w (graph,moves,joined) 383 val _ = no_print ("!" ^ term_to_string w ^ " ") 384 in build_stack graph moves joined n ((w,"r")::result) end handle HOL_ERR _ => 385 (* coalescing: ?x y. (x,y) IN moves and degree (x UNION y) < n, then combine x,y *) let 386 fun combined_degree (x,y) = length (all_distinct (list_find x graph @ list_find y graph)) 387 handle HOL_ERR _ => n+1000 388 val moves2 = filter (fn (x,y) => not (mem x (list_find y graph))) moves 389 val moves2 = filter (fn (x,y) => combined_degree (x,y) < n) moves2 390 val moves2 = sort (fn (x1,x2) => fn (y1,y2) => busy x1 + busy x2 >= busy y1 + busy y2) moves2 391 val moves2 = filter (fn (x,y) => is_colourable x orelse is_colourable y) moves2 392 val moves2 = filter (fn (x,y) => not (x = y)) moves2 393 val (x,y) = first (fn (x,y) => true) moves2 394 val (x,y) = if is_colourable y then (x,y) else (y,x) 395 val (graph,moves,joined) = merge_vertexes x y (graph,moves,joined) 396 val _ = no_print (term_to_string x ^ "<--" ^ term_to_string y ^ " ") 397 in build_stack graph moves joined n result end handle HOL_ERR _ => 398 (* freezing: removing the move property from an edge *) let 399 val ((x,y),moves) = if moves = [] then fail () else (hd moves,tl moves) 400 val _ = no_print (term_to_string x ^ "-/-" ^ term_to_string y ^ " ") 401 in build_stack graph moves joined n result end handle HOL_ERR _ => 402 (* spilling: select a vertex and spill it *) let 403 val ws = map fst graph 404 val ws = filter is_colourable ws 405 val ws = sort (fn x => fn y => busy x <= busy y) ws 406 val w = if ws = [] then fail () else hd ws (* select least busy *) 407 val (graph,moves,joined) = delete_vertex w (graph,moves,joined) 408 val _ = no_print ("^" ^ term_to_string w ^ " ") 409 in build_stack graph moves joined n ((w,"s")::result) end handle HOL_ERR _ => 410 (rev result, joined) 411 val (stack,joined) = build_stack graph moves [] n [] 412 val coalesced = join_all joined 413 fun update x y z i = if x = i then y else z i 414 fun select_colour x options c = let 415 fun score c = foldr (op +) 0 (map (fn (x,y) => if c x = c y then 1 else 0) moves) 416 val xs = map (fn p => (p,score (update x p c))) options 417 val result = fst (hd (sort (fn (_,x) => fn (_,y) => y <= x) xs)) 418 in result end handle HOL_ERR _ => hd options 419 handle Empty => failwith "no more registers" 420 fun colour [] (c,r) = c 421 | colour ((x,ty)::stack) (c,r) = 422 if ty = "r" then let 423 val qs = map snd (filter (fn (v,ns) => coalesced v = x) graph) 424 val qs = map coalesced (append_lists qs) 425 val zs = filter (fn z => mem z r) qs 426 val zs = map c zs 427 val new_colour = select_colour x (diff regs zs) c 428 in colour stack (update x new_colour c, x::r) end 429 else let 430 val qs = map snd (filter (fn (v,ns) => coalesced v = x) graph) 431 val qs = map coalesced (append_lists qs) 432 val zs = filter (fn z => mem z r) qs 433 val zs = map c zs 434 fun next_stack i = let 435 val z = mk_var("s" ^ int_to_string i,``:word32``) 436 in if mem z zs then next_stack (i+1) else z end 437 val z = next_stack 0 438 in colour stack (update x z c, x::r) end 439 val colouring = colour stack (I,r) o join_all joined 440 (* check validity of colouring *) 441 val g = map (fn (v,ns) => (colouring v, map colouring ns)) graph 442 val _ = if filter (fn (x,xs) => mem x xs) g = [] then () 443 else (print "\n\nRegister allocator produced invalid result.\n\n"; fail()) 444 in (colouring) end 445 446(* provide a list representing the frequency of use/def of each variable, 447 use/defs inside loops are times constant 16 for each loop nesting. *) 448fun frequency ts = let 449 fun is_rec (FUN_VAL tm) = not (pairSyntax.is_pair tm) 450 | is_rec (FUN_COND (tm,t)) = is_rec t 451 | is_rec (FUN_IF (tm,x1,x2)) = is_rec x1 orelse is_rec x2 452 | is_rec (FUN_LET (lhs,rhs,t)) = is_rec t 453 val fs = map (car o fst) ts 454 val vs = all_distinct (append_lists (map (ftree_free_vars o snd) ts)) 455 val vs = diff vs fs 456 fun occ v tm s = s + (if mem v (free_vars tm) then 1 else 0) 457 fun count v (FUN_VAL tm) s = occ v tm s 458 | count v (FUN_COND (tm,t)) s = count v t (occ v tm s) 459 | count v (FUN_IF (tm,x1,x2)) s = count v x1 (count v x2 (occ v tm s)) 460 | count v (FUN_LET (lhs,rhs,t2)) s = let 461 val t = (list_find rhs ts handle HOL_ERR _ => FUN_VAL T) 462 val inner_s = count v t 0 463 val inner_s = if is_rec t then inner_s * 16 else inner_s 464 in count v t2 (occ v lhs (occ v rhs (inner_s + s))) end 465 val freq = map (fn v => (v,count v (snd (last ts)) 0)) vs 466 in freq end; 467 468fun REMOVE_REFL_LET_CONV tm = let 469 val (x,y) = dest_let tm 470 val (v,x) = dest_abs x 471 in if v = y then EXPAND_LET_CONV tm else NO_CONV tm end 472 handle HOL_ERR _ => NO_CONV tm; 473 474fun REMOVE_DEAD_LET_CONV tm = let 475 val (x,y) = dest_let tm 476 val (v,x) = dest_abs x 477 in if mem v (free_vars x) then NO_CONV tm else EXPAND_LET_CONV tm end 478 handle HOL_ERR _ => NO_CONV tm; 479 480fun REG_ALLOC_CONV n tm = let 481 val xs = map (repeat (snd o dest_forall)) (list_dest dest_conj tm) 482 val ts = map (fn x => ((cdr o car) x, tm2ftree (cdr x))) xs 483 val graph = clash_graph ts 484 val moves = move_assignments ts graph 485 val freq = frequency ts 486 val is_colourable = is_t_var 487 val colouring = iterated_register_coalescing graph moves freq is_colourable n 488 fun COLOUR_ALPHA_CONV colouring tm = 489 ALPHA_CONV (colouring (fst (dest_abs tm))) tm handle HOL_ERR _ => NO_CONV tm 490 val thi = (BOTTOM_UP_CONV REMOVE_DEAD_LET_CONV THENC 491 BOTTOM_UP_CONV (COLOUR_ALPHA_CONV colouring) THENC 492 BOTTOM_UP_CONV REMOVE_REFL_LET_CONV) tm 493 in thi end; 494 495fun initial_clean tm = let 496 val tms = list_dest dest_conj tm 497 fun function_name tm = repeat car (fst (dest_eq tm)) 498 val fs = map function_name tms 499 fun add_foralls t = list_mk_forall (diff (free_vars t) fs, t) 500 val tms2 = map add_foralls tms 501 val tm2 = list_mk_conj tms2 502 val goal = mk_imp(tm2,tm) 503 val imp = auto_prove "initial_clean" (goal, 504 ONCE_REWRITE_TAC [EQ_SYM_EQ] THEN SIMP_TAC bool_ss [] 505 THEN ONCE_REWRITE_TAC [EQ_SYM_EQ] THEN SIMP_TAC bool_ss []) 506 in imp end; 507 508fun allocate_registers n input_tm = let 509 val imp = initial_clean input_tm 510 val tm = (fst o dest_imp o concl) imp 511 val cc = COMPILER_UNABBREV_CONV 512 THENC FLATTEN_EXPS_CONV 513 THENC SSA_CONV THENC COMMON_SUBEXP_CONV 514 THENC FIX_CALL_RETURN_VALUES_CONV 515 THENC REG_ALLOC_CONV n 516 (* 517 val tm = (snd o dest_eq o concl) (cc tm) 518 *) 519 in CONV_RULE ((RATOR_CONV o RAND_CONV) cc) imp end 520 521 522 523(* 524 for x86: 1. split binary ops into two parts 525 let x = y ?? z in --> let x = y in let x = x ?? z in 526 this might lead the reg allocator to coalesce x and y, 527 alternatively augment 'moves' to have artificial (x,y) edge 528 but many of these are commutative, should there be (x,[y,z]) edge? 529 2. assume infinite number of regs, make 530 regs 5,6,7,etc. --> stack locations 0,1,2,3,etc. 531 reserve one register for loading when two stack locations are 532 used in the same instruction 533*) 534 535(* 536 537val n = 3 538val pref_list = [4,3,2,1] 539val k = length pref_list 540fun t_vars n = if n = 0 then [] else mk_t_var(``:word32``)::t_vars (n-1) 541val qs = t_vars k 542fun cross xs ys = append_lists (map (fn x => map (fn y => (x,y)) ys) xs) 543fun rest [] = [] 544 | rest ((x,y)::xs) = (x,y)::rest (filter (fn z => not (z = (y,x))) xs) 545val edges = rest (filter (fn (x,y) => not (x = y)) (cross qs qs)) 546val max = foldr (op * ) 1 (map (K 2) edges) 547val max2 = max * max 548val freq = zip qs pref_list 549val is_colourable = is_t_var 550 551fun get_graph i = let 552 fun n_filter i [] = [] 553 | n_filter i (x::xs) = 554 if i mod 2 = 0 then n_filter (i div 2) xs 555 else x :: n_filter (i div 2) xs 556 fun adj vs edges = map (fn v => (v,map snd (filter (fn x => fst x = v) edges))) vs 557 val ts = n_filter i edges 558 val ts = ts @ map (fn (x,y) => (y,x)) ts 559 val moves = n_filter (i div max) edges 560 in (adj qs ts, moves) end; 561 562fun try_inst i = let 563 val (graph,moves) = get_graph i 564 val _ = print (int_to_string i) 565 val _ = print "/" 566 val _ = print (int_to_string max2) 567 val _ = print " " 568 val ok = (iterated_register_coalescing graph moves freq is_colourable n; true) 569 handle HOL_ERR _ => false 570 val _ = print "\n" 571 in if not ok then print ("\n\nFailed at "^int_to_string i^".\n\n") else 572 if i < max2 then try_inst (i+1) else print "\n\nDone!\n\n" end; 573 574val _ = try_inst 0 575 576*) 577 578end; 579