1structure Pmatch :> Pmatch =
2struct
3
4open HolKernel boolSyntax PmatchHeuristics;
5
6type thry   = {Tyop : string, Thy : string} ->
7              {case_const : term, constructors : term list} option
8
9val ERR = mk_HOL_ERR "Pmatch";
10
11val allow_new_clauses = ref true;
12
13(*---------------------------------------------------------------------------
14      Miscellaneous support
15 ---------------------------------------------------------------------------*)
16
17fun gtake f =
18  let fun grab(0,rst) = ([],rst)
19        | grab(n, x::rst) =
20             let val (taken,left) = grab(n-1,rst)
21             in (f x::taken, left) end
22        | grab _ = raise ERR "gtake" "grab.empty list"
23  in grab
24  end;
25
26fun list_to_string f delim =
27  let fun stringulate [] = []
28        | stringulate [x] = [f x]
29        | stringulate (h::t) = f h::delim::stringulate t
30  in
31    fn l => String.concat (stringulate l)
32  end;
33
34val stringize = list_to_string int_to_string ", ";
35
36fun enumerate l = map (fn (x,y) => (y,x)) (Lib.enumerate 0 l);
37
38fun match_term thry tm1 tm2 = Term.match_term tm1 tm2;
39fun match_type thry ty1 ty2 = Type.match_type ty1 ty2;
40
41fun match_info db s = db s
42
43(* should probably be in somewhere like HolKernel *)
44local val counter = ref 0
45in
46fun vary vlist =
47  let val slist = ref (map (fst o dest_var) vlist)
48      val _ = counter := 0
49      fun pass str =
50         if Lib.mem str (!slist)
51         then (counter := !counter + 1; pass ("v"^int_to_string(!counter)))
52         else (slist := str :: !slist; str)
53  in
54    fn ty => mk_var(pass "v", ty)
55  end
56end;
57
58
59(*---------------------------------------------------------------------------
60 * This datatype carries some information about the origin of a
61 * clause in a function definition.
62 *---------------------------------------------------------------------------*)
63
64datatype pattern = GIVEN   of term * int
65                 | OMITTED of term * int
66
67fun pattern_cmp (GIVEN(_,i)) (GIVEN(_, j)) = i <= j
68  | pattern_cmp all others = raise ERR "pattern_cmp" ""
69
70fun psubst theta (GIVEN (tm,i)) = GIVEN(subst theta tm, i)
71  | psubst theta (OMITTED (tm,i)) = OMITTED(subst theta tm, i);
72
73fun dest_pattern (GIVEN (tm,i)) = ((GIVEN,i),tm)
74  | dest_pattern (OMITTED (tm,i)) = ((OMITTED,i),tm);
75
76fun pat_of (GIVEN (tm,_)) = tm
77  | pat_of (OMITTED(tm,_)) = tm
78
79fun row_of_pat (GIVEN(_, i)) = i
80  | row_of_pat (OMITTED _) = ~1
81
82fun dest_given (GIVEN(tm,_)) = tm
83  | dest_given (OMITTED _) = raise ERR "dest_given" ""
84
85fun mk_omitted tm = OMITTED(tm,~1)
86
87fun is_omitted (OMITTED _) = true
88  | is_omitted otherwise   = false;
89
90val givens = mapfilter dest_given;
91
92(*---------------------------------------------------------------------------
93 * Produce an instance of a constructor, plus genvars for its arguments.
94 *---------------------------------------------------------------------------*)
95
96fun fresh_constr ty_match (colty:hol_type) gv c =
97  let val Ty = type_of c
98      val (L,ty) = strip_fun Ty
99      val ty_theta = ty_match ty colty
100      val c' = inst ty_theta c
101      val gvars = map (inst ty_theta o gv) L
102  in (c', gvars)
103  end;
104
105
106(*---------------------------------------------------------------------------*
107 * Goes through a list of rows and picks out the ones beginning with a       *
108 * pattern = Literal, or all those beginning with a variable if the pattern  *
109 * is a variable.                                                            *
110 *---------------------------------------------------------------------------*)
111
112fun mk_groupl literal rows =
113  let fun func (row as ((prefix, p::rst), rhs)) (in_group,not_in_group) =
114               if (is_var literal andalso is_var p) orelse p = literal
115               then if is_var literal
116                    then (((prefix,p::rst), rhs)::in_group, not_in_group)
117                    else (((prefix,rst), rhs)::in_group, not_in_group)
118               else (in_group, row::not_in_group)
119        | func _ _ = raise ERR "mk_groupl" ""
120  in
121    itlist func rows ([],[])
122  end;
123
124(*---------------------------------------------------------------------------*
125 * Goes through a list of rows and picks out the ones beginning with a       *
126 * pattern with constructor = c.                                             *
127 *---------------------------------------------------------------------------*)
128
129fun mk_group c rows =
130  let fun func (row as ((prefix, p::rst), rhs)) (in_group,not_in_group) =
131            let val (pc,args) = strip_comb p
132            in if same_const pc c
133               then (((prefix,args@rst), rhs)::in_group, not_in_group)
134               else (in_group, row::not_in_group)
135            end
136        | func _ _ = raise ERR "mk_group" ""
137  in
138    itlist func rows ([],[])
139  end;
140
141
142(*---------------------------------------------------------------------------*
143 * Partition the rows among literals. Not efficient.                         *
144 *---------------------------------------------------------------------------*)
145
146fun partitionl _ _ (_,_,_,[]) = raise ERR"partitionl" "no rows"
147  | partitionl gv ty_match
148              (constructors, colty, res_ty, rows as (((prefix,_),_)::_)) =
149let  fun part {constrs = [],      rows, A} = rev A
150       | part {constrs = c::crst, rows, A} =
151         let val (in_group, not_in_group) = mk_groupl c rows
152             val in_group' =
153                 if (null in_group)  (* Constructor not given *)
154                 then [((prefix, []), mk_omitted (mk_arb res_ty))]
155                 else in_group
156             val gvars = if is_var c then [c] else []
157         in
158         part{constrs = crst,
159              rows = not_in_group,
160              A = {constructor = c,
161                   new_formals = gvars,
162                   group = in_group'}::A}
163         end
164in part{constrs=constructors, rows=rows, A=[]}
165end;
166
167
168(*---------------------------------------------------------------------------*
169 * Partition the rows. Not efficient.                                        *
170 *---------------------------------------------------------------------------*)
171
172fun partition _ _ (_,_,_,[]) = raise ERR"partition" "no rows"
173  | partition gv ty_match
174              (constructors, colty, res_ty, rows as (((prefix:term list,_),_)::_)) =
175let val fresh = fresh_constr ty_match colty gv
176     fun part {constrs = [],      rows, A} = rev A
177       | part {constrs = c::crst, rows, A} =
178         let val (c',gvars) = fresh c
179             val (in_group, not_in_group) = mk_group c' rows
180             val in_group' =
181                 if (null in_group)  (* Constructor not given *)
182                 then [((prefix, #2(fresh c)), mk_omitted (mk_arb res_ty))]
183                 else in_group
184         in
185          part{constrs = crst,
186               rows = not_in_group,
187               A = {constructor = c', new_formals = gvars, group = in_group'}::A}
188         end
189in part{constrs=constructors, rows=rows, A=[]}
190end;
191
192
193(*---------------------------------------------------------------------------
194 * Misc. routines used in mk_case
195 *---------------------------------------------------------------------------*)
196
197fun mk_patl c =
198  let val L = if is_var c then 1 else 0
199      fun build (prefix,tag,plist) =
200          let val (args,plist') = gtake I (L, plist)
201              val c' = if is_var c then hd args else c
202           in (prefix,tag, c'::plist') end
203  in map build
204  end;
205
206fun mk_pat c =
207  let val L = length(#1(strip_fun(type_of c)))
208      fun build (prefix,tag,plist) =
209          let val (args,plist') = gtake I (L, plist)
210           in (prefix,tag,list_mk_comb(c,args)::plist') end
211  in map build
212  end;
213
214
215fun v_to_prefix (prefix, v::pats) = (v::prefix,pats)
216  | v_to_prefix _ = raise ERR "mk_case" "v_to_prefix"
217
218fun v_to_pats (v::prefix,tag, pats) = (prefix, tag, v::pats)
219  | v_to_pats _ = raise ERR "mk_case""v_to_pats";
220
221(* --------------------------------------------------------------
222   Literals include numeric, string, and character literals.
223   Boolean literals are the constructors of the bool type.
224   Currently, literals may be any expression without free vars.
225   These functions are not used at the moment, but may be someday.
226   -------------------------------------------------------------- *)
227
228(*
229val is_literal = Literal.is_literal
230
231fun is_lit_or_var tm = is_literal tm orelse is_var tm
232
233fun is_zero_emptystr_or_var tm =
234    Literal.is_zero tm orelse Literal.is_emptystring tm orelse is_var tm
235*)
236
237fun is_closed_or_var tm = is_var tm orelse null (free_vars tm)
238
239
240(* ---------------------------------------------------------------------------
241    Reconstructed code from TypeBasePure, to avoid circularity.
242   ---------------------------------------------------------------------------*)
243
244fun case_const_of   {case_const : term, constructors : term list} = case_const
245fun constructors_of {case_const : term, constructors : term list} = constructors
246
247fun type_names ty =
248  let val {Thy,Tyop,Args} = Type.dest_thy_type ty
249  in {Thy=Thy,Tyop=Tyop}
250  end;
251
252(*---------------------------------------------------------------------------*)
253(* Is a constant a constructor for some datatype.                            *)
254(*---------------------------------------------------------------------------*)
255
256fun is_constructor tybase c =
257  let val (_,ty) = strip_fun (type_of c)
258  in case tybase (type_names ty)
259     of NONE => false
260      | SOME tyinfo => op_mem same_const c (constructors_of tyinfo)
261  end handle HOL_ERR _ => false;
262
263fun is_constructor_pat tybase tm =
264    is_constructor tybase (fst (strip_comb tm));
265
266fun is_constructor_var_pat ty_info tm =
267    is_var tm orelse is_constructor_pat ty_info tm
268
269fun mk_switch_tm gv v base literals =
270    let val rty = type_of base
271        val lty = type_of v
272        val v' = last literals handle _ => gv lty
273        fun mk_arg lit = if is_var lit then gv (lty --> rty) else gv rty
274        val args = map mk_arg literals
275        open boolSyntax
276        fun mk_switch [] = base
277          | mk_switch ((lit,arg)::litargs) =
278                 if is_var lit then mk_comb(arg, v')
279                 else mk_bool_case(arg, mk_switch litargs, mk_eq(v', lit))
280        val switch = mk_switch (zip literals args)
281    in list_mk_abs(args@[v], mk_literal_case (mk_abs(v',switch), v))
282    end
283
284(* under_bool_case repairs a final beta_conv for literal switches. *)
285
286fun under_literal_case conv tm =
287  if is_literal_case tm then
288    let val (f,e) = dest_literal_case tm
289        val (x,bdy) = dest_abs f
290        val bdy' = conv bdy handle HOL_ERR _ => bdy
291    in mk_literal_case (mk_abs(x, bdy'), e)
292    end
293  else conv tm handle HOL_ERR _ => tm
294
295fun under_bool_case conv tm =
296  if is_bool_case tm then
297    let val (t,f,tst) = dest_bool_case tm
298        val f' = under_bool_case conv f
299    in mk_bool_case (t,f',tst)
300    end
301  else conv tm handle HOL_ERR _ => tm
302
303fun under_literal_bool_case conv tm =
304    under_literal_case (under_bool_case conv) tm
305
306
307(*----------------------------------------------------------------------------
308      Translation of pattern terms into nested case expressions.
309
310    This performs the translation and also builds the full set of patterns.
311    Thus it supports the construction of induction theorems even when an
312    incomplete set of patterns is given.
313 ----------------------------------------------------------------------------*)
314
315fun bring_to_front_list n l = let
316   val (l0, l1) = Lib.split_after n l
317   val (x, l1') = (hd l1, tl l1)
318  in x :: (l0 @ l1') end
319
320fun undo_bring_to_front n l = let
321   val (x, l') = (hd l, tl l)
322   val (l0, l1) = Lib.split_after n l'
323 in (l0 @ x::l1) end
324
325fun mk_case0_heu (heu : pmatch_heuristic) ty_info ty_match FV range_ty =
326 let
327 fun mk_case_fail s = raise ERR "mk_case" s
328 val fresh_var = vary FV
329 val dividel = partitionl fresh_var ty_match
330 val divide = partition fresh_var ty_match
331 fun expandl literals ty ((_,[]), _) = mk_case_fail "expandl_var_row"
332   | expandl literals ty (row as ((prefix, p::rst), rhs)) =
333       if is_var p
334       then let fun expnd l =
335                     ((prefix, l::rst), psubst[p |-> l] rhs)
336            in map expnd literals  end
337       else [row]
338 fun expand constructors ty ((_,[]), _) = mk_case_fail "expand_var_row"
339   | expand constructors ty (row as ((prefix, p::rst), rhs)) =
340      (if is_var p
341       then let val fresh = fresh_constr ty_match ty fresh_var
342                fun expnd (c,gvs) =
343                  let val capp = list_mk_comb(c,gvs)
344                  in ((prefix, capp::rst), psubst[p |-> capp] rhs)
345                  end
346            in map expnd (map fresh constructors)  end
347       else [row])
348 fun mk{rows=[],...} = mk_case_fail "no rows"
349   | mk{path=[], rows = ((prefix, []), rhs)::_} =  (* Done *)
350        let val (tag,tm) = dest_pattern rhs
351        in ([(prefix,tag,[])], tm)
352        end
353   | mk{path=[], rows = _::_} = mk_case_fail "blunder"
354   | mk{path as u::rstp, rows as ((prefix, []), rhs)::rst} =
355        mk{path = path,
356           rows = ((prefix, [fresh_var(type_of u)]), rhs)::rst}
357   | mk{path = rstp0, rows = rows0 as ((_, pL as (_ :: _)), _)::_} =
358     if ((#skip_rows heu) andalso length rows0 > 1 andalso all is_var pL)
359     then mk {path = rstp0, rows = [hd rows0]}
360     else
361     let val col_index = (#col_fun heu) ty_info (map (fn ((_, pL), _) => pL) rows0)
362         val u_rstp = bring_to_front_list col_index rstp0
363         val (u, rstp) = (hd u_rstp, tl u_rstp)
364         val rows = map (fn ((prefix, pL), rhs) => ((prefix, bring_to_front_list col_index pL), rhs)) rows0
365         val ((_, pL), _) = hd rows
366         val p = hd pL
367         val (pat_rectangle,rights) = unzip rows
368         val col0 = map(Lib.trye hd o #2) pat_rectangle
369     in
370     if all is_var col0
371     then let val rights' = map(fn(v,e) => psubst[v|->u] e) (zip col0 rights)
372              val pat_rectangle' = map v_to_prefix pat_rectangle
373              val (pref_patl,tm) = mk{path = rstp,
374                                      rows = zip pat_rectangle' rights'}
375              val pat_rect1 = map v_to_pats pref_patl
376              val pat_rect1' = map (fn (x, y, pL) => (x, y, undo_bring_to_front col_index pL)) pat_rect1
377          in (pat_rect1', tm)
378          end
379     else
380     let val pty = type_of p
381         val thy_tyop =
382             type_names pty
383             handle HOL_ERR _ =>
384                    raise ERR "mk_case0_heu"
385                          ("Term "^Parse.term_to_string p^
386                           " is a bad pattern (of var type?)")
387     in
388     if exists Literal.is_pure_literal col0 (* col0 has a literal *) then
389       let val is_lit_col = all (fn t => Literal.is_literal t orelse is_var t) col0
390           val _ = if is_lit_col then () else
391                   mk_case_fail "case expression mixes literals with non-literals."
392           val other_var = fresh_var pty
393           val constructors = rev (mk_set (rev (filter (not o is_var) col0)))
394                              @ [other_var]
395           val arb = mk_arb range_ty
396           val switch_tm = mk_switch_tm fresh_var u arb constructors
397           val nrows = flatten (map (expandl constructors pty) rows)
398           val subproblems = dividel(constructors, pty, range_ty, nrows)
399           val groups        = map #group subproblems
400           and new_formals   = map #new_formals subproblems
401           and constructors' = map #constructor subproblems
402           val news = map (fn (nf,rows) => {path = nf@rstp, rows=rows})
403                          (zip new_formals groups)
404           val rec_calls = map mk news
405           val (pat_rect,dtrees) = unzip rec_calls
406           val case_functions = map list_mk_abs(zip new_formals dtrees)
407           val tree = List.foldl (fn (a,tm) => beta_conv (mk_comb(tm,a)))
408                                 switch_tm (case_functions@[u])
409           val tree' = under_literal_bool_case beta_conv tree
410           val pat_rect1 = flatten(map2 mk_patl constructors' pat_rect)
411           val pat_rect1' = map (fn (x, y, pL) => (x, y, undo_bring_to_front col_index pL)) pat_rect1
412       in
413           (pat_rect1',tree')
414       end
415     else
416       case List.find (not o is_constructor_var_pat ty_info) col0 of
417         NONE => let
418           val {case_const,constructors} =
419             Lib.with_exn Option.valOf (ty_info thy_tyop)
420               (ERR "mk_case0" ("could not get case constructors for type " ^
421                                Parse.type_to_string pty))
422             handle Option.Option => (print "hello\n"; raise Option)
423           val {Name = case_const_name, Thy,...} = dest_thy_const case_const
424           val nrows = flatten (map (expand constructors pty) rows)
425           val subproblems = divide(constructors, pty, range_ty, nrows)
426           val groups       = map #group subproblems
427           and new_formals  = map #new_formals subproblems
428           and constructors' = map #constructor subproblems
429           val news = map (fn (nf,rows) => {path = nf@rstp, rows=rows})
430                          (zip new_formals groups)
431           val rec_calls = map mk news
432           val (pat_rect,dtrees) = unzip rec_calls
433           val tree =
434             if ((#collapse_cases heu) andalso
435                 (List.all (aconv (hd dtrees)) (tl dtrees)) andalso
436                 (List.all (fn (vL, tree) =>
437                    (List.all (fn v => not (free_in v tree)) vL)) (zip new_formals dtrees))) then
438               (* If all cases lead to the same result, no split is necessary *)
439               (hd dtrees)
440             else let
441               val case_functions = map list_mk_abs(zip new_formals dtrees)
442               val types = map type_of (u::case_functions)
443               val case_const' = mk_thy_const{Name = case_const_name, Thy = Thy,
444                                              Ty = list_mk_fun(types, range_ty)}
445               val tree = list_mk_comb(case_const', u::case_functions)
446             in tree end
447           val pat_rect1 = flatten(map2 mk_pat constructors' pat_rect)
448           val pat_rect1' = map (fn (x, y, pL) => (x, y, undo_bring_to_front col_index pL)) pat_rect1
449       in
450          (pat_rect1',tree)
451         end
452       | SOME t => mk_case_fail ("Pattern "^
453                                 trace ("Unicode", 0) Parse.term_to_string t^
454                                 " is not a constructor or variable")
455     end
456     end
457 in mk
458 end;
459
460fun mk_case0 ty_info ty_match FV range_ty rows =
461let
462  fun run_heu heu = mk_case0_heu heu ty_info ty_match FV range_ty rows
463
464  val (min_fun0, heu_fun) = (!pmatch_heuristic) ()
465  fun min_fun ((pL1, dt1), (pL2, dt2)) =
466    min_fun0 ((map (fn (x, _, _) => x) pL1, dt1), (map (fn (x, _, _) => x) pL2, dt2))
467
468  fun res_min NONE res = res
469    | res_min (SOME res1) res2 =
470        (case min_fun (res1, res2) of GREATER => res2 | _ => res1)
471
472  fun aux min = case (heu_fun ()) of
473     NONE => (case min of NONE => (print "SHOULD NOT HAPPEN! EMPTY PMATCH-HEURISTIC!"; fail()) | SOME min' => min')
474   | SOME heu => let
475       val res = run_heu heu
476       val min' = res_min min res
477     in aux (SOME min') end
478in
479  aux NONE
480end
481
482(*---------------------------------------------------------------------------
483     Repeated variable occurrences in a pattern are not allowed.
484 ---------------------------------------------------------------------------*)
485
486fun FV_multiset tm =
487   case dest_term tm
488     of VAR v => [mk_var v]
489      | CONST _ => []
490      | COMB(Rator,Rand) => FV_multiset Rator @ FV_multiset Rand
491      | LAMB(Bvar,Body) => Lib.subtract (FV_multiset Body) [Bvar]
492                           (* raise ERR"FV_multiset" "lambda"; *)
493
494fun no_repeat_vars pat =
495 let fun check [] = true
496       | check (v::rst) =
497         if Lib.op_mem aconv v rst
498         then raise ERR"no_repeat_vars"
499              (strcat(quote(#1(dest_var v)))
500                     (strcat" occurs repeatedly in the pattern "
501                      (quote(Hol_pp.term_to_string pat))))
502         else check rst
503 in check (FV_multiset pat)
504 end;
505
506
507(*---------------------------------------------------------------------------
508     Routines to repair the bound variable names found in cases
509 ---------------------------------------------------------------------------*)
510fun pat_match1 fvs pat given_pat =
511 let val (sub_tm, sub_ty) = Term.match_term pat given_pat
512     val _ = if null sub_ty then () else (raise ERR "pat_match1" "no type substitution expected");
513
514     fun is_valid_bound_var v = (is_var v andalso not (List.exists (fn tm => aconv tm v) fvs))
515     val _ = if List.all (fn m => is_valid_bound_var (#residue m)) sub_tm then () else
516           (raise ERR "pat_match1" "expected a bound variable renaming");
517 in sub_tm
518 end
519
520fun pat_match2 fvs pat_exps given_pat = tryfind ((C (pat_match1 fvs) given_pat) o fst) pat_exps
521                                        handle HOL_ERR _ => ([]);
522
523fun subst_to_renaming (s : (term, term) subst) : (term * string) list =
524  map (fn m => (#redex m, fst (dest_var (#residue m)))) s;
525
526fun distinguish fvs pat_tm_mats =
527    snd (List.foldr (fn ({redex,residue}, (vs,done)) =>
528                         let val residue' = variant vs residue
529                             val vs' = Lib.insert residue' vs
530                         in (vs', {redex=redex, residue=residue'} :: done)
531                         end)
532                    (fvs,[]) pat_tm_mats)
533
534fun reduce_mats pat_tm_mats =
535    snd (List.foldl (fn (mat as {redex,residue}, (vs,done)) =>
536                         if mem redex vs then (vs, done)
537                         else (redex :: vs, mat :: done))
538                    ([],[]) pat_tm_mats)
539
540fun purge_wildcards term_sub = filter (fn {redex,residue} =>
541        not (String.sub (fst (dest_var residue), 0) = #"_")
542        handle _ => false) term_sub
543
544fun pat_match3 fvs pat_exps given_pats =
545     ((subst_to_renaming o distinguish fvs o reduce_mats o purge_wildcards o flatten))
546           (map (pat_match2 fvs pat_exps) given_pats);
547
548
549(*---------------------------------------------------------------------------*)
550(* Syntax operations on the (extensible) set of case expressions.            *)
551(*---------------------------------------------------------------------------*)
552
553fun mk_case1 tybase (exp, plist) =
554  case match_info tybase (type_names (type_of exp))
555   of NONE => raise ERR "mk_case" "unable to analyze type"
556    | SOME tyinfo =>
557       let val c = case_const_of tyinfo
558           val fns = map (fn (p,R) => list_mk_abs(snd(strip_comb p),R)) plist
559           val ty' = list_mk_fun (type_of exp::map type_of fns,
560                                  type_of (snd (hd plist)))
561           val theta = Type.match_type (type_of c) ty'
562       in list_mk_comb(inst theta c,exp::fns)
563       end
564
565fun mk_case2 v (exp, plist) =
566       let fun mk_switch [] = raise ERR "mk_case" "null patterns"
567             | mk_switch [(p,R)] = R
568             | mk_switch ((p,R)::rst) =
569                  mk_bool_case(R, mk_switch rst, mk_eq(v,p))
570           val switch = mk_switch plist
571       in if v = exp then switch
572                     else mk_literal_case(mk_abs(v,switch),exp)
573       end;
574
575fun mk_case tybase (exp, plist) =
576  let val col0 = map fst plist
577  in if all (is_constructor_var_pat tybase) col0
578        andalso not (all is_var col0)
579     then (* constructor patterns *)
580          mk_case1 tybase (exp, plist)
581     else (* literal patterns *)
582          mk_case2 (last col0) (exp, plist)
583  end
584
585(*---------------------------------------------------------------------------*)
586(* dest_case destructs one level of pattern matching. To deal with nested    *)
587(* patterns, use strip_case.                                                 *)
588(*---------------------------------------------------------------------------*)
589
590local fun build_case_clause((ty,constr),rhs) =
591 let val (args,tau) = strip_fun (type_of constr)
592     fun peel  [] N = ([],N)
593       | peel (_::tys) N =
594           let val (v,M) = dest_abs N
595               val (V,M') = peel tys M
596           in (v::V,M')
597           end
598     val (V,rhs') = peel args rhs
599     val theta = Type.match_type (type_of constr)
600                      (list_mk_fun (map type_of V, ty))
601     val constr' = inst theta constr
602 in
603   (list_mk_comb(constr',V), rhs')
604  end
605in
606fun dest_case1 tybase M =
607  let val (c,args) = strip_comb M
608      val (cases,arg) =
609          case args of h::t => (t, h)
610                     | _ => raise ERR "dest_case" "case exp has too few args"
611  in case match_info tybase (type_names (type_of arg))
612      of NONE => raise ERR "dest_case" "unable to destruct case expression"
613       | SOME tyinfo =>
614          let val d = case_const_of tyinfo
615          in if same_const c d
616           then let val constrs = constructors_of tyinfo
617                    val constrs_type = map (pair (type_of arg)) constrs
618                in (c, arg, map build_case_clause (zip constrs_type cases))
619                end
620           else raise ERR "dest_case" "unable to destruct case expression"
621          end
622  end
623end
624
625fun dest_case tybase M =
626  if is_literal_case M then
627  let val (lcf, e)  = dest_comb M
628      val (lit_cs, f) = dest_comb lcf
629      val (x, M')  = dest_abs f
630  in (lit_cs, e, [(x,M')])
631  end
632  else dest_case1 tybase M
633
634fun is_case1 tybase M =
635  let val (c,args) = strip_comb M
636      val (tynames as {Tyop=tyop, ...}) =
637          type_names (type_of (hd args)) handle Empty => raise ERR "" ""
638      (* will get caught later *)
639  in
640    case match_info tybase tynames of
641      NONE => raise ERR "is_case" ("unknown type operator: "^Lib.quote tyop)
642    | SOME tyinfo => let
643        val gconst = case_const_of tyinfo
644        val gty = type_of gconst
645        val argtys = fst (strip_fun gty)
646      in
647        same_const c gconst andalso length args = length argtys
648      end
649  end
650  handle HOL_ERR _ => false;
651
652fun is_case tybase M = is_literal_case M orelse is_case1 tybase M
653
654local fun dest tybase (pat,rhs) =
655  let val patvars = free_vars pat
656  in if is_case tybase rhs
657     then let val (case_tm,exp,clauses) = dest_case tybase rhs
658              val (pats,rhsides) = unzip clauses
659          in if is_eq exp
660             then let val (v,e) = dest_eq exp
661                      val fvs = free_vars v
662                      (* val theta = fst (Term.match_term v e) handle HOL_ERR _ => [] *)
663                  in if null (subtract fvs patvars) andalso null (free_vars e)
664                        andalso is_var v
665                        (* andalso null_intersection fvs (free_vars (hd rhsides)) *)
666                     then flatten
667                            (map (dest tybase)
668                               (zip [subst [v |-> e] pat, pat] rhsides))
669                     else [(pat,rhs)]
670                  end
671             else let val fvs = free_vars exp
672                  in if null (subtract fvs patvars) andalso
673                        null_intersection fvs (free_varsl rhsides)
674                     then flatten
675                            (map (dest tybase)
676                               (zip (map (fn p =>
677                                           subst (fst (Term.match_term exp p)) pat) pats)
678                                    rhsides))
679                     else [(pat,rhs)]
680                  end
681                  handle HOL_ERR _ => [(pat,rhs)] (* catch from match_term *)
682          end
683     else [(pat,rhs)]
684  end
685in
686fun strip_case1 tybase M =
687 (case total (dest_case tybase) M
688   of NONE => (M,[])
689    | SOME(case_tm,exp,cases) =>
690         if is_eq exp
691         then let val (v,e) = dest_eq exp
692              in (v, flatten (map (dest tybase)
693                               (zip [e, v] (map snd cases))))
694              end
695         else (exp, flatten (map (dest tybase) cases)))
696end;
697
698fun strip_case tybase M =
699  if is_literal_case M then
700  let val (lcf, e) = dest_comb M
701      val (lit_cs, f) = dest_comb lcf
702      val (x, M') = dest_abs f
703      val (exp, cases) = if is_case1 tybase M'
704                         then strip_case1 tybase M'
705                         else (x, [(x, M')])
706  in (e, cases)
707  end
708  else strip_case1 tybase M
709
710fun rename_top_bound_vars ren cs = (
711 case dest_term cs of
712    VAR _ => cs
713  | CONST _ => cs
714  | COMB (t1, t2) => mk_comb (rename_top_bound_vars ren t1, rename_top_bound_vars ren t2)
715  | LAMB (v, t) =>
716    let val cs' = rename_bvar (Lib.assoc v ren) cs handle HOL_ERR _ => cs
717        val (v', t') = dest_abs cs'
718        val t'' = rename_top_bound_vars ren t'
719    in mk_abs (v', t'') end
720);
721
722local fun paired1{lhs,rhs} = (lhs,rhs)
723      and paired2{Rator,Rand} = (Rator,Rand)
724      fun err s = raise ERR "mk_functional" s
725      fun msg s = HOL_MESG ("mk_functional: "^s)
726in
727fun mk_functional thy eqs =
728 let val clauses = strip_conj eqs
729     val (L,R) = unzip (map (dest_eq o snd o strip_forall) clauses)
730     val (funcs,pats) = unzip(map dest_comb L)
731     val fs = Lib.op_mk_set aconv funcs
732     val f0 = if length fs = 1 then hd fs else err "function name not unique"
733     val f  = if is_var f0 then f0 else mk_var(dest_const f0)
734     val _  = map no_repeat_vars pats
735     val rows = zip (map (fn x => ([]:term list,[x])) pats) (map GIVEN (enumerate R))
736     val avs = all_varsl (L@R)
737     val a = variant avs (mk_var("a", type_of(Lib.trye hd pats)))
738     val FV = a::avs
739     val range_ty = type_of (Lib.trye hd R)
740     val (patts, case_tm) = mk_case0 (match_info thy) (match_type thy)
741                                     FV range_ty {path=[a], rows=rows}
742     fun func (_,(tag,i),[pat]) = tag (pat,i)
743       | func _ = err "error in pattern-match translation"
744     val patts1 = map func patts
745     val (omits,givens) = Lib.partition is_omitted patts1
746     val givens' = sort pattern_cmp givens
747     val patts2 = givens' @ omits
748     val finals = map row_of_pat patts2
749     val originals = map (row_of_pat o #2) rows
750     val new_rows = length finals - length originals
751     val clause_s = if new_rows = 1 then " clause " else " clauses "
752     val _ = if new_rows > 0 then
753               (msg ("\n  pattern completion has added "^
754                     Int.toString new_rows^clause_s^
755                     "to the original specification.");
756                if !allow_new_clauses then ()
757                else
758                  err ("new clauses not allowed under current setting of "^
759                       Lib.quote("Functional.allow_new_clauses")^" flag"))
760             else ()
761     fun int_eq i1 (i2:int) =  (i1=i2)
762     val inaccessibles = filter(fn x => not(op_mem int_eq x finals)) originals
763     fun accessible p = not(op_mem int_eq (row_of_pat p) inaccessibles)
764     val patts3 = (case inaccessibles of [] => patts2
765                        |  _ => filter accessible patts2)
766     val _ = case inaccessibles of [] => ()
767             | _ => msg("The following input rows (counting from zero) are\
768       \ inaccessible: "^stringize inaccessibles^".\nThey have been ignored.")
769     (* The next lines repair bound variable names in the nested case term. *)
770     val case_tm' =
771         let val (_,pat_exps) = strip_case thy case_tm
772             val fvs = free_vars case_tm
773             val ren = pat_match3 fvs pat_exps pats (* better pats than givens patts3 *)
774         in (rename_top_bound_vars ren case_tm)
775         end handle HOL_ERR _ =>
776           (Feedback.HOL_WARNING "Pmatch" "mk_functional" "SHOULD NOT HAPPEN! RENAMING CASE_TM FAILED!";
777            case_tm)
778     (* Ensure that the case test variable is fresh for the rest of the case *)
779     val avs = subtract (all_vars case_tm') [a]
780     val a' = variant avs a
781     val case_tm'' = if a' = a then case_tm'
782                               else subst ([a |-> a']) case_tm'
783 in
784   {functional = list_mk_abs ([f,a'], case_tm''),
785    pats = patts3}
786 end
787end;
788
789(*---------------------------------------------------------------------------
790   Given a list of (pattern,expression) pairs, mk_pattern_fn creates a term
791   as an abstraction containing a case expression on the function's argument.
792 ---------------------------------------------------------------------------*)
793
794fun mk_pattern_fn thy (pes: (term * term) list) =
795  let fun err s = raise ERR "mk_pattern_fn" s
796      val (p0,e0) = Lib.trye hd pes
797          handle HOL_ERR _ => err "empty list of (pattern,expression) pairs"
798      val pty = type_of p0 and ety = type_of e0
799      val (ps,es) = unzip pes
800      val _ = if all (Lib.equal pty o type_of) ps then ()
801              else err "patterns have varying types"
802      val _ = if all (Lib.equal ety o type_of) es then ()
803              else err "expressions have varying types"
804      val fvar = genvar (pty --> ety)
805      val eqs = list_mk_conj (map (fn (p,e) => mk_eq(mk_comb(fvar,p), e)) pes)
806      val {functional,pats} = mk_functional thy eqs
807      val f = snd (dest_abs functional)
808   in
809     f
810  end
811
812end;
813