1/* Schoenhage's fast multiplication modulo 2^N+1.
2
3   Contributed by Paul Zimmermann.
4
5   THE FUNCTIONS IN THIS FILE ARE INTERNAL WITH MUTABLE INTERFACES.  IT IS ONLY
6   SAFE TO REACH THEM THROUGH DOCUMENTED INTERFACES.  IN FACT, IT IS ALMOST
7   GUARANTEED THAT THEY WILL CHANGE OR DISAPPEAR IN A FUTURE GNU MP RELEASE.
8
9Copyright 1998-2010, 2012, 2013, 2018, 2020 Free Software Foundation, Inc.
10
11This file is part of the GNU MP Library.
12
13The GNU MP Library is free software; you can redistribute it and/or modify
14it under the terms of either:
15
16  * the GNU Lesser General Public License as published by the Free
17    Software Foundation; either version 3 of the License, or (at your
18    option) any later version.
19
20or
21
22  * the GNU General Public License as published by the Free Software
23    Foundation; either version 2 of the License, or (at your option) any
24    later version.
25
26or both in parallel, as here.
27
28The GNU MP Library is distributed in the hope that it will be useful, but
29WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
30or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
31for more details.
32
33You should have received copies of the GNU General Public License and the
34GNU Lesser General Public License along with the GNU MP Library.  If not,
35see https://www.gnu.org/licenses/.  */
36
37
38/* References:
39
40   Schnelle Multiplikation grosser Zahlen, by Arnold Schoenhage and Volker
41   Strassen, Computing 7, p. 281-292, 1971.
42
43   Asymptotically fast algorithms for the numerical multiplication and division
44   of polynomials with complex coefficients, by Arnold Schoenhage, Computer
45   Algebra, EUROCAM'82, LNCS 144, p. 3-15, 1982.
46
47   Tapes versus Pointers, a study in implementing fast algorithms, by Arnold
48   Schoenhage, Bulletin of the EATCS, 30, p. 23-32, 1986.
49
50   TODO:
51
52   Implement some of the tricks published at ISSAC'2007 by Gaudry, Kruppa, and
53   Zimmermann.
54
55   It might be possible to avoid a small number of MPN_COPYs by using a
56   rotating temporary or two.
57
58   Cleanup and simplify the code!
59*/
60
61#ifdef TRACE
62#undef TRACE
63#define TRACE(x) x
64#include <stdio.h>
65#else
66#define TRACE(x)
67#endif
68
69#include "gmp-impl.h"
70
71#ifdef WANT_ADDSUB
72#include "generic/add_n_sub_n.c"
73#define HAVE_NATIVE_mpn_add_n_sub_n 1
74#endif
75
76static mp_limb_t mpn_mul_fft_internal (mp_ptr, mp_size_t, int, mp_ptr *,
77				       mp_ptr *, mp_ptr, mp_ptr, mp_size_t,
78				       mp_size_t, mp_size_t, int **, mp_ptr, int);
79static void mpn_mul_fft_decompose (mp_ptr, mp_ptr *, mp_size_t, mp_size_t, mp_srcptr,
80				   mp_size_t, mp_size_t, mp_size_t, mp_ptr);
81
82
83/* Find the best k to use for a mod 2^(m*GMP_NUMB_BITS)+1 FFT for m >= n.
84   We have sqr=0 if for a multiply, sqr=1 for a square.
85   There are three generations of this code; we keep the old ones as long as
86   some gmp-mparam.h is not updated.  */
87
88
89/*****************************************************************************/
90
91#if TUNE_PROGRAM_BUILD || (defined (MUL_FFT_TABLE3) && defined (SQR_FFT_TABLE3))
92
93#ifndef FFT_TABLE3_SIZE		/* When tuning this is defined in gmp-impl.h */
94#if defined (MUL_FFT_TABLE3_SIZE) && defined (SQR_FFT_TABLE3_SIZE)
95#if MUL_FFT_TABLE3_SIZE > SQR_FFT_TABLE3_SIZE
96#define FFT_TABLE3_SIZE MUL_FFT_TABLE3_SIZE
97#else
98#define FFT_TABLE3_SIZE SQR_FFT_TABLE3_SIZE
99#endif
100#endif
101#endif
102
103#ifndef FFT_TABLE3_SIZE
104#define FFT_TABLE3_SIZE 200
105#endif
106
107FFT_TABLE_ATTRS struct fft_table_nk mpn_fft_table3[2][FFT_TABLE3_SIZE] =
108{
109  MUL_FFT_TABLE3,
110  SQR_FFT_TABLE3
111};
112
113int
114mpn_fft_best_k (mp_size_t n, int sqr)
115{
116  const struct fft_table_nk *fft_tab, *tab;
117  mp_size_t tab_n, thres;
118  int last_k;
119
120  fft_tab = mpn_fft_table3[sqr];
121  last_k = fft_tab->k;
122  for (tab = fft_tab + 1; ; tab++)
123    {
124      tab_n = tab->n;
125      thres = tab_n << last_k;
126      if (n <= thres)
127	break;
128      last_k = tab->k;
129    }
130  return last_k;
131}
132
133#define MPN_FFT_BEST_READY 1
134#endif
135
136/*****************************************************************************/
137
138#if ! defined (MPN_FFT_BEST_READY)
139FFT_TABLE_ATTRS mp_size_t mpn_fft_table[2][MPN_FFT_TABLE_SIZE] =
140{
141  MUL_FFT_TABLE,
142  SQR_FFT_TABLE
143};
144
145int
146mpn_fft_best_k (mp_size_t n, int sqr)
147{
148  int i;
149
150  for (i = 0; mpn_fft_table[sqr][i] != 0; i++)
151    if (n < mpn_fft_table[sqr][i])
152      return i + FFT_FIRST_K;
153
154  /* treat 4*last as one further entry */
155  if (i == 0 || n < 4 * mpn_fft_table[sqr][i - 1])
156    return i + FFT_FIRST_K;
157  else
158    return i + FFT_FIRST_K + 1;
159}
160#endif
161
162/*****************************************************************************/
163
164
165/* Returns smallest possible number of limbs >= pl for a fft of size 2^k,
166   i.e. smallest multiple of 2^k >= pl.
167
168   Don't declare static: needed by tuneup.
169*/
170
171mp_size_t
172mpn_fft_next_size (mp_size_t pl, int k)
173{
174  pl = 1 + ((pl - 1) >> k); /* ceil (pl/2^k) */
175  return pl << k;
176}
177
178
179/* Initialize l[i][j] with bitrev(j) */
180static void
181mpn_fft_initl (int **l, int k)
182{
183  int i, j, K;
184  int *li;
185
186  l[0][0] = 0;
187  for (i = 1, K = 1; i <= k; i++, K *= 2)
188    {
189      li = l[i];
190      for (j = 0; j < K; j++)
191	{
192	  li[j] = 2 * l[i - 1][j];
193	  li[K + j] = 1 + li[j];
194	}
195    }
196}
197
198
199/* r <- a*2^d mod 2^(n*GMP_NUMB_BITS)+1 with a = {a, n+1}
200   Assumes a is semi-normalized, i.e. a[n] <= 1.
201   r and a must have n+1 limbs, and not overlap.
202*/
203static void
204mpn_fft_mul_2exp_modF (mp_ptr r, mp_srcptr a, mp_bitcnt_t d, mp_size_t n)
205{
206  unsigned int sh;
207  mp_size_t m;
208  mp_limb_t cc, rd;
209
210  sh = d % GMP_NUMB_BITS;
211  m = d / GMP_NUMB_BITS;
212
213  if (m >= n)			/* negate */
214    {
215      /* r[0..m-1]  <-- lshift(a[n-m]..a[n-1], sh)
216	 r[m..n-1]  <-- -lshift(a[0]..a[n-m-1],  sh) */
217
218      m -= n;
219      if (sh != 0)
220	{
221	  /* no out shift below since a[n] <= 1 */
222	  mpn_lshift (r, a + n - m, m + 1, sh);
223	  rd = r[m];
224	  cc = mpn_lshiftc (r + m, a, n - m, sh);
225	}
226      else
227	{
228	  MPN_COPY (r, a + n - m, m);
229	  rd = a[n];
230	  mpn_com (r + m, a, n - m);
231	  cc = 0;
232	}
233
234      /* add cc to r[0], and add rd to r[m] */
235
236      /* now add 1 in r[m], subtract 1 in r[n], i.e. add 1 in r[0] */
237
238      r[n] = 0;
239      /* cc < 2^sh <= 2^(GMP_NUMB_BITS-1) thus no overflow here */
240      cc++;
241      mpn_incr_u (r, cc);
242
243      rd++;
244      /* rd might overflow when sh=GMP_NUMB_BITS-1 */
245      cc = (rd == 0) ? 1 : rd;
246      r = r + m + (rd == 0);
247      mpn_incr_u (r, cc);
248    }
249  else
250    {
251      /* r[0..m-1]  <-- -lshift(a[n-m]..a[n-1], sh)
252	 r[m..n-1]  <-- lshift(a[0]..a[n-m-1],  sh)  */
253      if (sh != 0)
254	{
255	  /* no out bits below since a[n] <= 1 */
256	  mpn_lshiftc (r, a + n - m, m + 1, sh);
257	  rd = ~r[m];
258	  /* {r, m+1} = {a+n-m, m+1} << sh */
259	  cc = mpn_lshift (r + m, a, n - m, sh); /* {r+m, n-m} = {a, n-m}<<sh */
260	}
261      else
262	{
263	  /* r[m] is not used below, but we save a test for m=0 */
264	  mpn_com (r, a + n - m, m + 1);
265	  rd = a[n];
266	  MPN_COPY (r + m, a, n - m);
267	  cc = 0;
268	}
269
270      /* now complement {r, m}, subtract cc from r[0], subtract rd from r[m] */
271
272      /* if m=0 we just have r[0]=a[n] << sh */
273      if (m != 0)
274	{
275	  /* now add 1 in r[0], subtract 1 in r[m] */
276	  if (cc-- == 0) /* then add 1 to r[0] */
277	    cc = mpn_add_1 (r, r, n, CNST_LIMB(1));
278	  cc = mpn_sub_1 (r, r, m, cc) + 1;
279	  /* add 1 to cc instead of rd since rd might overflow */
280	}
281
282      /* now subtract cc and rd from r[m..n] */
283
284      r[n] = -mpn_sub_1 (r + m, r + m, n - m, cc);
285      r[n] -= mpn_sub_1 (r + m, r + m, n - m, rd);
286      if (r[n] & GMP_LIMB_HIGHBIT)
287	r[n] = mpn_add_1 (r, r, n, CNST_LIMB(1));
288    }
289}
290
291#if HAVE_NATIVE_mpn_add_n_sub_n
292static inline void
293mpn_fft_add_sub_modF (mp_ptr A0, mp_ptr Ai, mp_srcptr tp, mp_size_t n)
294{
295  mp_limb_t cyas, c, x;
296
297  cyas = mpn_add_n_sub_n (A0, Ai, A0, tp, n);
298
299  c = A0[n] - tp[n] - (cyas & 1);
300  x = (-c) & -((c & GMP_LIMB_HIGHBIT) != 0);
301  Ai[n] = x + c;
302  MPN_INCR_U (Ai, n + 1, x);
303
304  c = A0[n] + tp[n] + (cyas >> 1);
305  x = (c - 1) & -(c != 0);
306  A0[n] = c - x;
307  MPN_DECR_U (A0, n + 1, x);
308}
309
310#else /* ! HAVE_NATIVE_mpn_add_n_sub_n  */
311
312/* r <- a+b mod 2^(n*GMP_NUMB_BITS)+1.
313   Assumes a and b are semi-normalized.
314*/
315static inline void
316mpn_fft_add_modF (mp_ptr r, mp_srcptr a, mp_srcptr b, mp_size_t n)
317{
318  mp_limb_t c, x;
319
320  c = a[n] + b[n] + mpn_add_n (r, a, b, n);
321  /* 0 <= c <= 3 */
322
323#if 1
324  /* GCC 4.1 outsmarts most expressions here, and generates a 50% branch.  The
325     result is slower code, of course.  But the following outsmarts GCC.  */
326  x = (c - 1) & -(c != 0);
327  r[n] = c - x;
328  MPN_DECR_U (r, n + 1, x);
329#endif
330#if 0
331  if (c > 1)
332    {
333      r[n] = 1;                       /* r[n] - c = 1 */
334      MPN_DECR_U (r, n + 1, c - 1);
335    }
336  else
337    {
338      r[n] = c;
339    }
340#endif
341}
342
343/* r <- a-b mod 2^(n*GMP_NUMB_BITS)+1.
344   Assumes a and b are semi-normalized.
345*/
346static inline void
347mpn_fft_sub_modF (mp_ptr r, mp_srcptr a, mp_srcptr b, mp_size_t n)
348{
349  mp_limb_t c, x;
350
351  c = a[n] - b[n] - mpn_sub_n (r, a, b, n);
352  /* -2 <= c <= 1 */
353
354#if 1
355  /* GCC 4.1 outsmarts most expressions here, and generates a 50% branch.  The
356     result is slower code, of course.  But the following outsmarts GCC.  */
357  x = (-c) & -((c & GMP_LIMB_HIGHBIT) != 0);
358  r[n] = x + c;
359  MPN_INCR_U (r, n + 1, x);
360#endif
361#if 0
362  if ((c & GMP_LIMB_HIGHBIT) != 0)
363    {
364      r[n] = 0;
365      MPN_INCR_U (r, n + 1, -c);
366    }
367  else
368    {
369      r[n] = c;
370    }
371#endif
372}
373#endif /* HAVE_NATIVE_mpn_add_n_sub_n */
374
375/* input: A[0] ... A[inc*(K-1)] are residues mod 2^N+1 where
376	  N=n*GMP_NUMB_BITS, and 2^omega is a primitive root mod 2^N+1
377   output: A[inc*l[k][i]] <- \sum (2^omega)^(ij) A[inc*j] mod 2^N+1 */
378
379static void
380mpn_fft_fft (mp_ptr *Ap, mp_size_t K, int **ll,
381	     mp_size_t omega, mp_size_t n, mp_size_t inc, mp_ptr tp)
382{
383  if (K == 2)
384    {
385      mp_limb_t cy;
386#if HAVE_NATIVE_mpn_add_n_sub_n
387      cy = mpn_add_n_sub_n (Ap[0], Ap[inc], Ap[0], Ap[inc], n + 1) & 1;
388#else
389      MPN_COPY (tp, Ap[0], n + 1);
390      mpn_add_n (Ap[0], Ap[0], Ap[inc], n + 1);
391      cy = mpn_sub_n (Ap[inc], tp, Ap[inc], n + 1);
392#endif
393      if (Ap[0][n] > 1) /* can be 2 or 3 */
394	Ap[0][n] = 1 - mpn_sub_1 (Ap[0], Ap[0], n, Ap[0][n] - 1);
395      if (cy) /* Ap[inc][n] can be -1 or -2 */
396	Ap[inc][n] = mpn_add_1 (Ap[inc], Ap[inc], n, ~Ap[inc][n] + 1);
397    }
398  else
399    {
400      mp_size_t j, K2 = K >> 1;
401      int *lk = *ll;
402
403      mpn_fft_fft (Ap,     K2, ll-1, 2 * omega, n, inc * 2, tp);
404      mpn_fft_fft (Ap+inc, K2, ll-1, 2 * omega, n, inc * 2, tp);
405      /* A[2*j*inc]   <- A[2*j*inc] + omega^l[k][2*j*inc] A[(2j+1)inc]
406	 A[(2j+1)inc] <- A[2*j*inc] + omega^l[k][(2j+1)inc] A[(2j+1)inc] */
407      for (j = 0; j < K2; j++, lk += 2, Ap += 2 * inc)
408	{
409	  /* Ap[inc] <- Ap[0] + Ap[inc] * 2^(lk[1] * omega)
410	     Ap[0]   <- Ap[0] + Ap[inc] * 2^(lk[0] * omega) */
411	  mpn_fft_mul_2exp_modF (tp, Ap[inc], lk[0] * omega, n);
412#if HAVE_NATIVE_mpn_add_n_sub_n
413	  mpn_fft_add_sub_modF (Ap[0], Ap[inc], tp, n);
414#else
415	  mpn_fft_sub_modF (Ap[inc], Ap[0], tp, n);
416	  mpn_fft_add_modF (Ap[0],   Ap[0], tp, n);
417#endif
418	}
419    }
420}
421
422/* input: A[0] ... A[inc*(K-1)] are residues mod 2^N+1 where
423	  N=n*GMP_NUMB_BITS, and 2^omega is a primitive root mod 2^N+1
424   output: A[inc*l[k][i]] <- \sum (2^omega)^(ij) A[inc*j] mod 2^N+1
425   tp must have space for 2*(n+1) limbs.
426*/
427
428
429/* Given ap[0..n] with ap[n]<=1, reduce it modulo 2^(n*GMP_NUMB_BITS)+1,
430   by subtracting that modulus if necessary.
431
432   If ap[0..n] is exactly 2^(n*GMP_NUMB_BITS) then mpn_sub_1 produces a
433   borrow and the limbs must be zeroed out again.  This will occur very
434   infrequently.  */
435
436static inline void
437mpn_fft_normalize (mp_ptr ap, mp_size_t n)
438{
439  if (ap[n] != 0)
440    {
441      MPN_DECR_U (ap, n + 1, CNST_LIMB(1));
442      if (ap[n] == 0)
443	{
444	  /* This happens with very low probability; we have yet to trigger it,
445	     and thereby make sure this code is correct.  */
446	  MPN_ZERO (ap, n);
447	  ap[n] = 1;
448	}
449      else
450	ap[n] = 0;
451    }
452}
453
454/* a[i] <- a[i]*b[i] mod 2^(n*GMP_NUMB_BITS)+1 for 0 <= i < K */
455static void
456mpn_fft_mul_modF_K (mp_ptr *ap, mp_ptr *bp, mp_size_t n, mp_size_t K)
457{
458  int i;
459  int sqr = (ap == bp);
460  TMP_DECL;
461
462  TMP_MARK;
463
464  if (n >= (sqr ? SQR_FFT_MODF_THRESHOLD : MUL_FFT_MODF_THRESHOLD))
465    {
466      mp_size_t K2, nprime2, Nprime2, M2, maxLK, l, Mp2;
467      int k;
468      int **fft_l, *tmp;
469      mp_ptr *Ap, *Bp, A, B, T;
470
471      k = mpn_fft_best_k (n, sqr);
472      K2 = (mp_size_t) 1 << k;
473      ASSERT_ALWAYS((n & (K2 - 1)) == 0);
474      maxLK = (K2 > GMP_NUMB_BITS) ? K2 : GMP_NUMB_BITS;
475      M2 = n * GMP_NUMB_BITS >> k;
476      l = n >> k;
477      Nprime2 = ((2 * M2 + k + 2 + maxLK) / maxLK) * maxLK;
478      /* Nprime2 = ceil((2*M2+k+3)/maxLK)*maxLK*/
479      nprime2 = Nprime2 / GMP_NUMB_BITS;
480
481      /* we should ensure that nprime2 is a multiple of the next K */
482      if (nprime2 >= (sqr ? SQR_FFT_MODF_THRESHOLD : MUL_FFT_MODF_THRESHOLD))
483	{
484	  mp_size_t K3;
485	  for (;;)
486	    {
487	      K3 = (mp_size_t) 1 << mpn_fft_best_k (nprime2, sqr);
488	      if ((nprime2 & (K3 - 1)) == 0)
489		break;
490	      nprime2 = (nprime2 + K3 - 1) & -K3;
491	      Nprime2 = nprime2 * GMP_LIMB_BITS;
492	      /* warning: since nprime2 changed, K3 may change too! */
493	    }
494	}
495      ASSERT_ALWAYS(nprime2 < n); /* otherwise we'll loop */
496
497      Mp2 = Nprime2 >> k;
498
499      Ap = TMP_BALLOC_MP_PTRS (K2);
500      Bp = TMP_BALLOC_MP_PTRS (K2);
501      A = TMP_BALLOC_LIMBS (2 * (nprime2 + 1) << k);
502      T = TMP_BALLOC_LIMBS (2 * (nprime2 + 1));
503      B = A + ((nprime2 + 1) << k);
504      fft_l = TMP_BALLOC_TYPE (k + 1, int *);
505      tmp = TMP_BALLOC_TYPE ((size_t) 2 << k, int);
506      for (i = 0; i <= k; i++)
507	{
508	  fft_l[i] = tmp;
509	  tmp += (mp_size_t) 1 << i;
510	}
511
512      mpn_fft_initl (fft_l, k);
513
514      TRACE (printf ("recurse: %ldx%ld limbs -> %ld times %ldx%ld (%1.2f)\n", n,
515		    n, K2, nprime2, nprime2, 2.0*(double)n/nprime2/K2));
516      for (i = 0; i < K; i++, ap++, bp++)
517	{
518	  mp_limb_t cy;
519	  mpn_fft_normalize (*ap, n);
520	  if (!sqr)
521	    mpn_fft_normalize (*bp, n);
522
523	  mpn_mul_fft_decompose (A, Ap, K2, nprime2, *ap, (l << k) + 1, l, Mp2, T);
524	  if (!sqr)
525	    mpn_mul_fft_decompose (B, Bp, K2, nprime2, *bp, (l << k) + 1, l, Mp2, T);
526
527	  cy = mpn_mul_fft_internal (*ap, n, k, Ap, Bp, A, B, nprime2,
528				     l, Mp2, fft_l, T, sqr);
529	  (*ap)[n] = cy;
530	}
531    }
532  else
533    {
534      mp_ptr a, b, tp, tpn;
535      mp_limb_t cc;
536      mp_size_t n2 = 2 * n;
537      tp = TMP_BALLOC_LIMBS (n2);
538      tpn = tp + n;
539      TRACE (printf ("  mpn_mul_n %ld of %ld limbs\n", K, n));
540      for (i = 0; i < K; i++)
541	{
542	  a = *ap++;
543	  b = *bp++;
544	  if (sqr)
545	    mpn_sqr (tp, a, n);
546	  else
547	    mpn_mul_n (tp, b, a, n);
548	  if (a[n] != 0)
549	    cc = mpn_add_n (tpn, tpn, b, n);
550	  else
551	    cc = 0;
552	  if (b[n] != 0)
553	    cc += mpn_add_n (tpn, tpn, a, n) + a[n];
554	  if (cc != 0)
555	    {
556	      cc = mpn_add_1 (tp, tp, n2, cc);
557	      /* If mpn_add_1 give a carry (cc != 0),
558		 the result (tp) is at most GMP_NUMB_MAX - 1,
559		 so the following addition can't overflow.
560	      */
561	      tp[0] += cc;
562	    }
563	  a[n] = mpn_sub_n (a, tp, tpn, n) && mpn_add_1 (a, a, n, CNST_LIMB(1));
564	}
565    }
566  TMP_FREE;
567}
568
569
570/* input: A^[l[k][0]] A^[l[k][1]] ... A^[l[k][K-1]]
571   output: K*A[0] K*A[K-1] ... K*A[1].
572   Assumes the Ap[] are pseudo-normalized, i.e. 0 <= Ap[][n] <= 1.
573   This condition is also fulfilled at exit.
574*/
575static void
576mpn_fft_fftinv (mp_ptr *Ap, mp_size_t K, mp_size_t omega, mp_size_t n, mp_ptr tp)
577{
578  if (K == 2)
579    {
580      mp_limb_t cy;
581#if HAVE_NATIVE_mpn_add_n_sub_n
582      cy = mpn_add_n_sub_n (Ap[0], Ap[1], Ap[0], Ap[1], n + 1) & 1;
583#else
584      MPN_COPY (tp, Ap[0], n + 1);
585      mpn_add_n (Ap[0], Ap[0], Ap[1], n + 1);
586      cy = mpn_sub_n (Ap[1], tp, Ap[1], n + 1);
587#endif
588      if (Ap[0][n] > 1) /* can be 2 or 3 */
589	Ap[0][n] = 1 - mpn_sub_1 (Ap[0], Ap[0], n, Ap[0][n] - 1);
590      if (cy) /* Ap[1][n] can be -1 or -2 */
591	Ap[1][n] = mpn_add_1 (Ap[1], Ap[1], n, ~Ap[1][n] + 1);
592    }
593  else
594    {
595      mp_size_t j, K2 = K >> 1;
596
597      mpn_fft_fftinv (Ap,      K2, 2 * omega, n, tp);
598      mpn_fft_fftinv (Ap + K2, K2, 2 * omega, n, tp);
599      /* A[j]     <- A[j] + omega^j A[j+K/2]
600	 A[j+K/2] <- A[j] + omega^(j+K/2) A[j+K/2] */
601      for (j = 0; j < K2; j++, Ap++)
602	{
603	  /* Ap[K2] <- Ap[0] + Ap[K2] * 2^((j + K2) * omega)
604	     Ap[0]  <- Ap[0] + Ap[K2] * 2^(j * omega) */
605	  mpn_fft_mul_2exp_modF (tp, Ap[K2], j * omega, n);
606#if HAVE_NATIVE_mpn_add_n_sub_n
607	  mpn_fft_add_sub_modF (Ap[0], Ap[K2], tp, n);
608#else
609	  mpn_fft_sub_modF (Ap[K2], Ap[0], tp, n);
610	  mpn_fft_add_modF (Ap[0],  Ap[0], tp, n);
611#endif
612	}
613    }
614}
615
616
617/* R <- A/2^k mod 2^(n*GMP_NUMB_BITS)+1 */
618static void
619mpn_fft_div_2exp_modF (mp_ptr r, mp_srcptr a, mp_bitcnt_t k, mp_size_t n)
620{
621  mp_bitcnt_t i;
622
623  ASSERT (r != a);
624  i = (mp_bitcnt_t) 2 * n * GMP_NUMB_BITS - k;
625  mpn_fft_mul_2exp_modF (r, a, i, n);
626  /* 1/2^k = 2^(2nL-k) mod 2^(n*GMP_NUMB_BITS)+1 */
627  /* normalize so that R < 2^(n*GMP_NUMB_BITS)+1 */
628  mpn_fft_normalize (r, n);
629}
630
631
632/* {rp,n} <- {ap,an} mod 2^(n*GMP_NUMB_BITS)+1, n <= an <= 3*n.
633   Returns carry out, i.e. 1 iff {ap,an} = -1 mod 2^(n*GMP_NUMB_BITS)+1,
634   then {rp,n}=0.
635*/
636static mp_size_t
637mpn_fft_norm_modF (mp_ptr rp, mp_size_t n, mp_ptr ap, mp_size_t an)
638{
639  mp_size_t l, m, rpn;
640  mp_limb_t cc;
641
642  ASSERT ((n <= an) && (an <= 3 * n));
643  m = an - 2 * n;
644  if (m > 0)
645    {
646      l = n;
647      /* add {ap, m} and {ap+2n, m} in {rp, m} */
648      cc = mpn_add_n (rp, ap, ap + 2 * n, m);
649      /* copy {ap+m, n-m} to {rp+m, n-m} */
650      rpn = mpn_add_1 (rp + m, ap + m, n - m, cc);
651    }
652  else
653    {
654      l = an - n; /* l <= n */
655      MPN_COPY (rp, ap, n);
656      rpn = 0;
657    }
658
659  /* remains to subtract {ap+n, l} from {rp, n+1} */
660  cc = mpn_sub_n (rp, rp, ap + n, l);
661  rpn -= mpn_sub_1 (rp + l, rp + l, n - l, cc);
662  if (rpn < 0) /* necessarily rpn = -1 */
663    rpn = mpn_add_1 (rp, rp, n, CNST_LIMB(1));
664  return rpn;
665}
666
667/* store in A[0..nprime] the first M bits from {n, nl},
668   in A[nprime+1..] the following M bits, ...
669   Assumes M is a multiple of GMP_NUMB_BITS (M = l * GMP_NUMB_BITS).
670   T must have space for at least (nprime + 1) limbs.
671   We must have nl <= 2*K*l.
672*/
673static void
674mpn_mul_fft_decompose (mp_ptr A, mp_ptr *Ap, mp_size_t K, mp_size_t nprime,
675		       mp_srcptr n, mp_size_t nl, mp_size_t l, mp_size_t Mp,
676		       mp_ptr T)
677{
678  mp_size_t i, j;
679  mp_ptr tmp;
680  mp_size_t Kl = K * l;
681  TMP_DECL;
682  TMP_MARK;
683
684  if (nl > Kl) /* normalize {n, nl} mod 2^(Kl*GMP_NUMB_BITS)+1 */
685    {
686      mp_size_t dif = nl - Kl;
687      mp_limb_signed_t cy;
688
689      tmp = TMP_BALLOC_LIMBS(Kl + 1);
690
691      if (dif > Kl)
692	{
693	  int subp = 0;
694
695	  cy = mpn_sub_n (tmp, n, n + Kl, Kl);
696	  n += 2 * Kl;
697	  dif -= Kl;
698
699	  /* now dif > 0 */
700	  while (dif > Kl)
701	    {
702	      if (subp)
703		cy += mpn_sub_n (tmp, tmp, n, Kl);
704	      else
705		cy -= mpn_add_n (tmp, tmp, n, Kl);
706	      subp ^= 1;
707	      n += Kl;
708	      dif -= Kl;
709	    }
710	  /* now dif <= Kl */
711	  if (subp)
712	    cy += mpn_sub (tmp, tmp, Kl, n, dif);
713	  else
714	    cy -= mpn_add (tmp, tmp, Kl, n, dif);
715	  if (cy >= 0)
716	    cy = mpn_add_1 (tmp, tmp, Kl, cy);
717	  else
718	    cy = mpn_sub_1 (tmp, tmp, Kl, -cy);
719	}
720      else /* dif <= Kl, i.e. nl <= 2 * Kl */
721	{
722	  cy = mpn_sub (tmp, n, Kl, n + Kl, dif);
723	  cy = mpn_add_1 (tmp, tmp, Kl, cy);
724	}
725      tmp[Kl] = cy;
726      nl = Kl + 1;
727      n = tmp;
728    }
729  for (i = 0; i < K; i++)
730    {
731      Ap[i] = A;
732      /* store the next M bits of n into A[0..nprime] */
733      if (nl > 0) /* nl is the number of remaining limbs */
734	{
735	  j = (l <= nl && i < K - 1) ? l : nl; /* store j next limbs */
736	  nl -= j;
737	  MPN_COPY (T, n, j);
738	  MPN_ZERO (T + j, nprime + 1 - j);
739	  n += l;
740	  mpn_fft_mul_2exp_modF (A, T, i * Mp, nprime);
741	}
742      else
743	MPN_ZERO (A, nprime + 1);
744      A += nprime + 1;
745    }
746  ASSERT_ALWAYS (nl == 0);
747  TMP_FREE;
748}
749
750/* op <- n*m mod 2^N+1 with fft of size 2^k where N=pl*GMP_NUMB_BITS
751   op is pl limbs, its high bit is returned.
752   One must have pl = mpn_fft_next_size (pl, k).
753   T must have space for 2 * (nprime + 1) limbs.
754*/
755
756static mp_limb_t
757mpn_mul_fft_internal (mp_ptr op, mp_size_t pl, int k,
758		      mp_ptr *Ap, mp_ptr *Bp, mp_ptr A, mp_ptr B,
759		      mp_size_t nprime, mp_size_t l, mp_size_t Mp,
760		      int **fft_l, mp_ptr T, int sqr)
761{
762  mp_size_t K, i, pla, lo, sh, j;
763  mp_ptr p;
764  mp_limb_t cc;
765
766  K = (mp_size_t) 1 << k;
767
768  /* direct fft's */
769  mpn_fft_fft (Ap, K, fft_l + k, 2 * Mp, nprime, 1, T);
770  if (!sqr)
771    mpn_fft_fft (Bp, K, fft_l + k, 2 * Mp, nprime, 1, T);
772
773  /* term to term multiplications */
774  mpn_fft_mul_modF_K (Ap, sqr ? Ap : Bp, nprime, K);
775
776  /* inverse fft's */
777  mpn_fft_fftinv (Ap, K, 2 * Mp, nprime, T);
778
779  /* division of terms after inverse fft */
780  Bp[0] = T + nprime + 1;
781  mpn_fft_div_2exp_modF (Bp[0], Ap[0], k, nprime);
782  for (i = 1; i < K; i++)
783    {
784      Bp[i] = Ap[i - 1];
785      mpn_fft_div_2exp_modF (Bp[i], Ap[i], k + (K - i) * Mp, nprime);
786    }
787
788  /* addition of terms in result p */
789  MPN_ZERO (T, nprime + 1);
790  pla = l * (K - 1) + nprime + 1; /* number of required limbs for p */
791  p = B; /* B has K*(n' + 1) limbs, which is >= pla, i.e. enough */
792  MPN_ZERO (p, pla);
793  cc = 0; /* will accumulate the (signed) carry at p[pla] */
794  for (i = K - 1, lo = l * i + nprime,sh = l * i; i >= 0; i--,lo -= l,sh -= l)
795    {
796      mp_ptr n = p + sh;
797
798      j = (K - i) & (K - 1);
799
800      if (mpn_add_n (n, n, Bp[j], nprime + 1))
801	cc += mpn_add_1 (n + nprime + 1, n + nprime + 1,
802			  pla - sh - nprime - 1, CNST_LIMB(1));
803      T[2 * l] = i + 1; /* T = (i + 1)*2^(2*M) */
804      if (mpn_cmp (Bp[j], T, nprime + 1) > 0)
805	{ /* subtract 2^N'+1 */
806	  cc -= mpn_sub_1 (n, n, pla - sh, CNST_LIMB(1));
807	  cc -= mpn_sub_1 (p + lo, p + lo, pla - lo, CNST_LIMB(1));
808	}
809    }
810  if (cc == -CNST_LIMB(1))
811    {
812      if ((cc = mpn_add_1 (p + pla - pl, p + pla - pl, pl, CNST_LIMB(1))))
813	{
814	  /* p[pla-pl]...p[pla-1] are all zero */
815	  mpn_sub_1 (p + pla - pl - 1, p + pla - pl - 1, pl + 1, CNST_LIMB(1));
816	  mpn_sub_1 (p + pla - 1, p + pla - 1, 1, CNST_LIMB(1));
817	}
818    }
819  else if (cc == 1)
820    {
821      if (pla >= 2 * pl)
822	{
823	  while ((cc = mpn_add_1 (p + pla - 2 * pl, p + pla - 2 * pl, 2 * pl, cc)))
824	    ;
825	}
826      else
827	{
828	  cc = mpn_sub_1 (p + pla - pl, p + pla - pl, pl, cc);
829	  ASSERT (cc == 0);
830	}
831    }
832  else
833    ASSERT (cc == 0);
834
835  /* here p < 2^(2M) [K 2^(M(K-1)) + (K-1) 2^(M(K-2)) + ... ]
836     < K 2^(2M) [2^(M(K-1)) + 2^(M(K-2)) + ... ]
837     < K 2^(2M) 2^(M(K-1))*2 = 2^(M*K+M+k+1) */
838  return mpn_fft_norm_modF (op, pl, p, pla);
839}
840
841/* return the lcm of a and 2^k */
842static mp_bitcnt_t
843mpn_mul_fft_lcm (mp_bitcnt_t a, int k)
844{
845  mp_bitcnt_t l = k;
846
847  while (a % 2 == 0 && k > 0)
848    {
849      a >>= 1;
850      k --;
851    }
852  return a << l;
853}
854
855
856mp_limb_t
857mpn_mul_fft (mp_ptr op, mp_size_t pl,
858	     mp_srcptr n, mp_size_t nl,
859	     mp_srcptr m, mp_size_t ml,
860	     int k)
861{
862  int i;
863  mp_size_t K, maxLK;
864  mp_size_t N, Nprime, nprime, M, Mp, l;
865  mp_ptr *Ap, *Bp, A, T, B;
866  int **fft_l, *tmp;
867  int sqr = (n == m && nl == ml);
868  mp_limb_t h;
869  TMP_DECL;
870
871  TRACE (printf ("\nmpn_mul_fft pl=%ld nl=%ld ml=%ld k=%d\n", pl, nl, ml, k));
872  ASSERT_ALWAYS (mpn_fft_next_size (pl, k) == pl);
873
874  TMP_MARK;
875  N = pl * GMP_NUMB_BITS;
876  fft_l = TMP_BALLOC_TYPE (k + 1, int *);
877  tmp = TMP_BALLOC_TYPE ((size_t) 2 << k, int);
878  for (i = 0; i <= k; i++)
879    {
880      fft_l[i] = tmp;
881      tmp += (mp_size_t) 1 << i;
882    }
883
884  mpn_fft_initl (fft_l, k);
885  K = (mp_size_t) 1 << k;
886  M = N >> k;	/* N = 2^k M */
887  l = 1 + (M - 1) / GMP_NUMB_BITS;
888  maxLK = mpn_mul_fft_lcm (GMP_NUMB_BITS, k); /* lcm (GMP_NUMB_BITS, 2^k) */
889
890  Nprime = (1 + (2 * M + k + 2) / maxLK) * maxLK;
891  /* Nprime = ceil((2*M+k+3)/maxLK)*maxLK; */
892  nprime = Nprime / GMP_NUMB_BITS;
893  TRACE (printf ("N=%ld K=%ld, M=%ld, l=%ld, maxLK=%ld, Np=%ld, np=%ld\n",
894		 N, K, M, l, maxLK, Nprime, nprime));
895  /* we should ensure that recursively, nprime is a multiple of the next K */
896  if (nprime >= (sqr ? SQR_FFT_MODF_THRESHOLD : MUL_FFT_MODF_THRESHOLD))
897    {
898      mp_size_t K2;
899      for (;;)
900	{
901	  K2 = (mp_size_t) 1 << mpn_fft_best_k (nprime, sqr);
902	  if ((nprime & (K2 - 1)) == 0)
903	    break;
904	  nprime = (nprime + K2 - 1) & -K2;
905	  Nprime = nprime * GMP_LIMB_BITS;
906	  /* warning: since nprime changed, K2 may change too! */
907	}
908      TRACE (printf ("new maxLK=%ld, Np=%ld, np=%ld\n", maxLK, Nprime, nprime));
909    }
910  ASSERT_ALWAYS (nprime < pl); /* otherwise we'll loop */
911
912  T = TMP_BALLOC_LIMBS (2 * (nprime + 1));
913  Mp = Nprime >> k;
914
915  TRACE (printf ("%ldx%ld limbs -> %ld times %ldx%ld limbs (%1.2f)\n",
916		pl, pl, K, nprime, nprime, 2.0 * (double) N / Nprime / K);
917	 printf ("   temp space %ld\n", 2 * K * (nprime + 1)));
918
919  A = TMP_BALLOC_LIMBS (K * (nprime + 1));
920  Ap = TMP_BALLOC_MP_PTRS (K);
921  mpn_mul_fft_decompose (A, Ap, K, nprime, n, nl, l, Mp, T);
922  if (sqr)
923    {
924      mp_size_t pla;
925      pla = l * (K - 1) + nprime + 1; /* number of required limbs for p */
926      B = TMP_BALLOC_LIMBS (pla);
927      Bp = TMP_BALLOC_MP_PTRS (K);
928    }
929  else
930    {
931      B = TMP_BALLOC_LIMBS (K * (nprime + 1));
932      Bp = TMP_BALLOC_MP_PTRS (K);
933      mpn_mul_fft_decompose (B, Bp, K, nprime, m, ml, l, Mp, T);
934    }
935  h = mpn_mul_fft_internal (op, pl, k, Ap, Bp, A, B, nprime, l, Mp, fft_l, T, sqr);
936
937  TMP_FREE;
938  return h;
939}
940
941#if WANT_OLD_FFT_FULL
942/* multiply {n, nl} by {m, ml}, and put the result in {op, nl+ml} */
943void
944mpn_mul_fft_full (mp_ptr op,
945		  mp_srcptr n, mp_size_t nl,
946		  mp_srcptr m, mp_size_t ml)
947{
948  mp_ptr pad_op;
949  mp_size_t pl, pl2, pl3, l;
950  mp_size_t cc, c2, oldcc;
951  int k2, k3;
952  int sqr = (n == m && nl == ml);
953
954  pl = nl + ml; /* total number of limbs of the result */
955
956  /* perform a fft mod 2^(2N)+1 and one mod 2^(3N)+1.
957     We must have pl3 = 3/2 * pl2, with pl2 a multiple of 2^k2, and
958     pl3 a multiple of 2^k3. Since k3 >= k2, both are multiples of 2^k2,
959     and pl2 must be an even multiple of 2^k2. Thus (pl2,pl3) =
960     (2*j*2^k2,3*j*2^k2), which works for 3*j <= pl/2^k2 <= 5*j.
961     We need that consecutive intervals overlap, i.e. 5*j >= 3*(j+1),
962     which requires j>=2. Thus this scheme requires pl >= 6 * 2^FFT_FIRST_K. */
963
964  /*  ASSERT_ALWAYS(pl >= 6 * (1 << FFT_FIRST_K)); */
965
966  pl2 = (2 * pl - 1) / 5; /* ceil (2pl/5) - 1 */
967  do
968    {
969      pl2++;
970      k2 = mpn_fft_best_k (pl2, sqr); /* best fft size for pl2 limbs */
971      pl2 = mpn_fft_next_size (pl2, k2);
972      pl3 = 3 * pl2 / 2; /* since k>=FFT_FIRST_K=4, pl2 is a multiple of 2^4,
973			    thus pl2 / 2 is exact */
974      k3 = mpn_fft_best_k (pl3, sqr);
975    }
976  while (mpn_fft_next_size (pl3, k3) != pl3);
977
978  TRACE (printf ("mpn_mul_fft_full nl=%ld ml=%ld -> pl2=%ld pl3=%ld k=%d\n",
979		 nl, ml, pl2, pl3, k2));
980
981  ASSERT_ALWAYS(pl3 <= pl);
982  cc = mpn_mul_fft (op, pl3, n, nl, m, ml, k3);     /* mu */
983  ASSERT(cc == 0);
984  pad_op = __GMP_ALLOCATE_FUNC_LIMBS (pl2);
985  cc = mpn_mul_fft (pad_op, pl2, n, nl, m, ml, k2); /* lambda */
986  cc = -cc + mpn_sub_n (pad_op, pad_op, op, pl2);    /* lambda - low(mu) */
987  /* 0 <= cc <= 1 */
988  ASSERT(0 <= cc && cc <= 1);
989  l = pl3 - pl2; /* l = pl2 / 2 since pl3 = 3/2 * pl2 */
990  c2 = mpn_add_n (pad_op, pad_op, op + pl2, l);
991  cc = mpn_add_1 (pad_op + l, pad_op + l, l, (mp_limb_t) c2) - cc;
992  ASSERT(-1 <= cc && cc <= 1);
993  if (cc < 0)
994    cc = mpn_add_1 (pad_op, pad_op, pl2, (mp_limb_t) -cc);
995  ASSERT(0 <= cc && cc <= 1);
996  /* now lambda-mu = {pad_op, pl2} - cc mod 2^(pl2*GMP_NUMB_BITS)+1 */
997  oldcc = cc;
998#if HAVE_NATIVE_mpn_add_n_sub_n
999  c2 = mpn_add_n_sub_n (pad_op + l, pad_op, pad_op, pad_op + l, l);
1000  cc += c2 >> 1; /* carry out from high <- low + high */
1001  c2 = c2 & 1; /* borrow out from low <- low - high */
1002#else
1003  {
1004    mp_ptr tmp;
1005    TMP_DECL;
1006
1007    TMP_MARK;
1008    tmp = TMP_BALLOC_LIMBS (l);
1009    MPN_COPY (tmp, pad_op, l);
1010    c2 = mpn_sub_n (pad_op,      pad_op, pad_op + l, l);
1011    cc += mpn_add_n (pad_op + l, tmp,    pad_op + l, l);
1012    TMP_FREE;
1013  }
1014#endif
1015  c2 += oldcc;
1016  /* first normalize {pad_op, pl2} before dividing by 2: c2 is the borrow
1017     at pad_op + l, cc is the carry at pad_op + pl2 */
1018  /* 0 <= cc <= 2 */
1019  cc -= mpn_sub_1 (pad_op + l, pad_op + l, l, (mp_limb_t) c2);
1020  /* -1 <= cc <= 2 */
1021  if (cc > 0)
1022    cc = -mpn_sub_1 (pad_op, pad_op, pl2, (mp_limb_t) cc);
1023  /* now -1 <= cc <= 0 */
1024  if (cc < 0)
1025    cc = mpn_add_1 (pad_op, pad_op, pl2, (mp_limb_t) -cc);
1026  /* now {pad_op, pl2} is normalized, with 0 <= cc <= 1 */
1027  if (pad_op[0] & 1) /* if odd, add 2^(pl2*GMP_NUMB_BITS)+1 */
1028    cc += 1 + mpn_add_1 (pad_op, pad_op, pl2, CNST_LIMB(1));
1029  /* now 0 <= cc <= 2, but cc=2 cannot occur since it would give a carry
1030     out below */
1031  mpn_rshift (pad_op, pad_op, pl2, 1); /* divide by two */
1032  if (cc) /* then cc=1 */
1033    pad_op [pl2 - 1] |= (mp_limb_t) 1 << (GMP_NUMB_BITS - 1);
1034  /* now {pad_op,pl2}-cc = (lambda-mu)/(1-2^(l*GMP_NUMB_BITS))
1035     mod 2^(pl2*GMP_NUMB_BITS) + 1 */
1036  c2 = mpn_add_n (op, op, pad_op, pl2); /* no need to add cc (is 0) */
1037  /* since pl2+pl3 >= pl, necessary the extra limbs (including cc) are zero */
1038  MPN_COPY (op + pl3, pad_op, pl - pl3);
1039  ASSERT_MPN_ZERO_P (pad_op + pl - pl3, pl2 + pl3 - pl);
1040  __GMP_FREE_FUNC_LIMBS (pad_op, pl2);
1041  /* since the final result has at most pl limbs, no carry out below */
1042  mpn_add_1 (op + pl2, op + pl2, pl - pl2, (mp_limb_t) c2);
1043}
1044#endif
1045