1/* mpz_bin_ui(RESULT, N, K) -- Set RESULT to N over K.
2
3Copyright 1998-2002, 2012, 2013, 2015, 2017-2018 Free Software Foundation, Inc.
4
5This file is part of the GNU MP Library.
6
7The GNU MP Library is free software; you can redistribute it and/or modify
8it under the terms of either:
9
10  * the GNU Lesser General Public License as published by the Free
11    Software Foundation; either version 3 of the License, or (at your
12    option) any later version.
13
14or
15
16  * the GNU General Public License as published by the Free Software
17    Foundation; either version 2 of the License, or (at your option) any
18    later version.
19
20or both in parallel, as here.
21
22The GNU MP Library is distributed in the hope that it will be useful, but
23WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
24or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
25for more details.
26
27You should have received copies of the GNU General Public License and the
28GNU Lesser General Public License along with the GNU MP Library.  If not,
29see https://www.gnu.org/licenses/.  */
30
31#include "gmp-impl.h"
32
33/* How many special cases? Minimum is 2: 0 and 1;
34 * also 3 {0,1,2} and 5 {0,1,2,3,4} are implemented.
35 */
36#define APARTAJ_KALKULOJ 2
37
38/* Whether to use (1) or not (0) the function mpz_bin_uiui whenever
39 * the operands fit.
40 */
41#define UZU_BIN_UIUI 0
42
43/* Whether to use a shortcut to precompute the product of four
44 * elements (1), or precompute only the product of a couple (0).
45 *
46 * In both cases the precomputed product is then updated with some
47 * linear operations to obtain the product of the next four (1)
48 * [or two (0)] operands.
49 */
50#define KVAROPE 1
51
52static void
53posmpz_init (mpz_ptr r)
54{
55  mp_ptr rp;
56  ASSERT (SIZ (r) > 0);
57  rp = SIZ (r) + MPZ_REALLOC (r, SIZ (r) + 2);
58  *rp = 0;
59  *++rp = 0;
60}
61
62/* Equivalent to mpz_add_ui (r, r, in), but faster when
63   0 < SIZ (r) < ALLOC (r) and limbs above SIZ (r) contain 0. */
64static void
65posmpz_inc_ui (mpz_ptr r, unsigned long in)
66{
67#if BITS_PER_ULONG > GMP_NUMB_BITS
68  mpz_add_ui (r, r, in);
69#else
70  ASSERT (SIZ (r) > 0);
71  MPN_INCR_U (PTR (r), SIZ (r) + 1, in);
72  SIZ (r) += (PTR (r)[SIZ (r)] != 0);
73#endif
74}
75
76/* Equivalent to mpz_sub_ui (r, r, in), but faster when
77   0 < SIZ (r) and we know in advance that the result is positive. */
78static void
79posmpz_dec_ui (mpz_ptr r, unsigned long in)
80{
81#if BITS_PER_ULONG > GMP_NUMB_BITS
82  mpz_sub_ui (r, r, in);
83#else
84  ASSERT (mpz_cmp_ui (r, in) >= 0);
85  MPN_DECR_U (PTR (r), SIZ (r), in);
86  SIZ (r) -= (PTR (r)[SIZ (r)-1] == 0);
87#endif
88}
89
90/* Equivalent to mpz_tdiv_q_2exp (r, r, 1), but faster when
91   0 < SIZ (r) and we know in advance that the result is positive. */
92static void
93posmpz_rsh1 (mpz_ptr r)
94{
95  mp_ptr rp;
96  mp_size_t rn;
97
98  rn = SIZ (r);
99  rp = PTR (r);
100  ASSERT (rn > 0);
101  mpn_rshift (rp, rp, rn, 1);
102  SIZ (r) -= rp[rn - 1] == 0;
103}
104
105/* Computes r = n(n+(2*k-1))/2
106   It uses a sqare instead of a product, computing
107   r = ((n+k-1)^2 + n - (k-1)^2)/2
108   As a side effect, sets t = n+k-1
109 */
110static void
111mpz_hmul_nbnpk (mpz_ptr r, mpz_srcptr n, unsigned long int k, mpz_ptr t)
112{
113  ASSERT (k > 0 && SIZ(n) > 0);
114  --k;
115  mpz_add_ui (t, n, k);
116  mpz_mul (r, t, t);
117  mpz_add (r, r, n);
118  posmpz_rsh1 (r);
119  if (LIKELY (k <= (1UL << (BITS_PER_ULONG / 2))))
120    posmpz_dec_ui (r, (k + (k & 1))*(k >> 1));
121  else
122    {
123      mpz_t tmp;
124      mpz_init_set_ui (tmp, (k + (k & 1)));
125      mpz_mul_ui (tmp, tmp, k >> 1);
126      mpz_sub (r, r, tmp);
127      mpz_clear (tmp);
128    }
129}
130
131#if KVAROPE
132static void
133rek_raising_fac4 (mpz_ptr r, mpz_ptr p, mpz_ptr P, unsigned long int k, unsigned long int lk, mpz_ptr t)
134{
135  if (k - lk < 5)
136    {
137      do {
138	posmpz_inc_ui (p, 4*k+2);
139	mpz_addmul_ui (P, p, 4*k);
140	posmpz_dec_ui (P, k);
141	mpz_mul (r, r, P);
142      } while (--k > lk);
143    }
144  else
145    {
146      mpz_t lt;
147      unsigned long int m;
148
149      m = ((k + lk) >> 1) + 1;
150      rek_raising_fac4 (r, p, P, k, m, t);
151
152      posmpz_inc_ui (p, 4*m+2);
153      mpz_addmul_ui (P, p, 4*m);
154      posmpz_dec_ui (P, m);
155      if (t == NULL)
156	{
157	  mpz_init_set (lt, P);
158	  t = lt;
159	}
160      else
161	{
162	  ALLOC (lt) = 0;
163	  mpz_set (t, P);
164	}
165      rek_raising_fac4 (t, p, P, m - 1, lk, NULL);
166
167      mpz_mul (r, r, t);
168      mpz_clear (lt);
169    }
170}
171
172/* Computes (n+1)(n+2)...(n+k)/2^(k/2 +k/4) using the helper function
173   rek_raising_fac4, and exploiting an idea inspired by a piece of
174   code that Fredrik Johansson wrote and by a comment by Niels M��ller.
175
176   Assume k = 4i then compute:
177     p  = (n+1)(n+4i)/2 - i
178	  (n+1+1)(n+4i)/2 = p + i + (n+4i)/2
179	  (n+1+1)(n+4i-1)/2 = p + i + ((n+4i)-(n+1+1))/2 = p + i + (n-n+4i-2)/2 = p + 3i-1
180     P  = (p + i)*(p+3i-1)/2 = (n+1)(n+2)(n+4i-1)(n+4i)/8
181     n' = n + 2
182     i' = i - 1
183	  (n'-1)(n')(n'+4i'+1)(n'+4i'+2)/8 = P
184	  (n'-1)(n'+4i'+2)/2 - i' - 1 = p
185	  (n'-1+2)(n'+4i'+2)/2 - i' - 1 = p + (n'+4i'+2)
186	  (n'-1+2)(n'+4i'+2-2)/2 - i' - 1 = p + (n'+4i'+2) - (n'-1+2) =  p + 4i' + 1
187	  (n'-1+2)(n'+4i'+2-2)/2 - i' = p + 4i' + 2
188     p' = p + 4i' + 2 = (n'+1)(n'+4i')/2 - i'
189	  p' - 4i' - 2 = p
190	  (p' - 4i' - 2 + i)*(p' - 4i' - 2+3i-1)/2 = P
191	  (p' - 4i' - 2 + i' + 1)*(p' - 4i' - 2 + 3i' + 3 - 1)/2 = P
192	  (p' - 3i' - 1)*(p' - i')/2 = P
193	  (p' - 3i' - 1 + 4i' + 1)*(p' - i' + 4i' - 1)/2 = P + (4i' + 1)*(p' - i')/2 + (p' - 3i' - 1 + 4i' + 1)*(4i' - 1)/2
194	  (p' + i')*(p' + 3i' - 1)/2 = P + (4i')*(p' + p')/2 + (p' - i' - (p' + i'))/2
195	  (p' + i')*(p' + 3i' - 1)/2 = P + 4i'p' + (p' - i' - p' - i')/2
196	  (p' + i')*(p' + 3i' - 1)/2 = P + 4i'p' - i'
197     P' = P + 4i'p' - i'
198
199   And compute the product P * P' * P" ...
200 */
201
202static void
203mpz_raising_fac4 (mpz_ptr r, mpz_ptr n, unsigned long int k, mpz_ptr t, mpz_ptr p)
204{
205  ASSERT ((k >= APARTAJ_KALKULOJ) && (APARTAJ_KALKULOJ > 0));
206  posmpz_init (n);
207  posmpz_inc_ui (n, 1);
208  SIZ (r) = 0;
209  if (k & 1)
210    {
211      mpz_set (r, n);
212      posmpz_inc_ui (n, 1);
213    }
214  k >>= 1;
215  if (APARTAJ_KALKULOJ < 2 && k == 0)
216    return;
217
218  mpz_hmul_nbnpk (p, n, k, t);
219  posmpz_init (p);
220
221  if (k & 1)
222    {
223      if (SIZ (r))
224	mpz_mul (r, r, p);
225      else
226	mpz_set (r, p);
227      posmpz_inc_ui (p, k - 1);
228    }
229  k >>= 1;
230  if (APARTAJ_KALKULOJ < 4 && k == 0)
231    return;
232
233  mpz_hmul_nbnpk (t, p, k, n);
234  if (SIZ (r))
235    mpz_mul (r, r, t);
236  else
237    mpz_set (r, t);
238
239  if (APARTAJ_KALKULOJ > 8 || k > 1)
240    {
241      posmpz_dec_ui (p, k);
242      rek_raising_fac4 (r, p, t, k - 1, 0, n);
243    }
244}
245
246#else /* KVAROPE */
247
248static void
249rek_raising_fac (mpz_ptr r, mpz_ptr n, unsigned long int k, unsigned long int lk, mpz_ptr t1, mpz_ptr t2)
250{
251  /* Should the threshold depend on SIZ (n) ? */
252  if (k - lk < 10)
253    {
254      do {
255	posmpz_inc_ui (n, k);
256	mpz_mul (r, r, n);
257	--k;
258      } while (k > lk);
259    }
260  else
261    {
262      mpz_t t3;
263      unsigned long int m;
264
265      m = ((k + lk) >> 1) + 1;
266      rek_raising_fac (r, n, k, m, t1, t2);
267
268      posmpz_inc_ui (n, m);
269      if (t1 == NULL)
270	{
271	  mpz_init_set (t3, n);
272	  t1 = t3;
273	}
274      else
275	{
276	  ALLOC (t3) = 0;
277	  mpz_set (t1, n);
278	}
279      rek_raising_fac (t1, n, m - 1, lk, t2, NULL);
280
281      mpz_mul (r, r, t1);
282      mpz_clear (t3);
283    }
284}
285
286/* Computes (n+1)(n+2)...(n+k)/2^(k/2) using the helper function
287   rek_raising_fac, and exploiting an idea inspired by a piece of
288   code that Fredrik Johansson wrote.
289
290   Force an even k = 2i then compute:
291     p  = (n+1)(n+2i)/2
292     i' = i - 1
293     p == (n+1)(n+2i'+2)/2
294     p' = p + i' == (n+2)(n+2i'+1)/2
295     n' = n + 1
296     p'== (n'+1)(n'+2i')/2 == (n+1 +1)(n+2i -1)/2
297
298   And compute the product p * p' * p" ...
299*/
300
301static void
302mpz_raising_fac (mpz_ptr r, mpz_ptr n, unsigned long int k, mpz_ptr t, mpz_ptr p)
303{
304  unsigned long int hk;
305  ASSERT ((k >= APARTAJ_KALKULOJ) && (APARTAJ_KALKULOJ > 1));
306  mpz_add_ui (n, n, 1);
307  hk = k >> 1;
308  mpz_hmul_nbnpk (p, n, hk, t);
309
310  if ((k & 1) != 0)
311    {
312      mpz_add_ui (t, t, hk + 1);
313      mpz_mul (r, t, p);
314    }
315  else
316    {
317      mpz_set (r, p);
318    }
319
320  if ((APARTAJ_KALKULOJ > 3) || (hk > 1))
321    {
322      posmpz_init (p);
323      rek_raising_fac (r, p, hk - 1, 0, t, n);
324    }
325}
326#endif /* KVAROPE */
327
328/* This is a poor implementation.  Look at bin_uiui.c for improvement ideas.
329   In fact consider calling mpz_bin_uiui() when the arguments fit, leaving
330   the code here only for big n.
331
332   The identity bin(n,k) = (-1)^k * bin(-n+k-1,k) can be found in Knuth vol
333   1 section 1.2.6 part G. */
334
335void
336mpz_bin_ui (mpz_ptr r, mpz_srcptr n, unsigned long int k)
337{
338  mpz_t      ni;
339  mp_size_t  negate;
340
341  if (SIZ (n) < 0)
342    {
343      /* bin(n,k) = (-1)^k * bin(-n+k-1,k), and set ni = -n+k-1 - k = -n-1 */
344      mpz_init (ni);
345      mpz_add_ui (ni, n, 1L);
346      mpz_neg (ni, ni);
347      negate = (k & 1);   /* (-1)^k */
348    }
349  else
350    {
351      /* bin(n,k) == 0 if k>n
352	 (no test for this under the n<0 case, since -n+k-1 >= k there) */
353      if (mpz_cmp_ui (n, k) < 0)
354	{
355	  SIZ (r) = 0;
356	  return;
357	}
358
359      /* set ni = n-k */
360      mpz_init (ni);
361      mpz_sub_ui (ni, n, k);
362      negate = 0;
363    }
364
365  /* Now wanting bin(ni+k,k), with ni positive, and "negate" is the sign (0
366     for positive, 1 for negative). */
367
368  /* Rewrite bin(n,k) as bin(n,n-k) if that is smaller.  In this case it's
369     whether ni+k-k < k meaning ni<k, and if so change to denominator ni+k-k
370     = ni, and new ni of ni+k-ni = k.  */
371  if (mpz_cmp_ui (ni, k) < 0)
372    {
373      unsigned long  tmp;
374      tmp = k;
375      k = mpz_get_ui (ni);
376      mpz_set_ui (ni, tmp);
377    }
378
379  if (k < APARTAJ_KALKULOJ)
380    {
381      if (k == 0)
382	{
383	  SIZ (r) = 1;
384	  MPZ_NEWALLOC (r, 1)[0] = 1;
385	}
386#if APARTAJ_KALKULOJ > 2
387      else if (k == 2)
388	{
389	  mpz_add_ui (ni, ni, 1);
390	  mpz_mul (r, ni, ni);
391	  mpz_add (r, r, ni);
392	  posmpz_rsh1 (r);
393	}
394#endif
395#if APARTAJ_KALKULOJ > 3
396      else if (k > 2)
397	{ /* k = 3, 4 */
398	  mpz_add_ui (ni, ni, 2); /* n+1 */
399	  mpz_mul (r, ni, ni); /* (n+1)^2 */
400	  mpz_sub_ui (r, r, 1); /* (n+1)^2-1 */
401	  if (k == 3)
402	    {
403	      mpz_mul (r, r, ni); /* ((n+1)^2-1)(n+1) = n(n+1)(n+2) */
404	      /* mpz_divexact_ui (r, r, 6); /\* 6=3<<1; div_by3 ? *\/ */
405	      mpn_pi1_bdiv_q_1 (PTR(r), PTR(r), SIZ(r), 3, GMP_NUMB_MASK/3*2+1, 1);
406	      MPN_NORMALIZE_NOT_ZERO (PTR(r), SIZ(r));
407	    }
408	  else /* k = 4 */
409	    {
410	      mpz_add (ni, ni, r); /* (n+1)^2+n */
411	      mpz_mul (r, ni, ni); /* ((n+1)^2+n)^2 */
412	      mpz_sub_ui (r, r, 1); /* ((n+1)^2+n)^2-1 = n(n+1)(n+2)(n+3) */
413	      /* mpz_divexact_ui (r, r, 24); /\* 24=3<<3; div_by3 ? *\/ */
414	      mpn_pi1_bdiv_q_1 (PTR(r), PTR(r), SIZ(r), 3, GMP_NUMB_MASK/3*2+1, 3);
415	      MPN_NORMALIZE_NOT_ZERO (PTR(r), SIZ(r));
416	    }
417	}
418#endif
419      else
420	{ /* k = 1 */
421	  mpz_add_ui (r, ni, 1);
422	}
423    }
424#if UZU_BIN_UIUI
425  else if (mpz_cmp_ui (ni, ULONG_MAX - k) <= 0)
426    {
427      mpz_bin_uiui (r, mpz_get_ui (ni) + k, k);
428    }
429#endif
430  else
431    {
432      mp_limb_t count;
433      mpz_t num, den;
434
435      mpz_init (num);
436      mpz_init (den);
437
438#if KVAROPE
439      mpz_raising_fac4 (num, ni, k, den, r);
440      popc_limb (count, k);
441      ASSERT (k - (k >> 1) - (k >> 2) - count >= 0);
442      mpz_tdiv_q_2exp (num, num, k - (k >> 1) - (k >> 2) - count);
443#else
444      mpz_raising_fac (num, ni, k, den, r);
445      popc_limb (count, k);
446      ASSERT (k - (k >> 1) - count >= 0);
447      mpz_tdiv_q_2exp (num, num, k - (k >> 1) - count);
448#endif
449
450      mpz_oddfac_1(den, k, 0);
451
452      mpz_divexact(r, num, den);
453      mpz_clear (num);
454      mpz_clear (den);
455    }
456  mpz_clear (ni);
457
458  SIZ(r) = (SIZ(r) ^ -negate) + negate;
459}
460