1(* ========================================================================= *)
2(* KNUTH-BENDIX TERM ORDERING CONSTRAINTS                                    *)
3(* Copyright (c) 2002-2004 Joe Hurd.                                         *)
4(* ========================================================================= *)
5
6(*
7app load ["Binaryset", "mlibOmega", "mlibTerm", "mlibSubst"];
8*)
9
10(*
11*)
12structure mlibTermorder :> mlibTermorder =
13struct
14
15infix ## |-> ::>;
16
17open mlibUseful mlibTerm;
18
19structure O = Option; local open Option in end;
20structure S = Binaryset; local open Binaryset in end;
21structure B = Binarymap; local open Binarymap in end;
22structure M = mlibMultiset; local open mlibMultiset in end;
23
24type subst   = mlibSubst.subst;
25type 'a mset = 'a M.mset;
26
27val |<>|          = mlibSubst.|<>|;
28val op::>         = mlibSubst.::>;
29val term_subst    = mlibSubst.term_subst;
30
31(* ------------------------------------------------------------------------- *)
32(* Chatting.                                                                 *)
33(* ------------------------------------------------------------------------- *)
34
35val module = "mlibTermorder";
36val () = add_trace {module = module, alignment = I}
37fun chatting l = tracing {module = module, level = l};
38fun chat s = (trace s; true)
39
40(* ------------------------------------------------------------------------- *)
41(* Parameters                                                                *)
42(* ------------------------------------------------------------------------- *)
43
44type parameters =
45  {weight     : string * int -> int,
46   precedence : (string * int) * (string * int) -> order,
47   precision  : int};
48
49(* Default weight = uniform *)
50
51val uniform : string * int -> int = fn _ => 1;
52
53(* Default precedence = by arity *)
54
55val arity : (string * int) * (string * int) -> order =
56  fn ((f,m),(g,n)) =>
57  if m < n then LESS else if m > n then GREATER else
58  let val p = String.size f
59      and q = String.size g
60  in if p < q then LESS else if p > q then GREATER else String.compare (f,g)
61  end;
62
63val defaults =
64  {weight     = uniform,
65   precedence = arity,
66   precision  = 3};
67
68fun update_precision f (parm : parameters) : parameters =
69  let val {weight = w, precedence = p, precision = r} = parm
70  in {weight = w, precedence = p, precision = f r}
71  end;
72
73(* ------------------------------------------------------------------------- *)
74(* Helper functions.                                                         *)
75(* ------------------------------------------------------------------------- *)
76
77val eqn_sum = M.foldl (fn (_,n,m) => n + m) 0;
78
79fun eqn_var _ ("",_,vs) = vs | eqn_var f (v,_,vs) = f v vs;
80
81fun list_eqn vars =
82  let val vars = vars @ [""] in fn eqn => map (M.count eqn) vars end;
83
84local
85  val no_vars = mlibMultiset.empty String.compare;
86  fun one_var v = mlibMultiset.insert (v,1) no_vars;
87
88  fun kb_weight w =
89    let
90      fun weight (Var v) = (0, one_var v)
91        | weight (Fn (f, a)) = foldl wght (w (f, length a), no_vars) a
92      and wght (t, (n, v)) = (curry op+ n ## mlibMultiset.union v) (weight t)
93    in
94      weight
95    end;
96in
97  fun weight wf t = let val (n,w) = kb_weight wf t in M.insert ("",n) w end;
98end;
99
100local
101  val emptys = S.empty String.compare;
102  fun inserts v vs = S.add (vs,v);
103in
104  val calc_vars =
105    S.listItems o foldl (fn (q,v) => M.foldl (eqn_var inserts) v q) emptys;
106end;
107
108fun partialorder_to_string (SOME LESS) = "SOME LESS"
109  | partialorder_to_string (SOME GREATER) = "SOME GREATER"
110  | partialorder_to_string (SOME EQUAL) = "SOME EQUAL"
111  | partialorder_to_string NONE = "NONE";
112
113(* ------------------------------------------------------------------------- *)
114(* Normalizing equations means checking for trivial cases and tidying up     *)
115(* ------------------------------------------------------------------------- *)
116
117fun divide_gcd eqn =
118  let val g = M.foldl (fn (_,m,n) => gcd m n) 0 eqn
119  in if g <= 1 then eqn else M.map (fn (_,n) => n div g) eqn
120  end;
121
122(* If an equation satisfies this it's inconsistent: some var must be <= 0 *)
123fun inconsistent_eqn q =
124  M.all (fn ("",_) => true | (_,n) => n < 0) q andalso eqn_sum q < 0;
125
126local
127  (* If an equation satisfies pos then it's completely uninformative *)
128  fun pos q =
129    M.all (fn ("",_) => true | (_,n) => 0 <= n) q andalso 0 <= eqn_sum q;
130
131  (* bad is a weaker condition, a compromise for efficiency *)
132  fun bad q =
133    0 <= M.foldl (fn ("",_,m) => m | (_,n,m) => n + m) 0 q andalso
134    0 <= M.foldl (fn ("",_,m) => m | (_,n,m) => if 0<n then m+1 else m-1) 0 q;
135
136  (* An equation being unbounded is an incredibly weak condition *)
137  fun trivial q = M.nonzero q=0 orelse M.nonzero q=1 andalso 0<M.count q "";
138  fun unbounded q = M.exists (fn ("",_) => false | (_,n) => 0 < n) q;
139in
140  fun good_eqn (parm : parameters) eqn =
141    if inconsistent_eqn eqn then raise Error "good_eqn: inconsistent"
142    else if #precision parm <= 0 then false
143    else if #precision parm <= 1 then not (unbounded eqn orelse trivial eqn)
144    else if #precision parm <= 2 then not (bad eqn)
145    else not (pos eqn);
146end;
147
148fun normalize parm =
149  let
150    fun g (q,l) = if good_eqn parm q then q :: l else l
151    fun f (q,l) = g (divide_gcd q, l)
152  in
153    foldl f []
154  end;
155
156(* ------------------------------------------------------------------------- *)
157(* Deriving an equation from a term comparison.                              *)
158(* ------------------------------------------------------------------------- *)
159
160datatype eqn = Equal | Less | Greater | Equation of string mset;
161
162fun mk_eqn (parm : parameters) =
163  let
164    val {weight = wf, precedence, ...} = parm
165    fun f [] = Equal
166      | f ((l,r) :: rest) =
167      if l = r then f rest else
168        let val w = M.subtract (weight wf r) (weight wf l)
169        in if M.nonzero w = 0 then g l r rest else Equation (divide_gcd w)
170        end
171    and g (Fn (f1,a1)) (Fn (f2,a2)) rest =
172      (case precedence ((f1, length a1), (f2, length a2)) of LESS => Less
173       | GREATER => Greater
174       | EQUAL => f (zip a1 a2 @ rest))
175      | g (Var _) _ _ = Less
176      | g _ (Var _) _ = Greater;
177  in
178    fn lr => f [lr]
179  end;
180
181(* ------------------------------------------------------------------------- *)
182(* A partial order on equations, designed to be quick to check.              *)
183(* ------------------------------------------------------------------------- *)
184
185local
186  fun gen_stronger cmp eqn1 eqn2 =
187    M.all (fn ("",_) => true | (v,i) => i <= M.count eqn2 v) eqn1 andalso
188    M.all (fn ("",_) => true | (v,i) => M.count eqn1 v <= i) eqn2 andalso
189    cmp (eqn_sum eqn1, eqn_sum eqn2);
190in
191  val stronger = gen_stronger op<=;
192  val strictly_stronger = gen_stronger op<;
193end;
194
195fun weaker eqn1 eqn2 = stronger eqn2 eqn1;
196fun strictly_weaker eqn1 eqn2 = strictly_stronger eqn2 eqn1;
197
198fun superfluous eqn eqns = List.exists (weaker eqn) eqns;
199fun strictly_superfluous eqn eqns = List.exists (strictly_weaker eqn) eqns;
200
201(* ------------------------------------------------------------------------- *)
202(* The termorder type.                                                       *)
203(*                                                                           *)
204(* Invariants:                                                               *)
205(*                                                                           *)
206(* 1. The string list contains precisely the variables that appear (with     *)
207(*    non-zero coefficient) in the eqns.                                     *)
208(*                                                                           *)
209(* 2. All the equations satisfy good_eqn.                                    *)
210(*                                                                           *)
211(* 3. The boolean is true whenever there are no equations, and otherwise     *)
212(*    only if the termorder is known to be satisfiable.                      *)
213(* ------------------------------------------------------------------------- *)
214
215datatype termorder = TO of parameters * string list * string mset list * bool;
216
217(* ------------------------------------------------------------------------- *)
218(* Pretty-printing.                                                          *)
219(* ------------------------------------------------------------------------- *)
220
221fun pp_equation vars =
222  let
223    fun pp_tm ("",n) = pp_string (int_to_string n)
224      | pp_tm (v,n) =
225          pp_string ((if n=1 then "" else (int_to_string n^"*")) ^ v)
226    fun pp_tms [] = pp_string "0"
227      | pp_tms [tm] = pp_tm tm
228      | pp_tms (tm :: tms) = pp_binop " +" pp_tm pp_tms (tm,tms)
229  in
230    fn eqn =>
231    let
232      val eqn = zip (vars @ [""]) (list_eqn vars eqn)
233      val tms = List.filter (fn (_,n) => n <> 0) eqn
234      val (pos,neg) = List.partition (fn (_,n) => 0 < n) tms
235      val neg = map (I ## ~) neg
236    in
237      pp_binop " <=" pp_tms pp_tms (neg,pos)
238    end
239  end;
240
241fun pp_termorder (TO (_,vars,eqns,sat)) =
242  pp_bracket "{" (if sat then "}*" else "}")
243  (pp_binop " |" (pp_sequence "" pp_string)
244   (pp_sequence "," (pp_equation vars))) (vars,eqns);
245
246val termorder_to_string = PP.pp_to_string (!LINE_LENGTH) pp_termorder;
247
248local
249  val q2s = PP.pp_to_string (!LINE_LENGTH)
250            (pp_list (pp_binop " |->" pp_string pp_int)) o M.to_list;
251
252  fun wf_eqn vars eqn =
253    if M.all (fn ("",_) => true | (v,_) => mem v vars) eqn then ()
254    else raise Bug ("wf_eqn: malformed equation: " ^ q2s eqn);
255in
256  fun chatto n s (to as TO (_,vars,eqns,_)) =
257    if not (chatting n) then () else
258      (chat (s ^ ":\n" ^ termorder_to_string to ^ "\n");
259       app (wf_eqn vars) eqns);
260end;
261
262(* ------------------------------------------------------------------------- *)
263(* Basic operations                                                          *)
264(* ------------------------------------------------------------------------- *)
265
266fun empty parm = TO (parm,[],[],true);
267
268fun TON parm eqns =
269  let val eqns = normalize parm eqns
270  in TO (parm, calc_vars eqns, eqns, null eqns)
271  end;
272
273fun tnull (TO (_,[],[],_)) = true | tnull _ = false;
274
275fun vars (TO (_,fvs,_,_)) = fvs;
276
277fun add_leq lr (to as TO (parm,vars,eqns,_)) =
278  let
279    fun keep eqn =
280      good_eqn parm eqn andalso
281      not (superfluous eqn eqns) andalso
282      (if not (strictly_superfluous (M.compl eqn) eqns) then true
283       else raise Error "add_leq: direct contradiction")
284
285    fun inc eqn =
286      let
287        val () = chatto 1 "add_leq input" to
288        val vars' = M.foldl (eqn_var insert) vars eqn
289        val eqns' = eqn :: List.filter (not o stronger eqn) eqns
290        val to = TO (parm,vars',eqns',false)
291        val () = chatto 1 "add_leq result" to
292      in
293        to
294      end
295  in
296    case mk_eqn parm lr of Equal => to
297    | Less => to
298    | Greater => raise Error "add_leq: violates order (weight)"
299    | Equation eqn => if keep eqn then inc eqn else to
300  end;
301
302fun add_leqs lrs to = foldl (uncurry add_leq) to lrs;
303
304local
305  fun table_to_string vars vars' tab =
306    let
307      fun nicevar "" = "1" | nicevar v = v;
308      fun nicerow l = "[" :: map (fn x => " " ^ x) (l @ ["]"])
309      fun makerow v =
310        nicevar v :: map (int_to_string o M.count (B.find (tab,v))) vars
311    in
312      join "\n"
313      (align_table {left = false, pad = #" "}
314       (map nicerow (("" :: map nicevar vars) :: map makerow vars'))) ^ "\n"
315    end;
316
317  fun new_vars vars mapl =
318    let val (ls,rs) = unzip (map (fn x |-> y => (x,y)) mapl)
319    in FVTL (List.filter (not o C mem ls) vars) rs
320    end;
321
322  val m0 = M.empty String.compare;
323  fun m1 xi = M.insert xi m0;
324  fun mn xis = foldl (uncurry M.insert) m0 xis;
325
326  fun table_add parm vars' ((v |-> t), tab) =
327    let
328      val {weight = wf, ...} : parameters = parm
329      fun add (w,i,t) = B.insert (t, w, M.insert (v, i) (B.find (t, w)))
330      val tab = if not (mem v vars') then tab else add (v,~1,tab)
331    in
332      M.foldl add tab (weight wf t)
333    end;
334
335  fun mk_table parm vars vars' =
336    let
337      fun init (v,m) = B.insert (m, v, if mem v vars then m1 (v,1) else m0)
338      val tab = foldl init (B.mkDict String.compare) vars'
339    in
340      foldl (table_add parm vars') tab
341    end;
342
343  fun new_eqn vars vars' tab eqn =
344    let
345      fun g (v,i,n) = n + M.count eqn v * i
346      fun f (v,m) = M.insert (v, M.foldl g 0 (B.find (tab,v))) m
347    in
348      foldl f m0 vars'
349    end;
350
351  fun nontriv mapl (to as TO (parm,vars,eqns,_)) =
352    let
353      val () = chatto 1 "subst input" to
354      val vars1 = "" :: vars
355      val vars2 = "" :: new_vars vars mapl
356      val tab = mk_table parm vars1 vars2 mapl
357      val _ = chatting 1 andalso
358              chat ("subst table:\n"^table_to_string vars1 vars2 tab)
359      val eqns' = map (new_eqn vars1 vars2 tab) eqns
360      val to = TON parm eqns'
361      val () = chatto 1 "subst result" to
362    in
363      to
364    end;
365in
366  fun subst sub (to as TO (_,vars,_,_)) =
367    let val mapl = mlibSubst.to_maplets (mlibSubst.norm (mlibSubst.restrict vars sub))
368    in if null mapl then to else nontriv mapl to
369    end;
370end;
371
372local
373  fun cast_away eqns = List.filter (fn eqn => not (superfluous eqn eqns));
374in
375  fun merge (TO (_,_,[],_)) to = to
376    | merge to (TO (_,_,[],_)) = to
377    | merge to1 to2 =
378    let
379      val () = chatto 1 "merge input1" to1
380      val () = chatto 1 "merge input2" to2
381      val TO (parm,_,eqns1,_) = to1
382      val TO (_,_,eqns2,_) = to2
383      val eqns1 = cast_away eqns2 eqns1
384      val eqns2 = cast_away eqns1 eqns2
385      val to =
386        if null eqns1 then to2 else if null eqns2 then to1 else
387          let val eqns = eqns1 @ eqns2
388          in TO (parm, calc_vars eqns, eqns, false)
389          end
390      val () = chatto 1 "merge result" to
391    in
392      to
393    end;
394end;
395
396(* ------------------------------------------------------------------------- *)
397(* Interface to mlibOmega.                                                       *)
398(* ------------------------------------------------------------------------- *)
399
400local
401  val raw_equations_to_string =
402    String.concat o
403    map (fn x => PP.pp_to_string (!LINE_LENGTH) (pp_list pp_int) x ^ "\n");
404
405  fun pos_eqns n =
406    snd (funpow n (fn (t,r) => (0 :: t, (1 :: t) :: map (cons 0) r)) ([~1],[]));
407
408  (* Remember that list_eqn does partial evaluation on vars *)
409  fun omega_eqns vars eqns = pos_eqns (length vars) @ map (list_eqn vars) eqns;
410
411  open mlibOmega;
412
413  fun mk_db normalc eqns db exc =
414    case eqns of [] => normalc db
415    | c :: cs =>
416      add_check_factoid db (gcd_check_dfactoid (fromList c, ASM ()))
417      (mk_db normalc cs) exc;
418
419  fun check eqns =
420    mk_db (fn db => work db I) eqns (dbempty (length (hd eqns))) I;
421
422  fun inconsistent (SATISFIABLE _) = false
423    | inconsistent (CONTR _) = true
424    | inconsistent NO_CONCL = false;
425
426  (* Uncomment this check function to print out the linear arithmetic problems
427  val THRESHOLD = 1.0;
428
429  fun result_to_string (SATISFIABLE _) = "satisfiable"
430    | result_to_string (CONTR _) = "inconsistent"
431    | result_to_string NO_CONCL = "no conclusion";
432
433  val check = fn eqns =>
434    let
435      val (t,r) = timed check eqns
436      val () =
437        if t < THRESHOLD then () else
438          print ("\n\nOMEGA: time = " ^ Real.fmt (StringCvt.FIX (SOME 3)) t ^
439                 "s\n" ^ raw_equations_to_string eqns ^
440                 "OMEGA: " ^ result_to_string r ^ "\n\n")
441    in
442      r
443    end;
444  *)
445in
446  fun consistent (to as TO (_,_,_,true)) = SOME to
447    | consistent (to as TO (parm,vars,eqns,false)) =
448    let
449      val () = chatto 1 "consistent" to
450    in
451      if inconsistent (check (omega_eqns vars eqns)) then NONE
452      else SOME (TO (parm,vars,eqns,true))
453    end;
454(* This bug has now been fixed, but others may appear in the future :-)
455    handle Option =>
456      (print ("BUG in mlibOmega library: uncaught Option exception" ^
457              "\ntermorder was:\n" ^ termorder_to_string to ^
458              "\nsent to mlibOmega:\n" ^ raw_equations_to_string (omega_eqns to) ^
459              "\n\n"); true)
460*)
461end;
462
463(* ------------------------------------------------------------------------- *)
464(* Query.                                                                    *)
465(* ------------------------------------------------------------------------- *)
466
467fun subsumes (TO (_,_,eqns1,_)) (TO (_,_,eqns2,_)) =
468  List.all (fn eqn => superfluous eqn eqns2) eqns1;
469
470local
471  fun cmp _ _ Equal = SOME EQUAL
472    | cmp _ _ Less = SOME LESS
473    | cmp _ _ Greater = SOME GREATER
474    | cmp parm eqns (Equation eqn) =
475    let
476      val eqn' = M.compl eqn
477    in
478      if inconsistent_eqn eqn then SOME GREATER
479      else if inconsistent_eqn eqn' then SOME LESS
480      else if strictly_superfluous eqn eqns then SOME LESS
481      else if strictly_superfluous eqn' eqns then SOME GREATER
482      else NONE
483    end;
484in
485  fun compare (to as TO (parm,_,eqns,_)) lr =
486    let
487      val () = chatto 1 "compare input" to
488      val _ = chatting 1 andalso
489              chat ("comparing " ^ term_to_string (fst lr) ^
490                    " and " ^ term_to_string (snd lr) ^ "\n")
491      val res = cmp parm eqns (mk_eqn parm lr)
492      val _ = chatting 1 andalso
493              chat ("compare result = " ^ partialorder_to_string res ^ "\n")
494    in
495      res
496    end;
497end;
498
499(* ------------------------------------------------------------------------- *)
500(* Name binding.                                                             *)
501(* ------------------------------------------------------------------------- *)
502
503val null = tnull;
504
505(* Quick testing
506app load ["mlibThm"];
507val T = parse_term;
508val F = parse_formula;
509installPP pp_termorder;
510installPP mlibSubst.pp_subst;
511installPP mlibThm.pp_thm;
512
513val to = empty defaults;
514val to = try (C add_leq to) (T`f x (f y z)`, T`f (f x y) z`);
515val x = (total o try) (C add_leq to) (T`f (f x y) z`, T`f x (f y z)`);
516val to = try (C add_leq to) (T`P (f a b)`, T`P x`);
517val to = try (C add_leq to) (T`P y`, T`P (g a b c)`);
518val to = try (C add_leq to) (T`x + y`, T`y + x`);
519val c = consistent to;
520val to = try (subst (("x" |-> T`v`) ::> |<>|)) to;
521val to = try (subst (("v" |-> T`f x x x x a a a a`) ::> |<>|)) to;
522val c = consistent to;
523
524val to = empty defaults;
525val to = try (C add_leq to) (T`P y`, T`P (g a b c d e f)`);
526val to = try (C add_leq to) (T`x + y`, T`y + x`);
527val to = try (C add_leq to) (T`x + y`, T`y + x`);
528val to = try (subst (("x" |-> T`f x x x`) ::> |<>|)) to;
529val c = consistent to;
530val to = try (subst (("x" |-> T`f w v`) ::> |<>|)) to;
531val c = consistent to;
532
533val to = empty defaults;
534val to = try (C add_leq to) (T`f x y`, T`f y x`);
535val to = try (subst (("x" |-> T`f x`) ::> |<>|)) to;
536val x = compare to (T`f x`, T`g y`);
537val x = compare to (T`g x`, T`f y`);
538val x = compare to (T`g a`, T`f a`);
539val x = compare to (T`f a`, T`g a`);
540val th =
541  mlibThm.ORD_REWRITE (compare to)
542  (map (mlibThm.AXIOM o wrap o F)
543   [`x + (y + z) = y + (x + z)`, `(x + y) + z = x + (y + z)`])
544  (mlibThm.AXIOM [F`P (y + x + y + x + y + x + 0)`]);
545*)
546
547end
548