1// SPDX-License-Identifier: GPL-2.0-only
2/*
3 * vsock_test - vsock.ko test suite
4 *
5 * Copyright (C) 2017 Red Hat, Inc.
6 *
7 * Author: Stefan Hajnoczi <stefanha@redhat.com>
8 */
9
10#include <getopt.h>
11#include <stdio.h>
12#include <stdlib.h>
13#include <string.h>
14#include <errno.h>
15#include <unistd.h>
16#include <linux/kernel.h>
17#include <sys/types.h>
18#include <sys/socket.h>
19#include <time.h>
20#include <sys/mman.h>
21#include <poll.h>
22#include <signal.h>
23
24#include "vsock_test_zerocopy.h"
25#include "timeout.h"
26#include "control.h"
27#include "util.h"
28
29static void test_stream_connection_reset(const struct test_opts *opts)
30{
31	union {
32		struct sockaddr sa;
33		struct sockaddr_vm svm;
34	} addr = {
35		.svm = {
36			.svm_family = AF_VSOCK,
37			.svm_port = opts->peer_port,
38			.svm_cid = opts->peer_cid,
39		},
40	};
41	int ret;
42	int fd;
43
44	fd = socket(AF_VSOCK, SOCK_STREAM, 0);
45
46	timeout_begin(TIMEOUT);
47	do {
48		ret = connect(fd, &addr.sa, sizeof(addr.svm));
49		timeout_check("connect");
50	} while (ret < 0 && errno == EINTR);
51	timeout_end();
52
53	if (ret != -1) {
54		fprintf(stderr, "expected connect(2) failure, got %d\n", ret);
55		exit(EXIT_FAILURE);
56	}
57	if (errno != ECONNRESET) {
58		fprintf(stderr, "unexpected connect(2) errno %d\n", errno);
59		exit(EXIT_FAILURE);
60	}
61
62	close(fd);
63}
64
65static void test_stream_bind_only_client(const struct test_opts *opts)
66{
67	union {
68		struct sockaddr sa;
69		struct sockaddr_vm svm;
70	} addr = {
71		.svm = {
72			.svm_family = AF_VSOCK,
73			.svm_port = opts->peer_port,
74			.svm_cid = opts->peer_cid,
75		},
76	};
77	int ret;
78	int fd;
79
80	/* Wait for the server to be ready */
81	control_expectln("BIND");
82
83	fd = socket(AF_VSOCK, SOCK_STREAM, 0);
84
85	timeout_begin(TIMEOUT);
86	do {
87		ret = connect(fd, &addr.sa, sizeof(addr.svm));
88		timeout_check("connect");
89	} while (ret < 0 && errno == EINTR);
90	timeout_end();
91
92	if (ret != -1) {
93		fprintf(stderr, "expected connect(2) failure, got %d\n", ret);
94		exit(EXIT_FAILURE);
95	}
96	if (errno != ECONNRESET) {
97		fprintf(stderr, "unexpected connect(2) errno %d\n", errno);
98		exit(EXIT_FAILURE);
99	}
100
101	/* Notify the server that the client has finished */
102	control_writeln("DONE");
103
104	close(fd);
105}
106
107static void test_stream_bind_only_server(const struct test_opts *opts)
108{
109	union {
110		struct sockaddr sa;
111		struct sockaddr_vm svm;
112	} addr = {
113		.svm = {
114			.svm_family = AF_VSOCK,
115			.svm_port = opts->peer_port,
116			.svm_cid = VMADDR_CID_ANY,
117		},
118	};
119	int fd;
120
121	fd = socket(AF_VSOCK, SOCK_STREAM, 0);
122
123	if (bind(fd, &addr.sa, sizeof(addr.svm)) < 0) {
124		perror("bind");
125		exit(EXIT_FAILURE);
126	}
127
128	/* Notify the client that the server is ready */
129	control_writeln("BIND");
130
131	/* Wait for the client to finish */
132	control_expectln("DONE");
133
134	close(fd);
135}
136
137static void test_stream_client_close_client(const struct test_opts *opts)
138{
139	int fd;
140
141	fd = vsock_stream_connect(opts->peer_cid, opts->peer_port);
142	if (fd < 0) {
143		perror("connect");
144		exit(EXIT_FAILURE);
145	}
146
147	send_byte(fd, 1, 0);
148	close(fd);
149}
150
151static void test_stream_client_close_server(const struct test_opts *opts)
152{
153	int fd;
154
155	fd = vsock_stream_accept(VMADDR_CID_ANY, opts->peer_port, NULL);
156	if (fd < 0) {
157		perror("accept");
158		exit(EXIT_FAILURE);
159	}
160
161	/* Wait for the remote to close the connection, before check
162	 * -EPIPE error on send.
163	 */
164	vsock_wait_remote_close(fd);
165
166	send_byte(fd, -EPIPE, 0);
167	recv_byte(fd, 1, 0);
168	recv_byte(fd, 0, 0);
169	close(fd);
170}
171
172static void test_stream_server_close_client(const struct test_opts *opts)
173{
174	int fd;
175
176	fd = vsock_stream_connect(opts->peer_cid, opts->peer_port);
177	if (fd < 0) {
178		perror("connect");
179		exit(EXIT_FAILURE);
180	}
181
182	/* Wait for the remote to close the connection, before check
183	 * -EPIPE error on send.
184	 */
185	vsock_wait_remote_close(fd);
186
187	send_byte(fd, -EPIPE, 0);
188	recv_byte(fd, 1, 0);
189	recv_byte(fd, 0, 0);
190	close(fd);
191}
192
193static void test_stream_server_close_server(const struct test_opts *opts)
194{
195	int fd;
196
197	fd = vsock_stream_accept(VMADDR_CID_ANY, opts->peer_port, NULL);
198	if (fd < 0) {
199		perror("accept");
200		exit(EXIT_FAILURE);
201	}
202
203	send_byte(fd, 1, 0);
204	close(fd);
205}
206
207/* With the standard socket sizes, VMCI is able to support about 100
208 * concurrent stream connections.
209 */
210#define MULTICONN_NFDS 100
211
212static void test_stream_multiconn_client(const struct test_opts *opts)
213{
214	int fds[MULTICONN_NFDS];
215	int i;
216
217	for (i = 0; i < MULTICONN_NFDS; i++) {
218		fds[i] = vsock_stream_connect(opts->peer_cid, opts->peer_port);
219		if (fds[i] < 0) {
220			perror("connect");
221			exit(EXIT_FAILURE);
222		}
223	}
224
225	for (i = 0; i < MULTICONN_NFDS; i++) {
226		if (i % 2)
227			recv_byte(fds[i], 1, 0);
228		else
229			send_byte(fds[i], 1, 0);
230	}
231
232	for (i = 0; i < MULTICONN_NFDS; i++)
233		close(fds[i]);
234}
235
236static void test_stream_multiconn_server(const struct test_opts *opts)
237{
238	int fds[MULTICONN_NFDS];
239	int i;
240
241	for (i = 0; i < MULTICONN_NFDS; i++) {
242		fds[i] = vsock_stream_accept(VMADDR_CID_ANY, opts->peer_port, NULL);
243		if (fds[i] < 0) {
244			perror("accept");
245			exit(EXIT_FAILURE);
246		}
247	}
248
249	for (i = 0; i < MULTICONN_NFDS; i++) {
250		if (i % 2)
251			send_byte(fds[i], 1, 0);
252		else
253			recv_byte(fds[i], 1, 0);
254	}
255
256	for (i = 0; i < MULTICONN_NFDS; i++)
257		close(fds[i]);
258}
259
260#define MSG_PEEK_BUF_LEN 64
261
262static void test_msg_peek_client(const struct test_opts *opts,
263				 bool seqpacket)
264{
265	unsigned char buf[MSG_PEEK_BUF_LEN];
266	int fd;
267	int i;
268
269	if (seqpacket)
270		fd = vsock_seqpacket_connect(opts->peer_cid, opts->peer_port);
271	else
272		fd = vsock_stream_connect(opts->peer_cid, opts->peer_port);
273
274	if (fd < 0) {
275		perror("connect");
276		exit(EXIT_FAILURE);
277	}
278
279	for (i = 0; i < sizeof(buf); i++)
280		buf[i] = rand() & 0xFF;
281
282	control_expectln("SRVREADY");
283
284	send_buf(fd, buf, sizeof(buf), 0, sizeof(buf));
285
286	close(fd);
287}
288
289static void test_msg_peek_server(const struct test_opts *opts,
290				 bool seqpacket)
291{
292	unsigned char buf_half[MSG_PEEK_BUF_LEN / 2];
293	unsigned char buf_normal[MSG_PEEK_BUF_LEN];
294	unsigned char buf_peek[MSG_PEEK_BUF_LEN];
295	int fd;
296
297	if (seqpacket)
298		fd = vsock_seqpacket_accept(VMADDR_CID_ANY, opts->peer_port, NULL);
299	else
300		fd = vsock_stream_accept(VMADDR_CID_ANY, opts->peer_port, NULL);
301
302	if (fd < 0) {
303		perror("accept");
304		exit(EXIT_FAILURE);
305	}
306
307	/* Peek from empty socket. */
308	recv_buf(fd, buf_peek, sizeof(buf_peek), MSG_PEEK | MSG_DONTWAIT,
309		 -EAGAIN);
310
311	control_writeln("SRVREADY");
312
313	/* Peek part of data. */
314	recv_buf(fd, buf_half, sizeof(buf_half), MSG_PEEK, sizeof(buf_half));
315
316	/* Peek whole data. */
317	recv_buf(fd, buf_peek, sizeof(buf_peek), MSG_PEEK, sizeof(buf_peek));
318
319	/* Compare partial and full peek. */
320	if (memcmp(buf_half, buf_peek, sizeof(buf_half))) {
321		fprintf(stderr, "Partial peek data mismatch\n");
322		exit(EXIT_FAILURE);
323	}
324
325	if (seqpacket) {
326		/* This type of socket supports MSG_TRUNC flag,
327		 * so check it with MSG_PEEK. We must get length
328		 * of the message.
329		 */
330		recv_buf(fd, buf_half, sizeof(buf_half), MSG_PEEK | MSG_TRUNC,
331			 sizeof(buf_peek));
332	}
333
334	recv_buf(fd, buf_normal, sizeof(buf_normal), 0, sizeof(buf_normal));
335
336	/* Compare full peek and normal read. */
337	if (memcmp(buf_peek, buf_normal, sizeof(buf_peek))) {
338		fprintf(stderr, "Full peek data mismatch\n");
339		exit(EXIT_FAILURE);
340	}
341
342	close(fd);
343}
344
345static void test_stream_msg_peek_client(const struct test_opts *opts)
346{
347	return test_msg_peek_client(opts, false);
348}
349
350static void test_stream_msg_peek_server(const struct test_opts *opts)
351{
352	return test_msg_peek_server(opts, false);
353}
354
355#define SOCK_BUF_SIZE (2 * 1024 * 1024)
356#define MAX_MSG_PAGES 4
357
358static void test_seqpacket_msg_bounds_client(const struct test_opts *opts)
359{
360	unsigned long curr_hash;
361	size_t max_msg_size;
362	int page_size;
363	int msg_count;
364	int fd;
365
366	fd = vsock_seqpacket_connect(opts->peer_cid, opts->peer_port);
367	if (fd < 0) {
368		perror("connect");
369		exit(EXIT_FAILURE);
370	}
371
372	/* Wait, until receiver sets buffer size. */
373	control_expectln("SRVREADY");
374
375	curr_hash = 0;
376	page_size = getpagesize();
377	max_msg_size = MAX_MSG_PAGES * page_size;
378	msg_count = SOCK_BUF_SIZE / max_msg_size;
379
380	for (int i = 0; i < msg_count; i++) {
381		size_t buf_size;
382		int flags;
383		void *buf;
384
385		/* Use "small" buffers and "big" buffers. */
386		if (i & 1)
387			buf_size = page_size +
388					(rand() % (max_msg_size - page_size));
389		else
390			buf_size = 1 + (rand() % page_size);
391
392		buf = malloc(buf_size);
393
394		if (!buf) {
395			perror("malloc");
396			exit(EXIT_FAILURE);
397		}
398
399		memset(buf, rand() & 0xff, buf_size);
400		/* Set at least one MSG_EOR + some random. */
401		if (i == (msg_count / 2) || (rand() & 1)) {
402			flags = MSG_EOR;
403			curr_hash++;
404		} else {
405			flags = 0;
406		}
407
408		send_buf(fd, buf, buf_size, flags, buf_size);
409
410		/*
411		 * Hash sum is computed at both client and server in
412		 * the same way:
413		 * H += hash('message data')
414		 * Such hash "controls" both data integrity and message
415		 * bounds. After data exchange, both sums are compared
416		 * using control socket, and if message bounds wasn't
417		 * broken - two values must be equal.
418		 */
419		curr_hash += hash_djb2(buf, buf_size);
420		free(buf);
421	}
422
423	control_writeln("SENDDONE");
424	control_writeulong(curr_hash);
425	close(fd);
426}
427
428static void test_seqpacket_msg_bounds_server(const struct test_opts *opts)
429{
430	unsigned long sock_buf_size;
431	unsigned long remote_hash;
432	unsigned long curr_hash;
433	int fd;
434	struct msghdr msg = {0};
435	struct iovec iov = {0};
436
437	fd = vsock_seqpacket_accept(VMADDR_CID_ANY, opts->peer_port, NULL);
438	if (fd < 0) {
439		perror("accept");
440		exit(EXIT_FAILURE);
441	}
442
443	sock_buf_size = SOCK_BUF_SIZE;
444
445	if (setsockopt(fd, AF_VSOCK, SO_VM_SOCKETS_BUFFER_MAX_SIZE,
446		       &sock_buf_size, sizeof(sock_buf_size))) {
447		perror("setsockopt(SO_VM_SOCKETS_BUFFER_MAX_SIZE)");
448		exit(EXIT_FAILURE);
449	}
450
451	if (setsockopt(fd, AF_VSOCK, SO_VM_SOCKETS_BUFFER_SIZE,
452		       &sock_buf_size, sizeof(sock_buf_size))) {
453		perror("setsockopt(SO_VM_SOCKETS_BUFFER_SIZE)");
454		exit(EXIT_FAILURE);
455	}
456
457	/* Ready to receive data. */
458	control_writeln("SRVREADY");
459	/* Wait, until peer sends whole data. */
460	control_expectln("SENDDONE");
461	iov.iov_len = MAX_MSG_PAGES * getpagesize();
462	iov.iov_base = malloc(iov.iov_len);
463	if (!iov.iov_base) {
464		perror("malloc");
465		exit(EXIT_FAILURE);
466	}
467
468	msg.msg_iov = &iov;
469	msg.msg_iovlen = 1;
470
471	curr_hash = 0;
472
473	while (1) {
474		ssize_t recv_size;
475
476		recv_size = recvmsg(fd, &msg, 0);
477
478		if (!recv_size)
479			break;
480
481		if (recv_size < 0) {
482			perror("recvmsg");
483			exit(EXIT_FAILURE);
484		}
485
486		if (msg.msg_flags & MSG_EOR)
487			curr_hash++;
488
489		curr_hash += hash_djb2(msg.msg_iov[0].iov_base, recv_size);
490	}
491
492	free(iov.iov_base);
493	close(fd);
494	remote_hash = control_readulong();
495
496	if (curr_hash != remote_hash) {
497		fprintf(stderr, "Message bounds broken\n");
498		exit(EXIT_FAILURE);
499	}
500}
501
502#define MESSAGE_TRUNC_SZ 32
503static void test_seqpacket_msg_trunc_client(const struct test_opts *opts)
504{
505	int fd;
506	char buf[MESSAGE_TRUNC_SZ];
507
508	fd = vsock_seqpacket_connect(opts->peer_cid, opts->peer_port);
509	if (fd < 0) {
510		perror("connect");
511		exit(EXIT_FAILURE);
512	}
513
514	send_buf(fd, buf, sizeof(buf), 0, sizeof(buf));
515
516	control_writeln("SENDDONE");
517	close(fd);
518}
519
520static void test_seqpacket_msg_trunc_server(const struct test_opts *opts)
521{
522	int fd;
523	char buf[MESSAGE_TRUNC_SZ / 2];
524	struct msghdr msg = {0};
525	struct iovec iov = {0};
526
527	fd = vsock_seqpacket_accept(VMADDR_CID_ANY, opts->peer_port, NULL);
528	if (fd < 0) {
529		perror("accept");
530		exit(EXIT_FAILURE);
531	}
532
533	control_expectln("SENDDONE");
534	iov.iov_base = buf;
535	iov.iov_len = sizeof(buf);
536	msg.msg_iov = &iov;
537	msg.msg_iovlen = 1;
538
539	ssize_t ret = recvmsg(fd, &msg, MSG_TRUNC);
540
541	if (ret != MESSAGE_TRUNC_SZ) {
542		printf("%zi\n", ret);
543		perror("MSG_TRUNC doesn't work");
544		exit(EXIT_FAILURE);
545	}
546
547	if (!(msg.msg_flags & MSG_TRUNC)) {
548		fprintf(stderr, "MSG_TRUNC expected\n");
549		exit(EXIT_FAILURE);
550	}
551
552	close(fd);
553}
554
555static time_t current_nsec(void)
556{
557	struct timespec ts;
558
559	if (clock_gettime(CLOCK_REALTIME, &ts)) {
560		perror("clock_gettime(3) failed");
561		exit(EXIT_FAILURE);
562	}
563
564	return (ts.tv_sec * 1000000000ULL) + ts.tv_nsec;
565}
566
567#define RCVTIMEO_TIMEOUT_SEC 1
568#define READ_OVERHEAD_NSEC 250000000 /* 0.25 sec */
569
570static void test_seqpacket_timeout_client(const struct test_opts *opts)
571{
572	int fd;
573	struct timeval tv;
574	char dummy;
575	time_t read_enter_ns;
576	time_t read_overhead_ns;
577
578	fd = vsock_seqpacket_connect(opts->peer_cid, opts->peer_port);
579	if (fd < 0) {
580		perror("connect");
581		exit(EXIT_FAILURE);
582	}
583
584	tv.tv_sec = RCVTIMEO_TIMEOUT_SEC;
585	tv.tv_usec = 0;
586
587	if (setsockopt(fd, SOL_SOCKET, SO_RCVTIMEO, (void *)&tv, sizeof(tv)) == -1) {
588		perror("setsockopt(SO_RCVTIMEO)");
589		exit(EXIT_FAILURE);
590	}
591
592	read_enter_ns = current_nsec();
593
594	if (read(fd, &dummy, sizeof(dummy)) != -1) {
595		fprintf(stderr,
596			"expected 'dummy' read(2) failure\n");
597		exit(EXIT_FAILURE);
598	}
599
600	if (errno != EAGAIN) {
601		perror("EAGAIN expected");
602		exit(EXIT_FAILURE);
603	}
604
605	read_overhead_ns = current_nsec() - read_enter_ns -
606			1000000000ULL * RCVTIMEO_TIMEOUT_SEC;
607
608	if (read_overhead_ns > READ_OVERHEAD_NSEC) {
609		fprintf(stderr,
610			"too much time in read(2), %lu > %i ns\n",
611			read_overhead_ns, READ_OVERHEAD_NSEC);
612		exit(EXIT_FAILURE);
613	}
614
615	control_writeln("WAITDONE");
616	close(fd);
617}
618
619static void test_seqpacket_timeout_server(const struct test_opts *opts)
620{
621	int fd;
622
623	fd = vsock_seqpacket_accept(VMADDR_CID_ANY, opts->peer_port, NULL);
624	if (fd < 0) {
625		perror("accept");
626		exit(EXIT_FAILURE);
627	}
628
629	control_expectln("WAITDONE");
630	close(fd);
631}
632
633static void test_seqpacket_bigmsg_client(const struct test_opts *opts)
634{
635	unsigned long sock_buf_size;
636	socklen_t len;
637	void *data;
638	int fd;
639
640	len = sizeof(sock_buf_size);
641
642	fd = vsock_seqpacket_connect(opts->peer_cid, opts->peer_port);
643	if (fd < 0) {
644		perror("connect");
645		exit(EXIT_FAILURE);
646	}
647
648	if (getsockopt(fd, AF_VSOCK, SO_VM_SOCKETS_BUFFER_SIZE,
649		       &sock_buf_size, &len)) {
650		perror("getsockopt");
651		exit(EXIT_FAILURE);
652	}
653
654	sock_buf_size++;
655
656	data = malloc(sock_buf_size);
657	if (!data) {
658		perror("malloc");
659		exit(EXIT_FAILURE);
660	}
661
662	send_buf(fd, data, sock_buf_size, 0, -EMSGSIZE);
663
664	control_writeln("CLISENT");
665
666	free(data);
667	close(fd);
668}
669
670static void test_seqpacket_bigmsg_server(const struct test_opts *opts)
671{
672	int fd;
673
674	fd = vsock_seqpacket_accept(VMADDR_CID_ANY, opts->peer_port, NULL);
675	if (fd < 0) {
676		perror("accept");
677		exit(EXIT_FAILURE);
678	}
679
680	control_expectln("CLISENT");
681
682	close(fd);
683}
684
685#define BUF_PATTERN_1 'a'
686#define BUF_PATTERN_2 'b'
687
688static void test_seqpacket_invalid_rec_buffer_client(const struct test_opts *opts)
689{
690	int fd;
691	unsigned char *buf1;
692	unsigned char *buf2;
693	int buf_size = getpagesize() * 3;
694
695	fd = vsock_seqpacket_connect(opts->peer_cid, opts->peer_port);
696	if (fd < 0) {
697		perror("connect");
698		exit(EXIT_FAILURE);
699	}
700
701	buf1 = malloc(buf_size);
702	if (!buf1) {
703		perror("'malloc()' for 'buf1'");
704		exit(EXIT_FAILURE);
705	}
706
707	buf2 = malloc(buf_size);
708	if (!buf2) {
709		perror("'malloc()' for 'buf2'");
710		exit(EXIT_FAILURE);
711	}
712
713	memset(buf1, BUF_PATTERN_1, buf_size);
714	memset(buf2, BUF_PATTERN_2, buf_size);
715
716	send_buf(fd, buf1, buf_size, 0, buf_size);
717
718	send_buf(fd, buf2, buf_size, 0, buf_size);
719
720	close(fd);
721}
722
723static void test_seqpacket_invalid_rec_buffer_server(const struct test_opts *opts)
724{
725	int fd;
726	unsigned char *broken_buf;
727	unsigned char *valid_buf;
728	int page_size = getpagesize();
729	int buf_size = page_size * 3;
730	ssize_t res;
731	int prot = PROT_READ | PROT_WRITE;
732	int flags = MAP_PRIVATE | MAP_ANONYMOUS;
733	int i;
734
735	fd = vsock_seqpacket_accept(VMADDR_CID_ANY, opts->peer_port, NULL);
736	if (fd < 0) {
737		perror("accept");
738		exit(EXIT_FAILURE);
739	}
740
741	/* Setup first buffer. */
742	broken_buf = mmap(NULL, buf_size, prot, flags, -1, 0);
743	if (broken_buf == MAP_FAILED) {
744		perror("mmap for 'broken_buf'");
745		exit(EXIT_FAILURE);
746	}
747
748	/* Unmap "hole" in buffer. */
749	if (munmap(broken_buf + page_size, page_size)) {
750		perror("'broken_buf' setup");
751		exit(EXIT_FAILURE);
752	}
753
754	valid_buf = mmap(NULL, buf_size, prot, flags, -1, 0);
755	if (valid_buf == MAP_FAILED) {
756		perror("mmap for 'valid_buf'");
757		exit(EXIT_FAILURE);
758	}
759
760	/* Try to fill buffer with unmapped middle. */
761	res = read(fd, broken_buf, buf_size);
762	if (res != -1) {
763		fprintf(stderr,
764			"expected 'broken_buf' read(2) failure, got %zi\n",
765			res);
766		exit(EXIT_FAILURE);
767	}
768
769	if (errno != EFAULT) {
770		perror("unexpected errno of 'broken_buf'");
771		exit(EXIT_FAILURE);
772	}
773
774	/* Try to fill valid buffer. */
775	res = read(fd, valid_buf, buf_size);
776	if (res < 0) {
777		perror("unexpected 'valid_buf' read(2) failure");
778		exit(EXIT_FAILURE);
779	}
780
781	if (res != buf_size) {
782		fprintf(stderr,
783			"invalid 'valid_buf' read(2), expected %i, got %zi\n",
784			buf_size, res);
785		exit(EXIT_FAILURE);
786	}
787
788	for (i = 0; i < buf_size; i++) {
789		if (valid_buf[i] != BUF_PATTERN_2) {
790			fprintf(stderr,
791				"invalid pattern for 'valid_buf' at %i, expected %hhX, got %hhX\n",
792				i, BUF_PATTERN_2, valid_buf[i]);
793			exit(EXIT_FAILURE);
794		}
795	}
796
797	/* Unmap buffers. */
798	munmap(broken_buf, page_size);
799	munmap(broken_buf + page_size * 2, page_size);
800	munmap(valid_buf, buf_size);
801	close(fd);
802}
803
804#define RCVLOWAT_BUF_SIZE 128
805
806static void test_stream_poll_rcvlowat_server(const struct test_opts *opts)
807{
808	int fd;
809	int i;
810
811	fd = vsock_stream_accept(VMADDR_CID_ANY, opts->peer_port, NULL);
812	if (fd < 0) {
813		perror("accept");
814		exit(EXIT_FAILURE);
815	}
816
817	/* Send 1 byte. */
818	send_byte(fd, 1, 0);
819
820	control_writeln("SRVSENT");
821
822	/* Wait until client is ready to receive rest of data. */
823	control_expectln("CLNSENT");
824
825	for (i = 0; i < RCVLOWAT_BUF_SIZE - 1; i++)
826		send_byte(fd, 1, 0);
827
828	/* Keep socket in active state. */
829	control_expectln("POLLDONE");
830
831	close(fd);
832}
833
834static void test_stream_poll_rcvlowat_client(const struct test_opts *opts)
835{
836	unsigned long lowat_val = RCVLOWAT_BUF_SIZE;
837	char buf[RCVLOWAT_BUF_SIZE];
838	struct pollfd fds;
839	short poll_flags;
840	int fd;
841
842	fd = vsock_stream_connect(opts->peer_cid, opts->peer_port);
843	if (fd < 0) {
844		perror("connect");
845		exit(EXIT_FAILURE);
846	}
847
848	if (setsockopt(fd, SOL_SOCKET, SO_RCVLOWAT,
849		       &lowat_val, sizeof(lowat_val))) {
850		perror("setsockopt(SO_RCVLOWAT)");
851		exit(EXIT_FAILURE);
852	}
853
854	control_expectln("SRVSENT");
855
856	/* At this point, server sent 1 byte. */
857	fds.fd = fd;
858	poll_flags = POLLIN | POLLRDNORM;
859	fds.events = poll_flags;
860
861	/* Try to wait for 1 sec. */
862	if (poll(&fds, 1, 1000) < 0) {
863		perror("poll");
864		exit(EXIT_FAILURE);
865	}
866
867	/* poll() must return nothing. */
868	if (fds.revents) {
869		fprintf(stderr, "Unexpected poll result %hx\n",
870			fds.revents);
871		exit(EXIT_FAILURE);
872	}
873
874	/* Tell server to send rest of data. */
875	control_writeln("CLNSENT");
876
877	/* Poll for data. */
878	if (poll(&fds, 1, 10000) < 0) {
879		perror("poll");
880		exit(EXIT_FAILURE);
881	}
882
883	/* Only these two bits are expected. */
884	if (fds.revents != poll_flags) {
885		fprintf(stderr, "Unexpected poll result %hx\n",
886			fds.revents);
887		exit(EXIT_FAILURE);
888	}
889
890	/* Use MSG_DONTWAIT, if call is going to wait, EAGAIN
891	 * will be returned.
892	 */
893	recv_buf(fd, buf, sizeof(buf), MSG_DONTWAIT, RCVLOWAT_BUF_SIZE);
894
895	control_writeln("POLLDONE");
896
897	close(fd);
898}
899
900#define INV_BUF_TEST_DATA_LEN 512
901
902static void test_inv_buf_client(const struct test_opts *opts, bool stream)
903{
904	unsigned char data[INV_BUF_TEST_DATA_LEN] = {0};
905	ssize_t expected_ret;
906	int fd;
907
908	if (stream)
909		fd = vsock_stream_connect(opts->peer_cid, opts->peer_port);
910	else
911		fd = vsock_seqpacket_connect(opts->peer_cid, opts->peer_port);
912
913	if (fd < 0) {
914		perror("connect");
915		exit(EXIT_FAILURE);
916	}
917
918	control_expectln("SENDDONE");
919
920	/* Use invalid buffer here. */
921	recv_buf(fd, NULL, sizeof(data), 0, -EFAULT);
922
923	if (stream) {
924		/* For SOCK_STREAM we must continue reading. */
925		expected_ret = sizeof(data);
926	} else {
927		/* For SOCK_SEQPACKET socket's queue must be empty. */
928		expected_ret = -EAGAIN;
929	}
930
931	recv_buf(fd, data, sizeof(data), MSG_DONTWAIT, expected_ret);
932
933	control_writeln("DONE");
934
935	close(fd);
936}
937
938static void test_inv_buf_server(const struct test_opts *opts, bool stream)
939{
940	unsigned char data[INV_BUF_TEST_DATA_LEN] = {0};
941	int fd;
942
943	if (stream)
944		fd = vsock_stream_accept(VMADDR_CID_ANY, opts->peer_port, NULL);
945	else
946		fd = vsock_seqpacket_accept(VMADDR_CID_ANY, opts->peer_port, NULL);
947
948	if (fd < 0) {
949		perror("accept");
950		exit(EXIT_FAILURE);
951	}
952
953	send_buf(fd, data, sizeof(data), 0, sizeof(data));
954
955	control_writeln("SENDDONE");
956
957	control_expectln("DONE");
958
959	close(fd);
960}
961
962static void test_stream_inv_buf_client(const struct test_opts *opts)
963{
964	test_inv_buf_client(opts, true);
965}
966
967static void test_stream_inv_buf_server(const struct test_opts *opts)
968{
969	test_inv_buf_server(opts, true);
970}
971
972static void test_seqpacket_inv_buf_client(const struct test_opts *opts)
973{
974	test_inv_buf_client(opts, false);
975}
976
977static void test_seqpacket_inv_buf_server(const struct test_opts *opts)
978{
979	test_inv_buf_server(opts, false);
980}
981
982#define HELLO_STR "HELLO"
983#define WORLD_STR "WORLD"
984
985static void test_stream_virtio_skb_merge_client(const struct test_opts *opts)
986{
987	int fd;
988
989	fd = vsock_stream_connect(opts->peer_cid, opts->peer_port);
990	if (fd < 0) {
991		perror("connect");
992		exit(EXIT_FAILURE);
993	}
994
995	/* Send first skbuff. */
996	send_buf(fd, HELLO_STR, strlen(HELLO_STR), 0, strlen(HELLO_STR));
997
998	control_writeln("SEND0");
999	/* Peer reads part of first skbuff. */
1000	control_expectln("REPLY0");
1001
1002	/* Send second skbuff, it will be appended to the first. */
1003	send_buf(fd, WORLD_STR, strlen(WORLD_STR), 0, strlen(WORLD_STR));
1004
1005	control_writeln("SEND1");
1006	/* Peer reads merged skbuff packet. */
1007	control_expectln("REPLY1");
1008
1009	close(fd);
1010}
1011
1012static void test_stream_virtio_skb_merge_server(const struct test_opts *opts)
1013{
1014	size_t read = 0, to_read;
1015	unsigned char buf[64];
1016	int fd;
1017
1018	fd = vsock_stream_accept(VMADDR_CID_ANY, opts->peer_port, NULL);
1019	if (fd < 0) {
1020		perror("accept");
1021		exit(EXIT_FAILURE);
1022	}
1023
1024	control_expectln("SEND0");
1025
1026	/* Read skbuff partially. */
1027	to_read = 2;
1028	recv_buf(fd, buf + read, to_read, 0, to_read);
1029	read += to_read;
1030
1031	control_writeln("REPLY0");
1032	control_expectln("SEND1");
1033
1034	/* Read the rest of both buffers */
1035	to_read = strlen(HELLO_STR WORLD_STR) - read;
1036	recv_buf(fd, buf + read, to_read, 0, to_read);
1037	read += to_read;
1038
1039	/* No more bytes should be there */
1040	to_read = sizeof(buf) - read;
1041	recv_buf(fd, buf + read, to_read, MSG_DONTWAIT, -EAGAIN);
1042
1043	if (memcmp(buf, HELLO_STR WORLD_STR, strlen(HELLO_STR WORLD_STR))) {
1044		fprintf(stderr, "pattern mismatch\n");
1045		exit(EXIT_FAILURE);
1046	}
1047
1048	control_writeln("REPLY1");
1049
1050	close(fd);
1051}
1052
1053static void test_seqpacket_msg_peek_client(const struct test_opts *opts)
1054{
1055	return test_msg_peek_client(opts, true);
1056}
1057
1058static void test_seqpacket_msg_peek_server(const struct test_opts *opts)
1059{
1060	return test_msg_peek_server(opts, true);
1061}
1062
1063static sig_atomic_t have_sigpipe;
1064
1065static void sigpipe(int signo)
1066{
1067	have_sigpipe = 1;
1068}
1069
1070static void test_stream_check_sigpipe(int fd)
1071{
1072	ssize_t res;
1073
1074	have_sigpipe = 0;
1075
1076	res = send(fd, "A", 1, 0);
1077	if (res != -1) {
1078		fprintf(stderr, "expected send(2) failure, got %zi\n", res);
1079		exit(EXIT_FAILURE);
1080	}
1081
1082	if (!have_sigpipe) {
1083		fprintf(stderr, "SIGPIPE expected\n");
1084		exit(EXIT_FAILURE);
1085	}
1086
1087	have_sigpipe = 0;
1088
1089	res = send(fd, "A", 1, MSG_NOSIGNAL);
1090	if (res != -1) {
1091		fprintf(stderr, "expected send(2) failure, got %zi\n", res);
1092		exit(EXIT_FAILURE);
1093	}
1094
1095	if (have_sigpipe) {
1096		fprintf(stderr, "SIGPIPE not expected\n");
1097		exit(EXIT_FAILURE);
1098	}
1099}
1100
1101static void test_stream_shutwr_client(const struct test_opts *opts)
1102{
1103	int fd;
1104
1105	struct sigaction act = {
1106		.sa_handler = sigpipe,
1107	};
1108
1109	sigaction(SIGPIPE, &act, NULL);
1110
1111	fd = vsock_stream_connect(opts->peer_cid, opts->peer_port);
1112	if (fd < 0) {
1113		perror("connect");
1114		exit(EXIT_FAILURE);
1115	}
1116
1117	if (shutdown(fd, SHUT_WR)) {
1118		perror("shutdown");
1119		exit(EXIT_FAILURE);
1120	}
1121
1122	test_stream_check_sigpipe(fd);
1123
1124	control_writeln("CLIENTDONE");
1125
1126	close(fd);
1127}
1128
1129static void test_stream_shutwr_server(const struct test_opts *opts)
1130{
1131	int fd;
1132
1133	fd = vsock_stream_accept(VMADDR_CID_ANY, opts->peer_port, NULL);
1134	if (fd < 0) {
1135		perror("accept");
1136		exit(EXIT_FAILURE);
1137	}
1138
1139	control_expectln("CLIENTDONE");
1140
1141	close(fd);
1142}
1143
1144static void test_stream_shutrd_client(const struct test_opts *opts)
1145{
1146	int fd;
1147
1148	struct sigaction act = {
1149		.sa_handler = sigpipe,
1150	};
1151
1152	sigaction(SIGPIPE, &act, NULL);
1153
1154	fd = vsock_stream_connect(opts->peer_cid, opts->peer_port);
1155	if (fd < 0) {
1156		perror("connect");
1157		exit(EXIT_FAILURE);
1158	}
1159
1160	control_expectln("SHUTRDDONE");
1161
1162	test_stream_check_sigpipe(fd);
1163
1164	control_writeln("CLIENTDONE");
1165
1166	close(fd);
1167}
1168
1169static void test_stream_shutrd_server(const struct test_opts *opts)
1170{
1171	int fd;
1172
1173	fd = vsock_stream_accept(VMADDR_CID_ANY, opts->peer_port, NULL);
1174	if (fd < 0) {
1175		perror("accept");
1176		exit(EXIT_FAILURE);
1177	}
1178
1179	if (shutdown(fd, SHUT_RD)) {
1180		perror("shutdown");
1181		exit(EXIT_FAILURE);
1182	}
1183
1184	control_writeln("SHUTRDDONE");
1185	control_expectln("CLIENTDONE");
1186
1187	close(fd);
1188}
1189
1190static void test_double_bind_connect_server(const struct test_opts *opts)
1191{
1192	int listen_fd, client_fd, i;
1193	struct sockaddr_vm sa_client;
1194	socklen_t socklen_client = sizeof(sa_client);
1195
1196	listen_fd = vsock_stream_listen(VMADDR_CID_ANY, opts->peer_port);
1197
1198	for (i = 0; i < 2; i++) {
1199		control_writeln("LISTENING");
1200
1201		timeout_begin(TIMEOUT);
1202		do {
1203			client_fd = accept(listen_fd, (struct sockaddr *)&sa_client,
1204					   &socklen_client);
1205			timeout_check("accept");
1206		} while (client_fd < 0 && errno == EINTR);
1207		timeout_end();
1208
1209		if (client_fd < 0) {
1210			perror("accept");
1211			exit(EXIT_FAILURE);
1212		}
1213
1214		/* Waiting for remote peer to close connection */
1215		vsock_wait_remote_close(client_fd);
1216	}
1217
1218	close(listen_fd);
1219}
1220
1221static void test_double_bind_connect_client(const struct test_opts *opts)
1222{
1223	int i, client_fd;
1224
1225	for (i = 0; i < 2; i++) {
1226		/* Wait until server is ready to accept a new connection */
1227		control_expectln("LISTENING");
1228
1229		/* We use 'peer_port + 1' as "some" port for the 'bind()'
1230		 * call. It is safe for overflow, but must be considered,
1231		 * when running multiple test applications simultaneously
1232		 * where 'peer-port' argument differs by 1.
1233		 */
1234		client_fd = vsock_bind_connect(opts->peer_cid, opts->peer_port,
1235					       opts->peer_port + 1, SOCK_STREAM);
1236
1237		close(client_fd);
1238	}
1239}
1240
1241#define RCVLOWAT_CREDIT_UPD_BUF_SIZE	(1024 * 128)
1242/* This define is the same as in 'include/linux/virtio_vsock.h':
1243 * it is used to decide when to send credit update message during
1244 * reading from rx queue of a socket. Value and its usage in
1245 * kernel is important for this test.
1246 */
1247#define VIRTIO_VSOCK_MAX_PKT_BUF_SIZE	(1024 * 64)
1248
1249static void test_stream_rcvlowat_def_cred_upd_client(const struct test_opts *opts)
1250{
1251	size_t buf_size;
1252	void *buf;
1253	int fd;
1254
1255	fd = vsock_stream_connect(opts->peer_cid, opts->peer_port);
1256	if (fd < 0) {
1257		perror("connect");
1258		exit(EXIT_FAILURE);
1259	}
1260
1261	/* Send 1 byte more than peer's buffer size. */
1262	buf_size = RCVLOWAT_CREDIT_UPD_BUF_SIZE + 1;
1263
1264	buf = malloc(buf_size);
1265	if (!buf) {
1266		perror("malloc");
1267		exit(EXIT_FAILURE);
1268	}
1269
1270	/* Wait until peer sets needed buffer size. */
1271	recv_byte(fd, 1, 0);
1272
1273	if (send(fd, buf, buf_size, 0) != buf_size) {
1274		perror("send failed");
1275		exit(EXIT_FAILURE);
1276	}
1277
1278	free(buf);
1279	close(fd);
1280}
1281
1282static void test_stream_credit_update_test(const struct test_opts *opts,
1283					   bool low_rx_bytes_test)
1284{
1285	size_t recv_buf_size;
1286	struct pollfd fds;
1287	size_t buf_size;
1288	void *buf;
1289	int fd;
1290
1291	fd = vsock_stream_accept(VMADDR_CID_ANY, opts->peer_port, NULL);
1292	if (fd < 0) {
1293		perror("accept");
1294		exit(EXIT_FAILURE);
1295	}
1296
1297	buf_size = RCVLOWAT_CREDIT_UPD_BUF_SIZE;
1298
1299	if (setsockopt(fd, AF_VSOCK, SO_VM_SOCKETS_BUFFER_SIZE,
1300		       &buf_size, sizeof(buf_size))) {
1301		perror("setsockopt(SO_VM_SOCKETS_BUFFER_SIZE)");
1302		exit(EXIT_FAILURE);
1303	}
1304
1305	if (low_rx_bytes_test) {
1306		/* Set new SO_RCVLOWAT here. This enables sending credit
1307		 * update when number of bytes if our rx queue become <
1308		 * SO_RCVLOWAT value.
1309		 */
1310		recv_buf_size = 1 + VIRTIO_VSOCK_MAX_PKT_BUF_SIZE;
1311
1312		if (setsockopt(fd, SOL_SOCKET, SO_RCVLOWAT,
1313			       &recv_buf_size, sizeof(recv_buf_size))) {
1314			perror("setsockopt(SO_RCVLOWAT)");
1315			exit(EXIT_FAILURE);
1316		}
1317	}
1318
1319	/* Send one dummy byte here, because 'setsockopt()' above also
1320	 * sends special packet which tells sender to update our buffer
1321	 * size. This 'send_byte()' will serialize such packet with data
1322	 * reads in a loop below. Sender starts transmission only when
1323	 * it receives this single byte.
1324	 */
1325	send_byte(fd, 1, 0);
1326
1327	buf = malloc(buf_size);
1328	if (!buf) {
1329		perror("malloc");
1330		exit(EXIT_FAILURE);
1331	}
1332
1333	/* Wait until there will be 128KB of data in rx queue. */
1334	while (1) {
1335		ssize_t res;
1336
1337		res = recv(fd, buf, buf_size, MSG_PEEK);
1338		if (res == buf_size)
1339			break;
1340
1341		if (res <= 0) {
1342			fprintf(stderr, "unexpected 'recv()' return: %zi\n", res);
1343			exit(EXIT_FAILURE);
1344		}
1345	}
1346
1347	/* There is 128KB of data in the socket's rx queue, dequeue first
1348	 * 64KB, credit update is sent if 'low_rx_bytes_test' == true.
1349	 * Otherwise, credit update is sent in 'if (!low_rx_bytes_test)'.
1350	 */
1351	recv_buf_size = VIRTIO_VSOCK_MAX_PKT_BUF_SIZE;
1352	recv_buf(fd, buf, recv_buf_size, 0, recv_buf_size);
1353
1354	if (!low_rx_bytes_test) {
1355		recv_buf_size++;
1356
1357		/* Updating SO_RCVLOWAT will send credit update. */
1358		if (setsockopt(fd, SOL_SOCKET, SO_RCVLOWAT,
1359			       &recv_buf_size, sizeof(recv_buf_size))) {
1360			perror("setsockopt(SO_RCVLOWAT)");
1361			exit(EXIT_FAILURE);
1362		}
1363	}
1364
1365	fds.fd = fd;
1366	fds.events = POLLIN | POLLRDNORM | POLLERR |
1367		     POLLRDHUP | POLLHUP;
1368
1369	/* This 'poll()' will return once we receive last byte
1370	 * sent by client.
1371	 */
1372	if (poll(&fds, 1, -1) < 0) {
1373		perror("poll");
1374		exit(EXIT_FAILURE);
1375	}
1376
1377	if (fds.revents & POLLERR) {
1378		fprintf(stderr, "'poll()' error\n");
1379		exit(EXIT_FAILURE);
1380	}
1381
1382	if (fds.revents & (POLLIN | POLLRDNORM)) {
1383		recv_buf(fd, buf, recv_buf_size, MSG_DONTWAIT, recv_buf_size);
1384	} else {
1385		/* These flags must be set, as there is at
1386		 * least 64KB of data ready to read.
1387		 */
1388		fprintf(stderr, "POLLIN | POLLRDNORM expected\n");
1389		exit(EXIT_FAILURE);
1390	}
1391
1392	free(buf);
1393	close(fd);
1394}
1395
1396static void test_stream_cred_upd_on_low_rx_bytes(const struct test_opts *opts)
1397{
1398	test_stream_credit_update_test(opts, true);
1399}
1400
1401static void test_stream_cred_upd_on_set_rcvlowat(const struct test_opts *opts)
1402{
1403	test_stream_credit_update_test(opts, false);
1404}
1405
1406static struct test_case test_cases[] = {
1407	{
1408		.name = "SOCK_STREAM connection reset",
1409		.run_client = test_stream_connection_reset,
1410	},
1411	{
1412		.name = "SOCK_STREAM bind only",
1413		.run_client = test_stream_bind_only_client,
1414		.run_server = test_stream_bind_only_server,
1415	},
1416	{
1417		.name = "SOCK_STREAM client close",
1418		.run_client = test_stream_client_close_client,
1419		.run_server = test_stream_client_close_server,
1420	},
1421	{
1422		.name = "SOCK_STREAM server close",
1423		.run_client = test_stream_server_close_client,
1424		.run_server = test_stream_server_close_server,
1425	},
1426	{
1427		.name = "SOCK_STREAM multiple connections",
1428		.run_client = test_stream_multiconn_client,
1429		.run_server = test_stream_multiconn_server,
1430	},
1431	{
1432		.name = "SOCK_STREAM MSG_PEEK",
1433		.run_client = test_stream_msg_peek_client,
1434		.run_server = test_stream_msg_peek_server,
1435	},
1436	{
1437		.name = "SOCK_SEQPACKET msg bounds",
1438		.run_client = test_seqpacket_msg_bounds_client,
1439		.run_server = test_seqpacket_msg_bounds_server,
1440	},
1441	{
1442		.name = "SOCK_SEQPACKET MSG_TRUNC flag",
1443		.run_client = test_seqpacket_msg_trunc_client,
1444		.run_server = test_seqpacket_msg_trunc_server,
1445	},
1446	{
1447		.name = "SOCK_SEQPACKET timeout",
1448		.run_client = test_seqpacket_timeout_client,
1449		.run_server = test_seqpacket_timeout_server,
1450	},
1451	{
1452		.name = "SOCK_SEQPACKET invalid receive buffer",
1453		.run_client = test_seqpacket_invalid_rec_buffer_client,
1454		.run_server = test_seqpacket_invalid_rec_buffer_server,
1455	},
1456	{
1457		.name = "SOCK_STREAM poll() + SO_RCVLOWAT",
1458		.run_client = test_stream_poll_rcvlowat_client,
1459		.run_server = test_stream_poll_rcvlowat_server,
1460	},
1461	{
1462		.name = "SOCK_SEQPACKET big message",
1463		.run_client = test_seqpacket_bigmsg_client,
1464		.run_server = test_seqpacket_bigmsg_server,
1465	},
1466	{
1467		.name = "SOCK_STREAM test invalid buffer",
1468		.run_client = test_stream_inv_buf_client,
1469		.run_server = test_stream_inv_buf_server,
1470	},
1471	{
1472		.name = "SOCK_SEQPACKET test invalid buffer",
1473		.run_client = test_seqpacket_inv_buf_client,
1474		.run_server = test_seqpacket_inv_buf_server,
1475	},
1476	{
1477		.name = "SOCK_STREAM virtio skb merge",
1478		.run_client = test_stream_virtio_skb_merge_client,
1479		.run_server = test_stream_virtio_skb_merge_server,
1480	},
1481	{
1482		.name = "SOCK_SEQPACKET MSG_PEEK",
1483		.run_client = test_seqpacket_msg_peek_client,
1484		.run_server = test_seqpacket_msg_peek_server,
1485	},
1486	{
1487		.name = "SOCK_STREAM SHUT_WR",
1488		.run_client = test_stream_shutwr_client,
1489		.run_server = test_stream_shutwr_server,
1490	},
1491	{
1492		.name = "SOCK_STREAM SHUT_RD",
1493		.run_client = test_stream_shutrd_client,
1494		.run_server = test_stream_shutrd_server,
1495	},
1496	{
1497		.name = "SOCK_STREAM MSG_ZEROCOPY",
1498		.run_client = test_stream_msgzcopy_client,
1499		.run_server = test_stream_msgzcopy_server,
1500	},
1501	{
1502		.name = "SOCK_SEQPACKET MSG_ZEROCOPY",
1503		.run_client = test_seqpacket_msgzcopy_client,
1504		.run_server = test_seqpacket_msgzcopy_server,
1505	},
1506	{
1507		.name = "SOCK_STREAM MSG_ZEROCOPY empty MSG_ERRQUEUE",
1508		.run_client = test_stream_msgzcopy_empty_errq_client,
1509		.run_server = test_stream_msgzcopy_empty_errq_server,
1510	},
1511	{
1512		.name = "SOCK_STREAM double bind connect",
1513		.run_client = test_double_bind_connect_client,
1514		.run_server = test_double_bind_connect_server,
1515	},
1516	{
1517		.name = "SOCK_STREAM virtio credit update + SO_RCVLOWAT",
1518		.run_client = test_stream_rcvlowat_def_cred_upd_client,
1519		.run_server = test_stream_cred_upd_on_set_rcvlowat,
1520	},
1521	{
1522		.name = "SOCK_STREAM virtio credit update + low rx_bytes",
1523		.run_client = test_stream_rcvlowat_def_cred_upd_client,
1524		.run_server = test_stream_cred_upd_on_low_rx_bytes,
1525	},
1526	{},
1527};
1528
1529static const char optstring[] = "";
1530static const struct option longopts[] = {
1531	{
1532		.name = "control-host",
1533		.has_arg = required_argument,
1534		.val = 'H',
1535	},
1536	{
1537		.name = "control-port",
1538		.has_arg = required_argument,
1539		.val = 'P',
1540	},
1541	{
1542		.name = "mode",
1543		.has_arg = required_argument,
1544		.val = 'm',
1545	},
1546	{
1547		.name = "peer-cid",
1548		.has_arg = required_argument,
1549		.val = 'p',
1550	},
1551	{
1552		.name = "peer-port",
1553		.has_arg = required_argument,
1554		.val = 'q',
1555	},
1556	{
1557		.name = "list",
1558		.has_arg = no_argument,
1559		.val = 'l',
1560	},
1561	{
1562		.name = "skip",
1563		.has_arg = required_argument,
1564		.val = 's',
1565	},
1566	{
1567		.name = "help",
1568		.has_arg = no_argument,
1569		.val = '?',
1570	},
1571	{},
1572};
1573
1574static void usage(void)
1575{
1576	fprintf(stderr, "Usage: vsock_test [--help] [--control-host=<host>] --control-port=<port> --mode=client|server --peer-cid=<cid> [--peer-port=<port>] [--list] [--skip=<test_id>]\n"
1577		"\n"
1578		"  Server: vsock_test --control-port=1234 --mode=server --peer-cid=3\n"
1579		"  Client: vsock_test --control-host=192.168.0.1 --control-port=1234 --mode=client --peer-cid=2\n"
1580		"\n"
1581		"Run vsock.ko tests.  Must be launched in both guest\n"
1582		"and host.  One side must use --mode=client and\n"
1583		"the other side must use --mode=server.\n"
1584		"\n"
1585		"A TCP control socket connection is used to coordinate tests\n"
1586		"between the client and the server.  The server requires a\n"
1587		"listen address and the client requires an address to\n"
1588		"connect to.\n"
1589		"\n"
1590		"The CID of the other side must be given with --peer-cid=<cid>.\n"
1591		"During the test, two AF_VSOCK ports will be used: the port\n"
1592		"specified with --peer-port=<port> (or the default port)\n"
1593		"and the next one.\n"
1594		"\n"
1595		"Options:\n"
1596		"  --help                 This help message\n"
1597		"  --control-host <host>  Server IP address to connect to\n"
1598		"  --control-port <port>  Server port to listen on/connect to\n"
1599		"  --mode client|server   Server or client mode\n"
1600		"  --peer-cid <cid>       CID of the other side\n"
1601		"  --peer-port <port>     AF_VSOCK port used for the test [default: %d]\n"
1602		"  --list                 List of tests that will be executed\n"
1603		"  --skip <test_id>       Test ID to skip;\n"
1604		"                         use multiple --skip options to skip more tests\n",
1605		DEFAULT_PEER_PORT
1606		);
1607	exit(EXIT_FAILURE);
1608}
1609
1610int main(int argc, char **argv)
1611{
1612	const char *control_host = NULL;
1613	const char *control_port = NULL;
1614	struct test_opts opts = {
1615		.mode = TEST_MODE_UNSET,
1616		.peer_cid = VMADDR_CID_ANY,
1617		.peer_port = DEFAULT_PEER_PORT,
1618	};
1619
1620	srand(time(NULL));
1621	init_signals();
1622
1623	for (;;) {
1624		int opt = getopt_long(argc, argv, optstring, longopts, NULL);
1625
1626		if (opt == -1)
1627			break;
1628
1629		switch (opt) {
1630		case 'H':
1631			control_host = optarg;
1632			break;
1633		case 'm':
1634			if (strcmp(optarg, "client") == 0)
1635				opts.mode = TEST_MODE_CLIENT;
1636			else if (strcmp(optarg, "server") == 0)
1637				opts.mode = TEST_MODE_SERVER;
1638			else {
1639				fprintf(stderr, "--mode must be \"client\" or \"server\"\n");
1640				return EXIT_FAILURE;
1641			}
1642			break;
1643		case 'p':
1644			opts.peer_cid = parse_cid(optarg);
1645			break;
1646		case 'q':
1647			opts.peer_port = parse_port(optarg);
1648			break;
1649		case 'P':
1650			control_port = optarg;
1651			break;
1652		case 'l':
1653			list_tests(test_cases);
1654			break;
1655		case 's':
1656			skip_test(test_cases, ARRAY_SIZE(test_cases) - 1,
1657				  optarg);
1658			break;
1659		case '?':
1660		default:
1661			usage();
1662		}
1663	}
1664
1665	if (!control_port)
1666		usage();
1667	if (opts.mode == TEST_MODE_UNSET)
1668		usage();
1669	if (opts.peer_cid == VMADDR_CID_ANY)
1670		usage();
1671
1672	if (!control_host) {
1673		if (opts.mode != TEST_MODE_SERVER)
1674			usage();
1675		control_host = "0.0.0.0";
1676	}
1677
1678	control_init(control_host, control_port,
1679		     opts.mode == TEST_MODE_SERVER);
1680
1681	run_tests(test_cases, &opts);
1682
1683	control_cleanup();
1684	return EXIT_SUCCESS;
1685}
1686