1structure boolTools :> boolTools =
2struct
3
4(*
5quietdec := true;
6*)
7
8open HolKernel Parse boolLib bossLib;
9
10(*
11quietdec := false;
12*)
13
14fun dest_neg_eq t = dest_eq (dest_neg t);
15val is_neg_eq = can dest_neg_eq;
16
17fun logical_mk_neg t =
18    if is_neg t then dest_neg t else mk_neg t;
19
20
21fun rewrite_eq t1 t2 =
22      aconv t1 t2 orelse
23      (is_eq t1 andalso is_eq t2 andalso
24       let
25           val (t1l, t1r) = dest_eq t1;
26           val (t2l, t2r) = dest_eq t2;
27       in
28          (aconv t1r t2l) andalso (aconv t1l t2r)
29       end) orelse
30      (is_neg_eq t1 andalso is_neg_eq t2 andalso
31       let
32           val (t1l, t1r) = dest_neg_eq t1;
33           val (t2l, t2r) = dest_neg_eq t2;
34       in
35          (aconv t1r t2l) andalso (aconv t1l t2r)
36       end);
37
38
39  fun logical_mem e [] = false
40    | logical_mem e (h::l) =
41      (rewrite_eq e h) orelse logical_mem e l;
42
43
44  fun findMatches ([], l2) = []
45    | findMatches (a::l1, l2) =
46         let val l1' = filter (fn e => not (e = a)) l1;
47             val l2' = filter (fn e => not (e = a)) l2;
48             val l = (findMatches (l1',l2')); in
49         if logical_mem a l2 then a::l else l end;
50
51  fun find_negation_pair [] = NONE |
52      find_negation_pair (e::l) =
53      if logical_mem (logical_mk_neg e) l then SOME e else
54      find_negation_pair l;
55
56
57  fun dest_quant t = dest_abs (snd (dest_comb t));
58  fun is_quant t = is_forall t orelse is_exists t orelse
59                 is_exists1 t;
60
61
62  (*returns a list of terms that imply the whole term and
63    a list of terms that are implied
64
65
66    (x ==> X, x <== X)
67   *)
68
69
70
71  fun get_impl_terms t =
72      if is_disj t then
73          (let val (t1,t2)=dest_disj t;
74               val (l11,l12)= get_impl_terms t1;
75               val (l21,l22)= get_impl_terms t2;
76           in
77              (t::(l11 @ l21), t::findMatches (l12, l22))
78           end)
79      else
80      if is_conj t then
81          (let val (t1,t2)=dest_conj t;
82               val (l11,l12)= get_impl_terms t1;
83               val (l21,l22)= get_impl_terms t2;
84           in
85              (t::findMatches (l11, l21), t::(l12 @ l22))
86           end)
87      else
88      if is_neg t then
89          (let val (l1,l2) = get_impl_terms (dest_neg t) in
90              (map logical_mk_neg l2, map logical_mk_neg l1)
91          end)
92      else
93      if is_imp t then
94          (let val (t1,t2)=dest_imp t;
95               val neg_t1 = logical_mk_neg t1;
96               val new_t = mk_disj (neg_t1, t2)
97           in get_impl_terms new_t end)
98      else
99      if is_quant t then
100          (let
101              val (v, b) = dest_quant t;
102              val (l1,l2) = get_impl_terms b;
103              fun filter_pred t = not (mem v (free_vars t));
104          in
105              (t::(filter filter_pred l1), t::(filter filter_pred l2))
106          end)
107      else
108      ([t],[t]);
109
110
111
112
113
114val bool_eq_imp_solve_TAC = ASM_REWRITE_TAC[] THEN
115                            ASM_SIMP_TAC std_ss [] THEN
116                            METIS_TAC[];
117
118(*
119fun neg_eq_ASSUME_TAC tac =
120   tac THENL [
121      POP_ASSUM (fn thm => ASSUME_TAC thm THEN ASSUME_TAC (GSYM thm)),
122      ALL_TAC
123   ];
124*)
125
126
127fun bool_eq_imp_case_TAC h =
128      let
129          val (h', n) = strip_neg h;
130          val org_cases_tac = ASM_CASES_TAC h';
131          val cases_tac = if (n mod 2 = 0) then org_cases_tac else
132                          Tactical.REVERSE org_cases_tac;
133      in
134          cases_tac
135      end;
136
137
138
139fun bool_eq_imp_solve_CONV c t =
140   let
141      val thm = prove (t, bool_eq_imp_case_TAC c THEN
142                          bool_eq_imp_solve_TAC);
143   in
144      EQT_INTRO thm
145   end;
146
147
148
149
150fun bool_eq_imp_real_imp_TAC [] = bool_eq_imp_solve_TAC
151  | bool_eq_imp_real_imp_TAC (h::l) =
152      bool_eq_imp_case_TAC h THENL [
153          bool_eq_imp_real_imp_TAC l,
154          bool_eq_imp_solve_TAC
155      ];
156
157
158
159
160
161
162
163
164fun bool_eq_imp_real_imp_CONV matches t =
165   let
166      val matches_thms1 = map ASSUME matches
167      val matches_thms2 = map GSYM (filter (fn thm => is_neg_eq (concl thm)) matches_thms1);
168      val conc_term = rhs (concl (REWRITE_CONV (matches_thms1 @ matches_thms2) t));
169      val _ = if (conc_term = F) then raise UNCHANGED else ();
170
171      val goal_term = if (conc_term = T) then T else mk_imp (list_mk_conj matches, conc_term);
172      val _ = if (t = goal_term) then raise UNCHANGED else ();
173      (* set_goal ([], mk_eq(t, goal_term)) *)
174      val thm = prove (mk_eq(t, goal_term), bool_eq_imp_real_imp_TAC matches);
175   in
176      thm
177   end;
178
179
180
181fun clean_disj_matches [] acc = acc
182  | clean_disj_matches (t::ts) acc =
183    let
184       val (disj_imp,_) = get_impl_terms t;
185       val acc' = if (null_intersection disj_imp (ts@acc)) then
186                     t::acc
187                  else
188                     acc;
189    in
190       clean_disj_matches ts acc'
191    end;
192
193
194fun clean_conj_matches [] acc = acc
195  | clean_conj_matches (t::ts) acc =
196    let
197       val (_, conj_imp) = get_impl_terms t;
198       val acc' = if (null_intersection conj_imp (ts@acc)) then
199                     t::acc
200                  else
201                     acc;
202    in
203       clean_conj_matches ts acc'
204    end;
205
206
207
208
209
210
211
212fun bool_eq_imp_CONV t =
213   let
214      val (l,r) = dest_eq t;
215      val _ = if (type_of l = bool) then () else raise mk_HOL_ERR "Conv" "bool_eq_imp_CONV" "";
216      val (disj_l, conj_l) = get_impl_terms l;
217      val (disj_r, conj_r) = get_impl_terms r;
218
219      val disj_matches = clean_disj_matches (findMatches (disj_l, disj_r)) [];
220      val conj_matches = clean_conj_matches (findMatches (conj_l, conj_r)) [];
221
222      val matches = (map logical_mk_neg disj_matches) @ conj_matches;
223      val _ = if matches = [] then raise UNCHANGED else ();
224      val solving_case_split = find_negation_pair matches;
225   in
226      if isSome solving_case_split then bool_eq_imp_solve_CONV (valOf solving_case_split) t else
227         bool_eq_imp_real_imp_CONV matches t
228   end;
229
230
231
232fun bool_neg_pair_CONV t =
233   let
234      val _ = if (type_of t = bool) then () else raise mk_HOL_ERR "Conv" "bool_negation_pair_CONV" "";
235      val (disj_t, conj_t) = get_impl_terms t;
236      val solving_case_split = find_negation_pair disj_t;
237      val disj = isSome solving_case_split;
238      val solving_case_split = if disj then solving_case_split else
239                               find_negation_pair conj_t;
240
241      val _ = if (isSome solving_case_split) then () else raise UNCHANGED;
242
243      val thm_term = mk_eq (t, if disj then T else F);
244      val thm = prove (thm_term, bool_eq_imp_case_TAC (valOf solving_case_split) THEN
245                          bool_eq_imp_solve_TAC);
246   in
247      thm
248   end;
249
250
251
252fun bool_imp_extract_CONV t =
253   let
254      val _ = if (type_of t = bool) then () else raise mk_HOL_ERR "Conv" "bool_imp_extract_CONV" "";
255      val (disj_t_refl,_) = get_impl_terms t;
256      val disj_t = tl disj_t_refl;
257      val disj_matches = clean_disj_matches disj_t [];
258
259      val matches = (map logical_mk_neg disj_t);
260      val _ = if matches = [] then raise UNCHANGED else ();
261   in
262      bool_eq_imp_real_imp_CONV matches t
263   end;
264
265
266
267
268val bool_eq_imp_ss = simpLib.conv_ss {name = "bool_eq_imp_CONV",
269            trace = 2,
270            key = SOME ([],``(a:bool) = (b:bool)``),
271            conv = K (K bool_eq_imp_CONV)};
272
273val bool_imp_extract_ss = simpLib.conv_ss {name = "bool_imp_extract_ss",
274            trace = 2,
275            key = SOME ([],``a:bool``),
276            conv = K (K bool_imp_extract_CONV)};
277
278val bool_neg_pair_ss = simpLib.conv_ss {name = "bool_neg_pair_CONV",
279            trace = 2,
280            key = SOME ([],``a:bool``),
281            conv = K (K bool_neg_pair_CONV)};
282
283
284
285
286
287
288val imp_thm_conj = prove (``!b1 c1 b2 c2. (b1 ==> c1) ==>
289                                           (b2 ==> c2) ==>
290                                           (b1 /\ b2) ==>
291                                           (c1 /\ c2)``, SIMP_TAC std_ss []);
292val imp_thm_disj = prove (``!b1 c1 b2 c2. (b1 ==> c1) ==>
293                                           (b2 ==> c2) ==>
294                                           (b1 \/ b2) ==>
295                                           (c1 \/ c2)``, SIMP_TAC std_ss [DISJ_IMP_THM]);
296
297val imp_thm_forall = prove (``(!x. (b1 x ==> b2 x)) ==> ((!x. b1 x) ==> (!x. b2 x))``,
298                              SIMP_TAC std_ss []);
299
300
301fun GEN_IMP v thm =
302  let
303     val thm1 = GEN v thm;
304     val thm2 = HO_MATCH_MP imp_thm_forall thm1;
305  in
306     thm2
307  end;
308
309
310
311fun REFL_IMP_CONV t = DISCH t (ASSUME t);
312
313fun GEN_ASSUM v thm =
314  let
315    val assums = filter (fn t => mem v (free_vars t)) (hyp thm);
316    val thm2 = foldl (fn (t,thm) => DISCH t thm) thm assums;
317    val thm3 = GEN v thm2;
318    val thm4 = foldl (fn (_,thm) => UNDISCH (HO_MATCH_MP MONO_ALL thm))
319                     thm3 assums;
320  in
321    thm4
322  end
323
324
325fun STRENGTHEN_CONV_WRAPPER conv t =
326let
327   val thm = conv t;
328   val thm_term = concl thm;
329in
330   if (is_imp thm_term) then
331      let
332         val (t1, t2) = dest_imp thm_term;
333         val _ = if not (t2 = t) then raise UNCHANGED else ();
334         val _ = if (t1 = t2) then raise UNCHANGED else ();
335      in
336         thm
337      end
338   else if (is_eq thm_term) then
339      if ((lhs thm_term = t) andalso not (rhs thm_term = t)) then
340         snd (EQ_IMP_RULE thm)
341      else raise UNCHANGED
342   else if (thm_term = t andalso not (t = T)) then
343      snd (EQ_IMP_RULE (EQT_INTRO thm))
344   else
345      raise UNCHANGED
346end;
347
348
349fun DEPTH_STRENGTHEN_CONV conv t =
350  if (is_conj t) then
351     let
352         val (b1,b2) = dest_conj t;
353         val thm1 = DEPTH_STRENGTHEN_CONV conv b1;
354         val thm2 = DEPTH_STRENGTHEN_CONV conv b2;
355
356         val (b1,c1) = dest_imp (concl thm1);
357         val (b2,c2) = dest_imp (concl thm2);
358         val thm3 = ISPECL [b1,c1,b2,c2] imp_thm_conj;
359         val thm4 = MP thm3 thm1;
360         val thm5 = MP thm4 thm2;
361     in
362        thm5
363     end handle HOL_ERR _ => (raise UNCHANGED)
364   else if (is_disj t) then
365     let
366         val (b1,b2) = dest_disj t;
367         val thm1 = DEPTH_STRENGTHEN_CONV conv b1;
368         val thm2 = DEPTH_STRENGTHEN_CONV conv b2;
369
370         val (b1,c1) = dest_imp (concl thm1);
371         val (b2,c2) = dest_imp (concl thm2);
372         val thm3 = ISPECL [b1,c1,b2,c2] imp_thm_disj;
373         val thm4 = MP thm3 thm1;
374         val thm5 = MP thm4 thm2;
375     in
376        thm5
377     end
378   else if (is_forall t) then
379     let
380        val (var, body) = dest_forall t;
381        val thm_body = DEPTH_STRENGTHEN_CONV conv body;
382        val thm = GEN_ASSUM var thm_body;
383        val thm2 = HO_MATCH_MP imp_thm_forall thm;
384     in
385        thm2
386     end
387   else
388     ((let
389         val thm = (STRENGTHEN_CONV_WRAPPER conv) t;
390         val (ante,_) = dest_imp (concl thm);
391         val thm2 = DEPTH_STRENGTHEN_CONV conv ante;
392         val thm3 = IMP_TRANS thm2 thm;
393     in
394         thm3
395     end handle HOL_ERR _ => REFL_IMP_CONV t)
396         handle UNCHANGED => REFL_IMP_CONV t);
397
398fun UNCHANGED_STRENGTHEN_CONV conv t =
399    let
400       val thm = conv t;
401       val (ante,conc) = dest_imp (concl thm);
402       val _ = if (ante = conc) then raise UNCHANGED else ();
403    in
404       thm
405    end;
406
407
408fun ORELSE_STRENGTHEN_CONV [] t = raise UNCHANGED
409  | ORELSE_STRENGTHEN_CONV (c1::L) t =
410    c1 t handle UNCHANGED =>
411    ORELSE_STRENGTHEN_CONV L t;
412
413
414
415
416
417fun CONJ_ASSUMPTIONS_STRENGTHEN_CONV conv preserve_hyps t =
418let
419    val thm = conv t;
420    val new_hyps = filter (fn t => not (mem t preserve_hyps)) (hyp thm);
421    val hyp_thms = map (fn t =>
422                       ((SOME (CONJ_ASSUMPTIONS_STRENGTHEN_CONV conv preserve_hyps t))
423                        handle HOL_ERR _ => NONE)
424                        handle UNCHANGED => NONE) new_hyps;
425
426    val hyp_thms2 = filter (fn thm_opt => (isSome thm_opt andalso
427                                           let val (l,r) = dest_imp (concl (valOf thm_opt)) in (not (l = r)) end handle HOL_ERR _ => false)) hyp_thms;
428    val hyp_thms3 = map (UNDISCH o valOf) hyp_thms2;
429
430    val thm2 = foldr (fn (thm1,thm2) => PROVE_HYP thm1 thm2) thm hyp_thms3;
431
432
433    val new_hyps2 = filter (fn t => not (mem t preserve_hyps)) (hyp thm2);
434    val thm3 = foldr (fn (t,thm) => SUBST_MATCH (SPEC_ALL AND_IMP_INTRO) (DISCH t thm)) thm2 (new_hyps2);
435    val thm4 = CONV_RULE (RATOR_CONV (REWRITE_CONV [])) thm3
436in
437    thm4
438end;
439
440
441fun CONJ_ASSUMPTIONS_DEPTH_STRENGTHEN_CONV conv =
442    CONJ_ASSUMPTIONS_STRENGTHEN_CONV (DEPTH_STRENGTHEN_CONV conv) []
443
444
445fun IMP_STRENGTHEN_CONV_RULE conv thm = let
446   val (imp_term,_) = dest_imp (concl thm);
447   val imp_thm = conv imp_term;
448  in
449   IMP_TRANS imp_thm thm
450  end
451
452
453fun STRENGTHEN_CONV_TAC conv (asm,t) =
454    HO_MATCH_MP_TAC (conv t) (asm,t);
455
456
457fun DEPTH_STRENGTHEN_CONV_TAC conv =
458    STRENGTHEN_CONV_TAC (DEPTH_STRENGTHEN_CONV conv)
459
460
461
462
463
464
465
466
467
468fun COND_REWR_CONV___with_match thm =
469  if (is_imp (concl thm)) then
470     if (is_eq (snd (dest_imp (concl thm)))) then
471        (UNDISCH o (PART_MATCH (lhs o snd o dest_imp) thm),
472         (lhs o snd o dest_imp o concl) thm)
473     else
474        (EQT_INTRO o UNDISCH o (PART_MATCH (snd o dest_imp) thm),
475         (snd o dest_imp o concl) thm)
476  else
477     if (is_eq (concl thm)) then
478        (PART_MATCH lhs thm,
479         (lhs o concl) thm)
480     else
481        (EQT_INTRO o PART_MATCH I thm,
482         concl thm)
483
484
485fun COND_REWR_CONV thm =
486    fst (COND_REWR_CONV___with_match thm);
487
488
489
490
491fun COND_REWRITE_CONV thmL =
492   let
493     val thmL' = flatten (map BODY_CONJUNCTS thmL);
494     val conv_termL = map COND_REWR_CONV___with_match thmL';
495     val net = foldr (fn ((conv,t),net) => Net.insert (t,conv) net) Net.empty conv_termL;
496   in
497     REPEATC (fn t =>
498        let
499          val convL = Net.match t net;
500        in
501          FIRST_CONV convL t
502        end)
503   end
504
505
506fun GUARDED_COND_REWRITE_CONV p thmL =
507   let
508      val conv = COND_REWRITE_CONV thmL
509   in
510      fn t => if p t then conv t else NO_CONV t
511   end
512
513
514(*
515fun COND_REWRITE_RULE r thm =
516   let
517      val rs = flatten (map (fn thm => CONJUNCTS thm) r);
518      val rs = map UNDISCH_ALL rs;
519      val thm' = repeat (fn thm => tryfind (fn thm2 => SUBST_MATCH thm2 thm) rs) thm
520   in
521      thm'
522   end;
523
524
525*)
526
527end
528