1(*  Title:      HOL/Tools/SMT/smt_translate.ML
2    Author:     Sascha Boehme, TU Muenchen
3
4Translate theorems into an SMT intermediate format and serialize them.
5*)
6
7signature SMT_TRANSLATE =
8sig
9  (*intermediate term structure*)
10  datatype squant = SForall | SExists
11  datatype 'a spattern = SPat of 'a list | SNoPat of 'a list
12  datatype sterm =
13    SVar of int * sterm list |
14    SConst of string * sterm list |
15    SQua of squant * string list * sterm spattern list * sterm
16
17  (*translation configuration*)
18  type sign = {
19    logic: string,
20    sorts: string list,
21    dtyps: (BNF_Util.fp_kind * (string * (string * (string * string) list) list)) list,
22    funcs: (string * (string list * string)) list }
23  type config = {
24    order: SMT_Util.order,
25    logic: term list -> string,
26    fp_kinds: BNF_Util.fp_kind list,
27    serialize: (string * string) list -> string list -> sign -> sterm list -> string }
28  type replay_data = {
29    context: Proof.context,
30    typs: typ Symtab.table,
31    terms: term Symtab.table,
32    ll_defs: term list,
33    rewrite_rules: thm list,
34    assms: (int * thm) list }
35
36  (*translation*)
37  val add_config: SMT_Util.class * (Proof.context -> config) -> Context.generic -> Context.generic
38  val translate: Proof.context -> (string * string) list -> string list -> (int * thm) list ->
39    string * replay_data
40end;
41
42structure SMT_Translate: SMT_TRANSLATE =
43struct
44
45
46(* intermediate term structure *)
47
48datatype squant = SForall | SExists
49
50datatype 'a spattern =
51  SPat of 'a list | SNoPat of 'a list
52
53datatype sterm =
54  SVar of int * sterm list |
55  SConst of string * sterm list |
56  SQua of squant * string list * sterm spattern list * sterm
57
58
59(* translation configuration *)
60
61type sign = {
62  logic: string,
63  sorts: string list,
64  dtyps: (BNF_Util.fp_kind * (string * (string * (string * string) list) list)) list,
65  funcs: (string * (string list * string)) list }
66
67type config = {
68  order: SMT_Util.order,
69  logic: term list -> string,
70  fp_kinds: BNF_Util.fp_kind list,
71  serialize: (string * string) list -> string list -> sign -> sterm list -> string }
72
73type replay_data = {
74  context: Proof.context,
75  typs: typ Symtab.table,
76  terms: term Symtab.table,
77  ll_defs: term list,
78  rewrite_rules: thm list,
79  assms: (int * thm) list }
80
81
82(* translation context *)
83
84fun add_components_of_typ (Type (s, Ts)) =
85    cons (Long_Name.base_name s) #> fold_rev add_components_of_typ Ts
86  | add_components_of_typ (TFree (s, _)) = cons (perhaps (try (unprefix "'")) s)
87  | add_components_of_typ _ = I;
88
89fun suggested_name_of_typ T = space_implode "_" (add_components_of_typ T []);
90
91fun suggested_name_of_term (Const (s, _)) = Long_Name.base_name s
92  | suggested_name_of_term (Free (s, _)) = s
93  | suggested_name_of_term _ = Name.uu
94
95val empty_tr_context = (Name.context, Typtab.empty, Termtab.empty)
96val safe_suffix = "$"
97
98fun add_typ T proper (cx as (names, typs, terms)) =
99  (case Typtab.lookup typs T of
100    SOME (name, _) => (name, cx)
101  | NONE =>
102      let
103        val sugg = Name.desymbolize (SOME true) (suggested_name_of_typ T) ^ safe_suffix
104        val (name, names') = Name.variant sugg names
105        val typs' = Typtab.update (T, (name, proper)) typs
106      in (name, (names', typs', terms)) end)
107
108fun add_fun t sort (cx as (names, typs, terms)) =
109  (case Termtab.lookup terms t of
110    SOME (name, _) => (name, cx)
111  | NONE =>
112      let
113        val sugg = Name.desymbolize (SOME false) (suggested_name_of_term t) ^ safe_suffix
114        val (name, names') = Name.variant sugg names
115        val terms' = Termtab.update (t, (name, sort)) terms
116      in (name, (names', typs, terms')) end)
117
118fun sign_of logic dtyps (_, typs, terms) = {
119  logic = logic,
120  sorts = Typtab.fold (fn (_, (n, true)) => cons n | _ => I) typs [],
121  dtyps = dtyps,
122  funcs = Termtab.fold (fn (_, (n, SOME ss)) => cons (n,ss) | _ => I) terms []}
123
124fun replay_data_of ctxt ll_defs rules assms (_, typs, terms) =
125  let
126    fun add_typ (T, (n, _)) = Symtab.update (n, T)
127    val typs' = Typtab.fold add_typ typs Symtab.empty
128
129    fun add_fun (t, (n, _)) = Symtab.update (n, t)
130    val terms' = Termtab.fold add_fun terms Symtab.empty
131  in
132    {context = ctxt, typs = typs', terms = terms', ll_defs = ll_defs, rewrite_rules = rules,
133     assms = assms}
134  end
135
136
137(* preprocessing *)
138
139(** (co)datatype declarations **)
140
141fun collect_co_datatypes fp_kinds (tr_context, ctxt) ts =
142  let
143    val (fp_decls, ctxt') =
144      ([], ctxt)
145      |> fold (Term.fold_types (SMT_Datatypes.add_decls fp_kinds)) ts
146      |>> flat
147
148    fun is_decl_typ T = exists (equal T o fst o snd) fp_decls
149
150    fun add_typ' T proper =
151      (case SMT_Builtin.dest_builtin_typ ctxt' T of
152        SOME (n, Ts) => pair n (* FIXME HO: Consider Ts *)
153      | NONE => add_typ T proper)
154
155    fun tr_select sel =
156      let val T = Term.range_type (Term.fastype_of sel)
157      in add_fun sel NONE ##>> add_typ' T (not (is_decl_typ T)) end
158    fun tr_constr (constr, selects) =
159      add_fun constr NONE ##>> fold_map tr_select selects
160    fun tr_typ (fp, (T, cases)) =
161      add_typ' T false ##>> fold_map tr_constr cases #>> pair fp
162
163    val (fp_decls', tr_context') = fold_map tr_typ fp_decls tr_context
164
165    fun add (constr, selects) =
166      Termtab.update (constr, length selects) #>
167      fold (Termtab.update o rpair 1) selects
168
169    val funcs = fold (fold add o snd o snd) fp_decls Termtab.empty
170
171  in ((funcs, fp_decls', tr_context', ctxt'), ts) end
172    (* FIXME: also return necessary (co)datatype theorems *)
173
174
175(** eta-expand quantifiers, let expressions and built-ins *)
176
177local
178  fun eta f T t = Abs (Name.uu, T, f (Term.incr_boundvars 1 t $ Bound 0))
179
180  fun exp f T = eta f (Term.domain_type (Term.domain_type T))
181
182  fun exp2 T q =
183    let val U = Term.domain_type T
184    in Abs (Name.uu, U, q $ eta I (Term.domain_type U) (Bound 0)) end
185
186  fun expf k i T t =
187    let val Ts = drop i (fst (SMT_Util.dest_funT k T))
188    in
189      Term.incr_boundvars (length Ts) t
190      |> fold_rev (fn i => fn u => u $ Bound i) (0 upto length Ts - 1)
191      |> fold_rev (fn T => fn u => Abs (Name.uu, T, u)) Ts
192    end
193in
194
195fun eta_expand ctxt funcs =
196  let
197    fun exp_func t T ts =
198      (case Termtab.lookup funcs t of
199        SOME k => Term.list_comb (t, ts) |> k <> length ts ? expf k (length ts) T
200      | NONE => Term.list_comb (t, ts))
201
202    fun expand ((q as Const (\<^const_name>\<open>All\<close>, _)) $ Abs a) = q $ abs_expand a
203      | expand ((q as Const (\<^const_name>\<open>All\<close>, T)) $ t) = q $ exp expand T t
204      | expand (q as Const (\<^const_name>\<open>All\<close>, T)) = exp2 T q
205      | expand ((q as Const (\<^const_name>\<open>Ex\<close>, _)) $ Abs a) = q $ abs_expand a
206      | expand ((q as Const (\<^const_name>\<open>Ex\<close>, T)) $ t) = q $ exp expand T t
207      | expand (q as Const (\<^const_name>\<open>Ex\<close>, T)) = exp2 T q
208      | expand (Const (\<^const_name>\<open>Let\<close>, T) $ t) =
209          let val U = Term.domain_type (Term.range_type T)
210          in Abs (Name.uu, U, Bound 0 $ Term.incr_boundvars 1 t) end
211      | expand (Const (\<^const_name>\<open>Let\<close>, T)) =
212          let val U = Term.domain_type (Term.range_type T)
213          in Abs (Name.uu, Term.domain_type T, Abs (Name.uu, U, Bound 0 $ Bound 1)) end
214      | expand t =
215          (case Term.strip_comb t of
216            (Const (\<^const_name>\<open>Let\<close>, _), t1 :: t2 :: ts) =>
217            Term.betapplys (Term.betapply (expand t2, expand t1), map expand ts)
218          | (u as Const (c as (_, T)), ts) =>
219              (case SMT_Builtin.dest_builtin ctxt c ts of
220                SOME (_, k, us, mk) =>
221                  if k = length us then mk (map expand us)
222                  else if k < length us then chop k (map expand us) |>> mk |> Term.list_comb
223                  else expf k (length ts) T (mk (map expand us))
224              | NONE => exp_func u T (map expand ts))
225          | (u as Free (_, T), ts) => exp_func u T (map expand ts)
226          | (Abs a, ts) => Term.list_comb (abs_expand a, map expand ts)
227          | (u, ts) => Term.list_comb (u, map expand ts))
228
229    and abs_expand (n, T, t) = Abs (n, T, expand t)
230
231  in map expand end
232
233end
234
235
236(** introduce explicit applications **)
237
238local
239  (*
240    Make application explicit for functions with varying number of arguments.
241  *)
242
243  fun add t i = apfst (Termtab.map_default (t, i) (Integer.min i))
244  fun add_type T = apsnd (Typtab.update (T, ()))
245
246  fun min_arities t =
247    (case Term.strip_comb t of
248      (u as Const _, ts) => add u (length ts) #> fold min_arities ts
249    | (u as Free _, ts) => add u (length ts) #> fold min_arities ts
250    | (Abs (_, T, u), ts) => (can dest_funT T ? add_type T) #> min_arities u #> fold min_arities ts
251    | (_, ts) => fold min_arities ts)
252
253  fun take_vars_into_account types t i =
254    let
255      fun find_min j (T as Type (\<^type_name>\<open>fun\<close>, [_, T'])) =
256          if j = i orelse Typtab.defined types T then j else find_min (j + 1) T'
257        | find_min j _ = j
258    in find_min 0 (Term.type_of t) end
259
260  fun app u (t, T) = (Const (\<^const_name>\<open>fun_app\<close>, T --> T) $ t $ u, Term.range_type T)
261
262  fun apply i t T ts =
263    let
264      val (ts1, ts2) = chop i ts
265      val (_, U) = SMT_Util.dest_funT i T
266    in fst (fold app ts2 (Term.list_comb (t, ts1), U)) end
267in
268
269fun intro_explicit_application ctxt funcs ts =
270  let
271    val explicit_application = Config.get ctxt SMT_Config.explicit_application
272    val get_arities =
273      (case explicit_application of
274        0 => min_arities
275      | 1 => min_arities
276      | 2 => K I
277      | n => error ("Illegal value for " ^ quote (Config.name_of SMT_Config.explicit_application) ^
278          ": " ^ string_of_int n))
279
280    val (arities, types) = fold get_arities ts (Termtab.empty, Typtab.empty)
281    val arities' = arities |> explicit_application = 1 ? Termtab.map (take_vars_into_account types)
282
283    fun app_func t T ts =
284      if is_some (Termtab.lookup funcs t) then Term.list_comb (t, ts)
285      else apply (the_default 0 (Termtab.lookup arities' t)) t T ts
286
287    fun in_list T f t = SMT_Util.mk_symb_list T (map f (SMT_Util.dest_symb_list t))
288
289    fun traverse Ts t =
290      (case Term.strip_comb t of
291        (q as Const (\<^const_name>\<open>All\<close>, _), [Abs (x, T, u)]) =>
292          q $ Abs (x, T, in_trigger (T :: Ts) u)
293      | (q as Const (\<^const_name>\<open>Ex\<close>, _), [Abs (x, T, u)]) =>
294          q $ Abs (x, T, in_trigger (T :: Ts) u)
295      | (q as Const (\<^const_name>\<open>Let\<close>, _), [u1, u2 as Abs _]) =>
296          q $ traverse Ts u1 $ traverse Ts u2
297      | (u as Const (c as (_, T)), ts) =>
298          (case SMT_Builtin.dest_builtin ctxt c ts of
299            SOME (_, k, us, mk) =>
300              let
301                val (ts1, ts2) = chop k (map (traverse Ts) us)
302                val U = Term.strip_type T |>> snd o chop k |> (op --->)
303              in apply 0 (mk ts1) U ts2 end
304          | NONE => app_func u T (map (traverse Ts) ts))
305      | (u as Free (_, T), ts) => app_func u T (map (traverse Ts) ts)
306      | (u as Bound i, ts) => apply 0 u (nth Ts i) (map (traverse Ts) ts)
307      | (Abs (n, T, u), ts) => traverses Ts (Abs (n, T, traverse (T::Ts) u)) ts
308      | (u, ts) => traverses Ts u ts)
309    and in_trigger Ts ((c as \<^const>\<open>trigger\<close>) $ p $ t) = c $ in_pats Ts p $ traverse Ts t
310      | in_trigger Ts t = traverse Ts t
311    and in_pats Ts ps =
312      in_list \<^typ>\<open>pattern symb_list\<close> (in_list \<^typ>\<open>pattern\<close> (in_pat Ts)) ps
313    and in_pat Ts ((p as Const (\<^const_name>\<open>pat\<close>, _)) $ t) = p $ traverse Ts t
314      | in_pat Ts ((p as Const (\<^const_name>\<open>nopat\<close>, _)) $ t) = p $ traverse Ts t
315      | in_pat _ t = raise TERM ("bad pattern", [t])
316    and traverses Ts t ts = Term.list_comb (t, map (traverse Ts) ts)
317  in map (traverse []) ts end
318
319val fun_app_eq = mk_meta_eq @{thm fun_app_def}
320
321end
322
323
324(** map HOL formulas to FOL formulas (i.e., separate formulas froms terms) **)
325
326local
327  val is_quant = member (op =) [\<^const_name>\<open>All\<close>, \<^const_name>\<open>Ex\<close>]
328
329  val fol_rules = [
330    Let_def,
331    @{lemma "P = True == P" by (rule eq_reflection) simp}]
332
333  exception BAD_PATTERN of unit
334
335  fun is_builtin_conn_or_pred ctxt c ts =
336    is_some (SMT_Builtin.dest_builtin_conn ctxt c ts) orelse
337    is_some (SMT_Builtin.dest_builtin_pred ctxt c ts)
338in
339
340fun folify ctxt =
341  let
342    fun in_list T f t = SMT_Util.mk_symb_list T (map_filter f (SMT_Util.dest_symb_list t))
343
344    fun in_term pat t =
345      (case Term.strip_comb t of
346        (\<^const>\<open>True\<close>, []) => t
347      | (\<^const>\<open>False\<close>, []) => t
348      | (u as Const (\<^const_name>\<open>If\<close>, _), [t1, t2, t3]) =>
349          if pat then raise BAD_PATTERN () else u $ in_form t1 $ in_term pat t2 $ in_term pat t3
350      | (Const (c as (n, _)), ts) =>
351          if is_builtin_conn_or_pred ctxt c ts orelse is_quant n then
352            if pat then raise BAD_PATTERN () else in_form t
353          else
354            Term.list_comb (Const c, map (in_term pat) ts)
355      | (Free c, ts) => Term.list_comb (Free c, map (in_term pat) ts)
356      | _ => t)
357
358    and in_pat ((p as Const (\<^const_name>\<open>pat\<close>, _)) $ t) =
359          p $ in_term true t
360      | in_pat ((p as Const (\<^const_name>\<open>nopat\<close>, _)) $ t) =
361          p $ in_term true t
362      | in_pat t = raise TERM ("bad pattern", [t])
363
364    and in_pats ps =
365      in_list \<^typ>\<open>pattern symb_list\<close> (SOME o in_list \<^typ>\<open>pattern\<close> (try in_pat)) ps
366
367    and in_trigger ((c as \<^const>\<open>trigger\<close>) $ p $ t) = c $ in_pats p $ in_form t
368      | in_trigger t = in_form t
369
370    and in_form t =
371      (case Term.strip_comb t of
372        (q as Const (qn, _), [Abs (n, T, u)]) =>
373          if is_quant qn then q $ Abs (n, T, in_trigger u)
374          else in_term false t
375      | (Const c, ts) =>
376          (case SMT_Builtin.dest_builtin_conn ctxt c ts of
377            SOME (_, _, us, mk) => mk (map in_form us)
378          | NONE =>
379              (case SMT_Builtin.dest_builtin_pred ctxt c ts of
380                SOME (_, _, us, mk) => mk (map (in_term false) us)
381              | NONE => in_term false t))
382      | _ => in_term false t)
383  in
384    map in_form #>
385    pair (fol_rules, I)
386  end
387
388end
389
390
391(* translation into intermediate format *)
392
393(** utility functions **)
394
395val quantifier = (fn
396    \<^const_name>\<open>All\<close> => SOME SForall
397  | \<^const_name>\<open>Ex\<close> => SOME SExists
398  | _ => NONE)
399
400fun group_quant qname Ts (t as Const (q, _) $ Abs (_, T, u)) =
401      if q = qname then group_quant qname (T :: Ts) u else (Ts, t)
402  | group_quant _ Ts t = (Ts, t)
403
404fun dest_pat (Const (\<^const_name>\<open>pat\<close>, _) $ t) = (t, true)
405  | dest_pat (Const (\<^const_name>\<open>nopat\<close>, _) $ t) = (t, false)
406  | dest_pat t = raise TERM ("bad pattern", [t])
407
408fun dest_pats [] = I
409  | dest_pats ts =
410      (case map dest_pat ts |> split_list ||> distinct (op =) of
411        (ps, [true]) => cons (SPat ps)
412      | (ps, [false]) => cons (SNoPat ps)
413      | _ => raise TERM ("bad multi-pattern", ts))
414
415fun dest_trigger (\<^const>\<open>trigger\<close> $ tl $ t) =
416      (rev (fold (dest_pats o SMT_Util.dest_symb_list) (SMT_Util.dest_symb_list tl) []), t)
417  | dest_trigger t = ([], t)
418
419fun dest_quant qn T t = quantifier qn |> Option.map (fn q =>
420  let
421    val (Ts, u) = group_quant qn [T] t
422    val (ps, p) = dest_trigger u
423  in (q, rev Ts, ps, p) end)
424
425fun fold_map_pat f (SPat ts) = fold_map f ts #>> SPat
426  | fold_map_pat f (SNoPat ts) = fold_map f ts #>> SNoPat
427
428
429(** translation from Isabelle terms into SMT intermediate terms **)
430
431fun intermediate logic dtyps builtin ctxt ts trx =
432  let
433    fun transT (T as TFree _) = add_typ T true
434      | transT (T as TVar _) = (fn _ => raise TYPE ("bad SMT type", [T], []))
435      | transT (T as Type _) =
436          (case SMT_Builtin.dest_builtin_typ ctxt T of
437            SOME (n, []) => pair n
438          | SOME (n, Ts) =>
439            fold_map transT Ts
440            #>> (fn ns => enclose "(" ")" (space_implode " " (n :: ns)))
441          | NONE => add_typ T true)
442
443    fun trans t =
444      (case Term.strip_comb t of
445        (Const (qn, _), [Abs (_, T, t1)]) =>
446          (case dest_quant qn T t1 of
447            SOME (q, Ts, ps, b) =>
448              fold_map transT Ts ##>> fold_map (fold_map_pat trans) ps ##>>
449              trans b #>> (fn ((Ts', ps'), b') => SQua (q, Ts', ps', b'))
450          | NONE => raise TERM ("unsupported quantifier", [t]))
451      | (u as Const (c as (_, T)), ts) =>
452          (case builtin ctxt c ts of
453            SOME (n, _, us, _) => fold_map trans us #>> curry SConst n
454          | NONE => trans_applied_fun u T ts)
455      | (u as Free (_, T), ts) => trans_applied_fun u T ts
456      | (Bound i, ts) => pair i ##>> fold_map trans ts #>> SVar
457      | _ => raise TERM ("bad SMT term", [t]))
458
459    and trans_applied_fun t T ts =
460      let val (Us, U) = SMT_Util.dest_funT (length ts) T
461      in
462        fold_map transT Us ##>> transT U #-> (fn Up =>
463          add_fun t (SOME Up) ##>> fold_map trans ts #>> SConst)
464      end
465
466    val (us, trx') = fold_map trans ts trx
467  in ((sign_of (logic ts) dtyps trx', us), trx') end
468
469
470(* translation *)
471
472structure Configs = Generic_Data
473(
474  type T = (Proof.context -> config) SMT_Util.dict
475  val empty = []
476  val extend = I
477  fun merge data = SMT_Util.dict_merge fst data
478)
479
480fun add_config (cs, cfg) = Configs.map (SMT_Util.dict_update (cs, cfg))
481
482fun get_config ctxt =
483  let val cs = SMT_Config.solver_class_of ctxt
484  in
485    (case SMT_Util.dict_get (Configs.get (Context.Proof ctxt)) cs of
486      SOME cfg => cfg ctxt
487    | NONE => error ("SMT: no translation configuration found " ^
488        "for solver class " ^ quote (SMT_Util.string_of_class cs)))
489  end
490
491fun translate ctxt smt_options comments ithms =
492  let
493    val {order, logic, fp_kinds, serialize} = get_config ctxt
494
495    fun no_dtyps (tr_context, ctxt) ts =
496      ((Termtab.empty, [], tr_context, ctxt), ts)
497
498    val ts1 = map (Envir.beta_eta_contract o SMT_Util.prop_of o snd) ithms
499
500    val ((funcs, dtyps, tr_context, ctxt1), ts2) =
501      ((empty_tr_context, ctxt), ts1)
502      |-> (if null fp_kinds then no_dtyps else collect_co_datatypes fp_kinds)
503
504    fun is_binder (Const (\<^const_name>\<open>Let\<close>, _) $ _) = true
505      | is_binder t = Lambda_Lifting.is_quantifier t
506
507    fun mk_trigger ((q as Const (\<^const_name>\<open>All\<close>, _)) $ Abs (n, T, t)) =
508          q $ Abs (n, T, mk_trigger t)
509      | mk_trigger (eq as (Const (\<^const_name>\<open>HOL.eq\<close>, T) $ lhs $ _)) =
510          Term.domain_type T --> \<^typ>\<open>pattern\<close>
511          |> (fn T => Const (\<^const_name>\<open>pat\<close>, T) $ lhs)
512          |> SMT_Util.mk_symb_list \<^typ>\<open>pattern\<close> o single
513          |> SMT_Util.mk_symb_list \<^typ>\<open>pattern symb_list\<close> o single
514          |> (fn t => \<^const>\<open>trigger\<close> $ t $ eq)
515      | mk_trigger t = t
516
517    val (ctxt2, (ts3, ll_defs)) =
518      ts2
519      |> eta_expand ctxt1 funcs
520      |> rpair ctxt1
521      |-> Lambda_Lifting.lift_lambdas NONE is_binder
522      |-> (fn (ts', ll_defs) => fn ctxt' =>
523        let
524          val ts'' = map mk_trigger ll_defs @ ts'
525            |> order = SMT_Util.First_Order ? intro_explicit_application ctxt' funcs
526        in
527          (ctxt', (ts'', ll_defs))
528        end)
529    val ((rewrite_rules, builtin), ts4) = folify ctxt2 ts3
530      |>> order = SMT_Util.First_Order ? apfst (cons fun_app_eq)
531  in
532    (ts4, tr_context)
533    |-> intermediate logic dtyps (builtin SMT_Builtin.dest_builtin) ctxt2
534    |>> uncurry (serialize smt_options comments)
535    ||> replay_data_of ctxt2 ll_defs rewrite_rules ithms
536  end
537
538end;
539