1structure Cond_rewr :> Cond_rewr =
2struct
3
4open HolKernel boolLib liteLib Trace;
5
6type controlled_thm = BoundedRewrites.controlled_thm
7
8fun WRAP_ERR x = STRUCT_WRAP "Cond_rewr" x;
9fun ERR x      = STRUCT_ERR "Cond_rewr" x;
10
11val stack_limit = ref 4;
12
13val track_rewrites = ref false;
14val used_rewrites  = ref [] : thm list ref;
15
16(* -----------------------------------------------------------------------*
17 * A total ordering on terms.  The behaviour of the simplifier depends    *
18 * on this, so don't change it without thinking.                          *
19 *                                                                        *
20 * Based on some code in Isabelle.                                        *
21 *                                                                        *
22 * A strict (not reflexive) linear well-founded AC-compatible ordering    *
23 * for terms.                                                             *
24 *                                                                        *
25 * Modified by DRS to have certain AC properties.  Vars are always        *
26 * bigger than constants (hence move to the right).  They are             *
27 * also bigger than unary comb functions.  They can't be bigger than      *
28 * 2 or more argument functions as AC rewriting then loops (you           *
29 * need var < f(var2,var3))                                               *
30 * -----------------------------------------------------------------------*)
31
32fun size_of_term tm =
33     case dest_term tm
34      of LAMB(Bvar,Body) => 1 + size_of_term Body
35       | COMB(Rator,Rand) => size_of_term Rator + size_of_term Rand
36       | _ => 1
37
38val op lex_cmp = pair_compare
39infix lex_cmp
40
41fun dest_hd env t =
42    case dest_term t of
43      VAR (s, ty) => let
44      in
45        case Binarymap.peek(env, t) of
46          NONE => (((s, ""), 0), ty)
47        | SOME n => ((("", ""), ~n), Type.alpha)
48      end
49    | CONST {Name, Thy, Ty} => (((Name, Thy), 1), Ty)
50    | LAMB (bv, body) => ((("", ""), 2), type_of bv)
51    | COMB _ => ((("", ""), 3), Type.alpha) (* should never happen *)
52
53fun hd_compare (env1, env2) (t1, t2) =
54    (String.compare lex_cmp String.compare lex_cmp Int.compare
55     lex_cmp Type.compare)
56    (dest_hd env1 t1, dest_hd env2 t2)
57
58fun ac_term_ord0 n (e as (env1, env2)) (tm1, tm2) = let
59  val cmp = ac_term_ord0 n e
60in
61  case Int.compare (size_of_term tm1, size_of_term tm2) of
62    EQUAL => let
63    in
64      if is_abs tm1 then
65        if is_abs tm2 then let
66            val (bv1, bdy1) = dest_abs tm1
67            val (bv2, bdy2) = dest_abs tm2
68          in
69            case Type.compare(type_of bv1, type_of bv2) of
70              EQUAL => let
71                val env1' = Binarymap.insert(env1, bv1, n)
72                val env2' = Binarymap.insert(env2, bv2, n)
73              in
74                ac_term_ord0 (n + 1) (env1', env2') (bdy1, bdy2)
75              end
76            | x => x
77          end
78        else GREATER
79      else if is_abs tm2 then LESS
80      else let
81          val (f, xs) = strip_comb tm1
82          val (g, ys) = strip_comb tm2
83        in
84          (hd_compare e lex_cmp Int.compare lex_cmp list_compare cmp)
85          (((f, length xs), xs), ((g, length ys), ys))
86        end
87    end
88  | x => x
89end
90val empty_dict = Binarymap.mkDict Term.compare
91val ac_term_ord = ac_term_ord0 0 (empty_dict, empty_dict)
92
93(* bad old implementation, has a loop between
94
95  (x + y) + 1  >  x + (y + 1)  >  x + (1 + y)  >  1 + (x + y)  >  (x + y) + 1
96
97 remembering that 1 is really NUMERAL (NUMERAL_BIT1 ALT_ZERO)
98
99fun ac_term_ord(tm1,tm2) =
100   case (dest_term tm1, dest_term tm2) of
101      (VAR _,CONST _) => GREATER
102    | (VAR _, COMB (Rator,Rand)) => if is_comb Rator then LESS else GREATER
103    | (CONST _, VAR _) => LESS
104    | (COMB (Rator,Rand), VAR _) => if is_comb Rator then GREATER else LESS
105    | (VAR v1, VAR v2) => String.compare(fst v1, fst v2)
106    | (CONST c1, CONST c2) =>
107        (case String.compare(#Name c1,#Name c2)
108          of EQUAL => String.compare(#Thy c1,#Thy c2)
109           | other => other)
110    | (dt1,dt2) =>
111      (case Int.compare (size_of_term tm1,size_of_term tm2) of
112       EQUAL =>
113         (case (dt1,dt2) of
114            (LAMB l1,LAMB l2) => ac_term_ord(snd l1, snd l2)
115          | _ => let val (con,args) = strip_comb tm1
116                     val (con2,args2) = strip_comb tm2
117                 in case ac_term_ord (con,con2) of
118                    EQUAL => list_ord ac_term_ord (args,args2)
119                  | ord => ord
120                 end)
121       | ord => ord)
122
123*)
124
125   (* ---------------------------------------------------------------------
126    * COND_REWR_CONV
127    * ---------------------------------------------------------------------*)
128
129   fun vperm(tm1,tm2) =
130    case (dest_term tm1, dest_term tm2)
131     of (VAR v1,VAR v2)   => (snd v1 = snd v2)
132      | (LAMB t1,LAMB t2) => vperm(snd t1, snd t2)
133      | (COMB t1,COMB t2) => vperm(fst t1,fst t2) andalso vperm(snd t1,snd t2)
134      | (x,y) => (x = y)
135
136   fun is_var_perm(tm1,tm2) =
137       vperm(tm1,tm2) andalso set_eq (free_vars tm1) (free_vars tm2)
138
139   fun COND_REWR_CONV th bounded =
140      let val eqn = snd (strip_imp (concl th))
141          val isperm = is_var_perm (dest_eq eqn)
142          val instth = HO_PART_MATCH (lhs o snd o strip_imp) th
143                       handle HOL_ERR _ => ERR("COND_REWR_CONV",
144                         "bad theorem argument (not a conditional equation)")
145      in
146      fn solver => fn stack => fn tm =>
147       (let val conditional_eqn = instth tm
148            val (conditions,eqn) = strip_imp (concl conditional_eqn)
149            val _ = if exists (C (op_mem aconv) stack) conditions
150                        then (trace(1, TEXT "looping - cut");
151                              failwith "looping!") else ()
152            val _ = if length stack + length conditions > (!stack_limit)
153                    then (trace(1, TEXT "looping - stack limit reached");
154                          failwith "stack limit") else ()
155            val (l,r) = dest_eq eqn
156            val _ =
157              if Term.aconv l r then
158                (trace(4, IGNORE ("Rewrite loops", conditional_eqn));
159                 failwith "looping rewrite")
160              else ()
161
162            val _ = if isperm andalso ac_term_ord(l, r) <> GREATER andalso
163                       not bounded
164                    then
165                      (trace(4, IGNORE("possibly looping",conditional_eqn));
166                       failwith "permutative rewr: not applied")
167                    else ()
168            val _ = if null conditions then ()
169                    else trace(if isperm then 2 else 1, REWRITING(tm,th))
170            val new_stack = conditions@stack
171            fun solver' condition =
172                 let val _   = trace(2,SIDECOND_ATTEMPT condition)
173                     val res = solver new_stack condition
174                      handle e as HOL_ERR _
175                       =>  (trace(1,SIDECOND_NOT_SOLVED condition); raise e)
176                 in trace(2,SIDECOND_SOLVED res);
177                    res
178                 end
179            val condition_thms = map solver' conditions
180            val disch_eqn = rev_itlist (C MP) condition_thms conditional_eqn
181            val final_thm = if (l = tm) then disch_eqn
182                            else TRANS (ALPHA tm l) disch_eqn
183            val _ = if null conditions then
184              trace(if isperm then 2 else 1, REWRITING(tm,th))
185                    else ()
186            val _ = if null stack andalso !track_rewrites
187                      then used_rewrites := th :: !used_rewrites
188                      else ()
189        in trace(if isperm then 3 else 2,PRODUCE(tm,"rewrite",final_thm));
190            final_thm
191        end
192        handle e => WRAP_ERR("COND_REWR_CONV (application)",e))
193      end
194      handle e  => WRAP_ERR("COND_REWR_CONV (construction) ",e);
195
196
197val BOUNDED_t = mk_thy_const {Thy = "bool", Name = "BOUNDED",
198                              Ty = bool --> bool}
199fun loops th = let
200  val (l,r) = dest_eq (concl th)
201in
202  can (find_term (aconv l)) r
203end handle HOL_ERR _ => failwith "loops"
204
205
206(*-------------------------------------------------------------------------
207 * IMP_CONJ_THM
208 * IMP_CONJ_RULE
209 * CONJ_DISCH
210 *
211 * CONJ_DISCH discharges a list of assumptions, and conjoins them as
212 * a single antecedent.
213 *
214 * EXAMPLE
215 *
216 * CONJ_DISCH [`P:bool`,`Q:bool`] (mk_thm([`P:bool`,`Q:bool`,`R:bool`],`T`));
217 * val it = [R] |- P /\ Q ==> T : thm
218 *------------------------------------------------------------------------*)
219
220
221val CONJ_DISCH =
222  let val IMP_CONJ_RULE =
223      let val (t1,t2,t3) = triple_of_list(fst(strip_forall(concl AND_IMP_INTRO)))
224          val IMP_CONJ_THM = fst(EQ_IMP_RULE (SPEC_ALL AND_IMP_INTRO))
225      in fn th =>
226        let val (p,qr) = dest_imp(concl th)
227            val (q,r) = dest_imp qr
228        in MP (INST [t1 |-> p, t2 |-> q, t3 |-> r] IMP_CONJ_THM) th
229        end
230      end;
231  in fn asms => fn th =>
232    itlist (fn tm => (fn th => IMP_CONJ_RULE th
233                      handle HOL_ERR _ => th) o DISCH tm)
234    asms th
235  end;
236
237
238
239  (* ----------------------------------------------------------------------
240   * IMP_EQ_CANON
241   *
242   * Put a theorem into canonical form as a conditional equality.
243   *
244   * Makes the set of rewrites from a given theorem.
245   * Split a theorem into a list of theorems suitable for rewriting:
246   *   1. Specialize all variables (SPEC_ALL).
247   *   2. Move all conditions into assumptions
248   *   3. Then do the following:
249   *     A |- t1 /\ t2     -->    [A |- t1 ; A |- t2]
250   *   4. Then A |- t --> [A |- t = T]
251   *           A |- ~(t1 = t2) -> [A |- (t1 = t2) = F; A |- (t2 = t1) = F]
252   *           A |- ~t --> A |- [t = F]
253   *           A |- F --> thrown away  (hmmm... a bit suss)
254   *           A |- T --> thrown away
255   *   5. Discharge all conditions as one single conjoined condition.
256   *   6. Existentially quantify variables free in the conditions
257   *      but not free in the equation.
258   *
259   * EXAMPLES
260   *
261   * IMP_EQ_CANON [mk_thm([],`foo (s1,s2) ==> P s2`];
262   * IMP_EQ_CANON (mk_thm([],`foo (s1,s2) ==> (v1 = v2)`));
263   * ----------------------------------------------------------------------*)
264(* new version of this due to is_imp/negation problem in hol90 *)
265
266fun UNDISCH_ALL th =
267  if is_imp (concl th) then UNDISCH_ALL (UNDISCH th)
268  else th;;
269
270val truth_tm = boolSyntax.T
271val false_tm = boolSyntax.F
272val Abbrev_tm = prim_mk_const {Name = "Abbrev", Thy = "marker"}
273val x_eq_false = SPEC (mk_eq(genvar bool, false_tm)) FALSITY
274val TF_EQ_F = PROVE_HYP (UNDISCH_ALL (NOT_ELIM (CONJUNCT1 BOOL_EQ_DISTINCT)))
275                        (UNDISCH_ALL x_eq_false)
276val FT_EQ_F = PROVE_HYP (UNDISCH_ALL (NOT_ELIM (CONJUNCT2 BOOL_EQ_DISTINCT)))
277                        (UNDISCH_ALL x_eq_false)
278
279fun IMP_EQ_CANON (thm,bnd) = let
280  val conditions = #1 (strip_imp (concl thm))
281  val hypfvs = hyp_frees thm
282  val undisch_thm = UNDISCH_ALL thm
283  val conc = concl undisch_thm
284  fun IMP_EQ_CANONb th = IMP_EQ_CANON (th, bnd)
285  val undisch_rewrites =
286      if (is_eq conc) then
287        if loops undisch_thm andalso bnd = BoundedRewrites.UNBOUNDED then
288          (trace(1,IGNORE("looping rewrite (but adding EQT versions)",thm));
289           [(EQT_INTRO undisch_thm, bnd), (EQT_INTRO (SYM undisch_thm), bnd)])
290        else
291          let
292            val (l,r) = dest_eq conc
293          in
294            if l = truth_tm then
295              if r = false_tm then [(PROVE_HYP undisch_thm TF_EQ_F, bnd)]
296              else [(CONV_RULE (REWR_CONV EQ_SYM_EQ) undisch_thm, bnd)]
297            else if l = false_tm then
298              if r = truth_tm then [(PROVE_HYP undisch_thm FT_EQ_F,bnd)]
299              else [(CONV_RULE (REWR_CONV EQ_SYM_EQ) undisch_thm, bnd)]
300            else
301              let
302                fun safelhs t =
303                  not (is_var t) orelse type_of t <> bool orelse
304                  t IN hypfvs orelse bnd <> BoundedRewrites.UNBOUNDED
305                val base =
306                    if null (subtract (free_vars r) (free_varsl (l::hyp thm)))
307                       andalso safelhs l
308                    then undisch_thm
309                    else
310                      (trace(1,IGNORE("rewrite with bad vars (adding \
311                                      \EQT version(s))",thm));
312                       EQT_INTRO undisch_thm)
313                val flip_eqp = let val (l,r) = dest_eq (concl base)
314                               in
315                                 is_eq l andalso not (is_eq r)
316                               end
317              in
318                if flip_eqp then
319                  [(base, bnd),
320                   (CONV_RULE (LAND_CONV (REWR_CONV EQ_SYM_EQ)) base, bnd)]
321                else [(base,bnd)]
322              end
323          end
324      else if is_neg conc then
325        let
326          val n = dest_neg conc
327        in
328          if is_eq n then
329            [(EQF_INTRO undisch_thm, bnd), (EQF_INTRO (GSYM undisch_thm), bnd)]
330          else if is_var n andalso not (n IN hypfvs) then
331            (trace(1, IGNORE ("boolean variable conclusion", thm)); [])
332          else
333            [(EQF_INTRO undisch_thm, bnd)]
334        end
335      else if conc = truth_tm then
336        (trace(2,IGNORE ("pointless rewrite",thm)); [])
337      else if conc = false_tm then [(MP x_eq_false undisch_thm, bnd)]
338      else if is_comb conc andalso same_const (rator conc) Abbrev_tm then let
339          val rnd = rand conc
340          fun funeqconv t =
341              if is_abs (rhs t) then
342                (REWR_CONV FUN_EQ_THM THENC
343                 BINDER_CONV (RAND_CONV BETA_CONV THENC funeqconv)) t
344              else REWR_CONV EQ_SYM_EQ t
345        in
346          if is_eq rnd then let
347              val bth = undisch_thm
348                          |> CONV_RULE (REWR_CONV markerTheory.Abbrev_def)
349              val base_eqns = bth |> SYM |> IMP_EQ_CANONb
350            in
351              if is_abs (rhs rnd) then
352                base_eqns @
353                (bth |> CONV_RULE funeqconv |> SPEC_ALL |> IMP_EQ_CANONb)
354              else
355                base_eqns
356            end
357          else []
358        end
359      else if is_var conc andalso not (conc IN hypfvs) then
360        (trace(1, IGNORE ("boolean variable conclusion", thm)); [])
361      else
362        [(EQT_INTRO undisch_thm, bnd)]
363in
364  map (fn (th,bnd) => (CONJ_DISCH conditions th, bnd)) undisch_rewrites
365end handle e => WRAP_ERR("IMP_EQ_CANON",e);
366
367
368fun QUANTIFY_CONDITIONS (thm, bnd) =
369    if is_imp (concl thm) then let
370        val free_in_eqn = (free_vars (snd(dest_imp (concl thm))))
371        val free_in_thm = (free_vars (concl thm))
372        val free_in_hyp = free_varsl (hyp thm)
373        val free_in_conditions =
374            subtract (subtract free_in_thm free_in_eqn) free_in_hyp
375        fun quantify fv = CONV_RULE (HO_REWR_CONV LEFT_FORALL_IMP_THM) o GEN fv
376        val quan_thm = itlist quantify free_in_conditions thm
377      in
378        [(quan_thm, bnd)]
379      end
380    else [(thm, bnd)]
381 handle e => WRAP_ERR("QUANTIFY_CONDITIONS",e)
382
383fun imp_canon_munge acc antthlist =
384    case antthlist of
385      [] => acc
386    | ((ants, th, bnd) :: rest) =>
387      imp_canon_munge ((List.foldl (uncurry DISCH) th ants, bnd) :: acc) rest
388
389fun IMP_CANON acc thl =
390    case thl of
391      [] => imp_canon_munge [] acc
392    | (ants, th, bnd)::ths => let
393        val w = concl th
394      in
395        if is_conj w then let
396            val (th1, th2) = CONJ_PAIR th
397          in
398            IMP_CANON acc ((ants, th1, bnd) :: (ants, th2, bnd) :: ths)
399          end
400        else if is_imp w then let
401            val (ant,c) = dest_imp w
402          in
403            if is_conj ant then let
404                val (conj1,conj2) = dest_conj ant
405                val newth =
406                    DISCH conj1 (DISCH conj2 (MP th (CONJ (ASSUME conj1)
407                                                          (ASSUME conj2))))
408              in
409                IMP_CANON acc ((ants, newth, bnd) :: ths)
410              end
411            else if is_disj ant then let
412                val (disj1,disj2) = dest_disj ant
413                val newth1 = DISCH disj1 (MP th (DISJ1 (ASSUME disj1) disj2))
414                val newth2 = DISCH disj2 (MP th (DISJ2 disj1 (ASSUME disj2)))
415              in
416                IMP_CANON acc
417                          ((ants, newth1, bnd) :: (ants, newth2, bnd) :: ths)
418              end
419            else if is_exists ant then let
420                val (Bvar,Body) = dest_exists ant
421                val bv' = variant (thm_frees th) Bvar
422                val body' = subst [Bvar |-> bv'] Body
423                val newth =
424                    DISCH body' (MP th (EXISTS(ant, bv') (ASSUME body')))
425              in
426                IMP_CANON acc ((ants, newth, bnd) :: ths)
427              end
428            else if c = boolSyntax.F then
429              IMP_CANON ((ants, NOT_INTRO th, bnd) :: acc) ths
430              (* we want [.] |- F theorems to rewrite to [.] |- x = F,
431                 done above in IMP_EQ_CANON, but we don't want this to
432                 be done for |- P ==> F, which would set up a rewrite
433                 of the form |- P ==> (x = F), which would match any
434                 boolean term and force endless attempts to prove P.
435                 Instead, convert to |- ~P *)
436            else
437              IMP_CANON acc ((ant::ants, UNDISCH th, bnd)::ths)
438          end
439        else if is_forall w then
440          IMP_CANON acc ((ants, SPEC_ALL th, bnd) :: ths)
441        else if is_res_forall w then let
442            val newth = CONV_RULE (REWR_CONV RES_FORALL_THM THENC
443                                   QUANT_CONV (RAND_CONV BETA_CONV)) th
444          in
445            IMP_CANON acc ((ants, newth, bnd) :: ths)
446          end
447        else IMP_CANON ((ants, th, bnd)::acc) ths
448      end
449
450val IMP_CANON = (fn (th,bnd) => IMP_CANON [] [([], th, bnd)])
451
452infix oo;
453fun f oo g = fn x => flatten (map f (g x));
454
455fun mk_cond_rewrs l =
456    (QUANTIFY_CONDITIONS oo IMP_EQ_CANON oo IMP_CANON) l
457    handle e as HOL_ERR _ => WRAP_ERR("mk_cond_rewrs",e);
458
459end;
460
461
462(* TESTS:
463 SIMP_CONV sum_ss [] (--`ISL y ==> (y = INL (OUTL y))`--);
464
465
466val th1 = ASSUME (--`!(x:num) (y:num). Q x y ==> R x`--);
467
468    SIMP_CONV (merge_ss [bool_ss,SATISFY_ss]) [th1] (--`Q 1 3 ==> R 1`--);
469
470
471
472*)
473