cachepc-linux

Fork of AMDESE/linux with modifications for CachePC side-channel attack
git clone https://git.sinitax.com/sinitax/cachepc-linux
Log | Files | Refs | README | LICENSE | sfeed.txt

skmsg.c (29254B)


      1// SPDX-License-Identifier: GPL-2.0
      2/* Copyright (c) 2017 - 2018 Covalent IO, Inc. http://covalent.io */
      3
      4#include <linux/skmsg.h>
      5#include <linux/skbuff.h>
      6#include <linux/scatterlist.h>
      7
      8#include <net/sock.h>
      9#include <net/tcp.h>
     10#include <net/tls.h>
     11
     12static bool sk_msg_try_coalesce_ok(struct sk_msg *msg, int elem_first_coalesce)
     13{
     14	if (msg->sg.end > msg->sg.start &&
     15	    elem_first_coalesce < msg->sg.end)
     16		return true;
     17
     18	if (msg->sg.end < msg->sg.start &&
     19	    (elem_first_coalesce > msg->sg.start ||
     20	     elem_first_coalesce < msg->sg.end))
     21		return true;
     22
     23	return false;
     24}
     25
     26int sk_msg_alloc(struct sock *sk, struct sk_msg *msg, int len,
     27		 int elem_first_coalesce)
     28{
     29	struct page_frag *pfrag = sk_page_frag(sk);
     30	u32 osize = msg->sg.size;
     31	int ret = 0;
     32
     33	len -= msg->sg.size;
     34	while (len > 0) {
     35		struct scatterlist *sge;
     36		u32 orig_offset;
     37		int use, i;
     38
     39		if (!sk_page_frag_refill(sk, pfrag)) {
     40			ret = -ENOMEM;
     41			goto msg_trim;
     42		}
     43
     44		orig_offset = pfrag->offset;
     45		use = min_t(int, len, pfrag->size - orig_offset);
     46		if (!sk_wmem_schedule(sk, use)) {
     47			ret = -ENOMEM;
     48			goto msg_trim;
     49		}
     50
     51		i = msg->sg.end;
     52		sk_msg_iter_var_prev(i);
     53		sge = &msg->sg.data[i];
     54
     55		if (sk_msg_try_coalesce_ok(msg, elem_first_coalesce) &&
     56		    sg_page(sge) == pfrag->page &&
     57		    sge->offset + sge->length == orig_offset) {
     58			sge->length += use;
     59		} else {
     60			if (sk_msg_full(msg)) {
     61				ret = -ENOSPC;
     62				break;
     63			}
     64
     65			sge = &msg->sg.data[msg->sg.end];
     66			sg_unmark_end(sge);
     67			sg_set_page(sge, pfrag->page, use, orig_offset);
     68			get_page(pfrag->page);
     69			sk_msg_iter_next(msg, end);
     70		}
     71
     72		sk_mem_charge(sk, use);
     73		msg->sg.size += use;
     74		pfrag->offset += use;
     75		len -= use;
     76	}
     77
     78	return ret;
     79
     80msg_trim:
     81	sk_msg_trim(sk, msg, osize);
     82	return ret;
     83}
     84EXPORT_SYMBOL_GPL(sk_msg_alloc);
     85
     86int sk_msg_clone(struct sock *sk, struct sk_msg *dst, struct sk_msg *src,
     87		 u32 off, u32 len)
     88{
     89	int i = src->sg.start;
     90	struct scatterlist *sge = sk_msg_elem(src, i);
     91	struct scatterlist *sgd = NULL;
     92	u32 sge_len, sge_off;
     93
     94	while (off) {
     95		if (sge->length > off)
     96			break;
     97		off -= sge->length;
     98		sk_msg_iter_var_next(i);
     99		if (i == src->sg.end && off)
    100			return -ENOSPC;
    101		sge = sk_msg_elem(src, i);
    102	}
    103
    104	while (len) {
    105		sge_len = sge->length - off;
    106		if (sge_len > len)
    107			sge_len = len;
    108
    109		if (dst->sg.end)
    110			sgd = sk_msg_elem(dst, dst->sg.end - 1);
    111
    112		if (sgd &&
    113		    (sg_page(sge) == sg_page(sgd)) &&
    114		    (sg_virt(sge) + off == sg_virt(sgd) + sgd->length)) {
    115			sgd->length += sge_len;
    116			dst->sg.size += sge_len;
    117		} else if (!sk_msg_full(dst)) {
    118			sge_off = sge->offset + off;
    119			sk_msg_page_add(dst, sg_page(sge), sge_len, sge_off);
    120		} else {
    121			return -ENOSPC;
    122		}
    123
    124		off = 0;
    125		len -= sge_len;
    126		sk_mem_charge(sk, sge_len);
    127		sk_msg_iter_var_next(i);
    128		if (i == src->sg.end && len)
    129			return -ENOSPC;
    130		sge = sk_msg_elem(src, i);
    131	}
    132
    133	return 0;
    134}
    135EXPORT_SYMBOL_GPL(sk_msg_clone);
    136
    137void sk_msg_return_zero(struct sock *sk, struct sk_msg *msg, int bytes)
    138{
    139	int i = msg->sg.start;
    140
    141	do {
    142		struct scatterlist *sge = sk_msg_elem(msg, i);
    143
    144		if (bytes < sge->length) {
    145			sge->length -= bytes;
    146			sge->offset += bytes;
    147			sk_mem_uncharge(sk, bytes);
    148			break;
    149		}
    150
    151		sk_mem_uncharge(sk, sge->length);
    152		bytes -= sge->length;
    153		sge->length = 0;
    154		sge->offset = 0;
    155		sk_msg_iter_var_next(i);
    156	} while (bytes && i != msg->sg.end);
    157	msg->sg.start = i;
    158}
    159EXPORT_SYMBOL_GPL(sk_msg_return_zero);
    160
    161void sk_msg_return(struct sock *sk, struct sk_msg *msg, int bytes)
    162{
    163	int i = msg->sg.start;
    164
    165	do {
    166		struct scatterlist *sge = &msg->sg.data[i];
    167		int uncharge = (bytes < sge->length) ? bytes : sge->length;
    168
    169		sk_mem_uncharge(sk, uncharge);
    170		bytes -= uncharge;
    171		sk_msg_iter_var_next(i);
    172	} while (i != msg->sg.end);
    173}
    174EXPORT_SYMBOL_GPL(sk_msg_return);
    175
    176static int sk_msg_free_elem(struct sock *sk, struct sk_msg *msg, u32 i,
    177			    bool charge)
    178{
    179	struct scatterlist *sge = sk_msg_elem(msg, i);
    180	u32 len = sge->length;
    181
    182	/* When the skb owns the memory we free it from consume_skb path. */
    183	if (!msg->skb) {
    184		if (charge)
    185			sk_mem_uncharge(sk, len);
    186		put_page(sg_page(sge));
    187	}
    188	memset(sge, 0, sizeof(*sge));
    189	return len;
    190}
    191
    192static int __sk_msg_free(struct sock *sk, struct sk_msg *msg, u32 i,
    193			 bool charge)
    194{
    195	struct scatterlist *sge = sk_msg_elem(msg, i);
    196	int freed = 0;
    197
    198	while (msg->sg.size) {
    199		msg->sg.size -= sge->length;
    200		freed += sk_msg_free_elem(sk, msg, i, charge);
    201		sk_msg_iter_var_next(i);
    202		sk_msg_check_to_free(msg, i, msg->sg.size);
    203		sge = sk_msg_elem(msg, i);
    204	}
    205	consume_skb(msg->skb);
    206	sk_msg_init(msg);
    207	return freed;
    208}
    209
    210int sk_msg_free_nocharge(struct sock *sk, struct sk_msg *msg)
    211{
    212	return __sk_msg_free(sk, msg, msg->sg.start, false);
    213}
    214EXPORT_SYMBOL_GPL(sk_msg_free_nocharge);
    215
    216int sk_msg_free(struct sock *sk, struct sk_msg *msg)
    217{
    218	return __sk_msg_free(sk, msg, msg->sg.start, true);
    219}
    220EXPORT_SYMBOL_GPL(sk_msg_free);
    221
    222static void __sk_msg_free_partial(struct sock *sk, struct sk_msg *msg,
    223				  u32 bytes, bool charge)
    224{
    225	struct scatterlist *sge;
    226	u32 i = msg->sg.start;
    227
    228	while (bytes) {
    229		sge = sk_msg_elem(msg, i);
    230		if (!sge->length)
    231			break;
    232		if (bytes < sge->length) {
    233			if (charge)
    234				sk_mem_uncharge(sk, bytes);
    235			sge->length -= bytes;
    236			sge->offset += bytes;
    237			msg->sg.size -= bytes;
    238			break;
    239		}
    240
    241		msg->sg.size -= sge->length;
    242		bytes -= sge->length;
    243		sk_msg_free_elem(sk, msg, i, charge);
    244		sk_msg_iter_var_next(i);
    245		sk_msg_check_to_free(msg, i, bytes);
    246	}
    247	msg->sg.start = i;
    248}
    249
    250void sk_msg_free_partial(struct sock *sk, struct sk_msg *msg, u32 bytes)
    251{
    252	__sk_msg_free_partial(sk, msg, bytes, true);
    253}
    254EXPORT_SYMBOL_GPL(sk_msg_free_partial);
    255
    256void sk_msg_free_partial_nocharge(struct sock *sk, struct sk_msg *msg,
    257				  u32 bytes)
    258{
    259	__sk_msg_free_partial(sk, msg, bytes, false);
    260}
    261
    262void sk_msg_trim(struct sock *sk, struct sk_msg *msg, int len)
    263{
    264	int trim = msg->sg.size - len;
    265	u32 i = msg->sg.end;
    266
    267	if (trim <= 0) {
    268		WARN_ON(trim < 0);
    269		return;
    270	}
    271
    272	sk_msg_iter_var_prev(i);
    273	msg->sg.size = len;
    274	while (msg->sg.data[i].length &&
    275	       trim >= msg->sg.data[i].length) {
    276		trim -= msg->sg.data[i].length;
    277		sk_msg_free_elem(sk, msg, i, true);
    278		sk_msg_iter_var_prev(i);
    279		if (!trim)
    280			goto out;
    281	}
    282
    283	msg->sg.data[i].length -= trim;
    284	sk_mem_uncharge(sk, trim);
    285	/* Adjust copybreak if it falls into the trimmed part of last buf */
    286	if (msg->sg.curr == i && msg->sg.copybreak > msg->sg.data[i].length)
    287		msg->sg.copybreak = msg->sg.data[i].length;
    288out:
    289	sk_msg_iter_var_next(i);
    290	msg->sg.end = i;
    291
    292	/* If we trim data a full sg elem before curr pointer update
    293	 * copybreak and current so that any future copy operations
    294	 * start at new copy location.
    295	 * However trimed data that has not yet been used in a copy op
    296	 * does not require an update.
    297	 */
    298	if (!msg->sg.size) {
    299		msg->sg.curr = msg->sg.start;
    300		msg->sg.copybreak = 0;
    301	} else if (sk_msg_iter_dist(msg->sg.start, msg->sg.curr) >=
    302		   sk_msg_iter_dist(msg->sg.start, msg->sg.end)) {
    303		sk_msg_iter_var_prev(i);
    304		msg->sg.curr = i;
    305		msg->sg.copybreak = msg->sg.data[i].length;
    306	}
    307}
    308EXPORT_SYMBOL_GPL(sk_msg_trim);
    309
    310int sk_msg_zerocopy_from_iter(struct sock *sk, struct iov_iter *from,
    311			      struct sk_msg *msg, u32 bytes)
    312{
    313	int i, maxpages, ret = 0, num_elems = sk_msg_elem_used(msg);
    314	const int to_max_pages = MAX_MSG_FRAGS;
    315	struct page *pages[MAX_MSG_FRAGS];
    316	ssize_t orig, copied, use, offset;
    317
    318	orig = msg->sg.size;
    319	while (bytes > 0) {
    320		i = 0;
    321		maxpages = to_max_pages - num_elems;
    322		if (maxpages == 0) {
    323			ret = -EFAULT;
    324			goto out;
    325		}
    326
    327		copied = iov_iter_get_pages(from, pages, bytes, maxpages,
    328					    &offset);
    329		if (copied <= 0) {
    330			ret = -EFAULT;
    331			goto out;
    332		}
    333
    334		iov_iter_advance(from, copied);
    335		bytes -= copied;
    336		msg->sg.size += copied;
    337
    338		while (copied) {
    339			use = min_t(int, copied, PAGE_SIZE - offset);
    340			sg_set_page(&msg->sg.data[msg->sg.end],
    341				    pages[i], use, offset);
    342			sg_unmark_end(&msg->sg.data[msg->sg.end]);
    343			sk_mem_charge(sk, use);
    344
    345			offset = 0;
    346			copied -= use;
    347			sk_msg_iter_next(msg, end);
    348			num_elems++;
    349			i++;
    350		}
    351		/* When zerocopy is mixed with sk_msg_*copy* operations we
    352		 * may have a copybreak set in this case clear and prefer
    353		 * zerocopy remainder when possible.
    354		 */
    355		msg->sg.copybreak = 0;
    356		msg->sg.curr = msg->sg.end;
    357	}
    358out:
    359	/* Revert iov_iter updates, msg will need to use 'trim' later if it
    360	 * also needs to be cleared.
    361	 */
    362	if (ret)
    363		iov_iter_revert(from, msg->sg.size - orig);
    364	return ret;
    365}
    366EXPORT_SYMBOL_GPL(sk_msg_zerocopy_from_iter);
    367
    368int sk_msg_memcopy_from_iter(struct sock *sk, struct iov_iter *from,
    369			     struct sk_msg *msg, u32 bytes)
    370{
    371	int ret = -ENOSPC, i = msg->sg.curr;
    372	struct scatterlist *sge;
    373	u32 copy, buf_size;
    374	void *to;
    375
    376	do {
    377		sge = sk_msg_elem(msg, i);
    378		/* This is possible if a trim operation shrunk the buffer */
    379		if (msg->sg.copybreak >= sge->length) {
    380			msg->sg.copybreak = 0;
    381			sk_msg_iter_var_next(i);
    382			if (i == msg->sg.end)
    383				break;
    384			sge = sk_msg_elem(msg, i);
    385		}
    386
    387		buf_size = sge->length - msg->sg.copybreak;
    388		copy = (buf_size > bytes) ? bytes : buf_size;
    389		to = sg_virt(sge) + msg->sg.copybreak;
    390		msg->sg.copybreak += copy;
    391		if (sk->sk_route_caps & NETIF_F_NOCACHE_COPY)
    392			ret = copy_from_iter_nocache(to, copy, from);
    393		else
    394			ret = copy_from_iter(to, copy, from);
    395		if (ret != copy) {
    396			ret = -EFAULT;
    397			goto out;
    398		}
    399		bytes -= copy;
    400		if (!bytes)
    401			break;
    402		msg->sg.copybreak = 0;
    403		sk_msg_iter_var_next(i);
    404	} while (i != msg->sg.end);
    405out:
    406	msg->sg.curr = i;
    407	return ret;
    408}
    409EXPORT_SYMBOL_GPL(sk_msg_memcopy_from_iter);
    410
    411/* Receive sk_msg from psock->ingress_msg to @msg. */
    412int sk_msg_recvmsg(struct sock *sk, struct sk_psock *psock, struct msghdr *msg,
    413		   int len, int flags)
    414{
    415	struct iov_iter *iter = &msg->msg_iter;
    416	int peek = flags & MSG_PEEK;
    417	struct sk_msg *msg_rx;
    418	int i, copied = 0;
    419
    420	msg_rx = sk_psock_peek_msg(psock);
    421	while (copied != len) {
    422		struct scatterlist *sge;
    423
    424		if (unlikely(!msg_rx))
    425			break;
    426
    427		i = msg_rx->sg.start;
    428		do {
    429			struct page *page;
    430			int copy;
    431
    432			sge = sk_msg_elem(msg_rx, i);
    433			copy = sge->length;
    434			page = sg_page(sge);
    435			if (copied + copy > len)
    436				copy = len - copied;
    437			copy = copy_page_to_iter(page, sge->offset, copy, iter);
    438			if (!copy)
    439				return copied ? copied : -EFAULT;
    440
    441			copied += copy;
    442			if (likely(!peek)) {
    443				sge->offset += copy;
    444				sge->length -= copy;
    445				if (!msg_rx->skb)
    446					sk_mem_uncharge(sk, copy);
    447				msg_rx->sg.size -= copy;
    448
    449				if (!sge->length) {
    450					sk_msg_iter_var_next(i);
    451					if (!msg_rx->skb)
    452						put_page(page);
    453				}
    454			} else {
    455				/* Lets not optimize peek case if copy_page_to_iter
    456				 * didn't copy the entire length lets just break.
    457				 */
    458				if (copy != sge->length)
    459					return copied;
    460				sk_msg_iter_var_next(i);
    461			}
    462
    463			if (copied == len)
    464				break;
    465		} while (i != msg_rx->sg.end);
    466
    467		if (unlikely(peek)) {
    468			msg_rx = sk_psock_next_msg(psock, msg_rx);
    469			if (!msg_rx)
    470				break;
    471			continue;
    472		}
    473
    474		msg_rx->sg.start = i;
    475		if (!sge->length && msg_rx->sg.start == msg_rx->sg.end) {
    476			msg_rx = sk_psock_dequeue_msg(psock);
    477			kfree_sk_msg(msg_rx);
    478		}
    479		msg_rx = sk_psock_peek_msg(psock);
    480	}
    481
    482	return copied;
    483}
    484EXPORT_SYMBOL_GPL(sk_msg_recvmsg);
    485
    486bool sk_msg_is_readable(struct sock *sk)
    487{
    488	struct sk_psock *psock;
    489	bool empty = true;
    490
    491	rcu_read_lock();
    492	psock = sk_psock(sk);
    493	if (likely(psock))
    494		empty = list_empty(&psock->ingress_msg);
    495	rcu_read_unlock();
    496	return !empty;
    497}
    498EXPORT_SYMBOL_GPL(sk_msg_is_readable);
    499
    500static struct sk_msg *sk_psock_create_ingress_msg(struct sock *sk,
    501						  struct sk_buff *skb)
    502{
    503	struct sk_msg *msg;
    504
    505	if (atomic_read(&sk->sk_rmem_alloc) > sk->sk_rcvbuf)
    506		return NULL;
    507
    508	if (!sk_rmem_schedule(sk, skb, skb->truesize))
    509		return NULL;
    510
    511	msg = kzalloc(sizeof(*msg), __GFP_NOWARN | GFP_KERNEL);
    512	if (unlikely(!msg))
    513		return NULL;
    514
    515	sk_msg_init(msg);
    516	return msg;
    517}
    518
    519static int sk_psock_skb_ingress_enqueue(struct sk_buff *skb,
    520					u32 off, u32 len,
    521					struct sk_psock *psock,
    522					struct sock *sk,
    523					struct sk_msg *msg)
    524{
    525	int num_sge, copied;
    526
    527	num_sge = skb_to_sgvec(skb, msg->sg.data, off, len);
    528	if (num_sge < 0) {
    529		/* skb linearize may fail with ENOMEM, but lets simply try again
    530		 * later if this happens. Under memory pressure we don't want to
    531		 * drop the skb. We need to linearize the skb so that the mapping
    532		 * in skb_to_sgvec can not error.
    533		 */
    534		if (skb_linearize(skb))
    535			return -EAGAIN;
    536
    537		num_sge = skb_to_sgvec(skb, msg->sg.data, off, len);
    538		if (unlikely(num_sge < 0))
    539			return num_sge;
    540	}
    541
    542	copied = len;
    543	msg->sg.start = 0;
    544	msg->sg.size = copied;
    545	msg->sg.end = num_sge;
    546	msg->skb = skb;
    547
    548	sk_psock_queue_msg(psock, msg);
    549	sk_psock_data_ready(sk, psock);
    550	return copied;
    551}
    552
    553static int sk_psock_skb_ingress_self(struct sk_psock *psock, struct sk_buff *skb,
    554				     u32 off, u32 len);
    555
    556static int sk_psock_skb_ingress(struct sk_psock *psock, struct sk_buff *skb,
    557				u32 off, u32 len)
    558{
    559	struct sock *sk = psock->sk;
    560	struct sk_msg *msg;
    561	int err;
    562
    563	/* If we are receiving on the same sock skb->sk is already assigned,
    564	 * skip memory accounting and owner transition seeing it already set
    565	 * correctly.
    566	 */
    567	if (unlikely(skb->sk == sk))
    568		return sk_psock_skb_ingress_self(psock, skb, off, len);
    569	msg = sk_psock_create_ingress_msg(sk, skb);
    570	if (!msg)
    571		return -EAGAIN;
    572
    573	/* This will transition ownership of the data from the socket where
    574	 * the BPF program was run initiating the redirect to the socket
    575	 * we will eventually receive this data on. The data will be released
    576	 * from skb_consume found in __tcp_bpf_recvmsg() after its been copied
    577	 * into user buffers.
    578	 */
    579	skb_set_owner_r(skb, sk);
    580	err = sk_psock_skb_ingress_enqueue(skb, off, len, psock, sk, msg);
    581	if (err < 0)
    582		kfree(msg);
    583	return err;
    584}
    585
    586/* Puts an skb on the ingress queue of the socket already assigned to the
    587 * skb. In this case we do not need to check memory limits or skb_set_owner_r
    588 * because the skb is already accounted for here.
    589 */
    590static int sk_psock_skb_ingress_self(struct sk_psock *psock, struct sk_buff *skb,
    591				     u32 off, u32 len)
    592{
    593	struct sk_msg *msg = kzalloc(sizeof(*msg), __GFP_NOWARN | GFP_ATOMIC);
    594	struct sock *sk = psock->sk;
    595	int err;
    596
    597	if (unlikely(!msg))
    598		return -EAGAIN;
    599	sk_msg_init(msg);
    600	skb_set_owner_r(skb, sk);
    601	err = sk_psock_skb_ingress_enqueue(skb, off, len, psock, sk, msg);
    602	if (err < 0)
    603		kfree(msg);
    604	return err;
    605}
    606
    607static int sk_psock_handle_skb(struct sk_psock *psock, struct sk_buff *skb,
    608			       u32 off, u32 len, bool ingress)
    609{
    610	if (!ingress) {
    611		if (!sock_writeable(psock->sk))
    612			return -EAGAIN;
    613		return skb_send_sock(psock->sk, skb, off, len);
    614	}
    615	return sk_psock_skb_ingress(psock, skb, off, len);
    616}
    617
    618static void sk_psock_skb_state(struct sk_psock *psock,
    619			       struct sk_psock_work_state *state,
    620			       struct sk_buff *skb,
    621			       int len, int off)
    622{
    623	spin_lock_bh(&psock->ingress_lock);
    624	if (sk_psock_test_state(psock, SK_PSOCK_TX_ENABLED)) {
    625		state->skb = skb;
    626		state->len = len;
    627		state->off = off;
    628	} else {
    629		sock_drop(psock->sk, skb);
    630	}
    631	spin_unlock_bh(&psock->ingress_lock);
    632}
    633
    634static void sk_psock_backlog(struct work_struct *work)
    635{
    636	struct sk_psock *psock = container_of(work, struct sk_psock, work);
    637	struct sk_psock_work_state *state = &psock->work_state;
    638	struct sk_buff *skb = NULL;
    639	bool ingress;
    640	u32 len, off;
    641	int ret;
    642
    643	mutex_lock(&psock->work_mutex);
    644	if (unlikely(state->skb)) {
    645		spin_lock_bh(&psock->ingress_lock);
    646		skb = state->skb;
    647		len = state->len;
    648		off = state->off;
    649		state->skb = NULL;
    650		spin_unlock_bh(&psock->ingress_lock);
    651	}
    652	if (skb)
    653		goto start;
    654
    655	while ((skb = skb_dequeue(&psock->ingress_skb))) {
    656		len = skb->len;
    657		off = 0;
    658		if (skb_bpf_strparser(skb)) {
    659			struct strp_msg *stm = strp_msg(skb);
    660
    661			off = stm->offset;
    662			len = stm->full_len;
    663		}
    664start:
    665		ingress = skb_bpf_ingress(skb);
    666		skb_bpf_redirect_clear(skb);
    667		do {
    668			ret = -EIO;
    669			if (!sock_flag(psock->sk, SOCK_DEAD))
    670				ret = sk_psock_handle_skb(psock, skb, off,
    671							  len, ingress);
    672			if (ret <= 0) {
    673				if (ret == -EAGAIN) {
    674					sk_psock_skb_state(psock, state, skb,
    675							   len, off);
    676					goto end;
    677				}
    678				/* Hard errors break pipe and stop xmit. */
    679				sk_psock_report_error(psock, ret ? -ret : EPIPE);
    680				sk_psock_clear_state(psock, SK_PSOCK_TX_ENABLED);
    681				sock_drop(psock->sk, skb);
    682				goto end;
    683			}
    684			off += ret;
    685			len -= ret;
    686		} while (len);
    687
    688		if (!ingress)
    689			kfree_skb(skb);
    690	}
    691end:
    692	mutex_unlock(&psock->work_mutex);
    693}
    694
    695struct sk_psock *sk_psock_init(struct sock *sk, int node)
    696{
    697	struct sk_psock *psock;
    698	struct proto *prot;
    699
    700	write_lock_bh(&sk->sk_callback_lock);
    701
    702	if (sk_is_inet(sk) && inet_csk_has_ulp(sk)) {
    703		psock = ERR_PTR(-EINVAL);
    704		goto out;
    705	}
    706
    707	if (sk->sk_user_data) {
    708		psock = ERR_PTR(-EBUSY);
    709		goto out;
    710	}
    711
    712	psock = kzalloc_node(sizeof(*psock), GFP_ATOMIC | __GFP_NOWARN, node);
    713	if (!psock) {
    714		psock = ERR_PTR(-ENOMEM);
    715		goto out;
    716	}
    717
    718	prot = READ_ONCE(sk->sk_prot);
    719	psock->sk = sk;
    720	psock->eval = __SK_NONE;
    721	psock->sk_proto = prot;
    722	psock->saved_unhash = prot->unhash;
    723	psock->saved_close = prot->close;
    724	psock->saved_write_space = sk->sk_write_space;
    725
    726	INIT_LIST_HEAD(&psock->link);
    727	spin_lock_init(&psock->link_lock);
    728
    729	INIT_WORK(&psock->work, sk_psock_backlog);
    730	mutex_init(&psock->work_mutex);
    731	INIT_LIST_HEAD(&psock->ingress_msg);
    732	spin_lock_init(&psock->ingress_lock);
    733	skb_queue_head_init(&psock->ingress_skb);
    734
    735	sk_psock_set_state(psock, SK_PSOCK_TX_ENABLED);
    736	refcount_set(&psock->refcnt, 1);
    737
    738	rcu_assign_sk_user_data_nocopy(sk, psock);
    739	sock_hold(sk);
    740
    741out:
    742	write_unlock_bh(&sk->sk_callback_lock);
    743	return psock;
    744}
    745EXPORT_SYMBOL_GPL(sk_psock_init);
    746
    747struct sk_psock_link *sk_psock_link_pop(struct sk_psock *psock)
    748{
    749	struct sk_psock_link *link;
    750
    751	spin_lock_bh(&psock->link_lock);
    752	link = list_first_entry_or_null(&psock->link, struct sk_psock_link,
    753					list);
    754	if (link)
    755		list_del(&link->list);
    756	spin_unlock_bh(&psock->link_lock);
    757	return link;
    758}
    759
    760static void __sk_psock_purge_ingress_msg(struct sk_psock *psock)
    761{
    762	struct sk_msg *msg, *tmp;
    763
    764	list_for_each_entry_safe(msg, tmp, &psock->ingress_msg, list) {
    765		list_del(&msg->list);
    766		sk_msg_free(psock->sk, msg);
    767		kfree(msg);
    768	}
    769}
    770
    771static void __sk_psock_zap_ingress(struct sk_psock *psock)
    772{
    773	struct sk_buff *skb;
    774
    775	while ((skb = skb_dequeue(&psock->ingress_skb)) != NULL) {
    776		skb_bpf_redirect_clear(skb);
    777		sock_drop(psock->sk, skb);
    778	}
    779	kfree_skb(psock->work_state.skb);
    780	/* We null the skb here to ensure that calls to sk_psock_backlog
    781	 * do not pick up the free'd skb.
    782	 */
    783	psock->work_state.skb = NULL;
    784	__sk_psock_purge_ingress_msg(psock);
    785}
    786
    787static void sk_psock_link_destroy(struct sk_psock *psock)
    788{
    789	struct sk_psock_link *link, *tmp;
    790
    791	list_for_each_entry_safe(link, tmp, &psock->link, list) {
    792		list_del(&link->list);
    793		sk_psock_free_link(link);
    794	}
    795}
    796
    797void sk_psock_stop(struct sk_psock *psock, bool wait)
    798{
    799	spin_lock_bh(&psock->ingress_lock);
    800	sk_psock_clear_state(psock, SK_PSOCK_TX_ENABLED);
    801	sk_psock_cork_free(psock);
    802	__sk_psock_zap_ingress(psock);
    803	spin_unlock_bh(&psock->ingress_lock);
    804
    805	if (wait)
    806		cancel_work_sync(&psock->work);
    807}
    808
    809static void sk_psock_done_strp(struct sk_psock *psock);
    810
    811static void sk_psock_destroy(struct work_struct *work)
    812{
    813	struct sk_psock *psock = container_of(to_rcu_work(work),
    814					      struct sk_psock, rwork);
    815	/* No sk_callback_lock since already detached. */
    816
    817	sk_psock_done_strp(psock);
    818
    819	cancel_work_sync(&psock->work);
    820	mutex_destroy(&psock->work_mutex);
    821
    822	psock_progs_drop(&psock->progs);
    823
    824	sk_psock_link_destroy(psock);
    825	sk_psock_cork_free(psock);
    826
    827	if (psock->sk_redir)
    828		sock_put(psock->sk_redir);
    829	sock_put(psock->sk);
    830	kfree(psock);
    831}
    832
    833void sk_psock_drop(struct sock *sk, struct sk_psock *psock)
    834{
    835	write_lock_bh(&sk->sk_callback_lock);
    836	sk_psock_restore_proto(sk, psock);
    837	rcu_assign_sk_user_data(sk, NULL);
    838	if (psock->progs.stream_parser)
    839		sk_psock_stop_strp(sk, psock);
    840	else if (psock->progs.stream_verdict || psock->progs.skb_verdict)
    841		sk_psock_stop_verdict(sk, psock);
    842	write_unlock_bh(&sk->sk_callback_lock);
    843
    844	sk_psock_stop(psock, false);
    845
    846	INIT_RCU_WORK(&psock->rwork, sk_psock_destroy);
    847	queue_rcu_work(system_wq, &psock->rwork);
    848}
    849EXPORT_SYMBOL_GPL(sk_psock_drop);
    850
    851static int sk_psock_map_verd(int verdict, bool redir)
    852{
    853	switch (verdict) {
    854	case SK_PASS:
    855		return redir ? __SK_REDIRECT : __SK_PASS;
    856	case SK_DROP:
    857	default:
    858		break;
    859	}
    860
    861	return __SK_DROP;
    862}
    863
    864int sk_psock_msg_verdict(struct sock *sk, struct sk_psock *psock,
    865			 struct sk_msg *msg)
    866{
    867	struct bpf_prog *prog;
    868	int ret;
    869
    870	rcu_read_lock();
    871	prog = READ_ONCE(psock->progs.msg_parser);
    872	if (unlikely(!prog)) {
    873		ret = __SK_PASS;
    874		goto out;
    875	}
    876
    877	sk_msg_compute_data_pointers(msg);
    878	msg->sk = sk;
    879	ret = bpf_prog_run_pin_on_cpu(prog, msg);
    880	ret = sk_psock_map_verd(ret, msg->sk_redir);
    881	psock->apply_bytes = msg->apply_bytes;
    882	if (ret == __SK_REDIRECT) {
    883		if (psock->sk_redir)
    884			sock_put(psock->sk_redir);
    885		psock->sk_redir = msg->sk_redir;
    886		if (!psock->sk_redir) {
    887			ret = __SK_DROP;
    888			goto out;
    889		}
    890		sock_hold(psock->sk_redir);
    891	}
    892out:
    893	rcu_read_unlock();
    894	return ret;
    895}
    896EXPORT_SYMBOL_GPL(sk_psock_msg_verdict);
    897
    898static int sk_psock_skb_redirect(struct sk_psock *from, struct sk_buff *skb)
    899{
    900	struct sk_psock *psock_other;
    901	struct sock *sk_other;
    902
    903	sk_other = skb_bpf_redirect_fetch(skb);
    904	/* This error is a buggy BPF program, it returned a redirect
    905	 * return code, but then didn't set a redirect interface.
    906	 */
    907	if (unlikely(!sk_other)) {
    908		skb_bpf_redirect_clear(skb);
    909		sock_drop(from->sk, skb);
    910		return -EIO;
    911	}
    912	psock_other = sk_psock(sk_other);
    913	/* This error indicates the socket is being torn down or had another
    914	 * error that caused the pipe to break. We can't send a packet on
    915	 * a socket that is in this state so we drop the skb.
    916	 */
    917	if (!psock_other || sock_flag(sk_other, SOCK_DEAD)) {
    918		skb_bpf_redirect_clear(skb);
    919		sock_drop(from->sk, skb);
    920		return -EIO;
    921	}
    922	spin_lock_bh(&psock_other->ingress_lock);
    923	if (!sk_psock_test_state(psock_other, SK_PSOCK_TX_ENABLED)) {
    924		spin_unlock_bh(&psock_other->ingress_lock);
    925		skb_bpf_redirect_clear(skb);
    926		sock_drop(from->sk, skb);
    927		return -EIO;
    928	}
    929
    930	skb_queue_tail(&psock_other->ingress_skb, skb);
    931	schedule_work(&psock_other->work);
    932	spin_unlock_bh(&psock_other->ingress_lock);
    933	return 0;
    934}
    935
    936static void sk_psock_tls_verdict_apply(struct sk_buff *skb,
    937				       struct sk_psock *from, int verdict)
    938{
    939	switch (verdict) {
    940	case __SK_REDIRECT:
    941		sk_psock_skb_redirect(from, skb);
    942		break;
    943	case __SK_PASS:
    944	case __SK_DROP:
    945	default:
    946		break;
    947	}
    948}
    949
    950int sk_psock_tls_strp_read(struct sk_psock *psock, struct sk_buff *skb)
    951{
    952	struct bpf_prog *prog;
    953	int ret = __SK_PASS;
    954
    955	rcu_read_lock();
    956	prog = READ_ONCE(psock->progs.stream_verdict);
    957	if (likely(prog)) {
    958		skb->sk = psock->sk;
    959		skb_dst_drop(skb);
    960		skb_bpf_redirect_clear(skb);
    961		ret = bpf_prog_run_pin_on_cpu(prog, skb);
    962		ret = sk_psock_map_verd(ret, skb_bpf_redirect_fetch(skb));
    963		skb->sk = NULL;
    964	}
    965	sk_psock_tls_verdict_apply(skb, psock, ret);
    966	rcu_read_unlock();
    967	return ret;
    968}
    969EXPORT_SYMBOL_GPL(sk_psock_tls_strp_read);
    970
    971static int sk_psock_verdict_apply(struct sk_psock *psock, struct sk_buff *skb,
    972				  int verdict)
    973{
    974	struct sock *sk_other;
    975	int err = 0;
    976	u32 len, off;
    977
    978	switch (verdict) {
    979	case __SK_PASS:
    980		err = -EIO;
    981		sk_other = psock->sk;
    982		if (sock_flag(sk_other, SOCK_DEAD) ||
    983		    !sk_psock_test_state(psock, SK_PSOCK_TX_ENABLED)) {
    984			skb_bpf_redirect_clear(skb);
    985			goto out_free;
    986		}
    987
    988		skb_bpf_set_ingress(skb);
    989
    990		/* If the queue is empty then we can submit directly
    991		 * into the msg queue. If its not empty we have to
    992		 * queue work otherwise we may get OOO data. Otherwise,
    993		 * if sk_psock_skb_ingress errors will be handled by
    994		 * retrying later from workqueue.
    995		 */
    996		if (skb_queue_empty(&psock->ingress_skb)) {
    997			len = skb->len;
    998			off = 0;
    999			if (skb_bpf_strparser(skb)) {
   1000				struct strp_msg *stm = strp_msg(skb);
   1001
   1002				off = stm->offset;
   1003				len = stm->full_len;
   1004			}
   1005			err = sk_psock_skb_ingress_self(psock, skb, off, len);
   1006		}
   1007		if (err < 0) {
   1008			spin_lock_bh(&psock->ingress_lock);
   1009			if (sk_psock_test_state(psock, SK_PSOCK_TX_ENABLED)) {
   1010				skb_queue_tail(&psock->ingress_skb, skb);
   1011				schedule_work(&psock->work);
   1012				err = 0;
   1013			}
   1014			spin_unlock_bh(&psock->ingress_lock);
   1015			if (err < 0) {
   1016				skb_bpf_redirect_clear(skb);
   1017				goto out_free;
   1018			}
   1019		}
   1020		break;
   1021	case __SK_REDIRECT:
   1022		err = sk_psock_skb_redirect(psock, skb);
   1023		break;
   1024	case __SK_DROP:
   1025	default:
   1026out_free:
   1027		sock_drop(psock->sk, skb);
   1028	}
   1029
   1030	return err;
   1031}
   1032
   1033static void sk_psock_write_space(struct sock *sk)
   1034{
   1035	struct sk_psock *psock;
   1036	void (*write_space)(struct sock *sk) = NULL;
   1037
   1038	rcu_read_lock();
   1039	psock = sk_psock(sk);
   1040	if (likely(psock)) {
   1041		if (sk_psock_test_state(psock, SK_PSOCK_TX_ENABLED))
   1042			schedule_work(&psock->work);
   1043		write_space = psock->saved_write_space;
   1044	}
   1045	rcu_read_unlock();
   1046	if (write_space)
   1047		write_space(sk);
   1048}
   1049
   1050#if IS_ENABLED(CONFIG_BPF_STREAM_PARSER)
   1051static void sk_psock_strp_read(struct strparser *strp, struct sk_buff *skb)
   1052{
   1053	struct sk_psock *psock;
   1054	struct bpf_prog *prog;
   1055	int ret = __SK_DROP;
   1056	struct sock *sk;
   1057
   1058	rcu_read_lock();
   1059	sk = strp->sk;
   1060	psock = sk_psock(sk);
   1061	if (unlikely(!psock)) {
   1062		sock_drop(sk, skb);
   1063		goto out;
   1064	}
   1065	prog = READ_ONCE(psock->progs.stream_verdict);
   1066	if (likely(prog)) {
   1067		skb->sk = sk;
   1068		skb_dst_drop(skb);
   1069		skb_bpf_redirect_clear(skb);
   1070		ret = bpf_prog_run_pin_on_cpu(prog, skb);
   1071		if (ret == SK_PASS)
   1072			skb_bpf_set_strparser(skb);
   1073		ret = sk_psock_map_verd(ret, skb_bpf_redirect_fetch(skb));
   1074		skb->sk = NULL;
   1075	}
   1076	sk_psock_verdict_apply(psock, skb, ret);
   1077out:
   1078	rcu_read_unlock();
   1079}
   1080
   1081static int sk_psock_strp_read_done(struct strparser *strp, int err)
   1082{
   1083	return err;
   1084}
   1085
   1086static int sk_psock_strp_parse(struct strparser *strp, struct sk_buff *skb)
   1087{
   1088	struct sk_psock *psock = container_of(strp, struct sk_psock, strp);
   1089	struct bpf_prog *prog;
   1090	int ret = skb->len;
   1091
   1092	rcu_read_lock();
   1093	prog = READ_ONCE(psock->progs.stream_parser);
   1094	if (likely(prog)) {
   1095		skb->sk = psock->sk;
   1096		ret = bpf_prog_run_pin_on_cpu(prog, skb);
   1097		skb->sk = NULL;
   1098	}
   1099	rcu_read_unlock();
   1100	return ret;
   1101}
   1102
   1103/* Called with socket lock held. */
   1104static void sk_psock_strp_data_ready(struct sock *sk)
   1105{
   1106	struct sk_psock *psock;
   1107
   1108	rcu_read_lock();
   1109	psock = sk_psock(sk);
   1110	if (likely(psock)) {
   1111		if (tls_sw_has_ctx_rx(sk)) {
   1112			psock->saved_data_ready(sk);
   1113		} else {
   1114			write_lock_bh(&sk->sk_callback_lock);
   1115			strp_data_ready(&psock->strp);
   1116			write_unlock_bh(&sk->sk_callback_lock);
   1117		}
   1118	}
   1119	rcu_read_unlock();
   1120}
   1121
   1122int sk_psock_init_strp(struct sock *sk, struct sk_psock *psock)
   1123{
   1124	static const struct strp_callbacks cb = {
   1125		.rcv_msg	= sk_psock_strp_read,
   1126		.read_sock_done	= sk_psock_strp_read_done,
   1127		.parse_msg	= sk_psock_strp_parse,
   1128	};
   1129
   1130	return strp_init(&psock->strp, sk, &cb);
   1131}
   1132
   1133void sk_psock_start_strp(struct sock *sk, struct sk_psock *psock)
   1134{
   1135	if (psock->saved_data_ready)
   1136		return;
   1137
   1138	psock->saved_data_ready = sk->sk_data_ready;
   1139	sk->sk_data_ready = sk_psock_strp_data_ready;
   1140	sk->sk_write_space = sk_psock_write_space;
   1141}
   1142
   1143void sk_psock_stop_strp(struct sock *sk, struct sk_psock *psock)
   1144{
   1145	psock_set_prog(&psock->progs.stream_parser, NULL);
   1146
   1147	if (!psock->saved_data_ready)
   1148		return;
   1149
   1150	sk->sk_data_ready = psock->saved_data_ready;
   1151	psock->saved_data_ready = NULL;
   1152	strp_stop(&psock->strp);
   1153}
   1154
   1155static void sk_psock_done_strp(struct sk_psock *psock)
   1156{
   1157	/* Parser has been stopped */
   1158	if (psock->progs.stream_parser)
   1159		strp_done(&psock->strp);
   1160}
   1161#else
   1162static void sk_psock_done_strp(struct sk_psock *psock)
   1163{
   1164}
   1165#endif /* CONFIG_BPF_STREAM_PARSER */
   1166
   1167static int sk_psock_verdict_recv(read_descriptor_t *desc, struct sk_buff *skb,
   1168				 unsigned int offset, size_t orig_len)
   1169{
   1170	struct sock *sk = (struct sock *)desc->arg.data;
   1171	struct sk_psock *psock;
   1172	struct bpf_prog *prog;
   1173	int ret = __SK_DROP;
   1174	int len = orig_len;
   1175
   1176	/* clone here so sk_eat_skb() in tcp_read_sock does not drop our data */
   1177	skb = skb_clone(skb, GFP_ATOMIC);
   1178	if (!skb) {
   1179		desc->error = -ENOMEM;
   1180		return 0;
   1181	}
   1182
   1183	rcu_read_lock();
   1184	psock = sk_psock(sk);
   1185	if (unlikely(!psock)) {
   1186		len = 0;
   1187		sock_drop(sk, skb);
   1188		goto out;
   1189	}
   1190	prog = READ_ONCE(psock->progs.stream_verdict);
   1191	if (!prog)
   1192		prog = READ_ONCE(psock->progs.skb_verdict);
   1193	if (likely(prog)) {
   1194		skb->sk = sk;
   1195		skb_dst_drop(skb);
   1196		skb_bpf_redirect_clear(skb);
   1197		ret = bpf_prog_run_pin_on_cpu(prog, skb);
   1198		ret = sk_psock_map_verd(ret, skb_bpf_redirect_fetch(skb));
   1199		skb->sk = NULL;
   1200	}
   1201	if (sk_psock_verdict_apply(psock, skb, ret) < 0)
   1202		len = 0;
   1203out:
   1204	rcu_read_unlock();
   1205	return len;
   1206}
   1207
   1208static void sk_psock_verdict_data_ready(struct sock *sk)
   1209{
   1210	struct socket *sock = sk->sk_socket;
   1211	read_descriptor_t desc;
   1212
   1213	if (unlikely(!sock || !sock->ops || !sock->ops->read_sock))
   1214		return;
   1215
   1216	desc.arg.data = sk;
   1217	desc.error = 0;
   1218	desc.count = 1;
   1219
   1220	sock->ops->read_sock(sk, &desc, sk_psock_verdict_recv);
   1221}
   1222
   1223void sk_psock_start_verdict(struct sock *sk, struct sk_psock *psock)
   1224{
   1225	if (psock->saved_data_ready)
   1226		return;
   1227
   1228	psock->saved_data_ready = sk->sk_data_ready;
   1229	sk->sk_data_ready = sk_psock_verdict_data_ready;
   1230	sk->sk_write_space = sk_psock_write_space;
   1231}
   1232
   1233void sk_psock_stop_verdict(struct sock *sk, struct sk_psock *psock)
   1234{
   1235	psock_set_prog(&psock->progs.stream_verdict, NULL);
   1236	psock_set_prog(&psock->progs.skb_verdict, NULL);
   1237
   1238	if (!psock->saved_data_ready)
   1239		return;
   1240
   1241	sk->sk_data_ready = psock->saved_data_ready;
   1242	psock->saved_data_ready = NULL;
   1243}