1// SPDX-License-Identifier: GPL-2.0-only
2/* io_uring tests for vsock
3 *
4 * Copyright (C) 2023 SberDevices.
5 *
6 * Author: Arseniy Krasnov <avkrasnov@salutedevices.com>
7 */
8
9#include <getopt.h>
10#include <stdio.h>
11#include <stdlib.h>
12#include <string.h>
13#include <liburing.h>
14#include <unistd.h>
15#include <sys/mman.h>
16#include <linux/kernel.h>
17#include <error.h>
18
19#include "util.h"
20#include "control.h"
21#include "msg_zerocopy_common.h"
22
23#ifndef PAGE_SIZE
24#define PAGE_SIZE		4096
25#endif
26
27#define RING_ENTRIES_NUM	4
28
29#define VSOCK_TEST_DATA_MAX_IOV 3
30
31struct vsock_io_uring_test {
32	/* Number of valid elements in 'vecs'. */
33	int vecs_cnt;
34	struct iovec vecs[VSOCK_TEST_DATA_MAX_IOV];
35};
36
37static struct vsock_io_uring_test test_data_array[] = {
38	/* All elements have page aligned base and size. */
39	{
40		.vecs_cnt = 3,
41		{
42			{ NULL, PAGE_SIZE },
43			{ NULL, 2 * PAGE_SIZE },
44			{ NULL, 3 * PAGE_SIZE },
45		}
46	},
47	/* Middle element has both non-page aligned base and size. */
48	{
49		.vecs_cnt = 3,
50		{
51			{ NULL, PAGE_SIZE },
52			{ (void *)1, 200  },
53			{ NULL, 3 * PAGE_SIZE },
54		}
55	}
56};
57
58static void vsock_io_uring_client(const struct test_opts *opts,
59				  const struct vsock_io_uring_test *test_data,
60				  bool msg_zerocopy)
61{
62	struct io_uring_sqe *sqe;
63	struct io_uring_cqe *cqe;
64	struct io_uring ring;
65	struct iovec *iovec;
66	struct msghdr msg;
67	int fd;
68
69	fd = vsock_stream_connect(opts->peer_cid, opts->peer_port);
70	if (fd < 0) {
71		perror("connect");
72		exit(EXIT_FAILURE);
73	}
74
75	if (msg_zerocopy)
76		enable_so_zerocopy(fd);
77
78	iovec = alloc_test_iovec(test_data->vecs, test_data->vecs_cnt);
79
80	if (io_uring_queue_init(RING_ENTRIES_NUM, &ring, 0))
81		error(1, errno, "io_uring_queue_init");
82
83	if (io_uring_register_buffers(&ring, iovec, test_data->vecs_cnt))
84		error(1, errno, "io_uring_register_buffers");
85
86	memset(&msg, 0, sizeof(msg));
87	msg.msg_iov = iovec;
88	msg.msg_iovlen = test_data->vecs_cnt;
89	sqe = io_uring_get_sqe(&ring);
90
91	if (msg_zerocopy)
92		io_uring_prep_sendmsg_zc(sqe, fd, &msg, 0);
93	else
94		io_uring_prep_sendmsg(sqe, fd, &msg, 0);
95
96	if (io_uring_submit(&ring) != 1)
97		error(1, errno, "io_uring_submit");
98
99	if (io_uring_wait_cqe(&ring, &cqe))
100		error(1, errno, "io_uring_wait_cqe");
101
102	io_uring_cqe_seen(&ring, cqe);
103
104	control_writeulong(iovec_hash_djb2(iovec, test_data->vecs_cnt));
105
106	control_writeln("DONE");
107	io_uring_queue_exit(&ring);
108	free_test_iovec(test_data->vecs, iovec, test_data->vecs_cnt);
109	close(fd);
110}
111
112static void vsock_io_uring_server(const struct test_opts *opts,
113				  const struct vsock_io_uring_test *test_data)
114{
115	unsigned long remote_hash;
116	unsigned long local_hash;
117	struct io_uring ring;
118	size_t data_len;
119	size_t recv_len;
120	void *data;
121	int fd;
122
123	fd = vsock_stream_accept(VMADDR_CID_ANY, opts->peer_port, NULL);
124	if (fd < 0) {
125		perror("accept");
126		exit(EXIT_FAILURE);
127	}
128
129	data_len = iovec_bytes(test_data->vecs, test_data->vecs_cnt);
130
131	data = malloc(data_len);
132	if (!data) {
133		perror("malloc");
134		exit(EXIT_FAILURE);
135	}
136
137	if (io_uring_queue_init(RING_ENTRIES_NUM, &ring, 0))
138		error(1, errno, "io_uring_queue_init");
139
140	recv_len = 0;
141
142	while (recv_len < data_len) {
143		struct io_uring_sqe *sqe;
144		struct io_uring_cqe *cqe;
145		struct iovec iovec;
146
147		sqe = io_uring_get_sqe(&ring);
148		iovec.iov_base = data + recv_len;
149		iovec.iov_len = data_len;
150
151		io_uring_prep_readv(sqe, fd, &iovec, 1, 0);
152
153		if (io_uring_submit(&ring) != 1)
154			error(1, errno, "io_uring_submit");
155
156		if (io_uring_wait_cqe(&ring, &cqe))
157			error(1, errno, "io_uring_wait_cqe");
158
159		recv_len += cqe->res;
160		io_uring_cqe_seen(&ring, cqe);
161	}
162
163	if (recv_len != data_len) {
164		fprintf(stderr, "expected %zu, got %zu\n", data_len,
165			recv_len);
166		exit(EXIT_FAILURE);
167	}
168
169	local_hash = hash_djb2(data, data_len);
170
171	remote_hash = control_readulong();
172	if (remote_hash != local_hash) {
173		fprintf(stderr, "hash mismatch\n");
174		exit(EXIT_FAILURE);
175	}
176
177	control_expectln("DONE");
178	io_uring_queue_exit(&ring);
179	free(data);
180}
181
182void test_stream_uring_server(const struct test_opts *opts)
183{
184	int i;
185
186	for (i = 0; i < ARRAY_SIZE(test_data_array); i++)
187		vsock_io_uring_server(opts, &test_data_array[i]);
188}
189
190void test_stream_uring_client(const struct test_opts *opts)
191{
192	int i;
193
194	for (i = 0; i < ARRAY_SIZE(test_data_array); i++)
195		vsock_io_uring_client(opts, &test_data_array[i], false);
196}
197
198void test_stream_uring_msg_zc_server(const struct test_opts *opts)
199{
200	int i;
201
202	for (i = 0; i < ARRAY_SIZE(test_data_array); i++)
203		vsock_io_uring_server(opts, &test_data_array[i]);
204}
205
206void test_stream_uring_msg_zc_client(const struct test_opts *opts)
207{
208	int i;
209
210	for (i = 0; i < ARRAY_SIZE(test_data_array); i++)
211		vsock_io_uring_client(opts, &test_data_array[i], true);
212}
213
214static struct test_case test_cases[] = {
215	{
216		.name = "SOCK_STREAM io_uring test",
217		.run_server = test_stream_uring_server,
218		.run_client = test_stream_uring_client,
219	},
220	{
221		.name = "SOCK_STREAM io_uring MSG_ZEROCOPY test",
222		.run_server = test_stream_uring_msg_zc_server,
223		.run_client = test_stream_uring_msg_zc_client,
224	},
225	{},
226};
227
228static const char optstring[] = "";
229static const struct option longopts[] = {
230	{
231		.name = "control-host",
232		.has_arg = required_argument,
233		.val = 'H',
234	},
235	{
236		.name = "control-port",
237		.has_arg = required_argument,
238		.val = 'P',
239	},
240	{
241		.name = "mode",
242		.has_arg = required_argument,
243		.val = 'm',
244	},
245	{
246		.name = "peer-cid",
247		.has_arg = required_argument,
248		.val = 'p',
249	},
250	{
251		.name = "peer-port",
252		.has_arg = required_argument,
253		.val = 'q',
254	},
255	{
256		.name = "help",
257		.has_arg = no_argument,
258		.val = '?',
259	},
260	{},
261};
262
263static void usage(void)
264{
265	fprintf(stderr, "Usage: vsock_uring_test [--help] [--control-host=<host>] --control-port=<port> --mode=client|server --peer-cid=<cid> [--peer-port=<port>]\n"
266		"\n"
267		"  Server: vsock_uring_test --control-port=1234 --mode=server --peer-cid=3\n"
268		"  Client: vsock_uring_test --control-host=192.168.0.1 --control-port=1234 --mode=client --peer-cid=2\n"
269		"\n"
270		"Run transmission tests using io_uring. Usage is the same as\n"
271		"in ./vsock_test\n"
272		"\n"
273		"Options:\n"
274		"  --help                 This help message\n"
275		"  --control-host <host>  Server IP address to connect to\n"
276		"  --control-port <port>  Server port to listen on/connect to\n"
277		"  --mode client|server   Server or client mode\n"
278		"  --peer-cid <cid>       CID of the other side\n"
279		"  --peer-port <port>     AF_VSOCK port used for the test [default: %d]\n",
280		DEFAULT_PEER_PORT
281		);
282	exit(EXIT_FAILURE);
283}
284
285int main(int argc, char **argv)
286{
287	const char *control_host = NULL;
288	const char *control_port = NULL;
289	struct test_opts opts = {
290		.mode = TEST_MODE_UNSET,
291		.peer_cid = VMADDR_CID_ANY,
292		.peer_port = DEFAULT_PEER_PORT,
293	};
294
295	init_signals();
296
297	for (;;) {
298		int opt = getopt_long(argc, argv, optstring, longopts, NULL);
299
300		if (opt == -1)
301			break;
302
303		switch (opt) {
304		case 'H':
305			control_host = optarg;
306			break;
307		case 'm':
308			if (strcmp(optarg, "client") == 0) {
309				opts.mode = TEST_MODE_CLIENT;
310			} else if (strcmp(optarg, "server") == 0) {
311				opts.mode = TEST_MODE_SERVER;
312			} else {
313				fprintf(stderr, "--mode must be \"client\" or \"server\"\n");
314				return EXIT_FAILURE;
315			}
316			break;
317		case 'p':
318			opts.peer_cid = parse_cid(optarg);
319			break;
320		case 'q':
321			opts.peer_port = parse_port(optarg);
322			break;
323		case 'P':
324			control_port = optarg;
325			break;
326		case '?':
327		default:
328			usage();
329		}
330	}
331
332	if (!control_port)
333		usage();
334	if (opts.mode == TEST_MODE_UNSET)
335		usage();
336	if (opts.peer_cid == VMADDR_CID_ANY)
337		usage();
338
339	if (!control_host) {
340		if (opts.mode != TEST_MODE_SERVER)
341			usage();
342		control_host = "0.0.0.0";
343	}
344
345	control_init(control_host, control_port,
346		     opts.mode == TEST_MODE_SERVER);
347
348	run_tests(test_cases, &opts);
349
350	control_cleanup();
351
352	return 0;
353}
354