1/*
2 * Copyright (c) 2001 Dima Dorfman.
3 * All rights reserved.
4 *
5 * Redistribution and use in source and binary forms, with or without
6 * modification, are permitted provided that the following conditions
7 * are met:
8 * 1. Redistributions of source code must retain the above copyright
9 *    notice, this list of conditions and the following disclaimer.
10 * 2. Redistributions in binary form must reproduce the above copyright
11 *    notice, this list of conditions and the following disclaimer in the
12 *    documentation and/or other materials provided with the distribution.
13 *
14 * THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS ``AS IS'' AND
15 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
16 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
17 * ARE DISCLAIMED.  IN NO EVENT SHALL THE AUTHOR OR CONTRIBUTORS BE LIABLE
18 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
19 * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
20 * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
21 * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
22 * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
23 * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
24 * SUCH DAMAGE.
25 */
26
27/*
28 * This is the traditional Berkeley MP library implemented in terms of
29 * the OpenSSL BIGNUM library.  It was written to replace libgmp, and
30 * is meant to be as compatible with the latter as feasible.
31 *
32 * There seems to be a lack of documentation for the Berkeley MP
33 * interface.  All I could find was libgmp documentation (which didn't
34 * talk about the semantics of the functions) and an old SunOS 4.1
35 * manual page from 1989.  The latter wasn't very detailed, either,
36 * but at least described what the function's arguments were.  In
37 * general the interface seems to be archaic, somewhat poorly
38 * designed, and poorly, if at all, documented.  It is considered
39 * harmful.
40 *
41 * Miscellaneous notes on this implementation:
42 *
43 *  - The SunOS manual page mentioned above indicates that if an error
44 *  occurs, the library should "produce messages and core images."
45 *  Given that most of the functions don't have return values (and
46 *  thus no sane way of alerting the caller to an error), this seems
47 *  reasonable.  The MPERR and MPERRX macros call warn and warnx,
48 *  respectively, then abort().
49 *
50 *  - All the functions which take an argument to be "filled in"
51 *  assume that the argument has been initialized by one of the *tom()
52 *  routines before being passed to it.  I never saw this documented
53 *  anywhere, but this seems to be consistent with the way this
54 *  library is used.
55 *
56 *  - msqrt() is the only routine which had to be implemented which
57 *  doesn't have a close counterpart in the OpenSSL BIGNUM library.
58 *  It was implemented by hand using Newton's recursive formula.
59 *  Doing it this way, although more error-prone, has the positive
60 *  sideaffect of testing a lot of other functions; if msqrt()
61 *  produces the correct results, most of the other routines will as
62 *  well.
63 *
64 *  - Internal-use-only routines (i.e., those defined here statically
65 *  and not in mp.h) have an underscore prepended to their name (this
66 *  is more for aesthetical reasons than technical).  All such
67 *  routines take an extra argument, 'msg', that denotes what they
68 *  should call themselves in an error message.  This is so a user
69 *  doesn't get an error message from a function they didn't call.
70 */
71
72#include <sys/cdefs.h>
73__FBSDID("$FreeBSD$");
74
75#include <ctype.h>
76#include <err.h>
77#include <errno.h>
78#include <stdio.h>
79#include <stdlib.h>
80#include <string.h>
81
82#include <openssl/crypto.h>
83#include <openssl/err.h>
84
85#include "mp.h"
86
87#define MPERR(s)	do { warn s; abort(); } while (0)
88#define MPERRX(s)	do { warnx s; abort(); } while (0)
89#define BN_ERRCHECK(msg, expr) do {		\
90	if (!(expr)) _bnerr(msg);		\
91} while (0)
92
93static void _bnerr(const char *);
94static MINT *_dtom(const char *, const char *);
95static MINT *_itom(const char *, short);
96static void _madd(const char *, const MINT *, const MINT *, MINT *);
97static int _mcmpa(const char *, const MINT *, const MINT *);
98static void _mdiv(const char *, const MINT *, const MINT *, MINT *, MINT *,
99		BN_CTX *);
100static void _mfree(const char *, MINT *);
101static void _moveb(const char *, const BIGNUM *, MINT *);
102static void _movem(const char *, const MINT *, MINT *);
103static void _msub(const char *, const MINT *, const MINT *, MINT *);
104static char *_mtod(const char *, const MINT *);
105static char *_mtox(const char *, const MINT *);
106static void _mult(const char *, const MINT *, const MINT *, MINT *, BN_CTX *);
107static void _sdiv(const char *, const MINT *, short, MINT *, short *, BN_CTX *);
108static MINT *_xtom(const char *, const char *);
109
110/*
111 * Report an error from one of the BN_* functions using MPERRX.
112 */
113static void
114_bnerr(const char *msg)
115{
116
117	ERR_load_crypto_strings();
118	MPERRX(("%s: %s", msg, ERR_reason_error_string(ERR_get_error())));
119}
120
121/*
122 * Convert a decimal string to an MINT.
123 */
124static MINT *
125_dtom(const char *msg, const char *s)
126{
127	MINT *mp;
128
129	mp = malloc(sizeof(*mp));
130	if (mp == NULL)
131		MPERR(("%s", msg));
132	mp->bn = BN_new();
133	if (mp->bn == NULL)
134		_bnerr(msg);
135	BN_ERRCHECK(msg, BN_dec2bn(&mp->bn, s));
136	return (mp);
137}
138
139/*
140 * Compute the greatest common divisor of mp1 and mp2; result goes in rmp.
141 */
142void
143mp_gcd(const MINT *mp1, const MINT *mp2, MINT *rmp)
144{
145	BIGNUM b;
146	BN_CTX *c;
147
148	c = BN_CTX_new();
149	if (c == NULL)
150		_bnerr("gcd");
151	BN_init(&b);
152	BN_ERRCHECK("gcd", BN_gcd(&b, mp1->bn, mp2->bn, c));
153	_moveb("gcd", &b, rmp);
154	BN_free(&b);
155	BN_CTX_free(c);
156}
157
158/*
159 * Make an MINT out of a short integer.  Return value must be mfree()'d.
160 */
161static MINT *
162_itom(const char *msg, short n)
163{
164	MINT *mp;
165	char *s;
166
167	asprintf(&s, "%x", n);
168	if (s == NULL)
169		MPERR(("%s", msg));
170	mp = _xtom(msg, s);
171	free(s);
172	return (mp);
173}
174
175MINT *
176mp_itom(short n)
177{
178
179	return (_itom("itom", n));
180}
181
182/*
183 * Compute rmp=mp1+mp2.
184 */
185static void
186_madd(const char *msg, const MINT *mp1, const MINT *mp2, MINT *rmp)
187{
188	BIGNUM b;
189
190	BN_init(&b);
191	BN_ERRCHECK(msg, BN_add(&b, mp1->bn, mp2->bn));
192	_moveb(msg, &b, rmp);
193	BN_free(&b);
194}
195
196void
197mp_madd(const MINT *mp1, const MINT *mp2, MINT *rmp)
198{
199
200	_madd("madd", mp1, mp2, rmp);
201}
202
203/*
204 * Return -1, 0, or 1 if mp1<mp2, mp1==mp2, or mp1>mp2, respectivley.
205 */
206int
207mp_mcmp(const MINT *mp1, const MINT *mp2)
208{
209
210	return (BN_cmp(mp1->bn, mp2->bn));
211}
212
213/*
214 * Same as mcmp but compares absolute values.
215 */
216static int
217_mcmpa(const char *msg __unused, const MINT *mp1, const MINT *mp2)
218{
219
220	return (BN_ucmp(mp1->bn, mp2->bn));
221}
222
223/*
224 * Compute qmp=nmp/dmp and rmp=nmp%dmp.
225 */
226static void
227_mdiv(const char *msg, const MINT *nmp, const MINT *dmp, MINT *qmp, MINT *rmp,
228    BN_CTX *c)
229{
230	BIGNUM q, r;
231
232	BN_init(&r);
233	BN_init(&q);
234	BN_ERRCHECK(msg, BN_div(&q, &r, nmp->bn, dmp->bn, c));
235	_moveb(msg, &q, qmp);
236	_moveb(msg, &r, rmp);
237	BN_free(&q);
238	BN_free(&r);
239}
240
241void
242mp_mdiv(const MINT *nmp, const MINT *dmp, MINT *qmp, MINT *rmp)
243{
244	BN_CTX *c;
245
246	c = BN_CTX_new();
247	if (c == NULL)
248		_bnerr("mdiv");
249	_mdiv("mdiv", nmp, dmp, qmp, rmp, c);
250	BN_CTX_free(c);
251}
252
253/*
254 * Free memory associated with an MINT.
255 */
256static void
257_mfree(const char *msg __unused, MINT *mp)
258{
259
260	BN_clear(mp->bn);
261	BN_free(mp->bn);
262	free(mp);
263}
264
265void
266mp_mfree(MINT *mp)
267{
268
269	_mfree("mfree", mp);
270}
271
272/*
273 * Read an integer from standard input and stick the result in mp.
274 * The input is treated to be in base 10.  This must be the silliest
275 * API in existence; why can't the program read in a string and call
276 * xtom()?  (Or if base 10 is desires, perhaps dtom() could be
277 * exported.)
278 */
279void
280mp_min(MINT *mp)
281{
282	MINT *rmp;
283	char *line, *nline;
284	size_t linelen;
285
286	line = fgetln(stdin, &linelen);
287	if (line == NULL)
288		MPERR(("min"));
289	nline = malloc(linelen + 1);
290	if (nline == NULL)
291		MPERR(("min"));
292	memcpy(nline, line, linelen);
293	nline[linelen] = '\0';
294	rmp = _dtom("min", nline);
295	_movem("min", rmp, mp);
296	_mfree("min", rmp);
297	free(nline);
298}
299
300/*
301 * Print the value of mp to standard output in base 10.  See blurb
302 * above min() for why this is so useless.
303 */
304void
305mp_mout(const MINT *mp)
306{
307	char *s;
308
309	s = _mtod("mout", mp);
310	printf("%s", s);
311	free(s);
312}
313
314/*
315 * Set the value of tmp to the value of smp (i.e., tmp=smp).
316 */
317void
318mp_move(const MINT *smp, MINT *tmp)
319{
320
321	_movem("move", smp, tmp);
322}
323
324
325/*
326 * Internal routine to set the value of tmp to that of sbp.
327 */
328static void
329_moveb(const char *msg, const BIGNUM *sbp, MINT *tmp)
330{
331
332	BN_ERRCHECK(msg, BN_copy(tmp->bn, sbp));
333}
334
335/*
336 * Internal routine to set the value of tmp to that of smp.
337 */
338static void
339_movem(const char *msg, const MINT *smp, MINT *tmp)
340{
341
342	BN_ERRCHECK(msg, BN_copy(tmp->bn, smp->bn));
343}
344
345/*
346 * Compute the square root of nmp and put the result in xmp.  The
347 * remainder goes in rmp.  Should satisfy: rmp=nmp-(xmp*xmp).
348 *
349 * Note that the OpenSSL BIGNUM library does not have a square root
350 * function, so this had to be implemented by hand using Newton's
351 * recursive formula:
352 *
353 *		x = (x + (n / x)) / 2
354 *
355 * where x is the square root of the positive number n.  In the
356 * beginning, x should be a reasonable guess, but the value 1,
357 * although suboptimal, works, too; this is that is used below.
358 */
359void
360mp_msqrt(const MINT *nmp, MINT *xmp, MINT *rmp)
361{
362	BN_CTX *c;
363	MINT *tolerance;
364	MINT *ox, *x;
365	MINT *z1, *z2, *z3;
366	short i;
367
368	c = BN_CTX_new();
369	if (c == NULL)
370		_bnerr("msqrt");
371	tolerance = _itom("msqrt", 1);
372	x = _itom("msqrt", 1);
373	ox = _itom("msqrt", 0);
374	z1 = _itom("msqrt", 0);
375	z2 = _itom("msqrt", 0);
376	z3 = _itom("msqrt", 0);
377	do {
378		_movem("msqrt", x, ox);
379		_mdiv("msqrt", nmp, x, z1, z2, c);
380		_madd("msqrt", x, z1, z2);
381		_sdiv("msqrt", z2, 2, x, &i, c);
382		_msub("msqrt", ox, x, z3);
383	} while (_mcmpa("msqrt", z3, tolerance) == 1);
384	_movem("msqrt", x, xmp);
385	_mult("msqrt", x, x, z1, c);
386	_msub("msqrt", nmp, z1, z2);
387	_movem("msqrt", z2, rmp);
388	_mfree("msqrt", tolerance);
389	_mfree("msqrt", ox);
390	_mfree("msqrt", x);
391	_mfree("msqrt", z1);
392	_mfree("msqrt", z2);
393	_mfree("msqrt", z3);
394	BN_CTX_free(c);
395}
396
397/*
398 * Compute rmp=mp1-mp2.
399 */
400static void
401_msub(const char *msg, const MINT *mp1, const MINT *mp2, MINT *rmp)
402{
403	BIGNUM b;
404
405	BN_init(&b);
406	BN_ERRCHECK(msg, BN_sub(&b, mp1->bn, mp2->bn));
407	_moveb(msg, &b, rmp);
408	BN_free(&b);
409}
410
411void
412mp_msub(const MINT *mp1, const MINT *mp2, MINT *rmp)
413{
414
415	_msub("msub", mp1, mp2, rmp);
416}
417
418/*
419 * Return a decimal representation of mp.  Return value must be
420 * free()'d.
421 */
422static char *
423_mtod(const char *msg, const MINT *mp)
424{
425	char *s, *s2;
426
427	s = BN_bn2dec(mp->bn);
428	if (s == NULL)
429		_bnerr(msg);
430	asprintf(&s2, "%s", s);
431	if (s2 == NULL)
432		MPERR(("%s", msg));
433	OPENSSL_free(s);
434	return (s2);
435}
436
437/*
438 * Return a hexadecimal representation of mp.  Return value must be
439 * free()'d.
440 */
441static char *
442_mtox(const char *msg, const MINT *mp)
443{
444	char *p, *s, *s2;
445	int len;
446
447	s = BN_bn2hex(mp->bn);
448	if (s == NULL)
449		_bnerr(msg);
450	asprintf(&s2, "%s", s);
451	if (s2 == NULL)
452		MPERR(("%s", msg));
453	OPENSSL_free(s);
454
455	/*
456	 * This is a kludge for libgmp compatibility.  The latter's
457	 * implementation of this function returns lower-case letters,
458	 * but BN_bn2hex returns upper-case.  Some programs (e.g.,
459	 * newkey(1)) are sensitive to this.  Although it's probably
460	 * their fault, it's nice to be compatible.
461	 */
462	len = strlen(s2);
463	for (p = s2; p < s2 + len; p++)
464		*p = tolower(*p);
465
466	return (s2);
467}
468
469char *
470mp_mtox(const MINT *mp)
471{
472
473	return (_mtox("mtox", mp));
474}
475
476/*
477 * Compute rmp=mp1*mp2.
478 */
479static void
480_mult(const char *msg, const MINT *mp1, const MINT *mp2, MINT *rmp, BN_CTX *c)
481{
482	BIGNUM b;
483
484	BN_init(&b);
485	BN_ERRCHECK(msg, BN_mul(&b, mp1->bn, mp2->bn, c));
486	_moveb(msg, &b, rmp);
487	BN_free(&b);
488}
489
490void
491mp_mult(const MINT *mp1, const MINT *mp2, MINT *rmp)
492{
493	BN_CTX *c;
494
495	c = BN_CTX_new();
496	if (c == NULL)
497		_bnerr("mult");
498	_mult("mult", mp1, mp2, rmp, c);
499	BN_CTX_free(c);
500}
501
502/*
503 * Compute rmp=(bmp^emp)mod mmp.  (Note that here and above rpow() '^'
504 * means 'raise to power', not 'bitwise XOR'.)
505 */
506void
507mp_pow(const MINT *bmp, const MINT *emp, const MINT *mmp, MINT *rmp)
508{
509	BIGNUM b;
510	BN_CTX *c;
511
512	c = BN_CTX_new();
513	if (c == NULL)
514		_bnerr("pow");
515	BN_init(&b);
516	BN_ERRCHECK("pow", BN_mod_exp(&b, bmp->bn, emp->bn, mmp->bn, c));
517	_moveb("pow", &b, rmp);
518	BN_free(&b);
519	BN_CTX_free(c);
520}
521
522/*
523 * Compute rmp=bmp^e.  (See note above pow().)
524 */
525void
526mp_rpow(const MINT *bmp, short e, MINT *rmp)
527{
528	MINT *emp;
529	BIGNUM b;
530	BN_CTX *c;
531
532	c = BN_CTX_new();
533	if (c == NULL)
534		_bnerr("rpow");
535	BN_init(&b);
536	emp = _itom("rpow", e);
537	BN_ERRCHECK("rpow", BN_exp(&b, bmp->bn, emp->bn, c));
538	_moveb("rpow", &b, rmp);
539	_mfree("rpow", emp);
540	BN_free(&b);
541	BN_CTX_free(c);
542}
543
544/*
545 * Compute qmp=nmp/d and ro=nmp%d.
546 */
547static void
548_sdiv(const char *msg, const MINT *nmp, short d, MINT *qmp, short *ro,
549    BN_CTX *c)
550{
551	MINT *dmp, *rmp;
552	BIGNUM q, r;
553	char *s;
554
555	BN_init(&q);
556	BN_init(&r);
557	dmp = _itom(msg, d);
558	rmp = _itom(msg, 0);
559	BN_ERRCHECK(msg, BN_div(&q, &r, nmp->bn, dmp->bn, c));
560	_moveb(msg, &q, qmp);
561	_moveb(msg, &r, rmp);
562	s = _mtox(msg, rmp);
563	errno = 0;
564	*ro = strtol(s, NULL, 16);
565	if (errno != 0)
566		MPERR(("%s underflow or overflow", msg));
567	free(s);
568	_mfree(msg, dmp);
569	_mfree(msg, rmp);
570	BN_free(&r);
571	BN_free(&q);
572}
573
574void
575mp_sdiv(const MINT *nmp, short d, MINT *qmp, short *ro)
576{
577	BN_CTX *c;
578
579	c = BN_CTX_new();
580	if (c == NULL)
581		_bnerr("sdiv");
582	_sdiv("sdiv", nmp, d, qmp, ro, c);
583	BN_CTX_free(c);
584}
585
586/*
587 * Convert a hexadecimal string to an MINT.
588 */
589static MINT *
590_xtom(const char *msg, const char *s)
591{
592	MINT *mp;
593
594	mp = malloc(sizeof(*mp));
595	if (mp == NULL)
596		MPERR(("%s", msg));
597	mp->bn = BN_new();
598	if (mp->bn == NULL)
599		_bnerr(msg);
600	BN_ERRCHECK(msg, BN_hex2bn(&mp->bn, s));
601	return (mp);
602}
603
604MINT *
605mp_xtom(const char *s)
606{
607
608	return (_xtom("xtom", s));
609}
610