1(*****************************************************************************) 2(* FILE : term_coeffs.sml *) 3(* DESCRIPTION : Functions for converting between arithmetic terms and *) 4(* their representation as bindings of variable names to *) 5(* coefficients. *) 6(* *) 7(* READS FILES : <none> *) 8(* WRITES FILES : <none> *) 9(* *) 10(* AUTHOR : R.J.Boulton, University of Cambridge *) 11(* DATE : 4th March 1991 *) 12(* *) 13(* TRANSLATOR : R.J.Boulton, University of Cambridge *) 14(* DATE : 5th February 1993 *) 15(* *) 16(* LAST MODIFIED : R.J.Boulton *) 17(* DATE : 15th February 1993 *) 18(*****************************************************************************) 19 20structure Term_coeffs :> Term_coeffs = 21struct 22 open Arbint HolKernel boolLib Arith_cons Rsyntax; 23 24 val op << = String.< 25 infix << ##; 26 27fun failwith function = raise (mk_HOL_ERR "Term_coeffs" function ""); 28 29 30(*===========================================================================*) 31(* Manipulating coefficient representations of arithmetic expressions *) 32(*===========================================================================*) 33 34(*---------------------------------------------------------------------------*) 35(* negate_coeffs : (int * ('a * int) list) -> (int * ('a * int) list) *) 36(* *) 37(* Negates constant value and coefficients of variables in a binding. *) 38(*---------------------------------------------------------------------------*) 39 40fun negate_coeffs x = ((fn n => zero - n) ## 41 (map (I ## (fn n => zero - n)))) x; 42 43(*---------------------------------------------------------------------------*) 44(* merge_coeffs : (int * (string * int) list) -> *) 45(* (int * (string * int) list) -> *) 46(* (int * (string * int) list) *) 47(* *) 48(* Sums constant values and merges bindings by adding coefficients of any *) 49(* variable that appears in both bindings. If the sum of the coefficients is *) 50(* zero, the variable concerned is not entered in the new binding. *) 51(*---------------------------------------------------------------------------*) 52 53fun merge_coeffs coeffs1 coeffs2 = 54 let fun merge bind1 bind2 = 55 if (null bind1) then bind2 56 else if (null bind2) then bind1 57 else (let val (name1:string,coeff1) = hd bind1 58 and (name2,coeff2) = hd bind2 59 in if (name1 = name2) 60 then if ((coeff1 + coeff2) = zero) 61 then merge (tl bind1) (tl bind2) 62 else (name1,(coeff1 + coeff2)):: 63 (merge (tl bind1) (tl bind2)) 64 else if (name1 << name2) 65 then (name1,coeff1)::(merge (tl bind1) bind2) 66 else (name2,coeff2)::(merge bind1 (tl bind2)) 67 end) 68 val (const1,bind1) = coeffs1 69 and (const2,bind2) = coeffs2 70 in ((const1 + const2:int),merge bind1 bind2) 71 end; 72 73(*---------------------------------------------------------------------------*) 74(* lhs_coeffs : (int * ('a * int) list) -> (int * ('a * int) list) *) 75(* *) 76(* Extract strictly negative coefficients and negate them. *) 77(*---------------------------------------------------------------------------*) 78 79fun lhs_coeffs x = 80 let fun f n = if (n < zero) then (zero - n) else zero 81 fun g (s,n) = if (n < zero) then (s,(zero - n)) 82 else failwith "lhs_coeffs" 83 in (f ## (mapfilter g)) x 84 end; 85 86(*---------------------------------------------------------------------------*) 87(* rhs_coeffs : (int * ('a * int) list) -> (int * ('a * int) list) *) 88(* *) 89(* Extract strictly positive coefficients. *) 90(*---------------------------------------------------------------------------*) 91 92fun rhs_coeffs x = 93 let fun f n = if (n > zero) then n else zero 94 in (f ## (filter (fn (_,n) => n > zero))) x 95 end; 96 97(*---------------------------------------------------------------------------*) 98(* diff_of_coeffs : *) 99(* ((int * (string * int) list) * (int * (string * int) list)) -> *) 100(* ((int * (string * int) list) * (int * (string * int) list)) *) 101(* *) 102(* Given the coefficients representing two inequalities, this function *) 103(* computes the terms (as coefficients) that have to be added to each in *) 104(* order to make the right-hand side of the first equal to the left-hand side*) 105(* of the second. *) 106(*---------------------------------------------------------------------------*) 107 108fun diff_of_coeffs (coeffs1,coeffs2) = 109 let val coeffs1' = rhs_coeffs coeffs1 110 and coeffs2' = lhs_coeffs coeffs2 111 val coeffs = merge_coeffs (negate_coeffs coeffs1') coeffs2' 112 in (rhs_coeffs coeffs,lhs_coeffs coeffs) 113 end; 114 115(*---------------------------------------------------------------------------*) 116(* vars_of_coeffs : ('a * (''b * 'c) list) list -> ''b list *) 117(* *) 118(* Obtain a list of variable names from a set of coefficient lists. *) 119(*---------------------------------------------------------------------------*) 120 121fun vars_of_coeffs coeffsl = 122 Lib.mk_set(Lib.flatten (map ((map fst) o snd) coeffsl)); 123 124(*===========================================================================*) 125(* Extracting coefficients and variable names from normalized terms *) 126(*===========================================================================*) 127 128(*---------------------------------------------------------------------------*) 129(* var_of_prod : term -> string *) 130(* *) 131(* Returns variable name from terms of the form "var" and "const * var". *) 132(*---------------------------------------------------------------------------*) 133 134fun var_of_prod tm = 135 (#Name (dest_var tm)) handle HOL_ERR _ => 136 (#Name (dest_var (rand tm))) handle HOL_ERR _ => 137 failwith "var_of_prod"; 138 139(*---------------------------------------------------------------------------*) 140(* coeffs_of_arith : term -> (int * (string * int) list) *) 141(* *) 142(* Takes an arithmetic term that has been sorted and returns the constant *) 143(* value and a binding of variable names to their coefficients, e.g. *) 144(* *) 145(* coeffs_of_arith `1 + (4 * x) + (10 * y)` ---> *) 146(* (1, [("x", 4); ("y", 10)]) *) 147(* *) 148(* Assumes that there are no zero coefficients in the argument term. The *) 149(* function also assumes that when a variable has a coefficient of one, it *) 150(* appears in the term as (for example) `1 * x` rather than as `x`. *) 151(*---------------------------------------------------------------------------*) 152 153fun coeffs_of_arith tm = 154 let fun coeff tm = (int_of_term o rand o rator) tm 155 fun coeffs tm = 156 (let val (prod,rest) = dest_plus tm 157 in (var_of_prod prod,coeff prod)::(coeffs rest) 158 end 159 ) handle HOL_ERR _ => [(var_of_prod tm,coeff tm)] 160 in (let val (const,rest) = dest_plus tm 161 in (int_of_term const,coeffs rest) 162 end) 163 handle HOL_ERR _ => (int_of_term tm,[]) 164 handle HOL_ERR _ => (zero,coeffs tm) 165 handle HOL_ERR _ => failwith "coeffs_of_arith" 166 end; 167 168(*---------------------------------------------------------------------------*) 169(* coeffs_of_leq : term -> (int * (string * int) list) *) 170(* *) 171(* Takes a less-than-or-equal-to inequality between two arithmetic terms *) 172(* that have been sorted and returns the constant value and a binding of *) 173(* variable names to their coefficients for the equivalent term with zero on *) 174(* the LHS of the inequality, e.g. *) 175(* *) 176(* coeffs_of_leq `((1 * x) + (1 * z)) <= (1 + (4 * x) + (10 * y))` ---> *) 177(* (1, [("x", 3); ("y", 10); ("z", -1)]) *) 178(* *) 179(* Assumes that there are no zero coefficients in the argument term. The *) 180(* function also assumes that when a variable has a coefficient of one, it *) 181(* appears in the term as (for example) `1 * x` rather than as `x`. *) 182(*---------------------------------------------------------------------------*) 183 184fun coeffs_of_leq tm = 185 (let val (tm1,tm2) = dest_leq tm 186 val coeffs1 = negate_coeffs (coeffs_of_arith tm1) 187 and coeffs2 = coeffs_of_arith tm2 188 in merge_coeffs coeffs1 coeffs2 189 end 190 ) handle HOL_ERR _ => failwith "coeffs_of_leq"; 191 192(*---------------------------------------------------------------------------*) 193(* coeffs_of_leq_set : term -> (int * (string * int) list) list *) 194(* *) 195(* Obtains coefficients from a set of normalised inequalities. *) 196(* See comments for coeffs_of_leq. *) 197(*---------------------------------------------------------------------------*) 198 199fun coeffs_of_leq_set tm = 200 map coeffs_of_leq (strip_conj tm) 201 handle HOL_ERR _ => 202 failwith "coeffs_of_leq_set"; 203 204(*===========================================================================*) 205(* Constructing terms from coefficients and variable names *) 206(*===========================================================================*) 207 208(*---------------------------------------------------------------------------*) 209(* build_arith : int * (string * int) list -> term *) 210(* *) 211(* Takes an integer and a binding of variable names and coefficients, and *) 212(* returns a linear sum (as a term) with the constant at the head. Terms *) 213(* with a coefficient of zero are eliminated, as is a zero constant. Terms *) 214(* with a coefficient of one are not simplified. *) 215(* *) 216(* Examples: *) 217(* *) 218(* (3,[("x",2);("y",1)]) ---> `3 + (2 * x) + (1 * y)` *) 219(* (3,[("x",2);("y",0)]) ---> `3 + (2 * x)` *) 220(* (0,[("x",2);("y",1)]) ---> `(2 * x) + (1 * y)` *) 221(* (0,[("x",0);("y",0)]) ---> `0` *) 222(*---------------------------------------------------------------------------*) 223 224val zero_tm = term_of_int zero 225fun build_arith (const,bind) = let 226 fun build bind = 227 if (null bind) then zero_tm 228 else let 229 val (name,coeff) = Lib.trye hd bind 230 and rest = build (Lib.trye tl bind) 231 in 232 if (coeff = zero) then rest 233 else let 234 val prod = mk_mult (term_of_int coeff,mk_num_var name) 235 in 236 if is_zero rest then prod else mk_plus (prod,rest) 237 end 238 end 239in (let 240 val c = term_of_int const 241 and rest = build bind 242in 243 if is_zero rest then c 244 else if (const = zero) then rest 245 else mk_plus (c,rest) 246end) handle HOL_ERR _ => failwith "build_arith" 247end; 248 249(*---------------------------------------------------------------------------*) 250(* build_leq : (int * (string * int) list) -> term *) 251(* *) 252(* Constructs a less-than-or-equal-to inequality from a constant and *) 253(* a binding of variable names to coefficients. *) 254(* See comments for build_arith. *) 255(*---------------------------------------------------------------------------*) 256 257fun build_leq coeffs = 258 mk_leq (build_arith (lhs_coeffs coeffs),build_arith (rhs_coeffs coeffs)); 259 260end 261