cachepc-linux

Fork of AMDESE/linux with modifications for CachePC side-channel attack
git clone https://git.sinitax.com/sinitax/cachepc-linux
Log | Files | Refs | README | LICENSE | sfeed.txt

mte_common_util.c (9898B)


      1// SPDX-License-Identifier: GPL-2.0
      2// Copyright (C) 2020 ARM Limited
      3
      4#include <fcntl.h>
      5#include <sched.h>
      6#include <signal.h>
      7#include <stdio.h>
      8#include <stdlib.h>
      9#include <unistd.h>
     10
     11#include <linux/auxvec.h>
     12#include <sys/auxv.h>
     13#include <sys/mman.h>
     14#include <sys/prctl.h>
     15
     16#include <asm/hwcap.h>
     17
     18#include "kselftest.h"
     19#include "mte_common_util.h"
     20#include "mte_def.h"
     21
     22#define INIT_BUFFER_SIZE       256
     23
     24struct mte_fault_cxt cur_mte_cxt;
     25static unsigned int mte_cur_mode;
     26static unsigned int mte_cur_pstate_tco;
     27
     28void mte_default_handler(int signum, siginfo_t *si, void *uc)
     29{
     30	unsigned long addr = (unsigned long)si->si_addr;
     31
     32	if (signum == SIGSEGV) {
     33#ifdef DEBUG
     34		ksft_print_msg("INFO: SIGSEGV signal at pc=%lx, fault addr=%lx, si_code=%lx\n",
     35				((ucontext_t *)uc)->uc_mcontext.pc, addr, si->si_code);
     36#endif
     37		if (si->si_code == SEGV_MTEAERR) {
     38			if (cur_mte_cxt.trig_si_code == si->si_code)
     39				cur_mte_cxt.fault_valid = true;
     40			else
     41				ksft_print_msg("Got unexpected SEGV_MTEAERR at pc=$lx, fault addr=%lx\n",
     42					       ((ucontext_t *)uc)->uc_mcontext.pc,
     43					       addr);
     44			return;
     45		}
     46		/* Compare the context for precise error */
     47		else if (si->si_code == SEGV_MTESERR) {
     48			if (cur_mte_cxt.trig_si_code == si->si_code &&
     49			    ((cur_mte_cxt.trig_range >= 0 &&
     50			      addr >= MT_CLEAR_TAG(cur_mte_cxt.trig_addr) &&
     51			      addr <= (MT_CLEAR_TAG(cur_mte_cxt.trig_addr) + cur_mte_cxt.trig_range)) ||
     52			     (cur_mte_cxt.trig_range < 0 &&
     53			      addr <= MT_CLEAR_TAG(cur_mte_cxt.trig_addr) &&
     54			      addr >= (MT_CLEAR_TAG(cur_mte_cxt.trig_addr) + cur_mte_cxt.trig_range)))) {
     55				cur_mte_cxt.fault_valid = true;
     56				/* Adjust the pc by 4 */
     57				((ucontext_t *)uc)->uc_mcontext.pc += 4;
     58			} else {
     59				ksft_print_msg("Invalid MTE synchronous exception caught!\n");
     60				exit(1);
     61			}
     62		} else {
     63			ksft_print_msg("Unknown SIGSEGV exception caught!\n");
     64			exit(1);
     65		}
     66	} else if (signum == SIGBUS) {
     67		ksft_print_msg("INFO: SIGBUS signal at pc=%lx, fault addr=%lx, si_code=%lx\n",
     68				((ucontext_t *)uc)->uc_mcontext.pc, addr, si->si_code);
     69		if ((cur_mte_cxt.trig_range >= 0 &&
     70		     addr >= MT_CLEAR_TAG(cur_mte_cxt.trig_addr) &&
     71		     addr <= (MT_CLEAR_TAG(cur_mte_cxt.trig_addr) + cur_mte_cxt.trig_range)) ||
     72		    (cur_mte_cxt.trig_range < 0 &&
     73		     addr <= MT_CLEAR_TAG(cur_mte_cxt.trig_addr) &&
     74		     addr >= (MT_CLEAR_TAG(cur_mte_cxt.trig_addr) + cur_mte_cxt.trig_range))) {
     75			cur_mte_cxt.fault_valid = true;
     76			/* Adjust the pc by 4 */
     77			((ucontext_t *)uc)->uc_mcontext.pc += 4;
     78		}
     79	}
     80}
     81
     82void mte_register_signal(int signal, void (*handler)(int, siginfo_t *, void *))
     83{
     84	struct sigaction sa;
     85
     86	sa.sa_sigaction = handler;
     87	sa.sa_flags = SA_SIGINFO;
     88	sigemptyset(&sa.sa_mask);
     89	sigaction(signal, &sa, NULL);
     90}
     91
     92void mte_wait_after_trig(void)
     93{
     94	sched_yield();
     95}
     96
     97void *mte_insert_tags(void *ptr, size_t size)
     98{
     99	void *tag_ptr;
    100	int align_size;
    101
    102	if (!ptr || (unsigned long)(ptr) & MT_ALIGN_GRANULE) {
    103		ksft_print_msg("FAIL: Addr=%lx: invalid\n", ptr);
    104		return NULL;
    105	}
    106	align_size = MT_ALIGN_UP(size);
    107	tag_ptr = mte_insert_random_tag(ptr);
    108	mte_set_tag_address_range(tag_ptr, align_size);
    109	return tag_ptr;
    110}
    111
    112void mte_clear_tags(void *ptr, size_t size)
    113{
    114	if (!ptr || (unsigned long)(ptr) & MT_ALIGN_GRANULE) {
    115		ksft_print_msg("FAIL: Addr=%lx: invalid\n", ptr);
    116		return;
    117	}
    118	size = MT_ALIGN_UP(size);
    119	ptr = (void *)MT_CLEAR_TAG((unsigned long)ptr);
    120	mte_clear_tag_address_range(ptr, size);
    121}
    122
    123static void *__mte_allocate_memory_range(size_t size, int mem_type, int mapping,
    124					 size_t range_before, size_t range_after,
    125					 bool tags, int fd)
    126{
    127	void *ptr;
    128	int prot_flag, map_flag;
    129	size_t entire_size = size + range_before + range_after;
    130
    131	switch (mem_type) {
    132	case USE_MALLOC:
    133		return malloc(entire_size) + range_before;
    134	case USE_MMAP:
    135	case USE_MPROTECT:
    136		break;
    137	default:
    138		ksft_print_msg("FAIL: Invalid allocate request\n");
    139		return NULL;
    140	}
    141
    142	prot_flag = PROT_READ | PROT_WRITE;
    143	if (mem_type == USE_MMAP)
    144		prot_flag |= PROT_MTE;
    145
    146	map_flag = mapping;
    147	if (fd == -1)
    148		map_flag = MAP_ANONYMOUS | map_flag;
    149	if (!(mapping & MAP_SHARED))
    150		map_flag |= MAP_PRIVATE;
    151	ptr = mmap(NULL, entire_size, prot_flag, map_flag, fd, 0);
    152	if (ptr == MAP_FAILED) {
    153		ksft_print_msg("FAIL: mmap allocation\n");
    154		return NULL;
    155	}
    156	if (mem_type == USE_MPROTECT) {
    157		if (mprotect(ptr, entire_size, prot_flag | PROT_MTE)) {
    158			munmap(ptr, size);
    159			ksft_print_msg("FAIL: mprotect PROT_MTE property\n");
    160			return NULL;
    161		}
    162	}
    163	if (tags)
    164		ptr = mte_insert_tags(ptr + range_before, size);
    165	return ptr;
    166}
    167
    168void *mte_allocate_memory_tag_range(size_t size, int mem_type, int mapping,
    169				    size_t range_before, size_t range_after)
    170{
    171	return __mte_allocate_memory_range(size, mem_type, mapping, range_before,
    172					   range_after, true, -1);
    173}
    174
    175void *mte_allocate_memory(size_t size, int mem_type, int mapping, bool tags)
    176{
    177	return __mte_allocate_memory_range(size, mem_type, mapping, 0, 0, tags, -1);
    178}
    179
    180void *mte_allocate_file_memory(size_t size, int mem_type, int mapping, bool tags, int fd)
    181{
    182	int index;
    183	char buffer[INIT_BUFFER_SIZE];
    184
    185	if (mem_type != USE_MPROTECT && mem_type != USE_MMAP) {
    186		ksft_print_msg("FAIL: Invalid mmap file request\n");
    187		return NULL;
    188	}
    189	/* Initialize the file for mappable size */
    190	lseek(fd, 0, SEEK_SET);
    191	for (index = INIT_BUFFER_SIZE; index < size; index += INIT_BUFFER_SIZE) {
    192		if (write(fd, buffer, INIT_BUFFER_SIZE) != INIT_BUFFER_SIZE) {
    193			perror("initialising buffer");
    194			return NULL;
    195		}
    196	}
    197	index -= INIT_BUFFER_SIZE;
    198	if (write(fd, buffer, size - index) != size - index) {
    199		perror("initialising buffer");
    200		return NULL;
    201	}
    202	return __mte_allocate_memory_range(size, mem_type, mapping, 0, 0, tags, fd);
    203}
    204
    205void *mte_allocate_file_memory_tag_range(size_t size, int mem_type, int mapping,
    206					 size_t range_before, size_t range_after, int fd)
    207{
    208	int index;
    209	char buffer[INIT_BUFFER_SIZE];
    210	int map_size = size + range_before + range_after;
    211
    212	if (mem_type != USE_MPROTECT && mem_type != USE_MMAP) {
    213		ksft_print_msg("FAIL: Invalid mmap file request\n");
    214		return NULL;
    215	}
    216	/* Initialize the file for mappable size */
    217	lseek(fd, 0, SEEK_SET);
    218	for (index = INIT_BUFFER_SIZE; index < map_size; index += INIT_BUFFER_SIZE)
    219		if (write(fd, buffer, INIT_BUFFER_SIZE) != INIT_BUFFER_SIZE) {
    220			perror("initialising buffer");
    221			return NULL;
    222		}
    223	index -= INIT_BUFFER_SIZE;
    224	if (write(fd, buffer, map_size - index) != map_size - index) {
    225		perror("initialising buffer");
    226		return NULL;
    227	}
    228	return __mte_allocate_memory_range(size, mem_type, mapping, range_before,
    229					   range_after, true, fd);
    230}
    231
    232static void __mte_free_memory_range(void *ptr, size_t size, int mem_type,
    233				    size_t range_before, size_t range_after, bool tags)
    234{
    235	switch (mem_type) {
    236	case USE_MALLOC:
    237		free(ptr - range_before);
    238		break;
    239	case USE_MMAP:
    240	case USE_MPROTECT:
    241		if (tags)
    242			mte_clear_tags(ptr, size);
    243		munmap(ptr - range_before, size + range_before + range_after);
    244		break;
    245	default:
    246		ksft_print_msg("FAIL: Invalid free request\n");
    247		break;
    248	}
    249}
    250
    251void mte_free_memory_tag_range(void *ptr, size_t size, int mem_type,
    252			       size_t range_before, size_t range_after)
    253{
    254	__mte_free_memory_range(ptr, size, mem_type, range_before, range_after, true);
    255}
    256
    257void mte_free_memory(void *ptr, size_t size, int mem_type, bool tags)
    258{
    259	__mte_free_memory_range(ptr, size, mem_type, 0, 0, tags);
    260}
    261
    262void mte_initialize_current_context(int mode, uintptr_t ptr, ssize_t range)
    263{
    264	cur_mte_cxt.fault_valid = false;
    265	cur_mte_cxt.trig_addr = ptr;
    266	cur_mte_cxt.trig_range = range;
    267	if (mode == MTE_SYNC_ERR)
    268		cur_mte_cxt.trig_si_code = SEGV_MTESERR;
    269	else if (mode == MTE_ASYNC_ERR)
    270		cur_mte_cxt.trig_si_code = SEGV_MTEAERR;
    271	else
    272		cur_mte_cxt.trig_si_code = 0;
    273}
    274
    275int mte_switch_mode(int mte_option, unsigned long incl_mask)
    276{
    277	unsigned long en = 0;
    278
    279	switch (mte_option) {
    280	case MTE_NONE_ERR:
    281	case MTE_SYNC_ERR:
    282	case MTE_ASYNC_ERR:
    283		break;
    284	default:
    285		ksft_print_msg("FAIL: Invalid MTE option %x\n", mte_option);
    286		return -EINVAL;
    287	}
    288
    289	if (incl_mask & ~MT_INCLUDE_TAG_MASK) {
    290		ksft_print_msg("FAIL: Invalid incl_mask %lx\n", incl_mask);
    291		return -EINVAL;
    292	}
    293
    294	en = PR_TAGGED_ADDR_ENABLE;
    295	switch (mte_option) {
    296	case MTE_SYNC_ERR:
    297		en |= PR_MTE_TCF_SYNC;
    298		break;
    299	case MTE_ASYNC_ERR:
    300		en |= PR_MTE_TCF_ASYNC;
    301		break;
    302	case MTE_NONE_ERR:
    303		en |= PR_MTE_TCF_NONE;
    304		break;
    305	}
    306
    307	en |= (incl_mask << PR_MTE_TAG_SHIFT);
    308	/* Enable address tagging ABI, mte error reporting mode and tag inclusion mask. */
    309	if (prctl(PR_SET_TAGGED_ADDR_CTRL, en, 0, 0, 0) != 0) {
    310		ksft_print_msg("FAIL:prctl PR_SET_TAGGED_ADDR_CTRL for mte mode\n");
    311		return -EINVAL;
    312	}
    313	return 0;
    314}
    315
    316int mte_default_setup(void)
    317{
    318	unsigned long hwcaps2 = getauxval(AT_HWCAP2);
    319	unsigned long en = 0;
    320	int ret;
    321
    322	if (!(hwcaps2 & HWCAP2_MTE)) {
    323		ksft_print_msg("SKIP: MTE features unavailable\n");
    324		return KSFT_SKIP;
    325	}
    326	/* Get current mte mode */
    327	ret = prctl(PR_GET_TAGGED_ADDR_CTRL, en, 0, 0, 0);
    328	if (ret < 0) {
    329		ksft_print_msg("FAIL:prctl PR_GET_TAGGED_ADDR_CTRL with error =%d\n", ret);
    330		return KSFT_FAIL;
    331	}
    332	if (ret & PR_MTE_TCF_SYNC)
    333		mte_cur_mode = MTE_SYNC_ERR;
    334	else if (ret & PR_MTE_TCF_ASYNC)
    335		mte_cur_mode = MTE_ASYNC_ERR;
    336	else if (ret & PR_MTE_TCF_NONE)
    337		mte_cur_mode = MTE_NONE_ERR;
    338
    339	mte_cur_pstate_tco = mte_get_pstate_tco();
    340	/* Disable PSTATE.TCO */
    341	mte_disable_pstate_tco();
    342	return 0;
    343}
    344
    345void mte_restore_setup(void)
    346{
    347	mte_switch_mode(mte_cur_mode, MTE_ALLOW_NON_ZERO_TAG);
    348	if (mte_cur_pstate_tco == MT_PSTATE_TCO_EN)
    349		mte_enable_pstate_tco();
    350	else if (mte_cur_pstate_tco == MT_PSTATE_TCO_DIS)
    351		mte_disable_pstate_tco();
    352}
    353
    354int create_temp_file(void)
    355{
    356	int fd;
    357	char filename[] = "/dev/shm/tmp_XXXXXX";
    358
    359	/* Create a file in the tmpfs filesystem */
    360	fd = mkstemp(&filename[0]);
    361	if (fd == -1) {
    362		perror(filename);
    363		ksft_print_msg("FAIL: Unable to open temporary file\n");
    364		return 0;
    365	}
    366	unlink(&filename[0]);
    367	return fd;
    368}