1structure constrFamiliesLib :> constrFamiliesLib = 2struct 3 4open HolKernel Parse boolLib Drule BasicProvers 5open boolLib simpLib patternMatchesSyntax numLib 6 7(***************************************************) 8(* Auxiliary definitions *) 9(***************************************************) 10 11fun cong_ss thms = simpLib.SSFRAG { 12 name = NONE, 13 convs = [], 14 rewrs = [], 15 ac = [], 16 filter = NONE, 17 dprocs = [], 18 congs = thms} 19 20fun failwith f x = 21 raise (mk_HOL_ERR "constrFamiliesLib" f x) 22 23fun variants used_vs vs = let 24 val (_, vs') = foldl (fn (v, (used_vs, vs')) => 25 let val v' = variant used_vs v 26 in (v'::used_vs, v'::vs') end) (used_vs, []) vs 27in 28 List.rev vs' 29end 30 31(* list_mk_comb with build-in beta reduction *) 32fun list_mk_comb_subst (c, args) = (case args of 33 [] => c 34 | (a::args') => let 35 val (v, c') = dest_abs c 36 in 37 list_mk_comb_subst (subst [v |-> a] c', args') 38 end handle HOL_ERR _ => 39 list_mk_comb_subst (mk_comb (c, a), args') 40) 41 42(*-----------------------------------------*) 43(* normalise free type variables in a type *) 44(* in order to use it as a map key *) 45(*-----------------------------------------*) 46 47fun next_ty ty = mk_vartype(Lexis.tyvar_vary (dest_vartype ty)); 48 49fun normalise_ty ty = let 50 fun recurse (acc as (dict,usethis)) tylist = 51 case tylist of 52 [] => acc 53 | ty :: rest => let 54 in 55 if is_vartype ty then 56 case Binarymap.peek(dict,ty) of 57 NONE => recurse (Binarymap.insert(dict,ty,usethis), 58 next_ty usethis) 59 rest 60 | SOME _ => recurse acc rest 61 else let 62 val {Args,...} = dest_thy_type ty 63 in 64 recurse acc (Args @ rest) 65 end 66 end 67 val (inst0, _) = recurse (Binarymap.mkDict Type.compare, Type.alpha) [ty] 68 val inst = Binarymap.foldl (fn (tyk,tyv,acc) => (tyk |-> tyv)::acc) 69 [] 70 inst0 71in 72 Type.type_subst inst ty 73end 74 75 76fun base_ty ty = let 77 val (tn, targs) = dest_type ty 78 val targs' = List.rev (snd (List.foldl (fn (_, (v, l)) => (next_ty v, v::l)) (Type.alpha, []) targs)) 79in 80 mk_type (tn, targs') 81end 82 83 84(*------------------------*) 85(* Encoding theorem lists *) 86(*------------------------*) 87 88fun encode_term_opt_list tl = let 89 val tl' = List.map (fn t => markerSyntax.mk_label ("THM_PART", Option.getOpt (t, T))) tl 90 val t = list_mk_conj tl' 91 val thm = (markerLib.DEST_LABELS_CONV THENC REWRITE_CONV []) t 92in 93 thm 94end 95 96fun decode_thm_opt_list combined_thm = let 97 fun process_thm thm = let 98 val thm' = CONV_RULE markerLib.DEST_LABEL_CONV thm 99 in 100 if (aconv (concl thm') T) then NONE else SOME thm' 101 end 102 103 val thms = CONJUNCTS combined_thm 104 val thms' = List.map process_thm thms 105in 106 thms' 107end 108 109fun set_goal_list tl = let 110 val thm = encode_term_opt_list tl 111in 112 proofManagerLib.set_goal ([], rhs (concl thm)) 113end 114 115 116fun prove_list (tl, tac) = let 117 val thm = encode_term_opt_list tl 118 val thm2 = prove (rhs (concl thm), tac) 119 val thm3 = EQ_MP (GSYM thm) thm2 120in 121 decode_thm_opt_list thm3 122end 123 124 125 126(***************************************************) 127(* Constructors *) 128(***************************************************) 129 130(* A constructor is a combination of a term with 131 a list of names for all it's arguments *) 132datatype constructor = CONSTR of term * (string list) 133 134fun mk_constructor t args = CONSTR (t, args) 135 136fun constructor_is_const (CONSTR (_, sl)) = null sl 137 138fun mk_constructor_term vs (CONSTR (c, args)) = let 139 val (arg_tys, b_ty) = strip_fun (type_of c) 140 val _ = if (length arg_tys < length args) then 141 failwith "check_constructor" "too many argument names given" else () 142 143 val typed_args = zip args (List.take (arg_tys, length args)) 144 val arg_vars = List.map mk_var typed_args 145 val arg_vars' = variants vs arg_vars 146 val t = list_mk_comb_subst (c, arg_vars') 147in 148 (t, arg_vars') 149end 150 151fun match_constructor (CONSTR (cr, args)) t = let 152 val (t', args') = strip_comb_bounded (List.length args) t 153in 154 if (same_const t' cr) then 155 SOME (t', zip args args') 156 else NONE 157end 158 159 160(* Multiple constructors for a single type are usually 161 grouped. These can be exhaustive or not. *) 162type constructorList = { 163 cl_type : hol_type, 164 cl_constructors : constructor list, 165 cl_is_exhaustive : bool 166} 167 168fun mk_constructorList is_exhaustive constrs = let 169 val ts = List.map (fst o (mk_constructor_term [])) constrs 170 val _ = if null ts then failwith "make_constructorList" "no constructors given" else () 171 val ty = type_of (hd ts) 172 val _ = if (Lib.all (fn t => type_of t = ty) ts) then () else 173 failwith "make_constructorList" "types of constructors don't match" 174in 175 { cl_type = ty, 176 cl_constructors = constrs, 177 cl_is_exhaustive = is_exhaustive }:constructorList 178end 179 180fun make_constructorList is_exhaustive constrs = 181 mk_constructorList is_exhaustive (List.map 182 (uncurry mk_constructor) constrs) 183 184(***************************************************) 185(* Constructor Families *) 186(***************************************************) 187 188(* Contructor families are lists of constructors with 189 a cass-split constant + extra theorems. 190*) 191 192type constructorFamily = { 193 constructors : constructorList, 194 case_const : term, 195 one_one_thm : thm option, 196 distinct_thm : thm option, 197 case_split_thm: thm, 198 case_cong_thm : thm, 199 nchotomy_thm : thm option 200} 201 202fun constructorFamily_get_rewrites (cf : constructorFamily) = 203 case (#one_one_thm cf, #distinct_thm cf) of 204 (NONE, NONE) => TRUTH 205 | (SOME thm1, NONE) => thm1 206 | (NONE, SOME thm2) => thm2 207 | (SOME thm1, SOME thm2) => CONJ thm1 thm2 208 209fun constructorFamily_get_ssfrag (cf : constructorFamily) = 210 simpLib.merge_ss [simpLib.rewrites [constructorFamily_get_rewrites cf], 211 cong_ss [#case_cong_thm cf]] 212 213fun constructorFamily_get_constructors (cf : constructorFamily) = let 214 val cl = #constructors cf 215 val cs = #cl_constructors cl 216 val ts = List.map (fn (CONSTR (a, b)) => (a, b)) cs 217in 218 (#cl_is_exhaustive cl, ts) 219end 220 221fun constructorFamily_get_case_split (cf: constructorFamily) = 222 (#case_split_thm cf) 223 224fun constructorFamily_get_case_cong (cf: constructorFamily) = 225 (#case_cong_thm cf) 226 227fun constructorFamily_get_nchotomy_thm_opt (cf: constructorFamily) = 228 (#nchotomy_thm cf) 229 230(* Test datatype 231val _ = Datatype `test_ty = 232 A 233 | B 'b 234 | C bool 'a bool 235 | D num bool` 236 237val SOME constrL = constructorList_of_typebase ``:('a, 'b) test_ty`` 238val case_const = TypeBase.case_const_of ``:('a, 'b) test_ty`` 239 240val constrL = make_constructorList false [(``{}:'a set``, []), (``\x:'a. {x}``, ["x"])] 241 242val set_CASE_def = zDefine ` 243 set_CASE s c_emp c_sing c_else = 244 (if s = {} then c_emp else ( 245 if (FINITE s /\ (CARD s = 1)) then c_sing (CHOICE s) else 246 c_else))` 247 248val case_const = ``set_CASE`` 249*) 250 251 252fun mk_one_one_thm_term_opt (constrL : constructorList) = let 253 fun mk_one_one_single cr = let 254 val (l, vl) = mk_constructor_term [] cr 255 val (r, vr) = mk_constructor_term vl cr 256 val lr = mk_eq (l, r) 257 val eqs = list_mk_conj (List.map mk_eq (zip vl vr)) 258 val b = mk_eq (lr, eqs) 259 in 260 list_mk_forall (vl @ vr, b) 261 end 262 263 val constrs = filter (not o constructor_is_const) (#cl_constructors constrL) 264 val eqs = map mk_one_one_single constrs 265in 266 if (null eqs) then NONE else SOME (list_mk_conj eqs) 267end 268 269 270fun mk_distinct_thm_term_opt (constrL : constructorList) = let 271 val constrs = #cl_constructors constrL 272 val all_pairs = flatten (List.map (fn x => 273 List.map (fn y => (x, y)) constrs) constrs) 274 val dist_pairs = List.filter (fn (CONSTR (c1, _), CONSTR (c2, _)) => 275 not (aconv c1 c2)) all_pairs 276 fun mk_distinct_single (cr1, cr2) = let 277 val (l, vl) = mk_constructor_term [] cr1 278 val (r, vr) = mk_constructor_term vl cr2 279 val lr = mk_neg (mk_eq (l, r)) 280 in 281 list_mk_forall (vl @ vr, lr) 282 end 283 284 val eqs = map mk_distinct_single dist_pairs 285in 286 if (null eqs) then NONE else SOME (list_mk_conj eqs) 287end 288 289 290fun mk_case_expand_thm_term case_const (constrL : constructorList) = let 291 val (arg_tys, res_ty) = strip_fun (type_of case_const) 292 val split_arg = mk_var ("x", hd arg_tys) 293 val split_fun = mk_var ("ff", hd arg_tys --> res_ty) 294 295 fun mk_arg cr = let 296 val (b, vs) = mk_constructor_term [split_fun,split_arg] cr 297 val b' = mk_comb (split_fun, b) 298 in 299 list_mk_abs (vs, b') 300 end 301 302 val args = List.map mk_arg (#cl_constructors constrL) 303 val args = if (#cl_is_exhaustive constrL) then args else 304 args@[(mk_abs (split_arg, mk_comb(split_fun, split_arg)))] 305 306 val r = list_mk_comb (case_const, split_arg::args) 307 val l = mk_comb (split_fun, split_arg) 308 309 val eq = list_mk_forall ([split_fun, split_arg], mk_eq (l, r)) 310in 311 eq 312end 313 314 315fun mk_case_const_cong_thm_term case_const (constrL : constructorList) = let 316 val (arg_tys, res_ty) = strip_fun (type_of case_const) 317 318 val (args_l, args_r) = let 319 fun mk_args avoid = let 320 fun mk_arg (a_ty, (i, avoid, vs)) = 321 let 322 val v = variant avoid (mk_var ("f"^(int_to_string i), a_ty)) 323 in 324 (i+1, v::avoid, v::vs) 325 end 326 val (_, _, vs_rev) = foldl mk_arg (1, avoid, []) (tl arg_tys) 327 in 328 List.rev vs_rev 329 end 330 331 val r0 = mk_var ("x", hd arg_tys) 332 val l0 = variant [r0] r0 333 val args_l = mk_args [r0, l0] 334 val args_r = mk_args (r0::l0::args_l) 335 in 336 (l0::args_l, r0::args_r) 337 end 338 339 val cong_0 = mk_eq (hd args_l, hd args_r) 340 val base = mk_eq ( 341 list_mk_comb (case_const, args_l), 342 list_mk_comb (case_const, args_r)) 343 344 (* 345 fun extract n = 346 (el n (#cl_constructors constrL), 347 el (n+1) args_l, 348 el (n+1) args_r) 349 350 val (CONSTR (c, vns), al, ar) = extract 2 351 352 *) 353 val congs_main = let 354 fun mk_arg_vars acc avoid (a_ty, vns) = case vns of 355 [] => List.rev acc 356 | (vn::vns') => let 357 val (_, atys) = dest_type a_ty 358 val v = variant avoid (mk_var (vn, hd atys)) 359 in 360 mk_arg_vars (v::acc) (v::avoid) (el 2 atys, vns') 361 end 362 363 fun process_all acc neg_pres crs als ars = 364 case (crs, als, ars) of 365 ([], [], []) => List.rev acc 366 | ([], [al], [ar]) => let 367 val arg_ts = mk_arg_vars [] [al, ar] (type_of al, ["x"]) 368 val eq_t = mk_eq (list_mk_comb (al, arg_ts), 369 list_mk_comb (ar, arg_ts)) 370 371 372 val pre_t = list_mk_conj neg_pres 373 val t_full = list_mk_forall (arg_ts, mk_imp (pre_t, eq_t)) 374 in 375 List.rev (t_full :: acc) 376 end 377 | ((CONSTR (c, vns))::crs', al::als', ar::ars') => let 378 val arg_ts = mk_arg_vars [] [al, ar] (type_of al, vns) 379 val eq_t = mk_eq (list_mk_comb (al, arg_ts), 380 list_mk_comb (ar, arg_ts)) 381 382 val pre_t = mk_eq (hd args_r, list_mk_comb (c, arg_ts)) 383 val t_full = list_mk_forall (arg_ts, mk_imp (pre_t, eq_t)) 384 val t_exp = list_mk_forall (arg_ts, mk_neg pre_t) 385 in 386 process_all (t_full::acc) (t_exp::neg_pres) crs' als' ars' 387 end 388 | _ => failwith "" "Something is wrong with the constructors/case constant. Wrong arity somewhere?" 389 in 390 process_all [] [] (#cl_constructors constrL) (tl args_l) (tl args_r) 391 end 392 393in 394 list_mk_forall (args_l @ args_r, 395 list_mk_imp (cong_0 :: congs_main, base)) 396end 397 398 399fun mk_nchotomy_thm_term_opt (constrL : constructorList) = 400 if not (#cl_is_exhaustive constrL) then NONE else let 401 val v = mk_var ("x", #cl_type constrL) 402 403 fun mk_disj cr = let 404 val (b, vs) = mk_constructor_term [v] cr 405 val eq = mk_eq (v, b) 406 in 407 list_mk_exists (vs, eq) 408 end 409 410 val eqs = List.map mk_disj (#cl_constructors constrL) 411 val eqs_t = list_mk_disj eqs 412 in 413 SOME (mk_forall (v, eqs_t)) 414 end; 415 416 417fun mk_constructorFamily_terms case_const constrL = let 418 val t1 = mk_one_one_thm_term_opt constrL 419 val t2 = mk_distinct_thm_term_opt constrL 420 val t3 = SOME (mk_case_expand_thm_term case_const constrL) 421 val t4 = SOME (mk_case_const_cong_thm_term case_const constrL) 422 val t5 = mk_nchotomy_thm_term_opt constrL 423in 424 [t1, t2, t3, t4, t5] 425end 426 427fun get_constructorFamily_proofObligations (constrL, case_const) = let 428 val ts = mk_constructorFamily_terms case_const constrL 429 val thm = encode_term_opt_list ts 430in 431 rhs (concl thm) 432end 433 434fun set_constructorFamily (constrL, case_const) = 435 set_goal_list (mk_constructorFamily_terms case_const constrL) 436 437fun mk_constructorFamily (constrL, case_const, tac) = let 438 val thms = prove_list (mk_constructorFamily_terms case_const constrL, tac) 439in 440 { 441 constructors = constrL, 442 case_const = case_const, 443 one_one_thm = el 1 thms, 444 distinct_thm = el 2 thms, 445 case_split_thm = valOf (el 3 thms), 446 case_cong_thm = valOf (el 4 thms), 447 nchotomy_thm = el 5 thms 448 }:constructorFamily 449end 450 451 452(***************************************************) 453(* Connection to typebase *) 454(***************************************************) 455 456 457(* given a type try to extract the constructors of a type 458 from typebase. Do not use the default type-base functions 459 for this but destruct the nchotomy_thm in order to get 460 the default argument names as well. *) 461fun constructorList_of_typebase ty = 462 if null (TypeBase.constructors_of ty) then NONE else let 463 val nchotomy_thm = TypeBase.nchotomy_of ty 464 val eqs = strip_disj (snd (dest_forall (concl nchotomy_thm))) 465 466 fun dest_eq eq = let 467 val (_, b) = strip_exists eq 468 val (c, args) = strip_comb (rhs b) 469 val args' = List.map (fst o dest_var) args 470 in 471 CONSTR (c, args') 472 end 473 474 val constrs = List.map dest_eq eqs 475in 476 SOME ({ cl_type = ty, 477 cl_constructors = constrs, 478 cl_is_exhaustive = true }:constructorList) 479end 480 481fun constructorFamily_of_typebase ty = let 482 val crL = valOf (constructorList_of_typebase ty) 483 handle Option => failwith "constructorList_of_typebase" "not a datatype" 484 val case_split_tm = TypeBase.case_const_of ty 485 val thm_distinct = TypeBase.distinct_of ty 486 val thm_one_one = TypeBase.one_one_of ty handle HOL_ERR _ => TRUTH 487 val thm_case = TypeBase.case_def_of ty 488 val thm_case_cong = TypeBase.case_cong_of ty 489 490 (* set_constructorFamily (crL, case_split_tm) *) 491 val cf = mk_constructorFamily (crL, case_split_tm, 492 SIMP_TAC std_ss [thm_distinct, thm_one_one, thm_case_cong] THEN 493 REPEAT STRIP_TAC THEN ( 494 Cases_on `x` THEN 495 SIMP_TAC std_ss [thm_distinct, thm_one_one, thm_case] 496 ) 497 ) 498in 499 cf 500end 501 502 503(***************************************************) 504(* Collections of constructorFamilies + *) 505(* extra matching info *) 506(***************************************************) 507 508(* Datatype for representing how well a constructorFamily or 509 a hand-written function matches a column. *) 510type matchcol_stats = { 511 colstat_missed_rows : int, 512 (* how many rows of the col are not constructor applications 513 or bound vars? *) 514 515 colstat_cases : int, 516 (* how many cases are covered ? *) 517 518 colstat_missed_constr : int 519 (* how many constructors of the family do not appear in the column *) 520} 521 522fun matchcol_stats_compare 523 (st1 : matchcol_stats) 524 (st2 : matchcol_stats) = let 525 fun lex_ord (i1, i2) b = 526 (i1 < i2) orelse ((i1 = i2) andalso b) 527in 528 lex_ord (#colstat_missed_rows st1, #colstat_missed_rows st2) ( 529 lex_ord (#colstat_cases st1, #colstat_cases st2) ( 530 op> (#colstat_missed_constr st1, #colstat_missed_constr st2) 531 ) 532 ) 533end 534 535 536type pmatch_compile_fun = (term list * term) list -> (thm * int * simpLib.ssfrag) option 537 538type pmatch_nchotomy_fun = (term list * term) list -> (thm * int) option 539 540val typeConstrFamsDB = ref (TypeNet.empty : constructorFamily TypeNet.typenet) 541 542type pmatch_compile_db = { 543 pcdb_compile_funs : pmatch_compile_fun list, 544 pcdb_nchotomy_funs : pmatch_nchotomy_fun list, 545 pcdb_constrFams : (constructorFamily list) TypeNet.typenet, 546 pcdb_ss : simpLib.ssfrag 547} 548 549val empty : pmatch_compile_db = { 550 pcdb_compile_funs = [], 551 pcdb_nchotomy_funs = [], 552 pcdb_constrFams = TypeNet.empty, 553 pcdb_ss = (simpLib.rewrites []) 554} 555 556val thePmatchCompileDB = ref empty 557 558fun lookup_typeBase_constructorFamily ty = let 559 val b_ty = base_ty ty 560in 561 SOME (b_ty, TypeNet.find (!typeConstrFamsDB, b_ty)) handle 562 NotFound => let 563 val cf = constructorFamily_of_typebase b_ty 564 val net = !typeConstrFamsDB 565 val net'= TypeNet.insert (net, b_ty, cf) 566 val _ = typeConstrFamsDB := net' 567 in 568 SOME (b_ty, cf) 569 end 570end handle HOL_ERR _ => NONE 571 572 573fun measure_constructorFamily (cf : constructorFamily) col = let 574 fun list_count p col = 575 foldl (fn (r, c) => if (p r) then c+1 else c) 0 col 576 577 (* extract the constructors of the family *) 578 val crs = List.map (fn (CONSTR (c, _)) => c) ( 579 #cl_constructors (#constructors cf)) 580 581 fun row_is_missed (vs, p) = 582 if (is_var p andalso mem p vs) then 583 (* bound variables are fine *) 584 false 585 else let 586 val (f, _) = strip_comb p 587 in 588 not (List.exists (same_const f) crs) 589 end handle HOL_ERR _ => true 590 591 fun constr_is_missed c = 592 not (List.exists (fn (vs, p) => let 593 val (f, _) = strip_comb p 594 in 595 same_const f c 596 end handle HOL_ERR _ => false) col) 597 598 val cases_no = List.length (#cl_constructors (#constructors cf)) 599 val cases_no' = if (#cl_is_exhaustive (#constructors cf)) then cases_no else (cases_no+1) 600in 601 { 602 colstat_missed_rows = list_count row_is_missed col, 603 colstat_missed_constr = list_count constr_is_missed crs, 604 colstat_cases = cases_no' 605 } 606end 607 608fun lookup_constructorFamilies_for_type (db : pmatch_compile_db) ty = let 609 val cts_fams = let 610 val cts_fams = TypeNet.match (#pcdb_constrFams db, ty) 611 val cts_fams' = Lib.flatten (List.map (fn (ty, l) => 612 List.map (fn cf => (ty, cf)) l) cts_fams) 613 val cty_opt = lookup_typeBase_constructorFamily ty 614 val cty_l = case cty_opt of 615 NONE => [] 616 | SOME (ty, cf) => [(ty, cf)] 617 in cts_fams' @ cty_l end 618 619 fun is_old_fam (ty, cf) = let 620 val (_, cl) = constructorFamily_get_constructors cf 621 fun is_old_const c = let 622 val (cn, _) = dest_const c 623 in 624 String.isSuffix "<-old" cn 625 end handle HOL_ERR _ => false 626 in 627 (List.exists (fn (c, _) => is_old_const c) cl) orelse 628 (is_old_const (#case_const cf)) 629 end 630 631 val cts_fams' = List.filter (fn cf => not (is_old_fam cf)) cts_fams 632in 633 cts_fams' 634end 635 636fun lookup_constructorFamily force_exh (db : pmatch_compile_db) col = let 637 val _ = if (List.null col) then (failwith "constructorFamiliesLib" "lookup_constructorFamilies: null col") else () 638 639 val _ = if List.all (fn (vs, c) => is_var c andalso Lib.mem c vs) col then 640 (failwith "constructorFamiliesLib" "lookup_constructorFamilies: var col") 641 else () 642 643 val ty = type_of (snd (hd col)) 644 val cts_fams = lookup_constructorFamilies_for_type db ty 645 val cts_fams' = if not force_exh then 646 cts_fams 647 else 648 List.filter (fn (_, cf) => isSome (#nchotomy_thm cf)) cts_fams 649 650 val weighted_fams = List.map (fn (ty, cf) => 651 ((ty, cf), measure_constructorFamily cf col)) cts_fams' 652 653 val weighted_fams' = filter (fn (_, w) => (#colstat_missed_rows w = 0)) weighted_fams 654 655 val weighted_fams_sorted = sort (fn (_, w1) => fn (_, w2) => 656 matchcol_stats_compare w1 w2) weighted_fams' 657in 658 case weighted_fams_sorted of 659 [] => NONE 660 | wcf::_ => SOME wcf 661end; 662 663 664fun pmatch_compile_db_compile_aux db col = ( 665 if (List.null col) then failwith "pmatch_compile_db_compile" "col 0" else let 666 val fun_res = get_first (fn f => f col handle HOL_ERR _ => NONE) (#pcdb_compile_funs db) 667 val cf_res = lookup_constructorFamily false db col 668 669 fun process_cf_res (ty, cf) w = let 670 val ty_s = match_type ty (type_of (snd (hd col))) 671 val thm = constructorFamily_get_case_split cf 672 val thm' = INST_TYPE ty_s thm 673 in 674 (thm',#colstat_cases w, merge_ss [(#pcdb_ss db), simpLib.rewrites [ 675 (constructorFamily_get_rewrites cf)]]) 676 end 677 in case (fun_res, cf_res) of 678 (NONE, NONE) => (NONE, NONE) 679 | (NONE, SOME (tycf, w)) => (SOME (process_cf_res tycf w), SOME tycf) 680 | (SOME (thm, c_no, ss), NONE) => (SOME (thm, c_no, ss), NONE) 681 | (SOME (thm, c_no, ss), SOME (tycf, w)) => if (c_no < #colstat_cases w) then 682 (SOME (thm, c_no, ss), NONE) else (SOME (process_cf_res tycf w), SOME tycf) 683 end 684); 685 686fun pmatch_compile_db_compile db col = ( 687 fst (pmatch_compile_db_compile_aux db col)) 688 689fun pmatch_compile_db_compile_cf db col = ( 690 case (snd (pmatch_compile_db_compile_aux db col)) of 691 NONE => NONE 692 | SOME (_, cf) => SOME cf 693); 694 695(* 696fun pmatch_compile_db_compile_nchotomy db col = ( 697 if (List.null col) then failwith "pmatch_compile_db_compile_cf" "col 0" else 698 case (get_first (fn f => f col handle HOL_ERR _ => NONE) (#pcdb_nchotomy_funs db)) of 699 SOME r => r | NONE => ( 700 case (lookup_constructorFamilies true db col) of 701 NONE => NONE 702 | SOME (_, cf) => #nchotomy_thm cf)) 703*) 704 705fun pmatch_compile_db_compile_nchotomy db col = ( 706 if (List.null col) then failwith "pmatch_compile_db_compile_nchotomy" "col 0" else let 707 val fun_res = get_first (fn f => f col handle HOL_ERR _ => NONE) (#pcdb_nchotomy_funs db) 708 val cf_res = lookup_constructorFamily true db col 709 710 fun process_cf_res (_, cf) = #nchotomy_thm cf 711 712 in case (fun_res, cf_res) of 713 (NONE, NONE) => NONE 714 | (NONE, SOME (tycf, _)) => process_cf_res tycf 715 | (SOME (thm, _), NONE) => SOME thm 716 | (SOME (thm, _), SOME (tycf, w)) => if (0 < #colstat_missed_rows w) then 717 (SOME thm) else (process_cf_res tycf) 718 end 719); 720 721fun pmatch_compile_db_dest_constr_term (db : pmatch_compile_db) t = let 722 val ty = type_of t 723 val cfs = lookup_constructorFamilies_for_type db ty 724 val cstrs = flatten (List.map (#cl_constructors o #constructors o snd) cfs) 725in 726 first_opt (fn _ => fn cr => match_constructor cr t) cstrs 727end 728 729 730(***************************************************) 731(* updating dbs *) 732(***************************************************) 733 734fun pmatch_compile_db_add_ssfrag (db : pmatch_compile_db) ss = { 735 pcdb_compile_funs = #pcdb_compile_funs db, 736 pcdb_nchotomy_funs = #pcdb_nchotomy_funs db, 737 pcdb_constrFams = #pcdb_constrFams db, 738 pcdb_ss = (simpLib.merge_ss [ss, #pcdb_ss db]) 739} : pmatch_compile_db 740 741fun pmatch_compile_db_add_congs db thms = 742 pmatch_compile_db_add_ssfrag db (cong_ss thms); 743 744fun pmatch_compile_db_register_ssfrag ss = 745 thePmatchCompileDB := pmatch_compile_db_add_ssfrag (!thePmatchCompileDB) ss; 746 747fun pmatch_compile_db_register_congs thms = 748 pmatch_compile_db_register_ssfrag (cong_ss thms) 749 750fun pmatch_compile_db_add_compile_fun (db : pmatch_compile_db) cf = { 751 pcdb_compile_funs = cf::(#pcdb_compile_funs db), 752 pcdb_nchotomy_funs = #pcdb_nchotomy_funs db, 753 pcdb_constrFams = #pcdb_constrFams db, 754 pcdb_ss = #pcdb_ss db 755} : pmatch_compile_db 756 757fun pmatch_compile_db_register_compile_fun cf = 758 thePmatchCompileDB := pmatch_compile_db_add_compile_fun (!thePmatchCompileDB) cf 759 760fun pmatch_compile_db_add_nchotomy_fun (db : pmatch_compile_db) cf = { 761 pcdb_compile_funs = #pcdb_compile_funs db, 762 pcdb_nchotomy_funs = cf::(#pcdb_nchotomy_funs db), 763 pcdb_constrFams = #pcdb_constrFams db, 764 pcdb_ss = #pcdb_ss db 765} : pmatch_compile_db 766 767fun pmatch_compile_db_register_nchotomy_fun f = 768 thePmatchCompileDB := pmatch_compile_db_add_nchotomy_fun (!thePmatchCompileDB) f 769 770fun pmatch_compile_db_add_constrFam (db : pmatch_compile_db) cf = { 771 pcdb_compile_funs = #pcdb_compile_funs db, 772 pcdb_nchotomy_funs = #pcdb_nchotomy_funs db, 773 pcdb_constrFams = let 774 val cl = (#constructors cf) 775 val ty = normalise_ty (#cl_type cl) 776 val net = #pcdb_constrFams db 777 val cfs = TypeNet.find (net, ty) handle NotFound => [] 778 val net' = TypeNet.insert (net, ty, cf::cfs) 779 in 780 net' 781 end, 782 pcdb_ss = merge_ss [constructorFamily_get_ssfrag cf, (#pcdb_ss db)] 783} : pmatch_compile_db 784 785fun pmatch_compile_db_register_constrFam cf = 786 thePmatchCompileDB := pmatch_compile_db_add_constrFam (!thePmatchCompileDB) cf 787 788fun pmatch_compile_db_remove_type (db : pmatch_compile_db) ty = { 789 pcdb_compile_funs = #pcdb_compile_funs db, 790 pcdb_nchotomy_funs = #pcdb_nchotomy_funs db, 791 pcdb_constrFams = let 792 val ty = normalise_ty ty 793 val net = #pcdb_constrFams db 794 val net' = TypeNet.insert (net, ty, []) 795 in 796 net' 797 end, 798 pcdb_ss = #pcdb_ss db 799} : pmatch_compile_db 800 801fun pmatch_compile_db_clear_type ty = 802 thePmatchCompileDB := pmatch_compile_db_remove_type (!thePmatchCompileDB) ty 803 804 805 806(***************************************************) 807(* complilation funs *) 808(***************************************************) 809 810val COND_CONG_APPLY = prove (``(if (x:'a) = c then (ff x):'b else ff x) = 811 (if x = c then ff c else ff x)``, 812Cases_on `x = c` THEN ASM_REWRITE_TAC[]) 813 814 815fun literals_compile_fun (col:(term list * term) list) = let 816 817 fun extract_literal ((vs, c), (tl, ts)) = let 818 val vars = FVL [c] empty_tmset 819 val is_lit = not (List.exists (fn v => HOLset.member (vars, v)) vs) 820 in 821 if is_lit then ( 822 if (HOLset.member(ts,c)) then 823 (tl, ts) 824 else 825 ((c::tl), HOLset.add(ts,c)) 826 ) else 827 (if is_var c then (tl, ts) else failwith "" "extract_literal") 828 end 829 830 val (lits_rev, _) = List.foldl extract_literal ([], empty_tmset) col 831 val _ = if (List.null lits_rev) then (failwith "" "no lits") else () 832 val lits = List.rev lits_rev 833 val cases_no = List.length lits + 1 834 835 val rty = gen_tyvar () 836 val lit_ty = type_of (snd (List.hd col)) 837 val split_arg = mk_var ("x", lit_ty) 838 val split_fun = mk_var ("ff", lit_ty --> rty) 839 val arg = mk_comb (split_fun, split_arg) 840 841 fun mk_expand_thm lits = case lits of 842 [] => REFL arg 843 | (l :: lits') => let 844 val b = mk_eq (split_arg, l) 845 val thm0 = GSYM (ISPEC arg (SPEC b COND_ID)) 846 val thm1 = CONV_RULE (RHS_CONV (REWR_CONV COND_CONG_APPLY)) thm0 847 val thm2a = mk_expand_thm lits' 848 val thm2 = CONV_RULE (RHS_CONV (RAND_CONV (K thm2a))) thm1 849 in 850 thm2 851 end 852 853 val thm0 = mk_expand_thm lits 854 val thm1 = let 855 val thm0_rhs = rhs (concl thm0) 856 val thm1a = GSYM (ISPECL [mk_abs(split_arg, thm0_rhs), split_arg] literal_case_THM) 857 val thm1 = CONV_RULE (LHS_CONV BETA_CONV) thm1a 858 in 859 thm1 860 end 861 val thm2 = TRANS thm0 thm1 862 val thm3 = GEN split_fun (GEN split_arg thm2) 863 864 865 val cong_thm = let 866 fun mk_lits_preconds (sua, sub, c_tms) pre lits = 867 case lits of 868 [] => let 869 val negs = map (fn pl => mk_neg (mk_eq (split_arg, pl))) pre 870 val a = list_mk_conj negs 871 val sf = mk_comb (split_fun, split_arg) 872 val va = genvar (type_of sf) 873 val vb = genvar (type_of sf) 874 val c = mk_eq (va, vb) 875 876 val new_p = mk_imp (a, c) 877 878 in ((sf |-> va)::sua, (sf |-> vb)::sub, new_p::c_tms) end 879 | (l::lits') => let 880 val negs = map (fn pl => mk_neg (mk_eq (l, pl))) pre 881 val eq = mk_eq (split_arg, l) 882 val a = list_mk_conj (eq::negs) 883 884 val sf = mk_comb (split_fun, l) 885 val va = genvar (type_of sf) 886 val vb = genvar (type_of sf) 887 val c = mk_eq (va, vb) 888 889 val new_p = mk_imp (a, c) 890 in 891 (mk_lits_preconds ((sf |-> va)::sua, (sf |-> vb)::sub, new_p::c_tms) (l::pre) lits') 892 end 893 894 val (sua, sub, c_tms) = mk_lits_preconds ([], [], []) [] lits 895 val tt00 = rhs (concl thm0) 896 897 val tt0a = subst sua tt00 898 val tt0b = subst sub tt00 899 val tt0 = mk_eq (tt0a, tt0b) 900 val tt1 = list_mk_imp (List.rev c_tms, tt0) 901 val thm1 = prove(tt1, metisLib.METIS_TAC[]) 902 in 903 thm1 904 end 905in 906 SOME (thm3, cases_no, cong_ss [cong_thm]) 907end 908 909val _ = pmatch_compile_db_register_compile_fun literals_compile_fun 910 911 912(***************************************************) 913(* nchotomy funs *) 914(***************************************************) 915 916fun literals_nchotomy_fun (col:(term list * term) list) = let 917 fun extract_literal ((vs, c), ts) = let 918 val vars = FVL [c] empty_tmset 919 val is_lit = not (List.exists (fn v => HOLset.member (vars, v)) vs) 920 in 921 if is_lit then HOLset.add(ts,c) else 922 (if is_var c then ts else failwith "" "extract_literal") 923 end 924 925 val ts = List.foldl extract_literal empty_tmset col 926 val lits = HOLset.listItems ts 927 val cases_no = List.length lits + 1 928 val _ = if (List.null lits) then (failwith "" "no lits") else () 929 930 val lit_ty = type_of (snd (List.hd col)) 931 val split_arg = mk_var ("x", lit_ty) 932 val wc_arg = mk_var ("y", lit_ty) 933 934 val lit_tms = List.map (fn l => mk_eq (split_arg, l)) lits 935 val wc_tm = let 936 val not_tms = 937 List.map (fn l => mk_neg (mk_eq (wc_arg, l))) lits 938 val eq_tm = mk_eq (split_arg, wc_arg) 939 val b_tm = mk_conj (eq_tm, list_mk_conj not_tms) 940 in 941 mk_exists (wc_arg, b_tm) 942 end 943 944 val nchot_tm = list_mk_disj (lit_tms @ [wc_tm]) 945 val nchot_thm = prove(nchot_tm, 946 CONV_TAC (DEPTH_CONV Unwind.UNWIND_EXISTS_CONV) THEN 947 EVERY (List.map (fn t => 948 (BOOL_CASES_TAC t THEN REWRITE_TAC[])) lit_tms)) 949 val nchot_thm' = GEN split_arg nchot_thm 950in 951 SOME (nchot_thm', cases_no) 952end handle HOL_ERR _ => NONE 953 954val _ = pmatch_compile_db_register_nchotomy_fun literals_nchotomy_fun 955 956 957 958end 959