1(*  Title:      HOL/Tools/monomorph.ML
2    Author:     Sascha Boehme, TU Muenchen
3
4Monomorphization of theorems, i.e., computation of ground instances for
5theorems with type variables.  This procedure is incomplete in general,
6but works well for most practical problems.
7
8Monomorphization is essentially an enumeration of substitutions that map
9schematic types to ground types. Applying these substitutions to theorems
10with type variables results in monomorphized ground instances. The
11enumeration is driven by schematic constants (constants occurring with
12type variables) and all ground instances of such constants (occurrences
13without type variables). The enumeration is organized in rounds in which
14all substitutions for the schematic constants are computed that are induced
15by the ground instances. Any new ground instance may induce further
16substitutions in a subsequent round. To prevent nontermination, there is
17an upper limit of rounds involved and of the number of monomorphized ground
18instances computed.
19
20Theorems to be monomorphized must be tagged with a number indicating the
21initial round in which they participate first. The initial round is
22ignored for theorems without type variables. For any other theorem, the
23initial round must be greater than zero. Returned monomorphized theorems
24carry a number showing from which monomorphization round they emerged.
25*)
26
27signature MONOMORPH =
28sig
29  (* utility functions *)
30  val typ_has_tvars: typ -> bool
31  val all_schematic_consts_of: term -> typ list Symtab.table
32  val add_schematic_consts_of: term -> typ list Symtab.table ->
33    typ list Symtab.table
34
35  (* configuration options *)
36  val max_rounds: int Config.T
37  val max_new_instances: int Config.T
38  val max_thm_instances: int Config.T
39  val max_new_const_instances_per_round: int Config.T
40
41  (* monomorphization *)
42  val monomorph: (term -> typ list Symtab.table) -> Proof.context ->
43    (int * thm) list -> (int * thm) list list
44end
45
46structure Monomorph: MONOMORPH =
47struct
48
49(* utility functions *)
50
51val typ_has_tvars = Term.exists_subtype (fn TVar _ => true | _ => false)
52
53fun add_schematic_const (c as (_, T)) =
54  if typ_has_tvars T then Symtab.insert_list (op =) c else I
55
56fun add_schematic_consts_of t =
57  Term.fold_aterms (fn Const c => add_schematic_const c | _ => I) t
58
59fun all_schematic_consts_of t = add_schematic_consts_of t Symtab.empty
60
61fun clear_grounds grounds = Symtab.map (K (K [])) grounds
62
63
64(* configuration options *)
65
66val max_rounds = Attrib.setup_config_int \<^binding>\<open>monomorph_max_rounds\<close> (K 5)
67
68val max_new_instances =
69  Attrib.setup_config_int \<^binding>\<open>monomorph_max_new_instances\<close> (K 500)
70
71val max_thm_instances =
72  Attrib.setup_config_int \<^binding>\<open>monomorph_max_thm_instances\<close> (K 20)
73
74val max_new_const_instances_per_round =
75  Attrib.setup_config_int \<^binding>\<open>monomorph_max_new_const_instances_per_round\<close> (K 5)
76
77fun limit_rounds ctxt f =
78  let
79    val max = Config.get ctxt max_rounds
80    fun round i x = if i > max then x else round (i + 1) (f ctxt i x)
81  in round 1 end
82
83
84(* theorem information and related functions *)
85
86datatype thm_info =
87  Ground of thm |
88  Ignored |
89  Schematic of {
90    id: int,
91    theorem: thm,
92    tvars: (indexname * sort) list,
93    schematics: (string * typ) list,
94    initial_round: int}
95
96fun fold_grounds f = fold (fn Ground thm => f thm | _ => I)
97
98fun fold_schematic f thm_info =
99  (case thm_info of
100    Schematic {id, theorem, tvars, schematics, initial_round} =>
101      f id theorem tvars schematics initial_round
102  | _ => I)
103
104fun fold_schematics pred f =
105  let
106    fun apply id thm tvars schematics initial_round x =
107      if pred initial_round then f id thm tvars schematics x else x
108  in fold (fold_schematic apply) end
109
110
111(* collecting data *)
112
113(*
114  Theorems with type variables that cannot be instantiated should be ignored.
115  A type variable has only a chance to be instantiated if it occurs in the
116  type of one of the schematic constants.
117*)
118fun groundable all_tvars schematics =
119  let val tvars' = Symtab.fold (fold Term.add_tvarsT o snd) schematics []
120  in forall (member (op =) tvars') all_tvars end
121
122
123fun prepare schematic_consts_of rthms =
124  let
125    fun prep (initial_round, thm) ((id, idx), consts) =
126      if Term.exists_type typ_has_tvars (Thm.prop_of thm) then
127        let
128          (* increase indices to avoid clashes of type variables *)
129          val thm' = Thm.incr_indexes idx thm
130          val idx' = Thm.maxidx_of thm' + 1
131
132          val tvars = Term.add_tvars (Thm.prop_of thm') []
133          val schematics = schematic_consts_of (Thm.prop_of thm')
134          val schematics' =
135            Symtab.fold (fn (n, Ts) => fold (cons o pair n) Ts) schematics []
136
137          (* collect the names of all constants that need to be instantiated *)
138          val consts' =
139            consts
140            |> Symtab.fold (fn (n, _) => Symtab.update (n, [])) schematics
141
142          val thm_info =
143            if not (groundable tvars schematics) then Ignored
144            else
145              Schematic {
146                id = id,
147                theorem = thm',
148                tvars = tvars,
149                schematics = schematics',
150                initial_round = initial_round}
151        in (thm_info, ((id + 1, idx'), consts')) end
152      else (Ground thm, ((id + 1, idx + Thm.maxidx_of thm + 1), consts))
153  in
154    fold_map prep rthms ((0, 0), Symtab.empty) ||> snd
155  end
156
157
158(* collecting instances *)
159
160fun instantiate ctxt subst =
161  let
162    fun cert (ix, (S, T)) = ((ix, S), Thm.ctyp_of ctxt T)
163    fun cert' subst = Vartab.fold (cons o cert) subst []
164  in Thm.instantiate (cert' subst, []) end
165
166fun add_new_grounds used_grounds new_grounds thm =
167  let
168    fun mem tab (n, T) = member (op =) (Symtab.lookup_list tab n) T
169    fun add (Const (c as (n, _))) =
170          if mem used_grounds c orelse mem new_grounds c then I
171          else if not (Symtab.defined used_grounds n) then I
172          else Symtab.insert_list (op =) c
173      | add _ = I
174  in Term.fold_aterms add (Thm.prop_of thm) end
175
176fun add_insts max_instances max_thm_insts ctxt round used_grounds
177    new_grounds id thm tvars schematics cx =
178  let
179    exception ENOUGH of
180      typ list Symtab.table * (int * ((int * (sort * typ) Vartab.table) * thm) list Inttab.table)
181
182    val thy = Proof_Context.theory_of ctxt
183
184    fun add subst (cx as (next_grounds, (n, insts))) =
185      if n >= max_instances then
186        raise ENOUGH cx
187      else
188        let
189          val thm' = instantiate ctxt subst thm
190          val rthm = ((round, subst), thm')
191          val rthms = Inttab.lookup_list insts id;
192          val n_insts' =
193            if member (eq_snd Thm.eq_thm) rthms rthm then
194              (n, insts)
195            else
196              (if length rthms >= max_thm_insts then n else n + 1,
197               Inttab.cons_list (id, rthm) insts)
198          val next_grounds' =
199            add_new_grounds used_grounds new_grounds thm' next_grounds
200        in (next_grounds', n_insts') end
201
202    fun with_grounds (n, T) f subst (n', Us) =
203      let
204        fun matching U = (* one-step refinement of the given substitution *)
205          (case try (Sign.typ_match thy (T, U)) subst of
206            NONE => I
207          | SOME subst' => f subst')
208      in if n = n' then fold matching Us else I end
209
210    fun with_matching_ground c subst f =
211      (* Try new grounds before already used grounds. Otherwise only
212         substitutions already seen in previous rounds get enumerated. *)
213      Symtab.fold (with_grounds c (f true) subst) new_grounds #>
214      Symtab.fold (with_grounds c (f false) subst) used_grounds
215
216    fun is_complete subst =
217      (* Check if a substitution is defined for all TVars of the theorem,
218         which guarantees that the instantiation with this substitution results
219         in a ground theorem since all matchings that led to this substitution
220         are with ground types only. *)
221      subset (op =) (tvars, Vartab.fold (cons o apsnd fst) subst [])
222
223    fun for_schematics _ [] _ = I
224      | for_schematics used_new (c :: cs) subst =
225          with_matching_ground c subst (fn new => fn subst' =>
226            if is_complete subst' then
227              if used_new orelse new then add subst'
228              else I
229            else for_schematics (used_new orelse new) cs subst') #>
230          for_schematics used_new cs subst
231  in
232    (* Enumerate all substitutions that lead to a ground instance of the
233       theorem not seen before. A necessary condition for such a new ground
234       instance is the usage of at least one ground from the new_grounds
235       table. The approach used here is to match all schematics of the theorem
236       with all relevant grounds. *)
237    for_schematics false schematics Vartab.empty cx
238    handle ENOUGH cx' => cx'
239  end
240
241fun is_new round initial_round = (round = initial_round)
242fun is_active round initial_round = (round > initial_round)
243
244fun find_instances max_instances max_thm_insts max_new_grounds thm_infos ctxt round
245    (known_grounds, new_grounds0, insts) =
246  let
247    val new_grounds =
248      Symtab.map (fn _ => fn grounds =>
249        if length grounds <= max_new_grounds then grounds
250        else take max_new_grounds (sort Term_Ord.typ_ord grounds)) new_grounds0
251
252    val add_new = add_insts max_instances max_thm_insts ctxt round
253    fun consider_all pred f (cx as (_, (n, _))) =
254      if n >= max_instances then cx else fold_schematics pred f thm_infos cx
255
256    val known_grounds' = Symtab.merge_list (op =) (known_grounds, new_grounds)
257    val empty_grounds = clear_grounds known_grounds'
258
259    val (new_grounds', insts') =
260      (Symtab.empty, insts)
261      |> consider_all (is_active round) (add_new known_grounds new_grounds)
262      |> consider_all (is_new round) (add_new empty_grounds known_grounds')
263  in
264    (known_grounds', new_grounds', insts')
265  end
266
267fun add_ground_types thm =
268  let fun add (n, T) = Symtab.map_entry n (insert (op =) T)
269  in Term.fold_aterms (fn Const c => add c | _ => I) (Thm.prop_of thm) end
270
271fun collect_instances ctxt max_thm_insts max_new_grounds thm_infos consts =
272  let
273    val known_grounds = fold_grounds add_ground_types thm_infos consts
274    val empty_grounds = clear_grounds known_grounds
275    val max_instances = Config.get ctxt max_new_instances
276      |> fold (fn Schematic _ => Integer.add 1 | _ => I) thm_infos
277  in
278    (empty_grounds, known_grounds, (0, Inttab.empty))
279    |> limit_rounds ctxt (find_instances max_instances max_thm_insts
280      max_new_grounds thm_infos)
281    |> (fn (_, _, (_, insts)) => insts)
282  end
283
284
285(* monomorphization *)
286
287fun size_of_subst subst =
288  Vartab.fold (Integer.add o size_of_typ o snd o snd) subst 0
289
290fun subst_ord subst = int_ord (apply2 size_of_subst subst)
291
292fun instantiated_thms _ _ (Ground thm) = [(0, thm)]
293  | instantiated_thms _ _ Ignored = []
294  | instantiated_thms max_thm_insts insts (Schematic {id, ...}) =
295    Inttab.lookup_list insts id
296    |> (fn rthms =>
297      if length rthms <= max_thm_insts then rthms
298      else take max_thm_insts (sort (prod_ord int_ord subst_ord o apply2 fst) rthms))
299    |> map (apfst fst)
300
301fun monomorph schematic_consts_of ctxt rthms =
302  let
303    val max_thm_insts = Config.get ctxt max_thm_instances
304    val max_new_grounds = Config.get ctxt max_new_const_instances_per_round
305    val (thm_infos, consts) = prepare schematic_consts_of rthms
306    val insts =
307      if Symtab.is_empty consts then Inttab.empty
308      else collect_instances ctxt max_thm_insts max_new_grounds thm_infos consts
309  in map (instantiated_thms max_thm_insts insts) thm_infos end
310
311end
312