1// SPDX-License-Identifier: GPL-2.0-only
2/* Control socket for client/server test execution
3 *
4 * Copyright (C) 2017 Red Hat, Inc.
5 *
6 * Author: Stefan Hajnoczi <stefanha@redhat.com>
7 */
8
9/* The client and server may need to coordinate to avoid race conditions like
10 * the client attempting to connect to a socket that the server is not
11 * listening on yet.  The control socket offers a communications channel for
12 * such coordination tasks.
13 *
14 * If the client calls control_expectln("LISTENING"), then it will block until
15 * the server calls control_writeln("LISTENING").  This provides a simple
16 * mechanism for coordinating between the client and the server.
17 */
18
19#include <errno.h>
20#include <netdb.h>
21#include <stdio.h>
22#include <stdlib.h>
23#include <string.h>
24#include <unistd.h>
25#include <sys/types.h>
26#include <sys/socket.h>
27
28#include "timeout.h"
29#include "control.h"
30
31static int control_fd = -1;
32
33/* Open the control socket, either in server or client mode */
34void control_init(const char *control_host,
35		  const char *control_port,
36		  bool server)
37{
38	struct addrinfo hints = {
39		.ai_socktype = SOCK_STREAM,
40	};
41	struct addrinfo *result = NULL;
42	struct addrinfo *ai;
43	int ret;
44
45	ret = getaddrinfo(control_host, control_port, &hints, &result);
46	if (ret != 0) {
47		fprintf(stderr, "%s\n", gai_strerror(ret));
48		exit(EXIT_FAILURE);
49	}
50
51	for (ai = result; ai; ai = ai->ai_next) {
52		int fd;
53		int val = 1;
54
55		fd = socket(ai->ai_family, ai->ai_socktype, ai->ai_protocol);
56		if (fd < 0)
57			continue;
58
59		if (!server) {
60			if (connect(fd, ai->ai_addr, ai->ai_addrlen) < 0)
61				goto next;
62			control_fd = fd;
63			printf("Control socket connected to %s:%s.\n",
64			       control_host, control_port);
65			break;
66		}
67
68		if (setsockopt(fd, SOL_SOCKET, SO_REUSEADDR,
69			       &val, sizeof(val)) < 0) {
70			perror("setsockopt");
71			exit(EXIT_FAILURE);
72		}
73
74		if (bind(fd, ai->ai_addr, ai->ai_addrlen) < 0)
75			goto next;
76		if (listen(fd, 1) < 0)
77			goto next;
78
79		printf("Control socket listening on %s:%s\n",
80		       control_host, control_port);
81		fflush(stdout);
82
83		control_fd = accept(fd, NULL, 0);
84		close(fd);
85
86		if (control_fd < 0) {
87			perror("accept");
88			exit(EXIT_FAILURE);
89		}
90		printf("Control socket connection accepted...\n");
91		break;
92
93next:
94		close(fd);
95	}
96
97	if (control_fd < 0) {
98		fprintf(stderr, "Control socket initialization failed.  Invalid address %s:%s?\n",
99			control_host, control_port);
100		exit(EXIT_FAILURE);
101	}
102
103	freeaddrinfo(result);
104}
105
106/* Free resources */
107void control_cleanup(void)
108{
109	close(control_fd);
110	control_fd = -1;
111}
112
113/* Write a line to the control socket */
114void control_writeln(const char *str)
115{
116	ssize_t len = strlen(str);
117	ssize_t ret;
118
119	timeout_begin(TIMEOUT);
120
121	do {
122		ret = send(control_fd, str, len, MSG_MORE);
123		timeout_check("send");
124	} while (ret < 0 && errno == EINTR);
125
126	if (ret != len) {
127		perror("send");
128		exit(EXIT_FAILURE);
129	}
130
131	do {
132		ret = send(control_fd, "\n", 1, 0);
133		timeout_check("send");
134	} while (ret < 0 && errno == EINTR);
135
136	if (ret != 1) {
137		perror("send");
138		exit(EXIT_FAILURE);
139	}
140
141	timeout_end();
142}
143
144void control_writeulong(unsigned long value)
145{
146	char str[32];
147
148	if (snprintf(str, sizeof(str), "%lu", value) >= sizeof(str)) {
149		perror("snprintf");
150		exit(EXIT_FAILURE);
151	}
152
153	control_writeln(str);
154}
155
156unsigned long control_readulong(void)
157{
158	unsigned long value;
159	char *str;
160
161	str = control_readln();
162
163	if (!str)
164		exit(EXIT_FAILURE);
165
166	value = strtoul(str, NULL, 10);
167	free(str);
168
169	return value;
170}
171
172/* Return the next line from the control socket (without the trailing newline).
173 *
174 * The program terminates if a timeout occurs.
175 *
176 * The caller must free() the returned string.
177 */
178char *control_readln(void)
179{
180	char *buf = NULL;
181	size_t idx = 0;
182	size_t buflen = 0;
183
184	timeout_begin(TIMEOUT);
185
186	for (;;) {
187		ssize_t ret;
188
189		if (idx >= buflen) {
190			char *new_buf;
191
192			new_buf = realloc(buf, buflen + 80);
193			if (!new_buf) {
194				perror("realloc");
195				exit(EXIT_FAILURE);
196			}
197
198			buf = new_buf;
199			buflen += 80;
200		}
201
202		do {
203			ret = recv(control_fd, &buf[idx], 1, 0);
204			timeout_check("recv");
205		} while (ret < 0 && errno == EINTR);
206
207		if (ret == 0) {
208			fprintf(stderr, "unexpected EOF on control socket\n");
209			exit(EXIT_FAILURE);
210		}
211
212		if (ret != 1) {
213			perror("recv");
214			exit(EXIT_FAILURE);
215		}
216
217		if (buf[idx] == '\n') {
218			buf[idx] = '\0';
219			break;
220		}
221
222		idx++;
223	}
224
225	timeout_end();
226
227	return buf;
228}
229
230/* Wait until a given line is received or a timeout occurs */
231void control_expectln(const char *str)
232{
233	char *line;
234
235	line = control_readln();
236
237	control_cmpln(line, str, true);
238
239	free(line);
240}
241
242bool control_cmpln(char *line, const char *str, bool fail)
243{
244	if (strcmp(str, line) == 0)
245		return true;
246
247	if (fail) {
248		fprintf(stderr, "expected \"%s\" on control socket, got \"%s\"\n",
249			str, line);
250		exit(EXIT_FAILURE);
251	}
252
253	return false;
254}
255