1/* Schoenhage's fast multiplication modulo 2^N+1. 2 3 Contributed by Paul Zimmermann. 4 5 THE FUNCTIONS IN THIS FILE ARE INTERNAL WITH MUTABLE INTERFACES. IT IS ONLY 6 SAFE TO REACH THEM THROUGH DOCUMENTED INTERFACES. IN FACT, IT IS ALMOST 7 GUARANTEED THAT THEY WILL CHANGE OR DISAPPEAR IN A FUTURE GNU MP RELEASE. 8 9Copyright 1998-2010, 2012, 2013, 2018, 2020 Free Software Foundation, Inc. 10 11This file is part of the GNU MP Library. 12 13The GNU MP Library is free software; you can redistribute it and/or modify 14it under the terms of either: 15 16 * the GNU Lesser General Public License as published by the Free 17 Software Foundation; either version 3 of the License, or (at your 18 option) any later version. 19 20or 21 22 * the GNU General Public License as published by the Free Software 23 Foundation; either version 2 of the License, or (at your option) any 24 later version. 25 26or both in parallel, as here. 27 28The GNU MP Library is distributed in the hope that it will be useful, but 29WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY 30or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License 31for more details. 32 33You should have received copies of the GNU General Public License and the 34GNU Lesser General Public License along with the GNU MP Library. If not, 35see https://www.gnu.org/licenses/. */ 36 37 38/* References: 39 40 Schnelle Multiplikation grosser Zahlen, by Arnold Schoenhage and Volker 41 Strassen, Computing 7, p. 281-292, 1971. 42 43 Asymptotically fast algorithms for the numerical multiplication and division 44 of polynomials with complex coefficients, by Arnold Schoenhage, Computer 45 Algebra, EUROCAM'82, LNCS 144, p. 3-15, 1982. 46 47 Tapes versus Pointers, a study in implementing fast algorithms, by Arnold 48 Schoenhage, Bulletin of the EATCS, 30, p. 23-32, 1986. 49 50 TODO: 51 52 Implement some of the tricks published at ISSAC'2007 by Gaudry, Kruppa, and 53 Zimmermann. 54 55 It might be possible to avoid a small number of MPN_COPYs by using a 56 rotating temporary or two. 57 58 Cleanup and simplify the code! 59*/ 60 61#ifdef TRACE 62#undef TRACE 63#define TRACE(x) x 64#include <stdio.h> 65#else 66#define TRACE(x) 67#endif 68 69#include "gmp-impl.h" 70 71#ifdef WANT_ADDSUB 72#include "generic/add_n_sub_n.c" 73#define HAVE_NATIVE_mpn_add_n_sub_n 1 74#endif 75 76static mp_limb_t mpn_mul_fft_internal (mp_ptr, mp_size_t, int, mp_ptr *, 77 mp_ptr *, mp_ptr, mp_ptr, mp_size_t, 78 mp_size_t, mp_size_t, int **, mp_ptr, int); 79static void mpn_mul_fft_decompose (mp_ptr, mp_ptr *, mp_size_t, mp_size_t, mp_srcptr, 80 mp_size_t, mp_size_t, mp_size_t, mp_ptr); 81 82 83/* Find the best k to use for a mod 2^(m*GMP_NUMB_BITS)+1 FFT for m >= n. 84 We have sqr=0 if for a multiply, sqr=1 for a square. 85 There are three generations of this code; we keep the old ones as long as 86 some gmp-mparam.h is not updated. */ 87 88 89/*****************************************************************************/ 90 91#if TUNE_PROGRAM_BUILD || (defined (MUL_FFT_TABLE3) && defined (SQR_FFT_TABLE3)) 92 93#ifndef FFT_TABLE3_SIZE /* When tuning this is defined in gmp-impl.h */ 94#if defined (MUL_FFT_TABLE3_SIZE) && defined (SQR_FFT_TABLE3_SIZE) 95#if MUL_FFT_TABLE3_SIZE > SQR_FFT_TABLE3_SIZE 96#define FFT_TABLE3_SIZE MUL_FFT_TABLE3_SIZE 97#else 98#define FFT_TABLE3_SIZE SQR_FFT_TABLE3_SIZE 99#endif 100#endif 101#endif 102 103#ifndef FFT_TABLE3_SIZE 104#define FFT_TABLE3_SIZE 200 105#endif 106 107FFT_TABLE_ATTRS struct fft_table_nk mpn_fft_table3[2][FFT_TABLE3_SIZE] = 108{ 109 MUL_FFT_TABLE3, 110 SQR_FFT_TABLE3 111}; 112 113int 114mpn_fft_best_k (mp_size_t n, int sqr) 115{ 116 const struct fft_table_nk *fft_tab, *tab; 117 mp_size_t tab_n, thres; 118 int last_k; 119 120 fft_tab = mpn_fft_table3[sqr]; 121 last_k = fft_tab->k; 122 for (tab = fft_tab + 1; ; tab++) 123 { 124 tab_n = tab->n; 125 thres = tab_n << last_k; 126 if (n <= thres) 127 break; 128 last_k = tab->k; 129 } 130 return last_k; 131} 132 133#define MPN_FFT_BEST_READY 1 134#endif 135 136/*****************************************************************************/ 137 138#if ! defined (MPN_FFT_BEST_READY) 139FFT_TABLE_ATTRS mp_size_t mpn_fft_table[2][MPN_FFT_TABLE_SIZE] = 140{ 141 MUL_FFT_TABLE, 142 SQR_FFT_TABLE 143}; 144 145int 146mpn_fft_best_k (mp_size_t n, int sqr) 147{ 148 int i; 149 150 for (i = 0; mpn_fft_table[sqr][i] != 0; i++) 151 if (n < mpn_fft_table[sqr][i]) 152 return i + FFT_FIRST_K; 153 154 /* treat 4*last as one further entry */ 155 if (i == 0 || n < 4 * mpn_fft_table[sqr][i - 1]) 156 return i + FFT_FIRST_K; 157 else 158 return i + FFT_FIRST_K + 1; 159} 160#endif 161 162/*****************************************************************************/ 163 164 165/* Returns smallest possible number of limbs >= pl for a fft of size 2^k, 166 i.e. smallest multiple of 2^k >= pl. 167 168 Don't declare static: needed by tuneup. 169*/ 170 171mp_size_t 172mpn_fft_next_size (mp_size_t pl, int k) 173{ 174 pl = 1 + ((pl - 1) >> k); /* ceil (pl/2^k) */ 175 return pl << k; 176} 177 178 179/* Initialize l[i][j] with bitrev(j) */ 180static void 181mpn_fft_initl (int **l, int k) 182{ 183 int i, j, K; 184 int *li; 185 186 l[0][0] = 0; 187 for (i = 1, K = 1; i <= k; i++, K *= 2) 188 { 189 li = l[i]; 190 for (j = 0; j < K; j++) 191 { 192 li[j] = 2 * l[i - 1][j]; 193 li[K + j] = 1 + li[j]; 194 } 195 } 196} 197 198 199/* r <- a*2^d mod 2^(n*GMP_NUMB_BITS)+1 with a = {a, n+1} 200 Assumes a is semi-normalized, i.e. a[n] <= 1. 201 r and a must have n+1 limbs, and not overlap. 202*/ 203static void 204mpn_fft_mul_2exp_modF (mp_ptr r, mp_srcptr a, mp_bitcnt_t d, mp_size_t n) 205{ 206 unsigned int sh; 207 mp_size_t m; 208 mp_limb_t cc, rd; 209 210 sh = d % GMP_NUMB_BITS; 211 m = d / GMP_NUMB_BITS; 212 213 if (m >= n) /* negate */ 214 { 215 /* r[0..m-1] <-- lshift(a[n-m]..a[n-1], sh) 216 r[m..n-1] <-- -lshift(a[0]..a[n-m-1], sh) */ 217 218 m -= n; 219 if (sh != 0) 220 { 221 /* no out shift below since a[n] <= 1 */ 222 mpn_lshift (r, a + n - m, m + 1, sh); 223 rd = r[m]; 224 cc = mpn_lshiftc (r + m, a, n - m, sh); 225 } 226 else 227 { 228 MPN_COPY (r, a + n - m, m); 229 rd = a[n]; 230 mpn_com (r + m, a, n - m); 231 cc = 0; 232 } 233 234 /* add cc to r[0], and add rd to r[m] */ 235 236 /* now add 1 in r[m], subtract 1 in r[n], i.e. add 1 in r[0] */ 237 238 r[n] = 0; 239 /* cc < 2^sh <= 2^(GMP_NUMB_BITS-1) thus no overflow here */ 240 cc++; 241 mpn_incr_u (r, cc); 242 243 rd++; 244 /* rd might overflow when sh=GMP_NUMB_BITS-1 */ 245 cc = (rd == 0) ? 1 : rd; 246 r = r + m + (rd == 0); 247 mpn_incr_u (r, cc); 248 } 249 else 250 { 251 /* r[0..m-1] <-- -lshift(a[n-m]..a[n-1], sh) 252 r[m..n-1] <-- lshift(a[0]..a[n-m-1], sh) */ 253 if (sh != 0) 254 { 255 /* no out bits below since a[n] <= 1 */ 256 mpn_lshiftc (r, a + n - m, m + 1, sh); 257 rd = ~r[m]; 258 /* {r, m+1} = {a+n-m, m+1} << sh */ 259 cc = mpn_lshift (r + m, a, n - m, sh); /* {r+m, n-m} = {a, n-m}<<sh */ 260 } 261 else 262 { 263 /* r[m] is not used below, but we save a test for m=0 */ 264 mpn_com (r, a + n - m, m + 1); 265 rd = a[n]; 266 MPN_COPY (r + m, a, n - m); 267 cc = 0; 268 } 269 270 /* now complement {r, m}, subtract cc from r[0], subtract rd from r[m] */ 271 272 /* if m=0 we just have r[0]=a[n] << sh */ 273 if (m != 0) 274 { 275 /* now add 1 in r[0], subtract 1 in r[m] */ 276 if (cc-- == 0) /* then add 1 to r[0] */ 277 cc = mpn_add_1 (r, r, n, CNST_LIMB(1)); 278 cc = mpn_sub_1 (r, r, m, cc) + 1; 279 /* add 1 to cc instead of rd since rd might overflow */ 280 } 281 282 /* now subtract cc and rd from r[m..n] */ 283 284 r[n] = -mpn_sub_1 (r + m, r + m, n - m, cc); 285 r[n] -= mpn_sub_1 (r + m, r + m, n - m, rd); 286 if (r[n] & GMP_LIMB_HIGHBIT) 287 r[n] = mpn_add_1 (r, r, n, CNST_LIMB(1)); 288 } 289} 290 291#if HAVE_NATIVE_mpn_add_n_sub_n 292static inline void 293mpn_fft_add_sub_modF (mp_ptr A0, mp_ptr Ai, mp_srcptr tp, mp_size_t n) 294{ 295 mp_limb_t cyas, c, x; 296 297 cyas = mpn_add_n_sub_n (A0, Ai, A0, tp, n); 298 299 c = A0[n] - tp[n] - (cyas & 1); 300 x = (-c) & -((c & GMP_LIMB_HIGHBIT) != 0); 301 Ai[n] = x + c; 302 MPN_INCR_U (Ai, n + 1, x); 303 304 c = A0[n] + tp[n] + (cyas >> 1); 305 x = (c - 1) & -(c != 0); 306 A0[n] = c - x; 307 MPN_DECR_U (A0, n + 1, x); 308} 309 310#else /* ! HAVE_NATIVE_mpn_add_n_sub_n */ 311 312/* r <- a+b mod 2^(n*GMP_NUMB_BITS)+1. 313 Assumes a and b are semi-normalized. 314*/ 315static inline void 316mpn_fft_add_modF (mp_ptr r, mp_srcptr a, mp_srcptr b, mp_size_t n) 317{ 318 mp_limb_t c, x; 319 320 c = a[n] + b[n] + mpn_add_n (r, a, b, n); 321 /* 0 <= c <= 3 */ 322 323#if 1 324 /* GCC 4.1 outsmarts most expressions here, and generates a 50% branch. The 325 result is slower code, of course. But the following outsmarts GCC. */ 326 x = (c - 1) & -(c != 0); 327 r[n] = c - x; 328 MPN_DECR_U (r, n + 1, x); 329#endif 330#if 0 331 if (c > 1) 332 { 333 r[n] = 1; /* r[n] - c = 1 */ 334 MPN_DECR_U (r, n + 1, c - 1); 335 } 336 else 337 { 338 r[n] = c; 339 } 340#endif 341} 342 343/* r <- a-b mod 2^(n*GMP_NUMB_BITS)+1. 344 Assumes a and b are semi-normalized. 345*/ 346static inline void 347mpn_fft_sub_modF (mp_ptr r, mp_srcptr a, mp_srcptr b, mp_size_t n) 348{ 349 mp_limb_t c, x; 350 351 c = a[n] - b[n] - mpn_sub_n (r, a, b, n); 352 /* -2 <= c <= 1 */ 353 354#if 1 355 /* GCC 4.1 outsmarts most expressions here, and generates a 50% branch. The 356 result is slower code, of course. But the following outsmarts GCC. */ 357 x = (-c) & -((c & GMP_LIMB_HIGHBIT) != 0); 358 r[n] = x + c; 359 MPN_INCR_U (r, n + 1, x); 360#endif 361#if 0 362 if ((c & GMP_LIMB_HIGHBIT) != 0) 363 { 364 r[n] = 0; 365 MPN_INCR_U (r, n + 1, -c); 366 } 367 else 368 { 369 r[n] = c; 370 } 371#endif 372} 373#endif /* HAVE_NATIVE_mpn_add_n_sub_n */ 374 375/* input: A[0] ... A[inc*(K-1)] are residues mod 2^N+1 where 376 N=n*GMP_NUMB_BITS, and 2^omega is a primitive root mod 2^N+1 377 output: A[inc*l[k][i]] <- \sum (2^omega)^(ij) A[inc*j] mod 2^N+1 */ 378 379static void 380mpn_fft_fft (mp_ptr *Ap, mp_size_t K, int **ll, 381 mp_size_t omega, mp_size_t n, mp_size_t inc, mp_ptr tp) 382{ 383 if (K == 2) 384 { 385 mp_limb_t cy; 386#if HAVE_NATIVE_mpn_add_n_sub_n 387 cy = mpn_add_n_sub_n (Ap[0], Ap[inc], Ap[0], Ap[inc], n + 1) & 1; 388#else 389 MPN_COPY (tp, Ap[0], n + 1); 390 mpn_add_n (Ap[0], Ap[0], Ap[inc], n + 1); 391 cy = mpn_sub_n (Ap[inc], tp, Ap[inc], n + 1); 392#endif 393 if (Ap[0][n] > 1) /* can be 2 or 3 */ 394 Ap[0][n] = 1 - mpn_sub_1 (Ap[0], Ap[0], n, Ap[0][n] - 1); 395 if (cy) /* Ap[inc][n] can be -1 or -2 */ 396 Ap[inc][n] = mpn_add_1 (Ap[inc], Ap[inc], n, ~Ap[inc][n] + 1); 397 } 398 else 399 { 400 mp_size_t j, K2 = K >> 1; 401 int *lk = *ll; 402 403 mpn_fft_fft (Ap, K2, ll-1, 2 * omega, n, inc * 2, tp); 404 mpn_fft_fft (Ap+inc, K2, ll-1, 2 * omega, n, inc * 2, tp); 405 /* A[2*j*inc] <- A[2*j*inc] + omega^l[k][2*j*inc] A[(2j+1)inc] 406 A[(2j+1)inc] <- A[2*j*inc] + omega^l[k][(2j+1)inc] A[(2j+1)inc] */ 407 for (j = 0; j < K2; j++, lk += 2, Ap += 2 * inc) 408 { 409 /* Ap[inc] <- Ap[0] + Ap[inc] * 2^(lk[1] * omega) 410 Ap[0] <- Ap[0] + Ap[inc] * 2^(lk[0] * omega) */ 411 mpn_fft_mul_2exp_modF (tp, Ap[inc], lk[0] * omega, n); 412#if HAVE_NATIVE_mpn_add_n_sub_n 413 mpn_fft_add_sub_modF (Ap[0], Ap[inc], tp, n); 414#else 415 mpn_fft_sub_modF (Ap[inc], Ap[0], tp, n); 416 mpn_fft_add_modF (Ap[0], Ap[0], tp, n); 417#endif 418 } 419 } 420} 421 422/* input: A[0] ... A[inc*(K-1)] are residues mod 2^N+1 where 423 N=n*GMP_NUMB_BITS, and 2^omega is a primitive root mod 2^N+1 424 output: A[inc*l[k][i]] <- \sum (2^omega)^(ij) A[inc*j] mod 2^N+1 425 tp must have space for 2*(n+1) limbs. 426*/ 427 428 429/* Given ap[0..n] with ap[n]<=1, reduce it modulo 2^(n*GMP_NUMB_BITS)+1, 430 by subtracting that modulus if necessary. 431 432 If ap[0..n] is exactly 2^(n*GMP_NUMB_BITS) then mpn_sub_1 produces a 433 borrow and the limbs must be zeroed out again. This will occur very 434 infrequently. */ 435 436static inline void 437mpn_fft_normalize (mp_ptr ap, mp_size_t n) 438{ 439 if (ap[n] != 0) 440 { 441 MPN_DECR_U (ap, n + 1, CNST_LIMB(1)); 442 if (ap[n] == 0) 443 { 444 /* This happens with very low probability; we have yet to trigger it, 445 and thereby make sure this code is correct. */ 446 MPN_ZERO (ap, n); 447 ap[n] = 1; 448 } 449 else 450 ap[n] = 0; 451 } 452} 453 454/* a[i] <- a[i]*b[i] mod 2^(n*GMP_NUMB_BITS)+1 for 0 <= i < K */ 455static void 456mpn_fft_mul_modF_K (mp_ptr *ap, mp_ptr *bp, mp_size_t n, mp_size_t K) 457{ 458 int i; 459 int sqr = (ap == bp); 460 TMP_DECL; 461 462 TMP_MARK; 463 464 if (n >= (sqr ? SQR_FFT_MODF_THRESHOLD : MUL_FFT_MODF_THRESHOLD)) 465 { 466 mp_size_t K2, nprime2, Nprime2, M2, maxLK, l, Mp2; 467 int k; 468 int **fft_l, *tmp; 469 mp_ptr *Ap, *Bp, A, B, T; 470 471 k = mpn_fft_best_k (n, sqr); 472 K2 = (mp_size_t) 1 << k; 473 ASSERT_ALWAYS((n & (K2 - 1)) == 0); 474 maxLK = (K2 > GMP_NUMB_BITS) ? K2 : GMP_NUMB_BITS; 475 M2 = n * GMP_NUMB_BITS >> k; 476 l = n >> k; 477 Nprime2 = ((2 * M2 + k + 2 + maxLK) / maxLK) * maxLK; 478 /* Nprime2 = ceil((2*M2+k+3)/maxLK)*maxLK*/ 479 nprime2 = Nprime2 / GMP_NUMB_BITS; 480 481 /* we should ensure that nprime2 is a multiple of the next K */ 482 if (nprime2 >= (sqr ? SQR_FFT_MODF_THRESHOLD : MUL_FFT_MODF_THRESHOLD)) 483 { 484 mp_size_t K3; 485 for (;;) 486 { 487 K3 = (mp_size_t) 1 << mpn_fft_best_k (nprime2, sqr); 488 if ((nprime2 & (K3 - 1)) == 0) 489 break; 490 nprime2 = (nprime2 + K3 - 1) & -K3; 491 Nprime2 = nprime2 * GMP_LIMB_BITS; 492 /* warning: since nprime2 changed, K3 may change too! */ 493 } 494 } 495 ASSERT_ALWAYS(nprime2 < n); /* otherwise we'll loop */ 496 497 Mp2 = Nprime2 >> k; 498 499 Ap = TMP_BALLOC_MP_PTRS (K2); 500 Bp = TMP_BALLOC_MP_PTRS (K2); 501 A = TMP_BALLOC_LIMBS (2 * (nprime2 + 1) << k); 502 T = TMP_BALLOC_LIMBS (2 * (nprime2 + 1)); 503 B = A + ((nprime2 + 1) << k); 504 fft_l = TMP_BALLOC_TYPE (k + 1, int *); 505 tmp = TMP_BALLOC_TYPE ((size_t) 2 << k, int); 506 for (i = 0; i <= k; i++) 507 { 508 fft_l[i] = tmp; 509 tmp += (mp_size_t) 1 << i; 510 } 511 512 mpn_fft_initl (fft_l, k); 513 514 TRACE (printf ("recurse: %ldx%ld limbs -> %ld times %ldx%ld (%1.2f)\n", n, 515 n, K2, nprime2, nprime2, 2.0*(double)n/nprime2/K2)); 516 for (i = 0; i < K; i++, ap++, bp++) 517 { 518 mp_limb_t cy; 519 mpn_fft_normalize (*ap, n); 520 if (!sqr) 521 mpn_fft_normalize (*bp, n); 522 523 mpn_mul_fft_decompose (A, Ap, K2, nprime2, *ap, (l << k) + 1, l, Mp2, T); 524 if (!sqr) 525 mpn_mul_fft_decompose (B, Bp, K2, nprime2, *bp, (l << k) + 1, l, Mp2, T); 526 527 cy = mpn_mul_fft_internal (*ap, n, k, Ap, Bp, A, B, nprime2, 528 l, Mp2, fft_l, T, sqr); 529 (*ap)[n] = cy; 530 } 531 } 532 else 533 { 534 mp_ptr a, b, tp, tpn; 535 mp_limb_t cc; 536 mp_size_t n2 = 2 * n; 537 tp = TMP_BALLOC_LIMBS (n2); 538 tpn = tp + n; 539 TRACE (printf (" mpn_mul_n %ld of %ld limbs\n", K, n)); 540 for (i = 0; i < K; i++) 541 { 542 a = *ap++; 543 b = *bp++; 544 if (sqr) 545 mpn_sqr (tp, a, n); 546 else 547 mpn_mul_n (tp, b, a, n); 548 if (a[n] != 0) 549 cc = mpn_add_n (tpn, tpn, b, n); 550 else 551 cc = 0; 552 if (b[n] != 0) 553 cc += mpn_add_n (tpn, tpn, a, n) + a[n]; 554 if (cc != 0) 555 { 556 cc = mpn_add_1 (tp, tp, n2, cc); 557 /* If mpn_add_1 give a carry (cc != 0), 558 the result (tp) is at most GMP_NUMB_MAX - 1, 559 so the following addition can't overflow. 560 */ 561 tp[0] += cc; 562 } 563 a[n] = mpn_sub_n (a, tp, tpn, n) && mpn_add_1 (a, a, n, CNST_LIMB(1)); 564 } 565 } 566 TMP_FREE; 567} 568 569 570/* input: A^[l[k][0]] A^[l[k][1]] ... A^[l[k][K-1]] 571 output: K*A[0] K*A[K-1] ... K*A[1]. 572 Assumes the Ap[] are pseudo-normalized, i.e. 0 <= Ap[][n] <= 1. 573 This condition is also fulfilled at exit. 574*/ 575static void 576mpn_fft_fftinv (mp_ptr *Ap, mp_size_t K, mp_size_t omega, mp_size_t n, mp_ptr tp) 577{ 578 if (K == 2) 579 { 580 mp_limb_t cy; 581#if HAVE_NATIVE_mpn_add_n_sub_n 582 cy = mpn_add_n_sub_n (Ap[0], Ap[1], Ap[0], Ap[1], n + 1) & 1; 583#else 584 MPN_COPY (tp, Ap[0], n + 1); 585 mpn_add_n (Ap[0], Ap[0], Ap[1], n + 1); 586 cy = mpn_sub_n (Ap[1], tp, Ap[1], n + 1); 587#endif 588 if (Ap[0][n] > 1) /* can be 2 or 3 */ 589 Ap[0][n] = 1 - mpn_sub_1 (Ap[0], Ap[0], n, Ap[0][n] - 1); 590 if (cy) /* Ap[1][n] can be -1 or -2 */ 591 Ap[1][n] = mpn_add_1 (Ap[1], Ap[1], n, ~Ap[1][n] + 1); 592 } 593 else 594 { 595 mp_size_t j, K2 = K >> 1; 596 597 mpn_fft_fftinv (Ap, K2, 2 * omega, n, tp); 598 mpn_fft_fftinv (Ap + K2, K2, 2 * omega, n, tp); 599 /* A[j] <- A[j] + omega^j A[j+K/2] 600 A[j+K/2] <- A[j] + omega^(j+K/2) A[j+K/2] */ 601 for (j = 0; j < K2; j++, Ap++) 602 { 603 /* Ap[K2] <- Ap[0] + Ap[K2] * 2^((j + K2) * omega) 604 Ap[0] <- Ap[0] + Ap[K2] * 2^(j * omega) */ 605 mpn_fft_mul_2exp_modF (tp, Ap[K2], j * omega, n); 606#if HAVE_NATIVE_mpn_add_n_sub_n 607 mpn_fft_add_sub_modF (Ap[0], Ap[K2], tp, n); 608#else 609 mpn_fft_sub_modF (Ap[K2], Ap[0], tp, n); 610 mpn_fft_add_modF (Ap[0], Ap[0], tp, n); 611#endif 612 } 613 } 614} 615 616 617/* R <- A/2^k mod 2^(n*GMP_NUMB_BITS)+1 */ 618static void 619mpn_fft_div_2exp_modF (mp_ptr r, mp_srcptr a, mp_bitcnt_t k, mp_size_t n) 620{ 621 mp_bitcnt_t i; 622 623 ASSERT (r != a); 624 i = (mp_bitcnt_t) 2 * n * GMP_NUMB_BITS - k; 625 mpn_fft_mul_2exp_modF (r, a, i, n); 626 /* 1/2^k = 2^(2nL-k) mod 2^(n*GMP_NUMB_BITS)+1 */ 627 /* normalize so that R < 2^(n*GMP_NUMB_BITS)+1 */ 628 mpn_fft_normalize (r, n); 629} 630 631 632/* {rp,n} <- {ap,an} mod 2^(n*GMP_NUMB_BITS)+1, n <= an <= 3*n. 633 Returns carry out, i.e. 1 iff {ap,an} = -1 mod 2^(n*GMP_NUMB_BITS)+1, 634 then {rp,n}=0. 635*/ 636static mp_size_t 637mpn_fft_norm_modF (mp_ptr rp, mp_size_t n, mp_ptr ap, mp_size_t an) 638{ 639 mp_size_t l, m, rpn; 640 mp_limb_t cc; 641 642 ASSERT ((n <= an) && (an <= 3 * n)); 643 m = an - 2 * n; 644 if (m > 0) 645 { 646 l = n; 647 /* add {ap, m} and {ap+2n, m} in {rp, m} */ 648 cc = mpn_add_n (rp, ap, ap + 2 * n, m); 649 /* copy {ap+m, n-m} to {rp+m, n-m} */ 650 rpn = mpn_add_1 (rp + m, ap + m, n - m, cc); 651 } 652 else 653 { 654 l = an - n; /* l <= n */ 655 MPN_COPY (rp, ap, n); 656 rpn = 0; 657 } 658 659 /* remains to subtract {ap+n, l} from {rp, n+1} */ 660 cc = mpn_sub_n (rp, rp, ap + n, l); 661 rpn -= mpn_sub_1 (rp + l, rp + l, n - l, cc); 662 if (rpn < 0) /* necessarily rpn = -1 */ 663 rpn = mpn_add_1 (rp, rp, n, CNST_LIMB(1)); 664 return rpn; 665} 666 667/* store in A[0..nprime] the first M bits from {n, nl}, 668 in A[nprime+1..] the following M bits, ... 669 Assumes M is a multiple of GMP_NUMB_BITS (M = l * GMP_NUMB_BITS). 670 T must have space for at least (nprime + 1) limbs. 671 We must have nl <= 2*K*l. 672*/ 673static void 674mpn_mul_fft_decompose (mp_ptr A, mp_ptr *Ap, mp_size_t K, mp_size_t nprime, 675 mp_srcptr n, mp_size_t nl, mp_size_t l, mp_size_t Mp, 676 mp_ptr T) 677{ 678 mp_size_t i, j; 679 mp_ptr tmp; 680 mp_size_t Kl = K * l; 681 TMP_DECL; 682 TMP_MARK; 683 684 if (nl > Kl) /* normalize {n, nl} mod 2^(Kl*GMP_NUMB_BITS)+1 */ 685 { 686 mp_size_t dif = nl - Kl; 687 mp_limb_signed_t cy; 688 689 tmp = TMP_BALLOC_LIMBS(Kl + 1); 690 691 if (dif > Kl) 692 { 693 int subp = 0; 694 695 cy = mpn_sub_n (tmp, n, n + Kl, Kl); 696 n += 2 * Kl; 697 dif -= Kl; 698 699 /* now dif > 0 */ 700 while (dif > Kl) 701 { 702 if (subp) 703 cy += mpn_sub_n (tmp, tmp, n, Kl); 704 else 705 cy -= mpn_add_n (tmp, tmp, n, Kl); 706 subp ^= 1; 707 n += Kl; 708 dif -= Kl; 709 } 710 /* now dif <= Kl */ 711 if (subp) 712 cy += mpn_sub (tmp, tmp, Kl, n, dif); 713 else 714 cy -= mpn_add (tmp, tmp, Kl, n, dif); 715 if (cy >= 0) 716 cy = mpn_add_1 (tmp, tmp, Kl, cy); 717 else 718 cy = mpn_sub_1 (tmp, tmp, Kl, -cy); 719 } 720 else /* dif <= Kl, i.e. nl <= 2 * Kl */ 721 { 722 cy = mpn_sub (tmp, n, Kl, n + Kl, dif); 723 cy = mpn_add_1 (tmp, tmp, Kl, cy); 724 } 725 tmp[Kl] = cy; 726 nl = Kl + 1; 727 n = tmp; 728 } 729 for (i = 0; i < K; i++) 730 { 731 Ap[i] = A; 732 /* store the next M bits of n into A[0..nprime] */ 733 if (nl > 0) /* nl is the number of remaining limbs */ 734 { 735 j = (l <= nl && i < K - 1) ? l : nl; /* store j next limbs */ 736 nl -= j; 737 MPN_COPY (T, n, j); 738 MPN_ZERO (T + j, nprime + 1 - j); 739 n += l; 740 mpn_fft_mul_2exp_modF (A, T, i * Mp, nprime); 741 } 742 else 743 MPN_ZERO (A, nprime + 1); 744 A += nprime + 1; 745 } 746 ASSERT_ALWAYS (nl == 0); 747 TMP_FREE; 748} 749 750/* op <- n*m mod 2^N+1 with fft of size 2^k where N=pl*GMP_NUMB_BITS 751 op is pl limbs, its high bit is returned. 752 One must have pl = mpn_fft_next_size (pl, k). 753 T must have space for 2 * (nprime + 1) limbs. 754*/ 755 756static mp_limb_t 757mpn_mul_fft_internal (mp_ptr op, mp_size_t pl, int k, 758 mp_ptr *Ap, mp_ptr *Bp, mp_ptr A, mp_ptr B, 759 mp_size_t nprime, mp_size_t l, mp_size_t Mp, 760 int **fft_l, mp_ptr T, int sqr) 761{ 762 mp_size_t K, i, pla, lo, sh, j; 763 mp_ptr p; 764 mp_limb_t cc; 765 766 K = (mp_size_t) 1 << k; 767 768 /* direct fft's */ 769 mpn_fft_fft (Ap, K, fft_l + k, 2 * Mp, nprime, 1, T); 770 if (!sqr) 771 mpn_fft_fft (Bp, K, fft_l + k, 2 * Mp, nprime, 1, T); 772 773 /* term to term multiplications */ 774 mpn_fft_mul_modF_K (Ap, sqr ? Ap : Bp, nprime, K); 775 776 /* inverse fft's */ 777 mpn_fft_fftinv (Ap, K, 2 * Mp, nprime, T); 778 779 /* division of terms after inverse fft */ 780 Bp[0] = T + nprime + 1; 781 mpn_fft_div_2exp_modF (Bp[0], Ap[0], k, nprime); 782 for (i = 1; i < K; i++) 783 { 784 Bp[i] = Ap[i - 1]; 785 mpn_fft_div_2exp_modF (Bp[i], Ap[i], k + (K - i) * Mp, nprime); 786 } 787 788 /* addition of terms in result p */ 789 MPN_ZERO (T, nprime + 1); 790 pla = l * (K - 1) + nprime + 1; /* number of required limbs for p */ 791 p = B; /* B has K*(n' + 1) limbs, which is >= pla, i.e. enough */ 792 MPN_ZERO (p, pla); 793 cc = 0; /* will accumulate the (signed) carry at p[pla] */ 794 for (i = K - 1, lo = l * i + nprime,sh = l * i; i >= 0; i--,lo -= l,sh -= l) 795 { 796 mp_ptr n = p + sh; 797 798 j = (K - i) & (K - 1); 799 800 if (mpn_add_n (n, n, Bp[j], nprime + 1)) 801 cc += mpn_add_1 (n + nprime + 1, n + nprime + 1, 802 pla - sh - nprime - 1, CNST_LIMB(1)); 803 T[2 * l] = i + 1; /* T = (i + 1)*2^(2*M) */ 804 if (mpn_cmp (Bp[j], T, nprime + 1) > 0) 805 { /* subtract 2^N'+1 */ 806 cc -= mpn_sub_1 (n, n, pla - sh, CNST_LIMB(1)); 807 cc -= mpn_sub_1 (p + lo, p + lo, pla - lo, CNST_LIMB(1)); 808 } 809 } 810 if (cc == -CNST_LIMB(1)) 811 { 812 if ((cc = mpn_add_1 (p + pla - pl, p + pla - pl, pl, CNST_LIMB(1)))) 813 { 814 /* p[pla-pl]...p[pla-1] are all zero */ 815 mpn_sub_1 (p + pla - pl - 1, p + pla - pl - 1, pl + 1, CNST_LIMB(1)); 816 mpn_sub_1 (p + pla - 1, p + pla - 1, 1, CNST_LIMB(1)); 817 } 818 } 819 else if (cc == 1) 820 { 821 if (pla >= 2 * pl) 822 { 823 while ((cc = mpn_add_1 (p + pla - 2 * pl, p + pla - 2 * pl, 2 * pl, cc))) 824 ; 825 } 826 else 827 { 828 cc = mpn_sub_1 (p + pla - pl, p + pla - pl, pl, cc); 829 ASSERT (cc == 0); 830 } 831 } 832 else 833 ASSERT (cc == 0); 834 835 /* here p < 2^(2M) [K 2^(M(K-1)) + (K-1) 2^(M(K-2)) + ... ] 836 < K 2^(2M) [2^(M(K-1)) + 2^(M(K-2)) + ... ] 837 < K 2^(2M) 2^(M(K-1))*2 = 2^(M*K+M+k+1) */ 838 return mpn_fft_norm_modF (op, pl, p, pla); 839} 840 841/* return the lcm of a and 2^k */ 842static mp_bitcnt_t 843mpn_mul_fft_lcm (mp_bitcnt_t a, int k) 844{ 845 mp_bitcnt_t l = k; 846 847 while (a % 2 == 0 && k > 0) 848 { 849 a >>= 1; 850 k --; 851 } 852 return a << l; 853} 854 855 856mp_limb_t 857mpn_mul_fft (mp_ptr op, mp_size_t pl, 858 mp_srcptr n, mp_size_t nl, 859 mp_srcptr m, mp_size_t ml, 860 int k) 861{ 862 int i; 863 mp_size_t K, maxLK; 864 mp_size_t N, Nprime, nprime, M, Mp, l; 865 mp_ptr *Ap, *Bp, A, T, B; 866 int **fft_l, *tmp; 867 int sqr = (n == m && nl == ml); 868 mp_limb_t h; 869 TMP_DECL; 870 871 TRACE (printf ("\nmpn_mul_fft pl=%ld nl=%ld ml=%ld k=%d\n", pl, nl, ml, k)); 872 ASSERT_ALWAYS (mpn_fft_next_size (pl, k) == pl); 873 874 TMP_MARK; 875 N = pl * GMP_NUMB_BITS; 876 fft_l = TMP_BALLOC_TYPE (k + 1, int *); 877 tmp = TMP_BALLOC_TYPE ((size_t) 2 << k, int); 878 for (i = 0; i <= k; i++) 879 { 880 fft_l[i] = tmp; 881 tmp += (mp_size_t) 1 << i; 882 } 883 884 mpn_fft_initl (fft_l, k); 885 K = (mp_size_t) 1 << k; 886 M = N >> k; /* N = 2^k M */ 887 l = 1 + (M - 1) / GMP_NUMB_BITS; 888 maxLK = mpn_mul_fft_lcm (GMP_NUMB_BITS, k); /* lcm (GMP_NUMB_BITS, 2^k) */ 889 890 Nprime = (1 + (2 * M + k + 2) / maxLK) * maxLK; 891 /* Nprime = ceil((2*M+k+3)/maxLK)*maxLK; */ 892 nprime = Nprime / GMP_NUMB_BITS; 893 TRACE (printf ("N=%ld K=%ld, M=%ld, l=%ld, maxLK=%ld, Np=%ld, np=%ld\n", 894 N, K, M, l, maxLK, Nprime, nprime)); 895 /* we should ensure that recursively, nprime is a multiple of the next K */ 896 if (nprime >= (sqr ? SQR_FFT_MODF_THRESHOLD : MUL_FFT_MODF_THRESHOLD)) 897 { 898 mp_size_t K2; 899 for (;;) 900 { 901 K2 = (mp_size_t) 1 << mpn_fft_best_k (nprime, sqr); 902 if ((nprime & (K2 - 1)) == 0) 903 break; 904 nprime = (nprime + K2 - 1) & -K2; 905 Nprime = nprime * GMP_LIMB_BITS; 906 /* warning: since nprime changed, K2 may change too! */ 907 } 908 TRACE (printf ("new maxLK=%ld, Np=%ld, np=%ld\n", maxLK, Nprime, nprime)); 909 } 910 ASSERT_ALWAYS (nprime < pl); /* otherwise we'll loop */ 911 912 T = TMP_BALLOC_LIMBS (2 * (nprime + 1)); 913 Mp = Nprime >> k; 914 915 TRACE (printf ("%ldx%ld limbs -> %ld times %ldx%ld limbs (%1.2f)\n", 916 pl, pl, K, nprime, nprime, 2.0 * (double) N / Nprime / K); 917 printf (" temp space %ld\n", 2 * K * (nprime + 1))); 918 919 A = TMP_BALLOC_LIMBS (K * (nprime + 1)); 920 Ap = TMP_BALLOC_MP_PTRS (K); 921 mpn_mul_fft_decompose (A, Ap, K, nprime, n, nl, l, Mp, T); 922 if (sqr) 923 { 924 mp_size_t pla; 925 pla = l * (K - 1) + nprime + 1; /* number of required limbs for p */ 926 B = TMP_BALLOC_LIMBS (pla); 927 Bp = TMP_BALLOC_MP_PTRS (K); 928 } 929 else 930 { 931 B = TMP_BALLOC_LIMBS (K * (nprime + 1)); 932 Bp = TMP_BALLOC_MP_PTRS (K); 933 mpn_mul_fft_decompose (B, Bp, K, nprime, m, ml, l, Mp, T); 934 } 935 h = mpn_mul_fft_internal (op, pl, k, Ap, Bp, A, B, nprime, l, Mp, fft_l, T, sqr); 936 937 TMP_FREE; 938 return h; 939} 940 941#if WANT_OLD_FFT_FULL 942/* multiply {n, nl} by {m, ml}, and put the result in {op, nl+ml} */ 943void 944mpn_mul_fft_full (mp_ptr op, 945 mp_srcptr n, mp_size_t nl, 946 mp_srcptr m, mp_size_t ml) 947{ 948 mp_ptr pad_op; 949 mp_size_t pl, pl2, pl3, l; 950 mp_size_t cc, c2, oldcc; 951 int k2, k3; 952 int sqr = (n == m && nl == ml); 953 954 pl = nl + ml; /* total number of limbs of the result */ 955 956 /* perform a fft mod 2^(2N)+1 and one mod 2^(3N)+1. 957 We must have pl3 = 3/2 * pl2, with pl2 a multiple of 2^k2, and 958 pl3 a multiple of 2^k3. Since k3 >= k2, both are multiples of 2^k2, 959 and pl2 must be an even multiple of 2^k2. Thus (pl2,pl3) = 960 (2*j*2^k2,3*j*2^k2), which works for 3*j <= pl/2^k2 <= 5*j. 961 We need that consecutive intervals overlap, i.e. 5*j >= 3*(j+1), 962 which requires j>=2. Thus this scheme requires pl >= 6 * 2^FFT_FIRST_K. */ 963 964 /* ASSERT_ALWAYS(pl >= 6 * (1 << FFT_FIRST_K)); */ 965 966 pl2 = (2 * pl - 1) / 5; /* ceil (2pl/5) - 1 */ 967 do 968 { 969 pl2++; 970 k2 = mpn_fft_best_k (pl2, sqr); /* best fft size for pl2 limbs */ 971 pl2 = mpn_fft_next_size (pl2, k2); 972 pl3 = 3 * pl2 / 2; /* since k>=FFT_FIRST_K=4, pl2 is a multiple of 2^4, 973 thus pl2 / 2 is exact */ 974 k3 = mpn_fft_best_k (pl3, sqr); 975 } 976 while (mpn_fft_next_size (pl3, k3) != pl3); 977 978 TRACE (printf ("mpn_mul_fft_full nl=%ld ml=%ld -> pl2=%ld pl3=%ld k=%d\n", 979 nl, ml, pl2, pl3, k2)); 980 981 ASSERT_ALWAYS(pl3 <= pl); 982 cc = mpn_mul_fft (op, pl3, n, nl, m, ml, k3); /* mu */ 983 ASSERT(cc == 0); 984 pad_op = __GMP_ALLOCATE_FUNC_LIMBS (pl2); 985 cc = mpn_mul_fft (pad_op, pl2, n, nl, m, ml, k2); /* lambda */ 986 cc = -cc + mpn_sub_n (pad_op, pad_op, op, pl2); /* lambda - low(mu) */ 987 /* 0 <= cc <= 1 */ 988 ASSERT(0 <= cc && cc <= 1); 989 l = pl3 - pl2; /* l = pl2 / 2 since pl3 = 3/2 * pl2 */ 990 c2 = mpn_add_n (pad_op, pad_op, op + pl2, l); 991 cc = mpn_add_1 (pad_op + l, pad_op + l, l, (mp_limb_t) c2) - cc; 992 ASSERT(-1 <= cc && cc <= 1); 993 if (cc < 0) 994 cc = mpn_add_1 (pad_op, pad_op, pl2, (mp_limb_t) -cc); 995 ASSERT(0 <= cc && cc <= 1); 996 /* now lambda-mu = {pad_op, pl2} - cc mod 2^(pl2*GMP_NUMB_BITS)+1 */ 997 oldcc = cc; 998#if HAVE_NATIVE_mpn_add_n_sub_n 999 c2 = mpn_add_n_sub_n (pad_op + l, pad_op, pad_op, pad_op + l, l); 1000 cc += c2 >> 1; /* carry out from high <- low + high */ 1001 c2 = c2 & 1; /* borrow out from low <- low - high */ 1002#else 1003 { 1004 mp_ptr tmp; 1005 TMP_DECL; 1006 1007 TMP_MARK; 1008 tmp = TMP_BALLOC_LIMBS (l); 1009 MPN_COPY (tmp, pad_op, l); 1010 c2 = mpn_sub_n (pad_op, pad_op, pad_op + l, l); 1011 cc += mpn_add_n (pad_op + l, tmp, pad_op + l, l); 1012 TMP_FREE; 1013 } 1014#endif 1015 c2 += oldcc; 1016 /* first normalize {pad_op, pl2} before dividing by 2: c2 is the borrow 1017 at pad_op + l, cc is the carry at pad_op + pl2 */ 1018 /* 0 <= cc <= 2 */ 1019 cc -= mpn_sub_1 (pad_op + l, pad_op + l, l, (mp_limb_t) c2); 1020 /* -1 <= cc <= 2 */ 1021 if (cc > 0) 1022 cc = -mpn_sub_1 (pad_op, pad_op, pl2, (mp_limb_t) cc); 1023 /* now -1 <= cc <= 0 */ 1024 if (cc < 0) 1025 cc = mpn_add_1 (pad_op, pad_op, pl2, (mp_limb_t) -cc); 1026 /* now {pad_op, pl2} is normalized, with 0 <= cc <= 1 */ 1027 if (pad_op[0] & 1) /* if odd, add 2^(pl2*GMP_NUMB_BITS)+1 */ 1028 cc += 1 + mpn_add_1 (pad_op, pad_op, pl2, CNST_LIMB(1)); 1029 /* now 0 <= cc <= 2, but cc=2 cannot occur since it would give a carry 1030 out below */ 1031 mpn_rshift (pad_op, pad_op, pl2, 1); /* divide by two */ 1032 if (cc) /* then cc=1 */ 1033 pad_op [pl2 - 1] |= (mp_limb_t) 1 << (GMP_NUMB_BITS - 1); 1034 /* now {pad_op,pl2}-cc = (lambda-mu)/(1-2^(l*GMP_NUMB_BITS)) 1035 mod 2^(pl2*GMP_NUMB_BITS) + 1 */ 1036 c2 = mpn_add_n (op, op, pad_op, pl2); /* no need to add cc (is 0) */ 1037 /* since pl2+pl3 >= pl, necessary the extra limbs (including cc) are zero */ 1038 MPN_COPY (op + pl3, pad_op, pl - pl3); 1039 ASSERT_MPN_ZERO_P (pad_op + pl - pl3, pl2 + pl3 - pl); 1040 __GMP_FREE_FUNC_LIMBS (pad_op, pl2); 1041 /* since the final result has at most pl limbs, no carry out below */ 1042 mpn_add_1 (op + pl2, op + pl2, pl - pl2, (mp_limb_t) c2); 1043} 1044#endif 1045