1(*
2 * Copyright 2020, Data61, CSIRO (ABN 41 687 119 230)
3 *
4 * SPDX-License-Identifier: BSD-2-Clause
5 *)
6
7theory FastMap
8imports
9  LemmaBucket
10begin
11
12text \<open>
13  Efficient rules and tactics for working with large lookup tables (maps).
14
15  Features:
16    \<^item> Define a binary lookup tree for any lookup table (requires linorder keys)
17    \<^item> Conversion from lookup tree to lookup lists
18    \<^item> Pre-computation of lookup results and domain/range sets
19
20  See FastMap_Test for examples.
21\<close>
22
23(*
24 * TODO:
25 *
26 *   ��� Storing the auxilliary list theorems with Local_Theory.notes
27 *     takes quadratic time. Unfortunately, this seems to be a problem
28 *     deep inside the Isabelle implementation. One might try to wrap
29 *     the lists in new constants, but Local_Theory.define also takes
30 *     quadratic time.
31 *
32 *   ��� Still a bit slower than the StaticFun package. Streamline the
33 *     rulesets and proofs.
34 *
35 *   ��� We use a lot of manual convs for performance and to avoid
36 *     relying on the dynamic simpset. However, we should clean up the
37 *     convs and move as much as possible to (simp only:) invocations.
38 *
39 *     Note that running simp on deeply nested terms (e.g. lists)
40 *     always takes quadratic time and we can't use it there. This is
41 *     because rewritec unconditionally calls eta_conversion (urgh).
42 *
43 *   ��� The injectivity prover currently hardcodes inj_def into the
44 *     simpset. This should be changed at some point, probably by
45 *     asking the user to prove it beforehand.
46 *
47 *   ��� The key ordering prover is currently hardcoded to simp_tac;
48 *     this should also be generalised. On the other hand, the user
49 *     could work around this by manually supplying a simpset with
50 *     precisely the needed theorems.
51 *
52 *   ��� Using the simplifier to evaluate tree lookups is still quite
53 *     slow because it looks at the entire tree term (even though
54 *     most of it is irrelevant for any given lookup). We should
55 *     provide a tactic or simproc to do this.
56 *
57 *     We already generate lookup theorems for keys in the map, so
58 *     this tactic should be optimised for missing keys.
59 *
60 *   ��� The linorder requirement can be cumbersome. It arises because
61 *     we express the map_of conversion as a general theorem using
62 *     lookup_tree_valid. An alternative approach is to extend what
63 *     StaticFun does, and cleverly extract the set of all relevant
64 *     bindings from the tree on a case-by-case basis.
65 *
66 *     We would still need to evaluate the key ordering function on
67 *     the input keys, but any arbitrary relation would be allowed.
68 *     This one probably calls for a wizard.
69 *)
70
71locale FastMap begin
72
73text \<open>
74  Binary lookup tree. This is largely an implementation detail, so we
75  choose the structure to make automation easier (e.g. separate fields
76  for the key and value).
77
78  We could reuse HOL.Tree instead, but the proofs would need changing.
79\<close>
80datatype ('k, 'v) Tree =
81    Leaf
82  | Node 'k 'v "('k, 'v) Tree" "('k, 'v) Tree"
83
84primrec lookup_tree :: "('k \<Rightarrow> 'ok :: linorder) \<Rightarrow> ('k, 'v) Tree \<Rightarrow> 'k \<Rightarrow> 'v option"
85  where
86    "lookup_tree key Leaf x = None"
87  | "lookup_tree key (Node k v l r) x =
88       (if key x = key k then Some v
89        else if key x < key k then lookup_tree key l x
90        else lookup_tree key r x)"
91
92text \<open>
93  Predicate for well-formed lookup trees.
94  This states that the keys are distinct and appear in ascending order.
95  It also returns the lowest and highest keys in the tree (or None for empty trees).
96\<close>
97primrec lookup_tree_valid ::
98        "('k \<Rightarrow> 'ok :: linorder) \<Rightarrow> ('k, 'v) Tree \<Rightarrow> bool \<times> ('k \<times> 'k) option"
99  where
100    "lookup_tree_valid key Leaf = (True, None)"
101  | "lookup_tree_valid key (Node k v lt rt) =
102       (let (lt_valid, lt_range) = lookup_tree_valid key lt;
103            (rt_valid, rt_range) = lookup_tree_valid key rt;
104            lt_low = (case lt_range of None \<Rightarrow> k | Some (low, high) \<Rightarrow> low);
105            rt_high = (case rt_range of None \<Rightarrow> k | Some (low, high) \<Rightarrow> high)
106        in (lt_valid \<and> rt_valid \<and>
107            (case lt_range of None \<Rightarrow> True | Some (low, high) \<Rightarrow> key high < key k) \<and>
108            (case rt_range of None \<Rightarrow> True | Some (low, high) \<Rightarrow> key k < key low),
109            Some (lt_low, rt_high)))"
110
111lemma lookup_tree_valid_simps':
112  "lookup_tree_valid key Leaf = (True, None)"
113  "lookup_tree_valid key (Node k v Leaf Leaf) = (True, Some (k, k))"
114  "\<lbrakk> lookup_tree_valid key (Node lk lv llt lrt) = (True, Some (llow, lhigh));
115     key lhigh < key k
116   \<rbrakk> \<Longrightarrow> lookup_tree_valid key (Node k v (Node lk lv llt lrt) Leaf) =
117           (True, Some (llow, k))"
118  "\<lbrakk> lookup_tree_valid key (Node rk rv rlt rrt) = (True, Some (rlow, rhigh));
119     key k < key rlow
120   \<rbrakk> \<Longrightarrow> lookup_tree_valid key (Node k v Leaf (Node rk rv rlt rrt)) =
121           (True, Some (k, rhigh))"
122  "\<lbrakk> lookup_tree_valid key (Node lk lv llt lrt) = (True, Some (llow, lhigh));
123     lookup_tree_valid key (Node rk rv rlt rrt) = (True, Some (rlow, rhigh));
124     key lhigh < key k;
125     key k < key rlow
126   \<rbrakk> \<Longrightarrow> lookup_tree_valid key (Node k v (Node lk lv llt lrt) (Node rk rv rlt rrt)) =
127           (True, Some (llow, rhigh))"
128  by auto
129
130lemma lookup_tree_valid_empty:
131  "lookup_tree_valid key tree = (True, None) \<Longrightarrow> tree = Leaf"
132  apply (induct tree)
133   apply simp
134  apply (fastforce split: prod.splits option.splits if_splits)
135  done
136
137lemma lookup_tree_valid_range:
138  "lookup_tree_valid key tree = (True, Some (low, high)) \<Longrightarrow> key low \<le> key high"
139  apply (induct tree arbitrary: low high)
140   apply simp
141  apply (fastforce split: prod.splits option.splits if_splits)
142  done
143
144lemma lookup_tree_valid_in_range:
145  "lookup_tree_valid key tree = (True, Some (low, high)) \<Longrightarrow>
146   lookup_tree key tree k = Some v \<Longrightarrow>
147   key k \<in> {key low .. key high}"
148  apply (induct tree arbitrary: k v low high)
149   apply simp
150  apply (fastforce split: prod.splits option.splits if_split_asm
151                   dest: lookup_tree_valid_empty lookup_tree_valid_range)
152  done
153
154lemma lookup_tree_valid_in_range_None:
155  "lookup_tree_valid key tree = (True, Some (low, high)) \<Longrightarrow>
156   key k \<notin> {key low .. key high} \<Longrightarrow>
157   lookup_tree key tree k = None"
158  using lookup_tree_valid_in_range by fastforce
159
160text \<open>
161  Flatten a lookup tree into an assoc-list.
162  As long as the tree is well-formed, both forms are equivalent.
163\<close>
164primrec lookup_tree_to_list :: "('k, 'v) Tree \<Rightarrow> ('k \<times> 'v) list"
165  where
166    "lookup_tree_to_list Leaf = []"
167  | "lookup_tree_to_list (Node k v lt rt) =
168        lookup_tree_to_list lt @ [(k, v)] @ lookup_tree_to_list rt"
169
170lemma lookup_tree_to_list_range:
171  "lookup_tree_valid key tree = (True, Some (low, high)) \<Longrightarrow>
172   (k, v) \<in> set (lookup_tree_to_list tree) \<Longrightarrow>
173   key k \<in> {key low .. key high}"
174  apply (induct tree arbitrary: k v low high)
175   apply simp
176  apply (fastforce split: prod.splits option.splits if_split_asm
177                   dest: lookup_tree_valid_empty lookup_tree_valid_range)
178  done
179
180lemma lookup_tree_dom_distinct_sorted_:
181  "fst (lookup_tree_valid key tree) \<Longrightarrow>
182   distinct (lookup_tree_to_list tree) \<and> sorted (map (key o fst) (lookup_tree_to_list tree))"
183  apply (induct tree)
184   apply simp
185  apply (fastforce simp: sorted_append
186                   split: prod.splits option.splits if_splits
187                   dest: lookup_tree_valid_empty lookup_tree_valid_range
188                         lookup_tree_valid_in_range lookup_tree_to_list_range
189                   elim: lookup_tree_valid_in_range_None)
190  done
191
192lemmas lookup_tree_dom_distinct = lookup_tree_dom_distinct_sorted_[THEN conjunct1]
193lemmas lookup_tree_dom_sorted = lookup_tree_dom_distinct_sorted_[THEN conjunct2]
194
195(* This goal is eta-expanded and flipped, which seems to help its proof *)
196lemma lookup_tree_to_list_of_:
197  "fst (lookup_tree_valid key tree) \<Longrightarrow>
198   map_of (map (apfst key) (lookup_tree_to_list tree)) (key k) = lookup_tree key tree k"
199  apply (induct tree arbitrary: k)
200   apply simp
201  (* this big blob just does case distinctions of both subtrees and
202     all possible lookup results within each, then solves *)
203  (* slow 10s *)
204  apply (fastforce simp: apfst_def map_prod_def map_add_def
205                   split: prod.splits option.splits if_splits
206                   dest: lookup_tree_valid_empty lookup_tree_valid_range lookup_tree_valid_in_range
207                   elim: lookup_tree_valid_in_range_None)
208  done
209
210(* Standard form of above *)
211lemma lookup_tree_to_list_of:
212  "fst (lookup_tree_valid key tree) \<Longrightarrow>
213   lookup_tree key tree = map_of (map (apfst key) (lookup_tree_to_list tree)) o key"
214  apply (rule ext)
215  apply (simp add: lookup_tree_to_list_of_)
216  done
217
218lemma map_of_key:
219  "inj key \<Longrightarrow> map_of (map (apfst key) binds) o key = map_of binds"
220  apply (rule ext)
221  apply (induct binds)
222   apply simp
223  apply (clarsimp simp: inj_def dom_def)
224  done
225
226lemma lookup_tree_to_list_of_distinct:
227  "\<lbrakk> fst (lookup_tree_valid key tree);
228     lookup_tree_to_list tree = binds;
229     lookup_tree key tree = map_of (map (apfst key) binds) o key
230   \<rbrakk> \<Longrightarrow> distinct (map (key o fst) binds)"
231  apply (drule sym[where t = binds])
232  apply clarsimp
233  apply (thin_tac "binds = _")
234  apply (induct tree)
235   apply simp
236  apply (fastforce simp: map_add_def lookup_tree_to_list_of
237                   split: prod.splits option.splits if_splits
238                   dest: lookup_tree_valid_empty lookup_tree_valid_range
239                         lookup_tree_valid_in_range lookup_tree_to_list_range
240                   elim: lookup_tree_valid_in_range_None)
241  done
242
243(* Top-level rule for converting to lookup list.
244   We add a distinctness assertion for inferring the range of values. *)
245lemma lookup_tree_to_list_of_gen:
246  "\<lbrakk> inj key;
247     fst (lookup_tree_valid key tree);
248     lookup_tree_to_list tree = binds
249   \<rbrakk> \<Longrightarrow> lookup_tree key tree = map_of binds \<and> distinct (map fst binds)"
250  using lookup_tree_to_list_of
251  apply (fastforce intro: lookup_tree_to_list_of_distinct
252                   simp: map_of_key distinct_inj)
253  done
254
255text \<open>
256  Domain and range of a @{const map_of}.
257  Like @{thm dom_map_of_conv_image_fst} but leaving out the set bloat.
258\<close>
259lemma dom_map_of_conv_list:
260  "dom (map_of xs) = set (map fst xs)"
261  by (simp add: dom_map_of_conv_image_fst)
262
263lemma ran_map_of_conv_list:
264  "distinct (map fst xs) \<Longrightarrow> ran (map_of xs) = set (map snd xs)"
265  by (erule distinct_map_via_ran)
266
267text \<open>
268  Read lookup rules from a @{const map_of}.
269\<close>
270
271lemma map_of_lookups:
272  "m = map_of binds \<and> distinct (map fst binds) \<Longrightarrow>
273   list_all (\<lambda>(k, v). m k = Some v) binds"
274  apply (induct binds)
275   apply simp
276  apply (force simp: list_all_iff)
277  done
278
279(* Helper for converting from maps defined as @{const fun_upd} chains,
280 * which are applied in reverse order *)
281lemma map_of_rev:
282  "distinct (map fst binds) \<Longrightarrow>
283   map_of (rev binds) = map_of binds"
284  apply (subgoal_tac "distinct (map fst (rev binds))")
285   apply (rule ext)
286   apply (induct binds)
287    apply simp
288   apply (force simp: map_add_def split: option.splits)
289  apply (metis distinct_rev rev_map)
290  done
291
292lemma list_all_dest:
293  "list_all P [(x, y)] \<equiv> P (x, y)"
294  "list_all P ((x, y) # z # xs) \<equiv> (P (x, y) \<and> list_all P (z # xs))"
295  by auto
296
297(* Install lookup rules that don't depend on if_cong/if_weak_cong setup *)
298lemma lookup_tree_simps':
299  "lookup_tree key Leaf x = None"
300  "key x = key k \<Longrightarrow> lookup_tree key (Node k v l r) x = Some v"
301  "key x < key k \<Longrightarrow> lookup_tree key (Node k v l r) x = lookup_tree key l x"
302  "key x > key k \<Longrightarrow> lookup_tree key (Node k v l r) x = lookup_tree key r x"
303  by auto
304end
305
306declare FastMap.lookup_tree.simps[simp del]
307declare FastMap.lookup_tree_simps'[simp]
308
309ML \<open>
310structure FastMap = struct
311
312(* utils *)
313fun mk_optionT typ = Type ("Option.option", [typ])
314fun dest_optionT (Type ("Option.option", [typ])) = typ
315  | dest_optionT t = raise TYPE ("dest_optionT", [t], [])
316
317(* O(1) version of thm RS @{thm eq_reflection} *)
318fun then_eq_reflection thm = let
319  val (x, y) = Thm.dest_binop (Thm.dest_arg (Thm.cprop_of thm));
320  val cT = Thm.ctyp_of_cterm x;
321  val rule = @{thm eq_reflection} |> Thm.instantiate' [SOME cT] [SOME x, SOME y];
322  in Thm.implies_elim rule thm end;
323
324val lhs_conv = Conv.fun_conv o Conv.arg_conv
325val rhs_conv = Conv.arg_conv
326(* first order rewr_conv *)
327fun fo_rewr_conv rule ct = let
328  val (pure_eq, eqn) =
329        ((true, Thm.instantiate (Thm.first_order_match (Thm.lhs_of rule, ct)) rule)
330         handle TERM _ =>
331           (false, Thm.instantiate (Thm.first_order_match
332                      (fst (Thm.dest_binop (Thm.dest_arg (Thm.cprop_of rule))), ct)) rule))
333        handle Pattern.MATCH => raise CTERM ("fo_rewr_conv", [Thm.cprop_of rule, ct]);
334  in if pure_eq then eqn else then_eq_reflection eqn end;
335fun fo_rewrs_conv rules = Conv.first_conv (map fo_rewr_conv rules);
336
337(* Evaluate a term with rewrite rules. Unlike the simplifier, this
338 * does only one top-down pass, but that's enough for tasks like
339 * pushing List.map through a list. Also runs much faster. *)
340fun fo_topdown_rewr_conv rules ctxt =
341      Conv.top_conv (K (Conv.try_conv (fo_rewrs_conv rules))) ctxt
342
343(* Allow recursive conv in cv2, deferred by function application *)
344infix 1 then_conv'
345fun (cv1 then_conv' cv2) ct =
346  let
347    val eq1 = cv1 ct;
348    val eq2 = cv2 () (Thm.rhs_of eq1);
349  in
350    if Thm.is_reflexive eq1 then eq2
351    else if Thm.is_reflexive eq2 then eq1
352    else Thm.transitive eq1 eq2
353  end;
354
355(*
356 * Helper that makes it easier to describe where to apply a conv.
357 * This takes a skeleton term and applies the conversion wherever "HERE"
358 * appears in the skeleton.
359 *
360 * FIXME: use HOL-Library.Rewrite instead
361 *)
362fun conv_at skel conv ctxt ct = let
363  fun mismatch current_skel current_ct =
364    raise TERM ("conv_at mismatch", [current_skel, Thm.term_of current_ct, skel, Thm.term_of ct])
365
366  fun walk (Free ("HERE", _)) ctxt ct = conv ct
367    | walk (skel as skel_f $ skel_x) ctxt ct =
368        (case Thm.term_of ct of
369            f $ x => Conv.combination_conv (walk skel_f ctxt) (walk skel_x ctxt) ct
370          | _ => mismatch skel ct)
371    | walk (skel as Abs (_, _, skel_body)) ctxt ct =
372        (case Thm.term_of ct of
373            Abs _ => Conv.abs_conv (fn (v, ctxt') => walk skel_body ctxt') ctxt ct
374          | _ => mismatch skel ct)
375    (* Also check that Consts match the skeleton pattern *)
376    | walk (skel as Const (skel_name, _)) ctxt ct =
377        if (case Thm.term_of ct of Const (name, _) => name = skel_name | _ => false)
378        then Thm.reflexive ct
379        else mismatch skel ct
380    (* Default case *)
381    | walk _ ctxt ct = Thm.reflexive ct
382  in walk skel ctxt ct end
383
384fun gconv_at_tac pat conv ctxt = Conv.gconv_rule (conv_at pat conv ctxt) 1 #> Seq.succeed
385
386
387(* Tree builder code, copied from StaticFun *)
388
389(* Actually build the tree -- theta (n lg(n)) *)
390fun build_tree' _ mk_leaf [] = mk_leaf
391  | build_tree' mk_node mk_leaf xs = let
392    val len = length xs
393    val (ys, zs) = chop (len div 2) xs
394  in case zs of [] => error "build_tree': impossible"
395    | ((a, b) :: zs) => mk_node a b (build_tree' mk_node mk_leaf ys)
396        (build_tree' mk_node mk_leaf zs)
397  end
398
399fun build_tree xs = case xs of [] => error "build_tree : empty"
400  | (idx, v) :: _ => let
401    val idxT = fastype_of idx
402    val vT = fastype_of v
403    val treeT = Type (@{type_name FastMap.Tree}, [idxT, vT])
404    val mk_leaf = Const (@{const_name FastMap.Leaf}, treeT)
405    val node = Const (@{const_name FastMap.Node},
406        idxT --> vT --> treeT --> treeT --> treeT)
407    fun mk_node a b l r = node $ a $ b $ l $ r
408  in
409    build_tree' mk_node mk_leaf xs
410  end
411
412fun define_partial_map_tree map_name mappings ord_term ctxt = let
413    val (idxT, vT) = apply2 fastype_of (hd mappings)
414    val treeT = Type (@{type_name FastMap.Tree}, [idxT, vT])
415    val lookup = Const (@{const_name FastMap.lookup_tree},
416                        fastype_of ord_term --> treeT --> idxT
417                            --> Type (@{type_name option}, [vT]))
418    val map_term = lookup $ ord_term $ build_tree mappings
419    val ((map_const, (_, map_def)), ctxt) =
420          Local_Theory.define ((map_name, NoSyn), ((Thm.def_binding map_name, []), map_term)) ctxt
421  in
422    ((map_const, map_def), ctxt)
423  end
424
425(* Prove key ordering theorems. This lets us issue precise error messages
426   when the user gives us keys whose ordering cannot be verified.
427   We will also need these thms to prove the lookup_tree_valid property. *)
428fun prove_key_ord_thms tree_name keyT mappings get_key simp_ctxt ctxt =
429  let
430    val solver = simp_tac (simp_ctxt ctxt [] []) 1;
431  in
432    fst (split_last mappings) ~~ tl mappings
433    |> map_index (fn (i, ((k1, _), (k2, _))) => let
434           val prop = Const (@{const_name less}, keyT --> keyT --> HOLogic.boolT) $
435                            (get_key $ k1) $ (get_key $ k2)
436                      |> HOLogic.mk_Trueprop;
437           in case try (Goal.prove ctxt [] [] prop) (K solver) of
438                  SOME x => x
439                | _ => raise TERM (tree_name ^ ": failed to prove less-than ordering for keys #" ^
440                                   string_of_int i ^ ", #" ^ string_of_int (i + 1),
441                                   [prop])
442           end)
443  end;
444
445(* Prove lookup_tree_valid *)
446fun prove_tree_valid tree_name mappings kT keyT tree_term get_key simp_ctxt ctxt = let
447    val key_ord_thms = prove_key_ord_thms tree_name keyT mappings get_key simp_ctxt ctxt;
448    val treeT = fastype_of tree_term
449    val valid_resultT = HOLogic.mk_prodT (HOLogic.boolT, mk_optionT (HOLogic.mk_prodT (kT, kT)))
450    val tree_valid_prop =
451          HOLogic.mk_Trueprop (
452            Const (@{const_name fst}, valid_resultT --> HOLogic.boolT) $
453            (Const (@{const_name FastMap.lookup_tree_valid},
454                    (kT --> keyT) --> treeT --> valid_resultT) $
455               get_key $ tree_term))
456    val solver = simp_tac (put_simpset HOL_basic_ss ctxt
457                           addsimps (@{thms prod.sel FastMap.lookup_tree_valid_simps'} @
458                                     key_ord_thms)) 1
459  in Goal.prove ctxt [] [] tree_valid_prop (K solver) end
460
461fun solve_simp_tac name ctxt = SUBGOAL (fn (t, i) =>
462      (simp_tac ctxt THEN_ALL_NEW SUBGOAL (fn (t', _) =>
463          raise TERM (name ^ ": unsolved", [t, t']))) i)
464
465fun convert_to_lookup_list kT valT mappings map_const map_def tree_valid_thm simp_ctxt ctxt = let
466  val lookupT = fastype_of map_const
467  (* map_eq = "<tree_const> = map_of <mappings>" *)
468  val bindT = HOLogic.mk_prodT (kT, valT)
469  val lookup_list = HOLogic.mk_list bindT (map HOLogic.mk_prod mappings)
470  val map_of_Const = Const (@{const_name map_of}, HOLogic.listT bindT --> lookupT)
471  val map_eq = HOLogic.mk_eq (map_const, map_of_Const $ lookup_list)
472  (* distinct_pred = "distinct (map fst <mappings>)" *)
473  val distinct_pred =
474        Const (@{const_name distinct}, HOLogic.listT kT --> HOLogic.boolT) $
475          (Const (@{const_name map}, (bindT --> kT) --> HOLogic.listT bindT --> HOLogic.listT kT) $
476             Const (@{const_name fst}, bindT --> kT) $
477             lookup_list)
478  val convert_prop = HOLogic.mk_Trueprop (
479        HOLogic.mk_conj (map_eq, distinct_pred)
480        )
481  fun TIMED desc tac = fn st =>
482        Seq.make (K (Timing.timeap_msg ("tactic timing for " ^ desc)
483                          (fn () => Seq.pull (tac st)) ()))
484
485  val append_basecase = @{thm append.simps(1)}
486  val append_rec = @{thm append.simps(2)}
487  val lookup_tree_to_list_basecase = @{thm FastMap.lookup_tree_to_list.simps(1)}
488  val lookup_tree_to_list_rec = @{thm FastMap.lookup_tree_to_list.simps(2)[simplified append.simps]}
489
490  val lookup_tree_to_list_eval = let
491    fun append_conv () =
492            fo_rewr_conv append_basecase else_conv
493            (fo_rewr_conv append_rec then_conv'
494             (fn () => rhs_conv (append_conv ())))
495    fun to_map_conv () =
496            fo_rewr_conv lookup_tree_to_list_basecase else_conv
497            (fo_rewr_conv lookup_tree_to_list_rec then_conv'
498             (fn () => lhs_conv (to_map_conv ())) then_conv'
499             (fn () => rhs_conv (rhs_conv (to_map_conv ()))) then_conv
500             append_conv ())
501  in to_map_conv () end
502
503  val solver =
504        TIMED "unfold" (gconv_at_tac @{term "Trueprop (HERE = map_of dummy1 \<and> dummy2)"}
505                                     (K map_def) ctxt)
506        THEN
507        TIMED "main rule" (resolve_tac ctxt @{thms FastMap.lookup_tree_to_list_of_gen} 1)
508        THEN
509          TIMED "solve inj" (solve_simp_tac "solve inj"
510                              (simp_ctxt ctxt @{thms simp_thms} @{thms inj_def}) 1)
511          THEN
512         TIMED "resolve valid" (resolve_tac ctxt [tree_valid_thm] 1)
513         THEN
514        TIMED "convert tree" (gconv_at_tac @{term "Trueprop (HERE = dummy1)"}
515                                           lookup_tree_to_list_eval ctxt)
516        THEN
517        resolve_tac ctxt @{thms refl} 1
518  val convert_thm = Goal.prove ctxt [] [] convert_prop (K solver)
519  in convert_thm end
520
521(* Obtain domain and range from lookup list *)
522fun domain_range_common dom_ran_const xT xs map_const lookup_list_eqn intro_conv_tac ctxt = let
523  val mapT = fastype_of map_const
524  val prop = HOLogic.mk_Trueprop (
525        HOLogic.mk_eq (
526          Const (dom_ran_const, mapT --> HOLogic.mk_setT xT) $ map_const,
527          Const (@{const_name set}, HOLogic.listT xT --> HOLogic.mk_setT xT) $
528            (HOLogic.mk_list xT xs)
529        ))
530  val lookup_list_eqn' = then_eq_reflection lookup_list_eqn
531  val map_fst_snd_conv = fo_topdown_rewr_conv @{thms list.map prod.sel} ctxt
532  val solver =
533        (gconv_at_tac @{term "Trueprop (dom_ran_dummy1 HERE = dummy2)"}
534                      (K lookup_list_eqn') ctxt)
535        THEN
536        intro_conv_tac
537        THEN
538        gconv_at_tac @{term "Trueprop (HERE = dummy1)"} map_fst_snd_conv ctxt
539        THEN
540        resolve_tac ctxt @{thms refl} 1
541  in Goal.prove ctxt [] [] prop (K solver) end;
542
543fun tree_domain kT mappings map_const lookup_list_eqn ctxt =
544  domain_range_common
545    @{const_name dom} kT (map fst mappings) map_const lookup_list_eqn
546    (* like (subst dom_map_of_conv_list) but faster *)
547    (resolve_tac ctxt @{thms FastMap.dom_map_of_conv_list[THEN trans]} 1)
548    ctxt;
549
550fun tree_range valT mappings map_const lookup_list_eqn map_distinct_thm ctxt =
551  domain_range_common
552    @{const_name ran} valT (map snd mappings) map_const lookup_list_eqn
553    (* like (subst ran_map_of_conv_list) but faster *)
554    (resolve_tac ctxt @{thms FastMap.ran_map_of_conv_list[THEN trans]} 1 THEN
555     resolve_tac ctxt [map_distinct_thm] 1)
556    ctxt;
557
558
559(* Choosing names for the const and its theorems. The constant will be named with
560   map_name; Local_Theory.define may also add extra names (e.g. <map_name>_def) *)
561type name_opts = {
562    map_name: string,
563    tree_valid_thm: string,
564    to_lookup_list: string,
565    keys_distinct_thm: string,
566    lookup_thms: string,
567    domain_thm: string,
568    range_thm: string
569};
570
571fun name_opts_default (map_name: string): name_opts = {
572    map_name = map_name,
573    tree_valid_thm = map_name ^ "_tree_valid",
574    to_lookup_list = map_name ^ "_to_lookup_list",
575    keys_distinct_thm = map_name ^ "_keys_distinct",
576    lookup_thms = map_name ^ "_lookups",
577    domain_thm = map_name ^ "_domain",
578    range_thm = map_name ^ "_range"
579};
580
581(* Top level interface *)
582fun define_map
583            (name_opts: name_opts)
584            (mappings: (term * term) list)
585            (get_key: term) (* function to get linorder key, must be injective *)
586            (extra_simps: thm list)
587            (minimal_simp: bool) (* true: start with minimal simpset; extra_simps must be adequate *)
588            ctxt = let
589    fun simp_ctxt ctxt basic_simps more_simps = if minimal_simp
590      then put_simpset HOL_basic_ss ctxt addsimps (basic_simps @ extra_simps @ more_simps)
591      else ctxt addsimps (extra_simps @ more_simps)
592
593    val (kT, keyT) = dest_funT (fastype_of get_key)
594    val valT = fastype_of (snd (hd mappings))
595
596    val _ = tracing (#map_name name_opts ^ ": defining tree")
597    val start = Timing.start ()
598    val ((map_const, map_def), ctxt) =
599            define_partial_map_tree
600                (Binding.name (#map_name name_opts))
601                mappings get_key ctxt
602    val _ = tracing ("  done: " ^ Timing.message (Timing.result start))
603
604    val _ = tracing (#map_name name_opts ^ ": proving tree is well-formed")
605    val start = Timing.start ()
606    val _ $ _ $ tree_term = Thm.term_of (Thm.rhs_of map_def)
607    val tree_valid_thm =
608          prove_tree_valid (#map_name name_opts) mappings kT keyT tree_term get_key simp_ctxt ctxt
609    val (_, ctxt) = ctxt |> Local_Theory.notes
610        [((Binding.name (#tree_valid_thm name_opts), []), [([tree_valid_thm], [])])]
611    val _ = tracing ("  done: " ^ Timing.message (Timing.result start))
612
613    val _ = tracing (#map_name name_opts ^ ": converting tree to map")
614    val start = Timing.start ()
615    val convert_thm =
616          convert_to_lookup_list kT valT mappings map_const map_def tree_valid_thm simp_ctxt ctxt
617    val [lookup_list_eqn, map_distinct_thm] = HOLogic.conj_elims ctxt convert_thm
618    val _ = tracing ("  done: " ^ Timing.message (Timing.result start))
619    val _ = tracing (#map_name name_opts ^ ": storing map and distinctness theorems")
620    val start = Timing.start ()
621    val (_, ctxt) = ctxt |> Local_Theory.notes
622        [((Binding.name (#to_lookup_list name_opts), []), [([lookup_list_eqn], [])]),
623         ((Binding.name (#keys_distinct_thm name_opts), []), [([map_distinct_thm], [])])]
624    val _ = tracing ("  done: " ^ Timing.message (Timing.result start))
625
626    val _ = tracing (#map_name name_opts ^ ": obtaining lookup rules")
627    val start = Timing.start ()
628    fun dest_list_all_conv () =
629          fo_rewr_conv @{thm FastMap.list_all_dest(1)} else_conv
630          (fo_rewr_conv @{thm FastMap.list_all_dest(2)} then_conv'
631           (fn () => rhs_conv (dest_list_all_conv())))
632    val combined_lookup_thm =
633          (convert_thm RS @{thm FastMap.map_of_lookups})
634          |> Conv.fconv_rule (conv_at @{term "Trueprop HERE"} (dest_list_all_conv ()) ctxt)
635    val _ = tracing ("  splitting... " ^ Timing.message (Timing.result start))
636    val lookup_thms =
637          HOLogic.conj_elims ctxt combined_lookup_thm
638          |> map (Conv.fconv_rule (conv_at @{term "Trueprop HERE"}
639                                     (fo_rewr_conv @{thm prod.case[THEN eq_reflection]}) ctxt))
640
641    val _ = if length lookup_thms = length mappings then () else
642              raise THM ("wrong number of lookup thms: " ^ string_of_int (length lookup_thms) ^
643                         " instead of " ^ string_of_int (length mappings), 0,
644                         lookup_thms)
645    val (_, ctxt) = ctxt |> Local_Theory.notes
646        [((Binding.name (#lookup_thms name_opts), []), [(lookup_thms, [])])]
647    val _ = tracing ("  done: " ^ Timing.message (Timing.result start))
648
649    (* domain and range *)
650    val _ = tracing (#map_name name_opts ^ ": getting domain and range")
651    val start = Timing.start ()
652    val domain_thm = timeap_msg "  calculate domain"
653            (tree_domain kT mappings map_const lookup_list_eqn) ctxt
654    val range_thm = timeap_msg "  calculate range"
655            (tree_range valT mappings map_const lookup_list_eqn map_distinct_thm) ctxt
656    val (_, ctxt) = ctxt |> Local_Theory.notes
657        [((Binding.name (#domain_thm name_opts), []), [([domain_thm], [])]),
658         ((Binding.name (#range_thm name_opts), []), [([range_thm], [])])]
659    val _ = tracing ("  done: " ^ Timing.message (Timing.result start))
660  in ctxt end
661
662end
663\<close>
664
665end