mpasbn.c revision 80529
1/*
2 * Copyright (c) 2001 Dima Dorfman.
3 * All rights reserved.
4 *
5 * Redistribution and use in source and binary forms, with or without
6 * modification, are permitted provided that the following conditions
7 * are met:
8 * 1. Redistributions of source code must retain the above copyright
9 *    notice, this list of conditions and the following disclaimer.
10 * 2. Redistributions in binary form must reproduce the above copyright
11 *    notice, this list of conditions and the following disclaimer in the
12 *    documentation and/or other materials provided with the distribution.
13 *
14 * THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS ``AS IS'' AND
15 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
16 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
17 * ARE DISCLAIMED.  IN NO EVENT SHALL THE AUTHOR OR CONTRIBUTORS BE LIABLE
18 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
19 * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
20 * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
21 * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
22 * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
23 * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
24 * SUCH DAMAGE.
25 */
26
27/*
28 * This is the traditional Berkeley MP library implemented in terms of
29 * the OpenSSL BIGNUM library.  It was written to replace libgmp, and
30 * is meant to be as compatible with the latter as feasible.
31 *
32 * There seems to be a lack of documentation for the Berkeley MP
33 * interface.  All I could find was libgmp documentation (which didn't
34 * talk about the semantics of the functions) and an old SunOS 4.1
35 * manual page from 1989.  The latter wasn't very detailed, either,
36 * but at least described what the function's arguments were.  In
37 * general the interface seems to be archaic, somewhat poorly
38 * designed, and poorly, if at all, documented.  It is considered
39 * harmful.
40 *
41 * Miscellaneous notes on this implementation:
42 *
43 *  - The SunOS manual page mentioned above indicates that if an error
44 *  occurs, the library should "produce messages and core images."
45 *  Given that most of the functions don't have return values (and
46 *  thus no sane way of alerting the caller to an error), this seems
47 *  reasonable.  The MPERR and MPERRX macros call warn and warnx,
48 *  respectively, then abort().
49 *
50 *  - All the functions which take an argument to be "filled in"
51 *  assume that the argument has been initialized by one of the *tom()
52 *  routines before being passed to it.  I never saw this documented
53 *  anywhere, but this seems to be consistent with the way this
54 *  library is used.
55 *
56 *  - msqrt() is the only routine which had to be implemented which
57 *  doesn't have a close counterpart in the OpenSSL BIGNUM library.
58 *  It was implemented by hand using Newton's recursive formula.
59 *  Doing it this way, although more error-prone, has the positive
60 *  sideaffect of testing a lot of other functions; if msqrt()
61 *  produces the correct results, most of the other routines will as
62 *  well.
63 *
64 *  - Internal-use-only routines (i.e., those defined here statically
65 *  and not in mp.h) have an underscore prepended to their name (this
66 *  is more for aesthetical reasons than technical).  All such
67 *  routines take an extra argument, 'msg', that denotes what they
68 *  should call themselves in an error message.  This is so a user
69 *  doesn't get an error message from a function they didn't call.
70 */
71
72#ifndef lint
73static const char rcsid[] =
74  "$FreeBSD: head/lib/libmp/mpasbn.c 80529 2001-07-29 08:49:15Z dd $";
75#endif /* not lint */
76
77#include <ctype.h>
78#include <err.h>
79#include <errno.h>
80#include <stdio.h>
81#include <stdlib.h>
82#include <string.h>
83
84#include <openssl/bn.h>
85#include <openssl/crypto.h>
86#include <openssl/err.h>
87
88#include "mp.h"
89
90#define MPERR(s)	do { warn s; abort(); } while (0)
91#define MPERRX(s)	do { warnx s; abort(); } while (0)
92#define BN_ERRCHECK(msg, expr) do {		\
93	if (!(expr)) _bnerr(msg);		\
94} while (0)
95
96static void _bnerr(const char *);
97static MINT *_dtom(const char *, const char *);
98static MINT *_itom(const char *, short);
99static void _madd(const char *, const MINT *, const MINT *, MINT *);
100static int _mcmpa(const char *, const MINT *, const MINT *);
101static void _mdiv(const char *, const MINT *, const MINT *, MINT *, MINT *);
102static void _mfree(const char *, MINT *);
103static void _moveb(const char *, const BIGNUM *, MINT *);
104static void _movem(const char *, const MINT *, MINT *);
105static void _msub(const char *, const MINT *, const MINT *, MINT *);
106static char *_mtod(const char *, const MINT *);
107static char *_mtox(const char *, const MINT *);
108static void _mult(const char *, const MINT *, const MINT *, MINT *);
109static void _sdiv(const char *, const MINT *, short, MINT *, short *);
110static MINT *_xtom(const char *, const char *);
111
112/*
113 * Report an error from one of the BN_* functions using MPERRX.
114 */
115static void
116_bnerr(const char *msg)
117{
118
119	ERR_load_crypto_strings();
120	MPERRX(("%s: %s", msg, ERR_reason_error_string(ERR_get_error())));
121}
122
123/*
124 * Convert a decimal string to an MINT.
125 */
126static MINT *
127_dtom(const char *msg, const char *s)
128{
129	MINT *mp;
130
131	mp = malloc(sizeof(*mp));
132	if (mp == NULL)
133		MPERR(("%s", msg));
134	mp->bn = BN_new();
135	if (mp->bn == NULL)
136		_bnerr(msg);
137	BN_ERRCHECK(msg, BN_dec2bn(&mp->bn, s));
138	return (mp);
139}
140
141/*
142 * Compute the greatest common divisor of mp1 and mp2; result goes in rmp.
143 */
144void
145gcd(const MINT *mp1, const MINT *mp2, MINT *rmp)
146{
147	BIGNUM b;
148	BN_CTX c;
149
150	BN_CTX_init(&c);
151	BN_init(&b);
152	BN_ERRCHECK("gcd", BN_gcd(&b, mp1->bn, mp2->bn, &c));
153	_moveb("gcd", &b, rmp);
154	BN_free(&b);
155	BN_CTX_free(&c);
156}
157
158/*
159 * Make an MINT out of a short integer.  Return value must be mfree()'d.
160 */
161static MINT *
162_itom(const char *msg, short n)
163{
164	MINT *mp;
165	char *s;
166
167	asprintf(&s, "%x", n);
168	if (s == NULL)
169		MPERR(("%s", msg));
170	mp = _xtom(msg, s);
171	free(s);
172	return (mp);
173}
174
175MINT *
176itom(short n)
177{
178
179	return (_itom("itom", n));
180}
181
182/*
183 * Compute rmp=mp1+mp2.
184 */
185static void
186_madd(const char *msg, const MINT *mp1, const MINT *mp2, MINT *rmp)
187{
188	BIGNUM b;
189
190	BN_init(&b);
191	BN_ERRCHECK(msg, BN_add(&b, mp1->bn, mp2->bn));
192	_moveb(msg, &b, rmp);
193	BN_free(&b);
194}
195
196void
197madd(const MINT *mp1, const MINT *mp2, MINT *rmp)
198{
199
200	_madd("madd", mp1, mp2, rmp);
201}
202
203/*
204 * Return -1, 0, or 1 if mp1<mp2, mp1==mp2, or mp1>mp2, respectivley.
205 */
206int
207mcmp(const MINT *mp1, const MINT *mp2)
208{
209
210	return (BN_cmp(mp1->bn, mp2->bn));
211}
212
213/*
214 * Same as mcmp but compares absolute values.
215 */
216static int
217_mcmpa(const char *msg __unused, const MINT *mp1, const MINT *mp2)
218{
219
220	return (BN_ucmp(mp1->bn, mp2->bn));
221}
222
223/*
224 * Compute qmp=nmp/dmp and rmp=nmp%dmp.
225 */
226static void
227_mdiv(const char *msg, const MINT *nmp, const MINT *dmp, MINT *qmp, MINT *rmp)
228{
229	BIGNUM q, r;
230	BN_CTX c;
231
232	BN_CTX_init(&c);
233	BN_init(&r);
234	BN_init(&q);
235	BN_ERRCHECK(msg, BN_div(&q, &r, nmp->bn, dmp->bn, &c));
236	_moveb(msg, &q, qmp);
237	_moveb(msg, &r, rmp);
238	BN_free(&q);
239	BN_free(&r);
240	BN_CTX_free(&c);
241}
242
243void
244mdiv(const MINT *nmp, const MINT *dmp, MINT *qmp, MINT *rmp)
245{
246
247	_mdiv("mdiv", nmp, dmp, qmp, rmp);
248}
249
250/*
251 * Free memory associated with an MINT.
252 */
253static void
254_mfree(const char *msg __unused, MINT *mp)
255{
256
257	BN_clear(mp->bn);
258	BN_free(mp->bn);
259	free(mp);
260}
261
262void
263mfree(MINT *mp)
264{
265
266	_mfree("mfree", mp);
267}
268
269/*
270 * Read an integer from standard input and stick the result in mp.
271 * The input is treated to be in base 10.  This must be the silliest
272 * API in existence; why can't the program read in a string and call
273 * xtom()?  (Or if base 10 is desires, perhaps dtom() could be
274 * exported.)
275 */
276void
277min(MINT *mp)
278{
279	MINT *rmp;
280	char *line, *nline;
281	size_t linelen;
282
283	line = fgetln(stdin, &linelen);
284	if (line == NULL)
285		MPERR(("min"));
286	nline = malloc(linelen);
287	if (nline == NULL)
288		MPERR(("min"));
289	strncpy(nline, line, linelen);
290	nline[linelen] = '\0';
291	rmp = _dtom("min", nline);
292	_movem("min", rmp, mp);
293	_mfree("min", rmp);
294	free(nline);
295}
296
297/*
298 * Print the value of mp to standard output in base 10.  See blurb
299 * above min() for why this is so useless.
300 */
301void
302mout(const MINT *mp)
303{
304	char *s;
305
306	s = _mtod("mout", mp);
307	printf("%s", s);
308	free(s);
309}
310
311/*
312 * Set the value of tmp to the value of smp (i.e., tmp=smp).
313 */
314void
315move(const MINT *smp, MINT *tmp)
316{
317
318	_movem("move", smp, tmp);
319}
320
321
322/*
323 * Internal routine to set the value of tmp to that of sbp.
324 */
325static void
326_moveb(const char *msg, const BIGNUM *sbp, MINT *tmp)
327{
328
329	BN_ERRCHECK(msg, BN_copy(tmp->bn, sbp));
330}
331
332/*
333 * Internal routine to set the value of tmp to that of smp.
334 */
335static void
336_movem(const char *msg, const MINT *smp, MINT *tmp)
337{
338
339	BN_ERRCHECK(msg, BN_copy(tmp->bn, smp->bn));
340}
341
342/*
343 * Compute the square root of nmp and put the result in xmp.  The
344 * remainder goes in rmp.  Should satisfy: rmp=nmp-(xmp*xmp).
345 *
346 * Note that the OpenSSL BIGNUM library does not have a square root
347 * function, so this had to be implemented by hand using Newton's
348 * recursive formula:
349 *
350 *		x = (x + (n / x)) / 2
351 *
352 * where x is the square root of the positive number n.  In the
353 * beginning, x should be a reasonable guess, but the value 1,
354 * although suboptimal, works, too; this is that is used below.
355 */
356void
357msqrt(const MINT *nmp, MINT *xmp, MINT *rmp)
358{
359	MINT *tolerance;
360	MINT *ox, *x;
361	MINT *z1, *z2, *z3;
362	short i;
363
364	tolerance = _itom("msqrt", 1);
365	x = _itom("msqrt", 1);
366	ox = _itom("msqrt", 0);
367	z1 = _itom("msqrt", 0);
368	z2 = _itom("msqrt", 0);
369	z3 = _itom("msqrt", 0);
370	do {
371		_movem("msqrt", x, ox);
372		_mdiv("msqrt", nmp, x, z1, z2);
373		_madd("msqrt", x, z1, z2);
374		_sdiv("msqrt", z2, 2, x, &i);
375		_msub("msqrt", ox, x, z3);
376	} while (_mcmpa("msqrt", z3, tolerance) == 1);
377	_movem("msqrt", x, xmp);
378	_mult("msqrt", x, x, z1);
379	_msub("msqrt", nmp, z1, z2);
380	_movem("msqrt", z2, rmp);
381	_mfree("msqrt", tolerance);
382	_mfree("msqrt", ox);
383	_mfree("msqrt", x);
384	_mfree("msqrt", z1);
385	_mfree("msqrt", z2);
386	_mfree("msqrt", z3);
387}
388
389/*
390 * Compute rmp=mp1-mp2.
391 */
392static void
393_msub(const char *msg, const MINT *mp1, const MINT *mp2, MINT *rmp)
394{
395	BIGNUM b;
396
397	BN_init(&b);
398	BN_ERRCHECK(msg, BN_sub(&b, mp1->bn, mp2->bn));
399	_moveb(msg, &b, rmp);
400	BN_free(&b);
401}
402
403void
404msub(const MINT *mp1, const MINT *mp2, MINT *rmp)
405{
406
407	_msub("msub", mp1, mp2, rmp);
408}
409
410/*
411 * Return a decimal representation of mp.  Return value must be
412 * free()'d.
413 */
414static char *
415_mtod(const char *msg, const MINT *mp)
416{
417	char *s, *s2;
418
419	s = BN_bn2dec(mp->bn);
420	if (s == NULL)
421		_bnerr(msg);
422	asprintf(&s2, "%s", s);
423	if (s2 == NULL)
424		MPERR(("%s", msg));
425	OPENSSL_free(s);
426	return (s2);
427}
428
429/*
430 * Return a hexadecimal representation of mp.  Return value must be
431 * free()'d.
432 */
433static char *
434_mtox(const char *msg, const MINT *mp)
435{
436	char *p, *s, *s2;
437	int len;
438
439	s = BN_bn2hex(mp->bn);
440	if (s == NULL)
441		_bnerr(msg);
442	asprintf(&s2, "%s", s);
443	if (s2 == NULL)
444		MPERR(("%s", msg));
445	OPENSSL_free(s);
446
447	/*
448	 * This is a kludge for libgmp compatibility.  The latter's
449	 * implementation of this function returns lower-case letters,
450	 * but BN_bn2hex returns upper-case.  Some programs (e.g.,
451	 * newkey(1)) are sensitive to this.  Although it's probably
452	 * their fault, it's nice to be compatible.
453	 */
454	len = strlen(s2);
455	for (p = s2; p < s2 + len; p++)
456		*p = tolower(*p);
457
458	return (s2);
459}
460
461char *
462mtox(const MINT *mp)
463{
464
465	return (_mtox("mtox", mp));
466}
467
468/*
469 * Compute rmp=mp1*mp2.
470 */
471static void
472_mult(const char *msg, const MINT *mp1, const MINT *mp2, MINT *rmp)
473{
474	BIGNUM b;
475	BN_CTX c;
476
477	BN_CTX_init(&c);
478	BN_init(&b);
479	BN_ERRCHECK(msg, BN_mul(&b, mp1->bn, mp2->bn, &c));
480	_moveb(msg, &b, rmp);
481	BN_free(&b);
482	BN_CTX_free(&c);
483}
484
485void
486mult(const MINT *mp1, const MINT *mp2, MINT *rmp)
487{
488
489	_mult("mult", mp1, mp2, rmp);
490}
491
492/*
493 * Compute rmp=(bmp^emp)mod mmp.  (Note that here and above rpow() '^'
494 * means 'raise to power', not 'bitwise XOR'.)
495 */
496void
497pow(const MINT *bmp, const MINT *emp, const MINT *mmp, MINT *rmp)
498{
499	BIGNUM b;
500	BN_CTX c;
501
502	BN_CTX_init(&c);
503	BN_init(&b);
504	BN_ERRCHECK("pow", BN_mod_exp(&b, bmp->bn, emp->bn, mmp->bn, &c));
505	_moveb("pow", &b, rmp);
506	BN_free(&b);
507	BN_CTX_free(&c);
508}
509
510/*
511 * Compute rmp=bmp^e.  (See note above pow().)
512 */
513void
514rpow(const MINT *bmp, short e, MINT *rmp)
515{
516	MINT *emp;
517	BIGNUM b;
518	BN_CTX c;
519
520	BN_CTX_init(&c);
521	BN_init(&b);
522	emp = _itom("rpow", e);
523	BN_ERRCHECK("rpow", BN_exp(&b, bmp->bn, emp->bn, &c));
524	_moveb("rpow", &b, rmp);
525	_mfree("rpow", emp);
526	BN_free(&b);
527	BN_CTX_free(&c);
528}
529
530/*
531 * Compute qmp=nmp/d and ro=nmp%d.
532 */
533static void
534_sdiv(const char *msg, const MINT *nmp, short d, MINT *qmp, short *ro)
535{
536	MINT *dmp, *rmp;
537	BIGNUM q, r;
538	BN_CTX c;
539	char *s;
540
541	BN_CTX_init(&c);
542	BN_init(&q);
543	BN_init(&r);
544	dmp = _itom(msg, d);
545	rmp = _itom(msg, 0);
546	BN_ERRCHECK(msg, BN_div(&q, &r, nmp->bn, dmp->bn, &c));
547	_moveb(msg, &q, qmp);
548	_moveb(msg, &r, rmp);
549	s = _mtox(msg, rmp);
550	errno = 0;
551	*ro = strtol(s, NULL, 16);
552	if (errno != 0)
553		MPERR(("%s underflow or overflow", msg));
554	free(s);
555	_mfree(msg, dmp);
556	_mfree(msg, rmp);
557	BN_free(&r);
558	BN_free(&q);
559	BN_CTX_free(&c);
560}
561
562void
563sdiv(const MINT *nmp, short d, MINT *qmp, short *ro)
564{
565
566	_sdiv("sdiv", nmp, d, qmp, ro);
567}
568
569/*
570 * Convert a hexadecimal string to an MINT.
571 */
572static MINT *
573_xtom(const char *msg, const char *s)
574{
575	MINT *mp;
576
577	mp = malloc(sizeof(*mp));
578	if (mp == NULL)
579		MPERR(("%s", msg));
580	mp->bn = BN_new();
581	if (mp->bn == NULL)
582		_bnerr(msg);
583	BN_ERRCHECK(msg, BN_hex2bn(&mp->bn, s));
584	return (mp);
585}
586
587MINT *
588xtom(const char *s)
589{
590
591	return (_xtom("xtom", s));
592}
593