1structure polytypicLib :> polytypicLib =
2struct
3
4open Binarymap List HolKernel boolLib bossLib Q Parse combinTheory computeLib
5     Conv Thm Tactical BasicProvers Tactic Drule Definition
6     listTheory numLib listLib pairLib Psyntax
7     pairTheory sumTheory Lib arithmeticTheory proofManagerLib;
8
9(*****************************************************************************)
10(* Error handling functions:                                                 *)
11(*                                                                           *)
12(* Standard,Fatal,Debug : exception_level                                    *)
13(* polyExn            : exception_level * string list * string -> exn        *)
14(* isFatal              : exn -> bool                                        *)
15(*     The different exception levels:                                       *)
16(*         Standard : General exceptions                                     *)
17(*         Fatal    : Input of the correct form, but cannot be computed      *)
18(*         Debug    : Input to a lower function was of the wrong form        *)
19(*     and an exception constructor for the traced, leveled exceptions.      *)
20(*     isFatal considers, Fatal, Debug and the Interrupt exceptions fatal    *)
21(*                                                                           *)
22(* exn_to_string        : exn -> string                                      *)
23(*     Converts an polyExn (and standard exceptions) to a string             *)
24(*                                                                           *)
25(* wrapException        : string -> exn -> 'a                                *)
26(* wrapExceptionHOL     : string -> exn -> 'a                                *)
27(*     wrapException adds a name to the trace in an exception and            *)
28(*     wrapExceptionHOL adds a name to the trace then converts to a HOL_ERR  *)
29(*                                                                           *)
30(* Raise                : exn -> 'a                                          *)
31(*     As Feedback.Raise but supports polyExn                                *)
32(*                                                                           *)
33(* mkStandardExn        : string -> string -> exn                            *)
34(* mkFatalExn           : string -> string -> exn                            *)
35(* mkDebugExn           : string -> string -> exn                            *)
36(*     Create the different levels of exception given a function and message *)
37(*                                                                           *)
38(* tryfind_e            : exn -> ('a -> 'b) -> 'a list -> 'b                 *)
39(* first_e              : exn -> ('a -> bool) -> 'a list -> 'a               *)
40(* can                  : ('a -> 'b) -> 'a -> bool                           *)
41(* total                : ('a -> 'b) -> 'a -> 'b option                      *)
42(* repeat               : ('a -> 'a) -> 'a -> 'a                             *)
43(*    Like the versions found in 'Lib' except will re-raise Fatal exceptions *)
44(*    and can have specific exceptions for the end of the list               *)
45(*                                                                           *)
46(* debug                : bool                                               *)
47(*    Determines whether inputs are tested                                   *)
48(*                                                                           *)
49(* assert               : string -> (string * ('a -> bool)) list -> 'a -> 'a *)
50(*    If the debug flag is set, applies each test to 'a raising a debug      *)
51(*    level exception if any fail.                                           *)
52(*                                                                           *)
53(* guarenteed           : ('a -> 'b) -> 'a -> 'b                             *)
54(*    Raises a debug exception if the application fails                      *)
55(*                                                                           *)
56(* check_standard_conv  : string -> term * thm -> thm                        *)
57(* check_matching_conv  : string -> term * thm -> thm                        *)
58(*    Checks the output of a conv to make sure it matches the input term     *)
59(*                                                                           *)
60(*****************************************************************************)
61
62datatype exception_level = Standard | Fatal | Debug;
63exception polyExn of exception_level * string list * string;
64
65local
66(* Prints a function trace                                                   *)
67fun trace_to_string [] = "\n"
68  | trace_to_string [x] = "'" ^ x ^ "'\n"
69  | trace_to_string (x::xs) = "'" ^ x ^ "' ->\n" ^ trace_to_string xs;
70
71fun print_e s msg ftrace = s ^ msg ^ "\nRaised at:\n" ^ (trace_to_string ftrace)
72in
73fun exn_to_string (polyExn(Standard,ftrace,msg)) = print_e "Exception: " msg ftrace
74  | exn_to_string (polyExn(Fatal   ,ftrace,msg)) = print_e "Fatal exception: " msg ftrace
75  | exn_to_string (polyExn(Debug   ,ftrace,msg)) = print_e "Debug exception: " msg ftrace
76  | exn_to_string x = Feedback.exn_to_string x
77end;
78
79fun Raise e = (print (exn_to_string e) ; raise e)
80
81fun isFatal (polyExn(Fatal,_,_)) = true
82  | isFatal (polyExn(Debug,_,_)) = true
83  | isFatal (Interrupt) = true
84  | isFatal _ = false;
85
86fun wrapException name (polyExn(level,trace,msg)) = raise polyExn(level,name::trace,msg)
87  | wrapException name Interrupt = raise Interrupt
88  | wrapException name (HOL_ERR {origin_structure,origin_function,message}) =
89                                raise polyExn(Standard,[name,origin_structure ^ "." ^ origin_function],message)
90  | wrapException name exn = raise polyExn(Standard,[name],exn_to_string exn);
91
92local
93fun set_level Standard msg = msg
94  | set_level Debug msg = "Debug: " ^ msg
95  | set_level Fatal msg = "Fatal: " ^ msg
96in
97fun wrapExceptionHOL name (polyExn(level,[],msg)) = raise (mk_HOL_ERR "polyLib" name (set_level level msg))
98  | wrapExceptionHOL name (polyExn(level,trace,msg)) =
99        raise (foldr (uncurry (Feedback.wrap_exn "polyLib"))
100                (mk_HOL_ERR "polyLib" (last trace) (set_level level msg)) (name::(butlast trace)))
101  | wrapExceptionHOL name Interrupt = raise Interrupt
102  | wrapExceptionHOL name exn = raise (mk_HOL_ERR "polyLib" name (exn_to_string exn))
103end
104
105fun mkStandardExn name msg = polyExn(Standard,[name],msg)
106fun mkFatalExn    name msg = polyExn(Fatal,[name],msg)
107fun mkDebugExn    name msg = polyExn(Debug,[name],msg)
108
109fun tryfind_e exn f [] = raise exn
110  | tryfind_e exn f (x::xs) = (f x) handle e => if isFatal e then raise e else tryfind_e exn f xs;
111
112fun first_e exn p [] = raise exn
113  | first_e exn p (x::xs) = if (p x handle e => if isFatal e then raise e else false) then x else first_e exn p xs;
114
115fun can f x = (f x ; true) handle e => if isFatal e then raise e else false;
116
117fun total f x = SOME (f x) handle e => if isFatal e then raise e else NONE;
118
119fun repeat f x = repeat f (f x) handle e => if isFatal e then raise e else x;
120
121val debug = ref true;
122
123fun assert fname [] data = data
124  | assert fname ((test_msg,test)::tests) data =
125        if (!debug) then
126                if (test data) handle e => false then assert fname tests data else raise polyExn(Debug,[fname],test_msg)
127        else    data;
128
129fun guarenteed f x = (f x)
130        handle (polyExn(level,trace,msg))       => wrapException "guarenteed" (polyExn(Debug,trace,msg))
131        |      Interrupt                        => raise Interrupt
132        |      e                                => wrapException "guarenteed" e;
133
134fun check_standard_conv name (term,thm) =
135        if (!debug) then
136                if is_eq (concl thm) then
137                        if not ((lhs o concl) thm = term)
138                                then raise polyExn(Debug,[name],"Standard conv returned a non-matching theorem")
139                                else thm
140                        else raise polyExn(Debug,[name],"Standard conv did not return an equality")
141        else thm
142
143fun check_matching_conv name (term,thm) =
144        if (!debug) then
145                if is_eq (concl thm) then
146                        if not (can (match_term term) ((lhs o concl) thm))
147                                then raise polyExn(Debug,[name],"Matching conv returned a non-matching theorem")
148                                else thm
149                        else raise polyExn(Debug,[name],"Matching conv did not return an equality")
150        else thm
151
152(*****************************************************************************)
153(* Data Types required                                                       *)
154(*                                                                           *)
155(* translation_scheme :                                                      *)
156(*     Holds the theorems necessary for the creation of polytypic functions  *)
157(*     and optionally polytypic theorems.                                    *)
158(* function :                                                                *)
159(*     Represents a polytypic function, also holds its induction principle   *)
160(* functions, theorems :                                                     *)
161(*     Binary maps from types to strings to functions or theorems            *)
162(* translations :                                                            *)
163(*     The map from types to translating functions and theorems              *)
164(*                                                                           *)
165(*****************************************************************************)
166
167type translation_scheme =
168        {target : hol_type, induction : thm, recursion : thm, left : term, right : term, predicate : term, bottom : term, bottom_thm : thm};
169
170type function = {const : term, definition : thm, induction : (thm * (term * (term * hol_type)) list) option}
171
172type functions = (hol_type,(string,function) dict ref) dict;
173type theorems = (hol_type,(string,thm) dict ref) dict;
174type translations = (hol_type,((functions ref * theorems ref) * translation_scheme)) dict;
175
176(*****************************************************************************)
177(* Trace functionality:                                                      *)
178(*                                                                           *)
179(* type_trace : int -> string -> unit                                        *)
180(*     Prints a trace message if the trace level supplied is greater than    *)
181(*     the level registered.                                                 *)
182(*                                                                           *)
183(* Level 0 : No ouput                                                        *)
184(* Level 1 : Progress through adding / removing splits and theorem progress  *)
185(* Level 2 : Output important intermediate results                           *)
186(* Level 3 : Output most intermediate results and function calls             *)
187(*****************************************************************************)
188
189val Trace = ref 1;
190
191val _ = register_trace ("polytypicLib.Trace",Trace,3);
192
193fun type_trace level s = if level <= !Trace then print s else ();
194
195(*****************************************************************************)
196(* Input testing functions:                                                  *)
197(*                                                                           *)
198(* both              : (bool * bool) -> bool                                 *)
199(*     Returns true for (true,true)                                          *)
200(*                                                                           *)
201(* is_conjunction_of : (term -> bool) -> term -> bool                        *)
202(* is_disjunction_of : (term -> bool) -> term -> bool                        *)
203(* is_implication_of : (term -> bool) -> (term -> bool) -> term -> bool      *)
204(* is_anything       : term -> bool                                          *)
205(*    Recognisers for different terms, 'is_conjunction_of' and               *)
206(*    'is_disjunction_of' only work for right-associated strings             *)
207(*                                                                           *)
208(*****************************************************************************)
209
210fun both (a,b) = a andalso b;
211
212fun is_conjunction_of f x =
213        (is_conj x andalso (f (fst (dest_conj x))) andalso is_conjunction_of f (snd (dest_conj x))) orelse
214        not (is_conj x) andalso (f x);
215
216fun is_disjunction_of f x =
217        (is_disj x andalso (f (fst (dest_disj x))) andalso is_disjunction_of f (snd (dest_disj x))) orelse
218        not (is_disj x) andalso (f x);
219
220fun is_implication_of f g x =
221        (is_imp x) andalso both ((f ## g) (dest_imp x))
222
223fun is_anything (x:term) = true;
224
225(*****************************************************************************)
226(* Printing tools:                                                           *)
227(*                                                                           *)
228(* xlist_to_string : ('a -> string) -> 'a list -> string                     *)
229(* xpair_to_string : ('a -> string) -> ('b -> string) -> string              *)
230(*     Prints a list or a pair of items using supplied printing functions    *)
231(*                                                                           *)
232(*****************************************************************************)
233
234local
235fun XL2S f [] = "]"
236  | XL2S f [x] = (f x) ^ "]"
237  | XL2S f (x::xs) = (f x) ^ "," ^ XL2S f xs
238in
239fun xlist_to_string f list = "[" ^ XL2S f list
240        handle e => wrapException "xlist_to_string" e
241end
242
243fun xpair_to_string f g (a,b) = "(" ^ (f a) ^ "," ^ (g b) ^ ")"
244        handle e => wrapException "xpair_to_string" e
245
246(*****************************************************************************)
247(* General list processing functions                                         *)
248(*                                                                           *)
249(* pick_e       : exn -> ('a -> 'b) -> 'a list -> 'b * 'a list               *)
250(*     Like tryfind_e but returns the rest of the list that cannot be used   *)
251(*                                                                           *)
252(* bucket_alist : (''a * 'b) list -> (''a * 'b list) list                    *)
253(*     Buckets together the first element with a list of the matching second *)
254(*                                                                           *)
255(* mappartition : ('a -> 'b) -> 'a list -> 'b list * 'a list                 *)
256(*     Like mapfilter except returns a list of failures as well              *)
257(*                                                                           *)
258(* reachable_graph  : (''a -> ''a list) -> ''a -> (''a * ''a) list           *)
259(*     Builds the graph of elements reachable from ''a under the function    *)
260(*                                                                           *)
261(* TC, RTC             : ''a * ''a list -> ''a * ''a list                    *)
262(*     Returns the [reflexive] transitive closure of a graph                 *)
263(*                                                                           *)
264(*****************************************************************************)
265
266fun pick_e exn f [] = raise exn
267  | pick_e exn f (x::xs) =
268        (f x,xs) handle e => if isFatal e then raise e else (I ## cons x) (pick_e exn f xs);
269
270fun bucket_alist [] = []
271  | bucket_alist ((x,y)::xys) =
272let     val (a,b) = partition (curry op= x o fst) xys
273in      (x,y::map snd a)::bucket_alist b
274end
275
276fun mappartition f [] = ([],[])
277          | mappartition f (x::xs) = (cons (f x) ## I) (mappartition f xs)
278        handle e => if isFatal e then raise e else (I ## cons x) (mappartition f xs);
279
280fun reachable_graph f t =
281let     fun BR nodes a G =
282        let     val new_nodes = f a
283                val new_edges = map (pair a) new_nodes
284                val to_search = set_diff new_nodes nodes
285        in
286                foldl (fn (nn,G) => BR (to_search @ nodes) nn G) (new_edges @ G) to_search
287        end
288in
289        BR [] t []
290end
291
292local
293        fun all_pairs _ [] = []
294          | all_pairs [] _ = []
295          | all_pairs (x::xs) ys =
296                map (pair x) ys @ all_pairs xs ys;
297
298        fun ep x = map fst o filter (curry op= x o snd)
299        fun sp y = map snd o filter (curry op= y o fst)
300
301        fun add_new (x,y) pairs =
302                union (all_pairs (x::ep x pairs) (y::sp y pairs)) pairs;
303in
304        fun TC pairs = foldl (uncurry add_new) [] pairs
305        fun RTC pairs = TC (foldl (fn ((x,y),l) => insert (x,x) (insert (y,y) l)) pairs pairs)
306end;
307
308(*****************************************************************************)
309(* Term and Thm tools:                                                       *)
310(*                                                                           *)
311(* list_mk_cond      : (term * term) list -> term -> term                    *)
312(*     [(P0,a0),..,(Pn,an)] b --> if P0 then a0 else if P1 then ... else b   *)
313(* imk_comb          : (term * term) -> term                                 *)
314(* list_imk_comb     : (term * term list) -> term                            *)
315(*     As mk_comb and list_mk_comb, except the left-hand term is             *)
316(*     instantiated to match, if possible.                                   *)
317(*                                                                           *)
318(* rimk_comb         : (term * term) -> term                                 *)
319(*     Like imk_comb, except instantiates the right term instead.            *)
320(*                                                                           *)
321(* full_beta_conv    : term -> term                                          *)
322(*      Like Term.beta_conv, but for a list of abstractions                  *)
323(* full_beta         : term -> term                                          *)
324(*      As above, except this does not error when given shorter lists.       *)
325(*                                                                           *)
326(* UNDISCH_ONLY      : thm -> thm                                            *)
327(* UNDISCH_ALL_ONLY  : thm -> thm                                            *)
328(*     Like Drule.UNDISCH_ALL,Drule.UNDISCH but avoids |- ~A                 *)
329(*                                                                           *)
330(* UNDISCH_EQ        : thm -> thm                                            *)
331(* UNDISCH_ALL_EQ    : thm -> thm                                            *)
332(*     |- (A ==> B) = (A ==> C)  --->  [A] |- B = C                          *)
333(*                                                                           *)
334(* UNDISCH_CONJ      : thm -> thm                                            *)
335(*     |- x0 /\ ... /\ xn ==> A  --->  [x0,...,xn] |- A                      *)
336(*                                                                           *)
337(* DISCH_LIST_CONJ   : term list -> thm -> thm                               *)
338(* DISCH_ALL_CONJ    : thm -> thm                                            *)
339(*     [x0,...,xn] |- A          --->  |- x0 /\ ... /\ xn ==> A              *)
340(*                                                                           *)
341(* CONJUNCTS_HYP     : term -> thm -> thm                                    *)
342(*     Splits [``A ==> B /\ C``] |- D to [``A ==> B``,``A ...                *)
343(*                                                                           *)
344(* CONV_HYP          : (term -> thm) -> thm -> thm                           *)
345(*     Applies a conversion to all hypotheses of a theorem                   *)
346(*                                                                           *)
347(* CHOOSE_L          : term list * thm -> thm -> thm                         *)
348(*     Performs a CHOOSE for a list of variables                             *)
349(*                                                                           *)
350(* PROVE_HYP_CHECK   : thm -> thm -> thm                                     *)
351(*     Like PROVE_HYP but checks for an effect first                         *)
352(*                                                                           *)
353(* GEN_THM           : term list -> thm -> thm                               *)
354(*    Like Drule.GENL except it fully specifies the thm first                *)
355(*                                                                           *)
356(* ADDR_AND_CONV     : term -> term -> thm                                   *)
357(* ADDL_AND_CONV     : term -> term -> thm                                   *)
358(*    Converts a term ``A:bool`` to [B] |- A = A /\ B or [B] |- A = B /\ A   *)
359(*                                                                           *)
360(* MATCH_CONV        : thm -> term -> thm                                    *)
361(*    Fully matches the theorem to the term, including instantiating         *)
362(*    variables in the hypothesis                                            *)
363(*                                                                           *)
364(* CASE_SPLIT_CONV   : term -> thm                                           *)
365(*     Converts a term of the form:  '!a. P a'  to perform a split case      *)
366(*     |- !a. P a =                                                          *)
367(*              (!a0 .. an. P (C0 a0 .. an)) /\ ... /\                       *)
368(*              (!a0 .. am. P (Cn a0 .. am))                                 *)
369(*                                                                           *)
370(* PUSH_COND_CONV    : term -> thm                                           *)
371(*    Pushes all function applications over a conditional                    *)
372(*    |- f (g (if a then b else c)) = if a then f (g b) else f (g c)         *)
373(*                                                                           *)
374(* ORDER_FORALL_CONV : term list -> term -> thm                              *)
375(*    Re-orders universally quantified variables to make the list given      *)
376(*                                                                           *)
377(* ORDER_EXISTS_CONV : term list -> term -> thm                              *)
378(*    Re-orders existentially quantified variables to make the list given    *)
379(*                                                                           *)
380(* FUN_EQ_CONV       : term -> thm                                           *)
381(*    Converts the term:  |- (!a b... f = g) = (!x a b... f x = g x)         *)
382(*                                                                           *)
383(* UNFUN_EQ_CONV     : term -> thm                                           *)
384(*    Converts the term:  |- (!a b x... f x = g x) = (!a b... f = g)         *)
385(*                                                                           *)
386(* UNBETA_LIST_CONV  : term list -> term -> thm                              *)
387(*    Like UNBETA_CONV but operates on a list of terms                       *)
388(*                                                                           *)
389(* NTH_CONJ_CONV     : int -> (term -> thm) -> term -> thm                   *)
390(*    Performs a conv on the nth term in a conjunction                       *)
391(*                                                                           *)
392(* MK_CONJ           : thm -> thm -> thm                                     *)
393(* LIST_MK_CONJ      : thm list -> thm                                       *)
394(*    Makes a theorem |- A /\ B = C /\ D from |- A = C, |- B = D             *)
395(*                                                                           *)
396(* TC_THMS           : thm list -> thm list                                  *)
397(*     Takes a list of theorems and repeatedly applies either TRANS or       *)
398(*     IMP_TRANS to get a new set of theorems                                *)
399(*                                                                           *)
400(* prove_rec_fn_exists : thm -> term -> thm                                  *)
401(*   Exactly the same as Prim.prove_rec_fn_exists but performs checking on   *)
402(*   the input if the !debug flag has been set                               *)
403(*                                                                           *)
404(*****************************************************************************)
405
406local
407fun LMC [] last = last
408  | LMC ((x,y)::xys) last = mk_cond (x,y,LMC xys last)
409in
410fun list_mk_cond a b = LMC a b handle e => wrapException "list_mk_cond" e
411end
412
413fun imk_comb (a,b) =
414    mk_comb(inst (match_type (fst (dom_rng (type_of a))) (type_of b)) a,b)
415    handle e => wrapException "imk_comb" e
416
417fun rimk_comb (a,b) =
418    mk_comb(a,inst (match_type (type_of b) (fst (dom_rng (type_of a)))) b)
419    handle e => wrapException "rimk_comb" e
420
421fun list_imk_comb(a,[]) = a
422  | list_imk_comb(a,x::xs) = list_imk_comb(imk_comb (a,x),xs)
423  handle e => wrapException "list_imk_comb" e;
424
425fun full_beta_conv term =
426let val (f,args) = strip_comb term
427in
428    foldl (fn (a,b) => beta_conv (mk_comb(b,a))) f args
429end handle e => wrapException "full_beta_conv" e
430
431fun full_beta x =
432    full_beta_conv x handle _ =>
433    mk_comb(full_beta (rator x),rand x) handle _ => x;
434
435fun UNDISCH_ONLY thm =
436        if is_imp_only (concl thm)
437                then guarenteed UNDISCH thm
438                else raise (mkStandardExn "UNDISCH_ONLY" "Thm is not of the form: \"|- A ==> B\"");
439
440fun UNDISCH_ALL_ONLY thm =
441        if is_imp_only (concl thm)
442                then UNDISCH_ALL_ONLY (guarenteed UNDISCH thm)
443                else thm;
444
445fun UNDISCH_EQ thm =
446let     val a = fst (dest_imp_only (lhs (concl thm)))
447        val b = REWRITE_CONV [ASSUME a] a;
448in
449        CONV_RULE (BINOP_CONV (LAND_CONV (REWR_CONV b) THENC
450                FIRST_CONV (map REWR_CONV (CONJUNCTS (SPEC_ALL IMP_CLAUSES))))) thm
451end     handle e => raise (mkStandardExn "UNDISCH_EQ" "Thm is not of the form: \"|- (P ==> A) = (P ==> B)\"");
452
453fun UNDISCH_ALL_EQ thm = repeat UNDISCH_EQ thm
454
455fun UNDISCH_CONJ thm =
456        (UNDISCH_CONJ (UNDISCH (CONV_RULE (REWR_CONV (GSYM AND_IMP_INTRO)) thm)) handle _ =>
457        UNDISCH_ONLY thm) handle e => raise (mkStandardExn "UNDISCH_CONJ" "Thm is not of the form: \"|- A ==> B\"");
458
459local
460fun DLC [] thm = thm
461  | DLC [x] thm = DISCH x thm
462  | DLC (x::xs) thm =
463        CONV_RULE (TRY_CONV (REWR_CONV AND_IMP_INTRO)) (DISCH x (DLC xs thm))
464in
465fun DISCH_LIST_CONJ l thm = DLC l thm handle e => wrapException "DISCH_LIST_CONJ" e
466end;
467
468fun DISCH_ALL_CONJ thm = DISCH_LIST_CONJ (hyp thm) thm handle e => wrapException "DISCH_ALL_CONJ" e
469
470fun CONJUNCTS_HYP h thm =
471let     val (imps,c) = strip_imp_only
472                (assert "CONJUNCTS_HYP" [("Hypothesis supplied is not a hypothesis of theorem",C mem (hyp thm))] h)
473in
474        (PROVE_HYP (foldr (uncurry DISCH)
475                (LIST_CONJ (map (UNDISCH_ALL_ONLY o ASSUME o curry list_mk_imp imps) (strip_conj c))) imps) thm)
476        handle e => wrapException "CONJUNCTS_HYP" e
477end
478
479fun CONV_HYP c thm =
480let     fun check h =
481                assert "CONV_HYP" [
482                        ("CONV returned a non-equality for hypothesis: " ^ term_to_string h,is_eq o concl),
483                        ("lhs of returned theorem does not match hypothesis: "  ^ term_to_string h,
484                                curry op= h o lhs o concl)] (c h) handle UNCHANGED => REFL h
485in
486        foldl (fn (h,thm) => PROVE_HYP (UNDISCH_ONLY (snd (EQ_IMP_RULE (check h)))) thm) thm (hyp thm)
487end;
488
489local
490fun get_exists x =
491        let     val (v,b) = Psyntax.dest_exists x
492                val (l,r) = get_exists b
493        in      (v::l,x::r) end handle e => ([],[])
494in
495fun CHOOSE_L ([],cthm) thm = thm
496  | CHOOSE_L (vars,cthm) thm =
497let     val (xvars,bodies) = guarenteed get_exists (assert "CHOOSE_L" [
498                                ("cthm is not existentially quantified",boolSyntax.is_exists)] (concl cthm))
499        val (xvars',bodies') = (List.take(xvars,length vars),List.take(bodies,length vars))
500in
501        PROVE_HYP cthm (foldr (uncurry CHOOSE) thm
502                (map2 (C pair o ASSUME o subst (map2 (curry op|->) xvars' vars)) bodies' vars))
503end handle e => wrapException "CHOOSE_L" e;
504end;
505
506fun GEN_THM list thm =
507let     val vars = fst (strip_forall (concl thm))
508        val vars' = map (fn v => if mem v list then genvar (type_of v) else v) vars
509        val _ = assert "GEN_THM" [("List is not a list of variables",all is_var)] list
510in
511        CONV_RULE (RENAME_VARS_CONV (map (fst o dest_var) list))
512                (GENL (map (fn v => assoc v (zip vars vars') handle e => v) list) (SPECL vars' thm))
513        handle e => wrapException "GEN_THM" e
514end;
515
516fun PROVE_HYP_CHECK th1 th2 =
517        PROVE_HYP (assert "PROVE_HYP_CHECK"
518                [("Conclusion of first argument is not a hypothesis of the second",C mem (hyp th2) o concl)] th1) th2;
519
520local
521        val (AND_L_T::AND_R_T::_) = CONJUNCTS (SPEC_ALL AND_CLAUSES)
522        fun ass1 s = assert s [("First argument not of type :bool",curry op= bool o type_of)]
523        fun ass2 s = assert s [("Second argument not of type :bool",curry op= bool o type_of)]
524in
525fun ADDR_AND_CONV term2 term1 =
526        check_standard_conv "ADDR_AND_CONV"
527                (term1,SYM (RIGHT_CONV_RULE (REWR_CONV AND_R_T)
528                        (AP_TERM (mk_comb(conjunction,ass1 "ADDR_AND_CONV" term1))
529                        (EQT_INTRO (ASSUME (ass2 "ADDR_AND_CONV" term2))))))
530fun ADDL_AND_CONV term2 term1 =
531        check_standard_conv "ADDL_AND_CONV"
532                (term2,SYM (RIGHT_CONV_RULE (REWR_CONV AND_L_T)
533                        (AP_THM (AP_TERM conjunction (EQT_INTRO (ASSUME (ass2 "ADDL_AND_CONV" term2))))
534                        (ass1 "ADDL_AND_CONV" term1))))
535end;
536
537fun MATCH_CONV thm term =
538let     val match = match_term ((lhs o concl) thm) term
539in
540        if op_mem (fn x => fn y => fst (dest_var x) = fst (dest_var y) handle _ => false)
541                ((repeat rator o lhs o concl) thm) (map #redex (fst match))
542        then NO_CONV term
543        else check_standard_conv "MATCH_CONV"
544                (term,REWR_CONV (INST_TY_TERM match thm) term)
545end;
546
547fun ORDER_FORALL_CONV list term =
548let     val (a,b) = strip_forall term
549        val (vars,body) = (List.take(a,length list),list_mk_forall(List.drop(a,length list),b))
550                handle e => wrapException "ORDER_FORALL_CONV" e
551        val _ = if set_eq vars list then () else
552                raise (mkStandardExn "ORDER_FORALL_CONV"
553                        ("Variable set: " ^ xlist_to_string term_to_string list ^
554                         "\n is equal to the quantifier set of: " ^ term_to_string term))
555in
556        check_standard_conv "ORDER_FORALL_CONV" (term,IMP_ANTISYM_RULE
557                (DISCH_ALL (GENL list (SPECL vars (ASSUME term))))
558                (DISCH_ALL (GENL vars (SPECL list (ASSUME (list_mk_forall(list,body))))))
559                handle e => wrapException "ORDER_FORALL_CONV" e)
560end;
561
562fun ORDER_EXISTS_CONV l term =
563let     val (ra,bodya) = strip_exists term
564        val (r,body) = (List.take(ra,length l),
565                        list_mk_exists(List.drop(ra,length l),bodya))
566                        handle e => wrapException "ORDER_EXISTS_CONV" e
567        fun mk_exists r l body =
568                DISCH_ALL (CHOOSE_L (l,ASSUME (list_mk_exists(l,body)))
569                (foldr (uncurry SIMPLE_EXISTS) (ASSUME body) r))
570in
571        IMP_ANTISYM_RULE (mk_exists l r body) (mk_exists r l body)
572        handle e => wrapException "ORDER_EXISTS_CONV" e
573end;
574
575local
576fun order_conv flip term =
577let     val (hs,body) = strip_forall term
578        val (front,back) = partition (flip o curry op= ((rand o lhs) body)) hs
579in      (ORDER_FORALL_CONV (front @ back) term)
580end
581in
582val FUN_EQ_CONV = STRIP_QUANT_CONV (REWR_CONV FUN_EQ_THM) THENC order_conv I
583val UNFUN_EQ_CONV = order_conv not THENC ONCE_DEPTH_CONV (REWR_CONV (GSYM FUN_EQ_THM))
584end
585
586local
587fun UNBETA_LIST_CONV' [] term = ALL_CONV term
588  | UNBETA_LIST_CONV' (x::xs) term =
589        (UNBETA_CONV x THENC RATOR_CONV (UNBETA_LIST_CONV' xs)) term
590in
591fun UNBETA_LIST_CONV list term = check_standard_conv "UNBETA_LIST_CONV"
592        (term,(UNBETA_LIST_CONV' (rev list) term) handle e => wrapException "UNBETA_LIST_CONV" e)
593end;
594
595local
596fun NTH_CONJ_CONV' 0 conv term =
597        if is_conj term then AP_THM (AP_TERM conjunction (conv (fst (dest_conj term)))) (snd (dest_conj term))
598                        else conv term
599  | NTH_CONJ_CONV' n conv term =
600        AP_TERM (mk_comb(conjunction,fst (dest_conj term))) (NTH_CONJ_CONV' (n - 1) conv (snd (dest_conj term)))
601in
602fun NTH_CONJ_CONV n conv term =
603let     fun conv' h =
604                assert "NTH_CONJ" [
605                        ("CONV returned a non-equality for conjunction: " ^ term_to_string h,is_eq o concl),
606                        ("CONV returned a theorem that does not match conjunction: "  ^ term_to_string h,
607                                curry op= h o lhs o concl)] (conv h)
608in
609        check_standard_conv "NTH_CONJ_CONV" (term,NTH_CONJ_CONV' n conv' term
610                handle e => wrapException "NTH_CONJ_CONV" e)
611end
612end;
613
614fun CASE_SPLIT_CONV term =
615let     val (xvar,body) = dest_forall term handle e => wrapException "CASE_SPLIT_CONV" e
616        val t = type_of xvar
617        val nchot_thm = TypeBase.nchotomy_of t
618                handle e => raise (mkStandardExn "CASE_SPLIT_CONV"
619                                ("An nchotomy does not exist for the type of the " ^
620                                 " universally quantified variable: " ^ type_to_string t))
621        val nchot = ISPEC xvar nchot_thm
622                handle e => raise (mkDebugExn "CASE_SPLIT_CONV"
623                        ("TypeBase returned an nchotomy for type " ^ type_to_string t ^
624                         " which was not universally quantified with a variable of the same type!"))
625        val all_vars = find_terms is_var term
626        fun VARIANT_CONV term =
627        let     val vars = fst (strip_exists term)
628        in      RENAME_VARS_CONV (map (fst o dest_var o variant all_vars) vars) term
629        end;
630        val nchot' = CONV_RULE (EVERY_DISJ_CONV VARIANT_CONV) nchot handle e => wrapException "CASE_SPLIT_CONV" e
631        val nchots = strip_disj (concl nchot')
632
633        val full_left = DISCH_ALL (LIST_CONJ (map (fn n => GENL (snd (strip_comb (rhs (snd (strip_exists n)))))
634                                (INST [xvar |-> rhs (snd (strip_exists n))] (SPEC_ALL (ASSUME term)))) nchots))
635                        handle e => wrapException "CASE_SPLIT_CONV" e
636
637        val r_tm = snd (dest_imp_only (concl full_left)) handle e => wrapException "CASE_SPLIT_CONV" e
638        val right = map2 (fn n => fn c => PURE_REWRITE_RULE [GSYM (ASSUME (snd (strip_exists n)))] (SPEC_ALL c)) nchots (CONJUNCTS (ASSUME r_tm))
639                        handle e => wrapException "CASE_SPLIT_CONV" e
640        val right' = map2 (fn n => fn r => CHOOSE_L (fst (strip_exists n),ASSUME n) r handle _ => r) nchots right
641        val full_right = DISCH_ALL (GEN xvar (DISJ_CASESL nchot' right'))
642                        handle e => wrapException "CASE_SPLIT_CONV" e
643in
644        IMP_ANTISYM_RULE full_left full_right handle e => wrapException "CASE_SPLIT_CONV" e
645end
646
647local
648fun PCC term =
649        ((REWR_CONV COND_RAND) ORELSEC (RAND_CONV PCC THENC REWR_CONV COND_RAND) ORELSEC ALL_CONV) term
650in
651fun PUSH_COND_CONV term = PCC term
652        handle UNCHANGED => raise UNCHANGED | e => wrapException "PUSH_COND_CONV" e
653end
654
655local
656fun MC thm1 thm2 = MK_COMB (AP_TERM conjunction thm1,thm2)
657fun LMC [] = raise (mkStandardExn "LIST_MK_CONJ" "Empty list")
658  | LMC [x] = x
659  | LMC (x::xs) = MC x (LMC xs);
660in
661fun MK_CONJ thm1 thm2 = MC thm1 thm2 handle e => wrapException "MK_CONJ" e
662fun LIST_MK_CONJ [] = LMC []
663  | LIST_MK_CONJ thms = LMC thms handle e => wrapException "LIST_MK_CONJ" e
664end;
665
666local
667fun etrans f l r thm1 thm2 =
668let     val (vars1,body1) = strip_forall (concl thm1)
669        val (vars2,body2) = strip_forall (concl thm2)
670        val match = match_term (l body2) (r body1)
671        val _ = if null (set_diff (map #redex (fst match)) vars1) then ()
672                else raise Empty
673        val _ = if null (set_diff (map #residue (fst match)) vars2) then ()
674                else raise Empty
675        val thm1' = SPEC_ALL thm1
676        val thm2' = INST_TY_TERM match (SPEC_ALL thm2)
677in
678        GENL vars1 (f thm1' thm2')
679end
680fun mtrans t1 t2 =
681        etrans TRANS lhs rhs t1 t2 handle e =>
682        etrans IMP_TRANS (fst o dest_imp_only) (snd o dest_imp_only) t1 t2;
683fun trans_all _ [] = []
684  | trans_all [] _ = []
685  | trans_all (x::xs) ys = (mapfilter (mtrans x) ys) @ trans_all xs ys
686in
687fun TC_THMS thms =
688let     val next = trans_all thms thms
689        val diff = op_set_diff (fn a => fn b => concl a = concl b) next thms
690in
691        if null diff then thms else TC_THMS (diff @ thms)
692end
693end
694
695local
696fun assert' x = assert "prove_rec_fn_exists" x
697val dexn = mkDebugExn "prove_rec_fn_exists"
698fun PRFE axiom term =
699let     val _ = assert' [("Not a (right associative) conjunction of generalised equalities: " ^ term_to_string term,
700                        is_conjunction_of (is_eq o snd o strip_forall))] term
701        val conjuncts = map (snd o strip_forall) (strip_conj term)
702        val funcs = map (fst o strip_comb o lhs) conjuncts
703        fun fvs conj = (set_diff (set_diff (free_vars (rhs conj)) (free_vars (lhs conj))) funcs,conj)
704        val _ = case (total (first (not o null o fst)) (map fvs conjuncts))
705                of SOME (var_list,clause) => raise (dexn ("The variables; " ^ xlist_to_string term_to_string var_list ^
706                        " are free in the clause: " ^ term_to_string clause))
707                |  NONE => ()
708        val ax_err = "Axiom is not an existentially quantified conjunction of equalities: " ^ thm_to_string axiom;
709        val _ = assert' [(ax_err,can Psyntax.dest_exists o snd o strip_forall o concl),
710                        (ax_err,is_conjunction_of (is_eq o snd o strip_forall) o snd o strip_exists o
711                                snd o strip_forall o concl)] axiom;
712        val constructors_axiom = map (repeat rator o rand o lhs o snd o strip_forall)
713                        ((strip_conj o snd o strip_exists o snd o strip_forall o concl) axiom)
714        val constructors_term = map (repeat rator o rand o lhs o snd o strip_forall) (strip_conj term)
715        val _ = case (op_set_diff same_const constructors_axiom constructors_term)
716                of [] => ()
717                |  list => raise (dexn ("The constructors; " ^ xlist_to_string term_to_string list ^
718                                " are not used in the function"))
719        val _ = case (op_set_diff same_const constructors_term constructors_axiom)
720                of [] => ()
721                |  list => raise (dexn ("The constructors; " ^ xlist_to_string term_to_string list ^
722                        " are used in the function but not specified in the axiom"))
723in
724        Prim_rec.prove_rec_fn_exists axiom term handle e => wrapException "prove_rec_fn_exists" e
725end
726in
727fun prove_rec_fn_exists axiom term = PRFE axiom term
728end;
729
730(*****************************************************************************)
731(* Type tools:                                                               *)
732(*                                                                           *)
733(* constructors_of : hol_type -> term list                                   *)
734(*     Like TypeBase.constructors_of except ensures that the result types    *)
735(*     match exactly                                                         *)
736(*                                                                           *)
737(* base26          : int -> string                                           *)
738(*     Converts a number to the nth string in the sequence {a,b,..aa,ab,...} *)
739(*                                                                           *)
740(* base_type       : hol_type -> hol_type                                    *)
741(*     Changes a type (t0,t1,t2...) t to ('a,'b,'c,...) t                    *)
742(*                                                                           *)
743(* sub_types       : hol_type -> hol_type list                               *)
744(*     Returns direct sub-types of the type given                            *)
745(*                                                                           *)
746(* uncurried_sub_types : hol_type -> hol_type list                           *)
747(*     Returns direct sub-types of the type given but treats constructors of *)
748(*     the type Cn:t0 -> .. -> tn -> t as :t0 * ... * tn -> t                *)
749(*                                                                           *)
750(* split_nested_recursive_set :                                              *)
751(*             hol_type -> (hol_type * (hol_type list * hol_type list)) list *)
752(*     Returns a list mapping a set of mutually recursive types to nested    *)
753(*     mutually recursive types and direct sub-types                         *)
754(*                                                                           *)
755(* zip_on_types    :  ('a -> hol_type) -> ('b -> hol_type) ->                *)
756(*                                      'a list -> 'b list -> ('a * 'b) list *)
757(*     Finds a mapping between two lists by matching types after applying    *)
758(*     a pair of functions                                                   *)
759(*                                                                           *)
760(* get_type_string : hol_type -> string                                      *)
761(*     Returns a sanitised name for a type by removing preceding '           *)
762(*                                                                           *)
763(* SAFE_INST_TYPE, safe_inst, safe_type_subst                                *)
764(*     Like their standard counterparts, except they prevent capture of      *)
765(*     type variables already in the term                                    *)
766(*                                                                           *)
767(*****************************************************************************)
768
769fun constructors_of t =
770        map (fn c => inst (match_type ((snd o strip_fun o type_of) c) t) c)
771                (TypeBase.constructors_of t)
772        handle e => wrapException "constructors_of" e;
773
774local
775fun base26i n A =
776        if n < 26 then (Char.chr (Char.ord #"a" + n)::A)
777        else base26i (n div 26 - 1) (Char.chr (Char.ord #"a" + n mod 26)::A)
778in
779fun base26 n = base26i n []
780end;
781
782local
783fun mk_nvartype n = mk_vartype (implode (#"'" :: base26 n));
784fun get_type_params t =
785        if is_vartype t
786                then []
787                else map (mk_nvartype o fst) (enumerate 0 (snd (dest_type t)))
788        handle e => wrapException "get_type_params" e;
789fun type_vars_cannonA (t,A) =
790        if is_vartype t then t::A
791        else if can dest_type t
792        then foldl type_vars_cannonA A (snd (dest_type t))
793        else [];
794fun type_vars_cannon t = rev (mk_set (type_vars_cannonA (t,[])))
795in
796fun base_type t =
797        mk_type (fst (dest_type t),get_type_params t)
798        handle e => wrapException "base_type" e;
799fun cannon_type t =
800        type_subst (map (fn (a,b) => b |-> mk_nvartype a) (enumerate 0 (type_vars_cannon t))) t
801end
802
803fun sub_types t =
804let     val constructors = constructors_of t
805in
806        mk_set (flatten (map (fst o strip_fun o type_of) constructors))
807end     handle e => []
808
809fun uncurried_subtypes t =
810let     val cs = constructors_of t
811in
812        if can (match_type (mk_prod (alpha,beta))) t then sub_types t
813        else mk_set (mapfilter (list_mk_prod o fst o strip_fun o type_of) cs)
814end     handle e => [];
815
816fun split_nested_recursive_set t =
817let     val G = (t,t)::reachable_graph sub_types t
818        val RTC_G = RTC G
819        val mr_set = mk_set (filter (fn a => mem (a,t) RTC_G andalso mem (t,a) RTC_G) (map fst G))
820        fun is_nested t' = not (exists (can (C match_type (base_type t'))) mr_set)
821        val (nmr,pmr) = partition is_nested mr_set
822in
823        map (fn x => (x,(mk_set ## mk_set)
824                (partition (C mem nmr) (map snd (reachable_graph (fn t => set_diff (sub_types t) pmr) x))))) pmr
825end     handle e => wrapException "split_nested_recursive_set" e
826
827local
828fun pluck_all f [] = []
829  | pluck_all f (x::xs) =
830        (if f x then (x,xs)::map (I ## cons x) (pluck_all f xs)
831                else map (I ## cons x) (pluck_all f xs))
832in
833fun zip_on_types f g [] [] = []
834  | zip_on_types f g _  [] = raise (mkStandardExn "zip_on_types" "Lists of different length")
835  | zip_on_types f g [] _  = raise (mkStandardExn "zip_on_types" "Lists of different length")
836  | zip_on_types f g (x::xs) ys =
837let     val poss_l = pluck_all (can (match_type (f x)) o g) ys handle e => wrapException "zip_on_types" e
838in
839        tryfind_e (mkStandardExn "zip_on_types" "No match found") (fn (p,l) => (x,p)::zip_on_types f g xs l) poss_l
840end
841end;
842
843local
844val sanitise = filter (fn a => not (a = #"'") andalso not (a = #"%"))
845val remove_primes = implode o sanitise o explode
846in
847fun get_type_string t =
848        if      is_vartype t
849        then    remove_primes (dest_vartype t)
850        else    fst (dest_type t)
851        handle e => wrapException "get_type_string" e
852end
853
854local
855fun mapsfrom s tyvars = map (fn a => a |-> gen_tyvar()) (filter (C mem (map #residue s)) tyvars)
856in
857fun SAFE_INST_TYPE s thm =
858let     val map1 = mapsfrom s (mk_set (flatten (map (type_vars_in_term) (uncurry (C (curry op::)) (dest_thm thm)))))
859in
860        INST_TYPE s (INST_TYPE map1 thm)
861end     handle e => wrapException "SAFE_INST_TYPE" e
862fun safe_inst s term =
863let     val map1 = mapsfrom s (type_vars_in_term term)
864in
865        inst s (inst map1 term)
866end     handle e => wrapException "safe_inst" e
867fun safe_type_subst s t =
868let     val map1 = mapsfrom s (type_vars t)
869in
870        type_subst s (type_subst map1 t)
871end     handle e => wrapException "safe_type_subst" e
872end;
873
874(*****************************************************************************)
875(* Checking functions for different kinds of output:                         *)
876(*                                                                           *)
877(* is_source_function : term -> bool                                         *)
878(*     Returns true if the term is a conjunction of terms of the form:       *)
879(*           (!x .. y. fni f0 .. fn (C a0 .. an) = A [fnj a0...])            *)
880(*     In particular:                                                        *)
881(*           a) there can be no free variables in any clause except fni      *)
882(*           b) the function must be defined for all constructors of all     *)
883(*              argument types.                                              *)
884(*                                                                           *)
885(* is_target_function : term -> bool                                         *)
886(*     Returns true if the term is a conjunction of terms of the form:       *)
887(*           (!x .. y. fni f0 .. fn x = if P x then A x else B x)            *)
888(*     In particular:                                                        *)
889(*           a) there can be no free variables in any clause except fni      *)
890(*           b) calls to any fni cannot occur inside B                       *)
891(*     also returns true if terms are singly constructed, eg:                *)
892(*           (!x .. y. fni f0 .. fn x = A x)                                 *)
893(*                                                                           *)
894(* is_expanded_function : term -> bool                                       *)
895(*    Returns true if the term is a function term and every recusive call:   *)
896(*          fni ... x = ... f0 f1 f2 .. fk a ...                             *)
897(*          where a is free in x, or x is free in a                          *)
898(*          there is an fni free in {f0..fk}                                 *)
899(*    has the function (f0 in this case) defined in the conjunction          *)
900(*                                                                           *)
901(*****************************************************************************)
902
903local
904fun is_function term =
905        all (fn x => x term) [
906                is_conjunction_of (is_eq o snd o strip_forall),
907                is_conjunction_of (is_comb o lhs o snd o strip_forall),
908                null o set_diff (free_vars term) o map (repeat rator o lhs o snd o strip_forall) o strip_conj]
909val pt = (fn (a,b,c) => a) o dest_cond
910val preds = map pt o filter is_cond o map (rhs o snd o strip_forall) o strip_conj
911fun xaconv a b = aconv (list_mk_abs(free_vars_lr a,a)) (list_mk_abs(free_vars_lr b,b))
912fun is_target_term p c =
913        (is_cond o rhs o snd o strip_forall) c andalso
914        (xaconv p o pt o rhs o snd o strip_forall) c
915fun is_target_term_single c =
916        (not o is_cond o rhs o snd o strip_forall) c orelse
917        let     val x = (rand o lhs o snd o strip_forall) c
918                val (_,_,y) = (dest_cond o rhs o snd o strip_forall) c
919        in
920                not (free_in x y)
921        end
922fun is_function_target term =
923        is_function term andalso
924        (is_conjunction_of is_target_term_single term orelse
925         can (tryfind (pt o rhs o snd o strip_forall)) (strip_conj term) andalso
926         let    val p = tryfind (pt o rhs o snd o strip_forall) (strip_conj term)
927         in     is_conjunction_of (fn x => is_target_term_single x orelse is_target_term p x) term
928         end)
929fun encodes_constructors C term =
930let     val cs = (map (repeat rator o rand o lhs o snd o strip_forall) o strip_conj) term
931in
932        all is_const cs andalso
933        all (fn c => exists (same_const c) C) cs
934end
935fun get_ftypes term = (mk_set o map (type_of o rand o lhs o snd o strip_forall) o strip_conj) term
936fun constructors t = TypeBase.constructors_of t handle _ => []
937in
938fun is_source_function term =
939        is_function term andalso
940        encodes_constructors
941                (flatten (map constructors (get_ftypes term)))
942                term
943fun is_target_function term =
944        is_function term andalso
945        is_function_target term
946fun is_expanded_function term =
947let     val all_fns = map (repeat rator o lhs o snd o strip_forall) (strip_conj term)
948        val rec_calls = flatten (map (fn c =>
949                        find_terms (fn t => is_comb t andalso
950                                        exists (C free_in (rator t)) all_fns andalso
951                                        (free_in ((rand o lhs o snd o strip_forall) c) (rand t) orelse
952                                         free_in (rand t) ((rand o lhs o snd o strip_forall) c)))
953                                ((rhs o snd o strip_forall) c)) (strip_conj term))
954        val shortened = filter (fn x => not (exists (fn t => not (x = t) andalso free_in t x) rec_calls))
955                                rec_calls
956in
957        all (C mem all_fns o repeat rator) shortened
958end
959end
960
961(*****************************************************************************)
962(* SPLIT_FUNCTION_CONV: (term list -> term -> term -> bool) * thm ->         *)
963(*                                                 thm list -> term -> thm   *)
964(*     SPLIT_FUNCTION_CONV: (is_double_term,pair_def) ho_function_defs term  *)
965(*     replaces any higher order functions with their definitions including  *)
966(*     a hypothesis stating that the variable used to replace them is        *)
967(*     correct. It also rewrites pair_def whenever possible                  *)
968(*                                                                           *)
969(* RFUN_CONV : thm list -> conv                                              *)
970(*     Given a list of equalities, rewrites them in a term provided the      *)
971(*     rewrite occurs for a variable in the constructor of the function      *)
972(*     being rewritten.                                                      *)
973(*                                                                           *)
974(* SPLIT_HFUN_CONV: thm -> term list -> term -> (term list * thm)            *)
975(*     Replaces a term ...f G x... with f' G' x where f' is a variable       *)
976(*     constrained in the hypothesis to follow the definition of f and G' is *)
977(*     a subset of G containing only variables in the top definition         *)
978(*                                                                           *)
979(* SPLIT_PAIR_CONV: (term list -> term -> term -> bool) -> term list ->      *)
980(*                                          thm -> term -> (term list * thm) *)
981(*     Replaces a pair term ...pair f g x... with the definition of pair, if *)
982(*     the pair is of the form pair f0 f1 (L/R x) then a new function is     *)
983(*     made                                                                  *)
984(*                                                                           *)
985(* is_single_constructor: translation_scheme -> term -> bool                 *)
986(*     Returns true if the term can be rewritten into the mutual recursion,  *)
987(*     and as such requires removal of (L x) and (R x) not just (L (R x))    *)
988(*                                                                           *)
989(* is_double_term_target : translation_scheme ->                             *)
990(*                                term list -> term -> term -> bool          *)
991(* is_double_term_source : term list -> term -> term -> bool                 *)
992(*     Checks to see if a pair term in a function requires splitting, the    *)
993(*     'target' function deals with decoding and detecting whereas the       *)
994(*     'source' function deals with encoding                                 *)
995(*                                                                           *)
996(*****************************************************************************)
997
998fun is_single_constructor (scheme:translation_scheme) term =
999let     val err = mkStandardExn "is_single_constructor" "Term is not of the form: 'f a = b'"
1000        val isP = #predicate scheme
1001        val left = #left scheme
1002        val right = #right scheme
1003        val (l,r) = (dest_eq o snd o strip_forall) term handle e => raise err
1004        val var = rand l handle e => raise err
1005        val _ = assert "is_single_constructor" [("Recursive variable is of type: " ^ type_to_string (type_of var) ^
1006                " however the predicate is of type: " ^ type_to_string (type_of isP),
1007                curry op= (type_of var) o fst o dom_rng o type_of)] isP
1008in
1009        not (free_in (beta_conv (mk_comb (isP,var))) r) orelse
1010        not (free_in (beta_conv (mk_comb (left,var))) r orelse free_in (beta_conv (mk_comb (right,var))) r)
1011        handle e => wrapException "is_single_constructor" e
1012end;
1013
1014fun RFUN_CONV rewrites term =
1015let     val all_funs = mk_set (map (rator o lhs o snd o strip_forall) (strip_conj term))
1016        fun conv clause =
1017                ONCE_DEPTH_CONV (fn term =>
1018                        if      exists (C free_in (rand term))
1019                                        (op:: (strip_comb (rand (lhs (snd (strip_forall clause))))))
1020                                andalso null (find_terms (same_const conditional) term)
1021                        then    ONCE_DEPTH_CONV (FIRST_CONV (map REWR_CONV rewrites)) term
1022                        else    NO_CONV term)
1023                clause
1024in
1025        EVERY_CONJ_CONV conv term
1026end;
1027
1028local
1029fun assert' x = assert "SPLIT_HFUN_CONV" x
1030val func_exn = mkDebugExn "SPLIT_HFUN_CONV" "HO Function supplied is of the form \"(f x = A x) /\\ (g x =...\"";
1031
1032fun wrap e = wrapException "SPLIT_HFUN_CONV" e
1033in
1034fun SPLIT_HFUN_CONV hfun_def fvs term =
1035let     val _ = type_trace 3 "->SPLIT_HFUN_CONV\n"
1036        val _  = (assert' [
1037                        ("Term is not a conjunction of equalities",is_conjunction_of (is_eq o snd o strip_forall)),
1038                        ("Term is not a conjunction of function (not constant) definitions",
1039                                is_conjunction_of (can dest_comb o lhs o snd o strip_forall))] term)
1040        val function_terms = (mk_set o map (rator o lhs o snd o strip_forall) o strip_conj o concl) hfun_def
1041        val _ = if exists (C free_in ((list_mk_conj o map (snd o strip_forall) o strip_conj) term)) function_terms
1042                        then () else raise UNCHANGED
1043        val _ = assert' [("Function list is not a list of variables",all is_var)] fvs
1044        val _ = assert' [("Constants specified in higher order function: " ^ thm_to_string hfun_def,
1045                         all (can dom_rng o type_of))] function_terms;
1046
1047        val (fvs',new_consts) =
1048                foldr (fn (x,(fvs,consts)) =>
1049                        let     val (arg_type,res_type) = dom_rng (type_of x)
1050                                val ftvs = set_diff (free_varsl (snd (strip_comb x))) fvs
1051                                val nc = variant fvs (mk_var("split",
1052                                        list_mk_fun (map type_of ftvs,arg_type --> res_type)))
1053                        in
1054                                (nc::fvs,(ftvs,list_mk_comb(nc,ftvs))::consts)
1055                        end) (fvs,[]) function_terms handle e => wrap e
1056
1057        val concl_assumptions = map2 (fn (a,b) => curry list_mk_forall a o curry mk_eq b) new_consts function_terms
1058                                        handle e => wrap e;
1059        val hyp_assumption = list_mk_conj (map (fn x =>
1060                                        list_mk_forall(free_vars_lr (rand (lhs x)),
1061                                        list_mk_forall(mk_set (flatten (map fst new_consts)),
1062                                                subst (map2 (curry op|->) function_terms (map snd new_consts)) x)))
1063                                        ((map (snd o strip_forall) o strip_conj o concl) hfun_def))
1064                                handle e => wrap e
1065        val rewrites = map (SYM o SPEC_ALL) (CONJUNCTS (UNDISCH_ONLY
1066                                (ASSUME (mk_imp(hyp_assumption,list_mk_conj concl_assumptions)))))
1067                                handle e => wrap e
1068in
1069        (rewrites,fvs',check_standard_conv "SPLIT_HFUN_CONV" (term,
1070                (RIGHT_CONV_RULE (ADDR_AND_CONV hyp_assumption THENC PURE_REWRITE_CONV [GSYM CONJ_ASSOC])
1071                (RFUN_CONV rewrites term)) handle e => wrap e))
1072end
1073end;
1074
1075local
1076val debug_exn = mkDebugExn "SPLIT_PAIR_CONV"
1077val func_exn = debug_exn "Term is not a conjunction of equalities";
1078val pair_exn = debug_exn "Pair theorem is not of the form \"pair f g x = A (f x) (g x)\""
1079fun wrap UNCHANGED = raise UNCHANGED | wrap e = wrapException "SPLIT_PAIR_CONV" e
1080fun wrapd UNCHANGED = raise UNCHANGED | wrapd e = wrapException "SPLIT_PAIR_CONV (fix_double_term)" e
1081
1082fun FUN_EQ_RULE thm =
1083let     val (vars,body) = (strip_forall o concl) thm
1084        val a = (rand o lhs) body
1085in
1086        GENL (set_diff vars [a]) (CONV_RULE (REWR_CONV (GSYM FUN_EQ_THM)) (GEN a (SPEC_ALL thm)))
1087end handle e => wrapException "SPLIT_PAIR_CONV (FUN_EQ_RULE)" e
1088
1089fun fix_double_term fvs funcs pair_def term =
1090let     val (l_thm,px) = with_exn dest_comb term pair_exn
1091        val vars = set_diff (free_vars l_thm) funcs
1092        val new_term = list_mk_comb(variant fvs (mk_var("split",
1093                                foldr (fn (a,t) => type_of a --> t) (type_of l_thm) vars)),vars) handle e => wrapd e
1094
1095        val pvar1 = with_exn (rand o lhs o concl) pair_def pair_exn
1096        val pvar2 = subst (fst (foldl (fn (v,(s,fvs)) => let val x = variant fvs v in ((v |-> x) :: s,x::fvs) end)
1097                                ([],vars) (free_vars pvar1))) pvar1 handle e => wrapd e;
1098        val pvar3 = inst (match_type (type_of pvar2) (fst (dom_rng (type_of l_thm)))) pvar2 handle e => wrapd e;
1099
1100        val func =  snd (EQ_IMP_RULE (STRIP_QUANT_CONV (RAND_CONV (REWR_CONV pair_def))
1101                        (list_mk_forall(free_vars_lr pvar3 @ vars,mk_eq(mk_comb(new_term,pvar3),mk_comb(l_thm,pvar3))))))
1102                        handle e => wrapd e
1103
1104        val rewrite = if is_pair pvar3
1105                        then FUN_EQ_RULE (HO_MATCH_MP (TypeBase.induction_of (mk_prod(alpha,beta)))
1106                                        (UNDISCH_ONLY func) handle e => wrapd e)
1107                        else FUN_EQ_RULE (UNDISCH_ONLY func handle e => wrapd e)
1108in
1109        (repeat rator new_term :: fvs,(GSYM rewrite,fst (dest_imp_only (concl func)))) handle e => wrapd e
1110end;
1111
1112in
1113fun SPLIT_PAIR_CONV is_double_term fvs pair_def term =
1114let     val _ = type_trace 3 "->SPLIT_PAIR_CONV\n"
1115        val _ = type_trace 4 ("Term: " ^ term_to_string term ^ "\n")
1116        val _ = type_trace 4 ("FVS:  " ^ xlist_to_string term_to_string fvs ^ "\n")
1117        val pair_def_spec = SPEC_ALL pair_def
1118        val pair_left = with_exn (rator o lhs o concl) pair_def_spec pair_exn
1119        val clauses = with_exn strip_conj term func_exn
1120        val funcs = with_exn (mk_set o map (repeat rator o lhs o snd o strip_forall)) clauses func_exn
1121        val split_terms = flatten (map (fn c => map (pair c)
1122                        (find_terms (fn x => is_comb x andalso exists (C free_in x) (free_vars (rand (lhs (snd (strip_forall c)))))
1123                                andalso can (match_term pair_left) (rator x)) c)) clauses);
1124        val double_terms = mk_set (map snd (filter (uncurry (is_double_term funcs)) split_terms)) handle e => wrap e
1125
1126        val (fvs',(rewrites,new_funcs)) =
1127                foldl (fn (double,(fvs,(RWS,NFS))) =>
1128                        (I ## (C cons RWS ## C cons NFS))
1129                                (fix_double_term fvs funcs pair_def_spec double))
1130                (fvs,([],[])) double_terms
1131                handle e => wrap e
1132in
1133        (rewrites,fvs',check_standard_conv "SPLIT_PAIR_CONV" (term,
1134                foldr (fn (a,thm) =>
1135                                RIGHT_CONV_RULE (ADDR_AND_CONV a THENC PURE_REWRITE_CONV [GSYM CONJ_ASSOC]) thm)
1136                        (RFUN_CONV rewrites term) new_funcs)
1137        handle e => wrap e)
1138end
1139end;
1140
1141
1142local
1143val debug_exn = mkDebugExn "SPLIT_FUNCTION_CONV"
1144val func_exn = debug_exn "Term is not a conjunction of equalities";
1145fun wrap e = wrapException "SPLIT_FUNCTION_CONV" e
1146
1147fun SFC (is_double_term,pair_def) [] (fvs,thm) =
1148        (let    val (rewrites,fvs',thm') = SPLIT_PAIR_CONV is_double_term fvs pair_def ((rhs o concl) thm)
1149        in      SFC (is_double_term,pair_def) [] (fvs',TRANS thm thm')
1150        end     handle UNCHANGED => (fvs,thm))
1151  | SFC (is_double_term,pair_def) hfuns (fvs,thm) =
1152        (let    val ((rewrites,fvs',thm'),hfuns') = pick_e
1153                        UNCHANGED (fn hfun => SPLIT_HFUN_CONV hfun fvs ((rhs o concl) thm)) hfuns
1154         in     SFC (is_double_term,pair_def)
1155                        (map (CONV_RULE (ONCE_DEPTH_CONV (FIRST_CONV (map REWR_CONV rewrites)))) hfuns')
1156                        (fvs',TRANS thm thm')
1157         end) handle UNCHANGED =>
1158        (let    val (rewrites,fvs',thm') = SPLIT_PAIR_CONV is_double_term fvs pair_def ((rhs o concl) thm)
1159         in     SFC (is_double_term,pair_def)
1160                        (map (CONV_RULE (ONCE_DEPTH_CONV (FIRST_CONV (map REWR_CONV rewrites)))) hfuns)
1161                        (fvs',TRANS thm thm')
1162         end) handle UNCHANGED =>
1163        raise (debug_exn        ("Unable to split function, neither conv applies to term:\n " ^
1164                                ((term_to_string o rhs o concl) thm) ^
1165                                "\n remaining function defs: " ^
1166                                (xlist_to_string thm_to_string hfuns)))
1167        | e => wrap e
1168in
1169fun SPLIT_FUNCTION_CONV pair_double ho_function_defs term =
1170let     val _ = type_trace 2 "->SPLIT_FUNCTION_CONV\n";
1171        val _ = assert "SPLIT_FUNCTION_CONV" [(
1172                        "The term:\n" ^ term_to_string term ^
1173                        "\nis not a valid source or target function",
1174                        fn x => is_source_function x orelse is_target_function x)] term
1175        val result = check_standard_conv "SPLIT_FUNCTION_CONV" (term,snd (SFC pair_double ho_function_defs
1176                ((with_exn (mk_set o map (repeat rator o lhs o snd o strip_forall) o strip_conj) term func_exn),
1177                REFL term)))
1178        val _ = assert "SPLIT_FUNCTION_CONV" [(
1179                        "Result of splitting:\n" ^ term_to_string ((rhs o concl) result) ^
1180                        "\nis not a fully expanded source or target function,\n" ^
1181                        "perhaps higher functions are missing from the function definitions given?",
1182                        fn x => (is_source_function x orelse is_target_function x) andalso is_expanded_function x)]
1183                ((rhs o concl) result)
1184in
1185        result
1186end
1187end;
1188
1189fun is_double_term_target (scheme:translation_scheme) funcs clause term =
1190let     val l = guarenteed (snd o dest_abs o #left) scheme
1191        val r = guarenteed (snd o dest_abs o #right) scheme
1192        val (b1,x) = dest_comb term handle e => wrapException "is_double_term_target" e
1193        val (b2,rcall) = dest_comb b1 handle e => wrapException "is_double_term_target" e
1194        val (_,lcall) = dest_comb b2 handle e => wrapException "is_double_term_target" e
1195        fun is_lr_term x = is_comb x andalso
1196                (can (match_term l) x orelse can (match_term r) x)
1197in
1198        ((exists (C free_in rcall) funcs) orelse (exists (C free_in lcall) funcs)) andalso
1199        (is_single_constructor scheme clause orelse is_lr_term x) handle e => wrapException "is_double_term_target" e
1200end
1201
1202
1203fun is_double_term_source funcs (clause:term) term =
1204let     val (b1,x) = dest_comb term handle e => wrapException "is_double_term_source" e
1205        val (b2,rcall) = dest_comb b1 handle e => wrapException "is_double_term_source" e
1206        val (_,lcall) = dest_comb b2 handle e => wrapException "is_double_term_source" e
1207in
1208        not (pairLib.is_pair x) andalso (exists (C free_in rcall) funcs orelse exists (C free_in lcall) funcs)
1209        handle e => wrapException "is_double_term_source" e
1210end;
1211
1212(*****************************************************************************)
1213(* prove_recind_thms_mutual: translation_scheme -> term -> thm * thm         *)
1214(*                                                                           *)
1215(*     prove_recind_thms_mutual proves the existence of mutually recursive   *)
1216(*     functions when given a recursion theorem of the form:                 *)
1217(*                                                                           *)
1218(*       |- ?fn. fn x =                                                      *)
1219(*               if P x then f (L x) (R x) (fn (L x)) (fn (R x)) else c      *)
1220(*                                                                           *)
1221(*     and a theorem of induction given a theorem of the form:               *)
1222(*                                                                           *)
1223(*       |- !P0. (!x. isP x /\ P0 (L x) /\ P0 (R x) ==> P0 x) /\             *)
1224(*                                      (!x. ~isP x ==> P0 x) ==> !x. P0 x   *)
1225(*                                                                           *)
1226(* build_call_graph : term * term -> term list ->                            *)
1227(*                                     (int * (int list * int list)) list    *)
1228(*     Given a list of function clauses returns a list of mapping functions  *)
1229(*     to those recursively called using L x and R x                         *)
1230(*                                                                           *)
1231(* create_mutual_theorem: (int * (int list * int list)) list -> thm -> thm   *)
1232(*     Given a call graph from build_call_graph and theorem of the type      *)
1233(*     specified earlier proves a mutually recursive function of the form:   *)
1234(*                                                                           *)
1235(*      |- ?fn0 fn1 fn2. fn0 x =                                             *)
1236(*                 (if P x then f0 (L x) (R x) (fi (L x)) (fj (R x)) else c) *)
1237(*              /\ (if P x then f1 ...                                       *)
1238(*                                                                           *)
1239(* instantiate_mutual_theorem: thm -> term list -> thm                       *)
1240(*     Given a mutually recursive theorem of the form output by the function *)
1241(*     create_mutual_theorem and a set of clauses instantiates the theorem   *)
1242(*     to prove the existence of the functions as defined by the clauses     *)
1243(*                                                                           *)
1244(* create_ind_theorem: (int * (int list * int list)) list ->                 *)
1245(*                                            translation_scheme -> thm      *)
1246(*     Creates a theorem of induction from a call graph and a theorem of the *)
1247(*     form given earlier.                                                   *)
1248(*                                                                           *)
1249(*****************************************************************************)
1250
1251local
1252        fun e_rev_assoc L [] = []
1253          | e_rev_assoc L (x::xs) =
1254                ((rev_assoc (repeat rator x) L) :: e_rev_assoc L xs) handle e => e_rev_assoc L xs
1255
1256        val type_exn = mkDebugExn "build_call_graph"
1257                        "The type of the L/R operators does not match the argument of type of one of the clauses"
1258
1259        fun snd_rand x = if is_comb x then (if not (is_comb (rand x)) then x else snd_rand (rand x)) else x;
1260in
1261fun build_call_graph (left,right) clauses =
1262let     val _ = type_trace 3 "->build_call_graph\n"
1263        val (names,ho_funcs) = unzip (map (strip_comb o fst o dest_comb o lhs o snd o strip_forall)
1264                        (assert "build_call_graph" [
1265                                ("Second argument is not a list of function clauses",all (is_eq o snd o strip_forall)),
1266                                ("Second argument is not a list of function (not constant) definitions",
1267                                        all (can dest_comb o lhs o snd o strip_forall))] clauses))
1268        val bl = with_exn (snd o dest_abs) left (mkDebugExn "build_call_graph" "Left term is not of the form: \\x.P x")
1269        val br = with_exn (snd o dest_abs) right (mkDebugExn "build_call_graph" "Right term is not of the form: \\x.P x")
1270        val fdefs = map (dest_eq o snd o strip_forall) clauses;
1271
1272        fun is_lr tm = can (match_term bl) tm orelse can (match_term br) tm
1273
1274        fun find_names_of X (func,def) =
1275        let     val var = rand func
1276                val isX = curry op= (with_exn (beta_conv o mk_comb) (X,var) type_exn)
1277                fun t1 x = is_comb x andalso not (is_lr x) andalso is_lr (rand x) andalso isX (snd_rand x)
1278        in
1279                find_terms t1 def
1280        end
1281
1282in
1283        map2    (fn (n,fdef) => fn hf => (n,(   e_rev_assoc (enumerate 0 names) (find_names_of left fdef),
1284                                                e_rev_assoc (enumerate 0 names) (find_names_of right fdef))))
1285                (enumerate 0 fdefs) ho_funcs
1286end
1287end;
1288
1289local
1290fun wrap "" e = wrapException "create_mutual_theorem" e
1291  | wrap s  e = wrapException ("create_mutual_theorem (" ^ s ^ ")") e
1292
1293fun make_out [] n tm = raise Empty
1294  | make_out ((x,ft)::fts) n tm =
1295        if x = n
1296                then (if can (sumSyntax.dest_sum o type_of) tm then sumSyntax.mk_outl tm else tm)
1297                else make_out fts n (sumSyntax.mk_outr tm)
1298fun make_outp a b c = make_out a b c handle e => wrap "make_out" e;
1299
1300fun make_in [] n tm = raise Empty
1301  | make_in [(x,ft)] n tm = if x = n then tm else raise Empty
1302  | make_in ((x,ft)::fts) n tm =
1303        if x = n
1304                then sumSyntax.mk_inl (tm,sumSyntax.list_mk_sum(map snd fts))
1305                else sumSyntax.mk_inr (make_in fts n tm,ft)
1306fun make_inp a b c = make_in a b c handle e => wrap "make_in" e;
1307
1308fun make_out_thm [] n thm = raise Empty
1309  | make_out_thm ((x,ft)::fts) n thm =
1310let     val l = (lhs o concl) thm
1311in
1312        if x = n
1313                then (if can (sumSyntax.dest_sum o type_of) l then
1314                        AP_TERM (rator (sumSyntax.mk_outl l)) thm else thm)
1315                else make_out_thm fts n
1316                        (AP_TERM (rator (sumSyntax.mk_outr l)) thm)
1317end
1318fun make_out_thmp a b c = make_out_thm a b c handle e => wrap "make_out_thm" e;
1319
1320fun make_rec_term func_types rt n = make_outp func_types n (mk_comb(rt,term_of_int n) handle e => wrap "make_rec_term" e);
1321
1322fun make_single_term func_types mk_var x_var (rlt,rrt) (n,(xs,ys)) =
1323        make_inp func_types n (list_mk_comb(mk_var("f"^(int_to_string n),
1324                        (type_of x_var) -->
1325                        (foldr (fn (a,t) => assoc a func_types --> t)
1326                                (foldr (fn (a,t) => assoc a func_types --> t) (assoc n func_types) ys) xs)),
1327                x_var::(map (make_rec_term func_types rlt) xs @ map (make_rec_term func_types rrt) ys))
1328                handle e => wrap "make_single_term" e);
1329
1330fun make_f_term func_types mk_var var x_var rlr [] = raise Empty
1331  | make_f_term func_types mk_var var x_var rlr [x] = make_single_term func_types mk_var x_var rlr x
1332  | make_f_term func_types mk_var var x_var rlr ((n,(x,y))::xs) =
1333let     val r = make_single_term func_types mk_var x_var rlr (n,(x,y))
1334        val f = make_f_term func_types mk_var var x_var rlr xs
1335in
1336        mk_cond(mk_eq(var,term_of_int n),r,f) handle e => wrap "make_f_term" e
1337end
1338
1339fun extract_f_term exn fterm =
1340        case (strip_comb fterm)
1341        of (f0,[x_var,rlt,rrt]) => (f0,x_var,rlt,rrt)
1342        |  _ => raise exn;
1343
1344fun check_call_graph cg =
1345        all (fn (n,(xs,ys)) =>
1346                all (fn x => exists (fn a => fst a = x) cg) xs andalso
1347                all (fn y => exists (fn a => fst a = y) cg) ys) cg
1348
1349fun make_c_term func_types mk_var x_var n =
1350        make_in func_types n
1351                (mk_comb(mk_var("c" ^ (int_to_string n),type_of x_var --> assoc n func_types),x_var)
1352                handle e => wrap "make_c_term" e)
1353
1354fun FTERM_CONV func_types func var term =
1355let     val (outs,tm) = repeat (fn (l,x) =>
1356                                if      can (match_term sumSyntax.outl_tm) (rator x) orelse
1357                                        can (match_term sumSyntax.outr_tm) (rator x)
1358                                then (rator x::l,rand x) else raise Empty) ([],term)
1359in
1360        if mem (type_of term) (map snd func_types) andalso can (match_term func) tm then
1361                (UNBETA_CONV ((rand o rator) tm) THENC RATOR_CONV (RENAME_VARS_CONV [(fst (dest_var var))])) term
1362                handle UNCHANGED => raise UNCHANGED | e => wrap "FTERM_CONV" e
1363        else NO_CONV term
1364end
1365
1366in
1367fun create_mutual_theorem call_graph thm =
1368let     val _ = type_trace 3 "->create_mutual_theorem\n"
1369        val _ = assert "create_mutual_theorem" [("Bad call graph!",check_call_graph)] call_graph
1370        val exn = mkDebugExn "create_mutual_theorem"
1371                        ("thm supplied for mutual recursion is not of the form: " ^
1372                         "\"?fn. !x. fn x = if P x then f0 (L x) (R x) (fn (L x)) (fn (R x)) else c0\"")
1373        val (fterm,body) = with_exn Psyntax.dest_exists (concl thm) exn;
1374        val res_t = type_of (with_exn (rhs o snd o dest_forall) body exn);
1375        val func_types = map (I ## gen_tyvar o K ()) call_graph
1376        val sum_t = sumSyntax.list_mk_sum (map snd func_types);
1377
1378        val inst_it = inst [res_t |-> ``:num`` --> sum_t]
1379        val var = with_exn (rand o lhs o snd o dest_forall) body exn;
1380
1381        val fvs = ref (free_varsl (fterm :: body :: hyp thm))
1382        fun mk_var (name,t) =
1383        let     val nv = variant (!fvs) (Term.mk_var (name,t)) in (fvs := nv :: (!fvs) ; nv) end
1384
1385        val (_,f_term,c_term) = with_exn (dest_cond o rhs o snd o dest_forall) (inst_it body) exn;
1386        val (f0,x_var,rlt,rrt) = extract_f_term exn f_term;
1387        val c0 = rator c_term
1388
1389        val (x_var',rlt',rrt') = (genvar (type_of x_var),genvar (type_of rlt),genvar (type_of rrt))
1390
1391        val v = mk_var("v",``:num``);
1392        val f0term =    let     val ft = make_f_term func_types mk_var v x_var' (rlt',rrt') call_graph
1393                        in      list_mk_abs([x_var',rlt',rrt',v],ft) handle e => wrap "" e end
1394        val c0term =    mk_abs(x_var',mk_abs(v,foldr (fn (a,t) => mk_cond(mk_eq(v,term_of_int (fst a)),
1395                                                        make_c_term func_types mk_var x_var' (fst a),t))
1396                                (make_c_term func_types mk_var x_var' (fst (last call_graph)))
1397                                (butlast call_graph))) handle e => wrap "" e;
1398
1399        val thm1 = RIGHT_CONV_RULE (REWRITE_CONV [COND_RAND,COND_RATOR]) (AP_THM (SPEC_ALL (ASSUME (inst_it body))) v)
1400                                handle e => wrap "" e
1401        val thm2 = BETA_RULE (INST [f0 |-> f0term, c0 |-> c0term] thm1) handle e => wrap "" e
1402        val thm3 = LIST_CONJ (map
1403                        (fn (n,_) =>    (GEN var o RIGHT_CONV_RULE PUSH_COND_CONV o
1404                                        make_out_thm func_types n o
1405                                        (REWR_CONV thm2 THENC ONCE_DEPTH_CONV REDUCE_CONV) o
1406                                        curry mk_comb (mk_comb(inst_it fterm,var)) o term_of_int) n) call_graph)
1407                        handle e => wrap "" e
1408        val thm4 = CONV_RULE (  DEPTH_CONV (REWR_CONV sumTheory.OUTL ORELSEC REWR_CONV sumTheory.OUTR) THENC
1409                                ONCE_DEPTH_CONV (FTERM_CONV func_types (list_mk_comb(inst_it fterm,[var,v])) var)) thm3
1410                        handle e => wrap "" e;
1411
1412        fun make_var n = mk_var(fst (dest_var fterm) ^ (int_to_string n),type_of var --> assoc n func_types);
1413        fun make_term n = mk_abs(var,make_out func_types n (mk_comb(mk_comb(inst_it fterm,var),term_of_int n)));
1414
1415        val thm5 = foldr (fn (a,thm) =>
1416                        let     val var = make_var (fst a)
1417                        in      EXISTS (Psyntax.mk_exists(var,subst [make_term (fst a) |-> var] (concl thm)),
1418                        make_term (fst a)) thm
1419                end) thm4 call_graph  handle e => wrap "" e
1420in
1421        CHOOSE (inst_it fterm,INST [f0 |-> f0term, c0 |-> c0term]
1422                (INST_TYPE [res_t |-> ``:num`` --> sum_t] thm)) thm5  handle e => wrap "" e
1423end
1424end;
1425
1426local
1427fun debug_exn s = mkDebugExn "instantiate_mutual_theorem" s;
1428
1429val exn1 = debug_exn "Function clauses are not all of the form \"!x x0 .. xn. f x = A x0 ... xn\""
1430val exn2 = debug_exn (  "Recursive theorem is not of the form: " ^
1431                        "\"?fn0 ... fnm. (!x. fn0 x = A (fn1 (L x)) ... (fnm (R x))) /\\ ... \"")
1432
1433fun wrap "" e = wrapException "instantiate_mutual_theorem" e
1434  | wrap s  e = wrapException ("instantiate_mutual_theorem (" ^ s ^ ")") e
1435
1436fun convit [] term = (DEPTH_CONV BETA_CONV term handle UNCHANGED => REFL term)
1437  | convit list term = (DEPTH_CONV BETA_CONV THENC UNBETA_LIST_CONV list) term;
1438
1439fun instantiate_clause term_subst ((n,(func,body)),(thm,mthm)) =
1440let     val thm_clause = with_exn List.nth ((strip_conj (concl thm)),n)
1441                (debug_exn "Recursion theorem has a different number of clauses than the function clauses supplied")
1442        val (_,thm_rec,thm_const) = with_exn (dest_cond o rhs o snd o strip_forall) thm_clause exn2
1443        val (_,term_rec,term_const) = with_exn (dest_cond o subst term_subst) body exn1
1444
1445        fun DCBC x = DEPTH_CONV BETA_CONV x handle UNCHANGED => REFL x
1446
1447        val term_const_thm = convit (snd (strip_comb thm_const)) term_const;
1448        fun drop x = List.drop x handle e => []
1449        val term_rec_thm = convit
1450                (assert "instantiate_mutual_theorem" [
1451                        ("Recursive call missing from mutual recursion theorem",
1452                        (all (C free_in ((rhs o concl o DCBC) (subst term_subst body))) o
1453                        C (curry drop) 2))] ((snd o strip_comb) thm_rec)) term_rec;
1454
1455        val _ = assert "instantiate_mutual_theorem"
1456                        [("x is free in the function body, should be either R x or L x, in function clause:\n" ^
1457                                (term_to_string (mk_eq (func,body))),
1458                         (not o free_in (rand func) o repeat rator o rhs o concl))] term_rec_thm
1459
1460        val insttt =    INST_TY_TERM (match_term thm_const ((rhs o concl) term_const_thm)) o
1461                        INST_TY_TERM (match_term thm_rec ((rhs o concl) term_rec_thm))
1462                        handle e => wrap "instantiate_clause" e
1463
1464in
1465        (CONV_RULE (NTH_CONJ_CONV n (
1466                        STRIP_QUANT_CONV (FORK_CONV (UNBETA_LIST_CONV (snd (strip_comb func)),
1467                                RAND_CONV (REWR_CONV (GSYM term_const_thm)) THENC
1468                                RATOR_CONV (RAND_CONV (REWR_CONV (GSYM term_rec_thm)))))))
1469                (insttt thm),
1470                insttt mthm) handle e => wrap "instantiate_clause" e
1471end;
1472
1473in
1474fun instantiate_mutual_theorem mthm clauses =
1475let     val _ = type_trace 3 "->instantiate_mutual_theorem\n"
1476        val split_term = with_exn (map (dest_eq o snd o strip_forall)) clauses exn1
1477        val (fterms_thm,thm_body) = with_exn (strip_exists o concl) mthm exn2;
1478        val thm_clauses = map SPEC_ALL (CONJUNCTS (ASSUME thm_body))
1479        val arg_types = with_exn (map (type_of o rand o (fn (a,b,c) => a) o dest_cond o snd)) split_term exn1;
1480        val arg_type = hd (assert "instantiate_mutual_theorem"
1481                                [("Function term is mutually recursive on different types",
1482                                all (curry op= (hd arg_types)))] arg_types);
1483        val thm_arg_types = with_exn (map (type_of o rand o lhs o concl)) thm_clauses exn2
1484        val thm_arg_type = hd (assert "instantiate_mutual_theorem"
1485                                [("Recursion thm is mutually recursive on different types",
1486                                all (curry op= (hd thm_arg_types)))] thm_arg_types);
1487
1488        val (type_subst,args) =
1489                unzip (map2 (fn tc => fn (func,body) =>
1490                        let     val args = with_exn (snd o strip_comb o rator) func exn1
1491                                val res_t = with_exn (type_of o lhs o concl) tc exn2
1492                        in
1493                                (res_t |-> list_mk_fun(map type_of args,type_of func),args)
1494                        end) thm_clauses
1495                (assert "instantiate_mutual_theorem"
1496                        [("Recursion theorem has a different number of clauses than the function clauses supplied",
1497                        curry op= (length thm_clauses) o length)] split_term));
1498
1499        val thm_clauses' =
1500                map2 (fn a => RIGHT_CONV_RULE (REWRITE_CONV [COND_RATOR]) o
1501                                C (foldl (uncurry (C AP_THM))) a o INST_TYPE ((thm_arg_type |-> arg_type)::type_subst))
1502                args thm_clauses handle e => wrap "" e;
1503
1504        val term_subst =
1505                map2 (fn tc => fn (func,body) =>
1506                        let     val args = snd (strip_comb (lhs (concl tc)))
1507                        in
1508                                (repeat rator func |-> list_mk_abs(tl args,(mk_abs(hd args,lhs (concl tc))))) end)
1509                thm_clauses' split_term handle e => wrap "" e;
1510
1511        val (thm1,mthm') = foldl (instantiate_clause term_subst)
1512                        (LIST_CONJ (map2 (fn x => GEN_THM ((fst o strip_forall) x)) clauses thm_clauses'),
1513                                INST_TYPE ((thm_arg_type |-> arg_type)::type_subst) mthm)
1514                        (enumerate 0 split_term);
1515
1516
1517        val thm2 = foldr (fn ((func1,func2),thm) =>
1518                        EXISTS (mk_exists(func1,subst [func2 |-> func1] (concl thm)),func2) thm) thm1
1519                        (zip    (map (repeat rator o fst) split_term)
1520                                (map (repeat rator o lhs o snd o strip_forall) (strip_conj (concl thm1))))
1521                handle e => wrap "" e
1522
1523in
1524        CHOOSE_L (fst (strip_exists (concl mthm')),mthm') thm2 handle e => wrap "" e
1525end
1526end;
1527
1528local
1529        fun mkDebug e = mkDebugExn "create_ind_theorem" e
1530        fun wrap e = wrapException "create_ind_theorem" e
1531in
1532fun create_ind_theorem call_graph (scheme:translation_scheme) =
1533let     val _ = type_trace 3 "->create_ind_theorem\n"
1534        val ind_thm = #induction scheme
1535        val isP   = #predicate scheme
1536        val left  = #left scheme
1537        val right = #right scheme
1538        val target = #target scheme
1539
1540        val x = mk_var("x",target)
1541        val isPcomb = beta_conv (mk_comb(isP,x)) handle e =>
1542                        raise (mkDebug ("Predicate for translation scheme " ^ type_to_string target ^
1543                                        " is not of the form: \\x.P x"))
1544        fun mkP y p = mk_comb(mk_var("P" ^ (int_to_string p),target --> ``:bool``),beta_conv (mk_comb(y,x)))
1545        fun mkP_var n = mk_comb(mk_var("P" ^ (int_to_string n),target --> ``:bool``),x)
1546
1547        val ind_terms_pre =
1548                map (fn (n,(l,r)) =>
1549                        (n,isPcomb :: (foldr (fn (a,l) => mkP left a :: l) (map (mkP right) r) l)))
1550                call_graph handle e => wrap e
1551        val non_ind_terms =
1552                map (fn (n,_) => mk_forall(x,mk_imp(mk_neg(isPcomb),mkP_var n))) call_graph handle e => wrap e
1553
1554        val full_ind_thm = BETA_RULE (SPEC (mk_abs(x,list_mk_conj(map (mkP_var o fst) call_graph))) ind_thm)
1555                                handle e => wrap e
1556
1557        val (thm1,ind_terms) =
1558                (LIST_CONJ ## I) (unzip (map (fn (n,tms) =>
1559                                let     val tmf = mk_forall(x,mk_imp(list_mk_conj tms,mkP_var n))
1560                                in      (MATCH_MP (ASSUME tmf) (LIST_CONJ (map ASSUME tms)),tmf) end)
1561                        ind_terms_pre)) handle e => wrap e
1562        val fi_term = (fst o dest_imp_only o snd o strip_forall o fst o dest_conj o fst o
1563                        dest_imp_only o concl o SPEC_ALL) full_ind_thm handle e => wrap e
1564
1565        val thm2 = CONJ (GEN x (DISCH fi_term (foldl (uncurry PROVE_HYP) thm1 (CONJUNCTS (ASSUME fi_term)))))
1566                        (GEN x (DISCH (mk_neg (isPcomb))
1567                                (LIST_CONJ (map (UNDISCH_ONLY o SPEC_ALL o ASSUME) non_ind_terms))))
1568                handle e => wrap e
1569in
1570        GENL (map (fn (n,_) => mk_var("P" ^ (int_to_string n),target --> ``:bool``)) call_graph)
1571                (PURE_REWRITE_RULE [AND_IMP_INTRO,GSYM CONJ_ASSOC]
1572                (foldr (uncurry DISCH) (foldr (uncurry DISCH)
1573                        (CONV_RULE (TOP_DEPTH_CONV FORALL_AND_CONV) (MP full_ind_thm thm2)) non_ind_terms) ind_terms))
1574        handle e => wrap e
1575end
1576end;
1577
1578fun prove_recind_thms_mutual (scheme:translation_scheme) term =
1579let     val _ = type_trace 3 "->prove_recind_thms_mutual\n"
1580        val rec_thm = #recursion scheme
1581        val ind_thm = #induction scheme
1582        val left = #left scheme
1583        val right = #right scheme
1584        val conjuncts = strip_conj term
1585        val call_graph = build_call_graph (left,right) conjuncts handle e => wrapException "prove_recind_thms_mutual" e
1586        val mthm = create_mutual_theorem call_graph rec_thm
1587        val _ = type_trace 2 ("Generated recursion theorem:\n" ^ thm_to_string mthm)
1588in
1589        (instantiate_mutual_theorem mthm conjuncts,
1590         create_ind_theorem call_graph scheme) handle e => wrapException "prove_recind_thms_mutual" e
1591end;
1592
1593(*****************************************************************************)
1594(* LEQ_REWRITES : term -> term -> thm list -> thm                            *)
1595(*                                                                           *)
1596(* Rewrites the first term to match the second using the list of rewrites    *)
1597(* given.                                                                    *)
1598(*                                                                           *)
1599(*****************************************************************************)
1600
1601local
1602fun insert x [] = [[x]]
1603  | insert x (y::ys) = (x::y::ys) :: (map (cons y) (insert x ys));
1604fun perm [] = [[]]
1605  | perm (x::xs) = flatten (map (insert x) (perm xs))
1606fun LEQSTEP 0 _ _ _ = raise Match
1607  | LEQSTEP n term1 term2 rewrites =
1608        if aconv term1 term2 then ALPHA term1 term2
1609        else (tryfind_e Match (fn r =>
1610                let val thm1 = REWR_CONV r term1
1611                in  TRANS thm1 (LEQSTEP (n - 1) ((rhs o concl) thm1) term2 rewrites) end) rewrites)
1612        handle Match =>
1613                if is_forall term1 andalso is_forall term2 then
1614                        tryfind_e Match (fn x =>
1615                                let     val thm1 = ORDER_FORALL_CONV x term1
1616                                        val thm2 = RIGHT_CONV_RULE (RENAME_VARS_CONV (map (fst o dest_var) (fst (strip_forall term2)))) thm1
1617                                        val r = LEQSTEP n (snd (strip_forall (rhs (concl thm2)))) (snd (strip_forall term2)) rewrites
1618                                in
1619                                        RIGHT_CONV_RULE (STRIP_BINDER_CONV (SOME universal) (REWR_CONV r)) thm2
1620                                end)
1621                                (filter (fn x => (map type_of x) = (map type_of (fst (strip_forall term2)))) (perm (fst (strip_forall term1))))
1622                else if is_comb term1 andalso is_comb term2 then
1623                        MK_COMB (LEQSTEP n (rator term1) (rator term2) rewrites,LEQSTEP n (rand term1) (rand term2) rewrites)
1624                else if is_abs term1 andalso is_abs term2 then
1625                let     val v1 = bvar term1
1626                        val v2 = bvar term2
1627                        val nvar = genvar (type_of v1)
1628                in
1629                        CONV_RULE (FORK_CONV (RENAME_VARS_CONV [fst (dest_var v1)],RENAME_VARS_CONV [fst (dest_var v2)]))
1630                                (MK_ABS (GEN nvar (LEQSTEP n (beta_conv (mk_comb(term1,nvar))) (beta_conv (mk_comb(term2,nvar))) rewrites)))
1631                end else raise Match
1632fun itdeep f n = f n handle Match => itdeep f (n + 1)
1633in
1634fun LEQ_REWRITES term1 term2 rwrs =
1635let     val thm1 = (PURE_REWRITE_CONV [FUN_EQ_THM] THENC DEPTH_CONV BETA_CONV) term1 handle e => REFL term1
1636        val thm2 = (PURE_REWRITE_CONV [FUN_EQ_THM] THENC DEPTH_CONV BETA_CONV) term2 handle e => REFL term2
1637        val rewrites = map (BETA_RULE o PURE_REWRITE_RULE [FUN_EQ_THM]) rwrs
1638in
1639        TRANS (TRANS thm1 (itdeep (fn n => LEQSTEP n (rhs (concl thm1)) (rhs (concl thm2)) rewrites) 0)) (GSYM thm2)
1640end
1641end;
1642
1643(*****************************************************************************)
1644(* prove_induction_recursion_thms:                                           *)
1645(*              translation_scheme -> term -> thm * (term * term) list * thm *)
1646(*                                                                           *)
1647(*     Given a function definition term with some clauses that are simply:   *)
1648(*     fn x = A x, with no recursive calls, rewrites the other clauses with  *)
1649(*     the non-recursive ones, if necessary proves the existence of the      *)
1650(*     mutual recursion theorem. Once complete, it rewrites the clauses back,*)
1651(*     and adds a proof of their existence to the overall proof.             *)
1652(*                                                                           *)
1653(*     Also returns a theorem of mutual induction, using (!x. P0 x ==> P1 x) *)
1654(*     when no recursion takes place, and provides a mapping from            *)
1655(*     predicates to clauses.                                                *)
1656(*                                                                           *)
1657(*****************************************************************************)
1658
1659local
1660fun debug_exn s = mkDebugExn "prove_induction_recursion_thms" s;
1661val fun_exn = debug_exn
1662        (       "Term supplied is not of the form: \n" ^
1663                "   |- ... (fni f0..fn x = \n" ^
1664                "               if isP x then fi x (decode (L x)) (decode (R x)) else ci)\n" ^
1665                "      ... (fnj f0..fn x = A (fn0 x) ... (fnm x))\n");
1666
1667val ind_mutual_exn = debug_exn
1668        (       "Returned induction theorem is not of the form: \n" ^
1669                "   |- !P0 .. Pn.\n" ^
1670                "        ... (!x. isP x /\\ P0 (L x) ... /\\ Pn (R x) ==> Pi x) /\\ \n" ^
1671                "        ... (!x. ~isP x ==> Pn x) ==> \n" ^
1672                "        (!x. P0 x) ... !x. Pn x\n");
1673fun wrap e = wrapException "prove_induction_recursion_thms" e;
1674
1675fun fix_nr_term tm =
1676let     val tm' = (snd o strip_forall) tm
1677        val vars = with_exn (snd o strip_comb o lhs) tm' fun_exn
1678in
1679        (STRIP_QUANT_CONV (RAND_CONV (UNBETA_LIST_CONV vars)) THENC
1680                PURE_REWRITE_CONV [GSYM FUN_EQ_THM]) (list_mk_forall (vars,tm'))
1681        handle e => wrapException "prove_induction_recursion_thms (fix_nr_term)" e
1682end
1683
1684fun UCONV conv term = (conv term) handle UNCHANGED => REFL term
1685
1686fun exists_nr_term thm =
1687let     val right = (rhs o concl) thm
1688        val var = lhs right
1689        val _ = if free_in var (rhs right) then
1690                raise (mkDebugExn "prove_induction_recursion_thms (exists_nr_term)"
1691                        ("Direct call term: " ^ term_to_string right ^
1692                         "\n directly refers to itself!")) else ()
1693in
1694        EXISTS (mk_exists(var,right),rhs right) (REFL (rhs right))
1695                handle e => wrapException "prove_induction_recursion_thms (exists_nr_term)" e
1696end
1697
1698fun wrapari e = wrapException "prove_induction_recursion_thms (add_redundant_ind)" e
1699fun add_redundant_ind clauses (scheme:translation_scheme) NONE =
1700let     val target_type = #target scheme
1701        val var = mk_var("x",target_type)
1702        val mkf = curry mk_forall var
1703        val l1 = map (fn (n,c) => (mk_comb(mk_var("P" ^ (int_to_string n),target_type --> ``:bool``),var),c))
1704                (enumerate 0 clauses)
1705        val isP = beta_conv(mk_comb(#predicate scheme,var))
1706        val l2 = map (UNDISCH_ONLY o SPEC var o ASSUME o mkf o curry mk_imp isP o fst) l1
1707        val l3 = map (UNDISCH_ONLY o SPEC var o ASSUME o mkf o curry mk_imp (mk_neg isP) o fst) l1
1708        val mapping = map (rator ## rator o lhs o snd o strip_forall) l1
1709in
1710        (GENL (map fst mapping) (PURE_REWRITE_RULE [AND_IMP_INTRO]
1711                (DISCH_ALL (LIST_CONJ (map2 (fn x => (GEN var o DISJ_CASES (SPEC isP EXCLUDED_MIDDLE) x)) l2 l3)))),
1712        mapping)
1713        handle e => wrapari e
1714end
1715  | add_redundant_ind clauses scheme (SOME ind) =
1716let     val rec_thm = #recursion scheme
1717        val Ptype = with_exn (type_of o fst o dest_forall o concl) ind ind_mutual_exn
1718        val islist = map (is_single_constructor scheme) clauses handle e => wrapari e
1719        val termsL = map (fn (n,b) => (b,mk_var("P" ^ (int_to_string n),Ptype))) (enumerate 0 islist)
1720        val mapping = map2 (curry (snd ## rator o lhs o snd o strip_forall)) termsL clauses
1721
1722        val ind1 = with_exn (SPECL (map snd (filter (not o fst) termsL))) ind ind_mutual_exn;
1723        val zipped = with_exn (zip termsL o map ((repeat rator ## I) o dest_eq o snd o strip_forall)) clauses fun_exn
1724
1725        val x = mk_var("x",fst (dom_rng Ptype));
1726
1727        (* Extra terms for inclusion from single constructed terms, ie: !x. P0 x ==> P1 x *)
1728        val extra_terms =
1729                foldr (fn (((single,pt),(_,right)),l) =>
1730                        if single then
1731                                mk_forall(x,mk_imp(list_mk_conj(map (fn p => mk_comb((snd o fst) p,x))
1732                                        (filter (C free_in right o fst o snd) zipped)),mk_comb(pt,x)))::l
1733                        else l) [] zipped handle e => wrapari e
1734
1735        (* Extra theorems, ie: [!x. P0 x ==> P1 x,!x. P1 x ==> P2 x] |- !x. P0 x ==> P2x *)
1736        val extra_thms = TC_THMS (map ASSUME extra_terms)
1737
1738        val all_clauses = (strip_conj o fst o dest_imp_only o concl) ind1
1739        val all_fns = map (fst o snd) zipped
1740
1741        (* Given theorems of the form: [..] |- Pi (f x) ==> Pj (f x) and a clause, *)
1742        (* replaces the term Pj (f x) with Pi (f x) in the induction theorem.      *)
1743        fun fix_thms [] clause induction = raise Empty
1744          | fix_thms thms clause induction =
1745        let     val (ante,conc) = (dest_imp_only o snd o strip_forall) clause
1746                val terms = strip_conj ante
1747                val var = rand conc
1748
1749                val thms' = map (fn t => tryfind_e Empty (C (PART_MATCH (fst o dest_imp_only)) t) thms handle Empty =>
1750                                DISCH_ALL (ASSUME t)) terms
1751                val final = foldr (fn (a,t) => MATCH_MP MONO_AND (CONJ a t)) (last thms') (butlast thms')
1752                val rthm = (GEN_ALL (IMP_TRANS final (SPEC var
1753                                (ASSUME (mk_forall(var,mk_imp(snd (dest_imp_only (concl final)),conc)))))))
1754        in
1755                if hyp rthm = [concl rthm] then raise Empty else PROVE_HYP_CHECK rthm induction
1756        end     handle Empty => raise Empty | e => wrapException
1757                        "prove_induction_recursion_thms (add_redundant_ind (fix_thms))" e
1758
1759        (* Finds all terms in the clause such that there is a function call:          *)
1760        (* f (lr x) where f corresponds to Pi but the induction contains the          *)
1761        (* predicate Pj (lr x), and uses the theorem [..] |- Pi (lr x) ==> Pj (lr x)  *)
1762        (* to replace it.                                                             *)
1763        fun check_replace clause induction =
1764        let     val pred = guarenteed (rator o snd o dest_imp_only o snd o strip_forall) clause
1765                val ((single,_),(name,body)) = guarenteed (first (curry op= pred o snd o fst)) zipped
1766                val vars = find_terms (C mem all_fns o repeat rator) body
1767                val vars_fixed = filter (fn t => not (exists (fn t' => not (t = t') andalso free_in t t') vars)) vars
1768                val rv = mapfilter (fn vf => SPEC (rand vf) (
1769                        first (curry op= (snd (fst (first (curry op= (repeat rator vf) o fst o snd) zipped)))
1770                                o rator o snd o dest_imp_only o snd o strip_forall o concl) extra_thms)) vars_fixed
1771        in
1772                fix_thms rv clause induction
1773        end     handle Empty => raise Empty | e => wrapException
1774                        "prove_induction_recursion_thms (add_redundant_ind (check_replace))" e
1775
1776        fun tf f [] = raise Empty
1777          | tf f (x::xs) = (f x) handle Empty => tf f xs
1778
1779        fun replace_all induction =
1780                replace_all (tf (C check_replace induction) (hyp induction)) handle Empty => induction
1781
1782        val induction = replace_all (CONV_HYP (PURE_REWRITE_CONV [AND_IMP_INTRO,GSYM CONJ_ASSOC])
1783                                (UNDISCH_ALL_ONLY (PURE_REWRITE_RULE [GSYM AND_IMP_INTRO] ind1)))
1784
1785        fun assoc_list i_thms [] = i_thms
1786          | assoc_list i_thms list =
1787        let     val (imps,not_imps) = partition (is_imp_only o concl) list
1788                val (mped,not_mped) = mappartition (fn th => tryfind (MP th o guarenteed SPEC x) i_thms) imps
1789        in
1790                if null mped andalso not (null imps) andalso null not_imps then
1791                        raise (mkDebugExn "add_redundant_ind"
1792                                ("Extra terms cannot be resolved, no theorem in the set:\n  " ^
1793                                xlist_to_string thm_to_string not_mped ^
1794                                "\ncan be resolved by a conclusion in the set:\n  " ^
1795                                xlist_to_string thm_to_string i_thms))
1796                else assoc_list (map (GEN x) not_imps @ i_thms) (mped @ not_mped)
1797        end
1798
1799        val all_thms = LIST_CONJ (assoc_list (CONJUNCTS induction)
1800                        (map (PURE_REWRITE_RULE [GSYM AND_IMP_INTRO] o SPEC x o ASSUME) extra_terms))
1801
1802        val cneg = exists is_neg o strip_conj o fst o dest_imp_only  o snd o strip_forall
1803        val p = rator o snd o dest_imp_only o snd o strip_forall
1804        fun order h1 h2 =
1805                (not (cneg h1) andalso (cneg h2)) orelse
1806                (cneg h1 = cneg h2) andalso
1807                fst (valOf (assoc2 (p h1) (enumerate 0 (map snd termsL)))) <
1808                fst (valOf (assoc2 (p h2) (enumerate 0 (map snd termsL))))
1809in
1810        (GENL (map snd termsL) (PURE_REWRITE_RULE [AND_IMP_INTRO,GSYM CONJ_ASSOC]
1811                (foldr (uncurry DISCH) all_thms (sort order (hyp all_thms)))),
1812        mapping)
1813end
1814
1815in
1816fun prove_induction_recursion_thms (scheme:translation_scheme) term =
1817let     val _ = type_trace 2 "->prove_induction_recursion_thms\n"
1818        val rec_thm = #recursion scheme
1819        val ind_thm = #induction scheme
1820
1821        val clauses = with_exn strip_conj term fun_exn
1822        val f_terms = with_exn (mk_set o map (repeat rator o lhs o snd o strip_forall)) clauses fun_exn
1823        val (non_rec_terms,rec_terms) = partition (is_single_constructor scheme) clauses handle e => wrap e
1824
1825        val non_rec_terms' = map fix_nr_term non_rec_terms handle e => wrap e
1826        val nr_rewrites = with_exn (map (ASSUME o rhs o concl)) non_rec_terms' fun_exn
1827        val rec_terms' = map (UCONV (   REDEPTH_CONV (FIRST_CONV (map REWR_CONV nr_rewrites))
1828                                        THENC DEPTH_CONV BETA_CONV)) rec_terms;
1829
1830        val nr_terms_exist = map exists_nr_term non_rec_terms'
1831        val (all_exists,induction1) = case (map (rhs o concl) rec_terms')
1832                of [] => (nr_terms_exist,NONE)
1833                |  L  => (C cons nr_terms_exist ## SOME) (prove_recind_thms_mutual scheme (list_mk_conj L))
1834                handle e => wrap e
1835        val (induction,mapping) = add_redundant_ind clauses scheme induction1 handle e => wrap e
1836
1837        val all_clauses =
1838                map (fn f_term =>
1839                        first_e (mkDebugExn     "prove_induction_recursion_thms"
1840                                                "Function term missing from existence proof")
1841                                (curry op= f_term o repeat rator o lhs o snd o strip_forall o concl)
1842                                (flatten (map (CONJUNCTS o ASSUME o snd o strip_exists o concl) all_exists))) f_terms;
1843
1844        fun NR_TERM_CONV term =
1845        let     val poss = filter (curry op= ((repeat rator o lhs o snd o strip_forall) term) o
1846                                repeat rator o lhs o rhs o concl) non_rec_terms'
1847        in
1848                (TRY_CONV (FIRST_CONV (map (REWR_CONV o GSYM) poss))) term
1849        end
1850
1851        fun NR_TERM_CONV term =
1852        let     val poss = filter (curry op= ((repeat rator o lhs o snd o strip_forall) term) o
1853                                repeat rator o lhs o rhs o concl) non_rec_terms'
1854        in
1855                (TRY_CONV (FIRST_CONV (map (REWR_CONV o GSYM) poss))) term
1856        end
1857
1858        val thm1 = LIST_CONJ (map2 (fn c => fn a => CONV_RULE (REWR_CONV (GSYM (LEQ_REWRITES c (concl a) nr_rewrites))) a) clauses all_clauses)
1859
1860        val thm2 = foldr (fn (v,thm) => SIMPLE_EXISTS v thm) thm1 f_terms handle e => wrap e
1861
1862        val lset = map (repeat rator o lhs o snd o strip_forall) o strip_conj;
1863        fun remove_witnesses thm list [] = thm
1864          | remove_witnesses thm list hs =
1865        let     val h = first_e
1866                        (mkDebugExn "prove_induction_recursion_thms"
1867                                ("Hypothesis set:\n  " ^ xlist_to_string term_to_string hs ^
1868                                 "\ncontains circular dependancies!"))
1869                        (fn h => all (fn h' => h = h' orelse not (exists (C free_in h') (lset h))) hs) hs
1870                val match = first_e
1871                        (mkDebugExn "prove_induction_recursion_thms"
1872                                ("Could not find a match for hypothesis:\n  " ^ term_to_string h ^
1873                                 "\nin the witness set: " ^ xlist_to_string thm_to_string list))
1874                         (curry op= h o snd o strip_exists o concl) list
1875        in
1876                remove_witnesses (CHOOSE_L ((fst o strip_exists o concl) match,match) thm) list
1877                        (set_diff hs [h])
1878        end
1879in
1880        (induction,mapping,remove_witnesses thm2 all_exists (hyp thm2))
1881        handle e => wrap e
1882end
1883end;
1884
1885(*****************************************************************************)
1886(* The data structure holding the type translations                          *)
1887(*                                                                           *)
1888(* Note: The precise versions of functions will match the exact type given,  *)
1889(*       whilst the imprecise versions will return a precise version if it   *)
1890(*       exists, but the imprecise version if it does not.                   *)
1891(*                                                                           *)
1892(* exists_translation[_precise] : hol_type -> hol_type -> bool               *)
1893(* add_translation[_precise]    : hol_type -> hol_type -> unit               *)
1894(* get_translation[_precise]    : hol_type -> hol_type ->                    *)
1895(*                                                (string,function) dict ref *)
1896(* get_theorems[_precise]       : hol_type -> hol_type ->                    *)
1897(*                                                (string, thm) dict ref     *)
1898(*     Tests for the existence of a translations, creates a new translation  *)
1899(*     and returns dictionarys of translating functions or theorems          *)
1900(*                                                                           *)
1901(* exists_coding_function[_precise] : hol_type -> hol_type -> string -> bool *)
1902(* add_coding_function : hol_type -> hol_type -> string -> function -> unit  *)
1903(* get_coding_function[_precise]_def : hol_type -> hol_type -> string -> thm *)
1904(* get_coding_function[_precise]_const :                                     *)
1905(*                                    hol_type -> hol_type -> string -> term *)
1906(* get_coding_function[_precise]_induction :  hol_type -> hol_type ->        *)
1907(*                           string -> thm * (term * (term * hol_type)) list *)
1908(*     Tests for the existence of a coding function, adds a new coding       *)
1909(*     function and returns a function's definition, constant and principle  *)
1910(*     of induction.                                                         *)
1911(*                                                                           *)
1912(* exists_coding_theorem[_precise] : hol_type -> hol_type -> string -> bool  *)
1913(* add_coding_theorem[_precise] : hol_type -> hol_type ->                    *)
1914(*                                                     string -> thm -> unit *)
1915(* get_coding_theorem[_precise] : hol_type -> hol_type -> string -> thm      *)
1916(*     Tests for the existence of a coding theorem, adds a new coding        *)
1917(*     theorem and returns a coding theorem.                                 *)
1918(*                                                                           *)
1919(* exists_source_function[_precise] : hol_type -> string -> bool             *)
1920(* add_source_function : hol_type -> string -> function -> unit              *)
1921(* get_source_function[_precise]_def : hol_type -> string -> thm             *)
1922(* get_source_function[_precise]_const : hol_type -> string -> term          *)
1923(* get_source_function[_precise]_induction :  hol_type -> string ->          *)
1924(*                                     thm * (term * (term * hol_type)) list *)
1925(*     Tests for the existence of a source function, adds a new source       *)
1926(*     function and returns a function's definition, constant and principle  *)
1927(*     of induction.                                                         *)
1928(*                                                                           *)
1929(* exists_source_theorem[_precise] : hol_type -> string -> bool              *)
1930(* add_source_theorem[_precise] : hol_type -> string -> thm -> unit          *)
1931(* get_source_theorem[_precise] : hol_type -> string -> thm                  *)
1932(*     Tests for the existence of a source theorem, adds a new source        *)
1933(*     theorem and returns a source theorem.                                 *)
1934(*                                                                           *)
1935(* add_translation_scheme : hol_type -> thm -> thm -> unit                   *)
1936(*     Given a theorems of the form:                                         *)
1937(*         |- P a ==> measure (L a) < measure a /\ measure (R a) < measure a *)
1938(*         |- P nil = F                                                      *)
1939(*     adds a translation scheme creating theorems of recursion and          *)
1940(*     induction from the theorem.                                           *)
1941(*                                                                           *)
1942(*****************************************************************************)
1943
1944val type_less = Type.compare;
1945
1946val codingBase = ref (mkDict type_less) : translations ref;
1947val sourceBase = (ref (mkDict type_less),ref (mkDict type_less)) : (functions ref * theorems ref);
1948
1949fun clearCoding () = (codingBase := mkDict type_less);
1950fun clearSource () = (fst sourceBase := mkDict type_less ; snd sourceBase := mkDict type_less);
1951
1952local
1953fun translation_not_found t1 t2 =
1954        mkStandardExn "get_translation"
1955        ("The translation " ^ type_to_string (base_type t1) ^ " --> " ^
1956          type_to_string t2 ^ " was not found in the database");
1957fun translation_scheme_not_found t =
1958        mkStandardExn "get_translation_scheme" ("There is no translation scheme for type " ^ type_to_string t);
1959fun get_translations target =
1960        Binarymap.find (!codingBase,target) handle NotFound => (raise (translation_scheme_not_found target))
1961fun vbase_type t = base_type t handle _ => t
1962in
1963fun get_translation_scheme target = snd (get_translations target)
1964fun exists_translation_precise target t =
1965        case (Binarymap.peek(!((fst o fst) (get_translations target)),cannon_type t))
1966        of NONE => false
1967        |  SOME x => true
1968fun exists_translation target t = exists_translation_precise target (vbase_type t)
1969fun add_translation target t =
1970        if exists_translation target t then
1971                raise (mkStandardExn "add_translation"
1972                        ("The translation " ^ type_to_string target ^ " --> " ^ type_to_string t ^
1973                         " already exists."))
1974        else    let     val ((fbase,tbase),_) = get_translations target
1975                in      (fbase := Binarymap.insert(!fbase,vbase_type t,
1976                                        ref (mkDict String.compare : (string,function) dict)) ;
1977                         tbase := Binarymap.insert(!tbase,vbase_type t,
1978                                        ref (mkDict String.compare : (string,thm) dict)))
1979                end
1980fun add_translation_precise target t =
1981        if exists_translation_precise target t then
1982                raise (mkStandardExn "add_translation"
1983                        ("The translation " ^ type_to_string target ^ " --> " ^ type_to_string t ^
1984                         " already exists."))
1985        else    let     val ((fbase,tbase),_) = get_translations target
1986                in      (fbase := Binarymap.insert(!fbase,cannon_type t,
1987                                        ref (mkDict String.compare : (string,function) dict)) ;
1988                         tbase := Binarymap.insert(!tbase,cannon_type t,
1989                                        ref (mkDict String.compare : (string,thm) dict)))
1990                end
1991fun get_translation_precise target t =
1992        case (Binarymap.peek(!((fst o fst) (get_translations target)), cannon_type t))
1993        of NONE => raise (translation_not_found target t)
1994        |  SOME x => x
1995fun get_translation target t = get_translation_precise target (vbase_type t)
1996fun get_theorems_precise target t =
1997        case (Binarymap.peek(!((snd o fst) (get_translations target)),cannon_type t))
1998        of NONE => raise (translation_not_found target t)
1999        |  SOME x => x
2000fun get_theorems target t = get_theorems_precise target (vbase_type t)
2001fun get_translation_types target =
2002    map fst (Binarymap.listItems (! (fst (fst (get_translations target)))));
2003end;
2004
2005(*****************************************************************************)
2006(* This function performs a breadth-first search, finding the most precise   *)
2007(* type with an entry in the database.                                       *)
2008(*****************************************************************************)
2009
2010fun all_lists [] = [[]]
2011  | all_lists (x::xs) =
2012    foldr (fn (a,t) => map (cons a) (all_lists xs) @ t) [] x;
2013
2014fun explode_type t =
2015    if is_vartype t
2016       then [gen_tyvar()]
2017       else gen_tyvar()::
2018                map (curry mk_type (fst (dest_type t)))
2019                    (all_lists (map explode_type (snd (dest_type t))));
2020
2021fun ordered_list t =
2022    map cannon_type (rev (explode_type t));
2023
2024fun most_precise_type exists_function t =
2025    first_e (mkStandardExn "most_precise_type"
2026               "No sub-type exists such that the function given holds")
2027            exists_function
2028            (ordered_list t);
2029
2030(*-- functions --*)
2031
2032fun exists_coding_function_precise target t name =
2033        if (exists_translation_precise target t)
2034        then
2035        (case (Binarymap.peek(!(get_translation_precise target t),name))
2036        of NONE => false
2037        |  SOME x => true)
2038        else false;
2039
2040fun exists_coding_function target t name =
2041    can (most_precise_type
2042            (C (exists_coding_function_precise target) name)) t
2043
2044fun inst_function {const,definition,induction} t =
2045    {const = safe_inst (match_type (cannon_type t) t) const,
2046     definition = SAFE_INST_TYPE (match_type (cannon_type t) t) definition,
2047     induction =
2048         Option.map
2049            (SAFE_INST_TYPE (match_type (cannon_type t) t) ##
2050             map (safe_inst (match_type (cannon_type t) t) ##
2051                   (safe_inst (match_type (cannon_type t) t) ##
2052                    safe_type_subst (match_type (cannon_type t) t)))) induction}
2053
2054fun get_coding_function_precise target t name =
2055    case (Binarymap.peek(!(get_translation_precise target t),name))
2056    of NONE => raise (mkStandardExn "get_coding_function_precise"
2057                ("The function " ^ name ^
2058                 " was not found for the translation " ^
2059                 type_to_string target ^ " --> " ^ type_to_string t))
2060    |  SOME function => inst_function function t;
2061
2062fun get_coding_function target t name =
2063    inst_function (get_coding_function_precise target
2064        (most_precise_type
2065            (C (exists_coding_function_precise target) name) t) name) t
2066    handle _ => raise (mkStandardExn "get_coding_function"
2067                ("The function " ^ name ^
2068                 " was not found for the translation " ^
2069                 type_to_string target ^ " --> " ^ type_to_string t))
2070
2071fun get_coding_function_def target t name =
2072        #definition (get_coding_function target t name)
2073fun get_coding_function_const target t name =
2074        #const (get_coding_function target t name)
2075fun get_coding_function_induction target t name =
2076        case (#induction (get_coding_function target t name))
2077        of NONE => raise (mkStandardExn "get_coding_function_induction"
2078                ("The function " ^ name ^ "(" ^ type_to_string t ^
2079                 " --> " ^ type_to_string target ^
2080                 ") does not have an induction principle defined for it."))
2081        |  SOME x => x
2082fun get_coding_function_precise_def target t name =
2083        #definition (get_coding_function_precise target t name)
2084fun get_coding_function_precise_const target t name =
2085        #const (get_coding_function_precise target t name)
2086fun get_coding_function_precise_induction target t name =
2087        case (#induction (get_coding_function_precise target t name))
2088        of NONE => raise (mkStandardExn "get_coding_function_precise_induction"
2089                ("The function " ^ name ^ "(" ^ type_to_string t ^
2090                 " --> " ^ type_to_string target ^
2091                 ") does not have an induction principle defined for it."))
2092        |  SOME x => x
2093
2094fun add_coding_function_precise target t name {const,definition,induction} =
2095let     val _ = type_trace 1 ("Adding coding function, " ^ name ^ ", for type: "
2096                                ^ (type_to_string (cannon_type t)) ^ "\n")
2097        val base = get_translation_precise target t handle e => wrapException "add_coding_function_precise" e
2098        val sub = match_type t (cannon_type t)
2099in
2100        base := Binarymap.insert(!base,name,{const = inst sub const,definition = SAFE_INST_TYPE sub definition,
2101                induction =  Option.map (SAFE_INST_TYPE sub ##
2102                        map (safe_inst sub ## (safe_inst sub ##
2103                                safe_type_subst sub))) induction})
2104end;
2105
2106fun add_coding_function target t name function = add_coding_function_precise target (base_type t handle _ => t) name function
2107
2108
2109(*-- theorems --*)
2110
2111fun exists_coding_theorem_precise target t name =
2112        if exists_translation_precise target t then
2113                case (Binarymap.peek(!(get_theorems_precise target t),name))
2114                of NONE => false
2115                |  SOME x => true
2116        else false;
2117fun exists_coding_theorem target t name =
2118        exists_coding_theorem_precise target t name orelse
2119        exists_coding_theorem_precise target (base_type t handle _ => t) name
2120
2121fun add_coding_theorem_precise target t name thm =
2122let     val _ = type_trace 1 ("Adding coding theorem, " ^ name ^ ", for type: " ^
2123                        (type_to_string (cannon_type t)) ^ "\n")
2124        val _ = if exists_translation_precise target t then () else add_translation_precise target t
2125        val base = get_theorems_precise target t handle e => wrapException "add_coding_theorem_precise" e
2126in
2127        base := Binarymap.insert(!base,name,SAFE_INST_TYPE (match_type t (cannon_type t)) thm)
2128end;
2129
2130fun add_coding_theorem target t name thm = add_coding_theorem_precise target (base_type t handle _ => t) name thm
2131
2132fun get_coding_theorem_precise target t name =
2133        case (Binarymap.peek(!(get_theorems_precise target t),name))
2134        of NONE => raise (mkStandardExn "get_coding_theorem_precise"
2135                ("The theorem " ^ name ^ " does not exists for the translation " ^ type_to_string target ^ " --> " ^
2136                type_to_string t))
2137        |  SOME x => SAFE_INST_TYPE (match_type (cannon_type t) t) x
2138
2139fun get_coding_theorem target t name =
2140    get_coding_theorem_precise target
2141        (most_precise_type (C (exists_coding_theorem_precise target) name) t)
2142        name
2143
2144(*-- source functions and theorems --*)
2145
2146fun exists_source_function_precise t name =
2147        case (Binarymap.peek(!(fst sourceBase),cannon_type t))
2148        of NONE => false | SOME x =>
2149                case (Binarymap.peek (!x,name))
2150                of NONE => false | SOME x => true;
2151fun exists_source_function t name =
2152    can (most_precise_type (C exists_source_function_precise name)) t
2153
2154fun get_source_function_precise t name =
2155        case (Binarymap.peek(!(fst sourceBase),cannon_type t))
2156        of NONE => raise (mkStandardExn "get_source_function"
2157                        ("No source functions found for type " ^ type_to_string t))
2158        |  SOME x =>
2159                case (Binarymap.peek(!x,name))
2160                of NONE => raise (mkStandardExn "get_source_function_precise"
2161                                ("The function " ^ name ^ " has not been defined for the type " ^ type_to_string t))
2162                |  SOME function => inst_function function t
2163
2164(*              {const = safe_inst (match_type (cannon_type t) t) const,
2165                 definition = SAFE_INST_TYPE (match_type (cannon_type t) t) definition,
2166                 induction = Option.map (SAFE_INST_TYPE (match_type (cannon_type t) t) ##
2167                        map (safe_inst (match_type (cannon_type t) t) ## (safe_inst (match_type (cannon_type t) t) ##
2168                                safe_type_subst (match_type (cannon_type t) t)))) induction}*)
2169fun get_source_function t name =
2170    inst_function (get_source_function_precise
2171        (most_precise_type (C exists_source_function_precise name) t) name) t
2172    handle e => raise (mkStandardExn "get_source_function"
2173    ("The function " ^ name ^ " has not been defined for any sub-type of " ^
2174      type_to_string t))
2175
2176fun get_source_function_def t name =
2177        #definition (get_source_function t name)
2178fun get_source_function_const t name =
2179        #const (get_source_function t name)
2180fun get_source_function_induction t name =
2181        case (#induction (get_source_function t name))
2182        of NONE => raise (mkStandardExn "get_source_function_induction"
2183                ("The function " ^ name ^ "(" ^ type_to_string t ^
2184                 ") does not have an induction principle defined for it."))
2185        |  SOME x => x
2186fun get_source_function_precise_def t name =
2187        #definition (get_source_function_precise t name)
2188fun get_source_function_precise_const t name =
2189        #const (get_source_function_precise t name)
2190fun get_source_function_precise_induction t name =
2191        case (#induction (get_source_function_precise t name))
2192        of NONE => raise (mkStandardExn "get_source_function_precise_induction"
2193                ("The function " ^ name ^ "(" ^ type_to_string t ^
2194                 ") does not have an induction principle defined for it."))
2195        |  SOME x => x
2196
2197fun add_source_function_precise t name {const,definition,induction} =
2198let     val sub = match_type t (cannon_type t)
2199        val _  = type_trace 1 ("Adding source function, " ^ name ^ ", for type: " ^
2200                                        (type_to_string (cannon_type t)) ^ "\n")
2201        val _ = case (Binarymap.peek(!(fst sourceBase),cannon_type t))
2202                of NONE => ((fst sourceBase) := Binarymap.insert(!(fst sourceBase),cannon_type t,ref (mkDict String.compare)))
2203                |  SOME x => ()
2204        val base = Binarymap.find(!(fst sourceBase),cannon_type t)
2205        val sub = match_type t (cannon_type t)
2206in
2207        base := Binarymap.insert(!base,name,{const = inst sub const,definition = SAFE_INST_TYPE sub definition,
2208                induction =  Option.map (SAFE_INST_TYPE sub ##
2209                        map (safe_inst sub ## (safe_inst sub ##
2210                                safe_type_subst sub))) induction})
2211end
2212fun add_source_function t name x = add_source_function_precise (base_type t handle _ => t) name x
2213
2214fun exists_source_theorem_precise t name =
2215        case (Binarymap.peek(!(snd sourceBase),cannon_type t))
2216        of NONE => false
2217        |  SOME x => (  case (Binarymap.peek(!x,name))
2218                        of NONE => false
2219                        |  SOME x => true);
2220
2221fun exists_source_theorem t name =
2222        exists_source_theorem_precise t name orelse
2223        exists_source_theorem_precise (base_type t handle _ => t) name
2224
2225fun get_source_theorem_precise t name =
2226let     val err = mkStandardExn "get_source_theorem_precise"
2227                        ("Theorem: " ^ name ^ " does not exist for type " ^ type_to_string t)
2228in
2229        case (Binarymap.peek(!(snd sourceBase),cannon_type t))
2230        of NONE => raise err
2231        |  SOME x => (  case (Binarymap.peek(!x,name))
2232                        of NONE => raise err
2233                        |  SOME x => SAFE_INST_TYPE (match_type (cannon_type t) t) x)
2234end
2235
2236fun get_source_theorem t name =
2237    get_source_theorem_precise
2238        (most_precise_type (C exists_source_theorem_precise name) t) name
2239
2240fun add_source_theorem_precise t name thm =
2241        (type_trace 1 ("Adding source theorem, " ^ name ^ ", for type: " ^
2242                                (type_to_string (cannon_type t)) ^ "\n") ;
2243        (case (Binarymap.peek(!(snd sourceBase),cannon_type t))
2244        of NONE => ((snd sourceBase) := Binarymap.insert(!(snd sourceBase),cannon_type t,ref (mkDict String.compare)))
2245        |  SOME x => ())
2246        ;
2247        let     val base = Binarymap.find(!(snd sourceBase),cannon_type t)
2248        in
2249                base := Binarymap.insert(!base,name,SAFE_INST_TYPE (match_type t (cannon_type t)) thm)
2250        end)
2251
2252fun add_source_theorem t name thm = add_source_theorem_precise (base_type t handle _ => t) name thm;
2253
2254fun remove_coding_theorem_precise target t name =
2255let     val base = get_theorems_precise target t handle e => wrapException "remove_coding_theorem_precise" e
2256in
2257        base := fst (Binarymap.remove(!base,name))
2258end;
2259
2260fun remove_source_theorem_precise t name =
2261        case (Binarymap.peek(!(snd sourceBase),t))
2262        of NONE => ()
2263        |  SOME x =>
2264        let     val base = Binarymap.find(!(snd sourceBase),t)
2265        in
2266                base := fst (Binarymap.remove(!base,name))
2267        end;
2268
2269
2270(*-- the translation base itself --*)
2271
2272local
2273val imp1 = CONV_RULE (REWR_CONV (GSYM AND_IMP_INTRO)) (SPEC_ALL IMP_CONG)
2274val eqt = fst (dest_imp_only (concl imp1))
2275val IMP_THM = MP (INST [(uncurry (C (curry op|->)) o dest_eq) eqt] imp1) (REFL (lhs eqt))
2276val size_format =
2277        "size_thm should be of the form: \n" ^
2278        "|- P x ==> size (left x) < size x /\\ size (right x) < size x"
2279fun size_err s = mkStandardExn "add_translation_scheme" (size_format ^ "\nhowever " ^ s)
2280fun wrap1 e = wrapException "add_translation_scheme (make_recursion)" e
2281fun wrap2 e = wrapException "add_translation_scheme (make_induction)" e
2282fun split_size_thm target size_thm =
2283let     val specced = SPEC_ALL size_thm
2284        val (p_term,rest) = with_exn (dest_imp_only o concl) specced (size_err "theorem is not an implication")
2285        val ((l,lmeasure),(r,rmeasure)) = with_exn ((numLib.dest_less ## numLib.dest_less) o dest_conj) rest
2286                (size_err "result of theorem is not a conjunction of a < b terms")
2287        val ((ml,left),(mr,right)) = with_exn (dest_comb ## dest_comb) (l,r)
2288                (size_err "left of < terms are not measures '(size x)'")
2289        val ((mlr,_),(mrr,_)) = with_exn (dest_comb ## dest_comb) (lmeasure,rmeasure)
2290                (size_err "right of < terms are not measures '(size x)'")
2291        val all_vars = free_vars p_term
2292        val xvar = hd all_vars
2293        val _ = if all (curry op= ml) [mlr,mrr,mr] then () else raise (size_err
2294                ("measures '" ^ term_to_string ml ^ "' and '" ^
2295                        term_to_string (first (not o curry op= ml) [mlr,mrr,mr]) ^ "' are not equal"))
2296        val _ = if length all_vars = 1 then () else raise (size_err
2297                ("the antecedant " ^ term_to_string p_term ^ " does not depend on only one variable"))
2298        val _ = if type_of xvar = target then () else raise (size_err
2299                        ("variable in antecedent is of type " ^ type_to_string (type_of xvar) ^
2300                         " and not " ^ type_to_string target))
2301in
2302        (specced,ml,left,right,p_term,xvar)
2303end
2304fun make_recursion (specced,ml,left,right,p_term,xvar) target size_thm =
2305let     val f = mk_var("f",target --> beta) handle e => wrap1 e
2306        val f0 = mk_var("f0",target --> beta --> beta --> beta)  handle e => wrap1 e
2307        val template = list_mk_abs([f,xvar],mk_cond(p_term,
2308                        list_mk_comb(f0,[xvar,mk_comb(f,left),mk_comb(f,right)]),
2309                                mk_comb(mk_var("c0",target --> beta),xvar))) handle e => wrap1 e
2310        val measure = mk_comb(mk_const("measure",(target --> num) --> target --> target --> bool),ml)
2311                handle e => wrap1 e
2312        val decode_var = mk_var("decode",target --> beta) handle e => wrap1 e
2313        val rec_term = list_mk_comb(mk_const("WFREC",
2314                                type_of measure --> type_of template --> target --> beta),
2315                                [measure,template]) handle e => wrap1 e
2316        val def = mk_eq(decode_var,rec_term) handle e => wrap1 e
2317        val th0 = GEN xvar (SPEC xvar (MP (MATCH_MP relationTheory.WFREC_COROLLARY (ASSUME def))
2318                        (PART_MATCH rand prim_recTheory.WF_measure measure))) handle e => wrap1 e
2319        val th1 = CONV_RULE (REDEPTH_CONV BETA_CONV)
2320                        (PURE_REWRITE_RULE [relationTheory.RESTRICT_DEF, prim_recTheory.measure_thm] th0)
2321                handle e => wrap1 e
2322        val term = snd (strip_forall (concl th1)) handle e => wrap1 e
2323        val th2 = (REWR_CONV COND_RAND THENC REWR_CONV COND_EXPAND THENC
2324                        (NTH_CONJ_CONV 0 (HO_REWR_CONV (GSYM IMP_DISJ_THM)))) term handle e => wrap1 e
2325        val half = MATCH_MP IMP_THM (DISCH p_term ((PURE_REWRITE_CONV [ASSUME p_term,UNDISCH_ONLY specced] THENC
2326                                DEPTH_CONV (REWR_CONV (fst (CONJ_PAIR (SPEC_ALL COND_CLAUSES)))))
2327                        (snd (dest_imp_only (fst (dest_conj (rhs (concl th2)))))))) handle e => wrap1 e
2328        val th_l = CONV_RULE (STRIP_QUANT_CONV (REWR_CONV
2329                        (RIGHT_CONV_RULE (NTH_CONJ_CONV 0 (REWR_CONV half THENC REWR_CONV IMP_DISJ_THM) THENC
2330                                REWR_CONV (GSYM COND_EXPAND) THENC (REWR_CONV (GSYM COND_RAND))) th2))) th1
2331        val th_r = EXISTS (Psyntax.mk_exists(decode_var,def),rec_term) (REFL rec_term) handle e => wrap1 e
2332in
2333        SPEC_ALL (DISCH_ALL (GEN_ALL (CHOOSE (decode_var,th_r) (SIMPLE_EXISTS decode_var th_l)))) handle e => wrap1 e
2334end
2335fun make_induction (specced,ml,left,right,p_term,xvar) target size_thm =
2336let     val pred = mk_var("P",target --> bool) handle e => wrap2 e
2337        val ante_true = mk_forall(xvar,mk_imp(
2338                list_mk_conj [p_term,mk_comb(pred,left),mk_comb(pred,right)],mk_comb(pred,xvar))) handle e => wrap2 e
2339        val ante_false = mk_forall(xvar,mk_imp(mk_neg p_term,mk_comb(pred,xvar))) handle e => wrap2 e
2340        val measure = mk_comb(mk_const("measure",(target --> num) --> target --> target --> bool),ml)
2341                handle e => wrap1 e
2342        val th1 = SPEC_ALL (PURE_REWRITE_RULE [prim_recTheory.measure_thm]
2343                        (MP     (ISPEC measure relationTheory.WF_INDUCTION_THM)
2344                                (ISPEC ml prim_recTheory.WF_measure)))
2345                        handle e => wrap2 e
2346        val (th_true,th_false) = (CONJ_PAIR (ASSUME (mk_conj(ante_true,ante_false)))) handle e => wrap2 e
2347        val (wvar,pt1) = dest_forall(fst (dest_imp_only (concl th1))) handle e => wrap2 e
2348        val proof_term = subst [wvar |-> xvar] (fst (dest_imp_only pt1)) handle e => wrap2 e
2349        val th_false2 = DISCH proof_term (UNDISCH (SPEC_ALL th_false)) handle e => wrap2 e
2350        val th_true2 = DISCH proof_term (MP (REWRITE_RULE [AND_IMP_INTRO]
2351                (UNDISCH (REWRITE_RULE [GSYM AND_IMP_INTRO] (SPEC_ALL th_true))))
2352                        (CONJ
2353                                (MP (SPEC left (ASSUME proof_term)) (fst (CONJ_PAIR (UNDISCH specced))))
2354                                (MP (SPEC right (ASSUME proof_term)) (snd (CONJ_PAIR (UNDISCH specced))))))
2355                handle e => wrap2 e
2356in
2357        GEN_ALL (DISCH (mk_conj(ante_true,ante_false))
2358                (MP th1 (GEN xvar (DISJ_CASES (SPEC p_term EXCLUDED_MIDDLE) th_true2 th_false2)))) handle e => wrap2 e
2359end
2360in
2361fun add_translation_scheme target size_thm dead_thm =
2362let     val (specced,ml,left,right,p_term,xvar) = split_size_thm target size_thm
2363        val dead_term = hd (map #residue (fst (match_term p_term (lhs (concl dead_thm)))))
2364in
2365        codingBase :=
2366        Binarymap.insert (!codingBase,target,
2367                ((ref (mkDict type_less) : functions ref,
2368                  ref (mkDict type_less) : theorems ref),
2369                {target = target,
2370                 induction = make_induction (specced,ml,left,right,p_term,xvar) target size_thm,
2371                 recursion = make_recursion (specced,ml,left,right,p_term,xvar) target size_thm,
2372                 bottom_thm = dead_thm,
2373                 bottom = dead_term,
2374                 left = mk_abs(xvar,left),
2375                 right = mk_abs(xvar,right),
2376                 predicate = mk_abs(xvar,p_term)}))
2377end
2378end
2379
2380(*****************************************************************************)
2381(* Loop checking for function generators                                     *)
2382(*                                                                           *)
2383(* Simply provides a function 'check_loop' that when giving an identifying   *)
2384(* string for a function can detect whether the function is looping.         *)
2385(*                                                                           *)
2386(* Looping is defined to occur if:                                           *)
2387(*     a) Exactly the same type is visited twice                             *)
2388(*     b) Types which are equal up to renaming of type variables visited     *)
2389(*        more than twice.                                                   *)
2390(*                                                                           *)
2391(*****************************************************************************)
2392
2393local
2394val stores = ref (mkDict String.compare);
2395fun get s =
2396        case (Binarymap.peek(!stores,s))
2397        of NONE => (stores := Binarymap.insert(!stores,s,ref []) ; Binarymap.find(!stores,s))
2398        |  SOME x => x
2399fun cons s t = let val x = get s in x := t :: (!x) end
2400fun head s = let val x = get s in case (!x) of [] => NONE | y::ys => (x := ys ; SOME y) end
2401fun tail s = let val x = get s in case (!x) of [] => NONE | y::ys => (x := ys ; SOME ys) end
2402fun mem s t = let val x = get s in Lib.mem t (!x) end
2403in
2404val cstores = stores
2405fun check_loop s t f fail =
2406        if mem s (cannon_type t)
2407        then fail (!(get s))
2408        else let val result = (cons s (cannon_type t) ; f t handle e => (tail s ; raise e)) in (tail s ; result) end
2409end;
2410
2411(*****************************************************************************)
2412(* Function generators:                                                      *)
2413(*                                                                           *)
2414(* Allow the recursive definition of functions for a whole type              *)
2415(*                                                                           *)
2416(*****************************************************************************)
2417
2418val coding_function_generators =
2419        ref (Binarymap.mkDict type_less) :
2420  (hol_type,
2421   (string,
2422    ((hol_type -> bool) *
2423     (hol_type -> function)) list ref) dict ref) dict ref
2424
2425fun add_coding_function_generator target (name:string) (predicate:hol_type -> bool)  (generator:hol_type -> function) =
2426let     val _ = case (Binarymap.peek (!coding_function_generators,target))
2427                of NONE => coding_function_generators := Binarymap.insert(!coding_function_generators,target,ref (mkDict String.compare))
2428                |  SOME _ => ()
2429        val generators = Binarymap.find (!coding_function_generators,target)
2430        val _ = case(Binarymap.peek(!generators,name))
2431                of NONE => generators := Binarymap.insert(!generators,name,ref [])
2432                |  SOME _ => ()
2433        val list = Binarymap.find(!generators,name)
2434in
2435        list := (predicate,generator) :: (!list)
2436end;
2437
2438val source_function_generators =
2439        ref (Binarymap.mkDict String.compare) :
2440  (string,
2441   ((hol_type -> bool) *
2442    (hol_type -> function)) list ref) dict ref
2443
2444fun add_source_function_generator name (predicate:hol_type -> bool) (generator : hol_type -> function) =
2445let     val _ = case(Binarymap.peek(!source_function_generators,name))
2446                of NONE => source_function_generators := Binarymap.insert(!source_function_generators,name,ref [])
2447                |  SOME _ => ()
2448        val list = Binarymap.find(!source_function_generators,name)
2449in
2450        list := (predicate,generator) :: (!list)
2451end;
2452
2453local
2454fun err name target t = mkStandardExn "get_coding_function_generator"
2455        ("No coding function generator exists for functions named " ^ name ^
2456         " in the translation: " ^ type_to_string target ^
2457         " --> " ^ type_to_string t)
2458fun get_coding_function_generator target name t =
2459        case (Binarymap.peek(!coding_function_generators,target))
2460        of NONE => raise (err name target t)
2461        |  SOME x => case (Binarymap.peek(!x,name))
2462                of NONE => raise (err name target t)
2463                |  SOME x => (snd (first_e (err name target t) (fn (x,y) => x t) (!x)))
2464fun gcf target name t =
2465let     val function = if exists_coding_function_precise target t name
2466                then get_coding_function_precise target t name
2467                else (type_trace 1 (
2468                        "Generating function " ^ name ^ " for translation " ^
2469                        (type_to_string target) ^ " --> " ^ type_to_string t ^ "\n") ;
2470                        (get_coding_function_generator target name t) t)
2471in
2472        if exists_coding_function_precise target t name
2473        then ()
2474        else add_coding_function_precise target t name function
2475end
2476in
2477fun generate_coding_function target name t =
2478        check_loop ("gcf" ^ name ^ type_to_string (cannon_type target)) t (gcf target name)
2479                (fn list => raise (mkDebugExn "generate_coding_function"
2480                        ("Experienced a loop whilst generating the coding function " ^ name ^
2481                         " for type " ^ type_to_string target ^
2482                         "\nTrace: " ^ xlist_to_string type_to_string list)))
2483end;
2484
2485local
2486fun err name t = mkStandardExn "get_source_function_generator"
2487        ("No source function generator exists for functions named " ^ name ^
2488         " and the type " ^ type_to_string t)
2489fun get_source_function_generator name t =
2490        case (Binarymap.peek(!source_function_generators,name))
2491        of NONE => raise (err name t)
2492        |  SOME x => (snd (first_e (err name t) (fn (x,y) => x t) (!x)))
2493fun gsf name t =
2494let     val function = if exists_source_function_precise t name
2495                then get_source_function_precise t name
2496                else (type_trace 1 (
2497                        "Generating function " ^ name ^ " for type " ^ type_to_string t ^ "\n") ;
2498                        (get_source_function_generator name t) t)
2499in
2500        if exists_source_function_precise t name
2501        then ()
2502        else add_source_function_precise t name function
2503end
2504in
2505fun generate_source_function name t =
2506        check_loop ("gsf"^name) t (gsf name)
2507                (fn list => raise (mkDebugExn "generate_source_function"
2508                        ("Experienced a loop whilst generating the source function " ^ name ^
2509                         "\nTrace: " ^ xlist_to_string type_to_string list)))
2510end;
2511
2512(*****************************************************************************)
2513(* Theorem generators:                                                       *)
2514(*                                                                           *)
2515(* Allow the proof of theorems recursively                                   *)
2516(*                                                                           *)
2517(*****************************************************************************)
2518
2519val coding_theorem_generators =
2520        ref (Binarymap.mkDict type_less) :
2521  (hol_type,
2522   (string,((hol_type -> term) option ref *
2523    ((hol_type -> bool) *
2524     (hol_type -> thm)) list ref)) dict ref) dict ref
2525
2526local
2527fun setup target name =
2528let     val _ = case (Binarymap.peek (!coding_theorem_generators,target))
2529                of NONE => coding_theorem_generators := Binarymap.insert(!coding_theorem_generators,target,ref (mkDict String.compare))
2530                |  SOME _ => ()
2531        val generators = Binarymap.find (!coding_theorem_generators,target)
2532        val _ = case(Binarymap.peek(!generators,name))
2533                of NONE => generators := Binarymap.insert(!generators,name,(ref NONE,ref []))
2534                |  SOME _ => ()
2535in
2536        Binarymap.find(!generators,name)
2537end
2538in
2539fun set_coding_theorem_conclusion target name mk_conc =
2540let     val (conc,list) = setup target name
2541in
2542        conc := SOME mk_conc
2543end
2544fun exists_coding_theorem_conclusion target name = isSome(!(fst(setup target name)))
2545fun get_coding_theorem_conclusion target name = valOf(!(fst(setup target name)))
2546fun add_coding_theorem_generator target (name:string) (predicate:hol_type -> bool)  (generator:hol_type -> thm) =
2547let     val (conc,list) = setup target name
2548in
2549        list := (predicate,generator) :: (!list)
2550end
2551end;
2552
2553val source_theorem_generators =
2554        ref (Binarymap.mkDict String.compare) :
2555  (string,((hol_type -> term) option ref *
2556   ((hol_type -> bool) *
2557    (hol_type -> thm)) list ref)) dict ref
2558
2559local
2560fun setup name =
2561let     val _ = case(Binarymap.peek(!source_theorem_generators,name))
2562                of NONE => source_theorem_generators :=
2563                        Binarymap.insert(!source_theorem_generators,name,(ref NONE,ref []))
2564                |  SOME _ => ()
2565in      Binarymap.find(!source_theorem_generators,name)
2566end
2567in
2568fun set_source_theorem_conclusion name mk_conc =
2569let     val (conc,list) = setup name
2570in      conc := SOME mk_conc
2571end
2572fun exists_source_theorem_conclusion name = isSome(!(fst(setup name)))
2573fun get_source_theorem_conclusion name = valOf(!(fst(setup name)))
2574fun add_source_theorem_generator name predicate generator =
2575let     val (conc,list) = setup name
2576in
2577        list := (predicate,generator) :: (!list)
2578end
2579end;
2580
2581fun MATCH_CONC thm conc =
2582let val thm' = SPEC_ALL thm
2583    val (vars,body) = strip_forall conc
2584in
2585    GENL vars (INST_TY_TERM (match_term (concl thm') body) thm')
2586end;
2587
2588local
2589fun err name target t = mkStandardExn "get_coding_theorem_generator"
2590        ("No coding theorem generator exists for theorems named " ^ name ^
2591         " in the translation: " ^ type_to_string target ^
2592         " --> " ^ type_to_string t)
2593fun get_coding_theorem_generator target name t =
2594    case (Binarymap.peek(!coding_theorem_generators,target))
2595    of NONE => raise (err name target t)
2596    |  SOME x => case (Binarymap.peek(!x,name))
2597       of NONE => raise (err name target t)
2598       |  SOME x => (snd (first_e (err name target t)
2599                                  (fn (x,y) => x t) (!(snd x))))
2600fun gct target name t =
2601let val _ = type_trace 2 ("->generate_coding_theorem(" ^ name ^ "," ^
2602                         (type_to_string t) ^ ")\n")
2603    val _ = if base_type t = t orelse
2604               exists_coding_theorem_precise target t name
2605               then () else (gct target name (base_type t) ; ())
2606    val theorem = if exists_coding_theorem_precise target t name
2607                then get_coding_theorem_precise target t name
2608                else (get_coding_theorem_generator target name t) t
2609    val mtheorem = if exists_coding_theorem_conclusion target name
2610                then MATCH_CONC theorem (get_coding_theorem_conclusion target name t)
2611                        handle e => raise (mkStandardExn "generate_coding_theorem"
2612("Generator for " ^ name ^
2613 " returned the theorem:\n " ^ thm_to_string theorem ^
2614 "\nThis does not match the specified conclusion for type: " ^
2615 type_to_string t ^ ":\n" ^
2616 term_to_string (get_coding_theorem_conclusion target name t)))
2617                else theorem
2618        val _ = if exists_coding_theorem_precise target t name
2619                then ()
2620                else add_coding_theorem_precise target t name mtheorem
2621   val _ = if null (hyp mtheorem) then ()
2622              else raise (mkStandardExn "generate_coding_theroem"
2623                   ("Generator for " ^ name ^
2624                    " returned the theorem:\n " ^ thm_to_string theorem ^
2625                    "\nwhich has the non-empty hypothesis set:\n" ^
2626                    xlist_to_string term_to_string (hyp mtheorem)))
2627in
2628        mtheorem
2629end
2630in
2631fun generate_coding_theorem target name t =
2632        check_loop ("gct" ^ name ^ type_to_string (cannon_type target)) t (gct target name)
2633                (fn list => raise (mkDebugExn "generate_coding_theorem"
2634                        ("Experienced a loop whilst generating the coding theorem " ^ name ^
2635                         " for type " ^ type_to_string target ^
2636                         "\nTrace: " ^ xlist_to_string type_to_string list)))
2637end;
2638
2639local
2640fun err name t = mkStandardExn "get_source_theorem_generator"
2641        ("No source theorem generator exists for theorems named " ^ name ^
2642         " and the type " ^ type_to_string t)
2643fun get_source_theorem_generator name t =
2644        case (Binarymap.peek(!source_theorem_generators,name))
2645        of NONE => raise (err name t)
2646        |  SOME x => (snd (first_e (err name t) (fn (x,y) => x t) (!(snd x))))
2647fun gst name t =
2648let     val _ = type_trace 2 ("->generate_source_theorem(" ^ name ^ "," ^ (type_to_string t) ^ ")\n")
2649        val _ = if base_type t = t orelse
2650                   (exists_source_theorem_precise t name)
2651                   then () else (gst name (base_type t) ; ())
2652        val theorem = if exists_source_theorem_precise t name
2653                then get_source_theorem_precise t name
2654                else (get_source_theorem_generator name t) t
2655        val mtheorem = if exists_source_theorem_conclusion name
2656                then MATCH_CONC theorem (get_source_theorem_conclusion name t)
2657                        handle e => raise (mkStandardExn "generate_source_theorem"
2658("Generator for " ^ name ^
2659 " returned the theorem:\n " ^ thm_to_string theorem ^
2660 "\nThis does not match the specified conclusion for type: " ^
2661 type_to_string t ^ ":\n" ^
2662 term_to_string (get_source_theorem_conclusion name t)))
2663                else theorem
2664        val _ = if null (hyp mtheorem) then ()
2665              else raise (mkStandardExn "generate_source_theroem"
2666                   ("Generator for " ^ name ^
2667                    " returned the theorem:\n " ^ thm_to_string theorem ^
2668                    "\nwhich has the non-empty hypothesis set:\n" ^
2669                    xlist_to_string term_to_string (hyp mtheorem)))
2670        val _ = if exists_source_theorem_precise t name
2671                then ()
2672                else add_source_theorem_precise t name mtheorem
2673in
2674        mtheorem
2675end
2676in
2677fun generate_source_theorem name t =
2678        check_loop ("gst" ^ name) t (gst name)
2679                (fn list => raise (mkDebugExn "generate_source_theorem"
2680                        ("Experienced a loop whilst generating the source theorem " ^ name ^
2681                         "\nTrace: " ^ xlist_to_string type_to_string list)))
2682end;
2683
2684(*****************************************************************************)
2685(* Polytypic generation of functions:                                        *)
2686(*                                                                           *)
2687(* All the following higher order functions expect functions from the set:   *)
2688(*      make_term : Makes a function term of type :target -> 'a              *)
2689(*      get_def   : Returns the definition of a previously defined function  *)
2690(*      get_func  : Returns a term to be applied for the given type          *)
2691(*      conv      : Converts a term from 'make_term' to the acceptable form  *)
2692(*      get_ind   : Returns the induction theorem for a type                 *)
2693(* these will be used as abbreviations in the type signatures.               *)
2694(*                                                                           *)
2695(* inst_function_def       : get_def -> get_func -> hol_type -> thm          *)
2696(* expanded_function_def   : conv -> get_def -> hol_type -> term list -> thm *)
2697(*     Each function returns a fully instantiated definition, however        *)
2698(*     'expanded_function_def' returns one that applies to all types in the  *)
2699(*     recursive set and also applies 'conv' to them                         *)
2700(*                                                                           *)
2701(* mk_split_source_function :                                                *)
2702(*         make_term -> get_def -> get_func -> conv -> hol_type -> thm * thm *)
2703(* mk_split_target_function :                                                *)
2704(*         make_term -> get_def -> get_func -> conv -> translation_scheme -> *)
2705(*                        hol_type -> (thm * (term * term) list * thm) * thm *)
2706(*    These two functions return a theorem of mutual recursion, and an       *)
2707(*    equality theorem that maps this onto the term given by make_term.      *)
2708(*    Both functions use SPLIT_FUNCTION_CONV to split the term, 'source'     *)
2709(*    uses 'prove_rec_fn_exists' to form the mutual recursive function and   *)
2710(*    'target' uses 'prove_induction_recursion_thms', and as such also       *)
2711(*    returns a mapping and theorem of induction over the target type.       *)
2712(*                                                                           *)
2713(* MATCH_IND_TERM    : term -> thm -> thm                                    *)
2714(*     Matches a theorem to a term from the antecedents of an induction      *)
2715(*     theorem                                                               *)
2716(*                                                                           *)
2717(* strengthen_proof_term  : thm list -> term -> thm                          *)
2718(*      Takes a term of the form:                                            *)
2719(*           (f ... = a) /\ ... ==> (f = h)                                  *)
2720(*      and returns a theorem strengthened by adding in the rest of the      *)
2721(*      mutually recursive definitions:                                      *)
2722(*      |- (f ... = a) /\ ... /\ (g ... = b) /\ ... ==> (f = h) /\ (g = h')  *)
2723(*            ==> (f ... = a) /\ ... ==> (f = h)                             *)
2724(*                                                                           *)
2725(* prove_split_term      : (term * (term * hol_type)) list ->                *)
2726(*                                                 thm -> thm -> term -> thm *)
2727(*     Given a mapping from predicates to function constants and types,      *)
2728(*     a [mutual] induction theorem and a set of mutually recursive          *)
2729(*     functions, proves a term of the form:                                 *)
2730(*         |- (split (C a .. z) = A a .. z) /\ .... ==> (split = fn x y)     *)
2731(*                                                                           *)
2732(* prove_all_split_terms : get_ind * get_def * conv ->                       *)
2733(*                         (term * hol_type) list -> thm -> thm list * thm   *)
2734(*     Given a list mapping function terms to types and a theorem            *)
2735(*     prove_all_split_terms removes all split terms from the hypothese      *)
2736(*     of the theorem and returns them as a list                             *)
2737(*                                                                           *)
2738(* remove_hyp_terms      : thm -> thm list -> thm list * thm -> thm          *)
2739(*     Given a pair function, the list of proved split terms and a the       *)
2740(*     conjunctions from the theorem of mutual recursion, remove_hyp_terms   *)
2741(*     removes all hypothesis from the final mutual recursion theorem        *)
2742(*                                                                           *)
2743(* match_mapping         : thm -> (term * term) list -> get_func -> thm ->   *)
2744(*                               hol_type -> (term * (term * hol_type)) list *)
2745(*     Given an equality theorem and a mapping as returned by mk_split...    *)
2746(*     and an induction theorem, match_mapping attempts to construct a full  *)
2747(*     mapping from predicates to functions constants and types.             *)
2748(*                                                                           *)
2749(* unsplit_function      : get_ind -> get_def -> get_func -> conv ->         *)
2750(*                                              hol_type -> thm * thm -> thm *)
2751(*     Given a pair '(eq_thm,mrec_thm)' representing the equality theorem    *)
2752(*     and mutual recursion theorem returned by the mk_split_... functions   *)
2753(*     'unsplit_function' returns a theorem of mutual recursion that matches *)
2754(*     that given by eq_thm.                                                 *)
2755(*     Uses 'prove_all_split_terms' and 'remove_hyp_terms' to remove any     *)
2756(*     condition imposed by the equality theorem.                            *)
2757(*                                                                           *)
2758(* mk_source_functions : string -> mk_term -> get_func -> conv ->            *)
2759(*                                                          hol_type -> unit *)
2760(* mk_coding_functions, mk_target_functions : string -> mk_term ->           *)
2761(*                          get_func -> conv -> hol_type -> hol_type -> unit *)
2762(*    These functions combine 'unsplit_function' and 'mk_split_..._function' *)
2763(*    to generate functions, these functions are then defined through        *)
2764(*    'new_specification' and the relevant theorems and definitions are      *)
2765(*    stored in the coding base.                                             *)
2766(*                                                                           *)
2767(*****************************************************************************)
2768
2769fun inst_function_def get_def get_func (t:hol_type) =
2770let     val _ = type_trace 3 "->inst_function_def\n"
2771in
2772        LIST_CONJ (map (C (PART_MATCH (rator o lhs)) (get_func t))
2773                (CONJUNCTS (get_def t)))
2774        handle e => wrapException "inst_function_def" e
2775end
2776
2777local
2778fun wrap "" e = wrapException ("expanded_function_def") e
2779  | wrap s  e = wrapException ("expanded_function_def (" ^ s ^ ")") e;
2780fun gen1 vars thm =
2781let     val rvars = free_vars_lr ((rand o lhs o concl) thm)
2782        val vars' = intersect vars (free_vars (concl thm))
2783in
2784        GENL rvars (GENL vars' thm)
2785end
2786fun GEN_RAND vars thm =
2787        LIST_CONJ (map (gen1 vars o SPEC_ALL) (CONJUNCTS thm)) handle e => wrap "GEN_RAND" e
2788
2789fun inst_it vars thm term =
2790        GEN_RAND vars (LIST_CONJ (map (C (PART_MATCH (rator o lhs)) term) (CONJUNCTS thm)) handle e => wrap "inst_it" e)
2791fun single_inst recfns vars main enc =
2792let     val current = strip_conj (concl main)
2793        val next = mapfilter
2794                (inst_it vars enc)
2795                (find_terms (can (match_term
2796                        ((rator o lhs o snd o strip_forall o hd o strip_conj o concl) enc)))
2797                (concl main) handle e => wrap "single_inst" e)
2798        val candidates = filter (fn x => (not o C mem current o concl) x andalso exists (C free_in (concl x)) recfns) next
2799in
2800        CONJ main (hd candidates)
2801end
2802fun inst_all recfns vars main [] = main
2803  | inst_all recfns vars main list =
2804        uncurry (inst_all recfns vars)
2805                (pick_e (mkDebugExn
2806                        ("expanded_function_def")
2807                        ("Sub-function returned by expanded_function_def not used in main function"))
2808                (single_inst recfns vars main) list)
2809fun inst_pairs recfns vars main pair = repeat (C (single_inst recfns vars) pair) main
2810fun fix_function recfns conv pair (main,sub) =
2811let     val vars = flatten (
2812                        map (fn c => set_diff ((fst o strip_forall) c) ((free_vars o rand o lhs o snd o strip_forall) c))
2813                                ((strip_conj o concl) main))
2814in
2815        inst_pairs recfns vars (inst_all recfns vars (CONV_RULE conv main)
2816                                        (map (CONV_RULE conv) sub)) (CONV_RULE conv pair)
2817end
2818fun subset x y = set_eq x (intersect x y);
2819in
2820fun expanded_function_def conv create_conv get_def t term_list =
2821let     val _ = type_trace 3 "->expanded_function_def\n"
2822        val base_types = split_nested_recursive_set (base_type t) handle e => wrap "" e
2823        val functions = map (get_def ## map get_def o fst) base_types handle e => wrap "" e
2824        val pair = get_def (mk_prod(alpha,beta))
2825
2826        val matched = map (fn (main,sub) => (tryfind_e
2827                        (mkDebugExn "expanded_function_def"
2828                                ("Could not find a match for the function:\n" ^ thm_to_string main ^
2829                                 "\nin the term list: " ^ xlist_to_string term_to_string term_list))
2830                        (fn t => LIST_CONJ (map (C (PART_MATCH (rator o lhs)) t) (CONJUNCTS main))) term_list,sub)) functions
2831        val recfns = map (repeat rator o lhs o snd o strip_forall o hd o strip_conj o concl o fst) matched
2832        val full = PURE_REWRITE_RULE [GSYM CONJ_ASSOC] (LIST_CONJ (map (fix_function recfns conv pair) matched))
2833in
2834        CONV_RULE create_conv full
2835end
2836end;
2837
2838local
2839fun wrap_conv source func conv term =
2840let     val name = ("conv (supplied function)")
2841        val mkExn = mkStandardExn name
2842        val result = conv term handle e => wrapException name e
2843        val _ = if can (match_term term) (lhs (concl result)) then () else
2844                        raise (mkExn (  "Left hand side of result theorem:\n" ^ thm_to_string result ^
2845                                        "\ndoes not match the term given:\n" ^ term_to_string term))
2846        val _ = if not source orelse source andalso is_source_function ((rhs o concl) result) then () else
2847                        raise (mkExn (  "Right hand side of result theorem:\n" ^ thm_to_string result ^
2848                                        "\nis not a correct source function, see help for details"))
2849        val _ = if source orelse not source andalso is_target_function ((rhs o concl) result) then () else
2850                        raise (mkExn (  "Right hand side of result theorem:\n" ^ thm_to_string result ^
2851                                        "\nis not a correct target function, see help for details"))
2852in
2853        result
2854end
2855fun mk_eq_thm func mk_term get_def get_func conv t =
2856let     fun wrap e = wrapException ("mk_split_" ^ func ^ "_function (mk_eq_thm)") e
2857        val recursive_types = (map (I ## fst) (split_nested_recursive_set t))
2858        val tm = mk_prod(alpha,beta)
2859        val terms = map (mk_term o fst) recursive_types handle e => wrap e
2860        val pair = CONV_RULE conv (get_def tm) handle e => wrap e
2861        val get_hfuns = map (inst_function_def (CONV_RULE conv o get_def) get_func) o filter (not o can (match_type tm))
2862        val hfuns = get_hfuns (mk_set (flatten (map snd recursive_types)))
2863        val term = list_mk_conj (flatten (map strip_conj terms)) handle e => wrap e
2864        val _ = type_trace 3 ("Target function: \n" ^ term_to_string term ^ "\n")
2865        val eq_thm =
2866                (conv THENC
2867                 SPLIT_FUNCTION_CONV (is_double_term_source,pair) hfuns) term handle e => wrap e
2868        val _ = type_trace 3 ("Equivalence theorem: \n" ^ thm_to_string eq_thm ^ "\n");
2869in
2870        eq_thm
2871end
2872fun wraps e = wrapException "mk_split_source_function" e
2873fun wrapt e = wrapException "mk_split_target_function" e
2874in
2875fun mk_split_source_function mk_term get_def get_func conv create_conv t =
2876let     val _ = type_trace 2 "->mk_split_source_function\n"
2877        val eq_thm1 = mk_eq_thm "source" mk_term get_def get_func (wrap_conv true "source" conv) t
2878        val eq_thm2 = create_conv (rhs (concl eq_thm1)) handle e => wraps e
2879        val r_thm = prove_rec_fn_exists (TypeBase.axiom_of t) ((rhs o concl) eq_thm2) handle e => wraps e
2880        val _ = assert "mk_split_source_function" [
2881                ("prove_rec_fn_exists returned a theorem with a non-empty hypothesis set!",
2882                 null o hyp)] r_thm
2883in
2884        (CONV_RULE (STRIP_QUANT_CONV (REWR_CONV (GSYM eq_thm2))) r_thm,eq_thm1) handle e => wraps e
2885end
2886fun mk_split_target_function mk_term get_def get_func conv create_conv (scheme:translation_scheme) t =
2887let     val _ = type_trace 2 "->mk_split_target_function\n"
2888        val eq_thm1 = mk_eq_thm "target" mk_term get_def get_func (wrap_conv false "target" conv) t
2889        val eq_thm2 = create_conv (rhs (concl eq_thm1)) handle e => wrapt e
2890        val (i_thm,mapping,r_thm) = prove_induction_recursion_thms scheme ((rhs o concl) eq_thm2)
2891in
2892        ((i_thm,mapping,CONV_RULE (STRIP_QUANT_CONV (REWR_CONV (GSYM eq_thm2))) r_thm),eq_thm1) handle e => wrapt e
2893end
2894end
2895
2896
2897(* Matches a thm to a term:
2898Terms are:
2899        Either: !a...z. f (C a .. z) = g (C a .. z)
2900        Or    : !a...z. (a = b) /\ (c = d) /\ (e = f) ==> !a'...z'. f (C a' .. z') = g (C a' .. z')
2901Thms will then be:
2902        Either: [.] |- f (C a .. z) = g (C a .. z)
2903        Or    : [.,a = b,c = d,e = f] |- f (C a .. z) = g (C a .. z)
2904*)
2905local
2906fun disch_and_conj_list thm [] = thm
2907  | disch_and_conj_list thm [a] = DISCH a thm
2908  | disch_and_conj_list thm (a::b::rest) =
2909        CONV_RULE (REWR_CONV AND_IMP_INTRO) (DISCH a (disch_and_conj_list thm (b::rest)))
2910fun SPECL_GEN thm =
2911        SPECL (map (genvar o type_of) (fst (strip_forall (concl thm)))) thm
2912in
2913fun MATCH_IND_TERM term assum =
2914let     val (gen1,body1) = strip_forall term
2915        val split2 = total ((strip_conj ## strip_forall) o dest_imp) body1;
2916        val assum' = SPECL_GEN assum
2917in
2918        GENL gen1 (case split2
2919                        of SOME (a,(b,term)) =>
2920                                (disch_and_conj_list
2921                                        (GENL b (INST_TY_TERM (match_term (concl assum')
2922                                                term) assum'))
2923                                        a
2924                                handle e =>
2925                                INST_TY_TERM (match_term (concl assum') body1) assum')
2926                        |  NONE => INST_TY_TERM (match_term (concl assum') body1) assum')
2927end     handle e => wrapException "MATCH_IND_TERM" e
2928end
2929
2930local
2931fun wrap "" e = wrapException "strengthenProof" e
2932  | wrap s e = wrapException ("strengthenProof (" ^ s ^ ")") e
2933
2934(* Anything in l1 is free in l2                                                     *)
2935fun any_free_in [] l2 = false
2936  | any_free_in (x::xs) l2 = exists (free_in x) l2 orelse any_free_in xs l2;
2937
2938(* A function *not* in funcs is supplied with one in funcs                          *)
2939fun undef_hofs funcs term =
2940        find_terms (both o (not o C mem funcs ## any_free_in funcs) o strip_comb) ((rhs o snd o strip_forall) term)
2941        handle e => wrap "undef_hofs" e;
2942
2943(* Generalise a theorem with arguments of the constructor                           *)
2944fun gen_const thm =
2945        GENL ((free_vars_lr o rand o lhs o snd o strip_forall o concl) thm) thm
2946        handle e => wrap "gen_const" e;
2947
2948(* Match a term such as, enc1 (enc2 ...) with a theorem                             *)
2949fun match_term_func term thm =
2950        LIST_CONJ (map (gen_const o C (PART_MATCH (rator o lhs)) term) (CONJUNCTS thm))
2951        handle e => wrap "match_term" e;
2952
2953(* Finds HO calls in a term and adds in the functions to the set of conjunctions    *)
2954fun add_defs_conv fvs functions (funcs,thm) =
2955let     val term = (rhs o concl) thm
2956        val hofs = flatten (map (undef_hofs funcs) (strip_conj term))
2957        val defs = mapfilter (fn h => tryfind_e Empty (match_term_func h) functions) hofs
2958        val defs' = map (fn d => LIST_CONJ (map (GENL (intersect (free_vars (concl d)) fvs)) (CONJUNCTS d))) defs
2959        val sdefs = map (fn d => (d,(rator o lhs o snd o strip_forall o hd o strip_conj o concl) d)) defs'
2960        val adefs = filter (not o C mem funcs o snd) sdefs
2961        val _ = if null adefs then raise Empty else ()
2962        val (new_funcs,defs'') = foldr (fn ((a,b),(nf,d)) => (b :: nf,a :: d)) (funcs,[]) adefs
2963in
2964        (new_funcs @ funcs,TRANS thm (foldr (uncurry PROVE_HYP)
2965                ((foldr (fn (a,b) => ADDR_AND_CONV (concl a) THENC b) ALL_CONV defs'') term) defs''))
2966end     handle e => wrap "add_defs_conv" e;
2967in
2968fun strengthen_proof_term functions term =
2969let     val _ = type_trace 3 "->strengthen_proof_term\n"
2970        val _ = type_trace 3 ("Strengthening proof term: " ^ term_to_string term)
2971        val _ = assert "strengthen_proof_term" [
2972                ("Proof term is not an implication from a function definition to a conjunction of function equalities",
2973                 is_implication_of
2974                        (is_conjunction_of (is_eq o snd o strip_forall))
2975                        (is_conjunction_of (fn x => (is_eq o snd o strip_forall) x
2976                                        andalso (can dom_rng o type_of o lhs o snd o strip_forall) x)))] term
2977        val (ante,conc) = guarenteed dest_imp_only term handle e => wrap "" e;
2978        val clauses = strip_conj ante handle e => wrap "" e;
2979        val funcs = map (rator o lhs o snd o strip_forall) clauses handle e => wrap "" e;
2980        val fvs = mk_set (flatten (map (fst o strip_forall) (strip_conj conc)))
2981        val thm1 = snd (EQ_IMP_RULE ((LAND_CONV (REWR_CONV
2982                        (snd (repeat (add_defs_conv fvs functions) (funcs,REFL ante)))) THENC
2983                         PURE_REWRITE_CONV [GSYM CONJ_ASSOC]) term))
2984                        handle e => wrap "" e
2985        val new_term = guarenteed (fst o dest_imp_only o concl) thm1 handle e => wrap "" e
2986        val _ = assert "strengthen_proof_term" [
2987                ("Strengthen proof term is not of the correct form, should be impossible!",
2988                 is_implication_of (is_conjunction_of (is_eq o snd o strip_forall))
2989                                ((fn x => (is_eq o snd o strip_forall) x
2990                                        andalso (can dom_rng o type_of o lhs o snd o strip_forall) x)))] new_term
2991        val all_rators = mk_set ((map (rator o lhs o snd o strip_forall) o strip_conj o fst o dest_imp_only) new_term)
2992                        handle e => wrap "" e
2993        val subs = subst (map (op|-> o dest_eq o snd o strip_forall) (strip_conj conc)) handle e => wrap "" e
2994        val conc' = (strip_conj conc) @ (map (fn x => list_mk_forall(intersect fvs (free_vars (subs x)),mk_eq(x,subs x)))
2995                        (filter (not o C mem (map (lhs o snd o strip_forall) (strip_conj conc))) all_rators))
2996                        handle e => wrap "" e
2997        val final_term = mk_imp(fst (dest_imp_only new_term),list_mk_conj conc') handle e => wrap "" e
2998in
2999        DISCH_ALL (MP thm1 (DISCH (fst (dest_imp_only new_term))
3000                (LIST_CONJ (map (fn c => first (curry op= c o concl) (CONJUNCTS (UNDISCH_ONLY (ASSUME final_term))))
3001                (strip_conj conc))))) handle e => wrap "" e
3002end
3003end
3004
3005fun prove_split_term mapping induction function (dead_thm,dead_value) term =
3006let     val _ = type_trace 3 "->prove_split_term\n"
3007        val _ = type_trace 3 ("Attempting to prove split term: " ^ term_to_string term ^ "\n")
3008        fun wrap e = wrapException "prove_split_term" e
3009
3010        val _ = assert "prove_split_term" [
3011                ("Proof term is not an implication from a function definition to a conjunction of function equalities",
3012                 is_implication_of (is_conjunction_of (is_eq o snd o strip_forall))
3013                        (is_conjunction_of (fn x => (is_eq o snd o strip_forall) x
3014                                        andalso (can dom_rng o type_of o lhs o snd o strip_forall) x)))] term
3015
3016        val equivs = (map ((I ## dest_eq) o strip_forall) o strip_conj o snd o dest_imp_only) term handle e => wrap e
3017
3018        val tt = mk_forall(mk_var("t",alpha),mk_comb(mk_var("P",alpha --> bool),mk_var("t",alpha))) handle e => wrap e
3019        val _ = assert "prove_split_term" [
3020                ("Induction theorem is not an implication to a conjunction of generalised predicates",
3021                 is_conjunction_of (can (match_term tt)) o snd o dest_imp_only o snd o strip_forall o concl)] induction
3022        val predicates =
3023                (map (rator o snd o strip_forall) o strip_conj o snd o dest_imp_only o snd o strip_forall o concl)
3024                induction handle e => wrap e
3025
3026        val all_fns = bucket_alist (map ((rator ## I) o dest_eq o snd o strip_forall)
3027                                (strip_conj (fst (dest_imp_only term)))) handle e => wrap e;
3028
3029        val _ = assert "prove_split_term" [
3030                ("Number of predicates, " ^ int_to_string (length predicates) ^
3031                 ", does not match number of functions, " ^ int_to_string (length all_fns),
3032                 curry op= (length predicates) o length)] all_fns
3033
3034        val _ = (raise (mkDebugExn "prove_split_term"
3035                ("Free variables occur in the predicate term: " ^ term_to_string
3036                        (first_e Empty (fn t => not (null (set_diff (free_vars t)
3037                                (flatten (map (op:: o strip_comb o fst) all_fns) @ set_diff (free_vars (rhs t)) (snd (strip_comb (lhs t)))))))
3038                                ((strip_conj o snd o dest_imp_only o snd o strip_forall) term))))) handle Empty => ();
3039
3040        val match = map (fn (a:term,(b,t:hol_type)) => (a,first (can (match_term b) o snd o snd) equivs)) mapping
3041
3042        val predicate = RIGHT_CONV_RULE (EVERY_CONJ_CONV (fn term =>
3043                                ORDER_FORALL_CONV (((fn a => last a :: butlast a) o fst o strip_forall) term) term))
3044                        (LIST_MK_CONJ ((map (fn (_,(a,b)) => STRIP_QUANT_CONV (REWR_CONV FUN_EQ_THM)
3045                                (list_mk_forall(a,mk_eq b))) match))) handle e => wrap e
3046
3047        (* The instantiated predicate, rewritten to match the conclusion of the term *)
3048        val inst = CONV_RULE (RAND_CONV (REWR_CONV (GSYM predicate)))
3049                (HO_PART_MATCH (snd o dest_imp_only) induction ((rhs o concl) predicate))
3050                handle _ => raise (mkDebugExn "prove_split_term"
3051                                ("Term conclusion, " ^ (term_to_string o lhs o concl) predicate ^
3052                                 " does not match induction conclusion: " ^
3053                                        (term_to_string o snd o dest_imp_only o snd o strip_forall o concl) induction));
3054
3055        (* [!a .. z. split_n (C a .. z) = body a .. z] |- split_n (C a .. z) = body a .. z *)
3056        val assums = map (SPEC_ALL o ASSUME) ((strip_conj o fst o dest_imp_only) term) handle e => wrap e
3057
3058        (* Converts a rewrite with the assumption [split_n x = f_n a_n .. x] |- split_n x = f_n a_n .. x *)
3059        fun fix_rewrite thm =
3060        let     val terms = find_terms (fn t => (exists (curry op= (rator t) o snd o snd) equivs handle _ => false))
3061                        (rhs (concl thm));
3062                val rwrs = map (fn t => (list_mk_forall o (I ## (mk_eq o
3063                                (C (curry mk_comb) (rand t) ## C (curry mk_comb) (rand t)))))
3064                        (first (curry op= (rator t) o snd o snd) equivs)) terms
3065
3066        in      PURE_REWRITE_RULE (map (SPEC_ALL o GSYM o ASSUME) rwrs) thm
3067        end     handle e => wrapException "prove_split_term (fix_rewrite)" e
3068
3069        val clauses = (CONJUNCTS o SPEC_ALL) function
3070        val missing_exn =
3071                mkDebugExn "prove_split_term"
3072                "The function given does not exactly match the induction theorem"
3073
3074        (* Rewrite theorems: should match the assumptions on the left and the antecedents of inst on the right *)
3075        val rewrites = map (GSYM o fix_rewrite o (fn x => tryfind_e missing_exn (C REWR_CONV x) clauses) o
3076                                rhs o snd o strip_forall o snd o strip_imp o snd o strip_forall)
3077                        ((strip_conj o fst o dest_imp_only o concl) inst)
3078
3079        (* Rewritten assumptions using the rewrites, output should match inst *)
3080        val assums' = map (CONV_RULE (STRIP_QUANT_CONV (RAND_CONV (FIRST_CONV (map MATCH_CONV rewrites))))) assums
3081                handle _ =>     raise (mkDebugExn "prove_split_term"
3082                                "The term given does not match the induction theorem and function")
3083
3084        (* Antecedents of the instantiation we wish to match *)
3085        val terms = (strip_conj o fst o dest_imp_only o concl) inst handle e => wrap e
3086
3087        (* Extra theorems when P x is false (not required for encoding) *)
3088        fun mk_all_extra [] = []
3089          | mk_all_extra thms =
3090        case (total (first is_neg) (mapfilter (fst o dest_imp_only o snd o strip_forall) (strip_conj (fst (dest_imp_only (concl inst))))))
3091        of SOME assum => map (RIGHT_CONV_RULE (REPEATC (CHANGED_CONV (PURE_ONCE_REWRITE_CONV [dead_thm,ASSUME assum,COND_CLAUSES] THENC PURE_ONCE_REWRITE_CONV thms)))) thms
3092        |  NONE => []
3093
3094        fun ttrans [] _ = []
3095          | ttrans (x::xs) ys =
3096        let val (y,ysr) = pluck (can (TRANS x)) ys
3097        in      TRANS x y::ttrans xs ysr end;
3098
3099        val extra_assums =
3100                if can (tryfind (dest_neg o fst o dest_imp_only o snd o strip_forall)) terms
3101                then ttrans (mk_all_extra (filter (is_cond o rhs o snd o strip_forall o concl) assums))
3102                                (map SYM (mk_all_extra (map SPEC_ALL
3103                                        (filter (is_cond o rhs o snd o strip_forall o concl) clauses))))
3104                else []
3105
3106        (* Dead value theorems (only used in making target functions) *)
3107        val dead_terms = flatten (map (filter (fn x => (curry op= dead_value o rand o lhs o snd o strip_forall) x handle e => false) o hyp) assums');
3108        val dead_thms = case (mappartition
3109                        (CONV_RULE bool_EQ_CONV o (REPEATC (ONCE_REWRITE_CONV (map GEN_ALL assums) THENC ONCE_REWRITE_CONV clauses THENC REWRITE_CONV [dead_thm]))) dead_terms)
3110                        of (x,[]) => x
3111                        |  (_,x::xs) => raise (mkDebugExn "prove_split_term"
3112                                ("Could not resolve the 'dead' term: " ^ (term_to_string x)))
3113
3114        val final_assums = map (C (foldl (uncurry PROVE_HYP)) dead_thms) assums';
3115
3116        (* Make sure we have exactly the same form as the term we were given *)
3117        val inst' = CONV_RULE (RAND_CONV (REWR_CONV (CONV_RULE bool_EQ_CONV (AC_CONV (CONJ_ASSOC,CONJ_COMM)
3118                (mk_eq((snd o dest_imp_only o concl) inst,(snd o dest_imp_only) term)))))) inst handle e => wrap e
3119
3120        (* Final proof of the term *)
3121        val final = PURE_REWRITE_RULE [AND_IMP_INTRO,GSYM CONJ_ASSOC] (foldr (uncurry DISCH) (MP inst' (LIST_CONJ (map (fn t =>
3122                        tryfind_e (     mkDebugExn "prove_split_term"
3123                                        ("No matching antecedent found to match induction term " ^ term_to_string t))
3124                        (MATCH_IND_TERM t) (final_assums @ extra_assums)) terms)))
3125                ((strip_conj o fst o dest_imp_only) term))
3126in
3127        if null (hyp final) then final else
3128                raise (mkDebugExn "prove_split_term" ("The exception: " ^ term_to_string (hd (hyp final)) ^ " exists in the final proof!"))
3129
3130end
3131
3132local
3133fun wrap "" e = wrapException "prove_all_split_terms" e
3134  | wrap s e = wrapException ("prove_all_split_terms (" ^ s ^ ")") e
3135
3136(* Prove a single split term by strengthening then inductive proof           *)
3137fun full_prove_split_term (get_ind,get_def,conv,create_conv,dead_thm,dead_value) t h =
3138let     val functions = flatten (map (map (CONV_RULE conv o get_def) o fst o snd)
3139                                (split_nested_recursive_set (base_type t)));
3140        val sh = strengthen_proof_term (CONV_RULE conv (get_def (mk_prod(alpha,beta))) :: functions) h
3141                handle e => raise (mkDebugExn "prove_all_split_terms (full_prove_split_term)"
3142                        ("Could not strengthen uniqueness proof: " ^ (term_to_string h) ^
3143                        "\nusing the function set: " ^ xlist_to_string thm_to_string functions ^
3144                        "\noriginal exception: " ^ exn_to_string e));
3145        val sh' = CONV_RULE (LAND_CONV (LAND_CONV create_conv)) sh
3146        val function = expanded_function_def conv create_conv get_def t
3147                                ((map (rhs o snd o strip_forall) o strip_conj o snd o dest_imp_only) h)
3148                handle e => wrap "full_prove_split_term" e
3149        val (induction,mapping) = get_ind t handle e => wrap "full_prove_split_term" e
3150        val th = prove_split_term mapping induction function (dead_thm,dead_value) ((fst o dest_imp_only o concl) sh')
3151                handle e => raise (mkDebugExn "prove_all_split_terms (full_prove_split_term)"
3152                        ("Could not prove uniqueness proof: " ^ ((term_to_string o fst o dest_imp_only o concl) sh') ^
3153                         "\nusing the expanded function definition: " ^ (thm_to_string function) ^
3154                         "\nand the induction theorem: " ^ (thm_to_string induction) ^
3155                         "\noriginal exception: " ^ exn_to_string e));
3156in
3157        MP sh' th handle e => raise (mkDebugExn "prove_all_split_terms (full_prove_split_term)"
3158                ("Proof returned by 'prove_split_term' does not exactly match its input term, " ^
3159                 "\ninput term: " ^ ((term_to_string o fst o dest_imp_only o concl) sh') ^
3160                 "\noutput thm: " ^ (thm_to_string th)))
3161end
3162
3163(* Find the type of a term in the match list                                 *)
3164fun get_type term list =
3165let     val rs = guarenteed (map (rhs o snd o strip_forall) o strip_conj o snd o dest_imp_only) term
3166in      tryfind_e Empty (C assoc list) rs
3167end
3168in
3169fun prove_all_split_terms gets matches thm =
3170let     val _ = type_trace 3 "->prove_all_split_terms\n"
3171        val terms = filter is_imp_only (hyp thm)
3172
3173        val _ = map (fn term => assert "prove_all_split_terms" [
3174                (("Proof term: " ^ term_to_string term ^
3175                 "is not an implication from a function definition to a conjunction of function equalities"),
3176                 is_implication_of (is_conjunction_of (is_eq o snd o strip_forall))
3177                        (fn x => (is_eq o snd o strip_forall) x
3178                                andalso (can dom_rng o type_of o lhs o snd o strip_forall) x))] term) terms
3179
3180        fun do_all matches [] = (type_trace 1 "0\n" ; [])
3181          | do_all matches terms =
3182        let     val (found,notfound) = mappartition (fn t => (get_type t matches,t)) terms
3183                val done = map (uncurry (full_prove_split_term gets)) found
3184                val _ = type_trace 1 (int_to_string (length terms) ^ "-")
3185                val _ = hd done
3186                val rwrs = map (op|-> o uncurry (C pair) o dest_eq o snd o
3187                        strip_forall o snd o dest_imp_only o concl) done
3188        in
3189                done @ do_all (map (subst rwrs ## I) matches) notfound
3190        end;
3191
3192        val proofs =
3193                if null terms then [] else
3194                (type_trace 1 "Proving uniqueness terms: " ;
3195                 do_all matches terms
3196                        handle Empty =>
3197                        raise (mkDebugExn "prove_all_split_terms"
3198                                ("The type of one or more of the uniqueness proofs: " ^
3199                                 xlist_to_string term_to_string terms ^
3200                                 "\ncould not be matched to the list: " ^
3201                                 xlist_to_string (xpair_to_string term_to_string type_to_string) matches)))
3202in
3203        (proofs,foldl (uncurry PROVE_HYP_CHECK) thm proofs)
3204        handle e => wrap "" e
3205end
3206end
3207
3208local
3209fun wrap "" e = wrapException "remove_hyp_terms" e
3210  | wrap s e = wrapException ("remove_hyp_terms (" ^ s ^ ")") e
3211
3212(* Performs a fold, but retains lists of passes and failures                 *)
3213fun filter_fold f a [] = (a,([],[]))
3214         | filter_fold f a (x::xs) = (I ## (cons x ## I)) (filter_fold f (f x a) xs)
3215        handle Empty => (I ## (I ## cons x)) (filter_fold f a xs)
3216
3217(* Given a pair thm: 'split (a,b) = f a b' returns a rewrite 'split = f'     *)
3218fun fix_pair1 pair_thm thm =
3219let     val thm' = SPEC_ALL (CONV_RULE (STRIP_QUANT_CONV (RAND_CONV (REWR_CONV (GSYM pair_thm)))) thm)
3220        val thm'' = GENL ((strip_pair o rand o rhs o concl) thm') thm'
3221in
3222        MATCH_MP (snd (EQ_IMP_RULE (SPEC_ALL FUN_EQ_THM))) (MP (HO_PART_MATCH (fst o dest_imp_only)
3223                (TypeBase.induction_of (mk_prod(alpha,beta))) (concl thm'')) thm'')
3224end handle e => wrap "fix_pair1" e
3225fun fix_pair2 pair_thm thm =
3226let     val thm' = SPEC_ALL (CONV_RULE (STRIP_QUANT_CONV (RAND_CONV (REWR_CONV (GSYM pair_thm)))) thm)
3227        val thm'' = GENL ((strip_pair o rand o rhs o concl) thm') thm'
3228in
3229        CONV_RULE (REWR_CONV (GEN_ALL (SYM (SPEC_ALL FUN_EQ_THM)))) thm''
3230end handle e => wrap "fix_pair2" e
3231fun fix_pair pair_thm thm =
3232        if (can (match_type (mk_prod(alpha,beta))) ((type_of o rand o lhs o snd o strip_forall o concl) thm))
3233        then    fix_pair1 pair_thm thm
3234        else    fix_pair2 pair_thm thm
3235
3236fun PROVE_HYP_CONJ thm1 thm2 =
3237let     val thm' = EQ_MP (CONV_RULE bool_EQ_CONV
3238                        (tryfind_e Empty (AC_CONV (CONJ_ASSOC,CONJ_COMM) o curry mk_eq (concl thm1)) (hyp thm2))) thm1
3239in
3240        if mem (concl thm') (hyp thm2) then PROVE_HYP_CHECK thm' thm2 else raise Empty
3241end     handle Empty => raise Empty | e => wrap "PROVE_HYP_CONJ" e
3242fun remove_hyp_terms_pre min pair_thm proofs (mthms,thm) =
3243let     val _ = type_trace 3 "->remove_hyp_terms\n"
3244        val to_remove = length mthms - min
3245        val _ = if to_remove = 0 then type_trace 1 "0\n" else type_trace 1 (int_to_string (length mthms - min) ^ "-")
3246        val (thm',(removed,kept)) = filter_fold PROVE_HYP_CONJ thm mthms
3247        val pair_rewrites = mapfilter (GEN_ALL o fix_pair pair_thm) removed
3248        val nonpair_rewrites = mapfilter (fn m => tryfind (C MP m) proofs) removed
3249        val _ = if length pair_rewrites + length nonpair_rewrites = length removed then ()
3250                else raise (mkDebugExn "remove_hyp_terms"
3251                        "Not all the hypotheses removed could be matched to a pair_theorem or a proved split term")
3252        fun conv term =
3253                if term = hd (hyp (hd mthms)) then ALL_CONV term
3254                else (PURE_REWRITE_CONV pair_rewrites THENC PURE_REWRITE_CONV nonpair_rewrites) term
3255in
3256        if null removed orelse null kept then
3257                if to_remove <= length removed then thm
3258                else raise (mkDebugExn "remove_hyp_terms"
3259                        ("The terms: " ^ xlist_to_string (term_to_string o concl) kept ^
3260                         " do not match terms in the hypothesis set"))
3261        else
3262        remove_hyp_terms_pre min pair_thm proofs (map (CONV_RULE conv) kept,PROVE_HYP TRUTH (CONV_HYP conv thm'))
3263end
3264in
3265fun remove_hyp_terms pair_thm proofs ([],thm) = thm
3266  | remove_hyp_terms pair_thm proofs (mthms,thm) =
3267let     val total = (length o mk_set o map (repeat rator o lhs o snd o strip_forall) o strip_conj o concl) thm
3268in
3269        if length mthms = total then
3270                foldl (uncurry PROVE_HYP) thm mthms
3271        else
3272                (type_trace 1 "Removing splits: " ;
3273                remove_hyp_terms_pre total pair_thm proofs (mthms,thm))
3274end
3275end;
3276
3277local
3278fun full_subst subs term =
3279        if subst subs term = term then term else full_subst subs (subst subs term)
3280in
3281fun match_mapping ethm mapping get_func pair_def t =
3282let     val all_types = flatten (map (op:: o (I ## op@)) (split_nested_recursive_set t))
3283        val alist = map (fn t => (get_func t,t)) all_types
3284        val eq_fns1 = mapfilter (dest_eq o snd o strip_forall o snd o dest_imp_only) (hyp ethm)
3285        val eq_fns2 = mapfilter ((rator ## rator) o dest_eq o snd o strip_forall o rhs o
3286                        concl o STRIP_QUANT_CONV (RAND_CONV (REWR_CONV (GSYM pair_def)))) (hyp ethm)
3287        val mapping' = map ((I:term -> term) ## full_subst (map op|-> (eq_fns1 @ eq_fns2))) mapping
3288
3289        val pt = (rator o lhs o snd o strip_forall o concl) pair_def
3290
3291        fun find_type func =
3292                case (assoc1 func alist)
3293                of NONE => if can (match_term pt) func
3294                                then list_mk_prod(mapfilter find_type (snd (strip_comb func)))
3295                                else raise (mkDebugExn "match_mapping"
3296                                                ("Could not find type for function: " ^ term_to_string func))
3297                |  SOME (a,t) => t
3298in
3299        map (fn (a,b) => (a,(b,find_type b))) mapping'
3300end
3301end
3302
3303local
3304fun wrap e = wrapException "unsplit_function" e
3305fun err1 thm = "Mutual recursion theorem must be of the form: \n" ^
3306        "|- ?fn0 ... fnK. (!a... fn0 ... = A0) /\\ ... /\\ (!a ... fnK ... = AK)\n" ^
3307        "theorem supplied has the form: \n" ^
3308        thm_to_string thm
3309fun wrap_ind get_ind t =
3310let     fun mkExn s = raise (mkStandardExn "get_ind (supplied function)" s)
3311        val result = get_ind t
3312        val all_types = flatten (map (op:: o (I ## fst)) (split_nested_recursive_set (base_type t)))
3313        val preds = fst (strip_forall (concl result))
3314        val _ = if null (hyp result) then ()
3315                else mkExn ("Induction theorem returned contains a non-empty hypothesis set")
3316        val _ = if all (can (match_type (alpha --> bool)) o type_of) preds then ()
3317                else mkExn ("Not all predicates of returned induction theorem are of type :'a -> bool")
3318        val _ = if length preds = length all_types then ()
3319                else mkExn ("Induction theorem specifies " ^ int_to_string (length preds) ^
3320                            " predicates but type " ^ type_to_string t ^ " is a set of " ^
3321                            int_to_string (length all_types) ^ " mutually recursive types")
3322        val _ = if is_imp_only (snd (strip_forall (concl result))) then ()
3323                else mkExn ("Induction theorem returned is not an implication: " ^ thm_to_string result)
3324        val (hyps,conc) = (strip_conj ## strip_conj) (dest_imp_only (snd (strip_forall (concl result))))
3325        val my_conc = map (fn p => mk_forall(mk_var("x",fst (dom_rng (type_of p))),
3326                                mk_comb(p,mk_var("x",fst (dom_rng (type_of p)))))) preds
3327        val _ = if all (fn c => exists (aconv c) my_conc) conc andalso all (fn c => exists (aconv c) conc) my_conc then ()
3328                else mkExn ("Conclusion of induction theorem does not use exactly the predicates: " ^
3329                            xlist_to_string term_to_string my_conc)
3330in
3331        result
3332end
3333in
3334fun unsplit_function get_ind get_def get_func conv create_conv (dead_thm,dead_value) t (mthm,ethm) =
3335let     val _ = type_trace 2 "->unsplit_function\n"
3336        val mterm = (snd o strip_exists) (
3337                assert "unsplit_function" [
3338                        (err1 mthm,boolSyntax.is_exists),
3339                        (err1 mthm,is_conjunction_of (is_eq o snd o strip_forall) o snd o strip_exists),
3340                        (err1 mthm,fn t => set_eq       ((map (repeat rator o lhs o snd o strip_forall) o
3341                                                        strip_conj o snd o strip_exists) t)
3342                                                ((fst o strip_exists) t))] (concl mthm));
3343        val mthms = map (LIST_CONJ o snd)
3344                        (bucket_alist (map (fn x => ((repeat rator o lhs o snd o strip_forall o concl) x,x))
3345                                (CONJUNCTS (ASSUME mterm)))) handle e => wrap e
3346        val thm = CONV_RULE (REWR_CONV (GSYM ethm)) (ASSUME mterm)
3347                        handle e => raise (mkDebugExn "unsplit_function"
3348                                "Equivalence theorem does not match theorem of mutual recursion")
3349        val prod = pairLib.mk_prod(alpha,beta) handle e => wrap e
3350        val htypes = mk_set (filter (not o can (match_type prod))
3351                (flatten (map (fst o snd) (split_nested_recursive_set t))))
3352        val pair_thm = get_def prod handle e => wrap e
3353        val matches = zip (map get_func htypes) htypes handle e => wrap e
3354        val (proofs,thm') = prove_all_split_terms (get_ind,get_def,conv,create_conv,dead_thm,dead_value) matches thm handle e => wrap e
3355        val thm'' = remove_hyp_terms pair_thm proofs (mthms,thm') handle e => wrap e
3356        val _ = if length (hyp thm'') = 1 then () else
3357                        raise (mkDebugExn "unsplit_function"
3358                        "remove_hyp_terms returned a theorem with more than one hypothesis")
3359        val hyp_vars = fst (strip_exists (concl mthm)) handle e => wrap e
3360        val thm_vars = mk_set (map (repeat rator o lhs o snd o strip_forall) (strip_conj (concl thm'')))
3361                handle e => wrap e
3362in
3363        CHOOSE_L (hyp_vars,mthm) (foldl (uncurry SIMPLE_EXISTS) thm'' thm_vars)
3364        handle e => wrap e
3365end
3366end
3367
3368local
3369fun complete_function name mk_term get_ind get_def get_func conv create_conv (dead_thm,dead_value) t (mthm,ethm) =
3370let     val unsplit = unsplit_function get_ind get_def get_func conv create_conv (dead_thm,dead_value) t (mthm,ethm)
3371        val all_types = map fst (split_nested_recursive_set t)
3372        val func_names = map (fst o dest_var o repeat rator o get_func) all_types
3373        val def = new_specification (name ^ "_" ^ (fst (dest_type t)),map (fst o dest_var) (fst (strip_exists (concl unsplit))),unsplit)
3374        val all_theorems = map (I ## LIST_CONJ) (bucket_alist
3375                        (map (fn x => ((repeat rator o lhs o snd o strip_forall o concl) x,x)) (CONJUNCTS def)))
3376        val all_consts = map2 (curry mk_const) func_names (map (type_of o repeat rator o get_func) all_types)
3377in
3378        map2 (fn t => fn ac => (t,assoc1 ac all_theorems)) all_types all_consts
3379end
3380fun check_defs func get_def t =
3381let     val all_types = split_nested_recursive_set t
3382        val required = filter (not o is_vartype) (flatten (map (op@ o snd) all_types))
3383in
3384        (raise (mkStandardExn func
3385                ("Can't create function for type " ^ type_to_string t ^ " as this is dependent upon type " ^
3386                type_to_string (first_e Empty (not o can get_def) required) ^
3387                " for which no function is returned by get_def"))) handle Empty => ()
3388end
3389fun store_funcs name store err [] = ()
3390  | store_funcs name store err ((t,NONE)::xs) = raise (mkDebugExn err ("Functions were not created for type: " ^ type_to_string t))
3391  | store_funcs name store err ((t,SOME x)::xs) = (overload_on(name,(fst x)) ; store t x ; store_funcs name store err xs) handle e => wrapException err e;
3392fun get_source_ind get_func t =
3393let     val thm = TypeBase.induction_of t
3394in
3395        (thm,zip_on_types (fst o dom_rng o type_of) snd
3396                ((fst o strip_forall o concl) thm) (map (fn t => (get_func t,t))
3397                        ((map (fst o dom_rng o type_of) o fst o strip_forall o concl) thm)))
3398end
3399fun check_const result tm =
3400        if is_const tm then tm
3401        else
3402                fst (valOf (snd (first (fn (_,(SOME (c,_))) =>
3403                        (fst (dest_const c) = fst (dest_var tm)) andalso
3404                        (type_of c = type_of tm) | _ => false) result))) handle e => tm
3405fun fix_ind result (thm,mapping) =
3406        (thm,map (fn (P,(tm,t)) =>
3407                (P,(list_mk_comb((check_const result ## I) (strip_comb tm)),t))) mapping)
3408in
3409fun mk_source_functions name mk_term get_func conv create_conv t =
3410let     val get_def = C get_source_function_def name
3411        fun wrap e = wrapException "mk_source_functions" e
3412        val _ = check_defs "mk_source_functions" get_def t
3413        val (mthm,ethm) = mk_split_source_function mk_term get_def get_func conv create_conv t handle e => wrap e
3414        val ind = get_source_ind get_func t
3415        val result = complete_function name mk_term (C get_source_function_induction name)
3416                        get_def get_func conv create_conv (TRUTH,mk_arb alpha) t (mthm,ethm) handle e => wrap e
3417in
3418        store_funcs name (fn t => fn (c,d) => add_source_function t name {const = c,definition = d,induction = SOME (fix_ind result ind)})
3419                "mk_source_functions (store)" result
3420
3421end
3422fun mk_coding_functions name mk_term get_func conv create_conv target t =
3423let     val get_def = C (get_coding_function_def target) name
3424        fun wrap e = wrapException "mk_coding_functions" e
3425        val _ = check_defs "mk_coding_functions" get_def t
3426        val (mthm,ethm) = mk_split_source_function mk_term get_def get_func conv create_conv t handle e => wrap e
3427        val ind = get_source_ind get_func t
3428        val dead_thm = #bottom_thm (get_translation_scheme target)
3429        val dead_value = #bottom (get_translation_scheme target)
3430        val result = complete_function name mk_term (C (get_coding_function_induction target) name)
3431                        get_def get_func conv create_conv (dead_thm,dead_value) t (mthm,ethm) handle e => wrap e
3432in
3433        store_funcs name (fn t => fn (term,thm) => add_coding_function target t name
3434                        {const = term,definition = thm,induction = SOME (fix_ind result ind)})
3435                "mk_coding_functions (store)"
3436                result
3437end
3438fun mk_target_functions name mk_term get_func conv create_conv target t =
3439let     val get_def = C (get_coding_function_def target) name
3440        val get_ind = C (get_coding_function_induction target) name
3441        fun wrap e = wrapException "mk_target_functions" e
3442        val _ = check_defs "mk_target_functions" get_def t
3443        val dead_thm = #bottom_thm (get_translation_scheme target)
3444        val dead_value = #bottom (get_translation_scheme target)
3445        val ((ithm,mapping,mthm),ethm) = mk_split_target_function mk_term get_def get_func conv create_conv
3446                                                (get_translation_scheme target) t handle e => wrap e
3447        val complete_mapping = match_mapping ethm mapping get_func (CONV_RULE conv (get_def (mk_prod(alpha,beta)))) t
3448        val result = complete_function name mk_term get_ind get_def get_func conv create_conv (dead_thm,dead_value) t (mthm,ethm) handle e => wrap e
3449in
3450        store_funcs name (fn t => fn (term,thm) => add_coding_function target t name
3451                                {const = term,definition = thm,induction = SOME (fix_ind result (ithm,complete_mapping))})
3452                "mk_target_functions (store)"
3453                result
3454end
3455end
3456
3457
3458(*****************************************************************************)
3459(* Function generators                                                       *)
3460(*****************************************************************************)
3461
3462fun add_compound_coding_function_generator name mk_term get_func conv create_conv target =
3463        add_coding_function_generator target name (can TypeBase.constructors_of)
3464        (fn t =>
3465                (let    val all_types = split_nested_recursive_set t
3466                        val required = filter (not o is_vartype) (flatten (map (op@ o snd) all_types))
3467                        val _ = map (generate_coding_function target name o base_type) required
3468                in
3469                        mk_coding_functions name mk_term get_func conv create_conv target t
3470                end     ;
3471                        get_coding_function_precise target t name));
3472fun add_compound_target_function_generator name mk_term get_func conv create_conv target =
3473        add_coding_function_generator target name (can TypeBase.constructors_of)
3474        (fn t =>
3475                (let    val all_types = split_nested_recursive_set t
3476                        val required = filter (not o is_vartype) (flatten (map (op@ o snd) all_types))
3477                        val _ = map (generate_coding_function target name o base_type) required
3478                in
3479                        mk_target_functions name mk_term get_func conv create_conv target t
3480                end     ;
3481                        get_coding_function_precise target t name));
3482fun add_compound_source_function_generator name mk_term get_func conv create_conv =
3483        add_source_function_generator name (can TypeBase.constructors_of)
3484        (fn t =>
3485                (let    val all_types = split_nested_recursive_set t
3486                        val required = filter (not o is_vartype) (flatten (map (op@ o snd) all_types))
3487                        val _ = map (generate_source_function name o base_type) required
3488                in
3489                        mk_source_functions name mk_term get_func conv create_conv t
3490                end     ;
3491                        get_source_function_precise t name));
3492
3493(*****************************************************************************)
3494(* Polytypic inductive proofs about functions                                *)
3495(*                                                                           *)
3496(* make_predicate_map thm -> (term * term list) list                         *)
3497(*     Given an induction theorem returns a mapping from predicates to the   *)
3498(*     predicates it relies upon.                                            *)
3499(*                                                                           *)
3500(* prove_inductive_source_theorem[_precise] : string -> string ->            *)
3501(*            (hol_type -> term) -> hol_type -> (term -> thm) ->             *)
3502(*            (hol_type * hol_type list -> thm list -> tactic) -> unit       *)
3503(* prove_inductive_coding_theorem[_precise] : string -> string ->            *)
3504(*            (hol_type -> term) -> hol_type -> hol_type -> (term -> thm) -> *)
3505(*            (hol_type * hol_type list -> thm list -> tactic) -> unit       *)
3506(*    Given the name of the function the induction is based around, the name *)
3507(*    of the theorem being proved, a function to generate conclusions, the   *)
3508(*    main type, a conversion to be applied to the conclusions to make their *)
3509(*    form induction and a tactic these functions prove the conclusions for  *)
3510(*    the type given using induction.                                        *)
3511(*                                                                           *)
3512(* prove_source_theorem[_precise]                                            *)
3513(*                      : string -> (hol_type -> term) -> hol_type ->        *)
3514(*                        (term -> thm) -> (hol_type * hol_type list ->      *)
3515(*                                          thm list -> tactic) -> unit      *)
3516(* prove_coding_theorem : string -> (hol_type -> term) -> hol_type ->        *)
3517(*               hol_type -> (term -> thm) -> (hol_type * hol_type list ->   *)
3518(*                                          thm list -> tactic) -> unit      *)
3519(*    Simply prove functions given a function to generate conclusions, a     *)
3520(*    conversion and a tactic.                                               *)
3521(*                                                                           *)
3522(*****************************************************************************)
3523
3524local
3525fun mstrip_imp term = if is_imp_only term then (strip_conj ## I) (dest_imp_only term) else ([],term);
3526in
3527fun make_predicate_map induction =
3528let     val predicates = fst (strip_forall (concl induction))
3529        val mapping1 = (map (uncurry (C pair) o (filter (C mem predicates) o mapfilter rator
3530                                        ## rator o snd o strip_forall) o
3531                                mstrip_imp o snd o strip_forall) o strip_conj o fst o
3532                                dest_imp_only o snd o strip_forall o concl) induction
3533in
3534        map (I ## flatten) (bucket_alist mapping1)
3535end
3536end
3537
3538fun delete_matching_types rset t =
3539        if op_mem (fn a => fn b => can (match_type a) b) t rset then gen_tyvar()
3540        else if can dest_type t then (mk_type o (I ## map (delete_matching_types rset)) o dest_type) t
3541        else t
3542
3543fun all_types t = filter (not o is_vartype) (mk_set (t :: map snd (reachable_graph uncurried_subtypes t)));
3544
3545fun relevant_types t =
3546let     val all_types = all_types t
3547        val rset = map fst (split_nested_recursive_set t)
3548in
3549        filter (not o is_vartype) (map (delete_matching_types rset) all_types)
3550end
3551
3552local
3553fun check_concs targ target [] = ()
3554  | check_concs targ target ((_,(_,(t,thm)))::rest) =
3555let     val var = type_of (hd (fst (strip_forall (rhs (concl thm)))))
3556in
3557        if (not targ andalso (var = t)) orelse (targ andalso (var = target))
3558                then check_concs targ target rest
3559                else raise (mkStandardExn "inductive_proof"
3560                        ("Conclusion returned does not match the form: \"!a" ^
3561                         type_to_string (if targ then target else t) ^ ".P a\""))
3562end
3563fun mk_thm (induction,mapping) mk_conc conv =
3564let     val all_concs = map (I ## (I ## (fn t => (t,UNDISCH_ALL_EQ (conv (mk_conc t)))))) mapping
3565        val _ = type_trace 3 ("Conclusions:\n" ^ xlist_to_string
3566                        (thm_to_string o snd o snd o snd) all_concs ^ "\n")
3567        val preds = map (rator o snd o strip_forall)
3568                        ((strip_conj o snd o dest_imp_only o
3569                                snd o strip_forall o concl) induction)
3570        val all_types = map (fst o dom_rng o type_of) preds;
3571        val _ = check_concs (length (mk_set all_types) = 1) (hd (all_types)) all_concs
3572        val ithm = LIST_MK_CONJ (map (fn p => snd (snd (assoc p all_concs))) preds)
3573in
3574        (all_concs,ithm,foldl (fn (a,b) => UNDISCH_ONLY (DISCH a b))
3575                        (UNDISCH_ONLY (repeat (UNDISCH_ONLY o CONV_RULE
3576                                        (REWR_CONV (GSYM AND_IMP_INTRO)))
3577                        (HO_PART_MATCH (snd o dest_imp_only) induction
3578                                ((rhs o concl) ithm)))) (hyp ithm))
3579end
3580fun mkgoal (induction,mapping) mk_conc conv =
3581let     val (all_concs,ithm,thm) = mk_thm (induction,mapping) mk_conc conv
3582in
3583        (proofManagerLib.set_goal (hyp ithm,(lhs o concl) ithm) ;
3584         proofManagerLib.expand(
3585                MATCH_MP_TAC (snd (EQ_IMP_RULE ithm)) THEN
3586                MATCH_MP_TAC (foldl (fn (x,t) =>
3587                        CONV_RULE (REWR_CONV AND_IMP_INTRO) (DISCH x t))
3588                                (DISCH (hd (hyp thm)) thm) (tl (hyp thm))) THEN
3589                REPEAT CONJ_TAC))
3590end
3591fun proveit (induction,mapping) mk_conc conv (tactic:hol_type -> tactic) get_theorem =
3592let     val (all_concs,ithm,thm) = mk_thm (induction,mapping) mk_conc conv
3593        val _ = type_trace 3
3594                        ("Instantiated induction theorem: " ^ thm_to_string thm ^ "\n")
3595        val to_provea = mapfilter ((strip_conj ## I) o dest_imp_only o
3596                                snd o strip_forall) (hyp thm)
3597        val to_proveb = map (pair (hyp ithm) o snd o strip_forall) (hyp thm)
3598        val typed = mapfilter
3599                (fn tp => ((fst o snd o snd o first_e Empty
3600                        (can (C match_term (snd (strip_forall (snd tp)))) o snd o
3601                        strip_forall o rhs o concl o snd o snd o snd)) all_concs,tp))
3602                (to_provea @ to_proveb)
3603        val _ = if length typed = length (hyp thm) - length (hyp ithm) then () else
3604                        raise (mkDebugExn "prove_inductive"
3605                                ("Some clauses in the instantiated theorem do not match " ^
3606                                "conclusions generated from the mapping"))
3607
3608        val proofs = map (fn (t,(assums:term list,goal:term)) =>
3609                        let     val clause_err = mkStandardExn "prove_inductive_theorem" o
3610                                        curry op^ ("Could not prove the clause:\n" ^
3611                                                xpair_to_string (xlist_to_string term_to_string) term_to_string
3612                                                (assums,goal))
3613                        in
3614                                (case (tactic t (assums @ hyp ithm,goal))
3615                                of ([],func) => foldl (uncurry PROVE_HYP) (func []) (map ASSUME assums)
3616                                |  (x::xs,func) => raise (clause_err ""))
3617                                handle e => raise (clause_err ("\nOriginal exception: " ^ exn_to_string e))
3618                        end) typed
3619
3620        val sorted = map (fn c => tryfind (C MATCH_IND_TERM c) (hyp thm)) proofs
3621in
3622        DISCH_ALL (CONV_RULE (REWR_CONV (GSYM ithm)) (foldr (uncurry PROVE_HYP_CHECK) thm sorted))
3623end
3624fun split_theorem thm mk_conc t =
3625let     val main_types = map fst (split_nested_recursive_set t)
3626        val conjuncts = CONJUNCTS thm
3627in
3628        map (fn t => (t,first (can (match_term (mk_conc t)) o concl) conjuncts)) main_types
3629end     handle e => wrapException "(split_theorem)" e
3630in
3631fun prove_inductive_coding_theorem fname name mk_conc target t conv tactic =
3632let     val _ = type_trace 1 ("Proving coding theorem: " ^ name ^ " for translation " ^
3633                        type_to_string target ^ " --> " ^ type_to_string t ^ "\n")
3634        val (induction,mapping) = get_coding_function_induction target t fname
3635            handle e => wrapException "prove_inductive_coding_theorem" e
3636        val tsub = tryfind (C match_type t o snd o snd) mapping
3637            handle e => wrapException "prove_inductive_coding_theorem" e
3638        val thm = proveit
3639                (INST_TYPE tsub induction,map (inst tsub ## (I ## type_subst tsub)) mapping)
3640                mk_conc conv tactic
3641                (CONV_RULE conv o C (get_coding_theorem target) name)
3642                handle e => wrapException "prove_inductive_coding_theorem" e
3643        val split_types = map (type_subst tsub o snd o snd) mapping
3644                handle e => wrapException "prove_inductive_coding_theorem" e
3645        val conjuncts = map DISCH_ALL (CONJUNCTS (UNDISCH_ALL thm))
3646        val (thms,failed) = mappartition (fn t =>
3647            (t,first (can (match_term (mk_conc t)) o concl) conjuncts))
3648            split_types
3649        val _ = if null failed then () else
3650            raise (mkStandardExn "prove_inductive_coding_theorem"
3651                  ("The type: " ^ type_to_string (hd failed) ^
3652                   "\nwith conclusion: " ^ term_to_string (mk_conc t) ^
3653                   "\nhas no corresponding theorem in the proved set:\n"
3654                          ^ xlist_to_string thm_to_string conjuncts))
3655in
3656        app (fn (t,thm) => add_coding_theorem_precise target t name thm) thms
3657        handle e => wrapException "prove_inductive_coding_theorem" e
3658end
3659fun inductive_coding_goal fname mk_conc target t (conv:term -> thm) =
3660let     val (induction,mapping) = get_coding_function_induction target t fname
3661        val tsub = tryfind (C match_type t o snd o snd) mapping
3662in
3663        mkgoal (INST_TYPE tsub induction,map (inst tsub ## (I ## type_subst tsub)) mapping)
3664                mk_conc conv
3665end handle e => wrapException "inductive_coding_goal" e
3666fun prove_inductive_source_theorem fname name mk_conc t conv tactic =
3667let     val _ = type_trace 1 ("Proving source theorem: " ^ name ^ " for type " ^ type_to_string t ^ "\n")
3668        val (induction,mapping) = get_source_function_induction t fname
3669        val tsub = tryfind (C match_type t o snd o snd) mapping
3670        val thm = proveit
3671                (INST_TYPE tsub induction,map (inst tsub ## (I ## type_subst tsub)) mapping)
3672                mk_conc conv tactic
3673                (CONV_RULE conv o C get_source_theorem name)
3674        val split_types = map (type_subst tsub o snd o snd) mapping
3675        val conjuncts = CONJUNCTS thm
3676        val thms = map (fn t => (t,first (can (match_term (mk_conc t)) o concl) (CONJUNCTS thm))) split_types;
3677        val (thms,failed) = mappartition (fn t =>
3678            (t,first (can (match_term (mk_conc t)) o concl) (CONJUNCTS thm)))
3679            split_types;
3680        val _ = if null failed then () else
3681            raise (mkStandardExn "prove_inductive_source_theorem"
3682                  ("The type: " ^ type_to_string (hd failed) ^
3683                   "\nwith conclusion: " ^ term_to_string (mk_conc t) ^
3684                   "\nhas no corresponding theorem in the proved set:\n"
3685                          ^ xlist_to_string thm_to_string conjuncts))
3686
3687in
3688        app (fn (t,thm) => add_source_theorem_precise t name thm) thms
3689end     handle e => wrapException "prove_inductive_source_theorem" e
3690fun inductive_source_goal fname mk_conc t (conv:term -> thm) =
3691let     val (induction,mapping) = get_source_function_induction t fname
3692        val tsub = tryfind (C match_type t o snd o snd) mapping
3693in
3694        mkgoal
3695        (INST_TYPE tsub induction,map (inst tsub ## (I ## type_subst tsub)) mapping)
3696        mk_conc conv
3697end     handle e => wrapException "inductive_source_goal" e
3698end
3699
3700fun add_inductive_coding_theorem_generator fname name target conv tactic =
3701        add_coding_theorem_generator target name (can TypeBase.constructors_of)
3702        (fn t =>
3703                (prove_inductive_coding_theorem fname name
3704                        (fn t => if exists_coding_theorem_conclusion target name
3705                                        then get_coding_theorem_conclusion target name t
3706                                        else raise (mkStandardExn ("inductive_coding_proof ("^name^")")
3707                                                ("The conclusion has not yet been set!")))
3708                        target t conv tactic ;
3709                        get_coding_theorem_precise target t name));
3710
3711fun add_inductive_source_theorem_generator fname name conv tactic =
3712        add_source_theorem_generator name (can TypeBase.constructors_of)
3713                (fn t => (prove_inductive_source_theorem fname name
3714                        (fn t => if exists_source_theorem_conclusion name
3715                                        then get_source_theorem_conclusion name t
3716                                        else raise (mkStandardExn ("inductive_source_proof ("^name^")")
3717                                                ("The conclusion has not yet been set!")))
3718                                t conv tactic ;
3719                                get_source_theorem_precise t name));
3720
3721fun add_tactic_coding_theorem_generator name test (tactic:hol_type -> tactic) target =
3722        add_coding_theorem_generator target name test
3723                (fn t => case (tactic t ([],get_coding_theorem_conclusion target name t))
3724                                of ([],func) => func []
3725                                |  (x::xs,func) =>
3726        (raise (mkStandardExn ("tactic_" ^ name ^ " (" ^ type_to_string t ^ ")") "Unsolved goals")))
3727
3728fun add_tactic_source_theorem_generator name test (tactic:hol_type -> tactic) =
3729        add_source_theorem_generator name test
3730                (fn t => case (tactic t ([],get_source_theorem_conclusion name t))
3731                                of ([],func) => func []
3732                                |  (x::xs,func) =>
3733        (raise (mkStandardExn ("tactic_" ^ name ^ " (" ^ type_to_string t ^ ")") "Unsolved goals")))
3734
3735fun add_rule_coding_theorem_generator name test rule target =
3736        add_coding_theorem_generator target name test
3737        (fn t => rule t handle e => wrapException ("rule:" ^ name ^ " (" ^ type_to_string t ^ ")") e);
3738
3739fun add_rule_source_theorem_generator name test rule =
3740        add_source_theorem_generator name test
3741        (fn t => rule t handle e => wrapException ("rule:" ^ name ^ " (" ^ type_to_string t ^ ")") e);
3742
3743end