1(*  Title:      HOL/Tools/SMT/smt_replay_methods.ML
2    Author:     Sascha Boehme, TU Muenchen
3    Author:     Jasmin Blanchette, TU Muenchen
4    Author:     Mathias Fleury, MPII
5
6Proof methods for replaying SMT proofs.
7*)
8
9signature SMT_REPLAY_METHODS =
10sig
11  val pretty_goal: Proof.context -> string -> string -> thm list -> term -> Pretty.T
12  val trace_goal: Proof.context -> string -> thm list -> term -> unit
13  val trace: Proof.context -> (unit -> string) -> unit
14
15  val replay_error: Proof.context -> string -> string -> thm list -> term -> 'a
16  val replay_rule_error: Proof.context -> string -> thm list -> term -> 'a
17
18  (*theory lemma methods*)
19  type th_lemma_method = Proof.context -> thm list -> term -> thm
20  val add_th_lemma_method: string * th_lemma_method -> Context.generic ->
21    Context.generic
22  val get_th_lemma_method: Proof.context -> th_lemma_method Symtab.table
23  val discharge: int -> thm list -> thm -> thm
24  val match_instantiate: Proof.context -> term -> thm -> thm
25  val prove: Proof.context -> term -> (Proof.context -> int -> tactic) -> thm
26
27  (*abstraction*)
28  type abs_context = int * term Termtab.table
29  type 'a abstracter = term -> abs_context -> 'a * abs_context
30  val add_arith_abstracter: (term abstracter -> term option abstracter) ->
31    Context.generic -> Context.generic
32
33  val abstract_lit: term -> abs_context -> term * abs_context
34  val abstract_conj: term -> abs_context -> term * abs_context
35  val abstract_disj: term -> abs_context -> term * abs_context
36  val abstract_not:  (term -> abs_context -> term * abs_context) ->
37    term -> abs_context -> term * abs_context
38  val abstract_unit:  term -> abs_context -> term * abs_context
39  val abstract_prop: term -> abs_context -> term * abs_context
40  val abstract_term:  term -> abs_context -> term * abs_context
41  val abstract_arith: Proof.context -> term -> abs_context -> term * abs_context
42
43  val prove_abstract:  Proof.context -> thm list -> term ->
44    (Proof.context -> thm list -> int -> tactic) ->
45    (abs_context -> (term list * term) * abs_context) -> thm
46  val prove_abstract': Proof.context -> term -> (Proof.context -> thm list -> int -> tactic) ->
47    (abs_context -> term * abs_context) -> thm
48  val try_provers:  Proof.context -> string -> (string * (term -> 'a)) list -> thm list -> term ->
49    'a
50
51  (*shared tactics*)
52  val cong_basic: Proof.context -> thm list -> term -> thm
53  val cong_full: Proof.context -> thm list -> term -> thm
54  val cong_unfolding_first: Proof.context -> thm list -> term -> thm
55
56  val certify_prop: Proof.context -> term -> cterm
57
58end;
59
60structure SMT_Replay_Methods: SMT_REPLAY_METHODS =
61struct
62
63(* utility functions *)
64
65fun trace ctxt f = SMT_Config.trace_msg ctxt f ()
66
67fun pretty_thm ctxt thm = Syntax.pretty_term ctxt (Thm.concl_of thm)
68
69fun pretty_goal ctxt msg rule thms t =
70  let
71    val full_msg = msg ^ ": " ^ quote rule
72    val assms =
73      if null thms then []
74      else [Pretty.big_list "assumptions:" (map (pretty_thm ctxt) thms)]
75    val concl = Pretty.big_list "proposition:" [Syntax.pretty_term ctxt t]
76  in Pretty.big_list full_msg (assms @ [concl]) end
77
78fun replay_error ctxt msg rule thms t = error (Pretty.string_of (pretty_goal ctxt msg rule thms t))
79
80fun replay_rule_error ctxt = replay_error ctxt "Failed to replay Z3 proof step"
81
82fun trace_goal ctxt rule thms t =
83  trace ctxt (fn () => Pretty.string_of (pretty_goal ctxt "Goal" rule thms t))
84
85fun as_prop (t as Const (\<^const_name>\<open>Trueprop\<close>, _) $ _) = t
86  | as_prop t = HOLogic.mk_Trueprop t
87
88fun dest_prop (Const (\<^const_name>\<open>Trueprop\<close>, _) $ t) = t
89  | dest_prop t = t
90
91fun dest_thm thm = dest_prop (Thm.concl_of thm)
92
93
94(* plug-ins *)
95
96type abs_context = int * term Termtab.table
97
98type 'a abstracter = term -> abs_context -> 'a * abs_context
99
100type th_lemma_method = Proof.context -> thm list -> term -> thm
101
102fun id_ord ((id1, _), (id2, _)) = int_ord (id1, id2)
103
104structure Plugins = Generic_Data
105(
106  type T =
107    (int * (term abstracter -> term option abstracter)) list *
108    th_lemma_method Symtab.table
109  val empty = ([], Symtab.empty)
110  val extend = I
111  fun merge ((abss1, ths1), (abss2, ths2)) = (
112    Ord_List.merge id_ord (abss1, abss2),
113    Symtab.merge (K true) (ths1, ths2))
114)
115
116fun add_arith_abstracter abs = Plugins.map (apfst (Ord_List.insert id_ord (serial (), abs)))
117fun get_arith_abstracters ctxt = map snd (fst (Plugins.get (Context.Proof ctxt)))
118
119fun add_th_lemma_method method = Plugins.map (apsnd (Symtab.update_new method))
120fun get_th_lemma_method ctxt = snd (Plugins.get (Context.Proof ctxt))
121
122fun match ctxt pat t =
123  (Vartab.empty, Vartab.empty)
124  |> Pattern.first_order_match (Proof_Context.theory_of ctxt) (pat, t)
125
126fun gen_certify_inst sel cert ctxt thm t =
127  let
128    val inst = match ctxt (dest_thm thm) (dest_prop t)
129    fun cert_inst (ix, (a, b)) = ((ix, a), cert b)
130  in Vartab.fold (cons o cert_inst) (sel inst) [] end
131
132fun match_instantiateT ctxt t thm =
133  if Term.exists_type (Term.exists_subtype Term.is_TVar) (dest_thm thm) then
134    Thm.instantiate (gen_certify_inst fst (Thm.ctyp_of ctxt) ctxt thm t, []) thm
135  else thm
136
137fun match_instantiate ctxt t thm =
138  let val thm' = match_instantiateT ctxt t thm in
139    Thm.instantiate ([], gen_certify_inst snd (Thm.cterm_of ctxt) ctxt thm' t) thm'
140  end
141
142fun discharge _ [] thm = thm
143  | discharge i (rule :: rules) thm = discharge (i + Thm.nprems_of rule) rules (rule RSN (i, thm))
144
145fun by_tac ctxt thms ns ts t tac =
146  Goal.prove ctxt [] (map as_prop ts) (as_prop t)
147    (fn {context, prems} => HEADGOAL (tac context prems))
148  |> Drule.generalize ([], ns)
149  |> discharge 1 thms
150
151fun prove ctxt t tac = by_tac ctxt [] [] [] t (K o tac)
152
153
154(* abstraction *)
155
156fun prove_abstract ctxt thms t tac f =
157  let
158    val ((prems, concl), (_, ts)) = f (1, Termtab.empty)
159    val ns = Termtab.fold (fn (_, v) => cons (fst (Term.dest_Free v))) ts []
160  in
161    by_tac ctxt [] ns prems concl tac
162    |> match_instantiate ctxt t
163    |> discharge 1 thms
164  end
165
166fun prove_abstract' ctxt t tac f =
167  prove_abstract ctxt [] t tac (f #>> pair [])
168
169fun lookup_term (_, terms) t = Termtab.lookup terms t
170
171fun abstract_sub t f cx =
172  (case lookup_term cx t of
173    SOME v => (v, cx)
174  | NONE => f cx)
175
176fun mk_fresh_free t (i, terms) =
177  let val v = Free ("t" ^ string_of_int i, fastype_of t)
178  in (v, (i + 1, Termtab.update (t, v) terms)) end
179
180fun apply_abstracters _ [] _ cx = (NONE, cx)
181  | apply_abstracters abs (abstracter :: abstracters) t cx =
182      (case abstracter abs t cx of
183        (NONE, _) => apply_abstracters abs abstracters t cx
184      | x as (SOME _, _) => x)
185
186fun abstract_term (t as _ $ _) = abstract_sub t (mk_fresh_free t)
187  | abstract_term (t as Abs _) = abstract_sub t (mk_fresh_free t)
188  | abstract_term t = pair t
189
190fun abstract_bin abs f t t1 t2 = abstract_sub t (abs t1 ##>> abs t2 #>> f)
191
192fun abstract_ter abs f t t1 t2 t3 =
193  abstract_sub t (abs t1 ##>> abs t2 ##>> abs t3 #>> (Scan.triple1 #> f))
194
195fun abstract_lit (\<^const>\<open>HOL.Not\<close> $ t) = abstract_term t #>> HOLogic.mk_not
196  | abstract_lit t = abstract_term t
197
198fun abstract_not abs (t as \<^const>\<open>HOL.Not\<close> $ t1) =
199      abstract_sub t (abs t1 #>> HOLogic.mk_not)
200  | abstract_not _ t = abstract_lit t
201
202fun abstract_conj (t as \<^const>\<open>HOL.conj\<close> $ t1 $ t2) =
203      abstract_bin abstract_conj HOLogic.mk_conj t t1 t2
204  | abstract_conj t = abstract_lit t
205
206fun abstract_disj (t as \<^const>\<open>HOL.disj\<close> $ t1 $ t2) =
207      abstract_bin abstract_disj HOLogic.mk_disj t t1 t2
208  | abstract_disj t = abstract_lit t
209
210fun abstract_prop (t as (c as @{const If (bool)}) $ t1 $ t2 $ t3) =
211      abstract_ter abstract_prop (fn (t1, t2, t3) => c $ t1 $ t2 $ t3) t t1 t2 t3
212  | abstract_prop (t as \<^const>\<open>HOL.disj\<close> $ t1 $ t2) =
213      abstract_bin abstract_prop HOLogic.mk_disj t t1 t2
214  | abstract_prop (t as \<^const>\<open>HOL.conj\<close> $ t1 $ t2) =
215      abstract_bin abstract_prop HOLogic.mk_conj t t1 t2
216  | abstract_prop (t as \<^const>\<open>HOL.implies\<close> $ t1 $ t2) =
217      abstract_bin abstract_prop HOLogic.mk_imp t t1 t2
218  | abstract_prop (t as \<^term>\<open>HOL.eq :: bool => _\<close> $ t1 $ t2) =
219      abstract_bin abstract_prop HOLogic.mk_eq t t1 t2
220  | abstract_prop t = abstract_not abstract_prop t
221
222fun abstract_arith ctxt u =
223  let
224    fun abs (t as (c as Const (\<^const_name>\<open>Hilbert_Choice.Eps\<close>, _) $ Abs (s, T, t'))) =
225          abstract_sub t (abstract_term t)
226      | abs (t as (c as Const _) $ Abs (s, T, t')) =
227          abstract_sub t (abs t' #>> (fn u' => c $ Abs (s, T, u')))
228      | abs (t as (c as Const (\<^const_name>\<open>If\<close>, _)) $ t1 $ t2 $ t3) =
229          abstract_ter abs (fn (t1, t2, t3) => c $ t1 $ t2 $ t3) t t1 t2 t3
230      | abs (t as \<^const>\<open>HOL.Not\<close> $ t1) = abstract_sub t (abs t1 #>> HOLogic.mk_not)
231      | abs (t as \<^const>\<open>HOL.disj\<close> $ t1 $ t2) =
232          abstract_sub t (abs t1 ##>> abs t2 #>> HOLogic.mk_disj)
233      | abs (t as (c as Const (\<^const_name>\<open>uminus_class.uminus\<close>, _)) $ t1) =
234          abstract_sub t (abs t1 #>> (fn u => c $ u))
235      | abs (t as (c as Const (\<^const_name>\<open>plus_class.plus\<close>, _)) $ t1 $ t2) =
236          abstract_sub t (abs t1 ##>> abs t2 #>> (fn (u1, u2) => c $ u1 $ u2))
237      | abs (t as (c as Const (\<^const_name>\<open>minus_class.minus\<close>, _)) $ t1 $ t2) =
238          abstract_sub t (abs t1 ##>> abs t2 #>> (fn (u1, u2) => c $ u1 $ u2))
239      | abs (t as (c as Const (\<^const_name>\<open>times_class.times\<close>, _)) $ t1 $ t2) =
240          abstract_sub t (abs t1 ##>> abs t2 #>> (fn (u1, u2) => c $ u1 $ u2))
241      | abs (t as (c as Const (\<^const_name>\<open>z3div\<close>, _)) $ t1 $ t2) =
242          abstract_sub t (abs t1 ##>> abs t2 #>> (fn (u1, u2) => c $ u1 $ u2))
243      | abs (t as (c as Const (\<^const_name>\<open>z3mod\<close>, _)) $ t1 $ t2) =
244          abstract_sub t (abs t1 ##>> abs t2 #>> (fn (u1, u2) => c $ u1 $ u2))
245      | abs (t as (c as Const (\<^const_name>\<open>HOL.eq\<close>, _)) $ t1 $ t2) =
246          abstract_sub t (abs t1 ##>> abs t2 #>> (fn (u1, u2) => c $ u1 $ u2))
247      | abs (t as (c as Const (\<^const_name>\<open>ord_class.less\<close>, _)) $ t1 $ t2) =
248          abstract_sub t (abs t1 ##>> abs t2 #>> (fn (u1, u2) => c $ u1 $ u2))
249      | abs (t as (c as Const (\<^const_name>\<open>ord_class.less_eq\<close>, _)) $ t1 $ t2) =
250          abstract_sub t (abs t1 ##>> abs t2 #>> (fn (u1, u2) => c $ u1 $ u2))
251      | abs t = abstract_sub t (fn cx =>
252          if can HOLogic.dest_number t then (t, cx)
253          else
254            (case apply_abstracters abs (get_arith_abstracters ctxt) t cx of
255              (SOME u, cx') => (u, cx')
256            | (NONE, _) => abstract_term t cx))
257  in abs u end
258
259fun abstract_unit (t as (\<^const>\<open>HOL.Not\<close> $ (\<^const>\<open>HOL.disj\<close> $ t1 $ t2))) =
260      abstract_sub t (abstract_unit t1 ##>> abstract_unit t2 #>>
261        HOLogic.mk_not o HOLogic.mk_disj)
262  | abstract_unit (t as (\<^const>\<open>HOL.disj\<close> $ t1 $ t2)) =
263      abstract_sub t (abstract_unit t1 ##>> abstract_unit t2 #>>
264        HOLogic.mk_disj)
265  | abstract_unit (t as (Const(\<^const_name>\<open>HOL.eq\<close>, _) $ t1 $ t2)) =
266      if fastype_of t1 = \<^typ>\<open>bool\<close> then
267        abstract_sub t (abstract_unit t1 ##>> abstract_unit t2 #>>
268          HOLogic.mk_eq)
269      else abstract_lit t
270  | abstract_unit (t as (\<^const>\<open>HOL.Not\<close> $ Const(\<^const_name>\<open>HOL.eq\<close>, _) $ t1 $ t2)) =
271      if fastype_of t1 = \<^typ>\<open>bool\<close> then
272        abstract_sub t (abstract_unit t1 ##>> abstract_unit t2 #>>
273          HOLogic.mk_eq #>> HOLogic.mk_not)
274      else abstract_lit t
275  | abstract_unit (t as (\<^const>\<open>HOL.Not\<close> $ t1)) =
276      abstract_sub t (abstract_unit t1 #>> HOLogic.mk_not)
277  | abstract_unit t = abstract_lit t
278
279
280(* theory lemmas *)
281
282fun try_provers ctxt rule [] thms t = replay_rule_error ctxt rule thms t
283  | try_provers ctxt rule ((name, prover) :: named_provers) thms t =
284      (case (trace ctxt (K ("Trying prover " ^ quote name)); try prover t) of
285        SOME thm => thm
286      | NONE => try_provers ctxt rule named_provers thms t)
287
288
289(* congruence *)
290
291fun certify_prop ctxt t = Thm.cterm_of ctxt (as_prop t)
292
293fun ctac ctxt prems i st = st |> (
294  resolve_tac ctxt (@{thm refl} :: prems) i
295  ORELSE (cong_tac ctxt i THEN ctac ctxt prems (i + 1) THEN ctac ctxt prems i))
296
297fun cong_basic ctxt thms t =
298  let val st = Thm.trivial (certify_prop ctxt t)
299  in
300    (case Seq.pull (ctac ctxt thms 1 st) of
301      SOME (thm, _) => thm
302    | NONE => raise THM ("cong", 0, thms @ [st]))
303  end
304
305val cong_dest_rules = @{lemma
306  "(\<not> P \<or> Q) \<and> (P \<or> \<not> Q) \<Longrightarrow> P = Q"
307  "(P \<or> \<not> Q) \<and> (\<not> P \<or> Q) \<Longrightarrow> P = Q"
308  by fast+}
309
310fun cong_full_core_tac ctxt =
311  eresolve_tac ctxt @{thms subst}
312  THEN' resolve_tac ctxt @{thms refl}
313  ORELSE' Classical.fast_tac ctxt
314
315fun cong_full ctxt thms t = prove ctxt t (fn ctxt' =>
316  Method.insert_tac ctxt thms
317  THEN' (cong_full_core_tac ctxt'
318    ORELSE' dresolve_tac ctxt cong_dest_rules
319    THEN' cong_full_core_tac ctxt'))
320
321fun cong_unfolding_first ctxt thms t =
322  let val reorder_for_simp = try (fn thm =>
323    let val t = Thm.prop_of ( @{thm eq_reflection} OF [thm])
324          val thm = (case Logic.dest_equals t of
325               (t1, t2) => if Term.size_of_term t1 > Term.size_of_term t2 then @{thm eq_reflection} OF [thm]
326                   else @{thm eq_reflection} OF [thm OF @{thms sym}])
327               handle TERM("dest_equals", _) =>  @{thm eq_reflection} OF [thm]
328    in thm end)
329  in
330    prove ctxt t (fn ctxt =>
331      Raw_Simplifier.rewrite_goal_tac ctxt
332        (map_filter reorder_for_simp thms)
333      THEN' Method.insert_tac ctxt thms
334     THEN' K (Clasimp.auto_tac ctxt))
335  end
336
337end;
338