1structure PmatchHeuristics :> PmatchHeuristics = 2struct 3 4open HolKernel boolSyntax; 5 6 7(*---------------------------------------------------------------------------- 8 Boilerplate stuff 9 ----------------------------------------------------------------------------*) 10 11type thry = {Tyop : string, Thy : string} -> 12 {case_const : term, constructors : term list} option 13 14fun constructors_of {case_const : term, constructors : term list} = constructors 15 16fun type_names ty = 17 let val {Thy,Tyop,Args} = Type.dest_thy_type ty 18 in {Thy=Thy,Tyop=Tyop} 19 end; 20 21 22(*---------------------------------------------------------------------------- 23 Various heuristics for pattern compilation 24 ----------------------------------------------------------------------------*) 25 26type pmatch_heuristic = {skip_rows : bool, collapse_cases : bool, col_fun : thry -> term list list -> int} 27 28(* the old heuristic used by HOL 4 *) 29val pheu_classic : pmatch_heuristic = { skip_rows = false, collapse_cases = false, col_fun = (fn _ => fn _ => 0) } 30 31(* one that uses always the first column, but quits early and collapses *) 32val pheu_first_col : pmatch_heuristic = { skip_rows = true, collapse_cases = true, col_fun = (fn _ => fn _ => 0) } 33 34(* one that uses always the last column, but quits early and collapses *) 35val pheu_last_col : pmatch_heuristic = { skip_rows = true, collapse_cases = true, col_fun = (fn _ => fn rowL => 36case rowL of [] => 0 | (r::_) => length r - 1) } 37 38(* A heuristic based on ranking functions *) 39fun pheu_rank (rankL : (thry -> term list -> int) list) = { skip_rows = true, 40 collapse_cases = true, 41 col_fun = (fn ty_info => fn rowL => let 42 val colL = let 43 (* assumption: rowL noteq [], and all rows have same length *) 44 fun aux a rowL = if (length (hd rowL) = 0) then List.rev a else 45 aux ((List.map hd rowL) :: a) (List.map tl rowL) 46 in 47 aux [] rowL 48 end 49 50 val ncolL = Lib.enumerate 0 colL 51 fun step rank ncolL = let 52 val ranked_cols = List.map (fn (i, pL) => ((i, pL), rank ty_info pL)) ncolL 53 val max = List.foldl (fn ((_, r), m) => if r > m then r else m) (snd (hd ranked_cols)) (tl ranked_cols) 54 val ranked_cols' = List.filter (fn (_, r) => r = max) ranked_cols 55 val ncolL' = List.map fst ranked_cols' 56 in 57 ncolL' 58 end 59 fun steps [] ncolL = ncolL 60 | steps _ [] = [] 61 | steps _ [e] = [e] 62 | steps (rf :: rankL) ncolL = steps rankL (step rf ncolL) 63 val ncolL' = steps rankL ncolL 64in 65 case ncolL' of 66 [] => 0 (* something went wrong, should not happen *) 67 | ((i, _) :: _) => i 68end)} : pmatch_heuristic 69 70(* ranking functions *) 71fun prheu_first_row _ [] = 0 72 | prheu_first_row _ (p :: _) = if (is_var p) then 0 else 1 73 74fun prheu_first_row_constr tybase [] = 0 75 | prheu_first_row_constr tybase (p :: _) = if (is_var p) then 0 else 76 let val (_,ty) = strip_fun (type_of p) in 77 case tybase (type_names ty) of NONE => 1 | SOME tyinfo => 78 (if (length (constructors_of tyinfo) = 1) then 0 else 1) end handle HOL_ERR _ => 0; 79 80val prheu_constr_prefix : (thry -> term list -> int) = 81 let fun aux n [] = n 82 | aux n (p :: pL) = if (is_var p) then n else aux (n+1) pL 83 in (fn _ => aux 0) end 84 85fun prheu_get_constr_set tybase pL = 86 case pL of [] => NONE | p :: pL' => 87 let val (_,ty) = strip_fun (type_of p) in 88 case tybase (type_names ty) of NONE => NONE | SOME tyinfo => 89 let 90 val constrL = constructors_of tyinfo; 91 val cL = List.map (fn p => fst (strip_comb p)) pL; 92 val cL' = List.filter (fn c => op_mem same_const c constrL) cL; 93 val cL'' = op_mk_set aconv cL'; 94 in 95 SOME (cL'', constrL) 96 end 97 end handle HOL_ERR _ => NONE 98 99fun prheu_get_nonvar_set pL = 100 let 101 val cL = List.map (fn p => fst (strip_comb p)) pL; 102 val cL' = List.filter (fn c => not (is_var c)) cL; 103 val cL'' = op_mk_set aconv cL'; 104 in 105 cL'' 106 end 107 108fun prheu_small_branching_factor ty_info pL = 109 case prheu_get_constr_set ty_info pL of 110 SOME (cL, full_constrL) => 111 (~(length cL + (if length cL = length full_constrL then 0 else 1))) 112 | NONE => (~(length (prheu_get_nonvar_set pL) + 2)) 113 114fun prheu_arity ty_info pL = 115 case prheu_get_constr_set ty_info pL of 116 SOME (cL, full_constrL) => 117 List.foldl (fn (c, s) => s + length (fst (strip_fun (type_of c)))) 0 cL 118 | NONE => 0 119 120 121(* heuristics defined using ranking functions *) 122val pheu_first_row = pheu_rank [prheu_first_row] 123val pheu_constr_prefix = pheu_rank [prheu_constr_prefix] 124val pheu_qba = pheu_rank [prheu_constr_prefix, prheu_small_branching_factor, prheu_arity] 125val pheu_cqba = pheu_rank [prheu_first_row_constr, prheu_constr_prefix, prheu_small_branching_factor, prheu_arity] 126 127(* A list of all the standard heuristics *) 128val default_heuristic_list = [pheu_qba, pheu_cqba, pheu_first_row, pheu_last_col, pheu_first_col] 129 130 131(*---------------------------------------------------------------------------- 132 Heuristic funs 133 ----------------------------------------------------------------------------*) 134 135(* run multiple heuristics and take the best result *) 136type pmatch_heuristic_res_compare = (term list list * term) Lib.cmp 137type pmatch_heuristic_fun = unit -> pmatch_heuristic_res_compare * (unit -> pmatch_heuristic option) 138 139 140(*---------------------- 141 comparing the results 142 -----------------------*) 143 144fun average_tree_depth t = 145let 146 val (_, ts) = strip_comb t 147 val ts' = tl ts 148 val _ = if is_var (hd ts) andalso not (null ts') then () else fail() 149 val ts'' = map (snd o strip_abs) ts' 150 val ds = List.foldl (fn (t, s) => s + average_tree_depth t) 0.0 ts'' 151 val ds' = (ds / (real (length ts''))) + 1.0 152in 153 ds' 154end handle Empty => 0.0 155 | HOL_ERR _ => 0.0 156 157fun lex_order (ord1 : 'a cmp) (ord2 : 'a cmp) xy = 158 case ord1 xy of 159 LESS => LESS 160 | GREATER => GREATER 161 | EQUAL => (ord2 xy handle Unordered => EQUAL) 162 handle Unordered => (ord2 xy handle Unordered => EQUAL) 163 164val pmatch_heuristic_cases_base_cmp : pmatch_heuristic_res_compare = 165 fn ((patts1, case_tm1), (patts2, case_tm2)) => Int.compare (length patts1, length patts2) 166 167fun pmatch_heuristic_size_base_cmp ((patts1, case_tm1), (patts2, case_tm2)) = 168 Int.compare (term_size case_tm1, term_size case_tm2) 169 170fun pmatch_heuristic_tree_base_cmp ((patts1, case_tm1), (patts2, case_tm2)) = 171 Real.compare (average_tree_depth case_tm1, average_tree_depth case_tm2) 172 173fun pmatch_heuristic_cases_cmp xy = lex_order pmatch_heuristic_cases_base_cmp 174 (lex_order pmatch_heuristic_size_base_cmp pmatch_heuristic_tree_base_cmp) xy 175 176fun pmatch_heuristic_size_cmp xy = lex_order pmatch_heuristic_size_base_cmp 177 (lex_order pmatch_heuristic_tree_base_cmp pmatch_heuristic_cases_base_cmp) xy 178 179 180 181(*------------------------- 182 try a list of heuristics 183 --------------------------*) 184 185fun pmatch_heuristic_list min_fun l () : (pmatch_heuristic_res_compare * (unit -> pmatch_heuristic option)) = let 186 val hL_ref = ref l 187 fun aux () = case (!hL_ref) of 188 [] => NONE 189 | h::hL => (hL_ref := hL; SOME h) 190in (min_fun, aux) end 191 192val default_heuristic_fun = (pmatch_heuristic_list pmatch_heuristic_cases_cmp default_heuristic_list); 193val classic_heuristic_fun = (pmatch_heuristic_list pmatch_heuristic_cases_cmp [pheu_classic]); 194 195 196(*----------------------------------- 197 an exhaustive search heuristic-fun 198 ------------------------------------*) 199 200fun exhaustive_heuristic_fun cmp = 201let 202 val heuristicL_ref = ref ([]:pmatch_heuristic list) 203 fun add_heu heu = (heuristicL_ref := heu :: (!heuristicL_ref)) 204 205 fun heu (prefix : (bool * int * int) list) : pmatch_heuristic = 206 let 207 val current_prefix = ref prefix 208 val remaining_prefix = ref prefix 209 210 fun colfun_print thry rowL = 211 case (!remaining_prefix) of 212 (i :: is) => (remaining_prefix := is; i) 213 | [] => 214 let 215 fun all_vars n = List.all is_var (List.map 216 (fn r => List.nth (r, n)) rowL) 217 fun col_fun n _ = if (all_vars n orelse 218 (List.null rowL) orelse 219 (List.length (List.hd rowL) < 2)) then 220 SOME n else NONE 221 in 222 case (Lib.first_opt col_fun rowL) of 223 SOME r => let 224 val _ = current_prefix := (!current_prefix) @ [(false, r, 1)] 225 in 226 (false, r, 1) 227 end 228 | NONE => let 229 val l = List.length (hd rowL) 230 val _ = Lib.appi (fn i => fn _ => add_heu (heu ((!current_prefix) @ [(true, i+1, l)]))) (tl (hd rowL)) 231 val _ = current_prefix := (!current_prefix) @ [(true, 0, l)] 232 in 233 (true, 0, l) 234 end 235 end 236 237 fun colfun thry rowL = let 238 val (do_it, r, l) = colfun_print thry rowL 239 val _ = if do_it then (print (int_to_string (r+1)); print "/"; print (int_to_string l); print " ") else () 240 in 241 r 242 end 243 in 244 { skip_rows = true, collapse_cases = true, col_fun = colfun } 245 end 246 247 fun next_heu () = 248 case (!heuristicL_ref) of 249 [] => NONE 250 | (h :: hs) => (print "\n"; heuristicL_ref := hs; SOME h) 251 252 fun init () = 253 let 254 val _ = heuristicL_ref := [heu []] 255 fun cmp' (r1, r2) = ( 256 print "\nCases in Result: "; 257 print (int_to_string (length (fst r2))); 258 print " (best so far "; 259 print (int_to_string (length (fst r1))); 260 print ")"; 261 cmp (r1, r2) 262 ) 263 in 264 (cmp', next_heu) 265 end 266in 267 init 268end 269 270(* A manual one taking an explicit order *) 271fun pheu_manual (res_input : int list) : pmatch_heuristic = 272let 273 fun pr ts = let 274 val ts_sl = List.map (fn t => "``" ^ (Hol_pp.term_to_string t) ^ "``") ts 275 val ts_s = String.concatWith ", " ts_sl 276 in 277 print ts_s; 278 print "\n" 279 end 280 val res = ref res_input 281 fun man_choice rowL = 282 let 283 val r = case (!res) of 284 [] => 0 285 | (i::is) => 286 let 287 val _ = res := is 288 in 289 if (i < List.length (List.hd rowL)) then i else 290 (print ("Error, can't use "^(int_to_string i)^"\n");0) 291 end 292 val _ = List.map pr rowL 293 val _ = print ("Using "^(int_to_string r)^"\n\n\n") 294 in 295 r 296 end 297 298 fun colfun thry rowL = 299 if ((null rowL) orelse (length (hd rowL) < 2)) then 0 else 300 let 301 fun all_vars n = List.all is_var (List.map 302 (fn r => List.nth (r, n)) rowL) 303 fun col_fun n _ = if all_vars n then SOME n else NONE 304 in 305 case (Lib.first_opt col_fun rowL) of 306 SOME r => r 307 | NONE => man_choice rowL 308 end 309in 310 { skip_rows = true, collapse_cases = true, col_fun = colfun } 311end 312 313 314(*---------------------------------------------------------------------------- 315 A reference to store the current heuristic_fun 316 ----------------------------------------------------------------------------*) 317 318val pmatch_heuristic = ref default_heuristic_fun 319val classic = ref false 320fun is_classic () = !classic 321 322fun set_heuristic_fun heu_fun = (pmatch_heuristic := heu_fun) 323fun set_heuristic_list_size heuL = set_heuristic_fun (pmatch_heuristic_list pmatch_heuristic_size_cmp heuL) 324fun set_heuristic_list_cases heuL = set_heuristic_fun (pmatch_heuristic_list pmatch_heuristic_cases_cmp heuL) 325fun set_heuristic heu = set_heuristic_list_cases [heu] 326 327fun set_default_heuristic () = 328 (classic := false; set_heuristic_fun default_heuristic_fun) 329 330fun set_default_heuristic_size () = 331 (classic := false; set_heuristic_list_size default_heuristic_list) 332 333fun set_default_heuristic_cases () = 334 (classic := false; set_heuristic_list_cases default_heuristic_list) 335 336fun set_classic_heuristic () = 337 (classic := true; set_heuristic_fun classic_heuristic_fun) 338 339fun with_heuristic heu f = 340 with_flag (classic, false) 341 (with_flag (pmatch_heuristic, 342 pmatch_heuristic_list pmatch_heuristic_cases_cmp [heu]) f) 343 344fun with_classic_heuristic f = 345 with_flag (classic, true) 346 (with_flag (pmatch_heuristic, classic_heuristic_fun) f) 347 348fun with_manual_heuristic choices = 349 with_heuristic (pheu_manual choices) 350 351end; 352