cachepc-linux

Fork of AMDESE/linux with modifications for CachePC side-channel attack
git clone https://git.sinitax.com/sinitax/cachepc-linux
Log | Files | Refs | README | LICENSE | sfeed.txt

test_l4lb_noinline.c (10814B)


      1// SPDX-License-Identifier: GPL-2.0
      2// Copyright (c) 2017 Facebook
      3#include <stddef.h>
      4#include <stdbool.h>
      5#include <string.h>
      6#include <linux/pkt_cls.h>
      7#include <linux/bpf.h>
      8#include <linux/in.h>
      9#include <linux/if_ether.h>
     10#include <linux/ip.h>
     11#include <linux/ipv6.h>
     12#include <linux/icmp.h>
     13#include <linux/icmpv6.h>
     14#include <linux/tcp.h>
     15#include <linux/udp.h>
     16#include <bpf/bpf_helpers.h>
     17#include "test_iptunnel_common.h"
     18#include <bpf/bpf_endian.h>
     19
     20static __always_inline __u32 rol32(__u32 word, unsigned int shift)
     21{
     22	return (word << shift) | (word >> ((-shift) & 31));
     23}
     24
     25/* copy paste of jhash from kernel sources to make sure llvm
     26 * can compile it into valid sequence of bpf instructions
     27 */
     28#define __jhash_mix(a, b, c)			\
     29{						\
     30	a -= c;  a ^= rol32(c, 4);  c += b;	\
     31	b -= a;  b ^= rol32(a, 6);  a += c;	\
     32	c -= b;  c ^= rol32(b, 8);  b += a;	\
     33	a -= c;  a ^= rol32(c, 16); c += b;	\
     34	b -= a;  b ^= rol32(a, 19); a += c;	\
     35	c -= b;  c ^= rol32(b, 4);  b += a;	\
     36}
     37
     38#define __jhash_final(a, b, c)			\
     39{						\
     40	c ^= b; c -= rol32(b, 14);		\
     41	a ^= c; a -= rol32(c, 11);		\
     42	b ^= a; b -= rol32(a, 25);		\
     43	c ^= b; c -= rol32(b, 16);		\
     44	a ^= c; a -= rol32(c, 4);		\
     45	b ^= a; b -= rol32(a, 14);		\
     46	c ^= b; c -= rol32(b, 24);		\
     47}
     48
     49#define JHASH_INITVAL		0xdeadbeef
     50
     51typedef unsigned int u32;
     52
     53static __noinline u32 jhash(const void *key, u32 length, u32 initval)
     54{
     55	u32 a, b, c;
     56	const unsigned char *k = key;
     57
     58	a = b = c = JHASH_INITVAL + length + initval;
     59
     60	while (length > 12) {
     61		a += *(u32 *)(k);
     62		b += *(u32 *)(k + 4);
     63		c += *(u32 *)(k + 8);
     64		__jhash_mix(a, b, c);
     65		length -= 12;
     66		k += 12;
     67	}
     68	switch (length) {
     69	case 12: c += (u32)k[11]<<24;
     70	case 11: c += (u32)k[10]<<16;
     71	case 10: c += (u32)k[9]<<8;
     72	case 9:  c += k[8];
     73	case 8:  b += (u32)k[7]<<24;
     74	case 7:  b += (u32)k[6]<<16;
     75	case 6:  b += (u32)k[5]<<8;
     76	case 5:  b += k[4];
     77	case 4:  a += (u32)k[3]<<24;
     78	case 3:  a += (u32)k[2]<<16;
     79	case 2:  a += (u32)k[1]<<8;
     80	case 1:  a += k[0];
     81		 __jhash_final(a, b, c);
     82	case 0: /* Nothing left to add */
     83		break;
     84	}
     85
     86	return c;
     87}
     88
     89static __noinline u32 __jhash_nwords(u32 a, u32 b, u32 c, u32 initval)
     90{
     91	a += initval;
     92	b += initval;
     93	c += initval;
     94	__jhash_final(a, b, c);
     95	return c;
     96}
     97
     98static __noinline u32 jhash_2words(u32 a, u32 b, u32 initval)
     99{
    100	return __jhash_nwords(a, b, 0, initval + JHASH_INITVAL + (2 << 2));
    101}
    102
    103#define PCKT_FRAGMENTED 65343
    104#define IPV4_HDR_LEN_NO_OPT 20
    105#define IPV4_PLUS_ICMP_HDR 28
    106#define IPV6_PLUS_ICMP_HDR 48
    107#define RING_SIZE 2
    108#define MAX_VIPS 12
    109#define MAX_REALS 5
    110#define CTL_MAP_SIZE 16
    111#define CH_RINGS_SIZE (MAX_VIPS * RING_SIZE)
    112#define F_IPV6 (1 << 0)
    113#define F_HASH_NO_SRC_PORT (1 << 0)
    114#define F_ICMP (1 << 0)
    115#define F_SYN_SET (1 << 1)
    116
    117struct packet_description {
    118	union {
    119		__be32 src;
    120		__be32 srcv6[4];
    121	};
    122	union {
    123		__be32 dst;
    124		__be32 dstv6[4];
    125	};
    126	union {
    127		__u32 ports;
    128		__u16 port16[2];
    129	};
    130	__u8 proto;
    131	__u8 flags;
    132};
    133
    134struct ctl_value {
    135	union {
    136		__u64 value;
    137		__u32 ifindex;
    138		__u8 mac[6];
    139	};
    140};
    141
    142struct vip_meta {
    143	__u32 flags;
    144	__u32 vip_num;
    145};
    146
    147struct real_definition {
    148	union {
    149		__be32 dst;
    150		__be32 dstv6[4];
    151	};
    152	__u8 flags;
    153};
    154
    155struct vip_stats {
    156	__u64 bytes;
    157	__u64 pkts;
    158};
    159
    160struct eth_hdr {
    161	unsigned char eth_dest[ETH_ALEN];
    162	unsigned char eth_source[ETH_ALEN];
    163	unsigned short eth_proto;
    164};
    165
    166struct {
    167	__uint(type, BPF_MAP_TYPE_HASH);
    168	__uint(max_entries, MAX_VIPS);
    169	__type(key, struct vip);
    170	__type(value, struct vip_meta);
    171} vip_map SEC(".maps");
    172
    173struct {
    174	__uint(type, BPF_MAP_TYPE_ARRAY);
    175	__uint(max_entries, CH_RINGS_SIZE);
    176	__type(key, __u32);
    177	__type(value, __u32);
    178} ch_rings SEC(".maps");
    179
    180struct {
    181	__uint(type, BPF_MAP_TYPE_ARRAY);
    182	__uint(max_entries, MAX_REALS);
    183	__type(key, __u32);
    184	__type(value, struct real_definition);
    185} reals SEC(".maps");
    186
    187struct {
    188	__uint(type, BPF_MAP_TYPE_PERCPU_ARRAY);
    189	__uint(max_entries, MAX_VIPS);
    190	__type(key, __u32);
    191	__type(value, struct vip_stats);
    192} stats SEC(".maps");
    193
    194struct {
    195	__uint(type, BPF_MAP_TYPE_ARRAY);
    196	__uint(max_entries, CTL_MAP_SIZE);
    197	__type(key, __u32);
    198	__type(value, struct ctl_value);
    199} ctl_array SEC(".maps");
    200
    201static __noinline __u32 get_packet_hash(struct packet_description *pckt, bool ipv6)
    202{
    203	if (ipv6)
    204		return jhash_2words(jhash(pckt->srcv6, 16, MAX_VIPS),
    205				    pckt->ports, CH_RINGS_SIZE);
    206	else
    207		return jhash_2words(pckt->src, pckt->ports, CH_RINGS_SIZE);
    208}
    209
    210static __noinline bool get_packet_dst(struct real_definition **real,
    211				      struct packet_description *pckt,
    212				      struct vip_meta *vip_info,
    213				      bool is_ipv6)
    214{
    215	__u32 hash = get_packet_hash(pckt, is_ipv6);
    216	__u32 key = RING_SIZE * vip_info->vip_num + hash % RING_SIZE;
    217	__u32 *real_pos;
    218
    219	if (hash != 0x358459b7 /* jhash of ipv4 packet */  &&
    220	    hash != 0x2f4bc6bb /* jhash of ipv6 packet */)
    221		return false;
    222
    223	real_pos = bpf_map_lookup_elem(&ch_rings, &key);
    224	if (!real_pos)
    225		return false;
    226	key = *real_pos;
    227	*real = bpf_map_lookup_elem(&reals, &key);
    228	if (!(*real))
    229		return false;
    230	return true;
    231}
    232
    233static __noinline int parse_icmpv6(void *data, void *data_end, __u64 off,
    234				   struct packet_description *pckt)
    235{
    236	struct icmp6hdr *icmp_hdr;
    237	struct ipv6hdr *ip6h;
    238
    239	icmp_hdr = data + off;
    240	if (icmp_hdr + 1 > data_end)
    241		return TC_ACT_SHOT;
    242	if (icmp_hdr->icmp6_type != ICMPV6_PKT_TOOBIG)
    243		return TC_ACT_OK;
    244	off += sizeof(struct icmp6hdr);
    245	ip6h = data + off;
    246	if (ip6h + 1 > data_end)
    247		return TC_ACT_SHOT;
    248	pckt->proto = ip6h->nexthdr;
    249	pckt->flags |= F_ICMP;
    250	memcpy(pckt->srcv6, ip6h->daddr.s6_addr32, 16);
    251	memcpy(pckt->dstv6, ip6h->saddr.s6_addr32, 16);
    252	return TC_ACT_UNSPEC;
    253}
    254
    255static __noinline int parse_icmp(void *data, void *data_end, __u64 off,
    256				 struct packet_description *pckt)
    257{
    258	struct icmphdr *icmp_hdr;
    259	struct iphdr *iph;
    260
    261	icmp_hdr = data + off;
    262	if (icmp_hdr + 1 > data_end)
    263		return TC_ACT_SHOT;
    264	if (icmp_hdr->type != ICMP_DEST_UNREACH ||
    265	    icmp_hdr->code != ICMP_FRAG_NEEDED)
    266		return TC_ACT_OK;
    267	off += sizeof(struct icmphdr);
    268	iph = data + off;
    269	if (iph + 1 > data_end)
    270		return TC_ACT_SHOT;
    271	if (iph->ihl != 5)
    272		return TC_ACT_SHOT;
    273	pckt->proto = iph->protocol;
    274	pckt->flags |= F_ICMP;
    275	pckt->src = iph->daddr;
    276	pckt->dst = iph->saddr;
    277	return TC_ACT_UNSPEC;
    278}
    279
    280static __noinline bool parse_udp(void *data, __u64 off, void *data_end,
    281				 struct packet_description *pckt)
    282{
    283	struct udphdr *udp;
    284	udp = data + off;
    285
    286	if (udp + 1 > data_end)
    287		return false;
    288
    289	if (!(pckt->flags & F_ICMP)) {
    290		pckt->port16[0] = udp->source;
    291		pckt->port16[1] = udp->dest;
    292	} else {
    293		pckt->port16[0] = udp->dest;
    294		pckt->port16[1] = udp->source;
    295	}
    296	return true;
    297}
    298
    299static __noinline bool parse_tcp(void *data, __u64 off, void *data_end,
    300				 struct packet_description *pckt)
    301{
    302	struct tcphdr *tcp;
    303
    304	tcp = data + off;
    305	if (tcp + 1 > data_end)
    306		return false;
    307
    308	if (tcp->syn)
    309		pckt->flags |= F_SYN_SET;
    310
    311	if (!(pckt->flags & F_ICMP)) {
    312		pckt->port16[0] = tcp->source;
    313		pckt->port16[1] = tcp->dest;
    314	} else {
    315		pckt->port16[0] = tcp->dest;
    316		pckt->port16[1] = tcp->source;
    317	}
    318	return true;
    319}
    320
    321static __noinline int process_packet(void *data, __u64 off, void *data_end,
    322				     bool is_ipv6, struct __sk_buff *skb)
    323{
    324	void *pkt_start = (void *)(long)skb->data;
    325	struct packet_description pckt = {};
    326	struct eth_hdr *eth = pkt_start;
    327	struct bpf_tunnel_key tkey = {};
    328	struct vip_stats *data_stats;
    329	struct real_definition *dst;
    330	struct vip_meta *vip_info;
    331	struct ctl_value *cval;
    332	__u32 v4_intf_pos = 1;
    333	__u32 v6_intf_pos = 2;
    334	struct ipv6hdr *ip6h;
    335	struct vip vip = {};
    336	struct iphdr *iph;
    337	int tun_flag = 0;
    338	__u16 pkt_bytes;
    339	__u64 iph_len;
    340	__u32 ifindex;
    341	__u8 protocol;
    342	__u32 vip_num;
    343	int action;
    344
    345	tkey.tunnel_ttl = 64;
    346	if (is_ipv6) {
    347		ip6h = data + off;
    348		if (ip6h + 1 > data_end)
    349			return TC_ACT_SHOT;
    350
    351		iph_len = sizeof(struct ipv6hdr);
    352		protocol = ip6h->nexthdr;
    353		pckt.proto = protocol;
    354		pkt_bytes = bpf_ntohs(ip6h->payload_len);
    355		off += iph_len;
    356		if (protocol == IPPROTO_FRAGMENT) {
    357			return TC_ACT_SHOT;
    358		} else if (protocol == IPPROTO_ICMPV6) {
    359			action = parse_icmpv6(data, data_end, off, &pckt);
    360			if (action >= 0)
    361				return action;
    362			off += IPV6_PLUS_ICMP_HDR;
    363		} else {
    364			memcpy(pckt.srcv6, ip6h->saddr.s6_addr32, 16);
    365			memcpy(pckt.dstv6, ip6h->daddr.s6_addr32, 16);
    366		}
    367	} else {
    368		iph = data + off;
    369		if (iph + 1 > data_end)
    370			return TC_ACT_SHOT;
    371		if (iph->ihl != 5)
    372			return TC_ACT_SHOT;
    373
    374		protocol = iph->protocol;
    375		pckt.proto = protocol;
    376		pkt_bytes = bpf_ntohs(iph->tot_len);
    377		off += IPV4_HDR_LEN_NO_OPT;
    378
    379		if (iph->frag_off & PCKT_FRAGMENTED)
    380			return TC_ACT_SHOT;
    381		if (protocol == IPPROTO_ICMP) {
    382			action = parse_icmp(data, data_end, off, &pckt);
    383			if (action >= 0)
    384				return action;
    385			off += IPV4_PLUS_ICMP_HDR;
    386		} else {
    387			pckt.src = iph->saddr;
    388			pckt.dst = iph->daddr;
    389		}
    390	}
    391	protocol = pckt.proto;
    392
    393	if (protocol == IPPROTO_TCP) {
    394		if (!parse_tcp(data, off, data_end, &pckt))
    395			return TC_ACT_SHOT;
    396	} else if (protocol == IPPROTO_UDP) {
    397		if (!parse_udp(data, off, data_end, &pckt))
    398			return TC_ACT_SHOT;
    399	} else {
    400		return TC_ACT_SHOT;
    401	}
    402
    403	if (is_ipv6)
    404		memcpy(vip.daddr.v6, pckt.dstv6, 16);
    405	else
    406		vip.daddr.v4 = pckt.dst;
    407
    408	vip.dport = pckt.port16[1];
    409	vip.protocol = pckt.proto;
    410	vip_info = bpf_map_lookup_elem(&vip_map, &vip);
    411	if (!vip_info) {
    412		vip.dport = 0;
    413		vip_info = bpf_map_lookup_elem(&vip_map, &vip);
    414		if (!vip_info)
    415			return TC_ACT_SHOT;
    416		pckt.port16[1] = 0;
    417	}
    418
    419	if (vip_info->flags & F_HASH_NO_SRC_PORT)
    420		pckt.port16[0] = 0;
    421
    422	if (!get_packet_dst(&dst, &pckt, vip_info, is_ipv6))
    423		return TC_ACT_SHOT;
    424
    425	if (dst->flags & F_IPV6) {
    426		cval = bpf_map_lookup_elem(&ctl_array, &v6_intf_pos);
    427		if (!cval)
    428			return TC_ACT_SHOT;
    429		ifindex = cval->ifindex;
    430		memcpy(tkey.remote_ipv6, dst->dstv6, 16);
    431		tun_flag = BPF_F_TUNINFO_IPV6;
    432	} else {
    433		cval = bpf_map_lookup_elem(&ctl_array, &v4_intf_pos);
    434		if (!cval)
    435			return TC_ACT_SHOT;
    436		ifindex = cval->ifindex;
    437		tkey.remote_ipv4 = dst->dst;
    438	}
    439	vip_num = vip_info->vip_num;
    440	data_stats = bpf_map_lookup_elem(&stats, &vip_num);
    441	if (!data_stats)
    442		return TC_ACT_SHOT;
    443	data_stats->pkts++;
    444	data_stats->bytes += pkt_bytes;
    445	bpf_skb_set_tunnel_key(skb, &tkey, sizeof(tkey), tun_flag);
    446	*(u32 *)eth->eth_dest = tkey.remote_ipv4;
    447	return bpf_redirect(ifindex, 0);
    448}
    449
    450SEC("tc")
    451int balancer_ingress(struct __sk_buff *ctx)
    452{
    453	void *data_end = (void *)(long)ctx->data_end;
    454	void *data = (void *)(long)ctx->data;
    455	struct eth_hdr *eth = data;
    456	__u32 eth_proto;
    457	__u32 nh_off;
    458
    459	nh_off = sizeof(struct eth_hdr);
    460	if (data + nh_off > data_end)
    461		return TC_ACT_SHOT;
    462	eth_proto = eth->eth_proto;
    463	if (eth_proto == bpf_htons(ETH_P_IP))
    464		return process_packet(data, nh_off, data_end, false, ctx);
    465	else if (eth_proto == bpf_htons(ETH_P_IPV6))
    466		return process_packet(data, nh_off, data_end, true, ctx);
    467	else
    468		return TC_ACT_SHOT;
    469}
    470char _license[] SEC("license") = "GPL";