1(*
2 * Copyright 2020, Data61, CSIRO (ABN 41 687 119 230)
3 *
4 * SPDX-License-Identifier: BSD-2-Clause
5 *)
6
7theory FP_Eval_Tests
8imports
9  Lib.FP_Eval
10  "HOL-Library.Sublist"
11begin
12
13section \<open>Controlling evaluation\<close>
14
15subsection \<open>Skeletons\<close>
16
17text \<open>
18  A "skeleton" is a subterm of a given term. For example,
19    "map (\<lambda>x. x + 1) [1, 2, 3]"
20  could have the following skeleton (among others):
21    "map (\<lambda>x. x + _) _"
22
23  The "_" stand for schematic variables with arbitrary names,
24  or dummy patterns (which are implemented as unnamed schematics).
25
26  FP_Eval uses skeletons internally to keep track of which parts of
27  a term it has already evaluated. In other words, schematic variables
28  indicate already-normalised subterms.
29
30  There are two useful predefined skeletons:
31\<close>
32ML_val \<open> FP_Eval.skel0 \<close>
33text \<open>
34  which is a special directive to evaluate all subterms; and
35\<close>
36ML_val \<open> FP_Eval.skel_skip \<close>
37text \<open>
38  which tells FP_Eval to skip evaluation.
39\<close>
40
41text \<open>
42  If we use the full FP_Eval interface, we can input a skeleton manually
43  and get the final skeleton as output.
44
45  It's useful to input a nontrivial skeleton for the following reasons:
46   \<bullet> if most of the term is known to be normalised, this can
47      save unnecessary computation.
48   \<bullet> if a tool runs FP_Eval on behalf of an end user, it may
49      want to avoid evaluating function calls in the user's input terms.
50      Alternatively, use explicit quotation terms
51      (see "Preventing evaluation", below) if finer control is needed.
52
53  The partial skeleton should match the structure of the input term.
54  If there is any mismatch, FP_Eval tries to be conservative and
55  evaluates the whole subterm (as if "skel0" had been given).
56  However, this should not be relied upon.
57  (FIXME: maybe stricter check in eval')
58
59  By default, FP_Eval attempts full evaluation of the input, so it
60  usually returns "skel_skip".
61
62  However, evaluation is not complete when:
63   \<bullet> the input skeleton skips some subterms;
64   \<bullet> FP_Eval doesn't descend into un-applied lambdas;
65   \<bullet> evaluation delayed due to cong rules.
66  In these cases, FP_Eval would return a partial skeleton.
67\<close>
68
69(* TODO: add examples *)
70
71subsection \<open>Congruence rules\<close>
72
73text \<open>
74  Use FP_Eval.add_cong or the second argument of FP_Eval.make_rules.
75  These accept weak congruence rules, e.g.:
76\<close>
77thm if_weak_cong option.case_cong_weak
78
79text \<open>
80  Note that @{thm let_weak_cong} contains a hidden eta expansion, which FP_Eval
81  currently doesn't understand. Use our alternative:
82\<close>
83thm FP_Eval.let_weak_cong'
84
85ML_val \<open>
86  @{assert} (not (Thm.eq_thm_prop (@{thm let_weak_cong}, @{thm FP_Eval.let_weak_cong'})));
87\<close>
88
89text \<open>
90  Example: avoid evaluating both branches of an @{const If}
91\<close>
92
93ML_val \<open>
94local
95  fun eval eqns congs =
96    FP_Eval.eval @{context} (FP_Eval.make_rules eqns congs)
97  val input = (@{cterm "subseq [0::nat, 2, 4] [0, 1, 2, 3, 4, 5]"}, Bound 0);
98in
99  (* No cong rule -- blowup *)
100  val r1 = eval @{thms list_emb_code} @{thms} input
101           |> fst |> fst;
102  (* if_weak_cong prevents early evaluation of branches *)
103  val r2 = eval @{thms list_emb_code} @{thms if_weak_cong} input
104           |> fst |> fst;
105
106  (* Compare performance counters: *)
107  val eqns = @{thms list_emb_code rel_simps simp_thms if_True if_False};
108  val p1 = eval eqns @{thms} input |> snd;
109  val p2 = eval eqns @{thms if_weak_cong} input |> snd;
110end
111\<close>
112
113subsection \<open>Preventing evaluation\<close>
114
115text \<open>
116  Sometimes it is useful to prevent evaluation of any arguments.
117  This can be done by adding a cong rule with no premises:
118\<close>
119context FP_Eval begin
120definition "quote x \<equiv> x"
121lemma quote_cong:
122  "quote x = quote x"
123  by simp
124lemma quote:
125  "x \<equiv> quote x"
126  by (simp add: quote_def)
127end
128
129ML_val \<open>
130local
131  fun eval eqns congs =
132    FP_Eval.eval @{context} (FP_Eval.make_rules eqns congs);
133in
134  (* By default, fp_eval evaluates all subterms *)
135  val r1 = eval @{thms fun_upd_def} @{thms}
136             (@{cterm "FP_Eval.quote (fun_upd f a b) c"}, Bound 0);
137  (* Use quote_cong to hold all quoted subterms.
138     Note how the resulting skeleton indicates unevaluated subterms. *)
139  val r2 = eval @{thms fun_upd_def} @{thms FP_Eval.quote_cong}
140             (@{cterm "FP_Eval.quote (fun_upd f a b) c"}, Bound 0);
141  (* Now remove the quote_cong hold. fp_eval continues evaluation
142     according to the previous skeleton. *)
143  val r3 = fst r2 |> apfst Thm.rhs_of
144           |> eval @{thms fun_upd_def} @{thms};
145end;
146\<close>
147
148
149section \<open>Tests\<close>
150
151subsection \<open>Basic tests\<close>
152
153ML_val \<open>
154local
155  fun eval eqns congs =
156    FP_Eval.eval @{context} (FP_Eval.make_rules eqns congs);
157  val input = (@{cterm "2 + 2 :: nat"}, Bound 0);
158in
159  val ((result, Var _), counters) = eval @{thms arith_simps} @{thms} input;
160  val _ = @{assert} (Thm.prop_of result aconv @{term "(2 + 2 :: nat) \<equiv> 4"});
161end;
162\<close>
163
164text \<open>fp_eval does not rewrite under lambda abstractions\<close>
165ML_val \<open>
166local
167  fun eval eqns congs =
168    FP_Eval.eval @{context} (FP_Eval.make_rules eqns congs);
169  val input = (@{cterm "(\<lambda>x. x + (2 + 2::nat))"}, Bound 0);
170in
171  val ((result, skel), _) = eval @{thms arith_simps} @{thms} input;
172  val _ = @{assert} (not (is_Var skel) andalso Thm.is_reflexive result);
173end
174\<close>
175
176
177subsection \<open>Cong rules\<close>
178
179text \<open>Test for @{thm if_weak_cong}\<close>
180ML_val \<open>
181local
182  fun eval eqns congs =
183    FP_Eval.eval @{context} (FP_Eval.make_rules eqns congs);
184  val input = (@{cterm "subseq [2::int,3,5,7,11] [0,1,2,3,4,5,6,7,8,9,10,11,12,13]"}, Bound 0);
185in
186  val r1 = eval @{thms list_emb_code rel_simps refl if_True if_False} @{thms} input;
187  val r2 = eval @{thms list_emb_code rel_simps refl if_True if_False} @{thms if_weak_cong} input;
188
189  val _ = @{assert} (Thm.term_of (Thm.rhs_of (fst (fst r1))) = @{term True});
190  val _ = @{assert} (Thm.term_of (Thm.rhs_of (fst (fst r2))) = @{term True});
191
192  (* Compare performance counters: *)
193  val (SOME r1_rewrs, SOME r2_rewrs) =
194        apply2 (snd #> Symtab.make #> (fn t => Symtab.lookup t "rewrites")) (r1, r2);
195  val _ = @{assert} (r1_rewrs > 10000);
196  val _ = @{assert} (r2_rewrs < 100);
197end
198\<close>
199
200
201subsection \<open>Advanced usage\<close>
202
203subsubsection \<open>Triggering breakpoints\<close>
204ML_val \<open>
205local
206  fun break_4 t = Thm.term_of t = @{term "4 :: nat"};
207  fun eval eqns congs break =
208    FP_Eval.eval' @{context} (K (K ())) break false (FP_Eval.make_rules eqns congs);
209  val input = (@{cterm "map Suc [2 + 2 :: nat, 2 + 3, 2 + 4, 2 + 5, 2 + 6]"}, Bound 0);
210in
211  (* Normal evaluation *)
212  val ((result, Var _), _) = eval @{thms list.map arith_simps} @{thms} (K false) input;
213  val _ = @{assert} (Thm.term_of (Thm.rhs_of result) aconv
214                        @{term "[Suc 4, Suc 5, Suc 6, Suc 7, Suc 8]"});
215
216  (* Evaluation stops after "4" is encountered *)
217  val ((result2, skel), _) = eval @{thms list.map arith_simps} @{thms} break_4 input;
218  val _ = @{assert} (Thm.term_of (Thm.rhs_of result2) aconv
219                        @{term "map Suc [4::nat, 5, 2 + 4, 2 + 5, 2 + 6]"});
220  (* Skeleton indicates evaluation is unfinished *)
221  val _ = @{assert} (not (is_Var skel));
222end;
223\<close>
224
225subsubsection \<open>Rule set manipulation\<close>
226ML_val \<open>
227local
228  val rules0 = FP_Eval.empty_rules;
229  val rules1 = FP_Eval.make_rules @{thms simp_thms arith_simps} @{thms if_weak_cong};
230  val rules2 = FP_Eval.make_rules @{thms simp_thms if_False if_True fun_upd_apply}
231                                @{thms if_weak_cong option.case_cong_weak};
232in
233  (* dest_rules returns rules *)
234  val _ = @{assert} (apply2 length (FP_Eval.dest_rules rules1) <> (0, 0));
235  (* test round-trip conversion *)
236  val _ = @{assert} (let val (thms, congs) = FP_Eval.dest_rules rules2;
237                         val (thms', congs') = FP_Eval.dest_rules (FP_Eval.make_rules thms congs);
238                     in forall Thm.eq_thm_prop (thms ~~ thms')
239                        andalso forall Thm.eq_thm_prop (congs ~~ congs') end);
240  (* test that merging succeeds and actually merges rules *)
241  fun test_merge r1 r2 =
242    let val (r1_eqns, r1_congs) = FP_Eval.dest_rules r1;
243        val (r2_eqns, r2_congs) = FP_Eval.dest_rules r2;
244        val (r12_eqns, r12_congs) = FP_Eval.dest_rules (FP_Eval.merge_rules (r1, r2));
245    in eq_set Thm.eq_thm_prop (r12_eqns, union Thm.eq_thm_prop r1_eqns r2_eqns)
246       andalso eq_set Thm.eq_thm_prop (r12_congs, union Thm.eq_thm_prop r1_congs r2_congs)
247    end;
248
249  val _ = @{assert} (test_merge rules0 rules1);
250  val _ = @{assert} (test_merge rules0 rules2);
251  val _ = @{assert} (test_merge rules1 rules2);
252
253  (* test that rules with conflicting arity are not allowed *)
254  val conflict_arity = FP_Eval.make_rules @{thms fun_upd_def} @{thms};
255  val _ = @{assert} (is_none (try FP_Eval.merge_rules (rules2, conflict_arity)));
256  (* test that installing different cong rules is not allowed *)
257  val conflict_cong = FP_Eval.make_rules @{thms} @{thms if_weak_cong[OF refl]};
258  val _ = @{assert} (is_none (try FP_Eval.merge_rules (rules2, conflict_cong)));
259end
260\<close>
261
262subsubsection \<open>Ordering of rules\<close>
263text \<open>
264  In the current implementation, equations are picked based on the default
265  Net ordering. This should be improved in the future.
266\<close>
267ML_val \<open>
268local
269  fun eval eqns congs =
270    FP_Eval.eval @{context} (FP_Eval.make_rules eqns congs);
271  val input = (@{cterm "list_all (\<lambda>x::nat. x \<le> x) [100000, 314159, 2718281845]"}, Bound 0);
272  val basic_eqns = @{thms list_all_simps rel_simps simp_thms};
273  fun get_counter cs x = the (Symtab.lookup (Symtab.make cs) x);
274in
275  (* evaluate \<le> slowly *)
276  val ((r1, _), counters1) = eval basic_eqns @{thms} input;
277  (* shortcut for \<le> *)
278  val ((r2, _), counters2) = eval (@{thms order.refl} @ basic_eqns) @{thms} input;
279  val ((r3, _), counters3) = eval (basic_eqns @ @{thms order.refl}) @{thms} input;
280
281  (* Bug: shortcut is never used -- no effect on runtime *)
282  val _ = @{assert} (length (distinct op= [counters1, counters2, counters3]) = 1);
283
284  (* desired outcome *)
285  val ((r4, _), counters4) = eval @{thms list_all_simps simp_thms order.refl} @{thms} input;
286  val _ = @{assert} (get_counter counters4 "rewrites" < get_counter counters1 "rewrites");
287end
288\<close>
289
290
291subsection \<open>Miscellaneous and regression tests\<close>
292
293text \<open>Test for partial arity and arg_conv\<close>
294ML_val \<open>
295local
296  fun eval eqns congs =
297    FP_Eval.eval @{context} (FP_Eval.make_rules eqns congs);
298  val input = (@{cterm "(if 2 + 2 = 4 then Suc else id) x"}, FP_Eval.skel0);
299  (* Need to build these manually, as @{cterm} itself does beta-eta normalisation *)
300  val input_abs1 = (fold (fn x => fn f => Thm.apply f x)
301                         [@{cterm "f::nat\<Rightarrow>nat"}, @{cterm "4::nat"}]
302                         @{cterm "\<lambda>f. fun_upd (f::nat\<Rightarrow>nat) (2+2) y"},
303                    FP_Eval.skel0);
304  val input_abs2 = (fold (fn x => fn f => Thm.apply f x)
305                         [@{cterm "f::nat\<Rightarrow>nat"}, @{cterm "4::nat"}]
306                         @{cterm "\<lambda>f x z. z + fun_upd (f::nat\<Rightarrow>nat) (2+2) y x"},
307                    FP_Eval.skel0);
308  val Abs (_, _, Abs (_, _, Abs _)) $ _ $ _ = Thm.term_of (fst input_abs2); (* check *)
309in
310  (* Head (= If) is rewritten *)
311  val ((result, Var _), _) = eval @{thms arith_simps refl if_True} @{thms} input;
312  val _ = @{assert} (Thm.term_of (Thm.rhs_of result) = @{term "Suc x"});
313
314  (* Head is not rewritten *)
315  val ((result2, Var _ $ _), _) = eval @{thms arith_simps refl if_False} @{thms} input;
316  val _ = @{assert} (Thm.term_of (Thm.rhs_of result2) = @{term "(if True then Suc else id) x"});
317
318  (* Head is Abs *)
319  val ((result3, Var _), _) = eval @{thms arith_simps refl fun_upd_apply if_True} @{thms} input_abs1;
320  val _ = @{assert} (Thm.term_of (Thm.rhs_of result3) = @{term "y::nat"});
321
322  (* Partially applied Abs *)
323  val ((result4, Abs _), _) = eval @{thms arith_simps refl fun_upd_apply if_True} @{thms} input_abs2;
324  val _ = @{assert} (Thm.term_of (Thm.rhs_of result4) = @{term "\<lambda>z. z + fun_upd (f::nat\<Rightarrow>nat) (2+2) y 4"});
325end;
326\<close>
327
328text \<open>Check that skel_skip is not returned for intermediate Abs\<close>
329ML_val \<open>
330local
331  fun eval eqns congs = FP_Eval.eval @{context} (FP_Eval.make_rules eqns congs);
332  val input = (@{cterm "map (\<lambda>((x, y), z). ((id x, id y), id z)) [((a::'a, b::'b), c::'c)]"}, Bound 0);
333in
334  val ((result, Var _), counters) = eval @{thms list.map prod.case arith_simps} @{thms} input;
335  val _ = @{assert} (Thm.term_of (Thm.rhs_of result) = @{term "[((id a::'a, id b::'b), id c::'c)]"});
336end;
337\<close>
338
339text \<open>Test conversion of some basic non-equation rules to equations\<close>
340ML_val \<open>
341let val thms = @{thms simp_thms rel_simps arith_simps};
342in @{assert} (length (map_filter FP_Eval.maybe_convert_eqn thms) = length thms) end
343\<close>
344
345end