1// SPDX-License-Identifier: GPL-2.0
2// Copyright (C) 2020 ARM Limited
3
4#define _GNU_SOURCE
5
6#include <assert.h>
7#include <errno.h>
8#include <fcntl.h>
9#include <signal.h>
10#include <stdlib.h>
11#include <stdio.h>
12#include <string.h>
13#include <ucontext.h>
14#include <unistd.h>
15#include <sys/uio.h>
16#include <sys/mman.h>
17
18#include "kselftest.h"
19#include "mte_common_util.h"
20#include "mte_def.h"
21
22static size_t page_sz;
23
24#define TEST_NAME_MAX 100
25
26enum test_type {
27	READ_TEST,
28	WRITE_TEST,
29	READV_TEST,
30	WRITEV_TEST,
31	LAST_TEST,
32};
33
34static int check_usermem_access_fault(int mem_type, int mode, int mapping,
35                                      int tag_offset, int tag_len,
36                                      enum test_type test_type)
37{
38	int fd, i, err;
39	char val = 'A';
40	ssize_t len, syscall_len;
41	void *ptr, *ptr_next;
42	int fileoff, ptroff, size;
43	int sizes[] = {1, 2, 3, 8, 16, 32, 4096, page_sz};
44
45	err = KSFT_PASS;
46	len = 2 * page_sz;
47	mte_switch_mode(mode, MTE_ALLOW_NON_ZERO_TAG);
48	fd = create_temp_file();
49	if (fd == -1)
50		return KSFT_FAIL;
51	for (i = 0; i < len; i++)
52		if (write(fd, &val, sizeof(val)) != sizeof(val))
53			return KSFT_FAIL;
54	lseek(fd, 0, 0);
55	ptr = mte_allocate_memory(len, mem_type, mapping, true);
56	if (check_allocated_memory(ptr, len, mem_type, true) != KSFT_PASS) {
57		close(fd);
58		return KSFT_FAIL;
59	}
60	mte_initialize_current_context(mode, (uintptr_t)ptr, len);
61	/* Copy from file into buffer with valid tag */
62	syscall_len = read(fd, ptr, len);
63	mte_wait_after_trig();
64	if (cur_mte_cxt.fault_valid || syscall_len < len)
65		goto usermem_acc_err;
66	/* Verify same pattern is read */
67	for (i = 0; i < len; i++)
68		if (*(char *)(ptr + i) != val)
69			break;
70	if (i < len)
71		goto usermem_acc_err;
72
73	if (!tag_len)
74		tag_len = len - tag_offset;
75	/* Tag a part of memory with different value */
76	ptr_next = (void *)((unsigned long)ptr + tag_offset);
77	ptr_next = mte_insert_new_tag(ptr_next);
78	mte_set_tag_address_range(ptr_next, tag_len);
79
80	for (fileoff = 0; fileoff < 16; fileoff++) {
81		for (ptroff = 0; ptroff < 16; ptroff++) {
82			for (i = 0; i < ARRAY_SIZE(sizes); i++) {
83				size = sizes[i];
84				lseek(fd, 0, 0);
85
86				/* perform file operation on buffer with invalid tag */
87				switch (test_type) {
88				case READ_TEST:
89					syscall_len = read(fd, ptr + ptroff, size);
90					break;
91				case WRITE_TEST:
92					syscall_len = write(fd, ptr + ptroff, size);
93					break;
94				case READV_TEST: {
95					struct iovec iov[1];
96					iov[0].iov_base = ptr + ptroff;
97					iov[0].iov_len = size;
98					syscall_len = readv(fd, iov, 1);
99					break;
100				}
101				case WRITEV_TEST: {
102					struct iovec iov[1];
103					iov[0].iov_base = ptr + ptroff;
104					iov[0].iov_len = size;
105					syscall_len = writev(fd, iov, 1);
106					break;
107				}
108				case LAST_TEST:
109					goto usermem_acc_err;
110				}
111
112				mte_wait_after_trig();
113				/*
114				 * Accessing user memory in kernel with invalid tag should fail in sync
115				 * mode without fault but may not fail in async mode as per the
116				 * implemented MTE userspace support in Arm64 kernel.
117				 */
118				if (cur_mte_cxt.fault_valid) {
119					goto usermem_acc_err;
120				}
121				if (mode == MTE_SYNC_ERR && syscall_len < len) {
122					/* test passed */
123				} else if (mode == MTE_ASYNC_ERR && syscall_len == size) {
124					/* test passed */
125				} else {
126					goto usermem_acc_err;
127				}
128			}
129		}
130	}
131
132	goto exit;
133
134usermem_acc_err:
135	err = KSFT_FAIL;
136exit:
137	mte_free_memory((void *)ptr, len, mem_type, true);
138	close(fd);
139	return err;
140}
141
142void format_test_name(char* name, int name_len, int type, int sync, int map, int len, int offset) {
143	const char* test_type;
144	const char* mte_type;
145	const char* map_type;
146
147	switch (type) {
148	case READ_TEST:
149		test_type = "read";
150		break;
151	case WRITE_TEST:
152		test_type = "write";
153		break;
154	case READV_TEST:
155		test_type = "readv";
156		break;
157	case WRITEV_TEST:
158		test_type = "writev";
159		break;
160	default:
161		assert(0);
162		break;
163	}
164
165	switch (sync) {
166	case MTE_SYNC_ERR:
167		mte_type = "MTE_SYNC_ERR";
168		break;
169	case MTE_ASYNC_ERR:
170		mte_type = "MTE_ASYNC_ERR";
171		break;
172	default:
173		assert(0);
174		break;
175	}
176
177	switch (map) {
178	case MAP_SHARED:
179		map_type = "MAP_SHARED";
180		break;
181	case MAP_PRIVATE:
182		map_type = "MAP_PRIVATE";
183		break;
184	default:
185		assert(0);
186		break;
187	}
188
189	snprintf(name, name_len,
190	         "test type: %s, %s, %s, tag len: %d, tag offset: %d\n",
191	         test_type, mte_type, map_type, len, offset);
192}
193
194int main(int argc, char *argv[])
195{
196	int err;
197	int t, s, m, l, o;
198	int mte_sync[] = {MTE_SYNC_ERR, MTE_ASYNC_ERR};
199	int maps[] = {MAP_SHARED, MAP_PRIVATE};
200	int tag_lens[] = {0, MT_GRANULE_SIZE};
201	int tag_offsets[] = {page_sz, MT_GRANULE_SIZE};
202	char test_name[TEST_NAME_MAX];
203
204	page_sz = getpagesize();
205	if (!page_sz) {
206		ksft_print_msg("ERR: Unable to get page size\n");
207		return KSFT_FAIL;
208	}
209	err = mte_default_setup();
210	if (err)
211		return err;
212
213	/* Register signal handlers */
214	mte_register_signal(SIGSEGV, mte_default_handler);
215
216	/* Set test plan */
217	ksft_set_plan(64);
218
219	for (t = 0; t < LAST_TEST; t++) {
220		for (s = 0; s < ARRAY_SIZE(mte_sync); s++) {
221			for (m = 0; m < ARRAY_SIZE(maps); m++) {
222				for (l = 0; l < ARRAY_SIZE(tag_lens); l++) {
223					for (o = 0; o < ARRAY_SIZE(tag_offsets); o++) {
224						int sync = mte_sync[s];
225						int map = maps[m];
226						int offset = tag_offsets[o];
227						int tag_len = tag_lens[l];
228						int res = check_usermem_access_fault(USE_MMAP, sync,
229						                                     map, offset,
230						                                     tag_len, t);
231						format_test_name(test_name, TEST_NAME_MAX,
232						                 t, sync, map, tag_len, offset);
233						evaluate_test(res, test_name);
234					}
235				}
236			}
237		}
238	}
239
240	mte_restore_setup();
241	ksft_print_cnts();
242	return ksft_get_fail_cnt() == 0 ? KSFT_PASS : KSFT_FAIL;
243}
244