1(*
2 * Copyright 2018, Data61
3 * Commonwealth Scientific and Industrial Research Organisation (CSIRO)
4 * ABN 41 687 119 230.
5 *
6 * This software may be distributed and modified according to the terms of
7 * the BSD 2-Clause license. Note that NO WARRANTY is provided.
8 * See "LICENSE_BSD2.txt" for details.
9 *
10 * @TAG(DATA61_BSD)
11 *)
12
13theory FastMap
14imports
15  LemmaBucket
16begin
17
18text \<open>
19  Efficient rules and tactics for working with large lookup tables (maps).
20
21  Features:
22    \<^item> Define a binary lookup tree for any lookup table (requires linorder keys)
23    \<^item> Conversion from lookup tree to lookup lists
24    \<^item> Pre-computation of lookup results and domain/range sets
25
26  See FastMap_Test for examples.
27\<close>
28
29(*
30 * TODO:
31 *
32 *   ��� Storing the auxilliary list theorems with Local_Theory.notes
33 *     takes quadratic time. Unfortunately, this seems to be a problem
34 *     deep inside the Isabelle implementation. One might try to wrap
35 *     the lists in new constants, but Local_Theory.define also takes
36 *     quadratic time.
37 *
38 *   ��� Still a bit slower than the StaticFun package. Streamline the
39 *     rulesets and proofs.
40 *
41 *   ��� We use a lot of manual convs for performance and to avoid
42 *     relying on the dynamic simpset. However, we should clean up the
43 *     convs and move as much as possible to (simp only:) invocations.
44 *
45 *     Note that running simp on deeply nested terms (e.g. lists)
46 *     always takes quadratic time and we can't use it there. This is
47 *     because rewritec unconditionally calls eta_conversion (urgh).
48 *
49 *   ��� The injectivity prover currently hardcodes inj_def into the
50 *     simpset. This should be changed at some point, probably by
51 *     asking the user to prove it beforehand.
52 *
53 *   ��� The key ordering prover is currently hardcoded to simp_tac;
54 *     this should also be generalised. On the other hand, the user
55 *     could work around this by manually supplying a simpset with
56 *     precisely the needed theorems.
57 *
58 *   ��� Using the simplifier to evaluate tree lookups is still quite
59 *     slow because it looks at the entire tree term (even though
60 *     most of it is irrelevant for any given lookup). We should
61 *     provide a tactic or simproc to do this.
62 *
63 *     We already generate lookup theorems for keys in the map, so
64 *     this tactic should be optimised for missing keys.
65 *
66 *   ��� The linorder requirement can be cumbersome. It arises because
67 *     we express the map_of conversion as a general theorem using
68 *     lookup_tree_valid. An alternative approach is to extend what
69 *     StaticFun does, and cleverly extract the set of all relevant
70 *     bindings from the tree on a case-by-case basis.
71 *
72 *     We would still need to evaluate the key ordering function on
73 *     the input keys, but any arbitrary relation would be allowed.
74 *     This one probably calls for a wizard.
75 *)
76
77locale FastMap begin
78
79text \<open>
80  Binary lookup tree. This is largely an implementation detail, so we
81  choose the structure to make automation easier (e.g. separate fields
82  for the key and value).
83
84  We could reuse HOL.Tree instead, but the proofs would need changing.
85\<close>
86datatype ('k, 'v) Tree =
87    Leaf
88  | Node 'k 'v "('k, 'v) Tree" "('k, 'v) Tree"
89
90primrec lookup_tree :: "('k \<Rightarrow> 'ok :: linorder) \<Rightarrow> ('k, 'v) Tree \<Rightarrow> 'k \<Rightarrow> 'v option"
91  where
92    "lookup_tree key Leaf x = None"
93  | "lookup_tree key (Node k v l r) x =
94       (if key x = key k then Some v
95        else if key x < key k then lookup_tree key l x
96        else lookup_tree key r x)"
97
98text \<open>
99  Predicate for well-formed lookup trees.
100  This states that the keys are distinct and appear in ascending order.
101  It also returns the lowest and highest keys in the tree (or None for empty trees).
102\<close>
103primrec lookup_tree_valid ::
104        "('k \<Rightarrow> 'ok :: linorder) \<Rightarrow> ('k, 'v) Tree \<Rightarrow> bool \<times> ('k \<times> 'k) option"
105  where
106    "lookup_tree_valid key Leaf = (True, None)"
107  | "lookup_tree_valid key (Node k v lt rt) =
108       (let (lt_valid, lt_range) = lookup_tree_valid key lt;
109            (rt_valid, rt_range) = lookup_tree_valid key rt;
110            lt_low = (case lt_range of None \<Rightarrow> k | Some (low, high) \<Rightarrow> low);
111            rt_high = (case rt_range of None \<Rightarrow> k | Some (low, high) \<Rightarrow> high)
112        in (lt_valid \<and> rt_valid \<and>
113            (case lt_range of None \<Rightarrow> True | Some (low, high) \<Rightarrow> key high < key k) \<and>
114            (case rt_range of None \<Rightarrow> True | Some (low, high) \<Rightarrow> key k < key low),
115            Some (lt_low, rt_high)))"
116
117lemma lookup_tree_valid_simps':
118  "lookup_tree_valid key Leaf = (True, None)"
119  "lookup_tree_valid key (Node k v Leaf Leaf) = (True, Some (k, k))"
120  "\<lbrakk> lookup_tree_valid key (Node lk lv llt lrt) = (True, Some (llow, lhigh));
121     key lhigh < key k
122   \<rbrakk> \<Longrightarrow> lookup_tree_valid key (Node k v (Node lk lv llt lrt) Leaf) =
123           (True, Some (llow, k))"
124  "\<lbrakk> lookup_tree_valid key (Node rk rv rlt rrt) = (True, Some (rlow, rhigh));
125     key k < key rlow
126   \<rbrakk> \<Longrightarrow> lookup_tree_valid key (Node k v Leaf (Node rk rv rlt rrt)) =
127           (True, Some (k, rhigh))"
128  "\<lbrakk> lookup_tree_valid key (Node lk lv llt lrt) = (True, Some (llow, lhigh));
129     lookup_tree_valid key (Node rk rv rlt rrt) = (True, Some (rlow, rhigh));
130     key lhigh < key k;
131     key k < key rlow
132   \<rbrakk> \<Longrightarrow> lookup_tree_valid key (Node k v (Node lk lv llt lrt) (Node rk rv rlt rrt)) =
133           (True, Some (llow, rhigh))"
134  by auto
135
136lemma lookup_tree_valid_empty:
137  "lookup_tree_valid key tree = (True, None) \<Longrightarrow> tree = Leaf"
138  apply (induct tree)
139   apply simp
140  apply (fastforce split: prod.splits option.splits if_splits)
141  done
142
143lemma lookup_tree_valid_range:
144  "lookup_tree_valid key tree = (True, Some (low, high)) \<Longrightarrow> key low \<le> key high"
145  apply (induct tree arbitrary: low high)
146   apply simp
147  apply (fastforce split: prod.splits option.splits if_splits)
148  done
149
150lemma lookup_tree_valid_in_range:
151  "lookup_tree_valid key tree = (True, Some (low, high)) \<Longrightarrow>
152   lookup_tree key tree k = Some v \<Longrightarrow>
153   key k \<in> {key low .. key high}"
154  apply (induct tree arbitrary: k v low high)
155   apply simp
156  apply (fastforce split: prod.splits option.splits if_split_asm
157                   dest: lookup_tree_valid_empty lookup_tree_valid_range)
158  done
159
160lemma lookup_tree_valid_in_range_None:
161  "lookup_tree_valid key tree = (True, Some (low, high)) \<Longrightarrow>
162   key k \<notin> {key low .. key high} \<Longrightarrow>
163   lookup_tree key tree k = None"
164  using lookup_tree_valid_in_range by fastforce
165
166text \<open>
167  Flatten a lookup tree into an assoc-list.
168  As long as the tree is well-formed, both forms are equivalent.
169\<close>
170primrec lookup_tree_to_list :: "('k, 'v) Tree \<Rightarrow> ('k \<times> 'v) list"
171  where
172    "lookup_tree_to_list Leaf = []"
173  | "lookup_tree_to_list (Node k v lt rt) =
174        lookup_tree_to_list lt @ [(k, v)] @ lookup_tree_to_list rt"
175
176lemma lookup_tree_to_list_range:
177  "lookup_tree_valid key tree = (True, Some (low, high)) \<Longrightarrow>
178   (k, v) \<in> set (lookup_tree_to_list tree) \<Longrightarrow>
179   key k \<in> {key low .. key high}"
180  apply (induct tree arbitrary: k v low high)
181   apply simp
182  apply (fastforce split: prod.splits option.splits if_split_asm
183                   dest: lookup_tree_valid_empty lookup_tree_valid_range)
184  done
185
186lemma lookup_tree_dom_distinct_sorted_:
187  "fst (lookup_tree_valid key tree) \<Longrightarrow>
188   distinct (lookup_tree_to_list tree) \<and> sorted (map (key o fst) (lookup_tree_to_list tree))"
189  apply (induct tree)
190   apply simp
191  apply (fastforce simp: sorted_append
192                   split: prod.splits option.splits if_splits
193                   dest: lookup_tree_valid_empty lookup_tree_valid_range
194                         lookup_tree_valid_in_range lookup_tree_to_list_range
195                   elim: lookup_tree_valid_in_range_None)
196  done
197
198lemmas lookup_tree_dom_distinct = lookup_tree_dom_distinct_sorted_[THEN conjunct1]
199lemmas lookup_tree_dom_sorted = lookup_tree_dom_distinct_sorted_[THEN conjunct2]
200
201(* This goal is eta-expanded and flipped, which seems to help its proof *)
202lemma lookup_tree_to_list_of_:
203  "fst (lookup_tree_valid key tree) \<Longrightarrow>
204   map_of (map (apfst key) (lookup_tree_to_list tree)) (key k) = lookup_tree key tree k"
205  apply (induct tree arbitrary: k)
206   apply simp
207  (* this big blob just does case distinctions of both subtrees and
208     all possible lookup results within each, then solves *)
209  (* slow 10s *)
210  apply (fastforce simp: apfst_def map_prod_def map_add_def
211                   split: prod.splits option.splits if_splits
212                   dest: lookup_tree_valid_empty lookup_tree_valid_range lookup_tree_valid_in_range
213                   elim: lookup_tree_valid_in_range_None)
214  done
215
216(* Standard form of above *)
217lemma lookup_tree_to_list_of:
218  "fst (lookup_tree_valid key tree) \<Longrightarrow>
219   lookup_tree key tree = map_of (map (apfst key) (lookup_tree_to_list tree)) o key"
220  apply (rule ext)
221  apply (simp add: lookup_tree_to_list_of_)
222  done
223
224lemma map_of_key:
225  "inj key \<Longrightarrow> map_of (map (apfst key) binds) o key = map_of binds"
226  apply (rule ext)
227  apply (induct binds)
228   apply simp
229  apply (clarsimp simp: inj_def dom_def)
230  done
231
232lemma lookup_tree_to_list_of_distinct:
233  "\<lbrakk> fst (lookup_tree_valid key tree);
234     lookup_tree_to_list tree = binds;
235     lookup_tree key tree = map_of (map (apfst key) binds) o key
236   \<rbrakk> \<Longrightarrow> distinct (map (key o fst) binds)"
237  apply (drule sym[where t = binds])
238  apply clarsimp
239  apply (thin_tac "binds = _")
240  apply (induct tree)
241   apply simp
242  apply (fastforce simp: map_add_def lookup_tree_to_list_of
243                   split: prod.splits option.splits if_splits
244                   dest: lookup_tree_valid_empty lookup_tree_valid_range
245                         lookup_tree_valid_in_range lookup_tree_to_list_range
246                   elim: lookup_tree_valid_in_range_None)
247  done
248
249(* Top-level rule for converting to lookup list.
250   We add a distinctness assertion for inferring the range of values. *)
251lemma lookup_tree_to_list_of_gen:
252  "\<lbrakk> inj key;
253     fst (lookup_tree_valid key tree);
254     lookup_tree_to_list tree = binds
255   \<rbrakk> \<Longrightarrow> lookup_tree key tree = map_of binds \<and> distinct (map fst binds)"
256  using lookup_tree_to_list_of
257  apply (fastforce intro: lookup_tree_to_list_of_distinct
258                   simp: map_of_key distinct_inj)
259  done
260
261text \<open>
262  Domain and range of a @{const map_of}.
263  Like @{thm dom_map_of_conv_image_fst} but leaving out the set bloat.
264\<close>
265lemma dom_map_of_conv_list:
266  "dom (map_of xs) = set (map fst xs)"
267  by (simp add: dom_map_of_conv_image_fst)
268
269lemma ran_map_of_conv_list:
270  "distinct (map fst xs) \<Longrightarrow> ran (map_of xs) = set (map snd xs)"
271  by (erule distinct_map_via_ran)
272
273text \<open>
274  Read lookup rules from a @{const map_of}.
275\<close>
276
277lemma map_of_lookups:
278  "m = map_of binds \<and> distinct (map fst binds) \<Longrightarrow>
279   list_all (\<lambda>(k, v). m k = Some v) binds"
280  apply (induct binds)
281   apply simp
282  apply (force simp: list_all_iff)
283  done
284
285(* Helper for converting from maps defined as @{const fun_upd} chains,
286 * which are applied in reverse order *)
287lemma map_of_rev:
288  "distinct (map fst binds) \<Longrightarrow>
289   map_of (rev binds) = map_of binds"
290  apply (subgoal_tac "distinct (map fst (rev binds))")
291   apply (rule ext)
292   apply (induct binds)
293    apply simp
294   apply (force simp: map_add_def split: option.splits)
295  apply (metis distinct_rev rev_map)
296  done
297
298lemma list_all_dest:
299  "list_all P [(x, y)] \<equiv> P (x, y)"
300  "list_all P ((x, y) # z # xs) \<equiv> (P (x, y) \<and> list_all P (z # xs))"
301  by auto
302
303(* Install lookup rules that don't depend on if_cong/if_weak_cong setup *)
304lemma lookup_tree_simps':
305  "lookup_tree key Leaf x = None"
306  "key x = key k \<Longrightarrow> lookup_tree key (Node k v l r) x = Some v"
307  "key x < key k \<Longrightarrow> lookup_tree key (Node k v l r) x = lookup_tree key l x"
308  "key x > key k \<Longrightarrow> lookup_tree key (Node k v l r) x = lookup_tree key r x"
309  by auto
310end
311
312declare FastMap.lookup_tree.simps[simp del]
313declare FastMap.lookup_tree_simps'[simp]
314
315ML \<open>
316structure FastMap = struct
317
318(* utils *)
319fun mk_optionT typ = Type ("Option.option", [typ])
320fun dest_optionT (Type ("Option.option", [typ])) = typ
321  | dest_optionT t = raise TYPE ("dest_optionT", [t], [])
322
323(* O(1) version of thm RS @{thm eq_reflection} *)
324fun then_eq_reflection thm = let
325  val (x, y) = Thm.dest_binop (Thm.dest_arg (Thm.cprop_of thm));
326  val cT = Thm.ctyp_of_cterm x;
327  val rule = @{thm eq_reflection} |> Thm.instantiate' [SOME cT] [SOME x, SOME y];
328  in Thm.implies_elim rule thm end;
329
330val lhs_conv = Conv.fun_conv o Conv.arg_conv
331val rhs_conv = Conv.arg_conv
332(* first order rewr_conv *)
333fun fo_rewr_conv rule ct = let
334  val (pure_eq, eqn) =
335        ((true, Thm.instantiate (Thm.first_order_match (Thm.lhs_of rule, ct)) rule)
336         handle TERM _ =>
337           (false, Thm.instantiate (Thm.first_order_match
338                      (fst (Thm.dest_binop (Thm.dest_arg (Thm.cprop_of rule))), ct)) rule))
339        handle Pattern.MATCH => raise CTERM ("fo_rewr_conv", [Thm.cprop_of rule, ct]);
340  in if pure_eq then eqn else then_eq_reflection eqn end;
341fun fo_rewrs_conv rules = Conv.first_conv (map fo_rewr_conv rules);
342
343(* Evaluate a term with rewrite rules. Unlike the simplifier, this
344 * does only one top-down pass, but that's enough for tasks like
345 * pushing List.map through a list. Also runs much faster. *)
346fun fo_topdown_rewr_conv rules ctxt =
347      Conv.top_conv (K (Conv.try_conv (fo_rewrs_conv rules))) ctxt
348
349(* Allow recursive conv in cv2, deferred by function application *)
350infix 1 then_conv'
351fun (cv1 then_conv' cv2) ct =
352  let
353    val eq1 = cv1 ct;
354    val eq2 = cv2 () (Thm.rhs_of eq1);
355  in
356    if Thm.is_reflexive eq1 then eq2
357    else if Thm.is_reflexive eq2 then eq1
358    else Thm.transitive eq1 eq2
359  end;
360
361(*
362 * Helper that makes it easier to describe where to apply a conv.
363 * This takes a skeleton term and applies the conversion wherever "HERE"
364 * appears in the skeleton.
365 *
366 * FIXME: use HOL-Library.Rewrite instead
367 *)
368fun conv_at skel conv ctxt ct = let
369  fun mismatch current_skel current_ct =
370    raise TERM ("conv_at mismatch", [current_skel, Thm.term_of current_ct, skel, Thm.term_of ct])
371
372  fun walk (Free ("HERE", _)) ctxt ct = conv ct
373    | walk (skel as skel_f $ skel_x) ctxt ct =
374        (case Thm.term_of ct of
375            f $ x => Conv.combination_conv (walk skel_f ctxt) (walk skel_x ctxt) ct
376          | _ => mismatch skel ct)
377    | walk (skel as Abs (_, _, skel_body)) ctxt ct =
378        (case Thm.term_of ct of
379            Abs _ => Conv.abs_conv (fn (v, ctxt') => walk skel_body ctxt') ctxt ct
380          | _ => mismatch skel ct)
381    (* Also check that Consts match the skeleton pattern *)
382    | walk (skel as Const (skel_name, _)) ctxt ct =
383        if (case Thm.term_of ct of Const (name, _) => name = skel_name | _ => false)
384        then Thm.reflexive ct
385        else mismatch skel ct
386    (* Default case *)
387    | walk _ ctxt ct = Thm.reflexive ct
388  in walk skel ctxt ct end
389
390fun gconv_at_tac pat conv ctxt = Conv.gconv_rule (conv_at pat conv ctxt) 1 #> Seq.succeed
391
392
393(* Tree builder code, copied from StaticFun *)
394
395(* Actually build the tree -- theta (n lg(n)) *)
396fun build_tree' _ mk_leaf [] = mk_leaf
397  | build_tree' mk_node mk_leaf xs = let
398    val len = length xs
399    val (ys, zs) = chop (len div 2) xs
400  in case zs of [] => error "build_tree': impossible"
401    | ((a, b) :: zs) => mk_node a b (build_tree' mk_node mk_leaf ys)
402        (build_tree' mk_node mk_leaf zs)
403  end
404
405fun build_tree xs = case xs of [] => error "build_tree : empty"
406  | (idx, v) :: _ => let
407    val idxT = fastype_of idx
408    val vT = fastype_of v
409    val treeT = Type (@{type_name FastMap.Tree}, [idxT, vT])
410    val mk_leaf = Const (@{const_name FastMap.Leaf}, treeT)
411    val node = Const (@{const_name FastMap.Node},
412        idxT --> vT --> treeT --> treeT --> treeT)
413    fun mk_node a b l r = node $ a $ b $ l $ r
414  in
415    build_tree' mk_node mk_leaf xs
416  end
417
418fun define_partial_map_tree map_name mappings ord_term ctxt = let
419    val (idxT, vT) = apply2 fastype_of (hd mappings)
420    val treeT = Type (@{type_name FastMap.Tree}, [idxT, vT])
421    val lookup = Const (@{const_name FastMap.lookup_tree},
422                        fastype_of ord_term --> treeT --> idxT
423                            --> Type (@{type_name option}, [vT]))
424    val map_term = lookup $ ord_term $ build_tree mappings
425    val ((map_const, (_, map_def)), ctxt) =
426          Local_Theory.define ((map_name, NoSyn), ((Thm.def_binding map_name, []), map_term)) ctxt
427  in
428    ((map_const, map_def), ctxt)
429  end
430
431(* Prove key ordering theorems. This lets us issue precise error messages
432   when the user gives us keys whose ordering cannot be verified.
433   We will also need these thms to prove the lookup_tree_valid property. *)
434fun prove_key_ord_thms tree_name keyT mappings get_key simp_ctxt ctxt =
435  let
436    val solver = simp_tac (simp_ctxt ctxt [] []) 1;
437  in
438    fst (split_last mappings) ~~ tl mappings
439    |> map_index (fn (i, ((k1, _), (k2, _))) => let
440           val prop = Const (@{const_name less}, keyT --> keyT --> HOLogic.boolT) $
441                            (get_key $ k1) $ (get_key $ k2)
442                      |> HOLogic.mk_Trueprop;
443           in case try (Goal.prove ctxt [] [] prop) (K solver) of
444                  SOME x => x
445                | _ => raise TERM (tree_name ^ ": failed to prove less-than ordering for keys #" ^
446                                   string_of_int i ^ ", #" ^ string_of_int (i + 1),
447                                   [prop])
448           end)
449  end;
450
451(* Prove lookup_tree_valid *)
452fun prove_tree_valid tree_name mappings kT keyT tree_term get_key simp_ctxt ctxt = let
453    val key_ord_thms = prove_key_ord_thms tree_name keyT mappings get_key simp_ctxt ctxt;
454    val treeT = fastype_of tree_term
455    val valid_resultT = HOLogic.mk_prodT (HOLogic.boolT, mk_optionT (HOLogic.mk_prodT (kT, kT)))
456    val tree_valid_prop =
457          HOLogic.mk_Trueprop (
458            Const (@{const_name fst}, valid_resultT --> HOLogic.boolT) $
459            (Const (@{const_name FastMap.lookup_tree_valid},
460                    (kT --> keyT) --> treeT --> valid_resultT) $
461               get_key $ tree_term))
462    val solver = simp_tac (put_simpset HOL_basic_ss ctxt
463                           addsimps (@{thms prod.sel FastMap.lookup_tree_valid_simps'} @
464                                     key_ord_thms)) 1
465  in Goal.prove ctxt [] [] tree_valid_prop (K solver) end
466
467fun solve_simp_tac name ctxt = SUBGOAL (fn (t, i) =>
468      (simp_tac ctxt THEN_ALL_NEW SUBGOAL (fn (t', _) =>
469          raise TERM (name ^ ": unsolved", [t, t']))) i)
470
471fun convert_to_lookup_list kT valT mappings map_const map_def tree_valid_thm simp_ctxt ctxt = let
472  val lookupT = fastype_of map_const
473  (* map_eq = "<tree_const> = map_of <mappings>" *)
474  val bindT = HOLogic.mk_prodT (kT, valT)
475  val lookup_list = HOLogic.mk_list bindT (map HOLogic.mk_prod mappings)
476  val map_of_Const = Const (@{const_name map_of}, HOLogic.listT bindT --> lookupT)
477  val map_eq = HOLogic.mk_eq (map_const, map_of_Const $ lookup_list)
478  (* distinct_pred = "distinct (map fst <mappings>)" *)
479  val distinct_pred =
480        Const (@{const_name distinct}, HOLogic.listT kT --> HOLogic.boolT) $
481          (Const (@{const_name map}, (bindT --> kT) --> HOLogic.listT bindT --> HOLogic.listT kT) $
482             Const (@{const_name fst}, bindT --> kT) $
483             lookup_list)
484  val convert_prop = HOLogic.mk_Trueprop (
485        HOLogic.mk_conj (map_eq, distinct_pred)
486        )
487  fun TIMED desc tac = fn st =>
488        Seq.make (K (Timing.timeap_msg ("tactic timing for " ^ desc)
489                          (fn () => Seq.pull (tac st)) ()))
490
491  val append_basecase = @{thm append.simps(1)}
492  val append_rec = @{thm append.simps(2)}
493  val lookup_tree_to_list_basecase = @{thm FastMap.lookup_tree_to_list.simps(1)}
494  val lookup_tree_to_list_rec = @{thm FastMap.lookup_tree_to_list.simps(2)[simplified append.simps]}
495
496  val lookup_tree_to_list_eval = let
497    fun append_conv () =
498            fo_rewr_conv append_basecase else_conv
499            (fo_rewr_conv append_rec then_conv'
500             (fn () => rhs_conv (append_conv ())))
501    fun to_map_conv () =
502            fo_rewr_conv lookup_tree_to_list_basecase else_conv
503            (fo_rewr_conv lookup_tree_to_list_rec then_conv'
504             (fn () => lhs_conv (to_map_conv ())) then_conv'
505             (fn () => rhs_conv (rhs_conv (to_map_conv ()))) then_conv
506             append_conv ())
507  in to_map_conv () end
508
509  val solver =
510        TIMED "unfold" (gconv_at_tac @{term "Trueprop (HERE = map_of dummy1 \<and> dummy2)"}
511                                     (K map_def) ctxt)
512        THEN
513        TIMED "main rule" (resolve_tac ctxt @{thms FastMap.lookup_tree_to_list_of_gen} 1)
514        THEN
515          TIMED "solve inj" (solve_simp_tac "solve inj"
516                              (simp_ctxt ctxt @{thms simp_thms} @{thms inj_def}) 1)
517          THEN
518         TIMED "resolve valid" (resolve_tac ctxt [tree_valid_thm] 1)
519         THEN
520        TIMED "convert tree" (gconv_at_tac @{term "Trueprop (HERE = dummy1)"}
521                                           lookup_tree_to_list_eval ctxt)
522        THEN
523        resolve_tac ctxt @{thms refl} 1
524  val convert_thm = Goal.prove ctxt [] [] convert_prop (K solver)
525  in convert_thm end
526
527(* Obtain domain and range from lookup list *)
528fun domain_range_common dom_ran_const xT xs map_const lookup_list_eqn intro_conv_tac ctxt = let
529  val mapT = fastype_of map_const
530  val prop = HOLogic.mk_Trueprop (
531        HOLogic.mk_eq (
532          Const (dom_ran_const, mapT --> HOLogic.mk_setT xT) $ map_const,
533          Const (@{const_name set}, HOLogic.listT xT --> HOLogic.mk_setT xT) $
534            (HOLogic.mk_list xT xs)
535        ))
536  val lookup_list_eqn' = then_eq_reflection lookup_list_eqn
537  val map_fst_snd_conv = fo_topdown_rewr_conv @{thms list.map prod.sel} ctxt
538  val solver =
539        (gconv_at_tac @{term "Trueprop (dom_ran_dummy1 HERE = dummy2)"}
540                      (K lookup_list_eqn') ctxt)
541        THEN
542        intro_conv_tac
543        THEN
544        gconv_at_tac @{term "Trueprop (HERE = dummy1)"} map_fst_snd_conv ctxt
545        THEN
546        resolve_tac ctxt @{thms refl} 1
547  in Goal.prove ctxt [] [] prop (K solver) end;
548
549fun tree_domain kT mappings map_const lookup_list_eqn ctxt =
550  domain_range_common
551    @{const_name dom} kT (map fst mappings) map_const lookup_list_eqn
552    (* like (subst dom_map_of_conv_list) but faster *)
553    (resolve_tac ctxt @{thms FastMap.dom_map_of_conv_list[THEN trans]} 1)
554    ctxt;
555
556fun tree_range valT mappings map_const lookup_list_eqn map_distinct_thm ctxt =
557  domain_range_common
558    @{const_name ran} valT (map snd mappings) map_const lookup_list_eqn
559    (* like (subst ran_map_of_conv_list) but faster *)
560    (resolve_tac ctxt @{thms FastMap.ran_map_of_conv_list[THEN trans]} 1 THEN
561     resolve_tac ctxt [map_distinct_thm] 1)
562    ctxt;
563
564
565(* Choosing names for the const and its theorems. The constant will be named with
566   map_name; Local_Theory.define may also add extra names (e.g. <map_name>_def) *)
567type name_opts = {
568    map_name: string,
569    tree_valid_thm: string,
570    to_lookup_list: string,
571    keys_distinct_thm: string,
572    lookup_thms: string,
573    domain_thm: string,
574    range_thm: string
575};
576
577fun name_opts_default (map_name: string): name_opts = {
578    map_name = map_name,
579    tree_valid_thm = map_name ^ "_tree_valid",
580    to_lookup_list = map_name ^ "_to_lookup_list",
581    keys_distinct_thm = map_name ^ "_keys_distinct",
582    lookup_thms = map_name ^ "_lookups",
583    domain_thm = map_name ^ "_domain",
584    range_thm = map_name ^ "_range"
585};
586
587(* Top level interface *)
588fun define_map
589            (name_opts: name_opts)
590            (mappings: (term * term) list)
591            (get_key: term) (* function to get linorder key, must be injective *)
592            (extra_simps: thm list)
593            (minimal_simp: bool) (* true: start with minimal simpset; extra_simps must be adequate *)
594            ctxt = let
595    fun simp_ctxt ctxt basic_simps more_simps = if minimal_simp
596      then put_simpset HOL_basic_ss ctxt addsimps (basic_simps @ extra_simps @ more_simps)
597      else ctxt addsimps (extra_simps @ more_simps)
598
599    val (kT, keyT) = dest_funT (fastype_of get_key)
600    val valT = fastype_of (snd (hd mappings))
601
602    val _ = tracing (#map_name name_opts ^ ": defining tree")
603    val start = Timing.start ()
604    val ((map_const, map_def), ctxt) =
605            define_partial_map_tree
606                (Binding.name (#map_name name_opts))
607                mappings get_key ctxt
608    val _ = tracing ("  done: " ^ Timing.message (Timing.result start))
609
610    val _ = tracing (#map_name name_opts ^ ": proving tree is well-formed")
611    val start = Timing.start ()
612    val _ $ _ $ tree_term = Thm.term_of (Thm.rhs_of map_def)
613    val tree_valid_thm =
614          prove_tree_valid (#map_name name_opts) mappings kT keyT tree_term get_key simp_ctxt ctxt
615    val (_, ctxt) = ctxt |> Local_Theory.notes
616        [((Binding.name (#tree_valid_thm name_opts), []), [([tree_valid_thm], [])])]
617    val _ = tracing ("  done: " ^ Timing.message (Timing.result start))
618
619    val _ = tracing (#map_name name_opts ^ ": converting tree to map")
620    val start = Timing.start ()
621    val convert_thm =
622          convert_to_lookup_list kT valT mappings map_const map_def tree_valid_thm simp_ctxt ctxt
623    val [lookup_list_eqn, map_distinct_thm] = HOLogic.conj_elims ctxt convert_thm
624    val _ = tracing ("  done: " ^ Timing.message (Timing.result start))
625    val _ = tracing (#map_name name_opts ^ ": storing map and distinctness theorems")
626    val start = Timing.start ()
627    val (_, ctxt) = ctxt |> Local_Theory.notes
628        [((Binding.name (#to_lookup_list name_opts), []), [([lookup_list_eqn], [])]),
629         ((Binding.name (#keys_distinct_thm name_opts), []), [([map_distinct_thm], [])])]
630    val _ = tracing ("  done: " ^ Timing.message (Timing.result start))
631
632    val _ = tracing (#map_name name_opts ^ ": obtaining lookup rules")
633    val start = Timing.start ()
634    fun dest_list_all_conv () =
635          fo_rewr_conv @{thm FastMap.list_all_dest(1)} else_conv
636          (fo_rewr_conv @{thm FastMap.list_all_dest(2)} then_conv'
637           (fn () => rhs_conv (dest_list_all_conv())))
638    val combined_lookup_thm =
639          (convert_thm RS @{thm FastMap.map_of_lookups})
640          |> Conv.fconv_rule (conv_at @{term "Trueprop HERE"} (dest_list_all_conv ()) ctxt)
641    val _ = tracing ("  splitting... " ^ Timing.message (Timing.result start))
642    val lookup_thms =
643          HOLogic.conj_elims ctxt combined_lookup_thm
644          |> map (Conv.fconv_rule (conv_at @{term "Trueprop HERE"}
645                                     (fo_rewr_conv @{thm prod.case[THEN eq_reflection]}) ctxt))
646
647    val _ = if length lookup_thms = length mappings then () else
648              raise THM ("wrong number of lookup thms: " ^ string_of_int (length lookup_thms) ^
649                         " instead of " ^ string_of_int (length mappings), 0,
650                         lookup_thms)
651    val (_, ctxt) = ctxt |> Local_Theory.notes
652        [((Binding.name (#lookup_thms name_opts), []), [(lookup_thms, [])])]
653    val _ = tracing ("  done: " ^ Timing.message (Timing.result start))
654
655    (* domain and range *)
656    val _ = tracing (#map_name name_opts ^ ": getting domain and range")
657    val start = Timing.start ()
658    val domain_thm = timeap_msg "  calculate domain"
659            (tree_domain kT mappings map_const lookup_list_eqn) ctxt
660    val range_thm = timeap_msg "  calculate range"
661            (tree_range valT mappings map_const lookup_list_eqn map_distinct_thm) ctxt
662    val (_, ctxt) = ctxt |> Local_Theory.notes
663        [((Binding.name (#domain_thm name_opts), []), [([domain_thm], [])]),
664         ((Binding.name (#range_thm name_opts), []), [([range_thm], [])])]
665    val _ = tracing ("  done: " ^ Timing.message (Timing.result start))
666  in ctxt end
667
668end
669\<close>
670
671end