cachepc-linux

Fork of AMDESE/linux with modifications for CachePC side-channel attack
git clone https://git.sinitax.com/sinitax/cachepc-linux
Log | Files | Refs | README | LICENSE | sfeed.txt

psock_snd.c (8740B)


      1// SPDX-License-Identifier: GPL-2.0
      2
      3#define _GNU_SOURCE
      4
      5#include <arpa/inet.h>
      6#include <errno.h>
      7#include <error.h>
      8#include <fcntl.h>
      9#include <limits.h>
     10#include <linux/filter.h>
     11#include <linux/bpf.h>
     12#include <linux/if_packet.h>
     13#include <linux/if_vlan.h>
     14#include <linux/virtio_net.h>
     15#include <net/if.h>
     16#include <net/ethernet.h>
     17#include <netinet/ip.h>
     18#include <netinet/udp.h>
     19#include <poll.h>
     20#include <sched.h>
     21#include <stdbool.h>
     22#include <stdint.h>
     23#include <stdio.h>
     24#include <stdlib.h>
     25#include <string.h>
     26#include <sys/mman.h>
     27#include <sys/socket.h>
     28#include <sys/stat.h>
     29#include <sys/types.h>
     30#include <unistd.h>
     31
     32#include "psock_lib.h"
     33
     34static bool	cfg_use_bind;
     35static bool	cfg_use_csum_off;
     36static bool	cfg_use_csum_off_bad;
     37static bool	cfg_use_dgram;
     38static bool	cfg_use_gso;
     39static bool	cfg_use_qdisc_bypass;
     40static bool	cfg_use_vlan;
     41static bool	cfg_use_vnet;
     42
     43static char	*cfg_ifname = "lo";
     44static int	cfg_mtu	= 1500;
     45static int	cfg_payload_len = DATA_LEN;
     46static int	cfg_truncate_len = INT_MAX;
     47static uint16_t	cfg_port = 8000;
     48
     49/* test sending up to max mtu + 1 */
     50#define TEST_SZ	(sizeof(struct virtio_net_hdr) + ETH_HLEN + ETH_MAX_MTU + 1)
     51
     52static char tbuf[TEST_SZ], rbuf[TEST_SZ];
     53
     54static unsigned long add_csum_hword(const uint16_t *start, int num_u16)
     55{
     56	unsigned long sum = 0;
     57	int i;
     58
     59	for (i = 0; i < num_u16; i++)
     60		sum += start[i];
     61
     62	return sum;
     63}
     64
     65static uint16_t build_ip_csum(const uint16_t *start, int num_u16,
     66			      unsigned long sum)
     67{
     68	sum += add_csum_hword(start, num_u16);
     69
     70	while (sum >> 16)
     71		sum = (sum & 0xffff) + (sum >> 16);
     72
     73	return ~sum;
     74}
     75
     76static int build_vnet_header(void *header)
     77{
     78	struct virtio_net_hdr *vh = header;
     79
     80	vh->hdr_len = ETH_HLEN + sizeof(struct iphdr) + sizeof(struct udphdr);
     81
     82	if (cfg_use_csum_off) {
     83		vh->flags |= VIRTIO_NET_HDR_F_NEEDS_CSUM;
     84		vh->csum_start = ETH_HLEN + sizeof(struct iphdr);
     85		vh->csum_offset = __builtin_offsetof(struct udphdr, check);
     86
     87		/* position check field exactly one byte beyond end of packet */
     88		if (cfg_use_csum_off_bad)
     89			vh->csum_start += sizeof(struct udphdr) + cfg_payload_len -
     90					  vh->csum_offset - 1;
     91	}
     92
     93	if (cfg_use_gso) {
     94		vh->gso_type = VIRTIO_NET_HDR_GSO_UDP;
     95		vh->gso_size = cfg_mtu - sizeof(struct iphdr);
     96	}
     97
     98	return sizeof(*vh);
     99}
    100
    101static int build_eth_header(void *header)
    102{
    103	struct ethhdr *eth = header;
    104
    105	if (cfg_use_vlan) {
    106		uint16_t *tag = header + ETH_HLEN;
    107
    108		eth->h_proto = htons(ETH_P_8021Q);
    109		tag[1] = htons(ETH_P_IP);
    110		return ETH_HLEN + 4;
    111	}
    112
    113	eth->h_proto = htons(ETH_P_IP);
    114	return ETH_HLEN;
    115}
    116
    117static int build_ipv4_header(void *header, int payload_len)
    118{
    119	struct iphdr *iph = header;
    120
    121	iph->ihl = 5;
    122	iph->version = 4;
    123	iph->ttl = 8;
    124	iph->tot_len = htons(sizeof(*iph) + sizeof(struct udphdr) + payload_len);
    125	iph->id = htons(1337);
    126	iph->protocol = IPPROTO_UDP;
    127	iph->saddr = htonl((172 << 24) | (17 << 16) | 2);
    128	iph->daddr = htonl((172 << 24) | (17 << 16) | 1);
    129	iph->check = build_ip_csum((void *) iph, iph->ihl << 1, 0);
    130
    131	return iph->ihl << 2;
    132}
    133
    134static int build_udp_header(void *header, int payload_len)
    135{
    136	const int alen = sizeof(uint32_t);
    137	struct udphdr *udph = header;
    138	int len = sizeof(*udph) + payload_len;
    139
    140	udph->source = htons(9);
    141	udph->dest = htons(cfg_port);
    142	udph->len = htons(len);
    143
    144	if (cfg_use_csum_off)
    145		udph->check = build_ip_csum(header - (2 * alen), alen,
    146					    htons(IPPROTO_UDP) + udph->len);
    147	else
    148		udph->check = 0;
    149
    150	return sizeof(*udph);
    151}
    152
    153static int build_packet(int payload_len)
    154{
    155	int off = 0;
    156
    157	off += build_vnet_header(tbuf);
    158	off += build_eth_header(tbuf + off);
    159	off += build_ipv4_header(tbuf + off, payload_len);
    160	off += build_udp_header(tbuf + off, payload_len);
    161
    162	if (off + payload_len > sizeof(tbuf))
    163		error(1, 0, "payload length exceeds max");
    164
    165	memset(tbuf + off, DATA_CHAR, payload_len);
    166
    167	return off + payload_len;
    168}
    169
    170static void do_bind(int fd)
    171{
    172	struct sockaddr_ll laddr = {0};
    173
    174	laddr.sll_family = AF_PACKET;
    175	laddr.sll_protocol = htons(ETH_P_IP);
    176	laddr.sll_ifindex = if_nametoindex(cfg_ifname);
    177	if (!laddr.sll_ifindex)
    178		error(1, errno, "if_nametoindex");
    179
    180	if (bind(fd, (void *)&laddr, sizeof(laddr)))
    181		error(1, errno, "bind");
    182}
    183
    184static void do_send(int fd, char *buf, int len)
    185{
    186	int ret;
    187
    188	if (!cfg_use_vnet) {
    189		buf += sizeof(struct virtio_net_hdr);
    190		len -= sizeof(struct virtio_net_hdr);
    191	}
    192	if (cfg_use_dgram) {
    193		buf += ETH_HLEN;
    194		len -= ETH_HLEN;
    195	}
    196
    197	if (cfg_use_bind) {
    198		ret = write(fd, buf, len);
    199	} else {
    200		struct sockaddr_ll laddr = {0};
    201
    202		laddr.sll_protocol = htons(ETH_P_IP);
    203		laddr.sll_ifindex = if_nametoindex(cfg_ifname);
    204		if (!laddr.sll_ifindex)
    205			error(1, errno, "if_nametoindex");
    206
    207		ret = sendto(fd, buf, len, 0, (void *)&laddr, sizeof(laddr));
    208	}
    209
    210	if (ret == -1)
    211		error(1, errno, "write");
    212	if (ret != len)
    213		error(1, 0, "write: %u %u", ret, len);
    214
    215	fprintf(stderr, "tx: %u\n", ret);
    216}
    217
    218static int do_tx(void)
    219{
    220	const int one = 1;
    221	int fd, len;
    222
    223	fd = socket(PF_PACKET, cfg_use_dgram ? SOCK_DGRAM : SOCK_RAW, 0);
    224	if (fd == -1)
    225		error(1, errno, "socket t");
    226
    227	if (cfg_use_bind)
    228		do_bind(fd);
    229
    230	if (cfg_use_qdisc_bypass &&
    231	    setsockopt(fd, SOL_PACKET, PACKET_QDISC_BYPASS, &one, sizeof(one)))
    232		error(1, errno, "setsockopt qdisc bypass");
    233
    234	if (cfg_use_vnet &&
    235	    setsockopt(fd, SOL_PACKET, PACKET_VNET_HDR, &one, sizeof(one)))
    236		error(1, errno, "setsockopt vnet");
    237
    238	len = build_packet(cfg_payload_len);
    239
    240	if (cfg_truncate_len < len)
    241		len = cfg_truncate_len;
    242
    243	do_send(fd, tbuf, len);
    244
    245	if (close(fd))
    246		error(1, errno, "close t");
    247
    248	return len;
    249}
    250
    251static int setup_rx(void)
    252{
    253	struct timeval tv = { .tv_usec = 100 * 1000 };
    254	struct sockaddr_in raddr = {0};
    255	int fd;
    256
    257	fd = socket(PF_INET, SOCK_DGRAM, 0);
    258	if (fd == -1)
    259		error(1, errno, "socket r");
    260
    261	if (setsockopt(fd, SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof(tv)))
    262		error(1, errno, "setsockopt rcv timeout");
    263
    264	raddr.sin_family = AF_INET;
    265	raddr.sin_port = htons(cfg_port);
    266	raddr.sin_addr.s_addr = htonl(INADDR_ANY);
    267
    268	if (bind(fd, (void *)&raddr, sizeof(raddr)))
    269		error(1, errno, "bind r");
    270
    271	return fd;
    272}
    273
    274static void do_rx(int fd, int expected_len, char *expected)
    275{
    276	int ret;
    277
    278	ret = recv(fd, rbuf, sizeof(rbuf), 0);
    279	if (ret == -1)
    280		error(1, errno, "recv");
    281	if (ret != expected_len)
    282		error(1, 0, "recv: %u != %u", ret, expected_len);
    283
    284	if (memcmp(rbuf, expected, ret))
    285		error(1, 0, "recv: data mismatch");
    286
    287	fprintf(stderr, "rx: %u\n", ret);
    288}
    289
    290static int setup_sniffer(void)
    291{
    292	struct timeval tv = { .tv_usec = 100 * 1000 };
    293	int fd;
    294
    295	fd = socket(PF_PACKET, SOCK_RAW, 0);
    296	if (fd == -1)
    297		error(1, errno, "socket p");
    298
    299	if (setsockopt(fd, SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof(tv)))
    300		error(1, errno, "setsockopt rcv timeout");
    301
    302	pair_udp_setfilter(fd);
    303	do_bind(fd);
    304
    305	return fd;
    306}
    307
    308static void parse_opts(int argc, char **argv)
    309{
    310	int c;
    311
    312	while ((c = getopt(argc, argv, "bcCdgl:qt:vV")) != -1) {
    313		switch (c) {
    314		case 'b':
    315			cfg_use_bind = true;
    316			break;
    317		case 'c':
    318			cfg_use_csum_off = true;
    319			break;
    320		case 'C':
    321			cfg_use_csum_off_bad = true;
    322			break;
    323		case 'd':
    324			cfg_use_dgram = true;
    325			break;
    326		case 'g':
    327			cfg_use_gso = true;
    328			break;
    329		case 'l':
    330			cfg_payload_len = strtoul(optarg, NULL, 0);
    331			break;
    332		case 'q':
    333			cfg_use_qdisc_bypass = true;
    334			break;
    335		case 't':
    336			cfg_truncate_len = strtoul(optarg, NULL, 0);
    337			break;
    338		case 'v':
    339			cfg_use_vnet = true;
    340			break;
    341		case 'V':
    342			cfg_use_vlan = true;
    343			break;
    344		default:
    345			error(1, 0, "%s: parse error", argv[0]);
    346		}
    347	}
    348
    349	if (cfg_use_vlan && cfg_use_dgram)
    350		error(1, 0, "option vlan (-V) conflicts with dgram (-d)");
    351
    352	if (cfg_use_csum_off && !cfg_use_vnet)
    353		error(1, 0, "option csum offload (-c) requires vnet (-v)");
    354
    355	if (cfg_use_csum_off_bad && !cfg_use_csum_off)
    356		error(1, 0, "option csum bad (-C) requires csum offload (-c)");
    357
    358	if (cfg_use_gso && !cfg_use_csum_off)
    359		error(1, 0, "option gso (-g) requires csum offload (-c)");
    360}
    361
    362static void run_test(void)
    363{
    364	int fdr, fds, total_len;
    365
    366	fdr = setup_rx();
    367	fds = setup_sniffer();
    368
    369	total_len = do_tx();
    370
    371	/* BPF filter accepts only this length, vlan changes MAC */
    372	if (cfg_payload_len == DATA_LEN && !cfg_use_vlan)
    373		do_rx(fds, total_len - sizeof(struct virtio_net_hdr),
    374		      tbuf + sizeof(struct virtio_net_hdr));
    375
    376	do_rx(fdr, cfg_payload_len, tbuf + total_len - cfg_payload_len);
    377
    378	if (close(fds))
    379		error(1, errno, "close s");
    380	if (close(fdr))
    381		error(1, errno, "close r");
    382}
    383
    384int main(int argc, char **argv)
    385{
    386	parse_opts(argc, argv);
    387
    388	if (system("ip link set dev lo mtu 1500"))
    389		error(1, errno, "ip link set mtu");
    390	if (system("ip addr add dev lo 172.17.0.1/24"))
    391		error(1, errno, "ip addr add");
    392	if (system("sysctl -w net.ipv4.conf.lo.accept_local=1"))
    393		error(1, errno, "sysctl lo.accept_local");
    394
    395	run_test();
    396
    397	fprintf(stderr, "OK\n\n");
    398	return 0;
    399}