1(* ===================================================================== *)
2(* FILE          : prim_rec.sml                                          *)
3(* DESCRIPTION   : Primitive recursive definitions on arbitrary recursive*)
4(*                 types.  Assumes the type is defined by an axiom of    *)
5(*                 the form proved by the recursive types package.       *)
6(*                 Translated from hol88.                                *)
7(*                                                                       *)
8(* AUTHOR        : (c) T. F. Melham, University of Cambridge             *)
9(* DATE          : 87.08.23                                              *)
10(* TRANSLATOR    : Konrad Slind, University of Calgary                   *)
11(* DATE          : September 11, 1991                                    *)
12(* REVISED       : 17.1.98                                               *)
13(* REVISION      : Added Induct_then and prove_induction_thm and         *)
14(*                 prove_cases_thm from former Rec_type_support.         *)
15(*                                                                       *)
16(* REVISED       : December 1999                                         *)
17(* BY            : Michael Norrish                                       *)
18(* REVISION      : Re-implemented new_prim_rec_defn using John           *)
19(*                 Harrison's HOL Light code, in conjunction with the    *)
20(*                 wide-ranging revisions to the datatype package.       *)
21(*                                                                       *)
22(* ===================================================================== *)
23
24
25structure Prim_rec :> Prim_rec =
26struct
27
28open HolKernel Parse boolTheory boolSyntax
29     Drule Tactical Tactic Conv Thm_cont Rewrite Abbrev;
30
31val ERR = mk_HOL_ERR "Prim_rec";
32
33structure Parse = struct
34  open Parse
35  val (Type,Term) = parse_from_grammars boolTheory.bool_grammars
36end
37open Parse
38
39
40(*---------------------------------------------------------------------------
41     stuff from various jrh HOL-Light code
42 ---------------------------------------------------------------------------*)
43
44val lhand = rand o rator
45val conjuncts = strip_conj
46
47fun strip_vars tm =
48  let fun pull_off_var tm acc =
49        let val (Rator, Rand) = dest_comb tm
50        in if is_var Rand then pull_off_var Rator (Rand::acc) else (tm, acc)
51        end handle HOL_ERR _ => (tm, acc)
52  in pull_off_var tm []
53  end;
54
55fun REPEATNC n c = if n < 1 then REFL else c THENC REPEATNC (n - 1) c
56
57fun HEAD_BETA_CONV tm =
58  let fun gotoheadpair c tm =
59         if is_comb tm andalso is_comb (rator tm)
60            then RATOR_CONV (gotoheadpair c) tm
61            else c tm
62  in
63    REPEATC (gotoheadpair BETA_CONV) tm
64  end;
65
66fun CONJS_CONV c tm =
67  if is_conj tm then BINOP_CONV (CONJS_CONV c) tm else c tm;
68
69fun mymatch_and_instantiate axth pattern instance = let
70  val (patvars, patbody) = strip_exists pattern
71  val (instvars, instbody) = strip_exists instance
72  fun tmlist_type (tm,ty) = type_of tm --> ty
73  val pat_type = List.foldr tmlist_type Type.bool patvars
74  val inst_type = List.foldr tmlist_type Type.bool instvars
75  val tyinst = Type.match_type pat_type inst_type
76  val new_patbody0 = Term.inst tyinst patbody
77  val new_patvars = map (Term.inst tyinst) patvars
78  val initial_env = ListPair.map op|-> (new_patvars, instvars)
79  val new_patbody = Term.subst initial_env new_patbody0
80  fun match_eqn cnum pat inst = let
81    (* both terms of the form: !v1 .. vn. f (C1 ..) = rhs *)
82    val (patvars, pateqn0) = strip_forall pat
83    val (instvars, insteqn) = strip_forall inst
84    val forall_env = ListPair.map op|-> (patvars, instvars)
85    val pateqn = Term.subst forall_env pateqn0
86    val _ = aconv (lhs pateqn) (lhs insteqn) orelse
87      raise HOL_ERR {origin_function =
88                "prove_raw_recursive_functions_exist.mymatch_and_instantiate",
89              origin_structure = "Prim_rec",
90              message = ("Failed to match LHSes in clause "^Int.toString cnum)}
91    val instrhs = rhs insteqn and patrhs = rhs pateqn
92    (* last arguments in the pattern will be instances of a function symbol
93       being applied to recursive arguments *)
94    val (pathd, patargs) = strip_comb patrhs
95    val (vars, others) = partition is_var patargs
96    val gvars_for_others = map (genvar o type_of) others
97    val others_away =
98      Term.subst (ListPair.map op|-> (others, gvars_for_others)) instrhs
99    (* here we rely on the assumption that the combs we split off were the
100       back half of the list of arguments *)
101    val answer = list_mk_abs(vars @ gvars_for_others, others_away)
102  in
103    (pathd |-> answer)
104    (* substitution to perform and number of beta conversions that will be
105       required as well *)
106  end
107  fun match_eqns acc n patconjs instconjs =
108    case (patconjs, instconjs) of
109      ([], []) => List.rev acc
110    | (p::ps, i::is) => match_eqns ((match_eqn n p i)::acc) (n + 1) ps is
111    | _ => raise
112        HOL_ERR {origin_function = "prove_raw_recursive_functions_exist",
113                 origin_structure = "recursion",
114                 message = "Number of conjuncts not even the same"}
115  val tmsubst = match_eqns [] 1 (strip_conj new_patbody) (strip_conj instbody)
116  val axth1 = Thm.INST_TYPE tyinst axth
117  val axth2 = Thm.INST tmsubst axth1
118in
119  CONV_RULE (STRIP_QUANT_CONV
120             (CONJS_CONV (STRIP_QUANT_CONV (RHS_CONV HEAD_BETA_CONV))))
121  axth2
122end
123
124fun findax c ((p,ax)::rst) =
125      ((match_term p c; (ax,rst))
126        handle HOL_ERR _ => let val (a,l) = findax c rst in (a,(p,ax)::l) end)
127  | findax c [] = raise ERR "prove_raw_recursive_functions_exist" "findax";
128
129fun prove_raw_recursive_functions_exist ax tm = let
130  val rawcls = conjuncts tm
131  val spcls = map (snd o strip_forall) rawcls
132  val lpats = map (strip_comb o lhand) spcls
133  val ufns = itlist (op_insert aconv o fst) lpats []
134  val axth = SPEC_ALL ax
135  val (exvs,axbody) = strip_exists (concl axth)
136  val axcls = conjuncts axbody
137  val f = repeat rator o rand o lhand o snd o strip_forall
138  val table = map (fn t => (f t,t)) axcls
139  fun gax c (axs,tabl) =
140       let val (axcl,tabl') = findax c tabl
141       in (axcl::axs, tabl')
142       end
143  val raxs0 = rev_itlist gax (map (repeat rator o hd o snd) lpats) ([],table)
144  val raxs = List.rev (fst raxs0)
145  val axfns = map (repeat rator o lhand o snd o strip_forall) raxs
146  val dict = ListPair.foldl (fn (a,(b,_),d) => Binarymap.insert(d,a,b))
147                            (Binarymap.mkDict Term.compare)
148                            (axfns, lpats)
149  val urfns =
150      map (fn v => Binarymap.find(dict,v) handle Binarymap.NotFound => v) exvs
151  val axtm = list_mk_exists(exvs,list_mk_conj raxs)
152  and urtm = list_mk_exists(urfns,tm)
153  val ixth = mymatch_and_instantiate axth axtm urtm
154  val (ixvs,ixbody) = strip_exists (concl ixth)
155  val ixtm = Term.subst (map2 (curry op|->) ixvs urfns) ixbody
156  val ixths = CONJUNCTS (ASSUME ixtm)
157  val rixths = map (fn t => valOf (List.find (aconv t o concl) ixths)) rawcls
158  val rixth = itlist SIMPLE_EXISTS ufns (end_itlist CONJ rixths)
159in
160  PROVE_HYP ixth (itlist SIMPLE_CHOOSE urfns rixth)
161end
162
163(* ------------------------------------------------------------------------ *)
164(* Prove existence when PR argument always comes first in argument lists.   *)
165(* ------------------------------------------------------------------------ *)
166
167val prove_canon_recursive_functions_exist = let
168  val RIGHT_BETAS =
169      rev_itlist (fn a => CONV_RULE (RAND_CONV BETA_CONV) o C AP_THM a)
170  fun canonize t = let
171    val (avs,bod) = strip_forall t
172    val (l,r) = dest_eq bod
173    val (fnn,args) = strip_comb l
174    val rarg = hd args
175    and vargs = tl args
176    val l' = mk_comb(fnn,rarg)
177    and r' = list_mk_abs(vargs,r)
178    val fvs = #2 (strip_comb rarg)
179    val def = ASSUME(list_mk_forall(fvs,mk_eq(l', r')))
180  in
181    GENL avs (RIGHT_BETAS vargs (SPECL fvs def))
182  end
183in
184  fn ax => fn tm => let
185    val ths = map canonize (conjuncts tm)
186    val atm = list_mk_conj (map (hd o hyp) ths)
187    val eth = prove_raw_recursive_functions_exist ax atm
188    val aths = CONJUNCTS(ASSUME atm)
189    val rth = end_itlist CONJ (map2 PROVE_HYP aths ths)
190    val evs = fst(strip_exists(concl eth))
191  in
192    PROVE_HYP eth (itlist SIMPLE_CHOOSE evs (itlist SIMPLE_EXISTS evs rth))
193  end
194end
195
196(* ------------------------------------------------------------------------ *)
197(* General version to prove existence.                                      *)
198(* ------------------------------------------------------------------------ *)
199
200fun universalise_clauses tm =
201 let val rawcls = conjuncts tm
202     val spcls = map (snd o strip_forall) rawcls
203     val lpats = map (strip_comb o lhand) spcls
204     val ufns = itlist (op_insert aconv o fst) lpats []
205     val fvs = map (fn t => op_set_diff aconv (free_vars_lr t) ufns) rawcls
206     val gcls = map2 (curry list_mk_forall) fvs rawcls
207 in
208   list_mk_conj gcls
209 end
210
211val prove_recursive_functions_exist =
212 let fun reshuffle fnn args acc =
213      let val args' = uncurry (C (curry op@)) (Lib.partition is_var args)
214      in if ListPair.allEq (uncurry aconv) (args, args') then acc
215         else
216          let val gvs = map (genvar o type_of) args
217              val gvs' = map (C (op_assoc aconv) (zip args gvs)) args'
218              val lty = itlist (curry (op -->) o type_of) gvs'
219                         (funpow (length gvs)
220                                 (hd o tl o snd o dest_type) (type_of fnn))
221              val fn' = genvar lty
222              val def = mk_eq(fnn,list_mk_abs(gvs,list_mk_comb(fn',gvs')))
223          in
224            (ASSUME def)::acc
225          end
226      end
227     fun scrub_def t th =
228      let val (lhs, rhs) = dest_eq t
229      in MP (Thm.INST [lhs |-> rhs] (DISCH t th)) (REFL rhs)
230      end
231     fun prove_once_universalised ax tm =
232      let val rawcls = conjuncts tm
233          val spcls = map (snd o strip_forall) rawcls
234          val lpats = map (strip_comb o lhand) spcls
235          val ufns = itlist (op_insert aconv o fst) lpats []
236          val uxargs = map (C (op_assoc aconv) lpats) ufns
237          val oxargs = map (uncurry (C (curry op@)) o Lib.partition is_var) uxargs
238          val trths = itlist2 reshuffle ufns uxargs []
239          val tth = QCONV
240                      (REPEATC (CHANGED_CONV
241                                  (PURE_REWRITE_CONV trths THENC
242                                   DEPTH_CONV BETA_CONV))) tm
243          val eth = prove_canon_recursive_functions_exist ax (rand(concl tth))
244          val (evs,ebod) = strip_exists(concl eth)
245          val fth = itlist SIMPLE_EXISTS ufns (EQ_MP (SYM tth) (ASSUME ebod))
246          val gth = itlist scrub_def (map concl trths) fth
247      in
248        PROVE_HYP eth (itlist SIMPLE_CHOOSE evs gth)
249      end
250in
251  fn ax => fn tm => prove_once_universalised ax (universalise_clauses tm)
252end
253
254val prove_rec_fn_exists = prove_recursive_functions_exist
255
256(* ------------------------------------------------------------------------ *)
257(* Version that defines function(s).                                        *)
258(* ------------------------------------------------------------------------ *)
259
260fun new_recursive_definition0 ax name tm =
261 let val eth = prove_recursive_functions_exist ax tm
262     val (evs,bod) = strip_exists(concl eth)
263 in
264  Rsyntax.new_specification
265    {sat_thm=eth, name=name,
266     consts = map (fn t => {const_name = fst(dest_var t),
267                            fixity = NONE}) evs }
268 end;
269
270(* test with:
271     load "listTheory";
272     val ax =
273       mk_thm([], ``!n c. ?f. (f [] = n) /\
274                              (!x xs. f (CONS x xs) = c x xs (f xs))``);
275
276     hide "map";
277     val tm =
278       ``(map f [] = []) /\ (map f (CONS x xs) = CONS (f x) (map f xs))``;
279     prove_recursive_functions_exist ax tm;
280     new_recursive_definition0 ax "map" tm;
281
282
283     also
284       new_type 0 "foo";
285       new_type 0 "bar";
286       load "arithmeticTheory";
287       new_constant ("C1", ``:num -> foo``);
288       new_constant ("C2", ``:bar -> foo``);
289       new_constant ("D1", ``:bar``);
290       new_constant ("D2", ``:foo -> num -> bar``);
291       val ax = mk_thm([],
292                       ``!C1' C2' D1' D2'.
293                            ?fn0 fn1.
294                               (!n. fn0 (C1 n) = C1' n) /\
295                               (!b. fn0 (C2 b) = C2' b (fn1 b)) /\
296                               (fn1 D1 = D1') /\
297                               (!f n. fn1 (D2 f n) = D2' n f (fn0 f))``);
298       app hide ["FF1", "FF2"];
299       val tm1 = ``(FF1 f (C1 n) = f n) /\ (FF1 f (C2 b) = FF2 f b) /\
300                   (FF2 f D1 = f 0) /\ (FF2 f (D2 fo n) = f n + FF1 f fo)``;
301
302
303       prove_recursive_functions_exist ax tm1;
304       new_recursive_definition0 ax "FF1" tm1;
305
306       hide "FF3";
307       val tm2 = ``(FF3 (C1 n) = T) /\ (FF3 (C2 b) = F)``;
308       prove_recursive_functions_exist ax tm2;
309       new_recursive_definition0 ax "FF3" tm2;
310
311       hide "FF4";
312       val tm3 = ``(FF4 D1 = 0) /\ (FF4 (D2 f n) = n)``;
313       prove_recursive_functions_exist ax tm3;
314       new_recursive_definition0 ax "FF4" tm3;
315
316       hide "FF5";
317       val tm4 = ``(FF5 D1 = 0)``;
318       prove_recursive_functions_exist ax tm4;
319       new_recursive_definition0 ax "FF5" tm4;
320
321       hide "FF6";
322       val tm5 = ``(FF6 (D2 fo n) = n) /\ (FF7 (C1 n) = n)``;
323       prove_recursive_functions_exist ax tm5;
324       new_recursive_definition0 ax "FF6" tm5;
325
326       (* HOL Light equivalent *)
327       new_type("foo", 0);;
328       new_type("bar", 0);;
329       new_constant("C1", `:num -> foo`);;
330       new_constant("C2", `:bar -> foo`);;
331       new_constant("D1", `:bar`);;
332       new_constant("D2", `:foo -> num -> bar`);;
333       let ax = mk_thm([], `!C1' C2' D1' D2'.
334                            ?fn0 fn1.
335                               (!n. fn0 (C1 n) = C1' n) /\
336                               (!b. fn0 (C2 b) = C2' b (fn1 b)) /\
337                               (fn1 D1 = D1') /\
338                               (!f n. fn1 (D2 f n) = D2' n f (fn0 f))`);;
339
340       let tm2 = `(FF3 (C1 n) = T) /\ (FF3 (C2 b) = F)`;;
341       let fns = prove_recursive_functions_exist ax tm2;;
342
343*)
344
345(*---------------------------------------------------------------------------
346     Make a new recursive function definition.
347 ---------------------------------------------------------------------------*)
348
349fun new_recursive_definition {name,rec_axiom,def} =
350  new_recursive_definition0 rec_axiom name def;
351
352(*---------------------------------------------------------------------------
353   Given axiom and the name of the type, return a list of terms
354   corresponding to that type's constructors with their arguments.
355 ---------------------------------------------------------------------------*)
356
357fun type_constructors_with_args ax name =
358  let val (_, body) = strip_exists (#2 (strip_forall (concl ax)))
359      fun extract_constructor tm =
360         let val (_, eqn) = strip_forall tm
361             val (lhs,_) = dest_eq eqn
362             val arg = rand lhs
363         in
364            if fst (dest_type (type_of arg)) = name then SOME arg
365            else NONE
366         end
367  in
368    List.mapPartial extract_constructor (strip_conj body)
369  end
370
371(* as above but without arguments *)
372fun type_constructors ax name =
373  map (#1 o strip_comb) (type_constructors_with_args ax name)
374
375
376local
377  fun only_variable_args ty = List.all (not o is_type) (snd (dest_type ty))
378  fun only_new types = List.filter only_variable_args types
379in
380  (* return all of the types defined by an axiom, formerly "new_types". *)
381  fun doms_of_tyaxiom ax =
382   let val (evs, _) = strip_exists (#2 (strip_forall (concl ax)))
383       val candidate_types = map (#1 o dom_rng o type_of) evs
384   in
385      only_new candidate_types
386   end
387
388  (*---------------------------------------------------------------------------
389      similarly for an induction theorem, which will be of the form
390
391        !P1 .. PN.
392           c1 /\ ... /\ cn ==> (!x. P1 x) /\ (!x. P2 x) ... /\ (!x. PN x)
393
394     Formerly "new_types_from_ind"
395   ---------------------------------------------------------------------------*)
396
397  fun doms_of_ind_thm ind =
398   let val conclusions = strip_conj(#2(strip_imp(#2(strip_forall(concl ind)))))
399       val candidate_types = map (type_of o fst o dest_forall) conclusions
400   in
401     only_new candidate_types
402   end
403end
404
405(*---------------------------------------------------------------------------*
406 * Define a case constant for a datatype. This is used by TFL's              *
407 * pattern-matching translation and are generally useful as replacements     *
408 * for "destructor" operations.                                              *
409 *---------------------------------------------------------------------------*)
410
411fun num_variant vlist v =
412  let val counter = ref 0
413      val (Name,Ty) = dest_var v
414      val slist = ref (map (fst o dest_var) vlist)
415      fun pass str =
416         if (mem str (!slist))
417         then ( counter := !counter + 1;
418                pass (Lib.strcat Name (Lib.int_to_string(!counter))))
419         else (slist := str :: !slist; str)
420  in
421  mk_var(pass Name, Ty)
422  end;
423
424fun case_constant_name {type_name} = type_name ^ "_CASE"
425fun case_constant_defn_name {type_name} = type_name ^ "_case_def"
426
427fun generate_case_constant_eqns ty clist =
428 let val (dty,rty) = Type.dom_rng ty
429     val (Tyop,Args) = dest_type dty
430     fun mk_cfun ctm (nv,away) =
431       let val (c,args) = strip_comb ctm
432           val fty = itlist (curry (op -->)) (map type_of args) rty
433           val vname = if (length args = 0) then "v" else "f"
434           val v = num_variant away (mk_var(vname, fty))
435       in (v::nv, v::away)
436       end
437     val arg_list = rev(fst(rev_itlist mk_cfun clist ([],free_varsl clist)))
438     val v = mk_var(case_constant_name{type_name = Tyop},
439                    list_mk_fun(dty :: map type_of arg_list, rty))
440     fun clause (a,c) = mk_eq(list_mk_comb(v,c::arg_list),
441                              list_mk_comb(a, #2 (strip_comb c)))
442 in
443   list_mk_conj (ListPair.map clause (arg_list, clist))
444 end
445
446fun define_case_constant ax =
447 let val oktypes = doms_of_tyaxiom ax
448     val conjs = strip_conj (#2 (strip_exists (#2 (strip_forall (concl ax)))))
449     val newfns = map (rator o lhs o #2 o strip_forall) conjs
450     val newtypes = map type_of newfns
451     val usethese = mk_set
452           (List.filter (fn ty => Lib.mem (#1 (dom_rng ty)) oktypes) newtypes)
453     fun mk_defn ty =
454      let val (dty,rty) = dom_rng ty
455          val name = fst (dest_type dty)
456          val cs = type_constructors_with_args ax name
457          val eqns = generate_case_constant_eqns ty cs
458      in new_recursive_definition
459             {name=case_constant_defn_name {type_name = name},
460              rec_axiom=ax,
461              def=eqns}
462      end
463 in
464  map mk_defn usethese
465end
466
467(*---------------------------------------------------------------------------*)
468(*       INDUCT_THEN                                                         *)
469(*---------------------------------------------------------------------------*)
470
471
472(* ---------------------------------------------------------------------*)
473(* Internal function:                                                   *)
474(*                                                                      *)
475(* BETAS "f" tm : returns a conversion that, when applied to a term with*)
476(*               the same structure as the input term tm, will do a     *)
477(*               beta reduction at all top-level subterms of tm which   *)
478(*               are of the form "f <arg>", for some argument <arg>.    *)
479(*                                                                      *)
480(* ---------------------------------------------------------------------*)
481
482fun BETAS fnn body =
483 if is_var body orelse is_const body then REFL else
484 if is_abs body then ABS_CONV (BETAS fnn (#2(dest_abs body)))
485 else let val (Rator,Rand) = dest_comb body
486      in if aconv Rator fnn then BETA_CONV
487         else let val cnv1 = BETAS fnn Rator
488                  and cnv2 = BETAS fnn Rand
489                  fun f (Rator,Rand) = (cnv1 Rator, cnv2 Rand)
490              in MK_COMB o (f o dest_comb)
491              end
492      end;
493
494(* ---------------------------------------------------------------------*)
495(* Internal function: GTAC                                              *)
496(*                                                                      *)
497(*   !x. tm[x]                                                          *)
498(*  ------------  GTAC "y"   (primes the "y" if necessary).             *)
499(*     tm[y]                                                            *)
500(*                                                                      *)
501(* NB: the x is always a genvar, so optimized for this case.            *)
502(* ---------------------------------------------------------------------*)
503
504fun GTAC y (A,g) =
505   let val (Bvar,Body) = dest_forall g
506       and y' = Term.variant (free_varsl (g::A)) y
507   in
508     if type_of Bvar = type_of y' then
509       ([(A, subst[Bvar |-> y'] Body)],
510        fn [th] => GEN Bvar (INST [y' |-> Bvar] th) | _ => raise Match)
511     else GEN_TAC (A,g)
512   end;
513
514(* ---------------------------------------------------------------------*)
515(* Internal function: TACF                                              *)
516(*                                                                      *)
517(* TACF is used to generate the subgoals for each case in an inductive  *)
518(* proof.  The argument tm is formula which states one generalized      *)
519(* case in the induction. For example, the induction theorem for num is:*)
520(*                                                                      *)
521(*   |- !P. P 0 /\ (!n. P n ==> P(SUC n)) ==> !n. P n                   *)
522(*                                                                      *)
523(* In this case, the argument tm will be one of:                        *)
524(*                                                                      *)
525(*   1:  "P 0"   or   2: !n. P n ==> P(SUC n)                           *)
526(*                                                                      *)
527(* TACF applied to each these terms to construct a parameterized tactic *)
528(* which will be used to further break these terms into subgoals.  The  *)
529(* resulting tactic takes a variable name x and a user supplied theorem *)
530(* continuation ttac.  For a base case, like case 1 above, the resulting*)
531(* tactic just throws these parameters away and passes the goal on      *)
532(* unchanged (i.e. \x ttac. ALL_TAC).  For a step case, like case 2, the*)
533(* tactic applies GTAC x as many times as required.  It then strips off *)
534(* the induction hypotheses and applies ttac to each one.  For example, *)
535(* if tac is the tactic generated by:                                   *)
536(*                                                                      *)
537(*    TACF "!n. P n ==> P(SUC n)" "x:num" ASSUME_TAC                    *)
538(*                                                                      *)
539(* then applying tac to the goal A,"!n. P[n] ==> P[SUC n] has the same  *)
540(* effect as applying:                                                  *)
541(*                                                                      *)
542(*    GTAC "x:num" THEN DISCH_THEN ASSUME_TAC                           *)
543(*                                                                      *)
544(* TACF is a strictly local function, used only to define TACS, below.  *)
545(* ---------------------------------------------------------------------*)
546
547local fun ctacs tm =
548       if is_conj tm
549       then let val tac2 = ctacs (snd(dest_conj tm))
550            in fn ttac => CONJUNCTS_THEN2 ttac (tac2 ttac)
551            end
552       else I
553in
554fun TACF tm =
555 let val (vs,body) = strip_forall tm
556 in if is_imp body
557    then let val TTAC = ctacs (fst(dest_imp body))
558         in fn x => fn ttac =>
559              MAP_EVERY (GTAC o Lib.K x) vs THEN DISCH_THEN (TTAC ttac)
560         end
561    else fn x => fn ttac => Tactical.ALL_TAC
562 end
563end;
564
565(* ---------------------------------------------------------------------*)
566(* Internal function: TACS                                              *)
567(*                                                                      *)
568(* TACS uses TACF to generate a parameterized list of tactics, one for  *)
569(* each conjunct in the hypothesis of an induction theorem.             *)
570(*                                                                      *)
571(* For example, if tm is the hypothesis of the induction theorem for the*)
572(* natural numbers---i.e. if:                                           *)
573(*                                                                      *)
574(*   tm = "P 0 /\ (!n. P n ==> P(SUC n))"                               *)
575(*                                                                      *)
576(* then TACS tm yields the parameterized list of tactics:               *)
577(*                                                                      *)
578(*   \x ttac. [TACF "P 0" x ttac; TACF "!n. P n ==> P(SUC n)" x ttac]   *)
579(*                                                                      *)
580(* TACS is a strictly local function, used only in INDUCT_THEN.         *)
581(* ---------------------------------------------------------------------*)
582
583fun f (conj1,conj2) = (TACF conj1, TACS conj2)
584and TACS tm =
585  let val (cf,csf) = f(dest_conj tm) handle HOL_ERR _ => (TACF tm, K(K[]))
586  in fn x => fn ttac => cf x ttac::csf x ttac
587  end;
588
589(* ---------------------------------------------------------------------*)
590(* Internal function: GOALS                                             *)
591(*                                                                      *)
592(* GOALS generates the subgoals (and proof functions) for all the cases *)
593(* in an induction. The argument A is the common assumption list for all*)
594(* the goals, and tacs is a list of tactics used to generate subgoals   *)
595(* from these goals.                                                    *)
596(*                                                                      *)
597(* GOALS is a strictly local function, used only in INDUCT_THEN.        *)
598(* ---------------------------------------------------------------------*)
599
600fun GOALS A [] tm = raise ERR "GOALS" "empty list"
601  | GOALS A [t] tm = let val (sg,pf) = t (A,tm) in ([sg],[pf]) end
602  | GOALS A (h::t) tm =
603      let val (conj1,conj2) = dest_conj tm
604          val (sgs,pfs) = GOALS A t conj2
605          val (sg,pf) = h (A,conj1)
606      in (sg::sgs, pf::pfs)
607      end;
608
609(* --------------------------------------------------------------------- *)
610(* Internal function: GALPH                                             *)
611(*                                                                      *)
612(* GALPH "!x1 ... xn. A ==> B":   alpha-converts the x's to genvars.    *)
613(* --------------------------------------------------------------------- *)
614
615local
616  fun rule vty v =
617      if type_of v = vty then
618        let
619          val gv = genvar(type_of v)
620        in fn eq => let val th = FORALL_EQ v eq
621                    in TRANS th (GEN_ALPHA_CONV gv (rhs(concl th)))
622                    end
623        end
624      else I
625in
626fun GALPH vty tm =
627   let val (vs,hy) = strip_forall tm
628   in if (is_imp hy) then Lib.itlist (rule vty) vs (REFL hy) else REFL tm
629   end
630end;
631
632(* ---------------------------------------------------------------------*)
633(* Internal function: GALPHA                                            *)
634(*                                                                      *)
635(* Applies the conversion GALPH to each conjunct in a sequence.         *)
636(* ---------------------------------------------------------------------*)
637
638fun GALPHA vty tm =
639   let
640     fun f (conj1,conj2) = (GALPH vty conj1, GALPHA vty conj2)
641     val (c,cs) = f(dest_conj tm)
642   in
643     MK_COMB(AP_TERM boolSyntax.conjunction c, cs)
644   end handle HOL_ERR _ => GALPH vty tm
645
646(* --------------------------------------------------------------------- *)
647(* INDUCT_THEN : general induction tactic for concrete recursive types.  *)
648(* --------------------------------------------------------------------- *)
649
650local val boolvar = genvar Type.bool
651in
652fun INDUCT_THEN th =
653 let val (Bvar,Body) = dest_forall(concl th)
654     val ty = Bvar |> type_of |> dom_rng |> #1
655     val (hy,_) = dest_imp Body
656     val bconv = BETAS Bvar hy
657     val tacsf = TACS hy
658     val v = genvar (type_of Bvar)
659     val eta_th = CONV_RULE (RAND_CONV ETA_CONV) (UNDISCH(SPEC v th))
660     val (asm,con) = case dest_thm eta_th
661                     of ([asm],con) => (asm,con)
662                      | _ => raise Match
663     val ind = GEN v (SUBST [boolvar |-> GALPHA ty asm]
664                            (mk_imp(boolvar, con))
665                            (DISCH asm eta_th))
666 in fn ttac => fn (A,t) =>
667     let val lam = snd(dest_comb t)
668         val spec = SPEC lam (INST_TYPE (Lib.snd(Term.match_term v lam)) ind)
669         val (ant,conseq) = dest_imp(concl spec)
670         val beta = SUBST [boolvar |-> bconv ant]
671                          (mk_imp(boolvar, conseq)) spec
672         val tacs = tacsf (fst(dest_abs lam)) ttac
673         val (gll,pl) = GOALS A tacs (fst(dest_imp(concl beta)))
674         val pf = ((MP beta) o LIST_CONJ) o mapshape(map length gll)pl
675     in
676       (Lib.flatten gll, pf)
677     end
678     handle e => raise wrap_exn "Prim_rec" "INDUCT_THEN" e
679 end
680 handle e => raise wrap_exn "Prim_rec" "INDUCT_THEN" e
681end;
682
683(*--------------------------------------------------------------------------
684 * Now prove_induction_thm and prove_cases_thm.
685 *--------------------------------------------------------------------------*)
686
687infixr 3 ==;
688infixr 4 ==>;
689infixr 5 \/;
690infixr 6 /\;
691infixr 3 -->;
692infixr 3 THENC;
693infixr 3 ORELSEC;
694
695fun (x == y)  = mk_eq(x,y);
696fun (x ==> y) = mk_imp(x, y)
697fun (x /\ y)  = mk_conj(x,y);
698fun (x \/ y)  = mk_disj(x,y);
699
700
701(* =====================================================================*)
702(* STRUCTURAL INDUCTION                               (c) T Melham 1990 *)
703(* =====================================================================*)
704
705(* ---------------------------------------------------------------------*)
706(* Internal function: UNIQUENESS                                        *)
707(*                                                                      *)
708(* This function derives uniqueness from unique existence:              *)
709(*                                                                      *)
710(*        |- ?!x. P[x]                                                  *)
711(* ---------------------------------------                              *)
712(*  |- !v1 v2. P[v1] /\ P[v2] ==> (v1=v2)                               *)
713(*                                                                      *)
714(* The variables v1 and v2 are genvars.                                 *)
715(* ---------------------------------------------------------------------*)
716
717val AP_AND = AP_TERM boolSyntax.conjunction
718
719local val P = mk_var("P", alpha --> bool)
720      val v = genvar Type.bool
721      and v1 = genvar alpha
722      and v2 = genvar alpha
723      val ex1P = mk_comb(boolSyntax.exists1,P)
724      val th1 = SPEC P (CONV_RULE (X_FUN_EQ_CONV P) EXISTS_UNIQUE_DEF)
725      val th2 = CONJUNCT2(UNDISCH(fst(EQ_IMP_RULE(RIGHT_BETA th1))))
726      val imp = GEN P (DISCH ex1P (SPECL [v1, v2] th2))
727      fun AND (e1,e2) = MK_COMB(AP_AND e1, e2)
728      fun beta_conj(conj1,conj2) = (BETA_CONV conj1, BETA_CONV conj2)
729      fun conv tm = AND (beta_conj (dest_conj tm))
730in
731fun UNIQUENESS th =
732  let val _ = assert boolSyntax.is_exists1 (concl th)
733      val (Rator,Rand) = dest_comb(concl th)
734      val theta = [alpha |-> type_of (bvar Rand)]
735      val uniq = MP (SPEC Rand (INST_TYPE theta imp)) th
736      val red = conv (fst(dest_imp(concl uniq)))
737      val (V1,V2) = let val i = Term.inst theta in (i v1,i v2) end
738  in
739    GEN V1 (GEN V2 (SUBST[v |-> red] (v ==> (V1 == V2)) uniq))
740  end
741  handle HOL_ERR _ => raise ERR "UNIQUENESS" ""
742end;
743
744(* ---------------------------------------------------------------------*)
745(* Internal function: DEPTH_FORALL_CONV                                 *)
746(*                                                                      *)
747(* DEPTH_FORALL_CONV conv `!x1...xn. tm` applies the conversion conv to *)
748(* the term tm to yield |- tm = tm', and then returns:                  *)
749(*                                                                      *)
750(*    |- (!x1...xn. tm)  =  (!x1...xn. tm')                             *)
751(*                                                                      *)
752(* ---------------------------------------------------------------------*)
753
754fun DEPTH_FORALL_CONV conv tm =
755   let val (vs,th) = (I ## conv) (strip_forall tm)
756   in itlist FORALL_EQ vs th
757   end;
758
759(* ---------------------------------------------------------------------*)
760(* Internal function: CONJS_CONV                                        *)
761(*                                                                      *)
762(* CONJS_CONV conv `t1 /\ t2 /\ ... /\ tn` applies conv to each of the  *)
763(* n conjuncts t1,t2,...,tn and then rebuilds the conjunction from the  *)
764(* results.                                                             *)
765(*                                                                      *)
766(* ---------------------------------------------------------------------*)
767
768fun CONJS_CONV conv tm =
769   let val (conj1,conj2) = dest_conj tm
770   in MK_COMB(AP_AND (conv conj1), CONJS_CONV conv conj2)
771   end handle HOL_ERR _ => conv tm;
772
773
774(* ---------------------------------------------------------------------*)
775(* Internal function: CONJS_SIMP                                        *)
776(*                                                                      *)
777(* CONJS_SIMP conv `t1 /\ t2 /\ ... /\ tn` applies conv to each of the  *)
778(* n conjuncts t1,t2,...,tn.  This should reduce each ti to `T`.  I.e.  *)
779(* executing conv ti should return |- ti = T.  The result returned by   *)
780(* CONJS_SIMP is then: |- (t1 /\ t2 /\ ... /\ tn) = T                   *)
781(*                                                                      *)
782(* ---------------------------------------------------------------------*)
783
784local val T_AND_T = CONJUNCT1 (SPEC boolSyntax.T AND_CLAUSES)
785in
786val CONJS_SIMP  =
787   let fun simp conv tm =
788          let val (conj1,conj2) = dest_conj tm
789          in TRANS (MK_COMB(AP_AND (conv conj1), simp conv conj2))
790                   (T_AND_T)
791          end handle HOL_ERR _ => conv tm
792   in simp
793   end
794end;
795
796(* ---------------------------------------------------------------------*)
797(* Internal function: T_AND_CONV                                        *)
798(*                                                                      *)
799(* T_AND_CONV `T /\ t` returns |- T /\ t = t                            *)
800(*                                                                      *)
801(* ---------------------------------------------------------------------*)
802
803local val T_AND = GEN_ALL (CONJUNCT1 (SPEC_ALL AND_CLAUSES))
804in
805fun T_AND_CONV tm = SPEC (snd(dest_conj tm)) T_AND
806end;
807
808(* ---------------------------------------------------------------------*)
809(* Internal function: GENL_T                                            *)
810(*                                                                      *)
811(* GENL_T [x1;...;xn] returns |- (!x1...xn.T) = T                       *)
812(*                                                                      *)
813(* ---------------------------------------------------------------------*)
814
815local val t_eq_t = REFL T
816in
817fun GENL_T [] = t_eq_t
818  | GENL_T l =
819      let val gen = list_mk_forall(l,T)
820          val imp1 = DISCH gen (SPECL l (ASSUME gen))
821          val imp2 = DISCH T (GENL l (ASSUME T))
822      in IMP_ANTISYM_RULE imp1 imp2
823      end
824end;
825
826(* ---------------------------------------------------------------------*)
827(* Internal function: SIMP_CONV                                         *)
828(*                                                                      *)
829(* SIMP_CONV is used by prove_induction_thm to simplify to `T` terms of *)
830(* the following two forms:                                             *)
831(*                                                                      *)
832(*   1: !x1...xn. (\x.T)v = (\x1...xn.T) x1 ... xn                      *)
833(*                                                                      *)
834(*   2: !x1...xn. (\x.T)v =                                             *)
835(*      (\y1...ym x1..xn. (y1 /\.../\ ym) \/ t) ((\x.T)u1)...((\x.T)um) *)
836(*                                                     x1 ... xn        *)
837(*                                                                      *)
838(* If tm, a term of one of these two forms, is the argument to SIMP_CONV*)
839(* then the theorem returned is |- tm = T.                              *)
840(* ---------------------------------------------------------------------*)
841
842local val v = genvar Type.bool
843      val eq = inst [alpha |-> bool] boolSyntax.equality
844      val T_EQ_T = EQT_INTRO(REFL T)
845      val T_OR = GEN v (CONJUNCT1 (SPEC v OR_CLAUSES))
846      fun DISJ_SIMP tm =
847         let val (disj1,disj2) = dest_disj tm
848             val eqn = SYM(CONJS_SIMP BETA_CONV disj1)
849         in SUBST[v |-> eqn] ((v \/ disj2) == T) (SPEC disj2 T_OR)
850         end
851
852in
853fun SIMP_CONV tm =
854   let val (vs,(lhs,rhs)) = (I ## dest_eq) (strip_forall tm)
855       val rsimp = (LIST_BETA_CONV THENC (DISJ_SIMP ORELSEC REFL)) rhs
856       and lsimp = AP_TERM eq (BETA_CONV lhs)
857       and gent  = GENL_T vs
858       val eqsimp = TRANS (MK_COMB(lsimp,rsimp)) T_EQ_T
859   in
860   TRANS (itlist FORALL_EQ vs eqsimp) gent
861   end
862end;
863
864(* ---------------------------------------------------------------------*)
865(* Internal function: HYP_SIMP                                          *)
866(*                                                                      *)
867(* HYP_SIMP is used by prove_induction_thm to simplify induction        *)
868(* hypotheses according to the following scheme:                        *)
869(*                                                                      *)
870(*   1: !x1...xn. P t = (\x1...xn.T) x1...xn                            *)
871(*                                                                      *)
872(*         simplifies to                                                *)
873(*                                                                      *)
874(*      !x1...xn. P t                                                   *)
875(*                                                                      *)
876(*   2: !x1...xn. P t =                                                 *)
877(*        ((\y1..ym x1..xn. y1 /\ ... /\ ym) \/ P t) v1 ... vm x1 ... xn*)
878(*                                                                      *)
879(*         simplifies to                                                *)
880(*                                                                      *)
881(*      !x1...xn. (v1 /\ ... /\ vm) ==> P t                             *)
882(*                                                                      *)
883(* ---------------------------------------------------------------------*)
884
885local val v = genvar Type.bool
886      val eq = inst [alpha |-> bool] boolSyntax.equality
887      val EQ_T = GEN v (CONJUNCT1 (CONJUNCT2 (SPEC v EQ_CLAUSES)))
888      fun R_SIMP tm =
889         let val (lhs,rhs) = dest_eq tm
890         in if aconv rhs T
891            then SPEC lhs EQ_T
892            else SPECL [lhs, fst(dest_disj rhs)] OR_IMP_THM
893         end
894in
895fun HYP_SIMP tm =
896   let val (vs,(lhs,rhs)) = (I##dest_eq) (strip_forall tm)
897       val eqsimp = AP_TERM (mk_comb(eq,lhs)) (LIST_BETA_CONV rhs)
898       val rsimp = CONV_RULE (RAND_CONV R_SIMP) eqsimp
899   in itlist FORALL_EQ vs rsimp
900   end
901end;
902
903(* ---------------------------------------------------------------------*)
904(* Internal function: ANTE_ALL_CONV                                     *)
905(*                                                                      *)
906(* ANTE_ALL_CONV `!x1...xn. P ==> Q` restricts the scope of as many of  *)
907(* the quantified x's as possible to the term Q.                        *)
908(* ---------------------------------------------------------------------*)
909
910fun ANTE_ALL_CONV tm =
911   let val (vs,(ant,_)) = (I ## dest_imp) (strip_forall tm)
912       val (ov,iv) = partition (C free_in ant) vs
913       val thm1 = GENL iv (UNDISCH (SPECL vs (ASSUME tm)))
914       val thm2 = GENL ov (DISCH ant thm1)
915       val asm = concl thm2
916       val thm3 = SPECL iv (UNDISCH (SPECL ov (ASSUME asm)))
917       val thm4 = GENL vs (DISCH ant thm3)
918   in
919   IMP_ANTISYM_RULE (DISCH tm thm2) (DISCH asm thm4)
920   end;
921
922(* ---------------------------------------------------------------------*)
923(* Internal function: CONCL_SIMP                                        *)
924(*                                                                      *)
925(* CONCL_SIMP `\x.T = P` returns: |- (\x.T = P) = (!y. P y) where y is  *)
926(* an appropriately chosen variable.                                    *)
927(* ---------------------------------------------------------------------*)
928
929local val v = genvar Type.bool
930      val T_EQ = GEN v (CONJUNCT1 (SPEC v EQ_CLAUSES))
931in
932fun CONCL_SIMP tm =
933   let val eq = FUN_EQ_CONV tm
934       val (Bvar,Body) = dest_forall(rhs(concl eq))
935       val eqn = RATOR_CONV(RAND_CONV BETA_CONV) Body
936       and simp = SPEC (rhs Body) T_EQ
937   in
938   TRANS eq (FORALL_EQ Bvar (TRANS eqn simp))
939  end
940end;
941
942(* ---------------------------------------------------------------------*)
943(* prove_induction_thm: prove a structural induction theorem from a type*)
944(* axiom of the form returned by define_type.                           *)
945(*                                                                      *)
946(* EXAMPLE:                                                             *)
947(*                                                                      *)
948(* Input:                                                               *)
949(*                                                                      *)
950(*    |- !x f. ?! fn. (fn[] = x) /\ (!h t. fn(CONS h t) = f(fn t)h t)   *)
951(*                                                                      *)
952(* Output:                                                              *)
953(*                                                                      *)
954(*    |- !P. P [] /\ (!t. P t ==> (!h. P(CONS h t))) ==> (!l. P l)      *)
955(*                                                                      *)
956(* ---------------------------------------------------------------------*)
957
958local val B = Type.bool
959      fun gen 0 = []
960        | gen n = genvar B::gen (n-1)
961      fun mk_fn P ty tm =
962         let val (lhs,rhs) = dest_eq(snd(strip_forall tm))
963             val c = rand lhs
964             val args = snd(strip_comb rhs)
965             val vars = filter is_var args
966             val n = length(filter (fn t => type_of t = ty) vars)
967         in if (n=0) then list_mk_abs (vars, T)
968            else let val bools = gen n
969                     val term = list_mk_conj bools \/ mk_comb(P,c)
970                 in list_mk_abs((bools@vars),term)
971                 end
972         end
973      val LCONV = RATOR_CONV o RAND_CONV
974      val conv1 = LCONV(CONJS_SIMP SIMP_CONV) THENC T_AND_CONV
975      and conv2 = CONJS_CONV (HYP_SIMP THENC TRY_CONV ANTE_ALL_CONV)
976in
977fun prove_induction_thm th =
978   let val (Bvar,Body) = dest_abs(rand(snd(strip_forall(concl th))))
979       val (ty,rty) = case dest_type (type_of Bvar)
980                      of (_,[ty, rty]) => (ty,rty)
981                       | _ => raise Match
982       val inst = INST_TYPE [rty |-> B] th
983       val P = mk_primed_var("P", ty --> B)
984       and v = genvar ty
985       and cases = strip_conj Body
986       val uniq = let val (vs,tm) = strip_forall(concl inst)
987                      val thm = UNIQUENESS(SPECL vs inst)
988                  in GENL vs (SPECL [mk_abs(v,T), P] thm)
989                  end
990      val spec = SPECL (map (mk_fn P ty) cases) uniq
991      val simp =  CONV_RULE (LCONV(conv1 THENC conv2)) spec
992   in
993     GEN P (CONV_RULE (RAND_CONV CONCL_SIMP) simp)
994   end
995   handle HOL_ERR _ => raise ERR "prove_induction_thm" ""
996end;
997
998
999(* ---------------------------------------------------------------------*)
1000(* Internal function: NOT_ALL_THENC                                     *)
1001(*                                                                      *)
1002(* This conversion first moves negation inwards through an arbitrary    *)
1003(* number of nested universal quantifiers. It then applies the supplied *)
1004(* conversion to the resulting inner negation.  For example if:         *)
1005(*                                                                      *)
1006(*      conv "~tm" ---> |- ~tm = tm'                                    *)
1007(* then                                                                 *)
1008(*                                                                      *)
1009(*       NOT_ALL_THENC conv "~(!x1 ... xn. tm)"                         *)
1010(*                                                                      *)
1011(* yields:                                                              *)
1012(*                                                                      *)
1013(*       |- ~(!x1...xn.tm) = ?x1...xn.tm'                               *)
1014(* ---------------------------------------------------------------------*)
1015
1016fun NOT_ALL_THENC conv tm =
1017   (NOT_FORALL_CONV THENC
1018    (RAND_CONV (ABS_CONV (NOT_ALL_THENC conv)))) tm
1019    handle HOL_ERR _ => conv tm;
1020
1021(* ---------------------------------------------------------------------*)
1022(* Internal function: BASE_CONV                                         *)
1023(*                                                                      *)
1024(* This conversion does the following simplification:                   *)
1025(*                                                                      *)
1026(*    BASE_CONV "~((\x.~tm)y)"  --->  |- ~((\x.~tm)y) = tm[y/x]         *)
1027(*                                                                      *)
1028(* ---------------------------------------------------------------------*)
1029
1030local val NOT_NOT = CONJUNCT1 NOT_CLAUSES
1031      and neg = boolSyntax.negation
1032in
1033fun BASE_CONV tm =
1034   let val beta = BETA_CONV (dest_neg tm)
1035       val simp = SPEC (rand(rhs(concl beta))) NOT_NOT
1036   in TRANS (AP_TERM neg beta) simp
1037   end
1038end;
1039
1040(* ---------------------------------------------------------------------*)
1041(* Internal function: STEP_CONV                                         *)
1042(*                                                                      *)
1043(* This conversion does the following simplification:                   *)
1044(*                                                                      *)
1045(*    STEP_CONV "~(tm' ==> !x1..xn.(\x.~tm)z"                           *)
1046(*                                                                      *)
1047(* yields:                                                              *)
1048(*                                                                      *)
1049(*   |- ~(tm' ==> !x1..xn.(\x.~tm)z = tm' /\ ?x1..xn.tm[z/x]            *)
1050(* ---------------------------------------------------------------------*)
1051
1052local val v1 = genvar Type.bool
1053      and v2 = genvar Type.bool
1054in
1055fun STEP_CONV tm =
1056   let val (ant,conseq) = dest_imp(dest_neg tm)
1057       val th1 = SPEC conseq (SPEC ant NOT_IMP)
1058       val simp = NOT_ALL_THENC BASE_CONV (mk_neg conseq)
1059   in
1060   SUBST [v2 |-> simp] (tm == (ant /\ v2)) th1
1061   end
1062end;
1063
1064(* ---------------------------------------------------------------------*)
1065(* Internal function: NOT_IN_CONV                                       *)
1066(*                                                                      *)
1067(* This first conversion moves negation inwards through conjunction and *)
1068(* universal quantification:                                            *)
1069(*                                                                      *)
1070(*   NOT_IN_CONV  "~(!x1..xn.c1 /\ ... /\ !x1..xm.cn)"                  *)
1071(*                                                                      *)
1072(* to transform the input term into:                                    *)
1073(*                                                                      *)
1074(*   ?x1..xn. ~c1 \/ ... \/ ?x1..xm. ~cn                                *)
1075(*                                                                      *)
1076(* It then applies either BASE_CONV or STEP_CONV to each subterm ~ci.   *)
1077(* ---------------------------------------------------------------------*)
1078
1079local val A = mk_var("A",Type.bool)
1080      val B = mk_var("B",Type.bool)
1081      val DE_MORG = GENL [A,B] (CONJUNCT1(SPEC_ALL DE_MORGAN_THM))
1082      and cnv = BASE_CONV ORELSEC STEP_CONV
1083      and v1 = genvar Type.bool
1084      and v2 = genvar Type.bool
1085in
1086fun NOT_IN_CONV tm =
1087   let val (conj1,conj2) = dest_conj(dest_neg tm)
1088       val thm = SPEC conj2 (SPEC conj1 DE_MORG)
1089       val cth = NOT_ALL_THENC cnv (mk_neg conj1)
1090       and csth = NOT_IN_CONV (mk_neg conj2)
1091   in
1092     SUBST[v1 |-> cth, v2 |-> csth] (tm == (v1 \/ v2)) thm
1093   end
1094   handle HOL_ERR _ => NOT_ALL_THENC cnv tm
1095end;
1096
1097
1098(* ---------------------------------------------------------------------*)
1099(* Internal function: STEP_SIMP                                         *)
1100(*                                                                      *)
1101(* This rule does the following simplification:                         *)
1102(*                                                                      *)
1103(*    STEP_RULE "?x1..xi. tm1 /\ ?xj..xn. tm2"                          *)
1104(*                                                                      *)
1105(* yields:                                                              *)
1106(*                                                                      *)
1107(*   ?x1..xi.tm1 /\ ?xj..xn.tm2 |- ?x1..xn.tm2                          *)
1108(*                                                                      *)
1109(* For input terms of other forms, the rule yields:                     *)
1110(*                                                                      *)
1111(*   STEP_RULE "tm" ---> tm |- tm                                       *)
1112(* ---------------------------------------------------------------------*)
1113
1114local fun EX tm th = EXISTS (mk_exists(tm,concl th),tm) th
1115      fun CH tm th = CHOOSE (tm,ASSUME(mk_exists(tm,hd(hyp th)))) th
1116in
1117fun STEP_SIMP tm =
1118   let val (vs,body) = strip_exists tm
1119   in itlist (fn t => CH t o EX t) vs (CONJUNCT2 (ASSUME body))
1120   end handle HOL_ERR _ => ASSUME tm
1121end;
1122
1123
1124(* ---------------------------------------------------------------------*)
1125(* Internal function: DISJ_CHAIN                                        *)
1126(*                                                                      *)
1127(* Suppose that                                                         *)
1128(*                                                                      *)
1129(*    rule "tmi"  --->   tmi |- tmi'            (for 1 <= i <= n)       *)
1130(*                                                                      *)
1131(* then:                                                                *)
1132(*                                                                      *)
1133(*       |- tm1 \/ ... \/ tmn                                           *)
1134(*    ---------------------------   DISJ_CHAIN rule                     *)
1135(*      |- tm1' \/ ... \/ tmn'                                          *)
1136(* ---------------------------------------------------------------------*)
1137
1138fun DISJS_CHAIN rule th =
1139   let val concl_th = concl th
1140   in let val (disj1,disj2) = dest_disj concl_th
1141          val i1 = rule disj1
1142          and i2 = DISJS_CHAIN rule (ASSUME disj2)
1143      in DISJ_CASES th (DISJ1 i1 (concl i2)) (DISJ2 (concl i1) i2)
1144      end
1145      handle HOL_ERR _ => MP (DISCH concl_th (rule concl_th)) th
1146   end
1147
1148
1149(* --------------------------------------------------------------------- *)
1150(* prove_cases_thm: prove a cases or "exhaustion" theorem for a concrete *)
1151(* recursive type from a structural induction theorem of the form        *)
1152(* returned by prove_induction_thm.                                      *)
1153(*                                                                       *)
1154(* EXAMPLE:                                                              *)
1155(*                                                                       *)
1156(* Input:                                                                *)
1157(*                                                                       *)
1158(*    |- !P. P[] /\ (!t. P t ==> (!h. P(CONS h t))) ==> (!l. P l)        *)
1159(*                                                                       *)
1160(* Output:                                                               *)
1161(*                                                                       *)
1162(*    |- !l. (l = []) \/ (?t h. l = CONS h t)                            *)
1163(*                                                                       *)
1164(* --------------------------------------------------------------------- *)
1165
1166fun EXISTS_EQUATION tm th = let
1167  val (l,r) = boolSyntax.dest_eq tm
1168  val P = mk_abs(l, concl th)
1169  val th1 = BETA_CONV(mk_comb(P,l))
1170  val th2 = ISPECL [P, r] JRH_INDUCT_UTIL
1171  val th3 = EQ_MP (SYM th1) th
1172  val th4 = GEN l (DISCH tm th3)
1173in
1174  MP th2 th4
1175end
1176
1177
1178local
1179  fun margs n s avoid [] = []
1180    | margs n s avoid (h::t) =
1181      let val v = variant avoid
1182                          (mk_var(s^(Int.toString n),h))
1183      in
1184        v::margs (n + 1) s (v::avoid) t
1185      end
1186  fun make_args s avoid tys =
1187      if length tys = 1 then
1188        [variant avoid (mk_var(s, hd tys))]
1189        handle _ => raise ERR "make_args" ""
1190      else margs 0 s avoid tys
1191
1192
1193  fun mk_exclauses x rpats = let
1194    (* order of existentially quantified variables is same order
1195       as they appear as arguments to the constructor *)
1196    val xts = map (fn t =>
1197                      list_mk_exists(#2 (strip_comb t), mk_eq(x,t))) rpats
1198  in
1199    mk_abs(x,list_mk_disj xts)
1200  end
1201  fun prove_triv tm = let
1202    val (evs,bod) = strip_exists tm
1203    val (l,r) = dest_eq bod
1204    val (lf,largs) = strip_comb l
1205    val (rf,rargs) = strip_comb r
1206    val _ = (aconv lf rf) orelse raise ERR "prove_triv" ""
1207    val ths = map (ASSUME o mk_eq) (zip rargs largs)
1208    val th1 = rev_itlist (C (curry MK_COMB)) ths (REFL lf)
1209  in
1210    itlist EXISTS_EQUATION (map concl ths) (SYM th1)
1211  end
1212  fun prove_disj tm =
1213      if is_disj tm then let
1214          val (l,r) = dest_disj tm
1215        in
1216          DISJ1 (prove_triv l) r
1217          handle HOL_ERR _ => DISJ2 l (prove_disj r)
1218        end
1219      else prove_triv tm
1220  fun prove_eclause tm = let
1221    val (avs,bod) = strip_forall tm
1222    val ctm = if is_imp bod then rand bod else bod
1223    val cth = prove_disj ctm
1224    val dth = if is_imp bod then DISCH (lhand bod) cth else cth
1225  in
1226    GENL avs dth
1227  end
1228  fun CONJUNCTS_CONV c tm =
1229      if is_conj tm then BINOP_CONV (CONJUNCTS_CONV c) tm else c tm
1230  fun prove_cases_thm0 th = let
1231    val (_,bod) = strip_forall(concl th)
1232    val cls = map (snd o strip_forall) (conjuncts(lhand bod))
1233    val pats = map (fn t => if is_imp t then rand t else t) cls
1234    val spats = map dest_comb pats
1235    val preds = itlist (op_insert aconv o fst) spats []
1236    val rpatlist = map
1237                     (fn pr => map snd (filter (fn (p,x) => aconv p pr) spats))
1238                     preds
1239    val xs = make_args "x" (free_varsl pats) (map (type_of o hd) rpatlist)
1240    val xpreds = map2 mk_exclauses xs rpatlist
1241    fun double_name th = let
1242      fun rename t = let
1243        val (v, _) = dest_forall t
1244        val (vnm,ty) = dest_var v
1245      in
1246        RAND_CONV (ALPHA_CONV (mk_var(vnm^vnm, ty))) t
1247      end
1248    in
1249      DISCH_ALL (CONV_RULE (CONJUNCTS_CONV rename) (UNDISCH th))
1250    end
1251    val ith = BETA_RULE
1252                (Thm.INST (ListPair.map (fn (x,p) => p |-> x) (xpreds, preds))
1253                          (double_name (SPEC_ALL th)))
1254    val eclauses = conjuncts(fst(dest_imp(concl ith)))
1255  in
1256    MP ith (end_itlist CONJ (map prove_eclause eclauses))
1257  end
1258in
1259fun prove_cases_thm ind0 = let
1260  val ind = CONV_RULE
1261              (STRIP_QUANT_CONV (RATOR_CONV (RAND_CONV
1262                (CONJUNCTS_CONV (REDEPTH_CONV RIGHT_IMP_FORALL_CONV))))) ind0
1263  val basic_thm = prove_cases_thm0 ind
1264  val oktypes = doms_of_ind_thm ind
1265in
1266  List.filter
1267    (fn th => Lib.mem (type_of (fst(dest_forall (concl th)))) oktypes)
1268    (CONJUNCTS basic_thm)
1269end
1270
1271end; (* prove_cases_thm *)
1272
1273(*---------------------------------------------------------------------------
1274    Proving case congruence:
1275
1276     |- (M = M') /\
1277        (!x1,...,xk. (M' = C1 x1..xk) ==> (f1 x1..xk = f1' x1..xk))
1278         /\ ... /\
1279        (!x1,...,xj. (M' = Cn x1..xj) ==> (fn x1..xj = fn' x1..xj))
1280        ==>
1281       (ty_case f1..fn M = ty_case f1'..fn' M')
1282
1283 ---------------------------------------------------------------------------*)
1284
1285fun case_cong_term case_def =
1286 let val clauses = (strip_conj o concl) case_def
1287     val clause1 = Lib.trye hd clauses
1288     val left = (#1 o dest_eq o #2 o strip_forall) clause1
1289     val (c, allargs) = strip_comb left
1290     val (tyarg, nonty_args) =
1291         case allargs of h::t => (h,t)
1292                       | _ => raise Fail "case_cong_term: should never happen"
1293     val ty = type_of tyarg
1294     val allvars = all_varsl clauses
1295     val M = variant allvars (mk_var("M", ty))
1296     val M' = variant (M::allvars) (mk_var("M",ty))
1297     val lhsM = list_mk_comb(c, M::nonty_args)
1298     fun mk_clause clause =
1299       let val (lhs,rhs) = (dest_eq o #2 o strip_forall) clause
1300           val func = (#1 o strip_comb) rhs
1301           val (Name,Ty) = dest_var func
1302           val func' = variant allvars (mk_var(Name^"'", Ty))
1303           val capp = hd (#2 (strip_comb lhs))
1304           val (constr,xbar) = strip_vars capp
1305       in (func',
1306           list_mk_forall
1307           (xbar, mk_imp(mk_eq(M',capp),
1308                         mk_eq(list_mk_comb(func,xbar),
1309                               list_mk_comb(func',xbar)))))
1310       end
1311     val (funcs',clauses') = unzip (map mk_clause clauses)
1312 in
1313 mk_imp(list_mk_conj(mk_eq(M, M')::clauses'),
1314        mk_eq(lhsM, list_mk_comb(c,M'::funcs')))
1315 end;
1316
1317(*---------------------------------------------------------------------------*
1318 *                                                                           *
1319 *        A, v = M[x1,...,xn] |- N                                           *
1320 *  ------------------------------------------                               *
1321 *     A, ?x1...xn. v = M[x1,...,xn] |- N                                    *
1322 *                                                                           *
1323 *---------------------------------------------------------------------------*)
1324
1325fun EQ_EXISTS_LINTRO (thm,(vlist,theta)) =
1326  let val veq = case filter (can dest_eq) (hyp thm)
1327                of [veq] => veq
1328                 | _ => raise Match
1329      fun CHOOSER v (tm,thm) =
1330        let val w = (case (subst_assoc (aconv v) theta)
1331                      of SOME w => w
1332                       | NONE => v)
1333            val ex_tm = mk_exists(w,tm)
1334        in (ex_tm, CHOOSE(w, ASSUME ex_tm) thm)
1335        end
1336  in snd(itlist CHOOSER vlist (veq,thm))
1337  end;
1338
1339
1340fun OKform case_def =
1341  let val clauses = (strip_conj o concl) case_def
1342      val left = (fst o dest_eq o #2 o strip_forall) (Lib.trye hd clauses)
1343      val opvars = tl (#2 (strip_comb left))
1344      fun rhs_head c = fst(strip_comb(rhs(snd(strip_forall c))))
1345      val rhs_heads = map rhs_head clauses
1346      fun check [] = true
1347        | check ((x,y)::rst) = aconv x y andalso check rst
1348  in
1349     check (zip opvars rhs_heads)
1350  end
1351
1352fun case_cong_thm nchotomy case_def =
1353 let open Psyntax
1354     val _ = assert OKform case_def
1355     val clause1 =
1356       let val c = concl case_def in fst(dest_conj c) handle HOL_ERR _ => c end
1357     val V = tl (snd (strip_comb (lhs (#2 (strip_forall clause1)))))
1358     val gl = case_cong_term case_def
1359     val (ant,conseq) = dest_imp gl
1360     val imps = CONJUNCTS (ASSUME ant)
1361     val M_eq_M' = hd imps
1362     val (M, M') = dest_eq (concl M_eq_M')
1363     fun get_asm tm = (fst o dest_imp o #2 o strip_forall) tm handle _ => tm
1364     val case_assms = map (ASSUME o get_asm o concl) imps
1365     val (lconseq, rconseq) = dest_eq conseq
1366     val lconseq_thm = SUBST_CONV [M |-> M_eq_M'] lconseq lconseq
1367     val lconseqM' = rhs(concl lconseq_thm)
1368     val nchotomy' = ISPEC M' nchotomy
1369     val disjrl = map ((I##rhs) o strip_exists) (strip_disj (concl nchotomy'))
1370     val V' = tl(snd(strip_comb rconseq))
1371     val theta = map2 (fn v => fn v' => {redex=v,residue=v'}) V V'
1372     fun zot (p as (icase_thm, case_def_clause)) (iimp,(vlist,disjrhs)) =
1373         let
1374           val c = case_def_clause |> concl |> lhs |> strip_comb |> #1
1375           fun AP_THMl tl th = List.foldl (fn (v,th) => AP_THM th v) th tl
1376           val lth =
1377                 icase_thm |> AP_TERM c |> AP_THMl V |> C TRANS case_def_clause
1378           val rth = TRANS (icase_thm |> AP_TERM c |> AP_THMl V')
1379                           (INST theta case_def_clause)
1380           val theta = Term.match_term disjrhs
1381                     ((rhs o fst o dest_imp o #2 o strip_forall o concl) iimp)
1382           val th = MATCH_MP iimp icase_thm
1383           val th1 = TRANS lth th
1384         in
1385           (TRANS th1 (SYM rth), (vlist, #1 theta))
1386         end
1387     val thm_substs = map2 zot
1388                       (zip (Lib.trye tl case_assms)
1389                            (map SPEC_ALL (CONJUNCTS case_def)))
1390                       (zip (Lib.trye tl imps) disjrl)
1391     val aag = map (TRANS lconseq_thm o EQ_EXISTS_LINTRO) thm_substs
1392 in
1393   GENL (M::M'::V) (DISCH_ALL (DISJ_CASESL nchotomy' aag))
1394 end
1395 handle HOL_ERR _ => raise ERR "case_cong_thm" "construction failed";
1396
1397
1398
1399(* The standard versions of these (in Conv) check that the term being
1400   manipulated is actually an equality.  I want a slightly more efficient
1401   version *)
1402val LHS_CONV = RATOR_CONV o RAND_CONV
1403val RHS_CONV = RAND_CONV
1404
1405
1406(* =====================================================================*)
1407(* PROOF THAT CONSTRUCTORS OF RECURSIVE TYPES ARE ONE-TO-ONE            *)
1408(* =====================================================================*)
1409
1410(* ---------------------------------------------------------------------*)
1411(* Internal function: list_variant                                      *)
1412(*                                                                      *)
1413(* makes variants of the variables in l2 such that they are all not in  *)
1414(* l1 and are all different.                                            *)
1415(* ---------------------------------------------------------------------*)
1416
1417fun list_variant l1 [] = []
1418  | list_variant l1 (h::t) =
1419       let val v = variant l1 h
1420       in v::(list_variant (v::l1) t)
1421       end;
1422
1423fun mk_subst2 [] [] = []
1424  | mk_subst2 (a::L1) (b::L2) = (b |-> a)::mk_subst2 L1 L2
1425  | mk_subst2 _ _ = raise Match;
1426
1427
1428(* ----------------------------------------------------------------------*)
1429(* Internal function: prove_const_one_one.                               *)
1430(*                                                                       *)
1431(* This function proves that a single constructor of a recursive type is *)
1432(* one-to-one (it is called once for each appropriate constructor). The  *)
1433(* theorem input, th, is the characterizing theorem for the recursive    *)
1434(* type in question.  The term, tm, is the defining equation for the     *)
1435(* constructor in question, taken from the body of the theorem th.       *)
1436(*                                                                       *)
1437(* For example, if:                                                      *)
1438(*                                                                       *)
1439(*  th = |- !x f. ?! fn. (fn[] = x) /\ (!h t. fn(CONS h t) = f(fn t)h t) *)
1440(*                                                                       *)
1441(* and                                                                   *)
1442(*                                                                       *)
1443(*  tm = "!h t. fn(CONS h t) = f(fn t)h t"                               *)
1444(*                                                                       *)
1445(* then prove_const_one_one th tm yields:                                *)
1446(*                                                                       *)
1447(*  |- !h t h' t'. (CONS h t = CONS h' t') = (h = h') /\ (t = t')        *)
1448(*                                                                       *)
1449(* ----------------------------------------------------------------------*)
1450
1451(* Basic strategy is to use a function
1452      f h' t' (C h t) = (h = h') /\ (t = t')
1453   Then, if we assume
1454               C h t = C h' t'
1455       f h t (C h t) = f h t (C h' t')
1456  (h = h) /\ (t = t) = (h = h') /\ (t = t')
1457                   T = (h = h') /\ (t = t')
1458  so
1459        (C h t = C h' t') ==> (h = h') /\ (t = t')
1460  in the other direction, we just rewrite (C h t) with the equalities and
1461  get the desired equation.
1462*)
1463
1464fun prove_const_one_one th tm = let
1465  val (vs,(lhs,_)) = (I ## dest_eq)(strip_forall tm)
1466  val C = rand lhs
1467  val funtype =
1468    List.foldr (fn (tm, ty) => Type.-->(type_of tm, ty))
1469    (Type.-->(type_of C, Type.bool)) vs
1470  val f = genvar funtype
1471  val vvs = list_variant vs vs
1472  val fn_body = list_mk_conj(ListPair.map op== (vs, vvs))
1473  val f_ap_vs = list_mk_comb(f, vs)
1474  val C' = subst (mk_subst2 vvs vs) C
1475  val eqn =
1476    list_mk_forall(vs @ vvs, mk_comb(f_ap_vs, C') == fn_body)
1477  val fn_exists_thm = prove_recursive_functions_exist th eqn
1478  val eqn_thm = ASSUME (snd (dest_exists (concl fn_exists_thm)))
1479  val C_eq_C'_t = C == C'
1480  val C_eq_C' = ASSUME C_eq_C'_t
1481  val fC_eq_fC' = AP_TERM f_ap_vs C_eq_C'
1482  val expandedfs = CONV_RULE (LHS_CONV (REWR_CONV eqn_thm) THENC
1483                              RHS_CONV (REWR_CONV eqn_thm)) fC_eq_fC'
1484  val imp1 =
1485    CHOOSE(f, fn_exists_thm) (DISCH C_eq_C'_t (REWRITE_RULE [] expandedfs))
1486
1487  val eqns = CONJUNCTS (ASSUME fn_body)
1488  val rewritten = REWRITE_CONV eqns C
1489  val imp2 = DISCH fn_body rewritten
1490in
1491  GENL vs (GENL vvs (IMP_ANTISYM_RULE imp1 imp2))
1492end
1493
1494(* ----------------------------------------------------------------------*)
1495(* prove_constructors_one_one : prove that the constructors of a given   *)
1496(* concrete recursive type are one-to-one. The input is a theorem of the *)
1497(* form returned by define_type.                                         *)
1498(*                                                                       *)
1499(* EXAMPLE:                                                              *)
1500(*                                                                       *)
1501(* Input:                                                                *)
1502(*                                                                       *)
1503(*    |- !x f. ?! fn. (fn[] = x) /\ (!h t. fn(CONS h t) = f(fn t)h t)    *)
1504(*                                                                       *)
1505(* Output:                                                               *)
1506(*                                                                       *)
1507(*    |- !h t h' t'. (CONS h t = CONS h' t') = (h = h') /\ (t = t')      *)
1508(* ----------------------------------------------------------------------*)
1509
1510local
1511  (* given an equivalence relation R, partition a list into a list of lists
1512     such that everything in each list is related to each other *)
1513  (* preserves the order of the elements within each partition with
1514     respect to the order they were given in the original list *)
1515  fun partition R l = let
1516    fun partition0 parts [] = parts
1517      | partition0 parts (x::xs) = let
1518          fun srch_parts [] = [[x]]
1519            | srch_parts (p::ps) = if R x (hd p) then (x::p)::ps
1520                                   else p::(srch_parts ps)
1521        in
1522          partition0 (srch_parts parts) xs
1523        end
1524  in
1525    map List.rev (partition0 [] l)
1526  end
1527in
1528
1529  fun prove_constructors_one_one th = let
1530    val all_eqns =
1531      strip_conj (snd (strip_exists(snd(strip_forall(concl th)))))
1532    val axtypes = doms_of_tyaxiom th
1533    fun eqn_type eq = type_of (rand (lhs (#2 (strip_forall eq))))
1534    fun same_domain eq1 eq2 = eqn_type eq1 = eqn_type eq2
1535    fun prove_c11_for_type eqns = let
1536      val funs =
1537        List.filter (fn tm => is_comb(rand(lhs(snd(strip_forall tm)))))
1538        eqns
1539    in
1540      if null funs then NONE
1541      else
1542        SOME (LIST_CONJ (map (prove_const_one_one th) funs))
1543        handle HOL_ERR _ =>
1544          raise ERR "prove_constructors_one_one" ""
1545    end
1546    fun maybe_prove eqns =
1547      if Lib.mem (eqn_type (hd eqns)) axtypes then
1548        SOME (prove_c11_for_type eqns)
1549      else NONE
1550  in
1551    List.mapPartial maybe_prove (partition same_domain all_eqns)
1552  end
1553
1554
1555(* =====================================================================*)
1556(* DISTINCTNESS OF VALUES FOR EACH CONSTRUCTOR                          *)
1557(* =====================================================================*)
1558
1559(* ---------------------------------------------------------------------*)
1560(* prove_constructors_distinct : prove that the constructors of a given *)
1561(* recursive type yield distinct (non-equal) values.                    *)
1562(*                                                                      *)
1563(* EXAMPLE:                                                             *)
1564(*                                                                      *)
1565(* Input:                                                               *)
1566(*                                                                      *)
1567(*    |- !x f. ?! fn. (fn[] = x) /\ (!h t. fn(CONS h t) = f(fn t)h t)   *)
1568(*                                                                      *)
1569(* Output:                                                              *)
1570(*                                                                      *)
1571(*    |- !h t. ~([] = CONS h t)                                         *)
1572(* ---------------------------------------------------------------------*)
1573
1574(* Basic strategy is to define a function over the type such that
1575      f (C1 ...) = 0
1576      f (C2 ...) = 1
1577      f (C3 ...) = 2
1578      ...
1579      f (Cn ...) = n
1580   However, we want to do this by avoiding the use of numbers.  So, we
1581   encode the numbers on the RHS above as functions over booleans.  In
1582   particular, the type of the function will be
1583      bool ^ log n -> bool
1584   The encoding of the function will be such that it is true iff the
1585   arguments form the encoding of the number it is supposed to represent.
1586   If we have 10 constructors, log n will be 4, and the encoding of 5 will
1587   be
1588       \b4 b3 b2 b1. b4 /\ ~b3 /\ b2 /\ ~b1
1589   The encoding is MSB to the left.
1590
1591   When function f is defined, it is then easy to distinguish any
1592   two constructors Ci and Cj.
1593       Assume           (Ci xn) = (Cj yn)
1594       then           f (Ci xn) = f (Cj yn)            (Leibnitz)
1595       so     f (Ci xn) [| i |] = f (Cj yn) [| i |]      (ditto)
1596   But f is constructed in such a way that the term on the left will be
1597   true, while that on the right will be false.  We derive a contradiction
1598   and conclude that the original assumption was false.
1599*)
1600
1601local
1602  val bn0 = mk_var("b0", Type.bool)
1603  val bn1 = mk_var("b1", Type.bool)
1604  val bn2 = mk_var("b2", Type.bool)
1605  val bn3 = mk_var("b3", Type.bool)
1606  val bn4 = mk_var("b4", Type.bool)
1607  val bn5 = mk_var("b5", Type.bool)
1608  val bn6 = mk_var("b6", Type.bool)
1609  val bn7 = mk_var("b7", Type.bool)
1610  val bn8 = mk_var("b8", Type.bool)
1611  val bn9 = mk_var("b9", Type.bool)
1612in
1613  fun bn n =
1614    case n of
1615      0 => bn0
1616    | 1 => bn1
1617    | 2 => bn2
1618    | 3 => bn3
1619    | 4 => bn4
1620    | 5 => bn5
1621    | 6 => bn6
1622    | 7 => bn7
1623    | 8 => bn8
1624    | 9 => bn9
1625    | x => mk_var("b"^Int.toString x, Type.bool)
1626end
1627
1628
1629(* encode bv nb n
1630     returns a list of terms encoding the number n in nb bits.  If bv is
1631     true, then the terms are successive boolean variables.  If bv is false
1632     then the terms are all either T or ~T.
1633   encode true nb n
1634     is used to produce the body of the functions encoding for n
1635   encode false nb n
1636     is used to produce the arguments that the functions are applied to.
1637*)
1638fun encode bv numbits n =
1639  if numbits <= 0 then []
1640  else let
1641    val bn0 = if bv then bn numbits else T
1642    val bn = if n mod 2 = 0 then mk_neg bn0 else bn0
1643  in
1644    bn::encode bv (numbits - 1) (n div 2)
1645  end
1646
1647(*
1648  mk_num generates the function corresponding to number n.  The abstraction
1649  will have numbits bound variables.
1650*)
1651fun mk_num numbits n = let
1652  val vars = List.tabulate(numbits, (fn n => bn (n + 1)))
1653in
1654  list_mk_abs(List.rev vars, list_mk_conj(encode true numbits n))
1655end
1656
1657(* calculates how many bits are required to represent a number *)
1658fun rounded_log n = if n <= 1 then 0 else 1 + rounded_log ((n + 1) div 2)
1659
1660fun RATORn_CONV n c t = if n <= 0 then c t
1661                        else RATOR_CONV (RATORn_CONV (n - 1) c) t
1662
1663fun nBETA_CONV dpth n =
1664  if n <= 0 then REFL
1665  else
1666    RATORn_CONV (dpth - 1) BETA_CONV THENC nBETA_CONV (dpth - 1) (n - 1)
1667
1668(* !x. ~T /\ x = ~T *)
1669val notT_and = prove(gen_all ((mk_neg T /\ bn 1) == mk_neg T),
1670                              REWRITE_TAC []);
1671(* !x. ~~T /\ x = x *)
1672val notnotT_and = prove(gen_all (((mk_neg (mk_neg T)) /\ bn 1) == bn 1),
1673                        REWRITE_TAC []);
1674(* !x. T /\ x = x *)
1675val T_and = prove(gen_all (T /\ bn 1 == bn 1), REWRITE_TAC []);
1676(* (T = ~T) = F *)
1677val T_eqF = prove((T == mk_neg T) == F, REWRITE_TAC []);
1678(* ~~T = T *)
1679val notnotT = prove(mk_neg (mk_neg T) == T, REWRITE_TAC []);
1680(* ~T = F *)
1681val notT = prove(mk_neg T == F, REWRITE_TAC []);
1682
1683(* A special purpose conv to move along a conjunction of T's, ~T's and ~~T's,
1684   simplifying it to a single atom as quickly as possible.
1685   Might be possible to improve it by looking for an instance of ~T, and
1686   then doing the two rewrites required to push this to the top. *)
1687fun simp_conjs t =
1688  if is_conj t then let
1689    val (conj1, conj2) = dest_conj t
1690  in
1691    if is_neg conj1 then
1692      if is_neg (dest_neg conj1) then
1693        (REWR_CONV notnotT_and THENC simp_conjs) t
1694      else
1695        REWR_CONV notT_and t
1696    else
1697      (REWR_CONV T_and THENC simp_conjs) t
1698  end
1699  else REFL t
1700
1701fun to_true t = if is_neg t then REWR_CONV notnotT t else REFL t
1702
1703fun prove_ineq nb f fc1 fc2 c1 c20 c1n = let
1704  val c1_vars = #2 (strip_comb c1)
1705  val (c2_t, c20_vars) = strip_comb c20
1706  val c2_vars = list_variant c1_vars c20_vars
1707  val c2 = list_mk_comb (c2_t, c2_vars)
1708  val c1c2_eqt = c1 == c2
1709  val c1_eq_c2 = ASSUME c1c2_eqt
1710  val fc1_eq_fc2 = AP_TERM f c1_eq_c2
1711  fun fold (arg, thm) = AP_THM thm arg
1712  val fc1_args_eq_fc2_args = List.foldl fold fc1_eq_fc2 c1n
1713  val expand_left =
1714    CONV_RULE (LHS_CONV (RATORn_CONV nb (REWR_CONV fc1))) fc1_args_eq_fc2_args
1715  val expand_right =
1716    CONV_RULE (RHS_CONV (RATORn_CONV nb (REWR_CONV fc2))) expand_left
1717  val beta_left = CONV_RULE (LHS_CONV (nBETA_CONV nb nb)) expand_right
1718  val beta_right = CONV_RULE (RHS_CONV (nBETA_CONV nb nb)) beta_left
1719  val result0 =
1720    CONV_RULE (LHS_CONV (simp_conjs THENC to_true) THENC RHS_CONV simp_conjs)
1721    beta_right
1722  val result1 = DISCH c1c2_eqt (EQ_MP T_eqF result0)
1723  val result = GEN_ALL (MATCH_MP IMP_F result1)
1724in
1725  result
1726end
1727
1728(* The type of numbers represented using nb many bits *)
1729
1730fun numtype nb = if nb <= 1 then bool --> bool else bool --> numtype (nb - 1)
1731
1732fun generate_fn_term nb ty = genvar (ty --> numtype nb)
1733
1734fun generate_eqns nb ctrs f =
1735 let fun recurse n [] = []
1736       | recurse n (x::xs) =
1737         (mk_comb(f, x) == mk_num nb n):: recurse (n + 1) xs
1738 in recurse 0 ctrs
1739 end
1740
1741fun number nb lst =
1742let fun number0 _ [] = []
1743      | number0 n (x::xs) = (encode false nb n,x)::number0 (n+1) xs
1744in
1745  number0 0 lst
1746end
1747
1748fun app_triangle f [] = []
1749  | app_triangle f [x] = []
1750  | app_triangle f (x::xs) = map (fn y => f (x, y)) xs @ app_triangle f xs
1751
1752fun ctrs_with_args clauses =
1753 let fun get_ctr tm = rand (lhs (#2 (strip_forall tm)))
1754 in map get_ctr clauses
1755 end
1756
1757fun prove_constructors_distinct thm = let
1758  val all_eqns = strip_conj(snd(strip_exists(snd(strip_forall(concl thm)))))
1759  val axtypes = doms_of_tyaxiom thm
1760  fun eqn_type eq = type_of (rand (lhs (#2 (strip_forall eq))))
1761  fun same_domain eq1 eq2 = eqn_type eq1 = eqn_type eq2
1762  fun prove_cd_for_type eqns = let
1763    val ctrs = ctrs_with_args eqns
1764    val nb = rounded_log (length ctrs)
1765  in
1766    if nb = 0 then NONE
1767    else let
1768      val f = generate_fn_term nb (type_of (hd ctrs))
1769      val eqns = generate_eqns nb ctrs f
1770      val fn_defn = list_mk_conj eqns
1771      val fn_exists = prove_recursive_functions_exist thm fn_defn
1772      val fn_thm = ASSUME (snd (dest_exists (concl fn_exists)))
1773      val eqn_thms = CONJUNCTS fn_thm
1774      val ctrs_with_eqns_and_numbers =
1775        number nb (ListPair.zip (ctrs, eqn_thms))
1776      fun prove_result ((c1n, (c1, fc1)), (c2n, (c2, fc2))) =
1777        prove_ineq nb f fc1 fc2 c1 c2 c1n
1778      val thms = app_triangle prove_result ctrs_with_eqns_and_numbers
1779      val thm = LIST_CONJ thms
1780    in
1781      SOME (CHOOSE (f, fn_exists) thm)
1782    end
1783  end
1784  fun maybe_prove_cd_for_type eqns =
1785    let val ctrs = ctrs_with_args eqns
1786    in if Lib.mem (type_of (hd ctrs)) axtypes
1787          then SOME (prove_cd_for_type eqns)
1788          else NONE
1789    end
1790in
1791  List.mapPartial maybe_prove_cd_for_type (partition same_domain all_eqns)
1792end
1793
1794end (* local where partition is defined *)
1795
1796(*---------------------------------------------------------------------------
1797
1798          Test routines for distinctness proofs.
1799
1800  load "Define_type";
1801  fun gen_type n = let
1802    val name = "foo"^Int.toString n
1803    val fixities = List.tabulate(n, fn _ => Prefix)
1804    fun clause n = let
1805      val C = "C"^Int.toString n
1806    in
1807      if n = 0 then  "C0 of bool => 'a"
1808      else C ^ " of bool => 'a => " ^name
1809    end
1810    fun sepby sep [] = []
1811      | sepby sep [x]= [x]
1812      | sepby sep (x::xs) = x::sep::sepby sep xs
1813    val clauses = sepby " | " (List.tabulate(n, clause))
1814    val spec = String.concat (name::" = "::clauses)
1815  in
1816    Define_type.define_type { fixities = fixities,
1817                              type_spec = [QUOTE spec],
1818                              name = name }
1819  end;
1820  val foo5 = gen_type 5;
1821  val foo10 = gen_type 10;
1822  val foo20 = gen_type 20;
1823  Lib.time prove_constructors_distinct foo5;
1824  Lib.time prove_constructors_distinct foo10;
1825  Lib.time prove_constructors_distinct foo20;
1826  (* tests seem to indicate that the code above is roughly 1.5 times
1827     slower than the original code by Tom Melham.  This is probably
1828     acceptable given that it is now independent of the theory of
1829     numbers *)
1830
1831 ---------------------------------------------------------------------------*)
1832
1833fun usefuls cs = let
1834  fun cfvs c = FVL (#1 (strip_forall c)) empty_tmset
1835  val vs = foldl (fn (c,s) => HOLset.intersection(cfvs c, s))
1836                 (cfvs (hd cs))
1837                 (tl cs)
1838  val vs_l = HOLset.listItems vs
1839  val (_, eqn1) = strip_forall (hd cs)
1840  val (const,args) = strip_comb (lhs eqn1)
1841  val (arg1_ty, casefs) = case args of
1842                              h::t => (type_of h,t)
1843                            | _ => raise mk_HOL_ERR "Prim_rec" "prove_case_rand_thm"
1844                                         "Case constant theorem has too few arguments"
1845  val v = variant vs_l (mk_var("x", arg1_ty))
1846in
1847  (const, list_mk_comb(const, v::tl args), casefs, vs_l, v)
1848end
1849
1850fun prove_case_rand_thm {nchotomy, case_def} = let
1851  val cs = strip_conj (concl case_def)
1852  val (const,t,casefs,vs_l,v) = usefuls cs
1853  val tyvs = foldl (fn (v,s) => HOLset.addList(s, type_vars_in_term v))
1854                   (HOLset.empty Type.compare)
1855                   (v::vs_l)
1856  fun foldthis (sv, acc) = if acc = sv then mk_vartype(dest_vartype acc ^ "1")
1857                           else acc
1858  val fresh_tyvar = HOLset.foldl foldthis alpha tyvs
1859  val f = variant vs_l (mk_var("f", type_of t --> fresh_tyvar))
1860  val ctor_args = map (fn c => c |> strip_forall |> #2 |> lhs |> strip_comb
1861                                 |> #2 |> hd |> strip_comb |> #2)
1862                      cs
1863  val new_lhs = mk_comb(f, t)
1864  val cfs' =
1865      ListPair.map
1866        (fn (cf, args) => list_mk_abs(args, mk_comb(f, list_mk_comb(cf, args))))
1867        (casefs, ctor_args)
1868  val const' = inst [#2 (strip_fun (type_of const)) |-> fresh_tyvar] const
1869  val rhs' = list_mk_comb(const', v::cfs')
1870in
1871  prove(mk_eq(mk_comb(f,t),rhs'),
1872        STRUCT_CASES_TAC (ISPEC v nchotomy) THEN
1873        PURE_REWRITE_TAC [case_def] THEN BETA_TAC THEN
1874        PURE_REWRITE_TAC [EQT_INTRO (SPEC_ALL EQ_REFL)])
1875end
1876
1877fun strip_exists' avds t =
1878  let
1879    fun recurse acc avds t =
1880      if is_exists t then
1881        let
1882          val (v, bod) = dest_exists t
1883          val v' = variant avds v
1884        in
1885          recurse (v'::acc) (v'::avds)
1886                  (if v ~~ v' then bod else subst[v |-> v'] bod)
1887        end
1888      else (List.rev acc, t)
1889  in
1890    recurse [] avds t
1891  end
1892
1893(* prove a theorem of the form
1894     ty_CASE x f1 .. fn :bool <=>
1895       (?a1 .. ai. x = ctor1 a1 .. ai /\ f1 a1 .. ai) \/
1896       (?b1 .. bj. x = ctor2 b1 .. bj /\ f2 b1 .. bj) \/
1897       ...
1898*)
1899fun prove_case_elim_thm {case_def,nchotomy} = let
1900  val cs = strip_conj (concl case_def)
1901  val (const,t0,casefs0,_,v) = usefuls cs
1902  val casefs = map (inst [type_of t0 |-> bool]) casefs0
1903  val t = inst [type_of t0 |-> bool] t0
1904  val t_thm = ASSUME t
1905  val nch = SPEC v nchotomy
1906  val disjs0 = strip_disj (concl nch)
1907  fun mapthis (ex_eqn, cf) = let
1908    val (vs, eqn) = strip_exists' casefs ex_eqn
1909  in
1910    list_mk_exists(vs, mk_conj(eqn, list_mk_comb(cf,vs)))
1911  end
1912  val _ = length casefs = length disjs0 orelse
1913          raise mk_HOL_ERR "Prim_rec" "prove_case_elim_thm"
1914                "case_def and nchotomy theorem don't match up"
1915  val disjs = ListPair.map mapthis (disjs0,casefs)
1916  fun prove_case_from_eqn ext = let
1917    (* ext is of the form ``?a1 .. ai. x = ctor1 a1 .. ai /\ f1 a1 .. ai`` *)
1918    val (vs,body) = strip_exists ext
1919    val body_th = ASSUME body
1920    val res0 = EQT_ELIM (PURE_REWRITE_CONV [body_th,case_def] t)
1921    fun foldthis (v,(ext,th)) = let
1922      val ex' = mk_exists(v,ext)
1923    in
1924      (ex', CHOOSE (v,ASSUME ex') th)
1925    end
1926  in
1927    List.foldr foldthis (body,res0) vs
1928  end
1929  val ths = map (#2 o prove_case_from_eqn) disjs
1930  val (ds, d) = front_last disjs
1931  fun disj_recurse ds ths =
1932      case (ds, ths) of
1933          ([d], [th]) => (d, th)
1934        | (d::ds, th::ths) =>
1935          let
1936            val (d2, th2) = disj_recurse ds ths
1937            val t = mk_disj(d,d2)
1938          in
1939            (t, DISJ_CASES (ASSUME t) th th2)
1940          end
1941        | _ => raise Fail "prove_case_elim_thm: can't happen"
1942  val (_, case1) = disj_recurse disjs ths
1943
1944  fun prove_exists d = let
1945    val (vs, body) = strip_exists d
1946    val (c1, c2) = dest_conj body
1947    val body_th = CONJ (ASSUME c1) (ASSUME c2)
1948    val ex_thm =
1949        List.foldr (fn (v, th) => EXISTS(mk_exists(v,concl th),v) th)
1950                   body_th vs
1951    val elim =
1952        EQ_MP (PURE_REWRITE_CONV [ASSUME c1, case_def] t) t_thm
1953  in
1954    PROVE_HYP elim ex_thm
1955  end
1956  val exists_thms = map prove_exists disjs
1957  fun mkholes ds =
1958      case ds of
1959          [] => []
1960        | d::t => (NONE::map SOME t) ::
1961                  (map (cons (SOME d)) (mkholes t))
1962  val holes = mkholes (map concl exists_thms)
1963  fun holemerge (th, holedlist) =
1964      case holedlist of
1965          [NONE] => th
1966        | NONE::rest => DISJ1 th (list_mk_disj (map valOf rest))
1967        | SOME t::rest =>
1968          DISJ2 t (holemerge (th, rest))
1969        | [] => raise Fail "Can't happen"
1970  val eqconcls =
1971      ListPair.map (DISCH t o holemerge) (exists_thms, holes)
1972  fun existentialify_asm th = let
1973    val eq = hd (hyp th)
1974    val (_, vs) = eq |> rhs |> strip_comb
1975    fun foldthis (v, (ext, th)) = let
1976      val ext' = mk_exists(v,ext)
1977    in
1978      (ext', CHOOSE(v, ASSUME ext') th)
1979    end
1980  in
1981    #2 (List.foldr foldthis (eq, th) vs)
1982  end
1983  val existentialified_asms = map existentialify_asm eqconcls
1984in
1985  IMP_ANTISYM_RULE
1986    (PROVE_HYP nch (#2 (disj_recurse disjs0 existentialified_asms)))
1987    (DISCH_ALL case1)
1988end
1989
1990fun prove_case_eq_thm (rcd as {case_def, nchotomy}) = let
1991  val elim = prove_case_elim_thm rcd
1992  val rand_thm = prove_case_rand_thm rcd
1993  val fvar = rator (lhs (concl rand_thm))
1994  val (dom, rng) = dom_rng (type_of fvar)
1995  val fvs = free_vars (concl rand_thm)
1996  val v = variant fvs (mk_var("v", dom))
1997  val v' = variant (v::fvs) v
1998  val rand_thm' = rand_thm |> INST_TYPE [rng |-> bool]
1999                           |> INST [inst [rng |-> bool] fvar |->
2000                                    mk_abs(v', mk_eq(v', v))]
2001                           |> BETA_RULE
2002in
2003  rand_thm' |> CONV_RULE (RAND_CONV (REWR_CONV elim))
2004            |> BETA_RULE
2005end
2006
2007
2008
2009
2010end; (* Prim_rec *)
2011