1(*----------------------------------------------------------------------------
2 * Rewriting splits into two parts:
3 *
4 *    1. Rewriting a subterm (M) by a set of rewrite rules. Conceptually,
5 *       we choose the first rewrite rule that matches M
6 *
7 *           R = |- lhs = rhs
8 *
9 *       from the set, and instantiate to get
10 *
11 *           R' = |- M = rhs'.
12 *
13 *    2. Traversing the term. For a contextual rewriter, like this one, this
14 *       involves adding new context at each node that introduces context
15 *       (like a conditional statement).
16 *--------------------------------------------------------------------------*)
17
18structure RW :> RW =
19struct
20
21open HolKernel Parse boolLib pairLib;
22
23val RW_ERR = mk_HOL_ERR "RW";
24
25val monitoring = ref 0
26
27val _ = register_trace ("TFL rewrite monitoring", monitoring, 20);
28
29fun lztrace(i,fname,msgf) = if i <= !monitoring then
30                         Lib.say ("RW."^fname^": "^ msgf() ^ "\n")
31                      else ()
32
33
34(* Fix the grammar used by this file *)
35val ambient_grammars = Parse.current_grammars();
36val _ = Parse.temp_set_grammars boolTheory.bool_grammars
37
38
39(*----------------------------------------------------------------------------
40 * |- !x y z. w   --->  |- w[x|->g1][y|->g2][z|->g3]
41 * This belongs in Drule.sml.
42 *---------------------------------------------------------------------------*)
43
44fun GSPEC_ALL th =
45   (case (dest_thy_const(rator (concl th)))
46     of {Name = "!",Thy="bool",Ty} =>
47          GSPEC_ALL (SPEC (genvar (#1(dom_rng(#1(dom_rng Ty))))) th)
48     | _ => th)
49    handle HOL_ERR _ => th;
50
51
52fun gvterm tm =
53  case dest_term tm
54   of COMB(t1,t2) => mk_comb(gvterm t1,gvterm t2)
55    | LAMB(v,M) =>
56       let val gv = genvar (type_of v)
57           val M' = gvterm(subst [v |-> gv] M)
58       in mk_abs(gv,M')
59       end
60    | otherwise => tm;
61
62fun GENVAR_THM th =
63 let val M = concl th
64     val M' = gvterm M
65 in
66  EQ_MP (ALPHA M M') th
67 end;
68
69 (*--------------------------------------------------------------------------
70  * Support for constructing rewrite rule sets. The following routines
71  * are attempts at providing "not too restrictive" checks for whether
72  * a rewrite will loop or not. These have been arrived at by trial and
73  * error, and  can certainly be improved!
74  * A couple of old versions follow.
75  *
76  * fun embedded_in tm =
77  *   let val head = #1(strip_comb tm)
78  *   in if is_var head then can (find_term (aconv head)) else fn _ => false
79  *   end;
80  *
81  * fun embedded_in tm =
82  *   let val head = #1(strip_comb tm)
83  *   in if is_var head then can (find_term (can (match_term tm)))
84  *                     else fn _ => false
85  *   end;
86  *--------------------------------------------------------------------------*)
87
88 fun alike head tm1 tm2 =
89  (#1 (strip_comb tm2) = head)
90  andalso
91  can (match_term tm1) tm2;
92
93 fun embedded1 tm =
94  let val head = #1(strip_comb tm)
95  in if is_var head then alike head tm else K false
96  end;
97
98 (* For changing the notion of a looping rewrite. *)
99
100 val embedded_ref = ref embedded1;
101
102
103 (*---------------------------------------------------------------------------
104  * I could check that the lhs is not embedded in the rhs, but that wouldn't
105  * allow me to unroll recursive functions.
106  *--------------------------------------------------------------------------*)
107 fun might_loop th =
108    let val (ants,(lhs,rhs)) = (I##dest_eq)(strip_imp_only(concl th))
109        val embedded_in = !embedded_ref
110        val islooper = (aconv lhs rhs) orelse (exists (embedded_in lhs) ants)
111    in if (islooper  andalso !monitoring > 0)
112       then Lib.say ("excluding possibly looping rewrite:\n"
113                     ^thm_to_string th^"\n\n")
114       else ();
115       islooper
116    end;
117
118(* ---------------------------------------------------------------------------
119 * Split a theorem into a list of theorems suitable for rewriting:
120 *
121 *   Apply the following transformations:
122 *
123 *        |t1 /\ t2|     -->    |t1| @ |t2|
124 *        |t1 ==> t2|    -->    (t1 |- |t2|)
125 *        |!x.tm|        -->    |{x |-> newvar}tm|
126 *
127 *   Bottom-out with |- t --> |- t = T and |- ~t --> |- t = F
128 *
129 *---------------------------------------------------------------------------*)
130 fun mk_simpls SPECer =
131  let val istrue = boolSyntax.T
132      fun mk_rewrs th =
133      let val tm = Thm.concl th
134      in  if (is_eq tm) then [th] else
135          if (is_neg tm) then [EQF_INTRO th] else
136          if (is_conj tm)
137          then (op @ o (mk_rewrs ## mk_rewrs) o Drule.CONJ_PAIR) th else
138          if (is_imp tm)
139          then let val ant = list_mk_conj (fst(strip_imp_only tm))
140                   fun step imp cnj =
141                       step (MP imp (CONJUNCT1 cnj)) (CONJUNCT2 cnj)
142                       handle HOL_ERR _ => MP imp cnj
143               in EQT_INTRO th
144                  ::map (DISCH ant) (mk_rewrs (step th (ASSUME ant)))
145               end else
146          if (is_forall tm) then mk_rewrs (SPECer th) else
147          if (tm = istrue) then [] else
148          [EQT_INTRO th]
149      end
150      handle HOL_ERR _ => raise RW_ERR "mk_simpls" ""
151  in
152    filter (not o might_loop) o mk_rewrs
153  end;
154
155 fun mk_simplsl SPECer = flatten o map (mk_simpls SPECer);
156
157 local val MK_FRESH = mk_simpls GSPEC_ALL        (* partly apply *)
158       val MK_READABLE = mk_simpls SPEC_ALL      (* partly apply *)
159 in
160 fun MK_RULES_APART th = MK_FRESH (GEN_ALL th)
161 and MK_RULES th = MK_READABLE (GEN_ALL th)
162 end;
163
164
165(* Tells whether to add context to the simplication set as term is traversed *)
166datatype context_policy = ADD | DONT_ADD;
167
168
169(* Provides a quick way of telling if a rewrite rule is conditional or not. *)
170datatype choice = COND of thm | UNCOND of thm;
171
172fun dest_choice (COND th)   = th
173  | dest_choice (UNCOND th) = th;
174
175
176(*----------------------------------------------------------------------------
177 * Takes a rewrite rule and applies it to a term, which, if it is an instance
178 * of the left-hand side of the rule, results in the return of the
179 * instantiated rule. Handles conditional rules.
180 *---------------------------------------------------------------------------*)
181fun PRIM_RW_CONV th =
182 let val (has_condition,eq) = ((not o null)##I)(strip_imp_only (concl th))
183     val pat = lhs eq
184     val matcher = Term.match_term pat
185     fun match_then_inst tm =
186        let val (tm_theta, ty_theta) = matcher tm
187            val th' = INST tm_theta (INST_TYPE ty_theta th)
188        in
189          if has_condition then (COND th') else (UNCOND th')
190        end
191 in match_then_inst
192 end;
193
194
195(*----------------------------------------------------------------------------
196 * Match and instantiate a congruence rule. A congruence rule looks like
197 *
198 *        (c1 ==> (M1 = M1')) /\ .../\ (cm ==> (Mn = Mn'))
199 *       -------------------------------------------------
200 *                    f M1...Mn = f M1'...Mn'
201 *
202 * The ci do not have to be there, i.e., unconditional antecedents can
203 * certainly exist.
204 *---------------------------------------------------------------------------*)
205
206fun CONGR th =
207   let val (ants,eq) = strip_imp_only (concl th)
208       (* TODO: Check that it is a congruence rule *)
209       val pat = lhs eq
210       val matcher = Term.match_term pat
211       fun match_then_inst tm =
212          let val (tm_theta, ty_theta) = matcher tm
213          in INST tm_theta (INST_TYPE ty_theta th) end
214   in
215     match_then_inst
216   end;
217
218
219datatype simpls = RW of {thms     : thm list list,
220                        congs    : thm list list,
221                        rw_net   : (term -> choice) Net.net,
222                        cong_net : (term -> thm) Net.net};
223
224val empty_simpls = RW{thms = [[]],  congs = [[]],
225                      rw_net = Net.empty,
226                      cong_net = Net.empty};
227
228fun dest_simpls (RW{thms, congs,...}) =
229   {rws = rev(flatten thms),
230    congs = rev(flatten congs)};
231
232
233fun add_rws (RW{thms,rw_net,congs, cong_net}) thl =
234 RW{thms   = thl::thms,
235    congs  = congs,
236  cong_net = cong_net,
237    rw_net = itlist Net.insert
238             (map (fn th => let val left = lhs(#2(strip_imp_only(concl th)))
239                            in  (left,  PRIM_RW_CONV th) end)
240                  (flatten (map MK_RULES_APART thl)))        rw_net}
241 handle HOL_ERR _
242 => raise RW_ERR "add_rws" "Unable to deal with input";
243
244
245fun add_congs (RW{cong_net, congs, thms, rw_net}) thl =
246  RW{thms = thms, rw_net = rw_net,
247     congs = thl::congs,
248     cong_net = itlist Net.insert
249         (map (fn th =>
250                let val c = concl th
251                    val eq = snd(dest_imp c) handle HOL_ERR _ => c
252                in
253                   (lhs eq,  CONGR th)
254                end)
255              (map (GSPEC_ALL o GEN_ALL) thl))         cong_net}
256  handle HOL_ERR _ =>
257  raise RW_ERR"add_congs" "Unable to deal with input"
258
259
260(*----------------------------------------------------------------------------
261 * In RW_STEP, we find the list of matching rewrites, and choose the first
262 * one that succeeds. Conditional rules succeed if they can solve their
263 * antecedent by applying the prover (it gets to use the context and the
264 * supplied simplifications).
265 * Note.
266 * "ant_vars_fixed" is true when the instantiated rewrite rule has no
267 * uninstantiated variables in its antecedent. If "ant_vars_fixed" is not
268 * true, we get the instantiation from the context.
269 *
270 * Note.
271 * "sys_var" could be more rigorous in its check, but we don't
272 * have a defined notion of the syntax of system variables.
273 *---------------------------------------------------------------------------*)
274
275fun stringulate _ [] = []
276  | stringulate f [x] = [f x]
277  | stringulate f (h::t) = f h::",\n"::stringulate f t;
278
279val drop_opt = List.mapPartial Lib.I
280
281local fun sys_var tm = (is_var tm andalso
282                        not(Lexis.ok_identifier(fst(dest_var tm))))
283      val failed = RW_ERR "RW_STEP" "all applications failed"
284in
285fun RW_STEP {context=(cntxt,_),prover,simpls as RW{rw_net,...}} tm = let
286  fun match f =
287      (case f tm of
288         UNCOND th => SOME th
289       | COND th => let
290           val condition = fst(dest_imp(concl th))
291           val cond_thm = prover simpls cntxt condition
292          val ant_vars_fixed = not(can(find_term sys_var) condition)
293         in
294           SOME ((if ant_vars_fixed then MP else MATCH_MP) th cond_thm)
295         end)
296      handle HOL_ERR _ => NONE
297  fun try [] = raise failed
298    | try (f::rst) =
299      case match f of
300        NONE => try rst
301      | SOME th =>
302        if !monitoring > 0 then
303          (set_trace "assumptions" 1;
304           case drop_opt (map match rst) of
305             [] => (Lib.say (String.concat
306                                 ["RW_STEP:\n", Parse.thm_to_string th]);
307                    th before set_trace "assumptions" 0)
308           | L => (Lib.say (String.concat
309                                ["RW_STEP: multiple rewrites possible \
310                                 \(first taken):\n",
311                                 String.concat
312                                    (stringulate Parse.thm_to_string (th::L))]);
313                   th before set_trace "assumptions" 0))
314        else th
315in
316   try (Net.match tm rw_net)
317end
318end
319
320(*---------------------------------------------------------------------------
321 * It should be a mistake to have more than one applicable congruence rule for
322 * a constant, but I don't currently check that.
323 *---------------------------------------------------------------------------*)
324
325fun CONG_STEP (RW{cong_net,...}) tm = Lib.trye hd (Net.match tm cong_net) tm;
326
327(*----------------------------------------------------------------------------
328 *                          Prettyprinting
329 *---------------------------------------------------------------------------*)
330
331local open Portable PP
332in
333fun pp_simpls (RW{thms,congs,...}) =
334   let val pp_thm = Parse.pp_thm
335       val thms' = mk_simplsl SPEC_ALL (rev(flatten thms))
336       val congs' = rev(flatten congs)
337       val how_many_thms = length thms'
338       val how_many_congs = length congs'
339       val B = block CONSISTENT 0
340   in
341     block PP.CONSISTENT 0 [
342       if (how_many_thms = 0)
343       then (add_string "<empty simplification set>")
344       else B [add_string"Rewrite Rules:", NL,
345               add_string"--------------", NL,
346               block PP.INCONSISTENT 0 (
347                 pr_list pp_thm [add_string";", add_break(2,0)] thms')
348              ],
349       NL,
350       add_string("Number of rewrite rules = "^Lib.int_to_string how_many_thms),
351       NL,
352       B (if (how_many_congs = 0) then []
353          else [
354            NL,
355            add_string"Congruence Rules", NL,
356            add_string"----------------", NL,
357            B (pr_list pp_thm [add_string";", add_break(2,0)] congs'), NL,
358            add_string("Number of congruence rules = "
359                       ^Lib.int_to_string how_many_congs), NL
360         ])
361     ]
362   end
363end;
364
365fun join_simpls s1 s2 =
366   let val {rws,congs,...} = dest_simpls s1
367   in add_congs (add_rws s2 rws) congs
368   end;
369
370 (* end implementation of simpls type *)
371
372val std_simpls = add_rws empty_simpls
373 ([boolTheory.REFL_CLAUSE,
374   boolTheory.EQ_CLAUSES,
375   boolTheory.NOT_CLAUSES,
376   boolTheory.AND_CLAUSES,
377   boolTheory.OR_CLAUSES,
378   boolTheory.IMP_CLAUSES,
379   boolTheory.COND_CLAUSES,
380   boolTheory.FORALL_SIMP,
381   boolTheory.EXISTS_SIMP,
382   boolTheory.ABS_SIMP]
383 @
384   [Q.prove(`(!x:'a. ?y. x = y) /\ !x:'a. ?y. y = x`,
385     CONJ_TAC THEN GEN_TAC THEN EXISTS_TAC(Term`x:'a`) THEN REFL_TAC)]);
386
387(*----------------------------------------------------------------------------
388 *
389 *                             TERM TRAVERSAL
390 *
391 *---------------------------------------------------------------------------*)
392
393exception UNCHANGED;
394
395fun QCONV cnv cp tm = cnv cp tm handle UNCHANGED => REFL tm;
396
397val ALL_QCONV = fn _ => raise UNCHANGED;
398
399fun THENQC cnv1 cnv2 cp tm =
400   let val th1 = cnv1 cp tm
401   in TRANS th1 (cnv2 cp (rhs (concl th1))) handle UNCHANGED => th1
402   end
403   handle UNCHANGED => cnv2 cp tm;
404
405fun ORELSEQC cnv1 cnv2 cp tm =
406   cnv1 cp tm handle UNCHANGED => raise UNCHANGED
407                   | HOL_ERR _ => cnv2 cp tm;
408
409fun REPEATQC conv cp tm =
410   ORELSEQC (THENQC conv (REPEATQC conv)) ALL_QCONV cp tm;
411
412local val CHANGED_QRW_ERR = RW_ERR "CHANGED_QRW" ""
413in
414fun CHANGED_QCONV cnv cp tm =
415   let val th = cnv cp tm handle UNCHANGED => raise CHANGED_QRW_ERR
416       val (lhs,rhs) = dest_eq (concl th)
417   in if aconv lhs rhs then raise CHANGED_QRW_ERR else th
418   end
419end;
420
421fun TRY_QCONV cnv = ORELSEQC cnv ALL_QCONV;
422
423datatype delta = CHANGE of thm | NO_CHANGE of thm;
424
425fun unchanged (NO_CHANGE _) = true | unchanged _ = false;
426
427
428(*---------------------------------------------------------------------------
429 * And now, a whole bunch of support for rewriting with congruence rules.
430 *---------------------------------------------------------------------------*)
431
432fun variants away0 vlist =
433  rev(fst (rev_itlist (fn v => fn (V,away) =>
434             let val v' = variant away v in (v'::V, v'::away) end)
435           vlist ([],away0)));
436
437fun variant_theta away0 vlist =
438 rev_itlist (fn v => fn (V,away) =>
439    let val v' = variant away v
440    in if v=v' then (V,away) else ((v|->v')::V, v'::away) end)
441 vlist ([],away0);
442
443(*---------------------------------------------------------------------------
444 * Takes a list of free variables and a list of pairs. If any of
445 * the free variables are in the pairs, they are replaced in the pairs
446 * by variants.  The final pairs are returned.
447 *---------------------------------------------------------------------------*)
448
449fun vstrl_variants away0 vstrl =
450  let val fvl = free_varsl vstrl
451      val clashes = op_intersect aconv away0 fvl
452  in if null clashes then vstrl
453     else let val theta =
454               #1(rev_itlist (fn v => fn (theta, pool) =>
455                     let val v' = variant pool v
456                     in if v=v' then (theta,pool)
457                                else ((v|->v')::theta, v'::pool)
458                     end) clashes ([], op_union aconv away0 fvl))
459          in map (subst theta) vstrl
460          end
461  end;
462
463fun thml_fvs thl =
464   Lib.op_U aconv (map (fn th => let val (asl,c) = dest_thm th
465                                 in free_varsl (c::asl)
466                                 end) thl);
467
468fun dest_combn tm 0 = (tm,[])
469  | dest_combn tm n =
470     let val (Rator,Rand) = dest_comb tm
471         val (f,rands) = dest_combn Rator (n-1)
472     in (f,Rand::rands)
473     end;
474
475fun add_cntxt ADD = add_rws | add_cntxt DONT_ADD = Lib.K;
476
477(*---------------------------------------------------------------------------*)
478(* A congruence rule can have two kinds of antecedent: universally           *)
479(* quantified or unquantified. A bare antecedent is fairly simple to deal    *)
480(* with: it has the form                                                     *)
481(*                                                                           *)
482(*    conditions ==> lhs = ?rhs                                              *)
483(*                                                                           *)
484(* The following congruence rule has only bare antecedents:                  *)
485(*                                                                           *)
486(*    |- !P Q x x' y y'.                                                     *)
487(*           (P ��� Q) /\                                                  UOK *)
488(*           (Q ��� (x = x')) /\                                           UOK *)
489(*           (��Q ��� (y = y'))                                             UOK *)
490(*           ==>                                                             *)
491(*           ((if P then x else y) = if Q then x' else y')                   *)
492(*                                                                           *)
493(* A bare antecedent is processed by assuming the conditions and rewriting   *)
494(* the LHS of the equation. This yields the value for ?rhs.                  *)
495(*                                                                           *)
496(* A quantified antecedent is more troublesome, since it usually implies     *)
497(* that something higher-order is going on, and so beta-conversion and       *)
498(* paired-lambdas may need to be handled. The following is the congruence    *)
499(* rule for LET:                                                             *)
500(*                                                                           *)
501(*  |- (M = M') /\ (!x. (x = M') ==> (f x = g x)) ==> LET f M = LET g M'     *)
502(*                                                                           *)
503(* In the second antecedent, f is a function that may be a constant, a       *)
504(* lambda term, a paired lambda term, or even the application of a higher-   *)
505(* order function. In order to extract termination conditions from f x,      *)
506(* a (paired) beta reduction will be done, which can completely transform    *)
507(* the term structure. After extraction, a corresponding function g is       *)
508(* obtained from the rhs of the equality theorem returned. This may not look *)
509(* like a function, so there is a step of "un-beta-expansion" on g.          *)
510(*                                                                           *)
511(* Note.                                                                     *)
512(* When doing rewriting of quantified antecedents to congruence rules, as    *)
513(* in the one for "let" statements above, the temptation is there to only    *)
514(* rewrite (in context) f to g, and use MK_COMB to get f x = g x. (Assume    *)
515(* that f is a lambda term.) However, the free variables in the context      *)
516(* (i.e., x) map to bound variables in f and the attempt to abstract on the  *)
517(* way out of the rewrite will fail, or isolate the free variables. Dealing  *)
518(* with this causes some clutter in the code.                                *)
519(*---------------------------------------------------------------------------*)
520
521fun no_change V L tm =
522  NO_CHANGE (itlist GEN V (itlist DISCH L (REFL tm)))
523
524fun map2_total f (h1::t1) (h2::t2) = f h1 h2 :: map2_total f t1 t2
525  | map2_total f other wise = [];
526
527fun try_cong cnv (cps as {context,prover,simpls}) tm =
528 let
529 fun simple cnv (cps as {context as (cntxt,b),prover,simpls}) (ant,rst) =
530   case total ((I##dest_eq) o strip_imp_only) ant
531    of NONE (* Not an equality, so just assume *)
532         => (CHANGE(ASSUME ant), rst)
533     | SOME (L,(lhs,rhs))
534        => let val outcome =
535             if aconv lhs rhs then no_change [] L lhs
536             else let val cps' = if null L then cps else
537                          {context = (map ASSUME L @ cntxt,b),
538                           prover  = prover,
539                           simpls  = add_cntxt b simpls (map ASSUME L)}
540                  in CHANGE(cnv cps' lhs)
541                     handle HOL_ERR _ => no_change [] L lhs
542                          | UNCHANGED => no_change [] L lhs
543                  end
544           in
545        case outcome
546         of CHANGE th =>
547              let val Mnew = boolSyntax.rhs(concl th)
548              in (CHANGE (itlist DISCH L th),
549                  map (subst [rhs |-> Mnew]) rst)
550              end
551          | NO_CHANGE _ => (outcome, map (subst [rhs |-> lhs]) rst)
552        end
553
554 fun complex cnv (cps as {context as (cntxt,b),prover,simpls}) (ant,rst) =
555  let val ant_frees = free_vars ant
556    val context_frees = free_varsl (map concl cntxt)
557    val (vlist,ceqn) = strip_forall ant
558    val (L0,eq) = strip_imp_only ceqn
559    val (lhs,rhs) = dest_eq eq
560    val nvars = length (snd (strip_comb rhs))  (* guessing ... *)
561    val (f,args) = (I##rev) (dest_combn lhs nvars)
562    val (rhsv,_) = dest_combn rhs nvars
563    val vstrl = #1(strip_pabs f)
564    val vstructs = vstrl_variants (union ant_frees context_frees) vstrl
565    val ceqn' = if null vstrl then ceqn else subst (map (op|->) (zip args vstructs)) ceqn
566
567(*    val ceqn' = if null vstrl then ceqn
568                 else subst (map2_total (curry op|->) args vstructs) ceqn
569*)
570
571    val (L,(lhs,rhs)) = (I##dest_eq) (strip_imp_only ceqn')
572    val outcome =
573       if aconv lhs rhs
574        then no_change vlist L lhs
575       else
576       let val lhs_beta_maybe = Conv.QCONV (Conv.DEPTH_CONV GEN_BETA_CONV) lhs
577           val lhs' = boolSyntax.rhs(concl lhs_beta_maybe)
578           val cps' = case L
579                       of [] => cps
580                        | otherwise =>
581                            {context = (map ASSUME L @ cntxt,b),
582                             prover = prover,
583                             simpls  = add_cntxt b simpls (map ASSUME L)}
584       in CHANGE(TRANS lhs_beta_maybe (cnv cps' lhs'))
585          handle HOL_ERR _ =>
586                  if aconv lhs lhs'
587                    then no_change vlist L lhs
588                    else CHANGE lhs_beta_maybe
589               | UNCHANGED => if aconv lhs lhs'
590                              then no_change vlist L lhs
591                              else CHANGE lhs_beta_maybe
592       end
593  in
594  case outcome
595   of NO_CHANGE _ => (outcome, map (subst [rhsv |-> f]) rst)
596    | CHANGE th =>
597      let (*------------------------------------------------------------*)
598          (* Function eta_rhs packages up the new rhs, eta-expanding it *)
599          (* if need be, i.e. if the lhs is an application of a lambda  *)
600          (* or paired-lambda term f. In that case, the extraction has  *)
601          (* first done a beta-reduction and then extraction, so the    *)
602          (* derived rhs needs to be "un-beta-expanded" in order that   *)
603          (* the existential var on the rhs (g)be filled in with a thing*)
604          (* that has function syntax. This will allow the final        *)
605          (* MATCH_MP icong ... to  succeed.                            *)
606          (*------------------------------------------------------------*)
607         fun drop n list =
608            if n <= 0 orelse null list then list
609            else drop (n-1) (tl list)
610         (*-------------------------------------------------------------*)
611         (* if fewer vstructs than args, this means that the body       *)
612         (* (rcore below) has a function type and will be eta-expanded  *)
613         (*-------------------------------------------------------------*)
614         val unconsumed = drop (length vstructs) args
615         val vstructs' = vstructs @ unconsumed
616         fun eta_rhs th =
617           let val r = boolSyntax.rhs(concl th)
618               val not_lambda_app = null vstrl
619               val rcore = if not_lambda_app
620                            then fst(dest_combn r nvars)
621                            else r
622(*               val g = list_mk_pabs(vstructs',rcore)
623               val gvstructs = list_mk_comb(g,vstructs')
624*)
625               val g = list_mk_pabs(vstructs,rcore)
626               val gvstructs = list_mk_comb(g,vstructs)
627
628               val rhs_eq = if not_lambda_app then REFL gvstructs
629                            else SYM(Conv.QCONV (DEPTH_CONV GEN_BETA_CONV) gvstructs)
630               val th1 = TRANS th rhs_eq (* |- f vstructs = g vstructs *)
631                         handle HOL_ERR _ => th
632            in (g,th1)
633            end
634         val (g,th1) = eta_rhs th
635         val th2 = itlist DISCH L th1
636(*         val pairs = zip args vstructs' handle HOL_ERR _ => [] *)
637         val pairs = zip args vstructs handle HOL_ERR _ => []
638         fun generalize v thm =
639              case assoc1 v pairs
640               of SOME (_,tup) => pairTools.PGEN v tup thm
641                | NONE => GEN v thm
642          val result = itlist generalize vlist th2
643      in
644        (CHANGE result, map (subst [rhsv |-> g]) rst)
645      end
646  end (* complex *)
647
648  val icong = CONG_STEP simpls tm
649  val ants = strip_conj (fst(dest_imp (concl icong)))
650  (* loop proves each antecedent in turn and propagates
651     instantiations to the remainder. *)
652  fun loop [] = []
653    | loop (ant::rst) =
654      let val (outcome,rst') =
655           if is_forall ant
656             then complex cnv cps (ant,rst)
657             else simple cnv cps (ant,rst)
658      in
659        outcome::loop rst'
660      end
661  val outcomes = loop ants
662  fun mk_ant (NO_CHANGE th) = th
663    | mk_ant (CHANGE th) = th
664 in
665   if Lib.all unchanged outcomes
666     then raise UNCHANGED
667     else MATCH_MP icong (LIST_CONJ (map mk_ant outcomes))
668 end
669
670
671fun SUB_QCONV cnv cps tm =
672 case dest_term tm
673  of COMB(Rator,Rand) =>
674      (try_cong cnv cps tm
675       handle UNCHANGED => raise UNCHANGED
676          | HOL_ERR _ =>
677              let val th = cnv cps Rator
678              in MK_COMB (th, cnv cps Rand) handle UNCHANGED => AP_THM th Rand
679              end
680              handle UNCHANGED => AP_TERM Rator (cnv cps Rand)
681      )
682   | LAMB(Bvar,Body) =>
683      let val Bth = cnv cps Body
684      in ABS Bvar Bth
685         handle HOL_ERR _ =>
686          let val _ = lztrace(6, "SUB_QCONV",
687                              trace ("assumptions", 1)
688                              (fn () => "ABS failure: " ^
689                                        ppstring pp_term Bvar ^ "  " ^
690                                        ppstring pp_thm Bth))
691              val v = genvar (type_of Bvar)
692              val th1 = ALPHA_CONV v tm
693              val call2 = cnv cps (body(rhs(concl th1)))
694              val _ = lztrace(6, "SUB_QCONV",
695                              trace ("assumptions", 1)
696                              (fn () => "ABS 2nd call: "^
697                                        ppstring pp_thm call2))
698              val eq_thm' = ABS v call2
699              val at = snd(dest_eq(concl eq_thm'))
700              val v' = variant (free_vars at) Bvar
701              val th2 = ALPHA_CONV v' at
702          in TRANS (TRANS th1 eq_thm') th2
703          end
704      end
705   | otherwise => raise UNCHANGED     (* Constants and  variables *)
706
707
708fun DEPTH_QCONV cnv cps tm =
709   THENQC (SUB_QCONV (DEPTH_QCONV cnv)) (REPEATQC cnv) cps tm;
710
711fun REDEPTH_QCONV cnv cps tm =
712   THENQC
713     (SUB_QCONV (REDEPTH_QCONV cnv))
714     (ORELSEQC (THENQC cnv (REDEPTH_QCONV cnv)) ALL_QCONV)
715     cps tm;
716
717fun TOP_DEPTH_QCONV cnv cps tm =
718 THENQC
719   (REPEATQC cnv)
720   (TRY_QCONV
721       (THENQC (CHANGED_QCONV (SUB_QCONV (TOP_DEPTH_QCONV cnv)))
722               (TRY_QCONV (THENQC cnv (TOP_DEPTH_QCONV cnv)))))
723  cps tm;
724
725fun ONCE_DEPTH_QCONV cnv cps tm =
726   TRY_QCONV (ORELSEQC cnv (SUB_QCONV (ONCE_DEPTH_QCONV cnv))) cps tm;
727
728
729type cntxt_solver = {context:thm list * context_policy,
730                     simpls:simpls,
731                     prover:simpls -> thm list -> term -> thm};
732
733type strategy = (cntxt_solver -> term -> thm) -> (cntxt_solver -> term -> thm)
734
735(* strategy builders *)
736
737fun DEPTH x = QCONV (DEPTH_QCONV x);
738fun REDEPTH x = QCONV (REDEPTH_QCONV x);
739fun TOP_DEPTH x = QCONV (TOP_DEPTH_QCONV x);
740fun ONCE_DEPTH x = QCONV (ONCE_DEPTH_QCONV x);
741
742fun RAND f cntxt tm =
743   let val (Rator,Rand) = dest_comb tm
744   in AP_TERM Rator (f cntxt Rand)
745   end
746   handle HOL_ERR _ => raise RW_ERR "RAND" ""
747
748fun RATOR f cntxt tm =
749   let val (Rator,Rand) = dest_comb tm
750   in AP_THM (f cntxt Rator) Rand
751   end
752   handle HOL_ERR _  => raise RW_ERR "RATOR" ""
753
754fun ABST f cntxt tm =
755   let val (Bvar,Body) = dest_abs tm
756   in ABS Bvar (f cntxt Body)
757   end
758   handle HOL_ERR _ => raise RW_ERR "ABST" "";
759
760
761(*---------------------------------------------------------------------------*
762 * This is the basis for all the high-level rewriting entrypoints. Basically,*
763 * the simpls get computed and after that the traverser moves around the     *
764 * term and applies RW_STEP at nodes.                                        *
765 *---------------------------------------------------------------------------*)
766
767fun RW_STEPS traverser (simpls,context,congs,prover) thl =
768   let val simpls' = add_congs(add_rws simpls thl) congs
769   in
770      traverser RW_STEP {context=context, prover=prover, simpls=simpls'}
771   end;
772
773(*---------------------------------------------------------------------------*
774 * Define an implicit set of rewrites, so that common rewrite rules don't    *
775 * need to be constantly given by the user.                                  *
776 *---------------------------------------------------------------------------*)
777
778local val implicit = ref std_simpls
779in
780   fun implicit_simpls() = !implicit
781   fun set_implicit_simpls rws = (implicit := rws)
782end
783
784val add_implicit_rws = fn thl => set_implicit_simpls
785                                       (add_rws (implicit_simpls()) thl)
786val add_implicit_congs = fn thl => set_implicit_simpls
787                                       (add_congs(implicit_simpls()) thl)
788val add_implicit_simpls = fn s => set_implicit_simpls
789                                       (join_simpls s (implicit_simpls()))
790
791datatype repetitions
792          = Once
793          | Fully
794          | Special of strategy;
795
796datatype rules
797          = Default of thm list
798          | Pure of thm list
799          | Simpls of simpls * thm list
800
801datatype context = Context of thm list * context_policy
802datatype congs   = Congs of thm list
803datatype solver  = Solver of simpls -> thm list -> term -> thm;
804
805
806(* Term rewriting *)
807
808(*---------------------------------------------------------------------------
809 * The basic choices are in the traversal strategy and whether or not to use
810 * a default set of simplifications.
811 *---------------------------------------------------------------------------*)
812
813fun Rewrite Once (Simpls(ss,thl),Context cntxt,Congs congs,Solver solver) =
814                 RW_STEPS ONCE_DEPTH (ss,cntxt,congs,solver) thl
815
816 | Rewrite Fully (Simpls(ss,thl),Context cntxt,Congs congs,Solver solver) =
817                 RW_STEPS TOP_DEPTH (ss,cntxt,congs,solver) thl
818
819 | Rewrite(Special f)(Simpls(ss,thl),Context cntxt,Congs congs,Solver solver) =
820                     RW_STEPS f (ss,cntxt,congs,solver) thl
821
822 | Rewrite Once (Default thl,Context cntxt,Congs congs,Solver solver) =
823                RW_STEPS ONCE_DEPTH (implicit_simpls(),
824                                     cntxt,congs,solver) thl
825
826 | Rewrite Once (Pure thl,Context cntxt,Congs congs,Solver solver) =
827                RW_STEPS ONCE_DEPTH (empty_simpls,cntxt,congs,solver) thl
828
829 | Rewrite Fully (Default thl,Context cntxt,Congs congs,Solver solver) =
830                 RW_STEPS TOP_DEPTH(implicit_simpls(),
831                                    cntxt,congs,solver) thl
832
833 | Rewrite Fully (Pure thl,Context cntxt,Congs congs,Solver solver) =
834                  RW_STEPS TOP_DEPTH (empty_simpls,cntxt,congs,solver) thl
835
836 | Rewrite (Special f) (Default thl,Context cntxt,Congs congs,Solver solver) =
837                 RW_STEPS f (implicit_simpls(),cntxt,congs,solver) thl
838
839 | Rewrite (Special f) (Pure thl,Context cntxt,Congs congs,Solver solver) =
840                       RW_STEPS f (empty_simpls,cntxt,congs,solver) thl;
841
842
843
844(*---------------------------------------------------------------------------
845 * Theorem rewriting
846 *---------------------------------------------------------------------------*)
847
848fun REWRITE_RULE style controls = CONV_RULE(Rewrite style controls);
849
850fun add_hyps asl =
851let val asl_thms = map ASSUME asl
852    fun add (Simpls(ss,thl),Context(L,p),c,s) =
853            (Simpls(ss, thl@asl_thms), Context(L@asl_thms,p),c,s)
854      | add (Pure thl,Context(L,p),c,s) =
855            (Pure(thl@asl_thms),Context(L@asl_thms,p),c,s)
856      | add (Default thl,Context(L,p),c,s) =
857            (Default(thl@asl_thms),Context(L@asl_thms,p),c,s)
858in add
859end
860
861fun ASM_REWRITE_RULE style controls =
862 fn th => REWRITE_RULE  style (add_hyps(hyp th) controls) th;
863
864
865(*---------------------------------------------------------------------------
866 * Goal rewriting
867 *---------------------------------------------------------------------------*)
868
869fun REWRITE_TAC style controls = CONV_TAC(Rewrite style controls);
870
871fun ASM_REWRITE_TAC style controls =
872  W(fn (asl,w) => REWRITE_TAC style (add_hyps asl controls));
873
874
875(*---------------------------------------------------------------------------
876 * Some solvers. One just does minor checking in the context; the other
877 * makes a recursive invocation of the rewriter.
878 *---------------------------------------------------------------------------*)
879
880fun solver_err() = raise RW_ERR "solver error" "";
881fun always_fails x y z = solver_err();
882
883(*---------------------------------------------------------------------------
884 * Just checks the context to see if it can find an instance of "tm".
885 *---------------------------------------------------------------------------*)
886
887fun std_solver _ context tm =
888 let val _ = if !monitoring > 0
889             then Lib.say("Solver: trying to lookup in context\n"
890                          ^term_to_string tm^"\n") else ()
891     fun loop [] = (if !monitoring > 0 then
892                      Lib.say "Solver: couldn't find it.\n"
893                    else ();
894                    solver_err())
895       | loop (x::rst) =
896           let val c = concl x
897           in if c = boolSyntax.F
898              then CCONTR tm x
899              else if aconv tm c then x
900                   else INST_TY_TERM (Term.match_term c tm) x
901                      handle HOL_ERR _ => loop rst
902           end
903     val thm = loop (boolTheory.TRUTH::context)
904 in
905    if !monitoring > 0 then Lib.say "Solver: found it.\n" else ();
906    thm
907 end;
908
909
910(*---------------------------------------------------------------------------*
911 * Make a recursive invocation of rewriting. Can be magically useful, but    *
912 * also can loop. In which case, use the std_solver.                         *
913 *---------------------------------------------------------------------------*)
914
915fun rw_solver simpls context tm =
916 let open boolSyntax
917     val _ = if !monitoring > 0
918             then Lib.say("Solver: attempting to prove (by rewriting)\n  "
919                          ^term_to_string tm^"\n") else ()
920     val th = TOP_DEPTH RW_STEP {context = (context,ADD),
921                                  simpls = simpls,
922                                  prover = rw_solver} tm
923     val _ = if !monitoring > 0
924             then let val (lhs,rhs) = dest_eq(concl th)
925                  in if rhs = T
926                     then Lib.say("Solver: proved\n"^thm_to_string th^"\n\n")
927                     else Lib.say("Solver: unable to prove.\n\n")
928                  end
929             else ()
930     val tm' = boolSyntax.rhs(concl th)
931     fun loop [] = solver_err()
932       | loop (x::rst) =
933           let val c = concl x
934           in if c = F then CCONTR tm x
935              else if aconv tm' c then x
936                   else INST_TY_TERM (Term.match_term c tm') x
937                      handle HOL_ERR _ => loop rst
938           end
939 in EQ_MP (SYM th) (loop (boolTheory.TRUTH::context))
940 end;
941
942
943(*---------------------------------------------------------------------------*
944 * The following are all instantiations of the above routines, to make them  *
945 * easier to invoke. Some of these are holdovers from unconditional          *
946 * rewriting and may not make a whole lot of sense. The "C" versions stand   *
947 * for using context as rewrite rules, and proving conditions via            *
948 * recursive invocations of the rewriter.                                    *
949 *---------------------------------------------------------------------------*)
950
951(* Rewrite a term *)
952
953fun CRW_CONV thl = Rewrite Fully (Default thl,Context([],ADD),
954                                  Congs[],Solver rw_solver)
955
956fun RW_CONV thl = Rewrite Fully (Default thl,Context([],ADD),
957                                 Congs[],Solver std_solver)
958
959fun PURE_RW_CONV thl = Rewrite Fully (Pure thl,Context([],DONT_ADD),
960                                      Congs[],Solver std_solver)
961fun ONCE_RW_CONV thl = Rewrite Once
962                               (Default thl,Context([],ADD),
963                                Congs[],Solver std_solver)
964
965fun PURE_ONCE_RW_CONV thl = Rewrite Once (Pure thl,Context([],DONT_ADD),
966                                          Congs[],Solver std_solver);
967
968
969(* Rewrite a theorem *)
970
971fun CRW_RULE thl = REWRITE_RULE Fully (Default thl,Context([],ADD),
972                                       Congs[],Solver rw_solver);
973fun RW_RULE thl = REWRITE_RULE Fully (Default thl,Context([],ADD),
974                                      Congs[],Solver std_solver);
975fun ONCE_RW_RULE thl = REWRITE_RULE Once (Default thl,Context([],ADD),
976                                          Congs[], Solver std_solver);
977fun PURE_RW_RULE thl = REWRITE_RULE Fully (Pure thl,Context([],DONT_ADD),
978                                           Congs[],Solver std_solver);
979fun PURE_ONCE_RW_RULE thl = REWRITE_RULE Once (Pure thl,Context([],DONT_ADD),
980                                               Congs[],Solver std_solver);
981
982
983(* Rewrite a theorem with the help of its assumptions *)
984
985fun ASM_CRW_RULE thl =
986ASM_REWRITE_RULE Fully (Default thl,Context([],ADD),Congs[],Solver rw_solver);
987
988fun ASM_RW_RULE thl =
989ASM_REWRITE_RULE Fully (Default thl,Context([],ADD),Congs[],Solver std_solver);
990
991fun ONCE_ASM_RW_RULE thl =
992ASM_REWRITE_RULE Once (Default thl,Context([],ADD),Congs[],Solver std_solver);
993
994fun PURE_ASM_RW_RULE thl =
995ASM_REWRITE_RULE Fully (Pure thl,Context([],DONT_ADD),
996                        Congs[],Solver std_solver);
997
998fun PURE_ONCE_ASM_RW_RULE thl =
999ASM_REWRITE_RULE Once (Pure thl,Context([],DONT_ADD),
1000                       Congs[],Solver std_solver);
1001
1002
1003(* Rewrite a goal *)
1004
1005fun CRW_TAC thl =
1006REWRITE_TAC Fully (Default thl,Context([],ADD),Congs[],Solver rw_solver);
1007
1008fun RW_TAC thl =
1009REWRITE_TAC Fully (Default thl,Context([],ADD),Congs[],Solver std_solver);
1010
1011fun ONCE_RW_TAC thl =
1012REWRITE_TAC Once(Default thl,Context([],ADD),Congs[],Solver std_solver);
1013
1014fun PURE_RW_TAC thl =
1015REWRITE_TAC Fully (Pure thl,Context([],DONT_ADD),Congs[],Solver std_solver);
1016
1017fun PURE_ONCE_RW_TAC thl =
1018REWRITE_TAC Once (Pure thl,Context([],DONT_ADD), Congs[],Solver std_solver);
1019
1020
1021(* Rewrite a goal with the help of its assumptions *)
1022
1023fun ASM_CRW_TAC thl =
1024ASM_REWRITE_TAC Fully (Default thl,Context([],ADD),Congs[],Solver rw_solver);
1025
1026fun ASM_RW_TAC thl =
1027ASM_REWRITE_TAC Fully (Default thl,Context([],ADD),Congs[],Solver std_solver);
1028
1029fun ONCE_ASM_RW_TAC thl =
1030ASM_REWRITE_TAC Once (Default thl,Context([],ADD),
1031                      Congs[],Solver std_solver);
1032
1033fun PURE_ASM_RW_TAC thl =
1034ASM_REWRITE_TAC Fully (Pure thl,Context([],DONT_ADD),
1035                       Congs[],Solver std_solver);
1036
1037fun PURE_ONCE_ASM_RW_TAC thl =
1038ASM_REWRITE_TAC Once (Pure thl,Context([],DONT_ADD),Congs[],Solver std_solver);
1039
1040fun Simpl tac std_thms thl =
1041  let val pss = add_rws (implicit_simpls()) std_thms
1042      val RWTAC = REWRITE_TAC Fully (Simpls(pss,thl),Context([],ADD),
1043                                     Congs[],Solver std_solver)
1044  in RWTAC THEN TRY(CHANGED_TAC tac THEN RWTAC)
1045  end;
1046
1047val _ = Parse.temp_set_grammars ambient_grammars;
1048
1049end (* structure RW *)
1050