1(* Simple-minded implementation of arbitrary precision natural numbers *)
2(* Copyright (c) Michael Norrish *)
3
4structure mlibArbnum :> mlibArbnum =
5struct
6
7fun extract arg = ArraySlice.vector(ArraySlice.slice arg)
8
9fun copyVec' {di,dst,len,src,si} = let
10  val v = VectorSlice.vector(VectorSlice.slice(src,si,len));
11in
12  Array.copyVec {di = di, dst = dst, src = v}
13end
14
15(* base must be <= the sqrt of MaxInt *)
16val BASE = 10000;
17val BASEless1 = BASE - 1;
18
19type num = int list
20
21fun asList x = x
22
23(* each element in the list is in the range 0 - (BASE - 1), least significant
24   "digit" is the first element in the list. *)
25
26val zero = [0];
27val one = [1];
28val two = [2];
29
30fun plus1 [] = raise Fail "Should never happen"
31  | plus1 [n] = if n = BASEless1 then [0, 1] else [n + 1]
32  | plus1 (n::xs) = if n = BASEless1 then 0::plus1 xs else (n + 1)::xs
33
34fun less1 [] = raise Fail "arbnum invariant falsified"
35  | less1 [x] = if x = 0 then raise Fail "Can't take one off zero"
36                else [x - 1]
37  | less1 (0::xs) = BASEless1::less1 xs
38  | less1 (x::xs) = (x - 1)::xs
39val less2 = less1 o less1
40
41val plus2 = plus1 o plus1
42
43fun times2 [] = []
44  | times2 [x] = if 2 * x < BASE then [2 * x] else [2 * x - BASE, 1]
45  | times2 (x::xs) =
46      if 2 * x < BASE then (2 * x)::(times2 xs)
47      else (2 * x - BASE)::(plus1 (times2 xs))
48
49fun revdiv2 [] = []
50  | revdiv2 [x] = [x div 2]
51  | revdiv2 (x::y::xs) = let
52      val dividend = x div 2
53      val remainder = x mod 2
54    in
55      dividend::(revdiv2 ((remainder * BASE) + y::xs))
56    end
57
58fun div2 n = let
59  val n' = List.rev n
60  fun strip [] = []
61    | strip [x] = [x]
62    | strip (0::xs) = strip xs
63    | strip (y as (x::xs)) = y
64in
65  List.rev (strip (revdiv2 n'))
66end
67
68fun mod2 (x::_) = [x mod 2]
69  | mod2 [] = raise Fail "arbnum representation invariant violated"
70
71fun fromInt n = let
72  val _ = n >= 0 orelse raise Fail "nums only work with positive numbers"
73  val dividend = n div BASE
74  val rem = n mod BASE
75in
76  rem::(if dividend > 0 then fromInt dividend else [])
77end
78
79fun toInt [] = 0
80  | toInt (x::xs) = Int.+(x, Int.*(BASE, toInt xs))
81
82
83(* addition is wrong on 78826 + 3251 *)
84fun (x + y) = addwc x y false
85and
86(* add with carry *)
87  addwc xn yn b =
88  case xn of
89    [] =>
90      if b then
91        if null yn then one else plus1 yn
92      else yn
93  | (x::xs) => let
94    in
95      case yn of
96        [] =>
97          if b then
98            if null xn then one else plus1 xn
99          else xn
100      | (y::ys) => let
101          val xy = Int.+(x, y)
102          val xyc = if b then Int.+(xy, 1) else xy
103          val (carry, rem) = if xyc >= BASE then (true, Int.-(xyc, BASE))
104                             else (false, xyc)
105        in
106          rem::(addwc xs ys carry)
107        end
108    end
109
110(*   x0 + 1000x < y0 + 1000y   (where x,y < BASE
111   =
112*)
113fun (xn < yn) =
114  case (xn, yn) of
115    (_, []) => false
116  | ([], _) => true
117  | (x::xs, y::ys) => xs < ys orelse (xs = ys andalso Int.<(x,y))
118fun (xn <= yn) = xn = yn orelse xn < yn
119fun (xn >= yn) = yn <= xn
120fun (xn > yn) = yn < xn
121
122fun normalise [] = [0]
123  | normalise x = let
124      fun strip_leading_zeroes [] = []
125        | strip_leading_zeroes (list as n::ns) =
126        if n = 0 then strip_leading_zeroes ns else list
127      val x' = List.rev (strip_leading_zeroes (List.rev x))
128    in
129      if null x' then [0] else x'
130    end
131
132fun (xn - yn) =
133  if xn < yn then zero else normalise (subwc xn yn false)
134and subwc xn yn b =
135  case (xn, yn) of
136    (_, []) => if b then less1 xn else xn
137  | ([], _) => zero
138  | (x::xs, y::ys) => let
139      val (x', carry) =
140        if b then
141          if Int.<=(x,y) then (Int.-(Int.+(x,BASEless1), y), true)
142          else (Int.-(Int.-(x,y), 1), false)
143        else
144          if Int.<(x,y) then (Int.-(Int.+(x, BASE), y), true)
145          else (Int.-(x,y), false)
146    in
147      x'::subwc xs ys carry
148    end
149
150(* (x0 + BASEx) * y = x0 * y + BASE * x * y *)
151fun single_digit(n, xn) =
152  case n of
153    0 => zero
154  | 1 => xn
155  | 2 => times2 xn
156  | _ => let
157      fun f [] = []
158        | f (x::xs) = let
159            val newx = Int.*(n,x)
160            val (rem, carry) = (Int.mod(newx, BASE), Int.div(newx,BASE))
161          in
162            if carry = 0 then rem::f xs
163            else rem::(f xs + [carry])
164          end
165    in
166      f xn
167    end
168
169
170fun (xn * yn) =
171  case (xn, yn) of
172    ([], _) => zero
173  | (_, []) => zero
174  | (x::xs, _) => normalise(single_digit(x, yn) + (0::(xs * yn)))
175
176fun replicate n el = List.tabulate(n, fn _ => el)
177(* returns result in wrong order *)
178fun comp_sub acc carry (xn, yn) =
179  case (xn, yn) of
180    ([], []) => (acc, carry)
181  | (x::xs, y::ys) => let
182      val (res, newcarry) =
183        if carry then
184          if Int.>(x, y) then
185            (Int.-(x, Int.+(y, 1)), false)
186          else
187            (Int.-(Int.+(x, BASEless1), y), true)
188        else
189          if Int.>=(x, y) then
190            (Int.-(x, y), false)
191          else
192            (Int.-(Int.+(x,BASE), y), true)
193    in
194      comp_sub (res::acc) newcarry (xs, ys)
195    end
196  | _ => raise Fail "comp_sub : arguments of different length"
197
198
199(* y < BASE *)
200fun single_divmod (xn:int list) (y:int) = let
201  val xnr = List.rev xn
202  fun loop xn acc =
203    case xn of
204      [] => raise Fail "single_divmod: can't happen"
205    | [x] => let
206        val q = Int.div(x, y)
207      in
208        (q::acc, [Int.mod(x,y)])
209      end
210    | [x1, x2] => let
211        val x = Int.+(Int.*(x1, BASE), x2)
212        val q = Int.div(x, y)
213      in
214        ((fromInt q) @ acc, [Int.mod(x, y)])
215      end
216    | (x1::x2::xs) => let
217        val x = Int.+(Int.*(x1, BASE), x2)
218        val q = Int.div(x, y)
219        val r = Int.mod(x, y)
220        (* r < y, so r < BASE *)
221      in
222        loop (r::xs) ((fromInt q) @ acc)
223      end
224  val _ = y <> 0 orelse raise Div
225  val (q, r) = loop xnr []
226in
227  (normalise q, normalise r)
228end
229
230
231fun divmod (xn, yn) =
232  if yn = zero then raise Div
233  else let
234
235    (* following algorithm from Knuth 4.3.1 Algorith D *)
236    (* we require
237       - that length vn > 1
238       - that length un >= length vn
239       - that hd vn <> 0
240     *)
241    fun KnuthD un vn = let
242      (* Knuth's algorithm is stateful so we mimic it here with
243       two arrays u and v *)
244      val d = Int.div(BASE, Int.+(List.last vn, 1))
245      val n = length vn
246      val m = Int.-(length un, length vn)
247      val normalised_u = single_digit(d, un)
248      val normalised_v = single_digit(d, vn)
249      val _ = length normalised_v = length vn orelse
250        raise Fail "normalised_v not same length as v"
251      val norm_v_rev = List.rev normalised_v
252      val v1 = hd norm_v_rev
253      val v2 = hd (tl norm_v_rev)
254      val norm_u_rev =
255        if (length normalised_u = length un) then 0::List.rev normalised_u
256        else List.rev normalised_u
257      val u = Array.fromList norm_u_rev
258      val _ = Array.length u = Int.+(m, Int.+(n, 1)) orelse
259        raise Fail "Array u of unexpected length"
260      val q = Array.array(Int.+(m, 1), 0)
261      fun inner_loop j = let
262        infix 9 sub
263        open Array
264        (* D3. Calculate q hat *)
265        fun qhat_test qhat = let
266          open Int
267        in
268          v2 * qhat >
269          (u sub j * BASE + u sub (j + 1) - qhat * v1) * BASE + u sub (j + 2)
270        end
271        val qhat0 = let
272          open Int
273        in
274          if u sub j = v1 then BASEless1
275          else ((u sub j) * BASE + (u sub (j + 1)))  div v1
276        end
277        val qhat =
278          if qhat_test qhat0 then
279            if qhat_test (Int.-(qhat0, 1)) then Int.-(qhat0, 2)
280            else Int.-(qhat0, 1)
281          else
282            qhat0
283        (* D4. multiply and subtract *)
284        val uslice_v = extract(u, j, SOME (Int.+(n, 1)))
285        val uslice_l = List.rev (Vector.foldr (op::) [] uslice_v)
286        val multiply_result0 = single_digit(qhat, normalised_v)
287        val multiply_result = let
288          val mr_len = List.length multiply_result0
289          val u_len = List.length uslice_l
290        in
291          if  mr_len <> u_len then
292            multiply_result0 @ replicate (Int.-(u_len, mr_len)) 0
293          else
294            multiply_result0
295        end
296        val (newu_l, d4carry) = comp_sub [] false (uslice_l, multiply_result)
297        val newu_v = Vector.fromList newu_l
298        val _ = copyVec' {di=j, dst=u, len=NONE, src=newu_v, si=0}
299        (* D5. test remainder *)
300        val () = update(q, j, qhat)
301        val _ =
302          if d4carry then let
303            (* D6. Add back *)
304            val uslice_v = extract(u, j, SOME (Int.+(n, 1)))
305            val uslice_l = List.rev(Vector.foldr (op::) [] uslice_v)
306            val newu0 = uslice_l + normalised_v
307            (* have to ignore rightmost digit *)
308            val newu = List.drop(List.rev newu0, 1);
309            val newu_v = Vector.fromList newu
310          in
311            update(q, j, Int.-(q sub j, 1));
312            copyVec' {di = j, dst = u, len = NONE, src = newu_v, si = 0}
313          end
314          else ()
315        open Int
316      in
317        if j + 1 <= m then inner_loop (j + 1)
318        else ()
319      end
320      val unnormal_result = inner_loop 0
321      val qn = normalise (List.rev (Array.foldr (op::) [] q))
322      val rn0 = let
323        open Int
324      in
325        normalise (List.rev
326                   (Vector.foldr (op::) [] (extract(u, m + 1, NONE))))
327      end
328      val rn = #1 (single_divmod rn0 d)
329    in
330      (qn, rn)
331    end
332  in
333    if length yn = 1 then
334      single_divmod xn (hd yn)
335    else
336      if Int.>(length yn, length xn) then
337        (zero, xn)
338      else
339        KnuthD xn yn
340  end
341
342fun (xn div yn) = #1 (divmod (xn, yn))
343fun (xn mod yn) = #2 (divmod (xn, yn))
344
345fun fromSubstring s = let
346  open Substring
347  val sz = size s
348in
349  if Int.<(sz, 5) then fromInt (valOf (Int.fromString (string s)))
350  else let
351    val (pfx, sfx) = splitAt(s, Int.-(sz, 4))
352    val sfx_n = fromInt (valOf (Int.fromString (string sfx)))
353    val pfx_n = fromInt 10000 * fromSubstring pfx
354  in
355    pfx_n + sfx_n
356  end
357end handle Option => raise Fail "String not numeric"
358
359fun fromString s = fromSubstring (Substring.full s)
360
361fun toString n =
362  if n = zero then "0"
363  else let
364    fun nonzero_recurse n =
365      if n = zero then ""
366      else let
367        val (q,r) = divmod(n, fromInt 10)
368      in
369        nonzero_recurse q^Int.toString (toInt r)
370      end
371  in
372    nonzero_recurse n
373  end
374
375(*  useful test code follows
376exception ArgsBad;
377
378
379fun test_op (nf, origf, P, print_opn, testresult) arg1 arg2 = let
380  val _ = P (arg1, arg2) orelse raise ArgsBad
381  val orig_result = origf(arg1, arg2)
382  val new_result = nf(fromInt arg1, fromInt arg2)
383  val ok = testresult(new_result, orig_result)
384in
385  print (Int.toString arg1^print_opn^Int.toString arg2);
386  if ok then print " agree\n"
387  else (print " disagree\n"; raise Fail "Urk")
388end
389
390fun test n opdetails = let
391  open Random
392  val gen = newgen()
393  fun do_test () = let
394    val arg1 = range (0, 60000000) gen
395    val arg2 = range (0, 60000000) gen
396  in
397    test_op opdetails arg1 arg2
398  end
399  fun doit_until_success f = f ()
400    handle Fail s => raise Fail s
401         | Interrupt => raise Interrupt
402         | _ => doit_until_success f
403  fun doit f n = if Int.<=(n, 0) then () else (doit_until_success f ;
404                                               doit f (Int.-(n,1)))
405in
406  doit do_test n
407end
408
409fun testintresult (new, old) = toInt new = old
410val test_addition = (op+, Int.+, (fn _ => true), " + ", testintresult);
411val test_less = (op<, Int.<, (fn _ => true), " < ", op=);
412val test_leq = (op<=, Int.<=, (fn _ => true), " <= ", op=);
413val test_subtraction = (op-, Int.-, (fn (x,y) => Int.>=(x,y)), " - ", testintresult)
414val test_mult = (op*, Int.*, (fn _ => true), " * ", testintresult)
415val test_div = (op div, Int.div, (fn(x,y) => y <> 0), " div ", testintresult);
416val test_mod = (op mod, Int.mod, (fn(x,y) => y <> 0), " mod ", testintresult);
417
418val _ = test_op test_addition 78826 3251
419val _ = test_op test_div 49772146 458048
420val _ = test_op test_mod 34182186 2499
421val _ = test_op test_mod 26708509 29912224
422val _ = test_op test_div 6258 42171
423val _ = test_op test_div 6 13766
424val _ = test_op test_mod 6 13766
425val _ = test_op test_mod 38294758 10769
426val _ = test_op test_div 38294758 10769
427
428val _ = test 200 test_addition
429val _ = test 200 test_less
430val _ = test 200 test_leq
431val _ = test 200 test_subtraction
432(* val _ = test 30 test_mult *)
433val _ = test 200 test_div
434val _ = test 200 test_mod
435
436exception FailedProp
437fun testproperty3 (f, printprop) = let
438  val gen = Random.newgen()
439  fun generate_arg () = let
440    val size = Random.range(1,8) gen
441  in
442    normalise (Random.rangelist(0,BASE) (size, gen))
443  end
444  val x = generate_arg()
445  val y = generate_arg()
446  val z = generate_arg()
447  val propstring =
448    "Property "^printprop^" for x = "^toString x^", y = "^toString y^
449    ", z = "^toString z^"..."
450in
451  print propstring;
452  if f(x,y,z) then print "OK\n"
453  else (print "FAILED\n"; raise FailedProp)
454end
455
456fun testprop n prop = if Int.<=(n, 0) then ()
457                      else (testproperty3 prop; testprop (Int.-(n,1)) prop) handle FailedProp => Process.exit Process.failure
458
459val addition_associative =
460  ((fn (x,y,z) => (x + (y + z) = (x + y) + z)),
461   "(x + (y + z) = (x + y) + z)")
462val mult_assoc =
463  ((fn (x,y,z) => x * (y * z) = (x * y) * z),
464   "x * (y * z) = (x * y) * z")
465val distrib =
466  ((fn (x,y,z) => x * (y + z) = (x * y) + (x * z)),
467   "x * (y + z) = (x * y) + (x * z)")
468val divmod_test =
469  ((fn (x,y,z) => if y <> zero then (x div y) * y + (x mod y) = x else true),
470   "(x div y) * y + (x mod y) = x")
471
472val xn = fromString "260309023368"
473val yn = fromString "76734110"
474val _ =
475  if #1 divmod_test (xn, yn, zero) then print "OK\n"
476  else (print "FAILED\n"; Process.exit Process.failure)
477
478val _ = testprop 100 addition_associative
479val _ = testprop 100 mult_assoc
480val _ = testprop 100 distrib
481val _ = testprop 100 divmod_test
482*)
483
484end
485