cachepc-linux

Fork of AMDESE/linux with modifications for CachePC side-channel attack
git clone https://git.sinitax.com/sinitax/cachepc-linux
Log | Files | Refs | README | LICENSE | sfeed.txt

mptcp_inq.c (11383B)


      1// SPDX-License-Identifier: GPL-2.0
      2
      3#define _GNU_SOURCE
      4
      5#include <assert.h>
      6#include <errno.h>
      7#include <fcntl.h>
      8#include <limits.h>
      9#include <string.h>
     10#include <stdarg.h>
     11#include <stdbool.h>
     12#include <stdint.h>
     13#include <inttypes.h>
     14#include <stdio.h>
     15#include <stdlib.h>
     16#include <strings.h>
     17#include <unistd.h>
     18#include <time.h>
     19
     20#include <sys/ioctl.h>
     21#include <sys/socket.h>
     22#include <sys/types.h>
     23#include <sys/wait.h>
     24
     25#include <netdb.h>
     26#include <netinet/in.h>
     27
     28#include <linux/tcp.h>
     29#include <linux/sockios.h>
     30
     31#ifndef IPPROTO_MPTCP
     32#define IPPROTO_MPTCP 262
     33#endif
     34#ifndef SOL_MPTCP
     35#define SOL_MPTCP 284
     36#endif
     37
     38static int pf = AF_INET;
     39static int proto_tx = IPPROTO_MPTCP;
     40static int proto_rx = IPPROTO_MPTCP;
     41
     42static void die_perror(const char *msg)
     43{
     44	perror(msg);
     45	exit(1);
     46}
     47
     48static void die_usage(int r)
     49{
     50	fprintf(stderr, "Usage: mptcp_inq [-6] [ -t tcp|mptcp ] [ -r tcp|mptcp]\n");
     51	exit(r);
     52}
     53
     54static void xerror(const char *fmt, ...)
     55{
     56	va_list ap;
     57
     58	va_start(ap, fmt);
     59	vfprintf(stderr, fmt, ap);
     60	va_end(ap);
     61	fputc('\n', stderr);
     62	exit(1);
     63}
     64
     65static const char *getxinfo_strerr(int err)
     66{
     67	if (err == EAI_SYSTEM)
     68		return strerror(errno);
     69
     70	return gai_strerror(err);
     71}
     72
     73static void xgetaddrinfo(const char *node, const char *service,
     74			 const struct addrinfo *hints,
     75			 struct addrinfo **res)
     76{
     77	int err = getaddrinfo(node, service, hints, res);
     78
     79	if (err) {
     80		const char *errstr = getxinfo_strerr(err);
     81
     82		fprintf(stderr, "Fatal: getaddrinfo(%s:%s): %s\n",
     83			node ? node : "", service ? service : "", errstr);
     84		exit(1);
     85	}
     86}
     87
     88static int sock_listen_mptcp(const char * const listenaddr,
     89			     const char * const port)
     90{
     91	int sock = -1;
     92	struct addrinfo hints = {
     93		.ai_protocol = IPPROTO_TCP,
     94		.ai_socktype = SOCK_STREAM,
     95		.ai_flags = AI_PASSIVE | AI_NUMERICHOST
     96	};
     97
     98	hints.ai_family = pf;
     99
    100	struct addrinfo *a, *addr;
    101	int one = 1;
    102
    103	xgetaddrinfo(listenaddr, port, &hints, &addr);
    104	hints.ai_family = pf;
    105
    106	for (a = addr; a; a = a->ai_next) {
    107		sock = socket(a->ai_family, a->ai_socktype, proto_rx);
    108		if (sock < 0)
    109			continue;
    110
    111		if (-1 == setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, &one,
    112				     sizeof(one)))
    113			perror("setsockopt");
    114
    115		if (bind(sock, a->ai_addr, a->ai_addrlen) == 0)
    116			break; /* success */
    117
    118		perror("bind");
    119		close(sock);
    120		sock = -1;
    121	}
    122
    123	freeaddrinfo(addr);
    124
    125	if (sock < 0)
    126		xerror("could not create listen socket");
    127
    128	if (listen(sock, 20))
    129		die_perror("listen");
    130
    131	return sock;
    132}
    133
    134static int sock_connect_mptcp(const char * const remoteaddr,
    135			      const char * const port, int proto)
    136{
    137	struct addrinfo hints = {
    138		.ai_protocol = IPPROTO_TCP,
    139		.ai_socktype = SOCK_STREAM,
    140	};
    141	struct addrinfo *a, *addr;
    142	int sock = -1;
    143
    144	hints.ai_family = pf;
    145
    146	xgetaddrinfo(remoteaddr, port, &hints, &addr);
    147	for (a = addr; a; a = a->ai_next) {
    148		sock = socket(a->ai_family, a->ai_socktype, proto);
    149		if (sock < 0)
    150			continue;
    151
    152		if (connect(sock, a->ai_addr, a->ai_addrlen) == 0)
    153			break; /* success */
    154
    155		die_perror("connect");
    156	}
    157
    158	if (sock < 0)
    159		xerror("could not create connect socket");
    160
    161	freeaddrinfo(addr);
    162	return sock;
    163}
    164
    165static int protostr_to_num(const char *s)
    166{
    167	if (strcasecmp(s, "tcp") == 0)
    168		return IPPROTO_TCP;
    169	if (strcasecmp(s, "mptcp") == 0)
    170		return IPPROTO_MPTCP;
    171
    172	die_usage(1);
    173	return 0;
    174}
    175
    176static void parse_opts(int argc, char **argv)
    177{
    178	int c;
    179
    180	while ((c = getopt(argc, argv, "h6t:r:")) != -1) {
    181		switch (c) {
    182		case 'h':
    183			die_usage(0);
    184			break;
    185		case '6':
    186			pf = AF_INET6;
    187			break;
    188		case 't':
    189			proto_tx = protostr_to_num(optarg);
    190			break;
    191		case 'r':
    192			proto_rx = protostr_to_num(optarg);
    193			break;
    194		default:
    195			die_usage(1);
    196			break;
    197		}
    198	}
    199}
    200
    201/* wait up to timeout milliseconds */
    202static void wait_for_ack(int fd, int timeout, size_t total)
    203{
    204	int i;
    205
    206	for (i = 0; i < timeout; i++) {
    207		int nsd, ret, queued = -1;
    208		struct timespec req;
    209
    210		ret = ioctl(fd, TIOCOUTQ, &queued);
    211		if (ret < 0)
    212			die_perror("TIOCOUTQ");
    213
    214		ret = ioctl(fd, SIOCOUTQNSD, &nsd);
    215		if (ret < 0)
    216			die_perror("SIOCOUTQNSD");
    217
    218		if ((size_t)queued > total)
    219			xerror("TIOCOUTQ %u, but only %zu expected\n", queued, total);
    220		assert(nsd <= queued);
    221
    222		if (queued == 0)
    223			return;
    224
    225		/* wait for peer to ack rx of all data */
    226		req.tv_sec = 0;
    227		req.tv_nsec = 1 * 1000 * 1000ul; /* 1ms */
    228		nanosleep(&req, NULL);
    229	}
    230
    231	xerror("still tx data queued after %u ms\n", timeout);
    232}
    233
    234static void connect_one_server(int fd, int unixfd)
    235{
    236	size_t len, i, total, sent;
    237	char buf[4096], buf2[4096];
    238	ssize_t ret;
    239
    240	len = rand() % (sizeof(buf) - 1);
    241
    242	if (len < 128)
    243		len = 128;
    244
    245	for (i = 0; i < len ; i++) {
    246		buf[i] = rand() % 26;
    247		buf[i] += 'A';
    248	}
    249
    250	buf[i] = '\n';
    251
    252	/* un-block server */
    253	ret = read(unixfd, buf2, 4);
    254	assert(ret == 4);
    255
    256	assert(strncmp(buf2, "xmit", 4) == 0);
    257
    258	ret = write(unixfd, &len, sizeof(len));
    259	assert(ret == (ssize_t)sizeof(len));
    260
    261	ret = write(fd, buf, len);
    262	if (ret < 0)
    263		die_perror("write");
    264
    265	if (ret != (ssize_t)len)
    266		xerror("short write");
    267
    268	ret = read(unixfd, buf2, 4);
    269	assert(strncmp(buf2, "huge", 4) == 0);
    270
    271	total = rand() % (16 * 1024 * 1024);
    272	total += (1 * 1024 * 1024);
    273	sent = total;
    274
    275	ret = write(unixfd, &total, sizeof(total));
    276	assert(ret == (ssize_t)sizeof(total));
    277
    278	wait_for_ack(fd, 5000, len);
    279
    280	while (total > 0) {
    281		if (total > sizeof(buf))
    282			len = sizeof(buf);
    283		else
    284			len = total;
    285
    286		ret = write(fd, buf, len);
    287		if (ret < 0)
    288			die_perror("write");
    289		total -= ret;
    290
    291		/* we don't have to care about buf content, only
    292		 * number of total bytes sent
    293		 */
    294	}
    295
    296	ret = read(unixfd, buf2, 4);
    297	assert(ret == 4);
    298	assert(strncmp(buf2, "shut", 4) == 0);
    299
    300	wait_for_ack(fd, 5000, sent);
    301
    302	ret = write(fd, buf, 1);
    303	assert(ret == 1);
    304	close(fd);
    305	ret = write(unixfd, "closed", 6);
    306	assert(ret == 6);
    307
    308	close(unixfd);
    309}
    310
    311static void get_tcp_inq(struct msghdr *msgh, unsigned int *inqv)
    312{
    313	struct cmsghdr *cmsg;
    314
    315	for (cmsg = CMSG_FIRSTHDR(msgh); cmsg ; cmsg = CMSG_NXTHDR(msgh, cmsg)) {
    316		if (cmsg->cmsg_level == IPPROTO_TCP && cmsg->cmsg_type == TCP_CM_INQ) {
    317			memcpy(inqv, CMSG_DATA(cmsg), sizeof(*inqv));
    318			return;
    319		}
    320	}
    321
    322	xerror("could not find TCP_CM_INQ cmsg type");
    323}
    324
    325static void process_one_client(int fd, int unixfd)
    326{
    327	unsigned int tcp_inq;
    328	size_t expect_len;
    329	char msg_buf[4096];
    330	char buf[4096];
    331	char tmp[16];
    332	struct iovec iov = {
    333		.iov_base = buf,
    334		.iov_len = 1,
    335	};
    336	struct msghdr msg = {
    337		.msg_iov = &iov,
    338		.msg_iovlen = 1,
    339		.msg_control = msg_buf,
    340		.msg_controllen = sizeof(msg_buf),
    341	};
    342	ssize_t ret, tot;
    343
    344	ret = write(unixfd, "xmit", 4);
    345	assert(ret == 4);
    346
    347	ret = read(unixfd, &expect_len, sizeof(expect_len));
    348	assert(ret == (ssize_t)sizeof(expect_len));
    349
    350	if (expect_len > sizeof(buf))
    351		xerror("expect len %zu exceeds buffer size", expect_len);
    352
    353	for (;;) {
    354		struct timespec req;
    355		unsigned int queued;
    356
    357		ret = ioctl(fd, FIONREAD, &queued);
    358		if (ret < 0)
    359			die_perror("FIONREAD");
    360		if (queued > expect_len)
    361			xerror("FIONREAD returned %u, but only %zu expected\n",
    362			       queued, expect_len);
    363		if (queued == expect_len)
    364			break;
    365
    366		req.tv_sec = 0;
    367		req.tv_nsec = 1000 * 1000ul;
    368		nanosleep(&req, NULL);
    369	}
    370
    371	/* read one byte, expect cmsg to return expected - 1 */
    372	ret = recvmsg(fd, &msg, 0);
    373	if (ret < 0)
    374		die_perror("recvmsg");
    375
    376	if (msg.msg_controllen == 0)
    377		xerror("msg_controllen is 0");
    378
    379	get_tcp_inq(&msg, &tcp_inq);
    380
    381	assert((size_t)tcp_inq == (expect_len - 1));
    382
    383	iov.iov_len = sizeof(buf);
    384	ret = recvmsg(fd, &msg, 0);
    385	if (ret < 0)
    386		die_perror("recvmsg");
    387
    388	/* should have gotten exact remainder of all pending data */
    389	assert(ret == (ssize_t)tcp_inq);
    390
    391	/* should be 0, all drained */
    392	get_tcp_inq(&msg, &tcp_inq);
    393	assert(tcp_inq == 0);
    394
    395	/* request a large swath of data. */
    396	ret = write(unixfd, "huge", 4);
    397	assert(ret == 4);
    398
    399	ret = read(unixfd, &expect_len, sizeof(expect_len));
    400	assert(ret == (ssize_t)sizeof(expect_len));
    401
    402	/* peer should send us a few mb of data */
    403	if (expect_len <= sizeof(buf))
    404		xerror("expect len %zu too small\n", expect_len);
    405
    406	tot = 0;
    407	do {
    408		iov.iov_len = sizeof(buf);
    409		ret = recvmsg(fd, &msg, 0);
    410		if (ret < 0)
    411			die_perror("recvmsg");
    412
    413		tot += ret;
    414
    415		get_tcp_inq(&msg, &tcp_inq);
    416
    417		if (tcp_inq > expect_len - tot)
    418			xerror("inq %d, remaining %d total_len %d\n",
    419			       tcp_inq, expect_len - tot, (int)expect_len);
    420
    421		assert(tcp_inq <= expect_len - tot);
    422	} while ((size_t)tot < expect_len);
    423
    424	ret = write(unixfd, "shut", 4);
    425	assert(ret == 4);
    426
    427	/* wait for hangup. Should have received one more byte of data. */
    428	ret = read(unixfd, tmp, sizeof(tmp));
    429	assert(ret == 6);
    430	assert(strncmp(tmp, "closed", 6) == 0);
    431
    432	sleep(1);
    433
    434	iov.iov_len = 1;
    435	ret = recvmsg(fd, &msg, 0);
    436	if (ret < 0)
    437		die_perror("recvmsg");
    438	assert(ret == 1);
    439
    440	get_tcp_inq(&msg, &tcp_inq);
    441
    442	/* tcp_inq should be 1 due to received fin. */
    443	assert(tcp_inq == 1);
    444
    445	iov.iov_len = 1;
    446	ret = recvmsg(fd, &msg, 0);
    447	if (ret < 0)
    448		die_perror("recvmsg");
    449
    450	/* expect EOF */
    451	assert(ret == 0);
    452	get_tcp_inq(&msg, &tcp_inq);
    453	assert(tcp_inq == 1);
    454
    455	close(fd);
    456}
    457
    458static int xaccept(int s)
    459{
    460	int fd = accept(s, NULL, 0);
    461
    462	if (fd < 0)
    463		die_perror("accept");
    464
    465	return fd;
    466}
    467
    468static int server(int unixfd)
    469{
    470	int fd = -1, r, on = 1;
    471
    472	switch (pf) {
    473	case AF_INET:
    474		fd = sock_listen_mptcp("127.0.0.1", "15432");
    475		break;
    476	case AF_INET6:
    477		fd = sock_listen_mptcp("::1", "15432");
    478		break;
    479	default:
    480		xerror("Unknown pf %d\n", pf);
    481		break;
    482	}
    483
    484	r = write(unixfd, "conn", 4);
    485	assert(r == 4);
    486
    487	alarm(15);
    488	r = xaccept(fd);
    489
    490	if (-1 == setsockopt(r, IPPROTO_TCP, TCP_INQ, &on, sizeof(on)))
    491		die_perror("setsockopt");
    492
    493	process_one_client(r, unixfd);
    494
    495	return 0;
    496}
    497
    498static int client(int unixfd)
    499{
    500	int fd = -1;
    501
    502	alarm(15);
    503
    504	switch (pf) {
    505	case AF_INET:
    506		fd = sock_connect_mptcp("127.0.0.1", "15432", proto_tx);
    507		break;
    508	case AF_INET6:
    509		fd = sock_connect_mptcp("::1", "15432", proto_tx);
    510		break;
    511	default:
    512		xerror("Unknown pf %d\n", pf);
    513	}
    514
    515	connect_one_server(fd, unixfd);
    516
    517	return 0;
    518}
    519
    520static void init_rng(void)
    521{
    522	int fd = open("/dev/urandom", O_RDONLY);
    523	unsigned int foo;
    524
    525	if (fd > 0) {
    526		int ret = read(fd, &foo, sizeof(foo));
    527
    528		if (ret < 0)
    529			srand(fd + foo);
    530		close(fd);
    531	}
    532
    533	srand(foo);
    534}
    535
    536static pid_t xfork(void)
    537{
    538	pid_t p = fork();
    539
    540	if (p < 0)
    541		die_perror("fork");
    542	else if (p == 0)
    543		init_rng();
    544
    545	return p;
    546}
    547
    548static int rcheck(int wstatus, const char *what)
    549{
    550	if (WIFEXITED(wstatus)) {
    551		if (WEXITSTATUS(wstatus) == 0)
    552			return 0;
    553		fprintf(stderr, "%s exited, status=%d\n", what, WEXITSTATUS(wstatus));
    554		return WEXITSTATUS(wstatus);
    555	} else if (WIFSIGNALED(wstatus)) {
    556		xerror("%s killed by signal %d\n", what, WTERMSIG(wstatus));
    557	} else if (WIFSTOPPED(wstatus)) {
    558		xerror("%s stopped by signal %d\n", what, WSTOPSIG(wstatus));
    559	}
    560
    561	return 111;
    562}
    563
    564int main(int argc, char *argv[])
    565{
    566	int e1, e2, wstatus;
    567	pid_t s, c, ret;
    568	int unixfds[2];
    569
    570	parse_opts(argc, argv);
    571
    572	e1 = socketpair(AF_UNIX, SOCK_DGRAM, 0, unixfds);
    573	if (e1 < 0)
    574		die_perror("pipe");
    575
    576	s = xfork();
    577	if (s == 0)
    578		return server(unixfds[1]);
    579
    580	close(unixfds[1]);
    581
    582	/* wait until server bound a socket */
    583	e1 = read(unixfds[0], &e1, 4);
    584	assert(e1 == 4);
    585
    586	c = xfork();
    587	if (c == 0)
    588		return client(unixfds[0]);
    589
    590	close(unixfds[0]);
    591
    592	ret = waitpid(s, &wstatus, 0);
    593	if (ret == -1)
    594		die_perror("waitpid");
    595	e1 = rcheck(wstatus, "server");
    596	ret = waitpid(c, &wstatus, 0);
    597	if (ret == -1)
    598		die_perror("waitpid");
    599	e2 = rcheck(wstatus, "client");
    600
    601	return e1 ? e1 : e2;
    602}