1(*  Title:      HOL/Tools/SMT/verit_replay.ML
2    Author:     Mathias Fleury, MPII
3
4VeriT proof parsing and replay.
5*)
6
7signature VERIT_REPLAY =
8sig
9  val replay: Proof.context -> SMT_Translate.replay_data -> string list -> thm
10end;
11
12structure Verit_Replay: VERIT_REPLAY =
13struct
14
15fun under_fixes f unchanged_prems (prems, nthms) names args (concl, ctxt) =
16  let
17    val thms1 = unchanged_prems @ map (SMT_Replay.varify ctxt) prems
18    val _ =  SMT_Config.veriT_msg ctxt (fn () => \<^print>  ("names =", names))
19    val thms2 = map snd nthms
20    val _ = SMT_Config.veriT_msg ctxt (fn () => \<^print> ("prems=", prems))
21    val _ = SMT_Config.veriT_msg ctxt (fn () => \<^print> ("nthms=", nthms))
22    val _ = SMT_Config.veriT_msg ctxt (fn () => \<^print> ("thms1=", thms1))
23    val _ = SMT_Config.veriT_msg ctxt (fn () => \<^print> ("thms2=", thms2))
24  in (f ctxt (thms1 @ thms2) args concl) end
25
26
27(** Replaying **)
28
29fun replay_thm method_for rewrite_rules ll_defs ctxt assumed unchanged_prems prems nthms
30    concl_transformation global_transformation args
31    (VeriT_Proof.VeriT_Replay_Node {id, rule, concl, bounds,  ...}) =
32  let
33    val _ = SMT_Config.veriT_msg ctxt (fn () => \<^print> id)
34    val rewrite = let val thy = Proof_Context.theory_of (empty_simpset ctxt) in
35        Raw_Simplifier.rewrite_term thy rewrite_rules []
36        #> not (null ll_defs) ? SMTLIB_Isar.unlift_term ll_defs
37      end
38    val post = let val thy = Proof_Context.theory_of (empty_simpset ctxt) in
39        Raw_Simplifier.rewrite_term thy rewrite_rules []
40        #> Object_Logic.atomize_term ctxt
41        #> not (null ll_defs) ? SMTLIB_Isar.unlift_term ll_defs
42        #> SMTLIB_Isar.unskolemize_names ctxt
43        #> HOLogic.mk_Trueprop
44      end
45    val concl = concl
46      |> concl_transformation
47      |> global_transformation
48      |> post
49in
50  if rule = VeriT_Proof.veriT_input_rule then
51    (case Symtab.lookup assumed id of
52      SOME (_, thm) => thm)
53  else
54    under_fixes (method_for rule) unchanged_prems
55      (prems, nthms) (map fst bounds)
56      (map rewrite args) (concl, ctxt)
57end
58
59fun add_used_asserts_in_step (VeriT_Proof.VeriT_Replay_Node {prems,
60    subproof = (_, _, subproof), ...}) =
61  union (op =) (map_filter (try SMTLIB_Interface.assert_index_of_name) prems @
62     flat (map (fn x => add_used_asserts_in_step x []) subproof))
63
64fun remove_rewrite_rules_from_rules n =
65  (fn (step as VeriT_Proof.VeriT_Replay_Node {id, ...}) =>
66    (case try SMTLIB_Interface.assert_index_of_name id of
67      NONE => SOME step
68    | SOME a => if a < n then NONE else SOME step))
69
70fun replay_step rewrite_rules ll_defs assumed proof_prems
71  (step as VeriT_Proof.VeriT_Replay_Node {id, rule, prems, bounds, args,
72     subproof = (fixes, assms, subproof), concl, ...}) state =
73  let
74    val (proofs, stats, ctxt, concl_tranformation, global_transformation) = state
75    val (_, ctxt) = Variable.variant_fixes (map fst bounds) ctxt
76      |> (fn (names, ctxt) => (names,
77        fold Variable.declare_term [SMTLIB_Isar.unskolemize_names ctxt concl] ctxt))
78
79    val (names, sub_ctxt) = Variable.variant_fixes (map fst fixes) ctxt
80       ||> fold Variable.declare_term (map Free fixes)
81    val export_vars =
82      Term.subst_free (ListPair.zip (map Free fixes, map Free (ListPair.zip (names, map snd fixes))))
83      o concl_tranformation
84
85    val post = let val thy = Proof_Context.theory_of (empty_simpset ctxt) in
86        Raw_Simplifier.rewrite_term thy rewrite_rules []
87        #> Object_Logic.atomize_term ctxt
88        #> not (null ll_defs) ? SMTLIB_Isar.unlift_term ll_defs
89        #> SMTLIB_Isar.unskolemize_names ctxt
90        #> HOLogic.mk_Trueprop
91      end
92    val assms = map (export_vars o global_transformation o post) assms
93    val (proof_prems', sub_ctxt2) = Assumption.add_assumes (map (Thm.cterm_of sub_ctxt) assms)
94      sub_ctxt
95
96    val all_proof_prems = proof_prems @ proof_prems'
97    val (proofs', stats, _, _, sub_global_rew) =
98       fold (replay_step rewrite_rules ll_defs assumed all_proof_prems) subproof
99         (assumed, stats, sub_ctxt2, export_vars, global_transformation)
100    val export_thm = singleton (Proof_Context.export sub_ctxt2 ctxt)
101    val nthms = prems
102      |>  map (apsnd export_thm o the o (Symtab.lookup (if null subproof then proofs else proofs')))
103    val proof_prems =
104       if Verit_Replay_Methods.veriT_step_requires_subproof_assms rule then proof_prems else []
105    val replay = Timing.timing (replay_thm Verit_Replay_Methods.method_for rewrite_rules ll_defs
106       ctxt assumed [] (proof_prems) nthms concl_tranformation global_transformation args)
107    val ({elapsed, ...}, thm) =
108      SMT_Config.with_time_limit ctxt SMT_Config.reconstruction_step_timeout replay step
109        handle Timeout.TIMEOUT _ => raise SMT_Failure.SMT SMT_Failure.Time_Out
110    val stats' = Symtab.cons_list (rule, Time.toMilliseconds elapsed) stats
111  in (Symtab.update (id, (map fst bounds, thm)) proofs, stats', ctxt,
112       concl_tranformation, sub_global_rew) end
113
114fun replay_ll_def assms ll_defs rewrite_rules stats ctxt term =
115  let
116    val rewrite = let val thy = Proof_Context.theory_of (empty_simpset ctxt) in
117        Raw_Simplifier.rewrite_term thy rewrite_rules []
118        #> not (null ll_defs) ? SMTLIB_Isar.unlift_term ll_defs
119      end
120   val replay = Timing.timing (SMT_Replay_Methods.prove ctxt (rewrite term))
121    val ({elapsed, ...}, thm) =
122      SMT_Config.with_time_limit ctxt SMT_Config.reconstruction_step_timeout replay
123         (fn _ => Method.insert_tac ctxt (map snd assms) THEN' Classical.fast_tac ctxt)
124        handle Timeout.TIMEOUT _ => raise SMT_Failure.SMT SMT_Failure.Time_Out
125    val stats' = Symtab.cons_list ("ll_defs", Time.toMilliseconds elapsed) stats
126  in
127    (thm, stats')
128  end
129
130fun replay outer_ctxt
131    ({context = ctxt, typs, terms, rewrite_rules, assms, ll_defs, ...} : SMT_Translate.replay_data)
132     output =
133  let
134    val rewrite_rules =
135      filter_out (fn thm => Term.could_unify (Thm.prop_of @{thm verit_eq_true_simplify},
136          Thm.prop_of thm))
137        rewrite_rules
138    val num_ll_defs = length ll_defs
139    val index_of_id = Integer.add (~ num_ll_defs)
140    val id_of_index = Integer.add num_ll_defs
141
142    val (actual_steps, ctxt2) =
143      VeriT_Proof.parse_replay typs terms output ctxt
144
145    fun step_of_assume (j, (_, th)) =
146      VeriT_Proof.VeriT_Replay_Node {
147        id = SMTLIB_Interface.assert_name_of_index (id_of_index j),
148        rule = VeriT_Proof.veriT_input_rule,
149        args = [],
150        prems = [],
151        proof_ctxt = [],
152        concl = Thm.prop_of th
153          |> Raw_Simplifier.rewrite_term (Proof_Context.theory_of
154               (empty_simpset ctxt addsimps rewrite_rules)) [] [],
155        bounds = [],
156        subproof = ([], [], [])}
157    val used_assert_ids = fold add_used_asserts_in_step actual_steps []
158    fun normalize_tac ctxt = let val thy = Proof_Context.theory_of (empty_simpset ctxt) in
159      Raw_Simplifier.rewrite_term thy rewrite_rules [] end
160    val used_assm_js =
161      map_filter (fn id => let val i = index_of_id id in if i >= 0 then SOME (i, nth assms i)
162          else NONE end)
163        used_assert_ids
164
165    val assm_steps = map step_of_assume used_assm_js
166    val steps = assm_steps @ actual_steps
167
168    fun extract (VeriT_Proof.VeriT_Replay_Node {id, rule, concl, bounds, ...}) =
169         (id, rule, concl, map fst bounds)
170    fun cond rule = rule = VeriT_Proof.veriT_input_rule
171    val add_asssert = SMT_Replay.add_asserted Symtab.update Symtab.empty extract cond
172    val ((_, _), (ctxt3, assumed)) =
173      add_asssert outer_ctxt rewrite_rules assms
174        (map_filter (remove_rewrite_rules_from_rules num_ll_defs) steps) ctxt2
175
176    val used_rew_js =
177      map_filter (fn id => let val i = index_of_id id in if i < 0
178          then SOME (id, normalize_tac ctxt (nth ll_defs id)) else NONE end)
179        used_assert_ids
180    val (assumed, stats) = fold (fn ((id, thm)) => fn (assumed, stats) =>
181           let val (thm, stats) =  replay_ll_def assms ll_defs rewrite_rules stats ctxt thm
182           in (Symtab.update (SMTLIB_Interface.assert_name_of_index id, ([], thm)) assumed, stats)
183           end)
184         used_rew_js (assumed,  Symtab.empty)
185
186    val ctxt4 =
187      ctxt3
188      |> put_simpset (SMT_Replay.make_simpset ctxt3 [])
189      |> Config.put SAT.solver (Config.get ctxt3 SMT_Config.sat_solver)
190    val len = length steps
191    val start = Timing.start ()
192    val print_runtime_statistics = SMT_Replay.intermediate_statistics ctxt4 start len
193    fun blockwise f (i, x) y =
194      (if i > 0 andalso i mod 100 = 0 then print_runtime_statistics i else (); f x y)
195    val (proofs, stats, ctxt5, _, _) =
196      fold_index (blockwise (replay_step rewrite_rules ll_defs assumed [])) steps
197        (assumed, stats, ctxt4, fn x => x, fn x => x)
198    val _ = print_runtime_statistics len
199    val total = Time.toMilliseconds (#elapsed (Timing.result start))
200    val (_, VeriT_Proof.VeriT_Replay_Node {id, ...}) = split_last steps
201    val _ = SMT_Config.statistics_msg ctxt5
202      (Pretty.string_of o SMT_Replay.pretty_statistics "veriT" total) stats
203  in
204    Symtab.lookup proofs id |> the |> snd |> singleton (Proof_Context.export ctxt5 outer_ctxt)
205  end
206
207end
208