1/*-
2 * SPDX-License-Identifier: BSD-2-Clause-FreeBSD
3 *
4 * Copyright (c) 2001 Dima Dorfman.
5 * All rights reserved.
6 *
7 * Redistribution and use in source and binary forms, with or without
8 * modification, are permitted provided that the following conditions
9 * are met:
10 * 1. Redistributions of source code must retain the above copyright
11 *    notice, this list of conditions and the following disclaimer.
12 * 2. Redistributions in binary form must reproduce the above copyright
13 *    notice, this list of conditions and the following disclaimer in the
14 *    documentation and/or other materials provided with the distribution.
15 *
16 * THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS ``AS IS'' AND
17 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19 * ARE DISCLAIMED.  IN NO EVENT SHALL THE AUTHOR OR CONTRIBUTORS BE LIABLE
20 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
21 * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
22 * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
23 * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
24 * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
25 * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
26 * SUCH DAMAGE.
27 */
28
29/*
30 * This is the traditional Berkeley MP library implemented in terms of
31 * the OpenSSL BIGNUM library.  It was written to replace libgmp, and
32 * is meant to be as compatible with the latter as feasible.
33 *
34 * There seems to be a lack of documentation for the Berkeley MP
35 * interface.  All I could find was libgmp documentation (which didn't
36 * talk about the semantics of the functions) and an old SunOS 4.1
37 * manual page from 1989.  The latter wasn't very detailed, either,
38 * but at least described what the function's arguments were.  In
39 * general the interface seems to be archaic, somewhat poorly
40 * designed, and poorly, if at all, documented.  It is considered
41 * harmful.
42 *
43 * Miscellaneous notes on this implementation:
44 *
45 *  - The SunOS manual page mentioned above indicates that if an error
46 *  occurs, the library should "produce messages and core images."
47 *  Given that most of the functions don't have return values (and
48 *  thus no sane way of alerting the caller to an error), this seems
49 *  reasonable.  The MPERR and MPERRX macros call warn and warnx,
50 *  respectively, then abort().
51 *
52 *  - All the functions which take an argument to be "filled in"
53 *  assume that the argument has been initialized by one of the *tom()
54 *  routines before being passed to it.  I never saw this documented
55 *  anywhere, but this seems to be consistent with the way this
56 *  library is used.
57 *
58 *  - msqrt() is the only routine which had to be implemented which
59 *  doesn't have a close counterpart in the OpenSSL BIGNUM library.
60 *  It was implemented by hand using Newton's recursive formula.
61 *  Doing it this way, although more error-prone, has the positive
62 *  sideaffect of testing a lot of other functions; if msqrt()
63 *  produces the correct results, most of the other routines will as
64 *  well.
65 *
66 *  - Internal-use-only routines (i.e., those defined here statically
67 *  and not in mp.h) have an underscore prepended to their name (this
68 *  is more for aesthetical reasons than technical).  All such
69 *  routines take an extra argument, 'msg', that denotes what they
70 *  should call themselves in an error message.  This is so a user
71 *  doesn't get an error message from a function they didn't call.
72 */
73
74#include <sys/cdefs.h>
75__FBSDID("$FreeBSD$");
76
77#include <ctype.h>
78#include <err.h>
79#include <errno.h>
80#include <stdio.h>
81#include <stdlib.h>
82#include <string.h>
83
84#include <openssl/crypto.h>
85#include <openssl/err.h>
86
87#include "mp.h"
88
89#define MPERR(s)	do { warn s; abort(); } while (0)
90#define MPERRX(s)	do { warnx s; abort(); } while (0)
91#define BN_ERRCHECK(msg, expr) do {		\
92	if (!(expr)) _bnerr(msg);		\
93} while (0)
94
95static void _bnerr(const char *);
96static MINT *_dtom(const char *, const char *);
97static MINT *_itom(const char *, short);
98static void _madd(const char *, const MINT *, const MINT *, MINT *);
99static int _mcmpa(const char *, const MINT *, const MINT *);
100static void _mdiv(const char *, const MINT *, const MINT *, MINT *, MINT *,
101		BN_CTX *);
102static void _mfree(const char *, MINT *);
103static void _moveb(const char *, const BIGNUM *, MINT *);
104static void _movem(const char *, const MINT *, MINT *);
105static void _msub(const char *, const MINT *, const MINT *, MINT *);
106static char *_mtod(const char *, const MINT *);
107static char *_mtox(const char *, const MINT *);
108static void _mult(const char *, const MINT *, const MINT *, MINT *, BN_CTX *);
109static void _sdiv(const char *, const MINT *, short, MINT *, short *, BN_CTX *);
110static MINT *_xtom(const char *, const char *);
111
112/*
113 * Report an error from one of the BN_* functions using MPERRX.
114 */
115static void
116_bnerr(const char *msg)
117{
118
119	ERR_load_crypto_strings();
120	MPERRX(("%s: %s", msg, ERR_reason_error_string(ERR_get_error())));
121}
122
123/*
124 * Convert a decimal string to an MINT.
125 */
126static MINT *
127_dtom(const char *msg, const char *s)
128{
129	MINT *mp;
130
131	mp = malloc(sizeof(*mp));
132	if (mp == NULL)
133		MPERR(("%s", msg));
134	mp->bn = BN_new();
135	if (mp->bn == NULL)
136		_bnerr(msg);
137	BN_ERRCHECK(msg, BN_dec2bn(&mp->bn, s));
138	return (mp);
139}
140
141/*
142 * Compute the greatest common divisor of mp1 and mp2; result goes in rmp.
143 */
144void
145mp_gcd(const MINT *mp1, const MINT *mp2, MINT *rmp)
146{
147	BIGNUM *b;
148	BN_CTX *c;
149
150	b = NULL;
151	c = BN_CTX_new();
152	if (c != NULL)
153		b = BN_new();
154	if (c == NULL || b == NULL)
155		_bnerr("gcd");
156	BN_ERRCHECK("gcd", BN_gcd(b, mp1->bn, mp2->bn, c));
157	_moveb("gcd", b, rmp);
158	BN_free(b);
159	BN_CTX_free(c);
160}
161
162/*
163 * Make an MINT out of a short integer.  Return value must be mfree()'d.
164 */
165static MINT *
166_itom(const char *msg, short n)
167{
168	MINT *mp;
169	char *s;
170
171	asprintf(&s, "%x", n);
172	if (s == NULL)
173		MPERR(("%s", msg));
174	mp = _xtom(msg, s);
175	free(s);
176	return (mp);
177}
178
179MINT *
180mp_itom(short n)
181{
182
183	return (_itom("itom", n));
184}
185
186/*
187 * Compute rmp=mp1+mp2.
188 */
189static void
190_madd(const char *msg, const MINT *mp1, const MINT *mp2, MINT *rmp)
191{
192	BIGNUM *b;
193
194	b = BN_new();
195	if (b == NULL)
196		_bnerr(msg);
197	BN_ERRCHECK(msg, BN_add(b, mp1->bn, mp2->bn));
198	_moveb(msg, b, rmp);
199	BN_free(b);
200}
201
202void
203mp_madd(const MINT *mp1, const MINT *mp2, MINT *rmp)
204{
205
206	_madd("madd", mp1, mp2, rmp);
207}
208
209/*
210 * Return -1, 0, or 1 if mp1<mp2, mp1==mp2, or mp1>mp2, respectivley.
211 */
212int
213mp_mcmp(const MINT *mp1, const MINT *mp2)
214{
215
216	return (BN_cmp(mp1->bn, mp2->bn));
217}
218
219/*
220 * Same as mcmp but compares absolute values.
221 */
222static int
223_mcmpa(const char *msg __unused, const MINT *mp1, const MINT *mp2)
224{
225
226	return (BN_ucmp(mp1->bn, mp2->bn));
227}
228
229/*
230 * Compute qmp=nmp/dmp and rmp=nmp%dmp.
231 */
232static void
233_mdiv(const char *msg, const MINT *nmp, const MINT *dmp, MINT *qmp, MINT *rmp,
234    BN_CTX *c)
235{
236	BIGNUM *q, *r;
237
238	q = NULL;
239	r = BN_new();
240	if (r != NULL)
241		q = BN_new();
242	if (r == NULL || q == NULL)
243		_bnerr(msg);
244	BN_ERRCHECK(msg, BN_div(q, r, nmp->bn, dmp->bn, c));
245	_moveb(msg, q, qmp);
246	_moveb(msg, r, rmp);
247	BN_free(q);
248	BN_free(r);
249}
250
251void
252mp_mdiv(const MINT *nmp, const MINT *dmp, MINT *qmp, MINT *rmp)
253{
254	BN_CTX *c;
255
256	c = BN_CTX_new();
257	if (c == NULL)
258		_bnerr("mdiv");
259	_mdiv("mdiv", nmp, dmp, qmp, rmp, c);
260	BN_CTX_free(c);
261}
262
263/*
264 * Free memory associated with an MINT.
265 */
266static void
267_mfree(const char *msg __unused, MINT *mp)
268{
269
270	BN_clear(mp->bn);
271	BN_free(mp->bn);
272	free(mp);
273}
274
275void
276mp_mfree(MINT *mp)
277{
278
279	_mfree("mfree", mp);
280}
281
282/*
283 * Read an integer from standard input and stick the result in mp.
284 * The input is treated to be in base 10.  This must be the silliest
285 * API in existence; why can't the program read in a string and call
286 * xtom()?  (Or if base 10 is desires, perhaps dtom() could be
287 * exported.)
288 */
289void
290mp_min(MINT *mp)
291{
292	MINT *rmp;
293	char *line, *nline;
294	size_t linelen;
295
296	line = fgetln(stdin, &linelen);
297	if (line == NULL)
298		MPERR(("min"));
299	nline = malloc(linelen + 1);
300	if (nline == NULL)
301		MPERR(("min"));
302	memcpy(nline, line, linelen);
303	nline[linelen] = '\0';
304	rmp = _dtom("min", nline);
305	_movem("min", rmp, mp);
306	_mfree("min", rmp);
307	free(nline);
308}
309
310/*
311 * Print the value of mp to standard output in base 10.  See blurb
312 * above min() for why this is so useless.
313 */
314void
315mp_mout(const MINT *mp)
316{
317	char *s;
318
319	s = _mtod("mout", mp);
320	printf("%s", s);
321	free(s);
322}
323
324/*
325 * Set the value of tmp to the value of smp (i.e., tmp=smp).
326 */
327void
328mp_move(const MINT *smp, MINT *tmp)
329{
330
331	_movem("move", smp, tmp);
332}
333
334
335/*
336 * Internal routine to set the value of tmp to that of sbp.
337 */
338static void
339_moveb(const char *msg, const BIGNUM *sbp, MINT *tmp)
340{
341
342	BN_ERRCHECK(msg, BN_copy(tmp->bn, sbp));
343}
344
345/*
346 * Internal routine to set the value of tmp to that of smp.
347 */
348static void
349_movem(const char *msg, const MINT *smp, MINT *tmp)
350{
351
352	BN_ERRCHECK(msg, BN_copy(tmp->bn, smp->bn));
353}
354
355/*
356 * Compute the square root of nmp and put the result in xmp.  The
357 * remainder goes in rmp.  Should satisfy: rmp=nmp-(xmp*xmp).
358 *
359 * Note that the OpenSSL BIGNUM library does not have a square root
360 * function, so this had to be implemented by hand using Newton's
361 * recursive formula:
362 *
363 *		x = (x + (n / x)) / 2
364 *
365 * where x is the square root of the positive number n.  In the
366 * beginning, x should be a reasonable guess, but the value 1,
367 * although suboptimal, works, too; this is that is used below.
368 */
369void
370mp_msqrt(const MINT *nmp, MINT *xmp, MINT *rmp)
371{
372	BN_CTX *c;
373	MINT *tolerance;
374	MINT *ox, *x;
375	MINT *z1, *z2, *z3;
376	short i;
377
378	c = BN_CTX_new();
379	if (c == NULL)
380		_bnerr("msqrt");
381	tolerance = _itom("msqrt", 1);
382	x = _itom("msqrt", 1);
383	ox = _itom("msqrt", 0);
384	z1 = _itom("msqrt", 0);
385	z2 = _itom("msqrt", 0);
386	z3 = _itom("msqrt", 0);
387	do {
388		_movem("msqrt", x, ox);
389		_mdiv("msqrt", nmp, x, z1, z2, c);
390		_madd("msqrt", x, z1, z2);
391		_sdiv("msqrt", z2, 2, x, &i, c);
392		_msub("msqrt", ox, x, z3);
393	} while (_mcmpa("msqrt", z3, tolerance) == 1);
394	_movem("msqrt", x, xmp);
395	_mult("msqrt", x, x, z1, c);
396	_msub("msqrt", nmp, z1, z2);
397	_movem("msqrt", z2, rmp);
398	_mfree("msqrt", tolerance);
399	_mfree("msqrt", ox);
400	_mfree("msqrt", x);
401	_mfree("msqrt", z1);
402	_mfree("msqrt", z2);
403	_mfree("msqrt", z3);
404	BN_CTX_free(c);
405}
406
407/*
408 * Compute rmp=mp1-mp2.
409 */
410static void
411_msub(const char *msg, const MINT *mp1, const MINT *mp2, MINT *rmp)
412{
413	BIGNUM *b;
414
415	b = BN_new();
416	if (b == NULL)
417		_bnerr(msg);
418	BN_ERRCHECK(msg, BN_sub(b, mp1->bn, mp2->bn));
419	_moveb(msg, b, rmp);
420	BN_free(b);
421}
422
423void
424mp_msub(const MINT *mp1, const MINT *mp2, MINT *rmp)
425{
426
427	_msub("msub", mp1, mp2, rmp);
428}
429
430/*
431 * Return a decimal representation of mp.  Return value must be
432 * free()'d.
433 */
434static char *
435_mtod(const char *msg, const MINT *mp)
436{
437	char *s, *s2;
438
439	s = BN_bn2dec(mp->bn);
440	if (s == NULL)
441		_bnerr(msg);
442	asprintf(&s2, "%s", s);
443	if (s2 == NULL)
444		MPERR(("%s", msg));
445	OPENSSL_free(s);
446	return (s2);
447}
448
449/*
450 * Return a hexadecimal representation of mp.  Return value must be
451 * free()'d.
452 */
453static char *
454_mtox(const char *msg, const MINT *mp)
455{
456	char *p, *s, *s2;
457	int len;
458
459	s = BN_bn2hex(mp->bn);
460	if (s == NULL)
461		_bnerr(msg);
462	asprintf(&s2, "%s", s);
463	if (s2 == NULL)
464		MPERR(("%s", msg));
465	OPENSSL_free(s);
466
467	/*
468	 * This is a kludge for libgmp compatibility.  The latter's
469	 * implementation of this function returns lower-case letters,
470	 * but BN_bn2hex returns upper-case.  Some programs (e.g.,
471	 * newkey(1)) are sensitive to this.  Although it's probably
472	 * their fault, it's nice to be compatible.
473	 */
474	len = strlen(s2);
475	for (p = s2; p < s2 + len; p++)
476		*p = tolower(*p);
477
478	return (s2);
479}
480
481char *
482mp_mtox(const MINT *mp)
483{
484
485	return (_mtox("mtox", mp));
486}
487
488/*
489 * Compute rmp=mp1*mp2.
490 */
491static void
492_mult(const char *msg, const MINT *mp1, const MINT *mp2, MINT *rmp, BN_CTX *c)
493{
494	BIGNUM *b;
495
496	b = BN_new();
497	if (b == NULL)
498		_bnerr(msg);
499	BN_ERRCHECK(msg, BN_mul(b, mp1->bn, mp2->bn, c));
500	_moveb(msg, b, rmp);
501	BN_free(b);
502}
503
504void
505mp_mult(const MINT *mp1, const MINT *mp2, MINT *rmp)
506{
507	BN_CTX *c;
508
509	c = BN_CTX_new();
510	if (c == NULL)
511		_bnerr("mult");
512	_mult("mult", mp1, mp2, rmp, c);
513	BN_CTX_free(c);
514}
515
516/*
517 * Compute rmp=(bmp^emp)mod mmp.  (Note that here and above rpow() '^'
518 * means 'raise to power', not 'bitwise XOR'.)
519 */
520void
521mp_pow(const MINT *bmp, const MINT *emp, const MINT *mmp, MINT *rmp)
522{
523	BIGNUM *b;
524	BN_CTX *c;
525
526	b = NULL;
527	c = BN_CTX_new();
528	if (c != NULL)
529		b = BN_new();
530	if (c == NULL || b == NULL)
531		_bnerr("pow");
532	BN_ERRCHECK("pow", BN_mod_exp(b, bmp->bn, emp->bn, mmp->bn, c));
533	_moveb("pow", b, rmp);
534	BN_free(b);
535	BN_CTX_free(c);
536}
537
538/*
539 * Compute rmp=bmp^e.  (See note above pow().)
540 */
541void
542mp_rpow(const MINT *bmp, short e, MINT *rmp)
543{
544	MINT *emp;
545	BIGNUM *b;
546	BN_CTX *c;
547
548	b = NULL;
549	c = BN_CTX_new();
550	if (c != NULL)
551		b = BN_new();
552	if (c == NULL || b == NULL)
553		_bnerr("rpow");
554	emp = _itom("rpow", e);
555	BN_ERRCHECK("rpow", BN_exp(b, bmp->bn, emp->bn, c));
556	_moveb("rpow", b, rmp);
557	_mfree("rpow", emp);
558	BN_free(b);
559	BN_CTX_free(c);
560}
561
562/*
563 * Compute qmp=nmp/d and ro=nmp%d.
564 */
565static void
566_sdiv(const char *msg, const MINT *nmp, short d, MINT *qmp, short *ro,
567    BN_CTX *c)
568{
569	MINT *dmp, *rmp;
570	BIGNUM *q, *r;
571	char *s;
572
573	r = NULL;
574	q = BN_new();
575	if (q != NULL)
576		r = BN_new();
577	if (q == NULL || r == NULL)
578		_bnerr(msg);
579	dmp = _itom(msg, d);
580	rmp = _itom(msg, 0);
581	BN_ERRCHECK(msg, BN_div(q, r, nmp->bn, dmp->bn, c));
582	_moveb(msg, q, qmp);
583	_moveb(msg, r, rmp);
584	s = _mtox(msg, rmp);
585	errno = 0;
586	*ro = strtol(s, NULL, 16);
587	if (errno != 0)
588		MPERR(("%s underflow or overflow", msg));
589	free(s);
590	_mfree(msg, dmp);
591	_mfree(msg, rmp);
592	BN_free(r);
593	BN_free(q);
594}
595
596void
597mp_sdiv(const MINT *nmp, short d, MINT *qmp, short *ro)
598{
599	BN_CTX *c;
600
601	c = BN_CTX_new();
602	if (c == NULL)
603		_bnerr("sdiv");
604	_sdiv("sdiv", nmp, d, qmp, ro, c);
605	BN_CTX_free(c);
606}
607
608/*
609 * Convert a hexadecimal string to an MINT.
610 */
611static MINT *
612_xtom(const char *msg, const char *s)
613{
614	MINT *mp;
615
616	mp = malloc(sizeof(*mp));
617	if (mp == NULL)
618		MPERR(("%s", msg));
619	mp->bn = BN_new();
620	if (mp->bn == NULL)
621		_bnerr(msg);
622	BN_ERRCHECK(msg, BN_hex2bn(&mp->bn, s));
623	return (mp);
624}
625
626MINT *
627mp_xtom(const char *s)
628{
629
630	return (_xtom("xtom", s));
631}
632