1(*  Title:      HOL/Tools/BNF/bnf_fp_rec_sugar_util.ML
2    Author:     Lorenz Panny, TU Muenchen
3    Author:     Jasmin Blanchette, TU Muenchen
4    Copyright   2013
5
6Library for recursor and corecursor sugar.
7*)
8
9signature BNF_FP_REC_SUGAR_UTIL =
10sig
11  val error_at: Proof.context -> term list -> string -> 'a
12  val warning_at: Proof.context -> term list -> string -> unit
13
14  val excess_equations: Proof.context -> term list -> 'a
15  val extra_variable_in_rhs: Proof.context -> term list -> term -> 'a
16  val ill_formed_corec_call: Proof.context -> term -> 'a
17  val ill_formed_equation_head: Proof.context -> term list -> 'a
18  val ill_formed_equation_lhs_rhs: Proof.context -> term list -> 'a
19  val ill_formed_equation: Proof.context -> term -> 'a
20  val ill_formed_formula: Proof.context -> term -> 'a
21  val ill_formed_rec_call: Proof.context -> term -> 'a
22  val inconstant_pattern_pos_for_fun: Proof.context -> term list -> string -> 'a
23  val invalid_map: Proof.context -> term list -> term -> 'a
24  val missing_args_to_fun_on_lhs: Proof.context -> term list -> 'a
25  val missing_equations_for_const: string -> 'a
26  val missing_equations_for_fun: string -> 'a
27  val missing_pattern: Proof.context -> term list -> 'a
28  val more_than_one_nonvar_in_lhs: Proof.context -> term list -> 'a
29  val multiple_equations_for_ctr: Proof.context -> term list -> 'a
30  val nonprimitive_corec: Proof.context -> term list -> 'a
31  val nonprimitive_pattern_in_lhs: Proof.context -> term list -> 'a
32  val not_codatatype: Proof.context -> typ -> 'a
33  val not_datatype: Proof.context -> typ -> 'a
34  val not_constructor_in_pattern: Proof.context -> term list -> term -> 'a
35  val not_constructor_in_rhs: Proof.context -> term list -> term -> 'a
36  val rec_call_not_apply_to_ctr_arg: Proof.context -> term list -> term -> 'a
37  val partially_applied_ctr_in_pattern: Proof.context -> term list -> 'a
38  val partially_applied_ctr_in_rhs: Proof.context -> term list -> 'a
39  val too_few_args_in_rec_call: Proof.context -> term list -> term -> 'a
40  val unexpected_rec_call_in: Proof.context -> term list -> term -> 'a
41  val unexpected_corec_call_in: Proof.context -> term list -> term -> 'a
42  val unsupported_case_around_corec_call: Proof.context -> term list -> term -> 'a
43
44  val no_equation_for_ctr_warning: Proof.context -> term list -> term -> unit
45
46  val check_all_fun_arg_frees: Proof.context -> term list -> term list -> unit
47  val check_duplicate_const_names: binding list -> unit
48  val check_duplicate_variables_in_lhs: Proof.context -> term list -> term list -> unit
49  val check_top_sort: Proof.context -> binding -> typ -> unit
50
51  datatype fp_kind = Least_FP | Greatest_FP
52
53  val case_fp: fp_kind -> 'a -> 'a -> 'a
54
55  type fp_rec_sugar =
56    {transfers: bool list,
57     fun_names: string list,
58     funs: term list,
59     fun_defs: thm list,
60     fpTs: typ list}
61
62  val morph_fp_rec_sugar: morphism -> fp_rec_sugar -> fp_rec_sugar
63  val transfer_fp_rec_sugar: theory -> fp_rec_sugar -> fp_rec_sugar
64
65  val flat_rec_arg_args: 'a list list -> 'a list
66
67  val indexed: 'a list -> int -> int list * int
68  val indexedd: 'a list list -> int -> int list list * int
69  val indexeddd: 'a list list list -> int -> int list list list * int
70  val indexedddd: 'a list list list list -> int -> int list list list list * int
71  val find_index_eq: ''a list -> ''a -> int
72  val finds: ('a * 'b -> bool) -> 'a list -> 'b list -> ('a * 'b list) list * 'b list
73  val find_indices: ('b * 'a -> bool) -> 'a list -> 'b list -> int list
74
75  val order_strong_conn: ('a * 'a -> bool) -> ((('a * unit) * 'a list) list -> 'b) ->
76    ('b -> 'a list) -> ('a * 'a list) list -> 'a list list -> 'a list list
77
78  val mk_common_name: string list -> string
79
80  val num_binder_types: typ -> int
81  val exists_subtype_in: typ list -> typ -> bool
82  val exists_strict_subtype_in: typ list -> typ -> bool
83  val tvar_subst: theory -> typ list -> typ list -> ((string * int) * typ) list
84
85  val retype_const_or_free: typ -> term -> term
86  val drop_all: term -> term
87  val permute_args: int -> term -> term
88
89  val mk_partial_compN: int -> typ -> term -> term
90  val mk_compN: int -> typ list -> term * term -> term
91  val mk_comp: typ list -> term * term -> term
92
93  val mk_co_rec: theory -> fp_kind -> typ list -> typ -> term -> term
94
95  val mk_conjunctN: int -> int -> thm
96  val conj_dests: int -> thm -> thm list
97
98  val print_def_consts: bool -> (term * (string * thm)) list -> Proof.context -> unit
99end;
100
101structure BNF_FP_Rec_Sugar_Util : BNF_FP_REC_SUGAR_UTIL =
102struct
103
104fun error_at ctxt ats str =
105  error (str ^ (if null ats then ""
106    else " at\n" ^ cat_lines (map (prefix "  " o Syntax.string_of_term ctxt) ats)));
107fun warning_at ctxt ats str =
108  warning (str ^ (if null ats then ""
109    else " at\n" ^ cat_lines (map (prefix "  " o Syntax.string_of_term ctxt) ats)));
110
111fun excess_equations ctxt ats =
112  error ("Excess equation(s):\n" ^
113    cat_lines (map (prefix "  " o Syntax.string_of_term ctxt) ats));
114fun extra_variable_in_rhs ctxt ats var =
115  error_at ctxt ats ("Extra variable " ^ quote (Syntax.string_of_term ctxt var) ^
116    " in right-hand side");
117fun ill_formed_corec_call ctxt t =
118  error ("Ill-formed corecursive call " ^ quote (Syntax.string_of_term ctxt t));
119fun ill_formed_equation_head ctxt ats =
120  error_at ctxt ats "Ill-formed function equation (expected function name on left-hand side)";
121fun ill_formed_equation_lhs_rhs ctxt ats =
122  error_at ctxt ats "Ill-formed equation (expected \"lhs = rhs\")";
123fun ill_formed_equation ctxt t =
124  error_at ctxt [] ("Ill-formed equation:\n  " ^ Syntax.string_of_term ctxt t);
125fun ill_formed_formula ctxt t =
126  error_at ctxt [] ("Ill-formed formula:\n  " ^ Syntax.string_of_term ctxt t);
127fun ill_formed_rec_call ctxt t =
128  error ("Ill-formed recursive call: " ^ quote (Syntax.string_of_term ctxt t));
129fun inconstant_pattern_pos_for_fun ctxt ats fun_name =
130  error_at ctxt ats ("Inconstant constructor pattern position for function " ^ quote fun_name);
131fun invalid_map ctxt ats t =
132  error_at ctxt ats ("Invalid map function in " ^ quote (Syntax.string_of_term ctxt t));
133fun missing_args_to_fun_on_lhs ctxt ats =
134  error_at ctxt ats "Expected more arguments to function on left-hand side";
135fun missing_equations_for_const fun_name =
136  error ("Missing equations for constant " ^ quote fun_name);
137fun missing_equations_for_fun fun_name =
138  error ("Missing equations for function " ^ quote fun_name);
139fun missing_pattern ctxt ats =
140  error_at ctxt ats "Constructor pattern missing in left-hand side";
141fun more_than_one_nonvar_in_lhs ctxt ats =
142  error_at ctxt ats "More than one non-variable argument in left-hand side";
143fun multiple_equations_for_ctr ctxt ats =
144  error ("Multiple equations for constructor:\n" ^
145    cat_lines (map (prefix "  " o Syntax.string_of_term ctxt) ats));
146fun nonprimitive_corec ctxt ats =
147  error_at ctxt ats "Nonprimitive corecursive specification";
148fun nonprimitive_pattern_in_lhs ctxt ats =
149  error_at ctxt ats "Nonprimitive pattern in left-hand side";
150fun not_codatatype ctxt T =
151  error ("Not a codatatype: " ^ Syntax.string_of_typ ctxt T);
152fun not_datatype ctxt T =
153  error ("Not a datatype: " ^ Syntax.string_of_typ ctxt T);
154fun not_constructor_in_pattern ctxt ats t =
155  error_at ctxt ats ("Not a constructor " ^ quote (Syntax.string_of_term ctxt t) ^
156    " in pattern");
157fun not_constructor_in_rhs ctxt ats t =
158  error_at ctxt ats ("Not a constructor " ^ quote (Syntax.string_of_term ctxt t) ^
159    " in right-hand side");
160fun rec_call_not_apply_to_ctr_arg ctxt ats t =
161  error_at ctxt ats ("Recursive call not directly applied to constructor argument in " ^
162    quote (Syntax.string_of_term ctxt t));
163fun partially_applied_ctr_in_pattern ctxt ats =
164  error_at ctxt ats "Partially applied constructor in pattern";
165fun partially_applied_ctr_in_rhs ctxt ats =
166  error_at ctxt ats "Partially applied constructor in right-hand side";
167fun too_few_args_in_rec_call ctxt ats t =
168  error_at ctxt ats ("Too few arguments in recursive call " ^ quote (Syntax.string_of_term ctxt t));
169fun unexpected_rec_call_in ctxt ats t =
170  error_at ctxt ats ("Unexpected recursive call in " ^ quote (Syntax.string_of_term ctxt t));
171fun unexpected_corec_call_in ctxt ats t =
172  error_at ctxt ats ("Unexpected corecursive call in " ^ quote (Syntax.string_of_term ctxt t));
173fun unsupported_case_around_corec_call ctxt ats t =
174  error_at ctxt ats ("Unsupported corecursive call under case expression " ^
175    quote (Syntax.string_of_term ctxt t) ^
176    "\n(Define datatype with discriminators and selectors to circumvent this limitation)");
177
178fun no_equation_for_ctr_warning ctxt ats ctr =
179  warning_at ctxt ats ("No equation for constructor " ^ quote (Syntax.string_of_term ctxt ctr));
180
181fun check_all_fun_arg_frees ctxt ats fun_args =
182  (case find_first (not o is_Free) fun_args of
183    SOME t => error_at ctxt ats ("Non-variable function argument on left-hand side " ^
184      quote (Syntax.string_of_term ctxt t))
185  | NONE =>
186    (case find_first (Variable.is_fixed ctxt o fst o dest_Free) fun_args of
187      SOME t => error_at ctxt ats ("Function argument " ^
188        quote (Syntax.string_of_term ctxt t) ^ " is fixed in context")
189    | NONE => ()));
190
191fun check_duplicate_const_names bs =
192  let val dups = duplicates (op =) (map Binding.name_of bs) in
193    ignore (null dups orelse error ("Duplicate constant name " ^ quote (hd dups)))
194  end;
195
196fun check_duplicate_variables_in_lhs ctxt ats vars =
197  let val dups = duplicates (op aconv) vars in
198    ignore (null dups orelse
199      error_at ctxt ats ("Duplicable variable " ^ quote (Syntax.string_of_term ctxt (hd dups)) ^
200        " in left-hand side"))
201  end;
202
203fun check_top_sort ctxt b T =
204  ignore (Sign.of_sort (Proof_Context.theory_of ctxt) (T, \<^sort>\<open>type\<close>) orelse
205    error ("Type of " ^ Binding.print b ^ " contains top sort"));
206
207datatype fp_kind = Least_FP | Greatest_FP;
208
209fun case_fp Least_FP l _ = l
210  | case_fp Greatest_FP _ g = g;
211
212type fp_rec_sugar =
213  {transfers: bool list,
214   fun_names: string list,
215   funs: term list,
216   fun_defs: thm list,
217   fpTs: typ list};
218
219fun morph_fp_rec_sugar phi ({transfers, fun_names, funs, fun_defs, fpTs} : fp_rec_sugar) =
220  {transfers = transfers,
221   fun_names = fun_names,
222   funs = map (Morphism.term phi) funs,
223   fun_defs = map (Morphism.thm phi) fun_defs,
224   fpTs = map (Morphism.typ phi) fpTs};
225
226val transfer_fp_rec_sugar = morph_fp_rec_sugar o Morphism.transfer_morphism;
227
228fun flat_rec_arg_args xss =
229  (* FIXME (once the old datatype package is completely phased out): The first line below gives the
230     preferred order. The second line is for compatibility with the old datatype package. *)
231  (* flat xss *)
232  map hd xss @ maps tl xss;
233
234fun indexe _ h = (h, h + 1);
235fun indexed xs = fold_map indexe xs;
236fun indexedd xss = fold_map indexed xss;
237fun indexeddd xsss = fold_map indexedd xsss;
238fun indexedddd xssss = fold_map indexeddd xssss;
239
240fun find_index_eq hs h = find_index (curry (op =) h) hs;
241
242fun finds eq = fold_map (fn x => List.partition (curry eq x) #>> pair x);
243
244fun find_indices eq xs ys =
245  map_filter I (map_index (fn (i, y) => if member eq xs y then SOME i else NONE) ys);
246
247fun order_strong_conn eq make_graph topological_order deps sccs =
248  let
249    val normals = maps (fn x :: xs => map (fn y => (y, x)) xs) sccs;
250    fun normal s = AList.lookup eq normals s |> the_default s;
251
252    val normal_deps = deps
253      |> map (fn (x, xs) => let val x' = normal x in
254          (x', fold (insert eq o normal) xs [] |> remove eq x')
255        end)
256      |> AList.group eq
257      |> map (apsnd (fn xss => fold (union eq) xss []));
258
259    val normal_G = make_graph (map (apfst (rpair ())) normal_deps);
260    val ordered_normals = rev (topological_order normal_G);
261  in
262    map (fn x => the (find_first (fn (y :: _) => eq (y, x)) sccs)) ordered_normals
263  end;
264
265val mk_common_name = space_implode "_";
266
267fun num_binder_types (Type (\<^type_name>\<open>fun\<close>, [_, T])) = 1 + num_binder_types T
268  | num_binder_types _ = 0;
269
270val exists_subtype_in = Term.exists_subtype o member (op =);
271fun exists_strict_subtype_in Ts T = exists_subtype_in (remove (op =) T Ts) T;
272
273fun tvar_subst thy Ts Us =
274  Vartab.fold (cons o apsnd snd) (fold (Sign.typ_match thy) (Ts ~~ Us) Vartab.empty) [];
275
276fun retype_const_or_free T (Const (s, _)) = Const (s, T)
277  | retype_const_or_free T (Free (s, _)) = Free (s, T)
278  | retype_const_or_free _ t = raise TERM ("retype_const_or_free", [t]);
279
280fun drop_all t =
281  subst_bounds (strip_qnt_vars \<^const_name>\<open>Pure.all\<close> t |> map Free |> rev,
282    strip_qnt_body \<^const_name>\<open>Pure.all\<close> t);
283
284fun permute_args n t =
285  list_comb (t, map Bound (0 :: (n downto 1))) |> fold (K (Term.abs (Name.uu, dummyT))) (0 upto n);
286
287fun mk_partial_comp fT g = fst (Term.dest_comb (HOLogic.mk_comp (g, Free (Name.uu, fT))));
288
289fun mk_partial_compN 0 _ g = g
290  | mk_partial_compN n fT g = mk_partial_comp fT (mk_partial_compN (n - 1) (range_type fT) g);
291
292fun mk_compN n bound_Ts (g, f) =
293  let val typof = curry fastype_of1 bound_Ts in
294    mk_partial_compN n (typof f) g $ f
295  end;
296
297val mk_comp = mk_compN 1;
298
299fun mk_co_rec thy fp Cs fpT t =
300  let
301    val ((f_Cs, prebody), body) = strip_type (fastype_of t) |>> split_last;
302    val fpT0 = case_fp fp prebody body;
303    val Cs0 = distinct (op =) (map (case_fp fp body_type domain_type) f_Cs);
304    val rho = tvar_subst thy (fpT0 :: Cs0) (fpT :: Cs);
305  in
306    Term.subst_TVars rho t
307  end;
308
309fun mk_conjunctN 1 1 = @{thm TrueE[OF TrueI]}
310  | mk_conjunctN _ 1 = conjunct1
311  | mk_conjunctN 2 2 = conjunct2
312  | mk_conjunctN n m = conjunct2 RS (mk_conjunctN (n - 1) (m - 1));
313
314fun conj_dests n thm = map (fn k => thm RS mk_conjunctN n k) (1 upto n);
315
316fun print_def_consts int defs ctxt =
317  Proof_Display.print_consts int (Position.thread_data ()) ctxt (K false)
318    (map_filter (try (dest_Free o fst)) defs);
319
320end;
321