1// SPDX-License-Identifier: GPL-2.0-only
2/*
3 * Copyright (C) 2023 ARM Limited.
4 * Original author: Mark Brown <broonie@kernel.org>
5 */
6
7#define _GNU_SOURCE
8
9#include <errno.h>
10#include <stdbool.h>
11#include <stddef.h>
12#include <stdio.h>
13#include <stdlib.h>
14#include <string.h>
15#include <unistd.h>
16
17#include <sys/auxv.h>
18#include <sys/prctl.h>
19#include <sys/ptrace.h>
20#include <sys/types.h>
21#include <sys/uio.h>
22#include <sys/wait.h>
23
24#include <linux/kernel.h>
25
26#include <asm/sigcontext.h>
27#include <asm/sve_context.h>
28#include <asm/ptrace.h>
29
30#include "../../kselftest.h"
31
32#include "fp-ptrace.h"
33
34/* <linux/elf.h> and <sys/auxv.h> don't like each other, so: */
35#ifndef NT_ARM_SVE
36#define NT_ARM_SVE 0x405
37#endif
38
39#ifndef NT_ARM_SSVE
40#define NT_ARM_SSVE 0x40b
41#endif
42
43#ifndef NT_ARM_ZA
44#define NT_ARM_ZA 0x40c
45#endif
46
47#ifndef NT_ARM_ZT
48#define NT_ARM_ZT 0x40d
49#endif
50
51#define ARCH_VQ_MAX 256
52
53/* VL 128..2048 in powers of 2 */
54#define MAX_NUM_VLS 5
55
56#define NUM_FPR 32
57__uint128_t v_in[NUM_FPR];
58__uint128_t v_expected[NUM_FPR];
59__uint128_t v_out[NUM_FPR];
60
61char z_in[__SVE_ZREGS_SIZE(ARCH_VQ_MAX)];
62char z_expected[__SVE_ZREGS_SIZE(ARCH_VQ_MAX)];
63char z_out[__SVE_ZREGS_SIZE(ARCH_VQ_MAX)];
64
65char p_in[__SVE_PREGS_SIZE(ARCH_VQ_MAX)];
66char p_expected[__SVE_PREGS_SIZE(ARCH_VQ_MAX)];
67char p_out[__SVE_PREGS_SIZE(ARCH_VQ_MAX)];
68
69char ffr_in[__SVE_PREG_SIZE(ARCH_VQ_MAX)];
70char ffr_expected[__SVE_PREG_SIZE(ARCH_VQ_MAX)];
71char ffr_out[__SVE_PREG_SIZE(ARCH_VQ_MAX)];
72
73char za_in[ZA_SIG_REGS_SIZE(ARCH_VQ_MAX)];
74char za_expected[ZA_SIG_REGS_SIZE(ARCH_VQ_MAX)];
75char za_out[ZA_SIG_REGS_SIZE(ARCH_VQ_MAX)];
76
77char zt_in[ZT_SIG_REG_BYTES];
78char zt_expected[ZT_SIG_REG_BYTES];
79char zt_out[ZT_SIG_REG_BYTES];
80
81uint64_t sve_vl_out;
82uint64_t sme_vl_out;
83uint64_t svcr_in, svcr_expected, svcr_out;
84
85void load_and_save(int sve, int sme, int sme2, int fa64);
86
87static bool got_alarm;
88
89static void handle_alarm(int sig, siginfo_t *info, void *context)
90{
91	got_alarm = true;
92}
93
94#ifdef CONFIG_CPU_BIG_ENDIAN
95static __uint128_t arm64_cpu_to_le128(__uint128_t x)
96{
97	u64 a = swab64(x);
98	u64 b = swab64(x >> 64);
99
100	return ((__uint128_t)a << 64) | b;
101}
102#else
103static __uint128_t arm64_cpu_to_le128(__uint128_t x)
104{
105	return x;
106}
107#endif
108
109#define arm64_le128_to_cpu(x) arm64_cpu_to_le128(x)
110
111static bool sve_supported(void)
112{
113	return getauxval(AT_HWCAP) & HWCAP_SVE;
114}
115
116static bool sme_supported(void)
117{
118	return getauxval(AT_HWCAP2) & HWCAP2_SME;
119}
120
121static bool sme2_supported(void)
122{
123	return getauxval(AT_HWCAP2) & HWCAP2_SME2;
124}
125
126static bool fa64_supported(void)
127{
128	return getauxval(AT_HWCAP2) & HWCAP2_SME_FA64;
129}
130
131static bool compare_buffer(const char *name, void *out,
132			   void *expected, size_t size)
133{
134	void *tmp;
135
136	if (memcmp(out, expected, size) == 0)
137		return true;
138
139	ksft_print_msg("Mismatch in %s\n", name);
140
141	/* Did we just get zeros back? */
142	tmp = malloc(size);
143	if (!tmp) {
144		ksft_print_msg("OOM allocating %lu bytes for %s\n",
145			       size, name);
146		ksft_exit_fail();
147	}
148	memset(tmp, 0, size);
149
150	if (memcmp(out, tmp, size) == 0)
151		ksft_print_msg("%s is zero\n", name);
152
153	free(tmp);
154
155	return false;
156}
157
158struct test_config {
159	int sve_vl_in;
160	int sve_vl_expected;
161	int sme_vl_in;
162	int sme_vl_expected;
163	int svcr_in;
164	int svcr_expected;
165};
166
167struct test_definition {
168	const char *name;
169	bool sve_vl_change;
170	bool (*supported)(struct test_config *config);
171	void (*set_expected_values)(struct test_config *config);
172	void (*modify_values)(pid_t child, struct test_config *test_config);
173};
174
175static int vl_in(struct test_config *config)
176{
177	int vl;
178
179	if (config->svcr_in & SVCR_SM)
180		vl = config->sme_vl_in;
181	else
182		vl = config->sve_vl_in;
183
184	return vl;
185}
186
187static int vl_expected(struct test_config *config)
188{
189	int vl;
190
191	if (config->svcr_expected & SVCR_SM)
192		vl = config->sme_vl_expected;
193	else
194		vl = config->sve_vl_expected;
195
196	return vl;
197}
198
199static void run_child(struct test_config *config)
200{
201	int ret;
202
203	/* Let the parent attach to us */
204	ret = ptrace(PTRACE_TRACEME, 0, 0, 0);
205	if (ret < 0)
206		ksft_exit_fail_msg("PTRACE_TRACEME failed: %s (%d)\n",
207				   strerror(errno), errno);
208
209	/* VL setup */
210	if (sve_supported()) {
211		ret = prctl(PR_SVE_SET_VL, config->sve_vl_in);
212		if (ret != config->sve_vl_in) {
213			ksft_print_msg("Failed to set SVE VL %d: %d\n",
214				       config->sve_vl_in, ret);
215		}
216	}
217
218	if (sme_supported()) {
219		ret = prctl(PR_SME_SET_VL, config->sme_vl_in);
220		if (ret != config->sme_vl_in) {
221			ksft_print_msg("Failed to set SME VL %d: %d\n",
222				       config->sme_vl_in, ret);
223		}
224	}
225
226	/* Load values and wait for the parent */
227	load_and_save(sve_supported(), sme_supported(),
228		      sme2_supported(), fa64_supported());
229
230	exit(0);
231}
232
233static void read_one_child_regs(pid_t child, char *name,
234				struct iovec *iov_parent,
235				struct iovec *iov_child)
236{
237	int len = iov_parent->iov_len;
238	int ret;
239
240	ret = process_vm_readv(child, iov_parent, 1, iov_child, 1, 0);
241	if (ret == -1)
242		ksft_print_msg("%s read failed: %s (%d)\n",
243			       name, strerror(errno), errno);
244	else if (ret != len)
245		ksft_print_msg("Short read of %s: %d\n", name, ret);
246}
247
248static void read_child_regs(pid_t child)
249{
250	struct iovec iov_parent, iov_child;
251
252	/*
253	 * Since the child fork()ed from us the buffer addresses are
254	 * the same in parent and child.
255	 */
256	iov_parent.iov_base = &v_out;
257	iov_parent.iov_len = sizeof(v_out);
258	iov_child.iov_base = &v_out;
259	iov_child.iov_len = sizeof(v_out);
260	read_one_child_regs(child, "FPSIMD", &iov_parent, &iov_child);
261
262	if (sve_supported() || sme_supported()) {
263		iov_parent.iov_base = &sve_vl_out;
264		iov_parent.iov_len = sizeof(sve_vl_out);
265		iov_child.iov_base = &sve_vl_out;
266		iov_child.iov_len = sizeof(sve_vl_out);
267		read_one_child_regs(child, "SVE VL", &iov_parent, &iov_child);
268
269		iov_parent.iov_base = &z_out;
270		iov_parent.iov_len = sizeof(z_out);
271		iov_child.iov_base = &z_out;
272		iov_child.iov_len = sizeof(z_out);
273		read_one_child_regs(child, "Z", &iov_parent, &iov_child);
274
275		iov_parent.iov_base = &p_out;
276		iov_parent.iov_len = sizeof(p_out);
277		iov_child.iov_base = &p_out;
278		iov_child.iov_len = sizeof(p_out);
279		read_one_child_regs(child, "P", &iov_parent, &iov_child);
280
281		iov_parent.iov_base = &ffr_out;
282		iov_parent.iov_len = sizeof(ffr_out);
283		iov_child.iov_base = &ffr_out;
284		iov_child.iov_len = sizeof(ffr_out);
285		read_one_child_regs(child, "FFR", &iov_parent, &iov_child);
286	}
287
288	if (sme_supported()) {
289		iov_parent.iov_base = &sme_vl_out;
290		iov_parent.iov_len = sizeof(sme_vl_out);
291		iov_child.iov_base = &sme_vl_out;
292		iov_child.iov_len = sizeof(sme_vl_out);
293		read_one_child_regs(child, "SME VL", &iov_parent, &iov_child);
294
295		iov_parent.iov_base = &svcr_out;
296		iov_parent.iov_len = sizeof(svcr_out);
297		iov_child.iov_base = &svcr_out;
298		iov_child.iov_len = sizeof(svcr_out);
299		read_one_child_regs(child, "SVCR", &iov_parent, &iov_child);
300
301		iov_parent.iov_base = &za_out;
302		iov_parent.iov_len = sizeof(za_out);
303		iov_child.iov_base = &za_out;
304		iov_child.iov_len = sizeof(za_out);
305		read_one_child_regs(child, "ZA", &iov_parent, &iov_child);
306	}
307
308	if (sme2_supported()) {
309		iov_parent.iov_base = &zt_out;
310		iov_parent.iov_len = sizeof(zt_out);
311		iov_child.iov_base = &zt_out;
312		iov_child.iov_len = sizeof(zt_out);
313		read_one_child_regs(child, "ZT", &iov_parent, &iov_child);
314	}
315}
316
317static bool continue_breakpoint(pid_t child,
318				enum __ptrace_request restart_type)
319{
320	struct user_pt_regs pt_regs;
321	struct iovec iov;
322	int ret;
323
324	/* Get PC */
325	iov.iov_base = &pt_regs;
326	iov.iov_len = sizeof(pt_regs);
327	ret = ptrace(PTRACE_GETREGSET, child, NT_PRSTATUS, &iov);
328	if (ret < 0) {
329		ksft_print_msg("Failed to get PC: %s (%d)\n",
330			       strerror(errno), errno);
331		return false;
332	}
333
334	/* Skip over the BRK */
335	pt_regs.pc += 4;
336	ret = ptrace(PTRACE_SETREGSET, child, NT_PRSTATUS, &iov);
337	if (ret < 0) {
338		ksft_print_msg("Failed to skip BRK: %s (%d)\n",
339			       strerror(errno), errno);
340		return false;
341	}
342
343	/* Restart */
344	ret = ptrace(restart_type, child, 0, 0);
345	if (ret < 0) {
346		ksft_print_msg("Failed to restart child: %s (%d)\n",
347			       strerror(errno), errno);
348		return false;
349	}
350
351	return true;
352}
353
354static bool check_ptrace_values_sve(pid_t child, struct test_config *config)
355{
356	struct user_sve_header *sve;
357	struct user_fpsimd_state *fpsimd;
358	struct iovec iov;
359	int ret, vq;
360	bool pass = true;
361
362	if (!sve_supported())
363		return true;
364
365	vq = __sve_vq_from_vl(config->sve_vl_in);
366
367	iov.iov_len = SVE_PT_SVE_OFFSET + SVE_PT_SVE_SIZE(vq, SVE_PT_REGS_SVE);
368	iov.iov_base = malloc(iov.iov_len);
369	if (!iov.iov_base) {
370		ksft_print_msg("OOM allocating %lu byte SVE buffer\n",
371			       iov.iov_len);
372		return false;
373	}
374
375	ret = ptrace(PTRACE_GETREGSET, child, NT_ARM_SVE, &iov);
376	if (ret != 0) {
377		ksft_print_msg("Failed to read initial SVE: %s (%d)\n",
378			       strerror(errno), errno);
379		pass = false;
380		goto out;
381	}
382
383	sve = iov.iov_base;
384
385	if (sve->vl != config->sve_vl_in) {
386		ksft_print_msg("Mismatch in initial SVE VL: %d != %d\n",
387			       sve->vl, config->sve_vl_in);
388		pass = false;
389	}
390
391	/* If we are in streaming mode we should just read FPSIMD */
392	if ((config->svcr_in & SVCR_SM) && (sve->flags & SVE_PT_REGS_SVE)) {
393		ksft_print_msg("NT_ARM_SVE reports SVE with PSTATE.SM\n");
394		pass = false;
395	}
396
397	if (sve->size != SVE_PT_SIZE(vq, sve->flags)) {
398		ksft_print_msg("Mismatch in SVE header size: %d != %lu\n",
399			       sve->size, SVE_PT_SIZE(vq, sve->flags));
400		pass = false;
401	}
402
403	/* The registers might be in completely different formats! */
404	if (sve->flags & SVE_PT_REGS_SVE) {
405		if (!compare_buffer("initial SVE Z",
406				    iov.iov_base + SVE_PT_SVE_ZREG_OFFSET(vq, 0),
407				    z_in, SVE_PT_SVE_ZREGS_SIZE(vq)))
408			pass = false;
409
410		if (!compare_buffer("initial SVE P",
411				    iov.iov_base + SVE_PT_SVE_PREG_OFFSET(vq, 0),
412				    p_in, SVE_PT_SVE_PREGS_SIZE(vq)))
413			pass = false;
414
415		if (!compare_buffer("initial SVE FFR",
416				    iov.iov_base + SVE_PT_SVE_FFR_OFFSET(vq),
417				    ffr_in, SVE_PT_SVE_PREG_SIZE(vq)))
418			pass = false;
419	} else {
420		fpsimd = iov.iov_base + SVE_PT_FPSIMD_OFFSET;
421		if (!compare_buffer("initial V via SVE", &fpsimd->vregs[0],
422				    v_in, sizeof(v_in)))
423			pass = false;
424	}
425
426out:
427	free(iov.iov_base);
428	return pass;
429}
430
431static bool check_ptrace_values_ssve(pid_t child, struct test_config *config)
432{
433	struct user_sve_header *sve;
434	struct user_fpsimd_state *fpsimd;
435	struct iovec iov;
436	int ret, vq;
437	bool pass = true;
438
439	if (!sme_supported())
440		return true;
441
442	vq = __sve_vq_from_vl(config->sme_vl_in);
443
444	iov.iov_len = SVE_PT_SVE_OFFSET + SVE_PT_SVE_SIZE(vq, SVE_PT_REGS_SVE);
445	iov.iov_base = malloc(iov.iov_len);
446	if (!iov.iov_base) {
447		ksft_print_msg("OOM allocating %lu byte SSVE buffer\n",
448			       iov.iov_len);
449		return false;
450	}
451
452	ret = ptrace(PTRACE_GETREGSET, child, NT_ARM_SSVE, &iov);
453	if (ret != 0) {
454		ksft_print_msg("Failed to read initial SSVE: %s (%d)\n",
455			       strerror(errno), errno);
456		pass = false;
457		goto out;
458	}
459
460	sve = iov.iov_base;
461
462	if (sve->vl != config->sme_vl_in) {
463		ksft_print_msg("Mismatch in initial SSVE VL: %d != %d\n",
464			       sve->vl, config->sme_vl_in);
465		pass = false;
466	}
467
468	if ((config->svcr_in & SVCR_SM) && !(sve->flags & SVE_PT_REGS_SVE)) {
469		ksft_print_msg("NT_ARM_SSVE reports FPSIMD with PSTATE.SM\n");
470		pass = false;
471	}
472
473	if (sve->size != SVE_PT_SIZE(vq, sve->flags)) {
474		ksft_print_msg("Mismatch in SSVE header size: %d != %lu\n",
475			       sve->size, SVE_PT_SIZE(vq, sve->flags));
476		pass = false;
477	}
478
479	/* The registers might be in completely different formats! */
480	if (sve->flags & SVE_PT_REGS_SVE) {
481		if (!compare_buffer("initial SSVE Z",
482				    iov.iov_base + SVE_PT_SVE_ZREG_OFFSET(vq, 0),
483				    z_in, SVE_PT_SVE_ZREGS_SIZE(vq)))
484			pass = false;
485
486		if (!compare_buffer("initial SSVE P",
487				    iov.iov_base + SVE_PT_SVE_PREG_OFFSET(vq, 0),
488				    p_in, SVE_PT_SVE_PREGS_SIZE(vq)))
489			pass = false;
490
491		if (!compare_buffer("initial SSVE FFR",
492				    iov.iov_base + SVE_PT_SVE_FFR_OFFSET(vq),
493				    ffr_in, SVE_PT_SVE_PREG_SIZE(vq)))
494			pass = false;
495	} else {
496		fpsimd = iov.iov_base + SVE_PT_FPSIMD_OFFSET;
497		if (!compare_buffer("initial V via SSVE",
498				    &fpsimd->vregs[0], v_in, sizeof(v_in)))
499			pass = false;
500	}
501
502out:
503	free(iov.iov_base);
504	return pass;
505}
506
507static bool check_ptrace_values_za(pid_t child, struct test_config *config)
508{
509	struct user_za_header *za;
510	struct iovec iov;
511	int ret, vq;
512	bool pass = true;
513
514	if (!sme_supported())
515		return true;
516
517	vq = __sve_vq_from_vl(config->sme_vl_in);
518
519	iov.iov_len = ZA_SIG_CONTEXT_SIZE(vq);
520	iov.iov_base = malloc(iov.iov_len);
521	if (!iov.iov_base) {
522		ksft_print_msg("OOM allocating %lu byte ZA buffer\n",
523			       iov.iov_len);
524		return false;
525	}
526
527	ret = ptrace(PTRACE_GETREGSET, child, NT_ARM_ZA, &iov);
528	if (ret != 0) {
529		ksft_print_msg("Failed to read initial ZA: %s (%d)\n",
530			       strerror(errno), errno);
531		pass = false;
532		goto out;
533	}
534
535	za = iov.iov_base;
536
537	if (za->vl != config->sme_vl_in) {
538		ksft_print_msg("Mismatch in initial SME VL: %d != %d\n",
539			       za->vl, config->sme_vl_in);
540		pass = false;
541	}
542
543	/* If PSTATE.ZA is not set we should just read the header */
544	if (config->svcr_in & SVCR_ZA) {
545		if (za->size != ZA_PT_SIZE(vq)) {
546			ksft_print_msg("Unexpected ZA ptrace read size: %d != %lu\n",
547				       za->size, ZA_PT_SIZE(vq));
548			pass = false;
549		}
550
551		if (!compare_buffer("initial ZA",
552				    iov.iov_base + ZA_PT_ZA_OFFSET,
553				    za_in, ZA_PT_ZA_SIZE(vq)))
554			pass = false;
555	} else {
556		if (za->size != sizeof(*za)) {
557			ksft_print_msg("Unexpected ZA ptrace read size: %d != %lu\n",
558				       za->size, sizeof(*za));
559			pass = false;
560		}
561	}
562
563out:
564	free(iov.iov_base);
565	return pass;
566}
567
568static bool check_ptrace_values_zt(pid_t child, struct test_config *config)
569{
570	uint8_t buf[512];
571	struct iovec iov;
572	int ret;
573
574	if (!sme2_supported())
575		return true;
576
577	iov.iov_base = &buf;
578	iov.iov_len = ZT_SIG_REG_BYTES;
579	ret = ptrace(PTRACE_GETREGSET, child, NT_ARM_ZT, &iov);
580	if (ret != 0) {
581		ksft_print_msg("Failed to read initial ZT: %s (%d)\n",
582			       strerror(errno), errno);
583		return false;
584	}
585
586	return compare_buffer("initial ZT", buf, zt_in, ZT_SIG_REG_BYTES);
587}
588
589
590static bool check_ptrace_values(pid_t child, struct test_config *config)
591{
592	bool pass = true;
593	struct user_fpsimd_state fpsimd;
594	struct iovec iov;
595	int ret;
596
597	iov.iov_base = &fpsimd;
598	iov.iov_len = sizeof(fpsimd);
599	ret = ptrace(PTRACE_GETREGSET, child, NT_PRFPREG, &iov);
600	if (ret == 0) {
601		if (!compare_buffer("initial V", &fpsimd.vregs, v_in,
602				    sizeof(v_in))) {
603			pass = false;
604		}
605	} else {
606		ksft_print_msg("Failed to read initial V: %s (%d)\n",
607			       strerror(errno), errno);
608		pass = false;
609	}
610
611	if (!check_ptrace_values_sve(child, config))
612		pass = false;
613
614	if (!check_ptrace_values_ssve(child, config))
615		pass = false;
616
617	if (!check_ptrace_values_za(child, config))
618		pass = false;
619
620	if (!check_ptrace_values_zt(child, config))
621		pass = false;
622
623	return pass;
624}
625
626static bool run_parent(pid_t child, struct test_definition *test,
627		       struct test_config *config)
628{
629	int wait_status, ret;
630	pid_t pid;
631	bool pass;
632
633	/* Initial attach */
634	while (1) {
635		pid = waitpid(child, &wait_status, 0);
636		if (pid < 0) {
637			if (errno == EINTR)
638				continue;
639			ksft_exit_fail_msg("waitpid() failed: %s (%d)\n",
640					   strerror(errno), errno);
641		}
642
643		if (pid == child)
644			break;
645	}
646
647	if (WIFEXITED(wait_status)) {
648		ksft_print_msg("Child exited loading values with status %d\n",
649			       WEXITSTATUS(wait_status));
650		pass = false;
651		goto out;
652	}
653
654	if (WIFSIGNALED(wait_status)) {
655		ksft_print_msg("Child died from signal %d loading values\n",
656			       WTERMSIG(wait_status));
657		pass = false;
658		goto out;
659	}
660
661	/* Read initial values via ptrace */
662	pass = check_ptrace_values(child, config);
663
664	/* Do whatever writes we want to do */
665	if (test->modify_values)
666		test->modify_values(child, config);
667
668	if (!continue_breakpoint(child, PTRACE_CONT))
669		goto cleanup;
670
671	while (1) {
672		pid = waitpid(child, &wait_status, 0);
673		if (pid < 0) {
674			if (errno == EINTR)
675				continue;
676			ksft_exit_fail_msg("waitpid() failed: %s (%d)\n",
677					   strerror(errno), errno);
678		}
679
680		if (pid == child)
681			break;
682	}
683
684	if (WIFEXITED(wait_status)) {
685		ksft_print_msg("Child exited saving values with status %d\n",
686			       WEXITSTATUS(wait_status));
687		pass = false;
688		goto out;
689	}
690
691	if (WIFSIGNALED(wait_status)) {
692		ksft_print_msg("Child died from signal %d saving values\n",
693			       WTERMSIG(wait_status));
694		pass = false;
695		goto out;
696	}
697
698	/* See what happened as a result */
699	read_child_regs(child);
700
701	if (!continue_breakpoint(child, PTRACE_DETACH))
702		goto cleanup;
703
704	/* The child should exit cleanly */
705	got_alarm = false;
706	alarm(1);
707	while (1) {
708		if (got_alarm) {
709			ksft_print_msg("Wait for child timed out\n");
710			goto cleanup;
711		}
712
713		pid = waitpid(child, &wait_status, 0);
714		if (pid < 0) {
715			if (errno == EINTR)
716				continue;
717			ksft_exit_fail_msg("waitpid() failed: %s (%d)\n",
718					   strerror(errno), errno);
719		}
720
721		if (pid == child)
722			break;
723	}
724	alarm(0);
725
726	if (got_alarm) {
727		ksft_print_msg("Timed out waiting for child\n");
728		pass = false;
729		goto cleanup;
730	}
731
732	if (pid == child && WIFSIGNALED(wait_status)) {
733		ksft_print_msg("Child died from signal %d cleaning up\n",
734			       WTERMSIG(wait_status));
735		pass = false;
736		goto out;
737	}
738
739	if (pid == child && WIFEXITED(wait_status)) {
740		if (WEXITSTATUS(wait_status) != 0) {
741			ksft_print_msg("Child exited with error %d\n",
742				       WEXITSTATUS(wait_status));
743			pass = false;
744		}
745	} else {
746		ksft_print_msg("Child did not exit cleanly\n");
747		pass = false;
748		goto cleanup;
749	}
750
751	goto out;
752
753cleanup:
754	ret = kill(child, SIGKILL);
755	if (ret != 0) {
756		ksft_print_msg("kill() failed: %s (%d)\n",
757			       strerror(errno), errno);
758		return false;
759	}
760
761	while (1) {
762		pid = waitpid(child, &wait_status, 0);
763		if (pid < 0) {
764			if (errno == EINTR)
765				continue;
766			ksft_exit_fail_msg("waitpid() failed: %s (%d)\n",
767					   strerror(errno), errno);
768		}
769
770		if (pid == child)
771			break;
772	}
773
774out:
775	return pass;
776}
777
778static void fill_random(void *buf, size_t size)
779{
780	int i;
781	uint32_t *lbuf = buf;
782
783	/* random() returns a 32 bit number regardless of the size of long */
784	for (i = 0; i < size / sizeof(uint32_t); i++)
785		lbuf[i] = random();
786}
787
788static void fill_random_ffr(void *buf, size_t vq)
789{
790	uint8_t *lbuf = buf;
791	int bits, i;
792
793	/*
794	 * Only values with a continuous set of 0..n bits set are
795	 * valid for FFR, set all bits then clear a random number of
796	 * high bits.
797	 */
798	memset(buf, 0, __SVE_FFR_SIZE(vq));
799
800	bits = random() % (__SVE_FFR_SIZE(vq) * 8);
801	for (i = 0; i < bits / 8; i++)
802		lbuf[i] = 0xff;
803	if (bits / 8 != __SVE_FFR_SIZE(vq))
804		lbuf[i] = (1 << (bits % 8)) - 1;
805}
806
807static void fpsimd_to_sve(__uint128_t *v, char *z, int vl)
808{
809	int vq = __sve_vq_from_vl(vl);
810	int i;
811	__uint128_t *p;
812
813	if (!vl)
814		return;
815
816	for (i = 0; i < __SVE_NUM_ZREGS; i++) {
817		p = (__uint128_t *)&z[__SVE_ZREG_OFFSET(vq, i)];
818		*p = arm64_cpu_to_le128(v[i]);
819	}
820}
821
822static void set_initial_values(struct test_config *config)
823{
824	int vq = __sve_vq_from_vl(vl_in(config));
825	int sme_vq = __sve_vq_from_vl(config->sme_vl_in);
826
827	svcr_in = config->svcr_in;
828	svcr_expected = config->svcr_expected;
829	svcr_out = 0;
830
831	fill_random(&v_in, sizeof(v_in));
832	memcpy(v_expected, v_in, sizeof(v_in));
833	memset(v_out, 0, sizeof(v_out));
834
835	/* Changes will be handled in the test case */
836	if (sve_supported() || (config->svcr_in & SVCR_SM)) {
837		/* The low 128 bits of Z are shared with the V registers */
838		fill_random(&z_in, __SVE_ZREGS_SIZE(vq));
839		fpsimd_to_sve(v_in, z_in, vl_in(config));
840		memcpy(z_expected, z_in, __SVE_ZREGS_SIZE(vq));
841		memset(z_out, 0, sizeof(z_out));
842
843		fill_random(&p_in, __SVE_PREGS_SIZE(vq));
844		memcpy(p_expected, p_in, __SVE_PREGS_SIZE(vq));
845		memset(p_out, 0, sizeof(p_out));
846
847		if ((config->svcr_in & SVCR_SM) && !fa64_supported())
848			memset(ffr_in, 0, __SVE_PREG_SIZE(vq));
849		else
850			fill_random_ffr(&ffr_in, vq);
851		memcpy(ffr_expected, ffr_in, __SVE_PREG_SIZE(vq));
852		memset(ffr_out, 0, __SVE_PREG_SIZE(vq));
853	}
854
855	if (config->svcr_in & SVCR_ZA)
856		fill_random(za_in, ZA_SIG_REGS_SIZE(sme_vq));
857	else
858		memset(za_in, 0, ZA_SIG_REGS_SIZE(sme_vq));
859	if (config->svcr_expected & SVCR_ZA)
860		memcpy(za_expected, za_in, ZA_SIG_REGS_SIZE(sme_vq));
861	else
862		memset(za_expected, 0, ZA_SIG_REGS_SIZE(sme_vq));
863	if (sme_supported())
864		memset(za_out, 0, sizeof(za_out));
865
866	if (sme2_supported()) {
867		if (config->svcr_in & SVCR_ZA)
868			fill_random(zt_in, ZT_SIG_REG_BYTES);
869		else
870			memset(zt_in, 0, ZT_SIG_REG_BYTES);
871		if (config->svcr_expected & SVCR_ZA)
872			memcpy(zt_expected, zt_in, ZT_SIG_REG_BYTES);
873		else
874			memset(zt_expected, 0, ZT_SIG_REG_BYTES);
875		memset(zt_out, 0, sizeof(zt_out));
876	}
877}
878
879static bool check_memory_values(struct test_config *config)
880{
881	bool pass = true;
882	int vq, sme_vq;
883
884	if (!compare_buffer("saved V", v_out, v_expected, sizeof(v_out)))
885		pass = false;
886
887	vq = __sve_vq_from_vl(vl_expected(config));
888	sme_vq = __sve_vq_from_vl(config->sme_vl_expected);
889
890	if (svcr_out != svcr_expected) {
891		ksft_print_msg("Mismatch in saved SVCR %lx != %lx\n",
892			       svcr_out, svcr_expected);
893		pass = false;
894	}
895
896	if (sve_vl_out != config->sve_vl_expected) {
897		ksft_print_msg("Mismatch in SVE VL: %ld != %d\n",
898			       sve_vl_out, config->sve_vl_expected);
899		pass = false;
900	}
901
902	if (sme_vl_out != config->sme_vl_expected) {
903		ksft_print_msg("Mismatch in SME VL: %ld != %d\n",
904			       sme_vl_out, config->sme_vl_expected);
905		pass = false;
906	}
907
908	if (!compare_buffer("saved Z", z_out, z_expected,
909			    __SVE_ZREGS_SIZE(vq)))
910		pass = false;
911
912	if (!compare_buffer("saved P", p_out, p_expected,
913			    __SVE_PREGS_SIZE(vq)))
914		pass = false;
915
916	if (!compare_buffer("saved FFR", ffr_out, ffr_expected,
917			    __SVE_PREG_SIZE(vq)))
918		pass = false;
919
920	if (!compare_buffer("saved ZA", za_out, za_expected,
921			    ZA_PT_ZA_SIZE(sme_vq)))
922		pass = false;
923
924	if (!compare_buffer("saved ZT", zt_out, zt_expected, ZT_SIG_REG_BYTES))
925		pass = false;
926
927	return pass;
928}
929
930static bool sve_sme_same(struct test_config *config)
931{
932	if (config->sve_vl_in != config->sve_vl_expected)
933		return false;
934
935	if (config->sme_vl_in != config->sme_vl_expected)
936		return false;
937
938	if (config->svcr_in != config->svcr_expected)
939		return false;
940
941	return true;
942}
943
944static bool sve_write_supported(struct test_config *config)
945{
946	if (!sve_supported() && !sme_supported())
947		return false;
948
949	if ((config->svcr_in & SVCR_ZA) != (config->svcr_expected & SVCR_ZA))
950		return false;
951
952	if (config->svcr_expected & SVCR_SM) {
953		if (config->sve_vl_in != config->sve_vl_expected) {
954			return false;
955		}
956
957		/* Changing the SME VL disables ZA */
958		if ((config->svcr_expected & SVCR_ZA) &&
959		    (config->sme_vl_in != config->sme_vl_expected)) {
960			return false;
961		}
962	} else {
963		if (config->sme_vl_in != config->sme_vl_expected) {
964			return false;
965		}
966	}
967
968	return true;
969}
970
971static void fpsimd_write_expected(struct test_config *config)
972{
973	int vl;
974
975	fill_random(&v_expected, sizeof(v_expected));
976
977	/* The SVE registers are flushed by a FPSIMD write */
978	vl = vl_expected(config);
979
980	memset(z_expected, 0, __SVE_ZREGS_SIZE(__sve_vq_from_vl(vl)));
981	memset(p_expected, 0, __SVE_PREGS_SIZE(__sve_vq_from_vl(vl)));
982	memset(ffr_expected, 0, __SVE_PREG_SIZE(__sve_vq_from_vl(vl)));
983
984	fpsimd_to_sve(v_expected, z_expected, vl);
985}
986
987static void fpsimd_write(pid_t child, struct test_config *test_config)
988{
989	struct user_fpsimd_state fpsimd;
990	struct iovec iov;
991	int ret;
992
993	memset(&fpsimd, 0, sizeof(fpsimd));
994	memcpy(&fpsimd.vregs, v_expected, sizeof(v_expected));
995
996	iov.iov_base = &fpsimd;
997	iov.iov_len = sizeof(fpsimd);
998	ret = ptrace(PTRACE_SETREGSET, child, NT_PRFPREG, &iov);
999	if (ret == -1)
1000		ksft_print_msg("FPSIMD set failed: (%s) %d\n",
1001			       strerror(errno), errno);
1002}
1003
1004static void sve_write_expected(struct test_config *config)
1005{
1006	int vl = vl_expected(config);
1007	int sme_vq = __sve_vq_from_vl(config->sme_vl_expected);
1008
1009	fill_random(z_expected, __SVE_ZREGS_SIZE(__sve_vq_from_vl(vl)));
1010	fill_random(p_expected, __SVE_PREGS_SIZE(__sve_vq_from_vl(vl)));
1011
1012	if ((svcr_expected & SVCR_SM) && !fa64_supported())
1013		memset(ffr_expected, 0, __SVE_PREG_SIZE(sme_vq));
1014	else
1015		fill_random_ffr(ffr_expected, __sve_vq_from_vl(vl));
1016
1017	/* Share the low bits of Z with V */
1018	fill_random(&v_expected, sizeof(v_expected));
1019	fpsimd_to_sve(v_expected, z_expected, vl);
1020
1021	if (config->sme_vl_in != config->sme_vl_expected) {
1022		memset(za_expected, 0, ZA_PT_ZA_SIZE(sme_vq));
1023		memset(zt_expected, 0, sizeof(zt_expected));
1024	}
1025}
1026
1027static void sve_write(pid_t child, struct test_config *config)
1028{
1029	struct user_sve_header *sve;
1030	struct iovec iov;
1031	int ret, vl, vq, regset;
1032
1033	vl = vl_expected(config);
1034	vq = __sve_vq_from_vl(vl);
1035
1036	iov.iov_len = SVE_PT_SVE_OFFSET + SVE_PT_SVE_SIZE(vq, SVE_PT_REGS_SVE);
1037	iov.iov_base = malloc(iov.iov_len);
1038	if (!iov.iov_base) {
1039		ksft_print_msg("Failed allocating %lu byte SVE write buffer\n",
1040			       iov.iov_len);
1041		return;
1042	}
1043	memset(iov.iov_base, 0, iov.iov_len);
1044
1045	sve = iov.iov_base;
1046	sve->size = iov.iov_len;
1047	sve->flags = SVE_PT_REGS_SVE;
1048	sve->vl = vl;
1049
1050	memcpy(iov.iov_base + SVE_PT_SVE_ZREG_OFFSET(vq, 0),
1051	       z_expected, SVE_PT_SVE_ZREGS_SIZE(vq));
1052	memcpy(iov.iov_base + SVE_PT_SVE_PREG_OFFSET(vq, 0),
1053	       p_expected, SVE_PT_SVE_PREGS_SIZE(vq));
1054	memcpy(iov.iov_base + SVE_PT_SVE_FFR_OFFSET(vq),
1055	       ffr_expected, SVE_PT_SVE_PREG_SIZE(vq));
1056
1057	if (svcr_expected & SVCR_SM)
1058		regset = NT_ARM_SSVE;
1059	else
1060		regset = NT_ARM_SVE;
1061
1062	ret = ptrace(PTRACE_SETREGSET, child, regset, &iov);
1063	if (ret != 0)
1064		ksft_print_msg("Failed to write SVE: %s (%d)\n",
1065			       strerror(errno), errno);
1066
1067	free(iov.iov_base);
1068}
1069
1070static bool za_write_supported(struct test_config *config)
1071{
1072	if (config->svcr_expected & SVCR_SM) {
1073		if (!(config->svcr_in & SVCR_SM))
1074			return false;
1075
1076		/* Changing the SME VL exits streaming mode */
1077		if (config->sme_vl_in != config->sme_vl_expected) {
1078			return false;
1079		}
1080	}
1081
1082	/* Can't disable SM outside a VL change */
1083	if ((config->svcr_in & SVCR_SM) &&
1084	    !(config->svcr_expected & SVCR_SM))
1085		return false;
1086
1087	return true;
1088}
1089
1090static void za_write_expected(struct test_config *config)
1091{
1092	int sme_vq, sve_vq;
1093
1094	sme_vq = __sve_vq_from_vl(config->sme_vl_expected);
1095
1096	if (config->svcr_expected & SVCR_ZA) {
1097		fill_random(za_expected, ZA_PT_ZA_SIZE(sme_vq));
1098	} else {
1099		memset(za_expected, 0, ZA_PT_ZA_SIZE(sme_vq));
1100		memset(zt_expected, 0, sizeof(zt_expected));
1101	}
1102
1103	/* Changing the SME VL flushes ZT, SVE state and exits SM */
1104	if (config->sme_vl_in != config->sme_vl_expected) {
1105		svcr_expected &= ~SVCR_SM;
1106
1107		sve_vq = __sve_vq_from_vl(vl_expected(config));
1108		memset(z_expected, 0, __SVE_ZREGS_SIZE(sve_vq));
1109		memset(p_expected, 0, __SVE_PREGS_SIZE(sve_vq));
1110		memset(ffr_expected, 0, __SVE_PREG_SIZE(sve_vq));
1111		memset(zt_expected, 0, sizeof(zt_expected));
1112
1113		fpsimd_to_sve(v_expected, z_expected, vl_expected(config));
1114	}
1115}
1116
1117static void za_write(pid_t child, struct test_config *config)
1118{
1119	struct user_za_header *za;
1120	struct iovec iov;
1121	int ret, vq;
1122
1123	vq = __sve_vq_from_vl(config->sme_vl_expected);
1124
1125	if (config->svcr_expected & SVCR_ZA)
1126		iov.iov_len = ZA_PT_SIZE(vq);
1127	else
1128		iov.iov_len = sizeof(*za);
1129	iov.iov_base = malloc(iov.iov_len);
1130	if (!iov.iov_base) {
1131		ksft_print_msg("Failed allocating %lu byte ZA write buffer\n",
1132			       iov.iov_len);
1133		return;
1134	}
1135	memset(iov.iov_base, 0, iov.iov_len);
1136
1137	za = iov.iov_base;
1138	za->size = iov.iov_len;
1139	za->vl = config->sme_vl_expected;
1140	if (config->svcr_expected & SVCR_ZA)
1141		memcpy(iov.iov_base + ZA_PT_ZA_OFFSET, za_expected,
1142		       ZA_PT_ZA_SIZE(vq));
1143
1144	ret = ptrace(PTRACE_SETREGSET, child, NT_ARM_ZA, &iov);
1145	if (ret != 0)
1146		ksft_print_msg("Failed to write ZA: %s (%d)\n",
1147			       strerror(errno), errno);
1148
1149	free(iov.iov_base);
1150}
1151
1152static bool zt_write_supported(struct test_config *config)
1153{
1154	if (!sme2_supported())
1155		return false;
1156	if (config->sme_vl_in != config->sme_vl_expected)
1157		return false;
1158	if (!(config->svcr_expected & SVCR_ZA))
1159		return false;
1160	if ((config->svcr_in & SVCR_SM) != (config->svcr_expected & SVCR_SM))
1161		return false;
1162
1163	return true;
1164}
1165
1166static void zt_write_expected(struct test_config *config)
1167{
1168	int sme_vq;
1169
1170	sme_vq = __sve_vq_from_vl(config->sme_vl_expected);
1171
1172	if (config->svcr_expected & SVCR_ZA) {
1173		fill_random(zt_expected, sizeof(zt_expected));
1174	} else {
1175		memset(za_expected, 0, ZA_PT_ZA_SIZE(sme_vq));
1176		memset(zt_expected, 0, sizeof(zt_expected));
1177	}
1178}
1179
1180static void zt_write(pid_t child, struct test_config *config)
1181{
1182	struct iovec iov;
1183	int ret;
1184
1185	iov.iov_len = ZT_SIG_REG_BYTES;
1186	iov.iov_base = zt_expected;
1187	ret = ptrace(PTRACE_SETREGSET, child, NT_ARM_ZT, &iov);
1188	if (ret != 0)
1189		ksft_print_msg("Failed to write ZT: %s (%d)\n",
1190			       strerror(errno), errno);
1191}
1192
1193/* Actually run a test */
1194static void run_test(struct test_definition *test, struct test_config *config)
1195{
1196	pid_t child;
1197	char name[1024];
1198	bool pass;
1199
1200	if (sve_supported() && sme_supported())
1201		snprintf(name, sizeof(name), "%s, SVE %d->%d, SME %d/%x->%d/%x",
1202			 test->name,
1203			 config->sve_vl_in, config->sve_vl_expected,
1204			 config->sme_vl_in, config->svcr_in,
1205			 config->sme_vl_expected, config->svcr_expected);
1206	else if (sve_supported())
1207		snprintf(name, sizeof(name), "%s, SVE %d->%d", test->name,
1208			 config->sve_vl_in, config->sve_vl_expected);
1209	else if (sme_supported())
1210		snprintf(name, sizeof(name), "%s, SME %d/%x->%d/%x",
1211			 test->name,
1212			 config->sme_vl_in, config->svcr_in,
1213			 config->sme_vl_expected, config->svcr_expected);
1214	else
1215		snprintf(name, sizeof(name), "%s", test->name);
1216
1217	if (test->supported && !test->supported(config)) {
1218		ksft_test_result_skip("%s\n", name);
1219		return;
1220	}
1221
1222	set_initial_values(config);
1223
1224	if (test->set_expected_values)
1225		test->set_expected_values(config);
1226
1227	child = fork();
1228	if (child < 0)
1229		ksft_exit_fail_msg("fork() failed: %s (%d)\n",
1230				   strerror(errno), errno);
1231	/* run_child() never returns */
1232	if (child == 0)
1233		run_child(config);
1234
1235	pass = run_parent(child, test, config);
1236	if (!check_memory_values(config))
1237		pass = false;
1238
1239	ksft_test_result(pass, "%s\n", name);
1240}
1241
1242static void run_tests(struct test_definition defs[], int count,
1243		      struct test_config *config)
1244{
1245	int i;
1246
1247	for (i = 0; i < count; i++)
1248		run_test(&defs[i], config);
1249}
1250
1251static struct test_definition base_test_defs[] = {
1252	{
1253		.name = "No writes",
1254		.supported = sve_sme_same,
1255	},
1256	{
1257		.name = "FPSIMD write",
1258		.supported = sve_sme_same,
1259		.set_expected_values = fpsimd_write_expected,
1260		.modify_values = fpsimd_write,
1261	},
1262};
1263
1264static struct test_definition sve_test_defs[] = {
1265	{
1266		.name = "SVE write",
1267		.supported = sve_write_supported,
1268		.set_expected_values = sve_write_expected,
1269		.modify_values = sve_write,
1270	},
1271};
1272
1273static struct test_definition za_test_defs[] = {
1274	{
1275		.name = "ZA write",
1276		.supported = za_write_supported,
1277		.set_expected_values = za_write_expected,
1278		.modify_values = za_write,
1279	},
1280};
1281
1282static struct test_definition zt_test_defs[] = {
1283	{
1284		.name = "ZT write",
1285		.supported = zt_write_supported,
1286		.set_expected_values = zt_write_expected,
1287		.modify_values = zt_write,
1288	},
1289};
1290
1291static int sve_vls[MAX_NUM_VLS], sme_vls[MAX_NUM_VLS];
1292static int sve_vl_count, sme_vl_count;
1293
1294static void probe_vls(const char *name, int vls[], int *vl_count, int set_vl)
1295{
1296	unsigned int vq;
1297	int vl;
1298
1299	*vl_count = 0;
1300
1301	for (vq = ARCH_VQ_MAX; vq > 0; vq /= 2) {
1302		vl = prctl(set_vl, vq * 16);
1303		if (vl == -1)
1304			ksft_exit_fail_msg("SET_VL failed: %s (%d)\n",
1305					   strerror(errno), errno);
1306
1307		vl &= PR_SVE_VL_LEN_MASK;
1308
1309		if (*vl_count && (vl == vls[*vl_count - 1]))
1310			break;
1311
1312		vq = sve_vq_from_vl(vl);
1313
1314		vls[*vl_count] = vl;
1315		*vl_count += 1;
1316	}
1317
1318	if (*vl_count > 2) {
1319		/* Just use the minimum and maximum */
1320		vls[1] = vls[*vl_count - 1];
1321		ksft_print_msg("%d %s VLs, using %d and %d\n",
1322			       *vl_count, name, vls[0], vls[1]);
1323		*vl_count = 2;
1324	} else {
1325		ksft_print_msg("%d %s VLs\n", *vl_count, name);
1326	}
1327}
1328
1329static struct {
1330	int svcr_in, svcr_expected;
1331} svcr_combinations[] = {
1332	{ .svcr_in = 0, .svcr_expected = 0, },
1333	{ .svcr_in = 0, .svcr_expected = SVCR_SM, },
1334	{ .svcr_in = 0, .svcr_expected = SVCR_ZA, },
1335	/* Can't enable both SM and ZA with a single ptrace write */
1336
1337	{ .svcr_in = SVCR_SM, .svcr_expected = 0, },
1338	{ .svcr_in = SVCR_SM, .svcr_expected = SVCR_SM, },
1339	{ .svcr_in = SVCR_SM, .svcr_expected = SVCR_ZA, },
1340	{ .svcr_in = SVCR_SM, .svcr_expected = SVCR_SM | SVCR_ZA, },
1341
1342	{ .svcr_in = SVCR_ZA, .svcr_expected = 0, },
1343	{ .svcr_in = SVCR_ZA, .svcr_expected = SVCR_SM, },
1344	{ .svcr_in = SVCR_ZA, .svcr_expected = SVCR_ZA, },
1345	{ .svcr_in = SVCR_ZA, .svcr_expected = SVCR_SM | SVCR_ZA, },
1346
1347	{ .svcr_in = SVCR_SM | SVCR_ZA, .svcr_expected = 0, },
1348	{ .svcr_in = SVCR_SM | SVCR_ZA, .svcr_expected = SVCR_SM, },
1349	{ .svcr_in = SVCR_SM | SVCR_ZA, .svcr_expected = SVCR_ZA, },
1350	{ .svcr_in = SVCR_SM | SVCR_ZA, .svcr_expected = SVCR_SM | SVCR_ZA, },
1351};
1352
1353static void run_sve_tests(void)
1354{
1355	struct test_config test_config;
1356	int i, j;
1357
1358	if (!sve_supported())
1359		return;
1360
1361	test_config.sme_vl_in = sme_vls[0];
1362	test_config.sme_vl_expected = sme_vls[0];
1363	test_config.svcr_in = 0;
1364	test_config.svcr_expected = 0;
1365
1366	for (i = 0; i < sve_vl_count; i++) {
1367		test_config.sve_vl_in = sve_vls[i];
1368
1369		for (j = 0; j < sve_vl_count; j++) {
1370			test_config.sve_vl_expected = sve_vls[j];
1371
1372			run_tests(base_test_defs,
1373				  ARRAY_SIZE(base_test_defs),
1374				  &test_config);
1375			if (sve_supported())
1376				run_tests(sve_test_defs,
1377					  ARRAY_SIZE(sve_test_defs),
1378					  &test_config);
1379		}
1380	}
1381
1382}
1383
1384static void run_sme_tests(void)
1385{
1386	struct test_config test_config;
1387	int i, j, k;
1388
1389	if (!sme_supported())
1390		return;
1391
1392	test_config.sve_vl_in = sve_vls[0];
1393	test_config.sve_vl_expected = sve_vls[0];
1394
1395	/*
1396	 * Every SME VL/SVCR combination
1397	 */
1398	for (i = 0; i < sme_vl_count; i++) {
1399		test_config.sme_vl_in = sme_vls[i];
1400
1401		for (j = 0; j < sme_vl_count; j++) {
1402			test_config.sme_vl_expected = sme_vls[j];
1403
1404			for (k = 0; k < ARRAY_SIZE(svcr_combinations); k++) {
1405				test_config.svcr_in = svcr_combinations[k].svcr_in;
1406				test_config.svcr_expected = svcr_combinations[k].svcr_expected;
1407
1408				run_tests(base_test_defs,
1409					  ARRAY_SIZE(base_test_defs),
1410					  &test_config);
1411				run_tests(sve_test_defs,
1412					  ARRAY_SIZE(sve_test_defs),
1413					  &test_config);
1414				run_tests(za_test_defs,
1415					  ARRAY_SIZE(za_test_defs),
1416					  &test_config);
1417
1418				if (sme2_supported())
1419					run_tests(zt_test_defs,
1420						  ARRAY_SIZE(zt_test_defs),
1421						  &test_config);
1422			}
1423		}
1424	}
1425}
1426
1427int main(void)
1428{
1429	struct test_config test_config;
1430	struct sigaction sa;
1431	int tests, ret, tmp;
1432
1433	srandom(getpid());
1434
1435	ksft_print_header();
1436
1437	if (sve_supported()) {
1438		probe_vls("SVE", sve_vls, &sve_vl_count, PR_SVE_SET_VL);
1439
1440		tests = ARRAY_SIZE(base_test_defs) +
1441			ARRAY_SIZE(sve_test_defs);
1442		tests *= sve_vl_count * sve_vl_count;
1443	} else {
1444		/* Only run the FPSIMD tests */
1445		sve_vl_count = 1;
1446		tests = ARRAY_SIZE(base_test_defs);
1447	}
1448
1449	if (sme_supported()) {
1450		probe_vls("SME", sme_vls, &sme_vl_count, PR_SME_SET_VL);
1451
1452		tmp = ARRAY_SIZE(base_test_defs) + ARRAY_SIZE(sve_test_defs)
1453			+ ARRAY_SIZE(za_test_defs);
1454
1455		if (sme2_supported())
1456			tmp += ARRAY_SIZE(zt_test_defs);
1457
1458		tmp *= sme_vl_count * sme_vl_count;
1459		tmp *= ARRAY_SIZE(svcr_combinations);
1460		tests += tmp;
1461	} else {
1462		sme_vl_count = 1;
1463	}
1464
1465	if (sme2_supported())
1466		ksft_print_msg("SME2 supported\n");
1467
1468	if (fa64_supported())
1469		ksft_print_msg("FA64 supported\n");
1470
1471	ksft_set_plan(tests);
1472
1473	/* Get signal handers ready before we start any children */
1474	memset(&sa, 0, sizeof(sa));
1475	sa.sa_sigaction = handle_alarm;
1476	sa.sa_flags = SA_RESTART | SA_SIGINFO;
1477	sigemptyset(&sa.sa_mask);
1478	ret = sigaction(SIGALRM, &sa, NULL);
1479	if (ret < 0)
1480		ksft_print_msg("Failed to install SIGALRM handler: %s (%d)\n",
1481			       strerror(errno), errno);
1482
1483	/*
1484	 * Run the test set if there is no SVE or SME, with those we
1485	 * have to pick a VL for each run.
1486	 */
1487	if (!sve_supported()) {
1488		test_config.sve_vl_in = 0;
1489		test_config.sve_vl_expected = 0;
1490		test_config.sme_vl_in = 0;
1491		test_config.sme_vl_expected = 0;
1492		test_config.svcr_in = 0;
1493		test_config.svcr_expected = 0;
1494
1495		run_tests(base_test_defs, ARRAY_SIZE(base_test_defs),
1496			  &test_config);
1497	}
1498
1499	run_sve_tests();
1500	run_sme_tests();
1501
1502	ksft_finished();
1503}
1504