1(*  Title:      HOL/Tools/SMT/z3_replay.ML
2    Author:     Sascha Boehme, TU Muenchen
3    Author:     Jasmin Blanchette, TU Muenchen
4
5Z3 proof parsing and replay.
6*)
7
8signature Z3_REPLAY =
9sig
10  val parse_proof: Proof.context -> SMT_Translate.replay_data ->
11    ((string * ATP_Problem_Generate.stature) * thm) list -> term list -> term -> string list ->
12    SMT_Solver.parsed_proof
13  val replay: Proof.context -> SMT_Translate.replay_data -> string list -> thm
14end;
15
16structure Z3_Replay: Z3_REPLAY =
17struct
18
19local
20  fun extract (Z3_Proof.Z3_Step {id, rule, concl, fixes, ...}) = (id, rule, concl, fixes)
21  fun cond rule = Z3_Proof.is_assumption rule andalso rule <> Z3_Proof.Hypothesis
22in
23
24val add_asserted = SMT_Replay.add_asserted Inttab.update Inttab.empty extract cond
25
26end
27
28fun add_paramTs names t =
29  fold2 (fn n => fn (_, T) => AList.update (op =) (n, T)) names (SMT_Replay.params_of t)
30
31fun new_fixes ctxt nTs =
32  let
33    val (ns, ctxt') = Variable.variant_fixes (replicate (length nTs) "") ctxt
34    fun mk (n, T) n' = (n, Thm.cterm_of ctxt' (Free (n', T)))
35  in (ctxt', Symtab.make (map2 mk nTs ns)) end
36
37fun forall_elim_term ct (Const (\<^const_name>\<open>Pure.all\<close>, _) $ (a as Abs _)) =
38      Term.betapply (a, Thm.term_of ct)
39  | forall_elim_term _ qt = raise TERM ("forall_elim'", [qt])
40
41fun apply_fixes elim env = fold (elim o the o Symtab.lookup env)
42
43val apply_fixes_prem = uncurry o apply_fixes Thm.forall_elim
44val apply_fixes_concl = apply_fixes forall_elim_term
45
46fun export_fixes env names = Drule.forall_intr_list (map (the o Symtab.lookup env) names)
47
48fun under_fixes f ctxt (prems, nthms) names concl =
49  let
50    val thms1 = map (SMT_Replay.varify ctxt) prems
51    val (ctxt', env) =
52      add_paramTs names concl []
53      |> fold (uncurry add_paramTs o apsnd Thm.prop_of) nthms
54      |> new_fixes ctxt
55    val thms2 = map (apply_fixes_prem env) nthms
56    val t = apply_fixes_concl env names concl
57  in export_fixes env names (f ctxt' (thms1 @ thms2) t) end
58
59fun replay_thm ctxt assumed nthms (Z3_Proof.Z3_Step {id, rule, concl, fixes, is_fix_step, ...}) =
60  if Z3_Proof.is_assumption rule then
61    (case Inttab.lookup assumed id of
62      SOME (_, thm) => thm
63    | NONE => Thm.assume (Thm.cterm_of ctxt concl))
64  else
65    under_fixes (Z3_Replay_Methods.method_for rule) ctxt
66      (if is_fix_step then (map snd nthms, []) else ([], nthms)) fixes concl
67
68fun replay_step ctxt assumed (step as Z3_Proof.Z3_Step {id, rule, prems, fixes, ...}) state =
69  let
70    val (proofs, stats) = state
71    val nthms = map (the o Inttab.lookup proofs) prems
72    val replay = Timing.timing (replay_thm ctxt assumed nthms)
73    val ({elapsed, ...}, thm) =
74      SMT_Config.with_time_limit ctxt SMT_Config.reconstruction_step_timeout replay step
75        handle Timeout.TIMEOUT _ => raise SMT_Failure.SMT SMT_Failure.Time_Out
76    val stats' = Symtab.cons_list (Z3_Proof.string_of_rule rule, Time.toMilliseconds elapsed) stats
77  in (Inttab.update (id, (fixes, thm)) proofs, stats') end
78
79(* |- (EX x. P x) = P c     |- ~ (ALL x. P x) = ~ P c *)
80local
81  val sk_rules = @{lemma
82    "c = (SOME x. P x) \<Longrightarrow> (\<exists>x. P x) = P c"
83    "c = (SOME x. \<not> P x) \<Longrightarrow> (\<not> (\<forall>x. P x)) = (\<not> P c)"
84    by (metis someI_ex)+}
85in
86
87fun discharge_sk_tac ctxt i st =
88  (resolve_tac ctxt @{thms trans} i
89   THEN resolve_tac ctxt sk_rules i
90   THEN (resolve_tac ctxt @{thms refl} ORELSE' discharge_sk_tac ctxt) (i+1)
91   THEN resolve_tac ctxt @{thms refl} i) st
92
93end
94
95val true_thm = @{lemma "\<not>False" by simp}
96fun make_discharge_rules rules = rules @ [@{thm allI}, @{thm refl}, @{thm reflexive}, true_thm]
97
98val intro_def_rules = @{lemma
99  "(\<not> P \<or> P) \<and> (P \<or> \<not> P)"
100  "(P \<or> \<not> P) \<and> (\<not> P \<or> P)"
101  by fast+}
102
103fun discharge_assms_tac ctxt rules =
104  REPEAT
105    (HEADGOAL (resolve_tac ctxt (intro_def_rules @ rules) ORELSE'
106      SOLVED' (discharge_sk_tac ctxt)))
107
108fun discharge_assms ctxt rules thm =
109  (if Thm.nprems_of thm = 0 then
110     thm
111   else
112     (case Seq.pull (discharge_assms_tac ctxt rules thm) of
113       SOME (thm', _) => thm'
114     | NONE => raise THM ("failed to discharge premise", 1, [thm])))
115  |> Goal.norm_result ctxt
116
117fun discharge rules outer_ctxt inner_ctxt =
118  singleton (Proof_Context.export inner_ctxt outer_ctxt)
119  #> discharge_assms outer_ctxt (make_discharge_rules rules)
120
121fun parse_proof outer_ctxt
122    ({context = ctxt, typs, terms, ll_defs, rewrite_rules, assms} : SMT_Translate.replay_data)
123    xfacts prems concl output =
124  let
125    val (steps, ctxt2) = Z3_Proof.parse typs terms output ctxt
126    val ((iidths, _), _) = add_asserted outer_ctxt rewrite_rules assms steps ctxt2
127
128    fun id_of_index i = the_default ~1 (Option.map fst (AList.lookup (op =) iidths i))
129
130    val conjecture_i = 0
131    val prems_i = 1
132    val facts_i = prems_i + length prems
133
134    val conjecture_id = id_of_index conjecture_i
135    val prem_ids = map id_of_index (prems_i upto facts_i - 1)
136    val fact_ids' =
137      map_filter (fn (i, (id, _)) => try (apsnd (nth xfacts)) (id, i - facts_i)) iidths
138    val helper_ids' = map_filter (try (fn (~1, idth) => idth)) iidths
139
140    val fact_helper_ts =
141      map (fn (_, th) => (ATP_Util.short_thm_name ctxt th, Thm.prop_of th)) helper_ids' @
142      map (fn (_, ((s, _), th)) => (s, Thm.prop_of th)) fact_ids'
143    val fact_helper_ids' =
144      map (apsnd (ATP_Util.short_thm_name ctxt)) helper_ids' @ map (apsnd (fst o fst)) fact_ids'
145  in
146    {outcome = NONE, fact_ids = SOME fact_ids',
147     atp_proof = fn () => Z3_Isar.atp_proof_of_z3_proof ctxt ll_defs rewrite_rules prems concl
148       fact_helper_ts prem_ids conjecture_id fact_helper_ids' steps}
149  end
150
151fun replay outer_ctxt
152    ({context = ctxt, typs, terms, rewrite_rules, assms, ...} : SMT_Translate.replay_data) output =
153  let
154    val (steps, ctxt2) = Z3_Proof.parse typs terms output ctxt
155    val ((_, rules), (ctxt3, assumed)) =
156      add_asserted outer_ctxt rewrite_rules assms steps ctxt2
157    val ctxt4 =
158      ctxt3
159      |> put_simpset (SMT_Replay.make_simpset ctxt3 [])
160      |> Config.put SAT.solver (Config.get ctxt3 SMT_Config.sat_solver)
161    val len = length steps
162    val start = Timing.start ()
163    val print_runtime_statistics = SMT_Replay.intermediate_statistics ctxt4 start len
164    fun blockwise f (i, x) y =
165      (if i > 0 andalso i mod 100 = 0 then print_runtime_statistics i else (); f x y)
166    val (proofs, stats) =
167      fold_index (blockwise (replay_step ctxt4 assumed)) steps (assumed, Symtab.empty)
168    val _ = print_runtime_statistics len
169    val total = Time.toMilliseconds (#elapsed (Timing.result start))
170    val (_, Z3_Proof.Z3_Step {id, ...}) = split_last steps
171    val _ = SMT_Config.statistics_msg ctxt4 (Pretty.string_of o SMT_Replay.pretty_statistics "Z3" total) stats
172  in
173    Inttab.lookup proofs id |> the |> snd |> discharge rules outer_ctxt ctxt4
174  end
175
176end;
177