mpasbn.c revision 110011
1/*
2 * Copyright (c) 2001 Dima Dorfman.
3 * All rights reserved.
4 *
5 * Redistribution and use in source and binary forms, with or without
6 * modification, are permitted provided that the following conditions
7 * are met:
8 * 1. Redistributions of source code must retain the above copyright
9 *    notice, this list of conditions and the following disclaimer.
10 * 2. Redistributions in binary form must reproduce the above copyright
11 *    notice, this list of conditions and the following disclaimer in the
12 *    documentation and/or other materials provided with the distribution.
13 *
14 * THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS ``AS IS'' AND
15 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
16 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
17 * ARE DISCLAIMED.  IN NO EVENT SHALL THE AUTHOR OR CONTRIBUTORS BE LIABLE
18 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
19 * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
20 * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
21 * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
22 * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
23 * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
24 * SUCH DAMAGE.
25 */
26
27/*
28 * This is the traditional Berkeley MP library implemented in terms of
29 * the OpenSSL BIGNUM library.  It was written to replace libgmp, and
30 * is meant to be as compatible with the latter as feasible.
31 *
32 * There seems to be a lack of documentation for the Berkeley MP
33 * interface.  All I could find was libgmp documentation (which didn't
34 * talk about the semantics of the functions) and an old SunOS 4.1
35 * manual page from 1989.  The latter wasn't very detailed, either,
36 * but at least described what the function's arguments were.  In
37 * general the interface seems to be archaic, somewhat poorly
38 * designed, and poorly, if at all, documented.  It is considered
39 * harmful.
40 *
41 * Miscellaneous notes on this implementation:
42 *
43 *  - The SunOS manual page mentioned above indicates that if an error
44 *  occurs, the library should "produce messages and core images."
45 *  Given that most of the functions don't have return values (and
46 *  thus no sane way of alerting the caller to an error), this seems
47 *  reasonable.  The MPERR and MPERRX macros call warn and warnx,
48 *  respectively, then abort().
49 *
50 *  - All the functions which take an argument to be "filled in"
51 *  assume that the argument has been initialized by one of the *tom()
52 *  routines before being passed to it.  I never saw this documented
53 *  anywhere, but this seems to be consistent with the way this
54 *  library is used.
55 *
56 *  - msqrt() is the only routine which had to be implemented which
57 *  doesn't have a close counterpart in the OpenSSL BIGNUM library.
58 *  It was implemented by hand using Newton's recursive formula.
59 *  Doing it this way, although more error-prone, has the positive
60 *  sideaffect of testing a lot of other functions; if msqrt()
61 *  produces the correct results, most of the other routines will as
62 *  well.
63 *
64 *  - Internal-use-only routines (i.e., those defined here statically
65 *  and not in mp.h) have an underscore prepended to their name (this
66 *  is more for aesthetical reasons than technical).  All such
67 *  routines take an extra argument, 'msg', that denotes what they
68 *  should call themselves in an error message.  This is so a user
69 *  doesn't get an error message from a function they didn't call.
70 */
71
72#include <sys/cdefs.h>
73__FBSDID("$FreeBSD: head/lib/libmp/mpasbn.c 110011 2003-01-28 23:03:15Z markm $");
74
75#include <ctype.h>
76#include <err.h>
77#include <errno.h>
78#include <stdio.h>
79#include <stdlib.h>
80#include <string.h>
81
82#include <openssl/crypto.h>
83#include <openssl/err.h>
84
85#include "openssl/crypto/bn/bn_lcl.h"
86#include "mp.h"
87
88#define MPERR(s)	do { warn s; abort(); } while (0)
89#define MPERRX(s)	do { warnx s; abort(); } while (0)
90#define BN_ERRCHECK(msg, expr) do {		\
91	if (!(expr)) _bnerr(msg);		\
92} while (0)
93
94static void _bnerr(const char *);
95static MINT *_dtom(const char *, const char *);
96static MINT *_itom(const char *, short);
97static void _madd(const char *, const MINT *, const MINT *, MINT *);
98static int _mcmpa(const char *, const MINT *, const MINT *);
99static void _mdiv(const char *, const MINT *, const MINT *, MINT *, MINT *);
100static void _mfree(const char *, MINT *);
101static void _moveb(const char *, const BIGNUM *, MINT *);
102static void _movem(const char *, const MINT *, MINT *);
103static void _msub(const char *, const MINT *, const MINT *, MINT *);
104static char *_mtod(const char *, const MINT *);
105static char *_mtox(const char *, const MINT *);
106static void _mult(const char *, const MINT *, const MINT *, MINT *);
107static void _sdiv(const char *, const MINT *, short, MINT *, short *);
108static MINT *_xtom(const char *, const char *);
109
110/*
111 * Report an error from one of the BN_* functions using MPERRX.
112 */
113static void
114_bnerr(const char *msg)
115{
116
117	ERR_load_crypto_strings();
118	MPERRX(("%s: %s", msg, ERR_reason_error_string(ERR_get_error())));
119}
120
121/*
122 * Convert a decimal string to an MINT.
123 */
124static MINT *
125_dtom(const char *msg, const char *s)
126{
127	MINT *mp;
128
129	mp = malloc(sizeof(*mp));
130	if (mp == NULL)
131		MPERR(("%s", msg));
132	mp->bn = BN_new();
133	if (mp->bn == NULL)
134		_bnerr(msg);
135	BN_ERRCHECK(msg, BN_dec2bn(&mp->bn, s));
136	return (mp);
137}
138
139/*
140 * Compute the greatest common divisor of mp1 and mp2; result goes in rmp.
141 */
142void
143gcd(const MINT *mp1, const MINT *mp2, MINT *rmp)
144{
145	BIGNUM b;
146	BN_CTX c;
147
148	BN_CTX_init(&c);
149	BN_init(&b);
150	BN_ERRCHECK("gcd", BN_gcd(&b, mp1->bn, mp2->bn, &c));
151	_moveb("gcd", &b, rmp);
152	BN_free(&b);
153	BN_CTX_free(&c);
154}
155
156/*
157 * Make an MINT out of a short integer.  Return value must be mfree()'d.
158 */
159static MINT *
160_itom(const char *msg, short n)
161{
162	MINT *mp;
163	char *s;
164
165	asprintf(&s, "%x", n);
166	if (s == NULL)
167		MPERR(("%s", msg));
168	mp = _xtom(msg, s);
169	free(s);
170	return (mp);
171}
172
173MINT *
174itom(short n)
175{
176
177	return (_itom("itom", n));
178}
179
180/*
181 * Compute rmp=mp1+mp2.
182 */
183static void
184_madd(const char *msg, const MINT *mp1, const MINT *mp2, MINT *rmp)
185{
186	BIGNUM b;
187
188	BN_init(&b);
189	BN_ERRCHECK(msg, BN_add(&b, mp1->bn, mp2->bn));
190	_moveb(msg, &b, rmp);
191	BN_free(&b);
192}
193
194void
195madd(const MINT *mp1, const MINT *mp2, MINT *rmp)
196{
197
198	_madd("madd", mp1, mp2, rmp);
199}
200
201/*
202 * Return -1, 0, or 1 if mp1<mp2, mp1==mp2, or mp1>mp2, respectivley.
203 */
204int
205mcmp(const MINT *mp1, const MINT *mp2)
206{
207
208	return (BN_cmp(mp1->bn, mp2->bn));
209}
210
211/*
212 * Same as mcmp but compares absolute values.
213 */
214static int
215_mcmpa(const char *msg __unused, const MINT *mp1, const MINT *mp2)
216{
217
218	return (BN_ucmp(mp1->bn, mp2->bn));
219}
220
221/*
222 * Compute qmp=nmp/dmp and rmp=nmp%dmp.
223 */
224static void
225_mdiv(const char *msg, const MINT *nmp, const MINT *dmp, MINT *qmp, MINT *rmp)
226{
227	BIGNUM q, r;
228	BN_CTX c;
229
230	BN_CTX_init(&c);
231	BN_init(&r);
232	BN_init(&q);
233	BN_ERRCHECK(msg, BN_div(&q, &r, nmp->bn, dmp->bn, &c));
234	_moveb(msg, &q, qmp);
235	_moveb(msg, &r, rmp);
236	BN_free(&q);
237	BN_free(&r);
238	BN_CTX_free(&c);
239}
240
241void
242mdiv(const MINT *nmp, const MINT *dmp, MINT *qmp, MINT *rmp)
243{
244
245	_mdiv("mdiv", nmp, dmp, qmp, rmp);
246}
247
248/*
249 * Free memory associated with an MINT.
250 */
251static void
252_mfree(const char *msg __unused, MINT *mp)
253{
254
255	BN_clear(mp->bn);
256	BN_free(mp->bn);
257	free(mp);
258}
259
260void
261mfree(MINT *mp)
262{
263
264	_mfree("mfree", mp);
265}
266
267/*
268 * Read an integer from standard input and stick the result in mp.
269 * The input is treated to be in base 10.  This must be the silliest
270 * API in existence; why can't the program read in a string and call
271 * xtom()?  (Or if base 10 is desires, perhaps dtom() could be
272 * exported.)
273 */
274void
275min(MINT *mp)
276{
277	MINT *rmp;
278	char *line, *nline;
279	size_t linelen;
280
281	line = fgetln(stdin, &linelen);
282	if (line == NULL)
283		MPERR(("min"));
284	nline = malloc(linelen);
285	if (nline == NULL)
286		MPERR(("min"));
287	strncpy(nline, line, linelen);
288	nline[linelen] = '\0';
289	rmp = _dtom("min", nline);
290	_movem("min", rmp, mp);
291	_mfree("min", rmp);
292	free(nline);
293}
294
295/*
296 * Print the value of mp to standard output in base 10.  See blurb
297 * above min() for why this is so useless.
298 */
299void
300mout(const MINT *mp)
301{
302	char *s;
303
304	s = _mtod("mout", mp);
305	printf("%s", s);
306	free(s);
307}
308
309/*
310 * Set the value of tmp to the value of smp (i.e., tmp=smp).
311 */
312void
313move(const MINT *smp, MINT *tmp)
314{
315
316	_movem("move", smp, tmp);
317}
318
319
320/*
321 * Internal routine to set the value of tmp to that of sbp.
322 */
323static void
324_moveb(const char *msg, const BIGNUM *sbp, MINT *tmp)
325{
326
327	BN_ERRCHECK(msg, BN_copy(tmp->bn, sbp));
328}
329
330/*
331 * Internal routine to set the value of tmp to that of smp.
332 */
333static void
334_movem(const char *msg, const MINT *smp, MINT *tmp)
335{
336
337	BN_ERRCHECK(msg, BN_copy(tmp->bn, smp->bn));
338}
339
340/*
341 * Compute the square root of nmp and put the result in xmp.  The
342 * remainder goes in rmp.  Should satisfy: rmp=nmp-(xmp*xmp).
343 *
344 * Note that the OpenSSL BIGNUM library does not have a square root
345 * function, so this had to be implemented by hand using Newton's
346 * recursive formula:
347 *
348 *		x = (x + (n / x)) / 2
349 *
350 * where x is the square root of the positive number n.  In the
351 * beginning, x should be a reasonable guess, but the value 1,
352 * although suboptimal, works, too; this is that is used below.
353 */
354void
355msqrt(const MINT *nmp, MINT *xmp, MINT *rmp)
356{
357	MINT *tolerance;
358	MINT *ox, *x;
359	MINT *z1, *z2, *z3;
360	short i;
361
362	tolerance = _itom("msqrt", 1);
363	x = _itom("msqrt", 1);
364	ox = _itom("msqrt", 0);
365	z1 = _itom("msqrt", 0);
366	z2 = _itom("msqrt", 0);
367	z3 = _itom("msqrt", 0);
368	do {
369		_movem("msqrt", x, ox);
370		_mdiv("msqrt", nmp, x, z1, z2);
371		_madd("msqrt", x, z1, z2);
372		_sdiv("msqrt", z2, 2, x, &i);
373		_msub("msqrt", ox, x, z3);
374	} while (_mcmpa("msqrt", z3, tolerance) == 1);
375	_movem("msqrt", x, xmp);
376	_mult("msqrt", x, x, z1);
377	_msub("msqrt", nmp, z1, z2);
378	_movem("msqrt", z2, rmp);
379	_mfree("msqrt", tolerance);
380	_mfree("msqrt", ox);
381	_mfree("msqrt", x);
382	_mfree("msqrt", z1);
383	_mfree("msqrt", z2);
384	_mfree("msqrt", z3);
385}
386
387/*
388 * Compute rmp=mp1-mp2.
389 */
390static void
391_msub(const char *msg, const MINT *mp1, const MINT *mp2, MINT *rmp)
392{
393	BIGNUM b;
394
395	BN_init(&b);
396	BN_ERRCHECK(msg, BN_sub(&b, mp1->bn, mp2->bn));
397	_moveb(msg, &b, rmp);
398	BN_free(&b);
399}
400
401void
402msub(const MINT *mp1, const MINT *mp2, MINT *rmp)
403{
404
405	_msub("msub", mp1, mp2, rmp);
406}
407
408/*
409 * Return a decimal representation of mp.  Return value must be
410 * free()'d.
411 */
412static char *
413_mtod(const char *msg, const MINT *mp)
414{
415	char *s, *s2;
416
417	s = BN_bn2dec(mp->bn);
418	if (s == NULL)
419		_bnerr(msg);
420	asprintf(&s2, "%s", s);
421	if (s2 == NULL)
422		MPERR(("%s", msg));
423	OPENSSL_free(s);
424	return (s2);
425}
426
427/*
428 * Return a hexadecimal representation of mp.  Return value must be
429 * free()'d.
430 */
431static char *
432_mtox(const char *msg, const MINT *mp)
433{
434	char *p, *s, *s2;
435	int len;
436
437	s = BN_bn2hex(mp->bn);
438	if (s == NULL)
439		_bnerr(msg);
440	asprintf(&s2, "%s", s);
441	if (s2 == NULL)
442		MPERR(("%s", msg));
443	OPENSSL_free(s);
444
445	/*
446	 * This is a kludge for libgmp compatibility.  The latter's
447	 * implementation of this function returns lower-case letters,
448	 * but BN_bn2hex returns upper-case.  Some programs (e.g.,
449	 * newkey(1)) are sensitive to this.  Although it's probably
450	 * their fault, it's nice to be compatible.
451	 */
452	len = strlen(s2);
453	for (p = s2; p < s2 + len; p++)
454		*p = tolower(*p);
455
456	return (s2);
457}
458
459char *
460mtox(const MINT *mp)
461{
462
463	return (_mtox("mtox", mp));
464}
465
466/*
467 * Compute rmp=mp1*mp2.
468 */
469static void
470_mult(const char *msg, const MINT *mp1, const MINT *mp2, MINT *rmp)
471{
472	BIGNUM b;
473	BN_CTX c;
474
475	BN_CTX_init(&c);
476	BN_init(&b);
477	BN_ERRCHECK(msg, BN_mul(&b, mp1->bn, mp2->bn, &c));
478	_moveb(msg, &b, rmp);
479	BN_free(&b);
480	BN_CTX_free(&c);
481}
482
483void
484mult(const MINT *mp1, const MINT *mp2, MINT *rmp)
485{
486
487	_mult("mult", mp1, mp2, rmp);
488}
489
490/*
491 * Compute rmp=(bmp^emp)mod mmp.  (Note that here and above rpow() '^'
492 * means 'raise to power', not 'bitwise XOR'.)
493 */
494void
495pow(const MINT *bmp, const MINT *emp, const MINT *mmp, MINT *rmp)
496{
497	BIGNUM b;
498	BN_CTX c;
499
500	BN_CTX_init(&c);
501	BN_init(&b);
502	BN_ERRCHECK("pow", BN_mod_exp(&b, bmp->bn, emp->bn, mmp->bn, &c));
503	_moveb("pow", &b, rmp);
504	BN_free(&b);
505	BN_CTX_free(&c);
506}
507
508/*
509 * Compute rmp=bmp^e.  (See note above pow().)
510 */
511void
512rpow(const MINT *bmp, short e, MINT *rmp)
513{
514	MINT *emp;
515	BIGNUM b;
516	BN_CTX c;
517
518	BN_CTX_init(&c);
519	BN_init(&b);
520	emp = _itom("rpow", e);
521	BN_ERRCHECK("rpow", BN_exp(&b, bmp->bn, emp->bn, &c));
522	_moveb("rpow", &b, rmp);
523	_mfree("rpow", emp);
524	BN_free(&b);
525	BN_CTX_free(&c);
526}
527
528/*
529 * Compute qmp=nmp/d and ro=nmp%d.
530 */
531static void
532_sdiv(const char *msg, const MINT *nmp, short d, MINT *qmp, short *ro)
533{
534	MINT *dmp, *rmp;
535	BIGNUM q, r;
536	BN_CTX c;
537	char *s;
538
539	BN_CTX_init(&c);
540	BN_init(&q);
541	BN_init(&r);
542	dmp = _itom(msg, d);
543	rmp = _itom(msg, 0);
544	BN_ERRCHECK(msg, BN_div(&q, &r, nmp->bn, dmp->bn, &c));
545	_moveb(msg, &q, qmp);
546	_moveb(msg, &r, rmp);
547	s = _mtox(msg, rmp);
548	errno = 0;
549	*ro = strtol(s, NULL, 16);
550	if (errno != 0)
551		MPERR(("%s underflow or overflow", msg));
552	free(s);
553	_mfree(msg, dmp);
554	_mfree(msg, rmp);
555	BN_free(&r);
556	BN_free(&q);
557	BN_CTX_free(&c);
558}
559
560void
561sdiv(const MINT *nmp, short d, MINT *qmp, short *ro)
562{
563
564	_sdiv("sdiv", nmp, d, qmp, ro);
565}
566
567/*
568 * Convert a hexadecimal string to an MINT.
569 */
570static MINT *
571_xtom(const char *msg, const char *s)
572{
573	MINT *mp;
574
575	mp = malloc(sizeof(*mp));
576	if (mp == NULL)
577		MPERR(("%s", msg));
578	mp->bn = BN_new();
579	if (mp->bn == NULL)
580		_bnerr(msg);
581	BN_ERRCHECK(msg, BN_hex2bn(&mp->bn, s));
582	return (mp);
583}
584
585MINT *
586xtom(const char *s)
587{
588
589	return (_xtom("xtom", s));
590}
591