1// SPDX-License-Identifier: GPL-2.0-only
2/*
3 * Copyright (C) 2022 ARM Limited.
4 */
5#include <errno.h>
6#include <stdbool.h>
7#include <stddef.h>
8#include <stdio.h>
9#include <stdlib.h>
10#include <string.h>
11#include <unistd.h>
12#include <sys/auxv.h>
13#include <sys/prctl.h>
14#include <sys/ptrace.h>
15#include <sys/types.h>
16#include <sys/uio.h>
17#include <sys/wait.h>
18#include <asm/sigcontext.h>
19#include <asm/ptrace.h>
20
21#include "../../kselftest.h"
22
23#define EXPECTED_TESTS 11
24
25#define MAX_TPIDRS 2
26
27static bool have_sme(void)
28{
29	return getauxval(AT_HWCAP2) & HWCAP2_SME;
30}
31
32static void test_tpidr(pid_t child)
33{
34	uint64_t read_val[MAX_TPIDRS];
35	uint64_t write_val[MAX_TPIDRS];
36	struct iovec read_iov, write_iov;
37	bool test_tpidr2 = false;
38	int ret, i;
39
40	read_iov.iov_base = read_val;
41	write_iov.iov_base = write_val;
42
43	/* Should be able to read a single TPIDR... */
44	read_iov.iov_len = sizeof(uint64_t);
45	ret = ptrace(PTRACE_GETREGSET, child, NT_ARM_TLS, &read_iov);
46	ksft_test_result(ret == 0, "read_tpidr_one\n");
47
48	/* ...write a new value.. */
49	write_iov.iov_len = sizeof(uint64_t);
50	write_val[0] = read_val[0]++;
51	ret = ptrace(PTRACE_SETREGSET, child, NT_ARM_TLS, &write_iov);
52	ksft_test_result(ret == 0, "write_tpidr_one\n");
53
54	/* ...then read it back */
55	ret = ptrace(PTRACE_GETREGSET, child, NT_ARM_TLS, &read_iov);
56	ksft_test_result(ret == 0 && write_val[0] == read_val[0],
57			 "verify_tpidr_one\n");
58
59	/* If we have TPIDR2 we should be able to read it */
60	read_iov.iov_len = sizeof(read_val);
61	ret = ptrace(PTRACE_GETREGSET, child, NT_ARM_TLS, &read_iov);
62	if (ret == 0) {
63		/* If we have SME there should be two TPIDRs */
64		if (read_iov.iov_len >= sizeof(read_val))
65			test_tpidr2 = true;
66
67		if (have_sme() && test_tpidr2) {
68			ksft_test_result(test_tpidr2, "count_tpidrs\n");
69		} else {
70			ksft_test_result(read_iov.iov_len % sizeof(uint64_t) == 0,
71					 "count_tpidrs\n");
72		}
73	} else {
74		ksft_test_result_fail("count_tpidrs\n");
75	}
76
77	if (test_tpidr2) {
78		/* Try to write new values to all known TPIDRs... */
79		write_iov.iov_len = sizeof(write_val);
80		for (i = 0; i < MAX_TPIDRS; i++)
81			write_val[i] = read_val[i] + 1;
82		ret = ptrace(PTRACE_SETREGSET, child, NT_ARM_TLS, &write_iov);
83
84		ksft_test_result(ret == 0 &&
85				 write_iov.iov_len == sizeof(write_val),
86				 "tpidr2_write\n");
87
88		/* ...then read them back */
89		read_iov.iov_len = sizeof(read_val);
90		ret = ptrace(PTRACE_GETREGSET, child, NT_ARM_TLS, &read_iov);
91
92		if (have_sme()) {
93			/* Should read back the written value */
94			ksft_test_result(ret == 0 &&
95					 read_iov.iov_len >= sizeof(read_val) &&
96					 memcmp(read_val, write_val,
97						sizeof(read_val)) == 0,
98					 "tpidr2_read\n");
99		} else {
100			/* TPIDR2 should read as zero */
101			ksft_test_result(ret == 0 &&
102					 read_iov.iov_len >= sizeof(read_val) &&
103					 read_val[0] == write_val[0] &&
104					 read_val[1] == 0,
105					 "tpidr2_read\n");
106		}
107
108		/* Writing only TPIDR... */
109		write_iov.iov_len = sizeof(uint64_t);
110		memcpy(write_val, read_val, sizeof(read_val));
111		write_val[0] += 1;
112		ret = ptrace(PTRACE_SETREGSET, child, NT_ARM_TLS, &write_iov);
113
114		if (ret == 0) {
115			/* ...should leave TPIDR2 untouched */
116			read_iov.iov_len = sizeof(read_val);
117			ret = ptrace(PTRACE_GETREGSET, child, NT_ARM_TLS,
118				     &read_iov);
119
120			ksft_test_result(ret == 0 &&
121					 read_iov.iov_len >= sizeof(read_val) &&
122					 memcmp(read_val, write_val,
123						sizeof(read_val)) == 0,
124					 "write_tpidr_only\n");
125		} else {
126			ksft_test_result_fail("write_tpidr_only\n");
127		}
128	} else {
129		ksft_test_result_skip("tpidr2_write\n");
130		ksft_test_result_skip("tpidr2_read\n");
131		ksft_test_result_skip("write_tpidr_only\n");
132	}
133}
134
135static void test_hw_debug(pid_t child, int type, const char *type_name)
136{
137	struct user_hwdebug_state state;
138	struct iovec iov;
139	int slots, arch, ret;
140
141	iov.iov_len = sizeof(state);
142	iov.iov_base = &state;
143
144	/* Should be able to read the values */
145	ret = ptrace(PTRACE_GETREGSET, child, type, &iov);
146	ksft_test_result(ret == 0, "read_%s\n", type_name);
147
148	if (ret == 0) {
149		/* Low 8 bits is the number of slots, next 4 bits the arch */
150		slots = state.dbg_info & 0xff;
151		arch = (state.dbg_info >> 8) & 0xf;
152
153		ksft_print_msg("%s version %d with %d slots\n", type_name,
154			       arch, slots);
155
156		/* Zero is not currently architecturally valid */
157		ksft_test_result(arch, "%s_arch_set\n", type_name);
158	} else {
159		ksft_test_result_skip("%s_arch_set\n");
160	}
161}
162
163static int do_child(void)
164{
165	if (ptrace(PTRACE_TRACEME, -1, NULL, NULL))
166		ksft_exit_fail_msg("PTRACE_TRACEME", strerror(errno));
167
168	if (raise(SIGSTOP))
169		ksft_exit_fail_msg("raise(SIGSTOP)", strerror(errno));
170
171	return EXIT_SUCCESS;
172}
173
174static int do_parent(pid_t child)
175{
176	int ret = EXIT_FAILURE;
177	pid_t pid;
178	int status;
179	siginfo_t si;
180
181	/* Attach to the child */
182	while (1) {
183		int sig;
184
185		pid = wait(&status);
186		if (pid == -1) {
187			perror("wait");
188			goto error;
189		}
190
191		/*
192		 * This should never happen but it's hard to flag in
193		 * the framework.
194		 */
195		if (pid != child)
196			continue;
197
198		if (WIFEXITED(status) || WIFSIGNALED(status))
199			ksft_exit_fail_msg("Child died unexpectedly\n");
200
201		if (!WIFSTOPPED(status))
202			goto error;
203
204		sig = WSTOPSIG(status);
205
206		if (ptrace(PTRACE_GETSIGINFO, pid, NULL, &si)) {
207			if (errno == ESRCH)
208				goto disappeared;
209
210			if (errno == EINVAL) {
211				sig = 0; /* bust group-stop */
212				goto cont;
213			}
214
215			ksft_test_result_fail("PTRACE_GETSIGINFO: %s\n",
216					      strerror(errno));
217			goto error;
218		}
219
220		if (sig == SIGSTOP && si.si_code == SI_TKILL &&
221		    si.si_pid == pid)
222			break;
223
224	cont:
225		if (ptrace(PTRACE_CONT, pid, NULL, sig)) {
226			if (errno == ESRCH)
227				goto disappeared;
228
229			ksft_test_result_fail("PTRACE_CONT: %s\n",
230					      strerror(errno));
231			goto error;
232		}
233	}
234
235	ksft_print_msg("Parent is %d, child is %d\n", getpid(), child);
236
237	test_tpidr(child);
238	test_hw_debug(child, NT_ARM_HW_WATCH, "NT_ARM_HW_WATCH");
239	test_hw_debug(child, NT_ARM_HW_BREAK, "NT_ARM_HW_BREAK");
240
241	ret = EXIT_SUCCESS;
242
243error:
244	kill(child, SIGKILL);
245
246disappeared:
247	return ret;
248}
249
250int main(void)
251{
252	int ret = EXIT_SUCCESS;
253	pid_t child;
254
255	srandom(getpid());
256
257	ksft_print_header();
258
259	ksft_set_plan(EXPECTED_TESTS);
260
261	child = fork();
262	if (!child)
263		return do_child();
264
265	if (do_parent(child))
266		ret = EXIT_FAILURE;
267
268	ksft_print_cnts();
269
270	return ret;
271}
272