1(* Title: HOL/Tools/Sledgehammer/sledgehammer_isar_annotate.ML 2 Author: Steffen Juilf Smolka, TU Muenchen 3 Author: Jasmin Blanchette, TU Muenchen 4 5Supplements term with a locally minimal, complete set of type constraints. 6Complete: The constraints suffice to infer the term's types. Minimal: Reducing 7the set of constraints further will make it incomplete. 8 9When configuring the pretty printer appropriately, the constraints will show up 10as type annotations when printing the term. This allows the term to be printed 11and reparsed without a change of types. 12 13Terms should be unchecked before calling "annotate_types_in_term" to avoid 14awkward syntax. 15*) 16 17signature SLEDGEHAMMER_ISAR_ANNOTATE = 18sig 19 val annotate_types_in_term : Proof.context -> term -> term 20end; 21 22structure Sledgehammer_Isar_Annotate : SLEDGEHAMMER_ISAR_ANNOTATE = 23struct 24 25fun post_traverse_term_type' f _ (t as Const (_, T)) s = f t T s 26 | post_traverse_term_type' f _ (t as Free (_, T)) s = f t T s 27 | post_traverse_term_type' f _ (t as Var (_, T)) s = f t T s 28 | post_traverse_term_type' f env (t as Bound i) s = f t (nth env i) s 29 | post_traverse_term_type' f env (Abs (x, T1, b)) s = 30 let val ((b', s'), T2) = post_traverse_term_type' f (T1 :: env) b s in 31 f (Abs (x, T1, b')) (T1 --> T2) s' 32 end 33 | post_traverse_term_type' f env (u $ v) s = 34 let 35 val ((u', s'), Type (_, [_, T])) = post_traverse_term_type' f env u s 36 val ((v', s''), _) = post_traverse_term_type' f env v s' 37 in f (u' $ v') T s'' end 38 handle Bind => raise Fail "Sledgehammer_Isar_Annotate: post_traverse_term_type'" 39 40fun post_traverse_term_type f s t = 41 post_traverse_term_type' (fn t => fn T => fn s => (f t T s, T)) [] t s |> fst 42fun post_fold_term_type f s t = 43 post_traverse_term_type (fn t => fn T => fn s => (t, f t T s)) s t |> snd 44 45fun fold_map_atypes f T s = 46 (case T of 47 Type (name, Ts) => 48 let val (Ts, s) = fold_map (fold_map_atypes f) Ts s in 49 (Type (name, Ts), s) 50 end 51 | _ => f T s) 52 53val indexname_ord = Term_Ord.fast_indexname_ord 54val cost_ord = prod_ord int_ord (prod_ord int_ord int_ord) 55 56structure Var_Set_Tab = Table( 57 type key = indexname list 58 val ord = list_ord indexname_ord) 59 60fun generalize_types ctxt t = 61 let 62 val erase_types = map_types (fn _ => dummyT) 63 (* use schematic type variables *) 64 val ctxt = ctxt |> Proof_Context.set_mode Proof_Context.mode_pattern 65 val infer_types = singleton (Type_Infer_Context.infer_types ctxt) 66 in 67 t |> erase_types |> infer_types 68 end 69 70fun match_types ctxt t1 t2 = 71 let 72 val thy = Proof_Context.theory_of ctxt 73 val get_types = post_fold_term_type (K cons) [] 74 in 75 fold (perhaps o try o Sign.typ_match thy) (get_types t1 ~~ get_types t2) Vartab.empty 76 end 77 78fun handle_trivial_tfrees ctxt t' subst = 79 let 80 val add_tfree_names = snd #> snd #> fold_atyps (fn TFree (x, _) => cons x | _ => I) 81 82 val trivial_tfree_names = 83 Vartab.fold add_tfree_names subst [] 84 |> filter_out (Variable.is_declared ctxt) 85 |> distinct (op =) 86 val tfree_name_trivial = Ord_List.member fast_string_ord trivial_tfree_names 87 88 val trivial_tvar_names = 89 Vartab.fold 90 (fn (tvar_name, (_, TFree (tfree_name, _))) => 91 tfree_name_trivial tfree_name ? cons tvar_name 92 | _ => I) 93 subst 94 [] 95 |> sort indexname_ord 96 val tvar_name_trivial = Ord_List.member indexname_ord trivial_tvar_names 97 98 val t' = 99 t' |> map_types 100 (map_type_tvar 101 (fn (idxn, sort) => 102 if tvar_name_trivial idxn then dummyT else TVar (idxn, sort))) 103 104 val subst = 105 subst |> fold Vartab.delete trivial_tvar_names 106 |> Vartab.map 107 (K (apsnd (map_type_tfree 108 (fn (name, sort) => 109 if tfree_name_trivial name then dummyT 110 else TFree (name, sort))))) 111 in 112 (t', subst) 113 end 114 115fun key_of_atype (TVar (z, _)) = Ord_List.insert indexname_ord z 116 | key_of_atype _ = I 117fun key_of_type T = fold_atyps key_of_atype T [] 118 119fun update_tab t T (tab, pos) = 120 ((case key_of_type T of 121 [] => tab 122 | key => 123 let val cost = (size_of_typ T, (size_of_term t, pos)) in 124 (case Var_Set_Tab.lookup tab key of 125 NONE => Var_Set_Tab.update_new (key, cost) tab 126 | SOME old_cost => 127 (case cost_ord (cost, old_cost) of 128 LESS => Var_Set_Tab.update (key, cost) tab 129 | _ => tab)) 130 end), 131 pos + 1) 132 133val typing_spot_table = post_fold_term_type update_tab (Var_Set_Tab.empty, 0) #> fst 134 135fun reverse_greedy typing_spot_tab = 136 let 137 fun update_count z = 138 fold (fn tvar => fn tab => 139 let val c = Vartab.lookup tab tvar |> the_default 0 in 140 Vartab.update (tvar, c + z) tab 141 end) 142 fun superfluous tcount = forall (fn tvar => the (Vartab.lookup tcount tvar) > 1) 143 fun drop_superfluous (tvars, (_, (_, spot))) (spots, tcount) = 144 if superfluous tcount tvars then (spots, update_count ~1 tvars tcount) 145 else (spot :: spots, tcount) 146 147 val (typing_spots, tvar_count_tab) = 148 Var_Set_Tab.fold (fn kv as (k, _) => apfst (cons kv) #> apsnd (update_count 1 k)) 149 typing_spot_tab ([], Vartab.empty) 150 |>> sort_distinct (rev_order o cost_ord o apply2 snd) 151 in 152 fold drop_superfluous typing_spots ([], tvar_count_tab) |> fst 153 end 154 155fun introduce_annotations subst spots t t' = 156 let 157 fun subst_atype (T as TVar (idxn, S)) subst = 158 (Envir.subst_type subst T, Vartab.update (idxn, (S, dummyT)) subst) 159 | subst_atype T subst = (T, subst) 160 161 val subst_type = fold_map_atypes subst_atype 162 163 fun collect_annot _ T (subst, cp, ps as p :: ps', annots) = 164 if p <> cp then 165 (subst, cp + 1, ps, annots) 166 else 167 let val (T, subst) = subst_type T subst in 168 (subst, cp + 1, ps', (p, T) :: annots) 169 end 170 | collect_annot _ _ x = x 171 172 val (_, _, _, annots) = post_fold_term_type collect_annot (subst, 0, spots, []) t' 173 174 fun insert_annot t _ (cp, annots as (p, T) :: annots') = 175 if p <> cp then (t, (cp + 1, annots)) else (Type.constraint T t, (cp + 1, annots')) 176 | insert_annot t _ x = (t, x) 177 in 178 t |> post_traverse_term_type insert_annot (0, rev annots) |> fst 179 end 180 181fun annotate_types_in_term ctxt t = 182 let 183 val t' = generalize_types ctxt t 184 val subst = match_types ctxt t' t 185 val (t'', subst') = handle_trivial_tfrees ctxt t' subst 186 val typing_spots = t'' |> typing_spot_table |> reverse_greedy |> sort int_ord 187 in 188 introduce_annotations subst' typing_spots t t'' 189 end 190 191end; 192