1(* Author: Michael Norrish *)
2
3(* very simple-minded implementation of arbitrary precision natural
4   numbers *)
5
6structure Arbnumcore :> Arbnumcore =
7struct
8
9(* app load ["Substring", "Random", "Int", "Process"] *)
10
11fun extract arg = ArraySlice.vector(ArraySlice.slice arg)
12
13fun copyVec' {di,dst,len,src,si} = let
14  val v = VectorSlice.vector(VectorSlice.slice(src,si,len));
15in
16  Array.copyVec {di = di, dst = dst, src = v}
17end
18
19
20(* base must be <= the sqrt of MaxInt *)
21val BASE = 10000;
22val BASEless1 = BASE - 1;
23
24type num = int list
25
26fun asList x = x
27
28(* each element in the list is in the range 0 - (BASE - 1), least significant
29   "digit" is the first element in the list. *)
30
31val zero = [0];
32val one = [1];
33val two = [2];
34
35fun plus1 [] = raise Fail "Should never happen"
36  | plus1 [n] = if n = BASEless1 then [0, 1] else [n + 1]
37  | plus1 (n::xs) = if n = BASEless1 then 0::plus1 xs else (n + 1)::xs
38
39fun less1 [] = raise Fail "arbnum invariant falsified"
40  | less1 [x] = if x = 0 then raise Fail "Can't take one off zero"
41                else [x - 1]
42  | less1 (0::xs) = BASEless1::less1 xs
43  | less1 (x::xs) = (x - 1)::xs
44val less2 = less1 o less1
45
46val plus2 = plus1 o plus1
47
48fun times2 [] = []
49  | times2 [x] = if 2 * x < BASE then [2 * x] else [2 * x - BASE, 1]
50  | times2 (x::xs) =
51      if 2 * x < BASE then (2 * x)::(times2 xs)
52      else (2 * x - BASE)::(plus1 (times2 xs))
53
54fun revdiv2 [] = []
55  | revdiv2 [x] = [x div 2]
56  | revdiv2 (x::y::xs) = let
57      val dividend = x div 2
58      val remainder = x mod 2
59    in
60      dividend::(revdiv2 ((remainder * BASE) + y::xs))
61    end
62
63fun div2 n = let
64  val n' = List.rev n
65  fun strip [] = []
66    | strip [x] = [x]
67    | strip (0::xs) = strip xs
68    | strip (y as (x::xs)) = y
69in
70  List.rev (strip (revdiv2 n'))
71end
72
73fun mod2 (x::_) = [x mod 2]
74  | mod2 [] = raise Fail "arbnum representation invariant violated"
75
76fun fromInt n = let
77  val _ = n >= 0 orelse raise Overflow
78  val dividend = n div BASE
79  val rem = n mod BASE
80in
81  rem::(if dividend > 0 then fromInt dividend else [])
82end
83
84fun toInt [] = 0
85  | toInt (x::xs) = Int.+(x, Int.*(BASE, toInt xs))
86
87
88
89fun (x + y) = addwc x y false
90and
91(* add with carry *)
92  addwc xn yn b =
93  case xn of
94    [] =>
95      if b then
96        if null yn then one else plus1 yn
97      else yn
98  | (x::xs) => let
99    in
100      case yn of
101        [] =>
102          if b then
103            if null xn then one else plus1 xn
104          else xn
105      | (y::ys) => let
106          val xy = Int.+(x, y)
107          val xyc = if b then Int.+(xy, 1) else xy
108          val (carry, rem) = if xyc >= BASE then (true, Int.-(xyc, BASE))
109                             else (false, xyc)
110        in
111          rem::(addwc xs ys carry)
112        end
113    end
114
115fun floor r = let
116  val _ = 0.0 <= r orelse raise Overflow
117  fun count_base_powers r =
118      (Real.floor r, 0)
119      handle Overflow => let
120               val (res, n) = count_base_powers (r / real BASE)
121             in
122               (res, Int.+(n,1))
123             end
124  val (sig_part, exp) = count_base_powers r
125  fun shl n a = if n = 0 then a else shl (n - 1) (0 :: a)
126  fun exp_r n r = if n = 0 then r else exp_r (n - 1) (r * real BASE)
127  val sig_anum = shl exp (fromInt sig_part)
128  val sig_real = exp_r exp (real sig_part)
129in
130  if exp > 0 then sig_anum + floor (r - sig_real)
131  else sig_anum
132end
133
134fun toReal anum = List.foldl (fn (i,acc) => Real.+(acc * real BASE, real i))
135                             0.0 (List.rev anum)
136
137(*   x0 + 1000x < y0 + 1000y   (where x,y < BASE
138   =
139*)
140fun (xn < yn) =
141  case (xn, yn) of
142    (_, []) => false
143  | ([], _) => true
144  | (x::xs, y::ys) => xs < ys orelse (xs = ys andalso Int.<(x,y))
145fun (xn <= yn) = xn = yn orelse xn < yn
146fun (xn >= yn) = yn <= xn
147fun (xn > yn) = yn < xn
148
149fun compare(x,y) = if x < y then LESS else if y < x then GREATER else EQUAL
150
151fun normalise [] = [0]
152  | normalise x = let
153      fun strip_leading_zeroes [] = []
154        | strip_leading_zeroes (list as n::ns) =
155        if n = 0 then strip_leading_zeroes ns else list
156      val x' = List.rev (strip_leading_zeroes (List.rev x))
157    in
158      if null x' then [0] else x'
159    end
160
161fun (xn - yn) =
162  if xn < yn then zero else normalise (subwc xn yn false)
163and subwc xn yn b =
164  case (xn, yn) of
165    (_, []) => if b then less1 xn else xn
166  | ([], _) => zero
167  | (x::xs, y::ys) => let
168      val (x', carry) =
169        if b then
170          if Int.<=(x,y) then (Int.-(Int.+(x,BASEless1), y), true)
171          else (Int.-(Int.-(x,y), 1), false)
172        else
173          if Int.<(x,y) then (Int.-(Int.+(x, BASE), y), true)
174          else (Int.-(x,y), false)
175    in
176      x'::subwc xs ys carry
177    end
178
179(* (x0 + BASEx) * y = x0 * y + BASE * x * y *)
180fun single_digit(n, xn) =
181  case n of
182    0 => zero
183  | 1 => xn
184  | 2 => times2 xn
185  | _ => let
186      fun f [] = []
187        | f (x::xs) = let
188            val newx = Int.*(n,x)
189            val (rem, carry) = (Int.mod(newx, BASE), Int.div(newx,BASE))
190          in
191            if carry = 0 then rem::f xs
192            else rem::(f xs + [carry])
193          end
194    in
195      f xn
196    end
197
198
199fun (xn * yn) =
200  case (xn, yn) of
201    ([], _) => zero
202  | (_, []) => zero
203  | (x::xs, _) => normalise(single_digit(x, yn) + (0::(xs * yn)))
204
205fun pow (xn, yn) =
206  if mod2 yn = one then
207    (xn * pow(xn, less1 yn))
208  else
209    if yn = zero then
210      one
211    else
212      pow((xn * xn), div2 yn);
213
214fun replicate n el = List.tabulate(n, fn _ => el)
215(* returns result in wrong order *)
216fun comp_sub acc carry (xn, yn) =
217  case (xn, yn) of
218    ([], []) => (acc, carry)
219  | (x::xs, y::ys) => let
220      val (res, newcarry) =
221        if carry then
222          if Int.>(x, y) then
223            (Int.-(x, Int.+(y, 1)), false)
224          else
225            (Int.-(Int.+(x, BASEless1), y), true)
226        else
227          if Int.>=(x, y) then
228            (Int.-(x, y), false)
229          else
230            (Int.-(Int.+(x,BASE), y), true)
231    in
232      comp_sub (res::acc) newcarry (xs, ys)
233    end
234  | _ => raise Fail "comp_sub : arguments of different length"
235
236
237(* y < BASE *)
238fun single_divmod (xn:int list) (y:int) = let
239  val xnr = List.rev xn
240  fun loop xn acc =
241    case xn of
242      [] => raise Fail "single_divmod: can't happen"
243    | [x] => let
244        val q = Int.div(x, y)
245      in
246        (q::acc, [Int.mod(x,y)])
247      end
248    | [x1, x2] => let
249        val x = Int.+(Int.*(x1, BASE), x2)
250        val q = Int.div(x, y)
251      in
252        ((fromInt q) @ acc, [Int.mod(x, y)])
253      end
254    | (x1::x2::xs) => let
255        val x = Int.+(Int.*(x1, BASE), x2)
256        val q = Int.div(x, y)
257        val r = Int.mod(x, y)
258        (* r < y, so r < BASE *)
259      in
260        loop (r::xs) ((fromInt q) @ acc)
261      end
262  val _ = y <> 0 orelse raise Div
263  val (q, r) = loop xnr []
264in
265  (normalise q, normalise r)
266end
267
268
269fun divmod (xn, yn) =
270  if yn = zero then raise Div
271  else let
272
273    (* following algorithm from Knuth 4.3.1 Algorith D *)
274    (* we require
275       - that length vn > 1
276       - that length un >= length vn
277       - that hd vn <> 0
278     *)
279    fun KnuthD un vn = let
280      (* Knuth's algorithm is stateful so we mimic it here with
281       two arrays u and v *)
282      val d = Int.div(BASE, Int.+(List.last vn, 1))
283      val n = length vn
284      val m = Int.-(length un, length vn)
285      val normalised_u = single_digit(d, un)
286      val normalised_v = single_digit(d, vn)
287      val _ = length normalised_v = length vn orelse
288        raise Fail "normalised_v not same length as v"
289      val norm_v_rev = List.rev normalised_v
290      val v1 = hd norm_v_rev
291      val v2 = hd (tl norm_v_rev)
292      val norm_u_rev =
293        if (length normalised_u = length un) then 0::List.rev normalised_u
294        else List.rev normalised_u
295      val u = Array.fromList norm_u_rev
296      val _ = Array.length u = Int.+(m, Int.+(n, 1)) orelse
297        raise Fail "Array u of unexpected length"
298      val q = Array.array(Int.+(m, 1), 0)
299      fun inner_loop j = let
300        infix 9 sub
301        open Array
302        (* D3. Calculate q hat *)
303        fun qhat_test qhat = let
304          open Int
305        in
306          v2 * qhat >
307          (u sub j * BASE + u sub (j + 1) - qhat * v1) * BASE + u sub (j + 2)
308        end
309        val qhat0 = let
310          open Int
311        in
312          if u sub j = v1 then BASEless1
313          else ((u sub j) * BASE + (u sub (j + 1)))  div v1
314        end
315        val qhat =
316          if qhat_test qhat0 then
317            if qhat_test (Int.-(qhat0, 1)) then Int.-(qhat0, 2)
318            else Int.-(qhat0, 1)
319          else
320            qhat0
321        (* D4. multiply and subtract *)
322        val uslice_v = extract(u, j, SOME (Int.+(n, 1)))
323        val uslice_l = List.rev (Vector.foldr (op::) [] uslice_v)
324        val multiply_result0 = single_digit(qhat, normalised_v)
325        val multiply_result = let
326          val mr_len = List.length multiply_result0
327          val u_len = List.length uslice_l
328        in
329          if  mr_len <> u_len then
330            multiply_result0 @ replicate (Int.-(u_len, mr_len)) 0
331          else
332            multiply_result0
333        end
334        val (newu_l, d4carry) = comp_sub [] false (uslice_l, multiply_result)
335        val newu_v = Vector.fromList newu_l
336        val _ = copyVec' {di=j, dst=u, len=NONE, src=newu_v, si=0}
337        (* D5. test remainder *)
338        val () = update(q, j, qhat)
339        val _ =
340          if d4carry then let
341            (* D6. Add back *)
342            val uslice_v = extract(u, j, SOME (Int.+(n, 1)))
343            val uslice_l = List.rev(Vector.foldr (op::) [] uslice_v)
344            val newu0 = uslice_l + normalised_v
345            (* have to ignore rightmost digit *)
346            val newu = List.drop(List.rev newu0, 1);
347            val newu_v = Vector.fromList newu
348          in
349            update(q, j, Int.-(q sub j, 1));
350            copyVec' {di = j, dst = u, len = NONE, src = newu_v, si = 0}
351          end
352          else ()
353        open Int
354      in
355        if j + 1 <= m then inner_loop (j + 1)
356        else ()
357      end
358      val unnormal_result = inner_loop 0
359      val qn = normalise (List.rev (Array.foldr (op::) [] q))
360      val rn0 = let
361        open Int
362      in
363        normalise (List.rev
364                   (Vector.foldr (op::) [] (extract(u, m + 1, NONE))))
365      end
366      val rn = #1 (single_divmod rn0 d)
367    in
368      (qn, rn)
369    end
370  in
371    if length yn = 1 then
372      single_divmod xn (hd yn)
373    else
374      if Int.>(length yn, length xn) then
375        (zero, xn)
376      else
377        KnuthD xn yn
378  end
379
380fun (xn div yn) = #1 (divmod (xn, yn))
381fun (xn mod yn) = #2 (divmod (xn, yn))
382
383local
384  fun recurse x n = if x = zero then n else recurse (div2 x) (plus1 n)
385 in
386  fun log2 x = if x = zero then raise Domain else recurse (div2 x) zero
387end
388
389local
390  fun gcd' i j = let
391    val r = i mod j
392  in
393    if r = zero then j else gcd' j r
394  end
395in
396fun gcd(i,j) = if i = zero then j
397               else if j = zero then i
398               else if i < j then gcd' j i
399               else gcd' i j
400end
401
402fun intexp(base,exponent) = let
403  fun recurse acc b n =
404      if Int.<=(n,0) then acc
405      else if Int.mod(n,2) = 1 then
406        recurse (Int.*(acc, b)) b (Int.-(n,1))
407      else recurse acc (Int.*(b,b)) (Int.div(n,2))
408in
409  recurse 1 base exponent
410end
411
412fun genFromString rdx = let
413  open StringCvt
414  val (base, chunksize, chunkshift, s_rdx) =
415      case rdx of
416        BIN => (2, 10, fromInt 1024, "BIN")
417      | OCT => (8, 5, fromInt 32768, "OCT")
418      | DEC => (10, 5, fromInt 100000, "DEC")
419      | HEX => (16, 5, fromInt 1048576, "HEX")
420  fun readchunk s =
421     case Int.scan rdx Substring.getc (Substring.full s) of
422        SOME (i, r) => if Substring.size r = 0 then SOME i else NONE
423      | _ => NONE
424  fun recurse acc s = let
425    val sz = size s
426  in
427    if Int.<=(sz, chunksize) then
428      fromInt (intexp(base,sz)) * acc + fromInt (valOf (readchunk s))
429    else let
430        val sz_less_cs = Int.-(sz, chunksize)
431        val pfx = substring(s, 0, chunksize)
432        val sfx = String.extract(s, chunksize, NONE)
433      in
434        recurse (chunkshift * acc + fromInt (valOf (readchunk pfx))) sfx
435      end
436  end
437in
438  fn s => recurse zero s handle Option => raise Fail ("String not " ^ s_rdx)
439end
440val fromHexString = genFromString StringCvt.HEX
441val fromOctString = genFromString StringCvt.OCT
442val fromBinString = genFromString StringCvt.BIN
443val fromString = genFromString StringCvt.DEC
444
445
446fun toString n =
447  if n = zero then "0"
448  else let
449    fun nonzero_recurse n =
450      if n = zero then ""
451      else let
452        val (q,r) = divmod(n, fromInt 10)
453      in
454        nonzero_recurse q^Int.toString (toInt r)
455      end
456  in
457    nonzero_recurse n
458  end
459
460
461local
462  fun toChar n =
463      str (if Int.<(n, 10)
464           then chr (Int.+(ord #"0", n))
465           else chr (Int.-(Int.+(ord #"A", n), 10)))
466  fun toBaseString base n =
467    let val (q,r) = divmod(n, base)
468        val s = toChar (toInt r)
469  in
470    if q = zero then s else toBaseString base q^s
471  end
472in
473  val toBinString = toBaseString (fromInt 2)
474  val toOctString = toBaseString (fromInt 8)
475  val toHexString = toBaseString (fromInt 16)
476end
477
478fun isqrt n =
479   if n < two
480      then n
481   else let
482           fun iter a =
483              if a * a <= n andalso n < (a + one) * (a + one)
484                 then a
485              else iter (div2 ((a * a + n) div a))
486        in
487           iter one
488        end
489
490(*  useful test code follows
491exception ArgsBad;
492
493
494fun test_op (nf, origf, P, print_opn, testresult) arg1 arg2 = let
495  val _ = P (arg1, arg2) orelse raise ArgsBad
496  val orig_result = origf(arg1, arg2)
497  val new_result = nf(fromInt arg1, fromInt arg2)
498  val ok = testresult(new_result, orig_result)
499in
500  print (Int.toString arg1^print_opn^Int.toString arg2);
501  if ok then print " agree\n"
502  else (print " disagree\n"; raise Fail "Urk")
503end
504
505fun test n opdetails = let
506  open Random
507  val gen = newgen()
508  fun do_test () = let
509    val arg1 = range (0, 60000000) gen
510    val arg2 = range (0, 60000000) gen
511  in
512    test_op opdetails arg1 arg2
513  end
514  fun doit_until_success f = f ()
515    handle Fail s => raise Fail s
516         | Interrupt => raise Interrupt
517         | _ => doit_until_success f
518  fun doit f n = if Int.<=(n, 0) then () else (doit_until_success f ;
519                                               doit f (Int.-(n,1)))
520in
521  doit do_test n
522end
523
524fun testintresult (new, old) = toInt new = old
525val test_addition = (op+, Int.+, (fn _ => true), " + ", testintresult);
526val test_less = (op<, Int.<, (fn _ => true), " < ", op=);
527val test_leq = (op<=, Int.<=, (fn _ => true), " <= ", op=);
528val test_subtraction = (op-, Int.-, (fn (x,y) => Int.>=(x,y)), " - ", testintresult)
529val test_mult = (op*, Int.*, (fn _ => true), " * ", testintresult)
530val test_div = (op div, Int.div, (fn(x,y) => y <> 0), " div ", testintresult);
531val test_mod = (op mod, Int.mod, (fn(x,y) => y <> 0), " mod ", testintresult);
532
533val _ = test_op test_addition 78826 3251
534val _ = test_op test_div 49772146 458048
535val _ = test_op test_mod 34182186 2499
536val _ = test_op test_mod 26708509 29912224
537val _ = test_op test_div 6258 42171
538val _ = test_op test_div 6 13766
539val _ = test_op test_mod 6 13766
540val _ = test_op test_mod 38294758 10769
541val _ = test_op test_div 38294758 10769
542
543val _ = test 200 test_addition
544val _ = test 200 test_less
545val _ = test 200 test_leq
546val _ = test 200 test_subtraction
547(* val _ = test 30 test_mult *)
548val _ = test 200 test_div
549val _ = test 200 test_mod
550
551exception FailedProp
552fun testproperty3 (f, printprop) = let
553  val gen = Random.newgen()
554  fun generate_arg () = let
555    val size = Random.range(1,8) gen
556  in
557    normalise (Random.rangelist(0,BASE) (size, gen))
558  end
559  val x = generate_arg()
560  val y = generate_arg()
561  val z = generate_arg()
562  val propstring =
563    "Property "^printprop^" for x = "^toString x^", y = "^toString y^
564    ", z = "^toString z^"..."
565in
566  print propstring;
567  if f(x,y,z) then print "OK\n"
568  else (print "FAILED\n"; raise FailedProp)
569end
570
571fun testprop n prop = if Int.<=(n, 0) then ()
572                      else (testproperty3 prop; testprop (Int.-(n,1)) prop) handle FailedProp => Process.exit Process.failure
573
574val addition_associative =
575  ((fn (x,y,z) => (x + (y + z) = (x + y) + z)),
576   "(x + (y + z) = (x + y) + z)")
577val mult_assoc =
578  ((fn (x,y,z) => x * (y * z) = (x * y) * z),
579   "x * (y * z) = (x * y) * z")
580val distrib =
581  ((fn (x,y,z) => x * (y + z) = (x * y) + (x * z)),
582   "x * (y + z) = (x * y) + (x * z)")
583val divmod_test =
584  ((fn (x,y,z) => if y <> zero then (x div y) * y + (x mod y) = x else true),
585   "(x div y) * y + (x mod y) = x")
586
587val xn = fromString "260309023368"
588val yn = fromString "76734110"
589val _ =
590  if #1 divmod_test (xn, yn, zero) then print "OK\n"
591  else (print "FAILED\n"; Process.exit Process.failure)
592
593val _ = testprop 100 addition_associative
594val _ = testprop 100 mult_assoc
595val _ = testprop 100 distrib
596val _ = testprop 100 divmod_test
597*)
598
599end
600