1(*  Title:      HOL/Tools/BNF/bnf_lfp_countable.ML
2    Author:     Jasmin Blanchette, TU Muenchen
3    Copyright   2014
4
5Countability tactic for BNF datatypes.
6*)
7
8signature BNF_LFP_COUNTABLE =
9sig
10  val derive_encode_injectives_thms: Proof.context -> string list -> thm list
11  val countable_datatype_tac: Proof.context -> tactic
12end;
13
14structure BNF_LFP_Countable : BNF_LFP_COUNTABLE =
15struct
16
17open BNF_FP_Rec_Sugar_Util
18open BNF_Def
19open BNF_Util
20open BNF_Tactics
21open BNF_FP_Util
22open BNF_FP_Def_Sugar
23
24val countableS = \<^sort>\<open>countable\<close>;
25
26fun nchotomy_tac ctxt nchotomy =
27  HEADGOAL (resolve_tac ctxt [nchotomy RS @{thm all_reg[rotated]}] THEN'
28    REPEAT_ALL_NEW (resolve_tac ctxt [allI, impI] ORELSE' eresolve_tac ctxt [exE, disjE]));
29
30fun meta_spec_mp_tac _ 0 = K all_tac
31  | meta_spec_mp_tac ctxt depth =
32    dtac ctxt meta_spec THEN' meta_spec_mp_tac ctxt (depth - 1) THEN'
33    dtac ctxt meta_mp THEN' assume_tac ctxt;
34
35fun use_induction_hypothesis_tac ctxt =
36  DEEPEN (1, 64 (* large number *))
37    (fn depth => meta_spec_mp_tac ctxt depth THEN' etac ctxt allE THEN' etac ctxt impE THEN'
38      assume_tac ctxt THEN' assume_tac ctxt) 0;
39
40val same_ctr_simps = @{thms sum_encode_eq prod_encode_eq sum.inject prod.inject to_nat_split
41  id_apply snd_conv simp_thms};
42val distinct_ctrs_simps = @{thms sum_encode_eq sum.inject sum.distinct simp_thms};
43
44fun same_ctr_tac ctxt injects recs map_congs' inj_map_strongs' =
45  HEADGOAL (asm_full_simp_tac
46      (ss_only (injects @ recs @ map_congs' @ same_ctr_simps) ctxt) THEN_MAYBE'
47    TRY o REPEAT_ALL_NEW (rtac ctxt conjI) THEN_ALL_NEW
48    REPEAT_ALL_NEW (eresolve_tac ctxt (conjE :: inj_map_strongs')) THEN_ALL_NEW
49    (assume_tac ctxt ORELSE' use_induction_hypothesis_tac ctxt));
50
51fun distinct_ctrs_tac ctxt recs =
52  HEADGOAL (asm_full_simp_tac (ss_only (recs @ distinct_ctrs_simps) ctxt));
53
54fun mk_encode_injective_tac ctxt n nchotomy injects recs map_comps' inj_map_strongs' =
55  let val ks = 1 upto n in
56    EVERY (maps (fn k => nchotomy_tac ctxt nchotomy :: map (fn k' =>
57      if k = k' then same_ctr_tac ctxt injects recs map_comps' inj_map_strongs'
58      else distinct_ctrs_tac ctxt recs) ks) ks)
59  end;
60
61fun mk_encode_injectives_tac ctxt ns induct nchotomys injectss recss map_comps' inj_map_strongs' =
62  HEADGOAL (rtac ctxt induct) THEN
63  EVERY (@{map 4} (fn n => fn nchotomy => fn injects => fn recs =>
64      mk_encode_injective_tac ctxt n nchotomy injects recs map_comps' inj_map_strongs')
65    ns nchotomys injectss recss);
66
67fun endgame_tac ctxt encode_injectives =
68  unfold_thms_tac ctxt @{thms inj_on_def ball_UNIV} THEN
69  ALLGOALS (rtac ctxt exI THEN' rtac ctxt allI THEN' resolve_tac ctxt encode_injectives);
70
71fun encode_sumN n k t =
72  Balanced_Tree.access {init = t,
73      left = fn t => \<^const>\<open>sum_encode\<close> $ (@{const Inl (nat, nat)} $ t),
74      right = fn t => \<^const>\<open>sum_encode\<close> $ (@{const Inr (nat, nat)} $ t)}
75    n k;
76
77fun encode_tuple [] = \<^term>\<open>0 :: nat\<close>
78  | encode_tuple ts =
79    Balanced_Tree.make (fn (t, u) => \<^const>\<open>prod_encode\<close> $ (@{const Pair (nat, nat)} $ u $ t)) ts;
80
81fun mk_encode_funs ctxt fpTs ns ctrss0 recs0 =
82  let
83    val thy = Proof_Context.theory_of ctxt;
84
85    fun check_countable T =
86      Sign.of_sort thy (T, countableS) orelse
87      raise TYPE ("Type is not of sort " ^ Syntax.string_of_sort ctxt countableS, [T], []);
88
89    fun mk_to_nat_checked T =
90      Const (\<^const_name>\<open>to_nat\<close>, tap check_countable T --> HOLogic.natT);
91
92    val nn = length ns;
93    val recs as rec1 :: _ = map2 (mk_co_rec thy Least_FP (replicate nn HOLogic.natT)) fpTs recs0;
94    val arg_Ts = binder_fun_types (fastype_of rec1);
95    val arg_Tss = Library.unflat ctrss0 arg_Ts;
96
97    fun mk_U (Type (\<^type_name>\<open>prod\<close>, [T1, T2])) =
98        if member (op =) fpTs T1 then T2 else HOLogic.mk_prodT (mk_U T1, mk_U T2)
99      | mk_U (Type (s, Ts)) = Type (s, map mk_U Ts)
100      | mk_U T = T;
101
102    fun mk_nat (j, T) =
103      if T = HOLogic.natT then
104        SOME (Bound j)
105      else if member (op =) fpTs T then
106        NONE
107      else if exists_subtype_in fpTs T then
108        let val U = mk_U T in
109          SOME (mk_to_nat_checked U $ (build_map ctxt [] [] (snd_const o fst) (T, U) $ Bound j))
110        end
111      else
112        SOME (mk_to_nat_checked T $ Bound j);
113
114    fun mk_arg n (k, arg_T) =
115      let
116        val bound_Ts = rev (binder_types arg_T);
117        val nats = map_filter mk_nat (tag_list 0 bound_Ts);
118      in
119        fold (fn T => fn t => Abs (Name.uu, T, t)) bound_Ts (encode_sumN n k (encode_tuple nats))
120      end;
121
122    val argss = map2 (map o mk_arg) ns (map (tag_list 1) arg_Tss);
123  in
124    map (fn recx => Term.list_comb (recx, flat argss)) recs
125  end;
126
127fun derive_encode_injectives_thms _ [] = []
128  | derive_encode_injectives_thms ctxt fpT_names0 =
129    let
130      fun not_datatype_name s =
131        error (quote s ^ " is not a datatype");
132      fun not_mutually_recursive ss = error (commas ss ^ " are not mutually recursive datatypes");
133
134      fun lfp_sugar_of s =
135        (case fp_sugar_of ctxt s of
136          SOME (fp_sugar as {fp = Least_FP, fp_co_induct_sugar = SOME _, ...}) => fp_sugar
137        | _ => not_datatype_name s);
138
139      val fpTs0 as Type (_, var_As) :: _ =
140        map (#T o lfp_sugar_of o fst o dest_Type) (#Ts (#fp_res (lfp_sugar_of (hd fpT_names0))));
141      val fpT_names = map (fst o dest_Type) fpTs0;
142
143      val (As_names, _) = Variable.variant_fixes (map (fn TVar ((s, _), _) => s) var_As) ctxt;
144      val As =
145        map2 (fn s => fn TVar (_, S) => TFree (s, union (op =) countableS S))
146          As_names var_As;
147      val fpTs = map (fn s => Type (s, As)) fpT_names;
148
149      val _ = subset (op =) (fpT_names0, fpT_names) orelse not_mutually_recursive fpT_names0;
150
151      fun mk_conjunct fpT x encode_fun =
152        HOLogic.all_const fpT $ Abs (Name.uu, fpT,
153          HOLogic.mk_imp (HOLogic.mk_eq (encode_fun $ x, encode_fun $ Bound 0),
154            HOLogic.eq_const fpT $ x $ Bound 0));
155
156      val fp_sugars as
157          {fp_nesting_bnfs, fp_co_induct_sugar = SOME {common_co_inducts = induct :: _, ...},
158           ...} :: _ =
159        map (the o fp_sugar_of ctxt o fst o dest_Type) fpTs0;
160      val ctr_sugars = map (#ctr_sugar o #fp_ctr_sugar) fp_sugars;
161
162      val ctrss0 = map #ctrs ctr_sugars;
163      val ns = map length ctrss0;
164      val recs0 = map (#co_rec o the o #fp_co_induct_sugar) fp_sugars;
165      val nchotomys = map #nchotomy ctr_sugars;
166      val injectss = map #injects ctr_sugars;
167      val rec_thmss = map (#co_rec_thms o the o #fp_co_induct_sugar) fp_sugars;
168      val map_comps' = map (unfold_thms ctxt @{thms comp_def} o map_comp_of_bnf) fp_nesting_bnfs;
169      val inj_map_strongs' = map (Thm.permute_prems 0 ~1 o inj_map_strong_of_bnf) fp_nesting_bnfs;
170
171      val (xs, names_ctxt) = ctxt |> mk_Frees "x" fpTs;
172
173      val conjuncts = @{map 3} mk_conjunct fpTs xs (mk_encode_funs ctxt fpTs ns ctrss0 recs0);
174      val goal = HOLogic.mk_Trueprop (Library.foldr1 HOLogic.mk_conj conjuncts);
175    in
176      Goal.prove (*no sorry*) ctxt [] [] goal (fn {context = ctxt, prems = _} =>
177        mk_encode_injectives_tac ctxt ns induct nchotomys injectss rec_thmss map_comps'
178          inj_map_strongs')
179      |> HOLogic.conj_elims ctxt
180      |> Proof_Context.export names_ctxt ctxt
181      |> map (Thm.close_derivation \<^here>)
182    end;
183
184fun get_countable_goal_type_name (\<^const>\<open>Trueprop\<close> $ (Const (\<^const_name>\<open>Ex\<close>, _)
185    $ Abs (_, Type (_, [Type (s, _), _]), Const (\<^const_name>\<open>inj_on\<close>, _) $ Bound 0
186        $ Const (\<^const_name>\<open>top\<close>, _)))) = s
187  | get_countable_goal_type_name _ = error "Wrong goal format for datatype countability tactic";
188
189fun core_countable_datatype_tac ctxt st =
190  let val T_names = map get_countable_goal_type_name (Thm.prems_of st) in
191    endgame_tac ctxt (derive_encode_injectives_thms ctxt T_names) st
192  end;
193
194fun countable_datatype_tac ctxt =
195  TRY (Class.intro_classes_tac ctxt []) THEN core_countable_datatype_tac ctxt;
196
197end;
198