1structure Pmatch :> Pmatch =
2struct
3
4open HolKernel boolSyntax PmatchHeuristics;
5
6type thry   = {Tyop : string, Thy : string} ->
7              {case_const : term, constructors : term list} option
8
9val ERR = mk_HOL_ERR "Pmatch";
10
11val allow_new_clauses = ref true;
12
13(*---------------------------------------------------------------------------
14      Miscellaneous support
15 ---------------------------------------------------------------------------*)
16
17fun gtake f =
18  let fun grab(0,rst) = ([],rst)
19        | grab(n, x::rst) =
20             let val (taken,left) = grab(n-1,rst)
21             in (f x::taken, left) end
22        | grab _ = raise ERR "gtake" "grab.empty list"
23  in grab
24  end;
25
26fun list_to_string f delim =
27  let fun stringulate [] = []
28        | stringulate [x] = [f x]
29        | stringulate (h::t) = f h::delim::stringulate t
30  in
31    fn l => String.concat (stringulate l)
32  end;
33
34val stringize = list_to_string int_to_string ", ";
35
36fun enumerate l = map (fn (x,y) => (y,x)) (Lib.enumerate 0 l);
37
38fun match_term thry tm1 tm2 = Term.match_term tm1 tm2;
39fun match_type thry ty1 ty2 = Type.match_type ty1 ty2;
40
41fun match_info db s = db s
42
43(* should probably be in somewhere like HolKernel *)
44local val counter = ref 0
45in
46fun vary vlist =
47  let val slist = ref (map (fst o dest_var) vlist)
48      val _ = counter := 0
49      fun pass str =
50         if Lib.mem str (!slist)
51         then (counter := !counter + 1; pass ("v"^int_to_string(!counter)))
52         else (slist := str :: !slist; str)
53  in
54    fn ty => mk_var(pass "v", ty)
55  end
56end;
57
58
59(*---------------------------------------------------------------------------
60 * This datatype carries some information about the origin of a
61 * clause in a function definition.
62 *---------------------------------------------------------------------------*)
63
64datatype pattern = GIVEN   of term * int
65                 | OMITTED of term * int
66
67fun pattern_cmp (GIVEN(_,i)) (GIVEN(_, j)) = i <= j
68  | pattern_cmp all others = raise ERR "pattern_cmp" ""
69
70fun psubst theta (GIVEN (tm,i)) = GIVEN(subst theta tm, i)
71  | psubst theta (OMITTED (tm,i)) = OMITTED(subst theta tm, i);
72
73fun dest_pattern (GIVEN (tm,i)) = ((GIVEN,i),tm)
74  | dest_pattern (OMITTED (tm,i)) = ((OMITTED,i),tm);
75
76fun pat_of (GIVEN (tm,_)) = tm
77  | pat_of (OMITTED(tm,_)) = tm
78
79fun row_of_pat (GIVEN(_, i)) = i
80  | row_of_pat (OMITTED _) = ~1
81
82fun dest_given (GIVEN(tm,_)) = tm
83  | dest_given (OMITTED _) = raise ERR "dest_given" ""
84
85fun mk_omitted tm = OMITTED(tm,~1)
86
87fun is_omitted (OMITTED _) = true
88  | is_omitted otherwise   = false;
89
90val givens = mapfilter dest_given;
91
92(*---------------------------------------------------------------------------
93 * Produce an instance of a constructor, plus genvars for its arguments.
94 *---------------------------------------------------------------------------*)
95
96fun fresh_constr ty_match (colty:hol_type) gv c =
97  let val Ty = type_of c
98      val (L,ty) = strip_fun Ty
99      val ty_theta = ty_match ty colty
100      val c' = inst ty_theta c
101      val gvars = map (inst ty_theta o gv) L
102  in (c', gvars)
103  end;
104
105
106(*---------------------------------------------------------------------------*
107 * Goes through a list of rows and picks out the ones beginning with a       *
108 * pattern = Literal, or all those beginning with a variable if the pattern  *
109 * is a variable.                                                            *
110 *---------------------------------------------------------------------------*)
111
112fun mk_groupl literal rows =
113  let fun func (row as ((prefix, p::rst), rhs)) (in_group,not_in_group) =
114               if (is_var literal andalso is_var p) orelse aconv p literal
115               then if is_var literal
116                    then (((prefix,p::rst), rhs)::in_group, not_in_group)
117                    else (((prefix,rst), rhs)::in_group, not_in_group)
118               else (in_group, row::not_in_group)
119        | func _ _ = raise ERR "mk_groupl" ""
120  in
121    itlist func rows ([],[])
122  end;
123
124(*---------------------------------------------------------------------------*
125 * Goes through a list of rows and picks out the ones beginning with a       *
126 * pattern with constructor = c.                                             *
127 *---------------------------------------------------------------------------*)
128
129fun mk_group c rows =
130  let fun func (row as ((prefix, p::rst), rhs)) (in_group,not_in_group) =
131            let val (pc,args) = strip_comb p
132            in if same_const pc c
133               then (((prefix,args@rst), rhs)::in_group, not_in_group)
134               else (in_group, row::not_in_group)
135            end
136        | func _ _ = raise ERR "mk_group" ""
137  in
138    itlist func rows ([],[])
139  end;
140
141
142(*---------------------------------------------------------------------------*
143 * Partition the rows among literals. Not efficient.                         *
144 *---------------------------------------------------------------------------*)
145
146fun partitionl _ _ (_,_,_,[]) = raise ERR"partitionl" "no rows"
147  | partitionl gv ty_match
148              (constructors, colty, res_ty, rows as (((prefix,_),_)::_)) =
149let  fun part {constrs = [],      rows, A} = rev A
150       | part {constrs = c::crst, rows, A} =
151         let val (in_group, not_in_group) = mk_groupl c rows
152             val in_group' =
153                 if (null in_group)  (* Constructor not given *)
154                 then [((prefix, []), mk_omitted (mk_arb res_ty))]
155                 else in_group
156             val gvars = if is_var c then [c] else []
157         in
158         part{constrs = crst,
159              rows = not_in_group,
160              A = {constructor = c,
161                   new_formals = gvars,
162                   group = in_group'}::A}
163         end
164in part{constrs=constructors, rows=rows, A=[]}
165end;
166
167
168(*---------------------------------------------------------------------------*
169 * Partition the rows. Not efficient.                                        *
170 *---------------------------------------------------------------------------*)
171
172fun partition _ _ (_,_,_,[]) = raise ERR"partition" "no rows"
173  | partition gv ty_match
174              (constructors, colty, res_ty, rows as (((prefix:term list,_),_)::_)) =
175let val fresh = fresh_constr ty_match colty gv
176     fun part {constrs = [],      rows, A} = rev A
177       | part {constrs = c::crst, rows, A} =
178         let val (c',gvars) = fresh c
179             val (in_group, not_in_group) = mk_group c' rows
180             val in_group' =
181                 if (null in_group)  (* Constructor not given *)
182                 then [((prefix, #2(fresh c)), mk_omitted (mk_arb res_ty))]
183                 else in_group
184         in
185          part{constrs = crst,
186               rows = not_in_group,
187               A = {constructor = c', new_formals = gvars, group = in_group'}::A}
188         end
189in part{constrs=constructors, rows=rows, A=[]}
190end;
191
192
193(*---------------------------------------------------------------------------
194 * Misc. routines used in mk_case
195 *---------------------------------------------------------------------------*)
196
197fun mk_patl c =
198  let val L = if is_var c then 1 else 0
199      fun build (prefix,tag,plist) =
200          let val (args,plist') = gtake I (L, plist)
201              val c' = if is_var c then hd args else c
202           in (prefix,tag, c'::plist') end
203  in map build
204  end;
205
206fun mk_pat c =
207  let val L = length(#1(strip_fun(type_of c)))
208      fun build (prefix,tag,plist) =
209          let val (args,plist') = gtake I (L, plist)
210           in (prefix,tag,list_mk_comb(c,args)::plist') end
211  in map build
212  end;
213
214
215fun v_to_prefix (prefix, v::pats) = (v::prefix,pats)
216  | v_to_prefix _ = raise ERR "mk_case" "v_to_prefix"
217
218fun v_to_pats (v::prefix,tag, pats) = (prefix, tag, v::pats)
219  | v_to_pats _ = raise ERR "mk_case""v_to_pats";
220
221(* --------------------------------------------------------------
222   Literals include numeric, string, and character literals.
223   Boolean literals are the constructors of the bool type.
224   Currently, literals may be any expression without free vars.
225   These functions are not used at the moment, but may be someday.
226   -------------------------------------------------------------- *)
227
228(*
229val is_literal = Literal.is_literal
230
231fun is_lit_or_var tm = is_literal tm orelse is_var tm
232
233fun is_zero_emptystr_or_var tm =
234    Literal.is_zero tm orelse Literal.is_emptystring tm orelse is_var tm
235*)
236
237fun is_closed_or_var tm = is_var tm orelse null (free_vars tm)
238
239
240(* ---------------------------------------------------------------------------
241    Reconstructed code from TypeBasePure, to avoid circularity.
242   ---------------------------------------------------------------------------*)
243
244fun case_const_of   {case_const : term, constructors : term list} = case_const
245fun constructors_of {case_const : term, constructors : term list} = constructors
246
247fun type_names ty =
248  let val {Thy,Tyop,Args} = Type.dest_thy_type ty
249  in {Thy=Thy,Tyop=Tyop}
250  end;
251
252(*---------------------------------------------------------------------------*)
253(* Is a constant a constructor for some datatype.                            *)
254(*---------------------------------------------------------------------------*)
255
256fun is_constructor tybase c =
257  let val (_,ty) = strip_fun (type_of c)
258  in case tybase (type_names ty)
259     of NONE => false
260      | SOME tyinfo => op_mem same_const c (constructors_of tyinfo)
261  end handle HOL_ERR _ => false;
262
263fun is_constructor_pat tybase tm =
264    is_constructor tybase (fst (strip_comb tm));
265
266fun is_constructor_var_pat ty_info tm =
267    is_var tm orelse is_constructor_pat ty_info tm
268
269fun mk_switch_tm gv v base literals =
270    let val rty = type_of base
271        val lty = type_of v
272        val v' = last literals handle _ => gv lty
273        fun mk_arg lit = if is_var lit then gv (lty --> rty) else gv rty
274        val args = map mk_arg literals
275        open boolSyntax
276        fun mk_switch [] = base
277          | mk_switch ((lit,arg)::litargs) =
278                 if is_var lit then mk_comb(arg, v')
279                 else mk_bool_case(arg, mk_switch litargs, mk_eq(v', lit))
280        val switch = mk_switch (zip literals args)
281    in list_mk_abs(args@[v], mk_literal_case (mk_abs(v',switch), v))
282    end
283
284(* under_bool_case repairs a final beta_conv for literal switches. *)
285
286fun under_literal_case conv tm =
287  if is_literal_case tm then
288    let val (f,e) = dest_literal_case tm
289        val (x,bdy) = dest_abs f
290        val bdy' = conv bdy handle HOL_ERR _ => bdy
291    in mk_literal_case (mk_abs(x, bdy'), e)
292    end
293  else conv tm handle HOL_ERR _ => tm
294
295fun under_bool_case conv tm =
296  if is_bool_case tm then
297    let val (t,f,tst) = dest_bool_case tm
298        val f' = under_bool_case conv f
299    in mk_bool_case (t,f',tst)
300    end
301  else conv tm handle HOL_ERR _ => tm
302
303fun under_literal_bool_case conv tm =
304    under_literal_case (under_bool_case conv) tm
305
306
307(*----------------------------------------------------------------------------
308      Translation of pattern terms into nested case expressions.
309
310    This performs the translation and also builds the full set of patterns.
311    Thus it supports the construction of induction theorems even when an
312    incomplete set of patterns is given.
313 ----------------------------------------------------------------------------*)
314
315fun bring_to_front_list n l = let
316   val (l0, l1) = Lib.split_after n l
317   val (x, l1') = (hd l1, tl l1)
318  in x :: (l0 @ l1') end
319
320fun undo_bring_to_front n l = let
321   val (x, l') = (hd l, tl l)
322   val (l0, l1) = Lib.split_after n l'
323 in (l0 @ x::l1) end
324
325fun mk_case0_heu (heu : pmatch_heuristic) ty_info ty_match FV range_ty =
326 let
327 fun mk_case_fail s = raise ERR "mk_case" s
328 val fresh_var = vary FV
329 val dividel = partitionl fresh_var ty_match
330 val divide = partition fresh_var ty_match
331 fun expandl literals ty ((_,[]), _) = mk_case_fail "expandl_var_row"
332   | expandl literals ty (row as ((prefix, p::rst), rhs)) =
333       if is_var p
334       then let fun expnd l =
335                     ((prefix, l::rst), psubst[p |-> l] rhs)
336            in map expnd literals  end
337       else [row]
338 fun expand constructors ty ((_,[]), _) = mk_case_fail "expand_var_row"
339   | expand constructors ty (row as ((prefix, p::rst), rhs)) =
340      (if is_var p
341       then let val fresh = fresh_constr ty_match ty fresh_var
342                fun expnd (c,gvs) =
343                  let val capp = list_mk_comb(c,gvs)
344                  in ((prefix, capp::rst), psubst[p |-> capp] rhs)
345                  end
346            in map expnd (map fresh constructors)  end
347       else [row])
348 fun mk{rows=[],...} = mk_case_fail "no rows"
349   | mk{path=[], rows = ((prefix, []), rhs)::_} =  (* Done *)
350        let val (tag,tm) = dest_pattern rhs
351        in ([(prefix,tag,[])], tm)
352        end
353   | mk{path=[], rows = _::_} = mk_case_fail "blunder"
354   | mk{path as u::rstp, rows as ((prefix, []), rhs)::rst} =
355        mk{path = path,
356           rows = ((prefix, [fresh_var(type_of u)]), rhs)::rst}
357   | mk{path = rstp0, rows = rows0 as ((_, pL as (_ :: _)), _)::_} =
358     if ((#skip_rows heu) andalso length rows0 > 1 andalso all is_var pL)
359     then mk {path = rstp0, rows = [hd rows0]}
360     else
361     let val col_index = (#col_fun heu) ty_info (map (fn ((_, pL), _) => pL) rows0)
362         val u_rstp = bring_to_front_list col_index rstp0
363         val (u, rstp) = (hd u_rstp, tl u_rstp)
364         val rows = map (fn ((prefix, pL), rhs) => ((prefix, bring_to_front_list col_index pL), rhs)) rows0
365         val ((_, pL), _) = hd rows
366         val p = hd pL
367         val (pat_rectangle,rights) = unzip rows
368         val col0 = map(Lib.trye hd o #2) pat_rectangle
369     in
370     if all is_var col0
371     then let val rights' = map(fn(v,e) => psubst[v|->u] e) (zip col0 rights)
372              val pat_rectangle' = map v_to_prefix pat_rectangle
373              val (pref_patl,tm) = mk{path = rstp,
374                                      rows = zip pat_rectangle' rights'}
375              val pat_rect1 = map v_to_pats pref_patl
376              val pat_rect1' = map (fn (x, y, pL) => (x, y, undo_bring_to_front col_index pL)) pat_rect1
377          in (pat_rect1', tm)
378          end
379     else
380     let val pty = type_of p
381         val thy_tyop =
382             type_names pty
383             handle HOL_ERR _ =>
384                    raise ERR "mk_case0_heu"
385                          ("Term "^Parse.term_to_string p^
386                           " is a bad pattern (of var type?)")
387     in
388     if exists Literal.is_pure_literal col0 (* col0 has a literal *) then
389       let val is_lit_col = all (fn t => Literal.is_literal t orelse is_var t) col0
390           val _ = if is_lit_col then () else
391                   mk_case_fail "case expression mixes literals with non-literals."
392           val other_var = fresh_var pty
393           val constructors =
394               rev (op_mk_set aconv (rev (filter (not o is_var) col0))) @
395               [other_var]
396           val arb = mk_arb range_ty
397           val switch_tm = mk_switch_tm fresh_var u arb constructors
398           val nrows = flatten (map (expandl constructors pty) rows)
399           val subproblems = dividel(constructors, pty, range_ty, nrows)
400           val groups        = map #group subproblems
401           and new_formals   = map #new_formals subproblems
402           and constructors' = map #constructor subproblems
403           val news = map (fn (nf,rows) => {path = nf@rstp, rows=rows})
404                          (zip new_formals groups)
405           val rec_calls = map mk news
406           val (pat_rect,dtrees) = unzip rec_calls
407           val case_functions = map list_mk_abs(zip new_formals dtrees)
408           val tree = List.foldl (fn (a,tm) => beta_conv (mk_comb(tm,a)))
409                                 switch_tm (case_functions@[u])
410           val tree' = under_literal_bool_case beta_conv tree
411           val pat_rect1 = flatten(map2 mk_patl constructors' pat_rect)
412           val pat_rect1' = map (fn (x, y, pL) => (x, y, undo_bring_to_front col_index pL)) pat_rect1
413       in
414           (pat_rect1',tree')
415       end
416     else
417       case List.find (not o is_constructor_var_pat ty_info) col0 of
418         NONE => let
419           val {case_const,constructors} =
420             Lib.with_exn Option.valOf (ty_info thy_tyop)
421               (ERR "mk_case0" ("could not get case constructors for type " ^
422                                Parse.type_to_string pty))
423             handle Option.Option => (print "hello\n"; raise Option)
424           val {Name = case_const_name, Thy,...} = dest_thy_const case_const
425           val nrows = flatten (map (expand constructors pty) rows)
426           val subproblems = divide(constructors, pty, range_ty, nrows)
427           val groups       = map #group subproblems
428           and new_formals  = map #new_formals subproblems
429           and constructors' = map #constructor subproblems
430           val news = map (fn (nf,rows) => {path = nf@rstp, rows=rows})
431                          (zip new_formals groups)
432           val rec_calls = map mk news
433           val (pat_rect,dtrees) = unzip rec_calls
434           val tree =
435             if ((#collapse_cases heu) andalso
436                 (List.all (aconv (hd dtrees)) (tl dtrees)) andalso
437                 (List.all (fn (vL, tree) =>
438                    (List.all (fn v => not (free_in v tree)) vL)) (zip new_formals dtrees))) then
439               (* If all cases lead to the same result, no split is necessary *)
440               (hd dtrees)
441             else let
442               val case_functions = map list_mk_abs(zip new_formals dtrees)
443               val types = map type_of (u::case_functions)
444               val case_const' = mk_thy_const{Name = case_const_name, Thy = Thy,
445                                              Ty = list_mk_fun(types, range_ty)}
446               val tree = list_mk_comb(case_const', u::case_functions)
447             in tree end
448           val pat_rect1 = flatten(map2 mk_pat constructors' pat_rect)
449           val pat_rect1' = map (fn (x, y, pL) => (x, y, undo_bring_to_front col_index pL)) pat_rect1
450       in
451          (pat_rect1',tree)
452         end
453       | SOME t => mk_case_fail ("Pattern "^
454                                 trace ("Unicode", 0) Parse.term_to_string t^
455                                 " is not a constructor or variable")
456     end
457     end
458 in mk
459 end;
460
461fun mk_case0 ty_info ty_match FV range_ty rows =
462let
463  fun run_heu heu = mk_case0_heu heu ty_info ty_match FV range_ty rows
464
465  val (min_fun0, heu_fun) = (!pmatch_heuristic) ()
466  fun min_fun ((pL1, dt1), (pL2, dt2)) =
467    min_fun0 ((map (fn (x, _, _) => x) pL1, dt1), (map (fn (x, _, _) => x) pL2, dt2))
468
469  fun res_min NONE res = res
470    | res_min (SOME res1) res2 =
471        (case min_fun (res1, res2) of GREATER => res2 | _ => res1)
472
473  fun aux min = case (heu_fun ()) of
474     NONE => (case min of NONE => (print "SHOULD NOT HAPPEN! EMPTY PMATCH-HEURISTIC!"; fail()) | SOME min' => min')
475   | SOME heu => let
476       val res = run_heu heu
477       val min' = res_min min res
478     in aux (SOME min') end
479in
480  aux NONE
481end
482
483(*---------------------------------------------------------------------------
484     Repeated variable occurrences in a pattern are not allowed.
485 ---------------------------------------------------------------------------*)
486
487fun inc d k = case Binarymap.peek(d,k) of NONE => Binarymap.insert(d,k,1)
488                                        | SOME n => Binarymap.insert(d,k,n+1);
489
490fun FV_multiset tm =
491  let
492    datatype witem = TM of term | RESET of term * int
493    fun recurse d wlist =
494      case wlist of
495          [] => d
496        | TM tm :: rest =>
497          (case dest_term tm of
498               VAR _ => recurse (inc d tm) rest
499             | CONST _ => recurse d rest
500             | COMB(Rator,Rand) => recurse d (TM Rator :: TM Rand :: rest)
501             | LAMB(Bvar,Body) =>
502               let
503                 val c0 = case Binarymap.peek(d,Bvar) of
504                              NONE => 0
505                            | SOME i => i
506               in
507                 recurse d (TM Body :: RESET(Bvar,c0) :: rest)
508               end)
509        | RESET (v,c) :: rest => recurse (Binarymap.insert(d,v,c)) rest
510  in
511    recurse (Binarymap.mkDict Term.compare) [TM tm]
512  end
513
514fun no_repeat_vars pat =
515  let
516    fun check d =
517      let
518        val repeats =
519            Binarymap.foldl (fn (k,v,acc) => if v > 1 then k::acc else acc) [] d
520      in
521        if null repeats then true
522        else
523          raise ERR"no_repeat_vars"
524              (quote(#1(dest_var (hd repeats))) ^
525               " occurs repeatedly in the pattern " ^
526               quote(Hol_pp.term_to_string pat))
527      end
528  in
529    check (FV_multiset pat)
530  end;
531
532
533(*---------------------------------------------------------------------------
534     Routines to repair the bound variable names found in cases
535 ---------------------------------------------------------------------------*)
536fun pat_match1 fvs pat given_pat =
537 let val (sub_tm, sub_ty) = Term.match_term pat given_pat
538     val _ = if null sub_ty then () else (raise ERR "pat_match1" "no type substitution expected");
539
540     fun is_valid_bound_var v = (is_var v andalso not (List.exists (fn tm => aconv tm v) fvs))
541     val _ = if List.all (fn m => is_valid_bound_var (#residue m)) sub_tm then () else
542           (raise ERR "pat_match1" "expected a bound variable renaming");
543 in sub_tm
544 end
545
546fun pat_match2 fvs pat_exps given_pat = tryfind ((C (pat_match1 fvs) given_pat) o fst) pat_exps
547                                        handle HOL_ERR _ => ([]);
548
549fun subst_to_renaming (s : (term, term) subst) : (term * string) list =
550  map (fn m => (#redex m, fst (dest_var (#residue m)))) s;
551
552fun distinguish fvs pat_tm_mats =
553    snd (List.foldr (fn ({redex,residue}, (vs,done)) =>
554                         let val residue' = variant vs residue
555                             val vs' = op_insert aconv residue' vs
556                         in (vs', {redex=redex, residue=residue'} :: done)
557                         end)
558                    (fvs,[]) pat_tm_mats)
559
560fun reduce_mats pat_tm_mats =
561    snd (List.foldl (fn (mat as {redex,residue}, (vs,done)) =>
562                         if op_mem aconv redex vs then (vs, done)
563                         else (redex :: vs, mat :: done))
564                    ([],[]) pat_tm_mats)
565
566fun purge_wildcards term_sub = filter (fn {redex,residue} =>
567        not (String.sub (fst (dest_var residue), 0) = #"_")
568        handle _ => false) term_sub
569
570fun pat_match3 fvs pat_exps given_pats =
571     ((subst_to_renaming o distinguish fvs o reduce_mats o purge_wildcards o flatten))
572           (map (pat_match2 fvs pat_exps) given_pats);
573
574
575(*---------------------------------------------------------------------------*)
576(* Syntax operations on the (extensible) set of case expressions.            *)
577(*---------------------------------------------------------------------------*)
578
579fun mk_case1 tybase (exp, plist) =
580  case match_info tybase (type_names (type_of exp))
581   of NONE => raise ERR "mk_case" "unable to analyze type"
582    | SOME tyinfo =>
583       let val c = case_const_of tyinfo
584           val fns = map (fn (p,R) => list_mk_abs(snd(strip_comb p),R)) plist
585           val ty' = list_mk_fun (type_of exp::map type_of fns,
586                                  type_of (snd (hd plist)))
587           val theta = Type.match_type (type_of c) ty'
588       in list_mk_comb(inst theta c,exp::fns)
589       end
590
591fun mk_case2 v (exp, plist) =
592       let fun mk_switch [] = raise ERR "mk_case" "null patterns"
593             | mk_switch [(p,R)] = R
594             | mk_switch ((p,R)::rst) =
595                  mk_bool_case(R, mk_switch rst, mk_eq(v,p))
596           val switch = mk_switch plist
597       in if aconv v exp then switch
598          else mk_literal_case(mk_abs(v,switch),exp)
599       end;
600
601fun mk_case tybase (exp, plist) =
602  let val col0 = map fst plist
603  in if all (is_constructor_var_pat tybase) col0
604        andalso not (all is_var col0)
605     then (* constructor patterns *)
606          mk_case1 tybase (exp, plist)
607     else (* literal patterns *)
608          mk_case2 (last col0) (exp, plist)
609  end
610
611(*---------------------------------------------------------------------------*)
612(* dest_case destructs one level of pattern matching. To deal with nested    *)
613(* patterns, use strip_case.                                                 *)
614(*---------------------------------------------------------------------------*)
615
616local fun build_case_clause((ty,constr),rhs) =
617 let val (args,tau) = strip_fun (type_of constr)
618     fun peel  [] N = ([],N)
619       | peel (_::tys) N =
620           let val (v,M) = dest_abs N
621               val (V,M') = peel tys M
622           in (v::V,M')
623           end
624     val (V,rhs') = peel args rhs
625     val theta = Type.match_type (type_of constr)
626                      (list_mk_fun (map type_of V, ty))
627     val constr' = inst theta constr
628 in
629   (list_mk_comb(constr',V), rhs')
630  end
631in
632fun dest_case1 tybase M =
633  let val (c,args) = strip_comb M
634      val (cases,arg) =
635          case args of h::t => (t, h)
636                     | _ => raise ERR "dest_case" "case exp has too few args"
637  in case match_info tybase (type_names (type_of arg))
638      of NONE => raise ERR "dest_case" "unable to destruct case expression"
639       | SOME tyinfo =>
640          let val d = case_const_of tyinfo
641          in if same_const c d
642           then let val constrs = constructors_of tyinfo
643                    val constrs_type = map (pair (type_of arg)) constrs
644                in (c, arg, map build_case_clause (zip constrs_type cases))
645                end
646           else raise ERR "dest_case" "unable to destruct case expression"
647          end
648  end
649end
650
651fun dest_case tybase M =
652  if is_literal_case M then
653  let val (lcf, e)  = dest_comb M
654      val (lit_cs, f) = dest_comb lcf
655      val (x, M')  = dest_abs f
656  in (lit_cs, e, [(x,M')])
657  end
658  else dest_case1 tybase M
659
660fun is_case1 tybase M =
661  let val (c,args) = strip_comb M
662      val (tynames as {Tyop=tyop, ...}) =
663          type_names (type_of (hd args)) handle Empty => raise ERR "" ""
664      (* will get caught later *)
665  in
666    case match_info tybase tynames of
667      NONE => raise ERR "is_case" ("unknown type operator: "^Lib.quote tyop)
668    | SOME tyinfo => let
669        val gconst = case_const_of tyinfo
670        val gty = type_of gconst
671        val argtys = fst (strip_fun gty)
672      in
673        same_const c gconst andalso length args = length argtys
674      end
675  end
676  handle HOL_ERR _ => false;
677
678fun is_case tybase M = is_literal_case M orelse is_case1 tybase M
679
680fun tm_null_intersection l1 l2 =
681  case (l1, l2) of
682      ([], _) => true
683    | (_, []) => true
684    | (tm::tms, _) => not (op_mem aconv tm l2) andalso
685                      tm_null_intersection tms l2
686
687local fun dest tybase (pat,rhs) =
688  let val patvars = free_vars pat
689  in if is_case tybase rhs then
690       let
691         val (case_tm,exp,clauses) = dest_case tybase rhs
692         val (pats,rhsides) = unzip clauses
693       in
694         if is_eq exp then
695           let
696             val (v,e) = dest_eq exp
697             val fvs = free_vars v
698             (* val theta = fst (Term.match_term v e) handle HOL_ERR _ => [] *)
699           in
700             if null (op_set_diff aconv fvs patvars) andalso null (free_vars e)
701                andalso is_var v
702                (* andalso null_intersection fvs (free_vars (hd rhsides)) *)
703             then flatten
704                    (map (dest tybase) (zip [subst [v |-> e] pat, pat] rhsides))
705             else [(pat,rhs)]
706           end
707         else
708           let
709             val fvs = free_vars exp
710           in
711             if null (op_set_diff aconv fvs patvars) andalso
712                tm_null_intersection fvs (free_varsl rhsides)
713             then flatten
714                    (map (dest tybase)
715                         (zip (map (fn p =>
716                                subst (fst (Term.match_term exp p)) pat) pats)
717                              rhsides))
718             else [(pat,rhs)]
719           end
720           handle HOL_ERR _ => [(pat,rhs)] (* catch from match_term *)
721       end
722     else [(pat,rhs)]
723  end
724in
725fun strip_case1 tybase M =
726 (case total (dest_case tybase) M
727   of NONE => (M,[])
728    | SOME(case_tm,exp,cases) =>
729         if is_eq exp
730         then let val (v,e) = dest_eq exp
731              in (v, flatten (map (dest tybase)
732                               (zip [e, v] (map snd cases))))
733              end
734         else (exp, flatten (map (dest tybase) cases)))
735end;
736
737fun strip_case tybase M =
738  if is_literal_case M then
739  let val (lcf, e) = dest_comb M
740      val (lit_cs, f) = dest_comb lcf
741      val (x, M') = dest_abs f
742      val (exp, cases) = if is_case1 tybase M'
743                         then strip_case1 tybase M'
744                         else (x, [(x, M')])
745  in (e, cases)
746  end
747  else strip_case1 tybase M
748
749fun rename_top_bound_vars ren cs =
750 case dest_term cs of
751    VAR _ => cs
752  | CONST _ => cs
753  | COMB (t1, t2) =>
754      mk_comb (rename_top_bound_vars ren t1, rename_top_bound_vars ren t2)
755  | LAMB (v, t) =>
756      let
757        val cs' = rename_bvar (op_assoc aconv v ren) cs handle HOL_ERR _ => cs
758        val (v', t') = dest_abs cs'
759        val t'' = rename_top_bound_vars ren t'
760      in
761        mk_abs (v', t'')
762      end;
763
764local fun paired1{lhs,rhs} = (lhs,rhs)
765      and paired2{Rator,Rand} = (Rator,Rand)
766      fun err s = raise ERR "mk_functional" s
767      fun msg s = HOL_MESG ("mk_functional: "^s)
768in
769fun mk_functional thy eqs =
770 let val clauses = strip_conj eqs
771     val (L,R) = unzip (map (dest_eq o snd o strip_forall) clauses)
772     val (funcs,pats) = unzip(map dest_comb L)
773     val fs = Lib.op_mk_set aconv funcs
774     val f0 = if length fs = 1 then hd fs else err "function name not unique"
775     val f  = if is_var f0 then f0 else mk_var(dest_const f0)
776     val _  = map no_repeat_vars pats
777     val rows = zip (map (fn x => ([]:term list,[x])) pats) (map GIVEN (enumerate R))
778     val avs = all_varsl (L@R)
779     val a = variant avs (mk_var("a", type_of(Lib.trye hd pats)))
780     val FV = a::avs
781     val range_ty = type_of (Lib.trye hd R)
782     val (patts, case_tm) = mk_case0 (match_info thy) (match_type thy)
783                                     FV range_ty {path=[a], rows=rows}
784     fun func (_,(tag,i),[pat]) = tag (pat,i)
785       | func _ = err "error in pattern-match translation"
786     val patts1 = map func patts
787     val (omits,givens) = Lib.partition is_omitted patts1
788     val givens' = sort pattern_cmp givens
789     val patts2 = givens' @ omits
790     val finals = map row_of_pat patts2
791     val originals = map (row_of_pat o #2) rows
792     val new_rows = length finals - length originals
793     val clause_s = if new_rows = 1 then " clause " else " clauses "
794     val _ = if new_rows > 0 then
795               (msg ("\n  pattern completion has added "^
796                     Int.toString new_rows^clause_s^
797                     "to the original specification.");
798                if !allow_new_clauses then ()
799                else
800                  err ("new clauses not allowed under current setting of "^
801                       Lib.quote("Functional.allow_new_clauses")^" flag"))
802             else ()
803     fun int_eq i1 (i2:int) =  (i1=i2)
804     val inaccessibles = filter(fn x => not(op_mem int_eq x finals)) originals
805     fun accessible p = not(op_mem int_eq (row_of_pat p) inaccessibles)
806     val patts3 = (case inaccessibles of [] => patts2
807                        |  _ => filter accessible patts2)
808     val _ = case inaccessibles of [] => ()
809             | _ => msg("The following input rows (counting from zero) are\
810       \ inaccessible: "^stringize inaccessibles^".\nThey have been ignored.")
811     (* The next lines repair bound variable names in the nested case term. *)
812     val case_tm' =
813         let val (_,pat_exps) = strip_case thy case_tm
814             val fvs = free_vars case_tm
815             val ren = pat_match3 fvs pat_exps pats (* better pats than givens patts3 *)
816         in (rename_top_bound_vars ren case_tm)
817         end handle HOL_ERR _ =>
818           (Feedback.HOL_WARNING "Pmatch" "mk_functional" "SHOULD NOT HAPPEN! RENAMING CASE_TM FAILED!";
819            case_tm)
820     (* Ensure that the case test variable is fresh for the rest of the case *)
821     val avs = op_set_diff aconv (all_vars case_tm') [a]
822     val a' = variant avs a
823     val case_tm'' = if aconv a' a then case_tm'
824                     else subst ([a |-> a']) case_tm'
825 in
826   {functional = list_mk_abs ([f,a'], case_tm''),
827    pats = patts3}
828 end
829end;
830
831(*---------------------------------------------------------------------------
832   Given a list of (pattern,expression) pairs, mk_pattern_fn creates a term
833   as an abstraction containing a case expression on the function's argument.
834 ---------------------------------------------------------------------------*)
835
836fun mk_pattern_fn thy (pes: (term * term) list) =
837  let fun err s = raise ERR "mk_pattern_fn" s
838      val (p0,e0) = Lib.trye hd pes
839          handle HOL_ERR _ => err "empty list of (pattern,expression) pairs"
840      val pty = type_of p0 and ety = type_of e0
841      val (ps,es) = unzip pes
842      val _ = if all (Lib.equal pty o type_of) ps then ()
843              else err "patterns have varying types"
844      val _ = if all (Lib.equal ety o type_of) es then ()
845              else err "expressions have varying types"
846      val fvar = genvar (pty --> ety)
847      val eqs = list_mk_conj (map (fn (p,e) => mk_eq(mk_comb(fvar,p), e)) pes)
848      val {functional,pats} = mk_functional thy eqs
849      val f = snd (dest_abs functional)
850   in
851     f
852  end
853
854end;
855