1(*  Title:      HOL/Tools/Sledgehammer/sledgehammer_isar_annotate.ML
2    Author:     Steffen Juilf Smolka, TU Muenchen
3    Author:     Jasmin Blanchette, TU Muenchen
4
5Supplements term with a locally minimal, complete set of type constraints.
6Complete: The constraints suffice to infer the term's types. Minimal: Reducing
7the set of constraints further will make it incomplete.
8
9When configuring the pretty printer appropriately, the constraints will show up
10as type annotations when printing the term. This allows the term to be printed
11and reparsed without a change of types.
12
13Terms should be unchecked before calling "annotate_types_in_term" to avoid
14awkward syntax.
15*)
16
17signature SLEDGEHAMMER_ISAR_ANNOTATE =
18sig
19  val annotate_types_in_term : Proof.context -> term -> term
20end;
21
22structure Sledgehammer_Isar_Annotate : SLEDGEHAMMER_ISAR_ANNOTATE =
23struct
24
25fun post_traverse_term_type' f _ (t as Const (_, T)) s = f t T s
26  | post_traverse_term_type' f _ (t as Free (_, T)) s = f t T s
27  | post_traverse_term_type' f _ (t as Var (_, T)) s = f t T s
28  | post_traverse_term_type' f env (t as Bound i) s = f t (nth env i) s
29  | post_traverse_term_type' f env (Abs (x, T1, b)) s =
30    let val ((b', s'), T2) = post_traverse_term_type' f (T1 :: env) b s in
31      f (Abs (x, T1, b')) (T1 --> T2) s'
32    end
33  | post_traverse_term_type' f env (u $ v) s =
34    let
35      val ((u', s'), Type (_, [_, T])) = post_traverse_term_type' f env u s
36      val ((v', s''), _) = post_traverse_term_type' f env v s'
37    in f (u' $ v') T s'' end
38    handle Bind => raise Fail "Sledgehammer_Isar_Annotate: post_traverse_term_type'"
39
40fun post_traverse_term_type f s t =
41  post_traverse_term_type' (fn t => fn T => fn s => (f t T s, T)) [] t s |> fst
42fun post_fold_term_type f s t =
43  post_traverse_term_type (fn t => fn T => fn s => (t, f t T s)) s t |> snd
44
45fun fold_map_atypes f T s =
46  (case T of
47    Type (name, Ts) =>
48    let val (Ts, s) = fold_map (fold_map_atypes f) Ts s in
49      (Type (name, Ts), s)
50    end
51  | _ => f T s)
52
53val indexname_ord = Term_Ord.fast_indexname_ord
54val cost_ord = prod_ord int_ord (prod_ord int_ord int_ord)
55
56structure Var_Set_Tab = Table(
57  type key = indexname list
58  val ord = list_ord indexname_ord)
59
60fun generalize_types ctxt t =
61  let
62    val erase_types = map_types (fn _ => dummyT)
63    (* use schematic type variables *)
64    val ctxt = ctxt |> Proof_Context.set_mode Proof_Context.mode_pattern
65    val infer_types = singleton (Type_Infer_Context.infer_types ctxt)
66  in
67     t |> erase_types |> infer_types
68  end
69
70fun match_types ctxt t1 t2 =
71  let
72    val thy = Proof_Context.theory_of ctxt
73    val get_types = post_fold_term_type (K cons) []
74  in
75    fold (perhaps o try o Sign.typ_match thy) (get_types t1 ~~ get_types t2) Vartab.empty
76  end
77
78fun handle_trivial_tfrees ctxt t' subst =
79  let
80    val add_tfree_names = snd #> snd #> fold_atyps (fn TFree (x, _) => cons x | _ => I)
81
82    val trivial_tfree_names =
83      Vartab.fold add_tfree_names subst []
84      |> filter_out (Variable.is_declared ctxt)
85      |> distinct (op =)
86    val tfree_name_trivial = Ord_List.member fast_string_ord trivial_tfree_names
87
88    val trivial_tvar_names =
89      Vartab.fold
90        (fn (tvar_name, (_, TFree (tfree_name, _))) =>
91               tfree_name_trivial tfree_name ? cons tvar_name
92          | _ => I)
93        subst
94        []
95      |> sort indexname_ord
96    val tvar_name_trivial = Ord_List.member indexname_ord trivial_tvar_names
97
98    val t' =
99      t' |> map_types
100              (map_type_tvar
101                (fn (idxn, sort) =>
102                  if tvar_name_trivial idxn then dummyT else TVar (idxn, sort)))
103
104    val subst =
105      subst |> fold Vartab.delete trivial_tvar_names
106            |> Vartab.map
107               (K (apsnd (map_type_tfree
108                           (fn (name, sort) =>
109                              if tfree_name_trivial name then dummyT
110                              else TFree (name, sort)))))
111  in
112    (t', subst)
113  end
114
115fun key_of_atype (TVar (z, _)) = Ord_List.insert indexname_ord z
116  | key_of_atype _ = I
117fun key_of_type T = fold_atyps key_of_atype T []
118
119fun update_tab t T (tab, pos) =
120  ((case key_of_type T of
121     [] => tab
122   | key =>
123     let val cost = (size_of_typ T, (size_of_term t, pos)) in
124       (case Var_Set_Tab.lookup tab key of
125         NONE => Var_Set_Tab.update_new (key, cost) tab
126       | SOME old_cost =>
127         (case cost_ord (cost, old_cost) of
128           LESS => Var_Set_Tab.update (key, cost) tab
129         | _ => tab))
130     end),
131   pos + 1)
132
133val typing_spot_table = post_fold_term_type update_tab (Var_Set_Tab.empty, 0) #> fst
134
135fun reverse_greedy typing_spot_tab =
136  let
137    fun update_count z =
138      fold (fn tvar => fn tab =>
139        let val c = Vartab.lookup tab tvar |> the_default 0 in
140          Vartab.update (tvar, c + z) tab
141        end)
142    fun superfluous tcount = forall (fn tvar => the (Vartab.lookup tcount tvar) > 1)
143    fun drop_superfluous (tvars, (_, (_, spot))) (spots, tcount) =
144      if superfluous tcount tvars then (spots, update_count ~1 tvars tcount)
145      else (spot :: spots, tcount)
146
147    val (typing_spots, tvar_count_tab) =
148      Var_Set_Tab.fold (fn kv as (k, _) => apfst (cons kv) #> apsnd (update_count 1 k))
149        typing_spot_tab ([], Vartab.empty)
150      |>> sort_distinct (rev_order o cost_ord o apply2 snd)
151  in
152    fold drop_superfluous typing_spots ([], tvar_count_tab) |> fst
153  end
154
155fun introduce_annotations subst spots t t' =
156  let
157    fun subst_atype (T as TVar (idxn, S)) subst =
158        (Envir.subst_type subst T, Vartab.update (idxn, (S, dummyT)) subst)
159      | subst_atype T subst = (T, subst)
160
161    val subst_type = fold_map_atypes subst_atype
162
163    fun collect_annot _ T (subst, cp, ps as p :: ps', annots) =
164        if p <> cp then
165          (subst, cp + 1, ps, annots)
166        else
167          let val (T, subst) = subst_type T subst in
168            (subst, cp + 1, ps', (p, T) :: annots)
169          end
170      | collect_annot _ _ x = x
171
172    val (_, _, _, annots) = post_fold_term_type collect_annot (subst, 0, spots, []) t'
173
174    fun insert_annot t _ (cp, annots as (p, T) :: annots') =
175        if p <> cp then (t, (cp + 1, annots)) else (Type.constraint T t, (cp + 1, annots'))
176      | insert_annot t _ x = (t, x)
177  in
178    t |> post_traverse_term_type insert_annot (0, rev annots) |> fst
179  end
180
181fun annotate_types_in_term ctxt t =
182  let
183    val t' = generalize_types ctxt t
184    val subst = match_types ctxt t' t
185    val (t'', subst') = handle_trivial_tfrees ctxt t' subst
186    val typing_spots = t'' |> typing_spot_table |> reverse_greedy |> sort int_ord
187  in
188    introduce_annotations subst' typing_spots t t''
189  end
190
191end;
192