1(*****************************************************************************) 2(* FILE : solve_ineqs.sml *) 3(* DESCRIPTION : Functions for solving inequalities. *) 4(* *) 5(* READS FILES : <none> *) 6(* WRITES FILES : <none> *) 7(* *) 8(* AUTHOR : R.J.Boulton, University of Cambridge *) 9(* DATE : 4th March 1991 *) 10(* *) 11(* TRANSLATOR : R.J.Boulton, University of Cambridge *) 12(* DATE : 5th February 1993 *) 13(* *) 14(* LAST MODIFIED : R.J.Boulton *) 15(* DATE : 7th August 1996 *) 16(*****************************************************************************) 17 18structure Solve_ineqs :> Solve_ineqs = 19struct 20 open Arbint HolKernel boolLib 21 Int_extra Arith_cons Term_coeffs RJBConv Theorems Thm_convs 22 Norm_arith Norm_ineqs reduceLib; 23 24val op << = String.< 25infix << THENC ##; 26 27val num_CONV = Num_conv.num_CONV; 28val MATCH_MP = Drule.MATCH_MP; 29 30fun failwith function = raise HOL_ERR{origin_structure = "Solve_ineqs", 31 origin_function = function, 32 message = ""}; 33 34 35(*===========================================================================*) 36(* Multiplying normalized arithmetic expressions by a constant *) 37(*===========================================================================*) 38 39(*---------------------------------------------------------------------------*) 40(* CONST_TIMES_ARITH_CONV : conv *) 41(* *) 42(* Converts the product of a constant and a normalized arithmetic expression *) 43(* to a new normalized arithmetic expression. *) 44(* *) 45(* Example: *) 46(* *) 47(* CONST_TIMES_ARITH_CONV `3 * (1 + (3 * x) + (2 * y))` ---> *) 48(* |- 3 * (1 + ((3 * x) + (2 * y))) = 3 + ((9 * x) + (6 * y)) *) 49(*---------------------------------------------------------------------------*) 50 51fun CONST_TIMES_ARITH_CONV tm = 52 (let fun CONST_TIMES_VARS_CONV tm = 53 if (is_mult (arg2 tm)) 54 then (MULT_ASSOC_CONV THENC 55 (RATOR_CONV (RAND_CONV FAST_MULT_CONV))) tm 56 else (LEFT_ADD_DISTRIB_CONV THENC 57 (RATOR_CONV 58 (RAND_CONV 59 (MULT_ASSOC_CONV THENC 60 (RATOR_CONV (RAND_CONV FAST_MULT_CONV))))) THENC 61 (RAND_CONV CONST_TIMES_VARS_CONV)) tm 62 val tm' = arg2 tm 63 in if (is_num_const tm') then FAST_MULT_CONV tm 64 else if (is_mult tm') then 65 (MULT_ASSOC_CONV THENC 66 (RATOR_CONV (RAND_CONV FAST_MULT_CONV))) tm 67 else if (is_num_const (arg1 tm')) then 68 (LEFT_ADD_DISTRIB_CONV THENC 69 (RATOR_CONV (RAND_CONV FAST_MULT_CONV)) THENC 70 (RAND_CONV CONST_TIMES_VARS_CONV)) tm 71 else CONST_TIMES_VARS_CONV tm 72 end 73 ) handle (HOL_ERR _) => failwith "CONST_TIMES_ARITH_CONV"; 74 75(*---------------------------------------------------------------------------*) 76(* MULT_LEQ_BY_CONST_CONV : term -> conv *) 77(* *) 78(* Multiplies both sides of a normalized inequality by a non-zero constant. *) 79(* *) 80(* Example: *) 81(* *) 82(* MULT_LEQ_BY_CONST_CONV `3` `(1 + (3 * x) + (2 * y)) <= (3 * z)` ---> *) 83(* |- (1 + ((3 * x) + (2 * y))) <= (3 * z) = *) 84(* (3 + ((9 * x) + (6 * y))) <= (9 * z) *) 85(*---------------------------------------------------------------------------*) 86 87fun MULT_LEQ_BY_CONST_CONV constant tm = 88 (let val (tm1,tm2) = dest_leq tm 89 and n = int_of_term constant 90 in 91 if (n = zero) then failwith "fail" 92 else if (n = one) then ALL_CONV tm 93 else let val constant' = term_of_int (n - one) 94 val th = SYM (num_CONV constant) 95 val th1 = SPEC constant' (SPEC tm2 (SPEC tm1 MULT_LEQ_SUC)) 96 val th2 = 97 ((RATOR_CONV 98 (RAND_CONV (RATOR_CONV (RAND_CONV (fn _ => th))))) THENC 99 (RAND_CONV (RATOR_CONV (RAND_CONV (fn _ => th))))) 100 (rhs (concl th1)) 101 in ((fn _ => TRANS th1 th2) THENC 102 (ARGS_CONV CONST_TIMES_ARITH_CONV)) tm 103 end 104 end 105 ) handle (HOL_ERR _) => failwith "MULT_LEQ_BY_CONST_CONV"; 106 107(*===========================================================================*) 108(* Solving inequalities between constants *) 109(*===========================================================================*) 110 111(*---------------------------------------------------------------------------*) 112(* LEQ_CONV : conv *) 113(* *) 114(* Given a term of the form `a <= b` where a and b are constants, returns *) 115(* the theorem |- (a <= b) = T or the theorem |- (a <= b) = F depending on *) 116(* the values of a and b. *) 117(* *) 118(* Optimized for when one or both of the arguments is zero. *) 119(*---------------------------------------------------------------------------*) 120 121val LEQ_CONV = reduceLib.LE_CONV 122 123(*===========================================================================*) 124(* Eliminating variables from sets of inequalities *) 125(*===========================================================================*) 126 127(*---------------------------------------------------------------------------*) 128(* WEIGHTED_SUM : *) 129(* string -> *) 130(* ((int * (string * int) list) * (int * (string * int) list)) -> *) 131(* ((int * (string * int) list) * (unit -> thm)) *) 132(* *) 133(* Function to eliminate a specified variable from two inequalities by *) 134(* forming their weighted sum. The inequalities must be given as bindings. *) 135(* The result is a pair. The first component is a binding representing the *) 136(* combined inequality, and the second component is a function. When applied *) 137(* to ():unit this function returns a theorem which states that under the *) 138(* assumption that the two original inequalities are true, then the *) 139(* resultant inequality is true. *) 140(* *) 141(* The variable to be eliminated should be on the right-hand side of the *) 142(* first inequality, and on the left-hand side of the second. *) 143(* *) 144(* Example: *) 145(* *) 146(* WEIGHTED_SUM `y` ((1,[(`x`, -3);(`y`, 1)]), (3,[(`x`, -3);(`y`, -1)])) *) 147(* ---> *) 148(* ((4, [(`x`, -6)]), -) *) 149(* *) 150(* snd it () ---> *) 151(* (3 * x) <= (1 + (1 * y)), ((3 * x) + (1 * y)) <= 3 |- (6 * x) <= 4 *) 152(*---------------------------------------------------------------------------*) 153 154fun WEIGHTED_SUM name (coeffs1,coeffs2) = 155 (let val coeff1 = assoc name (snd coeffs1) 156 and coeff2 = zero - (assoc name (snd coeffs2)) 157 val mult = lcm (coeff1,coeff2) 158 val mult1 = mult div coeff1 159 and mult2 = mult div coeff2 160 val coeffs1' = 161 ((fn n => n * mult1) ## (map (fn (s,n) => (s,n * mult1)))) coeffs1 162 and coeffs2' = 163 ((fn n => n * mult2) ## (map (fn (s,n) => (s,n * mult2)))) coeffs2 164 val (adds1,adds2) = diff_of_coeffs (coeffs1',coeffs2') 165 val coeffs1'' = merge_coeffs adds1 (lhs_coeffs coeffs1') 166 and coeffs2'' = merge_coeffs adds2 (rhs_coeffs coeffs2') 167 val new_coeffs = merge_coeffs (negate_coeffs coeffs1'') coeffs2'' 168 fun thf () = 169 let val th1 = 170 RULE_OF_CONV 171 ((if (mult1 = one) 172 then ALL_CONV 173 else MULT_LEQ_BY_CONST_CONV (term_of_int mult1)) THENC 174 (if (adds1 = (zero,[])) 175 then ALL_CONV 176 else (ADD_COEFFS_TO_LEQ_CONV adds1) THENC 177 (RAND_CONV (SORT_AND_GATHER_CONV THENC 178 NORM_ZERO_AND_ONE_CONV)))) 179 (build_leq coeffs1) 180 and th2 = 181 RULE_OF_CONV 182 ((if (mult2 = one) 183 then ALL_CONV 184 else MULT_LEQ_BY_CONST_CONV (term_of_int mult2)) THENC 185 (if (adds2 = (zero,[])) 186 then ALL_CONV 187 else (ADD_COEFFS_TO_LEQ_CONV adds2) THENC 188 (RATOR_CONV 189 (RAND_CONV (SORT_AND_GATHER_CONV THENC 190 NORM_ZERO_AND_ONE_CONV))))) 191 (build_leq coeffs2) 192 val th = CONJ (UNDISCH (fst (EQ_IMP_RULE th1))) 193 (UNDISCH (fst (EQ_IMP_RULE th2))) 194 val th1conv = 195 if (adds1 = (zero,[])) 196 then ALL_CONV 197 else RATOR_CONV 198 (RAND_CONV 199 (SORT_AND_GATHER_CONV THENC NORM_ZERO_AND_ONE_CONV)) 200 and th2conv = 201 if (adds2 = (zero,[])) 202 then ALL_CONV 203 else RAND_CONV 204 (SORT_AND_GATHER_CONV THENC NORM_ZERO_AND_ONE_CONV) 205 in CONV_RULE (th1conv THENC th2conv THENC LESS_OR_EQ_GATHER_CONV) 206 (MATCH_MP LESS_EQ_TRANSIT th) 207 end 208 in (new_coeffs,thf) 209 end 210 ) handle (HOL_ERR _) => failwith "WEIGHTED_SUM"; 211 212(*---------------------------------------------------------------------------*) 213(* var_to_elim : ('a * (string * int) list) list -> string *) 214(* *) 215(* Given a list of inequalities (as bindings), this function determines *) 216(* which variable to eliminate. Such a variable must occur in two *) 217(* inequalites on different sides. The variable chosen is the one that gives *) 218(* rise to the least number of pairings. *) 219(*---------------------------------------------------------------------------*) 220 221fun var_to_elim coeffsl = 222 (let fun var_to_elim' bind = 223 if (null bind) 224 then ([],[]) 225 else let val (name,coeff) = hd bind 226 and (occsl,occsr) = var_to_elim' (tl bind) 227 in if (coeff < zero) then ((name, one)::occsl,occsr) 228 else if (coeff > zero) then (occsl,(name, one)::occsr) 229 else (occsl,occsr) 230 end 231 fun min_increase bind1 bind2 = 232 let val (name1:string,num1:int) = hd bind1 233 and (name2,num2) = hd bind2 234 in if (name1 = name2) then 235 (let val increase = (num1 * num2) - (num1 + num2) 236 in (let val (name,min) = min_increase (tl bind1) (tl bind2) 237 in if (min < increase) 238 then (name,min) 239 else (name1,increase) 240 end) 241 handle _ => (name1,increase) 242 end) 243 else if (name1 << name2) then min_increase (tl bind1) bind2 244 else min_increase bind1 (tl bind2) 245 end 246 val merge = 247 end_itlist (fn b1 => fn b2 => snd (merge_coeffs (zero,b1) (zero,b2))) 248 val occs = map (var_to_elim' o snd) coeffsl 249 val (occsl,occsr) = (merge ## merge) (split occs) 250 in fst (min_increase occsl occsr) 251 end 252 ) handle _ => failwith "var_to_elim"; 253 254(*---------------------------------------------------------------------------*) 255(* VAR_ELIM : (int * (string * int) list) list -> (int list * (unit -> thm)) *) 256(* *) 257(* Given a list of inequalities represented by bindings, this function *) 258(* returns a `lazy' theorem with false (actually an inequality between *) 259(* constants that can immediately be shown to be false) as its conclusion, *) 260(* and some of the inequalities as assumptions. A list of numbers is also *) 261(* returned. These are the positions in the argument list of the *) 262(* inequalities that are assumptions of the theorem. The function fails if *) 263(* the set of inequalities is satisfiable. *) 264(* *) 265(* The function assumes that none of the inequalities given are false, that *) 266(* is they either contain variables, or evaluate to true. Those that are *) 267(* true are filtered out. The function then determines which variable to *) 268(* eliminate and splits the remaining inequalities into three sets: ones in *) 269(* which the variable occurs on the left-hand side, ones in which the *) 270(* variable occurs on the right, and ones in which the variable does not *) 271(* occur. *) 272(* *) 273(* Pairings of the `right' and `left' inequalities are then made, and the *) 274(* weighted sum of each is determined, except that as soon as a pairing *) 275(* yields false, the process is terminated. It may well be the case that no *) 276(* pairing gives false. In this case, the new inequalities are added to the *) 277(* inequalities that did not contain the variable, and a recursive call is *) 278(* made. *) 279(* *) 280(* The list of numbers from the recursive call (representing assumptions) is *) 281(* split into two: those that point to inequalities that were produced by *) 282(* doing weighted sums, and those that were not. The latter can be traced *) 283(* back so that their positions in the original argument list can be *) 284(* returned. The other inequalities have to be discharged from the theorem *) 285(* using the theorems proved by performing weighted sums. Each assumption *) 286(* thus gives rise to two new assumptions and the conclusion remains false. *) 287(* The positions of the two new assumptions in the original argument list *) 288(* are added to the list to be returned. Duplicates are removed from this *) 289(* list before returning it. *) 290(*---------------------------------------------------------------------------*) 291 292fun VAR_ELIM coeffsl = 293 let fun upto from to = 294 if (from > to) 295 then [] 296 else from::(upto (from + one) to) 297 fun left_ineqs var icoeffsl = 298 let fun left_ineq icoeffs = 299 not (null (filter 300 (fn (name,coeff) => (name = var) andalso (coeff < zero)) 301 (snd (snd icoeffs)))) 302 in filter left_ineq icoeffsl 303 end 304 fun right_ineqs var icoeffsl = 305 let fun right_ineq icoeffs = 306 not (null (filter 307 (fn (name,coeff) => (name = var) andalso (coeff > zero)) 308 (snd (snd icoeffs)))) 309 in filter right_ineq icoeffsl 310 end 311 fun no_var_ineqs var icoeffsl = 312 let fun no_var_ineq icoeffs = 313 null 314 (filter 315 (fn (name,coeff) => (name = var) andalso (not (coeff = zero))) 316 (snd (snd icoeffs))) 317 in filter no_var_ineq icoeffsl 318 end 319 fun pair_ineqs (ricoeffs,licoeffs) = 320 let fun pair (ricoeffs,licoeffs) = 321 if (null ricoeffs) 322 then [] 323 else (map (fn l => (hd ricoeffs,l)) licoeffs):: 324 (pair (tl ricoeffs,licoeffs)) 325 in flatten (pair (ricoeffs,licoeffs)) 326 end 327 fun weighted_sums var pairs = 328 if (null pairs) 329 then (false,[]) 330 else let val (success,rest) = weighted_sums var (tl pairs) 331 in if success 332 then (success,rest) 333 else let val ((lindex,lcoeffs),(rindex,rcoeffs)) = hd pairs 334 val ((const,bind),f) = 335 WEIGHTED_SUM var (lcoeffs,rcoeffs) 336 in if ((null bind) andalso (const < zero)) 337 then (true,[((lindex,rindex),((const,bind),f))]) 338 else (false,((lindex,rindex),((const,bind),f))::rest) 339 end 340 end 341 fun chain_assums ineqs thf indexl = 342 if (null indexl) then ([],thf) 343 else let 344 val (prev_indexl,thf') = chain_assums ineqs thf (tl indexl) 345 and ((lindex,rindex),(coeffs,f)) = el (toInt (hd indexl)) ineqs 346 in (lindex::rindex::prev_indexl, 347 fn () => PROVE_HYP (f ()) (thf' ())) 348 end 349 in 350 (let val icoeffsl = combine (upto one (fromInt (length coeffsl)), coeffsl) 351 val icoeffsl' = filter (fn (i,(const,bind)) => not (null bind)) icoeffsl 352 val var = var_to_elim (map snd icoeffsl') 353 val ricoeffs = right_ineqs var icoeffsl' 354 and licoeffs = left_ineqs var icoeffsl' 355 and nicoeffs = no_var_ineqs var icoeffsl' 356 val pairs = pair_ineqs (ricoeffs,licoeffs) 357 val (success,new_ineqs) = weighted_sums var pairs 358 in if success 359 then 360 case new_ineqs 361 of [((lindex,rindex),(coeffs,thf))] => ([lindex,rindex],thf) 362 | _ => raise Match 363 else let val n = fromInt (length new_ineqs) 364 and new_coeffs = 365 (map (fst o snd) new_ineqs) @ (map snd nicoeffs) 366 val (indexl,thf) = VAR_ELIM new_coeffs 367 val (prev_indexl,these_indexl) = 368 Lib.partition (fn i => i > n) indexl 369 val prev_indexl' = 370 map (fn i => fst (el (toInt (i - n)) nicoeffs)) prev_indexl 371 val (these_indexl',thf') = 372 chain_assums new_ineqs thf these_indexl 373 in (Lib.mk_set (these_indexl' @ prev_indexl'),thf') 374 end 375 end 376 ) handle _ => failwith "VAR_ELIM" 377 end; 378 379end 380