cachepc-linux

Fork of AMDESE/linux with modifications for CachePC side-channel attack
git clone https://git.sinitax.com/sinitax/cachepc-linux
Log | Files | Refs | README | LICENSE | sfeed.txt

util.c (8154B)


      1// SPDX-License-Identifier: GPL-2.0-only
      2/*
      3 * vsock test utilities
      4 *
      5 * Copyright (C) 2017 Red Hat, Inc.
      6 *
      7 * Author: Stefan Hajnoczi <stefanha@redhat.com>
      8 */
      9
     10#include <errno.h>
     11#include <stdio.h>
     12#include <stdint.h>
     13#include <stdlib.h>
     14#include <signal.h>
     15#include <unistd.h>
     16#include <assert.h>
     17#include <sys/epoll.h>
     18
     19#include "timeout.h"
     20#include "control.h"
     21#include "util.h"
     22
     23/* Install signal handlers */
     24void init_signals(void)
     25{
     26	struct sigaction act = {
     27		.sa_handler = sigalrm,
     28	};
     29
     30	sigaction(SIGALRM, &act, NULL);
     31	signal(SIGPIPE, SIG_IGN);
     32}
     33
     34/* Parse a CID in string representation */
     35unsigned int parse_cid(const char *str)
     36{
     37	char *endptr = NULL;
     38	unsigned long n;
     39
     40	errno = 0;
     41	n = strtoul(str, &endptr, 10);
     42	if (errno || *endptr != '\0') {
     43		fprintf(stderr, "malformed CID \"%s\"\n", str);
     44		exit(EXIT_FAILURE);
     45	}
     46	return n;
     47}
     48
     49/* Wait for the remote to close the connection */
     50void vsock_wait_remote_close(int fd)
     51{
     52	struct epoll_event ev;
     53	int epollfd, nfds;
     54
     55	epollfd = epoll_create1(0);
     56	if (epollfd == -1) {
     57		perror("epoll_create1");
     58		exit(EXIT_FAILURE);
     59	}
     60
     61	ev.events = EPOLLRDHUP | EPOLLHUP;
     62	ev.data.fd = fd;
     63	if (epoll_ctl(epollfd, EPOLL_CTL_ADD, fd, &ev) == -1) {
     64		perror("epoll_ctl");
     65		exit(EXIT_FAILURE);
     66	}
     67
     68	nfds = epoll_wait(epollfd, &ev, 1, TIMEOUT * 1000);
     69	if (nfds == -1) {
     70		perror("epoll_wait");
     71		exit(EXIT_FAILURE);
     72	}
     73
     74	if (nfds == 0) {
     75		fprintf(stderr, "epoll_wait timed out\n");
     76		exit(EXIT_FAILURE);
     77	}
     78
     79	assert(nfds == 1);
     80	assert(ev.events & (EPOLLRDHUP | EPOLLHUP));
     81	assert(ev.data.fd == fd);
     82
     83	close(epollfd);
     84}
     85
     86/* Connect to <cid, port> and return the file descriptor. */
     87static int vsock_connect(unsigned int cid, unsigned int port, int type)
     88{
     89	union {
     90		struct sockaddr sa;
     91		struct sockaddr_vm svm;
     92	} addr = {
     93		.svm = {
     94			.svm_family = AF_VSOCK,
     95			.svm_port = port,
     96			.svm_cid = cid,
     97		},
     98	};
     99	int ret;
    100	int fd;
    101
    102	control_expectln("LISTENING");
    103
    104	fd = socket(AF_VSOCK, type, 0);
    105
    106	timeout_begin(TIMEOUT);
    107	do {
    108		ret = connect(fd, &addr.sa, sizeof(addr.svm));
    109		timeout_check("connect");
    110	} while (ret < 0 && errno == EINTR);
    111	timeout_end();
    112
    113	if (ret < 0) {
    114		int old_errno = errno;
    115
    116		close(fd);
    117		fd = -1;
    118		errno = old_errno;
    119	}
    120	return fd;
    121}
    122
    123int vsock_stream_connect(unsigned int cid, unsigned int port)
    124{
    125	return vsock_connect(cid, port, SOCK_STREAM);
    126}
    127
    128int vsock_seqpacket_connect(unsigned int cid, unsigned int port)
    129{
    130	return vsock_connect(cid, port, SOCK_SEQPACKET);
    131}
    132
    133/* Listen on <cid, port> and return the first incoming connection.  The remote
    134 * address is stored to clientaddrp.  clientaddrp may be NULL.
    135 */
    136static int vsock_accept(unsigned int cid, unsigned int port,
    137			struct sockaddr_vm *clientaddrp, int type)
    138{
    139	union {
    140		struct sockaddr sa;
    141		struct sockaddr_vm svm;
    142	} addr = {
    143		.svm = {
    144			.svm_family = AF_VSOCK,
    145			.svm_port = port,
    146			.svm_cid = cid,
    147		},
    148	};
    149	union {
    150		struct sockaddr sa;
    151		struct sockaddr_vm svm;
    152	} clientaddr;
    153	socklen_t clientaddr_len = sizeof(clientaddr.svm);
    154	int fd;
    155	int client_fd;
    156	int old_errno;
    157
    158	fd = socket(AF_VSOCK, type, 0);
    159
    160	if (bind(fd, &addr.sa, sizeof(addr.svm)) < 0) {
    161		perror("bind");
    162		exit(EXIT_FAILURE);
    163	}
    164
    165	if (listen(fd, 1) < 0) {
    166		perror("listen");
    167		exit(EXIT_FAILURE);
    168	}
    169
    170	control_writeln("LISTENING");
    171
    172	timeout_begin(TIMEOUT);
    173	do {
    174		client_fd = accept(fd, &clientaddr.sa, &clientaddr_len);
    175		timeout_check("accept");
    176	} while (client_fd < 0 && errno == EINTR);
    177	timeout_end();
    178
    179	old_errno = errno;
    180	close(fd);
    181	errno = old_errno;
    182
    183	if (client_fd < 0)
    184		return client_fd;
    185
    186	if (clientaddr_len != sizeof(clientaddr.svm)) {
    187		fprintf(stderr, "unexpected addrlen from accept(2), %zu\n",
    188			(size_t)clientaddr_len);
    189		exit(EXIT_FAILURE);
    190	}
    191	if (clientaddr.sa.sa_family != AF_VSOCK) {
    192		fprintf(stderr, "expected AF_VSOCK from accept(2), got %d\n",
    193			clientaddr.sa.sa_family);
    194		exit(EXIT_FAILURE);
    195	}
    196
    197	if (clientaddrp)
    198		*clientaddrp = clientaddr.svm;
    199	return client_fd;
    200}
    201
    202int vsock_stream_accept(unsigned int cid, unsigned int port,
    203			struct sockaddr_vm *clientaddrp)
    204{
    205	return vsock_accept(cid, port, clientaddrp, SOCK_STREAM);
    206}
    207
    208int vsock_seqpacket_accept(unsigned int cid, unsigned int port,
    209			   struct sockaddr_vm *clientaddrp)
    210{
    211	return vsock_accept(cid, port, clientaddrp, SOCK_SEQPACKET);
    212}
    213
    214/* Transmit one byte and check the return value.
    215 *
    216 * expected_ret:
    217 *  <0 Negative errno (for testing errors)
    218 *   0 End-of-file
    219 *   1 Success
    220 */
    221void send_byte(int fd, int expected_ret, int flags)
    222{
    223	const uint8_t byte = 'A';
    224	ssize_t nwritten;
    225
    226	timeout_begin(TIMEOUT);
    227	do {
    228		nwritten = send(fd, &byte, sizeof(byte), flags);
    229		timeout_check("write");
    230	} while (nwritten < 0 && errno == EINTR);
    231	timeout_end();
    232
    233	if (expected_ret < 0) {
    234		if (nwritten != -1) {
    235			fprintf(stderr, "bogus send(2) return value %zd\n",
    236				nwritten);
    237			exit(EXIT_FAILURE);
    238		}
    239		if (errno != -expected_ret) {
    240			perror("write");
    241			exit(EXIT_FAILURE);
    242		}
    243		return;
    244	}
    245
    246	if (nwritten < 0) {
    247		perror("write");
    248		exit(EXIT_FAILURE);
    249	}
    250	if (nwritten == 0) {
    251		if (expected_ret == 0)
    252			return;
    253
    254		fprintf(stderr, "unexpected EOF while sending byte\n");
    255		exit(EXIT_FAILURE);
    256	}
    257	if (nwritten != sizeof(byte)) {
    258		fprintf(stderr, "bogus send(2) return value %zd\n", nwritten);
    259		exit(EXIT_FAILURE);
    260	}
    261}
    262
    263/* Receive one byte and check the return value.
    264 *
    265 * expected_ret:
    266 *  <0 Negative errno (for testing errors)
    267 *   0 End-of-file
    268 *   1 Success
    269 */
    270void recv_byte(int fd, int expected_ret, int flags)
    271{
    272	uint8_t byte;
    273	ssize_t nread;
    274
    275	timeout_begin(TIMEOUT);
    276	do {
    277		nread = recv(fd, &byte, sizeof(byte), flags);
    278		timeout_check("read");
    279	} while (nread < 0 && errno == EINTR);
    280	timeout_end();
    281
    282	if (expected_ret < 0) {
    283		if (nread != -1) {
    284			fprintf(stderr, "bogus recv(2) return value %zd\n",
    285				nread);
    286			exit(EXIT_FAILURE);
    287		}
    288		if (errno != -expected_ret) {
    289			perror("read");
    290			exit(EXIT_FAILURE);
    291		}
    292		return;
    293	}
    294
    295	if (nread < 0) {
    296		perror("read");
    297		exit(EXIT_FAILURE);
    298	}
    299	if (nread == 0) {
    300		if (expected_ret == 0)
    301			return;
    302
    303		fprintf(stderr, "unexpected EOF while receiving byte\n");
    304		exit(EXIT_FAILURE);
    305	}
    306	if (nread != sizeof(byte)) {
    307		fprintf(stderr, "bogus recv(2) return value %zd\n", nread);
    308		exit(EXIT_FAILURE);
    309	}
    310	if (byte != 'A') {
    311		fprintf(stderr, "unexpected byte read %c\n", byte);
    312		exit(EXIT_FAILURE);
    313	}
    314}
    315
    316/* Run test cases.  The program terminates if a failure occurs. */
    317void run_tests(const struct test_case *test_cases,
    318	       const struct test_opts *opts)
    319{
    320	int i;
    321
    322	for (i = 0; test_cases[i].name; i++) {
    323		void (*run)(const struct test_opts *opts);
    324		char *line;
    325
    326		printf("%d - %s...", i, test_cases[i].name);
    327		fflush(stdout);
    328
    329		/* Full barrier before executing the next test.  This
    330		 * ensures that client and server are executing the
    331		 * same test case.  In particular, it means whoever is
    332		 * faster will not see the peer still executing the
    333		 * last test.  This is important because port numbers
    334		 * can be used by multiple test cases.
    335		 */
    336		if (test_cases[i].skip)
    337			control_writeln("SKIP");
    338		else
    339			control_writeln("NEXT");
    340
    341		line = control_readln();
    342		if (control_cmpln(line, "SKIP", false) || test_cases[i].skip) {
    343
    344			printf("skipped\n");
    345
    346			free(line);
    347			continue;
    348		}
    349
    350		control_cmpln(line, "NEXT", true);
    351		free(line);
    352
    353		if (opts->mode == TEST_MODE_CLIENT)
    354			run = test_cases[i].run_client;
    355		else
    356			run = test_cases[i].run_server;
    357
    358		if (run)
    359			run(opts);
    360
    361		printf("ok\n");
    362	}
    363}
    364
    365void list_tests(const struct test_case *test_cases)
    366{
    367	int i;
    368
    369	printf("ID\tTest name\n");
    370
    371	for (i = 0; test_cases[i].name; i++)
    372		printf("%d\t%s\n", i, test_cases[i].name);
    373
    374	exit(EXIT_FAILURE);
    375}
    376
    377void skip_test(struct test_case *test_cases, size_t test_cases_len,
    378	       const char *test_id_str)
    379{
    380	unsigned long test_id;
    381	char *endptr = NULL;
    382
    383	errno = 0;
    384	test_id = strtoul(test_id_str, &endptr, 10);
    385	if (errno || *endptr != '\0') {
    386		fprintf(stderr, "malformed test ID \"%s\"\n", test_id_str);
    387		exit(EXIT_FAILURE);
    388	}
    389
    390	if (test_id >= test_cases_len) {
    391		fprintf(stderr, "test ID (%lu) larger than the max allowed (%lu)\n",
    392			test_id, test_cases_len - 1);
    393		exit(EXIT_FAILURE);
    394	}
    395
    396	test_cases[test_id].skip = true;
    397}