1/* $FreeBSD$ */
2/*
3 * The big num stuff is a bit broken at the moment and I've not yet fixed it.
4 * The symtom is that odd size big nums will fail.  Test code below (it only
5 * uses modexp currently).
6 *
7 * --Jason L. Wright
8 */
9#include <sys/types.h>
10#include <sys/ioctl.h>
11#include <machine/endian.h>
12#include <sys/time.h>
13#include <crypto/cryptodev.h>
14#include <openssl/bn.h>
15
16#include <paths.h>
17#include <fcntl.h>
18#include <err.h>
19#include <string.h>
20#include <unistd.h>
21#include <stdlib.h>
22
23int	crid = CRYPTO_FLAG_HARDWARE;
24int	verbose = 0;
25
26static int
27devcrypto(void)
28{
29	static int fd = -1;
30
31	if (fd < 0) {
32		fd = open(_PATH_DEV "crypto", O_RDWR, 0);
33		if (fd < 0)
34			err(1, _PATH_DEV "crypto");
35		if (fcntl(fd, F_SETFD, 1) == -1)
36			err(1, "fcntl(F_SETFD) (devcrypto)");
37	}
38	return fd;
39}
40
41static int
42crlookup(const char *devname)
43{
44	struct crypt_find_op find;
45
46	find.crid = -1;
47	strlcpy(find.name, devname, sizeof(find.name));
48	if (ioctl(devcrypto(), CIOCFINDDEV, &find) == -1)
49		err(1, "ioctl(CIOCFINDDEV)");
50	return find.crid;
51}
52
53static const char *
54crfind(int crid)
55{
56	static struct crypt_find_op find;
57
58	bzero(&find, sizeof(find));
59	find.crid = crid;
60	if (ioctl(devcrypto(), CIOCFINDDEV, &find) == -1)
61		err(1, "ioctl(CIOCFINDDEV)");
62	return find.name;
63}
64
65/*
66 * Convert a little endian byte string in 'p' that
67 * is 'plen' bytes long to a BIGNUM. If 'dst' is NULL,
68 * a new BIGNUM is allocated.  Returns NULL on failure.
69 *
70 * XXX there has got to be a more efficient way to do
71 * this, but I haven't figured out enough of the OpenSSL
72 * magic.
73 */
74BIGNUM *
75le_to_bignum(BIGNUM *dst, u_int8_t *p, int plen)
76{
77	u_int8_t *pd;
78	int i;
79
80	if (plen == 0)
81		return (NULL);
82
83	if ((pd = (u_int8_t *)malloc(plen)) == NULL)
84		return (NULL);
85
86	for (i = 0; i < plen; i++)
87		pd[i] = p[plen - i - 1];
88
89	dst = BN_bin2bn(pd, plen, dst);
90	free(pd);
91	return (dst);
92}
93
94/*
95 * Convert a BIGNUM to a little endian byte string.
96 * If 'rd' is NULL, allocate space for it, otherwise
97 * 'rd' is assumed to have room for BN_num_bytes(n)
98 * bytes.  Returns NULL on failure.
99 */
100u_int8_t *
101bignum_to_le(BIGNUM *n, u_int8_t *rd)
102{
103	int i, j, k;
104	int blen = BN_num_bytes(n);
105
106	if (blen == 0)
107		return (NULL);
108	if (rd == NULL)
109		rd = (u_int8_t *)malloc(blen);
110	if (rd == NULL)
111		return (NULL);
112
113	for (i = 0, j = 0; i < n->top; i++) {
114		for (k = 0; k < BN_BITS2 / 8; k++) {
115			if ((j + k) >= blen)
116				goto out;
117			rd[j + k] = n->d[i] >> (k * 8);
118		}
119		j += BN_BITS2 / 8;
120	}
121out:
122	return (rd);
123}
124
125int
126UB_mod_exp(BIGNUM *res, BIGNUM *a, BIGNUM *b, BIGNUM *c, BN_CTX *ctx)
127{
128	struct crypt_kop kop;
129	u_int8_t *ale, *ble, *cle;
130	static int crypto_fd = -1;
131
132	if (crypto_fd == -1 && ioctl(devcrypto(), CRIOGET, &crypto_fd) == -1)
133		err(1, "CRIOGET");
134
135	if ((ale = bignum_to_le(a, NULL)) == NULL)
136		err(1, "bignum_to_le, a");
137	if ((ble = bignum_to_le(b, NULL)) == NULL)
138		err(1, "bignum_to_le, b");
139	if ((cle = bignum_to_le(c, NULL)) == NULL)
140		err(1, "bignum_to_le, c");
141
142	bzero(&kop, sizeof(kop));
143	kop.crk_op = CRK_MOD_EXP;
144	kop.crk_iparams = 3;
145	kop.crk_oparams = 1;
146	kop.crk_crid = crid;
147	kop.crk_param[0].crp_p = ale;
148	kop.crk_param[0].crp_nbits = BN_num_bytes(a) * 8;
149	kop.crk_param[1].crp_p = ble;
150	kop.crk_param[1].crp_nbits = BN_num_bytes(b) * 8;
151	kop.crk_param[2].crp_p = cle;
152	kop.crk_param[2].crp_nbits = BN_num_bytes(c) * 8;
153	kop.crk_param[3].crp_p = cle;
154	kop.crk_param[3].crp_nbits = BN_num_bytes(c) * 8;
155
156	if (ioctl(crypto_fd, CIOCKEY2, &kop) == -1)
157		err(1, "CIOCKEY2");
158	if (verbose)
159		printf("device = %s\n", crfind(kop.crk_crid));
160
161	bzero(ale, BN_num_bytes(a));
162	free(ale);
163	bzero(ble, BN_num_bytes(b));
164	free(ble);
165
166	if (kop.crk_status != 0) {
167		printf("error %d\n", kop.crk_status);
168		bzero(cle, BN_num_bytes(c));
169		free(cle);
170		return (-1);
171	} else {
172		res = le_to_bignum(res, cle, BN_num_bytes(c));
173		bzero(cle, BN_num_bytes(c));
174		free(cle);
175		if (res == NULL)
176			err(1, "le_to_bignum");
177		return (0);
178	}
179	return (0);
180}
181
182void
183show_result(a, b, c, sw, hw)
184BIGNUM *a, *b, *c, *sw, *hw;
185{
186	printf("\n");
187
188	printf("A = ");
189	BN_print_fp(stdout, a);
190	printf("\n");
191
192	printf("B = ");
193	BN_print_fp(stdout, b);
194	printf("\n");
195
196	printf("C = ");
197	BN_print_fp(stdout, c);
198	printf("\n");
199
200	printf("sw= ");
201	BN_print_fp(stdout, sw);
202	printf("\n");
203
204	printf("hw= ");
205	BN_print_fp(stdout, hw);
206	printf("\n");
207
208	printf("\n");
209}
210
211void
212testit(void)
213{
214	BIGNUM *a, *b, *c, *r1, *r2;
215	BN_CTX *ctx;
216
217	ctx = BN_CTX_new();
218
219	a = BN_new();
220	b = BN_new();
221	c = BN_new();
222	r1 = BN_new();
223	r2 = BN_new();
224
225	BN_pseudo_rand(a, 1023, 0, 0);
226	BN_pseudo_rand(b, 1023, 0, 0);
227	BN_pseudo_rand(c, 1024, 0, 0);
228
229	if (BN_cmp(a, c) > 0) {
230		BIGNUM *rem = BN_new();
231
232		BN_mod(rem, a, c, ctx);
233		UB_mod_exp(r2, rem, b, c, ctx);
234		BN_free(rem);
235	} else {
236		UB_mod_exp(r2, a, b, c, ctx);
237	}
238	BN_mod_exp(r1, a, b, c, ctx);
239
240	if (BN_cmp(r1, r2) != 0) {
241		show_result(a, b, c, r1, r2);
242	}
243
244	BN_free(r2);
245	BN_free(r1);
246	BN_free(c);
247	BN_free(b);
248	BN_free(a);
249	BN_CTX_free(ctx);
250}
251
252static void
253usage(const char* cmd)
254{
255	printf("usage: %s [-d dev] [-v] [count]\n", cmd);
256	printf("count is the number of bignum ops to do\n");
257	printf("\n");
258	printf("-d use specific device\n");
259	printf("-v be verbose\n");
260	exit(-1);
261}
262
263int
264main(int argc, char *argv[])
265{
266	int c, i;
267
268	while ((c = getopt(argc, argv, "d:v")) != -1) {
269		switch (c) {
270		case 'd':
271			crid = crlookup(optarg);
272			break;
273		case 'v':
274			verbose = 1;
275			break;
276		default:
277			usage(argv[0]);
278		}
279	}
280	argc -= optind, argv += optind;
281
282	for (i = 0; i < 1000; i++) {
283		fprintf(stderr, "test %d\n", i);
284		testit();
285	}
286	return (0);
287}
288