1(*  Title:      HOL/Tools/BNF/bnf_gfp_grec_sugar_util.ML
2    Author:     Aymeric Bouzy, Ecole polytechnique
3    Author:     Jasmin Blanchette, Inria, LORIA, MPII
4    Copyright   2015, 2016
5
6Library for generalized corecursor sugar.
7*)
8
9signature BNF_GFP_GREC_SUGAR_UTIL =
10sig
11  type s_parse_info =
12    {outer_buffer: BNF_GFP_Grec.buffer,
13     ctr_guards: term Symtab.table,
14     inner_buffer: BNF_GFP_Grec.buffer}
15
16  type rho_parse_info =
17    {pattern_ctrs: (term * term list) Symtab.table,
18     discs: term Symtab.table,
19     sels: term Symtab.table,
20     it: term,
21     mk_case: typ -> term}
22
23  exception UNNATURAL of unit
24
25  val generalize_types: int -> typ -> typ -> typ
26  val mk_curry_uncurryN_balanced: Proof.context -> int -> thm
27  val mk_const_transfer_goal: Proof.context -> string * typ -> term
28  val mk_abs_transfer: Proof.context -> string -> thm
29  val mk_rep_transfer: Proof.context -> string -> thm
30  val mk_pointful_natural_from_transfer: Proof.context -> thm -> thm
31
32  val corec_parse_info_of: Proof.context -> typ list -> typ -> BNF_GFP_Grec.buffer -> s_parse_info
33  val friend_parse_info_of: Proof.context -> typ list -> typ -> BNF_GFP_Grec.buffer ->
34    s_parse_info * rho_parse_info
35end;
36
37structure BNF_GFP_Grec_Sugar_Util : BNF_GFP_GREC_SUGAR_UTIL =
38struct
39
40open Ctr_Sugar
41open BNF_Util
42open BNF_Tactics
43open BNF_Def
44open BNF_Comp
45open BNF_FP_Util
46open BNF_FP_Def_Sugar
47open BNF_GFP_Grec
48open BNF_GFP_Grec_Tactics
49
50val mk_case_sumN_balanced = Balanced_Tree.make mk_case_sum;
51
52fun generalize_types max_j T U =
53  let
54    val vars = Unsynchronized.ref [];
55
56    fun var_of T U =
57      (case AList.lookup (op =) (!vars) (T, U) of
58        SOME V => V
59      | NONE =>
60        let val V = TVar ((Name.aT, length (!vars) + max_j), \<^sort>\<open>type\<close>) in
61          vars := ((T, U), V) :: !vars; V
62        end);
63
64    fun gen (T as Type (s, Ts)) (U as Type (s', Us)) =
65        if s = s' then Type (s, map2 gen Ts Us) else var_of T U
66      | gen T U = if T = U then T else var_of T U;
67  in
68    gen T U
69  end;
70
71fun mk_curry_uncurryN_balanced_raw ctxt n =
72  let
73    val ((As, B), names_ctxt) = ctxt
74      |> mk_TFrees (n + 1)
75      |>> split_last;
76
77    val tupled_As = mk_tupleT_balanced As;
78
79    val f_T = As ---> B;
80    val g_T = tupled_As --> B;
81
82    val (((f, g), xs), _) = names_ctxt
83      |> yield_singleton (mk_Frees "f") f_T
84      ||>> yield_singleton (mk_Frees "g") g_T
85      ||>> mk_Frees "x" As;
86
87    val tupled_xs = mk_tuple1_balanced As xs;
88
89    val uncurried_f = mk_tupled_fun f tupled_xs xs;
90    val curried_g = abs_curried_balanced As g;
91
92    val lhs = HOLogic.mk_eq (uncurried_f, g);
93    val rhs =  HOLogic.mk_eq (f, curried_g);
94    val goal = fold_rev Logic.all [f, g] (mk_Trueprop_eq (lhs, rhs));
95
96    fun mk_tac ctxt =
97      HEADGOAL (rtac ctxt iffI THEN' dtac ctxt sym THEN' hyp_subst_tac ctxt) THEN
98      unfold_thms_tac ctxt @{thms prod.case} THEN
99      HEADGOAL (rtac ctxt refl THEN' hyp_subst_tac ctxt THEN'
100        REPEAT_DETERM o subst_tac ctxt NONE @{thms unit_abs_eta_conv case_prod_eta} THEN'
101        rtac ctxt refl);
102  in
103    Goal.prove_sorry ctxt [] [] goal (fn {context = ctxt, ...} => mk_tac ctxt)
104    |> Thm.close_derivation \<^here>
105  end;
106
107val num_curry_uncurryN_balanced_precomp = 8;
108val curry_uncurryN_balanced_precomp =
109  map (mk_curry_uncurryN_balanced_raw \<^context>) (0 upto num_curry_uncurryN_balanced_precomp);
110
111fun mk_curry_uncurryN_balanced ctxt n =
112  if n <= num_curry_uncurryN_balanced_precomp then nth curry_uncurryN_balanced_precomp n
113  else mk_curry_uncurryN_balanced_raw ctxt n;
114
115fun mk_const_transfer_goal ctxt (s, var_T) =
116  let
117    val var_As = Term.add_tvarsT var_T [];
118
119    val ((As, Bs), names_ctxt) = ctxt
120      |> Variable.declare_typ var_T
121      |> mk_TFrees' (map snd var_As)
122      ||>> mk_TFrees' (map snd var_As);
123
124    val (Rs, _) = names_ctxt
125      |> mk_Frees "R" (map2 mk_pred2T As Bs);
126
127    val T = Term.typ_subst_TVars (map fst var_As ~~ As) var_T;
128    val U = Term.typ_subst_TVars (map fst var_As ~~ Bs) var_T;
129  in
130    mk_parametricity_goal ctxt Rs (Const (s, T)) (Const (s, U))
131    |> tap (fn goal => can type_of goal orelse
132      error ("Cannot transfer constant " ^ quote (Syntax.string_of_term ctxt (Const (s, T))) ^
133        " from type " ^ quote (Syntax.string_of_typ ctxt T) ^ " to " ^
134        quote (Syntax.string_of_typ ctxt U)))
135  end;
136
137fun mk_abs_transfer ctxt fpT_name =
138  let
139    val SOME {pre_bnf, absT_info = {absT, repT, abs, type_definition, ...}, ...} =
140      fp_sugar_of ctxt fpT_name;
141  in
142    if absT = repT then
143      raise Fail "no abs/rep"
144    else
145      let
146        val rel_def = rel_def_of_bnf pre_bnf;
147
148        val absT = T_of_bnf pre_bnf
149          |> singleton (freeze_types ctxt (map dest_TVar (lives_of_bnf pre_bnf)));
150
151        val goal = mk_const_transfer_goal ctxt (dest_Const (mk_abs absT abs))
152      in
153        Variable.add_free_names ctxt goal []
154        |> (fn vars => Goal.prove_sorry ctxt vars [] goal (fn {context = ctxt, prems = _} =>
155          unfold_thms_tac ctxt [rel_def] THEN
156          HEADGOAL (rtac ctxt refl ORELSE'
157            rtac ctxt (@{thm Abs_transfer} OF [type_definition, type_definition]))))
158      end
159  end;
160
161fun mk_rep_transfer ctxt fpT_name =
162  let
163    val SOME {pre_bnf, absT_info = {absT, repT, rep, ...}, ...} = fp_sugar_of ctxt fpT_name;
164  in
165    if absT = repT then
166      raise Fail "no abs/rep"
167    else
168      let
169        val rel_def = rel_def_of_bnf pre_bnf;
170
171        val absT = T_of_bnf pre_bnf
172          |> singleton (freeze_types ctxt (map dest_TVar (lives_of_bnf pre_bnf)));
173
174        val goal = mk_const_transfer_goal ctxt (dest_Const (mk_rep absT rep))
175      in
176        Variable.add_free_names ctxt goal []
177        |> (fn vars => Goal.prove_sorry ctxt vars [] goal (fn {context = ctxt, prems = _} =>
178          unfold_thms_tac ctxt [rel_def] THEN
179          HEADGOAL (rtac ctxt refl ORELSE' rtac ctxt @{thm vimage2p_rel_fun})))
180      end
181  end;
182
183exception UNNATURAL of unit;
184
185fun mk_pointful_natural_from_transfer ctxt transfer =
186  let
187    val _ $ (_ $ Const (s, T0) $ Const (_, U0)) = Thm.prop_of transfer;
188    val [T, U] = freeze_types ctxt [] [T0, U0];
189    val var_T = generalize_types 0 T U;
190
191    val var_As = map TVar (rev (Term.add_tvarsT var_T []));
192
193    val ((As, Bs), names_ctxt) = ctxt
194      |> mk_TFrees' (map Type.sort_of_atyp var_As)
195      ||>> mk_TFrees' (map Type.sort_of_atyp var_As);
196
197    val TA = typ_subst_atomic (var_As ~~ As) var_T;
198
199    val ((xs, fs), _) = names_ctxt
200      |> mk_Frees "x" (binder_types TA)
201      ||>> mk_Frees "f" (map2 (curry (op -->)) As Bs);
202
203    val AB_fs = (As ~~ Bs) ~~ fs;
204
205    fun build_applied_map TU t =
206      if op = TU then
207        t
208      else
209        (case try (build_map ctxt [] [] (the o AList.lookup (op =) AB_fs)) TU of
210          SOME mapx => mapx $ t
211        | NONE => raise UNNATURAL ());
212
213    fun unextensionalize (f $ (x as Free _), rhs) = unextensionalize (f, lambda x rhs)
214      | unextensionalize tu = tu;
215
216    val TB = typ_subst_atomic (var_As ~~ Bs) var_T;
217
218    val (binder_TAs, body_TA) = strip_type TA;
219    val (binder_TBs, body_TB) = strip_type TB;
220
221    val n = length var_As;
222    val m = length binder_TAs;
223
224    val A_nesting_bnfs = nesting_bnfs ctxt [[body_TA :: binder_TAs]] As;
225    val A_nesting_map_ids = map map_id_of_bnf A_nesting_bnfs;
226    val A_nesting_rel_Grps = map rel_Grp_of_bnf A_nesting_bnfs;
227
228    val ta = Const (s, TA);
229    val tb = Const (s, TB);
230    val xfs = @{map 3} (curry build_applied_map) binder_TAs binder_TBs xs;
231
232    val goal = (list_comb (tb, xfs), build_applied_map (body_TA, body_TB) (list_comb (ta, xs)))
233      |> unextensionalize |> mk_Trueprop_eq;
234
235    val _ = if can type_of goal then () else raise UNNATURAL ();
236
237    val vars = map (fst o dest_Free) (xs @ fs);
238  in
239    Goal.prove_sorry ctxt vars [] goal (fn {context = ctxt, prems = _} =>
240      mk_natural_from_transfer_tac ctxt m (replicate n true) transfer A_nesting_map_ids
241        A_nesting_rel_Grps [])
242    |> Thm.close_derivation \<^here>
243  end;
244
245type s_parse_info =
246  {outer_buffer: BNF_GFP_Grec.buffer,
247   ctr_guards: term Symtab.table,
248   inner_buffer: BNF_GFP_Grec.buffer};
249
250type rho_parse_info =
251  {pattern_ctrs: (term * term list) Symtab.table,
252   discs: term Symtab.table,
253   sels: term Symtab.table,
254   it: term,
255   mk_case: typ -> term};
256
257fun curry_friend (T, t) =
258  let
259    val prod_T = domain_type (fastype_of t);
260    val Ts = dest_tupleT_balanced (num_binder_types T) prod_T;
261    val xs = map_index (fn (i, T) => Free ("x" ^ string_of_int i, T)) Ts;
262    val body = mk_tuple_balanced xs;
263  in
264    (T, fold_rev Term.lambda xs (t $ body))
265  end;
266
267fun curry_friends ({Oper, VLeaf, CLeaf, ctr_wrapper, friends} : buffer) =
268  {Oper = Oper, VLeaf = VLeaf, CLeaf = CLeaf, ctr_wrapper = ctr_wrapper,
269   friends = Symtab.map (K curry_friend) friends};
270
271fun checked_gfp_sugar_of lthy (T as Type (T_name, _)) =
272    (case fp_sugar_of lthy T_name of
273      SOME (sugar as {fp = Greatest_FP, ...}) => sugar
274    | _ => not_codatatype lthy T)
275  | checked_gfp_sugar_of lthy T = not_codatatype lthy T;
276
277fun generic_spec_of friend ctxt arg_Ts res_T (raw_buffer0 as {VLeaf = VLeaf0, ...}) =
278  let
279    val thy = Proof_Context.theory_of ctxt;
280
281    val tupled_arg_T = mk_tupleT_balanced arg_Ts;
282
283    val {T = fpT, X, fp_res_index, fp_res = {ctors = ctors0, ...},
284         absT_info = {abs = abs0, rep = rep0, ...},
285         fp_ctr_sugar = {ctrXs_Tss, ctr_sugar = {ctrs = ctrs0, casex = case0, discs = discs0,
286           selss = selss0, sel_defs, ...}, ...}, ...} =
287      checked_gfp_sugar_of ctxt res_T;
288
289    val VLeaf0_T = fastype_of VLeaf0;
290    val Y = domain_type VLeaf0_T;
291
292    val raw_buffer = specialize_buffer_types raw_buffer0;
293
294    val As_rho = tvar_subst thy [fpT] [res_T];
295
296    val substAT = Term.typ_subst_TVars As_rho;
297    val substA = Term.subst_TVars As_rho;
298    val substYT = Tsubst Y tupled_arg_T;
299    val substY = substT Y tupled_arg_T;
300
301    val Ys_rho_inner = if friend then [] else [(Y, tupled_arg_T)];
302    val substYT_inner = substAT o Term.typ_subst_atomic Ys_rho_inner;
303    val substY_inner = substA o Term.subst_atomic_types Ys_rho_inner;
304
305    val mid_T = substYT_inner (range_type VLeaf0_T);
306
307    val substXT_mid = Tsubst X mid_T;
308
309    val XifyT = typ_subst_nonatomic [(res_T, X)];
310    val YifyT = typ_subst_nonatomic [(res_T, Y)];
311
312    val substXYT = Tsubst X Y;
313
314    val ctor0 = nth ctors0 fp_res_index;
315    val ctor = enforce_type ctxt range_type res_T ctor0;
316    val preT = YifyT (domain_type (fastype_of ctor));
317
318    val n = length ctrs0;
319    val ks = 1 upto n;
320
321    fun mk_ctr_guards () =
322      let
323        val ctr_Tss = map (map (substXT_mid o substAT)) ctrXs_Tss;
324        val preT = XifyT (domain_type (fastype_of ctor));
325        val mid_preT = substXT_mid preT;
326        val abs = enforce_type ctxt range_type mid_preT abs0;
327        val absT = range_type (fastype_of abs);
328
329        fun mk_ctr_guard k ctr_Ts (Const (s, _)) =
330          let
331            val xs = map_index (fn (i, T) => Free ("x" ^ string_of_int i, T)) ctr_Ts;
332            val body = mk_absumprod absT abs n k xs;
333          in
334            (s, fold_rev Term.lambda xs body)
335          end;
336      in
337        Symtab.make (@{map 3} mk_ctr_guard ks ctr_Tss ctrs0)
338      end;
339
340    val substYT_mid = substYT o Tsubst Y mid_T;
341
342    val outer_T = substYT_mid preT;
343
344    val substY_outer = substY o substT Y outer_T;
345
346    val outer_buffer = curry_friends (map_buffer substY_outer raw_buffer);
347    val ctr_guards = mk_ctr_guards ();
348    val inner_buffer = curry_friends (map_buffer substY_inner raw_buffer);
349
350    val s_parse_info =
351      {outer_buffer = outer_buffer, ctr_guards = ctr_guards, inner_buffer = inner_buffer};
352
353    fun mk_friend_spec () =
354      let
355        fun encapsulate_nested U T free =
356          betapply (build_map ctxt [] [] (fn (T, _) =>
357              if T = domain_type VLeaf0_T then Abs (Name.uu, T, VLeaf0 $ Bound 0)
358              else Abs (Name.uu, T, Bound 0)) (T, U),
359            free);
360
361        val preT = YifyT (domain_type (fastype_of ctor));
362        val YpreT = HOLogic.mk_prodT (Y, preT);
363
364        val rep = rep0 |> enforce_type ctxt domain_type (substXT_mid (XifyT preT));
365
366        fun mk_disc k =
367          ctrXs_Tss
368          |> map_index (fn (i, Ts) =>
369            Abs (Name.uu, mk_tupleT_balanced Ts,
370              if i + 1 = k then \<^const>\<open>HOL.True\<close> else \<^const>\<open>HOL.False\<close>))
371          |> mk_case_sumN_balanced
372          |> map_types substXYT
373          |> (fn tm => Library.foldl1 HOLogic.mk_comp [tm, rep, snd_const YpreT])
374          |> map_types substAT;
375
376        val all_discs = map mk_disc ks;
377
378        fun mk_pair (Const (disc_name, _)) disc = SOME (disc_name, disc)
379          | mk_pair _ _ = NONE;
380
381        val discs = @{map 2} mk_pair discs0 all_discs
382          |> map_filter I |> Symtab.make;
383
384        fun mk_sel sel_def =
385          let
386            val (sel_name, case_functions) =
387              sel_def
388              |> Object_Logic.rulify ctxt
389              |> Thm.concl_of
390              |> perhaps (try drop_all)
391              |> perhaps (try HOLogic.dest_Trueprop)
392              |> HOLogic.dest_eq
393              |>> fst o strip_comb
394              |>> fst o dest_Const
395              ||> fst o dest_comb
396              ||> snd o strip_comb
397              ||> map (map_types (XifyT o substAT));
398
399            fun encapsulate_case_function case_function =
400              let
401                fun encapsulate bound_Ts [] case_function =
402                    let val T = fastype_of1 (bound_Ts, case_function) in
403                      encapsulate_nested (substXT_mid T) (substXYT T) case_function
404                    end
405                  | encapsulate bound_Ts (T :: Ts) case_function =
406                    Abs (Name.uu, T,
407                      encapsulate (T :: bound_Ts) Ts
408                        (betapply (incr_boundvars 1 case_function, Bound 0)));
409              in
410                encapsulate [] (binder_types (fastype_of case_function)) case_function
411              end;
412          in
413            (sel_name, ctrXs_Tss
414              |> map (map_index (fn (i, T) => Free ("x" ^ string_of_int (i + 1), T)))
415              |> `(map mk_tuple_balanced)
416              |> uncurry (@{map 3} mk_tupled_fun (map encapsulate_case_function case_functions))
417              |> mk_case_sumN_balanced
418              |> map_types substXYT
419              |> (fn tm => Library.foldl1 HOLogic.mk_comp [tm, rep, snd_const YpreT])
420              |> map_types substAT)
421          end;
422
423        val sels = Symtab.make (map mk_sel sel_defs);
424
425        fun mk_disc_sels_pair disc sels =
426          if forall is_some sels then SOME (disc, map the sels) else NONE;
427
428        val pattern_ctrs = (ctrs0, selss0)
429          ||> map (map (try dest_Const #> Option.mapPartial (fst #> Symtab.lookup sels)))
430          ||> @{map 2} mk_disc_sels_pair all_discs
431          |>> map (dest_Const #> fst)
432          |> op ~~
433          |> map_filter (fn (s, opt) => if is_some opt then SOME (s, the opt) else NONE)
434          |> Symtab.make;
435
436        val it = HOLogic.mk_comp (VLeaf0, fst_const YpreT);
437
438        val mk_case =
439          let
440            val abs_fun_tms = case0
441              |> fastype_of
442              |> substAT
443              |> XifyT
444              |> binder_fun_types
445              |> map_index (fn (i, T) => Free ("f" ^ string_of_int (i + 1), T));
446            val arg_Uss = abs_fun_tms
447              |> map fastype_of
448              |> map binder_types;
449            val arg_Tss = arg_Uss
450              |> map (map substXYT);
451            val case0 =
452              arg_Tss
453              |> map (map_index (fn (i, T) => Free ("x" ^ string_of_int (i + 1), T)))
454              |> `(map mk_tuple_balanced)
455              ||> @{map 3} (@{map 3} encapsulate_nested) arg_Uss arg_Tss
456              |> uncurry (@{map 3} mk_tupled_fun abs_fun_tms)
457              |> mk_case_sumN_balanced
458              |> (fn tm => Library.foldl1 HOLogic.mk_comp [tm, rep, snd_const YpreT])
459              |> fold_rev lambda abs_fun_tms
460              |> map_types (substAT o substXT_mid);
461          in
462            fn U => case0
463              |> substT (body_type (fastype_of case0)) U
464              |> Syntax.check_term ctxt
465          end;
466      in
467        {pattern_ctrs = pattern_ctrs, discs = discs, sels = sels, it = it, mk_case = mk_case}
468      end;
469  in
470    (s_parse_info, mk_friend_spec)
471  end;
472
473fun corec_parse_info_of ctxt =
474  fst ooo generic_spec_of false ctxt;
475
476fun friend_parse_info_of ctxt =
477  apsnd (fn f => f ()) ooo generic_spec_of true ctxt;
478
479end;
480