1(*
2 * Copyright 2020, Data61, CSIRO (ABN 41 687 119 230)
3 *
4 * SPDX-License-Identifier: BSD-2-Clause
5 *)
6
7theory SubMonadLib
8imports
9  EmptyFailLib
10  Corres_UL
11begin
12
13locale submonad_args =
14  fixes fetch :: "'a \<Rightarrow> 'b"
15  fixes replace :: "'b \<Rightarrow> 'a \<Rightarrow> 'a"
16  fixes guard :: "'a \<Rightarrow> bool"
17
18  assumes args:
19   "\<forall>x s. guard s \<longrightarrow> fetch (replace x s) = x"
20   "\<forall>x y s. replace x (replace y s) = replace x s"
21   "\<forall>s. replace (fetch s) s = s"
22
23  assumes replace_preserves_guard:
24   "\<And>s x. guard (replace x s) = guard s"
25
26definition
27  submonad_fn :: "('a \<Rightarrow> 'b) \<Rightarrow> ('b \<Rightarrow> 'a \<Rightarrow> 'a) \<Rightarrow> ('a \<Rightarrow> bool) \<Rightarrow>
28                  ('b, 'c) nondet_monad \<Rightarrow> ('a, 'c) nondet_monad"
29where
30 "submonad_fn fetch replace guard m \<equiv> do
31    stateAssert guard [];
32    substate \<leftarrow> gets fetch;
33    (rv, substate') \<leftarrow> select_f (m substate);
34    modify (replace substate');
35    return rv
36  od"
37
38locale submonad = submonad_args +
39  fixes fn :: "('b, 'c) nondet_monad \<Rightarrow> ('a, 'c) nondet_monad"
40
41  assumes fn_is_sm: "fn = submonad_fn fetch replace guard"
42
43lemma (in submonad_args) argsD1:
44  "\<And>x s. guard s \<Longrightarrow> fetch (replace x s) = x"
45  by (simp add: args)
46
47lemma (in submonad) guarded_sm:
48  "\<And>s. guard s \<Longrightarrow>
49   fn m s = (do
50     substate \<leftarrow> gets fetch;
51     (rv, substate') \<leftarrow> select_f (m substate);
52     modify (replace substate');
53     return rv
54   od) s"
55  unfolding fn_is_sm submonad_fn_def
56  by (simp add: stateAssert_def get_def assert_def bind_def return_def)
57
58lemma modify_modify:
59  "modify fn1 >>= (\<lambda>x. modify fn2) = modify (fn2 \<circ> fn1)"
60  by (simp add: bind_def modify_def get_def put_def)
61
62lemma select_f_walk:
63  assumes m1: "empty_fail m1"
64  assumes S: "fst S = {} \<Longrightarrow> snd S"
65  shows "(do a \<leftarrow> m1; b \<leftarrow> select_f S; m2 a b od) = (do b \<leftarrow> select_f S; a \<leftarrow> m1; m2 a b od)"
66  apply (rule ext)
67  apply (rule prod.expand)
68  apply (rule conjI)
69   apply (simp add: select_f_def bind_def split_def)
70   apply fastforce
71  apply (simp add: select_f_def bind_def split_def)
72  apply (case_tac "fst S = {}")
73   apply clarsimp
74   apply (case_tac "fst (m1 x) = {}")
75    apply (simp add: empty_failD [OF m1] S)
76   apply (frule S)
77   apply force
78  apply safe
79     apply clarsimp
80     apply force
81    apply force
82   apply clarsimp
83   apply force
84  apply clarsimp
85  apply (case_tac "fst (m1 x) = {}", simp add: empty_failD [OF m1])
86  apply force
87  done
88
89lemma stateAssert_stateAssert:
90  "(stateAssert g [] >>= (\<lambda>u. stateAssert g' [])) = stateAssert (g and g') []"
91  by (simp add: ext stateAssert_def bind_def get_def assert_def fail_def return_def)
92
93lemma modify_stateAssert:
94  "\<lbrakk> \<And>s x. g (r x s) = g s \<rbrakk> \<Longrightarrow>
95   (modify (r x) >>= (\<lambda>u. stateAssert g []))
96            = (stateAssert g [] >>= (\<lambda>u. modify (r x)))"
97  by (simp add: ext stateAssert_def bind_def get_def assert_def fail_def
98                return_def modify_def put_def)
99
100lemma gets_stateAssert:
101  "(gets f >>= (\<lambda>x. stateAssert g' [] >>= (\<lambda>u. m x)))
102            = (stateAssert g' [] >>= (\<lambda>u. gets f >>= (\<lambda>x. m x)))"
103  by (simp add: ext stateAssert_def bind_def gets_def get_def
104                assert_def fail_def return_def)
105
106lemma select_f_stateAssert:
107  "empty_fail m \<Longrightarrow>
108   (select_f (m a) >>= (\<lambda>x. stateAssert g [] >>= (\<lambda>u. n x))) =
109   (stateAssert g [] >>= (\<lambda>u. select_f (m a) >>= (\<lambda>x. n x)))"
110  apply (rule ext)
111  apply (clarsimp simp: stateAssert_def bind_def select_f_def get_def
112                        assert_def return_def fail_def split_def image_image)
113  apply (simp only: image_def)
114  apply (clarsimp simp: stateAssert_def bind_def select_f_def get_def
115                        assert_def return_def fail_def split_def image_image)
116  apply (simp only: image_def mem_simps empty_fail_def simp_thms)
117  apply fastforce
118  done
119
120lemma bind_select_f_bind':
121  shows "(select_f (m s) >>= (\<lambda>x. select_f (split n x))) = (select_f ((m >>= n) s))"
122  apply (rule ext)
123  apply (force simp: select_f_def bind_def split_def)
124  done
125
126lemma bind_select_f_bind:
127  "(select_f (m1 s) >>= (\<lambda>x. select_f (m2 (fst x) (snd x)))) = (select_f ((m1 >>= m2) s))"
128  by (insert bind_select_f_bind' [where m=m1 and n=m2 and s=s],
129      simp add: split_def)
130
131lemma select_from_gets: "select_f (gets f s) = return (f s, s)"
132  apply (rule ext)
133  apply (simp add: select_f_def return_def simpler_gets_def)
134  done
135
136lemma select_from_gets':
137  "(select_f \<circ> gets f) = (\<lambda>s. return (f s, s))"
138  apply (rule ext)
139  apply (simp add: o_def select_from_gets)
140  done
141
142lemma bind_subst_lift:
143  "(f >>= g) = h \<Longrightarrow> (do x \<leftarrow> f; y \<leftarrow> g x; j y od) = (h >>= j)"
144  by (simp add: bind_assoc[symmetric])
145
146lemma modify_gets:
147  "\<lbrakk> \<And>x s. g (r x s) = g s; \<And>x s. g s \<longrightarrow> f (r x s) = x \<rbrakk>
148   \<Longrightarrow> (modify (r x) >>= (\<lambda>u. stateAssert g [] >>= (\<lambda>u'. gets f)))
149            = (stateAssert g [] >>= (\<lambda>u'. modify (r x) >>= (\<lambda>u. return x)))"
150  by (simp add: ext stateAssert_def assert_def modify_def bind_def get_def
151                put_def gets_def return_def fail_def)
152
153lemma (in submonad_args) gets_modify:
154  "\<And>s. guard s \<Longrightarrow>
155   (do x \<leftarrow> gets fetch; u \<leftarrow> modify (replace x); f x od) s = ((gets fetch) >>= f) s"
156  by (clarsimp simp: modify_def gets_def return_def bind_def
157                     put_def args get_def
158              split: option.split)
159
160lemma submonad_bind:
161  "\<lbrakk> submonad f r g m; submonad f r g m'; submonad f r g m'';
162     empty_fail a; \<And>x. empty_fail (b x) \<rbrakk> \<Longrightarrow>
163   m (a >>= b) = (m' a) >>= (\<lambda>rv. m'' (b rv))"
164  apply (subst submonad.fn_is_sm, assumption)+
165  apply (clarsimp simp: submonad_def bind_assoc split_def submonad_fn_def)
166  apply (subst bind_subst_lift [OF modify_gets, unfolded bind_assoc])
167    apply (simp add: submonad_args.args submonad_args.replace_preserves_guard)+
168  apply (subst select_f_stateAssert, assumption)
169  apply (subst gets_stateAssert)
170  apply (subst bind_subst_lift [OF stateAssert_stateAssert])
171  apply (clarsimp simp: pred_conj_def)
172  apply (clarsimp simp: bind_assoc split_def select_f_walk
173                empty_fail_stateAssert empty_failD
174                bind_subst_lift[OF modify_modify] submonad_args.args o_def
175                bind_subst_lift[OF bind_select_f_bind])
176  done
177
178lemma (in submonad) guard_preserved:
179  "\<And>s s'. \<lbrakk> (rv, s') \<in> fst (fn m s) \<rbrakk> \<Longrightarrow> guard s'"
180  unfolding fn_is_sm submonad_fn_def
181  by (clarsimp simp: stateAssert_def gets_def get_def bind_def modify_def put_def
182                     return_def select_f_def replace_preserves_guard in_monad)
183
184lemma fst_stateAssertD:
185  "\<And>s s' v. (v, s') \<in> fst (stateAssert g [] s) \<Longrightarrow> s' = s \<and> g s"
186  by (clarsimp simp: stateAssert_def in_monad)
187
188lemma(in submonad) guarded_gets:
189  "\<And>s. guard s \<Longrightarrow> fn (gets f) s = gets (f \<circ> fetch) s"
190  apply (simp add: guarded_sm select_from_gets gets_modify)
191  apply (simp add: gets_def)
192  done
193
194lemma (in submonad) guarded_return:
195  "\<And>s. guard s \<Longrightarrow> fn (return x) s = return x s"
196  using args guarded_gets
197  by (fastforce simp: gets_def bind_def get_def)
198
199lemma (in submonad_args) submonad_fn_gets:
200  "submonad_fn fetch replace guard (gets f) =
201   (stateAssert guard [] >>= (\<lambda>u. gets (f \<circ> fetch)))"
202  apply (simp add: ext select_from_gets submonad_fn_def)
203  apply (rule bind_cong [OF refl])
204  apply (clarsimp simp: gets_modify dest!: fst_stateAssertD)
205  apply (simp add: gets_def)
206  done
207
208lemma(in submonad) gets:
209  "fn (gets f) = (stateAssert guard [] >>= (\<lambda>u. gets (f \<circ> fetch)))"
210  unfolding fn_is_sm submonad_fn_gets
211  by (rule refl)
212
213lemma (in submonad) return:
214  "fn (return x) = (stateAssert guard [] >>= (\<lambda>u. return x))"
215  using args gets
216  by (fastforce simp: gets_def bind_def get_def)
217
218lemma (in submonad) mapM_guard_preserved:
219  "\<And>s s'. \<lbrakk> guard s; \<exists>rv. (rv, s') \<in> fst (mapM (fn \<circ> m) xs s)\<rbrakk> \<Longrightarrow> guard s'"
220proof (induct xs)
221  case Nil
222  thus ?case
223    by (simp add: mapM_def sequence_def return_def)
224  next
225  case (Cons x xs)
226  thus ?case
227    apply (clarsimp simp: o_def mapM_Cons return_def bind_def)
228    apply (drule guard_preserved)
229    apply fastforce
230    done
231qed
232
233lemma (in submonad) mapM_x_guard_preserved:
234  "\<And>s s'. \<lbrakk> guard s; \<exists>rv. (rv, s') \<in> fst (mapM_x (fn \<circ> m) xs s)\<rbrakk> \<Longrightarrow> guard s'"
235proof (induct xs)
236  case Nil
237  thus ?case
238    by (simp add: mapM_x_def sequence_x_def return_def)
239  next
240  case (Cons x xs)
241  thus ?case
242    apply (clarsimp simp: o_def mapM_x_Cons return_def bind_def)
243    apply (drule guard_preserved)
244    apply fastforce
245    done
246qed
247
248lemma (in submonad) stateAssert_fn:
249  "stateAssert guard [] >>= (\<lambda>u. fn m) = fn m"
250  by (simp add: fn_is_sm submonad_fn_def pred_conj_def
251                bind_subst_lift [OF stateAssert_stateAssert])
252
253lemma (in submonad) fn_stateAssert:
254  "fn m >>= (\<lambda>x. stateAssert guard [] >>= (\<lambda>u. n x)) = (fn m >>= n)"
255  apply (simp add: fn_is_sm submonad_fn_def bind_assoc split_def)
256  apply (rule ext)
257  apply (rule bind_apply_cong [OF refl])+
258  apply (clarsimp simp: stateAssert_def bind_assoc in_monad select_f_def)
259  apply (drule iffD2 [OF replace_preserves_guard])
260  apply (fastforce simp: bind_def assert_def get_def return_def)
261  done
262
263lemma submonad_mapM:
264  assumes sm: "submonad f r g sm" and sm': "submonad f r g sm'"
265  assumes efm: "\<And>x. empty_fail (m x)"
266  shows
267  "(sm (mapM m l)) = (stateAssert g [] >>= (\<lambda>u. mapM (sm' \<circ> m) l))"
268proof (induct l)
269  case Nil
270  thus ?case
271    by (simp add: mapM_def sequence_def bind_def submonad.return [OF sm])
272  next
273  case (Cons x xs)
274  thus ?case
275    using sm sm' efm
276    apply (simp add: mapM_Cons)
277    apply (simp add: bind_subst_lift [OF submonad.stateAssert_fn])
278    apply (simp add: bind_assoc submonad_bind submonad.return)
279    apply (subst submonad.fn_stateAssert [OF sm'])
280    apply (intro ext bind_apply_cong [OF refl])
281    apply (subgoal_tac "g sta")
282     apply (clarsimp simp: stateAssert_def bind_def get_def assert_def return_def)
283    apply (frule(1) submonad.guard_preserved)
284    apply (erule(1) submonad.mapM_guard_preserved, fastforce simp: o_def)
285    done
286qed
287
288lemma submonad_mapM_x:
289  assumes sm: "submonad f r g sm" and sm': "submonad f r g sm'"
290  assumes efm: "\<And>x. empty_fail (m x)"
291  shows
292  "(sm (mapM_x m l)) = (stateAssert g [] >>= (\<lambda>u. mapM_x (sm' \<circ> m) l))"
293proof (induct l)
294  case Nil
295  thus ?case
296    by (simp add: mapM_x_def sequence_x_def bind_def submonad.return [OF sm])
297  next
298  case (Cons x xs)
299  thus ?case
300    using sm sm' efm
301    apply (simp add: mapM_x_Cons)
302    apply (simp add: bind_subst_lift [OF submonad.stateAssert_fn])
303    apply (simp add: bind_assoc submonad_bind submonad.return)
304    apply (subst submonad.fn_stateAssert [OF sm'])
305    apply (intro ext bind_apply_cong [OF refl])
306    apply (subgoal_tac "g st")
307     apply (clarsimp simp: stateAssert_def bind_def get_def assert_def return_def)
308    apply (frule(1) submonad.guard_preserved, simp)
309    done
310qed
311
312lemma corres_select:
313  "(\<forall>s' \<in> S'. \<exists>s \<in> S. rvr s s') \<Longrightarrow> corres_underlying sr nf nf' rvr \<top> \<top> (select S) (select S')"
314  by (clarsimp simp: select_def corres_underlying_def)
315
316lemma corres_select_f:
317  "\<lbrakk> \<forall>s' \<in> fst S'. \<exists>s \<in> fst S. rvr s s'; nf' \<Longrightarrow> \<not> snd S' \<rbrakk>
318      \<Longrightarrow> corres_underlying sr nf nf' rvr \<top> \<top> (select_f S) (select_f S')"
319  by (clarsimp simp: select_f_def corres_underlying_def)
320
321lemma corres_modify':
322  "\<lbrakk> (\<forall>s s'. (s, s') \<in> sr \<longrightarrow> (f s, f' s') \<in> sr); r () () \<rbrakk>
323      \<Longrightarrow> corres_underlying sr nf nf' r \<top> \<top> (modify f) (modify f')"
324  by (clarsimp simp: modify_def corres_underlying_def bind_def get_def put_def)
325
326(* FIXME: this should only be used for the lemma below *)
327lemma corres_select_f_stronger:
328  "\<lbrakk> \<forall>s' \<in> fst S'. \<exists>s \<in> fst S. rvr s s'; nf' \<Longrightarrow> \<not> snd S' \<rbrakk>
329      \<Longrightarrow> corres_underlying sr nf nf' rvr \<top> \<top> (select_f S) (select_f S')"
330  by (clarsimp simp: select_f_def corres_underlying_def)
331
332lemma stateAssert_sp:
333  "\<lbrace>P\<rbrace> stateAssert Q l \<lbrace>\<lambda>_. P and Q\<rbrace>"
334  by (clarsimp simp: valid_def stateAssert_def in_monad)
335
336lemma corres_submonad:
337  "\<lbrakk> submonad f r g fn; submonad f' r' g' fn';
338     \<forall>s s'. (s, s') \<in> sr \<and> g s \<and> g' s' \<longrightarrow> (f s, f' s') \<in> ssr;
339     \<forall>s s' ss ss'. ((s, s') \<in> sr \<and> (ss, ss') \<in> ssr) \<longrightarrow> (r ss s, r' ss' s') \<in> sr;
340     corres_underlying ssr False nf' rvr \<top> \<top> x x'\<rbrakk>
341   \<Longrightarrow> corres_underlying sr False nf' rvr g g' (fn x) (fn' x')"
342  apply (subst submonad.fn_is_sm, assumption)+
343  apply (clarsimp simp: submonad_fn_def)
344  apply (rule corres_split' [OF _ _ stateAssert_sp stateAssert_sp])
345   apply (fastforce simp: corres_underlying_def stateAssert_def get_def
346                         assert_def return_def bind_def)
347  apply (rule corres_split' [where r'="\<lambda>x y. (x, y) \<in> ssr",
348                             OF _ _ hoare_post_taut hoare_post_taut])
349   apply clarsimp
350  apply (rule corres_split' [where r'="\<lambda>(x, x') (y, y'). rvr x y \<and> (x', y') \<in> ssr",
351                             OF _ _ hoare_post_taut hoare_post_taut])
352   defer
353   apply clarsimp
354   apply (rule corres_split' [where r'=dc, OF _ _ hoare_post_taut hoare_post_taut])
355    apply (simp add: corres_modify')
356   apply clarsimp
357  apply (rule corres_select_f_stronger)
358   apply (clarsimp simp: corres_underlying_def)
359   apply (drule (1) bspec, clarsimp)
360   apply (drule (1) bspec, simp)
361   apply blast
362  apply (clarsimp simp: corres_underlying_def)
363  apply (drule (1) bspec, clarsimp)
364  done
365
366lemma stateAssert_top [simp]:
367  "stateAssert \<top> l >>= f = f ()"
368  by (clarsimp simp add: stateAssert_def get_def bind_def return_def)
369
370lemma stateAssert_A_top [simp]:
371  "stateAssert \<top> l = return ()"
372  by (simp add: stateAssert_def get_def bind_def return_def)
373
374text \<open>Use of the submonad concept to demonstrate commutativity.\<close>
375
376lemma gets_modify_comm:
377  "\<And> s. \<lbrakk> g (f s) = g s \<rbrakk> \<Longrightarrow>
378   (do x \<leftarrow> modify f; y \<leftarrow> gets g; m x y od) s =
379   (do y \<leftarrow> gets g; x \<leftarrow> modify f; m x y od) s"
380  by (simp add: modify_def gets_def get_def bind_def put_def return_def)
381
382lemma bind_subst_lhs_inv:
383  "\<And>s. \<lbrakk> \<And>x s'. P s' \<Longrightarrow> (f x >>= g x) s' = h x s'; \<lbrace>P\<rbrace> a \<lbrace>\<lambda>_. P\<rbrace>; P s \<rbrakk> \<Longrightarrow>
384   (do x \<leftarrow> a; y \<leftarrow> f x; g x y od) s = (a >>= h) s"
385  apply (rule bind_apply_cong [OF refl])
386  apply (drule(2) use_valid)
387  apply simp
388  done
389
390lemma gets_comm:
391  "do x \<leftarrow> gets f; y \<leftarrow> gets g; m x y od = do y \<leftarrow> gets g; x \<leftarrow> gets f; m x y od"
392  by (simp add: gets_def get_def return_def bind_def)
393
394lemma submonad_comm:
395  assumes x1: "submonad_args f r g" and x2: "submonad_args f' r' g'"
396  assumes y: "m = submonad_fn f r g im" "m' = submonad_fn f' r' g' im'"
397  assumes z: "\<And>s x x'. r x (r' x' s) = r' x' (r x s)"
398  assumes gp: "\<And>s x. g (r' x s) = g s" and gp': "\<And>s x. g' (r x s) = g' s"
399  assumes efim: "empty_fail im" and efim': "empty_fail im'"
400  shows      "(do x \<leftarrow> m; y \<leftarrow> m'; n x y od) = (do y \<leftarrow> m'; x \<leftarrow> m; n x y od)"
401proof -
402  have P: "\<And>x s. g s \<Longrightarrow> f (r' x s) = f s"
403    apply (subgoal_tac "f (r' x (r (f s) s)) = f s")
404     apply (simp add: submonad_args.args[OF x1])
405    apply (simp add: z[symmetric])
406    apply (subst(asm) gp [symmetric])
407    apply (fastforce dest: submonad_args.argsD1[OF x1])
408    done
409  have Q: "\<And>x s. g' s \<Longrightarrow> f' (r x s) = f' s"
410    apply (subgoal_tac "f' (r x (r' (f' s) s)) = f' s")
411     apply (simp add: submonad_args.args[OF x2])
412    apply (simp add: z)
413    apply (subst(asm) gp' [symmetric])
414    apply (fastforce dest: submonad_args.argsD1[OF x2])
415    done
416  note empty_failD [OF efim, simp]
417  note empty_failD [OF efim', simp]
418  show ?thesis
419    apply (clarsimp simp: submonad_fn_def y bind_assoc split_def)
420    apply (subst bind_subst_lift [OF modify_stateAssert], rule gp gp')+
421    apply (simp add: bind_assoc)
422    apply (subst select_f_stateAssert, rule efim efim')+
423    apply (subst gets_stateAssert bind_subst_lift [OF stateAssert_stateAssert])+
424    apply (rule bind_cong)
425     apply (simp add: pred_conj_def conj_comms)
426    apply (simp add: bind_assoc select_f_walk[symmetric])
427    apply (clarsimp dest!: fst_stateAssertD)
428    apply (subst bind_assoc[symmetric],
429           subst bind_subst_lhs_inv [OF gets_modify_comm],
430           erule P Q, wp, simp, simp)+
431    apply (simp add: bind_assoc)
432    apply (simp add: select_f_walk[symmetric])
433    apply (subst gets_comm)
434    apply (rule bind_apply_cong [OF refl])+
435    apply (subst select_f_walk, simp, simp,
436           subst select_f_walk, simp, simp,
437           rule bind_apply_cong [OF refl])
438    apply (subst select_f_walk, simp, simp, rule bind_apply_cong [OF refl])
439    apply (clarsimp simp: simpler_gets_def select_f_def)
440    apply (simp add: bind_def get_def put_def modify_def z)
441    done
442qed
443
444lemma submonad_comm2:
445  assumes x1: "submonad_args f r g" and x2: "m = submonad_fn f r g im"
446  assumes y: "submonad f' r' g' m'"
447  assumes z: "\<And>s x x'. r x (r' x' s) = r' x' (r x s)"
448  assumes gp: "\<And>s x. g (r' x s) = g s" and gp': "\<And>s x. g' (r x s) = g' s"
449  assumes efim: "empty_fail im" and efim': "empty_fail im'"
450  shows      "do x \<leftarrow> m; y \<leftarrow> m' im'; n x y od = do y \<leftarrow> m' im'; x \<leftarrow> m; n x y od"
451  apply (rule submonad_comm[where f'=f' and r'=r', OF x1 _ x2 _ z])
452       apply (insert y)
453       apply (fastforce simp add: submonad_def)
454      apply (fastforce dest: submonad.fn_is_sm)
455     apply (simp add: efim efim' gp gp')+
456  done
457
458lemma submonad_bind_alt:
459  assumes x: "submonad_args f r g"
460  assumes y: "a = submonad_fn f r g a'" "\<And>rv. b rv = submonad_fn f r g (b' rv)"
461  assumes efa: "empty_fail a'" and efb: "\<And>x. empty_fail (b' x)"
462  shows      "(a >>= b) = submonad_fn f r g (a' >>= b')"
463proof -
464  have P: "submonad f r g (submonad_fn f r g)"
465    by (simp add: x submonad_def submonad_axioms_def)
466  have Q: "b = (\<lambda>rv. submonad_fn f r g (b' rv))"
467    by (rule ext) fact+
468  show ?thesis
469    by (simp add: y Q submonad_bind [OF P P P efa efb])
470qed
471
472lemma submonad_singleton:
473  "submonad_fn fetch replace \<top> (\<lambda>s. ({(rv s, s' s)}, False))
474     = (\<lambda>s. ({(rv (fetch s), replace (s' (fetch s)) s)}, False))"
475  apply (rule ext)
476  apply (simp add: submonad_fn_def bind_def gets_def
477                put_def get_def modify_def return_def
478                select_f_def UNION_eq)
479  done
480
481lemma gets_submonad:
482  "\<lbrakk> submonad_args fetch replace \<top>; \<And>s. f s = f' (fetch s); m = gets f' \<rbrakk>
483   \<Longrightarrow> gets f = submonad_fn fetch replace \<top> m"
484  apply (drule submonad_args.args(3))
485  apply (clarsimp simp add: simpler_gets_def submonad_singleton)
486  done
487
488lemma modify_submonad:
489  "\<lbrakk> \<And>s. f s = replace (K_record (f' (fetch s))) s; m = modify f' \<rbrakk>
490     \<Longrightarrow> modify f = submonad_fn fetch (replace o K_record) \<top> m"
491  by (simp add: simpler_modify_def submonad_singleton)
492
493lemma fail_submonad:
494  "fail = submonad_fn fetch replace \<top> fail"
495  by (simp add: submonad_fn_def simpler_gets_def return_def
496                simpler_modify_def select_f_def bind_def fail_def)
497
498lemma return_submonad:
499  "submonad_args fetch replace guard \<Longrightarrow>
500   return v = submonad_fn fetch replace \<top> (return v)"
501  by (simp add: return_def submonad_singleton submonad_args.args)
502
503lemma assert_opt_submonad:
504  "submonad_args fetch replace \<top> \<Longrightarrow>
505   assert_opt v = submonad_fn fetch replace \<top> (assert_opt v)"
506  apply (case_tac v, simp_all add: assert_opt_def)
507   apply (rule fail_submonad)
508  apply (rule return_submonad)
509  apply assumption
510  done
511
512lemma is_stateAssert_gets:
513  "\<lbrakk> \<forall>s. \<lbrace>(=) s\<rbrace> f \<lbrace>\<lambda>_. (=) s\<rbrace>; \<lbrace>\<top>\<rbrace> f \<lbrace>\<lambda>_. guard\<rbrace>;
514     empty_fail f; no_fail guard f; \<lbrace>guard\<rbrace> f \<lbrace>\<lambda>rv s. fetch s = rv\<rbrace> \<rbrakk>
515    \<Longrightarrow> f = do stateAssert guard []; gets fetch od"
516  apply (rule ext)
517  apply (clarsimp simp: bind_def empty_fail_def valid_def no_fail_def
518                        stateAssert_def assert_def gets_def get_def
519                        return_def fail_def image_def split_def)
520  apply (case_tac "f x")
521  apply (intro conjI impI)
522   apply (drule_tac x=x in spec)+
523   apply (subgoal_tac "\<forall>xa\<in>fst (f x). fst xa = fetch x \<and> snd xa = x")
524    apply fastforce
525   apply clarsimp
526  apply (drule_tac x=x in spec)+
527  apply fastforce
528  done
529
530lemma is_modify:
531  "\<And>s. \<lbrakk> \<lbrace>(=) s\<rbrace> f \<lbrace>\<lambda>_. (=) (replace s)\<rbrace>; empty_fail f;
532          no_fail guard f; guard s \<rbrakk>
533    \<Longrightarrow> f s = modify replace s"
534  apply (clarsimp simp: bind_def empty_fail_def valid_def no_fail_def
535                        stateAssert_def assert_def modify_def get_def put_def
536                        return_def fail_def image_def split_def)
537  apply (case_tac "f s")
538  apply force
539  done
540
541lemma submonad_comm':
542  assumes sm1: "submonad f r g m" and sm2: "submonad f' r' g' m'"
543  assumes z: "\<And>s x x'. r x (r' x' s) = r' x' (r x s)"
544  assumes gp: "\<And>s x. g (r' x s) = g s" and gp': "\<And>s x. g' (r x s) = g' s"
545  assumes efim: "empty_fail im" and efim': "empty_fail im'"
546  shows      "(do x \<leftarrow> m im; y \<leftarrow> m' im'; n x y od) =
547              (do y \<leftarrow> m' im'; x \<leftarrow> m im; n x y od)"
548  apply (rule submonad_comm [where f'=f' and r'=r', OF _ _ _ _ z])
549         apply (insert sm1 sm2)
550         apply (fastforce dest: submonad.fn_is_sm simp: submonad_def)+
551     apply (simp add: efim efim' gp gp')+
552  done
553
554end
555