1// SPDX-License-Identifier: GPL-2.0-only
2/* MSG_ZEROCOPY feature tests for vsock
3 *
4 * Copyright (C) 2023 SberDevices.
5 *
6 * Author: Arseniy Krasnov <avkrasnov@salutedevices.com>
7 */
8
9#include <stdio.h>
10#include <stdlib.h>
11#include <string.h>
12#include <sys/mman.h>
13#include <unistd.h>
14#include <poll.h>
15#include <linux/errqueue.h>
16#include <linux/kernel.h>
17#include <errno.h>
18
19#include "control.h"
20#include "vsock_test_zerocopy.h"
21#include "msg_zerocopy_common.h"
22
23#ifndef PAGE_SIZE
24#define PAGE_SIZE		4096
25#endif
26
27#define VSOCK_TEST_DATA_MAX_IOV 3
28
29struct vsock_test_data {
30	/* This test case if for SOCK_STREAM only. */
31	bool stream_only;
32	/* Data must be zerocopied. This field is checked against
33	 * field 'ee_code' of the 'struct sock_extended_err', which
34	 * contains bit to detect that zerocopy transmission was
35	 * fallbacked to copy mode.
36	 */
37	bool zerocopied;
38	/* Enable SO_ZEROCOPY option on the socket. Without enabled
39	 * SO_ZEROCOPY, every MSG_ZEROCOPY transmission will behave
40	 * like without MSG_ZEROCOPY flag.
41	 */
42	bool so_zerocopy;
43	/* 'errno' after 'sendmsg()' call. */
44	int sendmsg_errno;
45	/* Number of valid elements in 'vecs'. */
46	int vecs_cnt;
47	struct iovec vecs[VSOCK_TEST_DATA_MAX_IOV];
48};
49
50static struct vsock_test_data test_data_array[] = {
51	/* Last element has non-page aligned size. */
52	{
53		.zerocopied = true,
54		.so_zerocopy = true,
55		.sendmsg_errno = 0,
56		.vecs_cnt = 3,
57		{
58			{ NULL, PAGE_SIZE },
59			{ NULL, PAGE_SIZE },
60			{ NULL, 200 }
61		}
62	},
63	/* All elements have page aligned base and size. */
64	{
65		.zerocopied = true,
66		.so_zerocopy = true,
67		.sendmsg_errno = 0,
68		.vecs_cnt = 3,
69		{
70			{ NULL, PAGE_SIZE },
71			{ NULL, PAGE_SIZE * 2 },
72			{ NULL, PAGE_SIZE * 3 }
73		}
74	},
75	/* All elements have page aligned base and size. But
76	 * data length is bigger than 64Kb.
77	 */
78	{
79		.zerocopied = true,
80		.so_zerocopy = true,
81		.sendmsg_errno = 0,
82		.vecs_cnt = 3,
83		{
84			{ NULL, PAGE_SIZE * 16 },
85			{ NULL, PAGE_SIZE * 16 },
86			{ NULL, PAGE_SIZE * 16 }
87		}
88	},
89	/* Middle element has both non-page aligned base and size. */
90	{
91		.zerocopied = true,
92		.so_zerocopy = true,
93		.sendmsg_errno = 0,
94		.vecs_cnt = 3,
95		{
96			{ NULL, PAGE_SIZE },
97			{ (void *)1, 100 },
98			{ NULL, PAGE_SIZE }
99		}
100	},
101	/* Middle element is unmapped. */
102	{
103		.zerocopied = false,
104		.so_zerocopy = true,
105		.sendmsg_errno = ENOMEM,
106		.vecs_cnt = 3,
107		{
108			{ NULL, PAGE_SIZE },
109			{ MAP_FAILED, PAGE_SIZE },
110			{ NULL, PAGE_SIZE }
111		}
112	},
113	/* Valid data, but SO_ZEROCOPY is off. This
114	 * will trigger fallback to copy.
115	 */
116	{
117		.zerocopied = false,
118		.so_zerocopy = false,
119		.sendmsg_errno = 0,
120		.vecs_cnt = 1,
121		{
122			{ NULL, PAGE_SIZE }
123		}
124	},
125	/* Valid data, but message is bigger than peer's
126	 * buffer, so this will trigger fallback to copy.
127	 * This test is for SOCK_STREAM only, because
128	 * for SOCK_SEQPACKET, 'sendmsg()' returns EMSGSIZE.
129	 */
130	{
131		.stream_only = true,
132		.zerocopied = false,
133		.so_zerocopy = true,
134		.sendmsg_errno = 0,
135		.vecs_cnt = 1,
136		{
137			{ NULL, 100 * PAGE_SIZE }
138		}
139	},
140};
141
142#define POLL_TIMEOUT_MS		100
143
144static void test_client(const struct test_opts *opts,
145			const struct vsock_test_data *test_data,
146			bool sock_seqpacket)
147{
148	struct pollfd fds = { 0 };
149	struct msghdr msg = { 0 };
150	ssize_t sendmsg_res;
151	struct iovec *iovec;
152	int fd;
153
154	if (sock_seqpacket)
155		fd = vsock_seqpacket_connect(opts->peer_cid, opts->peer_port);
156	else
157		fd = vsock_stream_connect(opts->peer_cid, opts->peer_port);
158
159	if (fd < 0) {
160		perror("connect");
161		exit(EXIT_FAILURE);
162	}
163
164	if (test_data->so_zerocopy)
165		enable_so_zerocopy(fd);
166
167	iovec = alloc_test_iovec(test_data->vecs, test_data->vecs_cnt);
168
169	msg.msg_iov = iovec;
170	msg.msg_iovlen = test_data->vecs_cnt;
171
172	errno = 0;
173
174	sendmsg_res = sendmsg(fd, &msg, MSG_ZEROCOPY);
175	if (errno != test_data->sendmsg_errno) {
176		fprintf(stderr, "expected 'errno' == %i, got %i\n",
177			test_data->sendmsg_errno, errno);
178		exit(EXIT_FAILURE);
179	}
180
181	if (!errno) {
182		if (sendmsg_res != iovec_bytes(iovec, test_data->vecs_cnt)) {
183			fprintf(stderr, "expected 'sendmsg()' == %li, got %li\n",
184				iovec_bytes(iovec, test_data->vecs_cnt),
185				sendmsg_res);
186			exit(EXIT_FAILURE);
187		}
188	}
189
190	fds.fd = fd;
191	fds.events = 0;
192
193	if (poll(&fds, 1, POLL_TIMEOUT_MS) < 0) {
194		perror("poll");
195		exit(EXIT_FAILURE);
196	}
197
198	if (fds.revents & POLLERR) {
199		vsock_recv_completion(fd, &test_data->zerocopied);
200	} else if (test_data->so_zerocopy && !test_data->sendmsg_errno) {
201		/* If we don't have data in the error queue, but
202		 * SO_ZEROCOPY was enabled and 'sendmsg()' was
203		 * successful - this is an error.
204		 */
205		fprintf(stderr, "POLLERR expected\n");
206		exit(EXIT_FAILURE);
207	}
208
209	if (!test_data->sendmsg_errno)
210		control_writeulong(iovec_hash_djb2(iovec, test_data->vecs_cnt));
211	else
212		control_writeulong(0);
213
214	control_writeln("DONE");
215	free_test_iovec(test_data->vecs, iovec, test_data->vecs_cnt);
216	close(fd);
217}
218
219void test_stream_msgzcopy_client(const struct test_opts *opts)
220{
221	int i;
222
223	for (i = 0; i < ARRAY_SIZE(test_data_array); i++)
224		test_client(opts, &test_data_array[i], false);
225}
226
227void test_seqpacket_msgzcopy_client(const struct test_opts *opts)
228{
229	int i;
230
231	for (i = 0; i < ARRAY_SIZE(test_data_array); i++) {
232		if (test_data_array[i].stream_only)
233			continue;
234
235		test_client(opts, &test_data_array[i], true);
236	}
237}
238
239static void test_server(const struct test_opts *opts,
240			const struct vsock_test_data *test_data,
241			bool sock_seqpacket)
242{
243	unsigned long remote_hash;
244	unsigned long local_hash;
245	ssize_t total_bytes_rec;
246	unsigned char *data;
247	size_t data_len;
248	int fd;
249
250	if (sock_seqpacket)
251		fd = vsock_seqpacket_accept(VMADDR_CID_ANY, opts->peer_port, NULL);
252	else
253		fd = vsock_stream_accept(VMADDR_CID_ANY, opts->peer_port, NULL);
254
255	if (fd < 0) {
256		perror("accept");
257		exit(EXIT_FAILURE);
258	}
259
260	data_len = iovec_bytes(test_data->vecs, test_data->vecs_cnt);
261
262	data = malloc(data_len);
263	if (!data) {
264		perror("malloc");
265		exit(EXIT_FAILURE);
266	}
267
268	total_bytes_rec = 0;
269
270	while (total_bytes_rec != data_len) {
271		ssize_t bytes_rec;
272
273		bytes_rec = read(fd, data + total_bytes_rec,
274				 data_len - total_bytes_rec);
275		if (bytes_rec <= 0)
276			break;
277
278		total_bytes_rec += bytes_rec;
279	}
280
281	if (test_data->sendmsg_errno == 0)
282		local_hash = hash_djb2(data, data_len);
283	else
284		local_hash = 0;
285
286	free(data);
287
288	/* Waiting for some result. */
289	remote_hash = control_readulong();
290	if (remote_hash != local_hash) {
291		fprintf(stderr, "hash mismatch\n");
292		exit(EXIT_FAILURE);
293	}
294
295	control_expectln("DONE");
296	close(fd);
297}
298
299void test_stream_msgzcopy_server(const struct test_opts *opts)
300{
301	int i;
302
303	for (i = 0; i < ARRAY_SIZE(test_data_array); i++)
304		test_server(opts, &test_data_array[i], false);
305}
306
307void test_seqpacket_msgzcopy_server(const struct test_opts *opts)
308{
309	int i;
310
311	for (i = 0; i < ARRAY_SIZE(test_data_array); i++) {
312		if (test_data_array[i].stream_only)
313			continue;
314
315		test_server(opts, &test_data_array[i], true);
316	}
317}
318
319void test_stream_msgzcopy_empty_errq_client(const struct test_opts *opts)
320{
321	struct msghdr msg = { 0 };
322	char cmsg_data[128];
323	ssize_t res;
324	int fd;
325
326	fd = vsock_stream_connect(opts->peer_cid, opts->peer_port);
327	if (fd < 0) {
328		perror("connect");
329		exit(EXIT_FAILURE);
330	}
331
332	msg.msg_control = cmsg_data;
333	msg.msg_controllen = sizeof(cmsg_data);
334
335	res = recvmsg(fd, &msg, MSG_ERRQUEUE);
336	if (res != -1) {
337		fprintf(stderr, "expected 'recvmsg(2)' failure, got %zi\n",
338			res);
339		exit(EXIT_FAILURE);
340	}
341
342	control_writeln("DONE");
343	close(fd);
344}
345
346void test_stream_msgzcopy_empty_errq_server(const struct test_opts *opts)
347{
348	int fd;
349
350	fd = vsock_stream_accept(VMADDR_CID_ANY, opts->peer_port, NULL);
351	if (fd < 0) {
352		perror("accept");
353		exit(EXIT_FAILURE);
354	}
355
356	control_expectln("DONE");
357	close(fd);
358}
359