1(* ========================================================================= *)
2(* NORMALIZING ALGEBRAIC EXPRESSIONS                                         *)
3(* Copyright (c) 2006 Joe Hurd, distributed under the GNU GPL version 2      *)
4(* ========================================================================= *)
5
6structure Algebra :> Algebra =
7struct
8
9open Useful;
10
11(* ------------------------------------------------------------------------- *)
12(* Helper functions.                                                         *)
13(* ------------------------------------------------------------------------- *)
14
15fun exponential _ 0 = 1
16  | exponential m 1 = m
17  | exponential m n =
18    (if n mod 2 = 0 then 1 else m) * exponential (m * m) (n div 2);
19
20fun primeDividesFactorial p =
21    let
22      fun f k = k + (if k < p then 0 else f (k div p))
23    in
24      fn n => f (n div p)
25    end;
26
27fun multinomial n l =
28    let
29      fun g p (k,z) = z - primeDividesFactorial p k
30
31      fun f (p,z) =
32          z * exponential p (foldl (g p) (primeDividesFactorial p n) l)
33    in
34      foldl f 1 (Useful.primes_up_to n)
35    end;
36
37fun first_multinomial n k = Useful.funpow (k - 1) (fn l => 0 :: l) [n];
38
39val next_multinomial =
40    let
41      fun f _ _ [] = NONE
42        | f z h (0 :: l) = f (0 :: z) h l
43        | f z h (k :: l) = SOME (List.revAppend (z, (h + 1) :: (k - 1) :: l))
44    in
45      fn [] => raise Bug "next_multinomial"
46       | h :: t => f [] h t
47    end;
48
49fun all_multinomials n k =
50    let
51      fun f m l =
52          let
53            val l = m :: l
54          in
55            case next_multinomial m of NONE => l | SOME m => f m l
56          end
57    in
58      f (first_multinomial n k) []
59    end;
60
61(* ------------------------------------------------------------------------- *)
62(* A type of algebraic expressions.                                          *)
63(* ------------------------------------------------------------------------- *)
64
65(* Invariants: *)
66(* 1. A Sum map: *)
67(*    a) Does not map any expression to 0. *)
68(*    b) Does not map any Sum expressions. *)
69(*    c) Does not consist of a single expression mapped to 1. *)
70(* 2. A Prod map: *)
71(*    a) Does not map any expression to 0. *)
72(*    b) Does not map any Prod expressions. *)
73(*    c) Does not consist of a single expression mapped to 1. *)
74
75datatype expression =
76    Var of string
77  | Sum of (expression,int) Map.map
78  | Prod of (expression,int) Map.map;
79
80type expressionl = (expression,int) Map.map;
81
82val explCompare : expressionl * expressionl -> order =
83    let
84      fun cmp (n1,n2) =
85          case Int.compare (n1,n2) of
86            LESS => GREATER
87          | EQUAL => EQUAL
88          | GREATER => LESS
89    in
90      Map.compare cmp
91    end;
92
93fun compare (e1,e2) =
94    case (e1,e2) of
95      (Var v1, Var v2) => String.compare (v1,v2)
96    | (Var _, _) => LESS
97    | (_, Var _) => GREATER
98    | (Sum m1, Sum m2) => explCompare (m1,m2)
99    | (Sum _, _) => LESS
100    | (_, Sum _) => GREATER
101    | (Prod m1, Prod m2) => explCompare (m1,m2);
102
103fun equal e1 e2 = compare (e1,e2) = EQUAL;
104
105fun destVar (Var n) = n
106  | destVar _ = raise Error "destVar";
107
108val isVar = can destVar;
109
110fun destSum (Sum n) = n
111  | destSum _ = raise Error "destSum";
112
113val isSum = can destSum;
114
115fun destProd (Prod n) = n
116  | destProd _ = raise Error "destProd";
117
118val isProd = can destProd;
119
120val explEmpty : expressionl = Map.new compare;
121
122fun explSingle e_n : expressionl = Map.singleton compare e_n;
123
124fun destExplSingle m : expression * int =
125    if Map.size m <> 1 then raise Error "destExplSingle"
126    else Option.valOf (Map.findl (K true) m);
127
128val isExplSingle = can destExplSingle;
129
130fun explUnit e = explSingle (e,1);
131
132fun destExplUnit m =
133    case total destExplSingle m of
134      SOME (e,1) => e
135    | _ => raise Error "destExplUnit";
136
137val isExplUnit = can destExplUnit;
138
139local
140  fun combine (a,b) =
141      let
142        val c = a + b
143      in
144        if c = 0 then NONE else SOME c
145      end;
146in
147  val explCombine : expressionl -> expressionl -> expressionl =
148      Map.union combine;
149end;
150
151fun explCons e_n m = explCombine (explSingle e_n) m;
152
153fun explScale 0 _ = explEmpty
154  | explScale 1 m = m
155  | explScale n m = Map.transform (fn k => k * n) m;
156
157fun explProdCons (Prod mp, n) m = explCombine m (explScale n mp)
158  | explProdCons (e,n) m = explCons (e,n) m;
159
160fun explSumCons (Sum ms, n) m = explCombine m (explScale n ms)
161  | explSumCons (e,n) m = explCons (e,n) m;
162
163val zero = Sum explEmpty;
164
165val one = Prod explEmpty;
166
167fun fromInt 0 = zero
168  | fromInt 1 = one
169  | fromInt n = Sum (explSingle (one,n));
170
171val minusOne = fromInt (~1);
172
173fun toInt (Var _) = NONE
174  | toInt (Prod m) = if Map.null m then SOME 1 else NONE
175  | toInt (Sum m) =
176    if Map.null m then SOME 0
177    else
178      case total destExplSingle m of
179        SOME (Prod m, n) => if Map.null m then SOME n else NONE
180      | _ => NONE;
181
182fun isInt exp = Option.isSome (toInt exp);
183
184fun mkSum m =
185    case total destExplUnit m of SOME e => e | NONE => Sum m;
186
187fun mkProd m =
188    case total destExplUnit m of SOME e => e | NONE => Prod m;
189
190fun add (a,b) =
191    mkSum (explSumCons (a,1) (explSumCons (b,1) explEmpty));
192
193fun multiply (a,b) =
194    mkProd (explProdCons (a,1) (explProdCons (b,1) explEmpty));
195
196fun negate e = multiply (minusOne,e);
197
198fun subtract (a,b) = add (a, negate b);
199
200fun power e_n = mkProd (explCons e_n explEmpty);
201
202(* ------------------------------------------------------------------------- *)
203(* Pretty printing.                                                          *)
204(* ------------------------------------------------------------------------- *)
205
206local
207  val infixes : Parser.infixities =
208      [{tok = " ** ", prec = 10, left_assoc = false},
209       {tok = " / ",  prec = 8,  left_assoc = true },
210       {tok = " * ",  prec = 8,  left_assoc = true },
211       {tok = " . ",  prec = 7,  left_assoc = true },
212       {tok = " + ",  prec = 6,  left_assoc = true },
213       {tok = " - ",  prec = 6,  left_assoc = true }];
214
215  fun expDestInfix (Var _) = NONE
216    | expDestInfix (e as Sum m) =
217      if isInt e then NONE
218      else
219        let
220          val (e,n) = Option.valOf (Map.findr (K true) m)
221          val m = Map.delete m e
222        in
223          if Map.null m then SOME (".", fromInt n, e)
224          else
225            let
226              val a = Option.getOpt (total destExplUnit m, Sum m)
227              and b = if n = 1 then e else Sum (explSingle (e,n))
228            in
229              SOME ("+",a,b)
230            end
231        end
232    | expDestInfix (e as Prod m) =
233      if isInt e then NONE
234      else
235        let
236          val (e,n) = Option.valOf (Map.findr (K true) m)
237          val m = Map.delete m e
238        in
239          if Map.null m then SOME ("**", e, fromInt n)
240          else
241            let
242              val a = Option.getOpt (total destExplUnit m, Prod m)
243              and b = if n = 1 then e else Prod (explSingle (e,n))
244            in
245              SOME ("*",a,b)
246            end
247        end;
248
249  fun pp_basic_exp pp (Var v, _) = Useful.pp_string pp v
250    | pp_basic_exp pp (e,_) =
251      case toInt e of
252        SOME n => Useful.pp_int pp n
253      | NONE => raise Bug "pp_basic_exp";
254in
255  val pp =
256      Useful.pp_map
257        (fn e => (e,false))
258        (Parser.pp_infixes infixes expDestInfix pp_basic_exp);
259
260  val toString = PP.pp_to_string (!Useful.LINE_LENGTH) pp;
261end;
262
263(* ------------------------------------------------------------------------- *)
264(* Normalization with a set of equations.                                    *)
265(* ------------------------------------------------------------------------- *)
266
267fun explSumExp m 1 = m
268  | explSumExp m n =
269    let
270      fun g (_,_,(_,_,[])) = raise Bug "explSumExp"
271        | g (_, _, (c, mp, 0 :: l)) = (c,mp,l)
272        | g (e, a, (c, mp, r :: l)) =
273          (c * exponential a r, explProdCons (e,r) mp, l)
274
275      fun f multi ms =
276          let
277            val np = multinomial n multi
278
279            val (np,mp,_) = Map.foldl g (np,explEmpty,multi) m
280
281            val ms = explSumCons (mkProd mp, np) ms
282          in
283            case next_multinomial multi of
284              NONE => ms
285            | SOME multi => f multi ms
286          end
287    in
288      f (first_multinomial n (Map.size m)) explEmpty
289    end;
290
291local
292  fun subtract (e,n,m) : expressionl =
293      case Map.peek m e of
294        SOME n' => if n = n' then Map.delete m e
295                   else raise Error "explSubtract*: wrong exp"
296      | NONE => raise Error "explSubtract*: no such base";
297in
298  fun explSubtract m1 m2 = Map.foldl subtract m1 m2;
299
300  fun explSubtractSingle m (e,n) = subtract (e,n,m);
301
302  fun explSubtractUnit m e = subtract (e,1,m);
303end;
304
305fun explProdRewr (Prod l_m, r) m = explProdCons (r,1) (explSubtract m l_m)
306  | explProdRewr (l_e,r) m = explProdCons (r,1) (explSubtractUnit m l_e);
307
308fun explProdRewrite eqns m =
309    case Useful.first (total (C explProdRewr m)) eqns of
310      NONE => raise Error "explProdRewrite"
311    | SOME m => m;
312
313local
314  fun norm _ (Var _) = raise Error "unchanged"
315    | norm eqns (Sum m) =
316      let
317        fun f (e,n,(changed,m)) =
318            case total (norm eqns) e of
319              SOME e => (true, explSumCons (e,n) m)
320            | NONE => (changed, explSumCons (e,n) m)
321
322        val (changed,m) = Map.foldl f (false,explEmpty) m
323      in
324        if not changed then raise Error "unchanged" else mkSum m
325      end
326    | norm eqns (Prod m) =
327      case Map.findr (K true) m of
328        SOME (e as Sum ms, n) =>
329        if Map.null ms then zero
330        else
331          let
332            val m = Map.delete m e
333
334            fun f (e,n,ms) =
335                explSumCons (mkProd (explProdCons (e,1) m), n) ms
336
337            val ms = explSumExp ms n
338          in
339            mkSum (if Map.null m then ms else Map.foldl f explEmpty ms)
340          end
341      | _ => mkProd (explProdRewrite eqns m);
342
343  fun repeat_norm eqns e =
344      case total (norm eqns) e of SOME e => repeat_norm eqns e | NONE => e;
345in
346  fun normalize {equations} exp =
347      let
348        val _ = print ("normalize: input =\n" ^ toString exp ^ "\n")
349        val exp = repeat_norm equations exp
350        val _ = print ("normalize: result =\n" ^ toString exp ^ "\n")
351      in
352        exp
353      end;
354end;
355
356(* Quick testing
357installPP pp_exp;
358
359val Sum ms = add (add (Var "x", Var "y"), Var "z");
360
361Sum (explSumExp ms 9);
362
363val e = multiply (Var "x", add (fromInt 1, Var "y"));
364
365val e = multiply (Var "x", fromInt (~1));
366
367try normalize [] e;
368*)
369
370end
371