1// SPDX-License-Identifier: GPL-2.0
2
3#define _GNU_SOURCE
4
5#include <arpa/inet.h>
6#include <errno.h>
7#include <error.h>
8#include <fcntl.h>
9#include <poll.h>
10#include <stdio.h>
11#include <stdlib.h>
12#include <unistd.h>
13
14#include <linux/tls.h>
15#include <linux/tcp.h>
16#include <linux/socket.h>
17
18#include <sys/epoll.h>
19#include <sys/types.h>
20#include <sys/sendfile.h>
21#include <sys/socket.h>
22#include <sys/stat.h>
23
24#include "../kselftest_harness.h"
25
26#define TLS_PAYLOAD_MAX_LEN 16384
27#define SOL_TLS 282
28
29static int fips_enabled;
30
31struct tls_crypto_info_keys {
32	union {
33		struct tls_crypto_info crypto_info;
34		struct tls12_crypto_info_aes_gcm_128 aes128;
35		struct tls12_crypto_info_chacha20_poly1305 chacha20;
36		struct tls12_crypto_info_sm4_gcm sm4gcm;
37		struct tls12_crypto_info_sm4_ccm sm4ccm;
38		struct tls12_crypto_info_aes_ccm_128 aesccm128;
39		struct tls12_crypto_info_aes_gcm_256 aesgcm256;
40		struct tls12_crypto_info_aria_gcm_128 ariagcm128;
41		struct tls12_crypto_info_aria_gcm_256 ariagcm256;
42	};
43	size_t len;
44};
45
46static void tls_crypto_info_init(uint16_t tls_version, uint16_t cipher_type,
47				 struct tls_crypto_info_keys *tls12)
48{
49	memset(tls12, 0, sizeof(*tls12));
50
51	switch (cipher_type) {
52	case TLS_CIPHER_CHACHA20_POLY1305:
53		tls12->len = sizeof(struct tls12_crypto_info_chacha20_poly1305);
54		tls12->chacha20.info.version = tls_version;
55		tls12->chacha20.info.cipher_type = cipher_type;
56		break;
57	case TLS_CIPHER_AES_GCM_128:
58		tls12->len = sizeof(struct tls12_crypto_info_aes_gcm_128);
59		tls12->aes128.info.version = tls_version;
60		tls12->aes128.info.cipher_type = cipher_type;
61		break;
62	case TLS_CIPHER_SM4_GCM:
63		tls12->len = sizeof(struct tls12_crypto_info_sm4_gcm);
64		tls12->sm4gcm.info.version = tls_version;
65		tls12->sm4gcm.info.cipher_type = cipher_type;
66		break;
67	case TLS_CIPHER_SM4_CCM:
68		tls12->len = sizeof(struct tls12_crypto_info_sm4_ccm);
69		tls12->sm4ccm.info.version = tls_version;
70		tls12->sm4ccm.info.cipher_type = cipher_type;
71		break;
72	case TLS_CIPHER_AES_CCM_128:
73		tls12->len = sizeof(struct tls12_crypto_info_aes_ccm_128);
74		tls12->aesccm128.info.version = tls_version;
75		tls12->aesccm128.info.cipher_type = cipher_type;
76		break;
77	case TLS_CIPHER_AES_GCM_256:
78		tls12->len = sizeof(struct tls12_crypto_info_aes_gcm_256);
79		tls12->aesgcm256.info.version = tls_version;
80		tls12->aesgcm256.info.cipher_type = cipher_type;
81		break;
82	case TLS_CIPHER_ARIA_GCM_128:
83		tls12->len = sizeof(struct tls12_crypto_info_aria_gcm_128);
84		tls12->ariagcm128.info.version = tls_version;
85		tls12->ariagcm128.info.cipher_type = cipher_type;
86		break;
87	case TLS_CIPHER_ARIA_GCM_256:
88		tls12->len = sizeof(struct tls12_crypto_info_aria_gcm_256);
89		tls12->ariagcm256.info.version = tls_version;
90		tls12->ariagcm256.info.cipher_type = cipher_type;
91		break;
92	default:
93		break;
94	}
95}
96
97static void memrnd(void *s, size_t n)
98{
99	int *dword = s;
100	char *byte;
101
102	for (; n >= 4; n -= 4)
103		*dword++ = rand();
104	byte = (void *)dword;
105	while (n--)
106		*byte++ = rand();
107}
108
109static void ulp_sock_pair(struct __test_metadata *_metadata,
110			  int *fd, int *cfd, bool *notls)
111{
112	struct sockaddr_in addr;
113	socklen_t len;
114	int sfd, ret;
115
116	*notls = false;
117	len = sizeof(addr);
118
119	addr.sin_family = AF_INET;
120	addr.sin_addr.s_addr = htonl(INADDR_ANY);
121	addr.sin_port = 0;
122
123	*fd = socket(AF_INET, SOCK_STREAM, 0);
124	sfd = socket(AF_INET, SOCK_STREAM, 0);
125
126	ret = bind(sfd, &addr, sizeof(addr));
127	ASSERT_EQ(ret, 0);
128	ret = listen(sfd, 10);
129	ASSERT_EQ(ret, 0);
130
131	ret = getsockname(sfd, &addr, &len);
132	ASSERT_EQ(ret, 0);
133
134	ret = connect(*fd, &addr, sizeof(addr));
135	ASSERT_EQ(ret, 0);
136
137	*cfd = accept(sfd, &addr, &len);
138	ASSERT_GE(*cfd, 0);
139
140	close(sfd);
141
142	ret = setsockopt(*fd, IPPROTO_TCP, TCP_ULP, "tls", sizeof("tls"));
143	if (ret != 0) {
144		ASSERT_EQ(errno, ENOENT);
145		*notls = true;
146		printf("Failure setting TCP_ULP, testing without tls\n");
147		return;
148	}
149
150	ret = setsockopt(*cfd, IPPROTO_TCP, TCP_ULP, "tls", sizeof("tls"));
151	ASSERT_EQ(ret, 0);
152}
153
154/* Produce a basic cmsg */
155static int tls_send_cmsg(int fd, unsigned char record_type,
156			 void *data, size_t len, int flags)
157{
158	char cbuf[CMSG_SPACE(sizeof(char))];
159	int cmsg_len = sizeof(char);
160	struct cmsghdr *cmsg;
161	struct msghdr msg;
162	struct iovec vec;
163
164	vec.iov_base = data;
165	vec.iov_len = len;
166	memset(&msg, 0, sizeof(struct msghdr));
167	msg.msg_iov = &vec;
168	msg.msg_iovlen = 1;
169	msg.msg_control = cbuf;
170	msg.msg_controllen = sizeof(cbuf);
171	cmsg = CMSG_FIRSTHDR(&msg);
172	cmsg->cmsg_level = SOL_TLS;
173	/* test sending non-record types. */
174	cmsg->cmsg_type = TLS_SET_RECORD_TYPE;
175	cmsg->cmsg_len = CMSG_LEN(cmsg_len);
176	*CMSG_DATA(cmsg) = record_type;
177	msg.msg_controllen = cmsg->cmsg_len;
178
179	return sendmsg(fd, &msg, flags);
180}
181
182static int tls_recv_cmsg(struct __test_metadata *_metadata,
183			 int fd, unsigned char record_type,
184			 void *data, size_t len, int flags)
185{
186	char cbuf[CMSG_SPACE(sizeof(char))];
187	struct cmsghdr *cmsg;
188	unsigned char ctype;
189	struct msghdr msg;
190	struct iovec vec;
191	int n;
192
193	vec.iov_base = data;
194	vec.iov_len = len;
195	memset(&msg, 0, sizeof(struct msghdr));
196	msg.msg_iov = &vec;
197	msg.msg_iovlen = 1;
198	msg.msg_control = cbuf;
199	msg.msg_controllen = sizeof(cbuf);
200
201	n = recvmsg(fd, &msg, flags);
202
203	cmsg = CMSG_FIRSTHDR(&msg);
204	EXPECT_NE(cmsg, NULL);
205	EXPECT_EQ(cmsg->cmsg_level, SOL_TLS);
206	EXPECT_EQ(cmsg->cmsg_type, TLS_GET_RECORD_TYPE);
207	ctype = *((unsigned char *)CMSG_DATA(cmsg));
208	EXPECT_EQ(ctype, record_type);
209
210	return n;
211}
212
213FIXTURE(tls_basic)
214{
215	int fd, cfd;
216	bool notls;
217};
218
219FIXTURE_SETUP(tls_basic)
220{
221	ulp_sock_pair(_metadata, &self->fd, &self->cfd, &self->notls);
222}
223
224FIXTURE_TEARDOWN(tls_basic)
225{
226	close(self->fd);
227	close(self->cfd);
228}
229
230/* Send some data through with ULP but no keys */
231TEST_F(tls_basic, base_base)
232{
233	char const *test_str = "test_read";
234	int send_len = 10;
235	char buf[10];
236
237	ASSERT_EQ(strlen(test_str) + 1, send_len);
238
239	EXPECT_EQ(send(self->fd, test_str, send_len, 0), send_len);
240	EXPECT_NE(recv(self->cfd, buf, send_len, 0), -1);
241	EXPECT_EQ(memcmp(buf, test_str, send_len), 0);
242};
243
244TEST_F(tls_basic, bad_cipher)
245{
246	struct tls_crypto_info_keys tls12;
247
248	tls12.crypto_info.version = 200;
249	tls12.crypto_info.cipher_type = TLS_CIPHER_AES_GCM_128;
250	EXPECT_EQ(setsockopt(self->fd, SOL_TLS, TLS_TX, &tls12, sizeof(struct tls12_crypto_info_aes_gcm_128)), -1);
251
252	tls12.crypto_info.version = TLS_1_2_VERSION;
253	tls12.crypto_info.cipher_type = 50;
254	EXPECT_EQ(setsockopt(self->fd, SOL_TLS, TLS_TX, &tls12, sizeof(struct tls12_crypto_info_aes_gcm_128)), -1);
255
256	tls12.crypto_info.version = TLS_1_2_VERSION;
257	tls12.crypto_info.cipher_type = 59;
258	EXPECT_EQ(setsockopt(self->fd, SOL_TLS, TLS_TX, &tls12, sizeof(struct tls12_crypto_info_aes_gcm_128)), -1);
259
260	tls12.crypto_info.version = TLS_1_2_VERSION;
261	tls12.crypto_info.cipher_type = 10;
262	EXPECT_EQ(setsockopt(self->fd, SOL_TLS, TLS_TX, &tls12, sizeof(struct tls12_crypto_info_aes_gcm_128)), -1);
263
264	tls12.crypto_info.version = TLS_1_2_VERSION;
265	tls12.crypto_info.cipher_type = 70;
266	EXPECT_EQ(setsockopt(self->fd, SOL_TLS, TLS_TX, &tls12, sizeof(struct tls12_crypto_info_aes_gcm_128)), -1);
267}
268
269FIXTURE(tls)
270{
271	int fd, cfd;
272	bool notls;
273};
274
275FIXTURE_VARIANT(tls)
276{
277	uint16_t tls_version;
278	uint16_t cipher_type;
279	bool nopad, fips_non_compliant;
280};
281
282FIXTURE_VARIANT_ADD(tls, 12_aes_gcm)
283{
284	.tls_version = TLS_1_2_VERSION,
285	.cipher_type = TLS_CIPHER_AES_GCM_128,
286};
287
288FIXTURE_VARIANT_ADD(tls, 13_aes_gcm)
289{
290	.tls_version = TLS_1_3_VERSION,
291	.cipher_type = TLS_CIPHER_AES_GCM_128,
292};
293
294FIXTURE_VARIANT_ADD(tls, 12_chacha)
295{
296	.tls_version = TLS_1_2_VERSION,
297	.cipher_type = TLS_CIPHER_CHACHA20_POLY1305,
298	.fips_non_compliant = true,
299};
300
301FIXTURE_VARIANT_ADD(tls, 13_chacha)
302{
303	.tls_version = TLS_1_3_VERSION,
304	.cipher_type = TLS_CIPHER_CHACHA20_POLY1305,
305	.fips_non_compliant = true,
306};
307
308FIXTURE_VARIANT_ADD(tls, 13_sm4_gcm)
309{
310	.tls_version = TLS_1_3_VERSION,
311	.cipher_type = TLS_CIPHER_SM4_GCM,
312	.fips_non_compliant = true,
313};
314
315FIXTURE_VARIANT_ADD(tls, 13_sm4_ccm)
316{
317	.tls_version = TLS_1_3_VERSION,
318	.cipher_type = TLS_CIPHER_SM4_CCM,
319	.fips_non_compliant = true,
320};
321
322FIXTURE_VARIANT_ADD(tls, 12_aes_ccm)
323{
324	.tls_version = TLS_1_2_VERSION,
325	.cipher_type = TLS_CIPHER_AES_CCM_128,
326};
327
328FIXTURE_VARIANT_ADD(tls, 13_aes_ccm)
329{
330	.tls_version = TLS_1_3_VERSION,
331	.cipher_type = TLS_CIPHER_AES_CCM_128,
332};
333
334FIXTURE_VARIANT_ADD(tls, 12_aes_gcm_256)
335{
336	.tls_version = TLS_1_2_VERSION,
337	.cipher_type = TLS_CIPHER_AES_GCM_256,
338};
339
340FIXTURE_VARIANT_ADD(tls, 13_aes_gcm_256)
341{
342	.tls_version = TLS_1_3_VERSION,
343	.cipher_type = TLS_CIPHER_AES_GCM_256,
344};
345
346FIXTURE_VARIANT_ADD(tls, 13_nopad)
347{
348	.tls_version = TLS_1_3_VERSION,
349	.cipher_type = TLS_CIPHER_AES_GCM_128,
350	.nopad = true,
351};
352
353FIXTURE_VARIANT_ADD(tls, 12_aria_gcm)
354{
355	.tls_version = TLS_1_2_VERSION,
356	.cipher_type = TLS_CIPHER_ARIA_GCM_128,
357};
358
359FIXTURE_VARIANT_ADD(tls, 12_aria_gcm_256)
360{
361	.tls_version = TLS_1_2_VERSION,
362	.cipher_type = TLS_CIPHER_ARIA_GCM_256,
363};
364
365FIXTURE_SETUP(tls)
366{
367	struct tls_crypto_info_keys tls12;
368	int one = 1;
369	int ret;
370
371	if (fips_enabled && variant->fips_non_compliant)
372		SKIP(return, "Unsupported cipher in FIPS mode");
373
374	tls_crypto_info_init(variant->tls_version, variant->cipher_type,
375			     &tls12);
376
377	ulp_sock_pair(_metadata, &self->fd, &self->cfd, &self->notls);
378
379	if (self->notls)
380		return;
381
382	ret = setsockopt(self->fd, SOL_TLS, TLS_TX, &tls12, tls12.len);
383	ASSERT_EQ(ret, 0);
384
385	ret = setsockopt(self->cfd, SOL_TLS, TLS_RX, &tls12, tls12.len);
386	ASSERT_EQ(ret, 0);
387
388	if (variant->nopad) {
389		ret = setsockopt(self->cfd, SOL_TLS, TLS_RX_EXPECT_NO_PAD,
390				 (void *)&one, sizeof(one));
391		ASSERT_EQ(ret, 0);
392	}
393}
394
395FIXTURE_TEARDOWN(tls)
396{
397	close(self->fd);
398	close(self->cfd);
399}
400
401TEST_F(tls, sendfile)
402{
403	int filefd = open("/proc/self/exe", O_RDONLY);
404	struct stat st;
405
406	EXPECT_GE(filefd, 0);
407	fstat(filefd, &st);
408	EXPECT_GE(sendfile(self->fd, filefd, 0, st.st_size), 0);
409}
410
411TEST_F(tls, send_then_sendfile)
412{
413	int filefd = open("/proc/self/exe", O_RDONLY);
414	char const *test_str = "test_send";
415	int to_send = strlen(test_str) + 1;
416	char recv_buf[10];
417	struct stat st;
418	char *buf;
419
420	EXPECT_GE(filefd, 0);
421	fstat(filefd, &st);
422	buf = (char *)malloc(st.st_size);
423
424	EXPECT_EQ(send(self->fd, test_str, to_send, 0), to_send);
425	EXPECT_EQ(recv(self->cfd, recv_buf, to_send, MSG_WAITALL), to_send);
426	EXPECT_EQ(memcmp(test_str, recv_buf, to_send), 0);
427
428	EXPECT_GE(sendfile(self->fd, filefd, 0, st.st_size), 0);
429	EXPECT_EQ(recv(self->cfd, buf, st.st_size, MSG_WAITALL), st.st_size);
430}
431
432static void chunked_sendfile(struct __test_metadata *_metadata,
433			     struct _test_data_tls *self,
434			     uint16_t chunk_size,
435			     uint16_t extra_payload_size)
436{
437	char buf[TLS_PAYLOAD_MAX_LEN];
438	uint16_t test_payload_size;
439	int size = 0;
440	int ret;
441	char filename[] = "/tmp/mytemp.XXXXXX";
442	int fd = mkstemp(filename);
443	off_t offset = 0;
444
445	unlink(filename);
446	ASSERT_GE(fd, 0);
447	EXPECT_GE(chunk_size, 1);
448	test_payload_size = chunk_size + extra_payload_size;
449	ASSERT_GE(TLS_PAYLOAD_MAX_LEN, test_payload_size);
450	memset(buf, 1, test_payload_size);
451	size = write(fd, buf, test_payload_size);
452	EXPECT_EQ(size, test_payload_size);
453	fsync(fd);
454
455	while (size > 0) {
456		ret = sendfile(self->fd, fd, &offset, chunk_size);
457		EXPECT_GE(ret, 0);
458		size -= ret;
459	}
460
461	EXPECT_EQ(recv(self->cfd, buf, test_payload_size, MSG_WAITALL),
462		  test_payload_size);
463
464	close(fd);
465}
466
467TEST_F(tls, multi_chunk_sendfile)
468{
469	chunked_sendfile(_metadata, self, 4096, 4096);
470	chunked_sendfile(_metadata, self, 4096, 0);
471	chunked_sendfile(_metadata, self, 4096, 1);
472	chunked_sendfile(_metadata, self, 4096, 2048);
473	chunked_sendfile(_metadata, self, 8192, 2048);
474	chunked_sendfile(_metadata, self, 4096, 8192);
475	chunked_sendfile(_metadata, self, 8192, 4096);
476	chunked_sendfile(_metadata, self, 12288, 1024);
477	chunked_sendfile(_metadata, self, 12288, 2000);
478	chunked_sendfile(_metadata, self, 15360, 100);
479	chunked_sendfile(_metadata, self, 15360, 300);
480	chunked_sendfile(_metadata, self, 1, 4096);
481	chunked_sendfile(_metadata, self, 2048, 4096);
482	chunked_sendfile(_metadata, self, 2048, 8192);
483	chunked_sendfile(_metadata, self, 4096, 8192);
484	chunked_sendfile(_metadata, self, 1024, 12288);
485	chunked_sendfile(_metadata, self, 2000, 12288);
486	chunked_sendfile(_metadata, self, 100, 15360);
487	chunked_sendfile(_metadata, self, 300, 15360);
488}
489
490TEST_F(tls, recv_max)
491{
492	unsigned int send_len = TLS_PAYLOAD_MAX_LEN;
493	char recv_mem[TLS_PAYLOAD_MAX_LEN];
494	char buf[TLS_PAYLOAD_MAX_LEN];
495
496	memrnd(buf, sizeof(buf));
497
498	EXPECT_GE(send(self->fd, buf, send_len, 0), 0);
499	EXPECT_NE(recv(self->cfd, recv_mem, send_len, 0), -1);
500	EXPECT_EQ(memcmp(buf, recv_mem, send_len), 0);
501}
502
503TEST_F(tls, recv_small)
504{
505	char const *test_str = "test_read";
506	int send_len = 10;
507	char buf[10];
508
509	send_len = strlen(test_str) + 1;
510	EXPECT_EQ(send(self->fd, test_str, send_len, 0), send_len);
511	EXPECT_NE(recv(self->cfd, buf, send_len, 0), -1);
512	EXPECT_EQ(memcmp(buf, test_str, send_len), 0);
513}
514
515TEST_F(tls, msg_more)
516{
517	char const *test_str = "test_read";
518	int send_len = 10;
519	char buf[10 * 2];
520
521	EXPECT_EQ(send(self->fd, test_str, send_len, MSG_MORE), send_len);
522	EXPECT_EQ(recv(self->cfd, buf, send_len, MSG_DONTWAIT), -1);
523	EXPECT_EQ(send(self->fd, test_str, send_len, 0), send_len);
524	EXPECT_EQ(recv(self->cfd, buf, send_len * 2, MSG_WAITALL),
525		  send_len * 2);
526	EXPECT_EQ(memcmp(buf, test_str, send_len), 0);
527}
528
529TEST_F(tls, msg_more_unsent)
530{
531	char const *test_str = "test_read";
532	int send_len = 10;
533	char buf[10];
534
535	EXPECT_EQ(send(self->fd, test_str, send_len, MSG_MORE), send_len);
536	EXPECT_EQ(recv(self->cfd, buf, send_len, MSG_DONTWAIT), -1);
537}
538
539TEST_F(tls, msg_eor)
540{
541	char const *test_str = "test_read";
542	int send_len = 10;
543	char buf[10];
544
545	EXPECT_EQ(send(self->fd, test_str, send_len, MSG_EOR), send_len);
546	EXPECT_EQ(recv(self->cfd, buf, send_len, MSG_WAITALL), send_len);
547	EXPECT_EQ(memcmp(buf, test_str, send_len), 0);
548}
549
550TEST_F(tls, sendmsg_single)
551{
552	struct msghdr msg;
553
554	char const *test_str = "test_sendmsg";
555	size_t send_len = 13;
556	struct iovec vec;
557	char buf[13];
558
559	vec.iov_base = (char *)test_str;
560	vec.iov_len = send_len;
561	memset(&msg, 0, sizeof(struct msghdr));
562	msg.msg_iov = &vec;
563	msg.msg_iovlen = 1;
564	EXPECT_EQ(sendmsg(self->fd, &msg, 0), send_len);
565	EXPECT_EQ(recv(self->cfd, buf, send_len, MSG_WAITALL), send_len);
566	EXPECT_EQ(memcmp(buf, test_str, send_len), 0);
567}
568
569#define MAX_FRAGS	64
570#define SEND_LEN	13
571TEST_F(tls, sendmsg_fragmented)
572{
573	char const *test_str = "test_sendmsg";
574	char buf[SEND_LEN * MAX_FRAGS];
575	struct iovec vec[MAX_FRAGS];
576	struct msghdr msg;
577	int i, frags;
578
579	for (frags = 1; frags <= MAX_FRAGS; frags++) {
580		for (i = 0; i < frags; i++) {
581			vec[i].iov_base = (char *)test_str;
582			vec[i].iov_len = SEND_LEN;
583		}
584
585		memset(&msg, 0, sizeof(struct msghdr));
586		msg.msg_iov = vec;
587		msg.msg_iovlen = frags;
588
589		EXPECT_EQ(sendmsg(self->fd, &msg, 0), SEND_LEN * frags);
590		EXPECT_EQ(recv(self->cfd, buf, SEND_LEN * frags, MSG_WAITALL),
591			  SEND_LEN * frags);
592
593		for (i = 0; i < frags; i++)
594			EXPECT_EQ(memcmp(buf + SEND_LEN * i,
595					 test_str, SEND_LEN), 0);
596	}
597}
598#undef MAX_FRAGS
599#undef SEND_LEN
600
601TEST_F(tls, sendmsg_large)
602{
603	void *mem = malloc(16384);
604	size_t send_len = 16384;
605	size_t sends = 128;
606	struct msghdr msg;
607	size_t recvs = 0;
608	size_t sent = 0;
609
610	memset(&msg, 0, sizeof(struct msghdr));
611	while (sent++ < sends) {
612		struct iovec vec = { (void *)mem, send_len };
613
614		msg.msg_iov = &vec;
615		msg.msg_iovlen = 1;
616		EXPECT_EQ(sendmsg(self->fd, &msg, 0), send_len);
617	}
618
619	while (recvs++ < sends) {
620		EXPECT_NE(recv(self->cfd, mem, send_len, 0), -1);
621	}
622
623	free(mem);
624}
625
626TEST_F(tls, sendmsg_multiple)
627{
628	char const *test_str = "test_sendmsg_multiple";
629	struct iovec vec[5];
630	char *test_strs[5];
631	struct msghdr msg;
632	int total_len = 0;
633	int len_cmp = 0;
634	int iov_len = 5;
635	char *buf;
636	int i;
637
638	memset(&msg, 0, sizeof(struct msghdr));
639	for (i = 0; i < iov_len; i++) {
640		test_strs[i] = (char *)malloc(strlen(test_str) + 1);
641		snprintf(test_strs[i], strlen(test_str) + 1, "%s", test_str);
642		vec[i].iov_base = (void *)test_strs[i];
643		vec[i].iov_len = strlen(test_strs[i]) + 1;
644		total_len += vec[i].iov_len;
645	}
646	msg.msg_iov = vec;
647	msg.msg_iovlen = iov_len;
648
649	EXPECT_EQ(sendmsg(self->fd, &msg, 0), total_len);
650	buf = malloc(total_len);
651	EXPECT_NE(recv(self->cfd, buf, total_len, 0), -1);
652	for (i = 0; i < iov_len; i++) {
653		EXPECT_EQ(memcmp(test_strs[i], buf + len_cmp,
654				 strlen(test_strs[i])),
655			  0);
656		len_cmp += strlen(buf + len_cmp) + 1;
657	}
658	for (i = 0; i < iov_len; i++)
659		free(test_strs[i]);
660	free(buf);
661}
662
663TEST_F(tls, sendmsg_multiple_stress)
664{
665	char const *test_str = "abcdefghijklmno";
666	struct iovec vec[1024];
667	char *test_strs[1024];
668	int iov_len = 1024;
669	int total_len = 0;
670	char buf[1 << 14];
671	struct msghdr msg;
672	int len_cmp = 0;
673	int i;
674
675	memset(&msg, 0, sizeof(struct msghdr));
676	for (i = 0; i < iov_len; i++) {
677		test_strs[i] = (char *)malloc(strlen(test_str) + 1);
678		snprintf(test_strs[i], strlen(test_str) + 1, "%s", test_str);
679		vec[i].iov_base = (void *)test_strs[i];
680		vec[i].iov_len = strlen(test_strs[i]) + 1;
681		total_len += vec[i].iov_len;
682	}
683	msg.msg_iov = vec;
684	msg.msg_iovlen = iov_len;
685
686	EXPECT_EQ(sendmsg(self->fd, &msg, 0), total_len);
687	EXPECT_NE(recv(self->cfd, buf, total_len, 0), -1);
688
689	for (i = 0; i < iov_len; i++)
690		len_cmp += strlen(buf + len_cmp) + 1;
691
692	for (i = 0; i < iov_len; i++)
693		free(test_strs[i]);
694}
695
696TEST_F(tls, splice_from_pipe)
697{
698	int send_len = TLS_PAYLOAD_MAX_LEN;
699	char mem_send[TLS_PAYLOAD_MAX_LEN];
700	char mem_recv[TLS_PAYLOAD_MAX_LEN];
701	int p[2];
702
703	ASSERT_GE(pipe(p), 0);
704	EXPECT_GE(write(p[1], mem_send, send_len), 0);
705	EXPECT_GE(splice(p[0], NULL, self->fd, NULL, send_len, 0), 0);
706	EXPECT_EQ(recv(self->cfd, mem_recv, send_len, MSG_WAITALL), send_len);
707	EXPECT_EQ(memcmp(mem_send, mem_recv, send_len), 0);
708}
709
710TEST_F(tls, splice_more)
711{
712	unsigned int f = SPLICE_F_NONBLOCK | SPLICE_F_MORE | SPLICE_F_GIFT;
713	int send_len = TLS_PAYLOAD_MAX_LEN;
714	char mem_send[TLS_PAYLOAD_MAX_LEN];
715	int i, send_pipe = 1;
716	int p[2];
717
718	ASSERT_GE(pipe(p), 0);
719	EXPECT_GE(write(p[1], mem_send, send_len), 0);
720	for (i = 0; i < 32; i++)
721		EXPECT_EQ(splice(p[0], NULL, self->fd, NULL, send_pipe, f), 1);
722}
723
724TEST_F(tls, splice_from_pipe2)
725{
726	int send_len = 16000;
727	char mem_send[16000];
728	char mem_recv[16000];
729	int p2[2];
730	int p[2];
731
732	memrnd(mem_send, sizeof(mem_send));
733
734	ASSERT_GE(pipe(p), 0);
735	ASSERT_GE(pipe(p2), 0);
736	EXPECT_EQ(write(p[1], mem_send, 8000), 8000);
737	EXPECT_EQ(splice(p[0], NULL, self->fd, NULL, 8000, 0), 8000);
738	EXPECT_EQ(write(p2[1], mem_send + 8000, 8000), 8000);
739	EXPECT_EQ(splice(p2[0], NULL, self->fd, NULL, 8000, 0), 8000);
740	EXPECT_EQ(recv(self->cfd, mem_recv, send_len, MSG_WAITALL), send_len);
741	EXPECT_EQ(memcmp(mem_send, mem_recv, send_len), 0);
742}
743
744TEST_F(tls, send_and_splice)
745{
746	int send_len = TLS_PAYLOAD_MAX_LEN;
747	char mem_send[TLS_PAYLOAD_MAX_LEN];
748	char mem_recv[TLS_PAYLOAD_MAX_LEN];
749	char const *test_str = "test_read";
750	int send_len2 = 10;
751	char buf[10];
752	int p[2];
753
754	ASSERT_GE(pipe(p), 0);
755	EXPECT_EQ(send(self->fd, test_str, send_len2, 0), send_len2);
756	EXPECT_EQ(recv(self->cfd, buf, send_len2, MSG_WAITALL), send_len2);
757	EXPECT_EQ(memcmp(test_str, buf, send_len2), 0);
758
759	EXPECT_GE(write(p[1], mem_send, send_len), send_len);
760	EXPECT_GE(splice(p[0], NULL, self->fd, NULL, send_len, 0), send_len);
761
762	EXPECT_EQ(recv(self->cfd, mem_recv, send_len, MSG_WAITALL), send_len);
763	EXPECT_EQ(memcmp(mem_send, mem_recv, send_len), 0);
764}
765
766TEST_F(tls, splice_to_pipe)
767{
768	int send_len = TLS_PAYLOAD_MAX_LEN;
769	char mem_send[TLS_PAYLOAD_MAX_LEN];
770	char mem_recv[TLS_PAYLOAD_MAX_LEN];
771	int p[2];
772
773	memrnd(mem_send, sizeof(mem_send));
774
775	ASSERT_GE(pipe(p), 0);
776	EXPECT_EQ(send(self->fd, mem_send, send_len, 0), send_len);
777	EXPECT_EQ(splice(self->cfd, NULL, p[1], NULL, send_len, 0), send_len);
778	EXPECT_EQ(read(p[0], mem_recv, send_len), send_len);
779	EXPECT_EQ(memcmp(mem_send, mem_recv, send_len), 0);
780}
781
782TEST_F(tls, splice_cmsg_to_pipe)
783{
784	char *test_str = "test_read";
785	char record_type = 100;
786	int send_len = 10;
787	char buf[10];
788	int p[2];
789
790	if (self->notls)
791		SKIP(return, "no TLS support");
792
793	ASSERT_GE(pipe(p), 0);
794	EXPECT_EQ(tls_send_cmsg(self->fd, 100, test_str, send_len, 0), 10);
795	EXPECT_EQ(splice(self->cfd, NULL, p[1], NULL, send_len, 0), -1);
796	EXPECT_EQ(errno, EINVAL);
797	EXPECT_EQ(recv(self->cfd, buf, send_len, 0), -1);
798	EXPECT_EQ(errno, EIO);
799	EXPECT_EQ(tls_recv_cmsg(_metadata, self->cfd, record_type,
800				buf, sizeof(buf), MSG_WAITALL),
801		  send_len);
802	EXPECT_EQ(memcmp(test_str, buf, send_len), 0);
803}
804
805TEST_F(tls, splice_dec_cmsg_to_pipe)
806{
807	char *test_str = "test_read";
808	char record_type = 100;
809	int send_len = 10;
810	char buf[10];
811	int p[2];
812
813	if (self->notls)
814		SKIP(return, "no TLS support");
815
816	ASSERT_GE(pipe(p), 0);
817	EXPECT_EQ(tls_send_cmsg(self->fd, 100, test_str, send_len, 0), 10);
818	EXPECT_EQ(recv(self->cfd, buf, send_len, 0), -1);
819	EXPECT_EQ(errno, EIO);
820	EXPECT_EQ(splice(self->cfd, NULL, p[1], NULL, send_len, 0), -1);
821	EXPECT_EQ(errno, EINVAL);
822	EXPECT_EQ(tls_recv_cmsg(_metadata, self->cfd, record_type,
823				buf, sizeof(buf), MSG_WAITALL),
824		  send_len);
825	EXPECT_EQ(memcmp(test_str, buf, send_len), 0);
826}
827
828TEST_F(tls, recv_and_splice)
829{
830	int send_len = TLS_PAYLOAD_MAX_LEN;
831	char mem_send[TLS_PAYLOAD_MAX_LEN];
832	char mem_recv[TLS_PAYLOAD_MAX_LEN];
833	int half = send_len / 2;
834	int p[2];
835
836	ASSERT_GE(pipe(p), 0);
837	EXPECT_EQ(send(self->fd, mem_send, send_len, 0), send_len);
838	/* Recv hald of the record, splice the other half */
839	EXPECT_EQ(recv(self->cfd, mem_recv, half, MSG_WAITALL), half);
840	EXPECT_EQ(splice(self->cfd, NULL, p[1], NULL, half, SPLICE_F_NONBLOCK),
841		  half);
842	EXPECT_EQ(read(p[0], &mem_recv[half], half), half);
843	EXPECT_EQ(memcmp(mem_send, mem_recv, send_len), 0);
844}
845
846TEST_F(tls, peek_and_splice)
847{
848	int send_len = TLS_PAYLOAD_MAX_LEN;
849	char mem_send[TLS_PAYLOAD_MAX_LEN];
850	char mem_recv[TLS_PAYLOAD_MAX_LEN];
851	int chunk = TLS_PAYLOAD_MAX_LEN / 4;
852	int n, i, p[2];
853
854	memrnd(mem_send, sizeof(mem_send));
855
856	ASSERT_GE(pipe(p), 0);
857	for (i = 0; i < 4; i++)
858		EXPECT_EQ(send(self->fd, &mem_send[chunk * i], chunk, 0),
859			  chunk);
860
861	EXPECT_EQ(recv(self->cfd, mem_recv, chunk * 5 / 2,
862		       MSG_WAITALL | MSG_PEEK),
863		  chunk * 5 / 2);
864	EXPECT_EQ(memcmp(mem_send, mem_recv, chunk * 5 / 2), 0);
865
866	n = 0;
867	while (n < send_len) {
868		i = splice(self->cfd, NULL, p[1], NULL, send_len - n, 0);
869		EXPECT_GT(i, 0);
870		n += i;
871	}
872	EXPECT_EQ(n, send_len);
873	EXPECT_EQ(read(p[0], mem_recv, send_len), send_len);
874	EXPECT_EQ(memcmp(mem_send, mem_recv, send_len), 0);
875}
876
877TEST_F(tls, recvmsg_single)
878{
879	char const *test_str = "test_recvmsg_single";
880	int send_len = strlen(test_str) + 1;
881	char buf[20];
882	struct msghdr hdr;
883	struct iovec vec;
884
885	memset(&hdr, 0, sizeof(hdr));
886	EXPECT_EQ(send(self->fd, test_str, send_len, 0), send_len);
887	vec.iov_base = (char *)buf;
888	vec.iov_len = send_len;
889	hdr.msg_iovlen = 1;
890	hdr.msg_iov = &vec;
891	EXPECT_NE(recvmsg(self->cfd, &hdr, 0), -1);
892	EXPECT_EQ(memcmp(test_str, buf, send_len), 0);
893}
894
895TEST_F(tls, recvmsg_single_max)
896{
897	int send_len = TLS_PAYLOAD_MAX_LEN;
898	char send_mem[TLS_PAYLOAD_MAX_LEN];
899	char recv_mem[TLS_PAYLOAD_MAX_LEN];
900	struct iovec vec;
901	struct msghdr hdr;
902
903	memrnd(send_mem, sizeof(send_mem));
904
905	EXPECT_EQ(send(self->fd, send_mem, send_len, 0), send_len);
906	vec.iov_base = (char *)recv_mem;
907	vec.iov_len = TLS_PAYLOAD_MAX_LEN;
908
909	hdr.msg_iovlen = 1;
910	hdr.msg_iov = &vec;
911	EXPECT_NE(recvmsg(self->cfd, &hdr, 0), -1);
912	EXPECT_EQ(memcmp(send_mem, recv_mem, send_len), 0);
913}
914
915TEST_F(tls, recvmsg_multiple)
916{
917	unsigned int msg_iovlen = 1024;
918	struct iovec vec[1024];
919	char *iov_base[1024];
920	unsigned int iov_len = 16;
921	int send_len = 1 << 14;
922	char buf[1 << 14];
923	struct msghdr hdr;
924	int i;
925
926	memrnd(buf, sizeof(buf));
927
928	EXPECT_EQ(send(self->fd, buf, send_len, 0), send_len);
929	for (i = 0; i < msg_iovlen; i++) {
930		iov_base[i] = (char *)malloc(iov_len);
931		vec[i].iov_base = iov_base[i];
932		vec[i].iov_len = iov_len;
933	}
934
935	hdr.msg_iovlen = msg_iovlen;
936	hdr.msg_iov = vec;
937	EXPECT_NE(recvmsg(self->cfd, &hdr, 0), -1);
938
939	for (i = 0; i < msg_iovlen; i++)
940		free(iov_base[i]);
941}
942
943TEST_F(tls, single_send_multiple_recv)
944{
945	unsigned int total_len = TLS_PAYLOAD_MAX_LEN * 2;
946	unsigned int send_len = TLS_PAYLOAD_MAX_LEN;
947	char send_mem[TLS_PAYLOAD_MAX_LEN * 2];
948	char recv_mem[TLS_PAYLOAD_MAX_LEN * 2];
949
950	memrnd(send_mem, sizeof(send_mem));
951
952	EXPECT_GE(send(self->fd, send_mem, total_len, 0), 0);
953	memset(recv_mem, 0, total_len);
954
955	EXPECT_NE(recv(self->cfd, recv_mem, send_len, 0), -1);
956	EXPECT_NE(recv(self->cfd, recv_mem + send_len, send_len, 0), -1);
957	EXPECT_EQ(memcmp(send_mem, recv_mem, total_len), 0);
958}
959
960TEST_F(tls, multiple_send_single_recv)
961{
962	unsigned int total_len = 2 * 10;
963	unsigned int send_len = 10;
964	char recv_mem[2 * 10];
965	char send_mem[10];
966
967	memrnd(send_mem, sizeof(send_mem));
968
969	EXPECT_GE(send(self->fd, send_mem, send_len, 0), 0);
970	EXPECT_GE(send(self->fd, send_mem, send_len, 0), 0);
971	memset(recv_mem, 0, total_len);
972	EXPECT_EQ(recv(self->cfd, recv_mem, total_len, MSG_WAITALL), total_len);
973
974	EXPECT_EQ(memcmp(send_mem, recv_mem, send_len), 0);
975	EXPECT_EQ(memcmp(send_mem, recv_mem + send_len, send_len), 0);
976}
977
978TEST_F(tls, single_send_multiple_recv_non_align)
979{
980	const unsigned int total_len = 15;
981	const unsigned int recv_len = 10;
982	char recv_mem[recv_len * 2];
983	char send_mem[total_len];
984
985	memrnd(send_mem, sizeof(send_mem));
986
987	EXPECT_GE(send(self->fd, send_mem, total_len, 0), 0);
988	memset(recv_mem, 0, total_len);
989
990	EXPECT_EQ(recv(self->cfd, recv_mem, recv_len, 0), recv_len);
991	EXPECT_EQ(recv(self->cfd, recv_mem + recv_len, recv_len, 0), 5);
992	EXPECT_EQ(memcmp(send_mem, recv_mem, total_len), 0);
993}
994
995TEST_F(tls, recv_partial)
996{
997	char const *test_str = "test_read_partial";
998	char const *test_str_first = "test_read";
999	char const *test_str_second = "_partial";
1000	int send_len = strlen(test_str) + 1;
1001	char recv_mem[18];
1002
1003	memset(recv_mem, 0, sizeof(recv_mem));
1004	EXPECT_EQ(send(self->fd, test_str, send_len, 0), send_len);
1005	EXPECT_EQ(recv(self->cfd, recv_mem, strlen(test_str_first),
1006		       MSG_WAITALL), strlen(test_str_first));
1007	EXPECT_EQ(memcmp(test_str_first, recv_mem, strlen(test_str_first)), 0);
1008	memset(recv_mem, 0, sizeof(recv_mem));
1009	EXPECT_EQ(recv(self->cfd, recv_mem, strlen(test_str_second),
1010		       MSG_WAITALL), strlen(test_str_second));
1011	EXPECT_EQ(memcmp(test_str_second, recv_mem, strlen(test_str_second)),
1012		  0);
1013}
1014
1015TEST_F(tls, recv_nonblock)
1016{
1017	char buf[4096];
1018	bool err;
1019
1020	EXPECT_EQ(recv(self->cfd, buf, sizeof(buf), MSG_DONTWAIT), -1);
1021	err = (errno == EAGAIN || errno == EWOULDBLOCK);
1022	EXPECT_EQ(err, true);
1023}
1024
1025TEST_F(tls, recv_peek)
1026{
1027	char const *test_str = "test_read_peek";
1028	int send_len = strlen(test_str) + 1;
1029	char buf[15];
1030
1031	EXPECT_EQ(send(self->fd, test_str, send_len, 0), send_len);
1032	EXPECT_EQ(recv(self->cfd, buf, send_len, MSG_PEEK), send_len);
1033	EXPECT_EQ(memcmp(test_str, buf, send_len), 0);
1034	memset(buf, 0, sizeof(buf));
1035	EXPECT_EQ(recv(self->cfd, buf, send_len, 0), send_len);
1036	EXPECT_EQ(memcmp(test_str, buf, send_len), 0);
1037}
1038
1039TEST_F(tls, recv_peek_multiple)
1040{
1041	char const *test_str = "test_read_peek";
1042	int send_len = strlen(test_str) + 1;
1043	unsigned int num_peeks = 100;
1044	char buf[15];
1045	int i;
1046
1047	EXPECT_EQ(send(self->fd, test_str, send_len, 0), send_len);
1048	for (i = 0; i < num_peeks; i++) {
1049		EXPECT_NE(recv(self->cfd, buf, send_len, MSG_PEEK), -1);
1050		EXPECT_EQ(memcmp(test_str, buf, send_len), 0);
1051		memset(buf, 0, sizeof(buf));
1052	}
1053	EXPECT_NE(recv(self->cfd, buf, send_len, 0), -1);
1054	EXPECT_EQ(memcmp(test_str, buf, send_len), 0);
1055}
1056
1057TEST_F(tls, recv_peek_multiple_records)
1058{
1059	char const *test_str = "test_read_peek_mult_recs";
1060	char const *test_str_first = "test_read_peek";
1061	char const *test_str_second = "_mult_recs";
1062	int len;
1063	char buf[64];
1064
1065	len = strlen(test_str_first);
1066	EXPECT_EQ(send(self->fd, test_str_first, len, 0), len);
1067
1068	len = strlen(test_str_second) + 1;
1069	EXPECT_EQ(send(self->fd, test_str_second, len, 0), len);
1070
1071	len = strlen(test_str_first);
1072	memset(buf, 0, len);
1073	EXPECT_EQ(recv(self->cfd, buf, len, MSG_PEEK | MSG_WAITALL), len);
1074
1075	/* MSG_PEEK can only peek into the current record. */
1076	len = strlen(test_str_first);
1077	EXPECT_EQ(memcmp(test_str_first, buf, len), 0);
1078
1079	len = strlen(test_str) + 1;
1080	memset(buf, 0, len);
1081	EXPECT_EQ(recv(self->cfd, buf, len, MSG_WAITALL), len);
1082
1083	/* Non-MSG_PEEK will advance strparser (and therefore record)
1084	 * however.
1085	 */
1086	len = strlen(test_str) + 1;
1087	EXPECT_EQ(memcmp(test_str, buf, len), 0);
1088
1089	/* MSG_MORE will hold current record open, so later MSG_PEEK
1090	 * will see everything.
1091	 */
1092	len = strlen(test_str_first);
1093	EXPECT_EQ(send(self->fd, test_str_first, len, MSG_MORE), len);
1094
1095	len = strlen(test_str_second) + 1;
1096	EXPECT_EQ(send(self->fd, test_str_second, len, 0), len);
1097
1098	len = strlen(test_str) + 1;
1099	memset(buf, 0, len);
1100	EXPECT_EQ(recv(self->cfd, buf, len, MSG_PEEK | MSG_WAITALL), len);
1101
1102	len = strlen(test_str) + 1;
1103	EXPECT_EQ(memcmp(test_str, buf, len), 0);
1104}
1105
1106TEST_F(tls, recv_peek_large_buf_mult_recs)
1107{
1108	char const *test_str = "test_read_peek_mult_recs";
1109	char const *test_str_first = "test_read_peek";
1110	char const *test_str_second = "_mult_recs";
1111	int len;
1112	char buf[64];
1113
1114	len = strlen(test_str_first);
1115	EXPECT_EQ(send(self->fd, test_str_first, len, 0), len);
1116
1117	len = strlen(test_str_second) + 1;
1118	EXPECT_EQ(send(self->fd, test_str_second, len, 0), len);
1119
1120	len = strlen(test_str) + 1;
1121	memset(buf, 0, len);
1122	EXPECT_NE((len = recv(self->cfd, buf, len,
1123			      MSG_PEEK | MSG_WAITALL)), -1);
1124	len = strlen(test_str) + 1;
1125	EXPECT_EQ(memcmp(test_str, buf, len), 0);
1126}
1127
1128TEST_F(tls, recv_lowat)
1129{
1130	char send_mem[10] = { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9 };
1131	char recv_mem[20];
1132	int lowat = 8;
1133
1134	EXPECT_EQ(send(self->fd, send_mem, 10, 0), 10);
1135	EXPECT_EQ(send(self->fd, send_mem, 5, 0), 5);
1136
1137	memset(recv_mem, 0, 20);
1138	EXPECT_EQ(setsockopt(self->cfd, SOL_SOCKET, SO_RCVLOWAT,
1139			     &lowat, sizeof(lowat)), 0);
1140	EXPECT_EQ(recv(self->cfd, recv_mem, 1, MSG_WAITALL), 1);
1141	EXPECT_EQ(recv(self->cfd, recv_mem + 1, 6, MSG_WAITALL), 6);
1142	EXPECT_EQ(recv(self->cfd, recv_mem + 7, 10, 0), 8);
1143
1144	EXPECT_EQ(memcmp(send_mem, recv_mem, 10), 0);
1145	EXPECT_EQ(memcmp(send_mem, recv_mem + 10, 5), 0);
1146}
1147
1148TEST_F(tls, bidir)
1149{
1150	char const *test_str = "test_read";
1151	int send_len = 10;
1152	char buf[10];
1153	int ret;
1154
1155	if (!self->notls) {
1156		struct tls_crypto_info_keys tls12;
1157
1158		tls_crypto_info_init(variant->tls_version, variant->cipher_type,
1159				     &tls12);
1160
1161		ret = setsockopt(self->fd, SOL_TLS, TLS_RX, &tls12,
1162				 tls12.len);
1163		ASSERT_EQ(ret, 0);
1164
1165		ret = setsockopt(self->cfd, SOL_TLS, TLS_TX, &tls12,
1166				 tls12.len);
1167		ASSERT_EQ(ret, 0);
1168	}
1169
1170	ASSERT_EQ(strlen(test_str) + 1, send_len);
1171
1172	EXPECT_EQ(send(self->fd, test_str, send_len, 0), send_len);
1173	EXPECT_NE(recv(self->cfd, buf, send_len, 0), -1);
1174	EXPECT_EQ(memcmp(buf, test_str, send_len), 0);
1175
1176	memset(buf, 0, sizeof(buf));
1177
1178	EXPECT_EQ(send(self->cfd, test_str, send_len, 0), send_len);
1179	EXPECT_NE(recv(self->fd, buf, send_len, 0), -1);
1180	EXPECT_EQ(memcmp(buf, test_str, send_len), 0);
1181};
1182
1183TEST_F(tls, pollin)
1184{
1185	char const *test_str = "test_poll";
1186	struct pollfd fd = { 0, 0, 0 };
1187	char buf[10];
1188	int send_len = 10;
1189
1190	EXPECT_EQ(send(self->fd, test_str, send_len, 0), send_len);
1191	fd.fd = self->cfd;
1192	fd.events = POLLIN;
1193
1194	EXPECT_EQ(poll(&fd, 1, 20), 1);
1195	EXPECT_EQ(fd.revents & POLLIN, 1);
1196	EXPECT_EQ(recv(self->cfd, buf, send_len, MSG_WAITALL), send_len);
1197	/* Test timing out */
1198	EXPECT_EQ(poll(&fd, 1, 20), 0);
1199}
1200
1201TEST_F(tls, poll_wait)
1202{
1203	char const *test_str = "test_poll_wait";
1204	int send_len = strlen(test_str) + 1;
1205	struct pollfd fd = { 0, 0, 0 };
1206	char recv_mem[15];
1207
1208	fd.fd = self->cfd;
1209	fd.events = POLLIN;
1210	EXPECT_EQ(send(self->fd, test_str, send_len, 0), send_len);
1211	/* Set timeout to inf. secs */
1212	EXPECT_EQ(poll(&fd, 1, -1), 1);
1213	EXPECT_EQ(fd.revents & POLLIN, 1);
1214	EXPECT_EQ(recv(self->cfd, recv_mem, send_len, MSG_WAITALL), send_len);
1215}
1216
1217TEST_F(tls, poll_wait_split)
1218{
1219	struct pollfd fd = { 0, 0, 0 };
1220	char send_mem[20] = {};
1221	char recv_mem[15];
1222
1223	fd.fd = self->cfd;
1224	fd.events = POLLIN;
1225	/* Send 20 bytes */
1226	EXPECT_EQ(send(self->fd, send_mem, sizeof(send_mem), 0),
1227		  sizeof(send_mem));
1228	/* Poll with inf. timeout */
1229	EXPECT_EQ(poll(&fd, 1, -1), 1);
1230	EXPECT_EQ(fd.revents & POLLIN, 1);
1231	EXPECT_EQ(recv(self->cfd, recv_mem, sizeof(recv_mem), MSG_WAITALL),
1232		  sizeof(recv_mem));
1233
1234	/* Now the remaining 5 bytes of record data are in TLS ULP */
1235	fd.fd = self->cfd;
1236	fd.events = POLLIN;
1237	EXPECT_EQ(poll(&fd, 1, -1), 1);
1238	EXPECT_EQ(fd.revents & POLLIN, 1);
1239	EXPECT_EQ(recv(self->cfd, recv_mem, sizeof(recv_mem), 0),
1240		  sizeof(send_mem) - sizeof(recv_mem));
1241}
1242
1243TEST_F(tls, blocking)
1244{
1245	size_t data = 100000;
1246	int res = fork();
1247
1248	EXPECT_NE(res, -1);
1249
1250	if (res) {
1251		/* parent */
1252		size_t left = data;
1253		char buf[16384];
1254		int status;
1255		int pid2;
1256
1257		while (left) {
1258			int res = send(self->fd, buf,
1259				       left > 16384 ? 16384 : left, 0);
1260
1261			EXPECT_GE(res, 0);
1262			left -= res;
1263		}
1264
1265		pid2 = wait(&status);
1266		EXPECT_EQ(status, 0);
1267		EXPECT_EQ(res, pid2);
1268	} else {
1269		/* child */
1270		size_t left = data;
1271		char buf[16384];
1272
1273		while (left) {
1274			int res = recv(self->cfd, buf,
1275				       left > 16384 ? 16384 : left, 0);
1276
1277			EXPECT_GE(res, 0);
1278			left -= res;
1279		}
1280	}
1281}
1282
1283TEST_F(tls, nonblocking)
1284{
1285	size_t data = 100000;
1286	int sendbuf = 100;
1287	int flags;
1288	int res;
1289
1290	flags = fcntl(self->fd, F_GETFL, 0);
1291	fcntl(self->fd, F_SETFL, flags | O_NONBLOCK);
1292	fcntl(self->cfd, F_SETFL, flags | O_NONBLOCK);
1293
1294	/* Ensure nonblocking behavior by imposing a small send
1295	 * buffer.
1296	 */
1297	EXPECT_EQ(setsockopt(self->fd, SOL_SOCKET, SO_SNDBUF,
1298			     &sendbuf, sizeof(sendbuf)), 0);
1299
1300	res = fork();
1301	EXPECT_NE(res, -1);
1302
1303	if (res) {
1304		/* parent */
1305		bool eagain = false;
1306		size_t left = data;
1307		char buf[16384];
1308		int status;
1309		int pid2;
1310
1311		while (left) {
1312			int res = send(self->fd, buf,
1313				       left > 16384 ? 16384 : left, 0);
1314
1315			if (res == -1 && errno == EAGAIN) {
1316				eagain = true;
1317				usleep(10000);
1318				continue;
1319			}
1320			EXPECT_GE(res, 0);
1321			left -= res;
1322		}
1323
1324		EXPECT_TRUE(eagain);
1325		pid2 = wait(&status);
1326
1327		EXPECT_EQ(status, 0);
1328		EXPECT_EQ(res, pid2);
1329	} else {
1330		/* child */
1331		bool eagain = false;
1332		size_t left = data;
1333		char buf[16384];
1334
1335		while (left) {
1336			int res = recv(self->cfd, buf,
1337				       left > 16384 ? 16384 : left, 0);
1338
1339			if (res == -1 && errno == EAGAIN) {
1340				eagain = true;
1341				usleep(10000);
1342				continue;
1343			}
1344			EXPECT_GE(res, 0);
1345			left -= res;
1346		}
1347		EXPECT_TRUE(eagain);
1348	}
1349}
1350
1351static void
1352test_mutliproc(struct __test_metadata *_metadata, struct _test_data_tls *self,
1353	       bool sendpg, unsigned int n_readers, unsigned int n_writers)
1354{
1355	const unsigned int n_children = n_readers + n_writers;
1356	const size_t data = 6 * 1000 * 1000;
1357	const size_t file_sz = data / 100;
1358	size_t read_bias, write_bias;
1359	int i, fd, child_id;
1360	char buf[file_sz];
1361	pid_t pid;
1362
1363	/* Only allow multiples for simplicity */
1364	ASSERT_EQ(!(n_readers % n_writers) || !(n_writers % n_readers), true);
1365	read_bias = n_writers / n_readers ?: 1;
1366	write_bias = n_readers / n_writers ?: 1;
1367
1368	/* prep a file to send */
1369	fd = open("/tmp/", O_TMPFILE | O_RDWR, 0600);
1370	ASSERT_GE(fd, 0);
1371
1372	memset(buf, 0xac, file_sz);
1373	ASSERT_EQ(write(fd, buf, file_sz), file_sz);
1374
1375	/* spawn children */
1376	for (child_id = 0; child_id < n_children; child_id++) {
1377		pid = fork();
1378		ASSERT_NE(pid, -1);
1379		if (!pid)
1380			break;
1381	}
1382
1383	/* parent waits for all children */
1384	if (pid) {
1385		for (i = 0; i < n_children; i++) {
1386			int status;
1387
1388			wait(&status);
1389			EXPECT_EQ(status, 0);
1390		}
1391
1392		return;
1393	}
1394
1395	/* Split threads for reading and writing */
1396	if (child_id < n_readers) {
1397		size_t left = data * read_bias;
1398		char rb[8001];
1399
1400		while (left) {
1401			int res;
1402
1403			res = recv(self->cfd, rb,
1404				   left > sizeof(rb) ? sizeof(rb) : left, 0);
1405
1406			EXPECT_GE(res, 0);
1407			left -= res;
1408		}
1409	} else {
1410		size_t left = data * write_bias;
1411
1412		while (left) {
1413			int res;
1414
1415			ASSERT_EQ(lseek(fd, 0, SEEK_SET), 0);
1416			if (sendpg)
1417				res = sendfile(self->fd, fd, NULL,
1418					       left > file_sz ? file_sz : left);
1419			else
1420				res = send(self->fd, buf,
1421					   left > file_sz ? file_sz : left, 0);
1422
1423			EXPECT_GE(res, 0);
1424			left -= res;
1425		}
1426	}
1427}
1428
1429TEST_F(tls, mutliproc_even)
1430{
1431	test_mutliproc(_metadata, self, false, 6, 6);
1432}
1433
1434TEST_F(tls, mutliproc_readers)
1435{
1436	test_mutliproc(_metadata, self, false, 4, 12);
1437}
1438
1439TEST_F(tls, mutliproc_writers)
1440{
1441	test_mutliproc(_metadata, self, false, 10, 2);
1442}
1443
1444TEST_F(tls, mutliproc_sendpage_even)
1445{
1446	test_mutliproc(_metadata, self, true, 6, 6);
1447}
1448
1449TEST_F(tls, mutliproc_sendpage_readers)
1450{
1451	test_mutliproc(_metadata, self, true, 4, 12);
1452}
1453
1454TEST_F(tls, mutliproc_sendpage_writers)
1455{
1456	test_mutliproc(_metadata, self, true, 10, 2);
1457}
1458
1459TEST_F(tls, control_msg)
1460{
1461	char *test_str = "test_read";
1462	char record_type = 100;
1463	int send_len = 10;
1464	char buf[10];
1465
1466	if (self->notls)
1467		SKIP(return, "no TLS support");
1468
1469	EXPECT_EQ(tls_send_cmsg(self->fd, record_type, test_str, send_len, 0),
1470		  send_len);
1471	/* Should fail because we didn't provide a control message */
1472	EXPECT_EQ(recv(self->cfd, buf, send_len, 0), -1);
1473
1474	EXPECT_EQ(tls_recv_cmsg(_metadata, self->cfd, record_type,
1475				buf, sizeof(buf), MSG_WAITALL | MSG_PEEK),
1476		  send_len);
1477	EXPECT_EQ(memcmp(buf, test_str, send_len), 0);
1478
1479	/* Recv the message again without MSG_PEEK */
1480	memset(buf, 0, sizeof(buf));
1481
1482	EXPECT_EQ(tls_recv_cmsg(_metadata, self->cfd, record_type,
1483				buf, sizeof(buf), MSG_WAITALL),
1484		  send_len);
1485	EXPECT_EQ(memcmp(buf, test_str, send_len), 0);
1486}
1487
1488TEST_F(tls, control_msg_nomerge)
1489{
1490	char *rec1 = "1111";
1491	char *rec2 = "2222";
1492	int send_len = 5;
1493	char buf[15];
1494
1495	if (self->notls)
1496		SKIP(return, "no TLS support");
1497
1498	EXPECT_EQ(tls_send_cmsg(self->fd, 100, rec1, send_len, 0), send_len);
1499	EXPECT_EQ(tls_send_cmsg(self->fd, 100, rec2, send_len, 0), send_len);
1500
1501	EXPECT_EQ(tls_recv_cmsg(_metadata, self->cfd, 100, buf, sizeof(buf), MSG_PEEK), send_len);
1502	EXPECT_EQ(memcmp(buf, rec1, send_len), 0);
1503
1504	EXPECT_EQ(tls_recv_cmsg(_metadata, self->cfd, 100, buf, sizeof(buf), MSG_PEEK), send_len);
1505	EXPECT_EQ(memcmp(buf, rec1, send_len), 0);
1506
1507	EXPECT_EQ(tls_recv_cmsg(_metadata, self->cfd, 100, buf, sizeof(buf), 0), send_len);
1508	EXPECT_EQ(memcmp(buf, rec1, send_len), 0);
1509
1510	EXPECT_EQ(tls_recv_cmsg(_metadata, self->cfd, 100, buf, sizeof(buf), 0), send_len);
1511	EXPECT_EQ(memcmp(buf, rec2, send_len), 0);
1512}
1513
1514TEST_F(tls, data_control_data)
1515{
1516	char *rec1 = "1111";
1517	char *rec2 = "2222";
1518	char *rec3 = "3333";
1519	int send_len = 5;
1520	char buf[15];
1521
1522	if (self->notls)
1523		SKIP(return, "no TLS support");
1524
1525	EXPECT_EQ(send(self->fd, rec1, send_len, 0), send_len);
1526	EXPECT_EQ(tls_send_cmsg(self->fd, 100, rec2, send_len, 0), send_len);
1527	EXPECT_EQ(send(self->fd, rec3, send_len, 0), send_len);
1528
1529	EXPECT_EQ(recv(self->cfd, buf, sizeof(buf), MSG_PEEK), send_len);
1530	EXPECT_EQ(recv(self->cfd, buf, sizeof(buf), MSG_PEEK), send_len);
1531}
1532
1533TEST_F(tls, shutdown)
1534{
1535	char const *test_str = "test_read";
1536	int send_len = 10;
1537	char buf[10];
1538
1539	ASSERT_EQ(strlen(test_str) + 1, send_len);
1540
1541	EXPECT_EQ(send(self->fd, test_str, send_len, 0), send_len);
1542	EXPECT_NE(recv(self->cfd, buf, send_len, 0), -1);
1543	EXPECT_EQ(memcmp(buf, test_str, send_len), 0);
1544
1545	shutdown(self->fd, SHUT_RDWR);
1546	shutdown(self->cfd, SHUT_RDWR);
1547}
1548
1549TEST_F(tls, shutdown_unsent)
1550{
1551	char const *test_str = "test_read";
1552	int send_len = 10;
1553
1554	EXPECT_EQ(send(self->fd, test_str, send_len, MSG_MORE), send_len);
1555
1556	shutdown(self->fd, SHUT_RDWR);
1557	shutdown(self->cfd, SHUT_RDWR);
1558}
1559
1560TEST_F(tls, shutdown_reuse)
1561{
1562	struct sockaddr_in addr;
1563	int ret;
1564
1565	shutdown(self->fd, SHUT_RDWR);
1566	shutdown(self->cfd, SHUT_RDWR);
1567	close(self->cfd);
1568
1569	addr.sin_family = AF_INET;
1570	addr.sin_addr.s_addr = htonl(INADDR_ANY);
1571	addr.sin_port = 0;
1572
1573	ret = bind(self->fd, &addr, sizeof(addr));
1574	EXPECT_EQ(ret, 0);
1575	ret = listen(self->fd, 10);
1576	EXPECT_EQ(ret, -1);
1577	EXPECT_EQ(errno, EINVAL);
1578
1579	ret = connect(self->fd, &addr, sizeof(addr));
1580	EXPECT_EQ(ret, -1);
1581	EXPECT_EQ(errno, EISCONN);
1582}
1583
1584TEST_F(tls, getsockopt)
1585{
1586	struct tls_crypto_info_keys expect, get;
1587	socklen_t len;
1588
1589	/* get only the version/cipher */
1590	len = sizeof(struct tls_crypto_info);
1591	memrnd(&get, sizeof(get));
1592	EXPECT_EQ(getsockopt(self->fd, SOL_TLS, TLS_TX, &get, &len), 0);
1593	EXPECT_EQ(len, sizeof(struct tls_crypto_info));
1594	EXPECT_EQ(get.crypto_info.version, variant->tls_version);
1595	EXPECT_EQ(get.crypto_info.cipher_type, variant->cipher_type);
1596
1597	/* get the full crypto_info */
1598	tls_crypto_info_init(variant->tls_version, variant->cipher_type, &expect);
1599	len = expect.len;
1600	memrnd(&get, sizeof(get));
1601	EXPECT_EQ(getsockopt(self->fd, SOL_TLS, TLS_TX, &get, &len), 0);
1602	EXPECT_EQ(len, expect.len);
1603	EXPECT_EQ(get.crypto_info.version, variant->tls_version);
1604	EXPECT_EQ(get.crypto_info.cipher_type, variant->cipher_type);
1605	EXPECT_EQ(memcmp(&get, &expect, expect.len), 0);
1606
1607	/* short get should fail */
1608	len = sizeof(struct tls_crypto_info) - 1;
1609	EXPECT_EQ(getsockopt(self->fd, SOL_TLS, TLS_TX, &get, &len), -1);
1610	EXPECT_EQ(errno, EINVAL);
1611
1612	/* partial get of the cipher data should fail */
1613	len = expect.len - 1;
1614	EXPECT_EQ(getsockopt(self->fd, SOL_TLS, TLS_TX, &get, &len), -1);
1615	EXPECT_EQ(errno, EINVAL);
1616}
1617
1618TEST_F(tls, recv_efault)
1619{
1620	char *rec1 = "1111111111";
1621	char *rec2 = "2222222222";
1622	struct msghdr hdr = {};
1623	struct iovec iov[2];
1624	char recv_mem[12];
1625	int ret;
1626
1627	if (self->notls)
1628		SKIP(return, "no TLS support");
1629
1630	EXPECT_EQ(send(self->fd, rec1, 10, 0), 10);
1631	EXPECT_EQ(send(self->fd, rec2, 10, 0), 10);
1632
1633	iov[0].iov_base = recv_mem;
1634	iov[0].iov_len = sizeof(recv_mem);
1635	iov[1].iov_base = NULL; /* broken iov to make process_rx_list fail */
1636	iov[1].iov_len = 1;
1637
1638	hdr.msg_iovlen = 2;
1639	hdr.msg_iov = iov;
1640
1641	EXPECT_EQ(recv(self->cfd, recv_mem, 1, 0), 1);
1642	EXPECT_EQ(recv_mem[0], rec1[0]);
1643
1644	ret = recvmsg(self->cfd, &hdr, 0);
1645	EXPECT_LE(ret, sizeof(recv_mem));
1646	EXPECT_GE(ret, 9);
1647	EXPECT_EQ(memcmp(rec1, recv_mem, 9), 0);
1648	if (ret > 9)
1649		EXPECT_EQ(memcmp(rec2, recv_mem + 9, ret - 9), 0);
1650}
1651
1652FIXTURE(tls_err)
1653{
1654	int fd, cfd;
1655	int fd2, cfd2;
1656	bool notls;
1657};
1658
1659FIXTURE_VARIANT(tls_err)
1660{
1661	uint16_t tls_version;
1662};
1663
1664FIXTURE_VARIANT_ADD(tls_err, 12_aes_gcm)
1665{
1666	.tls_version = TLS_1_2_VERSION,
1667};
1668
1669FIXTURE_VARIANT_ADD(tls_err, 13_aes_gcm)
1670{
1671	.tls_version = TLS_1_3_VERSION,
1672};
1673
1674FIXTURE_SETUP(tls_err)
1675{
1676	struct tls_crypto_info_keys tls12;
1677	int ret;
1678
1679	tls_crypto_info_init(variant->tls_version, TLS_CIPHER_AES_GCM_128,
1680			     &tls12);
1681
1682	ulp_sock_pair(_metadata, &self->fd, &self->cfd, &self->notls);
1683	ulp_sock_pair(_metadata, &self->fd2, &self->cfd2, &self->notls);
1684	if (self->notls)
1685		return;
1686
1687	ret = setsockopt(self->fd, SOL_TLS, TLS_TX, &tls12, tls12.len);
1688	ASSERT_EQ(ret, 0);
1689
1690	ret = setsockopt(self->cfd2, SOL_TLS, TLS_RX, &tls12, tls12.len);
1691	ASSERT_EQ(ret, 0);
1692}
1693
1694FIXTURE_TEARDOWN(tls_err)
1695{
1696	close(self->fd);
1697	close(self->cfd);
1698	close(self->fd2);
1699	close(self->cfd2);
1700}
1701
1702TEST_F(tls_err, bad_rec)
1703{
1704	char buf[64];
1705
1706	if (self->notls)
1707		SKIP(return, "no TLS support");
1708
1709	memset(buf, 0x55, sizeof(buf));
1710	EXPECT_EQ(send(self->fd2, buf, sizeof(buf), 0), sizeof(buf));
1711	EXPECT_EQ(recv(self->cfd2, buf, sizeof(buf), 0), -1);
1712	EXPECT_EQ(errno, EMSGSIZE);
1713	EXPECT_EQ(recv(self->cfd2, buf, sizeof(buf), MSG_DONTWAIT), -1);
1714	EXPECT_EQ(errno, EAGAIN);
1715}
1716
1717TEST_F(tls_err, bad_auth)
1718{
1719	char buf[128];
1720	int n;
1721
1722	if (self->notls)
1723		SKIP(return, "no TLS support");
1724
1725	memrnd(buf, sizeof(buf) / 2);
1726	EXPECT_EQ(send(self->fd, buf, sizeof(buf) / 2, 0), sizeof(buf) / 2);
1727	n = recv(self->cfd, buf, sizeof(buf), 0);
1728	EXPECT_GT(n, sizeof(buf) / 2);
1729
1730	buf[n - 1]++;
1731
1732	EXPECT_EQ(send(self->fd2, buf, n, 0), n);
1733	EXPECT_EQ(recv(self->cfd2, buf, sizeof(buf), 0), -1);
1734	EXPECT_EQ(errno, EBADMSG);
1735	EXPECT_EQ(recv(self->cfd2, buf, sizeof(buf), 0), -1);
1736	EXPECT_EQ(errno, EBADMSG);
1737}
1738
1739TEST_F(tls_err, bad_in_large_read)
1740{
1741	char txt[3][64];
1742	char cip[3][128];
1743	char buf[3 * 128];
1744	int i, n;
1745
1746	if (self->notls)
1747		SKIP(return, "no TLS support");
1748
1749	/* Put 3 records in the sockets */
1750	for (i = 0; i < 3; i++) {
1751		memrnd(txt[i], sizeof(txt[i]));
1752		EXPECT_EQ(send(self->fd, txt[i], sizeof(txt[i]), 0),
1753			  sizeof(txt[i]));
1754		n = recv(self->cfd, cip[i], sizeof(cip[i]), 0);
1755		EXPECT_GT(n, sizeof(txt[i]));
1756		/* Break the third message */
1757		if (i == 2)
1758			cip[2][n - 1]++;
1759		EXPECT_EQ(send(self->fd2, cip[i], n, 0), n);
1760	}
1761
1762	/* We should be able to receive the first two messages */
1763	EXPECT_EQ(recv(self->cfd2, buf, sizeof(buf), 0), sizeof(txt[0]) * 2);
1764	EXPECT_EQ(memcmp(buf, txt[0], sizeof(txt[0])), 0);
1765	EXPECT_EQ(memcmp(buf + sizeof(txt[0]), txt[1], sizeof(txt[1])), 0);
1766	/* Third mesasge is bad */
1767	EXPECT_EQ(recv(self->cfd2, buf, sizeof(buf), 0), -1);
1768	EXPECT_EQ(errno, EBADMSG);
1769	EXPECT_EQ(recv(self->cfd2, buf, sizeof(buf), 0), -1);
1770	EXPECT_EQ(errno, EBADMSG);
1771}
1772
1773TEST_F(tls_err, bad_cmsg)
1774{
1775	char *test_str = "test_read";
1776	int send_len = 10;
1777	char cip[128];
1778	char buf[128];
1779	char txt[64];
1780	int n;
1781
1782	if (self->notls)
1783		SKIP(return, "no TLS support");
1784
1785	/* Queue up one data record */
1786	memrnd(txt, sizeof(txt));
1787	EXPECT_EQ(send(self->fd, txt, sizeof(txt), 0), sizeof(txt));
1788	n = recv(self->cfd, cip, sizeof(cip), 0);
1789	EXPECT_GT(n, sizeof(txt));
1790	EXPECT_EQ(send(self->fd2, cip, n, 0), n);
1791
1792	EXPECT_EQ(tls_send_cmsg(self->fd, 100, test_str, send_len, 0), 10);
1793	n = recv(self->cfd, cip, sizeof(cip), 0);
1794	cip[n - 1]++; /* Break it */
1795	EXPECT_GT(n, send_len);
1796	EXPECT_EQ(send(self->fd2, cip, n, 0), n);
1797
1798	EXPECT_EQ(recv(self->cfd2, buf, sizeof(buf), 0), sizeof(txt));
1799	EXPECT_EQ(memcmp(buf, txt, sizeof(txt)), 0);
1800	EXPECT_EQ(recv(self->cfd2, buf, sizeof(buf), 0), -1);
1801	EXPECT_EQ(errno, EBADMSG);
1802	EXPECT_EQ(recv(self->cfd2, buf, sizeof(buf), 0), -1);
1803	EXPECT_EQ(errno, EBADMSG);
1804}
1805
1806TEST_F(tls_err, timeo)
1807{
1808	struct timeval tv = { .tv_usec = 10000, };
1809	char buf[128];
1810	int ret;
1811
1812	if (self->notls)
1813		SKIP(return, "no TLS support");
1814
1815	ret = setsockopt(self->cfd2, SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof(tv));
1816	ASSERT_EQ(ret, 0);
1817
1818	ret = fork();
1819	ASSERT_GE(ret, 0);
1820
1821	if (ret) {
1822		usleep(1000); /* Give child a head start */
1823
1824		EXPECT_EQ(recv(self->cfd2, buf, sizeof(buf), 0), -1);
1825		EXPECT_EQ(errno, EAGAIN);
1826
1827		EXPECT_EQ(recv(self->cfd2, buf, sizeof(buf), 0), -1);
1828		EXPECT_EQ(errno, EAGAIN);
1829
1830		wait(&ret);
1831	} else {
1832		EXPECT_EQ(recv(self->cfd2, buf, sizeof(buf), 0), -1);
1833		EXPECT_EQ(errno, EAGAIN);
1834		exit(0);
1835	}
1836}
1837
1838TEST_F(tls_err, poll_partial_rec)
1839{
1840	struct pollfd pfd = { };
1841	ssize_t rec_len;
1842	char rec[256];
1843	char buf[128];
1844
1845	if (self->notls)
1846		SKIP(return, "no TLS support");
1847
1848	pfd.fd = self->cfd2;
1849	pfd.events = POLLIN;
1850	EXPECT_EQ(poll(&pfd, 1, 1), 0);
1851
1852	memrnd(buf, sizeof(buf));
1853	EXPECT_EQ(send(self->fd, buf, sizeof(buf), 0), sizeof(buf));
1854	rec_len = recv(self->cfd, rec, sizeof(rec), 0);
1855	EXPECT_GT(rec_len, sizeof(buf));
1856
1857	/* Write 100B, not the full record ... */
1858	EXPECT_EQ(send(self->fd2, rec, 100, 0), 100);
1859	/* ... no full record should mean no POLLIN */
1860	pfd.fd = self->cfd2;
1861	pfd.events = POLLIN;
1862	EXPECT_EQ(poll(&pfd, 1, 1), 0);
1863	/* Now write the rest, and it should all pop out of the other end. */
1864	EXPECT_EQ(send(self->fd2, rec + 100, rec_len - 100, 0), rec_len - 100);
1865	pfd.fd = self->cfd2;
1866	pfd.events = POLLIN;
1867	EXPECT_EQ(poll(&pfd, 1, 1), 1);
1868	EXPECT_EQ(recv(self->cfd2, rec, sizeof(rec), 0), sizeof(buf));
1869	EXPECT_EQ(memcmp(buf, rec, sizeof(buf)), 0);
1870}
1871
1872TEST_F(tls_err, epoll_partial_rec)
1873{
1874	struct epoll_event ev, events[10];
1875	ssize_t rec_len;
1876	char rec[256];
1877	char buf[128];
1878	int epollfd;
1879
1880	if (self->notls)
1881		SKIP(return, "no TLS support");
1882
1883	epollfd = epoll_create1(0);
1884	ASSERT_GE(epollfd, 0);
1885
1886	memset(&ev, 0, sizeof(ev));
1887	ev.events = EPOLLIN;
1888	ev.data.fd = self->cfd2;
1889	ASSERT_GE(epoll_ctl(epollfd, EPOLL_CTL_ADD, self->cfd2, &ev), 0);
1890
1891	EXPECT_EQ(epoll_wait(epollfd, events, 10, 0), 0);
1892
1893	memrnd(buf, sizeof(buf));
1894	EXPECT_EQ(send(self->fd, buf, sizeof(buf), 0), sizeof(buf));
1895	rec_len = recv(self->cfd, rec, sizeof(rec), 0);
1896	EXPECT_GT(rec_len, sizeof(buf));
1897
1898	/* Write 100B, not the full record ... */
1899	EXPECT_EQ(send(self->fd2, rec, 100, 0), 100);
1900	/* ... no full record should mean no POLLIN */
1901	EXPECT_EQ(epoll_wait(epollfd, events, 10, 0), 0);
1902	/* Now write the rest, and it should all pop out of the other end. */
1903	EXPECT_EQ(send(self->fd2, rec + 100, rec_len - 100, 0), rec_len - 100);
1904	EXPECT_EQ(epoll_wait(epollfd, events, 10, 0), 1);
1905	EXPECT_EQ(recv(self->cfd2, rec, sizeof(rec), 0), sizeof(buf));
1906	EXPECT_EQ(memcmp(buf, rec, sizeof(buf)), 0);
1907
1908	close(epollfd);
1909}
1910
1911TEST_F(tls_err, poll_partial_rec_async)
1912{
1913	struct pollfd pfd = { };
1914	ssize_t rec_len;
1915	char rec[256];
1916	char buf[128];
1917	char token;
1918	int p[2];
1919	int ret;
1920
1921	if (self->notls)
1922		SKIP(return, "no TLS support");
1923
1924	ASSERT_GE(pipe(p), 0);
1925
1926	memrnd(buf, sizeof(buf));
1927	EXPECT_EQ(send(self->fd, buf, sizeof(buf), 0), sizeof(buf));
1928	rec_len = recv(self->cfd, rec, sizeof(rec), 0);
1929	EXPECT_GT(rec_len, sizeof(buf));
1930
1931	ret = fork();
1932	ASSERT_GE(ret, 0);
1933
1934	if (ret) {
1935		int status, pid2;
1936
1937		close(p[1]);
1938		usleep(1000); /* Give child a head start */
1939
1940		EXPECT_EQ(send(self->fd2, rec, 100, 0), 100);
1941
1942		EXPECT_EQ(read(p[0], &token, 1), 1); /* Barrier #1 */
1943
1944		EXPECT_EQ(send(self->fd2, rec + 100, rec_len - 100, 0),
1945			  rec_len - 100);
1946
1947		pid2 = wait(&status);
1948		EXPECT_EQ(pid2, ret);
1949		EXPECT_EQ(status, 0);
1950	} else {
1951		close(p[0]);
1952
1953		/* Child should sleep in poll(), never get a wake */
1954		pfd.fd = self->cfd2;
1955		pfd.events = POLLIN;
1956		EXPECT_EQ(poll(&pfd, 1, 20), 0);
1957
1958		EXPECT_EQ(write(p[1], &token, 1), 1); /* Barrier #1 */
1959
1960		pfd.fd = self->cfd2;
1961		pfd.events = POLLIN;
1962		EXPECT_EQ(poll(&pfd, 1, 20), 1);
1963
1964		exit(!__test_passed(_metadata));
1965	}
1966}
1967
1968TEST(non_established) {
1969	struct tls12_crypto_info_aes_gcm_256 tls12;
1970	struct sockaddr_in addr;
1971	int sfd, ret, fd;
1972	socklen_t len;
1973
1974	len = sizeof(addr);
1975
1976	memset(&tls12, 0, sizeof(tls12));
1977	tls12.info.version = TLS_1_2_VERSION;
1978	tls12.info.cipher_type = TLS_CIPHER_AES_GCM_256;
1979
1980	addr.sin_family = AF_INET;
1981	addr.sin_addr.s_addr = htonl(INADDR_ANY);
1982	addr.sin_port = 0;
1983
1984	fd = socket(AF_INET, SOCK_STREAM, 0);
1985	sfd = socket(AF_INET, SOCK_STREAM, 0);
1986
1987	ret = bind(sfd, &addr, sizeof(addr));
1988	ASSERT_EQ(ret, 0);
1989	ret = listen(sfd, 10);
1990	ASSERT_EQ(ret, 0);
1991
1992	ret = setsockopt(fd, IPPROTO_TCP, TCP_ULP, "tls", sizeof("tls"));
1993	EXPECT_EQ(ret, -1);
1994	/* TLS ULP not supported */
1995	if (errno == ENOENT)
1996		return;
1997	EXPECT_EQ(errno, ENOTCONN);
1998
1999	ret = setsockopt(sfd, IPPROTO_TCP, TCP_ULP, "tls", sizeof("tls"));
2000	EXPECT_EQ(ret, -1);
2001	EXPECT_EQ(errno, ENOTCONN);
2002
2003	ret = getsockname(sfd, &addr, &len);
2004	ASSERT_EQ(ret, 0);
2005
2006	ret = connect(fd, &addr, sizeof(addr));
2007	ASSERT_EQ(ret, 0);
2008
2009	ret = setsockopt(fd, IPPROTO_TCP, TCP_ULP, "tls", sizeof("tls"));
2010	ASSERT_EQ(ret, 0);
2011
2012	ret = setsockopt(fd, IPPROTO_TCP, TCP_ULP, "tls", sizeof("tls"));
2013	EXPECT_EQ(ret, -1);
2014	EXPECT_EQ(errno, EEXIST);
2015
2016	close(fd);
2017	close(sfd);
2018}
2019
2020TEST(keysizes) {
2021	struct tls12_crypto_info_aes_gcm_256 tls12;
2022	int ret, fd, cfd;
2023	bool notls;
2024
2025	memset(&tls12, 0, sizeof(tls12));
2026	tls12.info.version = TLS_1_2_VERSION;
2027	tls12.info.cipher_type = TLS_CIPHER_AES_GCM_256;
2028
2029	ulp_sock_pair(_metadata, &fd, &cfd, &notls);
2030
2031	if (!notls) {
2032		ret = setsockopt(fd, SOL_TLS, TLS_TX, &tls12,
2033				 sizeof(tls12));
2034		EXPECT_EQ(ret, 0);
2035
2036		ret = setsockopt(cfd, SOL_TLS, TLS_RX, &tls12,
2037				 sizeof(tls12));
2038		EXPECT_EQ(ret, 0);
2039	}
2040
2041	close(fd);
2042	close(cfd);
2043}
2044
2045TEST(no_pad) {
2046	struct tls12_crypto_info_aes_gcm_256 tls12;
2047	int ret, fd, cfd, val;
2048	socklen_t len;
2049	bool notls;
2050
2051	memset(&tls12, 0, sizeof(tls12));
2052	tls12.info.version = TLS_1_3_VERSION;
2053	tls12.info.cipher_type = TLS_CIPHER_AES_GCM_256;
2054
2055	ulp_sock_pair(_metadata, &fd, &cfd, &notls);
2056
2057	if (notls)
2058		exit(KSFT_SKIP);
2059
2060	ret = setsockopt(fd, SOL_TLS, TLS_TX, &tls12, sizeof(tls12));
2061	EXPECT_EQ(ret, 0);
2062
2063	ret = setsockopt(cfd, SOL_TLS, TLS_RX, &tls12, sizeof(tls12));
2064	EXPECT_EQ(ret, 0);
2065
2066	val = 1;
2067	ret = setsockopt(cfd, SOL_TLS, TLS_RX_EXPECT_NO_PAD,
2068			 (void *)&val, sizeof(val));
2069	EXPECT_EQ(ret, 0);
2070
2071	len = sizeof(val);
2072	val = 2;
2073	ret = getsockopt(cfd, SOL_TLS, TLS_RX_EXPECT_NO_PAD,
2074			 (void *)&val, &len);
2075	EXPECT_EQ(ret, 0);
2076	EXPECT_EQ(val, 1);
2077	EXPECT_EQ(len, 4);
2078
2079	val = 0;
2080	ret = setsockopt(cfd, SOL_TLS, TLS_RX_EXPECT_NO_PAD,
2081			 (void *)&val, sizeof(val));
2082	EXPECT_EQ(ret, 0);
2083
2084	len = sizeof(val);
2085	val = 2;
2086	ret = getsockopt(cfd, SOL_TLS, TLS_RX_EXPECT_NO_PAD,
2087			 (void *)&val, &len);
2088	EXPECT_EQ(ret, 0);
2089	EXPECT_EQ(val, 0);
2090	EXPECT_EQ(len, 4);
2091
2092	close(fd);
2093	close(cfd);
2094}
2095
2096TEST(tls_v6ops) {
2097	struct tls_crypto_info_keys tls12;
2098	struct sockaddr_in6 addr, addr2;
2099	int sfd, ret, fd;
2100	socklen_t len, len2;
2101
2102	tls_crypto_info_init(TLS_1_2_VERSION, TLS_CIPHER_AES_GCM_128, &tls12);
2103
2104	addr.sin6_family = AF_INET6;
2105	addr.sin6_addr = in6addr_any;
2106	addr.sin6_port = 0;
2107
2108	fd = socket(AF_INET6, SOCK_STREAM, 0);
2109	sfd = socket(AF_INET6, SOCK_STREAM, 0);
2110
2111	ret = bind(sfd, &addr, sizeof(addr));
2112	ASSERT_EQ(ret, 0);
2113	ret = listen(sfd, 10);
2114	ASSERT_EQ(ret, 0);
2115
2116	len = sizeof(addr);
2117	ret = getsockname(sfd, &addr, &len);
2118	ASSERT_EQ(ret, 0);
2119
2120	ret = connect(fd, &addr, sizeof(addr));
2121	ASSERT_EQ(ret, 0);
2122
2123	len = sizeof(addr);
2124	ret = getsockname(fd, &addr, &len);
2125	ASSERT_EQ(ret, 0);
2126
2127	ret = setsockopt(fd, IPPROTO_TCP, TCP_ULP, "tls", sizeof("tls"));
2128	if (ret) {
2129		ASSERT_EQ(errno, ENOENT);
2130		SKIP(return, "no TLS support");
2131	}
2132	ASSERT_EQ(ret, 0);
2133
2134	ret = setsockopt(fd, SOL_TLS, TLS_TX, &tls12, tls12.len);
2135	ASSERT_EQ(ret, 0);
2136
2137	ret = setsockopt(fd, SOL_TLS, TLS_RX, &tls12, tls12.len);
2138	ASSERT_EQ(ret, 0);
2139
2140	len2 = sizeof(addr2);
2141	ret = getsockname(fd, &addr2, &len2);
2142	ASSERT_EQ(ret, 0);
2143
2144	EXPECT_EQ(len2, len);
2145	EXPECT_EQ(memcmp(&addr, &addr2, len), 0);
2146
2147	close(fd);
2148	close(sfd);
2149}
2150
2151TEST(prequeue) {
2152	struct tls_crypto_info_keys tls12;
2153	char buf[20000], buf2[20000];
2154	struct sockaddr_in addr;
2155	int sfd, cfd, ret, fd;
2156	socklen_t len;
2157
2158	len = sizeof(addr);
2159	memrnd(buf, sizeof(buf));
2160
2161	tls_crypto_info_init(TLS_1_2_VERSION, TLS_CIPHER_AES_GCM_256, &tls12);
2162
2163	addr.sin_family = AF_INET;
2164	addr.sin_addr.s_addr = htonl(INADDR_ANY);
2165	addr.sin_port = 0;
2166
2167	fd = socket(AF_INET, SOCK_STREAM, 0);
2168	sfd = socket(AF_INET, SOCK_STREAM, 0);
2169
2170	ASSERT_EQ(bind(sfd, &addr, sizeof(addr)), 0);
2171	ASSERT_EQ(listen(sfd, 10), 0);
2172	ASSERT_EQ(getsockname(sfd, &addr, &len), 0);
2173	ASSERT_EQ(connect(fd, &addr, sizeof(addr)), 0);
2174	ASSERT_GE(cfd = accept(sfd, &addr, &len), 0);
2175	close(sfd);
2176
2177	ret = setsockopt(fd, IPPROTO_TCP, TCP_ULP, "tls", sizeof("tls"));
2178	if (ret) {
2179		ASSERT_EQ(errno, ENOENT);
2180		SKIP(return, "no TLS support");
2181	}
2182
2183	ASSERT_EQ(setsockopt(fd, SOL_TLS, TLS_TX, &tls12, tls12.len), 0);
2184	EXPECT_EQ(send(fd, buf, sizeof(buf), MSG_DONTWAIT), sizeof(buf));
2185
2186	ASSERT_EQ(setsockopt(cfd, IPPROTO_TCP, TCP_ULP, "tls", sizeof("tls")), 0);
2187	ASSERT_EQ(setsockopt(cfd, SOL_TLS, TLS_RX, &tls12, tls12.len), 0);
2188	EXPECT_EQ(recv(cfd, buf2, sizeof(buf2), MSG_WAITALL), sizeof(buf2));
2189
2190	EXPECT_EQ(memcmp(buf, buf2, sizeof(buf)), 0);
2191
2192	close(fd);
2193	close(cfd);
2194}
2195
2196static void __attribute__((constructor)) fips_check(void) {
2197	int res;
2198	FILE *f;
2199
2200	f = fopen("/proc/sys/crypto/fips_enabled", "r");
2201	if (f) {
2202		res = fscanf(f, "%d", &fips_enabled);
2203		if (res != 1)
2204			ksft_print_msg("ERROR: Couldn't read /proc/sys/crypto/fips_enabled\n");
2205		fclose(f);
2206	}
2207}
2208
2209TEST_HARNESS_MAIN
2210