1(*  Title:      HOL/Tools/BNF/bnf_lfp_rec_sugar_more.ML
2    Author:     Lorenz Panny, TU Muenchen
3    Author:     Jasmin Blanchette, TU Muenchen
4    Copyright   2013
5
6More recursor sugar.
7*)
8
9signature BNF_LFP_REC_SUGAR_MORE =
10sig
11  val massage_nested_rec_call: Proof.context -> (term -> bool) -> (typ -> typ -> term -> term) ->
12    (typ * typ -> term) -> typ list -> term -> term -> term -> term
13end;
14
15structure BNF_LFP_Rec_Sugar_More : BNF_LFP_REC_SUGAR_MORE =
16struct
17
18open BNF_Util
19open BNF_Tactics
20open BNF_Def
21open BNF_FP_Util
22open BNF_FP_Def_Sugar
23open BNF_FP_N2M_Sugar
24open BNF_LFP_Rec_Sugar
25
26(* FIXME: remove "nat" cases throughout once it is registered as a datatype *)
27
28val nested_simps = @{thms o_def[abs_def] id_def split fst_conv snd_conv};
29
30fun special_endgame_tac ctxt fp_nesting_map_ident0s fp_nesting_map_comps fp_nesting_pred_maps =
31  ALLGOALS (CONVERSION Thm.eta_long_conversion) THEN
32  HEADGOAL (simp_tac (ss_only @{thms pred_fun_True_id} ctxt
33    addsimprocs [\<^simproc>\<open>NO_MATCH\<close>])) THEN
34  unfold_thms_tac ctxt (nested_simps @
35    map (unfold_thms ctxt @{thms id_def}) (fp_nesting_map_ident0s @ fp_nesting_map_comps @
36      fp_nesting_pred_maps)) THEN
37  ALLGOALS (rtac ctxt refl);
38
39fun is_new_datatype _ \<^type_name>\<open>nat\<close> = true
40  | is_new_datatype ctxt s =
41    (case fp_sugar_of ctxt s of
42      SOME {fp = Least_FP, fp_co_induct_sugar = SOME _, ...} => true
43    | _ => false);
44
45fun basic_lfp_sugar_of C fun_arg_Tsss ({T, fp_res_index, fp_ctr_sugar = {ctr_sugar, ...},
46    fp_co_induct_sugar = SOME {co_rec = recx, co_rec_thms = rec_thms, ...}, ...} : fp_sugar) =
47    {T = T, fp_res_index = fp_res_index, C = C, fun_arg_Tsss = fun_arg_Tsss, ctr_sugar = ctr_sugar,
48     recx = recx, rec_thms = rec_thms};
49
50fun basic_lfp_sugars_of _ [\<^typ>\<open>nat\<close>] _ _ lthy =
51    ([], [0], [nat_basic_lfp_sugar], [], [], [], TrueI (*dummy*), [], false, lthy)
52  | basic_lfp_sugars_of bs arg_Ts callers callssss0 lthy0 =
53    let
54      val ((missing_arg_Ts, perm0_kks,
55            fp_sugars as {fp_nesting_bnfs,
56              fp_co_induct_sugar = SOME {common_co_inducts = [common_induct], ...}, ...} :: _,
57            (lfp_sugar_thms, _)), lthy) =
58        nested_to_mutual_fps (K true) Least_FP bs arg_Ts callers callssss0 lthy0;
59
60      val induct_attrs = (case lfp_sugar_thms of SOME ((_, _, attrs), _) => attrs | NONE => []);
61
62      val Ts = map #T fp_sugars;
63      val Xs = map #X fp_sugars;
64      val Cs = map (body_type o fastype_of o #co_rec o the o #fp_co_induct_sugar) fp_sugars;
65      val Xs_TCs = Xs ~~ (Ts ~~ Cs);
66
67      fun zip_XrecT (Type (s, Us)) = [Type (s, map (HOLogic.mk_tupleT o zip_XrecT) Us)]
68        | zip_XrecT U =
69          (case AList.lookup (op =) Xs_TCs U of
70            SOME (T, C) => [T, C]
71          | NONE => [U]);
72
73      val ctrXs_Tsss = map (#ctrXs_Tss o #fp_ctr_sugar) fp_sugars;
74      val fun_arg_Tssss = map (map (map zip_XrecT)) ctrXs_Tsss;
75
76      val fp_nesting_map_ident0s = map map_ident0_of_bnf fp_nesting_bnfs;
77      val fp_nesting_map_comps = map map_comp_of_bnf fp_nesting_bnfs;
78      val fp_nesting_pred_maps = map pred_map_of_bnf fp_nesting_bnfs;
79    in
80      (missing_arg_Ts, perm0_kks, @{map 3} basic_lfp_sugar_of Cs fun_arg_Tssss fp_sugars,
81       fp_nesting_map_ident0s, fp_nesting_map_comps, fp_nesting_pred_maps, common_induct,
82       induct_attrs, is_some lfp_sugar_thms, lthy)
83    end;
84
85exception NO_MAP of term;
86
87fun massage_nested_rec_call ctxt has_call massage_fun massage_nonfun bound_Ts y y' t0 =
88  let
89    fun check_no_call t = if has_call t then unexpected_rec_call_in ctxt [t0] t else ();
90
91    val typof = curry fastype_of1 bound_Ts;
92    val massage_no_call = build_map ctxt [] [] massage_nonfun;
93
94    val yT = typof y;
95    val yU = typof y';
96
97    fun y_of_y' () = massage_no_call (yU, yT) $ y';
98    val elim_y = Term.map_aterms (fn t => if t = y then y_of_y' () else t);
99
100    fun massage_mutual_fun U T t =
101      (case t of
102        Const (\<^const_name>\<open>comp\<close>, _) $ t1 $ t2 =>
103        mk_comp bound_Ts (tap check_no_call t1, massage_mutual_fun U T t2)
104      | _ =>
105        if has_call t then massage_fun U T t else mk_comp bound_Ts (t, massage_no_call (U, T)));
106
107    fun massage_map (Type (_, Us)) (Type (s, Ts)) t =
108        (case try (dest_map ctxt s) t of
109          SOME (map0, fs) =>
110          let
111            val Type (_, ran_Ts) = range_type (typof t);
112            val map' = mk_map (length fs) Us ran_Ts map0;
113            val fs' = map_flattened_map_args ctxt s (@{map 3} massage_map_or_map_arg Us Ts) fs;
114          in
115            Term.list_comb (map', fs')
116          end
117        | NONE =>
118          (case try (dest_pred ctxt s) t of
119            SOME (pred0, fs) =>
120            let
121              val pred' = mk_pred Us pred0;
122              val fs' = map_flattened_map_args ctxt s (@{map 3} massage_map_or_map_arg Us Ts) fs;
123            in
124              Term.list_comb (pred', fs')
125            end
126          | NONE => raise NO_MAP t))
127      | massage_map _ _ t = raise NO_MAP t
128    and massage_map_or_map_arg U T t =
129      if T = U then
130        tap check_no_call t
131      else
132        massage_map U T t
133        handle NO_MAP _ => massage_mutual_fun U T t;
134
135    fun massage_outer_call (t as t1 $ t2) =
136        if has_call t then
137          if t2 = y then
138            massage_map yU yT (elim_y t1) $ y'
139            handle NO_MAP t' => invalid_map ctxt [t0] t'
140          else
141            let val (g, xs) = Term.strip_comb t2 in
142              if g = y then
143                if exists has_call xs then unexpected_rec_call_in ctxt [t0] t2
144                else Term.list_comb (massage_outer_call (mk_compN (length xs) bound_Ts (t1, y)), xs)
145              else
146                ill_formed_rec_call ctxt t
147            end
148        else
149          elim_y t
150      | massage_outer_call t = if t = y then y_of_y' () else ill_formed_rec_call ctxt t;
151  in
152    massage_outer_call t0
153  end;
154
155fun rewrite_map_fun ctxt get_ctr_pos U T t =
156  let
157    val _ =
158      (case try HOLogic.dest_prodT U of
159        SOME (U1, _) => U1 = T orelse invalid_map ctxt [] t
160      | NONE => invalid_map ctxt [] t);
161
162    fun subst d (t as Bound d') = t |> d = SOME d' ? curry (op $) (fst_const U)
163      | subst d (Abs (v, T, b)) =
164        Abs (v, if d = SOME ~1 then U else T, subst (Option.map (Integer.add 1) d) b)
165      | subst d t =
166        let
167          val (u, vs) = strip_comb t;
168          val ctr_pos = try (get_ctr_pos o fst o dest_Free) u |> the_default ~1;
169        in
170          if ctr_pos >= 0 then
171            if d = SOME ~1 andalso length vs = ctr_pos then
172              Term.list_comb (permute_args ctr_pos (snd_const U), vs)
173            else if length vs > ctr_pos andalso is_some d andalso
174                d = try (fn Bound n => n) (nth vs ctr_pos) then
175              Term.list_comb (snd_const U $ nth vs ctr_pos, map (subst d) (nth_drop ctr_pos vs))
176            else
177              rec_call_not_apply_to_ctr_arg ctxt [] t
178          else
179            Term.list_comb (u, map (subst (if d = SOME ~1 then NONE else d)) vs)
180        end;
181  in
182    subst (SOME ~1) t
183  end;
184
185fun rewrite_nested_rec_call ctxt has_call get_ctr_pos =
186  massage_nested_rec_call ctxt has_call (rewrite_map_fun ctxt get_ctr_pos) (fst_const o fst);
187
188val _ = Theory.setup (register_lfp_rec_extension
189  {nested_simps = nested_simps, special_endgame_tac = special_endgame_tac,
190   is_new_datatype = is_new_datatype, basic_lfp_sugars_of = basic_lfp_sugars_of,
191   rewrite_nested_rec_call = SOME rewrite_nested_rec_call});
192
193end;
194