1(* ========================================================================= *) 2(* KNUTH-BENDIX TERM ORDERING CONSTRAINTS *) 3(* Copyright (c) 2002-2004 Joe Hurd. *) 4(* ========================================================================= *) 5 6(* 7app load ["Binaryset", "mlibOmega", "mlibTerm", "mlibSubst"]; 8*) 9 10(* 11*) 12structure mlibTermorder :> mlibTermorder = 13struct 14 15infix ## |-> ::>; 16 17open mlibUseful mlibTerm; 18 19structure O = Option; local open Option in end; 20structure S = Binaryset; local open Binaryset in end; 21structure B = Binarymap; local open Binarymap in end; 22structure M = mlibMultiset; local open mlibMultiset in end; 23 24type subst = mlibSubst.subst; 25type 'a mset = 'a M.mset; 26 27val |<>| = mlibSubst.|<>|; 28val op::> = mlibSubst.::>; 29val term_subst = mlibSubst.term_subst; 30 31(* ------------------------------------------------------------------------- *) 32(* Chatting. *) 33(* ------------------------------------------------------------------------- *) 34 35val module = "mlibTermorder"; 36val () = add_trace {module = module, alignment = I} 37fun chatting l = tracing {module = module, level = l}; 38fun chat s = (trace s; true) 39 40(* ------------------------------------------------------------------------- *) 41(* Parameters *) 42(* ------------------------------------------------------------------------- *) 43 44type parameters = 45 {weight : string * int -> int, 46 precedence : (string * int) * (string * int) -> order, 47 precision : int}; 48 49(* Default weight = uniform *) 50 51val uniform : string * int -> int = fn _ => 1; 52 53(* Default precedence = by arity *) 54 55val arity : (string * int) * (string * int) -> order = 56 fn ((f,m),(g,n)) => 57 if m < n then LESS else if m > n then GREATER else 58 let val p = String.size f 59 and q = String.size g 60 in if p < q then LESS else if p > q then GREATER else String.compare (f,g) 61 end; 62 63val defaults = 64 {weight = uniform, 65 precedence = arity, 66 precision = 3}; 67 68fun update_precision f (parm : parameters) : parameters = 69 let val {weight = w, precedence = p, precision = r} = parm 70 in {weight = w, precedence = p, precision = f r} 71 end; 72 73(* ------------------------------------------------------------------------- *) 74(* Helper functions. *) 75(* ------------------------------------------------------------------------- *) 76 77val eqn_sum = M.foldl (fn (_,n,m) => n + m) 0; 78 79fun eqn_var _ ("",_,vs) = vs | eqn_var f (v,_,vs) = f v vs; 80 81fun list_eqn vars = 82 let val vars = vars @ [""] in fn eqn => map (M.count eqn) vars end; 83 84local 85 val no_vars = mlibMultiset.empty String.compare; 86 fun one_var v = mlibMultiset.insert (v,1) no_vars; 87 88 fun kb_weight w = 89 let 90 fun weight (Var v) = (0, one_var v) 91 | weight (Fn (f, a)) = foldl wght (w (f, length a), no_vars) a 92 and wght (t, (n, v)) = (curry op+ n ## mlibMultiset.union v) (weight t) 93 in 94 weight 95 end; 96in 97 fun weight wf t = let val (n,w) = kb_weight wf t in M.insert ("",n) w end; 98end; 99 100local 101 val emptys = S.empty String.compare; 102 fun inserts v vs = S.add (vs,v); 103in 104 val calc_vars = 105 S.listItems o foldl (fn (q,v) => M.foldl (eqn_var inserts) v q) emptys; 106end; 107 108fun partialorder_to_string (SOME LESS) = "SOME LESS" 109 | partialorder_to_string (SOME GREATER) = "SOME GREATER" 110 | partialorder_to_string (SOME EQUAL) = "SOME EQUAL" 111 | partialorder_to_string NONE = "NONE"; 112 113(* ------------------------------------------------------------------------- *) 114(* Normalizing equations means checking for trivial cases and tidying up *) 115(* ------------------------------------------------------------------------- *) 116 117fun divide_gcd eqn = 118 let val g = M.foldl (fn (_,m,n) => gcd m n) 0 eqn 119 in if g <= 1 then eqn else M.map (fn (_,n) => n div g) eqn 120 end; 121 122(* If an equation satisfies this it's inconsistent: some var must be <= 0 *) 123fun inconsistent_eqn q = 124 M.all (fn ("",_) => true | (_,n) => n < 0) q andalso eqn_sum q < 0; 125 126local 127 (* If an equation satisfies pos then it's completely uninformative *) 128 fun pos q = 129 M.all (fn ("",_) => true | (_,n) => 0 <= n) q andalso 0 <= eqn_sum q; 130 131 (* bad is a weaker condition, a compromise for efficiency *) 132 fun bad q = 133 0 <= M.foldl (fn ("",_,m) => m | (_,n,m) => n + m) 0 q andalso 134 0 <= M.foldl (fn ("",_,m) => m | (_,n,m) => if 0<n then m+1 else m-1) 0 q; 135 136 (* An equation being unbounded is an incredibly weak condition *) 137 fun trivial q = M.nonzero q=0 orelse M.nonzero q=1 andalso 0<M.count q ""; 138 fun unbounded q = M.exists (fn ("",_) => false | (_,n) => 0 < n) q; 139in 140 fun good_eqn (parm : parameters) eqn = 141 if inconsistent_eqn eqn then raise Error "good_eqn: inconsistent" 142 else if #precision parm <= 0 then false 143 else if #precision parm <= 1 then not (unbounded eqn orelse trivial eqn) 144 else if #precision parm <= 2 then not (bad eqn) 145 else not (pos eqn); 146end; 147 148fun normalize parm = 149 let 150 fun g (q,l) = if good_eqn parm q then q :: l else l 151 fun f (q,l) = g (divide_gcd q, l) 152 in 153 foldl f [] 154 end; 155 156(* ------------------------------------------------------------------------- *) 157(* Deriving an equation from a term comparison. *) 158(* ------------------------------------------------------------------------- *) 159 160datatype eqn = Equal | Less | Greater | Equation of string mset; 161 162fun mk_eqn (parm : parameters) = 163 let 164 val {weight = wf, precedence, ...} = parm 165 fun f [] = Equal 166 | f ((l,r) :: rest) = 167 if l = r then f rest else 168 let val w = M.subtract (weight wf r) (weight wf l) 169 in if M.nonzero w = 0 then g l r rest else Equation (divide_gcd w) 170 end 171 and g (Fn (f1,a1)) (Fn (f2,a2)) rest = 172 (case precedence ((f1, length a1), (f2, length a2)) of LESS => Less 173 | GREATER => Greater 174 | EQUAL => f (zip a1 a2 @ rest)) 175 | g (Var _) _ _ = Less 176 | g _ (Var _) _ = Greater; 177 in 178 fn lr => f [lr] 179 end; 180 181(* ------------------------------------------------------------------------- *) 182(* A partial order on equations, designed to be quick to check. *) 183(* ------------------------------------------------------------------------- *) 184 185local 186 fun gen_stronger cmp eqn1 eqn2 = 187 M.all (fn ("",_) => true | (v,i) => i <= M.count eqn2 v) eqn1 andalso 188 M.all (fn ("",_) => true | (v,i) => M.count eqn1 v <= i) eqn2 andalso 189 cmp (eqn_sum eqn1, eqn_sum eqn2); 190in 191 val stronger = gen_stronger op<=; 192 val strictly_stronger = gen_stronger op<; 193end; 194 195fun weaker eqn1 eqn2 = stronger eqn2 eqn1; 196fun strictly_weaker eqn1 eqn2 = strictly_stronger eqn2 eqn1; 197 198fun superfluous eqn eqns = List.exists (weaker eqn) eqns; 199fun strictly_superfluous eqn eqns = List.exists (strictly_weaker eqn) eqns; 200 201(* ------------------------------------------------------------------------- *) 202(* The termorder type. *) 203(* *) 204(* Invariants: *) 205(* *) 206(* 1. The string list contains precisely the variables that appear (with *) 207(* non-zero coefficient) in the eqns. *) 208(* *) 209(* 2. All the equations satisfy good_eqn. *) 210(* *) 211(* 3. The boolean is true whenever there are no equations, and otherwise *) 212(* only if the termorder is known to be satisfiable. *) 213(* ------------------------------------------------------------------------- *) 214 215datatype termorder = TO of parameters * string list * string mset list * bool; 216 217(* ------------------------------------------------------------------------- *) 218(* Pretty-printing. *) 219(* ------------------------------------------------------------------------- *) 220 221fun pp_equation vars = 222 let 223 fun pp_tm ("",n) = pp_string (int_to_string n) 224 | pp_tm (v,n) = 225 pp_string ((if n=1 then "" else (int_to_string n^"*")) ^ v) 226 fun pp_tms [] = pp_string "0" 227 | pp_tms [tm] = pp_tm tm 228 | pp_tms (tm :: tms) = pp_binop " +" pp_tm pp_tms (tm,tms) 229 in 230 fn eqn => 231 let 232 val eqn = zip (vars @ [""]) (list_eqn vars eqn) 233 val tms = List.filter (fn (_,n) => n <> 0) eqn 234 val (pos,neg) = List.partition (fn (_,n) => 0 < n) tms 235 val neg = map (I ## ~) neg 236 in 237 pp_binop " <=" pp_tms pp_tms (neg,pos) 238 end 239 end; 240 241fun pp_termorder (TO (_,vars,eqns,sat)) = 242 pp_bracket "{" (if sat then "}*" else "}") 243 (pp_binop " |" (pp_sequence "" pp_string) 244 (pp_sequence "," (pp_equation vars))) (vars,eqns); 245 246val termorder_to_string = PP.pp_to_string (!LINE_LENGTH) pp_termorder; 247 248local 249 val q2s = PP.pp_to_string (!LINE_LENGTH) 250 (pp_list (pp_binop " |->" pp_string pp_int)) o M.to_list; 251 252 fun wf_eqn vars eqn = 253 if M.all (fn ("",_) => true | (v,_) => mem v vars) eqn then () 254 else raise Bug ("wf_eqn: malformed equation: " ^ q2s eqn); 255in 256 fun chatto n s (to as TO (_,vars,eqns,_)) = 257 if not (chatting n) then () else 258 (chat (s ^ ":\n" ^ termorder_to_string to ^ "\n"); 259 app (wf_eqn vars) eqns); 260end; 261 262(* ------------------------------------------------------------------------- *) 263(* Basic operations *) 264(* ------------------------------------------------------------------------- *) 265 266fun empty parm = TO (parm,[],[],true); 267 268fun TON parm eqns = 269 let val eqns = normalize parm eqns 270 in TO (parm, calc_vars eqns, eqns, null eqns) 271 end; 272 273fun tnull (TO (_,[],[],_)) = true | tnull _ = false; 274 275fun vars (TO (_,fvs,_,_)) = fvs; 276 277fun add_leq lr (to as TO (parm,vars,eqns,_)) = 278 let 279 fun keep eqn = 280 good_eqn parm eqn andalso 281 not (superfluous eqn eqns) andalso 282 (if not (strictly_superfluous (M.compl eqn) eqns) then true 283 else raise Error "add_leq: direct contradiction") 284 285 fun inc eqn = 286 let 287 val () = chatto 1 "add_leq input" to 288 val vars' = M.foldl (eqn_var insert) vars eqn 289 val eqns' = eqn :: List.filter (not o stronger eqn) eqns 290 val to = TO (parm,vars',eqns',false) 291 val () = chatto 1 "add_leq result" to 292 in 293 to 294 end 295 in 296 case mk_eqn parm lr of Equal => to 297 | Less => to 298 | Greater => raise Error "add_leq: violates order (weight)" 299 | Equation eqn => if keep eqn then inc eqn else to 300 end; 301 302fun add_leqs lrs to = foldl (uncurry add_leq) to lrs; 303 304local 305 fun table_to_string vars vars' tab = 306 let 307 fun nicevar "" = "1" | nicevar v = v; 308 fun nicerow l = "[" :: map (fn x => " " ^ x) (l @ ["]"]) 309 fun makerow v = 310 nicevar v :: map (int_to_string o M.count (B.find (tab,v))) vars 311 in 312 join "\n" 313 (align_table {left = false, pad = #" "} 314 (map nicerow (("" :: map nicevar vars) :: map makerow vars'))) ^ "\n" 315 end; 316 317 fun new_vars vars mapl = 318 let val (ls,rs) = unzip (map (fn x |-> y => (x,y)) mapl) 319 in FVTL (List.filter (not o C mem ls) vars) rs 320 end; 321 322 val m0 = M.empty String.compare; 323 fun m1 xi = M.insert xi m0; 324 fun mn xis = foldl (uncurry M.insert) m0 xis; 325 326 fun table_add parm vars' ((v |-> t), tab) = 327 let 328 val {weight = wf, ...} : parameters = parm 329 fun add (w,i,t) = B.insert (t, w, M.insert (v, i) (B.find (t, w))) 330 val tab = if not (mem v vars') then tab else add (v,~1,tab) 331 in 332 M.foldl add tab (weight wf t) 333 end; 334 335 fun mk_table parm vars vars' = 336 let 337 fun init (v,m) = B.insert (m, v, if mem v vars then m1 (v,1) else m0) 338 val tab = foldl init (B.mkDict String.compare) vars' 339 in 340 foldl (table_add parm vars') tab 341 end; 342 343 fun new_eqn vars vars' tab eqn = 344 let 345 fun g (v,i,n) = n + M.count eqn v * i 346 fun f (v,m) = M.insert (v, M.foldl g 0 (B.find (tab,v))) m 347 in 348 foldl f m0 vars' 349 end; 350 351 fun nontriv mapl (to as TO (parm,vars,eqns,_)) = 352 let 353 val () = chatto 1 "subst input" to 354 val vars1 = "" :: vars 355 val vars2 = "" :: new_vars vars mapl 356 val tab = mk_table parm vars1 vars2 mapl 357 val _ = chatting 1 andalso 358 chat ("subst table:\n"^table_to_string vars1 vars2 tab) 359 val eqns' = map (new_eqn vars1 vars2 tab) eqns 360 val to = TON parm eqns' 361 val () = chatto 1 "subst result" to 362 in 363 to 364 end; 365in 366 fun subst sub (to as TO (_,vars,_,_)) = 367 let val mapl = mlibSubst.to_maplets (mlibSubst.norm (mlibSubst.restrict vars sub)) 368 in if null mapl then to else nontriv mapl to 369 end; 370end; 371 372local 373 fun cast_away eqns = List.filter (fn eqn => not (superfluous eqn eqns)); 374in 375 fun merge (TO (_,_,[],_)) to = to 376 | merge to (TO (_,_,[],_)) = to 377 | merge to1 to2 = 378 let 379 val () = chatto 1 "merge input1" to1 380 val () = chatto 1 "merge input2" to2 381 val TO (parm,_,eqns1,_) = to1 382 val TO (_,_,eqns2,_) = to2 383 val eqns1 = cast_away eqns2 eqns1 384 val eqns2 = cast_away eqns1 eqns2 385 val to = 386 if null eqns1 then to2 else if null eqns2 then to1 else 387 let val eqns = eqns1 @ eqns2 388 in TO (parm, calc_vars eqns, eqns, false) 389 end 390 val () = chatto 1 "merge result" to 391 in 392 to 393 end; 394end; 395 396(* ------------------------------------------------------------------------- *) 397(* Interface to mlibOmega. *) 398(* ------------------------------------------------------------------------- *) 399 400local 401 val raw_equations_to_string = 402 String.concat o 403 map (fn x => PP.pp_to_string (!LINE_LENGTH) (pp_list pp_int) x ^ "\n"); 404 405 fun pos_eqns n = 406 snd (funpow n (fn (t,r) => (0 :: t, (1 :: t) :: map (cons 0) r)) ([~1],[])); 407 408 (* Remember that list_eqn does partial evaluation on vars *) 409 fun omega_eqns vars eqns = pos_eqns (length vars) @ map (list_eqn vars) eqns; 410 411 open mlibOmega; 412 413 fun mk_db normalc eqns db exc = 414 case eqns of [] => normalc db 415 | c :: cs => 416 add_check_factoid db (gcd_check_dfactoid (fromList c, ASM ())) 417 (mk_db normalc cs) exc; 418 419 fun check eqns = 420 mk_db (fn db => work db I) eqns (dbempty (length (hd eqns))) I; 421 422 fun inconsistent (SATISFIABLE _) = false 423 | inconsistent (CONTR _) = true 424 | inconsistent NO_CONCL = false; 425 426 (* Uncomment this check function to print out the linear arithmetic problems 427 val THRESHOLD = 1.0; 428 429 fun result_to_string (SATISFIABLE _) = "satisfiable" 430 | result_to_string (CONTR _) = "inconsistent" 431 | result_to_string NO_CONCL = "no conclusion"; 432 433 val check = fn eqns => 434 let 435 val (t,r) = timed check eqns 436 val () = 437 if t < THRESHOLD then () else 438 print ("\n\nOMEGA: time = " ^ Real.fmt (StringCvt.FIX (SOME 3)) t ^ 439 "s\n" ^ raw_equations_to_string eqns ^ 440 "OMEGA: " ^ result_to_string r ^ "\n\n") 441 in 442 r 443 end; 444 *) 445in 446 fun consistent (to as TO (_,_,_,true)) = SOME to 447 | consistent (to as TO (parm,vars,eqns,false)) = 448 let 449 val () = chatto 1 "consistent" to 450 in 451 if inconsistent (check (omega_eqns vars eqns)) then NONE 452 else SOME (TO (parm,vars,eqns,true)) 453 end; 454(* This bug has now been fixed, but others may appear in the future :-) 455 handle Option => 456 (print ("BUG in mlibOmega library: uncaught Option exception" ^ 457 "\ntermorder was:\n" ^ termorder_to_string to ^ 458 "\nsent to mlibOmega:\n" ^ raw_equations_to_string (omega_eqns to) ^ 459 "\n\n"); true) 460*) 461end; 462 463(* ------------------------------------------------------------------------- *) 464(* Query. *) 465(* ------------------------------------------------------------------------- *) 466 467fun subsumes (TO (_,_,eqns1,_)) (TO (_,_,eqns2,_)) = 468 List.all (fn eqn => superfluous eqn eqns2) eqns1; 469 470local 471 fun cmp _ _ Equal = SOME EQUAL 472 | cmp _ _ Less = SOME LESS 473 | cmp _ _ Greater = SOME GREATER 474 | cmp parm eqns (Equation eqn) = 475 let 476 val eqn' = M.compl eqn 477 in 478 if inconsistent_eqn eqn then SOME GREATER 479 else if inconsistent_eqn eqn' then SOME LESS 480 else if strictly_superfluous eqn eqns then SOME LESS 481 else if strictly_superfluous eqn' eqns then SOME GREATER 482 else NONE 483 end; 484in 485 fun compare (to as TO (parm,_,eqns,_)) lr = 486 let 487 val () = chatto 1 "compare input" to 488 val _ = chatting 1 andalso 489 chat ("comparing " ^ term_to_string (fst lr) ^ 490 " and " ^ term_to_string (snd lr) ^ "\n") 491 val res = cmp parm eqns (mk_eqn parm lr) 492 val _ = chatting 1 andalso 493 chat ("compare result = " ^ partialorder_to_string res ^ "\n") 494 in 495 res 496 end; 497end; 498 499(* ------------------------------------------------------------------------- *) 500(* Name binding. *) 501(* ------------------------------------------------------------------------- *) 502 503val null = tnull; 504 505(* Quick testing 506app load ["mlibThm"]; 507val T = parse_term; 508val F = parse_formula; 509installPP pp_termorder; 510installPP mlibSubst.pp_subst; 511installPP mlibThm.pp_thm; 512 513val to = empty defaults; 514val to = try (C add_leq to) (T`f x (f y z)`, T`f (f x y) z`); 515val x = (total o try) (C add_leq to) (T`f (f x y) z`, T`f x (f y z)`); 516val to = try (C add_leq to) (T`P (f a b)`, T`P x`); 517val to = try (C add_leq to) (T`P y`, T`P (g a b c)`); 518val to = try (C add_leq to) (T`x + y`, T`y + x`); 519val c = consistent to; 520val to = try (subst (("x" |-> T`v`) ::> |<>|)) to; 521val to = try (subst (("v" |-> T`f x x x x a a a a`) ::> |<>|)) to; 522val c = consistent to; 523 524val to = empty defaults; 525val to = try (C add_leq to) (T`P y`, T`P (g a b c d e f)`); 526val to = try (C add_leq to) (T`x + y`, T`y + x`); 527val to = try (C add_leq to) (T`x + y`, T`y + x`); 528val to = try (subst (("x" |-> T`f x x x`) ::> |<>|)) to; 529val c = consistent to; 530val to = try (subst (("x" |-> T`f w v`) ::> |<>|)) to; 531val c = consistent to; 532 533val to = empty defaults; 534val to = try (C add_leq to) (T`f x y`, T`f y x`); 535val to = try (subst (("x" |-> T`f x`) ::> |<>|)) to; 536val x = compare to (T`f x`, T`g y`); 537val x = compare to (T`g x`, T`f y`); 538val x = compare to (T`g a`, T`f a`); 539val x = compare to (T`f a`, T`g a`); 540val th = 541 mlibThm.ORD_REWRITE (compare to) 542 (map (mlibThm.AXIOM o wrap o F) 543 [`x + (y + z) = y + (x + z)`, `(x + y) + z = x + (y + z)`]) 544 (mlibThm.AXIOM [F`P (y + x + y + x + y + x + 0)`]); 545*) 546 547end 548