1(*
2 * Copyright 2020, Data61, CSIRO (ABN 41 687 119 230)
3 *
4 * SPDX-License-Identifier: BSD-2-Clause
5 *)
6
7(*
8 * Tactic for solving monadic equalities, such as:
9 *
10 * (liftE (return 3) = returnOk 3
11 *
12 * Theorems of the form:
13 *
14 *   ((a, s') \<in> fst (A s)) = P a s s'
15 *
16 * and
17 *
18 *   snd (A s) = P s
19 *
20 * are added to the "monad_eq" set.
21 *)
22theory MonadEq
23imports NonDetMonadVCG
24begin
25
26(* Setup "monad_eq" attributes. *)
27ML \<open>
28structure MonadEqThms = Named_Thms (
29    val name = Binding.name "monad_eq"
30    val description = "monad equality-prover theorems"
31    )
32\<close>
33attribute_setup monad_eq = \<open>
34  Attrib.add_del
35    (Thm.declaration_attribute MonadEqThms.add_thm)
36    (Thm.declaration_attribute MonadEqThms.del_thm)\<close>
37  "Monad equality-prover theorems"
38
39(* Setup tactic. *)
40
41ML \<open>
42fun monad_eq_tac ctxt =
43let
44  (* Set a simpset as being hidden, so warnings are not printed from it. *)
45  val ctxt' = Context_Position.set_visible false ctxt
46in
47  CHANGED (clarsimp_tac (ctxt' addsimps (MonadEqThms.get ctxt')) 1)
48end
49\<close>
50
51method_setup monad_eq = \<open>
52    Method.sections Clasimp.clasimp_modifiers >> (K (SIMPLE_METHOD o monad_eq_tac))\<close>
53  "prove equality on monads"
54
55lemma monad_eq_simp_state [monad_eq]:
56  "((A :: ('s, 'a) nondet_monad) s = B s') =
57      ((\<forall>r t. (r, t) \<in> fst (A s) \<longrightarrow> (r, t) \<in> fst (B s'))
58         \<and> (\<forall>r t. (r, t) \<in> fst (B s') \<longrightarrow> (r, t) \<in> fst (A s))
59         \<and> (snd (A s) = snd (B s')))"
60  apply (auto intro!: set_eqI prod_eqI)
61  done
62
63lemma monad_eq_simp [monad_eq]:
64  "((A :: ('s, 'a) nondet_monad) = B) =
65      ((\<forall>r t s. (r, t) \<in> fst (A s) \<longrightarrow> (r, t) \<in> fst (B s))
66         \<and> (\<forall>r t s. (r, t) \<in> fst (B s) \<longrightarrow> (r, t) \<in> fst (A s))
67         \<and> (\<forall>x. snd (A x) = snd (B x)))"
68  apply (auto intro!: set_eqI prod_eqI)
69  done
70
71declare in_monad [monad_eq]
72declare in_bindE [monad_eq]
73
74(* Test *)
75lemma "returnOk 3 = liftE (return 3)"
76  apply monad_eq
77  oops
78
79end
80