1structure bagSimps :> bagSimps =
2struct
3
4open HolKernel Parse boolLib simpLib boolSimps bagSyntax bagTheory;
5
6type cache = Cache.cache;
7
8infixr -->
9infix |-> THENC
10
11val ERR = mk_HOL_ERR "bagSimps";
12
13val BAG_AC_ss = simpLib.SSFRAG {name=SOME"BAG_AC",
14    convs = [], rewrs = [], dprocs = [], congs = [],
15    ac = [(SPEC_ALL ASSOC_BAG_UNION, SPEC_ALL COMM_BAG_UNION)],
16    filter = NONE
17};
18
19(* remove x xs removes one instance of x from xs *)
20fun remove x [] = raise ERR "remove" "no such element"
21  | remove x (y::xs) = if x = y then xs else y::(remove x xs)
22
23fun remove_list [] l2 = l2
24  | remove_list (x::xs) l2 = remove_list xs (remove x l2)
25
26fun buac_prover ty = let
27  fun type_inst ty = INST_TYPE [alpha |-> ty]
28in
29  AC_CONV (type_inst ty ASSOC_BAG_UNION, type_inst ty COMM_BAG_UNION)
30end
31
32val SUB_BAG_UNION_eliminate' =
33  hd (CONJUNCTS
34      (CONV_RULE (SIMP_CONV bool_ss [FORALL_AND_THM])
35       SUB_BAG_UNION_eliminate))
36val BAG_DIFF_UNION_eliminate' =
37  hd (CONJUNCTS
38      (CONV_RULE (SIMP_CONV bool_ss [FORALL_AND_THM])
39       BAG_DIFF_UNION_eliminate))
40val BU_EMPTY_R = hd (CONJUNCTS BAG_UNION_EMPTY)
41
42fun CANCEL_CONV tm = let
43  val (mk_rel, thm, (arg1, arg2)) =
44    (mk_sub_bag, SUB_BAG_UNION_eliminate', dest_sub_bag tm)
45    handle HOL_ERR _ =>
46      (mk_diff, BAG_DIFF_UNION_eliminate', dest_diff tm)
47      handle HOL_ERR _ => (mk_eq, BAG_UNION_LEFT_CANCEL, dest_eq tm)
48  val basetype = base_type arg1
49  val bag_type = basetype --> numSyntax.num
50  val arg1_ts = strip_union arg1 and arg2_ts = strip_union arg2
51  fun common [] _ = []  (* like intersect but no setifying *)
52    | common _ [] = []
53    | common (x::xs) y = x::common xs (remove x y)
54    handle _ => common xs y
55  val common_part = common arg1_ts arg2_ts
56  val _ = not (null common_part) orelse
57    raise ERR "CANCEL_CONV" "No common parts to eliminate"
58  val rem1 = remove_list common_part arg1_ts
59  val rem2 = remove_list common_part arg2_ts
60  val cpt = list_mk_union common_part
61  val ac1 = mk_eq(arg1, if null rem1 then cpt
62                        else mk_union (cpt, list_mk_union rem1))
63  val ac2 = mk_eq(arg2, if null rem2 then cpt
64                        else mk_union (cpt, list_mk_union rem2))
65  val ac1thm = EQT_ELIM (buac_prover basetype ac1)
66  val ac2thm = EQT_ELIM (buac_prover basetype ac2)
67  fun add_emptybag thm = let
68    val r = rhs (concl thm)
69  in
70    TRANS thm
71    (SYM (REWR_CONV BU_EMPTY_R (mk_union(cpt, mk_bag([], basetype)))))
72  end
73  val thm1 = if null rem1 then add_emptybag ac1thm else ac1thm
74  val thm2 = if null rem2 then add_emptybag ac2thm else ac2thm
75  val v1 = genvar bag_type and v2 = genvar bag_type
76  val template = mk_rel (v1, v2)
77in
78  SUBST_CONV [v1 |-> thm1, v2 |-> thm2] template THENC
79  REWR_CONV thm
80end tm
81
82val x = mk_var("x", bag_ty)
83val y = mk_var("y", bag_ty)
84fun mk_cancelconv (t, s) =
85  {conv = K (K (CHANGED_CONV CANCEL_CONV)),
86   key = SOME ([], list_mk_comb(t, [x, y])),
87   name = "CANCEL_CONV ("^s^")", trace = 2}
88
89val BAG_EQ_tm = mk_const("=", bag_ty --> bag_ty --> bool);
90
91val BAG_ss = SSFRAG
92  {name=SOME"BAG",
93   ac = [], congs = [],
94   convs = map mk_cancelconv [(BAG_DIFF_tm, "DIFF"),
95                              (SUB_BAG_tm, "SUB_BAG"),
96                              (BAG_EQ_tm, "=")],
97   filter = NONE, dprocs = [],
98   rewrs = [BAG_UNION_EMPTY, BAG_DIFF_EMPTY, SUB_BAG_REFL,
99            SUB_BAG_EMPTY,FINITE_EMPTY_BAG,
100            NOT_IN_EMPTY_BAG]};
101
102fun transform t =
103  ((if is_sub_bag t then
104      REWR_CONV SUB_BAG_LEQ
105    else if is_eq t then
106      FUN_EQ_CONV
107    else NO_CONV) THENC
108   PURE_REWRITE_CONV [BAG_UNION] THENC DEPTH_CONV BETA_CONV) t
109
110fun SBAG_SOLVE thms tm = let
111  val newgoal_thm = transform tm
112  val newgoal_tm = rhs (concl newgoal_thm)
113  val (gvar, gbody) = dest_forall newgoal_tm
114  val newasms = mapfilter (SPEC gvar o CONV_RULE transform) thms
115  val newasms_tm = list_mk_conj (map concl newasms)
116  val goal_thm0 = numLib.ARITH_PROVE (mk_imp(newasms_tm, gbody))
117  val goal_thm1 = MP goal_thm0 (LIST_CONJ newasms)
118  val goal_thm2 = EQT_INTRO (GEN gvar goal_thm1)
119  val thm = TRANS newgoal_thm goal_thm2
120  val _  =  Trace.trace(1,Trace.PRODUCE(tm,"SBAG_SOLVE",thm))
121in
122  thm
123end
124
125val diff_free = not o can (find_term is_diff)
126fun is_ok t =
127  (is_sub_bag t orelse (is_eq t andalso is_bag_ty (type_of (rand t)))) andalso
128  diff_free t
129val (CACHED_SBAG_SOLVE, sbag_cache) =
130    Cache.RCACHE(free_vars, is_ok, SBAG_SOLVE)
131
132
133val SBAG_SOLVER = let
134  exception CTXT of thm list;
135  fun get_ctxt e = (raise e) handle CTXT c => c
136  fun add_ctxt(ctxt, newthms) = let
137    val addthese = filter (is_ok o concl) (flatten (map CONJUNCTS newthms))
138  in
139    CTXT (addthese @ get_ctxt ctxt)
140  end
141in
142  Traverse.REDUCER
143  {name=SOME"SBAG_SOLVER",
144   addcontext = add_ctxt,
145   apply = fn args => CACHED_SBAG_SOLVE (get_ctxt (#context args)),
146   initial = CTXT []}
147end;
148
149val SBAG_SOLVE_ss = SSFRAG
150  {name=SOME"SBAG_SOLVE",
151   ac = [], convs = [], filter = NONE, rewrs = [],
152   dprocs = [SBAG_SOLVER], congs = []}
153
154end
155