1/* Chi-squared test for mpfr_erandom
2
3Copyright 2011-2023 Free Software Foundation, Inc.
4Contributed by Charles Karney <charles@karney.com>, SRI International.
5
6This file is part of the GNU MPFR Library.
7
8The GNU MPFR Library is free software; you can redistribute it and/or modify
9it under the terms of the GNU Lesser General Public License as published by
10the Free Software Foundation; either version 3 of the License, or (at your
11option) any later version.
12
13The GNU MPFR Library is distributed in the hope that it will be useful, but
14WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
15or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU Lesser General Public
16License for more details.
17
18You should have received a copy of the GNU Lesser General Public License
19along with the GNU MPFR Library; see the file COPYING.LESSER.  If not, see
20https://www.gnu.org/licenses/ or write to the Free Software Foundation, Inc.,
2151 Franklin St, Fifth Floor, Boston, MA 02110-1301, USA. */
22
23#include "mpfr-test.h"
24
25/* Return Phi(x) = 1 - exp(-x), the cumulative probability function for the
26 * exponential distribution.  We only take differences of this function so the
27 * offset doesn't matter; here Phi(0) = 0. */
28static void
29exponential_cumulative (mpfr_ptr z, mpfr_ptr x, mpfr_rnd_t rnd)
30{
31  mpfr_neg (z, x, rnd);
32  mpfr_expm1 (z, z, rnd);
33  mpfr_neg (z, z, rnd);
34}
35
36/* Given nu and chisqp, compute probability that chisq > chisqp.  This uses,
37 * A&S 26.4.16,
38 *
39 * Q(nu,chisqp) =
40 *     erfc( (3/2)*sqrt(nu) * ( cbrt(chisqp/nu) - 1 + 2/(9*nu) ) ) / 2
41 *
42 * which is valid for nu > 30.  This is the basis for the formula in Knuth,
43 * TAOCP, Vol 2, 3.3.1, Table 1.  It more accurate than the similar formula,
44 * DLMF 8.11.10. */
45static void
46chisq_prob (mpfr_ptr q, long nu, mpfr_ptr chisqp)
47{
48  mpfr_t t;
49  mpfr_rnd_t rnd;
50
51  rnd = MPFR_RNDN;  /* This uses an approx formula.  Might as well use RNDN. */
52  mpfr_init2 (t, mpfr_get_prec (q));
53
54  mpfr_div_si (q, chisqp, nu, rnd); /* chisqp/nu */
55  mpfr_cbrt (q, q, rnd);            /* (chisqp/nu)^(1/3) */
56  mpfr_sub_ui (q, q, 1, rnd);       /* (chisqp/nu)^(1/3) - 1 */
57  mpfr_set_ui (t, 2, rnd);
58  mpfr_div_si (t, t, 9*nu, rnd); /* 2/(9*nu) */
59  mpfr_add (q, q, t, rnd);       /* (chisqp/nu)^(1/3) - 1 + 2/(9*nu) */
60  mpfr_sqrt_ui (t, nu, rnd);     /* sqrt(nu) */
61  mpfr_mul_d (t, t, 1.5, rnd);   /* (3/2)*sqrt(nu) */
62  mpfr_mul (q, q, t, rnd);       /* arg to erfc */
63  mpfr_erfc (q, q, rnd);         /* erfc(...) */
64  mpfr_div_ui (q, q, 2, rnd);    /* erfc(...)/2 */
65
66  mpfr_clear (t);
67}
68
69/* The continuous chi-squared test on with a set of bins of equal width.
70 *
71 * A single precision is picked for sampling and the chi-squared calculation.
72 * This should picked high enough so that binning in test doesn't need to be
73 * accurately aligned with possible values of the deviates.  Also we need the
74 * precision big enough that chi-squared calculation itself is reliable.
75 *
76 * There's no particular benefit is testing with at very higher precisions;
77 * because of the way terandom samples, this just adds additional barely
78 * significant random bits to the deviates.  So this chi-squared test with
79 * continuous equal width bins isn't a good tool for finding problems here.
80 *
81 * The testing of low precision exponential deviates is done by
82 * test_erandom_chisq_disc. */
83static double
84test_erandom_chisq_cont (long num, mpfr_prec_t prec, int nu,
85                         double xmin, double xmax, int verbose)
86{
87  mpfr_t x, a, b, dx, z, pa, pb, ps, t;
88  long *counts;
89  int i, inexact;
90  long k;
91  mpfr_rnd_t rnd, rndd;
92  double Q, chisq;
93
94  rnd = MPFR_RNDN;              /* For chi-squared calculation */
95  rndd = MPFR_RNDD;             /* For sampling and figuring the bins */
96  mpfr_inits2 (prec, x, a, b, dx, z, pa, pb, ps, t, (mpfr_ptr) 0);
97
98  counts = (long *) tests_allocate ((nu + 1) * sizeof (long));
99  for (i = 0; i <= nu; i++)
100    counts[i] = 0;
101
102  /* a and b are bounds of nu equally spaced bins.  Set dx = (b-a)/nu */
103  mpfr_set_d (a, xmin, rnd);
104  mpfr_set_d (b, xmax, rnd);
105
106  mpfr_sub (dx, b, a, rnd);
107  mpfr_div_si (dx, dx, nu, rnd);
108
109  for (k = 0; k < num; ++k)
110    {
111      inexact = mpfr_erandom (x, RANDS, rndd);
112      if (inexact == 0)
113        {
114          /* one call in the loop pretended to return an exact number! */
115          printf ("Error: mpfr_erandom() returns a zero ternary value.\n");
116          exit (1);
117        }
118      if (mpfr_signbit (x))
119        {
120          printf ("Error: mpfr_erandom() returns a negative deviate.\n");
121          exit (1);
122        }
123      mpfr_sub (x, x, a, rndd);
124      mpfr_div (x, x, dx, rndd);
125      i = mpfr_get_si (x, rndd);
126      ++counts[i >= 0 && i < nu ? i : nu];
127    }
128
129  mpfr_set (x, a, rnd);
130  exponential_cumulative (pa, x, rnd);
131  mpfr_add_ui (ps, pa, 1, rnd);
132  mpfr_set_zero (t, 1);
133  for (i = 0; i <= nu; ++i)
134    {
135      if (i < nu)
136        {
137          mpfr_add (x, x, dx, rnd);
138          exponential_cumulative (pb, x, rnd);
139          mpfr_sub (pa, pb, pa, rnd); /* prob for this bin */
140        }
141      else
142        mpfr_sub (pa, ps, pa, rnd); /* prob for last bin, i = nu */
143
144      /* Compute z = counts[i] - num * p; t += z * z / (num * p) */
145      mpfr_mul_ui (pa, pa, num, rnd);
146      mpfr_ui_sub (z, counts[i], pa, rnd);
147      mpfr_sqr (z, z, rnd);
148      mpfr_div (z, z, pa, rnd);
149      mpfr_add (t, t, z, rnd);
150      mpfr_swap (pa, pb);       /* i.e., pa = pb */
151    }
152
153  chisq = mpfr_get_d (t, rnd);
154  chisq_prob (t, nu, t);
155  Q = mpfr_get_d (t, rnd);
156  if (verbose)
157    {
158      printf ("num = %ld, equal bins in [%.2f, %.2f], nu = %d: chisq = %.2f\n",
159              num, xmin, xmax, nu, chisq);
160      if (Q < 0.05)
161        printf ("    WARNING: probability (less than 5%%) = %.2e\n", Q);
162    }
163
164  tests_free (counts, (nu + 1) * sizeof (long));
165  mpfr_clears (x, a, b, dx, z, pa, pb, ps, t, (mpfr_ptr) 0);
166  return Q;
167}
168
169/* Return a sequential number for a positive low-precision x.  x is altered by
170 * this function.  low precision means prec = 2, 3, or 4.  High values of
171 * precision will result in integer overflow. */
172static long
173sequential (mpfr_ptr x)
174{
175  long expt, prec;
176
177  prec = mpfr_get_prec (x);
178  expt =  mpfr_get_exp (x);
179  mpfr_mul_2si (x, x, prec - expt, MPFR_RNDN);
180
181  return expt * (1 << (prec - 1)) + mpfr_get_si (x, MPFR_RNDN);
182}
183
184/* The chi-squared test on low precision exponential deviates.  wprec is the
185 * working precision for the chi-squared calculation.  prec is the precision
186 * for the sampling; choose this in [2,5].  The bins consist of all the
187 * possible deviate values in the range [xmin, xmax] coupled with the value of
188 * inexact.  Thus with prec = 2, the bins are
189 *   ...
190 *   (7/16, 1/2)  x = 1/2, inexact = +1
191 *   (1/2 , 5/8)  x = 1/2, inexact = -1
192 *   (5/8 , 3/4)  x = 3/4, inexact = +1
193 *   (3/4 , 7/8)  x = 3/4, inexact = -1
194 *   (7/8 , 1  )  x = 1  , inexact = +1
195 *   (1   , 5/4)  x = 1  , inexact = -1
196 *   (5/4 , 3/2)  x = 3/2, inexact = +1
197 *   (3/2 , 7/4)  x = 3/2, inexact = -1
198 *   ...
199 * In addition, two bins are allocated for [0,xmin) and (xmax,inf).
200 *
201 * The sampling is with MPFR_RNDN.  This is the rounding mode which elicits the
202 * most information.  trandom_deviate includes checks on the consistency of the
203 * results extracted from a random_deviate with other rounding modes.  */
204static double
205test_erandom_chisq_disc (long num, mpfr_prec_t wprec, int prec,
206                         double xmin, double xmax, int verbose)
207{
208  mpfr_t x, v, pa, pb, z, t;
209  mpfr_rnd_t rnd;
210  int i, inexact, nu;
211  long *counts;
212  long k, seqmin, seqmax, seq;
213  double Q, chisq;
214
215  rnd = MPFR_RNDN;
216  mpfr_init2 (x, prec);
217  mpfr_init2 (v, prec+1);
218  mpfr_inits2 (wprec, pa, pb, z, t, (mpfr_ptr) 0);
219
220  mpfr_set_d (x, xmin, rnd);
221  xmin = mpfr_get_d (x, rnd);
222  mpfr_set (v, x, rnd);
223  seqmin = sequential (x);
224  mpfr_set_d (x, xmax, rnd);
225  xmax = mpfr_get_d (x, rnd);
226  seqmax = sequential (x);
227
228  /* Two bins for each sequential number (for inexact = +/- 1), plus 1 for u <
229   * umin and 1 for u > umax, minus 1 for degrees of freedom */
230  nu = 2 * (seqmax - seqmin + 1) + 2 - 1;
231  counts = (long *) tests_allocate ((nu + 1) * sizeof (long));
232  for (i = 0; i <= nu; i++)
233    counts[i] = 0;
234
235  for (k = 0; k < num; ++k)
236    {
237      inexact = mpfr_erandom (x, RANDS, rnd);
238      if (mpfr_signbit (x))
239        {
240          printf ("Error: mpfr_erandom() returns a negative deviate.\n");
241          exit (1);
242        }
243      /* Don't call sequential with small args to avoid undefined behavior with
244       * zero and possibility of overflow. */
245      seq = mpfr_greaterequal_p (x, v) ? sequential (x) : seqmin - 1;
246      ++counts[seq < seqmin ? 0 :
247               seq <= seqmax ? 2 * (seq - seqmin) + 1 + (inexact > 0 ? 0 : 1) :
248               nu];
249    }
250
251  mpfr_set_zero (v, 1);
252  exponential_cumulative (pa, v, rnd);
253  /* Cycle through all the bin boundaries using mpfr_nextabove at precision
254   * prec + 1 starting at mpfr_nextbelow (xmin) */
255  mpfr_set_d (x, xmin, rnd);
256  mpfr_set (v, x, rnd);
257  mpfr_nextbelow (v);
258  mpfr_nextbelow (v);
259  mpfr_set_zero (t, 1);
260  for (i = 0; i <= nu; ++i)
261    {
262      if (i < nu)
263        mpfr_nextabove (v);
264      else
265        mpfr_set_inf (v, 1);
266      exponential_cumulative (pb, v, rnd);
267      mpfr_sub (pa, pb, pa, rnd);
268
269      /* Compute z = counts[i] - num * p; t += z * z / (num * p). */
270      mpfr_mul_ui (pa, pa, num, rnd);
271      mpfr_ui_sub (z, counts[i], pa, rnd);
272      mpfr_sqr (z, z, rnd);
273      mpfr_div (z, z, pa, rnd);
274      mpfr_add (t, t, z, rnd);
275      mpfr_swap (pa, pb);       /* i.e., pa = pb */
276    }
277
278  chisq = mpfr_get_d (t, rnd);
279  chisq_prob (t, nu, t);
280  Q = mpfr_get_d (t, rnd);
281  if (verbose)
282    {
283      printf ("num = %ld, discrete (prec = %d) bins in [%.6f, %.2f], "
284              "nu = %d: chisq = %.2f\n", num, prec, xmin, xmax, nu, chisq);
285      if (Q < 0.05)
286        printf ("    WARNING: probability (less than 5%%) = %.2e\n", Q);
287    }
288
289  tests_free (counts, (nu + 1) * sizeof (long));
290  mpfr_clears (x, v, pa, pb, z, t, (mpfr_ptr) 0);
291  return Q;
292}
293
294static void
295run_chisq (double (*f)(long, mpfr_prec_t, int, double, double, int),
296           long num, mpfr_prec_t prec, int bin,
297           double xmin, double xmax, int verbose)
298{
299  double Q, Qcum, Qbad, Qthresh;
300  int i;
301
302  Qcum = 1;
303  Qbad = 1.e-9;
304  Qthresh = 0.01;
305  for (i = 0; i < 3; ++i)
306    {
307      Q = (*f)(num, prec, bin, xmin, xmax, verbose);
308      Qcum *= Q;
309      if (Q > Qthresh)
310        return;
311      else if (Q < Qbad)
312        {
313          printf ("Error: mpfr_erandom chi-squared failure "
314                  "(prob = %.2e < %.2e)\n", Q, Qbad);
315          exit (1);
316        }
317      num *= 10;
318      Qthresh /= 10;
319    }
320  if (Qcum < Qbad)              /* Presumably this is true */
321    {
322      printf ("Error: mpfr_erandom combined chi-squared failure "
323              "(prob = %.2e)\n", Qcum);
324      exit (1);
325    }
326}
327
328int
329main (int argc, char *argv[])
330{
331  long nbtests;
332  int verbose;
333
334  tests_start_mpfr ();
335
336  verbose = 0;
337  nbtests = 100000;
338  if (argc > 1)
339    {
340      long a = atol (argv[1]);
341      verbose = 1;
342      if (a != 0)
343        nbtests = a;
344    }
345
346  run_chisq (test_erandom_chisq_cont, nbtests, 64, 60, 0, 7, verbose);
347  run_chisq (test_erandom_chisq_disc, nbtests, 64, 2, 0.002, 6, verbose);
348  run_chisq (test_erandom_chisq_disc, nbtests, 64, 3, 0.02, 7, verbose);
349  run_chisq (test_erandom_chisq_disc, nbtests, 64, 4, 0.04, 8, verbose);
350
351  tests_end_mpfr ();
352  return 0;
353}
354