1(* Title: HOL/Tools/BNF/bnf_gfp_grec_sugar_util.ML 2 Author: Aymeric Bouzy, Ecole polytechnique 3 Author: Jasmin Blanchette, Inria, LORIA, MPII 4 Copyright 2015, 2016 5 6Library for generalized corecursor sugar. 7*) 8 9signature BNF_GFP_GREC_SUGAR_UTIL = 10sig 11 type s_parse_info = 12 {outer_buffer: BNF_GFP_Grec.buffer, 13 ctr_guards: term Symtab.table, 14 inner_buffer: BNF_GFP_Grec.buffer} 15 16 type rho_parse_info = 17 {pattern_ctrs: (term * term list) Symtab.table, 18 discs: term Symtab.table, 19 sels: term Symtab.table, 20 it: term, 21 mk_case: typ -> term} 22 23 exception UNNATURAL of unit 24 25 val generalize_types: int -> typ -> typ -> typ 26 val mk_curry_uncurryN_balanced: Proof.context -> int -> thm 27 val mk_const_transfer_goal: Proof.context -> string * typ -> term 28 val mk_abs_transfer: Proof.context -> string -> thm 29 val mk_rep_transfer: Proof.context -> string -> thm 30 val mk_pointful_natural_from_transfer: Proof.context -> thm -> thm 31 32 val corec_parse_info_of: Proof.context -> typ list -> typ -> BNF_GFP_Grec.buffer -> s_parse_info 33 val friend_parse_info_of: Proof.context -> typ list -> typ -> BNF_GFP_Grec.buffer -> 34 s_parse_info * rho_parse_info 35end; 36 37structure BNF_GFP_Grec_Sugar_Util : BNF_GFP_GREC_SUGAR_UTIL = 38struct 39 40open Ctr_Sugar 41open BNF_Util 42open BNF_Tactics 43open BNF_Def 44open BNF_Comp 45open BNF_FP_Util 46open BNF_FP_Def_Sugar 47open BNF_GFP_Grec 48open BNF_GFP_Grec_Tactics 49 50val mk_case_sumN_balanced = Balanced_Tree.make mk_case_sum; 51 52fun generalize_types max_j T U = 53 let 54 val vars = Unsynchronized.ref []; 55 56 fun var_of T U = 57 (case AList.lookup (op =) (!vars) (T, U) of 58 SOME V => V 59 | NONE => 60 let val V = TVar ((Name.aT, length (!vars) + max_j), \<^sort>\<open>type\<close>) in 61 vars := ((T, U), V) :: !vars; V 62 end); 63 64 fun gen (T as Type (s, Ts)) (U as Type (s', Us)) = 65 if s = s' then Type (s, map2 gen Ts Us) else var_of T U 66 | gen T U = if T = U then T else var_of T U; 67 in 68 gen T U 69 end; 70 71fun mk_curry_uncurryN_balanced_raw ctxt n = 72 let 73 val ((As, B), names_ctxt) = ctxt 74 |> mk_TFrees (n + 1) 75 |>> split_last; 76 77 val tupled_As = mk_tupleT_balanced As; 78 79 val f_T = As ---> B; 80 val g_T = tupled_As --> B; 81 82 val (((f, g), xs), _) = names_ctxt 83 |> yield_singleton (mk_Frees "f") f_T 84 ||>> yield_singleton (mk_Frees "g") g_T 85 ||>> mk_Frees "x" As; 86 87 val tupled_xs = mk_tuple1_balanced As xs; 88 89 val uncurried_f = mk_tupled_fun f tupled_xs xs; 90 val curried_g = abs_curried_balanced As g; 91 92 val lhs = HOLogic.mk_eq (uncurried_f, g); 93 val rhs = HOLogic.mk_eq (f, curried_g); 94 val goal = fold_rev Logic.all [f, g] (mk_Trueprop_eq (lhs, rhs)); 95 96 fun mk_tac ctxt = 97 HEADGOAL (rtac ctxt iffI THEN' dtac ctxt sym THEN' hyp_subst_tac ctxt) THEN 98 unfold_thms_tac ctxt @{thms prod.case} THEN 99 HEADGOAL (rtac ctxt refl THEN' hyp_subst_tac ctxt THEN' 100 REPEAT_DETERM o subst_tac ctxt NONE @{thms unit_abs_eta_conv case_prod_eta} THEN' 101 rtac ctxt refl); 102 in 103 Goal.prove_sorry ctxt [] [] goal (fn {context = ctxt, ...} => mk_tac ctxt) 104 |> Thm.close_derivation \<^here> 105 end; 106 107val num_curry_uncurryN_balanced_precomp = 8; 108val curry_uncurryN_balanced_precomp = 109 map (mk_curry_uncurryN_balanced_raw \<^context>) (0 upto num_curry_uncurryN_balanced_precomp); 110 111fun mk_curry_uncurryN_balanced ctxt n = 112 if n <= num_curry_uncurryN_balanced_precomp then nth curry_uncurryN_balanced_precomp n 113 else mk_curry_uncurryN_balanced_raw ctxt n; 114 115fun mk_const_transfer_goal ctxt (s, var_T) = 116 let 117 val var_As = Term.add_tvarsT var_T []; 118 119 val ((As, Bs), names_ctxt) = ctxt 120 |> Variable.declare_typ var_T 121 |> mk_TFrees' (map snd var_As) 122 ||>> mk_TFrees' (map snd var_As); 123 124 val (Rs, _) = names_ctxt 125 |> mk_Frees "R" (map2 mk_pred2T As Bs); 126 127 val T = Term.typ_subst_TVars (map fst var_As ~~ As) var_T; 128 val U = Term.typ_subst_TVars (map fst var_As ~~ Bs) var_T; 129 in 130 mk_parametricity_goal ctxt Rs (Const (s, T)) (Const (s, U)) 131 |> tap (fn goal => can type_of goal orelse 132 error ("Cannot transfer constant " ^ quote (Syntax.string_of_term ctxt (Const (s, T))) ^ 133 " from type " ^ quote (Syntax.string_of_typ ctxt T) ^ " to " ^ 134 quote (Syntax.string_of_typ ctxt U))) 135 end; 136 137fun mk_abs_transfer ctxt fpT_name = 138 let 139 val SOME {pre_bnf, absT_info = {absT, repT, abs, type_definition, ...}, ...} = 140 fp_sugar_of ctxt fpT_name; 141 in 142 if absT = repT then 143 raise Fail "no abs/rep" 144 else 145 let 146 val rel_def = rel_def_of_bnf pre_bnf; 147 148 val absT = T_of_bnf pre_bnf 149 |> singleton (freeze_types ctxt (map dest_TVar (lives_of_bnf pre_bnf))); 150 151 val goal = mk_const_transfer_goal ctxt (dest_Const (mk_abs absT abs)) 152 in 153 Variable.add_free_names ctxt goal [] 154 |> (fn vars => Goal.prove_sorry ctxt vars [] goal (fn {context = ctxt, prems = _} => 155 unfold_thms_tac ctxt [rel_def] THEN 156 HEADGOAL (rtac ctxt refl ORELSE' 157 rtac ctxt (@{thm Abs_transfer} OF [type_definition, type_definition])))) 158 end 159 end; 160 161fun mk_rep_transfer ctxt fpT_name = 162 let 163 val SOME {pre_bnf, absT_info = {absT, repT, rep, ...}, ...} = fp_sugar_of ctxt fpT_name; 164 in 165 if absT = repT then 166 raise Fail "no abs/rep" 167 else 168 let 169 val rel_def = rel_def_of_bnf pre_bnf; 170 171 val absT = T_of_bnf pre_bnf 172 |> singleton (freeze_types ctxt (map dest_TVar (lives_of_bnf pre_bnf))); 173 174 val goal = mk_const_transfer_goal ctxt (dest_Const (mk_rep absT rep)) 175 in 176 Variable.add_free_names ctxt goal [] 177 |> (fn vars => Goal.prove_sorry ctxt vars [] goal (fn {context = ctxt, prems = _} => 178 unfold_thms_tac ctxt [rel_def] THEN 179 HEADGOAL (rtac ctxt refl ORELSE' rtac ctxt @{thm vimage2p_rel_fun}))) 180 end 181 end; 182 183exception UNNATURAL of unit; 184 185fun mk_pointful_natural_from_transfer ctxt transfer = 186 let 187 val _ $ (_ $ Const (s, T0) $ Const (_, U0)) = Thm.prop_of transfer; 188 val [T, U] = freeze_types ctxt [] [T0, U0]; 189 val var_T = generalize_types 0 T U; 190 191 val var_As = map TVar (rev (Term.add_tvarsT var_T [])); 192 193 val ((As, Bs), names_ctxt) = ctxt 194 |> mk_TFrees' (map Type.sort_of_atyp var_As) 195 ||>> mk_TFrees' (map Type.sort_of_atyp var_As); 196 197 val TA = typ_subst_atomic (var_As ~~ As) var_T; 198 199 val ((xs, fs), _) = names_ctxt 200 |> mk_Frees "x" (binder_types TA) 201 ||>> mk_Frees "f" (map2 (curry (op -->)) As Bs); 202 203 val AB_fs = (As ~~ Bs) ~~ fs; 204 205 fun build_applied_map TU t = 206 if op = TU then 207 t 208 else 209 (case try (build_map ctxt [] [] (the o AList.lookup (op =) AB_fs)) TU of 210 SOME mapx => mapx $ t 211 | NONE => raise UNNATURAL ()); 212 213 fun unextensionalize (f $ (x as Free _), rhs) = unextensionalize (f, lambda x rhs) 214 | unextensionalize tu = tu; 215 216 val TB = typ_subst_atomic (var_As ~~ Bs) var_T; 217 218 val (binder_TAs, body_TA) = strip_type TA; 219 val (binder_TBs, body_TB) = strip_type TB; 220 221 val n = length var_As; 222 val m = length binder_TAs; 223 224 val A_nesting_bnfs = nesting_bnfs ctxt [[body_TA :: binder_TAs]] As; 225 val A_nesting_map_ids = map map_id_of_bnf A_nesting_bnfs; 226 val A_nesting_rel_Grps = map rel_Grp_of_bnf A_nesting_bnfs; 227 228 val ta = Const (s, TA); 229 val tb = Const (s, TB); 230 val xfs = @{map 3} (curry build_applied_map) binder_TAs binder_TBs xs; 231 232 val goal = (list_comb (tb, xfs), build_applied_map (body_TA, body_TB) (list_comb (ta, xs))) 233 |> unextensionalize |> mk_Trueprop_eq; 234 235 val _ = if can type_of goal then () else raise UNNATURAL (); 236 237 val vars = map (fst o dest_Free) (xs @ fs); 238 in 239 Goal.prove_sorry ctxt vars [] goal (fn {context = ctxt, prems = _} => 240 mk_natural_from_transfer_tac ctxt m (replicate n true) transfer A_nesting_map_ids 241 A_nesting_rel_Grps []) 242 |> Thm.close_derivation \<^here> 243 end; 244 245type s_parse_info = 246 {outer_buffer: BNF_GFP_Grec.buffer, 247 ctr_guards: term Symtab.table, 248 inner_buffer: BNF_GFP_Grec.buffer}; 249 250type rho_parse_info = 251 {pattern_ctrs: (term * term list) Symtab.table, 252 discs: term Symtab.table, 253 sels: term Symtab.table, 254 it: term, 255 mk_case: typ -> term}; 256 257fun curry_friend (T, t) = 258 let 259 val prod_T = domain_type (fastype_of t); 260 val Ts = dest_tupleT_balanced (num_binder_types T) prod_T; 261 val xs = map_index (fn (i, T) => Free ("x" ^ string_of_int i, T)) Ts; 262 val body = mk_tuple_balanced xs; 263 in 264 (T, fold_rev Term.lambda xs (t $ body)) 265 end; 266 267fun curry_friends ({Oper, VLeaf, CLeaf, ctr_wrapper, friends} : buffer) = 268 {Oper = Oper, VLeaf = VLeaf, CLeaf = CLeaf, ctr_wrapper = ctr_wrapper, 269 friends = Symtab.map (K curry_friend) friends}; 270 271fun checked_gfp_sugar_of lthy (T as Type (T_name, _)) = 272 (case fp_sugar_of lthy T_name of 273 SOME (sugar as {fp = Greatest_FP, ...}) => sugar 274 | _ => not_codatatype lthy T) 275 | checked_gfp_sugar_of lthy T = not_codatatype lthy T; 276 277fun generic_spec_of friend ctxt arg_Ts res_T (raw_buffer0 as {VLeaf = VLeaf0, ...}) = 278 let 279 val thy = Proof_Context.theory_of ctxt; 280 281 val tupled_arg_T = mk_tupleT_balanced arg_Ts; 282 283 val {T = fpT, X, fp_res_index, fp_res = {ctors = ctors0, ...}, 284 absT_info = {abs = abs0, rep = rep0, ...}, 285 fp_ctr_sugar = {ctrXs_Tss, ctr_sugar = {ctrs = ctrs0, casex = case0, discs = discs0, 286 selss = selss0, sel_defs, ...}, ...}, ...} = 287 checked_gfp_sugar_of ctxt res_T; 288 289 val VLeaf0_T = fastype_of VLeaf0; 290 val Y = domain_type VLeaf0_T; 291 292 val raw_buffer = specialize_buffer_types raw_buffer0; 293 294 val As_rho = tvar_subst thy [fpT] [res_T]; 295 296 val substAT = Term.typ_subst_TVars As_rho; 297 val substA = Term.subst_TVars As_rho; 298 val substYT = Tsubst Y tupled_arg_T; 299 val substY = substT Y tupled_arg_T; 300 301 val Ys_rho_inner = if friend then [] else [(Y, tupled_arg_T)]; 302 val substYT_inner = substAT o Term.typ_subst_atomic Ys_rho_inner; 303 val substY_inner = substA o Term.subst_atomic_types Ys_rho_inner; 304 305 val mid_T = substYT_inner (range_type VLeaf0_T); 306 307 val substXT_mid = Tsubst X mid_T; 308 309 val XifyT = typ_subst_nonatomic [(res_T, X)]; 310 val YifyT = typ_subst_nonatomic [(res_T, Y)]; 311 312 val substXYT = Tsubst X Y; 313 314 val ctor0 = nth ctors0 fp_res_index; 315 val ctor = enforce_type ctxt range_type res_T ctor0; 316 val preT = YifyT (domain_type (fastype_of ctor)); 317 318 val n = length ctrs0; 319 val ks = 1 upto n; 320 321 fun mk_ctr_guards () = 322 let 323 val ctr_Tss = map (map (substXT_mid o substAT)) ctrXs_Tss; 324 val preT = XifyT (domain_type (fastype_of ctor)); 325 val mid_preT = substXT_mid preT; 326 val abs = enforce_type ctxt range_type mid_preT abs0; 327 val absT = range_type (fastype_of abs); 328 329 fun mk_ctr_guard k ctr_Ts (Const (s, _)) = 330 let 331 val xs = map_index (fn (i, T) => Free ("x" ^ string_of_int i, T)) ctr_Ts; 332 val body = mk_absumprod absT abs n k xs; 333 in 334 (s, fold_rev Term.lambda xs body) 335 end; 336 in 337 Symtab.make (@{map 3} mk_ctr_guard ks ctr_Tss ctrs0) 338 end; 339 340 val substYT_mid = substYT o Tsubst Y mid_T; 341 342 val outer_T = substYT_mid preT; 343 344 val substY_outer = substY o substT Y outer_T; 345 346 val outer_buffer = curry_friends (map_buffer substY_outer raw_buffer); 347 val ctr_guards = mk_ctr_guards (); 348 val inner_buffer = curry_friends (map_buffer substY_inner raw_buffer); 349 350 val s_parse_info = 351 {outer_buffer = outer_buffer, ctr_guards = ctr_guards, inner_buffer = inner_buffer}; 352 353 fun mk_friend_spec () = 354 let 355 fun encapsulate_nested U T free = 356 betapply (build_map ctxt [] [] (fn (T, _) => 357 if T = domain_type VLeaf0_T then Abs (Name.uu, T, VLeaf0 $ Bound 0) 358 else Abs (Name.uu, T, Bound 0)) (T, U), 359 free); 360 361 val preT = YifyT (domain_type (fastype_of ctor)); 362 val YpreT = HOLogic.mk_prodT (Y, preT); 363 364 val rep = rep0 |> enforce_type ctxt domain_type (substXT_mid (XifyT preT)); 365 366 fun mk_disc k = 367 ctrXs_Tss 368 |> map_index (fn (i, Ts) => 369 Abs (Name.uu, mk_tupleT_balanced Ts, 370 if i + 1 = k then \<^const>\<open>HOL.True\<close> else \<^const>\<open>HOL.False\<close>)) 371 |> mk_case_sumN_balanced 372 |> map_types substXYT 373 |> (fn tm => Library.foldl1 HOLogic.mk_comp [tm, rep, snd_const YpreT]) 374 |> map_types substAT; 375 376 val all_discs = map mk_disc ks; 377 378 fun mk_pair (Const (disc_name, _)) disc = SOME (disc_name, disc) 379 | mk_pair _ _ = NONE; 380 381 val discs = @{map 2} mk_pair discs0 all_discs 382 |> map_filter I |> Symtab.make; 383 384 fun mk_sel sel_def = 385 let 386 val (sel_name, case_functions) = 387 sel_def 388 |> Object_Logic.rulify ctxt 389 |> Thm.concl_of 390 |> perhaps (try drop_all) 391 |> perhaps (try HOLogic.dest_Trueprop) 392 |> HOLogic.dest_eq 393 |>> fst o strip_comb 394 |>> fst o dest_Const 395 ||> fst o dest_comb 396 ||> snd o strip_comb 397 ||> map (map_types (XifyT o substAT)); 398 399 fun encapsulate_case_function case_function = 400 let 401 fun encapsulate bound_Ts [] case_function = 402 let val T = fastype_of1 (bound_Ts, case_function) in 403 encapsulate_nested (substXT_mid T) (substXYT T) case_function 404 end 405 | encapsulate bound_Ts (T :: Ts) case_function = 406 Abs (Name.uu, T, 407 encapsulate (T :: bound_Ts) Ts 408 (betapply (incr_boundvars 1 case_function, Bound 0))); 409 in 410 encapsulate [] (binder_types (fastype_of case_function)) case_function 411 end; 412 in 413 (sel_name, ctrXs_Tss 414 |> map (map_index (fn (i, T) => Free ("x" ^ string_of_int (i + 1), T))) 415 |> `(map mk_tuple_balanced) 416 |> uncurry (@{map 3} mk_tupled_fun (map encapsulate_case_function case_functions)) 417 |> mk_case_sumN_balanced 418 |> map_types substXYT 419 |> (fn tm => Library.foldl1 HOLogic.mk_comp [tm, rep, snd_const YpreT]) 420 |> map_types substAT) 421 end; 422 423 val sels = Symtab.make (map mk_sel sel_defs); 424 425 fun mk_disc_sels_pair disc sels = 426 if forall is_some sels then SOME (disc, map the sels) else NONE; 427 428 val pattern_ctrs = (ctrs0, selss0) 429 ||> map (map (try dest_Const #> Option.mapPartial (fst #> Symtab.lookup sels))) 430 ||> @{map 2} mk_disc_sels_pair all_discs 431 |>> map (dest_Const #> fst) 432 |> op ~~ 433 |> map_filter (fn (s, opt) => if is_some opt then SOME (s, the opt) else NONE) 434 |> Symtab.make; 435 436 val it = HOLogic.mk_comp (VLeaf0, fst_const YpreT); 437 438 val mk_case = 439 let 440 val abs_fun_tms = case0 441 |> fastype_of 442 |> substAT 443 |> XifyT 444 |> binder_fun_types 445 |> map_index (fn (i, T) => Free ("f" ^ string_of_int (i + 1), T)); 446 val arg_Uss = abs_fun_tms 447 |> map fastype_of 448 |> map binder_types; 449 val arg_Tss = arg_Uss 450 |> map (map substXYT); 451 val case0 = 452 arg_Tss 453 |> map (map_index (fn (i, T) => Free ("x" ^ string_of_int (i + 1), T))) 454 |> `(map mk_tuple_balanced) 455 ||> @{map 3} (@{map 3} encapsulate_nested) arg_Uss arg_Tss 456 |> uncurry (@{map 3} mk_tupled_fun abs_fun_tms) 457 |> mk_case_sumN_balanced 458 |> (fn tm => Library.foldl1 HOLogic.mk_comp [tm, rep, snd_const YpreT]) 459 |> fold_rev lambda abs_fun_tms 460 |> map_types (substAT o substXT_mid); 461 in 462 fn U => case0 463 |> substT (body_type (fastype_of case0)) U 464 |> Syntax.check_term ctxt 465 end; 466 in 467 {pattern_ctrs = pattern_ctrs, discs = discs, sels = sels, it = it, mk_case = mk_case} 468 end; 469 in 470 (s_parse_info, mk_friend_spec) 471 end; 472 473fun corec_parse_info_of ctxt = 474 fst ooo generic_spec_of false ctxt; 475 476fun friend_parse_info_of ctxt = 477 apsnd (fn f => f ()) ooo generic_spec_of true ctxt; 478 479end; 480