1(* ========================================================================= *)
2(* ORDERED REWRITING                                                         *)
3(* Copyright (c) 2003-2004 Joe Hurd.                                         *)
4(* ========================================================================= *)
5
6(*
7app load ["mlibHeap", "mlibTerm", "mlibSubst", "mlibMatch", "mlibThm", "mlibTermorder"];
8*)
9
10(*
11*)
12structure mlibRewrite :> mlibRewrite =
13struct
14
15infix ## |-> ::>;
16
17open mlibUseful mlibTerm mlibThm mlibMatch;
18
19structure O = Option; local open Option in end;
20structure M = Intmap; local open Intmap in end;
21structure S = Intset; local open Intset in end;
22structure T = mlibTermnet; local open mlibTermnet in end;
23
24type 'a intmap  = 'a M.intmap;
25type intset     = S.intset;
26type subst      = mlibSubst.subst;
27type 'a termnet = 'a T.termnet;
28
29val |<>|          = mlibSubst.|<>|;
30val op::>         = mlibSubst.::>;
31val term_subst    = mlibSubst.term_subst;
32val formula_subst = mlibSubst.formula_subst;
33
34(* ------------------------------------------------------------------------- *)
35(* Chatting.                                                                 *)
36(* ------------------------------------------------------------------------- *)
37
38val module = "mlibRewrite";
39val () = add_trace {module = module, alignment = I}
40fun chatting l = tracing {module = module, level = l};
41fun chat s = (trace s; true)
42
43(* ------------------------------------------------------------------------- *)
44(* Helper functions.                                                         *)
45(* ------------------------------------------------------------------------- *)
46
47val blind_pick = S.find (K true);
48
49fun retrieve known i =
50  (case M.peek (known,i) of SOME rw_ort => rw_ort
51   | NONE => raise Error "rewrite: rewr has been rewritten away!");
52
53(* ------------------------------------------------------------------------- *)
54(* Representing ordered rewrites.                                            *)
55(* ------------------------------------------------------------------------- *)
56
57datatype orient = LtoR | RtoL | Both;
58
59datatype rewrs = REWRS of
60  {order    : term * term -> order option,
61   known    : (thm * orient) intmap,
62   rewrites : (int * bool) termnet,
63   subterms : (int * int list) termnet,
64   waiting  : intset};
65
66fun update_waiting waiting rw =
67  let
68    val REWRS {order, known, rewrites, subterms, waiting = _} = rw
69  in
70    REWRS {order = order, known = known, rewrites = rewrites,
71           subterms = subterms, waiting = waiting}
72  end;
73
74fun waiting_del i (rw as REWRS {waiting, ...}) =
75  update_waiting (S.delete (waiting,i)) rw;
76
77(* ------------------------------------------------------------------------- *)
78(* Basic operations                                                          *)
79(* ------------------------------------------------------------------------- *)
80
81fun empty order =
82  REWRS {order = order, known = M.empty (), rewrites = T.empty {fifo = false},
83         subterms = T.empty {fifo = false}, waiting = S.empty};
84
85fun reset (REWRS {order, ...}) = empty order;
86
87fun peek (REWRS {known, ...}) i = M.peek (known,i);
88
89fun size (REWRS {known, ...}) = M.numItems known;
90
91fun eqns (REWRS {known, ...}) =
92  map (fn (i,(th,_)) => th) (M.listItems known);
93
94(* ------------------------------------------------------------------------- *)
95(* Pretty-printing                                                           *)
96(* ------------------------------------------------------------------------- *)
97
98local fun f LtoR = "LtoR" | f RtoL = "RtoL" | f Both = "Both";
99in val pp_orient = pp_map f pp_string;
100end;
101
102local
103  val simple = pp_map eqns (pp_list pp_thm);
104
105  fun kws (REWRS {known, waiting, subterms, ...}) =
106    (M.listItems known,
107     S.listItems waiting,
108     subterms);
109
110  val pp_kws =
111    pp_triple
112    (pp_list (pp_pair pp_int (pp_pair pp_thm pp_orient)))
113    (pp_list pp_int)
114    (T.pp_termnet (pp_pair pp_int (pp_list pp_int)));
115
116  val complicated = pp_map kws pp_kws;
117in
118  fun pp_rewrs pp = (if chatting 3 then complicated else simple) pp;
119end;
120
121fun rewrs_to_string rw = PP.pp_to_string (!LINE_LENGTH) pp_rewrs rw;
122
123fun chatrewrs s rw =
124  chat (module ^ "." ^ s ^ ":\n" ^ rewrs_to_string rw ^ "\n");
125
126(* ------------------------------------------------------------------------- *)
127(* Add an equation into the system                                           *)
128(* ------------------------------------------------------------------------- *)
129
130fun orient (SOME EQUAL) = NONE
131  | orient (SOME GREATER) = SOME LtoR
132  | orient (SOME LESS) = SOME RtoL
133  | orient NONE = SOME Both;
134
135fun add_rewrite i (th,ort) rewrites =
136  let
137    val (l,r) = dest_unit_eq th
138  in
139    case ort of
140      LtoR => T.insert (l |-> (i,true)) rewrites
141    | RtoL => T.insert (r |-> (i,false)) rewrites
142    | Both => T.insert (l |-> (i,true)) (T.insert (r |-> (i,false)) rewrites)
143  end;
144
145fun add (i,th) (rw as REWRS {known, ...}) =
146  if Option.isSome (M.peek (known,i)) then rw else
147    let
148      val REWRS {order, rewrites, subterms, waiting, ...} = rw
149      val ort =
150        case orient (order (dest_unit_eq th)) of SOME x => x
151        | NONE => raise Bug "mlibRewrite.add: can't add reflexive eqns"
152      val known = M.insert (known, i, (th,ort))
153      val rewrites = add_rewrite i (th,ort) rewrites
154      val waiting = S.add (waiting,i)
155      val rw = REWRS {order = order, known = known, rewrites = rewrites,
156                      subterms = subterms, waiting = waiting}
157      val _ = chatting 1 andalso chatrewrs "add" rw
158    in
159      rw
160    end;
161
162(* ------------------------------------------------------------------------- *)
163(* Rewriting (the order must be a refinement of the initial order)           *)
164(* ------------------------------------------------------------------------- *)
165
166fun thm_match known order (i,th) =
167  let
168    fun orw (l,r) tm =
169      let val sub = match l tm
170      in assert (order (tm, term_subst sub r) = SOME GREATER) (Error "orw")
171      end
172    fun rw ((l,_),LtoR) tm = can (match l) tm
173      | rw ((_,r),RtoL) tm = can (match r) tm
174      | rw ((l,r),Both) tm = can (orw (l,r)) tm orelse can (orw (r,l)) tm
175    fun f (_,(th,ort)) = (dest_unit_eq th, ort)
176    val eqs = (map f o List.filter (not o equal i o fst) o M.listItems) known
177    fun can_rw tm = List.exists (fn eq => rw eq tm) eqs orelse can_depth tm
178    and can_depth (Var _) = false
179      | can_depth (Fn (_,tms)) = List.exists can_rw tms
180    val lit_match = can_depth o dest_atom o literal_atom
181  in
182    List.exists lit_match (clause th)
183  end;
184
185local
186  fun agree false LtoR = false | agree true RtoL = false | agree _ _ = true;
187
188  fun redex_residue lr th = (if lr then I else swap) (dest_unit_eq th);
189
190  local val reorder = sort (fn ((i,_),(j,_)) => Int.compare (j,i));
191  in fun get_rewrs rw tm = reorder (T.match rw tm);
192  end;
193
194  local
195    fun compile_neq (SOME LtoR, lit) =
196      let val lit' = dest_neg lit val (l,r) = dest_eq lit'
197      in SOME (l, (ASSUME lit', r, true))
198      end
199      | compile_neq (SOME RtoL, lit) =
200      let val lit' = dest_neg lit val (l,r) = dest_eq lit'
201      in SOME (r, (ASSUME lit', l, false))
202      end
203      | compile_neq _ = NONE;
204  in
205    val compile_neqs = List.mapPartial compile_neq;
206  end;
207
208  fun rewr known rewrites order i =
209    let
210      fun rewr_lit neqs =
211        let
212          fun f tm (j,lr) =
213            let
214              val () = assert (j <> i) (Error "rewrite: same theorem")
215              val (rw,ort) = retrieve known j
216              val () = assert (agree lr ort) (Error "rewrite: bad orientation")
217              val (l,r) = redex_residue lr rw
218              val sub = match l tm
219              val r' = term_subst sub r
220              val () = assert
221                (ort <> Both orelse order (tm,r') = SOME GREATER)
222                (Error "rewrite: order violation")
223            in
224              (INST sub rw, r', lr)
225            end
226          fun rewr_conv tm = first (total (f tm)) (get_rewrs rewrites tm)
227          fun neq_conv tm = Option.map snd (List.find (equal tm o fst) neqs)
228          fun conv tm =
229            case rewr_conv tm of SOME x => x
230            | NONE => (case neq_conv tm of SOME x => x
231                       | NONE => raise Error "rewrite: no matching rewrites")
232        in
233          DEPTH1 conv
234        end
235
236      fun orient_neq neq = orient (order (dest_eq (negate neq)))
237
238      fun orient_neqs neqs = map (fn neq => (orient_neq neq, neq)) neqs
239
240      fun rewr_neqs dealt [] th = (rev dealt, th)
241        | rewr_neqs dealt ((ort,neq) :: neqs) th =
242        if not (mem neq (clause th)) then rewr_neqs dealt neqs th else
243          let
244            val other_neqs = List.revAppend (dealt,neqs)
245            val (th,neq') = rewr_lit (compile_neqs other_neqs) (th,neq)
246          in
247            if neq' = neq then rewr_neqs ((ort,neq) :: dealt) neqs th else
248              let
249                val ort = orient_neq neq'
250                val active = ort = SOME LtoR orelse ort = SOME RtoL
251              in
252                if active then rewr_neqs [(ort,neq')] other_neqs th
253                else rewr_neqs ((ort,neq') :: dealt) neqs th
254              end
255          end
256
257      fun rewr' th =
258        let
259          val lits = clause th
260          val (neqs,rest) = List.partition (is_eq o negate) lits
261          val (neqs,th) = rewr_neqs [] (orient_neqs neqs) th
262          val neqs = compile_neqs neqs
263        in
264          if M.numItems known = 0 andalso null neqs then th
265          else foldl (fst o rewr_lit neqs o swap) th rest
266        end
267    in
268      fn th =>
269      if not (chatting 2) then rewr' th else
270        let
271          val th' = rewr' th
272          val m = thm_match known order (i,th')
273          val _ = chat ("rewrite:\n" ^ thm_to_string th
274                        ^ "\n ->\n" ^ thm_to_string th' ^ "\n")
275          val () = assert (not m) (Bug "rewrite: should be normalized")
276        in
277          th'
278        end
279    end;
280in
281  fun rewrite (REWRS {known,rewrites,...}) order (i,th) =
282    rewr known rewrites order i th;
283end;
284
285(* ------------------------------------------------------------------------- *)
286(* Inter-reduce the equations in the system                                  *)
287(* ------------------------------------------------------------------------- *)
288
289fun add_subterms i =
290  let fun f ((p |-> tm), subterms) = T.insert (tm |-> (i,p)) subterms
291  in fn th => fn subterms => foldl f subterms (literal_subterms (dest_unit th))
292  end;
293
294fun same_redex eq ort eq' =
295  let
296    val (l,r) = dest_eq eq
297    val (l',r') = dest_eq eq'
298  in
299    case ort of
300      LtoR => l = l'
301    | RtoL => r = r'
302    | Both => l = l' andalso r = r'
303  end;
304
305fun redex_residues eq ort =
306  let
307    val (l,r) = dest_eq eq
308  in
309    case ort of
310      LtoR => [(l,r,true)]
311    | RtoL => [(r,l,true)]
312    | Both => [(l,r,false),(r,l,false)]
313  end;
314
315fun find_rws order known subterms i =
316  let
317    fun valid_rw (l,r,ord) (j,p) =
318      let
319        val t = literal_subterm p (dest_unit (fst (retrieve known j)))
320        val s = match l t
321      in
322        assert (ord orelse order (t, term_subst s r) = SOME GREATER)
323               (Error "valid: violates order")
324      end
325
326    fun check_subtm lr (jp as (j,_), todo) =
327      if i <> j andalso not (S.member (todo,j)) andalso can (valid_rw lr) jp
328      then S.add (todo,j) else todo
329
330    fun find (lr as (l,_,_), todo) =
331      foldl (check_subtm lr) todo (T.matched subterms l)
332  in
333    foldl find
334  end;
335
336fun reduce1 new i (rpl,spl,todo,rw) =
337  let
338    val REWRS {order, known, rewrites, subterms, waiting} = rw
339    val (th0,ort0) = M.retrieve (known,i)
340    val eq0 = dest_unit th0
341    val th = rewrite rw order (i,th0)
342    val eq = dest_unit th
343    val identical = eq = eq0
344    val same_red = identical orelse (ort0<>Both andalso same_redex eq0 ort0 eq)
345    val rpl = if same_red then rpl else S.add (rpl,i)
346    val spl = if new orelse identical then spl else S.add (spl,i)
347  in
348    case (if same_red then SOME ort0 else orient (order (dest_eq eq))) of
349      NONE =>
350      (rpl, spl, todo,
351       REWRS {order = order, known = fst (M.remove (known,i)),
352              rewrites = rewrites, subterms = subterms, waiting = waiting})
353    | SOME ort =>
354      let
355        val known = if identical then known else M.insert (known,i,(th,ort))
356        val rewrites =
357          if same_red then rewrites else add_rewrite i (th,ort) rewrites
358        val todo =
359          if same_red andalso not new then todo
360          else find_rws order known subterms i todo (redex_residues eq ort)
361        val subterms =
362          if identical andalso not new then subterms
363          else add_subterms i th subterms
364      in
365        (rpl, spl, todo,
366         REWRS {order = order, known = known, rewrites = rewrites,
367                subterms = subterms, waiting = waiting})
368      end
369  end;
370
371fun add_rewrs known (i,rewrs) =
372  case M.peek (known,i) of NONE => rewrs
373  | SOME th_ort => add_rewrite i th_ort rewrs;
374
375fun add_stms known (i,stms) =
376  case M.peek (known,i) of NONE => stms
377  | SOME (th,_) => add_subterms i th stms;
378
379fun rebuild rpl spl rw =
380  let
381    val REWRS {order, known, rewrites, subterms, waiting} = rw
382    val rewrites =
383      if S.isEmpty rpl then rewrites
384      else T.filter (fn (i,_) => not (S.member (rpl,i))) rewrites
385    val rewrites = S.foldl (add_rewrs known) rewrites rpl
386    val subterms =
387      if S.isEmpty spl then subterms
388      else T.filter (fn (i,_) => not (S.member (spl,i))) subterms
389    val subterms = S.foldl (add_stms known) subterms spl
390  in
391    REWRS {order = order, known = known, rewrites = rewrites,
392           subterms = subterms, waiting = waiting}
393  end;
394
395fun pick known s =
396  case S.find (fn i => snd (retrieve known i) <> Both) s of SOME x => SOME x
397  | NONE => blind_pick s;
398
399fun reduce_acc (rpl, spl, todo, rw as REWRS {known, waiting, ...}) =
400  case pick known todo of
401    SOME i => reduce_acc (reduce1 false i (rpl, spl, S.delete (todo,i), rw))
402  | NONE =>
403    case pick known waiting of
404      SOME i => reduce_acc (reduce1 true i (rpl, spl, todo, waiting_del i rw))
405    | NONE => (rebuild rpl spl rw, rpl);
406
407fun reduce_newr rw =
408  let
409    val REWRS {waiting, ...} = rw
410    val (rw,changed) = reduce_acc (S.empty, S.empty, S.empty, rw)
411    val newr = S.union (changed,waiting)
412    val REWRS {known, ...} = rw
413    fun filt (i,l) = if Option.isSome (M.peek (known,i)) then i :: l else l
414    val newr = S.foldr filt [] newr
415  in
416    (rw,newr)
417  end;
418
419fun reduce' rw =
420  if not (chatting 2) then reduce_newr rw else
421    let
422      val REWRS {known, order, ...} = rw
423      val res as (rw',_) = reduce_newr rw
424      val REWRS {known = known', ...} = rw'
425      val eqs = map (fn (i,(th,_)) => (i,th)) (M.listItems known')
426      val m = List.exists (thm_match known order) eqs
427      val _ = chatrewrs "reduce before" rw
428      val _ = chatrewrs "reduce after" rw'
429      val () = assert (not m) (Bug "reduce: not fully reduced")
430    in
431      res
432    end;
433
434val reduce = fst o reduce';
435
436fun reduced (REWRS {waiting, ...}) = Intset.isEmpty waiting;
437
438(* ------------------------------------------------------------------------- *)
439(* Rewriting as a derived rule                                               *)
440(* ------------------------------------------------------------------------- *)
441
442local
443  fun f (th,(n,rw)) = (n + 1, add (n, FRESH_VARS th) rw);
444in
445  fun ORD_REWRITE ord ths =
446    let val (_,rw) = foldl f (0, empty ord) ths
447    in rewrite rw ord o pair ~1
448    end;
449end;
450
451val REWRITE = ORD_REWRITE (K (SOME GREATER));
452
453end
454