1// SPDX-License-Identifier: GPL-2.0
2#define _GNU_SOURCE
3#include <errno.h>
4#include <fcntl.h>
5#include <limits.h>
6#include <sched.h>
7#include <stdarg.h>
8#include <stdbool.h>
9#include <stdio.h>
10#include <stdlib.h>
11#include <string.h>
12#include <sys/mount.h>
13#include <sys/stat.h>
14#include <sys/types.h>
15#include <sys/vfs.h>
16#include <unistd.h>
17
18#ifndef MS_NOSYMFOLLOW
19# define MS_NOSYMFOLLOW 256     /* Do not follow symlinks */
20#endif
21
22#ifndef ST_NOSYMFOLLOW
23# define ST_NOSYMFOLLOW 0x2000  /* Do not follow symlinks */
24#endif
25
26#define DATA "/tmp/data"
27#define LINK "/tmp/symlink"
28#define TMP  "/tmp"
29
30static void die(char *fmt, ...)
31{
32	va_list ap;
33
34	va_start(ap, fmt);
35	vfprintf(stderr, fmt, ap);
36	va_end(ap);
37	exit(EXIT_FAILURE);
38}
39
40static void vmaybe_write_file(bool enoent_ok, char *filename, char *fmt,
41		va_list ap)
42{
43	ssize_t written;
44	char buf[4096];
45	int buf_len;
46	int fd;
47
48	buf_len = vsnprintf(buf, sizeof(buf), fmt, ap);
49	if (buf_len < 0)
50		die("vsnprintf failed: %s\n", strerror(errno));
51
52	if (buf_len >= sizeof(buf))
53		die("vsnprintf output truncated\n");
54
55	fd = open(filename, O_WRONLY);
56	if (fd < 0) {
57		if ((errno == ENOENT) && enoent_ok)
58			return;
59		die("open of %s failed: %s\n", filename, strerror(errno));
60	}
61
62	written = write(fd, buf, buf_len);
63	if (written != buf_len) {
64		if (written >= 0) {
65			die("short write to %s\n", filename);
66		} else {
67			die("write to %s failed: %s\n",
68				filename, strerror(errno));
69		}
70	}
71
72	if (close(fd) != 0)
73		die("close of %s failed: %s\n", filename, strerror(errno));
74}
75
76static void maybe_write_file(char *filename, char *fmt, ...)
77{
78	va_list ap;
79
80	va_start(ap, fmt);
81	vmaybe_write_file(true, filename, fmt, ap);
82	va_end(ap);
83}
84
85static void write_file(char *filename, char *fmt, ...)
86{
87	va_list ap;
88
89	va_start(ap, fmt);
90	vmaybe_write_file(false, filename, fmt, ap);
91	va_end(ap);
92}
93
94static void create_and_enter_ns(void)
95{
96	uid_t uid = getuid();
97	gid_t gid = getgid();
98
99	if (unshare(CLONE_NEWUSER) != 0)
100		die("unshare(CLONE_NEWUSER) failed: %s\n", strerror(errno));
101
102	maybe_write_file("/proc/self/setgroups", "deny");
103	write_file("/proc/self/uid_map", "0 %d 1", uid);
104	write_file("/proc/self/gid_map", "0 %d 1", gid);
105
106	if (setgid(0) != 0)
107		die("setgid(0) failed %s\n", strerror(errno));
108	if (setuid(0) != 0)
109		die("setuid(0) failed %s\n", strerror(errno));
110
111	if (unshare(CLONE_NEWNS) != 0)
112		die("unshare(CLONE_NEWNS) failed: %s\n", strerror(errno));
113}
114
115static void setup_symlink(void)
116{
117	int data, err;
118
119	data = creat(DATA, O_RDWR);
120	if (data < 0)
121		die("creat failed: %s\n", strerror(errno));
122
123	err = symlink(DATA, LINK);
124	if (err < 0)
125		die("symlink failed: %s\n", strerror(errno));
126
127	if (close(data) != 0)
128		die("close of %s failed: %s\n", DATA, strerror(errno));
129}
130
131static void test_link_traversal(bool nosymfollow)
132{
133	int link;
134
135	link = open(LINK, 0, O_RDWR);
136	if (nosymfollow) {
137		if ((link != -1 || errno != ELOOP)) {
138			die("link traversal unexpected result: %d, %s\n",
139					link, strerror(errno));
140		}
141	} else {
142		if (link < 0)
143			die("link traversal failed: %s\n", strerror(errno));
144
145		if (close(link) != 0)
146			die("close of link failed: %s\n", strerror(errno));
147	}
148}
149
150static void test_readlink(void)
151{
152	char buf[4096];
153	ssize_t ret;
154
155	bzero(buf, sizeof(buf));
156
157	ret = readlink(LINK, buf, sizeof(buf));
158	if (ret < 0)
159		die("readlink failed: %s\n", strerror(errno));
160	if (strcmp(buf, DATA) != 0)
161		die("readlink strcmp failed: '%s' '%s'\n", buf, DATA);
162}
163
164static void test_realpath(void)
165{
166	char *path = realpath(LINK, NULL);
167
168	if (!path)
169		die("realpath failed: %s\n", strerror(errno));
170	if (strcmp(path, DATA) != 0)
171		die("realpath strcmp failed\n");
172
173	free(path);
174}
175
176static void test_statfs(bool nosymfollow)
177{
178	struct statfs buf;
179	int ret;
180
181	ret = statfs(TMP, &buf);
182	if (ret)
183		die("statfs failed: %s\n", strerror(errno));
184
185	if (nosymfollow) {
186		if ((buf.f_flags & ST_NOSYMFOLLOW) == 0)
187			die("ST_NOSYMFOLLOW not set on %s\n", TMP);
188	} else {
189		if ((buf.f_flags & ST_NOSYMFOLLOW) != 0)
190			die("ST_NOSYMFOLLOW set on %s\n", TMP);
191	}
192}
193
194static void run_tests(bool nosymfollow)
195{
196	test_link_traversal(nosymfollow);
197	test_readlink();
198	test_realpath();
199	test_statfs(nosymfollow);
200}
201
202int main(int argc, char **argv)
203{
204	create_and_enter_ns();
205
206	if (mount("testing", TMP, "ramfs", 0, NULL) != 0)
207		die("mount failed: %s\n", strerror(errno));
208
209	setup_symlink();
210	run_tests(false);
211
212	if (mount("testing", TMP, "ramfs", MS_REMOUNT|MS_NOSYMFOLLOW, NULL) != 0)
213		die("remount failed: %s\n", strerror(errno));
214
215	run_tests(true);
216
217	return EXIT_SUCCESS;
218}
219