1(*****************************************************************************)
2(* FILE          : solve_ineqs.sml                                           *)
3(* DESCRIPTION   : Functions for solving inequalities.                       *)
4(*                                                                           *)
5(* READS FILES   : <none>                                                    *)
6(* WRITES FILES  : <none>                                                    *)
7(*                                                                           *)
8(* AUTHOR        : R.J.Boulton, University of Cambridge                      *)
9(* DATE          : 4th March 1991                                            *)
10(*                                                                           *)
11(* TRANSLATOR    : R.J.Boulton, University of Cambridge                      *)
12(* DATE          : 5th February 1993                                         *)
13(*                                                                           *)
14(* LAST MODIFIED : R.J.Boulton                                               *)
15(* DATE          : 7th August 1996                                           *)
16(*****************************************************************************)
17
18structure Solve_ineqs :> Solve_ineqs =
19struct
20  open Arbint HolKernel boolLib
21       Int_extra Arith_cons Term_coeffs RJBConv Theorems Thm_convs
22       Norm_arith Norm_ineqs reduceLib;
23
24val op << = String.<
25infix << THENC ##;
26
27val num_CONV = Num_conv.num_CONV;
28val MATCH_MP = Drule.MATCH_MP;
29
30fun failwith function = raise HOL_ERR{origin_structure = "Solve_ineqs",
31                                      origin_function = function,
32                                      message = ""};
33
34
35(*===========================================================================*)
36(* Multiplying normalized arithmetic expressions by a constant               *)
37(*===========================================================================*)
38
39(*---------------------------------------------------------------------------*)
40(* CONST_TIMES_ARITH_CONV : conv                                             *)
41(*                                                                           *)
42(* Converts the product of a constant and a normalized arithmetic expression *)
43(* to a new normalized arithmetic expression.                                *)
44(*                                                                           *)
45(* Example:                                                                  *)
46(*                                                                           *)
47(*    CONST_TIMES_ARITH_CONV `3 * (1 + (3 * x) + (2 * y))`  --->             *)
48(*    |- 3 * (1 + ((3 * x) + (2 * y))) = 3 + ((9 * x) + (6 * y))             *)
49(*---------------------------------------------------------------------------*)
50
51fun CONST_TIMES_ARITH_CONV tm =
52 (let fun CONST_TIMES_VARS_CONV tm =
53         if (is_mult (arg2 tm))
54         then (MULT_ASSOC_CONV THENC
55               (RATOR_CONV (RAND_CONV FAST_MULT_CONV))) tm
56         else (LEFT_ADD_DISTRIB_CONV THENC
57               (RATOR_CONV
58                 (RAND_CONV
59                   (MULT_ASSOC_CONV THENC
60                    (RATOR_CONV (RAND_CONV FAST_MULT_CONV))))) THENC
61               (RAND_CONV CONST_TIMES_VARS_CONV)) tm
62      val tm' = arg2 tm
63  in  if (is_num_const tm') then FAST_MULT_CONV tm
64      else if (is_mult tm') then
65         (MULT_ASSOC_CONV THENC
66          (RATOR_CONV (RAND_CONV FAST_MULT_CONV))) tm
67      else if (is_num_const (arg1 tm')) then
68         (LEFT_ADD_DISTRIB_CONV THENC
69          (RATOR_CONV (RAND_CONV FAST_MULT_CONV)) THENC
70          (RAND_CONV CONST_TIMES_VARS_CONV)) tm
71      else CONST_TIMES_VARS_CONV tm
72  end
73 ) handle (HOL_ERR _) => failwith "CONST_TIMES_ARITH_CONV";
74
75(*---------------------------------------------------------------------------*)
76(* MULT_LEQ_BY_CONST_CONV : term -> conv                                     *)
77(*                                                                           *)
78(* Multiplies both sides of a normalized inequality by a non-zero constant.  *)
79(*                                                                           *)
80(* Example:                                                                  *)
81(*                                                                           *)
82(*    MULT_LEQ_BY_CONST_CONV `3` `(1 + (3 * x) + (2 * y)) <= (3 * z)`  --->  *)
83(*    |- (1 + ((3 * x) + (2 * y))) <= (3 * z) =                              *)
84(*       (3 + ((9 * x) + (6 * y))) <= (9 * z)                                *)
85(*---------------------------------------------------------------------------*)
86
87fun MULT_LEQ_BY_CONST_CONV constant tm =
88 (let val (tm1,tm2) = dest_leq tm
89      and n = int_of_term constant
90  in
91  if (n = zero) then failwith "fail"
92  else if (n = one) then ALL_CONV tm
93  else let val constant' = term_of_int (n - one)
94           val th = SYM (num_CONV constant)
95           val th1 = SPEC constant' (SPEC tm2 (SPEC tm1 MULT_LEQ_SUC))
96           val th2 =
97              ((RATOR_CONV
98                 (RAND_CONV (RATOR_CONV (RAND_CONV (fn _ => th))))) THENC
99               (RAND_CONV (RATOR_CONV (RAND_CONV (fn _ => th)))))
100              (rhs (concl th1))
101       in  ((fn _ => TRANS th1 th2) THENC
102            (ARGS_CONV CONST_TIMES_ARITH_CONV)) tm
103       end
104  end
105 ) handle (HOL_ERR _) => failwith "MULT_LEQ_BY_CONST_CONV";
106
107(*===========================================================================*)
108(* Solving inequalities between constants                                    *)
109(*===========================================================================*)
110
111(*---------------------------------------------------------------------------*)
112(* LEQ_CONV : conv                                                           *)
113(*                                                                           *)
114(* Given a term of the form `a <= b` where a and b are constants, returns    *)
115(* the theorem |- (a <= b) = T or the theorem |- (a <= b) = F depending on   *)
116(* the values of a and b.                                                    *)
117(*                                                                           *)
118(* Optimized for when one or both of the arguments is zero.                  *)
119(*---------------------------------------------------------------------------*)
120
121val LEQ_CONV = reduceLib.LE_CONV
122
123(*===========================================================================*)
124(* Eliminating variables from sets of inequalities                           *)
125(*===========================================================================*)
126
127(*---------------------------------------------------------------------------*)
128(* WEIGHTED_SUM :                                                            *)
129(*    string ->                                                              *)
130(*    ((int * (string * int) list) * (int * (string * int) list)) ->         *)
131(*    ((int * (string * int) list) * (unit -> thm))                          *)
132(*                                                                           *)
133(* Function to eliminate a specified variable from two inequalities by       *)
134(* forming their weighted sum. The inequalities must be given as bindings.   *)
135(* The result is a pair. The first component is a binding representing the   *)
136(* combined inequality, and the second component is a function. When applied *)
137(* to ():unit this function returns a theorem which states that under the    *)
138(* assumption that the two original inequalities are true, then the          *)
139(* resultant inequality is true.                                             *)
140(*                                                                           *)
141(* The variable to be eliminated should be on the right-hand side of the     *)
142(* first inequality, and on the left-hand side of the second.                *)
143(*                                                                           *)
144(* Example:                                                                  *)
145(*                                                                           *)
146(*    WEIGHTED_SUM `y` ((1,[(`x`, -3);(`y`, 1)]), (3,[(`x`, -3);(`y`, -1)])) *)
147(*    --->                                                                   *)
148(*    ((4, [(`x`, -6)]), -)                                                  *)
149(*                                                                           *)
150(*    snd it ()  --->                                                        *)
151(*    (3 * x) <= (1 + (1 * y)), ((3 * x) + (1 * y)) <= 3 |- (6 * x) <= 4     *)
152(*---------------------------------------------------------------------------*)
153
154fun WEIGHTED_SUM name (coeffs1,coeffs2) =
155 (let val coeff1 = assoc name (snd coeffs1)
156      and coeff2 = zero - (assoc name (snd coeffs2))
157      val mult = lcm (coeff1,coeff2)
158      val mult1 = mult div coeff1
159      and mult2 = mult div coeff2
160      val coeffs1' =
161         ((fn n => n * mult1) ## (map (fn (s,n) => (s,n * mult1)))) coeffs1
162      and coeffs2' =
163         ((fn n => n * mult2) ## (map (fn (s,n) => (s,n * mult2)))) coeffs2
164      val (adds1,adds2) = diff_of_coeffs (coeffs1',coeffs2')
165      val coeffs1'' = merge_coeffs adds1 (lhs_coeffs coeffs1')
166      and coeffs2'' = merge_coeffs adds2 (rhs_coeffs coeffs2')
167      val new_coeffs = merge_coeffs (negate_coeffs coeffs1'') coeffs2''
168      fun thf () =
169         let val th1 =
170                RULE_OF_CONV
171                ((if (mult1 = one)
172                  then ALL_CONV
173                  else MULT_LEQ_BY_CONST_CONV (term_of_int mult1)) THENC
174                 (if (adds1 = (zero,[]))
175                  then ALL_CONV
176                  else (ADD_COEFFS_TO_LEQ_CONV adds1) THENC
177                       (RAND_CONV (SORT_AND_GATHER_CONV THENC
178                                    NORM_ZERO_AND_ONE_CONV))))
179                (build_leq coeffs1)
180             and th2 =
181                RULE_OF_CONV
182                ((if (mult2 = one)
183                  then ALL_CONV
184                  else MULT_LEQ_BY_CONST_CONV (term_of_int mult2)) THENC
185                 (if (adds2 = (zero,[]))
186                  then ALL_CONV
187                  else (ADD_COEFFS_TO_LEQ_CONV adds2) THENC
188                       (RATOR_CONV
189                         (RAND_CONV (SORT_AND_GATHER_CONV THENC
190                                      NORM_ZERO_AND_ONE_CONV)))))
191                (build_leq coeffs2)
192             val th = CONJ (UNDISCH (fst (EQ_IMP_RULE th1)))
193                           (UNDISCH (fst (EQ_IMP_RULE th2)))
194             val th1conv =
195                if (adds1 = (zero,[]))
196                then ALL_CONV
197                else RATOR_CONV
198                      (RAND_CONV
199                        (SORT_AND_GATHER_CONV THENC NORM_ZERO_AND_ONE_CONV))
200             and th2conv =
201                if (adds2 = (zero,[]))
202                then ALL_CONV
203                else RAND_CONV
204                      (SORT_AND_GATHER_CONV THENC NORM_ZERO_AND_ONE_CONV)
205         in  CONV_RULE (th1conv THENC th2conv THENC LESS_OR_EQ_GATHER_CONV)
206                        (MATCH_MP LESS_EQ_TRANSIT th)
207         end
208  in  (new_coeffs,thf)
209  end
210 ) handle (HOL_ERR _) => failwith "WEIGHTED_SUM";
211
212(*---------------------------------------------------------------------------*)
213(* var_to_elim : ('a * (string * int) list) list -> string                   *)
214(*                                                                           *)
215(* Given a list of inequalities (as bindings), this function determines      *)
216(* which variable to eliminate. Such a variable must occur in two            *)
217(* inequalites on different sides. The variable chosen is the one that gives *)
218(* rise to the least number of pairings.                                     *)
219(*---------------------------------------------------------------------------*)
220
221fun var_to_elim coeffsl =
222 (let fun var_to_elim' bind =
223         if (null bind)
224         then ([],[])
225         else let val (name,coeff) = hd bind
226                  and (occsl,occsr) = var_to_elim' (tl bind)
227              in  if (coeff < zero) then ((name, one)::occsl,occsr)
228                  else if (coeff > zero) then (occsl,(name, one)::occsr)
229                  else (occsl,occsr)
230              end
231      fun min_increase bind1 bind2 =
232         let val (name1:string,num1:int) = hd bind1
233             and (name2,num2) = hd bind2
234         in  if (name1 = name2) then
235                (let val increase = (num1 * num2) - (num1 + num2)
236                 in  (let val (name,min) = min_increase (tl bind1) (tl bind2)
237                      in  if (min < increase)
238                          then (name,min)
239                          else (name1,increase)
240                      end)
241                     handle _ => (name1,increase)
242                 end)
243             else if (name1 << name2) then min_increase (tl bind1) bind2
244             else min_increase bind1 (tl bind2)
245         end
246      val merge =
247         end_itlist (fn b1 => fn b2 => snd (merge_coeffs (zero,b1) (zero,b2)))
248      val occs = map (var_to_elim' o snd) coeffsl
249      val (occsl,occsr) = (merge ## merge) (split occs)
250  in  fst (min_increase occsl occsr)
251  end
252 ) handle _ => failwith "var_to_elim";
253
254(*---------------------------------------------------------------------------*)
255(* VAR_ELIM : (int * (string * int) list) list -> (int list * (unit -> thm)) *)
256(*                                                                           *)
257(* Given a list of inequalities represented by bindings, this function       *)
258(* returns a `lazy' theorem with false (actually an inequality between       *)
259(* constants that can immediately be shown to be false) as its conclusion,   *)
260(* and some of the inequalities as assumptions. A list of numbers is also    *)
261(* returned. These are the positions in the argument list of the             *)
262(* inequalities that are assumptions of the theorem. The function fails if   *)
263(* the set of inequalities is satisfiable.                                   *)
264(*                                                                           *)
265(* The function assumes that none of the inequalities given are false, that  *)
266(* is they either contain variables, or evaluate to true. Those that are     *)
267(* true are filtered out. The function then determines which variable to     *)
268(* eliminate and splits the remaining inequalities into three sets: ones in  *)
269(* which the variable occurs on the left-hand side, ones in which the        *)
270(* variable occurs on the right, and ones in which the variable does not     *)
271(* occur.                                                                    *)
272(*                                                                           *)
273(* Pairings of the `right' and `left' inequalities are then made, and the    *)
274(* weighted sum of each is determined, except that as soon as a pairing      *)
275(* yields false, the process is terminated. It may well be the case that no  *)
276(* pairing gives false. In this case, the new inequalities are added to the  *)
277(* inequalities that did not contain the variable, and a recursive call is   *)
278(* made.                                                                     *)
279(*                                                                           *)
280(* The list of numbers from the recursive call (representing assumptions) is *)
281(* split into two: those that point to inequalities that were produced by    *)
282(* doing weighted sums, and those that were not. The latter can be traced    *)
283(* back so that their positions in the original argument list can be         *)
284(* returned. The other inequalities have to be discharged from the theorem   *)
285(* using the theorems proved by performing weighted sums. Each assumption    *)
286(* thus gives rise to two new assumptions and the conclusion remains false.  *)
287(* The positions of the two new assumptions in the original argument list    *)
288(* are added to the list to be returned. Duplicates are removed from this    *)
289(* list before returning it.                                                 *)
290(*---------------------------------------------------------------------------*)
291
292fun VAR_ELIM coeffsl =
293 let fun upto from to =
294        if (from > to)
295        then []
296        else from::(upto (from + one) to)
297     fun left_ineqs var icoeffsl =
298        let fun left_ineq icoeffs =
299               not (null (filter
300                          (fn (name,coeff) => (name = var) andalso (coeff < zero))
301                          (snd (snd icoeffs))))
302        in  filter left_ineq icoeffsl
303        end
304     fun right_ineqs var icoeffsl =
305        let fun right_ineq icoeffs =
306               not (null (filter
307                          (fn (name,coeff) => (name = var) andalso (coeff > zero))
308                          (snd (snd icoeffs))))
309        in  filter right_ineq icoeffsl
310        end
311     fun no_var_ineqs var icoeffsl =
312        let fun no_var_ineq icoeffs =
313               null
314                (filter
315                 (fn (name,coeff) => (name = var) andalso (not (coeff = zero)))
316                 (snd (snd icoeffs)))
317        in  filter no_var_ineq icoeffsl
318        end
319     fun pair_ineqs (ricoeffs,licoeffs) =
320        let fun pair (ricoeffs,licoeffs) =
321               if (null ricoeffs)
322               then []
323               else (map (fn l => (hd ricoeffs,l)) licoeffs)::
324                    (pair (tl ricoeffs,licoeffs))
325        in  flatten (pair (ricoeffs,licoeffs))
326        end
327     fun weighted_sums var pairs =
328        if (null pairs)
329        then (false,[])
330        else let val (success,rest) = weighted_sums var (tl pairs)
331             in  if success
332                 then (success,rest)
333                 else let val ((lindex,lcoeffs),(rindex,rcoeffs)) = hd pairs
334                          val ((const,bind),f) =
335                                 WEIGHTED_SUM var (lcoeffs,rcoeffs)
336                      in  if ((null bind) andalso (const < zero))
337                          then (true,[((lindex,rindex),((const,bind),f))])
338                          else (false,((lindex,rindex),((const,bind),f))::rest)
339                      end
340             end
341     fun chain_assums ineqs thf indexl =
342       if (null indexl) then ([],thf)
343       else let
344         val (prev_indexl,thf') = chain_assums ineqs thf (tl indexl)
345         and ((lindex,rindex),(coeffs,f)) = el (toInt (hd indexl)) ineqs
346       in  (lindex::rindex::prev_indexl,
347            fn () => PROVE_HYP (f ()) (thf' ()))
348       end
349 in
350 (let val icoeffsl = combine (upto one (fromInt (length coeffsl)), coeffsl)
351      val icoeffsl' = filter (fn (i,(const,bind)) => not (null bind)) icoeffsl
352      val var = var_to_elim (map snd icoeffsl')
353      val ricoeffs = right_ineqs var icoeffsl'
354      and licoeffs = left_ineqs var icoeffsl'
355      and nicoeffs = no_var_ineqs var icoeffsl'
356      val pairs = pair_ineqs (ricoeffs,licoeffs)
357      val (success,new_ineqs) = weighted_sums var pairs
358  in  if success
359      then
360        case new_ineqs
361        of [((lindex,rindex),(coeffs,thf))] => ([lindex,rindex],thf)
362         | _ => raise Match
363      else let val n = fromInt (length new_ineqs)
364               and new_coeffs =
365                  (map (fst o snd) new_ineqs) @ (map snd nicoeffs)
366               val (indexl,thf) = VAR_ELIM new_coeffs
367               val (prev_indexl,these_indexl) =
368                  Lib.partition (fn i => i > n) indexl
369               val prev_indexl' =
370                  map (fn i => fst (el (toInt (i - n)) nicoeffs)) prev_indexl
371               val (these_indexl',thf') =
372                  chain_assums new_ineqs thf these_indexl
373           in  (Lib.mk_set (these_indexl' @ prev_indexl'),thf')
374           end
375  end
376 ) handle _ => failwith "VAR_ELIM"
377 end;
378
379end
380