1structure constrFamiliesLib :> constrFamiliesLib =
2struct
3
4open HolKernel Parse boolLib Drule BasicProvers
5open boolLib simpLib patternMatchesSyntax numLib
6
7(***************************************************)
8(* Auxiliary definitions                           *)
9(***************************************************)
10
11fun cong_ss thms = simpLib.SSFRAG {
12      name = NONE,
13     convs = [],
14     rewrs = [],
15        ac = [],
16    filter = NONE,
17    dprocs = [],
18     congs = thms}
19
20fun failwith f x =
21 raise (mk_HOL_ERR "constrFamiliesLib" f x)
22
23fun variants used_vs vs = let
24 val (_, vs') = foldl (fn (v, (used_vs, vs')) =>
25    let val v' = variant used_vs v
26    in (v'::used_vs, v'::vs') end) (used_vs, []) vs
27in
28  List.rev vs'
29end
30
31(* list_mk_comb with build-in beta reduction *)
32fun list_mk_comb_subst (c, args) = (case args of
33    [] => c
34  | (a::args') => let
35      val (v, c') = dest_abs c
36    in
37      list_mk_comb_subst (subst [v |-> a] c', args')
38    end handle HOL_ERR _ =>
39      list_mk_comb_subst (mk_comb (c, a), args')
40)
41
42(*-----------------------------------------*)
43(* normalise free type variables in a type *)
44(* in order to use it as a map key         *)
45(*-----------------------------------------*)
46
47fun next_ty ty = mk_vartype(Lexis.tyvar_vary (dest_vartype ty));
48
49fun normalise_ty ty = let
50  fun recurse (acc as (dict,usethis)) tylist =
51      case tylist of
52        [] => acc
53      | ty :: rest => let
54        in
55          if is_vartype ty then
56            case Binarymap.peek(dict,ty) of
57                NONE => recurse (Binarymap.insert(dict,ty,usethis),
58                                 next_ty usethis)
59                                rest
60              | SOME _ => recurse acc rest
61          else let
62              val {Args,...} = dest_thy_type ty
63            in
64              recurse acc (Args @ rest)
65            end
66        end
67  val (inst0, _) = recurse (Binarymap.mkDict Type.compare, Type.alpha) [ty]
68  val inst = Binarymap.foldl (fn (tyk,tyv,acc) => (tyk |-> tyv)::acc)
69                             []
70                             inst0
71in
72  Type.type_subst inst ty
73end
74
75
76fun base_ty ty = let
77  val (tn, targs) = dest_type ty
78  val targs' = List.rev (snd (List.foldl (fn (_, (v, l)) => (next_ty v, v::l)) (Type.alpha, []) targs))
79in
80  mk_type (tn, targs')
81end
82
83
84(*------------------------*)
85(* Encoding theorem lists *)
86(*------------------------*)
87
88fun encode_term_opt_list tl = let
89  val tl' = List.map (fn t => markerSyntax.mk_label ("THM_PART", Option.getOpt (t, T))) tl
90  val t = list_mk_conj tl'
91  val thm = (markerLib.DEST_LABELS_CONV THENC REWRITE_CONV []) t
92in
93  thm
94end
95
96fun decode_thm_opt_list combined_thm = let
97  fun process_thm thm = let
98    val thm' = CONV_RULE markerLib.DEST_LABEL_CONV thm
99  in
100    if (aconv (concl thm') T) then NONE else SOME thm'
101  end
102
103  val thms = CONJUNCTS combined_thm
104  val thms' = List.map process_thm thms
105in
106  thms'
107end
108
109fun set_goal_list tl = let
110  val thm = encode_term_opt_list tl
111in
112  proofManagerLib.set_goal ([], rhs (concl thm))
113end
114
115
116fun prove_list (tl, tac) = let
117  val thm = encode_term_opt_list tl
118  val thm2 = prove (rhs (concl thm), tac)
119  val thm3 = EQ_MP (GSYM thm) thm2
120in
121  decode_thm_opt_list thm3
122end
123
124
125
126(***************************************************)
127(* Constructors                                    *)
128(***************************************************)
129
130(* A constructor is a combination of a term with
131   a list of names for all it's arguments *)
132datatype constructor = CONSTR of term * (string list)
133
134fun mk_constructor t args = CONSTR (t, args)
135
136fun constructor_is_const (CONSTR (_, sl)) = null sl
137
138fun mk_constructor_term vs (CONSTR (c, args)) = let
139  val (arg_tys, b_ty) = strip_fun (type_of c)
140  val _ = if (length arg_tys < length args) then
141    failwith "check_constructor" "too many argument names given" else ()
142
143  val typed_args = zip args (List.take (arg_tys, length args))
144  val arg_vars = List.map mk_var typed_args
145  val arg_vars' = variants vs arg_vars
146  val t = list_mk_comb_subst (c, arg_vars')
147in
148  (t, arg_vars')
149end
150
151fun match_constructor (CONSTR (cr, args)) t = let
152  val (t', args') = strip_comb_bounded (List.length args) t
153in
154  if (same_const t' cr) then
155    SOME (t', zip args args')
156  else NONE
157end
158
159
160(* Multiple constructors for a single type are usually
161   grouped. These can be exhaustive or not. *)
162type constructorList = {
163  cl_type          : hol_type,
164  cl_constructors  : constructor list,
165  cl_is_exhaustive : bool
166}
167
168fun mk_constructorList is_exhaustive constrs = let
169  val ts = List.map (fst o (mk_constructor_term [])) constrs
170  val _ = if null ts then failwith "make_constructorList" "no constructors given" else ()
171  val ty = type_of (hd ts)
172  val _ = if (Lib.all (fn t => type_of t = ty) ts) then () else
173     failwith "make_constructorList" "types of constructors don't match"
174in
175  { cl_type          = ty,
176    cl_constructors  = constrs,
177    cl_is_exhaustive = is_exhaustive }:constructorList
178end
179
180fun make_constructorList is_exhaustive constrs =
181  mk_constructorList is_exhaustive (List.map
182    (uncurry mk_constructor) constrs)
183
184(***************************************************)
185(* Constructor Families                            *)
186(***************************************************)
187
188(* Contructor families are lists of constructors with
189   a cass-split constant + extra theorems.
190*)
191
192type constructorFamily = {
193  constructors  : constructorList,
194  case_const    : term,
195  one_one_thm   : thm option,
196  distinct_thm  : thm option,
197  case_split_thm: thm,
198  case_cong_thm : thm,
199  nchotomy_thm  : thm option
200}
201
202fun constructorFamily_get_rewrites (cf : constructorFamily) =
203  case (#one_one_thm cf, #distinct_thm cf) of
204      (NONE, NONE) => TRUTH
205    | (SOME thm1, NONE) => thm1
206    | (NONE, SOME thm2) => thm2
207    | (SOME thm1, SOME thm2) => CONJ thm1 thm2
208
209fun constructorFamily_get_ssfrag (cf : constructorFamily) =
210  simpLib.merge_ss [simpLib.rewrites [constructorFamily_get_rewrites cf],
211   cong_ss [#case_cong_thm cf]]
212
213fun constructorFamily_get_constructors (cf : constructorFamily) = let
214  val cl = #constructors cf
215  val cs = #cl_constructors cl
216  val ts = List.map (fn (CONSTR (a, b)) => (a, b)) cs
217in
218  (#cl_is_exhaustive cl, ts)
219end
220
221fun constructorFamily_get_case_split (cf: constructorFamily) =
222  (#case_split_thm cf)
223
224fun constructorFamily_get_case_cong (cf: constructorFamily) =
225  (#case_cong_thm cf)
226
227fun constructorFamily_get_nchotomy_thm_opt (cf: constructorFamily) =
228  (#nchotomy_thm cf)
229
230(* Test datatype
231val _ = Datatype `test_ty =
232    A
233  | B 'b
234  | C bool 'a bool
235  | D num bool`
236
237val SOME constrL = constructorList_of_typebase ``:('a, 'b) test_ty``
238val case_const = TypeBase.case_const_of ``:('a, 'b) test_ty``
239
240val constrL = make_constructorList false [(``{}:'a set``, []), (``\x:'a. {x}``, ["x"])]
241
242val set_CASE_def = zDefine `
243  set_CASE s c_emp c_sing c_else =
244    (if s = {} then c_emp else (
245     if (FINITE s /\ (CARD s = 1)) then c_sing (CHOICE s) else
246     c_else))`
247
248val case_const = ``set_CASE``
249*)
250
251
252fun mk_one_one_thm_term_opt (constrL : constructorList) = let
253  fun mk_one_one_single cr = let
254    val (l, vl) = mk_constructor_term [] cr
255    val (r, vr) = mk_constructor_term vl cr
256    val lr = mk_eq (l, r)
257    val eqs = list_mk_conj (List.map mk_eq (zip vl vr))
258    val b = mk_eq (lr, eqs)
259  in
260    list_mk_forall (vl @ vr, b)
261  end
262
263  val constrs = filter (not o constructor_is_const) (#cl_constructors constrL)
264  val eqs = map mk_one_one_single constrs
265in
266  if (null eqs) then NONE else SOME (list_mk_conj eqs)
267end
268
269
270fun mk_distinct_thm_term_opt (constrL : constructorList) = let
271  val constrs = #cl_constructors constrL
272  val all_pairs = flatten (List.map (fn x =>
273     List.map (fn y => (x, y)) constrs) constrs)
274  val dist_pairs = List.filter (fn (CONSTR (c1, _), CONSTR (c2, _)) =>
275    not (aconv c1 c2)) all_pairs
276  fun mk_distinct_single (cr1, cr2) = let
277    val (l, vl) = mk_constructor_term [] cr1
278    val (r, vr) = mk_constructor_term vl cr2
279    val lr = mk_neg (mk_eq (l, r))
280  in
281    list_mk_forall (vl @ vr, lr)
282  end
283
284  val eqs = map mk_distinct_single dist_pairs
285in
286  if (null eqs) then NONE else SOME (list_mk_conj eqs)
287end
288
289
290fun mk_case_expand_thm_term case_const (constrL : constructorList) = let
291  val (arg_tys, res_ty) = strip_fun (type_of case_const)
292  val split_arg = mk_var ("x", hd arg_tys)
293  val split_fun = mk_var ("ff", hd arg_tys --> res_ty)
294
295  fun mk_arg cr = let
296    val (b, vs) = mk_constructor_term [split_fun,split_arg] cr
297    val b' = mk_comb (split_fun, b)
298  in
299    list_mk_abs (vs, b')
300  end
301
302  val args = List.map mk_arg (#cl_constructors constrL)
303  val args = if (#cl_is_exhaustive constrL) then args else
304    args@[(mk_abs (split_arg, mk_comb(split_fun, split_arg)))]
305
306  val r = list_mk_comb (case_const, split_arg::args)
307  val l = mk_comb (split_fun, split_arg)
308
309  val eq = list_mk_forall ([split_fun, split_arg], mk_eq (l, r))
310in
311  eq
312end
313
314
315fun mk_case_const_cong_thm_term case_const (constrL : constructorList) = let
316  val (arg_tys, res_ty) = strip_fun (type_of case_const)
317
318  val (args_l, args_r) = let
319    fun mk_args avoid = let
320      fun mk_arg (a_ty, (i, avoid, vs)) =
321        let
322          val v = variant avoid (mk_var ("f"^(int_to_string i), a_ty))
323        in
324          (i+1, v::avoid, v::vs)
325        end
326      val (_, _, vs_rev) = foldl mk_arg (1, avoid, []) (tl arg_tys)
327    in
328     List.rev vs_rev
329    end
330
331    val r0 = mk_var ("x", hd arg_tys)
332    val l0 = variant [r0] r0
333    val args_l = mk_args [r0, l0]
334    val args_r = mk_args (r0::l0::args_l)
335  in
336    (l0::args_l, r0::args_r)
337  end
338
339  val cong_0 = mk_eq (hd args_l, hd args_r)
340  val base = mk_eq (
341               list_mk_comb (case_const, args_l),
342               list_mk_comb (case_const, args_r))
343
344  (*
345    fun extract n =
346      (el n (#cl_constructors constrL),
347       el (n+1) args_l,
348       el (n+1) args_r)
349
350    val (CONSTR (c, vns), al, ar) = extract 2
351
352  *)
353  val congs_main = let
354    fun mk_arg_vars acc avoid (a_ty, vns) = case vns of
355        [] => List.rev acc
356      | (vn::vns') => let
357          val (_, atys) = dest_type a_ty
358          val v = variant avoid (mk_var (vn, hd atys))
359        in
360           mk_arg_vars (v::acc) (v::avoid) (el 2 atys, vns')
361        end
362
363   fun process_all acc neg_pres crs als ars =
364     case (crs, als, ars) of
365        ([], [], []) => List.rev acc
366      | ([], [al], [ar]) => let
367          val arg_ts = mk_arg_vars [] [al, ar] (type_of al, ["x"])
368          val eq_t = mk_eq (list_mk_comb (al, arg_ts),
369                            list_mk_comb (ar, arg_ts))
370
371
372          val pre_t = list_mk_conj neg_pres
373          val t_full = list_mk_forall (arg_ts,  mk_imp (pre_t, eq_t))
374        in
375          List.rev (t_full :: acc)
376        end
377      | ((CONSTR (c, vns))::crs', al::als', ar::ars') => let
378          val arg_ts =  mk_arg_vars [] [al, ar] (type_of al, vns)
379          val eq_t = mk_eq (list_mk_comb (al, arg_ts),
380                            list_mk_comb (ar, arg_ts))
381
382          val pre_t = mk_eq (hd args_r, list_mk_comb (c, arg_ts))
383          val t_full = list_mk_forall (arg_ts,  mk_imp (pre_t, eq_t))
384          val t_exp = list_mk_forall (arg_ts,  mk_neg pre_t)
385        in
386          process_all (t_full::acc) (t_exp::neg_pres) crs' als' ars'
387        end
388       | _ => failwith "" "Something is wrong with the constructors/case constant. Wrong arity somewhere?"
389    in
390      process_all [] [] (#cl_constructors constrL) (tl args_l) (tl args_r)
391    end
392
393in
394  list_mk_forall (args_l @ args_r,
395     list_mk_imp (cong_0 :: congs_main, base))
396end
397
398
399fun mk_nchotomy_thm_term_opt (constrL : constructorList) =
400  if not (#cl_is_exhaustive constrL) then NONE else let
401    val v = mk_var ("x", #cl_type constrL)
402
403    fun mk_disj cr = let
404      val (b, vs) = mk_constructor_term [v] cr
405      val eq = mk_eq (v, b)
406    in
407      list_mk_exists (vs, eq)
408    end
409
410    val eqs = List.map mk_disj (#cl_constructors constrL)
411    val eqs_t = list_mk_disj eqs
412  in
413    SOME (mk_forall (v, eqs_t))
414  end;
415
416
417fun mk_constructorFamily_terms case_const constrL = let
418  val t1 = mk_one_one_thm_term_opt constrL
419  val t2 = mk_distinct_thm_term_opt constrL
420  val t3 = SOME (mk_case_expand_thm_term case_const constrL)
421  val t4 = SOME (mk_case_const_cong_thm_term case_const constrL)
422  val t5 = mk_nchotomy_thm_term_opt constrL
423in
424  [t1, t2, t3, t4, t5]
425end
426
427fun get_constructorFamily_proofObligations (constrL, case_const) = let
428  val ts = mk_constructorFamily_terms case_const constrL
429  val thm = encode_term_opt_list ts
430in
431  rhs (concl thm)
432end
433
434fun set_constructorFamily (constrL, case_const) =
435  set_goal_list (mk_constructorFamily_terms case_const constrL)
436
437fun mk_constructorFamily (constrL, case_const, tac) = let
438  val thms = prove_list (mk_constructorFamily_terms case_const constrL,  tac)
439in
440  {
441    constructors   = constrL,
442    case_const     = case_const,
443    one_one_thm    = el 1 thms,
444    distinct_thm   = el 2 thms,
445    case_split_thm = valOf (el 3 thms),
446    case_cong_thm = valOf (el 4 thms),
447    nchotomy_thm   = el 5 thms
448  }:constructorFamily
449end
450
451
452(***************************************************)
453(* Connection to typebase                          *)
454(***************************************************)
455
456
457(* given a type try to extract the constructors of a type
458   from typebase. Do not use the default type-base functions
459   for this but destruct the nchotomy_thm in order to get
460   the default argument names as well. *)
461fun constructorList_of_typebase ty =
462  if null (TypeBase.constructors_of ty) then NONE else let
463  val nchotomy_thm = TypeBase.nchotomy_of ty
464  val eqs = strip_disj (snd (dest_forall (concl nchotomy_thm)))
465
466  fun dest_eq eq = let
467    val (_, b) = strip_exists eq
468    val (c, args) = strip_comb (rhs b)
469    val args' = List.map (fst o dest_var) args
470  in
471    CONSTR (c, args')
472  end
473
474  val constrs = List.map dest_eq eqs
475in
476  SOME ({ cl_type          = ty,
477    cl_constructors  = constrs,
478    cl_is_exhaustive = true }:constructorList)
479end
480
481fun constructorFamily_of_typebase ty = let
482  val crL = valOf (constructorList_of_typebase ty)
483    handle Option => failwith "constructorList_of_typebase" "not a datatype"
484  val case_split_tm = TypeBase.case_const_of ty
485  val thm_distinct = TypeBase.distinct_of ty
486  val thm_one_one = TypeBase.one_one_of ty handle HOL_ERR _ => TRUTH
487  val thm_case = TypeBase.case_def_of ty
488  val thm_case_cong = TypeBase.case_cong_of ty
489
490  (*  set_constructorFamily (crL, case_split_tm) *)
491  val cf = mk_constructorFamily (crL, case_split_tm,
492    SIMP_TAC std_ss [thm_distinct, thm_one_one, thm_case_cong] THEN
493    REPEAT STRIP_TAC THEN (
494      Cases_on `x` THEN
495      SIMP_TAC std_ss [thm_distinct, thm_one_one, thm_case]
496    )
497  )
498in
499  cf
500end
501
502
503(***************************************************)
504(* Collections of constructorFamilies +            *)
505(* extra matching info                             *)
506(***************************************************)
507
508(* Datatype for representing how well a constructorFamily or
509   a hand-written function matches a column. *)
510type matchcol_stats = {
511  colstat_missed_rows : int,
512     (* how many rows of the col are not constructor applications
513        or bound vars? *)
514
515  colstat_cases : int,
516     (* how many cases are covered ? *)
517
518  colstat_missed_constr : int
519     (* how many constructors of the family do not appear in the column *)
520}
521
522fun matchcol_stats_compare
523  (st1 : matchcol_stats)
524  (st2 : matchcol_stats) = let
525  fun lex_ord (i1, i2) b =
526     (i1 < i2) orelse ((i1 = i2) andalso b)
527in
528  lex_ord (#colstat_missed_rows st1, #colstat_missed_rows st2) (
529    lex_ord (#colstat_cases st1, #colstat_cases st2) (
530       op> (#colstat_missed_constr st1, #colstat_missed_constr st2)
531    )
532  )
533end
534
535
536type pmatch_compile_fun = (term list * term) list -> (thm * int * simpLib.ssfrag) option
537
538type pmatch_nchotomy_fun = (term list * term) list -> (thm * int) option
539
540val typeConstrFamsDB = ref (TypeNet.empty : constructorFamily TypeNet.typenet)
541
542type pmatch_compile_db = {
543  pcdb_compile_funs  : pmatch_compile_fun list,
544  pcdb_nchotomy_funs : pmatch_nchotomy_fun list,
545  pcdb_constrFams    : (constructorFamily list) TypeNet.typenet,
546  pcdb_ss            : simpLib.ssfrag
547}
548
549val empty : pmatch_compile_db = {
550  pcdb_compile_funs = [],
551  pcdb_nchotomy_funs = [],
552  pcdb_constrFams = TypeNet.empty,
553  pcdb_ss = (simpLib.rewrites [])
554}
555
556val thePmatchCompileDB = ref empty
557
558fun lookup_typeBase_constructorFamily ty = let
559  val b_ty = base_ty ty
560in
561  SOME (b_ty, TypeNet.find (!typeConstrFamsDB, b_ty)) handle
562     NotFound => let
563       val cf = constructorFamily_of_typebase b_ty
564       val net = !typeConstrFamsDB
565       val net'= TypeNet.insert (net, b_ty, cf)
566       val _ = typeConstrFamsDB := net'
567     in
568       SOME (b_ty, cf)
569     end
570end handle HOL_ERR _ => NONE
571
572
573fun measure_constructorFamily (cf : constructorFamily) col = let
574  fun list_count p col =
575    foldl (fn (r, c) => if (p r) then c+1 else c) 0 col
576
577  (* extract the constructors of the family *)
578  val crs = List.map (fn (CONSTR (c, _)) => c) (
579    #cl_constructors (#constructors cf))
580
581  fun row_is_missed (vs, p) =
582    if (is_var p andalso mem p vs) then
583      (* bound variables are fine *)
584      false
585    else let
586      val (f, _) = strip_comb p
587    in
588      not (List.exists (same_const f) crs)
589    end handle HOL_ERR _ => true
590
591  fun constr_is_missed c =
592    not (List.exists (fn (vs, p) => let
593       val (f, _) = strip_comb p
594     in
595       same_const f c
596     end handle HOL_ERR _ => false) col)
597
598  val cases_no = List.length (#cl_constructors (#constructors cf))
599  val cases_no' = if (#cl_is_exhaustive (#constructors cf)) then cases_no else (cases_no+1)
600in
601  {
602    colstat_missed_rows = list_count row_is_missed col,
603    colstat_missed_constr = list_count constr_is_missed crs,
604    colstat_cases = cases_no'
605  }
606end
607
608fun lookup_constructorFamilies_for_type (db : pmatch_compile_db) ty = let
609  val cts_fams = let
610    val cts_fams = TypeNet.match (#pcdb_constrFams db, ty)
611    val cts_fams' = Lib.flatten (List.map (fn (ty, l) =>
612       List.map (fn cf => (ty, cf)) l) cts_fams)
613    val cty_opt = lookup_typeBase_constructorFamily ty
614    val cty_l = case cty_opt of
615         NONE => []
616       | SOME (ty, cf) => [(ty, cf)]
617  in cts_fams' @ cty_l end
618
619  fun is_old_fam (ty, cf) = let
620     val (_, cl) = constructorFamily_get_constructors cf
621     fun is_old_const c = let
622       val (cn, _)  = dest_const c
623     in
624       String.isSuffix "<-old" cn
625     end handle HOL_ERR _ => false
626  in
627     (List.exists (fn (c, _) => is_old_const c) cl) orelse
628     (is_old_const (#case_const cf))
629  end
630
631  val cts_fams' = List.filter (fn cf => not (is_old_fam cf)) cts_fams
632in
633  cts_fams'
634end
635
636fun lookup_constructorFamily force_exh (db : pmatch_compile_db) col = let
637  val _ = if (List.null col) then (failwith "constructorFamiliesLib" "lookup_constructorFamilies: null col") else ()
638
639  val _ = if List.all (fn (vs, c) => is_var c andalso Lib.mem c vs) col then
640            (failwith "constructorFamiliesLib" "lookup_constructorFamilies: var col")
641          else ()
642
643  val ty = type_of (snd (hd col))
644  val cts_fams = lookup_constructorFamilies_for_type db ty
645  val cts_fams' = if not force_exh then
646     cts_fams
647  else
648     List.filter (fn (_, cf) => isSome (#nchotomy_thm cf)) cts_fams
649
650  val weighted_fams = List.map (fn (ty, cf) =>
651    ((ty, cf), measure_constructorFamily cf col)) cts_fams'
652
653  val weighted_fams' = filter (fn (_, w) => (#colstat_missed_rows w = 0)) weighted_fams
654
655  val weighted_fams_sorted = sort (fn (_, w1) => fn (_, w2) =>
656    matchcol_stats_compare w1 w2) weighted_fams'
657in
658  case weighted_fams_sorted of
659     [] => NONE
660   | wcf::_ => SOME wcf
661end;
662
663
664fun pmatch_compile_db_compile_aux db col = (
665  if (List.null col) then failwith "pmatch_compile_db_compile" "col 0" else let
666    val fun_res = get_first (fn f => f col handle HOL_ERR _ => NONE) (#pcdb_compile_funs db)
667    val cf_res = lookup_constructorFamily false db col
668
669    fun process_cf_res (ty, cf) w = let
670      val ty_s = match_type ty (type_of (snd (hd col)))
671      val thm = constructorFamily_get_case_split cf
672      val thm' = INST_TYPE ty_s thm
673    in
674      (thm',#colstat_cases w, merge_ss [(#pcdb_ss db), simpLib.rewrites [
675        (constructorFamily_get_rewrites cf)]])
676    end
677  in case (fun_res, cf_res) of
678      (NONE, NONE) => (NONE, NONE)
679    | (NONE, SOME (tycf, w)) => (SOME (process_cf_res tycf w), SOME tycf)
680    | (SOME (thm, c_no, ss), NONE) => (SOME (thm, c_no, ss), NONE)
681    | (SOME (thm, c_no, ss), SOME (tycf, w)) => if (c_no < #colstat_cases w) then
682        (SOME (thm, c_no, ss), NONE) else (SOME (process_cf_res tycf w), SOME tycf)
683  end
684);
685
686fun pmatch_compile_db_compile db col = (
687  fst (pmatch_compile_db_compile_aux db col))
688
689fun pmatch_compile_db_compile_cf db col = (
690  case (snd (pmatch_compile_db_compile_aux db col)) of
691     NONE => NONE
692   | SOME (_, cf) => SOME cf
693);
694
695(*
696fun pmatch_compile_db_compile_nchotomy db col = (
697  if (List.null col) then failwith "pmatch_compile_db_compile_cf" "col 0" else
698  case (get_first (fn f => f col handle HOL_ERR _ => NONE) (#pcdb_nchotomy_funs db)) of
699    SOME r => r | NONE => (
700      case (lookup_constructorFamilies true db col) of
701          NONE => NONE
702        | SOME (_, cf) => #nchotomy_thm cf))
703*)
704
705fun pmatch_compile_db_compile_nchotomy db col = (
706  if (List.null col) then failwith "pmatch_compile_db_compile_nchotomy" "col 0" else let
707    val fun_res = get_first (fn f => f col handle HOL_ERR _ => NONE) (#pcdb_nchotomy_funs db)
708    val cf_res = lookup_constructorFamily true db col
709
710    fun process_cf_res (_, cf) = #nchotomy_thm cf
711
712  in case (fun_res, cf_res) of
713      (NONE, NONE) => NONE
714    | (NONE, SOME (tycf, _)) => process_cf_res tycf
715    | (SOME (thm, _), NONE) => SOME thm
716    | (SOME (thm, _), SOME (tycf, w)) => if (0 < #colstat_missed_rows w) then
717        (SOME thm) else (process_cf_res tycf)
718  end
719);
720
721fun pmatch_compile_db_dest_constr_term (db : pmatch_compile_db) t = let
722  val ty = type_of t
723  val cfs = lookup_constructorFamilies_for_type db ty
724  val cstrs = flatten (List.map (#cl_constructors o #constructors o snd) cfs)
725in
726  first_opt (fn _ => fn cr => match_constructor cr t) cstrs
727end
728
729
730(***************************************************)
731(* updating dbs                                    *)
732(***************************************************)
733
734fun pmatch_compile_db_add_ssfrag (db : pmatch_compile_db) ss = {
735  pcdb_compile_funs = #pcdb_compile_funs db,
736  pcdb_nchotomy_funs = #pcdb_nchotomy_funs db,
737  pcdb_constrFams = #pcdb_constrFams db,
738  pcdb_ss = (simpLib.merge_ss [ss, #pcdb_ss db])
739} : pmatch_compile_db
740
741fun pmatch_compile_db_add_congs db thms =
742  pmatch_compile_db_add_ssfrag db (cong_ss thms);
743
744fun pmatch_compile_db_register_ssfrag ss =
745  thePmatchCompileDB := pmatch_compile_db_add_ssfrag (!thePmatchCompileDB) ss;
746
747fun pmatch_compile_db_register_congs thms =
748  pmatch_compile_db_register_ssfrag (cong_ss thms)
749
750fun pmatch_compile_db_add_compile_fun (db : pmatch_compile_db) cf = {
751  pcdb_compile_funs = cf::(#pcdb_compile_funs db),
752  pcdb_nchotomy_funs = #pcdb_nchotomy_funs db,
753  pcdb_constrFams = #pcdb_constrFams db,
754  pcdb_ss = #pcdb_ss db
755} : pmatch_compile_db
756
757fun pmatch_compile_db_register_compile_fun cf =
758  thePmatchCompileDB := pmatch_compile_db_add_compile_fun (!thePmatchCompileDB) cf
759
760fun pmatch_compile_db_add_nchotomy_fun (db : pmatch_compile_db) cf = {
761  pcdb_compile_funs = #pcdb_compile_funs db,
762  pcdb_nchotomy_funs = cf::(#pcdb_nchotomy_funs db),
763  pcdb_constrFams = #pcdb_constrFams db,
764  pcdb_ss = #pcdb_ss db
765} : pmatch_compile_db
766
767fun pmatch_compile_db_register_nchotomy_fun f =
768  thePmatchCompileDB := pmatch_compile_db_add_nchotomy_fun (!thePmatchCompileDB) f
769
770fun pmatch_compile_db_add_constrFam (db : pmatch_compile_db) cf = {
771  pcdb_compile_funs = #pcdb_compile_funs db,
772  pcdb_nchotomy_funs = #pcdb_nchotomy_funs db,
773  pcdb_constrFams = let
774    val cl = (#constructors cf)
775    val ty = normalise_ty (#cl_type cl)
776    val net = #pcdb_constrFams db
777    val cfs = TypeNet.find (net, ty) handle NotFound => []
778    val net' = TypeNet.insert (net, ty, cf::cfs)
779  in
780    net'
781  end,
782  pcdb_ss = merge_ss [constructorFamily_get_ssfrag cf, (#pcdb_ss db)]
783} : pmatch_compile_db
784
785fun pmatch_compile_db_register_constrFam cf =
786  thePmatchCompileDB := pmatch_compile_db_add_constrFam (!thePmatchCompileDB) cf
787
788fun pmatch_compile_db_remove_type (db : pmatch_compile_db) ty = {
789  pcdb_compile_funs = #pcdb_compile_funs db,
790  pcdb_nchotomy_funs = #pcdb_nchotomy_funs db,
791  pcdb_constrFams = let
792    val ty = normalise_ty ty
793    val net = #pcdb_constrFams db
794    val net' = TypeNet.insert (net, ty, [])
795  in
796    net'
797  end,
798  pcdb_ss = #pcdb_ss db
799} : pmatch_compile_db
800
801fun pmatch_compile_db_clear_type ty =
802  thePmatchCompileDB := pmatch_compile_db_remove_type (!thePmatchCompileDB) ty
803
804
805
806(***************************************************)
807(* complilation funs                               *)
808(***************************************************)
809
810val COND_CONG_APPLY = prove (``(if (x:'a) = c then (ff x):'b else ff x) =
811  (if x = c then ff c else ff x)``,
812Cases_on `x = c` THEN ASM_REWRITE_TAC[])
813
814
815fun literals_compile_fun (col:(term list * term) list) = let
816
817  fun extract_literal ((vs, c), (tl, ts)) = let
818     val vars = FVL [c] empty_tmset
819     val is_lit = not (List.exists (fn v => HOLset.member (vars, v)) vs)
820  in
821    if is_lit then (
822         if (HOLset.member(ts,c)) then
823            (tl, ts)
824         else
825            ((c::tl), HOLset.add(ts,c))
826    ) else
827      (if is_var c then (tl, ts) else failwith "" "extract_literal")
828  end
829
830  val (lits_rev, _) = List.foldl extract_literal ([], empty_tmset) col
831  val _ = if (List.null lits_rev) then (failwith "" "no lits") else ()
832  val lits = List.rev lits_rev
833  val cases_no = List.length lits + 1
834
835  val rty = gen_tyvar ()
836  val lit_ty = type_of (snd (List.hd col))
837  val split_arg = mk_var ("x", lit_ty)
838  val split_fun = mk_var ("ff", lit_ty --> rty)
839  val arg = mk_comb (split_fun, split_arg)
840
841  fun mk_expand_thm lits = case lits of
842      [] => REFL arg
843    | (l :: lits') => let
844         val b = mk_eq (split_arg, l)
845         val thm0 = GSYM (ISPEC arg (SPEC b COND_ID))
846         val thm1 = CONV_RULE (RHS_CONV (REWR_CONV COND_CONG_APPLY)) thm0
847         val thm2a = mk_expand_thm lits'
848         val thm2 = CONV_RULE (RHS_CONV (RAND_CONV (K thm2a))) thm1
849      in
850         thm2
851      end
852
853  val thm0 = mk_expand_thm lits
854  val thm1 = let
855    val thm0_rhs = rhs (concl thm0)
856    val thm1a = GSYM (ISPECL [mk_abs(split_arg, thm0_rhs), split_arg] literal_case_THM)
857    val thm1 = CONV_RULE (LHS_CONV BETA_CONV) thm1a
858  in
859    thm1
860  end
861  val thm2 = TRANS thm0 thm1
862  val thm3 = GEN split_fun (GEN split_arg thm2)
863
864
865  val cong_thm = let
866     fun mk_lits_preconds (sua, sub, c_tms) pre lits =
867       case lits of
868           [] => let
869             val negs = map (fn pl => mk_neg (mk_eq (split_arg, pl))) pre
870             val a = list_mk_conj negs
871             val sf = mk_comb (split_fun, split_arg)
872             val va = genvar (type_of sf)
873             val vb = genvar (type_of sf)
874             val c = mk_eq (va, vb)
875
876             val new_p = mk_imp (a, c)
877
878           in ((sf |-> va)::sua, (sf |-> vb)::sub, new_p::c_tms) end
879         | (l::lits') => let
880             val negs = map (fn pl => mk_neg (mk_eq (l, pl))) pre
881             val eq = mk_eq (split_arg, l)
882             val a = list_mk_conj (eq::negs)
883
884             val sf = mk_comb (split_fun, l)
885             val va = genvar (type_of sf)
886             val vb = genvar (type_of sf)
887             val c = mk_eq (va, vb)
888
889             val new_p = mk_imp (a, c)
890         in
891            (mk_lits_preconds ((sf |-> va)::sua, (sf |-> vb)::sub, new_p::c_tms) (l::pre) lits')
892         end
893
894    val (sua, sub, c_tms) = mk_lits_preconds ([], [], []) [] lits
895    val tt00 = rhs (concl thm0)
896
897    val tt0a = subst sua tt00
898    val tt0b = subst sub tt00
899    val tt0 = mk_eq (tt0a, tt0b)
900    val tt1 = list_mk_imp (List.rev c_tms, tt0)
901    val thm1 = prove(tt1, metisLib.METIS_TAC[])
902  in
903    thm1
904  end
905in
906  SOME (thm3, cases_no, cong_ss [cong_thm])
907end
908
909val _ = pmatch_compile_db_register_compile_fun literals_compile_fun
910
911
912(***************************************************)
913(* nchotomy funs                                   *)
914(***************************************************)
915
916fun literals_nchotomy_fun (col:(term list * term) list) = let
917  fun extract_literal ((vs, c), ts) = let
918     val vars = FVL [c] empty_tmset
919     val is_lit = not (List.exists (fn v => HOLset.member (vars, v)) vs)
920  in
921    if is_lit then HOLset.add(ts,c) else
922      (if is_var c then ts else failwith "" "extract_literal")
923  end
924
925  val ts = List.foldl extract_literal empty_tmset col
926  val lits = HOLset.listItems ts
927  val cases_no = List.length lits + 1
928  val _ = if (List.null lits) then (failwith "" "no lits") else ()
929
930  val lit_ty = type_of (snd (List.hd col))
931  val split_arg = mk_var ("x", lit_ty)
932  val wc_arg = mk_var ("y", lit_ty)
933
934  val lit_tms = List.map (fn l => mk_eq (split_arg, l)) lits
935  val wc_tm = let
936    val not_tms =
937        List.map (fn l => mk_neg (mk_eq (wc_arg, l))) lits
938    val eq_tm = mk_eq (split_arg, wc_arg)
939    val b_tm = mk_conj (eq_tm, list_mk_conj not_tms)
940  in
941    mk_exists (wc_arg, b_tm)
942  end
943
944  val nchot_tm = list_mk_disj (lit_tms @ [wc_tm])
945  val nchot_thm = prove(nchot_tm,
946    CONV_TAC (DEPTH_CONV Unwind.UNWIND_EXISTS_CONV) THEN
947    EVERY (List.map (fn t =>
948       (BOOL_CASES_TAC t THEN REWRITE_TAC[])) lit_tms))
949  val nchot_thm' = GEN split_arg nchot_thm
950in
951  SOME (nchot_thm', cases_no)
952end handle HOL_ERR _ => NONE
953
954val _ = pmatch_compile_db_register_nchotomy_fun literals_nchotomy_fun
955
956
957
958end
959