1structure encodeLib :> encodeLib = 2struct 3 4open Thm Term Type boolSyntax Parse Conv Rewrite Drule 5open Tactic Tactical pairLib numLib polytypicLib 6open Binarymap List Lib 7open boolTheory pairTheory listTheory combinTheory 8 9open bossLib metisLib; 10 11(*****************************************************************************) 12(* construct_bottom_value : (hol_type -> bool) -> term -> hol_type -> term *) 13(* Given a predicate to indicate stopping, and a stop term, constructs *) 14(* the first non-recursive constructor of the type given. *) 15(* *) 16(* Eg. construct_bottom_value (curry op= bool) F *) 17(* ``:bool # (bool + bool)`` = ``(F,INL F)`` *) 18(* *) 19(* target_bottom_value : hol_type -> term -> hol_type -> term *) 20(* Constructs the term using 'ARB' values, then replaces them with the *) 21(* given term once complete. *) 22(* *) 23(* Eg. target_bottom_value bool F ``:'a # ('b + 'c)`` = ``(F,INL F)`` *) 24(* *) 25(*****************************************************************************) 26 27local 28exception Bottom; 29fun tryfind b f [] = if b then raise Bottom else raise Empty 30 | tryfind b f (x::xs) = 31 (f x) handle Empty => tryfind b f xs 32 | Bottom => tryfind true f xs 33fun match term t = inst (match_type (snd (strip_fun (type_of term))) t) term 34fun construct_bottom_value_n 0 p xvar _ = raise Bottom 35 | construct_bottom_value_n n p xvar t = 36 if p t then match xvar t 37 else let val cs = [match (get_source_function_const t "bottom-cons") t] 38 handle _ => (constructors_of t handle e => []) 39 in tryfind false 40 (fn c => full_beta (list_mk_comb(c, 41 map (construct_bottom_value_n (n - 1) p xvar) 42 (fst (strip_fun (type_of c)))))) cs 43 end 44fun itdeep f n = f n handle Bottom => itdeep f (n + 1) 45fun check value t = 46 if type_of value = t then value 47 else raise (mkDebugExn "construct_bottom_value" 48 ("Constructed a term of type: " ^ type_to_string (type_of value) ^ 49 "\nwhen a value of type: " ^ type_to_string t ^ 50 "\nshould have been returned!!")) 51in 52fun construct_bottom_value p xvar t = 53 check (itdeep (fn n => construct_bottom_value_n n p xvar t) 0 54 handle Empty => raise (mkStandardExn "construct_bottom_value" 55 ("Could not find bottom values for all sub-types of " 56 ^ type_to_string t)) 57 | e => wrapException "construct_bottom_value" e) t 58end; 59 60fun target_bottom_value target bottom_target t = 61let val b1 = construct_bottom_value is_vartype (mk_arb alpha) t 62 val arbs = HolKernel.find_terms is_arb b1 63 val types = mk_set (map type_of arbs) 64 val b2 = inst (map (fn t => t |-> target) types) b1 65in 66 subst [mk_arb target |-> bottom_target] b2 67end 68 69(*****************************************************************************) 70(* set_bottom_value : hol_type -> term -> unit *) 71(* Set the bottom value for the type given, this is only required for *) 72(* non-recursive types. *) 73(* *) 74(*****************************************************************************) 75 76fun set_bottom_value t term = 77 add_source_function_precise 78 t "bottom-cons" 79 {const = term, definition = TRUTH,induction = NONE} 80 handle e => wrapException "set_bottom_value" e 81 82(*****************************************************************************) 83(* Generation of the encoding functions *) 84(* *) 85(* get_encode_type, get_decode_type, get_detect_type, get_fix_type *) 86(* : hol_type -> hol_type -> hol_type *) 87(* get_map_type, get_all_type *) 88(* : hol_type -> hol_type *) 89(* mk_encode_var, mk_decode_var, mk_detect_var, mk_fix_var *) 90(* : hol_type -> hol_type -> term *) 91(* mk_map_var, mk_all_var *) 92(* : hol_type -> term *) 93(* Returns the type of an encoding, decoding, detecting or mapping *) 94(* constant, and makes a variable for a prospective constant *) 95(* *) 96(* mk_encode_term : hol_type -> hol_type -> term *) 97(* Makes a full encoding term for the translation given: *) 98(* Single constructor: enc (C a0 a1) = P enc0 enc1 (a0,a1) *) 99(* Label constructors: enc Cn = nat n *) 100(* Otherwise : enc (Ci a0 a1) = P nat (P enc0 enc1) (i,a0,a1) *) 101(* *) 102(* mk_decode_term : hol_type -> hol_type -> term *) 103(* Makes a full decoding term for the translation given: *) 104(* Single constructor: dec x = let (a,b) = D dec0 dec1 x in (C a b) *) 105(* Label constructors: dec x = if dnat x = 0 then C0 else .... *) 106(* Otherwise : dec x = *) 107(* let (l,r) = D dnat I x *) 108(* in if l = 0 then *) 109(* let (a,b) = D dec0 dec1 r in (C a b) *) 110(* else map dec0 dec1 (C nil nil) *) 111(* *) 112(* mk_detect_term : hol_type -> hol_type -> term *) 113(* Makes a full decoding term for the translation given: *) 114(* Single constructor: dec x = P dec0 dec1 x *) 115(* Label constructors: dec x = bool (dnat x = 0 \/ ... \/ dnat x = n) *) 116(* Otherwise : dec x = *) 117(* bool ( *) 118(* dbool (P dnat (K (bool T)) x) *) 119(* /\ let (l,r) = D dnat I x *) 120(* in (l = 0) /\ dbool (P det0 det1 r) *) 121(* \/ ...) *) 122(* *) 123(* mk_map_term : hol_type -> term *) 124(* Makes a full map function for the given type: *) 125(* Label constructors: map Li = Li *) 126(* Otherwise : map (C a0 .. an) = (map0 # .. # mapn) (a0,..,an) *) 127(* *) 128(* mk_all_term : hol_type -> term *) 129(* Makes a full all function for the given type: *) 130(* Label constructors: all Li = T *) 131(* Otherwise : all (C a0 .. an) = (all0 # .. # alln) (a0,..,an) *) 132(* *) 133(* mk_fix_term : hol_type -> hol_type -> term *) 134(* Single constructor: fix x = TP fix0 (TP ... fixn) .. ) x *) 135(* Label constructors: fix x = x *) 136(* Otherwise : fix x = *) 137(* let (l,r) = D dnat I x *) 138(* in if l = 0 then *) 139(* pair nat I (TP fix0 (TP .. fixn) ) r *) 140(* else enc fix0 fix1 (C nil nil) *) 141(* *) 142(* get_encode_function, get_decode_function, get_detect_function, *) 143(* get_fix_function : hol_type -> hol_type -> term *) 144(* get_map_function : hol_type -> term *) 145(* Gets a fully instantiated term to translate the type *) 146(* *) 147(* ENCODE_CONV, DECODE_CONV, DETECT_CONV, FIX_CONV : hol_type -> term -> thm *) 148(* Given the target type, each conv rewrites to convert a term given to *) 149(* it by mk_..._term to a form suitable for split_function. *) 150(* *) 151(* CONSOLIDATE_CONV : (term -> term) -> term -> thm *) 152(* Given a conjunction of functions with instantiated bottom values, ie *) 153(* of the form: *) 154(* f0 x = if .. then .. else B0 /\ *) 155(* ... *) 156(* fn x = if .. then .. else Bn *) 157(* Where B0...Bn may contain references to f0...fn and continually *) 158(* rewrites with each definition to get B0'...Bn' that don't contain *) 159(* references to f0 ... fn. Only works for 'decode' and 'fix'! *) 160(* *) 161(* *) 162(* mk_encodes, mk_decodes, mk_detects, mk_fixs : hol_type -> hol_type-> unit *) 163(* mk_maps, mk_alls : hol_type -> unit *) 164(* Generate the functions given. Shouldn't really be used, as it doesn't *) 165(* use the generator system, and will hence fail if functions are *) 166(* missing. *) 167(* *) 168(*****************************************************************************) 169 170local 171fun get_gen_type opr target t = 172 foldr (fn (x,t) => (opr (x,target)) --> t) (opr (t,target)) 173 (if is_vartype t then [] else snd (dest_type t)) 174in 175fun get_encode_type target t = get_gen_type op--> target t 176 handle e => wrapException "get_encode_type" e 177fun get_decode_type target t = get_gen_type (uncurry (C (curry op-->))) target t 178 handle e => wrapException "get_decode_type" e 179fun get_detect_type target t = get_gen_type (fn (_,a) => a --> bool) target t 180 handle e => wrapException "get_detect_type" e 181fun get_map_type t = 182let val tyvars = type_vars t 183 fun gentvar t = (mk_vartype o curry op^ "'map_" o get_type_string) t 184 val new_vars = map gentvar tyvars 185 val t' = type_subst (map2 (curry op|->) tyvars new_vars) t 186in 187 foldr (fn ((x,y),t) => (x --> y) --> t) (t --> t') 188 (if is_vartype t then [] else zip (snd (dest_type t)) (snd (dest_type t'))) 189end 190fun get_all_type t = get_gen_type (fn (a,_) => a --> bool) t t 191fun get_fix_type target t = get_gen_type (fn (_,_) => target --> target) target t 192 handle e => wrapException "get_fix_type" e 193end 194 195local 196fun mk_encode_string t = "encode" ^ (get_type_string t) 197fun mk_decode_string t = "decode" ^ (get_type_string t) 198fun mk_detect_string t = "detect" ^ (get_type_string t) 199fun mk_map_string t = "map" ^ (get_type_string t) 200fun mk_fix_string t = "fix" ^ (get_type_string t) 201fun mk_all_string t = "all" ^ (get_type_string t) 202fun mk_fix_string t = "fix" ^ (get_type_string t) 203in 204fun mk_encode_var target t = 205 mk_var(mk_encode_string t,get_encode_type target t) 206 handle e => wrapException "mk_encode_var" e 207fun mk_decode_var target t = 208 mk_var(mk_decode_string t,get_decode_type target t) 209 handle e => wrapException "mk_decode_var" e 210fun mk_detect_var target t = 211 mk_var(mk_detect_string t,get_detect_type target t) 212 handle e => wrapException "mk_detect_var" e 213fun mk_map_var t = 214 mk_var(mk_map_string t,get_map_type t) 215 handle e => wrapException "mk_map_var" e 216fun mk_fix_var target t = 217 mk_var(mk_fix_string t,get_fix_type target t) 218 handle e => wrapException "mk_fix_var" e 219fun mk_all_var t = 220 mk_var(mk_all_string t,get_all_type t) 221 handle e => wrapException "mk_all_var" e 222fun mk_fix_var target t = 223 mk_var(mk_fix_string t,get_fix_type target t) 224 handle e => wrapException "mk_fix_var" e 225end 226 227local 228fun new_const NONE const t = const 229 | new_const (SOME match) const t = 230 safe_inst (match_type (match (type_of const)) t) const 231in 232fun get_encode_const target t = 233 if t = target 234 then mk_const("I",t --> target) 235 else new_const (SOME (last o fst o strip_fun)) 236 (get_coding_function_const target t "encode") t 237 handle e => wrapException "get_encode_const" e 238fun get_decode_const target t = 239 if t = target 240 then mk_const("I",target --> t) 241 else new_const (SOME (snd o strip_fun)) 242 (get_coding_function_const target t "decode") t 243 handle e => wrapException "get_decode_const" e; 244fun get_map_const t = 245 new_const (SOME (last o fst o strip_fun)) 246 (get_source_function_const t "map") t 247 handle e => wrapException "get_map_const" e; 248fun get_fix_const target t = 249 if t = target 250 then mk_const("I",target --> t) 251 else new_const NONE 252 (get_coding_function_const target t "fix") t 253 handle e => wrapException "get_fix_const" e; 254fun get_all_const t = 255 new_const (SOME (last o fst o strip_fun)) 256 (get_source_function_const t "all") t 257 handle e => wrapException "get_all_const" e; 258fun get_detect_const target t = 259 if t = target 260 then mk_comb(mk_const("K",bool --> target --> bool),T) 261 else new_const NONE (get_coding_function_const target t "detect") t 262 handle e => wrapException "get_detect_const" e 263end 264 265local 266fun fix_base basetype term = 267let val types = set_diff (type_vars_in_term term) (type_vars basetype) 268in inst (map (fn x => x |-> gen_tyvar()) types) term 269end; 270fun imk_comb (main,p) = 271 mk_comb(inst (match_type (fst (dom_rng (type_of main))) (type_of p)) main,p) 272val is_the_value_type = can (match_type (type_of boolSyntax.the_value)) 273fun typevars_lr t = 274 if is_vartype t then [t] 275 else flatten (map typevars_lr (snd (dest_type t))); 276fun mk_the_value t = 277 inst (match_type (type_of the_value) (mk_type("itself",[t]))) the_value; 278fun get_function fconst fexists mvar t = 279let val basetype = (most_precise_type fexists t) 280 handle _ => (if is_vartype t then t else base_type t) 281 val base = fconst basetype handle _ => mvar t 282 val params = set_diff (typevars_lr basetype) [t] 283 val match = match_type basetype t 284 val param_list = map (type_subst match) params 285 val insted = inst match (fix_base basetype base) 286 in 287 if not (is_vartype t) 288 then foldl (fn (a,t) => 289 imk_comb(t,get_function fconst fexists mvar a) handle _ => 290 imk_comb(t,mk_the_value a) handle _ => t) 291 insted param_list 292 else base 293end 294in 295fun get_encode_function target t = 296 get_function (get_encode_const target) 297 (C (exists_coding_function_precise target) "encode") 298 (mk_encode_var target) t 299 handle e => wrapException "get_encode_function" e 300fun get_decode_function target t = 301 get_function (get_decode_const target) 302 (C (exists_coding_function_precise target) "decode") 303 (mk_decode_var target) t 304 handle e => wrapException "get_decode_function" e 305fun get_detect_function target t = 306 get_function (get_detect_const target) 307 (C (exists_coding_function_precise target) "detect") 308 (mk_detect_var target) t 309 handle e => wrapException "get_detect_function" e 310fun get_map_function t = 311 get_function get_map_const 312 (C exists_source_function_precise "map") 313 mk_map_var t 314 handle e => wrapException "get_map_function" e 315fun get_fix_function target t = 316 get_function (get_fix_const target) 317 (C (exists_coding_function_precise target) "fix") 318 (mk_fix_var target) t 319 handle e => wrapException "get_fix_function" e 320fun get_all_function t = 321 get_function get_all_const 322 (C exists_source_function_precise "all") 323 mk_all_var t 324 handle e => wrapException "get_all_function" e 325end 326 327local 328fun mk_detect_constructor_ns rest target C = 329let val (list,_) = strip_fun (type_of C) 330 val vars = map (mk_var o (implode o base26 ## I)) (enumerate 0 list) 331in 332 mk_comb(get_detect_function target (list_mk_prod list),rest) 333end 334fun mk_detect_constructor rest target C T = 335 if can dom_rng (type_of C) then mk_detect_constructor_ns rest target C else T 336fun mk_detect_res_term label rest target t constructors T = 337 list_mk_cond 338 (map (fn (a,b) => (mk_eq(label,numLib.term_of_int a), 339 mk_detect_constructor rest target b T)) (enumerate 0 constructors)) 340 F 341fun mk_detect_term_label (p,x) target t constructors = 342let val dnat = get_decode_function target num 343 val rnat = get_detect_function target num 344in 345 mk_forall(x,mk_eq(mk_comb(get_detect_function target t,x), 346 mk_cond(mk_comb(rnat,x), 347 mk_detect_res_term (mk_comb(dnat,x)) x target t constructors T, 348 F))) 349end 350fun mk_detect_term_all (p,x) target t constructors = 351let val dnat = get_detect_function target num 352 val label = mk_var("l",num) 353 val rest = mk_var("r",target) 354 val null = mk_comb(get_encode_function target bool,F) 355in 356 list_mk_forall(snd (strip_comb (get_detect_function target t)), 357 mk_forall(x,mk_eq(mk_comb(get_detect_function target t,x), 358 mk_cond(p,pairSyntax.mk_anylet ( 359 [(mk_pair(label,rest),mk_comb(get_decode_function target (mk_prod(num,target)),x))], 360 mk_detect_res_term label rest target t constructors (mk_eq(rest,null))), 361 F)))) 362end 363fun mk_detect_term_single (p,x) target t constructor = 364let val t' = (mk_type o (I ## map (K target)) o dest_type) t 365 val p = get_detect_function target t' 366in 367 list_mk_forall(snd (strip_comb (get_detect_function target t)), 368 mk_forall(x,mk_eq(mk_comb(get_detect_function target t,x),mk_detect_constructor x target constructor T))) 369end 370in 371fun mk_detect_term target t = 372let val t' = base_type t handle e => wrapException "mk_detect_term" e 373 val constructors = constructors_of t' handle e => wrapException "mk_detect_term" e 374 val x = mk_var("x",target) 375 val p = mk_comb(get_detect_function target (mk_prod(num,target)),x) 376in 377 if all (not o can dom_rng o type_of) constructors 378 then mk_detect_term_label (p,x) target t' constructors 379 handle e => wrapException "mk_detect_term (label)" e 380 else if length constructors = 1 381 then mk_detect_term_single (p,x) target t' (hd constructors) 382 handle e => wrapException "mk_detect_term (single)" e 383 else mk_detect_term_all (p,x) target t' constructors 384 handle e => wrapException "mk_detect_term" e 385end 386end 387 388local 389fun full_bottom_value target bottom_target t = 390let val bottom = target_bottom_value target bottom_target t 391 val mapf = get_map_const t 392 val decodef = get_decode_function target t 393 val map_function = list_imk_comb(mapf,snd (strip_comb decodef)) 394 val bottom' = inst (match_type (type_of bottom) 395 (fst (dom_rng (type_of map_function)))) bottom 396in 397 mk_comb(map_function,bottom') 398end handle e => wrapException "full_bottom_value" e 399fun mk_decode_constructor_ns rest target C = 400let val (list,_) = strip_fun (type_of C) 401 val vars = map (mk_var o (implode o base26 ## I)) (enumerate 0 list) 402in 403 pairSyntax.mk_anylet ( 404 [(list_mk_pair vars,mk_comb(get_decode_function target (list_mk_prod list),rest))], 405 (list_mk_comb(C,vars))) 406end 407fun mk_decode_constructor rest target C = 408 if can dom_rng (type_of C) then mk_decode_constructor_ns rest target C else C 409fun mk_decode_res_term label rest target t constructors bottom = 410 list_mk_cond 411 (map (fn (a,b) => (mk_eq(label,numLib.term_of_int a), 412 mk_decode_constructor rest target b)) (enumerate 0 constructors)) 413 bottom 414fun mk_decode_term_label (p,x) target t constructors = 415let val dnat = get_decode_function target num 416 val rnat = get_detect_function target num 417 val bottom = hd constructors 418in 419 mk_forall(x,mk_eq(mk_comb(get_decode_function target t,x), 420 mk_cond(mk_comb(rnat,x), 421 mk_decode_res_term (mk_comb(dnat,x)) x target t constructors bottom, 422 hd (constructors)))) 423end 424fun mk_decode_term_all (p,x) target t constructors = 425let val dnat = get_decode_function target num 426 val label = mk_var("l",num) 427 val rest = mk_var("r",target) 428 val bottom = full_bottom_value target (#bottom(get_translation_scheme target)) t 429in 430 list_mk_forall(snd (strip_comb (get_decode_function target t)), 431 mk_forall(x,mk_eq(mk_comb(get_decode_function target t,x), 432 mk_cond(p,pairSyntax.mk_anylet ( 433 [(mk_pair(label,rest),mk_comb(get_decode_function target (mk_prod(num,target)),x))], 434 mk_decode_res_term label rest target t constructors bottom), 435 bottom)))) 436end 437fun mk_decode_term_single (p,x) target t constructor = 438let val t' = (mk_type o (I ## map (K target)) o dest_type) t 439 val p = get_detect_function target t' 440in 441 list_mk_forall(snd (strip_comb (get_decode_function target t)), 442 mk_forall(x,mk_eq(mk_comb(get_decode_function target t,x), 443 mk_cond(mk_comb(p,x), 444 mk_decode_constructor x target constructor, 445 full_bottom_value target (#bottom(get_translation_scheme target)) t)))) 446end 447in 448fun mk_decode_term target t = 449let val t' = base_type t handle e => wrapException "mk_decode_term" e 450 val constructors = constructors_of t' handle e => wrapException "mk_decode_term" e 451 val x = mk_var("x",target) 452 val p = mk_comb(get_detect_function target (mk_prod(num,target)),x) 453in 454 if all (not o can dom_rng o type_of) constructors 455 then mk_decode_term_label (p,x) target t' constructors 456 handle e => wrapException "mk_decode_term (label)" e 457 else if length constructors = 1 458 then mk_decode_term_single (p,x) target t' (hd constructors) 459 handle e => wrapException "mk_decode_term (single)" e 460 else mk_decode_term_all (p,x) target t' constructors 461 handle e => wrapException "mk_decode_term" e 462end 463end 464 465local 466fun mk_fix_constructor_ns all target C = 467let val (list,_) = strip_fun (type_of C) 468 val vars = map (mk_var o (implode o base26 ## I)) (enumerate 0 list) 469in 470 mk_comb(get_fix_function target (list_mk_prod (num::list)),all) 471end 472fun mk_fix_constructor all target dead n C = 473 if can dom_rng (type_of C) 474 then mk_fix_constructor_ns all target C 475 else mk_comb(get_encode_function target (mk_prod(num,bool)), 476 mk_pair(numLib.term_of_int n,F)) 477fun mk_fix_res_term label all target constructors dead bottom = 478 list_mk_cond 479 (map (fn (a,b) => (mk_eq(label,numLib.term_of_int a), 480 mk_fix_constructor all target dead a b)) (enumerate 0 constructors)) 481 bottom 482fun mk_fix_term_label x target t constructors = 483let val dnat = get_decode_function target num 484 val rnat = get_detect_function target num 485 val enat = get_encode_function target num 486in 487 mk_forall(x,mk_eq(mk_comb(get_fix_function target t,x), 488 mk_cond(mk_comb(rnat,x), 489 list_mk_cond (map (fn (a,b) => (mk_eq(mk_comb(dnat,x),numLib.term_of_int a),x)) (enumerate 0 constructors)) 490 (mk_comb(enat,zero_tm)), 491 mk_comb(get_encode_function target num,zero_tm)))) 492end 493fun mk_fix_term_all (p,x) target t constructors dead = 494let val dnat = get_fix_function target num 495 val label = mk_var("l",num) 496 val rest = mk_var("r",target) 497 val t' = (mk_type o (I ## map (K target)) o dest_type) t 498 val instit = inst (map (fn v => v |-> target) (type_vars t)) 499 val enc1 = instit (get_encode_function target t) 500 val enc2 = subst (map (fn v => instit (get_encode_function target v) |-> 501 get_fix_function target v) (type_vars t)) enc1 502 val bottom = rimk_comb(enc2,target_bottom_value target dead t) 503in 504 list_mk_forall(snd (strip_comb (get_fix_function target t)), 505 mk_forall(x,mk_eq(mk_comb(get_fix_function target t,x), 506 mk_cond(p,pairSyntax.mk_anylet ( 507 [(mk_pair(label,rest), 508 mk_comb(get_decode_function target (mk_prod(num,target)),x))], 509 mk_fix_res_term label x target constructors dead bottom), 510 bottom)))) 511end 512fun mk_fix_term_single x target t constructor dead = 513let val t' = (mk_type o (I ## map (K target)) o dest_type) t 514 val p = get_detect_function target t' 515 val instit = inst (map (fn v => v |-> target) (type_vars t)) 516 val enc1 = instit (get_encode_function target t) 517 val enc2 = subst (map (fn v => 518 instit (get_encode_function target v) |-> 519 get_fix_function target v) (type_vars t)) enc1 520 val bottom = rimk_comb(enc2,target_bottom_value target dead t) 521 val list = fst (strip_fun (type_of constructor)) 522in 523 list_mk_forall(snd (strip_comb (get_fix_function target t)), 524 mk_forall(x,mk_eq(mk_comb(get_fix_function target t,x), 525 mk_cond(mk_comb(p,x), 526 mk_comb(get_fix_function target (list_mk_prod list),x), 527 bottom)))) 528end 529in 530fun mk_fix_term target t = 531let val t' = base_type t handle e => wrapException "mk_fix_term" e 532 val constructors = constructors_of t' 533 handle e => wrapException "mk_fix_term" e 534 val x = mk_var("x",target) 535 val p = mk_comb(get_detect_function target (mk_prod(num,target)),x) 536 val dead = #bottom (get_translation_scheme target) 537in 538 if all (not o can dom_rng o type_of) constructors 539 then mk_fix_term_label x target t' constructors 540 handle e => wrapException "mk_fix_term (label)" e 541 else if length constructors = 1 542 then mk_fix_term_single x target t' (hd constructors) dead 543 handle e => wrapException "mk_fix_term (single)" e 544 else mk_fix_term_all (p,x) target t' constructors dead 545 handle e => wrapException "mk_fix_term" e 546end 547end 548 549local 550fun mk_avar (n,t) = (mk_var ("a_" ^ Int.toString n,t),t) 551fun single_pair num target t cnst = 552let val ts = fst (strip_fun (type_of cnst)) 553 val tvs = map mk_avar (enumerate 0 ts) 554 val (vars,types) = 555 unzip (case num 556 of NONE => tvs 557 | SOME x => (case tvs 558 of [] => [(numLib.term_of_int x,``:num``), 559 (``F``,``:bool``)] 560 | list => (numLib.term_of_int x,``:num``)::list)) 561in 562 list_mk_forall(map fst tvs, 563 list_mk_forall(map (get_encode_function target) (snd (dest_type t)), 564 mk_eq(mk_comb(get_encode_function target t, 565 list_mk_comb(cnst,map fst tvs)), 566 if vars = [] 567 then mk_comb(get_encode_function target ``:num``,``0n``) 568 else (mk_comb(get_encode_function target 569 (pairLib.list_mk_prod types), 570 pairLib.list_mk_pair vars))))) 571end 572fun mk_encode_term_single target t cnst = single_pair NONE target t cnst 573fun mk_encode_term_label target t cnsts = 574let val num = get_encode_function target ``:num`` 575 val func = get_encode_function target t 576in 577 list_mk_conj (map (fn (n,c) => 578 mk_eq(mk_comb(func,c),mk_comb(num,numLib.term_of_int n))) 579 (enumerate 0 cnsts)) 580end 581fun mk_encode_term_all target t cnsts = 582 list_mk_conj (map (fn (n,c) => single_pair (SOME n) target t c) 583 (enumerate 0 cnsts)) 584in 585fun mk_encode_term target t = 586let val t' = base_type t handle e => wrapException "mk_encode_term" e 587 val constructors = constructors_of t' 588 handle e => wrapException "mk_encode_term" e 589in 590 if all (not o can dom_rng o type_of) constructors 591 then mk_encode_term_label target t' constructors 592 handle e => wrapException "mk_encode_term (label)" e 593 else if length constructors = 1 594 then mk_encode_term_single target t' (hd constructors) 595 handle e => wrapException "mk_encode_term (single)" e 596 else mk_encode_term_all target t' constructors 597 handle e => wrapException "mk_encode_term" e 598end 599end; 600 601fun mk_map_term t = 602let val cs = constructors_of t 603 val args = map (fn c => map (fn (n,t) => 604 mk_var(implode (base26 n),t)) 605 (enumerate 0 (fst (strip_fun (type_of c))))) cs 606 val combs = map2 (curry list_mk_comb) cs args 607 val func = get_map_function t 608 val funs = snd (strip_comb func) 609 fun imk_eq (a,b) = mk_eq(a,inst (match_type (type_of b) (type_of a)) b) 610in 611 list_mk_conj (map (fn c => list_mk_forall(funs, 612 list_mk_forall(snd (strip_comb c),imk_eq(mk_comb(func,c), 613 ((list_imk_comb o (I ## map (fn a => 614 mk_comb(get_map_function (type_of a),a)))) 615 (strip_comb c)))))) 616 combs) 617end handle e => wrapException "mk_map_term" e 618 619fun mk_all_term t = 620let val cs = constructors_of t 621 val args = map (fn c => map (fn (n,t) => 622 mk_var(implode (base26 n),t)) (enumerate 0 (fst (strip_fun (type_of c))))) cs 623 val combs = map2 (curry list_mk_comb) cs args 624 val func = get_all_function t 625 val funs = snd (strip_comb func) 626in 627 list_mk_conj (map2 (fn a => fn c => list_mk_forall(funs, 628 list_mk_forall(a,mk_eq(mk_comb(func,c), 629 case a of 630 [] => T 631 | a => mk_comb(get_all_function (list_mk_prod(map type_of a)),list_mk_pair a))))) 632 args combs) 633end handle e => wrapException "mk_all_term" e 634 635fun ENCODE_CONV pair_thm term = 636let val _ = type_trace 2 "->ENCODE_CONV\n" 637 val fa_pairs = if (!debug) then 638 bucket_alist (zip (map (repeat rator o lhs o snd o strip_forall) (strip_conj term)) 639 (map (type_of o rand o rhs o snd o strip_forall) (strip_conj term))) 640 handle _ => [] 641 else [] 642 val t = mk_prod(numLib.num,alpha) 643 val _ = case (total (first (fn (a,b) => exists (can (match_type t)) b andalso 644 exists (not o can (match_type t)) b)) fa_pairs) 645 of SOME (fname,_) => raise (mkDebugExn "ENCODE_CONV" 646 ("Function clause: " ^ term_to_string fname ^ 647 " converts to a mixture of labelled pairs and non-labelled pairs")) 648 | NONE => () 649 fun drop_all term = ((REWR_CONV pair_thm THENC RAND_CONV drop_all) ORELSEC ALL_CONV) term 650in 651 EVERY_CONJ_CONV (STRIP_QUANT_CONV (RAND_CONV drop_all)) term 652 handle UNCHANGED => REFL term handle e => wrapException "PAIR_CONV (encode)" e 653end; 654 655local 656fun rws l = RATOR_CONV (RATOR_CONV (RAND_CONV (REWRITE_CONV l))) 657fun DC target term = 658let val r = rhs term 659 val pair_def = get_coding_function_def target (mk_prod(alpha,beta)) "detect" 660 val (p,a,b) = dest_cond r handle e => raise UNCHANGED 661 val (left,right) = dest_eq (snd (strip_forall (concl pair_def))) 662 val (xp,_,_) = dest_cond right 663in 664 if can (match_term left) p then 665 let val thm1 = rws [ASSUME xp,pair_def,COND_EXPAND] (rhs term); 666 val thm2 = (rws [ASSUME (mk_neg xp),pair_def,COND_EXPAND] THENC PURE_REWRITE_CONV [COND_CLAUSES]) (rhs term); 667 in 668 AP_TERM (rator term) (SYM (RIGHT_CONV_RULE (REWR_CONV COND_ID) (MATCH_MP COND_CONG (LIST_CONJ [REFL xp,DISCH_ALL (SYM thm1),DISCH_ALL (SYM thm2)])))) 669 end 670 else REFL term 671end handle UNCHANGED => REFL term | e => wrapException "DC" e 672fun SC target term = 673let val (p,a,b) = dest_cond (rhs term) 674 val pair_def = get_coding_function_def target (mk_prod(alpha,beta)) "decode" 675 val pthm = PURE_REWRITE_RULE [COND_CLAUSES,ASSUME p] (PART_MATCH (rand o rator o rator o rhs) pair_def p) 676in 677 AP_TERM (rator term) (MATCH_MP COND_CONG (LIST_CONJ 678 [REFL p,DISCH p (RATOR_CONV (RAND_CONV (RAND_CONV (REWR_CONV pthm) THENC PURE_REWRITE_CONV [I_THM] THENC pairLib.let_CONV)) a),DISCH (mk_neg p) (REFL b)])) 679end handle e => NO_CONV term 680fun FIXC target term = 681let val (p,a,b) = dest_cond (rhs term) 682 val pair_def = get_coding_function_def target (mk_prod(alpha,beta)) "fix" 683 val pthm = REWRITE_RULE [I_THM] (PART_MATCH (rator o lhs) (PURE_REWRITE_RULE [COND_CLAUSES,ASSUME p] (PART_MATCH (rand o rator o rator o rhs) pair_def p)) 684 (get_fix_function target (mk_prod(num,alpha)))) 685in 686 AP_TERM (rator term) (MATCH_MP COND_CONG (LIST_CONJ 687 [REFL p,DISCH p (PURE_REWRITE_CONV [pthm] a),DISCH (mk_neg p) (REFL b)])) 688end 689in 690fun DETECT_CONV target term = 691 (type_trace 2 "->DETECT_CONV\n" ; 692 EVERY_CONJ_CONV (STRIP_QUANT_CONV (DC target THENC TRY_CONV (SC target))) term handle e => wrapException "DETECT_CONV" e) 693fun DECODE_CONV target term = 694 (type_trace 2 "->DECODE_CONV\n" ; 695 EVERY_CONJ_CONV (STRIP_QUANT_CONV (DC target THENC TRY_CONV (SC target))) term handle e => wrapException "DECODE_CONV" e) 696fun FIX_CONV target term = 697 (type_trace 2 "->FIX_CONV\n" ; 698 EVERY_CONJ_CONV (STRIP_QUANT_CONV (DC target THENC TRY_CONV (SC target) THENC REWRITE_CONV [K_THM] THENC TRY_CONV (FIXC target))) term handle e => wrapException "FIX_CONV" e) 699end; 700 701fun mk_encodes target t = 702let val _ = if exists_coding_function target t "encode" then 703 raise (mkStandardExn "mk_encodes" 704 ("Encoder function for translation: " ^ type_to_string t ^ " --> " ^ type_to_string target ^ 705 " already exists.")) else () 706 val _ = type_trace 1 707 ("Generating encoding function for: " ^ type_to_string t ^ " --> " ^ type_to_string target ^ "\n") 708in 709 mk_coding_functions 710 "encode" 711 (mk_encode_term target) 712 (get_encode_function target) 713 (ENCODE_CONV (get_coding_function_def target (mk_prod(alpha,beta)) "encode")) 714 REFL 715 target 716 (base_type t) handle e => wrapException "mk_encodes" e 717end; 718 719val CONSOLIDATE_CONV_data = ref (NONE:((term -> term) * term) option); 720 721local 722val PTAC = HO_MATCH_MP_TAC (METIS_PROVE [] ``(A = B) /\ (C = D) ==> (P A C = P B D)``) THEN CONJ_TAC; 723fun AP_TAC funcs (ag as (assums,goal)) = 724 if (C mem funcs o repeat rator o lhs) goal orelse (C mem funcs o repeat rator o rhs) goal then ALL_TAC ag else 725 (TRY ((AP_TERM_TAC ORELSE AP_THM_TAC ORELSE PTAC) THEN AP_TAC funcs)) ag 726in 727fun CONSOLIDATE_CONV rfix function = 728let val _ = type_trace 2 "->CONSOLIDATE_CONV\n" 729 val _ = CONSOLIDATE_CONV_data := SOME (rfix,function) 730 val clauses = strip_conj function 731 val ends = map ((fn (a,b,c) => c) o dest_cond o rhs o snd o strip_forall) clauses 732 val target = (type_of o rand o lhs o snd o strip_forall o hd) clauses 733 val dead_thm = #bottom_thm (get_translation_scheme target) 734 val subs = (mk_type o (I ## map (fn a => if a = target then gen_tyvar() else a)) o dest_type) 735 val ts = (mk_prod(num,target))::mk_set (filter (not o is_vartype) 736 (flatten (mapfilter (map snd o reachable_graph (uncurried_subtypes) o subs o type_of o rfix) ends))) 737 val maps = map (fn x => (generate_source_function "map" (base_type x) ; C get_source_function_def "map" x)) ts 738 val encs = map (fn x => (generate_coding_function target "encode" (base_type x) ; 739 get_coding_function_def target x "encode")) ts 740 val decs = flatten (map CONJUNCTS (mapfilter (C (get_coding_function_def target) "decode") ts)) 741 val fixs = map (REWRITE_RULE encs) (flatten (map CONJUNCTS (mapfilter (C (get_coding_function_def target) "fix") ts))) 742 val deads = dead_thm::mapfilter (generate_coding_theorem target "detect_dead" o base_type) ts; 743 val hos = map ASSUME clauses 744 val results = map (fn term => REPEATC (CHANGED_CONV (REWRITE_CONV maps THENC REWRITE_CONV encs THENC 745 ONCE_REWRITE_CONV decs THENC ONCE_REWRITE_CONV fixs 746 THENC ONCE_REWRITE_CONV hos THENC REWRITE_CONV deads)) term 747 handle UNCHANGED => REFL term) ends 748 val full = ONCE_DEPTH_CONV (FIRST_CONV (map REWR_CONV results)) function 749 val funcs = map (repeat rator o lhs o snd o strip_forall) (clauses @ map concl decs @ map concl fixs); 750in 751 prove(concl full, 752 REWRITE_TAC encs THEN EQ_TAC THEN REPEAT STRIP_TAC THEN 753 REPEAT ( 754 FIRST [FIRST_ASSUM (CONV_TAC o LAND_CONV o REWR_CONV),CONV_TAC (LAND_CONV (FIRST_CONV (map REWR_CONV (fixs @ decs)))),ALL_TAC] THEN 755 FIRST [FIRST_ASSUM (CONV_TAC o RAND_CONV o REWR_CONV),CONV_TAC (RAND_CONV (FIRST_CONV (map REWR_CONV (fixs @ decs)))),ALL_TAC] THEN 756 TRY (MATCH_MP_TAC COND_CONG THEN REPEAT STRIP_TAC) THEN REWRITE_TAC deads THEN AP_TAC funcs THEN 757 REWRITE_TAC maps THEN REWRITE_TAC encs THEN REWRITE_TAC (dead_thm::deads) THEN REWRITE_TAC (mapfilter TypeBase.one_one_of ts) THEN REPEAT CONJ_TAC)) 758end handle e => wrapException "CONSOLIDATE_CONV" e 759end 760 761fun mk_decodes target t = 762let val _ = if exists_coding_function target t "decode" then 763 raise (mkStandardExn "mk_decodes" 764 ("Decoder function for translation: " ^ type_to_string t ^ " --> " ^ type_to_string target ^ 765 " already exists.")) else () 766 val _ = type_trace 1 767 ("Generating decoding function for: " ^ type_to_string t ^ " --> " ^ type_to_string target ^ "\n") 768in 769 mk_target_functions 770 "decode" 771 (mk_decode_term target) 772 (get_decode_function target) 773 (DECODE_CONV target) 774 (CONSOLIDATE_CONV rand) 775 target 776 (base_type t) handle e => wrapException "mk_decodes" e 777end; 778 779fun mk_detects target t = 780let val _ = if exists_coding_function target t "detect" then 781 raise (mkStandardExn "mk_detects" 782 ("Detector function for translation: " ^ type_to_string t ^ " --> " ^ type_to_string target ^ 783 " already exists.")) else () 784 val _ = type_trace 1 785 ("Generating detecting function for: " ^ type_to_string t ^ " --> " ^ type_to_string target ^ "\n") 786in 787 mk_target_functions 788 "detect" 789 (mk_detect_term target) 790 (get_detect_function target) 791 (DETECT_CONV target) 792 REFL 793 target 794 (base_type t) handle e => wrapException "mk_detects" e 795end; 796 797fun mk_maps t = 798let val _ = if exists_source_function t "map" then 799 raise (mkStandardExn "mk_maps" 800 ("Map function for type: " ^ type_to_string t ^ " already exists.")) else () 801 val _ = type_trace 1 802 ("Generating map function for: " ^ type_to_string t ^ "\n") 803in 804 mk_source_functions 805 "map" 806 mk_map_term 807 get_map_function 808 REFL 809 REFL 810 (base_type t) 811end; 812 813fun mk_alls t = 814let val _ = if exists_source_function t "all" then 815 raise (mkStandardExn "mk_alls" 816 ("All function for type: " ^ type_to_string t ^ " already exists.")) else () 817 val _ = type_trace 1 818 ("Generating all function for: " ^ type_to_string t ^ "\n") 819in 820 mk_source_functions 821 "all" 822 mk_all_term 823 get_all_function 824 (fn x => (EVERY_CONJ_CONV (STRIP_QUANT_CONV (RAND_CONV 825 (PURE_REWRITE_CONV [get_source_function_def (mk_prod(alpha,beta)) "all"])))) x 826 handle UNCHANGED => REFL x) 827 REFL 828 (base_type t) 829end; 830 831fun mk_fixs target t = 832let val _ = if exists_coding_function target t "fix" then 833 raise (mkStandardExn "mk_fixs" 834 ("fix function for translation: " ^ type_to_string t ^ " --> " ^ type_to_string target ^ 835 " already exists.")) else () 836 val _ = type_trace 1 837 ("Generating fix function for: " ^ type_to_string t ^ " --> " ^ type_to_string target ^ "\n") 838in 839 mk_target_functions 840 "fix" 841 (mk_fix_term target) 842 (get_fix_function target) 843 (FIX_CONV target) 844 (CONSOLIDATE_CONV rand) 845 target 846 (base_type t) handle e => wrapException "mk_fixs" e 847end; 848 849(*****************************************************************************) 850(* Generate conclusions for the various goals to be proven: *) 851(* *) 852(* mk_encode_decode_map_conc : hol_type -> hol_type -> term *) 853(* mk_encode_detect_all_conc : hol_type -> hol_type -> term *) 854(* mk_decode_encode_fix_conc : hol_type -> hol_type -> term *) 855(* mk_encode_map_encode_conc : hol_type -> hol_type -> term *) 856(* mk_map_compose_conc : hol_type -> term *) 857(* mk_map_id_conc : hol_type -> term *) 858(* mk_all_id_conc : hol_type -> term *) 859(* mk_fix_id_conc : hol_type -> hol_type -> term *) 860(* mk_general_detect_conc : hol_type -> hol_type -> term *) 861(* *) 862(* mk_encode_decode_conc : hol_type -> hol_type -> term *) 863(* mk_decode_encode_conc : hol_type -> hol_type -> term *) 864(* mk_encode_detect_conc : hol_type -> hol_type -> term *) 865(* *) 866(* *) 867(* Make the conclusions for the various theorems: *) 868(* ?- (decode f o encode g) = map (f o g) *) 869(* ?- (encode f o decode g) = fix (f o g) *) 870(* ?- (detect f o encode g) = all (f o g) *) 871(* *) 872(* ?- (encode f o map g) = encode (f o g) *) 873(* ?- (map f o map g) = map (f o g) *) 874(* *) 875(* ?- map I = I *) 876(* ?- all (K T) = K T *) 877(* ?- (!x. f x = x) ==> (!x. fix f x = x) *) 878(* *) 879(* ?- !x. detect f g x ==> detect (K T) (K T) x *) 880(* *) 881(* ?- (!x. f (g x) = x) ==> !x. decode f (encode g x) = x *) 882(* ?- (!x. p x ==> g (f x) = x) ==> *) 883(* !x. detect p x ==> encode g (decode f x) = x *) 884(* ?- (!x. p (g x)) ==> !x. detect p (encode g x) *) 885(* *) 886(*****************************************************************************) 887 888fun get_hfuns term = 889 if is_comb term 890 then flatten (map get_hfuns (op:: (strip_comb term))) 891 else [term]; 892 893fun type_vars_avoiding_itself function t = 894 set_diff (type_vars t) 895 (map (hd o snd o dest_type o type_of) 896 (filter (can (match_term the_value)) (get_hfuns function))); 897 898fun check_function gf t = 899let val term = gf t 900 val hfuns = get_hfuns term 901 val vars = filter is_var hfuns 902 val values = filter (polymorphic o type_of) 903 (filter (can (match_term the_value)) hfuns) 904in 905 if length (mk_set (vars @ values)) > length (type_vars t) 906 then raise (mkDebugExn "check_function" 907 ("The function term: " ^ term_to_string term ^ 908 "\ncontains free-variables not derived from the type: " ^ 909 type_to_string t)) 910 else term 911end; 912 913local 914fun wrap e = wrapException "mk_encode_decode_map_conc" e 915fun err s = raise (mkDebugExn "mk_encode_decode_map_conc" 916("Unable to correctly instantiate type variables in " ^ s ^ " function")) 917in 918fun mk_encode_decode_map_conc target t = 919let val enc = check_function (get_encode_function target) t handle e => wrap e 920 val dec = check_function (get_decode_function target) t handle e => wrap e 921 val map_term = check_function get_map_function t handle e => wrap e 922 val safe_map_term = inst (map (fn v => v |-> gen_tyvar()) 923 (type_vars_in_term map_term)) map_term; 924 val tvs = type_vars_avoiding_itself enc t 925 val values = set_diff (type_vars t) tvs 926 927 fun inst_from term start types = 928 inst (map (fn (a,b) => 929 b |-> mk_vartype (String.implode(#"'" :: base26 (a + start)))) 930 (enumerate 0 types)) term; 931 val enc' = inst_from (inst_from enc 0 tvs) (length tvs) values 932 handle e => err "encode" 933 val dec' = inst_from (inst_from dec (length tvs + length values) tvs) 934 (length tvs) values handle e => err "decode" 935 val map' = inst (match_type (type_of safe_map_term) 936 (fst (dom_rng (type_of enc')) --> 937 snd (dom_rng (type_of dec')))) 938 safe_map_term handle e => err "map"; 939 940 val enc_vars = free_vars_lr enc' 941 val dec_vars = free_vars_lr dec' 942 val sub = map2 (curry op|->) 943 (free_vars_lr map') 944 (map2 (curry combinSyntax.mk_o) dec_vars enc_vars) 945 handle e => wrap e 946in 947 list_mk_forall(enc_vars, 948 list_mk_forall(dec_vars,mk_eq(combinSyntax.mk_o(dec',enc'),subst sub map'))) 949 handle e => wrap e 950end 951end 952 953local 954fun w s e = wrapException s e 955fun mk_ring_conc left func1 func2 = 956let val sub = map2 (curry op|->) 957 (free_vars_lr (if left then func1 else func2)) 958 (map2 (curry combinSyntax.mk_o) (free_vars_lr func1) (free_vars_lr func2)) handle e => w "mk_ring_conc" e 959 val tsubs = map (fn {redex,residue} => match_type (type_of redex) (type_of residue)) sub 960 handle e => w "mk_ring_conc" e 961 val ins = C (foldl (uncurry inst)) tsubs handle e => w "mk_ring_conc" e 962 val sub' = map (fn {redex,residue} => ins redex |-> residue) sub handle e => w "mk_ring_conc" e 963in 964 list_mk_forall(free_vars_lr func1, 965 list_mk_forall(free_vars_lr func2, 966 mk_eq((curry combinSyntax.mk_o) func1 func2,subst sub' (ins (if left then func1 else func2))))) 967 handle e => w "mk_ring_conc" e 968end 969in 970fun mk_encode_map_encode_conc target t = 971let val encf = check_function (get_encode_function target) t handle e => w "mk_encode_map_encode_conc" e 972 val mapf = check_function get_map_function t handle e => w "mk_encode_map_encode_conc" e 973 val encf' = inst (match_type (fst (dom_rng (type_of encf))) (snd (dom_rng (type_of mapf)))) encf 974 handle e => w "mk_encode_map_encode_conc" e 975in 976 mk_ring_conc true encf' mapf handle e => w "mk_encode_map_encode_conc" e 977end 978fun mk_map_compose_conc t = 979let val map1 = check_function get_map_function t handle e => w "mk_map_compose_conc" e 980 val map2 = inst (map (fn x => x |-> (mk_vartype o curry op^ "'map" o get_type_string) x) 981 (type_vars (type_of map1))) map1 handle e => w "mk_map_compose_conc" e 982 val map1' = inst (match_type (snd (dom_rng (type_of map1))) (fst (dom_rng (type_of map2)))) map1 983 handle e => w "mk_encode_map_encode_conc" e 984in 985 mk_ring_conc false map2 map1' handle e => w "mk_encode_map_encode_conc" e 986end 987end; 988 989fun mk_decode_encode_fix_conc target t = 990let val enc = check_function (get_encode_function target) t 991 val dec = check_function (get_decode_function target) t 992 val fix = check_function (get_fix_function target) t 993 994 val enc_vars = free_vars_lr enc 995 val dec_vars = free_vars_lr dec 996 val sub = map2 (curry op|->) (free_vars_lr fix) (map2 (curry combinSyntax.mk_o) enc_vars dec_vars) 997in 998 list_mk_forall(enc_vars, 999 list_mk_forall(dec_vars, 1000 mk_eq((curry combinSyntax.mk_o) enc dec,subst sub fix))) 1001end 1002 1003fun mk_encode_detect_all_conc target t = 1004let val enc = check_function (get_encode_function target) t 1005 val det = check_function (get_detect_function target) t 1006 val all = check_function get_all_function t 1007 val dbool = get_decode_function target bool 1008 1009 val enc_vars = free_vars_lr enc 1010 val det_vars = free_vars_lr det 1011 val sub = map2 (curry op|->) (free_vars_lr all) (map2 (curry combinSyntax.mk_o) det_vars enc_vars) 1012in 1013 list_mk_forall(det_vars, 1014 list_mk_forall(enc_vars,mk_eq((curry combinSyntax.mk_o) det enc,subst sub all))) 1015end; 1016 1017fun mk_map_id_conc t = 1018let val map_term = check_function get_map_function t 1019 val fvs = free_vars map_term 1020 val ty_sub = map (fn fv => snd (dom_rng (type_of fv)) |-> fst (dom_rng (type_of fv))) fvs 1021 val map' = inst ty_sub map_term 1022 val dub = fn t => fst (dom_rng t) --> fst (dom_rng t) 1023 val tm_sub = map (fn fv => (mk_var o (I ## dub) o dest_var) fv |-> mk_const("I",dub (type_of fv))) fvs 1024in 1025 mk_eq(subst tm_sub map',mk_const("I",type_of map')) 1026end 1027 1028fun mk_all_id_conc t = 1029let val all_term = check_function get_all_function t 1030 val fvs = free_vars all_term 1031 fun mk_all_id t = mk_comb(mk_const("K",bool --> (t --> bool)),T) 1032in mk_eq(subst (map (fn x => x |-> 1033 mk_all_id (fst (dom_rng (type_of x)))) fvs) all_term, 1034 mk_all_id t) 1035end 1036 1037fun mk_fix_id_conc target t = 1038let val fix_term = check_function (get_fix_function target) t 1039 val det_term = check_function (get_detect_function target) t 1040 val tvs = type_vars_avoiding_itself fix_term t 1041 val hyps = map (mk_fix_id_conc target) (set_diff tvs [t]) 1042 val x = mk_var("x",target) 1043 val e = mk_forall(x,mk_imp(mk_comb(det_term,x), 1044 mk_eq(mk_comb(fix_term,x),x))) 1045in 1046 if null hyps then e 1047 else mk_imp(list_mk_conj hyps,e) 1048end; 1049 1050fun mk_general_detect_conc target t = 1051let val p1 = check_function (get_detect_function target) t 1052 val t' = type_subst (map (fn v => v |-> target) 1053 (type_vars_avoiding_itself p1 t)) t 1054 val p2 = check_function (get_detect_function target) t' 1055 val xvar = mk_var("x",target) 1056in 1057 mk_forall(xvar, 1058 list_mk_forall(free_vars p1,mk_imp(mk_comb(p1,xvar),mk_comb(p2,xvar)))) 1059end; 1060 1061local 1062fun wrap e = wrapException "mk_encode_decode_conc" e 1063in 1064fun mk_encode_decode_conc target t = 1065let val encode = check_function (get_encode_function target) t 1066 handle e => wrap e 1067 val decode = check_function (get_decode_function target) t 1068 handle e => wrap e 1069 val var = mk_var("x",t) 1070 val conc = mk_forall(var,mk_eq(mk_comb(decode,mk_comb(encode,var)),var)) 1071 handle e => wrap e 1072 val tvs = type_vars_avoiding_itself encode t 1073 val ante = map (snd o dest_imp o snd o strip_forall o 1074 mk_encode_decode_conc target) (set_diff tvs [t]) 1075in 1076 list_mk_forall(map (get_encode_function target) tvs, 1077 list_mk_forall(map (get_decode_function target) tvs, 1078 if is_vartype t then mk_imp(conc,conc) else 1079 if null ante then conc else mk_imp(list_mk_conj ante,conc))) 1080 handle e => wrap e 1081end 1082end 1083 1084local 1085fun wrap e = wrapException "mk_decode_encode_conc" e 1086in 1087fun mk_decode_encode_conc target t = 1088let val encode = check_function (get_encode_function target) t 1089 handle e => wrap e 1090 val detect = check_function (get_detect_function target) t 1091 handle e => wrap e 1092 val decode = check_function (get_decode_function target) t 1093 handle e => wrap e 1094 val var = mk_var("x",target) 1095 val conc = mk_forall(var,mk_imp(mk_comb(detect,var), 1096 mk_eq(mk_comb(encode,mk_comb(decode,var)),var))) 1097 handle e => wrap e 1098 val tvs = type_vars_avoiding_itself encode t 1099 val ante = map (snd o dest_imp o snd o strip_forall o 1100 mk_decode_encode_conc target) (set_diff tvs [t]) 1101in 1102 list_mk_forall(map (get_encode_function target) tvs, 1103 list_mk_forall(map (get_decode_function target) tvs, 1104 list_mk_forall(map (get_detect_function target) tvs, 1105 if is_vartype t then mk_imp(conc,conc) else 1106 if null ante then conc else mk_imp(list_mk_conj ante,conc)))) 1107 handle e => wrap e 1108end 1109end 1110 1111local 1112fun wrap e = wrapException "mk_encode_detect_conc" e 1113in 1114fun mk_encode_detect_conc target t = 1115let val encode = check_function (get_encode_function target) t 1116 handle e => wrap e 1117 val detect = check_function (get_detect_function target) t 1118 handle e => wrap e 1119 val var = mk_var("x",t) 1120 val conc = mk_forall(var,mk_comb(detect,mk_comb(encode,var))) 1121 handle e => wrap e 1122 val tvs = type_vars_avoiding_itself encode t 1123 val ante = map (snd o dest_imp o snd o strip_forall o 1124 mk_encode_detect_conc target) (set_diff tvs [t]) 1125in 1126 list_mk_forall(map (get_encode_function target) tvs, 1127 list_mk_forall(map (get_detect_function target) tvs, 1128 if is_vartype t then mk_imp(conc,conc) else 1129 if null ante then conc else mk_imp(list_mk_conj ante,conc))) 1130 handle e => wrap e 1131end 1132end 1133 1134(*****************************************************************************) 1135(* Rules to generate instantiated theorems from base-type theorems: *) 1136(* *) 1137(* FULL_ENCODE_DECODE_MAP_THM : hol_type -> hol_type -> thm *) 1138(* FULL_ENCODE_DETECT_ALL_THM : hol_type -> hol_type -> thm *) 1139(* FULL_ENCODE_MAP_ENCODE_THM : hol_type -> hol_type -> thm *) 1140(* FULL_DECODE_ENCODE_FIX_THM : hol_type -> hol_type -> thm *) 1141(* FULL_MAP_COMPOSE_THM : hol_type -> hol_type -> thm *) 1142(* Create the theorem, eg: *) 1143(* |- map map o encode encode = encode encode *) 1144(* from: *) 1145(* |- !f g. map f o encode g = encode (f o g) *) 1146(* and |- map o encode = encode *) 1147(* *) 1148(* FULL_MAP_ID_THM : hol_type -> thm *) 1149(* FULL_ALL_ID_THM : hol_type -> thm *) 1150(* Create the theorem, eg: *) 1151(* |- map (map I) (map I) = I *) 1152(* from: |- map I I = I |- map I = I |- map I = I *) 1153(* *) 1154(* FULL_FIX_ID_THM : hol_type -> hol_type -> thm *) 1155(* Create the theorem, eg: *) 1156(* |- fix fix x = x *) 1157(* from: |- (!x. f x = x) ==> (!x. fix f x = x) |- !x. fix x = x *) 1158(* *) 1159(* *) 1160(* FULL_ENCODE_DECODE_THM : hol_type -> hol_type -> thm *) 1161(* Create the theorem, eg: *) 1162(* |- !x. decode decode (encode encode x) = x *) 1163(* *) 1164(* FULL_DECODE_ENCODE_THM : hol_type -> hol_type -> thm *) 1165(* Create the theorem, eg: *) 1166(* |- !x. detect detect x ==> encode encode (decode decode x) = x)*) 1167(* *) 1168(* FULL_ENCODE_DETECT_THM : hol_type -> hol_type -> thm *) 1169(* Create the theorem, eg: *) 1170(* |- !x. detect detect (encode encode x) *) 1171(* *) 1172(*****************************************************************************) 1173 1174fun wrap_full s t e = 1175 wrapException (s ^ "(" ^ type_to_string t ^ ")") e 1176 1177fun get_sub_types basetype t = 1178 filter (not o is_vartype) 1179 (map #residue (match_type basetype t)) 1180 1181local 1182fun EMPTY_RING gconc name target t thms = 1183let val conc = gconc target t 1184in 1185 EQT_ELIM (REWRITE_CONV thms conc) 1186end 1187fun RING_MATCH_THM gconc name target t = 1188let val basetype = if t = target then t else 1189 most_precise_type 1190 (C (exists_coding_theorem_precise target) name) t 1191 handle _ => t 1192 val thm = SPEC_ALL (generate_coding_theorem target name basetype) 1193 val conc = gconc target t 1194 val thm' = PART_MATCH lhs thm (lhs (snd (strip_forall conc))) 1195 val sub_thms = map (RING_MATCH_THM gconc name target) 1196 (get_sub_types basetype t) 1197in 1198 RIGHT_CONV_RULE (PURE_REWRITE_CONV sub_thms) thm' 1199end; 1200fun CHECK_RING gconc name target t thms = 1201 if null (type_vars t) 1202 then (EMPTY_RING gconc name target t thms handle _ => 1203 RING_MATCH_THM gconc name target t) 1204 else RING_MATCH_THM gconc name target t 1205in 1206fun FULL_ENCODE_DECODE_MAP_THM target t = 1207 if target = t 1208 then CONJUNCT1 (ISPEC (mk_const("I",target --> target)) I_o_ID) 1209 else RING_MATCH_THM mk_encode_decode_map_conc 1210 "encode_decode_map" target t 1211 handle e => wrap_full "FULL_ENCODE_DECODE_MAP_THM" t e 1212fun FULL_ENCODE_DETECT_ALL_THM target t = 1213 if target = t 1214 then CONJUNCT2 (ISPEC 1215 (mk_comb(mk_const("K",bool --> target --> bool),T)) 1216 I_o_ID) 1217 else RING_MATCH_THM mk_encode_detect_all_conc 1218 "encode_detect_all" target t 1219 handle e => wrap_full "FULL_ENCODE_DETECT_ALL_THM" t e 1220fun FULL_ENCODE_MAP_ENCODE_THM target t = 1221 if target = t 1222 then FULL_ENCODE_DECODE_MAP_THM target t 1223 else 1224 CHECK_RING mk_encode_map_encode_conc "encode_map_encode" target t 1225 [I_o_ID] 1226 handle e => wrap_full "FULL_ENCODE_MAP_ENCODE_THM" t e 1227fun FULL_DECODE_ENCODE_FIX_THM target t = 1228 if target = t 1229 then FULL_ENCODE_DECODE_MAP_THM target t 1230 else 1231 RING_MATCH_THM mk_decode_encode_fix_conc "decode_encode_fix" target t; 1232end 1233 1234fun FULL_MAP_COMPOSE_THM t = 1235let val basetype = most_precise_type 1236 (C exists_source_theorem_precise "map_compose") t 1237 handle e => t; 1238 val thm = SPEC_ALL (generate_source_theorem "map_compose" basetype) 1239 val conc = mk_map_compose_conc t 1240 val thm' = PART_MATCH lhs thm (lhs (snd (strip_forall conc))) 1241 val sub_thms = map FULL_MAP_COMPOSE_THM 1242 (get_sub_types basetype t) 1243in 1244 RIGHT_CONV_RULE (PURE_REWRITE_CONV sub_thms) thm' 1245end handle e => wrap_full "FULL_MAP_COMPOSE_THM" t e 1246 1247local 1248fun FMIDT getf t tname ename mk_const mk_conc = 1249let val basetype = most_precise_type (C exists_source_theorem_precise tname) t 1250 handle _ => t 1251 val thm = SPEC_ALL (generate_source_theorem tname basetype) 1252 val thm' = INST_TYPE (match_type (fst (dom_rng 1253 (type_of (lhs (concl thm))))) t) thm 1254 val left = lhs (concl thm') 1255 val subtypes = get_sub_types basetype t 1256 val sub_thms = map (fn x => 1257 if is_vartype x 1258 then NONE 1259 else SOME (FMIDT' getf x tname ename mk_const mk_conc)) 1260 subtypes 1261 1262 val conc = mk_conc t 1263 val thm1 = RAND_CONV (REWR_CONV (GSYM thm)) conc 1264 val sub_thms_filtered = 1265 filter (fn x => (not o op= o dest_eq o concl o valOf) x 1266 handle _ => true) sub_thms 1267 val thm2 = RIGHT_CONV_RULE 1268 (REWRITE_CONV (mapfilter Option.valOf sub_thms_filtered)) 1269 thm1 1270in 1271 CONV_RULE bool_EQ_CONV thm2 1272end handle e => wrap_full ename t e 1273and FMIDT' getf t tname ename mk_const mk_conc = 1274 if can (match_term (mk_const(alpha --> alpha))) (getf t) 1275 then REFL (getf t) 1276 else FMIDT getf t tname ename mk_const mk_conc 1277in 1278fun FULL_MAP_ID_THM t = 1279 FMIDT' (check_function get_map_function) t "map_id" "FULL_MAP_ID_THM" 1280 (curry mk_const "I") mk_map_id_conc 1281fun FULL_ALL_ID_THM t = 1282 FMIDT' (check_function get_all_function) t "all_id" "FULL_ALL_ID_THM" 1283 (fn t => mk_comb(mk_const("K",bool --> fst (dom_rng t) --> bool),T)) 1284 mk_all_id_conc 1285end 1286 1287fun FULL_FIX_ID_THM target t = 1288let fun wrap e = wrap_full "FULL_FIX_ID_THM" t e 1289 val basetype = if target = t then t else 1290 most_precise_type 1291 (C (exists_coding_theorem_precise target) "fix_id") t 1292 handle _ => t 1293 val thm = generate_coding_theorem target "fix_id" basetype 1294 handle e => wrap e 1295 fun mimp_only tm = if is_imp_only tm then snd (dest_imp tm) else tm 1296 val conc = mk_fix_id_conc target t 1297 handle e => wrap e 1298 val values = filter (can (match_term the_value)) 1299 (snd (strip_comb (lhs (snd (strip_imp (snd 1300 (strip_forall (mimp_only (snd 1301 (strip_forall conc)))))))))) 1302 handle e => wrap e 1303 val value_types = map (hd o snd o dest_type o type_of) values 1304 handle e => wrap e 1305 val tvs = set_diff (type_vars t) value_types 1306 val sub_thms = map (UNDISCH_ALL o PURE_REWRITE_RULE [GSYM AND_IMP_INTRO] o 1307 FULL_FIX_ID_THM target) 1308 (get_sub_types basetype t) 1309 handle e => wrap e 1310 val thm' = INST_TY_TERM 1311 (match_term (mimp_only (concl thm)) (mimp_only conc)) thm 1312 handle e => wrap e 1313 val disch_set = map (mk_fix_id_conc target) tvs 1314 handle e => wrap e 1315in 1316 PURE_REWRITE_RULE [AND_IMP_INTRO,GSYM CONJ_ASSOC] 1317 (foldr (uncurry DISCH) (foldl (uncurry PROVE_HYP) 1318 (UNDISCH_ALL (PURE_REWRITE_RULE [GSYM AND_IMP_INTRO] thm')) 1319 sub_thms) disch_set) 1320 handle e => wrap e 1321end; 1322 1323local 1324val conv1 = STRIP_QUANT_CONV (RAND_CONV (REWR_CONV (GSYM I_THM)) THENC 1325 LAND_CONV (REWR_CONV (GSYM o_THM))) THENC 1326 REWR_CONV (GSYM FUN_EQ_THM) 1327val conv2 = REWR_CONV FUN_EQ_THM THENC 1328 STRIP_QUANT_CONV (RAND_CONV (REWR_CONV I_THM) THENC 1329 LAND_CONV (REWR_CONV o_THM)) 1330fun FEDT target t = 1331let fun wrap e = wrap_full "FULL_ENCODE_DECODE_THM" t e 1332 val ename = "FULL_ENCODE_DECODE_THM" ^ type_to_string t 1333 val thm1 = generate_coding_theorem target "encode_decode_map" t 1334 handle e => wrap e 1335 val thm1_safe = 1336 INST_TYPE (map (fn v => v |-> gen_tyvar()) 1337 (type_vars_in_term (concl thm1))) thm1 1338 val thm2 = generate_source_theorem "map_id" t handle e => wrap e 1339 val tvs = type_vars_avoiding_itself (get_encode_function target t) t 1340 val antes = map (CONV_RULE conv1 o ASSUME o snd o dest_imp o 1341 snd o strip_forall o mk_encode_decode_conc target) 1342 tvs handle e => 1343 raise (mkDebugExn ename 1344 ("mk_encode_decode_conc returned invalid conclusion for" ^ 1345 " type variable: " ^ type_to_string t)) 1346 val thm2a = PURE_REWRITE_RULE (map SYM antes) thm2 handle e => wrap e 1347 fun instit f thm = INST_TYPE (match_type (f (concl thm)) t) thm 1348 val thm1a = instit (snd o dom_rng o type_of o lhs) 1349 (instit (fst o dom_rng o type_of o lhs) 1350 (SPEC_ALL thm1_safe)) 1351 val thm3 = TRANS thm1a thm2a handle e => 1352 raise (mkDebugExn ename 1353 ("Generated encode_decode_map and map_id theorems do not match:\n" ^ 1354 thm_to_string thm1a ^ "\n" ^ thm_to_string thm2a)) 1355 val thm4 = CONV_RULE conv2 thm3 handle e => wrap e 1356 1357 val (vs,conc) = strip_forall (mk_encode_decode_conc target t) 1358 handle e => wrap e 1359 val (vars,list) = if is_imp_only conc 1360 then (vs,strip_conj (fst (dest_imp conc))) else ([],[]) 1361 val result = PURE_REWRITE_RULE [AND_IMP_INTRO] 1362 (foldr (uncurry DISCH) thm4 list) 1363in 1364 if null (hyp result) then GENL vars result else 1365 raise (mkDebugExn ename 1366 "Hypothesis remain in conclusion of theorem!") 1367end 1368in 1369fun FULL_ENCODE_DECODE_THM target t = 1370 if target = t then 1371 CONV_RULE bool_EQ_CONV (REWRITE_CONV [combinTheory.I_THM] 1372 (mk_encode_decode_conc target t)) 1373 else if is_vartype t 1374 then DECIDE (mk_encode_decode_conc target t) else FEDT target t 1375end; 1376 1377local 1378fun wrap e = wrapException "FULL_DECODE_ENCODE_THM" e 1379fun FDET target t = 1380let val thm1 = generate_coding_theorem target "decode_encode_fix" t handle e => wrap e 1381 val thm2 = generate_coding_theorem target "fix_id" t handle e => wrap e 1382 1383 val thm1a = CONV_RULE (STRIP_QUANT_CONV (REWR_CONV FUN_EQ_THM THENC 1384 STRIP_QUANT_CONV (LAND_CONV (REWR_CONV o_THM)))) thm1; 1385 val v1 = (lhs o snd o dest_imp_only o snd o strip_forall o snd o dest_imp_only) 1386 val v2 = (lhs o snd o dest_imp_only) 1387 val thm2a = PART_MATCH v1 thm2 (rhs (snd (strip_forall (concl thm1a)))) handle e => 1388 PART_MATCH v2 thm2 (rhs (snd (strip_forall (concl thm1a)))) 1389 1390 val thm2b = CONV_RULE (LAND_CONV (EVERY_CONJ_CONV (STRIP_QUANT_CONV 1391 (RAND_CONV (LAND_CONV (REWR_CONV o_THM)))))) thm2a handle e => thm2a 1392 val thm2c = UNDISCH (SPEC_ALL (UNDISCH_CONJ thm2b)) handle e => 1393 UNDISCH (SPEC_ALL thm2b) 1394 1395 val thm3 = GEN (rhs (concl thm2c)) (DISCH (first (not o is_forall) (hyp thm2c)) (TRANS (SPEC_ALL thm1a) thm2c)) 1396 1397 val conc = mk_decode_encode_conc target t 1398 val list = if is_imp_only (snd (strip_forall (snd (dest_imp_only (snd (strip_forall conc)))))) 1399 then strip_conj (fst (dest_imp_only (snd (strip_forall conc)))) else [] 1400 val r = DISCH_LIST_CONJ list thm3 1401in 1402 if null (hyp r) then r else 1403 raise (mkDebugExn "FULL_DECODE_ENCODE_THM" 1404 "Hypothesis remain in resultant theorem, mismatch between mk_decode_encode_conc and this?") 1405end 1406in 1407fun FULL_DECODE_ENCODE_THM target t = 1408let val conc = (mk_decode_encode_conc target t) 1409 val vlist = if is_imp_only (snd (strip_forall (snd (dest_imp_only (snd (strip_forall conc)))))) 1410 then fst (strip_forall conc) else [] 1411in 1412 GENL vlist (if is_vartype t 1413 then DISCH_ALL (ASSUME (snd (dest_imp_only (snd (strip_forall (mk_decode_encode_conc target t)))))) 1414 else FDET target t) 1415end 1416end 1417 1418local 1419fun wrap e = wrapException "FULL_ENCODE_DETECT_THM" e 1420val rthm = DISCH_ALL (CONV_HYP (REWRITE_CONV [FUN_EQ_THM,K_THM]) 1421 (ASSUME (mk_eq(mk_var("A",alpha --> bool),mk_comb(mk_const("K",bool --> alpha --> bool),T))))) 1422fun FEDT target t = 1423let val thm1 = FULL_ENCODE_DETECT_ALL_THM target t handle e => wrap e 1424 val thm2 = FULL_ALL_ID_THM t handle e => wrap e 1425 1426 val thm2a = CONV_RULE (REWR_CONV FUN_EQ_THM THENC 1427 STRIP_QUANT_CONV (RAND_CONV (REWR_CONV K_THM) THENC bool_EQ_CONV)) thm2 handle e => wrap e 1428 val thm1a = snd (EQ_IMP_RULE (SPEC_ALL (CONV_RULE (REWR_CONV FUN_EQ_THM) thm1))) handle e => wrap e 1429 1430 val conc = mk_encode_detect_conc target t handle e => wrap e 1431 val imps = map (MATCH_MP rthm o CONV_RULE (STRIP_QUANT_CONV (REWR_CONV (GSYM o_THM))) o ASSUME o 1432 snd o dest_imp_only o snd o strip_forall o mk_encode_detect_conc target) (type_vars t) 1433 handle e => wrap e 1434 val thm3 = GEN_ALL (CONV_RULE (REWR_CONV o_THM) (MATCH_MP (PURE_REWRITE_RULE imps thm1a) (SPEC_ALL thm2a))) 1435 handle e => wrap e 1436 val thm3' = GEN_ALL (UNDISCH_ALL (PART_MATCH (snd o strip_imp) (DISCH_ALL thm3) 1437 (snd (dest_imp (snd (strip_forall conc))) handle _ => snd (strip_forall conc)))) 1438 handle e => wrap e; 1439 1440 val (vars,body) = strip_forall conc 1441 val (vs,timps) = (vars,strip_conj (fst (dest_imp body))) handle _ => ([],[]) 1442 val r = GENL vs (DISCH_LIST_CONJ timps thm3') 1443in 1444 if null (hyp r) then r else 1445 raise (mkDebugExn "FULL_ENCODE_DETECT_THM" 1446 "Hypothesis remain in resultant theorem, mismatch between mk_encode_detect_conc and this?") 1447end 1448in 1449fun FULL_ENCODE_DETECT_THM target t = 1450 if is_vartype t then DECIDE (mk_encode_detect_conc target t) 1451 else FEDT target t 1452end; 1453 1454(*****************************************************************************) 1455(* Conversions to fully apply functions: *) 1456(* *) 1457(* ENCODER_CONV : term -> thm *) 1458(* APP_MAP_CONV : term -> thm *) 1459(* APP_ALL_CONV : term -> thm *) 1460(* DECODE_PAIR_CONV : term -> thm *) 1461(* DETECT_PAIR_CONV : hol_type -> term -> thm *) 1462(* *) 1463(* ENCODER_CONV : |- (encode (C a b)) = encode_pair (x,a,b) *) 1464(* APP_MAP_CONV : |- (map (C a b)) = C (map a) (map b) *) 1465(* APP_ALL_CONV : |- (all (C a b)) = (all a) /\ (all b) *) 1466(* DECODE_PAIR_CONV : |- (decode (encode_pair f g a)) = *) 1467(* C (decode (f (FST (SND a)))) ... *) 1468(* DETECT_PAIR_CONV : |- (detect (encode_pair f g a)) = *) 1469(* (detect (f (FST (SND a)))) /\ ... *) 1470(* *) 1471(*****************************************************************************) 1472 1473fun ENCODER_CONV term = 1474let val t = type_of (rand term) 1475 val target = type_of term 1476 val check = check_function (get_encode_function target) t 1477 val def = get_coding_function_def target t "encode" 1478in 1479 if can (match_term check) (rator term) then 1480 FIRST_CONV (map REWR_CONV (CONJUNCTS def)) term 1481 else 1482 NO_CONV term 1483end handle e => NO_CONV term 1484 1485fun APP_MAP_CONV term = 1486let val t = type_of (rand term) 1487 val check = check_function get_map_function t 1488 val def = get_source_function_def t "map" 1489in 1490 if can (match_term check) (rator term) then 1491 FIRST_CONV (map REWR_CONV (CONJUNCTS def)) term 1492 else NO_CONV term 1493end handle e => NO_CONV term 1494 1495fun APP_ALL_CONV term = 1496let val t = type_of (rand term) 1497 val check = check_function get_all_function t 1498 val def = get_source_function_def t "all" 1499in 1500 if can (match_term check) (rator term) then 1501 (FIRST_CONV (map REWR_CONV (CONJUNCTS def))) term 1502 else NO_CONV term 1503end handle e => NO_CONV term 1504 1505fun DECODE_PAIR_CONV term = 1506let val t = type_of term 1507 val target = type_of (rand term) 1508 val check = check_function (get_decode_function target) t 1509 val def = get_coding_function_def target t "decode" 1510 1511 val pairp_pair = get_coding_theorem target (mk_prod(alpha,beta)) "encode_detect_all" 1512 val nump_num = get_coding_theorem target num "encode_detect_all" 1513 val labelled = mk_comb(get_detect_function target (mk_prod(num,target)),rand term) 1514 val pairp_id = get_source_theorem (mk_prod(alpha,beta)) "all_id" 1515 val pair_map = get_source_function_def (mk_prod(alpha,beta)) "map" 1516 val paird_pair = PURE_REWRITE_RULE [pair_map] 1517 (ISPEC (mk_pair(genvar (gen_tyvar()),genvar (gen_tyvar()))) 1518 (PURE_REWRITE_RULE [FUN_EQ_THM,o_THM] 1519 (SPEC_ALL (get_coding_theorem target (mk_prod(alpha,beta)) "encode_decode_map")))); 1520 val numd_num = get_coding_theorem target num "encode_decode_map"; 1521 val cs = constructors_of t 1522 val all_rwr = if all (not o can dom_rng o type_of) cs then 1523 PURE_REWRITE_CONV [o_THM,PURE_REWRITE_RULE [FUN_EQ_THM,o_THM] nump_num,K_THM] (mk_comb(get_detect_function target num,rand term)) 1524 else if length cs = 1 then 1525 PURE_REWRITE_RULE [get_coding_function_def target t "encode"] 1526 (ISPEC (list_mk_comb(hd cs,map genvar (fst (strip_fun (type_of (hd cs)))))) 1527 (PURE_REWRITE_RULE [FUN_EQ_THM,o_THM] (SPEC_ALL (generate_coding_theorem target "encode_detect_all" t)))) 1528 else PURE_REWRITE_RULE [K_THM,pairp_id,nump_num,K_o_THM] (PART_MATCH lhs (PURE_REWRITE_RULE [o_THM,FUN_EQ_THM] pairp_pair) labelled) 1529 1530 val first_decode = if all (not o can dom_rng o type_of) cs then 1531 PURE_REWRITE_CONV [o_THM,PURE_REWRITE_RULE [FUN_EQ_THM,o_THM] numd_num,I_THM] (mk_comb(get_decode_function target num,rand term)) 1532 else if length cs = 1 then 1533 REFL labelled 1534 else PURE_REWRITE_CONV [paird_pair,numd_num,I_o_ID,I_THM] (mk_comb(get_decode_function target (mk_prod(num,target)),rand term)) 1535 1536 fun re_o_conv term = 1537 let val (r,subs) = (REFL ## map (REWR_CONV (GSYM o_THM))) (strip_comb term) 1538 in 1539 foldl (fn (a,b) => MK_COMB(b,a)) r subs 1540 end; 1541in 1542 if can (match_term check) (rator term) then 1543 (REWR_CONV def THENC PURE_REWRITE_CONV (COND_CLAUSES::all_rwr::first_decode::K_THM::K_o_THM::[generate_source_theorem "all_id" t]) THENC 1544 TRY_CONV let_CONV THENC DEPTH_CONV (reduceLib.NEQ_CONV) THENC PURE_REWRITE_CONV [COND_CLAUSES,paird_pair,o_THM] THENC 1545 TRY_CONV let_CONV THENC re_o_conv) term 1546 else NO_CONV term 1547end handle e => wrapException "DETECT_PAIR_CONV" e 1548 1549fun DETECT_PAIR_CONV t term = 1550let val target = type_of (rand term) 1551 val check = check_function (get_detect_function target) t 1552 val def = get_coding_function_def target t "detect" 1553 1554 val pairp_pair = generate_coding_theorem target "encode_detect_all" (mk_prod(alpha,beta)) 1555 val nump_num = generate_coding_theorem target "encode_detect_all" num 1556 val pairp_id = generate_source_theorem "all_id" (mk_prod(alpha,beta)) 1557 val labelled = mk_comb(get_detect_function target (mk_prod(num,target)),rand term) 1558 val pair_all = get_source_function_def (mk_prod(alpha,beta)) "all" 1559 val pair_map = get_source_function_def (mk_prod(alpha,beta)) "map" 1560 val paird_pair = PURE_REWRITE_RULE [pair_map] 1561 (ISPEC (mk_pair(genvar (gen_tyvar()),genvar (gen_tyvar()))) 1562 (PURE_REWRITE_RULE [FUN_EQ_THM,o_THM] (SPEC_ALL (get_coding_theorem target (mk_prod(alpha,beta)) "encode_decode_map")))); 1563 val numd_num = get_coding_theorem target num "encode_decode_map"; 1564 val cs = constructors_of t 1565 val all_rwr = if all (not o can dom_rng o type_of) cs then 1566 PURE_REWRITE_CONV [o_THM,PURE_REWRITE_RULE [FUN_EQ_THM,o_THM] nump_num,K_THM] (mk_comb(get_detect_function target num,rand term)) 1567 else if length cs = 1 then 1568 REFL labelled 1569 else PURE_REWRITE_RULE [K_THM,pairp_id,nump_num,K_o_THM] (PART_MATCH lhs (PURE_REWRITE_RULE [o_THM,FUN_EQ_THM] pairp_pair) labelled) 1570 1571 val first_decode = 1572 if all (not o can dom_rng o type_of) cs then 1573 PURE_REWRITE_CONV [o_THM,PURE_REWRITE_RULE [FUN_EQ_THM,o_THM] numd_num,I_THM] (mk_comb(get_decode_function target num,rand term)) 1574 else if length cs = 1 then 1575 REFL labelled 1576 else PURE_REWRITE_CONV [paird_pair,numd_num,I_o_ID,I_THM] (mk_comb(get_decode_function target (mk_prod(num,target)),rand term)) 1577in 1578 if can (match_term check) (rator term) then 1579 (REWR_CONV def THENC PURE_REWRITE_CONV [COND_CLAUSES,all_rwr,first_decode] THENC 1580 TRY_CONV let_CONV THENC DEPTH_CONV (reduceLib.NEQ_CONV ORELSEC (REWR_CONV REFL_CLAUSE)) THENC PURE_REWRITE_CONV [COND_CLAUSES] THENC 1581 TRY_CONV let_CONV THENC TRY_CONV (REWR_CONV (GSYM o_THM))) term 1582 else NO_CONV term 1583end handle e => wrapException "DETECT_PAIR_CONV" e 1584 1585(*****************************************************************************) 1586(* Tactics to prove the goals described previously: *) 1587(* *) 1588(* encode_decode_map_tactic : hol_type -> hol_type -> tactic *) 1589(* encode_detect_all_tactic : hol_type -> hol_type -> tactic *) 1590(* decode_encode_fix_tactic : hol_type -> hol_type -> tactic *) 1591(* encode_map_encode_tactic : hol_type -> hol_type -> tactic *) 1592(* map_compose_tactic : hol_type -> tactic *) 1593(* map_id_tactic : hol_type -> tactic *) 1594(* all_id_tactic : hol_type -> tactic *) 1595(* fix_id_tactic : hol_type -> hol_type -> tactic *) 1596(* general_detect_tactic : hol_type -> hol_type -> tactic *) 1597(* *) 1598(* Tactics to solve inductive clauses for the goals given previously. *) 1599(* *) 1600(* detect_dead_rule : hol_type -> hol_type -> thm *) 1601(* Generates a single application of detect to 'nil'. Used in *) 1602(* CONSOLIDATE_CONV to show that bottom values terminate. *) 1603(* *) 1604(*****************************************************************************) 1605 1606fun encode_decode_map_tactic target (t:hol_type) (a,g) = 1607let val t = type_of (rand (rhs (snd (strip_forall (snd (strip_imp (snd (strip_forall g)))))))) 1608 val rts = relevant_types t 1609 val thms = map (generate_coding_theorem target "encode_decode_map") rts 1610 val map_defs = map (C get_source_function_def "map") rts 1611in 1612 (REPEAT STRIP_TAC THEN 1613 FIRST [ CONV_TAC (LAND_CONV (RATOR_CONV (FIRST_CONV (map REWR_CONV thms)))) THEN 1614 RULE_ASSUM_TAC GSYM THEN ASM_REWRITE_TAC (map_defs @ thms), 1615 PURE_REWRITE_TAC [o_THM] THEN CONV_TAC (LAND_CONV (RAND_CONV ENCODER_CONV THENC DECODE_PAIR_CONV) THENC RAND_CONV APP_MAP_CONV) THEN 1616 ASM_REWRITE_TAC (get_source_function_def (mk_prod(alpha,beta)) "map"::thms) THEN 1617 RULE_ASSUM_TAC GSYM THEN ASM_REWRITE_TAC (map_defs @ thms)]) (a,g) 1618end handle e => wrapException "encode_decode_map_tactic" e 1619 1620local 1621fun fix_type tm ty = 1622 if is_pair tm then uncurry cons ((I ## fix_type (snd (dest_pair tm))) (dest_prod ty)) else [ty]; 1623fun PTAC target rset t (a,g) = 1624let val endt = rand (lhs g) 1625 val t' = delete_matching_types rset (cannon_type (type_of endt)) 1626 val thm1 = PURE_REWRITE_RULE [FUN_EQ_THM,o_THM] (SPEC_ALL (generate_coding_theorem target "encode_map_encode" t')); 1627 val thm2 = PURE_REWRITE_RULE [I_THM,I_o_ID,get_source_function_def t' "map"] (ISPEC (list_mk_pair(map genvar (fix_type endt t'))) thm1) 1628in 1629 (CONV_TAC (LAND_CONV (REWR_CONV thm2))) (a,g) 1630end 1631fun CTAC target rset t = 1632let val cs = constructors_of t 1633in if all (not o can dom_rng o type_of) cs then ALL_TAC 1634 else if length cs = 1 then ALL_TAC else PTAC target rset t 1635end 1636in 1637fun encode_map_encode_tactic target (t:hol_type) (a,g) = 1638let val t = type_of (rand (rhs (snd (strip_forall (snd (strip_imp (snd (strip_forall g)))))))) 1639 val rset = map fst (split_nested_recursive_set t) 1640 val rts = relevant_types t 1641 val thms = map (generate_coding_theorem target "encode_map_encode") rts 1642 val enc_defs = map (C (get_coding_function_def target) "encode") (mk_prod(alpha,beta)::all_types t); 1643 val all_thms = map (PURE_REWRITE_RULE [o_THM,FUN_EQ_THM]) thms @ thms @ enc_defs 1644in 1645 (REPEAT STRIP_TAC THEN 1646 FIRST [ CONV_TAC (LAND_CONV (RATOR_CONV (FIRST_CONV (map REWR_CONV thms)))) THEN 1647 RULE_ASSUM_TAC GSYM THEN ASM_REWRITE_TAC all_thms, 1648 PURE_REWRITE_TAC [o_THM] THEN CONV_TAC (LAND_CONV (RAND_CONV APP_MAP_CONV THENC ENCODER_CONV) THENC RAND_CONV ENCODER_CONV) THEN 1649 PURE_REWRITE_TAC [I_THM,I_o_ID] THEN CTAC target rset t THEN ASM_REWRITE_TAC all_thms THEN 1650 RULE_ASSUM_TAC GSYM THEN ASM_REWRITE_TAC (o_THM::all_thms)]) (a,g) 1651end 1652end 1653 1654fun map_compose_tactic (t:hol_type) (a,g) = 1655let val t = type_of (rand (rhs (snd (strip_forall (snd (strip_imp (snd (strip_forall g)))))))) 1656 val rts = relevant_types t 1657 val thms = map (generate_source_theorem "map_compose") rts 1658 val map_defs = map (C get_source_function_def "map") (mk_prod(alpha,beta)::all_types t); 1659 val all_thms = map (PURE_REWRITE_RULE [o_THM,FUN_EQ_THM]) thms @ thms @ map_defs 1660in 1661 (REPEAT STRIP_TAC THEN 1662 FIRST [CONV_TAC (LAND_CONV (RATOR_CONV (FIRST_CONV (map REWR_CONV thms)))) THEN 1663 RULE_ASSUM_TAC GSYM THEN ASM_REWRITE_TAC all_thms, 1664 PURE_REWRITE_TAC [o_THM] THEN CONV_TAC (LAND_CONV (RAND_CONV APP_MAP_CONV THENC APP_MAP_CONV) THENC RAND_CONV APP_MAP_CONV) THEN 1665 REWRITE_TAC (mapfilter TypeBase.one_one_of (all_types t)) THEN REPEAT CONJ_TAC THEN 1666 CONV_TAC (LAND_CONV (REWR_CONV (GSYM o_THM))) THEN 1667 ASM_REWRITE_TAC all_thms]) (a,g) 1668end; 1669 1670fun encode_detect_all_tactic target (t:hol_type) (a,g) = 1671let val t = type_of (rand (rhs (snd (strip_forall (snd (strip_imp (snd (strip_forall g)))))))) 1672 val thms = map (generate_coding_theorem target "encode_detect_all") (relevant_types t) 1673 val all_defs = mapfilter (C get_source_function_def "all") (all_types t) 1674in 1675 (REPEAT STRIP_TAC THEN 1676 FIRST [ CONV_TAC (LAND_CONV (RATOR_CONV (FIRST_CONV (map REWR_CONV thms)))) THEN 1677 RULE_ASSUM_TAC GSYM THEN ASM_REWRITE_TAC (all_defs @ thms), 1678 PURE_REWRITE_TAC [o_THM] THEN CONV_TAC (LAND_CONV (RAND_CONV ENCODER_CONV THENC (DETECT_PAIR_CONV t)) THENC RAND_CONV APP_ALL_CONV) THEN 1679 ASM_REWRITE_TAC (get_source_function_def (mk_prod(alpha,beta)) "all"::thms) THEN 1680 RULE_ASSUM_TAC GSYM THEN ASM_REWRITE_TAC (all_defs @ thms)]) (a,g) 1681end handle e => wrapException "encode_detect_all_tactic" e 1682 1683 1684fun LET_RAND_CONV match term = 1685let fun strip_pair x = if is_pair x then op:: ((I ## strip_pair) (dest_pair x)) else [x] 1686 val (func_tm,let_tm) = dest_comb term 1687 val (inputs,output) = pairSyntax.dest_anylet let_tm 1688 val ginput = map (fn (a,b) => (a,genvar (type_of a))) inputs 1689 val alpha = gen_tyvar() 1690 val beta = gen_tyvar() 1691 val goutput = list_mk_comb(mk_var("M",list_mk_fun(flatten (map (map type_of o strip_pair o fst) ginput),alpha)), 1692 flatten (map (strip_pair o fst) ginput)); 1693 val gterml = mk_comb(mk_var("f",alpha --> beta),pairSyntax.mk_anylet (ginput,goutput)); 1694 val gtermr = pairSyntax.mk_anylet(ginput,mk_comb(mk_var("f",alpha --> beta),goutput)); 1695 val thm = CONV_RULE bool_EQ_CONV (DEPTH_CONV (REWR_CONV LET_DEF ORELSEC GEN_BETA_CONV ORELSEC REWR_CONV REFL_CLAUSE) (mk_eq(gterml,gtermr))) 1696in 1697 if match = func_tm then HO_REWR_CONV thm term else NO_CONV term 1698end 1699 1700local 1701fun PCONV tm = if is_pair tm then RAND_CONV PCONV tm else TRY_CONV (REWR_CONV (GSYM PAIR) THENC PCONV) tm 1702fun PLET_CONV term = 1703let val len = strip_pair (fst (dest_pabs (rand (rator term)))); 1704 val full_pair = list_mk_prod(map (fn a => gen_tyvar ()) len) 1705 val thm = PCONV (mk_var("x",full_pair)) 1706in 1707 (RATOR_CONV (RATOR_CONV (REWR_CONV LET_DEF)) THENC 1708 RATOR_CONV BETA_CONV THENC BETA_CONV THENC 1709 RAND_CONV (REWR_CONV thm) THENC 1710 PAIRED_BETA_CONV THENC 1711 REWRITE_CONV [GSYM thm]) term 1712end; 1713fun COND_CONG_TAC (a,g) = 1714 TRY (MATCH_MP_TAC COND_CONG THEN CONJ_TAC THENL [ALL_TAC,CONJ_TAC] THENL 1715 [ALL_TAC,DISCH_TAC THEN COND_CONG_TAC,DISCH_TAC THEN COND_CONG_TAC]) (a,g); 1716fun START_LABEL_TAC enc encoder target t = 1717 PURE_REWRITE_TAC [o_THM] THEN 1718 ONCE_REWRITE_TAC [get_coding_function_def target t "decode"] THEN 1719 ONCE_REWRITE_TAC [get_coding_function_def target t "fix"] THEN 1720 CONV_TAC (LAND_CONV (REDEPTH_CONV (FIRST_CONV [HO_REWR_CONV (ISPEC enc COND_RAND),LET_RAND_CONV enc]))) THEN 1721 MATCH_MP_TAC COND_CONG THEN REPEAT STRIP_TAC THEN 1722 REWRITE_TAC [REWRITE_RULE [FUN_EQ_THM,o_THM] (generate_coding_theorem target "encode_map_encode" t)] THEN 1723 ONCE_REWRITE_TAC [encoder] THEN 1724 CONV_TAC (REDEPTH_CONV (let_CONV ORELSEC PLET_CONV)) THEN COND_CONG_TAC THEN 1725 TRY (REWRITE_TAC [get_coding_function_def target (mk_prod(alpha,beta)) "encode",I_THM] THEN NO_TAC) 1726fun LABEL_TAC thms enc encoder target t = 1727 START_LABEL_TAC enc encoder target t THEN 1728 FIRST [FIRST_ASSUM (CONV_TAC o LAND_CONV o RAND_CONV o REWR_CONV o GSYM),REFL_TAC] THEN 1729 REWRITE_TAC (map (REWRITE_RULE [FUN_EQ_THM,o_THM]) thms) THEN 1730 MATCH_MP_TAC (get_coding_theorem target num "fix_id") THEN 1731 FIRST_ASSUM ACCEPT_TAC 1732fun XTAC fix_defs (a,g) = 1733 (if is_var (rand (rhs g)) then 1734 CONV_TAC (BINOP_CONV (FIRST_CONV (map REWR_CONV fix_defs)) THENC DEPTH_CONV (REWR_CONV LET_DEF) THENC DEPTH_CONV GEN_BETA_CONV) THEN 1735 COND_CONG_TAC else NO_TAC) (a,g) 1736in 1737fun decode_encode_fix_tactic target _ (a,g) = 1738let val term = (rator o lhs o snd o strip_forall o snd o strip_imp o snd o strip_forall) g 1739 val enc = rand (rator term); 1740 val t = fst (dom_rng (type_of enc)) 1741 val encoder = get_coding_function_def target t "encode"; 1742 val mt = first (not o C (exists_coding_theorem target) "decode_encode_fix") (all_types t) 1743 val rts = mk_prod(alpha,beta)::num::relevant_types mt 1744 val thms = map (generate_coding_theorem target "decode_encode_fix" o base_type) rts @ 1745 mapfilter (generate_coding_theorem target "decode_encode_fix") rts 1746 val fix_defs = map (C (get_coding_function_def target) "fix") (mk_prod(alpha,beta)::num::all_types mt) 1747 val dead_thm = #bottom_thm (get_translation_scheme target) 1748 val all_defs = foldl (fn (a,b) => get_coding_function_def target a "encode"::get_source_function_def a "map":: 1749 get_coding_function_def target a "decode"::get_coding_function_def target a "fix":: 1750 get_coding_function_def target a "detect"::b) [o_THM,dead_thm] (all_types t) 1751 val thm = REWRITE_RULE [I_THM,get_source_function_def (mk_prod(alpha,beta)) "map"] 1752 (ISPEC (mk_pair(mk_var("x",num),mk_var("y",beta))) (REWRITE_RULE [FUN_EQ_THM,o_THM] 1753 (SPEC_ALL (generate_coding_theorem target "encode_map_encode" (mk_prod(num,beta)))))) 1754in 1755 (REPEAT STRIP_TAC THEN 1756 FIRST [ 1757 CONV_TAC (LAND_CONV (RATOR_CONV (FIRST_CONV (mapfilter REWR_CONV thms)))), 1758 LABEL_TAC thms enc encoder target t, 1759 START_LABEL_TAC enc encoder target t THEN 1760 TRY (FULL_SIMP_TAC std_ss [get_coding_function_def target (mk_prod(alpha,beta)) "detect"] THEN NO_TAC) THEN 1761 TRY (FIRST_ASSUM (CONV_TAC o LAND_CONV o RAND_CONV o RATOR_CONV o RAND_CONV o REWR_CONV o GSYM) THEN 1762 CONV_TAC (LAND_CONV (REWR_CONV thm THENC RAND_CONV (REWR_CONV PAIR)))) THEN 1763 CONV_TAC (LAND_CONV (FIRST_CONV (mapfilter (REWR_CONV o PURE_REWRITE_RULE [FUN_EQ_THM,o_THM]) thms))) THEN 1764 REWRITE_TAC [I_o_ID]] THEN 1765 REPEAT (XTAC fix_defs THEN RES_TAC THEN 1766 ASM_REWRITE_TAC (last (CONJUNCTS sexpTheory.sexp_11)::get_coding_function_def target (mk_prod(alpha,beta)) "encode"::thms) THEN 1767 RULE_ASSUM_TAC GSYM THEN ASM_REWRITE_TAC (map GSYM thms)) THEN 1768 REPEAT (CHANGED_TAC (ONCE_ASM_REWRITE_TAC (K_THM::all_defs) THEN 1769 REWRITE_TAC (translateTheory.DETDEAD_PAIR::I_THM::mapfilter (C get_source_theorem "map_id" o base_type) (all_types t)) THEN 1770 CONV_TAC (DEPTH_CONV (REWR_CONV LET_DEF) THENC DEPTH_CONV GEN_BETA_CONV)))) (a,g) 1771end 1772end; 1773 1774local 1775fun MATCH_FIX_TAC target rset all_types (a,g) = 1776let fun mcheck rset t = exists (can (C match_type t)) rset orelse 1777 (can dest_type t andalso exists (mcheck rset) (snd (dest_type t))) 1778 val ftypes = (filter (not o is_vartype) (filter (not o mcheck rset) (num::all_types))); 1779 val thms = map (fn a => (C (PART_MATCH I) (snd (strip_forall (mk_fix_id_conc target a))) o FULL_FIX_ID_THM target) a) ftypes; 1780 val thms' = map (CONV_RULE (DEPTH_CONV RIGHT_IMP_FORALL_CONV THENC REWRITE_CONV [AND_IMP_INTRO])) thms 1781in 1782 (MATCH_MP_TAC (first (curry op= ((rator o lhs) g) o rator o lhs o snd o strip_imp o 1783 snd o strip_forall o snd o strip_imp o snd o strip_forall o concl) thms')) (a,g) 1784end; 1785in 1786fun fix_id_tactic target t (a,g) = 1787let val all_types = all_types t 1788 val mt = first (not o C (exists_coding_theorem target) "fix_id" o base_type) all_types 1789 val defs = map (C (get_coding_function_def target) "fix" o base_type) (mk_prod(alpha,beta)::num::all_types) 1790 val pdefs = flatten (mapfilter (CONJUNCTS o C (get_coding_function_def target) "detect" o base_type) 1791 (mk_prod(alpha,beta)::num::all_types)) 1792 val split_thm = generate_coding_theorem target "fix_id" (mk_prod(num,target)) 1793 val rts = relevant_types mt 1794 val rset = map fst (split_nested_recursive_set mt) 1795 val def_thms = mapfilter (generate_coding_theorem target "decode_encode_fix" o base_type) 1796 (mk_prod(alpha,beta)::num::rts) 1797 val tsplit_thm = SIMP_RULE std_ss [get_coding_function_def target (mk_prod(alpha,beta)) "detect", 1798 get_coding_function_def target (mk_prod(alpha,beta)) "fix"] 1799 (GSYM (generate_coding_theorem target "fix_id" (mk_prod(target,target)))); 1800 val thms = mapfilter (GEN_ALL o (CONV_RULE (DEPTH_CONV RIGHT_IMP_FORALL_CONV THENC 1801 REWRITE_CONV [AND_IMP_INTRO])) o 1802 generate_coding_theorem target "fix_id") (rev (mk_prod(alpha,beta)::num::rts)) 1803 val cond_tm = mk_cond(mk_var("p",bool),mk_var("a",alpha),(mk_var("b",alpha))); 1804in 1805 (REPEAT (POP_ASSUM MP_TAC) THEN REPEAT STRIP_TAC THEN 1806 FIRST [ 1807 RULE_ASSUM_TAC (CONV_RULE (ONCE_REWRITE_CONV pdefs THENC REWRITE_CONV [COND_EXPAND])) THEN 1808 POP_ASSUM STRIP_ASSUME_TAC THEN ONCE_REWRITE_TAC defs THEN ASM_REWRITE_TAC [COND_ID] THEN NO_TAC, 1809 RULE_ASSUM_TAC (ONCE_REWRITE_RULE pdefs) THEN 1810 RULE_ASSUM_TAC (ONCE_REWRITE_RULE [get_coding_function_def target (mk_prod(alpha,beta)) "detect"]) THEN 1811 POP_ASSUM MP_TAC THEN ASM_REWRITE_TAC [] THEN NO_TAC, 1812 IMP_RES_TAC tsplit_thm THEN POP_ASSUM (CONV_TAC o RAND_CONV o REWR_CONV) THEN 1813 CONV_TAC (LAND_CONV (FIRST_CONV (map REWR_CONV defs))) THEN 1814 RULE_ASSUM_TAC (CONV_RULE (TRY_CONV (FIRST_CONV (map REWR_CONV pdefs) THENC REWR_CONV COND_EXPAND))) THEN 1815 POP_ASSUM STRIP_ASSUME_TAC THEN 1816 CONV_TAC (DEPTH_CONV (REWR_CONV LET_DEF ORELSEC GEN_BETA_CONV)) THEN 1817 RULE_ASSUM_TAC (CONV_RULE (DEPTH_CONV (REWR_CONV LET_DEF ORELSEC GEN_BETA_CONV))) THEN 1818 ASM_REWRITE_TAC [] THEN 1819 TRY ( REPEAT IF_CASES_TAC THEN 1820 PAT_ASSUM cond_tm MP_TAC THEN ASM_REWRITE_TAC [get_coding_function_def target (mk_prod(alpha,beta)) "encode"] THEN TRY STRIP_TAC THEN 1821 TRY (CONV_TAC (LAND_CONV (FIRST_CONV (map REWR_CONV defs)))) THEN ASM_REWRITE_TAC []) THEN 1822 TRY (MK_COMB_TAC THENL [MK_COMB_TAC,ALL_TAC]) THEN 1823 REWRITE_TAC [] THEN 1824 TRY ( (FIRST_ASSUM (CONV_TAC o LAND_CONV o REWR_CONV o GSYM) ORELSE 1825 FIRST_ASSUM (CONV_TAC o LAND_CONV o RAND_CONV o REWR_CONV o GSYM) ORELSE MATCH_MP_TAC tsplit_thm) THEN 1826 REWRITE_TAC [get_coding_function_def target (mk_prod(alpha,beta)) "decode"] THEN 1827 ASM_REWRITE_TAC (I_THM::FST::SND::map (REWRITE_RULE [o_THM,FUN_EQ_THM]) def_thms)) THEN 1828 (FIRST_ASSUM ACCEPT_TAC ORELSE FIRST_ASSUM MATCH_MP_TAC ORELSE MATCH_FIX_TAC target rset all_types) THEN 1829 FULL_SIMP_TAC std_ss [get_coding_function_def target (mk_prod(alpha,beta)) "detect",get_coding_function_def target (mk_prod(alpha,beta)) "decode"] THEN 1830 RULE_ASSUM_TAC (REWRITE_RULE [ISPEC fst_tm COND_RAND,ISPEC snd_tm COND_RAND,I_THM]) THEN 1831 REPEAT (POP_ASSUM MP_TAC) THEN REPEAT IF_CASES_TAC THEN REPEAT STRIP_TAC THEN RULE_ASSUM_TAC (REWRITE_RULE []) THEN ASM_REWRITE_TAC [], 1832 ONCE_REWRITE_TAC defs THEN IF_CASES_TAC THEN 1833 TRY (FIRST_ASSUM MATCH_MP_TAC) THEN TRY (MATCH_FIX_TAC target rset all_types) THENL 1834 [RULE_ASSUM_TAC (ONCE_REWRITE_RULE pdefs),ALL_TAC] THEN 1835 REPEAT (FIRST_ASSUM (fn th => MP_TAC th THEN WEAKEN_TAC (curry op= (concl th)) THEN IF_CASES_TAC THEN DISCH_TAC)) THEN 1836 RES_TAC THEN IMP_RES_TAC (generate_coding_theorem target "general_detect" (base_type t)) THEN REPEAT CONJ_TAC THEN 1837 REPEAT ((FIRST_ASSUM ACCEPT_TAC ORELSE FIRST_ASSUM (ACCEPT_TAC o GSYM) ORELSE REPEAT (POP_ASSUM MP_TAC) THEN ASM_REWRITE_TAC []) THEN 1838 ONCE_REWRITE_TAC [get_coding_function_def target t "detect"] THEN REPEAT (IF_CASES_TAC THEN ASM_REWRITE_TAC []) THEN REPEAT STRIP_TAC THEN 1839 MAP_EVERY (IMP_RES_TAC o generate_coding_theorem target "general_detect" o base_type) (filter (not o is_vartype) all_types))]) (a,g) 1840end 1841end; 1842 1843local 1844fun FULL_MATCH_TAC thm (a,g) = 1845let val (tsub,_) = match_term (snd (strip_exists g)) (concl thm) 1846 val list = map (fn v => #residue (first (curry op= v o #redex) tsub) handle _ => v) (fst (strip_exists g)) 1847in 1848 (MAP_EVERY EXISTS_TAC list THEN MATCH_ACCEPT_TAC thm) (a,g) 1849end; 1850fun COMPLETE thm split thms (a,g) = 1851 (FIRST [REWRITE_TAC [K_THM] THEN NO_TAC, 1852 FIRST_ASSUM MATCH_MP_TAC, 1853 MAP_FIRST (MATCH_MP_TAC o GEN_ALL) thms THEN FIRST_ASSUM FULL_MATCH_TAC, 1854 CONV_TAC (UNDISCH o PART_MATCH (lhs o snd o dest_imp_only) thm) THEN REPEAT CONJ_TAC THEN COMPLETE thm split thms, 1855 IMP_RES_TAC thm THEN RES_TAC THEN COMPLETE thm split thms]) (a,g) 1856in 1857fun general_detect_tactic target t = 1858let val all_types = map base_type ((mk_prod(alpha,beta))::(filter (not o is_vartype) (flatten (map (op@ o snd) (split_nested_recursive_set t))))); 1859 val pair_def = get_coding_function_def target (mk_prod(alpha,beta)) "detect" 1860 val thms = [K_THM,get_coding_function_def target (mk_prod(alpha,beta)) "decode",I_THM]; 1861 val (thm,split) = CONJ_PAIR (CONV_RULE (REWR_CONV COND_RAND THENC 1862 REWRITE_CONV [COND_EXPAND,GSYM (SPEC (mk_neg(mk_var("a",bool))) DISJ_COMM),GSYM IMP_DISJ_THM]) (SPEC_ALL pair_def)); 1863 1864in REPEAT GEN_TAC THEN REWRITE_TAC [get_coding_function_def target t "detect"] THEN 1865 TRY (FULL_SIMP_TAC std_ss [pair_def] THEN NO_TAC) THEN 1866 REWRITE_TAC [LET_DEF] THEN GEN_BETA_TAC THEN ASM_REWRITE_TAC thms THEN 1867 REPEAT (IF_CASES_TAC THEN ASM_REWRITE_TAC []) THEN REPEAT STRIP_TAC THEN 1868 REPEAT (FIRST_ASSUM (fn th => MP_TAC th THEN WEAKEN_TAC (curry op= (concl th)) THEN IF_CASES_TAC)) THEN 1869 RES_TAC THEN ASM_REWRITE_TAC [FST,SND] THEN REPEAT STRIP_TAC THEN 1870 REPEAT (CHANGED_TAC (IMP_RES_TAC split THEN IMP_RES_TAC thm)) THEN 1871 COMPLETE thm split (mapfilter (generate_coding_theorem target "general_detect") all_types) 1872end 1873end 1874 1875local 1876fun DDR dead_value target t = 1877let val def = get_coding_function_def target t "detect" 1878 val den = get_coding_theorem target num "detect_dead" 1879 val t' = fst (strip_fun (type_of (hd (constructors_of t)))) 1880 val dead_thm = #bottom_thm (get_translation_scheme target) 1881in 1882 (ONCE_REWRITE_CONV [def] THENC 1883 REWRITE_CONV (dead_thm::den::mapfilter (generate_coding_theorem target "detect_dead" o list_mk_prod) [t'])) 1884 (mk_comb(check_function (get_detect_function target) t,dead_value)) 1885end 1886in 1887fun detect_dead_rule target t = 1888let val dead_value = #bottom (get_translation_scheme target) 1889in 1890 if t = target then 1891 REWRITE_CONV [get_coding_theorem target bool "encode_decode",K_THM] (mk_comb(get_decode_function target bool,mk_comb(get_detect_function target t,dead_value))) 1892 else DDR dead_value target t 1893end 1894end; 1895 1896fun all_id_tactic t (a,g) = 1897let val t = type_of (rand (rhs (snd (strip_forall (snd (strip_imp (snd (strip_forall g)))))))) 1898 val all_thms = mapfilter (generate_source_theorem "all_id") (relevant_types t) 1899 val def = get_source_function_def t "all" 1900 val pair_def = get_source_function_def (mk_prod(alpha,beta)) "all" 1901in 1902 (REPEAT STRIP_TAC THEN REWRITE_TAC [def,pair_def] THEN ASM_REWRITE_TAC (K_THM::all_thms)) (a,g) 1903end 1904 1905fun map_id_tactic t (a,g) = 1906let val t = type_of (rand (rhs (snd (strip_forall (snd (strip_imp (snd (strip_forall g)))))))) 1907 val all_thms = mapfilter (generate_source_theorem "map_id") (relevant_types t) 1908 val def = get_source_function_def t "map" 1909 val pair_def = get_source_function_def (mk_prod(alpha,beta)) "map" 1910in 1911 (REPEAT STRIP_TAC THEN REWRITE_TAC [def,pair_def] THEN ASM_REWRITE_TAC (I_THM::all_thms)) (a,g) 1912end; 1913 1914(*****************************************************************************) 1915(* Destructors: *) 1916(* mk_destructors : hol_type -> hol_type -> thm list *) 1917(* *) 1918(* Produces destructor theorems and conditional rewrites for resolving *) 1919(* them. This will also produce predicate theorems, eg: *) 1920(* |- (FST o sexp_to_pair nat f o pair g) (Ci ...) = i *) 1921(* *) 1922(*****************************************************************************) 1923 1924fun MK_FST thm = 1925 AP_TERM (mk_const("FST", 1926 type_of (lhs (concl thm)) --> 1927 fst (dest_prod (type_of (lhs (concl thm)))))) thm; 1928fun MK_SND thm = 1929 AP_TERM (mk_const("SND", 1930 type_of (lhs (concl thm)) --> 1931 snd (dest_prod (type_of (lhs (concl thm)))))) thm; 1932 1933local 1934fun PRODUCTS 0 thm = [thm] 1935 | PRODUCTS n thm = 1936 MK_FST thm :: PRODUCTS (n - 1) (MK_SND thm) handle _ => [thm]; 1937fun O_CONV c term = 1938 if free_in c (rator term) orelse free_in c (rator (rand term)) 1939 handle _ => true 1940 then ALL_CONV term 1941 else (RAND_CONV (O_CONV c) THENC REWR_CONV (GSYM o_THM)) term 1942fun dest_filter l = 1943 filter (fn thm => 1944 mem (rhs (concl thm)) (snd (strip_comb (rand (lhs (concl thm))))) 1945 handle e => false) l 1946exception NotAPair 1947fun mk_single_destructor target t c = 1948let val (types,_) = strip_fun (type_of c) 1949 val args = map (fn (n,t) => mk_var((implode o base26) n,t)) 1950 (enumerate 0 types); 1951 val cons = list_mk_comb(c,args) 1952 val encoders = CONJUNCTS (get_coding_function_def target t "encode") 1953 val e = SPEC_ALL (tryfind (C (PART_MATCH (rand o lhs)) cons) encoders) 1954 val encoder = PART_MATCH (rator o lhs) e (get_encode_function target t) 1955 val product = type_of (rand (rhs (concl encoder))) 1956 handle _ => raise NotAPair 1957 val _ = if can pairLib.dest_prod product then () else raise NotAPair 1958 val applied = AP_TERM (get_decode_function target 1959 product) (SPEC_ALL encoder); 1960 val decoder = SPEC_ALL (FULL_ENCODE_DECODE_THM target product) 1961 val decoder' = PART_MATCH (lhs o snd o strip_forall o snd o strip_imp) 1962 decoder (rhs (concl applied)) 1963 val decoder'' = SPEC_ALL (UNDISCH decoder' handle _ => decoder') 1964 val rewritten = RIGHT_CONV_RULE (REWR_CONV decoder'') applied 1965 val product_encoder = get_encode_function target product; 1966 val apped = RIGHT_CONV_RULE (REWR_CONV (GSYM encoder)) 1967 (AP_TERM product_encoder rewritten); 1968 val var = variant (thm_frees apped) (mk_var("x",t)); 1969 val eterm = list_mk_exists(args,mk_eq(var,cons)) 1970 val rapped = PURE_REWRITE_RULE [GSYM (ASSUME (mk_eq(var,cons)))] apped; 1971 val chosen = DISCH_ALL_CONJ (CHOOSE_L (args,ASSUME eterm) rapped); 1972 val subtypes = sub_types t 1973in 1974 (chosen,dest_filter (map (CONV_RULE (LAND_CONV (O_CONV c)) o 1975 RIGHT_CONV_RULE (REWRITE_CONV [FST,SND])) 1976 (PRODUCTS (length args) rewritten))) 1977end 1978fun mapf f [] = [] 1979 | mapf f (x::xs) = 1980let val r = mapf f xs 1981in (f x :: r) handle NotAPair => r | e => raise e end 1982in 1983fun mk_destructors target t = 1984let val (chosen,destructors) = 1985 unzip (mapf (mk_single_destructor target t) (constructors_of t)) 1986 handle e => wrapException "mk_destructors" e 1987in 1988 (chosen, flatten destructors) 1989end 1990end 1991 1992(*****************************************************************************) 1993(* Initialisation: *) 1994(* initialise_source_function_generators : unit -> unit *) 1995(* initialise_coding_function_generators : hol_type -> unit *) 1996(* *) 1997(*****************************************************************************) 1998 1999fun initialise_source_function_generators () = 2000let val _ = add_compound_source_function_generator 2001 "map" 2002 mk_map_term 2003 get_map_function 2004 REFL REFL; 2005 val _ = add_compound_source_function_generator 2006 "all" 2007 mk_all_term 2008 get_all_function 2009 (fn x => (EVERY_CONJ_CONV (STRIP_QUANT_CONV (RAND_CONV 2010 (PURE_REWRITE_CONV [get_source_function_def (mk_prod(alpha,beta)) "all"])))) x 2011 handle UNCHANGED => REFL x) 2012 REFL; 2013in 2014 () 2015end; 2016 2017fun initialise_coding_function_generators target = 2018let val _ = add_compound_coding_function_generator 2019 "encode" 2020 (mk_encode_term target) 2021 (get_encode_function target) 2022 (ENCODE_CONV (get_coding_function_def target (mk_prod(alpha,beta)) "encode")) 2023 REFL target; 2024 val _ = add_compound_target_function_generator 2025 "detect" 2026 (mk_detect_term target) 2027 (get_detect_function target) 2028 (DETECT_CONV target) 2029 REFL target; 2030 val _ = add_rule_coding_theorem_generator "detect_dead" (can constructors_of) 2031 (detect_dead_rule target) target; 2032 val _ = add_compound_target_function_generator 2033 "decode" 2034 (fn t => ( generate_source_function "map" (base_type t) ; 2035 generate_coding_function target "encode" (base_type t) ; 2036 mk_decode_term target t)) 2037 (get_decode_function target) 2038 (DECODE_CONV target) 2039 (CONSOLIDATE_CONV I) 2040 target; 2041 val _ = add_compound_target_function_generator 2042 "fix" 2043 (mk_fix_term target) 2044 (get_fix_function target) 2045 (FIX_CONV target) 2046 (CONSOLIDATE_CONV rand) 2047 target; 2048in 2049 () 2050end 2051 2052fun initialise_coding_theorem_generators target = 2053let val _ = set_coding_theorem_conclusion 2054 target "encode_detect_all" 2055 (mk_encode_detect_all_conc target); 2056 val _ = set_source_theorem_conclusion 2057 "all_id" mk_all_id_conc; 2058 val _ = set_source_theorem_conclusion 2059 "map_id" mk_map_id_conc; 2060 val _ = set_source_theorem_conclusion 2061 "map_compose" mk_map_compose_conc; 2062 val _ = set_coding_theorem_conclusion 2063 target "encode_decode_map" (mk_encode_decode_map_conc target); 2064 val _ = set_coding_theorem_conclusion 2065 target "encode_map_encode" (mk_encode_map_encode_conc target); 2066 val _ = set_coding_theorem_conclusion 2067 target "general_detect" (mk_general_detect_conc target); 2068 val _ = set_coding_theorem_conclusion 2069 target "decode_encode_fix" (mk_decode_encode_fix_conc target); 2070 val _ = set_coding_theorem_conclusion 2071 target "fix_id" (mk_fix_id_conc target); 2072 2073 val _ = add_inductive_coding_theorem_generator 2074 "encode" "encode_detect_all" 2075 target FUN_EQ_CONV 2076 (encode_detect_all_tactic target); 2077 val _ = add_inductive_source_theorem_generator 2078 "all" "all_id" 2079 FUN_EQ_CONV all_id_tactic; 2080 val _ = add_inductive_source_theorem_generator 2081 "map" "map_id" 2082 FUN_EQ_CONV map_id_tactic; 2083 val _ = add_inductive_source_theorem_generator 2084 "map" "map_compose" 2085 FUN_EQ_CONV map_compose_tactic; 2086 val _ = add_inductive_coding_theorem_generator 2087 "encode" "encode_decode_map" 2088 target FUN_EQ_CONV 2089 (encode_decode_map_tactic target); 2090 val _ = add_inductive_coding_theorem_generator 2091 "encode" "encode_map_encode" 2092 target FUN_EQ_CONV 2093 (encode_map_encode_tactic target); 2094 val _ = add_inductive_coding_theorem_generator 2095 "detect" "general_detect" 2096 target REFL 2097 (general_detect_tactic target); 2098 val _ = add_inductive_coding_theorem_generator 2099 "decode" "decode_encode_fix" 2100 target FUN_EQ_CONV 2101 (decode_encode_fix_tactic target); 2102 val _ = add_inductive_coding_theorem_generator 2103 "fix" "fix_id" 2104 target REFL 2105 (fix_id_tactic target); 2106 2107 fun check_target_rule_use function_name theorem_name t = 2108 (exists_coding_theorem target t theorem_name) orelse 2109 not (can (C (get_coding_function_induction target) function_name) t) 2110 2111 fun check_source_rule_use function_name theorem_name t = 2112 (exists_source_theorem t theorem_name) orelse 2113 not (can (C get_source_function_induction function_name) t) 2114 2115 val _ = add_rule_coding_theorem_generator 2116 "encode_detect_all" 2117 (check_target_rule_use "encode" "encode_detect_all") 2118 (FULL_ENCODE_DETECT_ALL_THM target) 2119 target; 2120 val _ = add_rule_source_theorem_generator 2121 "all_id" 2122 (check_source_rule_use "all" "all_id") 2123 FULL_ALL_ID_THM; 2124 val _ = add_rule_source_theorem_generator 2125 "map_id" 2126 (check_source_rule_use "map" "map_id") 2127 FULL_MAP_ID_THM; 2128 val _ = add_rule_source_theorem_generator 2129 "map_compose" 2130 (check_source_rule_use "map" "map_compose") 2131 FULL_MAP_COMPOSE_THM; 2132 val _ = add_rule_coding_theorem_generator 2133 "encode_decode_map" 2134 (check_target_rule_use "encode" "encode_decode_map") 2135 (FULL_ENCODE_DECODE_MAP_THM target) 2136 target; 2137 val _ = add_rule_coding_theorem_generator 2138 "encode_map_encode" 2139 (check_target_rule_use "encode" "encode_map_encode") 2140 (FULL_ENCODE_MAP_ENCODE_THM target) 2141 target; 2142 val _ = add_rule_coding_theorem_generator 2143 "general_detect" 2144 (C (exists_coding_theorem target) "general_detect") 2145 (C (get_coding_theorem target) "general_detect") 2146 target; 2147 val _ = add_rule_coding_theorem_generator 2148 "decode_encode_fix" 2149 (check_target_rule_use "decode" "decode_encode_fix") 2150 (FULL_DECODE_ENCODE_FIX_THM target) 2151 target; 2152 val _ = add_rule_coding_theorem_generator 2153 "fix_id" 2154 (check_target_rule_use "fix" "fix_id") 2155 (FULL_FIX_ID_THM target) 2156 target; 2157in 2158 () 2159end; 2160 2161fun encode_type target t = 2162let val _ = if can (match_type t) (base_type t) then () else 2163 raise (mkDebugExn "encode_type" 2164 "encode_type should only be applied to base types") 2165 val _ = if exists_translation target t 2166 then () 2167 else add_translation target t 2168 val _ = generate_source_function "map" t 2169 val _ = generate_source_function "all" t 2170 val _ = generate_coding_function target "encode" t 2171 val _ = generate_coding_function target "decode" t 2172 val _ = generate_coding_function target "detect" t 2173 val _ = generate_coding_function target "fix" t 2174 2175 val _ = generate_coding_theorem target "encode_detect_all" t 2176 val _ = generate_coding_theorem target "encode_map_encode" t 2177 val _ = generate_coding_theorem target "encode_decode_map" t 2178 val _ = generate_coding_theorem target "decode_encode_fix" t 2179 2180 val _ = generate_coding_theorem target "fix_id" t 2181in 2182 () 2183end handle e => wrapException "encode_type" e 2184 2185local 2186fun GENCF name f target t = 2187 if (target = t) then f target t 2188 else 2189 (if exists_coding_function target t name 2190 then f target t 2191 else (encode_type target (base_type t) ; f target t)) 2192 handle e => wrapException ("gen_" ^ name ^ "_function") e 2193in 2194val gen_encode_function = GENCF "encode" get_encode_function 2195val gen_decode_function = GENCF "decode" get_decode_function 2196val gen_detect_function = GENCF "detect" get_detect_function 2197end; 2198 2199(*****************************************************************************) 2200(* predicate_equivalence : hol_type -> hol_type -> thm *) 2201(* Returns a theorem of the form: *) 2202(* |- (!x. P x) = (!x. detect x ==> P (decode x)) *) 2203(* for a type t. *) 2204(* This can then be used to derive a fully encoded theorem using a rule *) 2205(* rule implication and the encoding of booleans. *) 2206(* *) 2207(*****************************************************************************) 2208 2209fun predicate_equivalence target t = 2210let val pred = mk_var("P",t --> bool) 2211 val var = mk_var("x",t); 2212 val target_var = mk_var("x",target); 2213 val detect = mk_comb(get_detect_function target t,target_var) 2214 val decode = mk_comb(get_decode_function target t,target_var) 2215 val encode = mk_comb(get_encode_function target t,var) 2216 2217 val full_pred = mk_forall(var,mk_comb(pred,var)) 2218 2219 val thm1 = GEN target_var (DISCH detect (SPEC decode (ASSUME full_pred))) 2220 2221 val encdet = UNDISCH_ALL (SPEC_ALL (FULL_ENCODE_DETECT_THM target t)) 2222 val decenc = UNDISCH_ALL (SPEC_ALL (FULL_ENCODE_DECODE_THM target t)) 2223 val thm2 = GEN var (PURE_REWRITE_RULE [encdet,decenc,IMP_CLAUSES] 2224 (SPEC encode (ASSUME (concl thm1)))) 2225in 2226 DISCH_ALL_CONJ (IMP_ANTISYM_RULE (DISCH (concl thm2) thm1) 2227 (DISCH (concl thm1) thm2)) 2228end handle e => wrapException "predicate_equivalence" e 2229 2230 2231end