1(* Title: HOL/Tools/SMT/smt_normalize.ML 2 Author: Sascha Boehme, TU Muenchen 3 4Normalization steps on theorems required by SMT solvers. 5*) 6 7signature SMT_NORMALIZE = 8sig 9 val drop_fact_warning: Proof.context -> thm -> unit 10 val atomize_conv: Proof.context -> conv 11 12 val special_quant_table: (string * thm) list 13 val case_bool_entry: string * thm 14 val abs_min_max_table: (string * thm) list 15 16 type extra_norm = Proof.context -> thm list * thm list -> thm list * thm list 17 val add_extra_norm: SMT_Util.class * extra_norm -> Context.generic -> Context.generic 18 val normalize: Proof.context -> thm list -> (int * thm) list 19end; 20 21structure SMT_Normalize: SMT_NORMALIZE = 22struct 23 24fun drop_fact_warning ctxt = 25 SMT_Config.verbose_msg ctxt (prefix "Warning: dropping assumption: " o 26 Thm.string_of_thm ctxt) 27 28 29(* general theorem normalizations *) 30 31(** instantiate elimination rules **) 32 33local 34 val (cpfalse, cfalse) = `SMT_Util.mk_cprop (Thm.cterm_of \<^context> \<^const>\<open>False\<close>) 35 36 fun inst f ct thm = 37 let val cv = f (Drule.strip_imp_concl (Thm.cprop_of thm)) 38 in Thm.instantiate ([], [(dest_Var (Thm.term_of cv), ct)]) thm end 39in 40 41fun instantiate_elim thm = 42 (case Thm.concl_of thm of 43 \<^const>\<open>Trueprop\<close> $ Var (_, \<^typ>\<open>bool\<close>) => inst Thm.dest_arg cfalse thm 44 | Var _ => inst I cpfalse thm 45 | _ => thm) 46 47end 48 49 50(** normalize definitions **) 51 52fun norm_def thm = 53 (case Thm.prop_of thm of 54 \<^const>\<open>Trueprop\<close> $ (Const (\<^const_name>\<open>HOL.eq\<close>, _) $ _ $ Abs _) => 55 norm_def (thm RS @{thm fun_cong}) 56 | Const (\<^const_name>\<open>Pure.eq\<close>, _) $ _ $ Abs _ => norm_def (HOLogic.mk_obj_eq thm) 57 | _ => thm) 58 59 60(** atomization **) 61 62fun atomize_conv ctxt ct = 63 (case Thm.term_of ct of 64 \<^const>\<open>Pure.imp\<close> $ _ $ _ => 65 Conv.binop_conv (atomize_conv ctxt) then_conv Conv.rewr_conv @{thm atomize_imp} 66 | Const (\<^const_name>\<open>Pure.eq\<close>, _) $ _ $ _ => 67 Conv.binop_conv (atomize_conv ctxt) then_conv Conv.rewr_conv @{thm atomize_eq} 68 | Const (\<^const_name>\<open>Pure.all\<close>, _) $ Abs _ => 69 Conv.binder_conv (atomize_conv o snd) ctxt then_conv Conv.rewr_conv @{thm atomize_all} 70 | _ => Conv.all_conv) ct 71 handle CTERM _ => Conv.all_conv ct 72 73val setup_atomize = 74 fold SMT_Builtin.add_builtin_fun_ext'' [\<^const_name>\<open>Pure.imp\<close>, \<^const_name>\<open>Pure.eq\<close>, 75 \<^const_name>\<open>Pure.all\<close>, \<^const_name>\<open>Trueprop\<close>] 76 77 78(** unfold special quantifiers **) 79 80val special_quant_table = [ 81 (\<^const_name>\<open>Ex1\<close>, @{thm Ex1_def_raw}), 82 (\<^const_name>\<open>Ball\<close>, @{thm Ball_def_raw}), 83 (\<^const_name>\<open>Bex\<close>, @{thm Bex_def_raw})] 84 85local 86 fun special_quant (Const (n, _)) = AList.lookup (op =) special_quant_table n 87 | special_quant _ = NONE 88 89 fun special_quant_conv _ ct = 90 (case special_quant (Thm.term_of ct) of 91 SOME thm => Conv.rewr_conv thm 92 | NONE => Conv.all_conv) ct 93in 94 95fun unfold_special_quants_conv ctxt = 96 SMT_Util.if_exists_conv (is_some o special_quant) (Conv.top_conv special_quant_conv ctxt) 97 98val setup_unfolded_quants = fold (SMT_Builtin.add_builtin_fun_ext'' o fst) special_quant_table 99 100end 101 102 103(** trigger inference **) 104 105local 106 (*** check trigger syntax ***) 107 108 fun dest_trigger (Const (\<^const_name>\<open>pat\<close>, _) $ _) = SOME true 109 | dest_trigger (Const (\<^const_name>\<open>nopat\<close>, _) $ _) = SOME false 110 | dest_trigger _ = NONE 111 112 fun eq_list [] = false 113 | eq_list (b :: bs) = forall (equal b) bs 114 115 fun proper_trigger t = 116 t 117 |> these o try SMT_Util.dest_symb_list 118 |> map (map_filter dest_trigger o these o try SMT_Util.dest_symb_list) 119 |> (fn [] => false | bss => forall eq_list bss) 120 121 fun proper_quant inside f t = 122 (case t of 123 Const (\<^const_name>\<open>All\<close>, _) $ Abs (_, _, u) => proper_quant true f u 124 | Const (\<^const_name>\<open>Ex\<close>, _) $ Abs (_, _, u) => proper_quant true f u 125 | \<^const>\<open>trigger\<close> $ p $ u => 126 (if inside then f p else false) andalso proper_quant false f u 127 | Abs (_, _, u) => proper_quant false f u 128 | u1 $ u2 => proper_quant false f u1 andalso proper_quant false f u2 129 | _ => true) 130 131 fun check_trigger_error ctxt t = 132 error ("SMT triggers must only occur under quantifier and multipatterns " ^ 133 "must have the same kind: " ^ Syntax.string_of_term ctxt t) 134 135 fun check_trigger_conv ctxt ct = 136 if proper_quant false proper_trigger (SMT_Util.term_of ct) then Conv.all_conv ct 137 else check_trigger_error ctxt (Thm.term_of ct) 138 139 140 (*** infer simple triggers ***) 141 142 fun dest_cond_eq ct = 143 (case Thm.term_of ct of 144 Const (\<^const_name>\<open>HOL.eq\<close>, _) $ _ $ _ => Thm.dest_binop ct 145 | \<^const>\<open>HOL.implies\<close> $ _ $ _ => dest_cond_eq (Thm.dest_arg ct) 146 | _ => raise CTERM ("no equation", [ct])) 147 148 fun get_constrs thy (Type (n, _)) = these (BNF_LFP_Compat.get_constrs thy n) 149 | get_constrs _ _ = [] 150 151 fun is_constr thy (n, T) = 152 let fun match (m, U) = m = n andalso Sign.typ_instance thy (T, U) 153 in can (the o find_first match o get_constrs thy o Term.body_type) T end 154 155 fun is_constr_pat thy t = 156 (case Term.strip_comb t of 157 (Free _, []) => true 158 | (Const c, ts) => is_constr thy c andalso forall (is_constr_pat thy) ts 159 | _ => false) 160 161 fun is_simp_lhs ctxt t = 162 (case Term.strip_comb t of 163 (Const c, ts as _ :: _) => 164 not (SMT_Builtin.is_builtin_fun_ext ctxt c ts) andalso 165 forall (is_constr_pat (Proof_Context.theory_of ctxt)) ts 166 | _ => false) 167 168 fun has_all_vars vs t = 169 subset (op aconv) (vs, map Free (Term.add_frees t [])) 170 171 fun minimal_pats vs ct = 172 if has_all_vars vs (Thm.term_of ct) then 173 (case Thm.term_of ct of 174 _ $ _ => 175 (case apply2 (minimal_pats vs) (Thm.dest_comb ct) of 176 ([], []) => [[ct]] 177 | (ctss, ctss') => union (eq_set (op aconvc)) ctss ctss') 178 | _ => []) 179 else [] 180 181 fun proper_mpat _ _ _ [] = false 182 | proper_mpat thy gen u cts = 183 let 184 val tps = (op ~~) (`gen (map Thm.term_of cts)) 185 fun some_match u = tps |> exists (fn (t', t) => 186 Pattern.matches thy (t', u) andalso not (t aconv u)) 187 in not (Term.exists_subterm some_match u) end 188 189 val pat = SMT_Util.mk_const_pat \<^theory> \<^const_name>\<open>pat\<close> Thm.dest_ctyp0 190 fun mk_pat ct = Thm.apply (SMT_Util.instT' ct pat) ct 191 192 fun mk_clist T = 193 apply2 (Thm.cterm_of \<^context>) (SMT_Util.symb_cons_const T, SMT_Util.symb_nil_const T) 194 fun mk_list (ccons, cnil) f cts = fold_rev (Thm.mk_binop ccons o f) cts cnil 195 val mk_pat_list = mk_list (mk_clist \<^typ>\<open>pattern\<close>) 196 val mk_mpat_list = mk_list (mk_clist \<^typ>\<open>pattern symb_list\<close>) 197 fun mk_trigger ctss = mk_mpat_list (mk_pat_list mk_pat) ctss 198 199 val trigger_eq = mk_meta_eq @{lemma "p = trigger t p" by (simp add: trigger_def)} 200 201 fun insert_trigger_conv [] ct = Conv.all_conv ct 202 | insert_trigger_conv ctss ct = 203 let 204 val (ctr, cp) = Thm.dest_binop (Thm.rhs_of trigger_eq) ||> rpair ct 205 val inst = map (apfst (dest_Var o Thm.term_of)) [cp, (ctr, mk_trigger ctss)] 206 in Thm.instantiate ([], inst) trigger_eq end 207 208 fun infer_trigger_eq_conv outer_ctxt (ctxt, cvs) ct = 209 let 210 val (lhs, rhs) = dest_cond_eq ct 211 212 val vs = map Thm.term_of cvs 213 val thy = Proof_Context.theory_of ctxt 214 215 fun get_mpats ct = 216 if is_simp_lhs ctxt (Thm.term_of ct) then minimal_pats vs ct 217 else [] 218 val gen = Variable.export_terms ctxt outer_ctxt 219 val filter_mpats = filter (proper_mpat thy gen (Thm.term_of rhs)) 220 221 in insert_trigger_conv (filter_mpats (get_mpats lhs)) ct end 222 223 fun has_trigger (\<^const>\<open>trigger\<close> $ _ $ _) = true 224 | has_trigger _ = false 225 226 fun try_trigger_conv cv ct = 227 if SMT_Util.under_quant has_trigger (SMT_Util.term_of ct) then Conv.all_conv ct 228 else Conv.try_conv cv ct 229 230 fun infer_trigger_conv ctxt = 231 if Config.get ctxt SMT_Config.infer_triggers then 232 try_trigger_conv (SMT_Util.under_quant_conv (infer_trigger_eq_conv ctxt) ctxt) 233 else Conv.all_conv 234in 235 236fun trigger_conv ctxt = 237 SMT_Util.prop_conv (check_trigger_conv ctxt then_conv infer_trigger_conv ctxt) 238 239val setup_trigger = 240 fold SMT_Builtin.add_builtin_fun_ext'' 241 [\<^const_name>\<open>pat\<close>, \<^const_name>\<open>nopat\<close>, \<^const_name>\<open>trigger\<close>] 242 243end 244 245 246(** combined general normalizations **) 247 248fun gen_normalize1_conv ctxt = 249 atomize_conv ctxt then_conv 250 unfold_special_quants_conv ctxt then_conv 251 Thm.beta_conversion true then_conv 252 trigger_conv ctxt 253 254fun gen_normalize1 ctxt = 255 instantiate_elim #> 256 norm_def #> 257 Conv.fconv_rule (Thm.beta_conversion true then_conv Thm.eta_conversion) #> 258 Drule.forall_intr_vars #> 259 Conv.fconv_rule (gen_normalize1_conv ctxt) #> 260 (* Z3 4.3.1 silently normalizes "P --> Q --> R" to "P & Q --> R" *) 261 Raw_Simplifier.rewrite_rule ctxt @{thms HOL.imp_conjL[symmetric, THEN eq_reflection]} 262 263fun gen_norm1_safe ctxt (i, thm) = 264 (case try (gen_normalize1 ctxt) thm of 265 SOME thm' => SOME (i, thm') 266 | NONE => (drop_fact_warning ctxt thm; NONE)) 267 268fun gen_normalize ctxt iwthms = map_filter (gen_norm1_safe ctxt) iwthms 269 270 271(* unfolding of definitions and theory-specific rewritings *) 272 273fun expand_head_conv cv ct = 274 (case Thm.term_of ct of 275 _ $ _ => 276 Conv.fun_conv (expand_head_conv cv) then_conv 277 Conv.try_conv (Thm.beta_conversion false) 278 | _ => cv) ct 279 280 281(** rewrite bool case expressions as if expressions **) 282 283val case_bool_entry = (\<^const_name>\<open>bool.case_bool\<close>, @{thm case_bool_if}) 284 285local 286 fun is_case_bool (Const (\<^const_name>\<open>bool.case_bool\<close>, _)) = true 287 | is_case_bool _ = false 288 289 fun unfold_conv _ = 290 SMT_Util.if_true_conv (is_case_bool o Term.head_of) 291 (expand_head_conv (Conv.rewr_conv @{thm case_bool_if})) 292in 293 294fun rewrite_case_bool_conv ctxt = 295 SMT_Util.if_exists_conv is_case_bool (Conv.top_conv unfold_conv ctxt) 296 297val setup_case_bool = SMT_Builtin.add_builtin_fun_ext'' \<^const_name>\<open>bool.case_bool\<close> 298 299end 300 301 302(** unfold abs, min and max **) 303 304val abs_min_max_table = [ 305 (\<^const_name>\<open>min\<close>, @{thm min_def_raw}), 306 (\<^const_name>\<open>max\<close>, @{thm max_def_raw}), 307 (\<^const_name>\<open>abs\<close>, @{thm abs_if_raw})] 308 309local 310 fun abs_min_max ctxt (Const (n, Type (\<^type_name>\<open>fun\<close>, [T, _]))) = 311 (case AList.lookup (op =) abs_min_max_table n of 312 NONE => NONE 313 | SOME thm => if SMT_Builtin.is_builtin_typ_ext ctxt T then SOME thm else NONE) 314 | abs_min_max _ _ = NONE 315 316 fun unfold_amm_conv ctxt ct = 317 (case abs_min_max ctxt (Term.head_of (Thm.term_of ct)) of 318 SOME thm => expand_head_conv (Conv.rewr_conv thm) 319 | NONE => Conv.all_conv) ct 320in 321 322fun unfold_abs_min_max_conv ctxt = 323 SMT_Util.if_exists_conv (is_some o abs_min_max ctxt) (Conv.top_conv unfold_amm_conv ctxt) 324 325val setup_abs_min_max = fold (SMT_Builtin.add_builtin_fun_ext'' o fst) abs_min_max_table 326 327end 328 329 330(** embedding of standard natural number operations into integer operations **) 331 332local 333 val simple_nat_ops = [ 334 @{const HOL.eq (nat)}, @{const less (nat)}, @{const less_eq (nat)}, 335 \<^const>\<open>Suc\<close>, @{const plus (nat)}, @{const minus (nat)}] 336 337 val nat_consts = simple_nat_ops @ 338 [@{const numeral (nat)}, @{const zero_class.zero (nat)}, @{const one_class.one (nat)}] @ 339 [@{const times (nat)}, @{const divide (nat)}, @{const modulo (nat)}] 340 341 val is_nat_const = member (op aconv) nat_consts 342 343 val nat_int_thm = Thm.symmetric (mk_meta_eq @{thm nat_int}) 344 val nat_int_comp_thms = map mk_meta_eq @{thms nat_int_comparison} 345 val int_ops_thms = map mk_meta_eq @{thms int_ops} 346 val int_if_thm = mk_meta_eq @{thm int_if} 347 348 fun if_conv cv1 cv2 = Conv.combination_conv (Conv.combination_conv (Conv.arg_conv cv1) cv2) cv2 349 350 fun int_ops_conv cv ctxt ct = 351 (case Thm.term_of ct of 352 @{const of_nat (int)} $ (Const (\<^const_name>\<open>If\<close>, _) $ _ $ _ $ _) => 353 Conv.rewr_conv int_if_thm then_conv 354 if_conv (cv ctxt) (int_ops_conv cv ctxt) 355 | @{const of_nat (int)} $ _ => 356 (Conv.rewrs_conv int_ops_thms then_conv 357 Conv.top_sweep_conv (int_ops_conv cv) ctxt) else_conv 358 Conv.arg_conv (Conv.sub_conv cv ctxt) 359 | _ => Conv.no_conv) ct 360 361 val unfold_nat_let_conv = Conv.rewr_conv @{lemma "Let (n::nat) f \<equiv> f n" by (rule Let_def)} 362 val drop_nat_int_conv = Conv.rewr_conv (Thm.symmetric nat_int_thm) 363 364 fun nat_to_int_conv ctxt ct = ( 365 Conv.top_conv (K (Conv.try_conv unfold_nat_let_conv)) ctxt then_conv 366 Conv.top_sweep_conv nat_conv ctxt then_conv 367 Conv.top_conv (K (Conv.try_conv drop_nat_int_conv)) ctxt) ct 368 369 and nat_conv ctxt ct = ( 370 Conv.rewrs_conv (nat_int_thm :: nat_int_comp_thms) then_conv 371 Conv.top_sweep_conv (int_ops_conv nat_to_int_conv) ctxt) ct 372 373 fun add_int_of_nat vs ct cu (q, cts) = 374 (case Thm.term_of ct of 375 @{const of_nat(int)} => 376 if Term.exists_subterm (member (op aconv) vs) (Thm.term_of cu) then (true, cts) 377 else (q, insert (op aconvc) cu cts) 378 | _ => (q, cts)) 379 380 fun add_apps f vs ct = 381 (case Thm.term_of ct of 382 _ $ _ => 383 let val (cu1, cu2) = Thm.dest_comb ct 384 in f vs cu1 cu2 #> add_apps f vs cu1 #> add_apps f vs cu2 end 385 | Abs _ => 386 let val (cv, cu) = Thm.dest_abs NONE ct 387 in add_apps f (Thm.term_of cv :: vs) cu end 388 | _ => I) 389 390 val int_thm = @{lemma "(0::int) <= int (n::nat)" by simp} 391 val nat_int_thms = @{lemma 392 "\<forall>n::nat. (0::int) <= int n" 393 "\<forall>n::nat. nat (int n) = n" 394 "\<forall>i::int. int (nat i) = (if 0 <= i then i else 0)" 395 by simp_all} 396 val var = Term.dest_Var (Thm.term_of (funpow 3 Thm.dest_arg (Thm.cprop_of int_thm))) 397in 398 399fun nat_as_int_conv ctxt = SMT_Util.if_exists_conv is_nat_const (nat_to_int_conv ctxt) 400 401fun add_int_of_nat_constraints thms = 402 let val (q, cts) = fold (add_apps add_int_of_nat [] o Thm.cprop_of) thms (false, []) 403 in 404 if q then (thms, nat_int_thms) 405 else (thms, map (fn ct => Thm.instantiate ([], [(var, ct)]) int_thm) cts) 406 end 407 408val setup_nat_as_int = 409 SMT_Builtin.add_builtin_typ_ext (\<^typ>\<open>nat\<close>, 410 fn ctxt => K (Config.get ctxt SMT_Config.nat_as_int)) #> 411 fold (SMT_Builtin.add_builtin_fun_ext' o Term.dest_Const) simple_nat_ops 412 413end 414 415 416(** normalize numerals **) 417 418local 419 (* 420 rewrite Numeral1 into 1 421 rewrite - 0 into 0 422 *) 423 424 fun is_irregular_number (Const (\<^const_name>\<open>numeral\<close>, _) $ Const (\<^const_name>\<open>num.One\<close>, _)) = 425 true 426 | is_irregular_number (Const (\<^const_name>\<open>uminus\<close>, _) $ Const (\<^const_name>\<open>Groups.zero\<close>, _)) = 427 true 428 | is_irregular_number _ = false 429 430 fun is_strange_number ctxt t = is_irregular_number t andalso SMT_Builtin.is_builtin_num ctxt t 431 432 val proper_num_ss = 433 simpset_of (put_simpset HOL_ss \<^context> addsimps @{thms Num.numeral_One minus_zero}) 434 435 fun norm_num_conv ctxt = 436 SMT_Util.if_conv (is_strange_number ctxt) (Simplifier.rewrite (put_simpset proper_num_ss ctxt)) 437 Conv.no_conv 438in 439 440fun normalize_numerals_conv ctxt = 441 SMT_Util.if_exists_conv (is_strange_number ctxt) (Conv.top_sweep_conv norm_num_conv ctxt) 442 443end 444 445 446(** combined unfoldings and rewritings **) 447 448fun burrow_ids f ithms = 449 let 450 val (is, thms) = split_list ithms 451 val (thms', extra_thms) = f thms 452 in (is ~~ thms') @ map (pair ~1) extra_thms end 453 454fun unfold_conv ctxt = 455 rewrite_case_bool_conv ctxt then_conv 456 unfold_abs_min_max_conv ctxt then_conv 457 (if Config.get ctxt SMT_Config.nat_as_int then nat_as_int_conv ctxt 458 else Conv.all_conv) then_conv 459 Thm.beta_conversion true 460 461fun unfold_polymorph ctxt = map (apsnd (Conv.fconv_rule (unfold_conv ctxt))) 462fun unfold_monomorph ctxt = 463 map (apsnd (Conv.fconv_rule (normalize_numerals_conv ctxt))) 464 #> Config.get ctxt SMT_Config.nat_as_int ? burrow_ids add_int_of_nat_constraints 465 466 467(* overall normalization *) 468 469type extra_norm = Proof.context -> thm list * thm list -> thm list * thm list 470 471structure Extra_Norms = Generic_Data 472( 473 type T = extra_norm SMT_Util.dict 474 val empty = [] 475 val extend = I 476 fun merge data = SMT_Util.dict_merge fst data 477) 478 479fun add_extra_norm (cs, norm) = Extra_Norms.map (SMT_Util.dict_update (cs, norm)) 480 481fun apply_extra_norms ctxt ithms = 482 let 483 val cs = SMT_Config.solver_class_of ctxt 484 val es = SMT_Util.dict_lookup (Extra_Norms.get (Context.Proof ctxt)) cs 485 in burrow_ids (fold (fn e => e ctxt) es o rpair []) ithms end 486 487local 488 val ignored = member (op =) [\<^const_name>\<open>All\<close>, \<^const_name>\<open>Ex\<close>, 489 \<^const_name>\<open>Let\<close>, \<^const_name>\<open>If\<close>, \<^const_name>\<open>HOL.eq\<close>] 490 491 val schematic_consts_of = 492 let 493 fun collect (\<^const>\<open>trigger\<close> $ p $ t) = collect_trigger p #> collect t 494 | collect (t $ u) = collect t #> collect u 495 | collect (Abs (_, _, t)) = collect t 496 | collect (t as Const (n, _)) = 497 if not (ignored n) then Monomorph.add_schematic_consts_of t else I 498 | collect _ = I 499 and collect_trigger t = 500 let val dest = these o try SMT_Util.dest_symb_list 501 in fold (fold collect_pat o dest) (dest t) end 502 and collect_pat (Const (\<^const_name>\<open>pat\<close>, _) $ t) = collect t 503 | collect_pat (Const (\<^const_name>\<open>nopat\<close>, _) $ t) = collect t 504 | collect_pat _ = I 505 in (fn t => collect t Symtab.empty) end 506in 507 508fun monomorph ctxt xthms = 509 let val (xs, thms) = split_list xthms 510 in 511 map (pair 1) thms 512 |> Monomorph.monomorph schematic_consts_of ctxt 513 |> maps (uncurry (map o pair)) o map2 pair xs o map (map snd) 514 end 515 516end 517 518fun normalize ctxt wthms = 519 wthms 520 |> map_index I 521 |> gen_normalize ctxt 522 |> unfold_polymorph ctxt 523 |> monomorph ctxt 524 |> unfold_monomorph ctxt 525 |> apply_extra_norms ctxt 526 527val _ = Theory.setup (Context.theory_map ( 528 setup_atomize #> 529 setup_unfolded_quants #> 530 setup_trigger #> 531 setup_case_bool #> 532 setup_abs_min_max #> 533 setup_nat_as_int)) 534 535end; 536