1(* ===================================================================== *)
2(* FILE          : BoolExtractShared                                     *)
3(* DESCRIPTION   : Tools to extract shared terms from a boolean          *)
4(*                 expression                                            *)
5(*                                                                       *)
6(* AUTHORS       : Thomas Tuerk                                          *)
7(* DATE          : July 3, 2008                                          *)
8(* ===================================================================== *)
9
10
11structure  BoolExtractShared :>  BoolExtractShared =
12struct
13
14(*
15quietdec := true;
16*)
17
18open HolKernel Parse boolLib
19
20(*
21quietdec := false;
22*)
23
24
25(*---------------------------------------------------------------------------
26 * dest_neg_eq (``~(a = b)``) = (``a``, ``b``);
27 *---------------------------------------------------------------------------*)
28  fun dest_neg_eq t = dest_eq (dest_neg t);
29  val is_neg_eq = can dest_neg_eq;
30
31
32(*---------------------------------------------------------------------------
33 * mk_neg___idempot a = ~a
34 * mk_neg___idempot ~a = a
35 * with a not being a negationg
36 *---------------------------------------------------------------------------*)
37  fun mk_neg___idempot t =
38    if is_neg t then dest_neg t else mk_neg t;
39
40
41(*---------------------------------------------------------------------------
42 * for t1 and t2 being disequations,
43 * eq_sym_acond tests, wether they are alpha convertible with
44 * respect to the symmetry of equaltity
45 *---------------------------------------------------------------------------*)
46  fun eq_sym_aconv t1 t2 =
47      aconv t1 t2 orelse
48      (is_eq t1 andalso is_eq t2 andalso
49       let
50           val (t1l, t1r) = dest_eq t1;
51           val (t2l, t2r) = dest_eq t2;
52       in
53          (aconv t1r t2l) andalso (aconv t1l t2r)
54       end) orelse
55      (is_neg_eq t1 andalso is_neg_eq t2 andalso
56       let
57           val (t1l, t1r) = dest_neg_eq t1;
58           val (t2l, t2r) = dest_neg_eq t2;
59       in
60          (aconv t1r t2l) andalso (aconv t1l t2r)
61       end);
62
63
64
65
66  fun eq_sym_mem e [] = false
67    | eq_sym_mem e (h::l) =
68      (eq_sym_aconv e h) orelse eq_sym_mem e l;
69
70
71  fun findMatches ([], l2) = []
72    | findMatches (a::l1, l2) =
73         let val l1' = filter (fn e => not (eq_sym_aconv e a)) l1;
74             val l2' = filter (fn e => not (eq_sym_aconv e a)) l2;
75             val l = (findMatches (l1',l2')); in
76         if eq_sym_mem a l2 then a::l else l end;
77
78  fun find_negation_pair [] = NONE |
79      find_negation_pair (e::l) =
80      if eq_sym_mem (mk_neg___idempot e) l then SOME e else
81      find_negation_pair l;
82
83
84  fun dest_quant t = dest_abs (snd (dest_comb t));
85  fun is_quant t = is_forall t orelse is_exists t orelse
86                 is_exists1 t;
87
88
89  fun findMatches___multiple (l1:(term * bool) list,l2:(term * bool) list) =
90      map (fn t => (t, true)) (findMatches (map fst l1, map fst l2));
91
92
93  (*get_impl_terms returns a list of terms that imply the whole term and
94    a list of terms that are implied. Thus if get_impl_terms t = (disjL,conjL) then
95
96    forall x in disjL. x implies t and
97    forall x in conjL. t implies x holds.
98
99
100    so e. g. get_impl_terms ``a \/ b \/ ~c`` returns
101     ([``a \/ b \/ ~c``, ``a``, ``b \/ ~c``, ``b``, ``~c``], [``a \/ b \/ ~c``])
102
103
104    get_impl_terms___multiple augments the results with information, whether
105    multiple occurences of a term have been abbriated to one.
106   *)
107  fun get_impl_terms___multiple t =
108      if is_disj t then
109          (let val (t1,t2)=dest_disj t;
110               val (l11,l12)= get_impl_terms___multiple t1;
111               val (l21,l22)= get_impl_terms___multiple t2;
112           in
113              ((t,false)::(l11 @ l21), (t,false)::findMatches___multiple (l12, l22))
114           end)
115      else
116      if is_conj t then
117          (let val (t1,t2)=dest_conj t;
118               val (l11,l12)= get_impl_terms___multiple t1;
119               val (l21,l22)= get_impl_terms___multiple t2;
120           in
121              ((t,false)::findMatches___multiple (l11, l21), (t,false)::(l12 @ l22))
122           end)
123      else
124      if is_neg t then
125          (let val (l1,l2) = get_impl_terms___multiple (dest_neg t) in
126              (map (fn (t,b) => (mk_neg___idempot t, b)) l2, map (fn (t,b) => (mk_neg___idempot t, b)) l1)
127          end)
128      else
129      if is_imp t then
130          (let val (t1,t2)=dest_imp_only t;
131               val (l11',l12')= get_impl_terms___multiple t1;
132               val (l11, l12) = (map (fn (t,b) => (mk_neg___idempot t, b)) l12',
133                                 map (fn (t,b) => (mk_neg___idempot t, b)) l11')
134               val (l21,l22)= get_impl_terms___multiple t2;
135           in
136              ((t,false)::(l11 @ l21), (t,false)::findMatches___multiple (l12, l22))
137           end)
138      else
139      if is_quant t then
140          (let
141              val (v, b) = dest_quant t
142              val (l1,l2) = get_impl_terms___multiple b
143              fun filter_pred (t,b) = not (free_in v t)
144          in
145              ((t,false)::filter filter_pred l1,
146               (t,false)::filter filter_pred l2)
147          end)
148      else
149        (if same_const T t orelse same_const F t then ([],[]) else
150           ([(t,false)],[(t,false)]));
151
152
153  fun clean_term_multiple_list [] = []
154    | clean_term_multiple_list ((t,b)::L) =
155        if op_mem aconv t (map fst L) then
156          (t,true) ::
157          clean_term_multiple_list (filter (fn (t',b') => not (aconv t t')) L)
158        else
159          (t,b)::clean_term_multiple_list L
160
161
162  fun get_impl_terms t =
163     let
164        val (l1,l2) = get_impl_terms___multiple t;
165     in
166        (map fst l1, map fst l2)
167     end
168
169
170
171
172
173fun get_rewrite_assumption_thms rewr =
174   let val match_thm = ASSUME rewr; in
175   if is_eq rewr then
176      [EQT_INTRO match_thm, EQT_INTRO (GSYM match_thm)]
177   else if (is_neg_eq rewr) then
178      [match_thm, GSYM match_thm]
179   else [match_thm]
180   end
181
182fun case_split_REWRITE_CONV [] t =
183    EQT_ELIM (REWRITE_CONV [] t)
184  | case_split_REWRITE_CONV (m::matches) t =
185    let
186       fun rec_prove rewr =
187       let
188          val match_thms = get_rewrite_assumption_thms rewr;
189          val thm = REWRITE_CONV match_thms t handle UNCHANGED => REFL t;
190          val r = rhs (concl thm);
191       in
192          if r ~~ T then thm else
193          TRANS thm (case_split_REWRITE_CONV matches r)
194       end;
195
196       val m_no_neg = fst (strip_neg m);
197       val thm1 = rec_prove m_no_neg;
198       val thm2 = rec_prove (mk_neg m_no_neg);
199
200
201       val disj_thm = SPEC m_no_neg EXCLUDED_MIDDLE
202       val thm = DISJ_CASES disj_thm thm1 thm2
203    in
204       thm
205    end;
206
207
208
209
210fun bool_eq_imp_real_imp_CONV matches t =
211   let
212      val matches_thms = flatten (map get_rewrite_assumption_thms matches)
213      val conc_term = rhs (concl (REWRITE_CONV matches_thms t))
214      val _ = if aconv conc_term F then raise UNCHANGED else ()
215
216      val goal_term = if aconv conc_term T then T
217                      else mk_imp (list_mk_conj matches, conc_term)
218      val _ = if aconv t goal_term then raise UNCHANGED else ()
219      val goal_eq_term = mk_eq (t, goal_term)
220
221      val thm = EQT_ELIM (case_split_REWRITE_CONV matches goal_eq_term)
222   in
223      thm
224   end;
225
226
227
228
229
230
231fun bool_extract_common_terms_internal_CONV disj matches t =
232   let
233      val neg_matches = if disj then map mk_neg___idempot matches else matches
234      val matches_thms = flatten (map get_rewrite_assumption_thms neg_matches)
235      val conc_term = rhs (concl (REWRITE_CONV matches_thms t))
236
237
238      val goal_term = if (disj) then
239                        if aconv conc_term T then T
240                        else
241                          list_mk_disj (conc_term::matches)
242                      else
243                        if aconv conc_term F then F
244                        else
245                          list_mk_conj (conc_term::matches)
246
247      val _ = if aconv t goal_term then raise UNCHANGED else ()
248      val goal_eq_term = mk_eq (t, goal_term)
249      val thm = EQT_ELIM (case_split_REWRITE_CONV matches goal_eq_term)
250   in
251      thm
252   end;
253
254
255
256
257
258
259(*cleans up the found matches by using just the simplest ones.
260  So clean_disj_matches removes terms from the list that are implied by
261  one other in the list and clean_conj_matches removes terms that imply
262  another term*)
263
264infix +=+
265val E = empty_tmset and op+=+ = HOLset.addList;
266
267
268fun clean_disj_matches [] acc = acc
269  | clean_disj_matches (t::ts) acc =
270    let
271      open HOLset
272      val (disj_imp,_) = get_impl_terms t
273      val acc' = if isEmpty(intersection(E +=+ disj_imp, E +=+ ts +=+ acc)) then
274                   t::acc
275                 else
276                   acc
277    in
278       clean_disj_matches ts acc'
279    end;
280
281
282fun clean_conj_matches [] acc = acc
283  | clean_conj_matches (t::ts) acc =
284    let
285       val (_, conj_imp) = get_impl_terms t
286       open HOLset
287       val acc' =
288           if isEmpty(intersection(E +=+ conj_imp, E +=+ ts +=+ acc)) then
289             t::acc
290           else
291             acc
292    in
293       clean_conj_matches ts acc'
294    end;
295
296
297
298
299
300
301
302
303(*---------------------------------------------------------------------------
304 * Given a equation with boolean expressions on both sides (b1 = b2),
305 * this conversion tries to extract common parts of b1 and b2 into a precondition.
306 *
307 * e.g.          (A \/ B \/ C) = (A \/ D) is converted to
308 *       ~A ==> ((     B \/ C) =       D )
309 *
310 *---------------------------------------------------------------------------*)
311
312fun BOOL_EQ_IMP_CONV t =
313   let
314      val (l,r) = dest_eq t;
315      val _ = if (type_of l = bool) then ()
316              else raise mk_HOL_ERR "Conv" "bool_eq_imp_CONV" ""
317      val (disj_l, conj_l) = get_impl_terms l
318      val (disj_r, conj_r) = get_impl_terms r
319
320      val disj_matches = clean_disj_matches (findMatches (disj_l, disj_r)) []
321      val conj_matches = clean_conj_matches (findMatches (conj_l, conj_r)) []
322
323      val matches = (map mk_neg___idempot disj_matches) @ conj_matches
324      val _ = if null matches then raise UNCHANGED else ()
325   in
326      bool_eq_imp_real_imp_CONV matches t
327   end;
328
329
330
331
332(*---------------------------------------------------------------------------
333 * Tries to convert a boolean expression to true or false by
334 * searching for a case split that will prove this expression.
335 *
336 * e.g. this conversion is able to prove
337 *        (A \/ B \/ ~A) = T
338 *   or   (A \/ B \/ ((D /\ A) ==> C)) = T
339 *   or   (A /\ B /\ ~A) = F
340 *
341 *---------------------------------------------------------------------------*)
342
343
344fun BOOL_NEG_PAIR_CONV t =
345   let
346      val _ = if (type_of t = bool) then () else raise mk_HOL_ERR "Conv" "bool_negation_pair_CONV" "";
347      val (disj_t, conj_t) = get_impl_terms t;
348      val solving_case_split = find_negation_pair disj_t;
349      val disj = isSome solving_case_split;
350      val solving_case_split = if disj then solving_case_split else
351                               find_negation_pair conj_t;
352
353      val _ = if (isSome solving_case_split) then () else raise UNCHANGED;
354
355      val thm_term = mk_eq (t, if disj then T else F);
356      val thm = EQT_ELIM (case_split_REWRITE_CONV [valOf solving_case_split] thm_term)
357   in
358      thm
359   end;
360
361
362
363(*---------------------------------------------------------------------------
364 * Tries to extract parts of a boolean terms that occur several times.
365 *
366 * e.g.   (D /\ A /\ ~C) \/ (A /\ B /\ ~C) = (D \/ B) /\ ~C /\ A
367 *   or   A \/ B \/ A = B \/ A
368 *
369 *---------------------------------------------------------------------------*)
370fun BOOL_EXTRACT_SHARED_CONV t =
371   let
372      val _ = if (type_of t = bool) then () else raise mk_HOL_ERR "Conv" "bool_imp_extract_CONV" "";
373      val (disj_t___multiple,conj_t___multiple) = get_impl_terms___multiple t;
374      val disj_t___multiple = clean_term_multiple_list disj_t___multiple;
375      val conj_t___multiple = clean_term_multiple_list conj_t___multiple;
376      val disj_t = clean_disj_matches (map fst (filter snd disj_t___multiple)) [];
377      val conj_t = clean_conj_matches (map fst (filter snd conj_t___multiple)) [];
378   in
379      if (not (null disj_t)) then
380         bool_extract_common_terms_internal_CONV true disj_t t
381      else if (not (null conj_t)) then
382         bool_extract_common_terms_internal_CONV false conj_t t
383      else raise UNCHANGED
384   end;
385
386
387
388
389val BOOL_EQ_IMP_convdata = {name = "BOOL_EQ_IMP_CONV",
390            trace = 2,
391            key = SOME ([],``(a:bool) = (b:bool)``),
392            conv = K (K BOOL_EQ_IMP_CONV)}:simpfrag.convdata;
393
394val BOOL_EXTRACT_SHARED_convdata =  {name = "BOOL_EXTRACT_SHARED_CONV",
395            trace = 2,
396            key = SOME ([],``a:bool``),
397            conv = K (K BOOL_EXTRACT_SHARED_CONV)}:simpfrag.convdata;
398
399val BOOL_NEG_PAIR_convdata = {name = "BOOL_NEG_PAIR_CONV",
400            trace = 2,
401            key = SOME ([],``a:bool``),
402            conv = K (K BOOL_NEG_PAIR_CONV)}:simpfrag.convdata;
403
404
405
406end
407