1structure boolTools :> boolTools = 2struct 3 4(* 5quietdec := true; 6*) 7 8open HolKernel Parse boolLib bossLib; 9 10(* 11quietdec := false; 12*) 13 14fun dest_neg_eq t = dest_eq (dest_neg t); 15val is_neg_eq = can dest_neg_eq; 16 17fun logical_mk_neg t = 18 if is_neg t then dest_neg t else mk_neg t; 19 20 21fun rewrite_eq t1 t2 = 22 aconv t1 t2 orelse 23 (is_eq t1 andalso is_eq t2 andalso 24 let 25 val (t1l, t1r) = dest_eq t1; 26 val (t2l, t2r) = dest_eq t2; 27 in 28 (aconv t1r t2l) andalso (aconv t1l t2r) 29 end) orelse 30 (is_neg_eq t1 andalso is_neg_eq t2 andalso 31 let 32 val (t1l, t1r) = dest_neg_eq t1; 33 val (t2l, t2r) = dest_neg_eq t2; 34 in 35 (aconv t1r t2l) andalso (aconv t1l t2r) 36 end); 37 38 39 fun logical_mem e [] = false 40 | logical_mem e (h::l) = 41 (rewrite_eq e h) orelse logical_mem e l; 42 43 44 fun findMatches ([], l2) = [] 45 | findMatches (a::l1, l2) = 46 let val l1' = filter (fn e => not (e = a)) l1; 47 val l2' = filter (fn e => not (e = a)) l2; 48 val l = (findMatches (l1',l2')); in 49 if logical_mem a l2 then a::l else l end; 50 51 fun find_negation_pair [] = NONE | 52 find_negation_pair (e::l) = 53 if logical_mem (logical_mk_neg e) l then SOME e else 54 find_negation_pair l; 55 56 57 fun dest_quant t = dest_abs (snd (dest_comb t)); 58 fun is_quant t = is_forall t orelse is_exists t orelse 59 is_exists1 t; 60 61 62 (*returns a list of terms that imply the whole term and 63 a list of terms that are implied 64 65 66 (x ==> X, x <== X) 67 *) 68 69 70 71 fun get_impl_terms t = 72 if is_disj t then 73 (let val (t1,t2)=dest_disj t; 74 val (l11,l12)= get_impl_terms t1; 75 val (l21,l22)= get_impl_terms t2; 76 in 77 (t::(l11 @ l21), t::findMatches (l12, l22)) 78 end) 79 else 80 if is_conj t then 81 (let val (t1,t2)=dest_conj t; 82 val (l11,l12)= get_impl_terms t1; 83 val (l21,l22)= get_impl_terms t2; 84 in 85 (t::findMatches (l11, l21), t::(l12 @ l22)) 86 end) 87 else 88 if is_neg t then 89 (let val (l1,l2) = get_impl_terms (dest_neg t) in 90 (map logical_mk_neg l2, map logical_mk_neg l1) 91 end) 92 else 93 if is_imp t then 94 (let val (t1,t2)=dest_imp t; 95 val neg_t1 = logical_mk_neg t1; 96 val new_t = mk_disj (neg_t1, t2) 97 in get_impl_terms new_t end) 98 else 99 if is_quant t then 100 (let 101 val (v, b) = dest_quant t; 102 val (l1,l2) = get_impl_terms b; 103 fun filter_pred t = not (mem v (free_vars t)); 104 in 105 (t::(filter filter_pred l1), t::(filter filter_pred l2)) 106 end) 107 else 108 ([t],[t]); 109 110 111 112 113 114val bool_eq_imp_solve_TAC = ASM_REWRITE_TAC[] THEN 115 ASM_SIMP_TAC std_ss [] THEN 116 METIS_TAC[]; 117 118(* 119fun neg_eq_ASSUME_TAC tac = 120 tac THENL [ 121 POP_ASSUM (fn thm => ASSUME_TAC thm THEN ASSUME_TAC (GSYM thm)), 122 ALL_TAC 123 ]; 124*) 125 126 127fun bool_eq_imp_case_TAC h = 128 let 129 val (h', n) = strip_neg h; 130 val org_cases_tac = ASM_CASES_TAC h'; 131 val cases_tac = if (n mod 2 = 0) then org_cases_tac else 132 Tactical.REVERSE org_cases_tac; 133 in 134 cases_tac 135 end; 136 137 138 139fun bool_eq_imp_solve_CONV c t = 140 let 141 val thm = prove (t, bool_eq_imp_case_TAC c THEN 142 bool_eq_imp_solve_TAC); 143 in 144 EQT_INTRO thm 145 end; 146 147 148 149 150fun bool_eq_imp_real_imp_TAC [] = bool_eq_imp_solve_TAC 151 | bool_eq_imp_real_imp_TAC (h::l) = 152 bool_eq_imp_case_TAC h THENL [ 153 bool_eq_imp_real_imp_TAC l, 154 bool_eq_imp_solve_TAC 155 ]; 156 157 158 159 160 161 162 163 164fun bool_eq_imp_real_imp_CONV matches t = 165 let 166 val matches_thms1 = map ASSUME matches 167 val matches_thms2 = map GSYM (filter (fn thm => is_neg_eq (concl thm)) matches_thms1); 168 val conc_term = rhs (concl (REWRITE_CONV (matches_thms1 @ matches_thms2) t)); 169 val _ = if (conc_term = F) then raise UNCHANGED else (); 170 171 val goal_term = if (conc_term = T) then T else mk_imp (list_mk_conj matches, conc_term); 172 val _ = if (t = goal_term) then raise UNCHANGED else (); 173 (* set_goal ([], mk_eq(t, goal_term)) *) 174 val thm = prove (mk_eq(t, goal_term), bool_eq_imp_real_imp_TAC matches); 175 in 176 thm 177 end; 178 179 180 181fun clean_disj_matches [] acc = acc 182 | clean_disj_matches (t::ts) acc = 183 let 184 val (disj_imp,_) = get_impl_terms t; 185 val acc' = if (null_intersection disj_imp (ts@acc)) then 186 t::acc 187 else 188 acc; 189 in 190 clean_disj_matches ts acc' 191 end; 192 193 194fun clean_conj_matches [] acc = acc 195 | clean_conj_matches (t::ts) acc = 196 let 197 val (_, conj_imp) = get_impl_terms t; 198 val acc' = if (null_intersection conj_imp (ts@acc)) then 199 t::acc 200 else 201 acc; 202 in 203 clean_conj_matches ts acc' 204 end; 205 206 207 208 209 210 211 212fun bool_eq_imp_CONV t = 213 let 214 val (l,r) = dest_eq t; 215 val _ = if (type_of l = bool) then () else raise mk_HOL_ERR "Conv" "bool_eq_imp_CONV" ""; 216 val (disj_l, conj_l) = get_impl_terms l; 217 val (disj_r, conj_r) = get_impl_terms r; 218 219 val disj_matches = clean_disj_matches (findMatches (disj_l, disj_r)) []; 220 val conj_matches = clean_conj_matches (findMatches (conj_l, conj_r)) []; 221 222 val matches = (map logical_mk_neg disj_matches) @ conj_matches; 223 val _ = if matches = [] then raise UNCHANGED else (); 224 val solving_case_split = find_negation_pair matches; 225 in 226 if isSome solving_case_split then bool_eq_imp_solve_CONV (valOf solving_case_split) t else 227 bool_eq_imp_real_imp_CONV matches t 228 end; 229 230 231 232fun bool_neg_pair_CONV t = 233 let 234 val _ = if (type_of t = bool) then () else raise mk_HOL_ERR "Conv" "bool_negation_pair_CONV" ""; 235 val (disj_t, conj_t) = get_impl_terms t; 236 val solving_case_split = find_negation_pair disj_t; 237 val disj = isSome solving_case_split; 238 val solving_case_split = if disj then solving_case_split else 239 find_negation_pair conj_t; 240 241 val _ = if (isSome solving_case_split) then () else raise UNCHANGED; 242 243 val thm_term = mk_eq (t, if disj then T else F); 244 val thm = prove (thm_term, bool_eq_imp_case_TAC (valOf solving_case_split) THEN 245 bool_eq_imp_solve_TAC); 246 in 247 thm 248 end; 249 250 251 252fun bool_imp_extract_CONV t = 253 let 254 val _ = if (type_of t = bool) then () else raise mk_HOL_ERR "Conv" "bool_imp_extract_CONV" ""; 255 val (disj_t_refl,_) = get_impl_terms t; 256 val disj_t = tl disj_t_refl; 257 val disj_matches = clean_disj_matches disj_t []; 258 259 val matches = (map logical_mk_neg disj_t); 260 val _ = if matches = [] then raise UNCHANGED else (); 261 in 262 bool_eq_imp_real_imp_CONV matches t 263 end; 264 265 266 267 268val bool_eq_imp_ss = simpLib.conv_ss {name = "bool_eq_imp_CONV", 269 trace = 2, 270 key = SOME ([],``(a:bool) = (b:bool)``), 271 conv = K (K bool_eq_imp_CONV)}; 272 273val bool_imp_extract_ss = simpLib.conv_ss {name = "bool_imp_extract_ss", 274 trace = 2, 275 key = SOME ([],``a:bool``), 276 conv = K (K bool_imp_extract_CONV)}; 277 278val bool_neg_pair_ss = simpLib.conv_ss {name = "bool_neg_pair_CONV", 279 trace = 2, 280 key = SOME ([],``a:bool``), 281 conv = K (K bool_neg_pair_CONV)}; 282 283 284 285 286 287 288val imp_thm_conj = prove (``!b1 c1 b2 c2. (b1 ==> c1) ==> 289 (b2 ==> c2) ==> 290 (b1 /\ b2) ==> 291 (c1 /\ c2)``, SIMP_TAC std_ss []); 292val imp_thm_disj = prove (``!b1 c1 b2 c2. (b1 ==> c1) ==> 293 (b2 ==> c2) ==> 294 (b1 \/ b2) ==> 295 (c1 \/ c2)``, SIMP_TAC std_ss [DISJ_IMP_THM]); 296 297val imp_thm_forall = prove (``(!x. (b1 x ==> b2 x)) ==> ((!x. b1 x) ==> (!x. b2 x))``, 298 SIMP_TAC std_ss []); 299 300 301fun GEN_IMP v thm = 302 let 303 val thm1 = GEN v thm; 304 val thm2 = HO_MATCH_MP imp_thm_forall thm1; 305 in 306 thm2 307 end; 308 309 310 311fun REFL_IMP_CONV t = DISCH t (ASSUME t); 312 313fun GEN_ASSUM v thm = 314 let 315 val assums = filter (fn t => mem v (free_vars t)) (hyp thm); 316 val thm2 = foldl (fn (t,thm) => DISCH t thm) thm assums; 317 val thm3 = GEN v thm2; 318 val thm4 = foldl (fn (_,thm) => UNDISCH (HO_MATCH_MP MONO_ALL thm)) 319 thm3 assums; 320 in 321 thm4 322 end 323 324 325fun STRENGTHEN_CONV_WRAPPER conv t = 326let 327 val thm = conv t; 328 val thm_term = concl thm; 329in 330 if (is_imp thm_term) then 331 let 332 val (t1, t2) = dest_imp thm_term; 333 val _ = if not (t2 = t) then raise UNCHANGED else (); 334 val _ = if (t1 = t2) then raise UNCHANGED else (); 335 in 336 thm 337 end 338 else if (is_eq thm_term) then 339 if ((lhs thm_term = t) andalso not (rhs thm_term = t)) then 340 snd (EQ_IMP_RULE thm) 341 else raise UNCHANGED 342 else if (thm_term = t andalso not (t = T)) then 343 snd (EQ_IMP_RULE (EQT_INTRO thm)) 344 else 345 raise UNCHANGED 346end; 347 348 349fun DEPTH_STRENGTHEN_CONV conv t = 350 if (is_conj t) then 351 let 352 val (b1,b2) = dest_conj t; 353 val thm1 = DEPTH_STRENGTHEN_CONV conv b1; 354 val thm2 = DEPTH_STRENGTHEN_CONV conv b2; 355 356 val (b1,c1) = dest_imp (concl thm1); 357 val (b2,c2) = dest_imp (concl thm2); 358 val thm3 = ISPECL [b1,c1,b2,c2] imp_thm_conj; 359 val thm4 = MP thm3 thm1; 360 val thm5 = MP thm4 thm2; 361 in 362 thm5 363 end handle HOL_ERR _ => (raise UNCHANGED) 364 else if (is_disj t) then 365 let 366 val (b1,b2) = dest_disj t; 367 val thm1 = DEPTH_STRENGTHEN_CONV conv b1; 368 val thm2 = DEPTH_STRENGTHEN_CONV conv b2; 369 370 val (b1,c1) = dest_imp (concl thm1); 371 val (b2,c2) = dest_imp (concl thm2); 372 val thm3 = ISPECL [b1,c1,b2,c2] imp_thm_disj; 373 val thm4 = MP thm3 thm1; 374 val thm5 = MP thm4 thm2; 375 in 376 thm5 377 end 378 else if (is_forall t) then 379 let 380 val (var, body) = dest_forall t; 381 val thm_body = DEPTH_STRENGTHEN_CONV conv body; 382 val thm = GEN_ASSUM var thm_body; 383 val thm2 = HO_MATCH_MP imp_thm_forall thm; 384 in 385 thm2 386 end 387 else 388 ((let 389 val thm = (STRENGTHEN_CONV_WRAPPER conv) t; 390 val (ante,_) = dest_imp (concl thm); 391 val thm2 = DEPTH_STRENGTHEN_CONV conv ante; 392 val thm3 = IMP_TRANS thm2 thm; 393 in 394 thm3 395 end handle HOL_ERR _ => REFL_IMP_CONV t) 396 handle UNCHANGED => REFL_IMP_CONV t); 397 398fun UNCHANGED_STRENGTHEN_CONV conv t = 399 let 400 val thm = conv t; 401 val (ante,conc) = dest_imp (concl thm); 402 val _ = if (ante = conc) then raise UNCHANGED else (); 403 in 404 thm 405 end; 406 407 408fun ORELSE_STRENGTHEN_CONV [] t = raise UNCHANGED 409 | ORELSE_STRENGTHEN_CONV (c1::L) t = 410 c1 t handle UNCHANGED => 411 ORELSE_STRENGTHEN_CONV L t; 412 413 414 415 416 417fun CONJ_ASSUMPTIONS_STRENGTHEN_CONV conv preserve_hyps t = 418let 419 val thm = conv t; 420 val new_hyps = filter (fn t => not (mem t preserve_hyps)) (hyp thm); 421 val hyp_thms = map (fn t => 422 ((SOME (CONJ_ASSUMPTIONS_STRENGTHEN_CONV conv preserve_hyps t)) 423 handle HOL_ERR _ => NONE) 424 handle UNCHANGED => NONE) new_hyps; 425 426 val hyp_thms2 = filter (fn thm_opt => (isSome thm_opt andalso 427 let val (l,r) = dest_imp (concl (valOf thm_opt)) in (not (l = r)) end handle HOL_ERR _ => false)) hyp_thms; 428 val hyp_thms3 = map (UNDISCH o valOf) hyp_thms2; 429 430 val thm2 = foldr (fn (thm1,thm2) => PROVE_HYP thm1 thm2) thm hyp_thms3; 431 432 433 val new_hyps2 = filter (fn t => not (mem t preserve_hyps)) (hyp thm2); 434 val thm3 = foldr (fn (t,thm) => SUBST_MATCH (SPEC_ALL AND_IMP_INTRO) (DISCH t thm)) thm2 (new_hyps2); 435 val thm4 = CONV_RULE (RATOR_CONV (REWRITE_CONV [])) thm3 436in 437 thm4 438end; 439 440 441fun CONJ_ASSUMPTIONS_DEPTH_STRENGTHEN_CONV conv = 442 CONJ_ASSUMPTIONS_STRENGTHEN_CONV (DEPTH_STRENGTHEN_CONV conv) [] 443 444 445fun IMP_STRENGTHEN_CONV_RULE conv thm = let 446 val (imp_term,_) = dest_imp (concl thm); 447 val imp_thm = conv imp_term; 448 in 449 IMP_TRANS imp_thm thm 450 end 451 452 453fun STRENGTHEN_CONV_TAC conv (asm,t) = 454 HO_MATCH_MP_TAC (conv t) (asm,t); 455 456 457fun DEPTH_STRENGTHEN_CONV_TAC conv = 458 STRENGTHEN_CONV_TAC (DEPTH_STRENGTHEN_CONV conv) 459 460 461 462 463 464 465 466 467 468fun COND_REWR_CONV___with_match thm = 469 if (is_imp (concl thm)) then 470 if (is_eq (snd (dest_imp (concl thm)))) then 471 (UNDISCH o (PART_MATCH (lhs o snd o dest_imp) thm), 472 (lhs o snd o dest_imp o concl) thm) 473 else 474 (EQT_INTRO o UNDISCH o (PART_MATCH (snd o dest_imp) thm), 475 (snd o dest_imp o concl) thm) 476 else 477 if (is_eq (concl thm)) then 478 (PART_MATCH lhs thm, 479 (lhs o concl) thm) 480 else 481 (EQT_INTRO o PART_MATCH I thm, 482 concl thm) 483 484 485fun COND_REWR_CONV thm = 486 fst (COND_REWR_CONV___with_match thm); 487 488 489 490 491fun COND_REWRITE_CONV thmL = 492 let 493 val thmL' = flatten (map BODY_CONJUNCTS thmL); 494 val conv_termL = map COND_REWR_CONV___with_match thmL'; 495 val net = foldr (fn ((conv,t),net) => Net.insert (t,conv) net) Net.empty conv_termL; 496 in 497 REPEATC (fn t => 498 let 499 val convL = Net.match t net; 500 in 501 FIRST_CONV convL t 502 end) 503 end 504 505 506fun GUARDED_COND_REWRITE_CONV p thmL = 507 let 508 val conv = COND_REWRITE_CONV thmL 509 in 510 fn t => if p t then conv t else NO_CONV t 511 end 512 513 514(* 515fun COND_REWRITE_RULE r thm = 516 let 517 val rs = flatten (map (fn thm => CONJUNCTS thm) r); 518 val rs = map UNDISCH_ALL rs; 519 val thm' = repeat (fn thm => tryfind (fn thm2 => SUBST_MATCH thm2 thm) rs) thm 520 in 521 thm' 522 end; 523 524 525*) 526 527end 528