1/* $OpenBSD: bn_word.c,v 1.2 2024/08/23 12:56:26 anton Exp $ */
2/*
3 * Copyright (c) 2023 Joel Sing <jsing@openbsd.org>
4 *
5 * Permission to use, copy, modify, and distribute this software for any
6 * purpose with or without fee is hereby granted, provided that the above
7 * copyright notice and this permission notice appear in all copies.
8 *
9 * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
10 * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
11 * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
12 * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
13 * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
14 * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
15 * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
16 */
17
18#include <err.h>
19#include <string.h>
20
21#include <openssl/bn.h>
22
23struct bn_word_test {
24	const char *in_hex;
25	BN_ULONG in_word;
26	BN_ULONG mod_word;
27	BN_ULONG out_word;
28	const char *out_hex;
29	int out_is_negative;
30};
31
32static int
33check_bn_word_test(const char *op_name, const BIGNUM *bn,
34    const struct bn_word_test *bwt)
35{
36	char *out_hex = NULL;
37	BN_ULONG out_word;
38	int failed = 1;
39
40	if ((out_word = BN_get_word(bn)) != bwt->out_word) {
41		fprintf(stderr, "FAIL %s: Got word %lx, want %lx\n",
42		    op_name, (unsigned long)out_word,
43		    (unsigned long)bwt->out_word);
44		goto failure;
45	}
46
47	if (BN_is_negative(bn) != bwt->out_is_negative) {
48		fprintf(stderr, "FAIL %s: Got is negative %d, want %d\n",
49		    op_name, BN_is_negative(bn), bwt->out_is_negative);
50		goto failure;
51	}
52
53	if ((out_hex = BN_bn2hex(bn)) == NULL)
54		errx(1, "BN_bn2hex() failed");
55
56	if (strcmp(out_hex, bwt->out_hex) != 0) {
57		fprintf(stderr, "FAIL %s: Got hex %s, want %s\n",
58		    op_name, out_hex, bwt->out_hex);
59		goto failure;
60	}
61
62	if (BN_is_zero(bn) && BN_is_negative(bn) != 0) {
63		fprintf(stderr, "FAIL %s: Got negative zero\n", op_name);
64		goto failure;
65	}
66
67	failed = 0;
68
69 failure:
70	free(out_hex);
71
72	return failed;
73}
74
75static int
76test_bn_word(int (*bn_word_op)(BIGNUM *, BN_ULONG), const char *op_name,
77    const struct bn_word_test *bwts, size_t num_tests)
78{
79	const struct bn_word_test *bwt;
80	BIGNUM *bn;
81	size_t i;
82	int failed = 0;
83
84	if ((bn = BN_new()) == NULL)
85		errx(1, "BN_new() failed");
86
87	for (i = 0; i < num_tests; i++) {
88		bwt = &bwts[i];
89
90		if (!BN_hex2bn(&bn, bwt->in_hex)) {
91			fprintf(stderr, "FAIL: BN_hex2bn(\"%s\") failed\n",
92			    bwt->in_hex);
93			failed = 1;
94			continue;
95		}
96
97		if (!bn_word_op(bn, bwt->in_word)) {
98			fprintf(stderr, "FAIL: %s(%lx) failed\n", op_name,
99			     (unsigned long)bwt->in_word);
100			failed = 1;
101			continue;
102		}
103
104		failed |= check_bn_word_test(op_name, bn, bwt);
105	}
106
107	BN_free(bn);
108
109	return failed;
110}
111
112static const struct bn_word_test bn_add_word_tests[] = {
113	{
114		.in_hex = "1",
115		.in_word = 0,
116		.out_word = 1,
117		.out_hex = "01",
118	},
119	{
120		.in_hex = "0",
121		.in_word = 1,
122		.out_word = 1,
123		.out_hex = "01",
124	},
125	{
126		.in_hex = "1",
127		.in_word = 1,
128		.out_word = 2,
129		.out_hex = "02",
130	},
131	{
132		.in_hex = "-1",
133		.in_word = 2,
134		.out_word = 1,
135		.out_hex = "01",
136	},
137	{
138		.in_hex = "-1",
139		.in_word = 1,
140		.out_word = 0,
141		.out_hex = "0",
142	},
143	{
144		.in_hex = "-3",
145		.in_word = 2,
146		.out_word = 1,
147		.out_hex = "-01",
148		.out_is_negative = 1,
149	},
150	{
151		.in_hex = "1",
152		.in_word = 0xfffffffeUL,
153		.out_word = 0xffffffffUL,
154		.out_hex = "FFFFFFFF",
155	},
156	{
157		.in_hex = "FFFFFFFFFFFFFFFF",
158		.in_word = 1,
159		.out_word = BN_MASK2,
160		.out_hex = "010000000000000000",
161	},
162};
163
164#define N_BN_ADD_WORD_TESTS \
165    (sizeof(bn_add_word_tests) / sizeof(bn_add_word_tests[0]))
166
167static int
168test_bn_add_word(void)
169{
170	return test_bn_word(BN_add_word, "BN_add_word", bn_add_word_tests,
171	    N_BN_ADD_WORD_TESTS);
172}
173
174static const struct bn_word_test bn_sub_word_tests[] = {
175	{
176		.in_hex = "1",
177		.in_word = 0,
178		.out_word = 1,
179		.out_hex = "01",
180	},
181	{
182		.in_hex = "0",
183		.in_word = 1,
184		.out_word = 1,
185		.out_hex = "-01",
186		.out_is_negative = 1,
187	},
188	{
189		.in_hex = "1",
190		.in_word = 1,
191		.out_word = 0,
192		.out_hex = "0",
193	},
194	{
195		.in_hex = "2",
196		.in_word = 1,
197		.out_word = 1,
198		.out_hex = "01",
199	},
200	{
201		.in_hex = "-1",
202		.in_word = 2,
203		.out_word = 3,
204		.out_hex = "-03",
205		.out_is_negative = 1,
206	},
207	{
208		.in_hex = "1",
209		.in_word = 1,
210		.out_word = 0,
211		.out_hex = "0",
212	},
213	{
214		.in_hex = "3",
215		.in_word = 2,
216		.out_word = 1,
217		.out_hex = "01",
218	},
219	{
220		.in_hex = "-3",
221		.in_word = 2,
222		.out_word = 5,
223		.out_hex = "-05",
224		.out_is_negative = 1,
225	},
226	{
227		.in_hex = "-1",
228		.in_word = 0xfffffffeUL,
229		.out_word = 0xffffffffUL,
230		.out_hex = "-FFFFFFFF",
231		.out_is_negative = 1,
232	},
233	{
234		.in_hex = "010000000000000000",
235		.in_word = 1,
236		.out_word = BN_MASK2,
237		.out_hex = "FFFFFFFFFFFFFFFF",
238	},
239};
240
241#define N_BN_SUB_WORD_TESTS \
242    (sizeof(bn_sub_word_tests) / sizeof(bn_sub_word_tests[0]))
243
244static int
245test_bn_sub_word(void)
246{
247	return test_bn_word(BN_sub_word, "BN_sub_word", bn_sub_word_tests,
248	    N_BN_SUB_WORD_TESTS);
249}
250
251static const struct bn_word_test bn_mul_word_tests[] = {
252	{
253		.in_hex = "1",
254		.in_word = 0,
255		.out_word = 0,
256		.out_hex = "0",
257	},
258	{
259		.in_hex = "0",
260		.in_word = 1,
261		.out_word = 0,
262		.out_hex = "0",
263	},
264	{
265		.in_hex = "1",
266		.in_word = 1,
267		.out_word = 1,
268		.out_hex = "01",
269	},
270	{
271		.in_hex = "-1",
272		.in_word = 0,
273		.out_word = 0,
274		.out_hex = "0",
275	},
276	{
277		.in_hex = "-1",
278		.in_word = 1,
279		.out_word = 1,
280		.out_hex = "-01",
281		.out_is_negative = 1,
282	},
283	{
284		.in_hex = "-3",
285		.in_word = 2,
286		.out_word = 6,
287		.out_hex = "-06",
288		.out_is_negative = 1,
289	},
290	{
291		.in_hex = "1",
292		.in_word = 0xfffffffeUL,
293		.out_word = 0xfffffffeUL,
294		.out_hex = "FFFFFFFE",
295	},
296	{
297		.in_hex = "010000000000000000",
298		.in_word = 2,
299		.out_word = BN_MASK2,
300		.out_hex = "020000000000000000",
301	},
302};
303
304#define N_BN_MUL_WORD_TESTS \
305    (sizeof(bn_mul_word_tests) / sizeof(bn_mul_word_tests[0]))
306
307static int
308test_bn_mul_word(void)
309{
310	return test_bn_word(BN_mul_word, "BN_mul_word", bn_mul_word_tests,
311	    N_BN_MUL_WORD_TESTS);
312}
313
314static const struct bn_word_test bn_div_word_tests[] = {
315	{
316		.in_hex = "1",
317		.in_word = 0,
318		.mod_word = BN_MASK2,
319		.out_word = 1,
320		.out_hex = "01",
321	},
322	{
323		.in_hex = "0",
324		.in_word = 1,
325		.mod_word = 0,
326		.out_word = 0,
327		.out_hex = "0",
328	},
329	{
330		.in_hex = "4",
331		.in_word = 2,
332		.mod_word = 0,
333		.out_word = 2,
334		.out_hex = "02",
335	},
336	{
337		.in_hex = "7",
338		.in_word = 3,
339		.mod_word = 1,
340		.out_word = 2,
341		.out_hex = "02",
342	},
343	{
344		.in_hex = "1",
345		.in_word = 1,
346		.mod_word = 0,
347		.out_word = 1,
348		.out_hex = "01",
349	},
350	{
351		.in_hex = "-2",
352		.in_word = 1,
353		.mod_word = 0,
354		.out_word = 2,
355		.out_hex = "-02",
356		.out_is_negative = 1,
357	},
358	{
359		.in_hex = "-1",
360		.in_word = 2,
361		.mod_word = 1,
362		.out_word = 0,
363		.out_hex = "0",
364	},
365	{
366		.in_hex = "-3",
367		.in_word = 2,
368		.mod_word = 1,
369		.out_word = 1,
370		.out_hex = "-01",
371		.out_is_negative = 1,
372	},
373	{
374		.in_hex = "1",
375		.in_word = 0xffffffffUL,
376		.mod_word = 1,
377		.out_word = 0,
378		.out_hex = "0",
379	},
380	{
381		.in_hex = "FFFFFFFF",
382		.in_word = 1,
383		.mod_word = 0,
384		.out_word = 0xffffffffUL,
385		.out_hex = "FFFFFFFF",
386	},
387	{
388		.in_hex = "FFFFFFFE",
389		.in_word = 0xffffffffUL,
390		.mod_word = 0xfffffffeUL,
391		.out_word = 0,
392		.out_hex = "0",
393	},
394	{
395		.in_hex = "FFFFFFFFFFFFFFFF",
396		.in_word = 1,
397		.mod_word = 0,
398		.out_word = BN_MASK2,
399		.out_hex = "FFFFFFFFFFFFFFFF",
400	},
401	{
402		.in_hex = "FFFFFFFF",
403		.in_word = 0xff,
404		.mod_word = 0,
405		.out_word = 0x1010101UL,
406		.out_hex = "01010101",
407	},
408	{
409		.in_hex = "FFFFFFFF",
410		.in_word = 0x10,
411		.mod_word = 0xf,
412		.out_word = 0xfffffffUL,
413		.out_hex = "0FFFFFFF",
414	},
415};
416
417#define N_BN_DIV_WORD_TESTS \
418    (sizeof(bn_div_word_tests) / sizeof(bn_div_word_tests[0]))
419
420static int
421test_bn_div_word(void)
422{
423	const char *op_name = "BN_div_word";
424	const struct bn_word_test *bwt;
425	BN_ULONG mod_word;
426	BIGNUM *bn;
427	size_t i;
428	int failed = 0;
429
430	if ((bn = BN_new()) == NULL)
431		errx(1, "BN_new() failed");
432
433	for (i = 0; i < N_BN_DIV_WORD_TESTS; i++) {
434		bwt = &bn_div_word_tests[i];
435
436		if (!BN_hex2bn(&bn, bwt->in_hex)) {
437			fprintf(stderr, "FAIL: BN_hex2bn(\"%s\") failed\n",
438			    bwt->in_hex);
439			failed = 1;
440			continue;
441		}
442
443		if ((mod_word = BN_div_word(bn, bwt->in_word)) != bwt->mod_word) {
444			fprintf(stderr, "FAIL %s: Got mod word %lx, want %lx\n",
445			    op_name, (unsigned long)mod_word,
446			    (unsigned long)bwt->mod_word);
447			failed = 1;
448			continue;
449		}
450
451		failed |= check_bn_word_test(op_name, bn, bwt);
452	}
453
454	BN_free(bn);
455
456	return failed;
457}
458
459static const struct bn_word_test bn_mod_word_tests[] = {
460	{
461		.in_hex = "1",
462		.in_word = 0,
463		.mod_word = BN_MASK2,
464		.out_word = 1,
465		.out_hex = "01",
466	},
467	{
468		.in_hex = "0",
469		.in_word = 1,
470		.mod_word = 0,
471		.out_word = 0,
472		.out_hex = "0",
473	},
474	{
475		.in_hex = "4",
476		.in_word = 2,
477		.mod_word = 0,
478		.out_word = 4,
479		.out_hex = "04",
480	},
481	{
482		.in_hex = "7",
483		.in_word = 3,
484		.mod_word = 1,
485		.out_word = 7,
486		.out_hex = "07",
487	},
488	{
489		.in_hex = "1",
490		.in_word = 1,
491		.mod_word = 0,
492		.out_word = 1,
493		.out_hex = "01",
494	},
495	{
496		.in_hex = "-2",
497		.in_word = 1,
498		.mod_word = 0,
499		.out_word = 2,
500		.out_hex = "-02",
501		.out_is_negative = 1,
502	},
503	{
504		.in_hex = "-1",
505		.in_word = 2,
506		.mod_word = 1,
507		.out_word = 1,
508		.out_hex = "-01",
509		.out_is_negative = 1,
510	},
511	{
512		.in_hex = "-3",
513		.in_word = 2,
514		.mod_word = 1,
515		.out_word = 3,
516		.out_hex = "-03",
517		.out_is_negative = 1,
518	},
519	{
520		.in_hex = "1",
521		.in_word = 0xffffffffUL,
522		.mod_word = 1,
523		.out_word = 1,
524		.out_hex = "01",
525	},
526	{
527		.in_hex = "FFFFFFFF",
528		.in_word = 1,
529		.mod_word = 0,
530		.out_word = 0xffffffffUL,
531		.out_hex = "FFFFFFFF",
532	},
533	{
534		.in_hex = "FFFFFFFE",
535		.in_word = 0xffffffffUL,
536		.mod_word = 0xfffffffeUL,
537		.out_word = 0xfffffffeUL,
538		.out_hex = "FFFFFFFE",
539	},
540	{
541		.in_hex = "FFFFFFFFFFFFFFFF",
542		.in_word = 1,
543		.mod_word = 0,
544		.out_word = BN_MASK2,
545		.out_hex = "FFFFFFFFFFFFFFFF",
546	},
547	{
548		.in_hex = "FFFFFFFF",
549		.in_word = 0xff,
550		.mod_word = 0,
551		.out_word = 0xffffffff,
552		.out_hex = "FFFFFFFF",
553	},
554	{
555		.in_hex = "FFFFFFFF",
556		.in_word = 0x10,
557		.mod_word = 0xf,
558		.out_word = 0xffffffffUL,
559		.out_hex = "FFFFFFFF",
560	},
561};
562
563#define N_BN_MOD_WORD_TESTS \
564    (sizeof(bn_mod_word_tests) / sizeof(bn_mod_word_tests[0]))
565
566static int
567test_bn_mod_word(void)
568{
569	const char *op_name = "BN_mod_word";
570	const struct bn_word_test *bwt;
571	BN_ULONG mod_word;
572	BIGNUM *bn;
573	size_t i;
574	int failed = 0;
575
576	if ((bn = BN_new()) == NULL)
577		errx(1, "BN_new() failed");
578
579	for (i = 0; i < N_BN_MOD_WORD_TESTS; i++) {
580		bwt = &bn_mod_word_tests[i];
581
582		if (!BN_hex2bn(&bn, bwt->in_hex)) {
583			fprintf(stderr, "FAIL: BN_hex2bn(\"%s\") failed\n",
584			    bwt->in_hex);
585			failed = 1;
586			continue;
587		}
588
589		if ((mod_word = BN_mod_word(bn, bwt->in_word)) != bwt->mod_word) {
590			fprintf(stderr, "FAIL %s: Got mod word %lx, want %lx\n",
591			    op_name, (unsigned long)mod_word,
592			    (unsigned long)bwt->mod_word);
593			failed = 1;
594			continue;
595		}
596
597		failed |= check_bn_word_test(op_name, bn, bwt);
598	}
599
600	BN_free(bn);
601
602	return failed;
603}
604
605int
606main(int argc, char **argv)
607{
608	int failed = 0;
609
610	failed |= test_bn_add_word();
611	failed |= test_bn_sub_word();
612	failed |= test_bn_mul_word();
613	failed |= test_bn_div_word();
614	failed |= test_bn_mod_word();
615
616	return failed;
617}
618