1// SPDX-License-Identifier: GPL-2.0
2
3#define _GNU_SOURCE
4
5#include <assert.h>
6#include <errno.h>
7#include <fcntl.h>
8#include <limits.h>
9#include <string.h>
10#include <stdarg.h>
11#include <stdbool.h>
12#include <stdint.h>
13#include <inttypes.h>
14#include <stdio.h>
15#include <stdlib.h>
16#include <strings.h>
17#include <unistd.h>
18#include <time.h>
19
20#include <sys/ioctl.h>
21#include <sys/random.h>
22#include <sys/socket.h>
23#include <sys/types.h>
24#include <sys/wait.h>
25
26#include <netdb.h>
27#include <netinet/in.h>
28
29#include <linux/tcp.h>
30#include <linux/sockios.h>
31
32#ifndef IPPROTO_MPTCP
33#define IPPROTO_MPTCP 262
34#endif
35#ifndef SOL_MPTCP
36#define SOL_MPTCP 284
37#endif
38
39static int pf = AF_INET;
40static int proto_tx = IPPROTO_MPTCP;
41static int proto_rx = IPPROTO_MPTCP;
42
43static void die_perror(const char *msg)
44{
45	perror(msg);
46	exit(1);
47}
48
49static void die_usage(int r)
50{
51	fprintf(stderr, "Usage: mptcp_inq [-6] [ -t tcp|mptcp ] [ -r tcp|mptcp]\n");
52	exit(r);
53}
54
55static void xerror(const char *fmt, ...)
56{
57	va_list ap;
58
59	va_start(ap, fmt);
60	vfprintf(stderr, fmt, ap);
61	va_end(ap);
62	fputc('\n', stderr);
63	exit(1);
64}
65
66static const char *getxinfo_strerr(int err)
67{
68	if (err == EAI_SYSTEM)
69		return strerror(errno);
70
71	return gai_strerror(err);
72}
73
74static void xgetaddrinfo(const char *node, const char *service,
75			 const struct addrinfo *hints,
76			 struct addrinfo **res)
77{
78	int err = getaddrinfo(node, service, hints, res);
79
80	if (err) {
81		const char *errstr = getxinfo_strerr(err);
82
83		fprintf(stderr, "Fatal: getaddrinfo(%s:%s): %s\n",
84			node ? node : "", service ? service : "", errstr);
85		exit(1);
86	}
87}
88
89static int sock_listen_mptcp(const char * const listenaddr,
90			     const char * const port)
91{
92	int sock = -1;
93	struct addrinfo hints = {
94		.ai_protocol = IPPROTO_TCP,
95		.ai_socktype = SOCK_STREAM,
96		.ai_flags = AI_PASSIVE | AI_NUMERICHOST
97	};
98
99	hints.ai_family = pf;
100
101	struct addrinfo *a, *addr;
102	int one = 1;
103
104	xgetaddrinfo(listenaddr, port, &hints, &addr);
105	hints.ai_family = pf;
106
107	for (a = addr; a; a = a->ai_next) {
108		sock = socket(a->ai_family, a->ai_socktype, proto_rx);
109		if (sock < 0)
110			continue;
111
112		if (-1 == setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, &one,
113				     sizeof(one)))
114			perror("setsockopt");
115
116		if (bind(sock, a->ai_addr, a->ai_addrlen) == 0)
117			break; /* success */
118
119		perror("bind");
120		close(sock);
121		sock = -1;
122	}
123
124	freeaddrinfo(addr);
125
126	if (sock < 0)
127		xerror("could not create listen socket");
128
129	if (listen(sock, 20))
130		die_perror("listen");
131
132	return sock;
133}
134
135static int sock_connect_mptcp(const char * const remoteaddr,
136			      const char * const port, int proto)
137{
138	struct addrinfo hints = {
139		.ai_protocol = IPPROTO_TCP,
140		.ai_socktype = SOCK_STREAM,
141	};
142	struct addrinfo *a, *addr;
143	int sock = -1;
144
145	hints.ai_family = pf;
146
147	xgetaddrinfo(remoteaddr, port, &hints, &addr);
148	for (a = addr; a; a = a->ai_next) {
149		sock = socket(a->ai_family, a->ai_socktype, proto);
150		if (sock < 0)
151			continue;
152
153		if (connect(sock, a->ai_addr, a->ai_addrlen) == 0)
154			break; /* success */
155
156		die_perror("connect");
157	}
158
159	if (sock < 0)
160		xerror("could not create connect socket");
161
162	freeaddrinfo(addr);
163	return sock;
164}
165
166static int protostr_to_num(const char *s)
167{
168	if (strcasecmp(s, "tcp") == 0)
169		return IPPROTO_TCP;
170	if (strcasecmp(s, "mptcp") == 0)
171		return IPPROTO_MPTCP;
172
173	die_usage(1);
174	return 0;
175}
176
177static void parse_opts(int argc, char **argv)
178{
179	int c;
180
181	while ((c = getopt(argc, argv, "h6t:r:")) != -1) {
182		switch (c) {
183		case 'h':
184			die_usage(0);
185			break;
186		case '6':
187			pf = AF_INET6;
188			break;
189		case 't':
190			proto_tx = protostr_to_num(optarg);
191			break;
192		case 'r':
193			proto_rx = protostr_to_num(optarg);
194			break;
195		default:
196			die_usage(1);
197			break;
198		}
199	}
200}
201
202/* wait up to timeout milliseconds */
203static void wait_for_ack(int fd, int timeout, size_t total)
204{
205	int i;
206
207	for (i = 0; i < timeout; i++) {
208		int nsd, ret, queued = -1;
209		struct timespec req;
210
211		ret = ioctl(fd, TIOCOUTQ, &queued);
212		if (ret < 0)
213			die_perror("TIOCOUTQ");
214
215		ret = ioctl(fd, SIOCOUTQNSD, &nsd);
216		if (ret < 0)
217			die_perror("SIOCOUTQNSD");
218
219		if ((size_t)queued > total)
220			xerror("TIOCOUTQ %u, but only %zu expected\n", queued, total);
221		assert(nsd <= queued);
222
223		if (queued == 0)
224			return;
225
226		/* wait for peer to ack rx of all data */
227		req.tv_sec = 0;
228		req.tv_nsec = 1 * 1000 * 1000ul; /* 1ms */
229		nanosleep(&req, NULL);
230	}
231
232	xerror("still tx data queued after %u ms\n", timeout);
233}
234
235static void connect_one_server(int fd, int unixfd)
236{
237	size_t len, i, total, sent;
238	char buf[4096], buf2[4096];
239	ssize_t ret;
240
241	len = rand() % (sizeof(buf) - 1);
242
243	if (len < 128)
244		len = 128;
245
246	for (i = 0; i < len ; i++) {
247		buf[i] = rand() % 26;
248		buf[i] += 'A';
249	}
250
251	buf[i] = '\n';
252
253	/* un-block server */
254	ret = read(unixfd, buf2, 4);
255	assert(ret == 4);
256
257	assert(strncmp(buf2, "xmit", 4) == 0);
258
259	ret = write(unixfd, &len, sizeof(len));
260	assert(ret == (ssize_t)sizeof(len));
261
262	ret = write(fd, buf, len);
263	if (ret < 0)
264		die_perror("write");
265
266	if (ret != (ssize_t)len)
267		xerror("short write");
268
269	ret = read(unixfd, buf2, 4);
270	assert(strncmp(buf2, "huge", 4) == 0);
271
272	total = rand() % (16 * 1024 * 1024);
273	total += (1 * 1024 * 1024);
274	sent = total;
275
276	ret = write(unixfd, &total, sizeof(total));
277	assert(ret == (ssize_t)sizeof(total));
278
279	wait_for_ack(fd, 5000, len);
280
281	while (total > 0) {
282		if (total > sizeof(buf))
283			len = sizeof(buf);
284		else
285			len = total;
286
287		ret = write(fd, buf, len);
288		if (ret < 0)
289			die_perror("write");
290		total -= ret;
291
292		/* we don't have to care about buf content, only
293		 * number of total bytes sent
294		 */
295	}
296
297	ret = read(unixfd, buf2, 4);
298	assert(ret == 4);
299	assert(strncmp(buf2, "shut", 4) == 0);
300
301	wait_for_ack(fd, 5000, sent);
302
303	ret = write(fd, buf, 1);
304	assert(ret == 1);
305	close(fd);
306	ret = write(unixfd, "closed", 6);
307	assert(ret == 6);
308
309	close(unixfd);
310}
311
312static void get_tcp_inq(struct msghdr *msgh, unsigned int *inqv)
313{
314	struct cmsghdr *cmsg;
315
316	for (cmsg = CMSG_FIRSTHDR(msgh); cmsg ; cmsg = CMSG_NXTHDR(msgh, cmsg)) {
317		if (cmsg->cmsg_level == IPPROTO_TCP && cmsg->cmsg_type == TCP_CM_INQ) {
318			memcpy(inqv, CMSG_DATA(cmsg), sizeof(*inqv));
319			return;
320		}
321	}
322
323	xerror("could not find TCP_CM_INQ cmsg type");
324}
325
326static void process_one_client(int fd, int unixfd)
327{
328	unsigned int tcp_inq;
329	size_t expect_len;
330	char msg_buf[4096];
331	char buf[4096];
332	char tmp[16];
333	struct iovec iov = {
334		.iov_base = buf,
335		.iov_len = 1,
336	};
337	struct msghdr msg = {
338		.msg_iov = &iov,
339		.msg_iovlen = 1,
340		.msg_control = msg_buf,
341		.msg_controllen = sizeof(msg_buf),
342	};
343	ssize_t ret, tot;
344
345	ret = write(unixfd, "xmit", 4);
346	assert(ret == 4);
347
348	ret = read(unixfd, &expect_len, sizeof(expect_len));
349	assert(ret == (ssize_t)sizeof(expect_len));
350
351	if (expect_len > sizeof(buf))
352		xerror("expect len %zu exceeds buffer size", expect_len);
353
354	for (;;) {
355		struct timespec req;
356		unsigned int queued;
357
358		ret = ioctl(fd, FIONREAD, &queued);
359		if (ret < 0)
360			die_perror("FIONREAD");
361		if (queued > expect_len)
362			xerror("FIONREAD returned %u, but only %zu expected\n",
363			       queued, expect_len);
364		if (queued == expect_len)
365			break;
366
367		req.tv_sec = 0;
368		req.tv_nsec = 1000 * 1000ul;
369		nanosleep(&req, NULL);
370	}
371
372	/* read one byte, expect cmsg to return expected - 1 */
373	ret = recvmsg(fd, &msg, 0);
374	if (ret < 0)
375		die_perror("recvmsg");
376
377	if (msg.msg_controllen == 0)
378		xerror("msg_controllen is 0");
379
380	get_tcp_inq(&msg, &tcp_inq);
381
382	assert((size_t)tcp_inq == (expect_len - 1));
383
384	iov.iov_len = sizeof(buf);
385	ret = recvmsg(fd, &msg, 0);
386	if (ret < 0)
387		die_perror("recvmsg");
388
389	/* should have gotten exact remainder of all pending data */
390	assert(ret == (ssize_t)tcp_inq);
391
392	/* should be 0, all drained */
393	get_tcp_inq(&msg, &tcp_inq);
394	assert(tcp_inq == 0);
395
396	/* request a large swath of data. */
397	ret = write(unixfd, "huge", 4);
398	assert(ret == 4);
399
400	ret = read(unixfd, &expect_len, sizeof(expect_len));
401	assert(ret == (ssize_t)sizeof(expect_len));
402
403	/* peer should send us a few mb of data */
404	if (expect_len <= sizeof(buf))
405		xerror("expect len %zu too small\n", expect_len);
406
407	tot = 0;
408	do {
409		iov.iov_len = sizeof(buf);
410		ret = recvmsg(fd, &msg, 0);
411		if (ret < 0)
412			die_perror("recvmsg");
413
414		tot += ret;
415
416		get_tcp_inq(&msg, &tcp_inq);
417
418		if (tcp_inq > expect_len - tot)
419			xerror("inq %d, remaining %d total_len %d\n",
420			       tcp_inq, expect_len - tot, (int)expect_len);
421
422		assert(tcp_inq <= expect_len - tot);
423	} while ((size_t)tot < expect_len);
424
425	ret = write(unixfd, "shut", 4);
426	assert(ret == 4);
427
428	/* wait for hangup. Should have received one more byte of data. */
429	ret = read(unixfd, tmp, sizeof(tmp));
430	assert(ret == 6);
431	assert(strncmp(tmp, "closed", 6) == 0);
432
433	sleep(1);
434
435	iov.iov_len = 1;
436	ret = recvmsg(fd, &msg, 0);
437	if (ret < 0)
438		die_perror("recvmsg");
439	assert(ret == 1);
440
441	get_tcp_inq(&msg, &tcp_inq);
442
443	/* tcp_inq should be 1 due to received fin. */
444	assert(tcp_inq == 1);
445
446	iov.iov_len = 1;
447	ret = recvmsg(fd, &msg, 0);
448	if (ret < 0)
449		die_perror("recvmsg");
450
451	/* expect EOF */
452	assert(ret == 0);
453	get_tcp_inq(&msg, &tcp_inq);
454	assert(tcp_inq == 1);
455
456	close(fd);
457}
458
459static int xaccept(int s)
460{
461	int fd = accept(s, NULL, 0);
462
463	if (fd < 0)
464		die_perror("accept");
465
466	return fd;
467}
468
469static int server(int unixfd)
470{
471	int fd = -1, r, on = 1;
472
473	switch (pf) {
474	case AF_INET:
475		fd = sock_listen_mptcp("127.0.0.1", "15432");
476		break;
477	case AF_INET6:
478		fd = sock_listen_mptcp("::1", "15432");
479		break;
480	default:
481		xerror("Unknown pf %d\n", pf);
482		break;
483	}
484
485	r = write(unixfd, "conn", 4);
486	assert(r == 4);
487
488	alarm(15);
489	r = xaccept(fd);
490
491	if (-1 == setsockopt(r, IPPROTO_TCP, TCP_INQ, &on, sizeof(on)))
492		die_perror("setsockopt");
493
494	process_one_client(r, unixfd);
495
496	return 0;
497}
498
499static int client(int unixfd)
500{
501	int fd = -1;
502
503	alarm(15);
504
505	switch (pf) {
506	case AF_INET:
507		fd = sock_connect_mptcp("127.0.0.1", "15432", proto_tx);
508		break;
509	case AF_INET6:
510		fd = sock_connect_mptcp("::1", "15432", proto_tx);
511		break;
512	default:
513		xerror("Unknown pf %d\n", pf);
514	}
515
516	connect_one_server(fd, unixfd);
517
518	return 0;
519}
520
521static void init_rng(void)
522{
523	unsigned int foo;
524
525	if (getrandom(&foo, sizeof(foo), 0) == -1) {
526		perror("getrandom");
527		exit(1);
528	}
529
530	srand(foo);
531}
532
533static pid_t xfork(void)
534{
535	pid_t p = fork();
536
537	if (p < 0)
538		die_perror("fork");
539	else if (p == 0)
540		init_rng();
541
542	return p;
543}
544
545static int rcheck(int wstatus, const char *what)
546{
547	if (WIFEXITED(wstatus)) {
548		if (WEXITSTATUS(wstatus) == 0)
549			return 0;
550		fprintf(stderr, "%s exited, status=%d\n", what, WEXITSTATUS(wstatus));
551		return WEXITSTATUS(wstatus);
552	} else if (WIFSIGNALED(wstatus)) {
553		xerror("%s killed by signal %d\n", what, WTERMSIG(wstatus));
554	} else if (WIFSTOPPED(wstatus)) {
555		xerror("%s stopped by signal %d\n", what, WSTOPSIG(wstatus));
556	}
557
558	return 111;
559}
560
561int main(int argc, char *argv[])
562{
563	int e1, e2, wstatus;
564	pid_t s, c, ret;
565	int unixfds[2];
566
567	parse_opts(argc, argv);
568
569	e1 = socketpair(AF_UNIX, SOCK_DGRAM, 0, unixfds);
570	if (e1 < 0)
571		die_perror("pipe");
572
573	s = xfork();
574	if (s == 0)
575		return server(unixfds[1]);
576
577	close(unixfds[1]);
578
579	/* wait until server bound a socket */
580	e1 = read(unixfds[0], &e1, 4);
581	assert(e1 == 4);
582
583	c = xfork();
584	if (c == 0)
585		return client(unixfds[0]);
586
587	close(unixfds[0]);
588
589	ret = waitpid(s, &wstatus, 0);
590	if (ret == -1)
591		die_perror("waitpid");
592	e1 = rcheck(wstatus, "server");
593	ret = waitpid(c, &wstatus, 0);
594	if (ret == -1)
595		die_perror("waitpid");
596	e2 = rcheck(wstatus, "client");
597
598	return e1 ? e1 : e2;
599}
600