1/* $OpenBSD: recordtest.c,v 1.5 2022/06/10 22:00:15 tb Exp $ */
2/*
3 * Copyright (c) 2019 Joel Sing <jsing@openbsd.org>
4 *
5 * Permission to use, copy, modify, and distribute this software for any
6 * purpose with or without fee is hereby granted, provided that the above
7 * copyright notice and this permission notice appear in all copies.
8 *
9 * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
10 * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
11 * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
12 * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
13 * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
14 * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
15 * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
16 */
17
18#include <err.h>
19#include <string.h>
20
21#include <openssl/ssl.h>
22
23#include "tls13_internal.h"
24#include "tls13_record.h"
25
26/* Valid record. */
27static uint8_t test_record_1[] = {
28	0x16, 0x03, 0x03, 0x00, 0x7a, 0x02, 0x00, 0x00,
29	0x76, 0x03, 0x03, 0x14, 0xae, 0x2b, 0x6d, 0x58,
30	0xe9, 0x79, 0x9d, 0xd4, 0x90, 0x52, 0x90, 0x13,
31	0x1c, 0x08, 0xaa, 0x3f, 0x5b, 0xfb, 0x64, 0xfe,
32	0x9a, 0xca, 0x73, 0x6d, 0x87, 0x8d, 0x8b, 0x3b,
33	0x70, 0x14, 0xa3, 0x20, 0xd7, 0x50, 0xa4, 0xe5,
34	0x17, 0x42, 0x5d, 0xce, 0xe6, 0xfe, 0x1b, 0x59,
35	0x27, 0x6b, 0xff, 0xc8, 0x40, 0xc7, 0xac, 0x16,
36	0x32, 0xe6, 0x5b, 0xd2, 0xd9, 0xd4, 0xb5, 0x3f,
37	0x8f, 0x74, 0x6e, 0x7d, 0x13, 0x02, 0x00, 0x00,
38	0x2e, 0x00, 0x33, 0x00, 0x24, 0x00, 0x1d, 0x00,
39	0x20, 0x72, 0xb0, 0xaf, 0x7f, 0xf5, 0x89, 0x0f,
40	0xcd, 0x6e, 0x45, 0xb1, 0x51, 0xa0, 0xbd, 0x1e,
41	0xee, 0x7e, 0xf1, 0xa5, 0xc5, 0xc6, 0x7e, 0x5f,
42	0x6a, 0xca, 0xc9, 0xe4, 0xae, 0xb9, 0x50, 0x76,
43	0x0a, 0x00, 0x2b, 0x00, 0x02, 0x03, 0x04,
44};
45
46/* Truncated record. */
47static uint8_t test_record_2[] = {
48	0x17, 0x03, 0x03, 0x41, 0x00, 0x02, 0x00, 0x00,
49};
50
51/* Oversized and truncated record. */
52static uint8_t test_record_3[] = {
53	0x17, 0x03, 0x03, 0x41, 0x01, 0x02, 0x00, 0x00,
54};
55
56static void
57hexdump(const unsigned char *buf, size_t len)
58{
59	size_t i;
60
61	for (i = 1; i <= len; i++)
62		fprintf(stderr, " 0x%02x,%s", buf[i - 1], i % 8 ? "" : "\n");
63	if (len % 8 != 0)
64		fprintf(stderr, "\n");
65}
66
67struct rw_state {
68	uint8_t *buf;
69	size_t len;
70	size_t offset;
71	uint8_t eof;
72};
73
74static ssize_t
75read_cb(void *buf, size_t buflen, void *cb_arg)
76{
77	struct rw_state *rs = cb_arg;
78	ssize_t n;
79
80	if (rs->eof)
81		return TLS13_IO_EOF;
82
83	if ((size_t)(n = buflen) > (rs->len - rs->offset))
84		n = rs->len - rs->offset;
85
86	if (n == 0)
87		return TLS13_IO_WANT_POLLIN;
88
89	memcpy(buf, &rs->buf[rs->offset], n);
90	rs->offset += n;
91
92	return n;
93}
94
95static ssize_t
96write_cb(const void *buf, size_t buflen, void *cb_arg)
97{
98	struct rw_state *ws = cb_arg;
99	ssize_t n;
100
101	if (ws->eof)
102		return TLS13_IO_EOF;
103
104	if ((size_t)(n = buflen) > (ws->len - ws->offset))
105		n = ws->len - ws->offset;
106
107	if (n == 0)
108		return TLS13_IO_WANT_POLLOUT;
109
110	memcpy(&ws->buf[ws->offset], buf, n);
111	ws->offset += n;
112
113	return n;
114}
115
116struct record_test {
117	size_t rw_len;
118	int eof;
119	ssize_t want_ret;
120};
121
122struct record_recv_test {
123	uint8_t *read_buf;
124	struct record_test rt[10];
125	uint8_t want_content_type;
126	uint8_t *want_data;
127	size_t want_len;
128};
129
130struct record_recv_test record_recv_tests[] = {
131	{
132		.read_buf = test_record_1,
133		.rt = {
134			{
135				.rw_len = sizeof(test_record_1),
136				.want_ret = sizeof(test_record_1),
137			},
138		},
139		.want_content_type = SSL3_RT_HANDSHAKE,
140		.want_data = test_record_1,
141		.want_len = sizeof(test_record_1),
142	},
143	{
144		.read_buf = test_record_1,
145		.rt = {
146			{
147				.rw_len = 0,
148				.want_ret = TLS13_IO_WANT_POLLIN,
149			},
150			{
151				.rw_len = sizeof(test_record_1),
152				.want_ret = sizeof(test_record_1),
153			},
154		},
155		.want_content_type = SSL3_RT_HANDSHAKE,
156		.want_data = test_record_1,
157		.want_len = sizeof(test_record_1),
158	},
159	{
160		.read_buf = test_record_1,
161		.rt = {
162			{
163				.rw_len = 0,
164				.want_ret = TLS13_IO_WANT_POLLIN,
165			},
166			{
167				.rw_len = 5,
168				.want_ret = TLS13_IO_WANT_POLLIN,
169			},
170			{
171				.rw_len = sizeof(test_record_1),
172				.want_ret = sizeof(test_record_1),
173			},
174		},
175		.want_content_type = SSL3_RT_HANDSHAKE,
176		.want_data = test_record_1,
177		.want_len = sizeof(test_record_1),
178	},
179	{
180		.read_buf = test_record_1,
181		.rt = {
182			{
183				.rw_len = 0,
184				.want_ret = TLS13_IO_WANT_POLLIN,
185			},
186			{
187				.rw_len = 2,
188				.want_ret = TLS13_IO_WANT_POLLIN,
189			},
190			{
191				.rw_len = 6,
192				.want_ret = TLS13_IO_WANT_POLLIN,
193			},
194			{
195				.rw_len = sizeof(test_record_1),
196				.want_ret = sizeof(test_record_1),
197			},
198		},
199		.want_content_type = SSL3_RT_HANDSHAKE,
200		.want_data = test_record_1,
201		.want_len = sizeof(test_record_1),
202	},
203	{
204		.read_buf = test_record_1,
205		.rt = {
206			{
207				.rw_len = 4,
208				.want_ret = TLS13_IO_WANT_POLLIN,
209			},
210			{
211				.eof = 1,
212				.want_ret = TLS13_IO_EOF,
213			},
214		},
215	},
216	{
217		.read_buf = test_record_1,
218		.rt = {
219			{
220				.eof = 1,
221				.want_ret = TLS13_IO_EOF,
222			},
223		},
224	},
225	{
226		.read_buf = test_record_2,
227		.rt = {
228			{
229				.rw_len = sizeof(test_record_2),
230				.want_ret = TLS13_IO_WANT_POLLIN,
231			},
232			{
233				.eof = 1,
234				.want_ret = TLS13_IO_EOF,
235			},
236		},
237		.want_content_type = SSL3_RT_APPLICATION_DATA,
238	},
239	{
240		.read_buf = test_record_3,
241		.rt = {
242			{
243				.rw_len = sizeof(test_record_3),
244				.want_ret = TLS13_IO_RECORD_OVERFLOW,
245			},
246		},
247	},
248};
249
250#define N_RECORD_RECV_TESTS (sizeof(record_recv_tests) / sizeof(record_recv_tests[0]))
251
252struct record_send_test {
253	uint8_t *data;
254	size_t data_len;
255	struct record_test rt[10];
256	uint8_t *want_data;
257	size_t want_len;
258};
259
260struct record_send_test record_send_tests[] = {
261	{
262		.data = test_record_1,
263		.data_len = sizeof(test_record_1),
264		.rt = {
265			{
266				.rw_len = sizeof(test_record_1),
267				.want_ret = sizeof(test_record_1),
268			},
269		},
270		.want_data = test_record_1,
271		.want_len = sizeof(test_record_1),
272	},
273	{
274		.data = test_record_1,
275		.data_len = sizeof(test_record_1),
276		.rt = {
277			{
278				.rw_len = 0,
279				.want_ret = TLS13_IO_WANT_POLLOUT,
280			},
281			{
282				.rw_len = sizeof(test_record_1),
283				.want_ret = sizeof(test_record_1),
284			},
285		},
286		.want_data = test_record_1,
287		.want_len = sizeof(test_record_1),
288	},
289	{
290		.data = test_record_1,
291		.data_len = sizeof(test_record_1),
292		.rt = {
293			{
294				.rw_len = 0,
295				.want_ret = TLS13_IO_WANT_POLLOUT,
296			},
297			{
298				.rw_len = 5,
299				.want_ret = TLS13_IO_WANT_POLLOUT,
300			},
301			{
302				.rw_len = sizeof(test_record_1),
303				.want_ret = sizeof(test_record_1),
304			},
305		},
306		.want_data = test_record_1,
307		.want_len = sizeof(test_record_1),
308	},
309	{
310		.data = test_record_1,
311		.data_len = sizeof(test_record_1),
312		.rt = {
313			{
314				.rw_len = 0,
315				.want_ret = TLS13_IO_WANT_POLLOUT,
316			},
317			{
318				.rw_len = 2,
319				.want_ret = TLS13_IO_WANT_POLLOUT,
320			},
321			{
322				.rw_len = 6,
323				.want_ret = TLS13_IO_WANT_POLLOUT,
324			},
325			{
326				.rw_len = sizeof(test_record_1),
327				.want_ret = sizeof(test_record_1),
328			},
329		},
330		.want_data = test_record_1,
331		.want_len = sizeof(test_record_1),
332	},
333	{
334		.data = test_record_1,
335		.data_len = sizeof(test_record_1),
336		.rt = {
337			{
338				.rw_len = 4,
339				.want_ret = TLS13_IO_WANT_POLLOUT,
340			},
341			{
342				.eof = 1,
343				.want_ret = TLS13_IO_EOF,
344			},
345		},
346		.want_data = test_record_1,
347		.want_len = 4,
348	},
349	{
350		.data = test_record_1,
351		.data_len = sizeof(test_record_1),
352		.rt = {
353			{
354				.rw_len = 0,
355				.want_ret = TLS13_IO_WANT_POLLOUT,
356			},
357			{
358				.eof = 1,
359				.want_ret = TLS13_IO_EOF,
360			},
361		},
362		.want_data = NULL,
363		.want_len = 0,
364	},
365};
366
367#define N_RECORD_SEND_TESTS (sizeof(record_send_tests) / sizeof(record_send_tests[0]))
368
369static int
370test_record_recv(size_t test_no, struct record_recv_test *rrt)
371{
372	struct tls13_record *rec;
373	struct rw_state rs;
374	int failed = 1;
375	ssize_t ret;
376	size_t i;
377	CBS cbs;
378
379	rs.buf = rrt->read_buf;
380	rs.offset = 0;
381
382	if ((rec = tls13_record_new()) == NULL)
383		errx(1, "tls13_record_new");
384
385	for (i = 0; rrt->rt[i].rw_len != 0 || rrt->rt[i].want_ret != 0; i++) {
386		rs.eof = rrt->rt[i].eof;
387		rs.len = rrt->rt[i].rw_len;
388
389		ret = tls13_record_recv(rec, read_cb, &rs);
390		if (ret != rrt->rt[i].want_ret) {
391			fprintf(stderr, "FAIL: Test %zu/%zu - tls_record_recv "
392			    "returned %zd, want %zd\n", test_no, i, ret,
393			    rrt->rt[i].want_ret);
394			goto failure;
395		}
396	}
397
398	if (tls13_record_content_type(rec) != rrt->want_content_type) {
399		fprintf(stderr, "FAIL: Test %zu - got content type %u, "
400		    "want %u\n", test_no, tls13_record_content_type(rec),
401		    rrt->want_content_type);
402		goto failure;
403	}
404
405	tls13_record_data(rec, &cbs);
406	if (rrt->want_data == NULL) {
407		if (CBS_data(&cbs) != NULL || CBS_len(&cbs) != 0) {
408			fprintf(stderr, "FAIL: Test %zu - got CBS with data, "
409			    "want NULL\n", test_no);
410			goto failure;
411		}
412		goto done;
413	}
414	if (!CBS_mem_equal(&cbs, rrt->want_data, rrt->want_len)) {
415		fprintf(stderr, "FAIL: Test %zu - data mismatch\n", test_no);
416		fprintf(stderr, "Got record data:\n");
417		hexdump(CBS_data(&cbs), CBS_len(&cbs));
418		fprintf(stderr, "Want record data:\n");
419		hexdump(rrt->want_data, rrt->want_len);
420		goto failure;
421	}
422
423	if (!tls13_record_header(rec, &cbs)) {
424		fprintf(stderr, "FAIL: Test %zu - fail to get record "
425		    "header", test_no);
426		goto failure;
427	}
428	if (!CBS_mem_equal(&cbs, rrt->want_data, TLS13_RECORD_HEADER_LEN)) {
429		fprintf(stderr, "FAIL: Test %zu - header mismatch\n", test_no);
430		fprintf(stderr, "Got record header:\n");
431		hexdump(CBS_data(&cbs), CBS_len(&cbs));
432		fprintf(stderr, "Want record header:\n");
433		hexdump(rrt->want_data, rrt->want_len);
434		goto failure;
435	}
436
437	if (!tls13_record_content(rec, &cbs)) {
438		fprintf(stderr, "FAIL: Test %zu - fail to get record "
439		    "content", test_no);
440		goto failure;
441	}
442	if (!CBS_mem_equal(&cbs, rrt->want_data + TLS13_RECORD_HEADER_LEN,
443	    rrt->want_len - TLS13_RECORD_HEADER_LEN)) {
444		fprintf(stderr, "FAIL: Test %zu - content mismatch\n", test_no);
445		fprintf(stderr, "Got record content:\n");
446		hexdump(CBS_data(&cbs), CBS_len(&cbs));
447		fprintf(stderr, "Want record content:\n");
448		hexdump(rrt->want_data, rrt->want_len);
449		goto failure;
450	}
451
452 done:
453	failed = 0;
454
455 failure:
456	tls13_record_free(rec);
457
458	return failed;
459}
460
461static int
462test_record_send(size_t test_no, struct record_send_test *rst)
463{
464	uint8_t *data = NULL;
465	struct tls13_record *rec;
466	struct rw_state ws;
467	int failed = 1;
468	ssize_t ret;
469	size_t i;
470
471	if ((ws.buf = malloc(TLS13_RECORD_MAX_LEN)) == NULL)
472		errx(1, "malloc");
473
474	ws.offset = 0;
475
476	if ((rec = tls13_record_new()) == NULL)
477		errx(1, "tls13_record_new");
478
479	if ((data = malloc(rst->data_len)) == NULL)
480		errx(1, "malloc");
481	memcpy(data, rst->data, rst->data_len);
482
483	if (!tls13_record_set_data(rec, data, rst->data_len)) {
484		fprintf(stderr, "FAIL: Test %zu - failed to set record data\n",
485		    test_no);
486		goto failure;
487	}
488	data = NULL;
489
490	for (i = 0; rst->rt[i].rw_len != 0 || rst->rt[i].want_ret != 0; i++) {
491		ws.eof = rst->rt[i].eof;
492		ws.len = rst->rt[i].rw_len;
493
494		ret = tls13_record_send(rec, write_cb, &ws);
495		if (ret != rst->rt[i].want_ret) {
496			fprintf(stderr, "FAIL: Test %zu/%zu - tls_record_send "
497			    "returned %zd, want %zd\n", test_no, i, ret,
498			    rst->rt[i].want_ret);
499			goto failure;
500		}
501	}
502
503	if (rst->want_data != NULL &&
504	    memcmp(ws.buf, rst->want_data, rst->want_len) != 0) {
505		fprintf(stderr, "FAIL: Test %zu - content mismatch\n", test_no);
506		fprintf(stderr, "Got record data:\n");
507		hexdump(rst->data, rst->data_len);
508		fprintf(stderr, "Want record data:\n");
509		hexdump(rst->want_data, rst->want_len);
510		goto failure;
511	}
512
513	failed = 0;
514
515 failure:
516	tls13_record_free(rec);
517	free(ws.buf);
518
519	return failed;
520}
521
522static int
523test_recv_records(void)
524{
525	int failed = 0;
526	size_t i;
527
528	for (i = 0; i < N_RECORD_RECV_TESTS; i++)
529		failed |= test_record_recv(i, &record_recv_tests[i]);
530
531	return failed;
532}
533
534static int
535test_send_records(void)
536{
537	int failed = 0;
538	size_t i;
539
540	for (i = 0; i < N_RECORD_SEND_TESTS; i++)
541		failed |= test_record_send(i, &record_send_tests[i]);
542
543	return failed;
544}
545
546int
547main(int argc, char **argv)
548{
549	int failed = 0;
550
551	failed |= test_recv_records();
552	failed |= test_send_records();
553
554	return failed;
555}
556