1(*  Title:      HOL/Tools/SMT/smtlib_proof.ML
2    Author:     Sascha Boehme, TU Muenchen
3    Author:     Mathias Fleury, ENS Rennes
4    Author:     Jasmin Blanchette, TU Muenchen
5
6SMT-LIB-2-style proofs: parsing and abstract syntax tree.
7*)
8
9signature SMTLIB_PROOF =
10sig
11  datatype 'b shared = Tree of SMTLIB.tree | Term of term | Proof of 'b | None
12  type ('a, 'b) context
13
14  val mk_context: Proof.context -> int -> 'b shared Symtab.table -> typ Symtab.table ->
15    term Symtab.table -> 'a -> ('a, 'b) context
16  val empty_context: Proof.context -> typ Symtab.table -> term Symtab.table -> ('a list, 'b) context
17  val ctxt_of: ('a, 'b) context -> Proof.context
18  val lookup_binding: ('a, 'b) context -> string -> 'b shared
19  val update_binding: string * 'b shared -> ('a, 'b) context -> ('a, 'b) context
20  val with_bindings: (string * 'b shared) list -> (('a, 'b) context -> 'c * ('d, 'b) context) ->
21    ('a, 'b) context -> 'c * ('d, 'b) context
22  val next_id: ('a, 'b) context -> int * ('a, 'b) context
23  val with_fresh_names: (('a list, 'b) context ->
24    term * ((string * (string * typ)) list, 'b) context) -> ('c, 'b) context -> (term * string list)
25
26  (*type and term parsers*)
27  type type_parser = SMTLIB.tree * typ list -> typ option
28  type term_parser = SMTLIB.tree * term list -> term option
29  val add_type_parser: type_parser -> Context.generic -> Context.generic
30  val add_term_parser: term_parser -> Context.generic -> Context.generic
31
32  exception SMTLIB_PARSE of string * SMTLIB.tree
33
34  val declare_fun: string -> typ -> ((string * typ) list, 'a) context ->
35    ((string * typ) list, 'a) context
36  val dest_binding: SMTLIB.tree -> string * 'a shared
37  val type_of: ('a, 'b) context -> SMTLIB.tree -> typ
38  val term_of: SMTLIB.tree -> ((string * (string * typ)) list, 'a) context ->
39    term * ((string * (string * typ)) list, 'a) context
40end;
41
42structure SMTLIB_Proof: SMTLIB_PROOF =
43struct
44
45(* proof parser context *)
46
47datatype 'b shared = Tree of SMTLIB.tree | Term of term | Proof of 'b | None
48
49type ('a, 'b) context = {
50  ctxt: Proof.context,
51  id: int,
52  syms: 'b shared Symtab.table,
53  typs: typ Symtab.table,
54  funs: term Symtab.table,
55  extra: 'a}
56
57fun mk_context ctxt id syms typs funs extra: ('a, 'b) context =
58  {ctxt = ctxt, id = id, syms = syms, typs = typs, funs = funs, extra = extra}
59
60fun empty_context ctxt typs funs = mk_context ctxt 1 Symtab.empty typs funs []
61
62fun ctxt_of ({ctxt, ...}: ('a, 'b) context) = ctxt
63
64fun lookup_binding ({syms, ...}: ('a, 'b) context) =
65  the_default None o Symtab.lookup syms
66
67fun map_syms f ({ctxt, id, syms, typs, funs, extra}: ('a, 'b) context) =
68  mk_context ctxt id (f syms) typs funs extra
69
70fun update_binding b = map_syms (Symtab.update b)
71
72fun with_bindings bs f cx =
73  let val bs' = map (lookup_binding cx o fst) bs
74  in
75    cx
76    |> fold update_binding bs
77    |> f
78    ||> fold2 (fn (name, _) => update_binding o pair name) bs bs'
79  end
80
81fun next_id ({ctxt, id, syms, typs, funs, extra}: ('a, 'b) context) =
82  (id, mk_context ctxt (id + 1) syms typs funs extra)
83
84fun with_fresh_names f ({ctxt, id, syms, typs, funs, ...}: ('a, 'b) context) =
85  let
86    fun bind (_, v as (_, T)) t = Logic.all_const T $ Term.absfree v t
87
88    val needs_inferT = equal Term.dummyT orf Term.is_TVar
89    val needs_infer = Term.exists_type (Term.exists_subtype needs_inferT)
90    fun infer_types ctxt =
91      singleton (Type_Infer_Context.infer_types ctxt) #>
92      singleton (Proof_Context.standard_term_check_finish ctxt)
93    fun infer ctxt t = if needs_infer t then infer_types ctxt t else t
94
95    val (t, {ctxt = ctxt', extra = names, ...}: ((string * (string * typ)) list, 'b) context) =
96      f (mk_context ctxt id syms typs funs [])
97    val t' = infer ctxt' (fold_rev bind names (HOLogic.mk_Trueprop t))
98  in
99    (t', map fst names)
100  end
101
102fun lookup_typ ({typs, ...}: ('a, 'b) context) = Symtab.lookup typs
103fun lookup_fun ({funs, ...}: ('a, 'b) context) = Symtab.lookup funs
104
105
106(* core type and term parser *)
107
108fun core_type_parser (SMTLIB.Sym "Bool", []) = SOME \<^typ>\<open>HOL.bool\<close>
109  | core_type_parser (SMTLIB.Sym "Int", []) = SOME \<^typ>\<open>Int.int\<close>
110  | core_type_parser _ = NONE
111
112fun mk_unary n t =
113  let val T = fastype_of t
114  in Const (n, T --> T) $ t end
115
116fun mk_binary' n T U t1 t2 = Const (n, [T, T] ---> U) $ t1 $ t2
117
118fun mk_binary n t1 t2 =
119  let val T = fastype_of t1
120  in mk_binary' n T T t1 t2 end
121
122fun mk_rassoc f t ts =
123  let val us = rev (t :: ts)
124  in fold f (tl us) (hd us) end
125
126fun mk_lassoc f t ts = fold (fn u1 => fn u2 => f u2 u1) ts t
127
128fun mk_lassoc' n = mk_lassoc (mk_binary n)
129
130fun mk_binary_pred n S t1 t2 =
131  let
132    val T1 = fastype_of t1
133    val T2 = fastype_of t2
134    val T =
135      if T1 <> Term.dummyT then T1
136      else if T2 <> Term.dummyT then T2
137      else TVar (("?a", serial ()), S)
138  in mk_binary' n T \<^typ>\<open>HOL.bool\<close> t1 t2 end
139
140fun mk_less t1 t2 = mk_binary_pred \<^const_name>\<open>ord_class.less\<close> \<^sort>\<open>linorder\<close> t1 t2
141fun mk_less_eq t1 t2 = mk_binary_pred \<^const_name>\<open>ord_class.less_eq\<close> \<^sort>\<open>linorder\<close> t1 t2
142
143fun core_term_parser (SMTLIB.Sym "true", _) = SOME \<^const>\<open>HOL.True\<close>
144  | core_term_parser (SMTLIB.Sym "false", _) = SOME \<^const>\<open>HOL.False\<close>
145  | core_term_parser (SMTLIB.Sym "not", [t]) = SOME (HOLogic.mk_not t)
146  | core_term_parser (SMTLIB.Sym "and", t :: ts) = SOME (mk_rassoc (curry HOLogic.mk_conj) t ts)
147  | core_term_parser (SMTLIB.Sym "or", t :: ts) = SOME (mk_rassoc (curry HOLogic.mk_disj) t ts)
148  | core_term_parser (SMTLIB.Sym "=>", [t1, t2]) = SOME (HOLogic.mk_imp (t1, t2))
149  | core_term_parser (SMTLIB.Sym "implies", [t1, t2]) = SOME (HOLogic.mk_imp (t1, t2))
150  | core_term_parser (SMTLIB.Sym "=", [t1, t2]) = SOME (HOLogic.mk_eq (t1, t2))
151  | core_term_parser (SMTLIB.Sym "~", [t1, t2]) = SOME (HOLogic.mk_eq (t1, t2))
152  | core_term_parser (SMTLIB.Sym "ite", [t1, t2, t3]) =
153      let
154        val T = fastype_of t2
155        val c = Const (\<^const_name>\<open>HOL.If\<close>, [\<^typ>\<open>HOL.bool\<close>, T, T] ---> T)
156      in SOME (c $ t1 $ t2 $ t3) end
157  | core_term_parser (SMTLIB.Num i, []) = SOME (HOLogic.mk_number \<^typ>\<open>Int.int\<close> i)
158  | core_term_parser (SMTLIB.Sym "-", [t]) = SOME (mk_unary \<^const_name>\<open>uminus_class.uminus\<close> t)
159  | core_term_parser (SMTLIB.Sym "~", [t]) = SOME (mk_unary \<^const_name>\<open>uminus_class.uminus\<close> t)
160  | core_term_parser (SMTLIB.Sym "+", t :: ts) =
161      SOME (mk_lassoc' \<^const_name>\<open>plus_class.plus\<close> t ts)
162  | core_term_parser (SMTLIB.Sym "-", t :: ts) =
163      SOME (mk_lassoc' \<^const_name>\<open>minus_class.minus\<close> t ts)
164  | core_term_parser (SMTLIB.Sym "*", t :: ts) =
165      SOME (mk_lassoc' \<^const_name>\<open>times_class.times\<close> t ts)
166  | core_term_parser (SMTLIB.Sym "div", [t1, t2]) = SOME (mk_binary \<^const_name>\<open>z3div\<close> t1 t2)
167  | core_term_parser (SMTLIB.Sym "mod", [t1, t2]) = SOME (mk_binary \<^const_name>\<open>z3mod\<close> t1 t2)
168  | core_term_parser (SMTLIB.Sym "<", [t1, t2]) = SOME (mk_less t1 t2)
169  | core_term_parser (SMTLIB.Sym ">", [t1, t2]) = SOME (mk_less t2 t1)
170  | core_term_parser (SMTLIB.Sym "<=", [t1, t2]) = SOME (mk_less_eq t1 t2)
171  | core_term_parser (SMTLIB.Sym ">=", [t1, t2]) = SOME (mk_less_eq t2 t1)
172  | core_term_parser _ = NONE
173
174
175(* custom type and term parsers *)
176
177type type_parser = SMTLIB.tree * typ list -> typ option
178
179type term_parser = SMTLIB.tree * term list -> term option
180
181fun id_ord ((id1, _), (id2, _)) = int_ord (id1, id2)
182
183structure Parsers = Generic_Data
184(
185  type T = (int * type_parser) list * (int * term_parser) list
186  val empty : T = ([(serial (), core_type_parser)], [(serial (), core_term_parser)])
187  val extend = I
188  fun merge ((tys1, ts1), (tys2, ts2)) =
189    (Ord_List.merge id_ord (tys1, tys2), Ord_List.merge id_ord (ts1, ts2))
190)
191
192fun add_type_parser type_parser =
193  Parsers.map (apfst (Ord_List.insert id_ord (serial (), type_parser)))
194
195fun add_term_parser term_parser =
196  Parsers.map (apsnd (Ord_List.insert id_ord (serial (), term_parser)))
197
198fun get_type_parsers ctxt = map snd (fst (Parsers.get (Context.Proof ctxt)))
199fun get_term_parsers ctxt = map snd (snd (Parsers.get (Context.Proof ctxt)))
200
201fun apply_parsers parsers x =
202  let
203    fun apply [] = NONE
204      | apply (parser :: parsers) =
205          (case parser x of
206            SOME y => SOME y
207          | NONE => apply parsers)
208  in apply parsers end
209
210
211(* type and term parsing *)
212
213exception SMTLIB_PARSE of string * SMTLIB.tree
214
215val desymbolize = Name.desymbolize (SOME false) o perhaps (try (unprefix "?"))
216
217fun fresh_fun add name n T ({ctxt, id, syms, typs, funs, extra}: ('a, 'b) context) =
218  let
219    val (n', ctxt') = yield_singleton Variable.variant_fixes n ctxt
220    val t = Free (n', T)
221    val funs' = Symtab.update (name, t) funs
222  in (t, mk_context ctxt' id syms typs funs' (add (n', T) extra)) end
223
224fun declare_fun name = snd oo fresh_fun cons name (desymbolize name)
225fun declare_free name = fresh_fun (cons o pair name) name (desymbolize name)
226
227fun parse_type cx ty Ts =
228  (case apply_parsers (get_type_parsers (ctxt_of cx)) (ty, Ts) of
229    SOME T => T
230  | NONE =>
231      (case ty of
232        SMTLIB.Sym name =>
233          (case lookup_typ cx name of
234            SOME T => T
235          | NONE => raise SMTLIB_PARSE ("unknown SMT type", ty))
236      | _ => raise SMTLIB_PARSE ("bad SMT type format", ty)))
237
238fun parse_term t ts cx =
239  (case apply_parsers (get_term_parsers (ctxt_of cx)) (t, ts) of
240    SOME u => (u, cx)
241  | NONE =>
242      (case t of
243        SMTLIB.Sym name =>
244          (case lookup_fun cx name of
245            SOME u => (Term.list_comb (u, ts), cx)
246          | NONE =>
247              if null ts then declare_free name Term.dummyT cx
248              else raise SMTLIB_PARSE ("bad SMT term", t))
249      | _ => raise SMTLIB_PARSE ("bad SMT term format", t)))
250
251fun type_of cx ty =
252  (case try (parse_type cx ty) [] of
253    SOME T => T
254  | NONE =>
255      (case ty of
256        SMTLIB.S (ty' :: tys) => parse_type cx ty' (map (type_of cx) tys)
257      | _ => raise SMTLIB_PARSE ("bad SMT type", ty)))
258
259fun dest_var cx (SMTLIB.S [SMTLIB.Sym name, ty]) = (name, (desymbolize name, type_of cx ty))
260  | dest_var _ v = raise SMTLIB_PARSE ("bad SMT quantifier variable format", v)
261
262fun dest_body (SMTLIB.S (SMTLIB.Sym "!" :: body :: _)) = dest_body body
263  | dest_body body = body
264
265fun dest_binding (SMTLIB.S [SMTLIB.Sym name, t]) = (name, Tree t)
266  | dest_binding b = raise SMTLIB_PARSE ("bad SMT let binding format", b)
267
268fun mk_choice (x, T, P) =  HOLogic.choice_const T $ absfree (x, T) P
269
270fun term_of t cx =
271  (case t of
272    SMTLIB.S [SMTLIB.Sym "forall", SMTLIB.S vars, body] => quant HOLogic.mk_all vars body cx
273  | SMTLIB.S [SMTLIB.Sym "exists", SMTLIB.S vars, body] => quant HOLogic.mk_exists vars body cx
274  | SMTLIB.S [SMTLIB.Sym "choice", SMTLIB.S vars, body] => quant mk_choice vars body cx
275  | SMTLIB.S [SMTLIB.Sym "let", SMTLIB.S bindings, body] =>
276      with_bindings (map dest_binding bindings) (term_of body) cx
277  | SMTLIB.S (SMTLIB.Sym "!" :: t :: _) => term_of t cx
278  | SMTLIB.S (f :: args) =>
279      cx
280      |> fold_map term_of args
281      |-> parse_term f
282  | SMTLIB.Sym name =>
283      (case lookup_binding cx name of
284        Tree u =>
285          cx
286          |> term_of u
287          |-> (fn u' => pair u' o update_binding (name, Term u'))
288      | Term u => (u, cx)
289      | None => parse_term t [] cx
290      | _ => raise SMTLIB_PARSE ("bad SMT term format", t))
291  | _ => parse_term t [] cx)
292
293and quant q vars body cx =
294  let val vs = map (dest_var cx) vars
295  in
296    cx
297    |> with_bindings (map (apsnd (Term o Free)) vs) (term_of (dest_body body))
298    |>> fold_rev (fn (_, (n, T)) => fn t => q (n, T, t)) vs
299  end
300
301end;
302