113901SN/Astructure countableLib :> countableLib = struct
213901SN/Aopen HolKernel bossLib Tactical Drule lcsymtacs
313901SN/Aopen pred_setTheory countable_initTheory
413901SN/Aopen boolSyntax numSyntax pairSyntax pred_setSyntax
513901SN/A
613901SN/Afun uneta tm = let
713901SN/A  val (t,_) = dom_rng (type_of tm)
813901SN/A  val x = genvar t
913901SN/Ain mk_abs(x,mk_comb(tm,x)) end
1013901SN/A
1113901SN/Aval mk_count_aux_inj_rwt_ttac = let
1213901SN/A  val count_num2 = ``count_num2``
1313901SN/A  fun count_args ctr = let
1413901SN/A    fun f [] = term_of_int 0
1513901SN/A      | f [a] = mk_comb(ctr (type_of a), a)
1613901SN/A      | f (a::xs) = mk_comb(count_num2,mk_pair(mk_comb(ctr (type_of a), a),f xs))
1713901SN/A  in f end
1813901SN/A  fun mk_eqn ctr lhs0 (c,(n,eqs,d)) = let
1913901SN/A    val (c,ars) = strip_comb c
2013901SN/A    val (ars,d) = foldr
2113901SN/A      (fn (a,(ars,d)) => let val (n,ty) = dest_var a in
2213901SN/A        case Redblackmap.peek(d,n) of
2313901SN/A          SOME ty' => if ty = ty' then (a::ars,d) else
2413901SN/A            let val vs = Redblackmap.foldl (fn (n,ty,ls) => mk_var(n,ty)::ls) [] d
2513901SN/A              val a' = variant vs a
2613901SN/A            in (a'::ars,Redblackmap.insert(d,fst(dest_var a'),ty)) end
2713901SN/A        | NONE => (a::ars,Redblackmap.insert(d,n,ty)) end)
2813901SN/A      ([],d) ars
29    val c = list_mk_comb(c,ars)
30    val lhs = mk_comb(lhs0, c)
31    val rhs = mk_comb(count_num2,mk_pair(term_of_int n,count_args ctr ars))
32    val eq = mk_eq (lhs, rhs)
33  in (n+1,eq::eqs,d) end
34  fun mk_inj_rwt_tm hyps (v,ctr) = let
35    val (n,ty) = dest_var v
36    val v' = mk_var (Lib.prime n, ty)
37  in list_mk_forall([v,v'],list_mk_imp(hyps,mk_eq(mk_eq(mk_comb(ctr,v),mk_comb(ctr,v')),mk_eq(v,v')))) end
38in fn tys => fn ttac => let
39  val (names,argls) = unzip (map dest_type tys)
40  val nchotomys = map (fn ty => SPEC_ALL (TypeBase.nchotomy_of ty)) tys
41  val constructorls = map (fn th => map (rhs o snd o strip_exists) (strip_disj (concl th))) nchotomys
42  val args = Lib.mk_set (flatten argls)
43  val helpers = map (fn a => mk_var("count_"^(dest_vartype a),a --> num)) args
44  val count_name_auxs = map (fn n => "count_"^n^"_aux") names
45  val count_ty_aux_vars = map (fn (ty,count_name_aux) =>
46    mk_var(count_name_aux,
47        foldr (fn (h,ty) => type_of h --> ty) (ty --> num) helpers))
48    (zip tys count_name_auxs)
49  val lhs0s = map (fn v => list_mk_comb (v,helpers)) count_ty_aux_vars
50  val counters = zip tys lhs0s
51  fun counter_search c t = let
52    val (n,ars) = dest_type t
53    val ty = foldr (fn (x,t) => (x --> num) --> t) (t --> num) ars
54    val ctr = mk_const("count_"^n^"_aux",ty)
55  in list_mk_comb(ctr, map (uneta o c) ars) end
56  fun counter t = Lib.assoc t counters
57    handle HOL_ERR _ => Lib.first (fn h => fst(dom_rng(type_of h)) = t) helpers
58    handle HOL_ERR _ => counter_search counter t
59  val (eqs,_) = foldl (fn ((lhs0,cl),(eqs,d)) =>
60       let val (_,eqs,d) = foldl (mk_eqn counter lhs0) (0,eqs,d) cl in (eqs,d) end)
61    ([],Redblackmap.mkDict String.compare)
62    (zip lhs0s constructorls)
63  val define = case ttac of NONE => xDefine | SOME ttac => (fn x => fn y => tDefine x y ttac)
64  val aux_name = hd count_name_auxs
65  val count_aux_def = define aux_name [ANTIQUOTE (list_mk_conj eqs)]
66  val count_aux_thm = SIMP_RULE (srw_ss()++boolSimps.ETA_ss) [] count_aux_def
67  val aux_name_thm = aux_name^"_thm"
68  val _ = save_thm(aux_name_thm,count_aux_thm)
69  val _ = export_rewrites[aux_name_thm]
70  val count_ty_aux_tms = map (fn (n,v) => mk_const(n, type_of v)) (zip count_name_auxs count_ty_aux_vars)
71  val hyps = map (mk_inj_rwt_tm []) (zip (map (fn a => mk_var(dest_vartype a, a)) args) helpers)
72  val induction = TypeBase.induction_of (hd tys)
73  val cvars = map (fst o dest_forall) (strip_conj(snd(strip_imp(snd(strip_forall(concl induction))))))
74  val lhs1s = map (fn c => list_mk_comb (c,helpers)) count_ty_aux_tms
75  val counters = zip tys lhs1s
76  fun counter t = Lib.assoc t counters handle HOL_ERR _ => counter_search counter t
77  val concls = map (fn v => let
78    val (n,ty) = dest_var v
79    val ctr = counter ty
80    in mk_inj_rwt_tm hyps (v,ctr) end)
81    cvars
82  val th = prove(list_mk_conj concls,
83    ho_match_mp_tac induction >>
84    srw_tac[boolSimps.ETA_ss][] >>
85    qmatch_rename_tac `(_ = _ z) ��� _` >>
86    Cases_on `z` >> rw[])
87  val (_,ths) = foldr
88    (fn (ty,(all,ths)) => let
89      val (th,all) = Lib.pluck (fn th => ty = type_of(fst(dest_forall(concl th)))) all
90    in (all,th::ths) end)
91    (CONJUNCTS th,[])
92    tys
93  val ths = map (GENL helpers o (SIMP_RULE (srw_ss()++boolSimps.ETA_ss) [])) ths
94  val names = map (fn n => n^"_inj_rwt") count_name_auxs
95  val _ = map save_thm (zip names ths)
96  val _ = export_rewrites names
97in ths end
98end
99
100val mk_count_aux_inj_rwt = Lib.C mk_count_aux_inj_rwt_ttac NONE
101
102fun mk_countable a = ``countable ^(mk_univ a)``
103
104fun inj_rwt_to_countable th = let
105  val (_,t) = dest_var(fst(dest_forall(concl th)))
106in prove(mk_countable t,
107  rw[countable_def,INJ_DEF] >>
108  prove_tac[th])
109end
110
111end
112