cachepc-linux

Fork of AMDESE/linux with modifications for CachePC side-channel attack
git clone https://git.sinitax.com/sinitax/cachepc-linux
Log | Files | Refs | README | LICENSE | sfeed.txt

fou.c (28610B)


      1// SPDX-License-Identifier: GPL-2.0-only
      2#include <linux/module.h>
      3#include <linux/errno.h>
      4#include <linux/socket.h>
      5#include <linux/skbuff.h>
      6#include <linux/ip.h>
      7#include <linux/icmp.h>
      8#include <linux/udp.h>
      9#include <linux/types.h>
     10#include <linux/kernel.h>
     11#include <net/genetlink.h>
     12#include <net/gro.h>
     13#include <net/gue.h>
     14#include <net/fou.h>
     15#include <net/ip.h>
     16#include <net/protocol.h>
     17#include <net/udp.h>
     18#include <net/udp_tunnel.h>
     19#include <uapi/linux/fou.h>
     20#include <uapi/linux/genetlink.h>
     21
     22struct fou {
     23	struct socket *sock;
     24	u8 protocol;
     25	u8 flags;
     26	__be16 port;
     27	u8 family;
     28	u16 type;
     29	struct list_head list;
     30	struct rcu_head rcu;
     31};
     32
     33#define FOU_F_REMCSUM_NOPARTIAL BIT(0)
     34
     35struct fou_cfg {
     36	u16 type;
     37	u8 protocol;
     38	u8 flags;
     39	struct udp_port_cfg udp_config;
     40};
     41
     42static unsigned int fou_net_id;
     43
     44struct fou_net {
     45	struct list_head fou_list;
     46	struct mutex fou_lock;
     47};
     48
     49static inline struct fou *fou_from_sock(struct sock *sk)
     50{
     51	return sk->sk_user_data;
     52}
     53
     54static int fou_recv_pull(struct sk_buff *skb, struct fou *fou, size_t len)
     55{
     56	/* Remove 'len' bytes from the packet (UDP header and
     57	 * FOU header if present).
     58	 */
     59	if (fou->family == AF_INET)
     60		ip_hdr(skb)->tot_len = htons(ntohs(ip_hdr(skb)->tot_len) - len);
     61	else
     62		ipv6_hdr(skb)->payload_len =
     63		    htons(ntohs(ipv6_hdr(skb)->payload_len) - len);
     64
     65	__skb_pull(skb, len);
     66	skb_postpull_rcsum(skb, udp_hdr(skb), len);
     67	skb_reset_transport_header(skb);
     68	return iptunnel_pull_offloads(skb);
     69}
     70
     71static int fou_udp_recv(struct sock *sk, struct sk_buff *skb)
     72{
     73	struct fou *fou = fou_from_sock(sk);
     74
     75	if (!fou)
     76		return 1;
     77
     78	if (fou_recv_pull(skb, fou, sizeof(struct udphdr)))
     79		goto drop;
     80
     81	return -fou->protocol;
     82
     83drop:
     84	kfree_skb(skb);
     85	return 0;
     86}
     87
     88static struct guehdr *gue_remcsum(struct sk_buff *skb, struct guehdr *guehdr,
     89				  void *data, size_t hdrlen, u8 ipproto,
     90				  bool nopartial)
     91{
     92	__be16 *pd = data;
     93	size_t start = ntohs(pd[0]);
     94	size_t offset = ntohs(pd[1]);
     95	size_t plen = sizeof(struct udphdr) + hdrlen +
     96	    max_t(size_t, offset + sizeof(u16), start);
     97
     98	if (skb->remcsum_offload)
     99		return guehdr;
    100
    101	if (!pskb_may_pull(skb, plen))
    102		return NULL;
    103	guehdr = (struct guehdr *)&udp_hdr(skb)[1];
    104
    105	skb_remcsum_process(skb, (void *)guehdr + hdrlen,
    106			    start, offset, nopartial);
    107
    108	return guehdr;
    109}
    110
    111static int gue_control_message(struct sk_buff *skb, struct guehdr *guehdr)
    112{
    113	/* No support yet */
    114	kfree_skb(skb);
    115	return 0;
    116}
    117
    118static int gue_udp_recv(struct sock *sk, struct sk_buff *skb)
    119{
    120	struct fou *fou = fou_from_sock(sk);
    121	size_t len, optlen, hdrlen;
    122	struct guehdr *guehdr;
    123	void *data;
    124	u16 doffset = 0;
    125	u8 proto_ctype;
    126
    127	if (!fou)
    128		return 1;
    129
    130	len = sizeof(struct udphdr) + sizeof(struct guehdr);
    131	if (!pskb_may_pull(skb, len))
    132		goto drop;
    133
    134	guehdr = (struct guehdr *)&udp_hdr(skb)[1];
    135
    136	switch (guehdr->version) {
    137	case 0: /* Full GUE header present */
    138		break;
    139
    140	case 1: {
    141		/* Direct encapsulation of IPv4 or IPv6 */
    142
    143		int prot;
    144
    145		switch (((struct iphdr *)guehdr)->version) {
    146		case 4:
    147			prot = IPPROTO_IPIP;
    148			break;
    149		case 6:
    150			prot = IPPROTO_IPV6;
    151			break;
    152		default:
    153			goto drop;
    154		}
    155
    156		if (fou_recv_pull(skb, fou, sizeof(struct udphdr)))
    157			goto drop;
    158
    159		return -prot;
    160	}
    161
    162	default: /* Undefined version */
    163		goto drop;
    164	}
    165
    166	optlen = guehdr->hlen << 2;
    167	len += optlen;
    168
    169	if (!pskb_may_pull(skb, len))
    170		goto drop;
    171
    172	/* guehdr may change after pull */
    173	guehdr = (struct guehdr *)&udp_hdr(skb)[1];
    174
    175	if (validate_gue_flags(guehdr, optlen))
    176		goto drop;
    177
    178	hdrlen = sizeof(struct guehdr) + optlen;
    179
    180	if (fou->family == AF_INET)
    181		ip_hdr(skb)->tot_len = htons(ntohs(ip_hdr(skb)->tot_len) - len);
    182	else
    183		ipv6_hdr(skb)->payload_len =
    184		    htons(ntohs(ipv6_hdr(skb)->payload_len) - len);
    185
    186	/* Pull csum through the guehdr now . This can be used if
    187	 * there is a remote checksum offload.
    188	 */
    189	skb_postpull_rcsum(skb, udp_hdr(skb), len);
    190
    191	data = &guehdr[1];
    192
    193	if (guehdr->flags & GUE_FLAG_PRIV) {
    194		__be32 flags = *(__be32 *)(data + doffset);
    195
    196		doffset += GUE_LEN_PRIV;
    197
    198		if (flags & GUE_PFLAG_REMCSUM) {
    199			guehdr = gue_remcsum(skb, guehdr, data + doffset,
    200					     hdrlen, guehdr->proto_ctype,
    201					     !!(fou->flags &
    202						FOU_F_REMCSUM_NOPARTIAL));
    203			if (!guehdr)
    204				goto drop;
    205
    206			data = &guehdr[1];
    207
    208			doffset += GUE_PLEN_REMCSUM;
    209		}
    210	}
    211
    212	if (unlikely(guehdr->control))
    213		return gue_control_message(skb, guehdr);
    214
    215	proto_ctype = guehdr->proto_ctype;
    216	__skb_pull(skb, sizeof(struct udphdr) + hdrlen);
    217	skb_reset_transport_header(skb);
    218
    219	if (iptunnel_pull_offloads(skb))
    220		goto drop;
    221
    222	return -proto_ctype;
    223
    224drop:
    225	kfree_skb(skb);
    226	return 0;
    227}
    228
    229static struct sk_buff *fou_gro_receive(struct sock *sk,
    230				       struct list_head *head,
    231				       struct sk_buff *skb)
    232{
    233	const struct net_offload __rcu **offloads;
    234	u8 proto = fou_from_sock(sk)->protocol;
    235	const struct net_offload *ops;
    236	struct sk_buff *pp = NULL;
    237
    238	/* We can clear the encap_mark for FOU as we are essentially doing
    239	 * one of two possible things.  We are either adding an L4 tunnel
    240	 * header to the outer L3 tunnel header, or we are simply
    241	 * treating the GRE tunnel header as though it is a UDP protocol
    242	 * specific header such as VXLAN or GENEVE.
    243	 */
    244	NAPI_GRO_CB(skb)->encap_mark = 0;
    245
    246	/* Flag this frame as already having an outer encap header */
    247	NAPI_GRO_CB(skb)->is_fou = 1;
    248
    249	offloads = NAPI_GRO_CB(skb)->is_ipv6 ? inet6_offloads : inet_offloads;
    250	ops = rcu_dereference(offloads[proto]);
    251	if (!ops || !ops->callbacks.gro_receive)
    252		goto out;
    253
    254	pp = call_gro_receive(ops->callbacks.gro_receive, head, skb);
    255
    256out:
    257	return pp;
    258}
    259
    260static int fou_gro_complete(struct sock *sk, struct sk_buff *skb,
    261			    int nhoff)
    262{
    263	const struct net_offload __rcu **offloads;
    264	u8 proto = fou_from_sock(sk)->protocol;
    265	const struct net_offload *ops;
    266	int err = -ENOSYS;
    267
    268	offloads = NAPI_GRO_CB(skb)->is_ipv6 ? inet6_offloads : inet_offloads;
    269	ops = rcu_dereference(offloads[proto]);
    270	if (WARN_ON(!ops || !ops->callbacks.gro_complete))
    271		goto out;
    272
    273	err = ops->callbacks.gro_complete(skb, nhoff);
    274
    275	skb_set_inner_mac_header(skb, nhoff);
    276
    277out:
    278	return err;
    279}
    280
    281static struct guehdr *gue_gro_remcsum(struct sk_buff *skb, unsigned int off,
    282				      struct guehdr *guehdr, void *data,
    283				      size_t hdrlen, struct gro_remcsum *grc,
    284				      bool nopartial)
    285{
    286	__be16 *pd = data;
    287	size_t start = ntohs(pd[0]);
    288	size_t offset = ntohs(pd[1]);
    289
    290	if (skb->remcsum_offload)
    291		return guehdr;
    292
    293	if (!NAPI_GRO_CB(skb)->csum_valid)
    294		return NULL;
    295
    296	guehdr = skb_gro_remcsum_process(skb, (void *)guehdr, off, hdrlen,
    297					 start, offset, grc, nopartial);
    298
    299	skb->remcsum_offload = 1;
    300
    301	return guehdr;
    302}
    303
    304static struct sk_buff *gue_gro_receive(struct sock *sk,
    305				       struct list_head *head,
    306				       struct sk_buff *skb)
    307{
    308	const struct net_offload __rcu **offloads;
    309	const struct net_offload *ops;
    310	struct sk_buff *pp = NULL;
    311	struct sk_buff *p;
    312	struct guehdr *guehdr;
    313	size_t len, optlen, hdrlen, off;
    314	void *data;
    315	u16 doffset = 0;
    316	int flush = 1;
    317	struct fou *fou = fou_from_sock(sk);
    318	struct gro_remcsum grc;
    319	u8 proto;
    320
    321	skb_gro_remcsum_init(&grc);
    322
    323	off = skb_gro_offset(skb);
    324	len = off + sizeof(*guehdr);
    325
    326	guehdr = skb_gro_header_fast(skb, off);
    327	if (skb_gro_header_hard(skb, len)) {
    328		guehdr = skb_gro_header_slow(skb, len, off);
    329		if (unlikely(!guehdr))
    330			goto out;
    331	}
    332
    333	switch (guehdr->version) {
    334	case 0:
    335		break;
    336	case 1:
    337		switch (((struct iphdr *)guehdr)->version) {
    338		case 4:
    339			proto = IPPROTO_IPIP;
    340			break;
    341		case 6:
    342			proto = IPPROTO_IPV6;
    343			break;
    344		default:
    345			goto out;
    346		}
    347		goto next_proto;
    348	default:
    349		goto out;
    350	}
    351
    352	optlen = guehdr->hlen << 2;
    353	len += optlen;
    354
    355	if (skb_gro_header_hard(skb, len)) {
    356		guehdr = skb_gro_header_slow(skb, len, off);
    357		if (unlikely(!guehdr))
    358			goto out;
    359	}
    360
    361	if (unlikely(guehdr->control) || guehdr->version != 0 ||
    362	    validate_gue_flags(guehdr, optlen))
    363		goto out;
    364
    365	hdrlen = sizeof(*guehdr) + optlen;
    366
    367	/* Adjust NAPI_GRO_CB(skb)->csum to account for guehdr,
    368	 * this is needed if there is a remote checkcsum offload.
    369	 */
    370	skb_gro_postpull_rcsum(skb, guehdr, hdrlen);
    371
    372	data = &guehdr[1];
    373
    374	if (guehdr->flags & GUE_FLAG_PRIV) {
    375		__be32 flags = *(__be32 *)(data + doffset);
    376
    377		doffset += GUE_LEN_PRIV;
    378
    379		if (flags & GUE_PFLAG_REMCSUM) {
    380			guehdr = gue_gro_remcsum(skb, off, guehdr,
    381						 data + doffset, hdrlen, &grc,
    382						 !!(fou->flags &
    383						    FOU_F_REMCSUM_NOPARTIAL));
    384
    385			if (!guehdr)
    386				goto out;
    387
    388			data = &guehdr[1];
    389
    390			doffset += GUE_PLEN_REMCSUM;
    391		}
    392	}
    393
    394	skb_gro_pull(skb, hdrlen);
    395
    396	list_for_each_entry(p, head, list) {
    397		const struct guehdr *guehdr2;
    398
    399		if (!NAPI_GRO_CB(p)->same_flow)
    400			continue;
    401
    402		guehdr2 = (struct guehdr *)(p->data + off);
    403
    404		/* Compare base GUE header to be equal (covers
    405		 * hlen, version, proto_ctype, and flags.
    406		 */
    407		if (guehdr->word != guehdr2->word) {
    408			NAPI_GRO_CB(p)->same_flow = 0;
    409			continue;
    410		}
    411
    412		/* Compare optional fields are the same. */
    413		if (guehdr->hlen && memcmp(&guehdr[1], &guehdr2[1],
    414					   guehdr->hlen << 2)) {
    415			NAPI_GRO_CB(p)->same_flow = 0;
    416			continue;
    417		}
    418	}
    419
    420	proto = guehdr->proto_ctype;
    421
    422next_proto:
    423
    424	/* We can clear the encap_mark for GUE as we are essentially doing
    425	 * one of two possible things.  We are either adding an L4 tunnel
    426	 * header to the outer L3 tunnel header, or we are simply
    427	 * treating the GRE tunnel header as though it is a UDP protocol
    428	 * specific header such as VXLAN or GENEVE.
    429	 */
    430	NAPI_GRO_CB(skb)->encap_mark = 0;
    431
    432	/* Flag this frame as already having an outer encap header */
    433	NAPI_GRO_CB(skb)->is_fou = 1;
    434
    435	offloads = NAPI_GRO_CB(skb)->is_ipv6 ? inet6_offloads : inet_offloads;
    436	ops = rcu_dereference(offloads[proto]);
    437	if (WARN_ON_ONCE(!ops || !ops->callbacks.gro_receive))
    438		goto out;
    439
    440	pp = call_gro_receive(ops->callbacks.gro_receive, head, skb);
    441	flush = 0;
    442
    443out:
    444	skb_gro_flush_final_remcsum(skb, pp, flush, &grc);
    445
    446	return pp;
    447}
    448
    449static int gue_gro_complete(struct sock *sk, struct sk_buff *skb, int nhoff)
    450{
    451	struct guehdr *guehdr = (struct guehdr *)(skb->data + nhoff);
    452	const struct net_offload __rcu **offloads;
    453	const struct net_offload *ops;
    454	unsigned int guehlen = 0;
    455	u8 proto;
    456	int err = -ENOENT;
    457
    458	switch (guehdr->version) {
    459	case 0:
    460		proto = guehdr->proto_ctype;
    461		guehlen = sizeof(*guehdr) + (guehdr->hlen << 2);
    462		break;
    463	case 1:
    464		switch (((struct iphdr *)guehdr)->version) {
    465		case 4:
    466			proto = IPPROTO_IPIP;
    467			break;
    468		case 6:
    469			proto = IPPROTO_IPV6;
    470			break;
    471		default:
    472			return err;
    473		}
    474		break;
    475	default:
    476		return err;
    477	}
    478
    479	offloads = NAPI_GRO_CB(skb)->is_ipv6 ? inet6_offloads : inet_offloads;
    480	ops = rcu_dereference(offloads[proto]);
    481	if (WARN_ON(!ops || !ops->callbacks.gro_complete))
    482		goto out;
    483
    484	err = ops->callbacks.gro_complete(skb, nhoff + guehlen);
    485
    486	skb_set_inner_mac_header(skb, nhoff + guehlen);
    487
    488out:
    489	return err;
    490}
    491
    492static bool fou_cfg_cmp(struct fou *fou, struct fou_cfg *cfg)
    493{
    494	struct sock *sk = fou->sock->sk;
    495	struct udp_port_cfg *udp_cfg = &cfg->udp_config;
    496
    497	if (fou->family != udp_cfg->family ||
    498	    fou->port != udp_cfg->local_udp_port ||
    499	    sk->sk_dport != udp_cfg->peer_udp_port ||
    500	    sk->sk_bound_dev_if != udp_cfg->bind_ifindex)
    501		return false;
    502
    503	if (fou->family == AF_INET) {
    504		if (sk->sk_rcv_saddr != udp_cfg->local_ip.s_addr ||
    505		    sk->sk_daddr != udp_cfg->peer_ip.s_addr)
    506			return false;
    507		else
    508			return true;
    509#if IS_ENABLED(CONFIG_IPV6)
    510	} else {
    511		if (ipv6_addr_cmp(&sk->sk_v6_rcv_saddr, &udp_cfg->local_ip6) ||
    512		    ipv6_addr_cmp(&sk->sk_v6_daddr, &udp_cfg->peer_ip6))
    513			return false;
    514		else
    515			return true;
    516#endif
    517	}
    518
    519	return false;
    520}
    521
    522static int fou_add_to_port_list(struct net *net, struct fou *fou,
    523				struct fou_cfg *cfg)
    524{
    525	struct fou_net *fn = net_generic(net, fou_net_id);
    526	struct fou *fout;
    527
    528	mutex_lock(&fn->fou_lock);
    529	list_for_each_entry(fout, &fn->fou_list, list) {
    530		if (fou_cfg_cmp(fout, cfg)) {
    531			mutex_unlock(&fn->fou_lock);
    532			return -EALREADY;
    533		}
    534	}
    535
    536	list_add(&fou->list, &fn->fou_list);
    537	mutex_unlock(&fn->fou_lock);
    538
    539	return 0;
    540}
    541
    542static void fou_release(struct fou *fou)
    543{
    544	struct socket *sock = fou->sock;
    545
    546	list_del(&fou->list);
    547	udp_tunnel_sock_release(sock);
    548
    549	kfree_rcu(fou, rcu);
    550}
    551
    552static int fou_create(struct net *net, struct fou_cfg *cfg,
    553		      struct socket **sockp)
    554{
    555	struct socket *sock = NULL;
    556	struct fou *fou = NULL;
    557	struct sock *sk;
    558	struct udp_tunnel_sock_cfg tunnel_cfg;
    559	int err;
    560
    561	/* Open UDP socket */
    562	err = udp_sock_create(net, &cfg->udp_config, &sock);
    563	if (err < 0)
    564		goto error;
    565
    566	/* Allocate FOU port structure */
    567	fou = kzalloc(sizeof(*fou), GFP_KERNEL);
    568	if (!fou) {
    569		err = -ENOMEM;
    570		goto error;
    571	}
    572
    573	sk = sock->sk;
    574
    575	fou->port = cfg->udp_config.local_udp_port;
    576	fou->family = cfg->udp_config.family;
    577	fou->flags = cfg->flags;
    578	fou->type = cfg->type;
    579	fou->sock = sock;
    580
    581	memset(&tunnel_cfg, 0, sizeof(tunnel_cfg));
    582	tunnel_cfg.encap_type = 1;
    583	tunnel_cfg.sk_user_data = fou;
    584	tunnel_cfg.encap_destroy = NULL;
    585
    586	/* Initial for fou type */
    587	switch (cfg->type) {
    588	case FOU_ENCAP_DIRECT:
    589		tunnel_cfg.encap_rcv = fou_udp_recv;
    590		tunnel_cfg.gro_receive = fou_gro_receive;
    591		tunnel_cfg.gro_complete = fou_gro_complete;
    592		fou->protocol = cfg->protocol;
    593		break;
    594	case FOU_ENCAP_GUE:
    595		tunnel_cfg.encap_rcv = gue_udp_recv;
    596		tunnel_cfg.gro_receive = gue_gro_receive;
    597		tunnel_cfg.gro_complete = gue_gro_complete;
    598		break;
    599	default:
    600		err = -EINVAL;
    601		goto error;
    602	}
    603
    604	setup_udp_tunnel_sock(net, sock, &tunnel_cfg);
    605
    606	sk->sk_allocation = GFP_ATOMIC;
    607
    608	err = fou_add_to_port_list(net, fou, cfg);
    609	if (err)
    610		goto error;
    611
    612	if (sockp)
    613		*sockp = sock;
    614
    615	return 0;
    616
    617error:
    618	kfree(fou);
    619	if (sock)
    620		udp_tunnel_sock_release(sock);
    621
    622	return err;
    623}
    624
    625static int fou_destroy(struct net *net, struct fou_cfg *cfg)
    626{
    627	struct fou_net *fn = net_generic(net, fou_net_id);
    628	int err = -EINVAL;
    629	struct fou *fou;
    630
    631	mutex_lock(&fn->fou_lock);
    632	list_for_each_entry(fou, &fn->fou_list, list) {
    633		if (fou_cfg_cmp(fou, cfg)) {
    634			fou_release(fou);
    635			err = 0;
    636			break;
    637		}
    638	}
    639	mutex_unlock(&fn->fou_lock);
    640
    641	return err;
    642}
    643
    644static struct genl_family fou_nl_family;
    645
    646static const struct nla_policy fou_nl_policy[FOU_ATTR_MAX + 1] = {
    647	[FOU_ATTR_PORT]			= { .type = NLA_U16, },
    648	[FOU_ATTR_AF]			= { .type = NLA_U8, },
    649	[FOU_ATTR_IPPROTO]		= { .type = NLA_U8, },
    650	[FOU_ATTR_TYPE]			= { .type = NLA_U8, },
    651	[FOU_ATTR_REMCSUM_NOPARTIAL]	= { .type = NLA_FLAG, },
    652	[FOU_ATTR_LOCAL_V4]		= { .type = NLA_U32, },
    653	[FOU_ATTR_PEER_V4]		= { .type = NLA_U32, },
    654	[FOU_ATTR_LOCAL_V6]		= { .len = sizeof(struct in6_addr), },
    655	[FOU_ATTR_PEER_V6]		= { .len = sizeof(struct in6_addr), },
    656	[FOU_ATTR_PEER_PORT]		= { .type = NLA_U16, },
    657	[FOU_ATTR_IFINDEX]		= { .type = NLA_S32, },
    658};
    659
    660static int parse_nl_config(struct genl_info *info,
    661			   struct fou_cfg *cfg)
    662{
    663	bool has_local = false, has_peer = false;
    664	struct nlattr *attr;
    665	int ifindex;
    666	__be16 port;
    667
    668	memset(cfg, 0, sizeof(*cfg));
    669
    670	cfg->udp_config.family = AF_INET;
    671
    672	if (info->attrs[FOU_ATTR_AF]) {
    673		u8 family = nla_get_u8(info->attrs[FOU_ATTR_AF]);
    674
    675		switch (family) {
    676		case AF_INET:
    677			break;
    678		case AF_INET6:
    679			cfg->udp_config.ipv6_v6only = 1;
    680			break;
    681		default:
    682			return -EAFNOSUPPORT;
    683		}
    684
    685		cfg->udp_config.family = family;
    686	}
    687
    688	if (info->attrs[FOU_ATTR_PORT]) {
    689		port = nla_get_be16(info->attrs[FOU_ATTR_PORT]);
    690		cfg->udp_config.local_udp_port = port;
    691	}
    692
    693	if (info->attrs[FOU_ATTR_IPPROTO])
    694		cfg->protocol = nla_get_u8(info->attrs[FOU_ATTR_IPPROTO]);
    695
    696	if (info->attrs[FOU_ATTR_TYPE])
    697		cfg->type = nla_get_u8(info->attrs[FOU_ATTR_TYPE]);
    698
    699	if (info->attrs[FOU_ATTR_REMCSUM_NOPARTIAL])
    700		cfg->flags |= FOU_F_REMCSUM_NOPARTIAL;
    701
    702	if (cfg->udp_config.family == AF_INET) {
    703		if (info->attrs[FOU_ATTR_LOCAL_V4]) {
    704			attr = info->attrs[FOU_ATTR_LOCAL_V4];
    705			cfg->udp_config.local_ip.s_addr = nla_get_in_addr(attr);
    706			has_local = true;
    707		}
    708
    709		if (info->attrs[FOU_ATTR_PEER_V4]) {
    710			attr = info->attrs[FOU_ATTR_PEER_V4];
    711			cfg->udp_config.peer_ip.s_addr = nla_get_in_addr(attr);
    712			has_peer = true;
    713		}
    714#if IS_ENABLED(CONFIG_IPV6)
    715	} else {
    716		if (info->attrs[FOU_ATTR_LOCAL_V6]) {
    717			attr = info->attrs[FOU_ATTR_LOCAL_V6];
    718			cfg->udp_config.local_ip6 = nla_get_in6_addr(attr);
    719			has_local = true;
    720		}
    721
    722		if (info->attrs[FOU_ATTR_PEER_V6]) {
    723			attr = info->attrs[FOU_ATTR_PEER_V6];
    724			cfg->udp_config.peer_ip6 = nla_get_in6_addr(attr);
    725			has_peer = true;
    726		}
    727#endif
    728	}
    729
    730	if (has_peer) {
    731		if (info->attrs[FOU_ATTR_PEER_PORT]) {
    732			port = nla_get_be16(info->attrs[FOU_ATTR_PEER_PORT]);
    733			cfg->udp_config.peer_udp_port = port;
    734		} else {
    735			return -EINVAL;
    736		}
    737	}
    738
    739	if (info->attrs[FOU_ATTR_IFINDEX]) {
    740		if (!has_local)
    741			return -EINVAL;
    742
    743		ifindex = nla_get_s32(info->attrs[FOU_ATTR_IFINDEX]);
    744
    745		cfg->udp_config.bind_ifindex = ifindex;
    746	}
    747
    748	return 0;
    749}
    750
    751static int fou_nl_cmd_add_port(struct sk_buff *skb, struct genl_info *info)
    752{
    753	struct net *net = genl_info_net(info);
    754	struct fou_cfg cfg;
    755	int err;
    756
    757	err = parse_nl_config(info, &cfg);
    758	if (err)
    759		return err;
    760
    761	return fou_create(net, &cfg, NULL);
    762}
    763
    764static int fou_nl_cmd_rm_port(struct sk_buff *skb, struct genl_info *info)
    765{
    766	struct net *net = genl_info_net(info);
    767	struct fou_cfg cfg;
    768	int err;
    769
    770	err = parse_nl_config(info, &cfg);
    771	if (err)
    772		return err;
    773
    774	return fou_destroy(net, &cfg);
    775}
    776
    777static int fou_fill_info(struct fou *fou, struct sk_buff *msg)
    778{
    779	struct sock *sk = fou->sock->sk;
    780
    781	if (nla_put_u8(msg, FOU_ATTR_AF, fou->sock->sk->sk_family) ||
    782	    nla_put_be16(msg, FOU_ATTR_PORT, fou->port) ||
    783	    nla_put_be16(msg, FOU_ATTR_PEER_PORT, sk->sk_dport) ||
    784	    nla_put_u8(msg, FOU_ATTR_IPPROTO, fou->protocol) ||
    785	    nla_put_u8(msg, FOU_ATTR_TYPE, fou->type) ||
    786	    nla_put_s32(msg, FOU_ATTR_IFINDEX, sk->sk_bound_dev_if))
    787		return -1;
    788
    789	if (fou->flags & FOU_F_REMCSUM_NOPARTIAL)
    790		if (nla_put_flag(msg, FOU_ATTR_REMCSUM_NOPARTIAL))
    791			return -1;
    792
    793	if (fou->sock->sk->sk_family == AF_INET) {
    794		if (nla_put_in_addr(msg, FOU_ATTR_LOCAL_V4, sk->sk_rcv_saddr))
    795			return -1;
    796
    797		if (nla_put_in_addr(msg, FOU_ATTR_PEER_V4, sk->sk_daddr))
    798			return -1;
    799#if IS_ENABLED(CONFIG_IPV6)
    800	} else {
    801		if (nla_put_in6_addr(msg, FOU_ATTR_LOCAL_V6,
    802				     &sk->sk_v6_rcv_saddr))
    803			return -1;
    804
    805		if (nla_put_in6_addr(msg, FOU_ATTR_PEER_V6, &sk->sk_v6_daddr))
    806			return -1;
    807#endif
    808	}
    809
    810	return 0;
    811}
    812
    813static int fou_dump_info(struct fou *fou, u32 portid, u32 seq,
    814			 u32 flags, struct sk_buff *skb, u8 cmd)
    815{
    816	void *hdr;
    817
    818	hdr = genlmsg_put(skb, portid, seq, &fou_nl_family, flags, cmd);
    819	if (!hdr)
    820		return -ENOMEM;
    821
    822	if (fou_fill_info(fou, skb) < 0)
    823		goto nla_put_failure;
    824
    825	genlmsg_end(skb, hdr);
    826	return 0;
    827
    828nla_put_failure:
    829	genlmsg_cancel(skb, hdr);
    830	return -EMSGSIZE;
    831}
    832
    833static int fou_nl_cmd_get_port(struct sk_buff *skb, struct genl_info *info)
    834{
    835	struct net *net = genl_info_net(info);
    836	struct fou_net *fn = net_generic(net, fou_net_id);
    837	struct sk_buff *msg;
    838	struct fou_cfg cfg;
    839	struct fou *fout;
    840	__be16 port;
    841	u8 family;
    842	int ret;
    843
    844	ret = parse_nl_config(info, &cfg);
    845	if (ret)
    846		return ret;
    847	port = cfg.udp_config.local_udp_port;
    848	if (port == 0)
    849		return -EINVAL;
    850
    851	family = cfg.udp_config.family;
    852	if (family != AF_INET && family != AF_INET6)
    853		return -EINVAL;
    854
    855	msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL);
    856	if (!msg)
    857		return -ENOMEM;
    858
    859	ret = -ESRCH;
    860	mutex_lock(&fn->fou_lock);
    861	list_for_each_entry(fout, &fn->fou_list, list) {
    862		if (fou_cfg_cmp(fout, &cfg)) {
    863			ret = fou_dump_info(fout, info->snd_portid,
    864					    info->snd_seq, 0, msg,
    865					    info->genlhdr->cmd);
    866			break;
    867		}
    868	}
    869	mutex_unlock(&fn->fou_lock);
    870	if (ret < 0)
    871		goto out_free;
    872
    873	return genlmsg_reply(msg, info);
    874
    875out_free:
    876	nlmsg_free(msg);
    877	return ret;
    878}
    879
    880static int fou_nl_dump(struct sk_buff *skb, struct netlink_callback *cb)
    881{
    882	struct net *net = sock_net(skb->sk);
    883	struct fou_net *fn = net_generic(net, fou_net_id);
    884	struct fou *fout;
    885	int idx = 0, ret;
    886
    887	mutex_lock(&fn->fou_lock);
    888	list_for_each_entry(fout, &fn->fou_list, list) {
    889		if (idx++ < cb->args[0])
    890			continue;
    891		ret = fou_dump_info(fout, NETLINK_CB(cb->skb).portid,
    892				    cb->nlh->nlmsg_seq, NLM_F_MULTI,
    893				    skb, FOU_CMD_GET);
    894		if (ret)
    895			break;
    896	}
    897	mutex_unlock(&fn->fou_lock);
    898
    899	cb->args[0] = idx;
    900	return skb->len;
    901}
    902
    903static const struct genl_small_ops fou_nl_ops[] = {
    904	{
    905		.cmd = FOU_CMD_ADD,
    906		.validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP,
    907		.doit = fou_nl_cmd_add_port,
    908		.flags = GENL_ADMIN_PERM,
    909	},
    910	{
    911		.cmd = FOU_CMD_DEL,
    912		.validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP,
    913		.doit = fou_nl_cmd_rm_port,
    914		.flags = GENL_ADMIN_PERM,
    915	},
    916	{
    917		.cmd = FOU_CMD_GET,
    918		.validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP,
    919		.doit = fou_nl_cmd_get_port,
    920		.dumpit = fou_nl_dump,
    921	},
    922};
    923
    924static struct genl_family fou_nl_family __ro_after_init = {
    925	.hdrsize	= 0,
    926	.name		= FOU_GENL_NAME,
    927	.version	= FOU_GENL_VERSION,
    928	.maxattr	= FOU_ATTR_MAX,
    929	.policy = fou_nl_policy,
    930	.netnsok	= true,
    931	.module		= THIS_MODULE,
    932	.small_ops	= fou_nl_ops,
    933	.n_small_ops	= ARRAY_SIZE(fou_nl_ops),
    934};
    935
    936size_t fou_encap_hlen(struct ip_tunnel_encap *e)
    937{
    938	return sizeof(struct udphdr);
    939}
    940EXPORT_SYMBOL(fou_encap_hlen);
    941
    942size_t gue_encap_hlen(struct ip_tunnel_encap *e)
    943{
    944	size_t len;
    945	bool need_priv = false;
    946
    947	len = sizeof(struct udphdr) + sizeof(struct guehdr);
    948
    949	if (e->flags & TUNNEL_ENCAP_FLAG_REMCSUM) {
    950		len += GUE_PLEN_REMCSUM;
    951		need_priv = true;
    952	}
    953
    954	len += need_priv ? GUE_LEN_PRIV : 0;
    955
    956	return len;
    957}
    958EXPORT_SYMBOL(gue_encap_hlen);
    959
    960int __fou_build_header(struct sk_buff *skb, struct ip_tunnel_encap *e,
    961		       u8 *protocol, __be16 *sport, int type)
    962{
    963	int err;
    964
    965	err = iptunnel_handle_offloads(skb, type);
    966	if (err)
    967		return err;
    968
    969	*sport = e->sport ? : udp_flow_src_port(dev_net(skb->dev),
    970						skb, 0, 0, false);
    971
    972	return 0;
    973}
    974EXPORT_SYMBOL(__fou_build_header);
    975
    976int __gue_build_header(struct sk_buff *skb, struct ip_tunnel_encap *e,
    977		       u8 *protocol, __be16 *sport, int type)
    978{
    979	struct guehdr *guehdr;
    980	size_t hdrlen, optlen = 0;
    981	void *data;
    982	bool need_priv = false;
    983	int err;
    984
    985	if ((e->flags & TUNNEL_ENCAP_FLAG_REMCSUM) &&
    986	    skb->ip_summed == CHECKSUM_PARTIAL) {
    987		optlen += GUE_PLEN_REMCSUM;
    988		type |= SKB_GSO_TUNNEL_REMCSUM;
    989		need_priv = true;
    990	}
    991
    992	optlen += need_priv ? GUE_LEN_PRIV : 0;
    993
    994	err = iptunnel_handle_offloads(skb, type);
    995	if (err)
    996		return err;
    997
    998	/* Get source port (based on flow hash) before skb_push */
    999	*sport = e->sport ? : udp_flow_src_port(dev_net(skb->dev),
   1000						skb, 0, 0, false);
   1001
   1002	hdrlen = sizeof(struct guehdr) + optlen;
   1003
   1004	skb_push(skb, hdrlen);
   1005
   1006	guehdr = (struct guehdr *)skb->data;
   1007
   1008	guehdr->control = 0;
   1009	guehdr->version = 0;
   1010	guehdr->hlen = optlen >> 2;
   1011	guehdr->flags = 0;
   1012	guehdr->proto_ctype = *protocol;
   1013
   1014	data = &guehdr[1];
   1015
   1016	if (need_priv) {
   1017		__be32 *flags = data;
   1018
   1019		guehdr->flags |= GUE_FLAG_PRIV;
   1020		*flags = 0;
   1021		data += GUE_LEN_PRIV;
   1022
   1023		if (type & SKB_GSO_TUNNEL_REMCSUM) {
   1024			u16 csum_start = skb_checksum_start_offset(skb);
   1025			__be16 *pd = data;
   1026
   1027			if (csum_start < hdrlen)
   1028				return -EINVAL;
   1029
   1030			csum_start -= hdrlen;
   1031			pd[0] = htons(csum_start);
   1032			pd[1] = htons(csum_start + skb->csum_offset);
   1033
   1034			if (!skb_is_gso(skb)) {
   1035				skb->ip_summed = CHECKSUM_NONE;
   1036				skb->encapsulation = 0;
   1037			}
   1038
   1039			*flags |= GUE_PFLAG_REMCSUM;
   1040			data += GUE_PLEN_REMCSUM;
   1041		}
   1042
   1043	}
   1044
   1045	return 0;
   1046}
   1047EXPORT_SYMBOL(__gue_build_header);
   1048
   1049#ifdef CONFIG_NET_FOU_IP_TUNNELS
   1050
   1051static void fou_build_udp(struct sk_buff *skb, struct ip_tunnel_encap *e,
   1052			  struct flowi4 *fl4, u8 *protocol, __be16 sport)
   1053{
   1054	struct udphdr *uh;
   1055
   1056	skb_push(skb, sizeof(struct udphdr));
   1057	skb_reset_transport_header(skb);
   1058
   1059	uh = udp_hdr(skb);
   1060
   1061	uh->dest = e->dport;
   1062	uh->source = sport;
   1063	uh->len = htons(skb->len);
   1064	udp_set_csum(!(e->flags & TUNNEL_ENCAP_FLAG_CSUM), skb,
   1065		     fl4->saddr, fl4->daddr, skb->len);
   1066
   1067	*protocol = IPPROTO_UDP;
   1068}
   1069
   1070static int fou_build_header(struct sk_buff *skb, struct ip_tunnel_encap *e,
   1071			    u8 *protocol, struct flowi4 *fl4)
   1072{
   1073	int type = e->flags & TUNNEL_ENCAP_FLAG_CSUM ? SKB_GSO_UDP_TUNNEL_CSUM :
   1074						       SKB_GSO_UDP_TUNNEL;
   1075	__be16 sport;
   1076	int err;
   1077
   1078	err = __fou_build_header(skb, e, protocol, &sport, type);
   1079	if (err)
   1080		return err;
   1081
   1082	fou_build_udp(skb, e, fl4, protocol, sport);
   1083
   1084	return 0;
   1085}
   1086
   1087static int gue_build_header(struct sk_buff *skb, struct ip_tunnel_encap *e,
   1088			    u8 *protocol, struct flowi4 *fl4)
   1089{
   1090	int type = e->flags & TUNNEL_ENCAP_FLAG_CSUM ? SKB_GSO_UDP_TUNNEL_CSUM :
   1091						       SKB_GSO_UDP_TUNNEL;
   1092	__be16 sport;
   1093	int err;
   1094
   1095	err = __gue_build_header(skb, e, protocol, &sport, type);
   1096	if (err)
   1097		return err;
   1098
   1099	fou_build_udp(skb, e, fl4, protocol, sport);
   1100
   1101	return 0;
   1102}
   1103
   1104static int gue_err_proto_handler(int proto, struct sk_buff *skb, u32 info)
   1105{
   1106	const struct net_protocol *ipprot = rcu_dereference(inet_protos[proto]);
   1107
   1108	if (ipprot && ipprot->err_handler) {
   1109		if (!ipprot->err_handler(skb, info))
   1110			return 0;
   1111	}
   1112
   1113	return -ENOENT;
   1114}
   1115
   1116static int gue_err(struct sk_buff *skb, u32 info)
   1117{
   1118	int transport_offset = skb_transport_offset(skb);
   1119	struct guehdr *guehdr;
   1120	size_t len, optlen;
   1121	int ret;
   1122
   1123	len = sizeof(struct udphdr) + sizeof(struct guehdr);
   1124	if (!pskb_may_pull(skb, transport_offset + len))
   1125		return -EINVAL;
   1126
   1127	guehdr = (struct guehdr *)&udp_hdr(skb)[1];
   1128
   1129	switch (guehdr->version) {
   1130	case 0: /* Full GUE header present */
   1131		break;
   1132	case 1: {
   1133		/* Direct encapsulation of IPv4 or IPv6 */
   1134		skb_set_transport_header(skb, -(int)sizeof(struct icmphdr));
   1135
   1136		switch (((struct iphdr *)guehdr)->version) {
   1137		case 4:
   1138			ret = gue_err_proto_handler(IPPROTO_IPIP, skb, info);
   1139			goto out;
   1140#if IS_ENABLED(CONFIG_IPV6)
   1141		case 6:
   1142			ret = gue_err_proto_handler(IPPROTO_IPV6, skb, info);
   1143			goto out;
   1144#endif
   1145		default:
   1146			ret = -EOPNOTSUPP;
   1147			goto out;
   1148		}
   1149	}
   1150	default: /* Undefined version */
   1151		return -EOPNOTSUPP;
   1152	}
   1153
   1154	if (guehdr->control)
   1155		return -ENOENT;
   1156
   1157	optlen = guehdr->hlen << 2;
   1158
   1159	if (!pskb_may_pull(skb, transport_offset + len + optlen))
   1160		return -EINVAL;
   1161
   1162	guehdr = (struct guehdr *)&udp_hdr(skb)[1];
   1163	if (validate_gue_flags(guehdr, optlen))
   1164		return -EINVAL;
   1165
   1166	/* Handling exceptions for direct UDP encapsulation in GUE would lead to
   1167	 * recursion. Besides, this kind of encapsulation can't even be
   1168	 * configured currently. Discard this.
   1169	 */
   1170	if (guehdr->proto_ctype == IPPROTO_UDP ||
   1171	    guehdr->proto_ctype == IPPROTO_UDPLITE)
   1172		return -EOPNOTSUPP;
   1173
   1174	skb_set_transport_header(skb, -(int)sizeof(struct icmphdr));
   1175	ret = gue_err_proto_handler(guehdr->proto_ctype, skb, info);
   1176
   1177out:
   1178	skb_set_transport_header(skb, transport_offset);
   1179	return ret;
   1180}
   1181
   1182
   1183static const struct ip_tunnel_encap_ops fou_iptun_ops = {
   1184	.encap_hlen = fou_encap_hlen,
   1185	.build_header = fou_build_header,
   1186	.err_handler = gue_err,
   1187};
   1188
   1189static const struct ip_tunnel_encap_ops gue_iptun_ops = {
   1190	.encap_hlen = gue_encap_hlen,
   1191	.build_header = gue_build_header,
   1192	.err_handler = gue_err,
   1193};
   1194
   1195static int ip_tunnel_encap_add_fou_ops(void)
   1196{
   1197	int ret;
   1198
   1199	ret = ip_tunnel_encap_add_ops(&fou_iptun_ops, TUNNEL_ENCAP_FOU);
   1200	if (ret < 0) {
   1201		pr_err("can't add fou ops\n");
   1202		return ret;
   1203	}
   1204
   1205	ret = ip_tunnel_encap_add_ops(&gue_iptun_ops, TUNNEL_ENCAP_GUE);
   1206	if (ret < 0) {
   1207		pr_err("can't add gue ops\n");
   1208		ip_tunnel_encap_del_ops(&fou_iptun_ops, TUNNEL_ENCAP_FOU);
   1209		return ret;
   1210	}
   1211
   1212	return 0;
   1213}
   1214
   1215static void ip_tunnel_encap_del_fou_ops(void)
   1216{
   1217	ip_tunnel_encap_del_ops(&fou_iptun_ops, TUNNEL_ENCAP_FOU);
   1218	ip_tunnel_encap_del_ops(&gue_iptun_ops, TUNNEL_ENCAP_GUE);
   1219}
   1220
   1221#else
   1222
   1223static int ip_tunnel_encap_add_fou_ops(void)
   1224{
   1225	return 0;
   1226}
   1227
   1228static void ip_tunnel_encap_del_fou_ops(void)
   1229{
   1230}
   1231
   1232#endif
   1233
   1234static __net_init int fou_init_net(struct net *net)
   1235{
   1236	struct fou_net *fn = net_generic(net, fou_net_id);
   1237
   1238	INIT_LIST_HEAD(&fn->fou_list);
   1239	mutex_init(&fn->fou_lock);
   1240	return 0;
   1241}
   1242
   1243static __net_exit void fou_exit_net(struct net *net)
   1244{
   1245	struct fou_net *fn = net_generic(net, fou_net_id);
   1246	struct fou *fou, *next;
   1247
   1248	/* Close all the FOU sockets */
   1249	mutex_lock(&fn->fou_lock);
   1250	list_for_each_entry_safe(fou, next, &fn->fou_list, list)
   1251		fou_release(fou);
   1252	mutex_unlock(&fn->fou_lock);
   1253}
   1254
   1255static struct pernet_operations fou_net_ops = {
   1256	.init = fou_init_net,
   1257	.exit = fou_exit_net,
   1258	.id   = &fou_net_id,
   1259	.size = sizeof(struct fou_net),
   1260};
   1261
   1262static int __init fou_init(void)
   1263{
   1264	int ret;
   1265
   1266	ret = register_pernet_device(&fou_net_ops);
   1267	if (ret)
   1268		goto exit;
   1269
   1270	ret = genl_register_family(&fou_nl_family);
   1271	if (ret < 0)
   1272		goto unregister;
   1273
   1274	ret = ip_tunnel_encap_add_fou_ops();
   1275	if (ret == 0)
   1276		return 0;
   1277
   1278	genl_unregister_family(&fou_nl_family);
   1279unregister:
   1280	unregister_pernet_device(&fou_net_ops);
   1281exit:
   1282	return ret;
   1283}
   1284
   1285static void __exit fou_fini(void)
   1286{
   1287	ip_tunnel_encap_del_fou_ops();
   1288	genl_unregister_family(&fou_nl_family);
   1289	unregister_pernet_device(&fou_net_ops);
   1290}
   1291
   1292module_init(fou_init);
   1293module_exit(fou_fini);
   1294MODULE_AUTHOR("Tom Herbert <therbert@google.com>");
   1295MODULE_LICENSE("GPL");
   1296MODULE_DESCRIPTION("Foo over UDP");