1(* ===================================================================== *)
2(* FILE          : BoolExtractShared                                     *)
3(* DESCRIPTION   : Tools to extract shared terms from a boolean          *)
4(*                 expression                                            *)
5(*                                                                       *)
6(* AUTHORS       : Thomas Tuerk                                          *)
7(* DATE          : July 3, 2008                                          *)
8(* ===================================================================== *)
9
10
11structure  BoolExtractShared :>  BoolExtractShared =
12struct
13
14(*
15quietdec := true;
16*)
17
18open HolKernel Parse boolLib
19
20(*
21quietdec := false;
22*)
23
24
25(*---------------------------------------------------------------------------
26 * dest_neg_eq (``~(a = b)``) = (``a``, ``b``);
27 *---------------------------------------------------------------------------*)
28  fun dest_neg_eq t = dest_eq (dest_neg t);
29  val is_neg_eq = can dest_neg_eq;
30
31
32(*---------------------------------------------------------------------------
33 * mk_neg___idempot a = ~a
34 * mk_neg___idempot ~a = a
35 * with a not being a negationg
36 *---------------------------------------------------------------------------*)
37  fun mk_neg___idempot t =
38    if is_neg t then dest_neg t else mk_neg t;
39
40
41(*---------------------------------------------------------------------------
42 * for t1 and t2 being disequations,
43 * eq_sym_acond tests, wether they are alpha convertible with
44 * respect to the symmetry of equaltity
45 *---------------------------------------------------------------------------*)
46  fun eq_sym_aconv t1 t2 =
47      aconv t1 t2 orelse
48      (is_eq t1 andalso is_eq t2 andalso
49       let
50           val (t1l, t1r) = dest_eq t1;
51           val (t2l, t2r) = dest_eq t2;
52       in
53          (aconv t1r t2l) andalso (aconv t1l t2r)
54       end) orelse
55      (is_neg_eq t1 andalso is_neg_eq t2 andalso
56       let
57           val (t1l, t1r) = dest_neg_eq t1;
58           val (t2l, t2r) = dest_neg_eq t2;
59       in
60          (aconv t1r t2l) andalso (aconv t1l t2r)
61       end);
62
63
64
65
66  fun eq_sym_mem e [] = false
67    | eq_sym_mem e (h::l) =
68      (eq_sym_aconv e h) orelse eq_sym_mem e l;
69
70
71  fun findMatches ([], l2) = []
72    | findMatches (a::l1, l2) =
73         let val l1' = filter (fn e => not (eq_sym_aconv e a)) l1;
74             val l2' = filter (fn e => not (eq_sym_aconv e a)) l2;
75             val l = (findMatches (l1',l2')); in
76         if eq_sym_mem a l2 then a::l else l end;
77
78  fun find_negation_pair [] = NONE |
79      find_negation_pair (e::l) =
80      if eq_sym_mem (mk_neg___idempot e) l then SOME e else
81      find_negation_pair l;
82
83
84  fun dest_quant t = dest_abs (snd (dest_comb t));
85  fun is_quant t = is_forall t orelse is_exists t orelse
86                 is_exists1 t;
87
88
89  fun findMatches___multiple (l1:(term * bool) list,l2:(term * bool) list) =
90      map (fn t => (t, true)) (findMatches (map fst l1, map fst l2));
91
92
93  (*get_impl_terms returns a list of terms that imply the whole term and
94    a list of terms that are implied. Thus if get_impl_terms t = (disjL,conjL) then
95
96    forall x in disjL. x implies t and
97    forall x in conjL. t implies x holds.
98
99
100    so e. g. get_impl_terms ``a \/ b \/ ~c`` returns
101     ([``a \/ b \/ ~c``, ``a``, ``b \/ ~c``, ``b``, ``~c``], [``a \/ b \/ ~c``])
102
103
104    get_impl_terms___multiple augments the results with information, whether
105    multiple occurences of a term have been abbriated to one.
106   *)
107  fun get_impl_terms___multiple t =
108      if is_disj t then
109          (let val (t1,t2)=dest_disj t;
110               val (l11,l12)= get_impl_terms___multiple t1;
111               val (l21,l22)= get_impl_terms___multiple t2;
112           in
113              ((t,false)::(l11 @ l21), (t,false)::findMatches___multiple (l12, l22))
114           end)
115      else
116      if is_conj t then
117          (let val (t1,t2)=dest_conj t;
118               val (l11,l12)= get_impl_terms___multiple t1;
119               val (l21,l22)= get_impl_terms___multiple t2;
120           in
121              ((t,false)::findMatches___multiple (l11, l21), (t,false)::(l12 @ l22))
122           end)
123      else
124      if is_neg t then
125          (let val (l1,l2) = get_impl_terms___multiple (dest_neg t) in
126              (map (fn (t,b) => (mk_neg___idempot t, b)) l2, map (fn (t,b) => (mk_neg___idempot t, b)) l1)
127          end)
128      else
129      if is_imp t then
130          (let val (t1,t2)=dest_imp_only t;
131               val (l11',l12')= get_impl_terms___multiple t1;
132               val (l11, l12) = (map (fn (t,b) => (mk_neg___idempot t, b)) l12',
133                                 map (fn (t,b) => (mk_neg___idempot t, b)) l11')
134               val (l21,l22)= get_impl_terms___multiple t2;
135           in
136              ((t,false)::(l11 @ l21), (t,false)::findMatches___multiple (l12, l22))
137           end)
138      else
139      if is_quant t then
140          (let
141              val (v, b) = dest_quant t;
142              val (l1,l2) = get_impl_terms___multiple b;
143              fun filter_pred (t,b) = not (mem v (free_vars t));
144          in
145              ((t,false)::(filter filter_pred l1), (t,false)::(filter filter_pred l2))
146          end)
147      else
148        (if same_const T t orelse same_const F t then ([],[]) else
149           ([(t,false)],[(t,false)]));
150
151
152  fun clean_term_multiple_list [] = []
153    | clean_term_multiple_list ((t,b)::L) =
154      if (mem t (map fst L)) then
155         (t,true)::clean_term_multiple_list (filter (fn (t',b') => not (t = t')) L)
156      else
157         (t,b)::clean_term_multiple_list L;
158
159
160  fun get_impl_terms t =
161     let
162        val (l1,l2) = get_impl_terms___multiple t;
163     in
164        (map fst l1, map fst l2)
165     end
166
167
168
169
170
171fun get_rewrite_assumption_thms rewr =
172   let val match_thm = ASSUME rewr; in
173   if is_eq rewr then
174      [EQT_INTRO match_thm, EQT_INTRO (GSYM match_thm)]
175   else if (is_neg_eq rewr) then
176      [match_thm, GSYM match_thm]
177   else [match_thm]
178   end
179
180fun case_split_REWRITE_CONV [] t =
181    EQT_ELIM (REWRITE_CONV [] t)
182  | case_split_REWRITE_CONV (m::matches) t =
183    let
184       fun rec_prove rewr =
185       let
186          val match_thms = get_rewrite_assumption_thms rewr;
187          val thm = REWRITE_CONV match_thms t handle UNCHANGED => REFL t;
188          val r = rhs (concl thm);
189       in
190          if (r = T) then thm else
191          TRANS thm (case_split_REWRITE_CONV matches r)
192       end;
193
194       val m_no_neg = fst (strip_neg m);
195       val thm1 = rec_prove m_no_neg;
196       val thm2 = rec_prove (mk_neg m_no_neg);
197
198
199       val disj_thm = SPEC m_no_neg EXCLUDED_MIDDLE
200       val thm = DISJ_CASES disj_thm thm1 thm2
201    in
202       thm
203    end;
204
205
206
207
208fun bool_eq_imp_real_imp_CONV matches t =
209   let
210      val matches_thms = flatten (map get_rewrite_assumption_thms matches);
211      val conc_term = rhs (concl (REWRITE_CONV matches_thms t));
212      val _ = if (conc_term = F) then raise UNCHANGED else ();
213
214      val goal_term = if (conc_term = T) then T else mk_imp (list_mk_conj matches, conc_term);
215      val _ = if (t = goal_term) then raise UNCHANGED else ();
216      val goal_eq_term = mk_eq (t, goal_term);
217
218      val thm = EQT_ELIM (case_split_REWRITE_CONV matches goal_eq_term);
219   in
220      thm
221   end;
222
223
224
225
226
227
228fun bool_extract_common_terms_internal_CONV disj matches t =
229   let
230      val neg_matches = if disj then map mk_neg___idempot matches else matches
231      val matches_thms = flatten (map get_rewrite_assumption_thms neg_matches);
232      val conc_term = rhs (concl (REWRITE_CONV matches_thms t));
233
234
235      val goal_term = if (disj) then
236                          if conc_term = T then T else
237                             list_mk_disj (conc_term::matches)
238                      else
239                          if conc_term = F then F else
240                             list_mk_conj (conc_term::matches)
241
242      val _ = if (t = goal_term) then raise UNCHANGED else ();
243      val goal_eq_term = mk_eq (t, goal_term);
244      val thm = EQT_ELIM (case_split_REWRITE_CONV matches goal_eq_term);
245   in
246      thm
247   end;
248
249
250
251
252
253
254(*cleans up the found matches by using just the simplest ones.
255  So clean_disj_matches removes terms from the list that are implied by
256  one other in the list and clean_conj_matches removes terms that imply
257  another term*)
258
259fun clean_disj_matches [] acc = acc
260  | clean_disj_matches (t::ts) acc =
261    let
262       val (disj_imp,_) = get_impl_terms t;
263       val acc' = if (null_intersection disj_imp (ts@acc)) then
264                     t::acc
265                  else
266                     acc;
267    in
268       clean_disj_matches ts acc'
269    end;
270
271
272fun clean_conj_matches [] acc = acc
273  | clean_conj_matches (t::ts) acc =
274    let
275       val (_, conj_imp) = get_impl_terms t;
276       val acc' = if (null_intersection conj_imp (ts@acc)) then
277                     t::acc
278                  else
279                     acc;
280    in
281       clean_conj_matches ts acc'
282    end;
283
284
285
286
287
288
289
290
291(*---------------------------------------------------------------------------
292 * Given a equation with boolean expressions on both sides (b1 = b2),
293 * this conversion tries to extract common parts of b1 and b2 into a precondition.
294 *
295 * e.g.          (A \/ B \/ C) = (A \/ D) is converted to
296 *       ~A ==> ((     B \/ C) =       D )
297 *
298 *---------------------------------------------------------------------------*)
299
300fun BOOL_EQ_IMP_CONV t =
301   let
302      val (l,r) = dest_eq t;
303      val _ = if (type_of l = bool) then () else raise mk_HOL_ERR "Conv" "bool_eq_imp_CONV" "";
304      val (disj_l, conj_l) = get_impl_terms l;
305      val (disj_r, conj_r) = get_impl_terms r;
306
307      val disj_matches = clean_disj_matches (findMatches (disj_l, disj_r)) [];
308      val conj_matches = clean_conj_matches (findMatches (conj_l, conj_r)) [];
309
310      val matches = (map mk_neg___idempot disj_matches) @ conj_matches;
311      val _ = if matches = [] then raise UNCHANGED else ();
312   in
313      bool_eq_imp_real_imp_CONV matches t
314   end;
315
316
317
318
319(*---------------------------------------------------------------------------
320 * Tries to convert a boolean expression to true or false by
321 * searching for a case split that will prove this expression.
322 *
323 * e.g. this conversion is able to prove
324 *        (A \/ B \/ ~A) = T
325 *   or   (A \/ B \/ ((D /\ A) ==> C)) = T
326 *   or   (A /\ B /\ ~A) = F
327 *
328 *---------------------------------------------------------------------------*)
329
330
331fun BOOL_NEG_PAIR_CONV t =
332   let
333      val _ = if (type_of t = bool) then () else raise mk_HOL_ERR "Conv" "bool_negation_pair_CONV" "";
334      val (disj_t, conj_t) = get_impl_terms t;
335      val solving_case_split = find_negation_pair disj_t;
336      val disj = isSome solving_case_split;
337      val solving_case_split = if disj then solving_case_split else
338                               find_negation_pair conj_t;
339
340      val _ = if (isSome solving_case_split) then () else raise UNCHANGED;
341
342      val thm_term = mk_eq (t, if disj then T else F);
343      val thm = EQT_ELIM (case_split_REWRITE_CONV [valOf solving_case_split] thm_term)
344   in
345      thm
346   end;
347
348
349
350(*---------------------------------------------------------------------------
351 * Tries to extract parts of a boolean terms that occur several times.
352 *
353 * e.g.   (D /\ A /\ ~C) \/ (A /\ B /\ ~C) = (D \/ B) /\ ~C /\ A
354 *   or   A \/ B \/ A = B \/ A
355 *
356 *---------------------------------------------------------------------------*)
357fun BOOL_EXTRACT_SHARED_CONV t =
358   let
359      val _ = if (type_of t = bool) then () else raise mk_HOL_ERR "Conv" "bool_imp_extract_CONV" "";
360      val (disj_t___multiple,conj_t___multiple) = get_impl_terms___multiple t;
361      val disj_t___multiple = clean_term_multiple_list disj_t___multiple;
362      val conj_t___multiple = clean_term_multiple_list conj_t___multiple;
363      val disj_t = clean_disj_matches (map fst (filter snd disj_t___multiple)) [];
364      val conj_t = clean_conj_matches (map fst (filter snd conj_t___multiple)) [];
365   in
366      if (not (null disj_t)) then
367         bool_extract_common_terms_internal_CONV true disj_t t
368      else if (not (null conj_t)) then
369         bool_extract_common_terms_internal_CONV false conj_t t
370      else raise UNCHANGED
371   end;
372
373
374
375
376val BOOL_EQ_IMP_convdata = {name = "BOOL_EQ_IMP_CONV",
377            trace = 2,
378            key = SOME ([],``(a:bool) = (b:bool)``),
379            conv = K (K BOOL_EQ_IMP_CONV)}:simpfrag.convdata;
380
381val BOOL_EXTRACT_SHARED_convdata =  {name = "BOOL_EXTRACT_SHARED_CONV",
382            trace = 2,
383            key = SOME ([],``a:bool``),
384            conv = K (K BOOL_EXTRACT_SHARED_CONV)}:simpfrag.convdata;
385
386val BOOL_NEG_PAIR_convdata = {name = "BOOL_NEG_PAIR_CONV",
387            trace = 2,
388            key = SOME ([],``a:bool``),
389            conv = K (K BOOL_NEG_PAIR_CONV)}:simpfrag.convdata;
390
391
392
393end
394