1structure Pmatch :> Pmatch = 2struct 3 4open HolKernel boolSyntax PmatchHeuristics; 5 6type thry = {Tyop : string, Thy : string} -> 7 {case_const : term, constructors : term list} option 8 9val ERR = mk_HOL_ERR "Pmatch"; 10 11val allow_new_clauses = ref true; 12 13(*--------------------------------------------------------------------------- 14 Miscellaneous support 15 ---------------------------------------------------------------------------*) 16 17fun gtake f = 18 let fun grab(0,rst) = ([],rst) 19 | grab(n, x::rst) = 20 let val (taken,left) = grab(n-1,rst) 21 in (f x::taken, left) end 22 | grab _ = raise ERR "gtake" "grab.empty list" 23 in grab 24 end; 25 26fun list_to_string f delim = 27 let fun stringulate [] = [] 28 | stringulate [x] = [f x] 29 | stringulate (h::t) = f h::delim::stringulate t 30 in 31 fn l => String.concat (stringulate l) 32 end; 33 34val stringize = list_to_string int_to_string ", "; 35 36fun enumerate l = map (fn (x,y) => (y,x)) (Lib.enumerate 0 l); 37 38fun match_term thry tm1 tm2 = Term.match_term tm1 tm2; 39fun match_type thry ty1 ty2 = Type.match_type ty1 ty2; 40 41fun match_info db s = db s 42 43(* should probably be in somewhere like HolKernel *) 44local val counter = ref 0 45in 46fun vary vlist = 47 let val slist = ref (map (fst o dest_var) vlist) 48 val _ = counter := 0 49 fun pass str = 50 if Lib.mem str (!slist) 51 then (counter := !counter + 1; pass ("v"^int_to_string(!counter))) 52 else (slist := str :: !slist; str) 53 in 54 fn ty => mk_var(pass "v", ty) 55 end 56end; 57 58 59(*--------------------------------------------------------------------------- 60 * This datatype carries some information about the origin of a 61 * clause in a function definition. 62 *---------------------------------------------------------------------------*) 63 64datatype pattern = GIVEN of term * int 65 | OMITTED of term * int 66 67fun pattern_cmp (GIVEN(_,i)) (GIVEN(_, j)) = i <= j 68 | pattern_cmp all others = raise ERR "pattern_cmp" "" 69 70fun psubst theta (GIVEN (tm,i)) = GIVEN(subst theta tm, i) 71 | psubst theta (OMITTED (tm,i)) = OMITTED(subst theta tm, i); 72 73fun dest_pattern (GIVEN (tm,i)) = ((GIVEN,i),tm) 74 | dest_pattern (OMITTED (tm,i)) = ((OMITTED,i),tm); 75 76fun pat_of (GIVEN (tm,_)) = tm 77 | pat_of (OMITTED(tm,_)) = tm 78 79fun row_of_pat (GIVEN(_, i)) = i 80 | row_of_pat (OMITTED _) = ~1 81 82fun dest_given (GIVEN(tm,_)) = tm 83 | dest_given (OMITTED _) = raise ERR "dest_given" "" 84 85fun mk_omitted tm = OMITTED(tm,~1) 86 87fun is_omitted (OMITTED _) = true 88 | is_omitted otherwise = false; 89 90val givens = mapfilter dest_given; 91 92(*--------------------------------------------------------------------------- 93 * Produce an instance of a constructor, plus genvars for its arguments. 94 *---------------------------------------------------------------------------*) 95 96fun fresh_constr ty_match (colty:hol_type) gv c = 97 let val Ty = type_of c 98 val (L,ty) = strip_fun Ty 99 val ty_theta = ty_match ty colty 100 val c' = inst ty_theta c 101 val gvars = map (inst ty_theta o gv) L 102 in (c', gvars) 103 end; 104 105 106(*---------------------------------------------------------------------------* 107 * Goes through a list of rows and picks out the ones beginning with a * 108 * pattern = Literal, or all those beginning with a variable if the pattern * 109 * is a variable. * 110 *---------------------------------------------------------------------------*) 111 112fun mk_groupl literal rows = 113 let fun func (row as ((prefix, p::rst), rhs)) (in_group,not_in_group) = 114 if (is_var literal andalso is_var p) orelse p = literal 115 then if is_var literal 116 then (((prefix,p::rst), rhs)::in_group, not_in_group) 117 else (((prefix,rst), rhs)::in_group, not_in_group) 118 else (in_group, row::not_in_group) 119 | func _ _ = raise ERR "mk_groupl" "" 120 in 121 itlist func rows ([],[]) 122 end; 123 124(*---------------------------------------------------------------------------* 125 * Goes through a list of rows and picks out the ones beginning with a * 126 * pattern with constructor = c. * 127 *---------------------------------------------------------------------------*) 128 129fun mk_group c rows = 130 let fun func (row as ((prefix, p::rst), rhs)) (in_group,not_in_group) = 131 let val (pc,args) = strip_comb p 132 in if same_const pc c 133 then (((prefix,args@rst), rhs)::in_group, not_in_group) 134 else (in_group, row::not_in_group) 135 end 136 | func _ _ = raise ERR "mk_group" "" 137 in 138 itlist func rows ([],[]) 139 end; 140 141 142(*---------------------------------------------------------------------------* 143 * Partition the rows among literals. Not efficient. * 144 *---------------------------------------------------------------------------*) 145 146fun partitionl _ _ (_,_,_,[]) = raise ERR"partitionl" "no rows" 147 | partitionl gv ty_match 148 (constructors, colty, res_ty, rows as (((prefix,_),_)::_)) = 149let fun part {constrs = [], rows, A} = rev A 150 | part {constrs = c::crst, rows, A} = 151 let val (in_group, not_in_group) = mk_groupl c rows 152 val in_group' = 153 if (null in_group) (* Constructor not given *) 154 then [((prefix, []), mk_omitted (mk_arb res_ty))] 155 else in_group 156 val gvars = if is_var c then [c] else [] 157 in 158 part{constrs = crst, 159 rows = not_in_group, 160 A = {constructor = c, 161 new_formals = gvars, 162 group = in_group'}::A} 163 end 164in part{constrs=constructors, rows=rows, A=[]} 165end; 166 167 168(*---------------------------------------------------------------------------* 169 * Partition the rows. Not efficient. * 170 *---------------------------------------------------------------------------*) 171 172fun partition _ _ (_,_,_,[]) = raise ERR"partition" "no rows" 173 | partition gv ty_match 174 (constructors, colty, res_ty, rows as (((prefix:term list,_),_)::_)) = 175let val fresh = fresh_constr ty_match colty gv 176 fun part {constrs = [], rows, A} = rev A 177 | part {constrs = c::crst, rows, A} = 178 let val (c',gvars) = fresh c 179 val (in_group, not_in_group) = mk_group c' rows 180 val in_group' = 181 if (null in_group) (* Constructor not given *) 182 then [((prefix, #2(fresh c)), mk_omitted (mk_arb res_ty))] 183 else in_group 184 in 185 part{constrs = crst, 186 rows = not_in_group, 187 A = {constructor = c', new_formals = gvars, group = in_group'}::A} 188 end 189in part{constrs=constructors, rows=rows, A=[]} 190end; 191 192 193(*--------------------------------------------------------------------------- 194 * Misc. routines used in mk_case 195 *---------------------------------------------------------------------------*) 196 197fun mk_patl c = 198 let val L = if is_var c then 1 else 0 199 fun build (prefix,tag,plist) = 200 let val (args,plist') = gtake I (L, plist) 201 val c' = if is_var c then hd args else c 202 in (prefix,tag, c'::plist') end 203 in map build 204 end; 205 206fun mk_pat c = 207 let val L = length(#1(strip_fun(type_of c))) 208 fun build (prefix,tag,plist) = 209 let val (args,plist') = gtake I (L, plist) 210 in (prefix,tag,list_mk_comb(c,args)::plist') end 211 in map build 212 end; 213 214 215fun v_to_prefix (prefix, v::pats) = (v::prefix,pats) 216 | v_to_prefix _ = raise ERR "mk_case" "v_to_prefix" 217 218fun v_to_pats (v::prefix,tag, pats) = (prefix, tag, v::pats) 219 | v_to_pats _ = raise ERR "mk_case""v_to_pats"; 220 221(* -------------------------------------------------------------- 222 Literals include numeric, string, and character literals. 223 Boolean literals are the constructors of the bool type. 224 Currently, literals may be any expression without free vars. 225 These functions are not used at the moment, but may be someday. 226 -------------------------------------------------------------- *) 227 228(* 229val is_literal = Literal.is_literal 230 231fun is_lit_or_var tm = is_literal tm orelse is_var tm 232 233fun is_zero_emptystr_or_var tm = 234 Literal.is_zero tm orelse Literal.is_emptystring tm orelse is_var tm 235*) 236 237fun is_closed_or_var tm = is_var tm orelse null (free_vars tm) 238 239 240(* --------------------------------------------------------------------------- 241 Reconstructed code from TypeBasePure, to avoid circularity. 242 ---------------------------------------------------------------------------*) 243 244fun case_const_of {case_const : term, constructors : term list} = case_const 245fun constructors_of {case_const : term, constructors : term list} = constructors 246 247fun type_names ty = 248 let val {Thy,Tyop,Args} = Type.dest_thy_type ty 249 in {Thy=Thy,Tyop=Tyop} 250 end; 251 252(*---------------------------------------------------------------------------*) 253(* Is a constant a constructor for some datatype. *) 254(*---------------------------------------------------------------------------*) 255 256fun is_constructor tybase c = 257 let val (_,ty) = strip_fun (type_of c) 258 in case tybase (type_names ty) 259 of NONE => false 260 | SOME tyinfo => op_mem same_const c (constructors_of tyinfo) 261 end handle HOL_ERR _ => false; 262 263fun is_constructor_pat tybase tm = 264 is_constructor tybase (fst (strip_comb tm)); 265 266fun is_constructor_var_pat ty_info tm = 267 is_var tm orelse is_constructor_pat ty_info tm 268 269fun mk_switch_tm gv v base literals = 270 let val rty = type_of base 271 val lty = type_of v 272 val v' = last literals handle _ => gv lty 273 fun mk_arg lit = if is_var lit then gv (lty --> rty) else gv rty 274 val args = map mk_arg literals 275 open boolSyntax 276 fun mk_switch [] = base 277 | mk_switch ((lit,arg)::litargs) = 278 if is_var lit then mk_comb(arg, v') 279 else mk_bool_case(arg, mk_switch litargs, mk_eq(v', lit)) 280 val switch = mk_switch (zip literals args) 281 in list_mk_abs(args@[v], mk_literal_case (mk_abs(v',switch), v)) 282 end 283 284(* under_bool_case repairs a final beta_conv for literal switches. *) 285 286fun under_literal_case conv tm = 287 if is_literal_case tm then 288 let val (f,e) = dest_literal_case tm 289 val (x,bdy) = dest_abs f 290 val bdy' = conv bdy handle HOL_ERR _ => bdy 291 in mk_literal_case (mk_abs(x, bdy'), e) 292 end 293 else conv tm handle HOL_ERR _ => tm 294 295fun under_bool_case conv tm = 296 if is_bool_case tm then 297 let val (t,f,tst) = dest_bool_case tm 298 val f' = under_bool_case conv f 299 in mk_bool_case (t,f',tst) 300 end 301 else conv tm handle HOL_ERR _ => tm 302 303fun under_literal_bool_case conv tm = 304 under_literal_case (under_bool_case conv) tm 305 306 307(*---------------------------------------------------------------------------- 308 Translation of pattern terms into nested case expressions. 309 310 This performs the translation and also builds the full set of patterns. 311 Thus it supports the construction of induction theorems even when an 312 incomplete set of patterns is given. 313 ----------------------------------------------------------------------------*) 314 315fun bring_to_front_list n l = let 316 val (l0, l1) = Lib.split_after n l 317 val (x, l1') = (hd l1, tl l1) 318 in x :: (l0 @ l1') end 319 320fun undo_bring_to_front n l = let 321 val (x, l') = (hd l, tl l) 322 val (l0, l1) = Lib.split_after n l' 323 in (l0 @ x::l1) end 324 325fun mk_case0_heu (heu : pmatch_heuristic) ty_info ty_match FV range_ty = 326 let 327 fun mk_case_fail s = raise ERR "mk_case" s 328 val fresh_var = vary FV 329 val dividel = partitionl fresh_var ty_match 330 val divide = partition fresh_var ty_match 331 fun expandl literals ty ((_,[]), _) = mk_case_fail "expandl_var_row" 332 | expandl literals ty (row as ((prefix, p::rst), rhs)) = 333 if is_var p 334 then let fun expnd l = 335 ((prefix, l::rst), psubst[p |-> l] rhs) 336 in map expnd literals end 337 else [row] 338 fun expand constructors ty ((_,[]), _) = mk_case_fail "expand_var_row" 339 | expand constructors ty (row as ((prefix, p::rst), rhs)) = 340 (if is_var p 341 then let val fresh = fresh_constr ty_match ty fresh_var 342 fun expnd (c,gvs) = 343 let val capp = list_mk_comb(c,gvs) 344 in ((prefix, capp::rst), psubst[p |-> capp] rhs) 345 end 346 in map expnd (map fresh constructors) end 347 else [row]) 348 fun mk{rows=[],...} = mk_case_fail "no rows" 349 | mk{path=[], rows = ((prefix, []), rhs)::_} = (* Done *) 350 let val (tag,tm) = dest_pattern rhs 351 in ([(prefix,tag,[])], tm) 352 end 353 | mk{path=[], rows = _::_} = mk_case_fail "blunder" 354 | mk{path as u::rstp, rows as ((prefix, []), rhs)::rst} = 355 mk{path = path, 356 rows = ((prefix, [fresh_var(type_of u)]), rhs)::rst} 357 | mk{path = rstp0, rows = rows0 as ((_, pL as (_ :: _)), _)::_} = 358 if ((#skip_rows heu) andalso length rows0 > 1 andalso all is_var pL) 359 then mk {path = rstp0, rows = [hd rows0]} 360 else 361 let val col_index = (#col_fun heu) ty_info (map (fn ((_, pL), _) => pL) rows0) 362 val u_rstp = bring_to_front_list col_index rstp0 363 val (u, rstp) = (hd u_rstp, tl u_rstp) 364 val rows = map (fn ((prefix, pL), rhs) => ((prefix, bring_to_front_list col_index pL), rhs)) rows0 365 val ((_, pL), _) = hd rows 366 val p = hd pL 367 val (pat_rectangle,rights) = unzip rows 368 val col0 = map(Lib.trye hd o #2) pat_rectangle 369 in 370 if all is_var col0 371 then let val rights' = map(fn(v,e) => psubst[v|->u] e) (zip col0 rights) 372 val pat_rectangle' = map v_to_prefix pat_rectangle 373 val (pref_patl,tm) = mk{path = rstp, 374 rows = zip pat_rectangle' rights'} 375 val pat_rect1 = map v_to_pats pref_patl 376 val pat_rect1' = map (fn (x, y, pL) => (x, y, undo_bring_to_front col_index pL)) pat_rect1 377 in (pat_rect1', tm) 378 end 379 else 380 let val pty = type_of p 381 val thy_tyop = 382 type_names pty 383 handle HOL_ERR _ => 384 raise ERR "mk_case0_heu" 385 ("Term "^Parse.term_to_string p^ 386 " is a bad pattern (of var type?)") 387 in 388 if exists Literal.is_pure_literal col0 (* col0 has a literal *) then 389 let val is_lit_col = all (fn t => Literal.is_literal t orelse is_var t) col0 390 val _ = if is_lit_col then () else 391 mk_case_fail "case expression mixes literals with non-literals." 392 val other_var = fresh_var pty 393 val constructors = rev (mk_set (rev (filter (not o is_var) col0))) 394 @ [other_var] 395 val arb = mk_arb range_ty 396 val switch_tm = mk_switch_tm fresh_var u arb constructors 397 val nrows = flatten (map (expandl constructors pty) rows) 398 val subproblems = dividel(constructors, pty, range_ty, nrows) 399 val groups = map #group subproblems 400 and new_formals = map #new_formals subproblems 401 and constructors' = map #constructor subproblems 402 val news = map (fn (nf,rows) => {path = nf@rstp, rows=rows}) 403 (zip new_formals groups) 404 val rec_calls = map mk news 405 val (pat_rect,dtrees) = unzip rec_calls 406 val case_functions = map list_mk_abs(zip new_formals dtrees) 407 val tree = List.foldl (fn (a,tm) => beta_conv (mk_comb(tm,a))) 408 switch_tm (case_functions@[u]) 409 val tree' = under_literal_bool_case beta_conv tree 410 val pat_rect1 = flatten(map2 mk_patl constructors' pat_rect) 411 val pat_rect1' = map (fn (x, y, pL) => (x, y, undo_bring_to_front col_index pL)) pat_rect1 412 in 413 (pat_rect1',tree') 414 end 415 else 416 case List.find (not o is_constructor_var_pat ty_info) col0 of 417 NONE => let 418 val {case_const,constructors} = 419 Lib.with_exn Option.valOf (ty_info thy_tyop) 420 (ERR "mk_case0" ("could not get case constructors for type " ^ 421 Parse.type_to_string pty)) 422 handle Option.Option => (print "hello\n"; raise Option) 423 val {Name = case_const_name, Thy,...} = dest_thy_const case_const 424 val nrows = flatten (map (expand constructors pty) rows) 425 val subproblems = divide(constructors, pty, range_ty, nrows) 426 val groups = map #group subproblems 427 and new_formals = map #new_formals subproblems 428 and constructors' = map #constructor subproblems 429 val news = map (fn (nf,rows) => {path = nf@rstp, rows=rows}) 430 (zip new_formals groups) 431 val rec_calls = map mk news 432 val (pat_rect,dtrees) = unzip rec_calls 433 val tree = 434 if ((#collapse_cases heu) andalso 435 (List.all (aconv (hd dtrees)) (tl dtrees)) andalso 436 (List.all (fn (vL, tree) => 437 (List.all (fn v => not (free_in v tree)) vL)) (zip new_formals dtrees))) then 438 (* If all cases lead to the same result, no split is necessary *) 439 (hd dtrees) 440 else let 441 val case_functions = map list_mk_abs(zip new_formals dtrees) 442 val types = map type_of (u::case_functions) 443 val case_const' = mk_thy_const{Name = case_const_name, Thy = Thy, 444 Ty = list_mk_fun(types, range_ty)} 445 val tree = list_mk_comb(case_const', u::case_functions) 446 in tree end 447 val pat_rect1 = flatten(map2 mk_pat constructors' pat_rect) 448 val pat_rect1' = map (fn (x, y, pL) => (x, y, undo_bring_to_front col_index pL)) pat_rect1 449 in 450 (pat_rect1',tree) 451 end 452 | SOME t => mk_case_fail ("Pattern "^ 453 trace ("Unicode", 0) Parse.term_to_string t^ 454 " is not a constructor or variable") 455 end 456 end 457 in mk 458 end; 459 460fun mk_case0 ty_info ty_match FV range_ty rows = 461let 462 fun run_heu heu = mk_case0_heu heu ty_info ty_match FV range_ty rows 463 464 val (min_fun0, heu_fun) = (!pmatch_heuristic) () 465 fun min_fun ((pL1, dt1), (pL2, dt2)) = 466 min_fun0 ((map (fn (x, _, _) => x) pL1, dt1), (map (fn (x, _, _) => x) pL2, dt2)) 467 468 fun res_min NONE res = res 469 | res_min (SOME res1) res2 = 470 (case min_fun (res1, res2) of GREATER => res2 | _ => res1) 471 472 fun aux min = case (heu_fun ()) of 473 NONE => (case min of NONE => (print "SHOULD NOT HAPPEN! EMPTY PMATCH-HEURISTIC!"; fail()) | SOME min' => min') 474 | SOME heu => let 475 val res = run_heu heu 476 val min' = res_min min res 477 in aux (SOME min') end 478in 479 aux NONE 480end 481 482(*--------------------------------------------------------------------------- 483 Repeated variable occurrences in a pattern are not allowed. 484 ---------------------------------------------------------------------------*) 485 486fun FV_multiset tm = 487 case dest_term tm 488 of VAR v => [mk_var v] 489 | CONST _ => [] 490 | COMB(Rator,Rand) => FV_multiset Rator @ FV_multiset Rand 491 | LAMB(Bvar,Body) => Lib.subtract (FV_multiset Body) [Bvar] 492 (* raise ERR"FV_multiset" "lambda"; *) 493 494fun no_repeat_vars pat = 495 let fun check [] = true 496 | check (v::rst) = 497 if Lib.op_mem aconv v rst 498 then raise ERR"no_repeat_vars" 499 (strcat(quote(#1(dest_var v))) 500 (strcat" occurs repeatedly in the pattern " 501 (quote(Hol_pp.term_to_string pat)))) 502 else check rst 503 in check (FV_multiset pat) 504 end; 505 506 507(*--------------------------------------------------------------------------- 508 Routines to repair the bound variable names found in cases 509 ---------------------------------------------------------------------------*) 510fun pat_match1 fvs pat given_pat = 511 let val (sub_tm, sub_ty) = Term.match_term pat given_pat 512 val _ = if null sub_ty then () else (raise ERR "pat_match1" "no type substitution expected"); 513 514 fun is_valid_bound_var v = (is_var v andalso not (List.exists (fn tm => aconv tm v) fvs)) 515 val _ = if List.all (fn m => is_valid_bound_var (#residue m)) sub_tm then () else 516 (raise ERR "pat_match1" "expected a bound variable renaming"); 517 in sub_tm 518 end 519 520fun pat_match2 fvs pat_exps given_pat = tryfind ((C (pat_match1 fvs) given_pat) o fst) pat_exps 521 handle HOL_ERR _ => ([]); 522 523fun subst_to_renaming (s : (term, term) subst) : (term * string) list = 524 map (fn m => (#redex m, fst (dest_var (#residue m)))) s; 525 526fun distinguish fvs pat_tm_mats = 527 snd (List.foldr (fn ({redex,residue}, (vs,done)) => 528 let val residue' = variant vs residue 529 val vs' = Lib.insert residue' vs 530 in (vs', {redex=redex, residue=residue'} :: done) 531 end) 532 (fvs,[]) pat_tm_mats) 533 534fun reduce_mats pat_tm_mats = 535 snd (List.foldl (fn (mat as {redex,residue}, (vs,done)) => 536 if mem redex vs then (vs, done) 537 else (redex :: vs, mat :: done)) 538 ([],[]) pat_tm_mats) 539 540fun purge_wildcards term_sub = filter (fn {redex,residue} => 541 not (String.sub (fst (dest_var residue), 0) = #"_") 542 handle _ => false) term_sub 543 544fun pat_match3 fvs pat_exps given_pats = 545 ((subst_to_renaming o distinguish fvs o reduce_mats o purge_wildcards o flatten)) 546 (map (pat_match2 fvs pat_exps) given_pats); 547 548 549(*---------------------------------------------------------------------------*) 550(* Syntax operations on the (extensible) set of case expressions. *) 551(*---------------------------------------------------------------------------*) 552 553fun mk_case1 tybase (exp, plist) = 554 case match_info tybase (type_names (type_of exp)) 555 of NONE => raise ERR "mk_case" "unable to analyze type" 556 | SOME tyinfo => 557 let val c = case_const_of tyinfo 558 val fns = map (fn (p,R) => list_mk_abs(snd(strip_comb p),R)) plist 559 val ty' = list_mk_fun (type_of exp::map type_of fns, 560 type_of (snd (hd plist))) 561 val theta = Type.match_type (type_of c) ty' 562 in list_mk_comb(inst theta c,exp::fns) 563 end 564 565fun mk_case2 v (exp, plist) = 566 let fun mk_switch [] = raise ERR "mk_case" "null patterns" 567 | mk_switch [(p,R)] = R 568 | mk_switch ((p,R)::rst) = 569 mk_bool_case(R, mk_switch rst, mk_eq(v,p)) 570 val switch = mk_switch plist 571 in if v = exp then switch 572 else mk_literal_case(mk_abs(v,switch),exp) 573 end; 574 575fun mk_case tybase (exp, plist) = 576 let val col0 = map fst plist 577 in if all (is_constructor_var_pat tybase) col0 578 andalso not (all is_var col0) 579 then (* constructor patterns *) 580 mk_case1 tybase (exp, plist) 581 else (* literal patterns *) 582 mk_case2 (last col0) (exp, plist) 583 end 584 585(*---------------------------------------------------------------------------*) 586(* dest_case destructs one level of pattern matching. To deal with nested *) 587(* patterns, use strip_case. *) 588(*---------------------------------------------------------------------------*) 589 590local fun build_case_clause((ty,constr),rhs) = 591 let val (args,tau) = strip_fun (type_of constr) 592 fun peel [] N = ([],N) 593 | peel (_::tys) N = 594 let val (v,M) = dest_abs N 595 val (V,M') = peel tys M 596 in (v::V,M') 597 end 598 val (V,rhs') = peel args rhs 599 val theta = Type.match_type (type_of constr) 600 (list_mk_fun (map type_of V, ty)) 601 val constr' = inst theta constr 602 in 603 (list_mk_comb(constr',V), rhs') 604 end 605in 606fun dest_case1 tybase M = 607 let val (c,args) = strip_comb M 608 val (cases,arg) = 609 case args of h::t => (t, h) 610 | _ => raise ERR "dest_case" "case exp has too few args" 611 in case match_info tybase (type_names (type_of arg)) 612 of NONE => raise ERR "dest_case" "unable to destruct case expression" 613 | SOME tyinfo => 614 let val d = case_const_of tyinfo 615 in if same_const c d 616 then let val constrs = constructors_of tyinfo 617 val constrs_type = map (pair (type_of arg)) constrs 618 in (c, arg, map build_case_clause (zip constrs_type cases)) 619 end 620 else raise ERR "dest_case" "unable to destruct case expression" 621 end 622 end 623end 624 625fun dest_case tybase M = 626 if is_literal_case M then 627 let val (lcf, e) = dest_comb M 628 val (lit_cs, f) = dest_comb lcf 629 val (x, M') = dest_abs f 630 in (lit_cs, e, [(x,M')]) 631 end 632 else dest_case1 tybase M 633 634fun is_case1 tybase M = 635 let val (c,args) = strip_comb M 636 val (tynames as {Tyop=tyop, ...}) = 637 type_names (type_of (hd args)) handle Empty => raise ERR "" "" 638 (* will get caught later *) 639 in 640 case match_info tybase tynames of 641 NONE => raise ERR "is_case" ("unknown type operator: "^Lib.quote tyop) 642 | SOME tyinfo => let 643 val gconst = case_const_of tyinfo 644 val gty = type_of gconst 645 val argtys = fst (strip_fun gty) 646 in 647 same_const c gconst andalso length args = length argtys 648 end 649 end 650 handle HOL_ERR _ => false; 651 652fun is_case tybase M = is_literal_case M orelse is_case1 tybase M 653 654local fun dest tybase (pat,rhs) = 655 let val patvars = free_vars pat 656 in if is_case tybase rhs 657 then let val (case_tm,exp,clauses) = dest_case tybase rhs 658 val (pats,rhsides) = unzip clauses 659 in if is_eq exp 660 then let val (v,e) = dest_eq exp 661 val fvs = free_vars v 662 (* val theta = fst (Term.match_term v e) handle HOL_ERR _ => [] *) 663 in if null (subtract fvs patvars) andalso null (free_vars e) 664 andalso is_var v 665 (* andalso null_intersection fvs (free_vars (hd rhsides)) *) 666 then flatten 667 (map (dest tybase) 668 (zip [subst [v |-> e] pat, pat] rhsides)) 669 else [(pat,rhs)] 670 end 671 else let val fvs = free_vars exp 672 in if null (subtract fvs patvars) andalso 673 null_intersection fvs (free_varsl rhsides) 674 then flatten 675 (map (dest tybase) 676 (zip (map (fn p => 677 subst (fst (Term.match_term exp p)) pat) pats) 678 rhsides)) 679 else [(pat,rhs)] 680 end 681 handle HOL_ERR _ => [(pat,rhs)] (* catch from match_term *) 682 end 683 else [(pat,rhs)] 684 end 685in 686fun strip_case1 tybase M = 687 (case total (dest_case tybase) M 688 of NONE => (M,[]) 689 | SOME(case_tm,exp,cases) => 690 if is_eq exp 691 then let val (v,e) = dest_eq exp 692 in (v, flatten (map (dest tybase) 693 (zip [e, v] (map snd cases)))) 694 end 695 else (exp, flatten (map (dest tybase) cases))) 696end; 697 698fun strip_case tybase M = 699 if is_literal_case M then 700 let val (lcf, e) = dest_comb M 701 val (lit_cs, f) = dest_comb lcf 702 val (x, M') = dest_abs f 703 val (exp, cases) = if is_case1 tybase M' 704 then strip_case1 tybase M' 705 else (x, [(x, M')]) 706 in (e, cases) 707 end 708 else strip_case1 tybase M 709 710fun rename_top_bound_vars ren cs = ( 711 case dest_term cs of 712 VAR _ => cs 713 | CONST _ => cs 714 | COMB (t1, t2) => mk_comb (rename_top_bound_vars ren t1, rename_top_bound_vars ren t2) 715 | LAMB (v, t) => 716 let val cs' = rename_bvar (Lib.assoc v ren) cs handle HOL_ERR _ => cs 717 val (v', t') = dest_abs cs' 718 val t'' = rename_top_bound_vars ren t' 719 in mk_abs (v', t'') end 720); 721 722local fun paired1{lhs,rhs} = (lhs,rhs) 723 and paired2{Rator,Rand} = (Rator,Rand) 724 fun err s = raise ERR "mk_functional" s 725 fun msg s = HOL_MESG ("mk_functional: "^s) 726in 727fun mk_functional thy eqs = 728 let val clauses = strip_conj eqs 729 val (L,R) = unzip (map (dest_eq o snd o strip_forall) clauses) 730 val (funcs,pats) = unzip(map dest_comb L) 731 val fs = Lib.op_mk_set aconv funcs 732 val f0 = if length fs = 1 then hd fs else err "function name not unique" 733 val f = if is_var f0 then f0 else mk_var(dest_const f0) 734 val _ = map no_repeat_vars pats 735 val rows = zip (map (fn x => ([]:term list,[x])) pats) (map GIVEN (enumerate R)) 736 val avs = all_varsl (L@R) 737 val a = variant avs (mk_var("a", type_of(Lib.trye hd pats))) 738 val FV = a::avs 739 val range_ty = type_of (Lib.trye hd R) 740 val (patts, case_tm) = mk_case0 (match_info thy) (match_type thy) 741 FV range_ty {path=[a], rows=rows} 742 fun func (_,(tag,i),[pat]) = tag (pat,i) 743 | func _ = err "error in pattern-match translation" 744 val patts1 = map func patts 745 val (omits,givens) = Lib.partition is_omitted patts1 746 val givens' = sort pattern_cmp givens 747 val patts2 = givens' @ omits 748 val finals = map row_of_pat patts2 749 val originals = map (row_of_pat o #2) rows 750 val new_rows = length finals - length originals 751 val clause_s = if new_rows = 1 then " clause " else " clauses " 752 val _ = if new_rows > 0 then 753 (msg ("\n pattern completion has added "^ 754 Int.toString new_rows^clause_s^ 755 "to the original specification."); 756 if !allow_new_clauses then () 757 else 758 err ("new clauses not allowed under current setting of "^ 759 Lib.quote("Functional.allow_new_clauses")^" flag")) 760 else () 761 fun int_eq i1 (i2:int) = (i1=i2) 762 val inaccessibles = filter(fn x => not(op_mem int_eq x finals)) originals 763 fun accessible p = not(op_mem int_eq (row_of_pat p) inaccessibles) 764 val patts3 = (case inaccessibles of [] => patts2 765 | _ => filter accessible patts2) 766 val _ = case inaccessibles of [] => () 767 | _ => msg("The following input rows (counting from zero) are\ 768 \ inaccessible: "^stringize inaccessibles^".\nThey have been ignored.") 769 (* The next lines repair bound variable names in the nested case term. *) 770 val case_tm' = 771 let val (_,pat_exps) = strip_case thy case_tm 772 val fvs = free_vars case_tm 773 val ren = pat_match3 fvs pat_exps pats (* better pats than givens patts3 *) 774 in (rename_top_bound_vars ren case_tm) 775 end handle HOL_ERR _ => 776 (Feedback.HOL_WARNING "Pmatch" "mk_functional" "SHOULD NOT HAPPEN! RENAMING CASE_TM FAILED!"; 777 case_tm) 778 (* Ensure that the case test variable is fresh for the rest of the case *) 779 val avs = subtract (all_vars case_tm') [a] 780 val a' = variant avs a 781 val case_tm'' = if a' = a then case_tm' 782 else subst ([a |-> a']) case_tm' 783 in 784 {functional = list_mk_abs ([f,a'], case_tm''), 785 pats = patts3} 786 end 787end; 788 789(*--------------------------------------------------------------------------- 790 Given a list of (pattern,expression) pairs, mk_pattern_fn creates a term 791 as an abstraction containing a case expression on the function's argument. 792 ---------------------------------------------------------------------------*) 793 794fun mk_pattern_fn thy (pes: (term * term) list) = 795 let fun err s = raise ERR "mk_pattern_fn" s 796 val (p0,e0) = Lib.trye hd pes 797 handle HOL_ERR _ => err "empty list of (pattern,expression) pairs" 798 val pty = type_of p0 and ety = type_of e0 799 val (ps,es) = unzip pes 800 val _ = if all (Lib.equal pty o type_of) ps then () 801 else err "patterns have varying types" 802 val _ = if all (Lib.equal ety o type_of) es then () 803 else err "expressions have varying types" 804 val fvar = genvar (pty --> ety) 805 val eqs = list_mk_conj (map (fn (p,e) => mk_eq(mk_comb(fvar,p), e)) pes) 806 val {functional,pats} = mk_functional thy eqs 807 val f = snd (dest_abs functional) 808 in 809 f 810 end 811 812end; 813