1(* Title: HOL/Tools/SMT/z3_replay.ML 2 Author: Sascha Boehme, TU Muenchen 3 Author: Jasmin Blanchette, TU Muenchen 4 5Z3 proof parsing and replay. 6*) 7 8signature Z3_REPLAY = 9sig 10 val parse_proof: Proof.context -> SMT_Translate.replay_data -> 11 ((string * ATP_Problem_Generate.stature) * thm) list -> term list -> term -> string list -> 12 SMT_Solver.parsed_proof 13 val replay: Proof.context -> SMT_Translate.replay_data -> string list -> thm 14end; 15 16structure Z3_Replay: Z3_REPLAY = 17struct 18 19local 20 fun extract (Z3_Proof.Z3_Step {id, rule, concl, fixes, ...}) = (id, rule, concl, fixes) 21 fun cond rule = Z3_Proof.is_assumption rule andalso rule <> Z3_Proof.Hypothesis 22in 23 24val add_asserted = SMT_Replay.add_asserted Inttab.update Inttab.empty extract cond 25 26end 27 28fun add_paramTs names t = 29 fold2 (fn n => fn (_, T) => AList.update (op =) (n, T)) names (SMT_Replay.params_of t) 30 31fun new_fixes ctxt nTs = 32 let 33 val (ns, ctxt') = Variable.variant_fixes (replicate (length nTs) "") ctxt 34 fun mk (n, T) n' = (n, Thm.cterm_of ctxt' (Free (n', T))) 35 in (ctxt', Symtab.make (map2 mk nTs ns)) end 36 37fun forall_elim_term ct (Const (\<^const_name>\<open>Pure.all\<close>, _) $ (a as Abs _)) = 38 Term.betapply (a, Thm.term_of ct) 39 | forall_elim_term _ qt = raise TERM ("forall_elim'", [qt]) 40 41fun apply_fixes elim env = fold (elim o the o Symtab.lookup env) 42 43val apply_fixes_prem = uncurry o apply_fixes Thm.forall_elim 44val apply_fixes_concl = apply_fixes forall_elim_term 45 46fun export_fixes env names = Drule.forall_intr_list (map (the o Symtab.lookup env) names) 47 48fun under_fixes f ctxt (prems, nthms) names concl = 49 let 50 val thms1 = map (SMT_Replay.varify ctxt) prems 51 val (ctxt', env) = 52 add_paramTs names concl [] 53 |> fold (uncurry add_paramTs o apsnd Thm.prop_of) nthms 54 |> new_fixes ctxt 55 val thms2 = map (apply_fixes_prem env) nthms 56 val t = apply_fixes_concl env names concl 57 in export_fixes env names (f ctxt' (thms1 @ thms2) t) end 58 59fun replay_thm ctxt assumed nthms (Z3_Proof.Z3_Step {id, rule, concl, fixes, is_fix_step, ...}) = 60 if Z3_Proof.is_assumption rule then 61 (case Inttab.lookup assumed id of 62 SOME (_, thm) => thm 63 | NONE => Thm.assume (Thm.cterm_of ctxt concl)) 64 else 65 under_fixes (Z3_Replay_Methods.method_for rule) ctxt 66 (if is_fix_step then (map snd nthms, []) else ([], nthms)) fixes concl 67 68fun replay_step ctxt assumed (step as Z3_Proof.Z3_Step {id, rule, prems, fixes, ...}) state = 69 let 70 val (proofs, stats) = state 71 val nthms = map (the o Inttab.lookup proofs) prems 72 val replay = Timing.timing (replay_thm ctxt assumed nthms) 73 val ({elapsed, ...}, thm) = 74 SMT_Config.with_time_limit ctxt SMT_Config.reconstruction_step_timeout replay step 75 handle Timeout.TIMEOUT _ => raise SMT_Failure.SMT SMT_Failure.Time_Out 76 val stats' = Symtab.cons_list (Z3_Proof.string_of_rule rule, Time.toMilliseconds elapsed) stats 77 in (Inttab.update (id, (fixes, thm)) proofs, stats') end 78 79(* |- (EX x. P x) = P c |- ~ (ALL x. P x) = ~ P c *) 80local 81 val sk_rules = @{lemma 82 "c = (SOME x. P x) \<Longrightarrow> (\<exists>x. P x) = P c" 83 "c = (SOME x. \<not> P x) \<Longrightarrow> (\<not> (\<forall>x. P x)) = (\<not> P c)" 84 by (metis someI_ex)+} 85in 86 87fun discharge_sk_tac ctxt i st = 88 (resolve_tac ctxt @{thms trans} i 89 THEN resolve_tac ctxt sk_rules i 90 THEN (resolve_tac ctxt @{thms refl} ORELSE' discharge_sk_tac ctxt) (i+1) 91 THEN resolve_tac ctxt @{thms refl} i) st 92 93end 94 95val true_thm = @{lemma "\<not>False" by simp} 96fun make_discharge_rules rules = rules @ [@{thm allI}, @{thm refl}, @{thm reflexive}, true_thm] 97 98val intro_def_rules = @{lemma 99 "(\<not> P \<or> P) \<and> (P \<or> \<not> P)" 100 "(P \<or> \<not> P) \<and> (\<not> P \<or> P)" 101 by fast+} 102 103fun discharge_assms_tac ctxt rules = 104 REPEAT 105 (HEADGOAL (resolve_tac ctxt (intro_def_rules @ rules) ORELSE' 106 SOLVED' (discharge_sk_tac ctxt))) 107 108fun discharge_assms ctxt rules thm = 109 (if Thm.nprems_of thm = 0 then 110 thm 111 else 112 (case Seq.pull (discharge_assms_tac ctxt rules thm) of 113 SOME (thm', _) => thm' 114 | NONE => raise THM ("failed to discharge premise", 1, [thm]))) 115 |> Goal.norm_result ctxt 116 117fun discharge rules outer_ctxt inner_ctxt = 118 singleton (Proof_Context.export inner_ctxt outer_ctxt) 119 #> discharge_assms outer_ctxt (make_discharge_rules rules) 120 121fun parse_proof outer_ctxt 122 ({context = ctxt, typs, terms, ll_defs, rewrite_rules, assms} : SMT_Translate.replay_data) 123 xfacts prems concl output = 124 let 125 val (steps, ctxt2) = Z3_Proof.parse typs terms output ctxt 126 val ((iidths, _), _) = add_asserted outer_ctxt rewrite_rules assms steps ctxt2 127 128 fun id_of_index i = the_default ~1 (Option.map fst (AList.lookup (op =) iidths i)) 129 130 val conjecture_i = 0 131 val prems_i = 1 132 val facts_i = prems_i + length prems 133 134 val conjecture_id = id_of_index conjecture_i 135 val prem_ids = map id_of_index (prems_i upto facts_i - 1) 136 val fact_ids' = 137 map_filter (fn (i, (id, _)) => try (apsnd (nth xfacts)) (id, i - facts_i)) iidths 138 val helper_ids' = map_filter (try (fn (~1, idth) => idth)) iidths 139 140 val fact_helper_ts = 141 map (fn (_, th) => (ATP_Util.short_thm_name ctxt th, Thm.prop_of th)) helper_ids' @ 142 map (fn (_, ((s, _), th)) => (s, Thm.prop_of th)) fact_ids' 143 val fact_helper_ids' = 144 map (apsnd (ATP_Util.short_thm_name ctxt)) helper_ids' @ map (apsnd (fst o fst)) fact_ids' 145 in 146 {outcome = NONE, fact_ids = SOME fact_ids', 147 atp_proof = fn () => Z3_Isar.atp_proof_of_z3_proof ctxt ll_defs rewrite_rules prems concl 148 fact_helper_ts prem_ids conjecture_id fact_helper_ids' steps} 149 end 150 151fun replay outer_ctxt 152 ({context = ctxt, typs, terms, rewrite_rules, assms, ...} : SMT_Translate.replay_data) output = 153 let 154 val (steps, ctxt2) = Z3_Proof.parse typs terms output ctxt 155 val ((_, rules), (ctxt3, assumed)) = 156 add_asserted outer_ctxt rewrite_rules assms steps ctxt2 157 val ctxt4 = 158 ctxt3 159 |> put_simpset (SMT_Replay.make_simpset ctxt3 []) 160 |> Config.put SAT.solver (Config.get ctxt3 SMT_Config.sat_solver) 161 val len = length steps 162 val start = Timing.start () 163 val print_runtime_statistics = SMT_Replay.intermediate_statistics ctxt4 start len 164 fun blockwise f (i, x) y = 165 (if i > 0 andalso i mod 100 = 0 then print_runtime_statistics i else (); f x y) 166 val (proofs, stats) = 167 fold_index (blockwise (replay_step ctxt4 assumed)) steps (assumed, Symtab.empty) 168 val _ = print_runtime_statistics len 169 val total = Time.toMilliseconds (#elapsed (Timing.result start)) 170 val (_, Z3_Proof.Z3_Step {id, ...}) = split_last steps 171 val _ = SMT_Config.statistics_msg ctxt4 (Pretty.string_of o SMT_Replay.pretty_statistics "Z3" total) stats 172 in 173 Inttab.lookup proofs id |> the |> snd |> discharge rules outer_ctxt ctxt4 174 end 175 176end; 177