1(*  Title:       HOL/Tools/Function/termination.ML
2    Author:      Alexander Krauss, TU Muenchen
3
4Context data for termination proofs.
5*)
6
7signature TERMINATION =
8sig
9  type data
10  datatype cell = Less of thm | LessEq of thm * thm | None of thm * thm | False of thm
11
12  val mk_sumcases : data -> typ -> term list -> term
13
14  val get_num_points : data -> int
15  val get_types      : data -> int -> typ
16  val get_measures   : data -> int -> term list
17
18  val get_chain      : data -> term -> term -> thm option option
19  val get_descent    : data -> term -> term -> term -> cell option
20
21  val dest_call : data -> term -> ((string * typ) list * int * term * int * term * term)
22
23  val CALLS : (term list * int -> tactic) -> int -> tactic
24
25  (* Termination tactics *)
26  type ttac = data -> int -> tactic
27
28  val TERMINATION : Proof.context -> tactic -> ttac -> int -> tactic
29
30  val wf_union_tac : Proof.context -> tactic
31
32  val decompose_tac : Proof.context -> ttac
33end
34
35structure Termination : TERMINATION =
36struct
37
38open Function_Lib
39
40val term2_ord = prod_ord Term_Ord.fast_term_ord Term_Ord.fast_term_ord
41structure Term2tab = Table(type key = term * term val ord = term2_ord);
42structure Term3tab =
43  Table(type key = term * (term * term) val ord = prod_ord Term_Ord.fast_term_ord term2_ord);
44
45(** Analyzing binary trees **)
46
47(* Skeleton of a tree structure *)
48
49datatype skel =
50  SLeaf of int (* index *)
51| SBranch of (skel * skel)
52
53
54(* abstract make and dest functions *)
55fun mk_tree leaf branch =
56  let fun mk (SLeaf i) = leaf i
57        | mk (SBranch (s, t)) = branch (mk s, mk t)
58  in mk end
59
60
61fun dest_tree split =
62  let fun dest (SLeaf i) x = [(i, x)]
63        | dest (SBranch (s, t)) x =
64          let val (l, r) = split x
65          in dest s l @ dest t r end
66  in dest end
67
68
69(* concrete versions for sum types *)
70fun is_inj (Const (\<^const_name>\<open>Sum_Type.Inl\<close>, _) $ _) = true
71  | is_inj (Const (\<^const_name>\<open>Sum_Type.Inr\<close>, _) $ _) = true
72  | is_inj _ = false
73
74fun dest_inl (Const (\<^const_name>\<open>Sum_Type.Inl\<close>, _) $ t) = SOME t
75  | dest_inl _ = NONE
76
77fun dest_inr (Const (\<^const_name>\<open>Sum_Type.Inr\<close>, _) $ t) = SOME t
78  | dest_inr _ = NONE
79
80
81fun mk_skel ps =
82  let
83    fun skel i ps =
84      if forall is_inj ps andalso not (null ps)
85      then let
86          val (j, s) = skel i (map_filter dest_inl ps)
87          val (k, t) = skel j (map_filter dest_inr ps)
88        in (k, SBranch (s, t)) end
89      else (i + 1, SLeaf i)
90  in
91    snd (skel 0 ps)
92  end
93
94(* compute list of types for nodes *)
95fun node_types sk T = dest_tree (fn Type (\<^type_name>\<open>Sum_Type.sum\<close>, [LT, RT]) => (LT, RT)) sk T |> map snd
96
97(* find index and raw term *)
98fun dest_inj (SLeaf i) trm = (i, trm)
99  | dest_inj (SBranch (s, t)) trm =
100    case dest_inl trm of
101      SOME trm' => dest_inj s trm'
102    | _ => dest_inj t (the (dest_inr trm))
103
104
105
106(** Matrix cell datatype **)
107
108datatype cell = Less of thm | LessEq of thm * thm | None of thm * thm | False of thm;
109
110
111type data =
112  skel                            (* structure of the sum type encoding "program points" *)
113  * (int -> typ)                  (* types of program points *)
114  * (term list Inttab.table)      (* measures for program points *)
115  * (term * term -> thm option)   (* which calls form chains? (cached) *)
116  * (term * (term * term) -> cell)(* local descents (cached) *)
117
118
119(* Build case expression *)
120fun mk_sumcases (sk, _, _, _, _) T fs =
121  mk_tree (fn i => (nth fs i, domain_type (fastype_of (nth fs i))))
122          (fn ((f, fT), (g, gT)) => (Sum_Tree.mk_sumcase fT gT T f g, Sum_Tree.mk_sumT fT gT))
123          sk
124  |> fst
125
126fun mk_sum_skel rel =
127  let
128    val cs = Function_Lib.dest_binop_list \<^const_name>\<open>Lattices.sup\<close> rel
129    fun collect_pats (Const (\<^const_name>\<open>Collect\<close>, _) $ Abs (_, _, c)) =
130      let
131        val (Const (\<^const_name>\<open>HOL.conj\<close>, _) $ (Const (\<^const_name>\<open>HOL.eq\<close>, _) $ _ $ (Const (\<^const_name>\<open>Pair\<close>, _) $ r $ l)) $ _)
132          = Term.strip_qnt_body \<^const_name>\<open>Ex\<close> c
133      in cons r o cons l end
134  in
135    mk_skel (fold collect_pats cs [])
136  end
137
138fun prove_chain ctxt chain_tac (c1, c2) =
139  let
140    val goal =
141      HOLogic.mk_eq (HOLogic.mk_binop \<^const_name>\<open>Relation.relcomp\<close> (c1, c2),
142        Const (\<^const_abbrev>\<open>Set.empty\<close>, fastype_of c1))
143      |> HOLogic.mk_Trueprop (* "C1 O C2 = {}" *)
144  in
145    (case Function_Lib.try_proof ctxt (Thm.cterm_of ctxt goal) chain_tac of
146      Function_Lib.Solved thm => SOME thm
147    | _ => NONE)
148  end
149
150
151fun dest_call' sk (Const (\<^const_name>\<open>Collect\<close>, _) $ Abs (_, _, c)) =
152  let
153    val vs = Term.strip_qnt_vars \<^const_name>\<open>Ex\<close> c
154
155    (* FIXME: throw error "dest_call" for malformed terms *)
156    val (Const (\<^const_name>\<open>HOL.conj\<close>, _) $ (Const (\<^const_name>\<open>HOL.eq\<close>, _) $ _ $ (Const (\<^const_name>\<open>Pair\<close>, _) $ r $ l)) $ Gam)
157      = Term.strip_qnt_body \<^const_name>\<open>Ex\<close> c
158    val (p, l') = dest_inj sk l
159    val (q, r') = dest_inj sk r
160  in
161    (vs, p, l', q, r', Gam)
162  end
163  | dest_call' _ _ = error "dest_call"
164
165fun dest_call (sk, _, _, _, _) = dest_call' sk
166
167fun mk_desc ctxt tac vs Gam l r m1 m2 =
168  let
169    fun try rel =
170      try_proof ctxt (Thm.cterm_of ctxt
171        (Logic.list_all (vs,
172           Logic.mk_implies (HOLogic.mk_Trueprop Gam,
173             HOLogic.mk_Trueprop (Const (rel, \<^typ>\<open>nat \<Rightarrow> nat \<Rightarrow> bool\<close>)
174               $ (m2 $ r) $ (m1 $ l)))))) tac
175  in
176    (case try \<^const_name>\<open>Orderings.less\<close> of
177      Solved thm => Less thm
178    | Stuck thm =>
179        (case try \<^const_name>\<open>Orderings.less_eq\<close> of
180          Solved thm2 => LessEq (thm2, thm)
181        | Stuck thm2 =>
182            if Thm.prems_of thm2 = [HOLogic.Trueprop $ \<^term>\<open>False\<close>]
183            then False thm2 else None (thm2, thm)
184        | _ => raise Match) (* FIXME *)
185    | _ => raise Match)
186end
187
188fun prove_descent ctxt tac sk (c, (m1, m2)) =
189  let
190    val (vs, _, l, _, r, Gam) = dest_call' sk c
191  in 
192    mk_desc ctxt tac vs Gam l r m1 m2
193  end
194
195fun create ctxt chain_tac descent_tac T rel =
196  let
197    val sk = mk_sum_skel rel
198    val Ts = node_types sk T
199    val M = Inttab.make (map_index (apsnd (Measure_Functions.get_measure_functions ctxt)) Ts)
200    val chain_cache =
201      Cache.create Term2tab.empty Term2tab.lookup Term2tab.update
202        (prove_chain ctxt chain_tac)
203    val descent_cache =
204      Cache.create Term3tab.empty Term3tab.lookup Term3tab.update
205        (prove_descent ctxt descent_tac sk)
206  in
207    (sk, nth Ts, M, chain_cache, descent_cache)
208  end
209
210fun get_num_points (sk, _, _, _, _) =
211  let
212    fun num (SLeaf i) = i + 1
213      | num (SBranch (s, t)) = num t
214  in num sk end
215
216fun get_types (_, T, _, _, _) = T
217fun get_measures (_, _, M, _, _) = Inttab.lookup_list M
218
219fun get_chain (_, _, _, C, _) c1 c2 =
220  SOME (C (c1, c2))
221
222fun get_descent (_, _, _, _, D) c m1 m2 =
223  SOME (D (c, (m1, m2)))
224
225fun CALLS tac i st =
226  if Thm.no_prems st then all_tac st
227  else case Thm.term_of (Thm.cprem_of st i) of
228    (_ $ (_ $ rel)) => tac (Function_Lib.dest_binop_list \<^const_name>\<open>Lattices.sup\<close> rel, i) st
229  |_ => no_tac st
230
231type ttac = data -> int -> tactic
232
233fun TERMINATION ctxt atac tac =
234  SUBGOAL (fn (_ $ (Const (\<^const_name>\<open>wf\<close>, wfT) $ rel), i) =>
235  let
236    val (T, _) = HOLogic.dest_prodT (HOLogic.dest_setT (domain_type wfT))
237  in
238    tac (create ctxt atac atac T rel) i
239  end)
240
241
242(* A tactic to convert open to closed termination goals *)
243local
244fun dest_term (t : term) = (* FIXME, cf. Lexicographic order *)
245  let
246    val (vars, prop) = Function_Lib.dest_all_all t
247    val (prems, concl) = Logic.strip_horn prop
248    val (lhs, rhs) = concl
249      |> HOLogic.dest_Trueprop
250      |> HOLogic.dest_mem |> fst
251      |> HOLogic.dest_prod
252  in
253    (vars, prems, lhs, rhs)
254  end
255
256fun mk_pair_compr (T, qs, l, r, conds) =
257  let
258    val pT = HOLogic.mk_prodT (T, T)
259    val n = length qs
260    val peq = HOLogic.eq_const pT $ Bound n $ (HOLogic.pair_const T T $ l $ r)
261    val conds' = if null conds then [\<^term>\<open>True\<close>] else conds
262  in
263    HOLogic.Collect_const pT $
264    Abs ("uu_", pT,
265      (foldr1 HOLogic.mk_conj (peq :: conds')
266      |> fold_rev (fn v => fn t => HOLogic.exists_const (fastype_of v) $ lambda v t) qs))
267  end
268
269val Un_aci_simps =
270  map mk_meta_eq @{thms Un_ac Un_absorb}
271
272in
273
274fun wf_union_tac ctxt st = SUBGOAL (fn _ =>
275  let
276    val ((_ $ (_ $ rel)) :: ineqs) = Thm.prems_of st
277
278    fun mk_compr ineq =
279      let
280        val (vars, prems, lhs, rhs) = dest_term ineq
281      in
282        mk_pair_compr (fastype_of lhs, vars, lhs, rhs, map (Object_Logic.atomize_term ctxt) prems)
283      end
284
285    val relation =
286      if null ineqs
287      then Const (\<^const_abbrev>\<open>Set.empty\<close>, fastype_of rel)
288      else map mk_compr ineqs
289        |> foldr1 (HOLogic.mk_binop \<^const_name>\<open>Lattices.sup\<close>)
290
291    fun solve_membership_tac i =
292      (EVERY' (replicate (i - 2) (resolve_tac ctxt @{thms UnI2}))  (* pick the right component of the union *)
293      THEN' (fn j => TRY (resolve_tac ctxt @{thms UnI1} j))
294      THEN' (resolve_tac ctxt @{thms CollectI})                    (* unfold comprehension *)
295      THEN' (fn i => REPEAT (resolve_tac ctxt @{thms exI} i))      (* Turn existentials into schematic Vars *)
296      THEN' ((resolve_tac ctxt @{thms refl})                       (* unification instantiates all Vars *)
297        ORELSE' ((resolve_tac ctxt @{thms conjI})
298          THEN' (resolve_tac ctxt @{thms refl})
299          THEN' (blast_tac ctxt)))    (* Solve rest of context... not very elegant *)
300      ) i
301  in
302    if is_Var rel then
303      PRIMITIVE (infer_instantiate ctxt [(#1 (dest_Var rel), Thm.cterm_of ctxt relation)])
304        THEN ALLGOALS (fn i => if i = 1 then all_tac else solve_membership_tac i)
305        THEN rewrite_goal_tac ctxt Un_aci_simps 1  (* eliminate duplicates *)
306    else no_tac
307  end) 1 st
308
309end
310
311
312
313(*** DEPENDENCY GRAPHS ***)
314
315fun mk_dgraph D cs =
316  Term_Graph.empty
317  |> fold (fn c => Term_Graph.new_node (c, ())) cs
318  |> fold_product (fn c1 => fn c2 =>
319     if is_none (get_chain D c1 c2 |> the_default NONE)
320     then Term_Graph.add_edge (c2, c1) else I)
321     cs cs
322
323fun ucomp_empty_tac ctxt T =
324  REPEAT_ALL_NEW (resolve_tac ctxt @{thms union_comp_emptyR}
325    ORELSE' resolve_tac ctxt @{thms union_comp_emptyL}
326    ORELSE' SUBGOAL (fn (_ $ (_ $ (_ $ c1 $ c2) $ _), i) => resolve_tac ctxt [T c1 c2] i))
327
328fun regroup_calls_tac ctxt cs = CALLS (fn (cs', i) =>
329 let
330   val is = map (fn c => find_index (curry op aconv c) cs') cs
331 in
332   CONVERSION (Conv.arg_conv (Conv.arg_conv
333     (Function_Lib.regroup_union_conv ctxt is))) i
334 end)
335
336
337fun solve_trivial_tac ctxt D =
338  CALLS (fn ([c], i) =>
339    (case get_chain D c c of
340      SOME (SOME thm) =>
341        resolve_tac ctxt @{thms wf_no_loop} i THEN
342        resolve_tac ctxt [thm] i
343    | _ => no_tac)
344  | _ => no_tac)
345
346fun decompose_tac ctxt D = CALLS (fn (cs, i) =>
347  let
348    val G = mk_dgraph D cs
349    val sccs = Term_Graph.strong_conn G
350
351    fun split [SCC] i = TRY (solve_trivial_tac ctxt D i)
352      | split (SCC::rest) i =
353        regroup_calls_tac ctxt SCC i
354        THEN resolve_tac ctxt @{thms wf_union_compatible} i
355        THEN resolve_tac ctxt @{thms less_by_empty} (i + 2)
356        THEN ucomp_empty_tac ctxt (the o the oo get_chain D) (i + 2)
357        THEN split rest (i + 1)
358        THEN TRY (solve_trivial_tac ctxt D i)
359  in
360    if length sccs > 1 then split sccs i
361    else solve_trivial_tac ctxt D i
362  end)
363
364end
365