1structure PmatchHeuristics :> PmatchHeuristics =
2struct
3
4open HolKernel boolSyntax;
5
6
7(*----------------------------------------------------------------------------
8      Boilerplate stuff
9 ----------------------------------------------------------------------------*)
10
11type thry   = {Tyop : string, Thy : string} ->
12              {case_const : term, constructors : term list} option
13
14fun constructors_of {case_const : term, constructors : term list} = constructors
15
16fun type_names ty =
17  let val {Thy,Tyop,Args} = Type.dest_thy_type ty
18  in {Thy=Thy,Tyop=Tyop}
19  end;
20
21
22(*----------------------------------------------------------------------------
23      Various heuristics for pattern compilation
24 ----------------------------------------------------------------------------*)
25
26type pmatch_heuristic = {skip_rows : bool, collapse_cases : bool, col_fun : thry -> term list list -> int}
27
28(* the old heuristic used by HOL 4 *)
29val pheu_classic : pmatch_heuristic = { skip_rows = false, collapse_cases = false, col_fun = (fn _ => fn _ => 0) }
30
31(* one that uses always the first column, but quits early and collapses *)
32val pheu_first_col : pmatch_heuristic = { skip_rows = true, collapse_cases = true, col_fun = (fn _ => fn _ => 0) }
33
34(* one that uses always the last column, but quits early and collapses *)
35val pheu_last_col : pmatch_heuristic = { skip_rows = true, collapse_cases = true, col_fun = (fn _ => fn rowL =>
36case rowL of [] => 0 | (r::_) => length r - 1) }
37
38(* A heuristic based on ranking functions *)
39fun pheu_rank (rankL : (thry -> term list -> int) list) = { skip_rows = true,
40   collapse_cases = true,
41   col_fun = (fn ty_info => fn rowL => let
42   val colL = let
43     (* assumption: rowL noteq [], and all rows have same length *)
44     fun aux a rowL = if (length (hd rowL) = 0) then List.rev a else
45             aux ((List.map hd rowL) :: a) (List.map tl rowL)
46   in
47     aux [] rowL
48   end
49
50   val ncolL = Lib.enumerate 0 colL
51   fun step rank ncolL = let
52     val ranked_cols = List.map (fn (i, pL) => ((i, pL), rank ty_info pL)) ncolL
53     val max = List.foldl (fn ((_, r), m) => if r > m then r else m) (snd (hd ranked_cols)) (tl ranked_cols)
54     val ranked_cols' = List.filter (fn (_, r) => r = max) ranked_cols
55     val ncolL' = List.map fst ranked_cols'
56   in
57     ncolL'
58   end
59   fun steps [] ncolL = ncolL
60     | steps _ [] = []
61     | steps _ [e] = [e]
62     | steps (rf :: rankL) ncolL = steps rankL (step rf ncolL)
63   val ncolL' = steps rankL ncolL
64in
65   case ncolL' of
66      [] => 0 (* something went wrong, should not happen *)
67    | ((i, _) :: _) => i
68end)} : pmatch_heuristic
69
70(* ranking functions *)
71fun prheu_first_row _ [] = 0
72  | prheu_first_row _ (p :: _) = if (is_var p) then 0 else 1
73
74fun prheu_first_row_constr tybase [] = 0
75  | prheu_first_row_constr tybase (p :: _) = if (is_var p) then 0 else
76    let val (_,ty) = strip_fun (type_of p) in
77     case tybase (type_names ty) of NONE => 1 | SOME tyinfo =>
78     (if (length (constructors_of tyinfo) = 1) then 0 else 1) end handle HOL_ERR _ => 0;
79
80val prheu_constr_prefix : (thry -> term list -> int) =
81  let fun aux n [] = n
82        | aux n (p :: pL) = if (is_var p) then n else aux (n+1)  pL
83  in (fn _ => aux 0) end
84
85fun prheu_get_constr_set tybase pL =
86  case pL of [] => NONE | p :: pL' =>
87  let val (_,ty) = strip_fun (type_of p) in
88  case tybase (type_names ty) of NONE => NONE | SOME tyinfo =>
89  let
90     val constrL = constructors_of tyinfo;
91     val cL = List.map (fn p => fst (strip_comb p)) pL;
92     val cL' = List.filter (fn c => op_mem same_const c constrL) cL;
93     val cL'' = op_mk_set aconv cL';
94  in
95    SOME (cL'', constrL)
96  end
97  end handle HOL_ERR _ => NONE
98
99fun prheu_get_nonvar_set pL =
100  let
101     val cL = List.map (fn p => fst (strip_comb p)) pL;
102     val cL' = List.filter (fn c => not (is_var c)) cL;
103     val cL'' = op_mk_set aconv cL';
104  in
105    cL''
106  end
107
108fun prheu_small_branching_factor ty_info pL =
109  case prheu_get_constr_set ty_info pL of
110      SOME (cL, full_constrL) =>
111        (~(length cL + (if length cL = length full_constrL then 0 else 1)))
112    | NONE => (~(length (prheu_get_nonvar_set pL) + 2))
113
114fun prheu_arity ty_info pL =
115  case prheu_get_constr_set ty_info pL of
116     SOME (cL, full_constrL) =>
117       List.foldl (fn (c, s) => s + length (fst (strip_fun (type_of c)))) 0 cL
118   | NONE => 0
119
120
121(* heuristics defined using ranking functions *)
122val pheu_first_row = pheu_rank [prheu_first_row]
123val pheu_constr_prefix = pheu_rank [prheu_constr_prefix]
124val pheu_qba = pheu_rank [prheu_constr_prefix, prheu_small_branching_factor, prheu_arity]
125val pheu_cqba = pheu_rank [prheu_first_row_constr, prheu_constr_prefix, prheu_small_branching_factor, prheu_arity]
126
127(* A list of all the standard heuristics *)
128val default_heuristic_list = [pheu_qba, pheu_cqba, pheu_first_row, pheu_last_col, pheu_first_col]
129
130
131(*----------------------------------------------------------------------------
132      Heuristic funs
133 ----------------------------------------------------------------------------*)
134
135(* run multiple heuristics and take the best result *)
136type pmatch_heuristic_res_compare = (term list list * term) Lib.cmp
137type pmatch_heuristic_fun = unit -> pmatch_heuristic_res_compare * (unit -> pmatch_heuristic option)
138
139
140(*----------------------
141  comparing the results
142 -----------------------*)
143
144fun average_tree_depth t =
145let
146  val (_, ts) = strip_comb t
147  val ts' = tl ts
148  val _ = if is_var (hd ts) andalso not (null ts') then () else fail()
149  val ts'' = map (snd o strip_abs) ts'
150  val ds = List.foldl (fn (t, s) => s + average_tree_depth t) 0.0 ts''
151  val ds' = (ds / (real (length ts''))) + 1.0
152in
153  ds'
154end handle Empty => 0.0
155         | HOL_ERR _ => 0.0
156
157fun lex_order (ord1 : 'a cmp) (ord2 : 'a cmp) xy =
158  case ord1 xy of
159     LESS => LESS
160   | GREATER => GREATER
161   | EQUAL => (ord2 xy handle Unordered => EQUAL)
162  handle Unordered => (ord2 xy handle Unordered => EQUAL)
163
164val pmatch_heuristic_cases_base_cmp : pmatch_heuristic_res_compare =
165  fn ((patts1, case_tm1), (patts2, case_tm2)) => Int.compare (length patts1, length patts2)
166
167fun pmatch_heuristic_size_base_cmp ((patts1, case_tm1), (patts2, case_tm2)) =
168  Int.compare (term_size case_tm1, term_size case_tm2)
169
170fun pmatch_heuristic_tree_base_cmp ((patts1, case_tm1), (patts2, case_tm2)) =
171  Real.compare (average_tree_depth case_tm1, average_tree_depth case_tm2)
172
173fun pmatch_heuristic_cases_cmp xy = lex_order pmatch_heuristic_cases_base_cmp
174  (lex_order pmatch_heuristic_size_base_cmp pmatch_heuristic_tree_base_cmp) xy
175
176fun pmatch_heuristic_size_cmp xy = lex_order pmatch_heuristic_size_base_cmp
177  (lex_order pmatch_heuristic_tree_base_cmp pmatch_heuristic_cases_base_cmp) xy
178
179
180
181(*-------------------------
182  try a list of heuristics
183 --------------------------*)
184
185fun pmatch_heuristic_list min_fun l () : (pmatch_heuristic_res_compare * (unit -> pmatch_heuristic option)) = let
186  val hL_ref = ref l
187  fun aux () = case (!hL_ref) of
188     [] => NONE
189   | h::hL => (hL_ref := hL; SOME h)
190in (min_fun, aux) end
191
192val default_heuristic_fun = (pmatch_heuristic_list pmatch_heuristic_cases_cmp default_heuristic_list);
193val classic_heuristic_fun = (pmatch_heuristic_list pmatch_heuristic_cases_cmp [pheu_classic]);
194
195
196(*-----------------------------------
197  an exhaustive search heuristic-fun
198 ------------------------------------*)
199
200fun exhaustive_heuristic_fun cmp =
201let
202  val heuristicL_ref = ref ([]:pmatch_heuristic list)
203  fun add_heu heu = (heuristicL_ref := heu :: (!heuristicL_ref))
204
205  fun heu (prefix : (bool * int * int) list) : pmatch_heuristic =
206  let
207    val current_prefix = ref prefix
208    val remaining_prefix = ref prefix
209
210    fun colfun_print thry rowL =
211      case (!remaining_prefix) of
212          (i :: is) => (remaining_prefix := is; i)
213        | [] =>
214            let
215              fun all_vars n = List.all is_var (List.map
216                 (fn r => List.nth (r, n)) rowL)
217              fun col_fun n _ = if (all_vars n orelse
218                   (List.null rowL) orelse
219                   (List.length (List.hd rowL) < 2)) then
220                SOME n else NONE
221            in
222               case (Lib.first_opt col_fun rowL) of
223                 SOME r => let
224                   val _ = current_prefix := (!current_prefix) @ [(false, r, 1)]
225                 in
226                   (false, r, 1)
227                 end
228               | NONE => let
229                   val l = List.length (hd rowL)
230                   val _ = Lib.appi (fn i => fn _ =>  add_heu (heu ((!current_prefix) @ [(true, i+1, l)]))) (tl (hd rowL))
231                   val _ = current_prefix := (!current_prefix) @ [(true, 0, l)]
232                 in
233                   (true, 0, l)
234                 end
235            end
236
237     fun colfun thry rowL = let
238       val (do_it, r, l) = colfun_print thry rowL
239       val _ = if do_it then (print (int_to_string (r+1)); print "/"; print (int_to_string l); print " ") else ()
240     in
241       r
242     end
243  in
244    { skip_rows = true, collapse_cases = true, col_fun = colfun }
245  end
246
247  fun next_heu () =
248    case (!heuristicL_ref) of
249       [] => NONE
250     | (h :: hs) => (print "\n"; heuristicL_ref := hs; SOME h)
251
252  fun init () =
253  let
254    val _ = heuristicL_ref := [heu []]
255    fun cmp' (r1, r2) = (
256       print "\nCases in Result: ";
257       print (int_to_string (length (fst r2)));
258       print " (best so far ";
259       print (int_to_string (length (fst r1)));
260       print ")";
261       cmp (r1, r2)
262    )
263  in
264    (cmp', next_heu)
265  end
266in
267  init
268end
269
270(* A manual one taking an explicit order *)
271fun pheu_manual (res_input : int list) : pmatch_heuristic =
272let
273  fun pr ts = let
274    val ts_sl = List.map (fn t => "``" ^ (Hol_pp.term_to_string t) ^ "``") ts
275    val ts_s = String.concatWith ", " ts_sl
276  in
277    print ts_s;
278    print "\n"
279  end
280  val res = ref res_input
281  fun man_choice rowL =
282    let
283      val r = case (!res) of
284          [] => 0
285        | (i::is) =>
286              let
287                val _ = res := is
288              in
289                if (i < List.length (List.hd rowL)) then i else
290                   (print ("Error, can't use "^(int_to_string i)^"\n");0)
291              end
292      val _ = List.map pr rowL
293      val _ = print ("Using "^(int_to_string r)^"\n\n\n")
294    in
295      r
296    end
297
298  fun colfun thry rowL =
299    if ((null rowL) orelse (length (hd rowL) < 2)) then 0 else
300    let
301      fun all_vars n = List.all is_var (List.map
302          (fn r => List.nth (r, n)) rowL)
303      fun col_fun n _ = if all_vars n then SOME n else NONE
304    in
305      case (Lib.first_opt col_fun rowL) of
306          SOME r => r
307        | NONE => man_choice rowL
308    end
309in
310  { skip_rows = true, collapse_cases = true, col_fun = colfun }
311end
312
313
314(*----------------------------------------------------------------------------
315      A reference to store the current heuristic_fun
316 ----------------------------------------------------------------------------*)
317
318val pmatch_heuristic = ref default_heuristic_fun
319val classic = ref false
320fun is_classic () = !classic
321
322fun set_heuristic_fun heu_fun = (pmatch_heuristic := heu_fun)
323fun set_heuristic_list_size heuL = set_heuristic_fun (pmatch_heuristic_list pmatch_heuristic_size_cmp heuL)
324fun set_heuristic_list_cases heuL = set_heuristic_fun (pmatch_heuristic_list pmatch_heuristic_cases_cmp heuL)
325fun set_heuristic heu = set_heuristic_list_cases [heu]
326
327fun set_default_heuristic () =
328   (classic := false; set_heuristic_fun default_heuristic_fun)
329
330fun set_default_heuristic_size () =
331   (classic := false; set_heuristic_list_size default_heuristic_list)
332
333fun set_default_heuristic_cases () =
334   (classic := false; set_heuristic_list_cases default_heuristic_list)
335
336fun set_classic_heuristic () =
337   (classic := true; set_heuristic_fun classic_heuristic_fun)
338
339fun with_heuristic heu f =
340   with_flag (classic, false)
341      (with_flag (pmatch_heuristic,
342                  pmatch_heuristic_list pmatch_heuristic_cases_cmp [heu]) f)
343
344fun with_classic_heuristic f =
345   with_flag (classic, true)
346      (with_flag (pmatch_heuristic, classic_heuristic_fun) f)
347
348fun with_manual_heuristic choices =
349   with_heuristic (pheu_manual choices)
350
351end;
352