1/* SPDX-License-Identifier: GPL-2.0 */
2#define _GNU_SOURCE
3#include <linux/membarrier.h>
4#include <syscall.h>
5#include <stdio.h>
6#include <errno.h>
7#include <string.h>
8#include <pthread.h>
9
10#include "../kselftest.h"
11
12static int registrations;
13
14static int sys_membarrier(int cmd, int flags)
15{
16	return syscall(__NR_membarrier, cmd, flags);
17}
18
19static int test_membarrier_get_registrations(int cmd)
20{
21	int ret, flags = 0;
22	const char *test_name =
23		"sys membarrier MEMBARRIER_CMD_GET_REGISTRATIONS";
24
25	registrations |= cmd;
26
27	ret = sys_membarrier(MEMBARRIER_CMD_GET_REGISTRATIONS, 0);
28	if (ret < 0) {
29		ksft_exit_fail_msg(
30			"%s test: flags = %d, errno = %d\n",
31			test_name, flags, errno);
32	} else if (ret != registrations) {
33		ksft_exit_fail_msg(
34			"%s test: flags = %d, ret = %d, registrations = %d\n",
35			test_name, flags, ret, registrations);
36	}
37	ksft_test_result_pass(
38		"%s test: flags = %d, ret = %d, registrations = %d\n",
39		test_name, flags, ret, registrations);
40
41	return 0;
42}
43
44static int test_membarrier_cmd_fail(void)
45{
46	int cmd = -1, flags = 0;
47	const char *test_name = "sys membarrier invalid command";
48
49	if (sys_membarrier(cmd, flags) != -1) {
50		ksft_exit_fail_msg(
51			"%s test: command = %d, flags = %d. Should fail, but passed\n",
52			test_name, cmd, flags);
53	}
54	if (errno != EINVAL) {
55		ksft_exit_fail_msg(
56			"%s test: flags = %d. Should return (%d: \"%s\"), but returned (%d: \"%s\").\n",
57			test_name, flags, EINVAL, strerror(EINVAL),
58			errno, strerror(errno));
59	}
60
61	ksft_test_result_pass(
62		"%s test: command = %d, flags = %d, errno = %d. Failed as expected\n",
63		test_name, cmd, flags, errno);
64	return 0;
65}
66
67static int test_membarrier_flags_fail(void)
68{
69	int cmd = MEMBARRIER_CMD_QUERY, flags = 1;
70	const char *test_name = "sys membarrier MEMBARRIER_CMD_QUERY invalid flags";
71
72	if (sys_membarrier(cmd, flags) != -1) {
73		ksft_exit_fail_msg(
74			"%s test: flags = %d. Should fail, but passed\n",
75			test_name, flags);
76	}
77	if (errno != EINVAL) {
78		ksft_exit_fail_msg(
79			"%s test: flags = %d. Should return (%d: \"%s\"), but returned (%d: \"%s\").\n",
80			test_name, flags, EINVAL, strerror(EINVAL),
81			errno, strerror(errno));
82	}
83
84	ksft_test_result_pass(
85		"%s test: flags = %d, errno = %d. Failed as expected\n",
86		test_name, flags, errno);
87	return 0;
88}
89
90static int test_membarrier_global_success(void)
91{
92	int cmd = MEMBARRIER_CMD_GLOBAL, flags = 0;
93	const char *test_name = "sys membarrier MEMBARRIER_CMD_GLOBAL";
94
95	if (sys_membarrier(cmd, flags) != 0) {
96		ksft_exit_fail_msg(
97			"%s test: flags = %d, errno = %d\n",
98			test_name, flags, errno);
99	}
100
101	ksft_test_result_pass(
102		"%s test: flags = %d\n", test_name, flags);
103	return 0;
104}
105
106static int test_membarrier_private_expedited_fail(void)
107{
108	int cmd = MEMBARRIER_CMD_PRIVATE_EXPEDITED, flags = 0;
109	const char *test_name = "sys membarrier MEMBARRIER_CMD_PRIVATE_EXPEDITED not registered failure";
110
111	if (sys_membarrier(cmd, flags) != -1) {
112		ksft_exit_fail_msg(
113			"%s test: flags = %d. Should fail, but passed\n",
114			test_name, flags);
115	}
116	if (errno != EPERM) {
117		ksft_exit_fail_msg(
118			"%s test: flags = %d. Should return (%d: \"%s\"), but returned (%d: \"%s\").\n",
119			test_name, flags, EPERM, strerror(EPERM),
120			errno, strerror(errno));
121	}
122
123	ksft_test_result_pass(
124		"%s test: flags = %d, errno = %d\n",
125		test_name, flags, errno);
126	return 0;
127}
128
129static int test_membarrier_register_private_expedited_success(void)
130{
131	int cmd = MEMBARRIER_CMD_REGISTER_PRIVATE_EXPEDITED, flags = 0;
132	const char *test_name = "sys membarrier MEMBARRIER_CMD_REGISTER_PRIVATE_EXPEDITED";
133
134	if (sys_membarrier(cmd, flags) != 0) {
135		ksft_exit_fail_msg(
136			"%s test: flags = %d, errno = %d\n",
137			test_name, flags, errno);
138	}
139
140	ksft_test_result_pass(
141		"%s test: flags = %d\n",
142		test_name, flags);
143
144	test_membarrier_get_registrations(cmd);
145	return 0;
146}
147
148static int test_membarrier_private_expedited_success(void)
149{
150	int cmd = MEMBARRIER_CMD_PRIVATE_EXPEDITED, flags = 0;
151	const char *test_name = "sys membarrier MEMBARRIER_CMD_PRIVATE_EXPEDITED";
152
153	if (sys_membarrier(cmd, flags) != 0) {
154		ksft_exit_fail_msg(
155			"%s test: flags = %d, errno = %d\n",
156			test_name, flags, errno);
157	}
158
159	ksft_test_result_pass(
160		"%s test: flags = %d\n",
161		test_name, flags);
162	return 0;
163}
164
165static int test_membarrier_private_expedited_sync_core_fail(void)
166{
167	int cmd = MEMBARRIER_CMD_PRIVATE_EXPEDITED_SYNC_CORE, flags = 0;
168	const char *test_name = "sys membarrier MEMBARRIER_CMD_PRIVATE_EXPEDITED_SYNC_CORE not registered failure";
169
170	if (sys_membarrier(cmd, flags) != -1) {
171		ksft_exit_fail_msg(
172			"%s test: flags = %d. Should fail, but passed\n",
173			test_name, flags);
174	}
175	if (errno != EPERM) {
176		ksft_exit_fail_msg(
177			"%s test: flags = %d. Should return (%d: \"%s\"), but returned (%d: \"%s\").\n",
178			test_name, flags, EPERM, strerror(EPERM),
179			errno, strerror(errno));
180	}
181
182	ksft_test_result_pass(
183		"%s test: flags = %d, errno = %d\n",
184		test_name, flags, errno);
185	return 0;
186}
187
188static int test_membarrier_register_private_expedited_sync_core_success(void)
189{
190	int cmd = MEMBARRIER_CMD_REGISTER_PRIVATE_EXPEDITED_SYNC_CORE, flags = 0;
191	const char *test_name = "sys membarrier MEMBARRIER_CMD_REGISTER_PRIVATE_EXPEDITED_SYNC_CORE";
192
193	if (sys_membarrier(cmd, flags) != 0) {
194		ksft_exit_fail_msg(
195			"%s test: flags = %d, errno = %d\n",
196			test_name, flags, errno);
197	}
198
199	ksft_test_result_pass(
200		"%s test: flags = %d\n",
201		test_name, flags);
202
203	test_membarrier_get_registrations(cmd);
204	return 0;
205}
206
207static int test_membarrier_private_expedited_sync_core_success(void)
208{
209	int cmd = MEMBARRIER_CMD_PRIVATE_EXPEDITED, flags = 0;
210	const char *test_name = "sys membarrier MEMBARRIER_CMD_PRIVATE_EXPEDITED_SYNC_CORE";
211
212	if (sys_membarrier(cmd, flags) != 0) {
213		ksft_exit_fail_msg(
214			"%s test: flags = %d, errno = %d\n",
215			test_name, flags, errno);
216	}
217
218	ksft_test_result_pass(
219		"%s test: flags = %d\n",
220		test_name, flags);
221	return 0;
222}
223
224static int test_membarrier_register_global_expedited_success(void)
225{
226	int cmd = MEMBARRIER_CMD_REGISTER_GLOBAL_EXPEDITED, flags = 0;
227	const char *test_name = "sys membarrier MEMBARRIER_CMD_REGISTER_GLOBAL_EXPEDITED";
228
229	if (sys_membarrier(cmd, flags) != 0) {
230		ksft_exit_fail_msg(
231			"%s test: flags = %d, errno = %d\n",
232			test_name, flags, errno);
233	}
234
235	ksft_test_result_pass(
236		"%s test: flags = %d\n",
237		test_name, flags);
238
239	test_membarrier_get_registrations(cmd);
240	return 0;
241}
242
243static int test_membarrier_global_expedited_success(void)
244{
245	int cmd = MEMBARRIER_CMD_GLOBAL_EXPEDITED, flags = 0;
246	const char *test_name = "sys membarrier MEMBARRIER_CMD_GLOBAL_EXPEDITED";
247
248	if (sys_membarrier(cmd, flags) != 0) {
249		ksft_exit_fail_msg(
250			"%s test: flags = %d, errno = %d\n",
251			test_name, flags, errno);
252	}
253
254	ksft_test_result_pass(
255		"%s test: flags = %d\n",
256		test_name, flags);
257	return 0;
258}
259
260static int test_membarrier_fail(void)
261{
262	int status;
263
264	status = test_membarrier_cmd_fail();
265	if (status)
266		return status;
267	status = test_membarrier_flags_fail();
268	if (status)
269		return status;
270	status = test_membarrier_private_expedited_fail();
271	if (status)
272		return status;
273	status = sys_membarrier(MEMBARRIER_CMD_QUERY, 0);
274	if (status < 0) {
275		ksft_test_result_fail("sys_membarrier() failed\n");
276		return status;
277	}
278	if (status & MEMBARRIER_CMD_PRIVATE_EXPEDITED_SYNC_CORE) {
279		status = test_membarrier_private_expedited_sync_core_fail();
280		if (status)
281			return status;
282	}
283	return 0;
284}
285
286static int test_membarrier_success(void)
287{
288	int status;
289
290	status = test_membarrier_global_success();
291	if (status)
292		return status;
293	status = test_membarrier_register_private_expedited_success();
294	if (status)
295		return status;
296	status = test_membarrier_private_expedited_success();
297	if (status)
298		return status;
299	status = sys_membarrier(MEMBARRIER_CMD_QUERY, 0);
300	if (status < 0) {
301		ksft_test_result_fail("sys_membarrier() failed\n");
302		return status;
303	}
304	if (status & MEMBARRIER_CMD_PRIVATE_EXPEDITED_SYNC_CORE) {
305		status = test_membarrier_register_private_expedited_sync_core_success();
306		if (status)
307			return status;
308		status = test_membarrier_private_expedited_sync_core_success();
309		if (status)
310			return status;
311	}
312	/*
313	 * It is valid to send a global membarrier from a non-registered
314	 * process.
315	 */
316	status = test_membarrier_global_expedited_success();
317	if (status)
318		return status;
319	status = test_membarrier_register_global_expedited_success();
320	if (status)
321		return status;
322	status = test_membarrier_global_expedited_success();
323	if (status)
324		return status;
325	return 0;
326}
327
328static int test_membarrier_query(void)
329{
330	int flags = 0, ret;
331
332	ret = sys_membarrier(MEMBARRIER_CMD_QUERY, flags);
333	if (ret < 0) {
334		if (errno == ENOSYS) {
335			/*
336			 * It is valid to build a kernel with
337			 * CONFIG_MEMBARRIER=n. However, this skips the tests.
338			 */
339			ksft_exit_skip(
340				"sys membarrier (CONFIG_MEMBARRIER) is disabled.\n");
341		}
342		ksft_exit_fail_msg("sys_membarrier() failed\n");
343	}
344	if (!(ret & MEMBARRIER_CMD_GLOBAL))
345		ksft_exit_skip(
346			"sys_membarrier unsupported: CMD_GLOBAL not found.\n");
347
348	ksft_test_result_pass("sys_membarrier available\n");
349	return 0;
350}
351