cachepc-linux

Fork of AMDESE/linux with modifications for CachePC side-channel attack
git clone https://git.sinitax.com/sinitax/cachepc-linux
Log | Files | Refs | README | LICENSE | sfeed.txt

vsock_diag_test.c (11566B)


      1// SPDX-License-Identifier: GPL-2.0-only
      2/*
      3 * vsock_diag_test - vsock_diag.ko test suite
      4 *
      5 * Copyright (C) 2017 Red Hat, Inc.
      6 *
      7 * Author: Stefan Hajnoczi <stefanha@redhat.com>
      8 */
      9
     10#include <getopt.h>
     11#include <stdio.h>
     12#include <stdlib.h>
     13#include <string.h>
     14#include <errno.h>
     15#include <unistd.h>
     16#include <sys/stat.h>
     17#include <sys/types.h>
     18#include <linux/list.h>
     19#include <linux/net.h>
     20#include <linux/netlink.h>
     21#include <linux/sock_diag.h>
     22#include <linux/vm_sockets_diag.h>
     23#include <netinet/tcp.h>
     24
     25#include "timeout.h"
     26#include "control.h"
     27#include "util.h"
     28
     29/* Per-socket status */
     30struct vsock_stat {
     31	struct list_head list;
     32	struct vsock_diag_msg msg;
     33};
     34
     35static const char *sock_type_str(int type)
     36{
     37	switch (type) {
     38	case SOCK_DGRAM:
     39		return "DGRAM";
     40	case SOCK_STREAM:
     41		return "STREAM";
     42	default:
     43		return "INVALID TYPE";
     44	}
     45}
     46
     47static const char *sock_state_str(int state)
     48{
     49	switch (state) {
     50	case TCP_CLOSE:
     51		return "UNCONNECTED";
     52	case TCP_SYN_SENT:
     53		return "CONNECTING";
     54	case TCP_ESTABLISHED:
     55		return "CONNECTED";
     56	case TCP_CLOSING:
     57		return "DISCONNECTING";
     58	case TCP_LISTEN:
     59		return "LISTEN";
     60	default:
     61		return "INVALID STATE";
     62	}
     63}
     64
     65static const char *sock_shutdown_str(int shutdown)
     66{
     67	switch (shutdown) {
     68	case 1:
     69		return "RCV_SHUTDOWN";
     70	case 2:
     71		return "SEND_SHUTDOWN";
     72	case 3:
     73		return "RCV_SHUTDOWN | SEND_SHUTDOWN";
     74	default:
     75		return "0";
     76	}
     77}
     78
     79static void print_vsock_addr(FILE *fp, unsigned int cid, unsigned int port)
     80{
     81	if (cid == VMADDR_CID_ANY)
     82		fprintf(fp, "*:");
     83	else
     84		fprintf(fp, "%u:", cid);
     85
     86	if (port == VMADDR_PORT_ANY)
     87		fprintf(fp, "*");
     88	else
     89		fprintf(fp, "%u", port);
     90}
     91
     92static void print_vsock_stat(FILE *fp, struct vsock_stat *st)
     93{
     94	print_vsock_addr(fp, st->msg.vdiag_src_cid, st->msg.vdiag_src_port);
     95	fprintf(fp, " ");
     96	print_vsock_addr(fp, st->msg.vdiag_dst_cid, st->msg.vdiag_dst_port);
     97	fprintf(fp, " %s %s %s %u\n",
     98		sock_type_str(st->msg.vdiag_type),
     99		sock_state_str(st->msg.vdiag_state),
    100		sock_shutdown_str(st->msg.vdiag_shutdown),
    101		st->msg.vdiag_ino);
    102}
    103
    104static void print_vsock_stats(FILE *fp, struct list_head *head)
    105{
    106	struct vsock_stat *st;
    107
    108	list_for_each_entry(st, head, list)
    109		print_vsock_stat(fp, st);
    110}
    111
    112static struct vsock_stat *find_vsock_stat(struct list_head *head, int fd)
    113{
    114	struct vsock_stat *st;
    115	struct stat stat;
    116
    117	if (fstat(fd, &stat) < 0) {
    118		perror("fstat");
    119		exit(EXIT_FAILURE);
    120	}
    121
    122	list_for_each_entry(st, head, list)
    123		if (st->msg.vdiag_ino == stat.st_ino)
    124			return st;
    125
    126	fprintf(stderr, "cannot find fd %d\n", fd);
    127	exit(EXIT_FAILURE);
    128}
    129
    130static void check_no_sockets(struct list_head *head)
    131{
    132	if (!list_empty(head)) {
    133		fprintf(stderr, "expected no sockets\n");
    134		print_vsock_stats(stderr, head);
    135		exit(1);
    136	}
    137}
    138
    139static void check_num_sockets(struct list_head *head, int expected)
    140{
    141	struct list_head *node;
    142	int n = 0;
    143
    144	list_for_each(node, head)
    145		n++;
    146
    147	if (n != expected) {
    148		fprintf(stderr, "expected %d sockets, found %d\n",
    149			expected, n);
    150		print_vsock_stats(stderr, head);
    151		exit(EXIT_FAILURE);
    152	}
    153}
    154
    155static void check_socket_state(struct vsock_stat *st, __u8 state)
    156{
    157	if (st->msg.vdiag_state != state) {
    158		fprintf(stderr, "expected socket state %#x, got %#x\n",
    159			state, st->msg.vdiag_state);
    160		exit(EXIT_FAILURE);
    161	}
    162}
    163
    164static void send_req(int fd)
    165{
    166	struct sockaddr_nl nladdr = {
    167		.nl_family = AF_NETLINK,
    168	};
    169	struct {
    170		struct nlmsghdr nlh;
    171		struct vsock_diag_req vreq;
    172	} req = {
    173		.nlh = {
    174			.nlmsg_len = sizeof(req),
    175			.nlmsg_type = SOCK_DIAG_BY_FAMILY,
    176			.nlmsg_flags = NLM_F_REQUEST | NLM_F_DUMP,
    177		},
    178		.vreq = {
    179			.sdiag_family = AF_VSOCK,
    180			.vdiag_states = ~(__u32)0,
    181		},
    182	};
    183	struct iovec iov = {
    184		.iov_base = &req,
    185		.iov_len = sizeof(req),
    186	};
    187	struct msghdr msg = {
    188		.msg_name = &nladdr,
    189		.msg_namelen = sizeof(nladdr),
    190		.msg_iov = &iov,
    191		.msg_iovlen = 1,
    192	};
    193
    194	for (;;) {
    195		if (sendmsg(fd, &msg, 0) < 0) {
    196			if (errno == EINTR)
    197				continue;
    198
    199			perror("sendmsg");
    200			exit(EXIT_FAILURE);
    201		}
    202
    203		return;
    204	}
    205}
    206
    207static ssize_t recv_resp(int fd, void *buf, size_t len)
    208{
    209	struct sockaddr_nl nladdr = {
    210		.nl_family = AF_NETLINK,
    211	};
    212	struct iovec iov = {
    213		.iov_base = buf,
    214		.iov_len = len,
    215	};
    216	struct msghdr msg = {
    217		.msg_name = &nladdr,
    218		.msg_namelen = sizeof(nladdr),
    219		.msg_iov = &iov,
    220		.msg_iovlen = 1,
    221	};
    222	ssize_t ret;
    223
    224	do {
    225		ret = recvmsg(fd, &msg, 0);
    226	} while (ret < 0 && errno == EINTR);
    227
    228	if (ret < 0) {
    229		perror("recvmsg");
    230		exit(EXIT_FAILURE);
    231	}
    232
    233	return ret;
    234}
    235
    236static void add_vsock_stat(struct list_head *sockets,
    237			   const struct vsock_diag_msg *resp)
    238{
    239	struct vsock_stat *st;
    240
    241	st = malloc(sizeof(*st));
    242	if (!st) {
    243		perror("malloc");
    244		exit(EXIT_FAILURE);
    245	}
    246
    247	st->msg = *resp;
    248	list_add_tail(&st->list, sockets);
    249}
    250
    251/*
    252 * Read vsock stats into a list.
    253 */
    254static void read_vsock_stat(struct list_head *sockets)
    255{
    256	long buf[8192 / sizeof(long)];
    257	int fd;
    258
    259	fd = socket(AF_NETLINK, SOCK_RAW, NETLINK_SOCK_DIAG);
    260	if (fd < 0) {
    261		perror("socket");
    262		exit(EXIT_FAILURE);
    263	}
    264
    265	send_req(fd);
    266
    267	for (;;) {
    268		const struct nlmsghdr *h;
    269		ssize_t ret;
    270
    271		ret = recv_resp(fd, buf, sizeof(buf));
    272		if (ret == 0)
    273			goto done;
    274		if (ret < sizeof(*h)) {
    275			fprintf(stderr, "short read of %zd bytes\n", ret);
    276			exit(EXIT_FAILURE);
    277		}
    278
    279		h = (struct nlmsghdr *)buf;
    280
    281		while (NLMSG_OK(h, ret)) {
    282			if (h->nlmsg_type == NLMSG_DONE)
    283				goto done;
    284
    285			if (h->nlmsg_type == NLMSG_ERROR) {
    286				const struct nlmsgerr *err = NLMSG_DATA(h);
    287
    288				if (h->nlmsg_len < NLMSG_LENGTH(sizeof(*err)))
    289					fprintf(stderr, "NLMSG_ERROR\n");
    290				else {
    291					errno = -err->error;
    292					perror("NLMSG_ERROR");
    293				}
    294
    295				exit(EXIT_FAILURE);
    296			}
    297
    298			if (h->nlmsg_type != SOCK_DIAG_BY_FAMILY) {
    299				fprintf(stderr, "unexpected nlmsg_type %#x\n",
    300					h->nlmsg_type);
    301				exit(EXIT_FAILURE);
    302			}
    303			if (h->nlmsg_len <
    304			    NLMSG_LENGTH(sizeof(struct vsock_diag_msg))) {
    305				fprintf(stderr, "short vsock_diag_msg\n");
    306				exit(EXIT_FAILURE);
    307			}
    308
    309			add_vsock_stat(sockets, NLMSG_DATA(h));
    310
    311			h = NLMSG_NEXT(h, ret);
    312		}
    313	}
    314
    315done:
    316	close(fd);
    317}
    318
    319static void free_sock_stat(struct list_head *sockets)
    320{
    321	struct vsock_stat *st;
    322	struct vsock_stat *next;
    323
    324	list_for_each_entry_safe(st, next, sockets, list)
    325		free(st);
    326}
    327
    328static void test_no_sockets(const struct test_opts *opts)
    329{
    330	LIST_HEAD(sockets);
    331
    332	read_vsock_stat(&sockets);
    333
    334	check_no_sockets(&sockets);
    335}
    336
    337static void test_listen_socket_server(const struct test_opts *opts)
    338{
    339	union {
    340		struct sockaddr sa;
    341		struct sockaddr_vm svm;
    342	} addr = {
    343		.svm = {
    344			.svm_family = AF_VSOCK,
    345			.svm_port = 1234,
    346			.svm_cid = VMADDR_CID_ANY,
    347		},
    348	};
    349	LIST_HEAD(sockets);
    350	struct vsock_stat *st;
    351	int fd;
    352
    353	fd = socket(AF_VSOCK, SOCK_STREAM, 0);
    354
    355	if (bind(fd, &addr.sa, sizeof(addr.svm)) < 0) {
    356		perror("bind");
    357		exit(EXIT_FAILURE);
    358	}
    359
    360	if (listen(fd, 1) < 0) {
    361		perror("listen");
    362		exit(EXIT_FAILURE);
    363	}
    364
    365	read_vsock_stat(&sockets);
    366
    367	check_num_sockets(&sockets, 1);
    368	st = find_vsock_stat(&sockets, fd);
    369	check_socket_state(st, TCP_LISTEN);
    370
    371	close(fd);
    372	free_sock_stat(&sockets);
    373}
    374
    375static void test_connect_client(const struct test_opts *opts)
    376{
    377	int fd;
    378	LIST_HEAD(sockets);
    379	struct vsock_stat *st;
    380
    381	fd = vsock_stream_connect(opts->peer_cid, 1234);
    382	if (fd < 0) {
    383		perror("connect");
    384		exit(EXIT_FAILURE);
    385	}
    386
    387	read_vsock_stat(&sockets);
    388
    389	check_num_sockets(&sockets, 1);
    390	st = find_vsock_stat(&sockets, fd);
    391	check_socket_state(st, TCP_ESTABLISHED);
    392
    393	control_expectln("DONE");
    394	control_writeln("DONE");
    395
    396	close(fd);
    397	free_sock_stat(&sockets);
    398}
    399
    400static void test_connect_server(const struct test_opts *opts)
    401{
    402	struct vsock_stat *st;
    403	LIST_HEAD(sockets);
    404	int client_fd;
    405
    406	client_fd = vsock_stream_accept(VMADDR_CID_ANY, 1234, NULL);
    407	if (client_fd < 0) {
    408		perror("accept");
    409		exit(EXIT_FAILURE);
    410	}
    411
    412	read_vsock_stat(&sockets);
    413
    414	check_num_sockets(&sockets, 1);
    415	st = find_vsock_stat(&sockets, client_fd);
    416	check_socket_state(st, TCP_ESTABLISHED);
    417
    418	control_writeln("DONE");
    419	control_expectln("DONE");
    420
    421	close(client_fd);
    422	free_sock_stat(&sockets);
    423}
    424
    425static struct test_case test_cases[] = {
    426	{
    427		.name = "No sockets",
    428		.run_server = test_no_sockets,
    429	},
    430	{
    431		.name = "Listen socket",
    432		.run_server = test_listen_socket_server,
    433	},
    434	{
    435		.name = "Connect",
    436		.run_client = test_connect_client,
    437		.run_server = test_connect_server,
    438	},
    439	{},
    440};
    441
    442static const char optstring[] = "";
    443static const struct option longopts[] = {
    444	{
    445		.name = "control-host",
    446		.has_arg = required_argument,
    447		.val = 'H',
    448	},
    449	{
    450		.name = "control-port",
    451		.has_arg = required_argument,
    452		.val = 'P',
    453	},
    454	{
    455		.name = "mode",
    456		.has_arg = required_argument,
    457		.val = 'm',
    458	},
    459	{
    460		.name = "peer-cid",
    461		.has_arg = required_argument,
    462		.val = 'p',
    463	},
    464	{
    465		.name = "list",
    466		.has_arg = no_argument,
    467		.val = 'l',
    468	},
    469	{
    470		.name = "skip",
    471		.has_arg = required_argument,
    472		.val = 's',
    473	},
    474	{
    475		.name = "help",
    476		.has_arg = no_argument,
    477		.val = '?',
    478	},
    479	{},
    480};
    481
    482static void usage(void)
    483{
    484	fprintf(stderr, "Usage: vsock_diag_test [--help] [--control-host=<host>] --control-port=<port> --mode=client|server --peer-cid=<cid> [--list] [--skip=<test_id>]\n"
    485		"\n"
    486		"  Server: vsock_diag_test --control-port=1234 --mode=server --peer-cid=3\n"
    487		"  Client: vsock_diag_test --control-host=192.168.0.1 --control-port=1234 --mode=client --peer-cid=2\n"
    488		"\n"
    489		"Run vsock_diag.ko tests.  Must be launched in both\n"
    490		"guest and host.  One side must use --mode=client and\n"
    491		"the other side must use --mode=server.\n"
    492		"\n"
    493		"A TCP control socket connection is used to coordinate tests\n"
    494		"between the client and the server.  The server requires a\n"
    495		"listen address and the client requires an address to\n"
    496		"connect to.\n"
    497		"\n"
    498		"The CID of the other side must be given with --peer-cid=<cid>.\n"
    499		"\n"
    500		"Options:\n"
    501		"  --help                 This help message\n"
    502		"  --control-host <host>  Server IP address to connect to\n"
    503		"  --control-port <port>  Server port to listen on/connect to\n"
    504		"  --mode client|server   Server or client mode\n"
    505		"  --peer-cid <cid>       CID of the other side\n"
    506		"  --list                 List of tests that will be executed\n"
    507		"  --skip <test_id>       Test ID to skip;\n"
    508		"                         use multiple --skip options to skip more tests\n"
    509		);
    510	exit(EXIT_FAILURE);
    511}
    512
    513int main(int argc, char **argv)
    514{
    515	const char *control_host = NULL;
    516	const char *control_port = NULL;
    517	struct test_opts opts = {
    518		.mode = TEST_MODE_UNSET,
    519		.peer_cid = VMADDR_CID_ANY,
    520	};
    521
    522	init_signals();
    523
    524	for (;;) {
    525		int opt = getopt_long(argc, argv, optstring, longopts, NULL);
    526
    527		if (opt == -1)
    528			break;
    529
    530		switch (opt) {
    531		case 'H':
    532			control_host = optarg;
    533			break;
    534		case 'm':
    535			if (strcmp(optarg, "client") == 0)
    536				opts.mode = TEST_MODE_CLIENT;
    537			else if (strcmp(optarg, "server") == 0)
    538				opts.mode = TEST_MODE_SERVER;
    539			else {
    540				fprintf(stderr, "--mode must be \"client\" or \"server\"\n");
    541				return EXIT_FAILURE;
    542			}
    543			break;
    544		case 'p':
    545			opts.peer_cid = parse_cid(optarg);
    546			break;
    547		case 'P':
    548			control_port = optarg;
    549			break;
    550		case 'l':
    551			list_tests(test_cases);
    552			break;
    553		case 's':
    554			skip_test(test_cases, ARRAY_SIZE(test_cases) - 1,
    555				  optarg);
    556			break;
    557		case '?':
    558		default:
    559			usage();
    560		}
    561	}
    562
    563	if (!control_port)
    564		usage();
    565	if (opts.mode == TEST_MODE_UNSET)
    566		usage();
    567	if (opts.peer_cid == VMADDR_CID_ANY)
    568		usage();
    569
    570	if (!control_host) {
    571		if (opts.mode != TEST_MODE_SERVER)
    572			usage();
    573		control_host = "0.0.0.0";
    574	}
    575
    576	control_init(control_host, control_port,
    577		     opts.mode == TEST_MODE_SERVER);
    578
    579	run_tests(test_cases, &opts);
    580
    581	control_cleanup();
    582	return EXIT_SUCCESS;
    583}