1structure bagSimps :> bagSimps =
2struct
3
4open HolKernel Parse boolLib simpLib boolSimps bagSyntax bagTheory;
5
6type cache = Cache.cache;
7
8val ERR = mk_HOL_ERR "bagSimps";
9
10val BAG_AC_ss = simpLib.SSFRAG {name=SOME"BAG_AC",
11    convs = [], rewrs = [], dprocs = [], congs = [],
12    ac = [(SPEC_ALL ASSOC_BAG_UNION, SPEC_ALL COMM_BAG_UNION)],
13    filter = NONE
14};
15
16fun buac_prover ty = let
17  fun type_inst ty = INST_TYPE [alpha |-> ty]
18in
19  AC_CONV (type_inst ty ASSOC_BAG_UNION, type_inst ty COMM_BAG_UNION)
20end
21
22val SUB_BAG_UNION_eliminate' =
23  hd (CONJUNCTS
24      (CONV_RULE (SIMP_CONV bool_ss [FORALL_AND_THM])
25       SUB_BAG_UNION_eliminate))
26val BAG_DIFF_UNION_eliminate' =
27  hd (CONJUNCTS
28      (CONV_RULE (SIMP_CONV bool_ss [FORALL_AND_THM])
29       BAG_DIFF_UNION_eliminate))
30val BU_EMPTY_R = hd (CONJUNCTS BAG_UNION_EMPTY)
31
32fun CANCEL_CONV tm = let
33  val (mk_rel, thm, (arg1, arg2)) =
34    (mk_sub_bag, SUB_BAG_UNION_eliminate', dest_sub_bag tm)
35    handle HOL_ERR _ =>
36      (mk_diff, BAG_DIFF_UNION_eliminate', dest_diff tm)
37      handle HOL_ERR _ => (mk_eq, BAG_UNION_LEFT_CANCEL, dest_eq tm)
38  val basetype = base_type arg1
39  val bag_type = basetype --> numSyntax.num
40  val arg1_ts = listset (strip_union arg1)
41  val arg2_ts = listset (strip_union arg2)
42  val common_part = arg1_ts Isct arg2_ts
43  val _ = not (HOLset.isEmpty common_part) orelse
44    raise ERR "CANCEL_CONV" "No common parts to eliminate"
45  val rem1 = HOLset.listItems (arg1_ts -- common_part)
46  val rem2 = HOLset.listItems (arg2_ts -- common_part)
47  val cpt = list_mk_union (HOLset.listItems common_part)
48  val ac1 = mk_eq(arg1, if null rem1 then cpt
49                        else mk_union (cpt, list_mk_union rem1))
50  val ac2 = mk_eq(arg2, if null rem2 then cpt
51                        else mk_union (cpt, list_mk_union rem2))
52  val ac1thm = EQT_ELIM (buac_prover basetype ac1)
53  val ac2thm = EQT_ELIM (buac_prover basetype ac2)
54  fun add_emptybag thm = let
55    val r = rhs (concl thm)
56  in
57    TRANS thm
58    (SYM (REWR_CONV BU_EMPTY_R (mk_union(cpt, mk_bag([], basetype)))))
59  end
60  val thm1 = if null rem1 then add_emptybag ac1thm else ac1thm
61  val thm2 = if null rem2 then add_emptybag ac2thm else ac2thm
62  val v1 = genvar bag_type and v2 = genvar bag_type
63  val template = mk_rel (v1, v2)
64in
65  SUBST_CONV [v1 |-> thm1, v2 |-> thm2] template THENC
66  REWR_CONV thm
67end tm
68
69val x = mk_var("x", bag_ty)
70val y = mk_var("y", bag_ty)
71fun mk_cancelconv (t, s) =
72  {conv = K (K (CHANGED_CONV CANCEL_CONV)),
73   key = SOME ([], list_mk_comb(t, [x, y])),
74   name = "CANCEL_CONV ("^s^")", trace = 2}
75
76val BAG_EQ_tm = mk_const("=", bag_ty --> bag_ty --> bool);
77
78val BAG_ss = SSFRAG
79  {name=SOME"BAG",
80   ac = [], congs = [],
81   convs = map mk_cancelconv [(BAG_DIFF_tm, "DIFF"),
82                              (SUB_BAG_tm, "SUB_BAG"),
83                              (BAG_EQ_tm, "=")],
84   filter = NONE, dprocs = [],
85   rewrs = map (fn s => (SOME{Thy = "bag", Name = s}, DB.fetch "bag" s)) [
86     "BAG_UNION_EMPTY", "BAG_DIFF_EMPTY", "SUB_BAG_REFL",
87     "SUB_BAG_EMPTY","FINITE_EMPTY_BAG",
88     "NOT_IN_EMPTY_BAG"
89   ]
90  };
91
92fun transform t =
93  ((if is_sub_bag t then
94      REWR_CONV SUB_BAG_LEQ
95    else if is_eq t then
96      FUN_EQ_CONV
97    else NO_CONV) THENC
98   PURE_REWRITE_CONV [BAG_UNION] THENC DEPTH_CONV BETA_CONV) t
99
100fun SBAG_SOLVE thms tm = let
101  val newgoal_thm = transform tm
102  val newgoal_tm = rhs (concl newgoal_thm)
103  val (gvar, gbody) = dest_forall newgoal_tm
104  val newasms = mapfilter (SPEC gvar o CONV_RULE transform) thms
105  val newasms_tm = list_mk_conj (map concl newasms)
106  val goal_thm0 = numLib.ARITH_PROVE (mk_imp(newasms_tm, gbody))
107  val goal_thm1 = MP goal_thm0 (LIST_CONJ newasms)
108  val goal_thm2 = EQT_INTRO (GEN gvar goal_thm1)
109  val thm = TRANS newgoal_thm goal_thm2
110  val _  =  Trace.trace(1,Trace.PRODUCE(tm,"SBAG_SOLVE",thm))
111in
112  thm
113end
114
115val diff_free = not o can (find_term is_diff)
116fun is_ok t =
117  (is_sub_bag t orelse (is_eq t andalso is_bag_ty (type_of (rand t)))) andalso
118  diff_free t
119val (CACHED_SBAG_SOLVE, sbag_cache) =
120    Cache.RCACHE(free_vars, is_ok, SBAG_SOLVE)
121
122
123val SBAG_SOLVER = let
124  exception CTXT of thm list;
125  fun get_ctxt e = (raise e) handle CTXT c => c
126  fun add_ctxt(ctxt, newthms) = let
127    val addthese = filter (is_ok o concl) (flatten (map CONJUNCTS newthms))
128  in
129    CTXT (addthese @ get_ctxt ctxt)
130  end
131in
132  Traverse.REDUCER
133  {name=SOME"SBAG_SOLVER",
134   addcontext = add_ctxt,
135   apply = fn args => CACHED_SBAG_SOLVE (get_ctxt (#context args)),
136   initial = CTXT []}
137end;
138
139val SBAG_SOLVE_ss = SSFRAG
140  {name=SOME"SBAG_SOLVE",
141   ac = [], convs = [], filter = NONE, rewrs = [],
142   dprocs = [SBAG_SOLVER], congs = []}
143
144end
145