1(*  Title:      Provers/trancl.ML
2    Author:     Oliver Kutter, TU Muenchen
3
4Transitivity reasoner for transitive closures of relations
5*)
6
7(*
8
9The packages provides tactics trancl_tac and rtrancl_tac that prove
10goals of the form
11
12   (x,y) : r^+     and     (x,y) : r^* (rtrancl_tac only)
13
14from premises of the form
15
16   (x,y) : r,     (x,y) : r^+     and     (x,y) : r^* (rtrancl_tac only)
17
18by reflexivity and transitivity.  The relation r is determined by inspecting
19the conclusion.
20
21The package is implemented as an ML functor and thus not limited to
22particular constructs for transitive and reflexive-transitive
23closures, neither need relations be represented as sets of pairs.  In
24order to instantiate the package for transitive closure only, supply
25dummy theorems to the additional rules for reflexive-transitive
26closures, and don't use rtrancl_tac!
27
28*)
29
30signature TRANCL_ARITH =
31sig
32
33  (* theorems for transitive closure *)
34
35  val r_into_trancl : thm
36      (* (a,b) : r ==> (a,b) : r^+ *)
37  val trancl_trans : thm
38      (* [| (a,b) : r^+ ; (b,c) : r^+ |] ==> (a,c) : r^+ *)
39
40  (* additional theorems for reflexive-transitive closure *)
41
42  val rtrancl_refl : thm
43      (* (a,a): r^* *)
44  val r_into_rtrancl : thm
45      (* (a,b) : r ==> (a,b) : r^* *)
46  val trancl_into_rtrancl : thm
47      (* (a,b) : r^+ ==> (a,b) : r^* *)
48  val rtrancl_trancl_trancl : thm
49      (* [| (a,b) : r^* ; (b,c) : r^+ |] ==> (a,c) : r^+ *)
50  val trancl_rtrancl_trancl : thm
51      (* [| (a,b) : r^+ ; (b,c) : r^* |] ==> (a,c) : r^+ *)
52  val rtrancl_trans : thm
53      (* [| (a,b) : r^* ; (b,c) : r^* |] ==> (a,c) : r^* *)
54
55  (* decomp: decompose a premise or conclusion
56
57     Returns one of the following:
58
59     NONE if not an instance of a relation,
60     SOME (x, y, r, s) if instance of a relation, where
61       x: left hand side argument, y: right hand side argument,
62       r: the relation,
63       s: the kind of closure, one of
64            "r":   the relation itself,
65            "r^+": transitive closure of the relation,
66            "r^*": reflexive-transitive closure of the relation
67  *)
68
69  val decomp: term ->  (term * term * term * string) option
70
71end;
72
73signature TRANCL_TAC =
74sig
75  val trancl_tac: Proof.context -> int -> tactic
76  val rtrancl_tac: Proof.context -> int -> tactic
77end;
78
79functor Trancl_Tac(Cls: TRANCL_ARITH): TRANCL_TAC =
80struct
81
82
83datatype proof
84  = Asm of int
85  | Thm of proof list * thm;
86
87exception Cannot; (* internal exception: raised if no proof can be found *)
88
89fun decomp t = Option.map (fn (x, y, rel, r) =>
90  (Envir.beta_eta_contract x, Envir.beta_eta_contract y,
91   Envir.beta_eta_contract rel, r)) (Cls.decomp t);
92
93fun prove ctxt r asms =
94  let
95    fun inst thm =
96      let val SOME (_, _, Var (r', _), _) = decomp (Thm.concl_of thm)
97      in infer_instantiate ctxt [(r', Thm.cterm_of ctxt r)] thm end;
98    fun pr (Asm i) = nth asms i
99      | pr (Thm (prfs, thm)) = map pr prfs MRS inst thm;
100  in pr end;
101
102
103(* Internal datatype for inequalities *)
104datatype rel
105   = Trans  of term * term * proof  (* R^+ *)
106   | RTrans of term * term * proof; (* R^* *)
107
108 (* Misc functions for datatype rel *)
109fun lower (Trans (x, _, _)) = x
110  | lower (RTrans (x,_,_)) = x;
111
112fun upper (Trans (_, y, _)) = y
113  | upper (RTrans (_,y,_)) = y;
114
115fun getprf   (Trans   (_, _, p)) = p
116|   getprf   (RTrans (_,_, p)) = p;
117
118(* ************************************************************************ *)
119(*                                                                          *)
120(*  mkasm_trancl Rel (t,n): term -> (term , int) -> rel list                *)
121(*                                                                          *)
122(*  Analyse assumption t with index n with respect to relation Rel:         *)
123(*  If t is of the form "(x, y) : Rel" (or Rel^+), translate to             *)
124(*  an object (singleton list) of internal datatype rel.                    *)
125(*  Otherwise return empty list.                                            *)
126(*                                                                          *)
127(* ************************************************************************ *)
128
129fun mkasm_trancl  Rel  (t, n) =
130  case decomp t of
131    SOME (x, y, rel,r) => if rel aconv Rel then
132
133    (case r of
134      "r"   => [Trans (x,y, Thm([Asm n], Cls.r_into_trancl))]
135    | "r+"  => [Trans (x,y, Asm n)]
136    | "r*"  => []
137    | _     => error ("trancl_tac: unknown relation symbol"))
138    else []
139  | NONE => [];
140
141(* ************************************************************************ *)
142(*                                                                          *)
143(*  mkasm_rtrancl Rel (t,n): term -> (term , int) -> rel list               *)
144(*                                                                          *)
145(*  Analyse assumption t with index n with respect to relation Rel:         *)
146(*  If t is of the form "(x, y) : Rel" (or Rel^+ or Rel^* ), translate to   *)
147(*  an object (singleton list) of internal datatype rel.                    *)
148(*  Otherwise return empty list.                                            *)
149(*                                                                          *)
150(* ************************************************************************ *)
151
152fun mkasm_rtrancl Rel (t, n) =
153  case decomp t of
154   SOME (x, y, rel, r) => if rel aconv Rel then
155    (case r of
156      "r"   => [ Trans (x,y, Thm([Asm n], Cls.r_into_trancl))]
157    | "r+"  => [ Trans (x,y, Asm n)]
158    | "r*"  => [ RTrans(x,y, Asm n)]
159    | _     => error ("rtrancl_tac: unknown relation symbol" ))
160   else []
161  | NONE => [];
162
163(* ************************************************************************ *)
164(*                                                                          *)
165(*  mkconcl_trancl t: term -> (term, rel, proof)                            *)
166(*  mkconcl_rtrancl t: term -> (term, rel, proof)                           *)
167(*                                                                          *)
168(*  Analyse conclusion t:                                                   *)
169(*    - must be of form "(x, y) : r^+ (or r^* for rtrancl)                  *)
170(*    - returns r                                                           *)
171(*    - conclusion in internal form                                         *)
172(*    - proof object                                                        *)
173(*                                                                          *)
174(* ************************************************************************ *)
175
176fun mkconcl_trancl  t =
177  case decomp t of
178    SOME (x, y, rel, r) => (case r of
179      "r+"  => (rel, Trans (x,y, Asm ~1), Asm 0)
180    | _     => raise Cannot)
181  | NONE => raise Cannot;
182
183fun mkconcl_rtrancl  t =
184  case decomp t of
185    SOME (x,  y, rel,r ) => (case r of
186      "r+"  => (rel, Trans (x,y, Asm ~1),  Asm 0)
187    | "r*"  => (rel, RTrans (x,y, Asm ~1), Asm 0)
188    | _     => raise Cannot)
189  | NONE => raise Cannot;
190
191(* ************************************************************************ *)
192(*                                                                          *)
193(*  makeStep (r1, r2): rel * rel -> rel                                     *)
194(*                                                                          *)
195(*  Apply transitivity to r1 and r2, obtaining a new element of r^+ or r^*, *)
196(*  according the following rules:                                          *)
197(*                                                                          *)
198(* ( (a, b) : r^+ , (b,c) : r^+ ) --> (a,c) : r^+                           *)
199(* ( (a, b) : r^* , (b,c) : r^+ ) --> (a,c) : r^+                           *)
200(* ( (a, b) : r^+ , (b,c) : r^* ) --> (a,c) : r^+                           *)
201(* ( (a, b) : r^* , (b,c) : r^* ) --> (a,c) : r^*                           *)
202(*                                                                          *)
203(* ************************************************************************ *)
204
205fun makeStep (Trans (a,_,p), Trans(_,c,q))  = Trans (a,c, Thm ([p,q], Cls.trancl_trans))
206(* refl. + trans. cls. rules *)
207|   makeStep (RTrans (a,_,p), Trans(_,c,q))  = Trans (a,c, Thm ([p,q], Cls.rtrancl_trancl_trancl))
208|   makeStep (Trans (a,_,p), RTrans(_,c,q))  = Trans (a,c, Thm ([p,q], Cls.trancl_rtrancl_trancl))
209|   makeStep (RTrans (a,_,p), RTrans(_,c,q))  = RTrans (a,c, Thm ([p,q], Cls.rtrancl_trans));
210
211(* ******************************************************************* *)
212(*                                                                     *)
213(* transPath (Clslist, Cls): (rel  list * rel) -> rel                  *)
214(*                                                                     *)
215(* If a path represented by a list of elements of type rel is found,   *)
216(* this needs to be contracted to a single element of type rel.        *)
217(* Prior to each transitivity step it is checked whether the step is   *)
218(* valid.                                                              *)
219(*                                                                     *)
220(* ******************************************************************* *)
221
222fun transPath ([],acc) = acc
223|   transPath (x::xs,acc) = transPath (xs, makeStep(acc,x))
224
225(* ********************************************************************* *)
226(* Graph functions                                                       *)
227(* ********************************************************************* *)
228
229(* *********************************************************** *)
230(* Functions for constructing graphs                           *)
231(* *********************************************************** *)
232
233fun addEdge (v,d,[]) = [(v,d)]
234|   addEdge (v,d,((u,dl)::el)) = if v aconv u then ((v,d@dl)::el)
235    else (u,dl):: (addEdge(v,d,el));
236
237(* ********************************************************************** *)
238(*                                                                        *)
239(* mkGraph constructs from a list of objects of type rel  a graph g       *)
240(* and a list of all edges with label r+.                                 *)
241(*                                                                        *)
242(* ********************************************************************** *)
243
244fun mkGraph [] = ([],[])
245|   mkGraph ys =
246 let
247  fun buildGraph ([],g,zs) = (g,zs)
248  |   buildGraph (x::xs, g, zs) =
249        case x of (Trans (_,_,_)) =>
250               buildGraph (xs, addEdge((upper x), [],(addEdge ((lower x),[((upper x),x)],g))), x::zs)
251        | _ => buildGraph (xs, addEdge((upper x), [],(addEdge ((lower x),[((upper x),x)],g))), zs)
252in buildGraph (ys, [], []) end;
253
254(* *********************************************************************** *)
255(*                                                                         *)
256(* adjacent g u : (''a * 'b list ) list -> ''a -> 'b list                  *)
257(*                                                                         *)
258(* List of successors of u in graph g                                      *)
259(*                                                                         *)
260(* *********************************************************************** *)
261
262fun adjacent eq_comp ((v,adj)::el) u =
263    if eq_comp (u, v) then adj else adjacent eq_comp el u
264|   adjacent _  []  _ = []
265
266(* *********************************************************************** *)
267(*                                                                         *)
268(* dfs eq_comp g u v:                                                      *)
269(* ('a * 'a -> bool) -> ('a  *( 'a * rel) list) list ->                    *)
270(* 'a -> 'a -> (bool * ('a * rel) list)                                    *)
271(*                                                                         *)
272(* Depth first search of v from u.                                         *)
273(* Returns (true, path(u, v)) if successful, otherwise (false, []).        *)
274(*                                                                         *)
275(* *********************************************************************** *)
276
277fun dfs eq_comp g u v =
278 let
279    val pred = Unsynchronized.ref [];
280    val visited = Unsynchronized.ref [];
281
282    fun been_visited v = exists (fn w => eq_comp (w, v)) (!visited)
283
284    fun dfs_visit u' =
285    let val _ = visited := u' :: (!visited)
286
287    fun update (x,l) = let val _ = pred := (x,l) ::(!pred) in () end;
288
289    in if been_visited v then ()
290    else (List.app (fn (v',l) => if been_visited v' then () else (
291       update (v',l);
292       dfs_visit v'; ()) )) (adjacent eq_comp g u')
293     end
294  in
295    dfs_visit u;
296    if (been_visited v) then (true, (!pred)) else (false , [])
297  end;
298
299(* *********************************************************************** *)
300(*                                                                         *)
301(* transpose g:                                                            *)
302(* (''a * ''a list) list -> (''a * ''a list) list                          *)
303(*                                                                         *)
304(* Computes transposed graph g' from g                                     *)
305(* by reversing all edges u -> v to v -> u                                 *)
306(*                                                                         *)
307(* *********************************************************************** *)
308
309fun transpose eq_comp g =
310  let
311   (* Compute list of reversed edges for each adjacency list *)
312   fun flip (u,(v,l)::el) = (v,(u,l)) :: flip (u,el)
313     | flip (_,[]) = []
314
315   (* Compute adjacency list for node u from the list of edges
316      and return a likewise reduced list of edges.  The list of edges
317      is searches for edges starting from u, and these edges are removed. *)
318   fun gather (u,(v,w)::el) =
319    let
320     val (adj,edges) = gather (u,el)
321    in
322     if eq_comp (u, v) then (w::adj,edges)
323     else (adj,(v,w)::edges)
324    end
325   | gather (_,[]) = ([],[])
326
327   (* For every node in the input graph, call gather to find all reachable
328      nodes in the list of edges *)
329   fun assemble ((u,_)::el) edges =
330       let val (adj,edges) = gather (u,edges)
331       in (u,adj) :: assemble el edges
332       end
333     | assemble [] _ = []
334
335   (* Compute, for each adjacency list, the list with reversed edges,
336      and concatenate these lists. *)
337   val flipped = maps flip g
338
339 in assemble g flipped end
340
341(* *********************************************************************** *)
342(*                                                                         *)
343(* dfs_reachable eq_comp g u:                                              *)
344(* (int * int list) list -> int -> int list                                *)
345(*                                                                         *)
346(* Computes list of all nodes reachable from u in g.                       *)
347(*                                                                         *)
348(* *********************************************************************** *)
349
350fun dfs_reachable eq_comp g u =
351 let
352  (* List of vertices which have been visited. *)
353  val visited  = Unsynchronized.ref [];
354
355  fun been_visited v = exists (fn w => eq_comp (w, v)) (!visited)
356
357  fun dfs_visit g u  =
358      let
359   val _ = visited := u :: !visited
360   val descendents =
361       List.foldr (fn ((v,_),ds) => if been_visited v then ds
362            else v :: dfs_visit g v @ ds)
363        [] (adjacent eq_comp g u)
364   in  descendents end
365
366 in u :: dfs_visit g u end;
367
368(* *********************************************************************** *)
369(*                                                                         *)
370(* dfs_term_reachable g u:                                                  *)
371(* (term * term list) list -> term -> term list                            *)
372(*                                                                         *)
373(* Computes list of all nodes reachable from u in g.                       *)
374(*                                                                         *)
375(* *********************************************************************** *)
376
377fun dfs_term_reachable g u = dfs_reachable (op aconv) g u;
378
379(* ************************************************************************ *)
380(*                                                                          *)
381(* findPath x y g: Term.term -> Term.term ->                                *)
382(*                  (Term.term * (Term.term * rel list) list) ->            *)
383(*                  (bool, rel list)                                        *)
384(*                                                                          *)
385(*  Searches a path from vertex x to vertex y in Graph g, returns true and  *)
386(*  the list of edges if path is found, otherwise false and nil.            *)
387(*                                                                          *)
388(* ************************************************************************ *)
389
390fun findPath x y g =
391  let
392   val (found, tmp) =  dfs (op aconv) g x y ;
393   val pred = map snd tmp;
394
395   fun path x y  =
396    let
397         (* find predecessor u of node v and the edge u -> v *)
398
399      fun lookup v [] = raise Cannot
400      |   lookup v (e::es) = if (upper e) aconv v then e else lookup v es;
401
402      (* traverse path backwards and return list of visited edges *)
403      fun rev_path v =
404        let val l = lookup v pred
405            val u = lower l;
406        in
407          if u aconv x then [l] else (rev_path u) @ [l]
408        end
409
410    in rev_path y end;
411
412   in
413
414
415      if found then ( (found, (path x y) )) else (found,[])
416
417
418
419   end;
420
421(* ************************************************************************ *)
422(*                                                                          *)
423(* findRtranclProof g tranclEdges subgoal:                                  *)
424(* (Term.term * (Term.term * rel list) list) -> rel -> proof list           *)
425(*                                                                          *)
426(* Searches in graph g a proof for subgoal.                                 *)
427(*                                                                          *)
428(* ************************************************************************ *)
429
430fun findRtranclProof g tranclEdges subgoal =
431   case subgoal of (RTrans (x,y,_)) => if x aconv y then [Thm ([], Cls.rtrancl_refl)] else (
432     let val (found, path) = findPath (lower subgoal) (upper subgoal) g
433     in
434       if found then (
435          let val path' = (transPath (tl path, hd path))
436          in
437
438            case path' of (Trans (_,_,p)) => [Thm ([p], Cls.trancl_into_rtrancl )]
439            | _ => [getprf path']
440
441          end
442       )
443       else raise Cannot
444     end
445   )
446
447| (Trans (x,y,_)) => (
448
449  let
450   val Vx = dfs_term_reachable g x;
451   val g' = transpose (op aconv) g;
452   val Vy = dfs_term_reachable g' y;
453
454   fun processTranclEdges [] = raise Cannot
455   |   processTranclEdges (e::es) =
456          if member (op =) Vx (upper e) andalso member (op =) Vx (lower e)
457          andalso member (op =) Vy (upper e) andalso member (op =) Vy (lower e)
458          then (
459
460
461            if (lower e) aconv x then (
462              if (upper e) aconv y then (
463                  [(getprf e)]
464              )
465              else (
466                  let
467                    val (found,path) = findPath (upper e) y g
468                  in
469
470                   if found then (
471                       [getprf (transPath (path, e))]
472                      ) else processTranclEdges es
473
474                  end
475              )
476            )
477            else if (upper e) aconv y then (
478               let val (xufound,xupath) = findPath x (lower e) g
479               in
480
481                  if xufound then (
482
483                    let val xuRTranclEdge = transPath (tl xupath, hd xupath)
484                            val xyTranclEdge = makeStep(xuRTranclEdge,e)
485
486                                in [getprf xyTranclEdge] end
487
488                 ) else processTranclEdges es
489
490               end
491            )
492            else (
493
494                let val (xufound,xupath) = findPath x (lower e) g
495                    val (vyfound,vypath) = findPath (upper e) y g
496                 in
497                    if xufound then (
498                         if vyfound then (
499                            let val xuRTranclEdge = transPath (tl xupath, hd xupath)
500                                val vyRTranclEdge = transPath (tl vypath, hd vypath)
501                                val xyTranclEdge = makeStep (makeStep(xuRTranclEdge,e),vyRTranclEdge)
502
503                                in [getprf xyTranclEdge] end
504
505                         ) else processTranclEdges es
506                    )
507                    else processTranclEdges es
508                 end
509            )
510          )
511          else processTranclEdges es;
512   in processTranclEdges tranclEdges end )
513
514
515fun solveTrancl (asms, concl) =
516 let val (g,_) = mkGraph asms
517 in
518  let val (_, subgoal, _) = mkconcl_trancl concl
519      val (found, path) = findPath (lower subgoal) (upper subgoal) g
520  in
521    if found then  [getprf (transPath (tl path, hd path))]
522    else raise Cannot
523  end
524 end;
525
526fun solveRtrancl (asms, concl) =
527 let val (g,tranclEdges) = mkGraph asms
528     val (_, subgoal, _) = mkconcl_rtrancl concl
529in
530  findRtranclProof g tranclEdges subgoal
531end;
532
533
534fun trancl_tac ctxt = SUBGOAL (fn (A, n) => fn st =>
535 let
536  val Hs = Logic.strip_assums_hyp A;
537  val C = Logic.strip_assums_concl A;
538  val (rel, _, prf) = mkconcl_trancl C;
539
540  val prems = flat (map_index (mkasm_trancl rel o swap) Hs);
541  val prfs = solveTrancl (prems, C);
542 in
543  Subgoal.FOCUS (fn {context = ctxt', prems, concl, ...} =>
544    let
545      val SOME (_, _, rel', _) = decomp (Thm.term_of concl);
546      val thms = map (prove ctxt' rel' prems) prfs
547    in resolve_tac ctxt' [prove ctxt' rel' thms prf] 1 end) ctxt n st
548 end
549 handle Cannot => Seq.empty);
550
551
552fun rtrancl_tac ctxt = SUBGOAL (fn (A, n) => fn st =>
553 let
554  val Hs = Logic.strip_assums_hyp A;
555  val C = Logic.strip_assums_concl A;
556  val (rel, _, prf) = mkconcl_rtrancl C;
557
558  val prems = flat (map_index (mkasm_rtrancl rel o swap) Hs);
559  val prfs = solveRtrancl (prems, C);
560 in
561  Subgoal.FOCUS (fn {context = ctxt', prems, concl, ...} =>
562    let
563      val SOME (_, _, rel', _) = decomp (Thm.term_of concl);
564      val thms = map (prove ctxt' rel' prems) prfs
565    in resolve_tac ctxt' [prove ctxt' rel' thms prf] 1 end) ctxt n st
566 end
567 handle Cannot => Seq.empty | General.Subscript => Seq.empty);
568
569end;
570