1// SPDX-License-Identifier: GPL-2.0-only
2/*
3 * Copyright (C) 2021 ARM Limited.
4 */
5
6#include <errno.h>
7#include <stdbool.h>
8#include <stddef.h>
9#include <stdio.h>
10#include <stdlib.h>
11#include <string.h>
12#include <unistd.h>
13#include <sys/auxv.h>
14#include <sys/prctl.h>
15#include <asm/hwcap.h>
16#include <asm/sigcontext.h>
17#include <asm/unistd.h>
18
19#include "../../kselftest.h"
20
21#include "syscall-abi.h"
22
23/*
24 * The kernel defines a much larger SVE_VQ_MAX than is expressable in
25 * the architecture, this creates a *lot* of overhead filling the
26 * buffers (especially ZA) on emulated platforms so use the actual
27 * architectural maximum instead.
28 */
29#define ARCH_SVE_VQ_MAX 16
30
31static int default_sme_vl;
32
33static int sve_vl_count;
34static unsigned int sve_vls[ARCH_SVE_VQ_MAX];
35static int sme_vl_count;
36static unsigned int sme_vls[ARCH_SVE_VQ_MAX];
37
38extern void do_syscall(int sve_vl, int sme_vl);
39
40static void fill_random(void *buf, size_t size)
41{
42	int i;
43	uint32_t *lbuf = buf;
44
45	/* random() returns a 32 bit number regardless of the size of long */
46	for (i = 0; i < size / sizeof(uint32_t); i++)
47		lbuf[i] = random();
48}
49
50/*
51 * We also repeat the test for several syscalls to try to expose different
52 * behaviour.
53 */
54static struct syscall_cfg {
55	int syscall_nr;
56	const char *name;
57} syscalls[] = {
58	{ __NR_getpid,		"getpid()" },
59	{ __NR_sched_yield,	"sched_yield()" },
60};
61
62#define NUM_GPR 31
63uint64_t gpr_in[NUM_GPR];
64uint64_t gpr_out[NUM_GPR];
65
66static void setup_gpr(struct syscall_cfg *cfg, int sve_vl, int sme_vl,
67		      uint64_t svcr)
68{
69	fill_random(gpr_in, sizeof(gpr_in));
70	gpr_in[8] = cfg->syscall_nr;
71	memset(gpr_out, 0, sizeof(gpr_out));
72}
73
74static int check_gpr(struct syscall_cfg *cfg, int sve_vl, int sme_vl, uint64_t svcr)
75{
76	int errors = 0;
77	int i;
78
79	/*
80	 * GPR x0-x7 may be clobbered, and all others should be preserved.
81	 */
82	for (i = 9; i < ARRAY_SIZE(gpr_in); i++) {
83		if (gpr_in[i] != gpr_out[i]) {
84			ksft_print_msg("%s SVE VL %d mismatch in GPR %d: %llx != %llx\n",
85				       cfg->name, sve_vl, i,
86				       gpr_in[i], gpr_out[i]);
87			errors++;
88		}
89	}
90
91	return errors;
92}
93
94#define NUM_FPR 32
95uint64_t fpr_in[NUM_FPR * 2];
96uint64_t fpr_out[NUM_FPR * 2];
97uint64_t fpr_zero[NUM_FPR * 2];
98
99static void setup_fpr(struct syscall_cfg *cfg, int sve_vl, int sme_vl,
100		      uint64_t svcr)
101{
102	fill_random(fpr_in, sizeof(fpr_in));
103	memset(fpr_out, 0, sizeof(fpr_out));
104}
105
106static int check_fpr(struct syscall_cfg *cfg, int sve_vl, int sme_vl,
107		     uint64_t svcr)
108{
109	int errors = 0;
110	int i;
111
112	if (!sve_vl && !(svcr & SVCR_SM_MASK)) {
113		for (i = 0; i < ARRAY_SIZE(fpr_in); i++) {
114			if (fpr_in[i] != fpr_out[i]) {
115				ksft_print_msg("%s Q%d/%d mismatch %llx != %llx\n",
116					       cfg->name,
117					       i / 2, i % 2,
118					       fpr_in[i], fpr_out[i]);
119				errors++;
120			}
121		}
122	}
123
124	/*
125	 * In streaming mode the whole register set should be cleared
126	 * by the transition out of streaming mode.
127	 */
128	if (svcr & SVCR_SM_MASK) {
129		if (memcmp(fpr_zero, fpr_out, sizeof(fpr_out)) != 0) {
130			ksft_print_msg("%s FPSIMD registers non-zero exiting SM\n",
131				       cfg->name);
132			errors++;
133		}
134	}
135
136	return errors;
137}
138
139#define SVE_Z_SHARED_BYTES (128 / 8)
140
141static uint8_t z_zero[__SVE_ZREG_SIZE(ARCH_SVE_VQ_MAX)];
142uint8_t z_in[SVE_NUM_ZREGS * __SVE_ZREG_SIZE(ARCH_SVE_VQ_MAX)];
143uint8_t z_out[SVE_NUM_ZREGS * __SVE_ZREG_SIZE(ARCH_SVE_VQ_MAX)];
144
145static void setup_z(struct syscall_cfg *cfg, int sve_vl, int sme_vl,
146		    uint64_t svcr)
147{
148	fill_random(z_in, sizeof(z_in));
149	fill_random(z_out, sizeof(z_out));
150}
151
152static int check_z(struct syscall_cfg *cfg, int sve_vl, int sme_vl,
153		   uint64_t svcr)
154{
155	size_t reg_size = sve_vl;
156	int errors = 0;
157	int i;
158
159	if (!sve_vl)
160		return 0;
161
162	for (i = 0; i < SVE_NUM_ZREGS; i++) {
163		uint8_t *in = &z_in[reg_size * i];
164		uint8_t *out = &z_out[reg_size * i];
165
166		if (svcr & SVCR_SM_MASK) {
167			/*
168			 * In streaming mode the whole register should
169			 * be cleared by the transition out of
170			 * streaming mode.
171			 */
172			if (memcmp(z_zero, out, reg_size) != 0) {
173				ksft_print_msg("%s SVE VL %d Z%d non-zero\n",
174					       cfg->name, sve_vl, i);
175				errors++;
176			}
177		} else {
178			/*
179			 * For standard SVE the low 128 bits should be
180			 * preserved and any additional bits cleared.
181			 */
182			if (memcmp(in, out, SVE_Z_SHARED_BYTES) != 0) {
183				ksft_print_msg("%s SVE VL %d Z%d low 128 bits changed\n",
184					       cfg->name, sve_vl, i);
185				errors++;
186			}
187
188			if (reg_size > SVE_Z_SHARED_BYTES &&
189			    (memcmp(z_zero, out + SVE_Z_SHARED_BYTES,
190				    reg_size - SVE_Z_SHARED_BYTES) != 0)) {
191				ksft_print_msg("%s SVE VL %d Z%d high bits non-zero\n",
192					       cfg->name, sve_vl, i);
193				errors++;
194			}
195		}
196	}
197
198	return errors;
199}
200
201uint8_t p_in[SVE_NUM_PREGS * __SVE_PREG_SIZE(ARCH_SVE_VQ_MAX)];
202uint8_t p_out[SVE_NUM_PREGS * __SVE_PREG_SIZE(ARCH_SVE_VQ_MAX)];
203
204static void setup_p(struct syscall_cfg *cfg, int sve_vl, int sme_vl,
205		    uint64_t svcr)
206{
207	fill_random(p_in, sizeof(p_in));
208	fill_random(p_out, sizeof(p_out));
209}
210
211static int check_p(struct syscall_cfg *cfg, int sve_vl, int sme_vl,
212		   uint64_t svcr)
213{
214	size_t reg_size = sve_vq_from_vl(sve_vl) * 2; /* 1 bit per VL byte */
215
216	int errors = 0;
217	int i;
218
219	if (!sve_vl)
220		return 0;
221
222	/* After a syscall the P registers should be zeroed */
223	for (i = 0; i < SVE_NUM_PREGS * reg_size; i++)
224		if (p_out[i])
225			errors++;
226	if (errors)
227		ksft_print_msg("%s SVE VL %d predicate registers non-zero\n",
228			       cfg->name, sve_vl);
229
230	return errors;
231}
232
233uint8_t ffr_in[__SVE_PREG_SIZE(ARCH_SVE_VQ_MAX)];
234uint8_t ffr_out[__SVE_PREG_SIZE(ARCH_SVE_VQ_MAX)];
235
236static void setup_ffr(struct syscall_cfg *cfg, int sve_vl, int sme_vl,
237		      uint64_t svcr)
238{
239	/*
240	 * If we are in streaming mode and do not have FA64 then FFR
241	 * is unavailable.
242	 */
243	if ((svcr & SVCR_SM_MASK) &&
244	    !(getauxval(AT_HWCAP2) & HWCAP2_SME_FA64)) {
245		memset(&ffr_in, 0, sizeof(ffr_in));
246		return;
247	}
248
249	/*
250	 * It is only valid to set a contiguous set of bits starting
251	 * at 0.  For now since we're expecting this to be cleared by
252	 * a syscall just set all bits.
253	 */
254	memset(ffr_in, 0xff, sizeof(ffr_in));
255	fill_random(ffr_out, sizeof(ffr_out));
256}
257
258static int check_ffr(struct syscall_cfg *cfg, int sve_vl, int sme_vl,
259		     uint64_t svcr)
260{
261	size_t reg_size = sve_vq_from_vl(sve_vl) * 2;  /* 1 bit per VL byte */
262	int errors = 0;
263	int i;
264
265	if (!sve_vl)
266		return 0;
267
268	if ((svcr & SVCR_SM_MASK) &&
269	    !(getauxval(AT_HWCAP2) & HWCAP2_SME_FA64))
270		return 0;
271
272	/* After a syscall FFR should be zeroed */
273	for (i = 0; i < reg_size; i++)
274		if (ffr_out[i])
275			errors++;
276	if (errors)
277		ksft_print_msg("%s SVE VL %d FFR non-zero\n",
278			       cfg->name, sve_vl);
279
280	return errors;
281}
282
283uint64_t svcr_in, svcr_out;
284
285static void setup_svcr(struct syscall_cfg *cfg, int sve_vl, int sme_vl,
286		    uint64_t svcr)
287{
288	svcr_in = svcr;
289}
290
291static int check_svcr(struct syscall_cfg *cfg, int sve_vl, int sme_vl,
292		      uint64_t svcr)
293{
294	int errors = 0;
295
296	if (svcr_out & SVCR_SM_MASK) {
297		ksft_print_msg("%s Still in SM, SVCR %llx\n",
298			       cfg->name, svcr_out);
299		errors++;
300	}
301
302	if ((svcr_in & SVCR_ZA_MASK) != (svcr_out & SVCR_ZA_MASK)) {
303		ksft_print_msg("%s PSTATE.ZA changed, SVCR %llx != %llx\n",
304			       cfg->name, svcr_in, svcr_out);
305		errors++;
306	}
307
308	return errors;
309}
310
311uint8_t za_in[ZA_SIG_REGS_SIZE(ARCH_SVE_VQ_MAX)];
312uint8_t za_out[ZA_SIG_REGS_SIZE(ARCH_SVE_VQ_MAX)];
313
314static void setup_za(struct syscall_cfg *cfg, int sve_vl, int sme_vl,
315		     uint64_t svcr)
316{
317	fill_random(za_in, sizeof(za_in));
318	memset(za_out, 0, sizeof(za_out));
319}
320
321static int check_za(struct syscall_cfg *cfg, int sve_vl, int sme_vl,
322		    uint64_t svcr)
323{
324	size_t reg_size = sme_vl * sme_vl;
325	int errors = 0;
326
327	if (!(svcr & SVCR_ZA_MASK))
328		return 0;
329
330	if (memcmp(za_in, za_out, reg_size) != 0) {
331		ksft_print_msg("SME VL %d ZA does not match\n", sme_vl);
332		errors++;
333	}
334
335	return errors;
336}
337
338uint8_t zt_in[ZT_SIG_REG_BYTES] __attribute__((aligned(16)));
339uint8_t zt_out[ZT_SIG_REG_BYTES] __attribute__((aligned(16)));
340
341static void setup_zt(struct syscall_cfg *cfg, int sve_vl, int sme_vl,
342		     uint64_t svcr)
343{
344	fill_random(zt_in, sizeof(zt_in));
345	memset(zt_out, 0, sizeof(zt_out));
346}
347
348static int check_zt(struct syscall_cfg *cfg, int sve_vl, int sme_vl,
349		    uint64_t svcr)
350{
351	int errors = 0;
352
353	if (!(getauxval(AT_HWCAP2) & HWCAP2_SME2))
354		return 0;
355
356	if (!(svcr & SVCR_ZA_MASK))
357		return 0;
358
359	if (memcmp(zt_in, zt_out, sizeof(zt_in)) != 0) {
360		ksft_print_msg("SME VL %d ZT does not match\n", sme_vl);
361		errors++;
362	}
363
364	return errors;
365}
366
367typedef void (*setup_fn)(struct syscall_cfg *cfg, int sve_vl, int sme_vl,
368			 uint64_t svcr);
369typedef int (*check_fn)(struct syscall_cfg *cfg, int sve_vl, int sme_vl,
370			uint64_t svcr);
371
372/*
373 * Each set of registers has a setup function which is called before
374 * the syscall to fill values in a global variable for loading by the
375 * test code and a check function which validates that the results are
376 * as expected.  Vector lengths are passed everywhere, a vector length
377 * of 0 should be treated as do not test.
378 */
379static struct {
380	setup_fn setup;
381	check_fn check;
382} regset[] = {
383	{ setup_gpr, check_gpr },
384	{ setup_fpr, check_fpr },
385	{ setup_z, check_z },
386	{ setup_p, check_p },
387	{ setup_ffr, check_ffr },
388	{ setup_svcr, check_svcr },
389	{ setup_za, check_za },
390	{ setup_zt, check_zt },
391};
392
393static bool do_test(struct syscall_cfg *cfg, int sve_vl, int sme_vl,
394		    uint64_t svcr)
395{
396	int errors = 0;
397	int i;
398
399	for (i = 0; i < ARRAY_SIZE(regset); i++)
400		regset[i].setup(cfg, sve_vl, sme_vl, svcr);
401
402	do_syscall(sve_vl, sme_vl);
403
404	for (i = 0; i < ARRAY_SIZE(regset); i++)
405		errors += regset[i].check(cfg, sve_vl, sme_vl, svcr);
406
407	return errors == 0;
408}
409
410static void test_one_syscall(struct syscall_cfg *cfg)
411{
412	int sve, sme;
413	int ret;
414
415	/* FPSIMD only case */
416	ksft_test_result(do_test(cfg, 0, default_sme_vl, 0),
417			 "%s FPSIMD\n", cfg->name);
418
419	for (sve = 0; sve < sve_vl_count; sve++) {
420		ret = prctl(PR_SVE_SET_VL, sve_vls[sve]);
421		if (ret == -1)
422			ksft_exit_fail_msg("PR_SVE_SET_VL failed: %s (%d)\n",
423					   strerror(errno), errno);
424
425		ksft_test_result(do_test(cfg, sve_vls[sve], default_sme_vl, 0),
426				 "%s SVE VL %d\n", cfg->name, sve_vls[sve]);
427
428		for (sme = 0; sme < sme_vl_count; sme++) {
429			ret = prctl(PR_SME_SET_VL, sme_vls[sme]);
430			if (ret == -1)
431				ksft_exit_fail_msg("PR_SME_SET_VL failed: %s (%d)\n",
432						   strerror(errno), errno);
433
434			ksft_test_result(do_test(cfg, sve_vls[sve],
435						 sme_vls[sme],
436						 SVCR_ZA_MASK | SVCR_SM_MASK),
437					 "%s SVE VL %d/SME VL %d SM+ZA\n",
438					 cfg->name, sve_vls[sve],
439					 sme_vls[sme]);
440			ksft_test_result(do_test(cfg, sve_vls[sve],
441						 sme_vls[sme], SVCR_SM_MASK),
442					 "%s SVE VL %d/SME VL %d SM\n",
443					 cfg->name, sve_vls[sve],
444					 sme_vls[sme]);
445			ksft_test_result(do_test(cfg, sve_vls[sve],
446						 sme_vls[sme], SVCR_ZA_MASK),
447					 "%s SVE VL %d/SME VL %d ZA\n",
448					 cfg->name, sve_vls[sve],
449					 sme_vls[sme]);
450		}
451	}
452
453	for (sme = 0; sme < sme_vl_count; sme++) {
454		ret = prctl(PR_SME_SET_VL, sme_vls[sme]);
455		if (ret == -1)
456			ksft_exit_fail_msg("PR_SME_SET_VL failed: %s (%d)\n",
457						   strerror(errno), errno);
458
459		ksft_test_result(do_test(cfg, 0, sme_vls[sme],
460					 SVCR_ZA_MASK | SVCR_SM_MASK),
461				 "%s SME VL %d SM+ZA\n",
462				 cfg->name, sme_vls[sme]);
463		ksft_test_result(do_test(cfg, 0, sme_vls[sme], SVCR_SM_MASK),
464				 "%s SME VL %d SM\n",
465				 cfg->name, sme_vls[sme]);
466		ksft_test_result(do_test(cfg, 0, sme_vls[sme], SVCR_ZA_MASK),
467				 "%s SME VL %d ZA\n",
468				 cfg->name, sme_vls[sme]);
469	}
470}
471
472void sve_count_vls(void)
473{
474	unsigned int vq;
475	int vl;
476
477	if (!(getauxval(AT_HWCAP) & HWCAP_SVE))
478		return;
479
480	/*
481	 * Enumerate up to ARCH_SVE_VQ_MAX vector lengths
482	 */
483	for (vq = ARCH_SVE_VQ_MAX; vq > 0; vq /= 2) {
484		vl = prctl(PR_SVE_SET_VL, vq * 16);
485		if (vl == -1)
486			ksft_exit_fail_msg("PR_SVE_SET_VL failed: %s (%d)\n",
487					   strerror(errno), errno);
488
489		vl &= PR_SVE_VL_LEN_MASK;
490
491		if (vq != sve_vq_from_vl(vl))
492			vq = sve_vq_from_vl(vl);
493
494		sve_vls[sve_vl_count++] = vl;
495	}
496}
497
498void sme_count_vls(void)
499{
500	unsigned int vq;
501	int vl;
502
503	if (!(getauxval(AT_HWCAP2) & HWCAP2_SME))
504		return;
505
506	/*
507	 * Enumerate up to ARCH_SVE_VQ_MAX vector lengths
508	 */
509	for (vq = ARCH_SVE_VQ_MAX; vq > 0; vq /= 2) {
510		vl = prctl(PR_SME_SET_VL, vq * 16);
511		if (vl == -1)
512			ksft_exit_fail_msg("PR_SME_SET_VL failed: %s (%d)\n",
513					   strerror(errno), errno);
514
515		vl &= PR_SME_VL_LEN_MASK;
516
517		/* Found lowest VL */
518		if (sve_vq_from_vl(vl) > vq)
519			break;
520
521		if (vq != sve_vq_from_vl(vl))
522			vq = sve_vq_from_vl(vl);
523
524		sme_vls[sme_vl_count++] = vl;
525	}
526
527	/* Ensure we configure a SME VL, used to flag if SVCR is set */
528	default_sme_vl = sme_vls[0];
529}
530
531int main(void)
532{
533	int i;
534	int tests = 1;  /* FPSIMD */
535	int sme_ver;
536
537	srandom(getpid());
538
539	ksft_print_header();
540
541	sve_count_vls();
542	sme_count_vls();
543
544	tests += sve_vl_count;
545	tests += sme_vl_count * 3;
546	tests += (sve_vl_count * sme_vl_count) * 3;
547	ksft_set_plan(ARRAY_SIZE(syscalls) * tests);
548
549	if (getauxval(AT_HWCAP2) & HWCAP2_SME2)
550		sme_ver = 2;
551	else
552		sme_ver = 1;
553
554	if (getauxval(AT_HWCAP2) & HWCAP2_SME_FA64)
555		ksft_print_msg("SME%d with FA64\n", sme_ver);
556	else if (getauxval(AT_HWCAP2) & HWCAP2_SME)
557		ksft_print_msg("SME%d without FA64\n", sme_ver);
558
559	for (i = 0; i < ARRAY_SIZE(syscalls); i++)
560		test_one_syscall(&syscalls[i]);
561
562	ksft_print_cnts();
563
564	return 0;
565}
566