1(*  Title:      HOL/Tools/Function/mutual.ML
2    Author:     Alexander Krauss, TU Muenchen
3
4Mutual recursive function definitions.
5*)
6
7signature FUNCTION_MUTUAL =
8sig
9  val prepare_function_mutual : Function_Common.function_config
10    -> binding (* defname *)
11    -> ((binding * typ) * mixfix) list
12    -> term list
13    -> local_theory
14    -> ((thm (* goalstate *)
15        * (Proof.context -> thm -> Function_Common.function_result) (* proof continuation *)
16       ) * local_theory)
17end
18
19structure Function_Mutual: FUNCTION_MUTUAL =
20struct
21
22open Function_Lib
23open Function_Common
24
25type qgar = string * (string * typ) list * term list * term list * term
26
27datatype mutual_part = MutualPart of
28 {i : int,
29  i' : int,
30  fname : binding,
31  fT : typ,
32  cargTs: typ list,
33  f_def: term,
34
35  f: term option,
36  f_defthm : thm option}
37
38datatype mutual_info = Mutual of
39 {n : int,
40  n' : int,
41  fsum_name : binding,
42  fsum_type: typ,
43
44  ST: typ,
45  RST: typ,
46
47  parts: mutual_part list,
48  fqgars: qgar list,
49  qglrs: ((string * typ) list * term list * term * term) list,
50
51  fsum : term option}
52
53fun mutual_induct_Pnames n =
54  if n < 5 then fst (chop n ["P","Q","R","S"])
55  else map (fn i => "P" ^ string_of_int i) (1 upto n)
56
57fun get_part f =
58  the o find_first (fn (MutualPart {fname, ...}) => Binding.name_of fname = f)
59
60(* FIXME *)
61fun mk_prod_abs e (t1, t2) =
62  let
63    val bTs = rev (map snd e)
64    val T1 = fastype_of1 (bTs, t1)
65    val T2 = fastype_of1 (bTs, t2)
66  in
67    HOLogic.pair_const T1 T2 $ t1 $ t2
68  end
69
70fun analyze_eqs ctxt defname fs eqs =
71  let
72    val num = length fs
73    val fqgars = map (split_def ctxt (K true)) eqs
74    fun arity_of fname =
75      the (get_first (fn (f, _, _, args, _) =>
76        if f = Binding.name_of fname then SOME (length args) else NONE) fqgars)
77
78    fun curried_types (fname, fT) =
79      let val (caTs, uaTs) = chop (arity_of fname) (binder_types fT)
80      in (caTs, uaTs ---> body_type fT) end
81
82    val (caTss, resultTs) = split_list (map curried_types fs)
83    val argTs = map (foldr1 HOLogic.mk_prodT) caTss
84
85    val dresultTs = distinct (op =) resultTs
86    val n' = length dresultTs
87
88    val RST = Balanced_Tree.make (uncurry Sum_Tree.mk_sumT) dresultTs
89    val ST = Balanced_Tree.make (uncurry Sum_Tree.mk_sumT) argTs
90
91    val fsum_name = derived_name_suffix defname "_sum"
92    val ([fsum_var_name], _) = Variable.add_fixes_binding [fsum_name] ctxt
93    val fsum_type = ST --> RST
94    val fsum_var = (fsum_var_name, fsum_type)
95
96    fun define (fname, fT) caTs resultT i =
97      let
98        val vars = map_index (fn (j,T) => Free ("x" ^ string_of_int j, T)) caTs (* FIXME: Bind xs properly *)
99        val i' = find_index (fn Ta => Ta = resultT) dresultTs + 1
100
101        val f_exp = Sum_Tree.mk_proj RST n' i' (Free fsum_var $ Sum_Tree.mk_inj ST num i (foldr1 HOLogic.mk_prod vars))
102        val def = Term.abstract_over (Free fsum_var, fold_rev lambda vars f_exp)
103
104        val rew = (Binding.name_of fname, fold_rev lambda vars f_exp)
105      in
106        (MutualPart
107          {i = i, i' = i', fname = fname, fT = fT, cargTs = caTs,
108            f_def = def, f = NONE, f_defthm = NONE}, rew)
109      end
110
111    val (parts, rews) = split_list (@{map 4} define fs caTss resultTs (1 upto num))
112
113    fun convert_eqs (f, qs, gs, args, rhs) =
114      let
115        val MutualPart {i, i', ...} = get_part f parts
116        val rhs' = rhs
117          |> map_aterms (fn t as Free (n, _) => the_default t (AList.lookup (op =) rews n) | t => t)
118      in
119        (qs, gs, Sum_Tree.mk_inj ST num i (foldr1 (mk_prod_abs qs) args),
120         Envir.beta_norm (Sum_Tree.mk_inj RST n' i' rhs'))
121      end
122
123    val qglrs = map convert_eqs fqgars
124  in
125    Mutual {n = num, n' = n', fsum_name = fsum_name, fsum_type = fsum_type,
126      ST = ST, RST = RST, parts = parts, fqgars = fqgars, qglrs = qglrs, fsum = NONE}
127  end
128
129fun define_projections fixes mutual fsum lthy =
130  let
131    fun def ((MutualPart {i=i, i'=i', fname, fT, cargTs, f_def, ...}), (_, mixfix)) lthy =
132      let
133        val def_binding = Thm.make_def_binding (Config.get lthy function_internals) fname
134        val ((f, (_, f_defthm)), lthy') =
135          Local_Theory.define
136            ((fname, mixfix), ((def_binding, []), Term.subst_bound (fsum, f_def))) lthy
137      in
138        (MutualPart {i = i, i' = i', fname = fname, fT = fT, cargTs = cargTs,
139            f_def = f_def, f = SOME f, f_defthm = SOME f_defthm}, lthy')
140      end
141
142    val Mutual {n, n', fsum_name, fsum_type, ST, RST, parts, fqgars, qglrs, ...} = mutual
143    val (parts', lthy') = fold_map def (parts ~~ fixes) lthy
144  in
145    (Mutual {n = n, n' = n', fsum_name = fsum_name, fsum_type = fsum_type, ST = ST,
146      RST = RST, parts = parts', fqgars = fqgars, qglrs = qglrs, fsum = SOME fsum}, lthy')
147  end
148
149fun in_context ctxt (f, pre_qs, pre_gs, pre_args, pre_rhs) F =
150  let
151    val oqnames = map fst pre_qs
152    val (qs, _) = Variable.variant_fixes oqnames ctxt
153      |>> map2 (fn (_, T) => fn n => Free (n, T)) pre_qs
154
155    fun inst t = subst_bounds (rev qs, t)
156    val gs = map inst pre_gs
157    val args = map inst pre_args
158    val rhs = inst pre_rhs
159
160    val cqs = map (Thm.cterm_of ctxt) qs
161    val ags = map (Thm.assume o Thm.cterm_of ctxt) gs
162
163    val import = fold Thm.forall_elim cqs
164      #> fold Thm.elim_implies ags
165
166    val export = fold_rev (Thm.implies_intr o Thm.cprop_of) ags
167      #> fold_rev forall_intr_rename (oqnames ~~ cqs)
168  in
169    F ctxt (f, qs, gs, args, rhs) import export
170  end
171
172fun recover_mutual_psimp all_orig_fdefs parts ctxt (fname, _, _, args, rhs)
173    import (export : thm -> thm) sum_psimp_eq =
174  let
175    val (MutualPart {f=SOME f, ...}) = get_part fname parts
176
177    val psimp = import sum_psimp_eq
178    val (simp, restore_cond) =
179      case cprems_of psimp of
180        [] => (psimp, I)
181      | [cond] => (Thm.implies_elim psimp (Thm.assume cond), Thm.implies_intr cond)
182      | _ => raise General.Fail "Too many conditions"
183
184    val simp_ctxt = fold Thm.declare_hyps (Thm.chyps_of simp) ctxt
185  in
186    Goal.prove simp_ctxt [] []
187      (HOLogic.Trueprop $ HOLogic.mk_eq (list_comb (f, args), rhs))
188      (fn _ =>
189        Local_Defs.unfold0_tac ctxt all_orig_fdefs
190          THEN EqSubst.eqsubst_tac ctxt [0] [simp] 1
191          THEN (simp_tac ctxt) 1)
192    |> restore_cond
193    |> export
194  end
195
196fun mk_applied_form ctxt caTs thm =
197  let
198    val xs =
199      map_index (fn (i, T) =>
200        Thm.cterm_of ctxt
201          (Free ("x" ^ string_of_int i, T))) caTs (* FIXME: Bind xs properly *)
202  in
203    fold (fn x => fn thm => Thm.combination thm (Thm.reflexive x)) xs thm
204    |> Conv.fconv_rule (Thm.beta_conversion true)
205    |> fold_rev Thm.forall_intr xs
206    |> Thm.forall_elim_vars 0
207  end
208
209fun mutual_induct_rules ctxt induct all_f_defs (Mutual {n, ST, parts, ...}) =
210  let
211    val newPs =
212      map2 (fn Pname => fn MutualPart {cargTs, ...} =>
213          Free (Pname, cargTs ---> HOLogic.boolT))
214        (mutual_induct_Pnames (length parts)) parts
215
216    fun mk_P (MutualPart {cargTs, ...}) P =
217      let
218        val avars = map_index (fn (i,T) => Var (("a", i), T)) cargTs
219        val atup = foldr1 HOLogic.mk_prod avars
220      in
221        HOLogic.tupled_lambda atup (list_comb (P, avars))
222      end
223
224    val Ps = map2 mk_P parts newPs
225    val case_exp = Sum_Tree.mk_sumcases HOLogic.boolT Ps
226
227    val induct_inst =
228      Thm.forall_elim (Thm.cterm_of ctxt case_exp) induct
229      |> full_simplify (put_simpset Sum_Tree.sumcase_split_ss ctxt)
230      |> full_simplify (put_simpset HOL_basic_ss ctxt addsimps all_f_defs)
231
232    fun project rule (MutualPart {cargTs, i, ...}) k =
233      let
234        val afs = map_index (fn (j,T) => Free ("a" ^ string_of_int (j + k), T)) cargTs (* FIXME! *)
235        val inj = Sum_Tree.mk_inj ST n i (foldr1 HOLogic.mk_prod afs)
236      in
237        (rule
238         |> Thm.forall_elim (Thm.cterm_of ctxt inj)
239         |> full_simplify (put_simpset Sum_Tree.sumcase_split_ss ctxt)
240         |> fold_rev (Thm.forall_intr o Thm.cterm_of ctxt) (afs @ newPs),
241         k + length cargTs)
242      end
243  in
244    fst (fold_map (project induct_inst) parts 0)
245  end
246
247fun mutual_cases_rule ctxt cases_rule n ST (MutualPart {i, cargTs = Ts, ...}) =
248  let
249    val [P, x] =
250      Variable.variant_frees ctxt [] [("P", \<^typ>\<open>bool\<close>), ("x", HOLogic.mk_tupleT Ts)]
251      |> map (Thm.cterm_of ctxt o Free);
252    val sumtree_inj = Thm.cterm_of ctxt (Sum_Tree.mk_inj ST n i (Thm.term_of x));
253
254    fun prep_subgoal_tac i =
255      REPEAT (eresolve_tac ctxt
256        @{thms Pair_inject Inl_inject [elim_format] Inr_inject [elim_format]} i)
257      THEN REPEAT (eresolve_tac ctxt
258        @{thms HOL.notE [OF Sum_Type.sum.distinct(1)] HOL.notE [OF Sum_Type.sum.distinct(2)]} i);
259  in
260    cases_rule
261    |> Thm.forall_elim P
262    |> Thm.forall_elim sumtree_inj
263    |> Tactic.rule_by_tactic ctxt (ALLGOALS prep_subgoal_tac)
264    |> Thm.forall_intr x
265    |> Thm.forall_intr P
266  end
267
268
269fun mk_partial_rules_mutual lthy inner_cont (m as Mutual {parts, fqgars, n, ST, ...}) proof =
270  let
271    val result = inner_cont proof
272    val FunctionResult {G, R, cases=[cases_rule], psimps, simple_pinducts=[simple_pinduct],
273      termination, domintros, dom, pelims, ...} = result
274
275    val (all_f_defs, fs) =
276      map (fn MutualPart {f_defthm = SOME f_def, f = SOME f, cargTs, ...} =>
277        (mk_applied_form lthy cargTs (Thm.symmetric f_def), f))
278      parts
279      |> split_list
280
281    val all_orig_fdefs =
282      map (fn MutualPart {f_defthm = SOME f_def, ...} => f_def) parts
283
284    fun mk_mpsimp fqgar sum_psimp =
285      in_context lthy fqgar (recover_mutual_psimp all_orig_fdefs parts) sum_psimp
286
287    val rew_simpset = put_simpset HOL_basic_ss lthy addsimps all_f_defs
288    val mpsimps = map2 mk_mpsimp fqgars psimps
289    val minducts = mutual_induct_rules lthy simple_pinduct all_f_defs m
290    val mcases = map (mutual_cases_rule lthy cases_rule n ST) parts
291    val mtermination = full_simplify rew_simpset termination
292    val mdomintros = Option.map (map (full_simplify rew_simpset)) domintros
293
294  in
295    FunctionResult { fs=fs, G=G, R=R, dom=dom,
296      psimps=mpsimps, simple_pinducts=minducts,
297      cases=mcases, pelims=pelims, termination=mtermination,
298      domintros=mdomintros}
299  end
300
301
302fun prepare_function_mutual config defname fixes eqss lthy =
303  let
304    val mutual as Mutual {fsum_name, fsum_type, qglrs, ...} =
305      analyze_eqs lthy defname (map fst fixes) (map Envir.beta_eta_contract eqss)
306
307    val ((fsum, goalstate, cont), lthy') =
308      Function_Core.prepare_function config defname [((fsum_name, fsum_type), NoSyn)] qglrs lthy
309
310    val (mutual', lthy'') = define_projections fixes mutual fsum lthy'
311
312    fun cont' ctxt = mk_partial_rules_mutual lthy'' (cont ctxt) mutual'
313  in
314    ((goalstate, cont'), lthy'')
315  end
316
317end
318