1(* ========================================================================= *)
2(* ORDERED REWRITING FOR FIRST ORDER TERMS                                   *)
3(* Copyright (c) 2003 Joe Hurd, distributed under the BSD License            *)
4(* ========================================================================= *)
5
6structure Rewrite :> Rewrite =
7struct
8
9open Useful;
10
11(* ------------------------------------------------------------------------- *)
12(* Orientations of equations.                                                *)
13(* ------------------------------------------------------------------------- *)
14
15datatype orient = LeftToRight | RightToLeft;
16
17fun toStringOrient ort =
18    case ort of
19      LeftToRight => "-->"
20    | RightToLeft => "<--";
21
22val ppOrient = Print.ppMap toStringOrient Print.ppString;
23
24fun toStringOrientOption orto =
25    case orto of
26      SOME ort => toStringOrient ort
27    | NONE => "<->";
28
29val ppOrientOption = Print.ppMap toStringOrientOption Print.ppString;
30
31(* ------------------------------------------------------------------------- *)
32(* A type of rewrite systems.                                                *)
33(* ------------------------------------------------------------------------- *)
34
35type reductionOrder = Term.term * Term.term -> order option;
36
37type equationId = int;
38
39type equation = Rule.equation;
40
41datatype rewrite =
42    Rewrite of
43      {order : reductionOrder,
44       known : (equation * orient option) IntMap.map,
45       redexes : (equationId * orient) TermNet.termNet,
46       subterms : (equationId * bool * Term.path) TermNet.termNet,
47       waiting : IntSet.set};
48
49fun updateWaiting rw waiting =
50    let
51      val Rewrite {order, known, redexes, subterms, waiting = _} = rw
52    in
53      Rewrite
54        {order = order, known = known, redexes = redexes,
55         subterms = subterms, waiting = waiting}
56    end;
57
58fun deleteWaiting (rw as Rewrite {waiting,...}) id =
59    updateWaiting rw (IntSet.delete waiting id);
60
61(* ------------------------------------------------------------------------- *)
62(* Basic operations                                                          *)
63(* ------------------------------------------------------------------------- *)
64
65fun new order =
66    Rewrite
67      {order = order,
68       known = IntMap.new (),
69       redexes = TermNet.new {fifo = false},
70       subterms = TermNet.new {fifo = false},
71       waiting = IntSet.empty};
72
73fun peek (Rewrite {known,...}) id = IntMap.peek known id;
74
75fun size (Rewrite {known,...}) = IntMap.size known;
76
77fun equations (Rewrite {known,...}) =
78    IntMap.foldr (fn (_,(eqn,_),eqns) => eqn :: eqns) [] known;
79
80val pp = Print.ppMap equations (Print.ppList Rule.ppEquation);
81
82(*MetisTrace1
83local
84  fun ppEq ((x_y,_),ort) =
85      Print.ppOp2 (" " ^ toStringOrientOption ort) Term.pp Term.pp x_y;
86
87  fun ppField f ppA a =
88      Print.inconsistentBlock 2
89        [Print.ppString (f ^ " ="),
90         Print.break,
91         ppA a];
92
93  val ppKnown =
94      ppField "known"
95        (Print.ppMap IntMap.toList
96           (Print.ppList (Print.ppPair Print.ppInt ppEq)));
97
98  val ppRedexes =
99      ppField "redexes"
100        (TermNet.pp (Print.ppPair Print.ppInt ppOrient));
101
102  val ppSubterms =
103      ppField "subterms"
104        (TermNet.pp
105           (Print.ppMap
106              (fn (i,l,p) => (i, (if l then 0 else 1) :: p))
107              (Print.ppPair Print.ppInt Term.ppPath)));
108
109  val ppWaiting =
110      ppField "waiting"
111        (Print.ppMap (IntSet.toList) (Print.ppList Print.ppInt));
112in
113  fun pp (Rewrite {known,redexes,subterms,waiting,...}) =
114      Print.inconsistentBlock 2
115        [Print.ppString "Rewrite",
116         Print.break,
117         Print.inconsistentBlock 1
118           [Print.ppString "{",
119            ppKnown known,
120(*MetisTrace5
121            Print.ppString ",",
122            Print.break,
123            ppRedexes redexes,
124            Print.ppString ",",
125            Print.break,
126            ppSubterms subterms,
127            Print.ppString ",",
128            Print.break,
129            ppWaiting waiting,
130*)
131            Print.skip],
132         Print.ppString "}"]
133end;
134*)
135
136val toString = Print.toString pp;
137
138(* ------------------------------------------------------------------------- *)
139(* Debug functions.                                                          *)
140(* ------------------------------------------------------------------------- *)
141
142fun termReducible order known id =
143    let
144      fun eqnRed ((l,r),_) tm =
145          case total (Subst.match Subst.empty l) tm of
146            NONE => false
147          | SOME sub =>
148            order (tm, Subst.subst (Subst.normalize sub) r) = SOME GREATER
149
150      fun knownRed tm (eqnId,(eqn,ort)) =
151          eqnId <> id andalso
152          ((ort <> SOME RightToLeft andalso eqnRed eqn tm) orelse
153           (ort <> SOME LeftToRight andalso eqnRed (Rule.symEqn eqn) tm))
154
155      fun termRed tm = IntMap.exists (knownRed tm) known orelse subtermRed tm
156      and subtermRed (Term.Var _) = false
157        | subtermRed (Term.Fn (_,tms)) = List.exists termRed tms
158    in
159      termRed
160    end;
161
162fun literalReducible order known id lit =
163    List.exists (termReducible order known id) (Literal.arguments lit);
164
165fun literalsReducible order known id lits =
166    LiteralSet.exists (literalReducible order known id) lits;
167
168fun thmReducible order known id th =
169    literalsReducible order known id (Thm.clause th);
170
171(* ------------------------------------------------------------------------- *)
172(* Add equations into the system.                                            *)
173(* ------------------------------------------------------------------------- *)
174
175fun orderToOrient (SOME EQUAL) = raise Error "Rewrite.orient: reflexive"
176  | orderToOrient (SOME GREATER) = SOME LeftToRight
177  | orderToOrient (SOME LESS) = SOME RightToLeft
178  | orderToOrient NONE = NONE;
179
180local
181  fun ins redexes redex id ort = TermNet.insert redexes (redex,(id,ort));
182in
183  fun addRedexes id (((l,r),_),ort) redexes =
184      case ort of
185        SOME LeftToRight => ins redexes l id LeftToRight
186      | SOME RightToLeft => ins redexes r id RightToLeft
187      | NONE => ins (ins redexes l id LeftToRight) r id RightToLeft;
188end;
189
190fun add (rw as Rewrite {known,...}) (id,eqn) =
191    if IntMap.inDomain id known then rw
192    else
193      let
194        val Rewrite {order,redexes,subterms,waiting, ...} = rw
195
196        val ort = orderToOrient (order (fst eqn))
197
198        val known = IntMap.insert known (id,(eqn,ort))
199
200        val redexes = addRedexes id (eqn,ort) redexes
201
202        val waiting = IntSet.add waiting id
203
204        val rw =
205            Rewrite
206              {order = order, known = known, redexes = redexes,
207               subterms = subterms, waiting = waiting}
208(*MetisTrace5
209        val () = Print.trace pp "Rewrite.add: result" rw
210*)
211      in
212        rw
213      end;
214
215local
216  fun uncurriedAdd (eqn,rw) = add rw eqn;
217in
218  fun addList rw = List.foldl uncurriedAdd rw;
219end;
220
221(* ------------------------------------------------------------------------- *)
222(* Rewriting (the order must be a refinement of the rewrite order).          *)
223(* ------------------------------------------------------------------------- *)
224
225local
226  fun reorder ((i,_),(j,_)) = Int.compare (j,i);
227in
228  fun matchingRedexes redexes tm = sort reorder (TermNet.match redexes tm);
229end;
230
231fun wellOriented NONE _ = true
232  | wellOriented (SOME LeftToRight) LeftToRight = true
233  | wellOriented (SOME RightToLeft) RightToLeft = true
234  | wellOriented _ _ = false;
235
236fun redexResidue LeftToRight ((l_r,_) : equation) = l_r
237  | redexResidue RightToLeft ((l,r),_) = (r,l);
238
239fun orientedEquation LeftToRight eqn = eqn
240  | orientedEquation RightToLeft eqn = Rule.symEqn eqn;
241
242fun rewrIdConv' order known redexes id tm =
243    let
244      fun rewr (id',lr) =
245          let
246            val _ = id <> id' orelse raise Error "same theorem"
247            val (eqn,ort) = IntMap.get known id'
248            val _ = wellOriented ort lr orelse raise Error "orientation"
249            val (l,r) = redexResidue lr eqn
250            val sub = Subst.normalize (Subst.match Subst.empty l tm)
251            val tm' = Subst.subst sub r
252            val _ = Option.isSome ort orelse
253                    order (tm,tm') = SOME GREATER orelse
254                    raise Error "order"
255            val (_,th) = orientedEquation lr eqn
256          in
257            (tm', Thm.subst sub th)
258          end
259    in
260      case first (total rewr) (matchingRedexes redexes tm) of
261        NONE => raise Error "Rewrite.rewrIdConv: no matching rewrites"
262      | SOME res => res
263    end;
264
265fun rewriteIdConv' order known redexes id =
266    if IntMap.null known then Rule.allConv
267    else Rule.repeatTopDownConv (rewrIdConv' order known redexes id);
268
269fun mkNeqConv order lit =
270    let
271      val (l,r) = Literal.destNeq lit
272    in
273      case order (l,r) of
274        NONE => raise Error "incomparable"
275      | SOME LESS =>
276        let
277          val th = Rule.symmetryRule l r
278        in
279          fn tm =>
280             if Term.equal tm r then (l,th) else raise Error "mkNeqConv: RL"
281        end
282      | SOME EQUAL => raise Error "irreflexive"
283      | SOME GREATER =>
284        let
285          val th = Thm.assume lit
286        in
287          fn tm =>
288             if Term.equal tm l then (r,th) else raise Error "mkNeqConv: LR"
289        end
290    end;
291
292datatype neqConvs = NeqConvs of Rule.conv LiteralMap.map;
293
294val neqConvsEmpty = NeqConvs (LiteralMap.new ());
295
296fun neqConvsNull (NeqConvs m) = LiteralMap.null m;
297
298fun neqConvsAdd order (neq as NeqConvs m) lit =
299    case total (mkNeqConv order) lit of
300      NONE => NONE
301    | SOME conv => SOME (NeqConvs (LiteralMap.insert m (lit,conv)));
302
303fun mkNeqConvs order =
304    let
305      fun add (lit,(neq,lits)) =
306          case neqConvsAdd order neq lit of
307            SOME neq => (neq,lits)
308          | NONE => (neq, LiteralSet.add lits lit)
309    in
310      LiteralSet.foldl add (neqConvsEmpty,LiteralSet.empty)
311    end;
312
313fun neqConvsDelete (NeqConvs m) lit = NeqConvs (LiteralMap.delete m lit);
314
315fun neqConvsToConv (NeqConvs m) =
316    Rule.firstConv (LiteralMap.foldr (fn (_,c,l) => c :: l) [] m);
317
318fun neqConvsFoldl f b (NeqConvs m) =
319    LiteralMap.foldl (fn (l,_,z) => f (l,z)) b m;
320
321fun neqConvsRewrIdLiterule order known redexes id neq =
322    if IntMap.null known andalso neqConvsNull neq then Rule.allLiterule
323    else
324      let
325        val neq_conv = neqConvsToConv neq
326        val rewr_conv = rewrIdConv' order known redexes id
327        val conv = Rule.orelseConv neq_conv rewr_conv
328        val conv = Rule.repeatTopDownConv conv
329      in
330        Rule.allArgumentsLiterule conv
331      end;
332
333fun rewriteIdEqn' order known redexes id (eqn as (l_r,th)) =
334    let
335      val (neq,_) = mkNeqConvs order (Thm.clause th)
336      val literule = neqConvsRewrIdLiterule order known redexes id neq
337      val (strongEqn,lit) =
338          case Rule.equationLiteral eqn of
339            NONE => (true, Literal.mkEq l_r)
340          | SOME lit => (false,lit)
341      val (lit',litTh) = literule lit
342    in
343      if Literal.equal lit lit' then eqn
344      else
345        (Literal.destEq lit',
346         if strongEqn then th
347         else if not (Thm.negateMember lit litTh) then litTh
348         else Thm.resolve lit th litTh)
349    end
350(*MetisDebug
351    handle Error err => raise Error ("Rewrite.rewriteIdEqn':\n" ^ err);
352*)
353
354fun rewriteIdLiteralsRule' order known redexes id lits th =
355    let
356      val mk_literule = neqConvsRewrIdLiterule order known redexes id
357
358      fun rewr_neq_lit (lit, acc as (changed,neq,lits,th)) =
359          let
360            val neq = neqConvsDelete neq lit
361            val (lit',litTh) = mk_literule neq lit
362          in
363            if Literal.equal lit lit' then acc
364            else
365              let
366                val th = Thm.resolve lit th litTh
367              in
368                case neqConvsAdd order neq lit' of
369                  SOME neq => (true,neq,lits,th)
370                | NONE => (changed, neq, LiteralSet.add lits lit', th)
371              end
372          end
373
374      fun rewr_neq_lits neq lits th =
375          let
376            val (changed,neq,lits,th) =
377                neqConvsFoldl rewr_neq_lit (false,neq,lits,th) neq
378          in
379            if changed then rewr_neq_lits neq lits th
380            else (neq,lits,th)
381          end
382
383      val (neq,lits) = mkNeqConvs order lits
384
385      val (neq,lits,th) = rewr_neq_lits neq lits th
386
387      val rewr_literule = mk_literule neq
388
389      fun rewr_lit (lit,th) =
390          if Thm.member lit th then Rule.literalRule rewr_literule lit th
391          else th
392    in
393      LiteralSet.foldl rewr_lit th lits
394    end;
395
396fun rewriteIdRule' order known redexes id th =
397    rewriteIdLiteralsRule' order known redexes id (Thm.clause th) th;
398
399(*MetisDebug
400val rewriteIdRule' = fn order => fn known => fn redexes => fn id => fn th =>
401    let
402(*MetisTrace6
403      val () = Print.trace Thm.pp "Rewrite.rewriteIdRule': th" th
404*)
405      val result = rewriteIdRule' order known redexes id th
406(*MetisTrace6
407      val () = Print.trace Thm.pp "Rewrite.rewriteIdRule': result" result
408*)
409      val _ = not (thmReducible order known id result) orelse
410              raise Bug "rewriteIdRule: should be normalized"
411    in
412      result
413    end
414    handle Error err => raise Error ("Rewrite.rewriteIdRule:\n" ^ err);
415*)
416
417fun rewrIdConv (Rewrite {known,redexes,...}) order =
418    rewrIdConv' order known redexes;
419
420fun rewrConv rewrite order = rewrIdConv rewrite order ~1;
421
422fun rewriteIdConv (Rewrite {known,redexes,...}) order =
423    rewriteIdConv' order known redexes;
424
425fun rewriteConv rewrite order = rewriteIdConv rewrite order ~1;
426
427fun rewriteIdLiteralsRule (Rewrite {known,redexes,...}) order =
428    rewriteIdLiteralsRule' order known redexes;
429
430fun rewriteLiteralsRule rewrite order =
431    rewriteIdLiteralsRule rewrite order ~1;
432
433fun rewriteIdRule (Rewrite {known,redexes,...}) order =
434    rewriteIdRule' order known redexes;
435
436fun rewriteRule rewrite order = rewriteIdRule rewrite order ~1;
437
438(* ------------------------------------------------------------------------- *)
439(* Inter-reduce the equations in the system.                                 *)
440(* ------------------------------------------------------------------------- *)
441
442fun addSubterms id (((l,r),_) : equation) subterms =
443    let
444      fun addSubterm b ((path,tm),net) = TermNet.insert net (tm,(id,b,path))
445
446      val subterms = List.foldl (addSubterm true) subterms (Term.subterms l)
447
448      val subterms = List.foldl (addSubterm false) subterms (Term.subterms r)
449    in
450      subterms
451    end;
452
453fun sameRedexes NONE _ _ = false
454  | sameRedexes (SOME LeftToRight) (l0,_) (l,_) = Term.equal l0 l
455  | sameRedexes (SOME RightToLeft) (_,r0) (_,r) = Term.equal r0 r;
456
457fun redexResidues NONE (l,r) = [(l,r,false),(r,l,false)]
458  | redexResidues (SOME LeftToRight) (l,r) = [(l,r,true)]
459  | redexResidues (SOME RightToLeft) (l,r) = [(r,l,true)];
460
461fun findReducibles order known subterms id =
462    let
463      fun checkValidRewr (l,r,ord) id' left path =
464          let
465            val (((x,y),_),_) = IntMap.get known id'
466            val tm = Term.subterm (if left then x else y) path
467            val sub = Subst.match Subst.empty l tm
468          in
469            if ord then ()
470            else
471              let
472                val tm' = Subst.subst (Subst.normalize sub) r
473              in
474                if order (tm,tm') = SOME GREATER then ()
475                else raise Error "order"
476              end
477          end
478
479      fun addRed lr ((id',left,path),todo) =
480          if id <> id' andalso not (IntSet.member id' todo) andalso
481             can (checkValidRewr lr id' left) path
482          then IntSet.add todo id'
483          else todo
484
485      fun findRed (lr as (l,_,_), todo) =
486          List.foldl (addRed lr) todo (TermNet.matched subterms l)
487    in
488      List.foldl findRed
489    end;
490
491fun reduce1 new id (eqn0,ort0) (rpl,spl,todo,rw,changed) =
492    let
493      val (eq0,_) = eqn0
494      val Rewrite {order,known,redexes,subterms,waiting} = rw
495      val eqn as (eq,_) = rewriteIdEqn' order known redexes id eqn0
496      val identical =
497          let
498            val (l0,r0) = eq0
499            and (l,r) = eq
500          in
501            Term.equal l l0 andalso Term.equal r r0
502          end
503      val same_redexes = identical orelse sameRedexes ort0 eq0 eq
504      val rpl = if same_redexes then rpl else IntSet.add rpl id
505      val spl = if new orelse identical then spl else IntSet.add spl id
506      val changed =
507          if not new andalso identical then changed else IntSet.add changed id
508      val ort =
509          if same_redexes then SOME ort0 else total orderToOrient (order eq)
510    in
511      case ort of
512        NONE =>
513        let
514          val known = IntMap.delete known id
515          val rw =
516              Rewrite
517                {order = order, known = known, redexes = redexes,
518                 subterms = subterms, waiting = waiting}
519        in
520          (rpl,spl,todo,rw,changed)
521        end
522      | SOME ort =>
523        let
524          val todo =
525              if not new andalso same_redexes then todo
526              else
527                findReducibles
528                  order known subterms id todo (redexResidues ort eq)
529          val known =
530              if identical then known else IntMap.insert known (id,(eqn,ort))
531          val redexes =
532              if same_redexes then redexes
533              else addRedexes id (eqn,ort) redexes
534          val subterms =
535              if new orelse not identical then addSubterms id eqn subterms
536              else subterms
537          val rw =
538              Rewrite
539                {order = order, known = known, redexes = redexes,
540                 subterms = subterms, waiting = waiting}
541        in
542          (rpl,spl,todo,rw,changed)
543        end
544    end;
545
546fun pick known set =
547    let
548      fun oriented id =
549          case IntMap.peek known id of
550            SOME (x as (_, SOME _)) => SOME (id,x)
551          | _ => NONE
552
553      fun any id =
554          case IntMap.peek known id of SOME x => SOME (id,x) | _ => NONE
555    in
556      case IntSet.firstl oriented set of
557        x as SOME _ => x
558      | NONE => IntSet.firstl any set
559    end;
560
561local
562  fun cleanRedexes known redexes rpl =
563      if IntSet.null rpl then redexes
564      else
565        let
566          fun filt (id,_) = not (IntSet.member id rpl)
567
568          fun addReds (id,reds) =
569              case IntMap.peek known id of
570                NONE => reds
571              | SOME eqn_ort => addRedexes id eqn_ort reds
572
573          val redexes = TermNet.filter filt redexes
574          val redexes = IntSet.foldl addReds redexes rpl
575        in
576          redexes
577        end;
578
579  fun cleanSubterms known subterms spl =
580      if IntSet.null spl then subterms
581      else
582        let
583          fun filt (id,_,_) = not (IntSet.member id spl)
584
585          fun addSubtms (id,subtms) =
586              case IntMap.peek known id of
587                NONE => subtms
588              | SOME (eqn,_) => addSubterms id eqn subtms
589
590          val subterms = TermNet.filter filt subterms
591          val subterms = IntSet.foldl addSubtms subterms spl
592        in
593          subterms
594        end;
595in
596  fun rebuild rpl spl rw =
597      let
598(*MetisTrace5
599        val ppPl = Print.ppMap IntSet.toList (Print.ppList Print.ppInt)
600        val () = Print.trace ppPl "Rewrite.rebuild: rpl" rpl
601        val () = Print.trace ppPl "Rewrite.rebuild: spl" spl
602*)
603        val Rewrite {order,known,redexes,subterms,waiting} = rw
604        val redexes = cleanRedexes known redexes rpl
605        val subterms = cleanSubterms known subterms spl
606      in
607        Rewrite
608          {order = order,
609           known = known,
610           redexes = redexes,
611           subterms = subterms,
612           waiting = waiting}
613      end;
614end;
615
616fun reduceAcc (rpl, spl, todo, rw as Rewrite {known,waiting,...}, changed) =
617    case pick known todo of
618      SOME (id,eqn_ort) =>
619      let
620        val todo = IntSet.delete todo id
621      in
622        reduceAcc (reduce1 false id eqn_ort (rpl,spl,todo,rw,changed))
623      end
624    | NONE =>
625      case pick known waiting of
626        SOME (id,eqn_ort) =>
627        let
628          val rw = deleteWaiting rw id
629        in
630          reduceAcc (reduce1 true id eqn_ort (rpl,spl,todo,rw,changed))
631        end
632      | NONE => (rebuild rpl spl rw, IntSet.toList changed);
633
634fun isReduced (Rewrite {waiting,...}) = IntSet.null waiting;
635
636fun reduce' rw =
637    if isReduced rw then (rw,[])
638    else reduceAcc (IntSet.empty,IntSet.empty,IntSet.empty,rw,IntSet.empty);
639
640(*MetisDebug
641val reduce' = fn rw =>
642    let
643(*MetisTrace4
644      val () = Print.trace pp "Rewrite.reduce': rw" rw
645*)
646      val Rewrite {known,order,...} = rw
647      val result as (Rewrite {known = known', ...}, _) = reduce' rw
648(*MetisTrace4
649      val ppResult = Print.ppPair pp (Print.ppList Print.ppInt)
650      val () = Print.trace ppResult "Rewrite.reduce': result" result
651*)
652      val ths = List.map (fn (id,((_,th),_)) => (id,th)) (IntMap.toList known')
653      val _ =
654          not (List.exists (uncurry (thmReducible order known')) ths) orelse
655          raise Bug "Rewrite.reduce': not fully reduced"
656    in
657      result
658    end
659    handle Error err => raise Bug ("Rewrite.reduce': shouldn't fail\n" ^ err);
660*)
661
662fun reduce rw = fst (reduce' rw);
663
664(* ------------------------------------------------------------------------- *)
665(* Rewriting as a derived rule.                                              *)
666(* ------------------------------------------------------------------------- *)
667
668local
669  fun addEqn (id_eqn,rw) = add rw id_eqn;
670in
671  fun orderedRewrite order ths =
672    let
673      val rw = List.foldl addEqn (new order) (enumerate ths)
674    in
675      rewriteRule rw order
676    end;
677end;
678
679local
680  val order : reductionOrder = K (SOME GREATER);
681in
682  val rewrite = orderedRewrite order;
683end;
684
685end
686