1// SPDX-License-Identifier: GPL-2.0-only
2/*
3 * vsock_diag_test - vsock_diag.ko test suite
4 *
5 * Copyright (C) 2017 Red Hat, Inc.
6 *
7 * Author: Stefan Hajnoczi <stefanha@redhat.com>
8 */
9
10#include <getopt.h>
11#include <stdio.h>
12#include <stdlib.h>
13#include <string.h>
14#include <errno.h>
15#include <unistd.h>
16#include <sys/stat.h>
17#include <sys/types.h>
18#include <linux/list.h>
19#include <linux/net.h>
20#include <linux/netlink.h>
21#include <linux/sock_diag.h>
22#include <linux/vm_sockets_diag.h>
23#include <netinet/tcp.h>
24
25#include "timeout.h"
26#include "control.h"
27#include "util.h"
28
29/* Per-socket status */
30struct vsock_stat {
31	struct list_head list;
32	struct vsock_diag_msg msg;
33};
34
35static const char *sock_type_str(int type)
36{
37	switch (type) {
38	case SOCK_DGRAM:
39		return "DGRAM";
40	case SOCK_STREAM:
41		return "STREAM";
42	case SOCK_SEQPACKET:
43		return "SEQPACKET";
44	default:
45		return "INVALID TYPE";
46	}
47}
48
49static const char *sock_state_str(int state)
50{
51	switch (state) {
52	case TCP_CLOSE:
53		return "UNCONNECTED";
54	case TCP_SYN_SENT:
55		return "CONNECTING";
56	case TCP_ESTABLISHED:
57		return "CONNECTED";
58	case TCP_CLOSING:
59		return "DISCONNECTING";
60	case TCP_LISTEN:
61		return "LISTEN";
62	default:
63		return "INVALID STATE";
64	}
65}
66
67static const char *sock_shutdown_str(int shutdown)
68{
69	switch (shutdown) {
70	case 1:
71		return "RCV_SHUTDOWN";
72	case 2:
73		return "SEND_SHUTDOWN";
74	case 3:
75		return "RCV_SHUTDOWN | SEND_SHUTDOWN";
76	default:
77		return "0";
78	}
79}
80
81static void print_vsock_addr(FILE *fp, unsigned int cid, unsigned int port)
82{
83	if (cid == VMADDR_CID_ANY)
84		fprintf(fp, "*:");
85	else
86		fprintf(fp, "%u:", cid);
87
88	if (port == VMADDR_PORT_ANY)
89		fprintf(fp, "*");
90	else
91		fprintf(fp, "%u", port);
92}
93
94static void print_vsock_stat(FILE *fp, struct vsock_stat *st)
95{
96	print_vsock_addr(fp, st->msg.vdiag_src_cid, st->msg.vdiag_src_port);
97	fprintf(fp, " ");
98	print_vsock_addr(fp, st->msg.vdiag_dst_cid, st->msg.vdiag_dst_port);
99	fprintf(fp, " %s %s %s %u\n",
100		sock_type_str(st->msg.vdiag_type),
101		sock_state_str(st->msg.vdiag_state),
102		sock_shutdown_str(st->msg.vdiag_shutdown),
103		st->msg.vdiag_ino);
104}
105
106static void print_vsock_stats(FILE *fp, struct list_head *head)
107{
108	struct vsock_stat *st;
109
110	list_for_each_entry(st, head, list)
111		print_vsock_stat(fp, st);
112}
113
114static struct vsock_stat *find_vsock_stat(struct list_head *head, int fd)
115{
116	struct vsock_stat *st;
117	struct stat stat;
118
119	if (fstat(fd, &stat) < 0) {
120		perror("fstat");
121		exit(EXIT_FAILURE);
122	}
123
124	list_for_each_entry(st, head, list)
125		if (st->msg.vdiag_ino == stat.st_ino)
126			return st;
127
128	fprintf(stderr, "cannot find fd %d\n", fd);
129	exit(EXIT_FAILURE);
130}
131
132static void check_no_sockets(struct list_head *head)
133{
134	if (!list_empty(head)) {
135		fprintf(stderr, "expected no sockets\n");
136		print_vsock_stats(stderr, head);
137		exit(1);
138	}
139}
140
141static void check_num_sockets(struct list_head *head, int expected)
142{
143	struct list_head *node;
144	int n = 0;
145
146	list_for_each(node, head)
147		n++;
148
149	if (n != expected) {
150		fprintf(stderr, "expected %d sockets, found %d\n",
151			expected, n);
152		print_vsock_stats(stderr, head);
153		exit(EXIT_FAILURE);
154	}
155}
156
157static void check_socket_state(struct vsock_stat *st, __u8 state)
158{
159	if (st->msg.vdiag_state != state) {
160		fprintf(stderr, "expected socket state %#x, got %#x\n",
161			state, st->msg.vdiag_state);
162		exit(EXIT_FAILURE);
163	}
164}
165
166static void send_req(int fd)
167{
168	struct sockaddr_nl nladdr = {
169		.nl_family = AF_NETLINK,
170	};
171	struct {
172		struct nlmsghdr nlh;
173		struct vsock_diag_req vreq;
174	} req = {
175		.nlh = {
176			.nlmsg_len = sizeof(req),
177			.nlmsg_type = SOCK_DIAG_BY_FAMILY,
178			.nlmsg_flags = NLM_F_REQUEST | NLM_F_DUMP,
179		},
180		.vreq = {
181			.sdiag_family = AF_VSOCK,
182			.vdiag_states = ~(__u32)0,
183		},
184	};
185	struct iovec iov = {
186		.iov_base = &req,
187		.iov_len = sizeof(req),
188	};
189	struct msghdr msg = {
190		.msg_name = &nladdr,
191		.msg_namelen = sizeof(nladdr),
192		.msg_iov = &iov,
193		.msg_iovlen = 1,
194	};
195
196	for (;;) {
197		if (sendmsg(fd, &msg, 0) < 0) {
198			if (errno == EINTR)
199				continue;
200
201			perror("sendmsg");
202			exit(EXIT_FAILURE);
203		}
204
205		return;
206	}
207}
208
209static ssize_t recv_resp(int fd, void *buf, size_t len)
210{
211	struct sockaddr_nl nladdr = {
212		.nl_family = AF_NETLINK,
213	};
214	struct iovec iov = {
215		.iov_base = buf,
216		.iov_len = len,
217	};
218	struct msghdr msg = {
219		.msg_name = &nladdr,
220		.msg_namelen = sizeof(nladdr),
221		.msg_iov = &iov,
222		.msg_iovlen = 1,
223	};
224	ssize_t ret;
225
226	do {
227		ret = recvmsg(fd, &msg, 0);
228	} while (ret < 0 && errno == EINTR);
229
230	if (ret < 0) {
231		perror("recvmsg");
232		exit(EXIT_FAILURE);
233	}
234
235	return ret;
236}
237
238static void add_vsock_stat(struct list_head *sockets,
239			   const struct vsock_diag_msg *resp)
240{
241	struct vsock_stat *st;
242
243	st = malloc(sizeof(*st));
244	if (!st) {
245		perror("malloc");
246		exit(EXIT_FAILURE);
247	}
248
249	st->msg = *resp;
250	list_add_tail(&st->list, sockets);
251}
252
253/*
254 * Read vsock stats into a list.
255 */
256static void read_vsock_stat(struct list_head *sockets)
257{
258	long buf[8192 / sizeof(long)];
259	int fd;
260
261	fd = socket(AF_NETLINK, SOCK_RAW, NETLINK_SOCK_DIAG);
262	if (fd < 0) {
263		perror("socket");
264		exit(EXIT_FAILURE);
265	}
266
267	send_req(fd);
268
269	for (;;) {
270		const struct nlmsghdr *h;
271		ssize_t ret;
272
273		ret = recv_resp(fd, buf, sizeof(buf));
274		if (ret == 0)
275			goto done;
276		if (ret < sizeof(*h)) {
277			fprintf(stderr, "short read of %zd bytes\n", ret);
278			exit(EXIT_FAILURE);
279		}
280
281		h = (struct nlmsghdr *)buf;
282
283		while (NLMSG_OK(h, ret)) {
284			if (h->nlmsg_type == NLMSG_DONE)
285				goto done;
286
287			if (h->nlmsg_type == NLMSG_ERROR) {
288				const struct nlmsgerr *err = NLMSG_DATA(h);
289
290				if (h->nlmsg_len < NLMSG_LENGTH(sizeof(*err)))
291					fprintf(stderr, "NLMSG_ERROR\n");
292				else {
293					errno = -err->error;
294					perror("NLMSG_ERROR");
295				}
296
297				exit(EXIT_FAILURE);
298			}
299
300			if (h->nlmsg_type != SOCK_DIAG_BY_FAMILY) {
301				fprintf(stderr, "unexpected nlmsg_type %#x\n",
302					h->nlmsg_type);
303				exit(EXIT_FAILURE);
304			}
305			if (h->nlmsg_len <
306			    NLMSG_LENGTH(sizeof(struct vsock_diag_msg))) {
307				fprintf(stderr, "short vsock_diag_msg\n");
308				exit(EXIT_FAILURE);
309			}
310
311			add_vsock_stat(sockets, NLMSG_DATA(h));
312
313			h = NLMSG_NEXT(h, ret);
314		}
315	}
316
317done:
318	close(fd);
319}
320
321static void free_sock_stat(struct list_head *sockets)
322{
323	struct vsock_stat *st;
324	struct vsock_stat *next;
325
326	list_for_each_entry_safe(st, next, sockets, list)
327		free(st);
328}
329
330static void test_no_sockets(const struct test_opts *opts)
331{
332	LIST_HEAD(sockets);
333
334	read_vsock_stat(&sockets);
335
336	check_no_sockets(&sockets);
337}
338
339static void test_listen_socket_server(const struct test_opts *opts)
340{
341	union {
342		struct sockaddr sa;
343		struct sockaddr_vm svm;
344	} addr = {
345		.svm = {
346			.svm_family = AF_VSOCK,
347			.svm_port = opts->peer_port,
348			.svm_cid = VMADDR_CID_ANY,
349		},
350	};
351	LIST_HEAD(sockets);
352	struct vsock_stat *st;
353	int fd;
354
355	fd = socket(AF_VSOCK, SOCK_STREAM, 0);
356
357	if (bind(fd, &addr.sa, sizeof(addr.svm)) < 0) {
358		perror("bind");
359		exit(EXIT_FAILURE);
360	}
361
362	if (listen(fd, 1) < 0) {
363		perror("listen");
364		exit(EXIT_FAILURE);
365	}
366
367	read_vsock_stat(&sockets);
368
369	check_num_sockets(&sockets, 1);
370	st = find_vsock_stat(&sockets, fd);
371	check_socket_state(st, TCP_LISTEN);
372
373	close(fd);
374	free_sock_stat(&sockets);
375}
376
377static void test_connect_client(const struct test_opts *opts)
378{
379	int fd;
380	LIST_HEAD(sockets);
381	struct vsock_stat *st;
382
383	fd = vsock_stream_connect(opts->peer_cid, opts->peer_port);
384	if (fd < 0) {
385		perror("connect");
386		exit(EXIT_FAILURE);
387	}
388
389	read_vsock_stat(&sockets);
390
391	check_num_sockets(&sockets, 1);
392	st = find_vsock_stat(&sockets, fd);
393	check_socket_state(st, TCP_ESTABLISHED);
394
395	control_expectln("DONE");
396	control_writeln("DONE");
397
398	close(fd);
399	free_sock_stat(&sockets);
400}
401
402static void test_connect_server(const struct test_opts *opts)
403{
404	struct vsock_stat *st;
405	LIST_HEAD(sockets);
406	int client_fd;
407
408	client_fd = vsock_stream_accept(VMADDR_CID_ANY, opts->peer_port, NULL);
409	if (client_fd < 0) {
410		perror("accept");
411		exit(EXIT_FAILURE);
412	}
413
414	read_vsock_stat(&sockets);
415
416	check_num_sockets(&sockets, 1);
417	st = find_vsock_stat(&sockets, client_fd);
418	check_socket_state(st, TCP_ESTABLISHED);
419
420	control_writeln("DONE");
421	control_expectln("DONE");
422
423	close(client_fd);
424	free_sock_stat(&sockets);
425}
426
427static struct test_case test_cases[] = {
428	{
429		.name = "No sockets",
430		.run_server = test_no_sockets,
431	},
432	{
433		.name = "Listen socket",
434		.run_server = test_listen_socket_server,
435	},
436	{
437		.name = "Connect",
438		.run_client = test_connect_client,
439		.run_server = test_connect_server,
440	},
441	{},
442};
443
444static const char optstring[] = "";
445static const struct option longopts[] = {
446	{
447		.name = "control-host",
448		.has_arg = required_argument,
449		.val = 'H',
450	},
451	{
452		.name = "control-port",
453		.has_arg = required_argument,
454		.val = 'P',
455	},
456	{
457		.name = "mode",
458		.has_arg = required_argument,
459		.val = 'm',
460	},
461	{
462		.name = "peer-cid",
463		.has_arg = required_argument,
464		.val = 'p',
465	},
466	{
467		.name = "peer-port",
468		.has_arg = required_argument,
469		.val = 'q',
470	},
471	{
472		.name = "list",
473		.has_arg = no_argument,
474		.val = 'l',
475	},
476	{
477		.name = "skip",
478		.has_arg = required_argument,
479		.val = 's',
480	},
481	{
482		.name = "help",
483		.has_arg = no_argument,
484		.val = '?',
485	},
486	{},
487};
488
489static void usage(void)
490{
491	fprintf(stderr, "Usage: vsock_diag_test [--help] [--control-host=<host>] --control-port=<port> --mode=client|server --peer-cid=<cid> [--peer-port=<port>] [--list] [--skip=<test_id>]\n"
492		"\n"
493		"  Server: vsock_diag_test --control-port=1234 --mode=server --peer-cid=3\n"
494		"  Client: vsock_diag_test --control-host=192.168.0.1 --control-port=1234 --mode=client --peer-cid=2\n"
495		"\n"
496		"Run vsock_diag.ko tests.  Must be launched in both\n"
497		"guest and host.  One side must use --mode=client and\n"
498		"the other side must use --mode=server.\n"
499		"\n"
500		"A TCP control socket connection is used to coordinate tests\n"
501		"between the client and the server.  The server requires a\n"
502		"listen address and the client requires an address to\n"
503		"connect to.\n"
504		"\n"
505		"The CID of the other side must be given with --peer-cid=<cid>.\n"
506		"\n"
507		"Options:\n"
508		"  --help                 This help message\n"
509		"  --control-host <host>  Server IP address to connect to\n"
510		"  --control-port <port>  Server port to listen on/connect to\n"
511		"  --mode client|server   Server or client mode\n"
512		"  --peer-cid <cid>       CID of the other side\n"
513		"  --peer-port <port>     AF_VSOCK port used for the test [default: %d]\n"
514		"  --list                 List of tests that will be executed\n"
515		"  --skip <test_id>       Test ID to skip;\n"
516		"                         use multiple --skip options to skip more tests\n",
517		DEFAULT_PEER_PORT
518		);
519	exit(EXIT_FAILURE);
520}
521
522int main(int argc, char **argv)
523{
524	const char *control_host = NULL;
525	const char *control_port = NULL;
526	struct test_opts opts = {
527		.mode = TEST_MODE_UNSET,
528		.peer_cid = VMADDR_CID_ANY,
529		.peer_port = DEFAULT_PEER_PORT,
530	};
531
532	init_signals();
533
534	for (;;) {
535		int opt = getopt_long(argc, argv, optstring, longopts, NULL);
536
537		if (opt == -1)
538			break;
539
540		switch (opt) {
541		case 'H':
542			control_host = optarg;
543			break;
544		case 'm':
545			if (strcmp(optarg, "client") == 0)
546				opts.mode = TEST_MODE_CLIENT;
547			else if (strcmp(optarg, "server") == 0)
548				opts.mode = TEST_MODE_SERVER;
549			else {
550				fprintf(stderr, "--mode must be \"client\" or \"server\"\n");
551				return EXIT_FAILURE;
552			}
553			break;
554		case 'p':
555			opts.peer_cid = parse_cid(optarg);
556			break;
557		case 'q':
558			opts.peer_port = parse_port(optarg);
559			break;
560		case 'P':
561			control_port = optarg;
562			break;
563		case 'l':
564			list_tests(test_cases);
565			break;
566		case 's':
567			skip_test(test_cases, ARRAY_SIZE(test_cases) - 1,
568				  optarg);
569			break;
570		case '?':
571		default:
572			usage();
573		}
574	}
575
576	if (!control_port)
577		usage();
578	if (opts.mode == TEST_MODE_UNSET)
579		usage();
580	if (opts.peer_cid == VMADDR_CID_ANY)
581		usage();
582
583	if (!control_host) {
584		if (opts.mode != TEST_MODE_SERVER)
585			usage();
586		control_host = "0.0.0.0";
587	}
588
589	control_init(control_host, control_port,
590		     opts.mode == TEST_MODE_SERVER);
591
592	run_tests(test_cases, &opts);
593
594	control_cleanup();
595	return EXIT_SUCCESS;
596}
597