1structure patternMatchesSyntax :> patternMatchesSyntax =
2struct
3
4open HolKernel Parse boolLib Drule BasicProvers
5open simpLib numLib
6open patternMatchesTheory
7open pairSyntax
8open ConseqConv
9
10
11(***********************************************)
12(* Auxiliary stuff                             *)
13(***********************************************)
14
15fun varname_starts_with_uscore v = let
16  val (s, _) = dest_var v
17in
18  String.sub(s, 0) = #"_"
19end handle HOL_ERR _ => false
20
21
22fun mk_var_gen prefix avoid = let
23  val c = ref 0
24  val avoidL = List.map (fst o dest_var) avoid
25  fun next_name () = let
26    val vn = prefix ^ (int_to_string (!c))
27    val _ = c := !c + 1
28    val ok = not (mem vn avoidL)
29  in
30    if ok then vn else next_name ()
31  end
32in
33  fn ty => mk_var (next_name (), ty)
34end
35
36fun mk_wildcard_gen avoid = mk_var_gen "_" avoid
37
38
39(* Get the first element of l that satisfies p and
40   remove it from the list.. *)
41fun pick_element p l = let
42 fun aux acc l =
43   case l of
44       [] => failwith "no element found"
45     | e::l => (if p e then (e, List.rev acc @ l)
46                else aux (e::acc) l)
47 in
48   aux [] l
49 end
50
51(* strip_comb with a maximum number of splits *)
52fun strip_comb_bounded_aux acc n t = if (n > 0) then (let
53  val (t', a) = dest_comb t
54  in
55  strip_comb_bounded_aux (a::acc) (n - 1) t'
56end handle HOL_ERR _ => (t, acc)) else (t, acc)
57
58fun strip_comb_bounded n t = strip_comb_bounded_aux [] n t
59
60(* apply a conversion to all leafs of a disjunct *)
61fun ALL_DISJ_CONV c t = if (is_disj t) then (
62  (BINOP_CONV (ALL_DISJ_CONV c)) t
63) else (TRY_CONV c) t
64
65(* apply a conversion to all leafs of a disjunct
66   and simplify the result by removing T and F. *)
67fun ALL_DISJ_TF_ELIM_CONV c t = let
68  val (t1, t2) = dest_disj t
69in
70  if (aconv t1 T) then
71    SPEC t2 (ConseqConvTheory.OR_CLAUSES_TX)
72  else if (aconv t2 T) then
73    SPEC t1 (ConseqConvTheory.OR_CLAUSES_XT)
74  else let
75    val thm1_opt = SOME (ALL_DISJ_TF_ELIM_CONV c t1) handle UNCHANGED => NONE
76    val thm1_opt_eq_T = case thm1_opt of
77        NONE => false
78      | SOME thm1 => (aconv (rhs (concl thm1)) T)
79    val thm2_opt = if thm1_opt_eq_T then NONE else SOME (ALL_DISJ_TF_ELIM_CONV c t2) handle UNCHANGED => NONE
80
81    val thm0 = case (thm1_opt, thm2_opt) of
82        (NONE, NONE) => raise UNCHANGED
83      | (SOME thm1, NONE) => RATOR_CONV (RAND_CONV (K thm1)) t
84      | (NONE, SOME thm2) => RAND_CONV (K thm2) t
85      | (SOME thm1, SOME thm2) => (
86           (RATOR_CONV (RAND_CONV (K thm1))) THENC
87           (RAND_CONV (K thm2))) t
88
89    val (t1', t2') = dest_disj (rhs (concl thm0))
90
91    val rewr_thm_opt = if (aconv t1' T) then
92        SOME (ConseqConvTheory.OR_CLAUSES_TX)
93      else if (aconv t1' F) then
94        SOME (ConseqConvTheory.OR_CLAUSES_FX)
95      else if (aconv t2' T) then
96        SOME (ConseqConvTheory.OR_CLAUSES_XT)
97      else if (aconv t2' F) then
98        SOME (ConseqConvTheory.OR_CLAUSES_XF)
99      else NONE
100
101   val thm1 = case rewr_thm_opt of
102      NONE => thm0
103    | SOME thm_rw => RIGHT_CONV_RULE (REWR_CONV thm_rw) thm0
104  in
105    thm1
106  end
107end handle HOL_ERR _ => (TRY_CONV c) t
108
109
110(* apply a conversion to all leafs of a conjunct. *)
111fun ALL_CONJ_CONV c t = if (is_conj t) then (
112  (BINOP_CONV (ALL_CONJ_CONV c)) t
113) else (TRY_CONV c) t
114
115
116fun DESCEND_CONV c_desc c t =
117  (c THENC TRY_CONV (c_desc (DESCEND_CONV c_desc c))) t
118
119
120fun STRIP_ABS_CONV conv t =
121  if (is_abs t) then ABS_CONV (STRIP_ABS_CONV conv) t else
122  conv t
123
124fun has_subterm p t = ((find_term p t; true) handle HOL_ERR _ => false)
125
126(* like proof, but less verbose, since we expect that it might fail *)
127val prove_attempt = Lib.with_flag (Feedback.emit_MESG, false) prove
128
129
130(***********************************************)
131(* Labels from markerLib                       *)
132(***********************************************)
133
134(* generating fresh labels and vars using
135   a counter *)
136fun mk_new_label_gen prefix = let
137  val c = ref 0
138in
139  fn () => let
140    val l = prefix ^ int_to_string (!c)
141    val _ = c := !c + 1
142  in
143     l
144  end
145end
146
147fun add_labels_CONV lbls t = let
148  val lbl_tm = List.foldl markerSyntax.mk_label t lbls
149in
150  GSYM ((REPEATC markerLib.DEST_LABEL_CONV) lbl_tm)
151end
152
153(*
154  val mk_new_label = mk_new_label_generator "disj"
155  val lbl_tm = markerSyntax.mk_label (mk_new_label (), lbl_tm)
156  val t = lbl_tm
157*)
158
159fun strip_labels t = let
160  fun aux acc t = let
161    val (l, t') = markerSyntax.dest_label t
162  in
163    aux (l::acc) t'
164  end handle HOL_ERR _ => (List.rev acc, t)
165in
166  aux [] t
167end
168
169(* conversion underneath a list of labels *)
170fun strip_labels_CONV c t =
171  if (markerSyntax.is_label t) then
172    RAND_CONV (strip_labels_CONV c) t
173  else
174    c t
175
176(* conversion underneath a list of labels containing at least
177   one label from list [lbls]. *)
178fun guarded_strip_labels_CONV lbls c t = let
179  val (lbls_found, _) = strip_labels t
180  val found = List.exists (fn l1 => List.exists (fn l2 => (l1 = l2)) lbls) lbls_found
181in
182  if not found then raise UNCHANGED else
183     strip_labels_CONV c t
184end
185
186
187(***********************************************)
188(* Terms                                       *)
189(***********************************************)
190
191val TyV = mk_vartype
192val ty_var_subst = [alpha |-> gen_tyvar (),
193             beta |-> gen_tyvar (),
194             gamma |-> gen_tyvar (),
195             delta |-> gen_tyvar (),
196             TyV "'e" |-> gen_tyvar (),
197             TyV "'f" |-> gen_tyvar (),
198             TyV "'g" |-> gen_tyvar (),
199             TyV "'h" |-> gen_tyvar (),
200             TyV "'i" |-> gen_tyvar (),
201             TyV "'j" |-> gen_tyvar ()
202            ]
203
204fun PC nm = prim_mk_const { Thy = "patternMatches", Name = nm }
205
206val PMATCH_ROW_tm = PC "PMATCH_ROW"
207val PMATCH_ROW_gtm = inst ty_var_subst PMATCH_ROW_tm;
208
209val PMATCH_ROW_COND_tm = PC "PMATCH_ROW_COND"
210val PMATCH_ROW_COND_gtm = inst ty_var_subst PMATCH_ROW_COND_tm;
211
212val PMATCH_ROW_COND_EX_tm = PC "PMATCH_ROW_COND_EX"
213val PMATCH_ROW_COND_EX_gtm = inst ty_var_subst PMATCH_ROW_COND_EX_tm;
214
215val PMATCH_tm = PC "PMATCH"
216val PMATCH_gtm = inst ty_var_subst PMATCH_tm
217
218val PMATCH_IS_EXHAUSTIVE_tm = PC "PMATCH_IS_EXHAUSTIVE"
219val PMATCH_IS_EXHAUSTIVE_gtm = inst ty_var_subst PMATCH_IS_EXHAUSTIVE_tm
220
221fun FRESH_TY_VARS_RULE thm =
222  INST_TYPE ty_var_subst thm
223
224fun REMOVE_REBIND_CONV_AUX avoid t = let
225  val (v, t') = dest_abs t
226  val v' = variant avoid v
227  val t'' = subst [v |-> v'] t'
228  val (t''', avoid') = REMOVE_REBIND_CONV_AUX (v'::avoid) t''
229in
230  (mk_abs (v', t'''), avoid')
231end handle HOL_ERR _ => let
232  val (t1, t2) = dest_comb t
233  val (t1', avoid1) = REMOVE_REBIND_CONV_AUX avoid t1
234  val (t2', avoid2) = REMOVE_REBIND_CONV_AUX avoid1 t2
235in
236  (mk_comb (t1', t2'), avoid2)
237end handle HOL_ERR _ => (t, avoid)
238
239fun REMOVE_REBIND_CONV t = let
240  val (t', _) = REMOVE_REBIND_CONV_AUX [] t
241in
242  ALPHA t t'
243end
244
245
246(***********************************************)
247(* PMATCH_ROW                                  *)
248(***********************************************)
249
250fun dest_PMATCH_ROW row = let
251  val (f, args) = strip_comb row
252  val _ = if (same_const f PMATCH_ROW_tm) andalso (List.length args = 3) then () else failwith "dest_PMATCH_ROW"
253in
254  (el 1 args, el 2 args, el 3 args)
255end
256
257fun is_PMATCH_ROW t = can dest_PMATCH_ROW t
258
259fun mk_PMATCH_ROW (p_t, g_t, r_t) =
260  list_mk_icomb (PMATCH_ROW_gtm, [p_t, g_t, r_t])
261
262fun mk_pabs_from_vars vars tl = case vars of
263      []  => let
264               val uv =
265                   variant (free_varsl tl) (mk_var("_uv", oneSyntax.one_ty))
266             in
267               fn t => mk_abs (uv, t)
268             end
269    | [v] => (fn t => mk_abs (v, t))
270    | _   => (fn t => pairSyntax.mk_pabs (pairSyntax.list_mk_pair vars, t))
271
272fun mk_PMATCH_ROW_PABS vars (p_t, g_t, r_t) = let
273    val mk_pabs = mk_pabs_from_vars vars [p_t, g_t, r_t]
274  in
275    mk_PMATCH_ROW (mk_pabs p_t, mk_pabs g_t, mk_pabs r_t)
276  end
277
278fun MULTIPLE_FV_AUX (dups : term HOLset.set) (seen : term HOLset.set) (t : term) =
279  case Psyntax.dest_term t of
280      Psyntax.VAR (_, _) =>
281      if (HOLset.member (seen, t)) then
282        (HOLset.add (dups, t), seen)
283      else
284        (dups, HOLset.add (seen, t))
285    | Psyntax.CONST _ => (dups, seen)
286    | Psyntax.LAMB (v, t') => let
287         val (dups', seen') = MULTIPLE_FV_AUX dups seen t'
288         val dups'' = if HOLset.member (dups, v) then dups' else
289                         HOLset.delete (dups', v) handle NotFound => dups'
290         val seen'' = if HOLset.member (seen, v) then seen' else
291                         HOLset.delete (seen', v) handle NotFound => seen'
292       in (dups'', seen'') end
293    | Psyntax.COMB (t1, t2) => let
294         val (dups',  seen')  = MULTIPLE_FV_AUX dups  seen  t1
295         val (dups'', seen'') = MULTIPLE_FV_AUX dups' seen' t2
296      in
297         (dups'', seen'')
298      end;
299
300fun MULTIPLE_FV t = MULTIPLE_FV_AUX empty_tmset empty_tmset t;
301
302fun mk_PMATCH_ROW_PABS_WILDCARDS vars (p_t, g_t, r_t) = let
303    val (pm_s, p_s) = MULTIPLE_FV p_t
304    val grd_s = FVL [g_t, r_t] pm_s
305
306    val avoid = HOLset.listItems (HOLset.union (grd_s, p_s))
307    val mk_wc = mk_wildcard_gen avoid
308    val mk_var = mk_var_gen "v" avoid
309
310    fun apply (v, (vars', subst)) = let
311      val should_be_uc = not (HOLset.member (grd_s, v))
312      val is_uc = varname_starts_with_uscore v
313    in
314      if (should_be_uc = is_uc) then
315         (v::vars', subst)
316      else let
317        val v' = if should_be_uc then
318          mk_wc (type_of v) else mk_var (type_of v)
319      in
320        (v'::vars', (v |-> v')::subst)
321      end
322    end
323
324    val (vars'_rev, subst) = List.foldl apply ([], []) vars
325    val vars' = List.rev vars'_rev
326    val p_t' = Term.subst subst p_t
327    val g_t' = Term.subst subst g_t
328    val r_t' = Term.subst subst r_t
329    val changed_wc = not (List.null subst)
330  in
331    (changed_wc, mk_PMATCH_ROW_PABS vars' (p_t', g_t', r_t'))
332  end
333
334
335fun dest_PMATCH_ROW_ABS row = let
336  val (p_t, g_t, r_t) = dest_PMATCH_ROW row
337
338  val (p_vars, p_body) = pairSyntax.dest_pabs p_t
339  val (g_vars, g_body) = pairSyntax.dest_pabs g_t
340  val (r_vars, r_body) = pairSyntax.dest_pabs r_t
341
342  val _ = if (aconv p_vars g_vars) andalso (aconv g_vars r_vars) then
343    () else failwith "dest_PMATCH_ROW_ABS"
344in
345  (p_vars, p_body, g_body, r_body)
346end
347
348
349fun dest_PMATCH_ROW_ABS_VARIANT vs row = let
350  val (p_vars, p_body, g_body, r_body) = dest_PMATCH_ROW_ABS row
351  val (p_vars', sub) = variant_of_term vs p_vars
352in
353  (p_vars', subst sub p_body, subst sub g_body, subst sub r_body)
354end;
355
356val K_elim = (* |- K x = (\y. x) *)
357  AP_THM combinTheory.K_DEF (mk_var("x", alpha))
358         |> CONV_RULE (RAND_CONV BETA_CONV)
359
360fun PMATCH_ROW_PABS_ELIM_CONV row = let
361  val (p, _, _) = dest_PMATCH_ROW row
362  val (vars, _) = pairSyntax.dest_pabs p
363
364  val c = TRY_CONV (REWR_CONV K_elim) THENC (TRY_CONV pairTools.PABS_ELIM_CONV)
365
366  val thm = ((RAND_CONV c) THENC
367             (RATOR_CONV (RAND_CONV c)) THENC
368             (RATOR_CONV (RATOR_CONV (RAND_CONV c)))) row
369            handle UNCHANGED => REFL row
370in
371  (vars, thm)
372end;
373
374
375fun PMATCH_ROW_PABS_INTRO_CONV vars row = let
376  val _ = if (is_PMATCH_ROW row) then () else failwith "PMATCH_ROW_PABS_INTRO_CONV"
377
378  val (vars', _) = variant_of_term (free_vars row) vars
379  val c = pairTools.PABS_INTRO_CONV vars'
380  val thm = ((RAND_CONV c) THENC
381             (RATOR_CONV (RAND_CONV c)) THENC
382             (RATOR_CONV (RATOR_CONV (RAND_CONV c)))) row
383in
384  thm
385end;
386
387fun PMATCH_ROW_FORCE_SAME_VARS_CONV row = let
388  val _ = if not (is_PMATCH_ROW row) then raise UNCHANGED else ()
389  val _ = if can dest_PMATCH_ROW_ABS row then raise UNCHANGED else ()
390  val (vars, thm0) = PMATCH_ROW_PABS_ELIM_CONV row
391  val thm1 = PMATCH_ROW_PABS_INTRO_CONV vars (rhs (concl thm0))
392in
393  TRANS thm0 thm1
394end handle HOL_ERR _ => raise UNCHANGED
395
396fun PMATCH_ROW_INTRO_WILDCARDS_CONV row = let
397  val (vars_tm, p_t, g_t, r_t) = dest_PMATCH_ROW_ABS row
398  val vars = pairSyntax.strip_pair vars_tm
399  val (ch, row') = mk_PMATCH_ROW_PABS_WILDCARDS vars (p_t, g_t, r_t)
400  val _ = if ch then () else raise UNCHANGED
401in
402  ALPHA row row'
403end handle HOL_ERR _ => raise UNCHANGED
404
405(*
406val row = ``
407      PMATCH_ROW (\ (y,z). (SOME y,SUC z,[1; 2]))
408                 (\ x. T)
409                 (\ (y,z). y + z)``
410
411val (vars, thm) = PMATCH_ROW_PABS_ELIM_CONV row
412val thm2 = PMATCH_ROW_PABS_INTRO_CONV vars (rhs (concl thm))
413val row = rhs (concl thm2)
414*)
415
416(***********************************************)
417(* PMATCH                                      *)
418(***********************************************)
419
420fun mk_PMATCH v rows = let
421  val rows_ty = let
422    val ty0 = type_of PMATCH_tm
423    val (arg_tys, _) = wfrecUtils.strip_fun_type  ty0
424  in el 2 arg_tys end
425
426  val ty_subst = match_type rows_ty (type_of rows)
427  val b_tm = inst ty_subst PMATCH_tm
428  val t1 = mk_comb (b_tm, v)
429  val t2 = mk_comb (t1, rows)
430in
431  t2
432end
433
434fun dest_PMATCH t = let
435  val (f, args) = strip_comb t
436  val _ = if (same_const f PMATCH_tm) andalso (List.length args = 2) then () else failwith "dest_PMATCH"
437  val (l, _) = listSyntax.dest_list (el 2 args)
438in
439  (el 1 args, l)
440end
441
442fun is_PMATCH t = can dest_PMATCH t
443
444fun dest_PATLIST_COLS v ps = let
445  fun split_pat (p, (m, l)) = let
446    val (vars_tm, pt) = pairSyntax.dest_pabs p
447    val vars = pairSyntax.strip_pair vars_tm
448    val ps = pairSyntax.strip_pair pt
449    val m' = length ps
450  in
451    (Int.max (m, m'), (vars, pt, ps, m')::l)
452  end
453  val (col_no, rows') = foldl split_pat (0, []) ps
454
455  fun aux acc v col_no = if (col_no <= 1) then List.rev (v::acc) else (
456    let
457       val (v1, v2) = pairSyntax.dest_pair v handle HOL_ERR _ =>
458          (pairSyntax.mk_fst v,  pairSyntax.mk_snd v)
459    in
460       aux (v1::acc) v2 (col_no-1)
461    end
462  )
463
464  fun final_process ((vars, pt, ps, cols), l) =
465  let
466    val ps' = if (cols = col_no) then ps else aux [] pt col_no
467  in
468    (List.map (fn p => (vars, p)) ps')::l
469  end
470
471  val rows'' = foldl final_process [] rows'
472  val vs = aux [] v col_no
473
474  fun get_cols acc vs rows = case vs of
475      [] => List.rev acc
476    | (v::vs') => let
477        val col = map hd rows
478        val rows' = map tl rows
479      in
480        get_cols ((v, col)::acc) vs' rows'
481      end
482
483  val cols = get_cols [] vs rows''
484in
485  cols
486end handle Empty => failwith "dest_PATLIST_COLS"
487
488
489fun dest_PMATCH_COLS t = let
490  val (v, rows) = dest_PMATCH t
491  val ps = List.map (#1 o dest_PMATCH_ROW) rows
492in
493  dest_PATLIST_COLS v ps
494end
495
496fun list_CONV c t =
497  if not (listSyntax.is_cons t) then  raise UNCHANGED else (
498  (RATOR_CONV (RAND_CONV c) THENC
499   RAND_CONV (list_CONV c)) t)
500
501fun list_nth_CONV n c t =
502  if not (listSyntax.is_cons t) then  raise UNCHANGED else
503  if (n < 0) then raise UNCHANGED else
504  if (n = 0) then RATOR_CONV (RAND_CONV c) t else
505  (RAND_CONV (list_nth_CONV (n-1) c)) t
506
507fun PMATCH_ROWS_CONV c t = let
508  val _ = if not (is_PMATCH t) then raise UNCHANGED else ()
509in
510  RAND_CONV (list_CONV c) t
511end
512
513val PMATCH_FORCE_SAME_VARS_CONV =
514  PMATCH_ROWS_CONV PMATCH_ROW_FORCE_SAME_VARS_CONV
515
516val PMATCH_INTRO_WILDCARDS_CONV =
517  PMATCH_ROWS_CONV PMATCH_ROW_INTRO_WILDCARDS_CONV
518
519(* Introduce fresh variables *)
520(*
521
522val t = ``case f x of
523  | (x, z, SUC l) when cond z => gggg l x
524  | x.| (x, z, _) => g2
525  | y.| (y, z, _) => g2
526  | (ff a, _, _) => a`` *)
527
528fun PMATCH_INTRO_GENVARS t = let
529  fun add_to_subst (s_intro, s_elim) nt =
530  if (is_var nt orelse exists (aconv nt o #redex) s_intro) then
531     (s_intro, s_elim)
532  else let
533     val nv = genvar (type_of nt)
534  in
535    ((nt |-> nv)::s_intro, (nv |-> nt)::s_elim)
536  end
537
538  val (v, rows) = dest_PMATCH t
539  val (s_intro, s_elim) = add_to_subst ([], []) v
540
541  fun add_row (r, (s_intro, s_elim)) = let
542    val (pt, gt, rt) = dest_PMATCH_ROW r
543    val (s_intro, s_elim) = add_to_subst (s_intro, s_elim) pt
544    val (s_intro, s_elim) = add_to_subst (s_intro, s_elim) gt
545    val (s_intro, s_elim) = add_to_subst (s_intro, s_elim) rt
546  in
547    (s_intro, s_elim)
548  end
549
550  val (s_intro, s_elim) = foldl add_row (s_intro, s_elim) rows
551in
552  (subst s_intro t, s_elim)
553end
554
555
556(***********************************************)
557(* PMATCH_ROW_COND                             *)
558(***********************************************)
559
560fun dest_PMATCH_ROW_COND rc = let
561  val (f, args) = strip_comb rc
562  val _ = if (same_const f PMATCH_ROW_COND_tm) andalso (List.length args = 4) then () else failwith "dest_PMATCH_ROW_COND"
563in
564  (el 1 args, el 2 args, el 3 args, el 4 args)
565end
566
567fun is_PMATCH_ROW_COND t = can dest_PMATCH_ROW_COND t
568
569fun mk_PMATCH_ROW_COND (p_t, g_t, i, x) =
570  list_mk_icomb (PMATCH_ROW_COND_gtm, [p_t, g_t, i, x])
571
572fun mk_PMATCH_ROW_COND_PABS vars (p_t, g_t, i, x) = let
573    val mk_pabs = mk_pabs_from_vars vars [p_t, g_t, x]
574  in
575    mk_PMATCH_ROW_COND (mk_pabs p_t, mk_pabs g_t, i, x)
576  end
577
578fun dest_PMATCH_ROW_COND_ABS rc = let
579  val (p_t, g_t, i_t, x_t) = dest_PMATCH_ROW_COND rc
580
581  val (p_vars, p_body) = pairSyntax.dest_pabs p_t
582  val (g_vars, g_body) = pairSyntax.dest_pabs g_t
583
584  val _ = if (aconv p_vars g_vars) then
585    () else failwith "dest_PMATCH_ROW_COND_ABS"
586in
587  (p_vars, p_body, g_body, i_t, x_t)
588end
589
590
591(***********************************************)
592(* PMATCH_ROW_COND_EX                          *)
593(***********************************************)
594
595fun dest_PMATCH_ROW_COND_EX rc = let
596  val (f, args) = strip_comb rc
597  val _ = if (same_const f PMATCH_ROW_COND_EX_tm) andalso (List.length args = 3) then () else failwith "dest_PMATCH_ROW_COND_EX"
598in
599  (el 1 args, el 2 args, el 3 args)
600end
601
602fun is_PMATCH_ROW_COND_EX t = can dest_PMATCH_ROW_COND_EX t
603
604fun mk_PMATCH_ROW_COND_EX (i, p_t, g_t) =
605  list_mk_icomb (PMATCH_ROW_COND_EX_gtm, [i, p_t, g_t])
606
607fun mk_PMATCH_ROW_COND_EX_PABS vars (i, p_t, g_t) = let
608    val mk_pabs = mk_pabs_from_vars vars [p_t, g_t]
609  in
610    mk_PMATCH_ROW_COND_EX (i, mk_pabs p_t, mk_pabs g_t)
611  end
612
613fun mk_PMATCH_ROW_COND_EX_PABS_MOVE_TO_GUARDS find vars (i, p_t, g_t) = let
614  val fr_vs = free_vars i @ free_vars p_t @ free_vars g_t
615  fun move_to_guard (vars, p_t, g_t) = let
616    val (p: term) = case find vars p_t of
617                NONE => failwith "not found"
618              | SOME p => p
619    val _ = if (mem p vars) then failwith "loop" else ()
620    val x = mk_var ("x", type_of p)
621    val x = variant (fr_vs @ vars) x
622    val p_t' = Term.subst [p |-> x] p_t
623    val g_t' = mk_conj (mk_eq (x, p), g_t)
624    val vars' = x :: vars
625  in
626    move_to_guard (vars', p_t', g_t')
627  end handle HOL_ERR _ => (vars, p_t, g_t)
628
629  val (vars', p_t', g_t') = move_to_guard (vars, p_t, g_t)
630in
631  mk_PMATCH_ROW_COND_EX_PABS vars' (i, p_t', g_t')
632end
633
634
635fun mk_PMATCH_ROW_COND_EX_pat i p = let
636    val (vars, _) = pairSyntax.dest_pabs p
637    val g = pairSyntax.mk_pabs (vars, T)
638  in
639    mk_PMATCH_ROW_COND_EX (i, p, g)
640  end
641
642fun mk_PMATCH_ROW_COND_EX_ROW i r = let
643    val (p, g, _) = dest_PMATCH_ROW r
644  in
645    mk_PMATCH_ROW_COND_EX (i, p, g)
646  end
647
648fun dest_PMATCH_ROW_COND_EX_ABS rc = let
649  val (i_t, p_t, g_t) = dest_PMATCH_ROW_COND_EX rc
650
651  val (p_vars, p_body) = pairSyntax.dest_pabs p_t
652  val (g_vars, g_body) = pairSyntax.dest_pabs g_t
653
654  val _ = if (aconv p_vars g_vars) then
655    () else failwith "dest_PMATCH_ROW_COND_EX_ABS"
656in
657  (i_t, p_vars, p_body, g_body)
658end
659
660
661(*
662val t = (el 4 o strip_disj o snd o strip_forall o concl) thm
663val v = (fst o dest_forall o concl) thm
664val t = ``x = (NONE,c)``
665val v = lhs t
666fun P vs x = NONE
667*)
668
669fun PMATCH_ROW_COND_EX_INTRO_CONV_GEN P v t = let
670  val (vs, b) = strip_exists t
671  val b_conjs = strip_conj b
672  val (peq_t, guards) = pick_element (fn c => (aconv (lhs c) v handle HOL_ERR _ => false)) b_conjs
673
674  val p_t = rhs peq_t
675  val g_t = if List.null guards then T else list_mk_conj guards
676
677  val rc = mk_PMATCH_ROW_COND_EX_PABS vs (v, p_t, g_t)
678
679  val rc_eq_tm = mk_eq (t, rc)
680  (* set_goal ([], rc_eq_tm) *)
681  val rc_eq_thm = prove (rc_eq_tm,
682    SIMP_TAC std_ss [PMATCH_ROW_COND_EX_def, PMATCH_ROW_COND_def, pairTheory.EXISTS_PROD] THEN
683    TRY (REDEPTH_CONSEQ_CONV_TAC (K EXISTS_EQ___CONSEQ_CONV)) THEN
684    SIMP_TAC (std_ss++boolSimps.EQUIV_EXTRACT_ss) []
685  )
686in
687  rc_eq_thm
688end
689
690fun PMATCH_ROW_COND_EX_INTRO_CONV v t =
691  PMATCH_ROW_COND_EX_INTRO_CONV_GEN (fn _ => fn _ => NONE) v t;
692
693fun nchotomy2PMATCH_ROW_COND_EX_CONV_GEN P t = let
694  val (v, _) = dest_forall t
695in
696  (QUANT_CONV (ALL_DISJ_CONV (PMATCH_ROW_COND_EX_INTRO_CONV_GEN P v))) t
697end;
698
699fun nchotomy2PMATCH_ROW_COND_EX_CONV t =
700  nchotomy2PMATCH_ROW_COND_EX_CONV_GEN (fn _ => fn _ => NONE) t;
701
702fun PMATCH_ROW_COND_EX_ELIM_CONV t = let
703  val (_, p_t, _) = dest_PMATCH_ROW_COND_EX t
704  val (vars, _) = pairSyntax.dest_pabs p_t
705
706  val thm0 = REWR_CONV PMATCH_ROW_COND_EX_FULL_DEF t
707  val thm1 = RIGHT_CONV_RULE (RAND_CONV (pairTools.PABS_INTRO_CONV vars)) thm0
708  val thm2 = RIGHT_CONV_RULE pairTools.ELIM_TUPLED_QUANT_CONV thm1 handle HOL_ERR _ => thm1
709  val thm3 = RIGHT_CONV_RULE (STRIP_QUANT_CONV (DEPTH_CONV pairLib.GEN_BETA_CONV)) thm2
710  val thm4 = RIGHT_CONV_RULE (PURE_REWRITE_CONV [AND_CLAUSES]) thm3
711  val thm5 = RIGHT_CONV_RULE (REWR_CONV boolTheory.EXISTS_SIMP) thm4 handle HOL_ERR _ => thm4
712in
713  thm5
714end
715
716
717(***********************************************)
718(* EXHAUSTIVE                                  *)
719(***********************************************)
720
721fun mk_PMATCH_IS_EXHAUSTIVE v rows = let
722  val rows_ty = let
723    val ty0 = type_of PMATCH_IS_EXHAUSTIVE_tm
724    val (arg_tys, _) = wfrecUtils.strip_fun_type  ty0
725  in el 2 arg_tys end
726
727  val ty_subst = match_type rows_ty (type_of rows)
728  val b_tm = inst ty_subst PMATCH_IS_EXHAUSTIVE_tm
729  val t1 = mk_comb (b_tm, v)
730  val t2 = mk_comb (t1, rows)
731in
732  t2
733end
734
735fun dest_PMATCH_IS_EXHAUSTIVE t = let
736  val (f, args) = strip_comb t
737  val _ = if (same_const f PMATCH_IS_EXHAUSTIVE_tm) andalso (List.length args = 2) then () else failwith "dest_PMATCH_IS_EXHAUSTIVE"
738  val (l, _) = listSyntax.dest_list (el 2 args)
739in
740  (el 1 args, l)
741end
742
743fun is_PMATCH_IS_EXHAUSTIVE t = can dest_PMATCH_IS_EXHAUSTIVE t
744
745
746(***********************************************)
747(* Pretty Printing                             *)
748(***********************************************)
749
750val use_pmatch_pp = ref true
751val _ = Feedback.register_btrace ("use pmatch_pp", use_pmatch_pp);
752
753fun pmatch_printer_fix_wildcards (vars, pat, guard, rh) = let
754  val var_l = pairSyntax.strip_pair vars
755  val (wc_l, var_l') = partition varname_starts_with_uscore var_l
756
757  fun mk_fake wc = mk_var (GrammarSpecials.mk_fakeconst_name {fake = "_", original = NONE}, type_of wc)
758
759  val fake_subst = map (fn wc => (wc |-> mk_fake wc)) wc_l
760
761  val vars' =
762    if List.null var_l' then
763      variant (free_varsl [pat, guard, rh]) (mk_var("_", oneSyntax.one_ty))
764    else
765      pairSyntax.list_mk_pair var_l'
766
767  val pat' = Term.subst fake_subst pat
768  val guard' = Term.subst fake_subst guard
769  val rh' = Term.subst fake_subst rh
770in
771  (vars', pat', guard', rh')
772end handle HOL_ERR _ => (vars, pat, guard, rh)
773
774(* wildcard munging turns _ variables into "fake consts" (ensuring the
775   pretty-printer doesn't treat them as variables (giving them a colour etc).
776*)
777fun is_uscV v =
778  isSome (GrammarSpecials.dest_fakeconst_name (#1 (dest_var v)))
779
780fun pmatch_printer
781    (GS : type_grammar.grammar * term_grammar.grammar)
782    (backend : term_grammar.grammar term_pp_types.ppbackend)
783    sys
784    (ppfns:term_pp_types.ppstream_funs)
785    ((pgr,lgr,rgr) : term_pp_types.grav * term_pp_types.grav * term_pp_types.grav)
786    d t =
787  let
788    open Portable term_pp_types smpp
789    infix >>
790    val _ = if (!use_pmatch_pp) then () else raise term_pp_types.UserPP_Failed
791    val {add_string,add_break,ublock,add_newline,ustyle,...} = ppfns
792    val (v,rows) = dest_PMATCH t;
793    val rows' = map (pmatch_printer_fix_wildcards o dest_PMATCH_ROW_ABS) rows
794    val bsys =
795     fn gravs => fn d => sys {gravs = gravs, depth = d, binderp = true}
796    val sys =
797     fn gravs => fn d => sys {gravs = gravs, depth = d, binderp = false}
798    val paren_required =
799      (case rgr of
800         Prec(p, _) => p > 70
801       | _ => false) orelse
802      (case lgr of
803         Prec(_, n) => n = GrammarSpecials.fnapp_special
804       | _ => false)
805    val doparen = if paren_required then (fn c => add_string c)
806                  else (fn c => nothing)
807
808    fun pp_row (vars, pat, guard, rh) =
809      let
810        val (print_vars, print_unit) =
811            let fun get_real_vars t = HOLset.filter (fn v => not (is_uscV v orelse varname_starts_with_uscore v)) (FVL [t] empty_tmset)
812                val vs =  get_real_vars vars
813                val pvs = get_real_vars pat
814            in
815              if HOLset.isSubset(pvs,vs) andalso HOLset.isSubset(vs,pvs) then (false, false)
816              else
817                (true, HOLset.find (not o varname_starts_with_uscore) vs = NONE)
818            end
819        val patsys = if print_vars then sys else bsys
820      in
821        term_pp_utils.record_bvars (pairSyntax.strip_pair vars) (
822          ublock PP.INCONSISTENT 5 (
823            (if not print_vars then nothing
824             else
825               let val V = if print_unit then oneSyntax.one_tm else vars
826               in
827                 bsys (Top, Top, Top) (d - 1) V >>
828                 add_string " " >>
829                 add_string ".|" >>
830                 add_break (1, 0)
831               end) >>
832            sys (Top, Top, Top) (d - 1) pat >>
833            (if aconv guard T then nothing
834             else
835              add_string " " >> add_string "when" >> add_break (1, 0) >>
836              sys (Top, Top, Top) (d - 1) guard) >>
837            add_string " " >> add_string "=>" >> add_break (1, 0) >>
838            sys (Top, Top, Top) (d - 1) rh))
839      end
840  in
841    doparen "(" >>
842    ublock PP.CONSISTENT 0 (
843       ublock PP.CONSISTENT 2
844               (add_string "case" >> add_break(1,2) >>
845                sys (Top, Top, Top) (d - 1) v >>
846                add_break(1,0) >> add_string "of") >>
847       add_break (1, 2) >>
848       ublock PP.CONSISTENT 0 (
849         smpp.pr_list
850           pp_row
851           (add_break(1,~2) >> add_string "|" >> add_string " ") rows'
852       )
853    ) >>
854    doparen ")"
855  end handle HOL_ERR _ => raise term_pp_types.UserPP_Failed;
856
857val userprinter_info = let
858  val (argtys, _) = strip_fun (type_of PMATCH_tm)
859  val args = Lib.mapi (fn i => fn ty => mk_var("x" ^ Int.toString i, ty)) argtys
860  val pmatch_pattern = list_mk_comb(PMATCH_tm, args)
861in
862  ("PMATCH", pmatch_pattern, pmatch_printer)
863end
864
865(* Enabling pmatch *)
866open parsePMATCH
867
868val ENABLE_PMATCH_CASES =
869    add_pmatch {get = term_grammar,
870                arule = K o Parse.temp_add_rule,
871                rmtmtok = K o Parse.temp_remove_termtok,
872                add_ptmproc =
873                  (fn s => fn pp => K (temp_add_preterm_processor s pp)),
874                addup = K o temp_add_user_printer,
875                up = userprinter_info}
876
877val grammar_add_pmatch =
878    add_pmatch { get = (fn g => g),
879                 arule = term_grammar.add_rule,
880                 rmtmtok = C term_grammar.remove_form_with_tok,
881                 add_ptmproc = term_grammar.new_preterm_processor,
882                 addup = term_grammar.add_user_printer,
883                 up = userprinter_info }
884
885
886
887end
888