cachepc-linux

Fork of AMDESE/linux with modifications for CachePC side-channel attack
git clone https://git.sinitax.com/sinitax/cachepc-linux
Log | Files | Refs | README | LICENSE | sfeed.txt

mptcp_connect.c (28545B)


      1// SPDX-License-Identifier: GPL-2.0
      2
      3#define _GNU_SOURCE
      4
      5#include <errno.h>
      6#include <limits.h>
      7#include <fcntl.h>
      8#include <string.h>
      9#include <stdarg.h>
     10#include <stdbool.h>
     11#include <stdint.h>
     12#include <stdio.h>
     13#include <stdlib.h>
     14#include <strings.h>
     15#include <signal.h>
     16#include <unistd.h>
     17#include <time.h>
     18
     19#include <sys/ioctl.h>
     20#include <sys/poll.h>
     21#include <sys/sendfile.h>
     22#include <sys/stat.h>
     23#include <sys/socket.h>
     24#include <sys/types.h>
     25#include <sys/mman.h>
     26
     27#include <netdb.h>
     28#include <netinet/in.h>
     29
     30#include <linux/tcp.h>
     31#include <linux/time_types.h>
     32#include <linux/sockios.h>
     33
     34extern int optind;
     35
     36#ifndef IPPROTO_MPTCP
     37#define IPPROTO_MPTCP 262
     38#endif
     39#ifndef TCP_ULP
     40#define TCP_ULP 31
     41#endif
     42
     43static int  poll_timeout = 10 * 1000;
     44static bool listen_mode;
     45static bool quit;
     46
     47enum cfg_mode {
     48	CFG_MODE_POLL,
     49	CFG_MODE_MMAP,
     50	CFG_MODE_SENDFILE,
     51};
     52
     53enum cfg_peek {
     54	CFG_NONE_PEEK,
     55	CFG_WITH_PEEK,
     56	CFG_AFTER_PEEK,
     57};
     58
     59static enum cfg_mode cfg_mode = CFG_MODE_POLL;
     60static enum cfg_peek cfg_peek = CFG_NONE_PEEK;
     61static const char *cfg_host;
     62static const char *cfg_port	= "12000";
     63static int cfg_sock_proto	= IPPROTO_MPTCP;
     64static int pf = AF_INET;
     65static int cfg_sndbuf;
     66static int cfg_rcvbuf;
     67static bool cfg_join;
     68static bool cfg_remove;
     69static unsigned int cfg_time;
     70static unsigned int cfg_do_w;
     71static int cfg_wait;
     72static uint32_t cfg_mark;
     73static char *cfg_input;
     74static int cfg_repeat = 1;
     75
     76struct cfg_cmsg_types {
     77	unsigned int cmsg_enabled:1;
     78	unsigned int timestampns:1;
     79	unsigned int tcp_inq:1;
     80};
     81
     82struct cfg_sockopt_types {
     83	unsigned int transparent:1;
     84};
     85
     86struct tcp_inq_state {
     87	unsigned int last;
     88	bool expect_eof;
     89};
     90
     91static struct tcp_inq_state tcp_inq;
     92
     93static struct cfg_cmsg_types cfg_cmsg_types;
     94static struct cfg_sockopt_types cfg_sockopt_types;
     95
     96static void die_usage(void)
     97{
     98	fprintf(stderr, "Usage: mptcp_connect [-6] [-c cmsg] [-i file] [-I num] [-j] [-l] "
     99		"[-m mode] [-M mark] [-o option] [-p port] [-P mode] [-j] [-l] [-r num] "
    100		"[-s MPTCP|TCP] [-S num] [-r num] [-t num] [-T num] [-u] [-w sec] connect_address\n");
    101	fprintf(stderr, "\t-6 use ipv6\n");
    102	fprintf(stderr, "\t-c cmsg -- test cmsg type <cmsg>\n");
    103	fprintf(stderr, "\t-i file -- read the data to send from the given file instead of stdin");
    104	fprintf(stderr, "\t-I num -- repeat the transfer 'num' times. In listen mode accepts num "
    105		"incoming connections, in client mode, disconnect and reconnect to the server\n");
    106	fprintf(stderr, "\t-j     -- add additional sleep at connection start and tear down "
    107		"-- for MPJ tests\n");
    108	fprintf(stderr, "\t-l     -- listens mode, accepts incoming connection\n");
    109	fprintf(stderr, "\t-m [poll|mmap|sendfile] -- use poll(default)/mmap+write/sendfile\n");
    110	fprintf(stderr, "\t-M mark -- set socket packet mark\n");
    111	fprintf(stderr, "\t-o option -- test sockopt <option>\n");
    112	fprintf(stderr, "\t-p num -- use port num\n");
    113	fprintf(stderr,
    114		"\t-P [saveWithPeek|saveAfterPeek] -- save data with/after MSG_PEEK form tcp socket\n");
    115	fprintf(stderr, "\t-t num -- set poll timeout to num\n");
    116	fprintf(stderr, "\t-T num -- set expected runtime to num ms\n");
    117	fprintf(stderr, "\t-r num -- enable slow mode, limiting each write to num bytes "
    118		"-- for remove addr tests\n");
    119	fprintf(stderr, "\t-R num -- set SO_RCVBUF to num\n");
    120	fprintf(stderr, "\t-s [MPTCP|TCP] -- use mptcp(default) or tcp sockets\n");
    121	fprintf(stderr, "\t-S num -- set SO_SNDBUF to num\n");
    122	fprintf(stderr, "\t-w num -- wait num sec before closing the socket\n");
    123	exit(1);
    124}
    125
    126static void xerror(const char *fmt, ...)
    127{
    128	va_list ap;
    129
    130	va_start(ap, fmt);
    131	vfprintf(stderr, fmt, ap);
    132	va_end(ap);
    133	exit(1);
    134}
    135
    136static void handle_signal(int nr)
    137{
    138	quit = true;
    139}
    140
    141static const char *getxinfo_strerr(int err)
    142{
    143	if (err == EAI_SYSTEM)
    144		return strerror(errno);
    145
    146	return gai_strerror(err);
    147}
    148
    149static void xgetnameinfo(const struct sockaddr *addr, socklen_t addrlen,
    150			 char *host, socklen_t hostlen,
    151			 char *serv, socklen_t servlen)
    152{
    153	int flags = NI_NUMERICHOST | NI_NUMERICSERV;
    154	int err = getnameinfo(addr, addrlen, host, hostlen, serv, servlen,
    155			      flags);
    156
    157	if (err) {
    158		const char *errstr = getxinfo_strerr(err);
    159
    160		fprintf(stderr, "Fatal: getnameinfo: %s\n", errstr);
    161		exit(1);
    162	}
    163}
    164
    165static void xgetaddrinfo(const char *node, const char *service,
    166			 const struct addrinfo *hints,
    167			 struct addrinfo **res)
    168{
    169	int err = getaddrinfo(node, service, hints, res);
    170
    171	if (err) {
    172		const char *errstr = getxinfo_strerr(err);
    173
    174		fprintf(stderr, "Fatal: getaddrinfo(%s:%s): %s\n",
    175			node ? node : "", service ? service : "", errstr);
    176		exit(1);
    177	}
    178}
    179
    180static void set_rcvbuf(int fd, unsigned int size)
    181{
    182	int err;
    183
    184	err = setsockopt(fd, SOL_SOCKET, SO_RCVBUF, &size, sizeof(size));
    185	if (err) {
    186		perror("set SO_RCVBUF");
    187		exit(1);
    188	}
    189}
    190
    191static void set_sndbuf(int fd, unsigned int size)
    192{
    193	int err;
    194
    195	err = setsockopt(fd, SOL_SOCKET, SO_SNDBUF, &size, sizeof(size));
    196	if (err) {
    197		perror("set SO_SNDBUF");
    198		exit(1);
    199	}
    200}
    201
    202static void set_mark(int fd, uint32_t mark)
    203{
    204	int err;
    205
    206	err = setsockopt(fd, SOL_SOCKET, SO_MARK, &mark, sizeof(mark));
    207	if (err) {
    208		perror("set SO_MARK");
    209		exit(1);
    210	}
    211}
    212
    213static void set_transparent(int fd, int pf)
    214{
    215	int one = 1;
    216
    217	switch (pf) {
    218	case AF_INET:
    219		if (-1 == setsockopt(fd, SOL_IP, IP_TRANSPARENT, &one, sizeof(one)))
    220			perror("IP_TRANSPARENT");
    221		break;
    222	case AF_INET6:
    223		if (-1 == setsockopt(fd, IPPROTO_IPV6, IPV6_TRANSPARENT, &one, sizeof(one)))
    224			perror("IPV6_TRANSPARENT");
    225		break;
    226	}
    227}
    228
    229static int do_ulp_so(int sock, const char *name)
    230{
    231	return setsockopt(sock, IPPROTO_TCP, TCP_ULP, name, strlen(name));
    232}
    233
    234#define X(m)	xerror("%s:%u: %s: failed for proto %d at line %u", __FILE__, __LINE__, (m), proto, line)
    235static void sock_test_tcpulp(int sock, int proto, unsigned int line)
    236{
    237	socklen_t buflen = 8;
    238	char buf[8] = "";
    239	int ret = getsockopt(sock, IPPROTO_TCP, TCP_ULP, buf, &buflen);
    240
    241	if (ret != 0)
    242		X("getsockopt");
    243
    244	if (buflen > 0) {
    245		if (strcmp(buf, "mptcp") != 0)
    246			xerror("unexpected ULP '%s' for proto %d at line %u", buf, proto, line);
    247		ret = do_ulp_so(sock, "tls");
    248		if (ret == 0)
    249			X("setsockopt");
    250	} else if (proto == IPPROTO_MPTCP) {
    251		ret = do_ulp_so(sock, "tls");
    252		if (ret != -1)
    253			X("setsockopt");
    254	}
    255
    256	ret = do_ulp_so(sock, "mptcp");
    257	if (ret != -1)
    258		X("setsockopt");
    259
    260#undef X
    261}
    262
    263#define SOCK_TEST_TCPULP(s, p) sock_test_tcpulp((s), (p), __LINE__)
    264
    265static int sock_listen_mptcp(const char * const listenaddr,
    266			     const char * const port)
    267{
    268	int sock = -1;
    269	struct addrinfo hints = {
    270		.ai_protocol = IPPROTO_TCP,
    271		.ai_socktype = SOCK_STREAM,
    272		.ai_flags = AI_PASSIVE | AI_NUMERICHOST
    273	};
    274
    275	hints.ai_family = pf;
    276
    277	struct addrinfo *a, *addr;
    278	int one = 1;
    279
    280	xgetaddrinfo(listenaddr, port, &hints, &addr);
    281	hints.ai_family = pf;
    282
    283	for (a = addr; a; a = a->ai_next) {
    284		sock = socket(a->ai_family, a->ai_socktype, cfg_sock_proto);
    285		if (sock < 0)
    286			continue;
    287
    288		SOCK_TEST_TCPULP(sock, cfg_sock_proto);
    289
    290		if (-1 == setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, &one,
    291				     sizeof(one)))
    292			perror("setsockopt");
    293
    294		if (cfg_sockopt_types.transparent)
    295			set_transparent(sock, pf);
    296
    297		if (bind(sock, a->ai_addr, a->ai_addrlen) == 0)
    298			break; /* success */
    299
    300		perror("bind");
    301		close(sock);
    302		sock = -1;
    303	}
    304
    305	freeaddrinfo(addr);
    306
    307	if (sock < 0) {
    308		fprintf(stderr, "Could not create listen socket\n");
    309		return sock;
    310	}
    311
    312	SOCK_TEST_TCPULP(sock, cfg_sock_proto);
    313
    314	if (listen(sock, 20)) {
    315		perror("listen");
    316		close(sock);
    317		return -1;
    318	}
    319
    320	SOCK_TEST_TCPULP(sock, cfg_sock_proto);
    321
    322	return sock;
    323}
    324
    325static int sock_connect_mptcp(const char * const remoteaddr,
    326			      const char * const port, int proto,
    327			      struct addrinfo **peer)
    328{
    329	struct addrinfo hints = {
    330		.ai_protocol = IPPROTO_TCP,
    331		.ai_socktype = SOCK_STREAM,
    332	};
    333	struct addrinfo *a, *addr;
    334	int sock = -1;
    335
    336	hints.ai_family = pf;
    337
    338	xgetaddrinfo(remoteaddr, port, &hints, &addr);
    339	for (a = addr; a; a = a->ai_next) {
    340		sock = socket(a->ai_family, a->ai_socktype, proto);
    341		if (sock < 0) {
    342			perror("socket");
    343			continue;
    344		}
    345
    346		SOCK_TEST_TCPULP(sock, proto);
    347
    348		if (cfg_mark)
    349			set_mark(sock, cfg_mark);
    350
    351		if (connect(sock, a->ai_addr, a->ai_addrlen) == 0) {
    352			*peer = a;
    353			break; /* success */
    354		}
    355
    356		perror("connect()");
    357		close(sock);
    358		sock = -1;
    359	}
    360
    361	freeaddrinfo(addr);
    362	if (sock != -1)
    363		SOCK_TEST_TCPULP(sock, proto);
    364	return sock;
    365}
    366
    367static size_t do_rnd_write(const int fd, char *buf, const size_t len)
    368{
    369	static bool first = true;
    370	unsigned int do_w;
    371	ssize_t bw;
    372
    373	do_w = rand() & 0xffff;
    374	if (do_w == 0 || do_w > len)
    375		do_w = len;
    376
    377	if (cfg_join && first && do_w > 100)
    378		do_w = 100;
    379
    380	if (cfg_remove && do_w > cfg_do_w)
    381		do_w = cfg_do_w;
    382
    383	bw = write(fd, buf, do_w);
    384	if (bw < 0)
    385		perror("write");
    386
    387	/* let the join handshake complete, before going on */
    388	if (cfg_join && first) {
    389		usleep(200000);
    390		first = false;
    391	}
    392
    393	if (cfg_remove)
    394		usleep(200000);
    395
    396	return bw;
    397}
    398
    399static size_t do_write(const int fd, char *buf, const size_t len)
    400{
    401	size_t offset = 0;
    402
    403	while (offset < len) {
    404		size_t written;
    405		ssize_t bw;
    406
    407		bw = write(fd, buf + offset, len - offset);
    408		if (bw < 0) {
    409			perror("write");
    410			return 0;
    411		}
    412
    413		written = (size_t)bw;
    414		offset += written;
    415	}
    416
    417	return offset;
    418}
    419
    420static void process_cmsg(struct msghdr *msgh)
    421{
    422	struct __kernel_timespec ts;
    423	bool inq_found = false;
    424	bool ts_found = false;
    425	unsigned int inq = 0;
    426	struct cmsghdr *cmsg;
    427
    428	for (cmsg = CMSG_FIRSTHDR(msgh); cmsg ; cmsg = CMSG_NXTHDR(msgh, cmsg)) {
    429		if (cmsg->cmsg_level == SOL_SOCKET && cmsg->cmsg_type == SO_TIMESTAMPNS_NEW) {
    430			memcpy(&ts, CMSG_DATA(cmsg), sizeof(ts));
    431			ts_found = true;
    432			continue;
    433		}
    434		if (cmsg->cmsg_level == IPPROTO_TCP && cmsg->cmsg_type == TCP_CM_INQ) {
    435			memcpy(&inq, CMSG_DATA(cmsg), sizeof(inq));
    436			inq_found = true;
    437			continue;
    438		}
    439
    440	}
    441
    442	if (cfg_cmsg_types.timestampns) {
    443		if (!ts_found)
    444			xerror("TIMESTAMPNS not present\n");
    445	}
    446
    447	if (cfg_cmsg_types.tcp_inq) {
    448		if (!inq_found)
    449			xerror("TCP_INQ not present\n");
    450
    451		if (inq > 1024)
    452			xerror("tcp_inq %u is larger than one kbyte\n", inq);
    453		tcp_inq.last = inq;
    454	}
    455}
    456
    457static ssize_t do_recvmsg_cmsg(const int fd, char *buf, const size_t len)
    458{
    459	char msg_buf[8192];
    460	struct iovec iov = {
    461		.iov_base = buf,
    462		.iov_len = len,
    463	};
    464	struct msghdr msg = {
    465		.msg_iov = &iov,
    466		.msg_iovlen = 1,
    467		.msg_control = msg_buf,
    468		.msg_controllen = sizeof(msg_buf),
    469	};
    470	int flags = 0;
    471	unsigned int last_hint = tcp_inq.last;
    472	int ret = recvmsg(fd, &msg, flags);
    473
    474	if (ret <= 0) {
    475		if (ret == 0 && tcp_inq.expect_eof)
    476			return ret;
    477
    478		if (ret == 0 && cfg_cmsg_types.tcp_inq)
    479			if (last_hint != 1 && last_hint != 0)
    480				xerror("EOF but last tcp_inq hint was %u\n", last_hint);
    481
    482		return ret;
    483	}
    484
    485	if (tcp_inq.expect_eof)
    486		xerror("expected EOF, last_hint %u, now %u\n",
    487		       last_hint, tcp_inq.last);
    488
    489	if (msg.msg_controllen && !cfg_cmsg_types.cmsg_enabled)
    490		xerror("got %lu bytes of cmsg data, expected 0\n",
    491		       (unsigned long)msg.msg_controllen);
    492
    493	if (msg.msg_controllen == 0 && cfg_cmsg_types.cmsg_enabled)
    494		xerror("%s\n", "got no cmsg data");
    495
    496	if (msg.msg_controllen)
    497		process_cmsg(&msg);
    498
    499	if (cfg_cmsg_types.tcp_inq) {
    500		if ((size_t)ret < len && last_hint > (unsigned int)ret) {
    501			if (ret + 1 != (int)last_hint) {
    502				int next = read(fd, msg_buf, sizeof(msg_buf));
    503
    504				xerror("read %u of %u, last_hint was %u tcp_inq hint now %u next_read returned %d/%m\n",
    505				       ret, (unsigned int)len, last_hint, tcp_inq.last, next);
    506			} else {
    507				tcp_inq.expect_eof = true;
    508			}
    509		}
    510	}
    511
    512	return ret;
    513}
    514
    515static ssize_t do_rnd_read(const int fd, char *buf, const size_t len)
    516{
    517	int ret = 0;
    518	char tmp[16384];
    519	size_t cap = rand();
    520
    521	cap &= 0xffff;
    522
    523	if (cap == 0)
    524		cap = 1;
    525	else if (cap > len)
    526		cap = len;
    527
    528	if (cfg_peek == CFG_WITH_PEEK) {
    529		ret = recv(fd, buf, cap, MSG_PEEK);
    530		ret = (ret < 0) ? ret : read(fd, tmp, ret);
    531	} else if (cfg_peek == CFG_AFTER_PEEK) {
    532		ret = recv(fd, buf, cap, MSG_PEEK);
    533		ret = (ret < 0) ? ret : read(fd, buf, cap);
    534	} else if (cfg_cmsg_types.cmsg_enabled) {
    535		ret = do_recvmsg_cmsg(fd, buf, cap);
    536	} else {
    537		ret = read(fd, buf, cap);
    538	}
    539
    540	return ret;
    541}
    542
    543static void set_nonblock(int fd, bool nonblock)
    544{
    545	int flags = fcntl(fd, F_GETFL);
    546
    547	if (flags == -1)
    548		return;
    549
    550	if (nonblock)
    551		fcntl(fd, F_SETFL, flags | O_NONBLOCK);
    552	else
    553		fcntl(fd, F_SETFL, flags & ~O_NONBLOCK);
    554}
    555
    556static int copyfd_io_poll(int infd, int peerfd, int outfd, bool *in_closed_after_out)
    557{
    558	struct pollfd fds = {
    559		.fd = peerfd,
    560		.events = POLLIN | POLLOUT,
    561	};
    562	unsigned int woff = 0, wlen = 0;
    563	char wbuf[8192];
    564
    565	set_nonblock(peerfd, true);
    566
    567	for (;;) {
    568		char rbuf[8192];
    569		ssize_t len;
    570
    571		if (fds.events == 0)
    572			break;
    573
    574		switch (poll(&fds, 1, poll_timeout)) {
    575		case -1:
    576			if (errno == EINTR)
    577				continue;
    578			perror("poll");
    579			return 1;
    580		case 0:
    581			fprintf(stderr, "%s: poll timed out (events: "
    582				"POLLIN %u, POLLOUT %u)\n", __func__,
    583				fds.events & POLLIN, fds.events & POLLOUT);
    584			return 2;
    585		}
    586
    587		if (fds.revents & POLLIN) {
    588			len = do_rnd_read(peerfd, rbuf, sizeof(rbuf));
    589			if (len == 0) {
    590				/* no more data to receive:
    591				 * peer has closed its write side
    592				 */
    593				fds.events &= ~POLLIN;
    594
    595				if ((fds.events & POLLOUT) == 0) {
    596					*in_closed_after_out = true;
    597					/* and nothing more to send */
    598					break;
    599				}
    600
    601			/* Else, still have data to transmit */
    602			} else if (len < 0) {
    603				perror("read");
    604				return 3;
    605			}
    606
    607			do_write(outfd, rbuf, len);
    608		}
    609
    610		if (fds.revents & POLLOUT) {
    611			if (wlen == 0) {
    612				woff = 0;
    613				wlen = read(infd, wbuf, sizeof(wbuf));
    614			}
    615
    616			if (wlen > 0) {
    617				ssize_t bw;
    618
    619				bw = do_rnd_write(peerfd, wbuf + woff, wlen);
    620				if (bw < 0)
    621					return 111;
    622
    623				woff += bw;
    624				wlen -= bw;
    625			} else if (wlen == 0) {
    626				/* We have no more data to send. */
    627				fds.events &= ~POLLOUT;
    628
    629				if ((fds.events & POLLIN) == 0)
    630					/* ... and peer also closed already */
    631					break;
    632
    633				/* ... but we still receive.
    634				 * Close our write side, ev. give some time
    635				 * for address notification and/or checking
    636				 * the current status
    637				 */
    638				if (cfg_wait)
    639					usleep(cfg_wait);
    640				shutdown(peerfd, SHUT_WR);
    641			} else {
    642				if (errno == EINTR)
    643					continue;
    644				perror("read");
    645				return 4;
    646			}
    647		}
    648
    649		if (fds.revents & (POLLERR | POLLNVAL)) {
    650			fprintf(stderr, "Unexpected revents: "
    651				"POLLERR/POLLNVAL(%x)\n", fds.revents);
    652			return 5;
    653		}
    654	}
    655
    656	/* leave some time for late join/announce */
    657	if (cfg_remove)
    658		usleep(cfg_wait);
    659
    660	return 0;
    661}
    662
    663static int do_recvfile(int infd, int outfd)
    664{
    665	ssize_t r;
    666
    667	do {
    668		char buf[16384];
    669
    670		r = do_rnd_read(infd, buf, sizeof(buf));
    671		if (r > 0) {
    672			if (write(outfd, buf, r) != r)
    673				break;
    674		} else if (r < 0) {
    675			perror("read");
    676		}
    677	} while (r > 0);
    678
    679	return (int)r;
    680}
    681
    682static int do_mmap(int infd, int outfd, unsigned int size)
    683{
    684	char *inbuf = mmap(NULL, size, PROT_READ, MAP_SHARED, infd, 0);
    685	ssize_t ret = 0, off = 0;
    686	size_t rem;
    687
    688	if (inbuf == MAP_FAILED) {
    689		perror("mmap");
    690		return 1;
    691	}
    692
    693	rem = size;
    694
    695	while (rem > 0) {
    696		ret = write(outfd, inbuf + off, rem);
    697
    698		if (ret < 0) {
    699			perror("write");
    700			break;
    701		}
    702
    703		off += ret;
    704		rem -= ret;
    705	}
    706
    707	munmap(inbuf, size);
    708	return rem;
    709}
    710
    711static int get_infd_size(int fd)
    712{
    713	struct stat sb;
    714	ssize_t count;
    715	int err;
    716
    717	err = fstat(fd, &sb);
    718	if (err < 0) {
    719		perror("fstat");
    720		return -1;
    721	}
    722
    723	if ((sb.st_mode & S_IFMT) != S_IFREG) {
    724		fprintf(stderr, "%s: stdin is not a regular file\n", __func__);
    725		return -2;
    726	}
    727
    728	count = sb.st_size;
    729	if (count > INT_MAX) {
    730		fprintf(stderr, "File too large: %zu\n", count);
    731		return -3;
    732	}
    733
    734	return (int)count;
    735}
    736
    737static int do_sendfile(int infd, int outfd, unsigned int count)
    738{
    739	while (count > 0) {
    740		ssize_t r;
    741
    742		r = sendfile(outfd, infd, NULL, count);
    743		if (r < 0) {
    744			perror("sendfile");
    745			return 3;
    746		}
    747
    748		count -= r;
    749	}
    750
    751	return 0;
    752}
    753
    754static int copyfd_io_mmap(int infd, int peerfd, int outfd,
    755			  unsigned int size, bool *in_closed_after_out)
    756{
    757	int err;
    758
    759	if (listen_mode) {
    760		err = do_recvfile(peerfd, outfd);
    761		if (err)
    762			return err;
    763
    764		err = do_mmap(infd, peerfd, size);
    765	} else {
    766		err = do_mmap(infd, peerfd, size);
    767		if (err)
    768			return err;
    769
    770		shutdown(peerfd, SHUT_WR);
    771
    772		err = do_recvfile(peerfd, outfd);
    773		*in_closed_after_out = true;
    774	}
    775
    776	return err;
    777}
    778
    779static int copyfd_io_sendfile(int infd, int peerfd, int outfd,
    780			      unsigned int size, bool *in_closed_after_out)
    781{
    782	int err;
    783
    784	if (listen_mode) {
    785		err = do_recvfile(peerfd, outfd);
    786		if (err)
    787			return err;
    788
    789		err = do_sendfile(infd, peerfd, size);
    790	} else {
    791		err = do_sendfile(infd, peerfd, size);
    792		if (err)
    793			return err;
    794		err = do_recvfile(peerfd, outfd);
    795		*in_closed_after_out = true;
    796	}
    797
    798	return err;
    799}
    800
    801static int copyfd_io(int infd, int peerfd, int outfd, bool close_peerfd)
    802{
    803	bool in_closed_after_out = false;
    804	struct timespec start, end;
    805	int file_size;
    806	int ret;
    807
    808	if (cfg_time && (clock_gettime(CLOCK_MONOTONIC, &start) < 0))
    809		xerror("can not fetch start time %d", errno);
    810
    811	switch (cfg_mode) {
    812	case CFG_MODE_POLL:
    813		ret = copyfd_io_poll(infd, peerfd, outfd, &in_closed_after_out);
    814		break;
    815
    816	case CFG_MODE_MMAP:
    817		file_size = get_infd_size(infd);
    818		if (file_size < 0)
    819			return file_size;
    820		ret = copyfd_io_mmap(infd, peerfd, outfd, file_size, &in_closed_after_out);
    821		break;
    822
    823	case CFG_MODE_SENDFILE:
    824		file_size = get_infd_size(infd);
    825		if (file_size < 0)
    826			return file_size;
    827		ret = copyfd_io_sendfile(infd, peerfd, outfd, file_size, &in_closed_after_out);
    828		break;
    829
    830	default:
    831		fprintf(stderr, "Invalid mode %d\n", cfg_mode);
    832
    833		die_usage();
    834		return 1;
    835	}
    836
    837	if (ret)
    838		return ret;
    839
    840	if (close_peerfd)
    841		close(peerfd);
    842
    843	if (cfg_time) {
    844		unsigned int delta_ms;
    845
    846		if (clock_gettime(CLOCK_MONOTONIC, &end) < 0)
    847			xerror("can not fetch end time %d", errno);
    848		delta_ms = (end.tv_sec - start.tv_sec) * 1000 + (end.tv_nsec - start.tv_nsec) / 1000000;
    849		if (delta_ms > cfg_time) {
    850			xerror("transfer slower than expected! runtime %d ms, expected %d ms",
    851			       delta_ms, cfg_time);
    852		}
    853
    854		/* show the runtime only if this end shutdown(wr) before receiving the EOF,
    855		 * (that is, if this end got the longer runtime)
    856		 */
    857		if (in_closed_after_out)
    858			fprintf(stderr, "%d", delta_ms);
    859	}
    860
    861	return 0;
    862}
    863
    864static void check_sockaddr(int pf, struct sockaddr_storage *ss,
    865			   socklen_t salen)
    866{
    867	struct sockaddr_in6 *sin6;
    868	struct sockaddr_in *sin;
    869	socklen_t wanted_size = 0;
    870
    871	switch (pf) {
    872	case AF_INET:
    873		wanted_size = sizeof(*sin);
    874		sin = (void *)ss;
    875		if (!sin->sin_port)
    876			fprintf(stderr, "accept: something wrong: ip connection from port 0");
    877		break;
    878	case AF_INET6:
    879		wanted_size = sizeof(*sin6);
    880		sin6 = (void *)ss;
    881		if (!sin6->sin6_port)
    882			fprintf(stderr, "accept: something wrong: ipv6 connection from port 0");
    883		break;
    884	default:
    885		fprintf(stderr, "accept: Unknown pf %d, salen %u\n", pf, salen);
    886		return;
    887	}
    888
    889	if (salen != wanted_size)
    890		fprintf(stderr, "accept: size mismatch, got %d expected %d\n",
    891			(int)salen, wanted_size);
    892
    893	if (ss->ss_family != pf)
    894		fprintf(stderr, "accept: pf mismatch, expect %d, ss_family is %d\n",
    895			(int)ss->ss_family, pf);
    896}
    897
    898static void check_getpeername(int fd, struct sockaddr_storage *ss, socklen_t salen)
    899{
    900	struct sockaddr_storage peerss;
    901	socklen_t peersalen = sizeof(peerss);
    902
    903	if (getpeername(fd, (struct sockaddr *)&peerss, &peersalen) < 0) {
    904		perror("getpeername");
    905		return;
    906	}
    907
    908	if (peersalen != salen) {
    909		fprintf(stderr, "%s: %d vs %d\n", __func__, peersalen, salen);
    910		return;
    911	}
    912
    913	if (memcmp(ss, &peerss, peersalen)) {
    914		char a[INET6_ADDRSTRLEN];
    915		char b[INET6_ADDRSTRLEN];
    916		char c[INET6_ADDRSTRLEN];
    917		char d[INET6_ADDRSTRLEN];
    918
    919		xgetnameinfo((struct sockaddr *)ss, salen,
    920			     a, sizeof(a), b, sizeof(b));
    921
    922		xgetnameinfo((struct sockaddr *)&peerss, peersalen,
    923			     c, sizeof(c), d, sizeof(d));
    924
    925		fprintf(stderr, "%s: memcmp failure: accept %s vs peername %s, %s vs %s salen %d vs %d\n",
    926			__func__, a, c, b, d, peersalen, salen);
    927	}
    928}
    929
    930static void check_getpeername_connect(int fd)
    931{
    932	struct sockaddr_storage ss;
    933	socklen_t salen = sizeof(ss);
    934	char a[INET6_ADDRSTRLEN];
    935	char b[INET6_ADDRSTRLEN];
    936
    937	if (getpeername(fd, (struct sockaddr *)&ss, &salen) < 0) {
    938		perror("getpeername");
    939		return;
    940	}
    941
    942	xgetnameinfo((struct sockaddr *)&ss, salen,
    943		     a, sizeof(a), b, sizeof(b));
    944
    945	if (strcmp(cfg_host, a) || strcmp(cfg_port, b))
    946		fprintf(stderr, "%s: %s vs %s, %s vs %s\n", __func__,
    947			cfg_host, a, cfg_port, b);
    948}
    949
    950static void maybe_close(int fd)
    951{
    952	unsigned int r = rand();
    953
    954	if (!(cfg_join || cfg_remove || cfg_repeat > 1) && (r & 1))
    955		close(fd);
    956}
    957
    958int main_loop_s(int listensock)
    959{
    960	struct sockaddr_storage ss;
    961	struct pollfd polls;
    962	socklen_t salen;
    963	int remotesock;
    964	int fd = 0;
    965
    966again:
    967	polls.fd = listensock;
    968	polls.events = POLLIN;
    969
    970	switch (poll(&polls, 1, poll_timeout)) {
    971	case -1:
    972		perror("poll");
    973		return 1;
    974	case 0:
    975		fprintf(stderr, "%s: timed out\n", __func__);
    976		close(listensock);
    977		return 2;
    978	}
    979
    980	salen = sizeof(ss);
    981	remotesock = accept(listensock, (struct sockaddr *)&ss, &salen);
    982	if (remotesock >= 0) {
    983		maybe_close(listensock);
    984		check_sockaddr(pf, &ss, salen);
    985		check_getpeername(remotesock, &ss, salen);
    986
    987		if (cfg_input) {
    988			fd = open(cfg_input, O_RDONLY);
    989			if (fd < 0)
    990				xerror("can't open %s: %d", cfg_input, errno);
    991		}
    992
    993		SOCK_TEST_TCPULP(remotesock, 0);
    994
    995		copyfd_io(fd, remotesock, 1, true);
    996	} else {
    997		perror("accept");
    998		return 1;
    999	}
   1000
   1001	if (--cfg_repeat > 0) {
   1002		if (cfg_input)
   1003			close(fd);
   1004		goto again;
   1005	}
   1006
   1007	return 0;
   1008}
   1009
   1010static void init_rng(void)
   1011{
   1012	int fd = open("/dev/urandom", O_RDONLY);
   1013	unsigned int foo;
   1014
   1015	if (fd > 0) {
   1016		int ret = read(fd, &foo, sizeof(foo));
   1017
   1018		if (ret < 0)
   1019			srand(fd + foo);
   1020		close(fd);
   1021	}
   1022
   1023	srand(foo);
   1024}
   1025
   1026static void xsetsockopt(int fd, int level, int optname, const void *optval, socklen_t optlen)
   1027{
   1028	int err;
   1029
   1030	err = setsockopt(fd, level, optname, optval, optlen);
   1031	if (err) {
   1032		perror("setsockopt");
   1033		exit(1);
   1034	}
   1035}
   1036
   1037static void apply_cmsg_types(int fd, const struct cfg_cmsg_types *cmsg)
   1038{
   1039	static const unsigned int on = 1;
   1040
   1041	if (cmsg->timestampns)
   1042		xsetsockopt(fd, SOL_SOCKET, SO_TIMESTAMPNS_NEW, &on, sizeof(on));
   1043	if (cmsg->tcp_inq)
   1044		xsetsockopt(fd, IPPROTO_TCP, TCP_INQ, &on, sizeof(on));
   1045}
   1046
   1047static void parse_cmsg_types(const char *type)
   1048{
   1049	char *next = strchr(type, ',');
   1050	unsigned int len = 0;
   1051
   1052	cfg_cmsg_types.cmsg_enabled = 1;
   1053
   1054	if (next) {
   1055		parse_cmsg_types(next + 1);
   1056		len = next - type;
   1057	} else {
   1058		len = strlen(type);
   1059	}
   1060
   1061	if (strncmp(type, "TIMESTAMPNS", len) == 0) {
   1062		cfg_cmsg_types.timestampns = 1;
   1063		return;
   1064	}
   1065
   1066	if (strncmp(type, "TCPINQ", len) == 0) {
   1067		cfg_cmsg_types.tcp_inq = 1;
   1068		return;
   1069	}
   1070
   1071	fprintf(stderr, "Unrecognized cmsg option %s\n", type);
   1072	exit(1);
   1073}
   1074
   1075static void parse_setsock_options(const char *name)
   1076{
   1077	char *next = strchr(name, ',');
   1078	unsigned int len = 0;
   1079
   1080	if (next) {
   1081		parse_setsock_options(next + 1);
   1082		len = next - name;
   1083	} else {
   1084		len = strlen(name);
   1085	}
   1086
   1087	if (strncmp(name, "TRANSPARENT", len) == 0) {
   1088		cfg_sockopt_types.transparent = 1;
   1089		return;
   1090	}
   1091
   1092	fprintf(stderr, "Unrecognized setsockopt option %s\n", name);
   1093	exit(1);
   1094}
   1095
   1096void xdisconnect(int fd, int addrlen)
   1097{
   1098	struct sockaddr_storage empty;
   1099	int msec_sleep = 10;
   1100	int queued = 1;
   1101	int i;
   1102
   1103	shutdown(fd, SHUT_WR);
   1104
   1105	/* while until the pending data is completely flushed, the later
   1106	 * disconnect will bypass/ignore/drop any pending data.
   1107	 */
   1108	for (i = 0; ; i += msec_sleep) {
   1109		if (ioctl(fd, SIOCOUTQ, &queued) < 0)
   1110			xerror("can't query out socket queue: %d", errno);
   1111
   1112		if (!queued)
   1113			break;
   1114
   1115		if (i > poll_timeout)
   1116			xerror("timeout while waiting for spool to complete");
   1117		usleep(msec_sleep * 1000);
   1118	}
   1119
   1120	memset(&empty, 0, sizeof(empty));
   1121	empty.ss_family = AF_UNSPEC;
   1122	if (connect(fd, (struct sockaddr *)&empty, addrlen) < 0)
   1123		xerror("can't disconnect: %d", errno);
   1124}
   1125
   1126int main_loop(void)
   1127{
   1128	int fd, ret, fd_in = 0;
   1129	struct addrinfo *peer;
   1130
   1131	/* listener is ready. */
   1132	fd = sock_connect_mptcp(cfg_host, cfg_port, cfg_sock_proto, &peer);
   1133	if (fd < 0)
   1134		return 2;
   1135
   1136again:
   1137	check_getpeername_connect(fd);
   1138
   1139	SOCK_TEST_TCPULP(fd, cfg_sock_proto);
   1140
   1141	if (cfg_rcvbuf)
   1142		set_rcvbuf(fd, cfg_rcvbuf);
   1143	if (cfg_sndbuf)
   1144		set_sndbuf(fd, cfg_sndbuf);
   1145	if (cfg_cmsg_types.cmsg_enabled)
   1146		apply_cmsg_types(fd, &cfg_cmsg_types);
   1147
   1148	if (cfg_input) {
   1149		fd_in = open(cfg_input, O_RDONLY);
   1150		if (fd < 0)
   1151			xerror("can't open %s:%d", cfg_input, errno);
   1152	}
   1153
   1154	/* close the client socket open only if we are not going to reconnect */
   1155	ret = copyfd_io(fd_in, fd, 1, cfg_repeat == 1);
   1156	if (ret)
   1157		return ret;
   1158
   1159	if (--cfg_repeat > 0) {
   1160		xdisconnect(fd, peer->ai_addrlen);
   1161
   1162		/* the socket could be unblocking at this point, we need the
   1163		 * connect to be blocking
   1164		 */
   1165		set_nonblock(fd, false);
   1166		if (connect(fd, peer->ai_addr, peer->ai_addrlen))
   1167			xerror("can't reconnect: %d", errno);
   1168		if (cfg_input)
   1169			close(fd_in);
   1170		goto again;
   1171	}
   1172	return 0;
   1173}
   1174
   1175int parse_proto(const char *proto)
   1176{
   1177	if (!strcasecmp(proto, "MPTCP"))
   1178		return IPPROTO_MPTCP;
   1179	if (!strcasecmp(proto, "TCP"))
   1180		return IPPROTO_TCP;
   1181
   1182	fprintf(stderr, "Unknown protocol: %s\n.", proto);
   1183	die_usage();
   1184
   1185	/* silence compiler warning */
   1186	return 0;
   1187}
   1188
   1189int parse_mode(const char *mode)
   1190{
   1191	if (!strcasecmp(mode, "poll"))
   1192		return CFG_MODE_POLL;
   1193	if (!strcasecmp(mode, "mmap"))
   1194		return CFG_MODE_MMAP;
   1195	if (!strcasecmp(mode, "sendfile"))
   1196		return CFG_MODE_SENDFILE;
   1197
   1198	fprintf(stderr, "Unknown test mode: %s\n", mode);
   1199	fprintf(stderr, "Supported modes are:\n");
   1200	fprintf(stderr, "\t\t\"poll\" - interleaved read/write using poll()\n");
   1201	fprintf(stderr, "\t\t\"mmap\" - send entire input file (mmap+write), then read response (-l will read input first)\n");
   1202	fprintf(stderr, "\t\t\"sendfile\" - send entire input file (sendfile), then read response (-l will read input first)\n");
   1203
   1204	die_usage();
   1205
   1206	/* silence compiler warning */
   1207	return 0;
   1208}
   1209
   1210int parse_peek(const char *mode)
   1211{
   1212	if (!strcasecmp(mode, "saveWithPeek"))
   1213		return CFG_WITH_PEEK;
   1214	if (!strcasecmp(mode, "saveAfterPeek"))
   1215		return CFG_AFTER_PEEK;
   1216
   1217	fprintf(stderr, "Unknown: %s\n", mode);
   1218	fprintf(stderr, "Supported MSG_PEEK mode are:\n");
   1219	fprintf(stderr,
   1220		"\t\t\"saveWithPeek\" - recv data with flags 'MSG_PEEK' and save the peek data into file\n");
   1221	fprintf(stderr,
   1222		"\t\t\"saveAfterPeek\" - read and save data into file after recv with flags 'MSG_PEEK'\n");
   1223
   1224	die_usage();
   1225
   1226	/* silence compiler warning */
   1227	return 0;
   1228}
   1229
   1230static int parse_int(const char *size)
   1231{
   1232	unsigned long s;
   1233
   1234	errno = 0;
   1235
   1236	s = strtoul(size, NULL, 0);
   1237
   1238	if (errno) {
   1239		fprintf(stderr, "Invalid sndbuf size %s (%s)\n",
   1240			size, strerror(errno));
   1241		die_usage();
   1242	}
   1243
   1244	if (s > INT_MAX) {
   1245		fprintf(stderr, "Invalid sndbuf size %s (%s)\n",
   1246			size, strerror(ERANGE));
   1247		die_usage();
   1248	}
   1249
   1250	return (int)s;
   1251}
   1252
   1253static void parse_opts(int argc, char **argv)
   1254{
   1255	int c;
   1256
   1257	while ((c = getopt(argc, argv, "6c:hi:I:jlm:M:o:p:P:r:R:s:S:t:T:w:")) != -1) {
   1258		switch (c) {
   1259		case 'j':
   1260			cfg_join = true;
   1261			cfg_mode = CFG_MODE_POLL;
   1262			break;
   1263		case 'r':
   1264			cfg_remove = true;
   1265			cfg_mode = CFG_MODE_POLL;
   1266			cfg_wait = 400000;
   1267			cfg_do_w = atoi(optarg);
   1268			if (cfg_do_w <= 0)
   1269				cfg_do_w = 50;
   1270			break;
   1271		case 'i':
   1272			cfg_input = optarg;
   1273			break;
   1274		case 'I':
   1275			cfg_repeat = atoi(optarg);
   1276			break;
   1277		case 'l':
   1278			listen_mode = true;
   1279			break;
   1280		case 'p':
   1281			cfg_port = optarg;
   1282			break;
   1283		case 's':
   1284			cfg_sock_proto = parse_proto(optarg);
   1285			break;
   1286		case 'h':
   1287			die_usage();
   1288			break;
   1289		case '6':
   1290			pf = AF_INET6;
   1291			break;
   1292		case 't':
   1293			poll_timeout = atoi(optarg) * 1000;
   1294			if (poll_timeout <= 0)
   1295				poll_timeout = -1;
   1296			break;
   1297		case 'T':
   1298			cfg_time = atoi(optarg);
   1299			break;
   1300		case 'm':
   1301			cfg_mode = parse_mode(optarg);
   1302			break;
   1303		case 'S':
   1304			cfg_sndbuf = parse_int(optarg);
   1305			break;
   1306		case 'R':
   1307			cfg_rcvbuf = parse_int(optarg);
   1308			break;
   1309		case 'w':
   1310			cfg_wait = atoi(optarg)*1000000;
   1311			break;
   1312		case 'M':
   1313			cfg_mark = strtol(optarg, NULL, 0);
   1314			break;
   1315		case 'P':
   1316			cfg_peek = parse_peek(optarg);
   1317			break;
   1318		case 'c':
   1319			parse_cmsg_types(optarg);
   1320			break;
   1321		case 'o':
   1322			parse_setsock_options(optarg);
   1323			break;
   1324		}
   1325	}
   1326
   1327	if (optind + 1 != argc)
   1328		die_usage();
   1329	cfg_host = argv[optind];
   1330
   1331	if (strchr(cfg_host, ':'))
   1332		pf = AF_INET6;
   1333}
   1334
   1335int main(int argc, char *argv[])
   1336{
   1337	init_rng();
   1338
   1339	signal(SIGUSR1, handle_signal);
   1340	parse_opts(argc, argv);
   1341
   1342	if (listen_mode) {
   1343		int fd = sock_listen_mptcp(cfg_host, cfg_port);
   1344
   1345		if (fd < 0)
   1346			return 1;
   1347
   1348		if (cfg_rcvbuf)
   1349			set_rcvbuf(fd, cfg_rcvbuf);
   1350		if (cfg_sndbuf)
   1351			set_sndbuf(fd, cfg_sndbuf);
   1352		if (cfg_mark)
   1353			set_mark(fd, cfg_mark);
   1354		if (cfg_cmsg_types.cmsg_enabled)
   1355			apply_cmsg_types(fd, &cfg_cmsg_types);
   1356
   1357		return main_loop_s(fd);
   1358	}
   1359
   1360	return main_loop();
   1361}