1(*****************************************************************************)
2(* functionEncodeLib.sml :                                                   *)
3(*     Translates functions between HOL and an embedded language using the   *)
4(*     coding functions defined in encodeLib.                                *)
5(*                                                                           *)
6(*****************************************************************************)
7
8structure functionEncodeLib :> functionEncodeLib =
9struct
10
11(*****************************************************************************)
12(* Bugs:                                                                     *)
13(*     1) Datatypes using fcps, eg:                                          *)
14(*            wordlist = Nil | Cons of 'a word => wordlist                   *)
15(*        currently fail. This is due to problems in the method of           *)
16(*        derivation of function type for encode/decode et al...             *)
17(*        This shouldn't take too long to fix...                             *)
18(*                                                                           *)
19(*****************************************************************************)
20
21(*****************************************************************************)
22(* Suggested Improvements:                                                   *)
23(*     1) (Minor) If is_target_function et al return false, the *reason* is  *)
24(*        not collected. A function should be created to log this and use it *)
25(*        in the error message.                                              *)
26(*     2) (Minor) When a coding function is added its form is not checked    *)
27(*     3) (Minor) When a coding theorem is added its form is not checked     *)
28(*     4) (Minor) Some functions & theorems had to be added for :sexp, these *)
29(*        should not be necessary.                                           *)
30(*     5) (Minor) general_detect can be automatically calculated for         *)
31(*        monomorphic types, but this is not currently done.                 *)
32(*     6) (Major) The backtracking rewriter does not currently cross         *)
33(*        encoding boundaries (eg. decode (encode x)). The suggested fix is  *)
34(*        in append_detector. However, this currently breaks other pieces of *)
35(*        code.                                                              *)
36(*     7) (Minor) 'Cannot resolve head term:                                 *)
37(*                  !x. %%genvar%%11776 x ==> (I (I x) = x)' appears quite   *)
38(*        often in the trace. First of all: why, secondly, it can be proved! *)
39(*     8) (Minor) When a polytypic theorem is added it should be confirmed   *)
40(*        to be of the correct conditional form.                             *)
41(*     9) (Minor) The polytypic functions I've created add LISTS of theorems *)
42(*        These then output a 'check' theorem to prevent looping. This is    *)
43(*        quite bad form... Hence:                                           *)
44(*             Either (a) change polytypic rewrites to use lists or          *)
45(*                    (b) amend the functions so that a single rewrite is    *)
46(*                        not used                                           *)
47(*    10) (Minor) The generators for map_compose seem to be slightly flaky,  *)
48(*        they require the theorem for simple types, and don't perform many  *)
49(*        checks.                                                            *)
50(*    11) (Major) Addition of propagation theorems for definitions should    *)
51(*        remove old propagation theorems. This may have been implemented by *)
52(*        the addition of scrubbing, however...                              *)
53(*    12) (Major) An assumption manager should be used, with maps to perform *)
54(*        (and cache!) the various conversions required. Should help speed   *)
55(*        things up... Especially in include_assumption_list                 *)
56(*    13) (Minor) You can't flatten fix functions. They could be helpful in  *)
57(*        some circumstances.                                                *)
58(*                                                                           *)
59(*****************************************************************************)
60
61(* -- Interactive use only --
62
63local
64val traces = ref [];
65fun clear_trace
66    (trace : {default : int, max : int, name : string, trace_level : int}) =
67    (traces := trace::(!traces) ;
68     Feedback.set_trace (#name trace) 0)
69fun reset_trace
70    (trace : {default : int, max : int, name : string, trace_level : int}) =
71    (Feedback.set_trace (#name trace) (#trace_level trace));
72in
73fun clear_traces () =
74    (traces := [] ; app clear_trace (Feedback.traces ()));
75fun reset_traces () =
76    (app reset_trace (!traces) ; traces := []);
77end;
78
79
80quietdec := true;
81clear_traces();
82
83load "encodeLib";
84Feedback.set_trace "polytypicLib.Trace" 0;
85
86use "functionEncodeLib.sml";
87reset_traces();
88Feedback.set_trace "polytypicLib.Trace" 1;
89
90*)
91
92open HolKernel bossLib proofManagerLib boolSyntax Parse Lib Term Thm Drule Type
93open Conv Tactic Tactical Rewrite
94open boolTheory combinTheory
95open polytypicLib encodeLib Feedback;
96
97
98(*****************************************************************************)
99(* Trace functionality:                                                      *)
100(*                                                                           *)
101(* trace : int -> string -> unit                                             *)
102(*     Prints a trace message if the trace level supplied is greater than    *)
103(*     the level registered.                                                 *)
104(*                                                                           *)
105(*****************************************************************************)
106val Trace = ref 0;
107
108val _ = register_trace ("functionEncodeLib.Trace",Trace,4);
109
110fun trace level s = if level <= !Trace then print s else ();
111
112(*****************************************************************************)
113(* is_encoded_term: term -> bool                                             *)
114(*     Checks to determine whether a term is an encoder applied to a value.  *)
115(*     The term must be of one of the three forms:                           *)
116(*     a) (encode : t -> target) x    where encode is a valid encoder        *)
117(*     b) (f : t -> target) x        where target is a known translation     *)
118(*     c) (f : t -> 'b) x                                                    *)
119(*                                                                           *)
120(* diagnose_encoded_term : term -> unit                                      *)
121(*     This informs the user of why a particular is not considered to be     *)
122(*     an encoded term.                                                      *)
123(*                                                                           *)
124(*****************************************************************************)
125
126fun is_encoded_term term =
127let val (ratort,randt) = dest_comb term
128    val t = type_of randt
129    val target = type_of term
130    val enc = get_encode_function target t
131in
132    (is_const (fst (strip_comb enc)) andalso can (match_term enc) ratort)
133    orelse (is_var ratort andalso
134              (is_vartype target orelse can get_translation_scheme target))
135end     handle e => false
136
137fun diagnose_encoded_term term =
138let val (ratort,randt) = dest_comb term handle _ => (term,term)
139    val t = type_of randt
140    val target = type_of term
141    val enc = get_encode_function target t
142    val _ = print "Diagnosing encoded term: "
143    val _ = print_term term
144in
145    if not (is_comb term) then
146       print "The term is not an instance of function application."
147    else if is_const (fst (strip_comb enc)) then
148            if can (match_term enc) ratort then print "Encoded term"
149               else (print "The encoder: " ; print_term enc ;
150                     print "\n does not match the term's encoder:" ;
151                     print_term (rator term))
152            else if (is_var ratort andalso
153              (is_vartype target orelse can get_translation_scheme target))
154              then print "Encoded term"
155              else print ("No encoder exists for the term given and\n" ^
156                  "and the term does not specify a valid, encoding\n" ^
157                  "variable. (Ie. no translation exists for the\n" ^
158                  "output type or the output type is not a type variable")
159end
160
161(*****************************************************************************)
162(* The rewrite database                                                      *)
163(*                                                                           *)
164(* conditionize_rewrite: thm -> thm                                          *)
165(*     Converts a theorem from standard form (as used by add_standard)       *)
166(*     to that used by add_conditional_rewrite.                              *)
167(*                                                                           *)
168(* add_standard_rewrite:                                                     *)
169(*     Adds rewrites of the form  |- P ==> (enc a = b)   or   |- (enc a = b) *)
170(*                                                                           *)
171(* add_conditional_rewrite:                                                  *)
172(*     Adds rewrites of the form:                                            *)
173(*          |- P0 /\ ... /\ Pn ==>                                           *)
174(*                (Q0 ==> encode a0 = A0) /\ ... /\ (Qm ==> encode am = Am)  *)
175(*                ==> (encode (F a0 ... an) = F A0 ... Am)                   *)
176(*                                                                           *)
177(*     No encoders may occur in {A0...Am} and when an empty list is required *)
178(*     (either for predicates or recursing encoders) T should be used.       *)
179(*                                                                           *)
180(* exists_rewrite: string -> bool                                            *)
181(*     Returns true if a rewrite of a given name exists.                     *)
182(*                                                                           *)
183(*****************************************************************************)
184
185val rewrites = ref (Net.empty : (int * string * thm) Net.net)
186
187local
188fun nomatch s = mkStandardExn "add_conditional_rewrite"
189        ("Theorem must be of the form: \n" ^
190         "|- P0 /\\ ... /\\ Pn ==> \n" ^
191         "      (Q0 ==> encode a0 = A0) /\\ ... /\\ (Qm ==> encode am = Am)\n" ^
192         "      ==> (encode (F a0 ... an) = F A0 ... Am)\n" ^
193         "  where no encoders are present in A0 .. Am" ^
194        (if s = "" then "" else "\nHowever, " ^ s))
195val err1 = "some terms in the conclusion and second antecedent\n are not encoded terms"
196fun rall f [] = true | rall f (x::xs) = f x andalso rall f xs handle _ => false
197fun choices [] current A = current::A
198  | choices (x::xs) current A = choices xs (x::current) (choices xs current A)
199in
200fun add_conditional_rewrite priority name thm =
201let     val _ = trace 2 "->add_conditional_rewrite\n";
202        val thm' = CONV_RULE (LAND_CONV (PURE_REWRITE_CONV [GSYM CONJ_ASSOC]))
203                             (SPEC_ALL thm)
204        val (predicates,r) = with_exn (dest_imp_only o concl) thm' (nomatch "")
205        val (recrewrites,final) = with_exn dest_imp_only r (nomatch "")
206        fun mimp x = if is_imp_only x then snd (dest_imp x) else x
207        val _ = if is_eq final then () else raise (nomatch "the conclusion is not an equality")
208        val stripped = if recrewrites = T then [] else strip_conj recrewrites
209        val target = type_of (lhs final)
210        val _ = (if rall (is_encoded_term o lhs o mimp) (final::stripped) then () else raise Empty)
211                        handle _ => raise (nomatch err1)
212        val gf = snd o dest_imp_only o snd o dest_imp_only
213        fun subset a b = set_eq (intersect (strip_conj a) (strip_conj b)) (strip_conj a)
214        fun supercedes (a,b,c) = can (match_term final) (gf (concl c)) andalso
215                subset  ((fst o dest_imp_only o concl o C (PART_MATCH gf) (gf (concl c))) thm')
216                        ((fst o dest_imp_only o concl) c) andalso
217                a <= priority handle e => wrapException "add_conditional_rewrite" e
218        val s = filter supercedes (Net.match (lhs final) (!rewrites))
219        val (matches,sups) = partition (curry op= (concl thm') o concl o (fn (a,b,c) => c)) s
220        fun ismatch (a,b,c) = mem (a,b,concl c) (map (fn (a,b,c) => (a,b,concl c)) matches)
221        val _ = if null matches then () else
222                        (trace 1 "<<Encoding Warning: Exact match found, removing previous rewrite>>\n" ;
223                         rewrites := Net.delete (lhs final,ismatch) (!rewrites))
224        fun p (a,b,c) = "Rewrite: " ^ b ^ " with priority: " ^ int_to_string a ^ ":\n" ^ thm_to_string c ^ "\n"
225        val _ = if null sups then () else
226                trace 1 ("<<Encoding Messing>>: New rewrite matches other rewrites>>\n" ^ String.concat (map p sups))
227        val _ = trace 1 ("Adding rewrite:\n" ^ thm_to_string thm ^ "\n")
228
229        val ffs = filter (can dom_rng o type_of) (free_vars (rand (lhs final)))
230        val _ = trace 3 ("Free functions : " ^ xlist_to_string term_to_string ffs ^ "\n")
231
232        fun fixs thm =
233        let     val a = fst (dest_forall (concl thm))
234                val mt = fst (dom_rng (type_of a))
235                val thm' = INST_TYPE [snd (dom_rng (type_of a)) |-> fst (dom_rng (type_of a))] thm
236                val (a',body) = dest_forall (concl thm')
237                val terms = find_terms (fn x => (curry op= a' o rator) x handle _ => false) (concl thm')
238                val t = mk_const("I",mt --> mt)
239        in
240                CONV_RULE (CHANGED_CONV (RAND_CONV (RAND_CONV
241                        (PURE_REWRITE_CONV (map (fn t => ISPEC (rand t) I_THM) terms))))) (ISPEC t thm)
242        end     handle e => SPEC (fst (dest_forall (concl thm))) thm
243
244        fun fix thm list = repeat fixs (GENL list thm)
245
246        val all_thms = op_mk_set (fn a => fn b => concl a = concl b) (map (fix thm') (choices ffs [] []))
247
248        val _ = trace 3 ("Adjusted theorems added: " ^ int_to_string (length all_thms) ^ "\n")
249in
250        (map (fn thm' => rewrites := Net.insert (lhs final,(priority:int,name:string,thm')) (!rewrites)) all_thms ; ())
251end
252end;
253
254local
255fun nomatch t =
256    mkStandardExn "conditionize_rewrite"
257    ("Theorem:\n" ^ thm_to_string t ^
258     "\nis not of the form: |- P ==> (encode a = b) or |- encode a = b")
259fun wrap e = wrapException "conditionize_rewrite" e
260fun conva thms = DEPTH_CONV (FIRST_CONV (map REWR_CONV thms))
261fun conv thms thm = CONV_RULE (conva thms) (CONV_HYP (conva thms) thm)
262fun fix_rewrites [] L = L
263  | fix_rewrites (x::xs) L =
264  fix_rewrites (map (conv [x]) xs) (x::map (conv [x]) L);
265in
266fun conditionize_rewrite thm =
267let val _ = trace 2 "->conditionize_rewrite\n"
268    val thm' = UNDISCH (SPEC_ALL thm) handle _ => SPEC_ALL thm
269    val P = hd (hyp thm') handle Empty => T
270    val (term,right) = (dest_eq  o concl) thm' handle e => raise (nomatch thm)
271    val _ = if is_encoded_term term then () else raise (nomatch thm)
272    val terms1 = mk_set (find_terms is_encoded_term right)
273    val terms = terms1;
274    val target = type_of term
275    fun mk_varv s = variant (free_vars term) (Term.mk_var(s,target))
276    val vars = map (mk_varv o implode o base26 o fst) (enumerate 0 terms)
277               handle e => wrap e
278    val assums = mapfilter (ASSUME o mk_eq) (zip terms vars)
279    val rewrites = fix_rewrites (filter (not o equal T o concl) assums) [];
280in
281    DISCH P (CONV_RULE (LAND_CONV
282           (PURE_REWRITE_CONV [AND_CLAUSES]))
283        (DISCH_LIST_CONJ (map concl (TRUTH::rewrites))
284         (RIGHT_CONV_RULE (conva rewrites) thm')))
285   handle e => wrap e
286end
287end
288
289fun add_standard_rewrite priority name thm =
290let val toadd = conditionize_rewrite thm
291                handle e => wrapException "add_standard_rewrite" e
292in
293    add_conditional_rewrite priority name toadd
294    handle e =>
295    raise (mkDebugExn "add_standard_rewrite"
296   ("add_conditional_rewrite did not like the output of " ^
297    "conditionize_rewrite,\n this should never happen!! Original exception:\n" ^
298    exn_to_string e))
299end
300
301fun exists_rewrite name =
302    exists (curry op= name o (fn (a,b,c) => b)) (Net.listItems (!rewrites));
303
304fun remove_rewrite name =
305    rewrites := Net.filter (not o curry op= name o (fn (a,b,c) => b))
306                           (!rewrites);
307
308fun scrub_rewrites () =
309    rewrites :=
310        Net.filter (fn (priority,name,thm) => Theory.uptodate_thm thm)
311                   (!rewrites);
312
313(*****************************************************************************)
314(* Resolution procedures:                                                    *)
315(*                                                                           *)
316(*    Main rewrite procedures are 'partial_resolve' and 'full_resolve':      *)
317(*      ..._resolve [flag] functions  |- A /\ B /\ C ... ==> P               *)
318(*      Will resolve {A,B,C,...} using the functions provided. In the case   *)
319(*      of partial_resolve, the flag indicates whether terms in P can be     *)
320(*      instantiated.                                                        *)
321(*                                                                           *)
322(*    Adding an assumption (include_assumption_list):                        *)
323(*      Assumptions are stored as a list of disjunctions (ie. CNF):   Assums *)
324(*      Auxilliary theorems are provided in the form:  |- P ==> Q :   Extras *)
325(*      -- where P is in CNF.                                                *)
326(*                                                                           *)
327(*      When a new assumption, A |- t, is added to Assums:                   *)
328(*          A |- t is resolved against all P for |- P ==> Q in extras:       *)
329(*          to get new extras: |- P' \ t ==> Q'. Where P',Q' are             *)
330(*          instantiated such that t matches a disjunction, D, in P. Ie:     *)
331(*               P = ... /\ D /\ ... and   |- t ==> D                        *)
332(*          A |- t is also resolved against all other assumptions, A' |- t': *)
333(*             A' |- t' = A' |- t0 \/ ... \/ tn  is converted to:            *)
334(*             |- ~t0 /\ ... /\ ~tn ==> ~A'. Which is then resolve against t *)
335(*             and converted back to:                                        *)
336(*             A |- (t0 \/ ... \/ tn) / t                                    *)
337(*          Any new assumptions generated through these procedures undergo   *)
338(*          the same procedure.                                              *)
339(*                                                                           *)
340(*    Resolving a theorem (full_resolve):                                    *)
341(*      A theorem of the form: |- A /\ B /\ C ... ==> P is resolved using    *)
342(*      full resolve with the following functions:                           *)
343(*                                                                           *)
344(*           A |- t     t0 \/ ... t ... \/ tn                                *)
345(*           -------------------------------- match_disjunction              *)
346(*               A |- t0 \/ ... t ... \/ tn                                  *)
347(*                                                                           *)
348(*             !x. detect x ==> encode (decode x) = x                        *)
349(*           ----------------------------------------- match_decenc          *)
350(*           |- !x. detect x ==> encode (decode x) = x                       *)
351(*                                                                           *)
352(*             !x. decode (encode x) = x                                     *)
353(*           ---------------------------- match_encdec                       *)
354(*           |- !x. decode (encode x) = x                                    *)
355(*                                                                           *)
356(*             !x. detect (encode x)                                         *)
357(*           ------------------------ match_encdet                           *)
358(*           |- !x. detect (encode x)                                        *)
359(*                                                                           *)
360(*****************************************************************************)
361
362(*****************************************************************************)
363(* NNF_CONV / CNF_CONV : thm -> thm                                          *)
364(*                                                                           *)
365(*     Like those in normalForms, but ONLY deals with disjunction,           *)
366(*     conjunction and negation.                                             *)
367(*                                                                           *)
368(*****************************************************************************)
369
370local
371val thms = CONJUNCT1 NOT_CLAUSES::
372           (CONJUNCTS (SPEC_ALL DE_MORGAN_THM))
373in
374val NNF_CONV = TOP_DEPTH_CONV (FIRST_CONV (map REWR_CONV thms));
375fun CNF_CONV term =
376    (NNF_CONV THENC normalForms.PURE_CNF_CONV) term
377end;
378
379(*****************************************************************************)
380(* SAFE_MATCH_MP : thm -> thm -> thm                                         *)
381(*                                                                           *)
382(*     Like MATCH_MP but doesn't mind variable instantiation                 *)
383(*     in the assumptions.                                                   *)
384(*                                                                           *)
385(*****************************************************************************)
386
387fun SAFE_MATCH_MP thm1 thm2 =
388let val match = match_term (fst (dest_imp (concl thm1))) (concl thm2)
389        handle e => wrapException "SAFE_MATCH_MP" e
390    val vars1 = thm_frees thm1
391    val vars2 = thm_frees thm2
392in
393    if null (intersect (map #residue (fst match)) vars1)
394       then MP (INST_TY_TERM match thm1) thm2
395                handle e => wrapException "SAFE_MATCH_MP" e
396       else raise (mkStandardExn "SAFE_MATCH_MP"
397                  ("Double bind between:\n" ^ thm_to_string thm1 ^
398                  "\nand\n" ^ thm_to_string thm2))
399end
400
401(*****************************************************************************)
402(* match_disjunction : thm -> term -> thm                                    *)
403(*                                                                           *)
404(*     A0 |- a0 \/ a1 \/ ...   b0 \/ b1 \/ b2 ...                            *)
405(*     ------------------------------------------ match_disjunction          *)
406(*                A0' |- b0 \/ b1 \/ b2 ...                                  *)
407(*                                                                           *)
408(*    Where {a0,a1,...} is a subset of {b0,b1,...} up to instantiation and   *)
409(*    alpha conversion.                                                      *)
410(*                                                                           *)
411(*****************************************************************************)
412
413local
414fun mappluck f [] = []
415  | mappluck f (x::xs) =
416  ((f x,x),xs) :: (map (I ## cons x) (mappluck f xs)) handle e =>
417     if isFatal e then raise e else (map (I ## cons x) (mappluck f xs));
418fun instsubst (m1,m2) = subst m1 o inst m2;
419fun MATCH_DISJ_CONV thm term =
420    EQ_MP (CONV_RULE bool_EQ_CONV (AC_CONV (DISJ_ASSOC,DISJ_COMM)
421                                    (mk_eq(concl thm,term)))) thm
422fun match [] leftover = [([],leftover)]
423  | match (d1::d1s) d2s =
424let val m = mappluck (C match_term d1) d2s
425    val passed = mapfilter (fn ((m,d2),d2s') => (d2,map (instsubst m) d2s')) m
426in
427    flatten (map (fn (d2,d2s) => map (cons d2 ## I) (match d1s d2s)) passed)
428end;
429fun prove_term thm fterm (list,leftover) =
430    SAFE_MATCH_MP (DISCH_ALL (MATCH_DISJ_CONV
431     (foldl (uncurry DISJ2) (ASSUME (list_mk_disj list)) leftover)
432     fterm)) thm;
433in
434fun match_disjunction thm term =
435let val thm_disjs = strip_disj (concl thm)
436    val term_disjs = strip_disj term
437    val matches = match thm_disjs term_disjs
438    val fvs = free_vars term @ free_vars (concl thm)
439in
440    case (filter (null o C set_diff fvs o free_vars o concl)
441                 (mapfilter (prove_term thm term) matches))
442    of [] => raise (mkStandardExn "match_disjunction"
443                   ("Theorem:\n" ^ thm_to_string thm ^ "\n" ^
444                   "cannot not match the term: " ^ term_to_string term))
445    |  [x] => x
446    |  (x::xs) => (trace 1
447       "<<Encoding Warning: Multiple matches in disjunction>>" ; x)
448end
449end;
450
451(*****************************************************************************)
452(* head_term         : thm -> term                                           *)
453(* tail_thm          : thm -> thm                                            *)
454(* resolve_head_term : bool -> thm -> thm -> term list -> term list * thm    *)
455(*                                                                           *)
456(*     For a theorem, thm,  of the form: A |- t0 /\ t1 /\ t2 ... ==> P       *)
457(*         head_term thm  returns t0 and                                     *)
458(*         tail_thm  thm  returns A u {t0} |- t1 /\ t2 ... ==> P             *)
459(*     For a theorem, thm, of the form: A |- T ==> P                         *)
460(*         head_term thm  returns T and                                      *)
461(*         tail_thm  thm  returns A |- T ==> P                               *)
462(*                                                                           *)
463(*     A0 |- a0 \/ a1 \/ ..  A1 |- t0 /\ t1 /\ t2 .. ==> P                   *)
464(*     --------------------------------------------------- resolve_head_term *)
465(*                 A0 u A1' |- t1 /\ t2 .. ==> P'                            *)
466(*                                                                           *)
467(*     where t0 = {b0,b1,...} and {a0,a1,...} is a subset of {b0,b1,...}.    *)
468(*     If the protect flag, the first argument, is true, then P must equal   *)
469(*     P'. If not, then P may be instantiated to achieve a match.            *)
470(*     A list of terms is provided, and instantiated in the same manner as P *)
471(*                                                                           *)
472(*****************************************************************************)
473
474fun head_term thm =
475let val x = fst (dest_imp_only (concl thm))
476            handle e => wrapException "head_term" e;
477in  (let val r = fst (dest_conj x)
478     in if is_conj r then hd (strip_conj r) else r
479     end) handle _ => x
480end
481
482fun resolve_head_term protect rthm thm assumptions =
483let     val thm' = if is_conj (fst (dest_imp_only (concl thm))) then
484                CONV_RULE (REWR_CONV (GSYM AND_IMP_INTRO)) thm else
485                DISCH (head_term thm) (DISCH T (UNDISCH thm))
486                handle e => wrapException "resolve_head_term" e
487        val rthm' = UNDISCH_CONJ rthm handle e => rthm
488        val r = SAFE_MATCH_MP thm' rthm' handle e =>
489                raise (mkStandardExn "resolve_head_term"
490                        ("Theorem to resolve:\n" ^ thm_to_string rthm ^
491                         "\ndoes not match head term:"
492                         ^ term_to_string (head_term thm)))
493        val (match1,match2) =
494            match_term (fst (dest_imp_only (concl thm'))) (concl rthm')
495        val _ = if   not protect orelse
496                     snd (strip_imp (concl r)) = snd (strip_imp (concl thm))
497                   then ()
498                   else raise (mkStandardExn "resolve_head_term"
499                        ("Matching the theorem:\n" ^ thm_to_string thm ^
500                         "\nto resolve:\n" ^
501                         thm_to_string rthm ^
502                         "\nrequires modification of the theorems conclusion."))
503        val list = filter (fn x => exists (can (C match_term x))
504                        (strip_conj (fst (dest_imp_only (concl rthm))))) (hyp r)
505                        handle e => []
506in
507        (map (subst match1 o inst match2) assumptions,
508         foldr (fn (x,r) => CONV_RULE (REWR_CONV AND_IMP_INTRO) (DISCH x r))
509               r list)
510end
511
512local
513fun TUNDISCH thm = MP thm TRUTH handle _ =>  UNDISCH thm
514in
515fun tail_thm thm =
516let val imp = fst (dest_imp_only (concl thm))
517in  (if is_conj imp
518        then TUNDISCH (CONV_RULE (REWR_CONV (GSYM AND_IMP_INTRO)) thm)
519        else    if imp = T then thm
520                else DISCH T (TUNDISCH thm)) handle e => wrapException "tail_thm" e
521end
522end;
523
524(*****************************************************************************)
525(* dest_encdec  : term -> hol_type * hol_type                                *)
526(* match_encdec : term -> thm                                                *)
527(*     Returns a theorem matching a term of the form:                        *)
528(*     |- !a. decode (encode a) = a                                          *)
529(*                                                                           *)
530(* dest_decenc  : term -> hol_type * hol_type                                *)
531(* match_decenc : term -> thm                                                *)
532(*     Returns a theorem matching a term of the form:                        *)
533(*     |- !a. detect a ==> encode (decode a) = a                             *)
534(*                                                                           *)
535(* dest_encdet  : term -> hol_type * hol_type                                *)
536(* match_encdet : term -> thm                                                *)
537(*     Returns a theorem matching a term of the form:                        *)
538(*     |- !a. detect (encode a)                                              *)
539(*                                                                           *)
540(*****************************************************************************)
541
542local
543val err = mkStandardExn "dest_encdec"
544                        "Not a term of the form: \"!a. decode (encode a) = a\""
545in
546fun dest_encdec term =
547let     val (var,body) = with_exn dest_forall term err
548        val (decoder,encoded_term) = with_exn (dest_comb o lhs) body err
549        (*val _ = if is_encoded_term encoded_term then () else raise err*)
550        val _ = if var = rhs body then () else raise err
551in
552        (type_of var,type_of encoded_term)
553end
554end
555
556fun match_encdec term =
557let     val (t,target) = dest_encdec term handle e => wrapException "match_encdec" e
558in
559        FULL_ENCODE_DECODE_THM target t handle e => wrapException "match_encdec" e
560end
561
562local
563val err = mkStandardExn "dest_decenc" "Not a term of the form: \"!a. detect a ==> encode (decode a) = a\""
564in
565fun dest_decenc term =
566let     val (var,body) = with_exn dest_forall term err
567        val (detect,body2) = with_exn dest_imp_only body err
568        val (encode,decoded_term) = with_exn (dest_comb o lhs) body2 err
569        val _ = if var = rhs body2 then () else raise err
570        val target = type_of var
571        val t = type_of decoded_term
572        val _ = if exists_translation target t then () else raise err
573        val _ = with_exn (match_term encode) (gen_encode_function target t) err
574        val _ = with_exn (match_term (rator decoded_term)) (gen_decode_function target t) err
575        val _ = with_exn (match_term (rator detect)) (gen_detect_function target t) err
576in
577        (t,target)
578end
579end;
580
581fun match_decenc term =
582let     val (t,target) = dest_decenc term handle e => wrapException "match_decenc" e
583in
584        FULL_DECODE_ENCODE_THM target t handle e => wrapException "match_decenc" e
585end;
586
587local
588val err = mkStandardExn "dest_encdet" "Not a term of the form: \"!a. detect (encode a)\""
589in
590fun dest_encdet term =
591let     val (var,body) = with_exn dest_forall term err
592        val (detect,encoded_term) = with_exn dest_comb body err
593        val _ = if var = rand encoded_term then () else raise err
594        val target = type_of encoded_term
595        val t = type_of var
596        val _ = if exists_translation target t then () else raise err
597        val _ = with_exn (match_term (rator encoded_term)) (gen_encode_function target t) err
598        val _ = with_exn (match_term detect) (gen_detect_function target t) err
599in
600        (t,target)
601end
602end;
603
604fun match_encdet term =
605let     val (t,target) = dest_encdet term handle e => wrapException "match_encdet" e
606in
607        FULL_ENCODE_DETECT_THM target t handle e => wrapException "match_encdet" e
608end;
609
610local
611fun presolve protect funcs thm A =
612let val head = head_term thm
613        handle e => raise (mkDebugExn "partial_resolve"
614                 ("Theorem:\n " ^ thm_to_string thm  ^
615                  "\nis not of the form: |- A /\\ B ... ==> P"))
616in
617    if head = T
618       then if null A
619               then ([],thm)
620               else (rev A,DISCH_LIST_CONJ (rev A)
621                     (CONV_RULE (FIRST_CONV (map REWR_CONV
622                      (CONJUNCTS (SPEC_ALL IMP_CLAUSES)))) thm))
623       else uncurry (C (presolve protect funcs))
624                    (resolve_head_term protect
625                                    (tryfind (fn x => x head) funcs) thm A)
626            handle e =>
627                   if isFatal e
628                      then raise e
629                      else presolve protect funcs (tail_thm thm) (head::A)
630end
631in
632fun partial_resolve protect funcs thm =
633let val (A,r) = presolve protect funcs thm []
634in
635    if (A = strip_conj (fst (dest_imp_only (concl thm)))
636          handle e => wrapException "partial_resolve" e)
637       then raise UNCHANGED
638       else r
639end
640end;
641
642fun full_resolve funcs thm =
643let     val head = head_term thm handle e => raise (mkDebugExn "full_resolve"
644                        "Theorem is not of the form: |- A /\\ B ... ==> P")
645in
646        if head = T then
647                CONV_RULE (FIRST_CONV (map REWR_CONV
648                                (CONJUNCTS (SPEC_ALL IMP_CLAUSES)))) thm
649        else full_resolve funcs
650                        (tryfind_e (mkStandardExn "full_resolve"
651                                ("Cannot resolve head term: "
652                                 ^ term_to_string head))
653                                (fn x => snd (resolve_head_term true
654                                      (x head) thm [])) funcs)
655end
656
657local
658fun rmatch thm term = (match_term term (concl thm) ; thm)
659in
660fun resolve_functions assums =
661        map match_disjunction assums @
662        [match_encdec,match_decenc,match_encdet,DECIDE]
663end
664
665(*****************************************************************************)
666(* include_assumption_list : thm list -> (thm list * thm list) ->            *)
667(*                                                  (thm list * thm list)    *)
668(*     Adds a list of assumptions into the current list of assumptions       *)
669(*     and auxilliary theorems.                                              *)
670(*                                                                           *)
671(*     Assumptions are always stored in CNF (conjunction is implicit)        *)
672(*     Auxilliary theorems are stored as |- P ==> Q where P is in CNF        *)
673(*                                                                           *)
674(*     Additional assumptions are first converted to CNF then resolved       *)
675(*     against the antecedents of auxilliary theorems. Additional            *)
676(*     assumptions may be generated at this stage.                           *)
677(*                                                                           *)
678(*****************************************************************************)
679
680val IAL_data =
681    ref (NONE : (term list * (thm list * ((thm list * thm list)))) option);
682
683
684local
685val NOT_NOT = el 1 (CONJUNCTS NOT_CLAUSES)
686val CONTR1 = CONV_RULE (CONTRAPOS_CONV THENC LAND_CONV CNF_CONV) o
687             DISCH_ALL_CONJ
688val CONTR2 = CONV_RULE CNF_CONV o UNDISCH_CONJ o
689             CONV_RULE (CONTRAPOS_CONV THENC LAND_CONV (REWR_CONV NOT_NOT));
690val thm_set = op_mk_set (fn a => fn b => concl a = concl b)
691val thm_diff = op_set_diff (fn a => fn b => concl a = concl b)
692fun rematch_disjunction vs thm term =
693let val result = match_disjunction thm term
694    val match = match_term term (concl result)
695in
696    if exists (C mem ((map #redex o fst) match)) vs then
697       raise (mkStandardExn "rematch_disjunction"
698                        "Variable in assumptions bound")
699                        else result
700end
701fun mapchanged f [] = []
702  | mapchanged f (x::xs) =
703  (f x :: mapchanged f xs) handle UNCHANGED => mapchanged f xs
704fun IAL oset L (assums,extras) =
705let     val to_avoid = free_varsl oset
706        val _ = trace 4 "I" ;
707        val _ = if !debug then IAL_data := SOME (oset,(L,(assums,extras)))
708                         else ();
709        val rematched = map (rematch_disjunction to_avoid) L
710        val l = mapchanged (partial_resolve false rematched) (filter (is_imp_only o concl) extras) ;
711        val e = thm_diff l extras ;
712        val full_e = map (fn x => partial_resolve false (resolve_functions []) x
713                                  handle UNCHANGED => x) e
714                        handle e => wrapException "include_assumption_list" e
715        val full_a = mapfilter (CONTR2 o
716                     partial_resolve true
717                        (map (rematch_disjunction to_avoid o SPEC_ALL) L) o
718                        CONTR1) assums
719
720        val _ = trace 4 ("New theorems: " ^
721                         xlist_to_string thm_to_string (full_a @ full_e) ^ "\n")
722        val (newa,newe) = mappartition (CONV_RULE (STRIP_QUANT_CONV (FIRST_CONV
723            (map REWR_CONV (CONJUNCTS (SPEC_ALL IMP_CLAUSES)))))) full_e
724                        handle e => wrapException "include_assumption_list" e
725
726        val new_assums =
727            thm_diff (flatten (map (CONJUNCTS o CONV_RULE CNF_CONV)
728                              (newa @ full_a)))
729                     (L @ assums)
730        val _ = trace 1 ("#(" ^ int_to_string (length new_assums))
731        val _ = trace 3 (":" ^ int_to_string (length L) ^
732                        ":"   ^ int_to_string (length newe))
733        val _ = trace 1 ")"
734
735        val _ = if !debug andalso
736                   exists (not o null o C set_diff oset o hyp)
737                          (new_assums @ newe)
738                then raise (mkDebugExn "include_assumption_list"
739                     ("Adding the following assumption\n" ^
740                      "(derived from an auxillary theorem):\n" ^
741                      thm_to_string
742                       (first (not o null o C set_diff oset o hyp) (new_assums @ newe)) ^
743                      "\nwill add the unwanted hypothesis to the set:\n" ^
744                      term_to_string
745                       (tryfind (hd o C set_diff oset o hyp)
746                                (new_assums @ newe))))
747                 else ()
748in
749        case (new_assums,thm_diff newe extras)
750        of ([],[]) => (thm_set (L @ assums),thm_set extras)
751        |  (NA,NE) => IAL oset (thm_set (NA @ L)) (NA @ assums,NE @ extras)
752end
753in
754fun include_assumption_list [] AE =
755    (trace 2 "->include_assumption_list\n" ; AE)
756  | include_assumption_list L AE =
757let val oset = mk_set (flatten (map hyp (L @ fst AE)))
758    val new_assums =
759        flatten (map (CONJUNCTS o CONV_RULE CNF_CONV) L)
760in
761    (trace 2 "->include_assumption_list\n" ;
762     trace 3 ("Including: " ^ xlist_to_string thm_to_string L ^ "\n") ;
763     (IAL oset new_assums AE)
764          before (trace 3 "\n"))
765end
766end;
767
768(*****************************************************************************)
769(* Tests                                                                     *)
770(*
771fun rinclude [] AE = AE
772  | rinclude (x::xs) AE =
773    rinclude xs (include_assumption_list (map ASSUME x) AE)
774
775fun test AE tst =
776    full_resolve (resolve_functions (fst AE))
777                 (ASSUME (mk_imp(tst,mk_var("SUCCESS",bool))));
778
779val mat = mk_affirmation_theorems;
780
781rinclude [[``~(arg = [])``]] ([],list_case_proofs);
782
783rinclude [[``~((?c. a = SUC (SUC c)) /\ (?d e. b = d::e))``],
784          [``~(a = 0n)``],[``~(PRE a = 0n)``]]
785         ([],mat ``:num`` @ mat ``:'a list``);
786
787
788*)
789
790
791(*****************************************************************************)
792(* return_matches : thm list -> term                                         *)
793(*                        -> (string * thm list) * (string * thm) list       *)
794(*                                                                           *)
795(*     return_matches assumptions term  returns a list of instantiated       *)
796(*     !rewrites that match a term. It does this in three stages:            *)
797(*        1) Use HO_PART_MATCH to match the term                             *)
798(*        2) Match any uninstantiated encoders and instantiate               *)
799(*        3) Attempt to resolve any conditions using the assumptions         *)
800(*     Fails if any of the steps fail. This can include double-bind          *)
801(*     problems instantiating encoders.                                      *)
802(*                                                                           *)
803(*     Stage 2 is returned as the left part of the tuple for debugging       *)
804(*     purposes as it allows tracking of partial matches.                    *)
805(*                                                                           *)
806(* match_single_rewrite : thm list -> term ->                                *)
807(*                                        int * string * thm -> string * thm *)
808(*     As return_matches, except it matches a single rewrite.                *)
809(*                                                                           *)
810(*****************************************************************************)
811
812val return_matches_data = ref NONE;
813
814local
815fun mapthm f (a,b,c) = (a,b,f c)
816fun revmatch [] = []
817  | revmatch ({residue,redex}::xs) = (residue |-> redex) :: revmatch xs;
818fun ismem a [] = false
819  | ismem a (x::xs) =
820  ((match_term a x = (revmatch ## revmatch) (match_term x a)) orelse ismem a xs)
821  handle _ => ismem a xs
822fun subset [] _ = true
823  | subset (x::xs) ys = mem x ys andalso subset xs ys;
824fun compile [] match2 = match2
825  | compile ({redex,residue}::xs) match2 =
826let val match1 = filter (curry op= redex o #redex) match2
827in
828    if all (curry op= residue o #residue) match1
829    then (if null match1
830          then (redex |-> residue)::(compile xs match2)
831          else compile xs match2)
832    else raise Match
833end;
834fun compile_matches [] = []
835  | compile_matches (x::xs) = compile x (compile_matches xs)
836fun nomatch s thm =
837let val name = if s = "" then "return_matches" else "return_matches (" ^ s ^ ")"
838in
839    raise (mkDebugExn name
840("Theorem: " ^ thm_to_string thm ^ "\nis not of the form: \n" ^
841"|- P0 /\\ ... /\\ Pn ==> \n" ^
842"      (Q0 ==> encode a0 = A0) /\\ ... /\\ (Qm ==> encode am = Am) ==>\n" ^
843"      (encode (F a0 ... an) = F A0 ... Am)"))
844end
845fun mimp x = if is_imp x then ((strip_conj ## I) (dest_imp x)) else ([],x)
846fun instantiate_encoders (priority,name,thm) =
847let     val (l,final) = (dest_imp_only o snd o dest_imp_only o concl) thm
848                handle e => nomatch "instantiate_encoders" thm
849        val encoding_terms = filter (not o curry op= T) (strip_conj l)
850        val target = (type_of o lhs) final
851                handle e => nomatch "instantiate_encoders" thm
852        val encoders = mapfilter (gen_encode_function target o type_of o rand o
853            lhs o snd o mimp) encoding_terms
854        val (match1,match2) = unzip (map2 (fn e => fn t =>
855            match_term (rator (lhs (snd (mimp t)))) e) encoders encoding_terms)
856in
857        (priority,name,
858        INST_TY_TERM (compile_matches match1, compile_matches match2) thm)
859        handle Match =>
860        raise (mkDebugExn "return_matches (instantiate_encoders)"
861                ("Theorem: " ^ thm_to_string thm ^ "\n" ^
862                 "cannot have its encoders instantiated as they double bind"))
863end
864(* Ensure variables are not captured in any translations                     *)
865fun adjust thm =
866let     val vars = thm_frees thm
867in
868        INST (map (fn v => v |-> genvar (type_of v)) vars) thm
869end;
870fun pfs f x = f x handle e =>
871    if isFatal e then raise e else
872       (trace 4 ("Exception: " ^ polytypicLib.exn_to_string e) ; raise Empty)
873fun resolveit assums (p,s,thm) =
874    (p,s,(full_resolve (resolve_functions assums) o
875                CONV_RULE (LAND_CONV CNF_CONV)) thm)
876in
877fun match_single_rewrite assums term rewrite =
878    (fn (a,b,c) => (b,c)) (resolveit assums
879        (instantiate_encoders (mapthm (C (HO_PART_MATCH
880              (lhs o snd o dest_imp_only o snd o dest_imp_only))
881                           term o adjust) rewrite)))
882    handle e => wrapException "match_single_rewrite" e
883fun return_matches assums term =
884let     val _ = return_matches_data := SOME (assums,term);
885        val _ = trace 2 "return_matches->\n"
886        val _ = trace 3 (term_to_string (repeat rator (rand term)))
887        val matches = Net.match term (!rewrites)
888        val _ = trace 3 (":" ^ int_to_string (length matches))
889        val matched = mapfilter (mapthm
890                       (C (HO_PART_MATCH
891                           (lhs o snd o dest_imp_only o snd o dest_imp_only))
892                           term o adjust)) matches
893        val _ = trace 3 (":" ^ int_to_string (length matched))
894        val ematched = mapfilter (instantiate_encoders) matched
895        val hmatched = mapfilter (pfs (resolveit assums)) ematched
896        val _ = trace 3 (":" ^ int_to_string (length hmatched) ^ "\n")
897in
898        (map (fn (a,b,c) => (b,c)) ematched,
899         map (fn (a,b,c) => (b,c))
900             (sort (fn (p1,_,_) => fn (p2,_,_) => p1 > p2) hmatched))
901end
902end;
903
904(*****************************************************************************)
905(* PROPAGATE_THENC : thm list * thm list ->                                  *)
906(*                            (thm list * thm list -> term -> thm) ->        *)
907(*                                                     (string * thm) -> thm *)
908(*     PROPAGATE_THENC (assumptions,extras) next_conv (name,thm)             *)
909(*     Applies the rewrite given as (name,thm) by encoding all of the        *)
910(*     sub-encoders of the rewrite using next_conv                           *)
911(*                                                                           *)
912(*     When encoding a theorem of the form:                                  *)
913(*          |- (Q0 ==> enc a0 = A0) /\ ... /\ (Qm ==> enc am = Am) ==>       *)
914(*                enc (f a0 ... am) = F A0 ... Am                            *)
915(*     with a list of assumptions, A, and extra theorems, E, PROPAGATE_THENC *)
916(*     progresses by encoding each ai under the assumptions <(A u {Q0}),E>   *)
917(*     where <a,b> resolves extra assumptions using include_assumption_list. *)
918(*                                                                           *)
919(*     It is possible for Ai to be present in any aj, in which case, enc aj  *)
920(*     is processed after enc ai. It is also possible for Ai to be of the    *)
921(*     form: Fi x y z, where x y z appear in Ai. In such cases, a higher-    *)
922(*     order match is performed.                                             *)
923(*                                                                           *)
924(*     If Qi is empty, then [enc ai = Ai] |- detect Ai is added to the       *)
925(*     assumption list of all subsequent encodings and resolved at the end.  *)
926(*                                                                           *)
927(*****************************************************************************)
928
929val PROPAGATE_THENC_data = ref [];
930
931(*****************************************************************************)
932(* append_detector : hol_type -> (term list * term) -> thm list -> thm list  *)
933(*     append_detector target (L,e) A takes an encoding term representing    *)
934(*     'n L ==> encode x = X' and, provided L is null, derives the theorem:  *)
935(*     [encode x = X] |- detect X  and appends it, and the theorem:          *)
936(*     [encode x = X] |- encode x = X to the list of assumptions A.          *)
937(*     This is then provided as an assumption for further encodings.         *)
938(*                                                                           *)
939(*****************************************************************************)
940
941fun append_detector target ([],e) A =
942    (DISCH e (CONV_RULE (RAND_CONV (REWR_CONV (ASSUME e)))
943     (ISPEC (rand (lhs e))
944      (FULL_ENCODE_DETECT_THM target (type_of (rand (lhs e))))))::A
945(* possible fix for backtracking:   (DISCH e (ASSUME e)) :: A*)
946    handle _ => A)
947  | append_detector target _ A = A;
948
949(*****************************************************************************)
950(* remove_head : thm -> thm -> thm                                           *)
951(*    remove_head M N takes theorems M and N, of the form:                   *)
952(*         |- A,    |- {A} u Q ==> P                                         *)
953(*    and returns the theorem:                                               *)
954(*         |- Q ==> P                                                        *)
955(*    If Q is empty, {}, then |- T ==> P is returned.                        *)
956(*                                                                           *)
957(*****************************************************************************)
958
959fun remove_head r thm =
960let val h = fst (dest_imp (concl thm))
961            handle e => wrapException "remove_head" e
962    val x = fst (dest_conj h) handle _ => h
963    val thm' = INST_TY_TERM (match_term x (concl r)) thm
964in
965    CONV_RULE (LAND_CONV (LAND_CONV (REWR_CONV (EQT_INTRO r)) THENC
966               REWR_CONV (CONJUNCT1 (SPEC_ALL AND_CLAUSES)))) thm'
967    handle e =>
968    CONV_RULE (LAND_CONV (REWR_CONV (EQT_INTRO r))) thm'
969end
970
971(*****************************************************************************)
972(* HO_INST_TY_TERM :                                                         *)
973(*        {redex : term, residue : term} list *                              *)
974(*        {redex : hol_type, residue : hol_type} list -> thm -> thm          *)
975(* ho_inst_ty_term :                                                         *)
976(*        {redex : term, residue : term} list *                              *)
977(*        {redex : hol_type, residue : hol_type} list -> term -> term        *)
978(*     Takes a higher-order match, as returned from ho_match_term, used to   *)
979(*     instantiate the theorem or term from the term given and beta-converts *)
980(*     any higher-order terms. Eg.:                                          *)
981(*         HO_INST_TY_TERM [a |-> \b. A b] `t (a b)`                         *)
982(*                                         |- t ((\b. A b) a) = |- t (A a)   *)
983(*     and similar for ho_inst_ty_term.                                      *)
984(*                                                                           *)
985(*     Note: This is not technically correct, ALL lambda abstractions will   *)
986(*           be beta converted, so the following will be incorrect:          *)
987(*           ho_inst_ty_term (match_term ``f a`` ``(\b. K c b) a``) ``f a``  *)
988(*                                                                           *)
989(*****************************************************************************)
990
991local
992fun FIX (match : {redex : term, residue : term} list) term
993        mkc mka (beta:term -> 'a) (refl:term -> 'a) tm =
994    if is_comb term
995       then if (is_abs (#residue
996                       (first (curry op= (rator term) o #redex) match))
997                handle e => false)
998               then beta tm
999               else mkc (FIX match (rator term) mkc mka beta refl (rator tm),
1000                         FIX match (rand term) mkc mka beta refl (rand tm))
1001       else if is_abs term
1002               then mka (bvar tm)
1003                       (FIX match (body term) mkc mka beta refl (body tm))
1004               else refl tm;
1005fun inst_ty_term match term =
1006    subst (fst match) (inst (snd match) term);
1007in
1008fun HO_INST_TY_TERM (term_match,type_match) thm =
1009let val hyps = hyp thm
1010    val thm' = DISCH_ALL_CONJ thm
1011    val thmb = INST_TY_TERM (term_match,type_match) thm'
1012    val thma = INST_TYPE type_match thm'
1013    val rewrite = FIX term_match (concl thma)
1014            MK_COMB (fn bvar => fn body => MK_ABS (GEN bvar body))
1015            BETA_CONV REFL (concl thmb)
1016    val complete = EQ_MP rewrite thmb
1017in
1018    case hyps
1019    of [] => complete
1020    |  L => UNDISCH_CONJ complete
1021end handle e => wrapException "HO_INST_TY_TERM" e
1022fun ho_inst_ty_term (term_match,type_match) term =
1023    FIX term_match (inst type_match term) mk_comb (curry mk_abs) beta_conv I
1024        (inst_ty_term (term_match,type_match) term)
1025    handle e => wrapException "ho_inst_ty_term" e
1026end;
1027
1028local
1029fun loop_exn name =
1030(mkDebugExn "PROPAGATE_THENC"
1031            ("The rewrite theorem " ^ name ^
1032             " appears to contain an encoding loop:\n" ^
1033             " ie. it has antecedents: \"(encode (f X) = Y) ..." ^
1034             " /\\ (encode (g Y) = X)\""))
1035fun match_exn name = (mkDebugExn "PROPAGATE_THENC"
1036 ("Rewrite theorem: " ^ name ^ " is not of the form: \n" ^
1037  "[] |- P0 /\\ ... /\\ Pn ==> \n" ^
1038  "      (Q0 ==> encode a0 = A0) /\\ ... /\\ (Qm ==> encode am = Am)\n" ^
1039  "      ==> (encode (F a0 ... an) = F A0 ... Am)\n" ^
1040  "  where no encoders are present in A0 .. Am."))
1041(* Checks to determine whether the encoder requires results from other *)
1042(* encoders.                                                           *)
1043fun clear fvsr (L,e) =
1044    if e = T then (L,e)
1045    else (if all (fn tm => not (exists (C free_in tm) fvsr))
1046                     (lhs e :: L)
1047    then (L,e) else raise Empty)
1048fun mimp x = if is_imp x then ((strip_conj ## I) (dest_imp x)) else ([],x)
1049(* Applies remove_head continuously to remove detects from the theorem       *)
1050fun remove_detect_hyps [] thm =
1051    (CONV_RULE (REWR_CONV (CONJUNCT1 (SPEC_ALL IMP_CLAUSES))) thm handle e =>
1052    if concl thm = T then TRUTH else
1053       raise (mkDebugExn "remove_detect_hyps" (
1054               "Could not remove all hypothesese from the theorem:\n" ^
1055                thm_to_string thm)))
1056  | remove_detect_hyps (r::rs) thm =
1057    (if head_term thm = T
1058        then if not (is_conj (fst (dest_imp_only (concl thm))))
1059                then remove_detect_hyps [] thm
1060                else remove_detect_hyps (r::rs) (remove_head TRUTH thm)
1061        else remove_detect_hyps rs (remove_head r thm))
1062    handle e => remove_detect_hyps [] thm
1063in
1064fun PROPAGATE_THENC (assums,extras) conv (name,thm_pre) =
1065let val _ = trace 2 "->PROPAGATE_THENC\n"
1066    val _ = trace 1 ("R(" ^ name ^ ")");
1067    val _ = PROPAGATE_THENC_data :=
1068            ((assums,extras),conv,(name,thm_pre)) :: !PROPAGATE_THENC_data;
1069    val thm = CONV_RULE (LAND_CONV (PURE_REWRITE_CONV
1070                        [GSYM CONJ_ASSOC])) thm_pre
1071    val (encoders,final) =
1072        (map mimp o strip_conj ## I) (dest_imp_only (concl thm))
1073        handle e => raise (match_exn name);
1074    val fvs_right = mapfilter (rhs o snd) encoders
1075    val target = type_of (rhs final)
1076        handle e => raise (match_exn name);
1077
1078    fun check_hyp thmb =
1079        if null (set_diff (hyp thmb) (flatten (map hyp assums))) then thmb
1080        else raise (mkDebugExn "check_hyp"
1081                    ("PROPAGATE_THENC has altered the hypothesis set," ^
1082                     " the following terms have been added:\n" ^
1083                     xlist_to_string term_to_string
1084                       (set_diff (hyp thmb) (flatten (map hyp assums)))))
1085
1086    fun enc A [] = []
1087      | enc A encs =
1088    let val ((n,(L,e)),rest) =
1089            pick_e (loop_exn name)
1090                   (I ## clear (mapfilter (rhs o snd o snd) encs)) encs
1091        val conved = DISCH_LIST_CONJ (T::map (fst o dest_imp o concl) A)
1092                         (DISCH_LIST_CONJ (T::L)
1093                            (conv (include_assumption_list (map ASSUME L)
1094                                        (assums @ (map UNDISCH A),extras))
1095                                                  (lhs e)))
1096                     handle E =>
1097                     if e = T then TRUTH else raise E
1098    in  (n,conved):: (enc (append_detector target (L,e) A) rest)
1099    end
1100
1101    val recs = enc [] (enumerate 1 encoders)
1102
1103    val list = strip_conj (fst (dest_imp_only (concl thm)))
1104                   handle e => raise (match_exn name);
1105
1106    fun check_cons x rs =
1107        if is_imp_only (concl x) then rs else x::rs;
1108
1109    fun matchit ((n,x),((list,removed),thm)) =
1110    let val x' = remove_detect_hyps removed x
1111        val x'' = CONV_RULE (FIRST_CONV [
1112                   REWR_CONV (CONJUNCT1 (SPEC_ALL IMP_CLAUSES)),
1113                   LAND_CONV (REWR_CONV (CONJUNCT1 (SPEC_ALL AND_CLAUSES))),
1114                   ALL_CONV]) x'
1115                  handle e => (if concl x' = T then x' else raise e)
1116        val match = ho_match_term [] Term.empty_tmset (el n list) (concl x'');
1117        val thm' = HO_INST_TY_TERM match thm
1118        val list' = map (ho_inst_ty_term match) list
1119    in  ((list',check_cons x'' removed),
1120                 HO_MATCH_MP (DISCH (el n list') thm') x'')
1121    end handle e => wrapException "matchit" e
1122in
1123    (check_hyp (snd (foldl matchit ((list,[]),UNDISCH_CONJ thm) recs)) before
1124     (PROPAGATE_THENC_data := tl (!PROPAGATE_THENC_data)))
1125    handle e => wrapException "PROPAGATE_THENC" e
1126end
1127end
1128
1129(*****************************************************************************)
1130(* backchain_rewrite : int -> thm list -> term -> thm                        *)
1131(*     Attempts to prove (or disprove) the term given by repeatedly applying *)
1132(*     rewrites in the list:                                                 *)
1133(*         backchain_rewrite n RR P =                                        *)
1134(*            a) |- A ==> P = (Q0 /\ Q1 ... ==> P0 \/ ...) /\ ...            *)
1135(*               ==> !a in A. backchain_rewrite (n + 1) RR a                 *)
1136(*                   !i. ?j. backchain_rewrite (n + 1) (RR u {Q0,Q1..}) Pi   *)
1137(*                                                                           *)
1138(*     This ONLY operates on boolean valued theorems. It was designed to     *)
1139(*     solve very simple problems relatively quickly. It can certainly be    *)
1140(*     improved, but does the job for now...                                 *)
1141(*                                                                           *)
1142(*****************************************************************************)
1143
1144fun DISCHL_CONJ hs thm =
1145    foldr (fn (t,thm) => CONV_RULE (REWR_CONV AND_IMP_INTRO) (DISCH t thm))
1146          (DISCH (last hs) thm) (butlast hs);
1147
1148fun SOME_CONJ_CONV conv term =
1149    conv term handle _ =>
1150    if is_conj term then
1151       LAND_CONV (SOME_CONJ_CONV conv) term handle _ =>
1152       RAND_CONV (SOME_CONJ_CONV conv) term
1153    else NO_CONV term;
1154
1155val max_depth = ref 8;
1156
1157val LHS = lhs o snd o strip_imp_only o snd o strip_forall;
1158val RHS = rhs o snd o strip_imp_only o snd o strip_forall;
1159val BOOL_RULE =
1160    CONV_RULE (FIRST_CONV (map REWR_CONV (CONJUNCTS (SPEC_ALL EQ_CLAUSES))));
1161fun UNBOOL_RULE thm =
1162    EQF_INTRO thm handle _ => EQT_INTRO thm
1163val full_strip_imp = (map strip_conj ## I) o strip_imp
1164
1165fun perform_rewrite depth RR term rewrite =
1166let val rewrite_thm = HO_PART_MATCH LHS rewrite term
1167    val (hyp_set,_) = (map (map (backchain_rewrite (depth + 1) RR)) ## I)
1168                      (full_strip_imp (concl rewrite_thm))
1169    val finished =
1170        if null hyp_set
1171           then rewrite_thm
1172           else foldr (uncurry (C MP)) rewrite_thm
1173                      (map LIST_CONJ hyp_set)
1174    val poss = strip_conj (rhs (concl finished))
1175    fun single p thm =
1176            RIGHT_CONV_RULE
1177            (SOME_CONJ_CONV (REWR_CONV (UNBOOL_RULE
1178                        (apply_rewrite depth RR p))) THENC
1179             REWRITE_CONV []) thm
1180    fun all [] thm = raise Empty
1181      | all (p::ps) thm =
1182    (let val x = single p thm in (BOOL_RULE x  handle _ => all ps x) end)
1183    handle _ => all ps thm;
1184in
1185    all poss finished
1186end
1187and apply_rewrite depth RR conj =
1188let val (terms,disjs) = (map strip_conj ## strip_disj) (strip_imp_only conj)
1189    val sortf = sort (fn a => fn b => term_size (snd (strip_imp_only a))
1190                                    < term_size (snd (strip_imp_only b)))
1191    val solved = tryfind_e Empty
1192                         (backchain_rewrite (depth + 1)
1193                         (map ASSUME (flatten terms) @ RR)) (sortf disjs)
1194                         handle e => DECIDE ``~F``
1195in
1196    if not (exists (curry op= (concl solved)) (disjs)) then
1197       if dest_neg (concl solved) = conj andalso null (flatten terms)
1198          then solved
1199          else BOOL_RULE ((PURE_ONCE_REWRITE_CONV [
1200               UNBOOL_RULE (tryfind_e Empty (backchain_rewrite (depth + 1) RR)
1201                                    (flatten terms))] THENC
1202           REWRITE_CONV []) conj)
1203       else foldr (uncurry DISCHL_CONJ)
1204              (BOOL_RULE ((ONCE_REWRITE_CONV [UNBOOL_RULE solved] THENC
1205                    REWRITE_CONV []) (list_mk_disj disjs))) terms
1206end
1207and backchain_rewrite depth RR term =
1208let val _ = trace 3 ("#" ^ int_to_string depth)
1209    val _ = if depth > !max_depth then raise Empty else ()
1210    val applicable = filter (can (C match_term term o LHS o concl)) RR
1211    val (quick,slow) = partition (curry op= T o RHS o concl) applicable
1212    val sortf = sort (fn a => fn b => length (hyp a) < length (hyp b))
1213    val quick_ordered = sortf quick
1214    val slow_ordered = sortf slow
1215in
1216    if is_imp_only term
1217       then DISCH (fst (dest_imp term))
1218                  (backchain_rewrite depth
1219                      (map UNBOOL_RULE (CONJUNCTS
1220                           (ASSUME (fst (dest_imp term))))
1221                           @ RR)
1222                      (snd (dest_imp term))) else
1223    if term = F then (trace 3 "F" ; DECIDE ``~F``) else
1224    if term = T then (trace 3 "T" ; TRUTH) else
1225    if null applicable andalso not (is_neg term)
1226       then backchain_rewrite depth RR (mk_neg term)
1227       else
1228    tryfind_e Empty (perform_rewrite depth RR term) quick_ordered handle _ =>
1229    tryfind_e Empty (perform_rewrite depth RR term) slow_ordered
1230end
1231
1232(*****************************************************************************)
1233(* ATTEMPT_REWRITE_PROOF : (thm list * thm list) ->                          *)
1234(*                                       (string * thm) -> (string * thm)    *)
1235(*    Attempts to prove a term by rewriting using the assumptions and extras *)
1236(*    The assumptions are converted as follows:                              *)
1237(*        A |- encode Y = y  --> A |- Y = decode y                           *)
1238(*        A |- X = Y         --> A |- X = Y : bool                           *)
1239(*        A |- X = Y         --> A |- P X = P Y                              *)
1240(*        A |- P             --> A |- P = T                                  *)
1241(*    The extras are converted as follows:                                   *)
1242(*        A |- P ==> Q       ==> A |- P ==> Q = T                            *)
1243(*        A |- P ==> X = Y   ==> A |- P ==> X = Y : bool                     *)
1244(*        A |- P ==> X = Y   ==> A |- P ==> Q X = Q Y                        *)
1245(*                                                                           *)
1246(*****************************************************************************)
1247
1248fun fix_extra (thm',A) =
1249let val thm = SPEC_ALL thm'
1250    val (imps,eq) = strip_imp_only (concl thm)
1251in
1252    (if is_eq eq
1253       then if (type_of (lhs eq) = bool) then (thm::GSYM thm::A) else
1254            let val r = foldr (uncurry DISCH)
1255                              (AP_TERM (genvar (type_of (lhs eq) --> bool))
1256                                       (UNDISCH_ALL_ONLY thm)) imps
1257            in (r::GSYM r::A)
1258            end
1259       else foldr (uncurry DISCH) (UNBOOL_RULE (UNDISCH_ALL_ONLY thm)) imps::A)
1260   handle _ => A
1261end;
1262
1263fun ap_decode thm =
1264let val target = type_of (lhs (concl thm))
1265    val t = type_of (rand (lhs (concl thm)))
1266    val thm' =  AP_TERM (get_decode_function target t) thm
1267    val encdec = FULL_ENCODE_DECODE_THM target t
1268in
1269    CONV_RULE (LAND_CONV (REWR_CONV encdec)) thm'
1270end;
1271
1272fun fix_assum (thm,A) =
1273    (if is_eq (concl thm)
1274       then let val r = if is_encoded_term (LHS (concl thm)) andalso
1275                           is_var (RHS (concl thm))
1276                then ap_decode thm
1277                else if (type_of (lhs (concl thm)) = bool) then thm
1278                     else AP_TERM (genvar (type_of (lhs (concl thm)) --> bool))
1279                          thm
1280            in (r::GSYM r::A)
1281            end
1282       else UNBOOL_RULE thm::A) handle _ => A
1283
1284val standard_backchain_thms =
1285    ref [COND_RAND,COND_RATOR,
1286         DECIDE ``(if a then b else c) = (a ==> b) /\ (~a ==> c)``,
1287         DECIDE ``~a = (a ==> F)``];
1288
1289fun ATTEMPT_REWRITE_PROOF (assums,extras) (string,thm) =
1290let val RR = foldl fix_assum (foldl fix_extra (!standard_backchain_thms)
1291                                    extras) assums
1292    val terms = strip_conj (fst (dest_imp_only (concl thm)))
1293    val _ = trace 3 "backtracking...\n";
1294    val _ = trace 4 (xlist_to_string thm_to_string RR);
1295    val _ = trace 4 "\n";
1296    val thms = map (fn x => backchain_rewrite 0 RR x before trace 3 "\n") terms
1297        handle e => (trace 3 "!!\n"; raise e)
1298in
1299   (string:string,MATCH_MP thm (LIST_CONJ thms))
1300end
1301
1302(*****************************************************************************)
1303(* PROPAGATE_ENCODERS_CONV : (thm list * thm list) -> term -> thm            *)
1304(*                                                                           *)
1305(*    PROPAGATE_ENCODERS_CONV (assumptions,extras) ``encode M``              *)
1306(*    propagates encoders through the term M under the assumptions given     *)
1307(*    with additional theorems to aid resolution given in extras. Theorems   *)
1308(*    for rewriting are found using 'return_matches' and the reference       *)
1309(*    functions !conversions.                                                *)
1310(*                                                                           *)
1311(*    If no rewrite matches a term then:                                     *)
1312(*       a) If the term is matched by function in !terminals then REFL term  *)
1313(*          is returned.                                                     *)
1314(*       b) If the term is matched by a polytypic theorem, ie. a function    *)
1315(*          in !polytypic_rewrites returns a theorem, then if that theorem   *)
1316(*          is not already present in rewrites another attempt is made,      *)
1317(*          otherwise, an exception is generated                             *)
1318(*       c) If none of the above occurs, an exception is generated.          *)
1319(*                                                                           *)
1320(*    Controlled through the references:                                     *)
1321(*            !rewrites  : thm list                                          *)
1322(*                List of propagation theorems in standard form.             *)
1323(*            !conversions : (int * string * (term -> thm)) list             *)
1324(*                List of conversions (results in conditional form)          *)
1325(*            !polytypic_rewrites : (int * string * (term -> thm)) list      *)
1326(*                List of polytypic propagation theorems                     *)
1327(*            !terminals : (string * (term -> bool)) list                    *)
1328(*                List of functions indicating stoppage.                     *)
1329(*    These are controlled through the functions:                            *)
1330(*            add_polytypic_rewrite, add_standard_conversion,                *)
1331(*            add_conditional_conversion, add_terminal                       *)
1332(*            and the corresponding remove_... functions                     *)
1333(*        The add_extended_terminal function can also determine the list of  *)
1334(*        assumptions when deciding whether to stop encoding.                *)
1335(*                                                                           *)
1336(*    Note: Free variables in 'extras' are instantiated to avoid variable    *)
1337(*    capture.                                                               *)
1338(*                                                                           *)
1339(*****************************************************************************)
1340
1341local
1342val string =  for 1 78 (K #"-");
1343fun mconcat [] = String.implode string
1344  | mconcat ((n,s)::L) =
1345let val ns = explode ("Failure: " ^ int_to_string n)
1346in  implode (ns @ (List.take(string,length string - length ns))) ^
1347    "\n" ^ s ^ "\n" ^ mconcat L
1348end;
1349fun check_failure ((ematched,assums),term) =
1350    (map (fn (s,x) => (s,partial_resolve true (resolve_functions assums) x
1351                  handle UNCHANGED => x)) ematched,
1352     assums,term);
1353fun describe_single_failure (pmatched,assums,term) =
1354   "  Term: " ^ term_to_string term ^ "\n" ^
1355   "  Assumptions: " ^ xlist_to_string thm_to_string assums ^
1356   "\n" ^ (if null pmatched then "" else
1357   "  ... However, the following list of theorems partially matched:\n" ^
1358          xlist_to_string (fn (x,y) => x ^ ": " ^ thm_to_string y) pmatched)
1359fun op_set_eq f a1 a2 =
1360    null (op_set_diff f a1 a2) andalso null (op_set_diff f a2 a1)
1361fun eq_term t1 t2 = can (match_term t1) t2 andalso can (match_term t2) t1
1362fun eq_thm thm1 thm2 =
1363    eq_term (concl thm1) (concl thm2) andalso
1364    op_set_eq eq_term (hyp thm1) (hyp thm2)
1365fun subset a1 a2 = null (op_set_diff eq_thm a1 a2)
1366fun ssubset a1 a2 = subset a1 a2 andalso not (subset a2 a1)
1367fun supercedes (p1,a1,t1) (p2,a2,t2) =
1368    (t2 = t1) andalso
1369        ((ssubset (map snd p1) (map snd p2) orelse
1370         (subset (map snd p1) (map snd p2) andalso subset a1 a2)))
1371fun reduce [] = []
1372  | reduce (x::L) =
1373    if exists (supercedes x) L
1374    then reduce L
1375    else x::reduce (filter (not o C supercedes x) L)
1376in
1377fun describe_match_failure L =
1378let val failures = map check_failure L
1379    val all_failures = reduce failures
1380    val full_fails = map describe_single_failure all_failures
1381in
1382    case full_fails
1383    of [] => raise Empty
1384    |  [x] => raise (mkStandardExn "PROPAGATE_ENCODERS_CONV"
1385                    ("No rewrite matched the following:\n" ^ x))
1386    |  _   => raise (mkStandardExn "PROPAGATE_ENCODERS_CONV"
1387                    ("No rewrite matched the following:\n" ^
1388                    mconcat (enumerate 0 full_fails)))
1389end
1390end;
1391
1392
1393val terminals = ref ([] : (string * (thm list -> term -> bool)) list);
1394val polytypic_rewrites = ref ([] : (int * string * (term -> thm)) list);
1395val conversions = ref ([] : (int * string * (term -> thm)) list);
1396fun clear_rewrites () =
1397    (rewrites := Net.empty ;
1398     polytypic_rewrites := [] ;
1399     conversions := [] ;
1400     terminals := []);
1401
1402val propagate_encoders_conv_data =
1403       ref (NONE : ((thm list * thm list) * term) option);
1404
1405local
1406exception MatchFailure of (((string * thm) list * thm list) * term) list
1407val this_function = "PROPAGATE_ENCODERS_CONV"
1408fun exists_polytypic_theorem previous (priority,name,theorem) =
1409let val matches = Net.match (lhs (snd (strip_imp (concl theorem)))) previous
1410in
1411    exists (fn (p,n,t) => (p = priority) andalso
1412                       (aconv (concl theorem) (concl t))) matches
1413end handle e => wrapException "exists_polytypic_theorem" e
1414fun tryadd_polytypic_theorem failure success term =
1415let val previous = !rewrites
1416    val polys = mapfilter (fn (p,s,f) => (p,s,f term)) (!polytypic_rewrites)
1417        handle e => wrapException "tryadd_polytypic_theorem" e
1418    val new_polys = filter (not o exists_polytypic_theorem previous) polys
1419        handle e => wrapException "tryadd_polytypic_theorem" e
1420    val sorted = sort (fn (p1,_,_) => fn (p2,_,_) => p1 > p2) new_polys
1421in
1422    case sorted
1423    of [] => failure ()
1424    |  ((priority,name,thm)::_) =>
1425       ((trace 1 ("A(" ^ name ^ ")") ;
1426         trace 2 ("Polytypic theorem: " ^ thm_to_string thm) ;
1427         (add_conditional_rewrite priority name thm
1428         handle e => wrapException "tryadd_polytypic_theorem" e) ;
1429         success ()))
1430
1431end
1432fun fix_extra e =
1433let val vs = thm_frees (SPEC_ALL e)
1434    val vs' = map (fn x => x |-> (genvar o type_of) x) vs
1435in
1436    INST vs' e
1437end
1438fun terminate (assums,extras) term =
1439let val (s,_) = first (fn (x,y) => y assums term) (!terminals)
1440    val _ = trace 1 ("T(" ^ s ^ ")")
1441in
1442    SOME (REFL term)
1443end handle _ => NONE
1444fun try_all_matches AE [] exns = raise (MatchFailure exns)
1445  | try_all_matches AE (match::matches) exns =
1446    PROPAGATE_THENC AE PEC match
1447    handle MatchFailure L =>
1448    (try_all_matches AE matches (exns @ L))
1449and try_backchain_matches AE [] failure exns
1450    = raise (MatchFailure (failure :: exns))
1451  | try_backchain_matches AE (ematch::matches) ((fails,assums),term) exns =
1452    PROPAGATE_THENC AE PEC
1453        (ATTEMPT_REWRITE_PROOF AE ematch)
1454    handle Empty => try_backchain_matches AE matches
1455                    ((ematch::fails,assums),term) exns
1456         | MatchFailure L => try_backchain_matches AE matches
1457                    ((fails,assums),term) (exns @ L)
1458and PEC (AE as (assums,extras)) term =
1459    case (terminate AE term)
1460    of SOME thm => thm
1461    |  NONE =>
1462    let val _ = propagate_encoders_conv_data := SOME (AE,term)
1463        val (ematched,matches) = return_matches assums term
1464        val cmatches =
1465            mapfilter (fn (p,n,func) => match_single_rewrite assums
1466                                        term (p,n,func term)) (!conversions)
1467    in
1468        case (matches @ cmatches)
1469        of [] => (tryadd_polytypic_theorem
1470                  (fn () => (try_backchain_matches AE ematched
1471                                 (([],assums),term) []))
1472                  (fn () => PEC AE term)
1473                  term)
1474        |  L => try_all_matches AE L []
1475    end
1476in
1477fun PROPAGATE_ENCODERS_CONV AE term =
1478    ((scrub_rewrites() ;
1479     PEC ((I ## map fix_extra) AE) term) before
1480      (trace 1 "\n" ; propagate_encoders_conv_data := NONE))
1481    handle (MatchFailure L) => describe_match_failure L
1482end;
1483
1484fun add_extended_terminal (s,func) =
1485    if exists (curry op= s o fst) (!terminals)
1486       then raise (mkStandardExn "add_terminal"
1487                    ("Terminal " ^ s ^ " already exists!"))
1488       else terminals := (s,func) :: (!terminals);
1489
1490fun add_terminal (s,func) = add_extended_terminal (s,K func);
1491
1492fun remove_terminal s =
1493    terminals := filter (not o curry op= s o fst) (!terminals)
1494
1495fun add_polytypic_rewrite priority name func =
1496    if exists (curry op= name o (fn (a,b,c) => b)) (!polytypic_rewrites)
1497       then raise (mkStandardExn "add_polytypic_rewrite"
1498                    ("Polytypic rewrite " ^ name ^ " already exists!"))
1499       else polytypic_rewrites := (priority,name,func) :: (!polytypic_rewrites)
1500
1501fun remove_polytypic_rewrite s =
1502    polytypic_rewrites :=
1503      filter (not o curry op= s o (fn (a,b,c) => b)) (!polytypic_rewrites)
1504
1505fun add_standard_conversion priority name func =
1506    if exists (curry op= name o (fn (a,b,c) => b)) (!conversions)
1507       then raise (mkStandardExn "add_standard_conversion"
1508                    ("Conversion " ^ name ^ " already exists!"))
1509       else conversions := (priority,name,conditionize_rewrite o func) ::
1510                           (!conversions)
1511
1512fun add_conditional_conversion priority name func =
1513    if exists (curry op= name o (fn (a,b,c) => b)) (!conversions)
1514       then raise (mkStandardExn "add_conditional_conversion"
1515                    ("Conversion " ^ name ^ " already exists!"))
1516       else conversions := (priority,name,func) ::
1517                           (!conversions)
1518
1519fun remove_conversion s =
1520    conversions := filter (not o curry op= s o (fn (a,b,c) => b)) (!conversions)
1521
1522
1523(*****************************************************************************)
1524(* Case processing:                                                          *)
1525(*                                                                           *)
1526(* find_comb : int list -> term -> term                                      *)
1527(*     Repeatedly finds the nth sub-term, eg:                                *)
1528(*        find_comb [1,2] ``f (g a b) c`` = ``b``                            *)
1529(*                                                                           *)
1530(* outermost_constructor      : term -> thm list -> (term -> term) option    *)
1531(*     Returns a function that finds the outermost leftmost constructed term *)
1532(*     such that the corresponding value in the template term is a variable  *)
1533(*     Eg. outermost_constructor ``SEG (SUC 0) 0 a`` (tl (CONJUNCTS SEG))    *)
1534(*         returns that function that returns the last argument from terms   *)
1535(*         of the form: SEG a b c = d                                        *)
1536(*     Fails if the theorems supplied are not function clauses, or if the    *)
1537(*     the function constants found in the term and clauses are not          *)
1538(*     equivalent up to the renaming of type variables                       *)
1539(*                                                                           *)
1540(* group_by_constructor                                                      *)
1541(*  :term -> (term -> term) -> thm list -> hol_type * (term * thm list) list *)
1542(*     Groups clauses to lists matching the outer most left most constructor *)
1543(*     The function supplied strips the constructor out of the term          *)
1544(*     Returns a list of clause * thm where if:                              *)
1545(*        function  = ``f a b c`` then                                       *)
1546(*        clause(i) = ``f (Ci x y z) a b c                                   *)
1547(*     Fails if no clauses are given, the function used fails or returns     *)
1548(*     a term that is not a compound type, or the left hand sides of the     *)
1549(*     clauses and the function term supplied are inconsistent.              *)
1550(*                                                                           *)
1551(* alpha_match_clauses :                                                     *)
1552(*        : (term -> term) -> (thm * term list) list -> term list * thm list *)
1553(*     Takes a list of function clauses and a list of missing clauses and    *)
1554(*     alpha converts them to match each other                               *)
1555(*     -- The list of missing clauses is simply concatenated, the clauses    *)
1556(*        are matched using their left hand sides with the sub-term          *)
1557(*        indicated by the function skipped.                                 *)
1558(*     Fails if the function given fails or the clauses are not of the       *)
1559(*     correct form, or if the left hand side of the clauses differ by more  *)
1560(*     than the sub-term located.                                            *)
1561(*                                                                           *)
1562(* condense_missing : (term -> term) -> term list -> term list               *)
1563(*     Takes a list of missing clauses and determines whether the set of     *)
1564(*     constructors determined by applying the function is complete, ie.     *)
1565(*     contains all the constructors for that type:                          *)
1566(*         condense_missing (rand o lhs) [``f (SUC n)``,``f 0``] = [``f n``] *)
1567(*     Notes: All arguments to constructors must be free variables           *)
1568(*            The function given should always perform 'lhs' (historical)    *)
1569(*                                                                           *)
1570(* clause_to_case : thm -> thm * term list                                   *)
1571(*     Converts a function defined using clause structure to use case        *)
1572(*     Returns a list of missing clauses                                     *)
1573(*                                                                           *)
1574(*     Algorithm:                                                            *)
1575(*       0. Group function clauses by the leftmost outermost constructor.    *)
1576(*       1. Recursively apply this algorithm to these groups                 *)
1577(*          --> For a type with n constructors should have n clauses         *)
1578(*       2. Alpha convert the clauses so that bound variables all match up   *)
1579(*       3. Fully specialise then generalise for bound variables in the      *)
1580(*          constructor in question, in left to right fashion.               *)
1581(*       4. Use Modus Ponens and the "func_case" theorem:                    *)
1582(*              |-   (!.. f (C0 ..) = A0) /\ .. /\ (!.. f (Cn ..) = An)      *)
1583(*                 = !x. f x = case A0 ... An x                              *)
1584(*                                                                           *)
1585(*     When grouping by constructor, a function term is instantiated to      *)
1586(*     match the constructor in use to be passed to the recursive call.      *)
1587(*     Eg. clause_to_case (f [] a) [|- f [] 0 = A, |- f [] (SUC n) = B]      *)
1588(*     The algorithm terminates if this clause does not have a free variable *)
1589(*     at the left most outermost constructor.                               *)
1590(*                                                                           *)
1591(*     If the grouping by constructor procedure returns an empty group       *)
1592(*     (this case is missing from the definition) the algorithm returns the  *)
1593(*     reflection of this clause and adds it to the list of missing calls.   *)
1594(*                                                                           *)
1595(*     Fails if the theorem given is not a proper function (conjunction of   *)
1596(*     universally quantified equalities) or two clauses exist which have    *)
1597(*     left hand sides that are alpha convertable.                           *)
1598(*     May also fail if a theorem return by "func_case" is of the wrong form *)
1599(*     which may happen if the user supplies such a theorem.                 *)
1600(*                                                                           *)
1601(* clause_to_case_list : int list -> thm -> (thm * term list)                *)
1602(*     Exactly as above, but a list is used to indicate the order in which   *)
1603(*     constructors are processed. Eg.                                       *)
1604(*         clause_to_case_list [[2],[1]]                                     *)
1605(*                  |- (f 0 0 = A) /\ (f (SUC n) 0 = B) /\                   *)
1606(*                     (f (SUC n) 0 = C) /\ (f (SUC n) (SUC m) = C)          *)
1607(*     will process the second argument first.                               *)
1608(*     If the list given does not correctly split the constructors then      *)
1609(*     this function will fail.                                              *)
1610(*                                                                           *)
1611(*****************************************************************************)
1612
1613local
1614fun fc [] y = y
1615  | fc (x::xs) y = fc xs (el x (snd (strip_comb y)))
1616in
1617fun find_comb l tm = fc l tm handle e =>
1618        raise (mkStandardExn "find_comb"
1619                ("Could not find the term specified by the list: " ^
1620                 xlist_to_string int_to_string l))
1621end
1622
1623local
1624fun find_mismatch tm1 tm2 =
1625let     val (f1,l1) = strip_comb tm1
1626        val (f2,l2) = strip_comb tm2
1627        val comb = enumerate 1 (zip l1 l2) handle _ => []
1628in
1629        if f1 = f2 orelse not (can polytypicLib.constructors_of (type_of f1))
1630        then tryfind_e Empty (uncurry cons o (I ## uncurry find_mismatch)) comb
1631        else    if is_var f1 andalso is_var f2
1632                then raise Empty else []
1633end
1634fun lex_less _ [] = false
1635  | lex_less [] _ = true
1636  | lex_less (x::xs) (y::ys) = x < y orelse x = y andalso lex_less xs ys
1637in
1638fun outermost_constructor function clauses =
1639let     val stripped = map (lhs o snd o strip_forall o concl) clauses
1640                handle e => raise (mkStandardExn "outermost_constructor"
1641                                        "Theorems are not all of the form: \"|- !a0 .. an. F = X\"")
1642        val normalised = map (fn x => inst (snd (match_term (repeat rator x) (repeat rator function))) x) stripped
1643                handle e => raise (mkStandardExn "outermost_constructor"
1644                                        "Theorems and function term supplied use different function constants")
1645        val lists = mapfilter (find_mismatch function) normalised
1646        val sorted = sort lex_less lists
1647        val list = hd sorted handle _ => []
1648in
1649        case list
1650        of [] => NONE
1651        |  list  => (trace 3 ("CP:" ^  (xlist_to_string int_to_string list) ^ "\n") ;
1652                        SOME (find_comb list o lhs o snd o strip_forall))
1653end
1654end;
1655
1656fun group_by_constructor _ _ [] =
1657    raise (mkStandardExn "group_by_constructor" "No function clauses given!")
1658  | group_by_constructor function outermost clauses =
1659let     fun om x = outermost x handle e => wrapException "group_by_constructor (outermost)" e
1660        val terms = map (fn y => (repeat rator (om (concl y)),y)) clauses
1661        val t = snd (strip_fun (type_of (fst (hd terms))))
1662        val cs = polytypicLib.constructors_of t handle e =>
1663                raise (mkStandardExn "group_by_constructor"
1664                        ("Function to find constructed terms returned a term with a non-compound type: " ^
1665                         type_to_string t))
1666        val matched = map (fn c => (c,map snd (filter (same_const c o fst) terms))) cs
1667        val _ = if foldl op+ 0 (map (length o snd) matched) = length clauses then ()
1668                        else raise (mkStandardExn "group_by_constructor"
1669                                ("The argument pointed to by the function given contains values " ^
1670                                 "which are not constructed terms: " ^ xlist_to_string (term_to_string o fst) terms))
1671        val replace = om (mk_eq(function,function))
1672        val fvs = ref (map (fst o dest_var) (free_vars (mk_abs(replace,function))))
1673        fun genvar t =
1674        let     val s = first (not o C mem (!fvs)) (map (implode o base26 o fst) (enumerate (length (!fvs)) (""::(!fvs))))
1675        in      (fvs := s :: (!fvs) ; mk_var(s,t))
1676        end
1677        fun mk_cons t c =
1678        let     val c' = inst [snd (strip_fun (type_of c)) |-> t] c
1679        in      list_mk_comb(c',map genvar (fst (strip_fun (type_of c')))) end
1680        fun sub a = subst [replace |-> mk_cons (type_of replace) a]
1681in
1682        (t,map (fn (a,b) => (sub a function,b)) matched) handle e => wrapException "group_by_constructor" e
1683end;
1684
1685local
1686val newvars = ref [];
1687fun gv t =
1688let val v = genvar (type_of t)
1689in  (newvars := (v |-> t) :: (!newvars) ; v) end;
1690fun subst_all_x x term =
1691    if x = term then gv term
1692    else if is_comb term
1693            then mk_comb (subst_all_x x (rator term),subst_all_x x (rand term))
1694            else if is_abs term
1695                    then mk_abs(bvar term,subst_all_x x (body term))
1696                    else term;
1697(* subst_om substitutes ONLY the term matching the function om (f term). *)
1698(* It does this by substituting everything, then replacing the incorrect *)
1699(* things.                                                               *)
1700fun subst_om om f g term =
1701let val x = om (g term)
1702    val _ = newvars := [];
1703    val all_subst = g (subst_all_x x term)
1704    val y = om all_subst
1705in
1706    subst (filter (not o curry op= y o #redex) (!newvars)) (f all_subst)
1707end;
1708in
1709fun alpha_match_clauses outermost [] = ([],[])
1710  | alpha_match_clauses outermost [(thm,missing)] = (missing,[SPEC_ALL thm])
1711  | alpha_match_clauses outermost ((lthm,lmissing)::clauses) =
1712let fun om x = outermost x
1713        handle e => wrapException "alpha_match_clauses (outermost)" e
1714    val (rmissing,thms) = alpha_match_clauses om clauses
1715    val rthm = hd thms
1716    fun split x = (lhs o snd o strip_forall o concl) x
1717        handle e => raise (mkStandardExn "alpha_match_clauses"
1718             "Clauses must all be of the form: \"|- !a0..an. f X = Y\"")
1719    val l = subst_om om lhs (fn x => mk_eq(x,x)) (split lthm)
1720    val r = subst_om om lhs (fn x => mk_eq(x,x)) (split rthm)
1721    val match = match_term l r
1722        handle e => raise (mkStandardExn "alpha_match_clauses"
1723             ("The left hand side of two or more clauses differs " ^
1724              "outside of the term indicated by 'outermost'"))
1725in
1726 (rmissing @ lmissing, (INST_TY_TERM match (SPEC_ALL lthm))::thms)
1727 handle e => wrapException "alpha_match_clauses" e
1728end
1729end;
1730
1731local
1732fun match_lists f [] [] = []
1733  | match_lists f []  _ = raise Empty
1734  | match_lists f (x::xs) L =
1735let val (y,ys) = pluck (f x) L handle e => raise Empty
1736in  (x,y)::match_lists f xs ys
1737end;
1738fun freebase26 n vars =
1739let val var = implode (base26 n)
1740in  if mem var vars then (freebase26 (n + 1) vars) else var
1741end
1742in
1743fun condense_missing outermost missing =
1744let val constructors = map (fn x => outermost (mk_eq (x,x))) missing
1745    handle e => wrapException "condense_missing (outermost)" e
1746in
1747   (case (mk_set (map (base_type o type_of) constructors))
1748    of [] => missing
1749    | (x::y::ys) => raise (mkDebugExn "condense_missing"
1750                    "Types of constructors do not match!")
1751    | [x] =>
1752    if all (all is_var o snd o strip_comb) constructors andalso
1753       can (match_lists (fn b => can (match_term (fst (strip_comb b))))
1754                         constructors) (constructors_of x)
1755       then let val var = case (total (tryfind (hd o free_vars)) constructors)
1756                          of SOME var => var
1757                          |  NONE =>
1758                             mk_var(freebase26 0 (map (fst o dest_var)
1759                                                      (free_varsl missing)),x)
1760                val _ = trace 1 ("M:[" ^ int_to_string (length missing) ^ "]")
1761            in [subst (map (fn c => c |-> var) constructors) (hd missing)]
1762            end
1763       else missing)
1764   handle e => wrapException "condense_missing" e
1765end
1766end;
1767
1768local
1769fun wrap s = wrapException ("clause_to_case_list" ^ s)
1770fun ctc omc function [] = (REFL function,[function])
1771  | ctc omc (function:term) (clauses:thm list) : thm * term list =
1772    case (omc function clauses)
1773    of NONE => if length clauses = 1
1774                  then (hd clauses,[])
1775                  else raise (mkStandardExn "clause_to_case_list"
1776                "Two or more function clauses do not differ by constructors")
1777    | SOME outermost =>
1778     let val grouped = (group_by_constructor function outermost clauses)
1779                       handle e => wrap "" e
1780         val (t,split_clauses) = (I ## map (uncurry (ctc omc))) grouped
1781         val (missing,next_thm) = alpha_match_clauses outermost split_clauses
1782                       handle e => wrap "" e
1783         val missing' = condense_missing outermost missing
1784         fun gen thm =
1785         let val list = snd (strip_comb (outermost (concl thm)))
1786                        handle e => wrap " (outermost)" e
1787         in GENL list thm end
1788         val thm = LIST_CONJ (map gen next_thm)
1789         val rule = fst (EQ_IMP_RULE (SPEC_ALL
1790                        (generate_source_theorem "func_case" t)))
1791                    handle e => wrap "" e
1792 in
1793  (HO_MATCH_MP rule thm,missing')
1794  handle e => raise (mkStandardExn "clause_to_case_list"
1795   ("Could not match the normalised group clauses:\n " ^
1796    thm_to_string thm ^ "\n" ^
1797    "with the \"func_case\" theorem generated for type: " ^
1798    type_to_string t ^ ":\n" ^
1799    thm_to_string rule))
1800 end
1801in
1802fun clause_to_case_list list thm =
1803let     val clauses = CONJUNCTS thm
1804        val left = (repeat rator o lhs o snd o strip_forall o concl o hd) clauses
1805                handle e => raise (mkStandardExn "clause_to_case_list"
1806                                "Theorem given is not a conjunction of universally quantified equalities")
1807        val function = list_mk_comb(left,map (fn (a,b) => (mk_var(implode (base26 a),b)))
1808                (enumerate 0 (fst (strip_fun (type_of left)))))
1809                handle e => wrapException "clause_to_case_list" e
1810        val rlist = ref list
1811        fun omc function clauses =
1812                case (!rlist)
1813                of [] => outermost_constructor function clauses
1814                |  (x::xs) =>
1815                let     val a = find_comb x function
1816                                handle e => raise (mkStandardExn "clause_to_case_list"
1817                                                ("The term path: " ^ xlist_to_string int_to_string x ^
1818                                                 " cannot find a sub-term in the current function: "
1819                                                ^ term_to_string function))
1820                        val _ = if is_var a then () else
1821                                raise (mkStandardExn "clause_to_case_list"
1822                                                ("The term path: " ^ xlist_to_string int_to_string x ^
1823                                                 " is selecting a constructed term: " ^ term_to_string function))
1824                in      (rlist := xs ; SOME (find_comb x o lhs o snd o strip_forall)) end
1825
1826in
1827        ctc omc function clauses
1828end
1829fun clause_to_case thm =
1830    clause_to_case_list [] thm handle e => wrapException "clause_to_case" e
1831end;
1832
1833(*****************************************************************************)
1834(* mk_func_case_thm : hol_type -> thm                                        *)
1835(*     Generates a theorem of the form:                                      *)
1836(*     |-   (!.. f (C0 ..) = A0 ..) /\ (!.. f (Cn ..) = An ..)               *)
1837(*        = !x. f x = case A0 .. An x                                        *)
1838(*     Simply generates the right hand side, applies CASE_SPLIT_CONV and     *)
1839(*     rewrites the left hand side to remove the case statements.            *)
1840(*     Fails if the type in question does not have a case definition or      *)
1841(*     constant supplied for it.                                             *)
1842(*                                                                           *)
1843(*     A theorem generator called "func_case" is added at load time.         *)
1844(*                                                                           *)
1845(*****************************************************************************)
1846
1847local
1848fun get_consts s ts = map (fn (a,b) => (mk_var(s ^ implode (base26 a),b)))
1849                              (enumerate 0 ts)
1850fun wrap e = wrapException "mk_func_case_thm" e
1851in
1852fun mk_func_case_thm t =
1853let     val c = TypeBase.case_const_of t handle e => wrap e
1854        val case_def = TypeBase.case_def_of t  handle e => wrap e
1855        val (ts,rtype) = strip_fun (type_of c)
1856        val consts = get_consts "f_" ts
1857        val case_term = fst (dest_comb (list_mk_comb(c,consts)))
1858                        handle e => wrap e
1859        val xvar = mk_var("arg",fst (dom_rng (type_of case_term)))
1860                   handle e => wrap e
1861        val fvar = mk_var("F",type_of xvar --> rtype)
1862        val full_term = mk_forall(xvar,mk_eq(mk_comb(fvar,xvar),
1863                      mk_comb(case_term,xvar)))  handle e => wrap e
1864in
1865        GSYM (RIGHT_CONV_RULE (EVERY_CONJ_CONV (STRIP_QUANT_CONV
1866                (RAND_CONV (FIRST_CONV (map REWR_CONV (CONJUNCTS case_def))))))
1867                           (CASE_SPLIT_CONV full_term))
1868        handle e => wrap e
1869end
1870end;
1871
1872val _ = add_rule_source_theorem_generator "func_case"
1873            (can constructors_of) mk_func_case_thm;
1874
1875(*****************************************************************************)
1876(* create_lambda_propagation_term : term -> term                             *)
1877(*     Given a term of the form: \x (y,z) ... . A  returns the conclusion of *)
1878(*     a propagation  theorem for lambda abstractions.                       *)
1879(*                                                                           *)
1880(* prove_lambda_propagation_term  : term -> thm                              *)
1881(*     Proves the conclusion generated by the previous function.             *)
1882(*                                                                           *)
1883(* make_lambda_propagation_theorem : term -> thm                             *)
1884(*     Creates a lambda propagation theorem to match the term given.         *)
1885(*                                                                           *)
1886(*****************************************************************************)
1887
1888local
1889open pairSyntax
1890fun compn [] _ = []
1891  | compn (n::xs) L = op:: ((I ## compn xs) (split_after n L));
1892fun wrap e = wrapException "general_lambda" e
1893fun general_lambda tm =
1894let     val terms = map (length o strip_pair) (fst (strip_pabs tm))
1895                    handle e => wrap e
1896        val names = for 0 (foldl op+ 0 terms) (String.implode o base26)
1897        val types = for 0 (foldl op+ 0 terms)
1898                          (mk_vartype o String.implode o cons #"'" o base26)
1899        val vars = map2 (curry mk_var) names types handle e => wrap e
1900        val pairs = map list_mk_pair (compn terms vars) handle e => wrap e
1901
1902        val out = mk_var(last names,foldr (op-->) (last types) (butlast types))
1903in
1904        list_mk_pabs (pairs,list_mk_comb(out,butlast vars)) handle e => wrap e
1905end
1906fun split_pairs [] = []
1907  | split_pairs (x::xs) =
1908        split_pairs (mk_fst x::mk_snd x::xs) handle e => x::split_pairs xs;
1909fun wrap e = wrapException "general_lambda_propagation_term" e
1910fun imp_conj [] term = term
1911  | imp_conj xs term = mk_imp(list_mk_conj xs,term);
1912in
1913fun create_lambda_propagation_term term =
1914let     val gterm = general_lambda term handle e => wrap e
1915        val vt = mk_vartype "'output";
1916        fun enc tm = mk_comb(mk_var("enc",type_of tm --> vt),tm);
1917
1918        val (pairs,body) = strip_pabs gterm
1919        val vns = map (String.concat o map (fst o dest_var) o strip_pair) pairs
1920        val vts = map type_of pairs
1921        val lterm = enc (list_mk_comb(gterm,map2 (curry mk_var) vns vts));
1922
1923        val rvars = map (C (curry mk_var) vt o prime) vns;
1924        val rterm = list_mk_comb(list_mk_abs(rvars,list_mk_comb(mk_var(prime (fst(dest_var(fst(strip_comb body)))),
1925                                foldl (fn (a,b) => (type_of a --> b)) vt rvars),rvars)),rvars)
1926
1927        val encs = map2 (curry mk_eq) (map enc (map2 (curry mk_var) vns vts)) rvars
1928        val decs = map (fn e => mk_comb(mk_var("dec",vt --> type_of (rand (lhs e))),rhs e)) encs;
1929        val dets = map (curry mk_comb (mk_var("det",vt --> bool)) o rhs) encs;
1930
1931
1932        val eterm = imp_conj dets (mk_eq(enc (list_mk_comb(fst (strip_comb body),split_pairs decs)),
1933                                        snd (strip_pabs (fst (strip_comb rterm)))));
1934
1935        val eds = map2 (fn a => fn b => mk_forall(rand(lhs b),mk_eq(mk_comb(rator a,lhs b),rand (lhs b)))) decs encs;
1936        val eps = map2 (fn a => fn b => mk_forall(rand(lhs b),mk_comb(rator a,lhs b))) dets encs;
1937in
1938        imp_conj (eds @ eps) (mk_imp(foldr mk_conj eterm encs,mk_eq(lterm,rterm)))
1939end
1940end;
1941
1942fun prove_lambda_propagation_term term =
1943let     val tac = REPEAT STRIP_TAC THEN pairLib.GEN_BETA_TAC THEN
1944                REPEAT (FIRST_ASSUM (SUBST_ALL_TAC o GSYM) THEN WEAKEN_TAC (fn a => is_eq a andalso lhs a = rhs a)) THEN
1945                ASSUM_LIST (fn list => RULE_ASSUM_TAC (fn th => if is_forall (concl th) then th else
1946                        PURE_REWRITE_RULE (filter (is_forall o concl) list) th)) THEN
1947                FIRST_ASSUM MATCH_MP_TAC THEN REPEAT CONJ_TAC THEN ACCEPT_TAC TRUTH
1948in
1949        case (tac ([],term) handle e => wrapException "prove_lambda_propagation_term" e)
1950        of ([],f) => (f [] handle e => wrapException "prove_lambda_propagation_term" e)
1951        |  _ => raise (mkStandardExn "prove_lambda_propagation_term"
1952                                ("Tactic used by this function did not fully solve the term:\n" ^ term_to_string term))
1953end;
1954
1955local
1956fun strip_to_abs term =
1957    if pairLib.is_pabs term then term else strip_to_abs (rator term)
1958in
1959fun make_lambda_propagation_theorem term =
1960    prove_lambda_propagation_term
1961            (create_lambda_propagation_term (strip_to_abs (rand term)))
1962            handle e => wrapException "make_lambda_propagation_theorem" e
1963end;
1964
1965(*****************************************************************************)
1966(* polytypic_let_conv : term -> thm                                          *)
1967(*     Proves a theorem similar to the lambda conversion, but for let        *)
1968(*     constructions.                                                        *)
1969(*****************************************************************************)
1970
1971fun mk_simplified_let let_term =
1972let val ty_vars = map (map (fn x => gen_tyvar()) o fst) let_term;
1973    val prod_types = map pairSyntax.list_mk_prod ty_vars
1974    val result_type = gen_tyvar();
1975
1976    val v = ref 0;
1977    fun next_var t = (mk_var(implode (base26 (!v)),t)) before (v := !v + 1);
1978
1979    val vars = zip (map (pairSyntax.list_mk_pair o map next_var) ty_vars)
1980                   (map next_var prod_types);
1981    val all_vars = flatten (map (pairSyntax.strip_pair o fst) vars)
1982    val result_var = next_var (foldr op--> result_type
1983                                     (map type_of all_vars));
1984    val result_term = list_mk_comb(result_var,all_vars);
1985    val simplified_let = pairSyntax.mk_anylet (vars,result_term)
1986    val encoder = next_var (type_of simplified_let --> gen_tyvar());
1987in
1988    (encoder,PURE_REWRITE_CONV [LET_THM] simplified_let)
1989end handle e => wrapException "mk_simplified_let" e;
1990
1991fun polytypic_let_conv term =
1992let val _ = trace 2 "->polytypic_let_conv\n";
1993    val _ = (pairSyntax.dest_anylet (rand term)) handle _ =>
1994        raise (mkStandardExn "polytypic_let_conv"
1995              ("Term: " ^ term_to_string term ^
1996               "\nis not an encoded let term"));
1997    val (let_term,result) =
1998                   (map (pairSyntax.strip_pair ## I) ## I)
1999                   (pairSyntax.dest_anylet (rand term));
2000    val let_thm1 = (uncurry AP_TERM o mk_simplified_let) let_term;
2001    val (_,let_thm2) = mk_simplified_let (map (C cons [] o hd ## I) let_term)
2002
2003    val lambda_thm = make_lambda_propagation_theorem (rhs (concl let_thm1))
2004        handle e => wrapException "polytypic_let_conv" e
2005in
2006    CONV_RULE (RAND_CONV (RAND_CONV (
2007              LAND_CONV (REWR_CONV (GSYM let_thm1)) THENC
2008              RAND_CONV (REWR_CONV (GSYM let_thm2))))) lambda_thm
2009        handle e => raise (mkDebugExn "polytypic_let_conv"
2010                 ("Could not rewrite lambda theorem:\n " ^
2011                  thm_to_string lambda_thm ^ "\nto a let expression using:\n" ^
2012                  thm_to_string let_thm1 ^ "\nand\n" ^
2013                  thm_to_string let_thm2))
2014end;
2015
2016(*****************************************************************************)
2017(* mk_affirmation_theorems : hol_type -> thm list                            *)
2018(*                                                                           *)
2019(*    Returns a list of theorems of the form:                                *)
2020(*    |- ~(?a... x = C0 a ..) /\ ~(?a... x = C1 a ..) ==> (?a... x = Cn a ..)*)
2021(*                                                                           *)
2022(*****************************************************************************)
2023
2024local
2025fun wrap e = wrapException "mk_affirmation_theorems" e
2026fun fix_term term thm =
2027    DISCH_ALL_CONJ (PURE_REWRITE_RULE [satTheory.NOT_NOT]
2028                                      (MATCH_MP IMP_F (DISCH term thm)))
2029fun ORDER_CONV term =
2030let val (l,r) = strip_exists (dest_neg term)
2031in  case (intersect (free_vars_lr r) l)
2032    of [] => REFL term
2033    | order => RAND_CONV (ORDER_EXISTS_CONV order) term
2034end;
2035in
2036fun mk_affirmation_theorems t =
2037let val nchotomy = SPEC_ALL (TypeBase.nchotomy_of t) handle e => wrap e
2038    val negated =
2039        CONV_HYP ORDER_CONV (REWRITE_RULE [] (UNDISCH_CONJ
2040                 (CONV_RULE (LAND_CONV NNF_CONV)
2041                            (CONTRAPOS (DISCH T nchotomy)))))
2042    val terms = hyp negated
2043in
2044    map (C fix_term negated) terms
2045end
2046end;
2047
2048(*****************************************************************************)
2049(* EXISTS_REFL_CONV : term -> thm                                            *)
2050(*    Proves the theorem  |- (?a. b = a) = T  using b as a witness.          *)
2051(*                                                                           *)
2052(*****************************************************************************)
2053
2054fun EXISTS_REFL_CONV term =
2055let val (var,body) = dest_exists term
2056    val thm = EXISTS (term,lhs body) (REFL (lhs body))
2057in
2058    REWRITE_CONV [thm] term
2059end
2060
2061(*****************************************************************************)
2062(* set_destructors : hol_type -> thm list -> unit                            *)
2063(*    Sets the destructors for a type and redefines the initial theorem.     *)
2064(*    Destructors must be of the following form:                             *)
2065(*       |- !a ... . F (C a ...) = a, ... |- !a b ... . G (C a b ..) = b     *)
2066(*                                                                           *)
2067(* nested_constructor_rewrite : term -> thm                                  *)
2068(*     Returns a polytypic rewrite that converts a nested constructor:       *)
2069(*        |- bool (?a b. x = C' (C a) (C b)) =                               *)
2070(*           bool (?a b. x = C' a b) /\ (?a. D x = C a) /\ (?b. D' x = C b)  *)
2071(*                                                                           *)
2072(* nested_constructor_theorem : term -> thm                                  *)
2073(*     Returns a theorem that resolves nested constructors:                  *)
2074(*     |- (?a b. x = C' a b) /\ (?a. D x = C a) /\ (?b. D' x = C b) ==>      *)
2075(*        ?a b. x = C' (C a) (C b)                                           *)
2076(*                                                                           *)
2077(*****************************************************************************)
2078
2079local
2080fun mk_initial [] = raise (mkDebugExn "mk_initial" "Empty list supplied!")
2081  | mk_initial dest_list =
2082let val list = map SPEC_ALL dest_list
2083    val cs = with_exn (map (rand o lhs o concl)) list
2084                      (mkDebugExn "mk_initial" "Invalid list of constructors")
2085    val normalized = with_exn (map (C (PART_MATCH (rand o lhs)) (hd cs))) list
2086                      (mkDebugExn "mk_initial"
2087                       "Constructors in list are not identical")
2088    val constructor = hd cs
2089    val vars = snd (strip_comb constructor)
2090    val _ = if exists (not o is_var) vars then
2091            raise (mkStandardExn "mk_initial"
2092                   "Destructor not applied to a ground constructed term!")
2093            else ();
2094    val t = type_of constructor
2095    val arg = variant vars (mk_var("arg",t))
2096    val left = mk_eq(arg,constructor)
2097    val term = mk_forall(arg,mk_eq(left,list_mk_conj(list_mk_exists(vars,left)::
2098                map (subst [constructor |-> arg] o concl) normalized)));
2099    val thms = mapfilter (fn f => f t) [TypeBase.distinct_of,GSYM o TypeBase.distinct_of,
2100                TypeBase.one_one_of];
2101in
2102    (case ((Cases THEN
2103     REPEAT (CHANGED_TAC (REWRITE_TAC (thms @ list))) THEN
2104     EQ_TAC THEN STRIP_TAC THEN ASM_REWRITE_TAC [] THEN
2105     CONV_TAC (TOP_DEPTH_CONV EXISTS_AND_CONV THENC
2106               EVERY_CONJ_CONV EXISTS_REFL_CONV) THEN
2107     REWRITE_TAC [])
2108     ([],term))
2109    of ([],func) => func []
2110    |  _ => raise Empty)
2111    handle e => raise (mkDebugExn "mk_initial"
2112                              "Failed to prove initial theorem!")
2113end
2114fun generalise_term_type t term =
2115let val (c,args) = strip_comb term
2116    val c' = first (can (C match_term c)) (constructors_of t)
2117in
2118    list_mk_comb(c',map2 generalise_term_type
2119                         (fst (strip_fun (type_of c'))) args)
2120end handle e =>
2121    if is_var term then mk_var(fst (dest_var term),t) else raise e;
2122fun rewrite_thm name initials body =
2123let val na =  mkStandardExn name "Not a nested constructor"
2124    val initial_type = type_of (hd (fst (strip_forall (concl (hd
2125                               (CONJUNCTS initials))))));
2126    val (_,eq) = strip_exists body
2127    val new_term = generalise_term_type initial_type (rhs eq)
2128    val nvars = free_vars_lr new_term
2129    val var = variant nvars (mk_var("argument",initial_type));
2130    val equality = mk_eq(var,new_term);
2131    val _ = if null nvars then raise na else ()
2132    val thm = tryfind (fn initial => SPEC_ALL (UNDISCH_ALL (
2133               PART_MATCH (lhs o snd o strip_forall o snd o strip_imp)
2134                          (DISCH_ALL initial) equality))) (CONJUNCTS initials)
2135    val thm' = RIGHT_CONV_RULE
2136               (TOP_DEPTH_CONV EXISTS_AND_CONV THENC
2137                    EVERY_CONJ_CONV (TRY_CONV EXISTS_REFL_CONV) THENC
2138                                PURE_REWRITE_CONV [AND_CLAUSES])
2139                               (foldr (MK_EXISTS o uncurry GEN) thm nvars);
2140    val _ = if not (is_conj (rhs (concl thm'))) then raise na else ()
2141in
2142    thm'
2143end
2144fun single_term term thm =
2145let val (var,body) = dest_forall term
2146    val d = rator (lhs body)
2147    val e = rator (rand (lhs body))
2148    val enc = get_encode_function (type_of (rand (lhs body))) (type_of var)
2149    val dec = get_decode_function (type_of (rand (lhs body))) (type_of var)
2150in
2151    INST [d |-> dec, e |-> enc] thm
2152end
2153fun fix_hyp thm =
2154    foldl (uncurry single_term) thm (hyp thm);
2155in
2156fun set_destructors target destructors =
2157let val _ = trace 2 "set_destructors"
2158    val list1 = map (fn x => ((fst o strip_comb o rand o lhs o
2159                               snd o strip_forall o concl) x,
2160                             SPEC_ALL x)) destructors
2161    val t = snd (strip_fun (type_of (fst (hd list1))))
2162    fun inst1 (term,thm) =
2163    let val m = match_type (snd (strip_fun (type_of term))) t
2164    in  (inst m term,INST_TYPE m thm) end;
2165    val bucketed = bucket_alist (map inst1 list1);
2166
2167    val initials = map (mk_initial o snd) bucketed
2168        handle e => wrapException "set_destructors" e
2169
2170    val _ = remove_coding_theorem_precise target t "destructors" handle e => ()
2171    val _ = remove_coding_theorem_precise target t "initial" handle e => ()
2172    val _ = add_coding_theorem_precise target t "destructors"
2173                                       (LIST_CONJ destructors)
2174    val _ = add_coding_theorem_precise target t "initial" (LIST_CONJ initials)
2175in
2176    ()
2177end
2178fun nested_constructor_theorem target term =
2179let val t = with_exn (type_of o lhs o snd o strip_exists) term
2180                     (mkStandardExn "nested_constructor_theorem"
2181                                "Not a term of the form: ?a.. x = C a..")
2182    val initial = get_coding_theorem target t "initial"
2183    val result =  snd (EQ_IMP_RULE
2184                     (rewrite_thm "nested_constructor_theorem" initial term))
2185    val result' = fix_hyp result
2186    val hs = map match_encdec (hyp result')
2187    handle e => wrapException "nested_constructor" e
2188in
2189     DISCH_ALL_CONJ (UNDISCH_CONJ (foldl (uncurry PROVE_HYP) result' hs))
2190end
2191fun nested_constructor_rewrite target term =
2192let val t = with_exn (type_of o lhs o snd o strip_exists o rand) term
2193                     (mkStandardExn "nested_constructor_theorem"
2194                                "Not a term of the form: bool (?a.. x = C a..)")
2195    val (enc,body) = with_exn dest_comb term
2196                     (mkStandardExn "nested_constructor_rewrite"
2197                      "Not an encoded term")
2198    val initial = get_coding_theorem target t "initial"
2199in
2200    DISCH_ALL_CONJ (AP_TERM enc
2201                      (rewrite_thm "nested_constructor_rewrite" initial body))
2202end
2203end
2204
2205(*****************************************************************************)
2206(* mk_destructor_rewrites : hol_type -> thm list -> thm list                 *)
2207(*                                                                           *)
2208(*     Destructors for regular datatypes are generated as:                   *)
2209(*        |- (FST o decode_pair f g o encode_type f' g') (C x ...) = x       *)
2210(*     When propagating over such theorems it is *very* important not to     *)
2211(*     encode 'encode_type f g A' seperately. This is because the final      *)
2212(*     theorem resolving the encoding must be of the form:                   *)
2213(*        |- ?x ... . A = C x ... ==>                                        *)
2214(*           (encode_type f' g' A = X)                                       *)
2215(*           (encode_pair f' g' (decode_pair f g (encode_type f' g' A)) = X) *)
2216(*     We therefore produce a theorem of the following form:                 *)
2217(*        |- (FST (decode_pair f g (encode_type f' g' x)) = X) ==>           *)
2218(*           ((FST o decode_pair f g o encode_type f' g') x = X)             *)
2219(*                                                                           *)
2220(*****************************************************************************)
2221
2222fun mk_destructor_rewrites target destructors =
2223let val types = map (type_of o lhs o concl) destructors
2224    val encs = map (get_encode_function target) types
2225    val lhss = map (rator o lhs o concl) destructors
2226    val var = mk_var("arg",fst (dom_rng (type_of (hd lhss))))
2227    val thms = map2 (fn l => fn e => ASSUME (mk_eq(mk_comb(e,mk_comb(l,var)),
2228                   genvar target))) lhss encs;
2229in
2230    map (DISCH T o DISCH_ALL o CONV_HYP (PURE_REWRITE_CONV [o_THM])) thms
2231end
2232
2233(*****************************************************************************)
2234(* mk_predicate_rewrites : hol_type -> hol_type -> thm list                  *)
2235(*                                                                           *)
2236(*    Predicates for regular datatypes are generated as:                     *)
2237(*             bool (?a b .. . x = Ci a b ..) =                              *)
2238(*             bool ((FST o decode_pair decode_num I o encode .. x)) = i)    *)
2239(*                                                                           *)
2240(*****************************************************************************)
2241
2242local
2243fun mk_predicates target t =
2244let val encoders = CONJUNCTS (get_coding_function_def target t "encode")
2245    val decode_pair = gen_decode_function target
2246                                          (pairLib.mk_prod(numLib.num,target));
2247    val map_thm = FULL_ENCODE_DECODE_MAP_THM target
2248                                          (pairLib.mk_prod(numLib.num,alpha));
2249in
2250    map (CONV_RULE (FORK_CONV (REWR_CONV (GSYM o_THM),
2251                     REWR_CONV pairTheory.FST_PAIR_MAP THENC
2252                     REWR_CONV I_THM THENC REWR_CONV pairTheory.FST)) o
2253         MK_FST o
2254         CONV_RULE (BINOP_CONV (REWR_CONV (GSYM o_THM)) THENC
2255                                RAND_CONV (RATOR_CONV (REWR_CONV map_thm))) o
2256         AP_TERM decode_pair o SPEC_ALL) encoders
2257end
2258fun pred_term var thm =
2259let val cons = rand (lhs (concl thm))
2260    val term = mk_eq(var,cons)
2261    val vars = snd (strip_comb cons)
2262in
2263    mk_eq(list_mk_exists(vars,term),
2264          mk_eq(mk_comb(rator (lhs (concl thm)),lhs term),rhs (concl thm)))
2265end;
2266in
2267fun mk_predicate_resolve target t =
2268let val encoders = CONJUNCTS (get_coding_function_def target t "encode")
2269    val p1 = pairLib.mk_prod(numLib.num,target);
2270    val p2 = pairLib.mk_prod(numLib.num,alpha);
2271    val decode_pair = gen_decode_function target p1
2272    val encode_pair = gen_encode_function target p1
2273    val thms =
2274        map (fn thm =>
2275              (CONV_RULE (RAND_CONV (RAND_CONV
2276                (REWR_CONV (GSYM o_THM) THENC
2277                 RATOR_CONV (REWR_CONV
2278                  (FULL_ENCODE_DECODE_MAP_THM target p2))) THENC
2279                   REWR_CONV (GSYM o_THM) THENC
2280                   RATOR_CONV (REWR_CONV
2281                    (FULL_ENCODE_MAP_ENCODE_THM target p2) THENC
2282                     PURE_REWRITE_CONV [I_o_ID]) THENC
2283                     REWR_CONV (GSYM thm))) o
2284              AP_TERM encode_pair o AP_TERM decode_pair o SPEC_ALL) thm)
2285             encoders;
2286    val var = genvar t
2287    val nchot = ISPEC var (TypeBase.nchotomy_of t);
2288    fun fix_thm thm =
2289    let val rnd = rand (rhs (concl thm))
2290        val term = first (can (match_term rnd) o rhs o snd o strip_exists)
2291                         (strip_disj (concl nchot))
2292        val (vars,tm) = strip_exists term
2293        val m = match_term rnd (rhs tm)
2294        val rconv = REWR_CONV (GSYM (ASSUME tm))
2295    in  CHOOSE_L (vars,ASSUME (list_mk_exists(vars,tm)))
2296         (CONV_RULE (RAND_CONV (RAND_CONV rconv) THENC
2297                     LAND_CONV (RAND_CONV (RAND_CONV (RAND_CONV rconv))))
2298          (INST_TY_TERM m thm))
2299    end
2300in
2301    GEN var (DISJ_CASESL nchot (map fix_thm thms))
2302end
2303fun mk_predicate_rewrites target t =
2304let val predicates = mk_predicates target t
2305        handle e => raise (mkStandardExn "mk_predicate_rewrites"
2306               ("Could not prove initial predicate theorems...\nis " ^
2307                type_to_string t ^ " encoded as a labelled product?\n"))
2308    val var = genvar t
2309    val thm =
2310    (case ((Cases THEN REWRITE_TAC (TypeBase.one_one_of t::
2311                             GSYM (TypeBase.distinct_of t)::
2312                             TypeBase.distinct_of t::predicates) THEN
2313     REPEAT (CHANGED_TAC (CONV_TAC (
2314            EVERY_CONJ_CONV (TRY_CONV EXISTS_REFL_CONV) THENC
2315            EVERY_CONJ_CONV (TOP_DEPTH_CONV EXISTS_AND_CONV)))) THEN
2316     CONV_TAC reduceLib.REDUCE_CONV)
2317    ([],mk_forall(var,list_mk_conj (map (pred_term var) predicates))))
2318    of ([],func) => func [] | _ => raise Empty)
2319    handle _ => raise (mkDebugExn "mk_predicate_rewrites"
2320                ("Could not prove the predicate theorem for the type: " ^
2321                 type_to_string t))
2322in
2323    (mk_destructor_rewrites target predicates,
2324     map (AP_TERM (get_encode_function target bool))
2325        (CONJUNCTS (SPEC_ALL thm)))
2326end
2327end
2328
2329(*****************************************************************************)
2330(* mk_case_propagation_theorem : hol_type -> hol_type -> term                *)
2331(*                                                                           *)
2332(*    Makes a standard propagation theorem, provided destructors have been   *)
2333(*    provided using set_destructors.                                        *)
2334(*                                                                           *)
2335(*****************************************************************************)
2336
2337local
2338fun mk_result out var destructors funcname c vars =
2339let val filtered = filter (can (match_term c) o fst o strip_comb o
2340                   rand o lhs o snd o strip_forall o concl) destructors
2341    fun pos thm = index (curry op= (rhs (concl thm)))
2342                        (snd (strip_comb (rand (lhs (concl thm)))));
2343    val mapped = map (fn x => (pos x,x)) filtered;
2344    val sorted = map (fn (n,_) => snd (first (curry op= n o fst) mapped))
2345                     (enumerate 0 (fst (strip_fun (type_of c))));
2346    val rfunc = mk_var(funcname,foldr (op-->) out (map type_of vars));
2347 in
2348    (list_imk_comb(rfunc,
2349        map (fn x => imk_comb(rator (lhs (concl x)),var)) sorted),sorted)
2350end handle e => wrapException "mk_result" e
2351fun wrap e = wrapException "mk_case_propagation_theorem" e
2352fun set_type t thm =
2353let val types = map (type_of o rand o lhs o concl) (CONJUNCTS thm)
2354in
2355    INST_TYPE (tryfind_e Empty (C match_type t) types) thm
2356end handle Empty => thm
2357in
2358fun mk_case_propagation_theorem target t =
2359let val _ = trace 2 "->mk_case_propagation_theorem";
2360    val destructor_thm = set_type t (get_coding_theorem target t "destructors")
2361        handle e => wrap e
2362    val hs = hyp destructor_thm
2363    val destructors = map SPEC_ALL (CONJUNCTS destructor_thm)
2364    val constructors = constructors_of t handle e =>
2365        raise (mkStandardExn "mk_case_propagation_theorem"
2366                             ("The type: " ^ type_to_string t ^
2367                              " does not appear to be a regular datatype."))
2368    val var = mk_var("arg",t)
2369    val vars = map (fn c => map (fn (n,a) => mk_var(implode (base26 n),a))
2370                   (enumerate 0 (fst (strip_fun (type_of c))))) constructors;
2371    val cases = map2 (fn v => fn c => list_mk_exists(v,
2372                         mk_eq(var,list_mk_comb(c,v))))
2373                vars constructors
2374    val (results,thms) = unzip (map2 (fn (n,c) => fn v =>
2375                       mk_result (mk_vartype "'out") var destructors
2376                                 (implode (base26 n)) c v)
2377                       (enumerate 0 constructors) vars) handle e => wrap e
2378    val conditional = list_mk_cond (butlast (zip cases results)) (last results)
2379                      handle e => wrap e
2380    val case_defs1 = map SPEC_ALL (CONJUNCTS (TypeBase.case_def_of t));
2381    val case_defs2 = map (fn thm => INST_TYPE
2382                         (match_type (type_of (rhs (concl thm)))
2383                                        (mk_vartype "'out")) thm) case_defs1
2384                     handle e => wrap e;
2385    val case_defs3 = map (fn thm => INST_TYPE
2386                         (match_type (type_of (rand (lhs (concl thm)))) t) thm)
2387                         case_defs2
2388                     handle e => wrap e;
2389    val ordered = map (fn c => first (can (match_term c) o fst o
2390                               strip_comb o rand o lhs o concl)
2391                               case_defs3) constructors
2392                     handle e => wrap e;
2393    val case_term1 = rator (lhs (concl (hd ordered))) handle e => wrap e
2394    val normalized = map (C (PART_MATCH (rator o lhs)) case_term1) ordered
2395                     handle e => wrap e
2396    val inst = map2 (fn nml => fn r =>
2397             (fst (strip_comb (rhs (concl nml)))) |-> (fst (strip_comb r)))
2398                    normalized results handle e => wrap e
2399    val case_defs = map (INST inst) normalized handle e => wrap e
2400    val case_term = rator (lhs (concl (hd case_defs))) handle e => wrap e
2401
2402    val term = mk_forall(var,mk_eq(mk_comb(case_term,var),conditional))
2403               handle e => wrap e
2404    val enc = mk_var("encode",mk_vartype "'out" --> mk_vartype "'target")
2405    val thms = mapfilter (fn f => f t)
2406                         [TypeBase.one_one_of,TypeBase.distinct_of,
2407                          GSYM o TypeBase.distinct_of]
2408               @ (case_defs @ destructors);
2409in
2410    (DISCH_ALL_CONJ (AP_TERM enc (SPEC_ALL
2411         (case ((Cases THEN REPEAT (CHANGED_TAC (REWRITE_TAC thms)) THEN
2412                 REPEAT (
2413                        CONV_TAC (DEPTH_CONV (EXISTS_REFL_CONV ORELSEC
2414                                              EXISTS_AND_CONV)) THEN
2415                        REWRITE_TAC [])) (hs,term))
2416          of ([],func) => func []
2417          |  _ => raise Empty))))
2418     handle e =>
2419     (if !debug
2420         then (set_goal(hs,term) ;
2421               raise (mkDebugExn "mk_case_propagation_theorem"
2422                "Unable to prove propagation theorem! (Now set as goal)"))
2423         else raise (mkDebugExn "mk_case_propagation_theorem"
2424                 "Unable to prove propagation theorem!"))
2425end
2426end
2427
2428(*****************************************************************************)
2429(* Propagation theorem for a case constructor for a single constructor.      *)
2430(*****************************************************************************)
2431
2432fun mk_single_case_propagation_theorem t =
2433let val thm = SPEC_ALL (TypeBase.case_def_of t)
2434    val enc = mk_var("encode",type_of (lhs (concl thm)) --> gen_tyvar())
2435in
2436    AP_TERM enc thm
2437end;
2438
2439(*****************************************************************************)
2440(* Propagation theorem for a label type.                                     *)
2441(*****************************************************************************)
2442
2443fun mk_label_case_propagation_theorem t =
2444let val thm = SPEC_ALL (TypeBase.case_def_of t)
2445    val cs = constructors_of t
2446    val vlist = map (C (curry mk_var) alpha o implode o base26 o fst)
2447                    (enumerate 0 cs);
2448    val var = mk_var("argument",t);
2449    val ifstatement =
2450        list_mk_cond (map2 (fn v => fn c => (mk_eq(var,c),v))
2451                           (butlast vlist) (butlast cs)) (last vlist)
2452    val encoder = mk_var("encode",alpha --> gen_tyvar());
2453    val term = mk_eq(TypeBase.mk_case(var,zip cs vlist),ifstatement);
2454
2455    val nchot = ISPEC var (TypeBase.nchotomy_of t)
2456    val distinct = TypeBase.distinct_of t
2457    val thms = map (fn n => CONV_RULE (bool_EQ_CONV)
2458                       (REWRITE_CONV [GSYM distinct,distinct,thm,ASSUME n]
2459                            term))
2460                   (strip_disj (concl nchot))
2461in
2462   AP_TERM encoder (DISJ_CASESL nchot thms)
2463end handle e => wrapException "mk_label_case_propagation_theorem" e
2464
2465(*****************************************************************************)
2466(* target_function_conv : hol_type -> term -> thm                            *)
2467(*    If a term is of the form: I M and all subterms of M are functions from *)
2468(*    over the type given, then I M = M is returned.                         *)
2469(*****************************************************************************)
2470
2471fun check_term target M =
2472    (is_var M andalso type_of M = target) orelse
2473    (is_const M andalso
2474        all (curry op= target)
2475            (uncurry (C cons) (strip_fun (type_of M)))) orelse
2476    (pairSyntax.is_pabs M andalso
2477        both ((all (curry op= target o type_of) o pairSyntax.strip_pair ##
2478               check_term target) (pairSyntax.dest_pabs M))) orelse
2479    (is_comb M andalso check_term target (rator M) andalso
2480                       check_term target (rand M));
2481
2482fun target_function_conv target term =
2483    if same_const (rator term) (mk_const("I",target --> target)) andalso
2484       check_term target (rand term)
2485       then REWR_CONV I_THM term
2486       else NO_CONV term
2487
2488(*****************************************************************************)
2489(* dummy_encoder_conv : hol_type -> term -> thm                              *)
2490(*****************************************************************************)
2491
2492fun dummy_encoder_conv target term =
2493    if same_const (rator term) (mk_const("I",target --> target)) andalso
2494       type_of (rand term) = target andalso is_encoded_term (rand term)
2495       then REWR_CONV I_THM term
2496       else NO_CONV term
2497
2498(*****************************************************************************)
2499(* add_standard_coding_rewrites : hol_type -> hol_type -> unit               *)
2500(*                                                                           *)
2501(* Add rewrites to encode constructors, case theorems, and encoding of       *)
2502(* decoded terms:                                                            *)
2503(*      |- encode (C x y z) = cons 0 (encode x) (encode y) (encode z)        *)
2504(*      |- E (if ?a b c. X = C a b c then A (f0 (R x)) ... = Y) ==>          *)
2505(*         (E (case x of C a b c -> A a b c ....) = Y                        *)
2506(*      |- (!x. f0' (f0 x) = x) /\ (encode f0 f1 ... x = X) /\               *)
2507(*         (?a b c. X = C a b c) ==> T ==>                                   *)
2508(*              (f0 (f0' (R X)) = R X)                                       *)
2509(*      |- detect a ==> (encode (decode a) = a)                              *)
2510(*                                                                           *)
2511(*****************************************************************************)
2512
2513local
2514fun U x = UNDISCH_CONJ x handle e => x
2515fun check_case_thm target t =
2516let val strip1 = lhs o snd o strip_imp o concl
2517    val strip2 = fst o strip_comb o rand
2518    val all_lhss = mapfilter (fn (a,b,c) => strip1 c)
2519                             (Net.listItems (!rewrites));
2520    val all_consts = map strip2 (
2521                         filter (can (C match_type target) o type_of) all_lhss)
2522in
2523   exists (can (C match_term (TypeBase.case_const_of t))) all_consts
2524end;
2525fun encode_names target t =
2526let val conjuncts = enumerate 0 (CONJUNCTS
2527                      (get_coding_function_def target t "encode"))
2528    val name = fst (dest_type t)
2529in
2530    map (fn (n,c) => ("C" ^ int_to_string n ^ "_" ^ name,c)) conjuncts
2531end
2532fun add_ifnew_standard_rewrite priority name thm =
2533    if exists_rewrite name
2534       then ()
2535       else add_standard_rewrite priority name thm
2536fun add_ifnew_conditional_rewrite priority name thm =
2537    if exists_rewrite name
2538       then ()
2539       else add_conditional_rewrite priority name thm
2540in
2541fun add_decodeencode_rewrites target t =
2542let val decenc_thm = DISCH_ALL_CONJ (U (SPEC_ALL (U (SPEC_ALL
2543                                (FULL_DECODE_ENCODE_THM target t)))))
2544        handle e => wrapException "add_decodeencode_function" e
2545    val name = fst (dest_type t)
2546in  (add_ifnew_standard_rewrite 0 ("DE_" ^ name) decenc_thm ;
2547     decenc_thm)
2548end
2549fun add_encode_rewrites target t =
2550let val filtered = filter (not o exists_rewrite o fst) (encode_names target t)
2551in
2552    (app (uncurry (add_ifnew_standard_rewrite 0)) filtered ;
2553     snd (hd filtered))
2554end;
2555fun add_case_rewrites target t =
2556let val (pair_rewrites,destructor_rewrites) =
2557        if exists_coding_theorem target t "destructors"
2558           orelse length (constructors_of t handle _ => [T]) = 1
2559           orelse all (not o can dom_rng o type_of) (constructors_of t)
2560           then ([],[])
2561           else let val (p,d) = mk_destructors target t
2562                    val _ = set_destructors target d
2563                in   (p,mk_destructor_rewrites target d)
2564                end handle e => wrapException "add_case_rewrites" e
2565    val case_thm =
2566        if check_case_thm target t then NONE else
2567        if length (constructors_of t) = 1 then
2568           SOME (mk_single_case_propagation_theorem t)
2569        else if all (not o can dom_rng o type_of) (constructors_of t) then
2570           SOME (mk_label_case_propagation_theorem t)
2571        else SOME (mk_case_propagation_theorem target t
2572                   handle e => wrapException "add_case_rewrites" e)
2573    val name = fst (dest_type t)
2574    val (predicate_rewrites,predicates) =
2575        if length (constructors_of t) = 1 orelse
2576           all (not o can dom_rng o type_of) (constructors_of t) then ([],[])
2577           else    mk_predicate_rewrites target t
2578                   handle e => if isFatal e then raise e
2579                                 else (trace 1 (exn_to_string e) ; ([],[]))
2580    val predicate_resolve = total (mk_predicate_resolve target) t;
2581in
2582    (Option.map (add_ifnew_standard_rewrite 0 ("Case_" ^ name)) case_thm ;
2583     app (fn (n,c) => add_standard_rewrite 0
2584                      ("EDp_" ^ name ^ int_to_string n) c)
2585                      (enumerate 0 pair_rewrites) ;
2586     app (fn (n,c) => add_ifnew_standard_rewrite 0
2587                      ("P_" ^ name ^ int_to_string n) c)
2588                      (enumerate 0 predicates) ;
2589     Option.map (add_ifnew_standard_rewrite 0 ("PED_" ^ name))
2590                 predicate_resolve ;
2591     app (fn (n,c) => add_ifnew_conditional_rewrite 0
2592                      ("Po_" ^ name ^ int_to_string n) c)
2593                      (enumerate 0 predicate_rewrites) ;
2594     app (fn (n,c) => add_ifnew_conditional_rewrite 0
2595                      ("Do_" ^ name ^ int_to_string n) c)
2596                      (enumerate 0 destructor_rewrites) ;
2597     case_thm)
2598     handle e => wrapException "add_case_rewrites" e
2599end
2600fun add_standard_coding_rewrites target t =
2601let val _ = if (can (match_type t) (base_type t)) then ()
2602        else raise (mkStandardExn "add_standard_coding_rewrites"
2603                   ("The type " ^ type_to_string t ^ " is not a base type"))
2604    val _ = if can TypeBase.constructors_of t then () else
2605            raise (mkStandardExn "add_standard_coding_rewrites"
2606                  ("The type supplied: " ^ type_to_string t ^
2607                   " does not appear to be a regular datatype.\n"))
2608    val _ = encode_type target t
2609    val _ = add_encode_rewrites target t
2610        handle Empty => TRUTH
2611             | e => wrapException "add_standard_coding_rewrites" e
2612    val _ = add_decodeencode_rewrites target t
2613        handle e => wrapException "add_standard_coding_rewrites" e
2614    val _ = add_case_rewrites target t
2615        handle e => wrapException "add_standard_coding_rewrites" e
2616in
2617    ()
2618end
2619end;
2620
2621(*****************************************************************************)
2622(* polytypic_encodedecode, polytypic_casestatement:                          *)
2623(*    If a statement is of the form: encode (decode x) or encode (case ...)  *)
2624(*    this function calls add_standard_rewrites then returns the             *)
2625(*    relevant theorems.                                                     *)
2626(*                                                                           *)
2627(*****************************************************************************)
2628
2629fun polytypic_decodeencode term =
2630let val (encoder,decoded_term) = dest_comb term
2631    val target = type_of term
2632    val var_type = type_of decoded_term
2633    val decoder = get_decode_function target var_type
2634    val name = fst (dest_type var_type)
2635in  (if can (match_term decoder) (rator decoded_term)
2636       then conditionize_rewrite (add_decodeencode_rewrites target var_type)
2637       else raise (mkStandardExn "polytypic_decodeencode"
2638                                 "Not a term: encode (decode x)"))
2639end
2640
2641fun polytypic_casestatement term =
2642    if (TypeBase.is_case (rand term))
2643       then conditionize_rewrite (Option.valOf (add_case_rewrites (type_of term)
2644                         (base_type (type_of (rand (rand term))))))
2645       else raise (mkStandardExn "polytypic_casestatement"
2646                                 "Not an encoded case statement");
2647
2648fun polytypic_encodes term =
2649    if (op_mem (fn a => fn b => can (match_term a) b)
2650               (rand term)
2651               (constructors_of (type_of (rand term))))
2652       then conditionize_rewrite (add_encode_rewrites (type_of term)
2653                                 (base_type (type_of (rand term))))
2654       else raise (mkStandardExn "polytypic_encodes"
2655                                 "Not an encoded constructor");
2656
2657(*****************************************************************************)
2658(* prove_propagation_theorem : thm list -> thm -> thm                        *)
2659(*                                                                           *)
2660(*    Takes a definition term, D,  of the form:                              *)
2661(*      ``!a b. <f> a b =                                                    *)
2662(*        encode (if (detect a /\ detect b /\                                *)
2663(*                    ... /\ t (decode a) (decode b) /\ ....                 *)
2664(*                    ... /\ P (decode a) (decode b) /\ ....)                *)
2665(*                   then f (decode a) (decode b)                            *)
2666(*                   else bottom)``                                          *)
2667(*    and proves a propagation theorem:                                      *)
2668(*    [D] |- P a b ... ==> (encode (f a b) = <f> (encode a) (encode b))      *)
2669(*                                                                           *)
2670(*****************************************************************************)
2671
2672val prove_propagation_theorem_data
2673    = ref (NONE : (thm option * thm list * term) option);
2674
2675local
2676fun exn1 this_function = mkStandardExn this_function
2677 "Definition is not of the form: \"|- !a b. f a b ... = encode (...)\"";
2678fun ECC c x =
2679    c x handle e =>
2680    MK_CONJ (ECC c (fst (dest_conj x))) (ECC c (snd (dest_conj x))) handle e =>
2681    NO_CONV x;
2682fun TCONV conv term =
2683let val r = conv term
2684in
2685    if rhs (concl r) = T then r else NO_CONV term
2686end;
2687fun map_thm_vars thm =
2688    (fn (l,r) =>
2689     if (same_const (repeat rator l) (repeat rator r) handle _ => false)
2690        then filter is_var (snd (strip_comb r))
2691        else filter is_var (snd (strip_comb (rand r))))
2692    (dest_eq (snd (strip_imp (concl (SPEC_ALL thm)))))
2693fun reduce_hyp (thm1,thm2) =
2694let val x = tryfind_e Empty (PART_MATCH (snd o dest_imp) thm1) (hyp thm2)
2695in  PROVE_HYP (UNDISCH_ONLY x) thm2
2696end handle _ => thm2
2697fun prove_propagation_theorem_local
2698    this_function map_thm tautologies definition =
2699let val _ = prove_propagation_theorem_data
2700          := SOME (map_thm,tautologies,definition);
2701    val (left,right) = with_exn (dest_eq o snd o strip_forall) definition
2702                       (exn1 this_function)
2703    val _ = trace 1 "Proving propagation theorem...\n"
2704    val target = type_of right
2705    val rright = with_exn rand right (exn1 this_function)
2706    val args = (snd o strip_comb) left
2707    val function_term = if is_cond rright then (rand o rator) rright else rright
2708    fun test x = is_var (rand x) handle _ => false
2709    val decoded_args =
2710        case map_thm
2711        of NONE => filter test ((snd o strip_comb) function_term)
2712        |  SOME thm =>
2713           map2 (fn a => fn b =>
2714                  mk_comb(gen_decode_function target (type_of a),b))
2715                (map_thm_vars thm)
2716                (map rand
2717                     (filter test ((snd o strip_comb) function_term)))
2718        handle e => wrapException this_function e
2719
2720    val encoded_args = map2 (fn d => fn a =>
2721            mk_comb(gen_encode_function target (type_of d),
2722                    mk_var(fst(dest_var a),type_of d)))
2723                    decoded_args args
2724            handle e => wrapException this_function e
2725    val rule = PURE_REWRITE_RULE [o_THM,FUN_EQ_THM,K_THM,I_THM];
2726
2727    (* Used to use FULL.... changed 09/10/2010 *)
2728    val encdetall_thms = map (
2729           rule o generate_coding_theorem target "encode_detect_all" o type_of) decoded_args
2730           handle e => wrapException this_function e
2731    val encdecmap_thms = map (
2732           rule o generate_coding_theorem target "encode_decode_map" o type_of) decoded_args
2733           handle e => wrapException this_function e
2734    val allid_thms = map (
2735           rule o FULL_ALL_ID_THM o type_of) decoded_args
2736           handle e => wrapException this_function e
2737    val mapid_thms = map (
2738           rule o FULL_MAP_ID_THM o type_of) decoded_args
2739           handle e => wrapException this_function e
2740    val encmap_thm =
2741           total (rule o FULL_ENCODE_MAP_ENCODE_THM target o type_of o
2742            rhs o snd o strip_imp o concl o SPEC_ALL o valOf) map_thm;
2743
2744    val instantiated =
2745           INST
2746           (map2 (fn arg => fn ed => arg |-> ed) args encoded_args)
2747           (SPEC_ALL (ASSUME definition))
2748           handle e => wrapException this_function e
2749    val _ = trace 3 ("Instantiated definition:\n" ^
2750                     thm_to_string instantiated ^ "\n")
2751    val thms_as_rwrs =
2752           map (fn t => REWRITE_CONV [t] (snd (strip_forall (concl t))))
2753               tautologies
2754    val filter_refl = filter
2755        (not o exists ((fn x => (lhs x = rhs x)) o snd o strip_forall) o
2756         strip_conj o concl o SPEC_ALL)
2757    fun rwr_conv term =
2758        (tryfind_e Empty
2759                  (UNDISCH_CONJ o
2760                   CONV_RULE (LAND_CONV (PURE_REWRITE_CONV
2761                             (filter_refl (encdecmap_thms @ mapid_thms)))) o
2762                   C (PART_MATCH (lhs o snd o strip_imp)) term o
2763                   DISCH_ALL_CONJ)
2764                  thms_as_rwrs
2765        handle Empty =>
2766        REPEATC (CHANGED_CONV (PURE_REWRITE_CONV (filter_refl
2767                (encdetall_thms @ allid_thms @ [o_THM,K_o_THM])))) term)
2768        handle e => wrapException this_function e;
2769    val cond_proved =
2770           (if is_cond rright then
2771              RIGHT_CONV_RULE (RAND_CONV (RATOR_CONV (RATOR_CONV (RAND_CONV (
2772                  ECC (TCONV rwr_conv) THENC
2773                  REWRITE_CONV [AND_CLAUSES]))) THENC
2774                  REWR_CONV (CONJUNCT1 (SPEC_ALL COND_CLAUSES))))
2775              instantiated
2776           else instantiated)
2777           handle e => raise (mkStandardExn this_function
2778                  ("Could not prove all terms in the conjunction:\n" ^
2779                   term_to_string (rand (rator (rator (rand
2780                                  (rhs (concl instantiated)))))) ^
2781                   "\nusing the list of theorems:\n" ^
2782                   xlist_to_string thm_to_string thms_as_rwrs));
2783    val cond_proved' = foldl reduce_hyp cond_proved tautologies;
2784    val list = set_diff (hyp cond_proved') [definition]
2785    fun MCONV NONE term = REFL term
2786      | MCONV (SOME map_thm) term =
2787        UNDISCH_ALL (PART_MATCH (lhs o snd o strip_imp) map_thm term)
2788    fun fix_I term =
2789        if is_encoded_term term then REFL term
2790           else SYM (ISPEC term I_THM)
2791    fun rwr_ed_conv term =
2792    let val r =
2793        (REPEATC (CHANGED_CONV (RAND_CONV (PURE_REWRITE_CONV (filter_refl
2794                (encdecmap_thms @ mapid_thms @ [I_THM,I_o_ID]))))) THENC
2795        RAND_CONV (MCONV map_thm) THENC
2796        PURE_REWRITE_CONV [I_THM,I_o_ID] THENC
2797        PURE_ONCE_REWRITE_CONV (mapfilter valOf [encmap_thm]) THENC
2798        PURE_REWRITE_CONV [I_o_ID])
2799        term handle e => wrapException this_function e
2800    in
2801        if all (fn x => is_var x orelse mem x (snd (strip_comb function_term)))
2802               (snd (strip_comb (rand (rhs (concl r)))))
2803               then RIGHT_CONV_RULE fix_I r else
2804        raise (mkDebugExn this_function
2805              ("Could not fully reduce the term: " ^ term_to_string term))
2806    end
2807    fun check x =
2808        if length (hyp x) = 1 then x else
2809           raise (mkDebugExn this_function
2810                 ("The resulting propagation theorem contains an unwanted " ^
2811                  "hypothesis:\n" ^ thm_to_string x ^
2812                  "\nIf this is a polymorphic theorem, check that theorems" ^
2813                  "\ndemonstrating the limits are map-invariant are included" ^
2814                  "\neg: (?a b. x = a :: b) ==> (?a b. MAP f x = a :: b)"))
2815in  check
2816    (DISCH_LIST_CONJ list (SYM (RIGHT_CONV_RULE rwr_ed_conv cond_proved')))
2817    before (prove_propagation_theorem_data := NONE)
2818end
2819in
2820fun prove_propagation_theorem tautologies definition =
2821    prove_propagation_theorem_local
2822    "prove_propagation_theorem" NONE tautologies definition
2823fun prove_polymorphic_propagation_theorem map_thm tautologies definition =
2824    prove_propagation_theorem_local
2825    "prove_polymorphic_propagation_theorem" (SOME map_thm)
2826    tautologies definition
2827end;
2828
2829(*****************************************************************************)
2830(* mk_analogue_definition                                                    *)
2831(*         : hol_type -> string -> thm list -> thm list -> thm -> thm        *)
2832(* mk_polymorphic_analogue_definition                                        *)
2833(*         : hol_type -> string -> thm -> thm list -> thm list -> thm -> thm *)
2834(* mk_analogue_definition_term : string -> thm list -> thm -> term           *)
2835(*                                                                           *)
2836(*     mk_analogue_definition target name                                    *)
2837(*            [... |- !a b. t ...]   (|- f a b = A a b, [\a b. P a b...])    *)
2838(*     creates a definition term, D, of the following form:                  *)
2839(*     |- !a b. <f> a b =                                                    *)
2840(*        encode (if (detect a /\ detect b /\                                *)
2841(*                    t (decode a) (decode b) /\ ... /\                      *)
2842(*                    P (decode a) (decode b) /\ ...)                        *)
2843(*                   then f (decode a) (decode b)                            *)
2844(*                   else bottom)                                            *)
2845(*     and returns the propagation theorem:                                  *)
2846(*     [D] |- P a b /\ ... ==> (encode (f a b) = <f> (encode a) (encode b)) *)
2847(*                                                                           *)
2848(*****************************************************************************)
2849
2850local
2851val this_function = "MK_DEFINITION";
2852val exn1 = mkStandardExn this_function
2853           "Function is not of the form: \"|- !a b. P ==> f a b ... = \"";
2854fun MK_DEFINITION tvs target limits function =
2855let val tfunction = INST_TYPE (map (fn x => x |-> target) tvs) function
2856    val sfunction = SPEC_ALL tfunction
2857    val STRIP = snd o strip_imp o concl
2858    val (fconst,args) = with_exn (strip_comb o lhs o STRIP) sfunction exn1
2859    val result_type = with_exn (type_of o rhs o STRIP) sfunction exn1
2860    val (normal_args,higher_args) = partition is_var args
2861    val new_args = map (C (curry mk_var) target o fst o dest_var) normal_args
2862    val decoded_args = map2 (fn arg => fn new_arg =>
2863            mk_comb(gen_decode_function target (type_of arg),new_arg))
2864            normal_args new_args handle e => wrapException this_function e
2865    val specced_limits =
2866        map (full_beta_conv o C (curry list_imk_comb) decoded_args) limits
2867            handle e => wrapException this_function e
2868    val detected_args = map2 (fn arg => fn new_arg =>
2869            mk_comb(gen_detect_function target (type_of arg),new_arg))
2870            normal_args new_args handle e => wrapException this_function e
2871    val decode_map =
2872        map (C assoc (zip normal_args decoded_args @
2873                      zip higher_args higher_args)) args
2874    val body = list_mk_comb(fconst,decode_map)
2875            handle e => wrapException this_function e
2876    val bottom = #bottom (get_translation_scheme target)
2877            handle e => wrapException this_function e
2878in
2879     (body,new_args,detected_args,specced_limits,target,result_type,bottom)
2880end
2881in
2882fun mk_analogue_definition_term target name limits function =
2883let val this_function = "mk_analogue_definition_term"
2884    val _ = trace 2 "->mk_analogue_definition_term\n"
2885    val _ = trace 1 ("Creating analogue definition of:\n" ^
2886                      thm_to_string function ^ "\n")
2887    val (body,new_args,detected_args,specced_limits,
2888                target,result_type,bottom) =
2889         MK_DEFINITION [] target limits function
2890         handle e => wrapException this_function e
2891    val conditional =
2892            case (detected_args @ specced_limits)
2893            of (x::xs) => mk_cond(list_mk_conj(x::xs),body,
2894                   construct_bottom_value
2895                   (fn x => is_vartype x orelse x = target) bottom result_type)
2896            | [] => body handle e => wrapException this_function e
2897     val right = mk_comb(gen_encode_function target result_type,conditional)
2898            handle e => wrapException this_function e
2899     val new_fconst = mk_var(name,foldl op--> target (map type_of new_args))
2900     val function_term = mk_eq(list_mk_comb(new_fconst,new_args),right)
2901            handle e => wrapException this_function e
2902in
2903     list_mk_forall(new_args,function_term)
2904end  handle e => wrapException this_function e
2905fun get_tautologies tautologies specced_limits new_args =
2906let val specced_tauts =
2907        mapfilter (fn l => tryfind (C (PART_MATCH I) l) tautologies)
2908                  specced_limits
2909    val x = map concl specced_tauts
2910in case (total (first (not o null o C set_diff new_args o free_vars)) x)
2911           of NONE => specced_tauts
2912           |  SOME y => raise (mkStandardExn "get_tautologies"
2913                         ("The tautology:\n" ^ term_to_string y ^
2914                          "\ncontains free variables after instantiation"))
2915end
2916fun mk_analogue_definition target name tautologies limits function =
2917let val this_function = "mk_analogue_definition"
2918    val _ = trace 2 "->mk_analogue_definition\n"
2919    val _ = trace 1 ("Creating analogue definition of:\n" ^
2920                      thm_to_string function ^ "\n")
2921    val (body,new_args,detected_args,specced_limits
2922                ,target,result_type,bottom) =
2923         MK_DEFINITION [] target limits function
2924         handle e => wrapException this_function e
2925    val checked_tauts = get_tautologies tautologies specced_limits new_args
2926    val conditional =
2927            case (detected_args @ specced_limits)
2928            of (x::xs) => mk_cond(list_mk_conj(x::xs),body,
2929                   construct_bottom_value
2930                   (fn x => is_vartype x orelse x = target) bottom result_type)
2931            | [] => body handle e => wrapException this_function e
2932     val right = mk_comb(gen_encode_function target result_type,conditional)
2933            handle e => wrapException this_function e
2934     val new_fconst = mk_var(name,foldl op--> target (map type_of new_args))
2935     val function_term = mk_eq(list_mk_comb(new_fconst,new_args),right)
2936            handle e => wrapException this_function e
2937in
2938    prove_propagation_theorem (checked_tauts @ map ASSUME specced_limits)
2939         (list_mk_forall(new_args,function_term))
2940end handle e => wrapException this_function e
2941fun mk_polymorphic_analogue_definition
2942    target name map_thm tautologies limits extras function =
2943let val this_function = "mk_polymorphic_analogue_definition"
2944    val _ = trace 2 "->mk_polymorphic_analogue_definition\n"
2945    val _ = trace 1 ("Creating analogue definition of:\n" ^
2946                      thm_to_string function ^ "\n")
2947    val map_lhs = lhs o snd o strip_imp o snd o strip_forall o concl
2948    val maps = snd (strip_comb (map_lhs map_thm))
2949               handle e => wrapException this_function e;
2950    val tvs = filter is_vartype (flatten
2951              (map (map snd o reachable_graph sub_types o type_of) maps));
2952    val map_thm1 = INST_TYPE (map (fn x => x |-> target) tvs) map_thm
2953    val match = match_term
2954                ((fst o strip_comb o map_lhs) function)
2955                ((fst o strip_comb o map_lhs) map_thm1)
2956                handle e => raise (mkStandardExn this_function
2957                     ("The instantiated map theorem does not match function." ^
2958                      "\nMap theorem uses the constant: " ^
2959                      term_to_string ((fst o strip_comb o map_lhs) map_thm1) ^
2960                      "\nwhich does not match the function constant: " ^
2961                      term_to_string ((fst o strip_comb o map_lhs) function)))
2962    val type_vars = map #redex (filter (curry op= target o #residue)
2963                        (snd match));
2964    val (body,new_args,detected_args,specced_limits
2965                ,target,result_type,bottom) =
2966         MK_DEFINITION type_vars target limits function
2967         handle e => wrapException this_function e
2968    val checked_tauts = get_tautologies tautologies specced_limits new_args
2969    val conditional =
2970            case (detected_args @ specced_limits)
2971            of (x::xs) => mk_cond(list_mk_conj(x::xs),body,
2972                   construct_bottom_value
2973                   (fn x => is_vartype x orelse x = target) bottom result_type)
2974            | [] => body handle e => wrapException this_function e
2975     val right = mk_comb(gen_encode_function target result_type,conditional)
2976            handle e => wrapException this_function e
2977     val new_fconst = mk_var(name,foldl op--> target (map type_of new_args))
2978     val function_term = mk_eq(list_mk_comb(new_fconst,new_args),right)
2979            handle e => wrapException this_function e
2980     val taut_limits = map ASSUME specced_limits
2981     val extra_limits = filter (is_imp o snd o strip_forall o concl) extras
2982in
2983    prove_polymorphic_propagation_theorem
2984         map_thm
2985         (checked_tauts @ taut_limits @ extra_limits)
2986         (list_mk_forall(new_args,function_term))
2987end handle e => wrapException this_function e
2988end;
2989
2990(*****************************************************************************)
2991(* clause_to_limit : term -> term                                            *)
2992(*                                                                           *)
2993(*    Converts a missing clause to a limit                                   *)
2994(*    clause_to_limit ``f (C a b c) (C d)`` =                                *)
2995(*                       ``\x y. ~((?a b c. x = C a b c) /\ (?d. y = C d))`` *)
2996(*                                                                           *)
2997(*****************************************************************************)
2998
2999fun clause_to_limit missing =
3000let val (fconst,constructors) = strip_comb missing
3001    val fvs = free_varsl constructors
3002    val arg_names = map (implode o base26 o fst) (enumerate 0 constructors)
3003    val args = map2 (fn an => fn c => variant fvs (mk_var(an,type_of c)))
3004                    arg_names constructors
3005    val clauses = mapfilter (fn (arg,cs) =>
3006        if is_var cs then raise Empty
3007                     else list_mk_exists(free_vars_lr cs,mk_eq(arg,cs)))
3008        (zip args constructors)
3009in
3010    case clauses
3011    of [] => raise (mkStandardExn "clause_to_limit"
3012                   ("Missing clause: " ^ term_to_string missing ^
3013                    " has no constructors."))
3014    | _ => list_mk_abs(args,mk_neg(list_mk_conj clauses))
3015end;
3016
3017(*****************************************************************************)
3018(* limit_to_theorems : hol_type -> term -> thm                               *)
3019(*                                                                           *)
3020(*    Calculates the nested theorems required for a missing clause limit:    *)
3021(*    limit_to_theorems ``SEG 0 (SUC n) (a::b::c) =                          *)
3022(*    |- (?x l. c' = x::l) /\ (?b c. TL c' = b::c) ==> ?a b c. c' = a::b::c  *)
3023(*                                                                           *)
3024(*****************************************************************************)
3025
3026fun limit_to_theorems target term =
3027    (mapfilter (nested_constructor_theorem target) o strip_conj o dest_neg o
3028     snd o strip_abs) term
3029    handle e => wrapException "limit_to_theorems" e
3030
3031(*****************************************************************************)
3032(* group_function_clauses : thm -> thm list                                  *)
3033(*                                                                           *)
3034(*    Groups a set of mutually recursive equations by function symbol        *)
3035(*                                                                           *)
3036(*****************************************************************************)
3037
3038fun group_function_clauses function =
3039    map (LIST_CONJ o snd) (bucket_alist (map (fn x =>
3040                    ((fst o strip_comb o lhs o snd o strip_imp o
3041                            snd o strip_forall o concl) x,x))
3042                    (CONJUNCTS function)))
3043    handle e => wrapException "group_function_clauses" e;
3044
3045(*****************************************************************************)
3046(* define_analogue    : string -> thm -> thm * thm                           *)
3047(* complete_analogues : thm list -> thm list -> thm list -> thm list -> thm  *)
3048(*                                                                           *)
3049(*     define_analogue  name  [D] |- P a b ==>                               *)
3050(*                              (encode (f a b) = <f> (encode a) (encode b)) *)
3051(*     Defines the definition term D with the name given and removes it from *)
3052(*     the theorem to return:                                                *)
3053(*           |- D         and                                                *)
3054(*           |- P a b ==> (encode (f a b) = <f> (encode a) (encode b))       *)
3055(*                                                                           *)
3056(*     complete_analogue extras |- D   |- P a b ==> ...    defn              *)
3057(*     Inserts the propagation theorem into the rewrite set, then calls      *)
3058(*     PROPAGATE_ENCODERS_CONV after rewriting using the definition          *)
3059(*                                                                           *)
3060(*****************************************************************************)
3061
3062fun define_analogue name thm =
3063let val _ = trace 2 "->define_analogue\n"
3064    val (definition,rewrite) =
3065        case (dest_thm thm)
3066        of ([D],p) => (D,p)
3067        | _ => raise (mkStandardExn "define_analogue"
3068              ("Theorem should be of the form [Definition] |- Rewrite:" ^
3069               "\nHowever, the theorem supplied is not: " ^ thm_to_string thm))
3070    val defined = new_definition (name,definition)
3071        handle e => wrapException "define_analogue" e
3072in
3073    (defined,MATCH_MP (DISCH_ALL thm) defined)
3074    handle e => wrapException "define_analogue" e
3075end;
3076
3077local
3078fun MAYBEIF_RWR_CONV thm term =
3079    if not (is_cond term)
3080       then (REWR_CONV thm term handle e => wrapException "MAYBEIF_RWR_CONV" e)
3081       else
3082let val (p,a,b) = dest_cond term
3083    val thma' = PART_MATCH (lhs o snd o strip_imp) thm a
3084              handle e => wrapException "MAYBEIF_RWR_CONV" e
3085    val thmb = DISCH (mk_neg p) (REFL b)
3086    val thmp = REFL p
3087    val thma =
3088        DISCH p (if is_imp_only (concl thma')
3089           then CONV_RULE (LAND_CONV (REWRITE_CONV [ASSUME p]) THENC
3090                           REWR_CONV (hd (CONJUNCTS (SPEC_ALL IMP_CLAUSES))))
3091                           thma'
3092           else thma') handle e =>
3093           raise (mkStandardExn "MAYBEIF_RWR_CONV"
3094                 ("Unable to prove the rewrite term: " ^
3095                  term_to_string (fst (dest_imp (concl thma'))) ^
3096                  "\nFrom the condition: " ^ term_to_string p))
3097in
3098    MATCH_MP COND_CONG (LIST_CONJ [thmp,thma,thmb])
3099    handle e => wrapException "MAYBEIF_RWR_CONV" e
3100end
3101in
3102fun complete_analogues extras rwrs functions definitions =
3103let val _ = trace 2 "->complete_analogue\n"
3104    val names = with_exn (map (fst o dest_const o repeat rator o
3105                         rand o lhs o snd o strip_imp o snd o strip_forall o
3106                         concl))
3107                        rwrs
3108                        (mkStandardExn "complete_analogues"
3109                        "Rewrite is not of the form: |- !a... encode (f a..) =")
3110    val _ = map2 (fn name => add_standard_rewrite 0 ("PROP-" ^ name))
3111                 names rwrs
3112            handle e => wrapException "complete_analogues" e
3113    val _ = map (fn x =>
3114                save_thm("prop_" ^ (fst o dest_const o fst o strip_comb o rhs o
3115                                    snd o strip_imp_only o snd o strip_forall o
3116                                    concl) x,x)) rwrs
3117        handle e => wrapException "complete_analogues" e
3118    val rewritten = map2 (fn func => fn defn =>
3119                         CONV_RULE (STRIP_QUANT_CONV (RAND_CONV (RAND_CONV
3120                                   (MAYBEIF_RWR_CONV defn)))) func)
3121                  functions definitions
3122                  handle e => wrapException "complete_analogues" e
3123in
3124    LIST_CONJ
3125      (map (RIGHT_CONV_RULE (PROPAGATE_ENCODERS_CONV ([],extras)) o SPEC_ALL)
3126           rewritten)
3127    handle e => wrapException "complete_analogues" e
3128end
3129end
3130
3131(*****************************************************************************)
3132(* convert_definition :                                                      *)
3133(*        hol_type -> (term * string) list -> (term * term list) list        *)
3134(*                                               -> thm list -> thm -> thm   *)
3135(*                                                                           *)
3136(*    Usage: convert_definition target_type                                  *)
3137(*                              [name map]                                   *)
3138(*                              [limit terms]                                *)
3139(*                              [extra theorems]                             *)
3140(*                              definition                                   *)
3141(*        The name map maps function clauses to new names                    *)
3142(*        Limits map function constants (from the original definition) to    *)
3143(*        limits applied to these functions, eg:                             *)
3144(*            [(``SEG``,``\a b c. a + b <= LENGTH c``)]                      *)
3145(*            [(``LOG``,``\a. 0 < a ==> 0 < a DIV 2``)]                      *)
3146(*        -- The predicate is applied to the arguments in order, as such,    *)
3147(*           the abstraction must exactly match the function arguments.      *)
3148(*        Extra theorems are theorems supplied to assist rewriting, examples *)
3149(*        of such functions are:                                             *)
3150(*           |- ~(x = 0) ==> (?d. x = SUC d)                                 *)
3151(*           |- a + b <= LENGTH c ==> ~((a = 0) /\ (c = []))                 *)
3152(*           |- 0 < a ==> 0 < a DIV 2                                        *)
3153(*                                                                           *)
3154(*    Performs the full conversion:                                          *)
3155(*         1) Conversion to case form using 'clause_to_case'                 *)
3156(*         2) Generation of the coding functions using 'encode_type'         *)
3157(*         3) Generation of the analogous definition using                   *)
3158(*            'mk_analogue_definition' and 'define_analogue'                 *)
3159(*         4) Generation of the affirmation theorems using                   *)
3160(*            'mk_affirmation_theorems'                                      *)
3161(*            -- This ensures that theorems are predicated on the presence   *)
3162(*               of constructors, rather than their absense, eg.:            *)
3163(*               |- (?a. x = SUC a) ==> (encode (PRE a) = ...)               *)
3164(*         5) Generation of extra limit theorems for nested constructors     *)
3165(*            using 'limit_to_theorems'                                      *)
3166(*            -- Theorems are predicated on '?a b c. x = a::b::c' rather     *)
3167(*               than two clauses using TL                                   *)
3168(*         6) The propagation of the encoder through the definition using    *)
3169(*            'complete_analogue', which uses PROPAGATE_ENCODERS_CONV.       *)
3170(*                                                                           *)
3171(*    If successful this stores the definition in the current theory,        *)
3172(*    converted definitions can be loaded back in using load_definitions     *)
3173(*                                                                           *)
3174(*****************************************************************************)
3175
3176(*****************************************************************************)
3177(* Given a list of functions and their limits, returns the extra theorems    *)
3178(* required for forward-chaining:                                            *)
3179(*     Limit theorems:       |- destructed clause ==> full clause            *)
3180(*     Affirmation theorems: |- ~(C0 a) /\ ~(C1 a) ==> C2 a                  *)
3181(*                                                                           *)
3182(*****************************************************************************)
3183
3184fun calculate_extra_theorems target list =
3185let val STRIP = snd o strip_imp o snd o strip_forall o concl o fst
3186    val arg_types = flatten (map (mapfilter type_of o
3187                  (snd o strip_comb o lhs o STRIP))
3188            list) handle e => wrapException "calculate_extra_theorems" e
3189    val filtered_arg_types =
3190        filter (not o is_vartype) arg_types
3191    val _ = map (encode_type target o base_type)
3192                (filter (can constructors_of) filtered_arg_types)
3193            handle e => wrapException "calculate_extra_theorems" e
3194    fun vbase_type t = if is_vartype t then t else base_type t
3195    val all_types = mk_set (flatten
3196        (map (mk_set o map (vbase_type o fst) o
3197             RTC o reachable_graph sub_types) arg_types))
3198        handle e => wrapException "calculate_extra_theorems" e
3199    val affirmation_theorems =
3200        flatten (mapfilter mk_affirmation_theorems all_types)
3201        handle e => wrapException "calculate_extra_theorems" e
3202    val clause_limits = mapfilter clause_to_limit (flatten (map snd list))
3203    val nested_theorems = map (limit_to_theorems target) clause_limits
3204                        handle e => wrapException "calculate_extra_theorems" e
3205in
3206    flatten (affirmation_theorems :: nested_theorems)
3207end;
3208
3209local
3210fun assoc [] a = NONE
3211  | assoc ((b,c)::xs) a =
3212    if can (match_term b) a then SOME c else assoc xs a;
3213fun convert_definition_local error mk_analogue_definition
3214                       target name_map limits extras definition =
3215let val list = map ((DISCH_ALL ## I) o clause_to_case o UNDISCH_ALL)
3216                   (group_function_clauses definition)
3217               handle e => wrapException error e
3218    val terms = map (fst o strip_comb o lhs o snd o strip_imp o
3219                    snd o strip_forall o concl o fst)
3220                    list handle e => wrapException error e
3221    val name_list = map (Option.valOf o assoc name_map) terms
3222        handle e => raise (mkStandardExn "convert_definition"
3223                    ("No name has been supplied for the function clause: " ^
3224                     term_to_string (first
3225                                    (not o can (Option.valOf o assoc name_map))
3226                                     terms)))
3227    val limit_list = map ((fn NONE => [] | SOME y => y) o assoc limits) terms
3228    val rlist = map2 (fn (name,limit) => fn (thm,missing) =>
3229                    mk_analogue_definition target name (map SPEC_ALL extras)
3230                    limit thm) (zip name_list limit_list) list
3231                handle e => wrapException error e
3232    val rrlist = map2 define_analogue name_list rlist
3233                 handle e => wrapException error e
3234    val  (functions,rwrs) = unzip rrlist
3235    val definitions = map fst list
3236    val extra_theorems = calculate_extra_theorems target list
3237                       handle e => wrapException error e
3238    val extras_filtered =
3239        (filter (fn x =>
3240                not (can (match_term (fst (dest_imp (concl x))))
3241                         (snd (dest_imp (concl x))))
3242                handle _ => true) (map SPEC_ALL extras));
3243    val completed = complete_analogues
3244                    (extras_filtered @ extra_theorems)
3245                    rwrs functions definitions
3246                    handle e => wrapException error e
3247    val _ = trace 1 ("Definition(s) converted: \n")
3248    val _ = app (fn x => trace 1 (thm_to_string x ^ "\n")) (CONJUNCTS completed)
3249
3250    val _ = map (fn x =>
3251                save_thm
3252                    ("translated_" ^
3253                    fst (dest_const (fst (strip_comb (lhs
3254                        (snd (strip_forall (concl x))))))),x))
3255            (CONJUNCTS completed)
3256            handle e => wrapException error e
3257in  completed
3258end
3259in
3260fun convert_definition target name_map limits extras definition =
3261    convert_definition_local "convert_definition" mk_analogue_definition
3262                             target name_map
3263                             limits extras definition
3264fun convert_polymorphic_definition
3265    target name_map limits map_thms extras definition =
3266let fun mad t s l1 l2 function =
3267    case (assoc map_thms ((fst o strip_comb o lhs o snd o strip_imp o
3268                    snd o strip_forall o concl) function))
3269    of SOME map_thm =>
3270       mk_polymorphic_analogue_definition t s map_thm l1 l2 extras function
3271    |  NONE => mk_analogue_definition t s l1 l2 function
3272in
3273    convert_definition_local "convert_polymorphic_definition"
3274                             mad target name_map limits extras definition
3275end
3276end;
3277
3278(*****************************************************************************)
3279(* load_definitions : string -> thm list                                     *)
3280(*     Loads function definitions in from the theory given.                  *)
3281(*                                                                           *)
3282(*     Definitions are always stored with a corresponding propagation        *)
3283(*     theorem. We can therefore reload by ensuring that we have both.       *)
3284(*                                                                           *)
3285(*****************************************************************************)
3286
3287fun get_definition theory constant =
3288let val name = fst (dest_const constant)
3289    val definition = assoc ("translated_" ^ name) (DB.theorems theory)
3290    val theorem = assoc ("prop_" ^ name) (DB.theorems theory)
3291in
3292    (name,(definition,theorem))
3293end;
3294
3295fun load_definitions theory =
3296let val constants = Theory.constants theory
3297    val loaded = mapfilter (get_definition theory) constants
3298    val _ = map (fn (name,(_,theorem)) =>
3299                add_standard_rewrite 0 ("PROP-" ^ name) theorem) loaded
3300in
3301    map (fst o snd) loaded
3302end;
3303
3304(*****************************************************************************)
3305(* encode_until : (term -> bool) list -> thm list * thm list -> term ->      *)
3306(*                                                     term list list * thm  *)
3307(*     Performs encoding until one of the function terminals is reached.     *)
3308(*     Returns the encoded term, along with a list of terminal terms for     *)
3309(*     each terminal functions used.                                         *)
3310(*                                                                           *)
3311(*****************************************************************************)
3312
3313fun encode_until funcs AE term =
3314let val _ = trace 2 "->encode_until\n";
3315    val terminals = map (curry op^ "encode_until_terminal_" o
3316                         implode o base26 o fst) (enumerate 0 funcs)
3317    val lists = map (fn x => ref []) funcs
3318    fun remove () =
3319        app remove_terminal terminals
3320    val _ = remove();
3321    fun mk_terminal (n,func) thm_list x =
3322        if (func x) then
3323           (el n lists := (thm_list,x) :: (!(el n lists)) ; true)
3324        else false;
3325
3326    val _ = map2 (curry add_extended_terminal) terminals
3327                 (map mk_terminal (enumerate 1 funcs))
3328    val result = PROPAGATE_ENCODERS_CONV AE term
3329                 handle e => (remove() ; wrapException "encode_until" e)
3330    val _ = remove();
3331in
3332    (map (op!) lists,result)
3333end;
3334
3335fun step_PROPAGATE_ENCODERS_CONV AE term =
3336    snd (encode_until
3337        [fn x => (TextIO.input1 TextIO.stdIn ; print_term x ;
3338                 print "\n" ;false)]
3339        AE term);
3340
3341local
3342fun mk_encoder target x = mk_comb(get_encode_function target (type_of x),x);
3343fun ap_decoder target x =
3344    AP_TERM (get_decode_function target (type_of (rand (lhs (concl x))))) x;
3345fun left_encdec target x =
3346    CONV_RULE (LAND_CONV (REWR_CONV (FULL_ENCODE_DECODE_THM target
3347              (type_of (rand (rand (lhs (concl x)))))))) x;
3348fun LIST_MK_COMB (a,L) = foldl (uncurry (C (curry MK_COMB))) a L;
3349fun mk_pres target arg =
3350let val p = mk_eq(mk_encoder target arg,genvar target)
3351in
3352    ((left_encdec target o ap_decoder target o ASSUME) p,SOME p)
3353    handle _ => (REFL arg,NONE)
3354end
3355fun prewrite term =
3356let val (f,args) = strip_comb (rand term)
3357    val target = type_of term
3358    val (rwrs,pres) = unzip (map (mk_pres target) args)
3359    val result = DISCH_LIST_CONJ (mapfilter Option.valOf pres)
3360                 (AP_TERM (rator term) (LIST_MK_COMB(REFL f,rwrs)))
3361in
3362    DISCH T result
3363end
3364in
3365fun encode_until_recursive funcs AE fterms term =
3366let val new_rewrites = map prewrite fterms
3367        handle e => wrapException "encode_until_recursive" e
3368    fun remove() = map (fn (a,_) =>
3369                       remove_rewrite ("encode_until_recursive_rwr_"
3370                                     ^ int_to_string a))
3371                   (enumerate 0 new_rewrites)
3372    val _ = remove()
3373    fun wrap e = (remove () ; wrapException "encode_until_recursive" e)
3374    val _ = map (fn (a,b) =>
3375        add_conditional_rewrite 100 ("encode_until_recursive_rwr_"
3376                                     ^ int_to_string a) b)
3377        (enumerate 0 new_rewrites) handle e => wrap e
3378in
3379    ((encode_until funcs AE term handle e => wrap e)
3380    before remove())
3381end
3382end;
3383
3384(*****************************************************************************)
3385(* get_all_detect_types : hol_type -> hol_type list                          *)
3386(*     Returns the list of all types in recursion, or nested recursion,      *)
3387(*     with the type t.                                                      *)
3388(*                                                                           *)
3389(*****************************************************************************)
3390
3391fun get_all_detect_types target t =
3392let val basetype = most_precise_type
3393                   (C (exists_coding_function_precise target) "detect") t
3394    fun st t = flatten (map (fst o strip_fun o type_of) (constructors_of t))
3395               handle e => (snd (dest_type t) handle e => [])
3396    val alltypes = (basetype,basetype)::TC (reachable_graph st basetype);
3397    val match = match_type basetype t
3398in
3399    mk_set (map (type_subst match o fst)
3400                (filter (curry op= basetype o snd) alltypes))
3401end;
3402
3403fun is_recursive_detect_type target t =
3404let val basetype = most_precise_type
3405                   (C (exists_coding_function_precise target) "detect") t
3406    fun st t = flatten (map (fst o strip_fun o type_of) (constructors_of t))
3407               handle e => (snd (dest_type t) handle e => [])
3408    val alltypes = TC (reachable_graph st basetype);
3409    val match = match_type basetype t
3410in
3411    mem t (map (type_subst match o fst) (filter op= alltypes))
3412end;
3413
3414(*****************************************************************************)
3415(* flatten_recognizers : hol_type -> hol_type -> thm                         *)
3416(*****************************************************************************)
3417
3418fun mk_fullname target t =
3419    if is_vartype t orelse t = target then "ANY"
3420    else String.concat (op:: ((I ## map (mk_fullname target)) (dest_type t)))
3421
3422fun SET_CODER thm =
3423    RIGHT_CONV_RULE (RAND_CONV (REWR_CONV (GSYM combinTheory.I_THM)))
3424                    (SPEC_ALL (GSYM thm))
3425    handle e => wrapException "SET_CODER" e
3426
3427fun generate_recognizer_terms namef target full_types =
3428let val _ = map (encode_type target) (filter (not o is_vartype)
3429                (map base_type full_types))
3430    val KT = mk_comb(mk_const("K",bool --> target --> bool),
3431             mk_const("T",bool));
3432
3433    val detectors = map (get_detect_function target) full_types
3434    val fvs = free_varsl detectors
3435    val detectors' = map (subst (map (fn x => x |-> KT) fvs)) detectors;
3436
3437    val var = mk_var("x",target);
3438    val bool_enc = get_encode_function target bool
3439in
3440    map2 (fn t => fn d =>
3441        mk_eq(mk_comb(mk_var(namef t,target --> target),var),
3442              mk_comb(bool_enc,mk_comb(d,var)))) full_types detectors'
3443end;
3444
3445fun flatten_recognizers namef target t =
3446let val _ = trace 2 "->flatten_recognizers\n"
3447    val _ = scrub_rewrites()
3448    val full_types = get_all_detect_types target t
3449    val funcs = generate_recognizer_terms namef target full_types
3450    val _ = trace 1 "Creating new recognition functions:\n";
3451    val _ = app (fn x => trace 1 (term_to_string x ^ "\n")) funcs
3452    val defns = map2 (fn t => curry new_definition (namef t))
3453                     full_types funcs
3454    val _ = map2 (fn t => add_standard_rewrite 1 (namef t)
3455                          o SET_CODER)
3456                 full_types defns;
3457    val _ = map2 (fn t => fn d => save_thm("prop_" ^ (namef t),
3458                                  SET_CODER d))
3459                 full_types defns;
3460    val all_detectors =
3461        map (C (get_coding_function_def target) "detect") full_types
3462    val rewrites = map2 (fn d => (RIGHT_CONV_RULE (RAND_CONV (REWR_CONV d))
3463                                  o SPEC_ALL)) all_detectors defns
3464    val general_detects =
3465        mapfilter (generate_coding_theorem target "general_detect" o base_type)
3466                  full_types
3467    val finished = map (RIGHT_CONV_RULE (PROPAGATE_ENCODERS_CONV
3468                       ([],general_detects)))
3469                   rewrites
3470    val _ = map2 (fn t => fn d =>
3471                 save_thm("translated_" ^ (namef t),d))
3472            full_types finished
3473in
3474   finished
3475end handle e => wrapException "flatten_recognizers" e
3476
3477fun detect_const_type term =
3478let val target = last (fst (strip_fun (type_of term)))
3479    val const_map =
3480        (get_detect_function target target,target) ::
3481        (mapfilter (fn t => (get_coding_function_const target t "detect",t))
3482            (get_translation_types target))
3483in
3484   snd (first (can (match_term term) o fst) const_map)
3485end;
3486
3487fun get_detect_type term =
3488    detect_const_type term handle _ =>
3489    (if is_comb term then
3490        let val (a,b) = strip_comb term
3491            val ft = get_detect_type a
3492            val ts = map get_detect_type b
3493        in  mk_type(fst (dest_type ft),ts)
3494        end
3495     else raise Empty);
3496
3497fun recognizer_rewrite target t =
3498let val detector = SPEC_ALL (get_coding_function_def target t "detect")
3499    val var = (rand o lhs o concl) detector
3500    val prior = mk_eq(imk_comb(mk_const("I",alpha --> alpha),var),genvar target)
3501    val rwr = REWRITE_RULE [combinTheory.I_THM] (ASSUME prior)
3502    val thm' = CONV_RULE (LAND_CONV (RAND_CONV (REWR_CONV (GSYM rwr))))
3503                         (INST [var |-> (rhs (concl rwr))] detector)
3504    val thm'' = AP_TERM (get_encode_function target bool) thm'
3505    val rrwr = ASSUME (mk_eq(rhs (concl thm''),genvar target))
3506    val thm''' = RIGHT_CONV_RULE (REWR_CONV rrwr) thm''
3507in
3508    DISCH T (CONV_RULE (REWR_CONV AND_IMP_INTRO)
3509              (DISCH prior (DISCH (concl rrwr) thm''')))
3510end handle e => wrapException "recognizer_rewrite" e
3511
3512fun polytypic_recognizer term =
3513let val target = type_of term
3514    val t = get_detect_type (rator (rand term))
3515    val detector = get_detect_function target t
3516    val encoder = get_encode_function target bool
3517    val var = mk_var("x",target);
3518    val _ = match_term (mk_comb(encoder,mk_comb(detector,var))) term
3519in
3520    if is_recursive_detect_type target t
3521       then (flatten_recognizers (mk_fullname target) target t ;
3522             DISCH T (snd (hd (snd (return_matches [] term)))))
3523       else recognizer_rewrite target t
3524end handle e => raise (mkStandardExn "polytypic_recognizer"
3525           "Not an encoder recognition predicate");
3526
3527(*****************************************************************************)
3528(* Like subst, except it acts like REWR_CONV, not SUBST_CONV                 *)
3529(*****************************************************************************)
3530
3531fun exact_subst tmap term =
3532let val term' = subst tmap term
3533in
3534    mk_comb(exact_subst tmap (rator term'),exact_subst tmap (rand term'))
3535    handle _ => mk_abs(bvar term',exact_subst tmap (body term'))
3536    handle _ => term'
3537end handle e => wrapException "exact_subst" e
3538
3539fun subst_all tmap term =
3540let val all_terms = find_terms (fn t =>
3541                        exists (fn r => can (match_term (#redex r)) t) tmap)
3542                        term
3543    val new_map = map (fn term =>
3544                    tryfind (fn {redex,residue} => term |->
3545                            subst (fst (match_term redex term)) residue) tmap)
3546                  all_terms
3547 in exact_subst new_map term
3548 end handle e => wrapException "subst_all" e
3549
3550(*****************************************************************************)
3551(* Takes the induction theorem from a translation scheme, and proves a       *)
3552(* corresponding relation well-founded:                                      *)
3553(*      |- (!x. isPair x /\ (left x) /\ (right x) ==> P x) /\                *)
3554(*         (!x. ~(isPair x) ==> P x) ==>                                     *)
3555(*              !x. P x                                                      *)
3556(*                      ===>                                                 *)
3557(*      |- WF (\y x. isPair x /\ ((y = left x) \/ (y = right x)))            *)
3558(*                                                                           *)
3559(*****************************************************************************)
3560
3561fun get_wf_relation target =
3562let val scheme = get_translation_scheme target
3563        handle e => wrapException "get_wf_relation" e
3564    val induction = #induction scheme
3565    val left = #left scheme
3566    val right = #right scheme
3567    val pair = #predicate scheme
3568    val x = mk_var("x",target)
3569    val y = mk_var("y",target)
3570    val R = list_mk_abs([y,x],mk_conj(mk_comb(pair,x),
3571                    mk_disj(mk_eq(y,mk_comb(left,x)),
3572                            mk_eq(y,mk_comb(right,x)))))
3573                            handle e => wrapException "get_wf_relation" e
3574    val term = imk_comb(``WF:('a -> 'a -> bool) -> bool``,R)
3575               handle e => wrapException "get_wf_relation" e
3576in
3577    BETA_RULE (prove(term,
3578        REWRITE_TAC [relationTheory.WF_EQ_INDUCTION_THM] THEN
3579        NTAC 2 STRIP_TAC THEN
3580        MATCH_MP_TAC induction THEN REPEAT STRIP_TAC THEN
3581        RULE_ASSUM_TAC BETA_RULE THEN FIRST_ASSUM MATCH_MP_TAC THEN
3582        REPEAT STRIP_TAC THEN
3583        FIRST_ASSUM SUBST_ALL_TAC THEN
3584        ASM_REWRITE_TAC []))
3585        handle e => raise (mkStandardExn "get_wf_relation"
3586               ("Unable to prove the relation: " ^ term_to_string term ^
3587                " well-founded"))
3588end;
3589
3590(*****************************************************************************)
3591(* ALLOW_CONV : conv -> conv                                                 *)
3592(*                                                                           *)
3593(*****************************************************************************)
3594
3595fun ALLOW_CONV conv term = (conv term) handle UNCHANGED => REFL term;
3596
3597(*****************************************************************************)
3598(* make_abstract_funcs ...                                                   *)
3599(*****************************************************************************)
3600
3601fun make_abstract_funcs target abstract_terms funcs input_terms =
3602let val var_map = map (fn (i,a) => a |->
3603                          (mk_var("ABS" ^ implode (base26 i),target)))
3604                      (enumerate 0 abstract_terms)
3605    val new_args = map #residue var_map
3606    fun reverse x = map (fn {redex,residue} => residue |-> redex) x
3607    fun mrhs x = mk_comb(rator (rhs x),
3608                         ((fn (a,b,c) => b) o dest_cond o rand o rhs) x)
3609                 handle _ => rhs x
3610    fun mlhs x = lhs (snd (strip_forall x))
3611    val new_funcs = map (fn func =>
3612                     let val (var,args) = strip_comb (mlhs func)
3613                     in  list_mk_comb(mk_var(fst (dest_var var),
3614                         foldl (op-->) target (map type_of (args @ new_args))),
3615                         args @ new_args)
3616                     end) funcs
3617   val func_map1 = map2 (curry op|->) (map mrhs funcs) new_funcs
3618   val func_map2 = map2 (curry op|->) (map mlhs funcs) new_funcs
3619in
3620   (map (fn {redex,residue} => mk_eq(redex,subst (reverse var_map) residue))
3621        func_map1,
3622    map (subst var_map o subst func_map2 o subst_all func_map1) input_terms)
3623end handle e => wrapException "make_abstract_funcs" e
3624
3625fun create_abstract_recognizers namef f target t =
3626let fun wrap e = wrapException "create_abstract_recognizers" e
3627    val _ = trace 2 "->create_abstract_recognizers\n";
3628    val full_types = get_all_detect_types target t
3629    val funcs = map (snd o strip_forall)
3630                (generate_recognizer_terms namef target full_types)
3631                handle e => wrap e
3632    val general_detects =
3633        mapfilter (SPEC_ALL o
3634                   generate_coding_theorem target "general_detect" o base_type)
3635                  [``:'a list``]
3636    fun rc detector func =
3637    let val thm = RAND_CONV (ALLOW_CONV (ONCE_REWRITE_CONV [detector]))
3638                            (rhs func)
3639        val (terms,r) = encode_until_recursive [f o rand] ([],general_detects)
3640                        (map rhs funcs) (rhs (concl thm))
3641    in  (map snd (hd terms),RIGHT_CONV_RULE (ALLOW_CONV (REWR_CONV r)) thm)
3642    end
3643    val all_detectors =
3644        map (C (get_coding_function_def target) "detect") full_types
3645        handle e => wrapException "create_abstract_recognizers" e
3646    val (terminals,thms) = unzip (map2 rc all_detectors funcs)
3647                           handle e => wrap e
3648
3649    val (props,output_terms) =
3650        make_abstract_funcs target (mk_set (flatten terminals))
3651                            funcs (map concl thms)
3652in
3653    (props,thms,output_terms)
3654    handle e => wrap e
3655end;
3656
3657(*****************************************************************************)
3658(* WF_TC_FINISH_TAC : term * term -> tactic                                  *)
3659(*     Given terms L and R, proves a goal of the form:                       *)
3660(*       A ?- TC R (L .. R .. x) x                                           *)
3661(*                                                                           *)
3662(*****************************************************************************)
3663
3664fun WF_TC_FINISH_TAC (a,g) =
3665let val (rrule,trule) = CONJ_PAIR (SPEC_ALL relationTheory.TC_RULES);
3666    val to = rand g
3667    val from = rand (rator g)
3668    val scheme = get_translation_scheme (type_of to);
3669    val eterm = rand from;
3670in
3671    (if   from = beta_conv (mk_comb(#left scheme,to)) orelse
3672          from = beta_conv (mk_comb(#right scheme,to))
3673        then MATCH_MP_TAC rrule THEN pairLib.GEN_BETA_TAC
3674        else MATCH_MP_TAC trule THEN EXISTS_TAC eterm THEN
3675             CONJ_TAC THEN WF_TC_FINISH_TAC)
3676    (a,g)
3677end;
3678
3679(*****************************************************************************)
3680(* mk_summap     : thm -> thm                                                *)
3681(*     |- WF R ----> |- WF (inv_image R summap)                              *)
3682(* mk_sumstart   : thm -> thm                                                *)
3683(*     |- WF R ----> |- WF (inv_image R (\a. (a,())))                        *)
3684(* mk_lex        : thm -> thm -> thm                                         *)
3685(*     |- WF R1, |- WF R2 ---> |- WF (R1 LEX R2)                             *)
3686(* mk_nested_rel : term list -> thm                                          *)
3687(*     [R (INL _) (INR _), R (INL (INR a)) (INR (INL x))...]                 *)
3688(*     ---> |- WF (\a b. (?a'. a = INL a') /\ (?b'. b = INR b') /\ ... )     *)
3689(*                                                                           *)
3690(*     Used to relate terms:                                                 *)
3691(*           R (INL (a,_)) (INR (b,_)) as:                                   *)
3692(*           a <_sexp b \/ (a = b) /\ (l = INL /\ r = INR)                   *)
3693(*                                                                           *)
3694(*****************************************************************************)
3695
3696local
3697open sumSyntax pairSyntax
3698val sumtype = ``:'a + 'b``;
3699val sumcase = TypeBase.case_const_of sumtype
3700val WF_inv = Q.SPEC `R` (relationTheory.WF_inv_image);
3701fun sumt n x = list_mk_sum (for 1 n (K x))
3702fun mk_sum 0 t b = [b]
3703      | mk_sum n t b = mk_inl(b,sumt n t)
3704                       :: (map (C (curry mk_inr) t) (mk_sum (n - 1) t b));
3705fun mk_sum_tm n =
3706let val avar = mk_var("a",sumt n alpha)
3707    val bvar = mk_var("b",beta)
3708    val alphaa = mk_var("a",alpha)
3709    fun lfoldr f x = foldr f (last x) (butlast x) handle _ => (hd x)
3710    fun mk_sumcase L =
3711        lfoldr (fn (a,b) => list_imk_comb(sumcase,[a,b]))
3712               (map (fn a => mk_abs(alphaa,mk_pair(alphaa,a))) L);
3713in
3714    mk_pabs(mk_pair(avar,bvar),
3715        mk_comb(mk_sumcase (mk_sum (n - 1) beta bvar),avar))
3716end;
3717fun WF_sum length =
3718    Q.GEN `R` (ISPEC (mk_sum_tm length) WF_inv);
3719val WF_ssum = Q.GEN `R` (ISPEC ``(\(a:'a). (a,()))`` WF_inv);
3720in
3721fun mk_summap n WF_R =
3722let val rt = type_of (rand (concl WF_R))
3723    val sum_type = snd (dest_prod (hd (fst (strip_fun rt))))
3724    val sum = WF_sum n
3725    val WF_R' = INST_TYPE (match_type sum_type (sumt n (gen_tyvar()))) WF_R
3726in
3727    MATCH_MP sum WF_R'
3728end handle e => wrapException "mk_summap" e
3729fun mk_sumstart WF_R =
3730    MATCH_MP WF_ssum (INST_TYPE
3731        (match_type (snd (pairSyntax.dest_prod (hd (fst (strip_fun (type_of
3732                    (rand (concl WF_R)))))))) ``:unit``) WF_R)
3733    handle e => wrapException "mk_sumstart" e
3734end;
3735
3736local
3737fun get_tyvars thm =
3738    HOLset.listItems
3739        (HOLset.addList ((hyp_tyvars thm),type_vars_in_term (concl thm)))
3740in
3741fun mk_lex WFRa WFRb =
3742let val tvsa = get_tyvars WFRa
3743    val tvsb = get_tyvars WFRb
3744    val mapa = map (fn v => v |-> gen_tyvar()) tvsa
3745    val mapb = map (fn v => v |-> gen_tyvar()) tvsb
3746in
3747    MATCH_MP pairTheory.WF_LEX
3748             (CONJ (INST_TYPE mapa WFRa) (INST_TYPE mapb WFRb))
3749end handle e => wrapException "mk_lex" e
3750end;
3751
3752local
3753open sumSyntax
3754fun match_dub f tm =
3755    inst (f (fst (dom_rng (type_of tm)))
3756         ((fst o dom_rng o snd o dom_rng o type_of) tm)) tm;
3757fun is_dub tm =
3758    (fst o dom_rng o type_of) tm = (fst o dom_rng o snd o dom_rng o type_of) tm;
3759fun mk_wf_nested_case term =
3760let val ((R,a),b) = (dest_comb ## I) (dest_comb term);
3761    val avar = mk_var("a",alpha)
3762    val bvar = mk_var("b",alpha)
3763    fun strip x y =
3764        if can (match_term (mk_inl (avar,beta))) x
3765           then mk_inl (strip (rand x) y ,gen_tyvar())
3766           else if can (match_term (mk_inr (avar,beta))) x
3767                   then mk_inr (strip (rand x) y ,gen_tyvar())
3768                   else y;
3769    val a' = strip a avar
3770    val b' = strip b bvar
3771    val avar' = mk_var("a'",type_of a')
3772    val bvar' = mk_var("b'",type_of b')
3773    val a'' = mk_exists(avar,mk_eq(avar',a'));
3774    val b'' = mk_exists(bvar,mk_eq(bvar',b'));
3775    fun rpt x = if is_dub x then x
3776                   else rpt (match_dub match_type x handle _ =>
3777                             match_dub (C match_type) x) handle _ =>
3778raise(mkStandardExn "mk_wf_nested_case"
3779     ("Could not instantiate type of: " ^ term_to_string x ^
3780      "\nto be of the form: 'a -> 'a -> bool"))
3781in
3782    rpt (list_mk_abs([avar',bvar'],list_mk_conj [a'',b'']))
3783end
3784fun make_all_terms [] = ``REMPTY:'a -> 'a -> bool``
3785  | make_all_terms [x] = mk_wf_nested_case x
3786  | make_all_terms (x::xs) =
3787let val r1 = mk_wf_nested_case x
3788    val r2 = make_all_terms xs
3789    val nty_var1 = gen_tyvar()
3790    val nty_var2 = gen_tyvar()
3791    val v1 = fst (dest_abs r1)
3792    val v2 = fst (dest_abs r2)
3793    val l1 = length (strip_sum (last (pairSyntax.strip_prod (type_of v1))))
3794    val l2 = length (strip_sum (last (pairSyntax.strip_prod (type_of v2))))
3795    val l = Int.max(l1,l2)
3796    val match = pairSyntax.list_mk_prod(
3797                        (butlast (pairSyntax.strip_prod (type_of v1))) @
3798                        [list_mk_sum (rev (nty_var2 ::
3799                                     for 1 (l - 1) (K nty_var1)))])
3800    val r1i = inst (match_type (type_of v1) match) r1
3801    val r2i = inst (match_type (type_of v2) match) r2
3802    val vars = fst (strip_abs r1i)
3803in
3804    list_mk_abs(vars,mk_disj(list_imk_comb(r1i,vars),list_mk_comb(r2i,vars)))
3805end handle e => wrapException "(make_all_terms)" e
3806fun prove_nested_case term =
3807    BETA_RULE ((prove(term,
3808    pairLib.GEN_BETA_TAC THEN
3809    REWRITE_TAC [relationTheory.WF_EMPTY_REL,
3810                 relationTheory.WF_EQ_INDUCTION_THM] THEN
3811    REPEAT STRIP_TAC THEN FIRST_ASSUM MATCH_MP_TAC THEN
3812    REPEAT (POP_ASSUM MP_TAC) THEN pairLib.GEN_BETA_TAC THEN
3813    REPEAT STRIP_TAC THEN
3814    FIRST_ASSUM MATCH_MP_TAC THEN
3815    METIS_TAC [TypeBase.nchotomy_of ``:'a + 'b``,
3816              sumTheory.sum_distinct,TypeBase.one_one_of ``:'a + 'b``])))
3817              handle e => wrapException "prove_nested_case" e
3818in
3819fun mk_nested_rel terms =
3820    prove_nested_case(imk_comb(``WF:('a -> 'a -> bool) -> bool``,
3821                                        make_all_terms terms))
3822    handle e => wrapException "mk_nested_rel" e
3823end;
3824
3825(*****************************************************************************)
3826(* WF_RECOGNIZER_TAC : tactic                                                *)
3827(*     Solves goals of the form: ?- ?R. WF R /\ ... where                    *)
3828(*     R : target # 'a + target # 'a + .... -> ... -> bool    or             *)
3829(*     R : target + target + ... -> ... -> bool                              *)
3830(*                                                                           *)
3831(*****************************************************************************)
3832
3833fun WF_RECOGNIZER_TAC (a,g) =
3834let val _ = proofManagerLib.set_goal (a,g);
3835    val r = type_of (fst (dest_exists g))
3836    val txs = sumSyntax.strip_sum (hd (fst (strip_fun r)))
3837    val target = hd (pairSyntax.strip_prod (hd txs));
3838    val WF_R = MATCH_MP relationTheory.WF_TC (get_wf_relation target)
3839    val WF_LEXR =
3840        MATCH_MP (pairTheory.WF_LEX) (CONJ WF_R relationTheory.WF_EMPTY_REL);
3841
3842    val all_terms = map (snd o strip_imp o snd o strip_forall)
3843                    (strip_conj (snd (dest_conj (snd (strip_exists g)))))
3844    fun strip_sum x =
3845        if is_comb x andalso exists (fn y => can (match_term y) (rator x))
3846                     [sumSyntax.inl_tm,sumSyntax.inr_tm]
3847           then strip_sum (rand x) else x;
3848    val tmap = hd o pairSyntax.strip_pair o strip_sum
3849    val same_terms =
3850        (filter (op= o (tmap ## tmap) o (snd o dest_comb ## I) o dest_comb)
3851                all_terms);
3852
3853    fun mk_rel R = mk_sumstart(mk_summap (length txs)
3854                       (mk_lex R (mk_nested_rel same_terms)));
3855    val relation =
3856        if can pairSyntax.dest_prod (hd txs)
3857           then mk_rel WF_LEXR else mk_rel WF_R
3858    val relation_matched =
3859        INST_TYPE (match_type (type_of (rand (concl relation))) r) relation
3860in
3861    ((EXISTS_TAC (rand (concl (relation_matched))) THEN
3862     REWRITE_TAC [relation_matched] THEN
3863     REPEAT STRIP_TAC THEN
3864     REPEAT (CHANGED_TAC (REWRITE_TAC
3865            [relationTheory.EMPTY_REL_DEF,pairTheory.LEX_DEF,
3866             sumTheory.sum_case_def,relationTheory.inv_image_def,
3867             pairTheory.FST,pairTheory.SND,combinTheory.I_THM,
3868             sumTheory.sum_distinct] THEN
3869            pairLib.GEN_BETA_TAC)) THEN
3870     RW_TAC std_ss [] THEN
3871     WF_TC_FINISH_TAC THEN
3872     CCONTR_TAC THEN FULL_SIMP_TAC std_ss [] THEN NO_TAC) (a,g))
3873     before proofManagerLib.drop()
3874end handle e => wrapException "WF_RECOGNIZER_TAC" e
3875
3876(*****************************************************************************)
3877(* TARGET_INDUCT_TAC : tactic                                                *)
3878(*    Performs an induction using a well founded relation derived from the   *)
3879(*    target type induction scheme.                                          *)
3880(*                                                                           *)
3881(*****************************************************************************)
3882
3883fun TARGET_INDUCT_TAC (a,g) =
3884let val var = fst (dest_forall g)
3885    val WF_R = MATCH_MP relationTheory.WF_TC (get_wf_relation (type_of var))
3886    val scheme = get_translation_scheme (type_of var)
3887    val left = #left scheme
3888    val right = #right scheme
3889in
3890    (recInduct (REWRITE_RULE [relationTheory.WF_EQ_INDUCTION_THM] WF_R) THEN
3891    NTAC 2 STRIP_TAC) (a,g)
3892end;
3893
3894(*****************************************************************************)
3895(* find_conditional_thm target t : hol_type -> thm                           *)
3896(*    Returns a rewrite for conditionals of the form:                        *)
3897(*        |- IF (P x) f g = if (isPair x) then f else g                      *)
3898(*                                                                           *)
3899(*****************************************************************************)
3900
3901fun find_conditional_thm target =
3902let val scheme = get_translation_scheme target
3903    val i = mk_const("I",target --> target);
3904    val vars = map (C (curry mk_var) target) ["a","b"]
3905    val pred = beta_conv (mk_comb(#predicate scheme,mk_var("p",target)))
3906    val cond = imk_comb(i,mk_cond(pred,el 1 vars,el 2 vars));
3907in
3908    GSYM (REWRITE_RULE [combinTheory.I_THM]
3909         (snd (encode_until [is_var o rand] ([],[]) cond)))
3910end;
3911
3912(*****************************************************************************)
3913(* PROPAGATE_RECOGNIZERS_TAC : thm list -> thm list -> tactic                *)
3914(*    Fully solves a goal of the form:                                       *)
3915(*      ?- !x. bool (detect ... x) = concrete_detect x (f ...) /\ bool ...   *)
3916(*    Given a list of theorems of the form:                                  *)
3917(*     |- bool (detect ... x) = if isPair x then ... else ...                *)
3918(*    Derived from partially encoding, then removing conditionals            *)
3919(*    and a list of definitions, representing the fully encoded theorems:    *)
3920(*     |- concrete_detect x (f ...) = if isPair x then ... else ...          *)
3921(*                                                                           *)
3922(*****************************************************************************)
3923
3924local
3925fun MATCH_CONJ_TAC thm =
3926    MAP_FIRST (MATCH_MP_TAC o
3927               GENL (fst (strip_forall (concl thm))) o
3928               DISCH (fst (dest_imp_only (snd (strip_forall (concl thm))))))
3929          (CONJUNCTS (UNDISCH (SPEC_ALL thm)));
3930fun PR_FINISH STAC =
3931    REPEAT (FIRST [REFL_TAC,
3932           FIRST_ASSUM MATCH_CONJ_TAC,
3933           MAP_FIRST MATCH_MP_TAC (DefnBase.read_congs())
3934           THEN REPEAT STRIP_TAC,
3935           CHANGED_TAC STAC,
3936           MK_COMB_TAC]);
3937val cond = ``COND:bool -> 'a -> 'a -> 'a``
3938in
3939fun PROPAGATE_RECOGNIZERS_TAC theorems definitions x =
3940let fun is_single thm = null (find_terms (can (match_term cond))
3941                                         (rhs (concl thm)));
3942    val (singles,(thms,defs)) =
3943        (I ## unzip) (partition (is_single o fst) (
3944                             zip theorems definitions));
3945    fun conv (a,b) = LAND_CONV (REWR_CONV a) THENC RAND_CONV (REWR_CONV b)
3946    val STAC = CONV_TAC (STRIP_QUANT_CONV
3947                        (EVERY_CONJ_CONV (TRY_CONV
3948                                         (FIRST_CONV (map conv singles)))));
3949in
3950     ((REPEAT (CHANGED_TAC STAC) THEN
3951     TARGET_INDUCT_TAC THEN REPEAT STRIP_TAC THEN
3952     ONCE_REWRITE_TAC definitions THEN
3953     ONCE_REWRITE_TAC theorems THEN
3954     REWRITE_TAC [find_conditional_thm
3955                 (type_of (fst (dest_forall (snd x))))] THEN
3956     TRY IF_CASES_TAC THEN ASM_REWRITE_TAC [] THEN
3957     RW_TAC (std_ss ++ boolSimps.LET_ss) [] THEN
3958     PR_FINISH STAC THEN
3959     WF_TC_FINISH_TAC THEN CCONTR_TAC THEN FULL_SIMP_TAC std_ss []) x)
3960     handle e => wrapException "PROPAGATE_RECOGNIZERS_TAC" e
3961end
3962end
3963
3964(*****************************************************************************)
3965(* fix_definition_terms : thm list -> term list -> term list                 *)
3966(*    Given a list of definitions, replaces the variables matching the       *)
3967(*    defined constants with the constants.                                  *)
3968(*                                                                           *)
3969(*****************************************************************************)
3970
3971fun fix_definition_terms defns props =
3972let val consts = map (fst o strip_comb o lhs o snd o strip_forall o concl) defns
3973    val vars = map (fn c => mk_var(fst (dest_const c),type_of c)) consts
3974in
3975    map (subst (map2 (curry op|->) vars consts)) props
3976end handle e => wrapException "fix_definition_terms" e
3977
3978(*****************************************************************************)
3979(* define_with_tactic : hol_type -> conv -> tactic -> term list -> thm list  *)
3980(*     Attempts to define a list of terms as functions using the tactic      *)
3981(*     given.                                                                *)
3982(*                                                                           *)
3983(*****************************************************************************)
3984
3985fun make_singles_definitions [] = []
3986  | make_singles_definitions L =
3987let val (defn,left) = pick_e Empty (fn s => new_definition
3988                (fst (dest_var (fst (strip_comb (lhs s)))),s)) L
3989in
3990    defn::make_singles_definitions (fix_definition_terms [defn] left)
3991end;
3992
3993fun fix_rewrite defs thm =
3994let val thm' = foldl (fn ((h,h'),thm) => INST_TY_TERM (match_term h h') thm)
3995                     thm (zip (hyp thm) (fix_definition_terms defs (hyp thm)))
3996in
3997    foldl (uncurry PROVE_HYP) thm' defs
3998end;
3999
4000fun define_with_tactic is_single conv tactic terms =
4001let val (singles,not_singles) = partition is_single terms
4002    val srws = map (fn x => list_mk_forall (snd (strip_comb (lhs x)),x)) singles
4003    val dterms = map (ALLOW_CONV (REWRITE_CONV (map ASSUME srws))) not_singles
4004    val rewrites = map (ALLOW_CONV conv o rhs o concl) dterms
4005    val name = fst (dest_var (fst (strip_comb (lhs (hd not_singles)))))
4006    val def_term = list_mk_conj (map (rhs o concl) rewrites)
4007    val (definition,induction) =
4008        case (total new_definition) (name,def_term)
4009        of NONE => (I ## SOME) (Defn.tprove(
4010                    (Defn.mk_defn name def_term),tactic))
4011         | SOME defn => (defn,NONE)
4012   val singles' = fix_definition_terms (CONJUNCTS definition) singles
4013   val sdefs = make_singles_definitions singles'
4014   val srws = fix_definition_terms (sdefs @ CONJUNCTS definition) srws
4015   val dterms' = map (fix_rewrite (sdefs @ CONJUNCTS definition)) dterms
4016   val penultimate =
4017       map2 (fn r => CONV_RULE (STRIP_QUANT_CONV (REWR_CONV (GSYM r))))
4018       rewrites (CONJUNCTS definition)
4019   val ultimate =
4020       map2 (fn r => GEN_ALL o CONV_RULE (STRIP_QUANT_CONV
4021                             (REWR_CONV (GSYM r))))
4022       dterms' penultimate
4023   val gvarname = fst o dest_var o fst o strip_comb o lhs o snd o strip_forall
4024   val gconstname = fst o dest_const o fst o strip_comb o
4025                  lhs o snd o strip_forall o concl
4026in
4027    (map (fn t => first (curry op= (gvarname t) o gconstname)
4028               (sdefs @ ultimate)) terms,induction)
4029end handle e => wrapException "define_with_tactic" e
4030
4031(*****************************************************************************)
4032(* flatten_abstract_recognizers :                                            *)
4033(*             (term -> bool) -> hol_type -> hol_type -> thm list            *)
4034(*                                                                           *)
4035(*    Flattens the recognizers for the type given, abstracting terms that    *)
4036(*    match the function given as inputs to the recognizer. For example:     *)
4037(*                                                                           *)
4038(*****************************************************************************)
4039
4040fun flatten_abstract_recognizers fname f target t =
4041let val conditional = find_conditional_thm target
4042    val (props,thms,terms) = create_abstract_recognizers fname f target t
4043    val conv = SIMP_CONV (std_ss ++ boolSimps.LET_ss) [] THENC
4044               REWRITE_CONV [conditional]
4045    val funcs = map lhs terms
4046    fun is_single x = exists (curry op= (fst (strip_comb (rhs x))) o
4047                                fst o strip_comb) funcs
4048    val (definitions,induction) = define_with_tactic is_single conv
4049                              WF_RECOGNIZER_TAC terms
4050    val props' = fix_definition_terms definitions props
4051    val var = rand (rand (lhs (hd props')))
4052    val term = mk_forall(var,list_mk_conj props');
4053    val theorems = map (CONV_RULE conv) thms;
4054    val props_thms = CONJUNCTS (SPEC_ALL (prove(term,
4055                   PROPAGATE_RECOGNIZERS_TAC theorems definitions)));
4056    fun fmap F f = map (fn x => f (fst (dest_const (fst (strip_comb
4057                                  (F (snd (strip_forall (concl x))))))),x));
4058    val _ = fmap rhs (fn (a,b) => save_thm("prop_" ^ a,b)) props_thms
4059    val _ = fmap rhs (uncurry (add_standard_rewrite 1)) props_thms
4060    val _ = fmap lhs (fn (a,b) => save_thm("translated_" ^ a,b))
4061                     definitions
4062in
4063    definitions
4064end handle e => wrapException "flatten_abstract_recognizers" e
4065
4066(*****************************************************************************)
4067(* Theorem tools to get rid of the detector stuff ....                       *)
4068(*****************************************************************************)
4069
4070fun generalize_abstract_recognizer_term target t =
4071let val var = mk_var("x",target)
4072    val detector = get_detect_function target t
4073    val boolenc = get_encode_function target bool
4074    val booldec = get_decode_function target bool
4075    val x = ref false;
4076    fun once _ = (!x before (x := true));
4077    val prop = snd (encode_until [once] ([],[])
4078                   (mk_comb(boolenc,mk_comb(detector,var))));
4079    fun fix x = if x = var then x else genvar (type_of x)
4080    val left = mk_comb(booldec,list_mk_comb((I ## map fix)
4081                         (strip_comb (rhs (concl prop)))));
4082    val basetype = (mk_type o (I ## map (K target)) o dest_type) t
4083    val right = mk_comb(get_detect_function target basetype,var);
4084in
4085    mk_forall(var,mk_imp(left,right))
4086end handle e => wrapException "generalize_abstract_recognizer_term" e;
4087
4088fun ENCODE_BOOL_UNTIL_CONV target terms term =
4089let val bool_encdec = FULL_ENCODE_DECODE_THM target bool
4090    val thm = GSYM (SPEC term bool_encdec);
4091    val limits = map (fn t => can (match_term t) o rand) terms
4092in
4093    RIGHT_CONV_RULE (RAND_CONV (snd o encode_until limits ([],[]))) thm
4094end;
4095
4096fun GENERALIZE_ABSTRACT_RECOGNIZER_TAC target t thm (a,g) =
4097let val right = snd (dest_imp_only (snd (strip_forall g)))
4098in
4099   (TARGET_INDUCT_TAC THEN
4100   ONCE_REWRITE_TAC [thm] THEN
4101   ONCE_REWRITE_TAC [get_coding_function_def target t "detect"] THEN
4102   IF_CASES_TAC THEN ASM_REWRITE_TAC [] THEN
4103   CONV_TAC (RAND_CONV (ENCODE_BOOL_UNTIL_CONV target [right]))) (a,g)
4104end
4105
4106fun generalize_abstract_recognizer target t pre_rewrites thm =
4107   prove(generalize_abstract_recognizer_term target t,
4108      GENERALIZE_ABSTRACT_RECOGNIZER_TAC target t thm THEN
4109      ASM_REWRITE_TAC pre_rewrites THEN
4110      RULE_ASSUM_TAC (CONV_RULE (REPEATC (STRIP_QUANT_CONV
4111           (RIGHT_IMP_FORALL_CONV ORELSEC REWR_CONV AND_IMP_INTRO)))) THEN
4112      REPEAT STRIP_TAC THEN FIRST_ASSUM MATCH_MP_TAC THEN
4113      ASM_REWRITE_TAC pre_rewrites THEN
4114      WF_TC_FINISH_TAC);
4115
4116(*****************************************************************************)
4117(* create_abstracted_definition                                              *)
4118(*              : (term -> bool) -> hol_type -> string -> thm list ->        *)
4119(*                                  term list -> thm -> *)
4120
4121fun encode_all_avoiding f func (assums,extras) term (tset,thmset) =
4122let val (ends,converted) = encode_until [f o rand,can (match_term func) o rand]
4123                                        (assums,extras) term
4124    val terminals = map snd (el 1 ends)
4125    val recursions = el 2 ends
4126    val target = type_of term
4127    fun mk_encoder x = imk_comb(get_encode_function target (type_of x),x)
4128        handle e => raise (mkDebugExn "mk_encoder"
4129               ("Could not encode the value: " ^ term_to_string x));
4130    val recs = foldl (fn ((b,c),a) => map (pair b o mk_encoder)
4131                                         (snd (strip_comb (rand c))) @ a)
4132             [] recursions;
4133in
4134    foldl (fn ((x,y),b) => encode_all_avoiding f func (x,extras) y b)
4135          (mk_set(terminals @ tset),converted::thmset) recs
4136end handle e => wrapException "encode_all_avoiding" e
4137
4138fun create_abstracted_definition_term
4139    f target name limits extras function =
4140let val term = snd (strip_forall
4141             (mk_analogue_definition_term target name
4142                                           limits function))
4143    val term_thm =
4144        STRIP_QUANT_CONV (RAND_CONV (RAND_CONV (RATOR_CONV (
4145                          RAND_CONV (ONCE_REWRITE_CONV [function]))))) term;
4146    val (tfunc,left) = dest_eq (snd (strip_forall (concl function)));
4147
4148    val extra_theorems = calculate_extra_theorems target [(function,limits)];
4149    val (encoder,encoded) = dest_comb (rhs (snd (strip_forall term)))
4150    val (pred,body,default) = dest_cond encoded
4151    val assums = map ASSUME (strip_conj pred);
4152
4153    val (terminals,converted) =
4154        encode_all_avoiding f tfunc ([],extras @ extra_theorems)
4155                              (rand (snd (strip_forall (rhs (concl term_thm)))))
4156                              ([],[])
4157
4158    val full_thm = RIGHT_CONV_RULE (STRIP_QUANT_CONV (RAND_CONV
4159                                   (REWR_CONV (last converted)))) term_thm
4160    val recursions = butlast converted
4161
4162    val _ = trace 1 ("Unabstracted conversion:\n" ^
4163                    term_to_string (rhs (concl full_thm)) ^ "\n")
4164    val _ = trace 1 ("\nRecusion points:\n")
4165    val _ = map (fn x => (trace 1 (thm_to_string x) ; trace 1 "\n")) recursions
4166    val _ = trace 1 ("\nTerminals:\n")
4167    val _ = map (fn x => (trace 1 (term_to_string x) ; trace 1 "\n")) terminals
4168
4169    fun fixr x =
4170        (fn t => PURE_REWRITE_RULE [FULL_ENCODE_DECODE_THM target t]
4171            (AP_TERM (get_decode_function target t) x))
4172        (type_of (rand (lhs (concl x))));
4173
4174    val recursions' = map fixr recursions;
4175    val rsubsts = map (op|-> o dest_eq o concl) recursions'
4176
4177    val (props,output_terms) =
4178        make_abstract_funcs target terminals [snd (strip_forall term)]
4179            [subst rsubsts (snd (strip_forall (rhs (concl full_thm))))];
4180
4181    val _ = trace 1 ("\nAbstracted definition:\n" ^
4182                     term_to_string (hd output_terms));
4183    val _ = trace 1 ("\nPropagation term:\n" ^
4184                     term_to_string (hd props) ^ "\n\n");
4185in
4186    (hd props,last converted,hd output_terms)
4187end;
4188
4189fun convert_abstracted_definition
4190    f target name limits extras thm pre_rewrites tactic1 tactic2 =
4191let val (function,missing) = clause_to_case thm
4192    val limits' = map clause_to_limit missing @ limits
4193    val (prop,thm,term) = create_abstracted_definition_term
4194                 f target name limits' extras function
4195                 handle e => wrapException "convert_abstracted_definition" e
4196    val conv = REWRITE_CONV pre_rewrites
4197    val (definition,ind) = (hd ## I)
4198                           (define_with_tactic (K false) conv tactic1 [term])
4199    val prop' = hd (fix_definition_terms [definition] [prop])
4200    val vars = free_vars_lr prop'
4201    fun is_decoded v term =
4202        (can (match_term (get_decode_function target (type_of term)))
4203            (rator term)
4204        andalso rand term = v)  handle e => false;
4205    val decoded_vars = map (fn v => (hd o find_terms (is_decoded v)) prop')
4206                           vars
4207    val encoded_vars = map (fn dv => mk_comb(
4208                       get_encode_function target (type_of dv),
4209                       mk_var(fst (dest_var (rand dv)),type_of dv)))
4210                       decoded_vars;
4211    val nvars = map rand encoded_vars
4212    fun lmkimp [] t = t
4213      | lmkimp L t = mk_imp(list_mk_conj
4214        (map (full_beta o C (curry list_imk_comb) nvars) L),t);
4215    val term = list_mk_forall(nvars,
4216        (lmkimp limits' ((subst (map2 (curry op|->) vars encoded_vars) o
4217         subst (map2 (curry op|->) decoded_vars nvars)) prop')));
4218    val theorem = REWRITE_RULE pre_rewrites thm;
4219    val encdets = map (FULL_ENCODE_DETECT_THM target o type_of) nvars;
4220    val encdecs = map (FULL_ENCODE_DECODE_THM target o type_of) nvars;
4221
4222    val prop_thm = prove(term,tactic2 definition thm)
4223
4224    val name = (fst o dest_const o fst o strip_comb o rhs o
4225                                    snd o strip_imp_only o snd o strip_forall o
4226                                    concl) prop_thm;
4227
4228    val _ = add_standard_rewrite 0 name prop_thm;
4229
4230    val _ = save_thm("prop_" ^ name,prop_thm);
4231    val _ = save_thm("translated_" ^ name,definition);
4232
4233in
4234    definition
4235end;
4236
4237fun convert_abstracted_nonrec_definition f target name limits extras thm =
4238    convert_abstracted_definition f target name limits extras thm []
4239    NO_TAC
4240    (fn definition => fn rewrite => fn (a,g) =>
4241        let val types = map type_of (fst (strip_forall g))
4242        in
4243            (REWRITE_TAC [definition,GSYM rewrite] THEN
4244            REWRITE_TAC (map (FULL_ENCODE_DETECT_THM target) types) THEN
4245            REWRITE_TAC (map (FULL_ENCODE_DECODE_THM target) types) THEN
4246            REPEAT STRIP_TAC THEN
4247            RW_TAC std_ss [thm]) (a,g)
4248        end);
4249
4250end