1(*****************************************************************************)
2(* FILE          : term_coeffs.sml                                           *)
3(* DESCRIPTION   : Functions for converting between arithmetic terms and     *)
4(*                 their representation as bindings of variable names to     *)
5(*                 coefficients.                                             *)
6(*                                                                           *)
7(* READS FILES   : <none>                                                    *)
8(* WRITES FILES  : <none>                                                    *)
9(*                                                                           *)
10(* AUTHOR        : R.J.Boulton, University of Cambridge                      *)
11(* DATE          : 4th March 1991                                            *)
12(*                                                                           *)
13(* TRANSLATOR    : R.J.Boulton, University of Cambridge                      *)
14(* DATE          : 5th February 1993                                         *)
15(*                                                                           *)
16(* LAST MODIFIED : R.J.Boulton                                               *)
17(* DATE          : 15th February 1993                                        *)
18(*****************************************************************************)
19
20structure Term_coeffs :> Term_coeffs =
21struct
22  open Arbint HolKernel boolLib Arith_cons Rsyntax;
23
24  val op << = String.<
25  infix << ##;
26
27fun failwith function = raise (mk_HOL_ERR "Term_coeffs" function "");
28
29
30(*===========================================================================*)
31(* Manipulating coefficient representations of arithmetic expressions        *)
32(*===========================================================================*)
33
34(*---------------------------------------------------------------------------*)
35(* negate_coeffs : (int * ('a * int) list) -> (int * ('a * int) list)        *)
36(*                                                                           *)
37(* Negates constant value and coefficients of variables in a binding.        *)
38(*---------------------------------------------------------------------------*)
39
40fun negate_coeffs x = ((fn n => zero - n) ##
41                       (map (I ## (fn n => zero - n)))) x;
42
43(*---------------------------------------------------------------------------*)
44(* merge_coeffs : (int * (string * int) list) ->                             *)
45(*                (int * (string * int) list) ->                             *)
46(*                (int * (string * int) list)                                *)
47(*                                                                           *)
48(* Sums constant values and merges bindings by adding coefficients of any    *)
49(* variable that appears in both bindings. If the sum of the coefficients is *)
50(* zero, the variable concerned is not entered in the new binding.           *)
51(*---------------------------------------------------------------------------*)
52
53fun merge_coeffs coeffs1 coeffs2 =
54   let fun merge bind1 bind2 =
55          if (null bind1) then bind2
56          else if (null bind2) then bind1
57          else (let val (name1:string,coeff1) = hd bind1
58                    and (name2,coeff2) = hd bind2
59                in  if (name1 = name2)
60                    then if ((coeff1 + coeff2) = zero)
61                         then merge (tl bind1) (tl bind2)
62                         else (name1,(coeff1 + coeff2))::
63                                 (merge (tl bind1) (tl bind2))
64                    else if (name1 << name2)
65                         then (name1,coeff1)::(merge (tl bind1) bind2)
66                         else (name2,coeff2)::(merge bind1 (tl bind2))
67                end)
68       val (const1,bind1) = coeffs1
69       and (const2,bind2) = coeffs2
70   in  ((const1 + const2:int),merge bind1 bind2)
71   end;
72
73(*---------------------------------------------------------------------------*)
74(* lhs_coeffs : (int * ('a * int) list) -> (int * ('a * int) list)           *)
75(*                                                                           *)
76(* Extract strictly negative coefficients and negate them.                   *)
77(*---------------------------------------------------------------------------*)
78
79fun lhs_coeffs x =
80   let fun f n = if (n < zero) then (zero - n) else zero
81       fun g (s,n) = if (n < zero) then (s,(zero - n))
82                     else failwith "lhs_coeffs"
83   in  (f ## (mapfilter g)) x
84   end;
85
86(*---------------------------------------------------------------------------*)
87(* rhs_coeffs : (int * ('a * int) list) -> (int * ('a * int) list)           *)
88(*                                                                           *)
89(* Extract strictly positive coefficients.                                   *)
90(*---------------------------------------------------------------------------*)
91
92fun rhs_coeffs x =
93   let fun f n = if (n > zero) then n else zero
94   in  (f ## (filter (fn (_,n) => n > zero))) x
95   end;
96
97(*---------------------------------------------------------------------------*)
98(* diff_of_coeffs :                                                          *)
99(*    ((int * (string * int) list) * (int * (string * int) list)) ->         *)
100(*    ((int * (string * int) list) * (int * (string * int) list))            *)
101(*                                                                           *)
102(* Given the coefficients representing two inequalities, this function       *)
103(* computes the terms (as coefficients) that have to be added to each in     *)
104(* order to make the right-hand side of the first equal to the left-hand side*)
105(* of the second.                                                            *)
106(*---------------------------------------------------------------------------*)
107
108fun diff_of_coeffs (coeffs1,coeffs2) =
109   let val coeffs1' = rhs_coeffs coeffs1
110       and coeffs2' = lhs_coeffs coeffs2
111       val coeffs = merge_coeffs (negate_coeffs coeffs1') coeffs2'
112   in  (rhs_coeffs coeffs,lhs_coeffs coeffs)
113   end;
114
115(*---------------------------------------------------------------------------*)
116(* vars_of_coeffs : ('a * (''b * 'c) list) list -> ''b list                  *)
117(*                                                                           *)
118(* Obtain a list of variable names from a set of coefficient lists.          *)
119(*---------------------------------------------------------------------------*)
120
121fun vars_of_coeffs coeffsl =
122 Lib.mk_set(Lib.flatten (map ((map fst) o snd) coeffsl));
123
124(*===========================================================================*)
125(* Extracting coefficients and variable names from normalized terms          *)
126(*===========================================================================*)
127
128(*---------------------------------------------------------------------------*)
129(* var_of_prod : term -> string                                              *)
130(*                                                                           *)
131(* Returns variable name from terms of the form "var" and "const * var".     *)
132(*---------------------------------------------------------------------------*)
133
134fun var_of_prod tm =
135 (#Name (dest_var tm)) handle HOL_ERR _ =>
136 (#Name (dest_var (rand tm))) handle HOL_ERR _ =>
137 failwith "var_of_prod";
138
139(*---------------------------------------------------------------------------*)
140(* coeffs_of_arith : term -> (int * (string * int) list)                     *)
141(*                                                                           *)
142(* Takes an arithmetic term that has been sorted and returns the constant    *)
143(* value and a binding of variable names to their coefficients, e.g.         *)
144(*                                                                           *)
145(*    coeffs_of_arith `1 + (4 * x) + (10 * y)`  --->                         *)
146(*    (1, [("x", 4); ("y", 10)])                                             *)
147(*                                                                           *)
148(* Assumes that there are no zero coefficients in the argument term. The     *)
149(* function also assumes that when a variable has a coefficient of one, it   *)
150(* appears in the term as (for example) `1 * x` rather than as `x`.          *)
151(*---------------------------------------------------------------------------*)
152
153fun coeffs_of_arith tm =
154   let fun coeff tm = (int_of_term o rand o rator) tm
155       fun coeffs tm =
156          (let val (prod,rest) = dest_plus tm
157           in  (var_of_prod prod,coeff prod)::(coeffs rest)
158           end
159          ) handle HOL_ERR _ => [(var_of_prod tm,coeff tm)]
160   in  (let val (const,rest) = dest_plus tm
161        in  (int_of_term const,coeffs rest)
162        end)
163       handle HOL_ERR _ => (int_of_term tm,[])
164       handle HOL_ERR _ => (zero,coeffs tm)
165       handle HOL_ERR _ => failwith "coeffs_of_arith"
166   end;
167
168(*---------------------------------------------------------------------------*)
169(* coeffs_of_leq : term -> (int * (string * int) list)                       *)
170(*                                                                           *)
171(* Takes a less-than-or-equal-to inequality between two arithmetic terms     *)
172(* that have been sorted and returns the constant value and a binding of     *)
173(* variable names to their coefficients for the equivalent term with zero on *)
174(* the LHS of the inequality, e.g.                                           *)
175(*                                                                           *)
176(*    coeffs_of_leq `((1 * x) + (1 * z)) <= (1 + (4 * x) + (10 * y))`  --->  *)
177(*    (1, [("x", 3); ("y", 10); ("z", -1)])                                  *)
178(*                                                                           *)
179(* Assumes that there are no zero coefficients in the argument term. The     *)
180(* function also assumes that when a variable has a coefficient of one, it   *)
181(* appears in the term as (for example) `1 * x` rather than as `x`.          *)
182(*---------------------------------------------------------------------------*)
183
184fun coeffs_of_leq tm =
185   (let val (tm1,tm2) = dest_leq tm
186        val coeffs1 = negate_coeffs (coeffs_of_arith tm1)
187        and coeffs2 = coeffs_of_arith tm2
188    in  merge_coeffs coeffs1 coeffs2
189    end
190   ) handle HOL_ERR _ => failwith "coeffs_of_leq";
191
192(*---------------------------------------------------------------------------*)
193(* coeffs_of_leq_set : term -> (int * (string * int) list) list              *)
194(*                                                                           *)
195(* Obtains coefficients from a set of normalised inequalities.               *)
196(* See comments for coeffs_of_leq.                                           *)
197(*---------------------------------------------------------------------------*)
198
199fun coeffs_of_leq_set tm =
200 map coeffs_of_leq (strip_conj tm)
201     handle HOL_ERR _ =>
202 failwith "coeffs_of_leq_set";
203
204(*===========================================================================*)
205(* Constructing terms from coefficients and variable names                   *)
206(*===========================================================================*)
207
208(*---------------------------------------------------------------------------*)
209(* build_arith : int * (string * int) list -> term                           *)
210(*                                                                           *)
211(* Takes an integer and a binding of variable names and coefficients, and    *)
212(* returns a linear sum (as a term) with the constant at the head. Terms     *)
213(* with a coefficient of zero are eliminated, as is a zero constant. Terms   *)
214(* with a coefficient of one are not simplified.                             *)
215(*                                                                           *)
216(* Examples:                                                                 *)
217(*                                                                           *)
218(*    (3,[("x",2);("y",1)]) ---> `3 + (2 * x) + (1 * y)`                     *)
219(*    (3,[("x",2);("y",0)]) ---> `3 + (2 * x)`                               *)
220(*    (0,[("x",2);("y",1)]) ---> `(2 * x) + (1 * y)`                         *)
221(*    (0,[("x",0);("y",0)]) ---> `0`                                         *)
222(*---------------------------------------------------------------------------*)
223
224val zero_tm = term_of_int zero
225fun build_arith (const,bind) = let
226  fun build bind =
227    if (null bind) then zero_tm
228    else let
229      val (name,coeff) = Lib.trye hd bind
230      and rest = build (Lib.trye tl bind)
231    in
232      if (coeff = zero) then rest
233      else let
234        val prod = mk_mult (term_of_int coeff,mk_num_var name)
235      in
236        if is_zero rest then prod else mk_plus (prod,rest)
237      end
238    end
239in (let
240  val c = term_of_int const
241  and rest = build bind
242in
243  if is_zero rest then c
244  else if (const = zero) then rest
245       else mk_plus (c,rest)
246end) handle HOL_ERR _ => failwith "build_arith"
247end;
248
249(*---------------------------------------------------------------------------*)
250(* build_leq : (int * (string * int) list) -> term                           *)
251(*                                                                           *)
252(* Constructs a less-than-or-equal-to inequality from a constant and         *)
253(* a binding of variable names to coefficients.                              *)
254(* See comments for build_arith.                                             *)
255(*---------------------------------------------------------------------------*)
256
257fun build_leq coeffs =
258   mk_leq (build_arith (lhs_coeffs coeffs),build_arith (rhs_coeffs coeffs));
259
260end
261