1structure reg_allocLib :> reg_allocLib = 2struct 3 4open HolKernel boolLib bossLib Parse; 5open wordsTheory codegen_inputLib helperLib; 6 7local 8 val compiler_abbrevs = ref let 9 fun kk i = if i < 32 then i::kk(i+1) else [] 10 val ys = map (numSyntax.mk_numeral o Arbnum.fromInt) (kk 1) 11 val ys = map (fn x => ISPECL [mk_var("w",``:word32``),x] WORD_MUL_LSL) ys 12 val ys = map (GSYM o (CONV_RULE (RAND_CONV EVAL))) ys 13 val ys = map (ONCE_REWRITE_RULE [WORD_MULT_COMM]) ys @ ys 14 in ys end 15 in 16 fun add_abbrevs thms = (compiler_abbrevs := thms @ (!compiler_abbrevs)) 17 fun COMPILER_UNABBREV_CONV tm = REWRITE_CONV (!compiler_abbrevs) tm 18 end; 19 20val get_temp_name = let 21 val n = ref 0 22 in (fn () => (n := (!n) + 1; "_" ^ int_to_string (!n))) end 23fun mk_temp_var ty = mk_var(get_temp_name(),ty) 24fun is_temp_var v = String.isPrefix "_" (fst (dest_var v)) handle HOL_ERR _ => false 25 26val get_t_name = let 27 val n = ref 0 28 in (fn () => (n := (!n) + 1; "t" ^ int_to_string (!n))) end 29fun mk_t_var ty = mk_var(get_t_name(),ty) 30fun is_t_var v = String.isPrefix "t" (fst (dest_var v)) handle HOL_ERR _ => false 31 32 33 34(* various helpers *) 35 36fun all_distinct [] = [] 37 | all_distinct (x::xs) = x :: all_distinct (filter (fn y => x !~ y) xs) 38 39fun append_lists [] = [] 40 | append_lists (y::ys) = y @ append_lists ys 41 42fun diff xs ys = filter (fn y => not (mem y ys)) xs 43fun intersect xs ys = filter (fn y => mem y xs) ys 44 45fun dest_tuple tm = 46 let val (x,y) = pairSyntax.dest_pair tm in x :: dest_tuple y end handle HOL_ERR e => [tm]; 47 48fun list_find x [] = fail() 49 | list_find x ((y,z)::zs) = if x ~~ y then z else list_find x zs 50 51val EXPAND_LET_CONV = 52 (RATOR_CONV o RATOR_CONV) (ONCE_REWRITE_CONV [LET_DEF]) THENC 53 RATOR_CONV BETA_CONV THENC BETA_CONV THENC BETA_CONV 54 55fun mk_tuple [] = ``()`` 56 | mk_tuple [x] = x 57 | mk_tuple (x::xs) = pairSyntax.mk_pair(x,mk_tuple xs) 58 59 60(* this conversion flattens large expressions into compilable assignments *) 61 62fun BOTTOM_UP_CONV c tm = 63 case dest_term tm of 64 COMB _ => (RAND_CONV (BOTTOM_UP_CONV c) THENC 65 RATOR_CONV (BOTTOM_UP_CONV c) THENC 66 TRY_CONV c) tm 67 | LAMB _ => (ABS_CONV (BOTTOM_UP_CONV c) THENC 68 TRY_CONV c) tm 69 | _ => (TRY_CONV c) tm 70 71fun TOP_DOWN_CONV c tm = 72 (TRY_CONV c THENC (fn tm => 73 case dest_term tm of 74 COMB _ => (RAND_CONV (TOP_DOWN_CONV c) THENC RATOR_CONV (TOP_DOWN_CONV c)) tm 75 | LAMB _ => (ABS_CONV (TOP_DOWN_CONV c)) tm 76 | _ => ALL_CONV tm)) tm 77 78fun FLATTEN_EXPS_CONV tm = let 79 fun is_compilable tm = let 80 val vs = filter (fn x => type_of x = ``:word32``) (free_vars tm) 81 val r0 = mk_var("r0",``:word32``) 82 val tm = subst (map (fn x => x |-> r0) vs) tm 83 val result = case term2assign ``r1:word32`` tm of 84 ASSIGN_OTHER _ => false | _ => true 85 handle HOL_ERR _ => false | Empty => false 86 in result end handle HOL_ERR _ => false 87 fun is_c_guard tm = let 88 val vs = filter (fn x => type_of x = ``:word32``) (free_vars tm) 89 val r0 = mk_var("r0",``:word32``) 90 val tm = subst (map (fn x => x |-> r0) vs) tm 91 val result = case term2guard tm of 92 GUARD_OTHER _ => false 93 | GUARD_NOT (GUARD_OTHER _) => false 94 | _ => true 95 handle HOL_ERR _ => false | Empty => false 96 in result end handle HOL_ERR _ => false 97 fun one [x] = x | one _ = fail() 98 fun divide_aux g (xs,rhs) = if g rhs then (xs,rhs) else let 99 val t = find_term (fn x => is_compilable x andalso not (is_var x)) rhs 100 val ty = type_of t 101 val temp = mk_temp_var ty 102 val temp = if ty = ``:word32`` then temp else 103 find_term (fn v => is_var v andalso (ty = type_of v)) t 104 handle HOL_ERR _ => temp 105 in divide_aux g (xs @ [(temp,t)], subst [t |-> temp] rhs) end 106 handle HOL_ERR _ => (xs,rhs) 107 fun partition p xs = filter p xs @ filter (not o p) xs 108 fun divide g (xs,rhs) = let 109 val (xs,rhs) = divide_aux g (xs,rhs) 110 val xs = partition (fn x => type_of (fst x) = ``:word32``) xs 111 in (xs,rhs) end 112 fun CONJUNCTS_CONV c tm = 113 if is_conj tm then BINOP_CONV (CONJUNCTS_CONV c) tm else c tm 114 fun FORALL_CONV c tm = 115 if is_forall tm then QUANT_CONV (FORALL_CONV c) tm else c tm 116 val FUNC_BODY_CONV = CONJUNCTS_CONV o FORALL_CONV o RAND_CONV 117 fun FLAT_CONV tm = let 118 val f = tm2ftree tm 119 fun lets ([],y) = y 120 | lets ((x1,x2)::xs,y) = FUN_LET (x1,x2,lets (xs,y)) 121 fun ftree_each (FUN_VAL rhs) = let 122 val (xs,rhs2) = divide is_compilable ([],rhs) 123 in lets (xs,FUN_VAL rhs2) end 124 | ftree_each (FUN_LET (lhs,rhs,t)) = let 125 val (xs,rhs2) = divide is_compilable ([],rhs) 126 in lets (xs,FUN_LET (lhs,rhs2,ftree_each t)) end 127 | ftree_each (FUN_IF (b,t1,t2)) = let 128 val (xs,b2) = divide is_c_guard ([],b) 129 in lets (xs,FUN_IF (b2,ftree_each t1,ftree_each t2)) end 130 | ftree_each (FUN_COND (b,t)) = FUN_COND (b,ftree_each t) 131 val tm2 = ftree2tm (ftree_each f) 132 fun EXPAND_TEMPVARLET_CONV tm = let 133 val (v,x) = dest_abs (fst (dest_let tm)) 134 in if is_temp_var v then EXPAND_LET_CONV tm else NO_CONV tm end 135 handle HOL_ERR _ => NO_CONV tm 136 val goal = mk_eq(tm,tm2) 137 val result = auto_prove "FLAT_CONV" (goal, 138 CONV_TAC (BOTTOM_UP_CONV EXPAND_TEMPVARLET_CONV) THEN REWRITE_TAC []) 139 in result end 140 val result = FUNC_BODY_CONV FLAT_CONV tm 141 in result end; 142 143 144(* translation into ssa form, at least for word32 variables other than r0,r1... *) 145 146fun not_fixed_reg v = let 147 val (name,ty) = dest_var v 148 val ii = explode name 149 val reg = mem (hd ii) [#"r",#"s"] andalso 150 (filter (fn x => not (mem x [#"0",#"1",#"2",#"3",#"4",#"5",#"6",#"7",#"8",#"9",#"'"])) (tl ii) = []) 151 in (ty = ``:word32``) andalso not reg end 152 handle HOL_ERR _ => false 153 154val SSA_CONV = let 155 fun rename tm = let 156 val (v,x) = dest_abs tm 157 in if not_fixed_reg v then ALPHA_CONV (mk_t_var(type_of v)) tm 158 else NO_CONV tm end 159 in BOTTOM_UP_CONV rename end 160 161val COMMON_SUBEXP_CONV = let 162 fun aux tm = let 163 val (x,y) = dest_let tm 164 val (v,x) = dest_abs x 165 val _ = dest_var v 166 val _ = if not_fixed_reg v then () else fail() 167 val _ = find_term (fn x => x ~~ y) x 168 val x2 = subst [y|->v] x 169 val tm2 = mk_let(mk_abs(v,x2),y) 170 val goal = mk_eq(tm,tm2) 171 val EXPAND_LET_CONV = 172 (RATOR_CONV o RATOR_CONV) (ONCE_REWRITE_CONV [LET_DEF]) THENC 173 RATOR_CONV BETA_CONV THENC BETA_CONV THENC BETA_CONV 174 val thi = auto_prove "" (goal, 175 CONV_TAC (BINOP_CONV EXPAND_LET_CONV) THEN REWRITE_TAC []) 176 fun DELETE_EXTRA_MOVE_CONV tm = let 177 val (x,y) = dest_let tm 178 val (v,x) = dest_abs x 179 val _ = dest_var v 180 val _ = dest_var y 181 val _ = if not_fixed_reg v then () else fail() 182 in EXPAND_LET_CONV tm end 183 in ((fn tm => thi) THENC BOTTOM_UP_CONV DELETE_EXTRA_MOVE_CONV) tm end 184 in TOP_DOWN_CONV aux end 185 186 187(* restrict register names *) 188 189fun parallel_assign tm2 tm = let (* both tm and tm2 must be tuples of variables *) 190 (* make basic parallel assignment *) 191 val (m,_) = match_term tm tm2 192 val xs = filter (fn x => x !~ subst m x) (dest_tuple tm) 193 val vs = map (fn x => mk_temp_var (type_of x)) xs 194 val rs = zip vs xs @ zip (map (subst m) xs) vs 195 (* optimise: copy forward *) 196 fun forward [] aux = [] 197 | forward ((x,y)::xs) aux = let 198 val y = list_find y aux handle HOL_ERR _ => y 199 val aux = filter (fn (lhs,rhs) => not (x IN FVs rhs)) aux 200 in (x,y) :: forward xs ((x,y)::aux) end 201 val rs = forward rs [] 202 (* optimise: remove unused temporary variables *) 203 fun is_used x [] = not (is_temp_var x) 204 | is_used x ((y,z)::xs) = if x IN FVs z then true else is_used x xs 205 fun in_tail [] = [] 206 | in_tail ((x,y)::xs) = if is_used x xs then (x,y)::in_tail xs else in_tail xs 207 val rs = in_tail rs 208 in rs end; 209 210fun FIX_CALL_RETURN_VALUES_CONV tm = let 211 (* find one return value for each function *) 212 fun in_out x = let 213 val (lhs,rhs) = dest_eq x 214 fun leaves (FUN_COND (_,t)) = leaves t 215 | leaves (FUN_LET (_,_,t)) = leaves t 216 | leaves (FUN_IF (_,t1,t2)) = leaves t1 @ leaves t2 217 | leaves (FUN_VAL tm) = [tm] 218 val bases = filter (not o can (match_term lhs)) (leaves (tm2ftree rhs)) 219 in (car lhs, (cdr lhs, hd bases)) end 220 val xs = map (repeat (snd o dest_forall)) (list_dest dest_conj tm) 221 val io = map in_out xs 222 (* invent new temporaries for each return value *) 223 fun invent_new_temps (x,(y,z)) = let 224 val f = map (fn z => if is_t_var z then mk_t_var(type_of z) else z) 225 in (x,(y,mk_tuple (f (dest_tuple z)))) end 226 val io = map invent_new_temps io 227 (* add restrictions on already compiled components *) 228 (* ... *) 229 (* make sure all function calls/returns respect this io restriction *) 230 fun CONJUNCTS_CONV c tm = 231 if is_conj tm then BINOP_CONV (CONJUNCTS_CONV c) tm else c tm 232 fun FORALL_CONV c tm = 233 if is_forall tm then QUANT_CONV (FORALL_CONV c) tm else c tm 234 val FUNC_BODY_CONV = CONJUNCTS_CONV o FORALL_CONV 235 fun FLAT_CONV tm = let 236 val func_tm = (car o fst o dest_eq) tm 237 val f = tm2ftree (cdr tm) 238 fun lets [] y = y 239 | lets ((x1,x2)::xs) y = FUN_LET (x1,x2,lets xs y) 240 fun ftree_each (FUN_IF (b,t1,t2)) = FUN_IF (b,ftree_each t1,ftree_each t2) 241 | ftree_each (FUN_COND (b,t)) = FUN_COND (b,ftree_each t) 242 | ftree_each (FUN_VAL rhs) = let 243 val call = (car rhs ~~ func_tm) handle HOL_ERR _ => false 244 val x = (if call then fst else snd) (list_find func_tm io) 245 val rhs2 = if call then cdr rhs else rhs 246 val rs1 = parallel_assign x rhs2 247 val ret = if call then mk_comb(func_tm,x) else x 248 in lets rs1 (FUN_VAL ret) end 249 | ftree_each (FUN_LET (lhs,rhs,t)) = let 250 val (x,y) = list_find (car rhs) io 251 val rs1 = parallel_assign x (cdr rhs) 252 val rs2 = parallel_assign lhs y 253 in lets rs1 (FUN_LET (y,mk_comb(car rhs,x),lets rs2 (ftree_each t))) end 254 handle HOL_ERR _ => FUN_LET (lhs,rhs,ftree_each t) 255 val tm2 = ftree2tm (ftree_each f) 256 fun EXPAND_TEMPVARLET_CONV tm = let 257 val (v,x) = dest_abs (fst (dest_let tm)) 258 in if is_temp_var v then EXPAND_LET_CONV tm else NO_CONV tm end 259 handle HOL_ERR _ => NO_CONV tm 260 val goal = mk_eq(tm,mk_eq((fst o dest_eq) tm,tm2)) 261 val result = auto_prove "FLAT_CONV" (goal,SIMP_TAC std_ss [LET_DEF]) 262 in result end 263 val result = FUNC_BODY_CONV FLAT_CONV tm 264 in result end; 265 266 267(* clash graph and reg allocation *) 268 269fun ftree_free_vars t = let 270 fun vars (FUN_VAL tm) = free_vars tm 271 | vars (FUN_COND (tm,t)) = free_vars tm @ vars t 272 | vars (FUN_IF (tm,x1,x2)) = all_distinct (free_vars tm @ vars x1 @ vars x2) 273 | vars (FUN_LET (lhs,rhs,t)) = all_distinct (free_vars lhs @ free_vars rhs @ vars t) 274 in all_distinct (vars t) end; 275 276fun subroutine_internal_vars (tm,t) = let 277 val vs = free_vars (cdr tm) 278 fun leaves (FUN_COND (_,t)) = leaves t 279 | leaves (FUN_LET (_,_,t)) = leaves t 280 | leaves (FUN_IF (_,t1,t2)) = leaves t1 @ leaves t2 281 | leaves (FUN_VAL tm) = [tm] 282 val xs = append_lists (map free_vars (leaves t)) 283 in op_set_diff aconv (ftree_free_vars t) (vs @ xs) end 284 285fun clash_graph ts = let 286 fun ok_var x = (type_of x = ``:word32``) 287 fun add_live_set2 ls1 ls2 t = FUN_COND 288 (mk_eq(listSyntax.mk_list(all_distinct ls1,``:word32``), 289 listSyntax.mk_list(all_distinct ls2,``:word32``)),t) 290 fun add_live_set ls t = add_live_set2 ls [] t 291 val fs = map (car o fst) ts 292 fun get_internal_vars rhs = 293 if not (tmem (car rhs) fs) handle HOL_ERR _ => true then [] else 294 subroutine_internal_vars (hd (filter (fn (x,_) => x ~~ rhs) ts)) 295 fun live (FUN_VAL tm) = let 296 val ls = filter ok_var (free_vars tm) 297 val t = add_live_set ls (FUN_VAL tm) 298 in (ls,t) end 299 | live (FUN_COND (tm,t)) = fail() 300 | live (FUN_IF (tm,x1,x2)) = let 301 val (ls1,y1) = live x1 302 val (ls2,y2) = live x2 303 val ls = (filter ok_var (free_vars tm)) @ ls1 @ ls2 304 val t = add_live_set ls (FUN_IF (tm,y1,y2)) 305 in (ls,t) end 306 | live (FUN_LET (lhs,rhs,t)) = let 307 val (ls,tt) = live t 308 val vs = (filter ok_var (free_vars lhs)) 309 val ls2 = op_set_diff aconv ls vs 310 val ls = ls2 @ (filter ok_var (free_vars rhs)) 311 val ii = get_internal_vars rhs 312 val t = if null ii then add_live_set ls (FUN_LET (lhs,rhs,tt)) else 313 add_live_set ls (add_live_set2 ls2 ii (FUN_LET (lhs,rhs,tt))) 314 in (ls,t) end 315 fun collect (FUN_VAL tm) = [] 316 | collect (FUN_IF (tm,x1,x2)) = collect x1 @ collect x2 317 | collect (FUN_LET (lhs,rhs,t)) = collect t 318 | collect (FUN_COND (tm,t)) = let 319 val f = fst o listSyntax.dest_list 320 val (x1,x2) = dest_eq tm 321 in (f x1, f x2) :: collect t end 322 val live_sets = append_lists (map (fn (f,t) => (collect (snd (live t)))) ts) 323 fun clash [] y z = false 324 | clash ((x1,x2)::xs) y z = 325 (tmem y x1 andalso tmem z x1) orelse 326 (tmem y x1 andalso tmem z x2) orelse 327 (tmem y x2 andalso tmem z x1) orelse clash xs y z 328 val all_vars = all_distinct (append_lists (map fst live_sets)) 329 val graph = map (fn v => (v,filter (clash live_sets v) all_vars)) all_vars 330 val graph = map (fn (v,cs) => (v,filter (fn y => y !~ v) cs)) graph 331 in graph end 332 333fun move_assignments ts graph = let 334 fun pref (FUN_COND (_,t)) = pref t 335 | pref (FUN_IF (_,t1,t2)) = pref t1 @ pref t2 336 | pref (FUN_VAL tm) = [] 337 | pref (FUN_LET (x,y,t)) = 338 if is_var x andalso is_var y then (x,y)::pref t else pref t 339 val moves = append_lists (map (pref o snd) ts) 340 in moves end; 341 342(* iterated_register_coalescing implements algorithm by George and Appel '96 *) 343fun iterated_register_coalescing graph moves freq is_colourable n = let 344 val init_graph = graph 345 fun kk n = if n < 0 then [] else n::kk(n-1) 346 val regs = map (fn n => mk_var("r" ^ (int_to_string n),``:word32``)) (rev (kk (n-1))) 347 val gsort = sort (fn (xz,x) => fn (yz:term,y:term list) => length x <= length y) 348 val r = map fst (filter (fn (x,xs) => tmem x regs) graph) 349 val q = filter (fn (x,xs) => not (tmem x regs)) graph 350 fun move_related [] = [] 351 | move_related ((x,y)::moves) = x::y::move_related moves 352 fun print_graph graph = 353 (map (fn (v,ns) => (print "\n "; print_term v; print ":"; 354 map (fn x => (print " "; print_term x)) ns)) graph; print "\n") 355 fun join_all joined x = join_all joined (list_find x joined) handle HOL_ERR _ => x 356 fun merge_vertexes x y (graph,moves,joined) = let 357 val xs = filter (fn v => v !~ x andalso v !~ y) (list_find x graph) 358 val ys = filter (fn v => v !~ x andalso v !~ y) (list_find y graph) 359 val graph = filter (fn (v,ns) => v !~ x andalso v !~ y) graph 360 val graph = map (fn (v,ns) => 361 (v,all_distinct 362 (map (fn n => if n ~~ y then x else n) ns))) 363 graph 364 val graph = (x,all_distinct (xs @ ys)) :: graph 365 val moves = filter (fn z => not (tmp_eq z (x,y))) moves 366 val moves = 367 map (fn (z1,z2) => 368 (if z1 ~~ y then x else z1,if z2 ~~ y then x else z2)) 369 moves 370 val joined = (y,x)::joined 371 in (graph,moves,joined) end; 372 fun delete_vertex w (graph,moves,joined) = let 373 val graph = filter (fn (v,ns) => v !~ w) graph 374 val graph = map (fn (v,ns) => (v,filter (fn n => n !~ w) ns)) graph 375 val moves = filter (fn (x,y) => x !~ w andalso y !~ w) moves 376 in (graph,moves,joined) end; 377 fun busy w = list_find w freq handle HOL_ERR _ => 0 378 fun no_print s = print (" " ^ s ^ "\n") 379 fun build_stack graph moves joined n result = 380 (* simplification: ?w. ~(w IN ms) and degree w < n, then remove from graph *) let 381 (* val _ = no_print_graph graph *) 382 val ms = move_related moves 383 val not_ms_graph = filter (fn (v,neighbours) => not (tmem v ms)) graph 384 val ws = map fst (filter (fn (v,ns) => length ns < n) not_ms_graph) 385 val ws = filter is_colourable ws 386 val ws = sort (fn x => fn y => busy x >= busy y) ws 387 val w = first (K true) ws (* select most busy *) 388 val (graph,moves,joined) = delete_vertex w (graph,moves,joined) 389 val _ = no_print ("!" ^ term_to_string w ^ " ") 390 in build_stack graph moves joined n ((w,"r")::result) end handle HOL_ERR _ => 391 (* coalescing: ?x y. (x,y) IN moves and degree (x UNION y) < n, then combine x,y *) let 392 fun combined_degree (x,y) = length (all_distinct (list_find x graph @ list_find y graph)) 393 handle HOL_ERR _ => n+1000 394 val moves2 = filter (fn (x,y) => not (tmem x (list_find y graph))) moves 395 val moves2 = filter (fn (x,y) => combined_degree (x,y) < n) moves2 396 val moves2 = sort (fn (x1,x2) => fn (y1,y2) => busy x1 + busy x2 >= busy y1 + busy y2) moves2 397 val moves2 = filter (fn (x,y) => is_colourable x orelse is_colourable y) moves2 398 val moves2 = filter (fn (x,y) => x !~ y) moves2 399 val (x,y) = first (fn (x,y) => true) moves2 400 val (x,y) = if is_colourable y then (x,y) else (y,x) 401 val (graph,moves,joined) = merge_vertexes x y (graph,moves,joined) 402 val _ = no_print (term_to_string x ^ "<--" ^ term_to_string y ^ " ") 403 in build_stack graph moves joined n result end handle HOL_ERR _ => 404 (* freezing: removing the move property from an edge *) let 405 val ((x,y),moves) = if null moves then fail () else (hd moves,tl moves) 406 val _ = no_print (term_to_string x ^ "-/-" ^ term_to_string y ^ " ") 407 in build_stack graph moves joined n result end handle HOL_ERR _ => 408 (* spilling: select a vertex and spill it *) let 409 val ws = map fst graph 410 val ws = filter is_colourable ws 411 val ws = sort (fn x => fn y => busy x <= busy y) ws 412 val w = if null ws then fail () else hd ws (* select least busy *) 413 val (graph,moves,joined) = delete_vertex w (graph,moves,joined) 414 val _ = no_print ("^" ^ term_to_string w ^ " ") 415 in build_stack graph moves joined n ((w,"s")::result) end handle HOL_ERR _ => 416 (rev result, joined) 417 val (stack,joined) = build_stack graph moves [] n [] 418 val coalesced = join_all joined 419 fun update x y z i = if x ~~ i then y else z i 420 fun select_colour x options c = let 421 fun score c = 422 foldr (op +) 0 (map (fn (x,y) => if c x ~~ c y then 1 else 0) moves) 423 val xs = map (fn p => (p,score (update x p c))) options 424 val result = fst (hd (sort (fn (_,x) => fn (_,y) => y <= x) xs)) 425 in result end handle HOL_ERR _ => hd options 426 handle Empty => failwith "no more registers" 427 fun colour [] (c,r) = c 428 | colour ((x,ty)::stack) (c,r) = 429 if ty = "r" then let 430 val qs = map snd (filter (fn (v,ns) => coalesced v ~~ x) graph) 431 val qs = map coalesced (append_lists qs) 432 val zs = filter (fn z => tmem z r) qs 433 val zs = map c zs 434 val new_colour = select_colour x (op_set_diff aconv regs zs) c 435 in colour stack (update x new_colour c, x::r) end 436 else let 437 val qs = map snd (filter (fn (v,ns) => coalesced v ~~ x) graph) 438 val qs = map coalesced (append_lists qs) 439 val zs = filter (fn z => tmem z r) qs 440 val zs = map c zs 441 fun next_stack i = let 442 val z = mk_var("s" ^ int_to_string i,``:word32``) 443 in if tmem z zs then next_stack (i+1) else z end 444 val z = next_stack 0 445 in colour stack (update x z c, x::r) end 446 val colouring = colour stack (I,r) o join_all joined 447 (* check validity of colouring *) 448 val g = map (fn (v,ns) => (colouring v, map colouring ns)) graph 449 val _ = if null (filter (fn (x,xs) => tmem x xs) g) then () 450 else (print "\n\nRegister allocator produced invalid result.\n\n"; fail()) 451 in (colouring) end 452 453(* provide a list representing the frequency of use/def of each variable, 454 use/defs inside loops are times constant 16 for each loop nesting. *) 455fun frequency ts = let 456 fun is_rec (FUN_VAL tm) = not (pairSyntax.is_pair tm) 457 | is_rec (FUN_COND (tm,t)) = is_rec t 458 | is_rec (FUN_IF (tm,x1,x2)) = is_rec x1 orelse is_rec x2 459 | is_rec (FUN_LET (lhs,rhs,t)) = is_rec t 460 val fs = map (car o fst) ts 461 val vs = all_distinct (append_lists (map (ftree_free_vars o snd) ts)) 462 val vs = op_set_diff aconv vs fs 463 fun occ v tm s = s + (if v IN FVs tm then 1 else 0) 464 fun count v (FUN_VAL tm) s = occ v tm s 465 | count v (FUN_COND (tm,t)) s = count v t (occ v tm s) 466 | count v (FUN_IF (tm,x1,x2)) s = count v x1 (count v x2 (occ v tm s)) 467 | count v (FUN_LET (lhs,rhs,t2)) s = let 468 val t = (list_find rhs ts handle HOL_ERR _ => FUN_VAL T) 469 val inner_s = count v t 0 470 val inner_s = if is_rec t then inner_s * 16 else inner_s 471 in count v t2 (occ v lhs (occ v rhs (inner_s + s))) end 472 val freq = map (fn v => (v,count v (snd (last ts)) 0)) vs 473 in freq end; 474 475fun REMOVE_REFL_LET_CONV tm = let 476 val (x,y) = dest_let tm 477 val (v,x) = dest_abs x 478 in if v ~~ y then EXPAND_LET_CONV tm else NO_CONV tm end 479 handle HOL_ERR _ => NO_CONV tm; 480 481fun REMOVE_DEAD_LET_CONV tm = let 482 val (x,y) = dest_let tm 483 val (v,x) = dest_abs x 484 in if v IN FVs x then NO_CONV tm else EXPAND_LET_CONV tm end 485 handle HOL_ERR _ => NO_CONV tm; 486 487fun REG_ALLOC_CONV n tm = let 488 val xs = map (repeat (snd o dest_forall)) (list_dest dest_conj tm) 489 val ts = map (fn x => ((cdr o car) x, tm2ftree (cdr x))) xs 490 val graph = clash_graph ts 491 val moves = move_assignments ts graph 492 val freq = frequency ts 493 val is_colourable = is_t_var 494 val colouring = iterated_register_coalescing graph moves freq is_colourable n 495 fun COLOUR_ALPHA_CONV colouring tm = 496 ALPHA_CONV (colouring (fst (dest_abs tm))) tm handle HOL_ERR _ => NO_CONV tm 497 val thi = (BOTTOM_UP_CONV REMOVE_DEAD_LET_CONV THENC 498 BOTTOM_UP_CONV (COLOUR_ALPHA_CONV colouring) THENC 499 BOTTOM_UP_CONV REMOVE_REFL_LET_CONV) tm 500 in thi end; 501 502fun initial_clean tm = let 503 val tms = list_dest dest_conj tm 504 fun function_name tm = repeat car (fst (dest_eq tm)) 505 val fs = map function_name tms 506 fun add_foralls t = list_mk_forall (op_set_diff aconv (free_vars t) fs, t) 507 val tms2 = map add_foralls tms 508 val tm2 = list_mk_conj tms2 509 val goal = mk_imp(tm2,tm) 510 val imp = auto_prove "initial_clean" (goal, 511 ONCE_REWRITE_TAC [EQ_SYM_EQ] THEN SIMP_TAC bool_ss [] 512 THEN ONCE_REWRITE_TAC [EQ_SYM_EQ] THEN SIMP_TAC bool_ss []) 513 in imp end; 514 515fun allocate_registers n input_tm = let 516 val imp = initial_clean input_tm 517 val tm = (fst o dest_imp o concl) imp 518 val cc = COMPILER_UNABBREV_CONV 519 THENC FLATTEN_EXPS_CONV 520 THENC SSA_CONV THENC COMMON_SUBEXP_CONV 521 THENC FIX_CALL_RETURN_VALUES_CONV 522 THENC REG_ALLOC_CONV n 523 (* 524 val tm = (snd o dest_eq o concl) (cc tm) 525 *) 526 in CONV_RULE ((RATOR_CONV o RAND_CONV) cc) imp end 527 528 529 530(* 531 for x86: 1. split binary ops into two parts 532 let x = y ?? z in --> let x = y in let x = x ?? z in 533 this might lead the reg allocator to coalesce x and y, 534 alternatively augment 'moves' to have artificial (x,y) edge 535 but many of these are commutative, should there be (x,[y,z]) edge? 536 2. assume infinite number of regs, make 537 regs 5,6,7,etc. --> stack locations 0,1,2,3,etc. 538 reserve one register for loading when two stack locations are 539 used in the same instruction 540*) 541 542(* 543 544val n = 3 545val pref_list = [4,3,2,1] 546val k = length pref_list 547fun t_vars n = if n = 0 then [] else mk_t_var(``:word32``)::t_vars (n-1) 548val qs = t_vars k 549fun cross xs ys = append_lists (map (fn x => map (fn y => (x,y)) ys) xs) 550fun rest [] = [] 551 | rest ((x,y)::xs) = (x,y)::rest (filter (fn z => not (z = (y,x))) xs) 552val edges = rest (filter (fn (x,y) => not (x = y)) (cross qs qs)) 553val max = foldr (op * ) 1 (map (K 2) edges) 554val max2 = max * max 555val freq = zip qs pref_list 556val is_colourable = is_t_var 557 558fun get_graph i = let 559 fun n_filter i [] = [] 560 | n_filter i (x::xs) = 561 if i mod 2 = 0 then n_filter (i div 2) xs 562 else x :: n_filter (i div 2) xs 563 fun adj vs edges = map (fn v => (v,map snd (filter (fn x => fst x = v) edges))) vs 564 val ts = n_filter i edges 565 val ts = ts @ map (fn (x,y) => (y,x)) ts 566 val moves = n_filter (i div max) edges 567 in (adj qs ts, moves) end; 568 569fun try_inst i = let 570 val (graph,moves) = get_graph i 571 val _ = print (int_to_string i) 572 val _ = print "/" 573 val _ = print (int_to_string max2) 574 val _ = print " " 575 val ok = (iterated_register_coalescing graph moves freq is_colourable n; true) 576 handle HOL_ERR _ => false 577 val _ = print "\n" 578 in if not ok then print ("\n\nFailed at "^int_to_string i^".\n\n") else 579 if i < max2 then try_inst (i+1) else print "\n\nDone!\n\n" end; 580 581val _ = try_inst 0 582 583*) 584 585end; 586