1(*  Title:      HOL/Tools/Function/pattern_split.ML
2    Author:     Alexander Krauss, TU Muenchen
3
4Fairly ad-hoc pattern splitting.
5*)
6
7signature FUNCTION_SPLIT =
8sig
9  val split_some_equations :
10      Proof.context -> (bool * term) list -> term list list
11
12  val split_all_equations :
13      Proof.context -> term list -> term list list
14end
15
16structure Function_Split : FUNCTION_SPLIT =
17struct
18
19open Function_Lib
20
21fun new_var ctxt vs T =
22  let
23    val [v] = Variable.variant_frees ctxt vs [("v", T)]
24  in
25    (Free v :: vs, Free v)
26  end
27
28fun saturate ctxt vs t =
29  fold (fn T => fn (vs, t) => new_var ctxt vs T |> apsnd (curry op $ t))
30    (binder_types (fastype_of t)) (vs, t)
31
32
33fun join ((vs1,sub1), (vs2,sub2)) = (merge (op aconv) (vs1,vs2), sub1 @ sub2)
34fun join_product (xs, ys) = map_product (curry join) xs ys
35
36exception DISJ
37
38fun pattern_subtract_subst ctxt vs t t' =
39  let
40    exception DISJ
41    fun pattern_subtract_subst_aux vs _ (Free v2) = []
42      | pattern_subtract_subst_aux vs (v as (Free (_, T))) t' =
43          let
44            fun aux constr =
45              let
46                val (vs', t) = saturate ctxt vs constr
47                val substs = pattern_subtract_subst ctxt vs' t t'
48              in
49                map (fn (vs, subst) => (vs, (v,t)::subst)) substs
50              end
51          in
52            maps aux (inst_constrs_of ctxt T)
53          end
54     | pattern_subtract_subst_aux vs t t' =
55         let
56           val (C, ps) = strip_comb t
57           val (C', qs) = strip_comb t'
58         in
59           if C = C'
60           then flat (map2 (pattern_subtract_subst_aux vs) ps qs)
61           else raise DISJ
62         end
63  in
64    pattern_subtract_subst_aux vs t t'
65    handle DISJ => [(vs, [])]
66  end
67
68(* p - q *)
69fun pattern_subtract ctxt eq2 eq1 =
70  let
71    val thy = Proof_Context.theory_of ctxt
72
73    val (vs, feq1 as (_ $ (_ $ lhs1 $ _))) = dest_all_all eq1
74    val (_,  _ $ (_ $ lhs2 $ _)) = dest_all_all eq2
75
76    val substs = pattern_subtract_subst ctxt vs lhs1 lhs2
77
78    fun instantiate (vs', sigma) =
79      let
80        val t = Pattern.rewrite_term thy sigma [] feq1
81        val xs = fold_aterms
82          (fn x as Free (a, _) =>
83              if not (Variable.is_fixed ctxt a) andalso member (op =) vs' x
84              then insert (op =) x else I
85            | _ => I) t [];
86      in fold Logic.all xs t end
87  in
88    map instantiate substs
89  end
90
91(* ps - p' *)
92fun pattern_subtract_from_many ctxt p'=
93  maps (pattern_subtract ctxt p')
94
95(* in reverse order *)
96fun pattern_subtract_many ctxt ps' =
97  fold_rev (pattern_subtract_from_many ctxt) ps'
98
99fun split_some_equations ctxt eqns =
100  let
101    fun split_aux prev [] = []
102      | split_aux prev ((true, eq) :: es) =
103          pattern_subtract_many ctxt prev [eq] :: split_aux (eq :: prev) es
104      | split_aux prev ((false, eq) :: es) =
105          [eq] :: split_aux (eq :: prev) es
106  in
107    split_aux [] eqns
108  end
109
110fun split_all_equations ctxt =
111  split_some_equations ctxt o map (pair true)
112
113
114end
115