1/* Tests matrix22_mul.
2
3Copyright 2008 Free Software Foundation, Inc.
4
5This file is part of the GNU MP Library test suite.
6
7The GNU MP Library test suite is free software; you can redistribute it
8and/or modify it under the terms of the GNU General Public License as
9published by the Free Software Foundation; either version 3 of the License,
10or (at your option) any later version.
11
12The GNU MP Library test suite is distributed in the hope that it will be
13useful, but WITHOUT ANY WARRANTY; without even the implied warranty of
14MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
15Public License for more details.
16
17You should have received a copy of the GNU General Public License along with
18the GNU MP Library test suite.  If not, see https://www.gnu.org/licenses/.  */
19
20#include <stdio.h>
21#include <stdlib.h>
22
23#include "gmp-impl.h"
24#include "tests.h"
25
26struct matrix {
27  mp_size_t alloc;
28  mp_size_t n;
29  mp_ptr e00, e01, e10, e11;
30};
31
32static void
33matrix_init (struct matrix *M, mp_size_t n)
34{
35  mp_ptr p = refmpn_malloc_limbs (4*(n+1));
36  M->e00 = p; p += n+1;
37  M->e01 = p; p += n+1;
38  M->e10 = p; p += n+1;
39  M->e11 = p;
40  M->alloc = n + 1;
41  M->n = 0;
42}
43
44static void
45matrix_clear (struct matrix *M)
46{
47  refmpn_free_limbs (M->e00);
48}
49
50static void
51matrix_copy (struct matrix *R, const struct matrix *M)
52{
53  R->n = M->n;
54  MPN_COPY (R->e00, M->e00, M->n);
55  MPN_COPY (R->e01, M->e01, M->n);
56  MPN_COPY (R->e10, M->e10, M->n);
57  MPN_COPY (R->e11, M->e11, M->n);
58}
59
60/* Used with same size, so no need for normalization. */
61static int
62matrix_equal_p (const struct matrix *A, const struct matrix *B)
63{
64  return (A->n == B->n
65	  && mpn_cmp (A->e00, B->e00, A->n) == 0
66	  && mpn_cmp (A->e01, B->e01, A->n) == 0
67	  && mpn_cmp (A->e10, B->e10, A->n) == 0
68	  && mpn_cmp (A->e11, B->e11, A->n) == 0);
69}
70
71static void
72matrix_random(struct matrix *M, mp_size_t n, gmp_randstate_ptr rands)
73{
74  M->n = n;
75  mpn_random (M->e00, n);
76  mpn_random (M->e01, n);
77  mpn_random (M->e10, n);
78  mpn_random (M->e11, n);
79}
80
81#define MUL(rp, ap, an, bp, bn) do { \
82    if (an > bn)		     \
83      mpn_mul (rp, ap, an, bp, bn);  \
84    else			     \
85      mpn_mul (rp, bp, bn, ap, an);  \
86  } while(0)
87
88static void
89ref_matrix22_mul (struct matrix *R,
90		  const struct matrix *A,
91		  const struct matrix *B, mp_ptr tp)
92{
93  mp_size_t an, bn, n;
94  mp_ptr r00, r01, r10, r11, a00, a01, a10, a11, b00, b01, b10, b11;
95
96  if (A->n >= B->n)
97    {
98      r00 = R->e00; a00 = A->e00; b00 = B->e00;
99      r01 = R->e01; a01 = A->e01; b01 = B->e01;
100      r10 = R->e10; a10 = A->e10; b10 = B->e10;
101      r11 = R->e11; a11 = A->e11; b11 = B->e11;
102      an = A->n, bn = B->n;
103    }
104  else
105    {
106      /* Transpose */
107      r00 = R->e00; a00 = B->e00; b00 = A->e00;
108      r01 = R->e10; a01 = B->e10; b01 = A->e10;
109      r10 = R->e01; a10 = B->e01; b10 = A->e01;
110      r11 = R->e11; a11 = B->e11; b11 = A->e11;
111      an = B->n, bn = A->n;
112    }
113  n = an + bn;
114  R->n = n + 1;
115
116  mpn_mul (r00, a00, an, b00, bn);
117  mpn_mul (tp, a01, an, b10, bn);
118  r00[n] = mpn_add_n (r00, r00, tp, n);
119
120  mpn_mul (r01, a00, an, b01, bn);
121  mpn_mul (tp, a01, an, b11, bn);
122  r01[n] = mpn_add_n (r01, r01, tp, n);
123
124  mpn_mul (r10, a10, an, b00, bn);
125  mpn_mul (tp, a11, an, b10, bn);
126  r10[n] = mpn_add_n (r10, r10, tp, n);
127
128  mpn_mul (r11, a10, an, b01, bn);
129  mpn_mul (tp, a11, an, b11, bn);
130  r11[n] = mpn_add_n (r11, r11, tp, n);
131}
132
133static void
134one_test (const struct matrix *A, const struct matrix *B, int i)
135{
136  struct matrix R;
137  struct matrix P;
138  mp_ptr tp;
139
140  matrix_init (&R, A->n + B->n + 1);
141  matrix_init (&P, A->n + B->n + 1);
142
143  tp = refmpn_malloc_limbs (mpn_matrix22_mul_itch (A->n, B->n));
144
145  ref_matrix22_mul (&R, A, B, tp);
146  matrix_copy (&P, A);
147  mpn_matrix22_mul (P.e00, P.e01, P.e10, P.e11, A->n,
148		    B->e00, B->e01, B->e10, B->e11, B->n, tp);
149  P.n = A->n + B->n + 1;
150  if (!matrix_equal_p (&R, &P))
151    {
152      fprintf (stderr, "ERROR in test %d\n", i);
153      gmp_fprintf (stderr, "A = (%Nx, %Nx\n      %Nx, %Nx)\n"
154		   "B = (%Nx, %Nx\n      %Nx, %Nx)\n"
155		   "R = (%Nx, %Nx (expected)\n      %Nx, %Nx)\n"
156		   "P = (%Nx, %Nx (incorrect)\n      %Nx, %Nx)\n",
157		   A->e00, A->n, A->e01, A->n, A->e10, A->n, A->e11, A->n,
158		   B->e00, B->n, B->e01, B->n, B->e10, B->n, B->e11, B->n,
159		   R.e00, R.n, R.e01, R.n, R.e10, R.n, R.e11, R.n,
160		   P.e00, P.n, P.e01, P.n, P.e10, P.n, P.e11, P.n);
161      abort();
162    }
163  refmpn_free_limbs (tp);
164  matrix_clear (&R);
165  matrix_clear (&P);
166}
167
168#define MAX_SIZE (2+2*MATRIX22_STRASSEN_THRESHOLD)
169
170int
171main (int argc, char **argv)
172{
173  struct matrix A;
174  struct matrix B;
175
176  gmp_randstate_ptr rands;
177  mpz_t bs;
178  int i;
179
180  tests_start ();
181  rands = RANDS;
182
183  matrix_init (&A, MAX_SIZE);
184  matrix_init (&B, MAX_SIZE);
185  mpz_init (bs);
186
187  for (i = 0; i < 1000; i++)
188    {
189      mp_size_t an, bn;
190      mpz_urandomb (bs, rands, 32);
191      an = 1 + mpz_get_ui (bs) % MAX_SIZE;
192      mpz_urandomb (bs, rands, 32);
193      bn = 1 + mpz_get_ui (bs) % MAX_SIZE;
194
195      matrix_random (&A, an, rands);
196      matrix_random (&B, bn, rands);
197
198      one_test (&A, &B, i);
199    }
200  mpz_clear (bs);
201  matrix_clear (&A);
202  matrix_clear (&B);
203
204  tests_end ();
205  return 0;
206}
207