1(* Title: ZF/int_arith.ML 2 Author: Larry Paulson 3 4Simprocs for linear arithmetic. 5*) 6 7signature INT_NUMERAL_SIMPROCS = 8sig 9 val cancel_numerals: simproc list 10 val combine_numerals: simproc 11 val combine_numerals_prod: simproc 12end 13 14structure Int_Numeral_Simprocs: INT_NUMERAL_SIMPROCS = 15struct 16 17(* abstract syntax operations *) 18 19fun mk_bit 0 = \<^term>\<open>0\<close> 20 | mk_bit 1 = \<^term>\<open>succ(0)\<close> 21 | mk_bit _ = raise TERM ("mk_bit", []); 22 23fun dest_bit \<^term>\<open>0\<close> = 0 24 | dest_bit \<^term>\<open>succ(0)\<close> = 1 25 | dest_bit t = raise TERM ("dest_bit", [t]); 26 27fun mk_bin i = 28 let 29 fun term_of [] = \<^term>\<open>Pls\<close> 30 | term_of [~1] = \<^term>\<open>Min\<close> 31 | term_of (b :: bs) = \<^term>\<open>Bit\<close> $ term_of bs $ mk_bit b; 32 in term_of (Numeral_Syntax.make_binary i) end; 33 34fun dest_bin tm = 35 let 36 fun bin_of \<^term>\<open>Pls\<close> = [] 37 | bin_of \<^term>\<open>Min\<close> = [~1] 38 | bin_of (\<^term>\<open>Bit\<close> $ bs $ b) = dest_bit b :: bin_of bs 39 | bin_of _ = raise TERM ("dest_bin", [tm]); 40 in Numeral_Syntax.dest_binary (bin_of tm) end; 41 42 43(*Utilities*) 44 45fun mk_numeral i = \<^const>\<open>integ_of\<close> $ mk_bin i; 46 47fun dest_numeral (Const(\<^const_name>\<open>integ_of\<close>, _) $ w) = dest_bin w 48 | dest_numeral t = raise TERM ("dest_numeral", [t]); 49 50fun find_first_numeral past (t::terms) = 51 ((dest_numeral t, rev past @ terms) 52 handle TERM _ => find_first_numeral (t::past) terms) 53 | find_first_numeral past [] = raise TERM("find_first_numeral", []); 54 55val zero = mk_numeral 0; 56val mk_plus = FOLogic.mk_binop \<^const_name>\<open>zadd\<close>; 57 58(*Thus mk_sum[t] yields t+#0; longer sums don't have a trailing zero*) 59fun mk_sum [] = zero 60 | mk_sum [t,u] = mk_plus (t, u) 61 | mk_sum (t :: ts) = mk_plus (t, mk_sum ts); 62 63(*this version ALWAYS includes a trailing zero*) 64fun long_mk_sum [] = zero 65 | long_mk_sum (t :: ts) = mk_plus (t, mk_sum ts); 66 67(*decompose additions AND subtractions as a sum*) 68fun dest_summing (pos, Const (\<^const_name>\<open>zadd\<close>, _) $ t $ u, ts) = 69 dest_summing (pos, t, dest_summing (pos, u, ts)) 70 | dest_summing (pos, Const (\<^const_name>\<open>zdiff\<close>, _) $ t $ u, ts) = 71 dest_summing (pos, t, dest_summing (not pos, u, ts)) 72 | dest_summing (pos, t, ts) = 73 if pos then t::ts else \<^const>\<open>zminus\<close> $ t :: ts; 74 75fun dest_sum t = dest_summing (true, t, []); 76 77val one = mk_numeral 1; 78val mk_times = FOLogic.mk_binop \<^const_name>\<open>zmult\<close>; 79 80fun mk_prod [] = one 81 | mk_prod [t] = t 82 | mk_prod (t :: ts) = if t = one then mk_prod ts 83 else mk_times (t, mk_prod ts); 84 85val dest_times = FOLogic.dest_bin \<^const_name>\<open>zmult\<close> \<^typ>\<open>i\<close>; 86 87fun dest_prod t = 88 let val (t,u) = dest_times t 89 in dest_prod t @ dest_prod u end 90 handle TERM _ => [t]; 91 92(*DON'T do the obvious simplifications; that would create special cases*) 93fun mk_coeff (k, t) = mk_times (mk_numeral k, t); 94 95(*Express t as a product of (possibly) a numeral with other sorted terms*) 96fun dest_coeff sign (Const (\<^const_name>\<open>zminus\<close>, _) $ t) = dest_coeff (~sign) t 97 | dest_coeff sign t = 98 let val ts = sort Term_Ord.term_ord (dest_prod t) 99 val (n, ts') = find_first_numeral [] ts 100 handle TERM _ => (1, ts) 101 in (sign*n, mk_prod ts') end; 102 103(*Find first coefficient-term THAT MATCHES u*) 104fun find_first_coeff past u [] = raise TERM("find_first_coeff", []) 105 | find_first_coeff past u (t::terms) = 106 let val (n,u') = dest_coeff 1 t 107 in if u aconv u' then (n, rev past @ terms) 108 else find_first_coeff (t::past) u terms 109 end 110 handle TERM _ => find_first_coeff (t::past) u terms; 111 112 113(*Simplify #1*n and n*#1 to n*) 114val add_0s = [@{thm zadd_0_intify}, @{thm zadd_0_right_intify}]; 115 116val mult_1s = [@{thm zmult_1_intify}, @{thm zmult_1_right_intify}, 117 @{thm zmult_minus1}, @{thm zmult_minus1_right}]; 118 119val tc_rules = [@{thm integ_of_type}, @{thm intify_in_int}, 120 @{thm int_of_type}, @{thm zadd_type}, @{thm zdiff_type}, @{thm zmult_type}] @ 121 @{thms bin.intros}; 122val intifys = [@{thm intify_ident}, @{thm zadd_intify1}, @{thm zadd_intify2}, 123 @{thm zdiff_intify1}, @{thm zdiff_intify2}, @{thm zmult_intify1}, @{thm zmult_intify2}, 124 @{thm zless_intify1}, @{thm zless_intify2}, @{thm zle_intify1}, @{thm zle_intify2}]; 125 126(*To perform binary arithmetic*) 127val bin_simps = [@{thm add_integ_of_left}] @ @{thms bin_arith_simps} @ @{thms bin_rel_simps}; 128 129(*To evaluate binary negations of coefficients*) 130val zminus_simps = @{thms NCons_simps} @ 131 [@{thm integ_of_minus} RS @{thm sym}, 132 @{thm bin_minus_1}, @{thm bin_minus_0}, @{thm bin_minus_Pls}, @{thm bin_minus_Min}, 133 @{thm bin_pred_1}, @{thm bin_pred_0}, @{thm bin_pred_Pls}, @{thm bin_pred_Min}]; 134 135(*To let us treat subtraction as addition*) 136val diff_simps = [@{thm zdiff_def}, @{thm zminus_zadd_distrib}, @{thm zminus_zminus}]; 137 138(*push the unary minus down*) 139val int_minus_mult_eq_1_to_2 = @{lemma "$- w $* z = w $* $- z" by simp}; 140 141(*to extract again any uncancelled minuses*) 142val int_minus_from_mult_simps = 143 [@{thm zminus_zminus}, @{thm zmult_zminus}, @{thm zmult_zminus_right}]; 144 145(*combine unary minus with numeric literals, however nested within a product*) 146val int_mult_minus_simps = 147 [@{thm zmult_assoc}, @{thm zmult_zminus} RS @{thm sym}, int_minus_mult_eq_1_to_2]; 148 149structure CancelNumeralsCommon = 150 struct 151 val mk_sum = (fn _ : typ => mk_sum) 152 val dest_sum = dest_sum 153 val mk_coeff = mk_coeff 154 val dest_coeff = dest_coeff 1 155 val find_first_coeff = find_first_coeff [] 156 fun trans_tac ctxt = ArithData.gen_trans_tac ctxt @{thm iff_trans} 157 158 val norm_ss1 = 159 simpset_of (put_simpset ZF_ss \<^context> 160 addsimps add_0s @ mult_1s @ diff_simps @ zminus_simps @ @{thms zadd_ac}) 161 val norm_ss2 = 162 simpset_of (put_simpset ZF_ss \<^context> 163 addsimps bin_simps @ int_mult_minus_simps @ intifys) 164 val norm_ss3 = 165 simpset_of (put_simpset ZF_ss \<^context> 166 addsimps int_minus_from_mult_simps @ @{thms zadd_ac} @ @{thms zmult_ac} @ tc_rules @ intifys) 167 fun norm_tac ctxt = 168 ALLGOALS (asm_simp_tac (put_simpset norm_ss1 ctxt)) 169 THEN ALLGOALS (asm_simp_tac (put_simpset norm_ss2 ctxt)) 170 THEN ALLGOALS (asm_simp_tac (put_simpset norm_ss3 ctxt)) 171 172 val numeral_simp_ss = 173 simpset_of (put_simpset ZF_ss \<^context> 174 addsimps add_0s @ bin_simps @ tc_rules @ intifys) 175 fun numeral_simp_tac ctxt = 176 ALLGOALS (simp_tac (put_simpset numeral_simp_ss ctxt)) 177 THEN ALLGOALS (asm_simp_tac ctxt) 178 val simplify_meta_eq = ArithData.simplify_meta_eq (add_0s @ mult_1s) 179 end; 180 181 182structure EqCancelNumerals = CancelNumeralsFun 183 (open CancelNumeralsCommon 184 val prove_conv = ArithData.prove_conv "inteq_cancel_numerals" 185 val mk_bal = FOLogic.mk_eq 186 val dest_bal = FOLogic.dest_eq 187 val bal_add1 = @{thm eq_add_iff1} RS @{thm iff_trans} 188 val bal_add2 = @{thm eq_add_iff2} RS @{thm iff_trans} 189); 190 191structure LessCancelNumerals = CancelNumeralsFun 192 (open CancelNumeralsCommon 193 val prove_conv = ArithData.prove_conv "intless_cancel_numerals" 194 val mk_bal = FOLogic.mk_binrel \<^const_name>\<open>zless\<close> 195 val dest_bal = FOLogic.dest_bin \<^const_name>\<open>zless\<close> \<^typ>\<open>i\<close> 196 val bal_add1 = @{thm less_add_iff1} RS @{thm iff_trans} 197 val bal_add2 = @{thm less_add_iff2} RS @{thm iff_trans} 198); 199 200structure LeCancelNumerals = CancelNumeralsFun 201 (open CancelNumeralsCommon 202 val prove_conv = ArithData.prove_conv "intle_cancel_numerals" 203 val mk_bal = FOLogic.mk_binrel \<^const_name>\<open>zle\<close> 204 val dest_bal = FOLogic.dest_bin \<^const_name>\<open>zle\<close> \<^typ>\<open>i\<close> 205 val bal_add1 = @{thm le_add_iff1} RS @{thm iff_trans} 206 val bal_add2 = @{thm le_add_iff2} RS @{thm iff_trans} 207); 208 209val cancel_numerals = 210 [Simplifier.make_simproc \<^context> "inteq_cancel_numerals" 211 {lhss = 212 [\<^term>\<open>l $+ m = n\<close>, \<^term>\<open>l = m $+ n\<close>, 213 \<^term>\<open>l $- m = n\<close>, \<^term>\<open>l = m $- n\<close>, 214 \<^term>\<open>l $* m = n\<close>, \<^term>\<open>l = m $* n\<close>], 215 proc = K EqCancelNumerals.proc}, 216 Simplifier.make_simproc \<^context> "intless_cancel_numerals" 217 {lhss = 218 [\<^term>\<open>l $+ m $< n\<close>, \<^term>\<open>l $< m $+ n\<close>, 219 \<^term>\<open>l $- m $< n\<close>, \<^term>\<open>l $< m $- n\<close>, 220 \<^term>\<open>l $* m $< n\<close>, \<^term>\<open>l $< m $* n\<close>], 221 proc = K LessCancelNumerals.proc}, 222 Simplifier.make_simproc \<^context> "intle_cancel_numerals" 223 {lhss = 224 [\<^term>\<open>l $+ m $\<le> n\<close>, \<^term>\<open>l $\<le> m $+ n\<close>, 225 \<^term>\<open>l $- m $\<le> n\<close>, \<^term>\<open>l $\<le> m $- n\<close>, 226 \<^term>\<open>l $* m $\<le> n\<close>, \<^term>\<open>l $\<le> m $* n\<close>], 227 proc = K LeCancelNumerals.proc}]; 228 229 230(*version without the hyps argument*) 231fun prove_conv_nohyps name tacs sg = ArithData.prove_conv name tacs sg []; 232 233structure CombineNumeralsData = 234 struct 235 type coeff = int 236 val iszero = (fn x => x = 0) 237 val add = op + 238 val mk_sum = (fn _ : typ => long_mk_sum) (*to work for #2*x $+ #3*x *) 239 val dest_sum = dest_sum 240 val mk_coeff = mk_coeff 241 val dest_coeff = dest_coeff 1 242 val left_distrib = @{thm left_zadd_zmult_distrib} RS @{thm trans} 243 val prove_conv = prove_conv_nohyps "int_combine_numerals" 244 fun trans_tac ctxt = ArithData.gen_trans_tac ctxt @{thm trans} 245 246 val norm_ss1 = 247 simpset_of (put_simpset ZF_ss \<^context> 248 addsimps add_0s @ mult_1s @ diff_simps @ zminus_simps @ @{thms zadd_ac} @ intifys) 249 val norm_ss2 = 250 simpset_of (put_simpset ZF_ss \<^context> 251 addsimps bin_simps @ int_mult_minus_simps @ intifys) 252 val norm_ss3 = 253 simpset_of (put_simpset ZF_ss \<^context> 254 addsimps int_minus_from_mult_simps @ @{thms zadd_ac} @ @{thms zmult_ac} @ tc_rules @ intifys) 255 fun norm_tac ctxt = 256 ALLGOALS (asm_simp_tac (put_simpset norm_ss1 ctxt)) 257 THEN ALLGOALS (asm_simp_tac (put_simpset norm_ss2 ctxt)) 258 THEN ALLGOALS (asm_simp_tac (put_simpset norm_ss3 ctxt)) 259 260 val numeral_simp_ss = 261 simpset_of (put_simpset ZF_ss \<^context> addsimps add_0s @ bin_simps @ tc_rules @ intifys) 262 fun numeral_simp_tac ctxt = 263 ALLGOALS (simp_tac (put_simpset numeral_simp_ss ctxt)) 264 val simplify_meta_eq = ArithData.simplify_meta_eq (add_0s @ mult_1s) 265 end; 266 267structure CombineNumerals = CombineNumeralsFun(CombineNumeralsData); 268 269val combine_numerals = 270 Simplifier.make_simproc \<^context> "int_combine_numerals" 271 {lhss = [\<^term>\<open>i $+ j\<close>, \<^term>\<open>i $- j\<close>], 272 proc = K CombineNumerals.proc}; 273 274 275 276(** Constant folding for integer multiplication **) 277 278(*The trick is to regard products as sums, e.g. #3 $* x $* #4 as 279 the "sum" of #3, x, #4; the literals are then multiplied*) 280 281 282structure CombineNumeralsProdData = 283struct 284 type coeff = int 285 val iszero = (fn x => x = 0) 286 val add = op * 287 val mk_sum = (fn _ : typ => mk_prod) 288 val dest_sum = dest_prod 289 fun mk_coeff(k,t) = 290 if t = one then mk_numeral k 291 else raise TERM("mk_coeff", []) 292 fun dest_coeff t = (dest_numeral t, one) (*We ONLY want pure numerals.*) 293 val left_distrib = @{thm zmult_assoc} RS @{thm sym} RS @{thm trans} 294 val prove_conv = prove_conv_nohyps "int_combine_numerals_prod" 295 fun trans_tac ctxt = ArithData.gen_trans_tac ctxt @{thm trans} 296 297 val norm_ss1 = 298 simpset_of (put_simpset ZF_ss \<^context> addsimps mult_1s @ diff_simps @ zminus_simps) 299 val norm_ss2 = 300 simpset_of (put_simpset ZF_ss \<^context> addsimps [@{thm zmult_zminus_right} RS @{thm sym}] @ 301 bin_simps @ @{thms zmult_ac} @ tc_rules @ intifys) 302 fun norm_tac ctxt = 303 ALLGOALS (asm_simp_tac (put_simpset norm_ss1 ctxt)) 304 THEN ALLGOALS (asm_simp_tac (put_simpset norm_ss2 ctxt)) 305 306 val numeral_simp_ss = 307 simpset_of (put_simpset ZF_ss \<^context> addsimps bin_simps @ tc_rules @ intifys) 308 fun numeral_simp_tac ctxt = 309 ALLGOALS (simp_tac (put_simpset numeral_simp_ss ctxt)) 310 val simplify_meta_eq = ArithData.simplify_meta_eq (mult_1s); 311end; 312 313 314structure CombineNumeralsProd = CombineNumeralsFun(CombineNumeralsProdData); 315 316val combine_numerals_prod = 317 Simplifier.make_simproc \<^context> "int_combine_numerals_prod" 318 {lhss = [\<^term>\<open>i $* j\<close>], proc = K CombineNumeralsProd.proc}; 319 320end; 321 322val _ = 323 Theory.setup (Simplifier.map_theory_simpset (fn ctxt => 324 ctxt addsimprocs 325 (Int_Numeral_Simprocs.cancel_numerals @ 326 [Int_Numeral_Simprocs.combine_numerals, 327 Int_Numeral_Simprocs.combine_numerals_prod]))); 328