1(* ========================================================================= *) 2(* NORMALIZING ALGEBRAIC EXPRESSIONS *) 3(* Copyright (c) 2006 Joe Hurd, distributed under the GNU GPL version 2 *) 4(* ========================================================================= *) 5 6structure Algebra :> Algebra = 7struct 8 9open Useful; 10 11(* ------------------------------------------------------------------------- *) 12(* Helper functions. *) 13(* ------------------------------------------------------------------------- *) 14 15fun exponential _ 0 = 1 16 | exponential m 1 = m 17 | exponential m n = 18 (if n mod 2 = 0 then 1 else m) * exponential (m * m) (n div 2); 19 20fun primeDividesFactorial p = 21 let 22 fun f k = k + (if k < p then 0 else f (k div p)) 23 in 24 fn n => f (n div p) 25 end; 26 27fun multinomial n l = 28 let 29 fun g p (k,z) = z - primeDividesFactorial p k 30 31 fun f (p,z) = 32 z * exponential p (foldl (g p) (primeDividesFactorial p n) l) 33 in 34 foldl f 1 (Useful.primes_up_to n) 35 end; 36 37fun first_multinomial n k = Useful.funpow (k - 1) (fn l => 0 :: l) [n]; 38 39val next_multinomial = 40 let 41 fun f _ _ [] = NONE 42 | f z h (0 :: l) = f (0 :: z) h l 43 | f z h (k :: l) = SOME (List.revAppend (z, (h + 1) :: (k - 1) :: l)) 44 in 45 fn [] => raise Bug "next_multinomial" 46 | h :: t => f [] h t 47 end; 48 49fun all_multinomials n k = 50 let 51 fun f m l = 52 let 53 val l = m :: l 54 in 55 case next_multinomial m of NONE => l | SOME m => f m l 56 end 57 in 58 f (first_multinomial n k) [] 59 end; 60 61(* ------------------------------------------------------------------------- *) 62(* A type of algebraic expressions. *) 63(* ------------------------------------------------------------------------- *) 64 65(* Invariants: *) 66(* 1. A Sum map: *) 67(* a) Does not map any expression to 0. *) 68(* b) Does not map any Sum expressions. *) 69(* c) Does not consist of a single expression mapped to 1. *) 70(* 2. A Prod map: *) 71(* a) Does not map any expression to 0. *) 72(* b) Does not map any Prod expressions. *) 73(* c) Does not consist of a single expression mapped to 1. *) 74 75datatype expression = 76 Var of string 77 | Sum of (expression,int) Map.map 78 | Prod of (expression,int) Map.map; 79 80type expressionl = (expression,int) Map.map; 81 82val explCompare : expressionl * expressionl -> order = 83 let 84 fun cmp (n1,n2) = 85 case Int.compare (n1,n2) of 86 LESS => GREATER 87 | EQUAL => EQUAL 88 | GREATER => LESS 89 in 90 Map.compare cmp 91 end; 92 93fun compare (e1,e2) = 94 case (e1,e2) of 95 (Var v1, Var v2) => String.compare (v1,v2) 96 | (Var _, _) => LESS 97 | (_, Var _) => GREATER 98 | (Sum m1, Sum m2) => explCompare (m1,m2) 99 | (Sum _, _) => LESS 100 | (_, Sum _) => GREATER 101 | (Prod m1, Prod m2) => explCompare (m1,m2); 102 103fun equal e1 e2 = compare (e1,e2) = EQUAL; 104 105fun destVar (Var n) = n 106 | destVar _ = raise Error "destVar"; 107 108val isVar = can destVar; 109 110fun destSum (Sum n) = n 111 | destSum _ = raise Error "destSum"; 112 113val isSum = can destSum; 114 115fun destProd (Prod n) = n 116 | destProd _ = raise Error "destProd"; 117 118val isProd = can destProd; 119 120val explEmpty : expressionl = Map.new compare; 121 122fun explSingle e_n : expressionl = Map.singleton compare e_n; 123 124fun destExplSingle m : expression * int = 125 if Map.size m <> 1 then raise Error "destExplSingle" 126 else Option.valOf (Map.findl (K true) m); 127 128val isExplSingle = can destExplSingle; 129 130fun explUnit e = explSingle (e,1); 131 132fun destExplUnit m = 133 case total destExplSingle m of 134 SOME (e,1) => e 135 | _ => raise Error "destExplUnit"; 136 137val isExplUnit = can destExplUnit; 138 139local 140 fun combine (a,b) = 141 let 142 val c = a + b 143 in 144 if c = 0 then NONE else SOME c 145 end; 146in 147 val explCombine : expressionl -> expressionl -> expressionl = 148 Map.union combine; 149end; 150 151fun explCons e_n m = explCombine (explSingle e_n) m; 152 153fun explScale 0 _ = explEmpty 154 | explScale 1 m = m 155 | explScale n m = Map.transform (fn k => k * n) m; 156 157fun explProdCons (Prod mp, n) m = explCombine m (explScale n mp) 158 | explProdCons (e,n) m = explCons (e,n) m; 159 160fun explSumCons (Sum ms, n) m = explCombine m (explScale n ms) 161 | explSumCons (e,n) m = explCons (e,n) m; 162 163val zero = Sum explEmpty; 164 165val one = Prod explEmpty; 166 167fun fromInt 0 = zero 168 | fromInt 1 = one 169 | fromInt n = Sum (explSingle (one,n)); 170 171val minusOne = fromInt (~1); 172 173fun toInt (Var _) = NONE 174 | toInt (Prod m) = if Map.null m then SOME 1 else NONE 175 | toInt (Sum m) = 176 if Map.null m then SOME 0 177 else 178 case total destExplSingle m of 179 SOME (Prod m, n) => if Map.null m then SOME n else NONE 180 | _ => NONE; 181 182fun isInt exp = Option.isSome (toInt exp); 183 184fun mkSum m = 185 case total destExplUnit m of SOME e => e | NONE => Sum m; 186 187fun mkProd m = 188 case total destExplUnit m of SOME e => e | NONE => Prod m; 189 190fun add (a,b) = 191 mkSum (explSumCons (a,1) (explSumCons (b,1) explEmpty)); 192 193fun multiply (a,b) = 194 mkProd (explProdCons (a,1) (explProdCons (b,1) explEmpty)); 195 196fun negate e = multiply (minusOne,e); 197 198fun subtract (a,b) = add (a, negate b); 199 200fun power e_n = mkProd (explCons e_n explEmpty); 201 202(* ------------------------------------------------------------------------- *) 203(* Pretty printing. *) 204(* ------------------------------------------------------------------------- *) 205 206local 207 val infixes : Parser.infixities = 208 [{tok = " ** ", prec = 10, left_assoc = false}, 209 {tok = " / ", prec = 8, left_assoc = true }, 210 {tok = " * ", prec = 8, left_assoc = true }, 211 {tok = " . ", prec = 7, left_assoc = true }, 212 {tok = " + ", prec = 6, left_assoc = true }, 213 {tok = " - ", prec = 6, left_assoc = true }]; 214 215 fun expDestInfix (Var _) = NONE 216 | expDestInfix (e as Sum m) = 217 if isInt e then NONE 218 else 219 let 220 val (e,n) = Option.valOf (Map.findr (K true) m) 221 val m = Map.delete m e 222 in 223 if Map.null m then SOME (".", fromInt n, e) 224 else 225 let 226 val a = Option.getOpt (total destExplUnit m, Sum m) 227 and b = if n = 1 then e else Sum (explSingle (e,n)) 228 in 229 SOME ("+",a,b) 230 end 231 end 232 | expDestInfix (e as Prod m) = 233 if isInt e then NONE 234 else 235 let 236 val (e,n) = Option.valOf (Map.findr (K true) m) 237 val m = Map.delete m e 238 in 239 if Map.null m then SOME ("**", e, fromInt n) 240 else 241 let 242 val a = Option.getOpt (total destExplUnit m, Prod m) 243 and b = if n = 1 then e else Prod (explSingle (e,n)) 244 in 245 SOME ("*",a,b) 246 end 247 end; 248 249 fun pp_basic_exp pp (Var v, _) = Useful.pp_string pp v 250 | pp_basic_exp pp (e,_) = 251 case toInt e of 252 SOME n => Useful.pp_int pp n 253 | NONE => raise Bug "pp_basic_exp"; 254in 255 val pp = 256 Useful.pp_map 257 (fn e => (e,false)) 258 (Parser.pp_infixes infixes expDestInfix pp_basic_exp); 259 260 val toString = PP.pp_to_string (!Useful.LINE_LENGTH) pp; 261end; 262 263(* ------------------------------------------------------------------------- *) 264(* Normalization with a set of equations. *) 265(* ------------------------------------------------------------------------- *) 266 267fun explSumExp m 1 = m 268 | explSumExp m n = 269 let 270 fun g (_,_,(_,_,[])) = raise Bug "explSumExp" 271 | g (_, _, (c, mp, 0 :: l)) = (c,mp,l) 272 | g (e, a, (c, mp, r :: l)) = 273 (c * exponential a r, explProdCons (e,r) mp, l) 274 275 fun f multi ms = 276 let 277 val np = multinomial n multi 278 279 val (np,mp,_) = Map.foldl g (np,explEmpty,multi) m 280 281 val ms = explSumCons (mkProd mp, np) ms 282 in 283 case next_multinomial multi of 284 NONE => ms 285 | SOME multi => f multi ms 286 end 287 in 288 f (first_multinomial n (Map.size m)) explEmpty 289 end; 290 291local 292 fun subtract (e,n,m) : expressionl = 293 case Map.peek m e of 294 SOME n' => if n = n' then Map.delete m e 295 else raise Error "explSubtract*: wrong exp" 296 | NONE => raise Error "explSubtract*: no such base"; 297in 298 fun explSubtract m1 m2 = Map.foldl subtract m1 m2; 299 300 fun explSubtractSingle m (e,n) = subtract (e,n,m); 301 302 fun explSubtractUnit m e = subtract (e,1,m); 303end; 304 305fun explProdRewr (Prod l_m, r) m = explProdCons (r,1) (explSubtract m l_m) 306 | explProdRewr (l_e,r) m = explProdCons (r,1) (explSubtractUnit m l_e); 307 308fun explProdRewrite eqns m = 309 case Useful.first (total (C explProdRewr m)) eqns of 310 NONE => raise Error "explProdRewrite" 311 | SOME m => m; 312 313local 314 fun norm _ (Var _) = raise Error "unchanged" 315 | norm eqns (Sum m) = 316 let 317 fun f (e,n,(changed,m)) = 318 case total (norm eqns) e of 319 SOME e => (true, explSumCons (e,n) m) 320 | NONE => (changed, explSumCons (e,n) m) 321 322 val (changed,m) = Map.foldl f (false,explEmpty) m 323 in 324 if not changed then raise Error "unchanged" else mkSum m 325 end 326 | norm eqns (Prod m) = 327 case Map.findr (K true) m of 328 SOME (e as Sum ms, n) => 329 if Map.null ms then zero 330 else 331 let 332 val m = Map.delete m e 333 334 fun f (e,n,ms) = 335 explSumCons (mkProd (explProdCons (e,1) m), n) ms 336 337 val ms = explSumExp ms n 338 in 339 mkSum (if Map.null m then ms else Map.foldl f explEmpty ms) 340 end 341 | _ => mkProd (explProdRewrite eqns m); 342 343 fun repeat_norm eqns e = 344 case total (norm eqns) e of SOME e => repeat_norm eqns e | NONE => e; 345in 346 fun normalize {equations} exp = 347 let 348 val _ = print ("normalize: input =\n" ^ toString exp ^ "\n") 349 val exp = repeat_norm equations exp 350 val _ = print ("normalize: result =\n" ^ toString exp ^ "\n") 351 in 352 exp 353 end; 354end; 355 356(* Quick testing 357installPP pp_exp; 358 359val Sum ms = add (add (Var "x", Var "y"), Var "z"); 360 361Sum (explSumExp ms 9); 362 363val e = multiply (Var "x", add (fromInt 1, Var "y")); 364 365val e = multiply (Var "x", fromInt (~1)); 366 367try normalize [] e; 368*) 369 370end 371