1structure Defn :> Defn =
2struct
3
4open HolKernel Parse boolLib;
5open pairLib Rules wfrecUtils Pmatch Induction DefnBase;
6
7type thry   = TypeBasePure.typeBase
8type proofs = Manager.proofs
9type absyn  = Absyn.absyn;
10
11val ERR = mk_HOL_ERR "Defn";
12val ERRloc = mk_HOL_ERRloc "Defn";
13
14val monitoring = ref false;
15
16(* Interactively:
17  val const_eq_ref = ref (!Defn.const_eq_ref);
18*)
19val const_eq_ref = ref Conv.NO_CONV;
20
21(*---------------------------------------------------------------------------
22      Miscellaneous support
23 ---------------------------------------------------------------------------*)
24
25fun enumerate l = map (fn (x,y) => (y,x)) (Lib.enumerate 0 l);
26
27fun drop [] x = x
28  | drop (_::t) (_::rst) = drop t rst
29  | drop _ _ = raise ERR "drop" "";
30
31fun variants FV vlist =
32  fst
33    (rev_itlist
34       (fn v => fn (V,W) =>
35           let val v' = variant W v in (v'::V, v'::W) end) vlist ([],FV));
36
37fun make_definition thry s tm = (new_definition(s,tm), thry)
38
39fun head tm = head (rator tm) handle HOL_ERR _ => tm;
40
41fun all_fns eqns =
42  mk_set (map (head o lhs o #2 o strip_forall) (strip_conj eqns));
43
44fun dest_hd_eqn eqs =
45  let val hd_eqn = if is_conj eqs then fst(dest_conj eqs) else eqs
46      val (lhs,rhs) = dest_eq (snd(strip_forall hd_eqn))
47  in (strip_comb lhs, rhs)
48  end;
49
50fun dest_hd_eqnl (hd_eqn::_) =
51  let val (lhs,rhs) = dest_eq (snd (strip_forall (concl hd_eqn)))
52  in (strip_comb lhs, rhs)
53  end
54  | dest_hd_eqnl _ = raise Match
55
56fun extract_info constset db =
57    let open TypeBasePure
58        fun foldthis (tyinfo, (R, C)) =
59            if List.exists
60                (fn x => same_const (case_const_of tyinfo) x
61                         handle HOL_ERR _ => false) constset then
62              (case_def_of tyinfo::R, case_cong_of tyinfo::C)
63            else (R, C)
64      val (rws,congs) = foldl foldthis ([], []) (listItems db)
65    in {case_congs=congs, case_rewrites=rws}
66    end;
67
68(*---------------------------------------------------------------------------
69    Support for automatically building names to store definitions
70    (and the consequences thereof) with in the current theory. Somewhat
71    ad hoc, but I don't know a better way!
72 ---------------------------------------------------------------------------*)
73
74val ind_suffix = ref "_ind";
75val def_suffix = boolLib.def_suffix
76
77fun indSuffix s     = (s ^ !ind_suffix);
78fun defSuffix s     = (s ^ !def_suffix);
79fun defPrim s       = defSuffix(s^"_primitive");
80fun defExtract(s,n) = defSuffix(s^"_extract"^Lib.int_to_string n);
81fun argMunge s      = defSuffix(s^"_curried");
82fun auxStem stem    = stem^"_AUX";
83fun unionStem stem  = stem^"_UNION";
84
85val imp_elim =
86 let val P = mk_var("P",bool)
87     val Q = mk_var("Q",bool)
88     val R = mk_var("R",bool)
89     val PimpQ = mk_imp(P,Q)
90     val PimpR = mk_imp(P,R)
91     val tm = mk_eq(PimpQ,PimpR)
92     val tm1 = mk_imp(P,tm)
93     val th1 = DISCH tm (DISCH P (ASSUME tm))
94     val th2 = ASSUME tm1
95     val th2a = ASSUME P
96     val th3 = MP th2 th2a
97     val th3a = EQ_MP (SPECL[PimpQ, PimpR] boolTheory.EQ_IMP_THM) th3
98     val (th4,th5) = (CONJUNCT1 th3a,CONJUNCT2 th3a)
99     fun pmap f (x,y) = (f x, f y)
100     val (th4a,th5a) = pmap (DISCH P o funpow 2 UNDISCH) (th4,th5)
101     val th4b = DISCH PimpQ th4a
102     val th5b = DISCH PimpR th5a
103     val th6 = DISCH tm1 (IMP_ANTISYM_RULE th4b th5b)
104     val th7 = DISCH tm (DISCH P (ASSUME tm))
105 in GENL [P,Q,R]
106         (IMP_ANTISYM_RULE th6 th7)
107 end;
108
109fun inject ty [v] = [v]
110  | inject ty (v::vs) =
111     let val (lty,rty) = case dest_type ty of
112                           (_, [x,y]) => (x,y)
113                         | _ => raise Bind
114         val res = mk_comb(mk_const("INL", lty-->ty),v)
115         val inr = curry mk_comb (mk_const("INR", rty-->ty))
116     in
117       res::map inr (inject rty vs)
118     end
119  | inject ty [] = raise Match
120
121
122fun project [] _ _ = raise ERR "project" "catastrophic invariant failure (eqns was empty!?)"
123  | project [_] ty M = [M]
124  | project (_::ls) ty M = let
125      val (lty,rty) = sumSyntax.dest_sum ty
126      in mk_comb(mk_const("OUTL", type_of M-->lty),M)
127         :: project ls rty (mk_comb(mk_const("OUTR", type_of M-->rty),M))
128      end
129
130(*---------------------------------------------------------------------------*
131 * We need a "smart" MP. th1 can be less quantified than th2, so th2 has     *
132 * to be specialized appropriately. We assume that all the "local"           *
133 * variables are quantified first.                                           *
134 *---------------------------------------------------------------------------*)
135
136fun ModusPonens th1 th2 =
137  let val V1 = #1(strip_forall(fst(dest_imp(concl th1))))
138      val V2 = #1(strip_forall(concl th2))
139      val diff = Lib.op_set_diff Term.aconv V2 V1
140      fun loop th =
141        if is_forall(concl th)
142        then let val (Bvar,Body) = dest_forall (concl th)
143             in if Lib.op_mem Term.aconv Bvar diff
144                then loop (SPEC Bvar th) else th
145             end
146        else th
147  in
148    MP th1 (loop th2)
149  end
150  handle _ => raise ERR "ModusPonens" "failed";
151
152
153(*---------------------------------------------------------------------------*)
154(* Version of PROVE_HYP that works modulo permuting outer universal quants.  *)
155(*---------------------------------------------------------------------------*)
156
157fun ALPHA_PROVE_HYP th1 th2 =
158 let val asl = hyp th2
159     val (U,tm) = strip_forall (concl th1)
160     val a = Lib.first (fn t => aconv tm (snd(strip_forall t))) asl
161     val V = fst(strip_forall a)
162     val th1' = GENL V (SPECL U th1)
163 in
164   PROVE_HYP th1' th2
165 end;
166
167fun name_of (ABBREV {bind, ...})           = bind
168  | name_of (PRIMREC{bind, ...})           = bind
169  | name_of (NONREC {eqs, ind, stem, ...}) = stem
170  | name_of (STDREC {eqs, ind, stem, ...}) = stem
171  | name_of (MUTREC {eqs,ind,stem,...})    = stem
172  | name_of (NESTREC{eqs,ind,stem, ...})   = stem
173  | name_of (TAILREC{eqs,ind,stem, ...})   = stem
174
175fun eqns_of (ABBREV  {eqn, ...}) = [eqn]
176  | eqns_of (NONREC  {eqs, ...}) = [eqs]
177  | eqns_of (PRIMREC {eqs, ...}) = [eqs]
178  | eqns_of (STDREC  {eqs, ...}) = eqs
179  | eqns_of (NESTREC {eqs, ...}) = eqs
180  | eqns_of (MUTREC  {eqs, ...}) = eqs
181  | eqns_of (TAILREC {eqs, ...}) = eqs;
182
183
184fun aux_defn (NESTREC {aux, ...}) = SOME aux
185  | aux_defn     _  = NONE;
186
187fun union_defn (MUTREC {union, ...}) = SOME union
188  | union_defn     _  = NONE;
189
190fun ind_of (ABBREV _)           = NONE
191  | ind_of (NONREC  {ind, ...}) = SOME ind
192  | ind_of (PRIMREC {ind, ...}) = SOME ind
193  | ind_of (STDREC  {ind, ...}) = SOME ind
194  | ind_of (NESTREC {ind, ...}) = SOME ind
195  | ind_of (MUTREC  {ind, ...}) = SOME ind
196  | ind_of (TAILREC {ind, ...}) = SOME ind;
197
198
199fun params_of (ABBREV _)  = []
200  | params_of (PRIMREC _) = []
201  | params_of (NONREC {SV, ...}) = SV
202  | params_of (STDREC {SV, ...}) = SV
203  | params_of (NESTREC{SV, ...}) = SV
204  | params_of (MUTREC {SV, ...}) = SV
205  | params_of (TAILREC {SV, ...}) = SV
206
207fun schematic defn = not(List.null (params_of defn));
208
209fun tcs_of (ABBREV _)  = []
210  | tcs_of (NONREC _)  = []
211  | tcs_of (PRIMREC _) = []
212  | tcs_of (STDREC  {eqs,...}) = op_U aconv (map hyp eqs)
213  | tcs_of (NESTREC {eqs,...}) = op_U aconv (map hyp eqs)
214  | tcs_of (MUTREC  {eqs,...}) = op_U aconv (map hyp eqs)
215  | tcs_of (TAILREC {eqs,...}) = raise ERR "tcs_of" "Tail recursive definition"
216
217
218fun reln_of (ABBREV _)  = NONE
219  | reln_of (NONREC _)  = NONE
220  | reln_of (PRIMREC _) = NONE
221  | reln_of (STDREC  {R, ...}) = SOME R
222  | reln_of (NESTREC {R, ...}) = SOME R
223  | reln_of (MUTREC  {R, ...}) = SOME R
224  | reln_of (TAILREC {R, ...}) = SOME R;
225
226
227fun nUNDISCH n th = if n<1 then th else nUNDISCH (n-1) (UNDISCH th)
228
229fun INST_THM theta th =
230  let val asl = hyp th
231      val th1 = rev_itlist DISCH asl th
232      val th2 = INST_TY_TERM theta th1
233  in
234   nUNDISCH (length asl) th2
235  end;
236
237fun isubst (tmtheta,tytheta) tm = subst tmtheta (inst tytheta tm);
238
239fun inst_defn (STDREC{eqs,ind,R,SV,stem}) theta =
240      STDREC {eqs=map (INST_THM theta) eqs,
241              ind=INST_THM theta ind,
242              R=isubst theta R,
243              SV=map (isubst theta) SV, stem=stem}
244  | inst_defn (NESTREC{eqs,ind,R,SV,aux,stem}) theta =
245      NESTREC {eqs=map (INST_THM theta) eqs,
246               ind=INST_THM theta ind,
247               R=isubst theta R,
248               SV=map (isubst theta) SV,
249               aux=inst_defn aux theta, stem=stem}
250  | inst_defn (MUTREC{eqs,ind,R,SV,union,stem}) theta =
251      MUTREC {eqs=map (INST_THM theta) eqs,
252              ind=INST_THM theta ind,
253              R=isubst theta R,
254              SV=map (isubst theta) SV,
255              union=inst_defn union theta, stem=stem}
256  | inst_defn (PRIMREC{eqs,ind,bind}) theta =
257      PRIMREC{eqs=INST_THM theta eqs,
258              ind=INST_THM theta ind, bind=bind}
259  | inst_defn (NONREC {eqs,ind,SV,stem}) theta =
260      NONREC {eqs=INST_THM theta eqs,
261              ind=INST_THM theta ind,
262              SV=map (isubst theta) SV,stem=stem}
263  | inst_defn (ABBREV {eqn,bind}) theta = ABBREV {eqn=INST_THM theta eqn,bind=bind}
264  | inst_defn (TAILREC{eqs,ind,R,SV,stem}) theta =
265      TAILREC {eqs=map (INST_THM theta) eqs,
266              ind=INST_THM theta ind,
267              R=isubst theta R,
268              SV=map (isubst theta) SV, stem=stem};
269
270
271fun set_reln def R =
272   case reln_of def
273    of NONE => def
274     | SOME Rpat => inst_defn def (Term.match_term Rpat R)
275                    handle e => (HOL_MESG"set_reln: unable"; raise e);
276
277fun PROVE_HYPL thl th = itlist PROVE_HYP thl th
278
279fun MATCH_HYPL thms th =
280  let val aslthms = mapfilter (EQT_ELIM o REWRITE_CONV thms) (hyp th)
281  in itlist PROVE_HYP aslthms th
282  end;
283
284
285(* We use PROVE_HYPL on induction theorems, since their tcs are fully
286   quantified. We use MATCH_HYPL on equations, since their tcs are
287   bare.
288*)
289
290fun elim_tcs (STDREC {eqs, ind, R, SV,stem}) thms =
291     STDREC{R=R, SV=SV, stem=stem,
292            eqs=map (MATCH_HYPL thms) eqs,
293            ind=PROVE_HYPL thms ind}
294  | elim_tcs (NESTREC {eqs, ind, R,  SV, aux, stem}) thms =
295     NESTREC{R=R, SV=SV, stem=stem,
296            eqs=map (MATCH_HYPL thms) eqs,
297            ind=PROVE_HYPL thms ind,
298            aux=elim_tcs aux thms}
299  | elim_tcs (MUTREC {eqs, ind, R, SV, union, stem}) thms =
300     MUTREC{R=R, SV=SV, stem=stem,
301            eqs=map (MATCH_HYPL thms) eqs,
302            ind=PROVE_HYPL thms ind,
303            union=elim_tcs union thms}
304  | elim_tcs (TAILREC {eqs, ind, R, SV,stem}) thms =
305     TAILREC{R=R, SV=SV, stem=stem,
306            eqs=map (MATCH_HYPL thms) eqs,
307            ind=PROVE_HYPL thms ind}
308  | elim_tcs x _ = x;
309
310
311local
312 val lem =
313   let val M  = mk_var("M",bool)
314       val P  = mk_var("P",bool)
315       val M1 = mk_var("M1",bool)
316       val tm1 = mk_eq(M,M1)
317       val tm2 = mk_imp(M,P)
318   in DISCH tm1 (DISCH tm2 (SUBS [ASSUME tm1] (ASSUME tm2)))
319   end
320in
321fun simp_assum conv tm th =
322  let val th' = DISCH tm th
323      val tmeq = conv tm
324      val tm' = rhs(concl tmeq)
325  in
326    if tm' = T then MP th' (EQT_ELIM tmeq)
327    else UNDISCH(MATCH_MP (MATCH_MP lem tmeq) th')
328  end
329end;
330
331fun SIMP_HYPL conv th = itlist (simp_assum conv) (hyp th) th;
332
333fun simp_tcs (STDREC {eqs, ind, R, SV, stem}) conv =
334     STDREC{R=rhs(concl(conv R)), SV=SV, stem=stem,
335            eqs=map (SIMP_HYPL conv) eqs,
336            ind=SIMP_HYPL conv ind}
337  | simp_tcs (NESTREC {eqs, ind, R,  SV, aux, stem}) conv =
338     NESTREC{R=rhs(concl(conv R)), SV=SV, stem=stem,
339            eqs=map (SIMP_HYPL conv) eqs,
340            ind=SIMP_HYPL conv ind,
341            aux=simp_tcs aux conv}
342  | simp_tcs (MUTREC {eqs, ind, R, SV, union, stem}) conv =
343     MUTREC{R=rhs(concl(conv R)), SV=SV, stem=stem,
344            eqs=map (SIMP_HYPL conv) eqs,
345            ind=SIMP_HYPL conv ind,
346            union=simp_tcs union conv}
347  | simp_tcs x _ = x;
348
349
350fun TAC_HYPL tac th =
351  PROVE_HYPL (mapfilter (C (curry Tactical.prove) tac) (hyp th)) th;
352
353fun prove_tcs (STDREC {eqs, ind, R, SV, stem}) tac =
354     STDREC{R=R, SV=SV, stem=stem,
355            eqs=map (TAC_HYPL tac) eqs,
356            ind=TAC_HYPL tac ind}
357  | prove_tcs (NESTREC {eqs, ind, R,  SV, aux, stem}) tac =
358     NESTREC{R=R, SV=SV, stem=stem,
359            eqs=map (TAC_HYPL tac) eqs,
360            ind=TAC_HYPL tac ind,
361            aux=prove_tcs aux tac}
362  | prove_tcs (MUTREC {eqs, ind, R, SV, union, stem}) tac =
363     MUTREC{R=R, SV=SV, stem=stem,
364            eqs=map (TAC_HYPL tac) eqs,
365            ind=TAC_HYPL tac ind,
366            union=prove_tcs union tac}
367  | prove_tcs x _ = x;
368
369
370(*---------------------------------------------------------------------------*)
371(* Deal with basic definitions.                                              *)
372(*---------------------------------------------------------------------------*)
373
374fun triv_defn (ABBREV _) = true
375  | triv_defn (PRIMREC _) = true
376  | triv_defn otherwise = false
377
378fun fetch_eqns (ABBREV{eqn,...})  = eqn
379  | fetch_eqns (PRIMREC{eqs,...}) = eqs
380  | fetch_eqns otherwise = raise ERR "fetch_eqns" "shouldn't happen"
381
382(*---------------------------------------------------------------------------
383   Store definition information to disk. Currently, just writes out the
384   eqns and induction theorem. A more advanced implementation would
385   write things out so that, when the exported theory is reloaded, the
386   defn datastructure is rebuilt. This would give a seamless view of
387   things.
388
389   Note that we would need to save union and aux info only when
390   termination has not been proved for a nested recursion.
391
392   Another (easier) way to look at it would be to require termination
393   and suchlike to be taken care of in the current theory. That is
394   what is assumed at present.
395 ---------------------------------------------------------------------------*)
396
397local fun is_suc tm =
398       case total dest_thy_const tm
399        of NONE => false
400         | SOME{Name,Thy,...} => Name="SUC" andalso Thy="num"
401in
402val SUC_TO_NUMERAL_DEFN_CONV_hook =
403      ref (fn _ => raise ERR "SUC_TO_NUMERAL_DEFN_CONV_hook" "not initialized")
404fun add_persistent_funs l =
405  if not (!computeLib.auto_import_definitions) then () else
406    let val has_lhs_SUC = List.exists
407              (can (find_term is_suc) o lhs o #2 o strip_forall)
408                                      o strip_conj o concl
409      fun f (s, th) =
410        [s] @
411        (if has_lhs_SUC th then let
412            val name = s^"_compute"
413            val name = let
414              val used = Lib.C Lib.mem (#1 (Lib.unzip (current_theorems())))
415              fun loop n = let val x = (name^(Int.toString n)) in if used x then loop (n+1) else x end
416              in if used name then loop 0 else name end
417            val th_compute = CONV_RULE (!SUC_TO_NUMERAL_DEFN_CONV_hook) th
418            val _ = save_thm(name,th_compute)
419            in [name] end
420         else [])
421    in
422      computeLib.add_persistent_funs (List.concat (map f l))
423    end
424end;
425
426val mesg = with_flag(MESG_to_string, Lib.I) HOL_MESG
427
428local
429  val chatting = ref true
430  val _ = Feedback.register_btrace("Define.storage_message", chatting)
431in
432fun been_stored (s,thm) =
433  (add_persistent_funs [(s,thm)];
434   if !chatting then
435     mesg ((if !Globals.interactive then
436              "Definition has been stored under "
437            else
438              "Saved definition __ ") ^Lib.quote s^"\n")
439   else ()
440   )
441
442fun store(stem,eqs,ind) =
443  let val eqs_bind = defSuffix stem
444      val ind_bind = indSuffix stem
445      fun save x = Feedback.trace ("Theory.save_thm_reporting", 0) save_thm x
446      val   _  = save (ind_bind, ind)
447      val eqns = save (eqs_bind, eqs)
448      val _ = add_persistent_funs [(eqs_bind,eqs)]
449         handle e => HOL_MESG ("Unable to add "^eqs_bind^" to global compset")
450  in
451    if !chatting then
452       mesg (String.concat
453               (if !Globals.interactive then
454                  [   "Equations stored under ", Lib.quote eqs_bind,
455                   ".\nInduction stored under ", Lib.quote ind_bind, ".\n"]
456                else
457                  [  "Saved definition __ ", Lib.quote eqs_bind,
458                   "\nSaved induction ___ ", Lib.quote ind_bind, "\n"]))
459    else ()
460  end
461end
462
463local
464  val LIST_CONJ_GEN = LIST_CONJ o map GEN_ALL
465in
466  fun save_defn (ABBREV {bind,eqn, ...})     = been_stored (bind,eqn)
467  | save_defn (PRIMREC{bind,eqs, ...})       = been_stored (bind,eqs)
468  | save_defn (NONREC {eqs, ind, stem, ...}) = store(stem,eqs,ind)
469  | save_defn (STDREC {eqs, ind, stem, ...}) = store(stem,LIST_CONJ_GEN eqs,ind)
470  | save_defn (TAILREC{eqs, ind, stem, ...}) = store(stem,LIST_CONJ_GEN eqs,ind)
471  | save_defn (MUTREC {eqs,ind,stem,...})    = store(stem,LIST_CONJ_GEN eqs,ind)
472  | save_defn (NESTREC{eqs,ind,stem, ...})   = store(stem,LIST_CONJ_GEN eqs,ind)
473end
474
475
476(*---------------------------------------------------------------------------
477        Termination condition extraction
478 ---------------------------------------------------------------------------*)
479
480fun extraction_thms constset thy =
481 let val {case_rewrites,case_congs} = extract_info constset thy
482 in (case_rewrites, case_congs@read_congs())
483 end;
484
485(*---------------------------------------------------------------------------
486         Capturing termination conditions.
487 ----------------------------------------------------------------------------*)
488
489
490local fun !!v M = mk_forall(v, M)
491      val mem = Lib.op_mem aconv
492      fun set_diff a b = Lib.filter (fn x => not (mem x b)) a
493in
494fun solver (restrf,f,G,nref) _ context tm =
495  let val globals = f::G  (* not to be generalized *)
496      fun genl tm = itlist !! (set_diff (rev(free_vars tm)) globals) tm
497      val rcontext = rev context
498      val antl = case rcontext of [] => []
499                               | _   => [list_mk_conj(map concl rcontext)]
500      val TC = genl(list_mk_imp(antl, tm))
501      val (R,arg,pat) = wfrecUtils.dest_relation tm
502  in
503     if can(find_term (aconv restrf)) arg
504     then (nref := true; raise ERR "solver" "nested function")
505     else let val _ = if can(find_term (aconv f)) TC
506                      then nref := true else ()
507          in case rcontext
508              of [] => SPEC_ALL(ASSUME TC)
509               | _  => MATCH_MP (SPEC_ALL (ASSUME TC)) (LIST_CONJ rcontext)
510          end
511  end
512end;
513
514fun extract FV congs f (proto_def,WFR) =
515 let val R = rand WFR
516     val CUT_LEM = ISPECL [f,R] relationTheory.RESTRICT_LEMMA
517     val restr_fR = rator(rator(lhs(snd(dest_imp (concl (SPEC_ALL CUT_LEM))))))
518     fun mk_restr p = mk_comb(restr_fR, p)
519 in fn (p,th) =>
520    let val nested_ref = ref false
521        val FV' = FV@free_vars(concl th)
522        val rwArgs = (RW.Pure [CUT_LEM],
523                      RW.Context ([],RW.DONT_ADD),
524                      RW.Congs congs,
525                      RW.Solver (solver (mk_restr p, f, FV', nested_ref)))
526        val th' = CONV_RULE (RW.Rewrite RW.Fully rwArgs) th
527    in
528      (th', Lib.op_set_diff aconv (hyp th') [proto_def,WFR], !nested_ref)
529    end
530end;
531
532
533(*---------------------------------------------------------------------------*
534 * Perform TC extraction without making a definition.                        *
535 *---------------------------------------------------------------------------*)
536
537type wfrec_eqns_result = {WFR : term,
538                          SV : term list,
539                          proto_def : term,
540                          extracta  : (thm * term list * bool) list,
541                          pats  : pattern list}
542
543fun protect_rhs eqn =
544  if is_forall eqn then
545    raise ERR "mk_defn"
546          "Universally quantified equation as argument to well-founded \
547          \recursion"
548  else
549    mk_eq(lhs eqn,combinSyntax.mk_I(rhs eqn))
550fun protect eqns = list_mk_conj (map protect_rhs (strip_conj eqns));
551
552val unprotect_term = rhs o concl o PURE_REWRITE_CONV [combinTheory.I_THM];
553val unprotect_thm  = PURE_REWRITE_RULE [combinTheory.I_THM];
554
555(*---------------------------------------------------------------------------*)
556(* Once patterns are instantiated and the clauses are simplified, there can  *)
557(* still remain right-hand sides with occurrences of fully concrete tests on *)
558(* literals. Here we simplify those away.                                    *)
559(*                                                                           *)
560(* const_eq_ref is a reference cell that decides equality of literals such   *)
561(* as 23, "foo", #"G", 6w, 0b1000w. It gets updated in reduceLib, stringLib  *)
562(* and wordsLib.                                                             *)
563(*---------------------------------------------------------------------------*)
564
565fun elim_triv_literal_CONV tm =
566   let
567      val const_eq_conv = !const_eq_ref
568      val cnv = TRY_CONV (REWR_CONV literal_case_THM THENC BETA_CONV) THENC
569                RATOR_CONV (RATOR_CONV (RAND_CONV const_eq_conv)) THENC
570                PURE_ONCE_REWRITE_CONV [COND_CLAUSES]
571   in
572       cnv tm
573   end
574
575fun checkSV pats SV =
576 let fun get_pat (GIVEN(p,_)) = p
577       | get_pat (OMITTED(p,_)) = p
578     fun strings_of vlist =
579         Lib.commafy (List.map (Lib.quote o #1 o dest_var)
580                               (Listsort.sort Term.compare vlist))
581 in
582   if null SV then ()
583   else HOL_MESG (String.concat
584     ["Definition is schematic in the following variables:\n    ",
585      String.concat (strings_of SV)])
586   ;
587   case intersect (free_varsl (map get_pat pats)) SV
588    of [] => ()
589     | probs =>
590       raise ERR "wfrec_eqns"
591         (String.concat
592             (["the following variables occur both free (schematic) ",
593               "and bound in the definition: \n   "] @ strings_of probs))
594 end
595
596(*---------------------------------------------------------------------------*)
597(* Instantiate the recursion theorem and extract termination conditions,     *)
598(* but do not define the constant yet.                                       *)
599(*---------------------------------------------------------------------------*)
600
601fun wfrec_eqns facts tup_eqs =
602 let val {functional,pats} =
603        mk_functional (TypeBasePure.toPmatchThry facts) (protect tup_eqs)
604     val SV = free_vars functional    (* schematic variables *)
605     val _ = checkSV pats SV
606     val (f, Body) = dest_abs functional
607     val (x,_) = dest_abs Body
608     val (Name, fty) = dest_var f
609     val (f_dty, f_rty) = Type.dom_rng fty
610     val WFREC_THM0 = ISPEC functional relationTheory.WFREC_COROLLARY
611     val R = variant (free_vars tup_eqs) (fst(dest_forall(concl WFREC_THM0)))
612     val WFREC_THM = ISPECL [R, f] WFREC_THM0
613     val tmp = fst(wfrecUtils.strip_imp(concl WFREC_THM))
614     val proto_def = Lib.trye hd tmp
615     val WFR = Lib.trye (hd o tl) tmp
616     val R1 = rand WFR
617     val corollary' = funpow 2 UNDISCH WFREC_THM
618     val given_pats = givens pats
619     val corollaries = map (C SPEC corollary') given_pats
620     val eqns_consts = mk_set(find_terms is_const functional)
621     val (case_rewrites,congs) = extraction_thms eqns_consts facts
622     val RWcnv = REWRITES_CONV (add_rewrites empty_rewrites
623                                (literal_case_THM::case_rewrites))
624     val rule = unprotect_thm o
625                RIGHT_CONV_RULE
626                   (LIST_BETA_CONV
627                    THENC REPEATC ((RWcnv THENC LIST_BETA_CONV) ORELSEC
628                                   elim_triv_literal_CONV))
629     val corollaries' = map rule corollaries
630  in
631     {proto_def=proto_def,
632      SV=Listsort.sort Term.compare SV,
633      WFR=WFR,
634      pats=pats,
635      extracta = map (extract [R1] congs f (proto_def,WFR))
636                     (zip given_pats corollaries')}
637  end
638
639(*---------------------------------------------------------------------------
640 * Pair patterns with termination conditions. The full list of patterns for
641 * a definition is merged with the TCs arising from the user-given clauses.
642 * There can be fewer clauses than the full list, if the user omitted some
643 * cases. This routine is used to prepare input for mk_induction.
644 *---------------------------------------------------------------------------*)
645
646fun merge full_pats TCs =
647let fun insert (p,TCs) =
648      let fun insrt ((x as (h,[]))::rst) =
649                 if (aconv p h) then (p,TCs)::rst else x::insrt rst
650            | insrt (x::rst) = x::insrt rst
651            | insrt[] = raise ERR"merge.insert" "pat not found"
652      in insrt end
653    fun pass ([],ptcl_final) = ptcl_final
654      | pass (ptcs::tcl, ptcl) = pass(tcl, insert ptcs ptcl)
655in
656  pass (TCs, map (fn p => (p,[])) full_pats)
657end;
658
659(*---------------------------------------------------------------------------*
660 * Define the constant after extracting the termination conditions. The      *
661 * wellfounded relation used in the definition is computed by using the      *
662 * choice operator on the extracted conditions (plus the condition that      *
663 * such a relation must be wellfounded).                                     *
664 *                                                                           *
665 * There are three flavours of recursion: standard, nested, and mutual.      *
666 *                                                                           *
667 * A "standard" recursion is one that is not mutual or nested.               *
668 *---------------------------------------------------------------------------*)
669
670fun stdrec thy bindstem {proto_def,SV,WFR,pats,extracta} =
671 let val R1 = rand WFR
672     val f = lhs proto_def
673     val (extractants,TCl_0,_) = unzip3 extracta
674     fun gen_all away tm =
675        let val FV = free_vars tm
676        in itlist (fn v => fn tm =>
677              if mem v away then tm else mk_forall(v,tm)) FV tm
678        end
679     val TCs_0 = op_U aconv TCl_0
680     val TCl = map (map (gen_all (R1::SV))) TCl_0
681     val TCs = op_U aconv TCl
682     val full_rqt = WFR::TCs
683     val R2 = mk_select(R1, list_mk_conj full_rqt)
684     val R2abs = rand R2
685     val fvar = mk_var(fst(dest_var f),
686                       itlist (curry op-->) (map type_of SV) (type_of f))
687     val fvar_app = list_mk_comb(fvar,SV)
688     val (def,theory) = make_definition thy (defPrim bindstem)
689                          (subst [f |-> fvar_app, R1 |-> R2] proto_def)
690     val fconst = fst(strip_comb(lhs(snd(strip_forall(concl def)))))
691     val disch'd = itlist DISCH (proto_def::WFR::TCs_0) (LIST_CONJ extractants)
692     val inst'd = SPEC (list_mk_comb(fconst,SV))
693                       (SPEC R2 (GENL [R1, f] disch'd))
694     val def' = MP inst'd (SPEC_ALL def)
695     val var_wits = LIST_CONJ (map ASSUME full_rqt)
696     val TC_choice_thm =
697           MP (CONV_RULE(BINOP_CONV BETA_CONV)(ISPECL[R2abs, R1] boolTheory.SELECT_AX)) var_wits
698 in
699    {theory = theory, R=R1, SV=SV,
700     rules = CONJUNCTS
701              (rev_itlist (C ModusPonens) (CONJUNCTS TC_choice_thm) def'),
702     full_pats_TCs = merge (map pat_of pats) (zip (givens pats) TCl),
703     patterns = pats}
704 end
705
706
707(*---------------------------------------------------------------------------
708      Nested recursion.
709 ---------------------------------------------------------------------------*)
710
711fun nestrec thy bindstem {proto_def,SV,WFR,pats,extracta} =
712 let val R1 = rand WFR
713     val (f,rhs_proto_def) = dest_eq proto_def
714     (* make parameterized definition *)
715     val (Name,Ty) = Lib.trye dest_var f
716     val aux_name = Name^"_aux"
717     val aux_fvar = mk_var(aux_name,itlist(curry(op-->)) (map type_of (R1::SV)) Ty)
718     val aux_bindstem = auxStem bindstem
719     val (def,theory) =
720           make_definition thy (defSuffix aux_bindstem)
721               (mk_eq(list_mk_comb(aux_fvar,R1::SV), rhs_proto_def))
722     val def' = SPEC_ALL def
723     val auxFn_capp = lhs(concl def')
724     val auxFn_const = #1(strip_comb auxFn_capp)
725     val (extractants,TCl_0,_) = unzip3 extracta
726     val TCs_0 = op_U aconv TCl_0
727     val disch'd = itlist DISCH (proto_def::WFR::TCs_0) (LIST_CONJ extractants)
728     val inst'd = GEN R1 (MP (SPEC auxFn_capp (GEN f disch'd)) def')
729     fun kdisch keep th =
730       itlist (fn h => fn th => if op_mem aconv h keep then th else DISCH h th)
731              (hyp th) th
732     val disch'dl_0 = map (DISCH proto_def o
733                           DISCH WFR o kdisch [proto_def,WFR])
734                        extractants
735     val disch'dl_1 = map (fn d => MP (SPEC auxFn_capp (GEN f d)) def')
736                          disch'dl_0
737     fun gen_all away tm =
738        let val FV = free_vars tm
739        in itlist (fn v => fn tm =>
740              if mem v away then tm else mk_forall(v,tm)) FV tm
741        end
742     val TCl = map (map (gen_all (R1::f::SV) o subst[f |-> auxFn_capp])) TCl_0
743     val TCs = op_U aconv TCl
744     val full_rqt = WFR::TCs
745     val R2 = mk_select(R1, list_mk_conj full_rqt)
746     val R2abs = rand R2
747     val R2inst'd = SPEC R2 inst'd
748     val fvar = mk_var(fst(dest_var f),
749                       itlist (curry op-->) (map type_of SV) (type_of f))
750     val fvar_app = list_mk_comb(fvar,SV)
751     val (def1,theory1) = make_definition thy (defPrim bindstem)
752               (mk_eq(fvar_app, list_mk_comb(auxFn_const,R2::SV)))
753     val var_wits = LIST_CONJ (map ASSUME full_rqt)
754     val TC_choice_thm =
755         MP (CONV_RULE(BINOP_CONV BETA_CONV)
756                      (ISPECL[R2abs, R1] boolTheory.SELECT_AX))
757            var_wits
758     val elim_chosenTCs =
759           rev_itlist (C ModusPonens) (CONJUNCTS TC_choice_thm) R2inst'd
760     val rules = simplify [GSYM def1] elim_chosenTCs
761     val pat_TCs_list = merge (map pat_of pats) (zip (givens pats) TCl)
762
763     (* and now induction *)
764
765     val aux_ind = Induction.mk_induction theory1
766                       {fconst=auxFn_const, R=R1, SV=SV,
767                        pat_TCs_list=pat_TCs_list}
768     val ics = strip_conj(fst(dest_imp(snd(dest_forall(concl aux_ind)))))
769     fun dest_ic tm = if is_imp tm then strip_conj (fst(dest_imp tm)) else []
770     val ihs = Lib.flatten (map (dest_ic o snd o strip_forall) ics)
771     val nested_ihs = filter (can (find_term (aconv auxFn_const))) ihs
772     (* a nested ih is of the form
773
774           !(c1/\.../\ck ==> R a pat ==> P a)
775
776        where "aux R N" occurs in "c1/\.../\ck" or "a". In the latter case,
777        we have a nested recursion; in the former, there's just a call
778        to aux in the context. In both cases, we want to eliminate "R a pat"
779        by assuming "c1/\.../\ck ==> R a pat" and doing some work. Really,
780        what we prove is something of the form
781
782          !(c1/\.../\ck ==> R a pat) |-
783             (!(c1/\.../\ck ==> R a pat ==> P a))
784               =
785             (!(c1/\.../\ck ==> P a))
786
787        where the c1/\.../\ck might not be there (when there is no
788        context for the recursive call), and where !( ... ) denotes
789        a universal prefix.
790     *)
791     fun fSPEC_ALL th =
792       case Lib.total dest_forall (concl th) of
793           SOME (v,_) => fSPEC_ALL (SPEC v th)
794         | NONE => th
795     fun simp_nested_ih nih =
796      let val (lvs,tm) = strip_forall nih
797          val (ants,Pa) = strip_imp_only tm
798          val P = rator Pa  (* keep R, P, and SV unquantified *)
799          val vs = op_set_diff aconv (free_varsl ants) (R1::P::SV)
800          val V = op_union aconv lvs vs
801          val has_context = (length ants = 2)
802          val ng = list_mk_forall(V,list_mk_imp (front_last ants))
803          val th1 = fSPEC_ALL (ASSUME ng)
804          val th1a = if has_context then UNDISCH th1 else th1
805          val th2 = fSPEC_ALL (ASSUME nih)
806          val th2a = if has_context then UNDISCH th2 else th2
807          val Rab = fst(dest_imp(concl th2a))
808          val th3 = MP th2a th1a
809          val th4 = if has_context
810                    then DISCH (fst(dest_imp(concl th1))) th3
811                    else th3
812          val th5 = GENL lvs th4
813          val th6 = DISCH nih th5
814          val tha = fSPEC_ALL(ASSUME (concl th5))
815          val thb = if has_context then UNDISCH tha else tha
816          val thc = DISCH Rab thb
817          val thd = if has_context
818                     then DISCH (fst(dest_imp(snd(strip_forall ng)))) thc
819                     else thc
820          val the = GENL lvs thd
821          val thf = DISCH_ALL the
822      in
823        MATCH_MP (MATCH_MP IMP_ANTISYM_AX th6) thf
824      end handle e => raise wrap_exn "nestrec.simp_nested_ih"
825                       "failed while trying to generated nested ind. hyp." e
826     val nested_ih_simps = map simp_nested_ih nested_ihs
827     val ind0 = simplify nested_ih_simps aux_ind
828     val ind1 = UNDISCH_ALL (SPEC R2 (GEN R1 (DISCH_ALL ind0)))
829     val ind2 = simplify [GSYM def1] ind1
830     val ind3 = itlist ALPHA_PROVE_HYP (CONJUNCTS TC_choice_thm) ind2
831 in
832    {rules = CONJUNCTS rules,
833     ind = ind3,
834     SV = SV,
835     R = R1,
836     theory = theory1, aux_def = def, def = def1,
837     aux_rules = map UNDISCH_ALL disch'dl_1,
838     aux_ind = aux_ind
839    }
840 end;
841
842
843(*---------------------------------------------------------------------------
844      Performs tupling and also eta-expansion.
845 ---------------------------------------------------------------------------*)
846
847fun tuple_args alist =
848 let
849   val find = Lib.C assoc1 alist
850   fun tupelo tm =
851     case dest_term tm of
852        LAMB (Bvar, Body) => mk_abs (Bvar, tupelo Body)
853      | _ =>
854         let
855           val (g, args) = strip_comb tm
856           val args' = map tupelo args
857         in
858           case find g of
859              NONE => list_mk_comb (g, args')
860            | SOME (_, (stem', argtys)) =>
861               if length args < length argtys  (* partial application *)
862                 then
863                   let
864                     val nvs = map (curry mk_var "a") (drop args argtys)
865                     val nvs' = variants (free_varsl args') nvs
866                     val comb' = mk_comb(stem', list_mk_pair(args' @nvs'))
867                   in
868                     list_mk_abs(nvs', comb')
869                   end
870               else mk_comb(stem', list_mk_pair args')
871         end
872 in
873   tupelo
874 end;
875
876(*---------------------------------------------------------------------------
877     Mutual recursion. This is reduced to an ordinary definition by
878     use of sum types. The n mutually recursive functions are mapped
879     to a single function "mut" having domain and range be sums of
880     the domains and ranges of the given functions. The domain sum
881     has n components. The range sum has k <= n components, built from
882     the set of range types. The arguments of the left hand side of
883     the function are uniformly injected into the domain sum. On the
884     right hand side, every occurrence of a function "f a" is translated
885     to "OUT(mut (IN a))", where IN is the compound injection function,
886     and OUT brings the result back to the original type of "f a".
887     Finally, each rhs is injected into the range sum.
888
889     After that translation, "mut" is defined. And then the individual
890     functions are defined. Rewriting then brings them out.
891
892     After that, induction is easy to recover from the induction theorem
893     for mut.
894 ---------------------------------------------------------------------------*)
895
896fun ndom_rng ty 0 = ([],ty)
897  | ndom_rng ty n =
898      let val (dom,rng) = dom_rng ty
899          val (L,last) = ndom_rng rng (n-1)
900      in (dom::L, last)
901      end;
902
903fun mutrec thy bindstem eqns =
904  let val dom_rng = Type.dom_rng
905      val genvar = Term.genvar
906      val DEPTH_CONV = Conv.DEPTH_CONV
907      val BETA_CONV = Thm.BETA_CONV
908      val OUTL = sumTheory.OUTL
909      val OUTR = sumTheory.OUTR
910      val sum_case_def = sumTheory.sum_case_def
911      val CONJ = Thm.CONJ
912      fun dest_atom tm = (dest_var tm handle HOL_ERR _ => dest_const tm)
913      val eqnl = strip_conj eqns
914      val lhs_info = mk_set(map ((I##length) o strip_comb o lhs) eqnl)
915      val div_tys = map (fn (tm,i) => ndom_rng (type_of tm) i) lhs_info
916      val lhs_info1 = zip (map fst lhs_info) div_tys
917      val dom_tyl = map (list_mk_prod_type o fst) div_tys
918      val rng_tyl = mk_set (map snd div_tys)
919      val mut_dom = end_itlist mk_sum_type dom_tyl
920      val mut_rng = end_itlist mk_sum_type rng_tyl
921      val mut_name = unionStem bindstem
922      val mut = mk_var(mut_name, mut_dom --> mut_rng)
923      fun inform (f,(doml,rng)) =
924        let val s = fst(dest_atom f)
925        in if 1<length doml
926            then (f, (mk_var(s^"_TUPLED",list_mk_prod_type doml --> rng),doml))
927            else (f, (f,doml))
928         end
929      val eqns' = tuple_args (map inform lhs_info1) eqns
930      val eqnl' = strip_conj eqns'
931      val (L,R) = unzip (map dest_eq eqnl')
932      val fnl' = mk_set (map (fst o strip_comb o lhs) eqnl')
933      val fnvar_map = zip lhs_info1 fnl'
934      val gvl = map genvar dom_tyl
935      val gvr = map genvar rng_tyl
936      val injmap = zip fnl' (map2 (C (curry mk_abs)) (inject mut_dom gvl) gvl)
937      fun mk_lhs_mut_app (f,arg) =
938          mk_comb(mut,beta_conv (mk_comb(assoc f injmap,arg)))
939      val L1 = map (mk_lhs_mut_app o dest_comb) L
940      val gv_mut_rng = genvar mut_rng
941      val outfns = map (curry mk_abs gv_mut_rng) (project rng_tyl mut_rng gv_mut_rng)
942      val ty_outs = zip rng_tyl outfns
943      (* now replace each f by \x. outbar(mut(inbar x)) *)
944      fun fout f = (f,assoc (#2(dom_rng(type_of f))) ty_outs)
945      val RNG_OUTS = map fout fnl'
946      fun mk_rhs_mut f v =
947          (f |-> mk_abs(v,beta_conv (mk_comb(assoc f RNG_OUTS,
948                                             mk_lhs_mut_app (f,v)))))
949      val R1 = map (Term.subst (map2 mk_rhs_mut fnl' gvl)) R
950      val eqnl1 = zip L1 R1
951      val rng_injmap =
952            zip rng_tyl (map2 (C (curry mk_abs)) (inject mut_rng gvr) gvr)
953      fun f_rng_in f = (f,assoc (#2(dom_rng(type_of f))) rng_injmap)
954      val RNG_INS = map f_rng_in fnl'
955      val tmp = zip (map (#1 o dest_comb) L) R1
956      val R2 = map (fn (f,r) => beta_conv(mk_comb(assoc f RNG_INS, r))) tmp
957      val R3 = map (rhs o concl o QCONV (DEPTH_CONV BETA_CONV)) R2
958      val mut_eqns = list_mk_conj(map mk_eq (zip L1 R3))
959      val wfrec_res = wfrec_eqns thy mut_eqns
960      val defn =
961        if exists I (#3(unzip3 (#extracta wfrec_res)))   (* nested *)
962        then let val {rules,ind,aux_rules, aux_ind, theory, def,aux_def,...}
963                     = nestrec thy mut_name wfrec_res
964             in {rules=rules, ind=ind, theory=theory,
965                 aux=SOME{rules=aux_rules, ind=aux_ind}}
966              end
967        else let val {rules,R,SV,theory,full_pats_TCs,...}
968                   = stdrec thy mut_name wfrec_res
969             val f = #1(dest_comb(lhs (concl(Lib.trye hd rules))))
970             val ind = Induction.mk_induction theory
971                         {fconst=f, R=R, SV=SV, pat_TCs_list=full_pats_TCs}
972             in {rules=rules, ind=ind, theory=theory, aux=NONE}
973             end
974      val theory1 = #theory defn
975      val mut_rules = #rules defn
976      val mut_constSV = #1(dest_comb(lhs(concl (hd mut_rules))))
977      val (mut_const,params) = strip_comb mut_constSV
978      fun define_subfn (n,((fvar,(argtys,rng)),ftupvar)) thy =
979         let val inbar  = assoc ftupvar injmap
980             val outbar = assoc ftupvar RNG_OUTS
981             val (fvarname,_) = dest_atom fvar
982             val defvars = rev
983                  (Lib.with_flag (Globals.priming, SOME"") (variants [fvar])
984                     (map (curry mk_var "x") argtys))
985             val tup_defvars = list_mk_pair defvars
986             val newty = itlist (curry (op-->)) (map type_of params@argtys) rng
987             val fvar' = mk_var(fvarname,newty)
988             val dlhs  = list_mk_comb(fvar',params@defvars)
989             val Uapp  = mk_comb(mut_constSV,
990                            beta_conv(mk_comb(inbar,list_mk_pair defvars)))
991             val drhs  = beta_conv (mk_comb(outbar,Uapp))
992             val thybind = defExtract(mut_name,n)
993         in
994           (make_definition thy thybind (mk_eq(dlhs,drhs)) , (Uapp,outbar))
995         end
996      fun mk_def triple (defl,thy,Uout_map) =
997            let val ((d,thy'),Uout) = define_subfn triple thy
998            in (d::defl, thy', Uout::Uout_map)
999            end
1000      val (defns,theory2,Uout_map) =
1001            itlist mk_def (Lib.enumerate 0 fnvar_map) ([],theory1,[])
1002      fun apply_outmap th =
1003         let fun matches (pat,_) = Lib.can (Term.match_term pat)
1004                                           (lhs (concl th))
1005             val (_,outf) = Lib.first matches Uout_map
1006         in AP_TERM outf th
1007         end
1008      val mut_rules1 = map apply_outmap mut_rules
1009      val simp = Rules.simplify (OUTL::OUTR::map GSYM defns)
1010      (* finally *)
1011      val mut_rules2 = map simp mut_rules1
1012
1013      (* induction *)
1014      val mut_ind0 = simp (#ind defn)
1015      val pindices = enumerate (map fst div_tys)
1016      val vary = Term.variant(Term.all_varsl
1017                    (concl (hd mut_rules2)::hyp (hd mut_rules2)))
1018      fun mkP (tyl,i) def =
1019          let val V0 = snd(strip_comb(lhs(snd(strip_forall(concl def)))))
1020              val V = drop (#SV wfrec_res) V0
1021              val P = vary (mk_var("P"^Lib.int_to_string i,
1022                                   list_mk_fun_type (tyl@[bool])))
1023           in (P, mk_pabs(list_mk_pair V,list_mk_comb(P, V)))
1024          end
1025      val (Plist,preds) = unzip (map2 mkP pindices defns)
1026      val Psum_case = end_itlist (fn P => fn tm =>
1027           let val Pty = type_of P
1028               val Pdom = #1(dom_rng Pty)
1029               val tmty = type_of tm
1030               val tmdom = #1(dom_rng tmty)
1031               val gv = genvar (sumSyntax.mk_sum(Pdom, tmdom))
1032           in
1033              mk_abs(gv, sumSyntax.mk_sum_case(P,tm,gv))
1034           end) preds
1035      val mut_ind1 = Rules.simplify [sum_case_def] (SPEC Psum_case mut_ind0)
1036      val (ant,_) = dest_imp (concl mut_ind1)
1037      fun mkv (i,ty) = mk_var("v"^Lib.int_to_string i,ty)
1038      val V = map (map mkv)
1039                  (map (Lib.enumerate 0 o fst) pindices)
1040      val Vinj = map2 (fn f => fn vlist =>
1041                        beta_conv(mk_comb(#2 f, list_mk_pair vlist))) injmap V
1042      val und_mut_ind1 = UNDISCH mut_ind1
1043      val tmpl = map (fn (vlist,v) =>
1044                         GENL vlist (Rules.simplify [sum_case_def]
1045                                     (SPEC v und_mut_ind1)))   (zip V Vinj)
1046      val mut_ind2 = GENL Plist (DISCH ant (LIST_CONJ tmpl))
1047  in
1048    { rules = mut_rules2,
1049      ind =  mut_ind2,
1050      SV = #SV wfrec_res,
1051      R = rand (#WFR wfrec_res),
1052      union = defn,
1053      theory = theory2
1054    }
1055  end;
1056
1057
1058(*---------------------------------------------------------------------------
1059       The purpose of pairf is to translate a prospective definition
1060       into a completely tupled format. On entry to pairf, we know that
1061       f is curried, i.e., of type
1062
1063              f : ty1 -> ... -> tyn -> rangety
1064
1065       We build a tupled version of f
1066
1067              f_tupled : ty1 * ... * tyn -> rangety
1068
1069       and then make a definition
1070
1071              f x1 ... xn = f_tupled (x1,...,xn)
1072
1073       We also need to remember how to revert an induction theorem
1074       into the original domain type. This function is not used for
1075       mutual recursion, since things are more complicated there.
1076
1077 ----------------------------------------------------------------------------*)
1078
1079fun move_arg tm =
1080  Option.map
1081    (fn (f, (v, b)) => boolSyntax.mk_eq (Term.mk_comb (f, v), b))
1082    (Lib.total ((Lib.I ## Term.dest_abs) o boolSyntax.dest_eq) tm)
1083
1084fun pairf (stem, eqs0) =
1085 let
1086   val ((f, args), rhs) = dest_hd_eqn eqs0
1087   val argslen = length args
1088 in
1089   if argslen = 1   (* not curried ... do eta-expansion *)
1090     then (tuple_args [(f, (f, map type_of args))] eqs0, stem, I)
1091   else
1092     let
1093       val stem'name = stem ^ "_tupled"
1094       val argtys = map type_of args
1095       val rng_ty = type_of rhs
1096       val tuple_dom = list_mk_prod_type argtys
1097       val stem' = mk_var (stem'name, tuple_dom --> rng_ty)
1098       fun untuple_args (rules, induction) =
1099         let
1100           val eq1 = concl (hd rules)
1101           val (lhs, rhs) = dest_eq (snd (strip_forall eq1))
1102           val (tuplec, args) = strip_comb lhs
1103           val (SV, p) = front_last args
1104           val defvars = rev (Lib.with_flag (Globals.priming, SOME "")
1105                                 (variants (f :: SV))
1106                                 (map (curry mk_var "x") argtys))
1107           val tuplecSV = list_mk_comb (tuplec, SV)
1108           val def_args = SV @ defvars
1109           val fvar =
1110             mk_var (atom_name f,
1111                     list_mk_fun_type (map type_of def_args @ [rng_ty]))
1112           val def  = new_definition (argMunge stem,
1113                        mk_eq (list_mk_comb (fvar, def_args),
1114                               list_mk_comb (tuplecSV, [list_mk_pair defvars])))
1115           val rules' = map (Rewrite.PURE_REWRITE_RULE [GSYM def]) rules
1116           val induction' =
1117             let
1118               val P = fst (dest_var (fst (dest_forall (concl induction))))
1119               val Qty = itlist (curry Type.-->) argtys Type.bool
1120               val Q = mk_primed_var (P, Qty)
1121               val tm =
1122                 mk_pabs (list_mk_pair defvars, list_mk_comb (Q, defvars))
1123             in
1124               GEN Q (PairedLambda.GEN_BETA_RULE
1125                   (SPEC tm (Rewrite.PURE_REWRITE_RULE [GSYM def] induction)))
1126             end
1127         in
1128           (rules', induction') before Theory.delete_const stem'name
1129         end
1130     in
1131       (tuple_args [(f, (stem', argtys))] eqs0, stem'name, untuple_args)
1132     end
1133 end
1134 handle e as HOL_ERR {message = "incompatible types", ...} =>
1135   case move_arg eqs0 of SOME tm => pairf (stem, tm) | NONE => raise e
1136
1137(*---------------------------------------------------------------------------*)
1138(* Abbreviation or prim. rec. definitions.                                   *)
1139(*---------------------------------------------------------------------------*)
1140
1141local fun is_constructor tm = not (is_var tm orelse is_pair tm)
1142in
1143fun non_wfrec_defn (facts,bind,eqns) =
1144 let val ((_,args),_) = dest_hd_eqn eqns
1145 in if Lib.exists is_constructor args
1146    then let
1147      val {Thy=cthy,Tyop=cty,...} =
1148            dest_thy_type (type_of(first is_constructor args))
1149      in
1150        case TypeBasePure.prim_get facts (cthy,cty)
1151        of NONE => raise ERR "non_wfrec_defn" "unexpected lhs in definition"
1152         | SOME tyinfo =>
1153           let val def = Prim_rec.new_recursive_definition
1154                          {name=bind,def = eqns,
1155                           rec_axiom = TypeBasePure.axiom_of tyinfo}
1156               val ind = TypeBasePure.induction_of tyinfo
1157           in PRIMREC{eqs = def, ind = ind, bind=bind}
1158           end
1159      end
1160    else ABBREV {eqn=new_definition (bind, eqns), bind=bind}
1161 end
1162end;
1163
1164fun mutrec_defn (facts,stem,eqns) =
1165 let val {rules, ind, SV, R,
1166          union as {rules=r,ind=i,aux,...},...} = mutrec facts stem eqns
1167     val union' = case aux
1168      of NONE => STDREC{eqs = r,
1169                        ind = i,
1170                        R = R,
1171                        SV=SV, stem=unionStem stem}
1172      | SOME{rules=raux,ind=iaux} =>
1173         NESTREC{eqs = r,
1174                 ind = i,
1175                 R = R,
1176                 SV=SV, stem=unionStem stem,
1177                 aux=STDREC{eqs = raux,
1178                            ind = iaux,
1179                            R = R,
1180                            SV=SV,stem=auxStem stem}}
1181 in MUTREC{eqs = rules,
1182           ind = ind,
1183           R = R, SV=SV, stem=stem, union=union'}
1184 end
1185 handle e => raise wrap_exn "Defn" "mutrec_defn" e;
1186
1187fun nestrec_defn (thy,(stem,stem'),wfrec_res,untuple) =
1188  let val {rules,ind,SV,R,aux_rules,aux_ind,...} = nestrec thy stem' wfrec_res
1189      val (rules', ind') = untuple (rules, ind)
1190  in NESTREC {eqs = rules',
1191              ind = ind',
1192              R = R, SV=SV, stem=stem,
1193              aux=STDREC{eqs = aux_rules,
1194                         ind = aux_ind,
1195                         R = R, SV=SV, stem=auxStem stem'}}
1196  end;
1197
1198fun stdrec_defn (facts,(stem,stem'),wfrec_res,untuple) =
1199 let val {rules,R,SV,full_pats_TCs,...} = stdrec facts stem' wfrec_res
1200     val ((f,_),_) = dest_hd_eqnl rules
1201     val ind = Induction.mk_induction facts
1202                  {fconst=f, R=R, SV=SV, pat_TCs_list=full_pats_TCs}
1203 in
1204 case hyp (LIST_CONJ rules)
1205 of []     => raise ERR "stdrec_defn" "Empty hypotheses"
1206  | [WF_R] =>   (* non-recursive defn via complex patterns *)
1207       (let val (WF,R)    = dest_comb WF_R
1208            val theta     = [Type.alpha |-> hd(snd(dest_type (type_of R)))]
1209            val Empty_thm = INST_TYPE theta relationTheory.WF_EMPTY_REL
1210            val (r1, i1)  = untuple(rules, ind)
1211            val r2        = MATCH_MP (DISCH_ALL (LIST_CONJ r1)) Empty_thm
1212            val i2        = MATCH_MP (DISCH_ALL i1) Empty_thm
1213        in
1214           NONREC {eqs = r2,
1215                   ind = i2, SV=SV, stem=stem}
1216        end handle HOL_ERR _ => raise ERR "stdrec_defn" "")
1217  | otherwise =>
1218        let val (rules', ind') = untuple (rules, ind)
1219        in STDREC {eqs = rules',
1220                   ind = ind',
1221                   R = R, SV=SV, stem=stem}
1222        end
1223 end
1224 handle e => raise wrap_exn "Defn" "stdrec_defn" e
1225
1226(*---------------------------------------------------------------------------
1227    A general, basic, interface to function definitions. First try to
1228    use standard existing machinery to make a prim. rec. definition, or
1229    an abbreviation. If those attempts fail, try to use a wellfounded
1230    definition (with pattern matching, wildcard expansion, etc.). Note
1231    that induction is derived for all wellfounded definitions, but
1232    a termination proof is not attempted. For that, use the entrypoints
1233    in TotalDefn.
1234 ---------------------------------------------------------------------------*)
1235
1236fun holexnMessage (HOL_ERR {origin_structure,origin_function,message}) =
1237      origin_structure ^ "." ^ origin_function ^ ": " ^ message
1238  | holexnMessage e = General.exnMessage e
1239
1240fun is_simple_arg t =
1241  is_var t orelse
1242  (case Lib.total dest_pair t of
1243       NONE => false
1244     | SOME (l,r) => is_simple_arg l andalso is_simple_arg r)
1245
1246fun prim_mk_defn stem eqns =
1247 let
1248   fun err s = raise ERR "prim_mk_defn" s
1249   val _ = Lexis.ok_identifier stem orelse
1250           err (String.concat [Lib.quote stem, " is not alphanumeric"])
1251   val facts = TypeBase.theTypeBase ()
1252 in
1253   case verdict non_wfrec_defn (fn _ => ()) (facts, defSuffix stem, eqns) of
1254     PASS th => th
1255   | FAIL (_, e) =>
1256     case all_fns eqns of
1257       [] => err "no eqns"
1258     | [_] => (* one defn being made *)
1259        let val ((f, args), rhs) = dest_hd_eqn eqns
1260        in
1261          if List.length args > 0 then
1262            if not (can dest_conj eqns) andalso not (free_in f rhs)  andalso
1263               List.all is_simple_arg args
1264            then
1265              (* not recursive, yet failed *)
1266              raise err ("Simple definition failed with message: "^
1267                         holexnMessage e)
1268            else
1269             let
1270                val (tup_eqs, stem', untuple) = pairf (stem, eqns)
1271                  handle HOL_ERR _ =>
1272                    err "failure in internal translation to tupled format"
1273                val wfrec_res = wfrec_eqns facts tup_eqs
1274              in
1275                if exists I (#3 (unzip3 (#extracta wfrec_res)))   (* nested *)
1276                  then nestrec_defn (facts, (stem, stem'), wfrec_res, untuple)
1277                else stdrec_defn  (facts, (stem, stem'), wfrec_res, untuple)
1278              end
1279          else if free_in f rhs then
1280            case move_arg eqns of
1281               SOME tm => prim_mk_defn stem tm
1282             | NONE => err "Simple nullary definition recurses"
1283          else
1284            case free_vars rhs of
1285               [] => err "Nullary definition failed - giving up"
1286             | fvs =>
1287                err ("Free variables (" ^
1288                     String.concat (Lib.commafy (map (#1 o dest_var) fvs)) ^
1289                     ") on RHS of nullary definition")
1290        end
1291     | (_::_::_) => (* mutrec defns being made *)
1292        mutrec_defn (facts, stem, eqns)
1293 end
1294 handle e as HOL_ERR {origin_structure = "Defn", message = message, ...} =>
1295       if not (String.isPrefix "at Induction.mk_induction" message) orelse
1296          PmatchHeuristics.is_classic ()
1297          then raise wrap_exn "Defn" "prim_mk_defn" e
1298       else ( Feedback.HOL_MESG "Trying classic cases heuristic..."
1299            ; PmatchHeuristics.with_classic_heuristic (prim_mk_defn stem) eqns)
1300      | e => raise wrap_exn "Defn" "prim_mk_defn" e
1301
1302(*---------------------------------------------------------------------------*)
1303(* Version of mk_defn that restores the term signature and grammar if it     *)
1304(* fails.                                                                    *)
1305(*---------------------------------------------------------------------------*)
1306
1307fun mk_defn stem eqns =
1308  Parse.try_grammar_extension
1309    (Theory.try_theory_extension (uncurry prim_mk_defn)) (stem, eqns)
1310
1311fun mk_Rdefn stem R eqs =
1312  let val defn = mk_defn stem eqs
1313  in case reln_of defn
1314      of NONE => defn
1315       | SOME Rvar => inst_defn defn (Term.match_term Rvar R)
1316  end;
1317
1318(*===========================================================================*)
1319(* Calculating dependencies among a unordered collection of definitions      *)
1320(*===========================================================================*)
1321
1322(*---------------------------------------------------------------------------*)
1323(* Transitive closure of relation rel given as list of pairs. The TC         *)
1324(* function operates on a list of elements [...,(x,(Y,fringe)),...] where    *)
1325(* Y is the set of "seen" fringe elements, and fringe is those elements of   *)
1326(* Y that we just "arrived at" (thinking of the relation as a directed       *)
1327(* graph) and which are not already in Y. Steps in the graph are made from   *)
1328(* the fringe, in a breadth-first manner.                                    *)
1329(*---------------------------------------------------------------------------*)
1330
1331fun TC rels_0 =
1332 let fun step a = Lib.assoc a rels_0 handle HOL_ERR _ => []
1333     fun relstep rels (x,(Y,fringe)) =
1334       let val fringe' = U (map step fringe)
1335           val Y' = union Y fringe'
1336           val fringe'' = set_diff fringe' Y
1337       in (x,(Y',fringe''))
1338       end
1339     fun steps rels =
1340       case Lib.partition (null o (snd o snd)) rels
1341        of (_,[]) => map (fn (x,(Y,_)) => (x,Y)) rels
1342         | (nullrels,nnullrels) => steps (map (relstep rels) nnullrels @ nullrels)
1343 in
1344   steps (map (fn (x,Y) => (x,(Y,Y))) rels_0)
1345 end;
1346
1347(*---------------------------------------------------------------------------*)
1348(* Transitive closure of a list of pairs.                                    *)
1349(*---------------------------------------------------------------------------*)
1350
1351fun trancl rel =
1352 let val field = U (map (fn (x,y) => [x,y]) rel)
1353     fun init x =
1354       let val Y = rev_itlist (fn (a,b) => fn acc =>
1355                      if a=x then insert b acc else acc) rel []
1356       in (x,Y)
1357       end
1358 in
1359   TC (map init field)
1360 end;
1361
1362(*---------------------------------------------------------------------------*)
1363(* partial order on the relation.                                            *)
1364(*---------------------------------------------------------------------------*)
1365
1366fun depends_on (a,adeps) (b,bdeps) = mem b adeps andalso not (mem a bdeps);
1367
1368(*---------------------------------------------------------------------------*)
1369(* Given a transitively closed relation, topsort it into dependency order,   *)
1370(* then build the mutually dependent chunks (cliques).                       *)
1371(*---------------------------------------------------------------------------*)
1372
1373fun cliques_of tcrel =
1374 let fun chunk [] acc = acc
1375       | chunk ((a,adeps)::t) acc =
1376         let val (bideps,rst) = Lib.partition (fn (b,bdeps) => mem a bdeps) t
1377         in chunk rst (((a,adeps)::bideps)::acc)
1378         end
1379     val sorted = Lib.topsort depends_on tcrel
1380 in chunk sorted []
1381 end;
1382
1383(*---------------------------------------------------------------------------
1384 Examples.
1385val ex1 =
1386 [("a","b"), ("b","c"), ("b","d"),("d","c"),
1387  ("a","e"), ("e","f"), ("f","e"), ("d","a"),
1388  ("f","g"), ("g","g"), ("g","i"), ("g","h"),
1389  ("i","h"), ("i","k"), ("h","j")];
1390
1391val ex2 =  ("a","z")::ex1;
1392val ex3 =  ("z","a")::ex1;
1393val ex4 =  ("z","c")::ex3;
1394val ex5 =  ("c","z")::ex3;
1395val ex6 =  ("c","i")::ex3;
1396
1397cliques_of (trancl ex1);
1398cliques_of (trancl ex2);
1399cliques_of (trancl ex3);
1400cliques_of (trancl ex4);
1401cliques_of (trancl ex5);
1402cliques_of (trancl ex6);
1403  --------------------------------------------------------------------------*)
1404
1405(*---------------------------------------------------------------------------*)
1406(* Find the variables free in the rhs of an equation, that don't also occur  *)
1407(* on the lhs. This helps locate calls to other functions also being defined,*)
1408(* so that the dependencies can be calculated.                               *)
1409(*---------------------------------------------------------------------------*)
1410
1411fun free_calls tm =
1412 let val (l,r) = dest_eq tm
1413     val (f,pats) = strip_comb l
1414     val lhs_vars = free_varsl pats
1415     val rhs_vars = free_vars r
1416 in
1417    (f, set_diff rhs_vars lhs_vars)
1418 end;
1419
1420(*---------------------------------------------------------------------------*)
1421(* Find all the dependencies between a collection of equations. Mutually     *)
1422(* dependent equations end up in the same clique.                            *)
1423(*---------------------------------------------------------------------------*)
1424
1425fun dependencies eqns =
1426  let val basic = map free_calls (strip_conj eqns)
1427      fun agglom ((f,l1)::(g,l2)::t) =
1428             if f=g then agglom ((f,union l1 l2)::t)
1429              else (f,l1)::agglom ((g,l2)::t)
1430        | agglom other = other
1431 in
1432   cliques_of (TC (agglom basic))
1433 end;
1434
1435(*---------------------------------------------------------------------------
1436 Examples.
1437
1438val eqns = ``(f x = 0) /\ (g y = 1) /\ (h a = h (a-1) + g a)``;
1439dependencies eqns;
1440
1441val eqns = ``(f [] = 0) /\ (f (h1::t1) = g(h1) + f t1) /\
1442             (g y = 1) /\ (h a = h (a-1) + g a)``;
1443dependencies eqns;
1444
1445load"stringTheory";
1446Hol_datatype `term = Var of string | App of string => term list`;
1447Hol_datatype `pterm = pVar of string | pApp of string # pterm list`;
1448
1449val eqns = ``(f (pVar s) = sing s) /\
1450             (f (pApp (a,ptl)) = onion (sing a) (g ptl)) /\
1451             (g [] = {}) /\
1452             (g (t::tlist) = onion (f t) (g tlist)) /\
1453             (sing a = {a}) /\ (onion s1 s2 = s1 UNION s2)``;
1454dependencies eqns;
1455  ---------------------------------------------------------------------------*)
1456
1457val eqn_head = head o lhs;
1458
1459(*---------------------------------------------------------------------------*)
1460(* Take a collection of equations for a set of different functions, and      *)
1461(* untangle the dependencies so that functions are defined before use.       *)
1462(* Handles (mutually) recursive dependencies.                                *)
1463(*---------------------------------------------------------------------------*)
1464
1465fun sort_eqns eqns =
1466 let val eql = strip_conj eqns
1467     val cliques = dependencies eqns
1468     val cliques' = map (map fst) cliques
1469     fun clique_eqns clique =
1470          filter (fn eqn => mem (eqn_head eqn) clique) eql
1471 in
1472   map (list_mk_conj o clique_eqns) cliques'
1473 end;
1474
1475(*---------------------------------------------------------------------------
1476 Example (adapted from ex1 above).
1477
1478val eqns =
1479 ``(f1 x = f2 (x-1) + f5 (x-1)) /\
1480   (f2 x = f3 (x-1) + f4 (x-1)) /\
1481   (f3 x = x + 2) /\
1482   (f4 x = f3 (x-1) + f1(x-2)) /\
1483   (f5 x = f6 (x-1)) /\
1484   (f6 x = f5(x-1) + f7(x-2)) /\
1485   (f7 0 = f8 3 + f9 4) /\
1486   (f7 (SUC n) = SUC n * f7 n + f9 n) /\
1487   (f8 x = f10 (x + x)) /\
1488   (f9 x = f8 x + f11 (x-4)) /\
1489   (f10 x = x MOD 2) /\
1490   (f11 x = x DIV 2)``;
1491
1492sort_eqns eqns;
1493sort_eqns (list_mk_conj (rev(strip_conj eqns)));
1494 ---------------------------------------------------------------------------*)
1495
1496fun mk_defns stems eqnsl =
1497 let fun lhs_atoms eqns = mk_set (map (head o lhs) (strip_conj eqns))
1498     fun mk_defn_theta stem eqns =
1499       let val Vs = lhs_atoms eqns
1500           val def = mk_defn stem eqns
1501           val Cs = mk_set (map (head o lhs o snd o strip_forall)
1502                                (flatten
1503                                   (map (strip_conj o concl) (eqns_of def))))
1504           val theta = map2 (curry (op |->)) Vs Cs
1505       in (def, theta)
1506       end
1507    (*----------------------------------------------------------------------*)
1508    (* Once a definition is made, substitute the introduced constant for    *)
1509    (* the corresponding variables in the rest of the equations waiting to  *)
1510    (* be defined.                                                          *)
1511    (*----------------------------------------------------------------------*)
1512    fun mapdefn [] = []
1513      | mapdefn ((s,e)::t) =
1514         let val (defn,theta) = mk_defn_theta s e
1515             val t' = map (fn (s,e) => (s,subst theta e)) t
1516         in defn :: mapdefn t'
1517         end
1518 in
1519   mapdefn (zip stems eqnsl)
1520 end
1521 handle e => raise wrap_exn "Defn" "mk_defns" e;
1522
1523(*---------------------------------------------------------------------------
1524     Quotation interface to definition. This includes a pass for
1525     expansion of wildcards in patterns.
1526 ---------------------------------------------------------------------------*)
1527
1528fun vary s S =
1529 let fun V n =
1530      let val s' = s^Lib.int_to_string n
1531      in if mem s' S then V (n+1) else (s',s'::S)
1532      end
1533 in V 0 end;
1534
1535(*---------------------------------------------------------------------------
1536    A wildcard is a contiguous sequence of underscores. This is
1537    somewhat cranky, we realize, but restricting to only one
1538    is not great for readability at times.
1539 ---------------------------------------------------------------------------*)
1540
1541fun wildcard s =
1542    s <> "" andalso CharVector.all (fn #"_" => true | _ => false) s
1543
1544local open Absyn
1545in
1546fun vnames_of (VAQ(_,tm)) S = union (map (fst o Term.dest_var) (all_vars tm)) S
1547  | vnames_of (VIDENT(_,s)) S = union [s] S
1548  | vnames_of (VPAIR(_,v1,v2)) S = vnames_of v1 (vnames_of v2 S)
1549  | vnames_of (VTYPED(_,v,_)) S = vnames_of v S
1550
1551fun names_of (AQ(_,tm)) S = union (map (fst o Term.dest_var) (all_vars tm)) S
1552  | names_of (IDENT(_,s)) S = union [s] S
1553  | names_of (APP(_,IDENT(_,"case_arrow__magic"), _)) S = S
1554  | names_of (APP(_,M,N)) S = names_of M (names_of N S)
1555  | names_of (LAM(_,v,M)) S = names_of M (vnames_of v S)
1556  | names_of (TYPED(_,M,_)) S = names_of M S
1557  | names_of (QIDENT(_,_,_)) S = S
1558end;
1559
1560local
1561  val v_vary = vary "v"
1562  fun tm_exp tm S =
1563    case dest_term tm
1564    of VAR(s,Ty) =>
1565         if wildcard s then
1566           let val (s',S') = v_vary S in (Term.mk_var(s',Ty),S') end
1567         else (tm,S)
1568     | CONST _  => (tm,S)
1569     | COMB(Rator,Rand) =>
1570        let val (Rator',S')  = tm_exp Rator S
1571            val (Rand', S'') = tm_exp Rand S'
1572        in (mk_comb(Rator', Rand'), S'')
1573        end
1574     | LAMB _ => raise ERR "tm_exp" "abstraction in pattern"
1575  open Absyn
1576in
1577fun exp (AQ(locn,tm)) S =
1578      let val (tm',S') = tm_exp tm S in (AQ(locn,tm'),S') end
1579  | exp (IDENT (p as (locn,s))) S =
1580      if wildcard s
1581        then let val (s',S') = v_vary S in (IDENT(locn,s'), S') end
1582        else (IDENT p, S)
1583  | exp (QIDENT (p as (locn,s,_))) S =
1584      if wildcard s
1585       then raise ERRloc "exp" locn "wildcard in long id. in pattern"
1586       else (QIDENT p, S)
1587  | exp (APP(locn,M,N)) S =
1588      let val (M',S')   = exp M S
1589          val (N', S'') = exp N S'
1590      in (APP (locn,M',N'), S'')
1591      end
1592  | exp (TYPED(locn,M,pty)) S =
1593      let val (M',S') = exp M S in (TYPED(locn,M',pty),S') end
1594  | exp (LAM(locn,_,_)) _ = raise ERRloc "exp" locn "abstraction in pattern"
1595
1596fun expand_wildcards asy (asyl,S) =
1597   let val (asy',S') = exp asy S in (asy'::asyl, S') end
1598end;
1599
1600fun multi_dest_eq t =
1601    Absyn.dest_eq t
1602      handle HOL_ERR _ => Absyn.dest_binop "<=>" t
1603      handle HOL_ERR _ => raise ERRloc "multi_dest_eq"
1604                                       (Absyn.locn_of_absyn t)
1605                                       "Expected an equality"
1606
1607local
1608  fun dest_pvar (Absyn.VIDENT(_,s)) = s
1609    | dest_pvar other = raise ERRloc "munge" (Absyn.locn_of_vstruct other)
1610                                     "dest_pvar"
1611  fun dest_atom tm = (dest_const tm handle HOL_ERR _ => dest_var tm);
1612  fun dest_head (Absyn.AQ(_,tm)) = fst(dest_atom tm)
1613    | dest_head (Absyn.IDENT(_,s)) = s
1614    | dest_head (Absyn.TYPED(_,a,_)) = dest_head a
1615    | dest_head (Absyn.QIDENT(locn,_,_)) =
1616            raise ERRloc "dest_head" locn "qual. ident."
1617    | dest_head (Absyn.APP(locn,_,_)) =
1618            raise ERRloc "dest_head" locn "app. node"
1619    | dest_head (Absyn.LAM(locn,_,_)) =
1620            raise ERRloc "dest_head" locn "lam. node"
1621  fun strip_tyannote0 acc absyn =
1622      case absyn of
1623        Absyn.TYPED(locn, a, ty) => strip_tyannote0 ((ty,locn)::acc) a
1624      | x => (List.rev acc, x)
1625  val strip_tyannote = strip_tyannote0 []
1626  fun list_mk_tyannote(tyl,a) =
1627      List.foldl (fn ((ty,locn),t) => Absyn.TYPED(locn,t,ty)) a tyl
1628in
1629fun munge eq (eqs,fset,V) =
1630 let val (vlist,body) = Absyn.strip_forall eq
1631     val (lhs0,rhs)   = multi_dest_eq body
1632(*     val   _          = if exists wildcard (names_of rhs []) then
1633                         raise ERRloc "munge" (Absyn.locn_of_absyn rhs)
1634                                      "wildcards on rhs" else () *)
1635     val (tys, lhs)   = strip_tyannote lhs0
1636     val (f,pats)     = Absyn.strip_app lhs
1637     val (pats',V')   = rev_itlist expand_wildcards pats
1638                            ([],Lib.union V (map dest_pvar vlist))
1639     val new_lhs0     = Absyn.list_mk_app(f,rev pats')
1640     val new_lhs      = list_mk_tyannote(tys, new_lhs0)
1641     val new_eq       = Absyn.list_mk_forall(vlist, Absyn.mk_eq(new_lhs, rhs))
1642     val fstr         = dest_head f
1643 in
1644    (new_eq::eqs, insert fstr fset, V')
1645 end
1646end;
1647
1648fun elim_wildcards eqs =
1649 let val names = names_of eqs []
1650     val (eql,fset,_) = rev_itlist munge (Absyn.strip_conj eqs) ([],[],names)
1651 in
1652   (Absyn.list_mk_conj (rev eql), rev fset)
1653 end;
1654
1655(*---------------------------------------------------------------------------*)
1656(* To parse a purported definition, we have to convince the parser that the  *)
1657(* names to be defined aren't constants.  We can do this using "hide".       *)
1658(* After the parsing has been done, the grammar has to be put back the way   *)
1659(* it was.  If a definition is subsequently made, this will update the       *)
1660(* grammar further (ultimately using add_const).                             *)
1661(*---------------------------------------------------------------------------*)
1662
1663fun non_head_idents acc alist =
1664    case alist of
1665      [] => acc
1666    | Absyn.IDENT(_, s)::rest => let
1667        val acc' =
1668            if Lexis.is_string_literal s orelse Lexis.is_char_literal s
1669            then acc
1670            else HOLset.add(acc,s)
1671      in
1672        non_head_idents acc' rest
1673      end
1674    | (a as Absyn.APP(_, _, x))::rest => let
1675        val (_, args) = Absyn.strip_app a
1676      in
1677        non_head_idents acc (args @ rest)
1678      end
1679    | Absyn.TYPED(_, a, _)::rest => non_head_idents acc (a::rest)
1680    | _ :: rest => non_head_idents acc rest
1681
1682fun get_param_names eqs_a = let
1683  val eqs = Absyn.strip_conj eqs_a
1684  val heads = map (#1 o multi_dest_eq o #2 o Absyn.strip_forall) eqs
1685in
1686  non_head_idents (HOLset.empty String.compare) heads |> HOLset.listItems
1687end
1688
1689fun is_constructor_name oinfo s = let
1690  val possible_ops =
1691      case Overload.info_for_name oinfo s of
1692        NONE => []
1693      | SOME {actual_ops, ...} => actual_ops
1694in
1695  List.exists TypeBase.is_constructor possible_ops
1696end
1697
1698fun unify_error pv1 pv2 = let
1699  open Preterm
1700  val (nm,ty1,l1) = dest_ptvar pv1
1701  val (_,ty2,l2) = dest_ptvar pv2
1702in
1703  "Couldn't unify types of head symbol " ^
1704  Lib.quote nm ^ " at positions " ^ locn.toShortString l1 ^ " and " ^
1705  locn.toShortString l2 ^ " with types " ^
1706  type_to_string (Pretype.toType ty1) ^ " and " ^
1707  type_to_string (Pretype.toType ty2)
1708end
1709
1710fun ptdefn_freevars pt = let
1711  open Preterm
1712  val (uvars, body) = strip_pforall pt
1713  val (l,r) = case strip_pcomb body of
1714                  (_, [l,r]) => (l,r)
1715                | _ => raise ERRloc "ptdefn_freevars" (locn body)
1716                             "Couldn't see preterm as equality"
1717  val (f0, args) = strip_pcomb l
1718  val f = head_var f0
1719  val lfs = op_U eq (map ptfvs args)
1720  val rfs = ptfvs r
1721  infix \\
1722  fun s1 \\ s2 = op_set_diff eq s1 s2
1723in
1724  op_union eq (rfs \\ lfs \\ uvars) [f]
1725end
1726
1727fun defn_absyn_to_term a = let
1728  val alist = Absyn.strip_conj a
1729  open errormonad
1730  val tycheck = Preterm.typecheck_phase1 (SOME (term_to_string, type_to_string))
1731  val ptsM =
1732      mmap
1733        (fn a => absyn_to_preterm a >- (fn ptm => tycheck ptm >> return ptm))
1734        alist
1735  fun foldthis (pv as Preterm.Var{Name,Ty,Locn}, env) =
1736    if String.sub(Name,0) = #"_" then return env
1737    else
1738      (case Binarymap.peek(env,Name) of
1739           NONE => return (Binarymap.insert(env,Name,pv))
1740         | SOME pv' =>
1741             Preterm.ptype_of pv' >- (fn pty' => Pretype.unify Ty pty') >>
1742             return env)
1743    | foldthis (_, env) = raise Fail "defn_absyn_to_term: can't happen"
1744  open Preterm
1745  fun construct_final_term pts =
1746    let
1747      val ptm = plist_mk_rbinop
1748                  (Antiq {Tm=boolSyntax.conjunction,Locn=locn.Loc_None})
1749                  pts
1750    in
1751      overloading_resolution ptm >- (fn (pt,b) =>
1752      report_ovl_ambiguity b     >>
1753      to_term pt                 >- (fn t =>
1754      return (t |> remove_case_magic |> !post_process_term)))
1755    end
1756  val M =
1757    ptsM >-
1758    (fn pts =>
1759      let
1760        val all_frees = op_U Preterm.eq (map ptdefn_freevars pts)
1761      in
1762        foldlM foldthis (Binarymap.mkDict String.compare) all_frees >>
1763        construct_final_term pts
1764      end)
1765in
1766  smash M Pretype.Env.empty
1767end
1768
1769fun parse_absyn absyn0 = let
1770  val (absyn,fn_names) = elim_wildcards absyn0
1771  val oldg = term_grammar()
1772  val oinfo = term_grammar.overload_info oldg
1773  val nonconstructor_parameter_names =
1774      List.filter (not o is_constructor_name oinfo) (get_param_names absyn)
1775  val _ =
1776      app (ignore o Parse.hide) (nonconstructor_parameter_names @ fn_names)
1777  fun restore() = temp_set_grammars(type_grammar(), oldg)
1778  val tm  = defn_absyn_to_term absyn handle e => (restore(); raise e)
1779in
1780  restore();
1781  (tm, fn_names)
1782end;
1783
1784(*---------------------------------------------------------------------------*)
1785(* Parse a quotation. Fail if parsing or type inference or overload          *)
1786(* resolution (etc) fail. Returns a list of equations; each element in the   *)
1787(* list is a separate mutually recursive clique.                             *)
1788(*---------------------------------------------------------------------------*)
1789
1790fun parse_quote q =
1791  sort_eqns (fst (parse_absyn (Parse.Absyn q)))
1792  handle e => raise wrap_exn "Defn" "parse_quote" e;
1793
1794fun Hol_defn stem q =
1795 (case parse_quote q
1796   of [] => raise ERR "Hol_defn" "no definitions"
1797    | [eqns] => mk_defn stem eqns
1798    | otherwise => raise ERR "Hol_defn" "multiple definitions")
1799  handle e =>
1800  raise wrap_exn_loc "Defn" "Hol_defn"
1801           (Absyn.locn_of_absyn (Parse.Absyn q)) e;
1802
1803fun Hol_defns stems q =
1804 (case parse_quote q
1805   of [] => raise ERR "Hol_defns" "no definition"
1806    | eqnl => mk_defns stems eqnl)
1807  handle e => raise wrap_exn_loc "Defn" "Hol_defns"
1808                 (Absyn.locn_of_absyn (Parse.Absyn q)) e;
1809
1810local
1811  val stems =
1812    List.map (fst o dest_var o fst o strip_comb o lhs o snd o strip_forall o
1813              hd o strip_conj)
1814in
1815  fun Hol_multi_defns q =
1816    (case parse_quote q of
1817       [] => raise ERR "Hol_multi_defns" "no definition"
1818      | eqnsl => mk_defns (stems eqnsl) eqnsl)
1819    handle e => raise wrap_exn_loc "Defn" "Hol_multi_defns"
1820                   (Absyn.locn_of_absyn (Parse.Absyn q)) e
1821end
1822
1823fun Hol_Rdefn stem Rquote eqs_quote =
1824  let val defn = Hol_defn stem eqs_quote
1825  in case reln_of defn
1826      of NONE => defn
1827       | SOME Rvar =>
1828          let val R = Parse.typedTerm Rquote (type_of Rvar)
1829          in inst_defn defn (Term.match_term Rvar R)
1830          end
1831  end;
1832
1833(*---------------------------------------------------------------------------
1834        Goalstack-based interface to termination proof.
1835 ---------------------------------------------------------------------------*)
1836
1837fun mangle th [] = th
1838  | mangle th [h] = DISCH h th
1839  | mangle th (h::rst) =
1840      Rewrite.PURE_ONCE_REWRITE_RULE [boolTheory.AND_IMP_INTRO]
1841         (DISCH h (mangle th rst));
1842
1843
1844(*---------------------------------------------------------------------------
1845    Have to take care with how the assumptions are discharged. Hence mangle.
1846 ---------------------------------------------------------------------------*)
1847
1848val WF_tm = prim_mk_const{Name="WF",Thy="relation"};
1849
1850fun get_WF tmlist =
1851 pluck (same_const WF_tm o rator) tmlist
1852 handle HOL_ERR _ => raise ERR "get_WF" "unexpected termination condition";
1853
1854fun TC_TAC0 E I =
1855 let val th = CONJ E I
1856     val asl = hyp th
1857     val hyps' = let val (wfr,rest) = get_WF asl
1858                 in wfr::rest end handle HOL_ERR _ => asl
1859     val tac = MATCH_MP_TAC (GEN_ALL (mangle th hyps'))
1860     val goal = ([],concl th)
1861 in
1862   case tac goal
1863    of ([([],g)],validation) => (([],g), fn th => validation [th])
1864     | _  => raise ERR "TC_TAC" "unexpected output"
1865 end;
1866
1867fun TC_TAC defn =
1868 let val E = LIST_CONJ (eqns_of defn)
1869     val I = Option.valOf (ind_of defn)
1870 in
1871   TC_TAC0 E I
1872 end;
1873
1874fun tgoal_no_defn0 (def,ind) =
1875   if null (op_U aconv [(hyp def)])
1876   then raise ERR "tgoal" "no termination conditions"
1877   else let val (g,validation) = TC_TAC0 def ind
1878        in proofManagerLib.add (Manager.new_goalstack g validation)
1879        end handle HOL_ERR _ => raise ERR "tgoal" "";
1880
1881fun tgoal_no_defn (def,ind) =
1882  Lib.with_flag (proofManagerLib.chatting,false) tgoal_no_defn0 (def,ind);
1883
1884fun tgoal0 defn =
1885   if null (tcs_of defn)
1886   then raise ERR "tgoal" "no termination conditions"
1887   else let val (g,validation) = TC_TAC defn
1888        in proofManagerLib.add (Manager.new_goalstack g validation)
1889        end handle HOL_ERR _ => raise ERR "tgoal" "";
1890
1891fun tgoal defn = Lib.with_flag (proofManagerLib.chatting,false) tgoal0 defn;
1892
1893(*---------------------------------------------------------------------------
1894     The error handling here is pretty coarse.
1895 ---------------------------------------------------------------------------*)
1896
1897fun tprove2 tgoal0 (defn,tactic) =
1898  let val _ = tgoal0 defn
1899      val _ = proofManagerLib.expand tactic  (* should finish proof off *)
1900      val th  = proofManagerLib.top_thm ()
1901      val _   = proofManagerLib.drop()
1902      val eqns = CONJUNCT1 th
1903      val ind  = CONJUNCT2 th
1904  in
1905     (eqns,ind)
1906  end
1907  handle e => (proofManagerLib.drop(); raise wrap_exn "Defn" "tprove" e)
1908
1909fun tprove1 tgoal0 p =
1910  let
1911    val (eqns,ind) = Lib.with_flag (proofManagerLib.chatting,false) (tprove2 tgoal0) p
1912    val () = if not (!computeLib.auto_import_definitions) then ()
1913             else computeLib.add_funs
1914                    [eqns, CONV_RULE (!SUC_TO_NUMERAL_DEFN_CONV_hook) eqns]
1915  in
1916    (eqns, ind)
1917  end
1918
1919fun tprove p = tprove1 tgoal0 p
1920fun tprove0 p = tprove2 tgoal0 p
1921fun tprove_no_defn p = tprove1 tgoal_no_defn0 p
1922
1923fun tstore_defn (d,t) =
1924  let val (def,ind) = tprove0 (d,t)
1925  in store (name_of d,def,ind)
1926   ; (def,ind)
1927  end;
1928
1929end (* Defn *)
1930