1(*
2 * Copyright 2020, Data61, CSIRO (ABN 41 687 119 230)
3 *
4 * SPDX-License-Identifier: BSD-2-Clause
5 *)
6
7theory NonDetMonadLemmas
8imports NonDetMonad
9begin
10
11section "General Lemmas Regarding the Nondeterministic State Monad"
12
13subsection "Congruence Rules for the Function Package"
14
15lemma bind_cong[fundef_cong]:
16  "\<lbrakk> f = f'; \<And>v s s'. (v, s') \<in> fst (f' s) \<Longrightarrow> g v s' = g' v s' \<rbrakk> \<Longrightarrow> f >>= g = f' >>= g'"
17  apply (rule ext)
18  apply (auto simp: bind_def Let_def split_def intro: rev_image_eqI)
19  done
20
21lemma bind_apply_cong [fundef_cong]:
22  "\<lbrakk> f s = f' s'; \<And>rv st. (rv, st) \<in> fst (f' s') \<Longrightarrow> g rv st = g' rv st \<rbrakk>
23       \<Longrightarrow> (f >>= g) s = (f' >>= g') s'"
24  apply (simp add: bind_def)
25  apply (auto simp: split_def intro: SUP_cong [OF refl] intro: rev_image_eqI)
26  done
27
28lemma bindE_cong[fundef_cong]:
29  "\<lbrakk> M = M' ; \<And>v s s'. (Inr v, s') \<in> fst (M' s) \<Longrightarrow> N v s' = N' v s' \<rbrakk> \<Longrightarrow> bindE M N = bindE M' N'"
30  apply (simp add: bindE_def)
31  apply (rule bind_cong)
32   apply (rule refl)
33  apply (unfold lift_def)
34  apply (case_tac v, simp_all)
35  done
36
37lemma bindE_apply_cong[fundef_cong]:
38  "\<lbrakk> f s = f' s'; \<And>rv st. (Inr rv, st) \<in> fst (f' s') \<Longrightarrow> g rv st = g' rv st \<rbrakk>
39  \<Longrightarrow> (f >>=E g) s = (f' >>=E g') s'"
40  apply (simp add: bindE_def)
41  apply (rule bind_apply_cong)
42   apply assumption
43  apply (case_tac rv, simp_all add: lift_def)
44  done
45
46lemma K_bind_apply_cong[fundef_cong]:
47  "\<lbrakk> f st = f' st' \<rbrakk> \<Longrightarrow> K_bind f arg st = K_bind f' arg' st'"
48  by simp
49
50lemma when_apply_cong[fundef_cong]:
51  "\<lbrakk> C = C'; s = s'; C' \<Longrightarrow> m s' = m' s' \<rbrakk> \<Longrightarrow> whenE C m s = whenE C' m' s'"
52  by (simp add: whenE_def)
53
54lemma unless_apply_cong[fundef_cong]:
55  "\<lbrakk> C = C'; s = s'; \<not> C' \<Longrightarrow> m s' = m' s' \<rbrakk> \<Longrightarrow> unlessE C m s = unlessE C' m' s'"
56  by (simp add: unlessE_def)
57
58lemma whenE_apply_cong[fundef_cong]:
59  "\<lbrakk> C = C'; s = s'; C' \<Longrightarrow> m s' = m' s' \<rbrakk> \<Longrightarrow> whenE C m s = whenE C' m' s'"
60  by (simp add: whenE_def)
61
62lemma unlessE_apply_cong[fundef_cong]:
63  "\<lbrakk> C = C'; s = s'; \<not> C' \<Longrightarrow> m s' = m' s' \<rbrakk> \<Longrightarrow> unlessE C m s = unlessE C' m' s'"
64  by (simp add: unlessE_def)
65
66subsection "Simplifying Monads"
67
68lemma nested_bind [simp]:
69  "do x <- do y <- f; return (g y) od; h x od =
70   do y <- f; h (g y) od"
71  apply (clarsimp simp add: bind_def)
72  apply (rule ext)
73  apply (clarsimp simp add: Let_def split_def return_def)
74  done
75
76lemma fail_bind [simp]:
77  "fail >>= f = fail"
78  by (simp add: bind_def fail_def)
79
80lemma fail_bindE [simp]:
81  "fail >>=E f = fail"
82  by (simp add: bindE_def bind_def fail_def)
83
84lemma assert_False [simp]:
85  "assert False >>= f = fail"
86  by (simp add: assert_def)
87
88lemma assert_True [simp]:
89  "assert True >>= f = f ()"
90  by (simp add: assert_def)
91
92lemma assertE_False [simp]:
93  "assertE False >>=E f = fail"
94  by (simp add: assertE_def)
95
96lemma assertE_True [simp]:
97  "assertE True >>=E f = f ()"
98  by (simp add: assertE_def)
99
100lemma when_False_bind [simp]:
101  "when False g >>= f = f ()"
102  by (rule ext) (simp add: when_def bind_def return_def)
103
104lemma when_True_bind [simp]:
105  "when True g >>= f = g >>= f"
106  by (simp add: when_def bind_def return_def)
107
108lemma whenE_False_bind [simp]:
109  "whenE False g >>=E f = f ()"
110  by (simp add: whenE_def bindE_def returnOk_def lift_def)
111
112lemma whenE_True_bind [simp]:
113  "whenE True g >>=E f = g >>=E f"
114  by (simp add: whenE_def bindE_def returnOk_def lift_def)
115
116lemma when_True [simp]: "when True X = X"
117  by (clarsimp simp: when_def)
118
119lemma when_False [simp]: "when False X = return ()"
120  by (clarsimp simp: when_def)
121
122lemma unless_False [simp]: "unless False X = X"
123  by (clarsimp simp: unless_def)
124
125lemma unlessE_False [simp]: "unlessE False f = f"
126  unfolding unlessE_def by fastforce
127
128lemma unless_True [simp]: "unless True X = return ()"
129  by (clarsimp simp: unless_def)
130
131lemma unlessE_True [simp]: "unlessE True f = returnOk ()"
132  unfolding unlessE_def by fastforce
133
134lemma unlessE_whenE:
135  "unlessE P = whenE (~P)"
136  by (rule ext)+ (simp add: unlessE_def whenE_def)
137
138lemma unless_when:
139  "unless P = when (~P)"
140  by (rule ext)+ (simp add: unless_def when_def)
141
142lemma gets_to_return [simp]: "gets (\<lambda>s. v) = return v"
143  by (clarsimp simp: gets_def put_def get_def bind_def return_def)
144
145lemma assert_opt_Some:
146  "assert_opt (Some x) = return x"
147  by (simp add: assert_opt_def)
148
149lemma assertE_liftE:
150  "assertE P = liftE (assert P)"
151  by (simp add: assertE_def assert_def liftE_def returnOk_def)
152
153lemma liftE_handleE' [simp]: "((liftE a) <handle2> b) = liftE a"
154  apply (clarsimp simp: liftE_def handleE'_def)
155  done
156
157lemma liftE_handleE [simp]: "((liftE a) <handle> b) = liftE a"
158  apply (unfold handleE_def)
159  apply simp
160  done
161
162lemma condition_split:
163  "P (condition C a b s) = ((((C s) \<longrightarrow> P (a s)) \<and> (\<not> (C s) \<longrightarrow> P (b s))))"
164  apply (clarsimp simp: condition_def)
165  done
166
167lemma condition_split_asm:
168  "P (condition C a b s) = (\<not> (C s \<and> \<not> P (a s) \<or> \<not> C s \<and> \<not> P (b s)))"
169  apply (clarsimp simp: condition_def)
170  done
171
172lemmas condition_splits = condition_split condition_split_asm
173
174lemma condition_true_triv [simp]:
175  "condition (\<lambda>_. True) A B = A"
176  apply (rule ext)
177  apply (clarsimp split: condition_splits)
178  done
179
180lemma condition_false_triv [simp]:
181  "condition (\<lambda>_. False) A B = B"
182  apply (rule ext)
183  apply (clarsimp split: condition_splits)
184  done
185
186lemma condition_true: "\<lbrakk> P s \<rbrakk> \<Longrightarrow> condition P A B s = A s"
187  apply (clarsimp simp: condition_def)
188  done
189
190lemma condition_false: "\<lbrakk> \<not> P s \<rbrakk> \<Longrightarrow> condition P A B s = B s"
191  apply (clarsimp simp: condition_def)
192  done
193
194lemmas arg_cong_bind = arg_cong2[where f=bind]
195lemmas arg_cong_bind1 = arg_cong_bind[OF refl ext]
196
197section "Low-level monadic reasoning"
198
199lemma monad_eqI [intro]:
200  "\<lbrakk> \<And>r t s. (r, t) \<in> fst (A s) \<Longrightarrow> (r, t) \<in> fst (B s);
201     \<And>r t s. (r, t) \<in> fst (B s) \<Longrightarrow> (r, t) \<in> fst (A s);
202     \<And>x. snd (A x) = snd (B x) \<rbrakk>
203  \<Longrightarrow> (A :: ('s, 'a) nondet_monad) = B"
204  apply (fastforce intro!: set_eqI prod_eqI)
205  done
206
207lemma monad_state_eqI [intro]:
208  "\<lbrakk> \<And>r t. (r, t) \<in> fst (A s) \<Longrightarrow> (r, t) \<in> fst (B s');
209     \<And>r t. (r, t) \<in> fst (B s') \<Longrightarrow> (r, t) \<in> fst (A s);
210     snd (A s) = snd (B s') \<rbrakk>
211  \<Longrightarrow> (A :: ('s, 'a) nondet_monad) s = B s'"
212  apply (fastforce intro!: set_eqI prod_eqI)
213  done
214
215subsection "General whileLoop reasoning"
216
217definition
218  "whileLoop_terminatesE C B \<equiv> (\<lambda>r.
219     whileLoop_terminates (\<lambda>r s. case r of Inr v \<Rightarrow> C v s | _ \<Rightarrow> False) (lift B) (Inr r))"
220
221lemma whileLoop_cond_fail:
222    "\<lbrakk> \<not> C x s \<rbrakk> \<Longrightarrow> (whileLoop C B x s) = (return x s)"
223  apply (auto simp: return_def whileLoop_def
224       intro: whileLoop_results.intros
225              whileLoop_terminates.intros
226       elim!: whileLoop_results.cases)
227  done
228
229lemma whileLoopE_cond_fail:
230    "\<lbrakk> \<not> C x s \<rbrakk> \<Longrightarrow> (whileLoopE C B x s) = (returnOk x s)"
231  apply (clarsimp simp: whileLoopE_def returnOk_def)
232  apply (auto intro: whileLoop_cond_fail)
233  done
234
235lemma whileLoop_results_simps_no_move [simp]:
236  shows "((Some x, Some x) \<in> whileLoop_results C B) = (\<not> C (fst x) (snd x))"
237    (is "?LHS x = ?RHS x")
238proof (rule iffI)
239  assume "?LHS x"
240  then have "(\<exists>a. Some x = Some a) \<longrightarrow> ?RHS (the (Some x))"
241   by (induct rule: whileLoop_results.induct, auto)
242  thus "?RHS x"
243    by clarsimp
244next
245  assume "?RHS x"
246  thus "?LHS x"
247    by (metis surjective_pairing whileLoop_results.intros(1))
248qed
249
250lemma whileLoop_unroll:
251  "(whileLoop C B r) =  ((condition (C r) (B r >>= (whileLoop C B)) (return r)))"
252  (is "?LHS r = ?RHS r")
253proof -
254  have cond_fail: "\<And>r s. \<not> C r s \<Longrightarrow> ?LHS r s = ?RHS r s"
255    apply (subst whileLoop_cond_fail, simp)
256    apply (clarsimp simp: condition_def bind_def return_def)
257    done
258
259  have cond_pass: "\<And>r s. C r s \<Longrightarrow> whileLoop C B r s = (B r >>= (whileLoop C B)) s"
260    apply (rule monad_state_eqI)
261      apply (clarsimp simp: whileLoop_def bind_def split_def)
262      apply (subst (asm) whileLoop_results_simps_valid)
263      apply fastforce
264     apply (clarsimp simp: whileLoop_def bind_def split_def)
265     apply (subst whileLoop_results.simps)
266     apply fastforce
267    apply (clarsimp simp: whileLoop_def bind_def split_def)
268    apply (subst whileLoop_results.simps)
269    apply (subst whileLoop_terminates.simps)
270    apply fastforce
271    done
272
273  show ?thesis
274    apply (rule ext)
275    apply (metis cond_fail cond_pass condition_def)
276    done
277qed
278
279lemma whileLoop_unroll':
280    "(whileLoop C B r) = ((condition (C r) (B r) (return r)) >>= (whileLoop C B))"
281  apply (rule ext)
282  apply (subst whileLoop_unroll)
283  apply (clarsimp simp: condition_def bind_def return_def split_def)
284  apply (subst whileLoop_cond_fail, simp)
285  apply (clarsimp simp: return_def)
286  done
287
288lemma whileLoopE_unroll:
289  "(whileLoopE C B r) =  ((condition (C r) (B r >>=E (whileLoopE C B)) (returnOk r)))"
290  apply (rule ext)
291  apply (unfold whileLoopE_def)
292  apply (subst whileLoop_unroll)
293  apply (clarsimp simp: whileLoopE_def bindE_def returnOk_def split: condition_splits)
294  apply (clarsimp simp: lift_def)
295  apply (rule_tac f="\<lambda>a. (B r >>= a) x" in arg_cong)
296  apply (rule ext)+
297  apply (clarsimp simp: lift_def split: sum.splits)
298  apply (subst whileLoop_unroll)
299  apply (subst condition_false)
300   apply fastforce
301  apply (clarsimp simp: throwError_def)
302  done
303
304lemma whileLoopE_unroll':
305  "(whileLoopE C B r) =  ((condition (C r) (B r) (returnOk r)) >>=E (whileLoopE C B))"
306  apply (rule ext)
307  apply (subst whileLoopE_unroll)
308  apply (clarsimp simp: condition_def bindE_def bind_def returnOk_def return_def lift_def split_def)
309  apply (subst whileLoopE_cond_fail, simp)
310  apply (clarsimp simp: returnOk_def return_def)
311  done
312
313(* These lemmas are useful to apply to rules to convert valid rules into
314 * a format suitable for wp. *)
315
316lemma valid_make_schematic_post:
317  "(\<forall>s0. \<lbrace> \<lambda>s. P s0 s \<rbrace> f \<lbrace> \<lambda>rv s. Q s0 rv s \<rbrace>) \<Longrightarrow>
318   \<lbrace> \<lambda>s. \<exists>s0. P s0 s \<and> (\<forall>rv s'. Q s0 rv s' \<longrightarrow> Q' rv s') \<rbrace> f \<lbrace> Q' \<rbrace>"
319  by (auto simp add: valid_def no_fail_def split: prod.splits)
320
321lemma validNF_make_schematic_post:
322  "(\<forall>s0. \<lbrace> \<lambda>s. P s0 s \<rbrace> f \<lbrace> \<lambda>rv s. Q s0 rv s \<rbrace>!) \<Longrightarrow>
323   \<lbrace> \<lambda>s. \<exists>s0. P s0 s \<and> (\<forall>rv s'. Q s0 rv s' \<longrightarrow> Q' rv s') \<rbrace> f \<lbrace> Q' \<rbrace>!"
324  by (auto simp add: valid_def validNF_def no_fail_def split: prod.splits)
325
326lemma validE_make_schematic_post:
327  "(\<forall>s0. \<lbrace> \<lambda>s. P s0 s \<rbrace> f \<lbrace> \<lambda>rv s. Q s0 rv s \<rbrace>, \<lbrace> \<lambda>rv s. E s0 rv s \<rbrace>) \<Longrightarrow>
328   \<lbrace> \<lambda>s. \<exists>s0. P s0 s \<and> (\<forall>rv s'. Q s0 rv s' \<longrightarrow> Q' rv s')
329        \<and> (\<forall>rv s'. E s0 rv s' \<longrightarrow> E' rv s') \<rbrace> f \<lbrace> Q' \<rbrace>, \<lbrace> E' \<rbrace>"
330  by (auto simp add: validE_def valid_def no_fail_def split: prod.splits sum.splits)
331
332lemma validE_NF_make_schematic_post:
333  "(\<forall>s0. \<lbrace> \<lambda>s. P s0 s \<rbrace> f \<lbrace> \<lambda>rv s. Q s0 rv s \<rbrace>, \<lbrace> \<lambda>rv s. E s0 rv s \<rbrace>!) \<Longrightarrow>
334   \<lbrace> \<lambda>s. \<exists>s0. P s0 s \<and> (\<forall>rv s'. Q s0 rv s' \<longrightarrow> Q' rv s')
335        \<and> (\<forall>rv s'. E s0 rv s' \<longrightarrow> E' rv s') \<rbrace> f \<lbrace> Q' \<rbrace>, \<lbrace> E' \<rbrace>!"
336  by (auto simp add: validE_NF_def validE_def valid_def no_fail_def split: prod.splits sum.splits)
337
338lemma validNF_conjD1: "\<lbrace> P \<rbrace> f \<lbrace> \<lambda>rv s. Q rv s \<and> Q' rv s \<rbrace>! \<Longrightarrow> \<lbrace> P \<rbrace> f \<lbrace> Q \<rbrace>!"
339  by (fastforce simp: validNF_def valid_def no_fail_def)
340
341lemma validNF_conjD2: "\<lbrace> P \<rbrace> f \<lbrace> \<lambda>rv s. Q rv s \<and> Q' rv s \<rbrace>! \<Longrightarrow> \<lbrace> P \<rbrace> f \<lbrace> Q' \<rbrace>!"
342  by (fastforce simp: validNF_def valid_def no_fail_def)
343
344end
345