1(* =====================================================================*)
2(* FILE: cond_rewr.sml                                                  *)
3(* AUTHOR: Wai Wong     DATE: 10 Feb 1993                               *)
4(* TRANSLATOR: Paul Curzon DATE: 27 May 1993                            *)
5(* CONDITIONAL REWRITING                                                *)
6(* ---------------------------------------------------------------------*)
7
8structure Cond_rewrite :> Cond_rewrite =
9struct
10
11open HolKernel boolLib Rsyntax;
12
13infix ## THEN THENL ORELSEC;
14
15fun COND_REWR_ERR {function,message} =
16          HOL_ERR{origin_structure = "Cond_rewrite",
17                  origin_function = function,
18                  message = message};
19
20val frees = rev o Term.free_vars;
21
22val subtract = op_set_diff aconv
23val intersect = op_intersect aconv
24
25fun match_aa tm1 tm2 = [match_term tm1 tm2] handle e => [] ;
26
27fun match_ok vs (l:((term,term) subst * (hol_type,hol_type) subst) list) =
28  ((length l) = 1) andalso
29   let val subst = hd l
30   val tm_subst = fst subst
31   val ty_subst = snd subst
32   in
33     null ty_subst andalso
34     (let val vset = listset (map #redex tm_subst)
35      in
36        HOLset.equal(listset vs, vset) orelse HOLset.isEmpty vset
37      end)
38   end ;
39
40fun match_aal fvars ante [] = []
41 |  match_aal fvars ante (asm::asml) =
42      let val ml = match_aa ante asm
43      in
44       if (match_ok (op_intersect aconv fvars (frees ante)) ml)
45       then [(fst(hd ml), asm)]
46       else (match_aal fvars ante asml)
47      end ;
48
49
50fun rr_compare ({redex=red1,residue=res1}, {redex=red2,residue=res2}) =
51  pair_compare(Term.compare,Term.compare) ((red1,res1),(red2,res2))
52fun substset s = HOLset.addList(HOLset.empty rr_compare, s)
53
54fun match_asm fvars asm (ml,pl) [] = (ml,pl)
55 | match_asm fvars asm (ml,pl) (ante :: antes) =
56     let val mlist = match_aal fvars ante asm
57         val mls = substset ml
58     in
59      if null mlist then (* no match *)
60        (match_asm fvars asm (ml,pl) antes)
61      else
62       (let val ml' = fst (hd mlist)
63            val mls' = substset ml'
64        in
65        if HOLset.isSubset(mls, mls') then (* found consistant match *)
66          (match_asm fvars asm (ml', (ante::pl)) antes)
67        else if HOLset.isSubset(mls', mls) then (* found consistant match *)
68          (match_asm fvars asm (ml, (ante::pl)) antes)
69        else (*  inconsistant match *)
70          (match_asm fvars asm (ml, pl) antes)
71       end)
72    end ;
73
74(* ---------------------------------------------------------------------*)
75(* MATCH_SUBS1 = -                                                      *)
76(* : thm ->  ((term # term) list # (type # type) list) list -> term list*)
77(*  -> term list -> (term list # thm)                                   *)
78(* MATCH_SUBS1 th m1 asm fvs = [tm1,...], th'                           *)
79(* INPUTS:                                                              *)
80(*  th is the theorem whose instances are required.                     *)
81(*  m1 is a list indicating an instantiation of th.                     *)
82(*  asm is a list of terms from which instances of the antecedents of th*)
83(*   are found.                                                         *)
84(*  fvs is a set fo variables which do not appear in the lhs of the     *)
85(*   conclusion of th.                                                  *)
86(* RETURNS:                                                             *)
87(*  tmi's are instances of the antecedents of th which do not appear in *)
88(*   the assumptions asm                                                *)
89(*  th' is an instance of th resulting from substituting the matches in *)
90(*   m1 and the matches newly found in the assumption asm.              *)
91(* ---------------------------------------------------------------------*)
92
93fun var_cap th fv sv newvs =
94  if not(is_forall (concl th))
95  then (newvs,th)
96  else
97   (let val {Bvar=v, Body=b} = dest_forall (concl th)
98     val (nv,th') = (I ## (GEN v)) (var_cap (SPEC v th) fv sv newvs)
99     in
100     if tmem v fv then
101      if tmem v sv then
102       (let val v' =  (variant (fv @ sv) v)
103            in
104             ((v':: nv), (CONV_RULE (GEN_ALPHA_CONV v') th'))
105            end)
106      else ((v :: nv),th')
107     else (nv,th')
108     end);
109
110fun MATCH_SUBS1 th fvs asm (m1:((term,term)subst * (hol_type,hol_type)subst)) =
111    let fun afilter l fvs =
112           mapfilter (fn (a,a') =>
113                          if null(intersect fvs (frees a))
114                          then raise COND_REWR_ERR{function="MATCH_SUBS1",
115                                                   message=""}
116                          else a') l
117    val subfrees = flatten (map (frees o #residue) (fst m1))
118    val thm1 = INST_TYPE (snd m1) th
119    val (nv, thm1') = var_cap thm1 fvs subfrees []
120    val thm2 = INST (fst m1) (SPEC_ALL thm1')
121    val antes = fst(strip_imp (snd (strip_forall(concl thm1'))))
122    val antes' = fst(strip_imp(concl thm2))
123    in
124    if (null fvs)
125    then (*not free vars *)
126        ((subtract antes' (intersect antes' asm)), (UNDISCH_ALL thm2))
127    else let
128        val rlist = match_asm nv asm ([],[])
129                              (afilter (combine(antes, antes')) nv)
130        val thm2a = INST (fst rlist) thm2
131        val (new_antes, _) = strip_imp (concl thm2a)
132        val thm3 = UNDISCH_ALL thm2a
133        val sgl =  subtract new_antes (intersect new_antes asm)
134      in
135        (sgl,thm3)
136      end
137   end;
138
139fun MATCH_SUBS th fvs asm mlist =
140    let val mll = map (MATCH_SUBS1 th fvs asm) mlist
141    val (tms,thms) = (flatten ## I) (split mll)
142    in
143    ((if (null tms) then [] else [list_mk_conj (op_mk_set aconv tms)]), thms)
144    end;
145
146(* --------------------------------------------------------------------- *)
147(* COND_REWR_TAC = -                                                     *)
148(* : ((term -> term -> ((term # term) list # (type # type) list) list) ->*)
149(*   thm_tactic)                                                         *)
150(* COND_REWR_TAC fn th                                                   *)
151(* fn is a search function which returns a list of matches in the format *)
152(* of the list returned by the system function match.                    *)
153(* th is the theorem used for rewriting which should be in the following *)
154(* form: |- !x1...xn y1..ym. P1[xi yj] ==> ...==> Pl[xi,yj] ==>          *)
155(*   (Q[x1...xn] = R[xi yj])                                             *)
156(* The variables x's appears in the lefthand side of the conclusion, and *)
157(* the variables y's do not appear in the lefthand side of the conclusion*)
158(* This tactic uses the search function fn to find any matches of Q in   *)
159(* the goal. It fails of no match is found, otherwise, it does:          *)
160(* 1) instantiating the input theorem (using both the type and the       *)
161(*    matched terms),                                                    *)
162(* 2) searching the assumptions to find any matches of the antecedents,  *)
163(* 3) if there is any antecedents which does not has match in the        *)
164(*    assumptions, it is put up as a subgoal and also added to the       *)
165(*    assumption list of the original goal,                              *)
166(* 4) substitutes these instances into the goal.                         *)
167(* Complications: if {yi} is not empty, step 2 will try to find instance *)
168(* of Pk and instantiate the y's using these matches.                    *)
169(* --------------------------------------------------------------------- *)
170
171fun COND_REWR_TAC f th =
172  (fn (asm,gl) =>
173    (let val (vars,body) = strip_forall (concl th)
174        val (antes,eqn) = strip_imp body
175        val {lhs=tml, rhs=tmr} = dest_eq eqn
176        val freevars = subtract vars (frees tml)
177        val mlist = f tml gl
178        val (sgl,thms) = (MATCH_SUBS th freevars asm mlist)
179     in
180     if null mlist
181     then raise COND_REWR_ERR{function="COND_REWR_TAC", message="no match"}
182     else if (null sgl)
183     then ((SUBST_TAC thms) (asm,gl))
184     else
185       ((SUBGOAL_THEN (hd sgl) STRIP_ASSUME_TAC THENL[
186         REPEAT CONJ_TAC, SUBST_TAC thms]) (asm,gl))
187    end
188    handle HOL_ERR {message = message,...} =>
189            (raise COND_REWR_ERR {function="COND_REWR_TAC",
190                                      message=message})));
191
192(* ---------------------------------------------------------------------*)
193(* search_top_down carries out a top-down left-to-right search for      *)
194(* matches of the term tm1 in the term tm2.                             *)
195(* ---------------------------------------------------------------------*)
196
197fun search_top_down tm1 tm2 =
198    [(match_term tm1 tm2)] handle _ =>
199     (if (is_comb tm2)
200      then  let val {Rator=rator,Rand=rand} = dest_comb tm2
201            in
202             ((search_top_down tm1 rator) @ (search_top_down tm1 rand))
203            end
204      else if (is_abs tm2)
205      then let val {Bvar=v,Body=body} = dest_abs tm2
206            in
207             search_top_down tm1 body
208            end
209      else []);
210
211(*---------------------------------------------------------------*)
212(* COND_REWR_CANON: thm -> thm                                   *)
213(* converts a theorem to a canonical form for conditional        *)
214(* rewriting. The input theorem should be in the following form: *)
215(* !x1 ... xn. P1[xi] ==> ... ==> !y1 ... ym. Pr[xi,yi] ==>      *)
216(* (!z1 ... zk. u[xi,yi,zi] = v[xi,yi,zi])                       *)
217(* The output theorem is in the form accepted by COND_REWR_THEN. *)
218(* It first moves all universal quantifications to the outermost *)
219(* level (possibly doing alpha conversion to avoid making free   *)
220(* variables become bound, then converts any Pj which is itself  *)
221(* a conjunction using ANTE_CONJ_CONV repeatedly.                *)
222(*---------------------------------------------------------------*)
223
224fun COND_REWR_CANON th =
225    let fun X_CONV conv tm =
226        (conv ORELSEC
227        (if is_imp tm then RAND_CONV (X_CONV conv) else NO_CONV)) tm
228    val rule = CONV_RULE(REPEATC(X_CONV ANTE_CONJ_CONV))
229    val th' = CONV_RULE (TOP_DEPTH_CONV RIGHT_IMP_FORALL_CONV) th
230    val vars = fst(strip_forall (concl th'))
231    in
232     (GENL vars (rule(SPEC_ALL th')))
233    end;
234
235(*---------------------------------------------------------------*)
236(* COND_REWRITE1_TAC : thm_tactic                                *)
237(*---------------------------------------------------------------*)
238
239val COND_REWRITE1_TAC:thm_tactic = fn  th =>
240    COND_REWR_TAC search_top_down (COND_REWR_CANON th);
241
242(* --------------------------------------------------------------------- *)
243(* COND_REWR_CONV =                                                      *)
244(* : ((term -> term -> ((term # term) list # (type # type) list) list) ->*)
245(*   thm -> conv)                                                        *)
246(* COND_REWR_CONV fn th tm                                               *)
247(* th is a theorem in the usual format for conditional rewriting.        *)
248(* tm is a term on which conditinal rewriting is performed.              *)
249(* fn is a function attempting to match the lhs of the conclusion of th  *)
250(* to the input term tm or any subterm(s) of it. The list returned by    *)
251(* this function is used to instantiate the input theorem. the resulting *)
252(* instance(s) are used in a REWRITE_CONV to produce the final theorem.  *)
253(* --------------------------------------------------------------------- *)
254
255fun COND_REWR_CONV f th = (fn tm =>
256    let val (vars,b) = strip_forall (concl th)
257    val tm1 = lhs(snd(strip_imp b))
258    val ilist = f tm1 tm
259    in
260    if (null ilist)
261    then raise COND_REWR_ERR{function="COND_REWR_CONV", message="no match"}
262    else  let val thm1 = SPEC_ALL th
263          val rlist = map (fn l => UNDISCH_ALL(INST_TY_TERM l thm1)) ilist
264          in
265            REWRITE_CONV rlist tm
266          end
267    end
268    handle HOL_ERR {message = message,...} =>
269            (raise COND_REWR_ERR {function="COND_REWR_CONV",
270                                      message=message})):conv;
271
272
273(* ---------------------------------------------------------------------*)
274(* COND_REWRITE1_CONV : thm list -> thm -> conv                         *)
275(* COND_REWRITE1_CONV thms thm tm                                       *)
276(* ---------------------------------------------------------------------*)
277fun COND_REWRITE1_CONV thms th = (fn tm =>
278    let val thm1 = COND_REWR_CONV search_top_down (COND_REWR_CANON th) tm
279    in
280     itlist PROVE_HYP thms thm1
281    end):conv;
282
283end (* structure Cond_rewrite *)
284