1(*---------------------------------------------------------------------------*
2 * Simplifications for Datatypes. This Library extracts information about
3 * datatypes from Typebase and provides some theorems and conversations that
4 * are suitable to reason about this datatype.
5 *---------------------------------------------------------------------------*)
6structure DatatypeSimps :> DatatypeSimps =
7struct
8
9open HolKernel Parse boolLib TypeBasePure
10open simpLib
11val std_ss = boolSimps.bool_ss
12
13
14fun map_option_filter f [] = []
15  | map_option_filter f (x :: xs) = case (f x handle
16      Interrupt => raise Interrupt | _ => NONE) of
17        NONE => map_option_filter f xs
18      | SOME fx => fx :: (map_option_filter f xs)
19
20fun tyinfos_of_tys tyl = map_option_filter TypeBase.fetch tyl
21
22
23(******************************************************************************)
24(* Generating thms                                                            *)
25(******************************************************************************)
26
27fun make_variant_list n s avoid [] = []
28  | make_variant_list n s avoid (h::t) =
29      let val v = variant avoid (mk_var(s^Int.toString n, h))
30      in v::make_variant_list (n + 1) s (v::avoid) t
31      end
32
33fun make_args_simple [] = []
34  | make_args_simple (ty :: tys) =
35    let
36      val arg0 = mk_var("M", ty);
37      val args = make_variant_list 0 "f" [arg0] tys;
38    in
39      arg0 :: args
40    end
41
42fun make_args_abs tyL = let
43  fun aux res n m avoid ty = let
44    val (arg_tyL, base_ty) = strip_fun ty
45    val args = make_variant_list m "x" avoid arg_tyL
46    val b = variant (args @ avoid) (mk_var("f"^Int.toString n, ty))
47  in
48    ((args, b) :: res, n+1, m+(length arg_tyL), b::(args@avoid))
49  end;
50  val arg0 = mk_var("M", hd tyL);
51  val (args, _, _, _) = foldl (fn (ty, (res, n, m, avoid)) =>
52      aux res n m avoid ty) ([], 0, 0, [arg0]) (tl tyL)
53in
54  (arg0, rev args)
55end
56
57fun mk_type_forall_thm_tyinfo tyinfo = let
58  val nchotomy_thm = nchotomy_of tyinfo;
59  val ty = type_of (fst (dest_forall (concl nchotomy_thm)))
60
61  val P_tm = mk_var ("P", ty --> bool)
62  val input_tm = mk_var ("tt", ty)
63  val body_tm = mk_comb (P_tm, input_tm)
64
65  val thm_base = GSYM (CONJUNCT1 (SPEC body_tm IMP_CLAUSES))
66  val true_expand_thm = GSYM (EQT_INTRO (SPEC input_tm nchotomy_thm))
67  val thm1 = CONV_RULE (RHS_CONV (RATOR_CONV (RAND_CONV (K true_expand_thm)))) thm_base
68
69  val thm2 = QUANT_CONV (K thm1) (mk_forall (input_tm, body_tm))
70  val thm3 = CONV_RULE (RHS_CONV (SIMP_CONV std_ss [DISJ_IMP_THM, GSYM LEFT_FORALL_IMP_THM, FORALL_AND_THM])) thm2
71
72  val thm4 = GEN P_tm thm3
73in
74  thm4
75end
76
77
78fun mk_type_quant_thms_tyinfo tyinfo = let
79  val forall_thm = mk_type_forall_thm_tyinfo tyinfo;
80
81  val (P_tm, _) = dest_forall (concl forall_thm)
82  val P_arg_tm = genvar (hd (fst (strip_fun (type_of P_tm))))
83  val P_neg_tm = mk_abs (P_arg_tm, mk_neg (mk_comb (P_tm, P_arg_tm)))
84
85  val thm0 = SPEC P_neg_tm forall_thm
86  val thm1 = AP_TERM boolSyntax.negation thm0
87  val thm3 = CONV_RULE (BINOP_CONV (SIMP_CONV std_ss [])) thm1
88  val thm4 = GEN P_tm thm3
89in
90  (forall_thm, thm4)
91end
92
93
94fun mk_type_exists_thm_tyinfo tyinfo =
95  snd (mk_type_quant_thms_tyinfo tyinfo)
96
97
98fun mk_case_elim_thm_tyinfo tyinfo = let
99  val case_c = case_const_of tyinfo;
100  val (arg_tyL, base_ty) = strip_fun (type_of case_c);
101  val (input_arg, case_args) = make_args_abs arg_tyL
102  val avoid = input_arg :: flatten (List.map (fn (args, b) => (b :: args)) case_args)
103  val const = variant avoid (mk_var ("c", base_ty))
104
105  val t0 = mk_comb (case_c, input_arg)
106  val t1 = foldl (fn ((args, _), t) =>
107     mk_comb (t, list_mk_abs (args, const))) t0 case_args
108  val t2 = mk_eq (t1, const)
109  val t3 = list_mk_forall ([input_arg, const], t2)
110
111  val forall_thm = mk_type_forall_thm_tyinfo tyinfo
112  val simp_thm = case_def_of tyinfo
113  val thm0 = HO_REWR_CONV forall_thm t3
114  val thm1 = CONV_RULE (RHS_CONV (REWRITE_CONV [simp_thm])) thm0
115  val thm2 = EQT_ELIM thm1
116in
117  thm2
118end
119
120
121fun mk_type_rewrites_tyinfo tyinfo = let
122  val thm_def0 = case_def_of tyinfo;
123  val thms_def = CONJUNCTS thm_def0
124
125  val thms_dist = case (distinct_of tyinfo) of
126      NONE => []
127    | SOME thm_dist0 => let
128        val thms_dist1 = CONJUNCTS thm_dist0
129        val thms_dist = thms_dist1 @ List.map GSYM thms_dist1
130      in thms_dist end
131
132  val thms_one_one = case (one_one_of tyinfo) of
133      NONE => []
134    | SOME thm => CONJUNCTS thm;
135
136  val elim_thms = [mk_case_elim_thm_tyinfo tyinfo]
137in
138  elim_thms @ thms_def @ thms_dist @ thms_one_one
139end
140
141
142fun mk_case_cong_thm_tyinfo tyinfo = let
143  val case_c = case_const_of tyinfo;
144  val (arg_tyL, base_ty) = strip_fun (type_of case_c);
145  val (input_arg, case_args) = make_args_abs arg_tyL
146  val avoid = input_arg :: flatten (List.map (fn (args, b) => (b :: args)) case_args)
147  val input_arg' = variant avoid input_arg
148  val (avoid, case_args') = Lib.foldl_map (fn (av, (args, v)) =>
149      let val v' = variant av v in
150      (v' :: av, (args, v')) end) (input_arg'::avoid, case_args)
151
152  val case_args0 = List.map (fn (args, c) =>
153     list_mk_abs (args, list_mk_comb (c, args)))
154     case_args
155  val t1a = list_mk_icomb (case_c, [input_arg] @ case_args0)
156
157  val case_args1 = List.map (fn (args, c) =>
158     list_mk_abs (args, list_mk_comb (c, args)))
159     case_args'
160  val t1b = list_mk_icomb (case_c, [input_arg'] @ case_args1)
161
162  val t2 = mk_eq(t1a, t1b)
163
164  val constr = constructors_of tyinfo
165  val M_eq = mk_eq (input_arg, input_arg')
166  fun mk_imp args c c' cr = let
167     val t0 = list_mk_icomb (c, args)
168     val t1 = list_mk_icomb (c', args)
169     val t2 = mk_eq (t0, t1)
170     val u0 = list_mk_icomb (cr, args)
171     val u1 = mk_eq (input_arg', u0)
172     val t3 = boolSyntax.mk_imp (u1, t2)
173     val t4 = list_mk_forall (args, t3)
174  in
175    t4
176  end
177  val imps = List.map (fn (((args, c), (_, c')), cr) =>
178     mk_imp args c c' cr)
179     (zip (zip case_args case_args') constr)
180
181
182  val t3 = boolSyntax.list_mk_imp  (M_eq::imps, t2)
183  val t4 = list_mk_forall ([input_arg, input_arg']@(List.map snd case_args)@(List.map snd case_args'), t3)
184
185  val forall_thm = mk_type_forall_thm_tyinfo tyinfo
186  val simp_thm = case_def_of tyinfo
187  val thm0 = HO_REWR_CONV forall_thm t4
188  val thm1 = CONV_RULE (RHS_CONV (SIMP_CONV std_ss [simp_thm])) thm0
189  val thm2 = EQT_ELIM thm1
190in
191  thm2
192end
193
194
195fun mk_case_rand_thm_tyinfo tyinfo = let
196  val case_c = case_const_of tyinfo;
197  val (arg_tyL, base_ty) = strip_fun (type_of case_c);
198  val (input_arg, case_args) = make_args_abs arg_tyL
199  val avoid = input_arg :: flatten (List.map (fn (args, b) => (b :: args)) case_args)
200  val res_ty = gen_tyvar ()
201  val const = variant avoid (mk_var ("f", base_ty --> res_ty))
202
203  val case_args0 = List.map (fn (args, c) =>
204     list_mk_abs (args, list_mk_comb (c, args)))
205     case_args
206  val t1 = list_mk_icomb (case_c, [input_arg] @ case_args0)
207  val t2a = mk_comb (const, t1)
208
209  val case_args1 = List.map (fn (args, c) =>
210     list_mk_abs (args, mk_comb (const, list_mk_comb (c, args))))
211     case_args
212  val t2b = list_mk_icomb (case_c, [input_arg] @ case_args1)
213
214  val t3 = mk_eq (t2a, t2b)
215  val consts = List.map snd case_args;
216  val t4 = list_mk_forall ([input_arg, const]@consts, t3)
217
218  val forall_thm = mk_type_forall_thm_tyinfo tyinfo
219  val simp_thm = case_def_of tyinfo
220  val thm0 = HO_REWR_CONV forall_thm t4
221  val thm1 = CONV_RULE (RHS_CONV (SIMP_CONV std_ss [simp_thm])) thm0
222  val thm2 = EQT_ELIM thm1
223in
224  thm2
225end
226
227
228fun mk_case_rator_thm_tyinfo tyinfo = let
229  val case_c = case_const_of tyinfo;
230  val (arg_tyL, base_ty) = strip_fun (type_of case_c);
231  val res_ty = gen_tyvar ()
232  val base_ty' = gen_tyvar ()
233  val inst_ty = inst [base_ty |-> (res_ty --> base_ty')]
234
235  val (input_arg, case_args) = make_args_abs arg_tyL
236  val avoid = input_arg :: flatten (List.map (fn (args, b) => (b :: args)) case_args)
237  val const = variant avoid (mk_var ("x", res_ty))
238
239  val case_args0 = List.map (fn (args, c) =>
240     list_mk_abs (args, list_mk_comb (c, args)))
241     case_args
242  val t1 = list_mk_icomb (case_c, [input_arg] @ case_args0)
243  val t2 = inst_ty t1
244  val t3a = mk_icomb (t2, const)
245
246  val case_args1 = List.map (fn (args, c) =>
247     list_mk_abs (args, mk_comb (inst_ty (list_mk_comb (c, args)), const)))
248     case_args
249  val t3b = list_mk_icomb (case_c, [input_arg] @ case_args1)
250
251  val t4 = mk_eq (t3a, t3b)
252  val consts = List.map (fn (_, t) => inst_ty t) case_args;
253  val t5 = list_mk_forall ([input_arg, const]@consts, t4)
254
255  val forall_thm = mk_type_forall_thm_tyinfo tyinfo
256  val simp_thm = case_def_of tyinfo
257  val thm0 = HO_REWR_CONV forall_thm t5
258  val thm1 = CONV_RULE (RHS_CONV (SIMP_CONV std_ss [simp_thm])) thm0
259  val thm2 = EQT_ELIM thm1
260in
261  thm2
262end
263
264
265fun mk_case_abs_thm_tyinfo tyinfo = let
266  val case_c = case_const_of tyinfo;
267  val (arg_tyL, base_ty) = strip_fun (type_of case_c);
268  val res_ty = gen_tyvar ()
269  val base_ty' = gen_tyvar ()
270  val inst_ty = inst [base_ty |-> (res_ty --> base_ty')]
271  val (input_arg, case_args) = make_args_abs arg_tyL
272  val avoid = input_arg :: flatten (List.map (fn (args, b) => (b :: args)) case_args)
273  val const = variant avoid (mk_var ("x", res_ty))
274
275  val case_args0 = List.map (fn (args, c) =>
276     list_mk_abs (args, mk_comb (inst_ty (list_mk_comb (c, args)), const)))
277     case_args
278  val t1 = list_mk_icomb (case_c, [input_arg] @ case_args0)
279  val t2a = mk_abs (const, t1)
280
281  val case_args1 = List.map (fn (args, c) =>
282     list_mk_abs (args, mk_abs (const, mk_comb (inst_ty (list_mk_comb (c, args)), const))))
283     case_args
284  val t2b = list_mk_icomb (case_c, [input_arg] @ case_args1)
285
286  val t3 = mk_eq (t2a, t2b)
287  val consts = List.map (fn (_, t) => inst [base_ty |-> (res_ty --> base_ty')] t) case_args;
288  val t4 = list_mk_forall ([input_arg, const]@consts, t3)
289
290  val forall_thm = mk_type_forall_thm_tyinfo tyinfo
291  val simp_thm = case_def_of tyinfo
292  val thm0 = HO_REWR_CONV forall_thm t4
293  val thm1 = CONV_RULE (RHS_CONV (SIMP_CONV std_ss [simp_thm])) thm0
294  val thm2 = EQT_ELIM thm1
295in
296  thm2
297end
298
299
300(******************************************************************************)
301(* Lifting                                                                    *)
302(******************************************************************************)
303
304fun lift_case_const_CONV stop_consts rand_thms = let
305  val conv = Ho_Rewrite.GEN_REWRITE_CONV I rand_thms
306in (fn t => let
307  val thm = conv t
308  val (c, args) = strip_comb t
309in
310  if (List.length args > 1 andalso List.exists (same_const c) stop_consts) then raise UNCHANGED else
311  thm
312end handle HOL_ERR _ => raise UNCHANGED) end
313
314
315fun lift_cases_typeinfos_ss til = let
316   val rand_thms = Lib.mapfilter mk_case_rand_thm_tyinfo til
317   val rator_thms = Lib.mapfilter mk_case_rator_thm_tyinfo til
318   val abs_thms = Lib.mapfilter mk_case_abs_thm_tyinfo til
319   val consts = Lib.mapfilter case_const_of til
320
321   val conv_rand = lift_case_const_CONV consts rand_thms
322   val conv_rand_ss = simpLib.std_conv_ss {
323      name = "lift_case_const_CONV",
324      pats = [``f x``],
325      conv = conv_rand}
326
327   val rewr_ss = simpLib.rewrites (abs_thms @ rator_thms)
328in
329   simpLib.merge_ss [rewr_ss, conv_rand_ss]
330end
331
332fun lift_cases_ss tyL = lift_cases_typeinfos_ss (tyinfos_of_tys tyL)
333
334fun lift_cases_stateful_ss () = lift_cases_typeinfos_ss (TypeBase.elts ())
335
336
337(******************************************************************************)
338(* Reverse Lifting                                                            *)
339(******************************************************************************)
340
341fun unlift_case_const_CONV stop_consts rand_thms = let
342  val conv = Rewrite.GEN_REWRITE_CONV I empty_rewrites rand_thms
343in (fn t => let
344  val thm = conv t
345  val (c, args) = strip_comb (rhs (concl thm))
346in
347  if (List.length args > 1 andalso List.exists (same_const c) stop_consts) then raise UNCHANGED else
348  thm
349end handle HOL_ERR _ => raise UNCHANGED) end
350
351fun unlift_cases_typeinfos_ss til = let
352   val rand_thms = List.map GSYM (Lib.mapfilter mk_case_rand_thm_tyinfo til)
353   val rator_thms = List.map GSYM (Lib.mapfilter mk_case_rator_thm_tyinfo til)
354   val abs_thms = List.map GSYM (Lib.mapfilter mk_case_abs_thm_tyinfo til)
355   val consts = Lib.mapfilter case_const_of til
356
357   val conv_rand = unlift_case_const_CONV consts rand_thms
358   val conv_rand_ss = simpLib.std_conv_ss {
359      name = "unlift_case_const_CONV",
360      pats = [``f x``],
361      conv = conv_rand}
362
363   val conv_rator_ss = simpLib.std_conv_ss {
364      name = "unlift_case_const_CONV",
365      pats = [``f x``],
366      conv = Rewrite.GEN_REWRITE_CONV I empty_rewrites rator_thms}
367
368   val rewr_ss = simpLib.rewrites abs_thms
369in
370   simpLib.merge_ss [rewr_ss, conv_rator_ss, conv_rand_ss]
371end
372
373fun unlift_cases_ss tyL = unlift_cases_typeinfos_ss (tyinfos_of_tys tyL)
374
375fun unlift_cases_stateful_ss () = unlift_cases_typeinfos_ss (TypeBase.elts ())
376
377
378(******************************************************************************)
379(* Simpset fragments                                                          *)
380(******************************************************************************)
381
382fun type_rewrites_typeinfos_ss til =
383  rewrites (flatten (Lib.mapfilter mk_type_rewrites_tyinfo til))
384
385fun type_rewrites_ss tyL = type_rewrites_typeinfos_ss (tyinfos_of_tys tyL)
386
387fun type_rewrites_stateful_ss () = type_rewrites_typeinfos_ss (TypeBase.elts ())
388
389fun congs thms = SSFRAG
390    {name  = NONE,
391     convs = [],
392     rewrs = [],
393        ac = [],
394    filter = NONE,
395    dprocs = [],
396     congs = thms}
397
398fun case_cong_typeinfos_ss til =
399  simpLib.merge_ss [congs (Lib.mapfilter mk_case_cong_thm_tyinfo til),
400                    type_rewrites_typeinfos_ss til]
401
402fun case_cong_ss tyL = case_cong_typeinfos_ss (tyinfos_of_tys tyL)
403
404fun case_cong_stateful_ss () = case_cong_typeinfos_ss (TypeBase.elts ())
405
406
407
408fun expand_type_quants_typeinfos_ss til =
409  rewrites (flatten (List.map (fn (x, y) => [x, y]) (Lib.mapfilter
410     mk_type_quant_thms_tyinfo til)))
411
412fun expand_type_quants_ss tyL = expand_type_quants_typeinfos_ss (tyinfos_of_tys tyL)
413
414fun expand_type_quants_stateful_ss () = expand_type_quants_typeinfos_ss (TypeBase.elts ())
415
416
417(******************************************************************************)
418(* Rule for eliminating case splits in equations                              *)
419(******************************************************************************)
420
421fun cases_to_top_RULE thm = let
422  val input_thmL = BODY_CONJUNCTS thm
423  val (input_eqL, input_restL) = partition (fn thm => is_eq (concl thm)) input_thmL
424
425  fun process_eq eq_thm = let
426     val free_vars_lhs = free_vars (lhs (concl eq_thm));
427     fun search_pred t = let
428       val (c, args) = strip_comb t
429       val _ = if length args = 0 then fail () else ();
430       val _ = if (List.exists (term_eq (hd args)) free_vars_lhs) then () else fail();
431       val case_const = TypeBase.case_const_of (type_of (hd args));
432     in
433       same_const c case_const
434     end handle HOL_ERR _ => false;
435     val case_term = find_term search_pred (rhs (concl eq_thm));
436     val (_, split_args) = strip_comb case_term;
437     val split_var = hd split_args;
438     val tyinfo = valOf (TypeBase.fetch (type_of split_var)) handle Option => fail()
439     val free_vars_full = free_vars (concl eq_thm)
440     val split_terms = List.map (fn (c_tm, cr_tm) => let
441        val (_, cr_ret_type) = strip_fun (type_of cr_tm);
442        val ty_inst = match_type cr_ret_type (type_of split_var);
443        val cr_tm' = inst ty_inst cr_tm;
444        val (args, _) = strip_abs c_tm
445        val (_, args') = foldl_map (fn (av, v) => let val v' = variant av v in (v' :: av, v') end) (free_vars_full, args)
446      in list_mk_comb (cr_tm', args') end)
447        (zip (tl split_args) (TypeBasePure.constructors_of tyinfo))
448
449
450     val rhs_conv = REWRITE_CONV (#rewrs (TypeBasePure.simpls_of tyinfo)) THENC
451                    DEPTH_CONV BETA_CONV
452     fun process_thm split_tm = let
453       val thm0 = INST [split_var |-> split_tm] eq_thm
454       val thm1 = CONV_RULE (RHS_CONV rhs_conv) thm0
455     in thm1 end
456     val result = List.map process_thm split_terms
457  in
458    SOME result
459  end handle HOL_ERR _ => NONE
460
461  fun process_all acc [] = List.rev acc
462    | process_all acc (eq_thm :: thms) = (case process_eq eq_thm of
463          NONE => process_all (eq_thm :: acc) thms
464        | SOME eq_thms => process_all acc (eq_thms @ thms))
465
466  val processed_eq_thms = process_all [] input_eqL
467  val all_thms = processed_eq_thms @ input_restL
468in
469  LIST_CONJ (List.map GEN_ALL all_thms)
470end
471
472end
473