1/* mpz_powm(res,base,exp,mod) -- Set R to (U^E) mod M.
2
3   Contributed to the GNU project by Torbjorn Granlund.
4
5Copyright 1991, 1993, 1994, 1996, 1997, 2000, 2001, 2002, 2005, 2008, 2009
6Free Software Foundation, Inc.
7
8This file is part of the GNU MP Library.
9
10The GNU MP Library is free software; you can redistribute it and/or modify
11it under the terms of the GNU Lesser General Public License as published by
12the Free Software Foundation; either version 3 of the License, or (at your
13option) any later version.
14
15The GNU MP Library is distributed in the hope that it will be useful, but
16WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
17or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU Lesser General Public
18License for more details.
19
20You should have received a copy of the GNU Lesser General Public License
21along with the GNU MP Library.  If not, see http://www.gnu.org/licenses/.  */
22
23
24#include "gmp.h"
25#include "gmp-impl.h"
26#include "longlong.h"
27#ifdef BERKELEY_MP
28#include "mp.h"
29#endif
30
31
32/* TODO
33
34 * Improve handling of buffers.  It is pretty ugly now.
35
36 * For even moduli, we compute a binvert of its odd part both here and in
37   mpn_powm.  How can we avoid this recomputation?
38*/
39
40/*
41  b ^ e mod m   res
42  0   0     0    ?
43  0   e     0    ?
44  0   0     m    ?
45  0   e     m    0
46  b   0     0    ?
47  b   e     0    ?
48  b   0     m    1 mod m
49  b   e     m    b^e mod m
50*/
51
52#define HANDLE_NEGATIVE_EXPONENT 1
53
54void
55#ifndef BERKELEY_MP
56mpz_powm (mpz_ptr r, mpz_srcptr b, mpz_srcptr e, mpz_srcptr m)
57#else /* BERKELEY_MP */
58pow (mpz_srcptr b, mpz_srcptr e, mpz_srcptr m, mpz_ptr r)
59#endif /* BERKELEY_MP */
60{
61  mp_size_t n, nodd, ncnt;
62  int cnt;
63  mp_ptr rp, tp;
64  mp_srcptr bp, ep, mp;
65  mp_size_t rn, bn, es, en, itch;
66  TMP_DECL;
67
68  n = ABSIZ(m);
69  if (n == 0)
70    DIVIDE_BY_ZERO;
71
72  mp = PTR(m);
73
74  TMP_MARK;
75
76  es = SIZ(e);
77  if (UNLIKELY (es <= 0))
78    {
79      mpz_t new_b;
80      if (es == 0)
81	{
82	  /* b^0 mod m,  b is anything and m is non-zero.
83	     Result is 1 mod m, i.e., 1 or 0 depending on if m = 1.  */
84	  SIZ(r) = n != 1 || mp[0] != 1;
85	  PTR(r)[0] = 1;
86	  TMP_FREE;	/* we haven't really allocated anything here */
87	  return;
88	}
89#if HANDLE_NEGATIVE_EXPONENT
90      MPZ_TMP_INIT (new_b, n + 1);
91
92      if (! mpz_invert (new_b, b, m))
93	DIVIDE_BY_ZERO;
94      b = new_b;
95      es = -es;
96#else
97      DIVIDE_BY_ZERO;
98#endif
99    }
100  en = es;
101
102  bn = ABSIZ(b);
103
104  if (UNLIKELY (bn == 0))
105    {
106      SIZ(r) = 0;
107      TMP_FREE;
108      return;
109    }
110
111  ep = PTR(e);
112
113  /* Handle (b^1 mod m) early, since mpn_pow* do not handle that case.  */
114  if (UNLIKELY (en == 1 && ep[0] == 1))
115    {
116      rp = TMP_ALLOC_LIMBS (n);
117      bp = PTR(b);
118      if (bn >= n)
119	{
120	  mp_ptr qp = TMP_ALLOC_LIMBS (bn - n + 1);
121	  mpn_tdiv_qr (qp, rp, 0L, bp, bn, mp, n);
122	  rn = n;
123	  MPN_NORMALIZE (rp, rn);
124
125	  if (SIZ(b) < 0 && rn != 0)
126	    {
127	      mpn_sub (rp, mp, n, rp, rn);
128	      rn = n;
129	      MPN_NORMALIZE (rp, rn);
130	    }
131	}
132      else
133	{
134	  if (SIZ(b) < 0)
135	    {
136	      mpn_sub (rp, mp, n, bp, bn);
137	      rn = n;
138	      rn -= (rp[rn - 1] == 0);
139	    }
140	  else
141	    {
142	      MPN_COPY (rp, bp, bn);
143	      rn = bn;
144	    }
145	}
146      goto ret;
147    }
148
149  /* Remove low zero limbs from M.  This loop will terminate for correctly
150     represented mpz numbers.  */
151  ncnt = 0;
152  while (UNLIKELY (mp[0] == 0))
153    {
154      mp++;
155      ncnt++;
156    }
157  nodd = n - ncnt;
158  cnt = 0;
159  if (mp[0] % 2 == 0)
160    {
161      mp_ptr new = TMP_ALLOC_LIMBS (nodd);
162      count_trailing_zeros (cnt, mp[0]);
163      mpn_rshift (new, mp, nodd, cnt);
164      nodd -= new[nodd - 1] == 0;
165      mp = new;
166      ncnt++;
167    }
168
169  if (ncnt != 0)
170    {
171      /* We will call both mpn_powm and mpn_powlo.  */
172      /* rp needs n, mpn_powlo needs 4n, the 2 mpn_binvert might need more */
173      mp_size_t n_largest_binvert = MAX (ncnt, nodd);
174      mp_size_t itch_binvert = mpn_binvert_itch (n_largest_binvert);
175      itch = 3 * n + MAX (itch_binvert, 2 * n);
176    }
177  else
178    {
179      /* We will call just mpn_powm.  */
180      mp_size_t itch_binvert = mpn_binvert_itch (nodd);
181      itch = n + MAX (itch_binvert, 2 * n);
182    }
183  tp = TMP_ALLOC_LIMBS (itch);
184
185  rp = tp;  tp += n;
186
187  bp = PTR(b);
188  mpn_powm (rp, bp, bn, ep, en, mp, nodd, tp);
189
190  rn = n;
191
192  if (ncnt != 0)
193    {
194      mp_ptr r2, xp, yp, odd_inv_2exp;
195      unsigned long t;
196      int bcnt;
197
198      if (bn < ncnt)
199	{
200	  mp_ptr new = TMP_ALLOC_LIMBS (ncnt);
201	  MPN_COPY (new, bp, bn);
202	  MPN_ZERO (new + bn, ncnt - bn);
203	  bp = new;
204	}
205
206      r2 = tp;
207
208      if (bp[0] % 2 == 0)
209	{
210	  if (en > 1)
211	    {
212	      MPN_ZERO (r2, ncnt);
213	      goto zero;
214	    }
215
216	  ASSERT (en == 1);
217	  t = (ncnt - (cnt != 0)) * GMP_NUMB_BITS + cnt;
218
219	  /* Count number of low zero bits in B, up to 3.  */
220	  bcnt = (0x1213 >> ((bp[0] & 7) << 1)) & 0x3;
221	  /* Note that ep[0] * bcnt might overflow, but that just results
222	     in a missed optimization.  */
223	  if (ep[0] * bcnt >= t)
224	    {
225	      MPN_ZERO (r2, ncnt);
226	      goto zero;
227	    }
228	}
229
230      mpn_powlo (r2, bp, ep, en, ncnt, tp + ncnt);
231
232    zero:
233      if (nodd < ncnt)
234	{
235	  mp_ptr new = TMP_ALLOC_LIMBS (ncnt);
236	  MPN_COPY (new, mp, nodd);
237	  MPN_ZERO (new + nodd, ncnt - nodd);
238	  mp = new;
239	}
240
241      odd_inv_2exp = tp + n;
242      mpn_binvert (odd_inv_2exp, mp, ncnt, tp + 2 * n);
243
244      mpn_sub (r2, r2, ncnt, rp, nodd > ncnt ? ncnt : nodd);
245
246      xp = tp + 2 * n;
247      mpn_mullo_n (xp, odd_inv_2exp, r2, ncnt);
248
249      if (cnt != 0)
250	xp[ncnt - 1] &= (CNST_LIMB(1) << cnt) - 1;
251
252      yp = tp;
253      if (ncnt > nodd)
254	mpn_mul (yp, xp, ncnt, mp, nodd);
255      else
256	mpn_mul (yp, mp, nodd, xp, ncnt);
257
258      mpn_add (rp, yp, n, rp, nodd);
259
260      ASSERT (nodd + ncnt >= n);
261      ASSERT (nodd + ncnt <= n + 1);
262    }
263
264  MPN_NORMALIZE (rp, rn);
265
266  if ((ep[0] & 1) && SIZ(b) < 0 && rn != 0)
267    {
268      mpn_sub (rp, PTR(m), n, rp, rn);
269      rn = n;
270      MPN_NORMALIZE (rp, rn);
271    }
272
273 ret:
274  MPZ_REALLOC (r, rn);
275  SIZ(r) = rn;
276  MPN_COPY (PTR(r), rp, rn);
277
278  TMP_FREE;
279}
280