1structure patternMatchesLib :> patternMatchesLib =
2struct
3
4open HolKernel Parse boolLib Drule BasicProvers
5open simpLib numLib metisLib
6open patternMatchesTheory
7open listTheory
8open quantHeuristicsLib
9open DatatypeSimps
10open patternMatchesSyntax
11open Traverse
12open constrFamiliesLib
13open unwindLib
14open oneSyntax
15
16structure Parse =
17struct
18  open Parse
19  val (Type,Term) =
20      parse_from_grammars patternMatchesTheory.patternMatches_grammars
21end
22open Parse
23
24val list_ss  = numLib.arith_ss ++ listSimps.LIST_ss
25
26(***********************************************)
27(* Auxiliary stuff                             *)
28(***********************************************)
29
30fun make_gen_conv_ss c name ssl = let
31   exception genconv_reducer_exn
32   fun addcontext (context,thms) = context
33   fun apply {solver,conv,context,stack,relation} tm = (
34     QCHANGED_CONV (c (ssl, SOME (conv stack))) tm
35   )
36   in simpLib.dproc_ss (REDUCER {name=SOME name,
37               addcontext=addcontext, apply=apply,
38               initial=genconv_reducer_exn})
39   end;
40
41(* Often in the following, a single row needs extracting.
42   given a list l, we want to an element [n], the list of
43   elements before it and the list of elements after it.
44   So, we need an efficient way to compute:
45   (List.take (l, n), List.nth (l, n), List.drop (l, n+1))
46
47   extract_element [0,1,2,3,4,5] 0  = ([], 0, [1,2,3,4,5])
48   extract_element [0,1,2,3,4,5] 1  = ([0], 1, [2,3,4,5])
49   extract_element [0,1,2,3,4,5] 3  = ([0,1,2], 3, [4,5])
50   extract_element [0,1,2,3,4,5] 5  = ([0,1,2,3,4], 5, [])
51 *)
52
53fun extract_element l n = let
54  val (l1, l2) = Lib.split_after n l
55  in
56    case l2 of
57        [] => failwith "index too large"
58      | x::xs => (l1, x, xs)
59  end
60
61
62(* Similarly, we often need to replace an element with
63   a list of elements. We need an efficient way to compute
64
65   (List.take (l, n) @ new_elements @ List.drop (l, n+1),
66    List.nth (l, n))
67 *)
68fun replace_element l n new =
69  if n < 0 then failwith "index too small"
70  else let
71    fun aux _ (_, []) = failwith "index too big"
72      | aux 0 (acc, x::xs) =
73          (List.revAppend (acc, new @ xs), x)
74      | aux n (acc, x::xs) =
75          aux (n-1) (x::acc, xs)
76  in
77     aux n ([], l)
78  end
79
80(* We have a problem with conversions that loop in a fancy way.
81   They add some pattern matching on the input variables and
82   in the body the original term with renamed variables. The
83   following function tries to detect this situation. *)
84fun does_conv_loop thm = let
85    val (l, r) = dest_eq (concl thm)
86    fun my_mk_abs t = list_mk_abs (free_vars_lr t, t)
87    val l' = my_mk_abs l
88    val const_check = let
89      val (l_c, _) = strip_comb l
90    in
91      fn t => (same_const (fst (strip_comb t)) l_c)
92    end handle HOL_ERR _ => (fn t => true)
93    fun is_similar t = const_check t andalso (aconv l' (my_mk_abs t))
94    val i = ((find_term is_similar r; true) handle HOL_ERR _ => false)
95  in
96    i
97  end
98
99
100(***********************************************)
101(* Simpset to evaluate PMATCH_ROWS             *)
102(***********************************************)
103
104val PAIR_EQ_COLLAPSE = prove (
105``(((FST x = (a:'a)) /\ (SND x = (b:'b))) = (x = (a, b)))``,
106Cases_on `x` THEN SIMP_TAC std_ss [] THEN METIS_TAC[])
107
108val PAIR_EQ_COLLAPSE = prove (
109``(((FST x = (a:'a)) /\ (SND x = (b:'b))) = (x = (a, b)))``,
110Cases_on `x` THEN SIMP_TAC std_ss [])
111
112fun is_FST_eq x t = let
113  val (l, r) = dest_eq t
114  val pred = aconv (pairSyntax.mk_fst x)
115in
116  pred l
117end
118
119fun FST_SND_CONJUNCT_COLLAPSE v conj = let
120  val conj'_thm = markerLib.move_conj_left (is_FST_eq v) conj
121
122  val v' = pairSyntax.mk_snd v
123
124  val thm_coll = (TRY_CONV (RAND_CONV (FST_SND_CONJUNCT_COLLAPSE v')) THENC
125   (REWR_CONV PAIR_EQ_COLLAPSE))
126    (rhs (concl conj'_thm))
127in
128  TRANS conj'_thm thm_coll
129end handle HOL_ERR _ => raise UNCHANGED
130
131fun ELIM_FST_SND_SELECT_CONV t = let
132  val (v, conj) = boolSyntax.dest_select t
133  val thm0 = FST_SND_CONJUNCT_COLLAPSE v conj
134
135  val thm1 = RAND_CONV (ABS_CONV (K thm0)) t
136  val thm2 = CONV_RULE (RHS_CONV (REWR_CONV SELECT_REFL)) thm1
137in
138  thm2
139end handle HOL_ERR _ => raise UNCHANGED
140
141
142(*
143val rc = DEPTH_CONV pairTools.PABS_ELIM_CONV THENC SIMP_CONV list_ss [pairTheory.EXISTS_PROD, pairTheory.FORALL_PROD, PMATCH_ROW_EQ_NONE, PAIR_EQ_COLLAPSE, oneTheory.one]
144*)
145
146val pabs_elim_ss =
147    simpLib.conv_ss
148      {name  = "PABS_ELIM_CONV",
149       trace = 2,
150       key   = SOME ([],``UNCURRY (f:'a -> 'b -> bool)``),
151       conv  = K (K pairTools.PABS_ELIM_CONV)}
152
153val elim_fst_snd_select_ss =
154    simpLib.conv_ss
155      {name  = "ELIM_FST_SND_SELECT_CONV",
156       trace = 2,
157       key   = SOME ([],``$@ (f:'a -> bool)``),
158       conv  = K (K ELIM_FST_SND_SELECT_CONV)}
159
160val select_conj_ss =
161    simpLib.conv_ss
162      {name  = "SELECT_CONJ_SS_CONV",
163       trace = 2,
164       key   = SOME ([],``$@ (f:'a -> bool)``),
165       conv  = K (K (SIMP_CONV (std_ss++boolSimps.CONJ_ss) []))};
166
167(* A basic simpset-fragment with a lot of useful stuff
168   to automatically show the validity of preconditions
169   as produced by functions in this library. *)
170val static_ss = simpLib.merge_ss
171  [pabs_elim_ss,
172   pairSimps.paired_forall_ss,
173   pairSimps.paired_exists_ss,
174   pairSimps.gen_beta_ss,
175   select_conj_ss,
176   elim_fst_snd_select_ss,
177   boolSimps.EQUIV_EXTRACT_ss,
178   quantHeuristicsLib.SIMPLE_QUANT_INST_ss,
179   simpLib.rewrites [
180     some_var_bool_T, some_var_bool_F,
181     GSYM boolTheory.F_DEF,
182     pairTheory.EXISTS_PROD,
183     pairTheory.FORALL_PROD,
184     PMATCH_ROW_EQ_NONE,
185     PMATCH_ROW_COND_def,
186     PMATCH_ROW_COND_EX_def,
187     PAIR_EQ_COLLAPSE,
188     oneTheory.one]];
189
190(* We add the stateful rewrite set (to simplify
191   e.g. case-constants or constructors) and a
192   custum component as well. *)
193fun rc_ss gl = simpLib.remove_ssfrags (srw_ss() ++ simpLib.merge_ss (static_ss :: gl)) ["patternMatchesSimp"]
194
195(* finally we add a call-back component. This is an
196   external conversion that is used at the end if
197   everything else fails. This is used to have nested calls
198   of the simplifier. The simplifier executes some conversion that
199   uses rs_ss. At the end, we might want to use the external
200   simplifier. This is realised with these call-backs. *)
201fun callback_CONV cb_opt t = (case cb_opt of
202    NONE => NO_CONV t
203  | SOME cb => (if (can (find_term is_PMATCH) t) then
204                  NO_CONV t
205                else cb t));
206
207fun rc_conv_rws (gl, callback_opt) thms = REPEATC (
208  SIMP_CONV (rc_ss gl) thms THENC
209  TRY_CONV (callback_CONV callback_opt))
210
211(* So, now combine it to get some convenient high-level
212   functions. *)
213fun rc_conv rc_arg = rc_conv_rws rc_arg []
214
215fun rc_tac (gl, callback_opt) =
216  CONV_TAC (rc_conv (gl, callback_opt))
217
218fun rc_elim_precond rc_arg thm = let
219  val pre = rand (rator (concl thm))
220  val pre_thm = prove_attempt (pre, rc_tac rc_arg)
221  val thm2 = MP thm pre_thm
222in
223  thm2
224end
225
226(* fix_appends expects a theorem of the form
227   PMATCH v rows = PMATCH v' rows'
228
229   and a term l of form
230   PMATCH v rows0.
231
232   It tries to get the appends in rows and rows' in
233   a nice form. To do this, it tries to prove that
234   l and the lhs of the theorem are equal.
235   Then it tries to simplify appends in rows'
236   resulting in rows''.
237
238   It returns a theorem of the form
239
240   l = PMATCH v' rows''.
241*)
242fun fix_appends rc_arg l thm = let
243  val t_eq_thm = prove (mk_eq (l, lhs (concl thm)),
244     CONV_TAC (DEPTH_CONV listLib.APPEND_CONV) THEN
245     rc_tac rc_arg)
246
247  val thm2 = TRANS t_eq_thm thm
248
249  fun my_append_conv t = let
250    val _ = if listSyntax.is_append t then () else raise UNCHANGED
251  in
252    (BINOP_CONV (TRY_CONV my_append_conv) THENC
253     listLib.APPEND_CONV) t
254  end
255
256  val thm3 = CONV_RULE (RHS_CONV (RAND_CONV my_append_conv)) thm2
257    handle HOL_ERR _ => thm2
258         | UNCHANGED => thm2
259in
260  thm3
261end
262
263(* Apply a conversion to all args of a PMATCH_ROW, i.e. given
264   a term of the form ``PMATCH_ROW pat guard rhs i``
265   it applies a conversion to ``pat`` ``guard`` and ``rhs``. *)
266fun PMATCH_ROW_ARGS_CONV c =
267   RATOR_CONV (RAND_CONV (TRY_CONV c)) THENC
268   RATOR_CONV (RATOR_CONV (RAND_CONV (TRY_CONV c))) THENC
269   RATOR_CONV (RATOR_CONV (RATOR_CONV (RAND_CONV (TRY_CONV c))))
270
271
272(***********************************************)
273(* converting between case-splits to PMATCH    *)
274(***********************************************)
275
276(* ----------------------- *)
277(* Auxiliary functions for *)
278(* case2pmatch             *)
279(* ----------------------- *)
280
281(*
282val t = ``case x of
283  (NONE, []) => 0`` *)
284
285fun type_names ty =
286  let val {Thy,Tyop,Args} = Type.dest_thy_type ty
287  in {Thy=Thy,Tyop=Tyop}
288  end;
289
290(* destruct variant cases, see dest_case_fun *)
291fun dest_case_fun_aux1 t = let
292  val (f, args) = strip_comb t
293  val (tys, _) = strip_fun (type_of f)
294  val _ = if (List.null args) then failwith "dest_case_fun" else ()
295  val ty = case tys of
296      [] => failwith "dest_case_fun"
297    | (ty::_) => ty
298  val tn = type_names ty
299  val ti = case TypeBase.fetch ty of
300      NONE => failwith "dest_case_fun"
301    | SOME ti => ti
302
303  val _ = if (same_const (TypeBasePure.case_const_of ti) f) then
304    () else  failwith "dest_case_fun"
305
306  val ty_s = match_type (type_of (TypeBasePure.case_const_of ti)) (type_of f)
307  val constrs = List.map (inst ty_s) (TypeBasePure.constructors_of ti)
308
309  val a = hd args
310  val ps = map2 (fn c => fn arg => let
311    val (vars, res) = strip_abs arg in
312    (list_mk_comb (c, vars), res) end) constrs (tl args)
313in
314  (a, ps)
315end
316
317(* destruct literal cases, see dest_case_fun *)
318fun dest_case_fun_aux2 t = let
319  val _ = if is_literal_case t then () else failwith "dest_case_fun"
320
321  val (f, args) = strip_comb t
322
323  val v = (el 2 args)
324  val (v', b) = dest_abs (el 1 args)
325
326  fun strip_cond acc b = let
327    val (c, t_t, t_f) = dest_cond b
328    val (c_l, c_r) = dest_eq c
329    val _ = if (aconv c_l v') then () else failwith "dest_case_fun"
330  in
331    strip_cond ((c_r, t_t)::acc) t_f
332  end handle HOL_ERR _ => (acc, b)
333
334  val (ps_rev, c_else) = strip_cond [] b
335  val ps = List.rev ((v', c_else) :: ps_rev)
336in
337  (v, ps)
338end
339
340
341(* destruct a case-function.
342   The top-most split is split into the input + a list of rows.
343   Each row consists of a pattern + the right-hand side. *)
344fun dest_case_fun t = dest_case_fun_aux1 t handle HOL_ERR _ => dest_case_fun_aux2 t
345
346
347(* try to collapse rows by introducing a catchall at end*)
348fun dest_case_fun_collapse (a, ps) = let
349
350  (* find all possible catch-all clauses *)
351  fun check_collapsable (p, rh) = let
352     val p_vs = FVL [p] empty_tmset
353     val rh' = if HOLset.isEmpty p_vs then rh else
354        Term.subst [p |-> a] rh
355     val ok = HOLset.isEmpty (HOLset.intersection (FVL [rh'] empty_tmset, p_vs))
356  in
357    if ok then SOME rh' else NONE
358  end
359
360  val catch_all_cands = List.foldl (fn (prh, cs) =>
361      case check_collapsable prh of
362         NONE => cs
363       | SOME rh => rh::cs) [] ps
364
365  (* really collapse *)
366  fun is_not_cought ca (p, rh) =
367     not (aconv rh (Term.subst [a |-> p] ca))
368
369  val all_collapse_opts = List.map (fn ca => (ca, filter (is_not_cought ca) ps)) catch_all_cands
370
371  val all_collapse_opts_sorted = sort (fn (_, l1) => fn (_, l2) => List.length l1 < List.length l2) all_collapse_opts
372
373  (* could we collapse 2 cases? *)
374in
375  if (List.null all_collapse_opts) then (a, ps) else
376  let
377     val (ca', ps') = hd all_collapse_opts_sorted
378  in if (List.length ps' + 1 < List.length ps) then
379        (a, (ps' @ [(a, ca')]))
380      else (a, ps)
381  end
382end
383
384fun case2pmatch_aux optimise x t = let
385  val (a, ps) = dest_case_fun t
386  val _ = if is_var a andalso free_in a x then () else failwith "case-split on non pattern var"
387  val (a, ps) = if optimise then (dest_case_fun_collapse (a, ps)) else (a, ps)
388
389  fun process_arg (p, rh) = let
390    val x' = subst [a |-> p] x
391  in
392    (* recursive call *)
393    case case2pmatch_aux optimise x' rh of
394        NONE => [(x', rh)]
395      | SOME resl => resl
396  end
397
398  val ps = flatten (map process_arg ps)
399in
400  SOME ps
401end handle HOL_ERR _ => NONE;
402
403fun case2pmatch_remove_unnessary_rows ps = let
404  fun mk_distinct_rows (p1, _) (p2, rh2) = let
405     val avoid = free_vars p1
406     val (s, _) = List.foldl (fn (v, (s, av)) =>
407       let val v' = variant av v in
408       ((v |-> v')::s, v'::av) end) ([], avoid) (free_vars p2)
409     val p2' = Term.subst s p2
410     val rh2' = Term.subst s rh2
411  in
412     (p2', rh2')
413  end
414
415  fun pats_unify (p1, _) (p2, _) = (
416    (Unify.simp_unify_terms [] p1 p2; true) handle HOL_ERR _ => false
417  )
418
419  fun row_subsumed (p1, rh1) (p2, rh2) = let
420     val (s, _) = match_term p2 p1
421     val rh2' = Term.subst s rh2
422  in aconv rh2' rh1 end handle HOL_ERR _ => false
423
424  fun row_is_needed r1 rs = case rs of
425      [] => true
426    | r2::rs' => let
427         val r2' = mk_distinct_rows r1 r2
428      in
429         if pats_unify r1 r2' then (
430           not (row_subsumed r1 r2')
431         ) else row_is_needed r1 rs'
432      end
433
434   fun check_rows acc rs = case rs of
435       [] => List.rev acc
436     | [_] => (* drop last one *) List.rev acc
437     | r::rs' => check_rows (if row_is_needed r rs' then r::acc else acc)
438                  rs'
439
440   val ps' = case ps of
441                [] => []
442              | (p, rh)::_ => (ps @ [(genvar (type_of p), mk_arb (type_of rh))])
443
444in
445  check_rows [] ps'
446end
447
448
449(* ----------------------- *)
450(* End Auxiliary functions *)
451(* for case2pmatch         *)
452(* ----------------------- *)
453
454(*
455val (p1, rh1) = el 5 ps
456val (p2, rh2) = mk_distinct_rows (p1, rh1) (el 6 ps)
457ps
458*)
459
460(* convert a case-term into a PMATCH-term, without any proofs *)
461fun case2pmatch opt t = let
462  val (f, args) = strip_comb t
463  val _ = if (List.null args) then failwith "not a case-split" else ()
464
465  val (p,patterns) = if is_literal_case t then (el 2 args, [el 1 args]) else
466      (hd args, tl args)
467  val v = genvar (type_of p)
468
469  val t0 = if is_literal_case t then list_mk_comb (f, patterns @ [v]) else list_mk_comb (f, v::patterns)
470  val ps = case case2pmatch_aux opt v t0 of
471      NONE => failwith "not a case-split"
472    | SOME ps => ps
473
474  val ps = if opt then case2pmatch_remove_unnessary_rows ps else ps
475
476  fun process_pattern (p, rh) = let
477    val fvs = List.rev (free_vars p)
478  in
479    if opt then
480      snd (mk_PMATCH_ROW_PABS_WILDCARDS fvs (p, T, rh))
481    else
482      mk_PMATCH_ROW_PABS fvs (p, T, rh)
483  end
484  val rows = List.map process_pattern ps
485  val rows_tm = listSyntax.mk_list (rows, type_of (hd rows))
486
487  val rows_tm_p = Term.subst [v |-> p] rows_tm
488in
489  mk_PMATCH p rows_tm_p
490end
491
492(* So far, we converted a classical case-expression
493   to a PMATCH without any proof. The following is used
494   to prove the equivalence of the result via repeated
495   case-splits and evaluation. This allows to
496   define some conversions then. *)
497
498val COND_CONG_STOP = prove (``
499  (c = c') ==> ((if c then x else y) = (if c' then x else y))``,
500SIMP_TAC std_ss [])
501
502fun case_pmatch_eq_prove t t' = let
503  val tm = mk_eq (t, t')
504
505  (* very slow, simple approach. Just brute force.
506     TODO: change implementation to get more runtime-speed *)
507  val my_tac = (
508    REPEAT (BasicProvers.TOP_CASE_TAC THEN
509            ASM_REWRITE_TAC[]) THEN
510    FULL_SIMP_TAC (rc_ss []) [PMATCH_EVAL, PMATCH_ROW_COND_def,
511      PMATCH_INCOMPLETE_def]
512  )
513in
514  (* set_goal ([], tm) *)
515  prove (tm, REPEAT my_tac)
516end handle HOL_ERR _ => raise UNCHANGED
517
518
519fun PMATCH_INTRO_CONV t =
520  case_pmatch_eq_prove t (case2pmatch true t)
521
522fun PMATCH_INTRO_CONV_NO_OPTIMISE t =
523  case_pmatch_eq_prove t (case2pmatch false t)
524
525
526(* ------------------------- *)
527(* pmatch2case               *)
528(* ------------------------- *)
529
530(* convert a case-term into a PMATCH-term, without any proofs *)
531fun pmatch2case t = let
532  val (v, rows) = dest_PMATCH t
533  val fv = genvar (type_of v --> type_of t)
534
535  fun process_row r = let
536     val (vars_tm, pt, gt, rh) = dest_PMATCH_ROW_ABS r
537     val _ = if (aconv gt T) then () else
538       failwith ("guard present in row " ^
539           (term_to_string r))
540
541     val vars = FVL [vars_tm] empty_tmset
542     val used_vars = FVL [pt] empty_tmset
543     val free_vars = HOLset.difference (used_vars, vars)
544     val _ = if (HOLset.isEmpty free_vars) then () else
545       failwith ("free variables in pattern " ^ (term_to_string pt))
546  in
547     mk_eq (mk_comb (fv, pt), rh)
548  end
549
550  val row_eqs = map process_row rows
551  val rows_tm = list_mk_conj row_eqs
552
553  (* compile patterns *)
554  val case_tm0 = GrammarSpecials.compile_pattern_match rows_tm
555
556
557  (* nearly there, now remove lambda's *)
558  val (vs, case_tm1) = strip_abs case_tm0
559  val case_tm = subst [el 2 vs |-> v] case_tm1
560in
561  case_tm
562end
563
564fun PMATCH_ELIM_CONV t =
565  case_pmatch_eq_prove t (pmatch2case t)
566
567
568
569(***********************************************)
570(* removing redundant rows                     *)
571(***********************************************)
572
573(*
574val rc_arg = ([], NONE)
575
576val t = ``
577   case l of
578     | [] => 0
579     | x::y::x::y::_ => (x + y)
580     | x::x::x::x::_ when (x > 10) => x
581     | x::x::x::x::x::_ => 9
582     | [] => 1
583     | x::x::x::y::_ => (x + x + x)
584     | x::_ => 1
585     | x::y::z::_ => (x + x + x)
586   ``
587
588val (rows, _) = listSyntax.dest_list (rand t)
589*)
590
591(* For removing redundant rows we want to check whether
592   the pattern of a row is overlapped by the pattern of a
593   previous row. In preparation for this, we extract all
594   patterns and generate fresh variables for it. The we
595   build for all rows the pair of the pattern + the patterns
596   of all following rows. This allows for simple checks
597   via matching later. *)
598fun compute_row_pat_pairs rows = let
599  (* get pats with fresh vars to do a quick prefiltering *)
600  val pats_unique = Lib.enumerate 0 (Lib.mapfilter (fn r => let
601    val (p, _, _) = dest_PMATCH_ROW r
602    val (vars_tm, pb) = pairSyntax.dest_pabs p
603    val vars = pairSyntax.strip_pair vars_tm
604    val s = List.map (fn v => (v |-> genvar (type_of v))) vars
605    val vars' = map (fn x => #residue x) s
606    val pb' = subst s pb
607  in
608    (vars', pb')
609  end) rows)
610
611  (* get all pairs, first component always appears before second *)
612  val candidates = let
613    fun aux acc l = case l of
614       [] => acc
615     | (x::xs) => aux ((List.map (fn y => (x, y)) xs) @ acc) xs
616  in
617    aux [] pats_unique
618  end
619in
620  candidates
621end
622
623(* Now do the real filtering *)
624fun PMATCH_REMOVE_FAST_REDUNDANT_CONV_GENCALL_SINGLE rc_arg t = let
625  val (v, rows) = dest_PMATCH t
626  val candidates = compute_row_pat_pairs rows
627
628  (* quick filter on matching *)
629  val candidates_match = let
630     fun does_match ((_, (v1, p1)), (_, (v2, p2))) =
631     let
632        val (t_s, ty_s) = match_term p1 p2
633     in
634        (null ty_s) andalso
635        (Lib.all (fn x => mem (#redex x) v1) t_s)
636     end handle HOL_ERR _ => false
637  in
638     List.filter does_match candidates
639  end
640
641  (* filtering finished, now try it for real *)
642  val cands = List.map (fn ((p1, _), (p2, _)) => (p1, p2)) candidates_match
643  (* val (r_no1, r_no2) = el 1 cands *)
644  fun try_pair (r_no1, r_no2) = let
645    val tm0 = let
646      val (rows1, r1, rows_rest) = extract_element rows r_no1
647      val (rows2, r2, rows3) = extract_element rows_rest (r_no2 - r_no1 - 1)
648
649      val rows1_tm = listSyntax.mk_list (rows1, type_of r1)
650      val rows2_tm = listSyntax.mk_list (rows2, type_of r1)
651      val r1rows2_tm = listSyntax.mk_cons (r1, rows2_tm)
652      val rows3_tm = listSyntax.mk_list (rows3, type_of r1)
653      val r2rows3_tm = listSyntax.mk_cons (r2, rows3_tm)
654
655      val arg = listSyntax.list_mk_append [rows1_tm, r1rows2_tm, r2rows3_tm]
656    in
657      mk_PMATCH v arg
658    end
659
660    val thm0 = FRESH_TY_VARS_RULE PMATCH_ROWS_DROP_REDUNDANT_PMATCH_ROWS
661    val thm1 = PART_MATCH (lhs o rand) thm0 tm0
662
663    val thm2 = rc_elim_precond rc_arg thm1
664    val thm3 = fix_appends rc_arg t thm2
665  in
666    thm3
667  end
668in
669  Lib.tryfind try_pair cands
670end handle HOL_ERR _ => raise UNCHANGED
671
672fun PMATCH_REMOVE_FAST_REDUNDANT_CONV_GENCALL rc_arg = REPEATC (PMATCH_REMOVE_FAST_REDUNDANT_CONV_GENCALL_SINGLE rc_arg)
673fun PMATCH_REMOVE_FAST_REDUNDANT_CONV_GEN ssl = PMATCH_REMOVE_FAST_REDUNDANT_CONV_GENCALL (ssl, NONE)
674val PMATCH_REMOVE_FAST_REDUNDANT_CONV = PMATCH_REMOVE_FAST_REDUNDANT_CONV_GEN []
675
676
677(***********************************************)
678(* removing subsumed rows                       *)
679(***********************************************)
680
681(*
682val rc_arg = ([], NONE)
683
684set_trace "parse deep cases" 0
685val t = case2pmatch false ``case x of NONE => 0``
686
687val t = case2pmatch false ``case (x, y, z) of
688   (0, y, z) => 2
689 | (x, NONE, []) => x
690 | (x, SOME y, l) => x+y``
691
692val t =
693   ``case (x,y,z) of
694    | (0,v1) => 2
695    | (SUC v4,NONE,[]) => (SUC v4)
696    | (SUC v4,NONE,v10::v11) => ARB
697    | (v4,NONE,_) => v4
698    | (0,SOME _ ,_) => ARB
699    | (SUC v4,SOME v9,v8) => (SUC v4 + v9)
700  ``
701
702*)
703
704(* When removing subsumed rows, i.e. rows that can be dropped,
705   because a following rule covers them, we can sometimes drop rows with
706   right-hand-side ARB, because PMATCH v [] evalutates to ARB.
707   This is semantically fine, but changes the users-view. The resulting
708   case expression might e.g. not be exhaustive any more. This can
709   also cause trouble for code generation. Therefore the parameter
710   [exploit_match_exp] determines, whether this optimisation is performed. *)
711fun PMATCH_REMOVE_FAST_SUBSUMED_CONV_GENCALL_SINGLE
712  exploit_match_exp rc_arg t = let
713  val (v, rows) = dest_PMATCH t
714  val candidates = compute_row_pat_pairs rows
715
716  (* quick filter on matching *)
717  val candidates_match = let
718     fun does_match ((_, (v1, p1)), (_, (v2, p2))) =
719     let
720        val (t_s, ty_s) = match_term p2 p1
721     in
722        (null ty_s) andalso
723        (Lib.all (fn x => mem (#redex x) v2) t_s)
724     end handle HOL_ERR _ => false
725  in
726     List.filter does_match candidates
727  end
728
729  val cands_sub = List.map (fn ((p1, _), (p2, _)) => (p1, SOME p2)) candidates_match
730
731  (* filtering finished, now try it for real *)
732  fun cands_arb () = Lib.mapfilter (fn (i, r) => let
733     val (_, _, _, r) = dest_PMATCH_ROW_ABS r in
734   (dest_arb r; (i, (NONE : int option))) end) (Lib.enumerate 0 rows)
735
736  val cands = if exploit_match_exp then (cands_sub @ cands_arb ()) else
737    cands_sub
738
739  (* filtering finished, now try it for real *)
740  (* val (r_no1, r_no2_opt) = el 2 cands_arb *)
741  fun try_pair (r_no1, r_no2_opt) = let
742    fun mk_row_list rs = listSyntax.mk_list (rs, type_of (hd rows))
743
744    fun extract_el_n n rs = let
745      val (rows1,r1,rows_rest) = extract_element rs n
746      val rows1_tm = mk_row_list rows1
747
748      fun build_tm rest_tm =
749        listSyntax.mk_append (rows1_tm,
750          (listSyntax.mk_cons (r1, rest_tm)))
751    in
752      (rows_rest, build_tm)
753    end
754
755    val tm0 = let
756       val (rs_rest, bf_1) = extract_el_n r_no1 rows
757
758       val rs2 = case r_no2_opt of
759           NONE => mk_row_list rs_rest
760         | SOME n => let
761             val n' = n - r_no1 - 1
762             val (rs_rest', bf_2) = extract_el_n n' rs_rest
763           in
764             bf_2 (mk_row_list rs_rest')
765           end
766    in
767      mk_PMATCH v (bf_1 rs2)
768    end
769
770    val thm_base = case r_no2_opt of
771        NONE => PMATCH_REMOVE_ARB_NO_OVERLAP
772      | SOME _ => PMATCH_ROWS_DROP_SUBSUMED_PMATCH_ROWS
773    val thm0 = FRESH_TY_VARS_RULE thm_base
774    val thm1 = PART_MATCH (lhs o rand) thm0 tm0
775
776    val thm2 = rc_elim_precond rc_arg thm1
777    val thm3 = fix_appends rc_arg t thm2
778  in
779    thm3
780  end
781in
782  Lib.tryfind try_pair cands
783end handle HOL_ERR _ => raise UNCHANGED
784
785fun PMATCH_REMOVE_FAST_SUBSUMED_CONV_GENCALL eme rc_arg = REPEATC (PMATCH_REMOVE_FAST_SUBSUMED_CONV_GENCALL_SINGLE eme rc_arg)
786fun PMATCH_REMOVE_FAST_SUBSUMED_CONV_GEN eme ssl = PMATCH_REMOVE_FAST_SUBSUMED_CONV_GENCALL eme (ssl, NONE)
787fun PMATCH_REMOVE_FAST_SUBSUMED_CONV eme = PMATCH_REMOVE_FAST_SUBSUMED_CONV_GEN eme []
788
789
790(***********************************************)
791(* Cleaning up unused vars in PMATCH_ROW       *)
792(***********************************************)
793
794(*val t = ``
795PMATCH (SOME x, xz)
796     [PMATCH_ROW (\x. (SOME 2,x,[])) (\x. T) (\x. x);
797      PMATCH_ROW (\y:'a. ((SOME 2,3,[]))) (\y. T) (\y. x);
798      PMATCH_ROW (\(z,x,yy). (z,x,[2])) (\(z,x,yy). T) (\(z,x,yy). x)]``
799*)
800
801
802(* Many simps depend on patterns being injective. This means
803   in particular that no extra, unused vars occur in the patterns.
804   The following removes such unused vars. *)
805
806fun PMATCH_CLEANUP_PVARS_CONV t = let
807  val _ = if is_PMATCH t then () else raise UNCHANGED
808
809  fun row_conv row = let
810     val (vars_tm, pt, gt, rh) = dest_PMATCH_ROW_ABS row
811     val _ = if (type_of vars_tm = one_ty) then raise UNCHANGED else ()
812     val vars = pairSyntax.strip_pair vars_tm
813     val used_vars = FVL [pt, rh] empty_tmset
814
815     val filtered_vars = filter (fn v => HOLset.member (used_vars, v)) vars
816
817     val _ = if (length vars = length filtered_vars) then
818       raise UNCHANGED else ()
819
820     val row' = mk_PMATCH_ROW_PABS filtered_vars (pt, gt, rh)
821
822     val eq_tm = mk_eq (row, row')
823     (* set_goal ([], eq_tm) *)
824     val eq_thm = prove (eq_tm,
825        MATCH_MP_TAC PMATCH_ROW_EQ_AUX THEN
826        rc_tac ([], NONE)
827     )
828  in
829     eq_thm
830  end
831in
832  CHANGED_CONV (DEPTH_CONV (PMATCH_ROW_FORCE_SAME_VARS_CONV THENC row_conv)) t
833end handle HOL_ERR _ => raise UNCHANGED
834
835
836(***********************************************)
837(* Cleaning up by removing rows that           *)
838(* don't match or are redundant                *)
839(* also remove the whole PMATCH, if first      *)
840(* row matches                                 *)
841(***********************************************)
842
843(*
844val t = ``
845PMATCH (NONE,x,l)
846     [PMATCH_ROW (\x. (NONE,x,[])) (\x. T) (\x. x);
847      PMATCH_ROW (\x. (NONE,x,[2])) (\x. F) (\x. x);
848      PMATCH_ROW (\x. (NONE,x,[2])) (\x. T) (\x. x);
849      PMATCH_ROW (\(x,y). (y,x,[2])) (\(x, y). T) (\(x, y). x);
850      PMATCH_ROW (\x. (SOME 3,x,[])) (\x. T) (\x. x)
851   ]``
852
853val t = ``PMATCH y [PMATCH_ROW (\_0_1. _0_1) (\_0_1. T) (\_0_1. F)]``
854
855val t = ``case (SUC x) of x => x + 3``
856
857val rc_arg = ([], NONE)
858
859val t' = rhs (concl (PMATCH_CLEANUP_CONV t))
860*)
861
862fun map_filter f l = case l of
863    [] => []
864  | x::xs => (case f x of
865       NONE => map_filter f xs
866     | SOME y => y :: (map_filter f xs));
867
868(* remove redundant rows *)
869fun PMATCH_CLEANUP_CONV_GENCALL rc_arg t = let
870  val (v, rows) = dest_PMATCH t
871  val _ = if (null rows) then raise UNCHANGED else ()
872
873  fun check_row r = let
874    val r_tm = mk_eq (mk_comb (r, v), optionSyntax.mk_none (type_of t))
875    val r_thm = rc_conv rc_arg r_tm
876    val res_tm = rhs (concl r_thm)
877  in
878    if (same_const res_tm T) then SOME (true, r_thm) else
879    (if (same_const res_tm F) then SOME (false, r_thm) else NONE)
880  end handle HOL_ERR _ => NONE
881
882  val (rows_checked_rev, _) = foldl (fn (r, (acc, abort)) =>
883    if abort then ((r, NONE)::acc, true) else (
884    let
885      val res = check_row r
886      val abort = (case res of
887         (SOME (false, _)) => true
888       | _ => false)
889    in
890      ((r, res)::acc, abort)
891    end)) ([], false) rows
892  val rows_checked = List.rev rows_checked_rev
893
894  (* did we get any results? *)
895  fun check_row_exists v rows =
896     exists (fn x => case x of (_, SOME (v', _)) => v = v' | _ => false) rows
897
898  val _ = if ((check_row_exists true rows_checked_rev) orelse (check_row_exists false (tl rows_checked_rev)) orelse (check_row_exists false [hd rows_checked])) then () else raise UNCHANGED
899
900  val row_ty = type_of (hd rows)
901
902  (* drop redundant rows *)
903  val (thm0, rows_checked0) = let
904    val n = index (fn x => case x of (_, SOME (false, _)) => true | _ => false) rows_checked
905    val n_tm = numSyntax.term_of_int n
906
907    val thma = ISPECL [v, listSyntax.mk_list (rows, row_ty), n_tm]
908      (FRESH_TY_VARS_RULE PMATCH_ROWS_DROP_REDUNDANT_TRIVIAL_SOUNDNESS)
909
910    val precond = fst (dest_imp (concl thma))
911    val precond_thm = prove (precond,
912      MP_TAC (snd(valOf (snd (el (n+1) rows_checked)))) THEN
913      SIMP_TAC list_ss [quantHeuristicsTheory.IS_SOME_EQ_NOT_NONE])
914
915    val thmb = MP thma precond_thm
916
917    val take_conv = RATOR_CONV (RAND_CONV reduceLib.SUC_CONV) THENC
918                    listLib.FIRSTN_CONV
919    val thmc = CONV_RULE (RHS_CONV (RAND_CONV take_conv)) thmb
920  in
921    (thmc, List.take (rows_checked, n+1))
922  end handle HOL_ERR _ => (REFL t, rows_checked)
923
924  (* drop false rows *)
925  val (thm1, rows_checked1) = let
926     val _ = if (exists (fn x => case x of (_, (SOME (true, _))) => true | _ => false) rows_checked0) then () else failwith "nothing to do"
927
928     fun process_row ((r, r_thm_opt), thm) = (case r_thm_opt of
929         (SOME (true, r_thm)) => let
930           val thmA = FRESH_TY_VARS_RULE PMATCH_EXTEND_OLD
931           val thmB = HO_MATCH_MP thmA (EQT_ELIM r_thm)
932           val thmC = HO_MATCH_MP thmB thm
933        in
934          thmC
935        end
936     | _ => let
937           val thmA = PMATCH_EXTEND_BOTH_ID
938           val thmB = HO_MATCH_MP thmA thm
939        in
940           ISPEC r thmB
941        end)
942
943    val base_thm = INST_TYPE [gamma |-> type_of t] (ISPECL [v, v] PMATCH_EXTEND_BASE)
944    val thma = foldl process_row base_thm (List.rev rows_checked0)
945
946    val rows_checked1 = filter (fn (_, res_opt) => case res_opt of
947         SOME (true, thm) => false
948     | _ => true) rows_checked0
949  in
950    (thma, rows_checked1)
951  end handle HOL_ERR _ => (REFL (rhs (concl thm0)), rows_checked0)
952
953
954  (* if first line matches, evaluate *)
955  val thm2 = let
956     val _ = if (not (List.null rows_checked1) andalso
957                 (case hd rows_checked1 of (_, (SOME (false, _))) => true | _ => false)) then () else failwith "nothing to do"
958
959     val thm1_tm = rhs (concl thm1)
960     val thm2a = PART_MATCH (lhs o rand) PMATCH_EVAL_MATCH thm1_tm
961     val pre_thm = EQF_ELIM (snd (valOf(snd (hd rows_checked1))))
962     val thm2b = MP thm2a pre_thm
963
964     val thm2c = CONV_RULE (RHS_CONV
965        (RAND_CONV (rc_conv rc_arg) THENC
966         pairLib.GEN_BETA_CONV)) thm2b handle HOL_ERR _ => thm2b
967   in
968     thm2c
969   end handle HOL_ERR _ => let
970     val _ = if (List.null rows_checked1) then () else failwith "nothing to do"
971   in
972     (REWR_CONV (CONJUNCT1 PMATCH_def)) (rhs (concl thm1))
973   end handle HOL_ERR _ => REFL (rhs (concl thm1))
974in
975  TRANS (TRANS thm0 thm1) thm2
976end handle HOL_ERR _ => raise UNCHANGED
977
978
979fun PMATCH_CLEANUP_CONV_GEN ssl = PMATCH_CLEANUP_CONV_GENCALL (ssl, NONE)
980fun PMATCH_CLEANUP_GEN_ss ssl =
981  make_gen_conv_ss PMATCH_CLEANUP_CONV_GENCALL "PMATCH_CLEANUP_REDUCER" ssl
982val PMATCH_CLEANUP_ss = PMATCH_CLEANUP_GEN_ss []
983val PMATCH_CLEANUP_CONV = PMATCH_CLEANUP_CONV_GEN [];
984val _ = computeLib.add_convs [(patternMatchesSyntax.PMATCH_tm, 2, QCHANGED_CONV PMATCH_CLEANUP_CONV)];
985
986
987(***********************************************)
988(* simplify a column                           *)
989(***********************************************)
990
991(* This can also be considered partial evaluation *)
992
993fun pair_get_col col v = let
994  val vs = pairSyntax.strip_pair v
995  val (vs', c_v) = replace_element vs col []
996  val _ = if (List.null vs') then failwith "pair_get_col"
997      else ()
998  val v' = pairSyntax.list_mk_pair vs'
999in
1000  (v', c_v)
1001end;
1002
1003(*----------------*)
1004(* drop a column  *)
1005(*----------------*)
1006
1007(*
1008val t = ``
1009PMATCH (NONE,x,l)
1010     [PMATCH_ROW (\x. (NONE,x,[])) (\x. T) (\x. x);
1011      PMATCH_ROW (\z. (NONE,z,[2])) (\z. F) (\z. z);
1012      PMATCH_ROW (\x. (NONE,x,[2])) (\x. T) (\x. x);
1013      PMATCH_ROW (\(z,y). (y,z,[2])) (\(z, y). IS_SOME y) (\(z, y). y)
1014   ]``
1015
1016val t = ``
1017  PMATCH (x + y,ys)
1018    [PMATCH_ROW (\x. (x,[])) (\x. T) (\x. x);
1019     PMATCH_ROW (\ (x,y,ys). (x,y::ys)) (\ (x,y,ys). T)
1020       (\ (x,y,ys). my_d (x + y,ys))]``
1021
1022
1023val t = ``PMATCH (x,y)
1024    [PMATCH_ROW (\x. (x,x)) (\x. T) (\x. T);
1025     PMATCH_ROW (\ (z, y). (z, y)) (\ (z, y). T) (\ (z, y). F)]``
1026
1027
1028val rc_arg = []
1029val col = 0
1030*)
1031
1032fun PMATCH_REMOVE_COL_AUX rc_arg col t = let
1033  val (v, rows) = dest_PMATCH t
1034  val (v', c_v) = pair_get_col col v
1035  val vs = free_vars c_v
1036
1037  val thm_row = let
1038    val thm = FRESH_TY_VARS_RULE PMATCH_ROW_REMOVE_FUN_VAR
1039    val thm = ISPEC v thm
1040    val thm = ISPEC v' thm
1041  in thm end
1042
1043  fun PMATCH_ROW_REMOVE_FUN_VAR_COL_AUX row = let
1044     val (vars_tm, pt, gt, rh) = dest_PMATCH_ROW_ABS_VARIANT vs row
1045     val vars = pairSyntax.strip_pair vars_tm
1046     val avoid = free_varsl [pt, gt, rh]
1047
1048     val (pt0', pv) = pair_get_col col pt
1049     val pt' = subst [pv |-> c_v] pt0'
1050
1051     val pv_i_opt = SOME (index (aconv pv) vars) handle HOL_ERR _ => NONE
1052     val (vars'_tm, f) = case pv_i_opt of
1053         (SOME pv_i) => (let
1054           (* we eliminate a variabe column *)
1055           val vars' = let
1056             val (vars', _) = replace_element vars pv_i []
1057           in
1058             if (List.null vars') then [variant avoid ``_uv:unit``] else vars'
1059           end
1060
1061           val vars'_tm = pairSyntax.list_mk_pair vars'
1062           val g' = let
1063             val (vs, _) = replace_element vars pv_i [c_v]
1064             val vs_tm = pairSyntax.list_mk_pair vs
1065           in
1066             pairSyntax.mk_pabs (vars'_tm, vs_tm)
1067           end
1068         in
1069           (vars'_tm, g')
1070         end)
1071       | NONE => (let
1072           (* we eliminate a costant columns *)
1073           val (sub, _) = match_term pv c_v
1074           val _ = if List.all (fn x => List.exists (aconv (#redex x)) vars) sub then () else failwith "not a constant-col after all"
1075
1076           val vars' = filter (fn v => not (List.exists (fn x => (aconv v (#redex x))) sub)) vars
1077           val vars' = if (List.null vars') then [variant avoid ``_uv:unit``] else vars'
1078           val vars'_tm = pairSyntax.list_mk_pair vars'
1079
1080           val g' = pairSyntax.mk_pabs (vars'_tm, Term.subst sub vars_tm)
1081         in
1082           (vars'_tm, g')
1083         end)
1084
1085(*   val f = pairSyntax.mk_pabs (vars_tm, pt)
1086     val f' = pairSyntax.mk_pabs (vars'_tm, pt')
1087     val g = pairSyntax.mk_pabs (vars_tm, rh)
1088
1089*)
1090     val p = pairSyntax.mk_pabs (vars_tm, pt)
1091     val p' = pairSyntax.mk_pabs (vars'_tm, pt')
1092     val g = pairSyntax.mk_pabs (vars_tm, gt)
1093     val r = pairSyntax.mk_pabs (vars_tm, rh)
1094
1095     val thm0 = let
1096       val thm = thm_row
1097       val thm = ISPEC f thm
1098       val thm = ISPEC p thm
1099       val thm = ISPEC g thm
1100       val thm = ISPEC r thm
1101       val thm = ISPEC p' thm
1102
1103       fun elim_conv_aux vs = (
1104         (pairTools.PABS_INTRO_CONV vs) THENC
1105         (DEPTH_CONV (pairLib.PAIRED_BETA_CONV ORELSEC BETA_CONV))
1106       )
1107
1108       fun elim_conv vs = PMATCH_ROW_ARGS_CONV (elim_conv_aux vs)
1109       val thm = CONV_RULE ((RAND_CONV o RHS_CONV) (elim_conv vars'_tm)) thm
1110
1111       val tm_eq = mk_eq(lhs (rand (concl thm)), mk_comb (row, v))
1112       val eq_thm = prove (tm_eq, rc_tac rc_arg)
1113
1114       val thm = CONV_RULE (RAND_CONV (LHS_CONV (K eq_thm))) thm
1115     in
1116       thm
1117     end
1118
1119     val pre_tm = fst (dest_imp (concl thm0))
1120(* set_goal ([], pre_tm) *)
1121     val pre_thm = prove (pre_tm, rc_tac rc_arg)
1122     val thm1 = MP thm0 pre_thm
1123  in
1124     thm1
1125  end
1126
1127  fun process_row (row, thm) = let
1128    val row_thm = PMATCH_ROW_REMOVE_FUN_VAR_COL_AUX row
1129    val thmA = PMATCH_EXTEND_BOTH
1130    val thmB = HO_MATCH_MP thmA row_thm
1131    val thmC = HO_MATCH_MP thmB thm
1132  in
1133    thmC
1134  end
1135
1136  val base_thm = INST_TYPE [gamma |-> type_of t] (ISPECL [v, v'] PMATCH_EXTEND_BASE)
1137  val thm0 = List.foldl process_row base_thm (List.rev rows)
1138in
1139  thm0
1140end handle HOL_ERR _ => raise UNCHANGED
1141
1142
1143(*------------------------------------*)
1144(* remove a constructor from a column *)
1145(*------------------------------------*)
1146
1147(*
1148val t = ``
1149PMATCH (SOME y,x,l)
1150     [PMATCH_ROW (\x. (SOME 0,x,[])) (\x. T) (\x. x);
1151      PMATCH_ROW (\z. (SOME 1,z,[2])) (\z. F) (\z. z);
1152      PMATCH_ROW (\x. (SOME 3,x,[2])) (\x. T) (\x. x);
1153      PMATCH_ROW (\(z,y). (y,z,[2])) (\(z, y). IS_SOME y) (\(z, y). y)
1154   ]``
1155
1156val rc_arg = []
1157val col = 0
1158*)
1159
1160
1161fun PMATCH_REMOVE_FUN_AUX rc_arg col t = let
1162  val (v, rows) = dest_PMATCH t
1163
1164  val (ff_tm, ff_inv, ff_inv_var, c) = let
1165    val vs = pairSyntax.strip_pair v
1166    val c_args = List.nth(vs, col)
1167    val (c, args) = strip_comb c_args
1168
1169    val vs_vars = List.map (fn t => genvar (type_of t)) vs
1170    val args_vars = List.map (fn t => genvar (type_of t)) args
1171
1172    val (vars, _) = replace_element vs_vars col args_vars
1173    val (ff_res, _) = replace_element vs_vars col [list_mk_comb (c, args_vars)]
1174    val ff_tm = pairSyntax.mk_pabs (pairSyntax.list_mk_pair vars,
1175       pairSyntax.list_mk_pair ff_res)
1176
1177    fun ff_inv tt = let
1178      val tts = pairSyntax.strip_pair tt
1179      val tt_args = List.nth(tts, col)
1180
1181      val (c', args') = strip_comb tt_args
1182      val _ = if (aconv c c') then () else failwith "different constr"
1183
1184      val (vars,_) = replace_element tts col args'
1185    in
1186      pairSyntax.list_mk_pair vars
1187    end
1188
1189    fun ff_inv_var avoid tt = let
1190      val tts = pairSyntax.strip_pair tt
1191      val tt_col = List.nth(tts, col)
1192
1193      val _ = if (is_var tt_col) then () else failwith "no var"
1194
1195      val (var_basename, _) = dest_var (tt_col)
1196      val gen_fun = mk_var_gen (var_basename ^ "_") avoid;
1197      val args =  map (fn t => gen_fun (type_of t)) args_vars
1198
1199      val (vars, _) = replace_element tts col args
1200    in
1201      (pairSyntax.list_mk_pair vars, tt_col, args)
1202    end
1203
1204  in
1205    (ff_tm, ff_inv, ff_inv_var, c)
1206  end
1207
1208  val ff_thm_tm = ``!x y. (^ff_tm x = ^ff_tm y) ==> (x = y)``
1209  val ff_thm = prove (ff_thm_tm, rc_tac rc_arg)
1210
1211  val v' = ff_inv v
1212
1213  val PMATCH_ROW_REMOVE_FUN' = let
1214    val thm0 =  FRESH_TY_VARS_RULE PMATCH_ROW_REMOVE_FUN
1215    val thm1 = ISPEC ff_tm  thm0
1216    val thm2 = ISPEC v'  thm1
1217    val thm3 = MATCH_MP thm2 ff_thm
1218
1219    val thm_v' = prove (``^ff_tm ^v' = ^v``, rc_tac rc_arg)
1220    val thm4 = CONV_RULE (STRIP_QUANT_CONV (LHS_CONV (RAND_CONV (K thm_v')))) thm3
1221  in
1222    thm4
1223  end
1224
1225  fun PMATCH_ROW_REMOVE_FUN_COL_AUX row = let
1226     val (vars_tm, pt, gt, rh) = dest_PMATCH_ROW_ABS row
1227
1228     val pt' = ff_inv pt
1229     val vpt' = pairSyntax.mk_pabs (vars_tm, pt')
1230     val vgt = pairSyntax.mk_pabs (vars_tm, gt)
1231     val vrh = pairSyntax.mk_pabs (vars_tm, rh)
1232
1233     val thm0 = ISPECL [vpt', vgt, vrh] PMATCH_ROW_REMOVE_FUN'
1234     val eq_thm_tm = mk_eq (lhs (concl thm0), mk_comb (row, v))
1235     val eq_thm = prove (eq_thm_tm, rc_tac rc_arg)
1236
1237     val thm1 = CONV_RULE (LHS_CONV (K eq_thm)) thm0
1238
1239     val vi_conv = (pairTools.PABS_INTRO_CONV vars_tm) THENC
1240         (DEPTH_CONV (pairLib.PAIRED_BETA_CONV ORELSEC BETA_CONV))
1241
1242     val thm2 = CONV_RULE (RHS_CONV (PMATCH_ROW_ARGS_CONV vi_conv)) thm1
1243  in
1244     thm2
1245  end
1246
1247  val thm_row = let
1248     val thm = FRESH_TY_VARS_RULE PMATCH_ROW_REMOVE_FUN_VAR
1249     val thm = ISPEC v thm
1250     val thm = ISPEC v' thm
1251  in thm end
1252
1253  fun PMATCH_ROW_REMOVE_VAR_COL_AUX row = let
1254     val (vars_tm, pt, gt, rh) = dest_PMATCH_ROW_ABS row
1255     val vars = pairSyntax.strip_pair vars_tm
1256
1257     val avoid = vars @ free_vars pt @ free_vars rh @ free_vars gt
1258     val (pt', pv, new_vars) = ff_inv_var avoid pt
1259
1260     val pv_i = index (aconv pv) vars
1261
1262     val vars' = let
1263       val (vars', _) = replace_element vars pv_i new_vars
1264     in
1265       if (List.null vars') then [variant avoid ``_uv:unit``] else vars'
1266     end
1267
1268     val vars'_tm = pairSyntax.list_mk_pair vars'
1269     val f_tm = let
1270        val c_v = list_mk_comb (c, new_vars)
1271        val (vs, _) = replace_element vars pv_i [c_v]
1272        val vs_tm = pairSyntax.list_mk_pair vs
1273     in
1274        pairSyntax.mk_pabs (vars'_tm, vs_tm)
1275     end
1276
1277     val vpt = pairSyntax.mk_pabs (vars_tm, pt)
1278     val vpt' = pairSyntax.mk_pabs (vars'_tm, pt')
1279     val vrh = pairSyntax.mk_pabs (vars_tm, rh)
1280     val vgt = pairSyntax.mk_pabs (vars_tm, gt)
1281
1282     val thm0 = let
1283       val thm = ISPEC f_tm thm_row
1284       val thm = ISPEC vpt thm
1285       val thm = ISPEC vgt thm
1286       val thm = ISPEC vrh thm
1287       val thm = ISPEC vpt' thm
1288
1289       fun elim_conv vs = PMATCH_ROW_ARGS_CONV (
1290         (pairTools.PABS_INTRO_CONV vs) THENC
1291         (DEPTH_CONV (pairLib.PAIRED_BETA_CONV ORELSEC BETA_CONV))
1292       )
1293
1294       val thm = CONV_RULE ((RAND_CONV o RHS_CONV) (elim_conv vars'_tm)) thm
1295
1296       val tm_eq = mk_eq(lhs (rand (concl thm)), mk_comb (row, v))
1297       val eq_thm = prove (tm_eq, rc_tac rc_arg)
1298
1299       val thm = CONV_RULE (RAND_CONV (LHS_CONV (K eq_thm))) thm
1300     in
1301       thm
1302     end
1303
1304     val pre_tm = fst (dest_imp (concl thm0))
1305     val pre_thm = prove (pre_tm, rc_tac rc_arg)
1306
1307     val thm1 = MP thm0 pre_thm
1308  in
1309     thm1
1310  end
1311
1312
1313  fun process_row (row, thm) = let
1314    val row_thm = PMATCH_ROW_REMOVE_FUN_COL_AUX row handle HOL_ERR _ =>
1315                  PMATCH_ROW_REMOVE_VAR_COL_AUX row
1316    val thmA = PMATCH_EXTEND_BOTH
1317    val thmB = HO_MATCH_MP thmA row_thm
1318    val thmC = HO_MATCH_MP thmB thm
1319  in
1320    thmC
1321  end
1322
1323(*
1324  val row = el 1 (List.rev rows)
1325  val thm = base_thm
1326  val thm = thm0
1327*)
1328
1329  val base_thm = INST_TYPE [gamma |-> type_of t] (ISPECL [v, v'] PMATCH_EXTEND_BASE)
1330  val thm0 = foldl process_row base_thm (List.rev rows)
1331in
1332  thm0
1333end handle HOL_ERR _ => raise UNCHANGED
1334
1335
1336(*------------------------*)
1337(* Combine auxiliary funs *)
1338(*------------------------*)
1339
1340(*
1341val t = ``
1342PMATCH (SOME y,x,l)
1343     [PMATCH_ROW (\x. (SOME 0,x,[])) (\x. T) (\x. x);
1344      PMATCH_ROW (\z. z) (\z. F) (\z. (FST (SND z)));
1345      PMATCH_ROW (\x. (SOME 3,x)) (\x. T) (\x. (FST x));
1346      PMATCH_ROW (\(z,y). (y,z,[2])) (\(z, y). IS_SOME y) (\(z, y). y)
1347   ]``
1348val rc_arg = []
1349*)
1350
1351fun PMATCH_SIMP_COLS_CONV_GENCALL rc_arg t = let
1352  val cols = dest_PMATCH_COLS t
1353(*
1354  val (col_v, col) = el 1 cols
1355  val (vars, col_pat) = el 3  col
1356*)
1357  fun do_match col_v (vars, col_pat) = let
1358    val (sub, _) = match_term col_pat col_v
1359    val vars_ok = List.all (fn x => (List.exists (aconv (#redex x)) vars)) sub
1360  in
1361    vars_ok
1362  end handle HOL_ERR _ => false
1363
1364  fun elim_col_ok (col_v, col) =
1365    List.all (do_match col_v) col
1366
1367  fun simp_col_ok (col_v, col) = let
1368    val (c, args) = strip_comb col_v
1369    val _ = if (List.null args) then failwith "elim_col instead" else ()
1370
1371    fun check_line (vars, pt) =
1372      (List.exists (aconv pt) vars) orelse
1373      (aconv (fst (strip_comb pt)) c)
1374  in
1375    List.all check_line col
1376  end handle HOL_ERR _ => false
1377
1378  fun process_col i col = if (elim_col_ok col) then
1379    SOME (PMATCH_REMOVE_COL_AUX rc_arg i t)
1380  else if (simp_col_ok col) then
1381    SOME (PMATCH_REMOVE_FUN_AUX rc_arg i t)
1382  else NONE
1383
1384  val thm_opt = first_opt process_col cols
1385in
1386  case thm_opt of NONE => raise UNCHANGED
1387                | SOME thm => thm
1388end
1389
1390fun PMATCH_SIMP_COLS_CONV_GEN ssl = PMATCH_SIMP_COLS_CONV_GENCALL (ssl, NONE)
1391val PMATCH_SIMP_COLS_CONV = PMATCH_SIMP_COLS_CONV_GEN [];
1392
1393
1394(***********************************************)
1395(* Resort and add dummy columns                *)
1396(***********************************************)
1397
1398(*
1399val t = ``PMATCH (s:'a option,x : 'a option, l:num list)
1400     [PMATCH_ROW (\_uv:unit. (NONE,NONE,[])) (\_uv. T) (\_uv. NONE);
1401      PMATCH_ROW (\z. (NONE,z,[2])) (\z. F) (\z. z);
1402      PMATCH_ROW (\(x, b). (SOME b,x,[2])) (\(x, b). T) (\(x, b). x);
1403      PMATCH_ROW (\(_0,y). (y,_0,[2])) (\(_0, y). IS_SOME y) (\(_0, y). y)
1404   ]``
1405
1406val nv = ``((l:num list), x : 'a option, xx:'a, s:'a option, z:'b)``
1407
1408val t = ``case (xs : num list) of [] => x | _ => HD xs``
1409val t = ``case (xs : num list) of [] => x | _::_ => HD xs``
1410val nv = ``(xs: num list, x:num)``
1411val rc_arg = ([], NONE)
1412
1413*)
1414fun PMATCH_EXTEND_INPUT_CONV_GENCALL rc_arg nv t = let
1415  val (v, rows) = dest_PMATCH t
1416  val _ = if aconv v nv then raise UNCHANGED else ()
1417
1418  val (new_pat_tm, new_col_vars, new_col_subst, old_col_vars, f'_tm, nv_vars_pair) = let
1419    val nv_parts = pairSyntax.strip_pair nv
1420    val v_parts = pairSyntax.strip_pair v
1421    val v_l = map (fn t => (t, genvar (type_of t))) v_parts
1422
1423    val avoid = all_varsl [t, nv]
1424    val gen_fun = mk_var_gen "_" avoid
1425
1426    fun compute_nv_l res_l nv_vars v_l [] =
1427      if (null v_l) then (res_l, nv_vars) else failwith "PMATCH_EXTEND_INPUT_CONV_AUX: part missing"
1428
1429    | compute_nv_l res nv_vars v_l (p::nv_parts) = let
1430        val ((_, v_v), v_l') = pluck (fn (t, _) => aconv p t) v_l
1431      in
1432        compute_nv_l (v_v::res) nv_vars v_l' nv_parts
1433      end handle HOL_ERR _ => let
1434        val vg = gen_fun (type_of p)
1435      in
1436        compute_nv_l (vg::res) (vg::nv_vars) v_l nv_parts
1437      end
1438
1439    val (res_l, nv_vars) = compute_nv_l [] [] v_l nv_parts
1440
1441    val t0 = pairSyntax.list_mk_pair (List.rev res_l)
1442
1443    val nv_parts_vars = filter (fn (v, n) => is_var n) (zip (List.rev res_l) nv_parts)
1444    val t1 = pairSyntax.list_mk_pair (map fst nv_parts_vars)
1445    val f'_tm = pairSyntax.mk_pabs(t0, t1)
1446    val nv_vars_pair = pairSyntax.list_mk_pair (map snd nv_parts_vars)
1447    val nv_parts_subst = map (fn (v, n) => v |-> n) nv_parts_vars
1448  in
1449    (t0, List.rev nv_vars, nv_parts_subst, List.map snd v_l, f'_tm, nv_vars_pair)
1450  end
1451
1452  val thm_row = let
1453    val thm = FRESH_TY_VARS_RULE PMATCH_ROW_EXTEND_INPUT
1454    val thm = ISPEC v thm
1455    val thm = ISPEC nv thm
1456    val thm = ISPEC f'_tm thm
1457  in thm end
1458
1459  fun process_row_aux row = let
1460     val (pt, gt, rh) = dest_PMATCH_ROW row
1461     val (vars_tm, pt_b) = pairSyntax.dest_pabs pt
1462
1463     val (old_vars_tm, new_vars) =
1464        if ((type_of vars_tm) = one_ty) then (
1465          if (List.null new_col_vars) then
1466            (one_tm, [vars_tm])
1467          else (one_tm, new_col_vars)
1468        ) else (
1469          let val ov = pairSyntax.strip_pair vars_tm
1470          in (vars_tm, ov @ new_col_vars) end
1471        )
1472
1473     val pat_vars_s0 = let
1474         val pt_ps = pairSyntax.strip_pair pt_b
1475       in map2 (fn v => fn t => (v |-> t)) old_col_vars pt_ps end
1476
1477     val new_vars_s0 = filter (fn vp => free_in (#residue vp) row
1478       andalso is_var (subst pat_vars_s0 (#redex vp))) new_col_subst
1479
1480     val new_vars_s = map (fn vp => subst pat_vars_s0 (#redex vp) |-> #residue vp) new_vars_s0
1481     val pat_vars_s = map (fn vp => #redex vp |-> subst new_vars_s (#residue vp)) pat_vars_s0
1482
1483     val new_vars_tm = subst new_vars_s (pairSyntax.list_mk_pair new_vars)
1484     val pt_b' = subst pat_vars_s (subst new_vars_s new_pat_tm)
1485
1486     val f_tm = pairSyntax.mk_pabs (new_vars_tm, subst new_vars_s old_vars_tm)
1487     val pt' = pairSyntax.mk_pabs (new_vars_tm, pt_b')
1488
1489     val thm0 = let
1490       val thm = ISPEC f_tm thm_row
1491       val thm = ISPEC pt thm
1492       val thm = ISPEC (pairSyntax.mk_pabs (nv_vars_pair, gt)) thm
1493       val thm = ISPEC (pairSyntax.mk_pabs (nv_vars_pair, rh)) thm
1494       val thm = ISPEC pt' thm
1495
1496       fun elim_conv_aux vs = (
1497         (pairTools.PABS_INTRO_CONV vs) THENC
1498         (DEPTH_CONV (pairLib.PAIRED_BETA_CONV ORELSEC BETA_CONV))
1499       )
1500
1501       fun elim_conv vs = PMATCH_ROW_ARGS_CONV (elim_conv_aux vs)
1502       val thm = CONV_RULE ((RAND_CONV o RHS_CONV) (elim_conv new_vars_tm)) thm
1503     in
1504       thm
1505     end
1506
1507     val pre_tm = fst (dest_imp (concl thm0))
1508(* set_goal ([], pre_tm) *)
1509     val pre_thm = prove (pre_tm, rc_tac rc_arg)
1510     val thm1 = MP thm0 pre_thm
1511
1512     val eq_tm = mk_eq (mk_comb (row, v), lhs (concl thm1))
1513     val eq_thm = prove (eq_tm, SIMP_TAC std_ss [])
1514     val thm2 = TRANS eq_thm thm1
1515
1516     (* fix wildcards *)
1517     val thm3 = CONV_RULE (RHS_CONV (RATOR_CONV (PMATCH_ROW_INTRO_WILDCARDS_CONV))) thm2
1518  in
1519     thm3
1520  end
1521
1522  fun process_row (row, thm) = let
1523    val row_thm = process_row_aux row
1524    val thmA = PMATCH_EXTEND_BOTH
1525    val thmB = HO_MATCH_MP thmA row_thm
1526    val thmC = HO_MATCH_MP thmB thm
1527  in
1528    thmC
1529  end
1530
1531  val base_thm = INST_TYPE [gamma |-> type_of t] (ISPECL [v, nv] PMATCH_EXTEND_BASE)
1532  val thm0 = List.foldl process_row base_thm (List.rev rows)
1533in
1534  thm0
1535end handle HOL_ERR _ => raise UNCHANGED
1536
1537
1538fun PMATCH_EXTEND_INPUT_CONV_GEN ssl = PMATCH_EXTEND_INPUT_CONV_GENCALL (ssl, NONE)
1539val PMATCH_EXTEND_INPUT_CONV = PMATCH_EXTEND_INPUT_CONV_GEN [];
1540
1541
1542(***********************************************)
1543(* Expand columns                              *)
1544(***********************************************)
1545
1546(* Sometimes not all rows of a PMATCH have the same number of
1547   explicit columns. This can happen, if some patterns are
1548   explicit pairs, while others are not. The following tries
1549   to expand columns into explicit ones. *)
1550
1551(*
1552val t = ``
1553PMATCH (SOME y,x,l)
1554     [PMATCH_ROW (\x. (SOME 0,x,[])) (\x. T) (\x. x);
1555      PMATCH_ROW (\z. z) (\z. F) (\z. (FST (SND z)));
1556      PMATCH_ROW (\x. (SOME 3,x)) (\x. T) (\x. (FST x));
1557      PMATCH_ROW (\(z,y). (y,z,[2])) (\(z, y). IS_SOME y) (\(z, y). y)
1558   ]``
1559*)
1560
1561fun PMATCH_EXPAND_COLS_CONV t = let
1562  val (v, rows) = dest_PMATCH t
1563
1564  val col_no_v = length (pairSyntax.strip_pair v)
1565  val col_no = foldl (fn (r, m) => let
1566    val (pt', _, _) = dest_PMATCH_ROW r
1567    val (_, pt) = pairSyntax.dest_pabs pt'
1568    val m' = length (pairSyntax.strip_pair pt)
1569    val m'' = if m' > m then m' else m
1570  in m'' end) col_no_v rows
1571
1572  fun split_var avoid cols l = let
1573    fun splits acc no ty = if (no = 0) then List.rev (ty::acc) else
1574    let
1575      val (ty_s, ty') = pairSyntax.dest_prod ty
1576    in
1577      splits (ty_s::acc) (no - 1) ty'
1578    end
1579
1580    val types = splits [] (col_no - cols) (type_of l)
1581
1582    val var_basename = fst (dest_var l) handle HOL_ERR _ => "v"
1583    val gen_fun = mk_var_gen (var_basename ^ "_") avoid;
1584    val new_vars =  map gen_fun types
1585  in
1586    new_vars
1587  end
1588
1589  fun PMATCH_ROW_EXPAND_COLS row = let
1590     val (vars_tm, pt, gt, rh) = dest_PMATCH_ROW_ABS row
1591
1592     val vars = pairSyntax.strip_pair vars_tm
1593     val pts = pairSyntax.strip_pair pt
1594     val cols = length pts
1595
1596     val _ = if (cols < col_no) then () else failwith "nothing to do"
1597     val l = last pts
1598
1599     val _ = if (List.exists (aconv l) vars) then () else failwith "nothing to do"
1600
1601     val avoids = vars @ free_vars pt @ free_vars gt @ free_vars rh
1602     val new_vars = split_var avoids cols l
1603
1604     val sub = [l |-> pairSyntax.list_mk_pair new_vars]
1605     val pt' = Term.subst sub pt
1606     val gt' = Term.subst sub gt
1607     val rh' = Term.subst sub rh
1608     val vars' = pairSyntax.strip_pair (Term.subst sub vars_tm)
1609
1610     val row' = mk_PMATCH_ROW_PABS vars' (pt', gt', rh')
1611
1612     val eq_tm = mk_eq(row, row')
1613     val eq_thm = prove (eq_tm, rc_tac ([], NONE))
1614     val thm = AP_THM eq_thm v
1615  in
1616     SOME thm
1617  end handle HOL_ERR _ => NONE
1618
1619  val rows = List.rev rows
1620  val row_thms = map PMATCH_ROW_EXPAND_COLS rows
1621  val _ = if (exists isSome row_thms) then () else raise UNCHANGED
1622
1623  fun process_row ((row_thm_opt, row), thm) = let
1624    val row_thm = case row_thm_opt of
1625        NONE => REFL (mk_comb (row, v))
1626      | SOME thm => thm
1627    val thmA = PMATCH_EXTEND_BOTH
1628    val thmB = HO_MATCH_MP thmA row_thm
1629    val thmC = HO_MATCH_MP thmB thm
1630  in
1631    thmC
1632  end
1633
1634  val base_thm = INST_TYPE [gamma |-> type_of t] (ISPECL [v, v] PMATCH_EXTEND_BASE)
1635  val thm0 = foldl process_row base_thm (zip row_thms rows)
1636in
1637  thm0
1638end handle HOL_ERR _ => raise UNCHANGED;
1639
1640
1641(***********************************************)
1642(* PMATCH_SIMP_CONV                            *)
1643(***********************************************)
1644
1645(*
1646val t = ``
1647PMATCH (SOME y,x,l)
1648     [PMATCH_ROW (\x. (SOME 0,x,[])) (\y. T) (\x. x);
1649      PMATCH_ROW (\z. z) (\z. F) (\z. (FST (SND z)));
1650      PMATCH_ROW (\x. (SOME 3,x)) (\x. T) (\x. (FST x));
1651      PMATCH_ROW (\(z,y). (y,z,[2])) (\(z, y). IS_SOME y) (\(z, y). y)
1652   ]``
1653*)
1654
1655val PMATCH_NORMALISE_CONV_AUX =
1656EVERY_CONV [
1657  TRY_CONV (QCHANGED_CONV PMATCH_CLEANUP_PVARS_CONV),
1658  TRY_CONV (QCHANGED_CONV PMATCH_FORCE_SAME_VARS_CONV),
1659  TRY_CONV (QCHANGED_CONV PMATCH_EXPAND_COLS_CONV),
1660  TRY_CONV (QCHANGED_CONV PMATCH_INTRO_WILDCARDS_CONV)
1661];
1662
1663fun PMATCH_NORMALISE_CONV t =
1664  if (is_PMATCH t) then PMATCH_NORMALISE_CONV_AUX t else raise UNCHANGED;
1665
1666val PMATCH_NORMALISE_ss =
1667    simpLib.conv_ss
1668      {name  = "PMATCH_NORMALISE_CONV",
1669       trace = 2,
1670       key   = SOME ([],``PMATCH (p:'a) (rows : ('a -> 'b option) list)``),
1671       conv  = K (K PMATCH_NORMALISE_CONV)}
1672
1673
1674fun PMATCH_SIMP_CONV_GENCALL_AUX rc_arg =
1675(TRY_CONV PMATCH_NORMALISE_CONV_AUX) THENC
1676REPEATC (FIRST_CONV [
1677  QCHANGED_CONV (PMATCH_CLEANUP_CONV_GENCALL rc_arg),
1678  QCHANGED_CONV (PMATCH_SIMP_COLS_CONV_GENCALL rc_arg),
1679  QCHANGED_CONV (PMATCH_REMOVE_FAST_REDUNDANT_CONV_GENCALL rc_arg),
1680  QCHANGED_CONV (PMATCH_REMOVE_FAST_SUBSUMED_CONV_GENCALL false rc_arg)
1681]);
1682
1683fun PMATCH_SIMP_CONV_GENCALL rc_arg t =
1684  if (is_PMATCH t) then PMATCH_SIMP_CONV_GENCALL_AUX rc_arg t else
1685  raise UNCHANGED
1686
1687fun PMATCH_SIMP_CONV_GEN ssl = PMATCH_SIMP_CONV_GENCALL (ssl, NONE)
1688
1689val PMATCH_SIMP_CONV = PMATCH_SIMP_CONV_GEN [];
1690
1691fun PMATCH_SIMP_GEN_ss ssl =
1692  make_gen_conv_ss PMATCH_SIMP_CONV_GENCALL "PMATCH_SIMP_REDUCER" ssl
1693
1694val PMATCH_SIMP_ss = name_ss "patternMatchesSimp" (PMATCH_SIMP_GEN_ss [])
1695val _ = BasicProvers.augment_srw_ss [PMATCH_SIMP_ss];
1696
1697
1698fun PMATCH_FAST_SIMP_CONV_GENCALL_AUX rc_arg =
1699REPEATC (FIRST_CONV [
1700  QCHANGED_CONV (PMATCH_CLEANUP_CONV_GENCALL rc_arg),
1701  QCHANGED_CONV (PMATCH_SIMP_COLS_CONV_GENCALL rc_arg)
1702]);
1703
1704fun PMATCH_FAST_SIMP_CONV_GENCALL rc_arg t =
1705  if (is_PMATCH t) then PMATCH_FAST_SIMP_CONV_GENCALL_AUX rc_arg t else
1706  raise UNCHANGED
1707
1708fun PMATCH_FAST_SIMP_CONV_GEN ssl = PMATCH_FAST_SIMP_CONV_GENCALL (ssl, NONE)
1709
1710val PMATCH_FAST_SIMP_CONV = PMATCH_FAST_SIMP_CONV_GEN [];
1711
1712fun PMATCH_FAST_SIMP_GEN_ss ssl =
1713  make_gen_conv_ss PMATCH_FAST_SIMP_CONV_GENCALL "PMATCH_FAST_SIMP_REDUCER" ssl
1714
1715val PMATCH_FAST_SIMP_ss = name_ss "patternMatchesFastSimp" (PMATCH_FAST_SIMP_GEN_ss [])
1716
1717
1718(***********************************************)
1719(* Remove double var bindings                  *)
1720(***********************************************)
1721
1722fun force_unique_vars s no_change avoid t =
1723  case Psyntax.dest_term t of
1724      Psyntax.VAR (_, _) =>
1725      if (mem t no_change) then (s, avoid, t) else
1726      let
1727         val v' = variant avoid t
1728         val avoid' = v'::avoid
1729         val s' = if (v' = t) then s else ((v', t)::s)
1730      in (s', avoid', v') end
1731    | Psyntax.CONST _ => (s, avoid, t)
1732    | Psyntax.LAMB (v, t') => let
1733         val (s', avoid', t'') = force_unique_vars s (v::no_change)
1734           (v::avoid) t'
1735      in
1736         (s', avoid', mk_abs (v, t''))
1737      end
1738    | Psyntax.COMB (t1, t2) => let
1739         val (s', avoid', t1') = force_unique_vars s no_change avoid t1
1740         val (s'', avoid'', t2') = force_unique_vars s' no_change avoid' t2
1741      in
1742         (s'', avoid'', mk_comb (t1', t2'))
1743      end;
1744
1745(*
1746val row = ``PMATCH_ROW (\ (x,y). (x, SOME y, SOME x, SOME z, (x+z)))
1747              (\ (x, y). P x y) (\ (x, y). f x y)``
1748
1749val row = ``PMATCH_ROW (\ (x,y). (x, SOME y, SOME z, SOME z, (z+z)))
1750              (\ (x, y). P x y) (\ (x, y). f x y)``
1751*)
1752
1753fun PMATCH_ROW_REMOVE_DOUBLE_BIND_CONV_GENCALL rc_arg row = let
1754  val _ = if not (is_PMATCH_ROW row) then raise UNCHANGED else ()
1755  val (p_t, g_t, r_t) = dest_PMATCH_ROW row
1756  val (vars_tm, p_tb) = pairSyntax.dest_pabs p_t
1757  val vars = pairSyntax.strip_pair vars_tm
1758
1759  val (new_binds, _, p_tb') = force_unique_vars [] [] (free_vars p_t) p_tb
1760  val _ = if List.null new_binds then raise UNCHANGED else ()
1761
1762  val vars' = vars @ (List.map fst new_binds)
1763  val g_v = genvar (type_of g_t)
1764  val r_v = genvar (type_of r_t)
1765
1766
1767  val g_t' = list_mk_conj ((List.map mk_eq new_binds)@[mk_comb (g_v, vars_tm)])
1768  val r_t' = mk_comb (r_v, vars_tm)
1769
1770  val row' = mk_PMATCH_ROW_PABS vars' (p_tb', g_t', r_t')
1771  val row0 = mk_PMATCH_ROW (p_t, g_v, r_v)
1772
1773  val thm0_tm = mk_eq (row0, row')
1774  val thm0 = let
1775    val thm0 = FRESH_TY_VARS_RULE PMATCH_ROW_REMOVE_DOUBLE_BINDS_THM
1776    val g_tm = pairSyntax.mk_pabs (vars_tm,
1777      subst (List.map (fn (v, v') => (v |-> v')) new_binds)
1778        (pairSyntax.list_mk_pair vars'))
1779    val thm1 = ISPEC g_tm thm0
1780    val thm2 = PART_MATCH rand thm1 thm0_tm
1781    val thm3 = rc_elim_precond rc_arg thm2
1782  in
1783    thm3
1784  end
1785
1786  val thm1 = INST [(g_v |-> g_t), (r_v |-> r_t)] thm0
1787
1788  val thm1a_tm = mk_eq (row, lhs (concl thm1))
1789  val thm1a = prove (thm1a_tm, rc_tac rc_arg)
1790
1791  val thm2 = TRANS thm1a thm1
1792
1793  val thm3 = CONV_RULE (RHS_CONV (DEPTH_CONV pairLib.GEN_BETA_CONV)) thm2
1794
1795   val thm4 = CONV_RULE (RHS_CONV (RATOR_CONV (RAND_CONV (REWRITE_CONV [])))) thm3
1796in
1797  thm4
1798end handle HOL_ERR _ => raise UNCHANGED
1799
1800fun PMATCH_REMOVE_DOUBLE_BIND_CONV_GENCALL rc_arg t =
1801  PMATCH_ROWS_CONV (PMATCH_ROW_REMOVE_DOUBLE_BIND_CONV_GENCALL
1802    rc_arg) t
1803
1804fun PMATCH_REMOVE_DOUBLE_BIND_CONV_GEN ssl =
1805  PMATCH_REMOVE_DOUBLE_BIND_CONV_GENCALL (ssl, NONE)
1806
1807val PMATCH_REMOVE_DOUBLE_BIND_CONV = PMATCH_REMOVE_DOUBLE_BIND_CONV_GEN [];
1808
1809fun PMATCH_REMOVE_DOUBLE_BIND_GEN_ss ssl =
1810  make_gen_conv_ss PMATCH_ROW_REMOVE_DOUBLE_BIND_CONV_GENCALL "PMATCH_REMOVE_DOUBLE_BIND_REDUCER" ssl
1811
1812val PMATCH_REMOVE_DOUBLE_BIND_ss = PMATCH_REMOVE_DOUBLE_BIND_GEN_ss []
1813
1814
1815(***********************************************)
1816(* Remove a GUARD                              *)
1817(***********************************************)
1818
1819(*
1820val t = ``case (x, y) of
1821  | (x, 2) when EVEN x => x + x
1822  | (SUC x, y) when ODD x => y + x + SUC x
1823  | (SUC x, 1) => x
1824  | (x, _) => x+3``
1825
1826val rc_arg = ([], NONE)
1827val rows = 0
1828*)
1829
1830fun PMATCH_REMOVE_GUARD_AUX rc_arg t = let
1831  val (v, rows) = dest_PMATCH t
1832
1833  fun find_row_to_split rs1 rs = case rs of
1834     [] => raise UNCHANGED (* nothing found *)
1835   | (r:: rs') => let
1836        val (_, _, g, _) = dest_PMATCH_ROW_ABS r
1837        val g_simple = ((g = T) orelse (g = F))
1838     in
1839        if g_simple then
1840           find_row_to_split (r::rs1) rs'
1841        else let
1842          val r_ty = type_of r
1843          val rs1_tm = listSyntax.mk_list (List.rev rs1, r_ty)
1844          val rs2_tm = listSyntax.mk_list (rs', r_ty)
1845        in
1846           (rs1_tm, r, rs2_tm)
1847        end
1848     end
1849
1850  val (rs1, r, rs2) = find_row_to_split [] rows
1851
1852  val thm = let
1853    val thm0 = FRESH_TY_VARS_RULE GUARDS_ELIM_THM
1854    val (p_tm, g_tm, r_tm) = dest_PMATCH_ROW r
1855    val thm1 = ISPECL [v, rs1, rs2, p_tm, g_tm, r_tm] thm0
1856
1857    val thm2 = rc_elim_precond rc_arg thm1
1858    val thm3 = fix_appends rc_arg t thm2
1859  in
1860    thm3
1861  end
1862
1863  val thm2 = CONV_RULE (RHS_CONV (RAND_CONV (RAND_CONV (RATOR_CONV (RAND_CONV PMATCH_ROW_FORCE_SAME_VARS_CONV))))) thm
1864
1865in
1866  thm2
1867end handle HOL_ERR _ => raise UNCHANGED
1868
1869
1870
1871fun PMATCH_REMOVE_GUARDS_CONV_GENCALL rc_arg t = let
1872  val thm0 = REPEATC (PMATCH_REMOVE_GUARD_AUX rc_arg) t
1873  val m_ss = simpLib.merge_ss (fst rc_arg)
1874  val c = SIMP_CONV (std_ss ++ m_ss ++
1875    PMATCH_SIMP_GEN_ss (fst rc_arg)) []
1876  val thm1 = CONV_RULE (RHS_CONV c) thm0
1877in
1878  thm1
1879end handle HOL_ERR _ => raise UNCHANGED
1880
1881fun PMATCH_REMOVE_GUARDS_CONV_GEN ssl = PMATCH_REMOVE_GUARDS_CONV_GENCALL (ssl, NONE)
1882
1883val PMATCH_REMOVE_GUARDS_CONV = PMATCH_REMOVE_GUARDS_CONV_GEN [];
1884
1885fun PMATCH_REMOVE_GUARDS_GEN_ss ssl =
1886  make_gen_conv_ss PMATCH_REMOVE_GUARDS_CONV_GENCALL "PMATCH_REMOVE_GUARDS_REDUCER" ssl
1887
1888val PMATCH_REMOVE_GUARDS_ss = PMATCH_REMOVE_GUARDS_GEN_ss []
1889
1890
1891
1892(***********************************************)
1893(* PATTERN COMPILATION                         *)
1894(***********************************************)
1895
1896(* A column heuristic is a function that chooses the
1897   next column to perform a case split on.
1898   It gets a list of columns of the pattern match, i.e.
1899   the input value + a list of the patterns in each row.
1900   The patterns are represented as a pair of
1901   a list of free variables and the real pattern. *)
1902type column = (term * (term list * term) list)
1903type column_heuristic = column list -> int
1904
1905(* one that uses always the first column *)
1906val colHeu_first_col : column_heuristic = (fn _ => 0)
1907
1908(* one that uses always the last column *)
1909val colHeu_last_col : column_heuristic = (fn cols => length cols - 1)
1910
1911(* A heuristic based on ranking functions *)
1912type column_ranking_fun = (term * (term list * term) list) -> int
1913
1914fun colHeu_rank (rankL : column_ranking_fun list) = (fn colL => let
1915   val ncolL = Lib.enumerate 0 colL
1916   fun step rank ncolL = let
1917     val ranked_cols = List.map (fn (i, c) => ((i, c), rank c)) ncolL
1918     val max = List.foldl (fn ((_, r), m) => if r > m then r else m) (snd (hd ranked_cols)) (tl ranked_cols)
1919     val ranked_cols' = List.filter (fn (_, r) => r = max) ranked_cols
1920     val ncolL' = List.map fst ranked_cols'
1921   in
1922     ncolL'
1923   end
1924   fun steps [] ncolL = ncolL
1925     | steps _ [] = []
1926     | steps _ [e] = [e]
1927     | steps (rf :: rankL) ncolL = steps rankL (step rf ncolL)
1928   val ncolL' = steps rankL ncolL
1929in
1930   case ncolL' of
1931      [] => 0 (* something went wrong, should not happen *)
1932    | ((i, _) :: _) => i
1933end) : column_heuristic
1934
1935
1936(* ranking functions *)
1937fun colRank_first_row (_:term, rows) = (
1938  case rows of
1939    [] => 0
1940  | (vs, p) :: _ =>
1941      if (is_var p andalso mem p vs) then 0 else 1);
1942
1943fun colRank_first_row_constr db (_, rows) = case rows of
1944    [] => 0
1945  | ((vs, p) :: _) => if (is_var p andalso mem p vs) then 0 else
1946      case pmatch_compile_db_compile_cf db rows of
1947        NONE => 0
1948      | SOME cf => let
1949          val (exh, constrL) = constructorFamily_get_constructors cf;
1950          val p_c = fst (strip_comb p)
1951          val cL_cf = List.map fst constrL;
1952          val p_c_ok = op_mem same_const p_c cL_cf
1953        in
1954          (if p_c_ok then 1 else 0)
1955        end handle HOL_ERR _ => 0;
1956
1957val colRank_constr_prefix : column_ranking_fun = (fn (_, rows) =>
1958  let fun aux n [] = n
1959        | aux n ((vs, p) :: pL) = if (is_var p)
1960             then n else aux (n+1)  pL
1961  in aux 0 rows end)
1962
1963
1964fun col_get_constr_set db (_, rows) =
1965  case pmatch_compile_db_compile_cf db rows of
1966    NONE => NONE
1967  | SOME cf => let
1968     val (exh, constrL) = constructorFamily_get_constructors cf;
1969     val cL_rows = List.map (fn (_, p) => fst (strip_comb p)) rows;
1970     val cL_cf = List.map fst constrL;
1971
1972     val cL_rows' = List.filter (fn c => op_mem same_const c cL_cf) cL_rows;
1973     val cL_rows'' = Lib.mk_set cL_rows';
1974  in
1975    SOME (cL_rows'', cL_cf, exh)
1976  end
1977
1978fun col_get_nonvar_set (_, rows) =
1979  let
1980     val cL' = List.filter (fn (vs, p) =>
1981        not (is_var p andalso mem p vs)) rows;
1982     val cL'' = Lib.mk_set cL';
1983  in
1984    cL''
1985  end
1986
1987fun colRank_small_branching_factor db : column_ranking_fun = (fn col =>
1988  case col_get_constr_set db col of
1989      SOME (cL, full_constrL, exh) =>
1990        (~(length cL + (if exh then 0 else 1) + (if length cL = length full_constrL then 0 else 1)))
1991    | NONE => (~(length (col_get_nonvar_set col) + 2)))
1992
1993fun colRank_arity db : column_ranking_fun = (fn col =>
1994  case col_get_constr_set db col of
1995     SOME (cL, full_constrL, exh) =>
1996       ~(List.foldl (fn (c, s) => s + length (fst (strip_fun (type_of c)))) 0 cL)
1997   | NONE => 0)
1998
1999
2000(* heuristics defined using ranking functions *)
2001val colHeu_first_row = colHeu_rank [colRank_first_row]
2002val colHeu_constr_prefix = colHeu_rank [colRank_constr_prefix]
2003fun colHeu_qba db = colHeu_rank [colRank_constr_prefix, colRank_small_branching_factor db, colRank_arity db]
2004fun colHeu_cqba db = colHeu_rank [colRank_first_row_constr db,
2005  colRank_constr_prefix, colRank_small_branching_factor db, colRank_arity db]
2006
2007(* A list of all the standard heuristics *)
2008fun colHeu_default cols = colHeu_qba (!thePmatchCompileDB) cols
2009
2010
2011(* Now we can define a case-split function that performs
2012   case-splits using such heuristics. *)
2013
2014(*
2015val t = ``case (a,x,xs) of
2016    | (NONE,x,[]) when x > 5 => x
2017    | (NONE,x,_) => SUC x``
2018
2019val t = ``case (a,x,xs) of
2020    | (NONE,x,[]) => x
2021    | (NONE,x,[2]) => x
2022    | (NONE,x,[v18]) => 3
2023    | (NONE,x,v12::v16::v17) => 3
2024    | (y,x,z,zs) .| (SOME y,x,[z]) => (x + 5 + z)
2025    | (y,v23,v24) .| (SOME y,0,v23::v24) => (v23 + y)
2026    | (y,z,v23) .| (SOME y,SUC z,[v23]) when (y > 5) => 3
2027    | (y,z) .| (SOME y,SUC z,[1; 2]) => (y + z)
2028  ``
2029*)
2030
2031fun literal_case_CONV c tt = if boolSyntax.is_literal_case tt then
2032   RATOR_CONV (RAND_CONV (ABS_CONV c)) tt else c tt
2033
2034val literal_cong_stop = prove(
2035   ``(v = v') ==> (literal_case (f:'a -> 'b) v = literal_case f v')``,
2036   SIMP_TAC std_ss [])
2037
2038fun PMATCH_CASE_SPLIT_AUX rc_arg col_no expand_thm t = let
2039  val (v, rows) = dest_PMATCH t
2040  val vs = pairSyntax.strip_pair v
2041
2042  val arg = el (col_no+1) vs
2043  val arg_v = genvar (type_of arg)
2044  val vs' = pairSyntax.list_mk_pair (fst (
2045    replace_element vs col_no [arg_v]))
2046
2047  val ff = let
2048    val (x, xs) = strip_comb t
2049    val t' = list_mk_comb(x, vs' :: (tl xs))
2050  in
2051    mk_abs (arg_v, t')
2052  end
2053
2054  val thm0 = ISPEC arg (ISPEC ff expand_thm)
2055  val thm1 = CONV_RULE (LHS_CONV BETA_CONV) thm0
2056
2057  val c' = REPEATC (
2058    TRY_CONV (QCHANGED_CONV (PMATCH_CLEANUP_CONV_GENCALL rc_arg)) THENC
2059    TRY_CONV (QCHANGED_CONV (PMATCH_SIMP_COLS_CONV_GENCALL rc_arg)) THENC
2060    TRY_CONV (REWR_CONV PMATCH_INCOMPLETE_def)
2061  );
2062
2063  fun c tt = let
2064    val _ = let
2065      val (t0, _) = dest_comb tt
2066      val (v', _) = dest_abs t0
2067    in
2068      if (aconv arg_v v') then () else failwith "not a new position"!
2069    end
2070  in
2071    (BETA_CONV THENC c') tt
2072  end;
2073
2074  val thm2 = CONV_RULE (RHS_CONV (TOP_SWEEP_CONV c)) thm1
2075
2076  (* check whether it got simpler, if not try full simp including propagating
2077     case information *)
2078  val thm3 = if (does_conv_loop thm2) then let
2079       val thm3 = CONV_RULE (RHS_CONV (literal_case_CONV (SIMP_CONV (
2080           (std_ss++simpLib.merge_ss (fst rc_arg) ++ PMATCH_SIMP_GEN_ss (fst rc_arg))) [PMATCH_INCOMPLETE_def, Cong literal_cong_stop]))) thm2
2081       val _ = if  (does_conv_loop thm3) then raise UNCHANGED else ()
2082       in thm3 end
2083     else thm2
2084in
2085  thm3
2086end
2087
2088(*
2089val t = t'
2090val col_no = 1
2091val rc_arg = ([], NONE)
2092val gl = []
2093val callback_opt = NONE
2094val db = !thePmatchCompileDB
2095val col_heu = colHeu_default
2096val t = ``case x of 3 => 1 | _ => 0``
2097*)
2098
2099fun PMATCH_CASE_SPLIT_CONV_GENCALL_STEP (gl, callback_opt) db col_heu t = let
2100  val _ = if (is_PMATCH t) then () else raise UNCHANGED
2101
2102  fun find_col cols = if (List.null cols) then raise UNCHANGED else let
2103    val col_no = col_heu cols
2104    val (v, col) = el (col_no+1) cols
2105    val res = pmatch_compile_db_compile db col
2106  in
2107    case res of
2108        SOME (expand_thm, _, expand_ss) => (col_no, expand_thm, expand_ss)
2109      | NONE => let
2110             val (cols', _) = replace_element cols col_no []
2111             val (col_no', expand_thm, expand_ss) = find_col cols'
2112             val col_no'' = if (col_no' < col_no) then col_no' else col_no'+1
2113          in
2114             (col_no'', expand_thm, expand_ss)
2115          end
2116  end
2117
2118  val (col_no, expand_thm, expand_ss) = find_col (dest_PMATCH_COLS t)
2119  val thm1 = QCHANGED_CONV (PMATCH_CASE_SPLIT_AUX
2120    (expand_ss::gl, callback_opt) col_no expand_thm) t
2121
2122  (* check whether it got simpler *)
2123  val _ = if (does_conv_loop thm1) then raise UNCHANGED else ()
2124in
2125  thm1
2126end
2127
2128
2129val pair_CASE_tm = mk_const ("pair_CASE", ``:'a # 'b -> ('a -> 'b -> 'c) -> 'c``)
2130
2131fun PMATCH_CASE_SPLIT_CONV_GENCALL rc_arg db col_heu t = let
2132  val thm0 = PMATCH_SIMP_CONV_GENCALL rc_arg t handle
2133        HOL_ERR _ => REFL t
2134      | UNCHANGED => REFL t
2135  val t' = rhs (concl thm0)
2136
2137  val cols = dest_PMATCH_COLS t'
2138  val col_no = length cols
2139  val (v, rows) = dest_PMATCH t'
2140  val rows_tm = rand t'
2141
2142  fun mk_pair avoid acc col_no v = if (col_no <= 1) then (
2143      let
2144        val vs = List.rev (v::acc)
2145        val p = pairSyntax.list_mk_pair vs
2146      in
2147        mk_PMATCH p rows_tm
2148      end
2149    ) else (
2150      let
2151         val (ty1, ty2) = pairSyntax.dest_prod (type_of v)
2152         val v1 = variant avoid (mk_var ("v", ty1))
2153         val v2 = variant (v1::avoid) (mk_var ("v", ty2))
2154
2155         val t0 = inst [alpha |-> ty1, beta |-> ty2, gamma |-> type_of t] pair_CASE_tm
2156         val t1 = mk_comb (t0, v)
2157         val t2a = mk_pair (v1::v2::avoid) (v1::acc) (col_no-1) v2
2158         val t2b = list_mk_abs ([v1, v2], t2a)
2159         val t2c = mk_comb (t1, t2b)
2160      in
2161        t2c
2162      end
2163    )
2164
2165  val t'' = mk_pair (free_vars t') [] col_no v
2166  val thm1_tm = mk_eq (t', t'')
2167  val thm1 = prove (thm1_tm, SIMP_TAC std_ss [pairTheory.pair_CASE_def])
2168
2169  val thm2 = CONV_RULE (RHS_CONV (
2170    (TOP_SWEEP_CONV (
2171      PMATCH_CASE_SPLIT_CONV_GENCALL_STEP rc_arg db col_heu
2172    )))) thm1
2173
2174  val thm3 = TRANS thm0 thm2
2175
2176  (* check whether it got simpler *)
2177  val _ = if (does_conv_loop thm3) then raise UNCHANGED else ()
2178
2179  val thm4 = if (has_subterm is_PMATCH (rhs (concl thm3))) then
2180       thm3
2181     else
2182       CONV_RULE (RHS_CONV REMOVE_REBIND_CONV) thm3
2183in
2184  thm4
2185end
2186
2187fun PMATCH_CASE_SPLIT_CONV_GEN ssl = PMATCH_CASE_SPLIT_CONV_GENCALL (ssl, NONE)
2188
2189fun PMATCH_CASE_SPLIT_CONV_HEU col_heu t =
2190  PMATCH_CASE_SPLIT_CONV_GEN [] (!thePmatchCompileDB) col_heu t
2191
2192fun PMATCH_CASE_SPLIT_CONV t =
2193  PMATCH_CASE_SPLIT_CONV_HEU colHeu_default t
2194
2195fun PMATCH_CASE_SPLIT_GEN_ss ssl db col_heu =
2196  make_gen_conv_ss (fn rc_arg =>
2197    PMATCH_CASE_SPLIT_CONV_GENCALL rc_arg db col_heu)
2198   "PMATCH_CASE_SPLIT_REDUCER" ssl
2199
2200fun PMATCH_CASE_SPLIT_HEU_ss col_heu =
2201  PMATCH_CASE_SPLIT_GEN_ss [] (!thePmatchCompileDB) col_heu
2202
2203fun PMATCH_CASE_SPLIT_ss () =
2204  PMATCH_CASE_SPLIT_HEU_ss colHeu_default
2205
2206
2207(***********************************************)
2208(* COMPUTE CASE-DISTINCTION based on pats      *)
2209(***********************************************)
2210
2211(*
2212val t = ``
2213  case (a,x,xs) of
2214    | (NONE,_,[]) => 0
2215    | (NONE,x,[]) when x < 10 => x
2216    | (NONE,x,[2]) => x
2217    | (NONE,x,[v18]) => 3
2218    | (NONE,_,[_;_]) => x
2219    | (NONE,x,v12::v16::v17) => 3
2220    | (SOME y,x,[z]) => x + 5 + z
2221    | (SOME y,0,v23::v24) => (v23 + y)
2222    | (SOME y,SUC z,[v23]) when y > 5 => 3
2223    | (SOME y,SUC z,[1; 2]) => y + z``;
2224
2225  val (v, rows) = dest_PMATCH t
2226  val pats = List.map (#1 o dest_PMATCH_ROW) rows
2227
2228  val col_heu = colHeu_default
2229  val db = !thePmatchCompileDB
2230
2231  val pats = [``\(x:num). 2``]
2232  val pats = [``\(x:num). [2;3;4]``]
2233
2234*)
2235
2236local
2237
2238  val case_dist_exists_thm = prove (``!Q. (
2239    (!(x:'a). Q x) ==>
2240    !P. (?x. P x) = (?x. Q x /\ P x))``,
2241  SIMP_TAC std_ss []);
2242
2243  val label_over_or_thm = prove (
2244    ``(lbl :- (t1 \/ t2)) <=> (lbl :- t1) \/ (lbl :- t2)``,
2245    REWRITE_TAC[markerTheory.label_def]);
2246
2247  fun find_nchotomy_for_cols db col_heu cols = let
2248    val _ = if (List.null cols) then
2249       raise failwith "compile failed" else ()
2250    val col_no = col_heu cols
2251    val (v, col) = el (col_no+1) cols
2252    val nchot_thm_opt = pmatch_compile_db_compile_nchotomy db col
2253  in
2254    case nchot_thm_opt of
2255      SOME nchot_thm => (v, ISPEC v nchot_thm)
2256    | NONE => let
2257        val (cols', _) = replace_element cols col_no []
2258      in
2259        find_nchotomy_for_cols db col_heu cols'
2260      end
2261  end
2262
2263
2264  fun mk_initial_state var_gen lbl_gen pats = let
2265    val (_, p) = pairSyntax.dest_pabs (hd pats)
2266    val cs = pairLib.strip_pair p
2267    val vs = List.map (fn p => var_gen (type_of p)) cs
2268    val initial_value = pairLib.list_mk_pair vs
2269    val cols = dest_PATLIST_COLS initial_value pats
2270
2271    val lbl = lbl_gen ()
2272    val initial_thm = let
2273      val x_tm = mk_var ("x", type_of initial_value)
2274      val tm = mk_forall (x_tm, markerSyntax.mk_label (lbl, list_mk_exists (vs, mk_eq (x_tm, initial_value))))
2275      val thm = prove (tm,
2276        SIMP_TAC std_ss [pairTheory.FORALL_PROD, markerTheory.label_def])
2277    in thm end
2278  in
2279    (initial_thm, cols, lbl)
2280  end
2281
2282
2283  fun compute_cases_info var_gen lbl_gen v nthm = let
2284    val disjuncts = ref ([] : (string * term * term list) list)
2285
2286    (* val d = el 2 ds *)
2287    fun process_disj d = let
2288      val lbl = lbl_gen ()
2289
2290      (* intro fresh vars *)
2291      val d_thm = let
2292        val (evs, d_b) = strip_exists d
2293        val s = List.map (fn v => (v |-> var_gen (type_of v))) evs
2294        val evs = List.map (Term.subst s) evs
2295        val d_b = Term.subst s d_b
2296        val d' = list_mk_exists (evs, d_b)
2297        val d_thm = ALPHA d d'
2298      in
2299        d_thm
2300      end
2301
2302      (* add label *)
2303      val ld_thm = RIGHT_CONV_RULE (add_labels_CONV [lbl]) d_thm
2304
2305
2306      (* figure out constructor and free variables and add them
2307         to list of disjuncts *)
2308      val _ = let
2309        val d' = rhs (concl d_thm)
2310        val (evs, b) = strip_exists d'
2311        val b_conjs = strip_conj b
2312        val main_conj = first (fn c' =>
2313           aconv (lhs c') v handle HOL_ERR _ => false) b_conjs
2314        val r = rhs main_conj
2315        val (c, _) = strip_comb_bounded (List.length evs) r
2316        val _ = disjuncts := (lbl, c, evs) :: !disjuncts
2317      in () end handle HOL_ERR _ => ()
2318    in
2319      ld_thm
2320    end handle HOL_ERR _ => raise UNCHANGED
2321
2322    (* val ds = strip_disj (concl nthm) *)
2323    val nthm' = CONV_RULE (ALL_DISJ_CONV process_disj) nthm
2324  in
2325    (nthm', List.rev (!disjuncts))
2326  end
2327
2328  fun exists_left_and_label_CONV t = let
2329    val (lbls_left, _) = (strip_labels o fst o dest_conj o snd o dest_exists) t
2330    val (lbls_right, _) = (strip_labels o snd o dest_conj o snd o dest_exists) t
2331
2332    val c_remove = QUANT_CONV (BINOP_CONV (REPEATC markerLib.DEST_LABEL_CONV))
2333
2334    val thm0 = (c_remove THENC (add_labels_CONV (lbls_left @ lbls_right))) t
2335  in
2336    thm0
2337  end
2338
2339  fun expand_disjunction_CONV v nthm_expand d_tm = let
2340    val thm00 = RESORT_EXISTS_CONV (fn vs =>
2341       let val (v', vs') = pick_element (aconv v) vs in
2342       (v'::vs') end) d_tm
2343
2344    val thm01a = HO_PART_MATCH (lhs o snd o strip_forall) nthm_expand (rhs (concl thm00))
2345    val thm01 = TRANS thm00 thm01a
2346
2347    val thm02 = RIGHT_CONV_RULE (PURE_REWRITE_CONV [RIGHT_AND_OVER_OR]) thm01
2348    val thm03 =
2349        RIGHT_CONV_RULE (DESCEND_CONV BINOP_CONV (TRY_CONV EXISTS_OR_CONV))
2350                        thm02
2351
2352    val thm04 = RIGHT_CONV_RULE (ALL_DISJ_CONV exists_left_and_label_CONV) thm03
2353
2354    val LEFT_RIGHT_AND_LIST_EXISTS_CONV =
2355        DESCEND_CONV QUANT_CONV
2356                     (RIGHT_AND_EXISTS_CONV ORELSEC LEFT_AND_EXISTS_CONV)
2357    val thm05 = RIGHT_CONV_RULE (ALL_DISJ_CONV (strip_labels_CONV (STRIP_QUANT_CONV LEFT_RIGHT_AND_LIST_EXISTS_CONV))) thm04
2358    val thm06 = RIGHT_CONV_RULE (ALL_DISJ_CONV (strip_labels_CONV (Unwind.UNWIND_EXISTS_CONV))) thm05
2359  in
2360    thm06
2361  end
2362
2363  fun expand_cases_in_thm lbl (v, nthm') thm = let
2364    val nthm_expand = HO_MATCH_MP case_dist_exists_thm (GEN v nthm')
2365
2366    val thm01 = CONV_RULE (QUANT_CONV (ALL_DISJ_CONV (
2367       guarded_strip_labels_CONV [lbl] (
2368       (expand_disjunction_CONV v nthm_expand))))) thm
2369
2370    val thm02 = CONV_RULE (PURE_REWRITE_CONV [label_over_or_thm, GSYM DISJ_ASSOC]) thm01
2371
2372   in
2373     thm02
2374   end handle HOL_ERR _ => thm
2375
2376
2377  fun get_columns_for_constructor current_col (c, evs) cols' = let
2378    fun process_current_col (cs : (term list * term) list list, kl : bool list) ps = case ps of
2379        [] => (List.map List.rev cs, List.rev kl)
2380      | (vs, p)::ps' => let
2381           val (cs', kl') =
2382             if (Term.is_var p) andalso List.exists (aconv p) vs then
2383               (Lib.map2 (fn v => fn l => ([v], v)::l) evs cs,
2384                true::kl)
2385             else let
2386               val (c', args) = strip_comb_bounded (List.length evs) p
2387             in
2388               if not (aconv c c') then (cs, false::kl) else
2389               (Lib.map2 (fn a => fn l => (vs, a)::l) args cs,
2390                true::kl)
2391             end
2392         in process_current_col (cs', kl') ps' end
2393
2394    val ps = (snd current_col)
2395    val (cs, kl) =  process_current_col (List.map (K []) evs, []) ps
2396    val cols1 = zip evs cs
2397
2398    val cols2 = List.map (fn (v, rs) =>
2399       (v, List.map snd (Lib.filter fst (zip kl rs)))) cols'
2400
2401    val cols'' = cols1 @ cols2
2402
2403    (* remove cols consisting of only vars *)
2404    val cols''' = filter (fn (_, ps) => not (List.all (fn (vs, p) =>    is_var p andalso List.exists (aconv p) vs) ps)) cols''
2405  in
2406     cols'''
2407  end
2408
2409
2410  (* extract the column for variable v from the list of columns *)
2411  fun pick_current_column v cols =
2412    pick_element (fn (v', _) => aconv v v') cols
2413
2414in (* in of local *)
2415
2416  fun nchotomy_of_pats_GEN db col_heu pats = let
2417    val var_gen = mk_var_gen "v" []
2418    val lbl_gen = mk_new_label_gen "case_"
2419
2420    (*
2421      val (thm, cols, lbl) = mk_initial_state var_gen lbl_gen pats
2422      val (thm, cols, lbl) = (thm1, cols'', lbl)
2423      val xxx = !args
2424      val (thm, cols, lbl) = el 3 xxx
2425*)
2426
2427    fun compile (thm, cols, lbl) = let
2428      val (v, nthm) = find_nchotomy_for_cols db col_heu cols
2429      val (current_col, cols_rest) = pick_current_column v cols
2430      val (nthm', cases_info) = compute_cases_info var_gen lbl_gen v nthm
2431
2432      (* Expand all labeled with [lbl] cases *)
2433      val thm1 = expand_cases_in_thm lbl (v, nthm') thm
2434
2435      (* Call recursively *)
2436      val thm2 = let
2437(*        val ((lbl, c, evs), current_thm) = ((el 2 cases_info, thm1)) *)
2438        fun process_case ((lbl, c, evs), current_thm) = let
2439          val cols' = get_columns_for_constructor current_col (c, evs) cols_rest
2440        in
2441          compile (current_thm, cols', lbl)
2442        end
2443      in
2444        List.foldl process_case thm1 cases_info
2445      end
2446    in
2447      thm2
2448    end handle HOL_ERR _ => thm
2449
2450    (* compile it *)
2451    val thm3 = compile (mk_initial_state var_gen lbl_gen pats)
2452
2453    (* get rid of labels *)
2454    val thm4 = CONV_RULE markerLib.DEST_LABELS_CONV thm3
2455  in
2456    thm4
2457  end
2458
2459  fun nchotomy_of_pats pats =
2460      nchotomy_of_pats_GEN (!thePmatchCompileDB) colHeu_default pats
2461
2462end
2463
2464
2465(********************************************)
2466(* Prune disjunctions of PMATCH_ROW_COND_EX *)
2467(********************************************)
2468
2469(* Given a list of disjunctions of PMATCH_ROW_COND_EX and
2470   a theorem stating that a certain PMATCH_ROW_COND_EX does not
2471   hold, prune the disjunction by removing all patterns
2472   covered by the one we know does not hold. *)
2473
2474
2475fun PMATCH_ROW_COND_EX_ELIM_FALSE_GUARD_CONV tt = let
2476  val (_, _, g) = dest_PMATCH_ROW_COND_EX tt
2477  val (_, g_b) = pairLib.dest_pabs g
2478  val _ = if (aconv g_b F) then () else raise UNCHANGED
2479
2480  val thm00 = PART_MATCH (lhs o rand) PMATCH_ROW_COND_EX_FALSE tt
2481  val pre = (rand o rator o concl) thm00
2482  (* set_goal ([], pre) *)
2483  val pre_thm = prove (pre,
2484    SIMP_TAC (std_ss++pairSimps.gen_beta_ss) [pairTheory.FORALL_PROD]
2485  )
2486  val thm01 = MP thm00 pre_thm
2487in
2488  thm01
2489end handle HOL_ERR _ => raise UNCHANGED
2490
2491(*
2492
2493val t = ``
2494  case (x,y,z) of
2495    | (NONE,_,[]) => 0
2496    | (NONE,x,[]) when x < 10 => x
2497    | (NONE,x,[2]) => x
2498    | (NONE,x,[v18]) => 3
2499    | (NONE,_,[_;_]) => 4
2500    | (NONE,x,v12::v16::v17) => 3
2501    | (SOME y,x,[z]) => x + 5 + z
2502    | (SOME y,0,v23::v24) => (v23 + y)
2503    | (SOME y,SUC z,[v23]) when y > 5 => 3
2504    | (SOME y,SUC z,[1; 2]) => y + z
2505  ``;
2506
2507  val (v, rows) = dest_PMATCH t
2508  val pats = List.map (#1 o dest_PMATCH_ROW) rows
2509
2510
2511val thm = CONV_RULE (nchotomy2PMATCH_ROW_COND_EX_CONV) (nchotomy_of_pats pats)
2512
2513val cs = (strip_disj o concl o SPEC v) thm
2514
2515val t  = (concl o SPEC v) thm
2516
2517val row_cs = List.map (mk_PMATCH_ROW_COND_EX_ROW v) rows
2518
2519val weaken_ce = el 4 row_cs
2520val weaken_thm = ASSUME (mk_neg weaken_ce)
2521val ce = el 4 cs
2522
2523val rc_arg = ([], NONE)
2524*)
2525
2526(* apply thm PMATCH_ROW_COND_EX_WEAKEN *)
2527fun PMATCH_ROW_COND_EX_WEAKEN_CONV_GENCALL rc_arg (weaken_thm, v_w, p_w', vars_w') ce = let
2528  val (v, p_t, _) = dest_PMATCH_ROW_COND_EX ce
2529  val (vars, p) = pairLib.dest_pabs p_t
2530  val _ = if (aconv v v_w) then () else raise UNCHANGED
2531
2532  (* try to match *)
2533  val s = let
2534    val (s_tm, s_ty) = Term.match_term p_w' p
2535    val _ = if List.null s_ty then () else failwith "bound too much"
2536    val vars_w'_l = pairSyntax.strip_pair vars_w'
2537    val _ = if List.exists (fn s => not (List.exists
2538        (aconv (#redex s)) vars_w'_l)) s_tm then
2539         failwith "bound too much" else ()
2540  in s_tm end
2541
2542  (* construct f *)
2543  val f_tm = pairSyntax.mk_pabs (vars, subst s vars_w')
2544
2545  (* instantiate the thm *)
2546  val thm0 = let
2547    val thm00 = FRESH_TY_VARS_RULE PMATCH_ROW_COND_EX_WEAKEN
2548    val thm01 = MATCH_MP thm00 weaken_thm
2549    val thm02 = ISPEC f_tm thm01
2550    val thm03 = PART_MATCH (lhs o rand) thm02 ce
2551    val thm04 = rc_elim_precond rc_arg thm03
2552  in
2553    thm04
2554  end
2555
2556  (* Simplify guard *)
2557  val thm1 = let
2558       val c = TRY_CONV (rc_conv rc_arg) THENC
2559               pairTools.PABS_INTRO_CONV vars
2560     in
2561       RIGHT_CONV_RULE (RAND_CONV c) thm0
2562     end
2563
2564  (* elim false *)
2565  val thm2 = RIGHT_CONV_RULE
2566    PMATCH_ROW_COND_EX_ELIM_FALSE_GUARD_CONV thm1
2567    handle HOL_ERR _ => thm1
2568in
2569  thm2
2570end handle HOL_ERR _ => raise UNCHANGED
2571
2572
2573fun PMATCH_ROW_COND_EX_DISJ_WEAKEN_CONV_GENCALL rc_arg weaken_thm t = let
2574  val (v_w, p_tw, _) =
2575    dest_PMATCH_ROW_COND_EX (dest_neg (concl weaken_thm))
2576  val (vars_w, p_w) = pairLib.dest_pabs p_tw
2577
2578  (* get fresh vars in p_w before matching *)
2579  val (p_w', vars_w') = let
2580    val vars'_l = pairSyntax.strip_pair vars_w
2581    val s = List.map (fn v => (v |-> genvar (type_of v)))  vars'_l
2582    val p_w' = subst s p_w
2583    val vars_w' = subst s vars_w
2584  in
2585    (p_w', vars_w')
2586  end
2587
2588  val thm0 = ALL_DISJ_CONV  (PMATCH_ROW_COND_EX_WEAKEN_CONV_GENCALL rc_arg (weaken_thm, v_w, p_w', vars_w')) t
2589
2590
2591  val thm1 = RIGHT_CONV_RULE (PURE_REWRITE_CONV [boolTheory.OR_CLAUSES]) thm0
2592in
2593  thm1
2594end
2595
2596
2597(*************************************)
2598(* Compute redundant rows info for a *)
2599(* PMATCH                            *)
2600(*************************************)
2601
2602(* val tt = el 3 cjs *)
2603
2604fun SIMPLIFY_PMATCH_ROW_COND_EX_IMP_CONV rc_arg tt = let
2605  (* destruct everything *)
2606  val (v, vars', p', g', vars, p, g) = let
2607    val (pre, cl_neg) = dest_imp tt
2608    val (v', p', g') = dest_PMATCH_ROW_COND_EX pre
2609    val (vars', _) = pairSyntax.dest_pabs p'
2610    val cl = dest_neg cl_neg
2611    val (v, p, g) = dest_PMATCH_ROW_COND_EX cl
2612    val _ = if (aconv v v') then () else raise UNCHANGED
2613    val (vars, _) = pairSyntax.dest_pabs p
2614  in
2615    (v, vars', p', g', vars, p, g)
2616  end
2617
2618  val thm00 = FRESH_TY_VARS_RULE PMATCH_ROW_COND_EX_IMP_REWRITE
2619  val thm01 = ISPECL [v, p', g', p, g] thm00
2620
2621  val thm02 = RIGHT_CONV_RULE (
2622      (QUANT_CONV (RAND_CONV (pairTools.PABS_INTRO_CONV vars))) THENC
2623      (RAND_CONV (pairTools.PABS_INTRO_CONV vars'))) thm01
2624  val thm03 = RIGHT_CONV_RULE (DEPTH_CONV pairLib.GEN_BETA_CONV) thm02
2625  val thm04 = RIGHT_CONV_RULE (TRY_CONV (pairTools.ELIM_TUPLED_QUANT_CONV) THENC
2626               TRY_CONV (STRIP_QUANT_CONV (pairTools.ELIM_TUPLED_QUANT_CONV))) thm03
2627
2628  fun imp_or_no_imp_CONV c t =
2629    if (is_imp t) then
2630      (RAND_CONV c) t
2631    else c t
2632
2633  val thm05 = RIGHT_CONV_RULE (
2634      (STRIP_QUANT_CONV (imp_or_no_imp_CONV (RATOR_CONV (RAND_CONV (SIMP_CONV (rc_ss []) []))))) THENC
2635      REWRITE_CONV[]) thm04
2636
2637  val rr = rhs (concl thm05)
2638  val thm06 = if aconv rr T then thm05 else let
2639      val thm_rr = prove_attempt (rr, rc_tac rc_arg)
2640    in
2641      TRANS thm05 (EQT_INTRO thm_rr)
2642    end handle HOL_ERR _ => thm05
2643in
2644  thm06
2645end
2646
2647(* val ttts = strip_disj pre
2648   val ttt = el 1 ttts
2649   val rc_arg = ([], NONE) *)
2650
2651fun SIMPLIFY_PMATCH_ROW_COND_EX_IMP_CONV rc_arg cc_thm v ttt = let
2652
2653  val (v', p, g) = dest_PMATCH_ROW_COND_EX ttt
2654  val _ = if (aconv v v') then () else raise UNCHANGED
2655
2656  val thm00 = FRESH_TY_VARS_RULE PMATCH_ROW_COND_EX_IMP_REWRITE
2657  val thm01 = MATCH_MP thm00 cc_thm
2658  val thm02 = ISPECL [p, g] thm01
2659
2660  val (x, pre, l) = let
2661     val (x, body) = (dest_forall o rand o rator o snd o strip_forall o concl) thm02
2662     val (pre, body') = dest_imp body
2663     val l = lhs body'
2664  in
2665    (x, pre, l)
2666  end
2667
2668  val l_thm0 = rc_conv_rws rc_arg [ASSUME pre] l
2669  val r = rhs (concl l_thm0)
2670  val _ = if (aconv r T) orelse (aconv r F) then () else
2671          (* we don't want complicated intermediate results *)
2672          raise UNCHANGED
2673  val l_thm1 = GEN x (DISCH pre l_thm0)
2674
2675  val thm03 = ISPEC r thm02
2676  val thm04 = MP thm03 l_thm1
2677in
2678  thm04
2679end
2680
2681
2682(* val thm = it
2683   val (tts, _) = (listSyntax.dest_list o rand o concl) thm
2684   val tt = el 2 tts *)
2685
2686val simple_imp_thm  = prove ( ``!X Y X'. ((Y ==> (X = X')) ==> ((X ==> ~Y) = (X' ==> ~Y)))``,
2687PROVE_TAC[])
2688
2689fun SIMPLIFY_REDUNDANT_ROWS_INFO_AUX rc_arg tt = let
2690  val (pre, cc_neg) = dest_imp tt
2691  val cc = dest_neg cc_neg
2692
2693  val (v, _, _) = dest_PMATCH_ROW_COND_EX cc
2694  val cc_thm = ASSUME cc
2695  val pre_thm0 =
2696    (ALL_DISJ_TF_ELIM_CONV (SIMPLIFY_PMATCH_ROW_COND_EX_IMP_CONV rc_arg cc_thm v)) pre handle UNCHANGED => REFL pre
2697  val pre_thm = DISCH cc pre_thm0
2698
2699  val thm0 = SPECL [pre, cc] simple_imp_thm
2700  val thm1 = MATCH_MP thm0 pre_thm
2701
2702  val thm2 = RIGHT_CONV_RULE (REWRITE_CONV [] THENC DEPTH_CONV
2703    PMATCH_ROW_COND_EX_ELIM_CONV) thm1
2704in
2705  thm2
2706end handle HOL_ERR _ => raise UNCHANGED
2707
2708
2709fun find_non_constructor_pattern db vs t = let
2710  fun aux l = case l of
2711      [] => NONE
2712    | (t::ts) =>  if (mem t vs) then aux ts else (
2713        if (pairSyntax.is_pair t) then
2714          aux ((pairSyntax.strip_pair t)@ts)
2715        else (
2716          case pmatch_compile_db_dest_constr_term db t of
2717             NONE => SOME t
2718           | SOME (_, args) => aux ((map snd args) @ ts)
2719        )
2720      )
2721in
2722  aux [t]
2723end
2724
2725
2726fun COMPUTE_REDUNDANT_ROWS_INFO_OF_PMATCH_GENCALL rc_arg db col_heu t =
2727let
2728  val (v, rows) = dest_PMATCH t
2729  val rc_arg = case rc_arg of
2730    (sl, cb_opt) => ((#pcdb_ss db)::sl, cb_opt)
2731
2732
2733  (* compute initial enchotomy *)
2734  val nchot_thm = let
2735    val pats = List.map (#1 o dest_PMATCH_ROW) rows
2736    val thm01 = nchotomy_of_pats_GEN db col_heu pats
2737    val thm02 = CONV_RULE (nchotomy2PMATCH_ROW_COND_EX_CONV_GEN
2738       (find_non_constructor_pattern db)
2739    ) thm01
2740    val thm03 = ISPEC v thm02
2741  in
2742    thm03
2743  end
2744
2745  (* get initial info *)
2746  val init_info = let
2747    val row_ty = listSyntax.dest_list_type (type_of (rand t))
2748    val s_ty = match_type (``:'a -> 'b option``) row_ty
2749    val thm00 = INST_TYPE s_ty IS_REDUNDANT_ROWS_INFO_NIL
2750    val thm01 = SPEC v thm00
2751
2752    val nthm = GSYM (EQT_INTRO nchot_thm)
2753    val thm02 = CONV_RULE (RATOR_CONV (RAND_CONV (K nthm))) thm01
2754  in
2755    thm02
2756  end
2757
2758  (* add a row to the info *)
2759  fun add_row (r, info_thm) = let
2760     val (p, g, r) = dest_PMATCH_ROW r
2761     val thm00 = FRESH_TY_VARS_RULE IS_REDUNDANT_ROWS_INFO_SNOC_PMATCH_ROW
2762     val thm01 = MATCH_MP thm00 info_thm
2763     val thm02 = ISPECL [p, g, r] thm01
2764
2765     (* simplify the condition we carry around *)
2766     val c'_thm = let
2767       val pthm = ASSUME (mk_neg (mk_PMATCH_ROW_COND_EX (v, p, g)))
2768       val c_tm = (rand o rator o concl) info_thm
2769       val c'_thm0 = PMATCH_ROW_COND_EX_DISJ_WEAKEN_CONV_GENCALL rc_arg pthm c_tm handle UNCHANGED => REFL c_tm
2770       val c'_thm = DISCH (concl pthm) c'_thm0
2771     in
2772       c'_thm
2773     end
2774
2775     val thm03 = MATCH_MP thm02 c'_thm
2776     val thm04 = CONV_RULE (RATOR_CONV (RATOR_CONV (RAND_CONV listLib.SNOC_CONV))) thm03
2777
2778     val new_cond_CONV = SIMPLIFY_REDUNDANT_ROWS_INFO_AUX rc_arg
2779     val thm05 = CONV_RULE (RAND_CONV (RATOR_CONV (RAND_CONV new_cond_CONV))) thm04
2780
2781     val thm06 = CONV_RULE (RAND_CONV (listLib.SNOC_CONV)) thm05
2782  in
2783     thm06
2784  end
2785in
2786  List.foldl add_row init_info rows
2787end
2788
2789fun COMPUTE_REDUNDANT_ROWS_INFO_OF_PMATCH_GEN ss db col_heu =
2790  COMPUTE_REDUNDANT_ROWS_INFO_OF_PMATCH_GENCALL (ss, NONE) db col_heu
2791
2792fun COMPUTE_REDUNDANT_ROWS_INFO_OF_PMATCH t =
2793  COMPUTE_REDUNDANT_ROWS_INFO_OF_PMATCH_GENCALL ([], NONE)
2794    (!thePmatchCompileDB) colHeu_default t
2795
2796
2797(*
2798val t = ``case (x, z) of
2799  | (NONE, NONE) => 0
2800  | (SOME _, _) => 1
2801  | (_, SOME _) => 2
2802``
2803
2804val t = ``case (x, z) of
2805  | (NONE, NONE) => 0
2806  | (SOME _, _) => 1
2807  | (_, NONE) => 2
2808``
2809
2810val t = ``case (x, z) of
2811  | (NONE, 1) => 0
2812  | (SOME _, 2) => 1
2813  | (_, x) when x > 5 => 2
2814``
2815*)
2816
2817fun IS_REDUNDANT_ROWS_INFO_WEAKEN_RULE info_thm = let
2818  val (conds, _) = listSyntax.dest_list (rand (concl info_thm))
2819  val conds' = List.map (fn c => if (aconv c T) then T else F) conds
2820  val _ = if exists (aconv T) conds' then () else raise UNCHANGED
2821  val conds'_tm = listSyntax.mk_list (conds', bool)
2822
2823  val thm00 = REDUNDANT_ROWS_INFOS_CONJ_THM
2824  val thm01 = MATCH_MP thm00 info_thm
2825  val thm02 = SPECL [F, conds'_tm] thm01
2826
2827  val thm03 = let
2828    val pre = rand (rator (concl thm02))
2829    val pre_thm = prove (pre, SIMP_TAC list_ss [])
2830  in
2831    MP thm02 pre_thm
2832  end
2833
2834  val thm04 = CONV_RULE (RATOR_CONV (RAND_CONV (REWRITE_CONV []))) thm03
2835
2836  val thm05 = CONV_RULE (RAND_CONV (REWRITE_CONV [
2837    REDUNDANT_ROWS_INFOS_CONJ_REWRITE])) thm04
2838in
2839  thm05
2840end
2841
2842fun IS_REDUNDANT_ROWS_INFO_TO_PMATCH_EQ_THM info_thm = let
2843  val info_thm' = IS_REDUNDANT_ROWS_INFO_WEAKEN_RULE info_thm
2844  val thm0 = MATCH_MP REDUNDANT_ROWS_INFO_TO_PMATCH_EQ info_thm'
2845  val c = PURE_REWRITE_CONV [APPLY_REDUNDANT_ROWS_INFO_THMS]
2846  val thm1 = RIGHT_CONV_RULE (RAND_CONV c) thm0
2847in
2848  thm1
2849end
2850
2851
2852fun PMATCH_REMOVE_REDUNDANT_CONV_GENCALL db col_heu rc_arg t = let
2853  val info_thm = COMPUTE_REDUNDANT_ROWS_INFO_OF_PMATCH_GENCALL rc_arg db col_heu t
2854in
2855  IS_REDUNDANT_ROWS_INFO_TO_PMATCH_EQ_THM info_thm
2856end
2857
2858fun PMATCH_REMOVE_REDUNDANT_CONV_GEN db col_heu ssl =
2859  PMATCH_REMOVE_REDUNDANT_CONV_GENCALL db col_heu (ssl, NONE)
2860
2861fun PMATCH_REMOVE_REDUNDANT_CONV t = PMATCH_REMOVE_REDUNDANT_CONV_GEN
2862  (!thePmatchCompileDB) colHeu_default [] t;
2863
2864fun PMATCH_REMOVE_REDUNDANT_GEN_ss db col_heu ssl =
2865  make_gen_conv_ss (PMATCH_REMOVE_REDUNDANT_CONV_GENCALL db col_heu)  "PMATCH_REMOVE_REDUNDANT_REDUCER" ssl
2866
2867fun PMATCH_REMOVE_REDUNDANT_ss () =
2868  PMATCH_REMOVE_REDUNDANT_GEN_ss (!thePmatchCompileDB) colHeu_default []
2869
2870
2871fun IS_REDUNDANT_ROWS_INFO_SHOW_ROW_IS_REDUNDANT thm i tac =
2872  CONV_RULE (RAND_CONV (list_nth_CONV i (fn t =>
2873    EQT_INTRO (prove (t, tac))))) thm
2874
2875fun IS_REDUNDANT_ROWS_INFO_SHOW_ROW_IS_REDUNDANT_set_goal thm i = let
2876  val (l, _) = (listSyntax.dest_list o rand o concl) thm
2877  val t = List.nth (l, i)
2878in
2879  proofManagerLib.set_goal ([], t)
2880end;
2881
2882
2883(*************************************)
2884(* Exhaustiveness                    *)
2885(*************************************)
2886
2887fun PMATCH_IS_EXHAUSTIVE_FAST_CHECK_GENCALL rc_arg t = let
2888  val (v, rows) = dest_PMATCH t
2889
2890  fun check_row r = let
2891    val r_tm = mk_eq (mk_comb (r, v), optionSyntax.mk_none (type_of t))
2892    val r_thm = rc_conv rc_arg r_tm
2893    val res_tm = rhs (concl r_thm)
2894  in
2895    if (same_const res_tm T) then SOME (true, r_thm) else
2896    (if (same_const res_tm F) then SOME (false, r_thm) else NONE)
2897  end handle HOL_ERR _ => NONE
2898
2899  fun find_thms a thmL [] = (a, thmL)
2900    | find_thms a thmL (r::rows) = (
2901      case (check_row r) of
2902         NONE => find_thms true thmL rows
2903       | SOME (true, r_thm) => find_thms a (r_thm :: thmL) rows
2904       | SOME (false, r_thm) => (false, [r_thm]))
2905
2906  val (abort, rewrite_thms) = find_thms false [] (List.rev rows)
2907  val _ = if abort then raise UNCHANGED else ()
2908
2909  val t0 = mk_PMATCH_IS_EXHAUSTIVE v (rand t)
2910in
2911  REWRITE_CONV (PMATCH_IS_EXHAUSTIVE_REWRITES::rewrite_thms) t0
2912end;
2913
2914fun PMATCH_IS_EXHAUSTIVE_FAST_CHECK_GEN ssl =
2915    PMATCH_IS_EXHAUSTIVE_FAST_CHECK_GENCALL (ssl, NONE)
2916
2917val PMATCH_IS_EXHAUSTIVE_FAST_CHECK =
2918    PMATCH_IS_EXHAUSTIVE_FAST_CHECK_GEN [];
2919
2920(*
2921val db = !thePmatchCompileDB
2922val col_heu = colHeu_default
2923val rc_arg = ([], NONE)
2924*)
2925
2926
2927(*
2928val t = ``case (x, z) of
2929  | (NONE, NONE) => 0
2930  | (_, SOME _) => 2
2931``
2932
2933val t = ``case (x, z) of
2934  | (NONE, NONE) => 0
2935  | (SOME _, _) => 1
2936  | (_, NONE) => 2
2937``
2938
2939val t = ``case (x, z) of
2940  | (NONE, 1) => 0
2941  | (SOME _, 2) => 1
2942  | (_, x) when x > 5 => 2
2943``
2944
2945val info_thm = COMPUTE_REDUNDANT_ROWS_INFO_OF_PMATCH t
2946
2947*)
2948
2949
2950fun IS_REDUNDANT_ROWS_INFO_TO_PMATCH_IS_EXHAUSTIVE info_thm = let
2951  val thm0 = MATCH_MP IS_REDUNDANT_ROWS_INFO_EXTRACT_IS_EXHAUSTIVE
2952    info_thm
2953in
2954  thm0
2955end
2956
2957fun PMATCH_IS_EXHAUSTIVE_COMPILE_CONSEQ_CHECK_FULLGEN db col_heu rc_arg t = let
2958  val info_thm = COMPUTE_REDUNDANT_ROWS_INFO_OF_PMATCH_GENCALL rc_arg db col_heu t
2959in
2960  IS_REDUNDANT_ROWS_INFO_TO_PMATCH_IS_EXHAUSTIVE info_thm
2961end
2962
2963fun PMATCH_IS_EXHAUSTIVE_COMPILE_CONSEQ_CHECK_GENCALL rc_arg t =
2964  PMATCH_IS_EXHAUSTIVE_COMPILE_CONSEQ_CHECK_FULLGEN
2965    (!thePmatchCompileDB) colHeu_default rc_arg t
2966
2967fun PMATCH_IS_EXHAUSTIVE_COMPILE_CONSEQ_CHECK_GEN ssl =
2968  PMATCH_IS_EXHAUSTIVE_COMPILE_CONSEQ_CHECK_GENCALL (ssl, NONE)
2969
2970val PMATCH_IS_EXHAUSTIVE_COMPILE_CONSEQ_CHECK =
2971  PMATCH_IS_EXHAUSTIVE_COMPILE_CONSEQ_CHECK_GEN [];
2972
2973
2974val IMP_TO_EQ_THM = prove (``!P Q. (P ==> Q) ==> (~P ==> ~Q) ==> (Q <=> P)``, PROVE_TAC[])
2975
2976fun PMATCH_IS_EXHAUSTIVE_COMPILE_CHECK_FULLGEN db col_heu rc_arg t = let
2977  val thm0 = PMATCH_IS_EXHAUSTIVE_COMPILE_CONSEQ_CHECK_FULLGEN db col_heu rc_arg t
2978in
2979  let
2980    val thm = rc_elim_precond rc_arg thm0
2981  in
2982    EQT_INTRO thm
2983  end handle HOL_ERR _ => let
2984    val thm1 = MATCH_MP IMP_TO_EQ_THM thm0
2985
2986    val (precond, _) = dest_imp_only (concl thm1)
2987    val pre_thm = prove_attempt (precond,
2988      REWRITE_TAC[PMATCH_IS_EXHAUSTIVE_REWRITES, PMATCH_ROW_EQ_NONE, PMATCH_ROW_COND_EX_def,
2989        DISJ_IMP_THM, GSYM LEFT_FORALL_IMP_THM] THEN
2990      SIMP_TAC (std_ss++pairSimps.gen_beta_ss) [PMATCH_ROW_COND_DEF_GSYM] THEN
2991      rc_tac rc_arg)
2992
2993    val thm2 = MP thm1 pre_thm
2994  in
2995    thm2
2996  end
2997end
2998
2999fun PMATCH_IS_EXHAUSTIVE_COMPILE_CHECK_GENCALL rc_arg t =
3000  PMATCH_IS_EXHAUSTIVE_COMPILE_CHECK_FULLGEN
3001    (!thePmatchCompileDB) colHeu_default rc_arg t
3002
3003fun PMATCH_IS_EXHAUSTIVE_COMPILE_CHECK_GEN ssl =
3004  PMATCH_IS_EXHAUSTIVE_COMPILE_CHECK_GENCALL (ssl, NONE)
3005
3006val PMATCH_IS_EXHAUSTIVE_COMPILE_CHECK =
3007  PMATCH_IS_EXHAUSTIVE_COMPILE_CHECK_GEN [];
3008
3009
3010fun PMATCH_IS_EXHAUSTIVE_CHECK_FULLGEN db col_heu rc_arg t =
3011  QCHANGED_CONV (PMATCH_IS_EXHAUSTIVE_FAST_CHECK_GENCALL rc_arg) t
3012  handle HOL_ERR _ =>
3013    PMATCH_IS_EXHAUSTIVE_COMPILE_CHECK_FULLGEN db col_heu rc_arg t;
3014
3015fun PMATCH_IS_EXHAUSTIVE_CHECK_GENCALL rc_arg t =
3016  PMATCH_IS_EXHAUSTIVE_CHECK_FULLGEN (!thePmatchCompileDB) colHeu_default rc_arg t
3017
3018fun PMATCH_IS_EXHAUSTIVE_CHECK_GEN ssl =
3019  PMATCH_IS_EXHAUSTIVE_CHECK_GENCALL (ssl, NONE)
3020
3021val PMATCH_IS_EXHAUSTIVE_CHECK = PMATCH_IS_EXHAUSTIVE_CHECK_GEN []
3022
3023
3024local
3025  val EQ_F_ELIM = prove (``!b. F ==> b``, PROVE_TAC[])
3026  val EQ_T_ELIM = prove (``!b. (b = T) ==> ~F ==> b``, PROVE_TAC[])
3027  val EQ_O_ELIM = prove (``!b1 b2. (b1 = b2) ==> b2 ==> b1``, PROVE_TAC[])
3028
3029in
3030
3031fun PMATCH_IS_EXHAUSTIVE_CONSEQ_CHECK_FULLGEN db col_heu rc_arg t = let
3032    val thm0 = QCHANGED_CONV (PMATCH_IS_EXHAUSTIVE_FAST_CHECK_GENCALL rc_arg) t
3033    val (ex_t, r) = dest_eq (concl thm0)
3034  in
3035    if (r = T) then
3036      MP (SPEC ex_t EQ_T_ELIM) thm0
3037    else (if (r = F) then
3038      (SPEC ex_t EQ_F_ELIM)
3039    else
3040      (MP (SPEC r (SPEC ex_t EQ_O_ELIM)) thm0)
3041    )
3042  end handle HOL_ERR _ =>
3043    PMATCH_IS_EXHAUSTIVE_COMPILE_CONSEQ_CHECK_FULLGEN db col_heu rc_arg t;
3044end;
3045
3046
3047fun PMATCH_IS_EXHAUSTIVE_CONSEQ_CHECK_GENCALL rc_arg t =
3048  PMATCH_IS_EXHAUSTIVE_CONSEQ_CHECK_FULLGEN (!thePmatchCompileDB) colHeu_default rc_arg t
3049
3050fun PMATCH_IS_EXHAUSTIVE_CONSEQ_CHECK_GEN ssl =
3051  PMATCH_IS_EXHAUSTIVE_CONSEQ_CHECK_GENCALL (ssl, NONE)
3052
3053val PMATCH_IS_EXHAUSTIVE_CONSEQ_CHECK = PMATCH_IS_EXHAUSTIVE_CONSEQ_CHECK_GEN []
3054
3055
3056(*************************************)
3057(* Nchotomy                          *)
3058(*************************************)
3059
3060(*
3061val db = !thePmatchCompileDB
3062val col_heu = colHeu_default
3063val rc_arg = ([], NONE)
3064*)
3065
3066
3067val neg_imp_rewr = prove (``(~A ==> B) = (~B ==> A)``,
3068  Cases_on `A` THEN   Cases_on `B` THEN REWRITE_TAC[]);
3069
3070fun nchotomy_PMATCH_ROW_COND_EX_CONSEQ_CONV_GEN rc_arg db col_heu tt = let
3071  (* destruct everything *)
3072  val (v, disjs) = let
3073    val disjs = strip_disj tt
3074    val (v, _, _) = dest_PMATCH_ROW_COND_EX (hd disjs)
3075  in
3076    (v, disjs)
3077  end
3078
3079  (* Sanity check *)
3080  val _ = List.map (fn r => let
3081    val (v', _, _) = dest_PMATCH_ROW_COND_EX r
3082    val _ = if (aconv v v') then () else failwith "illformed input"
3083  in () end) disjs
3084
3085  (* derive nchot thm *)
3086  val nchot_thm = let
3087    val pats = List.map (#2 o dest_PMATCH_ROW_COND_EX) disjs
3088    val thm01 = nchotomy_of_pats_GEN db col_heu pats
3089    val thm02 = CONV_RULE (nchotomy2PMATCH_ROW_COND_EX_CONV_GEN
3090      (find_non_constructor_pattern db)) thm01
3091    val thm03 = ISPEC v thm02
3092  in
3093    thm03
3094  end
3095
3096  (* prepare assumptions *)
3097  val neg_tt = mk_neg tt
3098  val pre_thms = let
3099     val thm00 = ASSUME neg_tt
3100     val thm01 = PURE_REWRITE_RULE [DE_MORGAN_THM] thm00
3101   in BODY_CONJUNCTS thm01 end
3102
3103
3104  (* apply these assumptions to the nchot_thm *)
3105  val nchot_thm' = let
3106    fun step (pre_thm, thm) =
3107      CONV_RULE (PMATCH_ROW_COND_EX_DISJ_WEAKEN_CONV_GENCALL rc_arg pre_thm) thm
3108    val thm00 = foldl step nchot_thm pre_thms
3109    val thm01 = DISCH neg_tt thm00
3110    in
3111      thm01
3112    end
3113
3114  val nchot_thm'' = let
3115    val thm00 = CONV_RULE (REWR_CONV neg_imp_rewr) nchot_thm'
3116    val thm01 = CONV_RULE (RATOR_CONV (RAND_CONV (REWRITE_CONV []))) thm00
3117  in thm01 end
3118
3119in
3120  nchot_thm''
3121end
3122
3123
3124fun SHOW_NCHOTOMY_CONSEQ_CONV_GEN ssl db col_heu tt = let
3125  val (x, b) = dest_forall tt
3126  val b_thm = ALL_DISJ_CONV (PMATCH_ROW_COND_EX_INTRO_CONV_GEN
3127    (find_non_constructor_pattern db) x) b
3128
3129  val thm2 = nchotomy_PMATCH_ROW_COND_EX_CONSEQ_CONV_GEN (ssl, NONE) db col_heu (rhs (concl b_thm))
3130
3131  val thm3 = CONV_RULE (RAND_CONV (K (GSYM b_thm))) thm2
3132
3133  val thm4 = CONV_RULE (RATOR_CONV (RAND_CONV (DEPTH_CONV (PMATCH_ROW_COND_EX_ELIM_CONV)))) thm3
3134
3135  val thm5 = GEN x thm4
3136in
3137  thm5
3138end
3139
3140fun SHOW_NCHOTOMY_CONSEQ_CONV tt =
3141  SHOW_NCHOTOMY_CONSEQ_CONV_GEN [] (!thePmatchCompileDB) colHeu_default tt
3142
3143
3144(*************************************)
3145(* Add missing patterns              *)
3146(*************************************)
3147
3148(*
3149val use_guards = true
3150val rc_arg = ([], NONE)
3151val db = !thePmatchCompileDB
3152val col_heu = colHeu_default
3153val t = ``case (x, y) of ([], x::xs) => x | (_, _) => 2``
3154val t = ``case (x, y) of ([], x::xs) => x``
3155*)
3156
3157fun PMATCH_COMPLETE_CONV_GENCALL_AUX rc_arg db col_heu use_guards t =
3158let
3159  val exh_thm = EQT_ELIM (PMATCH_IS_EXHAUSTIVE_FAST_CHECK_GENCALL rc_arg t)
3160                handle UNCHANGED => failwith "NOT EXH"
3161in (false, REFL t, (fn () => exh_thm)) end handle HOL_ERR _ =>
3162let
3163  val (v, rows) = dest_PMATCH t
3164  fun row_to_cond_ex r = let
3165    val (vs_t, p, g, _) = dest_PMATCH_ROW_ABS r
3166    val vs = pairSyntax.strip_pair vs_t
3167  in
3168    mk_PMATCH_ROW_COND_EX_PABS_MOVE_TO_GUARDS (find_non_constructor_pattern db) vs (v, p, g)
3169  end
3170  val disjs = List.map row_to_cond_ex rows
3171  val disjs_tm = list_mk_disj disjs
3172
3173  val thm_nchot = nchotomy_PMATCH_ROW_COND_EX_CONSEQ_CONV_GEN rc_arg db col_heu disjs_tm
3174
3175  val missing_list = let
3176    val pre = fst (dest_imp (concl thm_nchot))
3177    val disj = dest_neg pre
3178  in
3179    strip_disj disj
3180  end handle HOL_ERR _ => []
3181
3182  fun add_missing_pat (missing_t, thm) = let
3183    val (_, vs, p_t0, g_t0) = dest_PMATCH_ROW_COND_EX_ABS missing_t
3184    val g_t1 = if use_guards then g_t0 else T
3185    val g_t = pairSyntax.mk_pabs (vs, g_t1)
3186    val p_t = pairSyntax.mk_pabs (vs, p_t0)
3187    val r_t = pairSyntax.mk_pabs (vs, mk_arb (type_of t))
3188
3189    val thm00 = FRESH_TY_VARS_RULE PMATCH_REMOVE_ARB
3190    val rows_t = (rand o rhs o concl) thm
3191    val thm01 = ISPECL [p_t, g_t, r_t, v, rows_t] thm00
3192    val thm02 = rc_elim_precond rc_arg thm01
3193    val thm03 = GSYM thm02
3194    val thm04 = RIGHT_CONV_RULE (RAND_CONV (listLib.SNOC_CONV)) thm03
3195  in
3196    TRANS thm thm04
3197  end
3198
3199  val thm_expand = foldl add_missing_pat (REFL t) missing_list
3200
3201  (* set_goal ([], mk_PMATCH_IS_EXHAUSTIVE v (rand (rhs (concl thm_expand)))) *)
3202  fun exh_thm () = prove (mk_PMATCH_IS_EXHAUSTIVE v (rand (rhs (concl thm_expand))),
3203    ASSUME_TAC (thm_nchot) THEN
3204    PURE_REWRITE_TAC [PMATCH_IS_EXHAUSTIVE_REWRITES, PMATCH_ROW_NEQ_NONE] THEN
3205    PROVE_TAC[])
3206in
3207  (not (List.null missing_list), thm_expand, exh_thm)
3208end
3209
3210fun PMATCH_COMPLETE_CONV_GENCALL rc_arg db col_heu use_guards t =
3211  let
3212    val (ch, thm, _) = (PMATCH_COMPLETE_CONV_GENCALL_AUX rc_arg db col_heu use_guards t)
3213    val _ = if ch then () else raise UNCHANGED
3214  in thm end;
3215
3216fun PMATCH_COMPLETE_CONV_GEN ssl =
3217    PMATCH_COMPLETE_CONV_GENCALL (ssl, NONE);
3218
3219fun PMATCH_COMPLETE_CONV use_guards =
3220    PMATCH_COMPLETE_CONV_GEN [] (!thePmatchCompileDB) colHeu_default use_guards;
3221
3222fun PMATCH_COMPLETE_GEN_ss ssl db colHeu use_guards =
3223  make_gen_conv_ss (fn rc_arg =>
3224    PMATCH_COMPLETE_CONV_GENCALL rc_arg db colHeu use_guards)
3225    "PMATCH_COMPLETE_REDUCER" ssl;
3226
3227fun PMATCH_COMPLETE_ss use_guards = PMATCH_COMPLETE_GEN_ss [] (!thePmatchCompileDB) colHeu_default use_guards;
3228
3229
3230fun PMATCH_COMPLETE_CONV_GEN_WITH_EXH_PROOF ssl db col_heu use_guards t =
3231    let val (ch, mt, rt) = PMATCH_COMPLETE_CONV_GENCALL_AUX (ssl, NONE) db col_heu use_guards t in
3232    (if ch then SOME mt else NONE, rt ()) end
3233
3234fun PMATCH_COMPLETE_CONV_WITH_EXH_PROOF use_guards =
3235    PMATCH_COMPLETE_CONV_GEN_WITH_EXH_PROOF [] (!thePmatchCompileDB) colHeu_default use_guards;
3236
3237
3238
3239(***********************************************)
3240(* Lifting to lowest boolean level             *)
3241(***********************************************)
3242
3243(* One can replace pattern matches with a big-conjunction.
3244   Each row becomes one conjunct. Since the conjunction is
3245   of type bool, this needs to be done at a boolean level.
3246   So we can replace an arbitry term
3247
3248   P (PMATCH i rows) with
3249
3250   (row_cond 1 i -> P (row_rhs 1)) /\
3251   ...
3252   (row_cond n i -> P (row_rhs n)) /\
3253
3254   The row-cond contains that the pattern does not overlap with
3255   any previous pattern, that the guard holds.
3256
3257   The most common use-case of lifting are function definitions.
3258   of the form
3259
3260   f x = PMATCH x rows
3261
3262   which can be turned into a conjunction of top-level
3263   rewrite rules for the function f.
3264*)
3265
3266(*
3267
3268val tm = ``(P2 /\ Q ==> (
3269  (case xx of
3270    | (x, y::ys) => (x + y)
3271    | (0, []) => 9
3272    | (x, []) when x > 3 => x
3273    | (x, []) => 0) = 5))``
3274
3275val tm = ``
3276  (case xx of
3277    | (x, y::ys) => (x + y)
3278    | (0, []) => 9
3279    | (x, []) when x > 3 => x
3280    | (x, []) => 0) = 5``
3281
3282val _ = ENABLE_PMATCH_CASES()
3283val OPT_PAIR_def = TotalDefn.Define `OPT_PAIR xy = case xy of
3284 | (NONE, _) => NONE
3285 | (_, NONE) => NONE
3286 | (SOME x, SOME y) => SOME (x, y)`
3287val thm = OPT_PAIR_def
3288val tm = concl (hd (BODY_CONJUNCTS thm))
3289val force_minimal = false
3290val rc_arg = ([], NONE)
3291val try_exh = true
3292*)
3293
3294
3295local
3296val IMP_AUX_THM = prove (``(P ==> (X <=> Y)) <=>
3297   ((P ==> X) <=> (P ==> Y))``, PROVE_TAC[])
3298in
3299fun SIMPLE_IMP_COND_REWRITE_CONV thm tt = let
3300  val (pre, post) = dest_imp tt
3301  val pre_thm = ASSUME pre
3302  val rewr_thm = MATCH_MP thm pre_thm
3303  val thm0 = REWRITE_CONV [rewr_thm] post
3304  val thm1 = DISCH pre thm0
3305  val thm2 = CONV_RULE (REWR_CONV IMP_AUX_THM) thm1
3306in
3307  thm2
3308end
3309end;
3310
3311fun rename_uscore_vars ren avoid [] = ren
3312  | rename_uscore_vars ren avoid (v::vs) =
3313    let
3314      val (v_n, v_ty) = dest_var v
3315      val _ = if (String.sub(v_n, 0) = #"_") then () else failwith "nothing to do"
3316      val v' = variant avoid (mk_var ("v", v_ty))
3317    in
3318      rename_uscore_vars ((v |-> v')::ren) (v'::avoid) vs
3319    end handle HOL_ERR _ => rename_uscore_vars ren avoid vs
3320
3321
3322
3323fun PMATCH_LIFT_BOOL_CONV_GENCALL force_minimal try_exh rc_arg tm = let
3324  (* check whether we should really process tm *)
3325  val _ = if type_of tm = bool then () else raise UNCHANGED
3326  val p_tm = find_term is_PMATCH tm
3327  fun has_subterm p t = (find_term p t; true) handle HOL_ERR _ => false
3328
3329  val is_minimal = not force_minimal orelse not (has_subterm (fn t =>
3330    (not (aconv t tm)) andalso
3331    (type_of t = bool) andalso
3332    (has_subterm is_PMATCH t)) tm)
3333  val _ = if is_minimal then () else raise UNCHANGED
3334
3335  (* prepare tm *)
3336  val v = genvar (type_of p_tm)
3337  val P_tm = mk_abs (v, subst [p_tm |-> v] tm)
3338  val P_v = genvar (type_of P_tm)
3339
3340  (* do real work *)
3341  val thm0 = let
3342    val (p_tm', genvars_elim_s) = PMATCH_INTRO_GENVARS p_tm
3343    val t0 = (mk_comb (P_v, p_tm'))
3344    val c1 = SIMP_CONV std_ss [PMATCH_EXPAND_PRED_THM,
3345      PMATCH_EXPAND_PRED_def, PMATCH_ROW_NEQ_NONE,
3346      EVERY_DEF, PMATCH_ROW_EVAL_COND_EX, REVERSE_REV, REV_DEF]
3347    val c2 = REWRITE_CONV [PMATCH_ROW_COND_EX_def,
3348      PULL_EXISTS]
3349
3350    val thm00 = (c1 THENC c2) t0
3351    val thm01 = INST genvars_elim_s thm00
3352  in
3353    thm01
3354  end
3355
3356  (* Elim choice *)
3357  val thm1 = let
3358    val (v, rows) = dest_PMATCH p_tm
3359    fun process_row (r, thm') = let
3360      val (pt, gt, rt) = dest_PMATCH_ROW r
3361      val thm00 = ISPECL [pt, gt, v] (FRESH_TY_VARS_RULE PMATCH_COND_SELECT_UNIQUE)
3362      val thm01 = rc_elim_precond rc_arg thm00
3363      val thm'' = CONV_RULE (DEPTH_CONV (SIMPLE_IMP_COND_REWRITE_CONV thm01)) thm'
3364    in
3365      thm''
3366    end handle HOL_ERR _ => thm'
3367  in
3368    foldl process_row thm0 rows
3369  end
3370
3371  (* get rid of exhaustiveness check *)
3372  val thm2 = let
3373    val _ = if try_exh then () else failwith "skip"
3374    val thm_ex = PMATCH_IS_EXHAUSTIVE_CHECK_GENCALL rc_arg p_tm
3375    val thm2 = CONV_RULE (RHS_CONV (REWRITE_CONV [thm_ex])) thm1
3376  in
3377    thm2
3378  end handle HOL_ERR _ => thm1
3379           | UNCHANGED => thm1
3380
3381
3382  (* Use the right variable names and simplify *)
3383  val thm3 = let
3384    fun special_CONV tt = let
3385      val (vars_tm0, row) = let
3386        val (_, tt0) = dest_forall tt
3387        val (tt1, _) = dest_imp_only tt0
3388        val (tt2, _, _, _) = dest_PMATCH_ROW_COND tt1
3389      in (fst (pairSyntax.dest_pabs tt2), tt1) end
3390
3391      val vars_tm = let
3392        val to_ren = free_vars vars_tm0
3393        val avoid = free_vars row @ to_ren
3394        val ren = rename_uscore_vars [] avoid to_ren
3395      in
3396        subst ren vars_tm0
3397      end
3398
3399     val intro_marker = TRY_CONV (QUANT_CONV (RAND_CONV (RATOR_CONV (RAND_CONV markerLib.stmark_term))))
3400
3401      val elim_COND_CONV =
3402        QUANT_CONV (RATOR_CONV (RAND_CONV (REWR_CONV PMATCH_ROW_COND_DEF_GSYM)))
3403
3404      val intro_CONV = RAND_CONV (pairTools.PABS_INTRO_CONV vars_tm)
3405      val elim_CONV = TRY_CONV (pairTools.ELIM_TUPLED_QUANT_CONV)
3406      val eval_preconds = STRIP_QUANT_CONV (RAND_CONV (fn t => let
3407         val _ = dest_imp_only t
3408       in RATOR_CONV (RAND_CONV (rc_conv rc_arg)) t end))
3409
3410      val simp_CONV = TRY_CONV (SIMP_CONV std_ss [])
3411
3412      val elim_marker =
3413        (REWR_CONV markerTheory.stmarker_def) THENC
3414        TRY_CONV (rc_conv rc_arg)
3415    in
3416      EVERY_CONV [
3417        intro_marker,
3418        elim_COND_CONV,
3419        intro_CONV, elim_CONV,
3420        simp_CONV,
3421        DEPTH_CONV elim_marker,
3422        TRY_CONV (REWRITE_CONV [])
3423        ] tt
3424    end
3425  in
3426    CONV_RULE (RHS_CONV (ALL_CONJ_CONV special_CONV)) thm2
3427  end
3428
3429
3430  (* restore original predicate *)
3431  val thm4 = let
3432    val thm00 = INST [P_v |-> P_tm] thm3
3433    val thm01 = CONV_RULE (LHS_CONV BETA_CONV) thm00
3434    val thm02 = CONV_RULE (RHS_CONV (DEPTH_CONV BETA_CONV)) thm01
3435    val _ = assert (fn thm => aconv (lhs (concl thm)) tm) thm02
3436  in
3437    thm02
3438  end
3439in
3440  thm4
3441end
3442
3443fun PMATCH_LIFT_BOOL_CONV_GEN ssl try_exh = PMATCH_LIFT_BOOL_CONV_GENCALL true try_exh (ssl, NONE)
3444
3445val PMATCH_LIFT_BOOL_CONV = PMATCH_LIFT_BOOL_CONV_GEN [];
3446
3447fun PMATCH_LIFT_BOOL_GEN_ss ssl try_exh =
3448  make_gen_conv_ss (PMATCH_LIFT_BOOL_CONV_GENCALL true try_exh) "PMATCH_LIFT_BOOL_REDUCER" ssl
3449
3450val PMATCH_LIFT_BOOL_ss = PMATCH_LIFT_BOOL_GEN_ss []
3451
3452
3453fun PMATCH_TO_TOP_RULE_SINGLE ssl thm = let
3454  val thm0 = GEN_ALL thm
3455
3456  val thm1 = CONV_RULE (STRIP_QUANT_CONV (PMATCH_LIFT_BOOL_CONV_GENCALL false false (ssl, NONE))) thm0
3457  val thm2 = CONV_RULE (STRIP_QUANT_CONV (
3458     EVERY_CONJ_CONV (STRIP_QUANT_CONV (TRY_CONV (RAND_CONV markerLib.stmark_term))))) thm1
3459  val thm3 = SIMP_RULE std_ss [FORALL_AND_THM,
3460   Cong (REFL ``stmarker (t:'a)``)] thm2
3461  val thm4 = PURE_REWRITE_RULE [markerTheory.stmarker_def] thm3
3462
3463  val thm5 = LIST_CONJ (butlast (CONJUNCTS thm4))
3464in
3465  thm5
3466end
3467
3468fun PMATCH_TO_TOP_RULE_GEN ssl thm = let
3469  val thms = BODY_CONJUNCTS thm
3470  val thms' = List.map (PMATCH_TO_TOP_RULE_SINGLE ssl) thms
3471  val thm0 = LIST_CONJ thms'
3472  val thm1 = CONV_RULE unwindLib.FLATTEN_CONJ_CONV thm0
3473in
3474  thm1
3475end
3476
3477fun PMATCH_TO_TOP_RULE thm = PMATCH_TO_TOP_RULE_GEN [] thm;
3478
3479
3480(*************************************)
3481(* Lifting                           *)
3482(*************************************)
3483
3484(*
3485val tm = ``\y. SUC (SUC
3486  (case y of
3487    | (x, y::ys) => (x + y)
3488    | (0, []) => 0)) +
3489  (case xx of
3490    | (x, y::ys) => (x + y)
3491    | (0, []) => 0)``
3492val tm = p_tm
3493*)
3494
3495fun PMATCH_LIFT_CONV_GENCALL_AUX rc_arg db col_heu tm = let
3496  (* check whether we should really process tm *)
3497  val _ = if (is_PMATCH tm) then failwith "PMATCH_LIFT_CONV_GENCALL_AUX: nothing to do" else ()
3498
3499  (* search subterm to lift *)
3500  fun search_pred (bvs, tt) = if is_PMATCH tt andalso
3501    HOLset.isEmpty (HOLset.intersection (HOLset.fromList Term.compare bvs, FVL [tt] empty_tmset)) then SOME tt else NONE
3502  val p_tm = case gen_find_term search_pred tm of SOME p_tm => p_tm | NONE => failwith "no_case"
3503
3504  (* Abstract context with f_tm *)
3505  val f_tm = let
3506    val nv = genvar (type_of p_tm)
3507    val tm' = subst [p_tm |-> nv] tm
3508  in
3509    mk_abs (nv, tm')
3510  end
3511  val (_, p_thm, exh_thm) = PMATCH_COMPLETE_CONV_GENCALL_AUX rc_arg db col_heu true p_tm
3512
3513  (* Intro f_tm *)
3514  val thm0 = let
3515    val f_tm_thm = GSYM (BETA_CONV (mk_comb (f_tm, p_tm)))
3516    val f_tm_thm' = CONV_RULE (RHS_CONV (RAND_CONV (K p_thm))) f_tm_thm
3517  in f_tm_thm' end
3518
3519  (* Apply lifting thm *)
3520  val (v, rows_tm, thm1) = let
3521    val p_tm' = rhs (concl p_thm)
3522    val (v, rows) = dest_PMATCH p_tm'
3523    val rows_tm = rand p_tm'
3524    val thm10 = ISPECL [f_tm, v, rows_tm] (FRESH_TY_VARS_RULE PMATCH_LIFT_THM)
3525    val thm11 = MP thm10 (exh_thm())
3526  in (v, rows_tm, thm11) end
3527
3528  (* Simplify *)
3529  val thm2 = let
3530    fun c3 tt = let
3531      val (vt, _) = pairSyntax.dest_pabs (rator (rand (snd (dest_abs tt))))
3532      val c30 = (pairTools.PABS_INTRO_CONV vt)
3533      val c31 = PairRules.PABS_CONV (RAND_CONV PairRules.PBETA_CONV)
3534      val c32 = PairRules.PABS_CONV BETA_CONV
3535    in
3536      (c30 THENC (TRY_CONV c31) THENC c32) tt
3537    end
3538    val c2 = REWR_CONV PMATCH_ROW_LIFT_THM THENC (RAND_CONV c3)
3539    val c = listLib.MAP_CONV c2
3540    val thm2 = c (rand (rhs (concl thm1)))
3541  in
3542    thm2
3543  end
3544
3545  val thm_lift = CONV_RULE (RHS_CONV (RAND_CONV (K thm2))) (TRANS thm0 thm1)
3546
3547  (* construct exhaustiveness result *)
3548  fun exh_thm' () = let
3549    val exh_thm = exh_thm ()
3550    val thm00 = ISPECL [f_tm, v, rows_tm] PMATCH_IS_EXHAUSTIVE_LIFT
3551    val thm01 = MP thm00 exh_thm
3552    val thm02 = CONV_RULE (RAND_CONV (K thm2)) thm01
3553  in thm02 end
3554in
3555  (thm_lift, exh_thm')
3556end;
3557
3558
3559fun PMATCH_LIFT_CONV_GENCALL rc_arg db col_heu t =
3560  let
3561    val (thm, _) = (PMATCH_LIFT_CONV_GENCALL_AUX rc_arg db col_heu t)
3562  in thm end;
3563
3564fun PMATCH_LIFT_CONV_GENCALL_WITH_EXH_PROOF rc_arg db col_heu t =
3565  let
3566    val (thm, exh) = (PMATCH_LIFT_CONV_GENCALL_AUX rc_arg db col_heu t)
3567  in (thm, exh()) end;
3568
3569fun PMATCH_LIFT_CONV_GEN ssl =
3570    PMATCH_LIFT_CONV_GENCALL (ssl, NONE);
3571
3572fun PMATCH_LIFT_CONV t =
3573    PMATCH_LIFT_CONV_GEN [] (!thePmatchCompileDB) colHeu_default t;
3574
3575fun PMATCH_LIFT_CONV_GEN_WITH_EXH_PROOF ssl =
3576    PMATCH_LIFT_CONV_GENCALL_WITH_EXH_PROOF (ssl, NONE);
3577
3578fun PMATCH_LIFT_CONV_WITH_EXH_PROOF t =
3579    PMATCH_LIFT_CONV_GEN_WITH_EXH_PROOF [] (!thePmatchCompileDB) colHeu_default t;
3580
3581
3582(*************************************)
3583(* FLATTENING                        *)
3584(*************************************)
3585
3586(*
3587val do_lift = false
3588val use_guards = true
3589val rc_arg = ([], NONE)
3590val db = !thePmatchCompileDB
3591val col_heu = colHeu_default
3592
3593
3594val tm = ``case (x, y) of ([], x::xs) => (
3595           case xs of [] => 0 | _ => 5) | (_, []) => 1 ``
3596
3597val tm = ``case (x, y) of (x::xs, []) => 2 | ([], x::xs) => (
3598           SUC (case xs of [] => x | _ => HD xs)) | (_, []) => 1 ``
3599*)
3600
3601fun PMATCH_FLATTEN_CONV_GENCALL_AUX rc_arg db col_heu do_lift tm = let
3602  val (v, rows) = dest_PMATCH tm
3603
3604  (* Try to flatten row no i *)
3605  fun try_row i = let
3606    val (rows_b, row, rows_a) = extract_element rows i
3607    val (pt, gt, rt0) = dest_PMATCH_ROW row
3608    val (vs, rt) = pairSyntax.dest_pabs rt0
3609
3610    (* lift the rhs of row i to be PMATCH expression *)
3611    val thm0 = if do_lift andalso not (is_PMATCH rt) then
3612      PMATCH_LIFT_CONV_GENCALL rc_arg db col_heu rt
3613    else
3614      PMATCH_COMPLETE_CONV_GENCALL rc_arg db col_heu true rt handle UNCHANGED => REFL rt
3615
3616    (* extend the input to match the output of the outer PMATCH exactly *)
3617    val thm1 = let
3618      val thm1a = PMATCH_EXTEND_INPUT_CONV_GENCALL rc_arg vs (rhs (concl thm0))
3619      val thm1 = TRANS thm0 thm1a
3620    in thm1 end
3621
3622
3623    (* Apply the flatten theorem, discard preconditions and show that rhs equals input *)
3624    val thm2 = let
3625      val rt' = rhs (concl thm1)
3626      val (v', rows') = dest_PMATCH rt'
3627      val rows'' = map (fn t => pairSyntax.mk_pabs (v', t)) rows'
3628
3629
3630      (* instantiate thm *)
3631      val thm2a = let
3632        val thm20 = FRESH_TY_VARS_RULE PMATCH_FLATTEN_THM
3633        val thm20 = ISPEC v thm20
3634        val thm20 = ISPEC pt thm20
3635        val thm20 = ISPEC gt thm20
3636        val thm20 = ISPEC (listSyntax.mk_list(rows_b, type_of row)) thm20
3637        val thm20 = ISPEC (listSyntax.mk_list(rows_a, type_of row)) thm20
3638        val thm20 = ISPEC (listSyntax.mk_list(rows'', type_of (hd rows''))) thm20
3639        val thm21 = CONV_RULE (RATOR_CONV (RAND_CONV (RAND_CONV (pairTools.PABS_INTRO_CONV v')))) thm20
3640        val c = RATOR_CONV (RAND_CONV (RAND_CONV (pairTools.PABS_INTRO_CONV v')))
3641        val thm22 = CONV_RULE (RAND_CONV (LHS_CONV (RAND_CONV (RAND_CONV c)))) thm21
3642
3643      in thm22 end
3644
3645      (* simplify MAP (\x. r x) rows'' = rows' *)
3646      val thm2b = let
3647        val map_tm = rand (snd (pairSyntax.dest_pforall (fst (dest_imp (concl thm2a)))))
3648        val map_tm_eq = mk_eq (map_tm, listSyntax.mk_list (rows', type_of (hd rows')))
3649        val map_thm = prove (map_tm_eq, SIMP_TAC list_ss [])
3650
3651        val thm2b = CONV_RULE (DEPTH_CONV (REWR_CONV map_thm)) thm2a
3652      in thm2b end
3653
3654      (* elim precond *)
3655      val thm2c = let
3656        val exh_thm = PMATCH_IS_EXHAUSTIVE_CHECK_GENCALL rc_arg rt'
3657        val (pre, _) = dest_imp (concl thm2b)
3658        val pre_thm = prove (pre, SIMP_TAC std_ss [exh_thm, GSYM pairTheory.PFORALL_THM])
3659        val thm2c = MP thm2b pre_thm
3660      in thm2c end
3661
3662      (* use thm1 on lhs *)
3663      val thm2d = let
3664        val c = RATOR_CONV (RAND_CONV (RAND_CONV (PairRules.PABS_CONV (K (GSYM thm1)))))
3665        val thm20 = CONV_RULE (LHS_CONV (RAND_CONV (RAND_CONV c))) thm2c
3666        val l_eq = mk_eq (tm, lhs (concl thm20))
3667        val l_thm = prove (l_eq, SIMP_TAC list_ss [])
3668        val thm2d = TRANS l_thm thm20
3669      in thm2d end
3670    in
3671      thm2d
3672    end
3673
3674    (* EVALUATE MAP PMATCH_FLATTEN_FUN on rhs *)
3675    val thm3 = let
3676      val flatten_thm = let
3677        val thm00 = FRESH_TY_VARS_RULE PMATCH_FLATTEN_FUN_PMATCH_ROW
3678        val thm01 = ISPEC pt thm00
3679        val thm02 = rc_elim_precond rc_arg thm01
3680        val thm03 = ISPEC gt thm02
3681        val c = pairTools.PABS_INTRO_CONV vs
3682        val thm04 = CONV_RULE (STRIP_QUANT_CONV (LHS_CONV (RAND_CONV c))) thm03
3683      in thm04 end
3684
3685      fun flatten_fun_conv tt = let
3686        val (_, row_d) = pairSyntax.dest_pabs (rand tt)
3687        val (pt_d, gt_d, rt_d) = dest_PMATCH_ROW row_d
3688        val thm00 = ISPECL [pt_d, pairSyntax.mk_pabs(vs, gt_d), pairSyntax.mk_pabs(vs, rt_d)] flatten_thm
3689        val eq_tm = mk_eq (tt, lhs (concl thm00))
3690        val eq_thm = prove (eq_tm, SIMP_TAC (std_ss++pairSimps.gen_beta_ss) [])
3691
3692        val thm01 = TRANS eq_thm thm00
3693        val (vs', _) = pairSyntax.dest_pabs pt_d
3694        val thm02 = CONV_RULE (RHS_CONV (PMATCH_ROW_PABS_INTRO_CONV vs')) thm01
3695
3696        val thm03 = CONV_RULE (RHS_CONV (DEPTH_CONV (pairLib.PAIRED_BETA_CONV ORELSEC BETA_CONV))) thm02
3697        val thm04 = CONV_RULE (RHS_CONV (REWRITE_CONV [])) thm03
3698      in thm04 end
3699
3700      val c = BETA_CONV THENC flatten_fun_conv
3701      val thm30 = CONV_RULE (RHS_CONV (RAND_CONV (RATOR_CONV (
3702             RAND_CONV (RAND_CONV (listLib.MAP_CONV c)))))) thm2
3703
3704      val thm31 = CONV_RULE (RHS_CONV (RAND_CONV (RATOR_CONV (RAND_CONV
3705        listLib.APPEND_CONV)))) thm30
3706
3707      val thm32 = CONV_RULE (RHS_CONV (RAND_CONV
3708        listLib.APPEND_CONV)) thm31
3709    in thm32 end
3710
3711    (* Fix wildcards *)
3712    val thm4 = CONV_RULE (RHS_CONV PMATCH_INTRO_WILDCARDS_CONV) thm3
3713  in
3714    thm4
3715  end
3716
3717  val row_index_l = Lib.upto 0 (length rows - 1)
3718in
3719  tryfind try_row row_index_l
3720end
3721
3722
3723fun PMATCH_FLATTEN_CONV_GENCALL rc_arg db col_heu do_lift =
3724  REPEATC (PMATCH_FLATTEN_CONV_GENCALL_AUX rc_arg db col_heu do_lift)
3725
3726fun PMATCH_FLATTEN_CONV_GEN ssl =
3727    PMATCH_FLATTEN_CONV_GENCALL (ssl, NONE);
3728
3729fun PMATCH_FLATTEN_CONV do_lift =
3730    PMATCH_FLATTEN_CONV_GEN [] (!thePmatchCompileDB) colHeu_default do_lift;
3731
3732fun PMATCH_FLATTEN_GEN_ss ssl db col_heu do_lift =
3733  make_gen_conv_ss (fn rc_arg => PMATCH_FLATTEN_CONV_GENCALL rc_arg db col_heu do_lift)
3734    "PMATCH_FLATTEN_REDUCER" ssl
3735
3736fun PMATCH_FLATTEN_ss do_lift =
3737  PMATCH_FLATTEN_GEN_ss [] (!thePmatchCompileDB) colHeu_default do_lift;
3738
3739
3740(*************************************)
3741(* Analyse PMATCH expressions to     *)
3742(* check whether they can be         *)
3743(* translated to ML or OCAML         *)
3744(*************************************)
3745
3746type pmatch_info = {
3747  pmi_is_well_formed            : bool,
3748  pmi_ill_formed_rows           : int list,
3749  pmi_has_free_pat_vars         : (int * term list) list,
3750  pmi_has_unused_pat_vars       : (int * term list) list,
3751  pmi_has_double_bound_pat_vars : (int * term list) list,
3752  pmi_has_guards                : int list,
3753  pmi_has_lambda_in_pat         : int list,
3754  pmi_has_non_contr_in_pat      : (int * term list) list,
3755  pmi_exhaustiveness_cond       : thm option
3756}
3757
3758fun is_proven_exhaustive_pmatch (pmi : pmatch_info) =
3759  (case (#pmi_exhaustiveness_cond pmi) of
3760      NONE => false
3761    | SOME thm => let
3762        val (pre, _) = dest_imp_only (concl thm)
3763      in
3764        aconv pre ``~F``
3765      end handle HOL_ERR _ => false
3766  )
3767
3768fun get_possibly_missing_patterns (pmi : pmatch_info) =
3769  (case (#pmi_exhaustiveness_cond pmi) of
3770      NONE => NONE
3771    | SOME thm => (let
3772        val (pre, _) = dest_imp_only (concl thm)
3773      in if aconv pre ``~F`` then SOME [] else
3774      let
3775        val ps = strip_disj (dest_neg pre)
3776        fun dest_p p = let
3777           val (_, vs, p, g) = dest_PMATCH_ROW_COND_EX_ABS p
3778        in (vs, p, g) end
3779      in
3780        SOME (List.map dest_p ps)
3781      end end) handle HOL_ERR _ => NONE
3782  )
3783
3784fun extend_possibly_missing_patterns t (pmi : pmatch_info) =
3785  case get_possibly_missing_patterns pmi of
3786      NONE => failwith "no missing row info available"
3787    | SOME [] => t
3788    | SOME rs => let
3789       val use_guards = not (null (#pmi_has_guards pmi))
3790       val arb_t = mk_arb (type_of t)
3791       fun mk_row (v, p, g) = let
3792         val vars = pairSyntax.strip_pair v
3793       in
3794         snd (mk_PMATCH_ROW_PABS_WILDCARDS vars (p,
3795           if use_guards then g else T, arb_t))
3796       end
3797       val rows = List.map mk_row rs
3798
3799       val (i, rows_org) = dest_PMATCH t
3800       val rows_t =
3801         listSyntax.mk_list (rows_org @ rows, type_of (hd rows))
3802      in
3803        mk_PMATCH i rows_t
3804      end;
3805
3806
3807fun is_well_formed_pmatch (pmi : pmatch_info) =
3808  (#pmi_is_well_formed pmi) andalso
3809  (null (#pmi_ill_formed_rows pmi)) andalso
3810  (null (#pmi_has_unused_pat_vars pmi)) andalso
3811  (null (#pmi_has_lambda_in_pat pmi));
3812
3813fun is_ocaml_pmatch (pmi : pmatch_info) =
3814  (is_well_formed_pmatch pmi) andalso
3815  (null (#pmi_has_non_contr_in_pat pmi)) andalso
3816  (null (#pmi_has_free_pat_vars pmi)) andalso
3817  (null (#pmi_has_double_bound_pat_vars pmi));
3818
3819fun is_sml_pmatch (pmi : pmatch_info) =
3820  (is_ocaml_pmatch pmi) andalso
3821  (null (#pmi_has_guards pmi));
3822
3823val init_pmatch_info : pmatch_info = {
3824  pmi_is_well_formed            = false,
3825  pmi_ill_formed_rows           = [],
3826  pmi_has_free_pat_vars         = [],
3827  pmi_has_unused_pat_vars       = [],
3828  pmi_has_double_bound_pat_vars = [],
3829  pmi_has_guards                = [],
3830  pmi_has_lambda_in_pat         = [],
3831  pmi_has_non_contr_in_pat      = [],
3832  pmi_exhaustiveness_cond       = NONE
3833}
3834
3835fun pmi_genupdate f1 f2 f3 f4 f5 f6 f7 f8 f9
3836  (pmi : pmatch_info) = ({
3837  pmi_is_well_formed            = f1 (#pmi_is_well_formed pmi),
3838  pmi_ill_formed_rows           = f2 (#pmi_ill_formed_rows pmi),
3839  pmi_has_free_pat_vars         = f3 (#pmi_has_free_pat_vars pmi),
3840  pmi_has_unused_pat_vars       = f4 (#pmi_has_unused_pat_vars pmi),
3841  pmi_has_double_bound_pat_vars = f5 (#pmi_has_double_bound_pat_vars pmi),
3842  pmi_has_guards                = f6 (#pmi_has_guards pmi),
3843  pmi_has_lambda_in_pat         = f7 (#pmi_has_lambda_in_pat pmi),
3844  pmi_has_non_contr_in_pat      = f8 (#pmi_has_non_contr_in_pat pmi),
3845  pmi_exhaustiveness_cond       = f9 (#pmi_exhaustiveness_cond pmi)
3846}:pmatch_info)
3847
3848fun pmi_set_is_well_formed x =
3849    pmi_genupdate (K x) I I I I I I I I
3850
3851fun pmi_add_ill_formed_row row_no =
3852    pmi_genupdate (K true) (cons row_no) I I I I I I I;
3853
3854fun pmi_add_has_free_pat_vars row_no vars =
3855    pmi_genupdate I I (cons (row_no, vars)) I I I I I I;
3856
3857fun pmi_add_has_unused_pat_vars row_no vars =
3858    pmi_genupdate I I I (cons (row_no, vars)) I I I I I;
3859
3860fun pmi_add_has_double_bound_pat_vars row_no vars =
3861    pmi_genupdate I I I I (cons (row_no, vars)) I I I I;
3862
3863fun pmi_add_has_guards row_no =
3864    pmi_genupdate I I I I I (cons row_no) I I I;
3865
3866fun pmi_add_has_lambda_in_pat row_no =
3867    pmi_genupdate I I I I I I (cons row_no) I I;
3868
3869fun pmi_add_has_non_contr_in_pat row_no terms =
3870    pmi_genupdate I I I I I I I (cons (row_no, terms)) I;
3871
3872fun pmi_set_pmi_exhaustiveness_cond thm_opt =
3873    pmi_genupdate I I I I I I I I (K thm_opt);
3874
3875
3876local
3877
3878  fun analyse_pat (ls : bool (* has lamdbda been seen *),
3879                   sv : term set (* set of all seen vars *),
3880                   msv : term set (* set of all vars seen more than once *),
3881                   sc : term set (* set of all seen constants *))
3882       t =
3883    if is_var t then let
3884         val (sv, msv) = if HOLset.member (sv, t) then
3885            (sv, HOLset.add (msv, t))
3886         else
3887            (HOLset.add (sv, t), msv)
3888      in (ls, sv, msv, sc)
3889    end else if (Literal.is_literal t orelse is_const t) then
3890      (ls, sv, msv, HOLset.add (sc,t))
3891    else if (is_abs t) then
3892      (true, sv, msv, sc)
3893    else if (is_comb t) then let
3894        val (t1, t2) = dest_comb t
3895        val (ls, sv, msv, sc) = analyse_pat (ls, sv, msv, sc) t1
3896        val (ls, sv, msv, sc) = analyse_pat (ls, sv, msv, sc) t2
3897      in
3898         (ls, sv, msv, sc)
3899      end
3900    else failwith "UNREACHABLE"
3901
3902
3903  fun analyse_row ((row_num, row),pmi) = let
3904    val (p_vars, p_body, g_body, _) =
3905      dest_PMATCH_ROW_ABS row
3906
3907    (* check guard *)
3908    val pmi = if aconv g_body T then pmi else
3909                pmi_add_has_guards row_num pmi
3910
3911    (* check pattern *)
3912    val vars_l = pairSyntax.strip_pair p_vars
3913    val vars = HOLset.fromList Term.compare vars_l
3914    val (ls, sv, msv, sc) = analyse_pat (false,
3915       HOLset.empty Term.compare,
3916       HOLset.empty Term.compare,
3917       HOLset.empty Term.compare) p_body
3918
3919    (* Take care of unit vars *)
3920    val sv = case vars_l of
3921        [v] => if type_of v = oneSyntax.one_ty then
3922          HOLset.add (sv, v) else sv
3923      | _ => sv
3924
3925    (* check lambda *)
3926    val pmi = if ls then
3927                pmi_add_has_lambda_in_pat row_num pmi
3928              else pmi
3929
3930    (* check free_vars *)
3931    val fv = HOLset.difference (sv, vars)
3932    val pmi = if HOLset.isEmpty fv then pmi else
3933                (pmi_add_has_free_pat_vars row_num
3934                   (HOLset.listItems fv) pmi)
3935
3936    (* check unused vars *)
3937    val uv = HOLset.difference (vars, sv)
3938    val pmi = if HOLset.isEmpty uv then pmi else
3939                (pmi_add_has_unused_pat_vars row_num
3940                   (HOLset.listItems uv) pmi)
3941
3942    (* check double vars *)
3943    val dv = HOLset.intersection (msv, vars)
3944    val pmi = if HOLset.isEmpty dv then pmi else
3945                (pmi_add_has_double_bound_pat_vars row_num
3946                   (HOLset.listItems dv) pmi)
3947
3948    (* check constructors vars *)
3949    val c_l = HOLset.listItems sc
3950    val nc_l = List.filter (fn c =>
3951       not (TypeBase.is_constructor c orelse Literal.is_literal c)) c_l
3952    val pmi = if null nc_l then pmi else
3953                (pmi_add_has_non_contr_in_pat row_num
3954                   nc_l pmi)
3955  in
3956    pmi
3957  end
3958
3959in
3960
3961fun analyse_pmatch try_exh t = let
3962  val (_, rows) = dest_PMATCH t
3963  val nrows = enumerate 0 rows
3964  val pmi = pmi_set_is_well_formed true init_pmatch_info
3965  val pmi = List.foldl analyse_row pmi nrows
3966
3967  val pmi = (if (try_exh andalso is_ocaml_pmatch pmi) then
3968      pmi_set_pmi_exhaustiveness_cond (SOME (PMATCH_IS_EXHAUSTIVE_CONSEQ_CHECK t)) pmi else pmi) handle HOL_ERR _ => pmi
3969
3970in
3971  pmi
3972end handle HOL_ERR _ => init_pmatch_info
3973
3974end
3975
3976
3977end
3978