1structure encodeLib :> encodeLib =
2struct
3
4open Thm Term Type boolSyntax Parse Conv Rewrite Drule
5open Tactic Tactical pairLib numLib polytypicLib
6open Binarymap List Lib
7open boolTheory pairTheory listTheory combinTheory
8
9open bossLib metisLib;
10
11(*****************************************************************************)
12(* construct_bottom_value : (hol_type -> bool) -> term -> hol_type -> term   *)
13(*     Given a predicate to indicate stopping, and a stop term, constructs   *)
14(*     the first non-recursive constructor of the type given.                *)
15(*                                                                           *)
16(*     Eg. construct_bottom_value (curry op= bool) F                         *)
17(*                   ``:bool # (bool + bool)`` = ``(F,INL F)``               *)
18(*                                                                           *)
19(* target_bottom_value : hol_type -> term -> hol_type -> term                *)
20(*     Constructs the term using 'ARB' values, then replaces them with the   *)
21(*     given term once complete.                                             *)
22(*                                                                           *)
23(*     Eg. target_bottom_value bool F ``:'a # ('b + 'c)`` = ``(F,INL F)``    *)
24(*                                                                           *)
25(*****************************************************************************)
26
27local
28exception Bottom;
29fun tryfind b f [] = if b then raise Bottom else raise Empty
30  | tryfind b f (x::xs) =
31  (f x) handle Empty => tryfind b f xs
32             | Bottom => tryfind true f xs
33fun match term t = inst (match_type (snd (strip_fun (type_of term))) t) term
34fun construct_bottom_value_n 0 p xvar _ = raise Bottom
35  | construct_bottom_value_n n p xvar t =
36    if p t then match xvar t
37       else let val cs = [match (get_source_function_const t "bottom-cons") t]
38                         handle _ => (constructors_of t handle e => [])
39            in  tryfind false
40                (fn c => full_beta (list_mk_comb(c,
41                        map (construct_bottom_value_n (n - 1) p xvar)
42                        (fst (strip_fun (type_of c))))))  cs
43            end
44fun itdeep f n = f n handle Bottom => itdeep f (n + 1)
45fun check value t =
46    if type_of value = t then value
47       else raise (mkDebugExn "construct_bottom_value"
48            ("Constructed a term of type: " ^ type_to_string (type_of value) ^
49             "\nwhen a value of type: " ^ type_to_string t ^
50             "\nshould have been returned!!"))
51in
52fun construct_bottom_value p xvar t =
53    check (itdeep (fn n => construct_bottom_value_n n p xvar t) 0
54           handle Empty => raise (mkStandardExn "construct_bottom_value"
55                ("Could not find bottom values for all sub-types of "
56                 ^ type_to_string t))
57                | e => wrapException "construct_bottom_value" e) t
58end;
59
60fun target_bottom_value target bottom_target t =
61let val b1 = construct_bottom_value is_vartype (mk_arb alpha) t
62    val arbs = HolKernel.find_terms is_arb b1
63    val types = mk_set (map type_of arbs)
64    val b2 = inst (map (fn t => t |-> target) types) b1
65in
66    subst [mk_arb target |-> bottom_target] b2
67end
68
69(*****************************************************************************)
70(* set_bottom_value : hol_type -> term -> unit                               *)
71(*    Set the bottom value for the type given, this is only required for     *)
72(*    non-recursive types.                                                   *)
73(*                                                                           *)
74(*****************************************************************************)
75
76fun set_bottom_value t term =
77   add_source_function_precise
78       t "bottom-cons"
79       {const = term, definition = TRUTH,induction = NONE}
80   handle e => wrapException "set_bottom_value" e
81
82(*****************************************************************************)
83(* Generation of the encoding functions                                      *)
84(*                                                                           *)
85(* get_encode_type, get_decode_type, get_detect_type, get_fix_type           *)
86(*                   : hol_type -> hol_type -> hol_type                      *)
87(* get_map_type, get_all_type                                                *)
88(*                   : hol_type -> hol_type                                  *)
89(* mk_encode_var, mk_decode_var, mk_detect_var, mk_fix_var                   *)
90(*                   : hol_type -> hol_type -> term                          *)
91(* mk_map_var, mk_all_var                                                    *)
92(*                   : hol_type -> term                                      *)
93(*     Returns the type of an encoding, decoding, detecting or mapping       *)
94(*     constant, and makes a variable for a prospective constant             *)
95(*                                                                           *)
96(* mk_encode_term        : hol_type -> hol_type -> term                      *)
97(*     Makes a full encoding term for the translation given:                 *)
98(*     Single constructor: enc (C a0 a1) = P enc0 enc1 (a0,a1)               *)
99(*     Label constructors: enc Cn = nat n                                    *)
100(*     Otherwise         : enc (Ci a0 a1) = P nat (P enc0 enc1) (i,a0,a1)    *)
101(*                                                                           *)
102(* mk_decode_term        : hol_type -> hol_type -> term                      *)
103(*     Makes a full decoding term for the translation given:                 *)
104(*     Single constructor: dec x = let (a,b) = D dec0 dec1 x in (C a b)      *)
105(*     Label constructors: dec x = if dnat x = 0 then C0 else ....           *)
106(*     Otherwise         : dec x =                                           *)
107(*                             let (l,r) = D dnat I x                        *)
108(*                             in  if l = 0 then                             *)
109(*                                    let (a,b) = D dec0 dec1 r in (C a b)   *)
110(*                                 else map dec0 dec1 (C nil nil)            *)
111(*                                                                           *)
112(* mk_detect_term        : hol_type -> hol_type -> term                      *)
113(*     Makes a full decoding term for the translation given:                 *)
114(*     Single constructor: dec x = P dec0 dec1 x                             *)
115(*     Label constructors: dec x = bool (dnat x = 0 \/ ... \/ dnat x = n)    *)
116(*     Otherwise         : dec x =                                           *)
117(*                             bool (                                        *)
118(*                                dbool (P dnat (K (bool T)) x)              *)
119(*                             /\ let (l,r) = D dnat I x                     *)
120(*                                in  (l = 0) /\ dbool (P det0 det1 r)       *)
121(*                                    \/  ...)                               *)
122(*                                                                           *)
123(* mk_map_term           : hol_type -> term                                  *)
124(*     Makes a full map function for the given type:                         *)
125(*     Label constructors: map Li = Li                                       *)
126(*     Otherwise         : map (C a0 .. an) = (map0 # .. # mapn) (a0,..,an)  *)
127(*                                                                           *)
128(* mk_all_term           : hol_type -> term                                  *)
129(*     Makes a full all function for the given type:                         *)
130(*     Label constructors: all Li = T                                        *)
131(*     Otherwise         : all (C a0 .. an) = (all0 # .. # alln) (a0,..,an)  *)
132(*                                                                           *)
133(* mk_fix_term          : hol_type -> hol_type -> term                       *)
134(*     Single constructor: fix x = TP fix0 (TP ... fixn) .. ) x              *)
135(*     Label constructors: fix x = x                                         *)
136(*     Otherwise         : fix x =                                           *)
137(*                             let (l,r) = D dnat I x                        *)
138(*                             in  if l = 0 then                             *)
139(*                                    pair nat I (TP fix0 (TP .. fixn) ) r   *)
140(*                                 else enc fix0 fix1 (C nil nil)            *)
141(*                                                                           *)
142(* get_encode_function, get_decode_function, get_detect_function,            *)
143(* get_fix_function     : hol_type -> hol_type -> term                       *)
144(* get_map_function      : hol_type -> term                                  *)
145(*     Gets a fully instantiated term to translate the type                  *)
146(*                                                                           *)
147(* ENCODE_CONV, DECODE_CONV, DETECT_CONV, FIX_CONV : hol_type -> term -> thm *)
148(*     Given the target type, each conv rewrites to convert a term given to  *)
149(*     it by mk_..._term to a form suitable for split_function.              *)
150(*                                                                           *)
151(* CONSOLIDATE_CONV : (term -> term) -> term -> thm                          *)
152(*     Given a conjunction of functions with instantiated bottom values, ie  *)
153(*     of the form:                                                          *)
154(*        f0 x = if .. then .. else B0 /\                                    *)
155(*                  ...                                                      *)
156(*        fn x = if .. then .. else Bn                                       *)
157(*     Where B0...Bn may contain references to f0...fn and continually       *)
158(*     rewrites with each definition to get B0'...Bn' that don't contain     *)
159(*     references to f0 ... fn. Only works for 'decode' and 'fix'!           *)
160(*                                                                           *)
161(*                                                                           *)
162(* mk_encodes, mk_decodes, mk_detects, mk_fixs : hol_type -> hol_type-> unit *)
163(* mk_maps, mk_alls : hol_type -> unit                                       *)
164(*     Generate the functions given. Shouldn't really be used, as it doesn't *)
165(*     use the generator system, and will hence fail if functions are        *)
166(*     missing.                                                              *)
167(*                                                                           *)
168(*****************************************************************************)
169
170local
171fun get_gen_type opr target t =
172        foldr (fn (x,t) => (opr (x,target)) --> t) (opr (t,target))
173                (if is_vartype t then [] else snd (dest_type t))
174in
175fun get_encode_type target t = get_gen_type op--> target t
176                handle e => wrapException "get_encode_type" e
177fun get_decode_type target t = get_gen_type (uncurry (C (curry op-->))) target t
178                handle e => wrapException "get_decode_type" e
179fun get_detect_type target t = get_gen_type (fn (_,a) => a --> bool) target t
180                handle e => wrapException "get_detect_type" e
181fun get_map_type t =
182let     val tyvars = type_vars t
183        fun gentvar t = (mk_vartype o curry op^ "'map_" o get_type_string) t
184        val new_vars = map gentvar tyvars
185        val t' = type_subst (map2 (curry op|->) tyvars new_vars) t
186in
187        foldr (fn ((x,y),t) => (x --> y) --> t) (t --> t')
188                (if is_vartype t then [] else zip (snd (dest_type t)) (snd (dest_type t')))
189end
190fun get_all_type t = get_gen_type (fn (a,_) => a --> bool) t t
191fun get_fix_type target t = get_gen_type (fn (_,_) => target --> target) target t
192                handle e => wrapException "get_fix_type" e
193end
194
195local
196fun mk_encode_string t = "encode" ^ (get_type_string t)
197fun mk_decode_string t = "decode" ^ (get_type_string t)
198fun mk_detect_string t = "detect" ^ (get_type_string t)
199fun mk_map_string t    = "map"    ^ (get_type_string t)
200fun mk_fix_string t = "fix" ^ (get_type_string t)
201fun mk_all_string t = "all" ^ (get_type_string t)
202fun mk_fix_string t = "fix" ^ (get_type_string t)
203in
204fun mk_encode_var target t =
205        mk_var(mk_encode_string t,get_encode_type target t)
206        handle e => wrapException "mk_encode_var" e
207fun mk_decode_var target t =
208        mk_var(mk_decode_string t,get_decode_type target t)
209        handle e => wrapException "mk_decode_var" e
210fun mk_detect_var target t =
211        mk_var(mk_detect_string t,get_detect_type target t)
212        handle e => wrapException "mk_detect_var" e
213fun mk_map_var t =
214        mk_var(mk_map_string t,get_map_type t)
215        handle e => wrapException "mk_map_var" e
216fun mk_fix_var target t =
217        mk_var(mk_fix_string t,get_fix_type target t)
218        handle e => wrapException "mk_fix_var" e
219fun mk_all_var t =
220        mk_var(mk_all_string t,get_all_type t)
221        handle e => wrapException "mk_all_var" e
222fun mk_fix_var target t =
223        mk_var(mk_fix_string t,get_fix_type target t)
224        handle e => wrapException "mk_fix_var" e
225end
226
227local
228fun new_const NONE const t = const
229  | new_const (SOME match) const t =
230    safe_inst (match_type (match (type_of const)) t) const
231in
232fun get_encode_const target t =
233    if  t = target
234        then mk_const("I",t --> target)
235        else new_const (SOME (last o fst o strip_fun))
236                       (get_coding_function_const target t "encode") t
237    handle e => wrapException "get_encode_const" e
238fun get_decode_const target t =
239    if t = target
240       then mk_const("I",target --> t)
241       else new_const (SOME (snd o strip_fun))
242                      (get_coding_function_const target t "decode") t
243    handle e => wrapException "get_decode_const" e;
244fun get_map_const t =
245    new_const (SOME (last o fst o strip_fun))
246              (get_source_function_const t "map") t
247    handle e => wrapException "get_map_const" e;
248fun get_fix_const target t =
249    if t = target
250       then mk_const("I",target --> t)
251       else new_const NONE
252                      (get_coding_function_const target t "fix") t
253   handle e => wrapException "get_fix_const" e;
254fun get_all_const t =
255    new_const (SOME (last o fst o strip_fun))
256              (get_source_function_const t "all") t
257    handle e => wrapException "get_all_const" e;
258fun get_detect_const target t =
259    if t = target
260       then mk_comb(mk_const("K",bool --> target --> bool),T)
261       else new_const NONE (get_coding_function_const target t "detect") t
262    handle e => wrapException "get_detect_const" e
263end
264
265local
266fun fix_base basetype term =
267let val types = set_diff (type_vars_in_term term) (type_vars basetype)
268in  inst (map (fn x => x |-> gen_tyvar()) types) term
269end;
270fun imk_comb (main,p) =
271    mk_comb(inst (match_type (fst (dom_rng (type_of main))) (type_of p)) main,p)
272val is_the_value_type = can (match_type (type_of boolSyntax.the_value))
273fun typevars_lr t =
274    if is_vartype t then [t]
275       else flatten (map typevars_lr (snd (dest_type t)));
276fun mk_the_value t =
277    inst (match_type (type_of the_value) (mk_type("itself",[t]))) the_value;
278fun get_function fconst fexists mvar t =
279let val basetype = (most_precise_type fexists t)
280        handle _ => (if is_vartype t then t else base_type t)
281    val base = fconst basetype handle _ => mvar t
282    val params = set_diff (typevars_lr basetype) [t]
283    val match = match_type basetype t
284    val param_list = map (type_subst match) params
285    val insted = inst match (fix_base basetype base)
286 in
287    if not (is_vartype t)
288       then foldl (fn (a,t) =>
289                  imk_comb(t,get_function fconst fexists mvar a) handle _ =>
290                  imk_comb(t,mk_the_value a) handle _ => t)
291          insted param_list
292       else base
293end
294in
295fun get_encode_function target t =
296    get_function (get_encode_const target)
297                 (C (exists_coding_function_precise target) "encode")
298                 (mk_encode_var target) t
299        handle e => wrapException "get_encode_function" e
300fun get_decode_function target t =
301    get_function (get_decode_const target)
302                 (C (exists_coding_function_precise target) "decode")
303                 (mk_decode_var target) t
304        handle e => wrapException "get_decode_function" e
305fun get_detect_function target t =
306    get_function (get_detect_const target)
307                 (C (exists_coding_function_precise target) "detect")
308                 (mk_detect_var target) t
309        handle e => wrapException "get_detect_function" e
310fun get_map_function t =
311    get_function get_map_const
312                 (C exists_source_function_precise "map")
313                 mk_map_var t
314        handle e => wrapException "get_map_function" e
315fun get_fix_function target t =
316    get_function (get_fix_const target)
317                 (C (exists_coding_function_precise target) "fix")
318                 (mk_fix_var target) t
319    handle e => wrapException "get_fix_function" e
320fun get_all_function t =
321     get_function get_all_const
322                 (C exists_source_function_precise "all")
323                 mk_all_var t
324        handle e => wrapException "get_all_function" e
325end
326
327local
328fun mk_detect_constructor_ns rest target C =
329let     val (list,_) = strip_fun (type_of C)
330        val vars = map (mk_var o (implode o base26 ## I)) (enumerate 0 list)
331in
332        mk_comb(get_detect_function target (list_mk_prod list),rest)
333end
334fun mk_detect_constructor rest target C T =
335        if can dom_rng (type_of C) then mk_detect_constructor_ns rest target C else T
336fun mk_detect_res_term label rest target t constructors T =
337        list_mk_cond
338                (map (fn (a,b) => (mk_eq(label,numLib.term_of_int a),
339                        mk_detect_constructor rest target b T)) (enumerate 0 constructors))
340                F
341fun mk_detect_term_label (p,x) target t constructors =
342let     val dnat = get_decode_function target num
343        val rnat = get_detect_function target num
344in
345        mk_forall(x,mk_eq(mk_comb(get_detect_function target t,x),
346                mk_cond(mk_comb(rnat,x),
347                        mk_detect_res_term (mk_comb(dnat,x)) x target t constructors T,
348                        F)))
349end
350fun mk_detect_term_all (p,x) target t constructors =
351let     val dnat = get_detect_function target num
352        val label = mk_var("l",num)
353        val rest = mk_var("r",target)
354        val null = mk_comb(get_encode_function target bool,F)
355in
356        list_mk_forall(snd (strip_comb (get_detect_function target t)),
357                mk_forall(x,mk_eq(mk_comb(get_detect_function target t,x),
358                mk_cond(p,pairSyntax.mk_anylet (
359                                [(mk_pair(label,rest),mk_comb(get_decode_function target (mk_prod(num,target)),x))],
360                                mk_detect_res_term label rest target t constructors (mk_eq(rest,null))),
361                        F))))
362end
363fun mk_detect_term_single (p,x) target t constructor =
364let     val t' = (mk_type o (I ## map (K target)) o dest_type) t
365        val p = get_detect_function target t'
366in
367        list_mk_forall(snd (strip_comb (get_detect_function target t)),
368                mk_forall(x,mk_eq(mk_comb(get_detect_function target t,x),mk_detect_constructor x target constructor T)))
369end
370in
371fun mk_detect_term target t =
372let     val t' = base_type t  handle e => wrapException "mk_detect_term" e
373        val constructors = constructors_of t'  handle e => wrapException "mk_detect_term" e
374        val x = mk_var("x",target)
375        val p = mk_comb(get_detect_function target (mk_prod(num,target)),x)
376in
377        if      all (not o can dom_rng o type_of) constructors
378        then    mk_detect_term_label (p,x) target t' constructors
379                handle e => wrapException "mk_detect_term (label)" e
380        else    if      length constructors = 1
381                then    mk_detect_term_single (p,x) target t' (hd constructors)
382                        handle e => wrapException "mk_detect_term (single)" e
383                else    mk_detect_term_all (p,x) target t' constructors
384                        handle e => wrapException "mk_detect_term" e
385end
386end
387
388local
389fun full_bottom_value target bottom_target t =
390let val bottom = target_bottom_value target bottom_target t
391    val mapf = get_map_const t
392    val decodef = get_decode_function target t
393    val map_function = list_imk_comb(mapf,snd (strip_comb decodef))
394    val bottom' = inst (match_type (type_of bottom)
395                                (fst (dom_rng (type_of map_function)))) bottom
396in
397    mk_comb(map_function,bottom')
398end handle e => wrapException "full_bottom_value" e
399fun mk_decode_constructor_ns rest target C =
400let     val (list,_) = strip_fun (type_of C)
401        val vars = map (mk_var o (implode o base26 ## I)) (enumerate 0 list)
402in
403        pairSyntax.mk_anylet (
404                [(list_mk_pair vars,mk_comb(get_decode_function target (list_mk_prod list),rest))],
405                (list_mk_comb(C,vars)))
406end
407fun mk_decode_constructor rest target C =
408        if can dom_rng (type_of C) then mk_decode_constructor_ns rest target C else C
409fun mk_decode_res_term label rest target t constructors bottom =
410        list_mk_cond
411                (map (fn (a,b) => (mk_eq(label,numLib.term_of_int a),
412                        mk_decode_constructor rest target b)) (enumerate 0 constructors))
413                bottom
414fun mk_decode_term_label (p,x) target t constructors =
415let     val dnat = get_decode_function target num
416        val rnat = get_detect_function target num
417        val bottom = hd constructors
418in
419        mk_forall(x,mk_eq(mk_comb(get_decode_function target t,x),
420                mk_cond(mk_comb(rnat,x),
421                        mk_decode_res_term (mk_comb(dnat,x)) x target t constructors bottom,
422                        hd (constructors))))
423end
424fun mk_decode_term_all (p,x) target t constructors =
425let     val dnat = get_decode_function target num
426        val label = mk_var("l",num)
427        val rest = mk_var("r",target)
428        val bottom = full_bottom_value target (#bottom(get_translation_scheme target)) t
429in
430        list_mk_forall(snd (strip_comb (get_decode_function target t)),
431                mk_forall(x,mk_eq(mk_comb(get_decode_function target t,x),
432                mk_cond(p,pairSyntax.mk_anylet (
433                                [(mk_pair(label,rest),mk_comb(get_decode_function target (mk_prod(num,target)),x))],
434                                mk_decode_res_term label rest target t constructors bottom),
435                        bottom))))
436end
437fun mk_decode_term_single (p,x) target t constructor =
438let     val t' = (mk_type o (I ## map (K target)) o dest_type) t
439        val p = get_detect_function target t'
440in
441        list_mk_forall(snd (strip_comb (get_decode_function target t)),
442                mk_forall(x,mk_eq(mk_comb(get_decode_function target t,x),
443                mk_cond(mk_comb(p,x),
444                        mk_decode_constructor x target constructor,
445                        full_bottom_value target (#bottom(get_translation_scheme target)) t))))
446end
447in
448fun mk_decode_term target t =
449let     val t' = base_type t  handle e => wrapException "mk_decode_term" e
450        val constructors = constructors_of t'  handle e => wrapException "mk_decode_term" e
451        val x = mk_var("x",target)
452        val p = mk_comb(get_detect_function target (mk_prod(num,target)),x)
453in
454        if      all (not o can dom_rng o type_of) constructors
455        then    mk_decode_term_label (p,x) target t' constructors
456                handle e => wrapException "mk_decode_term (label)" e
457        else    if      length constructors = 1
458                then    mk_decode_term_single (p,x) target t' (hd constructors)
459                        handle e => wrapException "mk_decode_term (single)" e
460                else    mk_decode_term_all (p,x) target t' constructors
461                        handle e => wrapException "mk_decode_term" e
462end
463end
464
465local
466fun mk_fix_constructor_ns all target C =
467let     val (list,_) = strip_fun (type_of C)
468        val vars = map (mk_var o (implode o base26 ## I)) (enumerate 0 list)
469in
470        mk_comb(get_fix_function target (list_mk_prod (num::list)),all)
471end
472fun mk_fix_constructor all target dead n C =
473        if can dom_rng (type_of C)
474                then mk_fix_constructor_ns all target C
475                else mk_comb(get_encode_function target (mk_prod(num,bool)),
476                        mk_pair(numLib.term_of_int n,F))
477fun mk_fix_res_term label all target constructors dead bottom =
478        list_mk_cond
479                (map (fn (a,b) => (mk_eq(label,numLib.term_of_int a),
480                        mk_fix_constructor all target dead a b)) (enumerate 0 constructors))
481                bottom
482fun mk_fix_term_label x target t constructors =
483let     val dnat = get_decode_function target num
484        val rnat = get_detect_function target num
485        val enat = get_encode_function target num
486in
487        mk_forall(x,mk_eq(mk_comb(get_fix_function target t,x),
488                mk_cond(mk_comb(rnat,x),
489                        list_mk_cond (map (fn (a,b) => (mk_eq(mk_comb(dnat,x),numLib.term_of_int a),x)) (enumerate 0 constructors))
490                                        (mk_comb(enat,zero_tm)),
491                        mk_comb(get_encode_function target num,zero_tm))))
492end
493fun mk_fix_term_all (p,x) target t constructors dead =
494let val dnat = get_fix_function target num
495    val label = mk_var("l",num)
496    val rest = mk_var("r",target)
497    val t' = (mk_type o (I ## map (K target)) o dest_type) t
498    val instit = inst (map (fn v => v |-> target) (type_vars t))
499    val enc1 = instit (get_encode_function target t)
500    val enc2 = subst (map (fn v => instit (get_encode_function target v) |->
501                              get_fix_function target v) (type_vars t)) enc1
502    val bottom = rimk_comb(enc2,target_bottom_value target dead t)
503in
504    list_mk_forall(snd (strip_comb (get_fix_function target t)),
505        mk_forall(x,mk_eq(mk_comb(get_fix_function target t,x),
506            mk_cond(p,pairSyntax.mk_anylet (
507                [(mk_pair(label,rest),
508                  mk_comb(get_decode_function target (mk_prod(num,target)),x))],
509                  mk_fix_res_term label x target constructors dead bottom),
510                  bottom))))
511end
512fun mk_fix_term_single x target t constructor dead =
513let val t' = (mk_type o (I ## map (K target)) o dest_type) t
514    val p = get_detect_function target t'
515    val instit = inst (map (fn v => v |-> target) (type_vars t))
516    val enc1 = instit (get_encode_function target t)
517    val enc2 = subst (map (fn v =>
518                          instit (get_encode_function target v) |->
519                          get_fix_function target v) (type_vars t)) enc1
520    val bottom = rimk_comb(enc2,target_bottom_value target dead t)
521    val list = fst (strip_fun (type_of constructor))
522in
523    list_mk_forall(snd (strip_comb (get_fix_function target t)),
524        mk_forall(x,mk_eq(mk_comb(get_fix_function target t,x),
525            mk_cond(mk_comb(p,x),
526                mk_comb(get_fix_function target (list_mk_prod list),x),
527                    bottom))))
528end
529in
530fun mk_fix_term target t =
531let val t' = base_type t  handle e => wrapException "mk_fix_term" e
532    val constructors = constructors_of t'
533        handle e => wrapException "mk_fix_term" e
534    val x = mk_var("x",target)
535    val p = mk_comb(get_detect_function target (mk_prod(num,target)),x)
536    val dead = #bottom (get_translation_scheme target)
537in
538    if all (not o can dom_rng o type_of) constructors
539       then mk_fix_term_label x target t' constructors
540            handle e => wrapException "mk_fix_term (label)" e
541       else if length constructors = 1
542               then mk_fix_term_single x target t' (hd constructors) dead
543                    handle e => wrapException "mk_fix_term (single)" e
544               else mk_fix_term_all (p,x) target t' constructors dead
545                    handle e => wrapException "mk_fix_term" e
546end
547end
548
549local
550fun mk_avar (n,t) = (mk_var ("a_" ^ Int.toString n,t),t)
551fun single_pair num target t cnst =
552let val ts = fst (strip_fun (type_of cnst))
553    val tvs = map mk_avar (enumerate 0 ts)
554    val (vars,types) =
555        unzip (case num
556              of NONE => tvs
557              |  SOME x => (case tvs
558                            of [] => [(numLib.term_of_int x,``:num``),
559                                      (``F``,``:bool``)]
560                            |  list => (numLib.term_of_int x,``:num``)::list))
561in
562    list_mk_forall(map fst tvs,
563    list_mk_forall(map (get_encode_function target) (snd (dest_type t)),
564        mk_eq(mk_comb(get_encode_function target t,
565                      list_mk_comb(cnst,map fst tvs)),
566              if vars = []
567                 then mk_comb(get_encode_function target ``:num``,``0n``)
568                 else (mk_comb(get_encode_function target
569                                  (pairLib.list_mk_prod types),
570                               pairLib.list_mk_pair vars)))))
571end
572fun mk_encode_term_single target t cnst = single_pair NONE target t cnst
573fun mk_encode_term_label target t cnsts =
574let val num = get_encode_function target ``:num``
575    val func = get_encode_function target t
576in
577    list_mk_conj (map (fn (n,c) =>
578                 mk_eq(mk_comb(func,c),mk_comb(num,numLib.term_of_int n)))
579                 (enumerate 0 cnsts))
580end
581fun mk_encode_term_all target t cnsts =
582    list_mk_conj (map (fn (n,c) => single_pair (SOME n) target t c)
583                 (enumerate 0 cnsts))
584in
585fun mk_encode_term target t =
586let val t' = base_type t handle e => wrapException "mk_encode_term" e
587    val constructors = constructors_of t'
588     handle e => wrapException "mk_encode_term" e
589in
590    if all (not o can dom_rng o type_of) constructors
591       then mk_encode_term_label target t' constructors
592            handle e => wrapException "mk_encode_term (label)" e
593       else if length constructors = 1
594               then mk_encode_term_single target t' (hd constructors)
595             handle e => wrapException "mk_encode_term (single)" e
596               else mk_encode_term_all target t' constructors
597             handle e => wrapException "mk_encode_term" e
598end
599end;
600
601fun mk_map_term t =
602let val cs = constructors_of t
603    val args = map (fn c => map (fn (n,t) =>
604                mk_var(implode (base26 n),t))
605                (enumerate 0 (fst (strip_fun (type_of c))))) cs
606    val combs = map2 (curry list_mk_comb) cs args
607    val func = get_map_function t
608    val funs = snd (strip_comb func)
609    fun imk_eq (a,b) = mk_eq(a,inst (match_type (type_of b) (type_of a)) b)
610in
611    list_mk_conj (map (fn c => list_mk_forall(funs,
612                list_mk_forall(snd (strip_comb c),imk_eq(mk_comb(func,c),
613                ((list_imk_comb o (I ## map (fn a =>
614                        mk_comb(get_map_function (type_of a),a))))
615                        (strip_comb c))))))
616                combs)
617end handle e => wrapException "mk_map_term" e
618
619fun mk_all_term t =
620let     val cs = constructors_of t
621        val args =  map (fn c => map (fn (n,t) =>
622                mk_var(implode (base26 n),t)) (enumerate 0 (fst (strip_fun (type_of c))))) cs
623        val combs = map2 (curry list_mk_comb) cs args
624        val func = get_all_function t
625        val funs = snd (strip_comb func)
626in
627        list_mk_conj (map2 (fn a => fn c => list_mk_forall(funs,
628                        list_mk_forall(a,mk_eq(mk_comb(func,c),
629                                case a of
630                                [] => T
631                                | a => mk_comb(get_all_function (list_mk_prod(map type_of a)),list_mk_pair a)))))
632        args combs)
633end     handle e => wrapException "mk_all_term" e
634
635fun ENCODE_CONV pair_thm term =
636let     val _ = type_trace 2 "->ENCODE_CONV\n"
637        val fa_pairs = if (!debug) then
638                        bucket_alist (zip       (map (repeat rator o lhs o snd o strip_forall) (strip_conj term))
639                                                (map (type_of o rand o rhs o snd o strip_forall) (strip_conj term)))
640                        handle _ => []
641                        else []
642        val t = mk_prod(numLib.num,alpha)
643        val _ = case (total (first (fn (a,b) => exists (can (match_type t)) b andalso
644                        exists (not o can (match_type t)) b)) fa_pairs)
645                of SOME (fname,_) => raise (mkDebugExn "ENCODE_CONV"
646                                        ("Function clause: " ^ term_to_string fname ^
647                                         " converts to a mixture of labelled pairs and non-labelled pairs"))
648                |  NONE => ()
649        fun drop_all term = ((REWR_CONV pair_thm THENC RAND_CONV drop_all) ORELSEC ALL_CONV) term
650in
651        EVERY_CONJ_CONV (STRIP_QUANT_CONV (RAND_CONV drop_all)) term
652        handle UNCHANGED => REFL term handle e => wrapException "PAIR_CONV (encode)" e
653end;
654
655local
656fun rws l = RATOR_CONV (RATOR_CONV (RAND_CONV (REWRITE_CONV l)))
657fun DC target term =
658let     val r = rhs term
659        val pair_def = get_coding_function_def target (mk_prod(alpha,beta)) "detect"
660        val (p,a,b) = dest_cond r handle e => raise UNCHANGED
661        val (left,right) = dest_eq (snd (strip_forall (concl pair_def)))
662        val (xp,_,_) = dest_cond right
663in
664        if can (match_term left) p then
665                let     val thm1 = rws [ASSUME xp,pair_def,COND_EXPAND] (rhs term);
666                        val thm2 = (rws [ASSUME (mk_neg xp),pair_def,COND_EXPAND] THENC PURE_REWRITE_CONV [COND_CLAUSES]) (rhs term);
667                in
668                        AP_TERM (rator term) (SYM (RIGHT_CONV_RULE (REWR_CONV COND_ID) (MATCH_MP COND_CONG (LIST_CONJ [REFL xp,DISCH_ALL (SYM thm1),DISCH_ALL (SYM thm2)]))))
669                end
670        else    REFL term
671end handle UNCHANGED => REFL term | e => wrapException "DC" e
672fun SC target term =
673let     val (p,a,b) = dest_cond (rhs term)
674        val pair_def = get_coding_function_def target (mk_prod(alpha,beta)) "decode"
675        val pthm = PURE_REWRITE_RULE [COND_CLAUSES,ASSUME p] (PART_MATCH (rand o rator o rator o rhs) pair_def p)
676in
677        AP_TERM (rator term) (MATCH_MP COND_CONG (LIST_CONJ
678                [REFL p,DISCH p (RATOR_CONV (RAND_CONV (RAND_CONV (REWR_CONV pthm) THENC PURE_REWRITE_CONV [I_THM] THENC pairLib.let_CONV)) a),DISCH (mk_neg p) (REFL b)]))
679end handle e => NO_CONV term
680fun FIXC target term =
681let     val (p,a,b) = dest_cond (rhs term)
682        val pair_def = get_coding_function_def target (mk_prod(alpha,beta)) "fix"
683        val pthm = REWRITE_RULE [I_THM] (PART_MATCH (rator o lhs) (PURE_REWRITE_RULE [COND_CLAUSES,ASSUME p] (PART_MATCH (rand o rator o rator o rhs) pair_def p))
684                                (get_fix_function target (mk_prod(num,alpha))))
685in
686        AP_TERM (rator term) (MATCH_MP COND_CONG (LIST_CONJ
687                [REFL p,DISCH p (PURE_REWRITE_CONV [pthm] a),DISCH (mk_neg p) (REFL b)]))
688end
689in
690fun DETECT_CONV target term =
691        (type_trace 2 "->DETECT_CONV\n" ;
692        EVERY_CONJ_CONV (STRIP_QUANT_CONV (DC target THENC TRY_CONV (SC target))) term handle e => wrapException "DETECT_CONV" e)
693fun DECODE_CONV target term =
694        (type_trace 2 "->DECODE_CONV\n" ;
695        EVERY_CONJ_CONV (STRIP_QUANT_CONV (DC target THENC TRY_CONV (SC target))) term handle e => wrapException "DECODE_CONV" e)
696fun FIX_CONV target term =
697        (type_trace 2 "->FIX_CONV\n" ;
698        EVERY_CONJ_CONV (STRIP_QUANT_CONV (DC target THENC TRY_CONV (SC target) THENC REWRITE_CONV [K_THM] THENC TRY_CONV (FIXC target))) term handle e => wrapException "FIX_CONV" e)
699end;
700
701fun mk_encodes target t =
702let     val _ = if exists_coding_function target t "encode" then
703                raise (mkStandardExn "mk_encodes"
704                        ("Encoder function for translation: " ^ type_to_string t ^ " --> " ^ type_to_string target ^
705                         " already exists.")) else ()
706        val _ = type_trace 1
707                ("Generating encoding function for: " ^ type_to_string t ^ " --> " ^ type_to_string target ^ "\n")
708in
709        mk_coding_functions
710                        "encode"
711                        (mk_encode_term target)
712                        (get_encode_function target)
713                        (ENCODE_CONV (get_coding_function_def target (mk_prod(alpha,beta)) "encode"))
714                        REFL
715                        target
716                        (base_type t) handle e => wrapException "mk_encodes" e
717end;
718
719val CONSOLIDATE_CONV_data = ref (NONE:((term -> term) * term) option);
720
721local
722val PTAC = HO_MATCH_MP_TAC (METIS_PROVE [] ``(A = B) /\ (C = D) ==> (P A C = P B D)``) THEN CONJ_TAC;
723fun AP_TAC funcs (ag as (assums,goal)) =
724        if (C mem funcs o repeat rator o lhs) goal orelse (C mem funcs o repeat rator o rhs) goal then ALL_TAC ag else
725                (TRY ((AP_TERM_TAC ORELSE AP_THM_TAC ORELSE PTAC) THEN AP_TAC funcs)) ag
726in
727fun CONSOLIDATE_CONV rfix function =
728let     val _ = type_trace 2 "->CONSOLIDATE_CONV\n"
729        val _ = CONSOLIDATE_CONV_data := SOME (rfix,function)
730        val clauses = strip_conj function
731        val ends = map ((fn (a,b,c) => c) o dest_cond o rhs o snd o strip_forall) clauses
732        val target = (type_of o rand o lhs o snd o strip_forall o hd) clauses
733        val dead_thm = #bottom_thm (get_translation_scheme target)
734        val subs = (mk_type o (I ## map (fn a => if a = target then gen_tyvar() else a)) o dest_type)
735        val ts = (mk_prod(num,target))::mk_set (filter (not o is_vartype)
736                (flatten (mapfilter (map snd o reachable_graph (uncurried_subtypes) o subs o type_of o rfix) ends)))
737        val maps = map (fn x => (generate_source_function "map" (base_type x) ; C get_source_function_def "map" x)) ts
738        val encs = map (fn x => (generate_coding_function target "encode" (base_type x) ;
739                                get_coding_function_def target x "encode")) ts
740        val decs = flatten (map CONJUNCTS (mapfilter (C (get_coding_function_def target) "decode") ts))
741        val fixs = map (REWRITE_RULE encs) (flatten (map CONJUNCTS (mapfilter (C (get_coding_function_def target) "fix") ts)))
742        val deads = dead_thm::mapfilter (generate_coding_theorem target "detect_dead" o base_type) ts;
743        val hos = map ASSUME clauses
744        val results = map (fn term => REPEATC (CHANGED_CONV (REWRITE_CONV maps THENC REWRITE_CONV encs THENC
745                                ONCE_REWRITE_CONV decs THENC ONCE_REWRITE_CONV fixs
746                                THENC ONCE_REWRITE_CONV hos THENC REWRITE_CONV deads)) term
747                                handle UNCHANGED => REFL term) ends
748        val full = ONCE_DEPTH_CONV (FIRST_CONV (map REWR_CONV results)) function
749        val funcs = map (repeat rator o lhs o snd o strip_forall) (clauses @ map concl decs @ map concl fixs);
750in
751        prove(concl full,
752        REWRITE_TAC encs THEN EQ_TAC THEN REPEAT STRIP_TAC THEN
753        REPEAT (
754                FIRST [FIRST_ASSUM (CONV_TAC o LAND_CONV o REWR_CONV),CONV_TAC (LAND_CONV (FIRST_CONV (map REWR_CONV (fixs @ decs)))),ALL_TAC] THEN
755                FIRST [FIRST_ASSUM (CONV_TAC o RAND_CONV o REWR_CONV),CONV_TAC (RAND_CONV (FIRST_CONV (map REWR_CONV (fixs @ decs)))),ALL_TAC] THEN
756                TRY (MATCH_MP_TAC COND_CONG THEN REPEAT STRIP_TAC) THEN REWRITE_TAC deads THEN AP_TAC funcs THEN
757                REWRITE_TAC maps THEN REWRITE_TAC encs THEN REWRITE_TAC (dead_thm::deads) THEN REWRITE_TAC (mapfilter TypeBase.one_one_of ts) THEN REPEAT CONJ_TAC))
758end     handle e => wrapException "CONSOLIDATE_CONV" e
759end
760
761fun mk_decodes target t =
762let     val _ = if exists_coding_function target t "decode" then
763                raise (mkStandardExn "mk_decodes"
764                        ("Decoder function for translation: " ^ type_to_string t ^ " --> " ^ type_to_string target ^
765                         " already exists.")) else ()
766        val _ = type_trace 1
767                ("Generating decoding function for: " ^ type_to_string t ^ " --> " ^ type_to_string target ^ "\n")
768in
769        mk_target_functions
770                        "decode"
771                        (mk_decode_term target)
772                        (get_decode_function target)
773                        (DECODE_CONV target)
774                        (CONSOLIDATE_CONV rand)
775                        target
776                        (base_type t) handle e => wrapException "mk_decodes" e
777end;
778
779fun mk_detects target t =
780let     val _ = if exists_coding_function target t "detect" then
781                raise (mkStandardExn "mk_detects"
782                        ("Detector function for translation: " ^ type_to_string t ^ " --> " ^ type_to_string target ^
783                         " already exists.")) else ()
784        val _ = type_trace 1
785                ("Generating detecting function for: " ^ type_to_string t ^ " --> " ^ type_to_string target ^ "\n")
786in
787        mk_target_functions
788                        "detect"
789                        (mk_detect_term target)
790                        (get_detect_function target)
791                        (DETECT_CONV target)
792                        REFL
793                        target
794                        (base_type t) handle e => wrapException "mk_detects" e
795end;
796
797fun mk_maps t =
798let     val _ = if exists_source_function t "map" then
799                raise (mkStandardExn "mk_maps"
800                        ("Map function for type: " ^ type_to_string t ^ " already exists.")) else ()
801        val _ = type_trace 1
802                ("Generating map function for: " ^ type_to_string t ^ "\n")
803in
804        mk_source_functions
805                        "map"
806                        mk_map_term
807                        get_map_function
808                        REFL
809                        REFL
810                        (base_type t)
811end;
812
813fun mk_alls t =
814let     val _ = if exists_source_function t "all" then
815                raise (mkStandardExn "mk_alls"
816                        ("All function for type: " ^ type_to_string t ^ " already exists.")) else ()
817        val _ = type_trace 1
818                ("Generating all function for: " ^ type_to_string t ^ "\n")
819in
820        mk_source_functions
821                        "all"
822                        mk_all_term
823                        get_all_function
824                        (fn x => (EVERY_CONJ_CONV (STRIP_QUANT_CONV (RAND_CONV
825                                (PURE_REWRITE_CONV [get_source_function_def (mk_prod(alpha,beta)) "all"])))) x
826                                handle UNCHANGED => REFL x)
827                        REFL
828                        (base_type t)
829end;
830
831fun mk_fixs target t =
832let     val _ = if exists_coding_function target t "fix" then
833                raise (mkStandardExn "mk_fixs"
834                        ("fix function for translation: " ^ type_to_string t ^ " --> " ^ type_to_string target ^
835                         " already exists.")) else ()
836        val _ = type_trace 1
837                ("Generating fix function for: " ^ type_to_string t ^ " --> " ^ type_to_string target ^ "\n")
838in
839        mk_target_functions
840                        "fix"
841                        (mk_fix_term target)
842                        (get_fix_function target)
843                        (FIX_CONV target)
844                        (CONSOLIDATE_CONV rand)
845                        target
846                        (base_type t) handle e => wrapException "mk_fixs" e
847end;
848
849(*****************************************************************************)
850(* Generate conclusions for the various goals to be proven:                  *)
851(*                                                                           *)
852(* mk_encode_decode_map_conc  : hol_type -> hol_type -> term                 *)
853(* mk_encode_detect_all_conc  : hol_type -> hol_type -> term                 *)
854(* mk_decode_encode_fix_conc  : hol_type -> hol_type -> term                 *)
855(* mk_encode_map_encode_conc  : hol_type -> hol_type -> term                 *)
856(* mk_map_compose_conc        : hol_type -> term                             *)
857(* mk_map_id_conc             : hol_type -> term                             *)
858(* mk_all_id_conc             : hol_type -> term                             *)
859(* mk_fix_id_conc             : hol_type -> hol_type -> term                 *)
860(* mk_general_detect_conc     : hol_type -> hol_type -> term                 *)
861(*                                                                           *)
862(* mk_encode_decode_conc      : hol_type -> hol_type -> term                 *)
863(* mk_decode_encode_conc      : hol_type -> hol_type -> term                 *)
864(* mk_encode_detect_conc      : hol_type -> hol_type -> term                 *)
865(*                                                                           *)
866(*                                                                           *)
867(*     Make the conclusions for the various theorems:                        *)
868(*     ?- (decode f o encode g) = map (f o g)                                *)
869(*     ?- (encode f o decode g) = fix (f o g)                                *)
870(*     ?- (detect f o encode g) = all (f o g)                                *)
871(*                                                                           *)
872(*     ?- (encode f o map g) = encode (f o g)                                *)
873(*     ?- (map f o map g) = map (f o g)                                      *)
874(*                                                                           *)
875(*     ?- map I = I                                                          *)
876(*     ?- all (K T) = K T                                                    *)
877(*     ?- (!x. f x = x) ==> (!x. fix f x = x)                                *)
878(*                                                                           *)
879(*     ?- !x. detect f g x ==> detect (K T) (K T) x                          *)
880(*                                                                           *)
881(*     ?- (!x. f (g x) = x) ==> !x. decode f (encode g x) = x                *)
882(*     ?- (!x. p x ==> g (f x) = x) ==>                                      *)
883(*                    !x. detect p x ==> encode g (decode f x) = x           *)
884(*     ?- (!x. p (g x)) ==> !x. detect p (encode g x)                        *)
885(*                                                                           *)
886(*****************************************************************************)
887
888fun get_hfuns term =
889    if is_comb term
890       then flatten (map get_hfuns (op:: (strip_comb term)))
891       else [term];
892
893fun type_vars_avoiding_itself function t =
894    set_diff (type_vars t)
895             (map (hd o snd o dest_type o type_of)
896             (filter (can (match_term the_value)) (get_hfuns function)));
897
898fun check_function gf t =
899let val term = gf t
900    val hfuns = get_hfuns term
901    val vars = filter is_var hfuns
902    val values = filter (polymorphic o type_of)
903                        (filter (can (match_term the_value)) hfuns)
904in
905    if length (mk_set (vars @ values)) > length (type_vars t)
906       then raise (mkDebugExn "check_function"
907                  ("The function term: " ^ term_to_string term ^
908                  "\ncontains free-variables not derived from the type: " ^
909                  type_to_string t))
910       else term
911end;
912
913local
914fun wrap e = wrapException "mk_encode_decode_map_conc" e
915fun err s = raise (mkDebugExn "mk_encode_decode_map_conc"
916("Unable to correctly instantiate type variables in " ^ s ^ " function"))
917in
918fun mk_encode_decode_map_conc target t =
919let val enc = check_function (get_encode_function target) t handle e => wrap e
920    val dec = check_function (get_decode_function target) t handle e => wrap e
921    val map_term = check_function get_map_function t  handle e => wrap e
922    val safe_map_term = inst (map (fn v => v |-> gen_tyvar())
923                                  (type_vars_in_term map_term)) map_term;
924    val tvs = type_vars_avoiding_itself enc t
925    val values = set_diff (type_vars t) tvs
926
927    fun inst_from term start types =
928        inst (map (fn (a,b) =>
929                  b |-> mk_vartype (String.implode(#"'" :: base26 (a + start))))
930                  (enumerate 0 types)) term;
931    val enc' = inst_from (inst_from enc 0 tvs) (length tvs) values
932               handle e => err "encode"
933    val dec' = inst_from (inst_from dec (length tvs + length values) tvs)
934                         (length tvs) values handle e => err "decode"
935    val map' = inst (match_type (type_of safe_map_term)
936                                (fst (dom_rng (type_of enc')) -->
937                                 snd (dom_rng (type_of dec'))))
938                        safe_map_term handle e => err "map";
939
940    val enc_vars = free_vars_lr enc'
941    val dec_vars = free_vars_lr dec'
942    val sub = map2 (curry op|->)
943                   (free_vars_lr map')
944                   (map2 (curry combinSyntax.mk_o) dec_vars enc_vars)
945              handle e => wrap e
946in
947    list_mk_forall(enc_vars,
948    list_mk_forall(dec_vars,mk_eq(combinSyntax.mk_o(dec',enc'),subst sub map')))
949    handle e => wrap e
950end
951end
952
953local
954fun w s e = wrapException s e
955fun mk_ring_conc left func1 func2 =
956let val sub = map2 (curry op|->)
957                   (free_vars_lr (if left then func1 else func2))
958                   (map2 (curry combinSyntax.mk_o) (free_vars_lr func1) (free_vars_lr func2)) handle e => w "mk_ring_conc" e
959        val tsubs = map (fn {redex,residue} => match_type (type_of redex) (type_of residue)) sub
960                handle e => w "mk_ring_conc" e
961        val ins = C (foldl (uncurry inst)) tsubs  handle e => w "mk_ring_conc" e
962        val sub' = map (fn {redex,residue} => ins redex |-> residue) sub  handle e => w "mk_ring_conc" e
963in
964        list_mk_forall(free_vars_lr func1,
965        list_mk_forall(free_vars_lr func2,
966                mk_eq((curry combinSyntax.mk_o) func1 func2,subst sub' (ins (if left then func1 else func2)))))
967                 handle e => w "mk_ring_conc" e
968end
969in
970fun mk_encode_map_encode_conc target t =
971let     val encf = check_function (get_encode_function target) t handle e => w "mk_encode_map_encode_conc" e
972        val mapf = check_function get_map_function t handle e => w "mk_encode_map_encode_conc" e
973        val encf' = inst (match_type (fst (dom_rng (type_of encf))) (snd (dom_rng (type_of mapf)))) encf
974                         handle e => w "mk_encode_map_encode_conc" e
975in
976        mk_ring_conc true encf' mapf  handle e => w "mk_encode_map_encode_conc" e
977end
978fun mk_map_compose_conc t =
979let     val map1 = check_function get_map_function t handle e => w "mk_map_compose_conc" e
980        val map2 = inst (map (fn x => x |-> (mk_vartype o curry op^ "'map" o get_type_string) x)
981                        (type_vars (type_of map1))) map1 handle e => w "mk_map_compose_conc" e
982        val map1' = inst (match_type (snd (dom_rng (type_of map1))) (fst (dom_rng (type_of map2)))) map1
983                         handle e => w "mk_encode_map_encode_conc" e
984in
985        mk_ring_conc false map2 map1' handle e => w "mk_encode_map_encode_conc" e
986end
987end;
988
989fun mk_decode_encode_fix_conc target t =
990let     val enc = check_function (get_encode_function target) t
991        val dec = check_function (get_decode_function target) t
992        val fix = check_function (get_fix_function target) t
993
994        val enc_vars = free_vars_lr enc
995        val dec_vars = free_vars_lr dec
996        val sub = map2 (curry op|->) (free_vars_lr fix) (map2 (curry combinSyntax.mk_o) enc_vars dec_vars)
997in
998        list_mk_forall(enc_vars,
999        list_mk_forall(dec_vars,
1000        mk_eq((curry combinSyntax.mk_o) enc dec,subst sub fix)))
1001end
1002
1003fun mk_encode_detect_all_conc target t =
1004let     val enc = check_function (get_encode_function target) t
1005        val det = check_function (get_detect_function target) t
1006        val all = check_function get_all_function t
1007        val dbool = get_decode_function target bool
1008
1009        val enc_vars = free_vars_lr enc
1010        val det_vars = free_vars_lr det
1011        val sub = map2 (curry op|->) (free_vars_lr all) (map2 (curry combinSyntax.mk_o) det_vars enc_vars)
1012in
1013        list_mk_forall(det_vars,
1014        list_mk_forall(enc_vars,mk_eq((curry combinSyntax.mk_o) det enc,subst sub all)))
1015end;
1016
1017fun mk_map_id_conc t =
1018let     val map_term = check_function get_map_function t
1019        val fvs = free_vars map_term
1020        val ty_sub = map (fn fv => snd (dom_rng (type_of fv)) |-> fst (dom_rng (type_of fv))) fvs
1021        val map' = inst ty_sub map_term
1022        val dub = fn t => fst (dom_rng t) --> fst (dom_rng t)
1023        val tm_sub = map (fn fv => (mk_var o (I ## dub) o dest_var) fv |-> mk_const("I",dub (type_of fv))) fvs
1024in
1025        mk_eq(subst tm_sub map',mk_const("I",type_of map'))
1026end
1027
1028fun mk_all_id_conc t =
1029let     val all_term = check_function get_all_function t
1030        val fvs = free_vars all_term
1031        fun mk_all_id t = mk_comb(mk_const("K",bool --> (t --> bool)),T)
1032in      mk_eq(subst (map (fn x => x |->
1033                mk_all_id (fst (dom_rng (type_of x)))) fvs) all_term,
1034                mk_all_id t)
1035end
1036
1037fun mk_fix_id_conc target t =
1038let val fix_term = check_function (get_fix_function target) t
1039    val det_term = check_function (get_detect_function target) t
1040    val tvs = type_vars_avoiding_itself fix_term t
1041    val hyps = map (mk_fix_id_conc target) (set_diff tvs [t])
1042    val x = mk_var("x",target)
1043    val e = mk_forall(x,mk_imp(mk_comb(det_term,x),
1044                        mk_eq(mk_comb(fix_term,x),x)))
1045in
1046    if null hyps then e
1047    else mk_imp(list_mk_conj hyps,e)
1048end;
1049
1050fun mk_general_detect_conc target t =
1051let val p1 = check_function (get_detect_function target) t
1052    val t' = type_subst (map (fn v => v |-> target)
1053                             (type_vars_avoiding_itself p1 t)) t
1054    val p2 = check_function (get_detect_function target) t'
1055    val xvar = mk_var("x",target)
1056in
1057    mk_forall(xvar,
1058    list_mk_forall(free_vars p1,mk_imp(mk_comb(p1,xvar),mk_comb(p2,xvar))))
1059end;
1060
1061local
1062fun wrap e = wrapException "mk_encode_decode_conc" e
1063in
1064fun mk_encode_decode_conc target t =
1065let val encode = check_function (get_encode_function target) t
1066                 handle e => wrap e
1067    val decode = check_function (get_decode_function target) t
1068                 handle e => wrap e
1069    val var = mk_var("x",t)
1070    val conc = mk_forall(var,mk_eq(mk_comb(decode,mk_comb(encode,var)),var))
1071             handle e => wrap e
1072    val tvs = type_vars_avoiding_itself encode t
1073    val ante = map (snd o dest_imp o snd o strip_forall o
1074                   mk_encode_decode_conc target) (set_diff tvs [t])
1075in
1076    list_mk_forall(map (get_encode_function target) tvs,
1077    list_mk_forall(map (get_decode_function target) tvs,
1078    if is_vartype t then mk_imp(conc,conc) else
1079    if null ante then conc else mk_imp(list_mk_conj ante,conc)))
1080       handle e => wrap e
1081end
1082end
1083
1084local
1085fun wrap e = wrapException "mk_decode_encode_conc" e
1086in
1087fun mk_decode_encode_conc target t =
1088let val encode = check_function (get_encode_function target) t
1089                 handle e => wrap e
1090    val detect = check_function (get_detect_function target) t
1091                 handle e => wrap e
1092    val decode = check_function (get_decode_function target) t
1093                 handle e => wrap e
1094    val var = mk_var("x",target)
1095    val conc = mk_forall(var,mk_imp(mk_comb(detect,var),
1096                        mk_eq(mk_comb(encode,mk_comb(decode,var)),var)))
1097                        handle e => wrap e
1098    val tvs = type_vars_avoiding_itself encode t
1099    val ante = map (snd o dest_imp o snd o strip_forall o
1100             mk_decode_encode_conc target) (set_diff tvs [t])
1101in
1102    list_mk_forall(map (get_encode_function target) tvs,
1103    list_mk_forall(map (get_decode_function target) tvs,
1104    list_mk_forall(map (get_detect_function target) tvs,
1105    if is_vartype t then mk_imp(conc,conc) else
1106    if null ante then conc else mk_imp(list_mk_conj ante,conc))))
1107       handle e => wrap e
1108end
1109end
1110
1111local
1112fun wrap e = wrapException "mk_encode_detect_conc" e
1113in
1114fun mk_encode_detect_conc target t =
1115let val encode = check_function (get_encode_function target) t
1116                 handle e => wrap e
1117    val detect = check_function (get_detect_function target) t
1118                 handle e => wrap e
1119    val var = mk_var("x",t)
1120    val conc = mk_forall(var,mk_comb(detect,mk_comb(encode,var)))
1121               handle e => wrap e
1122    val tvs = type_vars_avoiding_itself encode t
1123    val ante = map (snd o dest_imp o snd o strip_forall o
1124                   mk_encode_detect_conc target) (set_diff tvs [t])
1125in
1126    list_mk_forall(map (get_encode_function target) tvs,
1127    list_mk_forall(map (get_detect_function target) tvs,
1128    if is_vartype t then mk_imp(conc,conc) else
1129    if null ante then conc else mk_imp(list_mk_conj ante,conc)))
1130        handle e => wrap e
1131end
1132end
1133
1134(*****************************************************************************)
1135(* Rules to generate instantiated theorems from base-type theorems:          *)
1136(*                                                                           *)
1137(* FULL_ENCODE_DECODE_MAP_THM : hol_type -> hol_type -> thm                  *)
1138(* FULL_ENCODE_DETECT_ALL_THM : hol_type -> hol_type -> thm                  *)
1139(* FULL_ENCODE_MAP_ENCODE_THM : hol_type -> hol_type -> thm                  *)
1140(* FULL_DECODE_ENCODE_FIX_THM : hol_type -> hol_type -> thm                  *)
1141(* FULL_MAP_COMPOSE_THM : hol_type -> hol_type -> thm                        *)
1142(*     Create the theorem, eg:                                               *)
1143(*            |- map map o encode encode = encode encode                     *)
1144(*     from:                                                                 *)
1145(*            |- !f g. map f o encode g = encode (f o g)                     *)
1146(*     and    |- map o encode = encode                                       *)
1147(*                                                                           *)
1148(* FULL_MAP_ID_THM : hol_type -> thm                                         *)
1149(* FULL_ALL_ID_THM : hol_type -> thm                                         *)
1150(*     Create the theorem, eg:                                               *)
1151(*            |- map (map I) (map I) = I                                     *)
1152(*     from:  |- map I I = I    |- map I = I    |- map I = I                 *)
1153(*                                                                           *)
1154(* FULL_FIX_ID_THM : hol_type -> hol_type -> thm                             *)
1155(*     Create the theorem, eg:                                               *)
1156(*            |- fix fix x = x                                               *)
1157(*     from:  |- (!x. f x = x) ==> (!x. fix f x = x)     |- !x. fix x = x    *)
1158(*                                                                           *)
1159(*                                                                           *)
1160(* FULL_ENCODE_DECODE_THM : hol_type -> hol_type -> thm                      *)
1161(*     Create the theorem, eg:                                               *)
1162(*            |- !x. decode decode (encode encode x) = x                     *)
1163(*                                                                           *)
1164(* FULL_DECODE_ENCODE_THM : hol_type -> hol_type -> thm                      *)
1165(*     Create the theorem, eg:                                               *)
1166(*            |- !x. detect detect x ==> encode encode (decode decode x) = x)*)
1167(*                                                                           *)
1168(* FULL_ENCODE_DETECT_THM : hol_type -> hol_type -> thm                      *)
1169(*     Create the theorem, eg:                                               *)
1170(*            |- !x. detect detect (encode encode x)                         *)
1171(*                                                                           *)
1172(*****************************************************************************)
1173
1174fun wrap_full s t e =
1175    wrapException (s ^ "(" ^ type_to_string t ^ ")") e
1176
1177fun get_sub_types basetype t =
1178    filter (not o is_vartype)
1179           (map #residue (match_type basetype t))
1180
1181local
1182fun EMPTY_RING gconc name target t thms =
1183let val conc = gconc target t
1184in
1185    EQT_ELIM (REWRITE_CONV thms conc)
1186end
1187fun RING_MATCH_THM gconc name target t =
1188let val basetype = if t = target then t else
1189                   most_precise_type
1190                   (C (exists_coding_theorem_precise target) name) t
1191                   handle _ => t
1192    val thm = SPEC_ALL (generate_coding_theorem target name basetype)
1193    val conc = gconc target t
1194    val thm' = PART_MATCH lhs thm (lhs (snd (strip_forall conc)))
1195    val sub_thms = map (RING_MATCH_THM gconc name target)
1196                       (get_sub_types basetype t)
1197in
1198    RIGHT_CONV_RULE (PURE_REWRITE_CONV sub_thms) thm'
1199end;
1200fun CHECK_RING gconc name target t thms =
1201    if null (type_vars t)
1202       then (EMPTY_RING gconc name target t thms handle _ =>
1203             RING_MATCH_THM gconc name target t)
1204       else RING_MATCH_THM gconc name target t
1205in
1206fun FULL_ENCODE_DECODE_MAP_THM target t =
1207    if target = t
1208       then CONJUNCT1 (ISPEC (mk_const("I",target --> target)) I_o_ID)
1209       else RING_MATCH_THM mk_encode_decode_map_conc
1210                           "encode_decode_map" target t
1211            handle e => wrap_full "FULL_ENCODE_DECODE_MAP_THM" t e
1212fun FULL_ENCODE_DETECT_ALL_THM target t =
1213    if target = t
1214       then CONJUNCT2 (ISPEC
1215                      (mk_comb(mk_const("K",bool --> target --> bool),T))
1216                      I_o_ID)
1217       else RING_MATCH_THM mk_encode_detect_all_conc
1218                           "encode_detect_all" target t
1219        handle e => wrap_full "FULL_ENCODE_DETECT_ALL_THM" t e
1220fun FULL_ENCODE_MAP_ENCODE_THM target t =
1221    if target = t
1222       then FULL_ENCODE_DECODE_MAP_THM target t
1223       else
1224        CHECK_RING mk_encode_map_encode_conc "encode_map_encode" target t
1225                   [I_o_ID]
1226        handle e => wrap_full "FULL_ENCODE_MAP_ENCODE_THM" t e
1227fun FULL_DECODE_ENCODE_FIX_THM target t =
1228    if target = t
1229       then FULL_ENCODE_DECODE_MAP_THM target t
1230       else
1231        RING_MATCH_THM mk_decode_encode_fix_conc "decode_encode_fix" target t;
1232end
1233
1234fun FULL_MAP_COMPOSE_THM t =
1235let val basetype = most_precise_type
1236                   (C exists_source_theorem_precise "map_compose") t
1237                   handle e => t;
1238    val thm = SPEC_ALL (generate_source_theorem "map_compose" basetype)
1239    val conc = mk_map_compose_conc t
1240    val thm' = PART_MATCH lhs thm (lhs (snd (strip_forall conc)))
1241    val sub_thms = map FULL_MAP_COMPOSE_THM
1242                       (get_sub_types basetype t)
1243in
1244    RIGHT_CONV_RULE (PURE_REWRITE_CONV sub_thms) thm'
1245end handle e => wrap_full "FULL_MAP_COMPOSE_THM" t e
1246
1247local
1248fun FMIDT getf t tname ename mk_const mk_conc =
1249let val basetype = most_precise_type (C exists_source_theorem_precise tname) t
1250                   handle _ => t
1251    val thm = SPEC_ALL (generate_source_theorem tname basetype)
1252    val thm' = INST_TYPE (match_type (fst (dom_rng
1253                     (type_of (lhs (concl thm))))) t) thm
1254    val left = lhs (concl thm')
1255    val subtypes = get_sub_types basetype t
1256    val sub_thms = map (fn x =>
1257             if is_vartype x
1258                then NONE
1259                else SOME (FMIDT' getf x tname ename mk_const mk_conc))
1260             subtypes
1261
1262    val conc = mk_conc t
1263    val thm1 = RAND_CONV (REWR_CONV (GSYM thm)) conc
1264    val sub_thms_filtered =
1265        filter (fn x => (not o op= o dest_eq o concl o valOf) x
1266                        handle _ => true) sub_thms
1267    val thm2 = RIGHT_CONV_RULE
1268                   (REWRITE_CONV (mapfilter Option.valOf sub_thms_filtered))
1269                   thm1
1270in
1271    CONV_RULE bool_EQ_CONV thm2
1272end handle e => wrap_full ename t e
1273and FMIDT' getf t tname ename mk_const mk_conc =
1274    if can (match_term (mk_const(alpha --> alpha))) (getf t)
1275       then REFL (getf t)
1276       else FMIDT getf t tname ename mk_const mk_conc
1277in
1278fun FULL_MAP_ID_THM t =
1279    FMIDT' (check_function get_map_function) t "map_id" "FULL_MAP_ID_THM"
1280           (curry mk_const "I") mk_map_id_conc
1281fun FULL_ALL_ID_THM t =
1282    FMIDT' (check_function get_all_function) t "all_id" "FULL_ALL_ID_THM"
1283           (fn t => mk_comb(mk_const("K",bool --> fst (dom_rng t) --> bool),T))
1284           mk_all_id_conc
1285end
1286
1287fun FULL_FIX_ID_THM target t =
1288let fun wrap e = wrap_full "FULL_FIX_ID_THM" t e
1289    val basetype = if target = t then t else
1290                   most_precise_type
1291                   (C (exists_coding_theorem_precise target) "fix_id") t
1292                   handle _ => t
1293    val thm = generate_coding_theorem target "fix_id" basetype
1294              handle e => wrap e
1295    fun mimp_only tm = if is_imp_only tm then snd (dest_imp tm) else tm
1296    val conc = mk_fix_id_conc target t
1297              handle e => wrap e
1298    val values = filter (can (match_term the_value))
1299                        (snd (strip_comb (lhs (snd (strip_imp (snd
1300                             (strip_forall (mimp_only (snd
1301                             (strip_forall conc))))))))))
1302              handle e => wrap e
1303    val value_types = map (hd o snd o dest_type o type_of) values
1304              handle e => wrap e
1305    val tvs = set_diff (type_vars t) value_types
1306    val sub_thms = map (UNDISCH_ALL o PURE_REWRITE_RULE [GSYM AND_IMP_INTRO] o
1307                        FULL_FIX_ID_THM target)
1308                       (get_sub_types basetype t)
1309              handle e => wrap e
1310    val thm' = INST_TY_TERM
1311               (match_term (mimp_only (concl thm)) (mimp_only conc)) thm
1312              handle e => wrap e
1313    val disch_set = map (mk_fix_id_conc target) tvs
1314              handle e => wrap e
1315in
1316    PURE_REWRITE_RULE [AND_IMP_INTRO,GSYM CONJ_ASSOC]
1317         (foldr (uncurry DISCH) (foldl (uncurry PROVE_HYP)
1318                (UNDISCH_ALL (PURE_REWRITE_RULE [GSYM AND_IMP_INTRO] thm'))
1319                sub_thms) disch_set)
1320              handle e => wrap e
1321end;
1322
1323local
1324val conv1 = STRIP_QUANT_CONV (RAND_CONV (REWR_CONV (GSYM I_THM)) THENC
1325            LAND_CONV (REWR_CONV (GSYM o_THM))) THENC
1326            REWR_CONV (GSYM FUN_EQ_THM)
1327val conv2 = REWR_CONV FUN_EQ_THM THENC
1328            STRIP_QUANT_CONV (RAND_CONV (REWR_CONV I_THM) THENC
1329            LAND_CONV (REWR_CONV o_THM))
1330fun FEDT target t =
1331let fun wrap e = wrap_full "FULL_ENCODE_DECODE_THM" t e
1332    val ename = "FULL_ENCODE_DECODE_THM" ^ type_to_string t
1333    val thm1 = generate_coding_theorem target "encode_decode_map" t
1334               handle e => wrap e
1335    val thm1_safe =
1336        INST_TYPE (map (fn v => v |-> gen_tyvar())
1337                       (type_vars_in_term (concl thm1))) thm1
1338    val thm2 = generate_source_theorem "map_id" t handle e => wrap e
1339    val tvs = type_vars_avoiding_itself (get_encode_function target t) t
1340    val antes = map (CONV_RULE conv1 o ASSUME o snd o dest_imp o
1341                    snd o strip_forall o mk_encode_decode_conc target)
1342                                tvs handle e =>
1343                raise (mkDebugExn ename
1344                      ("mk_encode_decode_conc returned invalid conclusion for" ^
1345                       " type variable: " ^ type_to_string t))
1346    val thm2a = PURE_REWRITE_RULE (map SYM antes) thm2 handle e => wrap e
1347    fun instit f thm = INST_TYPE (match_type (f (concl thm)) t) thm
1348    val thm1a = instit (snd o dom_rng o type_of o lhs)
1349                       (instit (fst o dom_rng o type_of o lhs)
1350                       (SPEC_ALL thm1_safe))
1351    val thm3 = TRANS thm1a thm2a handle e =>
1352               raise (mkDebugExn ename
1353         ("Generated encode_decode_map and map_id theorems do not match:\n" ^
1354                        thm_to_string thm1a ^ "\n" ^ thm_to_string thm2a))
1355    val thm4 = CONV_RULE conv2 thm3 handle e => wrap e
1356
1357    val (vs,conc) = strip_forall (mk_encode_decode_conc target t)
1358                    handle e => wrap e
1359    val (vars,list) = if is_imp_only conc
1360                         then (vs,strip_conj (fst (dest_imp conc))) else ([],[])
1361    val result = PURE_REWRITE_RULE [AND_IMP_INTRO]
1362                     (foldr (uncurry DISCH) thm4 list)
1363in
1364    if null (hyp result) then GENL vars result else
1365       raise (mkDebugExn ename
1366                "Hypothesis remain in conclusion of theorem!")
1367end
1368in
1369fun FULL_ENCODE_DECODE_THM target t =
1370    if target = t then
1371       CONV_RULE bool_EQ_CONV (REWRITE_CONV [combinTheory.I_THM]
1372                 (mk_encode_decode_conc target t))
1373    else if is_vartype t
1374            then DECIDE (mk_encode_decode_conc target t) else FEDT target t
1375end;
1376
1377local
1378fun wrap e = wrapException "FULL_DECODE_ENCODE_THM" e
1379fun FDET target t =
1380let     val thm1 = generate_coding_theorem target "decode_encode_fix" t handle e => wrap e
1381        val thm2 = generate_coding_theorem target "fix_id" t handle e => wrap e
1382
1383        val thm1a = CONV_RULE (STRIP_QUANT_CONV (REWR_CONV FUN_EQ_THM THENC
1384                                STRIP_QUANT_CONV (LAND_CONV (REWR_CONV o_THM)))) thm1;
1385        val v1 = (lhs o snd o dest_imp_only o snd o strip_forall o snd o dest_imp_only)
1386        val v2 = (lhs o snd o dest_imp_only)
1387        val thm2a =     PART_MATCH v1 thm2 (rhs (snd (strip_forall (concl thm1a)))) handle e =>
1388                        PART_MATCH v2 thm2 (rhs (snd (strip_forall (concl thm1a))))
1389
1390        val thm2b = CONV_RULE (LAND_CONV (EVERY_CONJ_CONV (STRIP_QUANT_CONV
1391                        (RAND_CONV (LAND_CONV (REWR_CONV o_THM)))))) thm2a handle e => thm2a
1392        val thm2c =     UNDISCH (SPEC_ALL (UNDISCH_CONJ thm2b)) handle e =>
1393                        UNDISCH (SPEC_ALL thm2b)
1394
1395        val thm3 = GEN (rhs (concl thm2c)) (DISCH (first (not o is_forall) (hyp thm2c)) (TRANS (SPEC_ALL thm1a) thm2c))
1396
1397        val conc = mk_decode_encode_conc target t
1398        val list = if is_imp_only (snd (strip_forall (snd (dest_imp_only (snd (strip_forall conc))))))
1399                        then strip_conj (fst (dest_imp_only (snd (strip_forall conc)))) else []
1400        val r = DISCH_LIST_CONJ list thm3
1401in
1402        if null (hyp r) then r else
1403                raise (mkDebugExn "FULL_DECODE_ENCODE_THM"
1404                        "Hypothesis remain in resultant theorem, mismatch between mk_decode_encode_conc and this?")
1405end
1406in
1407fun FULL_DECODE_ENCODE_THM target t =
1408let     val conc  = (mk_decode_encode_conc target t)
1409        val vlist = if is_imp_only (snd (strip_forall (snd (dest_imp_only (snd (strip_forall conc))))))
1410                        then fst (strip_forall conc) else []
1411in
1412        GENL vlist (if is_vartype t
1413                then DISCH_ALL (ASSUME (snd (dest_imp_only (snd (strip_forall (mk_decode_encode_conc target t))))))
1414                else FDET target t)
1415end
1416end
1417
1418local
1419fun wrap e = wrapException "FULL_ENCODE_DETECT_THM" e
1420val rthm = DISCH_ALL (CONV_HYP (REWRITE_CONV [FUN_EQ_THM,K_THM])
1421                (ASSUME (mk_eq(mk_var("A",alpha --> bool),mk_comb(mk_const("K",bool --> alpha --> bool),T)))))
1422fun FEDT target t =
1423let     val thm1 = FULL_ENCODE_DETECT_ALL_THM target t handle e => wrap e
1424        val thm2 = FULL_ALL_ID_THM t handle e => wrap e
1425
1426        val thm2a = CONV_RULE (REWR_CONV FUN_EQ_THM THENC
1427                        STRIP_QUANT_CONV (RAND_CONV (REWR_CONV K_THM) THENC bool_EQ_CONV)) thm2 handle e => wrap e
1428        val thm1a = snd (EQ_IMP_RULE (SPEC_ALL (CONV_RULE (REWR_CONV FUN_EQ_THM) thm1))) handle e => wrap e
1429
1430        val conc = mk_encode_detect_conc target t handle e => wrap e
1431        val imps = map (MATCH_MP rthm o CONV_RULE (STRIP_QUANT_CONV (REWR_CONV (GSYM o_THM))) o ASSUME o
1432                        snd o dest_imp_only o snd o strip_forall o mk_encode_detect_conc target) (type_vars t)
1433                        handle e => wrap e
1434        val thm3 = GEN_ALL (CONV_RULE (REWR_CONV o_THM) (MATCH_MP (PURE_REWRITE_RULE imps thm1a) (SPEC_ALL thm2a)))
1435                        handle e => wrap e
1436        val thm3' = GEN_ALL (UNDISCH_ALL (PART_MATCH (snd o strip_imp) (DISCH_ALL thm3)
1437                        (snd (dest_imp (snd (strip_forall conc))) handle _ => snd (strip_forall conc))))
1438                        handle e => wrap e;
1439
1440        val (vars,body) = strip_forall conc
1441        val (vs,timps) = (vars,strip_conj (fst (dest_imp body))) handle _ => ([],[])
1442        val r = GENL vs (DISCH_LIST_CONJ timps thm3')
1443in
1444        if null (hyp r) then r else
1445                raise (mkDebugExn "FULL_ENCODE_DETECT_THM"
1446                        "Hypothesis remain in resultant theorem, mismatch between mk_encode_detect_conc and this?")
1447end
1448in
1449fun FULL_ENCODE_DETECT_THM target t =
1450        if is_vartype t then DECIDE (mk_encode_detect_conc target t)
1451                else FEDT target t
1452end;
1453
1454(*****************************************************************************)
1455(* Conversions to fully apply functions:                                     *)
1456(*                                                                           *)
1457(* ENCODER_CONV : term -> thm                                                *)
1458(* APP_MAP_CONV : term -> thm                                                *)
1459(* APP_ALL_CONV : term -> thm                                                *)
1460(* DECODE_PAIR_CONV : term -> thm                                            *)
1461(* DETECT_PAIR_CONV : hol_type -> term -> thm                                *)
1462(*                                                                           *)
1463(*        ENCODER_CONV : |- (encode (C a b)) = encode_pair (x,a,b)           *)
1464(*        APP_MAP_CONV : |- (map (C a b)) = C (map a) (map b)                *)
1465(*        APP_ALL_CONV : |- (all (C a b)) = (all a) /\ (all b)               *)
1466(*        DECODE_PAIR_CONV : |- (decode (encode_pair f g a)) =               *)
1467(*                              C (decode (f (FST (SND a)))) ...             *)
1468(*        DETECT_PAIR_CONV : |- (detect (encode_pair f g a)) =               *)
1469(*                              (detect (f (FST (SND a)))) /\ ...            *)
1470(*                                                                           *)
1471(*****************************************************************************)
1472
1473fun ENCODER_CONV term =
1474let     val t = type_of (rand term)
1475        val target = type_of term
1476        val check = check_function (get_encode_function target) t
1477        val def = get_coding_function_def target t "encode"
1478in
1479        if can (match_term check) (rator term) then
1480                FIRST_CONV (map REWR_CONV (CONJUNCTS def)) term
1481        else
1482                NO_CONV term
1483end     handle e => NO_CONV term
1484
1485fun APP_MAP_CONV term =
1486let     val t = type_of (rand term)
1487        val check = check_function get_map_function t
1488        val def = get_source_function_def t "map"
1489in
1490        if can (match_term check) (rator term) then
1491                FIRST_CONV (map REWR_CONV (CONJUNCTS def)) term
1492        else    NO_CONV term
1493end     handle e => NO_CONV term
1494
1495fun APP_ALL_CONV term =
1496let     val t = type_of (rand term)
1497        val check = check_function get_all_function t
1498        val def = get_source_function_def t "all"
1499in
1500        if can (match_term check) (rator term) then
1501                (FIRST_CONV (map REWR_CONV (CONJUNCTS def))) term
1502        else    NO_CONV term
1503end     handle e => NO_CONV term
1504
1505fun DECODE_PAIR_CONV term =
1506let     val t = type_of term
1507        val target = type_of (rand term)
1508        val check = check_function (get_decode_function target) t
1509        val def = get_coding_function_def target t "decode"
1510
1511        val pairp_pair = get_coding_theorem target (mk_prod(alpha,beta)) "encode_detect_all"
1512        val nump_num = get_coding_theorem target num "encode_detect_all"
1513        val labelled = mk_comb(get_detect_function target (mk_prod(num,target)),rand term)
1514        val pairp_id = get_source_theorem (mk_prod(alpha,beta)) "all_id"
1515        val pair_map = get_source_function_def (mk_prod(alpha,beta)) "map"
1516        val paird_pair = PURE_REWRITE_RULE [pair_map]
1517                                (ISPEC (mk_pair(genvar (gen_tyvar()),genvar (gen_tyvar())))
1518                                        (PURE_REWRITE_RULE [FUN_EQ_THM,o_THM]
1519                                (SPEC_ALL (get_coding_theorem target (mk_prod(alpha,beta)) "encode_decode_map"))));
1520        val numd_num = get_coding_theorem target num "encode_decode_map";
1521        val cs = constructors_of t
1522        val all_rwr =   if all (not o can dom_rng o type_of) cs then
1523                                PURE_REWRITE_CONV [o_THM,PURE_REWRITE_RULE [FUN_EQ_THM,o_THM] nump_num,K_THM] (mk_comb(get_detect_function target num,rand term))
1524                        else if length cs = 1 then
1525                                PURE_REWRITE_RULE [get_coding_function_def target t "encode"]
1526                                        (ISPEC (list_mk_comb(hd cs,map genvar (fst (strip_fun (type_of (hd cs))))))
1527                                                (PURE_REWRITE_RULE [FUN_EQ_THM,o_THM] (SPEC_ALL (generate_coding_theorem target "encode_detect_all" t))))
1528                        else    PURE_REWRITE_RULE [K_THM,pairp_id,nump_num,K_o_THM] (PART_MATCH lhs (PURE_REWRITE_RULE [o_THM,FUN_EQ_THM] pairp_pair) labelled)
1529
1530        val first_decode = if all (not o can dom_rng o type_of) cs then
1531                                PURE_REWRITE_CONV [o_THM,PURE_REWRITE_RULE [FUN_EQ_THM,o_THM] numd_num,I_THM] (mk_comb(get_decode_function target num,rand term))
1532                        else if length cs = 1 then
1533                                REFL labelled
1534                        else    PURE_REWRITE_CONV [paird_pair,numd_num,I_o_ID,I_THM] (mk_comb(get_decode_function target (mk_prod(num,target)),rand term))
1535
1536        fun re_o_conv term =
1537        let     val (r,subs) = (REFL ## map (REWR_CONV (GSYM o_THM))) (strip_comb term)
1538        in
1539                foldl (fn (a,b) => MK_COMB(b,a)) r subs
1540        end;
1541in
1542        if can (match_term check) (rator term) then
1543                (REWR_CONV def THENC PURE_REWRITE_CONV (COND_CLAUSES::all_rwr::first_decode::K_THM::K_o_THM::[generate_source_theorem "all_id" t]) THENC
1544                 TRY_CONV let_CONV THENC DEPTH_CONV (reduceLib.NEQ_CONV) THENC PURE_REWRITE_CONV [COND_CLAUSES,paird_pair,o_THM] THENC
1545                 TRY_CONV let_CONV THENC re_o_conv) term
1546        else    NO_CONV term
1547end handle e => wrapException "DETECT_PAIR_CONV" e
1548
1549fun DETECT_PAIR_CONV t term =
1550let     val target = type_of (rand term)
1551        val check = check_function (get_detect_function target) t
1552        val def = get_coding_function_def target t "detect"
1553
1554        val pairp_pair = generate_coding_theorem target "encode_detect_all" (mk_prod(alpha,beta))
1555        val nump_num = generate_coding_theorem target "encode_detect_all" num
1556        val pairp_id = generate_source_theorem "all_id" (mk_prod(alpha,beta))
1557        val labelled = mk_comb(get_detect_function target (mk_prod(num,target)),rand term)
1558        val pair_all = get_source_function_def (mk_prod(alpha,beta)) "all"
1559        val pair_map = get_source_function_def (mk_prod(alpha,beta)) "map"
1560        val paird_pair = PURE_REWRITE_RULE [pair_map]
1561                                (ISPEC (mk_pair(genvar (gen_tyvar()),genvar (gen_tyvar())))
1562                                        (PURE_REWRITE_RULE [FUN_EQ_THM,o_THM] (SPEC_ALL (get_coding_theorem target (mk_prod(alpha,beta)) "encode_decode_map"))));
1563        val numd_num = get_coding_theorem target num "encode_decode_map";
1564        val cs = constructors_of t
1565        val all_rwr =   if all (not o can dom_rng o type_of) cs then
1566                                PURE_REWRITE_CONV [o_THM,PURE_REWRITE_RULE [FUN_EQ_THM,o_THM] nump_num,K_THM] (mk_comb(get_detect_function target num,rand term))
1567                        else if length cs = 1 then
1568                                REFL labelled
1569                        else    PURE_REWRITE_RULE [K_THM,pairp_id,nump_num,K_o_THM] (PART_MATCH lhs (PURE_REWRITE_RULE [o_THM,FUN_EQ_THM] pairp_pair) labelled)
1570
1571        val first_decode =
1572                        if all (not o can dom_rng o type_of) cs then
1573                                PURE_REWRITE_CONV [o_THM,PURE_REWRITE_RULE [FUN_EQ_THM,o_THM] numd_num,I_THM] (mk_comb(get_decode_function target num,rand term))
1574                        else if length cs = 1 then
1575                                REFL labelled
1576                        else    PURE_REWRITE_CONV [paird_pair,numd_num,I_o_ID,I_THM] (mk_comb(get_decode_function target (mk_prod(num,target)),rand term))
1577in
1578        if can (match_term check) (rator term) then
1579                (REWR_CONV def THENC PURE_REWRITE_CONV [COND_CLAUSES,all_rwr,first_decode] THENC
1580                 TRY_CONV let_CONV THENC DEPTH_CONV (reduceLib.NEQ_CONV ORELSEC (REWR_CONV REFL_CLAUSE)) THENC PURE_REWRITE_CONV [COND_CLAUSES] THENC
1581                 TRY_CONV let_CONV THENC TRY_CONV (REWR_CONV (GSYM o_THM))) term
1582        else    NO_CONV term
1583end     handle e => wrapException "DETECT_PAIR_CONV" e
1584
1585(*****************************************************************************)
1586(* Tactics to prove the goals described previously:                          *)
1587(*                                                                           *)
1588(* encode_decode_map_tactic : hol_type -> hol_type -> tactic                 *)
1589(* encode_detect_all_tactic : hol_type -> hol_type -> tactic                 *)
1590(* decode_encode_fix_tactic : hol_type -> hol_type -> tactic                 *)
1591(* encode_map_encode_tactic : hol_type -> hol_type -> tactic                 *)
1592(* map_compose_tactic       : hol_type -> tactic                             *)
1593(* map_id_tactic            : hol_type -> tactic                             *)
1594(* all_id_tactic            : hol_type -> tactic                             *)
1595(* fix_id_tactic            : hol_type -> hol_type -> tactic                 *)
1596(* general_detect_tactic    : hol_type -> hol_type -> tactic                 *)
1597(*                                                                           *)
1598(*    Tactics to solve inductive clauses for the goals given previously.     *)
1599(*                                                                           *)
1600(* detect_dead_rule         : hol_type -> hol_type -> thm                    *)
1601(*    Generates a single application of detect to 'nil'. Used in             *)
1602(*    CONSOLIDATE_CONV to show that bottom values terminate.                 *)
1603(*                                                                           *)
1604(*****************************************************************************)
1605
1606fun encode_decode_map_tactic target (t:hol_type) (a,g) =
1607let     val t = type_of (rand (rhs (snd (strip_forall (snd (strip_imp (snd (strip_forall g))))))))
1608        val rts = relevant_types t
1609        val thms = map (generate_coding_theorem target "encode_decode_map") rts
1610        val map_defs = map (C get_source_function_def "map") rts
1611in
1612        (REPEAT STRIP_TAC THEN
1613        FIRST [ CONV_TAC (LAND_CONV (RATOR_CONV (FIRST_CONV (map REWR_CONV thms)))) THEN
1614                RULE_ASSUM_TAC GSYM THEN ASM_REWRITE_TAC (map_defs @ thms),
1615                PURE_REWRITE_TAC [o_THM] THEN CONV_TAC (LAND_CONV (RAND_CONV ENCODER_CONV THENC DECODE_PAIR_CONV) THENC RAND_CONV APP_MAP_CONV) THEN
1616                ASM_REWRITE_TAC (get_source_function_def (mk_prod(alpha,beta)) "map"::thms) THEN
1617                RULE_ASSUM_TAC GSYM THEN ASM_REWRITE_TAC (map_defs @ thms)]) (a,g)
1618end     handle e => wrapException "encode_decode_map_tactic" e
1619
1620local
1621fun fix_type tm ty =
1622        if is_pair tm then uncurry cons ((I ## fix_type (snd (dest_pair tm))) (dest_prod ty)) else [ty];
1623fun PTAC target rset t (a,g) =
1624let     val endt = rand (lhs g)
1625        val t' = delete_matching_types rset (cannon_type (type_of endt))
1626        val thm1 = PURE_REWRITE_RULE [FUN_EQ_THM,o_THM] (SPEC_ALL (generate_coding_theorem target "encode_map_encode" t'));
1627        val thm2 = PURE_REWRITE_RULE [I_THM,I_o_ID,get_source_function_def t' "map"] (ISPEC (list_mk_pair(map genvar (fix_type endt t'))) thm1)
1628in
1629        (CONV_TAC (LAND_CONV (REWR_CONV thm2))) (a,g)
1630end
1631fun CTAC target rset t =
1632let     val cs = constructors_of t
1633in      if all (not o can dom_rng o type_of) cs then ALL_TAC
1634        else if length cs = 1 then ALL_TAC else PTAC target rset t
1635end
1636in
1637fun encode_map_encode_tactic target (t:hol_type) (a,g) =
1638let     val t = type_of (rand (rhs (snd (strip_forall (snd (strip_imp (snd (strip_forall g))))))))
1639        val rset = map fst (split_nested_recursive_set t)
1640        val rts = relevant_types t
1641        val thms = map (generate_coding_theorem target "encode_map_encode") rts
1642        val enc_defs = map (C (get_coding_function_def target) "encode") (mk_prod(alpha,beta)::all_types t);
1643        val all_thms = map (PURE_REWRITE_RULE [o_THM,FUN_EQ_THM]) thms @ thms @ enc_defs
1644in
1645        (REPEAT STRIP_TAC THEN
1646        FIRST [ CONV_TAC (LAND_CONV (RATOR_CONV (FIRST_CONV (map REWR_CONV thms)))) THEN
1647                RULE_ASSUM_TAC GSYM THEN ASM_REWRITE_TAC all_thms,
1648                PURE_REWRITE_TAC [o_THM] THEN CONV_TAC (LAND_CONV (RAND_CONV APP_MAP_CONV THENC ENCODER_CONV) THENC RAND_CONV ENCODER_CONV) THEN
1649                PURE_REWRITE_TAC [I_THM,I_o_ID] THEN CTAC target rset t THEN ASM_REWRITE_TAC all_thms THEN
1650                RULE_ASSUM_TAC GSYM THEN ASM_REWRITE_TAC (o_THM::all_thms)]) (a,g)
1651end
1652end
1653
1654fun map_compose_tactic (t:hol_type) (a,g) =
1655let     val t = type_of (rand (rhs (snd (strip_forall (snd (strip_imp (snd (strip_forall g))))))))
1656        val rts = relevant_types t
1657        val thms = map (generate_source_theorem "map_compose") rts
1658        val map_defs = map (C get_source_function_def "map") (mk_prod(alpha,beta)::all_types t);
1659        val all_thms = map (PURE_REWRITE_RULE [o_THM,FUN_EQ_THM]) thms @ thms @ map_defs
1660in
1661        (REPEAT STRIP_TAC THEN
1662         FIRST [CONV_TAC (LAND_CONV (RATOR_CONV (FIRST_CONV (map REWR_CONV thms)))) THEN
1663                RULE_ASSUM_TAC GSYM THEN ASM_REWRITE_TAC all_thms,
1664                PURE_REWRITE_TAC [o_THM] THEN CONV_TAC (LAND_CONV (RAND_CONV APP_MAP_CONV THENC APP_MAP_CONV) THENC RAND_CONV APP_MAP_CONV) THEN
1665                REWRITE_TAC (mapfilter TypeBase.one_one_of (all_types t)) THEN REPEAT CONJ_TAC THEN
1666                CONV_TAC (LAND_CONV (REWR_CONV (GSYM o_THM))) THEN
1667                ASM_REWRITE_TAC all_thms]) (a,g)
1668end;
1669
1670fun encode_detect_all_tactic target (t:hol_type) (a,g) =
1671let     val t = type_of (rand (rhs (snd (strip_forall (snd (strip_imp (snd (strip_forall g))))))))
1672        val thms = map (generate_coding_theorem target "encode_detect_all") (relevant_types t)
1673        val all_defs = mapfilter (C get_source_function_def "all") (all_types t)
1674in
1675        (REPEAT STRIP_TAC THEN
1676        FIRST [ CONV_TAC (LAND_CONV (RATOR_CONV (FIRST_CONV (map REWR_CONV thms)))) THEN
1677                RULE_ASSUM_TAC GSYM THEN ASM_REWRITE_TAC (all_defs @ thms),
1678                PURE_REWRITE_TAC [o_THM] THEN CONV_TAC (LAND_CONV (RAND_CONV ENCODER_CONV THENC (DETECT_PAIR_CONV t)) THENC RAND_CONV APP_ALL_CONV) THEN
1679                ASM_REWRITE_TAC (get_source_function_def (mk_prod(alpha,beta)) "all"::thms) THEN
1680                RULE_ASSUM_TAC GSYM THEN ASM_REWRITE_TAC (all_defs @ thms)]) (a,g)
1681end     handle e => wrapException "encode_detect_all_tactic" e
1682
1683
1684fun LET_RAND_CONV match term =
1685let     fun strip_pair x = if is_pair x then op:: ((I ## strip_pair) (dest_pair x)) else [x]
1686        val (func_tm,let_tm) = dest_comb term
1687        val (inputs,output) = pairSyntax.dest_anylet let_tm
1688        val ginput = map (fn (a,b) => (a,genvar (type_of a))) inputs
1689        val alpha = gen_tyvar()
1690        val beta = gen_tyvar()
1691        val goutput = list_mk_comb(mk_var("M",list_mk_fun(flatten (map (map type_of o strip_pair o fst) ginput),alpha)),
1692                                flatten (map (strip_pair o fst) ginput));
1693        val gterml = mk_comb(mk_var("f",alpha --> beta),pairSyntax.mk_anylet (ginput,goutput));
1694        val gtermr = pairSyntax.mk_anylet(ginput,mk_comb(mk_var("f",alpha --> beta),goutput));
1695        val thm = CONV_RULE bool_EQ_CONV (DEPTH_CONV (REWR_CONV LET_DEF ORELSEC GEN_BETA_CONV ORELSEC REWR_CONV REFL_CLAUSE) (mk_eq(gterml,gtermr)))
1696in
1697        if match = func_tm then HO_REWR_CONV thm term else NO_CONV term
1698end
1699
1700local
1701fun PCONV tm = if is_pair tm then RAND_CONV PCONV tm else TRY_CONV (REWR_CONV (GSYM PAIR) THENC PCONV) tm
1702fun PLET_CONV term =
1703let     val len = strip_pair (fst (dest_pabs (rand (rator term))));
1704        val full_pair = list_mk_prod(map (fn a => gen_tyvar ()) len)
1705        val thm = PCONV (mk_var("x",full_pair))
1706in
1707        (RATOR_CONV (RATOR_CONV (REWR_CONV LET_DEF)) THENC
1708        RATOR_CONV BETA_CONV THENC BETA_CONV  THENC
1709        RAND_CONV (REWR_CONV thm) THENC
1710        PAIRED_BETA_CONV THENC
1711        REWRITE_CONV [GSYM thm]) term
1712end;
1713fun COND_CONG_TAC (a,g) =
1714        TRY (MATCH_MP_TAC COND_CONG THEN CONJ_TAC THENL [ALL_TAC,CONJ_TAC] THENL
1715                [ALL_TAC,DISCH_TAC THEN COND_CONG_TAC,DISCH_TAC THEN COND_CONG_TAC]) (a,g);
1716fun START_LABEL_TAC enc encoder target t =
1717        PURE_REWRITE_TAC [o_THM] THEN
1718        ONCE_REWRITE_TAC [get_coding_function_def target t "decode"] THEN
1719        ONCE_REWRITE_TAC [get_coding_function_def target t "fix"] THEN
1720        CONV_TAC (LAND_CONV (REDEPTH_CONV (FIRST_CONV [HO_REWR_CONV (ISPEC enc COND_RAND),LET_RAND_CONV enc]))) THEN
1721        MATCH_MP_TAC COND_CONG THEN REPEAT STRIP_TAC THEN
1722        REWRITE_TAC [REWRITE_RULE [FUN_EQ_THM,o_THM] (generate_coding_theorem target "encode_map_encode" t)] THEN
1723        ONCE_REWRITE_TAC [encoder] THEN
1724        CONV_TAC (REDEPTH_CONV (let_CONV ORELSEC PLET_CONV)) THEN COND_CONG_TAC THEN
1725        TRY (REWRITE_TAC [get_coding_function_def target (mk_prod(alpha,beta)) "encode",I_THM] THEN NO_TAC)
1726fun LABEL_TAC thms enc encoder target t =
1727        START_LABEL_TAC enc encoder target t THEN
1728        FIRST [FIRST_ASSUM (CONV_TAC o LAND_CONV o RAND_CONV o REWR_CONV o GSYM),REFL_TAC] THEN
1729        REWRITE_TAC (map (REWRITE_RULE [FUN_EQ_THM,o_THM]) thms) THEN
1730        MATCH_MP_TAC (get_coding_theorem target num "fix_id") THEN
1731        FIRST_ASSUM ACCEPT_TAC
1732fun XTAC fix_defs (a,g) =
1733        (if is_var (rand (rhs g)) then
1734        CONV_TAC (BINOP_CONV (FIRST_CONV (map REWR_CONV fix_defs)) THENC DEPTH_CONV (REWR_CONV LET_DEF) THENC DEPTH_CONV GEN_BETA_CONV) THEN
1735        COND_CONG_TAC else NO_TAC) (a,g)
1736in
1737fun decode_encode_fix_tactic target _ (a,g) =
1738let     val term = (rator o lhs o snd o strip_forall o snd o strip_imp o snd o strip_forall) g
1739        val enc = rand (rator term);
1740        val t = fst (dom_rng (type_of enc))
1741        val encoder = get_coding_function_def target t "encode";
1742        val mt = first (not o C (exists_coding_theorem target) "decode_encode_fix") (all_types t)
1743        val rts = mk_prod(alpha,beta)::num::relevant_types mt
1744        val thms = map (generate_coding_theorem target "decode_encode_fix" o base_type) rts @
1745                   mapfilter (generate_coding_theorem target "decode_encode_fix") rts
1746        val fix_defs = map (C (get_coding_function_def target) "fix") (mk_prod(alpha,beta)::num::all_types mt)
1747        val dead_thm = #bottom_thm (get_translation_scheme target)
1748        val all_defs = foldl (fn (a,b) => get_coding_function_def target a "encode"::get_source_function_def a "map"::
1749                                get_coding_function_def target a "decode"::get_coding_function_def target a "fix"::
1750                                get_coding_function_def target a "detect"::b) [o_THM,dead_thm] (all_types t)
1751        val thm = REWRITE_RULE [I_THM,get_source_function_def (mk_prod(alpha,beta)) "map"]
1752                (ISPEC (mk_pair(mk_var("x",num),mk_var("y",beta))) (REWRITE_RULE [FUN_EQ_THM,o_THM]
1753                        (SPEC_ALL (generate_coding_theorem target "encode_map_encode" (mk_prod(num,beta))))))
1754in
1755        (REPEAT STRIP_TAC THEN
1756        FIRST [
1757                CONV_TAC (LAND_CONV (RATOR_CONV (FIRST_CONV (mapfilter REWR_CONV thms)))),
1758                LABEL_TAC thms enc encoder target t,
1759                START_LABEL_TAC enc encoder target t THEN
1760                TRY (FULL_SIMP_TAC std_ss [get_coding_function_def target (mk_prod(alpha,beta)) "detect"] THEN NO_TAC) THEN
1761                TRY (FIRST_ASSUM (CONV_TAC o LAND_CONV o RAND_CONV o RATOR_CONV o RAND_CONV o REWR_CONV o GSYM) THEN
1762                CONV_TAC (LAND_CONV (REWR_CONV thm THENC RAND_CONV (REWR_CONV PAIR)))) THEN
1763                CONV_TAC (LAND_CONV (FIRST_CONV (mapfilter (REWR_CONV o PURE_REWRITE_RULE [FUN_EQ_THM,o_THM]) thms))) THEN
1764                REWRITE_TAC [I_o_ID]] THEN
1765        REPEAT (XTAC fix_defs THEN RES_TAC THEN
1766                ASM_REWRITE_TAC (last (CONJUNCTS sexpTheory.sexp_11)::get_coding_function_def target (mk_prod(alpha,beta)) "encode"::thms) THEN
1767                RULE_ASSUM_TAC GSYM THEN ASM_REWRITE_TAC (map GSYM thms)) THEN
1768        REPEAT (CHANGED_TAC (ONCE_ASM_REWRITE_TAC (K_THM::all_defs) THEN
1769                REWRITE_TAC (translateTheory.DETDEAD_PAIR::I_THM::mapfilter (C get_source_theorem "map_id" o base_type) (all_types t)) THEN
1770                CONV_TAC (DEPTH_CONV (REWR_CONV LET_DEF) THENC DEPTH_CONV GEN_BETA_CONV)))) (a,g)
1771end
1772end;
1773
1774local
1775fun MATCH_FIX_TAC target rset all_types (a,g) =
1776let     fun mcheck rset t = exists (can (C match_type t)) rset orelse
1777                        (can dest_type t andalso exists (mcheck rset) (snd (dest_type t)))
1778        val ftypes = (filter (not o is_vartype) (filter (not o mcheck rset) (num::all_types)));
1779        val thms = map (fn a => (C (PART_MATCH I) (snd (strip_forall (mk_fix_id_conc target a))) o FULL_FIX_ID_THM target) a) ftypes;
1780        val thms' = map (CONV_RULE (DEPTH_CONV RIGHT_IMP_FORALL_CONV THENC REWRITE_CONV [AND_IMP_INTRO])) thms
1781in
1782        (MATCH_MP_TAC (first (curry op= ((rator o lhs) g) o rator o lhs o snd o strip_imp o
1783                        snd o strip_forall o snd o strip_imp o snd o strip_forall o concl) thms')) (a,g)
1784end;
1785in
1786fun fix_id_tactic target t (a,g) =
1787let     val all_types = all_types t
1788        val mt = first (not o C (exists_coding_theorem target) "fix_id" o base_type) all_types
1789        val defs = map (C (get_coding_function_def target) "fix" o base_type) (mk_prod(alpha,beta)::num::all_types)
1790        val pdefs = flatten (mapfilter (CONJUNCTS o C (get_coding_function_def target) "detect" o base_type)
1791                                (mk_prod(alpha,beta)::num::all_types))
1792        val split_thm = generate_coding_theorem target "fix_id" (mk_prod(num,target))
1793        val rts = relevant_types mt
1794        val rset = map fst (split_nested_recursive_set mt)
1795        val def_thms = mapfilter (generate_coding_theorem target "decode_encode_fix" o base_type)
1796                                (mk_prod(alpha,beta)::num::rts)
1797        val tsplit_thm = SIMP_RULE std_ss [get_coding_function_def target (mk_prod(alpha,beta)) "detect",
1798                                get_coding_function_def target (mk_prod(alpha,beta)) "fix"]
1799                        (GSYM (generate_coding_theorem target "fix_id" (mk_prod(target,target))));
1800        val thms = mapfilter (GEN_ALL o (CONV_RULE (DEPTH_CONV RIGHT_IMP_FORALL_CONV THENC
1801                        REWRITE_CONV [AND_IMP_INTRO])) o
1802                        generate_coding_theorem target "fix_id") (rev (mk_prod(alpha,beta)::num::rts))
1803        val cond_tm = mk_cond(mk_var("p",bool),mk_var("a",alpha),(mk_var("b",alpha)));
1804in
1805        (REPEAT (POP_ASSUM MP_TAC) THEN REPEAT STRIP_TAC THEN
1806        FIRST [
1807                RULE_ASSUM_TAC (CONV_RULE (ONCE_REWRITE_CONV pdefs THENC REWRITE_CONV [COND_EXPAND])) THEN
1808                POP_ASSUM STRIP_ASSUME_TAC THEN ONCE_REWRITE_TAC defs THEN ASM_REWRITE_TAC [COND_ID] THEN NO_TAC,
1809                RULE_ASSUM_TAC (ONCE_REWRITE_RULE pdefs) THEN
1810                RULE_ASSUM_TAC (ONCE_REWRITE_RULE [get_coding_function_def target (mk_prod(alpha,beta)) "detect"]) THEN
1811                POP_ASSUM MP_TAC THEN ASM_REWRITE_TAC [] THEN NO_TAC,
1812                IMP_RES_TAC tsplit_thm THEN POP_ASSUM (CONV_TAC o RAND_CONV o REWR_CONV) THEN
1813                CONV_TAC (LAND_CONV (FIRST_CONV (map REWR_CONV defs))) THEN
1814                RULE_ASSUM_TAC (CONV_RULE (TRY_CONV (FIRST_CONV (map REWR_CONV pdefs) THENC REWR_CONV COND_EXPAND))) THEN
1815                POP_ASSUM STRIP_ASSUME_TAC THEN
1816                CONV_TAC (DEPTH_CONV (REWR_CONV LET_DEF ORELSEC GEN_BETA_CONV)) THEN
1817                RULE_ASSUM_TAC (CONV_RULE (DEPTH_CONV (REWR_CONV LET_DEF ORELSEC GEN_BETA_CONV))) THEN
1818                ASM_REWRITE_TAC [] THEN
1819                TRY (   REPEAT IF_CASES_TAC THEN
1820                        PAT_ASSUM cond_tm MP_TAC THEN ASM_REWRITE_TAC [get_coding_function_def target (mk_prod(alpha,beta)) "encode"] THEN TRY STRIP_TAC THEN
1821                        TRY (CONV_TAC (LAND_CONV (FIRST_CONV (map REWR_CONV defs)))) THEN ASM_REWRITE_TAC []) THEN
1822                TRY (MK_COMB_TAC THENL [MK_COMB_TAC,ALL_TAC]) THEN
1823                REWRITE_TAC [] THEN
1824                TRY (   (FIRST_ASSUM (CONV_TAC o LAND_CONV o REWR_CONV o GSYM) ORELSE
1825                                FIRST_ASSUM (CONV_TAC o LAND_CONV o RAND_CONV o REWR_CONV o GSYM) ORELSE MATCH_MP_TAC tsplit_thm) THEN
1826                        REWRITE_TAC [get_coding_function_def target (mk_prod(alpha,beta)) "decode"] THEN
1827                        ASM_REWRITE_TAC (I_THM::FST::SND::map (REWRITE_RULE [o_THM,FUN_EQ_THM]) def_thms)) THEN
1828                (FIRST_ASSUM ACCEPT_TAC ORELSE FIRST_ASSUM MATCH_MP_TAC ORELSE MATCH_FIX_TAC target rset all_types) THEN
1829                FULL_SIMP_TAC std_ss [get_coding_function_def target (mk_prod(alpha,beta)) "detect",get_coding_function_def target (mk_prod(alpha,beta)) "decode"] THEN
1830                RULE_ASSUM_TAC (REWRITE_RULE [ISPEC fst_tm COND_RAND,ISPEC snd_tm COND_RAND,I_THM]) THEN
1831                REPEAT (POP_ASSUM MP_TAC) THEN REPEAT IF_CASES_TAC THEN REPEAT STRIP_TAC THEN RULE_ASSUM_TAC (REWRITE_RULE []) THEN ASM_REWRITE_TAC [],
1832                ONCE_REWRITE_TAC defs THEN IF_CASES_TAC THEN
1833                TRY (FIRST_ASSUM MATCH_MP_TAC) THEN TRY (MATCH_FIX_TAC target rset all_types) THENL
1834                        [RULE_ASSUM_TAC (ONCE_REWRITE_RULE pdefs),ALL_TAC] THEN
1835                REPEAT (FIRST_ASSUM (fn th => MP_TAC th THEN WEAKEN_TAC (curry op= (concl th)) THEN IF_CASES_TAC THEN DISCH_TAC)) THEN
1836                RES_TAC THEN IMP_RES_TAC (generate_coding_theorem target "general_detect" (base_type t)) THEN REPEAT CONJ_TAC THEN
1837                REPEAT ((FIRST_ASSUM ACCEPT_TAC ORELSE FIRST_ASSUM (ACCEPT_TAC o GSYM) ORELSE REPEAT (POP_ASSUM MP_TAC) THEN ASM_REWRITE_TAC []) THEN
1838                        ONCE_REWRITE_TAC [get_coding_function_def target t "detect"] THEN REPEAT (IF_CASES_TAC THEN ASM_REWRITE_TAC []) THEN REPEAT STRIP_TAC THEN
1839                        MAP_EVERY (IMP_RES_TAC o generate_coding_theorem target "general_detect" o base_type) (filter (not o is_vartype) all_types))]) (a,g)
1840end
1841end;
1842
1843local
1844fun FULL_MATCH_TAC thm (a,g) =
1845let     val (tsub,_) = match_term (snd (strip_exists g)) (concl thm)
1846        val list = map (fn v => #residue (first (curry op= v o #redex) tsub) handle _ => v) (fst (strip_exists g))
1847in
1848        (MAP_EVERY EXISTS_TAC list THEN MATCH_ACCEPT_TAC thm) (a,g)
1849end;
1850fun COMPLETE thm split thms (a,g) =
1851        (FIRST [REWRITE_TAC [K_THM] THEN NO_TAC,
1852                FIRST_ASSUM MATCH_MP_TAC,
1853                MAP_FIRST (MATCH_MP_TAC o GEN_ALL) thms THEN FIRST_ASSUM FULL_MATCH_TAC,
1854                CONV_TAC (UNDISCH o PART_MATCH (lhs o snd o dest_imp_only) thm) THEN REPEAT CONJ_TAC THEN COMPLETE thm split thms,
1855                IMP_RES_TAC thm THEN RES_TAC THEN COMPLETE thm split thms]) (a,g)
1856in
1857fun general_detect_tactic target t =
1858let     val all_types = map base_type ((mk_prod(alpha,beta))::(filter (not o is_vartype) (flatten (map (op@ o snd) (split_nested_recursive_set t)))));
1859        val pair_def = get_coding_function_def target (mk_prod(alpha,beta)) "detect"
1860        val thms = [K_THM,get_coding_function_def target (mk_prod(alpha,beta)) "decode",I_THM];
1861        val (thm,split) = CONJ_PAIR (CONV_RULE (REWR_CONV COND_RAND THENC
1862                REWRITE_CONV [COND_EXPAND,GSYM (SPEC (mk_neg(mk_var("a",bool))) DISJ_COMM),GSYM IMP_DISJ_THM]) (SPEC_ALL pair_def));
1863
1864in      REPEAT GEN_TAC THEN REWRITE_TAC [get_coding_function_def target t "detect"] THEN
1865        TRY (FULL_SIMP_TAC std_ss [pair_def] THEN NO_TAC) THEN
1866        REWRITE_TAC [LET_DEF] THEN GEN_BETA_TAC THEN ASM_REWRITE_TAC thms THEN
1867        REPEAT (IF_CASES_TAC THEN ASM_REWRITE_TAC []) THEN REPEAT STRIP_TAC THEN
1868        REPEAT (FIRST_ASSUM (fn th => MP_TAC th THEN WEAKEN_TAC (curry op= (concl th)) THEN IF_CASES_TAC)) THEN
1869        RES_TAC THEN ASM_REWRITE_TAC [FST,SND] THEN REPEAT STRIP_TAC THEN
1870        REPEAT (CHANGED_TAC (IMP_RES_TAC split THEN IMP_RES_TAC thm)) THEN
1871        COMPLETE thm split (mapfilter (generate_coding_theorem target "general_detect") all_types)
1872end
1873end
1874
1875local
1876fun DDR dead_value target t =
1877let     val def = get_coding_function_def target t "detect"
1878        val den = get_coding_theorem target num "detect_dead"
1879        val t' = fst (strip_fun (type_of (hd (constructors_of t))))
1880        val dead_thm = #bottom_thm (get_translation_scheme target)
1881in
1882        (ONCE_REWRITE_CONV [def] THENC
1883        REWRITE_CONV (dead_thm::den::mapfilter (generate_coding_theorem target "detect_dead" o list_mk_prod) [t']))
1884        (mk_comb(check_function (get_detect_function target) t,dead_value))
1885end
1886in
1887fun detect_dead_rule target t =
1888let     val dead_value = #bottom (get_translation_scheme target)
1889in
1890        if t = target then
1891                REWRITE_CONV [get_coding_theorem target bool "encode_decode",K_THM] (mk_comb(get_decode_function target bool,mk_comb(get_detect_function target t,dead_value)))
1892        else DDR dead_value target t
1893end
1894end;
1895
1896fun all_id_tactic t (a,g) =
1897let     val t = type_of (rand (rhs (snd (strip_forall (snd (strip_imp (snd (strip_forall g))))))))
1898        val all_thms = mapfilter (generate_source_theorem "all_id") (relevant_types t)
1899        val def = get_source_function_def t "all"
1900        val pair_def = get_source_function_def (mk_prod(alpha,beta)) "all"
1901in
1902        (REPEAT STRIP_TAC THEN REWRITE_TAC [def,pair_def] THEN ASM_REWRITE_TAC (K_THM::all_thms)) (a,g)
1903end
1904
1905fun map_id_tactic t (a,g) =
1906let     val t = type_of (rand (rhs (snd (strip_forall (snd (strip_imp (snd (strip_forall g))))))))
1907        val all_thms = mapfilter (generate_source_theorem "map_id") (relevant_types t)
1908        val def = get_source_function_def t "map"
1909        val pair_def = get_source_function_def (mk_prod(alpha,beta)) "map"
1910in
1911        (REPEAT STRIP_TAC THEN REWRITE_TAC [def,pair_def] THEN ASM_REWRITE_TAC (I_THM::all_thms)) (a,g)
1912end;
1913
1914(*****************************************************************************)
1915(* Destructors:                                                              *)
1916(*     mk_destructors : hol_type -> hol_type -> thm list                     *)
1917(*                                                                           *)
1918(*     Produces destructor theorems and conditional rewrites for resolving   *)
1919(*     them. This will also produce predicate theorems, eg:                  *)
1920(*         |- (FST o sexp_to_pair nat f o pair g) (Ci ...) = i               *)
1921(*                                                                           *)
1922(*****************************************************************************)
1923
1924fun MK_FST thm =
1925    AP_TERM (mk_const("FST",
1926            type_of (lhs (concl thm)) -->
1927            fst (dest_prod (type_of (lhs (concl thm)))))) thm;
1928fun MK_SND thm =
1929    AP_TERM (mk_const("SND",
1930            type_of (lhs (concl thm)) -->
1931            snd (dest_prod (type_of (lhs (concl thm)))))) thm;
1932
1933local
1934fun PRODUCTS 0 thm = [thm]
1935  | PRODUCTS n thm =
1936    MK_FST thm :: PRODUCTS (n - 1) (MK_SND thm) handle _ => [thm];
1937fun O_CONV c term =
1938    if free_in c (rator term) orelse free_in c (rator (rand term))
1939       handle _ => true
1940       then ALL_CONV term
1941       else (RAND_CONV (O_CONV c) THENC REWR_CONV (GSYM o_THM)) term
1942fun dest_filter l =
1943    filter (fn thm =>
1944           mem (rhs (concl thm)) (snd (strip_comb (rand (lhs (concl thm)))))
1945           handle e => false) l
1946exception NotAPair
1947fun mk_single_destructor target t c =
1948let val (types,_) = strip_fun (type_of c)
1949    val args = map (fn (n,t) => mk_var((implode o base26) n,t))
1950                   (enumerate 0 types);
1951    val cons = list_mk_comb(c,args)
1952    val encoders = CONJUNCTS (get_coding_function_def target t "encode")
1953    val e = SPEC_ALL (tryfind (C (PART_MATCH (rand o lhs)) cons) encoders)
1954    val encoder = PART_MATCH (rator o lhs) e (get_encode_function target t)
1955    val product = type_of (rand (rhs (concl encoder)))
1956                  handle _ => raise NotAPair
1957    val _ = if can pairLib.dest_prod product then () else raise NotAPair
1958    val applied = AP_TERM (get_decode_function target
1959                          product) (SPEC_ALL encoder);
1960    val decoder = SPEC_ALL (FULL_ENCODE_DECODE_THM target product)
1961    val decoder' = PART_MATCH (lhs o snd o strip_forall o snd o strip_imp)
1962                              decoder (rhs (concl applied))
1963    val decoder'' = SPEC_ALL (UNDISCH decoder' handle _ => decoder')
1964    val rewritten = RIGHT_CONV_RULE (REWR_CONV decoder'') applied
1965    val product_encoder = get_encode_function target product;
1966    val apped = RIGHT_CONV_RULE (REWR_CONV (GSYM encoder))
1967                               (AP_TERM product_encoder rewritten);
1968    val var = variant (thm_frees apped) (mk_var("x",t));
1969    val eterm = list_mk_exists(args,mk_eq(var,cons))
1970    val rapped = PURE_REWRITE_RULE [GSYM (ASSUME (mk_eq(var,cons)))] apped;
1971    val chosen = DISCH_ALL_CONJ (CHOOSE_L (args,ASSUME eterm) rapped);
1972    val subtypes = sub_types t
1973in
1974    (chosen,dest_filter (map (CONV_RULE (LAND_CONV (O_CONV c)) o
1975                RIGHT_CONV_RULE (REWRITE_CONV [FST,SND]))
1976                    (PRODUCTS (length args) rewritten)))
1977end
1978fun mapf f [] = []
1979  | mapf f (x::xs) =
1980let val r = mapf f xs
1981in  (f x :: r) handle NotAPair => r | e => raise e end
1982in
1983fun mk_destructors target t =
1984let val (chosen,destructors) =
1985        unzip (mapf (mk_single_destructor target t) (constructors_of t))
1986        handle e => wrapException "mk_destructors" e
1987in
1988    (chosen, flatten destructors)
1989end
1990end
1991
1992(*****************************************************************************)
1993(* Initialisation:                                                           *)
1994(*     initialise_source_function_generators : unit -> unit                  *)
1995(*     initialise_coding_function_generators : hol_type -> unit              *)
1996(*                                                                           *)
1997(*****************************************************************************)
1998
1999fun initialise_source_function_generators () =
2000let     val _ = add_compound_source_function_generator
2001                "map"
2002                mk_map_term
2003                get_map_function
2004                REFL REFL;
2005        val _ = add_compound_source_function_generator
2006                "all"
2007                mk_all_term
2008                get_all_function
2009                (fn x => (EVERY_CONJ_CONV (STRIP_QUANT_CONV (RAND_CONV
2010                        (PURE_REWRITE_CONV [get_source_function_def (mk_prod(alpha,beta)) "all"])))) x
2011                                handle UNCHANGED => REFL x)
2012                REFL;
2013in
2014        ()
2015end;
2016
2017fun initialise_coding_function_generators target =
2018let     val _ = add_compound_coding_function_generator
2019                "encode"
2020                (mk_encode_term target)
2021                (get_encode_function target)
2022                (ENCODE_CONV (get_coding_function_def target (mk_prod(alpha,beta)) "encode"))
2023                REFL target;
2024        val _ = add_compound_target_function_generator
2025                "detect"
2026                (mk_detect_term target)
2027                (get_detect_function target)
2028                (DETECT_CONV target)
2029                REFL target;
2030        val _ = add_rule_coding_theorem_generator "detect_dead" (can constructors_of)
2031                (detect_dead_rule target) target;
2032        val _ = add_compound_target_function_generator
2033                "decode"
2034                (fn t => (      generate_source_function "map" (base_type t) ;
2035                                generate_coding_function target "encode" (base_type t) ;
2036                                mk_decode_term target t))
2037                (get_decode_function target)
2038                (DECODE_CONV target)
2039                (CONSOLIDATE_CONV I)
2040                target;
2041        val _ = add_compound_target_function_generator
2042                "fix"
2043                (mk_fix_term target)
2044                (get_fix_function target)
2045                (FIX_CONV target)
2046                (CONSOLIDATE_CONV rand)
2047                target;
2048in
2049        ()
2050end
2051
2052fun initialise_coding_theorem_generators target =
2053let     val _ = set_coding_theorem_conclusion
2054                target "encode_detect_all"
2055                (mk_encode_detect_all_conc target);
2056        val _ = set_source_theorem_conclusion
2057                "all_id" mk_all_id_conc;
2058        val _ = set_source_theorem_conclusion
2059                "map_id" mk_map_id_conc;
2060        val _ = set_source_theorem_conclusion
2061                "map_compose" mk_map_compose_conc;
2062        val _  = set_coding_theorem_conclusion
2063                target "encode_decode_map" (mk_encode_decode_map_conc target);
2064        val _ = set_coding_theorem_conclusion
2065                target "encode_map_encode" (mk_encode_map_encode_conc target);
2066        val _ = set_coding_theorem_conclusion
2067                target "general_detect" (mk_general_detect_conc target);
2068        val _ = set_coding_theorem_conclusion
2069                target "decode_encode_fix" (mk_decode_encode_fix_conc target);
2070        val _ = set_coding_theorem_conclusion
2071                target "fix_id" (mk_fix_id_conc target);
2072
2073        val _ = add_inductive_coding_theorem_generator
2074                "encode" "encode_detect_all"
2075                target FUN_EQ_CONV
2076                (encode_detect_all_tactic target);
2077        val _ = add_inductive_source_theorem_generator
2078                "all" "all_id"
2079                FUN_EQ_CONV all_id_tactic;
2080        val _ = add_inductive_source_theorem_generator
2081                "map" "map_id"
2082                FUN_EQ_CONV map_id_tactic;
2083        val _ = add_inductive_source_theorem_generator
2084                "map" "map_compose"
2085                FUN_EQ_CONV map_compose_tactic;
2086        val _ = add_inductive_coding_theorem_generator
2087                "encode" "encode_decode_map"
2088                target FUN_EQ_CONV
2089                (encode_decode_map_tactic target);
2090        val _ = add_inductive_coding_theorem_generator
2091                "encode" "encode_map_encode"
2092                target FUN_EQ_CONV
2093                (encode_map_encode_tactic target);
2094        val _ = add_inductive_coding_theorem_generator
2095                "detect" "general_detect"
2096                target REFL
2097                (general_detect_tactic target);
2098        val _ = add_inductive_coding_theorem_generator
2099                "decode" "decode_encode_fix"
2100                target FUN_EQ_CONV
2101                (decode_encode_fix_tactic target);
2102        val _ = add_inductive_coding_theorem_generator
2103                "fix" "fix_id"
2104                target REFL
2105                (fix_id_tactic target);
2106
2107        fun check_target_rule_use function_name theorem_name t =
2108            (exists_coding_theorem target t theorem_name) orelse
2109            not (can (C (get_coding_function_induction target) function_name) t)
2110
2111        fun check_source_rule_use function_name theorem_name t =
2112            (exists_source_theorem t theorem_name) orelse
2113            not (can (C get_source_function_induction function_name) t)
2114
2115        val _ = add_rule_coding_theorem_generator
2116                "encode_detect_all"
2117                (check_target_rule_use "encode" "encode_detect_all")
2118                (FULL_ENCODE_DETECT_ALL_THM target)
2119                target;
2120        val _ = add_rule_source_theorem_generator
2121                "all_id"
2122                (check_source_rule_use "all" "all_id")
2123                FULL_ALL_ID_THM;
2124        val _ = add_rule_source_theorem_generator
2125                "map_id"
2126                (check_source_rule_use "map" "map_id")
2127                FULL_MAP_ID_THM;
2128        val _ = add_rule_source_theorem_generator
2129                "map_compose"
2130                (check_source_rule_use "map" "map_compose")
2131                FULL_MAP_COMPOSE_THM;
2132        val _ = add_rule_coding_theorem_generator
2133                "encode_decode_map"
2134                (check_target_rule_use "encode" "encode_decode_map")
2135                (FULL_ENCODE_DECODE_MAP_THM target)
2136                target;
2137        val _ = add_rule_coding_theorem_generator
2138                "encode_map_encode"
2139                (check_target_rule_use "encode" "encode_map_encode")
2140                (FULL_ENCODE_MAP_ENCODE_THM target)
2141                target;
2142        val _ = add_rule_coding_theorem_generator
2143                "general_detect"
2144                (C (exists_coding_theorem target) "general_detect")
2145                (C (get_coding_theorem target) "general_detect")
2146                target;
2147        val _ = add_rule_coding_theorem_generator
2148                "decode_encode_fix"
2149                (check_target_rule_use "decode" "decode_encode_fix")
2150                (FULL_DECODE_ENCODE_FIX_THM target)
2151                target;
2152        val _ = add_rule_coding_theorem_generator
2153                "fix_id"
2154                (check_target_rule_use "fix" "fix_id")
2155                (FULL_FIX_ID_THM target)
2156                target;
2157in
2158        ()
2159end;
2160
2161fun encode_type target t =
2162let     val _ = if can (match_type t) (base_type t) then () else
2163                   raise (mkDebugExn "encode_type"
2164                  "encode_type should only be applied to base types")
2165        val _ = if exists_translation target t
2166                   then ()
2167                   else add_translation target t
2168        val _ = generate_source_function "map" t
2169        val _ = generate_source_function "all" t
2170        val _ = generate_coding_function target "encode" t
2171        val _ = generate_coding_function target "decode" t
2172        val _ = generate_coding_function target "detect" t
2173        val _ = generate_coding_function target "fix" t
2174
2175        val _ = generate_coding_theorem target "encode_detect_all" t
2176        val _ = generate_coding_theorem target "encode_map_encode" t
2177        val _ = generate_coding_theorem target "encode_decode_map" t
2178        val _ = generate_coding_theorem target "decode_encode_fix" t
2179
2180        val _ = generate_coding_theorem target "fix_id" t
2181in
2182        ()
2183end     handle e => wrapException "encode_type" e
2184
2185local
2186fun GENCF name f target t =
2187    if (target = t) then f target t
2188    else
2189    (if exists_coding_function target t name
2190       then f target t
2191       else (encode_type target (base_type t) ; f target t))
2192    handle e => wrapException ("gen_" ^ name ^ "_function") e
2193in
2194val gen_encode_function = GENCF "encode" get_encode_function
2195val gen_decode_function = GENCF "decode" get_decode_function
2196val gen_detect_function = GENCF "detect" get_detect_function
2197end;
2198
2199(*****************************************************************************)
2200(* predicate_equivalence : hol_type -> hol_type -> thm                       *)
2201(*     Returns a theorem of the form:                                        *)
2202(*         |- (!x. P x) = (!x. detect x ==> P (decode x))                    *)
2203(*     for a type t.                                                         *)
2204(*     This can then be used to derive a fully encoded theorem using a rule  *)
2205(*     rule implication and the encoding of booleans.                        *)
2206(*                                                                           *)
2207(*****************************************************************************)
2208
2209fun predicate_equivalence target t =
2210let val pred = mk_var("P",t --> bool)
2211    val var = mk_var("x",t);
2212    val target_var = mk_var("x",target);
2213    val detect = mk_comb(get_detect_function target t,target_var)
2214    val decode = mk_comb(get_decode_function target t,target_var)
2215    val encode = mk_comb(get_encode_function target t,var)
2216
2217    val full_pred = mk_forall(var,mk_comb(pred,var))
2218
2219    val thm1 = GEN target_var (DISCH detect (SPEC decode (ASSUME full_pred)))
2220
2221    val encdet = UNDISCH_ALL (SPEC_ALL (FULL_ENCODE_DETECT_THM target t))
2222    val decenc = UNDISCH_ALL (SPEC_ALL (FULL_ENCODE_DECODE_THM target t))
2223    val thm2 = GEN var (PURE_REWRITE_RULE [encdet,decenc,IMP_CLAUSES]
2224                    (SPEC encode (ASSUME (concl thm1))))
2225in
2226    DISCH_ALL_CONJ (IMP_ANTISYM_RULE (DISCH (concl thm2) thm1)
2227                                     (DISCH (concl thm1) thm2))
2228end handle e => wrapException "predicate_equivalence" e
2229
2230
2231end