1(*  Title:      HOL/Tools/Function/induction_schema.ML
2    Author:     Alexander Krauss, TU Muenchen
3
4A method to prove induction schemas.
5*)
6
7signature INDUCTION_SCHEMA =
8sig
9  val mk_ind_tac : (int -> tactic) -> (int -> tactic) -> (int -> tactic)
10                   -> Proof.context -> thm list -> tactic
11  val induction_schema_tac : Proof.context -> thm list -> tactic
12end
13
14structure Induction_Schema : INDUCTION_SCHEMA =
15struct
16
17open Function_Lib
18
19type rec_call_info = int * (string * typ) list * term list * term list
20
21datatype scheme_case = SchemeCase of
22 {bidx : int,
23  qs: (string * typ) list,
24  oqnames: string list,
25  gs: term list,
26  lhs: term list,
27  rs: rec_call_info list}
28
29datatype scheme_branch = SchemeBranch of
30 {P : term,
31  xs: (string * typ) list,
32  ws: (string * typ) list,
33  Cs: term list}
34
35datatype ind_scheme = IndScheme of
36 {T: typ, (* sum of products *)
37  branches: scheme_branch list,
38  cases: scheme_case list}
39
40fun ind_atomize ctxt = Raw_Simplifier.rewrite ctxt true @{thms induct_atomize}
41fun ind_rulify ctxt = Raw_Simplifier.rewrite ctxt true @{thms induct_rulify}
42
43fun meta thm = thm RS eq_reflection
44
45fun sum_prod_conv ctxt = Raw_Simplifier.rewrite ctxt true
46  (map meta (@{thm split_conv} :: @{thms sum.case}))
47
48fun term_conv ctxt cv t =
49  cv (Thm.cterm_of ctxt t)
50  |> Thm.prop_of |> Logic.dest_equals |> snd
51
52fun mk_relT T = HOLogic.mk_setT (HOLogic.mk_prodT (T, T))
53
54fun dest_hhf ctxt t =
55  let
56    val ((params, imp), ctxt') = Variable.focus NONE t ctxt
57  in
58    (ctxt', map #2 params, Logic.strip_imp_prems imp, Logic.strip_imp_concl imp)
59  end
60
61fun mk_scheme' ctxt cases concl =
62  let
63    fun mk_branch concl =
64      let
65        val (_, ws, Cs, _ $ Pxs) = dest_hhf ctxt concl
66        val (P, xs) = strip_comb Pxs
67      in
68        SchemeBranch { P=P, xs=map dest_Free xs, ws=ws, Cs=Cs }
69      end
70
71    val (branches, cases') = (* correction *)
72      case Logic.dest_conjunctions concl of
73        [conc] =>
74        let
75          val _ $ Pxs = Logic.strip_assums_concl conc
76          val (P, _) = strip_comb Pxs
77          val (cases', conds) =
78            chop_prefix (Term.exists_subterm (curry op aconv P)) cases
79          val concl' = fold_rev (curry Logic.mk_implies) conds conc
80        in
81          ([mk_branch concl'], cases')
82        end
83      | concls => (map mk_branch concls, cases)
84
85    fun mk_case premise =
86      let
87        val (ctxt', qs, prems, _ $ Plhs) = dest_hhf ctxt premise
88        val (P, lhs) = strip_comb Plhs
89
90        fun bidx Q =
91          find_index (fn SchemeBranch {P=P',...} => Q aconv P') branches
92
93        fun mk_rcinfo pr =
94          let
95            val (_, Gvs, Gas, _ $ Phyp) = dest_hhf ctxt' pr
96            val (P', rcs) = strip_comb Phyp
97          in
98            (bidx P', Gvs, Gas, rcs)
99          end
100
101        fun is_pred v = exists (fn SchemeBranch {P,...} => v aconv P) branches
102
103        val (gs, rcprs) =
104          chop_prefix (not o Term.exists_subterm is_pred) prems
105      in
106        SchemeCase {bidx=bidx P, qs=qs, oqnames=map fst qs(*FIXME*),
107          gs=gs, lhs=lhs, rs=map mk_rcinfo rcprs}
108      end
109
110    fun PT_of (SchemeBranch { xs, ...}) =
111      foldr1 HOLogic.mk_prodT (map snd xs)
112
113    val ST = Balanced_Tree.make (uncurry Sum_Tree.mk_sumT) (map PT_of branches)
114  in
115    IndScheme {T=ST, cases=map mk_case cases', branches=branches }
116  end
117
118fun mk_completeness ctxt (IndScheme {cases, branches, ...}) bidx =
119  let
120    val SchemeBranch { xs, ws, Cs, ... } = nth branches bidx
121    val relevant_cases = filter (fn SchemeCase {bidx=bidx', ...} => bidx' = bidx) cases
122
123    val allqnames = fold (fn SchemeCase {qs, ...} => fold (insert (op =) o Free) qs) relevant_cases []
124    val (Pbool :: xs') = map Free (Variable.variant_frees ctxt allqnames (("P", HOLogic.boolT) :: xs))
125    val Cs' = map (Pattern.rewrite_term (Proof_Context.theory_of ctxt) (filter_out (op aconv) (map Free xs ~~ xs')) []) Cs
126
127    fun mk_case (SchemeCase {qs, oqnames, gs, lhs, ...}) =
128      HOLogic.mk_Trueprop Pbool
129      |> fold_rev (fn x_l => curry Logic.mk_implies (HOLogic.mk_Trueprop(HOLogic.mk_eq x_l)))
130           (xs' ~~ lhs)
131      |> fold_rev (curry Logic.mk_implies) gs
132      |> fold_rev mk_forall_rename (oqnames ~~ map Free qs)
133  in
134    HOLogic.mk_Trueprop Pbool
135    |> fold_rev (curry Logic.mk_implies o mk_case) relevant_cases
136    |> fold_rev (curry Logic.mk_implies) Cs'
137    |> fold_rev (Logic.all o Free) ws
138    |> fold_rev mk_forall_rename (map fst xs ~~ xs')
139    |> mk_forall_rename ("P", Pbool)
140  end
141
142fun mk_wf R (IndScheme {T, ...}) =
143  HOLogic.Trueprop $ (Const (\<^const_name>\<open>wf\<close>, mk_relT T --> HOLogic.boolT) $ R)
144
145fun mk_ineqs R thesisn (IndScheme {T, cases, branches}) =
146  let
147    fun inject i ts =
148       Sum_Tree.mk_inj T (length branches) (i + 1) (foldr1 HOLogic.mk_prod ts)
149
150    val thesis = Free (thesisn, HOLogic.boolT)
151
152    fun mk_pres bdx args =
153      let
154        val SchemeBranch { xs, ws, Cs, ... } = nth branches bdx
155        fun replace (x, v) t = betapply (lambda (Free x) t, v)
156        val Cs' = map (fold replace (xs ~~ args)) Cs
157        val cse =
158          HOLogic.mk_Trueprop thesis
159          |> fold_rev (curry Logic.mk_implies) Cs'
160          |> fold_rev (Logic.all o Free) ws
161      in
162        Logic.mk_implies (cse, HOLogic.mk_Trueprop thesis)
163      end
164
165    fun f (SchemeCase {bidx, qs, oqnames, gs, lhs, rs, ...}) =
166      let
167        fun g (bidx', Gvs, Gas, rcarg) =
168          let val export =
169            fold_rev (curry Logic.mk_implies) Gas
170            #> fold_rev (curry Logic.mk_implies) gs
171            #> fold_rev (Logic.all o Free) Gvs
172            #> fold_rev mk_forall_rename (oqnames ~~ map Free qs)
173          in
174            (HOLogic.mk_mem (HOLogic.mk_prod (inject bidx' rcarg, inject bidx lhs), R)
175             |> HOLogic.mk_Trueprop
176             |> export,
177             mk_pres bidx' rcarg
178             |> export
179             |> Logic.all thesis)
180          end
181      in
182        map g rs
183      end
184  in
185    map f cases
186  end
187
188
189fun mk_ind_goal ctxt branches =
190  let
191    fun brnch (SchemeBranch { P, xs, ws, Cs, ... }) =
192      HOLogic.mk_Trueprop (list_comb (P, map Free xs))
193      |> fold_rev (curry Logic.mk_implies) Cs
194      |> fold_rev (Logic.all o Free) ws
195      |> term_conv ctxt (ind_atomize ctxt)
196      |> Object_Logic.drop_judgment ctxt
197      |> HOLogic.tupled_lambda (foldr1 HOLogic.mk_prod (map Free xs))
198  in
199    Sum_Tree.mk_sumcases HOLogic.boolT (map brnch branches)
200  end
201
202fun mk_induct_rule ctxt R x complete_thms wf_thm ineqss
203  (IndScheme {T, cases=scases, branches}) =
204  let
205    val n = length branches
206    val scases_idx = map_index I scases
207
208    fun inject i ts =
209      Sum_Tree.mk_inj T n (i + 1) (foldr1 HOLogic.mk_prod ts)
210    val P_of = nth (map (fn (SchemeBranch { P, ... }) => P) branches)
211
212    val P_comp = mk_ind_goal ctxt branches
213
214    (* Inductive Hypothesis: !!z. (z,x):R ==> P z *)
215    val ihyp = Logic.all_const T $ Abs ("z", T,
216      Logic.mk_implies
217        (HOLogic.mk_Trueprop (
218          Const (\<^const_name>\<open>Set.member\<close>, HOLogic.mk_prodT (T, T) --> mk_relT T --> HOLogic.boolT) 
219          $ (HOLogic.pair_const T T $ Bound 0 $ x)
220          $ R),
221         HOLogic.mk_Trueprop (P_comp $ Bound 0)))
222      |> Thm.cterm_of ctxt
223
224    val aihyp = Thm.assume ihyp
225
226    (* Rule for case splitting along the sum types *)
227    val xss = map (fn (SchemeBranch { xs, ... }) => map Free xs) branches
228    val pats = map_index (uncurry inject) xss
229    val sum_split_rule =
230      Pat_Completeness.prove_completeness ctxt [x] (P_comp $ x) xss (map single pats)
231
232    fun prove_branch (bidx, (SchemeBranch { P, xs, ws, Cs, ... }, (complete_thm, pat))) =
233      let
234        val fxs = map Free xs
235        val branch_hyp =
236          Thm.assume (Thm.cterm_of ctxt (HOLogic.mk_Trueprop (HOLogic.mk_eq (x, pat))))
237
238        val C_hyps = map (Thm.cterm_of ctxt #> Thm.assume) Cs
239
240        val (relevant_cases, ineqss') =
241          (scases_idx ~~ ineqss)
242          |> filter (fn ((_, SchemeCase {bidx=bidx', ...}), _) => bidx' = bidx)
243          |> split_list
244
245        fun prove_case (cidx, SchemeCase {qs, gs, lhs, rs, ...}) ineq_press =
246          let
247            val case_hyps =
248              map (Thm.assume o Thm.cterm_of ctxt o HOLogic.mk_Trueprop o HOLogic.mk_eq)
249                (fxs ~~ lhs)
250
251            val cqs = map (Thm.cterm_of ctxt o Free) qs
252            val ags = map (Thm.assume o Thm.cterm_of ctxt) gs
253
254            val replace_x_simpset =
255              put_simpset HOL_basic_ss ctxt addsimps (branch_hyp :: case_hyps)
256            val sih = full_simplify replace_x_simpset aihyp
257
258            fun mk_Prec (idx, Gvs, Gas, rcargs) (ineq, pres) =
259              let
260                val cGas = map (Thm.assume o Thm.cterm_of ctxt) Gas
261                val cGvs = map (Thm.cterm_of ctxt o Free) Gvs
262                val import = fold Thm.forall_elim (cqs @ cGvs)
263                  #> fold Thm.elim_implies (ags @ cGas)
264                val ipres = pres
265                  |> Thm.forall_elim (Thm.cterm_of ctxt (list_comb (P_of idx, rcargs)))
266                  |> import
267              in
268                sih
269                |> Thm.forall_elim (Thm.cterm_of ctxt (inject idx rcargs))
270                |> Thm.elim_implies (import ineq) (* Psum rcargs *)
271                |> Conv.fconv_rule (sum_prod_conv ctxt)
272                |> Conv.fconv_rule (ind_rulify ctxt)
273                |> (fn th => th COMP ipres) (* P rs *)
274                |> fold_rev (Thm.implies_intr o Thm.cprop_of) cGas
275                |> fold_rev Thm.forall_intr cGvs
276              end
277
278            val P_recs = map2 mk_Prec rs ineq_press   (*  [P rec1, P rec2, ... ]  *)
279
280            val step = HOLogic.mk_Trueprop (list_comb (P, lhs))
281              |> fold_rev (curry Logic.mk_implies o Thm.prop_of) P_recs
282              |> fold_rev (curry Logic.mk_implies) gs
283              |> fold_rev (Logic.all o Free) qs
284              |> Thm.cterm_of ctxt
285
286            val Plhs_to_Pxs_conv =
287              foldl1 (uncurry Conv.combination_conv)
288                (Conv.all_conv :: map (fn ch => K (Thm.symmetric (ch RS eq_reflection))) case_hyps)
289
290            val res = Thm.assume step
291              |> fold Thm.forall_elim cqs
292              |> fold Thm.elim_implies ags
293              |> fold Thm.elim_implies P_recs (* P lhs *)
294              |> Conv.fconv_rule (Conv.arg_conv Plhs_to_Pxs_conv) (* P xs *)
295              |> fold_rev (Thm.implies_intr o Thm.cprop_of) (ags @ case_hyps)
296              |> fold_rev Thm.forall_intr cqs (* !!qs. Gas ==> xs = lhss ==> P xs *)
297          in
298            (res, (cidx, step))
299          end
300
301        val (cases, steps) = split_list (map2 prove_case relevant_cases ineqss')
302
303        val bstep = complete_thm
304          |> Thm.forall_elim (Thm.cterm_of ctxt (list_comb (P, fxs)))
305          |> fold (Thm.forall_elim o Thm.cterm_of ctxt) (fxs @ map Free ws)
306          |> fold Thm.elim_implies C_hyps
307          |> fold Thm.elim_implies cases (* P xs *)
308          |> fold_rev (Thm.implies_intr o Thm.cprop_of) C_hyps
309          |> fold_rev (Thm.forall_intr o Thm.cterm_of ctxt o Free) ws
310
311        val Pxs =
312          Thm.cterm_of ctxt (HOLogic.mk_Trueprop (P_comp $ x))
313          |> Goal.init
314          |> (Simplifier.rewrite_goals_tac ctxt
315                (map meta (branch_hyp :: @{thm split_conv} :: @{thms sum.case}))
316              THEN CONVERSION (ind_rulify ctxt) 1)
317          |> Seq.hd
318          |> Thm.elim_implies (Conv.fconv_rule Drule.beta_eta_conversion bstep)
319          |> Goal.finish ctxt
320          |> Thm.implies_intr (Thm.cprop_of branch_hyp)
321          |> fold_rev (Thm.forall_intr o Thm.cterm_of ctxt) fxs
322      in
323        (Pxs, steps)
324      end
325
326    val (branches, steps) =
327      map_index prove_branch (branches ~~ (complete_thms ~~ pats))
328      |> split_list |> apsnd flat
329
330    val istep = sum_split_rule
331      |> fold (fn b => fn th => Drule.compose (b, 1, th)) branches
332      |> Thm.implies_intr ihyp
333      |> Thm.forall_intr (Thm.cterm_of ctxt x) (* "!!x. (!!y<x. P y) ==> P x" *)
334
335    val induct_rule =
336      @{thm "wf_induct_rule"}
337      |> (curry op COMP) wf_thm
338      |> (curry op COMP) istep
339
340    val steps_sorted = map snd (sort (int_ord o apply2 fst) steps)
341  in
342    (steps_sorted, induct_rule)
343  end
344
345
346fun mk_ind_tac comp_tac pres_tac term_tac ctxt facts =
347  (* FIXME proper use of facts!? *)
348  (ALLGOALS (Method.insert_tac ctxt facts)) THEN HEADGOAL (SUBGOAL (fn (t, i) =>
349  let
350    val (ctxt', _, cases, concl) = dest_hhf ctxt t
351    val scheme as IndScheme {T=ST, branches, ...} = mk_scheme' ctxt' cases concl
352    val ([Rn, xn, thesisn], ctxt'') = Variable.variant_fixes ["R", "x", "thesis"] ctxt'
353    val R = Free (Rn, mk_relT ST)
354    val x = Free (xn, ST)
355
356    val ineqss =
357      mk_ineqs R thesisn scheme
358      |> map (map (apply2 (Thm.assume o Thm.cterm_of ctxt'')))
359    val complete =
360      map_range (mk_completeness ctxt'' scheme #> Thm.cterm_of ctxt'' #> Thm.assume)
361        (length branches)
362    val wf_thm = mk_wf R scheme |> Thm.cterm_of ctxt'' |> Thm.assume
363
364    val (descent, pres) = split_list (flat ineqss)
365    val newgoals = complete @ pres @ wf_thm :: descent
366
367    val (steps, indthm) =
368      mk_induct_rule ctxt'' R x complete wf_thm ineqss scheme
369
370    fun project (i, SchemeBranch {xs, ...}) =
371      let
372        val inst = (foldr1 HOLogic.mk_prod (map Free xs))
373          |> Sum_Tree.mk_inj ST (length branches) (i + 1)
374          |> Thm.cterm_of ctxt''
375      in
376        indthm
377        |> Thm.instantiate' [] [SOME inst]
378        |> simplify (put_simpset Sum_Tree.sumcase_split_ss ctxt'')
379        |> Conv.fconv_rule (ind_rulify ctxt'')
380      end
381
382    val res = Conjunction.intr_balanced (map_index project branches)
383      |> fold_rev Thm.implies_intr (map Thm.cprop_of newgoals @ steps)
384      |> Drule.generalize ([], [Rn])
385
386    val nbranches = length branches
387    val npres = length pres
388  in
389    Thm.bicompose (SOME ctxt'') {flatten = false, match = false, incremented = false}
390      (false, res, length newgoals) i
391    THEN term_tac (i + nbranches + npres)
392    THEN (EVERY (map (TRY o pres_tac) ((i + nbranches + npres - 1) downto (i + nbranches))))
393    THEN (EVERY (map (TRY o comp_tac) ((i + nbranches - 1) downto i)))
394  end))
395
396
397fun induction_schema_tac ctxt =
398  mk_ind_tac (K all_tac) (assume_tac ctxt APPEND' Goal.assume_rule_tac ctxt) (K all_tac) ctxt;
399
400end
401