1/*
2 * Copyright (c) 2006 - 2007 Kungliga Tekniska Högskolan
3 * (Royal Institute of Technology, Stockholm, Sweden).
4 * All rights reserved.
5 *
6 * Redistribution and use in source and binary forms, with or without
7 * modification, are permitted provided that the following conditions
8 * are met:
9 *
10 * 1. Redistributions of source code must retain the above copyright
11 *    notice, this list of conditions and the following disclaimer.
12 *
13 * 2. Redistributions in binary form must reproduce the above copyright
14 *    notice, this list of conditions and the following disclaimer in the
15 *    documentation and/or other materials provided with the distribution.
16 *
17 * 3. Neither the name of the Institute nor the names of its contributors
18 *    may be used to endorse or promote products derived from this software
19 *    without specific prior written permission.
20 *
21 * THIS SOFTWARE IS PROVIDED BY THE INSTITUTE AND CONTRIBUTORS ``AS IS'' AND
22 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
23 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
24 * ARE DISCLAIMED.  IN NO EVENT SHALL THE INSTITUTE OR CONTRIBUTORS BE LIABLE
25 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
26 * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
27 * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
28 * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
29 * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
30 * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
31 * SUCH DAMAGE.
32 */
33
34#include <config.h>
35
36#include <stdio.h>
37#include <stdlib.h>
38#include <krb5-types.h>
39#include <assert.h>
40
41#include <rsa.h>
42
43#include <roken.h>
44
45#ifdef HAVE_GMP
46
47#include <gmp.h>
48
49static void
50BN2mpz(mpz_t s, const BIGNUM *bn)
51{
52    size_t len;
53    void *p;
54
55    len = BN_num_bytes(bn);
56    p = malloc(len);
57    BN_bn2bin(bn, p);
58    mpz_init(s);
59    mpz_import(s, len, 1, 1, 1, 0, p);
60
61    free(p);
62}
63
64
65static BIGNUM *
66mpz2BN(mpz_t s)
67{
68    size_t size;
69    BIGNUM *bn;
70    void *p;
71
72    mpz_export(NULL, &size, 1, 1, 1, 0, s);
73    p = malloc(size);
74    if (p == NULL && size != 0)
75	return NULL;
76    mpz_export(p, &size, 1, 1, 1, 0, s);
77    bn = BN_bin2bn(p, size, NULL);
78    free(p);
79    return bn;
80}
81
82static int
83rsa_private_calculate(mpz_t in, mpz_t p,  mpz_t q,
84		      mpz_t dmp1, mpz_t dmq1, mpz_t iqmp,
85		      mpz_t out)
86{
87    mpz_t vp, vq, u;
88    mpz_init(vp); mpz_init(vq); mpz_init(u);
89
90    /* vq = c ^ (d mod (q - 1)) mod q */
91    /* vp = c ^ (d mod (p - 1)) mod p */
92    mpz_fdiv_r(vp, in, p);
93    mpz_powm(vp, vp, dmp1, p);
94    mpz_fdiv_r(vq, in, q);
95    mpz_powm(vq, vq, dmq1, q);
96
97    /* C2 = 1/q mod p  (iqmp) */
98    /* u = (vp - vq)C2 mod p. */
99    mpz_sub(u, vp, vq);
100#if 0
101    if (mp_int_compare_zero(&u) < 0)
102	mp_int_add(&u, p, &u);
103#endif
104    mpz_mul(u, iqmp, u);
105    mpz_fdiv_r(u, u, p);
106
107    /* c ^ d mod n = vq + u q */
108    mpz_mul(u, q, u);
109    mpz_add(out, u, vq);
110
111    mpz_clear(vp);
112    mpz_clear(vq);
113    mpz_clear(u);
114
115    return 0;
116}
117
118/*
119 *
120 */
121
122static int
123gmp_rsa_public_encrypt(int flen, const unsigned char* from,
124			unsigned char* to, RSA* rsa, int padding)
125{
126    unsigned char *p, *p0;
127    size_t size, padlen;
128    mpz_t enc, dec, n, e;
129
130    if (padding != RSA_PKCS1_PADDING)
131	return -1;
132
133    size = RSA_size(rsa);
134
135    if (size < RSA_PKCS1_PADDING_SIZE || size - RSA_PKCS1_PADDING_SIZE < flen)
136	return -2;
137
138    BN2mpz(n, rsa->n);
139    BN2mpz(e, rsa->e);
140
141    p = p0 = malloc(size - 1);
142    if (p0 == NULL) {
143	mpz_clear(e);
144	mpz_clear(n);
145	return -3;
146    }
147
148    padlen = size - flen - 3;
149    assert(padlen >= 8);
150
151    *p++ = 2;
152    if (CCRandomCopyBytes(kCCRandomDefault,p, padlen) != 0) {
153	mpz_clear(e);
154	mpz_clear(n);
155	free(p0);
156	return -4;
157    }
158    while(padlen) {
159	if (*p == 0)
160	    *p = 1;
161	padlen--;
162	p++;
163    }
164    *p++ = 0;
165    memcpy(p, from, flen);
166    p += flen;
167    assert((p - p0) == size - 1);
168
169    mpz_init(enc);
170    mpz_init(dec);
171    mpz_import(dec, size - 1, 1, 1, 1, 0, p0);
172    free(p0);
173
174    mpz_powm(enc, dec, e, n);
175
176    mpz_clear(dec);
177    mpz_clear(e);
178    mpz_clear(n);
179    {
180	size_t ssize;
181	mpz_export(to, &ssize, 1, 1, 1, 0, enc);
182	assert(size >= ssize);
183	size = ssize;
184    }
185    mpz_clear(enc);
186
187    return size;
188}
189
190static int
191gmp_rsa_public_decrypt(int flen, const unsigned char* from,
192			 unsigned char* to, RSA* rsa, int padding)
193{
194    unsigned char *p;
195    size_t size;
196    mpz_t s, us, n, e;
197
198    if (padding != RSA_PKCS1_PADDING)
199	return -1;
200
201    if (flen > RSA_size(rsa))
202	return -2;
203
204    BN2mpz(n, rsa->n);
205    BN2mpz(e, rsa->e);
206
207#if 0
208    /* Check that the exponent is larger then 3 */
209    if (mp_int_compare_value(&e, 3) <= 0) {
210	mp_int_clear(&n);
211	mp_int_clear(&e);
212	return -3;
213    }
214#endif
215
216    mpz_init(s);
217    mpz_init(us);
218    mpz_import(s, flen, 1, 1, 1, 0, rk_UNCONST(from));
219
220    if (mpz_cmp(s, n) >= 0) {
221	mpz_clear(n);
222	mpz_clear(e);
223	return -4;
224    }
225
226    mpz_powm(us, s, e, n);
227
228    mpz_clear(s);
229    mpz_clear(n);
230    mpz_clear(e);
231
232    p = to;
233
234    mpz_export(p, &size, 1, 1, 1, 0, us);
235    assert(size <= RSA_size(rsa));
236
237    mpz_clear(us);
238
239    /* head zero was skipped by mp_int_to_unsigned */
240    if (*p == 0)
241	return -6;
242    if (*p != 1)
243	return -7;
244    size--; p++;
245    while (size && *p == 0xff) {
246	size--; p++;
247    }
248    if (size == 0 || *p != 0)
249	return -8;
250    size--; p++;
251
252    memmove(to, p, size);
253
254    return size;
255}
256
257static int
258gmp_rsa_private_encrypt(int flen, const unsigned char* from,
259			  unsigned char* to, RSA* rsa, int padding)
260{
261    unsigned char *p, *p0;
262    size_t size;
263    mpz_t in, out, n, e;
264
265    if (padding != RSA_PKCS1_PADDING)
266	return -1;
267
268    size = RSA_size(rsa);
269
270    if (size < RSA_PKCS1_PADDING_SIZE || size - RSA_PKCS1_PADDING_SIZE < flen)
271	return -2;
272
273    p0 = p = malloc(size);
274    *p++ = 0;
275    *p++ = 1;
276    memset(p, 0xff, size - flen - 3);
277    p += size - flen - 3;
278    *p++ = 0;
279    memcpy(p, from, flen);
280    p += flen;
281    assert((p - p0) == size);
282
283    BN2mpz(n, rsa->n);
284    BN2mpz(e, rsa->e);
285
286    mpz_init(in);
287    mpz_init(out);
288    mpz_import(in, size, 1, 1, 1, 0, p0);
289    free(p0);
290
291#if 0
292    if(mp_int_compare_zero(&in) < 0 ||
293       mp_int_compare(&in, &n) >= 0) {
294	size = 0;
295	goto out;
296    }
297#endif
298
299    if (rsa->p && rsa->q && rsa->dmp1 && rsa->dmq1 && rsa->iqmp) {
300	mpz_t p, q, dmp1, dmq1, iqmp;
301
302	BN2mpz(p, rsa->p);
303	BN2mpz(q, rsa->q);
304	BN2mpz(dmp1, rsa->dmp1);
305	BN2mpz(dmq1, rsa->dmq1);
306	BN2mpz(iqmp, rsa->iqmp);
307
308	rsa_private_calculate(in, p, q, dmp1, dmq1, iqmp, out);
309
310	mpz_clear(p);
311	mpz_clear(q);
312	mpz_clear(dmp1);
313	mpz_clear(dmq1);
314	mpz_clear(iqmp);
315    } else {
316	mpz_t d;
317
318	BN2mpz(d, rsa->d);
319	mpz_powm(out, in, d, n);
320	mpz_clear(d);
321    }
322
323    {
324	size_t ssize;
325	mpz_export(to, &ssize, 1, 1, 1, 0, out);
326	assert(size >= ssize);
327	size = ssize;
328    }
329
330    mpz_clear(e);
331    mpz_clear(n);
332    mpz_clear(in);
333    mpz_clear(out);
334
335    return size;
336}
337
338static int
339gmp_rsa_private_decrypt(int flen, const unsigned char* from,
340			  unsigned char* to, RSA* rsa, int padding)
341{
342    unsigned char *ptr;
343    size_t size;
344    mpz_t in, out, n, e;
345
346    if (padding != RSA_PKCS1_PADDING)
347	return -1;
348
349    size = RSA_size(rsa);
350    if (flen > size)
351	return -2;
352
353    mpz_init(in);
354    mpz_init(out);
355
356    BN2mpz(n, rsa->n);
357    BN2mpz(e, rsa->e);
358
359    mpz_import(in, flen, 1, 1, 1, 0, from);
360
361    if(mpz_cmp_ui(in, 0) < 0 ||
362       mpz_cmp(in, n) >= 0) {
363	size = 0;
364	goto out;
365    }
366
367    if (rsa->p && rsa->q && rsa->dmp1 && rsa->dmq1 && rsa->iqmp) {
368	mpz_t p, q, dmp1, dmq1, iqmp;
369
370	BN2mpz(p, rsa->p);
371	BN2mpz(q, rsa->q);
372	BN2mpz(dmp1, rsa->dmp1);
373	BN2mpz(dmq1, rsa->dmq1);
374	BN2mpz(iqmp, rsa->iqmp);
375
376	rsa_private_calculate(in, p, q, dmp1, dmq1, iqmp, out);
377
378	mpz_clear(p);
379	mpz_clear(q);
380	mpz_clear(dmp1);
381	mpz_clear(dmq1);
382	mpz_clear(iqmp);
383    } else {
384	mpz_t d;
385
386#if 0
387	if(mp_int_compare_zero(&in) < 0 ||
388	   mp_int_compare(&in, &n) >= 0)
389	    return MP_RANGE;
390#endif
391
392	BN2mpz(d, rsa->d);
393	mpz_powm(out, in, d, n);
394	mpz_clear(d);
395    }
396
397    ptr = to;
398    {
399	size_t ssize;
400	mpz_export(ptr, &ssize, 1, 1, 1, 0, out);
401	assert(size >= ssize);
402	size = ssize;
403    }
404
405    /* head zero was skipped by mp_int_to_unsigned */
406    if (*ptr != 2)
407	return -3;
408    size--; ptr++;
409    while (size && *ptr != 0) {
410	size--; ptr++;
411    }
412    if (size == 0)
413	return -4;
414    size--; ptr++;
415
416    memmove(to, ptr, size);
417
418out:
419    mpz_clear(e);
420    mpz_clear(n);
421    mpz_clear(in);
422    mpz_clear(out);
423
424    return size;
425}
426
427static int
428random_num(mpz_t num, size_t len)
429{
430    unsigned char *p;
431
432    len = (len + 7) / 8;
433    p = malloc(len);
434    if (p == NULL)
435	return 1;
436    if (CCRandomCopyBytes(kCCRandomDefault, p, len) != 0) {
437	free(p);
438	return 1;
439    }
440    mpz_import(num, len, 1, 1, 1, 0, p);
441    free(p);
442    return 0;
443}
444
445
446static int
447gmp_rsa_generate_key(RSA *rsa, int bits, BIGNUM *e, BN_GENCB *cb)
448{
449    mpz_t el, p, q, n, d, dmp1, dmq1, iqmp, t1, t2, t3;
450    int counter, ret;
451
452    if (bits < 789)
453	return -1;
454
455    ret = -1;
456
457    mpz_init(el);
458    mpz_init(p);
459    mpz_init(q);
460    mpz_init(n);
461    mpz_init(d);
462    mpz_init(dmp1);
463    mpz_init(dmq1);
464    mpz_init(iqmp);
465    mpz_init(t1);
466    mpz_init(t2);
467    mpz_init(t3);
468
469    BN2mpz(el, e);
470
471    /* generate p and q so that p != q and bits(pq) ~ bits */
472
473    counter = 0;
474    do {
475	BN_GENCB_call(cb, 2, counter++);
476	random_num(p, bits / 2 + 1);
477	mpz_nextprime(p, p);
478
479	mpz_sub_ui(t1, p, 1);
480	mpz_gcd(t2, t1, el);
481    } while(mpz_cmp_ui(t2, 1) != 0);
482
483    BN_GENCB_call(cb, 3, 0);
484
485    counter = 0;
486    do {
487	BN_GENCB_call(cb, 2, counter++);
488	random_num(q, bits / 2 + 1);
489	mpz_nextprime(q, q);
490
491	mpz_sub_ui(t1, q, 1);
492	mpz_gcd(t2, t1, el);
493    } while(mpz_cmp_ui(t2, 1) != 0);
494
495    /* make p > q */
496    if (mpz_cmp(p, q) < 0)
497	mpz_swap(p, q);
498
499    BN_GENCB_call(cb, 3, 1);
500
501    /* calculate n,  		n = p * q */
502    mpz_mul(n, p, q);
503
504    /* calculate d, 		d = 1/e mod (p - 1)(q - 1) */
505    mpz_sub_ui(t1, p, 1);
506    mpz_sub_ui(t2, q, 1);
507    mpz_mul(t3, t1, t2);
508    mpz_invert(d, el, t3);
509
510    /* calculate dmp1		dmp1 = d mod (p-1) */
511    mpz_mod(dmp1, d, t1);
512    /* calculate dmq1		dmq1 = d mod (q-1) */
513    mpz_mod(dmq1, d, t2);
514    /* calculate iqmp 		iqmp = 1/q mod p */
515    mpz_invert(iqmp, q, p);
516
517    /* fill in RSA key */
518
519    rsa->e = mpz2BN(el);
520    rsa->p = mpz2BN(p);
521    rsa->q = mpz2BN(q);
522    rsa->n = mpz2BN(n);
523    rsa->d = mpz2BN(d);
524    rsa->dmp1 = mpz2BN(dmp1);
525    rsa->dmq1 = mpz2BN(dmq1);
526    rsa->iqmp = mpz2BN(iqmp);
527
528    ret = 1;
529
530    mpz_clear(el);
531    mpz_clear(p);
532    mpz_clear(q);
533    mpz_clear(n);
534    mpz_clear(d);
535    mpz_clear(dmp1);
536    mpz_clear(dmq1);
537    mpz_clear(iqmp);
538    mpz_clear(t1);
539    mpz_clear(t2);
540    mpz_clear(t3);
541
542    return ret;
543}
544
545static int
546gmp_rsa_init(RSA *rsa)
547{
548    return 1;
549}
550
551static int
552gmp_rsa_finish(RSA *rsa)
553{
554    return 1;
555}
556
557const RSA_METHOD _hc_rsa_gmp_method = {
558    "hcrypto GMP RSA",
559    gmp_rsa_public_encrypt,
560    gmp_rsa_public_decrypt,
561    gmp_rsa_private_encrypt,
562    gmp_rsa_private_decrypt,
563    NULL,
564    NULL,
565    gmp_rsa_init,
566    gmp_rsa_finish,
567    0,
568    NULL,
569    NULL,
570    NULL,
571    gmp_rsa_generate_key
572};
573
574#endif /* HAVE_GMP */
575
576/**
577 * RSA implementation using Gnu Multipresistion Library.
578 */
579
580const RSA_METHOD *
581RSA_gmp_method(void)
582{
583#ifdef HAVE_GMP
584    return &_hc_rsa_gmp_method;
585#else
586    return NULL;
587#endif
588}
589