1(*
2 * Copyright 2014, NICTA
3 *
4 * This software may be distributed and modified according to the terms of
5 * the BSD 2-Clause license. Note that NO WARRANTY is provided.
6 * See "LICENSE_BSD2.txt" for details.
7 *
8 * @TAG(NICTA_BSD)
9 *)
10
11theory SubMonadLib
12imports
13  EmptyFailLib
14  Corres_UL
15begin
16
17locale submonad_args =
18  fixes fetch :: "'a \<Rightarrow> 'b"
19  fixes replace :: "'b \<Rightarrow> 'a \<Rightarrow> 'a"
20  fixes guard :: "'a \<Rightarrow> bool"
21
22  assumes args:
23   "\<forall>x s. guard s \<longrightarrow> fetch (replace x s) = x"
24   "\<forall>x y s. replace x (replace y s) = replace x s"
25   "\<forall>s. replace (fetch s) s = s"
26
27  assumes replace_preserves_guard:
28   "\<And>s x. guard (replace x s) = guard s"
29
30definition
31  submonad_fn :: "('a \<Rightarrow> 'b) \<Rightarrow> ('b \<Rightarrow> 'a \<Rightarrow> 'a) \<Rightarrow> ('a \<Rightarrow> bool) \<Rightarrow>
32                  ('b, 'c) nondet_monad \<Rightarrow> ('a, 'c) nondet_monad"
33where
34 "submonad_fn fetch replace guard m \<equiv> do
35    stateAssert guard [];
36    substate \<leftarrow> gets fetch;
37    (rv, substate') \<leftarrow> select_f (m substate);
38    modify (replace substate');
39    return rv
40  od"
41
42locale submonad = submonad_args +
43  fixes fn :: "('b, 'c) nondet_monad \<Rightarrow> ('a, 'c) nondet_monad"
44
45  assumes fn_is_sm: "fn = submonad_fn fetch replace guard"
46
47lemma (in submonad_args) argsD1:
48  "\<And>x s. guard s \<Longrightarrow> fetch (replace x s) = x"
49  by (simp add: args)
50
51lemma (in submonad) guarded_sm:
52  "\<And>s. guard s \<Longrightarrow>
53   fn m s = (do
54     substate \<leftarrow> gets fetch;
55     (rv, substate') \<leftarrow> select_f (m substate);
56     modify (replace substate');
57     return rv
58   od) s"
59  unfolding fn_is_sm submonad_fn_def
60  by (simp add: stateAssert_def get_def assert_def bind_def return_def)
61
62lemma modify_modify:
63  "modify fn1 >>= (\<lambda>x. modify fn2) = modify (fn2 \<circ> fn1)"
64  by (simp add: bind_def modify_def get_def put_def)
65
66lemma select_f_walk:
67  assumes m1: "empty_fail m1"
68  assumes S: "fst S = {} \<Longrightarrow> snd S"
69  shows "(do a \<leftarrow> m1; b \<leftarrow> select_f S; m2 a b od) = (do b \<leftarrow> select_f S; a \<leftarrow> m1; m2 a b od)"
70  apply (rule ext)
71  apply (rule prod.expand)
72  apply (rule conjI)
73   apply (simp add: select_f_def bind_def split_def)
74   apply fastforce
75  apply (simp add: select_f_def bind_def split_def)
76  apply (case_tac "fst S = {}")
77   apply clarsimp
78   apply (case_tac "fst (m1 x) = {}")
79    apply (simp add: empty_failD [OF m1] S)
80   apply (frule S)
81   apply force
82  apply safe
83     apply clarsimp
84     apply force
85    apply force
86   apply clarsimp
87   apply force
88  apply clarsimp
89  apply (case_tac "fst (m1 x) = {}", simp add: empty_failD [OF m1])
90  apply force
91  done
92
93lemma stateAssert_stateAssert:
94  "(stateAssert g [] >>= (\<lambda>u. stateAssert g' [])) = stateAssert (g and g') []"
95  by (simp add: ext stateAssert_def bind_def get_def assert_def fail_def return_def)
96
97lemma modify_stateAssert:
98  "\<lbrakk> \<And>s x. g (r x s) = g s \<rbrakk> \<Longrightarrow>
99   (modify (r x) >>= (\<lambda>u. stateAssert g []))
100            = (stateAssert g [] >>= (\<lambda>u. modify (r x)))"
101  by (simp add: ext stateAssert_def bind_def get_def assert_def fail_def
102                return_def modify_def put_def)
103
104lemma gets_stateAssert:
105  "(gets f >>= (\<lambda>x. stateAssert g' [] >>= (\<lambda>u. m x)))
106            = (stateAssert g' [] >>= (\<lambda>u. gets f >>= (\<lambda>x. m x)))"
107  by (simp add: ext stateAssert_def bind_def gets_def get_def
108                assert_def fail_def return_def)
109
110lemma select_f_stateAssert:
111  "empty_fail m \<Longrightarrow>
112   (select_f (m a) >>= (\<lambda>x. stateAssert g [] >>= (\<lambda>u. n x))) =
113   (stateAssert g [] >>= (\<lambda>u. select_f (m a) >>= (\<lambda>x. n x)))"
114  apply (rule ext)
115  apply (clarsimp simp: stateAssert_def bind_def select_f_def get_def
116                        assert_def return_def fail_def split_def image_image)
117  apply (simp only: image_def)
118  apply (clarsimp simp: stateAssert_def bind_def select_f_def get_def
119                        assert_def return_def fail_def split_def image_image)
120  apply (simp only: image_def mem_simps empty_fail_def simp_thms)
121  apply fastforce
122  done
123
124lemma bind_select_f_bind':
125  shows "(select_f (m s) >>= (\<lambda>x. select_f (split n x))) = (select_f ((m >>= n) s))"
126  apply (rule ext)
127  apply (force simp: select_f_def bind_def split_def)
128  done
129
130lemma bind_select_f_bind:
131  "(select_f (m1 s) >>= (\<lambda>x. select_f (m2 (fst x) (snd x)))) = (select_f ((m1 >>= m2) s))"
132  by (insert bind_select_f_bind' [where m=m1 and n=m2 and s=s],
133      simp add: split_def)
134
135lemma select_from_gets: "select_f (gets f s) = return (f s, s)"
136  apply (rule ext)
137  apply (simp add: select_f_def return_def simpler_gets_def)
138  done
139
140lemma select_from_gets':
141  "(select_f \<circ> gets f) = (\<lambda>s. return (f s, s))"
142  apply (rule ext)
143  apply (simp add: o_def select_from_gets)
144  done
145
146lemma bind_subst_lift:
147  "(f >>= g) = h \<Longrightarrow> (do x \<leftarrow> f; y \<leftarrow> g x; j y od) = (h >>= j)"
148  by (simp add: bind_assoc[symmetric])
149
150lemma modify_gets:
151  "\<lbrakk> \<And>x s. g (r x s) = g s; \<And>x s. g s \<longrightarrow> f (r x s) = x \<rbrakk>
152   \<Longrightarrow> (modify (r x) >>= (\<lambda>u. stateAssert g [] >>= (\<lambda>u'. gets f)))
153            = (stateAssert g [] >>= (\<lambda>u'. modify (r x) >>= (\<lambda>u. return x)))"
154  by (simp add: ext stateAssert_def assert_def modify_def bind_def get_def
155                put_def gets_def return_def fail_def)
156
157lemma (in submonad_args) gets_modify:
158  "\<And>s. guard s \<Longrightarrow>
159   (do x \<leftarrow> gets fetch; u \<leftarrow> modify (replace x); f x od) s = ((gets fetch) >>= f) s"
160  by (clarsimp simp: modify_def gets_def return_def bind_def
161                     put_def args get_def
162              split: option.split)
163
164lemma submonad_bind:
165  "\<lbrakk> submonad f r g m; submonad f r g m'; submonad f r g m'';
166     empty_fail a; \<And>x. empty_fail (b x) \<rbrakk> \<Longrightarrow>
167   m (a >>= b) = (m' a) >>= (\<lambda>rv. m'' (b rv))"
168  apply (subst submonad.fn_is_sm, assumption)+
169  apply (clarsimp simp: submonad_def bind_assoc split_def submonad_fn_def)
170  apply (subst bind_subst_lift [OF modify_gets, unfolded bind_assoc])
171    apply (simp add: submonad_args.args submonad_args.replace_preserves_guard)+
172  apply (subst select_f_stateAssert, assumption)
173  apply (subst gets_stateAssert)
174  apply (subst bind_subst_lift [OF stateAssert_stateAssert])
175  apply (clarsimp simp: pred_conj_def)
176  apply (clarsimp simp: bind_assoc split_def select_f_walk
177                empty_fail_stateAssert empty_failD
178                bind_subst_lift[OF modify_modify] submonad_args.args o_def
179                bind_subst_lift[OF bind_select_f_bind])
180  done
181
182lemma (in submonad) guard_preserved:
183  "\<And>s s'. \<lbrakk> (rv, s') \<in> fst (fn m s) \<rbrakk> \<Longrightarrow> guard s'"
184  unfolding fn_is_sm submonad_fn_def
185  by (clarsimp simp: stateAssert_def gets_def get_def bind_def modify_def put_def
186                     return_def select_f_def replace_preserves_guard in_monad)
187
188lemma fst_stateAssertD:
189  "\<And>s s' v. (v, s') \<in> fst (stateAssert g [] s) \<Longrightarrow> s' = s \<and> g s"
190  by (clarsimp simp: stateAssert_def in_monad)
191
192lemma(in submonad) guarded_gets:
193  "\<And>s. guard s \<Longrightarrow> fn (gets f) s = gets (f \<circ> fetch) s"
194  apply (simp add: guarded_sm select_from_gets gets_modify)
195  apply (simp add: gets_def)
196  done
197
198lemma (in submonad) guarded_return:
199  "\<And>s. guard s \<Longrightarrow> fn (return x) s = return x s"
200  using args guarded_gets
201  by (fastforce simp: gets_def bind_def get_def)
202
203lemma (in submonad_args) submonad_fn_gets:
204  "submonad_fn fetch replace guard (gets f) =
205   (stateAssert guard [] >>= (\<lambda>u. gets (f \<circ> fetch)))"
206  apply (simp add: ext select_from_gets submonad_fn_def)
207  apply (rule bind_cong [OF refl])
208  apply (clarsimp simp: gets_modify dest!: fst_stateAssertD)
209  apply (simp add: gets_def)
210  done
211
212lemma(in submonad) gets:
213  "fn (gets f) = (stateAssert guard [] >>= (\<lambda>u. gets (f \<circ> fetch)))"
214  unfolding fn_is_sm submonad_fn_gets
215  by (rule refl)
216
217lemma (in submonad) return:
218  "fn (return x) = (stateAssert guard [] >>= (\<lambda>u. return x))"
219  using args gets
220  by (fastforce simp: gets_def bind_def get_def)
221
222lemma (in submonad) mapM_guard_preserved:
223  "\<And>s s'. \<lbrakk> guard s; \<exists>rv. (rv, s') \<in> fst (mapM (fn \<circ> m) xs s)\<rbrakk> \<Longrightarrow> guard s'"
224proof (induct xs)
225  case Nil
226  thus ?case
227    by (simp add: mapM_def sequence_def return_def)
228  next
229  case (Cons x xs)
230  thus ?case
231    apply (clarsimp simp: o_def mapM_Cons return_def bind_def)
232    apply (drule guard_preserved)
233    apply fastforce
234    done
235qed
236
237lemma (in submonad) mapM_x_guard_preserved:
238  "\<And>s s'. \<lbrakk> guard s; \<exists>rv. (rv, s') \<in> fst (mapM_x (fn \<circ> m) xs s)\<rbrakk> \<Longrightarrow> guard s'"
239proof (induct xs)
240  case Nil
241  thus ?case
242    by (simp add: mapM_x_def sequence_x_def return_def)
243  next
244  case (Cons x xs)
245  thus ?case
246    apply (clarsimp simp: o_def mapM_x_Cons return_def bind_def)
247    apply (drule guard_preserved)
248    apply fastforce
249    done
250qed
251
252lemma (in submonad) stateAssert_fn:
253  "stateAssert guard [] >>= (\<lambda>u. fn m) = fn m"
254  by (simp add: fn_is_sm submonad_fn_def pred_conj_def
255                bind_subst_lift [OF stateAssert_stateAssert])
256
257lemma (in submonad) fn_stateAssert:
258  "fn m >>= (\<lambda>x. stateAssert guard [] >>= (\<lambda>u. n x)) = (fn m >>= n)"
259  apply (simp add: fn_is_sm submonad_fn_def bind_assoc split_def)
260  apply (rule ext)
261  apply (rule bind_apply_cong [OF refl])+
262  apply (clarsimp simp: stateAssert_def bind_assoc in_monad select_f_def)
263  apply (drule iffD2 [OF replace_preserves_guard])
264  apply (fastforce simp: bind_def assert_def get_def return_def)
265  done
266
267lemma submonad_mapM:
268  assumes sm: "submonad f r g sm" and sm': "submonad f r g sm'"
269  assumes efm: "\<And>x. empty_fail (m x)"
270  shows
271  "(sm (mapM m l)) = (stateAssert g [] >>= (\<lambda>u. mapM (sm' \<circ> m) l))"
272proof (induct l)
273  case Nil
274  thus ?case
275    by (simp add: mapM_def sequence_def bind_def submonad.return [OF sm])
276  next
277  case (Cons x xs)
278  thus ?case
279    using sm sm' efm
280    apply (simp add: mapM_Cons)
281    apply (simp add: bind_subst_lift [OF submonad.stateAssert_fn])
282    apply (simp add: bind_assoc submonad_bind submonad.return)
283    apply (subst submonad.fn_stateAssert [OF sm'])
284    apply (intro ext bind_apply_cong [OF refl])
285    apply (subgoal_tac "g sta")
286     apply (clarsimp simp: stateAssert_def bind_def get_def assert_def return_def)
287    apply (frule(1) submonad.guard_preserved)
288    apply (erule(1) submonad.mapM_guard_preserved, fastforce simp: o_def)
289    done
290qed
291
292lemma submonad_mapM_x:
293  assumes sm: "submonad f r g sm" and sm': "submonad f r g sm'"
294  assumes efm: "\<And>x. empty_fail (m x)"
295  shows
296  "(sm (mapM_x m l)) = (stateAssert g [] >>= (\<lambda>u. mapM_x (sm' \<circ> m) l))"
297proof (induct l)
298  case Nil
299  thus ?case
300    by (simp add: mapM_x_def sequence_x_def bind_def submonad.return [OF sm])
301  next
302  case (Cons x xs)
303  thus ?case
304    using sm sm' efm
305    apply (simp add: mapM_x_Cons)
306    apply (simp add: bind_subst_lift [OF submonad.stateAssert_fn])
307    apply (simp add: bind_assoc submonad_bind submonad.return)
308    apply (subst submonad.fn_stateAssert [OF sm'])
309    apply (intro ext bind_apply_cong [OF refl])
310    apply (subgoal_tac "g st")
311     apply (clarsimp simp: stateAssert_def bind_def get_def assert_def return_def)
312    apply (frule(1) submonad.guard_preserved, simp)
313    done
314qed
315
316lemma corres_select:
317  "(\<forall>s' \<in> S'. \<exists>s \<in> S. rvr s s') \<Longrightarrow> corres_underlying sr nf nf' rvr \<top> \<top> (select S) (select S')"
318  by (clarsimp simp: select_def corres_underlying_def)
319
320lemma corres_select_f:
321  "\<lbrakk> \<forall>s' \<in> fst S'. \<exists>s \<in> fst S. rvr s s'; nf' \<Longrightarrow> \<not> snd S' \<rbrakk>
322      \<Longrightarrow> corres_underlying sr nf nf' rvr \<top> \<top> (select_f S) (select_f S')"
323  by (clarsimp simp: select_f_def corres_underlying_def)
324
325lemma corres_modify':
326  "\<lbrakk> (\<forall>s s'. (s, s') \<in> sr \<longrightarrow> (f s, f' s') \<in> sr); r () () \<rbrakk>
327      \<Longrightarrow> corres_underlying sr nf nf' r \<top> \<top> (modify f) (modify f')"
328  by (clarsimp simp: modify_def corres_underlying_def bind_def get_def put_def)
329
330(* FIXME: this should only be used for the lemma below *)
331lemma corres_select_f_stronger:
332  "\<lbrakk> \<forall>s' \<in> fst S'. \<exists>s \<in> fst S. rvr s s'; nf' \<Longrightarrow> \<not> snd S' \<rbrakk>
333      \<Longrightarrow> corres_underlying sr nf nf' rvr \<top> \<top> (select_f S) (select_f S')"
334  by (clarsimp simp: select_f_def corres_underlying_def)
335
336lemma stateAssert_sp:
337  "\<lbrace>P\<rbrace> stateAssert Q l \<lbrace>\<lambda>_. P and Q\<rbrace>"
338  by (clarsimp simp: valid_def stateAssert_def in_monad)
339
340lemma corres_submonad:
341  "\<lbrakk> submonad f r g fn; submonad f' r' g' fn';
342     \<forall>s s'. (s, s') \<in> sr \<and> g s \<and> g' s' \<longrightarrow> (f s, f' s') \<in> ssr;
343     \<forall>s s' ss ss'. ((s, s') \<in> sr \<and> (ss, ss') \<in> ssr) \<longrightarrow> (r ss s, r' ss' s') \<in> sr;
344     corres_underlying ssr False nf' rvr \<top> \<top> x x'\<rbrakk>
345   \<Longrightarrow> corres_underlying sr False nf' rvr g g' (fn x) (fn' x')"
346  apply (subst submonad.fn_is_sm, assumption)+
347  apply (clarsimp simp: submonad_fn_def)
348  apply (rule corres_split' [OF _ _ stateAssert_sp stateAssert_sp])
349   apply (fastforce simp: corres_underlying_def stateAssert_def get_def
350                         assert_def return_def bind_def)
351  apply (rule corres_split' [where r'="\<lambda>x y. (x, y) \<in> ssr",
352                             OF _ _ hoare_post_taut hoare_post_taut])
353   apply clarsimp
354  apply (rule corres_split' [where r'="\<lambda>(x, x') (y, y'). rvr x y \<and> (x', y') \<in> ssr",
355                             OF _ _ hoare_post_taut hoare_post_taut])
356   defer
357   apply clarsimp
358   apply (rule corres_split' [where r'=dc, OF _ _ hoare_post_taut hoare_post_taut])
359    apply (simp add: corres_modify')
360   apply clarsimp
361  apply (rule corres_select_f_stronger)
362   apply (clarsimp simp: corres_underlying_def)
363   apply (drule (1) bspec, clarsimp)
364   apply (drule (1) bspec, simp)
365   apply blast
366  apply (clarsimp simp: corres_underlying_def)
367  apply (drule (1) bspec, clarsimp)
368  done
369
370lemma stateAssert_top [simp]:
371  "stateAssert \<top> l >>= f = f ()"
372  by (clarsimp simp add: stateAssert_def get_def bind_def return_def)
373
374lemma stateAssert_A_top [simp]:
375  "stateAssert \<top> l = return ()"
376  by (simp add: stateAssert_def get_def bind_def return_def)
377
378text {* Use of the submonad concept to demonstrate commutativity. *}
379
380lemma gets_modify_comm:
381  "\<And> s. \<lbrakk> g (f s) = g s \<rbrakk> \<Longrightarrow>
382   (do x \<leftarrow> modify f; y \<leftarrow> gets g; m x y od) s =
383   (do y \<leftarrow> gets g; x \<leftarrow> modify f; m x y od) s"
384  by (simp add: modify_def gets_def get_def bind_def put_def return_def)
385
386lemma bind_subst_lhs_inv:
387  "\<And>s. \<lbrakk> \<And>x s'. P s' \<Longrightarrow> (f x >>= g x) s' = h x s'; \<lbrace>P\<rbrace> a \<lbrace>\<lambda>_. P\<rbrace>; P s \<rbrakk> \<Longrightarrow>
388   (do x \<leftarrow> a; y \<leftarrow> f x; g x y od) s = (a >>= h) s"
389  apply (rule bind_apply_cong [OF refl])
390  apply (drule(2) use_valid)
391  apply simp
392  done
393
394lemma gets_comm:
395  "do x \<leftarrow> gets f; y \<leftarrow> gets g; m x y od = do y \<leftarrow> gets g; x \<leftarrow> gets f; m x y od"
396  by (simp add: gets_def get_def return_def bind_def)
397
398lemma submonad_comm:
399  assumes x1: "submonad_args f r g" and x2: "submonad_args f' r' g'"
400  assumes y: "m = submonad_fn f r g im" "m' = submonad_fn f' r' g' im'"
401  assumes z: "\<And>s x x'. r x (r' x' s) = r' x' (r x s)"
402  assumes gp: "\<And>s x. g (r' x s) = g s" and gp': "\<And>s x. g' (r x s) = g' s"
403  assumes efim: "empty_fail im" and efim': "empty_fail im'"
404  shows      "(do x \<leftarrow> m; y \<leftarrow> m'; n x y od) = (do y \<leftarrow> m'; x \<leftarrow> m; n x y od)"
405proof -
406  have P: "\<And>x s. g s \<Longrightarrow> f (r' x s) = f s"
407    apply (subgoal_tac "f (r' x (r (f s) s)) = f s")
408     apply (simp add: submonad_args.args[OF x1])
409    apply (simp add: z[symmetric])
410    apply (subst(asm) gp [symmetric])
411    apply (fastforce dest: submonad_args.argsD1[OF x1])
412    done
413  have Q: "\<And>x s. g' s \<Longrightarrow> f' (r x s) = f' s"
414    apply (subgoal_tac "f' (r x (r' (f' s) s)) = f' s")
415     apply (simp add: submonad_args.args[OF x2])
416    apply (simp add: z)
417    apply (subst(asm) gp' [symmetric])
418    apply (fastforce dest: submonad_args.argsD1[OF x2])
419    done
420  note empty_failD [OF efim, simp]
421  note empty_failD [OF efim', simp]
422  show ?thesis
423    apply (clarsimp simp: submonad_fn_def y bind_assoc split_def)
424    apply (subst bind_subst_lift [OF modify_stateAssert], rule gp gp')+
425    apply (simp add: bind_assoc)
426    apply (subst select_f_stateAssert, rule efim efim')+
427    apply (subst gets_stateAssert bind_subst_lift [OF stateAssert_stateAssert])+
428    apply (rule bind_cong)
429     apply (simp add: pred_conj_def conj_comms)
430    apply (simp add: bind_assoc select_f_walk[symmetric])
431    apply (clarsimp dest!: fst_stateAssertD)
432    apply (subst bind_assoc[symmetric],
433           subst bind_subst_lhs_inv [OF gets_modify_comm],
434           erule P Q, wp, simp, simp)+
435    apply (simp add: bind_assoc)
436    apply (simp add: select_f_walk[symmetric])
437    apply (subst gets_comm)
438    apply (rule bind_apply_cong [OF refl])+
439    apply (subst select_f_walk, simp, simp,
440           subst select_f_walk, simp, simp,
441           rule bind_apply_cong [OF refl])
442    apply (subst select_f_walk, simp, simp, rule bind_apply_cong [OF refl])
443    apply (clarsimp simp: simpler_gets_def select_f_def)
444    apply (simp add: bind_def get_def put_def modify_def z)
445    done
446qed
447
448lemma submonad_comm2:
449  assumes x1: "submonad_args f r g" and x2: "m = submonad_fn f r g im"
450  assumes y: "submonad f' r' g' m'"
451  assumes z: "\<And>s x x'. r x (r' x' s) = r' x' (r x s)"
452  assumes gp: "\<And>s x. g (r' x s) = g s" and gp': "\<And>s x. g' (r x s) = g' s"
453  assumes efim: "empty_fail im" and efim': "empty_fail im'"
454  shows      "do x \<leftarrow> m; y \<leftarrow> m' im'; n x y od = do y \<leftarrow> m' im'; x \<leftarrow> m; n x y od"
455  apply (rule submonad_comm[where f'=f' and r'=r', OF x1 _ x2 _ z])
456       apply (insert y)
457       apply (fastforce simp add: submonad_def)
458      apply (fastforce dest: submonad.fn_is_sm)
459     apply (simp add: efim efim' gp gp')+
460  done
461
462lemma submonad_bind_alt:
463  assumes x: "submonad_args f r g"
464  assumes y: "a = submonad_fn f r g a'" "\<And>rv. b rv = submonad_fn f r g (b' rv)"
465  assumes efa: "empty_fail a'" and efb: "\<And>x. empty_fail (b' x)"
466  shows      "(a >>= b) = submonad_fn f r g (a' >>= b')"
467proof -
468  have P: "submonad f r g (submonad_fn f r g)"
469    by (simp add: x submonad_def submonad_axioms_def)
470  have Q: "b = (\<lambda>rv. submonad_fn f r g (b' rv))"
471    by (rule ext) fact+
472  show ?thesis
473    by (simp add: y Q submonad_bind [OF P P P efa efb])
474qed
475
476lemma submonad_singleton:
477  "submonad_fn fetch replace \<top> (\<lambda>s. ({(rv s, s' s)}, False))
478     = (\<lambda>s. ({(rv (fetch s), replace (s' (fetch s)) s)}, False))"
479  apply (rule ext)
480  apply (simp add: submonad_fn_def bind_def gets_def
481                put_def get_def modify_def return_def
482                select_f_def UNION_eq)
483  done
484
485lemma gets_submonad:
486  "\<lbrakk> submonad_args fetch replace \<top>; \<And>s. f s = f' (fetch s); m = gets f' \<rbrakk>
487   \<Longrightarrow> gets f = submonad_fn fetch replace \<top> m"
488  apply (drule submonad_args.args(3))
489  apply (clarsimp simp add: simpler_gets_def submonad_singleton)
490  done
491
492lemma modify_submonad:
493  "\<lbrakk> \<And>s. f s = replace (K_record (f' (fetch s))) s; m = modify f' \<rbrakk>
494     \<Longrightarrow> modify f = submonad_fn fetch (replace o K_record) \<top> m"
495  by (simp add: simpler_modify_def submonad_singleton)
496
497lemma fail_submonad:
498  "fail = submonad_fn fetch replace \<top> fail"
499  by (simp add: submonad_fn_def simpler_gets_def return_def
500                simpler_modify_def select_f_def bind_def fail_def)
501
502lemma return_submonad:
503  "submonad_args fetch replace guard \<Longrightarrow>
504   return v = submonad_fn fetch replace \<top> (return v)"
505  by (simp add: return_def submonad_singleton submonad_args.args)
506
507lemma assert_opt_submonad:
508  "submonad_args fetch replace \<top> \<Longrightarrow>
509   assert_opt v = submonad_fn fetch replace \<top> (assert_opt v)"
510  apply (case_tac v, simp_all add: assert_opt_def)
511   apply (rule fail_submonad)
512  apply (rule return_submonad)
513  apply assumption
514  done
515
516lemma is_stateAssert_gets:
517  "\<lbrakk> \<forall>s. \<lbrace>(=) s\<rbrace> f \<lbrace>\<lambda>_. (=) s\<rbrace>; \<lbrace>\<top>\<rbrace> f \<lbrace>\<lambda>_. guard\<rbrace>;
518     empty_fail f; no_fail guard f; \<lbrace>guard\<rbrace> f \<lbrace>\<lambda>rv s. fetch s = rv\<rbrace> \<rbrakk>
519    \<Longrightarrow> f = do stateAssert guard []; gets fetch od"
520  apply (rule ext)
521  apply (clarsimp simp: bind_def empty_fail_def valid_def no_fail_def
522                        stateAssert_def assert_def gets_def get_def
523                        return_def fail_def image_def split_def)
524  apply (case_tac "f x")
525  apply (intro conjI impI)
526   apply (drule_tac x=x in spec)+
527   apply (subgoal_tac "\<forall>xa\<in>fst (f x). fst xa = fetch x \<and> snd xa = x")
528    apply fastforce
529   apply clarsimp
530  apply (drule_tac x=x in spec)+
531  apply fastforce
532  done
533
534lemma is_modify:
535  "\<And>s. \<lbrakk> \<lbrace>(=) s\<rbrace> f \<lbrace>\<lambda>_. (=) (replace s)\<rbrace>; empty_fail f;
536          no_fail guard f; guard s \<rbrakk>
537    \<Longrightarrow> f s = modify replace s"
538  apply (clarsimp simp: bind_def empty_fail_def valid_def no_fail_def
539                        stateAssert_def assert_def modify_def get_def put_def
540                        return_def fail_def image_def split_def)
541  apply (case_tac "f s")
542  apply force
543  done
544
545lemma submonad_comm':
546  assumes sm1: "submonad f r g m" and sm2: "submonad f' r' g' m'"
547  assumes z: "\<And>s x x'. r x (r' x' s) = r' x' (r x s)"
548  assumes gp: "\<And>s x. g (r' x s) = g s" and gp': "\<And>s x. g' (r x s) = g' s"
549  assumes efim: "empty_fail im" and efim': "empty_fail im'"
550  shows      "(do x \<leftarrow> m im; y \<leftarrow> m' im'; n x y od) =
551              (do y \<leftarrow> m' im'; x \<leftarrow> m im; n x y od)"
552  apply (rule submonad_comm [where f'=f' and r'=r', OF _ _ _ _ z])
553         apply (insert sm1 sm2)
554         apply (fastforce dest: submonad.fn_is_sm simp: submonad_def)+
555     apply (simp add: efim efim' gp gp')+
556  done
557
558end
559