cachepc-linux

Fork of AMDESE/linux with modifications for CachePC side-channel attack
git clone https://git.sinitax.com/sinitax/cachepc-linux
Log | Files | Refs | README | LICENSE | sfeed.txt

esp4.c (28116B)


      1// SPDX-License-Identifier: GPL-2.0-only
      2#define pr_fmt(fmt) "IPsec: " fmt
      3
      4#include <crypto/aead.h>
      5#include <crypto/authenc.h>
      6#include <linux/err.h>
      7#include <linux/module.h>
      8#include <net/ip.h>
      9#include <net/xfrm.h>
     10#include <net/esp.h>
     11#include <linux/scatterlist.h>
     12#include <linux/kernel.h>
     13#include <linux/pfkeyv2.h>
     14#include <linux/rtnetlink.h>
     15#include <linux/slab.h>
     16#include <linux/spinlock.h>
     17#include <linux/in6.h>
     18#include <net/icmp.h>
     19#include <net/protocol.h>
     20#include <net/udp.h>
     21#include <net/tcp.h>
     22#include <net/espintcp.h>
     23
     24#include <linux/highmem.h>
     25
     26struct esp_skb_cb {
     27	struct xfrm_skb_cb xfrm;
     28	void *tmp;
     29};
     30
     31struct esp_output_extra {
     32	__be32 seqhi;
     33	u32 esphoff;
     34};
     35
     36#define ESP_SKB_CB(__skb) ((struct esp_skb_cb *)&((__skb)->cb[0]))
     37
     38/*
     39 * Allocate an AEAD request structure with extra space for SG and IV.
     40 *
     41 * For alignment considerations the IV is placed at the front, followed
     42 * by the request and finally the SG list.
     43 *
     44 * TODO: Use spare space in skb for this where possible.
     45 */
     46static void *esp_alloc_tmp(struct crypto_aead *aead, int nfrags, int extralen)
     47{
     48	unsigned int len;
     49
     50	len = extralen;
     51
     52	len += crypto_aead_ivsize(aead);
     53
     54	if (len) {
     55		len += crypto_aead_alignmask(aead) &
     56		       ~(crypto_tfm_ctx_alignment() - 1);
     57		len = ALIGN(len, crypto_tfm_ctx_alignment());
     58	}
     59
     60	len += sizeof(struct aead_request) + crypto_aead_reqsize(aead);
     61	len = ALIGN(len, __alignof__(struct scatterlist));
     62
     63	len += sizeof(struct scatterlist) * nfrags;
     64
     65	return kmalloc(len, GFP_ATOMIC);
     66}
     67
     68static inline void *esp_tmp_extra(void *tmp)
     69{
     70	return PTR_ALIGN(tmp, __alignof__(struct esp_output_extra));
     71}
     72
     73static inline u8 *esp_tmp_iv(struct crypto_aead *aead, void *tmp, int extralen)
     74{
     75	return crypto_aead_ivsize(aead) ?
     76	       PTR_ALIGN((u8 *)tmp + extralen,
     77			 crypto_aead_alignmask(aead) + 1) : tmp + extralen;
     78}
     79
     80static inline struct aead_request *esp_tmp_req(struct crypto_aead *aead, u8 *iv)
     81{
     82	struct aead_request *req;
     83
     84	req = (void *)PTR_ALIGN(iv + crypto_aead_ivsize(aead),
     85				crypto_tfm_ctx_alignment());
     86	aead_request_set_tfm(req, aead);
     87	return req;
     88}
     89
     90static inline struct scatterlist *esp_req_sg(struct crypto_aead *aead,
     91					     struct aead_request *req)
     92{
     93	return (void *)ALIGN((unsigned long)(req + 1) +
     94			     crypto_aead_reqsize(aead),
     95			     __alignof__(struct scatterlist));
     96}
     97
     98static void esp_ssg_unref(struct xfrm_state *x, void *tmp)
     99{
    100	struct crypto_aead *aead = x->data;
    101	int extralen = 0;
    102	u8 *iv;
    103	struct aead_request *req;
    104	struct scatterlist *sg;
    105
    106	if (x->props.flags & XFRM_STATE_ESN)
    107		extralen += sizeof(struct esp_output_extra);
    108
    109	iv = esp_tmp_iv(aead, tmp, extralen);
    110	req = esp_tmp_req(aead, iv);
    111
    112	/* Unref skb_frag_pages in the src scatterlist if necessary.
    113	 * Skip the first sg which comes from skb->data.
    114	 */
    115	if (req->src != req->dst)
    116		for (sg = sg_next(req->src); sg; sg = sg_next(sg))
    117			put_page(sg_page(sg));
    118}
    119
    120#ifdef CONFIG_INET_ESPINTCP
    121struct esp_tcp_sk {
    122	struct sock *sk;
    123	struct rcu_head rcu;
    124};
    125
    126static void esp_free_tcp_sk(struct rcu_head *head)
    127{
    128	struct esp_tcp_sk *esk = container_of(head, struct esp_tcp_sk, rcu);
    129
    130	sock_put(esk->sk);
    131	kfree(esk);
    132}
    133
    134static struct sock *esp_find_tcp_sk(struct xfrm_state *x)
    135{
    136	struct xfrm_encap_tmpl *encap = x->encap;
    137	struct esp_tcp_sk *esk;
    138	__be16 sport, dport;
    139	struct sock *nsk;
    140	struct sock *sk;
    141
    142	sk = rcu_dereference(x->encap_sk);
    143	if (sk && sk->sk_state == TCP_ESTABLISHED)
    144		return sk;
    145
    146	spin_lock_bh(&x->lock);
    147	sport = encap->encap_sport;
    148	dport = encap->encap_dport;
    149	nsk = rcu_dereference_protected(x->encap_sk,
    150					lockdep_is_held(&x->lock));
    151	if (sk && sk == nsk) {
    152		esk = kmalloc(sizeof(*esk), GFP_ATOMIC);
    153		if (!esk) {
    154			spin_unlock_bh(&x->lock);
    155			return ERR_PTR(-ENOMEM);
    156		}
    157		RCU_INIT_POINTER(x->encap_sk, NULL);
    158		esk->sk = sk;
    159		call_rcu(&esk->rcu, esp_free_tcp_sk);
    160	}
    161	spin_unlock_bh(&x->lock);
    162
    163	sk = inet_lookup_established(xs_net(x), &tcp_hashinfo, x->id.daddr.a4,
    164				     dport, x->props.saddr.a4, sport, 0);
    165	if (!sk)
    166		return ERR_PTR(-ENOENT);
    167
    168	if (!tcp_is_ulp_esp(sk)) {
    169		sock_put(sk);
    170		return ERR_PTR(-EINVAL);
    171	}
    172
    173	spin_lock_bh(&x->lock);
    174	nsk = rcu_dereference_protected(x->encap_sk,
    175					lockdep_is_held(&x->lock));
    176	if (encap->encap_sport != sport ||
    177	    encap->encap_dport != dport) {
    178		sock_put(sk);
    179		sk = nsk ?: ERR_PTR(-EREMCHG);
    180	} else if (sk == nsk) {
    181		sock_put(sk);
    182	} else {
    183		rcu_assign_pointer(x->encap_sk, sk);
    184	}
    185	spin_unlock_bh(&x->lock);
    186
    187	return sk;
    188}
    189
    190static int esp_output_tcp_finish(struct xfrm_state *x, struct sk_buff *skb)
    191{
    192	struct sock *sk;
    193	int err;
    194
    195	rcu_read_lock();
    196
    197	sk = esp_find_tcp_sk(x);
    198	err = PTR_ERR_OR_ZERO(sk);
    199	if (err)
    200		goto out;
    201
    202	bh_lock_sock(sk);
    203	if (sock_owned_by_user(sk))
    204		err = espintcp_queue_out(sk, skb);
    205	else
    206		err = espintcp_push_skb(sk, skb);
    207	bh_unlock_sock(sk);
    208
    209out:
    210	rcu_read_unlock();
    211	return err;
    212}
    213
    214static int esp_output_tcp_encap_cb(struct net *net, struct sock *sk,
    215				   struct sk_buff *skb)
    216{
    217	struct dst_entry *dst = skb_dst(skb);
    218	struct xfrm_state *x = dst->xfrm;
    219
    220	return esp_output_tcp_finish(x, skb);
    221}
    222
    223static int esp_output_tail_tcp(struct xfrm_state *x, struct sk_buff *skb)
    224{
    225	int err;
    226
    227	local_bh_disable();
    228	err = xfrm_trans_queue_net(xs_net(x), skb, esp_output_tcp_encap_cb);
    229	local_bh_enable();
    230
    231	/* EINPROGRESS just happens to do the right thing.  It
    232	 * actually means that the skb has been consumed and
    233	 * isn't coming back.
    234	 */
    235	return err ?: -EINPROGRESS;
    236}
    237#else
    238static int esp_output_tail_tcp(struct xfrm_state *x, struct sk_buff *skb)
    239{
    240	kfree_skb(skb);
    241
    242	return -EOPNOTSUPP;
    243}
    244#endif
    245
    246static void esp_output_done(struct crypto_async_request *base, int err)
    247{
    248	struct sk_buff *skb = base->data;
    249	struct xfrm_offload *xo = xfrm_offload(skb);
    250	void *tmp;
    251	struct xfrm_state *x;
    252
    253	if (xo && (xo->flags & XFRM_DEV_RESUME)) {
    254		struct sec_path *sp = skb_sec_path(skb);
    255
    256		x = sp->xvec[sp->len - 1];
    257	} else {
    258		x = skb_dst(skb)->xfrm;
    259	}
    260
    261	tmp = ESP_SKB_CB(skb)->tmp;
    262	esp_ssg_unref(x, tmp);
    263	kfree(tmp);
    264
    265	if (xo && (xo->flags & XFRM_DEV_RESUME)) {
    266		if (err) {
    267			XFRM_INC_STATS(xs_net(x), LINUX_MIB_XFRMOUTSTATEPROTOERROR);
    268			kfree_skb(skb);
    269			return;
    270		}
    271
    272		skb_push(skb, skb->data - skb_mac_header(skb));
    273		secpath_reset(skb);
    274		xfrm_dev_resume(skb);
    275	} else {
    276		if (!err &&
    277		    x->encap && x->encap->encap_type == TCP_ENCAP_ESPINTCP)
    278			esp_output_tail_tcp(x, skb);
    279		else
    280			xfrm_output_resume(skb->sk, skb, err);
    281	}
    282}
    283
    284/* Move ESP header back into place. */
    285static void esp_restore_header(struct sk_buff *skb, unsigned int offset)
    286{
    287	struct ip_esp_hdr *esph = (void *)(skb->data + offset);
    288	void *tmp = ESP_SKB_CB(skb)->tmp;
    289	__be32 *seqhi = esp_tmp_extra(tmp);
    290
    291	esph->seq_no = esph->spi;
    292	esph->spi = *seqhi;
    293}
    294
    295static void esp_output_restore_header(struct sk_buff *skb)
    296{
    297	void *tmp = ESP_SKB_CB(skb)->tmp;
    298	struct esp_output_extra *extra = esp_tmp_extra(tmp);
    299
    300	esp_restore_header(skb, skb_transport_offset(skb) + extra->esphoff -
    301				sizeof(__be32));
    302}
    303
    304static struct ip_esp_hdr *esp_output_set_extra(struct sk_buff *skb,
    305					       struct xfrm_state *x,
    306					       struct ip_esp_hdr *esph,
    307					       struct esp_output_extra *extra)
    308{
    309	/* For ESN we move the header forward by 4 bytes to
    310	 * accommodate the high bits.  We will move it back after
    311	 * encryption.
    312	 */
    313	if ((x->props.flags & XFRM_STATE_ESN)) {
    314		__u32 seqhi;
    315		struct xfrm_offload *xo = xfrm_offload(skb);
    316
    317		if (xo)
    318			seqhi = xo->seq.hi;
    319		else
    320			seqhi = XFRM_SKB_CB(skb)->seq.output.hi;
    321
    322		extra->esphoff = (unsigned char *)esph -
    323				 skb_transport_header(skb);
    324		esph = (struct ip_esp_hdr *)((unsigned char *)esph - 4);
    325		extra->seqhi = esph->spi;
    326		esph->seq_no = htonl(seqhi);
    327	}
    328
    329	esph->spi = x->id.spi;
    330
    331	return esph;
    332}
    333
    334static void esp_output_done_esn(struct crypto_async_request *base, int err)
    335{
    336	struct sk_buff *skb = base->data;
    337
    338	esp_output_restore_header(skb);
    339	esp_output_done(base, err);
    340}
    341
    342static struct ip_esp_hdr *esp_output_udp_encap(struct sk_buff *skb,
    343					       int encap_type,
    344					       struct esp_info *esp,
    345					       __be16 sport,
    346					       __be16 dport)
    347{
    348	struct udphdr *uh;
    349	__be32 *udpdata32;
    350	unsigned int len;
    351
    352	len = skb->len + esp->tailen - skb_transport_offset(skb);
    353	if (len + sizeof(struct iphdr) > IP_MAX_MTU)
    354		return ERR_PTR(-EMSGSIZE);
    355
    356	uh = (struct udphdr *)esp->esph;
    357	uh->source = sport;
    358	uh->dest = dport;
    359	uh->len = htons(len);
    360	uh->check = 0;
    361
    362	*skb_mac_header(skb) = IPPROTO_UDP;
    363
    364	if (encap_type == UDP_ENCAP_ESPINUDP_NON_IKE) {
    365		udpdata32 = (__be32 *)(uh + 1);
    366		udpdata32[0] = udpdata32[1] = 0;
    367		return (struct ip_esp_hdr *)(udpdata32 + 2);
    368	}
    369
    370	return (struct ip_esp_hdr *)(uh + 1);
    371}
    372
    373#ifdef CONFIG_INET_ESPINTCP
    374static struct ip_esp_hdr *esp_output_tcp_encap(struct xfrm_state *x,
    375						    struct sk_buff *skb,
    376						    struct esp_info *esp)
    377{
    378	__be16 *lenp = (void *)esp->esph;
    379	struct ip_esp_hdr *esph;
    380	unsigned int len;
    381	struct sock *sk;
    382
    383	len = skb->len + esp->tailen - skb_transport_offset(skb);
    384	if (len > IP_MAX_MTU)
    385		return ERR_PTR(-EMSGSIZE);
    386
    387	rcu_read_lock();
    388	sk = esp_find_tcp_sk(x);
    389	rcu_read_unlock();
    390
    391	if (IS_ERR(sk))
    392		return ERR_CAST(sk);
    393
    394	*lenp = htons(len);
    395	esph = (struct ip_esp_hdr *)(lenp + 1);
    396
    397	return esph;
    398}
    399#else
    400static struct ip_esp_hdr *esp_output_tcp_encap(struct xfrm_state *x,
    401						    struct sk_buff *skb,
    402						    struct esp_info *esp)
    403{
    404	return ERR_PTR(-EOPNOTSUPP);
    405}
    406#endif
    407
    408static int esp_output_encap(struct xfrm_state *x, struct sk_buff *skb,
    409			    struct esp_info *esp)
    410{
    411	struct xfrm_encap_tmpl *encap = x->encap;
    412	struct ip_esp_hdr *esph;
    413	__be16 sport, dport;
    414	int encap_type;
    415
    416	spin_lock_bh(&x->lock);
    417	sport = encap->encap_sport;
    418	dport = encap->encap_dport;
    419	encap_type = encap->encap_type;
    420	spin_unlock_bh(&x->lock);
    421
    422	switch (encap_type) {
    423	default:
    424	case UDP_ENCAP_ESPINUDP:
    425	case UDP_ENCAP_ESPINUDP_NON_IKE:
    426		esph = esp_output_udp_encap(skb, encap_type, esp, sport, dport);
    427		break;
    428	case TCP_ENCAP_ESPINTCP:
    429		esph = esp_output_tcp_encap(x, skb, esp);
    430		break;
    431	}
    432
    433	if (IS_ERR(esph))
    434		return PTR_ERR(esph);
    435
    436	esp->esph = esph;
    437
    438	return 0;
    439}
    440
    441int esp_output_head(struct xfrm_state *x, struct sk_buff *skb, struct esp_info *esp)
    442{
    443	u8 *tail;
    444	int nfrags;
    445	int esph_offset;
    446	struct page *page;
    447	struct sk_buff *trailer;
    448	int tailen = esp->tailen;
    449
    450	/* this is non-NULL only with TCP/UDP Encapsulation */
    451	if (x->encap) {
    452		int err = esp_output_encap(x, skb, esp);
    453
    454		if (err < 0)
    455			return err;
    456	}
    457
    458	if (ALIGN(tailen, L1_CACHE_BYTES) > PAGE_SIZE ||
    459	    ALIGN(skb->data_len, L1_CACHE_BYTES) > PAGE_SIZE)
    460		goto cow;
    461
    462	if (!skb_cloned(skb)) {
    463		if (tailen <= skb_tailroom(skb)) {
    464			nfrags = 1;
    465			trailer = skb;
    466			tail = skb_tail_pointer(trailer);
    467
    468			goto skip_cow;
    469		} else if ((skb_shinfo(skb)->nr_frags < MAX_SKB_FRAGS)
    470			   && !skb_has_frag_list(skb)) {
    471			int allocsize;
    472			struct sock *sk = skb->sk;
    473			struct page_frag *pfrag = &x->xfrag;
    474
    475			esp->inplace = false;
    476
    477			allocsize = ALIGN(tailen, L1_CACHE_BYTES);
    478
    479			spin_lock_bh(&x->lock);
    480
    481			if (unlikely(!skb_page_frag_refill(allocsize, pfrag, GFP_ATOMIC))) {
    482				spin_unlock_bh(&x->lock);
    483				goto cow;
    484			}
    485
    486			page = pfrag->page;
    487			get_page(page);
    488
    489			tail = page_address(page) + pfrag->offset;
    490
    491			esp_output_fill_trailer(tail, esp->tfclen, esp->plen, esp->proto);
    492
    493			nfrags = skb_shinfo(skb)->nr_frags;
    494
    495			__skb_fill_page_desc(skb, nfrags, page, pfrag->offset,
    496					     tailen);
    497			skb_shinfo(skb)->nr_frags = ++nfrags;
    498
    499			pfrag->offset = pfrag->offset + allocsize;
    500
    501			spin_unlock_bh(&x->lock);
    502
    503			nfrags++;
    504
    505			skb->len += tailen;
    506			skb->data_len += tailen;
    507			skb->truesize += tailen;
    508			if (sk && sk_fullsock(sk))
    509				refcount_add(tailen, &sk->sk_wmem_alloc);
    510
    511			goto out;
    512		}
    513	}
    514
    515cow:
    516	esph_offset = (unsigned char *)esp->esph - skb_transport_header(skb);
    517
    518	nfrags = skb_cow_data(skb, tailen, &trailer);
    519	if (nfrags < 0)
    520		goto out;
    521	tail = skb_tail_pointer(trailer);
    522	esp->esph = (struct ip_esp_hdr *)(skb_transport_header(skb) + esph_offset);
    523
    524skip_cow:
    525	esp_output_fill_trailer(tail, esp->tfclen, esp->plen, esp->proto);
    526	pskb_put(skb, trailer, tailen);
    527
    528out:
    529	return nfrags;
    530}
    531EXPORT_SYMBOL_GPL(esp_output_head);
    532
    533int esp_output_tail(struct xfrm_state *x, struct sk_buff *skb, struct esp_info *esp)
    534{
    535	u8 *iv;
    536	int alen;
    537	void *tmp;
    538	int ivlen;
    539	int assoclen;
    540	int extralen;
    541	struct page *page;
    542	struct ip_esp_hdr *esph;
    543	struct crypto_aead *aead;
    544	struct aead_request *req;
    545	struct scatterlist *sg, *dsg;
    546	struct esp_output_extra *extra;
    547	int err = -ENOMEM;
    548
    549	assoclen = sizeof(struct ip_esp_hdr);
    550	extralen = 0;
    551
    552	if (x->props.flags & XFRM_STATE_ESN) {
    553		extralen += sizeof(*extra);
    554		assoclen += sizeof(__be32);
    555	}
    556
    557	aead = x->data;
    558	alen = crypto_aead_authsize(aead);
    559	ivlen = crypto_aead_ivsize(aead);
    560
    561	tmp = esp_alloc_tmp(aead, esp->nfrags + 2, extralen);
    562	if (!tmp)
    563		goto error;
    564
    565	extra = esp_tmp_extra(tmp);
    566	iv = esp_tmp_iv(aead, tmp, extralen);
    567	req = esp_tmp_req(aead, iv);
    568	sg = esp_req_sg(aead, req);
    569
    570	if (esp->inplace)
    571		dsg = sg;
    572	else
    573		dsg = &sg[esp->nfrags];
    574
    575	esph = esp_output_set_extra(skb, x, esp->esph, extra);
    576	esp->esph = esph;
    577
    578	sg_init_table(sg, esp->nfrags);
    579	err = skb_to_sgvec(skb, sg,
    580		           (unsigned char *)esph - skb->data,
    581		           assoclen + ivlen + esp->clen + alen);
    582	if (unlikely(err < 0))
    583		goto error_free;
    584
    585	if (!esp->inplace) {
    586		int allocsize;
    587		struct page_frag *pfrag = &x->xfrag;
    588
    589		allocsize = ALIGN(skb->data_len, L1_CACHE_BYTES);
    590
    591		spin_lock_bh(&x->lock);
    592		if (unlikely(!skb_page_frag_refill(allocsize, pfrag, GFP_ATOMIC))) {
    593			spin_unlock_bh(&x->lock);
    594			goto error_free;
    595		}
    596
    597		skb_shinfo(skb)->nr_frags = 1;
    598
    599		page = pfrag->page;
    600		get_page(page);
    601		/* replace page frags in skb with new page */
    602		__skb_fill_page_desc(skb, 0, page, pfrag->offset, skb->data_len);
    603		pfrag->offset = pfrag->offset + allocsize;
    604		spin_unlock_bh(&x->lock);
    605
    606		sg_init_table(dsg, skb_shinfo(skb)->nr_frags + 1);
    607		err = skb_to_sgvec(skb, dsg,
    608			           (unsigned char *)esph - skb->data,
    609			           assoclen + ivlen + esp->clen + alen);
    610		if (unlikely(err < 0))
    611			goto error_free;
    612	}
    613
    614	if ((x->props.flags & XFRM_STATE_ESN))
    615		aead_request_set_callback(req, 0, esp_output_done_esn, skb);
    616	else
    617		aead_request_set_callback(req, 0, esp_output_done, skb);
    618
    619	aead_request_set_crypt(req, sg, dsg, ivlen + esp->clen, iv);
    620	aead_request_set_ad(req, assoclen);
    621
    622	memset(iv, 0, ivlen);
    623	memcpy(iv + ivlen - min(ivlen, 8), (u8 *)&esp->seqno + 8 - min(ivlen, 8),
    624	       min(ivlen, 8));
    625
    626	ESP_SKB_CB(skb)->tmp = tmp;
    627	err = crypto_aead_encrypt(req);
    628
    629	switch (err) {
    630	case -EINPROGRESS:
    631		goto error;
    632
    633	case -ENOSPC:
    634		err = NET_XMIT_DROP;
    635		break;
    636
    637	case 0:
    638		if ((x->props.flags & XFRM_STATE_ESN))
    639			esp_output_restore_header(skb);
    640	}
    641
    642	if (sg != dsg)
    643		esp_ssg_unref(x, tmp);
    644
    645	if (!err && x->encap && x->encap->encap_type == TCP_ENCAP_ESPINTCP)
    646		err = esp_output_tail_tcp(x, skb);
    647
    648error_free:
    649	kfree(tmp);
    650error:
    651	return err;
    652}
    653EXPORT_SYMBOL_GPL(esp_output_tail);
    654
    655static int esp_output(struct xfrm_state *x, struct sk_buff *skb)
    656{
    657	int alen;
    658	int blksize;
    659	struct ip_esp_hdr *esph;
    660	struct crypto_aead *aead;
    661	struct esp_info esp;
    662
    663	esp.inplace = true;
    664
    665	esp.proto = *skb_mac_header(skb);
    666	*skb_mac_header(skb) = IPPROTO_ESP;
    667
    668	/* skb is pure payload to encrypt */
    669
    670	aead = x->data;
    671	alen = crypto_aead_authsize(aead);
    672
    673	esp.tfclen = 0;
    674	if (x->tfcpad) {
    675		struct xfrm_dst *dst = (struct xfrm_dst *)skb_dst(skb);
    676		u32 padto;
    677
    678		padto = min(x->tfcpad, xfrm_state_mtu(x, dst->child_mtu_cached));
    679		if (skb->len < padto)
    680			esp.tfclen = padto - skb->len;
    681	}
    682	blksize = ALIGN(crypto_aead_blocksize(aead), 4);
    683	esp.clen = ALIGN(skb->len + 2 + esp.tfclen, blksize);
    684	esp.plen = esp.clen - skb->len - esp.tfclen;
    685	esp.tailen = esp.tfclen + esp.plen + alen;
    686
    687	esp.esph = ip_esp_hdr(skb);
    688
    689	esp.nfrags = esp_output_head(x, skb, &esp);
    690	if (esp.nfrags < 0)
    691		return esp.nfrags;
    692
    693	esph = esp.esph;
    694	esph->spi = x->id.spi;
    695
    696	esph->seq_no = htonl(XFRM_SKB_CB(skb)->seq.output.low);
    697	esp.seqno = cpu_to_be64(XFRM_SKB_CB(skb)->seq.output.low +
    698				 ((u64)XFRM_SKB_CB(skb)->seq.output.hi << 32));
    699
    700	skb_push(skb, -skb_network_offset(skb));
    701
    702	return esp_output_tail(x, skb, &esp);
    703}
    704
    705static inline int esp_remove_trailer(struct sk_buff *skb)
    706{
    707	struct xfrm_state *x = xfrm_input_state(skb);
    708	struct crypto_aead *aead = x->data;
    709	int alen, hlen, elen;
    710	int padlen, trimlen;
    711	__wsum csumdiff;
    712	u8 nexthdr[2];
    713	int ret;
    714
    715	alen = crypto_aead_authsize(aead);
    716	hlen = sizeof(struct ip_esp_hdr) + crypto_aead_ivsize(aead);
    717	elen = skb->len - hlen;
    718
    719	if (skb_copy_bits(skb, skb->len - alen - 2, nexthdr, 2))
    720		BUG();
    721
    722	ret = -EINVAL;
    723	padlen = nexthdr[0];
    724	if (padlen + 2 + alen >= elen) {
    725		net_dbg_ratelimited("ipsec esp packet is garbage padlen=%d, elen=%d\n",
    726				    padlen + 2, elen - alen);
    727		goto out;
    728	}
    729
    730	trimlen = alen + padlen + 2;
    731	if (skb->ip_summed == CHECKSUM_COMPLETE) {
    732		csumdiff = skb_checksum(skb, skb->len - trimlen, trimlen, 0);
    733		skb->csum = csum_block_sub(skb->csum, csumdiff,
    734					   skb->len - trimlen);
    735	}
    736	pskb_trim(skb, skb->len - trimlen);
    737
    738	ret = nexthdr[1];
    739
    740out:
    741	return ret;
    742}
    743
    744int esp_input_done2(struct sk_buff *skb, int err)
    745{
    746	const struct iphdr *iph;
    747	struct xfrm_state *x = xfrm_input_state(skb);
    748	struct xfrm_offload *xo = xfrm_offload(skb);
    749	struct crypto_aead *aead = x->data;
    750	int hlen = sizeof(struct ip_esp_hdr) + crypto_aead_ivsize(aead);
    751	int ihl;
    752
    753	if (!xo || !(xo->flags & CRYPTO_DONE))
    754		kfree(ESP_SKB_CB(skb)->tmp);
    755
    756	if (unlikely(err))
    757		goto out;
    758
    759	err = esp_remove_trailer(skb);
    760	if (unlikely(err < 0))
    761		goto out;
    762
    763	iph = ip_hdr(skb);
    764	ihl = iph->ihl * 4;
    765
    766	if (x->encap) {
    767		struct xfrm_encap_tmpl *encap = x->encap;
    768		struct tcphdr *th = (void *)(skb_network_header(skb) + ihl);
    769		struct udphdr *uh = (void *)(skb_network_header(skb) + ihl);
    770		__be16 source;
    771
    772		switch (x->encap->encap_type) {
    773		case TCP_ENCAP_ESPINTCP:
    774			source = th->source;
    775			break;
    776		case UDP_ENCAP_ESPINUDP:
    777		case UDP_ENCAP_ESPINUDP_NON_IKE:
    778			source = uh->source;
    779			break;
    780		default:
    781			WARN_ON_ONCE(1);
    782			err = -EINVAL;
    783			goto out;
    784		}
    785
    786		/*
    787		 * 1) if the NAT-T peer's IP or port changed then
    788		 *    advertize the change to the keying daemon.
    789		 *    This is an inbound SA, so just compare
    790		 *    SRC ports.
    791		 */
    792		if (iph->saddr != x->props.saddr.a4 ||
    793		    source != encap->encap_sport) {
    794			xfrm_address_t ipaddr;
    795
    796			ipaddr.a4 = iph->saddr;
    797			km_new_mapping(x, &ipaddr, source);
    798
    799			/* XXX: perhaps add an extra
    800			 * policy check here, to see
    801			 * if we should allow or
    802			 * reject a packet from a
    803			 * different source
    804			 * address/port.
    805			 */
    806		}
    807
    808		/*
    809		 * 2) ignore UDP/TCP checksums in case
    810		 *    of NAT-T in Transport Mode, or
    811		 *    perform other post-processing fixes
    812		 *    as per draft-ietf-ipsec-udp-encaps-06,
    813		 *    section 3.1.2
    814		 */
    815		if (x->props.mode == XFRM_MODE_TRANSPORT)
    816			skb->ip_summed = CHECKSUM_UNNECESSARY;
    817	}
    818
    819	skb_pull_rcsum(skb, hlen);
    820	if (x->props.mode == XFRM_MODE_TUNNEL)
    821		skb_reset_transport_header(skb);
    822	else
    823		skb_set_transport_header(skb, -ihl);
    824
    825	/* RFC4303: Drop dummy packets without any error */
    826	if (err == IPPROTO_NONE)
    827		err = -EINVAL;
    828
    829out:
    830	return err;
    831}
    832EXPORT_SYMBOL_GPL(esp_input_done2);
    833
    834static void esp_input_done(struct crypto_async_request *base, int err)
    835{
    836	struct sk_buff *skb = base->data;
    837
    838	xfrm_input_resume(skb, esp_input_done2(skb, err));
    839}
    840
    841static void esp_input_restore_header(struct sk_buff *skb)
    842{
    843	esp_restore_header(skb, 0);
    844	__skb_pull(skb, 4);
    845}
    846
    847static void esp_input_set_header(struct sk_buff *skb, __be32 *seqhi)
    848{
    849	struct xfrm_state *x = xfrm_input_state(skb);
    850	struct ip_esp_hdr *esph;
    851
    852	/* For ESN we move the header forward by 4 bytes to
    853	 * accommodate the high bits.  We will move it back after
    854	 * decryption.
    855	 */
    856	if ((x->props.flags & XFRM_STATE_ESN)) {
    857		esph = skb_push(skb, 4);
    858		*seqhi = esph->spi;
    859		esph->spi = esph->seq_no;
    860		esph->seq_no = XFRM_SKB_CB(skb)->seq.input.hi;
    861	}
    862}
    863
    864static void esp_input_done_esn(struct crypto_async_request *base, int err)
    865{
    866	struct sk_buff *skb = base->data;
    867
    868	esp_input_restore_header(skb);
    869	esp_input_done(base, err);
    870}
    871
    872/*
    873 * Note: detecting truncated vs. non-truncated authentication data is very
    874 * expensive, so we only support truncated data, which is the recommended
    875 * and common case.
    876 */
    877static int esp_input(struct xfrm_state *x, struct sk_buff *skb)
    878{
    879	struct crypto_aead *aead = x->data;
    880	struct aead_request *req;
    881	struct sk_buff *trailer;
    882	int ivlen = crypto_aead_ivsize(aead);
    883	int elen = skb->len - sizeof(struct ip_esp_hdr) - ivlen;
    884	int nfrags;
    885	int assoclen;
    886	int seqhilen;
    887	__be32 *seqhi;
    888	void *tmp;
    889	u8 *iv;
    890	struct scatterlist *sg;
    891	int err = -EINVAL;
    892
    893	if (!pskb_may_pull(skb, sizeof(struct ip_esp_hdr) + ivlen))
    894		goto out;
    895
    896	if (elen <= 0)
    897		goto out;
    898
    899	assoclen = sizeof(struct ip_esp_hdr);
    900	seqhilen = 0;
    901
    902	if (x->props.flags & XFRM_STATE_ESN) {
    903		seqhilen += sizeof(__be32);
    904		assoclen += seqhilen;
    905	}
    906
    907	if (!skb_cloned(skb)) {
    908		if (!skb_is_nonlinear(skb)) {
    909			nfrags = 1;
    910
    911			goto skip_cow;
    912		} else if (!skb_has_frag_list(skb)) {
    913			nfrags = skb_shinfo(skb)->nr_frags;
    914			nfrags++;
    915
    916			goto skip_cow;
    917		}
    918	}
    919
    920	err = skb_cow_data(skb, 0, &trailer);
    921	if (err < 0)
    922		goto out;
    923
    924	nfrags = err;
    925
    926skip_cow:
    927	err = -ENOMEM;
    928	tmp = esp_alloc_tmp(aead, nfrags, seqhilen);
    929	if (!tmp)
    930		goto out;
    931
    932	ESP_SKB_CB(skb)->tmp = tmp;
    933	seqhi = esp_tmp_extra(tmp);
    934	iv = esp_tmp_iv(aead, tmp, seqhilen);
    935	req = esp_tmp_req(aead, iv);
    936	sg = esp_req_sg(aead, req);
    937
    938	esp_input_set_header(skb, seqhi);
    939
    940	sg_init_table(sg, nfrags);
    941	err = skb_to_sgvec(skb, sg, 0, skb->len);
    942	if (unlikely(err < 0)) {
    943		kfree(tmp);
    944		goto out;
    945	}
    946
    947	skb->ip_summed = CHECKSUM_NONE;
    948
    949	if ((x->props.flags & XFRM_STATE_ESN))
    950		aead_request_set_callback(req, 0, esp_input_done_esn, skb);
    951	else
    952		aead_request_set_callback(req, 0, esp_input_done, skb);
    953
    954	aead_request_set_crypt(req, sg, sg, elen + ivlen, iv);
    955	aead_request_set_ad(req, assoclen);
    956
    957	err = crypto_aead_decrypt(req);
    958	if (err == -EINPROGRESS)
    959		goto out;
    960
    961	if ((x->props.flags & XFRM_STATE_ESN))
    962		esp_input_restore_header(skb);
    963
    964	err = esp_input_done2(skb, err);
    965
    966out:
    967	return err;
    968}
    969
    970static int esp4_err(struct sk_buff *skb, u32 info)
    971{
    972	struct net *net = dev_net(skb->dev);
    973	const struct iphdr *iph = (const struct iphdr *)skb->data;
    974	struct ip_esp_hdr *esph = (struct ip_esp_hdr *)(skb->data+(iph->ihl<<2));
    975	struct xfrm_state *x;
    976
    977	switch (icmp_hdr(skb)->type) {
    978	case ICMP_DEST_UNREACH:
    979		if (icmp_hdr(skb)->code != ICMP_FRAG_NEEDED)
    980			return 0;
    981		break;
    982	case ICMP_REDIRECT:
    983		break;
    984	default:
    985		return 0;
    986	}
    987
    988	x = xfrm_state_lookup(net, skb->mark, (const xfrm_address_t *)&iph->daddr,
    989			      esph->spi, IPPROTO_ESP, AF_INET);
    990	if (!x)
    991		return 0;
    992
    993	if (icmp_hdr(skb)->type == ICMP_DEST_UNREACH)
    994		ipv4_update_pmtu(skb, net, info, 0, IPPROTO_ESP);
    995	else
    996		ipv4_redirect(skb, net, 0, IPPROTO_ESP);
    997	xfrm_state_put(x);
    998
    999	return 0;
   1000}
   1001
   1002static void esp_destroy(struct xfrm_state *x)
   1003{
   1004	struct crypto_aead *aead = x->data;
   1005
   1006	if (!aead)
   1007		return;
   1008
   1009	crypto_free_aead(aead);
   1010}
   1011
   1012static int esp_init_aead(struct xfrm_state *x)
   1013{
   1014	char aead_name[CRYPTO_MAX_ALG_NAME];
   1015	struct crypto_aead *aead;
   1016	int err;
   1017
   1018	err = -ENAMETOOLONG;
   1019	if (snprintf(aead_name, CRYPTO_MAX_ALG_NAME, "%s(%s)",
   1020		     x->geniv, x->aead->alg_name) >= CRYPTO_MAX_ALG_NAME)
   1021		goto error;
   1022
   1023	aead = crypto_alloc_aead(aead_name, 0, 0);
   1024	err = PTR_ERR(aead);
   1025	if (IS_ERR(aead))
   1026		goto error;
   1027
   1028	x->data = aead;
   1029
   1030	err = crypto_aead_setkey(aead, x->aead->alg_key,
   1031				 (x->aead->alg_key_len + 7) / 8);
   1032	if (err)
   1033		goto error;
   1034
   1035	err = crypto_aead_setauthsize(aead, x->aead->alg_icv_len / 8);
   1036	if (err)
   1037		goto error;
   1038
   1039error:
   1040	return err;
   1041}
   1042
   1043static int esp_init_authenc(struct xfrm_state *x)
   1044{
   1045	struct crypto_aead *aead;
   1046	struct crypto_authenc_key_param *param;
   1047	struct rtattr *rta;
   1048	char *key;
   1049	char *p;
   1050	char authenc_name[CRYPTO_MAX_ALG_NAME];
   1051	unsigned int keylen;
   1052	int err;
   1053
   1054	err = -EINVAL;
   1055	if (!x->ealg)
   1056		goto error;
   1057
   1058	err = -ENAMETOOLONG;
   1059
   1060	if ((x->props.flags & XFRM_STATE_ESN)) {
   1061		if (snprintf(authenc_name, CRYPTO_MAX_ALG_NAME,
   1062			     "%s%sauthencesn(%s,%s)%s",
   1063			     x->geniv ?: "", x->geniv ? "(" : "",
   1064			     x->aalg ? x->aalg->alg_name : "digest_null",
   1065			     x->ealg->alg_name,
   1066			     x->geniv ? ")" : "") >= CRYPTO_MAX_ALG_NAME)
   1067			goto error;
   1068	} else {
   1069		if (snprintf(authenc_name, CRYPTO_MAX_ALG_NAME,
   1070			     "%s%sauthenc(%s,%s)%s",
   1071			     x->geniv ?: "", x->geniv ? "(" : "",
   1072			     x->aalg ? x->aalg->alg_name : "digest_null",
   1073			     x->ealg->alg_name,
   1074			     x->geniv ? ")" : "") >= CRYPTO_MAX_ALG_NAME)
   1075			goto error;
   1076	}
   1077
   1078	aead = crypto_alloc_aead(authenc_name, 0, 0);
   1079	err = PTR_ERR(aead);
   1080	if (IS_ERR(aead))
   1081		goto error;
   1082
   1083	x->data = aead;
   1084
   1085	keylen = (x->aalg ? (x->aalg->alg_key_len + 7) / 8 : 0) +
   1086		 (x->ealg->alg_key_len + 7) / 8 + RTA_SPACE(sizeof(*param));
   1087	err = -ENOMEM;
   1088	key = kmalloc(keylen, GFP_KERNEL);
   1089	if (!key)
   1090		goto error;
   1091
   1092	p = key;
   1093	rta = (void *)p;
   1094	rta->rta_type = CRYPTO_AUTHENC_KEYA_PARAM;
   1095	rta->rta_len = RTA_LENGTH(sizeof(*param));
   1096	param = RTA_DATA(rta);
   1097	p += RTA_SPACE(sizeof(*param));
   1098
   1099	if (x->aalg) {
   1100		struct xfrm_algo_desc *aalg_desc;
   1101
   1102		memcpy(p, x->aalg->alg_key, (x->aalg->alg_key_len + 7) / 8);
   1103		p += (x->aalg->alg_key_len + 7) / 8;
   1104
   1105		aalg_desc = xfrm_aalg_get_byname(x->aalg->alg_name, 0);
   1106		BUG_ON(!aalg_desc);
   1107
   1108		err = -EINVAL;
   1109		if (aalg_desc->uinfo.auth.icv_fullbits / 8 !=
   1110		    crypto_aead_authsize(aead)) {
   1111			pr_info("ESP: %s digestsize %u != %hu\n",
   1112				x->aalg->alg_name,
   1113				crypto_aead_authsize(aead),
   1114				aalg_desc->uinfo.auth.icv_fullbits / 8);
   1115			goto free_key;
   1116		}
   1117
   1118		err = crypto_aead_setauthsize(
   1119			aead, x->aalg->alg_trunc_len / 8);
   1120		if (err)
   1121			goto free_key;
   1122	}
   1123
   1124	param->enckeylen = cpu_to_be32((x->ealg->alg_key_len + 7) / 8);
   1125	memcpy(p, x->ealg->alg_key, (x->ealg->alg_key_len + 7) / 8);
   1126
   1127	err = crypto_aead_setkey(aead, key, keylen);
   1128
   1129free_key:
   1130	kfree(key);
   1131
   1132error:
   1133	return err;
   1134}
   1135
   1136static int esp_init_state(struct xfrm_state *x)
   1137{
   1138	struct crypto_aead *aead;
   1139	u32 align;
   1140	int err;
   1141
   1142	x->data = NULL;
   1143
   1144	if (x->aead)
   1145		err = esp_init_aead(x);
   1146	else
   1147		err = esp_init_authenc(x);
   1148
   1149	if (err)
   1150		goto error;
   1151
   1152	aead = x->data;
   1153
   1154	x->props.header_len = sizeof(struct ip_esp_hdr) +
   1155			      crypto_aead_ivsize(aead);
   1156	if (x->props.mode == XFRM_MODE_TUNNEL)
   1157		x->props.header_len += sizeof(struct iphdr);
   1158	else if (x->props.mode == XFRM_MODE_BEET && x->sel.family != AF_INET6)
   1159		x->props.header_len += IPV4_BEET_PHMAXLEN;
   1160	if (x->encap) {
   1161		struct xfrm_encap_tmpl *encap = x->encap;
   1162
   1163		switch (encap->encap_type) {
   1164		default:
   1165			err = -EINVAL;
   1166			goto error;
   1167		case UDP_ENCAP_ESPINUDP:
   1168			x->props.header_len += sizeof(struct udphdr);
   1169			break;
   1170		case UDP_ENCAP_ESPINUDP_NON_IKE:
   1171			x->props.header_len += sizeof(struct udphdr) + 2 * sizeof(u32);
   1172			break;
   1173#ifdef CONFIG_INET_ESPINTCP
   1174		case TCP_ENCAP_ESPINTCP:
   1175			/* only the length field, TCP encap is done by
   1176			 * the socket
   1177			 */
   1178			x->props.header_len += 2;
   1179			break;
   1180#endif
   1181		}
   1182	}
   1183
   1184	align = ALIGN(crypto_aead_blocksize(aead), 4);
   1185	x->props.trailer_len = align + 1 + crypto_aead_authsize(aead);
   1186
   1187error:
   1188	return err;
   1189}
   1190
   1191static int esp4_rcv_cb(struct sk_buff *skb, int err)
   1192{
   1193	return 0;
   1194}
   1195
   1196static const struct xfrm_type esp_type =
   1197{
   1198	.owner		= THIS_MODULE,
   1199	.proto	     	= IPPROTO_ESP,
   1200	.flags		= XFRM_TYPE_REPLAY_PROT,
   1201	.init_state	= esp_init_state,
   1202	.destructor	= esp_destroy,
   1203	.input		= esp_input,
   1204	.output		= esp_output,
   1205};
   1206
   1207static struct xfrm4_protocol esp4_protocol = {
   1208	.handler	=	xfrm4_rcv,
   1209	.input_handler	=	xfrm_input,
   1210	.cb_handler	=	esp4_rcv_cb,
   1211	.err_handler	=	esp4_err,
   1212	.priority	=	0,
   1213};
   1214
   1215static int __init esp4_init(void)
   1216{
   1217	if (xfrm_register_type(&esp_type, AF_INET) < 0) {
   1218		pr_info("%s: can't add xfrm type\n", __func__);
   1219		return -EAGAIN;
   1220	}
   1221	if (xfrm4_protocol_register(&esp4_protocol, IPPROTO_ESP) < 0) {
   1222		pr_info("%s: can't add protocol\n", __func__);
   1223		xfrm_unregister_type(&esp_type, AF_INET);
   1224		return -EAGAIN;
   1225	}
   1226	return 0;
   1227}
   1228
   1229static void __exit esp4_fini(void)
   1230{
   1231	if (xfrm4_protocol_deregister(&esp4_protocol, IPPROTO_ESP) < 0)
   1232		pr_info("%s: can't remove protocol\n", __func__);
   1233	xfrm_unregister_type(&esp_type, AF_INET);
   1234}
   1235
   1236module_init(esp4_init);
   1237module_exit(esp4_fini);
   1238MODULE_LICENSE("GPL");
   1239MODULE_ALIAS_XFRM_TYPE(AF_INET, XFRM_PROTO_ESP);