1(*
2 * Copyright 2020, Data61, CSIRO (ABN 41 687 119 230)
3 *
4 * SPDX-License-Identifier: BSD-2-Clause
5 *)
6
7section "Solving Word Equalities"
8
9theory Word_EqI
10  imports
11    Word_Next
12    "HOL-Eisbach.Eisbach_Tools"
13begin
14
15text \<open>
16  Some word equalities can be solved by considering the problem bitwise for all
17  @{prop "n < LENGTH('a::len)"}, which is different to running @{text word_bitwise}
18  and expanding into an explicit list of bits.
19\<close>
20
21lemma word_or_zero:
22  "(a || b = 0) = (a = 0 \<and> b = 0)"
23  by (safe; rule word_eqI, drule_tac x=n in word_eqD, simp)
24
25lemma test_bit_over:
26  "n \<ge> size (x::'a::len0 word) \<Longrightarrow> (x !! n) = False"
27  by (simp add: test_bit_bl word_size)
28
29lemma neg_mask_test_bit:
30  "(~~(mask n) :: 'a :: len word) !! m = (n \<le> m \<and> m < LENGTH('a))"
31  by (metis not_le nth_mask test_bit_bin word_ops_nth_size word_size)
32
33lemma word_2p_mult_inc:
34  assumes x: "2 * 2 ^ n < (2::'a::len word) * 2 ^ m"
35  assumes suc_n: "Suc n < LENGTH('a::len)"
36  shows "2^n < (2::'a::len word)^m"
37  by (smt suc_n le_less_trans lessI nat_less_le nat_mult_less_cancel_disj p2_gt_0
38          power_Suc power_Suc unat_power_lower word_less_nat_alt x)
39
40lemma word_power_increasing:
41  assumes x: "2 ^ x < (2 ^ y::'a::len word)" "x < LENGTH('a::len)" "y < LENGTH('a::len)"
42  shows "x < y" using x
43  apply (induct x arbitrary: y)
44   apply (case_tac y; simp)
45  apply (case_tac y; clarsimp simp: word_2p_mult_inc)
46  apply (subst (asm) power_Suc [symmetric])
47  apply (subst (asm) p2_eq_0)
48  apply simp
49  done
50
51lemma upper_bits_unset_is_l2p:
52  "n < LENGTH('a) \<Longrightarrow>
53   (\<forall>n' \<ge> n. n' < LENGTH('a) \<longrightarrow> \<not> p !! n') = (p < 2 ^ n)" for p :: "'a :: len word"
54  apply (cases "Suc 0 < LENGTH('a)")
55   prefer 2
56   apply (subgoal_tac "LENGTH('a) = 1", auto simp: word_eq_iff)[1]
57  apply (rule iffI)
58   apply (subst mask_eq_iff_w2p [symmetric])
59    apply (clarsimp simp: word_size)
60   apply (rule word_eqI, rename_tac n')
61   apply (case_tac "n' < n"; simp add: word_size)
62  by (meson bang_is_le le_less_trans not_le word_power_increasing)
63
64lemma less_2p_is_upper_bits_unset:
65  "p < 2 ^ n \<longleftrightarrow> n < LENGTH('a) \<and> (\<forall>n' \<ge> n. n' < LENGTH('a) \<longrightarrow> \<not> p !! n')" for p :: "'a :: len word"
66  by (meson le_less_trans le_mask_iff_lt_2n upper_bits_unset_is_l2p word_zero_le)
67
68lemma word_le_minus_one_leq:
69  "x < y \<Longrightarrow> x \<le> y - 1" for x :: "'a :: len word"
70  by (simp add: plus_one_helper)
71
72lemma word_less_sub_le[simp]:
73  fixes x :: "'a :: len word"
74  assumes nv: "n < LENGTH('a)"
75  shows "(x \<le> 2 ^ n - 1) = (x < 2 ^ n)"
76  using le_less_trans word_le_minus_one_leq nv power_2_ge_iff by blast
77
78lemma not_greatest_aligned:
79  "\<lbrakk> x < y; is_aligned x n; is_aligned y n \<rbrakk> \<Longrightarrow> x + 2 ^ n \<noteq> 0"
80  by (metis NOT_mask add_diff_cancel_right' diff_0 is_aligned_neg_mask_eq not_le word_and_le1)
81
82lemma neg_mask_mono_le:
83  "x \<le> y \<Longrightarrow> x && ~~(mask n) \<le> y && ~~(mask n)" for x :: "'a :: len word"
84proof (rule ccontr, simp add: linorder_not_le, cases "n < LENGTH('a)")
85  case False
86  then show "y && ~~(mask n) < x && ~~(mask n) \<Longrightarrow> False"
87    by (simp add: mask_def linorder_not_less power_overflow)
88next
89  case True
90  assume a: "x \<le> y" and b: "y && ~~(mask n) < x && ~~(mask n)"
91  have word_bits: "n < LENGTH('a)" by fact
92  have "y \<le> (y && ~~(mask n)) + (y && mask n)"
93    by (simp add: word_plus_and_or_coroll2 add.commute)
94  also have "\<dots> \<le> (y && ~~(mask n)) + 2 ^ n"
95    apply (rule word_plus_mono_right)
96     apply (rule order_less_imp_le, rule and_mask_less_size)
97     apply (simp add: word_size word_bits)
98    apply (rule is_aligned_no_overflow'', simp add: is_aligned_neg_mask word_bits)
99    apply (rule not_greatest_aligned, rule b; simp add: is_aligned_neg_mask)
100    done
101  also have "\<dots> \<le> x && ~~(mask n)"
102    using b
103    apply (subst add.commute)
104    apply (rule le_plus)
105     apply (rule aligned_at_least_t2n_diff; simp add: is_aligned_neg_mask)
106    apply (rule ccontr, simp add: linorder_not_le)
107    apply (drule aligned_small_is_0[rotated]; simp add: is_aligned_neg_mask)
108    done
109  also have "\<dots> \<le> x" by (rule word_and_le2)
110  also have "x \<le> y" by fact
111  finally
112  show "False" using b by simp
113qed
114
115lemma and_neg_mask_eq_iff_not_mask_le:
116  "w && ~~(mask n) = ~~(mask n) \<longleftrightarrow> ~~(mask n) \<le> w"
117  by (metis eq_iff neg_mask_mono_le word_and_le1 word_and_le2 word_bw_same(1))
118
119lemma le_mask_high_bits:
120  "w \<le> mask n \<longleftrightarrow> (\<forall>i \<in> {n ..< size w}. \<not> w !! i)"
121  by (auto simp: word_size and_mask_eq_iff_le_mask[symmetric] word_eq_iff)
122
123lemma neg_mask_le_high_bits:
124  "~~(mask n) \<le> w \<longleftrightarrow> (\<forall>i \<in> {n ..< size w}. w !! i)"
125  by (auto simp: word_size and_neg_mask_eq_iff_not_mask_le[symmetric] word_eq_iff neg_mask_test_bit)
126
127lemma test_bit_conj_lt:
128  "(x !! m \<and> m < LENGTH('a)) = x !! m" for x :: "'a :: len word"
129  using test_bit_bin by blast
130
131lemma neg_test_bit:
132  "(~~ x) !! n = (\<not> x !! n \<and> n < LENGTH('a))" for x :: "'a::len word"
133  by (cases "n < LENGTH('a)") (auto simp add: test_bit_over word_ops_nth_size word_size)
134
135named_theorems word_eqI_simps
136
137lemmas [word_eqI_simps] =
138  word_ops_nth_size
139  word_size
140  word_or_zero
141  neg_mask_test_bit
142  nth_ucast
143  is_aligned_nth
144  nth_w2p nth_shiftl
145  nth_shiftr
146  less_2p_is_upper_bits_unset
147  le_mask_high_bits
148  neg_mask_le_high_bits
149  bang_eq
150  neg_test_bit
151  is_up
152  is_down
153
154lemmas word_eqI_rule = word_eqI[rule_format]
155
156lemma test_bit_lenD:
157  "x !! n \<Longrightarrow> n < LENGTH('a) \<and> x !! n" for x :: "'a :: len word"
158  by (fastforce dest: test_bit_size simp: word_size)
159
160method word_eqI uses simp simp_del split split_del cong flip =
161  ((* reduce conclusion to test_bit: *)
162   rule word_eqI_rule,
163   (* make sure we're in clarsimp normal form: *)
164   (clarsimp simp: simp simp del: simp_del simp flip: flip split: split split del: split_del cong: cong)?,
165   (* turn x < 2^n assumptions into mask equations: *)
166   ((drule less_mask_eq)+)?,
167   (* expand and distribute test_bit everywhere: *)
168   (clarsimp simp: word_eqI_simps simp simp del: simp_del simp flip: flip
169             split: split split del: split_del cong: cong)?,
170   (* add any additional word size constraints to new indices: *)
171   ((drule test_bit_lenD)+)?,
172   (* try to make progress (can't use +, would loop): *)
173   (clarsimp simp: word_eqI_simps simp simp del: simp_del simp flip: flip
174             split: split split del: split_del cong: cong)?,
175   (* helps sometimes, rarely: *)
176   (simp add: simp test_bit_conj_lt del: simp_del flip: flip split: split split del: split_del cong: cong)?)
177
178method word_eqI_solve uses simp simp_del split split_del cong flip =
179  solves \<open>word_eqI simp: simp simp_del: simp_del split: split split_del: split_del
180                   cong: cong simp flip: flip;
181          (fastforce dest: test_bit_size simp: word_eqI_simps simp flip: flip
182                     simp: simp simp del: simp_del split: split split del: split_del cong: cong)?\<close>
183
184end
185