1249261Sdim(* ========================================================================= *)
2249261Sdim(* ORDERED REWRITING                                                         *)
3249261Sdim(* Copyright (c) 2003-2004 Joe Hurd.                                         *)
4249261Sdim(* ========================================================================= *)
5249261Sdim
6249261Sdim(*
7249261Sdimapp load ["mlibHeap", "mlibTerm", "mlibSubst", "mlibMatch", "mlibThm", "mlibTermorder"];
8249261Sdim*)
9249261Sdim
10249261Sdim(*
11249261Sdim*)
12263509Sdimstructure mlibRewrite :> mlibRewrite =
13249261Sdimstruct
14249261Sdim
15263509Sdiminfix ## |-> ::>;
16249261Sdim
17249261Sdimopen mlibUseful mlibTerm mlibThm mlibMatch;
18249261Sdim
19249261Sdimstructure O = Option; local open Option in end;
20249261Sdimstructure M = Intmap; local open Intmap in end;
21249261Sdimstructure S = Intset; local open Intset in end;
22249261Sdimstructure T = mlibTermnet; local open mlibTermnet in end;
23249261Sdim
24249261Sdimtype 'a intmap  = 'a M.intmap;
25249261Sdimtype intset     = S.intset;
26249261Sdimtype subst      = mlibSubst.subst;
27249261Sdimtype 'a termnet = 'a T.termnet;
28249261Sdim
29249261Sdimval |<>|          = mlibSubst.|<>|;
30249261Sdimval op::>         = mlibSubst.::>;
31249261Sdimval term_subst    = mlibSubst.term_subst;
32249261Sdimval formula_subst = mlibSubst.formula_subst;
33249261Sdim
34249261Sdim(* ------------------------------------------------------------------------- *)
35249261Sdim(* Chatting.                                                                 *)
36249261Sdim(* ------------------------------------------------------------------------- *)
37249261Sdim
38249261Sdimval module = "mlibRewrite";
39249261Sdimval () = add_trace {module = module, alignment = I}
40249261Sdimfun chatting l = tracing {module = module, level = l};
41249261Sdimfun chat s = (trace s; true)
42249261Sdim
43249261Sdim(* ------------------------------------------------------------------------- *)
44249261Sdim(* Helper functions.                                                         *)
45249261Sdim(* ------------------------------------------------------------------------- *)
46249261Sdim
47263509Sdimval blind_pick = S.find (K true);
48249261Sdim
49249261Sdimfun retrieve known i =
50249261Sdim  (case M.peek (known,i) of SOME rw_ort => rw_ort
51249261Sdim   | NONE => raise Error "rewrite: rewr has been rewritten away!");
52249261Sdim
53249261Sdim(* ------------------------------------------------------------------------- *)
54249261Sdim(* Representing ordered rewrites.                                            *)
55249261Sdim(* ------------------------------------------------------------------------- *)
56249261Sdim
57249261Sdimdatatype orient = LtoR | RtoL | Both;
58249261Sdim
59249261Sdimdatatype rewrs = REWRS of
60249261Sdim  {order    : term * term -> order option,
61249261Sdim   known    : (thm * orient) intmap,
62249261Sdim   rewrites : (int * bool) termnet,
63249261Sdim   subterms : (int * int list) termnet,
64249261Sdim   waiting  : intset};
65249261Sdim
66249261Sdimfun update_waiting waiting rw =
67249261Sdim  let
68249261Sdim    val REWRS {order, known, rewrites, subterms, waiting = _} = rw
69249261Sdim  in
70249261Sdim    REWRS {order = order, known = known, rewrites = rewrites,
71249261Sdim           subterms = subterms, waiting = waiting}
72249261Sdim  end;
73249261Sdim
74249261Sdimfun waiting_del i (rw as REWRS {waiting, ...}) =
75249261Sdim  update_waiting (S.delete (waiting,i)) rw;
76249261Sdim
77249261Sdim(* ------------------------------------------------------------------------- *)
78249261Sdim(* Basic operations                                                          *)
79249261Sdim(* ------------------------------------------------------------------------- *)
80249261Sdim
81249261Sdimfun empty order =
82249261Sdim  REWRS {order = order, known = M.empty (), rewrites = T.empty {fifo = false},
83249261Sdim         subterms = T.empty {fifo = false}, waiting = S.empty};
84249261Sdim
85249261Sdimfun reset (REWRS {order, ...}) = empty order;
86249261Sdim
87249261Sdimfun peek (REWRS {known, ...}) i = M.peek (known,i);
88249261Sdim
89249261Sdimfun size (REWRS {known, ...}) = M.numItems known;
90249261Sdim
91249261Sdimfun eqns (REWRS {known, ...}) =
92249261Sdim  map (fn (i,(th,_)) => th) (M.listItems known);
93249261Sdim
94249261Sdim(* ------------------------------------------------------------------------- *)
95249261Sdim(* Pretty-printing                                                           *)
96249261Sdim(* ------------------------------------------------------------------------- *)
97249261Sdim
98249261Sdimlocal fun f LtoR = "LtoR" | f RtoL = "RtoL" | f Both = "Both";
99249261Sdimin val pp_orient = pp_map f pp_string;
100249261Sdimend;
101249261Sdim
102249261Sdimlocal
103249261Sdim  val simple = pp_map eqns (pp_list pp_thm);
104249261Sdim
105249261Sdim  fun kws (REWRS {known, waiting, subterms, ...}) =
106249261Sdim    (M.listItems known,
107249261Sdim     S.listItems waiting,
108249261Sdim     subterms);
109249261Sdim
110249261Sdim  val pp_kws =
111249261Sdim    pp_triple
112249261Sdim    (pp_list (pp_pair pp_int (pp_pair pp_thm pp_orient)))
113249261Sdim    (pp_list pp_int)
114249261Sdim    (T.pp_termnet (pp_pair pp_int (pp_list pp_int)));
115249261Sdim
116249261Sdim  val complicated = pp_map kws pp_kws;
117249261Sdimin
118249261Sdim  fun pp_rewrs pp = (if chatting 3 then complicated else simple) pp;
119249261Sdimend;
120249261Sdim
121249261Sdimfun rewrs_to_string rw = PP.pp_to_string (!LINE_LENGTH) pp_rewrs rw;
122249261Sdim
123249261Sdimfun chatrewrs s rw =
124249261Sdim  chat (module ^ "." ^ s ^ ":\n" ^ rewrs_to_string rw ^ "\n");
125249261Sdim
126249261Sdim(* ------------------------------------------------------------------------- *)
127249261Sdim(* Add an equation into the system                                           *)
128249261Sdim(* ------------------------------------------------------------------------- *)
129249261Sdim
130249261Sdimfun orient (SOME EQUAL) = NONE
131249261Sdim  | orient (SOME GREATER) = SOME LtoR
132249261Sdim  | orient (SOME LESS) = SOME RtoL
133249261Sdim  | orient NONE = SOME Both;
134249261Sdim
135249261Sdimfun add_rewrite i (th,ort) rewrites =
136249261Sdim  let
137249261Sdim    val (l,r) = dest_unit_eq th
138249261Sdim  in
139249261Sdim    case ort of
140249261Sdim      LtoR => T.insert (l |-> (i,true)) rewrites
141249261Sdim    | RtoL => T.insert (r |-> (i,false)) rewrites
142249261Sdim    | Both => T.insert (l |-> (i,true)) (T.insert (r |-> (i,false)) rewrites)
143249261Sdim  end;
144249261Sdim
145249261Sdimfun add (i,th) (rw as REWRS {known, ...}) =
146249261Sdim  if Option.isSome (M.peek (known,i)) then rw else
147249261Sdim    let
148249261Sdim      val REWRS {order, rewrites, subterms, waiting, ...} = rw
149249261Sdim      val ort =
150249261Sdim        case orient (order (dest_unit_eq th)) of SOME x => x
151249261Sdim        | NONE => raise Bug "mlibRewrite.add: can't add reflexive eqns"
152249261Sdim      val known = M.insert (known, i, (th,ort))
153249261Sdim      val rewrites = add_rewrite i (th,ort) rewrites
154249261Sdim      val waiting = S.add (waiting,i)
155249261Sdim      val rw = REWRS {order = order, known = known, rewrites = rewrites,
156249261Sdim                      subterms = subterms, waiting = waiting}
157249261Sdim      val _ = chatting 1 andalso chatrewrs "add" rw
158249261Sdim    in
159249261Sdim      rw
160249261Sdim    end;
161249261Sdim
162249261Sdim(* ------------------------------------------------------------------------- *)
163249261Sdim(* Rewriting (the order must be a refinement of the initial order)           *)
164249261Sdim(* ------------------------------------------------------------------------- *)
165249261Sdim
166263509Sdimfun thm_match known order (i,th) =
167249261Sdim  let
168249261Sdim    fun orw (l,r) tm =
169249261Sdim      let val sub = match l tm
170249261Sdim      in assert (order (tm, term_subst sub r) = SOME GREATER) (Error "orw")
171249261Sdim      end
172249261Sdim    fun rw ((l,_),LtoR) tm = can (match l) tm
173249261Sdim      | rw ((_,r),RtoL) tm = can (match r) tm
174249261Sdim      | rw ((l,r),Both) tm = can (orw (l,r)) tm orelse can (orw (r,l)) tm
175249261Sdim    fun f (_,(th,ort)) = (dest_unit_eq th, ort)
176249261Sdim    val eqs = (map f o List.filter (not o equal i o fst) o M.listItems) known
177249261Sdim    fun can_rw tm = List.exists (fn eq => rw eq tm) eqs orelse can_depth tm
178249261Sdim    and can_depth (Var _) = false
179249261Sdim      | can_depth (Fn (_,tms)) = List.exists can_rw tms
180249261Sdim    val lit_match = can_depth o dest_atom o literal_atom
181249261Sdim  in
182249261Sdim    List.exists lit_match (clause th)
183249261Sdim  end;
184249261Sdim
185249261Sdimlocal
186249261Sdim  fun agree false LtoR = false | agree true RtoL = false | agree _ _ = true;
187249261Sdim
188249261Sdim  fun redex_residue lr th = (if lr then I else swap) (dest_unit_eq th);
189249261Sdim
190249261Sdim  local val reorder = sort (fn ((i,_),(j,_)) => Int.compare (j,i));
191249261Sdim  in fun get_rewrs rw tm = reorder (T.match rw tm);
192249261Sdim  end;
193249261Sdim
194249261Sdim  local
195249261Sdim    fun compile_neq (SOME LtoR, lit) =
196249261Sdim      let val lit' = dest_neg lit val (l,r) = dest_eq lit'
197249261Sdim      in SOME (l, (ASSUME lit', r, true))
198249261Sdim      end
199249261Sdim      | compile_neq (SOME RtoL, lit) =
200249261Sdim      let val lit' = dest_neg lit val (l,r) = dest_eq lit'
201249261Sdim      in SOME (r, (ASSUME lit', l, false))
202249261Sdim      end
203249261Sdim      | compile_neq _ = NONE;
204249261Sdim  in
205249261Sdim    val compile_neqs = List.mapPartial compile_neq;
206249261Sdim  end;
207249261Sdim
208249261Sdim  fun rewr known rewrites order i =
209249261Sdim    let
210249261Sdim      fun rewr_lit neqs =
211249261Sdim        let
212249261Sdim          fun f tm (j,lr) =
213249261Sdim            let
214249261Sdim              val () = assert (j <> i) (Error "rewrite: same theorem")
215249261Sdim              val (rw,ort) = retrieve known j
216249261Sdim              val () = assert (agree lr ort) (Error "rewrite: bad orientation")
217249261Sdim              val (l,r) = redex_residue lr rw
218249261Sdim              val sub = match l tm
219249261Sdim              val r' = term_subst sub r
220249261Sdim              val () = assert
221249261Sdim                (ort <> Both orelse order (tm,r') = SOME GREATER)
222249261Sdim                (Error "rewrite: order violation")
223249261Sdim            in
224249261Sdim              (INST sub rw, r', lr)
225249261Sdim            end
226249261Sdim          fun rewr_conv tm = first (total (f tm)) (get_rewrs rewrites tm)
227249261Sdim          fun neq_conv tm = Option.map snd (List.find (equal tm o fst) neqs)
228249261Sdim          fun conv tm =
229249261Sdim            case rewr_conv tm of SOME x => x
230249261Sdim            | NONE => (case neq_conv tm of SOME x => x
231249261Sdim                       | NONE => raise Error "rewrite: no matching rewrites")
232249261Sdim        in
233249261Sdim          DEPTH1 conv
234249261Sdim        end
235249261Sdim
236249261Sdim      fun orient_neq neq = orient (order (dest_eq (negate neq)))
237249261Sdim
238249261Sdim      fun orient_neqs neqs = map (fn neq => (orient_neq neq, neq)) neqs
239249261Sdim
240249261Sdim      fun rewr_neqs dealt [] th = (rev dealt, th)
241249261Sdim        | rewr_neqs dealt ((ort,neq) :: neqs) th =
242249261Sdim        if not (mem neq (clause th)) then rewr_neqs dealt neqs th else
243249261Sdim          let
244249261Sdim            val other_neqs = List.revAppend (dealt,neqs)
245249261Sdim            val (th,neq') = rewr_lit (compile_neqs other_neqs) (th,neq)
246249261Sdim          in
247249261Sdim            if neq' = neq then rewr_neqs ((ort,neq) :: dealt) neqs th else
248249261Sdim              let
249249261Sdim                val ort = orient_neq neq'
250249261Sdim                val active = ort = SOME LtoR orelse ort = SOME RtoL
251249261Sdim              in
252249261Sdim                if active then rewr_neqs [(ort,neq')] other_neqs th
253249261Sdim                else rewr_neqs ((ort,neq') :: dealt) neqs th
254249261Sdim              end
255249261Sdim          end
256249261Sdim
257249261Sdim      fun rewr' th =
258249261Sdim        let
259249261Sdim          val lits = clause th
260249261Sdim          val (neqs,rest) = List.partition (is_eq o negate) lits
261249261Sdim          val (neqs,th) = rewr_neqs [] (orient_neqs neqs) th
262249261Sdim          val neqs = compile_neqs neqs
263249261Sdim        in
264249261Sdim          if M.numItems known = 0 andalso null neqs then th
265249261Sdim          else foldl (fst o rewr_lit neqs o swap) th rest
266249261Sdim        end
267249261Sdim    in
268249261Sdim      fn th =>
269249261Sdim      if not (chatting 2) then rewr' th else
270249261Sdim        let
271249261Sdim          val th' = rewr' th
272249261Sdim          val m = thm_match known order (i,th')
273249261Sdim          val _ = chat ("rewrite:\n" ^ thm_to_string th
274249261Sdim                        ^ "\n ->\n" ^ thm_to_string th' ^ "\n")
275249261Sdim          val () = assert (not m) (Bug "rewrite: should be normalized")
276249261Sdim        in
277249261Sdim          th'
278249261Sdim        end
279249261Sdim    end;
280249261Sdimin
281249261Sdim  fun rewrite (REWRS {known,rewrites,...}) order (i,th) =
282249261Sdim    rewr known rewrites order i th;
283249261Sdimend;
284249261Sdim
285249261Sdim(* ------------------------------------------------------------------------- *)
286249261Sdim(* Inter-reduce the equations in the system                                  *)
287249261Sdim(* ------------------------------------------------------------------------- *)
288249261Sdim
289249261Sdimfun add_subterms i =
290249261Sdim  let fun f ((p |-> tm), subterms) = T.insert (tm |-> (i,p)) subterms
291249261Sdim  in fn th => fn subterms => foldl f subterms (literal_subterms (dest_unit th))
292249261Sdim  end;
293249261Sdim
294249261Sdimfun same_redex eq ort eq' =
295249261Sdim  let
296249261Sdim    val (l,r) = dest_eq eq
297249261Sdim    val (l',r') = dest_eq eq'
298249261Sdim  in
299249261Sdim    case ort of
300249261Sdim      LtoR => l = l'
301249261Sdim    | RtoL => r = r'
302249261Sdim    | Both => l = l' andalso r = r'
303249261Sdim  end;
304249261Sdim
305249261Sdimfun redex_residues eq ort =
306249261Sdim  let
307249261Sdim    val (l,r) = dest_eq eq
308249261Sdim  in
309249261Sdim    case ort of
310249261Sdim      LtoR => [(l,r,true)]
311249261Sdim    | RtoL => [(r,l,true)]
312249261Sdim    | Both => [(l,r,false),(r,l,false)]
313249261Sdim  end;
314249261Sdim
315249261Sdimfun find_rws order known subterms i =
316249261Sdim  let
317249261Sdim    fun valid_rw (l,r,ord) (j,p) =
318249261Sdim      let
319249261Sdim        val t = literal_subterm p (dest_unit (fst (retrieve known j)))
320249261Sdim        val s = match l t
321249261Sdim      in
322249261Sdim        assert (ord orelse order (t, term_subst s r) = SOME GREATER)
323249261Sdim               (Error "valid: violates order")
324249261Sdim      end
325249261Sdim
326249261Sdim    fun check_subtm lr (jp as (j,_), todo) =
327249261Sdim      if i <> j andalso not (S.member (todo,j)) andalso can (valid_rw lr) jp
328249261Sdim      then S.add (todo,j) else todo
329249261Sdim
330249261Sdim    fun find (lr as (l,_,_), todo) =
331249261Sdim      foldl (check_subtm lr) todo (T.matched subterms l)
332249261Sdim  in
333249261Sdim    foldl find
334249261Sdim  end;
335249261Sdim
336249261Sdimfun reduce1 new i (rpl,spl,todo,rw) =
337249261Sdim  let
338249261Sdim    val REWRS {order, known, rewrites, subterms, waiting} = rw
339249261Sdim    val (th0,ort0) = M.retrieve (known,i)
340249261Sdim    val eq0 = dest_unit th0
341249261Sdim    val th = rewrite rw order (i,th0)
342249261Sdim    val eq = dest_unit th
343249261Sdim    val identical = eq = eq0
344249261Sdim    val same_red = identical orelse (ort0<>Both andalso same_redex eq0 ort0 eq)
345249261Sdim    val rpl = if same_red then rpl else S.add (rpl,i)
346249261Sdim    val spl = if new orelse identical then spl else S.add (spl,i)
347249261Sdim  in
348249261Sdim    case (if same_red then SOME ort0 else orient (order (dest_eq eq))) of
349249261Sdim      NONE =>
350249261Sdim      (rpl, spl, todo,
351249261Sdim       REWRS {order = order, known = fst (M.remove (known,i)),
352249261Sdim              rewrites = rewrites, subterms = subterms, waiting = waiting})
353249261Sdim    | SOME ort =>
354249261Sdim      let
355249261Sdim        val known = if identical then known else M.insert (known,i,(th,ort))
356249261Sdim        val rewrites =
357249261Sdim          if same_red then rewrites else add_rewrite i (th,ort) rewrites
358249261Sdim        val todo =
359249261Sdim          if same_red andalso not new then todo
360249261Sdim          else find_rws order known subterms i todo (redex_residues eq ort)
361249261Sdim        val subterms =
362249261Sdim          if identical andalso not new then subterms
363249261Sdim          else add_subterms i th subterms
364249261Sdim      in
365249261Sdim        (rpl, spl, todo,
366249261Sdim         REWRS {order = order, known = known, rewrites = rewrites,
367249261Sdim                subterms = subterms, waiting = waiting})
368249261Sdim      end
369249261Sdim  end;
370249261Sdim
371249261Sdimfun add_rewrs known (i,rewrs) =
372249261Sdim  case M.peek (known,i) of NONE => rewrs
373249261Sdim  | SOME th_ort => add_rewrite i th_ort rewrs;
374249261Sdim
375249261Sdimfun add_stms known (i,stms) =
376249261Sdim  case M.peek (known,i) of NONE => stms
377249261Sdim  | SOME (th,_) => add_subterms i th stms;
378249261Sdim
379249261Sdimfun rebuild rpl spl rw =
380249261Sdim  let
381249261Sdim    val REWRS {order, known, rewrites, subterms, waiting} = rw
382249261Sdim    val rewrites =
383249261Sdim      if S.isEmpty rpl then rewrites
384249261Sdim      else T.filter (fn (i,_) => not (S.member (rpl,i))) rewrites
385263509Sdim    val rewrites = S.foldl (add_rewrs known) rewrites rpl
386249261Sdim    val subterms =
387249261Sdim      if S.isEmpty spl then subterms
388249261Sdim      else T.filter (fn (i,_) => not (S.member (spl,i))) subterms
389249261Sdim    val subterms = S.foldl (add_stms known) subterms spl
390249261Sdim  in
391249261Sdim    REWRS {order = order, known = known, rewrites = rewrites,
392249261Sdim           subterms = subterms, waiting = waiting}
393263509Sdim  end;
394249261Sdim
395249261Sdimfun pick known s =
396249261Sdim  case S.find (fn i => snd (retrieve known i) <> Both) s of SOME x => SOME x
397249261Sdim  | NONE => blind_pick s;
398249261Sdim
399249261Sdimfun reduce_acc (rpl, spl, todo, rw as REWRS {known, waiting, ...}) =
400249261Sdim  case pick known todo of
401249261Sdim    SOME i => reduce_acc (reduce1 false i (rpl, spl, S.delete (todo,i), rw))
402263509Sdim  | NONE =>
403249261Sdim    case pick known waiting of
404249261Sdim      SOME i => reduce_acc (reduce1 true i (rpl, spl, todo, waiting_del i rw))
405249261Sdim    | NONE => (rebuild rpl spl rw, rpl);
406249261Sdim
407249261Sdimfun reduce_newr rw =
408249261Sdim  let
409    val REWRS {waiting, ...} = rw
410    val (rw,changed) = reduce_acc (S.empty, S.empty, S.empty, rw)
411    val newr = S.union (changed,waiting)
412    val REWRS {known, ...} = rw
413    fun filt (i,l) = if Option.isSome (M.peek (known,i)) then i :: l else l
414    val newr = S.foldr filt [] newr
415  in
416    (rw,newr)
417  end;
418
419fun reduce' rw =
420  if not (chatting 2) then reduce_newr rw else
421    let
422      val REWRS {known, order, ...} = rw
423      val res as (rw',_) = reduce_newr rw
424      val REWRS {known = known', ...} = rw'
425      val eqs = map (fn (i,(th,_)) => (i,th)) (M.listItems known')
426      val m = List.exists (thm_match known order) eqs
427      val _ = chatrewrs "reduce before" rw
428      val _ = chatrewrs "reduce after" rw'
429      val () = assert (not m) (Bug "reduce: not fully reduced")
430    in
431      res
432    end;
433
434val reduce = fst o reduce';
435
436fun reduced (REWRS {waiting, ...}) = Intset.isEmpty waiting;
437
438(* ------------------------------------------------------------------------- *)
439(* Rewriting as a derived rule                                               *)
440(* ------------------------------------------------------------------------- *)
441
442local
443  fun f (th,(n,rw)) = (n + 1, add (n, FRESH_VARS th) rw);
444in
445  fun ORD_REWRITE ord ths =
446    let val (_,rw) = foldl f (0, empty ord) ths
447    in rewrite rw ord o pair ~1
448    end;
449end;
450
451val REWRITE = ORD_REWRITE (K (SOME GREATER));
452
453end
454