1/* Schoenhage's fast multiplication modulo 2^N+1.
2
3   Contributed by Paul Zimmermann.
4
5   THE FUNCTIONS IN THIS FILE ARE INTERNAL WITH MUTABLE INTERFACES.  IT IS ONLY
6   SAFE TO REACH THEM THROUGH DOCUMENTED INTERFACES.  IN FACT, IT IS ALMOST
7   GUARANTEED THAT THEY WILL CHANGE OR DISAPPEAR IN A FUTURE GNU MP RELEASE.
8
9Copyright 1998, 1999, 2000, 2001, 2002, 2003, 2004, 2005, 2006, 2007, 2008,
102009, 2010 Free Software Foundation, Inc.
11
12This file is part of the GNU MP Library.
13
14The GNU MP Library is free software; you can redistribute it and/or modify
15it under the terms of the GNU Lesser General Public License as published by
16the Free Software Foundation; either version 3 of the License, or (at your
17option) any later version.
18
19The GNU MP Library is distributed in the hope that it will be useful, but
20WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
21or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU Lesser General Public
22License for more details.
23
24You should have received a copy of the GNU Lesser General Public License
25along with the GNU MP Library.  If not, see http://www.gnu.org/licenses/.  */
26
27
28/* References:
29
30   Schnelle Multiplikation grosser Zahlen, by Arnold Schoenhage and Volker
31   Strassen, Computing 7, p. 281-292, 1971.
32
33   Asymptotically fast algorithms for the numerical multiplication and division
34   of polynomials with complex coefficients, by Arnold Schoenhage, Computer
35   Algebra, EUROCAM'82, LNCS 144, p. 3-15, 1982.
36
37   Tapes versus Pointers, a study in implementing fast algorithms, by Arnold
38   Schoenhage, Bulletin of the EATCS, 30, p. 23-32, 1986.
39
40   TODO:
41
42   Implement some of the tricks published at ISSAC'2007 by Gaudry, Kruppa, and
43   Zimmermann.
44
45   It might be possible to avoid a small number of MPN_COPYs by using a
46   rotating temporary or two.
47
48   Cleanup and simplify the code!
49*/
50
51#ifdef TRACE
52#undef TRACE
53#define TRACE(x) x
54#include <stdio.h>
55#else
56#define TRACE(x)
57#endif
58
59#include "gmp.h"
60#include "gmp-impl.h"
61
62#ifdef WANT_ADDSUB
63#include "generic/add_n_sub_n.c"
64#define HAVE_NATIVE_mpn_add_n_sub_n 1
65#endif
66
67static mp_limb_t mpn_mul_fft_internal
68__GMP_PROTO ((mp_ptr, mp_size_t, int, mp_ptr *, mp_ptr *,
69	      mp_ptr, mp_ptr, mp_size_t, mp_size_t, mp_size_t, int **, mp_ptr, int));
70static void mpn_mul_fft_decompose
71__GMP_PROTO ((mp_ptr, mp_ptr *, int, int, mp_srcptr, mp_size_t, int, int, mp_ptr));
72
73
74/* Find the best k to use for a mod 2^(m*GMP_NUMB_BITS)+1 FFT for m >= n.
75   We have sqr=0 if for a multiply, sqr=1 for a square.
76   There are three generations of this code; we keep the old ones as long as
77   some gmp-mparam.h is not updated.  */
78
79
80/*****************************************************************************/
81
82#if TUNE_PROGRAM_BUILD || (defined (MUL_FFT_TABLE3) && defined (SQR_FFT_TABLE3))
83
84#ifndef FFT_TABLE3_SIZE		/* When tuning, this is define in gmp-impl.h */
85#if defined (MUL_FFT_TABLE3_SIZE) && defined (SQR_FFT_TABLE3_SIZE)
86#if MUL_FFT_TABLE3_SIZE > SQR_FFT_TABLE3_SIZE
87#define FFT_TABLE3_SIZE MUL_FFT_TABLE3_SIZE
88#else
89#define FFT_TABLE3_SIZE SQR_FFT_TABLE3_SIZE
90#endif
91#endif
92#endif
93
94#ifndef FFT_TABLE3_SIZE
95#define FFT_TABLE3_SIZE 200
96#endif
97
98FFT_TABLE_ATTRS struct fft_table_nk mpn_fft_table3[2][FFT_TABLE3_SIZE] =
99{
100  MUL_FFT_TABLE3,
101  SQR_FFT_TABLE3
102};
103
104int
105mpn_fft_best_k (mp_size_t n, int sqr)
106{
107  FFT_TABLE_ATTRS struct fft_table_nk *fft_tab, *tab;
108  mp_size_t tab_n, thres;
109  int last_k;
110
111  fft_tab = mpn_fft_table3[sqr];
112  last_k = fft_tab->k;
113  for (tab = fft_tab + 1; ; tab++)
114    {
115      tab_n = tab->n;
116      thres = tab_n << last_k;
117      if (n <= thres)
118	break;
119      last_k = tab->k;
120    }
121  return last_k;
122}
123
124#define MPN_FFT_BEST_READY 1
125#endif
126
127/*****************************************************************************/
128
129#if ! defined (MPN_FFT_BEST_READY)
130FFT_TABLE_ATTRS mp_size_t mpn_fft_table[2][MPN_FFT_TABLE_SIZE] =
131{
132  MUL_FFT_TABLE,
133  SQR_FFT_TABLE
134};
135
136int
137mpn_fft_best_k (mp_size_t n, int sqr)
138{
139  int i;
140
141  for (i = 0; mpn_fft_table[sqr][i] != 0; i++)
142    if (n < mpn_fft_table[sqr][i])
143      return i + FFT_FIRST_K;
144
145  /* treat 4*last as one further entry */
146  if (i == 0 || n < 4 * mpn_fft_table[sqr][i - 1])
147    return i + FFT_FIRST_K;
148  else
149    return i + FFT_FIRST_K + 1;
150}
151#endif
152
153/*****************************************************************************/
154
155
156/* Returns smallest possible number of limbs >= pl for a fft of size 2^k,
157   i.e. smallest multiple of 2^k >= pl.
158
159   Don't declare static: needed by tuneup.
160*/
161
162mp_size_t
163mpn_fft_next_size (mp_size_t pl, int k)
164{
165  pl = 1 + ((pl - 1) >> k); /* ceil (pl/2^k) */
166  return pl << k;
167}
168
169
170/* Initialize l[i][j] with bitrev(j) */
171static void
172mpn_fft_initl (int **l, int k)
173{
174  int i, j, K;
175  int *li;
176
177  l[0][0] = 0;
178  for (i = 1, K = 1; i <= k; i++, K *= 2)
179    {
180      li = l[i];
181      for (j = 0; j < K; j++)
182	{
183	  li[j] = 2 * l[i - 1][j];
184	  li[K + j] = 1 + li[j];
185	}
186    }
187}
188
189
190/* r <- a*2^d mod 2^(n*GMP_NUMB_BITS)+1 with a = {a, n+1}
191   Assumes a is semi-normalized, i.e. a[n] <= 1.
192   r and a must have n+1 limbs, and not overlap.
193*/
194static void
195mpn_fft_mul_2exp_modF (mp_ptr r, mp_srcptr a, unsigned int d, mp_size_t n)
196{
197  int sh;
198  mp_limb_t cc, rd;
199
200  sh = d % GMP_NUMB_BITS;
201  d /= GMP_NUMB_BITS;
202
203  if (d >= n)			/* negate */
204    {
205      /* r[0..d-1]  <-- lshift(a[n-d]..a[n-1], sh)
206	 r[d..n-1]  <-- -lshift(a[0]..a[n-d-1],  sh) */
207
208      d -= n;
209      if (sh != 0)
210	{
211	  /* no out shift below since a[n] <= 1 */
212	  mpn_lshift (r, a + n - d, d + 1, sh);
213	  rd = r[d];
214	  cc = mpn_lshiftc (r + d, a, n - d, sh);
215	}
216      else
217	{
218	  MPN_COPY (r, a + n - d, d);
219	  rd = a[n];
220	  mpn_com (r + d, a, n - d);
221	  cc = 0;
222	}
223
224      /* add cc to r[0], and add rd to r[d] */
225
226      /* now add 1 in r[d], subtract 1 in r[n], i.e. add 1 in r[0] */
227
228      r[n] = 0;
229      /* cc < 2^sh <= 2^(GMP_NUMB_BITS-1) thus no overflow here */
230      cc++;
231      mpn_incr_u (r, cc);
232
233      rd++;
234      /* rd might overflow when sh=GMP_NUMB_BITS-1 */
235      cc = (rd == 0) ? 1 : rd;
236      r = r + d + (rd == 0);
237      mpn_incr_u (r, cc);
238    }
239  else
240    {
241      /* r[0..d-1]  <-- -lshift(a[n-d]..a[n-1], sh)
242	 r[d..n-1]  <-- lshift(a[0]..a[n-d-1],  sh)  */
243      if (sh != 0)
244	{
245	  /* no out bits below since a[n] <= 1 */
246	  mpn_lshiftc (r, a + n - d, d + 1, sh);
247	  rd = ~r[d];
248	  /* {r, d+1} = {a+n-d, d+1} << sh */
249	  cc = mpn_lshift (r + d, a, n - d, sh); /* {r+d, n-d} = {a, n-d}<<sh */
250	}
251      else
252	{
253	  /* r[d] is not used below, but we save a test for d=0 */
254	  mpn_com (r, a + n - d, d + 1);
255	  rd = a[n];
256	  MPN_COPY (r + d, a, n - d);
257	  cc = 0;
258	}
259
260      /* now complement {r, d}, subtract cc from r[0], subtract rd from r[d] */
261
262      /* if d=0 we just have r[0]=a[n] << sh */
263      if (d != 0)
264	{
265	  /* now add 1 in r[0], subtract 1 in r[d] */
266	  if (cc-- == 0) /* then add 1 to r[0] */
267	    cc = mpn_add_1 (r, r, n, CNST_LIMB(1));
268	  cc = mpn_sub_1 (r, r, d, cc) + 1;
269	  /* add 1 to cc instead of rd since rd might overflow */
270	}
271
272      /* now subtract cc and rd from r[d..n] */
273
274      r[n] = -mpn_sub_1 (r + d, r + d, n - d, cc);
275      r[n] -= mpn_sub_1 (r + d, r + d, n - d, rd);
276      if (r[n] & GMP_LIMB_HIGHBIT)
277	r[n] = mpn_add_1 (r, r, n, CNST_LIMB(1));
278    }
279}
280
281
282/* r <- a+b mod 2^(n*GMP_NUMB_BITS)+1.
283   Assumes a and b are semi-normalized.
284*/
285static inline void
286mpn_fft_add_modF (mp_ptr r, mp_srcptr a, mp_srcptr b, int n)
287{
288  mp_limb_t c, x;
289
290  c = a[n] + b[n] + mpn_add_n (r, a, b, n);
291  /* 0 <= c <= 3 */
292
293#if 1
294  /* GCC 4.1 outsmarts most expressions here, and generates a 50% branch.  The
295     result is slower code, of course.  But the following outsmarts GCC.  */
296  x = (c - 1) & -(c != 0);
297  r[n] = c - x;
298  MPN_DECR_U (r, n + 1, x);
299#endif
300#if 0
301  if (c > 1)
302    {
303      r[n] = 1;                       /* r[n] - c = 1 */
304      MPN_DECR_U (r, n + 1, c - 1);
305    }
306  else
307    {
308      r[n] = c;
309    }
310#endif
311}
312
313/* r <- a-b mod 2^(n*GMP_NUMB_BITS)+1.
314   Assumes a and b are semi-normalized.
315*/
316static inline void
317mpn_fft_sub_modF (mp_ptr r, mp_srcptr a, mp_srcptr b, int n)
318{
319  mp_limb_t c, x;
320
321  c = a[n] - b[n] - mpn_sub_n (r, a, b, n);
322  /* -2 <= c <= 1 */
323
324#if 1
325  /* GCC 4.1 outsmarts most expressions here, and generates a 50% branch.  The
326     result is slower code, of course.  But the following outsmarts GCC.  */
327  x = (-c) & -((c & GMP_LIMB_HIGHBIT) != 0);
328  r[n] = x + c;
329  MPN_INCR_U (r, n + 1, x);
330#endif
331#if 0
332  if ((c & GMP_LIMB_HIGHBIT) != 0)
333    {
334      r[n] = 0;
335      MPN_INCR_U (r, n + 1, -c);
336    }
337  else
338    {
339      r[n] = c;
340    }
341#endif
342}
343
344/* input: A[0] ... A[inc*(K-1)] are residues mod 2^N+1 where
345	  N=n*GMP_NUMB_BITS, and 2^omega is a primitive root mod 2^N+1
346   output: A[inc*l[k][i]] <- \sum (2^omega)^(ij) A[inc*j] mod 2^N+1 */
347
348static void
349mpn_fft_fft (mp_ptr *Ap, mp_size_t K, int **ll,
350	     mp_size_t omega, mp_size_t n, mp_size_t inc, mp_ptr tp)
351{
352  if (K == 2)
353    {
354      mp_limb_t cy;
355#if HAVE_NATIVE_mpn_add_n_sub_n
356      cy = mpn_add_n_sub_n (Ap[0], Ap[inc], Ap[0], Ap[inc], n + 1) & 1;
357#else
358      MPN_COPY (tp, Ap[0], n + 1);
359      mpn_add_n (Ap[0], Ap[0], Ap[inc], n + 1);
360      cy = mpn_sub_n (Ap[inc], tp, Ap[inc], n + 1);
361#endif
362      if (Ap[0][n] > 1) /* can be 2 or 3 */
363	Ap[0][n] = 1 - mpn_sub_1 (Ap[0], Ap[0], n, Ap[0][n] - 1);
364      if (cy) /* Ap[inc][n] can be -1 or -2 */
365	Ap[inc][n] = mpn_add_1 (Ap[inc], Ap[inc], n, ~Ap[inc][n] + 1);
366    }
367  else
368    {
369      int j;
370      int *lk = *ll;
371
372      mpn_fft_fft (Ap,     K >> 1, ll-1, 2 * omega, n, inc * 2, tp);
373      mpn_fft_fft (Ap+inc, K >> 1, ll-1, 2 * omega, n, inc * 2, tp);
374      /* A[2*j*inc]   <- A[2*j*inc] + omega^l[k][2*j*inc] A[(2j+1)inc]
375	 A[(2j+1)inc] <- A[2*j*inc] + omega^l[k][(2j+1)inc] A[(2j+1)inc] */
376      for (j = 0; j < (K >> 1); j++, lk += 2, Ap += 2 * inc)
377	{
378	  /* Ap[inc] <- Ap[0] + Ap[inc] * 2^(lk[1] * omega)
379	     Ap[0]   <- Ap[0] + Ap[inc] * 2^(lk[0] * omega) */
380	  mpn_fft_mul_2exp_modF (tp, Ap[inc], lk[0] * omega, n);
381	  mpn_fft_sub_modF (Ap[inc], Ap[0], tp, n);
382	  mpn_fft_add_modF (Ap[0],   Ap[0], tp, n);
383	}
384    }
385}
386
387/* input: A[0] ... A[inc*(K-1)] are residues mod 2^N+1 where
388	  N=n*GMP_NUMB_BITS, and 2^omega is a primitive root mod 2^N+1
389   output: A[inc*l[k][i]] <- \sum (2^omega)^(ij) A[inc*j] mod 2^N+1
390   tp must have space for 2*(n+1) limbs.
391*/
392
393
394/* Given ap[0..n] with ap[n]<=1, reduce it modulo 2^(n*GMP_NUMB_BITS)+1,
395   by subtracting that modulus if necessary.
396
397   If ap[0..n] is exactly 2^(n*GMP_NUMB_BITS) then mpn_sub_1 produces a
398   borrow and the limbs must be zeroed out again.  This will occur very
399   infrequently.  */
400
401static inline void
402mpn_fft_normalize (mp_ptr ap, mp_size_t n)
403{
404  if (ap[n] != 0)
405    {
406      MPN_DECR_U (ap, n + 1, CNST_LIMB(1));
407      if (ap[n] == 0)
408	{
409	  /* This happens with very low probability; we have yet to trigger it,
410	     and thereby make sure this code is correct.  */
411	  MPN_ZERO (ap, n);
412	  ap[n] = 1;
413	}
414      else
415	ap[n] = 0;
416    }
417}
418
419/* a[i] <- a[i]*b[i] mod 2^(n*GMP_NUMB_BITS)+1 for 0 <= i < K */
420static void
421mpn_fft_mul_modF_K (mp_ptr *ap, mp_ptr *bp, mp_size_t n, int K)
422{
423  int i;
424  int sqr = (ap == bp);
425  TMP_DECL;
426
427  TMP_MARK;
428
429  if (n >= (sqr ? SQR_FFT_MODF_THRESHOLD : MUL_FFT_MODF_THRESHOLD))
430    {
431      int k, K2, nprime2, Nprime2, M2, maxLK, l, Mp2;
432      int **fft_l;
433      mp_ptr *Ap, *Bp, A, B, T;
434
435      k = mpn_fft_best_k (n, sqr);
436      K2 = 1 << k;
437      ASSERT_ALWAYS((n & (K2 - 1)) == 0);
438      maxLK = (K2 > GMP_NUMB_BITS) ? K2 : GMP_NUMB_BITS;
439      M2 = n * GMP_NUMB_BITS >> k;
440      l = n >> k;
441      Nprime2 = ((2 * M2 + k + 2 + maxLK) / maxLK) * maxLK;
442      /* Nprime2 = ceil((2*M2+k+3)/maxLK)*maxLK*/
443      nprime2 = Nprime2 / GMP_NUMB_BITS;
444
445      /* we should ensure that nprime2 is a multiple of the next K */
446      if (nprime2 >= (sqr ? SQR_FFT_MODF_THRESHOLD : MUL_FFT_MODF_THRESHOLD))
447	{
448	  unsigned long K3;
449	  for (;;)
450	    {
451	      K3 = 1L << mpn_fft_best_k (nprime2, sqr);
452	      if ((nprime2 & (K3 - 1)) == 0)
453		break;
454	      nprime2 = (nprime2 + K3 - 1) & -K3;
455	      Nprime2 = nprime2 * GMP_LIMB_BITS;
456	      /* warning: since nprime2 changed, K3 may change too! */
457	    }
458	}
459      ASSERT_ALWAYS(nprime2 < n); /* otherwise we'll loop */
460
461      Mp2 = Nprime2 >> k;
462
463      Ap = TMP_ALLOC_MP_PTRS (K2);
464      Bp = TMP_ALLOC_MP_PTRS (K2);
465      A = TMP_ALLOC_LIMBS (2 * (nprime2 + 1) << k);
466      T = TMP_ALLOC_LIMBS (2 * (nprime2 + 1));
467      B = A + ((nprime2 + 1) << k);
468      fft_l = TMP_ALLOC_TYPE (k + 1, int *);
469      for (i = 0; i <= k; i++)
470	fft_l[i] = TMP_ALLOC_TYPE (1<<i, int);
471      mpn_fft_initl (fft_l, k);
472
473      TRACE (printf ("recurse: %ldx%ld limbs -> %d times %dx%d (%1.2f)\n", n,
474		    n, K2, nprime2, nprime2, 2.0*(double)n/nprime2/K2));
475      for (i = 0; i < K; i++, ap++, bp++)
476	{
477	  mp_limb_t cy;
478	  mpn_fft_normalize (*ap, n);
479	  if (!sqr)
480	    mpn_fft_normalize (*bp, n);
481
482	  mpn_mul_fft_decompose (A, Ap, K2, nprime2, *ap, (l << k) + 1, l, Mp2, T);
483	  if (!sqr)
484	    mpn_mul_fft_decompose (B, Bp, K2, nprime2, *bp, (l << k) + 1, l, Mp2, T);
485
486	  cy = mpn_mul_fft_internal (*ap, n, k, Ap, Bp, A, B, nprime2,
487				     l, Mp2, fft_l, T, sqr);
488	  (*ap)[n] = cy;
489	}
490    }
491  else
492    {
493      mp_ptr a, b, tp, tpn;
494      mp_limb_t cc;
495      int n2 = 2 * n;
496      tp = TMP_ALLOC_LIMBS (n2);
497      tpn = tp + n;
498      TRACE (printf ("  mpn_mul_n %d of %ld limbs\n", K, n));
499      for (i = 0; i < K; i++)
500	{
501	  a = *ap++;
502	  b = *bp++;
503	  if (sqr)
504	    mpn_sqr (tp, a, n);
505	  else
506	    mpn_mul_n (tp, b, a, n);
507	  if (a[n] != 0)
508	    cc = mpn_add_n (tpn, tpn, b, n);
509	  else
510	    cc = 0;
511	  if (b[n] != 0)
512	    cc += mpn_add_n (tpn, tpn, a, n) + a[n];
513	  if (cc != 0)
514	    {
515	      /* FIXME: use MPN_INCR_U here, since carry is not expected.  */
516	      cc = mpn_add_1 (tp, tp, n2, cc);
517	      ASSERT (cc == 0);
518	    }
519	  a[n] = mpn_sub_n (a, tp, tpn, n) && mpn_add_1 (a, a, n, CNST_LIMB(1));
520	}
521    }
522  TMP_FREE;
523}
524
525
526/* input: A^[l[k][0]] A^[l[k][1]] ... A^[l[k][K-1]]
527   output: K*A[0] K*A[K-1] ... K*A[1].
528   Assumes the Ap[] are pseudo-normalized, i.e. 0 <= Ap[][n] <= 1.
529   This condition is also fulfilled at exit.
530*/
531static void
532mpn_fft_fftinv (mp_ptr *Ap, int K, mp_size_t omega, mp_size_t n, mp_ptr tp)
533{
534  if (K == 2)
535    {
536      mp_limb_t cy;
537#if HAVE_NATIVE_mpn_add_n_sub_n
538      cy = mpn_add_n_sub_n (Ap[0], Ap[1], Ap[0], Ap[1], n + 1) & 1;
539#else
540      MPN_COPY (tp, Ap[0], n + 1);
541      mpn_add_n (Ap[0], Ap[0], Ap[1], n + 1);
542      cy = mpn_sub_n (Ap[1], tp, Ap[1], n + 1);
543#endif
544      if (Ap[0][n] > 1) /* can be 2 or 3 */
545	Ap[0][n] = 1 - mpn_sub_1 (Ap[0], Ap[0], n, Ap[0][n] - 1);
546      if (cy) /* Ap[1][n] can be -1 or -2 */
547	Ap[1][n] = mpn_add_1 (Ap[1], Ap[1], n, ~Ap[1][n] + 1);
548    }
549  else
550    {
551      int j, K2 = K >> 1;
552
553      mpn_fft_fftinv (Ap,      K2, 2 * omega, n, tp);
554      mpn_fft_fftinv (Ap + K2, K2, 2 * omega, n, tp);
555      /* A[j]     <- A[j] + omega^j A[j+K/2]
556	 A[j+K/2] <- A[j] + omega^(j+K/2) A[j+K/2] */
557      for (j = 0; j < K2; j++, Ap++)
558	{
559	  /* Ap[K2] <- Ap[0] + Ap[K2] * 2^((j + K2) * omega)
560	     Ap[0]  <- Ap[0] + Ap[K2] * 2^(j * omega) */
561	  mpn_fft_mul_2exp_modF (tp, Ap[K2], j * omega, n);
562	  mpn_fft_sub_modF (Ap[K2], Ap[0], tp, n);
563	  mpn_fft_add_modF (Ap[0],  Ap[0], tp, n);
564	}
565    }
566}
567
568
569/* R <- A/2^k mod 2^(n*GMP_NUMB_BITS)+1 */
570static void
571mpn_fft_div_2exp_modF (mp_ptr r, mp_srcptr a, int k, mp_size_t n)
572{
573  int i;
574
575  ASSERT (r != a);
576  i = 2 * n * GMP_NUMB_BITS - k;
577  mpn_fft_mul_2exp_modF (r, a, i, n);
578  /* 1/2^k = 2^(2nL-k) mod 2^(n*GMP_NUMB_BITS)+1 */
579  /* normalize so that R < 2^(n*GMP_NUMB_BITS)+1 */
580  mpn_fft_normalize (r, n);
581}
582
583
584/* {rp,n} <- {ap,an} mod 2^(n*GMP_NUMB_BITS)+1, n <= an <= 3*n.
585   Returns carry out, i.e. 1 iff {ap,an} = -1 mod 2^(n*GMP_NUMB_BITS)+1,
586   then {rp,n}=0.
587*/
588static int
589mpn_fft_norm_modF (mp_ptr rp, mp_size_t n, mp_ptr ap, mp_size_t an)
590{
591  mp_size_t l;
592  long int m;
593  mp_limb_t cc;
594  int rpn;
595
596  ASSERT ((n <= an) && (an <= 3 * n));
597  m = an - 2 * n;
598  if (m > 0)
599    {
600      l = n;
601      /* add {ap, m} and {ap+2n, m} in {rp, m} */
602      cc = mpn_add_n (rp, ap, ap + 2 * n, m);
603      /* copy {ap+m, n-m} to {rp+m, n-m} */
604      rpn = mpn_add_1 (rp + m, ap + m, n - m, cc);
605    }
606  else
607    {
608      l = an - n; /* l <= n */
609      MPN_COPY (rp, ap, n);
610      rpn = 0;
611    }
612
613  /* remains to subtract {ap+n, l} from {rp, n+1} */
614  cc = mpn_sub_n (rp, rp, ap + n, l);
615  rpn -= mpn_sub_1 (rp + l, rp + l, n - l, cc);
616  if (rpn < 0) /* necessarily rpn = -1 */
617    rpn = mpn_add_1 (rp, rp, n, CNST_LIMB(1));
618  return rpn;
619}
620
621/* store in A[0..nprime] the first M bits from {n, nl},
622   in A[nprime+1..] the following M bits, ...
623   Assumes M is a multiple of GMP_NUMB_BITS (M = l * GMP_NUMB_BITS).
624   T must have space for at least (nprime + 1) limbs.
625   We must have nl <= 2*K*l.
626*/
627static void
628mpn_mul_fft_decompose (mp_ptr A, mp_ptr *Ap, int K, int nprime, mp_srcptr n,
629		       mp_size_t nl, int l, int Mp, mp_ptr T)
630{
631  int i, j;
632  mp_ptr tmp;
633  mp_size_t Kl = K * l;
634  TMP_DECL;
635  TMP_MARK;
636
637  if (nl > Kl) /* normalize {n, nl} mod 2^(Kl*GMP_NUMB_BITS)+1 */
638    {
639      mp_size_t dif = nl - Kl;
640      mp_limb_signed_t cy;
641
642      tmp = TMP_ALLOC_LIMBS(Kl + 1);
643
644      if (dif > Kl)
645	{
646	  int subp = 0;
647
648	  cy = mpn_sub_n (tmp, n, n + Kl, Kl);
649	  n += 2 * Kl;
650	  dif -= Kl;
651
652	  /* now dif > 0 */
653	  while (dif > Kl)
654	    {
655	      if (subp)
656		cy += mpn_sub_n (tmp, tmp, n, Kl);
657	      else
658		cy -= mpn_add_n (tmp, tmp, n, Kl);
659	      subp ^= 1;
660	      n += Kl;
661	      dif -= Kl;
662	    }
663	  /* now dif <= Kl */
664	  if (subp)
665	    cy += mpn_sub (tmp, tmp, Kl, n, dif);
666	  else
667	    cy -= mpn_add (tmp, tmp, Kl, n, dif);
668	  if (cy >= 0)
669	    cy = mpn_add_1 (tmp, tmp, Kl, cy);
670	  else
671	    cy = mpn_sub_1 (tmp, tmp, Kl, -cy);
672	}
673      else /* dif <= Kl, i.e. nl <= 2 * Kl */
674	{
675	  cy = mpn_sub (tmp, n, Kl, n + Kl, dif);
676	  cy = mpn_add_1 (tmp, tmp, Kl, cy);
677	}
678      tmp[Kl] = cy;
679      nl = Kl + 1;
680      n = tmp;
681    }
682  for (i = 0; i < K; i++)
683    {
684      Ap[i] = A;
685      /* store the next M bits of n into A[0..nprime] */
686      if (nl > 0) /* nl is the number of remaining limbs */
687	{
688	  j = (l <= nl && i < K - 1) ? l : nl; /* store j next limbs */
689	  nl -= j;
690	  MPN_COPY (T, n, j);
691	  MPN_ZERO (T + j, nprime + 1 - j);
692	  n += l;
693	  mpn_fft_mul_2exp_modF (A, T, i * Mp, nprime);
694	}
695      else
696	MPN_ZERO (A, nprime + 1);
697      A += nprime + 1;
698    }
699  ASSERT_ALWAYS (nl == 0);
700  TMP_FREE;
701}
702
703/* op <- n*m mod 2^N+1 with fft of size 2^k where N=pl*GMP_NUMB_BITS
704   op is pl limbs, its high bit is returned.
705   One must have pl = mpn_fft_next_size (pl, k).
706   T must have space for 2 * (nprime + 1) limbs.
707*/
708
709static mp_limb_t
710mpn_mul_fft_internal (mp_ptr op, mp_size_t pl, int k,
711		      mp_ptr *Ap, mp_ptr *Bp, mp_ptr A, mp_ptr B,
712		      mp_size_t nprime, mp_size_t l, mp_size_t Mp,
713		      int **fft_l, mp_ptr T, int sqr)
714{
715  int K, i, pla, lo, sh, j;
716  mp_ptr p;
717  mp_limb_t cc;
718
719  K = 1 << k;
720
721  /* direct fft's */
722  mpn_fft_fft (Ap, K, fft_l + k, 2 * Mp, nprime, 1, T);
723  if (!sqr)
724    mpn_fft_fft (Bp, K, fft_l + k, 2 * Mp, nprime, 1, T);
725
726  /* term to term multiplications */
727  mpn_fft_mul_modF_K (Ap, sqr ? Ap : Bp, nprime, K);
728
729  /* inverse fft's */
730  mpn_fft_fftinv (Ap, K, 2 * Mp, nprime, T);
731
732  /* division of terms after inverse fft */
733  Bp[0] = T + nprime + 1;
734  mpn_fft_div_2exp_modF (Bp[0], Ap[0], k, nprime);
735  for (i = 1; i < K; i++)
736    {
737      Bp[i] = Ap[i - 1];
738      mpn_fft_div_2exp_modF (Bp[i], Ap[i], k + (K - i) * Mp, nprime);
739    }
740
741  /* addition of terms in result p */
742  MPN_ZERO (T, nprime + 1);
743  pla = l * (K - 1) + nprime + 1; /* number of required limbs for p */
744  p = B; /* B has K*(n' + 1) limbs, which is >= pla, i.e. enough */
745  MPN_ZERO (p, pla);
746  cc = 0; /* will accumulate the (signed) carry at p[pla] */
747  for (i = K - 1, lo = l * i + nprime,sh = l * i; i >= 0; i--,lo -= l,sh -= l)
748    {
749      mp_ptr n = p + sh;
750
751      j = (K - i) & (K - 1);
752
753      if (mpn_add_n (n, n, Bp[j], nprime + 1))
754	cc += mpn_add_1 (n + nprime + 1, n + nprime + 1,
755			  pla - sh - nprime - 1, CNST_LIMB(1));
756      T[2 * l] = i + 1; /* T = (i + 1)*2^(2*M) */
757      if (mpn_cmp (Bp[j], T, nprime + 1) > 0)
758	{ /* subtract 2^N'+1 */
759	  cc -= mpn_sub_1 (n, n, pla - sh, CNST_LIMB(1));
760	  cc -= mpn_sub_1 (p + lo, p + lo, pla - lo, CNST_LIMB(1));
761	}
762    }
763  if (cc == -CNST_LIMB(1))
764    {
765      if ((cc = mpn_add_1 (p + pla - pl, p + pla - pl, pl, CNST_LIMB(1))))
766	{
767	  /* p[pla-pl]...p[pla-1] are all zero */
768	  mpn_sub_1 (p + pla - pl - 1, p + pla - pl - 1, pl + 1, CNST_LIMB(1));
769	  mpn_sub_1 (p + pla - 1, p + pla - 1, 1, CNST_LIMB(1));
770	}
771    }
772  else if (cc == 1)
773    {
774      if (pla >= 2 * pl)
775	{
776	  while ((cc = mpn_add_1 (p + pla - 2 * pl, p + pla - 2 * pl, 2 * pl, cc)))
777	    ;
778	}
779      else
780	{
781	  cc = mpn_sub_1 (p + pla - pl, p + pla - pl, pl, cc);
782	  ASSERT (cc == 0);
783	}
784    }
785  else
786    ASSERT (cc == 0);
787
788  /* here p < 2^(2M) [K 2^(M(K-1)) + (K-1) 2^(M(K-2)) + ... ]
789     < K 2^(2M) [2^(M(K-1)) + 2^(M(K-2)) + ... ]
790     < K 2^(2M) 2^(M(K-1))*2 = 2^(M*K+M+k+1) */
791  return mpn_fft_norm_modF (op, pl, p, pla);
792}
793
794/* return the lcm of a and 2^k */
795static unsigned long int
796mpn_mul_fft_lcm (unsigned long int a, unsigned int k)
797{
798  unsigned long int l = k;
799
800  while (a % 2 == 0 && k > 0)
801    {
802      a >>= 1;
803      k --;
804    }
805  return a << l;
806}
807
808
809mp_limb_t
810mpn_mul_fft (mp_ptr op, mp_size_t pl,
811	     mp_srcptr n, mp_size_t nl,
812	     mp_srcptr m, mp_size_t ml,
813	     int k)
814{
815  int K, maxLK, i;
816  mp_size_t N, Nprime, nprime, M, Mp, l;
817  mp_ptr *Ap, *Bp, A, T, B;
818  int **fft_l;
819  int sqr = (n == m && nl == ml);
820  mp_limb_t h;
821  TMP_DECL;
822
823  TRACE (printf ("\nmpn_mul_fft pl=%ld nl=%ld ml=%ld k=%d\n", pl, nl, ml, k));
824  ASSERT_ALWAYS (mpn_fft_next_size (pl, k) == pl);
825
826  TMP_MARK;
827  N = pl * GMP_NUMB_BITS;
828  fft_l = TMP_ALLOC_TYPE (k + 1, int *);
829  for (i = 0; i <= k; i++)
830    fft_l[i] = TMP_ALLOC_TYPE (1 << i, int);
831  mpn_fft_initl (fft_l, k);
832  K = 1 << k;
833  M = N >> k;	/* N = 2^k M */
834  l = 1 + (M - 1) / GMP_NUMB_BITS;
835  maxLK = mpn_mul_fft_lcm ((unsigned long) GMP_NUMB_BITS, k); /* lcm (GMP_NUMB_BITS, 2^k) */
836
837  Nprime = (1 + (2 * M + k + 2) / maxLK) * maxLK;
838  /* Nprime = ceil((2*M+k+3)/maxLK)*maxLK; */
839  nprime = Nprime / GMP_NUMB_BITS;
840  TRACE (printf ("N=%ld K=%d, M=%ld, l=%ld, maxLK=%d, Np=%ld, np=%ld\n",
841		 N, K, M, l, maxLK, Nprime, nprime));
842  /* we should ensure that recursively, nprime is a multiple of the next K */
843  if (nprime >= (sqr ? SQR_FFT_MODF_THRESHOLD : MUL_FFT_MODF_THRESHOLD))
844    {
845      unsigned long K2;
846      for (;;)
847	{
848	  K2 = 1L << mpn_fft_best_k (nprime, sqr);
849	  if ((nprime & (K2 - 1)) == 0)
850	    break;
851	  nprime = (nprime + K2 - 1) & -K2;
852	  Nprime = nprime * GMP_LIMB_BITS;
853	  /* warning: since nprime changed, K2 may change too! */
854	}
855      TRACE (printf ("new maxLK=%d, Np=%ld, np=%ld\n", maxLK, Nprime, nprime));
856    }
857  ASSERT_ALWAYS (nprime < pl); /* otherwise we'll loop */
858
859  T = TMP_ALLOC_LIMBS (2 * (nprime + 1));
860  Mp = Nprime >> k;
861
862  TRACE (printf ("%ldx%ld limbs -> %d times %ldx%ld limbs (%1.2f)\n",
863		pl, pl, K, nprime, nprime, 2.0 * (double) N / Nprime / K);
864	 printf ("   temp space %ld\n", 2 * K * (nprime + 1)));
865
866  A = TMP_ALLOC_LIMBS (K * (nprime + 1));
867  Ap = TMP_ALLOC_MP_PTRS (K);
868  mpn_mul_fft_decompose (A, Ap, K, nprime, n, nl, l, Mp, T);
869  if (sqr)
870    {
871      mp_size_t pla;
872      pla = l * (K - 1) + nprime + 1; /* number of required limbs for p */
873      B = TMP_ALLOC_LIMBS (pla);
874      Bp = TMP_ALLOC_MP_PTRS (K);
875    }
876  else
877    {
878      B = TMP_ALLOC_LIMBS (K * (nprime + 1));
879      Bp = TMP_ALLOC_MP_PTRS (K);
880      mpn_mul_fft_decompose (B, Bp, K, nprime, m, ml, l, Mp, T);
881    }
882  h = mpn_mul_fft_internal (op, pl, k, Ap, Bp, A, B, nprime, l, Mp, fft_l, T, sqr);
883
884  TMP_FREE;
885  return h;
886}
887
888#if WANT_OLD_FFT_FULL
889/* multiply {n, nl} by {m, ml}, and put the result in {op, nl+ml} */
890void
891mpn_mul_fft_full (mp_ptr op,
892		  mp_srcptr n, mp_size_t nl,
893		  mp_srcptr m, mp_size_t ml)
894{
895  mp_ptr pad_op;
896  mp_size_t pl, pl2, pl3, l;
897  int k2, k3;
898  int sqr = (n == m && nl == ml);
899  int cc, c2, oldcc;
900
901  pl = nl + ml; /* total number of limbs of the result */
902
903  /* perform a fft mod 2^(2N)+1 and one mod 2^(3N)+1.
904     We must have pl3 = 3/2 * pl2, with pl2 a multiple of 2^k2, and
905     pl3 a multiple of 2^k3. Since k3 >= k2, both are multiples of 2^k2,
906     and pl2 must be an even multiple of 2^k2. Thus (pl2,pl3) =
907     (2*j*2^k2,3*j*2^k2), which works for 3*j <= pl/2^k2 <= 5*j.
908     We need that consecutive intervals overlap, i.e. 5*j >= 3*(j+1),
909     which requires j>=2. Thus this scheme requires pl >= 6 * 2^FFT_FIRST_K. */
910
911  /*  ASSERT_ALWAYS(pl >= 6 * (1 << FFT_FIRST_K)); */
912
913  pl2 = (2 * pl - 1) / 5; /* ceil (2pl/5) - 1 */
914  do
915    {
916      pl2++;
917      k2 = mpn_fft_best_k (pl2, sqr); /* best fft size for pl2 limbs */
918      pl2 = mpn_fft_next_size (pl2, k2);
919      pl3 = 3 * pl2 / 2; /* since k>=FFT_FIRST_K=4, pl2 is a multiple of 2^4,
920			    thus pl2 / 2 is exact */
921      k3 = mpn_fft_best_k (pl3, sqr);
922    }
923  while (mpn_fft_next_size (pl3, k3) != pl3);
924
925  TRACE (printf ("mpn_mul_fft_full nl=%ld ml=%ld -> pl2=%ld pl3=%ld k=%d\n",
926		 nl, ml, pl2, pl3, k2));
927
928  ASSERT_ALWAYS(pl3 <= pl);
929  cc = mpn_mul_fft (op, pl3, n, nl, m, ml, k3);     /* mu */
930  ASSERT(cc == 0);
931  pad_op = __GMP_ALLOCATE_FUNC_LIMBS (pl2);
932  cc = mpn_mul_fft (pad_op, pl2, n, nl, m, ml, k2); /* lambda */
933  cc = -cc + mpn_sub_n (pad_op, pad_op, op, pl2);    /* lambda - low(mu) */
934  /* 0 <= cc <= 1 */
935  ASSERT(0 <= cc && cc <= 1);
936  l = pl3 - pl2; /* l = pl2 / 2 since pl3 = 3/2 * pl2 */
937  c2 = mpn_add_n (pad_op, pad_op, op + pl2, l);
938  cc = mpn_add_1 (pad_op + l, pad_op + l, l, (mp_limb_t) c2) - cc;
939  ASSERT(-1 <= cc && cc <= 1);
940  if (cc < 0)
941    cc = mpn_add_1 (pad_op, pad_op, pl2, (mp_limb_t) -cc);
942  ASSERT(0 <= cc && cc <= 1);
943  /* now lambda-mu = {pad_op, pl2} - cc mod 2^(pl2*GMP_NUMB_BITS)+1 */
944  oldcc = cc;
945#if HAVE_NATIVE_mpn_add_n_sub_n
946  c2 = mpn_add_n_sub_n (pad_op + l, pad_op, pad_op, pad_op + l, l);
947  /* c2 & 1 is the borrow, c2 & 2 is the carry */
948  cc += c2 >> 1; /* carry out from high <- low + high */
949  c2 = c2 & 1; /* borrow out from low <- low - high */
950#else
951  {
952    mp_ptr tmp;
953    TMP_DECL;
954
955    TMP_MARK;
956    tmp = TMP_ALLOC_LIMBS (l);
957    MPN_COPY (tmp, pad_op, l);
958    c2 = mpn_sub_n (pad_op,      pad_op, pad_op + l, l);
959    cc += mpn_add_n (pad_op + l, tmp,    pad_op + l, l);
960    TMP_FREE;
961  }
962#endif
963  c2 += oldcc;
964  /* first normalize {pad_op, pl2} before dividing by 2: c2 is the borrow
965     at pad_op + l, cc is the carry at pad_op + pl2 */
966  /* 0 <= cc <= 2 */
967  cc -= mpn_sub_1 (pad_op + l, pad_op + l, l, (mp_limb_t) c2);
968  /* -1 <= cc <= 2 */
969  if (cc > 0)
970    cc = -mpn_sub_1 (pad_op, pad_op, pl2, (mp_limb_t) cc);
971  /* now -1 <= cc <= 0 */
972  if (cc < 0)
973    cc = mpn_add_1 (pad_op, pad_op, pl2, (mp_limb_t) -cc);
974  /* now {pad_op, pl2} is normalized, with 0 <= cc <= 1 */
975  if (pad_op[0] & 1) /* if odd, add 2^(pl2*GMP_NUMB_BITS)+1 */
976    cc += 1 + mpn_add_1 (pad_op, pad_op, pl2, CNST_LIMB(1));
977  /* now 0 <= cc <= 2, but cc=2 cannot occur since it would give a carry
978     out below */
979  mpn_rshift (pad_op, pad_op, pl2, 1); /* divide by two */
980  if (cc) /* then cc=1 */
981    pad_op [pl2 - 1] |= (mp_limb_t) 1 << (GMP_NUMB_BITS - 1);
982  /* now {pad_op,pl2}-cc = (lambda-mu)/(1-2^(l*GMP_NUMB_BITS))
983     mod 2^(pl2*GMP_NUMB_BITS) + 1 */
984  c2 = mpn_add_n (op, op, pad_op, pl2); /* no need to add cc (is 0) */
985  /* since pl2+pl3 >= pl, necessary the extra limbs (including cc) are zero */
986  MPN_COPY (op + pl3, pad_op, pl - pl3);
987  ASSERT_MPN_ZERO_P (pad_op + pl - pl3, pl2 + pl3 - pl);
988  __GMP_FREE_FUNC_LIMBS (pad_op, pl2);
989  /* since the final result has at most pl limbs, no carry out below */
990  mpn_add_1 (op + pl2, op + pl2, pl - pl2, (mp_limb_t) c2);
991}
992#endif
993