1/*	$NetBSD: rsa-ltm.c,v 1.1.1.1 2011/04/13 18:14:51 elric Exp $	*/
2
3/*
4 * Copyright (c) 2006 - 2007, 2010 Kungliga Tekniska Högskolan
5 * (Royal Institute of Technology, Stockholm, Sweden).
6 * All rights reserved.
7 *
8 * Redistribution and use in source and binary forms, with or without
9 * modification, are permitted provided that the following conditions
10 * are met:
11 *
12 * 1. Redistributions of source code must retain the above copyright
13 *    notice, this list of conditions and the following disclaimer.
14 *
15 * 2. Redistributions in binary form must reproduce the above copyright
16 *    notice, this list of conditions and the following disclaimer in the
17 *    documentation and/or other materials provided with the distribution.
18 *
19 * 3. Neither the name of the Institute nor the names of its contributors
20 *    may be used to endorse or promote products derived from this software
21 *    without specific prior written permission.
22 *
23 * THIS SOFTWARE IS PROVIDED BY THE INSTITUTE AND CONTRIBUTORS ``AS IS'' AND
24 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
25 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
26 * ARE DISCLAIMED.  IN NO EVENT SHALL THE INSTITUTE OR CONTRIBUTORS BE LIABLE
27 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
28 * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
29 * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
30 * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
31 * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
32 * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
33 * SUCH DAMAGE.
34 */
35
36#include <config.h>
37
38#include <stdio.h>
39#include <stdlib.h>
40#include <krb5/krb5-types.h>
41#include <assert.h>
42
43#include <rsa.h>
44
45#include <krb5/roken.h>
46
47#include "tommath.h"
48
49static int
50random_num(mp_int *num, size_t len)
51{
52    unsigned char *p;
53
54    len = (len + 7) / 8;
55    p = malloc(len);
56    if (p == NULL)
57	return 1;
58    if (RAND_bytes(p, len) != 1) {
59	free(p);
60	return 1;
61    }
62    mp_read_unsigned_bin(num, p, len);
63    free(p);
64    return 0;
65}
66
67static void
68BN2mpz(mp_int *s, const BIGNUM *bn)
69{
70    size_t len;
71    void *p;
72
73    len = BN_num_bytes(bn);
74    p = malloc(len);
75    BN_bn2bin(bn, p);
76    mp_read_unsigned_bin(s, p, len);
77    free(p);
78}
79
80static void
81setup_blind(mp_int *n, mp_int *b, mp_int *bi)
82{
83    random_num(b, mp_count_bits(n));
84    mp_mod(b, n, b);
85    mp_invmod(b, n, bi);
86}
87
88static void
89blind(mp_int *in, mp_int *b, mp_int *e, mp_int *n)
90{
91    mp_int t1;
92    mp_init(&t1);
93    /* in' = (in * b^e) mod n */
94    mp_exptmod(b, e, n, &t1);
95    mp_mul(&t1, in, in);
96    mp_mod(in, n, in);
97    mp_clear(&t1);
98}
99
100static void
101unblind(mp_int *out, mp_int *bi, mp_int *n)
102{
103    /* out' = (out * 1/b) mod n */
104    mp_mul(out, bi, out);
105    mp_mod(out, n, out);
106}
107
108static int
109ltm_rsa_private_calculate(mp_int * in, mp_int * p,  mp_int * q,
110			  mp_int * dmp1, mp_int * dmq1, mp_int * iqmp,
111			  mp_int * out)
112{
113    mp_int vp, vq, u;
114
115    mp_init_multi(&vp, &vq, &u, NULL);
116
117    /* vq = c ^ (d mod (q - 1)) mod q */
118    /* vp = c ^ (d mod (p - 1)) mod p */
119    mp_mod(in, p, &u);
120    mp_exptmod(&u, dmp1, p, &vp);
121    mp_mod(in, q, &u);
122    mp_exptmod(&u, dmq1, q, &vq);
123
124    /* C2 = 1/q mod p  (iqmp) */
125    /* u = (vp - vq)C2 mod p. */
126    mp_sub(&vp, &vq, &u);
127    if (mp_isneg(&u))
128	mp_add(&u, p, &u);
129    mp_mul(&u, iqmp, &u);
130    mp_mod(&u, p, &u);
131
132    /* c ^ d mod n = vq + u q */
133    mp_mul(&u, q, &u);
134    mp_add(&u, &vq, out);
135
136    mp_clear_multi(&vp, &vq, &u, NULL);
137
138    return 0;
139}
140
141/*
142 *
143 */
144
145static int
146ltm_rsa_public_encrypt(int flen, const unsigned char* from,
147			unsigned char* to, RSA* rsa, int padding)
148{
149    unsigned char *p, *p0;
150    int res;
151    size_t size, padlen;
152    mp_int enc, dec, n, e;
153
154    if (padding != RSA_PKCS1_PADDING)
155	return -1;
156
157    mp_init_multi(&n, &e, &enc, &dec, NULL);
158
159    size = RSA_size(rsa);
160
161    if (size < RSA_PKCS1_PADDING_SIZE || size - RSA_PKCS1_PADDING_SIZE < flen) {
162	mp_clear_multi(&n, &e, &enc, &dec);
163	return -2;
164    }
165
166    BN2mpz(&n, rsa->n);
167    BN2mpz(&e, rsa->e);
168
169    p = p0 = malloc(size - 1);
170    if (p0 == NULL) {
171	mp_clear_multi(&e, &n, &enc, &dec, NULL);
172	return -3;
173    }
174
175    padlen = size - flen - 3;
176
177    *p++ = 2;
178    if (RAND_bytes(p, padlen) != 1) {
179	mp_clear_multi(&e, &n, &enc, &dec, NULL);
180	free(p0);
181	return -4;
182    }
183    while(padlen) {
184	if (*p == 0)
185	    *p = 1;
186	padlen--;
187	p++;
188    }
189    *p++ = 0;
190    memcpy(p, from, flen);
191    p += flen;
192    assert((p - p0) == size - 1);
193
194    mp_read_unsigned_bin(&dec, p0, size - 1);
195    free(p0);
196
197    res = mp_exptmod(&dec, &e, &n, &enc);
198
199    mp_clear_multi(&dec, &e, &n, NULL);
200
201    if (res != 0) {
202	mp_clear(&enc);
203	return -4;
204    }
205
206    {
207	size_t ssize;
208	ssize = mp_unsigned_bin_size(&enc);
209	assert(size >= ssize);
210	mp_to_unsigned_bin(&enc, to);
211	size = ssize;
212    }
213    mp_clear(&enc);
214
215    return size;
216}
217
218static int
219ltm_rsa_public_decrypt(int flen, const unsigned char* from,
220		       unsigned char* to, RSA* rsa, int padding)
221{
222    unsigned char *p;
223    int res;
224    size_t size;
225    mp_int s, us, n, e;
226
227    if (padding != RSA_PKCS1_PADDING)
228	return -1;
229
230    if (flen > RSA_size(rsa))
231	return -2;
232
233    mp_init_multi(&e, &n, &s, &us, NULL);
234
235    BN2mpz(&n, rsa->n);
236    BN2mpz(&e, rsa->e);
237
238#if 0
239    /* Check that the exponent is larger then 3 */
240    if (mp_int_compare_value(&e, 3) <= 0) {
241	mp_clear_multi(&e, &n, &s, &us, NULL);
242	return -3;
243    }
244#endif
245
246    mp_read_unsigned_bin(&s, rk_UNCONST(from), flen);
247
248    if (mp_cmp(&s, &n) >= 0) {
249	mp_clear_multi(&e, &n, &s, &us, NULL);
250	return -4;
251    }
252
253    res = mp_exptmod(&s, &e, &n, &us);
254
255    mp_clear_multi(&e, &n, &s, NULL);
256
257    if (res != 0) {
258	mp_clear(&us);
259	return -5;
260    }
261    p = to;
262
263
264    size = mp_unsigned_bin_size(&us);
265    assert(size <= RSA_size(rsa));
266    mp_to_unsigned_bin(&us, p);
267
268    mp_clear(&us);
269
270    /* head zero was skipped by mp_to_unsigned_bin */
271    if (*p == 0)
272	return -6;
273    if (*p != 1)
274	return -7;
275    size--; p++;
276    while (size && *p == 0xff) {
277	size--; p++;
278    }
279    if (size == 0 || *p != 0)
280	return -8;
281    size--; p++;
282
283    memmove(to, p, size);
284
285    return size;
286}
287
288static int
289ltm_rsa_private_encrypt(int flen, const unsigned char* from,
290			unsigned char* to, RSA* rsa, int padding)
291{
292    unsigned char *p, *p0;
293    int res;
294    int size;
295    mp_int in, out, n, e;
296    mp_int bi, b;
297    int blinding = (rsa->flags & RSA_FLAG_NO_BLINDING) == 0;
298    int do_unblind = 0;
299
300    if (padding != RSA_PKCS1_PADDING)
301	return -1;
302
303    mp_init_multi(&e, &n, &in, &out, &b, &bi, NULL);
304
305    size = RSA_size(rsa);
306
307    if (size < RSA_PKCS1_PADDING_SIZE || size - RSA_PKCS1_PADDING_SIZE < flen)
308	return -2;
309
310    p0 = p = malloc(size);
311    *p++ = 0;
312    *p++ = 1;
313    memset(p, 0xff, size - flen - 3);
314    p += size - flen - 3;
315    *p++ = 0;
316    memcpy(p, from, flen);
317    p += flen;
318    assert((p - p0) == size);
319
320    BN2mpz(&n, rsa->n);
321    BN2mpz(&e, rsa->e);
322
323    mp_read_unsigned_bin(&in, p0, size);
324    free(p0);
325
326    if(mp_isneg(&in) || mp_cmp(&in, &n) >= 0) {
327	size = -3;
328	goto out;
329    }
330
331    if (blinding) {
332	setup_blind(&n, &b, &bi);
333	blind(&in, &b, &e, &n);
334	do_unblind = 1;
335    }
336
337    if (rsa->p && rsa->q && rsa->dmp1 && rsa->dmq1 && rsa->iqmp) {
338	mp_int p, q, dmp1, dmq1, iqmp;
339
340	mp_init_multi(&p, &q, &dmp1, &dmq1, &iqmp, NULL);
341
342	BN2mpz(&p, rsa->p);
343	BN2mpz(&q, rsa->q);
344	BN2mpz(&dmp1, rsa->dmp1);
345	BN2mpz(&dmq1, rsa->dmq1);
346	BN2mpz(&iqmp, rsa->iqmp);
347
348	res = ltm_rsa_private_calculate(&in, &p, &q, &dmp1, &dmq1, &iqmp, &out);
349
350	mp_clear_multi(&p, &q, &dmp1, &dmq1, &iqmp, NULL);
351
352	if (res != 0) {
353	    size = -4;
354	    goto out;
355	}
356    } else {
357	mp_int d;
358
359	BN2mpz(&d, rsa->d);
360	res = mp_exptmod(&in, &d, &n, &out);
361	mp_clear(&d);
362	if (res != 0) {
363	    size = -5;
364	    goto out;
365	}
366    }
367
368    if (do_unblind)
369	unblind(&out, &bi, &n);
370
371    if (size > 0) {
372	size_t ssize;
373	ssize = mp_unsigned_bin_size(&out);
374	assert(size >= ssize);
375	mp_to_unsigned_bin(&out, to);
376	size = ssize;
377    }
378
379 out:
380    mp_clear_multi(&e, &n, &in, &out, &b, &bi, NULL);
381
382    return size;
383}
384
385static int
386ltm_rsa_private_decrypt(int flen, const unsigned char* from,
387			unsigned char* to, RSA* rsa, int padding)
388{
389    unsigned char *ptr;
390    int res, size;
391    mp_int in, out, n, e, b, bi;
392    int blinding = (rsa->flags & RSA_FLAG_NO_BLINDING) == 0;
393    int do_unblind = 0;
394
395    if (padding != RSA_PKCS1_PADDING)
396	return -1;
397
398    size = RSA_size(rsa);
399    if (flen > size)
400	return -2;
401
402    mp_init_multi(&in, &n, &e, &out, &b, &bi, NULL);
403
404    BN2mpz(&n, rsa->n);
405    BN2mpz(&e, rsa->e);
406
407    mp_read_unsigned_bin(&in, rk_UNCONST(from), flen);
408
409    if(mp_isneg(&in) || mp_cmp(&in, &n) >= 0) {
410	size = -2;
411	goto out;
412    }
413
414    if (blinding) {
415	setup_blind(&n, &b, &bi);
416	blind(&in, &b, &e, &n);
417	do_unblind = 1;
418    }
419
420    if (rsa->p && rsa->q && rsa->dmp1 && rsa->dmq1 && rsa->iqmp) {
421	mp_int p, q, dmp1, dmq1, iqmp;
422
423	mp_init_multi(&p, &q, &dmp1, &dmq1, &iqmp, NULL);
424
425	BN2mpz(&p, rsa->p);
426	BN2mpz(&q, rsa->q);
427	BN2mpz(&dmp1, rsa->dmp1);
428	BN2mpz(&dmq1, rsa->dmq1);
429	BN2mpz(&iqmp, rsa->iqmp);
430
431	res = ltm_rsa_private_calculate(&in, &p, &q, &dmp1, &dmq1, &iqmp, &out);
432
433	mp_clear_multi(&p, &q, &dmp1, &dmq1, &iqmp, NULL);
434
435	if (res != 0) {
436	    size = -3;
437	    goto out;
438	}
439
440    } else {
441	mp_int d;
442
443	if(mp_isneg(&in) || mp_cmp(&in, &n) >= 0)
444	    return -4;
445
446	BN2mpz(&d, rsa->d);
447	res = mp_exptmod(&in, &d, &n, &out);
448	mp_clear(&d);
449	if (res != 0) {
450	    size = -5;
451	    goto out;
452	}
453    }
454
455    if (do_unblind)
456	unblind(&out, &bi, &n);
457
458    ptr = to;
459    {
460	size_t ssize;
461	ssize = mp_unsigned_bin_size(&out);
462	assert(size >= ssize);
463	mp_to_unsigned_bin(&out, ptr);
464	size = ssize;
465    }
466
467    /* head zero was skipped by mp_int_to_unsigned */
468    if (*ptr != 2) {
469	size = -6;
470	goto out;
471    }
472    size--; ptr++;
473    while (size && *ptr != 0) {
474	size--; ptr++;
475    }
476    if (size == 0)
477	return -7;
478    size--; ptr++;
479
480    memmove(to, ptr, size);
481
482 out:
483    mp_clear_multi(&e, &n, &in, &out, &b, &bi, NULL);
484
485    return size;
486}
487
488static BIGNUM *
489mpz2BN(mp_int *s)
490{
491    size_t size;
492    BIGNUM *bn;
493    void *p;
494
495    size = mp_unsigned_bin_size(s);
496    p = malloc(size);
497    if (p == NULL && size != 0)
498	return NULL;
499
500    mp_to_unsigned_bin(s, p);
501
502    bn = BN_bin2bn(p, size, NULL);
503    free(p);
504    return bn;
505}
506
507#define CHECK(f, v) if ((f) != (v)) { goto out; }
508
509static int
510ltm_rsa_generate_key(RSA *rsa, int bits, BIGNUM *e, BN_GENCB *cb)
511{
512    mp_int el, p, q, n, d, dmp1, dmq1, iqmp, t1, t2, t3;
513    int counter, ret, bitsp;
514
515    if (bits < 789)
516	return -1;
517
518    bitsp = (bits + 1) / 2;
519
520    ret = -1;
521
522    mp_init_multi(&el, &p, &q, &n, &d,
523		  &dmp1, &dmq1, &iqmp,
524		  &t1, &t2, &t3, NULL);
525
526    BN2mpz(&el, e);
527
528    /* generate p and q so that p != q and bits(pq) ~ bits */
529    counter = 0;
530    do {
531	BN_GENCB_call(cb, 2, counter++);
532	CHECK(random_num(&p, bitsp), 0);
533	CHECK(mp_find_prime(&p), MP_YES);
534
535	mp_sub_d(&p, 1, &t1);
536	mp_gcd(&t1, &el, &t2);
537    } while(mp_cmp_d(&t2, 1) != 0);
538
539    BN_GENCB_call(cb, 3, 0);
540
541    counter = 0;
542    do {
543	BN_GENCB_call(cb, 2, counter++);
544	CHECK(random_num(&q, bits - bitsp), 0);
545	CHECK(mp_find_prime(&q), MP_YES);
546
547	if (mp_cmp(&p, &q) == 0) /* don't let p and q be the same */
548	    continue;
549
550	mp_sub_d(&q, 1, &t1);
551	mp_gcd(&t1, &el, &t2);
552    } while(mp_cmp_d(&t2, 1) != 0);
553
554    /* make p > q */
555    if (mp_cmp(&p, &q) < 0) {
556	mp_int c;
557	c = p;
558	p = q;
559	q = c;
560    }
561
562    BN_GENCB_call(cb, 3, 1);
563
564    /* calculate n,  		n = p * q */
565    mp_mul(&p, &q, &n);
566
567    /* calculate d, 		d = 1/e mod (p - 1)(q - 1) */
568    mp_sub_d(&p, 1, &t1);
569    mp_sub_d(&q, 1, &t2);
570    mp_mul(&t1, &t2, &t3);
571    mp_invmod(&el, &t3, &d);
572
573    /* calculate dmp1		dmp1 = d mod (p-1) */
574    mp_mod(&d, &t1, &dmp1);
575    /* calculate dmq1		dmq1 = d mod (q-1) */
576    mp_mod(&d, &t2, &dmq1);
577    /* calculate iqmp 		iqmp = 1/q mod p */
578    mp_invmod(&q, &p, &iqmp);
579
580    /* fill in RSA key */
581
582    rsa->e = mpz2BN(&el);
583    rsa->p = mpz2BN(&p);
584    rsa->q = mpz2BN(&q);
585    rsa->n = mpz2BN(&n);
586    rsa->d = mpz2BN(&d);
587    rsa->dmp1 = mpz2BN(&dmp1);
588    rsa->dmq1 = mpz2BN(&dmq1);
589    rsa->iqmp = mpz2BN(&iqmp);
590
591    ret = 1;
592
593out:
594    mp_clear_multi(&el, &p, &q, &n, &d,
595		   &dmp1, &dmq1, &iqmp,
596		   &t1, &t2, &t3, NULL);
597
598    return ret;
599}
600
601static int
602ltm_rsa_init(RSA *rsa)
603{
604    return 1;
605}
606
607static int
608ltm_rsa_finish(RSA *rsa)
609{
610    return 1;
611}
612
613const RSA_METHOD hc_rsa_ltm_method = {
614    "hcrypto ltm RSA",
615    ltm_rsa_public_encrypt,
616    ltm_rsa_public_decrypt,
617    ltm_rsa_private_encrypt,
618    ltm_rsa_private_decrypt,
619    NULL,
620    NULL,
621    ltm_rsa_init,
622    ltm_rsa_finish,
623    0,
624    NULL,
625    NULL,
626    NULL,
627    ltm_rsa_generate_key
628};
629
630const RSA_METHOD *
631RSA_ltm_method(void)
632{
633    return &hc_rsa_ltm_method;
634}
635