1structure patternMatchesLib :> patternMatchesLib = 2struct 3 4open HolKernel Parse boolLib Drule BasicProvers 5open simpLib numLib metisLib 6open patternMatchesTheory 7open listTheory 8open quantHeuristicsLib 9open DatatypeSimps 10open patternMatchesSyntax 11open Traverse 12open constrFamiliesLib 13open unwindLib 14open oneSyntax 15 16structure Parse = 17struct 18 open Parse 19 val (Type,Term) = 20 parse_from_grammars patternMatchesTheory.patternMatches_grammars 21end 22open Parse 23 24val list_ss = numLib.arith_ss ++ listSimps.LIST_ss 25 26(***********************************************) 27(* Auxiliary stuff *) 28(***********************************************) 29 30fun make_gen_conv_ss c name ssl = let 31 exception genconv_reducer_exn 32 fun addcontext (context,thms) = context 33 fun apply {solver,conv,context,stack,relation} tm = ( 34 QCHANGED_CONV (c (ssl, SOME (conv stack))) tm 35 ) 36 in simpLib.dproc_ss (REDUCER {name=SOME name, 37 addcontext=addcontext, apply=apply, 38 initial=genconv_reducer_exn}) 39 end; 40 41(* Often in the following, a single row needs extracting. 42 given a list l, we want to an element [n], the list of 43 elements before it and the list of elements after it. 44 So, we need an efficient way to compute: 45 (List.take (l, n), List.nth (l, n), List.drop (l, n+1)) 46 47 extract_element [0,1,2,3,4,5] 0 = ([], 0, [1,2,3,4,5]) 48 extract_element [0,1,2,3,4,5] 1 = ([0], 1, [2,3,4,5]) 49 extract_element [0,1,2,3,4,5] 3 = ([0,1,2], 3, [4,5]) 50 extract_element [0,1,2,3,4,5] 5 = ([0,1,2,3,4], 5, []) 51 *) 52 53fun extract_element l n = let 54 val (l1, l2) = Lib.split_after n l 55 in 56 case l2 of 57 [] => failwith "index too large" 58 | x::xs => (l1, x, xs) 59 end 60 61 62(* Similarly, we often need to replace an element with 63 a list of elements. We need an efficient way to compute 64 65 (List.take (l, n) @ new_elements @ List.drop (l, n+1), 66 List.nth (l, n)) 67 *) 68fun replace_element l n new = 69 if n < 0 then failwith "index too small" 70 else let 71 fun aux _ (_, []) = failwith "index too big" 72 | aux 0 (acc, x::xs) = 73 (List.revAppend (acc, new @ xs), x) 74 | aux n (acc, x::xs) = 75 aux (n-1) (x::acc, xs) 76 in 77 aux n ([], l) 78 end 79 80(* We have a problem with conversions that loop in a fancy way. 81 They add some pattern matching on the input variables and 82 in the body the original term with renamed variables. The 83 following function tries to detect this situation. *) 84fun does_conv_loop thm = let 85 val (l, r) = dest_eq (concl thm) 86 fun my_mk_abs t = list_mk_abs (free_vars_lr t, t) 87 val l' = my_mk_abs l 88 val const_check = let 89 val (l_c, _) = strip_comb l 90 in 91 fn t => (same_const (fst (strip_comb t)) l_c) 92 end handle HOL_ERR _ => (fn t => true) 93 fun is_similar t = const_check t andalso (aconv l' (my_mk_abs t)) 94 val i = ((find_term is_similar r; true) handle HOL_ERR _ => false) 95 in 96 i 97 end 98 99 100(***********************************************) 101(* Simpset to evaluate PMATCH_ROWS *) 102(***********************************************) 103 104val PAIR_EQ_COLLAPSE = prove ( 105``(((FST x = (a:'a)) /\ (SND x = (b:'b))) = (x = (a, b)))``, 106Cases_on `x` THEN SIMP_TAC std_ss [] THEN METIS_TAC[]) 107 108val PAIR_EQ_COLLAPSE = prove ( 109``(((FST x = (a:'a)) /\ (SND x = (b:'b))) = (x = (a, b)))``, 110Cases_on `x` THEN SIMP_TAC std_ss []) 111 112fun is_FST_eq x t = let 113 val (l, r) = dest_eq t 114 val pred = aconv (pairSyntax.mk_fst x) 115in 116 pred l 117end 118 119fun FST_SND_CONJUNCT_COLLAPSE v conj = let 120 val conj'_thm = markerLib.move_conj_left (is_FST_eq v) conj 121 122 val v' = pairSyntax.mk_snd v 123 124 val thm_coll = (TRY_CONV (RAND_CONV (FST_SND_CONJUNCT_COLLAPSE v')) THENC 125 (REWR_CONV PAIR_EQ_COLLAPSE)) 126 (rhs (concl conj'_thm)) 127in 128 TRANS conj'_thm thm_coll 129end handle HOL_ERR _ => raise UNCHANGED 130 131fun ELIM_FST_SND_SELECT_CONV t = let 132 val (v, conj) = boolSyntax.dest_select t 133 val thm0 = FST_SND_CONJUNCT_COLLAPSE v conj 134 135 val thm1 = RAND_CONV (ABS_CONV (K thm0)) t 136 val thm2 = CONV_RULE (RHS_CONV (REWR_CONV SELECT_REFL)) thm1 137in 138 thm2 139end handle HOL_ERR _ => raise UNCHANGED 140 141 142(* 143val rc = DEPTH_CONV pairTools.PABS_ELIM_CONV THENC SIMP_CONV list_ss [pairTheory.EXISTS_PROD, pairTheory.FORALL_PROD, PMATCH_ROW_EQ_NONE, PAIR_EQ_COLLAPSE, oneTheory.one] 144*) 145 146val pabs_elim_ss = 147 simpLib.conv_ss 148 {name = "PABS_ELIM_CONV", 149 trace = 2, 150 key = SOME ([],``UNCURRY (f:'a -> 'b -> bool)``), 151 conv = K (K pairTools.PABS_ELIM_CONV)} 152 153val elim_fst_snd_select_ss = 154 simpLib.conv_ss 155 {name = "ELIM_FST_SND_SELECT_CONV", 156 trace = 2, 157 key = SOME ([],``$@ (f:'a -> bool)``), 158 conv = K (K ELIM_FST_SND_SELECT_CONV)} 159 160val select_conj_ss = 161 simpLib.conv_ss 162 {name = "SELECT_CONJ_SS_CONV", 163 trace = 2, 164 key = SOME ([],``$@ (f:'a -> bool)``), 165 conv = K (K (SIMP_CONV (std_ss++boolSimps.CONJ_ss) []))}; 166 167(* A basic simpset-fragment with a lot of useful stuff 168 to automatically show the validity of preconditions 169 as produced by functions in this library. *) 170val static_ss = simpLib.merge_ss 171 [pabs_elim_ss, 172 pairSimps.paired_forall_ss, 173 pairSimps.paired_exists_ss, 174 pairSimps.gen_beta_ss, 175 select_conj_ss, 176 elim_fst_snd_select_ss, 177 boolSimps.EQUIV_EXTRACT_ss, 178 quantHeuristicsLib.SIMPLE_QUANT_INST_ss, 179 simpLib.rewrites [ 180 some_var_bool_T, some_var_bool_F, 181 GSYM boolTheory.F_DEF, 182 pairTheory.EXISTS_PROD, 183 pairTheory.FORALL_PROD, 184 PMATCH_ROW_EQ_NONE, 185 PMATCH_ROW_COND_def, 186 PMATCH_ROW_COND_EX_def, 187 PAIR_EQ_COLLAPSE, 188 oneTheory.one]]; 189 190(* We add the stateful rewrite set (to simplify 191 e.g. case-constants or constructors) and a 192 custum component as well. *) 193fun rc_ss gl = simpLib.remove_ssfrags (srw_ss() ++ simpLib.merge_ss (static_ss :: gl)) ["patternMatchesSimp"] 194 195(* finally we add a call-back component. This is an 196 external conversion that is used at the end if 197 everything else fails. This is used to have nested calls 198 of the simplifier. The simplifier executes some conversion that 199 uses rs_ss. At the end, we might want to use the external 200 simplifier. This is realised with these call-backs. *) 201fun callback_CONV cb_opt t = (case cb_opt of 202 NONE => NO_CONV t 203 | SOME cb => (if (can (find_term is_PMATCH) t) then 204 NO_CONV t 205 else cb t)); 206 207fun rc_conv_rws (gl, callback_opt) thms = REPEATC ( 208 SIMP_CONV (rc_ss gl) thms THENC 209 TRY_CONV (callback_CONV callback_opt)) 210 211(* So, now combine it to get some convenient high-level 212 functions. *) 213fun rc_conv rc_arg = rc_conv_rws rc_arg [] 214 215fun rc_tac (gl, callback_opt) = 216 CONV_TAC (rc_conv (gl, callback_opt)) 217 218fun rc_elim_precond rc_arg thm = let 219 val pre = rand (rator (concl thm)) 220 val pre_thm = prove_attempt (pre, rc_tac rc_arg) 221 val thm2 = MP thm pre_thm 222in 223 thm2 224end 225 226(* fix_appends expects a theorem of the form 227 PMATCH v rows = PMATCH v' rows' 228 229 and a term l of form 230 PMATCH v rows0. 231 232 It tries to get the appends in rows and rows' in 233 a nice form. To do this, it tries to prove that 234 l and the lhs of the theorem are equal. 235 Then it tries to simplify appends in rows' 236 resulting in rows''. 237 238 It returns a theorem of the form 239 240 l = PMATCH v' rows''. 241*) 242fun fix_appends rc_arg l thm = let 243 val t_eq_thm = prove (mk_eq (l, lhs (concl thm)), 244 CONV_TAC (DEPTH_CONV listLib.APPEND_CONV) THEN 245 rc_tac rc_arg) 246 247 val thm2 = TRANS t_eq_thm thm 248 249 fun my_append_conv t = let 250 val _ = if listSyntax.is_append t then () else raise UNCHANGED 251 in 252 (BINOP_CONV (TRY_CONV my_append_conv) THENC 253 listLib.APPEND_CONV) t 254 end 255 256 val thm3 = CONV_RULE (RHS_CONV (RAND_CONV my_append_conv)) thm2 257 handle HOL_ERR _ => thm2 258 | UNCHANGED => thm2 259in 260 thm3 261end 262 263(* Apply a conversion to all args of a PMATCH_ROW, i.e. given 264 a term of the form ``PMATCH_ROW pat guard rhs i`` 265 it applies a conversion to ``pat`` ``guard`` and ``rhs``. *) 266fun PMATCH_ROW_ARGS_CONV c = 267 RATOR_CONV (RAND_CONV (TRY_CONV c)) THENC 268 RATOR_CONV (RATOR_CONV (RAND_CONV (TRY_CONV c))) THENC 269 RATOR_CONV (RATOR_CONV (RATOR_CONV (RAND_CONV (TRY_CONV c)))) 270 271 272(***********************************************) 273(* converting between case-splits to PMATCH *) 274(***********************************************) 275 276(* ----------------------- *) 277(* Auxiliary functions for *) 278(* case2pmatch *) 279(* ----------------------- *) 280 281(* 282val t = ``case x of 283 (NONE, []) => 0`` *) 284 285fun type_names ty = 286 let val {Thy,Tyop,Args} = Type.dest_thy_type ty 287 in {Thy=Thy,Tyop=Tyop} 288 end; 289 290(* destruct variant cases, see dest_case_fun *) 291fun dest_case_fun_aux1 t = let 292 val (f, args) = strip_comb t 293 val (tys, _) = strip_fun (type_of f) 294 val _ = if (List.null args) then failwith "dest_case_fun" else () 295 val ty = case tys of 296 [] => failwith "dest_case_fun" 297 | (ty::_) => ty 298 val tn = type_names ty 299 val ti = case TypeBase.fetch ty of 300 NONE => failwith "dest_case_fun" 301 | SOME ti => ti 302 303 val _ = if (same_const (TypeBasePure.case_const_of ti) f) then 304 () else failwith "dest_case_fun" 305 306 val ty_s = match_type (type_of (TypeBasePure.case_const_of ti)) (type_of f) 307 val constrs = List.map (inst ty_s) (TypeBasePure.constructors_of ti) 308 309 val a = hd args 310 val ps = map2 (fn c => fn arg => let 311 val (vars, res) = strip_abs arg in 312 (list_mk_comb (c, vars), res) end) constrs (tl args) 313in 314 (a, ps) 315end 316 317(* destruct literal cases, see dest_case_fun *) 318fun dest_case_fun_aux2 t = let 319 val _ = if is_literal_case t then () else failwith "dest_case_fun" 320 321 val (f, args) = strip_comb t 322 323 val v = (el 2 args) 324 val (v', b) = dest_abs (el 1 args) 325 326 fun strip_cond acc b = let 327 val (c, t_t, t_f) = dest_cond b 328 val (c_l, c_r) = dest_eq c 329 val _ = if (aconv c_l v') then () else failwith "dest_case_fun" 330 in 331 strip_cond ((c_r, t_t)::acc) t_f 332 end handle HOL_ERR _ => (acc, b) 333 334 val (ps_rev, c_else) = strip_cond [] b 335 val ps = List.rev ((v', c_else) :: ps_rev) 336in 337 (v, ps) 338end 339 340 341(* destruct a case-function. 342 The top-most split is split into the input + a list of rows. 343 Each row consists of a pattern + the right-hand side. *) 344fun dest_case_fun t = dest_case_fun_aux1 t handle HOL_ERR _ => dest_case_fun_aux2 t 345 346 347(* try to collapse rows by introducing a catchall at end*) 348fun dest_case_fun_collapse (a, ps) = let 349 350 (* find all possible catch-all clauses *) 351 fun check_collapsable (p, rh) = let 352 val p_vs = FVL [p] empty_tmset 353 val rh' = if HOLset.isEmpty p_vs then rh else 354 Term.subst [p |-> a] rh 355 val ok = HOLset.isEmpty (HOLset.intersection (FVL [rh'] empty_tmset, p_vs)) 356 in 357 if ok then SOME rh' else NONE 358 end 359 360 val catch_all_cands = List.foldl (fn (prh, cs) => 361 case check_collapsable prh of 362 NONE => cs 363 | SOME rh => rh::cs) [] ps 364 365 (* really collapse *) 366 fun is_not_cought ca (p, rh) = 367 not (aconv rh (Term.subst [a |-> p] ca)) 368 369 val all_collapse_opts = List.map (fn ca => (ca, filter (is_not_cought ca) ps)) catch_all_cands 370 371 val all_collapse_opts_sorted = sort (fn (_, l1) => fn (_, l2) => List.length l1 < List.length l2) all_collapse_opts 372 373 (* could we collapse 2 cases? *) 374in 375 if (List.null all_collapse_opts) then (a, ps) else 376 let 377 val (ca', ps') = hd all_collapse_opts_sorted 378 in if (List.length ps' + 1 < List.length ps) then 379 (a, (ps' @ [(a, ca')])) 380 else (a, ps) 381 end 382end 383 384fun case2pmatch_aux optimise x t = let 385 val (a, ps) = dest_case_fun t 386 val _ = if is_var a andalso free_in a x then () else failwith "case-split on non pattern var" 387 val (a, ps) = if optimise then (dest_case_fun_collapse (a, ps)) else (a, ps) 388 389 fun process_arg (p, rh) = let 390 val x' = subst [a |-> p] x 391 in 392 (* recursive call *) 393 case case2pmatch_aux optimise x' rh of 394 NONE => [(x', rh)] 395 | SOME resl => resl 396 end 397 398 val ps = flatten (map process_arg ps) 399in 400 SOME ps 401end handle HOL_ERR _ => NONE; 402 403fun case2pmatch_remove_unnessary_rows ps = let 404 fun mk_distinct_rows (p1, _) (p2, rh2) = let 405 val avoid = free_vars p1 406 val (s, _) = List.foldl (fn (v, (s, av)) => 407 let val v' = variant av v in 408 ((v |-> v')::s, v'::av) end) ([], avoid) (free_vars p2) 409 val p2' = Term.subst s p2 410 val rh2' = Term.subst s rh2 411 in 412 (p2', rh2') 413 end 414 415 fun pats_unify (p1, _) (p2, _) = ( 416 (Unify.simp_unify_terms [] p1 p2; true) handle HOL_ERR _ => false 417 ) 418 419 fun row_subsumed (p1, rh1) (p2, rh2) = let 420 val (s, _) = match_term p2 p1 421 val rh2' = Term.subst s rh2 422 in aconv rh2' rh1 end handle HOL_ERR _ => false 423 424 fun row_is_needed r1 rs = case rs of 425 [] => true 426 | r2::rs' => let 427 val r2' = mk_distinct_rows r1 r2 428 in 429 if pats_unify r1 r2' then ( 430 not (row_subsumed r1 r2') 431 ) else row_is_needed r1 rs' 432 end 433 434 fun check_rows acc rs = case rs of 435 [] => List.rev acc 436 | [_] => (* drop last one *) List.rev acc 437 | r::rs' => check_rows (if row_is_needed r rs' then r::acc else acc) 438 rs' 439 440 val ps' = case ps of 441 [] => [] 442 | (p, rh)::_ => (ps @ [(genvar (type_of p), mk_arb (type_of rh))]) 443 444in 445 check_rows [] ps' 446end 447 448 449(* ----------------------- *) 450(* End Auxiliary functions *) 451(* for case2pmatch *) 452(* ----------------------- *) 453 454(* 455val (p1, rh1) = el 5 ps 456val (p2, rh2) = mk_distinct_rows (p1, rh1) (el 6 ps) 457ps 458*) 459 460(* convert a case-term into a PMATCH-term, without any proofs *) 461fun case2pmatch opt t = let 462 val (f, args) = strip_comb t 463 val _ = if (List.null args) then failwith "not a case-split" else () 464 465 val (p,patterns) = if is_literal_case t then (el 2 args, [el 1 args]) else 466 (hd args, tl args) 467 val v = genvar (type_of p) 468 469 val t0 = if is_literal_case t then list_mk_comb (f, patterns @ [v]) else list_mk_comb (f, v::patterns) 470 val ps = case case2pmatch_aux opt v t0 of 471 NONE => failwith "not a case-split" 472 | SOME ps => ps 473 474 val ps = if opt then case2pmatch_remove_unnessary_rows ps else ps 475 476 fun process_pattern (p, rh) = let 477 val fvs = List.rev (free_vars p) 478 in 479 if opt then 480 snd (mk_PMATCH_ROW_PABS_WILDCARDS fvs (p, T, rh)) 481 else 482 mk_PMATCH_ROW_PABS fvs (p, T, rh) 483 end 484 val rows = List.map process_pattern ps 485 val rows_tm = listSyntax.mk_list (rows, type_of (hd rows)) 486 487 val rows_tm_p = Term.subst [v |-> p] rows_tm 488in 489 mk_PMATCH p rows_tm_p 490end 491 492(* So far, we converted a classical case-expression 493 to a PMATCH without any proof. The following is used 494 to prove the equivalence of the result via repeated 495 case-splits and evaluation. This allows to 496 define some conversions then. *) 497 498val COND_CONG_STOP = prove (`` 499 (c = c') ==> ((if c then x else y) = (if c' then x else y))``, 500SIMP_TAC std_ss []) 501 502fun case_pmatch_eq_prove t t' = let 503 val tm = mk_eq (t, t') 504 505 (* very slow, simple approach. Just brute force. 506 TODO: change implementation to get more runtime-speed *) 507 val my_tac = ( 508 REPEAT (BasicProvers.TOP_CASE_TAC THEN 509 ASM_REWRITE_TAC[]) THEN 510 FULL_SIMP_TAC (rc_ss []) [PMATCH_EVAL, PMATCH_ROW_COND_def, 511 PMATCH_INCOMPLETE_def] 512 ) 513in 514 (* set_goal ([], tm) *) 515 prove (tm, REPEAT my_tac) 516end handle HOL_ERR _ => raise UNCHANGED 517 518 519fun PMATCH_INTRO_CONV t = 520 case_pmatch_eq_prove t (case2pmatch true t) 521 522fun PMATCH_INTRO_CONV_NO_OPTIMISE t = 523 case_pmatch_eq_prove t (case2pmatch false t) 524 525 526(* ------------------------- *) 527(* pmatch2case *) 528(* ------------------------- *) 529 530(* convert a case-term into a PMATCH-term, without any proofs *) 531fun pmatch2case t = let 532 val (v, rows) = dest_PMATCH t 533 val fv = genvar (type_of v --> type_of t) 534 535 fun process_row r = let 536 val (vars_tm, pt, gt, rh) = dest_PMATCH_ROW_ABS r 537 val _ = if (aconv gt T) then () else 538 failwith ("guard present in row " ^ 539 (term_to_string r)) 540 541 val vars = FVL [vars_tm] empty_tmset 542 val used_vars = FVL [pt] empty_tmset 543 val free_vars = HOLset.difference (used_vars, vars) 544 val _ = if (HOLset.isEmpty free_vars) then () else 545 failwith ("free variables in pattern " ^ (term_to_string pt)) 546 in 547 mk_eq (mk_comb (fv, pt), rh) 548 end 549 550 val row_eqs = map process_row rows 551 val rows_tm = list_mk_conj row_eqs 552 553 (* compile patterns *) 554 val case_tm0 = GrammarSpecials.compile_pattern_match rows_tm 555 556 557 (* nearly there, now remove lambda's *) 558 val (vs, case_tm1) = strip_abs case_tm0 559 val case_tm = subst [el 2 vs |-> v] case_tm1 560in 561 case_tm 562end 563 564fun PMATCH_ELIM_CONV t = 565 case_pmatch_eq_prove t (pmatch2case t) 566 567 568 569(***********************************************) 570(* removing redundant rows *) 571(***********************************************) 572 573(* 574val rc_arg = ([], NONE) 575 576val t = `` 577 case l of 578 | [] => 0 579 | x::y::x::y::_ => (x + y) 580 | x::x::x::x::_ when (x > 10) => x 581 | x::x::x::x::x::_ => 9 582 | [] => 1 583 | x::x::x::y::_ => (x + x + x) 584 | x::_ => 1 585 | x::y::z::_ => (x + x + x) 586 `` 587 588val (rows, _) = listSyntax.dest_list (rand t) 589*) 590 591(* For removing redundant rows we want to check whether 592 the pattern of a row is overlapped by the pattern of a 593 previous row. In preparation for this, we extract all 594 patterns and generate fresh variables for it. The we 595 build for all rows the pair of the pattern + the patterns 596 of all following rows. This allows for simple checks 597 via matching later. *) 598fun compute_row_pat_pairs rows = let 599 (* get pats with fresh vars to do a quick prefiltering *) 600 val pats_unique = Lib.enumerate 0 (Lib.mapfilter (fn r => let 601 val (p, _, _) = dest_PMATCH_ROW r 602 val (vars_tm, pb) = pairSyntax.dest_pabs p 603 val vars = pairSyntax.strip_pair vars_tm 604 val s = List.map (fn v => (v |-> genvar (type_of v))) vars 605 val vars' = map (fn x => #residue x) s 606 val pb' = subst s pb 607 in 608 (vars', pb') 609 end) rows) 610 611 (* get all pairs, first component always appears before second *) 612 val candidates = let 613 fun aux acc l = case l of 614 [] => acc 615 | (x::xs) => aux ((List.map (fn y => (x, y)) xs) @ acc) xs 616 in 617 aux [] pats_unique 618 end 619in 620 candidates 621end 622 623(* Now do the real filtering *) 624fun PMATCH_REMOVE_FAST_REDUNDANT_CONV_GENCALL_SINGLE rc_arg t = let 625 val (v, rows) = dest_PMATCH t 626 val candidates = compute_row_pat_pairs rows 627 628 (* quick filter on matching *) 629 val candidates_match = let 630 fun does_match ((_, (v1, p1)), (_, (v2, p2))) = 631 let 632 val (t_s, ty_s) = match_term p1 p2 633 in 634 (null ty_s) andalso 635 (Lib.all (fn x => mem (#redex x) v1) t_s) 636 end handle HOL_ERR _ => false 637 in 638 List.filter does_match candidates 639 end 640 641 (* filtering finished, now try it for real *) 642 val cands = List.map (fn ((p1, _), (p2, _)) => (p1, p2)) candidates_match 643 (* val (r_no1, r_no2) = el 1 cands *) 644 fun try_pair (r_no1, r_no2) = let 645 val tm0 = let 646 val (rows1, r1, rows_rest) = extract_element rows r_no1 647 val (rows2, r2, rows3) = extract_element rows_rest (r_no2 - r_no1 - 1) 648 649 val rows1_tm = listSyntax.mk_list (rows1, type_of r1) 650 val rows2_tm = listSyntax.mk_list (rows2, type_of r1) 651 val r1rows2_tm = listSyntax.mk_cons (r1, rows2_tm) 652 val rows3_tm = listSyntax.mk_list (rows3, type_of r1) 653 val r2rows3_tm = listSyntax.mk_cons (r2, rows3_tm) 654 655 val arg = listSyntax.list_mk_append [rows1_tm, r1rows2_tm, r2rows3_tm] 656 in 657 mk_PMATCH v arg 658 end 659 660 val thm0 = FRESH_TY_VARS_RULE PMATCH_ROWS_DROP_REDUNDANT_PMATCH_ROWS 661 val thm1 = PART_MATCH (lhs o rand) thm0 tm0 662 663 val thm2 = rc_elim_precond rc_arg thm1 664 val thm3 = fix_appends rc_arg t thm2 665 in 666 thm3 667 end 668in 669 Lib.tryfind try_pair cands 670end handle HOL_ERR _ => raise UNCHANGED 671 672fun PMATCH_REMOVE_FAST_REDUNDANT_CONV_GENCALL rc_arg = REPEATC (PMATCH_REMOVE_FAST_REDUNDANT_CONV_GENCALL_SINGLE rc_arg) 673fun PMATCH_REMOVE_FAST_REDUNDANT_CONV_GEN ssl = PMATCH_REMOVE_FAST_REDUNDANT_CONV_GENCALL (ssl, NONE) 674val PMATCH_REMOVE_FAST_REDUNDANT_CONV = PMATCH_REMOVE_FAST_REDUNDANT_CONV_GEN [] 675 676 677(***********************************************) 678(* removing subsumed rows *) 679(***********************************************) 680 681(* 682val rc_arg = ([], NONE) 683 684set_trace "parse deep cases" 0 685val t = case2pmatch false ``case x of NONE => 0`` 686 687val t = case2pmatch false ``case (x, y, z) of 688 (0, y, z) => 2 689 | (x, NONE, []) => x 690 | (x, SOME y, l) => x+y`` 691 692val t = 693 ``case (x,y,z) of 694 | (0,v1) => 2 695 | (SUC v4,NONE,[]) => (SUC v4) 696 | (SUC v4,NONE,v10::v11) => ARB 697 | (v4,NONE,_) => v4 698 | (0,SOME _ ,_) => ARB 699 | (SUC v4,SOME v9,v8) => (SUC v4 + v9) 700 `` 701 702*) 703 704(* When removing subsumed rows, i.e. rows that can be dropped, 705 because a following rule covers them, we can sometimes drop rows with 706 right-hand-side ARB, because PMATCH v [] evalutates to ARB. 707 This is semantically fine, but changes the users-view. The resulting 708 case expression might e.g. not be exhaustive any more. This can 709 also cause trouble for code generation. Therefore the parameter 710 [exploit_match_exp] determines, whether this optimisation is performed. *) 711fun PMATCH_REMOVE_FAST_SUBSUMED_CONV_GENCALL_SINGLE 712 exploit_match_exp rc_arg t = let 713 val (v, rows) = dest_PMATCH t 714 val candidates = compute_row_pat_pairs rows 715 716 (* quick filter on matching *) 717 val candidates_match = let 718 fun does_match ((_, (v1, p1)), (_, (v2, p2))) = 719 let 720 val (t_s, ty_s) = match_term p2 p1 721 in 722 (null ty_s) andalso 723 (Lib.all (fn x => mem (#redex x) v2) t_s) 724 end handle HOL_ERR _ => false 725 in 726 List.filter does_match candidates 727 end 728 729 val cands_sub = List.map (fn ((p1, _), (p2, _)) => (p1, SOME p2)) candidates_match 730 731 (* filtering finished, now try it for real *) 732 fun cands_arb () = Lib.mapfilter (fn (i, r) => let 733 val (_, _, _, r) = dest_PMATCH_ROW_ABS r in 734 (dest_arb r; (i, (NONE : int option))) end) (Lib.enumerate 0 rows) 735 736 val cands = if exploit_match_exp then (cands_sub @ cands_arb ()) else 737 cands_sub 738 739 (* filtering finished, now try it for real *) 740 (* val (r_no1, r_no2_opt) = el 2 cands_arb *) 741 fun try_pair (r_no1, r_no2_opt) = let 742 fun mk_row_list rs = listSyntax.mk_list (rs, type_of (hd rows)) 743 744 fun extract_el_n n rs = let 745 val (rows1,r1,rows_rest) = extract_element rs n 746 val rows1_tm = mk_row_list rows1 747 748 fun build_tm rest_tm = 749 listSyntax.mk_append (rows1_tm, 750 (listSyntax.mk_cons (r1, rest_tm))) 751 in 752 (rows_rest, build_tm) 753 end 754 755 val tm0 = let 756 val (rs_rest, bf_1) = extract_el_n r_no1 rows 757 758 val rs2 = case r_no2_opt of 759 NONE => mk_row_list rs_rest 760 | SOME n => let 761 val n' = n - r_no1 - 1 762 val (rs_rest', bf_2) = extract_el_n n' rs_rest 763 in 764 bf_2 (mk_row_list rs_rest') 765 end 766 in 767 mk_PMATCH v (bf_1 rs2) 768 end 769 770 val thm_base = case r_no2_opt of 771 NONE => PMATCH_REMOVE_ARB_NO_OVERLAP 772 | SOME _ => PMATCH_ROWS_DROP_SUBSUMED_PMATCH_ROWS 773 val thm0 = FRESH_TY_VARS_RULE thm_base 774 val thm1 = PART_MATCH (lhs o rand) thm0 tm0 775 776 val thm2 = rc_elim_precond rc_arg thm1 777 val thm3 = fix_appends rc_arg t thm2 778 in 779 thm3 780 end 781in 782 Lib.tryfind try_pair cands 783end handle HOL_ERR _ => raise UNCHANGED 784 785fun PMATCH_REMOVE_FAST_SUBSUMED_CONV_GENCALL eme rc_arg = REPEATC (PMATCH_REMOVE_FAST_SUBSUMED_CONV_GENCALL_SINGLE eme rc_arg) 786fun PMATCH_REMOVE_FAST_SUBSUMED_CONV_GEN eme ssl = PMATCH_REMOVE_FAST_SUBSUMED_CONV_GENCALL eme (ssl, NONE) 787fun PMATCH_REMOVE_FAST_SUBSUMED_CONV eme = PMATCH_REMOVE_FAST_SUBSUMED_CONV_GEN eme [] 788 789 790(***********************************************) 791(* Cleaning up unused vars in PMATCH_ROW *) 792(***********************************************) 793 794(*val t = `` 795PMATCH (SOME x, xz) 796 [PMATCH_ROW (\x. (SOME 2,x,[])) (\x. T) (\x. x); 797 PMATCH_ROW (\y:'a. ((SOME 2,3,[]))) (\y. T) (\y. x); 798 PMATCH_ROW (\(z,x,yy). (z,x,[2])) (\(z,x,yy). T) (\(z,x,yy). x)]`` 799*) 800 801 802(* Many simps depend on patterns being injective. This means 803 in particular that no extra, unused vars occur in the patterns. 804 The following removes such unused vars. *) 805 806fun PMATCH_CLEANUP_PVARS_CONV t = let 807 val _ = if is_PMATCH t then () else raise UNCHANGED 808 809 fun row_conv row = let 810 val (vars_tm, pt, gt, rh) = dest_PMATCH_ROW_ABS row 811 val _ = if (type_of vars_tm = one_ty) then raise UNCHANGED else () 812 val vars = pairSyntax.strip_pair vars_tm 813 val used_vars = FVL [pt, rh] empty_tmset 814 815 val filtered_vars = filter (fn v => HOLset.member (used_vars, v)) vars 816 817 val _ = if (length vars = length filtered_vars) then 818 raise UNCHANGED else () 819 820 val row' = mk_PMATCH_ROW_PABS filtered_vars (pt, gt, rh) 821 822 val eq_tm = mk_eq (row, row') 823 (* set_goal ([], eq_tm) *) 824 val eq_thm = prove (eq_tm, 825 MATCH_MP_TAC PMATCH_ROW_EQ_AUX THEN 826 rc_tac ([], NONE) 827 ) 828 in 829 eq_thm 830 end 831in 832 CHANGED_CONV (DEPTH_CONV (PMATCH_ROW_FORCE_SAME_VARS_CONV THENC row_conv)) t 833end handle HOL_ERR _ => raise UNCHANGED 834 835 836(***********************************************) 837(* Cleaning up by removing rows that *) 838(* don't match or are redundant *) 839(* also remove the whole PMATCH, if first *) 840(* row matches *) 841(***********************************************) 842 843(* 844val t = `` 845PMATCH (NONE,x,l) 846 [PMATCH_ROW (\x. (NONE,x,[])) (\x. T) (\x. x); 847 PMATCH_ROW (\x. (NONE,x,[2])) (\x. F) (\x. x); 848 PMATCH_ROW (\x. (NONE,x,[2])) (\x. T) (\x. x); 849 PMATCH_ROW (\(x,y). (y,x,[2])) (\(x, y). T) (\(x, y). x); 850 PMATCH_ROW (\x. (SOME 3,x,[])) (\x. T) (\x. x) 851 ]`` 852 853val t = ``PMATCH y [PMATCH_ROW (\_0_1. _0_1) (\_0_1. T) (\_0_1. F)]`` 854 855val t = ``case (SUC x) of x => x + 3`` 856 857val rc_arg = ([], NONE) 858 859val t' = rhs (concl (PMATCH_CLEANUP_CONV t)) 860*) 861 862fun map_filter f l = case l of 863 [] => [] 864 | x::xs => (case f x of 865 NONE => map_filter f xs 866 | SOME y => y :: (map_filter f xs)); 867 868(* remove redundant rows *) 869fun PMATCH_CLEANUP_CONV_GENCALL rc_arg t = let 870 val (v, rows) = dest_PMATCH t 871 val _ = if (null rows) then raise UNCHANGED else () 872 873 fun check_row r = let 874 val r_tm = mk_eq (mk_comb (r, v), optionSyntax.mk_none (type_of t)) 875 val r_thm = rc_conv rc_arg r_tm 876 val res_tm = rhs (concl r_thm) 877 in 878 if (same_const res_tm T) then SOME (true, r_thm) else 879 (if (same_const res_tm F) then SOME (false, r_thm) else NONE) 880 end handle HOL_ERR _ => NONE 881 882 val (rows_checked_rev, _) = foldl (fn (r, (acc, abort)) => 883 if abort then ((r, NONE)::acc, true) else ( 884 let 885 val res = check_row r 886 val abort = (case res of 887 (SOME (false, _)) => true 888 | _ => false) 889 in 890 ((r, res)::acc, abort) 891 end)) ([], false) rows 892 val rows_checked = List.rev rows_checked_rev 893 894 (* did we get any results? *) 895 fun check_row_exists v rows = 896 exists (fn x => case x of (_, SOME (v', _)) => v = v' | _ => false) rows 897 898 val _ = if ((check_row_exists true rows_checked_rev) orelse (check_row_exists false (tl rows_checked_rev)) orelse (check_row_exists false [hd rows_checked])) then () else raise UNCHANGED 899 900 val row_ty = type_of (hd rows) 901 902 (* drop redundant rows *) 903 val (thm0, rows_checked0) = let 904 val n = index (fn x => case x of (_, SOME (false, _)) => true | _ => false) rows_checked 905 val n_tm = numSyntax.term_of_int n 906 907 val thma = ISPECL [v, listSyntax.mk_list (rows, row_ty), n_tm] 908 (FRESH_TY_VARS_RULE PMATCH_ROWS_DROP_REDUNDANT_TRIVIAL_SOUNDNESS) 909 910 val precond = fst (dest_imp (concl thma)) 911 val precond_thm = prove (precond, 912 MP_TAC (snd(valOf (snd (el (n+1) rows_checked)))) THEN 913 SIMP_TAC list_ss [quantHeuristicsTheory.IS_SOME_EQ_NOT_NONE]) 914 915 val thmb = MP thma precond_thm 916 917 val take_conv = RATOR_CONV (RAND_CONV reduceLib.SUC_CONV) THENC 918 listLib.FIRSTN_CONV 919 val thmc = CONV_RULE (RHS_CONV (RAND_CONV take_conv)) thmb 920 in 921 (thmc, List.take (rows_checked, n+1)) 922 end handle HOL_ERR _ => (REFL t, rows_checked) 923 924 (* drop false rows *) 925 val (thm1, rows_checked1) = let 926 val _ = if (exists (fn x => case x of (_, (SOME (true, _))) => true | _ => false) rows_checked0) then () else failwith "nothing to do" 927 928 fun process_row ((r, r_thm_opt), thm) = (case r_thm_opt of 929 (SOME (true, r_thm)) => let 930 val thmA = FRESH_TY_VARS_RULE PMATCH_EXTEND_OLD 931 val thmB = HO_MATCH_MP thmA (EQT_ELIM r_thm) 932 val thmC = HO_MATCH_MP thmB thm 933 in 934 thmC 935 end 936 | _ => let 937 val thmA = PMATCH_EXTEND_BOTH_ID 938 val thmB = HO_MATCH_MP thmA thm 939 in 940 ISPEC r thmB 941 end) 942 943 val base_thm = INST_TYPE [gamma |-> type_of t] (ISPECL [v, v] PMATCH_EXTEND_BASE) 944 val thma = foldl process_row base_thm (List.rev rows_checked0) 945 946 val rows_checked1 = filter (fn (_, res_opt) => case res_opt of 947 SOME (true, thm) => false 948 | _ => true) rows_checked0 949 in 950 (thma, rows_checked1) 951 end handle HOL_ERR _ => (REFL (rhs (concl thm0)), rows_checked0) 952 953 954 (* if first line matches, evaluate *) 955 val thm2 = let 956 val _ = if (not (List.null rows_checked1) andalso 957 (case hd rows_checked1 of (_, (SOME (false, _))) => true | _ => false)) then () else failwith "nothing to do" 958 959 val thm1_tm = rhs (concl thm1) 960 val thm2a = PART_MATCH (lhs o rand) PMATCH_EVAL_MATCH thm1_tm 961 val pre_thm = EQF_ELIM (snd (valOf(snd (hd rows_checked1)))) 962 val thm2b = MP thm2a pre_thm 963 964 val thm2c = CONV_RULE (RHS_CONV 965 (RAND_CONV (rc_conv rc_arg) THENC 966 pairLib.GEN_BETA_CONV)) thm2b handle HOL_ERR _ => thm2b 967 in 968 thm2c 969 end handle HOL_ERR _ => let 970 val _ = if (List.null rows_checked1) then () else failwith "nothing to do" 971 in 972 (REWR_CONV (CONJUNCT1 PMATCH_def)) (rhs (concl thm1)) 973 end handle HOL_ERR _ => REFL (rhs (concl thm1)) 974in 975 TRANS (TRANS thm0 thm1) thm2 976end handle HOL_ERR _ => raise UNCHANGED 977 978 979fun PMATCH_CLEANUP_CONV_GEN ssl = PMATCH_CLEANUP_CONV_GENCALL (ssl, NONE) 980fun PMATCH_CLEANUP_GEN_ss ssl = 981 make_gen_conv_ss PMATCH_CLEANUP_CONV_GENCALL "PMATCH_CLEANUP_REDUCER" ssl 982val PMATCH_CLEANUP_ss = PMATCH_CLEANUP_GEN_ss [] 983val PMATCH_CLEANUP_CONV = PMATCH_CLEANUP_CONV_GEN []; 984val _ = computeLib.add_convs [(patternMatchesSyntax.PMATCH_tm, 2, QCHANGED_CONV PMATCH_CLEANUP_CONV)]; 985 986 987(***********************************************) 988(* simplify a column *) 989(***********************************************) 990 991(* This can also be considered partial evaluation *) 992 993fun pair_get_col col v = let 994 val vs = pairSyntax.strip_pair v 995 val (vs', c_v) = replace_element vs col [] 996 val _ = if (List.null vs') then failwith "pair_get_col" 997 else () 998 val v' = pairSyntax.list_mk_pair vs' 999in 1000 (v', c_v) 1001end; 1002 1003(*----------------*) 1004(* drop a column *) 1005(*----------------*) 1006 1007(* 1008val t = `` 1009PMATCH (NONE,x,l) 1010 [PMATCH_ROW (\x. (NONE,x,[])) (\x. T) (\x. x); 1011 PMATCH_ROW (\z. (NONE,z,[2])) (\z. F) (\z. z); 1012 PMATCH_ROW (\x. (NONE,x,[2])) (\x. T) (\x. x); 1013 PMATCH_ROW (\(z,y). (y,z,[2])) (\(z, y). IS_SOME y) (\(z, y). y) 1014 ]`` 1015 1016val t = `` 1017 PMATCH (x + y,ys) 1018 [PMATCH_ROW (\x. (x,[])) (\x. T) (\x. x); 1019 PMATCH_ROW (\ (x,y,ys). (x,y::ys)) (\ (x,y,ys). T) 1020 (\ (x,y,ys). my_d (x + y,ys))]`` 1021 1022 1023val t = ``PMATCH (x,y) 1024 [PMATCH_ROW (\x. (x,x)) (\x. T) (\x. T); 1025 PMATCH_ROW (\ (z, y). (z, y)) (\ (z, y). T) (\ (z, y). F)]`` 1026 1027 1028val rc_arg = [] 1029val col = 0 1030*) 1031 1032fun PMATCH_REMOVE_COL_AUX rc_arg col t = let 1033 val (v, rows) = dest_PMATCH t 1034 val (v', c_v) = pair_get_col col v 1035 val vs = free_vars c_v 1036 1037 val thm_row = let 1038 val thm = FRESH_TY_VARS_RULE PMATCH_ROW_REMOVE_FUN_VAR 1039 val thm = ISPEC v thm 1040 val thm = ISPEC v' thm 1041 in thm end 1042 1043 fun PMATCH_ROW_REMOVE_FUN_VAR_COL_AUX row = let 1044 val (vars_tm, pt, gt, rh) = dest_PMATCH_ROW_ABS_VARIANT vs row 1045 val vars = pairSyntax.strip_pair vars_tm 1046 val avoid = free_varsl [pt, gt, rh] 1047 1048 val (pt0', pv) = pair_get_col col pt 1049 val pt' = subst [pv |-> c_v] pt0' 1050 1051 val pv_i_opt = SOME (index (aconv pv) vars) handle HOL_ERR _ => NONE 1052 val (vars'_tm, f) = case pv_i_opt of 1053 (SOME pv_i) => (let 1054 (* we eliminate a variabe column *) 1055 val vars' = let 1056 val (vars', _) = replace_element vars pv_i [] 1057 in 1058 if (List.null vars') then [variant avoid ``_uv:unit``] else vars' 1059 end 1060 1061 val vars'_tm = pairSyntax.list_mk_pair vars' 1062 val g' = let 1063 val (vs, _) = replace_element vars pv_i [c_v] 1064 val vs_tm = pairSyntax.list_mk_pair vs 1065 in 1066 pairSyntax.mk_pabs (vars'_tm, vs_tm) 1067 end 1068 in 1069 (vars'_tm, g') 1070 end) 1071 | NONE => (let 1072 (* we eliminate a costant columns *) 1073 val (sub, _) = match_term pv c_v 1074 val _ = if List.all (fn x => List.exists (aconv (#redex x)) vars) sub then () else failwith "not a constant-col after all" 1075 1076 val vars' = filter (fn v => not (List.exists (fn x => (aconv v (#redex x))) sub)) vars 1077 val vars' = if (List.null vars') then [variant avoid ``_uv:unit``] else vars' 1078 val vars'_tm = pairSyntax.list_mk_pair vars' 1079 1080 val g' = pairSyntax.mk_pabs (vars'_tm, Term.subst sub vars_tm) 1081 in 1082 (vars'_tm, g') 1083 end) 1084 1085(* val f = pairSyntax.mk_pabs (vars_tm, pt) 1086 val f' = pairSyntax.mk_pabs (vars'_tm, pt') 1087 val g = pairSyntax.mk_pabs (vars_tm, rh) 1088 1089*) 1090 val p = pairSyntax.mk_pabs (vars_tm, pt) 1091 val p' = pairSyntax.mk_pabs (vars'_tm, pt') 1092 val g = pairSyntax.mk_pabs (vars_tm, gt) 1093 val r = pairSyntax.mk_pabs (vars_tm, rh) 1094 1095 val thm0 = let 1096 val thm = thm_row 1097 val thm = ISPEC f thm 1098 val thm = ISPEC p thm 1099 val thm = ISPEC g thm 1100 val thm = ISPEC r thm 1101 val thm = ISPEC p' thm 1102 1103 fun elim_conv_aux vs = ( 1104 (pairTools.PABS_INTRO_CONV vs) THENC 1105 (DEPTH_CONV (pairLib.PAIRED_BETA_CONV ORELSEC BETA_CONV)) 1106 ) 1107 1108 fun elim_conv vs = PMATCH_ROW_ARGS_CONV (elim_conv_aux vs) 1109 val thm = CONV_RULE ((RAND_CONV o RHS_CONV) (elim_conv vars'_tm)) thm 1110 1111 val tm_eq = mk_eq(lhs (rand (concl thm)), mk_comb (row, v)) 1112 val eq_thm = prove (tm_eq, rc_tac rc_arg) 1113 1114 val thm = CONV_RULE (RAND_CONV (LHS_CONV (K eq_thm))) thm 1115 in 1116 thm 1117 end 1118 1119 val pre_tm = fst (dest_imp (concl thm0)) 1120(* set_goal ([], pre_tm) *) 1121 val pre_thm = prove (pre_tm, rc_tac rc_arg) 1122 val thm1 = MP thm0 pre_thm 1123 in 1124 thm1 1125 end 1126 1127 fun process_row (row, thm) = let 1128 val row_thm = PMATCH_ROW_REMOVE_FUN_VAR_COL_AUX row 1129 val thmA = PMATCH_EXTEND_BOTH 1130 val thmB = HO_MATCH_MP thmA row_thm 1131 val thmC = HO_MATCH_MP thmB thm 1132 in 1133 thmC 1134 end 1135 1136 val base_thm = INST_TYPE [gamma |-> type_of t] (ISPECL [v, v'] PMATCH_EXTEND_BASE) 1137 val thm0 = List.foldl process_row base_thm (List.rev rows) 1138in 1139 thm0 1140end handle HOL_ERR _ => raise UNCHANGED 1141 1142 1143(*------------------------------------*) 1144(* remove a constructor from a column *) 1145(*------------------------------------*) 1146 1147(* 1148val t = `` 1149PMATCH (SOME y,x,l) 1150 [PMATCH_ROW (\x. (SOME 0,x,[])) (\x. T) (\x. x); 1151 PMATCH_ROW (\z. (SOME 1,z,[2])) (\z. F) (\z. z); 1152 PMATCH_ROW (\x. (SOME 3,x,[2])) (\x. T) (\x. x); 1153 PMATCH_ROW (\(z,y). (y,z,[2])) (\(z, y). IS_SOME y) (\(z, y). y) 1154 ]`` 1155 1156val rc_arg = [] 1157val col = 0 1158*) 1159 1160 1161fun PMATCH_REMOVE_FUN_AUX rc_arg col t = let 1162 val (v, rows) = dest_PMATCH t 1163 1164 val (ff_tm, ff_inv, ff_inv_var, c) = let 1165 val vs = pairSyntax.strip_pair v 1166 val c_args = List.nth(vs, col) 1167 val (c, args) = strip_comb c_args 1168 1169 val vs_vars = List.map (fn t => genvar (type_of t)) vs 1170 val args_vars = List.map (fn t => genvar (type_of t)) args 1171 1172 val (vars, _) = replace_element vs_vars col args_vars 1173 val (ff_res, _) = replace_element vs_vars col [list_mk_comb (c, args_vars)] 1174 val ff_tm = pairSyntax.mk_pabs (pairSyntax.list_mk_pair vars, 1175 pairSyntax.list_mk_pair ff_res) 1176 1177 fun ff_inv tt = let 1178 val tts = pairSyntax.strip_pair tt 1179 val tt_args = List.nth(tts, col) 1180 1181 val (c', args') = strip_comb tt_args 1182 val _ = if (aconv c c') then () else failwith "different constr" 1183 1184 val (vars,_) = replace_element tts col args' 1185 in 1186 pairSyntax.list_mk_pair vars 1187 end 1188 1189 fun ff_inv_var avoid tt = let 1190 val tts = pairSyntax.strip_pair tt 1191 val tt_col = List.nth(tts, col) 1192 1193 val _ = if (is_var tt_col) then () else failwith "no var" 1194 1195 val (var_basename, _) = dest_var (tt_col) 1196 val gen_fun = mk_var_gen (var_basename ^ "_") avoid; 1197 val args = map (fn t => gen_fun (type_of t)) args_vars 1198 1199 val (vars, _) = replace_element tts col args 1200 in 1201 (pairSyntax.list_mk_pair vars, tt_col, args) 1202 end 1203 1204 in 1205 (ff_tm, ff_inv, ff_inv_var, c) 1206 end 1207 1208 val ff_thm_tm = ``!x y. (^ff_tm x = ^ff_tm y) ==> (x = y)`` 1209 val ff_thm = prove (ff_thm_tm, rc_tac rc_arg) 1210 1211 val v' = ff_inv v 1212 1213 val PMATCH_ROW_REMOVE_FUN' = let 1214 val thm0 = FRESH_TY_VARS_RULE PMATCH_ROW_REMOVE_FUN 1215 val thm1 = ISPEC ff_tm thm0 1216 val thm2 = ISPEC v' thm1 1217 val thm3 = MATCH_MP thm2 ff_thm 1218 1219 val thm_v' = prove (``^ff_tm ^v' = ^v``, rc_tac rc_arg) 1220 val thm4 = CONV_RULE (STRIP_QUANT_CONV (LHS_CONV (RAND_CONV (K thm_v')))) thm3 1221 in 1222 thm4 1223 end 1224 1225 fun PMATCH_ROW_REMOVE_FUN_COL_AUX row = let 1226 val (vars_tm, pt, gt, rh) = dest_PMATCH_ROW_ABS row 1227 1228 val pt' = ff_inv pt 1229 val vpt' = pairSyntax.mk_pabs (vars_tm, pt') 1230 val vgt = pairSyntax.mk_pabs (vars_tm, gt) 1231 val vrh = pairSyntax.mk_pabs (vars_tm, rh) 1232 1233 val thm0 = ISPECL [vpt', vgt, vrh] PMATCH_ROW_REMOVE_FUN' 1234 val eq_thm_tm = mk_eq (lhs (concl thm0), mk_comb (row, v)) 1235 val eq_thm = prove (eq_thm_tm, rc_tac rc_arg) 1236 1237 val thm1 = CONV_RULE (LHS_CONV (K eq_thm)) thm0 1238 1239 val vi_conv = (pairTools.PABS_INTRO_CONV vars_tm) THENC 1240 (DEPTH_CONV (pairLib.PAIRED_BETA_CONV ORELSEC BETA_CONV)) 1241 1242 val thm2 = CONV_RULE (RHS_CONV (PMATCH_ROW_ARGS_CONV vi_conv)) thm1 1243 in 1244 thm2 1245 end 1246 1247 val thm_row = let 1248 val thm = FRESH_TY_VARS_RULE PMATCH_ROW_REMOVE_FUN_VAR 1249 val thm = ISPEC v thm 1250 val thm = ISPEC v' thm 1251 in thm end 1252 1253 fun PMATCH_ROW_REMOVE_VAR_COL_AUX row = let 1254 val (vars_tm, pt, gt, rh) = dest_PMATCH_ROW_ABS row 1255 val vars = pairSyntax.strip_pair vars_tm 1256 1257 val avoid = vars @ free_vars pt @ free_vars rh @ free_vars gt 1258 val (pt', pv, new_vars) = ff_inv_var avoid pt 1259 1260 val pv_i = index (aconv pv) vars 1261 1262 val vars' = let 1263 val (vars', _) = replace_element vars pv_i new_vars 1264 in 1265 if (List.null vars') then [variant avoid ``_uv:unit``] else vars' 1266 end 1267 1268 val vars'_tm = pairSyntax.list_mk_pair vars' 1269 val f_tm = let 1270 val c_v = list_mk_comb (c, new_vars) 1271 val (vs, _) = replace_element vars pv_i [c_v] 1272 val vs_tm = pairSyntax.list_mk_pair vs 1273 in 1274 pairSyntax.mk_pabs (vars'_tm, vs_tm) 1275 end 1276 1277 val vpt = pairSyntax.mk_pabs (vars_tm, pt) 1278 val vpt' = pairSyntax.mk_pabs (vars'_tm, pt') 1279 val vrh = pairSyntax.mk_pabs (vars_tm, rh) 1280 val vgt = pairSyntax.mk_pabs (vars_tm, gt) 1281 1282 val thm0 = let 1283 val thm = ISPEC f_tm thm_row 1284 val thm = ISPEC vpt thm 1285 val thm = ISPEC vgt thm 1286 val thm = ISPEC vrh thm 1287 val thm = ISPEC vpt' thm 1288 1289 fun elim_conv vs = PMATCH_ROW_ARGS_CONV ( 1290 (pairTools.PABS_INTRO_CONV vs) THENC 1291 (DEPTH_CONV (pairLib.PAIRED_BETA_CONV ORELSEC BETA_CONV)) 1292 ) 1293 1294 val thm = CONV_RULE ((RAND_CONV o RHS_CONV) (elim_conv vars'_tm)) thm 1295 1296 val tm_eq = mk_eq(lhs (rand (concl thm)), mk_comb (row, v)) 1297 val eq_thm = prove (tm_eq, rc_tac rc_arg) 1298 1299 val thm = CONV_RULE (RAND_CONV (LHS_CONV (K eq_thm))) thm 1300 in 1301 thm 1302 end 1303 1304 val pre_tm = fst (dest_imp (concl thm0)) 1305 val pre_thm = prove (pre_tm, rc_tac rc_arg) 1306 1307 val thm1 = MP thm0 pre_thm 1308 in 1309 thm1 1310 end 1311 1312 1313 fun process_row (row, thm) = let 1314 val row_thm = PMATCH_ROW_REMOVE_FUN_COL_AUX row handle HOL_ERR _ => 1315 PMATCH_ROW_REMOVE_VAR_COL_AUX row 1316 val thmA = PMATCH_EXTEND_BOTH 1317 val thmB = HO_MATCH_MP thmA row_thm 1318 val thmC = HO_MATCH_MP thmB thm 1319 in 1320 thmC 1321 end 1322 1323(* 1324 val row = el 1 (List.rev rows) 1325 val thm = base_thm 1326 val thm = thm0 1327*) 1328 1329 val base_thm = INST_TYPE [gamma |-> type_of t] (ISPECL [v, v'] PMATCH_EXTEND_BASE) 1330 val thm0 = foldl process_row base_thm (List.rev rows) 1331in 1332 thm0 1333end handle HOL_ERR _ => raise UNCHANGED 1334 1335 1336(*------------------------*) 1337(* Combine auxiliary funs *) 1338(*------------------------*) 1339 1340(* 1341val t = `` 1342PMATCH (SOME y,x,l) 1343 [PMATCH_ROW (\x. (SOME 0,x,[])) (\x. T) (\x. x); 1344 PMATCH_ROW (\z. z) (\z. F) (\z. (FST (SND z))); 1345 PMATCH_ROW (\x. (SOME 3,x)) (\x. T) (\x. (FST x)); 1346 PMATCH_ROW (\(z,y). (y,z,[2])) (\(z, y). IS_SOME y) (\(z, y). y) 1347 ]`` 1348val rc_arg = [] 1349*) 1350 1351fun PMATCH_SIMP_COLS_CONV_GENCALL rc_arg t = let 1352 val cols = dest_PMATCH_COLS t 1353(* 1354 val (col_v, col) = el 1 cols 1355 val (vars, col_pat) = el 3 col 1356*) 1357 fun do_match col_v (vars, col_pat) = let 1358 val (sub, _) = match_term col_pat col_v 1359 val vars_ok = List.all (fn x => (List.exists (aconv (#redex x)) vars)) sub 1360 in 1361 vars_ok 1362 end handle HOL_ERR _ => false 1363 1364 fun elim_col_ok (col_v, col) = 1365 List.all (do_match col_v) col 1366 1367 fun simp_col_ok (col_v, col) = let 1368 val (c, args) = strip_comb col_v 1369 val _ = if (List.null args) then failwith "elim_col instead" else () 1370 1371 fun check_line (vars, pt) = 1372 (List.exists (aconv pt) vars) orelse 1373 (aconv (fst (strip_comb pt)) c) 1374 in 1375 List.all check_line col 1376 end handle HOL_ERR _ => false 1377 1378 fun process_col i col = if (elim_col_ok col) then 1379 SOME (PMATCH_REMOVE_COL_AUX rc_arg i t) 1380 else if (simp_col_ok col) then 1381 SOME (PMATCH_REMOVE_FUN_AUX rc_arg i t) 1382 else NONE 1383 1384 val thm_opt = first_opt process_col cols 1385in 1386 case thm_opt of NONE => raise UNCHANGED 1387 | SOME thm => thm 1388end 1389 1390fun PMATCH_SIMP_COLS_CONV_GEN ssl = PMATCH_SIMP_COLS_CONV_GENCALL (ssl, NONE) 1391val PMATCH_SIMP_COLS_CONV = PMATCH_SIMP_COLS_CONV_GEN []; 1392 1393 1394(***********************************************) 1395(* Resort and add dummy columns *) 1396(***********************************************) 1397 1398(* 1399val t = ``PMATCH (s:'a option,x : 'a option, l:num list) 1400 [PMATCH_ROW (\_uv:unit. (NONE,NONE,[])) (\_uv. T) (\_uv. NONE); 1401 PMATCH_ROW (\z. (NONE,z,[2])) (\z. F) (\z. z); 1402 PMATCH_ROW (\(x, b). (SOME b,x,[2])) (\(x, b). T) (\(x, b). x); 1403 PMATCH_ROW (\(_0,y). (y,_0,[2])) (\(_0, y). IS_SOME y) (\(_0, y). y) 1404 ]`` 1405 1406val nv = ``((l:num list), x : 'a option, xx:'a, s:'a option, z:'b)`` 1407 1408val t = ``case (xs : num list) of [] => x | _ => HD xs`` 1409val t = ``case (xs : num list) of [] => x | _::_ => HD xs`` 1410val nv = ``(xs: num list, x:num)`` 1411val rc_arg = ([], NONE) 1412 1413*) 1414fun PMATCH_EXTEND_INPUT_CONV_GENCALL rc_arg nv t = let 1415 val (v, rows) = dest_PMATCH t 1416 val _ = if aconv v nv then raise UNCHANGED else () 1417 1418 val (new_pat_tm, new_col_vars, new_col_subst, old_col_vars, f'_tm, nv_vars_pair) = let 1419 val nv_parts = pairSyntax.strip_pair nv 1420 val v_parts = pairSyntax.strip_pair v 1421 val v_l = map (fn t => (t, genvar (type_of t))) v_parts 1422 1423 val avoid = all_varsl [t, nv] 1424 val gen_fun = mk_var_gen "_" avoid 1425 1426 fun compute_nv_l res_l nv_vars v_l [] = 1427 if (null v_l) then (res_l, nv_vars) else failwith "PMATCH_EXTEND_INPUT_CONV_AUX: part missing" 1428 1429 | compute_nv_l res nv_vars v_l (p::nv_parts) = let 1430 val ((_, v_v), v_l') = pluck (fn (t, _) => aconv p t) v_l 1431 in 1432 compute_nv_l (v_v::res) nv_vars v_l' nv_parts 1433 end handle HOL_ERR _ => let 1434 val vg = gen_fun (type_of p) 1435 in 1436 compute_nv_l (vg::res) (vg::nv_vars) v_l nv_parts 1437 end 1438 1439 val (res_l, nv_vars) = compute_nv_l [] [] v_l nv_parts 1440 1441 val t0 = pairSyntax.list_mk_pair (List.rev res_l) 1442 1443 val nv_parts_vars = filter (fn (v, n) => is_var n) (zip (List.rev res_l) nv_parts) 1444 val t1 = pairSyntax.list_mk_pair (map fst nv_parts_vars) 1445 val f'_tm = pairSyntax.mk_pabs(t0, t1) 1446 val nv_vars_pair = pairSyntax.list_mk_pair (map snd nv_parts_vars) 1447 val nv_parts_subst = map (fn (v, n) => v |-> n) nv_parts_vars 1448 in 1449 (t0, List.rev nv_vars, nv_parts_subst, List.map snd v_l, f'_tm, nv_vars_pair) 1450 end 1451 1452 val thm_row = let 1453 val thm = FRESH_TY_VARS_RULE PMATCH_ROW_EXTEND_INPUT 1454 val thm = ISPEC v thm 1455 val thm = ISPEC nv thm 1456 val thm = ISPEC f'_tm thm 1457 in thm end 1458 1459 fun process_row_aux row = let 1460 val (pt, gt, rh) = dest_PMATCH_ROW row 1461 val (vars_tm, pt_b) = pairSyntax.dest_pabs pt 1462 1463 val (old_vars_tm, new_vars) = 1464 if ((type_of vars_tm) = one_ty) then ( 1465 if (List.null new_col_vars) then 1466 (one_tm, [vars_tm]) 1467 else (one_tm, new_col_vars) 1468 ) else ( 1469 let val ov = pairSyntax.strip_pair vars_tm 1470 in (vars_tm, ov @ new_col_vars) end 1471 ) 1472 1473 val pat_vars_s0 = let 1474 val pt_ps = pairSyntax.strip_pair pt_b 1475 in map2 (fn v => fn t => (v |-> t)) old_col_vars pt_ps end 1476 1477 val new_vars_s0 = filter (fn vp => free_in (#residue vp) row 1478 andalso is_var (subst pat_vars_s0 (#redex vp))) new_col_subst 1479 1480 val new_vars_s = map (fn vp => subst pat_vars_s0 (#redex vp) |-> #residue vp) new_vars_s0 1481 val pat_vars_s = map (fn vp => #redex vp |-> subst new_vars_s (#residue vp)) pat_vars_s0 1482 1483 val new_vars_tm = subst new_vars_s (pairSyntax.list_mk_pair new_vars) 1484 val pt_b' = subst pat_vars_s (subst new_vars_s new_pat_tm) 1485 1486 val f_tm = pairSyntax.mk_pabs (new_vars_tm, subst new_vars_s old_vars_tm) 1487 val pt' = pairSyntax.mk_pabs (new_vars_tm, pt_b') 1488 1489 val thm0 = let 1490 val thm = ISPEC f_tm thm_row 1491 val thm = ISPEC pt thm 1492 val thm = ISPEC (pairSyntax.mk_pabs (nv_vars_pair, gt)) thm 1493 val thm = ISPEC (pairSyntax.mk_pabs (nv_vars_pair, rh)) thm 1494 val thm = ISPEC pt' thm 1495 1496 fun elim_conv_aux vs = ( 1497 (pairTools.PABS_INTRO_CONV vs) THENC 1498 (DEPTH_CONV (pairLib.PAIRED_BETA_CONV ORELSEC BETA_CONV)) 1499 ) 1500 1501 fun elim_conv vs = PMATCH_ROW_ARGS_CONV (elim_conv_aux vs) 1502 val thm = CONV_RULE ((RAND_CONV o RHS_CONV) (elim_conv new_vars_tm)) thm 1503 in 1504 thm 1505 end 1506 1507 val pre_tm = fst (dest_imp (concl thm0)) 1508(* set_goal ([], pre_tm) *) 1509 val pre_thm = prove (pre_tm, rc_tac rc_arg) 1510 val thm1 = MP thm0 pre_thm 1511 1512 val eq_tm = mk_eq (mk_comb (row, v), lhs (concl thm1)) 1513 val eq_thm = prove (eq_tm, SIMP_TAC std_ss []) 1514 val thm2 = TRANS eq_thm thm1 1515 1516 (* fix wildcards *) 1517 val thm3 = CONV_RULE (RHS_CONV (RATOR_CONV (PMATCH_ROW_INTRO_WILDCARDS_CONV))) thm2 1518 in 1519 thm3 1520 end 1521 1522 fun process_row (row, thm) = let 1523 val row_thm = process_row_aux row 1524 val thmA = PMATCH_EXTEND_BOTH 1525 val thmB = HO_MATCH_MP thmA row_thm 1526 val thmC = HO_MATCH_MP thmB thm 1527 in 1528 thmC 1529 end 1530 1531 val base_thm = INST_TYPE [gamma |-> type_of t] (ISPECL [v, nv] PMATCH_EXTEND_BASE) 1532 val thm0 = List.foldl process_row base_thm (List.rev rows) 1533in 1534 thm0 1535end handle HOL_ERR _ => raise UNCHANGED 1536 1537 1538fun PMATCH_EXTEND_INPUT_CONV_GEN ssl = PMATCH_EXTEND_INPUT_CONV_GENCALL (ssl, NONE) 1539val PMATCH_EXTEND_INPUT_CONV = PMATCH_EXTEND_INPUT_CONV_GEN []; 1540 1541 1542(***********************************************) 1543(* Expand columns *) 1544(***********************************************) 1545 1546(* Sometimes not all rows of a PMATCH have the same number of 1547 explicit columns. This can happen, if some patterns are 1548 explicit pairs, while others are not. The following tries 1549 to expand columns into explicit ones. *) 1550 1551(* 1552val t = `` 1553PMATCH (SOME y,x,l) 1554 [PMATCH_ROW (\x. (SOME 0,x,[])) (\x. T) (\x. x); 1555 PMATCH_ROW (\z. z) (\z. F) (\z. (FST (SND z))); 1556 PMATCH_ROW (\x. (SOME 3,x)) (\x. T) (\x. (FST x)); 1557 PMATCH_ROW (\(z,y). (y,z,[2])) (\(z, y). IS_SOME y) (\(z, y). y) 1558 ]`` 1559*) 1560 1561fun PMATCH_EXPAND_COLS_CONV t = let 1562 val (v, rows) = dest_PMATCH t 1563 1564 val col_no_v = length (pairSyntax.strip_pair v) 1565 val col_no = foldl (fn (r, m) => let 1566 val (pt', _, _) = dest_PMATCH_ROW r 1567 val (_, pt) = pairSyntax.dest_pabs pt' 1568 val m' = length (pairSyntax.strip_pair pt) 1569 val m'' = if m' > m then m' else m 1570 in m'' end) col_no_v rows 1571 1572 fun split_var avoid cols l = let 1573 fun splits acc no ty = if (no = 0) then List.rev (ty::acc) else 1574 let 1575 val (ty_s, ty') = pairSyntax.dest_prod ty 1576 in 1577 splits (ty_s::acc) (no - 1) ty' 1578 end 1579 1580 val types = splits [] (col_no - cols) (type_of l) 1581 1582 val var_basename = fst (dest_var l) handle HOL_ERR _ => "v" 1583 val gen_fun = mk_var_gen (var_basename ^ "_") avoid; 1584 val new_vars = map gen_fun types 1585 in 1586 new_vars 1587 end 1588 1589 fun PMATCH_ROW_EXPAND_COLS row = let 1590 val (vars_tm, pt, gt, rh) = dest_PMATCH_ROW_ABS row 1591 1592 val vars = pairSyntax.strip_pair vars_tm 1593 val pts = pairSyntax.strip_pair pt 1594 val cols = length pts 1595 1596 val _ = if (cols < col_no) then () else failwith "nothing to do" 1597 val l = last pts 1598 1599 val _ = if (List.exists (aconv l) vars) then () else failwith "nothing to do" 1600 1601 val avoids = vars @ free_vars pt @ free_vars gt @ free_vars rh 1602 val new_vars = split_var avoids cols l 1603 1604 val sub = [l |-> pairSyntax.list_mk_pair new_vars] 1605 val pt' = Term.subst sub pt 1606 val gt' = Term.subst sub gt 1607 val rh' = Term.subst sub rh 1608 val vars' = pairSyntax.strip_pair (Term.subst sub vars_tm) 1609 1610 val row' = mk_PMATCH_ROW_PABS vars' (pt', gt', rh') 1611 1612 val eq_tm = mk_eq(row, row') 1613 val eq_thm = prove (eq_tm, rc_tac ([], NONE)) 1614 val thm = AP_THM eq_thm v 1615 in 1616 SOME thm 1617 end handle HOL_ERR _ => NONE 1618 1619 val rows = List.rev rows 1620 val row_thms = map PMATCH_ROW_EXPAND_COLS rows 1621 val _ = if (exists isSome row_thms) then () else raise UNCHANGED 1622 1623 fun process_row ((row_thm_opt, row), thm) = let 1624 val row_thm = case row_thm_opt of 1625 NONE => REFL (mk_comb (row, v)) 1626 | SOME thm => thm 1627 val thmA = PMATCH_EXTEND_BOTH 1628 val thmB = HO_MATCH_MP thmA row_thm 1629 val thmC = HO_MATCH_MP thmB thm 1630 in 1631 thmC 1632 end 1633 1634 val base_thm = INST_TYPE [gamma |-> type_of t] (ISPECL [v, v] PMATCH_EXTEND_BASE) 1635 val thm0 = foldl process_row base_thm (zip row_thms rows) 1636in 1637 thm0 1638end handle HOL_ERR _ => raise UNCHANGED; 1639 1640 1641(***********************************************) 1642(* PMATCH_SIMP_CONV *) 1643(***********************************************) 1644 1645(* 1646val t = `` 1647PMATCH (SOME y,x,l) 1648 [PMATCH_ROW (\x. (SOME 0,x,[])) (\y. T) (\x. x); 1649 PMATCH_ROW (\z. z) (\z. F) (\z. (FST (SND z))); 1650 PMATCH_ROW (\x. (SOME 3,x)) (\x. T) (\x. (FST x)); 1651 PMATCH_ROW (\(z,y). (y,z,[2])) (\(z, y). IS_SOME y) (\(z, y). y) 1652 ]`` 1653*) 1654 1655val PMATCH_NORMALISE_CONV_AUX = 1656EVERY_CONV [ 1657 TRY_CONV (QCHANGED_CONV PMATCH_CLEANUP_PVARS_CONV), 1658 TRY_CONV (QCHANGED_CONV PMATCH_FORCE_SAME_VARS_CONV), 1659 TRY_CONV (QCHANGED_CONV PMATCH_EXPAND_COLS_CONV), 1660 TRY_CONV (QCHANGED_CONV PMATCH_INTRO_WILDCARDS_CONV) 1661]; 1662 1663fun PMATCH_NORMALISE_CONV t = 1664 if (is_PMATCH t) then PMATCH_NORMALISE_CONV_AUX t else raise UNCHANGED; 1665 1666val PMATCH_NORMALISE_ss = 1667 simpLib.conv_ss 1668 {name = "PMATCH_NORMALISE_CONV", 1669 trace = 2, 1670 key = SOME ([],``PMATCH (p:'a) (rows : ('a -> 'b option) list)``), 1671 conv = K (K PMATCH_NORMALISE_CONV)} 1672 1673 1674fun PMATCH_SIMP_CONV_GENCALL_AUX rc_arg = 1675(TRY_CONV PMATCH_NORMALISE_CONV_AUX) THENC 1676REPEATC (FIRST_CONV [ 1677 QCHANGED_CONV (PMATCH_CLEANUP_CONV_GENCALL rc_arg), 1678 QCHANGED_CONV (PMATCH_SIMP_COLS_CONV_GENCALL rc_arg), 1679 QCHANGED_CONV (PMATCH_REMOVE_FAST_REDUNDANT_CONV_GENCALL rc_arg), 1680 QCHANGED_CONV (PMATCH_REMOVE_FAST_SUBSUMED_CONV_GENCALL false rc_arg) 1681]); 1682 1683fun PMATCH_SIMP_CONV_GENCALL rc_arg t = 1684 if (is_PMATCH t) then PMATCH_SIMP_CONV_GENCALL_AUX rc_arg t else 1685 raise UNCHANGED 1686 1687fun PMATCH_SIMP_CONV_GEN ssl = PMATCH_SIMP_CONV_GENCALL (ssl, NONE) 1688 1689val PMATCH_SIMP_CONV = PMATCH_SIMP_CONV_GEN []; 1690 1691fun PMATCH_SIMP_GEN_ss ssl = 1692 make_gen_conv_ss PMATCH_SIMP_CONV_GENCALL "PMATCH_SIMP_REDUCER" ssl 1693 1694val PMATCH_SIMP_ss = name_ss "patternMatchesSimp" (PMATCH_SIMP_GEN_ss []) 1695val _ = BasicProvers.augment_srw_ss [PMATCH_SIMP_ss]; 1696 1697 1698fun PMATCH_FAST_SIMP_CONV_GENCALL_AUX rc_arg = 1699REPEATC (FIRST_CONV [ 1700 QCHANGED_CONV (PMATCH_CLEANUP_CONV_GENCALL rc_arg), 1701 QCHANGED_CONV (PMATCH_SIMP_COLS_CONV_GENCALL rc_arg) 1702]); 1703 1704fun PMATCH_FAST_SIMP_CONV_GENCALL rc_arg t = 1705 if (is_PMATCH t) then PMATCH_FAST_SIMP_CONV_GENCALL_AUX rc_arg t else 1706 raise UNCHANGED 1707 1708fun PMATCH_FAST_SIMP_CONV_GEN ssl = PMATCH_FAST_SIMP_CONV_GENCALL (ssl, NONE) 1709 1710val PMATCH_FAST_SIMP_CONV = PMATCH_FAST_SIMP_CONV_GEN []; 1711 1712fun PMATCH_FAST_SIMP_GEN_ss ssl = 1713 make_gen_conv_ss PMATCH_FAST_SIMP_CONV_GENCALL "PMATCH_FAST_SIMP_REDUCER" ssl 1714 1715val PMATCH_FAST_SIMP_ss = name_ss "patternMatchesFastSimp" (PMATCH_FAST_SIMP_GEN_ss []) 1716 1717 1718(***********************************************) 1719(* Remove double var bindings *) 1720(***********************************************) 1721 1722fun force_unique_vars s no_change avoid t = 1723 case Psyntax.dest_term t of 1724 Psyntax.VAR (_, _) => 1725 if (mem t no_change) then (s, avoid, t) else 1726 let 1727 val v' = variant avoid t 1728 val avoid' = v'::avoid 1729 val s' = if (v' = t) then s else ((v', t)::s) 1730 in (s', avoid', v') end 1731 | Psyntax.CONST _ => (s, avoid, t) 1732 | Psyntax.LAMB (v, t') => let 1733 val (s', avoid', t'') = force_unique_vars s (v::no_change) 1734 (v::avoid) t' 1735 in 1736 (s', avoid', mk_abs (v, t'')) 1737 end 1738 | Psyntax.COMB (t1, t2) => let 1739 val (s', avoid', t1') = force_unique_vars s no_change avoid t1 1740 val (s'', avoid'', t2') = force_unique_vars s' no_change avoid' t2 1741 in 1742 (s'', avoid'', mk_comb (t1', t2')) 1743 end; 1744 1745(* 1746val row = ``PMATCH_ROW (\ (x,y). (x, SOME y, SOME x, SOME z, (x+z))) 1747 (\ (x, y). P x y) (\ (x, y). f x y)`` 1748 1749val row = ``PMATCH_ROW (\ (x,y). (x, SOME y, SOME z, SOME z, (z+z))) 1750 (\ (x, y). P x y) (\ (x, y). f x y)`` 1751*) 1752 1753fun PMATCH_ROW_REMOVE_DOUBLE_BIND_CONV_GENCALL rc_arg row = let 1754 val _ = if not (is_PMATCH_ROW row) then raise UNCHANGED else () 1755 val (p_t, g_t, r_t) = dest_PMATCH_ROW row 1756 val (vars_tm, p_tb) = pairSyntax.dest_pabs p_t 1757 val vars = pairSyntax.strip_pair vars_tm 1758 1759 val (new_binds, _, p_tb') = force_unique_vars [] [] (free_vars p_t) p_tb 1760 val _ = if List.null new_binds then raise UNCHANGED else () 1761 1762 val vars' = vars @ (List.map fst new_binds) 1763 val g_v = genvar (type_of g_t) 1764 val r_v = genvar (type_of r_t) 1765 1766 1767 val g_t' = list_mk_conj ((List.map mk_eq new_binds)@[mk_comb (g_v, vars_tm)]) 1768 val r_t' = mk_comb (r_v, vars_tm) 1769 1770 val row' = mk_PMATCH_ROW_PABS vars' (p_tb', g_t', r_t') 1771 val row0 = mk_PMATCH_ROW (p_t, g_v, r_v) 1772 1773 val thm0_tm = mk_eq (row0, row') 1774 val thm0 = let 1775 val thm0 = FRESH_TY_VARS_RULE PMATCH_ROW_REMOVE_DOUBLE_BINDS_THM 1776 val g_tm = pairSyntax.mk_pabs (vars_tm, 1777 subst (List.map (fn (v, v') => (v |-> v')) new_binds) 1778 (pairSyntax.list_mk_pair vars')) 1779 val thm1 = ISPEC g_tm thm0 1780 val thm2 = PART_MATCH rand thm1 thm0_tm 1781 val thm3 = rc_elim_precond rc_arg thm2 1782 in 1783 thm3 1784 end 1785 1786 val thm1 = INST [(g_v |-> g_t), (r_v |-> r_t)] thm0 1787 1788 val thm1a_tm = mk_eq (row, lhs (concl thm1)) 1789 val thm1a = prove (thm1a_tm, rc_tac rc_arg) 1790 1791 val thm2 = TRANS thm1a thm1 1792 1793 val thm3 = CONV_RULE (RHS_CONV (DEPTH_CONV pairLib.GEN_BETA_CONV)) thm2 1794 1795 val thm4 = CONV_RULE (RHS_CONV (RATOR_CONV (RAND_CONV (REWRITE_CONV [])))) thm3 1796in 1797 thm4 1798end handle HOL_ERR _ => raise UNCHANGED 1799 1800fun PMATCH_REMOVE_DOUBLE_BIND_CONV_GENCALL rc_arg t = 1801 PMATCH_ROWS_CONV (PMATCH_ROW_REMOVE_DOUBLE_BIND_CONV_GENCALL 1802 rc_arg) t 1803 1804fun PMATCH_REMOVE_DOUBLE_BIND_CONV_GEN ssl = 1805 PMATCH_REMOVE_DOUBLE_BIND_CONV_GENCALL (ssl, NONE) 1806 1807val PMATCH_REMOVE_DOUBLE_BIND_CONV = PMATCH_REMOVE_DOUBLE_BIND_CONV_GEN []; 1808 1809fun PMATCH_REMOVE_DOUBLE_BIND_GEN_ss ssl = 1810 make_gen_conv_ss PMATCH_ROW_REMOVE_DOUBLE_BIND_CONV_GENCALL "PMATCH_REMOVE_DOUBLE_BIND_REDUCER" ssl 1811 1812val PMATCH_REMOVE_DOUBLE_BIND_ss = PMATCH_REMOVE_DOUBLE_BIND_GEN_ss [] 1813 1814 1815(***********************************************) 1816(* Remove a GUARD *) 1817(***********************************************) 1818 1819(* 1820val t = ``case (x, y) of 1821 | (x, 2) when EVEN x => x + x 1822 | (SUC x, y) when ODD x => y + x + SUC x 1823 | (SUC x, 1) => x 1824 | (x, _) => x+3`` 1825 1826val rc_arg = ([], NONE) 1827val rows = 0 1828*) 1829 1830fun PMATCH_REMOVE_GUARD_AUX rc_arg t = let 1831 val (v, rows) = dest_PMATCH t 1832 1833 fun find_row_to_split rs1 rs = case rs of 1834 [] => raise UNCHANGED (* nothing found *) 1835 | (r:: rs') => let 1836 val (_, _, g, _) = dest_PMATCH_ROW_ABS r 1837 val g_simple = ((g = T) orelse (g = F)) 1838 in 1839 if g_simple then 1840 find_row_to_split (r::rs1) rs' 1841 else let 1842 val r_ty = type_of r 1843 val rs1_tm = listSyntax.mk_list (List.rev rs1, r_ty) 1844 val rs2_tm = listSyntax.mk_list (rs', r_ty) 1845 in 1846 (rs1_tm, r, rs2_tm) 1847 end 1848 end 1849 1850 val (rs1, r, rs2) = find_row_to_split [] rows 1851 1852 val thm = let 1853 val thm0 = FRESH_TY_VARS_RULE GUARDS_ELIM_THM 1854 val (p_tm, g_tm, r_tm) = dest_PMATCH_ROW r 1855 val thm1 = ISPECL [v, rs1, rs2, p_tm, g_tm, r_tm] thm0 1856 1857 val thm2 = rc_elim_precond rc_arg thm1 1858 val thm3 = fix_appends rc_arg t thm2 1859 in 1860 thm3 1861 end 1862 1863 val thm2 = CONV_RULE (RHS_CONV (RAND_CONV (RAND_CONV (RATOR_CONV (RAND_CONV PMATCH_ROW_FORCE_SAME_VARS_CONV))))) thm 1864 1865in 1866 thm2 1867end handle HOL_ERR _ => raise UNCHANGED 1868 1869 1870 1871fun PMATCH_REMOVE_GUARDS_CONV_GENCALL rc_arg t = let 1872 val thm0 = REPEATC (PMATCH_REMOVE_GUARD_AUX rc_arg) t 1873 val m_ss = simpLib.merge_ss (fst rc_arg) 1874 val c = SIMP_CONV (std_ss ++ m_ss ++ 1875 PMATCH_SIMP_GEN_ss (fst rc_arg)) [] 1876 val thm1 = CONV_RULE (RHS_CONV c) thm0 1877in 1878 thm1 1879end handle HOL_ERR _ => raise UNCHANGED 1880 1881fun PMATCH_REMOVE_GUARDS_CONV_GEN ssl = PMATCH_REMOVE_GUARDS_CONV_GENCALL (ssl, NONE) 1882 1883val PMATCH_REMOVE_GUARDS_CONV = PMATCH_REMOVE_GUARDS_CONV_GEN []; 1884 1885fun PMATCH_REMOVE_GUARDS_GEN_ss ssl = 1886 make_gen_conv_ss PMATCH_REMOVE_GUARDS_CONV_GENCALL "PMATCH_REMOVE_GUARDS_REDUCER" ssl 1887 1888val PMATCH_REMOVE_GUARDS_ss = PMATCH_REMOVE_GUARDS_GEN_ss [] 1889 1890 1891 1892(***********************************************) 1893(* PATTERN COMPILATION *) 1894(***********************************************) 1895 1896(* A column heuristic is a function that chooses the 1897 next column to perform a case split on. 1898 It gets a list of columns of the pattern match, i.e. 1899 the input value + a list of the patterns in each row. 1900 The patterns are represented as a pair of 1901 a list of free variables and the real pattern. *) 1902type column = (term * (term list * term) list) 1903type column_heuristic = column list -> int 1904 1905(* one that uses always the first column *) 1906val colHeu_first_col : column_heuristic = (fn _ => 0) 1907 1908(* one that uses always the last column *) 1909val colHeu_last_col : column_heuristic = (fn cols => length cols - 1) 1910 1911(* A heuristic based on ranking functions *) 1912type column_ranking_fun = (term * (term list * term) list) -> int 1913 1914fun colHeu_rank (rankL : column_ranking_fun list) = (fn colL => let 1915 val ncolL = Lib.enumerate 0 colL 1916 fun step rank ncolL = let 1917 val ranked_cols = List.map (fn (i, c) => ((i, c), rank c)) ncolL 1918 val max = List.foldl (fn ((_, r), m) => if r > m then r else m) (snd (hd ranked_cols)) (tl ranked_cols) 1919 val ranked_cols' = List.filter (fn (_, r) => r = max) ranked_cols 1920 val ncolL' = List.map fst ranked_cols' 1921 in 1922 ncolL' 1923 end 1924 fun steps [] ncolL = ncolL 1925 | steps _ [] = [] 1926 | steps _ [e] = [e] 1927 | steps (rf :: rankL) ncolL = steps rankL (step rf ncolL) 1928 val ncolL' = steps rankL ncolL 1929in 1930 case ncolL' of 1931 [] => 0 (* something went wrong, should not happen *) 1932 | ((i, _) :: _) => i 1933end) : column_heuristic 1934 1935 1936(* ranking functions *) 1937fun colRank_first_row (_:term, rows) = ( 1938 case rows of 1939 [] => 0 1940 | (vs, p) :: _ => 1941 if (is_var p andalso mem p vs) then 0 else 1); 1942 1943fun colRank_first_row_constr db (_, rows) = case rows of 1944 [] => 0 1945 | ((vs, p) :: _) => if (is_var p andalso mem p vs) then 0 else 1946 case pmatch_compile_db_compile_cf db rows of 1947 NONE => 0 1948 | SOME cf => let 1949 val (exh, constrL) = constructorFamily_get_constructors cf; 1950 val p_c = fst (strip_comb p) 1951 val cL_cf = List.map fst constrL; 1952 val p_c_ok = op_mem same_const p_c cL_cf 1953 in 1954 (if p_c_ok then 1 else 0) 1955 end handle HOL_ERR _ => 0; 1956 1957val colRank_constr_prefix : column_ranking_fun = (fn (_, rows) => 1958 let fun aux n [] = n 1959 | aux n ((vs, p) :: pL) = if (is_var p) 1960 then n else aux (n+1) pL 1961 in aux 0 rows end) 1962 1963 1964fun col_get_constr_set db (_, rows) = 1965 case pmatch_compile_db_compile_cf db rows of 1966 NONE => NONE 1967 | SOME cf => let 1968 val (exh, constrL) = constructorFamily_get_constructors cf; 1969 val cL_rows = List.map (fn (_, p) => fst (strip_comb p)) rows; 1970 val cL_cf = List.map fst constrL; 1971 1972 val cL_rows' = List.filter (fn c => op_mem same_const c cL_cf) cL_rows; 1973 val cL_rows'' = Lib.mk_set cL_rows'; 1974 in 1975 SOME (cL_rows'', cL_cf, exh) 1976 end 1977 1978fun col_get_nonvar_set (_, rows) = 1979 let 1980 val cL' = List.filter (fn (vs, p) => 1981 not (is_var p andalso mem p vs)) rows; 1982 val cL'' = Lib.mk_set cL'; 1983 in 1984 cL'' 1985 end 1986 1987fun colRank_small_branching_factor db : column_ranking_fun = (fn col => 1988 case col_get_constr_set db col of 1989 SOME (cL, full_constrL, exh) => 1990 (~(length cL + (if exh then 0 else 1) + (if length cL = length full_constrL then 0 else 1))) 1991 | NONE => (~(length (col_get_nonvar_set col) + 2))) 1992 1993fun colRank_arity db : column_ranking_fun = (fn col => 1994 case col_get_constr_set db col of 1995 SOME (cL, full_constrL, exh) => 1996 ~(List.foldl (fn (c, s) => s + length (fst (strip_fun (type_of c)))) 0 cL) 1997 | NONE => 0) 1998 1999 2000(* heuristics defined using ranking functions *) 2001val colHeu_first_row = colHeu_rank [colRank_first_row] 2002val colHeu_constr_prefix = colHeu_rank [colRank_constr_prefix] 2003fun colHeu_qba db = colHeu_rank [colRank_constr_prefix, colRank_small_branching_factor db, colRank_arity db] 2004fun colHeu_cqba db = colHeu_rank [colRank_first_row_constr db, 2005 colRank_constr_prefix, colRank_small_branching_factor db, colRank_arity db] 2006 2007(* A list of all the standard heuristics *) 2008fun colHeu_default cols = colHeu_qba (!thePmatchCompileDB) cols 2009 2010 2011(* Now we can define a case-split function that performs 2012 case-splits using such heuristics. *) 2013 2014(* 2015val t = ``case (a,x,xs) of 2016 | (NONE,x,[]) when x > 5 => x 2017 | (NONE,x,_) => SUC x`` 2018 2019val t = ``case (a,x,xs) of 2020 | (NONE,x,[]) => x 2021 | (NONE,x,[2]) => x 2022 | (NONE,x,[v18]) => 3 2023 | (NONE,x,v12::v16::v17) => 3 2024 | (y,x,z,zs) .| (SOME y,x,[z]) => (x + 5 + z) 2025 | (y,v23,v24) .| (SOME y,0,v23::v24) => (v23 + y) 2026 | (y,z,v23) .| (SOME y,SUC z,[v23]) when (y > 5) => 3 2027 | (y,z) .| (SOME y,SUC z,[1; 2]) => (y + z) 2028 `` 2029*) 2030 2031fun literal_case_CONV c tt = if boolSyntax.is_literal_case tt then 2032 RATOR_CONV (RAND_CONV (ABS_CONV c)) tt else c tt 2033 2034val literal_cong_stop = prove( 2035 ``(v = v') ==> (literal_case (f:'a -> 'b) v = literal_case f v')``, 2036 SIMP_TAC std_ss []) 2037 2038fun PMATCH_CASE_SPLIT_AUX rc_arg col_no expand_thm t = let 2039 val (v, rows) = dest_PMATCH t 2040 val vs = pairSyntax.strip_pair v 2041 2042 val arg = el (col_no+1) vs 2043 val arg_v = genvar (type_of arg) 2044 val vs' = pairSyntax.list_mk_pair (fst ( 2045 replace_element vs col_no [arg_v])) 2046 2047 val ff = let 2048 val (x, xs) = strip_comb t 2049 val t' = list_mk_comb(x, vs' :: (tl xs)) 2050 in 2051 mk_abs (arg_v, t') 2052 end 2053 2054 val thm0 = ISPEC arg (ISPEC ff expand_thm) 2055 val thm1 = CONV_RULE (LHS_CONV BETA_CONV) thm0 2056 2057 val c' = REPEATC ( 2058 TRY_CONV (QCHANGED_CONV (PMATCH_CLEANUP_CONV_GENCALL rc_arg)) THENC 2059 TRY_CONV (QCHANGED_CONV (PMATCH_SIMP_COLS_CONV_GENCALL rc_arg)) THENC 2060 TRY_CONV (REWR_CONV PMATCH_INCOMPLETE_def) 2061 ); 2062 2063 fun c tt = let 2064 val _ = let 2065 val (t0, _) = dest_comb tt 2066 val (v', _) = dest_abs t0 2067 in 2068 if (aconv arg_v v') then () else failwith "not a new position"! 2069 end 2070 in 2071 (BETA_CONV THENC c') tt 2072 end; 2073 2074 val thm2 = CONV_RULE (RHS_CONV (TOP_SWEEP_CONV c)) thm1 2075 2076 (* check whether it got simpler, if not try full simp including propagating 2077 case information *) 2078 val thm3 = if (does_conv_loop thm2) then let 2079 val thm3 = CONV_RULE (RHS_CONV (literal_case_CONV (SIMP_CONV ( 2080 (std_ss++simpLib.merge_ss (fst rc_arg) ++ PMATCH_SIMP_GEN_ss (fst rc_arg))) [PMATCH_INCOMPLETE_def, Cong literal_cong_stop]))) thm2 2081 val _ = if (does_conv_loop thm3) then raise UNCHANGED else () 2082 in thm3 end 2083 else thm2 2084in 2085 thm3 2086end 2087 2088(* 2089val t = t' 2090val col_no = 1 2091val rc_arg = ([], NONE) 2092val gl = [] 2093val callback_opt = NONE 2094val db = !thePmatchCompileDB 2095val col_heu = colHeu_default 2096val t = ``case x of 3 => 1 | _ => 0`` 2097*) 2098 2099fun PMATCH_CASE_SPLIT_CONV_GENCALL_STEP (gl, callback_opt) db col_heu t = let 2100 val _ = if (is_PMATCH t) then () else raise UNCHANGED 2101 2102 fun find_col cols = if (List.null cols) then raise UNCHANGED else let 2103 val col_no = col_heu cols 2104 val (v, col) = el (col_no+1) cols 2105 val res = pmatch_compile_db_compile db col 2106 in 2107 case res of 2108 SOME (expand_thm, _, expand_ss) => (col_no, expand_thm, expand_ss) 2109 | NONE => let 2110 val (cols', _) = replace_element cols col_no [] 2111 val (col_no', expand_thm, expand_ss) = find_col cols' 2112 val col_no'' = if (col_no' < col_no) then col_no' else col_no'+1 2113 in 2114 (col_no'', expand_thm, expand_ss) 2115 end 2116 end 2117 2118 val (col_no, expand_thm, expand_ss) = find_col (dest_PMATCH_COLS t) 2119 val thm1 = QCHANGED_CONV (PMATCH_CASE_SPLIT_AUX 2120 (expand_ss::gl, callback_opt) col_no expand_thm) t 2121 2122 (* check whether it got simpler *) 2123 val _ = if (does_conv_loop thm1) then raise UNCHANGED else () 2124in 2125 thm1 2126end 2127 2128 2129val pair_CASE_tm = mk_const ("pair_CASE", ``:'a # 'b -> ('a -> 'b -> 'c) -> 'c``) 2130 2131fun PMATCH_CASE_SPLIT_CONV_GENCALL rc_arg db col_heu t = let 2132 val thm0 = PMATCH_SIMP_CONV_GENCALL rc_arg t handle 2133 HOL_ERR _ => REFL t 2134 | UNCHANGED => REFL t 2135 val t' = rhs (concl thm0) 2136 2137 val cols = dest_PMATCH_COLS t' 2138 val col_no = length cols 2139 val (v, rows) = dest_PMATCH t' 2140 val rows_tm = rand t' 2141 2142 fun mk_pair avoid acc col_no v = if (col_no <= 1) then ( 2143 let 2144 val vs = List.rev (v::acc) 2145 val p = pairSyntax.list_mk_pair vs 2146 in 2147 mk_PMATCH p rows_tm 2148 end 2149 ) else ( 2150 let 2151 val (ty1, ty2) = pairSyntax.dest_prod (type_of v) 2152 val v1 = variant avoid (mk_var ("v", ty1)) 2153 val v2 = variant (v1::avoid) (mk_var ("v", ty2)) 2154 2155 val t0 = inst [alpha |-> ty1, beta |-> ty2, gamma |-> type_of t] pair_CASE_tm 2156 val t1 = mk_comb (t0, v) 2157 val t2a = mk_pair (v1::v2::avoid) (v1::acc) (col_no-1) v2 2158 val t2b = list_mk_abs ([v1, v2], t2a) 2159 val t2c = mk_comb (t1, t2b) 2160 in 2161 t2c 2162 end 2163 ) 2164 2165 val t'' = mk_pair (free_vars t') [] col_no v 2166 val thm1_tm = mk_eq (t', t'') 2167 val thm1 = prove (thm1_tm, SIMP_TAC std_ss [pairTheory.pair_CASE_def]) 2168 2169 val thm2 = CONV_RULE (RHS_CONV ( 2170 (TOP_SWEEP_CONV ( 2171 PMATCH_CASE_SPLIT_CONV_GENCALL_STEP rc_arg db col_heu 2172 )))) thm1 2173 2174 val thm3 = TRANS thm0 thm2 2175 2176 (* check whether it got simpler *) 2177 val _ = if (does_conv_loop thm3) then raise UNCHANGED else () 2178 2179 val thm4 = if (has_subterm is_PMATCH (rhs (concl thm3))) then 2180 thm3 2181 else 2182 CONV_RULE (RHS_CONV REMOVE_REBIND_CONV) thm3 2183in 2184 thm4 2185end 2186 2187fun PMATCH_CASE_SPLIT_CONV_GEN ssl = PMATCH_CASE_SPLIT_CONV_GENCALL (ssl, NONE) 2188 2189fun PMATCH_CASE_SPLIT_CONV_HEU col_heu t = 2190 PMATCH_CASE_SPLIT_CONV_GEN [] (!thePmatchCompileDB) col_heu t 2191 2192fun PMATCH_CASE_SPLIT_CONV t = 2193 PMATCH_CASE_SPLIT_CONV_HEU colHeu_default t 2194 2195fun PMATCH_CASE_SPLIT_GEN_ss ssl db col_heu = 2196 make_gen_conv_ss (fn rc_arg => 2197 PMATCH_CASE_SPLIT_CONV_GENCALL rc_arg db col_heu) 2198 "PMATCH_CASE_SPLIT_REDUCER" ssl 2199 2200fun PMATCH_CASE_SPLIT_HEU_ss col_heu = 2201 PMATCH_CASE_SPLIT_GEN_ss [] (!thePmatchCompileDB) col_heu 2202 2203fun PMATCH_CASE_SPLIT_ss () = 2204 PMATCH_CASE_SPLIT_HEU_ss colHeu_default 2205 2206 2207(***********************************************) 2208(* COMPUTE CASE-DISTINCTION based on pats *) 2209(***********************************************) 2210 2211(* 2212val t = `` 2213 case (a,x,xs) of 2214 | (NONE,_,[]) => 0 2215 | (NONE,x,[]) when x < 10 => x 2216 | (NONE,x,[2]) => x 2217 | (NONE,x,[v18]) => 3 2218 | (NONE,_,[_;_]) => x 2219 | (NONE,x,v12::v16::v17) => 3 2220 | (SOME y,x,[z]) => x + 5 + z 2221 | (SOME y,0,v23::v24) => (v23 + y) 2222 | (SOME y,SUC z,[v23]) when y > 5 => 3 2223 | (SOME y,SUC z,[1; 2]) => y + z``; 2224 2225 val (v, rows) = dest_PMATCH t 2226 val pats = List.map (#1 o dest_PMATCH_ROW) rows 2227 2228 val col_heu = colHeu_default 2229 val db = !thePmatchCompileDB 2230 2231 val pats = [``\(x:num). 2``] 2232 val pats = [``\(x:num). [2;3;4]``] 2233 2234*) 2235 2236local 2237 2238 val case_dist_exists_thm = prove (``!Q. ( 2239 (!(x:'a). Q x) ==> 2240 !P. (?x. P x) = (?x. Q x /\ P x))``, 2241 SIMP_TAC std_ss []); 2242 2243 val label_over_or_thm = prove ( 2244 ``(lbl :- (t1 \/ t2)) <=> (lbl :- t1) \/ (lbl :- t2)``, 2245 REWRITE_TAC[markerTheory.label_def]); 2246 2247 fun find_nchotomy_for_cols db col_heu cols = let 2248 val _ = if (List.null cols) then 2249 raise failwith "compile failed" else () 2250 val col_no = col_heu cols 2251 val (v, col) = el (col_no+1) cols 2252 val nchot_thm_opt = pmatch_compile_db_compile_nchotomy db col 2253 in 2254 case nchot_thm_opt of 2255 SOME nchot_thm => (v, ISPEC v nchot_thm) 2256 | NONE => let 2257 val (cols', _) = replace_element cols col_no [] 2258 in 2259 find_nchotomy_for_cols db col_heu cols' 2260 end 2261 end 2262 2263 2264 fun mk_initial_state var_gen lbl_gen pats = let 2265 val (_, p) = pairSyntax.dest_pabs (hd pats) 2266 val cs = pairLib.strip_pair p 2267 val vs = List.map (fn p => var_gen (type_of p)) cs 2268 val initial_value = pairLib.list_mk_pair vs 2269 val cols = dest_PATLIST_COLS initial_value pats 2270 2271 val lbl = lbl_gen () 2272 val initial_thm = let 2273 val x_tm = mk_var ("x", type_of initial_value) 2274 val tm = mk_forall (x_tm, markerSyntax.mk_label (lbl, list_mk_exists (vs, mk_eq (x_tm, initial_value)))) 2275 val thm = prove (tm, 2276 SIMP_TAC std_ss [pairTheory.FORALL_PROD, markerTheory.label_def]) 2277 in thm end 2278 in 2279 (initial_thm, cols, lbl) 2280 end 2281 2282 2283 fun compute_cases_info var_gen lbl_gen v nthm = let 2284 val disjuncts = ref ([] : (string * term * term list) list) 2285 2286 (* val d = el 2 ds *) 2287 fun process_disj d = let 2288 val lbl = lbl_gen () 2289 2290 (* intro fresh vars *) 2291 val d_thm = let 2292 val (evs, d_b) = strip_exists d 2293 val s = List.map (fn v => (v |-> var_gen (type_of v))) evs 2294 val evs = List.map (Term.subst s) evs 2295 val d_b = Term.subst s d_b 2296 val d' = list_mk_exists (evs, d_b) 2297 val d_thm = ALPHA d d' 2298 in 2299 d_thm 2300 end 2301 2302 (* add label *) 2303 val ld_thm = RIGHT_CONV_RULE (add_labels_CONV [lbl]) d_thm 2304 2305 2306 (* figure out constructor and free variables and add them 2307 to list of disjuncts *) 2308 val _ = let 2309 val d' = rhs (concl d_thm) 2310 val (evs, b) = strip_exists d' 2311 val b_conjs = strip_conj b 2312 val main_conj = first (fn c' => 2313 aconv (lhs c') v handle HOL_ERR _ => false) b_conjs 2314 val r = rhs main_conj 2315 val (c, _) = strip_comb_bounded (List.length evs) r 2316 val _ = disjuncts := (lbl, c, evs) :: !disjuncts 2317 in () end handle HOL_ERR _ => () 2318 in 2319 ld_thm 2320 end handle HOL_ERR _ => raise UNCHANGED 2321 2322 (* val ds = strip_disj (concl nthm) *) 2323 val nthm' = CONV_RULE (ALL_DISJ_CONV process_disj) nthm 2324 in 2325 (nthm', List.rev (!disjuncts)) 2326 end 2327 2328 fun exists_left_and_label_CONV t = let 2329 val (lbls_left, _) = (strip_labels o fst o dest_conj o snd o dest_exists) t 2330 val (lbls_right, _) = (strip_labels o snd o dest_conj o snd o dest_exists) t 2331 2332 val c_remove = QUANT_CONV (BINOP_CONV (REPEATC markerLib.DEST_LABEL_CONV)) 2333 2334 val thm0 = (c_remove THENC (add_labels_CONV (lbls_left @ lbls_right))) t 2335 in 2336 thm0 2337 end 2338 2339 fun expand_disjunction_CONV v nthm_expand d_tm = let 2340 val thm00 = RESORT_EXISTS_CONV (fn vs => 2341 let val (v', vs') = pick_element (aconv v) vs in 2342 (v'::vs') end) d_tm 2343 2344 val thm01a = HO_PART_MATCH (lhs o snd o strip_forall) nthm_expand (rhs (concl thm00)) 2345 val thm01 = TRANS thm00 thm01a 2346 2347 val thm02 = RIGHT_CONV_RULE (PURE_REWRITE_CONV [RIGHT_AND_OVER_OR]) thm01 2348 val thm03 = 2349 RIGHT_CONV_RULE (DESCEND_CONV BINOP_CONV (TRY_CONV EXISTS_OR_CONV)) 2350 thm02 2351 2352 val thm04 = RIGHT_CONV_RULE (ALL_DISJ_CONV exists_left_and_label_CONV) thm03 2353 2354 val LEFT_RIGHT_AND_LIST_EXISTS_CONV = 2355 DESCEND_CONV QUANT_CONV 2356 (RIGHT_AND_EXISTS_CONV ORELSEC LEFT_AND_EXISTS_CONV) 2357 val thm05 = RIGHT_CONV_RULE (ALL_DISJ_CONV (strip_labels_CONV (STRIP_QUANT_CONV LEFT_RIGHT_AND_LIST_EXISTS_CONV))) thm04 2358 val thm06 = RIGHT_CONV_RULE (ALL_DISJ_CONV (strip_labels_CONV (Unwind.UNWIND_EXISTS_CONV))) thm05 2359 in 2360 thm06 2361 end 2362 2363 fun expand_cases_in_thm lbl (v, nthm') thm = let 2364 val nthm_expand = HO_MATCH_MP case_dist_exists_thm (GEN v nthm') 2365 2366 val thm01 = CONV_RULE (QUANT_CONV (ALL_DISJ_CONV ( 2367 guarded_strip_labels_CONV [lbl] ( 2368 (expand_disjunction_CONV v nthm_expand))))) thm 2369 2370 val thm02 = CONV_RULE (PURE_REWRITE_CONV [label_over_or_thm, GSYM DISJ_ASSOC]) thm01 2371 2372 in 2373 thm02 2374 end handle HOL_ERR _ => thm 2375 2376 2377 fun get_columns_for_constructor current_col (c, evs) cols' = let 2378 fun process_current_col (cs : (term list * term) list list, kl : bool list) ps = case ps of 2379 [] => (List.map List.rev cs, List.rev kl) 2380 | (vs, p)::ps' => let 2381 val (cs', kl') = 2382 if (Term.is_var p) andalso List.exists (aconv p) vs then 2383 (Lib.map2 (fn v => fn l => ([v], v)::l) evs cs, 2384 true::kl) 2385 else let 2386 val (c', args) = strip_comb_bounded (List.length evs) p 2387 in 2388 if not (aconv c c') then (cs, false::kl) else 2389 (Lib.map2 (fn a => fn l => (vs, a)::l) args cs, 2390 true::kl) 2391 end 2392 in process_current_col (cs', kl') ps' end 2393 2394 val ps = (snd current_col) 2395 val (cs, kl) = process_current_col (List.map (K []) evs, []) ps 2396 val cols1 = zip evs cs 2397 2398 val cols2 = List.map (fn (v, rs) => 2399 (v, List.map snd (Lib.filter fst (zip kl rs)))) cols' 2400 2401 val cols'' = cols1 @ cols2 2402 2403 (* remove cols consisting of only vars *) 2404 val cols''' = filter (fn (_, ps) => not (List.all (fn (vs, p) => is_var p andalso List.exists (aconv p) vs) ps)) cols'' 2405 in 2406 cols''' 2407 end 2408 2409 2410 (* extract the column for variable v from the list of columns *) 2411 fun pick_current_column v cols = 2412 pick_element (fn (v', _) => aconv v v') cols 2413 2414in (* in of local *) 2415 2416 fun nchotomy_of_pats_GEN db col_heu pats = let 2417 val var_gen = mk_var_gen "v" [] 2418 val lbl_gen = mk_new_label_gen "case_" 2419 2420 (* 2421 val (thm, cols, lbl) = mk_initial_state var_gen lbl_gen pats 2422 val (thm, cols, lbl) = (thm1, cols'', lbl) 2423 val xxx = !args 2424 val (thm, cols, lbl) = el 3 xxx 2425*) 2426 2427 fun compile (thm, cols, lbl) = let 2428 val (v, nthm) = find_nchotomy_for_cols db col_heu cols 2429 val (current_col, cols_rest) = pick_current_column v cols 2430 val (nthm', cases_info) = compute_cases_info var_gen lbl_gen v nthm 2431 2432 (* Expand all labeled with [lbl] cases *) 2433 val thm1 = expand_cases_in_thm lbl (v, nthm') thm 2434 2435 (* Call recursively *) 2436 val thm2 = let 2437(* val ((lbl, c, evs), current_thm) = ((el 2 cases_info, thm1)) *) 2438 fun process_case ((lbl, c, evs), current_thm) = let 2439 val cols' = get_columns_for_constructor current_col (c, evs) cols_rest 2440 in 2441 compile (current_thm, cols', lbl) 2442 end 2443 in 2444 List.foldl process_case thm1 cases_info 2445 end 2446 in 2447 thm2 2448 end handle HOL_ERR _ => thm 2449 2450 (* compile it *) 2451 val thm3 = compile (mk_initial_state var_gen lbl_gen pats) 2452 2453 (* get rid of labels *) 2454 val thm4 = CONV_RULE markerLib.DEST_LABELS_CONV thm3 2455 in 2456 thm4 2457 end 2458 2459 fun nchotomy_of_pats pats = 2460 nchotomy_of_pats_GEN (!thePmatchCompileDB) colHeu_default pats 2461 2462end 2463 2464 2465(********************************************) 2466(* Prune disjunctions of PMATCH_ROW_COND_EX *) 2467(********************************************) 2468 2469(* Given a list of disjunctions of PMATCH_ROW_COND_EX and 2470 a theorem stating that a certain PMATCH_ROW_COND_EX does not 2471 hold, prune the disjunction by removing all patterns 2472 covered by the one we know does not hold. *) 2473 2474 2475fun PMATCH_ROW_COND_EX_ELIM_FALSE_GUARD_CONV tt = let 2476 val (_, _, g) = dest_PMATCH_ROW_COND_EX tt 2477 val (_, g_b) = pairLib.dest_pabs g 2478 val _ = if (aconv g_b F) then () else raise UNCHANGED 2479 2480 val thm00 = PART_MATCH (lhs o rand) PMATCH_ROW_COND_EX_FALSE tt 2481 val pre = (rand o rator o concl) thm00 2482 (* set_goal ([], pre) *) 2483 val pre_thm = prove (pre, 2484 SIMP_TAC (std_ss++pairSimps.gen_beta_ss) [pairTheory.FORALL_PROD] 2485 ) 2486 val thm01 = MP thm00 pre_thm 2487in 2488 thm01 2489end handle HOL_ERR _ => raise UNCHANGED 2490 2491(* 2492 2493val t = `` 2494 case (x,y,z) of 2495 | (NONE,_,[]) => 0 2496 | (NONE,x,[]) when x < 10 => x 2497 | (NONE,x,[2]) => x 2498 | (NONE,x,[v18]) => 3 2499 | (NONE,_,[_;_]) => 4 2500 | (NONE,x,v12::v16::v17) => 3 2501 | (SOME y,x,[z]) => x + 5 + z 2502 | (SOME y,0,v23::v24) => (v23 + y) 2503 | (SOME y,SUC z,[v23]) when y > 5 => 3 2504 | (SOME y,SUC z,[1; 2]) => y + z 2505 ``; 2506 2507 val (v, rows) = dest_PMATCH t 2508 val pats = List.map (#1 o dest_PMATCH_ROW) rows 2509 2510 2511val thm = CONV_RULE (nchotomy2PMATCH_ROW_COND_EX_CONV) (nchotomy_of_pats pats) 2512 2513val cs = (strip_disj o concl o SPEC v) thm 2514 2515val t = (concl o SPEC v) thm 2516 2517val row_cs = List.map (mk_PMATCH_ROW_COND_EX_ROW v) rows 2518 2519val weaken_ce = el 4 row_cs 2520val weaken_thm = ASSUME (mk_neg weaken_ce) 2521val ce = el 4 cs 2522 2523val rc_arg = ([], NONE) 2524*) 2525 2526(* apply thm PMATCH_ROW_COND_EX_WEAKEN *) 2527fun PMATCH_ROW_COND_EX_WEAKEN_CONV_GENCALL rc_arg (weaken_thm, v_w, p_w', vars_w') ce = let 2528 val (v, p_t, _) = dest_PMATCH_ROW_COND_EX ce 2529 val (vars, p) = pairLib.dest_pabs p_t 2530 val _ = if (aconv v v_w) then () else raise UNCHANGED 2531 2532 (* try to match *) 2533 val s = let 2534 val (s_tm, s_ty) = Term.match_term p_w' p 2535 val _ = if List.null s_ty then () else failwith "bound too much" 2536 val vars_w'_l = pairSyntax.strip_pair vars_w' 2537 val _ = if List.exists (fn s => not (List.exists 2538 (aconv (#redex s)) vars_w'_l)) s_tm then 2539 failwith "bound too much" else () 2540 in s_tm end 2541 2542 (* construct f *) 2543 val f_tm = pairSyntax.mk_pabs (vars, subst s vars_w') 2544 2545 (* instantiate the thm *) 2546 val thm0 = let 2547 val thm00 = FRESH_TY_VARS_RULE PMATCH_ROW_COND_EX_WEAKEN 2548 val thm01 = MATCH_MP thm00 weaken_thm 2549 val thm02 = ISPEC f_tm thm01 2550 val thm03 = PART_MATCH (lhs o rand) thm02 ce 2551 val thm04 = rc_elim_precond rc_arg thm03 2552 in 2553 thm04 2554 end 2555 2556 (* Simplify guard *) 2557 val thm1 = let 2558 val c = TRY_CONV (rc_conv rc_arg) THENC 2559 pairTools.PABS_INTRO_CONV vars 2560 in 2561 RIGHT_CONV_RULE (RAND_CONV c) thm0 2562 end 2563 2564 (* elim false *) 2565 val thm2 = RIGHT_CONV_RULE 2566 PMATCH_ROW_COND_EX_ELIM_FALSE_GUARD_CONV thm1 2567 handle HOL_ERR _ => thm1 2568in 2569 thm2 2570end handle HOL_ERR _ => raise UNCHANGED 2571 2572 2573fun PMATCH_ROW_COND_EX_DISJ_WEAKEN_CONV_GENCALL rc_arg weaken_thm t = let 2574 val (v_w, p_tw, _) = 2575 dest_PMATCH_ROW_COND_EX (dest_neg (concl weaken_thm)) 2576 val (vars_w, p_w) = pairLib.dest_pabs p_tw 2577 2578 (* get fresh vars in p_w before matching *) 2579 val (p_w', vars_w') = let 2580 val vars'_l = pairSyntax.strip_pair vars_w 2581 val s = List.map (fn v => (v |-> genvar (type_of v))) vars'_l 2582 val p_w' = subst s p_w 2583 val vars_w' = subst s vars_w 2584 in 2585 (p_w', vars_w') 2586 end 2587 2588 val thm0 = ALL_DISJ_CONV (PMATCH_ROW_COND_EX_WEAKEN_CONV_GENCALL rc_arg (weaken_thm, v_w, p_w', vars_w')) t 2589 2590 2591 val thm1 = RIGHT_CONV_RULE (PURE_REWRITE_CONV [boolTheory.OR_CLAUSES]) thm0 2592in 2593 thm1 2594end 2595 2596 2597(*************************************) 2598(* Compute redundant rows info for a *) 2599(* PMATCH *) 2600(*************************************) 2601 2602(* val tt = el 3 cjs *) 2603 2604fun SIMPLIFY_PMATCH_ROW_COND_EX_IMP_CONV rc_arg tt = let 2605 (* destruct everything *) 2606 val (v, vars', p', g', vars, p, g) = let 2607 val (pre, cl_neg) = dest_imp tt 2608 val (v', p', g') = dest_PMATCH_ROW_COND_EX pre 2609 val (vars', _) = pairSyntax.dest_pabs p' 2610 val cl = dest_neg cl_neg 2611 val (v, p, g) = dest_PMATCH_ROW_COND_EX cl 2612 val _ = if (aconv v v') then () else raise UNCHANGED 2613 val (vars, _) = pairSyntax.dest_pabs p 2614 in 2615 (v, vars', p', g', vars, p, g) 2616 end 2617 2618 val thm00 = FRESH_TY_VARS_RULE PMATCH_ROW_COND_EX_IMP_REWRITE 2619 val thm01 = ISPECL [v, p', g', p, g] thm00 2620 2621 val thm02 = RIGHT_CONV_RULE ( 2622 (QUANT_CONV (RAND_CONV (pairTools.PABS_INTRO_CONV vars))) THENC 2623 (RAND_CONV (pairTools.PABS_INTRO_CONV vars'))) thm01 2624 val thm03 = RIGHT_CONV_RULE (DEPTH_CONV pairLib.GEN_BETA_CONV) thm02 2625 val thm04 = RIGHT_CONV_RULE (TRY_CONV (pairTools.ELIM_TUPLED_QUANT_CONV) THENC 2626 TRY_CONV (STRIP_QUANT_CONV (pairTools.ELIM_TUPLED_QUANT_CONV))) thm03 2627 2628 fun imp_or_no_imp_CONV c t = 2629 if (is_imp t) then 2630 (RAND_CONV c) t 2631 else c t 2632 2633 val thm05 = RIGHT_CONV_RULE ( 2634 (STRIP_QUANT_CONV (imp_or_no_imp_CONV (RATOR_CONV (RAND_CONV (SIMP_CONV (rc_ss []) []))))) THENC 2635 REWRITE_CONV[]) thm04 2636 2637 val rr = rhs (concl thm05) 2638 val thm06 = if aconv rr T then thm05 else let 2639 val thm_rr = prove_attempt (rr, rc_tac rc_arg) 2640 in 2641 TRANS thm05 (EQT_INTRO thm_rr) 2642 end handle HOL_ERR _ => thm05 2643in 2644 thm06 2645end 2646 2647(* val ttts = strip_disj pre 2648 val ttt = el 1 ttts 2649 val rc_arg = ([], NONE) *) 2650 2651fun SIMPLIFY_PMATCH_ROW_COND_EX_IMP_CONV rc_arg cc_thm v ttt = let 2652 2653 val (v', p, g) = dest_PMATCH_ROW_COND_EX ttt 2654 val _ = if (aconv v v') then () else raise UNCHANGED 2655 2656 val thm00 = FRESH_TY_VARS_RULE PMATCH_ROW_COND_EX_IMP_REWRITE 2657 val thm01 = MATCH_MP thm00 cc_thm 2658 val thm02 = ISPECL [p, g] thm01 2659 2660 val (x, pre, l) = let 2661 val (x, body) = (dest_forall o rand o rator o snd o strip_forall o concl) thm02 2662 val (pre, body') = dest_imp body 2663 val l = lhs body' 2664 in 2665 (x, pre, l) 2666 end 2667 2668 val l_thm0 = rc_conv_rws rc_arg [ASSUME pre] l 2669 val r = rhs (concl l_thm0) 2670 val _ = if (aconv r T) orelse (aconv r F) then () else 2671 (* we don't want complicated intermediate results *) 2672 raise UNCHANGED 2673 val l_thm1 = GEN x (DISCH pre l_thm0) 2674 2675 val thm03 = ISPEC r thm02 2676 val thm04 = MP thm03 l_thm1 2677in 2678 thm04 2679end 2680 2681 2682(* val thm = it 2683 val (tts, _) = (listSyntax.dest_list o rand o concl) thm 2684 val tt = el 2 tts *) 2685 2686val simple_imp_thm = prove ( ``!X Y X'. ((Y ==> (X = X')) ==> ((X ==> ~Y) = (X' ==> ~Y)))``, 2687PROVE_TAC[]) 2688 2689fun SIMPLIFY_REDUNDANT_ROWS_INFO_AUX rc_arg tt = let 2690 val (pre, cc_neg) = dest_imp tt 2691 val cc = dest_neg cc_neg 2692 2693 val (v, _, _) = dest_PMATCH_ROW_COND_EX cc 2694 val cc_thm = ASSUME cc 2695 val pre_thm0 = 2696 (ALL_DISJ_TF_ELIM_CONV (SIMPLIFY_PMATCH_ROW_COND_EX_IMP_CONV rc_arg cc_thm v)) pre handle UNCHANGED => REFL pre 2697 val pre_thm = DISCH cc pre_thm0 2698 2699 val thm0 = SPECL [pre, cc] simple_imp_thm 2700 val thm1 = MATCH_MP thm0 pre_thm 2701 2702 val thm2 = RIGHT_CONV_RULE (REWRITE_CONV [] THENC DEPTH_CONV 2703 PMATCH_ROW_COND_EX_ELIM_CONV) thm1 2704in 2705 thm2 2706end handle HOL_ERR _ => raise UNCHANGED 2707 2708 2709fun find_non_constructor_pattern db vs t = let 2710 fun aux l = case l of 2711 [] => NONE 2712 | (t::ts) => if (mem t vs) then aux ts else ( 2713 if (pairSyntax.is_pair t) then 2714 aux ((pairSyntax.strip_pair t)@ts) 2715 else ( 2716 case pmatch_compile_db_dest_constr_term db t of 2717 NONE => SOME t 2718 | SOME (_, args) => aux ((map snd args) @ ts) 2719 ) 2720 ) 2721in 2722 aux [t] 2723end 2724 2725 2726fun COMPUTE_REDUNDANT_ROWS_INFO_OF_PMATCH_GENCALL rc_arg db col_heu t = 2727let 2728 val (v, rows) = dest_PMATCH t 2729 val rc_arg = case rc_arg of 2730 (sl, cb_opt) => ((#pcdb_ss db)::sl, cb_opt) 2731 2732 2733 (* compute initial enchotomy *) 2734 val nchot_thm = let 2735 val pats = List.map (#1 o dest_PMATCH_ROW) rows 2736 val thm01 = nchotomy_of_pats_GEN db col_heu pats 2737 val thm02 = CONV_RULE (nchotomy2PMATCH_ROW_COND_EX_CONV_GEN 2738 (find_non_constructor_pattern db) 2739 ) thm01 2740 val thm03 = ISPEC v thm02 2741 in 2742 thm03 2743 end 2744 2745 (* get initial info *) 2746 val init_info = let 2747 val row_ty = listSyntax.dest_list_type (type_of (rand t)) 2748 val s_ty = match_type (``:'a -> 'b option``) row_ty 2749 val thm00 = INST_TYPE s_ty IS_REDUNDANT_ROWS_INFO_NIL 2750 val thm01 = SPEC v thm00 2751 2752 val nthm = GSYM (EQT_INTRO nchot_thm) 2753 val thm02 = CONV_RULE (RATOR_CONV (RAND_CONV (K nthm))) thm01 2754 in 2755 thm02 2756 end 2757 2758 (* add a row to the info *) 2759 fun add_row (r, info_thm) = let 2760 val (p, g, r) = dest_PMATCH_ROW r 2761 val thm00 = FRESH_TY_VARS_RULE IS_REDUNDANT_ROWS_INFO_SNOC_PMATCH_ROW 2762 val thm01 = MATCH_MP thm00 info_thm 2763 val thm02 = ISPECL [p, g, r] thm01 2764 2765 (* simplify the condition we carry around *) 2766 val c'_thm = let 2767 val pthm = ASSUME (mk_neg (mk_PMATCH_ROW_COND_EX (v, p, g))) 2768 val c_tm = (rand o rator o concl) info_thm 2769 val c'_thm0 = PMATCH_ROW_COND_EX_DISJ_WEAKEN_CONV_GENCALL rc_arg pthm c_tm handle UNCHANGED => REFL c_tm 2770 val c'_thm = DISCH (concl pthm) c'_thm0 2771 in 2772 c'_thm 2773 end 2774 2775 val thm03 = MATCH_MP thm02 c'_thm 2776 val thm04 = CONV_RULE (RATOR_CONV (RATOR_CONV (RAND_CONV listLib.SNOC_CONV))) thm03 2777 2778 val new_cond_CONV = SIMPLIFY_REDUNDANT_ROWS_INFO_AUX rc_arg 2779 val thm05 = CONV_RULE (RAND_CONV (RATOR_CONV (RAND_CONV new_cond_CONV))) thm04 2780 2781 val thm06 = CONV_RULE (RAND_CONV (listLib.SNOC_CONV)) thm05 2782 in 2783 thm06 2784 end 2785in 2786 List.foldl add_row init_info rows 2787end 2788 2789fun COMPUTE_REDUNDANT_ROWS_INFO_OF_PMATCH_GEN ss db col_heu = 2790 COMPUTE_REDUNDANT_ROWS_INFO_OF_PMATCH_GENCALL (ss, NONE) db col_heu 2791 2792fun COMPUTE_REDUNDANT_ROWS_INFO_OF_PMATCH t = 2793 COMPUTE_REDUNDANT_ROWS_INFO_OF_PMATCH_GENCALL ([], NONE) 2794 (!thePmatchCompileDB) colHeu_default t 2795 2796 2797(* 2798val t = ``case (x, z) of 2799 | (NONE, NONE) => 0 2800 | (SOME _, _) => 1 2801 | (_, SOME _) => 2 2802`` 2803 2804val t = ``case (x, z) of 2805 | (NONE, NONE) => 0 2806 | (SOME _, _) => 1 2807 | (_, NONE) => 2 2808`` 2809 2810val t = ``case (x, z) of 2811 | (NONE, 1) => 0 2812 | (SOME _, 2) => 1 2813 | (_, x) when x > 5 => 2 2814`` 2815*) 2816 2817fun IS_REDUNDANT_ROWS_INFO_WEAKEN_RULE info_thm = let 2818 val (conds, _) = listSyntax.dest_list (rand (concl info_thm)) 2819 val conds' = List.map (fn c => if (aconv c T) then T else F) conds 2820 val _ = if exists (aconv T) conds' then () else raise UNCHANGED 2821 val conds'_tm = listSyntax.mk_list (conds', bool) 2822 2823 val thm00 = REDUNDANT_ROWS_INFOS_CONJ_THM 2824 val thm01 = MATCH_MP thm00 info_thm 2825 val thm02 = SPECL [F, conds'_tm] thm01 2826 2827 val thm03 = let 2828 val pre = rand (rator (concl thm02)) 2829 val pre_thm = prove (pre, SIMP_TAC list_ss []) 2830 in 2831 MP thm02 pre_thm 2832 end 2833 2834 val thm04 = CONV_RULE (RATOR_CONV (RAND_CONV (REWRITE_CONV []))) thm03 2835 2836 val thm05 = CONV_RULE (RAND_CONV (REWRITE_CONV [ 2837 REDUNDANT_ROWS_INFOS_CONJ_REWRITE])) thm04 2838in 2839 thm05 2840end 2841 2842fun IS_REDUNDANT_ROWS_INFO_TO_PMATCH_EQ_THM info_thm = let 2843 val info_thm' = IS_REDUNDANT_ROWS_INFO_WEAKEN_RULE info_thm 2844 val thm0 = MATCH_MP REDUNDANT_ROWS_INFO_TO_PMATCH_EQ info_thm' 2845 val c = PURE_REWRITE_CONV [APPLY_REDUNDANT_ROWS_INFO_THMS] 2846 val thm1 = RIGHT_CONV_RULE (RAND_CONV c) thm0 2847in 2848 thm1 2849end 2850 2851 2852fun PMATCH_REMOVE_REDUNDANT_CONV_GENCALL db col_heu rc_arg t = let 2853 val info_thm = COMPUTE_REDUNDANT_ROWS_INFO_OF_PMATCH_GENCALL rc_arg db col_heu t 2854in 2855 IS_REDUNDANT_ROWS_INFO_TO_PMATCH_EQ_THM info_thm 2856end 2857 2858fun PMATCH_REMOVE_REDUNDANT_CONV_GEN db col_heu ssl = 2859 PMATCH_REMOVE_REDUNDANT_CONV_GENCALL db col_heu (ssl, NONE) 2860 2861fun PMATCH_REMOVE_REDUNDANT_CONV t = PMATCH_REMOVE_REDUNDANT_CONV_GEN 2862 (!thePmatchCompileDB) colHeu_default [] t; 2863 2864fun PMATCH_REMOVE_REDUNDANT_GEN_ss db col_heu ssl = 2865 make_gen_conv_ss (PMATCH_REMOVE_REDUNDANT_CONV_GENCALL db col_heu) "PMATCH_REMOVE_REDUNDANT_REDUCER" ssl 2866 2867fun PMATCH_REMOVE_REDUNDANT_ss () = 2868 PMATCH_REMOVE_REDUNDANT_GEN_ss (!thePmatchCompileDB) colHeu_default [] 2869 2870 2871fun IS_REDUNDANT_ROWS_INFO_SHOW_ROW_IS_REDUNDANT thm i tac = 2872 CONV_RULE (RAND_CONV (list_nth_CONV i (fn t => 2873 EQT_INTRO (prove (t, tac))))) thm 2874 2875fun IS_REDUNDANT_ROWS_INFO_SHOW_ROW_IS_REDUNDANT_set_goal thm i = let 2876 val (l, _) = (listSyntax.dest_list o rand o concl) thm 2877 val t = List.nth (l, i) 2878in 2879 proofManagerLib.set_goal ([], t) 2880end; 2881 2882 2883(*************************************) 2884(* Exhaustiveness *) 2885(*************************************) 2886 2887fun PMATCH_IS_EXHAUSTIVE_FAST_CHECK_GENCALL rc_arg t = let 2888 val (v, rows) = dest_PMATCH t 2889 2890 fun check_row r = let 2891 val r_tm = mk_eq (mk_comb (r, v), optionSyntax.mk_none (type_of t)) 2892 val r_thm = rc_conv rc_arg r_tm 2893 val res_tm = rhs (concl r_thm) 2894 in 2895 if (same_const res_tm T) then SOME (true, r_thm) else 2896 (if (same_const res_tm F) then SOME (false, r_thm) else NONE) 2897 end handle HOL_ERR _ => NONE 2898 2899 fun find_thms a thmL [] = (a, thmL) 2900 | find_thms a thmL (r::rows) = ( 2901 case (check_row r) of 2902 NONE => find_thms true thmL rows 2903 | SOME (true, r_thm) => find_thms a (r_thm :: thmL) rows 2904 | SOME (false, r_thm) => (false, [r_thm])) 2905 2906 val (abort, rewrite_thms) = find_thms false [] (List.rev rows) 2907 val _ = if abort then raise UNCHANGED else () 2908 2909 val t0 = mk_PMATCH_IS_EXHAUSTIVE v (rand t) 2910in 2911 REWRITE_CONV (PMATCH_IS_EXHAUSTIVE_REWRITES::rewrite_thms) t0 2912end; 2913 2914fun PMATCH_IS_EXHAUSTIVE_FAST_CHECK_GEN ssl = 2915 PMATCH_IS_EXHAUSTIVE_FAST_CHECK_GENCALL (ssl, NONE) 2916 2917val PMATCH_IS_EXHAUSTIVE_FAST_CHECK = 2918 PMATCH_IS_EXHAUSTIVE_FAST_CHECK_GEN []; 2919 2920(* 2921val db = !thePmatchCompileDB 2922val col_heu = colHeu_default 2923val rc_arg = ([], NONE) 2924*) 2925 2926 2927(* 2928val t = ``case (x, z) of 2929 | (NONE, NONE) => 0 2930 | (_, SOME _) => 2 2931`` 2932 2933val t = ``case (x, z) of 2934 | (NONE, NONE) => 0 2935 | (SOME _, _) => 1 2936 | (_, NONE) => 2 2937`` 2938 2939val t = ``case (x, z) of 2940 | (NONE, 1) => 0 2941 | (SOME _, 2) => 1 2942 | (_, x) when x > 5 => 2 2943`` 2944 2945val info_thm = COMPUTE_REDUNDANT_ROWS_INFO_OF_PMATCH t 2946 2947*) 2948 2949 2950fun IS_REDUNDANT_ROWS_INFO_TO_PMATCH_IS_EXHAUSTIVE info_thm = let 2951 val thm0 = MATCH_MP IS_REDUNDANT_ROWS_INFO_EXTRACT_IS_EXHAUSTIVE 2952 info_thm 2953in 2954 thm0 2955end 2956 2957fun PMATCH_IS_EXHAUSTIVE_COMPILE_CONSEQ_CHECK_FULLGEN db col_heu rc_arg t = let 2958 val info_thm = COMPUTE_REDUNDANT_ROWS_INFO_OF_PMATCH_GENCALL rc_arg db col_heu t 2959in 2960 IS_REDUNDANT_ROWS_INFO_TO_PMATCH_IS_EXHAUSTIVE info_thm 2961end 2962 2963fun PMATCH_IS_EXHAUSTIVE_COMPILE_CONSEQ_CHECK_GENCALL rc_arg t = 2964 PMATCH_IS_EXHAUSTIVE_COMPILE_CONSEQ_CHECK_FULLGEN 2965 (!thePmatchCompileDB) colHeu_default rc_arg t 2966 2967fun PMATCH_IS_EXHAUSTIVE_COMPILE_CONSEQ_CHECK_GEN ssl = 2968 PMATCH_IS_EXHAUSTIVE_COMPILE_CONSEQ_CHECK_GENCALL (ssl, NONE) 2969 2970val PMATCH_IS_EXHAUSTIVE_COMPILE_CONSEQ_CHECK = 2971 PMATCH_IS_EXHAUSTIVE_COMPILE_CONSEQ_CHECK_GEN []; 2972 2973 2974val IMP_TO_EQ_THM = prove (``!P Q. (P ==> Q) ==> (~P ==> ~Q) ==> (Q <=> P)``, PROVE_TAC[]) 2975 2976fun PMATCH_IS_EXHAUSTIVE_COMPILE_CHECK_FULLGEN db col_heu rc_arg t = let 2977 val thm0 = PMATCH_IS_EXHAUSTIVE_COMPILE_CONSEQ_CHECK_FULLGEN db col_heu rc_arg t 2978in 2979 let 2980 val thm = rc_elim_precond rc_arg thm0 2981 in 2982 EQT_INTRO thm 2983 end handle HOL_ERR _ => let 2984 val thm1 = MATCH_MP IMP_TO_EQ_THM thm0 2985 2986 val (precond, _) = dest_imp_only (concl thm1) 2987 val pre_thm = prove_attempt (precond, 2988 REWRITE_TAC[PMATCH_IS_EXHAUSTIVE_REWRITES, PMATCH_ROW_EQ_NONE, PMATCH_ROW_COND_EX_def, 2989 DISJ_IMP_THM, GSYM LEFT_FORALL_IMP_THM] THEN 2990 SIMP_TAC (std_ss++pairSimps.gen_beta_ss) [PMATCH_ROW_COND_DEF_GSYM] THEN 2991 rc_tac rc_arg) 2992 2993 val thm2 = MP thm1 pre_thm 2994 in 2995 thm2 2996 end 2997end 2998 2999fun PMATCH_IS_EXHAUSTIVE_COMPILE_CHECK_GENCALL rc_arg t = 3000 PMATCH_IS_EXHAUSTIVE_COMPILE_CHECK_FULLGEN 3001 (!thePmatchCompileDB) colHeu_default rc_arg t 3002 3003fun PMATCH_IS_EXHAUSTIVE_COMPILE_CHECK_GEN ssl = 3004 PMATCH_IS_EXHAUSTIVE_COMPILE_CHECK_GENCALL (ssl, NONE) 3005 3006val PMATCH_IS_EXHAUSTIVE_COMPILE_CHECK = 3007 PMATCH_IS_EXHAUSTIVE_COMPILE_CHECK_GEN []; 3008 3009 3010fun PMATCH_IS_EXHAUSTIVE_CHECK_FULLGEN db col_heu rc_arg t = 3011 QCHANGED_CONV (PMATCH_IS_EXHAUSTIVE_FAST_CHECK_GENCALL rc_arg) t 3012 handle HOL_ERR _ => 3013 PMATCH_IS_EXHAUSTIVE_COMPILE_CHECK_FULLGEN db col_heu rc_arg t; 3014 3015fun PMATCH_IS_EXHAUSTIVE_CHECK_GENCALL rc_arg t = 3016 PMATCH_IS_EXHAUSTIVE_CHECK_FULLGEN (!thePmatchCompileDB) colHeu_default rc_arg t 3017 3018fun PMATCH_IS_EXHAUSTIVE_CHECK_GEN ssl = 3019 PMATCH_IS_EXHAUSTIVE_CHECK_GENCALL (ssl, NONE) 3020 3021val PMATCH_IS_EXHAUSTIVE_CHECK = PMATCH_IS_EXHAUSTIVE_CHECK_GEN [] 3022 3023 3024local 3025 val EQ_F_ELIM = prove (``!b. F ==> b``, PROVE_TAC[]) 3026 val EQ_T_ELIM = prove (``!b. (b = T) ==> ~F ==> b``, PROVE_TAC[]) 3027 val EQ_O_ELIM = prove (``!b1 b2. (b1 = b2) ==> b2 ==> b1``, PROVE_TAC[]) 3028 3029in 3030 3031fun PMATCH_IS_EXHAUSTIVE_CONSEQ_CHECK_FULLGEN db col_heu rc_arg t = let 3032 val thm0 = QCHANGED_CONV (PMATCH_IS_EXHAUSTIVE_FAST_CHECK_GENCALL rc_arg) t 3033 val (ex_t, r) = dest_eq (concl thm0) 3034 in 3035 if (r = T) then 3036 MP (SPEC ex_t EQ_T_ELIM) thm0 3037 else (if (r = F) then 3038 (SPEC ex_t EQ_F_ELIM) 3039 else 3040 (MP (SPEC r (SPEC ex_t EQ_O_ELIM)) thm0) 3041 ) 3042 end handle HOL_ERR _ => 3043 PMATCH_IS_EXHAUSTIVE_COMPILE_CONSEQ_CHECK_FULLGEN db col_heu rc_arg t; 3044end; 3045 3046 3047fun PMATCH_IS_EXHAUSTIVE_CONSEQ_CHECK_GENCALL rc_arg t = 3048 PMATCH_IS_EXHAUSTIVE_CONSEQ_CHECK_FULLGEN (!thePmatchCompileDB) colHeu_default rc_arg t 3049 3050fun PMATCH_IS_EXHAUSTIVE_CONSEQ_CHECK_GEN ssl = 3051 PMATCH_IS_EXHAUSTIVE_CONSEQ_CHECK_GENCALL (ssl, NONE) 3052 3053val PMATCH_IS_EXHAUSTIVE_CONSEQ_CHECK = PMATCH_IS_EXHAUSTIVE_CONSEQ_CHECK_GEN [] 3054 3055 3056(*************************************) 3057(* Nchotomy *) 3058(*************************************) 3059 3060(* 3061val db = !thePmatchCompileDB 3062val col_heu = colHeu_default 3063val rc_arg = ([], NONE) 3064*) 3065 3066 3067val neg_imp_rewr = prove (``(~A ==> B) = (~B ==> A)``, 3068 Cases_on `A` THEN Cases_on `B` THEN REWRITE_TAC[]); 3069 3070fun nchotomy_PMATCH_ROW_COND_EX_CONSEQ_CONV_GEN rc_arg db col_heu tt = let 3071 (* destruct everything *) 3072 val (v, disjs) = let 3073 val disjs = strip_disj tt 3074 val (v, _, _) = dest_PMATCH_ROW_COND_EX (hd disjs) 3075 in 3076 (v, disjs) 3077 end 3078 3079 (* Sanity check *) 3080 val _ = List.map (fn r => let 3081 val (v', _, _) = dest_PMATCH_ROW_COND_EX r 3082 val _ = if (aconv v v') then () else failwith "illformed input" 3083 in () end) disjs 3084 3085 (* derive nchot thm *) 3086 val nchot_thm = let 3087 val pats = List.map (#2 o dest_PMATCH_ROW_COND_EX) disjs 3088 val thm01 = nchotomy_of_pats_GEN db col_heu pats 3089 val thm02 = CONV_RULE (nchotomy2PMATCH_ROW_COND_EX_CONV_GEN 3090 (find_non_constructor_pattern db)) thm01 3091 val thm03 = ISPEC v thm02 3092 in 3093 thm03 3094 end 3095 3096 (* prepare assumptions *) 3097 val neg_tt = mk_neg tt 3098 val pre_thms = let 3099 val thm00 = ASSUME neg_tt 3100 val thm01 = PURE_REWRITE_RULE [DE_MORGAN_THM] thm00 3101 in BODY_CONJUNCTS thm01 end 3102 3103 3104 (* apply these assumptions to the nchot_thm *) 3105 val nchot_thm' = let 3106 fun step (pre_thm, thm) = 3107 CONV_RULE (PMATCH_ROW_COND_EX_DISJ_WEAKEN_CONV_GENCALL rc_arg pre_thm) thm 3108 val thm00 = foldl step nchot_thm pre_thms 3109 val thm01 = DISCH neg_tt thm00 3110 in 3111 thm01 3112 end 3113 3114 val nchot_thm'' = let 3115 val thm00 = CONV_RULE (REWR_CONV neg_imp_rewr) nchot_thm' 3116 val thm01 = CONV_RULE (RATOR_CONV (RAND_CONV (REWRITE_CONV []))) thm00 3117 in thm01 end 3118 3119in 3120 nchot_thm'' 3121end 3122 3123 3124fun SHOW_NCHOTOMY_CONSEQ_CONV_GEN ssl db col_heu tt = let 3125 val (x, b) = dest_forall tt 3126 val b_thm = ALL_DISJ_CONV (PMATCH_ROW_COND_EX_INTRO_CONV_GEN 3127 (find_non_constructor_pattern db) x) b 3128 3129 val thm2 = nchotomy_PMATCH_ROW_COND_EX_CONSEQ_CONV_GEN (ssl, NONE) db col_heu (rhs (concl b_thm)) 3130 3131 val thm3 = CONV_RULE (RAND_CONV (K (GSYM b_thm))) thm2 3132 3133 val thm4 = CONV_RULE (RATOR_CONV (RAND_CONV (DEPTH_CONV (PMATCH_ROW_COND_EX_ELIM_CONV)))) thm3 3134 3135 val thm5 = GEN x thm4 3136in 3137 thm5 3138end 3139 3140fun SHOW_NCHOTOMY_CONSEQ_CONV tt = 3141 SHOW_NCHOTOMY_CONSEQ_CONV_GEN [] (!thePmatchCompileDB) colHeu_default tt 3142 3143 3144(*************************************) 3145(* Add missing patterns *) 3146(*************************************) 3147 3148(* 3149val use_guards = true 3150val rc_arg = ([], NONE) 3151val db = !thePmatchCompileDB 3152val col_heu = colHeu_default 3153val t = ``case (x, y) of ([], x::xs) => x | (_, _) => 2`` 3154val t = ``case (x, y) of ([], x::xs) => x`` 3155*) 3156 3157fun PMATCH_COMPLETE_CONV_GENCALL_AUX rc_arg db col_heu use_guards t = 3158let 3159 val exh_thm = EQT_ELIM (PMATCH_IS_EXHAUSTIVE_FAST_CHECK_GENCALL rc_arg t) 3160 handle UNCHANGED => failwith "NOT EXH" 3161in (false, REFL t, (fn () => exh_thm)) end handle HOL_ERR _ => 3162let 3163 val (v, rows) = dest_PMATCH t 3164 fun row_to_cond_ex r = let 3165 val (vs_t, p, g, _) = dest_PMATCH_ROW_ABS r 3166 val vs = pairSyntax.strip_pair vs_t 3167 in 3168 mk_PMATCH_ROW_COND_EX_PABS_MOVE_TO_GUARDS (find_non_constructor_pattern db) vs (v, p, g) 3169 end 3170 val disjs = List.map row_to_cond_ex rows 3171 val disjs_tm = list_mk_disj disjs 3172 3173 val thm_nchot = nchotomy_PMATCH_ROW_COND_EX_CONSEQ_CONV_GEN rc_arg db col_heu disjs_tm 3174 3175 val missing_list = let 3176 val pre = fst (dest_imp (concl thm_nchot)) 3177 val disj = dest_neg pre 3178 in 3179 strip_disj disj 3180 end handle HOL_ERR _ => [] 3181 3182 fun add_missing_pat (missing_t, thm) = let 3183 val (_, vs, p_t0, g_t0) = dest_PMATCH_ROW_COND_EX_ABS missing_t 3184 val g_t1 = if use_guards then g_t0 else T 3185 val g_t = pairSyntax.mk_pabs (vs, g_t1) 3186 val p_t = pairSyntax.mk_pabs (vs, p_t0) 3187 val r_t = pairSyntax.mk_pabs (vs, mk_arb (type_of t)) 3188 3189 val thm00 = FRESH_TY_VARS_RULE PMATCH_REMOVE_ARB 3190 val rows_t = (rand o rhs o concl) thm 3191 val thm01 = ISPECL [p_t, g_t, r_t, v, rows_t] thm00 3192 val thm02 = rc_elim_precond rc_arg thm01 3193 val thm03 = GSYM thm02 3194 val thm04 = RIGHT_CONV_RULE (RAND_CONV (listLib.SNOC_CONV)) thm03 3195 in 3196 TRANS thm thm04 3197 end 3198 3199 val thm_expand = foldl add_missing_pat (REFL t) missing_list 3200 3201 (* set_goal ([], mk_PMATCH_IS_EXHAUSTIVE v (rand (rhs (concl thm_expand)))) *) 3202 fun exh_thm () = prove (mk_PMATCH_IS_EXHAUSTIVE v (rand (rhs (concl thm_expand))), 3203 ASSUME_TAC (thm_nchot) THEN 3204 PURE_REWRITE_TAC [PMATCH_IS_EXHAUSTIVE_REWRITES, PMATCH_ROW_NEQ_NONE] THEN 3205 PROVE_TAC[]) 3206in 3207 (not (List.null missing_list), thm_expand, exh_thm) 3208end 3209 3210fun PMATCH_COMPLETE_CONV_GENCALL rc_arg db col_heu use_guards t = 3211 let 3212 val (ch, thm, _) = (PMATCH_COMPLETE_CONV_GENCALL_AUX rc_arg db col_heu use_guards t) 3213 val _ = if ch then () else raise UNCHANGED 3214 in thm end; 3215 3216fun PMATCH_COMPLETE_CONV_GEN ssl = 3217 PMATCH_COMPLETE_CONV_GENCALL (ssl, NONE); 3218 3219fun PMATCH_COMPLETE_CONV use_guards = 3220 PMATCH_COMPLETE_CONV_GEN [] (!thePmatchCompileDB) colHeu_default use_guards; 3221 3222fun PMATCH_COMPLETE_GEN_ss ssl db colHeu use_guards = 3223 make_gen_conv_ss (fn rc_arg => 3224 PMATCH_COMPLETE_CONV_GENCALL rc_arg db colHeu use_guards) 3225 "PMATCH_COMPLETE_REDUCER" ssl; 3226 3227fun PMATCH_COMPLETE_ss use_guards = PMATCH_COMPLETE_GEN_ss [] (!thePmatchCompileDB) colHeu_default use_guards; 3228 3229 3230fun PMATCH_COMPLETE_CONV_GEN_WITH_EXH_PROOF ssl db col_heu use_guards t = 3231 let val (ch, mt, rt) = PMATCH_COMPLETE_CONV_GENCALL_AUX (ssl, NONE) db col_heu use_guards t in 3232 (if ch then SOME mt else NONE, rt ()) end 3233 3234fun PMATCH_COMPLETE_CONV_WITH_EXH_PROOF use_guards = 3235 PMATCH_COMPLETE_CONV_GEN_WITH_EXH_PROOF [] (!thePmatchCompileDB) colHeu_default use_guards; 3236 3237 3238 3239(***********************************************) 3240(* Lifting to lowest boolean level *) 3241(***********************************************) 3242 3243(* One can replace pattern matches with a big-conjunction. 3244 Each row becomes one conjunct. Since the conjunction is 3245 of type bool, this needs to be done at a boolean level. 3246 So we can replace an arbitry term 3247 3248 P (PMATCH i rows) with 3249 3250 (row_cond 1 i -> P (row_rhs 1)) /\ 3251 ... 3252 (row_cond n i -> P (row_rhs n)) /\ 3253 3254 The row-cond contains that the pattern does not overlap with 3255 any previous pattern, that the guard holds. 3256 3257 The most common use-case of lifting are function definitions. 3258 of the form 3259 3260 f x = PMATCH x rows 3261 3262 which can be turned into a conjunction of top-level 3263 rewrite rules for the function f. 3264*) 3265 3266(* 3267 3268val tm = ``(P2 /\ Q ==> ( 3269 (case xx of 3270 | (x, y::ys) => (x + y) 3271 | (0, []) => 9 3272 | (x, []) when x > 3 => x 3273 | (x, []) => 0) = 5))`` 3274 3275val tm = `` 3276 (case xx of 3277 | (x, y::ys) => (x + y) 3278 | (0, []) => 9 3279 | (x, []) when x > 3 => x 3280 | (x, []) => 0) = 5`` 3281 3282val _ = ENABLE_PMATCH_CASES() 3283val OPT_PAIR_def = TotalDefn.Define `OPT_PAIR xy = case xy of 3284 | (NONE, _) => NONE 3285 | (_, NONE) => NONE 3286 | (SOME x, SOME y) => SOME (x, y)` 3287val thm = OPT_PAIR_def 3288val tm = concl (hd (BODY_CONJUNCTS thm)) 3289val force_minimal = false 3290val rc_arg = ([], NONE) 3291val try_exh = true 3292*) 3293 3294 3295local 3296val IMP_AUX_THM = prove (``(P ==> (X <=> Y)) <=> 3297 ((P ==> X) <=> (P ==> Y))``, PROVE_TAC[]) 3298in 3299fun SIMPLE_IMP_COND_REWRITE_CONV thm tt = let 3300 val (pre, post) = dest_imp tt 3301 val pre_thm = ASSUME pre 3302 val rewr_thm = MATCH_MP thm pre_thm 3303 val thm0 = REWRITE_CONV [rewr_thm] post 3304 val thm1 = DISCH pre thm0 3305 val thm2 = CONV_RULE (REWR_CONV IMP_AUX_THM) thm1 3306in 3307 thm2 3308end 3309end; 3310 3311fun rename_uscore_vars ren avoid [] = ren 3312 | rename_uscore_vars ren avoid (v::vs) = 3313 let 3314 val (v_n, v_ty) = dest_var v 3315 val _ = if (String.sub(v_n, 0) = #"_") then () else failwith "nothing to do" 3316 val v' = variant avoid (mk_var ("v", v_ty)) 3317 in 3318 rename_uscore_vars ((v |-> v')::ren) (v'::avoid) vs 3319 end handle HOL_ERR _ => rename_uscore_vars ren avoid vs 3320 3321 3322 3323fun PMATCH_LIFT_BOOL_CONV_GENCALL force_minimal try_exh rc_arg tm = let 3324 (* check whether we should really process tm *) 3325 val _ = if type_of tm = bool then () else raise UNCHANGED 3326 val p_tm = find_term is_PMATCH tm 3327 fun has_subterm p t = (find_term p t; true) handle HOL_ERR _ => false 3328 3329 val is_minimal = not force_minimal orelse not (has_subterm (fn t => 3330 (not (aconv t tm)) andalso 3331 (type_of t = bool) andalso 3332 (has_subterm is_PMATCH t)) tm) 3333 val _ = if is_minimal then () else raise UNCHANGED 3334 3335 (* prepare tm *) 3336 val v = genvar (type_of p_tm) 3337 val P_tm = mk_abs (v, subst [p_tm |-> v] tm) 3338 val P_v = genvar (type_of P_tm) 3339 3340 (* do real work *) 3341 val thm0 = let 3342 val (p_tm', genvars_elim_s) = PMATCH_INTRO_GENVARS p_tm 3343 val t0 = (mk_comb (P_v, p_tm')) 3344 val c1 = SIMP_CONV std_ss [PMATCH_EXPAND_PRED_THM, 3345 PMATCH_EXPAND_PRED_def, PMATCH_ROW_NEQ_NONE, 3346 EVERY_DEF, PMATCH_ROW_EVAL_COND_EX, REVERSE_REV, REV_DEF] 3347 val c2 = REWRITE_CONV [PMATCH_ROW_COND_EX_def, 3348 PULL_EXISTS] 3349 3350 val thm00 = (c1 THENC c2) t0 3351 val thm01 = INST genvars_elim_s thm00 3352 in 3353 thm01 3354 end 3355 3356 (* Elim choice *) 3357 val thm1 = let 3358 val (v, rows) = dest_PMATCH p_tm 3359 fun process_row (r, thm') = let 3360 val (pt, gt, rt) = dest_PMATCH_ROW r 3361 val thm00 = ISPECL [pt, gt, v] (FRESH_TY_VARS_RULE PMATCH_COND_SELECT_UNIQUE) 3362 val thm01 = rc_elim_precond rc_arg thm00 3363 val thm'' = CONV_RULE (DEPTH_CONV (SIMPLE_IMP_COND_REWRITE_CONV thm01)) thm' 3364 in 3365 thm'' 3366 end handle HOL_ERR _ => thm' 3367 in 3368 foldl process_row thm0 rows 3369 end 3370 3371 (* get rid of exhaustiveness check *) 3372 val thm2 = let 3373 val _ = if try_exh then () else failwith "skip" 3374 val thm_ex = PMATCH_IS_EXHAUSTIVE_CHECK_GENCALL rc_arg p_tm 3375 val thm2 = CONV_RULE (RHS_CONV (REWRITE_CONV [thm_ex])) thm1 3376 in 3377 thm2 3378 end handle HOL_ERR _ => thm1 3379 | UNCHANGED => thm1 3380 3381 3382 (* Use the right variable names and simplify *) 3383 val thm3 = let 3384 fun special_CONV tt = let 3385 val (vars_tm0, row) = let 3386 val (_, tt0) = dest_forall tt 3387 val (tt1, _) = dest_imp_only tt0 3388 val (tt2, _, _, _) = dest_PMATCH_ROW_COND tt1 3389 in (fst (pairSyntax.dest_pabs tt2), tt1) end 3390 3391 val vars_tm = let 3392 val to_ren = free_vars vars_tm0 3393 val avoid = free_vars row @ to_ren 3394 val ren = rename_uscore_vars [] avoid to_ren 3395 in 3396 subst ren vars_tm0 3397 end 3398 3399 val intro_marker = TRY_CONV (QUANT_CONV (RAND_CONV (RATOR_CONV (RAND_CONV markerLib.stmark_term)))) 3400 3401 val elim_COND_CONV = 3402 QUANT_CONV (RATOR_CONV (RAND_CONV (REWR_CONV PMATCH_ROW_COND_DEF_GSYM))) 3403 3404 val intro_CONV = RAND_CONV (pairTools.PABS_INTRO_CONV vars_tm) 3405 val elim_CONV = TRY_CONV (pairTools.ELIM_TUPLED_QUANT_CONV) 3406 val eval_preconds = STRIP_QUANT_CONV (RAND_CONV (fn t => let 3407 val _ = dest_imp_only t 3408 in RATOR_CONV (RAND_CONV (rc_conv rc_arg)) t end)) 3409 3410 val simp_CONV = TRY_CONV (SIMP_CONV std_ss []) 3411 3412 val elim_marker = 3413 (REWR_CONV markerTheory.stmarker_def) THENC 3414 TRY_CONV (rc_conv rc_arg) 3415 in 3416 EVERY_CONV [ 3417 intro_marker, 3418 elim_COND_CONV, 3419 intro_CONV, elim_CONV, 3420 simp_CONV, 3421 DEPTH_CONV elim_marker, 3422 TRY_CONV (REWRITE_CONV []) 3423 ] tt 3424 end 3425 in 3426 CONV_RULE (RHS_CONV (ALL_CONJ_CONV special_CONV)) thm2 3427 end 3428 3429 3430 (* restore original predicate *) 3431 val thm4 = let 3432 val thm00 = INST [P_v |-> P_tm] thm3 3433 val thm01 = CONV_RULE (LHS_CONV BETA_CONV) thm00 3434 val thm02 = CONV_RULE (RHS_CONV (DEPTH_CONV BETA_CONV)) thm01 3435 val _ = assert (fn thm => aconv (lhs (concl thm)) tm) thm02 3436 in 3437 thm02 3438 end 3439in 3440 thm4 3441end 3442 3443fun PMATCH_LIFT_BOOL_CONV_GEN ssl try_exh = PMATCH_LIFT_BOOL_CONV_GENCALL true try_exh (ssl, NONE) 3444 3445val PMATCH_LIFT_BOOL_CONV = PMATCH_LIFT_BOOL_CONV_GEN []; 3446 3447fun PMATCH_LIFT_BOOL_GEN_ss ssl try_exh = 3448 make_gen_conv_ss (PMATCH_LIFT_BOOL_CONV_GENCALL true try_exh) "PMATCH_LIFT_BOOL_REDUCER" ssl 3449 3450val PMATCH_LIFT_BOOL_ss = PMATCH_LIFT_BOOL_GEN_ss [] 3451 3452 3453fun PMATCH_TO_TOP_RULE_SINGLE ssl thm = let 3454 val thm0 = GEN_ALL thm 3455 3456 val thm1 = CONV_RULE (STRIP_QUANT_CONV (PMATCH_LIFT_BOOL_CONV_GENCALL false false (ssl, NONE))) thm0 3457 val thm2 = CONV_RULE (STRIP_QUANT_CONV ( 3458 EVERY_CONJ_CONV (STRIP_QUANT_CONV (TRY_CONV (RAND_CONV markerLib.stmark_term))))) thm1 3459 val thm3 = SIMP_RULE std_ss [FORALL_AND_THM, 3460 Cong (REFL ``stmarker (t:'a)``)] thm2 3461 val thm4 = PURE_REWRITE_RULE [markerTheory.stmarker_def] thm3 3462 3463 val thm5 = LIST_CONJ (butlast (CONJUNCTS thm4)) 3464in 3465 thm5 3466end 3467 3468fun PMATCH_TO_TOP_RULE_GEN ssl thm = let 3469 val thms = BODY_CONJUNCTS thm 3470 val thms' = List.map (PMATCH_TO_TOP_RULE_SINGLE ssl) thms 3471 val thm0 = LIST_CONJ thms' 3472 val thm1 = CONV_RULE unwindLib.FLATTEN_CONJ_CONV thm0 3473in 3474 thm1 3475end 3476 3477fun PMATCH_TO_TOP_RULE thm = PMATCH_TO_TOP_RULE_GEN [] thm; 3478 3479 3480(*************************************) 3481(* Lifting *) 3482(*************************************) 3483 3484(* 3485val tm = ``\y. SUC (SUC 3486 (case y of 3487 | (x, y::ys) => (x + y) 3488 | (0, []) => 0)) + 3489 (case xx of 3490 | (x, y::ys) => (x + y) 3491 | (0, []) => 0)`` 3492val tm = p_tm 3493*) 3494 3495fun PMATCH_LIFT_CONV_GENCALL_AUX rc_arg db col_heu tm = let 3496 (* check whether we should really process tm *) 3497 val _ = if (is_PMATCH tm) then failwith "PMATCH_LIFT_CONV_GENCALL_AUX: nothing to do" else () 3498 3499 (* search subterm to lift *) 3500 fun search_pred (bvs, tt) = if is_PMATCH tt andalso 3501 HOLset.isEmpty (HOLset.intersection (HOLset.fromList Term.compare bvs, FVL [tt] empty_tmset)) then SOME tt else NONE 3502 val p_tm = case gen_find_term search_pred tm of SOME p_tm => p_tm | NONE => failwith "no_case" 3503 3504 (* Abstract context with f_tm *) 3505 val f_tm = let 3506 val nv = genvar (type_of p_tm) 3507 val tm' = subst [p_tm |-> nv] tm 3508 in 3509 mk_abs (nv, tm') 3510 end 3511 val (_, p_thm, exh_thm) = PMATCH_COMPLETE_CONV_GENCALL_AUX rc_arg db col_heu true p_tm 3512 3513 (* Intro f_tm *) 3514 val thm0 = let 3515 val f_tm_thm = GSYM (BETA_CONV (mk_comb (f_tm, p_tm))) 3516 val f_tm_thm' = CONV_RULE (RHS_CONV (RAND_CONV (K p_thm))) f_tm_thm 3517 in f_tm_thm' end 3518 3519 (* Apply lifting thm *) 3520 val (v, rows_tm, thm1) = let 3521 val p_tm' = rhs (concl p_thm) 3522 val (v, rows) = dest_PMATCH p_tm' 3523 val rows_tm = rand p_tm' 3524 val thm10 = ISPECL [f_tm, v, rows_tm] (FRESH_TY_VARS_RULE PMATCH_LIFT_THM) 3525 val thm11 = MP thm10 (exh_thm()) 3526 in (v, rows_tm, thm11) end 3527 3528 (* Simplify *) 3529 val thm2 = let 3530 fun c3 tt = let 3531 val (vt, _) = pairSyntax.dest_pabs (rator (rand (snd (dest_abs tt)))) 3532 val c30 = (pairTools.PABS_INTRO_CONV vt) 3533 val c31 = PairRules.PABS_CONV (RAND_CONV PairRules.PBETA_CONV) 3534 val c32 = PairRules.PABS_CONV BETA_CONV 3535 in 3536 (c30 THENC (TRY_CONV c31) THENC c32) tt 3537 end 3538 val c2 = REWR_CONV PMATCH_ROW_LIFT_THM THENC (RAND_CONV c3) 3539 val c = listLib.MAP_CONV c2 3540 val thm2 = c (rand (rhs (concl thm1))) 3541 in 3542 thm2 3543 end 3544 3545 val thm_lift = CONV_RULE (RHS_CONV (RAND_CONV (K thm2))) (TRANS thm0 thm1) 3546 3547 (* construct exhaustiveness result *) 3548 fun exh_thm' () = let 3549 val exh_thm = exh_thm () 3550 val thm00 = ISPECL [f_tm, v, rows_tm] PMATCH_IS_EXHAUSTIVE_LIFT 3551 val thm01 = MP thm00 exh_thm 3552 val thm02 = CONV_RULE (RAND_CONV (K thm2)) thm01 3553 in thm02 end 3554in 3555 (thm_lift, exh_thm') 3556end; 3557 3558 3559fun PMATCH_LIFT_CONV_GENCALL rc_arg db col_heu t = 3560 let 3561 val (thm, _) = (PMATCH_LIFT_CONV_GENCALL_AUX rc_arg db col_heu t) 3562 in thm end; 3563 3564fun PMATCH_LIFT_CONV_GENCALL_WITH_EXH_PROOF rc_arg db col_heu t = 3565 let 3566 val (thm, exh) = (PMATCH_LIFT_CONV_GENCALL_AUX rc_arg db col_heu t) 3567 in (thm, exh()) end; 3568 3569fun PMATCH_LIFT_CONV_GEN ssl = 3570 PMATCH_LIFT_CONV_GENCALL (ssl, NONE); 3571 3572fun PMATCH_LIFT_CONV t = 3573 PMATCH_LIFT_CONV_GEN [] (!thePmatchCompileDB) colHeu_default t; 3574 3575fun PMATCH_LIFT_CONV_GEN_WITH_EXH_PROOF ssl = 3576 PMATCH_LIFT_CONV_GENCALL_WITH_EXH_PROOF (ssl, NONE); 3577 3578fun PMATCH_LIFT_CONV_WITH_EXH_PROOF t = 3579 PMATCH_LIFT_CONV_GEN_WITH_EXH_PROOF [] (!thePmatchCompileDB) colHeu_default t; 3580 3581 3582(*************************************) 3583(* FLATTENING *) 3584(*************************************) 3585 3586(* 3587val do_lift = false 3588val use_guards = true 3589val rc_arg = ([], NONE) 3590val db = !thePmatchCompileDB 3591val col_heu = colHeu_default 3592 3593 3594val tm = ``case (x, y) of ([], x::xs) => ( 3595 case xs of [] => 0 | _ => 5) | (_, []) => 1 `` 3596 3597val tm = ``case (x, y) of (x::xs, []) => 2 | ([], x::xs) => ( 3598 SUC (case xs of [] => x | _ => HD xs)) | (_, []) => 1 `` 3599*) 3600 3601fun PMATCH_FLATTEN_CONV_GENCALL_AUX rc_arg db col_heu do_lift tm = let 3602 val (v, rows) = dest_PMATCH tm 3603 3604 (* Try to flatten row no i *) 3605 fun try_row i = let 3606 val (rows_b, row, rows_a) = extract_element rows i 3607 val (pt, gt, rt0) = dest_PMATCH_ROW row 3608 val (vs, rt) = pairSyntax.dest_pabs rt0 3609 3610 (* lift the rhs of row i to be PMATCH expression *) 3611 val thm0 = if do_lift andalso not (is_PMATCH rt) then 3612 PMATCH_LIFT_CONV_GENCALL rc_arg db col_heu rt 3613 else 3614 PMATCH_COMPLETE_CONV_GENCALL rc_arg db col_heu true rt handle UNCHANGED => REFL rt 3615 3616 (* extend the input to match the output of the outer PMATCH exactly *) 3617 val thm1 = let 3618 val thm1a = PMATCH_EXTEND_INPUT_CONV_GENCALL rc_arg vs (rhs (concl thm0)) 3619 val thm1 = TRANS thm0 thm1a 3620 in thm1 end 3621 3622 3623 (* Apply the flatten theorem, discard preconditions and show that rhs equals input *) 3624 val thm2 = let 3625 val rt' = rhs (concl thm1) 3626 val (v', rows') = dest_PMATCH rt' 3627 val rows'' = map (fn t => pairSyntax.mk_pabs (v', t)) rows' 3628 3629 3630 (* instantiate thm *) 3631 val thm2a = let 3632 val thm20 = FRESH_TY_VARS_RULE PMATCH_FLATTEN_THM 3633 val thm20 = ISPEC v thm20 3634 val thm20 = ISPEC pt thm20 3635 val thm20 = ISPEC gt thm20 3636 val thm20 = ISPEC (listSyntax.mk_list(rows_b, type_of row)) thm20 3637 val thm20 = ISPEC (listSyntax.mk_list(rows_a, type_of row)) thm20 3638 val thm20 = ISPEC (listSyntax.mk_list(rows'', type_of (hd rows''))) thm20 3639 val thm21 = CONV_RULE (RATOR_CONV (RAND_CONV (RAND_CONV (pairTools.PABS_INTRO_CONV v')))) thm20 3640 val c = RATOR_CONV (RAND_CONV (RAND_CONV (pairTools.PABS_INTRO_CONV v'))) 3641 val thm22 = CONV_RULE (RAND_CONV (LHS_CONV (RAND_CONV (RAND_CONV c)))) thm21 3642 3643 in thm22 end 3644 3645 (* simplify MAP (\x. r x) rows'' = rows' *) 3646 val thm2b = let 3647 val map_tm = rand (snd (pairSyntax.dest_pforall (fst (dest_imp (concl thm2a))))) 3648 val map_tm_eq = mk_eq (map_tm, listSyntax.mk_list (rows', type_of (hd rows'))) 3649 val map_thm = prove (map_tm_eq, SIMP_TAC list_ss []) 3650 3651 val thm2b = CONV_RULE (DEPTH_CONV (REWR_CONV map_thm)) thm2a 3652 in thm2b end 3653 3654 (* elim precond *) 3655 val thm2c = let 3656 val exh_thm = PMATCH_IS_EXHAUSTIVE_CHECK_GENCALL rc_arg rt' 3657 val (pre, _) = dest_imp (concl thm2b) 3658 val pre_thm = prove (pre, SIMP_TAC std_ss [exh_thm, GSYM pairTheory.PFORALL_THM]) 3659 val thm2c = MP thm2b pre_thm 3660 in thm2c end 3661 3662 (* use thm1 on lhs *) 3663 val thm2d = let 3664 val c = RATOR_CONV (RAND_CONV (RAND_CONV (PairRules.PABS_CONV (K (GSYM thm1))))) 3665 val thm20 = CONV_RULE (LHS_CONV (RAND_CONV (RAND_CONV c))) thm2c 3666 val l_eq = mk_eq (tm, lhs (concl thm20)) 3667 val l_thm = prove (l_eq, SIMP_TAC list_ss []) 3668 val thm2d = TRANS l_thm thm20 3669 in thm2d end 3670 in 3671 thm2d 3672 end 3673 3674 (* EVALUATE MAP PMATCH_FLATTEN_FUN on rhs *) 3675 val thm3 = let 3676 val flatten_thm = let 3677 val thm00 = FRESH_TY_VARS_RULE PMATCH_FLATTEN_FUN_PMATCH_ROW 3678 val thm01 = ISPEC pt thm00 3679 val thm02 = rc_elim_precond rc_arg thm01 3680 val thm03 = ISPEC gt thm02 3681 val c = pairTools.PABS_INTRO_CONV vs 3682 val thm04 = CONV_RULE (STRIP_QUANT_CONV (LHS_CONV (RAND_CONV c))) thm03 3683 in thm04 end 3684 3685 fun flatten_fun_conv tt = let 3686 val (_, row_d) = pairSyntax.dest_pabs (rand tt) 3687 val (pt_d, gt_d, rt_d) = dest_PMATCH_ROW row_d 3688 val thm00 = ISPECL [pt_d, pairSyntax.mk_pabs(vs, gt_d), pairSyntax.mk_pabs(vs, rt_d)] flatten_thm 3689 val eq_tm = mk_eq (tt, lhs (concl thm00)) 3690 val eq_thm = prove (eq_tm, SIMP_TAC (std_ss++pairSimps.gen_beta_ss) []) 3691 3692 val thm01 = TRANS eq_thm thm00 3693 val (vs', _) = pairSyntax.dest_pabs pt_d 3694 val thm02 = CONV_RULE (RHS_CONV (PMATCH_ROW_PABS_INTRO_CONV vs')) thm01 3695 3696 val thm03 = CONV_RULE (RHS_CONV (DEPTH_CONV (pairLib.PAIRED_BETA_CONV ORELSEC BETA_CONV))) thm02 3697 val thm04 = CONV_RULE (RHS_CONV (REWRITE_CONV [])) thm03 3698 in thm04 end 3699 3700 val c = BETA_CONV THENC flatten_fun_conv 3701 val thm30 = CONV_RULE (RHS_CONV (RAND_CONV (RATOR_CONV ( 3702 RAND_CONV (RAND_CONV (listLib.MAP_CONV c)))))) thm2 3703 3704 val thm31 = CONV_RULE (RHS_CONV (RAND_CONV (RATOR_CONV (RAND_CONV 3705 listLib.APPEND_CONV)))) thm30 3706 3707 val thm32 = CONV_RULE (RHS_CONV (RAND_CONV 3708 listLib.APPEND_CONV)) thm31 3709 in thm32 end 3710 3711 (* Fix wildcards *) 3712 val thm4 = CONV_RULE (RHS_CONV PMATCH_INTRO_WILDCARDS_CONV) thm3 3713 in 3714 thm4 3715 end 3716 3717 val row_index_l = Lib.upto 0 (length rows - 1) 3718in 3719 tryfind try_row row_index_l 3720end 3721 3722 3723fun PMATCH_FLATTEN_CONV_GENCALL rc_arg db col_heu do_lift = 3724 REPEATC (PMATCH_FLATTEN_CONV_GENCALL_AUX rc_arg db col_heu do_lift) 3725 3726fun PMATCH_FLATTEN_CONV_GEN ssl = 3727 PMATCH_FLATTEN_CONV_GENCALL (ssl, NONE); 3728 3729fun PMATCH_FLATTEN_CONV do_lift = 3730 PMATCH_FLATTEN_CONV_GEN [] (!thePmatchCompileDB) colHeu_default do_lift; 3731 3732fun PMATCH_FLATTEN_GEN_ss ssl db col_heu do_lift = 3733 make_gen_conv_ss (fn rc_arg => PMATCH_FLATTEN_CONV_GENCALL rc_arg db col_heu do_lift) 3734 "PMATCH_FLATTEN_REDUCER" ssl 3735 3736fun PMATCH_FLATTEN_ss do_lift = 3737 PMATCH_FLATTEN_GEN_ss [] (!thePmatchCompileDB) colHeu_default do_lift; 3738 3739 3740(*************************************) 3741(* Analyse PMATCH expressions to *) 3742(* check whether they can be *) 3743(* translated to ML or OCAML *) 3744(*************************************) 3745 3746type pmatch_info = { 3747 pmi_is_well_formed : bool, 3748 pmi_ill_formed_rows : int list, 3749 pmi_has_free_pat_vars : (int * term list) list, 3750 pmi_has_unused_pat_vars : (int * term list) list, 3751 pmi_has_double_bound_pat_vars : (int * term list) list, 3752 pmi_has_guards : int list, 3753 pmi_has_lambda_in_pat : int list, 3754 pmi_has_non_contr_in_pat : (int * term list) list, 3755 pmi_exhaustiveness_cond : thm option 3756} 3757 3758fun is_proven_exhaustive_pmatch (pmi : pmatch_info) = 3759 (case (#pmi_exhaustiveness_cond pmi) of 3760 NONE => false 3761 | SOME thm => let 3762 val (pre, _) = dest_imp_only (concl thm) 3763 in 3764 aconv pre ``~F`` 3765 end handle HOL_ERR _ => false 3766 ) 3767 3768fun get_possibly_missing_patterns (pmi : pmatch_info) = 3769 (case (#pmi_exhaustiveness_cond pmi) of 3770 NONE => NONE 3771 | SOME thm => (let 3772 val (pre, _) = dest_imp_only (concl thm) 3773 in if aconv pre ``~F`` then SOME [] else 3774 let 3775 val ps = strip_disj (dest_neg pre) 3776 fun dest_p p = let 3777 val (_, vs, p, g) = dest_PMATCH_ROW_COND_EX_ABS p 3778 in (vs, p, g) end 3779 in 3780 SOME (List.map dest_p ps) 3781 end end) handle HOL_ERR _ => NONE 3782 ) 3783 3784fun extend_possibly_missing_patterns t (pmi : pmatch_info) = 3785 case get_possibly_missing_patterns pmi of 3786 NONE => failwith "no missing row info available" 3787 | SOME [] => t 3788 | SOME rs => let 3789 val use_guards = not (null (#pmi_has_guards pmi)) 3790 val arb_t = mk_arb (type_of t) 3791 fun mk_row (v, p, g) = let 3792 val vars = pairSyntax.strip_pair v 3793 in 3794 snd (mk_PMATCH_ROW_PABS_WILDCARDS vars (p, 3795 if use_guards then g else T, arb_t)) 3796 end 3797 val rows = List.map mk_row rs 3798 3799 val (i, rows_org) = dest_PMATCH t 3800 val rows_t = 3801 listSyntax.mk_list (rows_org @ rows, type_of (hd rows)) 3802 in 3803 mk_PMATCH i rows_t 3804 end; 3805 3806 3807fun is_well_formed_pmatch (pmi : pmatch_info) = 3808 (#pmi_is_well_formed pmi) andalso 3809 (null (#pmi_ill_formed_rows pmi)) andalso 3810 (null (#pmi_has_unused_pat_vars pmi)) andalso 3811 (null (#pmi_has_lambda_in_pat pmi)); 3812 3813fun is_ocaml_pmatch (pmi : pmatch_info) = 3814 (is_well_formed_pmatch pmi) andalso 3815 (null (#pmi_has_non_contr_in_pat pmi)) andalso 3816 (null (#pmi_has_free_pat_vars pmi)) andalso 3817 (null (#pmi_has_double_bound_pat_vars pmi)); 3818 3819fun is_sml_pmatch (pmi : pmatch_info) = 3820 (is_ocaml_pmatch pmi) andalso 3821 (null (#pmi_has_guards pmi)); 3822 3823val init_pmatch_info : pmatch_info = { 3824 pmi_is_well_formed = false, 3825 pmi_ill_formed_rows = [], 3826 pmi_has_free_pat_vars = [], 3827 pmi_has_unused_pat_vars = [], 3828 pmi_has_double_bound_pat_vars = [], 3829 pmi_has_guards = [], 3830 pmi_has_lambda_in_pat = [], 3831 pmi_has_non_contr_in_pat = [], 3832 pmi_exhaustiveness_cond = NONE 3833} 3834 3835fun pmi_genupdate f1 f2 f3 f4 f5 f6 f7 f8 f9 3836 (pmi : pmatch_info) = ({ 3837 pmi_is_well_formed = f1 (#pmi_is_well_formed pmi), 3838 pmi_ill_formed_rows = f2 (#pmi_ill_formed_rows pmi), 3839 pmi_has_free_pat_vars = f3 (#pmi_has_free_pat_vars pmi), 3840 pmi_has_unused_pat_vars = f4 (#pmi_has_unused_pat_vars pmi), 3841 pmi_has_double_bound_pat_vars = f5 (#pmi_has_double_bound_pat_vars pmi), 3842 pmi_has_guards = f6 (#pmi_has_guards pmi), 3843 pmi_has_lambda_in_pat = f7 (#pmi_has_lambda_in_pat pmi), 3844 pmi_has_non_contr_in_pat = f8 (#pmi_has_non_contr_in_pat pmi), 3845 pmi_exhaustiveness_cond = f9 (#pmi_exhaustiveness_cond pmi) 3846}:pmatch_info) 3847 3848fun pmi_set_is_well_formed x = 3849 pmi_genupdate (K x) I I I I I I I I 3850 3851fun pmi_add_ill_formed_row row_no = 3852 pmi_genupdate (K true) (cons row_no) I I I I I I I; 3853 3854fun pmi_add_has_free_pat_vars row_no vars = 3855 pmi_genupdate I I (cons (row_no, vars)) I I I I I I; 3856 3857fun pmi_add_has_unused_pat_vars row_no vars = 3858 pmi_genupdate I I I (cons (row_no, vars)) I I I I I; 3859 3860fun pmi_add_has_double_bound_pat_vars row_no vars = 3861 pmi_genupdate I I I I (cons (row_no, vars)) I I I I; 3862 3863fun pmi_add_has_guards row_no = 3864 pmi_genupdate I I I I I (cons row_no) I I I; 3865 3866fun pmi_add_has_lambda_in_pat row_no = 3867 pmi_genupdate I I I I I I (cons row_no) I I; 3868 3869fun pmi_add_has_non_contr_in_pat row_no terms = 3870 pmi_genupdate I I I I I I I (cons (row_no, terms)) I; 3871 3872fun pmi_set_pmi_exhaustiveness_cond thm_opt = 3873 pmi_genupdate I I I I I I I I (K thm_opt); 3874 3875 3876local 3877 3878 fun analyse_pat (ls : bool (* has lamdbda been seen *), 3879 sv : term set (* set of all seen vars *), 3880 msv : term set (* set of all vars seen more than once *), 3881 sc : term set (* set of all seen constants *)) 3882 t = 3883 if is_var t then let 3884 val (sv, msv) = if HOLset.member (sv, t) then 3885 (sv, HOLset.add (msv, t)) 3886 else 3887 (HOLset.add (sv, t), msv) 3888 in (ls, sv, msv, sc) 3889 end else if (Literal.is_literal t orelse is_const t) then 3890 (ls, sv, msv, HOLset.add (sc,t)) 3891 else if (is_abs t) then 3892 (true, sv, msv, sc) 3893 else if (is_comb t) then let 3894 val (t1, t2) = dest_comb t 3895 val (ls, sv, msv, sc) = analyse_pat (ls, sv, msv, sc) t1 3896 val (ls, sv, msv, sc) = analyse_pat (ls, sv, msv, sc) t2 3897 in 3898 (ls, sv, msv, sc) 3899 end 3900 else failwith "UNREACHABLE" 3901 3902 3903 fun analyse_row ((row_num, row),pmi) = let 3904 val (p_vars, p_body, g_body, _) = 3905 dest_PMATCH_ROW_ABS row 3906 3907 (* check guard *) 3908 val pmi = if aconv g_body T then pmi else 3909 pmi_add_has_guards row_num pmi 3910 3911 (* check pattern *) 3912 val vars_l = pairSyntax.strip_pair p_vars 3913 val vars = HOLset.fromList Term.compare vars_l 3914 val (ls, sv, msv, sc) = analyse_pat (false, 3915 HOLset.empty Term.compare, 3916 HOLset.empty Term.compare, 3917 HOLset.empty Term.compare) p_body 3918 3919 (* Take care of unit vars *) 3920 val sv = case vars_l of 3921 [v] => if type_of v = oneSyntax.one_ty then 3922 HOLset.add (sv, v) else sv 3923 | _ => sv 3924 3925 (* check lambda *) 3926 val pmi = if ls then 3927 pmi_add_has_lambda_in_pat row_num pmi 3928 else pmi 3929 3930 (* check free_vars *) 3931 val fv = HOLset.difference (sv, vars) 3932 val pmi = if HOLset.isEmpty fv then pmi else 3933 (pmi_add_has_free_pat_vars row_num 3934 (HOLset.listItems fv) pmi) 3935 3936 (* check unused vars *) 3937 val uv = HOLset.difference (vars, sv) 3938 val pmi = if HOLset.isEmpty uv then pmi else 3939 (pmi_add_has_unused_pat_vars row_num 3940 (HOLset.listItems uv) pmi) 3941 3942 (* check double vars *) 3943 val dv = HOLset.intersection (msv, vars) 3944 val pmi = if HOLset.isEmpty dv then pmi else 3945 (pmi_add_has_double_bound_pat_vars row_num 3946 (HOLset.listItems dv) pmi) 3947 3948 (* check constructors vars *) 3949 val c_l = HOLset.listItems sc 3950 val nc_l = List.filter (fn c => 3951 not (TypeBase.is_constructor c orelse Literal.is_literal c)) c_l 3952 val pmi = if null nc_l then pmi else 3953 (pmi_add_has_non_contr_in_pat row_num 3954 nc_l pmi) 3955 in 3956 pmi 3957 end 3958 3959in 3960 3961fun analyse_pmatch try_exh t = let 3962 val (_, rows) = dest_PMATCH t 3963 val nrows = enumerate 0 rows 3964 val pmi = pmi_set_is_well_formed true init_pmatch_info 3965 val pmi = List.foldl analyse_row pmi nrows 3966 3967 val pmi = (if (try_exh andalso is_ocaml_pmatch pmi) then 3968 pmi_set_pmi_exhaustiveness_cond (SOME (PMATCH_IS_EXHAUSTIVE_CONSEQ_CHECK t)) pmi else pmi) handle HOL_ERR _ => pmi 3969 3970in 3971 pmi 3972end handle HOL_ERR _ => init_pmatch_info 3973 3974end 3975 3976 3977end 3978