1(*
2 * Copyright 2020, Data61, CSIRO (ABN 41 687 119 230)
3 *
4 * SPDX-License-Identifier: BSD-2-Clause
5 *)
6
7theory ProveGraphRefine
8imports GraphRefine GlobalsSwap FieldAccessors AsmSemanticsRespects CommonOpsLemmas
9begin
10
11ML \<open>
12val do_trace = false;
13
14fun trace_tac ctxt msg i =
15    if do_trace
16    then print_tac ctxt ("(tracing subgoal " ^ Int.toString i ^ "): " ^ msg)
17    else all_tac;
18\<close>
19
20lemma const_globals_in_memory_heap_updateE:
21  "\<lbrakk> globals_list_distinct D symtab gs;
22    const_globals_in_memory symtab gs hmem;
23    htd_safe D htd;
24    ptr_safe (p :: ('a :: wf_type) ptr) htd \<rbrakk>
25     \<Longrightarrow> const_globals_in_memory symtab gs (heap_update p val hmem)"
26  by (simp add: const_globals_in_memory_heap_update)
27
28lemma disjoint_h_val_globals_swap_insert:
29  "\<lbrakk> global_acc_valid g_hrs g_hrs_upd;
30     globals_list_distinct D symtab xs;
31     htd_safe D htd;
32     ptr_safe (p :: ('a :: wf_type) ptr) htd \<rbrakk>
33     \<Longrightarrow> h_val (hrs_mem (g_hrs (globals s))) p
34         = h_val (hrs_mem (g_hrs (globals_swap g_hrs g_hrs_upd symtab xs (globals s)))) p"
35  (* the current apparatus produces goals where the Simpl-derived
36     h_vals are applied to a globals swap and the graph-derived
37     h_vals lack it. we thus *add* a globals swap since that is the
38     case where we can prove ptr_safe *)
39  apply (rule disjoint_h_val_globals_swap[symmetric], assumption+)
40  apply (clarsimp simp: ptr_safe_def htd_safe_def del: subsetI)
41  apply blast
42  done
43
44lemma disjoint_heap_update_globals_swap_rearranged:
45  "\<lbrakk> global_acc_valid g_hrs g_hrs_upd;
46     globals_list_distinct D symtab xs;
47     htd_safe D htd;
48     ptr_safe (p :: ('a :: wf_type) ptr) htd \<rbrakk>
49     \<Longrightarrow> hrs_mem (g_hrs (globals_swap g_hrs g_hrs_upd symtab xs (g_hrs_upd (hrs_mem_update (heap_update p v)) gs)))
50         = heap_update p v (hrs_mem (g_hrs (globals_swap g_hrs g_hrs_upd symtab xs gs)))"
51  apply (subst disjoint_heap_update_globals_swap[symmetric], assumption+)
52   apply (clarsimp simp: ptr_safe_def htd_safe_def del: subsetI)
53   apply blast
54  apply (simp add: global_acc_valid_def hrs_mem_update)
55  done
56
57lemma hrs_mem_update_triv:
58  "hrs_mem_update (\<lambda>_. hrs_mem x) x = x"
59  by (cases x, simp_all add: hrs_mem_update_def hrs_mem_def)
60
61lemma h_t_valid_orig_and_ptr_safe:
62  "h_t_valid d g p \<Longrightarrow> h_t_valid d g p \<and> ptr_safe p d"
63  by (simp add: h_t_valid_ptr_safe)
64
65lemma array_ptr_index_coerce:
66  fixes p :: "(('a :: c_type)['b :: finite]) ptr"
67  shows "n < CARD ('b)
68    \<Longrightarrow> array_ptr_index p False n = array_ptr_index p True n"
69  by (simp add: array_ptr_index_def)
70
71lemma unat_mono_thms:
72  "unat (a + b :: ('a :: len) word) \<le> unat a + unat b"
73  "unat (a * b) \<le> unat a * unat b"
74  by (simp_all add: unat_word_ariths)
75
76lemma unat_mono_intro:
77  "unat a \<le> x \<Longrightarrow> x < b \<Longrightarrow> unat a < b"
78  "unat a \<le> x \<Longrightarrow> x \<le> b \<Longrightarrow> unat a \<le> b"
79  "unat a \<le> x \<Longrightarrow> x \<le> 0 \<Longrightarrow> unat a = 0"
80  by simp_all
81
82lemma word_neq_0_conv_neg_conv:
83  "(\<not> 0 < (n :: ('a :: len) word)) = (n = 0)"
84  by (cases "n = 0", simp_all)
85
86lemmas unat_ucast_upcasts =
87  unat_ucast_upcast[OF is_up[where 'a=32 and 'b=64, simplified]]
88
89definition
90  drop_sign :: "('a :: len) signed word \<Rightarrow> 'a word"
91where
92  "drop_sign = ucast"
93
94lemma sint_drop_sign_isomorphism:
95  "sint (drop_sign x) = sint x"
96  by (simp add: drop_sign_def word_sint_msb_eq uint_up_ucast is_up_def
97                source_size_def target_size_def word_size msb_ucast_eq)
98
99lemma drop_sign_isomorphism_ariths:
100  "(x = y) = (drop_sign x = drop_sign y)"
101  "(x < y) = (drop_sign x < drop_sign y)"
102  "(x \<le> y) = (drop_sign x \<le> drop_sign y)"
103  "(x <s y) = (drop_sign x <s drop_sign y)"
104  "(x <=s y) = (drop_sign x <=s drop_sign y)"
105  "drop_sign (x + y) = drop_sign x + drop_sign y"
106  "drop_sign (x - y) = drop_sign x - drop_sign y"
107  "drop_sign (x * y) = drop_sign x * drop_sign y"
108  "drop_sign (x div y) = drop_sign x div drop_sign y"
109  "drop_sign (x sdiv y) = drop_sign x sdiv drop_sign y"
110  "drop_sign (- y) = - drop_sign y"
111  "drop_sign (if P then x else y) = (if P then drop_sign x else drop_sign y)"
112  "drop_sign (w ^ n) = drop_sign w ^ n"
113  by (simp_all add: drop_sign_def word_less_def
114                    word_le_def word_sless_def word_sle_def
115                    sint_drop_sign_isomorphism[unfolded drop_sign_def]
116                    word_uint.Rep_inject[symmetric]
117                    uint_up_ucast is_up_def source_size_def
118                    target_size_def word_size
119                    uint_word_arith_bintrs
120                    word_arith_power_alt
121                    uint_word_of_int
122                    uint_div_alt sdiv_word_def sdiv_int_def
123               del: word_uint.Rep_inject)
124
125lemma drop_sign_isomorphism_bitwise:
126  "drop_sign (x AND y) = drop_sign x AND drop_sign y"
127  "drop_sign (bitOR x y) = bitOR (drop_sign x) (drop_sign y)"
128  "drop_sign (x XOR y) = drop_sign x XOR drop_sign y"
129  "drop_sign (~~ y) = ~~ drop_sign y"
130  "drop_sign (shiftl x n) = shiftl (drop_sign x) n"
131  "drop_sign (shiftr x n) = shiftr (drop_sign x) n"
132  "drop_sign (sshiftr x n) = sshiftr (drop_sign x) n"
133  "drop_sign (ucast z) = ucast z"
134  "drop_sign (scast z) = scast z"
135  "ucast x = ucast (drop_sign x)"
136  "scast x = scast (drop_sign x)"
137  by (rule word_eqI
138          | simp add: word_size drop_sign_def nth_ucast nth_shiftl
139                      nth_shiftr nth_sshiftr word_ops_nth_size
140                      nth_scast
141          | safe
142          | simp add: test_bit_bin)+
143
144lemma drop_sign_of_nat:
145  "drop_sign (of_nat n) = of_nat n"
146  by (simp add: drop_sign_def ucast_of_nat is_down_def
147                target_size_def source_size_def word_size)
148
149lemma drop_sign_to_bl:
150  "to_bl (drop_sign w) = to_bl w"
151  by (simp add: drop_sign_def to_bl_ucast)
152
153lemma drop_sign_extra_bl_ops:
154  "drop_sign (bv_clz w) = bv_clz (drop_sign w)"
155  "drop_sign (bv_ctz w) = bv_ctz (drop_sign w)"
156  "drop_sign (bv_popcount w) = bv_popcount (drop_sign w)"
157  by (simp_all add: bv_clz_def bv_ctz_def bv_popcount_def drop_sign_of_nat
158                    word_ctz_def word_clz_def pop_count_def drop_sign_to_bl)
159
160lemma drop_sign_number[simp]:
161  "drop_sign (numeral n) = numeral n"
162  "drop_sign (- numeral n) = - numeral n"
163  "drop_sign 0 = 0" "drop_sign 1 = 1"
164  by (simp_all add: drop_sign_def ucast_def)
165
166lemma drop_sign_minus_1[simp]:
167  "drop_sign (-1) = (-1)"
168  by (clarsimp simp add: drop_sign_def ucast_def uint_word_ariths)
169
170lemma drop_sign_projections:
171  "unat x = unat (drop_sign x)"
172  "uint x = uint (drop_sign x)"
173  "sint x = sint (drop_sign x)"
174  apply (simp_all add: sint_drop_sign_isomorphism)
175  apply (auto simp: unat_def uint_up_ucast drop_sign_def is_up_def
176                    source_size_def target_size_def word_size)
177  done
178
179lemmas drop_sign_isomorphism
180    = drop_sign_isomorphism_ariths drop_sign_projections
181        drop_sign_isomorphism_bitwise drop_sign_of_nat
182        drop_sign_extra_bl_ops ucast_id
183
184lemma drop_sign_h_val[simp]:
185  "drop_sign (h_val hp p :: ('a :: len8) signed word) = h_val hp (ptr_coerce p)"
186  using len8_dv8[where 'a='a]
187  apply (simp add: h_val_def drop_sign_def)
188  apply (simp add: from_bytes_ucast_isom word_size size_of_def typ_info_word)
189  done
190
191lemma drop_sign_heap_update[simp]:
192  "heap_update p v = heap_update (ptr_coerce p) (drop_sign v)"
193  using len8_dv8[where 'a='a]
194  apply (simp add: heap_update_def drop_sign_def fun_eq_iff)
195  apply (simp add: to_bytes_ucast_isom word_size size_of_def typ_info_word)
196  done
197
198lemma typ_uinfo_t_signed_word:
199  "typ_uinfo_t TYPE (('a :: len8) signed word) = typ_uinfo_t TYPE ('a word)"
200  using len8_dv8[where 'a='a]
201  apply (simp add: typ_uinfo_t_def typ_info_word)
202  apply (clarsimp simp: field_norm_def fun_eq_iff)
203  apply (simp add: word_rsplit_rcat_size word_size)
204  done
205
206lemma align_td_signed_word:
207  "align_td (typ_info_t TYPE (('a :: len8) signed word))
208    = align_td (typ_info_t TYPE (('a :: len8) word))"
209  using arg_cong[where f=align_td, OF typ_uinfo_t_signed_word[where 'a='a]]
210  by (simp add: typ_uinfo_t_def)
211
212lemma size_td_signed_word:
213  "size_td (typ_info_t TYPE (('a :: len8) signed word))
214    = size_td (typ_info_t TYPE (('a :: len8) word))"
215  by (simp add: typ_info_word)
216
217lemma pointer_inverse_safe_sign:
218  "ptr_inverse_safe (ptr :: (('a :: len8) signed word ptr))
219    = ptr_inverse_safe (ptr_coerce ptr :: 'a word ptr)"
220  by (simp add: fun_eq_iff ptr_inverse_safe_def s_footprint_def
221                c_guard_def ptr_aligned_def c_null_guard_def
222                typ_uinfo_t_signed_word align_td_signed_word align_of_def
223                size_of_def size_td_signed_word)
224
225lemma ptr_equalities_to_ptr_val:
226  "(Ptr addr = p) = (addr = ptr_val p)"
227  "(p = Ptr addr) = (ptr_val p = addr)"
228  by (simp | cases p)+
229
230lemma unat_ucast_if_up:
231  "unat (ucast (x :: ('a :: len) word) :: ('b :: len) word)
232    = (if len_of TYPE('a) \<le> len_of TYPE('b) then unat x else unat x mod 2 ^ len_of TYPE ('b))"
233  apply (simp, safe intro!: unat_ucast unat_ucast_upcast)
234  apply (simp add: is_up_def source_size_def target_size_def word_size)
235  done
236
237(* FIXME: these 2 duplicated from crefine *)
238lemma Collect_const_mem:
239  "(x \<in> (if P then UNIV else {})) = P"
240  by simp
241
242lemma typ_uinfo_t_diff_from_typ_name:
243  "typ_name (typ_info_t TYPE ('a :: c_type)) \<noteq> typ_name (typ_info_t TYPE('b :: c_type))
244    \<Longrightarrow> typ_uinfo_t (aty :: 'a itself) \<noteq> typ_uinfo_t (bty :: 'b itself)"
245  by (clarsimp simp: typ_uinfo_t_def td_diff_from_typ_name)
246
247lemmas ptr_add_assertion_unfold_numeral
248    = ptr_add_assertion_def[where offs="numeral n" for n, simplified]
249      ptr_add_assertion_def[where offs="uminus (numeral n)" for n, simplified]
250      ptr_add_assertion_def[where offs=0, simplified]
251      ptr_add_assertion_def[where offs=1, simplified]
252
253definition machine_word_truncate_nat :: "nat => machine_word \<Rightarrow> machine_word"
254where
255  "machine_word_truncate_nat n x = (if unat x \<le> n then x else of_nat n)"
256
257lemma machine_word_truncate_noop:
258  "unat x < Suc n \<Longrightarrow> machine_word_truncate_nat n x = x"
259  by (simp add: machine_word_truncate_nat_def)
260
261lemma fold_of_nat_eq_Ifs_proof:
262  "\<lbrakk> unat (x :: machine_word) \<notin> set ns \<Longrightarrow> y = z;
263      \<And>n. n \<in> set ns \<Longrightarrow> x = of_nat n \<Longrightarrow> f n = z \<rbrakk>
264    \<Longrightarrow> foldr (\<lambda>n v. if x = of_nat n then f n else v) ns y = z"
265  apply (induct ns)
266   apply simp
267  apply (atomize(full))
268  apply clarsimp
269  done
270
271lemma fold_of_nat_eq_Ifs[simplified word_bits_conv]:
272  "m < 2 ^ word_bits
273    \<Longrightarrow> foldr (\<lambda>n v. if x = of_nat n then f n else v) [0 ..< m] (f m)
274        = f (unat (machine_word_truncate_nat m x))"
275  apply (rule fold_of_nat_eq_Ifs_proof)
276   apply (simp_all add: machine_word_truncate_nat_def unat_of_nat word_bits_def)
277  done
278
279lemma of_int_sint_scast:
280  "of_int (sint x) = scast x"
281  by (simp add: scast_def word_of_int)
282
283lemma less_is_non_zero_p1':
284  fixes a :: "'a :: len word"
285  shows "a < k \<Longrightarrow> 1 + a \<noteq> 0"
286  by (metis less_is_non_zero_p1 add.commute)
287
288lemma(in comm_semiring_1) add_mult_comms:
289  "a + b + c = a + c + b"
290  "a * b * c = a * c * b"
291  by (rule semiring_normalization_rules)+
292
293lemma array_index_update_If:
294  "i < CARD ('b :: finite)
295    \<Longrightarrow> Arrays.index (Arrays.update arr j x) i
296        = (if i = j then x else Arrays.index (arr :: ('a['b])) i)"
297  by simp
298
299\<comment> \<open>Of the assumptions, only pos is needed to prove the conclusion.
300    The guard assumptions are there to ensure that when used as a simp rule,
301    the RHS array pointer gets an appropriate type.\<close>
302lemma ptr_safe_ptr_add_array_ptr_index_int:
303  assumes guard: "ptr_safe (Ptr p::('a['b]) ptr) htd" (* "nat i < CARD('b)" *)
304  assumes pos: "0 \<le> i"
305  shows "(Ptr p::'a::c_type ptr) +\<^sub>p i = array_ptr_index (Ptr p::('a['b::finite]) ptr) False (nat i)"
306  using pos by (simp add: array_ptr_index_def)
307
308lemma ptr_safe_ptr_add_array_ptr_index_sint:
309  assumes guard: "ptr_safe (Ptr p::('a['b]) ptr) htd" "i <s of_nat CARD('b)"
310  assumes pos: "0 <=s i"
311  shows "(Ptr p::'a::c_type ptr) +\<^sub>p sint i = array_ptr_index (Ptr p::('a['b::finite]) ptr) False (unat i)"
312  using pos by (simp add: array_ptr_index_def int_unat sint_eq_uint word_sle_msb_le)
313
314lemmas ptr_safe_ptr_add_array_ptr_index =
315  ptr_safe_ptr_add_array_ptr_index_int
316  ptr_safe_ptr_add_array_ptr_index_sint
317
318lemma ptr_safe_Array_element_0:
319  "ptr_safe (PTR('a::mem_type['b::finite]) p) htd \<Longrightarrow> ptr_safe (PTR('a) p) htd"
320  by (drule ptr_safe_Array_element[where coerce=False and n=0]; simp add: array_ptr_index_def)
321
322ML \<open>
323fun preserve_skel_conv consts arg_conv ct = let
324    val (hd, xs) = strip_comb (Thm.term_of ct)
325    val self = preserve_skel_conv consts arg_conv
326  in if is_Const hd andalso member (op =) consts
327        (fst (dest_Const hd))
328    then  if null xs then Conv.all_conv ct
329        else Conv.combination_conv self self ct
330    else arg_conv ct end
331
332fun fold_of_nat_eq_Ifs ctxt tm = let
333    fun recr (Const (@{const_name If}, _)
334            $ (@{term "(=) :: machine_word => _"} $ _ $ n) $ y $ z)
335        = (SOME n, y) :: recr z
336      | recr t = [(NONE, t)]
337    val (ns, vs) = recr tm |> map_split I
338    val ns = map_filter I ns |> map (HOLogic.dest_number #> snd)
339    val _ = (ns = (0 upto (length ns - 1)))
340        orelse raise TERM ("fold_of_nat_eq_Ifs: ns", [tm])
341    val _ = length vs > 1
342        orelse raise TERM ("fold_of_nat_eq_Ifs: no If", [tm])
343    val n = @{term "n_hopefully_uniq :: nat"}
344    fun get_pat @{term "0 :: nat"} @{term "Suc 0 :: nat"} = n
345      | get_pat (f $ x) (g $ y) = get_pat f g $ get_pat x y
346      | get_pat (a as Abs (s, T, t)) (a' as Abs (_, T', t'))
347          = ((T = T' orelse raise TERM ("fold_array_conditional: get_pat", [a, a']))
348              ; Abs (s, T, get_pat t t'))
349      | get_pat t t' = (t aconv t' orelse raise TERM ("fold_array_conditional: get_pat", [t, t'])
350              ; t)
351    val pat = lambda n (get_pat (nth vs 0) (nth vs 1))
352    val m = HOLogic.mk_number @{typ nat} (length vs - 1)
353    val conv = preserve_skel_conv [fst (dest_Const @{term "(==>)"}),
354            @{const_name Trueprop}, fst (dest_Const @{term "(==)"}),
355            @{const_name If}]
356        (Simplifier.rewrite ctxt)
357    val thm = @{thm fold_of_nat_eq_Ifs}
358      |> infer_instantiate ctxt [(("f",0), Thm.cterm_of ctxt pat),
359          (("m",0), Thm.cterm_of ctxt m)]
360      |> simplify (put_simpset HOL_basic_ss ctxt
361          addsimprocs [Word_Bitwise_Tac.expand_upt_simproc]
362          addsimps @{thms foldr.simps id_apply o_apply})
363      |> mk_meta_eq
364      |> Conv.fconv_rule conv
365  in thm end
366
367val fold_of_nat_eq_Ifs_simproc = Simplifier.make_simproc
368  (Proof_Context.init_global @{theory}) "fold_of_nat_eq_Ifs"
369  { lhss = [@{term "If (x = 0) y z"}]
370  , proc = fn _ => fn ctxt => try (fold_of_nat_eq_Ifs ctxt) o Thm.term_of
371  }
372
373fun unfold_assertion_data_get_set_conv ctxt tm = let
374    val (f, xs) = strip_comb tm
375    val (f_nm, _) = dest_Const f
376    val procs = map_filter (fn (Const (s, _)) => if String.isSuffix "_'proc" s
377        then SOME s else NONE | _ => NONE) xs
378    val defs = map (suffix "_def" #> Proof_Context.get_thm ctxt) (f_nm :: procs)
379  in Simplifier.rewrite (ctxt addsimps defs) (Thm.cterm_of ctxt tm) end
380
381val unfold_assertion_data_get_set = Simplifier.make_simproc
382  (Proof_Context.init_global @{theory}) "unfold_assertion_data_get"
383  { lhss = [@{term "ghost_assertion_data_get k acc s"}, @{term "ghost_assertion_data_set k v upd"}]
384  , proc = fn _ => fn ctxt => SOME o (unfold_assertion_data_get_set_conv ctxt) o Thm.term_of
385  }
386
387\<close>
388
389ML \<open>
390fun wrap_tac tac i t = if Thm.nprems_of t = 0 then no_tac t else let
391    val t' = Goal.restrict i 1 t
392    val r = tac 1 t'
393  in case Seq.pull r of NONE => Seq.empty
394    | SOME (t'', _) => Seq.single (Goal.unrestrict i t'')
395  end
396
397fun eqsubst_wrap_tac ctxt thms = wrap_tac (EqSubst.eqsubst_tac ctxt [0] thms)
398fun eqsubst_asm_wrap_tac ctxt thms = wrap_tac (EqSubst.eqsubst_asm_tac ctxt [0] thms)
399fun eqsubst_either_wrap_tac ctxt thms = (eqsubst_asm_wrap_tac ctxt thms
400    ORELSE' eqsubst_wrap_tac ctxt thms)
401\<close>
402
403
404
405ML \<open>
406structure ProveSimplToGraphGoals = struct
407
408fun goal_eq (g, g') =
409    (eq_list (op aconv) (Logic.strip_assums_hyp g, Logic.strip_assums_hyp g'))
410    andalso (Logic.strip_assums_concl g aconv Logic.strip_assums_concl g')
411    andalso (map snd (Logic.strip_params g) = map snd (Logic.strip_params g'))
412
413fun tactic_check s tac = let
414  in fn i => fn t => case Seq.list_of (tac i t)
415    of [] => Seq.empty
416    | [t'] => let
417        val orig_goals = Thm.prems_of t
418        val new_goals = Thm.prems_of t'
419      in (eq_list goal_eq (take (i - 1) orig_goals, take (i - 1) new_goals)
420          andalso eq_list goal_eq (drop i orig_goals,
421              drop (i + length new_goals - length orig_goals) new_goals))
422        orelse raise THM ("tactic " ^ s ^ " broke the rules!", i, [t, t'])
423        ; Seq.single t'
424      end
425    | _ => raise THM ("tactic " ^ s ^ " nondeterministic", i, [t])
426  end
427
428(* FIXME: shadows SimplExport *)
429fun get_c_type_size ctxt (Type (@{type_name array}, [elT, nT])) =
430    get_c_type_size ctxt elT * Word_Lib.dest_binT nT
431  | get_c_type_size _ @{typ word8} = 1
432  | get_c_type_size _ @{typ word16} = 2
433  | get_c_type_size _ @{typ word32} = 4
434  | get_c_type_size _ @{typ word64} = 8
435  | get_c_type_size ctxt (Type (@{type_name ptr}, [_])) =
436        get_c_type_size ctxt @{typ machine_word}
437  | get_c_type_size ctxt (Type (@{type_name word}, [Type (@{type_name signed}, [t])]))
438    = get_c_type_size ctxt (Type (@{type_name word}, [t]))
439  | get_c_type_size ctxt (T as Type (s, _)) = let
440    val thm = Proof_Context.get_thm ctxt (s ^ "_size")
441      handle ERROR _ => raise TYPE ("get_c_type_size: couldn't get size", [T], [])
442  in (Thm.rhs_of thm |> Thm.term_of |> HOLogic.dest_number |> snd)
443    handle TERM (s, ts) => raise TYPE ("get_c_type_size: " ^ s, [T], ts)
444  end
445  | get_c_type_size _ T = raise TYPE ("get_c_type_size:", [T], [])
446
447fun enum_simps csenv ctxt = let
448    val Absyn.CE ecenv = ProgramAnalysis.cse2ecenv csenv;
449  in
450    #enumenv ecenv |> Symtab.dest
451       |> map (Proof_Context.get_thm ctxt o suffix "_def" o fst)
452  end
453
454fun safe_goal_tac ctxt =
455  REPEAT_ALL_NEW (DETERM o CHANGED o safe_steps_tac ctxt)
456
457fun res_from_ctxt tac_name thm_name ctxt thm = let
458    val thm_from_ctxt = Proof_Context.get_thm ctxt thm_name
459      handle ERROR _ => raise THM (tac_name ^ ": need thm " ^ thm_name, 1, [])
460  in thm_from_ctxt RS thm
461    handle THM _ => raise THM (tac_name ^ ": need thm to resolve: " ^ thm_name,
462        1, [thm_from_ctxt, thm])
463  end
464
465val except_tac = SimplToGraphProof.except_tac
466
467fun warn_schem_tac msg ctxt tac = SUBGOAL (fn (t, i) => let
468    val _ = if null (Term.add_var_names t []) then ()
469      else warning ("schematic in goal: " ^ msg ^ ": "
470        ^ Pretty.string_of (Syntax.pretty_term ctxt t))
471  in tac i end)
472
473fun prove_ptr_safe reason ctxt = DETERM o
474    warn_schem_tac "prove_ptr_safe" ctxt
475    (TRY o REPEAT_ALL_NEW (eqsubst_either_wrap_tac ctxt
476                @{thms array_ptr_index_coerce nat_uint_less_helper}
477            )
478        THEN_ALL_NEW asm_full_simp_tac (ctxt addsimps
479            @{thms ptr_safe_ptr_add_array_ptr_index
480                   word_sle_msb_le word_sless_msb_less
481                   nat_uint_less_helper})
482        THEN_ALL_NEW asm_simp_tac (ctxt addsimps
483            @{thms ptr_safe_field[unfolded typ_uinfo_t_def]
484                   ptr_safe_Array_element unat_less_helper unat_def[symmetric]
485                   ptr_safe_Array_element_0
486                   h_t_valid_Array_element' h_t_valid_field
487                   nat_uint_less_helper upcast_less_unat_less})
488        THEN_ALL_NEW except_tac ctxt
489            ("prove_ptr_safe: failed for " ^ reason)
490    )
491
492fun get_disjoint_h_val_globals_swap ctxt =
493    @{thm disjoint_h_val_globals_swap_insert}
494        |> res_from_ctxt "prove_heap_update_id" "global_acc_valid" ctxt
495        |> res_from_ctxt "prove_heap_update_id" "globals_list_distinct" ctxt
496
497fun prove_heap_update_id ctxt = DETERM o let
498    val thm = get_disjoint_h_val_globals_swap ctxt
499  in fn i => (resolve_tac ctxt @{thms heap_update_id_Array heap_update_id} i
500        ORELSE except_tac ctxt "prove_heap_update_id: couldn't init" i)
501    THEN (simp_tac ctxt
502    THEN_ALL_NEW (* simp_tac will solve goal unless globals swap involved *)
503    ((resolve0_tac [thm]
504      ORELSE' (resolve0_tac [@{thm sym}] THEN' resolve0_tac [thm])
505      ORELSE' except_tac ctxt "prove_heap_update_id: couldn't rtac")
506    THEN' (assume_tac ctxt (* htd_safe assumption *)
507      ORELSE' except_tac ctxt "prove_heap_update_id: couldn't atac")
508    THEN' prove_ptr_safe "prove_heap_update" ctxt)) i
509  end
510
511fun get_field_h_val_rewrites ctxt =
512    Proof_Context.get_thms ctxt "field_h_val_rewrites"
513        handle ERROR _ => raise THM
514            ("run add_field_h_val_rewrites on ctxt", 1, [])
515
516fun get_field_offset_rewrites ctxt =
517    Proof_Context.get_thms ctxt "field_offset_rewrites"
518        handle ERROR _ => raise THM
519            ("run add_field_offset_rewrites on ctxt", 1, [])
520
521fun get_globals_rewrites ctxt = let
522    val gsr = Proof_Context.get_thms ctxt "globals_swap_rewrites"
523    val cgr = Proof_Context.get_thms ctxt "const_globals_rewrites_with_swap"
524    val pinv = Proof_Context.get_thms ctxt "pointer_inverse_safe_global_rules"
525    val pinv2 = map (simplify (put_simpset HOL_basic_ss ctxt
526        addsimps @{thms pointer_inverse_safe_sign ptr_coerce.simps})) pinv
527  in (gsr, cgr, pinv @ pinv2) end
528        handle ERROR _ => raise THM
529            ("run add_globals_swap_rewrites on ctxt", 1, [])
530
531fun add_symbols (Free (_, _) $ s) xs = (case try HOLogic.dest_string s
532        of SOME str => str :: xs | _ => xs)
533  | add_symbols (f $ x) xs = add_symbols f (add_symbols x xs)
534  | add_symbols (Abs (_, _, t)) xs = add_symbols t xs
535  | add_symbols _ xs = xs
536
537fun get_symbols t = add_symbols t [] |> Ord_List.make fast_string_ord
538
539fun get_expand_const_globals ctxt goal = let
540    val goal_symbs = get_symbols goal
541    val cgr = #2 (get_globals_rewrites ctxt)
542    val cgr_missing = filter_out (fn t => Ord_List.subset fast_string_ord
543        (get_symbols (Thm.concl_of t), goal_symbs)) cgr
544    val cgs_unfold = map (Thm.concl_of #> HOLogic.dest_Trueprop
545        #> HOLogic.dest_eq #> fst #> dest_Const #> fst) cgr_missing
546    val cgs_unfold_defs = map (suffix "_def"
547        #> Proof_Context.get_thm ctxt) cgs_unfold
548  in cgs_unfold_defs end
549
550fun normalise_mem_accs reason ctxt = DETERM o let
551    val gr = get_globals_rewrites ctxt
552    val msg' = "normalise_mem_accs: " ^ reason
553    fun msg str = msg' ^ ": " ^ str
554    val init_simps = @{thms hrs_mem_update
555                       heap_access_Array_element'
556                       o_def fupdate_def
557                       pointer_inverse_safe_sign
558                       ptr_safe_ptr_add_array_ptr_index
559                       unat_less_helper nat_uint_less_helper
560            } @ get_field_h_val_rewrites ctxt
561        @ #1 gr @ #2 gr
562    val h_val = get_disjoint_h_val_globals_swap ctxt
563    val disjoint_h_val_tac
564    = (eqsubst_asm_wrap_tac ctxt [h_val] ORELSE' eqsubst_wrap_tac ctxt [h_val])
565         THEN' (assume_tac ctxt ORELSE' except_tac ctxt (msg "couldn't atac"))
566  in
567    asm_full_simp_tac (ctxt addsimps init_simps addsimps [h_val])
568    THEN_ALL_NEW warn_schem_tac msg' ctxt (K all_tac)
569    THEN_ALL_NEW
570        (TRY o REPEAT_ALL_NEW ((eqsubst_asm_wrap_tac ctxt
571                    @{thms heap_access_Array_element'}
572                ORELSE' eqsubst_wrap_tac ctxt
573                    @{thms heap_access_Array_element'}
574                ORELSE' disjoint_h_val_tac)
575            THEN_ALL_NEW asm_full_simp_tac (ctxt addsimps init_simps)))
576    THEN_ALL_NEW
577        SUBGOAL (fn (t, i) => case
578            Envir.beta_eta_contract (Logic.strip_assums_concl t)
579          of @{term Trueprop} $ (Const (@{const_name h_t_valid}, _) $ _ $ _ $ _)
580              => prove_ptr_safe msg' ctxt i
581            | @{term Trueprop} $ (Const (@{const_name ptr_safe}, _) $ _ $ _)
582              => prove_ptr_safe msg' ctxt i
583            | _ => all_tac)
584    THEN_ALL_NEW full_simp_tac (ctxt addsimps @{thms h_val_word_simps nat_uint_less_helper})
585  end
586
587val heap_update_id_nonsense
588    = Thm.trivial (Thm.cterm_of @{context} (Proof_Context.read_term_pattern @{context}
589        "Trueprop (heap_update ?p (h_val ?hp' ?p) (hrs_mem ?hrs) = hrs_mem ?hrs)"))
590
591fun prove_mem_equality_init_simpset ctxt =
592    ctxt addsimps
593      @{thms hrs_mem_update heap_update_Array_update heap_access_Array_element' o_def}
594        @ get_field_h_val_rewrites ctxt
595
596fun prove_mem_equality_unpack_simpset ctxt =
597    ctxt addsimps
598      @{thms heap_update_def to_bytes_array
599             heap_update_list_append
600             h_val_word_simps
601             heap_update_word_simps
602             heap_list_update_word_simps
603             to_bytes_sword
604             drop_sign_isomorphism
605             field_lvalue_offset_eq ptr_add_def
606             array_ptr_index_def
607             take_heap_list_min drop_heap_list_general
608             ucast_nat_def of_int_sint_scast of_int_uint_ucast
609             heap_access_Array_element field_lvalue_def}
610        @ Proof_Context.get_thms ctxt "field_to_bytes_rewrites"
611        @ (get_field_h_val_rewrites ctxt)
612      addsimprocs [Word_Bitwise_Tac.expand_upt_simproc]
613      delsimps @{thms One_nat_def}
614      addsimps @{thms One_nat_def[symmetric]}
615    handle ERROR _ => raise THM
616      ("prove_mem_equality: run add_field_to_bytes_rewrites on ctxt", 1, [])
617
618fun prove_mem_equality_unchecked ctxt = let
619    fun heap_update_id_proofs ctxt =
620        REPEAT_ALL_NEW (eqsubst_wrap_tac ctxt [heap_update_id_nonsense]
621            THEN' prove_heap_update_id ctxt)
622
623  in
624    (trace_tac ctxt "prove_mem_equality: initial state")
625    THEN_ALL_NEW (simp_tac (prove_mem_equality_init_simpset ctxt))
626    THEN_ALL_NEW warn_schem_tac "prove_mem_equality: before subst" ctxt (K all_tac)
627    THEN_ALL_NEW (TRY o REPEAT_ALL_NEW ((eqsubst_wrap_tac ctxt
628            @{thms heap_access_Array_element' heap_update_Array_update})
629        THEN_ALL_NEW simp_tac (prove_mem_equality_init_simpset ctxt)))
630    THEN_ALL_NEW TRY o heap_update_id_proofs ctxt
631    THEN_ALL_NEW SUBGOAL (fn (t, i) => if
632        exists_Const (fn (s, T) => s = @{const_name heap_update}
633            andalso get_c_type_size ctxt (domain_type (range_type T)) > 256
634        ) t
635        then except_tac ctxt "prove_mem_equality: unfolding large heap_update" i
636        else all_tac)
637    (* need to normalise_mem_accs first as it operates on typed pointer ops
638       and won't function after we unpack them *)
639    THEN_ALL_NEW normalise_mem_accs "prove_mem_equality" ctxt
640    THEN_ALL_NEW asm_lr_simp_tac (prove_mem_equality_unpack_simpset ctxt)
641    THEN_ALL_NEW simp_tac (ctxt addsimps @{thms add_ac mult_ac add_mult_comms ucast_id})
642  end
643
644fun prove_mem_equality ctxt = DETERM o
645    (prove_mem_equality_unchecked ctxt
646    THEN_ALL_NEW SUBGOAL (fn (t, i) =>
647      if exists_Const (fn (s, _) =>
648                 s = @{const_name store_word8}
649          orelse s = @{const_name store_word32}
650          orelse s = @{const_name store_word64}
651          orelse s = @{const_name heap_update}
652          orelse s = @{const_name heap_update_list}) t
653      then except_tac ctxt "prove_mem_equality: remaining mem upds" i
654      else all_tac))
655
656fun prove_global_equality ctxt
657    = simp_tac (ctxt addsimps (#1 (get_globals_rewrites ctxt)))
658        THEN' prove_mem_equality ctxt
659
660fun clean_heap_upd_swap ctxt = DETERM o let
661    val thm = @{thm disjoint_heap_update_globals_swap_rearranged}
662    val thm = res_from_ctxt "clean_heap_upd_swap" "global_acc_valid" ctxt thm
663    val thm = res_from_ctxt "clean_heap_upd_swap" "globals_list_distinct" ctxt thm
664  in fn i => resolve_tac ctxt [@{thm trans}]  i
665    THEN (resolve_tac ctxt [thm] i
666      ORELSE except_tac ctxt "clean_heap_upd_swap: couldn't rtac" i)
667    THEN (assume_tac ctxt i (* htd_safe assumption *)
668      ORELSE except_tac ctxt "clean_heap_upd_swap: couldn't atac" i)
669    THEN prove_ptr_safe "clean_upd_upd_swap" ctxt i
670  end
671
672fun clean_htd_upd_swap ctxt = let
673    val thm = @{thm globals_swap_hrs_htd_update[symmetric]}
674    val thm = res_from_ctxt "clean_htd_upd_swap" "global_acc_valid" ctxt thm
675    val thm = res_from_ctxt "clean_htd_upd_swap" "globals_list_valid" ctxt thm
676  in simp_tac (ctxt addsimps [thm])
677    THEN_ALL_NEW (except_tac ctxt "clean_htd_upd_swap: not finished")
678  end
679
680fun heap_upd_kind (Const (@{const_name heap_update}, _) $ _ $ _ $ _)
681    = "HeapUpd"
682  | heap_upd_kind (Const (@{const_name hrs_mem}, _) $ v)
683    = let
684    val gs = exists_Const (fn (s, _) => s = @{const_name globals_swap}) v
685    val hu = exists_Const (fn (s, _) => s = @{const_name heap_update}) v
686    val htd = exists_Const (fn (s, _) => s = @{const_name hrs_htd_update}) v
687  in (gs orelse raise TERM ("heap_upd_kind: hrs_mem but no globals_swap", [v]));
688    if hu then "HeapUpdWithSwap" else if htd then "HTDUpdateWithSwap"
689        else "GlobalUpd"
690  end
691  | heap_upd_kind t = raise TERM ("heap_upd_kind: unknown", [t])
692
693fun decompose_mem_goals_init post trace ctxt = warn_schem_tac "decompose_mem_goals" ctxt
694  (trace_tac ctxt "decompose_mem_goals: init" THEN' SUBGOAL (fn (t, i) =>
695  (case Envir.beta_eta_contract (Logic.strip_assums_concl t) of
696    @{term Trueprop} $ (Const (@{const_name const_globals_in_memory}, _) $ _ $ _ $ _)
697        => let val thm = res_from_ctxt "decompose_mem_goals"
698                        "globals_list_distinct" ctxt
699                        @{thm const_globals_in_memory_heap_updateE}
700        in (eresolve_tac ctxt [thm] THEN' assume_tac ctxt THEN' prove_ptr_safe "const_globals" ctxt)
701            ORELSE' except_tac ctxt "decompose_mem_goals: const globals"
702        end
703    | @{term Trueprop} $ (Const (@{const_name pglobal_valid}, _) $ _ $ _ $ _)
704        => asm_full_simp_tac (ctxt addsimps @{thms pglobal_valid_def}
705              addsimps #3 (get_globals_rewrites ctxt))
706    | @{term Trueprop} $ (@{term "(=) :: heap_mem \<Rightarrow> _"} $ x $ y) => let
707        val query = (heap_upd_kind x, heap_upd_kind y)
708        val _ = if trace then writeln ("decompose_mem_goals: " ^ @{make_string} query)
709            else ()
710      in case (heap_upd_kind x, heap_upd_kind y) of
711          ("HeapUpd", "HeapUpd") => post ctxt
712        | ("HeapUpdWithSwap", "HeapUpd")
713            => clean_heap_upd_swap ctxt THEN' post ctxt
714        | ("HeapUpd", "HeapUpdWithSwap") =>
715            resolve_tac ctxt [@{thm sym}]
716              THEN' clean_heap_upd_swap ctxt THEN' post ctxt
717        | ("HeapUpd", "GlobalUpd") =>
718            simp_tac (ctxt addsimps (#1 (get_globals_rewrites ctxt)))
719              THEN_ALL_NEW (post ctxt)
720        | ("GlobalUpd", "HeapUpd") =>
721            simp_tac (ctxt addsimps (#1 (get_globals_rewrites ctxt)))
722              THEN_ALL_NEW (post ctxt)
723        | ("HTDUpdateWithSwap", _)
724            => clean_htd_upd_swap ctxt
725        | (_, "HTDUpdateWithSwap")
726            => resolve_tac ctxt [@{thm sym}] THEN' clean_htd_upd_swap ctxt
727        | _ => raise TERM ("decompose_mem_goals: mixed up "
728            ^ heap_upd_kind x ^ "," ^ heap_upd_kind y, [x, y])
729      end THEN_ALL_NEW warn_schem_tac "decompose_mem_goals: after"
730        ctxt (K all_tac)
731    | _ => K all_tac) i)
732    THEN' trace_tac ctxt "decompose_mem_goals: end")
733
734val decompose_mem_goals = decompose_mem_goals_init prove_mem_equality
735
736fun unat_mono_tac ctxt = resolve_tac ctxt @{thms unat_mono_intro}
737    THEN' ((((TRY o REPEAT_ALL_NEW (resolve_tac ctxt @{thms unat_mono_thms}))
738                THEN_ALL_NEW resolve_tac ctxt [@{thm order_refl}])
739            THEN_ALL_NEW except_tac ctxt "unat_mono_tac: escaped order_refl")
740        ORELSE' except_tac ctxt "unat_mono_tac: couldn't get started")
741    THEN' (asm_full_simp_tac (ctxt addsimps @{thms
742            word_sless_to_less word_sle_to_le
743        })
744        THEN_ALL_NEW asm_full_simp_tac (ctxt addsimps @{thms
745            word_less_nat_alt word_le_nat_alt
746            unat_ucast_if_up
747        })
748        THEN_ALL_NEW except_tac ctxt "unat_mono_tac: unsolved")
749
750fun dest_ptr_add_assertion ctxt = SUBGOAL (fn (t, i) =>
751    if Term.exists_Const (fn (s, _) => s = @{const_name parray_valid}) t
752        then (full_simp_tac (ctxt addsimps @{thms ptr_add_assertion'
753            typ_uinfo_t_diff_from_typ_name parray_valid_def
754            ptr_add_assertion_unfold_numeral} delsimps @{thms One_nat_def})
755          THEN_ALL_NEW TRY o REPEAT_ALL_NEW (dresolve_tac ctxt
756            @{thms ptr_add_assertion_uintD[rule_format]
757                   ptr_add_assertion_sintD[rule_format]})
758          THEN_ALL_NEW TRY o safe_goal_tac ctxt
759          THEN_ALL_NEW SUBGOAL (fn (t, i) =>
760            if Term.exists_Const (fn (s, _) => s = @{const_name ptr_add_assertion}
761                orelse s = @{const_name ptr_add_assertion'}) t
762            then except_tac ctxt "dest_ptr_add_assertion" i
763            else all_tac)
764        ) i
765    else all_tac)
766
767fun tactic_check' (ss, t) = (ss, tactic_check (hd ss) t)
768
769fun graph_refine_proof_tacs csenv ctxt = let
770    (* FIXME: fix shiftr_no and sshiftr_no in Word *)
771    val ctxt = ctxt delsimps @{thms shiftr_no sshiftr_no shiftl_numeral}
772        |> Splitter.del_split @{thm if_split}
773        |> Simplifier.del_cong @{thm if_weak_cong}
774
775  in [
776        (["step 1: normalise some word arithmetic. this needs",
777            "to be done before any general simplification.",
778            "also unfold some things that may be in assumptions",
779            "and should be unfolded"],
780        full_simp_tac (put_simpset HOL_basic_ss ctxt addsimps @{thms
781              guard_arith_simps
782              mex_def meq_def}
783              addsimps [Proof_Context.get_thm ctxt "simpl_invariant_def"])),
784        (["step 2: normalise a lot of things that occur in",
785            "simpl->graph that are extraneous"],
786        SUBGOAL (fn (t, i) =>
787            asm_full_simp_tac (ctxt addsimps @{thms eq_impl_def
788                    var_word32_def var_word8_def var_mem_def
789                    var_word64_def var_word16_def var_ghoststate_def
790                    var_htd_def var_acc_var_upd
791                    var_ms_def init_vars_def
792                    return_vars_def upd_vars_def save_vals_def
793                    asm_args_to_list_def asm_rets_to_list_def
794                    mem_upd_def mem_acc_def hrs_mem_update
795                    hrs_htd_update
796                    fupdate_def
797                    hrs_mem_update_triv
798
799                    (* this includes wrappers for word arithmetic
800                       and other simpl actions*)
801                    bvlshr_def bvashr_def bvshl_def bv_clz_def
802
803                    (* and some stupidity *)
804                    Collect_const_mem
805                    }
806                (* we should also unfold enumerations, since the graph
807                   representation does this, and we need to normalise
808                   word arithmetic the same way on both sides. *)
809                addsimps (enum_simps csenv ctxt)
810                (* unfold constant globals unless we can see their symbols
811                   somewhere else in the goal *)
812                addsimps (get_expand_const_globals ctxt t)
813                (* and fold up expanded array accesses, and clean up assertion_data get/set *)
814                addsimprocs [fold_of_nat_eq_Ifs_simproc, unfold_assertion_data_get_set]
815            ) i)),
816        (["step 3: split into goals with safe steps",
817            "also derive ptr_safe assumptions from h_t_valid",
818            "and adjust ptr_add_assertion facts",
819            "also work on some asm_semantics problems"],
820        trace_tac ctxt "step 3: init" THEN' (TRY o safe_goal_tac ctxt)
821            THEN_ALL_NEW (TRY o DETERM o resolve_tac ctxt @{thms TrueI})
822            THEN_ALL_NEW warn_schem_tac "step 3" ctxt (K all_tac)
823            THEN_ALL_NEW (TRY o DETERM
824                o REPEAT_ALL_NEW (dresolve_tac ctxt [@{thm h_t_valid_orig_and_ptr_safe}]))
825            THEN_ALL_NEW (TRY o DETERM o (eresolve_tac ctxt [@{thm asm_semantics_protects_globs_revD[rule_format]}]
826                THEN_ALL_NEW asm_full_simp_tac ctxt))
827            THEN_ALL_NEW (TRY o safe_goal_tac ctxt)),
828        (["step 4: split up memory write problems",
829          "and expand ptr_add_assertion if needed."],
830        trace_tac ctxt "step 4: init" THEN' decompose_mem_goals false ctxt
831          THEN_ALL_NEW dest_ptr_add_assertion ctxt),
832        (["step 5: normalise memory reads"],
833        normalise_mem_accs "step 5" ctxt),
834        (["step 6: explicitly apply some inequalities"],
835        TRY o DETERM o REPEAT_ALL_NEW
836            (eqsubst_either_wrap_tac ctxt @{thms machine_word_truncate_noop})),
837        (["step 7: try to simplify out all remaining word logic"],
838        asm_full_simp_tac (ctxt addsimps @{thms
839                        pvalid_def pweak_valid_def palign_valid_def
840                        field_lvalue_offset_eq array_ptr_index_def ptr_add_def
841                        mask_def unat_less_helper
842                        word_sle_def[THEN iffD2] word_sless_alt[THEN iffD2]
843                        drop_sign_isomorphism max_word_minus
844                        ptr_equalities_to_ptr_val
845                        word_neq_0_conv_neg_conv
846                        ucast_nat_def of_int_sint_scast of_int_uint_ucast
847                        unat_ucast_upcasts
848                        ptr_val_inj[symmetric]
849                        fold_all_htd_updates
850                        array_assertion_shrink_right
851                        sdiv_word_def sdiv_int_def
852                        signed_ge_zero_scast_eq_ucast
853                        unatSuc[OF less_is_non_zero_p1'] unatSuc2[OF less_is_non_zero_p1]
854                        less_shift_targeted_cast_convs
855                } delsimps @{thms ptr_val_inj})),
856        (["step 8: try rewriting ring equalitites",
857            "this must be done after general simplification",
858            "because of a bug in the simpset for 2 ^ n",
859            "(unfolded in Suc notation if addition is commuted)"],
860        asm_full_simp_tac (ctxt addsimps @{thms field_simps})),
861        (["step 9: attack unat less-than properties explicitly"],
862        TRY o unat_mono_tac ctxt)
863
864    ]
865
866  end
867
868fun graph_refine_proof_full_tac csenv ctxt = EVERY
869    (map (fn (ss, t) => ALLGOALS
870        (t ORELSE' except_tac ctxt ("FAILED: " ^ space_implode "\n" ss)))
871        (graph_refine_proof_tacs csenv ctxt))
872
873fun graph_refine_proof_full_goal_tac csenv ctxt i t
874    = (foldr1 (op THEN_ALL_NEW)
875        (map snd (graph_refine_proof_tacs csenv ctxt)) i t)
876        |> try Seq.hd |> (fn NONE => Seq.empty | SOME t => Seq.single t)
877
878fun debug_tac csenv ctxt = let
879    val tacs = graph_refine_proof_tacs csenv ctxt
880    fun wrap_tacs [] _ t = all_tac t
881      | wrap_tacs ((nms, tac) :: tacs) i t = case try ((tac) i #> Seq.hd) t
882        of NONE => (warning ("step failed: " ^ commas nms); all_tac t)
883         | SOME t' => ((fn _ => fn _ => all_tac t') THEN_ALL_NEW wrap_tacs tacs) i t
884  in wrap_tacs tacs end
885
886fun debug_step_tac csenv ctxt step = let
887    val tac = nth (graph_refine_proof_tacs csenv ctxt) (step - 1)
888    fun wrap_tac (nms, tac) i t = case try (tac i #> Seq.hd) t
889        of NONE => (warning ("step failed: " ^ commas nms); all_tac t)
890         | SOME t' => all_tac t'
891  in wrap_tac tac end
892
893fun simpl_to_graph_thm funs csenv ctxt nm = let
894    val hints = SimplToGraphProof.mk_hints funs ctxt nm
895    val init_thm = SimplToGraphProof.simpl_to_graph_upto_subgoals funs hints nm
896        ctxt
897    val res_thm = init_thm |> graph_refine_proof_full_tac csenv ctxt |> Seq.hd
898    val _ = if Thm.nprems_of res_thm = 0 then ()
899        else raise THM ("simpl_to_graph_thm: unsolved subgoals", 1, [res_thm])
900    (* FIXME: make the hidden assumptions of the thm appear again *)
901  in res_thm end
902    handle
903      TERM (s, ts) => raise TERM ("simpl_to_graph_thm: " ^ nm
904        ^ ": " ^ s, ts)
905    | THM (s, idx, ts) => raise THM ("simpl_to_graph_thm: " ^ nm
906        ^ ": " ^ s, idx, ts)
907
908fun test_graph_refine_proof funs csenv ctxt nm = case
909    Symtab.lookup funs nm of SOME (_, _, NONE) => ("skipped " ^ nm, @{thm TrueI})
910  | _ => let
911    val hints = SimplToGraphProof.mk_hints funs ctxt nm
912    val init_thm = SimplToGraphProof.simpl_to_graph_upto_subgoals funs hints nm ctxt
913    val res_thm = init_thm |> ALLGOALS (graph_refine_proof_full_goal_tac csenv ctxt)
914        |> Seq.hd
915    val succ = case Thm.nprems_of res_thm of 0 => "success on "
916        | n => string_of_int n ^ " failed goals: "
917  in (succ ^ nm, res_thm) end handle TERM (s, ts) => raise TERM ("test_graph_refine_proof: " ^ nm
918        ^ ": " ^ s, ts)
919
920\<comment>\<open>
921  Utility for configuring SimplToGraphProof with debugging features.
922\<close>
923type debug_config = {
924  \<comment>\<open> Functions with these names won't be tested. \<close>
925  skips: string list,
926  \<comment>\<open> If non-empty, *only* functions with these names will be tested. \<close>
927  only: string list,
928
929  \<comment>\<open>
930    Timeout for proofs. Any individual proof that takes longer
931    than this will be aborted and logged.
932  \<close>
933  timeout: Time.time option
934};
935
936type debug = {
937  config: debug_config,
938
939  \<comment>\<open>
940    Logs the names of functions when they pass or fail tests, or timeout,
941    or are skipped because they don't have a definition.
942  \<close>
943  successes: (string list) Unsynchronized.ref,
944  failures: (string list) Unsynchronized.ref,
945  timeouts: (string list) Unsynchronized.ref,
946  new_skips: (string list) Unsynchronized.ref
947};
948
949fun new_debug (config: debug_config): debug = {
950  config = config,
951  new_skips = Unsynchronized.ref [],
952  successes = Unsynchronized.ref [],
953  failures = Unsynchronized.ref [],
954  timeouts = Unsynchronized.ref []
955}
956
957fun no_debug (): debug = new_debug { skips = [], only = [], timeout = NONE };
958
959fun insert (dbg: debug) field x = change (field dbg) (curry (op ::) x)
960
961fun filter_fns (dbg: debug) =
962    (if null (#only (#config dbg)) then I else filter (member (op =) (#only (#config dbg)))) #>
963    (if null (#skips (#config dbg)) then I else filter_out (member (op =) (#skips (#config dbg))))
964
965fun has (dbg: debug) field = not (null (! (field dbg)))
966
967fun interleave _ [] = []
968  | interleave _ [a] = [a]
969  | interleave x (a :: b :: xs) = a :: x :: interleave x (b :: xs);
970
971\<comment>\<open>
972  Produces a string that should be valid SML; useful for copy-pasting lists of functions
973  to modify debug lists.
974\<close>
975fun render_ML_string_list xs =
976  if null xs
977  then "(none)"
978  else
979    let
980      val lines = map (fn x => "\"" ^ x ^ "\"") xs |> interleave ",\n" |> List.foldr (op ^) ""
981    in "[\n" ^ lines ^ "\n]" end;
982
983fun print (dbg: debug) msg field =
984  let
985    val data = !(field dbg);
986    val _ = writeln msg;
987  in render_ML_string_list data |> writeln end
988
989fun timeout (dbg: debug) f =
990  case #timeout (#config dbg) of
991      SOME time => Timeout.apply time f
992    | NONE => f;
993
994fun test_graph_refine_proof_with_def funs csenv ctxt dbg nm =
995  case Symtab.lookup funs nm of
996      SOME (_, _, NONE) => (insert dbg #new_skips nm; "skipped " ^ nm)
997    | _ =>
998      let
999        val ctxt = define_graph_fun_short funs nm ctxt
1000        fun do_proof nm = (simpl_to_graph_thm funs csenv ctxt nm; insert dbg #successes nm)
1001        fun try_proof nm =
1002            ((timeout dbg do_proof) nm)
1003            handle
1004              TERM (message, data) =>
1005                (insert dbg #failures nm; raise TERM ("failure for " ^ nm ^ ": " ^ message, data))
1006            | THM (message, idx, data) =>
1007                (insert dbg #failures nm; raise THM ("failure for " ^ nm ^ ": " ^ message, idx, data))
1008            | Timeout.TIMEOUT t =>
1009                (insert dbg #timeouts nm; raise Timeout.TIMEOUT t);
1010        val (time, _) = Timing.timing try_proof nm
1011      in "success on " ^ nm ^ "  [" ^ Timing.message time ^ "]" end
1012
1013fun test_all_graph_refine_proofs_after funs csenv ctxt dbg nm = let
1014    val ss = Symtab.keys funs
1015    val n = case nm of NONE => ~1 | SOME nm' => find_index (fn s => s = nm') ss
1016    val ss = if n = ~1 then ss else drop (n + 1) ss
1017    val err = prefix "ERROR for: " #> error
1018    val _ = map (fn s => (writeln ("testing: " ^ s);
1019        writeln (test_graph_refine_proof_with_def funs csenv ctxt dbg s))
1020      handle TERM _ => err s | TYPE _ => err s | THM _ => err s) ss
1021  in "success" end
1022
1023fun test_all_graph_refine_proofs_parallel funs csenv ctxt dbg = let
1024    val ss = Symtab.keys funs |> filter_fns dbg
1025    fun test_and_log nm =
1026        (test_graph_refine_proof_with_def funs csenv ctxt dbg nm |> writeln)
1027        handle
1028          TERM (msg, _) => warning msg
1029        | THM (msg, _, _) => warning msg
1030        | Timeout.TIMEOUT _ => warning ("Timeout for " ^ nm)
1031    val (time, _) = Timing.timing (Par_List.map test_and_log) ss
1032    val time_msg = "[" ^ Timing.message time ^ "]"
1033    val failure_msg =
1034        if has dbg #failures
1035        then SOME "Failures! Check the `#failures` field of the debug parameter.\n"
1036        else NONE;
1037    val timeout_msg =
1038        if has dbg #timeouts
1039        then SOME "Timeouts! Check the `#timeouts` field of the debug parameter.\n"
1040        else NONE;
1041    val msg =
1042        if isSome failure_msg orelse isSome timeout_msg
1043        then SOME (Option.getOpt (failure_msg, "") ^ Option.getOpt (timeout_msg, ""))
1044        else NONE
1045  in
1046    case msg of
1047      SOME msg => error (msg ^ time_msg)
1048    | NONE => "success! " ^ time_msg
1049  end
1050
1051end
1052\<close>
1053
1054end
1055