1(*---------------------------------------------------------------------------* 2 * Simplifications for Datatypes. This Library extracts information about 3 * datatypes from Typebase and provides some theorems and conversations that 4 * are suitable to reason about this datatype. 5 *---------------------------------------------------------------------------*) 6structure DatatypeSimps :> DatatypeSimps = 7struct 8 9open HolKernel Parse boolLib TypeBasePure 10open simpLib 11val std_ss = boolSimps.bool_ss 12 13 14fun map_option_filter f [] = [] 15 | map_option_filter f (x :: xs) = case (f x handle 16 Interrupt => raise Interrupt | _ => NONE) of 17 NONE => map_option_filter f xs 18 | SOME fx => fx :: (map_option_filter f xs) 19 20fun tyinfos_of_tys tyl = map_option_filter TypeBase.fetch tyl 21 22 23(******************************************************************************) 24(* Generating thms *) 25(******************************************************************************) 26 27fun make_variant_list n s avoid [] = [] 28 | make_variant_list n s avoid (h::t) = 29 let val v = variant avoid (mk_var(s^Int.toString n, h)) 30 in v::make_variant_list (n + 1) s (v::avoid) t 31 end 32 33fun make_args_simple [] = [] 34 | make_args_simple (ty :: tys) = 35 let 36 val arg0 = mk_var("M", ty); 37 val args = make_variant_list 0 "f" [arg0] tys; 38 in 39 arg0 :: args 40 end 41 42fun make_args_abs tyL = let 43 fun aux res n m avoid ty = let 44 val (arg_tyL, base_ty) = strip_fun ty 45 val args = make_variant_list m "x" avoid arg_tyL 46 val b = variant (args @ avoid) (mk_var("f"^Int.toString n, ty)) 47 in 48 ((args, b) :: res, n+1, m+(length arg_tyL), b::(args@avoid)) 49 end; 50 val arg0 = mk_var("M", hd tyL); 51 val (args, _, _, _) = foldl (fn (ty, (res, n, m, avoid)) => 52 aux res n m avoid ty) ([], 0, 0, [arg0]) (tl tyL) 53in 54 (arg0, rev args) 55end 56 57fun mk_type_forall_thm_tyinfo tyinfo = let 58 val nchotomy_thm = nchotomy_of tyinfo; 59 val ty = type_of (fst (dest_forall (concl nchotomy_thm))) 60 61 val P_tm = mk_var ("P", ty --> bool) 62 val input_tm = mk_var ("tt", ty) 63 val body_tm = mk_comb (P_tm, input_tm) 64 65 val thm_base = GSYM (CONJUNCT1 (SPEC body_tm IMP_CLAUSES)) 66 val true_expand_thm = GSYM (EQT_INTRO (SPEC input_tm nchotomy_thm)) 67 val thm1 = CONV_RULE (RHS_CONV (RATOR_CONV (RAND_CONV (K true_expand_thm)))) thm_base 68 69 val thm2 = QUANT_CONV (K thm1) (mk_forall (input_tm, body_tm)) 70 val thm3 = CONV_RULE (RHS_CONV (SIMP_CONV std_ss [DISJ_IMP_THM, GSYM LEFT_FORALL_IMP_THM, FORALL_AND_THM])) thm2 71 72 val thm4 = GEN P_tm thm3 73in 74 thm4 75end 76 77 78fun mk_type_quant_thms_tyinfo tyinfo = let 79 val forall_thm = mk_type_forall_thm_tyinfo tyinfo; 80 81 val (P_tm, _) = dest_forall (concl forall_thm) 82 val P_arg_tm = genvar (hd (fst (strip_fun (type_of P_tm)))) 83 val P_neg_tm = mk_abs (P_arg_tm, mk_neg (mk_comb (P_tm, P_arg_tm))) 84 85 val thm0 = SPEC P_neg_tm forall_thm 86 val thm1 = AP_TERM boolSyntax.negation thm0 87 val thm3 = CONV_RULE (BINOP_CONV (SIMP_CONV std_ss [])) thm1 88 val thm4 = GEN P_tm thm3 89in 90 (forall_thm, thm4) 91end 92 93 94fun mk_type_exists_thm_tyinfo tyinfo = 95 snd (mk_type_quant_thms_tyinfo tyinfo) 96 97 98fun mk_case_elim_thm_tyinfo tyinfo = let 99 val case_c = case_const_of tyinfo; 100 val (arg_tyL, base_ty) = strip_fun (type_of case_c); 101 val (input_arg, case_args) = make_args_abs arg_tyL 102 val avoid = input_arg :: flatten (List.map (fn (args, b) => (b :: args)) case_args) 103 val const = variant avoid (mk_var ("c", base_ty)) 104 105 val t0 = mk_comb (case_c, input_arg) 106 val t1 = foldl (fn ((args, _), t) => 107 mk_comb (t, list_mk_abs (args, const))) t0 case_args 108 val t2 = mk_eq (t1, const) 109 val t3 = list_mk_forall ([input_arg, const], t2) 110 111 val forall_thm = mk_type_forall_thm_tyinfo tyinfo 112 val simp_thm = case_def_of tyinfo 113 val thm0 = HO_REWR_CONV forall_thm t3 114 val thm1 = CONV_RULE (RHS_CONV (REWRITE_CONV [simp_thm])) thm0 115 val thm2 = EQT_ELIM thm1 116in 117 thm2 118end 119 120 121fun mk_type_rewrites_tyinfo tyinfo = let 122 val thm_def0 = case_def_of tyinfo; 123 val thms_def = CONJUNCTS thm_def0 124 125 val thms_dist = case (distinct_of tyinfo) of 126 NONE => [] 127 | SOME thm_dist0 => let 128 val thms_dist1 = CONJUNCTS thm_dist0 129 val thms_dist = thms_dist1 @ List.map GSYM thms_dist1 130 in thms_dist end 131 132 val thms_one_one = case (one_one_of tyinfo) of 133 NONE => [] 134 | SOME thm => CONJUNCTS thm; 135 136 val elim_thms = [mk_case_elim_thm_tyinfo tyinfo] 137in 138 elim_thms @ thms_def @ thms_dist @ thms_one_one 139end 140 141 142fun mk_case_cong_thm_tyinfo tyinfo = let 143 val case_c = case_const_of tyinfo; 144 val (arg_tyL, base_ty) = strip_fun (type_of case_c); 145 val (input_arg, case_args) = make_args_abs arg_tyL 146 val avoid = input_arg :: flatten (List.map (fn (args, b) => (b :: args)) case_args) 147 val input_arg' = variant avoid input_arg 148 val (avoid, case_args') = Lib.foldl_map (fn (av, (args, v)) => 149 let val v' = variant av v in 150 (v' :: av, (args, v')) end) (input_arg'::avoid, case_args) 151 152 val case_args0 = List.map (fn (args, c) => 153 list_mk_abs (args, list_mk_comb (c, args))) 154 case_args 155 val t1a = list_mk_icomb (case_c, [input_arg] @ case_args0) 156 157 val case_args1 = List.map (fn (args, c) => 158 list_mk_abs (args, list_mk_comb (c, args))) 159 case_args' 160 val t1b = list_mk_icomb (case_c, [input_arg'] @ case_args1) 161 162 val t2 = mk_eq(t1a, t1b) 163 164 val constr = constructors_of tyinfo 165 val M_eq = mk_eq (input_arg, input_arg') 166 fun mk_imp args c c' cr = let 167 val t0 = list_mk_icomb (c, args) 168 val t1 = list_mk_icomb (c', args) 169 val t2 = mk_eq (t0, t1) 170 val u0 = list_mk_icomb (cr, args) 171 val u1 = mk_eq (input_arg', u0) 172 val t3 = boolSyntax.mk_imp (u1, t2) 173 val t4 = list_mk_forall (args, t3) 174 in 175 t4 176 end 177 val imps = List.map (fn (((args, c), (_, c')), cr) => 178 mk_imp args c c' cr) 179 (zip (zip case_args case_args') constr) 180 181 182 val t3 = boolSyntax.list_mk_imp (M_eq::imps, t2) 183 val t4 = list_mk_forall ([input_arg, input_arg']@(List.map snd case_args)@(List.map snd case_args'), t3) 184 185 val forall_thm = mk_type_forall_thm_tyinfo tyinfo 186 val simp_thm = case_def_of tyinfo 187 val thm0 = HO_REWR_CONV forall_thm t4 188 val thm1 = CONV_RULE (RHS_CONV (SIMP_CONV std_ss [simp_thm])) thm0 189 val thm2 = EQT_ELIM thm1 190in 191 thm2 192end 193 194 195fun mk_case_rand_thm_tyinfo tyinfo = let 196 val case_c = case_const_of tyinfo; 197 val (arg_tyL, base_ty) = strip_fun (type_of case_c); 198 val (input_arg, case_args) = make_args_abs arg_tyL 199 val avoid = input_arg :: flatten (List.map (fn (args, b) => (b :: args)) case_args) 200 val res_ty = gen_tyvar () 201 val const = variant avoid (mk_var ("f", base_ty --> res_ty)) 202 203 val case_args0 = List.map (fn (args, c) => 204 list_mk_abs (args, list_mk_comb (c, args))) 205 case_args 206 val t1 = list_mk_icomb (case_c, [input_arg] @ case_args0) 207 val t2a = mk_comb (const, t1) 208 209 val case_args1 = List.map (fn (args, c) => 210 list_mk_abs (args, mk_comb (const, list_mk_comb (c, args)))) 211 case_args 212 val t2b = list_mk_icomb (case_c, [input_arg] @ case_args1) 213 214 val t3 = mk_eq (t2a, t2b) 215 val consts = List.map snd case_args; 216 val t4 = list_mk_forall ([input_arg, const]@consts, t3) 217 218 val forall_thm = mk_type_forall_thm_tyinfo tyinfo 219 val simp_thm = case_def_of tyinfo 220 val thm0 = HO_REWR_CONV forall_thm t4 221 val thm1 = CONV_RULE (RHS_CONV (SIMP_CONV std_ss [simp_thm])) thm0 222 val thm2 = EQT_ELIM thm1 223in 224 thm2 225end 226 227 228fun mk_case_rator_thm_tyinfo tyinfo = let 229 val case_c = case_const_of tyinfo; 230 val (arg_tyL, base_ty) = strip_fun (type_of case_c); 231 val res_ty = gen_tyvar () 232 val base_ty' = gen_tyvar () 233 val inst_ty = inst [base_ty |-> (res_ty --> base_ty')] 234 235 val (input_arg, case_args) = make_args_abs arg_tyL 236 val avoid = input_arg :: flatten (List.map (fn (args, b) => (b :: args)) case_args) 237 val const = variant avoid (mk_var ("x", res_ty)) 238 239 val case_args0 = List.map (fn (args, c) => 240 list_mk_abs (args, list_mk_comb (c, args))) 241 case_args 242 val t1 = list_mk_icomb (case_c, [input_arg] @ case_args0) 243 val t2 = inst_ty t1 244 val t3a = mk_icomb (t2, const) 245 246 val case_args1 = List.map (fn (args, c) => 247 list_mk_abs (args, mk_comb (inst_ty (list_mk_comb (c, args)), const))) 248 case_args 249 val t3b = list_mk_icomb (case_c, [input_arg] @ case_args1) 250 251 val t4 = mk_eq (t3a, t3b) 252 val consts = List.map (fn (_, t) => inst_ty t) case_args; 253 val t5 = list_mk_forall ([input_arg, const]@consts, t4) 254 255 val forall_thm = mk_type_forall_thm_tyinfo tyinfo 256 val simp_thm = case_def_of tyinfo 257 val thm0 = HO_REWR_CONV forall_thm t5 258 val thm1 = CONV_RULE (RHS_CONV (SIMP_CONV std_ss [simp_thm])) thm0 259 val thm2 = EQT_ELIM thm1 260in 261 thm2 262end 263 264 265fun mk_case_abs_thm_tyinfo tyinfo = let 266 val case_c = case_const_of tyinfo; 267 val (arg_tyL, base_ty) = strip_fun (type_of case_c); 268 val res_ty = gen_tyvar () 269 val base_ty' = gen_tyvar () 270 val inst_ty = inst [base_ty |-> (res_ty --> base_ty')] 271 val (input_arg, case_args) = make_args_abs arg_tyL 272 val avoid = input_arg :: flatten (List.map (fn (args, b) => (b :: args)) case_args) 273 val const = variant avoid (mk_var ("x", res_ty)) 274 275 val case_args0 = List.map (fn (args, c) => 276 list_mk_abs (args, mk_comb (inst_ty (list_mk_comb (c, args)), const))) 277 case_args 278 val t1 = list_mk_icomb (case_c, [input_arg] @ case_args0) 279 val t2a = mk_abs (const, t1) 280 281 val case_args1 = List.map (fn (args, c) => 282 list_mk_abs (args, mk_abs (const, mk_comb (inst_ty (list_mk_comb (c, args)), const)))) 283 case_args 284 val t2b = list_mk_icomb (case_c, [input_arg] @ case_args1) 285 286 val t3 = mk_eq (t2a, t2b) 287 val consts = List.map (fn (_, t) => inst [base_ty |-> (res_ty --> base_ty')] t) case_args; 288 val t4 = list_mk_forall ([input_arg, const]@consts, t3) 289 290 val forall_thm = mk_type_forall_thm_tyinfo tyinfo 291 val simp_thm = case_def_of tyinfo 292 val thm0 = HO_REWR_CONV forall_thm t4 293 val thm1 = CONV_RULE (RHS_CONV (SIMP_CONV std_ss [simp_thm])) thm0 294 val thm2 = EQT_ELIM thm1 295in 296 thm2 297end 298 299 300(******************************************************************************) 301(* Lifting *) 302(******************************************************************************) 303 304fun lift_case_const_CONV stop_consts rand_thms = let 305 val conv = Ho_Rewrite.GEN_REWRITE_CONV I rand_thms 306in (fn t => let 307 val thm = conv t 308 val (c, args) = strip_comb t 309in 310 if (List.length args > 1 andalso List.exists (same_const c) stop_consts) then raise UNCHANGED else 311 thm 312end handle HOL_ERR _ => raise UNCHANGED) end 313 314 315fun lift_cases_typeinfos_ss til = let 316 val rand_thms = Lib.mapfilter mk_case_rand_thm_tyinfo til 317 val rator_thms = Lib.mapfilter mk_case_rator_thm_tyinfo til 318 val abs_thms = Lib.mapfilter mk_case_abs_thm_tyinfo til 319 val consts = Lib.mapfilter case_const_of til 320 321 val conv_rand = lift_case_const_CONV consts rand_thms 322 val conv_rand_ss = simpLib.std_conv_ss { 323 name = "lift_case_const_CONV", 324 pats = [``f x``], 325 conv = conv_rand} 326 327 val rewr_ss = simpLib.rewrites (abs_thms @ rator_thms) 328in 329 simpLib.merge_ss [rewr_ss, conv_rand_ss] 330end 331 332fun lift_cases_ss tyL = lift_cases_typeinfos_ss (tyinfos_of_tys tyL) 333 334fun lift_cases_stateful_ss () = lift_cases_typeinfos_ss (TypeBase.elts ()) 335 336 337(******************************************************************************) 338(* Reverse Lifting *) 339(******************************************************************************) 340 341fun unlift_case_const_CONV stop_consts rand_thms = let 342 val conv = Rewrite.GEN_REWRITE_CONV I empty_rewrites rand_thms 343in (fn t => let 344 val thm = conv t 345 val (c, args) = strip_comb (rhs (concl thm)) 346in 347 if (List.length args > 1 andalso List.exists (same_const c) stop_consts) then raise UNCHANGED else 348 thm 349end handle HOL_ERR _ => raise UNCHANGED) end 350 351fun unlift_cases_typeinfos_ss til = let 352 val rand_thms = List.map GSYM (Lib.mapfilter mk_case_rand_thm_tyinfo til) 353 val rator_thms = List.map GSYM (Lib.mapfilter mk_case_rator_thm_tyinfo til) 354 val abs_thms = List.map GSYM (Lib.mapfilter mk_case_abs_thm_tyinfo til) 355 val consts = Lib.mapfilter case_const_of til 356 357 val conv_rand = unlift_case_const_CONV consts rand_thms 358 val conv_rand_ss = simpLib.std_conv_ss { 359 name = "unlift_case_const_CONV", 360 pats = [``f x``], 361 conv = conv_rand} 362 363 val conv_rator_ss = simpLib.std_conv_ss { 364 name = "unlift_case_const_CONV", 365 pats = [``f x``], 366 conv = Rewrite.GEN_REWRITE_CONV I empty_rewrites rator_thms} 367 368 val rewr_ss = simpLib.rewrites abs_thms 369in 370 simpLib.merge_ss [rewr_ss, conv_rator_ss, conv_rand_ss] 371end 372 373fun unlift_cases_ss tyL = unlift_cases_typeinfos_ss (tyinfos_of_tys tyL) 374 375fun unlift_cases_stateful_ss () = unlift_cases_typeinfos_ss (TypeBase.elts ()) 376 377 378(******************************************************************************) 379(* Simpset fragments *) 380(******************************************************************************) 381 382fun type_rewrites_typeinfos_ss til = 383 rewrites (flatten (Lib.mapfilter mk_type_rewrites_tyinfo til)) 384 385fun type_rewrites_ss tyL = type_rewrites_typeinfos_ss (tyinfos_of_tys tyL) 386 387fun type_rewrites_stateful_ss () = type_rewrites_typeinfos_ss (TypeBase.elts ()) 388 389fun congs thms = SSFRAG 390 {name = NONE, 391 convs = [], 392 rewrs = [], 393 ac = [], 394 filter = NONE, 395 dprocs = [], 396 congs = thms} 397 398fun case_cong_typeinfos_ss til = 399 simpLib.merge_ss [congs (Lib.mapfilter mk_case_cong_thm_tyinfo til), 400 type_rewrites_typeinfos_ss til] 401 402fun case_cong_ss tyL = case_cong_typeinfos_ss (tyinfos_of_tys tyL) 403 404fun case_cong_stateful_ss () = case_cong_typeinfos_ss (TypeBase.elts ()) 405 406 407 408fun expand_type_quants_typeinfos_ss til = 409 rewrites (flatten (List.map (fn (x, y) => [x, y]) (Lib.mapfilter 410 mk_type_quant_thms_tyinfo til))) 411 412fun expand_type_quants_ss tyL = expand_type_quants_typeinfos_ss (tyinfos_of_tys tyL) 413 414fun expand_type_quants_stateful_ss () = expand_type_quants_typeinfos_ss (TypeBase.elts ()) 415 416 417(******************************************************************************) 418(* Rule for eliminating case splits in equations *) 419(******************************************************************************) 420 421fun cases_to_top_RULE thm = let 422 val input_thmL = BODY_CONJUNCTS thm 423 val (input_eqL, input_restL) = partition (fn thm => is_eq (concl thm)) input_thmL 424 425 fun process_eq eq_thm = let 426 val free_vars_lhs = free_vars (lhs (concl eq_thm)); 427 fun search_pred t = let 428 val (c, args) = strip_comb t 429 val _ = if length args = 0 then fail () else (); 430 val _ = if (List.exists (term_eq (hd args)) free_vars_lhs) then () else fail(); 431 val case_const = TypeBase.case_const_of (type_of (hd args)); 432 in 433 same_const c case_const 434 end handle HOL_ERR _ => false; 435 val case_term = find_term search_pred (rhs (concl eq_thm)); 436 val (_, split_args) = strip_comb case_term; 437 val split_var = hd split_args; 438 val tyinfo = valOf (TypeBase.fetch (type_of split_var)) handle Option => fail() 439 val free_vars_full = free_vars (concl eq_thm) 440 val split_terms = List.map (fn (c_tm, cr_tm) => let 441 val (_, cr_ret_type) = strip_fun (type_of cr_tm); 442 val ty_inst = match_type cr_ret_type (type_of split_var); 443 val cr_tm' = inst ty_inst cr_tm; 444 val (args, _) = strip_abs c_tm 445 val (_, args') = foldl_map (fn (av, v) => let val v' = variant av v in (v' :: av, v') end) (free_vars_full, args) 446 in list_mk_comb (cr_tm', args') end) 447 (zip (tl split_args) (TypeBasePure.constructors_of tyinfo)) 448 449 450 val rhs_conv = REWRITE_CONV (#rewrs (TypeBasePure.simpls_of tyinfo)) THENC 451 DEPTH_CONV BETA_CONV 452 fun process_thm split_tm = let 453 val thm0 = INST [split_var |-> split_tm] eq_thm 454 val thm1 = CONV_RULE (RHS_CONV rhs_conv) thm0 455 in thm1 end 456 val result = List.map process_thm split_terms 457 in 458 SOME result 459 end handle HOL_ERR _ => NONE 460 461 fun process_all acc [] = List.rev acc 462 | process_all acc (eq_thm :: thms) = (case process_eq eq_thm of 463 NONE => process_all (eq_thm :: acc) thms 464 | SOME eq_thms => process_all acc (eq_thms @ thms)) 465 466 val processed_eq_thms = process_all [] input_eqL 467 val all_thms = processed_eq_thms @ input_restL 468in 469 LIST_CONJ (List.map GEN_ALL all_thms) 470end 471 472end 473