1/* mpn_toom32_mul -- Multiply {ap,an} and {bp,bn} where an is nominally 1.5
2   times as large as bn.  Or more accurately, bn < an < 3bn.
3
4   Contributed to the GNU project by Torbjorn Granlund.
5   Improvements by Marco Bodrato and Niels M��ller.
6
7   The idea of applying toom to unbalanced multiplication is due to Marco
8   Bodrato and Alberto Zanoni.
9
10   THE FUNCTION IN THIS FILE IS INTERNAL WITH A MUTABLE INTERFACE.  IT IS ONLY
11   SAFE TO REACH IT THROUGH DOCUMENTED INTERFACES.  IN FACT, IT IS ALMOST
12   GUARANTEED THAT IT WILL CHANGE OR DISAPPEAR IN A FUTURE GNU MP RELEASE.
13
14Copyright 2006-2010 Free Software Foundation, Inc.
15
16This file is part of the GNU MP Library.
17
18The GNU MP Library is free software; you can redistribute it and/or modify
19it under the terms of either:
20
21  * the GNU Lesser General Public License as published by the Free
22    Software Foundation; either version 3 of the License, or (at your
23    option) any later version.
24
25or
26
27  * the GNU General Public License as published by the Free Software
28    Foundation; either version 2 of the License, or (at your option) any
29    later version.
30
31or both in parallel, as here.
32
33The GNU MP Library is distributed in the hope that it will be useful, but
34WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
35or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
36for more details.
37
38You should have received copies of the GNU General Public License and the
39GNU Lesser General Public License along with the GNU MP Library.  If not,
40see https://www.gnu.org/licenses/.  */
41
42
43#include "gmp-impl.h"
44
45/* Evaluate in: -1, 0, +1, +inf
46
47  <-s-><--n--><--n-->
48   ___ ______ ______
49  |a2_|___a1_|___a0_|
50	|_b1_|___b0_|
51	<-t--><--n-->
52
53  v0  =  a0         * b0      #   A(0)*B(0)
54  v1  = (a0+ a1+ a2)*(b0+ b1) #   A(1)*B(1)      ah  <= 2  bh <= 1
55  vm1 = (a0- a1+ a2)*(b0- b1) #  A(-1)*B(-1)    |ah| <= 1  bh = 0
56  vinf=          a2 *     b1  # A(inf)*B(inf)
57*/
58
59#define TOOM32_MUL_N_REC(p, a, b, n, ws)				\
60  do {									\
61    mpn_mul_n (p, a, b, n);						\
62  } while (0)
63
64void
65mpn_toom32_mul (mp_ptr pp,
66		mp_srcptr ap, mp_size_t an,
67		mp_srcptr bp, mp_size_t bn,
68		mp_ptr scratch)
69{
70  mp_size_t n, s, t;
71  int vm1_neg;
72  mp_limb_t cy;
73  mp_limb_signed_t hi;
74  mp_limb_t ap1_hi, bp1_hi;
75
76#define a0  ap
77#define a1  (ap + n)
78#define a2  (ap + 2 * n)
79#define b0  bp
80#define b1  (bp + n)
81
82  /* Required, to ensure that s + t >= n. */
83  ASSERT (bn + 2 <= an && an + 6 <= 3*bn);
84
85  n = 1 + (2 * an >= 3 * bn ? (an - 1) / (size_t) 3 : (bn - 1) >> 1);
86
87  s = an - 2 * n;
88  t = bn - n;
89
90  ASSERT (0 < s && s <= n);
91  ASSERT (0 < t && t <= n);
92  ASSERT (s + t >= n);
93
94  /* Product area of size an + bn = 3*n + s + t >= 4*n + 2. */
95#define ap1 (pp)		/* n, most significant limb in ap1_hi */
96#define bp1 (pp + n)		/* n, most significant bit in bp1_hi */
97#define am1 (pp + 2*n)		/* n, most significant bit in hi */
98#define bm1 (pp + 3*n)		/* n */
99#define v1 (scratch)		/* 2n + 1 */
100#define vm1 (pp)		/* 2n + 1 */
101#define scratch_out (scratch + 2*n + 1) /* Currently unused. */
102
103  /* Scratch need: 2*n + 1 + scratch for the recursive multiplications. */
104
105  /* FIXME: Keep v1[2*n] and vm1[2*n] in scalar variables? */
106
107  /* Compute ap1 = a0 + a1 + a2, am1 = a0 - a1 + a2 */
108  ap1_hi = mpn_add (ap1, a0, n, a2, s);
109#if HAVE_NATIVE_mpn_add_n_sub_n
110  if (ap1_hi == 0 && mpn_cmp (ap1, a1, n) < 0)
111    {
112      ap1_hi = mpn_add_n_sub_n (ap1, am1, a1, ap1, n) >> 1;
113      hi = 0;
114      vm1_neg = 1;
115    }
116  else
117    {
118      cy = mpn_add_n_sub_n (ap1, am1, ap1, a1, n);
119      hi = ap1_hi - (cy & 1);
120      ap1_hi += (cy >> 1);
121      vm1_neg = 0;
122    }
123#else
124  if (ap1_hi == 0 && mpn_cmp (ap1, a1, n) < 0)
125    {
126      ASSERT_NOCARRY (mpn_sub_n (am1, a1, ap1, n));
127      hi = 0;
128      vm1_neg = 1;
129    }
130  else
131    {
132      hi = ap1_hi - mpn_sub_n (am1, ap1, a1, n);
133      vm1_neg = 0;
134    }
135  ap1_hi += mpn_add_n (ap1, ap1, a1, n);
136#endif
137
138  /* Compute bp1 = b0 + b1 and bm1 = b0 - b1. */
139  if (t == n)
140    {
141#if HAVE_NATIVE_mpn_add_n_sub_n
142      if (mpn_cmp (b0, b1, n) < 0)
143	{
144	  cy = mpn_add_n_sub_n (bp1, bm1, b1, b0, n);
145	  vm1_neg ^= 1;
146	}
147      else
148	{
149	  cy = mpn_add_n_sub_n (bp1, bm1, b0, b1, n);
150	}
151      bp1_hi = cy >> 1;
152#else
153      bp1_hi = mpn_add_n (bp1, b0, b1, n);
154
155      if (mpn_cmp (b0, b1, n) < 0)
156	{
157	  ASSERT_NOCARRY (mpn_sub_n (bm1, b1, b0, n));
158	  vm1_neg ^= 1;
159	}
160      else
161	{
162	  ASSERT_NOCARRY (mpn_sub_n (bm1, b0, b1, n));
163	}
164#endif
165    }
166  else
167    {
168      /* FIXME: Should still use mpn_add_n_sub_n for the main part. */
169      bp1_hi = mpn_add (bp1, b0, n, b1, t);
170
171      if (mpn_zero_p (b0 + t, n - t) && mpn_cmp (b0, b1, t) < 0)
172	{
173	  ASSERT_NOCARRY (mpn_sub_n (bm1, b1, b0, t));
174	  MPN_ZERO (bm1 + t, n - t);
175	  vm1_neg ^= 1;
176	}
177      else
178	{
179	  ASSERT_NOCARRY (mpn_sub (bm1, b0, n, b1, t));
180	}
181    }
182
183  TOOM32_MUL_N_REC (v1, ap1, bp1, n, scratch_out);
184  if (ap1_hi == 1)
185    {
186      cy = bp1_hi + mpn_add_n (v1 + n, v1 + n, bp1, n);
187    }
188  else if (ap1_hi == 2)
189    {
190#if HAVE_NATIVE_mpn_addlsh1_n
191      cy = 2 * bp1_hi + mpn_addlsh1_n (v1 + n, v1 + n, bp1, n);
192#else
193      cy = 2 * bp1_hi + mpn_addmul_1 (v1 + n, bp1, n, CNST_LIMB(2));
194#endif
195    }
196  else
197    cy = 0;
198  if (bp1_hi != 0)
199    cy += mpn_add_n (v1 + n, v1 + n, ap1, n);
200  v1[2 * n] = cy;
201
202  TOOM32_MUL_N_REC (vm1, am1, bm1, n, scratch_out);
203  if (hi)
204    hi = mpn_add_n (vm1+n, vm1+n, bm1, n);
205
206  vm1[2*n] = hi;
207
208  /* v1 <-- (v1 + vm1) / 2 = x0 + x2 */
209  if (vm1_neg)
210    {
211#if HAVE_NATIVE_mpn_rsh1sub_n
212      mpn_rsh1sub_n (v1, v1, vm1, 2*n+1);
213#else
214      mpn_sub_n (v1, v1, vm1, 2*n+1);
215      ASSERT_NOCARRY (mpn_rshift (v1, v1, 2*n+1, 1));
216#endif
217    }
218  else
219    {
220#if HAVE_NATIVE_mpn_rsh1add_n
221      mpn_rsh1add_n (v1, v1, vm1, 2*n+1);
222#else
223      mpn_add_n (v1, v1, vm1, 2*n+1);
224      ASSERT_NOCARRY (mpn_rshift (v1, v1, 2*n+1, 1));
225#endif
226    }
227
228  /* We get x1 + x3 = (x0 + x2) - (x0 - x1 + x2 - x3), and hence
229
230     y = x1 + x3 + (x0 + x2) * B
231       = (x0 + x2) * B + (x0 + x2) - vm1.
232
233     y is 3*n + 1 limbs, y = y0 + y1 B + y2 B^2. We store them as
234     follows: y0 at scratch, y1 at pp + 2*n, and y2 at scratch + n
235     (already in place, except for carry propagation).
236
237     We thus add
238
239   B^3  B^2   B    1
240    |    |    |    |
241   +-----+----+
242 + |  x0 + x2 |
243   +----+-----+----+
244 +      |  x0 + x2 |
245	+----------+
246 -      |  vm1     |
247 --+----++----+----+-
248   | y2  | y1 | y0 |
249   +-----+----+----+
250
251  Since we store y0 at the same location as the low half of x0 + x2, we
252  need to do the middle sum first. */
253
254  hi = vm1[2*n];
255  cy = mpn_add_n (pp + 2*n, v1, v1 + n, n);
256  MPN_INCR_U (v1 + n, n + 1, cy + v1[2*n]);
257
258  /* FIXME: Can we get rid of this second vm1_neg conditional by
259     swapping the location of +1 and -1 values? */
260  if (vm1_neg)
261    {
262      cy = mpn_add_n (v1, v1, vm1, n);
263      hi += mpn_add_nc (pp + 2*n, pp + 2*n, vm1 + n, n, cy);
264      MPN_INCR_U (v1 + n, n+1, hi);
265    }
266  else
267    {
268      cy = mpn_sub_n (v1, v1, vm1, n);
269      hi += mpn_sub_nc (pp + 2*n, pp + 2*n, vm1 + n, n, cy);
270      MPN_DECR_U (v1 + n, n+1, hi);
271    }
272
273  TOOM32_MUL_N_REC (pp, a0, b0, n, scratch_out);
274  /* vinf, s+t limbs.  Use mpn_mul for now, to handle unbalanced operands */
275  if (s > t)  mpn_mul (pp+3*n, a2, s, b1, t);
276  else        mpn_mul (pp+3*n, b1, t, a2, s);
277
278  /* Remaining interpolation.
279
280     y * B + x0 + x3 B^3 - x0 B^2 - x3 B
281     = (x1 + x3) B + (x0 + x2) B^2 + x0 + x3 B^3 - x0 B^2 - x3 B
282     = y0 B + y1 B^2 + y3 B^3 + Lx0 + H x0 B
283       + L x3 B^3 + H x3 B^4 - Lx0 B^2 - H x0 B^3 - L x3 B - H x3 B^2
284     = L x0 + (y0 + H x0 - L x3) B + (y1 - L x0 - H x3) B^2
285       + (y2 - (H x0 - L x3)) B^3 + H x3 B^4
286
287	  B^4       B^3       B^2        B         1
288 |         |         |         |         |         |
289   +-------+                   +---------+---------+
290   |  Hx3  |                   | Hx0-Lx3 |    Lx0  |
291   +------+----------+---------+---------+---------+
292	  |    y2    |  y1     |   y0    |
293	  ++---------+---------+---------+
294	  -| Hx0-Lx3 | - Lx0   |
295	   +---------+---------+
296		      | - Hx3  |
297		      +--------+
298
299    We must take into account the carry from Hx0 - Lx3.
300  */
301
302  cy = mpn_sub_n (pp + n, pp + n, pp+3*n, n);
303  hi = scratch[2*n] + cy;
304
305  cy = mpn_sub_nc (pp + 2*n, pp + 2*n, pp, n, cy);
306  hi -= mpn_sub_nc (pp + 3*n, scratch + n, pp + n, n, cy);
307
308  hi += mpn_add (pp + n, pp + n, 3*n, scratch, n);
309
310  /* FIXME: Is support for s + t == n needed? */
311  if (LIKELY (s + t > n))
312    {
313      hi -= mpn_sub (pp + 2*n, pp + 2*n, 2*n, pp + 4*n, s+t-n);
314
315      if (hi < 0)
316	MPN_DECR_U (pp + 4*n, s+t-n, -hi);
317      else
318	MPN_INCR_U (pp + 4*n, s+t-n, hi);
319    }
320  else
321    ASSERT (hi == 0);
322}
323