1(* ===================================================================== *)
2(* FILE          : Ho_Rewrite.sml                                        *)
3(* DESCRIPTION   : Rewriting routines using higher-order matching. A     *)
4(*                 straightforward adaptation of the first order         *)
5(*                 rewriter found in Rewrite.sml.                        *)
6(*                                                                       *)
7(* AUTHOR        : Don Syme                                              *)
8(* ===================================================================== *)
9
10
11structure Ho_Rewrite :> Ho_Rewrite =
12struct
13
14open HolKernel boolTheory boolSyntax Abbrev
15     Drule Conv Tactic Tactical Ho_Net;
16
17type pred = term -> bool;
18
19infix ## |->;
20
21val ERR = mk_HOL_ERR "Ho_Rewrite";
22
23val term_to_string = Parse.term_to_string;
24
25(*-----------------------------------------------------------------------------
26 * Split a theorem into a list of theorems suitable for rewriting:
27 *
28 *   1. Specialize all variables (SPEC_ALL).
29 *
30 *   2. Then do the following:
31 *
32 *        |- t1 /\ t2     -->    [|- t1 ; |- t2]
33 *
34 *   3. Then |- t --> |- t = T and |- ~t --> |- t = F
35 *
36 *---------------------------------------------------------------------------*)
37
38fun mk_rewrites th =
39  let val th = SPEC_ALL th
40      val t = concl th
41  in
42  if is_eq t   then [th] else
43  if is_conj t then (op @ o (mk_rewrites##mk_rewrites) o CONJ_PAIR) th else
44  if is_neg t  then [EQF_INTRO th]
45               else [EQT_INTRO th]
46  end
47  handle e => raise (wrap_exn "Ho_Rewrite" "mk_rewrites" e);
48
49
50val monitoring = ref false;
51
52val _ = register_btrace ("Ho_Rewrite", monitoring);
53
54(*---------------------------------------------------------------------------
55    A datatype of rewrite rule sets.
56 ---------------------------------------------------------------------------*)
57
58datatype rewrites = RW of {thms :thm list,  net :conv Ho_Net.net}
59
60fun dest_rewrites(RW{thms, ...}) = thms
61
62val empty_rewrites = RW{thms = [],  net = Ho_Net.empty}
63
64val implicit = ref empty_rewrites;
65
66fun add_rewrites (RW{thms,net}) thl =
67  RW{thms = thms@thl,
68     net = itlist Ho_Net.enter
69     (map (fn th => (HOLset.listItems (hyp_frees th),
70                     lhs(concl th), Conv.HO_REWR_CONV th))
71        (itlist (append o mk_rewrites) thl [])) net}
72
73fun implicit_rewrites() = #thms ((fn (RW x) => x) (!implicit));
74fun set_implicit_rewrites thl = implicit := add_rewrites empty_rewrites thl;
75fun add_implicit_rewrites thl = implicit := add_rewrites (!implicit) thl;
76
77fun stringulate _ [] = []
78  | stringulate f [x] = [f x]
79  | stringulate f (h::t) = f h::",\n"::stringulate f t;
80
81
82(*---------------------------------------------------------------------------
83      Create a conversion from some rewrite rules
84 ---------------------------------------------------------------------------*)
85
86fun REWRITES_CONV (RW{net,...}) tm =
87 if !monitoring
88 then case mapfilter (fn f => f tm) (Ho_Net.lookup tm net)
89       of []   => Conv.NO_CONV tm
90        | [x]  => (HOL_MESG (String.concat
91                    ["Rewrite:\n", Parse.thm_to_string x]) ; x)
92        | h::t => (HOL_MESG (String.concat
93           ["Multiple rewrites possible (first taken):\n",
94            String.concat (stringulate Parse.thm_to_string (h::t))]); h)
95 else Conv.FIRST_CONV (Ho_Net.lookup tm net) tm;
96
97local open boolTheory
98in
99val _ = add_implicit_rewrites
100   [REFL_CLAUSE, EQ_CLAUSES, NOT_CLAUSES, AND_CLAUSES,
101    OR_CLAUSES, IMP_CLAUSES, FORALL_SIMP, EXISTS_SIMP,
102    ABS_SIMP, SELECT_REFL, SELECT_REFL_2, COND_CLAUSES];
103end;
104
105(* =====================================================================*)
106(* Main rewriting conversion                                            *)
107(* =====================================================================*)
108
109fun GEN_REWRITE_CONV' rw_func rws thl =
110     rw_func (REWRITES_CONV (add_rewrites rws thl));
111
112(* ---------------------------------------------------------------------*)
113(* Rewriting conversions.                                               *)
114(* ---------------------------------------------------------------------*)
115
116val PURE_REWRITE_CONV      = GEN_REWRITE_CONV' TOP_DEPTH_CONV empty_rewrites
117val PURE_ONCE_REWRITE_CONV = GEN_REWRITE_CONV' ONCE_DEPTH_CONV empty_rewrites
118fun REWRITE_CONV thl       = GEN_REWRITE_CONV' TOP_DEPTH_CONV (!implicit) thl
119fun ONCE_REWRITE_CONV thl  = GEN_REWRITE_CONV' ONCE_DEPTH_CONV (!implicit) thl
120fun GEN_REWRITE_RULE f rws = CONV_RULE o GEN_REWRITE_CONV' f rws;
121val PURE_REWRITE_RULE      = GEN_REWRITE_RULE TOP_DEPTH_CONV empty_rewrites
122val PURE_ONCE_REWRITE_RULE = GEN_REWRITE_RULE ONCE_DEPTH_CONV empty_rewrites
123fun REWRITE_RULE thl       = GEN_REWRITE_RULE TOP_DEPTH_CONV (!implicit) thl
124fun ONCE_REWRITE_RULE thl  = GEN_REWRITE_RULE ONCE_DEPTH_CONV (!implicit) thl;
125
126fun PURE_ASM_REWRITE_RULE L th = PURE_REWRITE_RULE((map ASSUME(hyp th))@L) th
127fun PURE_ONCE_ASM_REWRITE_RULE L th =
128    PURE_ONCE_REWRITE_RULE((map ASSUME(hyp th)) @ L) th
129fun ASM_REWRITE_RULE thl th = REWRITE_RULE ((map ASSUME (hyp th)) @ thl) th
130fun ONCE_ASM_REWRITE_RULE L th = ONCE_REWRITE_RULE((map ASSUME(hyp th))@ L) th
131
132fun GEN_REWRITE_TAC f rws = CONV_TAC o GEN_REWRITE_CONV' f rws
133val PURE_REWRITE_TAC = GEN_REWRITE_TAC TOP_DEPTH_CONV empty_rewrites
134val PURE_ONCE_REWRITE_TAC = GEN_REWRITE_TAC ONCE_DEPTH_CONV empty_rewrites
135fun REWRITE_TAC thl = GEN_REWRITE_TAC TOP_DEPTH_CONV (!implicit) thl
136fun ONCE_REWRITE_TAC thl = GEN_REWRITE_TAC ONCE_DEPTH_CONV (!implicit) thl
137fun PURE_ASM_REWRITE_TAC L = ASSUM_LIST (fn l => PURE_REWRITE_TAC (l @ L))
138fun ASM_REWRITE_TAC thl = ASSUM_LIST (fn asl => REWRITE_TAC (asl @ thl))
139fun PURE_ONCE_ASM_REWRITE_TAC L = ASSUM_LIST(fn l=>PURE_ONCE_REWRITE_TAC(l@L))
140fun ONCE_ASM_REWRITE_TAC L  = ASSUM_LIST (fn l => ONCE_REWRITE_TAC (l @ L));
141
142fun FILTER_PURE_ASM_REWRITE_RULE f thl th =
143    PURE_REWRITE_RULE ((map ASSUME (filter f (hyp th))) @ thl) th
144fun FILTER_ASM_REWRITE_RULE f thl th =
145    REWRITE_RULE ((map ASSUME (filter f (hyp th))) @ thl) th
146fun FILTER_PURE_ONCE_ASM_REWRITE_RULE f thl th =
147    PURE_ONCE_REWRITE_RULE ((map ASSUME (filter f (hyp th))) @ thl) th
148fun FILTER_ONCE_ASM_REWRITE_RULE f thl th =
149    ONCE_REWRITE_RULE ((map ASSUME (filter f (hyp th))) @ thl) th;;
150fun FILTER_PURE_ASM_REWRITE_TAC f thl =
151    ASSUM_LIST (fn asl => PURE_REWRITE_TAC ((filter (f o concl) asl)@thl))
152fun FILTER_ASM_REWRITE_TAC f thl =
153    ASSUM_LIST (fn asl => REWRITE_TAC ((filter (f o concl) asl) @ thl))
154fun FILTER_PURE_ONCE_ASM_REWRITE_TAC f L =
155    ASSUM_LIST (fn l => PURE_ONCE_REWRITE_TAC ((filter (f o concl) l) @ L))
156fun FILTER_ONCE_ASM_REWRITE_TAC f thl =
157    ASSUM_LIST (fn asl => ONCE_REWRITE_TAC ((filter (f o concl) asl) @ thl));
158
159
160fun GEN_REWRITE_CONV rw_func thl =
161   rw_func (REWRITES_CONV (add_rewrites empty_rewrites thl));
162
163fun GEN_REWRITE_RULE rw_func thl =
164    CONV_RULE (GEN_REWRITE_CONV rw_func thl);
165
166fun GEN_REWRITE_TAC rw_func thl =
167    CONV_TAC (GEN_REWRITE_CONV rw_func thl);
168
169
170(***************************************************************************
171 * SUBST_MATCH (|-u=v) th   searches for an instance of u in
172 * (the conclusion of) th and then substitutes the corresponding
173 * instance of v. Much faster than rewriting.
174 ****************************************************************************)
175
176local fun find_match u =
177       let val hom = ho_match_term [] empty_tmset u
178           fun find_mt t =
179               hom t handle HOL_ERR _ =>
180               find_mt(rator t)  handle HOL_ERR _ =>
181               find_mt(rand t)   handle HOL_ERR _ =>
182               find_mt(body t)   handle HOL_ERR _ =>
183               raise ERR "SUBST_MATCH" "no match"
184       in find_mt
185       end
186in
187fun SUBST_MATCH eqth th =
188 SUBS [Drule.INST_TY_TERM (find_match(lhs(concl eqth)) (concl th)) eqth] th
189end;
190
191
192(* -------------------------------------------------------------------------
193 * Useful instance of more general higher order matching.
194 * Taken directly from the GTT source code (Don Syme).
195 *
196 * val LOCAL_COND_ELIM_THM1 = prove
197 *     ((--`!P:'a->bool. P(a => b | c) = (~a \/ P(b)) /\ (a \/ P(c))`--),
198 *      GEN_TAC THEN COND_CASES_TAC THEN ASM_REWRITE_TAC[]);
199 *
200 * val conv = HIGHER_REWRITE_CONV[LOCAL_COND_ELIM_THM1];
201 * val x = conv (--`(P:'a -> bool) (a => b | c)`--);
202 * val x = conv (--`(a + (f x => 0 | n) + m) = 0`--) handle e => Raise e;
203 * ------------------------------------------------------------------------- *)
204
205
206val HIGHER_REWRITE_CONV =
207  let fun GINST th =
208      let val fvs = HOLset.listItems
209                       (HOLset.difference(FVL[concl th]empty_tmset,
210                                          hyp_frees th))
211          val gvs = map (genvar o type_of) fvs
212      in INST (map2 (curry op |->) fvs gvs) th
213      end
214  in fn ths =>
215      let val thl = map (GINST o SPEC_ALL) ths
216          val concs = map concl thl
217          val lefts = map lhs concs
218          val (preds,pats) = unzip(map dest_comb lefts)
219          val beta_fns = map2 BETA_VAR preds concs
220          val ass_list = zip pats (zip preds (zip thl beta_fns))
221          fun insert p = Ho_Net.enter ([],p,p)
222          val mnet = itlist insert pats Ho_Net.empty
223          fun look_fn t = mapfilter
224                    (fn p => if can (ho_match_term [] empty_tmset p) t then p
225                             else fail())
226                    (lookup t mnet)
227      in fn tm =>
228          let val ts = find_terms
229                        (fn t => not (null (look_fn t)) andalso free_in t tm) tm
230              val stm = Lib.trye hd (sort free_in ts)
231              val pat = Lib.trye hd (look_fn stm)
232              val (tmin,tyin) = ho_match_term [] empty_tmset pat stm
233              val (pred,(th,beta_fn)) = op_assoc aconv pat ass_list
234              val gv = genvar(type_of stm)
235              val abs = mk_abs(gv,subst[stm |-> gv] tm)
236              val (tmin0,tyin0) = ho_match_term [] empty_tmset pred abs
237          in CONV_RULE beta_fn (INST tmin (INST tmin0 (INST_TYPE tyin0 th)))
238          end
239      end
240      handle e => raise (wrap_exn "Ho_Rewrite" "HIGHER_REWRITE_CONV" e)
241  end;
242
243fun strip t =
244    if is_forall t then t |> strip_forall |> #2 |> strip
245    else if is_conj t then t |> strip_conj |> map strip |> List.concat
246    else if is_imp t then t |> strip_imp |> #2 |> strip
247    else if is_eq t then [lhs t]
248    else [t]
249
250fun num_matches th t =
251    let
252      val fvs = hyp_frees th
253      val tyvs = HOLset.foldl
254                   (fn (v,A) => HOLset.addList(A, type_vars_in_term v))
255                   (HOLset.empty Type.compare)
256                   fvs
257      val pats = strip (concl th)
258      fun match subt =
259          List.exists (fn p => can (ho_match_term (HOLset.listItems tyvs) fvs p)
260                                   subt)
261                      pats
262    in
263      length (find_terms match t)
264    end
265
266val thm_to_string = trace ("Unicode", 0) Parse.thm_to_string
267fun REQUIRE0_TAC th =
268    check_delta (ERR "REQUIRE0_TAC"
269                     ("LHS of " ^ thm_to_string th ^ " remains in goal"))
270                (count0 (num_matches th))
271
272fun REQUIRE_DECREASE_TAC th =
273    check_delta (ERR "REQUIRE_DECREASE_TAC"
274                     ("LHSes from " ^ thm_to_string th ^ " didn't decrease"))
275                (count_decreases (num_matches th))
276
277end (* Ho_Rewrite *)
278