1/* $OpenBSD: tlstest.c,v 1.15 2022/07/16 07:46:08 tb Exp $ */
2/*
3 * Copyright (c) 2017 Joel Sing <jsing@openbsd.org>
4 *
5 * Permission to use, copy, modify, and distribute this software for any
6 * purpose with or without fee is hereby granted, provided that the above
7 * copyright notice and this permission notice appear in all copies.
8 *
9 * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
10 * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
11 * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
12 * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
13 * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
14 * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
15 * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
16 */
17
18#include <sys/socket.h>
19
20#include <err.h>
21#include <fcntl.h>
22#include <stdio.h>
23#include <string.h>
24#include <unistd.h>
25
26#include <tls.h>
27
28#define CIRCULAR_BUFFER_SIZE 512
29
30unsigned char client_buffer[CIRCULAR_BUFFER_SIZE];
31unsigned char *client_readptr, *client_writeptr;
32
33unsigned char server_buffer[CIRCULAR_BUFFER_SIZE];
34unsigned char *server_readptr, *server_writeptr;
35
36char *cafile, *certfile, *keyfile;
37
38int debug = 0;
39
40static void
41circular_init(void)
42{
43	client_readptr = client_writeptr = client_buffer;
44	server_readptr = server_writeptr = server_buffer;
45}
46
47static ssize_t
48circular_read(char *name, unsigned char *buf, size_t bufsize,
49    unsigned char **readptr, unsigned char *writeptr,
50    unsigned char *outbuf, size_t outlen)
51{
52	unsigned char *nextptr = *readptr;
53	size_t n = 0;
54
55	while (n < outlen) {
56		if (nextptr == writeptr)
57			break;
58		*outbuf++ = *nextptr++;
59		if ((size_t)(nextptr - buf) >= bufsize)
60			nextptr = buf;
61		*readptr = nextptr;
62		n++;
63	}
64
65	if (debug && n > 0)
66		fprintf(stderr, "%s buffer: read %zi bytes\n", name, n);
67
68	return (n > 0 ? (ssize_t)n : TLS_WANT_POLLIN);
69}
70
71static ssize_t
72circular_write(char *name, unsigned char *buf, size_t bufsize,
73    unsigned char *readptr, unsigned char **writeptr,
74    const unsigned char *inbuf, size_t inlen)
75{
76	unsigned char *nextptr = *writeptr;
77	unsigned char *prevptr;
78	size_t n = 0;
79
80	while (n < inlen) {
81		prevptr = nextptr++;
82		if ((size_t)(nextptr - buf) >= bufsize)
83			nextptr = buf;
84		if (nextptr == readptr)
85			break;
86		*prevptr = *inbuf++;
87		*writeptr = nextptr;
88		n++;
89	}
90
91	if (debug && n > 0)
92		fprintf(stderr, "%s buffer: wrote %zi bytes\n", name, n);
93
94	return (n > 0 ? (ssize_t)n : TLS_WANT_POLLOUT);
95}
96
97static ssize_t
98client_read(struct tls *ctx, void *buf, size_t buflen, void *cb_arg)
99{
100	return circular_read("client", client_buffer, sizeof(client_buffer),
101	    &client_readptr, client_writeptr, buf, buflen);
102}
103
104static ssize_t
105client_write(struct tls *ctx, const void *buf, size_t buflen, void *cb_arg)
106{
107	return circular_write("server", server_buffer, sizeof(server_buffer),
108	    server_readptr, &server_writeptr, buf, buflen);
109}
110
111static ssize_t
112server_read(struct tls *ctx, void *buf, size_t buflen, void *cb_arg)
113{
114	return circular_read("server", server_buffer, sizeof(server_buffer),
115	    &server_readptr, server_writeptr, buf, buflen);
116}
117
118static ssize_t
119server_write(struct tls *ctx, const void *buf, size_t buflen, void *cb_arg)
120{
121	return circular_write("client", client_buffer, sizeof(client_buffer),
122	    client_readptr, &client_writeptr, buf, buflen);
123}
124
125static int
126do_tls_handshake(char *name, struct tls *ctx)
127{
128	int rv;
129
130	rv = tls_handshake(ctx);
131	if (rv == 0)
132		return (1);
133	if (rv == TLS_WANT_POLLIN || rv == TLS_WANT_POLLOUT)
134		return (0);
135
136	errx(1, "%s handshake failed: %s", name, tls_error(ctx));
137}
138
139static int
140do_tls_close(char *name, struct tls *ctx)
141{
142	int rv;
143
144	rv = tls_close(ctx);
145	if (rv == 0)
146		return (1);
147	if (rv == TLS_WANT_POLLIN || rv == TLS_WANT_POLLOUT)
148		return (0);
149
150	errx(1, "%s close failed: %s", name, tls_error(ctx));
151}
152
153static int
154do_client_server_handshake(char *desc, struct tls *client,
155    struct tls *server_cctx)
156{
157	int i, client_done, server_done;
158
159	i = client_done = server_done = 0;
160	do {
161		if (client_done == 0)
162			client_done = do_tls_handshake("client", client);
163		if (server_done == 0)
164			server_done = do_tls_handshake("server", server_cctx);
165	} while (i++ < 100 && (client_done == 0 || server_done == 0));
166
167	if (client_done == 0 || server_done == 0) {
168		printf("FAIL: %s TLS handshake did not complete\n", desc);
169		return (1);
170	}
171
172	return (0);
173}
174
175static int
176do_client_server_close(char *desc, struct tls *client, struct tls *server_cctx)
177{
178	int i, client_done, server_done;
179
180	i = client_done = server_done = 0;
181	do {
182		if (client_done == 0)
183			client_done = do_tls_close("client", client);
184		if (server_done == 0)
185			server_done = do_tls_close("server", server_cctx);
186	} while (i++ < 100 && (client_done == 0 || server_done == 0));
187
188	if (client_done == 0 || server_done == 0) {
189		printf("FAIL: %s TLS close did not complete\n", desc);
190		return (1);
191	}
192
193	return (0);
194}
195
196static int
197do_client_server_test(char *desc, struct tls *client, struct tls *server_cctx)
198{
199	if (do_client_server_handshake(desc, client, server_cctx) != 0)
200		return (1);
201
202	printf("INFO: %s TLS handshake completed successfully\n", desc);
203
204	/* XXX - Do some reads and writes... */
205
206	if (do_client_server_close(desc, client, server_cctx) != 0)
207		return (1);
208
209	printf("INFO: %s TLS close completed successfully\n", desc);
210
211	return (0);
212}
213
214static int
215test_tls_cbs(struct tls *client, struct tls *server)
216{
217	struct tls *server_cctx;
218	int failure;
219
220	circular_init();
221
222	if (tls_accept_cbs(server, &server_cctx, server_read, server_write,
223	    NULL) == -1)
224		errx(1, "failed to accept: %s", tls_error(server));
225
226	if (tls_connect_cbs(client, client_read, client_write, NULL,
227	    "test") == -1)
228		errx(1, "failed to connect: %s", tls_error(client));
229
230	failure = do_client_server_test("callback", client, server_cctx);
231
232	tls_free(server_cctx);
233
234	return (failure);
235}
236
237static int
238test_tls_fds(struct tls *client, struct tls *server)
239{
240	struct tls *server_cctx;
241	int cfds[2], sfds[2];
242	int failure;
243
244	if (pipe2(cfds, O_NONBLOCK) == -1)
245		err(1, "failed to create pipe");
246	if (pipe2(sfds, O_NONBLOCK) == -1)
247		err(1, "failed to create pipe");
248
249	if (tls_accept_fds(server, &server_cctx, sfds[0], cfds[1]) == -1)
250		errx(1, "failed to accept: %s", tls_error(server));
251
252	if (tls_connect_fds(client, cfds[0], sfds[1], "test") == -1)
253		errx(1, "failed to connect: %s", tls_error(client));
254
255	failure = do_client_server_test("file descriptor", client, server_cctx);
256
257	tls_free(server_cctx);
258
259	close(cfds[0]);
260	close(cfds[1]);
261	close(sfds[0]);
262	close(sfds[1]);
263
264	return (failure);
265}
266
267static int
268test_tls_socket(struct tls *client, struct tls *server)
269{
270	struct tls *server_cctx;
271	int failure;
272	int sv[2];
273
274	if (socketpair(AF_UNIX, SOCK_STREAM | SOCK_NONBLOCK, PF_UNSPEC,
275	    sv) == -1)
276		err(1, "failed to create socketpair");
277
278	if (tls_accept_socket(server, &server_cctx, sv[0]) == -1)
279		errx(1, "failed to accept: %s", tls_error(server));
280
281	if (tls_connect_socket(client, sv[1], "test") == -1)
282		errx(1, "failed to connect: %s", tls_error(client));
283
284	failure = do_client_server_test("socket", client, server_cctx);
285
286	tls_free(server_cctx);
287
288	close(sv[0]);
289	close(sv[1]);
290
291	return (failure);
292}
293
294static int
295test_tls(char *client_protocols, char *server_protocols, char *ciphers)
296{
297	struct tls_config *client_cfg, *server_cfg;
298	struct tls *client, *server;
299	uint32_t protocols;
300	int failure = 0;
301
302	if ((client = tls_client()) == NULL)
303		errx(1, "failed to create tls client");
304	if ((client_cfg = tls_config_new()) == NULL)
305		errx(1, "failed to create tls client config");
306	tls_config_insecure_noverifyname(client_cfg);
307	if (tls_config_parse_protocols(&protocols, client_protocols) == -1)
308		errx(1, "failed to parse protocols: %s", tls_config_error(client_cfg));
309	if (tls_config_set_protocols(client_cfg, protocols) == -1)
310		errx(1, "failed to set protocols: %s", tls_config_error(client_cfg));
311	if (tls_config_set_ciphers(client_cfg, ciphers) == -1)
312		errx(1, "failed to set ciphers: %s", tls_config_error(client_cfg));
313	if (tls_config_set_ca_file(client_cfg, cafile) == -1)
314		errx(1, "failed to set ca: %s", tls_config_error(client_cfg));
315
316	if ((server = tls_server()) == NULL)
317		errx(1, "failed to create tls server");
318	if ((server_cfg = tls_config_new()) == NULL)
319		errx(1, "failed to create tls server config");
320	if (tls_config_parse_protocols(&protocols, server_protocols) == -1)
321		errx(1, "failed to parse protocols: %s", tls_config_error(server_cfg));
322	if (tls_config_set_protocols(server_cfg, protocols) == -1)
323		errx(1, "failed to set protocols: %s", tls_config_error(server_cfg));
324	if (tls_config_set_ciphers(server_cfg, ciphers) == -1)
325		errx(1, "failed to set ciphers: %s", tls_config_error(server_cfg));
326	if (tls_config_set_keypair_file(server_cfg, certfile, keyfile) == -1)
327		errx(1, "failed to set keypair: %s",
328		    tls_config_error(server_cfg));
329
330	if (tls_configure(client, client_cfg) == -1)
331		errx(1, "failed to configure client: %s", tls_error(client));
332	tls_reset(server);
333	if (tls_configure(server, server_cfg) == -1)
334		errx(1, "failed to configure server: %s", tls_error(server));
335
336	tls_config_free(client_cfg);
337	tls_config_free(server_cfg);
338
339	failure |= test_tls_cbs(client, server);
340
341	tls_free(client);
342	tls_free(server);
343
344	return (failure);
345}
346
347static int
348do_tls_tests(void)
349{
350	struct tls_config *client_cfg, *server_cfg;
351	struct tls *client, *server;
352	int failure = 0;
353
354	printf("== TLS tests ==\n");
355
356	if ((client = tls_client()) == NULL)
357		errx(1, "failed to create tls client");
358	if ((client_cfg = tls_config_new()) == NULL)
359		errx(1, "failed to create tls client config");
360	tls_config_insecure_noverifyname(client_cfg);
361	if (tls_config_set_ca_file(client_cfg, cafile) == -1)
362		errx(1, "failed to set ca: %s", tls_config_error(client_cfg));
363
364	if ((server = tls_server()) == NULL)
365		errx(1, "failed to create tls server");
366	if ((server_cfg = tls_config_new()) == NULL)
367		errx(1, "failed to create tls server config");
368	if (tls_config_set_keypair_file(server_cfg, certfile, keyfile) == -1)
369		errx(1, "failed to set keypair: %s",
370		    tls_config_error(server_cfg));
371
372	tls_reset(client);
373	if (tls_configure(client, client_cfg) == -1)
374		errx(1, "failed to configure client: %s", tls_error(client));
375	tls_reset(server);
376	if (tls_configure(server, server_cfg) == -1)
377		errx(1, "failed to configure server: %s", tls_error(server));
378
379	failure |= test_tls_cbs(client, server);
380
381	tls_reset(client);
382	if (tls_configure(client, client_cfg) == -1)
383		errx(1, "failed to configure client: %s", tls_error(client));
384	tls_reset(server);
385	if (tls_configure(server, server_cfg) == -1)
386		errx(1, "failed to configure server: %s", tls_error(server));
387
388	failure |= test_tls_fds(client, server);
389
390	tls_reset(client);
391	if (tls_configure(client, client_cfg) == -1)
392		errx(1, "failed to configure client: %s", tls_error(client));
393	tls_reset(server);
394	if (tls_configure(server, server_cfg) == -1)
395		errx(1, "failed to configure server: %s", tls_error(server));
396
397	tls_config_free(client_cfg);
398	tls_config_free(server_cfg);
399
400	failure |= test_tls_socket(client, server);
401
402	tls_free(client);
403	tls_free(server);
404
405	printf("\n");
406
407	return (failure);
408}
409
410static int
411do_tls_ordering_tests(void)
412{
413	struct tls *client = NULL, *server = NULL, *server_cctx = NULL;
414	struct tls_config *client_cfg, *server_cfg;
415	int failure = 0;
416
417	printf("== TLS ordering tests ==\n");
418
419	if ((client = tls_client()) == NULL)
420		errx(1, "failed to create tls client");
421	if ((client_cfg = tls_config_new()) == NULL)
422		errx(1, "failed to create tls client config");
423	tls_config_insecure_noverifyname(client_cfg);
424	if (tls_config_set_ca_file(client_cfg, cafile) == -1)
425		errx(1, "failed to set ca: %s", tls_config_error(client_cfg));
426
427	if ((server = tls_server()) == NULL)
428		errx(1, "failed to create tls server");
429	if ((server_cfg = tls_config_new()) == NULL)
430		errx(1, "failed to create tls server config");
431	if (tls_config_set_keypair_file(server_cfg, certfile, keyfile) == -1)
432		errx(1, "failed to set keypair: %s",
433		    tls_config_error(server_cfg));
434
435	if (tls_configure(client, client_cfg) == -1)
436		errx(1, "failed to configure client: %s", tls_error(client));
437	if (tls_configure(server, server_cfg) == -1)
438		errx(1, "failed to configure server: %s", tls_error(server));
439
440	tls_config_free(client_cfg);
441	tls_config_free(server_cfg);
442
443	if (tls_handshake(client) != -1) {
444		printf("FAIL: TLS handshake succeeded on unconnnected "
445		    "client context\n");
446		failure = 1;
447		goto done;
448	}
449
450	circular_init();
451
452	if (tls_accept_cbs(server, &server_cctx, server_read, server_write,
453	    NULL) == -1)
454		errx(1, "failed to accept: %s", tls_error(server));
455
456	if (tls_connect_cbs(client, client_read, client_write, NULL,
457	    "test") == -1)
458		errx(1, "failed to connect: %s", tls_error(client));
459
460	if (do_client_server_handshake("ordering", client, server_cctx) != 0) {
461		failure = 1;
462		goto done;
463	}
464
465	if (tls_handshake(client) != -1) {
466		printf("FAIL: TLS handshake succeeded twice\n");
467		failure = 1;
468		goto done;
469	}
470
471	if (tls_handshake(server_cctx) != -1) {
472		printf("FAIL: TLS handshake succeeded twice\n");
473		failure = 1;
474		goto done;
475	}
476
477	if (do_client_server_close("ordering", client, server_cctx) != 0) {
478		failure = 1;
479		goto done;
480	}
481
482 done:
483	tls_free(client);
484	tls_free(server);
485	tls_free(server_cctx);
486
487	printf("\n");
488
489	return (failure);
490}
491
492struct test_versions {
493	char *client;
494	char *server;
495};
496
497static struct test_versions tls_test_versions[] = {
498	{"tlsv1.3", "all"},
499	{"tlsv1.2", "all"},
500	{"tlsv1.1", "all"},
501	{"tlsv1.0", "all"},
502	{"all", "tlsv1.3"},
503	{"all", "tlsv1.2"},
504	{"all", "tlsv1.1"},
505	{"all", "tlsv1.0"},
506	{"tlsv1.3", "tlsv1.3"},
507	{"tlsv1.2", "tlsv1.2"},
508	{"tlsv1.1", "tlsv1.1"},
509	{"tlsv1.0", "tlsv1.0"},
510};
511
512#define N_TLS_VERSION_TESTS \
513    (sizeof(tls_test_versions) / sizeof(*tls_test_versions))
514
515static int
516do_tls_version_tests(void)
517{
518	struct test_versions *tv;
519	int failure = 0;
520	size_t i;
521
522	printf("== TLS version tests ==\n");
523
524	for (i = 0; i < N_TLS_VERSION_TESTS; i++) {
525		tv = &tls_test_versions[i];
526		printf("INFO: version test %zu - client versions '%s' "
527		    "and server versions '%s'\n", i, tv->client, tv->server);
528		failure |= test_tls(tv->client, tv->server, "legacy");
529		printf("\n");
530	}
531
532	return failure;
533}
534
535int
536main(int argc, char **argv)
537{
538	int failure = 0;
539
540	if (argc != 4) {
541		fprintf(stderr, "usage: %s cafile certfile keyfile\n",
542		    argv[0]);
543		return (1);
544	}
545
546	cafile = argv[1];
547	certfile = argv[2];
548	keyfile = argv[3];
549
550	failure |= do_tls_tests();
551	failure |= do_tls_ordering_tests();
552	failure |= do_tls_version_tests();
553
554	return (failure);
555}
556