1(* Author: Michael Norrish *) 2 3(* very simple-minded implementation of arbitrary precision natural 4 numbers *) 5 6structure Arbnumcore :> Arbnumcore = 7struct 8 9(* app load ["Substring", "Random", "Int", "Process"] *) 10 11fun extract arg = ArraySlice.vector(ArraySlice.slice arg) 12 13fun copyVec' {di,dst,len,src,si} = let 14 val v = VectorSlice.vector(VectorSlice.slice(src,si,len)); 15in 16 Array.copyVec {di = di, dst = dst, src = v} 17end 18 19 20(* base must be <= the sqrt of MaxInt *) 21val BASE = 10000; 22val BASEless1 = BASE - 1; 23 24type num = int list 25 26fun asList x = x 27 28(* each element in the list is in the range 0 - (BASE - 1), least significant 29 "digit" is the first element in the list. *) 30 31val zero = [0]; 32val one = [1]; 33val two = [2]; 34 35fun plus1 [] = raise Fail "Should never happen" 36 | plus1 [n] = if n = BASEless1 then [0, 1] else [n + 1] 37 | plus1 (n::xs) = if n = BASEless1 then 0::plus1 xs else (n + 1)::xs 38 39fun less1 [] = raise Fail "arbnum invariant falsified" 40 | less1 [x] = if x = 0 then raise Fail "Can't take one off zero" 41 else [x - 1] 42 | less1 (0::xs) = BASEless1::less1 xs 43 | less1 (x::xs) = (x - 1)::xs 44val less2 = less1 o less1 45 46val plus2 = plus1 o plus1 47 48fun times2 [] = [] 49 | times2 [x] = if 2 * x < BASE then [2 * x] else [2 * x - BASE, 1] 50 | times2 (x::xs) = 51 if 2 * x < BASE then (2 * x)::(times2 xs) 52 else (2 * x - BASE)::(plus1 (times2 xs)) 53 54fun revdiv2 [] = [] 55 | revdiv2 [x] = [x div 2] 56 | revdiv2 (x::y::xs) = let 57 val dividend = x div 2 58 val remainder = x mod 2 59 in 60 dividend::(revdiv2 ((remainder * BASE) + y::xs)) 61 end 62 63fun div2 n = let 64 val n' = List.rev n 65 fun strip [] = [] 66 | strip [x] = [x] 67 | strip (0::xs) = strip xs 68 | strip (y as (x::xs)) = y 69in 70 List.rev (strip (revdiv2 n')) 71end 72 73fun mod2 (x::_) = [x mod 2] 74 | mod2 [] = raise Fail "arbnum representation invariant violated" 75 76fun fromInt n = let 77 val _ = n >= 0 orelse raise Overflow 78 val dividend = n div BASE 79 val rem = n mod BASE 80in 81 rem::(if dividend > 0 then fromInt dividend else []) 82end 83 84fun toInt [] = 0 85 | toInt (x::xs) = Int.+(x, Int.*(BASE, toInt xs)) 86 87 88 89fun (x + y) = addwc x y false 90and 91(* add with carry *) 92 addwc xn yn b = 93 case xn of 94 [] => 95 if b then 96 if null yn then one else plus1 yn 97 else yn 98 | (x::xs) => let 99 in 100 case yn of 101 [] => 102 if b then 103 if null xn then one else plus1 xn 104 else xn 105 | (y::ys) => let 106 val xy = Int.+(x, y) 107 val xyc = if b then Int.+(xy, 1) else xy 108 val (carry, rem) = if xyc >= BASE then (true, Int.-(xyc, BASE)) 109 else (false, xyc) 110 in 111 rem::(addwc xs ys carry) 112 end 113 end 114 115fun floor r = let 116 val _ = 0.0 <= r orelse raise Overflow 117 fun count_base_powers r = 118 (Real.floor r, 0) 119 handle Overflow => let 120 val (res, n) = count_base_powers (r / real BASE) 121 in 122 (res, Int.+(n,1)) 123 end 124 val (sig_part, exp) = count_base_powers r 125 fun shl n a = if n = 0 then a else shl (n - 1) (0 :: a) 126 fun exp_r n r = if n = 0 then r else exp_r (n - 1) (r * real BASE) 127 val sig_anum = shl exp (fromInt sig_part) 128 val sig_real = exp_r exp (real sig_part) 129in 130 if exp > 0 then sig_anum + floor (r - sig_real) 131 else sig_anum 132end 133 134fun toReal anum = List.foldl (fn (i,acc) => Real.+(acc * real BASE, real i)) 135 0.0 (List.rev anum) 136 137(* x0 + 1000x < y0 + 1000y (where x,y < BASE 138 = 139*) 140fun (xn < yn) = 141 case (xn, yn) of 142 (_, []) => false 143 | ([], _) => true 144 | (x::xs, y::ys) => xs < ys orelse (xs = ys andalso Int.<(x,y)) 145fun (xn <= yn) = xn = yn orelse xn < yn 146fun (xn >= yn) = yn <= xn 147fun (xn > yn) = yn < xn 148 149fun compare(x,y) = if x < y then LESS else if y < x then GREATER else EQUAL 150 151fun normalise [] = [0] 152 | normalise x = let 153 fun strip_leading_zeroes [] = [] 154 | strip_leading_zeroes (list as n::ns) = 155 if n = 0 then strip_leading_zeroes ns else list 156 val x' = List.rev (strip_leading_zeroes (List.rev x)) 157 in 158 if null x' then [0] else x' 159 end 160 161fun (xn - yn) = 162 if xn < yn then zero else normalise (subwc xn yn false) 163and subwc xn yn b = 164 case (xn, yn) of 165 (_, []) => if b then less1 xn else xn 166 | ([], _) => zero 167 | (x::xs, y::ys) => let 168 val (x', carry) = 169 if b then 170 if Int.<=(x,y) then (Int.-(Int.+(x,BASEless1), y), true) 171 else (Int.-(Int.-(x,y), 1), false) 172 else 173 if Int.<(x,y) then (Int.-(Int.+(x, BASE), y), true) 174 else (Int.-(x,y), false) 175 in 176 x'::subwc xs ys carry 177 end 178 179(* (x0 + BASEx) * y = x0 * y + BASE * x * y *) 180fun single_digit(n, xn) = 181 case n of 182 0 => zero 183 | 1 => xn 184 | 2 => times2 xn 185 | _ => let 186 fun f [] = [] 187 | f (x::xs) = let 188 val newx = Int.*(n,x) 189 val (rem, carry) = (Int.mod(newx, BASE), Int.div(newx,BASE)) 190 in 191 if carry = 0 then rem::f xs 192 else rem::(f xs + [carry]) 193 end 194 in 195 f xn 196 end 197 198 199fun (xn * yn) = 200 case (xn, yn) of 201 ([], _) => zero 202 | (_, []) => zero 203 | (x::xs, _) => normalise(single_digit(x, yn) + (0::(xs * yn))) 204 205fun pow (xn, yn) = 206 if mod2 yn = one then 207 (xn * pow(xn, less1 yn)) 208 else 209 if yn = zero then 210 one 211 else 212 pow((xn * xn), div2 yn); 213 214fun replicate n el = List.tabulate(n, fn _ => el) 215(* returns result in wrong order *) 216fun comp_sub acc carry (xn, yn) = 217 case (xn, yn) of 218 ([], []) => (acc, carry) 219 | (x::xs, y::ys) => let 220 val (res, newcarry) = 221 if carry then 222 if Int.>(x, y) then 223 (Int.-(x, Int.+(y, 1)), false) 224 else 225 (Int.-(Int.+(x, BASEless1), y), true) 226 else 227 if Int.>=(x, y) then 228 (Int.-(x, y), false) 229 else 230 (Int.-(Int.+(x,BASE), y), true) 231 in 232 comp_sub (res::acc) newcarry (xs, ys) 233 end 234 | _ => raise Fail "comp_sub : arguments of different length" 235 236 237(* y < BASE *) 238fun single_divmod (xn:int list) (y:int) = let 239 val xnr = List.rev xn 240 fun loop xn acc = 241 case xn of 242 [] => raise Fail "single_divmod: can't happen" 243 | [x] => let 244 val q = Int.div(x, y) 245 in 246 (q::acc, [Int.mod(x,y)]) 247 end 248 | [x1, x2] => let 249 val x = Int.+(Int.*(x1, BASE), x2) 250 val q = Int.div(x, y) 251 in 252 ((fromInt q) @ acc, [Int.mod(x, y)]) 253 end 254 | (x1::x2::xs) => let 255 val x = Int.+(Int.*(x1, BASE), x2) 256 val q = Int.div(x, y) 257 val r = Int.mod(x, y) 258 (* r < y, so r < BASE *) 259 in 260 loop (r::xs) ((fromInt q) @ acc) 261 end 262 val _ = y <> 0 orelse raise Div 263 val (q, r) = loop xnr [] 264in 265 (normalise q, normalise r) 266end 267 268 269fun divmod (xn, yn) = 270 if yn = zero then raise Div 271 else let 272 273 (* following algorithm from Knuth 4.3.1 Algorith D *) 274 (* we require 275 - that length vn > 1 276 - that length un >= length vn 277 - that hd vn <> 0 278 *) 279 fun KnuthD un vn = let 280 (* Knuth's algorithm is stateful so we mimic it here with 281 two arrays u and v *) 282 val d = Int.div(BASE, Int.+(List.last vn, 1)) 283 val n = length vn 284 val m = Int.-(length un, length vn) 285 val normalised_u = single_digit(d, un) 286 val normalised_v = single_digit(d, vn) 287 val _ = length normalised_v = length vn orelse 288 raise Fail "normalised_v not same length as v" 289 val norm_v_rev = List.rev normalised_v 290 val v1 = hd norm_v_rev 291 val v2 = hd (tl norm_v_rev) 292 val norm_u_rev = 293 if (length normalised_u = length un) then 0::List.rev normalised_u 294 else List.rev normalised_u 295 val u = Array.fromList norm_u_rev 296 val _ = Array.length u = Int.+(m, Int.+(n, 1)) orelse 297 raise Fail "Array u of unexpected length" 298 val q = Array.array(Int.+(m, 1), 0) 299 fun inner_loop j = let 300 infix 9 sub 301 open Array 302 (* D3. Calculate q hat *) 303 fun qhat_test qhat = let 304 open Int 305 in 306 v2 * qhat > 307 (u sub j * BASE + u sub (j + 1) - qhat * v1) * BASE + u sub (j + 2) 308 end 309 val qhat0 = let 310 open Int 311 in 312 if u sub j = v1 then BASEless1 313 else ((u sub j) * BASE + (u sub (j + 1))) div v1 314 end 315 val qhat = 316 if qhat_test qhat0 then 317 if qhat_test (Int.-(qhat0, 1)) then Int.-(qhat0, 2) 318 else Int.-(qhat0, 1) 319 else 320 qhat0 321 (* D4. multiply and subtract *) 322 val uslice_v = extract(u, j, SOME (Int.+(n, 1))) 323 val uslice_l = List.rev (Vector.foldr (op::) [] uslice_v) 324 val multiply_result0 = single_digit(qhat, normalised_v) 325 val multiply_result = let 326 val mr_len = List.length multiply_result0 327 val u_len = List.length uslice_l 328 in 329 if mr_len <> u_len then 330 multiply_result0 @ replicate (Int.-(u_len, mr_len)) 0 331 else 332 multiply_result0 333 end 334 val (newu_l, d4carry) = comp_sub [] false (uslice_l, multiply_result) 335 val newu_v = Vector.fromList newu_l 336 val _ = copyVec' {di=j, dst=u, len=NONE, src=newu_v, si=0} 337 (* D5. test remainder *) 338 val () = update(q, j, qhat) 339 val _ = 340 if d4carry then let 341 (* D6. Add back *) 342 val uslice_v = extract(u, j, SOME (Int.+(n, 1))) 343 val uslice_l = List.rev(Vector.foldr (op::) [] uslice_v) 344 val newu0 = uslice_l + normalised_v 345 (* have to ignore rightmost digit *) 346 val newu = List.drop(List.rev newu0, 1); 347 val newu_v = Vector.fromList newu 348 in 349 update(q, j, Int.-(q sub j, 1)); 350 copyVec' {di = j, dst = u, len = NONE, src = newu_v, si = 0} 351 end 352 else () 353 open Int 354 in 355 if j + 1 <= m then inner_loop (j + 1) 356 else () 357 end 358 val unnormal_result = inner_loop 0 359 val qn = normalise (List.rev (Array.foldr (op::) [] q)) 360 val rn0 = let 361 open Int 362 in 363 normalise (List.rev 364 (Vector.foldr (op::) [] (extract(u, m + 1, NONE)))) 365 end 366 val rn = #1 (single_divmod rn0 d) 367 in 368 (qn, rn) 369 end 370 in 371 if length yn = 1 then 372 single_divmod xn (hd yn) 373 else 374 if Int.>(length yn, length xn) then 375 (zero, xn) 376 else 377 KnuthD xn yn 378 end 379 380fun (xn div yn) = #1 (divmod (xn, yn)) 381fun (xn mod yn) = #2 (divmod (xn, yn)) 382 383local 384 fun recurse x n = if x = zero then n else recurse (div2 x) (plus1 n) 385 in 386 fun log2 x = if x = zero then raise Domain else recurse (div2 x) zero 387end 388 389local 390 fun gcd' i j = let 391 val r = i mod j 392 in 393 if r = zero then j else gcd' j r 394 end 395in 396fun gcd(i,j) = if i = zero then j 397 else if j = zero then i 398 else if i < j then gcd' j i 399 else gcd' i j 400end 401 402fun intexp(base,exponent) = let 403 fun recurse acc b n = 404 if Int.<=(n,0) then acc 405 else if Int.mod(n,2) = 1 then 406 recurse (Int.*(acc, b)) b (Int.-(n,1)) 407 else recurse acc (Int.*(b,b)) (Int.div(n,2)) 408in 409 recurse 1 base exponent 410end 411 412fun genFromString rdx = let 413 open StringCvt 414 val (base, chunksize, chunkshift, s_rdx) = 415 case rdx of 416 BIN => (2, 10, fromInt 1024, "BIN") 417 | OCT => (8, 5, fromInt 32768, "OCT") 418 | DEC => (10, 5, fromInt 100000, "DEC") 419 | HEX => (16, 5, fromInt 1048576, "HEX") 420 fun readchunk s = 421 case Int.scan rdx Substring.getc (Substring.full s) of 422 SOME (i, r) => if Substring.size r = 0 then SOME i else NONE 423 | _ => NONE 424 fun recurse acc s = let 425 val sz = size s 426 in 427 if Int.<=(sz, chunksize) then 428 fromInt (intexp(base,sz)) * acc + fromInt (valOf (readchunk s)) 429 else let 430 val sz_less_cs = Int.-(sz, chunksize) 431 val pfx = substring(s, 0, chunksize) 432 val sfx = String.extract(s, chunksize, NONE) 433 in 434 recurse (chunkshift * acc + fromInt (valOf (readchunk pfx))) sfx 435 end 436 end 437in 438 fn s => recurse zero s handle Option => raise Fail ("String not " ^ s_rdx) 439end 440val fromHexString = genFromString StringCvt.HEX 441val fromOctString = genFromString StringCvt.OCT 442val fromBinString = genFromString StringCvt.BIN 443val fromString = genFromString StringCvt.DEC 444 445 446fun toString n = 447 if n = zero then "0" 448 else let 449 fun nonzero_recurse n = 450 if n = zero then "" 451 else let 452 val (q,r) = divmod(n, fromInt 10) 453 in 454 nonzero_recurse q^Int.toString (toInt r) 455 end 456 in 457 nonzero_recurse n 458 end 459 460 461local 462 fun toChar n = 463 str (if Int.<(n, 10) 464 then chr (Int.+(ord #"0", n)) 465 else chr (Int.-(Int.+(ord #"A", n), 10))) 466 fun toBaseString base n = 467 let val (q,r) = divmod(n, base) 468 val s = toChar (toInt r) 469 in 470 if q = zero then s else toBaseString base q^s 471 end 472in 473 val toBinString = toBaseString (fromInt 2) 474 val toOctString = toBaseString (fromInt 8) 475 val toHexString = toBaseString (fromInt 16) 476end 477 478fun isqrt n = 479 if n < two 480 then n 481 else let 482 fun iter a = 483 if a * a <= n andalso n < (a + one) * (a + one) 484 then a 485 else iter (div2 ((a * a + n) div a)) 486 in 487 iter one 488 end 489 490(* useful test code follows 491exception ArgsBad; 492 493 494fun test_op (nf, origf, P, print_opn, testresult) arg1 arg2 = let 495 val _ = P (arg1, arg2) orelse raise ArgsBad 496 val orig_result = origf(arg1, arg2) 497 val new_result = nf(fromInt arg1, fromInt arg2) 498 val ok = testresult(new_result, orig_result) 499in 500 print (Int.toString arg1^print_opn^Int.toString arg2); 501 if ok then print " agree\n" 502 else (print " disagree\n"; raise Fail "Urk") 503end 504 505fun test n opdetails = let 506 open Random 507 val gen = newgen() 508 fun do_test () = let 509 val arg1 = range (0, 60000000) gen 510 val arg2 = range (0, 60000000) gen 511 in 512 test_op opdetails arg1 arg2 513 end 514 fun doit_until_success f = f () 515 handle Fail s => raise Fail s 516 | Interrupt => raise Interrupt 517 | _ => doit_until_success f 518 fun doit f n = if Int.<=(n, 0) then () else (doit_until_success f ; 519 doit f (Int.-(n,1))) 520in 521 doit do_test n 522end 523 524fun testintresult (new, old) = toInt new = old 525val test_addition = (op+, Int.+, (fn _ => true), " + ", testintresult); 526val test_less = (op<, Int.<, (fn _ => true), " < ", op=); 527val test_leq = (op<=, Int.<=, (fn _ => true), " <= ", op=); 528val test_subtraction = (op-, Int.-, (fn (x,y) => Int.>=(x,y)), " - ", testintresult) 529val test_mult = (op*, Int.*, (fn _ => true), " * ", testintresult) 530val test_div = (op div, Int.div, (fn(x,y) => y <> 0), " div ", testintresult); 531val test_mod = (op mod, Int.mod, (fn(x,y) => y <> 0), " mod ", testintresult); 532 533val _ = test_op test_addition 78826 3251 534val _ = test_op test_div 49772146 458048 535val _ = test_op test_mod 34182186 2499 536val _ = test_op test_mod 26708509 29912224 537val _ = test_op test_div 6258 42171 538val _ = test_op test_div 6 13766 539val _ = test_op test_mod 6 13766 540val _ = test_op test_mod 38294758 10769 541val _ = test_op test_div 38294758 10769 542 543val _ = test 200 test_addition 544val _ = test 200 test_less 545val _ = test 200 test_leq 546val _ = test 200 test_subtraction 547(* val _ = test 30 test_mult *) 548val _ = test 200 test_div 549val _ = test 200 test_mod 550 551exception FailedProp 552fun testproperty3 (f, printprop) = let 553 val gen = Random.newgen() 554 fun generate_arg () = let 555 val size = Random.range(1,8) gen 556 in 557 normalise (Random.rangelist(0,BASE) (size, gen)) 558 end 559 val x = generate_arg() 560 val y = generate_arg() 561 val z = generate_arg() 562 val propstring = 563 "Property "^printprop^" for x = "^toString x^", y = "^toString y^ 564 ", z = "^toString z^"..." 565in 566 print propstring; 567 if f(x,y,z) then print "OK\n" 568 else (print "FAILED\n"; raise FailedProp) 569end 570 571fun testprop n prop = if Int.<=(n, 0) then () 572 else (testproperty3 prop; testprop (Int.-(n,1)) prop) handle FailedProp => Process.exit Process.failure 573 574val addition_associative = 575 ((fn (x,y,z) => (x + (y + z) = (x + y) + z)), 576 "(x + (y + z) = (x + y) + z)") 577val mult_assoc = 578 ((fn (x,y,z) => x * (y * z) = (x * y) * z), 579 "x * (y * z) = (x * y) * z") 580val distrib = 581 ((fn (x,y,z) => x * (y + z) = (x * y) + (x * z)), 582 "x * (y + z) = (x * y) + (x * z)") 583val divmod_test = 584 ((fn (x,y,z) => if y <> zero then (x div y) * y + (x mod y) = x else true), 585 "(x div y) * y + (x mod y) = x") 586 587val xn = fromString "260309023368" 588val yn = fromString "76734110" 589val _ = 590 if #1 divmod_test (xn, yn, zero) then print "OK\n" 591 else (print "FAILED\n"; Process.exit Process.failure) 592 593val _ = testprop 100 addition_associative 594val _ = testprop 100 mult_assoc 595val _ = testprop 100 distrib 596val _ = testprop 100 divmod_test 597*) 598 599end 600