1(* ========================================================================= *) 2(* ORDERED REWRITING *) 3(* Copyright (c) 2003-2004 Joe Hurd. *) 4(* ========================================================================= *) 5 6(* 7app load ["mlibHeap", "mlibTerm", "mlibSubst", "mlibMatch", "mlibThm", "mlibTermorder"]; 8*) 9 10(* 11*) 12structure mlibRewrite :> mlibRewrite = 13struct 14 15infix ## |-> ::>; 16 17open mlibUseful mlibTerm mlibThm mlibMatch; 18 19structure O = Option; local open Option in end; 20structure M = Intmap; local open Intmap in end; 21structure S = Intset; local open Intset in end; 22structure T = mlibTermnet; local open mlibTermnet in end; 23 24type 'a intmap = 'a M.intmap; 25type intset = S.intset; 26type subst = mlibSubst.subst; 27type 'a termnet = 'a T.termnet; 28 29val |<>| = mlibSubst.|<>|; 30val op::> = mlibSubst.::>; 31val term_subst = mlibSubst.term_subst; 32val formula_subst = mlibSubst.formula_subst; 33 34(* ------------------------------------------------------------------------- *) 35(* Chatting. *) 36(* ------------------------------------------------------------------------- *) 37 38val module = "mlibRewrite"; 39val () = add_trace {module = module, alignment = I} 40fun chatting l = tracing {module = module, level = l}; 41fun chat s = (trace s; true) 42 43(* ------------------------------------------------------------------------- *) 44(* Helper functions. *) 45(* ------------------------------------------------------------------------- *) 46 47val blind_pick = S.find (K true); 48 49fun retrieve known i = 50 (case M.peek (known,i) of SOME rw_ort => rw_ort 51 | NONE => raise Error "rewrite: rewr has been rewritten away!"); 52 53(* ------------------------------------------------------------------------- *) 54(* Representing ordered rewrites. *) 55(* ------------------------------------------------------------------------- *) 56 57datatype orient = LtoR | RtoL | Both; 58 59datatype rewrs = REWRS of 60 {order : term * term -> order option, 61 known : (thm * orient) intmap, 62 rewrites : (int * bool) termnet, 63 subterms : (int * int list) termnet, 64 waiting : intset}; 65 66fun update_waiting waiting rw = 67 let 68 val REWRS {order, known, rewrites, subterms, waiting = _} = rw 69 in 70 REWRS {order = order, known = known, rewrites = rewrites, 71 subterms = subterms, waiting = waiting} 72 end; 73 74fun waiting_del i (rw as REWRS {waiting, ...}) = 75 update_waiting (S.delete (waiting,i)) rw; 76 77(* ------------------------------------------------------------------------- *) 78(* Basic operations *) 79(* ------------------------------------------------------------------------- *) 80 81fun empty order = 82 REWRS {order = order, known = M.empty (), rewrites = T.empty {fifo = false}, 83 subterms = T.empty {fifo = false}, waiting = S.empty}; 84 85fun reset (REWRS {order, ...}) = empty order; 86 87fun peek (REWRS {known, ...}) i = M.peek (known,i); 88 89fun size (REWRS {known, ...}) = M.numItems known; 90 91fun eqns (REWRS {known, ...}) = 92 map (fn (i,(th,_)) => th) (M.listItems known); 93 94(* ------------------------------------------------------------------------- *) 95(* Pretty-printing *) 96(* ------------------------------------------------------------------------- *) 97 98local fun f LtoR = "LtoR" | f RtoL = "RtoL" | f Both = "Both"; 99in val pp_orient = pp_map f pp_string; 100end; 101 102local 103 val simple = pp_map eqns (pp_list pp_thm); 104 105 fun kws (REWRS {known, waiting, subterms, ...}) = 106 (M.listItems known, 107 S.listItems waiting, 108 subterms); 109 110 val pp_kws = 111 pp_triple 112 (pp_list (pp_pair pp_int (pp_pair pp_thm pp_orient))) 113 (pp_list pp_int) 114 (T.pp_termnet (pp_pair pp_int (pp_list pp_int))); 115 116 val complicated = pp_map kws pp_kws; 117in 118 fun pp_rewrs pp = (if chatting 3 then complicated else simple) pp; 119end; 120 121fun rewrs_to_string rw = PP.pp_to_string (!LINE_LENGTH) pp_rewrs rw; 122 123fun chatrewrs s rw = 124 chat (module ^ "." ^ s ^ ":\n" ^ rewrs_to_string rw ^ "\n"); 125 126(* ------------------------------------------------------------------------- *) 127(* Add an equation into the system *) 128(* ------------------------------------------------------------------------- *) 129 130fun orient (SOME EQUAL) = NONE 131 | orient (SOME GREATER) = SOME LtoR 132 | orient (SOME LESS) = SOME RtoL 133 | orient NONE = SOME Both; 134 135fun add_rewrite i (th,ort) rewrites = 136 let 137 val (l,r) = dest_unit_eq th 138 in 139 case ort of 140 LtoR => T.insert (l |-> (i,true)) rewrites 141 | RtoL => T.insert (r |-> (i,false)) rewrites 142 | Both => T.insert (l |-> (i,true)) (T.insert (r |-> (i,false)) rewrites) 143 end; 144 145fun add (i,th) (rw as REWRS {known, ...}) = 146 if Option.isSome (M.peek (known,i)) then rw else 147 let 148 val REWRS {order, rewrites, subterms, waiting, ...} = rw 149 val ort = 150 case orient (order (dest_unit_eq th)) of SOME x => x 151 | NONE => raise Bug "mlibRewrite.add: can't add reflexive eqns" 152 val known = M.insert (known, i, (th,ort)) 153 val rewrites = add_rewrite i (th,ort) rewrites 154 val waiting = S.add (waiting,i) 155 val rw = REWRS {order = order, known = known, rewrites = rewrites, 156 subterms = subterms, waiting = waiting} 157 val _ = chatting 1 andalso chatrewrs "add" rw 158 in 159 rw 160 end; 161 162(* ------------------------------------------------------------------------- *) 163(* Rewriting (the order must be a refinement of the initial order) *) 164(* ------------------------------------------------------------------------- *) 165 166fun thm_match known order (i,th) = 167 let 168 fun orw (l,r) tm = 169 let val sub = match l tm 170 in assert (order (tm, term_subst sub r) = SOME GREATER) (Error "orw") 171 end 172 fun rw ((l,_),LtoR) tm = can (match l) tm 173 | rw ((_,r),RtoL) tm = can (match r) tm 174 | rw ((l,r),Both) tm = can (orw (l,r)) tm orelse can (orw (r,l)) tm 175 fun f (_,(th,ort)) = (dest_unit_eq th, ort) 176 val eqs = (map f o List.filter (not o equal i o fst) o M.listItems) known 177 fun can_rw tm = List.exists (fn eq => rw eq tm) eqs orelse can_depth tm 178 and can_depth (Var _) = false 179 | can_depth (Fn (_,tms)) = List.exists can_rw tms 180 val lit_match = can_depth o dest_atom o literal_atom 181 in 182 List.exists lit_match (clause th) 183 end; 184 185local 186 fun agree false LtoR = false | agree true RtoL = false | agree _ _ = true; 187 188 fun redex_residue lr th = (if lr then I else swap) (dest_unit_eq th); 189 190 local val reorder = sort (fn ((i,_),(j,_)) => Int.compare (j,i)); 191 in fun get_rewrs rw tm = reorder (T.match rw tm); 192 end; 193 194 local 195 fun compile_neq (SOME LtoR, lit) = 196 let val lit' = dest_neg lit val (l,r) = dest_eq lit' 197 in SOME (l, (ASSUME lit', r, true)) 198 end 199 | compile_neq (SOME RtoL, lit) = 200 let val lit' = dest_neg lit val (l,r) = dest_eq lit' 201 in SOME (r, (ASSUME lit', l, false)) 202 end 203 | compile_neq _ = NONE; 204 in 205 val compile_neqs = List.mapPartial compile_neq; 206 end; 207 208 fun rewr known rewrites order i = 209 let 210 fun rewr_lit neqs = 211 let 212 fun f tm (j,lr) = 213 let 214 val () = assert (j <> i) (Error "rewrite: same theorem") 215 val (rw,ort) = retrieve known j 216 val () = assert (agree lr ort) (Error "rewrite: bad orientation") 217 val (l,r) = redex_residue lr rw 218 val sub = match l tm 219 val r' = term_subst sub r 220 val () = assert 221 (ort <> Both orelse order (tm,r') = SOME GREATER) 222 (Error "rewrite: order violation") 223 in 224 (INST sub rw, r', lr) 225 end 226 fun rewr_conv tm = first (total (f tm)) (get_rewrs rewrites tm) 227 fun neq_conv tm = Option.map snd (List.find (equal tm o fst) neqs) 228 fun conv tm = 229 case rewr_conv tm of SOME x => x 230 | NONE => (case neq_conv tm of SOME x => x 231 | NONE => raise Error "rewrite: no matching rewrites") 232 in 233 DEPTH1 conv 234 end 235 236 fun orient_neq neq = orient (order (dest_eq (negate neq))) 237 238 fun orient_neqs neqs = map (fn neq => (orient_neq neq, neq)) neqs 239 240 fun rewr_neqs dealt [] th = (rev dealt, th) 241 | rewr_neqs dealt ((ort,neq) :: neqs) th = 242 if not (mem neq (clause th)) then rewr_neqs dealt neqs th else 243 let 244 val other_neqs = List.revAppend (dealt,neqs) 245 val (th,neq') = rewr_lit (compile_neqs other_neqs) (th,neq) 246 in 247 if neq' = neq then rewr_neqs ((ort,neq) :: dealt) neqs th else 248 let 249 val ort = orient_neq neq' 250 val active = ort = SOME LtoR orelse ort = SOME RtoL 251 in 252 if active then rewr_neqs [(ort,neq')] other_neqs th 253 else rewr_neqs ((ort,neq') :: dealt) neqs th 254 end 255 end 256 257 fun rewr' th = 258 let 259 val lits = clause th 260 val (neqs,rest) = List.partition (is_eq o negate) lits 261 val (neqs,th) = rewr_neqs [] (orient_neqs neqs) th 262 val neqs = compile_neqs neqs 263 in 264 if M.numItems known = 0 andalso null neqs then th 265 else foldl (fst o rewr_lit neqs o swap) th rest 266 end 267 in 268 fn th => 269 if not (chatting 2) then rewr' th else 270 let 271 val th' = rewr' th 272 val m = thm_match known order (i,th') 273 val _ = chat ("rewrite:\n" ^ thm_to_string th 274 ^ "\n ->\n" ^ thm_to_string th' ^ "\n") 275 val () = assert (not m) (Bug "rewrite: should be normalized") 276 in 277 th' 278 end 279 end; 280in 281 fun rewrite (REWRS {known,rewrites,...}) order (i,th) = 282 rewr known rewrites order i th; 283end; 284 285(* ------------------------------------------------------------------------- *) 286(* Inter-reduce the equations in the system *) 287(* ------------------------------------------------------------------------- *) 288 289fun add_subterms i = 290 let fun f ((p |-> tm), subterms) = T.insert (tm |-> (i,p)) subterms 291 in fn th => fn subterms => foldl f subterms (literal_subterms (dest_unit th)) 292 end; 293 294fun same_redex eq ort eq' = 295 let 296 val (l,r) = dest_eq eq 297 val (l',r') = dest_eq eq' 298 in 299 case ort of 300 LtoR => l = l' 301 | RtoL => r = r' 302 | Both => l = l' andalso r = r' 303 end; 304 305fun redex_residues eq ort = 306 let 307 val (l,r) = dest_eq eq 308 in 309 case ort of 310 LtoR => [(l,r,true)] 311 | RtoL => [(r,l,true)] 312 | Both => [(l,r,false),(r,l,false)] 313 end; 314 315fun find_rws order known subterms i = 316 let 317 fun valid_rw (l,r,ord) (j,p) = 318 let 319 val t = literal_subterm p (dest_unit (fst (retrieve known j))) 320 val s = match l t 321 in 322 assert (ord orelse order (t, term_subst s r) = SOME GREATER) 323 (Error "valid: violates order") 324 end 325 326 fun check_subtm lr (jp as (j,_), todo) = 327 if i <> j andalso not (S.member (todo,j)) andalso can (valid_rw lr) jp 328 then S.add (todo,j) else todo 329 330 fun find (lr as (l,_,_), todo) = 331 foldl (check_subtm lr) todo (T.matched subterms l) 332 in 333 foldl find 334 end; 335 336fun reduce1 new i (rpl,spl,todo,rw) = 337 let 338 val REWRS {order, known, rewrites, subterms, waiting} = rw 339 val (th0,ort0) = M.retrieve (known,i) 340 val eq0 = dest_unit th0 341 val th = rewrite rw order (i,th0) 342 val eq = dest_unit th 343 val identical = eq = eq0 344 val same_red = identical orelse (ort0<>Both andalso same_redex eq0 ort0 eq) 345 val rpl = if same_red then rpl else S.add (rpl,i) 346 val spl = if new orelse identical then spl else S.add (spl,i) 347 in 348 case (if same_red then SOME ort0 else orient (order (dest_eq eq))) of 349 NONE => 350 (rpl, spl, todo, 351 REWRS {order = order, known = fst (M.remove (known,i)), 352 rewrites = rewrites, subterms = subterms, waiting = waiting}) 353 | SOME ort => 354 let 355 val known = if identical then known else M.insert (known,i,(th,ort)) 356 val rewrites = 357 if same_red then rewrites else add_rewrite i (th,ort) rewrites 358 val todo = 359 if same_red andalso not new then todo 360 else find_rws order known subterms i todo (redex_residues eq ort) 361 val subterms = 362 if identical andalso not new then subterms 363 else add_subterms i th subterms 364 in 365 (rpl, spl, todo, 366 REWRS {order = order, known = known, rewrites = rewrites, 367 subterms = subterms, waiting = waiting}) 368 end 369 end; 370 371fun add_rewrs known (i,rewrs) = 372 case M.peek (known,i) of NONE => rewrs 373 | SOME th_ort => add_rewrite i th_ort rewrs; 374 375fun add_stms known (i,stms) = 376 case M.peek (known,i) of NONE => stms 377 | SOME (th,_) => add_subterms i th stms; 378 379fun rebuild rpl spl rw = 380 let 381 val REWRS {order, known, rewrites, subterms, waiting} = rw 382 val rewrites = 383 if S.isEmpty rpl then rewrites 384 else T.filter (fn (i,_) => not (S.member (rpl,i))) rewrites 385 val rewrites = S.foldl (add_rewrs known) rewrites rpl 386 val subterms = 387 if S.isEmpty spl then subterms 388 else T.filter (fn (i,_) => not (S.member (spl,i))) subterms 389 val subterms = S.foldl (add_stms known) subterms spl 390 in 391 REWRS {order = order, known = known, rewrites = rewrites, 392 subterms = subterms, waiting = waiting} 393 end; 394 395fun pick known s = 396 case S.find (fn i => snd (retrieve known i) <> Both) s of SOME x => SOME x 397 | NONE => blind_pick s; 398 399fun reduce_acc (rpl, spl, todo, rw as REWRS {known, waiting, ...}) = 400 case pick known todo of 401 SOME i => reduce_acc (reduce1 false i (rpl, spl, S.delete (todo,i), rw)) 402 | NONE => 403 case pick known waiting of 404 SOME i => reduce_acc (reduce1 true i (rpl, spl, todo, waiting_del i rw)) 405 | NONE => (rebuild rpl spl rw, rpl); 406 407fun reduce_newr rw = 408 let 409 val REWRS {waiting, ...} = rw 410 val (rw,changed) = reduce_acc (S.empty, S.empty, S.empty, rw) 411 val newr = S.union (changed,waiting) 412 val REWRS {known, ...} = rw 413 fun filt (i,l) = if Option.isSome (M.peek (known,i)) then i :: l else l 414 val newr = S.foldr filt [] newr 415 in 416 (rw,newr) 417 end; 418 419fun reduce' rw = 420 if not (chatting 2) then reduce_newr rw else 421 let 422 val REWRS {known, order, ...} = rw 423 val res as (rw',_) = reduce_newr rw 424 val REWRS {known = known', ...} = rw' 425 val eqs = map (fn (i,(th,_)) => (i,th)) (M.listItems known') 426 val m = List.exists (thm_match known order) eqs 427 val _ = chatrewrs "reduce before" rw 428 val _ = chatrewrs "reduce after" rw' 429 val () = assert (not m) (Bug "reduce: not fully reduced") 430 in 431 res 432 end; 433 434val reduce = fst o reduce'; 435 436fun reduced (REWRS {waiting, ...}) = Intset.isEmpty waiting; 437 438(* ------------------------------------------------------------------------- *) 439(* Rewriting as a derived rule *) 440(* ------------------------------------------------------------------------- *) 441 442local 443 fun f (th,(n,rw)) = (n + 1, add (n, FRESH_VARS th) rw); 444in 445 fun ORD_REWRITE ord ths = 446 let val (_,rw) = foldl f (0, empty ord) ths 447 in rewrite rw ord o pair ~1 448 end; 449end; 450 451val REWRITE = ORD_REWRITE (K (SOME GREATER)); 452 453end 454