1(*
2 * Copyright 2020, Data61, CSIRO (ABN 41 687 119 230)
3 *
4 * SPDX-License-Identifier: BSD-2-Clause
5 *)
6
7signature WP =
8sig
9  type wp_rules = {trips: thm list * (theory -> term -> term),
10    rules: (int * thm) Net.net * int * (int * thm) list,
11    splits: thm list, combs: thm list, unsafe_rules: thm list};
12
13  val debug_get: Proof.context -> wp_rules;
14
15  val get_rules: Proof.context -> thm list -> wp_rules;
16
17  val apply_rules_tac_n: bool -> Proof.context -> thm list -> int -> tactic;
18  val apply_rules_tac: bool -> Proof.context -> thm list -> tactic;
19  val apply_once_tac: bool -> Proof.context -> thm list -> tactic;
20  val apply_wp_args: (Proof.context -> Method.method) context_parser;
21
22  val setup: theory -> theory;
23
24  val wp_add: Thm.attribute;
25  val wp_del: Thm.attribute;
26  val splits_add: Thm.attribute;
27  val splits_del: Thm.attribute;
28  val combs_add: Thm.attribute;
29  val combs_del: Thm.attribute;
30  val wp_unsafe_add: Thm.attribute;
31  val wp_unsafe_del: Thm.attribute;
32end;
33
34structure WeakestPre =
35struct
36
37type wp_rules = {trips: thm list * (theory -> term -> term),
38    rules: (int * thm) Net.net * int * (int * thm) list,
39    splits: thm list, combs: thm list, unsafe_rules: thm list};
40
41fun accum_last_occurence' [] _ = ([], Termtab.empty)
42  | accum_last_occurence' ((t, v) :: ts) tt1 = let
43      val tm = Thm.prop_of t;
44      val tt2 = Termtab.insert_list (K false) (tm, v) tt1;
45      val (ts', tt3) = accum_last_occurence' ts tt2;
46  in case Termtab.lookup tt3 tm of
47        NONE => ((t, Termtab.lookup_list tt2 tm)  :: ts',
48                    Termtab.update (tm, ()) tt3)
49      | SOME _ => (ts', tt3)
50  end;
51
52fun accum_last_occurence ts =
53        fst (accum_last_occurence' ts Termtab.empty);
54
55fun flat_last_occurence ts =
56  map fst (accum_last_occurence (map (fn v => (v, ())) ts));
57
58fun dest_rules (trips, _, others) =
59  rev (order_list (Net.entries trips @ others));
60
61fun get_key trip_conv t = let
62    val t' = Thm.concl_of t |> trip_conv (Thm.theory_of_thm t)
63        |> Envir.beta_eta_contract;
64  in case t' of Const (@{const_name Trueprop}, _) $
65      (Const (@{const_name triple_judgement}, _) $ _ $ f $ _) => SOME f
66    | _ => NONE end;
67
68fun add_rule_inner trip_conv t (trips, n, others) = (
69  case get_key trip_conv t of
70      SOME k => (Net.insert_term (K false)
71                 (k, (n, t)) trips, n + 1, others)
72    | _ => (trips, n + 1, (n, t) :: others)
73  );
74
75fun del_rule_inner trip_conv t (trips, n, others) =
76    case get_key trip_conv t of
77      SOME k => (Net.delete_term_safe (Thm.eq_thm_prop o apply2 snd)
78                 (k, (n, t)) trips, n, others)
79    | _ => (trips, n, remove (Thm.eq_thm_prop o apply2 snd) (n, t) others)
80
81val no_rules = (Net.empty, 0, []);
82
83fun mk_rules trip_conv rules = fold_rev (add_rule_inner trip_conv) rules no_rules;
84
85fun mk_trip_conv trips thy = Pattern.rewrite_term thy
86    (map (Thm.concl_of #> HOLogic.dest_Trueprop #> HOLogic.dest_eq) trips) []
87
88fun rules_merge (wp_rules, wp_rules') = let
89    val trips = Thm.merge_thms (fst (#trips wp_rules), fst (#trips wp_rules'));
90    val trip_conv = mk_trip_conv trips
91    val rules = flat_last_occurence (dest_rules (#rules wp_rules) @ dest_rules (#rules wp_rules'));
92  in {trips = (trips, trip_conv),
93        rules = mk_rules trip_conv rules,
94        splits = Thm.merge_thms (#splits wp_rules, #splits wp_rules'),
95        combs = Thm.merge_thms (#combs wp_rules, #combs wp_rules'),
96        unsafe_rules = Thm.merge_thms (#unsafe_rules wp_rules, #unsafe_rules wp_rules')} end
97
98structure WPData = Generic_Data
99(struct
100    type T = wp_rules;
101    val empty = {trips = ([], K I), rules = no_rules,
102      splits = [], combs = [], unsafe_rules = []};
103    val extend = I;
104
105    val merge = rules_merge;
106end);
107
108fun is_wp_rule ctxt thm =
109let
110  val ({rules, trips, ...}) = WPData.get (Context.Proof ctxt);
111  val (triples, _, others) = rules;
112  val trip_conv = (snd trips);
113  val results = case get_key trip_conv thm of
114    SOME k => Net.lookup triples (Net.key_of_term k)
115  | NONE => others
116in exists (fn (_, thm') => Thm.eq_thm_prop (thm, thm')) results end
117
118fun add_rule rule rs =
119    {trips = #trips rs,
120      rules = add_rule_inner (snd (#trips rs)) rule (#rules rs),
121      splits = #splits rs,
122      combs = #combs rs,
123      unsafe_rules = #unsafe_rules rs
124    }
125
126fun del_rule rule rs =
127    {trips = #trips rs,
128      rules = del_rule_inner (snd (#trips rs)) rule (#rules rs),
129      splits = #splits rs,
130      combs = #combs rs,
131      unsafe_rules = #unsafe_rules rs
132    }
133
134fun add_trip rule (rs : wp_rules) = let
135    val trips = Thm.add_thm rule (fst (#trips rs));
136    val trip_conv = mk_trip_conv trips
137  in {trips = (trips, trip_conv),
138      rules = mk_rules trip_conv (dest_rules (#rules rs)),
139      splits = #splits rs,
140      combs = #combs rs,
141      unsafe_rules = #unsafe_rules rs}
142  end
143
144fun del_trip rule (rs : wp_rules) = let
145    val trips = Thm.del_thm rule (fst (#trips rs));
146    val trip_conv = mk_trip_conv trips
147  in {trips = (trips, trip_conv),
148      rules = mk_rules trip_conv (dest_rules (#rules rs)),
149      splits = #splits rs,
150      combs = #combs rs,
151      unsafe_rules = #unsafe_rules rs}
152  end
153
154fun add_split rule (rs : wp_rules) =
155    {trips = #trips rs,
156      rules = #rules rs,
157      splits = Thm.add_thm rule (#splits rs),
158      combs = #combs rs,
159      unsafe_rules = #unsafe_rules rs}
160
161fun add_comb rule (rs : wp_rules) =
162    {trips = #trips rs, rules = #rules rs,
163      splits = #splits rs, combs = Thm.add_thm rule (#combs rs),
164      unsafe_rules = #unsafe_rules rs}
165
166fun del_split rule rs =
167    {trips = #trips rs, rules = #rules rs,
168      splits = Thm.del_thm rule (#splits rs), combs = #combs rs,
169      unsafe_rules = #unsafe_rules rs}
170
171fun del_comb rule rs =
172    {trips = #trips rs, rules = #rules rs,
173      splits = #splits rs, combs = Thm.del_thm rule (#combs rs),
174      unsafe_rules = #unsafe_rules rs}
175
176fun add_unsafe_rule rule rs =
177    {trips = #trips rs, rules = #rules rs,
178      splits = #splits rs, combs = #combs rs,
179      unsafe_rules = Thm.add_thm rule (#unsafe_rules rs)}
180
181fun del_unsafe_rule rule rs =
182    {trips = #trips rs, rules = #rules rs,
183      splits = #splits rs, combs = #combs rs,
184      unsafe_rules = Thm.del_thm rule (#unsafe_rules rs)}
185
186fun gen_att m = Thm.declaration_attribute (fn thm => fn context => WPData.map (m thm) context);
187
188val wp_add = gen_att add_rule;
189val wp_del = gen_att del_rule;
190val trip_add = gen_att add_trip;
191val trip_del = gen_att del_trip;
192val splits_add = gen_att add_split;
193val splits_del = gen_att del_split;
194val combs_add = gen_att add_comb;
195val combs_del = gen_att del_comb;
196val wp_unsafe_add = gen_att add_unsafe_rule;
197val wp_unsafe_del = gen_att del_unsafe_rule;
198
199val setup =
200      Attrib.setup @{binding "wp"}
201          (Attrib.add_del wp_add wp_del)
202          "monadic weakest precondition rules"
203      #> Attrib.setup @{binding "wp_trip"}
204          (Attrib.add_del trip_add trip_del)
205          "monadic triple conversion rules"
206      #> Attrib.setup @{binding "wp_split"}
207          (Attrib.add_del splits_add splits_del)
208          "monadic split rules"
209      #> Attrib.setup @{binding "wp_comb"}
210          (Attrib.add_del combs_add combs_del)
211          "monadic combination rules"
212      #> Attrib.setup @{binding "wp_unsafe"}
213          (Attrib.add_del wp_unsafe_add wp_unsafe_del)
214          "unsafe monadic weakest precondition rules"
215
216fun debug_get ctxt = WPData.get (Context.Proof ctxt);
217
218fun get_rules ctxt extras = fold_rev add_rule extras (debug_get ctxt);
219
220fun resolve_ruleset_tac' trace ctxt rs used_thms_ref n t =
221  let
222    val rtac = WP_Pre.rtac ctxt
223    fun trace_rtac tag rule = WP_Pre.trace_rule trace ctxt used_thms_ref tag rtac rule
224  in case
225    Thm.cprem_of t n |> Thm.term_of |> snd (#trips rs) (Thm.theory_of_thm t)
226        |> Envir.beta_eta_contract |> Logic.strip_assums_concl
227     handle THM _ => @{const True}
228  of Const (@{const_name Trueprop}, _) $
229      (Const (@{const_name triple_judgement}, _) $ _ $ f $ _) =>
230      let
231        val rules = Net.unify_term (#1 (#rules rs)) f |> order_list |> rev;
232        fun per_rule_combapp_tac rule combapp =
233          let val insts_ref = Unsynchronized.ref (Trace_Schematic_Insts.empty_instantiations)
234          in WP_Pre.trace_rule' trace ctxt
235               (fn rule_insts => fn _ => insts_ref := rule_insts)
236               rtac combapp
237             THEN'
238             WP_Pre.trace_rule' trace ctxt
239               (fn rule_insts => fn _ =>
240                  (WP_Pre.append_used_rule ctxt used_thms_ref "wp_comb" combapp (!insts_ref);
241                   WP_Pre.append_used_rule ctxt used_thms_ref "wp" rule rule_insts))
242               rtac rule
243          end
244        fun per_rule_tac rule =
245          trace_rtac "wp" rule ORELSE'
246          FIRST' (map (per_rule_combapp_tac rule) (#combs rs))
247      in (FIRST' (map per_rule_tac rules) ORELSE'
248          FIRST' (map (trace_rtac "wp_split") (#splits rs))) n t
249      end
250    | _ => FIRST' (map (trace_rtac "wp") (map snd (#3 (#rules rs))) @
251                   map (trace_rtac "wp_split") (#splits rs)) n t
252  end;
253
254fun resolve_ruleset_tac trace ctxt rs used_thms_ref n =
255  (Apply_Debug.break ctxt (SOME "wp")) THEN (resolve_ruleset_tac' trace ctxt rs used_thms_ref n)
256
257fun trace_used_thm ctxt (name, tag, prop) =
258  let val adjusted_name = ThmExtras.adjust_thm_name ctxt (name, NONE) prop
259  in Pretty.block
260    (ThmExtras.pretty_adjusted_name ctxt adjusted_name ::
261     [Pretty.str ("[" ^ tag ^ "]:"),Pretty.brk 1, Syntax.unparse_term ctxt prop])
262  end
263
264fun trace_used_thms trace ctxt used_thms_ref =
265  if trace
266  then Pretty.big_list "Theorems used by wp:"
267                       (map (trace_used_thm ctxt) (!used_thms_ref))
268       |> Pretty.writeln
269       handle Size => warning ("WP tracing information was too large to print.")
270  else ();
271
272fun warn_unsafe_rules unsafe_rules n ctxt t =
273  let val used_thms_dummy = Unsynchronized.ref [] : (string * string * term) list Unsynchronized.ref;
274      val ctxt' = Config.put WP_Pre.wp_trace false ctxt
275      val useful_unsafe_rules =
276          filter (fn rule =>
277            (is_some o SINGLE (
278              resolve_ruleset_tac false ctxt' (get_rules ctxt [rule]) used_thms_dummy n)) t)
279            unsafe_rules
280  in if not (null useful_unsafe_rules)
281     then Pretty.list "Unsafe theorems that could be used: \n" ""
282                      (map (ThmExtras.pretty_thm true ctxt) useful_unsafe_rules)
283          |> Pretty.writeln
284     else () end;
285
286fun apply_rules_tac_n trace ctxt extras n =
287let
288  val trace' = trace orelse Config.get ctxt WP_Pre.wp_trace
289  val used_thms_ref = Unsynchronized.ref [] : (string * string * term) list Unsynchronized.ref
290  val rules = get_rules ctxt extras
291  val wp_pre_tac = TRY (WP_Pre.tac trace' used_thms_ref ctxt 1)
292  val wp_fix_tac = TRY (WPFix.both_tac ctxt 1)
293  val cleanup_tac = TRY (REPEAT
294                      (resolve_tac ctxt [@{thm TrueI}, @{thm conj_TrueI}, @{thm conj_TrueI2}] 1
295                       ORELSE assume_tac ctxt 1))
296  val steps_tac = (CHANGED (REPEAT_DETERM (resolve_ruleset_tac trace' ctxt rules used_thms_ref 1)))
297                  THEN cleanup_tac
298in
299  SELECT_GOAL (
300    (fn t => Seq.map (fn thm => (trace_used_thms trace' ctxt used_thms_ref;
301                                 used_thms_ref := []; thm))
302                     ((wp_pre_tac THEN wp_fix_tac THEN steps_tac) t))
303    THEN_ELSE
304    (fn t => (warn_unsafe_rules (#unsafe_rules rules) 1 ctxt t; all_tac t),
305     fn t => (warn_unsafe_rules (#unsafe_rules rules) 1 ctxt t; no_tac t))) n
306end
307
308fun apply_rules_tac trace ctxt extras = apply_rules_tac_n trace ctxt extras 1;
309
310fun apply_once_tac trace ctxt extras t =
311  let
312    val trace' = trace orelse Config.get ctxt WP_Pre.wp_trace
313    val used_thms_ref = Unsynchronized.ref [] : (string * string * term) list Unsynchronized.ref
314    val rules = get_rules ctxt extras
315  in Seq.map (fn thm => (trace_used_thms trace' ctxt used_thms_ref; thm))
316             (SELECT_GOAL (resolve_ruleset_tac trace' ctxt rules used_thms_ref 1) 1 t)
317  end
318
319fun clear_rules ({combs, rules, trips, splits, unsafe_rules}) =
320  {combs=combs, rules=no_rules, trips=trips, splits=splits, unsafe_rules=unsafe_rules}
321
322val wp_modifiers =
323 [Args.add -- Args.colon >> K (I, wp_add),
324  Args.del -- Args.colon >> K (I, wp_del),
325  Args.$$$ "comb" -- Args.colon >> K (I, combs_add),
326  Args.$$$ "comb" -- Args.add -- Args.colon >> K (I, combs_add),
327  Args.$$$ "comb" -- Args.del -- Args.colon >> K (I, combs_del),
328  Args.$$$ "only" -- Args.colon >> K (Context.proof_map (WPData.map clear_rules), wp_add)];
329
330fun has_colon xs = exists (Token.keyword_with (curry (op =) ":")) xs;
331
332fun if_colon scan1 scan2 xs = if has_colon (snd xs) then scan1 xs else scan2 xs;
333
334(* FIXME: It would be nice if we could just use Method.sections, but to maintain
335   compatability we require that the order of thms in each section is reversed. *)
336fun thms ss = Scan.repeat (Scan.unless (Scan.lift (Scan.first ss)) Attrib.multi_thm) >> flat;
337fun app (f, att) ths context = fold_map (Thm.apply_attribute att) ths (Context.map_proof f context);
338
339fun section ss = Scan.depend (fn ctxt => (Scan.first ss -- Scan.pass ctxt (thms ss)) :|--
340  (fn (m, thms) => Scan.succeed (swap (app m (rev thms) ctxt))));
341fun sections ss = Scan.repeat (section ss);
342
343val add_section = Scan.depend (fn ctxt => (Scan.pass ctxt Attrib.thms) :|--
344  (fn thms => Scan.succeed (swap (app (I, wp_add) (rev thms) ctxt))));
345
346fun modes ss =
347  Scan.optional (Args.parens (Parse.list (Scan.first (map Args.$$$ ss)))
348                 >> (fn strings => map (member (op =) strings) ss))
349                (replicate (length ss) false);
350
351fun apply_wp_args xs =
352  let fun apply_tac once = if once then apply_once_tac else apply_rules_tac;
353  in
354    Scan.lift (modes ["trace", "once"])
355      --| if_colon (sections wp_modifiers >> flat) add_section
356    >> curry (fn ([trace, once], ctxt) => SIMPLE_METHOD (apply_tac once trace ctxt []))
357  end xs;
358
359end;
360
361structure WeakestPreInst : WP = WeakestPre;
362