1(*  Title:      HOL/Tools/SMT/smt_real.ML
2    Author:     Sascha Boehme, TU Muenchen
3
4SMT setup for reals.
5*)
6
7structure SMT_Real: sig end =
8struct
9
10
11(* SMT-LIB logic *)
12
13fun smtlib_logic ts =
14  if exists (Term.exists_type (Term.exists_subtype (equal \<^typ>\<open>real\<close>))) ts
15  then SOME "AUFLIRA"
16  else NONE
17
18
19(* SMT-LIB and Z3 built-ins *)
20
21local
22  fun real_num _ i = SOME (string_of_int i ^ ".0")
23
24  fun is_linear [t] = SMT_Util.is_number t
25    | is_linear [t, u] = SMT_Util.is_number t orelse SMT_Util.is_number u
26    | is_linear _ = false
27
28  fun mk_times ts = Term.list_comb (@{const times (real)}, ts)
29
30  fun times _ _ ts = if is_linear ts then SOME ("*", 2, ts, mk_times) else NONE
31in
32
33val setup_builtins =
34  SMT_Builtin.add_builtin_typ SMTLIB_Interface.smtlibC
35    (\<^typ>\<open>real\<close>, K (SOME ("Real", [])), real_num) #>
36  fold (SMT_Builtin.add_builtin_fun' SMTLIB_Interface.smtlibC) [
37    (@{const less (real)}, "<"),
38    (@{const less_eq (real)}, "<="),
39    (@{const uminus (real)}, "-"),
40    (@{const plus (real)}, "+"),
41    (@{const minus (real)}, "-") ] #>
42  SMT_Builtin.add_builtin_fun SMTLIB_Interface.smtlibC
43    (Term.dest_Const @{const times (real)}, times) #>
44  SMT_Builtin.add_builtin_fun' Z3_Interface.smtlib_z3C
45    (@{const times (real)}, "*") #>
46  SMT_Builtin.add_builtin_fun' Z3_Interface.smtlib_z3C
47    (@{const divide (real)}, "/")
48
49end
50
51
52(* Z3 constructors *)
53
54local
55  fun z3_mk_builtin_typ (Z3_Interface.Sym ("Real", _)) = SOME \<^typ>\<open>real\<close>
56    | z3_mk_builtin_typ (Z3_Interface.Sym ("real", _)) = SOME \<^typ>\<open>real\<close>
57        (*FIXME: delete*)
58    | z3_mk_builtin_typ _ = NONE
59
60  fun z3_mk_builtin_num _ i T =
61    if T = \<^typ>\<open>real\<close> then SOME (Numeral.mk_cnumber \<^ctyp>\<open>real\<close> i)
62    else NONE
63
64  fun mk_nary _ cu [] = cu
65    | mk_nary ct _ cts = uncurry (fold_rev (Thm.mk_binop ct)) (split_last cts)
66
67  val mk_uminus = Thm.apply (Thm.cterm_of \<^context> @{const uminus (real)})
68  val add = Thm.cterm_of \<^context> @{const plus (real)}
69  val real0 = Numeral.mk_cnumber \<^ctyp>\<open>real\<close> 0
70  val mk_sub = Thm.mk_binop (Thm.cterm_of \<^context> @{const minus (real)})
71  val mk_mul = Thm.mk_binop (Thm.cterm_of \<^context> @{const times (real)})
72  val mk_div = Thm.mk_binop (Thm.cterm_of \<^context> @{const divide (real)})
73  val mk_lt = Thm.mk_binop (Thm.cterm_of \<^context> @{const less (real)})
74  val mk_le = Thm.mk_binop (Thm.cterm_of \<^context> @{const less_eq (real)})
75
76  fun z3_mk_builtin_fun (Z3_Interface.Sym ("-", _)) [ct] = SOME (mk_uminus ct)
77    | z3_mk_builtin_fun (Z3_Interface.Sym ("+", _)) cts = SOME (mk_nary add real0 cts)
78    | z3_mk_builtin_fun (Z3_Interface.Sym ("-", _)) [ct, cu] = SOME (mk_sub ct cu)
79    | z3_mk_builtin_fun (Z3_Interface.Sym ("*", _)) [ct, cu] = SOME (mk_mul ct cu)
80    | z3_mk_builtin_fun (Z3_Interface.Sym ("/", _)) [ct, cu] = SOME (mk_div ct cu)
81    | z3_mk_builtin_fun (Z3_Interface.Sym ("<", _)) [ct, cu] = SOME (mk_lt ct cu)
82    | z3_mk_builtin_fun (Z3_Interface.Sym ("<=", _)) [ct, cu] = SOME (mk_le ct cu)
83    | z3_mk_builtin_fun (Z3_Interface.Sym (">", _)) [ct, cu] = SOME (mk_lt cu ct)
84    | z3_mk_builtin_fun (Z3_Interface.Sym (">=", _)) [ct, cu] = SOME (mk_le cu ct)
85    | z3_mk_builtin_fun _ _ = NONE
86in
87
88val z3_mk_builtins = {
89  mk_builtin_typ = z3_mk_builtin_typ,
90  mk_builtin_num = z3_mk_builtin_num,
91  mk_builtin_fun = (fn _ => fn sym => fn cts =>
92    (case try (Thm.typ_of_cterm o hd) cts of
93      SOME \<^typ>\<open>real\<close> => z3_mk_builtin_fun sym cts
94    | _ => NONE)) }
95
96end
97
98
99(* Z3 proof replay *)
100
101val real_linarith_proc =
102  Simplifier.make_simproc \<^context> "fast_real_arith"
103   {lhss = [\<^term>\<open>(m::real) < n\<close>, \<^term>\<open>(m::real) \<le> n\<close>, \<^term>\<open>(m::real) = n\<close>],
104    proc = K Lin_Arith.simproc}
105
106
107(* setup *)
108
109val _ = Theory.setup (Context.theory_map (
110  SMTLIB_Interface.add_logic (10, smtlib_logic) #>
111  setup_builtins #>
112  Z3_Interface.add_mk_builtins z3_mk_builtins #>
113  SMT_Replay.add_simproc real_linarith_proc))
114
115end;
116