1(*
2 * Copyright 2020, Data61, CSIRO (ABN 41 687 119 230)
3 *
4 * SPDX-License-Identifier: BSD-2-Clause
5 *)
6
7theory WPEx
8imports
9  NonDetMonadVCG
10  Strengthen
11begin
12
13text \<open>WPEx - the WP Extension Experiment\<close>
14
15definition
16  mresults :: "('s, 'a) nondet_monad \<Rightarrow> ('a \<times> 's \<times> 's) set"
17where
18 "mresults f = {(rv, s', s). (rv, s') \<in> fst (f s)}"
19
20definition
21  assert_value_exported :: "'x \<times> 's \<Rightarrow> ('s, 'a) nondet_monad \<Rightarrow> ('x \<Rightarrow> ('s, 'a) nondet_monad)"
22where
23 "assert_value_exported x f y \<equiv>
24    do s \<leftarrow> get; if x = (y, s) then f else fail od"
25
26syntax
27  "_assert_bind" :: "['a, 'b] => dobind" ("_ =<= _")
28
29translations
30  "do v =<= a; e od" == "a >>= CONST assert_value_exported v e"
31  "doE v =<= a; e odE" == "a >>=E CONST assert_value_exported v e"
32
33lemma in_mresults_export:
34  "(rv, s', s) \<in> mresults (assert_value_exported (rv', s'') f rv'')
35      = ((rv, s', s) \<in> mresults f \<and> rv' = rv'' \<and> s'' = s)"
36  by (simp add: assert_value_exported_def mresults_def in_monad)
37
38lemma in_mresults_bind:
39  "(rv, s', s) \<in> mresults (a >>= b)
40       = (\<exists>rv' s''. (rv, s', s'') \<in> mresults (b rv') \<and> (rv', s'', s) \<in> mresults a)"
41  apply (simp add: mresults_def bind_def)
42  apply (auto elim: rev_bexI)
43  done
44
45lemma mresults_export_bindD:
46  "(rv, s', s) \<in> mresults (a >>= assert_value_exported (rv', s'') b)
47       \<Longrightarrow> (rv, s', s'') \<in> mresults b"
48  "(rv, s', s) \<in> mresults (a >>= assert_value_exported (rv', s'') b)
49       \<Longrightarrow> (rv', s'', s) \<in> mresults a"
50  by (simp_all add: in_mresults_export in_mresults_bind)
51
52definition "wpex_name_for_id = id"
53
54definition "wpex_name_for_id_prop p \<equiv> (p :: prop)"
55
56lemma wpex_name_for_id_propI:
57  "PROP p \<Longrightarrow> PROP wpex_name_for_id_prop p"
58  by (simp add: wpex_name_for_id_prop_def)
59
60lemma wpex_name_for_id_propE:
61  "PROP wpex_name_for_id_prop p \<Longrightarrow> PROP p"
62  by (simp add:  wpex_name_for_id_prop_def)
63
64lemma del_asm_rule:
65  "\<lbrakk> PROP P; PROP Q \<rbrakk> \<Longrightarrow> PROP Q"
66  by assumption
67
68ML \<open>
69
70val p_prop_var = Term.dest_Var (Logic.varify_global @{term "P :: prop"});
71
72fun del_asm_tac asm =
73  eresolve0_tac [(Thm.instantiate ([], [(p_prop_var, asm)]) @{thm del_asm_rule})];
74
75fun subgoal_asm_as_thm tac =
76  Subgoal.FOCUS_PARAMS (fn focus => SUBGOAL (fn (t, _) => let
77    val asms = Logic.strip_assums_hyp t;
78    val ctxt = #context focus;
79    fun asm_tac asm = (Subgoal.FOCUS_PREMS (fn focus => let
80        fun is_asm asm' = asm aconv (Thm.concl_of asm');
81        val (asm' :: _) = filter is_asm (#prems focus);
82      in tac asm' end) (#context focus)
83        THEN_ALL_NEW del_asm_tac (Thm.cterm_of ctxt asm)) 1;
84  in
85    FIRST (map asm_tac asms)
86  end) 1);
87
88exception SAME;
89
90fun eta_flat (Abs (name, tp, (Abs a)))
91        = eta_flat (Abs (name, tp, eta_flat (Abs a)))
92  | eta_flat (Abs (_, _, t $ Bound 0))
93        = if member (=) (loose_bnos t) 0 then raise SAME
94          else subst_bound (Bound 0, t)
95  | eta_flat (Abs (name, tp, t $ Abs a))
96        = eta_flat (Abs (name, tp, t $ eta_flat (Abs a)))
97  | eta_flat _ = raise SAME;
98
99fun const_spine t = case strip_comb t of
100    (Const c, xs) => SOME (c, xs)
101  | (Abs v, []) => (const_spine (eta_flat (Abs v)) handle SAME => NONE)
102  | (Abs _, (_ :: _)) => error "const_spine: term not beta expanded"
103  | _ => NONE;
104
105fun build_annotate' t wr ps = case (const_spine t, wr) of
106    (SOME (bd as ("NonDetMonad.bind", _), [a, b]),
107     "WPEx.mresults") => let
108           val (a', ps') = build_annotate' a "WPEx.mresults" ps;
109         in case const_spine b of
110             SOME (ass as ("WPEx.assert_value_exported", _), [rvs, c])
111                 => let
112                      val (c', ps'') = build_annotate' c "WPEx.mresults" ps'
113                    in (Const bd $ a' $ (Const ass $ rvs $ c'), ps'') end
114          | _ => let
115                  val tp  = fastype_of (Const bd);
116                  val btp = domain_type (range_type tp);
117                  val rtp = domain_type btp;
118                  val stp = domain_type (range_type btp);
119                  val mtp = range_type (range_type btp);
120                  val ass = Const ("WPEx.assert_value_exported",
121                                   HOLogic.mk_prodT (rtp, stp) -->
122                                    (stp --> mtp) --> rtp --> stp --> mtp);
123                  val rv  = Bound (length ps');
124                  val s   = Bound (length ps' + 1);
125                  val rvs = HOLogic.pair_const rtp stp $ rv $ s;
126                  val b'  = betapply (b, Bound (length ps'));
127                  val borings = ["x", "y", "rv"];
128                  val rvnms = case b of
129                      Abs (rvnm, _, _) =>
130                          if member (=) borings rvnm then []
131                          else [(rvnm, rvnm ^ "_st")]
132                    | _ => [];
133                  val cnms = case const_spine a' of
134                      SOME ((cnm, _), _) => let
135                          val cnm' = List.last (space_explode "." cnm);
136                        in [(cnm' ^ "_rv", cnm' ^ "_st")] end
137                    | _ => [];
138                  val nms = hd (rvnms @ cnms @ [("rv", "s")]);
139                  val ps'' = ps' @ [(fst nms, rtp), (snd nms, stp)];
140                  val (b'', ps''') = build_annotate' b' "WPEx.mresults" ps'';
141               in (Const bd $ a' $ (ass $ rvs $ b''), ps''') end
142         end
143  | _ => (t, ps);
144
145fun build_annotate asm =
146  case const_spine (HOLogic.dest_Trueprop (Envir.beta_norm asm)) of
147    SOME (memb as ("Set.member", _), [x, st]) => (case const_spine st of
148        SOME (mres as ("WPEx.mresults", _), [m]) => let
149              val (m', ps) = build_annotate' m "WPEx.mresults" [];
150              val _ = if null ps then raise SAME else ();
151              val t = Const memb $ x $ (Const mres $ m');
152              fun mk_exists ((s, tp), tm) = HOLogic.exists_const tp $ Abs (s, tp, tm);
153            in HOLogic.mk_Trueprop (Library.foldr mk_exists (rev ps, t)) end
154      | _ => raise SAME) | _ => raise SAME;
155
156
157val put_Lib_simpset =  put_simpset (Simplifier.simpset_of (Proof_Context.init_global @{theory Lib}))
158
159
160fun in_mresults_ctxt ctxt = ctxt
161    |> put_Lib_simpset
162    |> (fn ctxt => ctxt addsimps [@{thm in_mresults_export}, @{thm in_mresults_bind}])
163    |> Splitter.del_split @{thm if_split}
164
165fun prove_qad ctxt term tac = Goal.prove ctxt [] [] term
166  (K (if Config.get ctxt quick_and_dirty andalso false
167      then ALLGOALS (Skip_Proof.cheat_tac ctxt)
168      else tac));
169
170fun preannotate_ss ctxt = ctxt
171  |> put_simpset HOL_basic_ss
172  |> (fn ctxt => ctxt addsimps [@{thm K_bind_def}])
173  |> simpset_of
174
175fun in_mresults_ss ctxt = ctxt
176  |> put_Lib_simpset
177  |> (fn ctxt => ctxt addsimps [@{thm in_mresults_export}, @{thm in_mresults_bind}])
178  |> Splitter.del_split @{thm if_split}
179  |> simpset_of
180
181
182val in_mresults_cs = Classical.claset_of (Proof_Context.init_global @{theory Lib});
183
184fun annotate_tac ctxt asm = let
185    val asm' = simplify (put_simpset (preannotate_ss ctxt) ctxt) asm;
186    val annotated = build_annotate (Thm.concl_of asm');
187    val ctxt' = Classical.put_claset in_mresults_cs (put_simpset (in_mresults_ss ctxt) ctxt)
188    val thm = prove_qad ctxt (Logic.mk_implies (Thm.concl_of asm', annotated))
189                   (auto_tac ctxt'
190                    THEN ALLGOALS (TRY o blast_tac ctxt'));
191  in
192    cut_facts_tac [asm' RS thm] 1
193  end
194  handle SAME => no_tac;
195
196fun annotate_goal_tac ctxt
197  = REPEAT_DETERM1 (subgoal_asm_as_thm (annotate_tac ctxt) ctxt 1
198       ORELSE (eresolve_tac ctxt [exE] 1));
199
200val annotate_method =
201  Scan.succeed (fn ctxt => Method.SIMPLE_METHOD (annotate_goal_tac ctxt))
202    : (Proof.context -> Method.method) context_parser;
203
204\<close>
205
206method_setup annotate = \<open>annotate_method\<close> "tries to annotate"
207
208lemma use_valid_mresults:
209  "\<lbrakk> (rv, s', s) \<in> mresults f; \<lbrace>P\<rbrace> f \<lbrace>Q\<rbrace> \<rbrakk> \<Longrightarrow> P s \<longrightarrow> Q rv s'"
210  by (auto simp: mresults_def valid_def)
211
212lemma mresults_validI:
213  "\<lbrakk> \<And>rv s' s. (rv, s', s) \<in> mresults f \<Longrightarrow> P s \<longrightarrow> Q rv s' \<rbrakk>
214        \<Longrightarrow> \<lbrace>P\<rbrace> f \<lbrace>Q\<rbrace>"
215  by (auto simp: mresults_def valid_def)
216
217ML \<open>
218
219val use_valid_mresults = @{thm use_valid_mresults};
220
221val mresults_export_bindD = @{thms mresults_export_bindD};
222
223fun dest_mresults_tac t = Seq.of_list ([t] RL mresults_export_bindD);
224
225(* take a rule of conclusion p --> q and decide whether to use it
226   as an introduction rule or if of form ?P x --> ?P y to use it
227   as y = x *)
228fun get_rule_uses ctxt rule = let
229    val (p, q) = (Thm.concl_of #> Envir.beta_eta_contract #> HOLogic.dest_Trueprop
230                    #> HOLogic.dest_imp) rule;
231    fun mk_eqthm v (n, (x, _)) = let
232        val (_, tp) = dest_Var v;
233        val (argtps, tp') = strip_type tp;
234        val _ = if (tp' = @{typ bool}) then ()
235                else error "get_rule_uses: range type <> bool";
236        val ct = Thm.cterm_of ctxt;
237        val eq = HOLogic.eq_const (nth argtps (n - 1))
238                    $ Bound (length argtps - n) $ x;
239        val v' = fold_rev Term.abs (map (pair "x") argtps) eq;
240      in rule
241        |> Thm.instantiate ([], [(Term.dest_Var v, ct v')])
242        |> simplify (put_simpset HOL_ss ctxt)
243      end;
244  in case (strip_comb p, strip_comb q) of
245      ((v as Var _, args), (v' as Var _, args')) =>
246        if v = v' andalso length args = length args'
247        then (map (mk_eqthm v) ((1 upto length args) ~~ (args ~~ args')), [])
248        else ([], [])
249    | (_, (Var _, _)) => ([], [])
250    | _ => ([], [rule])
251  end;
252
253fun get_wp_simps_strgs ctxt rules asms = let
254    val wp_rules = rules @ (WeakestPre.debug_get ctxt |> #rules |> WeakestPre.dest_rules);
255    val wp_rules' = filter (null o Thm.prems_of) wp_rules;
256    val asms' = maps (Seq.list_of o REPEAT dest_mresults_tac) asms;
257    val uses = asms' RL [use_valid_mresults];
258    val wp_rules'' = wp_rules' RL uses;
259  in
260    apply2 flat (map_split (get_rule_uses ctxt) wp_rules'')
261  end;
262
263fun tac_with_wp_simps_strgs ctxt rules tac =
264  subgoal_asm_as_thm (fn asm => let
265    val (simps, strgs) = get_wp_simps_strgs ctxt rules [asm]
266  in
267    cut_facts_tac [asm] 1 THEN tac (simps, strgs)
268  end) ctxt;
269
270val mresults_validI = @{thm mresults_validI};
271
272fun postcond_ss ctxt = ctxt
273    |> put_simpset HOL_basic_ss
274    |> (fn ctxt => ctxt addsimps [@{thm pred_conj_def}])
275    |> simpset_of
276
277fun wp_default_ss ctxt = ctxt
278    |> put_simpset HOL_ss
279    |> Splitter.del_split @{thm if_split}
280    |> simpset_of
281
282fun raise_tac s = all_tac THEN (fn _ => error s);
283
284fun wpx_tac ctxt rules
285  = TRY (resolve_tac ctxt [mresults_validI] 1)
286    THEN (full_simp_tac (put_simpset (postcond_ss ctxt) ctxt) 1)
287    THEN TRY (annotate_goal_tac ctxt)
288    THEN tac_with_wp_simps_strgs ctxt rules (fn (simps, strgs) =>
289      REPEAT_DETERM1
290        (CHANGED (full_simp_tac (put_simpset (wp_default_ss ctxt) ctxt addsimps simps) 1)
291          ORELSE Strengthen.default_strengthen ctxt strgs 1)
292    ) 1;
293
294val wpx_method = Attrib.thms >> curry (fn (ts, ctxt) =>
295  Method.SIMPLE_METHOD (wpx_tac ctxt ts));
296
297\<close>
298
299method_setup wpx = \<open>wpx_method\<close> "experimental wp method"
300
301lemma foo:
302  "(rv, s', s) \<in> mresults (do x \<leftarrow> get; y \<leftarrow> get; put (x + y :: nat); return () od)
303        \<Longrightarrow> s' = s + s"
304  apply annotate
305  apply wpx
306  done
307
308lemma foo2:
309  "(rv, s', s) \<in> mresults (do x \<leftarrow> get; y \<leftarrow> get; put (if z = Suc 0 then x + y else x + y + z); return () od)
310        \<Longrightarrow> s' = s + s + (if z = Suc 0 then 0 else z)"
311  apply wpx
312  apply simp
313  done
314
315text \<open>Have played around with it, the issues are:
316  1: Need to deal with non-linear code, known issue.
317  2: Using fastforce in annotate isn't cutting the mustard, need to automate better.
318     Probably half the issue is that there are too many simp rules available.
319  3: Related to (2), there's the question of whether you can simplify code enough
320     once it's been annotated. This may re-raise the specter of annotation on demand.
321  4: It's hard to tell whether it's worked or not.
322  5: Structural rules don't really work - rules that want to transform the whole
323     postcondition once we get up to a particular point. Related to 4, it's hard to
324     say where that point is hit.
325  6: Performance problems with getting the set of available rules.
326\<close>
327
328lemma valid_strengthen_with_mresults:
329  "\<lbrakk> \<And>s rv s'. \<lbrakk> (rv, s', s) \<in> mresults f;
330           wpex_name_for_id (Q' s rv s') \<rbrakk> \<Longrightarrow> Q rv s';
331       \<And>prev_s. \<lbrace>P prev_s\<rbrace> f \<lbrace>Q' prev_s\<rbrace> \<rbrakk>
332     \<Longrightarrow> \<lbrace>\<lambda>s. P s s\<rbrace> f \<lbrace>Q\<rbrace>"
333  apply atomize
334  apply (clarsimp simp: valid_def mresults_def wpex_name_for_id_def)
335  apply blast
336  done
337
338lemma wpex_name_for_idE: "wpex_name_for_id P \<Longrightarrow> P"
339  by (simp add: wpex_name_for_id_def)
340
341ML \<open>
342
343val valid_strengthen_with_mresults = @{thm valid_strengthen_with_mresults};
344val wpex_name_for_idE = @{thm wpex_name_for_idE};
345
346fun wps_tac ctxt rules =
347let
348  (* avoid duplicate simp rule etc warnings: *)
349  val ctxt = Context_Position.set_visible false ctxt
350in
351  resolve_tac ctxt [valid_strengthen_with_mresults]
352  THEN' (safe_simp_tac (put_simpset (postcond_ss ctxt) ctxt))
353  THEN' Subgoal.FOCUS (fn focus => let
354      val ctxt = #context focus;
355      val (simps, _) = get_wp_simps_strgs ctxt rules (#prems focus);
356    in CHANGED (simp_tac (put_simpset (wp_default_ss ctxt) ctxt addsimps simps) 1) end) ctxt
357  THEN' eresolve_tac ctxt [wpex_name_for_idE]
358end
359
360val wps_method = Attrib.thms >> curry
361  (fn (ts, ctxt) => Method.SIMPLE_METHOD' (wps_tac ctxt ts));
362
363\<close>
364
365method_setup wps = \<open>wps_method\<close> "experimental wp simp method"
366
367lemma foo3:
368  "\<lbrace>P\<rbrace> do v \<leftarrow> return (Suc 0); return (Suc (Suc 0)) od \<lbrace>(=)\<rbrace>"
369  apply (rule hoare_pre)
370   apply (rule hoare_seq_ext)+
371    apply (wps | rule hoare_vcg_prop)+
372  oops
373
374end
375