1structure Pmatch :> Pmatch = 2struct 3 4open HolKernel boolSyntax PmatchHeuristics; 5 6type thry = {Tyop : string, Thy : string} -> 7 {case_const : term, constructors : term list} option 8 9val ERR = mk_HOL_ERR "Pmatch"; 10 11val allow_new_clauses = ref true; 12 13(*--------------------------------------------------------------------------- 14 Miscellaneous support 15 ---------------------------------------------------------------------------*) 16 17fun gtake f = 18 let fun grab(0,rst) = ([],rst) 19 | grab(n, x::rst) = 20 let val (taken,left) = grab(n-1,rst) 21 in (f x::taken, left) end 22 | grab _ = raise ERR "gtake" "grab.empty list" 23 in grab 24 end; 25 26fun list_to_string f delim = 27 let fun stringulate [] = [] 28 | stringulate [x] = [f x] 29 | stringulate (h::t) = f h::delim::stringulate t 30 in 31 fn l => String.concat (stringulate l) 32 end; 33 34val stringize = list_to_string int_to_string ", "; 35 36fun enumerate l = map (fn (x,y) => (y,x)) (Lib.enumerate 0 l); 37 38fun match_term thry tm1 tm2 = Term.match_term tm1 tm2; 39fun match_type thry ty1 ty2 = Type.match_type ty1 ty2; 40 41fun match_info db s = db s 42 43(* should probably be in somewhere like HolKernel *) 44local val counter = ref 0 45in 46fun vary vlist = 47 let val slist = ref (map (fst o dest_var) vlist) 48 val _ = counter := 0 49 fun pass str = 50 if Lib.mem str (!slist) 51 then (counter := !counter + 1; pass ("v"^int_to_string(!counter))) 52 else (slist := str :: !slist; str) 53 in 54 fn ty => mk_var(pass "v", ty) 55 end 56end; 57 58 59(*--------------------------------------------------------------------------- 60 * This datatype carries some information about the origin of a 61 * clause in a function definition. 62 *---------------------------------------------------------------------------*) 63 64datatype pattern = GIVEN of term * int 65 | OMITTED of term * int 66 67fun pattern_cmp (GIVEN(_,i)) (GIVEN(_, j)) = i <= j 68 | pattern_cmp all others = raise ERR "pattern_cmp" "" 69 70fun psubst theta (GIVEN (tm,i)) = GIVEN(subst theta tm, i) 71 | psubst theta (OMITTED (tm,i)) = OMITTED(subst theta tm, i); 72 73fun dest_pattern (GIVEN (tm,i)) = ((GIVEN,i),tm) 74 | dest_pattern (OMITTED (tm,i)) = ((OMITTED,i),tm); 75 76fun pat_of (GIVEN (tm,_)) = tm 77 | pat_of (OMITTED(tm,_)) = tm 78 79fun row_of_pat (GIVEN(_, i)) = i 80 | row_of_pat (OMITTED _) = ~1 81 82fun dest_given (GIVEN(tm,_)) = tm 83 | dest_given (OMITTED _) = raise ERR "dest_given" "" 84 85fun mk_omitted tm = OMITTED(tm,~1) 86 87fun is_omitted (OMITTED _) = true 88 | is_omitted otherwise = false; 89 90val givens = mapfilter dest_given; 91 92(*--------------------------------------------------------------------------- 93 * Produce an instance of a constructor, plus genvars for its arguments. 94 *---------------------------------------------------------------------------*) 95 96fun fresh_constr ty_match (colty:hol_type) gv c = 97 let val Ty = type_of c 98 val (L,ty) = strip_fun Ty 99 val ty_theta = ty_match ty colty 100 val c' = inst ty_theta c 101 val gvars = map (inst ty_theta o gv) L 102 in (c', gvars) 103 end; 104 105 106(*---------------------------------------------------------------------------* 107 * Goes through a list of rows and picks out the ones beginning with a * 108 * pattern = Literal, or all those beginning with a variable if the pattern * 109 * is a variable. * 110 *---------------------------------------------------------------------------*) 111 112fun mk_groupl literal rows = 113 let fun func (row as ((prefix, p::rst), rhs)) (in_group,not_in_group) = 114 if (is_var literal andalso is_var p) orelse aconv p literal 115 then if is_var literal 116 then (((prefix,p::rst), rhs)::in_group, not_in_group) 117 else (((prefix,rst), rhs)::in_group, not_in_group) 118 else (in_group, row::not_in_group) 119 | func _ _ = raise ERR "mk_groupl" "" 120 in 121 itlist func rows ([],[]) 122 end; 123 124(*---------------------------------------------------------------------------* 125 * Goes through a list of rows and picks out the ones beginning with a * 126 * pattern with constructor = c. * 127 *---------------------------------------------------------------------------*) 128 129fun mk_group c rows = 130 let fun func (row as ((prefix, p::rst), rhs)) (in_group,not_in_group) = 131 let val (pc,args) = strip_comb p 132 in if same_const pc c 133 then (((prefix,args@rst), rhs)::in_group, not_in_group) 134 else (in_group, row::not_in_group) 135 end 136 | func _ _ = raise ERR "mk_group" "" 137 in 138 itlist func rows ([],[]) 139 end; 140 141 142(*---------------------------------------------------------------------------* 143 * Partition the rows among literals. Not efficient. * 144 *---------------------------------------------------------------------------*) 145 146fun partitionl _ _ (_,_,_,[]) = raise ERR"partitionl" "no rows" 147 | partitionl gv ty_match 148 (constructors, colty, res_ty, rows as (((prefix,_),_)::_)) = 149let fun part {constrs = [], rows, A} = rev A 150 | part {constrs = c::crst, rows, A} = 151 let val (in_group, not_in_group) = mk_groupl c rows 152 val in_group' = 153 if (null in_group) (* Constructor not given *) 154 then [((prefix, []), mk_omitted (mk_arb res_ty))] 155 else in_group 156 val gvars = if is_var c then [c] else [] 157 in 158 part{constrs = crst, 159 rows = not_in_group, 160 A = {constructor = c, 161 new_formals = gvars, 162 group = in_group'}::A} 163 end 164in part{constrs=constructors, rows=rows, A=[]} 165end; 166 167 168(*---------------------------------------------------------------------------* 169 * Partition the rows. Not efficient. * 170 *---------------------------------------------------------------------------*) 171 172fun partition _ _ (_,_,_,[]) = raise ERR"partition" "no rows" 173 | partition gv ty_match 174 (constructors, colty, res_ty, rows as (((prefix:term list,_),_)::_)) = 175let val fresh = fresh_constr ty_match colty gv 176 fun part {constrs = [], rows, A} = rev A 177 | part {constrs = c::crst, rows, A} = 178 let val (c',gvars) = fresh c 179 val (in_group, not_in_group) = mk_group c' rows 180 val in_group' = 181 if (null in_group) (* Constructor not given *) 182 then [((prefix, #2(fresh c)), mk_omitted (mk_arb res_ty))] 183 else in_group 184 in 185 part{constrs = crst, 186 rows = not_in_group, 187 A = {constructor = c', new_formals = gvars, group = in_group'}::A} 188 end 189in part{constrs=constructors, rows=rows, A=[]} 190end; 191 192 193(*--------------------------------------------------------------------------- 194 * Misc. routines used in mk_case 195 *---------------------------------------------------------------------------*) 196 197fun mk_patl c = 198 let val L = if is_var c then 1 else 0 199 fun build (prefix,tag,plist) = 200 let val (args,plist') = gtake I (L, plist) 201 val c' = if is_var c then hd args else c 202 in (prefix,tag, c'::plist') end 203 in map build 204 end; 205 206fun mk_pat c = 207 let val L = length(#1(strip_fun(type_of c))) 208 fun build (prefix,tag,plist) = 209 let val (args,plist') = gtake I (L, plist) 210 in (prefix,tag,list_mk_comb(c,args)::plist') end 211 in map build 212 end; 213 214 215fun v_to_prefix (prefix, v::pats) = (v::prefix,pats) 216 | v_to_prefix _ = raise ERR "mk_case" "v_to_prefix" 217 218fun v_to_pats (v::prefix,tag, pats) = (prefix, tag, v::pats) 219 | v_to_pats _ = raise ERR "mk_case""v_to_pats"; 220 221(* -------------------------------------------------------------- 222 Literals include numeric, string, and character literals. 223 Boolean literals are the constructors of the bool type. 224 Currently, literals may be any expression without free vars. 225 These functions are not used at the moment, but may be someday. 226 -------------------------------------------------------------- *) 227 228(* 229val is_literal = Literal.is_literal 230 231fun is_lit_or_var tm = is_literal tm orelse is_var tm 232 233fun is_zero_emptystr_or_var tm = 234 Literal.is_zero tm orelse Literal.is_emptystring tm orelse is_var tm 235*) 236 237fun is_closed_or_var tm = is_var tm orelse null (free_vars tm) 238 239 240(* --------------------------------------------------------------------------- 241 Reconstructed code from TypeBasePure, to avoid circularity. 242 ---------------------------------------------------------------------------*) 243 244fun case_const_of {case_const : term, constructors : term list} = case_const 245fun constructors_of {case_const : term, constructors : term list} = constructors 246 247fun type_names ty = 248 let val {Thy,Tyop,Args} = Type.dest_thy_type ty 249 in {Thy=Thy,Tyop=Tyop} 250 end; 251 252(*---------------------------------------------------------------------------*) 253(* Is a constant a constructor for some datatype. *) 254(*---------------------------------------------------------------------------*) 255 256fun is_constructor tybase c = 257 let val (_,ty) = strip_fun (type_of c) 258 in case tybase (type_names ty) 259 of NONE => false 260 | SOME tyinfo => op_mem same_const c (constructors_of tyinfo) 261 end handle HOL_ERR _ => false; 262 263fun is_constructor_pat tybase tm = 264 is_constructor tybase (fst (strip_comb tm)); 265 266fun is_constructor_var_pat ty_info tm = 267 is_var tm orelse is_constructor_pat ty_info tm 268 269fun mk_switch_tm gv v base literals = 270 let val rty = type_of base 271 val lty = type_of v 272 val v' = last literals handle _ => gv lty 273 fun mk_arg lit = if is_var lit then gv (lty --> rty) else gv rty 274 val args = map mk_arg literals 275 open boolSyntax 276 fun mk_switch [] = base 277 | mk_switch ((lit,arg)::litargs) = 278 if is_var lit then mk_comb(arg, v') 279 else mk_bool_case(arg, mk_switch litargs, mk_eq(v', lit)) 280 val switch = mk_switch (zip literals args) 281 in list_mk_abs(args@[v], mk_literal_case (mk_abs(v',switch), v)) 282 end 283 284(* under_bool_case repairs a final beta_conv for literal switches. *) 285 286fun under_literal_case conv tm = 287 if is_literal_case tm then 288 let val (f,e) = dest_literal_case tm 289 val (x,bdy) = dest_abs f 290 val bdy' = conv bdy handle HOL_ERR _ => bdy 291 in mk_literal_case (mk_abs(x, bdy'), e) 292 end 293 else conv tm handle HOL_ERR _ => tm 294 295fun under_bool_case conv tm = 296 if is_bool_case tm then 297 let val (t,f,tst) = dest_bool_case tm 298 val f' = under_bool_case conv f 299 in mk_bool_case (t,f',tst) 300 end 301 else conv tm handle HOL_ERR _ => tm 302 303fun under_literal_bool_case conv tm = 304 under_literal_case (under_bool_case conv) tm 305 306 307(*---------------------------------------------------------------------------- 308 Translation of pattern terms into nested case expressions. 309 310 This performs the translation and also builds the full set of patterns. 311 Thus it supports the construction of induction theorems even when an 312 incomplete set of patterns is given. 313 ----------------------------------------------------------------------------*) 314 315fun bring_to_front_list n l = let 316 val (l0, l1) = Lib.split_after n l 317 val (x, l1') = (hd l1, tl l1) 318 in x :: (l0 @ l1') end 319 320fun undo_bring_to_front n l = let 321 val (x, l') = (hd l, tl l) 322 val (l0, l1) = Lib.split_after n l' 323 in (l0 @ x::l1) end 324 325fun mk_case0_heu (heu : pmatch_heuristic) ty_info ty_match FV range_ty = 326 let 327 fun mk_case_fail s = raise ERR "mk_case" s 328 val fresh_var = vary FV 329 val dividel = partitionl fresh_var ty_match 330 val divide = partition fresh_var ty_match 331 fun expandl literals ty ((_,[]), _) = mk_case_fail "expandl_var_row" 332 | expandl literals ty (row as ((prefix, p::rst), rhs)) = 333 if is_var p 334 then let fun expnd l = 335 ((prefix, l::rst), psubst[p |-> l] rhs) 336 in map expnd literals end 337 else [row] 338 fun expand constructors ty ((_,[]), _) = mk_case_fail "expand_var_row" 339 | expand constructors ty (row as ((prefix, p::rst), rhs)) = 340 (if is_var p 341 then let val fresh = fresh_constr ty_match ty fresh_var 342 fun expnd (c,gvs) = 343 let val capp = list_mk_comb(c,gvs) 344 in ((prefix, capp::rst), psubst[p |-> capp] rhs) 345 end 346 in map expnd (map fresh constructors) end 347 else [row]) 348 fun mk{rows=[],...} = mk_case_fail "no rows" 349 | mk{path=[], rows = ((prefix, []), rhs)::_} = (* Done *) 350 let val (tag,tm) = dest_pattern rhs 351 in ([(prefix,tag,[])], tm) 352 end 353 | mk{path=[], rows = _::_} = mk_case_fail "blunder" 354 | mk{path as u::rstp, rows as ((prefix, []), rhs)::rst} = 355 mk{path = path, 356 rows = ((prefix, [fresh_var(type_of u)]), rhs)::rst} 357 | mk{path = rstp0, rows = rows0 as ((_, pL as (_ :: _)), _)::_} = 358 if ((#skip_rows heu) andalso length rows0 > 1 andalso all is_var pL) 359 then mk {path = rstp0, rows = [hd rows0]} 360 else 361 let val col_index = (#col_fun heu) ty_info (map (fn ((_, pL), _) => pL) rows0) 362 val u_rstp = bring_to_front_list col_index rstp0 363 val (u, rstp) = (hd u_rstp, tl u_rstp) 364 val rows = map (fn ((prefix, pL), rhs) => ((prefix, bring_to_front_list col_index pL), rhs)) rows0 365 val ((_, pL), _) = hd rows 366 val p = hd pL 367 val (pat_rectangle,rights) = unzip rows 368 val col0 = map(Lib.trye hd o #2) pat_rectangle 369 in 370 if all is_var col0 371 then let val rights' = map(fn(v,e) => psubst[v|->u] e) (zip col0 rights) 372 val pat_rectangle' = map v_to_prefix pat_rectangle 373 val (pref_patl,tm) = mk{path = rstp, 374 rows = zip pat_rectangle' rights'} 375 val pat_rect1 = map v_to_pats pref_patl 376 val pat_rect1' = map (fn (x, y, pL) => (x, y, undo_bring_to_front col_index pL)) pat_rect1 377 in (pat_rect1', tm) 378 end 379 else 380 let val pty = type_of p 381 val thy_tyop = 382 type_names pty 383 handle HOL_ERR _ => 384 raise ERR "mk_case0_heu" 385 ("Term "^Parse.term_to_string p^ 386 " is a bad pattern (of var type?)") 387 in 388 if exists Literal.is_pure_literal col0 (* col0 has a literal *) then 389 let val is_lit_col = all (fn t => Literal.is_literal t orelse is_var t) col0 390 val _ = if is_lit_col then () else 391 mk_case_fail "case expression mixes literals with non-literals." 392 val other_var = fresh_var pty 393 val constructors = 394 rev (op_mk_set aconv (rev (filter (not o is_var) col0))) @ 395 [other_var] 396 val arb = mk_arb range_ty 397 val switch_tm = mk_switch_tm fresh_var u arb constructors 398 val nrows = flatten (map (expandl constructors pty) rows) 399 val subproblems = dividel(constructors, pty, range_ty, nrows) 400 val groups = map #group subproblems 401 and new_formals = map #new_formals subproblems 402 and constructors' = map #constructor subproblems 403 val news = map (fn (nf,rows) => {path = nf@rstp, rows=rows}) 404 (zip new_formals groups) 405 val rec_calls = map mk news 406 val (pat_rect,dtrees) = unzip rec_calls 407 val case_functions = map list_mk_abs(zip new_formals dtrees) 408 val tree = List.foldl (fn (a,tm) => beta_conv (mk_comb(tm,a))) 409 switch_tm (case_functions@[u]) 410 val tree' = under_literal_bool_case beta_conv tree 411 val pat_rect1 = flatten(map2 mk_patl constructors' pat_rect) 412 val pat_rect1' = map (fn (x, y, pL) => (x, y, undo_bring_to_front col_index pL)) pat_rect1 413 in 414 (pat_rect1',tree') 415 end 416 else 417 case List.find (not o is_constructor_var_pat ty_info) col0 of 418 NONE => let 419 val {case_const,constructors} = 420 Lib.with_exn Option.valOf (ty_info thy_tyop) 421 (ERR "mk_case0" ("could not get case constructors for type " ^ 422 Parse.type_to_string pty)) 423 handle Option.Option => (print "hello\n"; raise Option) 424 val {Name = case_const_name, Thy,...} = dest_thy_const case_const 425 val nrows = flatten (map (expand constructors pty) rows) 426 val subproblems = divide(constructors, pty, range_ty, nrows) 427 val groups = map #group subproblems 428 and new_formals = map #new_formals subproblems 429 and constructors' = map #constructor subproblems 430 val news = map (fn (nf,rows) => {path = nf@rstp, rows=rows}) 431 (zip new_formals groups) 432 val rec_calls = map mk news 433 val (pat_rect,dtrees) = unzip rec_calls 434 val tree = 435 if ((#collapse_cases heu) andalso 436 (List.all (aconv (hd dtrees)) (tl dtrees)) andalso 437 (List.all (fn (vL, tree) => 438 (List.all (fn v => not (free_in v tree)) vL)) (zip new_formals dtrees))) then 439 (* If all cases lead to the same result, no split is necessary *) 440 (hd dtrees) 441 else let 442 val case_functions = map list_mk_abs(zip new_formals dtrees) 443 val types = map type_of (u::case_functions) 444 val case_const' = mk_thy_const{Name = case_const_name, Thy = Thy, 445 Ty = list_mk_fun(types, range_ty)} 446 val tree = list_mk_comb(case_const', u::case_functions) 447 in tree end 448 val pat_rect1 = flatten(map2 mk_pat constructors' pat_rect) 449 val pat_rect1' = map (fn (x, y, pL) => (x, y, undo_bring_to_front col_index pL)) pat_rect1 450 in 451 (pat_rect1',tree) 452 end 453 | SOME t => mk_case_fail ("Pattern "^ 454 trace ("Unicode", 0) Parse.term_to_string t^ 455 " is not a constructor or variable") 456 end 457 end 458 in mk 459 end; 460 461fun mk_case0 ty_info ty_match FV range_ty rows = 462let 463 fun run_heu heu = mk_case0_heu heu ty_info ty_match FV range_ty rows 464 465 val (min_fun0, heu_fun) = (!pmatch_heuristic) () 466 fun min_fun ((pL1, dt1), (pL2, dt2)) = 467 min_fun0 ((map (fn (x, _, _) => x) pL1, dt1), (map (fn (x, _, _) => x) pL2, dt2)) 468 469 fun res_min NONE res = res 470 | res_min (SOME res1) res2 = 471 (case min_fun (res1, res2) of GREATER => res2 | _ => res1) 472 473 fun aux min = case (heu_fun ()) of 474 NONE => (case min of NONE => (print "SHOULD NOT HAPPEN! EMPTY PMATCH-HEURISTIC!"; fail()) | SOME min' => min') 475 | SOME heu => let 476 val res = run_heu heu 477 val min' = res_min min res 478 in aux (SOME min') end 479in 480 aux NONE 481end 482 483(*--------------------------------------------------------------------------- 484 Repeated variable occurrences in a pattern are not allowed. 485 ---------------------------------------------------------------------------*) 486 487fun inc d k = case Binarymap.peek(d,k) of NONE => Binarymap.insert(d,k,1) 488 | SOME n => Binarymap.insert(d,k,n+1); 489 490fun FV_multiset tm = 491 let 492 datatype witem = TM of term | RESET of term * int 493 fun recurse d wlist = 494 case wlist of 495 [] => d 496 | TM tm :: rest => 497 (case dest_term tm of 498 VAR _ => recurse (inc d tm) rest 499 | CONST _ => recurse d rest 500 | COMB(Rator,Rand) => recurse d (TM Rator :: TM Rand :: rest) 501 | LAMB(Bvar,Body) => 502 let 503 val c0 = case Binarymap.peek(d,Bvar) of 504 NONE => 0 505 | SOME i => i 506 in 507 recurse d (TM Body :: RESET(Bvar,c0) :: rest) 508 end) 509 | RESET (v,c) :: rest => recurse (Binarymap.insert(d,v,c)) rest 510 in 511 recurse (Binarymap.mkDict Term.compare) [TM tm] 512 end 513 514fun no_repeat_vars pat = 515 let 516 fun check d = 517 let 518 val repeats = 519 Binarymap.foldl (fn (k,v,acc) => if v > 1 then k::acc else acc) [] d 520 in 521 if null repeats then true 522 else 523 raise ERR"no_repeat_vars" 524 (quote(#1(dest_var (hd repeats))) ^ 525 " occurs repeatedly in the pattern " ^ 526 quote(Hol_pp.term_to_string pat)) 527 end 528 in 529 check (FV_multiset pat) 530 end; 531 532 533(*--------------------------------------------------------------------------- 534 Routines to repair the bound variable names found in cases 535 ---------------------------------------------------------------------------*) 536fun pat_match1 fvs pat given_pat = 537 let val (sub_tm, sub_ty) = Term.match_term pat given_pat 538 val _ = if null sub_ty then () else (raise ERR "pat_match1" "no type substitution expected"); 539 540 fun is_valid_bound_var v = (is_var v andalso not (List.exists (fn tm => aconv tm v) fvs)) 541 val _ = if List.all (fn m => is_valid_bound_var (#residue m)) sub_tm then () else 542 (raise ERR "pat_match1" "expected a bound variable renaming"); 543 in sub_tm 544 end 545 546fun pat_match2 fvs pat_exps given_pat = tryfind ((C (pat_match1 fvs) given_pat) o fst) pat_exps 547 handle HOL_ERR _ => ([]); 548 549fun subst_to_renaming (s : (term, term) subst) : (term * string) list = 550 map (fn m => (#redex m, fst (dest_var (#residue m)))) s; 551 552fun distinguish fvs pat_tm_mats = 553 snd (List.foldr (fn ({redex,residue}, (vs,done)) => 554 let val residue' = variant vs residue 555 val vs' = op_insert aconv residue' vs 556 in (vs', {redex=redex, residue=residue'} :: done) 557 end) 558 (fvs,[]) pat_tm_mats) 559 560fun reduce_mats pat_tm_mats = 561 snd (List.foldl (fn (mat as {redex,residue}, (vs,done)) => 562 if op_mem aconv redex vs then (vs, done) 563 else (redex :: vs, mat :: done)) 564 ([],[]) pat_tm_mats) 565 566fun purge_wildcards term_sub = filter (fn {redex,residue} => 567 not (String.sub (fst (dest_var residue), 0) = #"_") 568 handle _ => false) term_sub 569 570fun pat_match3 fvs pat_exps given_pats = 571 ((subst_to_renaming o distinguish fvs o reduce_mats o purge_wildcards o flatten)) 572 (map (pat_match2 fvs pat_exps) given_pats); 573 574 575(*---------------------------------------------------------------------------*) 576(* Syntax operations on the (extensible) set of case expressions. *) 577(*---------------------------------------------------------------------------*) 578 579fun mk_case1 tybase (exp, plist) = 580 case match_info tybase (type_names (type_of exp)) 581 of NONE => raise ERR "mk_case" "unable to analyze type" 582 | SOME tyinfo => 583 let val c = case_const_of tyinfo 584 val fns = map (fn (p,R) => list_mk_abs(snd(strip_comb p),R)) plist 585 val ty' = list_mk_fun (type_of exp::map type_of fns, 586 type_of (snd (hd plist))) 587 val theta = Type.match_type (type_of c) ty' 588 in list_mk_comb(inst theta c,exp::fns) 589 end 590 591fun mk_case2 v (exp, plist) = 592 let fun mk_switch [] = raise ERR "mk_case" "null patterns" 593 | mk_switch [(p,R)] = R 594 | mk_switch ((p,R)::rst) = 595 mk_bool_case(R, mk_switch rst, mk_eq(v,p)) 596 val switch = mk_switch plist 597 in if aconv v exp then switch 598 else mk_literal_case(mk_abs(v,switch),exp) 599 end; 600 601fun mk_case tybase (exp, plist) = 602 let val col0 = map fst plist 603 in if all (is_constructor_var_pat tybase) col0 604 andalso not (all is_var col0) 605 then (* constructor patterns *) 606 mk_case1 tybase (exp, plist) 607 else (* literal patterns *) 608 mk_case2 (last col0) (exp, plist) 609 end 610 611(*---------------------------------------------------------------------------*) 612(* dest_case destructs one level of pattern matching. To deal with nested *) 613(* patterns, use strip_case. *) 614(*---------------------------------------------------------------------------*) 615 616local fun build_case_clause((ty,constr),rhs) = 617 let val (args,tau) = strip_fun (type_of constr) 618 fun peel [] N = ([],N) 619 | peel (_::tys) N = 620 let val (v,M) = dest_abs N 621 val (V,M') = peel tys M 622 in (v::V,M') 623 end 624 val (V,rhs') = peel args rhs 625 val theta = Type.match_type (type_of constr) 626 (list_mk_fun (map type_of V, ty)) 627 val constr' = inst theta constr 628 in 629 (list_mk_comb(constr',V), rhs') 630 end 631in 632fun dest_case1 tybase M = 633 let val (c,args) = strip_comb M 634 val (cases,arg) = 635 case args of h::t => (t, h) 636 | _ => raise ERR "dest_case" "case exp has too few args" 637 in case match_info tybase (type_names (type_of arg)) 638 of NONE => raise ERR "dest_case" "unable to destruct case expression" 639 | SOME tyinfo => 640 let val d = case_const_of tyinfo 641 in if same_const c d 642 then let val constrs = constructors_of tyinfo 643 val constrs_type = map (pair (type_of arg)) constrs 644 in (c, arg, map build_case_clause (zip constrs_type cases)) 645 end 646 else raise ERR "dest_case" "unable to destruct case expression" 647 end 648 end 649end 650 651fun dest_case tybase M = 652 if is_literal_case M then 653 let val (lcf, e) = dest_comb M 654 val (lit_cs, f) = dest_comb lcf 655 val (x, M') = dest_abs f 656 in (lit_cs, e, [(x,M')]) 657 end 658 else dest_case1 tybase M 659 660fun is_case1 tybase M = 661 let val (c,args) = strip_comb M 662 val (tynames as {Tyop=tyop, ...}) = 663 type_names (type_of (hd args)) handle Empty => raise ERR "" "" 664 (* will get caught later *) 665 in 666 case match_info tybase tynames of 667 NONE => raise ERR "is_case" ("unknown type operator: "^Lib.quote tyop) 668 | SOME tyinfo => let 669 val gconst = case_const_of tyinfo 670 val gty = type_of gconst 671 val argtys = fst (strip_fun gty) 672 in 673 same_const c gconst andalso length args = length argtys 674 end 675 end 676 handle HOL_ERR _ => false; 677 678fun is_case tybase M = is_literal_case M orelse is_case1 tybase M 679 680fun tm_null_intersection l1 l2 = 681 case (l1, l2) of 682 ([], _) => true 683 | (_, []) => true 684 | (tm::tms, _) => not (op_mem aconv tm l2) andalso 685 tm_null_intersection tms l2 686 687local fun dest tybase (pat,rhs) = 688 let val patvars = free_vars pat 689 in if is_case tybase rhs then 690 let 691 val (case_tm,exp,clauses) = dest_case tybase rhs 692 val (pats,rhsides) = unzip clauses 693 in 694 if is_eq exp then 695 let 696 val (v,e) = dest_eq exp 697 val fvs = free_vars v 698 (* val theta = fst (Term.match_term v e) handle HOL_ERR _ => [] *) 699 in 700 if null (op_set_diff aconv fvs patvars) andalso null (free_vars e) 701 andalso is_var v 702 (* andalso null_intersection fvs (free_vars (hd rhsides)) *) 703 then flatten 704 (map (dest tybase) (zip [subst [v |-> e] pat, pat] rhsides)) 705 else [(pat,rhs)] 706 end 707 else 708 let 709 val fvs = free_vars exp 710 in 711 if null (op_set_diff aconv fvs patvars) andalso 712 tm_null_intersection fvs (free_varsl rhsides) 713 then flatten 714 (map (dest tybase) 715 (zip (map (fn p => 716 subst (fst (Term.match_term exp p)) pat) pats) 717 rhsides)) 718 else [(pat,rhs)] 719 end 720 handle HOL_ERR _ => [(pat,rhs)] (* catch from match_term *) 721 end 722 else [(pat,rhs)] 723 end 724in 725fun strip_case1 tybase M = 726 (case total (dest_case tybase) M 727 of NONE => (M,[]) 728 | SOME(case_tm,exp,cases) => 729 if is_eq exp 730 then let val (v,e) = dest_eq exp 731 in (v, flatten (map (dest tybase) 732 (zip [e, v] (map snd cases)))) 733 end 734 else (exp, flatten (map (dest tybase) cases))) 735end; 736 737fun strip_case tybase M = 738 if is_literal_case M then 739 let val (lcf, e) = dest_comb M 740 val (lit_cs, f) = dest_comb lcf 741 val (x, M') = dest_abs f 742 val (exp, cases) = if is_case1 tybase M' 743 then strip_case1 tybase M' 744 else (x, [(x, M')]) 745 in (e, cases) 746 end 747 else strip_case1 tybase M 748 749fun rename_top_bound_vars ren cs = 750 case dest_term cs of 751 VAR _ => cs 752 | CONST _ => cs 753 | COMB (t1, t2) => 754 mk_comb (rename_top_bound_vars ren t1, rename_top_bound_vars ren t2) 755 | LAMB (v, t) => 756 let 757 val cs' = rename_bvar (op_assoc aconv v ren) cs handle HOL_ERR _ => cs 758 val (v', t') = dest_abs cs' 759 val t'' = rename_top_bound_vars ren t' 760 in 761 mk_abs (v', t'') 762 end; 763 764local fun paired1{lhs,rhs} = (lhs,rhs) 765 and paired2{Rator,Rand} = (Rator,Rand) 766 fun err s = raise ERR "mk_functional" s 767 fun msg s = HOL_MESG ("mk_functional: "^s) 768in 769fun mk_functional thy eqs = 770 let val clauses = strip_conj eqs 771 val (L,R) = unzip (map (dest_eq o snd o strip_forall) clauses) 772 val (funcs,pats) = unzip(map dest_comb L) 773 val fs = Lib.op_mk_set aconv funcs 774 val f0 = if length fs = 1 then hd fs else err "function name not unique" 775 val f = if is_var f0 then f0 else mk_var(dest_const f0) 776 val _ = map no_repeat_vars pats 777 val rows = zip (map (fn x => ([]:term list,[x])) pats) (map GIVEN (enumerate R)) 778 val avs = all_varsl (L@R) 779 val a = variant avs (mk_var("a", type_of(Lib.trye hd pats))) 780 val FV = a::avs 781 val range_ty = type_of (Lib.trye hd R) 782 val (patts, case_tm) = mk_case0 (match_info thy) (match_type thy) 783 FV range_ty {path=[a], rows=rows} 784 fun func (_,(tag,i),[pat]) = tag (pat,i) 785 | func _ = err "error in pattern-match translation" 786 val patts1 = map func patts 787 val (omits,givens) = Lib.partition is_omitted patts1 788 val givens' = sort pattern_cmp givens 789 val patts2 = givens' @ omits 790 val finals = map row_of_pat patts2 791 val originals = map (row_of_pat o #2) rows 792 val new_rows = length finals - length originals 793 val clause_s = if new_rows = 1 then " clause " else " clauses " 794 val _ = if new_rows > 0 then 795 (msg ("\n pattern completion has added "^ 796 Int.toString new_rows^clause_s^ 797 "to the original specification."); 798 if !allow_new_clauses then () 799 else 800 err ("new clauses not allowed under current setting of "^ 801 Lib.quote("Functional.allow_new_clauses")^" flag")) 802 else () 803 fun int_eq i1 (i2:int) = (i1=i2) 804 val inaccessibles = filter(fn x => not(op_mem int_eq x finals)) originals 805 fun accessible p = not(op_mem int_eq (row_of_pat p) inaccessibles) 806 val patts3 = (case inaccessibles of [] => patts2 807 | _ => filter accessible patts2) 808 val _ = case inaccessibles of [] => () 809 | _ => msg("The following input rows (counting from zero) are\ 810 \ inaccessible: "^stringize inaccessibles^".\nThey have been ignored.") 811 (* The next lines repair bound variable names in the nested case term. *) 812 val case_tm' = 813 let val (_,pat_exps) = strip_case thy case_tm 814 val fvs = free_vars case_tm 815 val ren = pat_match3 fvs pat_exps pats (* better pats than givens patts3 *) 816 in (rename_top_bound_vars ren case_tm) 817 end handle HOL_ERR _ => 818 (Feedback.HOL_WARNING "Pmatch" "mk_functional" "SHOULD NOT HAPPEN! RENAMING CASE_TM FAILED!"; 819 case_tm) 820 (* Ensure that the case test variable is fresh for the rest of the case *) 821 val avs = op_set_diff aconv (all_vars case_tm') [a] 822 val a' = variant avs a 823 val case_tm'' = if aconv a' a then case_tm' 824 else subst ([a |-> a']) case_tm' 825 in 826 {functional = list_mk_abs ([f,a'], case_tm''), 827 pats = patts3} 828 end 829end; 830 831(*--------------------------------------------------------------------------- 832 Given a list of (pattern,expression) pairs, mk_pattern_fn creates a term 833 as an abstraction containing a case expression on the function's argument. 834 ---------------------------------------------------------------------------*) 835 836fun mk_pattern_fn thy (pes: (term * term) list) = 837 let fun err s = raise ERR "mk_pattern_fn" s 838 val (p0,e0) = Lib.trye hd pes 839 handle HOL_ERR _ => err "empty list of (pattern,expression) pairs" 840 val pty = type_of p0 and ety = type_of e0 841 val (ps,es) = unzip pes 842 val _ = if all (Lib.equal pty o type_of) ps then () 843 else err "patterns have varying types" 844 val _ = if all (Lib.equal ety o type_of) es then () 845 else err "expressions have varying types" 846 val fvar = genvar (pty --> ety) 847 val eqs = list_mk_conj (map (fn (p,e) => mk_eq(mk_comb(fvar,p), e)) pes) 848 val {functional,pats} = mk_functional thy eqs 849 val f = snd (dest_abs functional) 850 in 851 f 852 end 853 854end; 855