1(*
2 * Copyright 2020, Data61, CSIRO (ABN 41 687 119 230)
3 *
4 * SPDX-License-Identifier: BSD-2-Clause
5 *)
6(*
7 * Optimise L2 fragments of code by using facts learnt earlier in the fragments
8 * to simplify code afterwards.
9 *)
10
11structure L2Opt =
12struct
13
14(*
15 * Map the given simpset to tweak it for L2Opt.
16 *
17 * If "use_ugly_rules" is enabled, we will use rules that are useful for
18 * discharging proofs, but make the output ugly.
19 *)
20fun map_opt_simpset use_ugly_rules =
21    Simplifier.add_cong @{thm if_cong}
22    #> Simplifier.add_cong @{thm split_cong}
23    #> Simplifier.add_cong @{thm HOL.conj_cong}
24    #> (if use_ugly_rules then
25          (fn ctxt => ctxt addsimps [@{thm split_def}])
26        else
27           I)
28
29(*
30 * Solve a goal of the form:
31 *
32 *   simp_expr P A ?X
33 *
34 * This is done by simplifying "A" while assuming "P", and unifying the result
35 * (usually instantiating "X") in the process.
36 *)
37val simp_expr_thm =
38    @{lemma "(simp_expr P G G == simp_expr P G G') ==> simp_expr P G G'" by (clarsimp simp: simp_expr_def)}
39fun solve_simp_expr_tac ctxt =
40  Subgoal.FOCUS_PARAMS (fn {context = ctxt, ...} =>
41  (fn thm =>
42    case Drule.cprems_of thm of
43      [] => (no_tac thm)
44    | (goal::_) =>
45      (case Thm.term_of goal of
46        (_ $ (Const (@{const_name "simp_expr"}, _) $ P $ L $ _)) =>
47        let
48          val goal = @{mk_term "simp_expr ?P ?L ?L" (P, L)} (P, L)
49              |> Thm.cterm_of ctxt
50          val simplified = Simplifier.asm_full_rewrite (map_opt_simpset false ctxt) goal
51
52          (* Ensure that all schematics have been resolved. *)
53          val schematic_remains = Term.exists_subterm Term.is_Var (Thm.prop_of simplified)
54        in
55          if schematic_remains then
56            (resolve_tac ctxt @{thms simp_expr_triv} 1) thm
57          else
58            ((resolve_tac ctxt [simp_expr_thm] 1) THEN (resolve_tac ctxt [simplified] 1)) thm
59        end
60      | _ => no_tac thm)
61    )) ctxt
62
63(*
64 * Solve a goal of the forms:
65 *
66 *   simp_expr P A B
67 *
68 * where both "A" and "B" are constants (i.e., not schematics).
69 *)
70fun solve_simp_expr_const_tac ctxt thm =
71  if (Term.exists_subterm Term.is_Var (Thm.term_of (Thm.cprem_of thm 1))) then
72    no_tac thm
73  else
74    SOLVES (
75      (resolve_tac ctxt @{thms simp_expr_solve_constant} 1)
76      THEN (Clasimp.clarsimp_tac (map_opt_simpset true ctxt) 1)) thm
77
78(*
79 * Given a theorem of the form:
80 *
81 *   monad_equiv P L R Q E
82 *
83 * simplify "P", possibly trimming parts of it that are too large.
84 *
85 * The idea here is to avoid exponential blow-up by trimming off terms that get
86 * too large.
87 *)
88fun simp_monad_equiv_pre_tac ctxt =
89  Subgoal.FOCUS_PARAMS (fn {context = ctxt, ...} =>
90  (fn thm =>
91    case Thm.term_of (Thm.cprem_of thm 1) of
92      Const (@{const_name Trueprop}, _) $
93        (Const (@{const_name monad_equiv}, _) $ P $ _ $ _ $ _ $ _) =>
94        let
95          (* If P is schematic, we could end up with flex-flex pairs that Isabelle refuses to solve.
96           * Our monad_equiv rules should never allow this to happen. *)
97          val _ = if not (exists_subterm is_Var P) then () else
98                    raise CTERM ("autocorres: bad schematic in monad_equiv_pre", [Thm.cprem_of thm 1])
99          (* Perform basic simplification of the term. *)
100          val simp_thm = Simplifier.asm_full_rewrite (map_opt_simpset false ctxt) (Thm.cterm_of ctxt P)
101        in
102          (resolve_tac ctxt [@{thm monad_equiv_weaken_pre''} OF [simp_thm]] 1
103            ORELSE (fn t => raise (CTERM ("autocorres: monad_equiv_pre failed to prove goal", [Thm.cprem_of t 1])))) thm
104        end
105    | _ =>
106        all_tac thm
107    )) ctxt
108
109(*
110 * Recursively simplify a monadic expression, using information gleaned from
111 * earlier in the program to simplify parts of the program further down.
112 *)
113fun monad_equiv ctxt ct =
114let
115  (* Mark context as being "invisible" to reduce warnings being printed. *)
116  val ctxt = Context_Position.set_visible false ctxt
117
118  (* Generate our top-level "monad_equiv" goal. *)
119  val goal = @{mk_term "?L == ?R" (L)} (Thm.term_of ct)
120    |> Thm.cterm_of ctxt
121    |> Goal.init
122    |> Utils.apply_tac "Creating object-level equality." (resolve_tac ctxt @{thms eq_reflection} 1)
123    |> Utils.apply_tac "Creating 'monad_equiv' goal." (resolve_tac ctxt @{thms monad_equiv_eq} 1)
124
125  (* Print a diagnostic if this branch fails. *)
126  val num_failures = ref 0
127  fun print_failure_tac t =
128    if (false andalso !num_failures < 5) then
129      (num_failures := !num_failures + 1; (print_tac ctxt "Branch failed" THEN no_tac) t)
130    else
131      (no_tac t)
132
133  (* Fetch theorms used in the simplification process. *)
134  val thms = Utils.get_rules ctxt @{named_theorems L2flow}
135
136  (* Tactic to blindly apply simplification rules. *)
137  fun solve_goal_tac _ =
138    (simp_monad_equiv_pre_tac ctxt 1)
139    THEN DETERM (
140      SOLVES
141      ((solve_simp_expr_const_tac ctxt)
142      ORELSE
143      ((solve_simp_expr_tac ctxt 1)
144      ORELSE
145      ((resolve_tac ctxt thms THEN_ALL_NEW solve_goal_tac) 1
146      ORELSE
147      ((print_failure_tac))))))
148
149  (* Apply the rules. *)
150  val thm =
151    Utils.apply_tac "Simplifying L2" (solve_goal_tac 1) goal
152    |> Goal.finish ctxt
153in
154  thm
155end
156
157(*
158 * A simproc implementing the "L2_gets_bind" rule. The rule, unfortunately, has
159 * the ability to cause exponential growth in the spec size in some cases;
160 * thus, we can only selectively apply it in cases where this doesn't happen.
161 *
162 * In particular, we propagate a "gets" into its usage if it is used at most once.
163 *
164 * Or, if the user asks for "no_opt", we only erase the "gets" if it is never used.
165 * (Even with "no optimisation", we still want to get rid of control flow variables
166 * emitted by c-parser. Hopefully the user won't mind if their own unused variables
167 * also disappear.)
168 *)
169val l2_gets_bind_thm = mk_meta_eq @{thm L2_gets_bind}
170fun l2_gets_bind_simproc' ctxt cterm =
171let
172  fun is_simple (_ $ Abs (_, _, Bound _)) = true
173    | is_simple (_ $ Abs (_, _, Free _)) = true
174    | is_simple (_ $ Abs (_, _, Const _)) = true
175    | is_simple _ = false
176in
177  case Thm.term_of cterm of
178    (Const (@{const_name "L2_seq"}, _) $ lhs $ Abs (_, _, rhs)) =>
179      let
180        fun count_var_usage (a $ b) = count_var_usage a + count_var_usage b
181          | count_var_usage (Abs (_, _, x)) = count_var_usage x
182          | count_var_usage (Free ("_dummy", _)) = 1
183          | count_var_usage _ = 0
184        val count = count_var_usage (subst_bounds ([Free ("_dummy", dummyT)], rhs))
185      in
186        if count <= 1 orelse is_simple lhs then
187          SOME l2_gets_bind_thm
188        else
189          NONE
190      end
191    | _ => NONE
192end
193val l2_gets_bind_simproc =
194  Utils.mk_simproc' @{context}
195    ("L2_gets_bind_simproc", ["L2_seq (L2_gets (%_. ?A) ?n) ?B"], l2_gets_bind_simproc')
196
197(* Simproc to clean up guards. *)
198fun l2_guard_simproc' ss ctxt cterm =
199let
200  val simp_thm = Simplifier.asm_full_rewrite
201      (Simplifier.add_cong @{thm HOL.conj_cong} (put_simpset ss ctxt)) cterm
202  val [lhs, rhs] = Thm.prop_of (Drule.eta_contraction_rule simp_thm) |> Term.strip_comb |> snd
203in
204  if Term_Ord.fast_term_ord (lhs, rhs) = EQUAL then
205    NONE
206  else
207    SOME simp_thm
208end
209fun l2_guard_simproc ss =
210  Utils.mk_simproc' @{context} ("L2_guard_simproc", ["L2_guard ?G"], l2_guard_simproc' ss)
211
212(*
213 * Adjust "case_prod commands so that constructs such as:
214 *
215 *    while C (%x. gets (case x of (a, b) => %s. P a b)) ...
216 *
217 * are transformed into:
218 *
219 *    while C (%(a, b). gets (%s. P a b)) ...
220 *)
221fun fix_L2_while_loop_splits_conv ctxt =
222  Simplifier.asm_full_rewrite (
223    put_simpset HOL_ss ctxt
224    addsimps @{thms L2_split_fixups}
225    |> fold Simplifier.add_cong @{thms L2_split_fixups_congs})
226
227(*
228 * Carry out flow-sensitive optimisations on the given 'thm'.
229 *
230 * "n" is the argument number to cleanup, counting from 1. So for example, if
231 * our input theorem was "corres P A B", an "n" of 2 would simplify "A".
232 * If n < 0, then the cleanup is applied to the -n-th argument from the end.
233 *
234 * If "fast_mode" is 0, perform flow-sensitive optimisations (which tend to be
235 * time-consuming). If 1, only apply L2Peephole and L2Opt simplification rules.
236 * If 2, do not use AutoCorres' simplification rules at all.
237 *)
238fun cleanup_thm ctxt thm fast_mode n do_trace =
239let
240  (* Don't print out warning messages. *)
241  val ctxt = Context_Position.set_visible false ctxt
242
243  (* Setup basic simplifier. *)
244  fun basic_ss ctxt =
245      put_simpset AUTOCORRES_SIMPSET ctxt
246      |> (fn ctxt => if fast_mode < 2 then ctxt addsimps (Utils.get_rules ctxt @{named_theorems L2opt}) else ctxt)
247      |> (fn ctxt => if fast_mode < 2 then ctxt addsimprocs [l2_gets_bind_simproc] else ctxt)
248      |> (fn ctxt => ctxt addsimprocs [l2_guard_simproc (simpset_of ctxt)])
249      |> map_opt_simpset false
250  fun simp_conv ctxt =
251    Drule.beta_eta_conversion
252    then_conv (fix_L2_while_loop_splits_conv ctxt)
253    then_conv (Simplifier.rewrite (basic_ss ctxt))
254
255  fun l2conv conv =
256    Utils.remove_meta_conv (fn ctxt => Utils.nth_arg_conv n (conv ctxt)) ctxt
257
258  (* Apply peephole optimisations to the theorem. *)
259  val (new_thm, peephole_trace) =
260      AutoCorresTrace.fconv_rule_maybe_traced ctxt (l2conv simp_conv) thm do_trace
261      |> apfst Drule.eta_contraction_rule
262
263  (* Apply flow-sensitive optimisations, and then re-apply simple simplifications. *)
264  (* TODO: trace monad_equiv using trace_solve_tac rather than fconv_rule_traced *)
265  val (new_thm, flow_trace) =
266    if fast_mode = 0 then
267      AutoCorresTrace.fconv_rule_maybe_traced ctxt (
268        l2conv (fn ctxt =>
269          monad_equiv ctxt
270          then_conv (simp_conv (put_simpset AUTOCORRES_SIMPSET ctxt))
271        )) new_thm do_trace
272    else
273      (new_thm, NONE)
274
275  (* Beta/Eta normalise. *)
276  val new_thm = Conv.fconv_rule (l2conv (K Drule.beta_eta_conversion)) new_thm
277in
278  (new_thm, List.mapPartial I [peephole_trace, flow_trace])
279end
280
281(* Also tag the traces in a suitable format to be stored in AutoCorresData. *)
282fun cleanup_thm_tagged ctxt thm fast_mode n do_trace phase =
283  cleanup_thm ctxt thm fast_mode n do_trace
284  |> apsnd (map AutoCorresData.SimpTrace #> Utils.zip [phase ^ " peephole opt", phase ^ " flow opt"])
285
286end
287