1(*  Title:      Provers/Arith/cancel_numerals.ML
2    Author:     Lawrence C Paulson, Cambridge University Computer Laboratory
3    Copyright   2000  University of Cambridge
4
5Cancel common coefficients in balanced expressions:
6
7     i + #m*u + j ~~ i' + #m'*u + j'  ==  #(m-m')*u + i + j ~~ i' + j'
8
9where ~~ is an appropriate balancing operation (e.g. =, <=, <, -).
10
11It works by (a) massaging both sides to bring the selected term to the front:
12
13     #m*u + (i + j) ~~ #m'*u + (i' + j')
14
15(b) then using bal_add1 or bal_add2 to reach
16
17     #(m-m')*u + i + j ~~ i' + j'       (if m'<=m)
18
19or
20
21     i + j ~~ #(m'-m)*u + i' + j'       (otherwise)
22*)
23
24signature CANCEL_NUMERALS_DATA =
25sig
26  (*abstract syntax*)
27  val mk_sum: typ -> term list -> term
28  val dest_sum: term -> term list
29  val mk_bal: term * term -> term
30  val dest_bal: term -> term * term
31  val mk_coeff: int * term -> term
32  val dest_coeff: term -> int * term
33  val find_first_coeff: term -> term list -> int * term list
34  (*rules*)
35  val bal_add1: thm
36  val bal_add2: thm
37  (*proof tools*)
38  val prove_conv: tactic list -> Proof.context -> thm list -> term * term -> thm option
39  val trans_tac: Proof.context -> thm option -> tactic            (*applies the initial lemma*)
40  val norm_tac: Proof.context -> tactic          (*proves the initial lemma*)
41  val numeral_simp_tac: Proof.context -> tactic  (*proves the final theorem*)
42  val simplify_meta_eq: Proof.context -> thm -> thm (*simplifies the final theorem*)
43end;
44
45signature CANCEL_NUMERALS =
46sig
47  val proc: Proof.context -> cterm -> thm option
48end;
49
50functor CancelNumeralsFun(Data: CANCEL_NUMERALS_DATA): CANCEL_NUMERALS =
51struct
52
53(*For t = #n*u then put u in the table*)
54fun update_by_coeff t =
55  Termtab.update (#2 (Data.dest_coeff t), ());
56
57(*a left-to-right scan of terms1, seeking a term of the form #n*u, where
58  #m*u is in terms2 for some m*)
59fun find_common (terms1,terms2) =
60  let val tab2 = fold update_by_coeff terms2 Termtab.empty
61      fun seek [] = raise TERM("find_common", [])
62        | seek (t::terms) =
63              let val (_,u) = Data.dest_coeff t
64              in if Termtab.defined tab2 u then u else seek terms end
65  in  seek terms1 end;
66
67(*the simplification procedure*)
68fun proc ctxt ct =
69  let
70    val prems = Simplifier.prems_of ctxt
71    val t = Thm.term_of ct
72    val ([t'], ctxt') = Variable.import_terms true [t] ctxt
73    val export = singleton (Variable.export ctxt' ctxt)
74    (* FIXME ctxt vs. ctxt' (!?) *)
75
76    val (t1,t2) = Data.dest_bal t'
77    val terms1 = Data.dest_sum t1
78    and terms2 = Data.dest_sum t2
79
80    val u = find_common (terms1, terms2)
81    val (n1, terms1') = Data.find_first_coeff u terms1
82    and (n2, terms2') = Data.find_first_coeff u terms2
83    and T = Term.fastype_of u
84
85    fun newshape (i,terms) = Data.mk_sum T (Data.mk_coeff(i,u)::terms)
86    val reshape =  (*Move i*u to the front and put j*u into standard form
87                       i + #m + j + k == #m + i + (j + k) *)
88        if n1=0 orelse n2=0 then   (*trivial, so do nothing*)
89          raise TERM("cancel_numerals", [])
90        else Data.prove_conv [Data.norm_tac ctxt] ctxt prems
91          (t', Data.mk_bal (newshape(n1,terms1'), newshape(n2,terms2')))
92  in
93    Option.map (export o Data.simplify_meta_eq ctxt)
94      (if n2 <= n1 then
95         Data.prove_conv
96           [Data.trans_tac ctxt reshape, resolve_tac ctxt [Data.bal_add1] 1,
97            Data.numeral_simp_tac ctxt] ctxt prems
98           (t', Data.mk_bal (newshape(n1-n2,terms1'), Data.mk_sum T terms2'))
99       else
100         Data.prove_conv
101           [Data.trans_tac ctxt reshape, resolve_tac ctxt [Data.bal_add2] 1,
102            Data.numeral_simp_tac ctxt] ctxt prems
103           (t', Data.mk_bal (Data.mk_sum T terms1', newshape(n2-n1,terms2'))))
104  end
105  (* FIXME avoid handling of generic exceptions *)
106  handle TERM _ => NONE
107       | TYPE _ => NONE;   (*Typically (if thy doesn't include Numeral)
108                             Undeclared type constructor "Numeral.bin"*)
109
110end;
111