1
2structure patternMatch =
3struct
4
5(*****************************************************************************************)
6(* Decision Tree for pattern matching.                                                   *)
7(* When compiling a function in clausal form, we build a decision tree which determines  *)
8(* the order in which subterms of any term are to be examined to find the first patter n *)
9(* that matches that term. We attempt to make this decision tree optimal or minimal, in  *)
10(* the sense that the order imposed on subterm-testing is such that the matching pattern *)
11(* can be found with a minimum number of tests. Each node of a decision tree represents  *)
12(* a test that can be carried out on a sub-term discriminating between constructor cases.*)
13(* The branches coming out of a node correspond to the possible results of the test      *)
14(* performed at that node (i.e. possible constructors for that subterm, which are        *)
15(* determined by the type of the subterm). Each branch is labeled with a type constuctor *)
16(* and a set of pattern indices representing the patterns that were still possibly       *)
17(* matching (live) before the test and that have this construct or as result of the test.*)
18(* At run-time, when a value term has to be matched against one of the patterns, the     *)
19(* code executed corresponds to going down the decision tree from the root to one of the *)
20(* leaves, executing the tests corresponding to the test nodes along the path.           *)
21(*****************************************************************************************)
22
23open HolKernel Parse;
24
25structure S = Binaryset
26
27(*****************************************************************************************)
28(* Decision Tree.                                                                        *)
29(*****************************************************************************************)
30
31datatype 'a subterm = Wildcard          (* _ *)
32                    | Var of 'a         (* variables *)
33                    | Constr of 'a      (* constructor *)
34                    | ConS of 'a S.set    (* constructor set *)
35
36type rule_type = int * (string subterm) list;
37type test_type = int * string;
38
39datatype 'a tree = Node of {tset : test_type S.set,    (* tests *)
40                            rlist : rule_type list,    (* rules *)
41                            ledge : test_type,         (* left out-coming edge *)
42                            left : 'a tree,            (* left sub-tree *)
43                            right : 'a tree}           (* right subtree *)
44                 | Leaf of rule_type list              (* a single rule *)
45
46fun strOrder (name1 : string, name2 : string) =
47      if name1 > name2 then GREATER
48      else if name1 = name2 then EQUAL
49      else LESS
50
51fun testOrder ((index1, name1) : int * string, (index2, name2) : int * string) =
52    if index1 > index2 then GREATER
53    else if index1 = index2 then
54      if name1 > name2 then GREATER
55      else if name1 = name2 then EQUAL
56      else LESS
57    else LESS;
58
59(*****************************************************************************************)
60(* Attempt to choose the optimal test.                                                   *)
61(*****************************************************************************************)
62
63fun match_constr(Constr x, v) = x = v
64 |  match_constr(ConS s, v) = S.member(s,v)
65 |  match_constr(_, v) = true
66
67fun relevant_indices(test_set, rule) =
68  let
69    val test_list = S.listItems(test_set)
70    val relevant_list =
71      List.filter (fn (index,name) =>
72         (index = fst rule) andalso
73         not (match_constr (List.nth(snd rule, index), name))) test_list
74  in
75    relevant_list
76  end
77
78fun next_indices (test_set, rules) =
79  let
80      fun first_index(remaining_rules) =
81        if length remaining_rules = 0 then
82          []
83        else
84          let val v = relevant_indices(test_set, hd remaining_rules)
85          in  if length v = 0 then
86                first_index(tl remaining_rules)
87              else
88                v
89          end
90   in
91      first_index(rev rules)
92   end
93
94(*****************************************************************************************)
95(* Eliminate redundant rules in a rule list.                                             *)
96(*****************************************************************************************)
97
98fun zip ([],[]) = []
99 |  zip (x1::l1,x2::l2) = (x1,x2)::zip(l1,l2)
100
101fun elim_redundant_rules rules =
102  let
103    fun cover rule1 rule2 =
104      List.all (fn (t1,t2) =>
105                  case t1 of
106                      Constr x =>
107                       (case t2 of
108                           Constr y => x = y
109                         | ConS s => S.listItems s = [x]
110                         | _ => false        (* Wildcard or variables *)
111                       )
112                    | ConS s1 =>
113                       (case t2 of
114                           Constr y => S.member(s1, y)
115                         | ConS s2 => S.isSubset(s2,s1)
116                         | _ => false        (* Wildcard or variables *)
117                       )
118                    | _ => true)
119                (zip(rule1,rule2))
120  in
121    List.foldl (fn (rule2, r_rules) =>
122                 if List.exists (fn rule1 => cover (snd rule1) (snd rule2)) r_rules     (* rule1 consumes rule2 *)
123                 then r_rules
124                 else r_rules @ [rule2]
125               )
126               [] rules
127  end
128
129(*****************************************************************************************)
130(* Modify the rule by instantializing it with constructor information.                   *)
131(*****************************************************************************************)
132
133fun inst_rule (rule : rule_type, index, subterm) =
134  (#1(rule), List.take(#2 rule, index) @ [subterm] @ List.drop(#2 rule, index + 1))
135
136fun inst_rules(rules : rule_type list, test_set) =
137  let val l = S.listItems test_set
138      val index = fst (hd (l))
139      val constr_set = S.addList(S.empty strOrder, List.map snd l)
140  in
141     List.foldl (fn (rule, r_rules) =>
142         case List.nth(#2 rule, index) of
143            Constr x => if S.member(constr_set, x)
144                        then r_rules @ [inst_rule(rule, index, Constr x)]
145                        else r_rules
146          | ConS s => let val sub_s = S.intersection(constr_set, s) in
147                          if S.isEmpty sub_s then r_rules
148                          else if S.numItems s = 1 then
149                            r_rules @ [inst_rule(rule, index, Constr (hd (S.listItems s)))]
150                          else
151                            r_rules @ [inst_rule(rule, index, ConS constr_set)]
152                      end
153          | _ => r_rules @ [inst_rule(rule, index,
154                            if S.numItems constr_set = 1
155                            then Constr (hd (S.listItems constr_set))
156                            else ConS constr_set)]
157          )
158          [] rules
159  end
160
161(*****************************************************************************************)
162(* Build the decision tree in a top-down manner.                                         *)
163(* A test on a subterm is relevant to a pattern pi if and only if pi does not agree with *)
164(* all the possible values on that subterm. In terms of decision-trees, a test on a sub  *)
165(* term is relevant to a pattern pi if and only if i does not appear in the set of live  *)
166(* rule indices which label each successor of that test. Given a set of possible tests   *)
167(* tset and a set of live rule indices rset, the relevance heuristic searches for the    *)
168(* least index i in rset such that at least one test in tset is relevant to pi. If there *)
169(* is no such index, no test in tset is relevant to any pattern in rset and one has      *)
170(* reached a leaf of the tree. Otherwise, one computes the sub set trel of tset          *)
171(* containing the tests that are relevant to pi. If trel is a singleton, its element is  *)
172(* the next desired next test. Otherwise, the next heuristic selection is applied on     *)
173(* trel.                                                                                 *)
174(*****************************************************************************************)
175
176fun build_tree(test_set, rules) =
177  let
178     val indices = next_indices (test_set, rules)
179  in
180     if null indices then
181       Leaf rules
182     else
183       let
184         val test = hd indices
185         val left_set = S.add(S.empty testOrder, test)
186         val right_set = S.addList(S.empty testOrder, List.filter (fn (n,x) => n = fst test)
187                         (S.listItems(S.delete(test_set, test))))
188         val next_set = S.delete(test_set, test)
189
190         val left_rules = elim_redundant_rules(inst_rules(rules, left_set))
191         val right_rules = elim_redundant_rules(inst_rules(rules, right_set))
192      in
193         Node {tset = test_set,                              (* tests *)
194               rlist = rules,                                (* rules *)
195               ledge = test,                                 (* left out-coming edge *)
196               left = build_tree(next_set, left_rules),      (* left sub-tree *)
197               right = build_tree(next_set, right_rules)     (* right subtree *)
198              }
199      end
200  end
201
202(* val test_set = S.addList(S.empty testOrder,
203         [(0,"nil"), (0,"cons"), (1,"nil"), (1,"cons")]);
204   val rules = [(0, [Constr "nil", Var ""]), (1, [Var "", Constr "nil"])];
205
206   build_tree(test_set, rules)
207
208*)
209
210(*****************************************************************************************)
211(* Branching Factor Heuristic.                                                           *)
212(* The branching factor heuristic tries to minimize the number of test-nodes by favoring *)
213(* the choice of tests with low branching factor first.                                  *)
214(*****************************************************************************************)
215
216fun select_index(test_set, indices) =
217  if length indices = 1 then
218    hd indices
219  else
220    let
221      val test_list = S.listItems test_set
222      val (min_index, min_value) =
223        List.foldl
224          (fn ((index,name), (i,j)) =>
225             let val n = length (List.filter (fn (n,x) => n = i) test_list)
226             in  if n < j then (index, n)
227                 else (i,j)
228             end
229           )
230        (0, length test_list) indices
231    in
232       valOf (List.find (fn(index,name) => index = min_index) indices)
233    end
234
235(* val test_set = S.addList(S.empty testOrder,
236         [(0,"true"), (0,"false"), (1,"red"), (1,"blue"), (1,"green")]);
237   val test_set = S.addList(S.empty testOrder,
238         [(0,"true"), (0,"false"), (1,"green"), (1,"red/blue")]);
239   val rules = [(0, [Constr "true", Constr "green"]), (1, [Constr "false", Constr "green"])];
240
241   build_tree(test_set, rules)
242
243*)
244
245end
246
247(*
248Define `(f 0 _ = 0) /\
249          (f (SUC i) 0 = i) /\
250          (f (SUC i) (SUC j) = i + j)
251         `;
252*)
253