1/*
2 * Copyright (c) 2018 Thomas Pornin <pornin@bolet.org>
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining
5 * a copy of this software and associated documentation files (the
6 * "Software"), to deal in the Software without restriction, including
7 * without limitation the rights to use, copy, modify, merge, publish,
8 * distribute, sublicense, and/or sell copies of the Software, and to
9 * permit persons to whom the Software is furnished to do so, subject to
10 * the following conditions:
11 *
12 * The above copyright notice and this permission notice shall be
13 * included in all copies or substantial portions of the Software.
14 *
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
16 * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
17 * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
18 * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
19 * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
20 * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
21 * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24
25#include "inner.h"
26
27/*
28 * In this file, we handle big integers with a custom format, i.e.
29 * without the usual one-word header. Value is split into 31-bit words,
30 * each stored in a 32-bit slot (top bit is zero) in little-endian
31 * order. The length (in words) is provided explicitly. In some cases,
32 * the value can be negative (using two's complement representation). In
33 * some cases, the top word is allowed to have a 32th bit.
34 */
35
36/*
37 * Negate big integer conditionally. The value consists of 'len' words,
38 * with 31 bits in each word (the top bit of each word should be 0,
39 * except possibly for the last word). If 'ctl' is 1, the negation is
40 * computed; otherwise, if 'ctl' is 0, then the value is unchanged.
41 */
42static void
43cond_negate(uint32_t *a, size_t len, uint32_t ctl)
44{
45	size_t k;
46	uint32_t cc, xm;
47
48	cc = ctl;
49	xm = -ctl >> 1;
50	for (k = 0; k < len; k ++) {
51		uint32_t aw;
52
53		aw = a[k];
54		aw = (aw ^ xm) + cc;
55		a[k] = aw & 0x7FFFFFFF;
56		cc = aw >> 31;
57	}
58}
59
60/*
61 * Finish modular reduction. Rules on input parameters:
62 *
63 *   if neg = 1, then -m <= a < 0
64 *   if neg = 0, then 0 <= a < 2*m
65 *
66 * If neg = 0, then the top word of a[] may use 32 bits.
67 *
68 * Also, modulus m must be odd.
69 */
70static void
71finish_mod(uint32_t *a, size_t len, const uint32_t *m, uint32_t neg)
72{
73	size_t k;
74	uint32_t cc, xm, ym;
75
76	/*
77	 * First pass: compare a (assumed nonnegative) with m.
78	 * Note that if the final word uses the top extra bit, then
79	 * subtracting m must yield a value less than 2^31, since we
80	 * assumed that a < 2*m.
81	 */
82	cc = 0;
83	for (k = 0; k < len; k ++) {
84		uint32_t aw, mw;
85
86		aw = a[k];
87		mw = m[k];
88		cc = (aw - mw - cc) >> 31;
89	}
90
91	/*
92	 * At this point:
93	 *   if neg = 1, then we must add m (regardless of cc)
94	 *   if neg = 0 and cc = 0, then we must subtract m
95	 *   if neg = 0 and cc = 1, then we must do nothing
96	 */
97	xm = -neg >> 1;
98	ym = -(neg | (1 - cc));
99	cc = neg;
100	for (k = 0; k < len; k ++) {
101		uint32_t aw, mw;
102
103		aw = a[k];
104		mw = (m[k] ^ xm) & ym;
105		aw = aw - mw - cc;
106		a[k] = aw & 0x7FFFFFFF;
107		cc = aw >> 31;
108	}
109}
110
111/*
112 * Compute:
113 *   a <- (a*pa+b*pb)/(2^31)
114 *   b <- (a*qa+b*qb)/(2^31)
115 * The division is assumed to be exact (i.e. the low word is dropped).
116 * If the final a is negative, then it is negated. Similarly for b.
117 * Returned value is the combination of two bits:
118 *   bit 0: 1 if a had to be negated, 0 otherwise
119 *   bit 1: 1 if b had to be negated, 0 otherwise
120 *
121 * Factors pa, pb, qa and qb must be at most 2^31 in absolute value.
122 * Source integers a and b must be nonnegative; top word is not allowed
123 * to contain an extra 32th bit.
124 */
125static uint32_t
126co_reduce(uint32_t *a, uint32_t *b, size_t len,
127	int64_t pa, int64_t pb, int64_t qa, int64_t qb)
128{
129	size_t k;
130	int64_t cca, ccb;
131	uint32_t nega, negb;
132
133	cca = 0;
134	ccb = 0;
135	for (k = 0; k < len; k ++) {
136		uint32_t wa, wb;
137		uint64_t za, zb;
138		uint64_t tta, ttb;
139
140		/*
141		 * Since:
142		 *   |pa| <= 2^31
143		 *   |pb| <= 2^31
144		 *   0 <= wa <= 2^31 - 1
145		 *   0 <= wb <= 2^31 - 1
146		 *   |cca| <= 2^32 - 1
147		 * Then:
148		 *   |za| <= (2^31-1)*(2^32) + (2^32-1) = 2^63 - 1
149		 *
150		 * Thus, the new value of cca is such that |cca| <= 2^32 - 1.
151		 * The same applies to ccb.
152		 */
153		wa = a[k];
154		wb = b[k];
155		za = wa * (uint64_t)pa + wb * (uint64_t)pb + (uint64_t)cca;
156		zb = wa * (uint64_t)qa + wb * (uint64_t)qb + (uint64_t)ccb;
157		if (k > 0) {
158			a[k - 1] = za & 0x7FFFFFFF;
159			b[k - 1] = zb & 0x7FFFFFFF;
160		}
161
162		/*
163		 * For the new values of cca and ccb, we need a signed
164		 * right-shift; since, in C, right-shifting a signed
165		 * negative value is implementation-defined, we use a
166		 * custom portable sign extension expression.
167		 */
168#define M   ((uint64_t)1 << 32)
169		tta = za >> 31;
170		ttb = zb >> 31;
171		tta = (tta ^ M) - M;
172		ttb = (ttb ^ M) - M;
173		cca = *(int64_t *)&tta;
174		ccb = *(int64_t *)&ttb;
175#undef M
176	}
177	a[len - 1] = (uint32_t)cca;
178	b[len - 1] = (uint32_t)ccb;
179
180	nega = (uint32_t)((uint64_t)cca >> 63);
181	negb = (uint32_t)((uint64_t)ccb >> 63);
182	cond_negate(a, len, nega);
183	cond_negate(b, len, negb);
184	return nega | (negb << 1);
185}
186
187/*
188 * Compute:
189 *   a <- (a*pa+b*pb)/(2^31) mod m
190 *   b <- (a*qa+b*qb)/(2^31) mod m
191 *
192 * m0i is equal to -1/m[0] mod 2^31.
193 *
194 * Factors pa, pb, qa and qb must be at most 2^31 in absolute value.
195 * Source integers a and b must be nonnegative; top word is not allowed
196 * to contain an extra 32th bit.
197 */
198static void
199co_reduce_mod(uint32_t *a, uint32_t *b, size_t len,
200	int64_t pa, int64_t pb, int64_t qa, int64_t qb,
201	const uint32_t *m, uint32_t m0i)
202{
203	size_t k;
204	int64_t cca, ccb;
205	uint32_t fa, fb;
206
207	cca = 0;
208	ccb = 0;
209	fa = ((a[0] * (uint32_t)pa + b[0] * (uint32_t)pb) * m0i) & 0x7FFFFFFF;
210	fb = ((a[0] * (uint32_t)qa + b[0] * (uint32_t)qb) * m0i) & 0x7FFFFFFF;
211	for (k = 0; k < len; k ++) {
212		uint32_t wa, wb;
213		uint64_t za, zb;
214		uint64_t tta, ttb;
215
216		/*
217		 * In this loop, carries 'cca' and 'ccb' always fit on
218		 * 33 bits (in absolute value).
219		 */
220		wa = a[k];
221		wb = b[k];
222		za = wa * (uint64_t)pa + wb * (uint64_t)pb
223			+ m[k] * (uint64_t)fa + (uint64_t)cca;
224		zb = wa * (uint64_t)qa + wb * (uint64_t)qb
225			+ m[k] * (uint64_t)fb + (uint64_t)ccb;
226		if (k > 0) {
227			a[k - 1] = (uint32_t)za & 0x7FFFFFFF;
228			b[k - 1] = (uint32_t)zb & 0x7FFFFFFF;
229		}
230
231#define M   ((uint64_t)1 << 32)
232		tta = za >> 31;
233		ttb = zb >> 31;
234		tta = (tta ^ M) - M;
235		ttb = (ttb ^ M) - M;
236		cca = *(int64_t *)&tta;
237		ccb = *(int64_t *)&ttb;
238#undef M
239	}
240	a[len - 1] = (uint32_t)cca;
241	b[len - 1] = (uint32_t)ccb;
242
243	/*
244	 * At this point:
245	 *   -m <= a < 2*m
246	 *   -m <= b < 2*m
247	 * (this is a case of Montgomery reduction)
248	 * The top word of 'a' and 'b' may have a 32-th bit set.
249	 * We may have to add or subtract the modulus.
250	 */
251	finish_mod(a, len, m, (uint32_t)((uint64_t)cca >> 63));
252	finish_mod(b, len, m, (uint32_t)((uint64_t)ccb >> 63));
253}
254
255/* see inner.h */
256uint32_t
257br_i31_moddiv(uint32_t *x, const uint32_t *y, const uint32_t *m, uint32_t m0i,
258	uint32_t *t)
259{
260	/*
261	 * Algorithm is an extended binary GCD. We maintain four values
262	 * a, b, u and v, with the following invariants:
263	 *
264	 *   a * x = y * u mod m
265	 *   b * x = y * v mod m
266	 *
267	 * Starting values are:
268	 *
269	 *   a = y
270	 *   b = m
271	 *   u = x
272	 *   v = 0
273	 *
274	 * The formal definition of the algorithm is a sequence of steps:
275	 *
276	 *   - If a is even, then a <- a/2 and u <- u/2 mod m.
277	 *   - Otherwise, if b is even, then b <- b/2 and v <- v/2 mod m.
278	 *   - Otherwise, if a > b, then a <- (a-b)/2 and u <- (u-v)/2 mod m.
279	 *   - Otherwise, b <- (b-a)/2 and v <- (v-u)/2 mod m.
280	 *
281	 * Algorithm stops when a = b. At that point, they both are equal
282	 * to GCD(y,m); the modular division succeeds if that value is 1.
283	 * The result of the modular division is then u (or v: both are
284	 * equal at that point).
285	 *
286	 * Each step makes either a or b shrink by at least one bit; hence,
287	 * if m has bit length k bits, then 2k-2 steps are sufficient.
288	 *
289	 *
290	 * Though complexity is quadratic in the size of m, the bit-by-bit
291	 * processing is not very efficient. We can speed up processing by
292	 * remarking that the decisions are taken based only on observation
293	 * of the top and low bits of a and b.
294	 *
295	 * In the loop below, at each iteration, we use the two top words
296	 * of a and b, and the low words of a and b, to compute reduction
297	 * parameters pa, pb, qa and qb such that the new values for a
298	 * and b are:
299	 *
300	 *   a' = (a*pa + b*pb) / (2^31)
301	 *   b' = (a*qa + b*qb) / (2^31)
302	 *
303	 * the division being exact.
304	 *
305	 * Since the choices are based on the top words, they may be slightly
306	 * off, requiring an optional correction: if a' < 0, then we replace
307	 * pa with -pa, and pb with -pb. The total length of a and b is
308	 * thus reduced by at least 30 bits at each iteration.
309	 *
310	 * The stopping conditions are still the same, though: when a
311	 * and b become equal, they must be both odd (since m is odd,
312	 * the GCD cannot be even), therefore the next operation is a
313	 * subtraction, and one of the values becomes 0. At that point,
314	 * nothing else happens, i.e. one value is stuck at 0, and the
315	 * other one is the GCD.
316	 */
317	size_t len, k;
318	uint32_t *a, *b, *u, *v;
319	uint32_t num, r;
320
321	len = (m[0] + 31) >> 5;
322	a = t;
323	b = a + len;
324	u = x + 1;
325	v = b + len;
326	memcpy(a, y + 1, len * sizeof *y);
327	memcpy(b, m + 1, len * sizeof *m);
328	memset(v, 0, len * sizeof *v);
329
330	/*
331	 * Loop below ensures that a and b are reduced by some bits each,
332	 * for a total of at least 30 bits.
333	 */
334	for (num = ((m[0] - (m[0] >> 5)) << 1) + 30; num >= 30; num -= 30) {
335		size_t j;
336		uint32_t c0, c1;
337		uint32_t a0, a1, b0, b1;
338		uint64_t a_hi, b_hi;
339		uint32_t a_lo, b_lo;
340		int64_t pa, pb, qa, qb;
341		int i;
342
343		/*
344		 * Extract top words of a and b. If j is the highest
345		 * index >= 1 such that a[j] != 0 or b[j] != 0, then we want
346		 * (a[j] << 31) + a[j - 1], and (b[j] << 31) + b[j - 1].
347		 * If a and b are down to one word each, then we use a[0]
348		 * and b[0].
349		 */
350		c0 = (uint32_t)-1;
351		c1 = (uint32_t)-1;
352		a0 = 0;
353		a1 = 0;
354		b0 = 0;
355		b1 = 0;
356		j = len;
357		while (j -- > 0) {
358			uint32_t aw, bw;
359
360			aw = a[j];
361			bw = b[j];
362			a0 ^= (a0 ^ aw) & c0;
363			a1 ^= (a1 ^ aw) & c1;
364			b0 ^= (b0 ^ bw) & c0;
365			b1 ^= (b1 ^ bw) & c1;
366			c1 = c0;
367			c0 &= (((aw | bw) + 0x7FFFFFFF) >> 31) - (uint32_t)1;
368		}
369
370		/*
371		 * If c1 = 0, then we grabbed two words for a and b.
372		 * If c1 != 0 but c0 = 0, then we grabbed one word. It
373		 * is not possible that c1 != 0 and c0 != 0, because that
374		 * would mean that both integers are zero.
375		 */
376		a1 |= a0 & c1;
377		a0 &= ~c1;
378		b1 |= b0 & c1;
379		b0 &= ~c1;
380		a_hi = ((uint64_t)a0 << 31) + a1;
381		b_hi = ((uint64_t)b0 << 31) + b1;
382		a_lo = a[0];
383		b_lo = b[0];
384
385		/*
386		 * Compute reduction factors:
387		 *
388		 *   a' = a*pa + b*pb
389		 *   b' = a*qa + b*qb
390		 *
391		 * such that a' and b' are both multiple of 2^31, but are
392		 * only marginally larger than a and b.
393		 */
394		pa = 1;
395		pb = 0;
396		qa = 0;
397		qb = 1;
398		for (i = 0; i < 31; i ++) {
399			/*
400			 * At each iteration:
401			 *
402			 *   a <- (a-b)/2 if: a is odd, b is odd, a_hi > b_hi
403			 *   b <- (b-a)/2 if: a is odd, b is odd, a_hi <= b_hi
404			 *   a <- a/2 if: a is even
405			 *   b <- b/2 if: a is odd, b is even
406			 *
407			 * We multiply a_lo and b_lo by 2 at each
408			 * iteration, thus a division by 2 really is a
409			 * non-multiplication by 2.
410			 */
411			uint32_t r, oa, ob, cAB, cBA, cA;
412			uint64_t rz;
413
414			/*
415			 * r = GT(a_hi, b_hi)
416			 * But the GT() function works on uint32_t operands,
417			 * so we inline a 64-bit version here.
418			 */
419			rz = b_hi - a_hi;
420			r = (uint32_t)((rz ^ ((a_hi ^ b_hi)
421				& (a_hi ^ rz))) >> 63);
422
423			/*
424			 * cAB = 1 if b must be subtracted from a
425			 * cBA = 1 if a must be subtracted from b
426			 * cA = 1 if a is divided by 2, 0 otherwise
427			 *
428			 * Rules:
429			 *
430			 *   cAB and cBA cannot be both 1.
431			 *   if a is not divided by 2, b is.
432			 */
433			oa = (a_lo >> i) & 1;
434			ob = (b_lo >> i) & 1;
435			cAB = oa & ob & r;
436			cBA = oa & ob & NOT(r);
437			cA = cAB | NOT(oa);
438
439			/*
440			 * Conditional subtractions.
441			 */
442			a_lo -= b_lo & -cAB;
443			a_hi -= b_hi & -(uint64_t)cAB;
444			pa -= qa & -(int64_t)cAB;
445			pb -= qb & -(int64_t)cAB;
446			b_lo -= a_lo & -cBA;
447			b_hi -= a_hi & -(uint64_t)cBA;
448			qa -= pa & -(int64_t)cBA;
449			qb -= pb & -(int64_t)cBA;
450
451			/*
452			 * Shifting.
453			 */
454			a_lo += a_lo & (cA - 1);
455			pa += pa & ((int64_t)cA - 1);
456			pb += pb & ((int64_t)cA - 1);
457			a_hi ^= (a_hi ^ (a_hi >> 1)) & -(uint64_t)cA;
458			b_lo += b_lo & -cA;
459			qa += qa & -(int64_t)cA;
460			qb += qb & -(int64_t)cA;
461			b_hi ^= (b_hi ^ (b_hi >> 1)) & ((uint64_t)cA - 1);
462		}
463
464		/*
465		 * Replace a and b with new values a' and b'.
466		 */
467		r = co_reduce(a, b, len, pa, pb, qa, qb);
468		pa -= pa * ((r & 1) << 1);
469		pb -= pb * ((r & 1) << 1);
470		qa -= qa * (r & 2);
471		qb -= qb * (r & 2);
472		co_reduce_mod(u, v, len, pa, pb, qa, qb, m + 1, m0i);
473	}
474
475	/*
476	 * Now one of the arrays should be 0, and the other contains
477	 * the GCD. If a is 0, then u is 0 as well, and v contains
478	 * the division result.
479	 * Result is correct if and only if GCD is 1.
480	 */
481	r = (a[0] | b[0]) ^ 1;
482	u[0] |= v[0];
483	for (k = 1; k < len; k ++) {
484		r |= a[k] | b[k];
485		u[k] |= v[k];
486	}
487	return EQ0(r);
488}
489