1(*
2 * Copyright 2020, Data61, CSIRO (ABN 41 687 119 230)
3 *
4 * SPDX-License-Identifier: BSD-2-Clause
5 *)
6
7theory SplitRule
8imports Main
9begin
10
11ML \<open>
12
13fun str_of_term t = Pretty.string_of (Syntax.pretty_term @{context} t)
14
15structure SplitSimps = struct
16
17val conjunct_rules = foldr1 (fn (a, b) => [a, b] MRS conjI);
18
19fun was_split t = let
20    val is_free_eq_imp = is_Free o fst o HOLogic.dest_eq
21              o fst o HOLogic.dest_imp;
22    val get_conjs = HOLogic.dest_conj o HOLogic.dest_Trueprop;
23    fun dest_alls (Const ("HOL.All", _) $ Abs (_, _, t)) = dest_alls t
24      | dest_alls t = t;
25  in forall (is_free_eq_imp o dest_alls) (get_conjs t) end
26        handle TERM _ => false;
27
28fun apply_split ctxt split t = Seq.of_list let
29    val (t', thaw) = Misc_Legacy.freeze_thaw_robust ctxt t;
30  in (map (thaw 0) (filter (was_split o Thm.prop_of) ([t'] RL [split]))) end;
31
32fun forward_tac rules t = Seq.of_list ([t] RL rules);
33
34val refl_imp = refl RSN (2, mp);
35
36val get_rules_once_split =
37  REPEAT (forward_tac [conjunct1, conjunct2])
38    THEN REPEAT (forward_tac [spec])
39    THEN (forward_tac [refl_imp]);
40
41fun do_split ctxt split = let
42    val split' = split RS iffD1;
43    val split_rhs = Thm.concl_of (fst (Misc_Legacy.freeze_thaw_robust ctxt split'));
44  in if was_split split_rhs
45     then apply_split ctxt split' THEN get_rules_once_split
46     else raise TERM ("malformed split rule: " ^ (str_of_term split_rhs), [split_rhs])
47  end;
48
49val atomize_meta_eq = forward_tac [meta_eq_to_obj_eq];
50
51fun better_split ctxt splitthms thm = conjunct_rules
52  (Seq.list_of ((TRY atomize_meta_eq
53                 THEN (REPEAT (FIRST (map (do_split ctxt) splitthms)))) thm));
54
55val split_att
56  = Attrib.thms >>
57    (fn thms => Thm.rule_attribute thms (fn context => better_split (Context.proof_of context) thms));
58
59
60end;
61
62\<close>
63
64ML \<open>
65val split_att_setup =
66  Attrib.setup @{binding split_simps} SplitSimps.split_att
67    "split rule involving case construct into multiple simp rules";
68\<close>
69
70setup split_att_setup
71
72definition
73  split_rule_test :: "((nat => 'a) + ('b * (('b => 'a) option))) => ('a => nat) => nat"
74where
75 "split_rule_test x f = f (case x of Inl af \<Rightarrow> af 1
76    | Inr (b, None) => inv f 0
77    | Inr (b, Some g) => g b)"
78
79lemmas split_rule_test_simps
80    = split_rule_test_def[split_simps sum.split prod.split option.split]
81
82end
83