1(*=====================================================================  *)
2(* FILE          : quantHeuristicsLibAbbrev.sml                          *)
3(* DESCRIPTION   : Abbreviate subterms                                   *)
4(*                                                                       *)
5(* AUTHORS       : Thomas Tuerk                                          *)
6(* DATE          : Oct 2012                                              *)
7(* ===================================================================== *)
8
9
10structure quantHeuristicsLibAbbrev :> quantHeuristicsLibAbbrev =
11struct
12
13(*
14quietdec := true;
15loadPath :=
16            (concat [Globals.HOLDIR, "/src/quantHeuristics"])::
17            !loadPath;
18
19map load ["quantHeuristicsTheory"];
20load "ConseqConv"
21show_assums := true;
22quietdec := true;
23*)
24
25open HolKernel Parse boolLib Drule
26     quantHeuristicsLibBase
27
28fun no_var_const_filter t =
29  not ((is_var t) orelse (is_const t))
30
31local
32   fun find_tms (d:(term, int) Redblackmap.dict) tm =
33   let
34      val d = Redblackmap.update (d, tm, (fn vo => (getOpt (vo, 0)) + 1))
35   in
36      find_tms d (body tm)
37      handle HOL_ERR _ =>
38         case Lib.total dest_comb tm of
39            SOME (r1, r2) => (find_tms (find_tms d r1) r2)
40          | NONE => d
41   end
42
43   val empty_tm_dict:(term, int) Redblackmap.dict = Redblackmap.mkDict Term.compare
44
45in
46   (* Returns a list of all the subterms of t that satisfy P ordered by the number of appearences in t.
47      Moreover, a function that allows to efficiently lookup how often a term apears in t is returned *)
48   fun find_terms_count P t =
49   let
50      val d = find_tms empty_tm_dict t;
51      fun d_fun t = getOpt(Redblackmap.peek (d, t), 0)
52
53      val tL = Redblackmap.listItems d;
54      val ftL = filter (fn (t, _) => P t) tL
55      val stL = Lib.sort (fn (_, i) => fn (_, j:int) => (i > j)) ftL
56      val ftL = map fst stL
57   in
58      (ftL, d_fun)
59   end
60
61end
62
63
64(*
65val t = ``P (FST (h y)) /\ !x. (Q (SND (g x)) /\ P (f (FST (g x))))``
66
67val t = ``P (FST (h y)) /\ (Q (SND (g x)) /\ P (f (FST (g x))))``
68
69val elim_abbrev_abort = false
70fun elim_conv _ = NO_CONV
71val intro_abbrev = true
72val only_top = true
73*)
74
75type selection_fun = term -> (term -> int) -> term -> (term * string) list
76
77fun select_funs_combine (sfL:selection_fun list) = (fn ctx => fn cf =>
78   (let val sfL' = map (fn sf => sf ctx cf) sfL in
79   (fn t =>
80      flatten (map (fn sf => sf t handle HOL_ERR _ => []) sfL')) end)):selection_fun
81
82fun INTRO_QUANT_ABBREVIATIONS_CONV_base elim_conv elim_abbrev_abort intro_abbrev only_top (select_funs:selection_fun list) t =
83let
84   val select_terms = let
85      val (tL, t_count_fun) = (find_terms_count (K true) t);
86      val sf = select_funs_combine select_funs t t_count_fun
87      val sfs = flatten (map sf tL)
88      val fsfs = filter (fn (tt, _) => no_var_const_filter tt) sfs
89
90      fun my_insert ((tt, n), l) = if (exists (fn (tt', _) => tt' = tt) l) then l else (tt,n)::l
91      val fsfs_unique = rev (foldl my_insert [] fsfs)
92   in fsfs_unique end
93   val fvL = ref (all_vars t);
94(*
95   val (st, ss) = hd select_terms
96*)
97
98
99   fun try_select (st, ss) t =
100   let
101      val new_v = variant (!fvL) (mk_var (ss, type_of st))
102      val s = [st |-> new_v];
103      val _ = fvL := new_v :: (!fvL);
104
105      fun v_inst_filter v t i fvL = not (i = st)
106      fun v_filter v t = v = new_v
107      val v_qps = [inst_filter_qp [v_inst_filter], filter_qp [v_filter]]
108
109(*
110      val (_, t0) = dest_abs (rand (rand t))
111val t0 = t
112val t0 = qasm_t
113val is_top = true
114val t0 = snd (strip_forall t)
115*)
116      fun try_subst is_top t0 =
117      let
118         val t' = subst s t0
119         val is_top' = (is_top andalso is_forall t0)
120      in
121         if (t0 = t') then (if (only_top andalso (not is_top')) then fail() else
122                               (if is_top' then QUANT_CONV (try_subst true) t0 else
123                                                (SUB_CONV (try_subst false) t0))) else
124         let
125           val abbrev_t = mk_forall (new_v, mk_imp(mk_eq (st, new_v), t'));
126           val elim_thm = QCHANGED_CONV (elim_conv v_qps) abbrev_t handle HOL_ERR _ => REFL abbrev_t
127           val _ = if (aconv t0 (rhs (concl elim_thm))) then fail() else ()
128           val no_simp = (aconv abbrev_t (rhs (concl elim_thm)))
129           val _ = if (elim_abbrev_abort andalso no_simp) then fail() else ()
130           val elim_thm = if (no_simp andalso intro_abbrev andalso is_top) then
131              let
132                 val pre_thm = CONV_RULE (LHS_CONV SYM_CONV) (GSYM (ISPEC (mk_eq (new_v, st)) markerTheory.Abbrev_def))
133              in QUANT_CONV (RATOR_CONV (RAND_CONV (K pre_thm))) abbrev_t end
134              else elim_thm;
135
136           val abbrev_thm = prove (mk_eq (t0, abbrev_t), Unwind.UNWIND_FORALL_TAC THEN REWRITE_TAC [])
137         in
138           TRANS abbrev_thm elim_thm
139         end
140      end
141   in
142     try_subst true t
143   end;
144in
145  EVERY_CONV (map (fn arg => TRY_CONV (try_select arg)) select_terms) t
146end;
147
148fun GEN_SIMPLE_QUANT_ABBREV_CONV intro_abbrev only_top select_funs =
149   (INTRO_QUANT_ABBREVIATIONS_CONV_base (K NO_CONV) false intro_abbrev only_top select_funs)
150
151fun SIMPLE_QUANT_ABBREV_CONV select_funs =
152    GEN_SIMPLE_QUANT_ABBREV_CONV false false select_funs
153
154fun SIMPLE_QUANT_ABBREV_TAC (select_funs:selection_fun list) =
155  ConseqConv.DISCH_ASM_CONV_TAC (GEN_SIMPLE_QUANT_ABBREV_CONV true true select_funs)
156
157fun quant_elim_conv qpL qps = QCHANGED_CONV (FAST_QUANT_INSTANTIATE_CONV (qps@qpL))
158
159fun QUANT_ABBREV_CONV select_funs qpL = (INTRO_QUANT_ABBREVIATIONS_CONV_base
160 (quant_elim_conv qpL) false false false select_funs)
161
162fun QUANT_ABBREV_TAC (select_funs:selection_fun list) qpL =
163  ConseqConv.DISCH_ASM_CONV_TAC (QUANT_ABBREV_CONV select_funs qpL)
164
165
166(* Some select functions *)
167
168(* Searching for constants "c" and abbreviation argument number "i"
169   with name "name". i = 0 means the whole term, i = 1 the first argument ... *)
170fun select_fun_constant c i (name:string) = (fn ctx => fn cf => (fn t =>
171  let
172    val (c', aL) = strip_comb t;
173    val _ = if same_const c c' then () else fail();
174  in
175    [(if i = 0 then t else el i aL, name)]
176  end handle HOL_ERR _ => [])): selection_fun
177
178(* This allows common selection funs *)
179val FST_select_fun = select_fun_constant pairSyntax.fst_tm 1 "p"
180val SND_select_fun = select_fun_constant pairSyntax.snd_tm 1 "p"
181val IS_SOME_select_fun = select_fun_constant optionSyntax.is_some_tm 1 "x"
182val THE_select_fun = select_fun_constant optionSyntax.the_tm 1 "p"
183
184(* For pairs, the pattern "(\(x,y). P) X" is useful as well *)
185fun select_fun_pabs name = (fn ctx => fn cf => (fn t =>
186  let
187    val (p, a) = dest_comb t;
188    val _ = if pairSyntax.is_pabs p andalso not (is_abs p) then () else fail();
189  in
190    [(p, name)]
191  end handle HOL_ERR _ => [])): selection_fun;
192
193val PAIR_select_fun = select_funs_combine [select_fun_pabs "p", FST_select_fun, SND_select_fun];
194
195
196(* In general, we might want pattern matching. The following function tries matching
197   and abbreviates any matched variable in the pattern that does not start with _.
198   The general version allows to specify that it needs to occur at least n times as well. *)
199(*
200val q = `f (_, xx)`
201val ctx = ``Q (f (2, g(z))) /\ Y x``
202val t = ``f (2, g(z))``
203*)
204fun select_fun_pattern_occ n q = (fn ctx => fn cf =>
205let
206  val fvs = all_vars ctx
207  val pat = Parse.parse_in_context fvs q
208  val fvs_set = HOLset.fromList (Term.compare) fvs
209in
210  fn t => let
211     val ((ts, _), _) = raw_match [] fvs_set pat t ([], [])
212     val sb' = filter (fn {redex=l,residue=r} => not (String.isPrefix "_" (fst (dest_var l))) andalso
213                     cf r >= n) ts
214  in (map (fn {redex=l,residue=r} => (r, fst (dest_var l))) sb') end
215end):selection_fun
216
217fun select_fun_pattern q = select_fun_pattern_occ 0 q
218
219(* you might want to abbreviate the whole match however, as well *)
220
221fun select_fun_match_occ n q name_fun = (fn ctx => fn cf =>
222let
223  val fvs = all_vars ctx
224  val pat = Parse.parse_in_context fvs q
225  val fvs_set = HOLset.fromList (Term.compare) fvs
226in
227  fn t => let
228     val _ = if (cf t) >= n then () else fail();
229     val _ = raw_match [] fvs_set pat t ([], [])
230  in [(t, name_fun t)] end
231end):selection_fun
232
233fun select_fun_match q name = select_fun_match_occ 0 q (K name)
234
235
236(* Testing it
237
238open quantHeuristicsLibAbbrev
239open quantHeuristicsLib
240
241val t = ``P (FST (g x)) /\ (P (FST (g x))) /\ P (FST (f x)) /\ !y. Q p /\ P (SND (g y))``
242
243
244SIMPLE_QUANT_ABBREV_CONV [FST_select_fun]  t
245SIMPLE_QUANT_ABBREV_CONV [SND_select_fun]  t
246
247QUANT_ABBREV_CONV [FST_select_fun] [std_qp] t
248QUANT_ABBREV_CONV [SND_select_fun] [std_qp] t
249QUANT_ABBREV_CONV [FST_select_fun, SND_select_fun] [std_qp] t
250
251
252set_goal ([], t)
253e (SIMPLE_QUANT_ABBREV_TAC [SND_select_fun, FST_select_fun])
254e (QUANT_ABBREV_TAC [SND_select_fun, FST_select_fun] [std_qp])
255
256
257val select_funs = [FST_select_fun, SND_select_fun]
258
259set_goal ([], t)
260e (QUANT_ABBREV_TAC [FST_select_fun, SND_select_fun] [std_qp])
261e (SIMPLE_QUANT_ABBREV_TAC [FST_select_fun, SND_select_fun])
262
263Q.UNABBREV_TAC `p''`
264
265val t2 = ``Q x ==> (IS_SOME (g x)) ==> (IS_SOME (f x)) ==> P (THE (g x), THE (f x))``
266
267REPEAT STRIP_TAC
268QUANT_ABBREV_TAC [select_fun_pattern `IS_SOME dummy`] [std_qp]
269
270QUANT_ABBREV_TAC [select_fun_match `f x` "gx"] [std_qp]
271
272QUANT_ABBREV_CONV [select_fun_match `g x` "gx"] [std_qp] t2
273SIMPLE_QUANT_ABBREV_TAC [THE_select_fun]
274
275*)
276
277end
278