1(*  Title:      HOL/Tools/BNF/bnf_lfp_rec_sugar.ML
2    Author:     Lorenz Panny, TU Muenchen
3    Author:     Jasmin Blanchette, TU Muenchen
4    Copyright   2013
5
6Recursor sugar ("primrec").
7*)
8
9signature BNF_LFP_REC_SUGAR =
10sig
11  datatype rec_option =
12    Plugins_Option of Proof.context -> Plugin_Name.filter |
13    Nonexhaustive_Option |
14    Transfer_Option
15
16  datatype rec_call =
17    No_Rec of int * typ |
18    Mutual_Rec of (int * typ) * (int * typ) |
19    Nested_Rec of int * typ
20
21  type rec_ctr_spec =
22    {ctr: term,
23     offset: int,
24     calls: rec_call list,
25     rec_thm: thm}
26
27  type rec_spec =
28    {recx: term,
29     fp_nesting_map_ident0s: thm list,
30     fp_nesting_map_comps: thm list,
31     fp_nesting_pred_maps: thm list,
32     ctr_specs: rec_ctr_spec list}
33
34  type basic_lfp_sugar =
35    {T: typ,
36     fp_res_index: int,
37     C: typ,
38     fun_arg_Tsss : typ list list list,
39     ctr_sugar: Ctr_Sugar.ctr_sugar,
40     recx: term,
41     rec_thms: thm list};
42
43  type lfp_rec_extension =
44    {nested_simps: thm list,
45     special_endgame_tac: Proof.context -> thm list -> thm list -> thm list -> tactic,
46     is_new_datatype: Proof.context -> string -> bool,
47     basic_lfp_sugars_of: binding list -> typ list -> term list ->
48       (term * term list list) list list -> local_theory ->
49       typ list * int list * basic_lfp_sugar list * thm list * thm list * thm list * thm
50       * Token.src list * bool * local_theory,
51     rewrite_nested_rec_call: (Proof.context -> (term -> bool) -> (string -> int) -> typ list ->
52       term -> term -> term -> term) option};
53
54  val register_lfp_rec_extension: lfp_rec_extension -> theory -> theory
55  val default_basic_lfp_sugars_of: binding list -> typ list -> term list ->
56    (term * term list list) list list -> local_theory ->
57    typ list * int list * basic_lfp_sugar list * thm list * thm list * thm list * thm
58    * Token.src list * bool * local_theory
59  val rec_specs_of: binding list -> typ list -> typ list -> term list ->
60    (term * term list list) list list -> local_theory ->
61    (bool * rec_spec list * typ list * thm * thm list * Token.src list * typ list) * local_theory
62
63  val lfp_rec_sugar_interpretation: string ->
64    (BNF_FP_Rec_Sugar_Util.fp_rec_sugar -> local_theory -> local_theory) -> theory -> theory
65
66  val primrec: bool -> rec_option list -> (binding * typ option * mixfix) list ->
67    Specification.multi_specs -> local_theory ->
68    (term list * thm list * thm list list) * local_theory
69  val primrec_cmd: bool -> rec_option list -> (binding * string option * mixfix) list ->
70    Specification.multi_specs_cmd -> local_theory ->
71    (term list * thm list * thm list list) * local_theory
72  val primrec_global: bool -> rec_option list -> (binding * typ option * mixfix) list ->
73    Specification.multi_specs -> theory -> (term list * thm list * thm list list) * theory
74  val primrec_overloaded: bool -> rec_option list -> (string * (string * typ) * bool) list ->
75    (binding * typ option * mixfix) list ->
76    Specification.multi_specs -> theory -> (term list * thm list * thm list list) * theory
77  val primrec_simple: bool -> ((binding * typ) * mixfix) list -> term list -> local_theory ->
78    ((string list * (binding -> binding) list)
79     * (term list * thm list * (int list list * thm list list))) * local_theory
80end;
81
82structure BNF_LFP_Rec_Sugar : BNF_LFP_REC_SUGAR =
83struct
84
85open Ctr_Sugar
86open Ctr_Sugar_Util
87open Ctr_Sugar_General_Tactics
88open BNF_FP_Rec_Sugar_Util
89
90val inductN = "induct";
91val simpsN = "simps";
92
93val nitpicksimp_attrs = @{attributes [nitpick_simp]};
94val simp_attrs = @{attributes [simp]};
95val nitpicksimp_simp_attrs = nitpicksimp_attrs @ simp_attrs;
96
97exception OLD_PRIMREC of unit;
98
99datatype rec_option =
100  Plugins_Option of Proof.context -> Plugin_Name.filter |
101  Nonexhaustive_Option |
102  Transfer_Option;
103
104datatype rec_call =
105  No_Rec of int * typ |
106  Mutual_Rec of (int * typ) * (int * typ) |
107  Nested_Rec of int * typ;
108
109type rec_ctr_spec =
110  {ctr: term,
111   offset: int,
112   calls: rec_call list,
113   rec_thm: thm};
114
115type rec_spec =
116  {recx: term,
117   fp_nesting_map_ident0s: thm list,
118   fp_nesting_map_comps: thm list,
119   fp_nesting_pred_maps: thm list,
120   ctr_specs: rec_ctr_spec list};
121
122type basic_lfp_sugar =
123  {T: typ,
124   fp_res_index: int,
125   C: typ,
126   fun_arg_Tsss : typ list list list,
127   ctr_sugar: ctr_sugar,
128   recx: term,
129   rec_thms: thm list};
130
131type lfp_rec_extension =
132  {nested_simps: thm list,
133   special_endgame_tac: Proof.context -> thm list -> thm list -> thm list -> tactic,
134   is_new_datatype: Proof.context -> string -> bool,
135   basic_lfp_sugars_of: binding list -> typ list -> term list ->
136     (term * term list list) list list -> local_theory ->
137     typ list * int list * basic_lfp_sugar list * thm list * thm list * thm list * thm
138     * Token.src list * bool * local_theory,
139   rewrite_nested_rec_call: (Proof.context -> (term -> bool) -> (string -> int) -> typ list ->
140     term -> term -> term -> term) option};
141
142structure Data = Theory_Data
143(
144  type T = lfp_rec_extension option;
145  val empty = NONE;
146  val extend = I;
147  val merge = merge_options;
148);
149
150val register_lfp_rec_extension = Data.put o SOME;
151
152fun nested_simps ctxt =
153  (case Data.get (Proof_Context.theory_of ctxt) of
154    SOME {nested_simps, ...} => nested_simps
155  | NONE => []);
156
157fun special_endgame_tac ctxt =
158  (case Data.get (Proof_Context.theory_of ctxt) of
159    SOME {special_endgame_tac, ...} => special_endgame_tac ctxt
160  | NONE => K (K (K no_tac)));
161
162fun is_new_datatype ctxt =
163  (case Data.get (Proof_Context.theory_of ctxt) of
164    SOME {is_new_datatype, ...} => is_new_datatype ctxt
165  | NONE => K true);
166
167fun default_basic_lfp_sugars_of _ [Type (arg_T_name, _)] _ _ ctxt =
168    let
169      val ctr_sugar as {T, ctrs, casex, case_thms, ...} =
170        (case ctr_sugar_of ctxt arg_T_name of
171          SOME ctr_sugar => ctr_sugar
172        | NONE => error ("Unsupported type " ^ quote arg_T_name ^ " at this stage"));
173
174      val C = body_type (fastype_of casex);
175      val fun_arg_Tsss = map (map single o binder_types o fastype_of) ctrs;
176
177      val basic_lfp_sugar =
178        {T = T, fp_res_index = 0, C = C, fun_arg_Tsss = fun_arg_Tsss, ctr_sugar = ctr_sugar,
179         recx = casex, rec_thms = case_thms};
180    in
181      ([], [0], [basic_lfp_sugar], [], [], [], TrueI (*dummy*), [], false, ctxt)
182    end
183  | default_basic_lfp_sugars_of _ [T] _ _ ctxt =
184    error ("Cannot recurse through type " ^ quote (Syntax.string_of_typ ctxt T))
185  | default_basic_lfp_sugars_of _ _ _ _ _ = error "Unsupported mutual recursion at this stage";
186
187fun basic_lfp_sugars_of bs arg_Ts callers callssss lthy =
188  (case Data.get (Proof_Context.theory_of lthy) of
189    SOME {basic_lfp_sugars_of, ...} => basic_lfp_sugars_of
190  | NONE => default_basic_lfp_sugars_of) bs arg_Ts callers callssss lthy;
191
192fun rewrite_nested_rec_call ctxt =
193  (case Data.get (Proof_Context.theory_of ctxt) of
194    SOME {rewrite_nested_rec_call = SOME f, ...} => f ctxt
195  | _ => error "Unsupported nested recursion");
196
197structure LFP_Rec_Sugar_Plugin = Plugin(type T = fp_rec_sugar);
198
199fun lfp_rec_sugar_interpretation name f =
200  LFP_Rec_Sugar_Plugin.interpretation name (fn fp_rec_sugar => fn lthy =>
201    f (transfer_fp_rec_sugar (Proof_Context.theory_of lthy) fp_rec_sugar) lthy);
202
203val interpret_lfp_rec_sugar = LFP_Rec_Sugar_Plugin.data;
204
205fun rec_specs_of bs arg_Ts res_Ts callers callssss0 lthy0 =
206  let
207    val thy = Proof_Context.theory_of lthy0;
208
209    val (missing_arg_Ts, perm0_kks, basic_lfp_sugars, fp_nesting_map_ident0s, fp_nesting_map_comps,
210         fp_nesting_pred_maps, common_induct, induct_attrs, n2m, lthy) =
211      basic_lfp_sugars_of bs arg_Ts callers callssss0 lthy0;
212
213    val perm_basic_lfp_sugars = sort (int_ord o apply2 #fp_res_index) basic_lfp_sugars;
214
215    val indices = map #fp_res_index basic_lfp_sugars;
216    val perm_indices = map #fp_res_index perm_basic_lfp_sugars;
217
218    val perm_ctrss = map (#ctrs o #ctr_sugar) perm_basic_lfp_sugars;
219
220    val nn0 = length arg_Ts;
221    val nn = length perm_ctrss;
222    val kks = 0 upto nn - 1;
223
224    val perm_ctr_offsets = map (fn kk => Integer.sum (map length (take kk perm_ctrss))) kks;
225
226    val perm_fpTs = map #T perm_basic_lfp_sugars;
227    val perm_Cs = map #C perm_basic_lfp_sugars;
228    val perm_fun_arg_Tssss = map #fun_arg_Tsss perm_basic_lfp_sugars;
229
230    fun unpermute0 perm0_xs = permute_like_unique (op =) perm0_kks kks perm0_xs;
231    fun unpermute perm_xs = permute_like_unique (op =) perm_indices indices perm_xs;
232
233    val inducts = unpermute0 (conj_dests nn common_induct);
234
235    val fpTs = unpermute perm_fpTs;
236    val Cs = unpermute perm_Cs;
237    val ctr_offsets = unpermute perm_ctr_offsets;
238
239    val As_rho = tvar_subst thy (take nn0 fpTs) arg_Ts;
240    val Cs_rho = map (fst o dest_TVar) Cs ~~ pad_list HOLogic.unitT nn res_Ts;
241
242    val substA = Term.subst_TVars As_rho;
243    val substAT = Term.typ_subst_TVars As_rho;
244    val substCT = Term.typ_subst_TVars Cs_rho;
245    val substACT = substAT o substCT;
246
247    val perm_Cs' = map substCT perm_Cs;
248
249    fun call_of [i] [T] = (if exists_subtype_in Cs T then Nested_Rec else No_Rec) (i, substACT T)
250      | call_of [i, i'] [T, T'] = Mutual_Rec ((i, substACT T), (i', substACT T'));
251
252    fun mk_ctr_spec ctr offset fun_arg_Tss rec_thm =
253      let
254        val (fun_arg_hss, _) = indexedd fun_arg_Tss 0;
255        val fun_arg_hs = flat_rec_arg_args fun_arg_hss;
256        val fun_arg_iss = map (map (find_index_eq fun_arg_hs)) fun_arg_hss;
257      in
258        {ctr = substA ctr, offset = offset, calls = map2 call_of fun_arg_iss fun_arg_Tss,
259         rec_thm = rec_thm}
260      end;
261
262    fun mk_ctr_specs fp_res_index k ctrs rec_thms =
263      @{map 4} mk_ctr_spec ctrs (k upto k + length ctrs - 1) (nth perm_fun_arg_Tssss fp_res_index)
264        rec_thms;
265
266    fun mk_spec ctr_offset
267        ({T, fp_res_index, ctr_sugar = {ctrs, ...}, recx, rec_thms, ...} : basic_lfp_sugar) =
268      {recx = mk_co_rec thy Least_FP perm_Cs' (substAT T) recx,
269       fp_nesting_map_ident0s = fp_nesting_map_ident0s, fp_nesting_map_comps = fp_nesting_map_comps,
270       fp_nesting_pred_maps = fp_nesting_pred_maps,
271       ctr_specs = mk_ctr_specs fp_res_index ctr_offset ctrs rec_thms};
272  in
273    ((n2m, map2 mk_spec ctr_offsets basic_lfp_sugars, missing_arg_Ts, common_induct, inducts,
274      induct_attrs, map #T basic_lfp_sugars), lthy)
275  end;
276
277val undef_const = Const (\<^const_name>\<open>undefined\<close>, dummyT);
278
279type eqn_data = {
280  fun_name: string,
281  rec_type: typ,
282  ctr: term,
283  ctr_args: term list,
284  left_args: term list,
285  right_args: term list,
286  res_type: typ,
287  rhs_term: term,
288  user_eqn: term
289};
290
291fun dissect_eqn ctxt fun_names eqn0 =
292  let
293    val eqn = drop_all eqn0 |> HOLogic.dest_Trueprop
294      handle TERM _ => ill_formed_equation_lhs_rhs ctxt [eqn0];
295    val (lhs, rhs) = HOLogic.dest_eq eqn
296      handle TERM _ => ill_formed_equation_lhs_rhs ctxt [eqn];
297    val (fun_name, args) = strip_comb lhs
298      |>> (fn x => if is_Free x then fst (dest_Free x) else ill_formed_equation_head ctxt [eqn]);
299    val (left_args, rest) = chop_prefix is_Free args;
300    val (nonfrees, right_args) = chop_suffix is_Free rest;
301    val num_nonfrees = length nonfrees;
302    val _ = num_nonfrees = 1 orelse
303      (if num_nonfrees = 0 then missing_pattern ctxt [eqn]
304       else more_than_one_nonvar_in_lhs ctxt [eqn]);
305    val _ = member (op =) fun_names fun_name orelse raise ill_formed_equation_head ctxt [eqn];
306
307    val (ctr, ctr_args) = strip_comb (the_single nonfrees);
308    val _ = try (num_binder_types o fastype_of) ctr = SOME (length ctr_args) orelse
309      partially_applied_ctr_in_pattern ctxt [eqn];
310
311    val _ = check_duplicate_variables_in_lhs ctxt [eqn] (left_args @ ctr_args @ right_args)
312    val _ = forall is_Free ctr_args orelse nonprimitive_pattern_in_lhs ctxt [eqn];
313    val _ =
314      let
315        val bads =
316          fold_aterms (fn x as Free (v, _) =>
317              if (not (member (op =) (left_args @ ctr_args @ right_args) x) andalso
318                  not (member (op =) fun_names v) andalso not (Variable.is_fixed ctxt v)) then
319                cons x
320              else
321                I
322            | _ => I) rhs [];
323      in
324        null bads orelse extra_variable_in_rhs ctxt [eqn] (hd bads)
325      end;
326  in
327    {fun_name = fun_name,
328     rec_type = body_type (type_of ctr),
329     ctr = ctr,
330     ctr_args = ctr_args,
331     left_args = left_args,
332     right_args = right_args,
333     res_type = map fastype_of (left_args @ right_args) ---> fastype_of rhs,
334     rhs_term = rhs,
335     user_eqn = eqn0}
336  end;
337
338fun subst_rec_calls ctxt get_ctr_pos has_call ctr_args mutual_calls nested_calls =
339  let
340    fun try_nested_rec bound_Ts y t =
341      AList.lookup (op =) nested_calls y
342      |> Option.map (fn y' => rewrite_nested_rec_call ctxt has_call get_ctr_pos bound_Ts y y' t);
343
344    fun subst bound_Ts (t as g' $ y) =
345        let
346          fun subst_comb (h $ z) = subst bound_Ts h $ subst bound_Ts z
347            | subst_comb t = t;
348
349          val y_head = head_of y;
350        in
351          if not (member (op =) ctr_args y_head) then
352            subst_comb t
353          else
354            (case try_nested_rec bound_Ts y_head t of
355              SOME t' => subst_comb t'
356            | NONE =>
357              let val (g, g_args) = strip_comb g' in
358                (case try (get_ctr_pos o fst o dest_Free) g of
359                  SOME ~1 => subst_comb t
360                | SOME ctr_pos =>
361                  (length g_args >= ctr_pos orelse too_few_args_in_rec_call ctxt [] t;
362                   (case AList.lookup (op =) mutual_calls y of
363                     SOME y' => list_comb (y', map (subst bound_Ts) g_args)
364                   | NONE => subst_comb t))
365                | NONE => subst_comb t)
366              end)
367        end
368      | subst bound_Ts (Abs (v, T, b)) = Abs (v, T, subst (T :: bound_Ts) b)
369      | subst _ t = t
370
371    fun subst' t =
372      if has_call t then rec_call_not_apply_to_ctr_arg ctxt [] t
373      else try_nested_rec [] (head_of t) t |> the_default t;
374  in
375    subst' o subst []
376  end;
377
378fun build_rec_arg ctxt (funs_data : eqn_data list list) has_call (ctr_spec : rec_ctr_spec)
379    (eqn_data_opt : eqn_data option) =
380  (case eqn_data_opt of
381    NONE => undef_const
382  | SOME {ctr_args, left_args, right_args, rhs_term = t, ...} =>
383    let
384      val calls = #calls ctr_spec;
385      val n_args = fold (Integer.add o (fn Mutual_Rec _ => 2 | _ => 1)) calls 0;
386
387      val no_calls' = tag_list 0 calls
388        |> map_filter (try (apsnd (fn No_Rec p => p | Mutual_Rec (p, _) => p)));
389      val mutual_calls' = tag_list 0 calls
390        |> map_filter (try (apsnd (fn Mutual_Rec (_, p) => p)));
391      val nested_calls' = tag_list 0 calls
392        |> map_filter (try (apsnd (fn Nested_Rec p => p)));
393
394      fun ensure_unique frees t =
395        if member (op =) frees t then Free (the_single (Term.variant_frees t [dest_Free t])) else t;
396
397      val args = replicate n_args ("", dummyT)
398        |> Term.rename_wrt_term t
399        |> map Free
400        |> fold (fn (ctr_arg_idx, (arg_idx, _)) =>
401            nth_map arg_idx (K (nth ctr_args ctr_arg_idx)))
402          no_calls'
403        |> fold (fn (ctr_arg_idx, (arg_idx, T)) => fn xs =>
404            nth_map arg_idx (K (ensure_unique xs
405              (retype_const_or_free T (nth ctr_args ctr_arg_idx)))) xs)
406          mutual_calls'
407        |> fold (fn (ctr_arg_idx, (arg_idx, T)) =>
408            nth_map arg_idx (K (retype_const_or_free T (nth ctr_args ctr_arg_idx))))
409          nested_calls';
410
411      val fun_name_ctr_pos_list =
412        map (fn (x :: _) => (#fun_name x, length (#left_args x))) funs_data;
413      val get_ctr_pos = try (the o AList.lookup (op =) fun_name_ctr_pos_list) #> the_default ~1;
414      val mutual_calls = map (map_prod (nth ctr_args) (nth args o fst)) mutual_calls';
415      val nested_calls = map (map_prod (nth ctr_args) (nth args o fst)) nested_calls';
416    in
417      t
418      |> subst_rec_calls ctxt get_ctr_pos has_call ctr_args mutual_calls nested_calls
419      |> fold_rev lambda (args @ left_args @ right_args)
420    end);
421
422fun build_defs ctxt nonexhaustives bs mxs (funs_data : eqn_data list list)
423    (rec_specs : rec_spec list) has_call =
424  let
425    val n_funs = length funs_data;
426
427    val ctr_spec_eqn_data_list' =
428      maps (fn ((xs, ys), z) =>
429        let
430          val zs = replicate (length xs) z;
431          val (b, c) = finds (fn ((x, _), y) => #ctr x = #ctr y) (xs ~~ zs) ys;
432          val _ = null c orelse excess_equations ctxt (map #rhs_term c);
433        in b end) (map #ctr_specs (take n_funs rec_specs) ~~ funs_data ~~ nonexhaustives);
434
435    val (_ : unit list) = ctr_spec_eqn_data_list' |> map (fn (({ctr, ...}, nonexhaustive), x) =>
436      if length x > 1 then
437        multiple_equations_for_ctr ctxt (map #user_eqn x)
438      else if length x = 1 orelse nonexhaustive orelse not (Context_Position.is_visible ctxt) then
439        ()
440      else
441        no_equation_for_ctr_warning ctxt [] ctr);
442
443    val ctr_spec_eqn_data_list =
444      map (apfst fst) ctr_spec_eqn_data_list' @
445      (drop n_funs rec_specs |> maps #ctr_specs |> map (rpair []));
446
447    val recs = take n_funs rec_specs |> map #recx;
448    val rec_args = ctr_spec_eqn_data_list
449      |> sort (op < o apply2 (#offset o fst) |> make_ord)
450      |> map (uncurry (build_rec_arg ctxt funs_data has_call) o apsnd (try the_single));
451    val ctr_poss = map (fn x =>
452      if length (distinct (op = o apply2 (length o #left_args)) x) <> 1 then
453        inconstant_pattern_pos_for_fun ctxt [] (#fun_name (hd x))
454      else
455        hd x |> #left_args |> length) funs_data;
456  in
457    (recs, ctr_poss)
458    |-> map2 (fn recx => fn ctr_pos => list_comb (recx, rec_args) |> permute_args ctr_pos)
459    |> Syntax.check_terms ctxt
460    |> @{map 3} (fn b => fn mx => fn t =>
461        ((b, mx), ((Binding.concealed (Thm.def_binding b), []), t)))
462      bs mxs
463  end;
464
465fun find_rec_calls has_call ({ctr, ctr_args, rhs_term, ...} : eqn_data) =
466  let
467    fun find bound_Ts (Abs (_, T, b)) ctr_arg = find (T :: bound_Ts) b ctr_arg
468      | find bound_Ts (t as _ $ _) ctr_arg =
469        let
470          val typof = curry fastype_of1 bound_Ts;
471          val (f', args') = strip_comb t;
472          val n = find_index (equal ctr_arg o head_of) args';
473        in
474          if n < 0 then
475            find bound_Ts f' ctr_arg @ maps (fn x => find bound_Ts x ctr_arg) args'
476          else
477            let
478              val (f, args as arg :: _) = chop n args' |>> curry list_comb f'
479              val (arg_head, arg_args) = Term.strip_comb arg;
480            in
481              if has_call f then
482                mk_partial_compN (length arg_args) (typof arg_head) f ::
483                maps (fn x => find bound_Ts x ctr_arg) args
484              else
485                find bound_Ts f ctr_arg @ maps (fn x => find bound_Ts x ctr_arg) args
486            end
487        end
488      | find _ _ _ = [];
489  in
490    map (find [] rhs_term) ctr_args
491    |> (fn [] => NONE | callss => SOME (ctr, callss))
492  end;
493
494fun mk_primrec_tac ctxt num_extra_args fp_nesting_map_ident0s fp_nesting_map_comps
495    fp_nesting_pred_maps fun_defs recx =
496  unfold_thms_tac ctxt fun_defs THEN
497  HEADGOAL (rtac ctxt (funpow num_extra_args (fn thm => thm RS fun_cong) recx RS trans)) THEN
498  unfold_thms_tac ctxt (nested_simps ctxt @ fp_nesting_map_ident0s @ fp_nesting_map_comps @
499    fp_nesting_pred_maps) THEN
500  REPEAT_DETERM (HEADGOAL (rtac ctxt refl) ORELSE
501    special_endgame_tac ctxt fp_nesting_map_ident0s fp_nesting_map_comps fp_nesting_pred_maps);
502
503fun prepare_primrec plugins nonexhaustives transfers fixes specs lthy0 =
504  let
505    val thy = Proof_Context.theory_of lthy0;
506
507    val (bs, mxs) = map_split (apfst fst) fixes;
508    val fun_names = map Binding.name_of bs;
509    val qualifys = map (fold_rev (uncurry Binding.qualify o swap) o Binding.path_of) bs;
510    val eqns_data = map (dissect_eqn lthy0 fun_names) specs;
511    val funs_data = eqns_data
512      |> partition_eq (op = o apply2 #fun_name)
513      |> finds (fn (x, y) => x = #fun_name (hd y)) fun_names |> fst
514      |> map (fn (x, y) => the_single y
515        handle List.Empty => missing_equations_for_fun x);
516
517    val frees = map (fst #>> Binding.name_of #> Free) fixes;
518    val has_call = exists_subterm (member (op =) frees);
519    val arg_Ts = map (#rec_type o hd) funs_data;
520    val res_Ts = map (#res_type o hd) funs_data;
521    val callssss = funs_data
522      |> map (partition_eq (op = o apply2 #ctr))
523      |> map (maps (map_filter (find_rec_calls has_call)));
524
525    fun is_only_old_datatype (Type (s, _)) =
526        is_some (Old_Datatype_Data.get_info thy s) andalso not (is_new_datatype lthy0 s)
527      | is_only_old_datatype _ = false;
528
529    val _ = if exists is_only_old_datatype arg_Ts then raise OLD_PRIMREC () else ();
530    val _ = List.app (uncurry (check_top_sort lthy0)) (bs ~~ res_Ts);
531
532    val ((n2m, rec_specs, _, common_induct, inducts, induct_attrs, Ts), lthy) =
533      rec_specs_of bs arg_Ts res_Ts frees callssss lthy0;
534
535    val actual_nn = length funs_data;
536
537    val ctrs = maps (map #ctr o #ctr_specs) rec_specs;
538    val _ = List.app (fn {ctr, user_eqn, ...} =>
539        ignore (member (op =) ctrs ctr orelse not_constructor_in_pattern lthy0 [user_eqn] ctr))
540      eqns_data;
541
542    val defs = build_defs lthy nonexhaustives bs mxs funs_data rec_specs has_call;
543
544    fun prove def_thms ({ctr_specs, fp_nesting_map_ident0s, fp_nesting_map_comps,
545        fp_nesting_pred_maps, ...} : rec_spec) (fun_data : eqn_data list) lthy' =
546      let
547        val js =
548          find_indices (op = o apply2 (fn {fun_name, ctr, ...} => (fun_name, ctr)))
549            fun_data eqns_data;
550
551        val simps = finds (fn (x, y) => #ctr x = #ctr y) fun_data ctr_specs
552          |> fst
553          |> map_filter (try (fn (x, [y]) =>
554            (#user_eqn x, length (#left_args x) + length (#right_args x), #rec_thm y)))
555          |> map (fn (user_eqn, num_extra_args, rec_thm) =>
556              Goal.prove_sorry lthy' [] [] user_eqn
557                (fn {context = ctxt, prems = _} =>
558                  mk_primrec_tac ctxt num_extra_args fp_nesting_map_ident0s fp_nesting_map_comps
559                    fp_nesting_pred_maps def_thms rec_thm)
560              |> Thm.close_derivation \<^here>);
561      in
562        ((js, simps), lthy')
563      end;
564
565    val notes =
566      (if n2m then
567         @{map 3} (fn name => fn qualify => fn thm => (name, qualify, inductN, [thm], induct_attrs))
568         fun_names qualifys (take actual_nn inducts)
569       else
570         [])
571      |> map (fn (prefix, qualify, thmN, thms, attrs) =>
572        ((qualify (Binding.qualify true prefix (Binding.name thmN)), attrs), [(thms, [])]));
573
574    val common_name = mk_common_name fun_names;
575    val common_qualify = fold_rev I qualifys;
576
577    val common_notes =
578      (if n2m then [(inductN, [common_induct], [])] else [])
579      |> map (fn (thmN, thms, attrs) =>
580        ((common_qualify (Binding.qualify true common_name (Binding.name thmN)), attrs),
581          [(thms, [])]));
582  in
583    (((fun_names, qualifys, arg_Ts, defs),
584      fn lthy => fn defs =>
585        let
586          val def_thms = map (snd o snd) defs;
587          val ts = map fst defs;
588          val phi = Local_Theory.target_morphism lthy;
589          val fp_rec_sugar =
590            {transfers = transfers, fun_names = fun_names, funs = map (Morphism.term phi) ts,
591             fun_defs = Morphism.fact phi def_thms, fpTs = take actual_nn Ts};
592        in
593          map_prod split_list (interpret_lfp_rec_sugar plugins fp_rec_sugar)
594            (@{fold_map 2} (prove (map (snd o snd) defs)) (take actual_nn rec_specs) funs_data lthy)
595        end),
596      lthy |> Local_Theory.notes (notes @ common_notes) |> snd)
597  end;
598
599fun primrec_simple0 int plugins nonexhaustive transfer fixes ts lthy =
600  let
601    val _ = check_duplicate_const_names (map (fst o fst) fixes);
602
603    val actual_nn = length fixes;
604
605    val nonexhaustives = replicate actual_nn nonexhaustive;
606    val transfers = replicate actual_nn transfer;
607
608    val (((names, qualifys, arg_Ts, defs), prove), lthy') =
609      prepare_primrec plugins nonexhaustives transfers fixes ts lthy;
610  in
611    lthy'
612    |> fold_map Local_Theory.define defs
613    |> tap (uncurry (print_def_consts int))
614    |-> (fn defs => fn lthy =>
615      let
616        val ((jss, simpss), lthy) = prove lthy defs;
617        val res =
618          {prefix = (names, qualifys),
619           types = map (#1 o dest_Type) arg_Ts,
620           result = (map fst defs, map (snd o snd) defs, (jss, simpss))};
621      in (res, lthy) end)
622  end;
623
624fun primrec_simple int fixes ts lthy =
625  primrec_simple0 int Plugin_Name.default_filter false false fixes ts lthy
626    |>> (fn {prefix, result, ...} => (prefix, result))
627  handle OLD_PRIMREC () =>
628    Old_Primrec.primrec_simple int fixes ts lthy
629    |>> (fn {prefix, result = (ts, thms), ...} =>
630          (map_split (rpair I) [prefix], (ts, [], ([], [thms]))))
631
632fun gen_primrec old_primrec prep_spec int opts raw_fixes raw_specs lthy =
633  let
634    val plugins = get_first (fn Plugins_Option f => SOME (f lthy) | _ => NONE) (rev opts)
635      |> the_default Plugin_Name.default_filter;
636    val nonexhaustive = exists (can (fn Nonexhaustive_Option => ())) opts;
637    val transfer = exists (can (fn Transfer_Option => ())) opts;
638
639    val (fixes, specs) = fst (prep_spec raw_fixes raw_specs lthy);
640    val spec_name = Binding.conglomerate (map (#1 o #1) fixes);
641
642    val mk_notes =
643      flat oooo @{map 4} (fn js => fn prefix => fn qualify => fn thms =>
644        let
645          val (bs, attrss) = map_split (fst o nth specs) js;
646          val notes =
647            @{map 3} (fn b => fn attrs => fn thm =>
648                ((Binding.qualify false prefix b, nitpicksimp_simp_attrs @ attrs),
649                 [([thm], [])]))
650              bs attrss thms;
651        in
652          ((qualify (Binding.qualify true prefix (Binding.name simpsN)), []), [(thms, [])]) :: notes
653        end);
654  in
655    lthy
656    |> primrec_simple0 int plugins nonexhaustive transfer fixes (map snd specs)
657    |-> (fn {prefix = (names, qualifys), types, result = (ts, defs, (jss, simpss))} =>
658      Spec_Rules.add spec_name (Spec_Rules.equational_primrec types) ts (flat simpss)
659      #> Local_Theory.notes (mk_notes jss names qualifys simpss)
660      #-> (fn notes =>
661        plugins code_plugin ? Code.declare_default_eqns (map (rpair true) (maps snd notes))
662        #> pair (ts, defs, map_filter (fn ("", _) => NONE | (_, thms) => SOME thms) notes)))
663  end
664  handle OLD_PRIMREC () =>
665    old_primrec int raw_fixes raw_specs lthy
666    |>> (fn {result = (ts, thms), ...} => (ts, [], [thms]));
667
668val primrec = gen_primrec Old_Primrec.primrec Specification.check_multi_specs;
669val primrec_cmd = gen_primrec Old_Primrec.primrec_cmd Specification.read_multi_specs;
670
671fun primrec_global int opts fixes specs =
672  Named_Target.theory_init
673  #> primrec int opts fixes specs
674  ##> Local_Theory.exit_global;
675
676fun primrec_overloaded int opts ops fixes specs =
677  Overloading.overloading ops
678  #> primrec int opts fixes specs
679  ##> Local_Theory.exit_global;
680
681val rec_option_parser = Parse.group (K "option")
682  (Plugin_Name.parse_filter >> Plugins_Option
683   || Parse.reserved "nonexhaustive" >> K Nonexhaustive_Option
684   || Parse.reserved "transfer" >> K Transfer_Option);
685
686val _ = Outer_Syntax.local_theory \<^command_keyword>\<open>primrec\<close>
687  "define primitive recursive functions"
688  ((Scan.optional (\<^keyword>\<open>(\<close> |-- Parse.!!! (Parse.list1 rec_option_parser)
689      --| \<^keyword>\<open>)\<close>) []) -- Parse_Spec.specification
690    >> (fn (opts, (fixes, specs)) => snd o primrec_cmd true opts fixes specs));
691
692end;
693