1// SPDX-License-Identifier: GPL-2.0 OR MIT
2#define _GNU_SOURCE
3#include <error.h>
4#include <limits.h>
5#include <stddef.h>
6#include <stdio.h>
7#include <stdlib.h>
8#include <sys/socket.h>
9#include <linux/socket.h>
10#include <unistd.h>
11#include <string.h>
12#include <errno.h>
13#include <sys/un.h>
14#include <sys/signal.h>
15#include <sys/types.h>
16#include <sys/wait.h>
17
18#include "../../kselftest_harness.h"
19
20#define clean_errno() (errno == 0 ? "None" : strerror(errno))
21#define log_err(MSG, ...)                                                   \
22	fprintf(stderr, "(%s:%d: errno: %s) " MSG "\n", __FILE__, __LINE__, \
23		clean_errno(), ##__VA_ARGS__)
24
25#ifndef SCM_PIDFD
26#define SCM_PIDFD 0x04
27#endif
28
29static void child_die()
30{
31	exit(1);
32}
33
34static int safe_int(const char *numstr, int *converted)
35{
36	char *err = NULL;
37	long sli;
38
39	errno = 0;
40	sli = strtol(numstr, &err, 0);
41	if (errno == ERANGE && (sli == LONG_MAX || sli == LONG_MIN))
42		return -ERANGE;
43
44	if (errno != 0 && sli == 0)
45		return -EINVAL;
46
47	if (err == numstr || *err != '\0')
48		return -EINVAL;
49
50	if (sli > INT_MAX || sli < INT_MIN)
51		return -ERANGE;
52
53	*converted = (int)sli;
54	return 0;
55}
56
57static int char_left_gc(const char *buffer, size_t len)
58{
59	size_t i;
60
61	for (i = 0; i < len; i++) {
62		if (buffer[i] == ' ' || buffer[i] == '\t')
63			continue;
64
65		return i;
66	}
67
68	return 0;
69}
70
71static int char_right_gc(const char *buffer, size_t len)
72{
73	int i;
74
75	for (i = len - 1; i >= 0; i--) {
76		if (buffer[i] == ' ' || buffer[i] == '\t' ||
77		    buffer[i] == '\n' || buffer[i] == '\0')
78			continue;
79
80		return i + 1;
81	}
82
83	return 0;
84}
85
86static char *trim_whitespace_in_place(char *buffer)
87{
88	buffer += char_left_gc(buffer, strlen(buffer));
89	buffer[char_right_gc(buffer, strlen(buffer))] = '\0';
90	return buffer;
91}
92
93/* borrowed (with all helpers) from pidfd/pidfd_open_test.c */
94static pid_t get_pid_from_fdinfo_file(int pidfd, const char *key, size_t keylen)
95{
96	int ret;
97	char path[512];
98	FILE *f;
99	size_t n = 0;
100	pid_t result = -1;
101	char *line = NULL;
102
103	snprintf(path, sizeof(path), "/proc/self/fdinfo/%d", pidfd);
104
105	f = fopen(path, "re");
106	if (!f)
107		return -1;
108
109	while (getline(&line, &n, f) != -1) {
110		char *numstr;
111
112		if (strncmp(line, key, keylen))
113			continue;
114
115		numstr = trim_whitespace_in_place(line + 4);
116		ret = safe_int(numstr, &result);
117		if (ret < 0)
118			goto out;
119
120		break;
121	}
122
123out:
124	free(line);
125	fclose(f);
126	return result;
127}
128
129static int cmsg_check(int fd)
130{
131	struct msghdr msg = { 0 };
132	struct cmsghdr *cmsg;
133	struct iovec iov;
134	struct ucred *ucred = NULL;
135	int data = 0;
136	char control[CMSG_SPACE(sizeof(struct ucred)) +
137		     CMSG_SPACE(sizeof(int))] = { 0 };
138	int *pidfd = NULL;
139	pid_t parent_pid;
140	int err;
141
142	iov.iov_base = &data;
143	iov.iov_len = sizeof(data);
144
145	msg.msg_iov = &iov;
146	msg.msg_iovlen = 1;
147	msg.msg_control = control;
148	msg.msg_controllen = sizeof(control);
149
150	err = recvmsg(fd, &msg, 0);
151	if (err < 0) {
152		log_err("recvmsg");
153		return 1;
154	}
155
156	if (msg.msg_flags & (MSG_TRUNC | MSG_CTRUNC)) {
157		log_err("recvmsg: truncated");
158		return 1;
159	}
160
161	for (cmsg = CMSG_FIRSTHDR(&msg); cmsg != NULL;
162	     cmsg = CMSG_NXTHDR(&msg, cmsg)) {
163		if (cmsg->cmsg_level == SOL_SOCKET &&
164		    cmsg->cmsg_type == SCM_PIDFD) {
165			if (cmsg->cmsg_len < sizeof(*pidfd)) {
166				log_err("CMSG parse: SCM_PIDFD wrong len");
167				return 1;
168			}
169
170			pidfd = (void *)CMSG_DATA(cmsg);
171		}
172
173		if (cmsg->cmsg_level == SOL_SOCKET &&
174		    cmsg->cmsg_type == SCM_CREDENTIALS) {
175			if (cmsg->cmsg_len < sizeof(*ucred)) {
176				log_err("CMSG parse: SCM_CREDENTIALS wrong len");
177				return 1;
178			}
179
180			ucred = (void *)CMSG_DATA(cmsg);
181		}
182	}
183
184	/* send(pfd, "x", sizeof(char), 0) */
185	if (data != 'x') {
186		log_err("recvmsg: data corruption");
187		return 1;
188	}
189
190	if (!pidfd) {
191		log_err("CMSG parse: SCM_PIDFD not found");
192		return 1;
193	}
194
195	if (!ucred) {
196		log_err("CMSG parse: SCM_CREDENTIALS not found");
197		return 1;
198	}
199
200	/* pidfd from SCM_PIDFD should point to the parent process PID */
201	parent_pid =
202		get_pid_from_fdinfo_file(*pidfd, "Pid:", sizeof("Pid:") - 1);
203	if (parent_pid != getppid()) {
204		log_err("wrong SCM_PIDFD %d != %d", parent_pid, getppid());
205		return 1;
206	}
207
208	return 0;
209}
210
211struct sock_addr {
212	char sock_name[32];
213	struct sockaddr_un listen_addr;
214	socklen_t addrlen;
215};
216
217FIXTURE(scm_pidfd)
218{
219	int server;
220	pid_t client_pid;
221	int startup_pipe[2];
222	struct sock_addr server_addr;
223	struct sock_addr *client_addr;
224};
225
226FIXTURE_VARIANT(scm_pidfd)
227{
228	int type;
229	bool abstract;
230};
231
232FIXTURE_VARIANT_ADD(scm_pidfd, stream_pathname)
233{
234	.type = SOCK_STREAM,
235	.abstract = 0,
236};
237
238FIXTURE_VARIANT_ADD(scm_pidfd, stream_abstract)
239{
240	.type = SOCK_STREAM,
241	.abstract = 1,
242};
243
244FIXTURE_VARIANT_ADD(scm_pidfd, dgram_pathname)
245{
246	.type = SOCK_DGRAM,
247	.abstract = 0,
248};
249
250FIXTURE_VARIANT_ADD(scm_pidfd, dgram_abstract)
251{
252	.type = SOCK_DGRAM,
253	.abstract = 1,
254};
255
256FIXTURE_SETUP(scm_pidfd)
257{
258	self->client_addr = mmap(NULL, sizeof(*self->client_addr), PROT_READ | PROT_WRITE,
259				 MAP_SHARED | MAP_ANONYMOUS, -1, 0);
260	ASSERT_NE(MAP_FAILED, self->client_addr);
261}
262
263FIXTURE_TEARDOWN(scm_pidfd)
264{
265	close(self->server);
266
267	kill(self->client_pid, SIGKILL);
268	waitpid(self->client_pid, NULL, 0);
269
270	if (!variant->abstract) {
271		unlink(self->server_addr.sock_name);
272		unlink(self->client_addr->sock_name);
273	}
274}
275
276static void fill_sockaddr(struct sock_addr *addr, bool abstract)
277{
278	char *sun_path_buf = (char *)&addr->listen_addr.sun_path;
279
280	addr->listen_addr.sun_family = AF_UNIX;
281	addr->addrlen = offsetof(struct sockaddr_un, sun_path);
282	snprintf(addr->sock_name, sizeof(addr->sock_name), "scm_pidfd_%d", getpid());
283	addr->addrlen += strlen(addr->sock_name);
284	if (abstract) {
285		*sun_path_buf = '\0';
286		addr->addrlen++;
287		sun_path_buf++;
288	} else {
289		unlink(addr->sock_name);
290	}
291	memcpy(sun_path_buf, addr->sock_name, strlen(addr->sock_name));
292}
293
294static void client(FIXTURE_DATA(scm_pidfd) *self,
295		   const FIXTURE_VARIANT(scm_pidfd) *variant)
296{
297	int cfd;
298	socklen_t len;
299	struct ucred peer_cred;
300	int peer_pidfd;
301	pid_t peer_pid;
302	int on = 0;
303
304	cfd = socket(AF_UNIX, variant->type, 0);
305	if (cfd < 0) {
306		log_err("socket");
307		child_die();
308	}
309
310	if (variant->type == SOCK_DGRAM) {
311		fill_sockaddr(self->client_addr, variant->abstract);
312
313		if (bind(cfd, (struct sockaddr *)&self->client_addr->listen_addr, self->client_addr->addrlen)) {
314			log_err("bind");
315			child_die();
316		}
317	}
318
319	if (connect(cfd, (struct sockaddr *)&self->server_addr.listen_addr,
320		    self->server_addr.addrlen) != 0) {
321		log_err("connect");
322		child_die();
323	}
324
325	on = 1;
326	if (setsockopt(cfd, SOL_SOCKET, SO_PASSCRED, &on, sizeof(on))) {
327		log_err("Failed to set SO_PASSCRED");
328		child_die();
329	}
330
331	if (setsockopt(cfd, SOL_SOCKET, SO_PASSPIDFD, &on, sizeof(on))) {
332		log_err("Failed to set SO_PASSPIDFD");
333		child_die();
334	}
335
336	close(self->startup_pipe[1]);
337
338	if (cmsg_check(cfd)) {
339		log_err("cmsg_check failed");
340		child_die();
341	}
342
343	/* skip further for SOCK_DGRAM as it's not applicable */
344	if (variant->type == SOCK_DGRAM)
345		return;
346
347	len = sizeof(peer_cred);
348	if (getsockopt(cfd, SOL_SOCKET, SO_PEERCRED, &peer_cred, &len)) {
349		log_err("Failed to get SO_PEERCRED");
350		child_die();
351	}
352
353	len = sizeof(peer_pidfd);
354	if (getsockopt(cfd, SOL_SOCKET, SO_PEERPIDFD, &peer_pidfd, &len)) {
355		log_err("Failed to get SO_PEERPIDFD");
356		child_die();
357	}
358
359	/* pid from SO_PEERCRED should point to the parent process PID */
360	if (peer_cred.pid != getppid()) {
361		log_err("peer_cred.pid != getppid(): %d != %d", peer_cred.pid, getppid());
362		child_die();
363	}
364
365	peer_pid = get_pid_from_fdinfo_file(peer_pidfd,
366					    "Pid:", sizeof("Pid:") - 1);
367	if (peer_pid != peer_cred.pid) {
368		log_err("peer_pid != peer_cred.pid: %d != %d", peer_pid, peer_cred.pid);
369		child_die();
370	}
371}
372
373TEST_F(scm_pidfd, test)
374{
375	int err;
376	int pfd;
377	int child_status = 0;
378
379	self->server = socket(AF_UNIX, variant->type, 0);
380	ASSERT_NE(-1, self->server);
381
382	fill_sockaddr(&self->server_addr, variant->abstract);
383
384	err = bind(self->server, (struct sockaddr *)&self->server_addr.listen_addr, self->server_addr.addrlen);
385	ASSERT_EQ(0, err);
386
387	if (variant->type == SOCK_STREAM) {
388		err = listen(self->server, 1);
389		ASSERT_EQ(0, err);
390	}
391
392	err = pipe(self->startup_pipe);
393	ASSERT_NE(-1, err);
394
395	self->client_pid = fork();
396	ASSERT_NE(-1, self->client_pid);
397	if (self->client_pid == 0) {
398		close(self->server);
399		close(self->startup_pipe[0]);
400		client(self, variant);
401		exit(0);
402	}
403	close(self->startup_pipe[1]);
404
405	if (variant->type == SOCK_STREAM) {
406		pfd = accept(self->server, NULL, NULL);
407		ASSERT_NE(-1, pfd);
408	} else {
409		pfd = self->server;
410	}
411
412	/* wait until the child arrives at checkpoint */
413	read(self->startup_pipe[0], &err, sizeof(int));
414	close(self->startup_pipe[0]);
415
416	if (variant->type == SOCK_DGRAM) {
417		err = sendto(pfd, "x", sizeof(char), 0, (struct sockaddr *)&self->client_addr->listen_addr, self->client_addr->addrlen);
418		ASSERT_NE(-1, err);
419	} else {
420		err = send(pfd, "x", sizeof(char), 0);
421		ASSERT_NE(-1, err);
422	}
423
424	close(pfd);
425	waitpid(self->client_pid, &child_status, 0);
426	ASSERT_EQ(0, WIFEXITED(child_status) ? WEXITSTATUS(child_status) : 1);
427}
428
429TEST_HARNESS_MAIN
430