cachepc-linux

Fork of AMDESE/linux with modifications for CachePC side-channel attack
git clone https://git.sinitax.com/sinitax/cachepc-linux
Log | Files | Refs | README | LICENSE | sfeed.txt

syscall-abi.c (11003B)


      1// SPDX-License-Identifier: GPL-2.0-only
      2/*
      3 * Copyright (C) 2021 ARM Limited.
      4 */
      5
      6#include <errno.h>
      7#include <stdbool.h>
      8#include <stddef.h>
      9#include <stdio.h>
     10#include <stdlib.h>
     11#include <string.h>
     12#include <unistd.h>
     13#include <sys/auxv.h>
     14#include <sys/prctl.h>
     15#include <asm/hwcap.h>
     16#include <asm/sigcontext.h>
     17#include <asm/unistd.h>
     18
     19#include "../../kselftest.h"
     20
     21#include "syscall-abi.h"
     22
     23#define NUM_VL ((SVE_VQ_MAX - SVE_VQ_MIN) + 1)
     24
     25static int default_sme_vl;
     26
     27extern void do_syscall(int sve_vl, int sme_vl);
     28
     29static void fill_random(void *buf, size_t size)
     30{
     31	int i;
     32	uint32_t *lbuf = buf;
     33
     34	/* random() returns a 32 bit number regardless of the size of long */
     35	for (i = 0; i < size / sizeof(uint32_t); i++)
     36		lbuf[i] = random();
     37}
     38
     39/*
     40 * We also repeat the test for several syscalls to try to expose different
     41 * behaviour.
     42 */
     43static struct syscall_cfg {
     44	int syscall_nr;
     45	const char *name;
     46} syscalls[] = {
     47	{ __NR_getpid,		"getpid()" },
     48	{ __NR_sched_yield,	"sched_yield()" },
     49};
     50
     51#define NUM_GPR 31
     52uint64_t gpr_in[NUM_GPR];
     53uint64_t gpr_out[NUM_GPR];
     54
     55static void setup_gpr(struct syscall_cfg *cfg, int sve_vl, int sme_vl,
     56		      uint64_t svcr)
     57{
     58	fill_random(gpr_in, sizeof(gpr_in));
     59	gpr_in[8] = cfg->syscall_nr;
     60	memset(gpr_out, 0, sizeof(gpr_out));
     61}
     62
     63static int check_gpr(struct syscall_cfg *cfg, int sve_vl, int sme_vl, uint64_t svcr)
     64{
     65	int errors = 0;
     66	int i;
     67
     68	/*
     69	 * GPR x0-x7 may be clobbered, and all others should be preserved.
     70	 */
     71	for (i = 9; i < ARRAY_SIZE(gpr_in); i++) {
     72		if (gpr_in[i] != gpr_out[i]) {
     73			ksft_print_msg("%s SVE VL %d mismatch in GPR %d: %llx != %llx\n",
     74				       cfg->name, sve_vl, i,
     75				       gpr_in[i], gpr_out[i]);
     76			errors++;
     77		}
     78	}
     79
     80	return errors;
     81}
     82
     83#define NUM_FPR 32
     84uint64_t fpr_in[NUM_FPR * 2];
     85uint64_t fpr_out[NUM_FPR * 2];
     86
     87static void setup_fpr(struct syscall_cfg *cfg, int sve_vl, int sme_vl,
     88		      uint64_t svcr)
     89{
     90	fill_random(fpr_in, sizeof(fpr_in));
     91	memset(fpr_out, 0, sizeof(fpr_out));
     92}
     93
     94static int check_fpr(struct syscall_cfg *cfg, int sve_vl, int sme_vl,
     95		     uint64_t svcr)
     96{
     97	int errors = 0;
     98	int i;
     99
    100	if (!sve_vl) {
    101		for (i = 0; i < ARRAY_SIZE(fpr_in); i++) {
    102			if (fpr_in[i] != fpr_out[i]) {
    103				ksft_print_msg("%s Q%d/%d mismatch %llx != %llx\n",
    104					       cfg->name,
    105					       i / 2, i % 2,
    106					       fpr_in[i], fpr_out[i]);
    107				errors++;
    108			}
    109		}
    110	}
    111
    112	return errors;
    113}
    114
    115static uint8_t z_zero[__SVE_ZREG_SIZE(SVE_VQ_MAX)];
    116uint8_t z_in[SVE_NUM_PREGS * __SVE_ZREG_SIZE(SVE_VQ_MAX)];
    117uint8_t z_out[SVE_NUM_PREGS * __SVE_ZREG_SIZE(SVE_VQ_MAX)];
    118
    119static void setup_z(struct syscall_cfg *cfg, int sve_vl, int sme_vl,
    120		    uint64_t svcr)
    121{
    122	fill_random(z_in, sizeof(z_in));
    123	fill_random(z_out, sizeof(z_out));
    124}
    125
    126static int check_z(struct syscall_cfg *cfg, int sve_vl, int sme_vl,
    127		   uint64_t svcr)
    128{
    129	size_t reg_size = sve_vl;
    130	int errors = 0;
    131	int i;
    132
    133	if (!sve_vl)
    134		return 0;
    135
    136	/*
    137	 * After a syscall the low 128 bits of the Z registers should
    138	 * be preserved and the rest be zeroed or preserved, except if
    139	 * we were in streaming mode in which case the low 128 bits may
    140	 * also be cleared by the transition out of streaming mode.
    141	 */
    142	for (i = 0; i < SVE_NUM_ZREGS; i++) {
    143		void *in = &z_in[reg_size * i];
    144		void *out = &z_out[reg_size * i];
    145
    146		if ((memcmp(in, out, SVE_VQ_BYTES) != 0) &&
    147		    !((svcr & SVCR_SM_MASK) &&
    148		      memcmp(z_zero, out, SVE_VQ_BYTES) == 0)) {
    149			ksft_print_msg("%s SVE VL %d Z%d low 128 bits changed\n",
    150				       cfg->name, sve_vl, i);
    151			errors++;
    152		}
    153	}
    154
    155	return errors;
    156}
    157
    158uint8_t p_in[SVE_NUM_PREGS * __SVE_PREG_SIZE(SVE_VQ_MAX)];
    159uint8_t p_out[SVE_NUM_PREGS * __SVE_PREG_SIZE(SVE_VQ_MAX)];
    160
    161static void setup_p(struct syscall_cfg *cfg, int sve_vl, int sme_vl,
    162		    uint64_t svcr)
    163{
    164	fill_random(p_in, sizeof(p_in));
    165	fill_random(p_out, sizeof(p_out));
    166}
    167
    168static int check_p(struct syscall_cfg *cfg, int sve_vl, int sme_vl,
    169		   uint64_t svcr)
    170{
    171	size_t reg_size = sve_vq_from_vl(sve_vl) * 2; /* 1 bit per VL byte */
    172
    173	int errors = 0;
    174	int i;
    175
    176	if (!sve_vl)
    177		return 0;
    178
    179	/* After a syscall the P registers should be preserved or zeroed */
    180	for (i = 0; i < SVE_NUM_PREGS * reg_size; i++)
    181		if (p_out[i] && (p_in[i] != p_out[i]))
    182			errors++;
    183	if (errors)
    184		ksft_print_msg("%s SVE VL %d predicate registers non-zero\n",
    185			       cfg->name, sve_vl);
    186
    187	return errors;
    188}
    189
    190uint8_t ffr_in[__SVE_PREG_SIZE(SVE_VQ_MAX)];
    191uint8_t ffr_out[__SVE_PREG_SIZE(SVE_VQ_MAX)];
    192
    193static void setup_ffr(struct syscall_cfg *cfg, int sve_vl, int sme_vl,
    194		      uint64_t svcr)
    195{
    196	/*
    197	 * If we are in streaming mode and do not have FA64 then FFR
    198	 * is unavailable.
    199	 */
    200	if ((svcr & SVCR_SM_MASK) &&
    201	    !(getauxval(AT_HWCAP2) & HWCAP2_SME_FA64)) {
    202		memset(&ffr_in, 0, sizeof(ffr_in));
    203		return;
    204	}
    205
    206	/*
    207	 * It is only valid to set a contiguous set of bits starting
    208	 * at 0.  For now since we're expecting this to be cleared by
    209	 * a syscall just set all bits.
    210	 */
    211	memset(ffr_in, 0xff, sizeof(ffr_in));
    212	fill_random(ffr_out, sizeof(ffr_out));
    213}
    214
    215static int check_ffr(struct syscall_cfg *cfg, int sve_vl, int sme_vl,
    216		     uint64_t svcr)
    217{
    218	size_t reg_size = sve_vq_from_vl(sve_vl) * 2;  /* 1 bit per VL byte */
    219	int errors = 0;
    220	int i;
    221
    222	if (!sve_vl)
    223		return 0;
    224
    225	if ((svcr & SVCR_SM_MASK) &&
    226	    !(getauxval(AT_HWCAP2) & HWCAP2_SME_FA64))
    227		return 0;
    228
    229	/* After a syscall the P registers should be preserved or zeroed */
    230	for (i = 0; i < reg_size; i++)
    231		if (ffr_out[i] && (ffr_in[i] != ffr_out[i]))
    232			errors++;
    233	if (errors)
    234		ksft_print_msg("%s SVE VL %d FFR non-zero\n",
    235			       cfg->name, sve_vl);
    236
    237	return errors;
    238}
    239
    240uint64_t svcr_in, svcr_out;
    241
    242static void setup_svcr(struct syscall_cfg *cfg, int sve_vl, int sme_vl,
    243		    uint64_t svcr)
    244{
    245	svcr_in = svcr;
    246}
    247
    248static int check_svcr(struct syscall_cfg *cfg, int sve_vl, int sme_vl,
    249		      uint64_t svcr)
    250{
    251	int errors = 0;
    252
    253	if (svcr_out & SVCR_SM_MASK) {
    254		ksft_print_msg("%s Still in SM, SVCR %llx\n",
    255			       cfg->name, svcr_out);
    256		errors++;
    257	}
    258
    259	if ((svcr_in & SVCR_ZA_MASK) != (svcr_out & SVCR_ZA_MASK)) {
    260		ksft_print_msg("%s PSTATE.ZA changed, SVCR %llx != %llx\n",
    261			       cfg->name, svcr_in, svcr_out);
    262		errors++;
    263	}
    264
    265	return errors;
    266}
    267
    268uint8_t za_in[SVE_NUM_PREGS * __SVE_ZREG_SIZE(SVE_VQ_MAX)];
    269uint8_t za_out[SVE_NUM_PREGS * __SVE_ZREG_SIZE(SVE_VQ_MAX)];
    270
    271static void setup_za(struct syscall_cfg *cfg, int sve_vl, int sme_vl,
    272		     uint64_t svcr)
    273{
    274	fill_random(za_in, sizeof(za_in));
    275	memset(za_out, 0, sizeof(za_out));
    276}
    277
    278static int check_za(struct syscall_cfg *cfg, int sve_vl, int sme_vl,
    279		    uint64_t svcr)
    280{
    281	size_t reg_size = sme_vl * sme_vl;
    282	int errors = 0;
    283
    284	if (!(svcr & SVCR_ZA_MASK))
    285		return 0;
    286
    287	if (memcmp(za_in, za_out, reg_size) != 0) {
    288		ksft_print_msg("SME VL %d ZA does not match\n", sme_vl);
    289		errors++;
    290	}
    291
    292	return errors;
    293}
    294
    295typedef void (*setup_fn)(struct syscall_cfg *cfg, int sve_vl, int sme_vl,
    296			 uint64_t svcr);
    297typedef int (*check_fn)(struct syscall_cfg *cfg, int sve_vl, int sme_vl,
    298			uint64_t svcr);
    299
    300/*
    301 * Each set of registers has a setup function which is called before
    302 * the syscall to fill values in a global variable for loading by the
    303 * test code and a check function which validates that the results are
    304 * as expected.  Vector lengths are passed everywhere, a vector length
    305 * of 0 should be treated as do not test.
    306 */
    307static struct {
    308	setup_fn setup;
    309	check_fn check;
    310} regset[] = {
    311	{ setup_gpr, check_gpr },
    312	{ setup_fpr, check_fpr },
    313	{ setup_z, check_z },
    314	{ setup_p, check_p },
    315	{ setup_ffr, check_ffr },
    316	{ setup_svcr, check_svcr },
    317	{ setup_za, check_za },
    318};
    319
    320static bool do_test(struct syscall_cfg *cfg, int sve_vl, int sme_vl,
    321		    uint64_t svcr)
    322{
    323	int errors = 0;
    324	int i;
    325
    326	for (i = 0; i < ARRAY_SIZE(regset); i++)
    327		regset[i].setup(cfg, sve_vl, sme_vl, svcr);
    328
    329	do_syscall(sve_vl, sme_vl);
    330
    331	for (i = 0; i < ARRAY_SIZE(regset); i++)
    332		errors += regset[i].check(cfg, sve_vl, sme_vl, svcr);
    333
    334	return errors == 0;
    335}
    336
    337static void test_one_syscall(struct syscall_cfg *cfg)
    338{
    339	int sve_vq, sve_vl;
    340	int sme_vq, sme_vl;
    341
    342	/* FPSIMD only case */
    343	ksft_test_result(do_test(cfg, 0, default_sme_vl, 0),
    344			 "%s FPSIMD\n", cfg->name);
    345
    346	if (!(getauxval(AT_HWCAP) & HWCAP_SVE))
    347		return;
    348
    349	for (sve_vq = SVE_VQ_MAX; sve_vq > 0; --sve_vq) {
    350		sve_vl = prctl(PR_SVE_SET_VL, sve_vq * 16);
    351		if (sve_vl == -1)
    352			ksft_exit_fail_msg("PR_SVE_SET_VL failed: %s (%d)\n",
    353					   strerror(errno), errno);
    354
    355		sve_vl &= PR_SVE_VL_LEN_MASK;
    356
    357		if (sve_vq != sve_vq_from_vl(sve_vl))
    358			sve_vq = sve_vq_from_vl(sve_vl);
    359
    360		ksft_test_result(do_test(cfg, sve_vl, default_sme_vl, 0),
    361				 "%s SVE VL %d\n", cfg->name, sve_vl);
    362
    363		if (!(getauxval(AT_HWCAP2) & HWCAP2_SME))
    364			continue;
    365
    366		for (sme_vq = SVE_VQ_MAX; sme_vq > 0; --sme_vq) {
    367			sme_vl = prctl(PR_SME_SET_VL, sme_vq * 16);
    368			if (sme_vl == -1)
    369				ksft_exit_fail_msg("PR_SME_SET_VL failed: %s (%d)\n",
    370						   strerror(errno), errno);
    371
    372			sme_vl &= PR_SME_VL_LEN_MASK;
    373
    374			if (sme_vq != sve_vq_from_vl(sme_vl))
    375				sme_vq = sve_vq_from_vl(sme_vl);
    376
    377			ksft_test_result(do_test(cfg, sve_vl, sme_vl,
    378						 SVCR_ZA_MASK | SVCR_SM_MASK),
    379					 "%s SVE VL %d/SME VL %d SM+ZA\n",
    380					 cfg->name, sve_vl, sme_vl);
    381			ksft_test_result(do_test(cfg, sve_vl, sme_vl,
    382						 SVCR_SM_MASK),
    383					 "%s SVE VL %d/SME VL %d SM\n",
    384					 cfg->name, sve_vl, sme_vl);
    385			ksft_test_result(do_test(cfg, sve_vl, sme_vl,
    386						 SVCR_ZA_MASK),
    387					 "%s SVE VL %d/SME VL %d ZA\n",
    388					 cfg->name, sve_vl, sme_vl);
    389		}
    390	}
    391}
    392
    393int sve_count_vls(void)
    394{
    395	unsigned int vq;
    396	int vl_count = 0;
    397	int vl;
    398
    399	if (!(getauxval(AT_HWCAP) & HWCAP_SVE))
    400		return 0;
    401
    402	/*
    403	 * Enumerate up to SVE_VQ_MAX vector lengths
    404	 */
    405	for (vq = SVE_VQ_MAX; vq > 0; --vq) {
    406		vl = prctl(PR_SVE_SET_VL, vq * 16);
    407		if (vl == -1)
    408			ksft_exit_fail_msg("PR_SVE_SET_VL failed: %s (%d)\n",
    409					   strerror(errno), errno);
    410
    411		vl &= PR_SVE_VL_LEN_MASK;
    412
    413		if (vq != sve_vq_from_vl(vl))
    414			vq = sve_vq_from_vl(vl);
    415
    416		vl_count++;
    417	}
    418
    419	return vl_count;
    420}
    421
    422int sme_count_vls(void)
    423{
    424	unsigned int vq;
    425	int vl_count = 0;
    426	int vl;
    427
    428	if (!(getauxval(AT_HWCAP2) & HWCAP2_SME))
    429		return 0;
    430
    431	/* Ensure we configure a SME VL, used to flag if SVCR is set */
    432	default_sme_vl = 16;
    433
    434	/*
    435	 * Enumerate up to SVE_VQ_MAX vector lengths
    436	 */
    437	for (vq = SVE_VQ_MAX; vq > 0; --vq) {
    438		vl = prctl(PR_SME_SET_VL, vq * 16);
    439		if (vl == -1)
    440			ksft_exit_fail_msg("PR_SME_SET_VL failed: %s (%d)\n",
    441					   strerror(errno), errno);
    442
    443		vl &= PR_SME_VL_LEN_MASK;
    444
    445		if (vq != sve_vq_from_vl(vl))
    446			vq = sve_vq_from_vl(vl);
    447
    448		vl_count++;
    449	}
    450
    451	return vl_count;
    452}
    453
    454int main(void)
    455{
    456	int i;
    457	int tests = 1;  /* FPSIMD */
    458
    459	srandom(getpid());
    460
    461	ksft_print_header();
    462	tests += sve_count_vls();
    463	tests += (sve_count_vls() * sme_count_vls()) * 3;
    464	ksft_set_plan(ARRAY_SIZE(syscalls) * tests);
    465
    466	if (getauxval(AT_HWCAP2) & HWCAP2_SME_FA64)
    467		ksft_print_msg("SME with FA64\n");
    468	else if (getauxval(AT_HWCAP2) & HWCAP2_SME)
    469		ksft_print_msg("SME without FA64\n");
    470
    471	for (i = 0; i < ARRAY_SIZE(syscalls); i++)
    472		test_one_syscall(&syscalls[i]);
    473
    474	ksft_print_cnts();
    475
    476	return 0;
    477}