1249261Sdim(* ========================================================================= *) 2249261Sdim(* ORDERED REWRITING *) 3249261Sdim(* Copyright (c) 2003-2004 Joe Hurd. *) 4249261Sdim(* ========================================================================= *) 5249261Sdim 6249261Sdim(* 7249261Sdimapp load ["mlibHeap", "mlibTerm", "mlibSubst", "mlibMatch", "mlibThm", "mlibTermorder"]; 8249261Sdim*) 9249261Sdim 10249261Sdim(* 11249261Sdim*) 12263509Sdimstructure mlibRewrite :> mlibRewrite = 13249261Sdimstruct 14249261Sdim 15263509Sdiminfix ## |-> ::>; 16249261Sdim 17249261Sdimopen mlibUseful mlibTerm mlibThm mlibMatch; 18249261Sdim 19249261Sdimstructure O = Option; local open Option in end; 20249261Sdimstructure M = Intmap; local open Intmap in end; 21249261Sdimstructure S = Intset; local open Intset in end; 22249261Sdimstructure T = mlibTermnet; local open mlibTermnet in end; 23249261Sdim 24249261Sdimtype 'a intmap = 'a M.intmap; 25249261Sdimtype intset = S.intset; 26249261Sdimtype subst = mlibSubst.subst; 27249261Sdimtype 'a termnet = 'a T.termnet; 28249261Sdim 29249261Sdimval |<>| = mlibSubst.|<>|; 30249261Sdimval op::> = mlibSubst.::>; 31249261Sdimval term_subst = mlibSubst.term_subst; 32249261Sdimval formula_subst = mlibSubst.formula_subst; 33249261Sdim 34249261Sdim(* ------------------------------------------------------------------------- *) 35249261Sdim(* Chatting. *) 36249261Sdim(* ------------------------------------------------------------------------- *) 37249261Sdim 38249261Sdimval module = "mlibRewrite"; 39249261Sdimval () = add_trace {module = module, alignment = I} 40249261Sdimfun chatting l = tracing {module = module, level = l}; 41249261Sdimfun chat s = (trace s; true) 42249261Sdim 43249261Sdim(* ------------------------------------------------------------------------- *) 44249261Sdim(* Helper functions. *) 45249261Sdim(* ------------------------------------------------------------------------- *) 46249261Sdim 47263509Sdimval blind_pick = S.find (K true); 48249261Sdim 49249261Sdimfun retrieve known i = 50249261Sdim (case M.peek (known,i) of SOME rw_ort => rw_ort 51249261Sdim | NONE => raise Error "rewrite: rewr has been rewritten away!"); 52249261Sdim 53249261Sdim(* ------------------------------------------------------------------------- *) 54249261Sdim(* Representing ordered rewrites. *) 55249261Sdim(* ------------------------------------------------------------------------- *) 56249261Sdim 57249261Sdimdatatype orient = LtoR | RtoL | Both; 58249261Sdim 59249261Sdimdatatype rewrs = REWRS of 60249261Sdim {order : term * term -> order option, 61249261Sdim known : (thm * orient) intmap, 62249261Sdim rewrites : (int * bool) termnet, 63249261Sdim subterms : (int * int list) termnet, 64249261Sdim waiting : intset}; 65249261Sdim 66249261Sdimfun update_waiting waiting rw = 67249261Sdim let 68249261Sdim val REWRS {order, known, rewrites, subterms, waiting = _} = rw 69249261Sdim in 70249261Sdim REWRS {order = order, known = known, rewrites = rewrites, 71249261Sdim subterms = subterms, waiting = waiting} 72249261Sdim end; 73249261Sdim 74249261Sdimfun waiting_del i (rw as REWRS {waiting, ...}) = 75249261Sdim update_waiting (S.delete (waiting,i)) rw; 76249261Sdim 77249261Sdim(* ------------------------------------------------------------------------- *) 78249261Sdim(* Basic operations *) 79249261Sdim(* ------------------------------------------------------------------------- *) 80249261Sdim 81249261Sdimfun empty order = 82249261Sdim REWRS {order = order, known = M.empty (), rewrites = T.empty {fifo = false}, 83249261Sdim subterms = T.empty {fifo = false}, waiting = S.empty}; 84249261Sdim 85249261Sdimfun reset (REWRS {order, ...}) = empty order; 86249261Sdim 87249261Sdimfun peek (REWRS {known, ...}) i = M.peek (known,i); 88249261Sdim 89249261Sdimfun size (REWRS {known, ...}) = M.numItems known; 90249261Sdim 91249261Sdimfun eqns (REWRS {known, ...}) = 92249261Sdim map (fn (i,(th,_)) => th) (M.listItems known); 93249261Sdim 94249261Sdim(* ------------------------------------------------------------------------- *) 95249261Sdim(* Pretty-printing *) 96249261Sdim(* ------------------------------------------------------------------------- *) 97249261Sdim 98249261Sdimlocal fun f LtoR = "LtoR" | f RtoL = "RtoL" | f Both = "Both"; 99249261Sdimin val pp_orient = pp_map f pp_string; 100249261Sdimend; 101249261Sdim 102249261Sdimlocal 103249261Sdim val simple = pp_map eqns (pp_list pp_thm); 104249261Sdim 105249261Sdim fun kws (REWRS {known, waiting, subterms, ...}) = 106249261Sdim (M.listItems known, 107249261Sdim S.listItems waiting, 108249261Sdim subterms); 109249261Sdim 110249261Sdim val pp_kws = 111249261Sdim pp_triple 112249261Sdim (pp_list (pp_pair pp_int (pp_pair pp_thm pp_orient))) 113249261Sdim (pp_list pp_int) 114249261Sdim (T.pp_termnet (pp_pair pp_int (pp_list pp_int))); 115249261Sdim 116249261Sdim val complicated = pp_map kws pp_kws; 117249261Sdimin 118249261Sdim fun pp_rewrs pp = (if chatting 3 then complicated else simple) pp; 119249261Sdimend; 120249261Sdim 121249261Sdimfun rewrs_to_string rw = PP.pp_to_string (!LINE_LENGTH) pp_rewrs rw; 122249261Sdim 123249261Sdimfun chatrewrs s rw = 124249261Sdim chat (module ^ "." ^ s ^ ":\n" ^ rewrs_to_string rw ^ "\n"); 125249261Sdim 126249261Sdim(* ------------------------------------------------------------------------- *) 127249261Sdim(* Add an equation into the system *) 128249261Sdim(* ------------------------------------------------------------------------- *) 129249261Sdim 130249261Sdimfun orient (SOME EQUAL) = NONE 131249261Sdim | orient (SOME GREATER) = SOME LtoR 132249261Sdim | orient (SOME LESS) = SOME RtoL 133249261Sdim | orient NONE = SOME Both; 134249261Sdim 135249261Sdimfun add_rewrite i (th,ort) rewrites = 136249261Sdim let 137249261Sdim val (l,r) = dest_unit_eq th 138249261Sdim in 139249261Sdim case ort of 140249261Sdim LtoR => T.insert (l |-> (i,true)) rewrites 141249261Sdim | RtoL => T.insert (r |-> (i,false)) rewrites 142249261Sdim | Both => T.insert (l |-> (i,true)) (T.insert (r |-> (i,false)) rewrites) 143249261Sdim end; 144249261Sdim 145249261Sdimfun add (i,th) (rw as REWRS {known, ...}) = 146249261Sdim if Option.isSome (M.peek (known,i)) then rw else 147249261Sdim let 148249261Sdim val REWRS {order, rewrites, subterms, waiting, ...} = rw 149249261Sdim val ort = 150249261Sdim case orient (order (dest_unit_eq th)) of SOME x => x 151249261Sdim | NONE => raise Bug "mlibRewrite.add: can't add reflexive eqns" 152249261Sdim val known = M.insert (known, i, (th,ort)) 153249261Sdim val rewrites = add_rewrite i (th,ort) rewrites 154249261Sdim val waiting = S.add (waiting,i) 155249261Sdim val rw = REWRS {order = order, known = known, rewrites = rewrites, 156249261Sdim subterms = subterms, waiting = waiting} 157249261Sdim val _ = chatting 1 andalso chatrewrs "add" rw 158249261Sdim in 159249261Sdim rw 160249261Sdim end; 161249261Sdim 162249261Sdim(* ------------------------------------------------------------------------- *) 163249261Sdim(* Rewriting (the order must be a refinement of the initial order) *) 164249261Sdim(* ------------------------------------------------------------------------- *) 165249261Sdim 166263509Sdimfun thm_match known order (i,th) = 167249261Sdim let 168249261Sdim fun orw (l,r) tm = 169249261Sdim let val sub = match l tm 170249261Sdim in assert (order (tm, term_subst sub r) = SOME GREATER) (Error "orw") 171249261Sdim end 172249261Sdim fun rw ((l,_),LtoR) tm = can (match l) tm 173249261Sdim | rw ((_,r),RtoL) tm = can (match r) tm 174249261Sdim | rw ((l,r),Both) tm = can (orw (l,r)) tm orelse can (orw (r,l)) tm 175249261Sdim fun f (_,(th,ort)) = (dest_unit_eq th, ort) 176249261Sdim val eqs = (map f o List.filter (not o equal i o fst) o M.listItems) known 177249261Sdim fun can_rw tm = List.exists (fn eq => rw eq tm) eqs orelse can_depth tm 178249261Sdim and can_depth (Var _) = false 179249261Sdim | can_depth (Fn (_,tms)) = List.exists can_rw tms 180249261Sdim val lit_match = can_depth o dest_atom o literal_atom 181249261Sdim in 182249261Sdim List.exists lit_match (clause th) 183249261Sdim end; 184249261Sdim 185249261Sdimlocal 186249261Sdim fun agree false LtoR = false | agree true RtoL = false | agree _ _ = true; 187249261Sdim 188249261Sdim fun redex_residue lr th = (if lr then I else swap) (dest_unit_eq th); 189249261Sdim 190249261Sdim local val reorder = sort (fn ((i,_),(j,_)) => Int.compare (j,i)); 191249261Sdim in fun get_rewrs rw tm = reorder (T.match rw tm); 192249261Sdim end; 193249261Sdim 194249261Sdim local 195249261Sdim fun compile_neq (SOME LtoR, lit) = 196249261Sdim let val lit' = dest_neg lit val (l,r) = dest_eq lit' 197249261Sdim in SOME (l, (ASSUME lit', r, true)) 198249261Sdim end 199249261Sdim | compile_neq (SOME RtoL, lit) = 200249261Sdim let val lit' = dest_neg lit val (l,r) = dest_eq lit' 201249261Sdim in SOME (r, (ASSUME lit', l, false)) 202249261Sdim end 203249261Sdim | compile_neq _ = NONE; 204249261Sdim in 205249261Sdim val compile_neqs = List.mapPartial compile_neq; 206249261Sdim end; 207249261Sdim 208249261Sdim fun rewr known rewrites order i = 209249261Sdim let 210249261Sdim fun rewr_lit neqs = 211249261Sdim let 212249261Sdim fun f tm (j,lr) = 213249261Sdim let 214249261Sdim val () = assert (j <> i) (Error "rewrite: same theorem") 215249261Sdim val (rw,ort) = retrieve known j 216249261Sdim val () = assert (agree lr ort) (Error "rewrite: bad orientation") 217249261Sdim val (l,r) = redex_residue lr rw 218249261Sdim val sub = match l tm 219249261Sdim val r' = term_subst sub r 220249261Sdim val () = assert 221249261Sdim (ort <> Both orelse order (tm,r') = SOME GREATER) 222249261Sdim (Error "rewrite: order violation") 223249261Sdim in 224249261Sdim (INST sub rw, r', lr) 225249261Sdim end 226249261Sdim fun rewr_conv tm = first (total (f tm)) (get_rewrs rewrites tm) 227249261Sdim fun neq_conv tm = Option.map snd (List.find (equal tm o fst) neqs) 228249261Sdim fun conv tm = 229249261Sdim case rewr_conv tm of SOME x => x 230249261Sdim | NONE => (case neq_conv tm of SOME x => x 231249261Sdim | NONE => raise Error "rewrite: no matching rewrites") 232249261Sdim in 233249261Sdim DEPTH1 conv 234249261Sdim end 235249261Sdim 236249261Sdim fun orient_neq neq = orient (order (dest_eq (negate neq))) 237249261Sdim 238249261Sdim fun orient_neqs neqs = map (fn neq => (orient_neq neq, neq)) neqs 239249261Sdim 240249261Sdim fun rewr_neqs dealt [] th = (rev dealt, th) 241249261Sdim | rewr_neqs dealt ((ort,neq) :: neqs) th = 242249261Sdim if not (mem neq (clause th)) then rewr_neqs dealt neqs th else 243249261Sdim let 244249261Sdim val other_neqs = List.revAppend (dealt,neqs) 245249261Sdim val (th,neq') = rewr_lit (compile_neqs other_neqs) (th,neq) 246249261Sdim in 247249261Sdim if neq' = neq then rewr_neqs ((ort,neq) :: dealt) neqs th else 248249261Sdim let 249249261Sdim val ort = orient_neq neq' 250249261Sdim val active = ort = SOME LtoR orelse ort = SOME RtoL 251249261Sdim in 252249261Sdim if active then rewr_neqs [(ort,neq')] other_neqs th 253249261Sdim else rewr_neqs ((ort,neq') :: dealt) neqs th 254249261Sdim end 255249261Sdim end 256249261Sdim 257249261Sdim fun rewr' th = 258249261Sdim let 259249261Sdim val lits = clause th 260249261Sdim val (neqs,rest) = List.partition (is_eq o negate) lits 261249261Sdim val (neqs,th) = rewr_neqs [] (orient_neqs neqs) th 262249261Sdim val neqs = compile_neqs neqs 263249261Sdim in 264249261Sdim if M.numItems known = 0 andalso null neqs then th 265249261Sdim else foldl (fst o rewr_lit neqs o swap) th rest 266249261Sdim end 267249261Sdim in 268249261Sdim fn th => 269249261Sdim if not (chatting 2) then rewr' th else 270249261Sdim let 271249261Sdim val th' = rewr' th 272249261Sdim val m = thm_match known order (i,th') 273249261Sdim val _ = chat ("rewrite:\n" ^ thm_to_string th 274249261Sdim ^ "\n ->\n" ^ thm_to_string th' ^ "\n") 275249261Sdim val () = assert (not m) (Bug "rewrite: should be normalized") 276249261Sdim in 277249261Sdim th' 278249261Sdim end 279249261Sdim end; 280249261Sdimin 281249261Sdim fun rewrite (REWRS {known,rewrites,...}) order (i,th) = 282249261Sdim rewr known rewrites order i th; 283249261Sdimend; 284249261Sdim 285249261Sdim(* ------------------------------------------------------------------------- *) 286249261Sdim(* Inter-reduce the equations in the system *) 287249261Sdim(* ------------------------------------------------------------------------- *) 288249261Sdim 289249261Sdimfun add_subterms i = 290249261Sdim let fun f ((p |-> tm), subterms) = T.insert (tm |-> (i,p)) subterms 291249261Sdim in fn th => fn subterms => foldl f subterms (literal_subterms (dest_unit th)) 292249261Sdim end; 293249261Sdim 294249261Sdimfun same_redex eq ort eq' = 295249261Sdim let 296249261Sdim val (l,r) = dest_eq eq 297249261Sdim val (l',r') = dest_eq eq' 298249261Sdim in 299249261Sdim case ort of 300249261Sdim LtoR => l = l' 301249261Sdim | RtoL => r = r' 302249261Sdim | Both => l = l' andalso r = r' 303249261Sdim end; 304249261Sdim 305249261Sdimfun redex_residues eq ort = 306249261Sdim let 307249261Sdim val (l,r) = dest_eq eq 308249261Sdim in 309249261Sdim case ort of 310249261Sdim LtoR => [(l,r,true)] 311249261Sdim | RtoL => [(r,l,true)] 312249261Sdim | Both => [(l,r,false),(r,l,false)] 313249261Sdim end; 314249261Sdim 315249261Sdimfun find_rws order known subterms i = 316249261Sdim let 317249261Sdim fun valid_rw (l,r,ord) (j,p) = 318249261Sdim let 319249261Sdim val t = literal_subterm p (dest_unit (fst (retrieve known j))) 320249261Sdim val s = match l t 321249261Sdim in 322249261Sdim assert (ord orelse order (t, term_subst s r) = SOME GREATER) 323249261Sdim (Error "valid: violates order") 324249261Sdim end 325249261Sdim 326249261Sdim fun check_subtm lr (jp as (j,_), todo) = 327249261Sdim if i <> j andalso not (S.member (todo,j)) andalso can (valid_rw lr) jp 328249261Sdim then S.add (todo,j) else todo 329249261Sdim 330249261Sdim fun find (lr as (l,_,_), todo) = 331249261Sdim foldl (check_subtm lr) todo (T.matched subterms l) 332249261Sdim in 333249261Sdim foldl find 334249261Sdim end; 335249261Sdim 336249261Sdimfun reduce1 new i (rpl,spl,todo,rw) = 337249261Sdim let 338249261Sdim val REWRS {order, known, rewrites, subterms, waiting} = rw 339249261Sdim val (th0,ort0) = M.retrieve (known,i) 340249261Sdim val eq0 = dest_unit th0 341249261Sdim val th = rewrite rw order (i,th0) 342249261Sdim val eq = dest_unit th 343249261Sdim val identical = eq = eq0 344249261Sdim val same_red = identical orelse (ort0<>Both andalso same_redex eq0 ort0 eq) 345249261Sdim val rpl = if same_red then rpl else S.add (rpl,i) 346249261Sdim val spl = if new orelse identical then spl else S.add (spl,i) 347249261Sdim in 348249261Sdim case (if same_red then SOME ort0 else orient (order (dest_eq eq))) of 349249261Sdim NONE => 350249261Sdim (rpl, spl, todo, 351249261Sdim REWRS {order = order, known = fst (M.remove (known,i)), 352249261Sdim rewrites = rewrites, subterms = subterms, waiting = waiting}) 353249261Sdim | SOME ort => 354249261Sdim let 355249261Sdim val known = if identical then known else M.insert (known,i,(th,ort)) 356249261Sdim val rewrites = 357249261Sdim if same_red then rewrites else add_rewrite i (th,ort) rewrites 358249261Sdim val todo = 359249261Sdim if same_red andalso not new then todo 360249261Sdim else find_rws order known subterms i todo (redex_residues eq ort) 361249261Sdim val subterms = 362249261Sdim if identical andalso not new then subterms 363249261Sdim else add_subterms i th subterms 364249261Sdim in 365249261Sdim (rpl, spl, todo, 366249261Sdim REWRS {order = order, known = known, rewrites = rewrites, 367249261Sdim subterms = subterms, waiting = waiting}) 368249261Sdim end 369249261Sdim end; 370249261Sdim 371249261Sdimfun add_rewrs known (i,rewrs) = 372249261Sdim case M.peek (known,i) of NONE => rewrs 373249261Sdim | SOME th_ort => add_rewrite i th_ort rewrs; 374249261Sdim 375249261Sdimfun add_stms known (i,stms) = 376249261Sdim case M.peek (known,i) of NONE => stms 377249261Sdim | SOME (th,_) => add_subterms i th stms; 378249261Sdim 379249261Sdimfun rebuild rpl spl rw = 380249261Sdim let 381249261Sdim val REWRS {order, known, rewrites, subterms, waiting} = rw 382249261Sdim val rewrites = 383249261Sdim if S.isEmpty rpl then rewrites 384249261Sdim else T.filter (fn (i,_) => not (S.member (rpl,i))) rewrites 385263509Sdim val rewrites = S.foldl (add_rewrs known) rewrites rpl 386249261Sdim val subterms = 387249261Sdim if S.isEmpty spl then subterms 388249261Sdim else T.filter (fn (i,_) => not (S.member (spl,i))) subterms 389249261Sdim val subterms = S.foldl (add_stms known) subterms spl 390249261Sdim in 391249261Sdim REWRS {order = order, known = known, rewrites = rewrites, 392249261Sdim subterms = subterms, waiting = waiting} 393263509Sdim end; 394249261Sdim 395249261Sdimfun pick known s = 396249261Sdim case S.find (fn i => snd (retrieve known i) <> Both) s of SOME x => SOME x 397249261Sdim | NONE => blind_pick s; 398249261Sdim 399249261Sdimfun reduce_acc (rpl, spl, todo, rw as REWRS {known, waiting, ...}) = 400249261Sdim case pick known todo of 401249261Sdim SOME i => reduce_acc (reduce1 false i (rpl, spl, S.delete (todo,i), rw)) 402263509Sdim | NONE => 403249261Sdim case pick known waiting of 404249261Sdim SOME i => reduce_acc (reduce1 true i (rpl, spl, todo, waiting_del i rw)) 405249261Sdim | NONE => (rebuild rpl spl rw, rpl); 406249261Sdim 407249261Sdimfun reduce_newr rw = 408249261Sdim let 409 val REWRS {waiting, ...} = rw 410 val (rw,changed) = reduce_acc (S.empty, S.empty, S.empty, rw) 411 val newr = S.union (changed,waiting) 412 val REWRS {known, ...} = rw 413 fun filt (i,l) = if Option.isSome (M.peek (known,i)) then i :: l else l 414 val newr = S.foldr filt [] newr 415 in 416 (rw,newr) 417 end; 418 419fun reduce' rw = 420 if not (chatting 2) then reduce_newr rw else 421 let 422 val REWRS {known, order, ...} = rw 423 val res as (rw',_) = reduce_newr rw 424 val REWRS {known = known', ...} = rw' 425 val eqs = map (fn (i,(th,_)) => (i,th)) (M.listItems known') 426 val m = List.exists (thm_match known order) eqs 427 val _ = chatrewrs "reduce before" rw 428 val _ = chatrewrs "reduce after" rw' 429 val () = assert (not m) (Bug "reduce: not fully reduced") 430 in 431 res 432 end; 433 434val reduce = fst o reduce'; 435 436fun reduced (REWRS {waiting, ...}) = Intset.isEmpty waiting; 437 438(* ------------------------------------------------------------------------- *) 439(* Rewriting as a derived rule *) 440(* ------------------------------------------------------------------------- *) 441 442local 443 fun f (th,(n,rw)) = (n + 1, add (n, FRESH_VARS th) rw); 444in 445 fun ORD_REWRITE ord ths = 446 let val (_,rw) = foldl f (0, empty ord) ths 447 in rewrite rw ord o pair ~1 448 end; 449end; 450 451val REWRITE = ORD_REWRITE (K (SOME GREATER)); 452 453end 454