1/* mpfr_digamma -- digamma function of a floating-point number
2
3Copyright 2009-2023 Free Software Foundation, Inc.
4Contributed by the AriC and Caramba projects, INRIA.
5
6This file is part of the GNU MPFR Library.
7
8The GNU MPFR Library is free software; you can redistribute it and/or modify
9it under the terms of the GNU Lesser General Public License as published by
10the Free Software Foundation; either version 3 of the License, or (at your
11option) any later version.
12
13The GNU MPFR Library is distributed in the hope that it will be useful, but
14WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
15or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU Lesser General Public
16License for more details.
17
18You should have received a copy of the GNU Lesser General Public License
19along with the GNU MPFR Library; see the file COPYING.LESSER.  If not, see
20https://www.gnu.org/licenses/ or write to the Free Software Foundation, Inc.,
2151 Franklin St, Fifth Floor, Boston, MA 02110-1301, USA. */
22
23#include "mpfr-impl.h"
24
25/* FIXME: Check that MPFR_GET_EXP can only be called on regular values
26   (in r14025, this is not the case) and that there cannot be integer
27   overflows. */
28
29/* Put in s an approximation of digamma(x).
30   Assumes x >= 2.
31   Assumes s does not overlap with x.
32   Returns an integer e such that the error is bounded by 2^e ulps
33   of the result s.
34*/
35static mpfr_exp_t
36mpfr_digamma_approx (mpfr_ptr s, mpfr_srcptr x)
37{
38  mpfr_prec_t p = MPFR_PREC (s);
39  mpfr_t t, u, invxx;
40  mpfr_exp_t e, exps, f, expu;
41  unsigned long n;
42
43  MPFR_ASSERTN (MPFR_IS_POS (x) && MPFR_GET_EXP (x) >= 2);
44
45  mpfr_init2 (t, p);
46  mpfr_init2 (u, p);
47  mpfr_init2 (invxx, p);
48
49  mpfr_log (s, x, MPFR_RNDN);         /* error <= 1/2 ulp */
50  mpfr_ui_div (t, 1, x, MPFR_RNDN);   /* error <= 1/2 ulp */
51  mpfr_div_2ui (t, t, 1, MPFR_RNDN); /* exact */
52  mpfr_sub (s, s, t, MPFR_RNDN);
53  /* error <= 1/2 + 1/2*2^(EXP(olds)-EXP(s)) + 1/2*2^(EXP(t)-EXP(s)).
54     For x >= 2, log(x) >= 2*(1/(2x)), thus olds >= 2t, and olds - t >= olds/2,
55     thus 0 <= EXP(olds)-EXP(s) <= 1, and EXP(t)-EXP(s) <= 0, thus
56     error <= 1/2 + 1/2*2 + 1/2 <= 2 ulps. */
57  e = 2; /* initial error */
58  mpfr_sqr (invxx, x, MPFR_RNDZ);     /* invxx = x^2 * (1 + theta)
59                                         for |theta| <= 2^(-p) */
60  mpfr_ui_div (invxx, 1, invxx, MPFR_RNDU); /* invxx = 1/x^2 * (1 + theta)^2 */
61
62  /* in the following we note err=xxx when the ratio between the approximation
63     and the exact result can be written (1 + theta)^xxx for |theta| <= 2^(-p),
64     following Higham's method */
65  mpfr_set_ui (t, 1, MPFR_RNDN); /* err = 0 */
66  for (n = 1;; n++)
67    {
68      /* The main term is Bernoulli[2n]/(2n)/x^(2n) = B[n]/(2n+1)!(2n)/x^(2n)
69         = B[n]*t[n]/(2n) where t[n]/t[n-1] = 1/(2n)/(2n+1)/x^2. */
70      mpfr_mul (t, t, invxx, MPFR_RNDU);        /* err = err + 3 */
71      mpfr_div_ui (t, t, 2 * n, MPFR_RNDU);     /* err = err + 1 */
72      mpfr_div_ui (t, t, 2 * n + 1, MPFR_RNDU); /* err = err + 1 */
73      /* we thus have err = 5n here */
74      mpfr_div_ui (u, t, 2 * n, MPFR_RNDU);     /* err = 5n+1 */
75      mpfr_mul_z (u, u, mpfr_bernoulli_cache(n), MPFR_RNDU);/* err = 5n+2, and the
76                                                   absolute error is bounded
77                                                   by 10n+4 ulp(u) [Rule 11] */
78      /* if the terms 'u' are decreasing by a factor two at least,
79         then the error coming from those is bounded by
80         sum((10n+4)/2^n, n=1..infinity) = 24 */
81      exps = MPFR_GET_EXP (s);
82      expu = MPFR_GET_EXP (u);
83      if (expu < exps - (mpfr_exp_t) p)
84        break;
85      mpfr_sub (s, s, u, MPFR_RNDN); /* error <= 24 + n/2 */
86      if (MPFR_GET_EXP (s) < exps)
87        e <<= exps - MPFR_GET_EXP (s);
88      e ++; /* error in mpfr_sub */
89      f = 10 * n + 4;
90      while (expu < exps)
91        {
92          f = (1 + f) / 2;
93          expu ++;
94        }
95      e += f; /* total rounding error coming from 'u' term */
96    }
97
98  mpfr_clear (t);
99  mpfr_clear (u);
100  mpfr_clear (invxx);
101
102  f = 0;
103  while (e > 1)
104    {
105      f++;
106      e = (e + 1) / 2;
107      /* Invariant: 2^f * e does not decrease */
108    }
109  return f;
110}
111
112/* Use the reflection formula Digamma(1-x) = Digamma(x) + Pi * cot(Pi*x),
113   i.e., Digamma(x) = Digamma(1-x) - Pi * cot(Pi*x).
114   Assume x < 1/2. */
115static int
116mpfr_digamma_reflection (mpfr_ptr y, mpfr_srcptr x, mpfr_rnd_t rnd_mode)
117{
118  mpfr_prec_t p = MPFR_PREC(y) + 10;
119  mpfr_t t, u, v;
120  mpfr_exp_t e1, expv, expx, q;
121  int inex;
122  MPFR_ZIV_DECL (loop);
123
124  MPFR_LOG_FUNC
125    (("x[%Pd]=%.*Rg rnd=%d", mpfr_get_prec(x), mpfr_log_prec, x, rnd_mode),
126     ("y[%Pd]=%.*Rg inexact=%d", mpfr_get_prec(y), mpfr_log_prec, y, inex));
127
128  /* we want that 1-x is exact with precision q: if 0 < x < 1/2, then
129     q = PREC(x)-EXP(x) is ok, otherwise if -1 <= x < 0, q = PREC(x)-EXP(x)
130     is ok, otherwise for x < -1, PREC(x) is ok if EXP(x) <= PREC(x),
131     otherwise we need EXP(x) */
132  expx = MPFR_GET_EXP (x);
133  if (expx < 0)
134    q = MPFR_PREC(x) + 1 - expx;
135  else if (expx <= MPFR_PREC(x))
136    q = MPFR_PREC(x) + 1;
137  else
138    q = expx;
139  MPFR_ASSERTN (q <= MPFR_PREC_MAX);
140  mpfr_init2 (u, q);
141  MPFR_DBGRES(inex = mpfr_ui_sub (u, 1, x, MPFR_RNDN));
142  MPFR_ASSERTN(inex == 0);
143
144  /* if x is half an integer, cot(Pi*x) = 0, thus Digamma(x) = Digamma(1-x) */
145  mpfr_mul_2ui (u, u, 1, MPFR_RNDN);
146  inex = mpfr_integer_p (u);
147  mpfr_div_2ui (u, u, 1, MPFR_RNDN);
148  if (inex)
149    {
150      inex = mpfr_digamma (y, u, rnd_mode);
151      goto end;
152    }
153
154  mpfr_init2 (t, p);
155  mpfr_init2 (v, p);
156
157  MPFR_ZIV_INIT (loop, p);
158  for (;;)
159    {
160      mpfr_const_pi (v, MPFR_RNDN);  /* v = Pi*(1+theta) for |theta|<=2^(-p) */
161      mpfr_mul (t, v, x, MPFR_RNDN); /* (1+theta)^2 */
162      e1 = MPFR_GET_EXP(t) - (mpfr_exp_t) p + 1; /* bound for t: err(t) <= 2^e1 */
163      mpfr_cot (t, t, MPFR_RNDN);
164      /* cot(t * (1+h)) = cot(t) - theta * (1 + cot(t)^2) with |theta|<=t*h */
165      if (MPFR_GET_EXP(t) > 0)
166        e1 = e1 + 2 * MPFR_EXP(t) + 1;
167      else
168        e1 = e1 + 1;
169      /* now theta * (1 + cot(t)^2) <= 2^e1 */
170      e1 += (mpfr_exp_t) p - MPFR_EXP(t); /* error is now 2^e1 ulps */
171      mpfr_mul (t, t, v, MPFR_RNDN);
172      e1 ++;
173      mpfr_digamma (v, u, MPFR_RNDN);   /* error <= 1/2 ulp */
174      expv = MPFR_GET_EXP (v);
175      mpfr_sub (v, v, t, MPFR_RNDN);
176      if (MPFR_NOTZERO(v))
177        {
178          if (MPFR_GET_EXP (v) < MPFR_GET_EXP (t))
179            e1 += MPFR_EXP(t) - MPFR_EXP(v); /* scale error for t wrt new v */
180          /* now take into account the 1/2 ulp error for v */
181          if (expv - MPFR_EXP(v) - 1 > e1)
182            e1 = expv - MPFR_EXP(v) - 1;
183          else
184            e1 ++;
185          e1 ++; /* rounding error for mpfr_sub */
186          if (MPFR_CAN_ROUND (v, p - e1, MPFR_PREC(y), rnd_mode))
187            break;
188        }
189      MPFR_ZIV_NEXT (loop, p);
190      mpfr_set_prec (t, p);
191      mpfr_set_prec (v, p);
192    }
193  MPFR_ZIV_FREE (loop);
194
195  inex = mpfr_set (y, v, rnd_mode);
196
197  mpfr_clear (t);
198  mpfr_clear (v);
199 end:
200  mpfr_clear (u);
201
202  return inex;
203}
204
205/* we have x >= 1/2 here */
206static int
207mpfr_digamma_positive (mpfr_ptr y, mpfr_srcptr x, mpfr_rnd_t rnd_mode)
208{
209  mpfr_prec_t p = MPFR_PREC(y) + 10, q;
210  mpfr_t t, u, x_plus_j;
211  int inex;
212  mpfr_exp_t errt, erru, expt;
213  unsigned long j = 0, min;
214  MPFR_ZIV_DECL (loop);
215
216  MPFR_LOG_FUNC
217    (("x[%Pd]=%.*Rg rnd=%d", mpfr_get_prec(x), mpfr_log_prec, x, rnd_mode),
218     ("y[%Pd]=%.*Rg inexact=%d", mpfr_get_prec(y), mpfr_log_prec, y, inex));
219
220  /* For very large x, use |digamma(x) - log(x)| < 1/x < 2^(1-EXP(x)).
221     However, for a fixed value of GUARD, MPFR_CAN_ROUND() might fail
222     with probability 1/2^GUARD, in which case the default code will
223     fail since it requires x+1 to be exact, thus a huge precision if
224     x is huge. There are two workarounds:
225     * either perform a Ziv's loop, by increasing GUARD at each step.
226       However, this might fail if x is moderately large, in which case
227       more terms of the asymptotic expansion would be needed.
228     * implement a full asymptotic expansion (with Ziv's loop). */
229#define GUARD 30
230  if (MPFR_PREC(y) + GUARD < MPFR_EXP(x))
231    {
232      /* this ensures EXP(x) >= 3, thus x >= 4, thus log(x) > 1 */
233      mpfr_init2 (t, MPFR_PREC(y) + GUARD);
234      mpfr_log (t, x, MPFR_RNDN);
235      /* |t - digamma(x)| <= 1/2*ulp(t) + |digamma(x) - log(x)|
236                          <= 1/2*ulp(t) + 2^(1-EXP(x))
237                          <= 1/2*ulp(t) + 2^(-PREC(y)-GUARD)
238                          <= ulp(t)
239         since |t| >= 1 thus ulp(t) >= 2^(1-PREC(y)-GUARD) */
240      if (MPFR_CAN_ROUND (t, MPFR_PREC(y) + GUARD, MPFR_PREC(y), rnd_mode))
241        {
242          inex = mpfr_set (y, t, rnd_mode);
243          mpfr_clear (t);
244          return inex;
245        }
246      mpfr_clear (t);
247    }
248
249  /* compute a precision q such that x+1 is exact */
250  if (MPFR_PREC(x) < MPFR_GET_EXP(x))
251    {
252      /* The goal of the first assertion is to let the compiler ignore
253         the second one when MPFR_EMAX_MAX <= MPFR_PREC_MAX. */
254      MPFR_ASSERTD (MPFR_EXP(x) <= MPFR_EMAX_MAX);
255      MPFR_ASSERTN (MPFR_EXP(x) <= MPFR_PREC_MAX);
256      q = MPFR_EXP(x);
257    }
258  else
259    q = MPFR_PREC(x) + 1;
260
261  /* FIXME: q can be much too large, e.g. equal to the maximum exponent! */
262  MPFR_LOG_MSG (("q=%Pd\n", q));
263
264  mpfr_init2 (x_plus_j, q);
265
266  mpfr_init2 (t, p);
267  mpfr_init2 (u, p);
268  MPFR_ZIV_INIT (loop, p);
269  for(;;)
270    {
271      /* Lower bound for x+j in mpfr_digamma_approx call: since the smallest
272         term of the divergent series for Digamma(x) is about exp(-2*Pi*x), and
273         we want it to be less than 2^(-p), this gives x > p*log(2)/(2*Pi)
274         i.e., x >= 0.1103 p.
275         To be safe, we ensure x >= 0.25 * p.
276      */
277      min = (p + 3) / 4;
278      if (min < 2)
279        min = 2;
280
281      mpfr_set (x_plus_j, x, MPFR_RNDN);
282      mpfr_set_ui (u, 0, MPFR_RNDN);
283      j = 0;
284      while (mpfr_cmp_ui (x_plus_j, min) < 0)
285        {
286          j ++;
287          mpfr_ui_div (t, 1, x_plus_j, MPFR_RNDN); /* err <= 1/2 ulp */
288          mpfr_add (u, u, t, MPFR_RNDN);
289          inex = mpfr_add_ui (x_plus_j, x_plus_j, 1, MPFR_RNDZ);
290          if (inex != 0) /* we lost one bit */
291            {
292              q ++;
293              mpfr_prec_round (x_plus_j, q, MPFR_RNDZ);
294              mpfr_nextabove (x_plus_j);
295            }
296          /* since all terms are positive, the error is bounded by j ulps */
297        }
298      for (erru = 0; j > 1; erru++, j = (j + 1) / 2);
299      errt = mpfr_digamma_approx (t, x_plus_j);
300      expt = MPFR_GET_EXP (t);
301      mpfr_sub (t, t, u, MPFR_RNDN);
302      /* Warning! t may be zero (more likely in small precision). Note
303         that in this case, this is an exact zero, not an underflow. */
304      if (MPFR_NOTZERO(t))
305        {
306          if (MPFR_GET_EXP (t) < expt)
307            errt += expt - MPFR_EXP(t);
308          /* Warning: if u is zero (which happens when x_plus_j >= min at the
309             beginning of the while loop above), EXP(u) is not defined.
310             In this case we have no error from u. */
311          if (MPFR_NOTZERO(u) && MPFR_GET_EXP (t) < MPFR_GET_EXP (u))
312            erru += MPFR_EXP(u) - MPFR_EXP(t);
313          if (errt > erru)
314            errt = errt + 1;
315          else if (errt == erru)
316            errt = errt + 2;
317          else
318            errt = erru + 1;
319          if (MPFR_CAN_ROUND (t, p - errt, MPFR_PREC(y), rnd_mode))
320            break;
321        }
322      MPFR_ZIV_NEXT (loop, p);
323      mpfr_set_prec (t, p);
324      mpfr_set_prec (u, p);
325    }
326  MPFR_ZIV_FREE (loop);
327  inex = mpfr_set (y, t, rnd_mode);
328  mpfr_clear (t);
329  mpfr_clear (u);
330  mpfr_clear (x_plus_j);
331  return inex;
332}
333
334int
335mpfr_digamma (mpfr_ptr y, mpfr_srcptr x, mpfr_rnd_t rnd_mode)
336{
337  int inex;
338  MPFR_SAVE_EXPO_DECL (expo);
339
340  MPFR_LOG_FUNC
341    (("x[%Pd]=%.*Rg rnd=%d", mpfr_get_prec(x), mpfr_log_prec, x, rnd_mode),
342     ("y[%Pd]=%.*Rg inexact=%d", mpfr_get_prec(y), mpfr_log_prec, y, inex));
343
344  if (MPFR_UNLIKELY(MPFR_IS_SINGULAR(x)))
345    {
346      if (MPFR_IS_NAN(x))
347        {
348          MPFR_SET_NAN(y);
349          MPFR_RET_NAN;
350        }
351      else if (MPFR_IS_INF(x))
352        {
353          if (MPFR_IS_POS(x)) /* Digamma(+Inf) = +Inf */
354            {
355              MPFR_SET_SAME_SIGN(y, x);
356              MPFR_SET_INF(y);
357              MPFR_RET(0);
358            }
359          else                /* Digamma(-Inf) = NaN */
360            {
361              MPFR_SET_NAN(y);
362              MPFR_RET_NAN;
363            }
364        }
365      else /* Zero case */
366        {
367          /* the following works also in case of overlap */
368          MPFR_SET_INF(y);
369          MPFR_SET_OPPOSITE_SIGN(y, x);
370          MPFR_SET_DIVBY0 ();
371          MPFR_RET(0);
372        }
373    }
374
375  /* Digamma is undefined for negative integers */
376  if (MPFR_IS_NEG(x) && mpfr_integer_p (x))
377    {
378      MPFR_SET_NAN(y);
379      MPFR_RET_NAN;
380    }
381
382  /* now x is a normal number */
383
384  MPFR_SAVE_EXPO_MARK (expo);
385  /* for x very small, we have Digamma(x) = -1/x - gamma + O(x), more precisely
386     -1 < Digamma(x) + 1/x < 0 for -0.2 < x < 0.2, thus:
387     (i) either x is a power of two, then 1/x is exactly representable, and
388         as long as 1/2*ulp(1/x) > 1, we can conclude;
389     (ii) otherwise assume x has <= n bits, and y has <= n+1 bits, then
390   |y + 1/x| >= 2^(-2n) ufp(y), where ufp means unit in first place.
391   Since |Digamma(x) + 1/x| <= 1, if 2^(-2n) ufp(y) >= 2, then
392   |y - Digamma(x)| >= 2^(-2n-1)ufp(y), and rounding -1/x gives the correct result.
393   If x < 2^E, then y > 2^(-E), thus ufp(y) > 2^(-E-1).
394   A sufficient condition is thus EXP(x) <= -2 MAX(PREC(x),PREC(Y)). */
395  if (MPFR_GET_EXP (x) < -2)
396    {
397      if (MPFR_EXP(x) <= -2 * (mpfr_exp_t) MAX(MPFR_PREC(x), MPFR_PREC(y)))
398        {
399          int signx = MPFR_SIGN(x);
400          inex = mpfr_si_div (y, -1, x, rnd_mode);
401          if (inex == 0) /* x is a power of two */
402            { /* result always -1/x, except when rounding down */
403              if (rnd_mode == MPFR_RNDA)
404                rnd_mode = (signx > 0) ? MPFR_RNDD : MPFR_RNDU;
405              if (rnd_mode == MPFR_RNDZ)
406                rnd_mode = (signx > 0) ? MPFR_RNDU : MPFR_RNDD;
407              if (rnd_mode == MPFR_RNDU)
408                inex = 1;
409              else if (rnd_mode == MPFR_RNDD)
410                {
411                  mpfr_nextbelow (y);
412                  inex = -1;
413                }
414              else /* nearest */
415                inex = 1;
416            }
417          MPFR_SAVE_EXPO_UPDATE_FLAGS (expo, __gmpfr_flags);
418          goto end;
419        }
420    }
421
422  /* if x < 1/2 we use the reflection formula */
423  if (MPFR_IS_NEG(x) || MPFR_EXP(x) < 0)
424    inex = mpfr_digamma_reflection (y, x, rnd_mode);
425  else
426    inex = mpfr_digamma_positive (y, x, rnd_mode);
427
428 end:
429  MPFR_SAVE_EXPO_FREE (expo);
430  return mpfr_check_range (y, inex, rnd_mode);
431}
432