1(*  Title:      HOL/Tools/SMT/z3_interface.ML
2    Author:     Sascha Boehme, TU Muenchen
3
4Interface to Z3 based on a relaxed version of SMT-LIB.
5*)
6
7signature Z3_INTERFACE =
8sig
9  val smtlib_z3C: SMT_Util.class
10
11  datatype sym = Sym of string * sym list
12  type mk_builtins = {
13    mk_builtin_typ: sym -> typ option,
14    mk_builtin_num: theory -> int -> typ -> cterm option,
15    mk_builtin_fun: theory -> sym -> cterm list -> cterm option }
16  val add_mk_builtins: mk_builtins -> Context.generic -> Context.generic
17  val mk_builtin_typ: Proof.context -> sym -> typ option
18  val mk_builtin_num: Proof.context -> int -> typ -> cterm option
19  val mk_builtin_fun: Proof.context -> sym -> cterm list -> cterm option
20
21  val is_builtin_theory_term: Proof.context -> term -> bool
22end;
23
24structure Z3_Interface: Z3_INTERFACE =
25struct
26
27val z3C = ["z3"]
28
29val smtlib_z3C = SMTLIB_Interface.smtlibC @ z3C
30
31
32(* interface *)
33
34local
35  fun translate_config ctxt =
36    {order = SMT_Util.First_Order,
37     logic = K "",
38     fp_kinds = [BNF_Util.Least_FP],
39     serialize = #serialize (SMTLIB_Interface.translate_config SMT_Util.First_Order ctxt)}
40
41  fun is_div_mod @{const divide (int)} = true
42    | is_div_mod @{const modulo (int)} = true
43    | is_div_mod _ = false
44
45  val have_int_div_mod = exists (Term.exists_subterm is_div_mod o Thm.prop_of)
46
47  fun add_div_mod _ (thms, extra_thms) =
48    if have_int_div_mod thms orelse have_int_div_mod extra_thms then
49      (thms, @{thms div_as_z3div mod_as_z3mod} @ extra_thms)
50    else (thms, extra_thms)
51
52  val setup_builtins =
53    SMT_Builtin.add_builtin_fun' smtlib_z3C (@{const times (int)}, "*") #>
54    SMT_Builtin.add_builtin_fun' smtlib_z3C (\<^const>\<open>z3div\<close>, "div") #>
55    SMT_Builtin.add_builtin_fun' smtlib_z3C (\<^const>\<open>z3mod\<close>, "mod")
56in
57
58val _ = Theory.setup (Context.theory_map (
59  setup_builtins #>
60  SMT_Normalize.add_extra_norm (smtlib_z3C, add_div_mod) #>
61  SMT_Translate.add_config (smtlib_z3C, translate_config)))
62
63end
64
65
66(* constructors *)
67
68datatype sym = Sym of string * sym list
69
70
71(** additional constructors **)
72
73type mk_builtins = {
74  mk_builtin_typ: sym -> typ option,
75  mk_builtin_num: theory -> int -> typ -> cterm option,
76  mk_builtin_fun: theory -> sym -> cterm list -> cterm option }
77
78fun chained _ [] = NONE
79  | chained f (b :: bs) = (case f b of SOME y => SOME y | NONE => chained f bs)
80
81fun chained_mk_builtin_typ bs sym =
82  chained (fn {mk_builtin_typ=mk, ...} : mk_builtins => mk sym) bs
83
84fun chained_mk_builtin_num ctxt bs i T =
85  let val thy = Proof_Context.theory_of ctxt
86  in chained (fn {mk_builtin_num=mk, ...} : mk_builtins => mk thy i T) bs end
87
88fun chained_mk_builtin_fun ctxt bs s cts =
89  let val thy = Proof_Context.theory_of ctxt
90  in chained (fn {mk_builtin_fun=mk, ...} : mk_builtins => mk thy s cts) bs end
91
92fun fst_int_ord ((i1, _), (i2, _)) = int_ord (i1, i2)
93
94structure Mk_Builtins = Generic_Data
95(
96  type T = (int * mk_builtins) list
97  val empty = []
98  val extend = I
99  fun merge data = Ord_List.merge fst_int_ord data
100)
101
102fun add_mk_builtins mk = Mk_Builtins.map (Ord_List.insert fst_int_ord (serial (), mk))
103
104fun get_mk_builtins ctxt = map snd (Mk_Builtins.get (Context.Proof ctxt))
105
106
107(** basic and additional constructors **)
108
109fun mk_builtin_typ _ (Sym ("Bool", _)) = SOME \<^typ>\<open>bool\<close>
110  | mk_builtin_typ _ (Sym ("Int", _)) = SOME \<^typ>\<open>int\<close>
111  | mk_builtin_typ _ (Sym ("bool", _)) = SOME \<^typ>\<open>bool\<close>  (*FIXME: legacy*)
112  | mk_builtin_typ _ (Sym ("int", _)) = SOME \<^typ>\<open>int\<close>  (*FIXME: legacy*)
113  | mk_builtin_typ ctxt sym = chained_mk_builtin_typ (get_mk_builtins ctxt) sym
114
115fun mk_builtin_num _ i \<^typ>\<open>int\<close> = SOME (Numeral.mk_cnumber \<^ctyp>\<open>int\<close> i)
116  | mk_builtin_num ctxt i T =
117      chained_mk_builtin_num ctxt (get_mk_builtins ctxt) i T
118
119val mk_true = Thm.cterm_of \<^context> (\<^const>\<open>Not\<close> $ \<^const>\<open>False\<close>)
120val mk_false = Thm.cterm_of \<^context> \<^const>\<open>False\<close>
121val mk_not = Thm.apply (Thm.cterm_of \<^context> \<^const>\<open>Not\<close>)
122val mk_implies = Thm.mk_binop (Thm.cterm_of \<^context> \<^const>\<open>HOL.implies\<close>)
123val mk_iff = Thm.mk_binop (Thm.cterm_of \<^context> @{const HOL.eq (bool)})
124val conj = Thm.cterm_of \<^context> \<^const>\<open>HOL.conj\<close>
125val disj = Thm.cterm_of \<^context> \<^const>\<open>HOL.disj\<close>
126
127fun mk_nary _ cu [] = cu
128  | mk_nary ct _ cts = uncurry (fold_rev (Thm.mk_binop ct)) (split_last cts)
129
130val eq = SMT_Util.mk_const_pat \<^theory> \<^const_name>\<open>HOL.eq\<close> Thm.dest_ctyp0
131fun mk_eq ct cu = Thm.mk_binop (SMT_Util.instT' ct eq) ct cu
132
133val if_term =
134  SMT_Util.mk_const_pat \<^theory> \<^const_name>\<open>If\<close> (Thm.dest_ctyp0 o Thm.dest_ctyp1)
135fun mk_if cc ct = Thm.mk_binop (Thm.apply (SMT_Util.instT' ct if_term) cc) ct
136
137val access = SMT_Util.mk_const_pat \<^theory> \<^const_name>\<open>fun_app\<close> Thm.dest_ctyp0
138fun mk_access array = Thm.apply (SMT_Util.instT' array access) array
139
140val update =
141  SMT_Util.mk_const_pat \<^theory> \<^const_name>\<open>fun_upd\<close> (Thm.dest_ctyp o Thm.dest_ctyp0)
142fun mk_update array index value =
143  let val cTs = Thm.dest_ctyp (Thm.ctyp_of_cterm array)
144  in Thm.apply (Thm.mk_binop (SMT_Util.instTs cTs update) array index) value end
145
146val mk_uminus = Thm.apply (Thm.cterm_of \<^context> @{const uminus (int)})
147val add = Thm.cterm_of \<^context> @{const plus (int)}
148val int0 = Numeral.mk_cnumber \<^ctyp>\<open>int\<close> 0
149val mk_sub = Thm.mk_binop (Thm.cterm_of \<^context> @{const minus (int)})
150val mk_mul = Thm.mk_binop (Thm.cterm_of \<^context> @{const times (int)})
151val mk_div = Thm.mk_binop (Thm.cterm_of \<^context> \<^const>\<open>z3div\<close>)
152val mk_mod = Thm.mk_binop (Thm.cterm_of \<^context> \<^const>\<open>z3mod\<close>)
153val mk_lt = Thm.mk_binop (Thm.cterm_of \<^context> @{const less (int)})
154val mk_le = Thm.mk_binop (Thm.cterm_of \<^context> @{const less_eq (int)})
155
156fun mk_builtin_fun ctxt sym cts =
157  (case (sym, cts) of
158    (Sym ("true", _), []) => SOME mk_true
159  | (Sym ("false", _), []) => SOME mk_false
160  | (Sym ("not", _), [ct]) => SOME (mk_not ct)
161  | (Sym ("and", _), _) => SOME (mk_nary conj mk_true cts)
162  | (Sym ("or", _), _) => SOME (mk_nary disj mk_false cts)
163  | (Sym ("implies", _), [ct, cu]) => SOME (mk_implies ct cu)
164  | (Sym ("iff", _), [ct, cu]) => SOME (mk_iff ct cu)
165  | (Sym ("~", _), [ct, cu]) => SOME (mk_iff ct cu)
166  | (Sym ("xor", _), [ct, cu]) => SOME (mk_not (mk_iff ct cu))
167  | (Sym ("if", _), [ct1, ct2, ct3]) => SOME (mk_if ct1 ct2 ct3)
168  | (Sym ("ite", _), [ct1, ct2, ct3]) => SOME (mk_if ct1 ct2 ct3) (* FIXME: remove *)
169  | (Sym ("=", _), [ct, cu]) => SOME (mk_eq ct cu)
170  | (Sym ("select", _), [ca, ck]) => SOME (Thm.apply (mk_access ca) ck)
171  | (Sym ("store", _), [ca, ck, cv]) => SOME (mk_update ca ck cv)
172  | _ =>
173    (case (sym, try (Thm.typ_of_cterm o hd) cts, cts) of
174      (Sym ("+", _), SOME \<^typ>\<open>int\<close>, _) => SOME (mk_nary add int0 cts)
175    | (Sym ("-", _), SOME \<^typ>\<open>int\<close>, [ct]) => SOME (mk_uminus ct)
176    | (Sym ("-", _), SOME \<^typ>\<open>int\<close>, [ct, cu]) => SOME (mk_sub ct cu)
177    | (Sym ("*", _), SOME \<^typ>\<open>int\<close>, [ct, cu]) => SOME (mk_mul ct cu)
178    | (Sym ("div", _), SOME \<^typ>\<open>int\<close>, [ct, cu]) => SOME (mk_div ct cu)
179    | (Sym ("mod", _), SOME \<^typ>\<open>int\<close>, [ct, cu]) => SOME (mk_mod ct cu)
180    | (Sym ("<", _), SOME \<^typ>\<open>int\<close>, [ct, cu]) => SOME (mk_lt ct cu)
181    | (Sym ("<=", _), SOME \<^typ>\<open>int\<close>, [ct, cu]) => SOME (mk_le ct cu)
182    | (Sym (">", _), SOME \<^typ>\<open>int\<close>, [ct, cu]) => SOME (mk_lt cu ct)
183    | (Sym (">=", _), SOME \<^typ>\<open>int\<close>, [ct, cu]) => SOME (mk_le cu ct)
184    | _ => chained_mk_builtin_fun ctxt (get_mk_builtins ctxt) sym cts))
185
186
187(* abstraction *)
188
189fun is_builtin_theory_term ctxt t =
190  if SMT_Builtin.is_builtin_num ctxt t then true
191  else
192    (case Term.strip_comb t of
193      (Const c, ts) => SMT_Builtin.is_builtin_fun ctxt c ts
194    | _ => false)
195
196end;
197