1/*
2 * Copyright (c) 2006 Kungliga Tekniska Högskolan
3 * (Royal Institute of Technology, Stockholm, Sweden).
4 * All rights reserved.
5 *
6 * Redistribution and use in source and binary forms, with or without
7 * modification, are permitted provided that the following conditions
8 * are met:
9 *
10 * 1. Redistributions of source code must retain the above copyright
11 *    notice, this list of conditions and the following disclaimer.
12 *
13 * 2. Redistributions in binary form must reproduce the above copyright
14 *    notice, this list of conditions and the following disclaimer in the
15 *    documentation and/or other materials provided with the distribution.
16 *
17 * 3. Neither the name of the Institute nor the names of its contributors
18 *    may be used to endorse or promote products derived from this software
19 *    without specific prior written permission.
20 *
21 * THIS SOFTWARE IS PROVIDED BY THE INSTITUTE AND CONTRIBUTORS ``AS IS'' AND
22 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
23 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
24 * ARE DISCLAIMED.  IN NO EVENT SHALL THE INSTITUTE OR CONTRIBUTORS BE LIABLE
25 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
26 * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
27 * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
28 * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
29 * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
30 * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
31 * SUCH DAMAGE.
32 */
33
34#include <config.h>
35
36
37#include <stdio.h>
38#include <stdlib.h>
39#include <string.h>
40#include <limits.h>
41
42#include <krb5-types.h>
43#include <roken.h>
44#include <rfc2459_asn1.h> /* XXX */
45#include <der.h>
46
47#include <bn.h>
48#include <rand.h>
49#include <hex.h>
50
51BIGNUM *
52BN_new(void)
53{
54    heim_integer *hi;
55    hi = calloc(1, sizeof(*hi));
56    return (BIGNUM *)hi;
57}
58
59void
60BN_free(BIGNUM *bn)
61{
62    BN_clear(bn);
63    free(bn);
64}
65
66void
67BN_clear(BIGNUM *bn)
68{
69    heim_integer *hi = (heim_integer *)bn;
70    if (hi->data) {
71	memset(hi->data, 0, hi->length);
72	free(hi->data);
73    }
74    memset(hi, 0, sizeof(*hi));
75}
76
77void
78BN_clear_free(BIGNUM *bn)
79{
80    BN_free(bn);
81}
82
83BIGNUM *
84BN_dup(const BIGNUM *bn)
85{
86    BIGNUM *b = BN_new();
87    if (der_copy_heim_integer((const heim_integer *)bn, (heim_integer *)b)) {
88	BN_free(b);
89	return NULL;
90    }
91    return b;
92}
93
94/*
95 * If the caller really want to know the number of bits used, subtract
96 * one from the length, multiply by 8, and then lookup in the table
97 * how many bits the hightest byte uses.
98 */
99int
100BN_num_bits(const BIGNUM *bn)
101{
102    static unsigned char num2bits[256] = {
103	0,1,2,2,3,3,3,3,4,4,4,4,4,4,4,4,  5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,
104	6,6,6,6,6,6,6,6,6,6,6,6,6,6,6,6,  6,6,6,6,6,6,6,6,6,6,6,6,6,6,6,6,
105	7,7,7,7,7,7,7,7,7,7,7,7,7,7,7,7,  7,7,7,7,7,7,7,7,7,7,7,7,7,7,7,7,
106	7,7,7,7,7,7,7,7,7,7,7,7,7,7,7,7,  7,7,7,7,7,7,7,7,7,7,7,7,7,7,7,7,
107	8,8,8,8,8,8,8,8,8,8,8,8,8,8,8,8,  8,8,8,8,8,8,8,8,8,8,8,8,8,8,8,8,
108	8,8,8,8,8,8,8,8,8,8,8,8,8,8,8,8,  8,8,8,8,8,8,8,8,8,8,8,8,8,8,8,8,
109	8,8,8,8,8,8,8,8,8,8,8,8,8,8,8,8,  8,8,8,8,8,8,8,8,8,8,8,8,8,8,8,8,
110	8,8,8,8,8,8,8,8,8,8,8,8,8,8,8,8,  8,8,8,8,8,8,8,8,8,8,8,8,8,8,8,8,
111    };
112    const heim_integer *i = (const void *)bn;
113    if (i->length == 0)
114	return 0;
115    return (i->length - 1) * 8 + num2bits[((unsigned char *)i->data)[0]];
116}
117
118int
119BN_num_bytes(const BIGNUM *bn)
120{
121    return ((const heim_integer *)bn)->length;
122}
123
124/*
125 * Ignore negative flag.
126 */
127
128BIGNUM *
129BN_bin2bn(const void *s, int len, BIGNUM *bn)
130{
131    heim_integer *hi = (void *)bn;
132
133    if (len < 0)
134	return NULL;
135
136    if (hi == NULL) {
137	hi = (heim_integer *)BN_new();
138	if (hi == NULL)
139	    return NULL;
140    }
141    if (hi->data)
142	BN_clear((BIGNUM *)hi);
143    hi->negative = 0;
144    hi->data = malloc(len);
145    if (hi->data == NULL && len != 0) {
146	if (bn == NULL)
147	    BN_free((BIGNUM *)hi);
148	return NULL;
149    }
150    hi->length = len;
151    memcpy(hi->data, s, len);
152    return (BIGNUM *)hi;
153}
154
155int
156BN_bn2bin(const BIGNUM *bn, void *to)
157{
158    const heim_integer *hi = (const void *)bn;
159    memcpy(to, hi->data, hi->length);
160    return hi->length;
161}
162
163int
164BN_hex2bn(BIGNUM **bnp, const char *in)
165{
166    int negative;
167    ssize_t ret;
168    size_t len;
169    void *data;
170
171    len = strlen(in);
172    data = malloc(len);
173    if (data == NULL)
174	return 0;
175
176    if (*in == '-') {
177	negative = 1;
178	in++;
179    } else
180	negative = 0;
181
182    ret = hex_decode(in, data, len);
183    if (ret < 0) {
184	free(data);
185	return 0;
186    }
187
188    *bnp = BN_bin2bn(data, ret, NULL);
189    free(data);
190    if (*bnp == NULL)
191	return 0;
192    BN_set_negative(*bnp, negative);
193    return 1;
194}
195
196char *
197BN_bn2hex(const BIGNUM *bn)
198{
199    ssize_t ret;
200    size_t len;
201    void *data;
202    char *str;
203
204    len = BN_num_bytes(bn);
205    data = malloc(len);
206    if (data == NULL)
207	return 0;
208
209    len = BN_bn2bin(bn, data);
210
211    ret = hex_encode(data, len, &str);
212    free(data);
213    if (ret < 0)
214	return 0;
215
216    return str;
217}
218
219int
220BN_cmp(const BIGNUM *bn1, const BIGNUM *bn2)
221{
222    return der_heim_integer_cmp((const heim_integer *)bn1,
223				(const heim_integer *)bn2);
224}
225
226void
227BN_set_negative(BIGNUM *bn, int flag)
228{
229    ((heim_integer *)bn)->negative = (flag ? 1 : 0);
230}
231
232int
233BN_is_negative(const BIGNUM *bn)
234{
235    return ((const heim_integer *)bn)->negative ? 1 : 0;
236}
237
238static const unsigned char is_set[8] = { 1, 2, 4, 8, 16, 32, 64, 128 };
239
240int
241BN_is_bit_set(const BIGNUM *bn, int bit)
242{
243    heim_integer *hi = (heim_integer *)bn;
244    unsigned char *p = hi->data;
245
246    if ((bit / 8) > hi->length || hi->length == 0)
247	return 0;
248
249    return p[hi->length - 1 - (bit / 8)] & is_set[bit % 8];
250}
251
252int
253BN_set_bit(BIGNUM *bn, int bit)
254{
255    heim_integer *hi = (heim_integer *)bn;
256    unsigned char *p;
257
258    if ((bit / 8) > hi->length || hi->length == 0) {
259	size_t len = (bit + 7) / 8;
260	void *d = realloc(hi->data, len);
261	if (d == NULL)
262	    return 0;
263	hi->data = d;
264	p = hi->data;
265	memset(&p[hi->length], 0, len);
266	hi->length = len;
267    } else
268	p = hi->data;
269
270    p[hi->length - 1 - (bit / 8)] |= is_set[bit % 8];
271    return 1;
272}
273
274int
275BN_clear_bit(BIGNUM *bn, int bit)
276{
277    heim_integer *hi = (heim_integer *)bn;
278    unsigned char *p = hi->data;
279
280    if ((bit / 8) > hi->length || hi->length == 0)
281	return 0;
282
283    p[hi->length - 1 - (bit / 8)] &= (unsigned char)(~(is_set[bit % 8]));
284
285    return 1;
286}
287
288int
289BN_set_word(BIGNUM *bn, unsigned long num)
290{
291    unsigned char p[sizeof(num)];
292    unsigned long num2;
293    int i, len;
294
295    for (num2 = num, i = 0; num2 > 0; i++)
296	num2 = num2 >> 8;
297
298    len = i;
299    for (; i > 0; i--) {
300	p[i - 1] = (num & 0xff);
301	num = num >> 8;
302    }
303
304    bn = BN_bin2bn(p, len, bn);
305    return bn != NULL;
306}
307
308unsigned long
309BN_get_word(const BIGNUM *bn)
310{
311    heim_integer *hi = (heim_integer *)bn;
312    unsigned long num = 0;
313    int i;
314
315    if (hi->negative || hi->length > sizeof(num))
316	return ULONG_MAX;
317
318    for (i = 0; i < hi->length; i++)
319	num = ((unsigned char *)hi->data)[i] | (num << 8);
320    return num;
321}
322
323int
324BN_rand(BIGNUM *bn, int bits, int top, int bottom)
325{
326    size_t len = (bits + 7) / 8;
327    heim_integer *i = (heim_integer *)bn;
328
329    BN_clear(bn);
330
331    i->negative = 0;
332    i->data = malloc(len);
333    if (i->data == NULL && len != 0)
334	return 0;
335    i->length = len;
336
337    if (RAND_bytes(i->data, i->length) != 1) {
338	free(i->data);
339	i->data = NULL;
340	return 0;
341    }
342
343    {
344	size_t j = len * 8;
345	while(j > bits) {
346	    BN_clear_bit(bn, j - 1);
347	    j--;
348	}
349    }
350
351    if (top == -1) {
352	;
353    } else if (top == 0 && bits > 0) {
354	BN_set_bit(bn, bits - 1);
355    } else if (top == 1 && bits > 1) {
356	BN_set_bit(bn, bits - 1);
357	BN_set_bit(bn, bits - 2);
358    } else {
359	BN_clear(bn);
360	return 0;
361    }
362
363    if (bottom && bits > 0)
364	BN_set_bit(bn, 0);
365
366    return 1;
367}
368
369/*
370 *
371 */
372
373int
374BN_uadd(BIGNUM *res, const BIGNUM *a, const BIGNUM *b)
375{
376    const heim_integer *ai = (const heim_integer *)a;
377    const heim_integer *bi = (const heim_integer *)b;
378    const unsigned char *ap, *bp;
379    unsigned char *cp;
380    heim_integer ci;
381    int carry = 0;
382    ssize_t len;
383
384    if (ai->negative && bi->negative)
385	return 0;
386    if (ai->length < bi->length) {
387	const heim_integer *si = bi;
388	bi = ai; ai = si;
389    }
390
391    ci.negative = 0;
392    ci.length = ai->length + 1;
393    ci.data = malloc(ci.length);
394    if (ci.data == NULL)
395	return 0;
396
397    ap = &((const unsigned char *)ai->data)[ai->length - 1];
398    bp = &((const unsigned char *)bi->data)[bi->length - 1];
399    cp = &((unsigned char *)ci.data)[ci.length - 1];
400
401    for (len = bi->length; len > 0; len--) {
402	carry = *ap + *bp + carry;
403	*cp = carry & 0xff;
404	carry = (carry & ~0xff) ? 1 : 0;
405	ap--; bp--; cp--;
406    }
407    for (len = ai->length - bi->length; len > 0; len--) {
408	carry = *ap + carry;
409	*cp = carry & 0xff;
410	carry = (carry & ~0xff) ? 1 : 0;
411	ap--; cp--;
412    }
413    if (!carry)
414	memmove(cp, cp + 1, --ci.length);
415    else
416	*cp = carry;
417
418    BN_clear(res);
419    *((heim_integer *)res) = ci;
420
421    return 1;
422}
423
424
425/*
426 * Callback when doing slow generation of numbers, like primes.
427 */
428
429void
430BN_GENCB_set(BN_GENCB *gencb, int (*cb_2)(int, int, BN_GENCB *), void *ctx)
431{
432    gencb->ver = 2;
433    gencb->cb.cb_2 = cb_2;
434    gencb->arg = ctx;
435}
436
437int
438BN_GENCB_call(BN_GENCB *cb, int a, int b)
439{
440    if (cb == NULL || cb->cb.cb_2 == NULL)
441	return 1;
442    return cb->cb.cb_2(a, b, cb);
443}
444
445/*
446 *
447 */
448
449struct BN_CTX {
450    struct {
451	BIGNUM **val;
452	size_t used;
453	size_t len;
454    } bn;
455    struct {
456	size_t *val;
457	size_t used;
458	size_t len;
459    } stack;
460};
461
462BN_CTX *
463BN_CTX_new(void)
464{
465    struct BN_CTX *c;
466    c = calloc(1, sizeof(*c));
467    return c;
468}
469
470void
471BN_CTX_free(BN_CTX *c)
472{
473    size_t i;
474    for (i = 0; i < c->bn.len; i++)
475	BN_free(c->bn.val[i]);
476    free(c->bn.val);
477    free(c->stack.val);
478}
479
480BIGNUM *
481BN_CTX_get(BN_CTX *c)
482{
483    if (c->bn.used == c->bn.len) {
484	void *ptr;
485	size_t i;
486	c->bn.len += 16;
487	ptr = realloc(c->bn.val, c->bn.len * sizeof(c->bn.val[0]));
488	if (ptr == NULL)
489	    return NULL;
490	c->bn.val = ptr;
491	for (i = c->bn.used; i < c->bn.len; i++) {
492	    c->bn.val[i] = BN_new();
493	    if (c->bn.val[i] == NULL) {
494		c->bn.len = i;
495		return NULL;
496	    }
497	}
498    }
499    return c->bn.val[c->bn.used++];
500}
501
502void
503BN_CTX_start(BN_CTX *c)
504{
505    if (c->stack.used == c->stack.len) {
506	void *ptr;
507	c->stack.len += 16;
508	ptr = realloc(c->stack.val, c->stack.len * sizeof(c->stack.val[0]));
509	if (ptr == NULL)
510	    abort();
511	c->stack.val = ptr;
512    }
513    c->stack.val[c->stack.used++] = c->bn.used;
514}
515
516void
517BN_CTX_end(BN_CTX *c)
518{
519    const size_t prev = c->stack.val[c->stack.used - 1];
520    size_t i;
521
522    if (c->stack.used == 0)
523	abort();
524
525    for (i = prev; i < c->bn.used; i++)
526	BN_clear(c->bn.val[i]);
527
528    c->stack.used--;
529    c->bn.used = prev;
530}
531
532