112048Speter// SPDX-License-Identifier: GPL-2.0
212048Speter/*
312048Speter * memfd GUP test-case
412048Speter * This tests memfd interactions with get_user_pages(). We require the
512048Speter * fuse_mnt.c program to provide a fake direct-IO FUSE mount-point for us. This
612048Speter * file-system delays _all_ reads by 1s and forces direct-IO. This means, any
712048Speter * read() on files in that file-system will pin the receive-buffer pages for at
812048Speter * least 1s via get_user_pages().
912048Speter *
1012048Speter * We use this trick to race ADD_SEALS against a write on a memfd object. The
1112048Speter * ADD_SEALS must fail if the memfd pages are still pinned. Note that we use
1212048Speter * the read() syscall with our memory-mapped memfd object as receive buffer to
1312048Speter * force the kernel to write into our memfd object.
1412048Speter */
1512048Speter
1612048Speter#define _GNU_SOURCE
1712048Speter#define __EXPORTED_HEADERS__
1812048Speter
1912048Speter#include <errno.h>
2012048Speter#include <inttypes.h>
2112048Speter#include <limits.h>
2212048Speter#include <linux/falloc.h>
2312048Speter#include <fcntl.h>
2412048Speter#include <linux/memfd.h>
2512048Speter#include <linux/types.h>
2612048Speter#include <sched.h>
2712048Speter#include <stdio.h>
2812048Speter#include <stdlib.h>
2912048Speter#include <signal.h>
3012048Speter#include <string.h>
3112048Speter#include <sys/mman.h>
3237001Scharnier#include <sys/stat.h>
3350476Speter#include <sys/syscall.h>
3412048Speter#include <sys/wait.h>
3512048Speter#include <unistd.h>
3674594Salfred
3712048Speter#include "common.h"
3837001Scharnier
3912048Speter#define MFD_DEF_SIZE 8192
4012048Speter#define STACK_SIZE 65536
41122621Sjohan
4212048Speterstatic size_t mfd_def_size = MFD_DEF_SIZE;
4337001Scharnier
44122621Sjohanstatic int mfd_assert_new(const char *name, loff_t sz, unsigned int flags)
4512048Speter{
4612048Speter	int r, fd;
4774594Salfred
4812048Speter	fd = sys_memfd_create(name, flags);
4989827Sjoerg	if (fd < 0) {
5089827Sjoerg		printf("memfd_create(\"%s\", %u) failed: %m\n",
5112048Speter		       name, flags);
5212048Speter		abort();
5312048Speter	}
5492881Simp
55207141Sjeff	r = ftruncate(fd, sz);
5698542Smckusick	if (r < 0) {
5789827Sjoerg		printf("ftruncate(%llu) failed: %m\n", (unsigned long long)sz);
5812048Speter		abort();
5992881Simp	}
6012048Speter
6112048Speter	return fd;
6212048Speter}
6312048Speter
6412048Speterstatic __u64 mfd_assert_get_seals(int fd)
6512048Speter{
6612048Speter	long r;
6712048Speter
6812048Speter	r = fcntl(fd, F_GET_SEALS);
6912048Speter	if (r < 0) {
7012048Speter		printf("GET_SEALS(%d) failed: %m\n", fd);
7112048Speter		abort();
7212048Speter	}
7312048Speter
7412048Speter	return r;
7512048Speter}
7689791Sgreen
7792881Simpstatic void mfd_assert_has_seals(int fd, __u64 seals)
7889791Sgreen{
7989791Sgreen	__u64 s;
8089791Sgreen
8189791Sgreen	s = mfd_assert_get_seals(fd);
8289791Sgreen	if (s != seals) {
8389791Sgreen		printf("%llu != %llu = GET_SEALS(%d)\n",
8489791Sgreen		       (unsigned long long)seals, (unsigned long long)s, fd);
8589791Sgreen		abort();
8689791Sgreen	}
8789791Sgreen}
8889791Sgreen
8989791Sgreenstatic void mfd_assert_add_seals(int fd, __u64 seals)
9089791Sgreen{
9189791Sgreen	long r;
9289791Sgreen	__u64 s;
9389791Sgreen
9489791Sgreen	s = mfd_assert_get_seals(fd);
9589791Sgreen	r = fcntl(fd, F_ADD_SEALS, seals);
9612048Speter	if (r < 0) {
9792881Simp		printf("ADD_SEALS(%d, %llu -> %llu) failed: %m\n",
9812048Speter		       fd, (unsigned long long)s, (unsigned long long)seals);
9912048Speter		abort();
10089791Sgreen	}
101157950Smaxim}
10212048Speter
10312048Speterstatic int mfd_busy_add_seals(int fd, __u64 seals)
10412048Speter{
10512048Speter	long r;
10612048Speter	__u64 s;
10712048Speter
10812048Speter	r = fcntl(fd, F_GET_SEALS);
10912048Speter	if (r < 0)
11012048Speter		s = 0;
11198542Smckusick	else
11212048Speter		s = r;
11312048Speter
11412048Speter	r = fcntl(fd, F_ADD_SEALS, seals);
11598542Smckusick	if (r < 0 && errno != EBUSY) {
11698542Smckusick		printf("ADD_SEALS(%d, %llu -> %llu) didn't fail as expected with EBUSY: %m\n",
11712048Speter		       fd, (unsigned long long)s, (unsigned long long)seals);
11823854Sbde		abort();
11912048Speter	}
12012048Speter
12198542Smckusick	return r;
12212048Speter}
12312048Speter
12412048Speterstatic void *mfd_assert_mmap_shared(int fd)
12512048Speter{
12612048Speter	void *p;
12712048Speter
12812048Speter	p = mmap(NULL,
12912048Speter		 mfd_def_size,
13098542Smckusick		 PROT_READ | PROT_WRITE,
13112048Speter		 MAP_SHARED,
13212048Speter		 fd,
13312048Speter		 0);
13498542Smckusick	if (p == MAP_FAILED) {
13512048Speter		printf("mmap() failed: %m\n");
13612048Speter		abort();
13712048Speter	}
13898542Smckusick
13998542Smckusick	return p;
14098542Smckusick}
14198542Smckusick
14298542Smckusickstatic void *mfd_assert_mmap_private(int fd)
14398542Smckusick{
14498542Smckusick	void *p;
14598542Smckusick
14698542Smckusick	p = mmap(NULL,
14798542Smckusick		 mfd_def_size,
14898542Smckusick		 PROT_READ | PROT_WRITE,
14912048Speter		 MAP_PRIVATE,
15012048Speter		 fd,
15112048Speter		 0);
15212048Speter	if (p == MAP_FAILED) {
15312048Speter		printf("mmap() failed: %m\n");
15412048Speter		abort();
15512048Speter	}
15612048Speter
157122621Sjohan	return p;
158122621Sjohan}
159161558Sceri
160161558Sceristatic int global_mfd = -1;
161161558Sceristatic void *global_p = NULL;
162161558Sceri
163161558Sceristatic int sealing_thread_fn(void *arg)
164161558Sceri{
16598542Smckusick	int sig, r;
16698542Smckusick
16798542Smckusick	/*
16898542Smckusick	 * This thread first waits 200ms so any pending operation in the parent
16923854Sbde	 * is correctly started. After that, it tries to seal @global_mfd as
17012048Speter	 * SEAL_WRITE. This _must_ fail as the parent thread has a read() into
17198542Smckusick	 * that memory mapped object still ongoing.
17298542Smckusick	 * We then wait one more second and try sealing again. This time it
17398542Smckusick	 * must succeed as there shouldn't be anyone else pinning the pages.
17498542Smckusick	 */
17598542Smckusick
17623854Sbde	/* wait 200ms for FUSE-request to be active */
17712048Speter	usleep(200000);
17898542Smckusick
17998542Smckusick	/* unmount mapping before sealing to avoid i_mmap_writable failures */
18098542Smckusick	munmap(global_p, mfd_def_size);
18198542Smckusick
18298542Smckusick	/* Try sealing the global file; expect EBUSY or success. Current
18323854Sbde	 * kernels will never succeed, but in the future, kernels might
18412048Speter	 * implement page-replacements or other fancy ways to avoid racing
18598542Smckusick	 * writes. */
18612048Speter	r = mfd_busy_add_seals(global_mfd, F_SEAL_WRITE);
18798542Smckusick	if (r >= 0) {
18812048Speter		printf("HURRAY! This kernel fixed GUP races!\n");
18912048Speter	} else {
19098542Smckusick		/* wait 1s more so the FUSE-request is done */
19198542Smckusick		sleep(1);
19212048Speter
19312048Speter		/* try sealing the global file again */
19498542Smckusick		mfd_assert_add_seals(global_mfd, F_SEAL_WRITE);
19512048Speter	}
19698542Smckusick
19798542Smckusick	return 0;
198122621Sjohan}
199122621Sjohan
20012048Speterstatic pid_t spawn_sealing_thread(void)
20112048Speter{
20289827Sjoerg	uint8_t *stack;
20389827Sjoerg	pid_t pid;
20489827Sjoerg
20589827Sjoerg	stack = malloc(STACK_SIZE);
20689827Sjoerg	if (!stack) {
20789827Sjoerg		printf("malloc(STACK_SIZE) failed: %m\n");
20889827Sjoerg		abort();
20992881Simp	}
21089827Sjoerg
21189827Sjoerg	pid = clone(sealing_thread_fn,
21289827Sjoerg		    stack + STACK_SIZE,
21389827Sjoerg		    SIGCHLD | CLONE_FILES | CLONE_FS | CLONE_VM,
21489827Sjoerg		    NULL);
21589827Sjoerg	if (pid < 0) {
21689827Sjoerg		printf("clone() failed: %m\n");
21789827Sjoerg		abort();
21889827Sjoerg	}
21989827Sjoerg
22089827Sjoerg	return pid;
22189827Sjoerg}
22289827Sjoerg
22389827Sjoergstatic void join_sealing_thread(pid_t pid)
22489827Sjoerg{
22589827Sjoerg	waitpid(pid, NULL, 0);
22689827Sjoerg}
22789827Sjoerg
22889827Sjoergint main(int argc, char **argv)
229207141Sjeff{
23098542Smckusick	char *zero;
23189827Sjoerg	int fd, mfd, r;
23289827Sjoerg	void *p;
23398542Smckusick	int was_sealed;
23489827Sjoerg	pid_t pid;
23598542Smckusick
23689827Sjoerg	if (argc < 2) {
237207141Sjeff		printf("error: please pass path to file in fuse_mnt mount-point\n");
238207141Sjeff		abort();
239207141Sjeff	}
24089827Sjoerg
24189827Sjoerg	if (argc >= 3) {
24289827Sjoerg		if (!strcmp(argv[2], "hugetlbfs")) {
24389827Sjoerg			unsigned long hpage_size = default_huge_page_size();
24489827Sjoerg
24589827Sjoerg			if (!hpage_size) {
24689827Sjoerg				printf("Unable to determine huge page size\n");
24789827Sjoerg				abort();
24889827Sjoerg			}
24989827Sjoerg
25089827Sjoerg			hugetlbfs_test = 1;
25189827Sjoerg			mfd_def_size = hpage_size * 2;
25289827Sjoerg		} else {
25398542Smckusick			printf("Unknown option: %s\n", argv[2]);
25498542Smckusick			abort();
25598542Smckusick		}
25698542Smckusick	}
257207141Sjeff
258207141Sjeff	zero = calloc(sizeof(*zero), mfd_def_size);
259122621Sjohan
26089827Sjoerg	/* open FUSE memfd file for GUP testing */
26189827Sjoerg	printf("opening: %s\n", argv[1]);
26289827Sjoerg	fd = open(argv[1], O_RDONLY | O_CLOEXEC);
26389827Sjoerg	if (fd < 0) {
26489827Sjoerg		printf("cannot open(\"%s\"): %m\n", argv[1]);
26589827Sjoerg		abort();
26689827Sjoerg	}
26789827Sjoerg
26889827Sjoerg	/* create new memfd-object */
26989827Sjoerg	mfd = mfd_assert_new("kern_memfd_fuse",
27089827Sjoerg			     mfd_def_size,
27189827Sjoerg			     MFD_CLOEXEC | MFD_ALLOW_SEALING);
27289827Sjoerg
273207141Sjeff	/* mmap memfd-object for writing */
274207141Sjeff	p = mfd_assert_mmap_shared(mfd);
275207141Sjeff
27689827Sjoerg	/* pass mfd+mapping to a separate sealing-thread which tries to seal
27789827Sjoerg	 * the memfd objects with SEAL_WRITE while we write into it */
27889827Sjoerg	global_mfd = mfd;
27989827Sjoerg	global_p = p;
280207141Sjeff	pid = spawn_sealing_thread();
28189827Sjoerg
28289827Sjoerg	/* Use read() on the FUSE file to read into our memory-mapped memfd
28389827Sjoerg	 * object. This races the other thread which tries to seal the
28489827Sjoerg	 * memfd-object.
28589827Sjoerg	 * If @fd is on the memfd-fake-FUSE-FS, the read() is delayed by 1s.
28689827Sjoerg	 * This guarantees that the receive-buffer is pinned for 1s until the
28789827Sjoerg	 * data is written into it. The racing ADD_SEALS should thus fail as
28898542Smckusick	 * the pages are still pinned. */
28989827Sjoerg	r = read(fd, p, mfd_def_size);
29089827Sjoerg	if (r < 0) {
291122621Sjohan		printf("read() failed: %m\n");
29289827Sjoerg		abort();
29398542Smckusick	} else if (!r) {
29489827Sjoerg		printf("unexpected EOF on read()\n");
29589827Sjoerg		abort();
29689827Sjoerg	}
29798542Smckusick
298231574Struckman	was_sealed = mfd_assert_get_seals(mfd) & F_SEAL_WRITE;
29989827Sjoerg
30089827Sjoerg	/* Wait for sealing-thread to finish and verify that it
30198542Smckusick	 * successfully sealed the file after the second try. */
302122621Sjohan	join_sealing_thread(pid);
303231574Struckman	mfd_assert_has_seals(mfd, F_SEAL_WRITE);
304231574Struckman
305231574Struckman	/* *IF* the memfd-object was sealed at the time our read() returned,
306231574Struckman	 * then the kernel did a page-replacement or canceled the read() (or
30789827Sjoerg	 * whatever magic it did..). In that case, the memfd object is still
30889827Sjoerg	 * all zero.
30989827Sjoerg	 * In case the memfd-object was *not* sealed, the read() was successfull
31089827Sjoerg	 * and the memfd object must *not* be all zero.
31189827Sjoerg	 * Note that in real scenarios, there might be a mixture of both, but
312231574Struckman	 * in this test-cases, we have explicit 200ms delays which should be
31389827Sjoerg	 * enough to avoid any in-flight writes. */
31489827Sjoerg
31589827Sjoerg	p = mfd_assert_mmap_private(mfd);
31689827Sjoerg	if (was_sealed && memcmp(p, zero, mfd_def_size)) {
31789827Sjoerg		printf("memfd sealed during read() but data not discarded\n");
31889827Sjoerg		abort();
31989827Sjoerg	} else if (!was_sealed && !memcmp(p, zero, mfd_def_size)) {
320207141Sjeff		printf("memfd sealed after read() but data discarded\n");
32189827Sjoerg		abort();
32289827Sjoerg	}
32389827Sjoerg
32489827Sjoerg	close(mfd);
32512048Speter	close(fd);
32692881Simp
32712048Speter	printf("fuse: DONE\n");
32812048Speter	free(zero);
32912048Speter
33012048Speter	return 0;
33112048Speter}
33212048Speter