1(*  Title:      HOL/Tools/Sledgehammer/sledgehammer_isar_annotate.ML
2    Author:     Steffen Juilf Smolka, TU Muenchen
3    Author:     Jasmin Blanchette, TU Muenchen
5Supplements term with a locally minimal, complete set of type constraints.
6Complete: The constraints suffice to infer the term's types. Minimal: Reducing
7the set of constraints further will make it incomplete.
9When configuring the pretty printer appropriately, the constraints will show up
10as type annotations when printing the term. This allows the term to be printed
11and reparsed without a change of types.
13Terms should be unchecked before calling "annotate_types_in_term" to avoid
14awkward syntax.
19  val annotate_types_in_term : Proof.context -> term -> term
22structure Sledgehammer_Isar_Annotate : SLEDGEHAMMER_ISAR_ANNOTATE =
25fun post_traverse_term_type' f _ (t as Const (_, T)) s = f t T s
26  | post_traverse_term_type' f _ (t as Free (_, T)) s = f t T s
27  | post_traverse_term_type' f _ (t as Var (_, T)) s = f t T s
28  | post_traverse_term_type' f env (t as Bound i) s = f t (nth env i) s
29  | post_traverse_term_type' f env (Abs (x, T1, b)) s =
30    let val ((b', s'), T2) = post_traverse_term_type' f (T1 :: env) b s in
31      f (Abs (x, T1, b')) (T1 --> T2) s'
32    end
33  | post_traverse_term_type' f env (u $ v) s =
34    let
35      val ((u', s'), Type (_, [_, T])) = post_traverse_term_type' f env u s
36      val ((v', s''), _) = post_traverse_term_type' f env v s'
37    in f (u' $ v') T s'' end
38    handle Bind => raise Fail "Sledgehammer_Isar_Annotate: post_traverse_term_type'"
40fun post_traverse_term_type f s t =
41  post_traverse_term_type' (fn t => fn T => fn s => (f t T s, T)) [] t s |> fst
42fun post_fold_term_type f s t =
43  post_traverse_term_type (fn t => fn T => fn s => (t, f t T s)) s t |> snd
45fun fold_map_atypes f T s =
46  (case T of
47    Type (name, Ts) =>
48    let val (Ts, s) = fold_map (fold_map_atypes f) Ts s in
49      (Type (name, Ts), s)
50    end
51  | _ => f T s)
53val indexname_ord = Term_Ord.fast_indexname_ord
54val cost_ord = prod_ord int_ord (prod_ord int_ord int_ord)
56structure Var_Set_Tab = Table(
57  type key = indexname list
58  val ord = list_ord indexname_ord)
60fun generalize_types ctxt t =
61  let
62    val erase_types = map_types (fn _ => dummyT)
63    (* use schematic type variables *)
64    val ctxt = ctxt |> Proof_Context.set_mode Proof_Context.mode_pattern
65    val infer_types = singleton (Type_Infer_Context.infer_types ctxt)
66  in
67     t |> erase_types |> infer_types
68  end
70fun match_types ctxt t1 t2 =
71  let
72    val thy = Proof_Context.theory_of ctxt
73    val get_types = post_fold_term_type (K cons) []
74  in
75    fold (perhaps o try o Sign.typ_match thy) (get_types t1 ~~ get_types t2) Vartab.empty
76  end
78fun handle_trivial_tfrees ctxt t' subst =
79  let
80    val add_tfree_names = snd #> snd #> fold_atyps (fn TFree (x, _) => cons x | _ => I)
82    val trivial_tfree_names =
83      Vartab.fold add_tfree_names subst []
84      |> filter_out (Variable.is_declared ctxt)
85      |> distinct (op =)
86    val tfree_name_trivial = Ord_List.member fast_string_ord trivial_tfree_names
88    val trivial_tvar_names =
89      Vartab.fold
90        (fn (tvar_name, (_, TFree (tfree_name, _))) =>
91               tfree_name_trivial tfree_name ? cons tvar_name
92          | _ => I)
93        subst
94        []
95      |> sort indexname_ord
96    val tvar_name_trivial = Ord_List.member indexname_ord trivial_tvar_names
98    val t' =
99      t' |> map_types
100              (map_type_tvar
101                (fn (idxn, sort) =>
102                  if tvar_name_trivial idxn then dummyT else TVar (idxn, sort)))
104    val subst =
105      subst |> fold Vartab.delete trivial_tvar_names
106            |> Vartab.map
107               (K (apsnd (map_type_tfree
108                           (fn (name, sort) =>
109                              if tfree_name_trivial name then dummyT
110                              else TFree (name, sort)))))
111  in
112    (t', subst)
113  end
115fun key_of_atype (TVar (z, _)) = Ord_List.insert indexname_ord z
116  | key_of_atype _ = I
117fun key_of_type T = fold_atyps key_of_atype T []
119fun update_tab t T (tab, pos) =
120  ((case key_of_type T of
121     [] => tab
122   | key =>
123     let val cost = (size_of_typ T, (size_of_term t, pos)) in
124       (case Var_Set_Tab.lookup tab key of
125         NONE => Var_Set_Tab.update_new (key, cost) tab
126       | SOME old_cost =>
127         (case cost_ord (cost, old_cost) of
128           LESS => Var_Set_Tab.update (key, cost) tab
129         | _ => tab))
130     end),
131   pos + 1)
133val typing_spot_table = post_fold_term_type update_tab (Var_Set_Tab.empty, 0) #> fst
135fun reverse_greedy typing_spot_tab =
136  let
137    fun update_count z =
138      fold (fn tvar => fn tab =>
139        let val c = Vartab.lookup tab tvar |> the_default 0 in
140          Vartab.update (tvar, c + z) tab
141        end)
142    fun superfluous tcount = forall (fn tvar => the (Vartab.lookup tcount tvar) > 1)
143    fun drop_superfluous (tvars, (_, (_, spot))) (spots, tcount) =
144      if superfluous tcount tvars then (spots, update_count ~1 tvars tcount)
145      else (spot :: spots, tcount)
147    val (typing_spots, tvar_count_tab) =
148      Var_Set_Tab.fold (fn kv as (k, _) => apfst (cons kv) #> apsnd (update_count 1 k))
149        typing_spot_tab ([], Vartab.empty)
150      |>> sort_distinct (rev_order o cost_ord o apply2 snd)
151  in
152    fold drop_superfluous typing_spots ([], tvar_count_tab) |> fst
153  end
155fun introduce_annotations subst spots t t' =
156  let
157    fun subst_atype (T as TVar (idxn, S)) subst =
158        (Envir.subst_type subst T, Vartab.update (idxn, (S, dummyT)) subst)
159      | subst_atype T subst = (T, subst)
161    val subst_type = fold_map_atypes subst_atype
163    fun collect_annot _ T (subst, cp, ps as p :: ps', annots) =
164        if p <> cp then
165          (subst, cp + 1, ps, annots)
166        else
167          let val (T, subst) = subst_type T subst in
168            (subst, cp + 1, ps', (p, T) :: annots)
169          end
170      | collect_annot _ _ x = x
172    val (_, _, _, annots) = post_fold_term_type collect_annot (subst, 0, spots, []) t'
174    fun insert_annot t _ (cp, annots as (p, T) :: annots') =
175        if p <> cp then (t, (cp + 1, annots)) else (Type.constraint T t, (cp + 1, annots'))
176      | insert_annot t _ x = (t, x)
177  in
178    t |> post_traverse_term_type insert_annot (0, rev annots) |> fst
179  end
181fun annotate_types_in_term ctxt t =
182  let
183    val t' = generalize_types ctxt t
184    val subst = match_types ctxt t' t
185    val (t'', subst') = handle_trivial_tfrees ctxt t' subst
186    val typing_spots = t'' |> typing_spot_table |> reverse_greedy |> sort int_ord
187  in
188    introduce_annotations subst' typing_spots t t''
189  end