1(*  Title:      ZF/int_arith.ML
2    Author:     Larry Paulson
3
4Simprocs for linear arithmetic.
5*)
6
7signature INT_NUMERAL_SIMPROCS =
8sig
9  val cancel_numerals: simproc list
10  val combine_numerals: simproc
11  val combine_numerals_prod: simproc
12end
13
14structure Int_Numeral_Simprocs: INT_NUMERAL_SIMPROCS =
15struct
16
17(* abstract syntax operations *)
18
19fun mk_bit 0 = \<^term>\<open>0\<close>
20  | mk_bit 1 = \<^term>\<open>succ(0)\<close>
21  | mk_bit _ = raise TERM ("mk_bit", []);
22
23fun dest_bit \<^term>\<open>0\<close> = 0
24  | dest_bit \<^term>\<open>succ(0)\<close> = 1
25  | dest_bit t = raise TERM ("dest_bit", [t]);
26
27fun mk_bin i =
28  let
29    fun term_of [] = \<^term>\<open>Pls\<close>
30      | term_of [~1] = \<^term>\<open>Min\<close>
31      | term_of (b :: bs) = \<^term>\<open>Bit\<close> $ term_of bs $ mk_bit b;
32  in term_of (Numeral_Syntax.make_binary i) end;
33
34fun dest_bin tm =
35  let
36    fun bin_of \<^term>\<open>Pls\<close> = []
37      | bin_of \<^term>\<open>Min\<close> = [~1]
38      | bin_of (\<^term>\<open>Bit\<close> $ bs $ b) = dest_bit b :: bin_of bs
39      | bin_of _ = raise TERM ("dest_bin", [tm]);
40  in Numeral_Syntax.dest_binary (bin_of tm) end;
41
42
43(*Utilities*)
44
45fun mk_numeral i = \<^const>\<open>integ_of\<close> $ mk_bin i;
46
47fun dest_numeral (Const(\<^const_name>\<open>integ_of\<close>, _) $ w) = dest_bin w
48  | dest_numeral t = raise TERM ("dest_numeral", [t]);
49
50fun find_first_numeral past (t::terms) =
51        ((dest_numeral t, rev past @ terms)
52         handle TERM _ => find_first_numeral (t::past) terms)
53  | find_first_numeral past [] = raise TERM("find_first_numeral", []);
54
55val zero = mk_numeral 0;
56val mk_plus = FOLogic.mk_binop \<^const_name>\<open>zadd\<close>;
57
58(*Thus mk_sum[t] yields t+#0; longer sums don't have a trailing zero*)
59fun mk_sum []        = zero
60  | mk_sum [t,u]     = mk_plus (t, u)
61  | mk_sum (t :: ts) = mk_plus (t, mk_sum ts);
62
63(*this version ALWAYS includes a trailing zero*)
64fun long_mk_sum []        = zero
65  | long_mk_sum (t :: ts) = mk_plus (t, mk_sum ts);
66
67(*decompose additions AND subtractions as a sum*)
68fun dest_summing (pos, Const (\<^const_name>\<open>zadd\<close>, _) $ t $ u, ts) =
69        dest_summing (pos, t, dest_summing (pos, u, ts))
70  | dest_summing (pos, Const (\<^const_name>\<open>zdiff\<close>, _) $ t $ u, ts) =
71        dest_summing (pos, t, dest_summing (not pos, u, ts))
72  | dest_summing (pos, t, ts) =
73        if pos then t::ts else \<^const>\<open>zminus\<close> $ t :: ts;
74
75fun dest_sum t = dest_summing (true, t, []);
76
77val one = mk_numeral 1;
78val mk_times = FOLogic.mk_binop \<^const_name>\<open>zmult\<close>;
79
80fun mk_prod [] = one
81  | mk_prod [t] = t
82  | mk_prod (t :: ts) = if t = one then mk_prod ts
83                        else mk_times (t, mk_prod ts);
84
85val dest_times = FOLogic.dest_bin \<^const_name>\<open>zmult\<close> \<^typ>\<open>i\<close>;
86
87fun dest_prod t =
88      let val (t,u) = dest_times t
89      in  dest_prod t @ dest_prod u  end
90      handle TERM _ => [t];
91
92(*DON'T do the obvious simplifications; that would create special cases*)
93fun mk_coeff (k, t) = mk_times (mk_numeral k, t);
94
95(*Express t as a product of (possibly) a numeral with other sorted terms*)
96fun dest_coeff sign (Const (\<^const_name>\<open>zminus\<close>, _) $ t) = dest_coeff (~sign) t
97  | dest_coeff sign t =
98    let val ts = sort Term_Ord.term_ord (dest_prod t)
99        val (n, ts') = find_first_numeral [] ts
100                          handle TERM _ => (1, ts)
101    in (sign*n, mk_prod ts') end;
102
103(*Find first coefficient-term THAT MATCHES u*)
104fun find_first_coeff past u [] = raise TERM("find_first_coeff", [])
105  | find_first_coeff past u (t::terms) =
106        let val (n,u') = dest_coeff 1 t
107        in  if u aconv u' then (n, rev past @ terms)
108                          else find_first_coeff (t::past) u terms
109        end
110        handle TERM _ => find_first_coeff (t::past) u terms;
111
112
113(*Simplify #1*n and n*#1 to n*)
114val add_0s = [@{thm zadd_0_intify}, @{thm zadd_0_right_intify}];
115
116val mult_1s = [@{thm zmult_1_intify}, @{thm zmult_1_right_intify},
117               @{thm zmult_minus1}, @{thm zmult_minus1_right}];
118
119val tc_rules = [@{thm integ_of_type}, @{thm intify_in_int},
120                @{thm int_of_type}, @{thm zadd_type}, @{thm zdiff_type}, @{thm zmult_type}] @ 
121               @{thms bin.intros};
122val intifys = [@{thm intify_ident}, @{thm zadd_intify1}, @{thm zadd_intify2},
123               @{thm zdiff_intify1}, @{thm zdiff_intify2}, @{thm zmult_intify1}, @{thm zmult_intify2},
124               @{thm zless_intify1}, @{thm zless_intify2}, @{thm zle_intify1}, @{thm zle_intify2}];
125
126(*To perform binary arithmetic*)
127val bin_simps = [@{thm add_integ_of_left}] @ @{thms bin_arith_simps} @ @{thms bin_rel_simps};
128
129(*To evaluate binary negations of coefficients*)
130val zminus_simps = @{thms NCons_simps} @
131                   [@{thm integ_of_minus} RS @{thm sym},
132                    @{thm bin_minus_1}, @{thm bin_minus_0}, @{thm bin_minus_Pls}, @{thm bin_minus_Min},
133                    @{thm bin_pred_1}, @{thm bin_pred_0}, @{thm bin_pred_Pls}, @{thm bin_pred_Min}];
134
135(*To let us treat subtraction as addition*)
136val diff_simps = [@{thm zdiff_def}, @{thm zminus_zadd_distrib}, @{thm zminus_zminus}];
137
138(*push the unary minus down*)
139val int_minus_mult_eq_1_to_2 = @{lemma "$- w $* z = w $* $- z" by simp};
140
141(*to extract again any uncancelled minuses*)
142val int_minus_from_mult_simps =
143    [@{thm zminus_zminus}, @{thm zmult_zminus}, @{thm zmult_zminus_right}];
144
145(*combine unary minus with numeric literals, however nested within a product*)
146val int_mult_minus_simps =
147    [@{thm zmult_assoc}, @{thm zmult_zminus} RS @{thm sym}, int_minus_mult_eq_1_to_2];
148
149structure CancelNumeralsCommon =
150  struct
151  val mk_sum = (fn _ : typ => mk_sum)
152  val dest_sum = dest_sum
153  val mk_coeff = mk_coeff
154  val dest_coeff = dest_coeff 1
155  val find_first_coeff = find_first_coeff []
156  fun trans_tac ctxt = ArithData.gen_trans_tac ctxt @{thm iff_trans}
157
158  val norm_ss1 =
159    simpset_of (put_simpset ZF_ss \<^context>
160      addsimps add_0s @ mult_1s @ diff_simps @ zminus_simps @ @{thms zadd_ac})
161  val norm_ss2 =
162    simpset_of (put_simpset ZF_ss \<^context>
163      addsimps bin_simps @ int_mult_minus_simps @ intifys)
164  val norm_ss3 =
165    simpset_of (put_simpset ZF_ss \<^context>
166      addsimps int_minus_from_mult_simps @ @{thms zadd_ac} @ @{thms zmult_ac} @ tc_rules @ intifys)
167  fun norm_tac ctxt =
168    ALLGOALS (asm_simp_tac (put_simpset norm_ss1 ctxt))
169    THEN ALLGOALS (asm_simp_tac (put_simpset norm_ss2 ctxt))
170    THEN ALLGOALS (asm_simp_tac (put_simpset norm_ss3 ctxt))
171
172  val numeral_simp_ss =
173    simpset_of (put_simpset ZF_ss \<^context>
174      addsimps add_0s @ bin_simps @ tc_rules @ intifys)
175  fun numeral_simp_tac ctxt =
176    ALLGOALS (simp_tac (put_simpset numeral_simp_ss ctxt))
177    THEN ALLGOALS (asm_simp_tac ctxt)
178  val simplify_meta_eq  = ArithData.simplify_meta_eq (add_0s @ mult_1s)
179  end;
180
181
182structure EqCancelNumerals = CancelNumeralsFun
183 (open CancelNumeralsCommon
184  val prove_conv = ArithData.prove_conv "inteq_cancel_numerals"
185  val mk_bal   = FOLogic.mk_eq
186  val dest_bal = FOLogic.dest_eq
187  val bal_add1 = @{thm eq_add_iff1} RS @{thm iff_trans}
188  val bal_add2 = @{thm eq_add_iff2} RS @{thm iff_trans}
189);
190
191structure LessCancelNumerals = CancelNumeralsFun
192 (open CancelNumeralsCommon
193  val prove_conv = ArithData.prove_conv "intless_cancel_numerals"
194  val mk_bal   = FOLogic.mk_binrel \<^const_name>\<open>zless\<close>
195  val dest_bal = FOLogic.dest_bin \<^const_name>\<open>zless\<close> \<^typ>\<open>i\<close>
196  val bal_add1 = @{thm less_add_iff1} RS @{thm iff_trans}
197  val bal_add2 = @{thm less_add_iff2} RS @{thm iff_trans}
198);
199
200structure LeCancelNumerals = CancelNumeralsFun
201 (open CancelNumeralsCommon
202  val prove_conv = ArithData.prove_conv "intle_cancel_numerals"
203  val mk_bal   = FOLogic.mk_binrel \<^const_name>\<open>zle\<close>
204  val dest_bal = FOLogic.dest_bin \<^const_name>\<open>zle\<close> \<^typ>\<open>i\<close>
205  val bal_add1 = @{thm le_add_iff1} RS @{thm iff_trans}
206  val bal_add2 = @{thm le_add_iff2} RS @{thm iff_trans}
207);
208
209val cancel_numerals =
210 [Simplifier.make_simproc \<^context> "inteq_cancel_numerals"
211   {lhss =
212     [\<^term>\<open>l $+ m = n\<close>, \<^term>\<open>l = m $+ n\<close>,
213      \<^term>\<open>l $- m = n\<close>, \<^term>\<open>l = m $- n\<close>,
214      \<^term>\<open>l $* m = n\<close>, \<^term>\<open>l = m $* n\<close>],
215    proc = K EqCancelNumerals.proc},
216  Simplifier.make_simproc \<^context> "intless_cancel_numerals"
217   {lhss =
218     [\<^term>\<open>l $+ m $< n\<close>, \<^term>\<open>l $< m $+ n\<close>,
219      \<^term>\<open>l $- m $< n\<close>, \<^term>\<open>l $< m $- n\<close>,
220      \<^term>\<open>l $* m $< n\<close>, \<^term>\<open>l $< m $* n\<close>],
221    proc = K LessCancelNumerals.proc},
222  Simplifier.make_simproc \<^context> "intle_cancel_numerals"
223   {lhss =
224     [\<^term>\<open>l $+ m $\<le> n\<close>, \<^term>\<open>l $\<le> m $+ n\<close>,
225      \<^term>\<open>l $- m $\<le> n\<close>, \<^term>\<open>l $\<le> m $- n\<close>,
226      \<^term>\<open>l $* m $\<le> n\<close>, \<^term>\<open>l $\<le> m $* n\<close>],
227    proc = K LeCancelNumerals.proc}];
228
229
230(*version without the hyps argument*)
231fun prove_conv_nohyps name tacs sg = ArithData.prove_conv name tacs sg [];
232
233structure CombineNumeralsData =
234  struct
235  type coeff = int
236  val iszero = (fn x => x = 0)
237  val add = op + 
238  val mk_sum = (fn _ : typ => long_mk_sum) (*to work for #2*x $+ #3*x *)
239  val dest_sum = dest_sum
240  val mk_coeff = mk_coeff
241  val dest_coeff = dest_coeff 1
242  val left_distrib = @{thm left_zadd_zmult_distrib} RS @{thm trans}
243  val prove_conv = prove_conv_nohyps "int_combine_numerals"
244  fun trans_tac ctxt = ArithData.gen_trans_tac ctxt @{thm trans}
245
246  val norm_ss1 =
247    simpset_of (put_simpset ZF_ss \<^context>
248      addsimps add_0s @ mult_1s @ diff_simps @ zminus_simps @ @{thms zadd_ac} @ intifys)
249  val norm_ss2 =
250    simpset_of (put_simpset ZF_ss \<^context>
251      addsimps bin_simps @ int_mult_minus_simps @ intifys)
252  val norm_ss3 =
253    simpset_of (put_simpset ZF_ss \<^context>
254      addsimps int_minus_from_mult_simps @ @{thms zadd_ac} @ @{thms zmult_ac} @ tc_rules @ intifys)
255  fun norm_tac ctxt =
256    ALLGOALS (asm_simp_tac (put_simpset norm_ss1 ctxt))
257    THEN ALLGOALS (asm_simp_tac (put_simpset norm_ss2 ctxt))
258    THEN ALLGOALS (asm_simp_tac (put_simpset norm_ss3 ctxt))
259
260  val numeral_simp_ss =
261    simpset_of (put_simpset ZF_ss \<^context> addsimps add_0s @ bin_simps @ tc_rules @ intifys)
262  fun numeral_simp_tac ctxt =
263    ALLGOALS (simp_tac (put_simpset numeral_simp_ss ctxt))
264  val simplify_meta_eq  = ArithData.simplify_meta_eq (add_0s @ mult_1s)
265  end;
266
267structure CombineNumerals = CombineNumeralsFun(CombineNumeralsData);
268
269val combine_numerals =
270  Simplifier.make_simproc \<^context> "int_combine_numerals"
271    {lhss = [\<^term>\<open>i $+ j\<close>, \<^term>\<open>i $- j\<close>],
272     proc = K CombineNumerals.proc};
273
274
275
276(** Constant folding for integer multiplication **)
277
278(*The trick is to regard products as sums, e.g. #3 $* x $* #4 as
279  the "sum" of #3, x, #4; the literals are then multiplied*)
280
281
282structure CombineNumeralsProdData =
283struct
284  type coeff = int
285  val iszero = (fn x => x = 0)
286  val add = op *
287  val mk_sum = (fn _ : typ => mk_prod)
288  val dest_sum = dest_prod
289  fun mk_coeff(k,t) =
290    if t = one then mk_numeral k
291    else raise TERM("mk_coeff", [])
292  fun dest_coeff t = (dest_numeral t, one)  (*We ONLY want pure numerals.*)
293  val left_distrib = @{thm zmult_assoc} RS @{thm sym} RS @{thm trans}
294  val prove_conv = prove_conv_nohyps "int_combine_numerals_prod"
295  fun trans_tac ctxt = ArithData.gen_trans_tac ctxt @{thm trans}
296
297  val norm_ss1 =
298    simpset_of (put_simpset ZF_ss \<^context> addsimps mult_1s @ diff_simps @ zminus_simps)
299  val norm_ss2 =
300    simpset_of (put_simpset ZF_ss \<^context> addsimps [@{thm zmult_zminus_right} RS @{thm sym}] @
301    bin_simps @ @{thms zmult_ac} @ tc_rules @ intifys)
302  fun norm_tac ctxt =
303    ALLGOALS (asm_simp_tac (put_simpset norm_ss1 ctxt))
304    THEN ALLGOALS (asm_simp_tac (put_simpset norm_ss2 ctxt))
305
306  val numeral_simp_ss =
307    simpset_of (put_simpset ZF_ss \<^context> addsimps bin_simps @ tc_rules @ intifys)
308  fun numeral_simp_tac ctxt =
309    ALLGOALS (simp_tac (put_simpset numeral_simp_ss ctxt))
310  val simplify_meta_eq  = ArithData.simplify_meta_eq (mult_1s);
311end;
312
313
314structure CombineNumeralsProd = CombineNumeralsFun(CombineNumeralsProdData);
315
316val combine_numerals_prod =
317  Simplifier.make_simproc \<^context> "int_combine_numerals_prod"
318   {lhss = [\<^term>\<open>i $* j\<close>], proc = K CombineNumeralsProd.proc};
319
320end;
321
322val _ =
323  Theory.setup (Simplifier.map_theory_simpset (fn ctxt =>
324    ctxt addsimprocs
325      (Int_Numeral_Simprocs.cancel_numerals @
326       [Int_Numeral_Simprocs.combine_numerals,
327        Int_Numeral_Simprocs.combine_numerals_prod])));
328