1// SPDX-License-Identifier: GPL-2.0-only
2/*
3 * Copyright 2018 Google Inc.
4 * Author: Soheil Hassas Yeganeh (soheil@google.com)
5 *
6 * Simple example on how to use TCP_INQ and TCP_CM_INQ.
7 */
8#define _GNU_SOURCE
9
10#include <error.h>
11#include <netinet/in.h>
12#include <netinet/tcp.h>
13#include <pthread.h>
14#include <stdio.h>
15#include <errno.h>
16#include <stdlib.h>
17#include <string.h>
18#include <sys/socket.h>
19#include <unistd.h>
20
21#ifndef TCP_INQ
22#define TCP_INQ 36
23#endif
24
25#ifndef TCP_CM_INQ
26#define TCP_CM_INQ TCP_INQ
27#endif
28
29#define BUF_SIZE 8192
30#define CMSG_SIZE 32
31
32static int family = AF_INET6;
33static socklen_t addr_len = sizeof(struct sockaddr_in6);
34static int port = 4974;
35
36static void setup_loopback_addr(int family, struct sockaddr_storage *sockaddr)
37{
38	struct sockaddr_in6 *addr6 = (void *) sockaddr;
39	struct sockaddr_in *addr4 = (void *) sockaddr;
40
41	switch (family) {
42	case PF_INET:
43		memset(addr4, 0, sizeof(*addr4));
44		addr4->sin_family = AF_INET;
45		addr4->sin_addr.s_addr = htonl(INADDR_LOOPBACK);
46		addr4->sin_port = htons(port);
47		break;
48	case PF_INET6:
49		memset(addr6, 0, sizeof(*addr6));
50		addr6->sin6_family = AF_INET6;
51		addr6->sin6_addr = in6addr_loopback;
52		addr6->sin6_port = htons(port);
53		break;
54	default:
55		error(1, 0, "illegal family");
56	}
57}
58
59void *start_server(void *arg)
60{
61	int server_fd = (int)(unsigned long)arg;
62	struct sockaddr_in addr;
63	socklen_t addrlen = sizeof(addr);
64	char *buf;
65	int fd;
66	int r;
67
68	buf = malloc(BUF_SIZE);
69
70	for (;;) {
71		fd = accept(server_fd, (struct sockaddr *)&addr, &addrlen);
72		if (fd == -1) {
73			perror("accept");
74			break;
75		}
76		do {
77			r = send(fd, buf, BUF_SIZE, 0);
78		} while (r < 0 && errno == EINTR);
79		if (r < 0)
80			perror("send");
81		if (r != BUF_SIZE)
82			fprintf(stderr, "can only send %d bytes\n", r);
83		/* TCP_INQ can overestimate in-queue by one byte if we send
84		 * the FIN packet. Sleep for 1 second, so that the client
85		 * likely invoked recvmsg().
86		 */
87		sleep(1);
88		close(fd);
89	}
90
91	free(buf);
92	close(server_fd);
93	pthread_exit(0);
94}
95
96int main(int argc, char *argv[])
97{
98	struct sockaddr_storage listen_addr, addr;
99	int c, one = 1, inq = -1;
100	pthread_t server_thread;
101	char cmsgbuf[CMSG_SIZE];
102	struct iovec iov[1];
103	struct cmsghdr *cm;
104	struct msghdr msg;
105	int server_fd, fd;
106	char *buf;
107
108	while ((c = getopt(argc, argv, "46p:")) != -1) {
109		switch (c) {
110		case '4':
111			family = PF_INET;
112			addr_len = sizeof(struct sockaddr_in);
113			break;
114		case '6':
115			family = PF_INET6;
116			addr_len = sizeof(struct sockaddr_in6);
117			break;
118		case 'p':
119			port = atoi(optarg);
120			break;
121		}
122	}
123
124	server_fd = socket(family, SOCK_STREAM, 0);
125	if (server_fd < 0)
126		error(1, errno, "server socket");
127	setup_loopback_addr(family, &listen_addr);
128	if (setsockopt(server_fd, SOL_SOCKET, SO_REUSEADDR,
129		       &one, sizeof(one)) != 0)
130		error(1, errno, "setsockopt(SO_REUSEADDR)");
131	if (bind(server_fd, (const struct sockaddr *)&listen_addr,
132		 addr_len) == -1)
133		error(1, errno, "bind");
134	if (listen(server_fd, 128) == -1)
135		error(1, errno, "listen");
136	if (pthread_create(&server_thread, NULL, start_server,
137			   (void *)(unsigned long)server_fd) != 0)
138		error(1, errno, "pthread_create");
139
140	fd = socket(family, SOCK_STREAM, 0);
141	if (fd < 0)
142		error(1, errno, "client socket");
143	setup_loopback_addr(family, &addr);
144	if (connect(fd, (const struct sockaddr *)&addr, addr_len) == -1)
145		error(1, errno, "connect");
146	if (setsockopt(fd, SOL_TCP, TCP_INQ, &one, sizeof(one)) != 0)
147		error(1, errno, "setsockopt(TCP_INQ)");
148
149	msg.msg_name = NULL;
150	msg.msg_namelen = 0;
151	msg.msg_iov = iov;
152	msg.msg_iovlen = 1;
153	msg.msg_control = cmsgbuf;
154	msg.msg_controllen = sizeof(cmsgbuf);
155	msg.msg_flags = 0;
156
157	buf = malloc(BUF_SIZE);
158	iov[0].iov_base = buf;
159	iov[0].iov_len = BUF_SIZE / 2;
160
161	if (recvmsg(fd, &msg, 0) != iov[0].iov_len)
162		error(1, errno, "recvmsg");
163	if (msg.msg_flags & MSG_CTRUNC)
164		error(1, 0, "control message is truncated");
165
166	for (cm = CMSG_FIRSTHDR(&msg); cm; cm = CMSG_NXTHDR(&msg, cm))
167		if (cm->cmsg_level == SOL_TCP && cm->cmsg_type == TCP_CM_INQ)
168			inq = *((int *) CMSG_DATA(cm));
169
170	if (inq != BUF_SIZE - iov[0].iov_len) {
171		fprintf(stderr, "unexpected inq: %d\n", inq);
172		exit(1);
173	}
174
175	printf("PASSED\n");
176	free(buf);
177	close(fd);
178	return 0;
179}
180