1(* 2 * Copyright 2020, Data61, CSIRO (ABN 41 687 119 230) 3 * 4 * SPDX-License-Identifier: BSD-2-Clause 5 *) 6(* 7 * Optimise L2 fragments of code by using facts learnt earlier in the fragments 8 * to simplify code afterwards. 9 *) 10 11structure L2Opt = 12struct 13 14(* 15 * Map the given simpset to tweak it for L2Opt. 16 * 17 * If "use_ugly_rules" is enabled, we will use rules that are useful for 18 * discharging proofs, but make the output ugly. 19 *) 20fun map_opt_simpset use_ugly_rules = 21 Simplifier.add_cong @{thm if_cong} 22 #> Simplifier.add_cong @{thm split_cong} 23 #> Simplifier.add_cong @{thm HOL.conj_cong} 24 #> (if use_ugly_rules then 25 (fn ctxt => ctxt addsimps [@{thm split_def}]) 26 else 27 I) 28 29(* 30 * Solve a goal of the form: 31 * 32 * simp_expr P A ?X 33 * 34 * This is done by simplifying "A" while assuming "P", and unifying the result 35 * (usually instantiating "X") in the process. 36 *) 37val simp_expr_thm = 38 @{lemma "(simp_expr P G G == simp_expr P G G') ==> simp_expr P G G'" by (clarsimp simp: simp_expr_def)} 39fun solve_simp_expr_tac ctxt = 40 Subgoal.FOCUS_PARAMS (fn {context = ctxt, ...} => 41 (fn thm => 42 case Drule.cprems_of thm of 43 [] => (no_tac thm) 44 | (goal::_) => 45 (case Thm.term_of goal of 46 (_ $ (Const (@{const_name "simp_expr"}, _) $ P $ L $ _)) => 47 let 48 val goal = @{mk_term "simp_expr ?P ?L ?L" (P, L)} (P, L) 49 |> Thm.cterm_of ctxt 50 val simplified = Simplifier.asm_full_rewrite (map_opt_simpset false ctxt) goal 51 52 (* Ensure that all schematics have been resolved. *) 53 val schematic_remains = Term.exists_subterm Term.is_Var (Thm.prop_of simplified) 54 in 55 if schematic_remains then 56 (resolve_tac ctxt @{thms simp_expr_triv} 1) thm 57 else 58 ((resolve_tac ctxt [simp_expr_thm] 1) THEN (resolve_tac ctxt [simplified] 1)) thm 59 end 60 | _ => no_tac thm) 61 )) ctxt 62 63(* 64 * Solve a goal of the forms: 65 * 66 * simp_expr P A B 67 * 68 * where both "A" and "B" are constants (i.e., not schematics). 69 *) 70fun solve_simp_expr_const_tac ctxt thm = 71 if (Term.exists_subterm Term.is_Var (Thm.term_of (Thm.cprem_of thm 1))) then 72 no_tac thm 73 else 74 SOLVES ( 75 (resolve_tac ctxt @{thms simp_expr_solve_constant} 1) 76 THEN (Clasimp.clarsimp_tac (map_opt_simpset true ctxt) 1)) thm 77 78(* 79 * Given a theorem of the form: 80 * 81 * monad_equiv P L R Q E 82 * 83 * simplify "P", possibly trimming parts of it that are too large. 84 * 85 * The idea here is to avoid exponential blow-up by trimming off terms that get 86 * too large. 87 *) 88fun simp_monad_equiv_pre_tac ctxt = 89 Subgoal.FOCUS_PARAMS (fn {context = ctxt, ...} => 90 (fn thm => 91 case Thm.term_of (Thm.cprem_of thm 1) of 92 Const (@{const_name Trueprop}, _) $ 93 (Const (@{const_name monad_equiv}, _) $ P $ _ $ _ $ _ $ _) => 94 let 95 (* If P is schematic, we could end up with flex-flex pairs that Isabelle refuses to solve. 96 * Our monad_equiv rules should never allow this to happen. *) 97 val _ = if not (exists_subterm is_Var P) then () else 98 raise CTERM ("autocorres: bad schematic in monad_equiv_pre", [Thm.cprem_of thm 1]) 99 (* Perform basic simplification of the term. *) 100 val simp_thm = Simplifier.asm_full_rewrite (map_opt_simpset false ctxt) (Thm.cterm_of ctxt P) 101 in 102 (resolve_tac ctxt [@{thm monad_equiv_weaken_pre''} OF [simp_thm]] 1 103 ORELSE (fn t => raise (CTERM ("autocorres: monad_equiv_pre failed to prove goal", [Thm.cprem_of t 1])))) thm 104 end 105 | _ => 106 all_tac thm 107 )) ctxt 108 109(* 110 * Recursively simplify a monadic expression, using information gleaned from 111 * earlier in the program to simplify parts of the program further down. 112 *) 113fun monad_equiv ctxt ct = 114let 115 (* Mark context as being "invisible" to reduce warnings being printed. *) 116 val ctxt = Context_Position.set_visible false ctxt 117 118 (* Generate our top-level "monad_equiv" goal. *) 119 val goal = @{mk_term "?L == ?R" (L)} (Thm.term_of ct) 120 |> Thm.cterm_of ctxt 121 |> Goal.init 122 |> Utils.apply_tac "Creating object-level equality." (resolve_tac ctxt @{thms eq_reflection} 1) 123 |> Utils.apply_tac "Creating 'monad_equiv' goal." (resolve_tac ctxt @{thms monad_equiv_eq} 1) 124 125 (* Print a diagnostic if this branch fails. *) 126 val num_failures = ref 0 127 fun print_failure_tac t = 128 if (false andalso !num_failures < 5) then 129 (num_failures := !num_failures + 1; (print_tac ctxt "Branch failed" THEN no_tac) t) 130 else 131 (no_tac t) 132 133 (* Fetch theorms used in the simplification process. *) 134 val thms = Utils.get_rules ctxt @{named_theorems L2flow} 135 136 (* Tactic to blindly apply simplification rules. *) 137 fun solve_goal_tac _ = 138 (simp_monad_equiv_pre_tac ctxt 1) 139 THEN DETERM ( 140 SOLVES 141 ((solve_simp_expr_const_tac ctxt) 142 ORELSE 143 ((solve_simp_expr_tac ctxt 1) 144 ORELSE 145 ((resolve_tac ctxt thms THEN_ALL_NEW solve_goal_tac) 1 146 ORELSE 147 ((print_failure_tac)))))) 148 149 (* Apply the rules. *) 150 val thm = 151 Utils.apply_tac "Simplifying L2" (solve_goal_tac 1) goal 152 |> Goal.finish ctxt 153in 154 thm 155end 156 157(* 158 * A simproc implementing the "L2_gets_bind" rule. The rule, unfortunately, has 159 * the ability to cause exponential growth in the spec size in some cases; 160 * thus, we can only selectively apply it in cases where this doesn't happen. 161 * 162 * In particular, we propagate a "gets" into its usage if it is used at most once. 163 * 164 * Or, if the user asks for "no_opt", we only erase the "gets" if it is never used. 165 * (Even with "no optimisation", we still want to get rid of control flow variables 166 * emitted by c-parser. Hopefully the user won't mind if their own unused variables 167 * also disappear.) 168 *) 169val l2_gets_bind_thm = mk_meta_eq @{thm L2_gets_bind} 170fun l2_gets_bind_simproc' ctxt cterm = 171let 172 fun is_simple (_ $ Abs (_, _, Bound _)) = true 173 | is_simple (_ $ Abs (_, _, Free _)) = true 174 | is_simple (_ $ Abs (_, _, Const _)) = true 175 | is_simple _ = false 176in 177 case Thm.term_of cterm of 178 (Const (@{const_name "L2_seq"}, _) $ lhs $ Abs (_, _, rhs)) => 179 let 180 fun count_var_usage (a $ b) = count_var_usage a + count_var_usage b 181 | count_var_usage (Abs (_, _, x)) = count_var_usage x 182 | count_var_usage (Free ("_dummy", _)) = 1 183 | count_var_usage _ = 0 184 val count = count_var_usage (subst_bounds ([Free ("_dummy", dummyT)], rhs)) 185 in 186 if count <= 1 orelse is_simple lhs then 187 SOME l2_gets_bind_thm 188 else 189 NONE 190 end 191 | _ => NONE 192end 193val l2_gets_bind_simproc = 194 Utils.mk_simproc' @{context} 195 ("L2_gets_bind_simproc", ["L2_seq (L2_gets (%_. ?A) ?n) ?B"], l2_gets_bind_simproc') 196 197(* Simproc to clean up guards. *) 198fun l2_guard_simproc' ss ctxt cterm = 199let 200 val simp_thm = Simplifier.asm_full_rewrite 201 (Simplifier.add_cong @{thm HOL.conj_cong} (put_simpset ss ctxt)) cterm 202 val [lhs, rhs] = Thm.prop_of (Drule.eta_contraction_rule simp_thm) |> Term.strip_comb |> snd 203in 204 if Term_Ord.fast_term_ord (lhs, rhs) = EQUAL then 205 NONE 206 else 207 SOME simp_thm 208end 209fun l2_guard_simproc ss = 210 Utils.mk_simproc' @{context} ("L2_guard_simproc", ["L2_guard ?G"], l2_guard_simproc' ss) 211 212(* 213 * Adjust "case_prod commands so that constructs such as: 214 * 215 * while C (%x. gets (case x of (a, b) => %s. P a b)) ... 216 * 217 * are transformed into: 218 * 219 * while C (%(a, b). gets (%s. P a b)) ... 220 *) 221fun fix_L2_while_loop_splits_conv ctxt = 222 Simplifier.asm_full_rewrite ( 223 put_simpset HOL_ss ctxt 224 addsimps @{thms L2_split_fixups} 225 |> fold Simplifier.add_cong @{thms L2_split_fixups_congs}) 226 227(* 228 * Carry out flow-sensitive optimisations on the given 'thm'. 229 * 230 * "n" is the argument number to cleanup, counting from 1. So for example, if 231 * our input theorem was "corres P A B", an "n" of 2 would simplify "A". 232 * If n < 0, then the cleanup is applied to the -n-th argument from the end. 233 * 234 * If "fast_mode" is 0, perform flow-sensitive optimisations (which tend to be 235 * time-consuming). If 1, only apply L2Peephole and L2Opt simplification rules. 236 * If 2, do not use AutoCorres' simplification rules at all. 237 *) 238fun cleanup_thm ctxt thm fast_mode n do_trace = 239let 240 (* Don't print out warning messages. *) 241 val ctxt = Context_Position.set_visible false ctxt 242 243 (* Setup basic simplifier. *) 244 fun basic_ss ctxt = 245 put_simpset AUTOCORRES_SIMPSET ctxt 246 |> (fn ctxt => if fast_mode < 2 then ctxt addsimps (Utils.get_rules ctxt @{named_theorems L2opt}) else ctxt) 247 |> (fn ctxt => if fast_mode < 2 then ctxt addsimprocs [l2_gets_bind_simproc] else ctxt) 248 |> (fn ctxt => ctxt addsimprocs [l2_guard_simproc (simpset_of ctxt)]) 249 |> map_opt_simpset false 250 fun simp_conv ctxt = 251 Drule.beta_eta_conversion 252 then_conv (fix_L2_while_loop_splits_conv ctxt) 253 then_conv (Simplifier.rewrite (basic_ss ctxt)) 254 255 fun l2conv conv = 256 Utils.remove_meta_conv (fn ctxt => Utils.nth_arg_conv n (conv ctxt)) ctxt 257 258 (* Apply peephole optimisations to the theorem. *) 259 val (new_thm, peephole_trace) = 260 AutoCorresTrace.fconv_rule_maybe_traced ctxt (l2conv simp_conv) thm do_trace 261 |> apfst Drule.eta_contraction_rule 262 263 (* Apply flow-sensitive optimisations, and then re-apply simple simplifications. *) 264 (* TODO: trace monad_equiv using trace_solve_tac rather than fconv_rule_traced *) 265 val (new_thm, flow_trace) = 266 if fast_mode = 0 then 267 AutoCorresTrace.fconv_rule_maybe_traced ctxt ( 268 l2conv (fn ctxt => 269 monad_equiv ctxt 270 then_conv (simp_conv (put_simpset AUTOCORRES_SIMPSET ctxt)) 271 )) new_thm do_trace 272 else 273 (new_thm, NONE) 274 275 (* Beta/Eta normalise. *) 276 val new_thm = Conv.fconv_rule (l2conv (K Drule.beta_eta_conversion)) new_thm 277in 278 (new_thm, List.mapPartial I [peephole_trace, flow_trace]) 279end 280 281(* Also tag the traces in a suitable format to be stored in AutoCorresData. *) 282fun cleanup_thm_tagged ctxt thm fast_mode n do_trace phase = 283 cleanup_thm ctxt thm fast_mode n do_trace 284 |> apsnd (map AutoCorresData.SimpTrace #> Utils.zip [phase ^ " peephole opt", phase ^ " flow opt"]) 285 286end 287