1(* Title: HOL/Tools/BNF/bnf_lfp_rec_sugar.ML 2 Author: Lorenz Panny, TU Muenchen 3 Author: Jasmin Blanchette, TU Muenchen 4 Copyright 2013 5 6Recursor sugar ("primrec"). 7*) 8 9signature BNF_LFP_REC_SUGAR = 10sig 11 datatype rec_option = 12 Plugins_Option of Proof.context -> Plugin_Name.filter | 13 Nonexhaustive_Option | 14 Transfer_Option 15 16 datatype rec_call = 17 No_Rec of int * typ | 18 Mutual_Rec of (int * typ) * (int * typ) | 19 Nested_Rec of int * typ 20 21 type rec_ctr_spec = 22 {ctr: term, 23 offset: int, 24 calls: rec_call list, 25 rec_thm: thm} 26 27 type rec_spec = 28 {recx: term, 29 fp_nesting_map_ident0s: thm list, 30 fp_nesting_map_comps: thm list, 31 fp_nesting_pred_maps: thm list, 32 ctr_specs: rec_ctr_spec list} 33 34 type basic_lfp_sugar = 35 {T: typ, 36 fp_res_index: int, 37 C: typ, 38 fun_arg_Tsss : typ list list list, 39 ctr_sugar: Ctr_Sugar.ctr_sugar, 40 recx: term, 41 rec_thms: thm list}; 42 43 type lfp_rec_extension = 44 {nested_simps: thm list, 45 special_endgame_tac: Proof.context -> thm list -> thm list -> thm list -> tactic, 46 is_new_datatype: Proof.context -> string -> bool, 47 basic_lfp_sugars_of: binding list -> typ list -> term list -> 48 (term * term list list) list list -> local_theory -> 49 typ list * int list * basic_lfp_sugar list * thm list * thm list * thm list * thm 50 * Token.src list * bool * local_theory, 51 rewrite_nested_rec_call: (Proof.context -> (term -> bool) -> (string -> int) -> typ list -> 52 term -> term -> term -> term) option}; 53 54 val register_lfp_rec_extension: lfp_rec_extension -> theory -> theory 55 val default_basic_lfp_sugars_of: binding list -> typ list -> term list -> 56 (term * term list list) list list -> local_theory -> 57 typ list * int list * basic_lfp_sugar list * thm list * thm list * thm list * thm 58 * Token.src list * bool * local_theory 59 val rec_specs_of: binding list -> typ list -> typ list -> term list -> 60 (term * term list list) list list -> local_theory -> 61 (bool * rec_spec list * typ list * thm * thm list * Token.src list * typ list) * local_theory 62 63 val lfp_rec_sugar_interpretation: string -> 64 (BNF_FP_Rec_Sugar_Util.fp_rec_sugar -> local_theory -> local_theory) -> theory -> theory 65 66 val primrec: bool -> rec_option list -> (binding * typ option * mixfix) list -> 67 Specification.multi_specs -> local_theory -> 68 (term list * thm list * thm list list) * local_theory 69 val primrec_cmd: bool -> rec_option list -> (binding * string option * mixfix) list -> 70 Specification.multi_specs_cmd -> local_theory -> 71 (term list * thm list * thm list list) * local_theory 72 val primrec_global: bool -> rec_option list -> (binding * typ option * mixfix) list -> 73 Specification.multi_specs -> theory -> (term list * thm list * thm list list) * theory 74 val primrec_overloaded: bool -> rec_option list -> (string * (string * typ) * bool) list -> 75 (binding * typ option * mixfix) list -> 76 Specification.multi_specs -> theory -> (term list * thm list * thm list list) * theory 77 val primrec_simple: bool -> ((binding * typ) * mixfix) list -> term list -> local_theory -> 78 ((string list * (binding -> binding) list) 79 * (term list * thm list * (int list list * thm list list))) * local_theory 80end; 81 82structure BNF_LFP_Rec_Sugar : BNF_LFP_REC_SUGAR = 83struct 84 85open Ctr_Sugar 86open Ctr_Sugar_Util 87open Ctr_Sugar_General_Tactics 88open BNF_FP_Rec_Sugar_Util 89 90val inductN = "induct"; 91val simpsN = "simps"; 92 93val nitpicksimp_attrs = @{attributes [nitpick_simp]}; 94val simp_attrs = @{attributes [simp]}; 95val nitpicksimp_simp_attrs = nitpicksimp_attrs @ simp_attrs; 96 97exception OLD_PRIMREC of unit; 98 99datatype rec_option = 100 Plugins_Option of Proof.context -> Plugin_Name.filter | 101 Nonexhaustive_Option | 102 Transfer_Option; 103 104datatype rec_call = 105 No_Rec of int * typ | 106 Mutual_Rec of (int * typ) * (int * typ) | 107 Nested_Rec of int * typ; 108 109type rec_ctr_spec = 110 {ctr: term, 111 offset: int, 112 calls: rec_call list, 113 rec_thm: thm}; 114 115type rec_spec = 116 {recx: term, 117 fp_nesting_map_ident0s: thm list, 118 fp_nesting_map_comps: thm list, 119 fp_nesting_pred_maps: thm list, 120 ctr_specs: rec_ctr_spec list}; 121 122type basic_lfp_sugar = 123 {T: typ, 124 fp_res_index: int, 125 C: typ, 126 fun_arg_Tsss : typ list list list, 127 ctr_sugar: ctr_sugar, 128 recx: term, 129 rec_thms: thm list}; 130 131type lfp_rec_extension = 132 {nested_simps: thm list, 133 special_endgame_tac: Proof.context -> thm list -> thm list -> thm list -> tactic, 134 is_new_datatype: Proof.context -> string -> bool, 135 basic_lfp_sugars_of: binding list -> typ list -> term list -> 136 (term * term list list) list list -> local_theory -> 137 typ list * int list * basic_lfp_sugar list * thm list * thm list * thm list * thm 138 * Token.src list * bool * local_theory, 139 rewrite_nested_rec_call: (Proof.context -> (term -> bool) -> (string -> int) -> typ list -> 140 term -> term -> term -> term) option}; 141 142structure Data = Theory_Data 143( 144 type T = lfp_rec_extension option; 145 val empty = NONE; 146 val extend = I; 147 val merge = merge_options; 148); 149 150val register_lfp_rec_extension = Data.put o SOME; 151 152fun nested_simps ctxt = 153 (case Data.get (Proof_Context.theory_of ctxt) of 154 SOME {nested_simps, ...} => nested_simps 155 | NONE => []); 156 157fun special_endgame_tac ctxt = 158 (case Data.get (Proof_Context.theory_of ctxt) of 159 SOME {special_endgame_tac, ...} => special_endgame_tac ctxt 160 | NONE => K (K (K no_tac))); 161 162fun is_new_datatype ctxt = 163 (case Data.get (Proof_Context.theory_of ctxt) of 164 SOME {is_new_datatype, ...} => is_new_datatype ctxt 165 | NONE => K true); 166 167fun default_basic_lfp_sugars_of _ [Type (arg_T_name, _)] _ _ ctxt = 168 let 169 val ctr_sugar as {T, ctrs, casex, case_thms, ...} = 170 (case ctr_sugar_of ctxt arg_T_name of 171 SOME ctr_sugar => ctr_sugar 172 | NONE => error ("Unsupported type " ^ quote arg_T_name ^ " at this stage")); 173 174 val C = body_type (fastype_of casex); 175 val fun_arg_Tsss = map (map single o binder_types o fastype_of) ctrs; 176 177 val basic_lfp_sugar = 178 {T = T, fp_res_index = 0, C = C, fun_arg_Tsss = fun_arg_Tsss, ctr_sugar = ctr_sugar, 179 recx = casex, rec_thms = case_thms}; 180 in 181 ([], [0], [basic_lfp_sugar], [], [], [], TrueI (*dummy*), [], false, ctxt) 182 end 183 | default_basic_lfp_sugars_of _ [T] _ _ ctxt = 184 error ("Cannot recurse through type " ^ quote (Syntax.string_of_typ ctxt T)) 185 | default_basic_lfp_sugars_of _ _ _ _ _ = error "Unsupported mutual recursion at this stage"; 186 187fun basic_lfp_sugars_of bs arg_Ts callers callssss lthy = 188 (case Data.get (Proof_Context.theory_of lthy) of 189 SOME {basic_lfp_sugars_of, ...} => basic_lfp_sugars_of 190 | NONE => default_basic_lfp_sugars_of) bs arg_Ts callers callssss lthy; 191 192fun rewrite_nested_rec_call ctxt = 193 (case Data.get (Proof_Context.theory_of ctxt) of 194 SOME {rewrite_nested_rec_call = SOME f, ...} => f ctxt 195 | _ => error "Unsupported nested recursion"); 196 197structure LFP_Rec_Sugar_Plugin = Plugin(type T = fp_rec_sugar); 198 199fun lfp_rec_sugar_interpretation name f = 200 LFP_Rec_Sugar_Plugin.interpretation name (fn fp_rec_sugar => fn lthy => 201 f (transfer_fp_rec_sugar (Proof_Context.theory_of lthy) fp_rec_sugar) lthy); 202 203val interpret_lfp_rec_sugar = LFP_Rec_Sugar_Plugin.data; 204 205fun rec_specs_of bs arg_Ts res_Ts callers callssss0 lthy0 = 206 let 207 val thy = Proof_Context.theory_of lthy0; 208 209 val (missing_arg_Ts, perm0_kks, basic_lfp_sugars, fp_nesting_map_ident0s, fp_nesting_map_comps, 210 fp_nesting_pred_maps, common_induct, induct_attrs, n2m, lthy) = 211 basic_lfp_sugars_of bs arg_Ts callers callssss0 lthy0; 212 213 val perm_basic_lfp_sugars = sort (int_ord o apply2 #fp_res_index) basic_lfp_sugars; 214 215 val indices = map #fp_res_index basic_lfp_sugars; 216 val perm_indices = map #fp_res_index perm_basic_lfp_sugars; 217 218 val perm_ctrss = map (#ctrs o #ctr_sugar) perm_basic_lfp_sugars; 219 220 val nn0 = length arg_Ts; 221 val nn = length perm_ctrss; 222 val kks = 0 upto nn - 1; 223 224 val perm_ctr_offsets = map (fn kk => Integer.sum (map length (take kk perm_ctrss))) kks; 225 226 val perm_fpTs = map #T perm_basic_lfp_sugars; 227 val perm_Cs = map #C perm_basic_lfp_sugars; 228 val perm_fun_arg_Tssss = map #fun_arg_Tsss perm_basic_lfp_sugars; 229 230 fun unpermute0 perm0_xs = permute_like_unique (op =) perm0_kks kks perm0_xs; 231 fun unpermute perm_xs = permute_like_unique (op =) perm_indices indices perm_xs; 232 233 val inducts = unpermute0 (conj_dests nn common_induct); 234 235 val fpTs = unpermute perm_fpTs; 236 val Cs = unpermute perm_Cs; 237 val ctr_offsets = unpermute perm_ctr_offsets; 238 239 val As_rho = tvar_subst thy (take nn0 fpTs) arg_Ts; 240 val Cs_rho = map (fst o dest_TVar) Cs ~~ pad_list HOLogic.unitT nn res_Ts; 241 242 val substA = Term.subst_TVars As_rho; 243 val substAT = Term.typ_subst_TVars As_rho; 244 val substCT = Term.typ_subst_TVars Cs_rho; 245 val substACT = substAT o substCT; 246 247 val perm_Cs' = map substCT perm_Cs; 248 249 fun call_of [i] [T] = (if exists_subtype_in Cs T then Nested_Rec else No_Rec) (i, substACT T) 250 | call_of [i, i'] [T, T'] = Mutual_Rec ((i, substACT T), (i', substACT T')); 251 252 fun mk_ctr_spec ctr offset fun_arg_Tss rec_thm = 253 let 254 val (fun_arg_hss, _) = indexedd fun_arg_Tss 0; 255 val fun_arg_hs = flat_rec_arg_args fun_arg_hss; 256 val fun_arg_iss = map (map (find_index_eq fun_arg_hs)) fun_arg_hss; 257 in 258 {ctr = substA ctr, offset = offset, calls = map2 call_of fun_arg_iss fun_arg_Tss, 259 rec_thm = rec_thm} 260 end; 261 262 fun mk_ctr_specs fp_res_index k ctrs rec_thms = 263 @{map 4} mk_ctr_spec ctrs (k upto k + length ctrs - 1) (nth perm_fun_arg_Tssss fp_res_index) 264 rec_thms; 265 266 fun mk_spec ctr_offset 267 ({T, fp_res_index, ctr_sugar = {ctrs, ...}, recx, rec_thms, ...} : basic_lfp_sugar) = 268 {recx = mk_co_rec thy Least_FP perm_Cs' (substAT T) recx, 269 fp_nesting_map_ident0s = fp_nesting_map_ident0s, fp_nesting_map_comps = fp_nesting_map_comps, 270 fp_nesting_pred_maps = fp_nesting_pred_maps, 271 ctr_specs = mk_ctr_specs fp_res_index ctr_offset ctrs rec_thms}; 272 in 273 ((n2m, map2 mk_spec ctr_offsets basic_lfp_sugars, missing_arg_Ts, common_induct, inducts, 274 induct_attrs, map #T basic_lfp_sugars), lthy) 275 end; 276 277val undef_const = Const (\<^const_name>\<open>undefined\<close>, dummyT); 278 279type eqn_data = { 280 fun_name: string, 281 rec_type: typ, 282 ctr: term, 283 ctr_args: term list, 284 left_args: term list, 285 right_args: term list, 286 res_type: typ, 287 rhs_term: term, 288 user_eqn: term 289}; 290 291fun dissect_eqn ctxt fun_names eqn0 = 292 let 293 val eqn = drop_all eqn0 |> HOLogic.dest_Trueprop 294 handle TERM _ => ill_formed_equation_lhs_rhs ctxt [eqn0]; 295 val (lhs, rhs) = HOLogic.dest_eq eqn 296 handle TERM _ => ill_formed_equation_lhs_rhs ctxt [eqn]; 297 val (fun_name, args) = strip_comb lhs 298 |>> (fn x => if is_Free x then fst (dest_Free x) else ill_formed_equation_head ctxt [eqn]); 299 val (left_args, rest) = chop_prefix is_Free args; 300 val (nonfrees, right_args) = chop_suffix is_Free rest; 301 val num_nonfrees = length nonfrees; 302 val _ = num_nonfrees = 1 orelse 303 (if num_nonfrees = 0 then missing_pattern ctxt [eqn] 304 else more_than_one_nonvar_in_lhs ctxt [eqn]); 305 val _ = member (op =) fun_names fun_name orelse raise ill_formed_equation_head ctxt [eqn]; 306 307 val (ctr, ctr_args) = strip_comb (the_single nonfrees); 308 val _ = try (num_binder_types o fastype_of) ctr = SOME (length ctr_args) orelse 309 partially_applied_ctr_in_pattern ctxt [eqn]; 310 311 val _ = check_duplicate_variables_in_lhs ctxt [eqn] (left_args @ ctr_args @ right_args) 312 val _ = forall is_Free ctr_args orelse nonprimitive_pattern_in_lhs ctxt [eqn]; 313 val _ = 314 let 315 val bads = 316 fold_aterms (fn x as Free (v, _) => 317 if (not (member (op =) (left_args @ ctr_args @ right_args) x) andalso 318 not (member (op =) fun_names v) andalso not (Variable.is_fixed ctxt v)) then 319 cons x 320 else 321 I 322 | _ => I) rhs []; 323 in 324 null bads orelse extra_variable_in_rhs ctxt [eqn] (hd bads) 325 end; 326 in 327 {fun_name = fun_name, 328 rec_type = body_type (type_of ctr), 329 ctr = ctr, 330 ctr_args = ctr_args, 331 left_args = left_args, 332 right_args = right_args, 333 res_type = map fastype_of (left_args @ right_args) ---> fastype_of rhs, 334 rhs_term = rhs, 335 user_eqn = eqn0} 336 end; 337 338fun subst_rec_calls ctxt get_ctr_pos has_call ctr_args mutual_calls nested_calls = 339 let 340 fun try_nested_rec bound_Ts y t = 341 AList.lookup (op =) nested_calls y 342 |> Option.map (fn y' => rewrite_nested_rec_call ctxt has_call get_ctr_pos bound_Ts y y' t); 343 344 fun subst bound_Ts (t as g' $ y) = 345 let 346 fun subst_comb (h $ z) = subst bound_Ts h $ subst bound_Ts z 347 | subst_comb t = t; 348 349 val y_head = head_of y; 350 in 351 if not (member (op =) ctr_args y_head) then 352 subst_comb t 353 else 354 (case try_nested_rec bound_Ts y_head t of 355 SOME t' => subst_comb t' 356 | NONE => 357 let val (g, g_args) = strip_comb g' in 358 (case try (get_ctr_pos o fst o dest_Free) g of 359 SOME ~1 => subst_comb t 360 | SOME ctr_pos => 361 (length g_args >= ctr_pos orelse too_few_args_in_rec_call ctxt [] t; 362 (case AList.lookup (op =) mutual_calls y of 363 SOME y' => list_comb (y', map (subst bound_Ts) g_args) 364 | NONE => subst_comb t)) 365 | NONE => subst_comb t) 366 end) 367 end 368 | subst bound_Ts (Abs (v, T, b)) = Abs (v, T, subst (T :: bound_Ts) b) 369 | subst _ t = t 370 371 fun subst' t = 372 if has_call t then rec_call_not_apply_to_ctr_arg ctxt [] t 373 else try_nested_rec [] (head_of t) t |> the_default t; 374 in 375 subst' o subst [] 376 end; 377 378fun build_rec_arg ctxt (funs_data : eqn_data list list) has_call (ctr_spec : rec_ctr_spec) 379 (eqn_data_opt : eqn_data option) = 380 (case eqn_data_opt of 381 NONE => undef_const 382 | SOME {ctr_args, left_args, right_args, rhs_term = t, ...} => 383 let 384 val calls = #calls ctr_spec; 385 val n_args = fold (Integer.add o (fn Mutual_Rec _ => 2 | _ => 1)) calls 0; 386 387 val no_calls' = tag_list 0 calls 388 |> map_filter (try (apsnd (fn No_Rec p => p | Mutual_Rec (p, _) => p))); 389 val mutual_calls' = tag_list 0 calls 390 |> map_filter (try (apsnd (fn Mutual_Rec (_, p) => p))); 391 val nested_calls' = tag_list 0 calls 392 |> map_filter (try (apsnd (fn Nested_Rec p => p))); 393 394 fun ensure_unique frees t = 395 if member (op =) frees t then Free (the_single (Term.variant_frees t [dest_Free t])) else t; 396 397 val args = replicate n_args ("", dummyT) 398 |> Term.rename_wrt_term t 399 |> map Free 400 |> fold (fn (ctr_arg_idx, (arg_idx, _)) => 401 nth_map arg_idx (K (nth ctr_args ctr_arg_idx))) 402 no_calls' 403 |> fold (fn (ctr_arg_idx, (arg_idx, T)) => fn xs => 404 nth_map arg_idx (K (ensure_unique xs 405 (retype_const_or_free T (nth ctr_args ctr_arg_idx)))) xs) 406 mutual_calls' 407 |> fold (fn (ctr_arg_idx, (arg_idx, T)) => 408 nth_map arg_idx (K (retype_const_or_free T (nth ctr_args ctr_arg_idx)))) 409 nested_calls'; 410 411 val fun_name_ctr_pos_list = 412 map (fn (x :: _) => (#fun_name x, length (#left_args x))) funs_data; 413 val get_ctr_pos = try (the o AList.lookup (op =) fun_name_ctr_pos_list) #> the_default ~1; 414 val mutual_calls = map (map_prod (nth ctr_args) (nth args o fst)) mutual_calls'; 415 val nested_calls = map (map_prod (nth ctr_args) (nth args o fst)) nested_calls'; 416 in 417 t 418 |> subst_rec_calls ctxt get_ctr_pos has_call ctr_args mutual_calls nested_calls 419 |> fold_rev lambda (args @ left_args @ right_args) 420 end); 421 422fun build_defs ctxt nonexhaustives bs mxs (funs_data : eqn_data list list) 423 (rec_specs : rec_spec list) has_call = 424 let 425 val n_funs = length funs_data; 426 427 val ctr_spec_eqn_data_list' = 428 maps (fn ((xs, ys), z) => 429 let 430 val zs = replicate (length xs) z; 431 val (b, c) = finds (fn ((x, _), y) => #ctr x = #ctr y) (xs ~~ zs) ys; 432 val _ = null c orelse excess_equations ctxt (map #rhs_term c); 433 in b end) (map #ctr_specs (take n_funs rec_specs) ~~ funs_data ~~ nonexhaustives); 434 435 val (_ : unit list) = ctr_spec_eqn_data_list' |> map (fn (({ctr, ...}, nonexhaustive), x) => 436 if length x > 1 then 437 multiple_equations_for_ctr ctxt (map #user_eqn x) 438 else if length x = 1 orelse nonexhaustive orelse not (Context_Position.is_visible ctxt) then 439 () 440 else 441 no_equation_for_ctr_warning ctxt [] ctr); 442 443 val ctr_spec_eqn_data_list = 444 map (apfst fst) ctr_spec_eqn_data_list' @ 445 (drop n_funs rec_specs |> maps #ctr_specs |> map (rpair [])); 446 447 val recs = take n_funs rec_specs |> map #recx; 448 val rec_args = ctr_spec_eqn_data_list 449 |> sort (op < o apply2 (#offset o fst) |> make_ord) 450 |> map (uncurry (build_rec_arg ctxt funs_data has_call) o apsnd (try the_single)); 451 val ctr_poss = map (fn x => 452 if length (distinct (op = o apply2 (length o #left_args)) x) <> 1 then 453 inconstant_pattern_pos_for_fun ctxt [] (#fun_name (hd x)) 454 else 455 hd x |> #left_args |> length) funs_data; 456 in 457 (recs, ctr_poss) 458 |-> map2 (fn recx => fn ctr_pos => list_comb (recx, rec_args) |> permute_args ctr_pos) 459 |> Syntax.check_terms ctxt 460 |> @{map 3} (fn b => fn mx => fn t => 461 ((b, mx), ((Binding.concealed (Thm.def_binding b), []), t))) 462 bs mxs 463 end; 464 465fun find_rec_calls has_call ({ctr, ctr_args, rhs_term, ...} : eqn_data) = 466 let 467 fun find bound_Ts (Abs (_, T, b)) ctr_arg = find (T :: bound_Ts) b ctr_arg 468 | find bound_Ts (t as _ $ _) ctr_arg = 469 let 470 val typof = curry fastype_of1 bound_Ts; 471 val (f', args') = strip_comb t; 472 val n = find_index (equal ctr_arg o head_of) args'; 473 in 474 if n < 0 then 475 find bound_Ts f' ctr_arg @ maps (fn x => find bound_Ts x ctr_arg) args' 476 else 477 let 478 val (f, args as arg :: _) = chop n args' |>> curry list_comb f' 479 val (arg_head, arg_args) = Term.strip_comb arg; 480 in 481 if has_call f then 482 mk_partial_compN (length arg_args) (typof arg_head) f :: 483 maps (fn x => find bound_Ts x ctr_arg) args 484 else 485 find bound_Ts f ctr_arg @ maps (fn x => find bound_Ts x ctr_arg) args 486 end 487 end 488 | find _ _ _ = []; 489 in 490 map (find [] rhs_term) ctr_args 491 |> (fn [] => NONE | callss => SOME (ctr, callss)) 492 end; 493 494fun mk_primrec_tac ctxt num_extra_args fp_nesting_map_ident0s fp_nesting_map_comps 495 fp_nesting_pred_maps fun_defs recx = 496 unfold_thms_tac ctxt fun_defs THEN 497 HEADGOAL (rtac ctxt (funpow num_extra_args (fn thm => thm RS fun_cong) recx RS trans)) THEN 498 unfold_thms_tac ctxt (nested_simps ctxt @ fp_nesting_map_ident0s @ fp_nesting_map_comps @ 499 fp_nesting_pred_maps) THEN 500 REPEAT_DETERM (HEADGOAL (rtac ctxt refl) ORELSE 501 special_endgame_tac ctxt fp_nesting_map_ident0s fp_nesting_map_comps fp_nesting_pred_maps); 502 503fun prepare_primrec plugins nonexhaustives transfers fixes specs lthy0 = 504 let 505 val thy = Proof_Context.theory_of lthy0; 506 507 val (bs, mxs) = map_split (apfst fst) fixes; 508 val fun_names = map Binding.name_of bs; 509 val qualifys = map (fold_rev (uncurry Binding.qualify o swap) o Binding.path_of) bs; 510 val eqns_data = map (dissect_eqn lthy0 fun_names) specs; 511 val funs_data = eqns_data 512 |> partition_eq (op = o apply2 #fun_name) 513 |> finds (fn (x, y) => x = #fun_name (hd y)) fun_names |> fst 514 |> map (fn (x, y) => the_single y 515 handle List.Empty => missing_equations_for_fun x); 516 517 val frees = map (fst #>> Binding.name_of #> Free) fixes; 518 val has_call = exists_subterm (member (op =) frees); 519 val arg_Ts = map (#rec_type o hd) funs_data; 520 val res_Ts = map (#res_type o hd) funs_data; 521 val callssss = funs_data 522 |> map (partition_eq (op = o apply2 #ctr)) 523 |> map (maps (map_filter (find_rec_calls has_call))); 524 525 fun is_only_old_datatype (Type (s, _)) = 526 is_some (Old_Datatype_Data.get_info thy s) andalso not (is_new_datatype lthy0 s) 527 | is_only_old_datatype _ = false; 528 529 val _ = if exists is_only_old_datatype arg_Ts then raise OLD_PRIMREC () else (); 530 val _ = List.app (uncurry (check_top_sort lthy0)) (bs ~~ res_Ts); 531 532 val ((n2m, rec_specs, _, common_induct, inducts, induct_attrs, Ts), lthy) = 533 rec_specs_of bs arg_Ts res_Ts frees callssss lthy0; 534 535 val actual_nn = length funs_data; 536 537 val ctrs = maps (map #ctr o #ctr_specs) rec_specs; 538 val _ = List.app (fn {ctr, user_eqn, ...} => 539 ignore (member (op =) ctrs ctr orelse not_constructor_in_pattern lthy0 [user_eqn] ctr)) 540 eqns_data; 541 542 val defs = build_defs lthy nonexhaustives bs mxs funs_data rec_specs has_call; 543 544 fun prove def_thms ({ctr_specs, fp_nesting_map_ident0s, fp_nesting_map_comps, 545 fp_nesting_pred_maps, ...} : rec_spec) (fun_data : eqn_data list) lthy' = 546 let 547 val js = 548 find_indices (op = o apply2 (fn {fun_name, ctr, ...} => (fun_name, ctr))) 549 fun_data eqns_data; 550 551 val simps = finds (fn (x, y) => #ctr x = #ctr y) fun_data ctr_specs 552 |> fst 553 |> map_filter (try (fn (x, [y]) => 554 (#user_eqn x, length (#left_args x) + length (#right_args x), #rec_thm y))) 555 |> map (fn (user_eqn, num_extra_args, rec_thm) => 556 Goal.prove_sorry lthy' [] [] user_eqn 557 (fn {context = ctxt, prems = _} => 558 mk_primrec_tac ctxt num_extra_args fp_nesting_map_ident0s fp_nesting_map_comps 559 fp_nesting_pred_maps def_thms rec_thm) 560 |> Thm.close_derivation \<^here>); 561 in 562 ((js, simps), lthy') 563 end; 564 565 val notes = 566 (if n2m then 567 @{map 3} (fn name => fn qualify => fn thm => (name, qualify, inductN, [thm], induct_attrs)) 568 fun_names qualifys (take actual_nn inducts) 569 else 570 []) 571 |> map (fn (prefix, qualify, thmN, thms, attrs) => 572 ((qualify (Binding.qualify true prefix (Binding.name thmN)), attrs), [(thms, [])])); 573 574 val common_name = mk_common_name fun_names; 575 val common_qualify = fold_rev I qualifys; 576 577 val common_notes = 578 (if n2m then [(inductN, [common_induct], [])] else []) 579 |> map (fn (thmN, thms, attrs) => 580 ((common_qualify (Binding.qualify true common_name (Binding.name thmN)), attrs), 581 [(thms, [])])); 582 in 583 (((fun_names, qualifys, arg_Ts, defs), 584 fn lthy => fn defs => 585 let 586 val def_thms = map (snd o snd) defs; 587 val ts = map fst defs; 588 val phi = Local_Theory.target_morphism lthy; 589 val fp_rec_sugar = 590 {transfers = transfers, fun_names = fun_names, funs = map (Morphism.term phi) ts, 591 fun_defs = Morphism.fact phi def_thms, fpTs = take actual_nn Ts}; 592 in 593 map_prod split_list (interpret_lfp_rec_sugar plugins fp_rec_sugar) 594 (@{fold_map 2} (prove (map (snd o snd) defs)) (take actual_nn rec_specs) funs_data lthy) 595 end), 596 lthy |> Local_Theory.notes (notes @ common_notes) |> snd) 597 end; 598 599fun primrec_simple0 int plugins nonexhaustive transfer fixes ts lthy = 600 let 601 val _ = check_duplicate_const_names (map (fst o fst) fixes); 602 603 val actual_nn = length fixes; 604 605 val nonexhaustives = replicate actual_nn nonexhaustive; 606 val transfers = replicate actual_nn transfer; 607 608 val (((names, qualifys, arg_Ts, defs), prove), lthy') = 609 prepare_primrec plugins nonexhaustives transfers fixes ts lthy; 610 in 611 lthy' 612 |> fold_map Local_Theory.define defs 613 |> tap (uncurry (print_def_consts int)) 614 |-> (fn defs => fn lthy => 615 let 616 val ((jss, simpss), lthy) = prove lthy defs; 617 val res = 618 {prefix = (names, qualifys), 619 types = map (#1 o dest_Type) arg_Ts, 620 result = (map fst defs, map (snd o snd) defs, (jss, simpss))}; 621 in (res, lthy) end) 622 end; 623 624fun primrec_simple int fixes ts lthy = 625 primrec_simple0 int Plugin_Name.default_filter false false fixes ts lthy 626 |>> (fn {prefix, result, ...} => (prefix, result)) 627 handle OLD_PRIMREC () => 628 Old_Primrec.primrec_simple int fixes ts lthy 629 |>> (fn {prefix, result = (ts, thms), ...} => 630 (map_split (rpair I) [prefix], (ts, [], ([], [thms])))) 631 632fun gen_primrec old_primrec prep_spec int opts raw_fixes raw_specs lthy = 633 let 634 val plugins = get_first (fn Plugins_Option f => SOME (f lthy) | _ => NONE) (rev opts) 635 |> the_default Plugin_Name.default_filter; 636 val nonexhaustive = exists (can (fn Nonexhaustive_Option => ())) opts; 637 val transfer = exists (can (fn Transfer_Option => ())) opts; 638 639 val (fixes, specs) = fst (prep_spec raw_fixes raw_specs lthy); 640 val spec_name = Binding.conglomerate (map (#1 o #1) fixes); 641 642 val mk_notes = 643 flat oooo @{map 4} (fn js => fn prefix => fn qualify => fn thms => 644 let 645 val (bs, attrss) = map_split (fst o nth specs) js; 646 val notes = 647 @{map 3} (fn b => fn attrs => fn thm => 648 ((Binding.qualify false prefix b, nitpicksimp_simp_attrs @ attrs), 649 [([thm], [])])) 650 bs attrss thms; 651 in 652 ((qualify (Binding.qualify true prefix (Binding.name simpsN)), []), [(thms, [])]) :: notes 653 end); 654 in 655 lthy 656 |> primrec_simple0 int plugins nonexhaustive transfer fixes (map snd specs) 657 |-> (fn {prefix = (names, qualifys), types, result = (ts, defs, (jss, simpss))} => 658 Spec_Rules.add spec_name (Spec_Rules.equational_primrec types) ts (flat simpss) 659 #> Local_Theory.notes (mk_notes jss names qualifys simpss) 660 #-> (fn notes => 661 plugins code_plugin ? Code.declare_default_eqns (map (rpair true) (maps snd notes)) 662 #> pair (ts, defs, map_filter (fn ("", _) => NONE | (_, thms) => SOME thms) notes))) 663 end 664 handle OLD_PRIMREC () => 665 old_primrec int raw_fixes raw_specs lthy 666 |>> (fn {result = (ts, thms), ...} => (ts, [], [thms])); 667 668val primrec = gen_primrec Old_Primrec.primrec Specification.check_multi_specs; 669val primrec_cmd = gen_primrec Old_Primrec.primrec_cmd Specification.read_multi_specs; 670 671fun primrec_global int opts fixes specs = 672 Named_Target.theory_init 673 #> primrec int opts fixes specs 674 ##> Local_Theory.exit_global; 675 676fun primrec_overloaded int opts ops fixes specs = 677 Overloading.overloading ops 678 #> primrec int opts fixes specs 679 ##> Local_Theory.exit_global; 680 681val rec_option_parser = Parse.group (K "option") 682 (Plugin_Name.parse_filter >> Plugins_Option 683 || Parse.reserved "nonexhaustive" >> K Nonexhaustive_Option 684 || Parse.reserved "transfer" >> K Transfer_Option); 685 686val _ = Outer_Syntax.local_theory \<^command_keyword>\<open>primrec\<close> 687 "define primitive recursive functions" 688 ((Scan.optional (\<^keyword>\<open>(\<close> |-- Parse.!!! (Parse.list1 rec_option_parser) 689 --| \<^keyword>\<open>)\<close>) []) -- Parse_Spec.specification 690 >> (fn (opts, (fixes, specs)) => snd o primrec_cmd true opts fixes specs)); 691 692end; 693