1structure Defn :> Defn =
2struct
3
4open HolKernel Parse boolLib;
5open pairLib Rules wfrecUtils Pmatch Induction DefnBase;
6
7type thry   = TypeBasePure.typeBase
8type proofs = Manager.proofs
9type absyn  = Absyn.absyn;
10
11val ERR = mk_HOL_ERR "Defn";
12val ERRloc = mk_HOL_ERRloc "Defn";
13
14val monitoring = ref false;
15
16(* Interactively:
17  val const_eq_ref = ref (!Defn.const_eq_ref);
18*)
19(*---------------------------------------------------------------------------
20      Miscellaneous support
21 ---------------------------------------------------------------------------*)
22
23fun enumerate l = map (fn (x,y) => (y,x)) (Lib.enumerate 0 l);
24
25fun drop [] x = x
26  | drop (_::t) (_::rst) = drop t rst
27  | drop _ _ = raise ERR "drop" "";
28
29fun variants FV vlist =
30  fst
31    (rev_itlist
32       (fn v => fn (V,W) =>
33           let val v' = variant W v in (v'::V, v'::W) end) vlist ([],FV));
34
35fun numvariants FV vlist =
36  fst
37    (rev_itlist
38       (fn v => fn (V,W) =>
39           let val v' = numvariant W v in (v'::V, v'::W) end) vlist ([],FV));
40
41fun make_definition thry s tm = (new_definition(s,tm), thry)
42
43fun head tm = head (rator tm) handle HOL_ERR _ => tm;
44
45fun all_fns eqns =
46  op_mk_set aconv (map (head o lhs o #2 o strip_forall) (strip_conj eqns));
47
48fun dest_hd_eqn eqs =
49  let val hd_eqn = if is_conj eqs then fst(dest_conj eqs) else eqs
50      val (lhs,rhs) = dest_eq (snd(strip_forall hd_eqn))
51  in (strip_comb lhs, rhs)
52  end;
53
54fun dest_hd_eqnl (hd_eqn::_) =
55  let val (lhs,rhs) = dest_eq (snd (strip_forall (concl hd_eqn)))
56  in (strip_comb lhs, rhs)
57  end
58  | dest_hd_eqnl _ = raise Match
59
60fun extract_info constset db =
61    let open TypeBasePure
62        fun foldthis (tyinfo, (R, C)) =
63            if List.exists
64                (fn x => same_const (case_const_of tyinfo) x
65                         handle HOL_ERR _ => false) constset then
66              (case_def_of tyinfo::R, case_cong_of tyinfo::C)
67            else (R, C)
68      val (rws,congs) = foldl foldthis ([], []) (listItems db)
69    in {case_congs=congs, case_rewrites=rws}
70    end;
71
72(*---------------------------------------------------------------------------
73    Support for automatically building names to store definitions
74    (and the consequences thereof) with in the current theory. Somewhat
75    ad hoc, but I don't know a better way!
76 ---------------------------------------------------------------------------*)
77
78val ind_suffix = ref "_ind";
79val def_suffix = boolLib.def_suffix
80
81fun defSuffix s     = (s ^ !def_suffix);
82fun defPrim s       = defSuffix(s^"_primitive");
83fun defExtract(s,n) = defSuffix(s^"_extract"^Lib.int_to_string n);
84fun argMunge s      = defSuffix(s^"_curried");
85fun auxStem stem    = stem^"_AUX";
86fun unionStem stem  = stem^"_UNION";
87
88val imp_elim =
89 let val P = mk_var("P",bool)
90     val Q = mk_var("Q",bool)
91     val R = mk_var("R",bool)
92     val PimpQ = mk_imp(P,Q)
93     val PimpR = mk_imp(P,R)
94     val tm = mk_eq(PimpQ,PimpR)
95     val tm1 = mk_imp(P,tm)
96     val th1 = DISCH tm (DISCH P (ASSUME tm))
97     val th2 = ASSUME tm1
98     val th2a = ASSUME P
99     val th3 = MP th2 th2a
100     val th3a = EQ_MP (SPECL[PimpQ, PimpR] boolTheory.EQ_IMP_THM) th3
101     val (th4,th5) = (CONJUNCT1 th3a,CONJUNCT2 th3a)
102     fun pmap f (x,y) = (f x, f y)
103     val (th4a,th5a) = pmap (DISCH P o funpow 2 UNDISCH) (th4,th5)
104     val th4b = DISCH PimpQ th4a
105     val th5b = DISCH PimpR th5a
106     val th6 = DISCH tm1 (IMP_ANTISYM_RULE th4b th5b)
107     val th7 = DISCH tm (DISCH P (ASSUME tm))
108 in GENL [P,Q,R]
109         (IMP_ANTISYM_RULE th6 th7)
110 end;
111
112fun inject ty [v] = [v]
113  | inject ty (v::vs) =
114     let val (lty,rty) = case dest_type ty of
115                           (_, [x,y]) => (x,y)
116                         | _ => raise Bind
117         val res = mk_comb(mk_const("INL", lty-->ty),v)
118         val inr = curry mk_comb (mk_const("INR", rty-->ty))
119     in
120       res::map inr (inject rty vs)
121     end
122  | inject ty [] = raise Match
123
124
125fun project [] _ _ = raise ERR "project"
126                           "catastrophic invariant failure (eqns was empty!?)"
127  | project [_] ty M = [M]
128  | project (_::ls) ty M = let
129      val (lty,rty) = sumSyntax.dest_sum ty
130      in mk_comb(mk_const("OUTL", type_of M-->lty),M)
131         :: project ls rty (mk_comb(mk_const("OUTR", type_of M-->rty),M))
132      end
133
134(*---------------------------------------------------------------------------*
135 * We need a "smart" MP. th1 can be less quantified than th2, so th2 has     *
136 * to be specialized appropriately. We assume that all the "local"           *
137 * variables are quantified first.                                           *
138 *---------------------------------------------------------------------------*)
139
140fun ModusPonens th1 th2 =
141  let val V1 = #1(strip_forall(fst(dest_imp(concl th1))))
142      val V2 = #1(strip_forall(concl th2))
143      val diff = Lib.op_set_diff Term.aconv V2 V1
144      fun loop th =
145        if is_forall(concl th)
146        then let val (Bvar,Body) = dest_forall (concl th)
147             in if Lib.op_mem Term.aconv Bvar diff
148                then loop (SPEC Bvar th) else th
149             end
150        else th
151  in
152    MP th1 (loop th2)
153  end
154  handle _ => raise ERR "ModusPonens" "failed";
155
156
157(*---------------------------------------------------------------------------*)
158(* Version of PROVE_HYP that works modulo permuting outer universal quants.  *)
159(*---------------------------------------------------------------------------*)
160
161fun ALPHA_PROVE_HYP th1 th2 =
162 let val asl = hyp th2
163     val (U,tm) = strip_forall (concl th1)
164     val a = Lib.first (fn t => aconv tm (snd(strip_forall t))) asl
165     val V = fst(strip_forall a)
166     val th1' = GENL V (SPECL U th1)
167 in
168   PROVE_HYP th1' th2
169 end;
170
171fun name_of (ABBREV {bind, ...})           = bind
172  | name_of (PRIMREC{bind, ...})           = bind
173  | name_of (NONREC {eqs, ind, stem, ...}) = stem
174  | name_of (STDREC {eqs, ind, stem, ...}) = stem
175  | name_of (MUTREC {eqs,ind,stem,...})    = stem
176  | name_of (NESTREC{eqs,ind,stem, ...})   = stem
177  | name_of (TAILREC{eqs,ind,stem, ...})   = stem
178
179fun eqns_of (ABBREV  {eqn, ...}) = [eqn]
180  | eqns_of (NONREC  {eqs, ...}) = [eqs]
181  | eqns_of (PRIMREC {eqs, ...}) = [eqs]
182  | eqns_of (STDREC  {eqs, ...}) = eqs
183  | eqns_of (NESTREC {eqs, ...}) = eqs
184  | eqns_of (MUTREC  {eqs, ...}) = eqs
185  | eqns_of (TAILREC {eqs, ...}) = eqs;
186
187
188fun aux_defn (NESTREC {aux, ...}) = SOME aux
189  | aux_defn     _  = NONE;
190
191fun union_defn (MUTREC {union, ...}) = SOME union
192  | union_defn     _  = NONE;
193
194fun ind_of (ABBREV _)           = NONE
195  | ind_of (NONREC  {ind, ...}) = SOME ind
196  | ind_of (PRIMREC {ind, ...}) = SOME ind
197  | ind_of (STDREC  {ind, ...}) = SOME ind
198  | ind_of (NESTREC {ind, ...}) = SOME ind
199  | ind_of (MUTREC  {ind, ...}) = SOME ind
200  | ind_of (TAILREC {ind, ...}) = SOME ind;
201
202
203fun params_of (ABBREV _)  = []
204  | params_of (PRIMREC _) = []
205  | params_of (NONREC {SV, ...}) = SV
206  | params_of (STDREC {SV, ...}) = SV
207  | params_of (NESTREC{SV, ...}) = SV
208  | params_of (MUTREC {SV, ...}) = SV
209  | params_of (TAILREC {SV, ...}) = SV
210
211fun schematic defn = not(List.null (params_of defn));
212
213fun tcs_of (ABBREV _)  = []
214  | tcs_of (NONREC _)  = []
215  | tcs_of (PRIMREC _) = []
216  | tcs_of (STDREC  {eqs,...}) = op_U aconv (map hyp eqs)
217  | tcs_of (NESTREC {eqs,...}) = op_U aconv (map hyp eqs)
218  | tcs_of (MUTREC  {eqs,...}) = op_U aconv (map hyp eqs)
219  | tcs_of (TAILREC {eqs,...}) = raise ERR "tcs_of" "Tail recursive definition"
220
221
222fun reln_of (ABBREV _)  = NONE
223  | reln_of (NONREC _)  = NONE
224  | reln_of (PRIMREC _) = NONE
225  | reln_of (STDREC  {R, ...}) = SOME R
226  | reln_of (NESTREC {R, ...}) = SOME R
227  | reln_of (MUTREC  {R, ...}) = SOME R
228  | reln_of (TAILREC {R, ...}) = SOME R;
229
230
231fun nUNDISCH n th = if n<1 then th else nUNDISCH (n-1) (UNDISCH th)
232
233fun INST_THM theta th =
234  let val asl = hyp th
235      val th1 = rev_itlist DISCH asl th
236      val th2 = INST_TY_TERM theta th1
237  in
238   nUNDISCH (length asl) th2
239  end;
240
241fun isubst (tmtheta,tytheta) tm = subst tmtheta (inst tytheta tm);
242
243fun inst_defn (STDREC{eqs,ind,R,SV,stem}) theta =
244      STDREC {eqs=map (INST_THM theta) eqs,
245              ind=INST_THM theta ind,
246              R=isubst theta R,
247              SV=map (isubst theta) SV, stem=stem}
248  | inst_defn (NESTREC{eqs,ind,R,SV,aux,stem}) theta =
249      NESTREC {eqs=map (INST_THM theta) eqs,
250               ind=INST_THM theta ind,
251               R=isubst theta R,
252               SV=map (isubst theta) SV,
253               aux=inst_defn aux theta, stem=stem}
254  | inst_defn (MUTREC{eqs,ind,R,SV,union,stem}) theta =
255      MUTREC {eqs=map (INST_THM theta) eqs,
256              ind=INST_THM theta ind,
257              R=isubst theta R,
258              SV=map (isubst theta) SV,
259              union=inst_defn union theta, stem=stem}
260  | inst_defn (PRIMREC{eqs,ind,bind}) theta =
261      PRIMREC{eqs=INST_THM theta eqs,
262              ind=INST_THM theta ind, bind=bind}
263  | inst_defn (NONREC {eqs,ind,SV,stem}) theta =
264      NONREC {eqs=INST_THM theta eqs,
265              ind=INST_THM theta ind,
266              SV=map (isubst theta) SV,stem=stem}
267  | inst_defn (ABBREV {eqn,bind}) theta =
268      ABBREV {eqn=INST_THM theta eqn,bind=bind}
269  | inst_defn (TAILREC{eqs,ind,R,SV,stem}) theta =
270      TAILREC {eqs=map (INST_THM theta) eqs,
271              ind=INST_THM theta ind,
272              R=isubst theta R,
273              SV=map (isubst theta) SV, stem=stem};
274
275
276fun set_reln def R =
277   case reln_of def
278    of NONE => def
279     | SOME Rpat => inst_defn def (Term.match_term Rpat R)
280                    handle e => (HOL_MESG"set_reln: unable"; raise e);
281
282fun PROVE_HYPL thl th = itlist PROVE_HYP thl th
283
284fun MATCH_HYPL thms th =
285  let val aslthms = mapfilter (EQT_ELIM o REWRITE_CONV thms) (hyp th)
286  in itlist PROVE_HYP aslthms th
287  end;
288
289
290(* We use PROVE_HYPL on induction theorems, since their tcs are fully
291   quantified. We use MATCH_HYPL on equations, since their tcs are
292   bare.
293*)
294
295fun elim_tcs (STDREC {eqs, ind, R, SV,stem}) thms =
296     STDREC{R=R, SV=SV, stem=stem,
297            eqs=map (MATCH_HYPL thms) eqs,
298            ind=PROVE_HYPL thms ind}
299  | elim_tcs (NESTREC {eqs, ind, R,  SV, aux, stem}) thms =
300     NESTREC{R=R, SV=SV, stem=stem,
301            eqs=map (MATCH_HYPL thms) eqs,
302            ind=PROVE_HYPL thms ind,
303            aux=elim_tcs aux thms}
304  | elim_tcs (MUTREC {eqs, ind, R, SV, union, stem}) thms =
305     MUTREC{R=R, SV=SV, stem=stem,
306            eqs=map (MATCH_HYPL thms) eqs,
307            ind=PROVE_HYPL thms ind,
308            union=elim_tcs union thms}
309  | elim_tcs (TAILREC {eqs, ind, R, SV,stem}) thms =
310     TAILREC{R=R, SV=SV, stem=stem,
311            eqs=map (MATCH_HYPL thms) eqs,
312            ind=PROVE_HYPL thms ind}
313  | elim_tcs x _ = x;
314
315
316local
317 val lem =
318   let val M  = mk_var("M",bool)
319       val P  = mk_var("P",bool)
320       val M1 = mk_var("M1",bool)
321       val tm1 = mk_eq(M,M1)
322       val tm2 = mk_imp(M,P)
323   in DISCH tm1 (DISCH tm2 (SUBS [ASSUME tm1] (ASSUME tm2)))
324   end
325in
326fun simp_assum conv tm th =
327  let val th' = DISCH tm th
328      val tmeq = conv tm
329      val tm' = rhs(concl tmeq)
330  in
331    if aconv tm' T then MP th' (EQT_ELIM tmeq)
332    else UNDISCH(MATCH_MP (MATCH_MP lem tmeq) th')
333  end
334end;
335
336fun SIMP_HYPL conv th = itlist (simp_assum conv) (hyp th) th;
337
338fun simp_tcs (STDREC {eqs, ind, R, SV, stem}) conv =
339     STDREC{R=rhs(concl(conv R)), SV=SV, stem=stem,
340            eqs=map (SIMP_HYPL conv) eqs,
341            ind=SIMP_HYPL conv ind}
342  | simp_tcs (NESTREC {eqs, ind, R,  SV, aux, stem}) conv =
343     NESTREC{R=rhs(concl(conv R)), SV=SV, stem=stem,
344            eqs=map (SIMP_HYPL conv) eqs,
345            ind=SIMP_HYPL conv ind,
346            aux=simp_tcs aux conv}
347  | simp_tcs (MUTREC {eqs, ind, R, SV, union, stem}) conv =
348     MUTREC{R=rhs(concl(conv R)), SV=SV, stem=stem,
349            eqs=map (SIMP_HYPL conv) eqs,
350            ind=SIMP_HYPL conv ind,
351            union=simp_tcs union conv}
352  | simp_tcs x _ = x;
353
354
355fun TAC_HYPL tac th =
356  PROVE_HYPL (mapfilter (C (curry Tactical.prove) tac) (hyp th)) th;
357
358fun prove_tcs (STDREC {eqs, ind, R, SV, stem}) tac =
359     STDREC{R=R, SV=SV, stem=stem,
360            eqs=map (TAC_HYPL tac) eqs,
361            ind=TAC_HYPL tac ind}
362  | prove_tcs (NESTREC {eqs, ind, R,  SV, aux, stem}) tac =
363     NESTREC{R=R, SV=SV, stem=stem,
364            eqs=map (TAC_HYPL tac) eqs,
365            ind=TAC_HYPL tac ind,
366            aux=prove_tcs aux tac}
367  | prove_tcs (MUTREC {eqs, ind, R, SV, union, stem}) tac =
368     MUTREC{R=R, SV=SV, stem=stem,
369            eqs=map (TAC_HYPL tac) eqs,
370            ind=TAC_HYPL tac ind,
371            union=prove_tcs union tac}
372  | prove_tcs x _ = x;
373
374
375(*---------------------------------------------------------------------------*)
376(* Deal with basic definitions.                                              *)
377(*---------------------------------------------------------------------------*)
378
379fun triv_defn (ABBREV _) = true
380  | triv_defn (PRIMREC _) = true
381  | triv_defn otherwise = false
382
383fun fetch_eqns (ABBREV{eqn,...})  = eqn
384  | fetch_eqns (PRIMREC{eqs,...}) = eqs
385  | fetch_eqns otherwise = raise ERR "fetch_eqns" "shouldn't happen"
386
387(*---------------------------------------------------------------------------
388   Store definition information to disk. Currently, just writes out the
389   eqns and induction theorem. A more advanced implementation would
390   write things out so that, when the exported theory is reloaded, the
391   defn datastructure is rebuilt. This would give a seamless view of
392   things.
393
394   Note that we would need to save union and aux info only when
395   termination has not been proved for a nested recursion.
396
397   Another (easier) way to look at it would be to require termination
398   and suchlike to be taken care of in the current theory. That is
399   what is assumed at present.
400 ---------------------------------------------------------------------------*)
401
402local fun is_suc tm =
403       case total dest_thy_const tm
404        of NONE => false
405         | SOME{Name,Thy,...} => Name="SUC" andalso Thy="num"
406      fun undef s =
407          if String.isSuffix "_DEF" s orelse String.isSuffix "_def" s then
408            String.substring(s,0,size s - 4)
409          else s
410in
411val SUC_TO_NUMERAL_DEFN_CONV_hook =
412      ref (fn _ => raise ERR "SUC_TO_NUMERAL_DEFN_CONV_hook" "not initialized")
413fun add_persistent_funs l =
414  if not (!computeLib.auto_import_definitions) then () else
415    let val has_lhs_SUC = List.exists
416              (can (find_term is_suc) o lhs o #2 o strip_forall)
417                                      o strip_conj o concl
418      fun f (s, th) =
419        [s] @
420        (if has_lhs_SUC th then let
421            val name = undef s^"_compute"
422            val name = let
423              val used = Lib.C Lib.mem (#1 (Lib.unzip (current_theorems())))
424              fun loop n = let val x = (name^(Int.toString n))
425                           in if used x then loop (n+1) else x end
426              in if used name then loop 0 else name end
427            val th_compute = CONV_RULE (!SUC_TO_NUMERAL_DEFN_CONV_hook) th
428            val _ = save_thm(name,th_compute)
429            in [name] end
430         else [])
431    in
432      computeLib.add_persistent_funs (List.concat (map f l))
433    end
434end;
435
436val mesg = with_flag(MESG_to_string, Lib.I) HOL_MESG
437
438local
439  val chatting = ref true
440  val _ = Feedback.register_btrace("Define.storage_message", chatting)
441in
442fun been_stored (s,thm) =
443  (add_persistent_funs [(s,thm)];
444   if !chatting then
445     mesg ((if !Globals.interactive then
446              "Definition has been stored under "
447            else
448              "Saved definition __ ") ^Lib.quote s^"\n")
449   else ()
450   )
451
452
453(* can fiddle with indSuffix to get "neat" effects; if it is a string
454   starting with a space, then that the string without the space is the name
455   of the induction theorem exactly. If the string is "", then the
456   process looks for a trailing _def in the name of the definition and replaces
457   it with _ind (preserving case)
458 *)
459fun indSuffix stem =
460    let
461      fun munge s =
462        if String.isSuffix "_def" s then
463          String.extract(s,0,SOME(size s - 4)) ^ "_ind"
464        else if String.isSuffix "_DEF" s then
465          String.extract(s,0,SOME(size s - 4)) ^ "_IND"
466        else s ^ "_ind"
467    in
468      case !ind_suffix of
469          "" => munge stem
470        | s => if String.sub(s,0) = #" " then String.extract(s,1,NONE)
471               else stem ^ s
472    end
473
474
475fun store(stem,eqs,ind) =
476  let val eqs_bind = defSuffix stem
477      val ind_bind = indSuffix stem
478      fun save x = Feedback.trace ("Theory.save_thm_reporting", 0) save_thm x
479      val   _  = save (ind_bind, ind)
480      val eqns = save (eqs_bind, eqs)
481      val _ = add_persistent_funs [(eqs_bind,eqs)]
482         handle e => HOL_MESG ("Unable to add "^eqs_bind^" to global compset")
483  in
484    if !chatting then
485       mesg (String.concat
486               (if !Globals.interactive then
487                  [   "Equations stored under ", Lib.quote eqs_bind,
488                   ".\nInduction stored under ", Lib.quote ind_bind, ".\n"]
489                else
490                  [  "Saved definition __ ", Lib.quote eqs_bind,
491                   "\nSaved induction ___ ", Lib.quote ind_bind, "\n"]))
492    else ()
493  end
494end
495
496local
497  val LIST_CONJ_GEN = LIST_CONJ o map GEN_ALL
498in
499  fun save_defn (ABBREV {bind,eqn, ...})     = been_stored (bind,eqn)
500  | save_defn (PRIMREC{bind,eqs, ...})       = been_stored (bind,eqs)
501  | save_defn (NONREC {eqs, ind, stem, ...}) = store(stem,eqs,ind)
502  | save_defn (STDREC {eqs, ind, stem, ...}) = store(stem,LIST_CONJ_GEN eqs,ind)
503  | save_defn (TAILREC{eqs, ind, stem, ...}) = store(stem,LIST_CONJ_GEN eqs,ind)
504  | save_defn (MUTREC {eqs,ind,stem,...})    = store(stem,LIST_CONJ_GEN eqs,ind)
505  | save_defn (NESTREC{eqs,ind,stem, ...})   = store(stem,LIST_CONJ_GEN eqs,ind)
506end
507
508
509(*---------------------------------------------------------------------------
510        Termination condition extraction
511 ---------------------------------------------------------------------------*)
512
513fun extraction_thms constset thy =
514 let val {case_rewrites,case_congs} = extract_info constset thy
515 in (case_rewrites, case_congs@read_congs())
516 end;
517
518(*---------------------------------------------------------------------------
519         Capturing termination conditions.
520 ----------------------------------------------------------------------------*)
521
522
523local fun !!v M = mk_forall(v, M)
524      val mem = Lib.op_mem aconv
525      fun set_diff a b = Lib.filter (fn x => not (mem x b)) a
526in
527fun solver (restrf,f,G,nref) _ context tm =
528  let val globals = f::G  (* not to be generalized *)
529      fun genl tm = itlist !! (set_diff (rev(free_vars tm)) globals) tm
530      val rcontext = rev context
531      val antl = case rcontext of [] => []
532                               | _   => [list_mk_conj(map concl rcontext)]
533      val TC = genl(list_mk_imp(antl, tm))
534      val (R,arg,pat) = wfrecUtils.dest_relation tm
535  in
536     if can(find_term (aconv restrf)) arg
537     then (nref := true; raise ERR "solver" "nested function")
538     else let val _ = if can(find_term (aconv f)) TC
539                      then nref := true else ()
540          in case rcontext
541              of [] => SPEC_ALL(ASSUME TC)
542               | _  => MATCH_MP (SPEC_ALL (ASSUME TC)) (LIST_CONJ rcontext)
543          end
544  end
545end;
546
547fun extract FV congs f (proto_def,WFR) =
548 let val R = rand WFR
549     val CUT_LEM = ISPECL [f,R] relationTheory.RESTRICT_LEMMA
550     val restr_fR = rator(rator(lhs(snd(dest_imp (concl (SPEC_ALL CUT_LEM))))))
551     fun mk_restr p = mk_comb(restr_fR, p)
552 in fn (p,th) =>
553    let val nested_ref = ref false
554        val FV' = FV@free_vars(concl th)
555        val rwArgs = (RW.Pure [CUT_LEM],
556                      RW.Context ([],RW.DONT_ADD),
557                      RW.Congs congs,
558                      RW.Solver (solver (mk_restr p, f, FV', nested_ref)))
559        val th' = CONV_RULE (RW.Rewrite RW.Fully rwArgs) th
560    in
561      (th', Lib.op_set_diff aconv (hyp th') [proto_def,WFR], !nested_ref)
562    end
563end;
564
565
566(*---------------------------------------------------------------------------*
567 * Perform TC extraction without making a definition.                        *
568 *---------------------------------------------------------------------------*)
569
570type wfrec_eqns_result = {WFR : term,
571                          SV : term list,
572                          proto_def : term,
573                          extracta  : (thm * term list * bool) list,
574                          pats  : pattern list}
575
576fun protect_rhs eqn =
577  if is_forall eqn then
578    raise ERR "mk_defn"
579          "Universally quantified equation as argument to well-founded \
580          \recursion"
581  else
582    mk_eq(lhs eqn,combinSyntax.mk_I(rhs eqn))
583fun protect eqns = list_mk_conj (map protect_rhs (strip_conj eqns));
584
585val unprotect_term = rhs o concl o PURE_REWRITE_CONV [combinTheory.I_THM];
586val unprotect_thm  = PURE_REWRITE_RULE [combinTheory.I_THM];
587
588fun checkSV pats SV =
589 let fun get_pat (GIVEN(p,_)) = p
590       | get_pat (OMITTED(p,_)) = p
591     fun strings_of vlist =
592         Lib.commafy (List.map (Lib.quote o #1 o dest_var)
593                               (Listsort.sort Term.compare vlist))
594 in
595   if null SV then ()
596   else HOL_MESG (String.concat
597     ["Definition is schematic in the following variables:\n    ",
598      String.concat (strings_of SV)])
599   ;
600   case op_intersect aconv (free_varsl (map get_pat pats)) SV
601    of [] => ()
602     | probs =>
603       raise ERR "wfrec_eqns"
604         (String.concat
605             (["the following variables occur both free (schematic) ",
606               "and bound in the definition: \n   "] @ strings_of probs))
607 end
608
609(*---------------------------------------------------------------------------*)
610(* Instantiate the recursion theorem and extract termination conditions,     *)
611(* but do not define the constant yet.                                       *)
612(*---------------------------------------------------------------------------*)
613
614fun wfrec_eqns facts tup_eqs =
615 let val {functional,pats} =
616        mk_functional (TypeBasePure.toPmatchThry facts) (protect tup_eqs)
617     val SV = free_vars functional    (* schematic variables *)
618     val _ = checkSV pats SV
619     val (f, Body) = dest_abs functional
620     val (x,_) = dest_abs Body
621     val (Name, fty) = dest_var f
622     val (f_dty, f_rty) = Type.dom_rng fty
623     val WFREC_THM0 = ISPEC functional relationTheory.WFREC_COROLLARY
624     val R = variant (free_vars tup_eqs) (fst(dest_forall(concl WFREC_THM0)))
625     val WFREC_THM = ISPECL [R, f] WFREC_THM0
626     val tmp = fst(wfrecUtils.strip_imp(concl WFREC_THM))
627     val proto_def = Lib.trye hd tmp
628     val WFR = Lib.trye (hd o tl) tmp
629     val R1 = rand WFR
630     val corollary' = funpow 2 UNDISCH WFREC_THM
631     val given_pats = givens pats
632     val corollaries = map (C SPEC corollary') given_pats
633     val eqns_consts = op_mk_set aconv (find_terms is_const functional)
634     val (case_rewrites,congs) = extraction_thms eqns_consts facts
635     val RWcnv = REWRITES_CONV (add_rewrites empty_rewrites
636                                (literal_case_THM::case_rewrites))
637     val rule = unprotect_thm o
638                RIGHT_CONV_RULE
639                   (LIST_BETA_CONV
640                    THENC REPEATC ((RWcnv THENC LIST_BETA_CONV) ORELSEC
641                                   elim_triv_literal_CONV))
642     val corollaries' = map rule corollaries
643  in
644     {proto_def=proto_def,
645      SV=Listsort.sort Term.compare SV,
646      WFR=WFR,
647      pats=pats,
648      extracta = map (extract [R1] congs f (proto_def,WFR))
649                     (zip given_pats corollaries')}
650  end
651
652(*---------------------------------------------------------------------------
653 * Pair patterns with termination conditions. The full list of patterns for
654 * a definition is merged with the TCs arising from the user-given clauses.
655 * There can be fewer clauses than the full list, if the user omitted some
656 * cases. This routine is used to prepare input for mk_induction.
657 *---------------------------------------------------------------------------*)
658
659fun merge full_pats TCs =
660let fun insert (p,TCs) =
661      let fun insrt ((x as (h,[]))::rst) =
662                 if (aconv p h) then (p,TCs)::rst else x::insrt rst
663            | insrt (x::rst) = x::insrt rst
664            | insrt[] = raise ERR"merge.insert" "pat not found"
665      in insrt end
666    fun pass ([],ptcl_final) = ptcl_final
667      | pass (ptcs::tcl, ptcl) = pass(tcl, insert ptcs ptcl)
668in
669  pass (TCs, map (fn p => (p,[])) full_pats)
670end;
671
672(*---------------------------------------------------------------------------*
673 * Define the constant after extracting the termination conditions. The      *
674 * wellfounded relation used in the definition is computed by using the      *
675 * choice operator on the extracted conditions (plus the condition that      *
676 * such a relation must be wellfounded).                                     *
677 *                                                                           *
678 * There are three flavours of recursion: standard, nested, and mutual.      *
679 *                                                                           *
680 * A "standard" recursion is one that is not mutual or nested.               *
681 *---------------------------------------------------------------------------*)
682
683fun stdrec thy bindstem {proto_def,SV,WFR,pats,extracta} =
684 let val R1 = rand WFR
685     val f = lhs proto_def
686     val (extractants,TCl_0,_) = unzip3 extracta
687     fun gen_all away tm =
688        let val FV = free_vars tm
689        in itlist (fn v => fn tm =>
690              if op_mem aconv v away then tm else mk_forall(v,tm)) FV tm
691        end
692     val TCs_0 = op_U aconv TCl_0
693     val TCl = map (map (gen_all (R1::SV))) TCl_0
694     val TCs = op_U aconv TCl
695     val full_rqt = WFR::TCs
696     val R2 = mk_select(R1, list_mk_conj full_rqt)
697     val R2abs = rand R2
698     val fvar = mk_var(fst(dest_var f),
699                       itlist (curry op-->) (map type_of SV) (type_of f))
700     val fvar_app = list_mk_comb(fvar,SV)
701     val (def,theory) = make_definition thy (defPrim bindstem)
702                          (subst [f |-> fvar_app, R1 |-> R2] proto_def)
703     val fconst = fst(strip_comb(lhs(snd(strip_forall(concl def)))))
704     val disch'd = itlist DISCH (proto_def::WFR::TCs_0) (LIST_CONJ extractants)
705     val inst'd = SPEC (list_mk_comb(fconst,SV))
706                       (SPEC R2 (GENL [R1, f] disch'd))
707     val def' = MP inst'd (SPEC_ALL def)
708     val var_wits = LIST_CONJ (map ASSUME full_rqt)
709     val TC_choice_thm =
710           MP (CONV_RULE(BINOP_CONV BETA_CONV)
711                        (ISPECL[R2abs, R1] boolTheory.SELECT_AX)) var_wits
712 in
713    {theory = theory, R=R1, SV=SV,
714     rules = CONJUNCTS
715              (rev_itlist (C ModusPonens) (CONJUNCTS TC_choice_thm) def'),
716     full_pats_TCs = merge (map pat_of pats) (zip (givens pats) TCl),
717     patterns = pats}
718 end
719
720
721(*---------------------------------------------------------------------------
722      Nested recursion.
723 ---------------------------------------------------------------------------*)
724
725fun nestrec thy bindstem {proto_def,SV,WFR,pats,extracta} =
726 let val R1 = rand WFR
727     val (f,rhs_proto_def) = dest_eq proto_def
728     (* make parameterized definition *)
729     val (Name,Ty) = Lib.trye dest_var f
730     val aux_name = Name^"_aux"
731     val aux_fvar =
732         mk_var(aux_name,itlist(curry(op-->)) (map type_of (R1::SV)) Ty)
733     val aux_bindstem = auxStem bindstem
734     val (def,theory) =
735           make_definition thy (defSuffix aux_bindstem)
736               (mk_eq(list_mk_comb(aux_fvar,R1::SV), rhs_proto_def))
737     val def' = SPEC_ALL def
738     val auxFn_capp = lhs(concl def')
739     val auxFn_const = #1(strip_comb auxFn_capp)
740     val (extractants,TCl_0,_) = unzip3 extracta
741     val TCs_0 = op_U aconv TCl_0
742     val disch'd = itlist DISCH (proto_def::WFR::TCs_0) (LIST_CONJ extractants)
743     val inst'd = GEN R1 (MP (SPEC auxFn_capp (GEN f disch'd)) def')
744     fun kdisch keep th =
745       itlist (fn h => fn th => if op_mem aconv h keep then th else DISCH h th)
746              (hyp th) th
747     val disch'dl_0 = map (DISCH proto_def o
748                           DISCH WFR o kdisch [proto_def,WFR])
749                        extractants
750     val disch'dl_1 = map (fn d => MP (SPEC auxFn_capp (GEN f d)) def')
751                          disch'dl_0
752     fun gen_all away tm =
753        let val FV = free_vars tm
754        in itlist (fn v => fn tm =>
755              if op_mem aconv v away then tm else mk_forall(v,tm)) FV tm
756        end
757     val TCl = map (map (gen_all (R1::f::SV) o subst[f |-> auxFn_capp])) TCl_0
758     val TCs = op_U aconv TCl
759     val full_rqt = WFR::TCs
760     val R2 = mk_select(R1, list_mk_conj full_rqt)
761     val R2abs = rand R2
762     val R2inst'd = SPEC R2 inst'd
763     val fvar = mk_var(fst(dest_var f),
764                       itlist (curry op-->) (map type_of SV) (type_of f))
765     val fvar_app = list_mk_comb(fvar,SV)
766     val (def1,theory1) = make_definition thy (defPrim bindstem)
767               (mk_eq(fvar_app, list_mk_comb(auxFn_const,R2::SV)))
768     val var_wits = LIST_CONJ (map ASSUME full_rqt)
769     val TC_choice_thm =
770         MP (CONV_RULE(BINOP_CONV BETA_CONV)
771                      (ISPECL[R2abs, R1] boolTheory.SELECT_AX))
772            var_wits
773     val elim_chosenTCs =
774           rev_itlist (C ModusPonens) (CONJUNCTS TC_choice_thm) R2inst'd
775     val rules = simplify [GSYM def1] elim_chosenTCs
776     val pat_TCs_list = merge (map pat_of pats) (zip (givens pats) TCl)
777
778     (* and now induction *)
779
780     val aux_ind = Induction.mk_induction theory1
781                       {fconst=auxFn_const, R=R1, SV=SV,
782                        pat_TCs_list=pat_TCs_list}
783     val ics = strip_conj(fst(dest_imp(snd(dest_forall(concl aux_ind)))))
784     fun dest_ic tm = if is_imp tm then strip_conj (fst(dest_imp tm)) else []
785     val ihs = Lib.flatten (map (dest_ic o snd o strip_forall) ics)
786     val nested_ihs = filter (can (find_term (aconv auxFn_const))) ihs
787     (* a nested ih is of the form
788
789           !(c1/\.../\ck ==> R a pat ==> P a)
790
791        where "aux R N" occurs in "c1/\.../\ck" or "a". In the latter case,
792        we have a nested recursion; in the former, there's just a call
793        to aux in the context. In both cases, we want to eliminate "R a pat"
794        by assuming "c1/\.../\ck ==> R a pat" and doing some work. Really,
795        what we prove is something of the form
796
797          !(c1/\.../\ck ==> R a pat) |-
798             (!(c1/\.../\ck ==> R a pat ==> P a))
799               =
800             (!(c1/\.../\ck ==> P a))
801
802        where the c1/\.../\ck might not be there (when there is no
803        context for the recursive call), and where !( ... ) denotes
804        a universal prefix.
805     *)
806     fun fSPEC_ALL th =
807       case Lib.total dest_forall (concl th) of
808           SOME (v,_) => fSPEC_ALL (SPEC v th)
809         | NONE => th
810     fun simp_nested_ih nih =
811      let val (lvs,tm) = strip_forall nih
812          val (ants,Pa) = strip_imp_only tm
813          val P = rator Pa  (* keep R, P, and SV unquantified *)
814          val vs = op_set_diff aconv (free_varsl ants) (R1::P::SV)
815          val V = op_union aconv lvs vs
816          val has_context = (length ants = 2)
817          val ng = list_mk_forall(V,list_mk_imp (front_last ants))
818          val th1 = fSPEC_ALL (ASSUME ng)
819          val th1a = if has_context then UNDISCH th1 else th1
820          val th2 = fSPEC_ALL (ASSUME nih)
821          val th2a = if has_context then UNDISCH th2 else th2
822          val Rab = fst(dest_imp(concl th2a))
823          val th3 = MP th2a th1a
824          val th4 = if has_context
825                    then DISCH (fst(dest_imp(concl th1))) th3
826                    else th3
827          val th5 = GENL lvs th4
828          val th6 = DISCH nih th5
829          val tha = fSPEC_ALL(ASSUME (concl th5))
830          val thb = if has_context then UNDISCH tha else tha
831          val thc = DISCH Rab thb
832          val thd = if has_context
833                     then DISCH (fst(dest_imp(snd(strip_forall ng)))) thc
834                     else thc
835          val the = GENL lvs thd
836          val thf = DISCH_ALL the
837      in
838        MATCH_MP (MATCH_MP IMP_ANTISYM_AX th6) thf
839      end handle e => raise wrap_exn "nestrec.simp_nested_ih"
840                       "failed while trying to generated nested ind. hyp." e
841     val nested_ih_simps = map simp_nested_ih nested_ihs
842     val ind0 = simplify nested_ih_simps aux_ind
843     val ind1 = UNDISCH_ALL (SPEC R2 (GEN R1 (DISCH_ALL ind0)))
844     val ind2 = simplify [GSYM def1] ind1
845     val ind3 = itlist ALPHA_PROVE_HYP (CONJUNCTS TC_choice_thm) ind2
846 in
847    {rules = CONJUNCTS rules,
848     ind = ind3,
849     SV = SV,
850     R = R1,
851     theory = theory1, aux_def = def, def = def1,
852     aux_rules = map UNDISCH_ALL disch'dl_1,
853     aux_ind = aux_ind
854    }
855 end;
856
857
858(*---------------------------------------------------------------------------
859      Performs tupling and also eta-expansion.
860 ---------------------------------------------------------------------------*)
861
862fun tuple_args alist =
863 let
864   val find = Lib.C (op_assoc1 aconv) alist
865   fun tupelo tm =
866     case dest_term tm of
867        LAMB (Bvar, Body) => mk_abs (Bvar, tupelo Body)
868      | _ =>
869         let
870           val (g, args) = strip_comb tm
871           val args' = map tupelo args
872         in
873           case find g of
874              NONE => list_mk_comb (g, args')
875            | SOME (stem', argtys) =>
876               if length args < length argtys  (* partial application *)
877                 then
878                   let
879                     val nvs = map (curry mk_var "a") (drop args argtys)
880                     val nvs' = variants (free_varsl args') nvs
881                     val comb' = mk_comb(stem', list_mk_pair(args' @nvs'))
882                   in
883                     list_mk_abs(nvs', comb')
884                   end
885               else mk_comb(stem', list_mk_pair args')
886         end
887 in
888   tupelo
889 end;
890
891(*---------------------------------------------------------------------------
892     Mutual recursion. This is reduced to an ordinary definition by
893     use of sum types. The n mutually recursive functions are mapped
894     to a single function "mut" having domain and range be sums of
895     the domains and ranges of the given functions. The domain sum
896     has n components. The range sum has k <= n components, built from
897     the set of range types. The arguments of the left hand side of
898     the function are uniformly injected into the domain sum. On the
899     right hand side, every occurrence of a function "f a" is translated
900     to "OUT(mut (IN a))", where IN is the compound injection function,
901     and OUT brings the result back to the original type of "f a".
902     Finally, each rhs is injected into the range sum.
903
904     After that translation, "mut" is defined. And then the individual
905     functions are defined. Rewriting then brings them out.
906
907     After that, induction is easy to recover from the induction theorem
908     for mut.
909 ---------------------------------------------------------------------------*)
910
911fun ndom_rng ty 0 = ([],ty)
912  | ndom_rng ty n =
913      let val (dom,rng) = dom_rng ty
914          val (L,last) = ndom_rng rng (n-1)
915      in (dom::L, last)
916      end;
917
918fun tmi_eq (tm1,i1:int) (tm2,i2) = i1 = i2 andalso aconv tm1 tm2
919fun mutrec thy bindstem eqns =
920  let val dom_rng = Type.dom_rng
921      val genvar = Term.genvar
922      val DEPTH_CONV = Conv.DEPTH_CONV
923      val BETA_CONV = Thm.BETA_CONV
924      val OUTL = sumTheory.OUTL
925      val OUTR = sumTheory.OUTR
926      val sum_case_def = sumTheory.sum_case_def
927      val CONJ = Thm.CONJ
928      fun dest_atom tm = (dest_var tm handle HOL_ERR _ => dest_const tm)
929      val eqnl = strip_conj eqns
930      val lhs_info =
931          op_mk_set tmi_eq (map ((I##length) o strip_comb o lhs) eqnl)
932      val div_tys = map (fn (tm,i) => ndom_rng (type_of tm) i) lhs_info
933      val lhs_info1 = zip (map fst lhs_info) div_tys
934      val dom_tyl = map (list_mk_prod_type o fst) div_tys
935      val rng_tyl = mk_set (map snd div_tys)
936      val mut_dom = end_itlist mk_sum_type dom_tyl
937      val mut_rng = end_itlist mk_sum_type rng_tyl
938      val mut_name = unionStem bindstem
939      val mut = mk_var(mut_name, mut_dom --> mut_rng)
940      fun inform (f,(doml,rng)) =
941        let val s = fst(dest_atom f)
942        in if 1<length doml
943            then (f, (mk_var(s^"_TUPLED",list_mk_prod_type doml --> rng),doml))
944            else (f, (f,doml))
945         end
946      val eqns' = tuple_args (map inform lhs_info1) eqns
947      val eqnl' = strip_conj eqns'
948      val (L,R) = unzip (map dest_eq eqnl')
949      val fnl' = op_mk_set aconv (map (fst o strip_comb o lhs) eqnl')
950      val fnvar_map = zip lhs_info1 fnl'
951      val gvl = map genvar dom_tyl
952      val gvr = map genvar rng_tyl
953      val injmap = zip fnl' (map2 (C (curry mk_abs)) (inject mut_dom gvl) gvl)
954      fun mk_lhs_mut_app (f,arg) =
955          mk_comb(mut,beta_conv (mk_comb(op_assoc aconv f injmap,arg)))
956      val L1 = map (mk_lhs_mut_app o dest_comb) L
957      val gv_mut_rng = genvar mut_rng
958      val outfns = map (curry mk_abs gv_mut_rng)
959                       (project rng_tyl mut_rng gv_mut_rng)
960      val ty_outs = zip rng_tyl outfns
961      (* now replace each f by \x. outbar(mut(inbar x)) *)
962      fun fout f = (f,assoc (#2(dom_rng(type_of f))) ty_outs)
963      val RNG_OUTS = map fout fnl'
964      fun mk_rhs_mut f v =
965          (f |-> mk_abs(v,beta_conv (mk_comb(op_assoc aconv f RNG_OUTS,
966                                             mk_lhs_mut_app (f,v)))))
967      val R1 = map (Term.subst (map2 mk_rhs_mut fnl' gvl)) R
968      val eqnl1 = zip L1 R1
969      val rng_injmap =
970            zip rng_tyl (map2 (C (curry mk_abs)) (inject mut_rng gvr) gvr)
971      fun f_rng_in f = (f,assoc (#2(dom_rng(type_of f))) rng_injmap)
972      val RNG_INS = map f_rng_in fnl'
973      val tmp = zip (map (#1 o dest_comb) L) R1
974      val R2 = map (fn (f,r) => beta_conv(mk_comb(op_assoc aconv f RNG_INS, r)))
975                   tmp
976      val R3 = map (rhs o concl o QCONV (DEPTH_CONV BETA_CONV)) R2
977      val mut_eqns = list_mk_conj(map mk_eq (zip L1 R3))
978      val wfrec_res = wfrec_eqns thy mut_eqns
979      val defn =
980        if exists I (#3(unzip3 (#extracta wfrec_res)))   (* nested *)
981        then let val {rules,ind,aux_rules, aux_ind, theory, def,aux_def,...}
982                     = nestrec thy mut_name wfrec_res
983             in {rules=rules, ind=ind, theory=theory,
984                 aux=SOME{rules=aux_rules, ind=aux_ind}}
985              end
986        else let val {rules,R,SV,theory,full_pats_TCs,...}
987                   = stdrec thy mut_name wfrec_res
988             val f = #1(dest_comb(lhs (concl(Lib.trye hd rules))))
989             val ind = Induction.mk_induction theory
990                         {fconst=f, R=R, SV=SV, pat_TCs_list=full_pats_TCs}
991             in {rules=rules, ind=ind, theory=theory, aux=NONE}
992             end
993      val theory1 = #theory defn
994      val mut_rules = #rules defn
995      val mut_constSV = #1(dest_comb(lhs(concl (hd mut_rules))))
996      val (mut_const,params) = strip_comb mut_constSV
997      fun define_subfn (n,((fvar,(argtys,rng)),ftupvar)) thy =
998         let val inbar  = op_assoc aconv ftupvar injmap
999             val outbar = op_assoc aconv ftupvar RNG_OUTS
1000             val (fvarname,_) = dest_atom fvar
1001             val defvars = rev
1002                             (numvariants [fvar]
1003                                          (map (curry mk_var "x") argtys))
1004             val tup_defvars = list_mk_pair defvars
1005             val newty = itlist (curry (op-->)) (map type_of params@argtys) rng
1006             val fvar' = mk_var(fvarname,newty)
1007             val dlhs  = list_mk_comb(fvar',params@defvars)
1008             val Uapp  = mk_comb(mut_constSV,
1009                            beta_conv(mk_comb(inbar,list_mk_pair defvars)))
1010             val drhs  = beta_conv (mk_comb(outbar,Uapp))
1011             val thybind = defExtract(mut_name,n)
1012         in
1013           (make_definition thy thybind (mk_eq(dlhs,drhs)) , (Uapp,outbar))
1014         end
1015      fun mk_def triple (defl,thy,Uout_map) =
1016            let val ((d,thy'),Uout) = define_subfn triple thy
1017            in (d::defl, thy', Uout::Uout_map)
1018            end
1019      val (defns,theory2,Uout_map) =
1020            itlist mk_def (Lib.enumerate 0 fnvar_map) ([],theory1,[])
1021      fun apply_outmap th =
1022         let fun matches (pat,_) = Lib.can (Term.match_term pat)
1023                                           (lhs (concl th))
1024             val (_,outf) = Lib.first matches Uout_map
1025         in AP_TERM outf th
1026         end
1027      val mut_rules1 = map apply_outmap mut_rules
1028      val simp = Rules.simplify (OUTL::OUTR::map GSYM defns)
1029      (* finally *)
1030      val mut_rules2 = map simp mut_rules1
1031
1032      (* induction *)
1033      val mut_ind0 = simp (#ind defn)
1034      val pindices = enumerate (map fst div_tys)
1035      val vary = Term.variant(Term.all_varsl
1036                    (concl (hd mut_rules2)::hyp (hd mut_rules2)))
1037      fun mkP (tyl,i) def =
1038          let val V0 = snd(strip_comb(lhs(snd(strip_forall(concl def)))))
1039              val V = drop (#SV wfrec_res) V0
1040              val P = vary (mk_var("P"^Lib.int_to_string i,
1041                                   list_mk_fun_type (tyl@[bool])))
1042           in (P, mk_pabs(list_mk_pair V,list_mk_comb(P, V)))
1043          end
1044      val (Plist,preds) = unzip (map2 mkP pindices defns)
1045      val Psum_case = end_itlist (fn P => fn tm =>
1046           let val Pty = type_of P
1047               val Pdom = #1(dom_rng Pty)
1048               val tmty = type_of tm
1049               val tmdom = #1(dom_rng tmty)
1050               val gv = genvar (sumSyntax.mk_sum(Pdom, tmdom))
1051           in
1052              mk_abs(gv, sumSyntax.mk_sum_case(P,tm,gv))
1053           end) preds
1054      val mut_ind1 = Rules.simplify [sum_case_def] (SPEC Psum_case mut_ind0)
1055      val (ant,_) = dest_imp (concl mut_ind1)
1056      fun mkv (i,ty) = mk_var("v"^Lib.int_to_string i,ty)
1057      val V = map (map mkv)
1058                  (map (Lib.enumerate 0 o fst) pindices)
1059      val Vinj = map2 (fn f => fn vlist =>
1060                        beta_conv(mk_comb(#2 f, list_mk_pair vlist))) injmap V
1061      val und_mut_ind1 = UNDISCH mut_ind1
1062      val tmpl = map (fn (vlist,v) =>
1063                         GENL vlist (Rules.simplify [sum_case_def]
1064                                     (SPEC v und_mut_ind1)))   (zip V Vinj)
1065      val mut_ind2 = GENL Plist (DISCH ant (LIST_CONJ tmpl))
1066  in
1067    { rules = mut_rules2,
1068      ind =  mut_ind2,
1069      SV = #SV wfrec_res,
1070      R = rand (#WFR wfrec_res),
1071      union = defn,
1072      theory = theory2
1073    }
1074  end;
1075
1076
1077(*---------------------------------------------------------------------------
1078       The purpose of pairf is to translate a prospective definition
1079       into a completely tupled format. On entry to pairf, we know that
1080       f is curried, i.e., of type
1081
1082              f : ty1 -> ... -> tyn -> rangety
1083
1084       We build a tupled version of f
1085
1086              f_tupled : ty1 * ... * tyn -> rangety
1087
1088       and then make a definition
1089
1090              f x1 ... xn = f_tupled (x1,...,xn)
1091
1092       We also need to remember how to revert an induction theorem
1093       into the original domain type. This function is not used for
1094       mutual recursion, since things are more complicated there.
1095
1096 ----------------------------------------------------------------------------*)
1097
1098fun move_arg tm =
1099  Option.map
1100    (fn (f, (v, b)) => boolSyntax.mk_eq (Term.mk_comb (f, v), b))
1101    (Lib.total ((Lib.I ## Term.dest_abs) o boolSyntax.dest_eq) tm)
1102
1103fun pairf (stem, eqs0) =
1104 let
1105   val ((f, args), rhs) = dest_hd_eqn eqs0
1106   val argslen = length args
1107 in
1108   if argslen = 1   (* not curried ... do eta-expansion *)
1109     then (tuple_args [(f, (f, map type_of args))] eqs0, stem, I)
1110   else
1111     let
1112       val stem'name = stem ^ "_tupled"
1113       val argtys = map type_of args
1114       val rng_ty = type_of rhs
1115       val tuple_dom = list_mk_prod_type argtys
1116       val stem' = mk_var (stem'name, tuple_dom --> rng_ty)
1117       fun untuple_args (rules, induction) =
1118         let
1119           val eq1 = concl (hd rules)
1120           val (lhs, rhs) = dest_eq (snd (strip_forall eq1))
1121           val (tuplec, args) = strip_comb lhs
1122           val (SV, p) = front_last args
1123           val defvars = rev (numvariants (f :: SV)
1124                                          (map (curry mk_var "x") argtys))
1125           val tuplecSV = list_mk_comb (tuplec, SV)
1126           val def_args = SV @ defvars
1127           val fvar =
1128             mk_var (atom_name f,
1129                     list_mk_fun_type (map type_of def_args @ [rng_ty]))
1130           val def  = new_definition (argMunge stem,
1131                        mk_eq (list_mk_comb (fvar, def_args),
1132                               list_mk_comb (tuplecSV, [list_mk_pair defvars])))
1133           val rules' = map (Rewrite.PURE_REWRITE_RULE [GSYM def]) rules
1134           val induction' =
1135             let
1136               val P = fst (dest_var (fst (dest_forall (concl induction))))
1137               val Qty = itlist (curry Type.-->) argtys Type.bool
1138               val Q = mk_primed_var (P, Qty)
1139               val tm =
1140                 mk_pabs (list_mk_pair defvars, list_mk_comb (Q, defvars))
1141             in
1142               GEN Q (PairedLambda.GEN_BETA_RULE
1143                   (SPEC tm (Rewrite.PURE_REWRITE_RULE [GSYM def] induction)))
1144             end
1145         in
1146           (rules', induction') before Theory.delete_const stem'name
1147         end
1148     in
1149       (tuple_args [(f, (stem', argtys))] eqs0, stem'name, untuple_args)
1150     end
1151 end
1152 handle e as HOL_ERR {message = "incompatible types", ...} =>
1153   case move_arg eqs0 of SOME tm => pairf (stem, tm) | NONE => raise e
1154
1155(*---------------------------------------------------------------------------*)
1156(* Abbreviation or prim. rec. definitions.                                   *)
1157(*---------------------------------------------------------------------------*)
1158
1159local fun is_constructor tm = not (is_var tm orelse is_pair tm)
1160in
1161fun non_wfrec_defn (facts,bind,eqns) =
1162 let val ((_,args),_) = dest_hd_eqn eqns
1163 in if Lib.exists is_constructor args
1164    then let
1165      val {Thy=cthy,Tyop=cty,...} =
1166            dest_thy_type (type_of(first is_constructor args))
1167      in
1168        case TypeBasePure.prim_get facts (cthy,cty)
1169        of NONE => raise ERR "non_wfrec_defn" "unexpected lhs in definition"
1170         | SOME tyinfo =>
1171           let val def = Prim_rec.new_recursive_definition
1172                          {name=bind,def = eqns,
1173                           rec_axiom = TypeBasePure.axiom_of tyinfo}
1174               val ind = TypeBasePure.induction_of tyinfo
1175           in PRIMREC{eqs = def, ind = ind, bind=bind}
1176           end
1177      end
1178    else ABBREV {eqn=new_definition (bind, eqns), bind=bind}
1179 end
1180end;
1181
1182fun mutrec_defn (facts,stem,eqns) =
1183 let val {rules, ind, SV, R,
1184          union as {rules=r,ind=i,aux,...},...} = mutrec facts stem eqns
1185     val union' = case aux
1186      of NONE => STDREC{eqs = r,
1187                        ind = i,
1188                        R = R,
1189                        SV=SV, stem=unionStem stem}
1190      | SOME{rules=raux,ind=iaux} =>
1191         NESTREC{eqs = r,
1192                 ind = i,
1193                 R = R,
1194                 SV=SV, stem=unionStem stem,
1195                 aux=STDREC{eqs = raux,
1196                            ind = iaux,
1197                            R = R,
1198                            SV=SV,stem=auxStem stem}}
1199 in MUTREC{eqs = rules,
1200           ind = ind,
1201           R = R, SV=SV, stem=stem, union=union'}
1202 end
1203 handle e => raise wrap_exn "Defn" "mutrec_defn" e;
1204
1205fun nestrec_defn (thy,(stem,stem'),wfrec_res,untuple) =
1206  let val {rules,ind,SV,R,aux_rules,aux_ind,...} = nestrec thy stem' wfrec_res
1207      val (rules', ind') = untuple (rules, ind)
1208  in NESTREC {eqs = rules',
1209              ind = ind',
1210              R = R, SV=SV, stem=stem,
1211              aux=STDREC{eqs = aux_rules,
1212                         ind = aux_ind,
1213                         R = R, SV=SV, stem=auxStem stem'}}
1214  end;
1215
1216fun stdrec_defn (facts,(stem,stem'),wfrec_res,untuple) =
1217 let val {rules,R,SV,full_pats_TCs,...} = stdrec facts stem' wfrec_res
1218     val ((f,_),_) = dest_hd_eqnl rules
1219     val ind = Induction.mk_induction facts
1220                  {fconst=f, R=R, SV=SV, pat_TCs_list=full_pats_TCs}
1221 in
1222 case hyp (LIST_CONJ rules)
1223 of []     => raise ERR "stdrec_defn" "Empty hypotheses"
1224  | [WF_R] =>   (* non-recursive defn via complex patterns *)
1225       (let val (WF,R)    = dest_comb WF_R
1226            val theta     = [Type.alpha |-> hd(snd(dest_type (type_of R)))]
1227            val Empty_thm = INST_TYPE theta relationTheory.WF_EMPTY_REL
1228            val (r1, i1)  = untuple(rules, ind)
1229            val r2        = MATCH_MP (DISCH_ALL (LIST_CONJ r1)) Empty_thm
1230            val i2        = MATCH_MP (DISCH_ALL i1) Empty_thm
1231        in
1232           NONREC {eqs = r2,
1233                   ind = i2, SV=SV, stem=stem}
1234        end handle HOL_ERR _ => raise ERR "stdrec_defn" "")
1235  | otherwise =>
1236        let val (rules', ind') = untuple (rules, ind)
1237        in STDREC {eqs = rules',
1238                   ind = ind',
1239                   R = R, SV=SV, stem=stem}
1240        end
1241 end
1242 handle e => raise wrap_exn "Defn" "stdrec_defn" e
1243
1244(*---------------------------------------------------------------------------
1245    A general, basic, interface to function definitions. First try to
1246    use standard existing machinery to make a prim. rec. definition, or
1247    an abbreviation. If those attempts fail, try to use a wellfounded
1248    definition (with pattern matching, wildcard expansion, etc.). Note
1249    that induction is derived for all wellfounded definitions, but
1250    a termination proof is not attempted. For that, use the entrypoints
1251    in TotalDefn.
1252 ---------------------------------------------------------------------------*)
1253
1254fun holexnMessage (HOL_ERR {origin_structure,origin_function,message}) =
1255      origin_structure ^ "." ^ origin_function ^ ": " ^ message
1256  | holexnMessage e = General.exnMessage e
1257
1258fun is_simple_arg t =
1259  is_var t orelse
1260  (case Lib.total dest_pair t of
1261       NONE => false
1262     | SOME (l,r) => is_simple_arg l andalso is_simple_arg r)
1263
1264fun prim_mk_defn stem eqns =
1265 let
1266   fun err s = raise ERR "prim_mk_defn" s
1267   val _ = Lexis.ok_identifier stem orelse
1268           err (String.concat [Lib.quote stem, " is not alphanumeric"])
1269   val facts = TypeBase.theTypeBase ()
1270 in
1271   case verdict non_wfrec_defn (fn _ => ()) (facts, defSuffix stem, eqns) of
1272     PASS th => th
1273   | FAIL (_, e) =>
1274     case all_fns eqns of
1275       [] => err "no eqns"
1276     | [_] => (* one defn being made *)
1277        let val ((f, args), rhs) = dest_hd_eqn eqns
1278        in
1279          if List.length args > 0 then
1280            if not (can dest_conj eqns) andalso not (free_in f rhs)  andalso
1281               List.all is_simple_arg args
1282            then
1283              (* not recursive, yet failed *)
1284              raise err ("Simple definition failed with message: "^
1285                         holexnMessage e)
1286            else
1287             let
1288                val (tup_eqs, stem', untuple) = pairf (stem, eqns)
1289                  handle HOL_ERR _ =>
1290                    err "failure in internal translation to tupled format"
1291                val wfrec_res = wfrec_eqns facts tup_eqs
1292              in
1293                if exists I (#3 (unzip3 (#extracta wfrec_res)))   (* nested *)
1294                  then nestrec_defn (facts, (stem, stem'), wfrec_res, untuple)
1295                else stdrec_defn  (facts, (stem, stem'), wfrec_res, untuple)
1296              end
1297          else if free_in f rhs then
1298            case move_arg eqns of
1299               SOME tm => prim_mk_defn stem tm
1300             | NONE => err "Simple nullary definition recurses"
1301          else
1302            case free_vars rhs of
1303               [] => err "Nullary definition failed - giving up"
1304             | fvs =>
1305                err ("Free variables (" ^
1306                     String.concat (Lib.commafy (map (#1 o dest_var) fvs)) ^
1307                     ") on RHS of nullary definition")
1308        end
1309     | (_::_::_) => (* mutrec defns being made *)
1310        mutrec_defn (facts, stem, eqns)
1311 end
1312
1313(*---------------------------------------------------------------------------*)
1314(* Version of mk_defn that restores the term signature and grammar if it     *)
1315(* fails.                                                                    *)
1316(*---------------------------------------------------------------------------*)
1317
1318fun mk_defn stem eqns =
1319  Parse.try_grammar_extension
1320    (Theory.try_theory_extension (uncurry prim_mk_defn)) (stem, eqns)
1321
1322fun mk_Rdefn stem R eqs =
1323  let val defn = mk_defn stem eqs
1324  in case reln_of defn
1325      of NONE => defn
1326       | SOME Rvar => inst_defn defn (Term.match_term Rvar R)
1327  end;
1328
1329(*===========================================================================*)
1330(* Calculating dependencies among a unordered collection of definitions      *)
1331(*===========================================================================*)
1332
1333(*---------------------------------------------------------------------------*)
1334(* Transitive closure of relation rel given as list of pairs. The TC         *)
1335(* function operates on a list of elements [...,(x,(Y,fringe)),...] where    *)
1336(* Y is the set of "seen" fringe elements, and fringe is those elements of   *)
1337(* Y that we just "arrived at" (thinking of the relation as a directed       *)
1338(* graph) and which are not already in Y. Steps in the graph are made from   *)
1339(* the fringe, in a breadth-first manner.                                    *)
1340(*---------------------------------------------------------------------------*)
1341
1342fun TC rels_0 =
1343 let fun step a = op_assoc aconv a rels_0 handle HOL_ERR _ => []
1344     fun relstep rels (x,(Y,fringe)) =
1345       let val fringe' = op_U aconv (map step fringe)
1346           val Y' = op_union aconv Y fringe'
1347           val fringe'' = op_set_diff aconv fringe' Y
1348       in (x,(Y',fringe''))
1349       end
1350     fun steps rels =
1351       case Lib.partition (null o (snd o snd)) rels of
1352           (_,[]) => map (fn (x,(Y,_)) => (x,Y)) rels
1353         | (nullrels,nnullrels) =>
1354             steps (map (relstep rels) nnullrels @ nullrels)
1355 in
1356   steps (map (fn (x,Y) => (x,(Y,Y))) rels_0)
1357 end;
1358
1359(*---------------------------------------------------------------------------*)
1360(* Transitive closure of a list of pairs.                                    *)
1361(*---------------------------------------------------------------------------*)
1362
1363fun trancl rel =
1364 let val field = op_U aconv (map (fn (x,y) => [x,y]) rel)
1365     fun init x =
1366       let val Y = rev_itlist (fn (a,b) => fn acc =>
1367                      if aconv a x then op_insert aconv b acc else acc) rel []
1368       in (x,Y)
1369       end
1370 in
1371   TC (map init field)
1372 end;
1373
1374(*---------------------------------------------------------------------------*)
1375(* partial order on the relation.                                            *)
1376(*---------------------------------------------------------------------------*)
1377
1378fun depends_on (a,adeps) (b,bdeps) =
1379  op_mem aconv b adeps andalso not (op_mem aconv a bdeps);
1380
1381(*---------------------------------------------------------------------------*)
1382(* Given a transitively closed relation, topsort it into dependency order,   *)
1383(* then build the mutually dependent chunks (cliques).                       *)
1384(*---------------------------------------------------------------------------*)
1385
1386fun cliques_of tcrel =
1387 let fun chunk [] acc = acc
1388       | chunk ((a,adeps)::t) acc =
1389         let val (bideps,rst) =
1390                 Lib.partition (fn (b,bdeps) => op_mem aconv a bdeps) t
1391         in chunk rst (((a,adeps)::bideps)::acc)
1392         end
1393     val sorted = Lib.topsort depends_on tcrel
1394 in chunk sorted []
1395 end;
1396
1397(*---------------------------------------------------------------------------
1398 Examples.
1399val ex1 =
1400 [("a","b"), ("b","c"), ("b","d"),("d","c"),
1401  ("a","e"), ("e","f"), ("f","e"), ("d","a"),
1402  ("f","g"), ("g","g"), ("g","i"), ("g","h"),
1403  ("i","h"), ("i","k"), ("h","j")];
1404
1405val ex2 =  ("a","z")::ex1;
1406val ex3 =  ("z","a")::ex1;
1407val ex4 =  ("z","c")::ex3;
1408val ex5 =  ("c","z")::ex3;
1409val ex6 =  ("c","i")::ex3;
1410
1411cliques_of (trancl ex1);
1412cliques_of (trancl ex2);
1413cliques_of (trancl ex3);
1414cliques_of (trancl ex4);
1415cliques_of (trancl ex5);
1416cliques_of (trancl ex6);
1417  --------------------------------------------------------------------------*)
1418
1419(*---------------------------------------------------------------------------*)
1420(* Find the variables free in the rhs of an equation, that don't also occur  *)
1421(* on the lhs. This helps locate calls to other functions also being defined,*)
1422(* so that the dependencies can be calculated.                               *)
1423(*---------------------------------------------------------------------------*)
1424
1425fun free_calls tm =
1426 let val (l,r) = dest_eq tm
1427     val (f,pats) = strip_comb l
1428     val lhs_vars = free_varsl pats
1429     val rhs_vars = free_vars r
1430 in
1431    (f, op_set_diff aconv rhs_vars lhs_vars)
1432 end;
1433
1434(*---------------------------------------------------------------------------*)
1435(* Find all the dependencies between a collection of equations. Mutually     *)
1436(* dependent equations end up in the same clique.                            *)
1437(*---------------------------------------------------------------------------*)
1438
1439fun dependencies eqns =
1440  let val basic = map free_calls (strip_conj eqns)
1441      fun agglom ((f,l1)::(g,l2)::t) =
1442             if aconv f g then agglom ((f,op_union aconv l1 l2)::t)
1443              else (f,l1)::agglom ((g,l2)::t)
1444        | agglom other = other
1445 in
1446   cliques_of (TC (agglom basic))
1447 end;
1448
1449(*---------------------------------------------------------------------------
1450 Examples.
1451
1452val eqns = ``(f x = 0) /\ (g y = 1) /\ (h a = h (a-1) + g a)``;
1453dependencies eqns;
1454
1455val eqns = ``(f [] = 0) /\ (f (h1::t1) = g(h1) + f t1) /\
1456             (g y = 1) /\ (h a = h (a-1) + g a)``;
1457dependencies eqns;
1458
1459load"stringTheory";
1460Hol_datatype `term = Var of string | App of string => term list`;
1461Hol_datatype `pterm = pVar of string | pApp of string # pterm list`;
1462
1463val eqns = ``(f (pVar s) = sing s) /\
1464             (f (pApp (a,ptl)) = onion (sing a) (g ptl)) /\
1465             (g [] = {}) /\
1466             (g (t::tlist) = onion (f t) (g tlist)) /\
1467             (sing a = {a}) /\ (onion s1 s2 = s1 UNION s2)``;
1468dependencies eqns;
1469  ---------------------------------------------------------------------------*)
1470
1471val eqn_head = head o lhs;
1472
1473(*---------------------------------------------------------------------------*)
1474(* Take a collection of equations for a set of different functions, and      *)
1475(* untangle the dependencies so that functions are defined before use.       *)
1476(* Handles (mutually) recursive dependencies.                                *)
1477(*---------------------------------------------------------------------------*)
1478
1479fun sort_eqns eqns =
1480 let val eql = strip_conj eqns
1481     val cliques = dependencies eqns
1482     val cliques' = map (map fst) cliques
1483     fun clique_eqns clique =
1484          filter (fn eqn => op_mem aconv (eqn_head eqn) clique) eql
1485 in
1486   map (list_mk_conj o clique_eqns) cliques'
1487 end;
1488
1489(*---------------------------------------------------------------------------
1490 Example (adapted from ex1 above).
1491
1492val eqns =
1493 ``(f1 x = f2 (x-1) + f5 (x-1)) /\
1494   (f2 x = f3 (x-1) + f4 (x-1)) /\
1495   (f3 x = x + 2) /\
1496   (f4 x = f3 (x-1) + f1(x-2)) /\
1497   (f5 x = f6 (x-1)) /\
1498   (f6 x = f5(x-1) + f7(x-2)) /\
1499   (f7 0 = f8 3 + f9 4) /\
1500   (f7 (SUC n) = SUC n * f7 n + f9 n) /\
1501   (f8 x = f10 (x + x)) /\
1502   (f9 x = f8 x + f11 (x-4)) /\
1503   (f10 x = x MOD 2) /\
1504   (f11 x = x DIV 2)``;
1505
1506sort_eqns eqns;
1507sort_eqns (list_mk_conj (rev(strip_conj eqns)));
1508 ---------------------------------------------------------------------------*)
1509
1510fun mk_defns stems eqnsl =
1511 let fun lhs_atoms eqns = op_mk_set aconv (map (head o lhs) (strip_conj eqns))
1512     fun mk_defn_theta stem eqns =
1513       let val Vs = lhs_atoms eqns
1514           val def = mk_defn stem eqns
1515           val Cs = op_mk_set aconv
1516                              (map (head o lhs o snd o strip_forall)
1517                                   (flatten
1518                                      (map (strip_conj o concl) (eqns_of def))))
1519           val theta = map2 (curry (op |->)) Vs Cs
1520       in (def, theta)
1521       end
1522    (*----------------------------------------------------------------------*)
1523    (* Once a definition is made, substitute the introduced constant for    *)
1524    (* the corresponding variables in the rest of the equations waiting to  *)
1525    (* be defined.                                                          *)
1526    (*----------------------------------------------------------------------*)
1527    fun mapdefn [] = []
1528      | mapdefn ((s,e)::t) =
1529         let val (defn,theta) = mk_defn_theta s e
1530             val t' = map (fn (s,e) => (s,subst theta e)) t
1531         in defn :: mapdefn t'
1532         end
1533 in
1534   mapdefn (zip stems eqnsl)
1535 end
1536 handle e => raise wrap_exn "Defn" "mk_defns" e;
1537
1538(*---------------------------------------------------------------------------
1539     Quotation interface to definition. This includes a pass for
1540     expansion of wildcards in patterns.
1541 ---------------------------------------------------------------------------*)
1542
1543fun vary s S =
1544 let fun V n =
1545      let val s' = s^Lib.int_to_string n
1546      in if mem s' S then V (n+1) else (s',s'::S)
1547      end
1548 in V 0 end;
1549
1550(*---------------------------------------------------------------------------
1551    A wildcard is a contiguous sequence of underscores. This is
1552    somewhat cranky, we realize, but restricting to only one
1553    is not great for readability at times.
1554 ---------------------------------------------------------------------------*)
1555
1556fun wildcard s =
1557    s <> "" andalso CharVector.all (fn #"_" => true | _ => false) s
1558
1559local open Absyn
1560in
1561fun vnames_of (VAQ(_,tm)) S = union (map (fst o Term.dest_var) (all_vars tm)) S
1562  | vnames_of (VIDENT(_,s)) S = union [s] S
1563  | vnames_of (VPAIR(_,v1,v2)) S = vnames_of v1 (vnames_of v2 S)
1564  | vnames_of (VTYPED(_,v,_)) S = vnames_of v S
1565
1566fun names_of (AQ(_,tm)) S = union (map (fst o Term.dest_var) (all_vars tm)) S
1567  | names_of (IDENT(_,s)) S = union [s] S
1568  | names_of (APP(_,IDENT(_,"case_arrow__magic"), _)) S = S
1569  | names_of (APP(_,M,N)) S = names_of M (names_of N S)
1570  | names_of (LAM(_,v,M)) S = names_of M (vnames_of v S)
1571  | names_of (TYPED(_,M,_)) S = names_of M S
1572  | names_of (QIDENT(_,_,_)) S = S
1573end;
1574
1575local
1576  val v_vary = vary "v"
1577  fun tm_exp tm S =
1578    case dest_term tm
1579    of VAR(s,Ty) =>
1580         if wildcard s then
1581           let val (s',S') = v_vary S in (Term.mk_var(s',Ty),S') end
1582         else (tm,S)
1583     | CONST _  => (tm,S)
1584     | COMB(Rator,Rand) =>
1585        let val (Rator',S')  = tm_exp Rator S
1586            val (Rand', S'') = tm_exp Rand S'
1587        in (mk_comb(Rator', Rand'), S'')
1588        end
1589     | LAMB _ => raise ERR "tm_exp" "abstraction in pattern"
1590  open Absyn
1591in
1592fun exp (AQ(locn,tm)) S =
1593      let val (tm',S') = tm_exp tm S in (AQ(locn,tm'),S') end
1594  | exp (IDENT (p as (locn,s))) S =
1595      if wildcard s
1596        then let val (s',S') = v_vary S in (IDENT(locn,s'), S') end
1597        else (IDENT p, S)
1598  | exp (QIDENT (p as (locn,s,_))) S =
1599      if wildcard s
1600       then raise ERRloc "exp" locn "wildcard in long id. in pattern"
1601       else (QIDENT p, S)
1602  | exp (APP(locn,M,N)) S =
1603      let val (M',S')   = exp M S
1604          val (N', S'') = exp N S'
1605      in (APP (locn,M',N'), S'')
1606      end
1607  | exp (TYPED(locn,M,pty)) S =
1608      let val (M',S') = exp M S in (TYPED(locn,M',pty),S') end
1609  | exp (LAM(locn,_,_)) _ = raise ERRloc "exp" locn "abstraction in pattern"
1610
1611fun expand_wildcards asy (asyl,S) =
1612   let val (asy',S') = exp asy S in (asy'::asyl, S') end
1613end;
1614
1615fun multi_dest_eq t =
1616    Absyn.dest_eq t
1617      handle HOL_ERR _ => Absyn.dest_binop "<=>" t
1618      handle HOL_ERR _ => raise ERRloc "multi_dest_eq"
1619                                       (Absyn.locn_of_absyn t)
1620                                       "Expected an equality"
1621
1622local
1623  fun dest_pvar (Absyn.VIDENT(_,s)) = s
1624    | dest_pvar other = raise ERRloc "munge" (Absyn.locn_of_vstruct other)
1625                                     "dest_pvar"
1626  fun dest_atom tm = (dest_const tm handle HOL_ERR _ => dest_var tm);
1627  fun dest_head (Absyn.AQ(_,tm)) = fst(dest_atom tm)
1628    | dest_head (Absyn.IDENT(_,s)) = s
1629    | dest_head (Absyn.TYPED(_,a,_)) = dest_head a
1630    | dest_head (Absyn.QIDENT(locn,_,_)) =
1631            raise ERRloc "dest_head" locn "qual. ident."
1632    | dest_head (Absyn.APP(locn,_,_)) =
1633            raise ERRloc "dest_head" locn "app. node"
1634    | dest_head (Absyn.LAM(locn,_,_)) =
1635            raise ERRloc "dest_head" locn "lam. node"
1636  fun strip_tyannote0 acc absyn =
1637      case absyn of
1638        Absyn.TYPED(locn, a, ty) => strip_tyannote0 ((ty,locn)::acc) a
1639      | x => (List.rev acc, x)
1640  val strip_tyannote = strip_tyannote0 []
1641  fun list_mk_tyannote(tyl,a) =
1642      List.foldl (fn ((ty,locn),t) => Absyn.TYPED(locn,t,ty)) a tyl
1643in
1644fun munge eq (eqs,fset,V) =
1645 let val (vlist,body) = Absyn.strip_forall eq
1646     val (lhs0,rhs)   = multi_dest_eq body
1647(*     val   _          = if exists wildcard (names_of rhs []) then
1648                         raise ERRloc "munge" (Absyn.locn_of_absyn rhs)
1649                                      "wildcards on rhs" else () *)
1650     val (tys, lhs)   = strip_tyannote lhs0
1651     val (f,pats)     = Absyn.strip_app lhs
1652     val (pats',V')   = rev_itlist expand_wildcards pats
1653                            ([],Lib.union V (map dest_pvar vlist))
1654     val new_lhs0     = Absyn.list_mk_app(f,rev pats')
1655     val new_lhs      = list_mk_tyannote(tys, new_lhs0)
1656     val new_eq       = Absyn.list_mk_forall(vlist, Absyn.mk_eq(new_lhs, rhs))
1657     val fstr         = dest_head f
1658 in
1659    (new_eq::eqs, insert fstr fset, V')
1660 end
1661end;
1662
1663fun elim_wildcards eqs =
1664 let val names = names_of eqs []
1665     val (eql,fset,_) = rev_itlist munge (Absyn.strip_conj eqs) ([],[],names)
1666 in
1667   (Absyn.list_mk_conj (rev eql), rev fset)
1668 end;
1669
1670(*---------------------------------------------------------------------------*)
1671(* To parse a purported definition, we have to convince the parser that the  *)
1672(* names to be defined aren't constants.  We can do this using "hide".       *)
1673(* After the parsing has been done, the grammar has to be put back the way   *)
1674(* it was.  If a definition is subsequently made, this will update the       *)
1675(* grammar further (ultimately using add_const).                             *)
1676(*---------------------------------------------------------------------------*)
1677
1678fun non_head_idents acc alist =
1679    case alist of
1680      [] => acc
1681    | Absyn.IDENT(_, s)::rest => let
1682        val acc' =
1683            if Lexis.is_string_literal s orelse Lexis.is_char_literal s
1684            then acc
1685            else HOLset.add(acc,s)
1686      in
1687        non_head_idents acc' rest
1688      end
1689    | (a as Absyn.APP(_, _, x))::rest => let
1690        val (_, args) = Absyn.strip_app a
1691      in
1692        non_head_idents acc (args @ rest)
1693      end
1694    | Absyn.TYPED(_, a, _)::rest => non_head_idents acc (a::rest)
1695    | _ :: rest => non_head_idents acc rest
1696
1697fun get_param_names eqs_a = let
1698  val eqs = Absyn.strip_conj eqs_a
1699  val heads = map (#1 o multi_dest_eq o #2 o Absyn.strip_forall) eqs
1700in
1701  non_head_idents (HOLset.empty String.compare) heads |> HOLset.listItems
1702end
1703
1704fun is_constructor_name oinfo s = let
1705  val possible_ops =
1706      case Overload.info_for_name oinfo s of
1707        NONE => []
1708      | SOME {actual_ops, ...} => actual_ops
1709in
1710  List.exists TypeBase.is_constructor possible_ops
1711end
1712
1713fun unify_error pv1 pv2 = let
1714  open Preterm
1715  val (nm,ty1,l1) = dest_ptvar pv1
1716  val (_,ty2,l2) = dest_ptvar pv2
1717in
1718  "Couldn't unify types of head symbol " ^
1719  Lib.quote nm ^ " at positions " ^ locn.toShortString l1 ^ " and " ^
1720  locn.toShortString l2 ^ " with types " ^
1721  type_to_string (Pretype.toType ty1) ^ " and " ^
1722  type_to_string (Pretype.toType ty2)
1723end
1724
1725fun ptdefn_freevars pt = let
1726  open Preterm
1727  val (uvars, body) = strip_pforall pt
1728  val (l,r) = case strip_pcomb body of
1729                  (_, [l,r]) => (l,r)
1730                | _ => raise ERRloc "ptdefn_freevars" (locn body)
1731                             "Couldn't see preterm as equality"
1732  val (f0, args) = strip_pcomb l
1733  val f = head_var f0
1734  val lfs = op_U eq (map ptfvs args)
1735  val rfs = ptfvs r
1736  infix \\
1737  fun s1 \\ s2 = op_set_diff eq s1 s2
1738in
1739  op_union eq (rfs \\ lfs \\ uvars) [f]
1740end
1741
1742fun defn_absyn_to_term a = let
1743  val alist = Absyn.strip_conj a
1744  open errormonad
1745  val tycheck = Preterm.typecheck_phase1 (SOME (term_to_string, type_to_string))
1746  val ptsM =
1747      mmap
1748        (fn a => absyn_to_preterm a >- (fn ptm => tycheck ptm >> return ptm))
1749        alist
1750  fun foldthis (pv as Preterm.Var{Name,Ty,Locn}, env) =
1751    if String.sub(Name,0) = #"_" then return env
1752    else
1753      (case Binarymap.peek(env,Name) of
1754           NONE => return (Binarymap.insert(env,Name,pv))
1755         | SOME pv' =>
1756             Preterm.ptype_of pv' >- (fn pty' => Pretype.unify Ty pty') >>
1757             return env)
1758    | foldthis (_, env) = raise Fail "defn_absyn_to_term: can't happen"
1759  open Preterm
1760  fun construct_final_term pts =
1761    let
1762      val ptm = plist_mk_rbinop
1763                  (Antiq {Tm=boolSyntax.conjunction,Locn=locn.Loc_None})
1764                  pts
1765    in
1766      overloading_resolution ptm >- (fn (pt,b) =>
1767      report_ovl_ambiguity b     >>
1768      to_term pt                 >- (fn t =>
1769      return (t |> remove_case_magic |> !post_process_term)))
1770    end
1771  val M =
1772    ptsM >-
1773    (fn pts =>
1774      let
1775        val all_frees = op_U Preterm.eq (map ptdefn_freevars pts)
1776      in
1777        foldlM foldthis (Binarymap.mkDict String.compare) all_frees >>
1778        construct_final_term pts
1779      end)
1780in
1781  smash M Pretype.Env.empty
1782end
1783
1784fun parse_absyn absyn0 = let
1785  val (absyn,fn_names) = elim_wildcards absyn0
1786  val oldg = term_grammar()
1787  val oinfo = term_grammar.overload_info oldg
1788  val nonconstructor_parameter_names =
1789      List.filter (not o is_constructor_name oinfo) (get_param_names absyn)
1790  val _ =
1791      app (ignore o Parse.hide) (nonconstructor_parameter_names @ fn_names)
1792  fun restore() = temp_set_grammars(type_grammar(), oldg)
1793  val tm  = defn_absyn_to_term absyn handle e => (restore(); raise e)
1794in
1795  restore();
1796  (tm, fn_names)
1797end;
1798
1799(*---------------------------------------------------------------------------*)
1800(* Parse a quotation. Fail if parsing or type inference or overload          *)
1801(* resolution (etc) fail. Returns a list of equations; each element in the   *)
1802(* list is a separate mutually recursive clique.                             *)
1803(*---------------------------------------------------------------------------*)
1804
1805fun parse_quote q =
1806  sort_eqns (fst (parse_absyn (Parse.Absyn q)))
1807  handle e => raise wrap_exn "Defn" "parse_quote" e;
1808
1809fun Hol_defn stem q =
1810 (case parse_quote q
1811   of [] => raise ERR "Hol_defn" "no definitions"
1812    | [eqns] => mk_defn stem eqns
1813    | otherwise => raise ERR "Hol_defn" "multiple definitions")
1814  handle e =>
1815  raise wrap_exn_loc "Defn" "Hol_defn"
1816           (Absyn.locn_of_absyn (Parse.Absyn q)) e;
1817
1818fun Hol_defns stems q =
1819 (case parse_quote q
1820   of [] => raise ERR "Hol_defns" "no definition"
1821    | eqnl => mk_defns stems eqnl)
1822  handle e => raise wrap_exn_loc "Defn" "Hol_defns"
1823                 (Absyn.locn_of_absyn (Parse.Absyn q)) e;
1824
1825local
1826  val stems =
1827    List.map (fst o dest_var o fst o strip_comb o lhs o snd o strip_forall o
1828              hd o strip_conj)
1829in
1830  fun Hol_multi_defns q =
1831    (case parse_quote q of
1832       [] => raise ERR "Hol_multi_defns" "no definition"
1833      | eqnsl => mk_defns (stems eqnsl) eqnsl)
1834    handle e => raise wrap_exn_loc "Defn" "Hol_multi_defns"
1835                   (Absyn.locn_of_absyn (Parse.Absyn q)) e
1836end
1837
1838fun Hol_Rdefn stem Rquote eqs_quote =
1839  let val defn = Hol_defn stem eqs_quote
1840  in case reln_of defn
1841      of NONE => defn
1842       | SOME Rvar =>
1843          let val R = Parse.typedTerm Rquote (type_of Rvar)
1844          in inst_defn defn (Term.match_term Rvar R)
1845          end
1846  end;
1847
1848(*---------------------------------------------------------------------------
1849        Goalstack-based interface to termination proof.
1850 ---------------------------------------------------------------------------*)
1851
1852fun mangle th [] = th
1853  | mangle th [h] = DISCH h th
1854  | mangle th (h::rst) =
1855      Rewrite.PURE_ONCE_REWRITE_RULE [boolTheory.AND_IMP_INTRO]
1856         (DISCH h (mangle th rst));
1857
1858
1859(*---------------------------------------------------------------------------
1860    Have to take care with how the assumptions are discharged. Hence mangle.
1861 ---------------------------------------------------------------------------*)
1862
1863val WF_tm = prim_mk_const{Name="WF",Thy="relation"};
1864
1865fun get_WF tmlist =
1866 pluck (same_const WF_tm o rator) tmlist
1867 handle HOL_ERR _ => raise ERR "get_WF" "unexpected termination condition";
1868
1869fun TC_TAC0 E I =
1870 let val th = CONJ E I
1871     val asl = hyp th
1872     val hyps' = let val (wfr,rest) = get_WF asl
1873                 in wfr::rest end handle HOL_ERR _ => asl
1874     val tac = MATCH_MP_TAC (GEN_ALL (mangle th hyps'))
1875     val goal = ([],concl th)
1876 in
1877   case tac goal
1878    of ([([],g)],validation) => (([],g), fn th => validation [th])
1879     | _  => raise ERR "TC_TAC" "unexpected output"
1880 end;
1881
1882fun TC_TAC defn =
1883 let val E = LIST_CONJ (eqns_of defn)
1884     val I = Option.valOf (ind_of defn)
1885 in
1886   TC_TAC0 E I
1887 end;
1888
1889fun tgoal_no_defn0 (def,ind) =
1890   if null (op_U aconv [(hyp def)])
1891   then raise ERR "tgoal" "no termination conditions"
1892   else let val (g,validation) = TC_TAC0 def ind
1893        in proofManagerLib.add (Manager.new_goalstack g validation)
1894        end handle HOL_ERR _ => raise ERR "tgoal" "";
1895
1896fun tgoal_no_defn (def,ind) =
1897  Lib.with_flag (proofManagerLib.chatting,false) tgoal_no_defn0 (def,ind);
1898
1899fun tgoal0 defn =
1900   if null (tcs_of defn)
1901   then raise ERR "tgoal" "no termination conditions"
1902   else let val (g,validation) = TC_TAC defn
1903        in proofManagerLib.add (Manager.new_goalstack g validation)
1904        end handle HOL_ERR _ => raise ERR "tgoal" "";
1905
1906fun tgoal defn = Lib.with_flag (proofManagerLib.chatting,false) tgoal0 defn;
1907
1908(*---------------------------------------------------------------------------
1909     The error handling here is pretty coarse.
1910 ---------------------------------------------------------------------------*)
1911
1912fun tprove2 tgoal0 (defn,tactic) =
1913  let val _ = tgoal0 defn
1914      val _ = proofManagerLib.expand tactic  (* should finish proof off *)
1915      val th  = proofManagerLib.top_thm ()
1916      val _   = proofManagerLib.drop()
1917      val eqns = CONJUNCT1 th
1918      val ind  = CONJUNCT2 th
1919  in
1920     (eqns,ind)
1921  end
1922  handle e => (proofManagerLib.drop(); raise wrap_exn "Defn" "tprove" e)
1923
1924fun tprove1 tgoal0 p =
1925  let
1926    val (eqns,ind) = Lib.with_flag (proofManagerLib.chatting,false)
1927                                   (tprove2 tgoal0) p
1928    val () = if not (!computeLib.auto_import_definitions) then ()
1929             else computeLib.add_funs
1930                    [eqns, CONV_RULE (!SUC_TO_NUMERAL_DEFN_CONV_hook) eqns]
1931  in
1932    (eqns, ind)
1933  end
1934
1935fun tprove p = tprove1 tgoal0 p
1936fun tprove0 p = tprove2 tgoal0 p
1937fun tprove_no_defn p = tprove1 tgoal_no_defn0 p
1938
1939fun tstore_defn (d,t) =
1940  let val (def,ind) = tprove0 (d,t)
1941  in store (name_of d,def,ind)
1942   ; (def,ind)
1943  end;
1944
1945end (* Defn *)
1946