cachepc-linux

Fork of AMDESE/linux with modifications for CachePC side-channel attack
git clone https://git.sinitax.com/sinitax/cachepc-linux
Log | Files | Refs | README | LICENSE | sfeed.txt

espintcp.c (12804B)


      1// SPDX-License-Identifier: GPL-2.0
      2#include <net/tcp.h>
      3#include <net/strparser.h>
      4#include <net/xfrm.h>
      5#include <net/esp.h>
      6#include <net/espintcp.h>
      7#include <linux/skmsg.h>
      8#include <net/inet_common.h>
      9#if IS_ENABLED(CONFIG_IPV6)
     10#include <net/ipv6_stubs.h>
     11#endif
     12
     13static void handle_nonesp(struct espintcp_ctx *ctx, struct sk_buff *skb,
     14			  struct sock *sk)
     15{
     16	if (atomic_read(&sk->sk_rmem_alloc) >= sk->sk_rcvbuf ||
     17	    !sk_rmem_schedule(sk, skb, skb->truesize)) {
     18		XFRM_INC_STATS(sock_net(sk), LINUX_MIB_XFRMINERROR);
     19		kfree_skb(skb);
     20		return;
     21	}
     22
     23	skb_set_owner_r(skb, sk);
     24
     25	memset(skb->cb, 0, sizeof(skb->cb));
     26	skb_queue_tail(&ctx->ike_queue, skb);
     27	ctx->saved_data_ready(sk);
     28}
     29
     30static void handle_esp(struct sk_buff *skb, struct sock *sk)
     31{
     32	struct tcp_skb_cb *tcp_cb = (struct tcp_skb_cb *)skb->cb;
     33
     34	skb_reset_transport_header(skb);
     35
     36	/* restore IP CB, we need at least IP6CB->nhoff */
     37	memmove(skb->cb, &tcp_cb->header, sizeof(tcp_cb->header));
     38
     39	rcu_read_lock();
     40	skb->dev = dev_get_by_index_rcu(sock_net(sk), skb->skb_iif);
     41	local_bh_disable();
     42#if IS_ENABLED(CONFIG_IPV6)
     43	if (sk->sk_family == AF_INET6)
     44		ipv6_stub->xfrm6_rcv_encap(skb, IPPROTO_ESP, 0, TCP_ENCAP_ESPINTCP);
     45	else
     46#endif
     47		xfrm4_rcv_encap(skb, IPPROTO_ESP, 0, TCP_ENCAP_ESPINTCP);
     48	local_bh_enable();
     49	rcu_read_unlock();
     50}
     51
     52static void espintcp_rcv(struct strparser *strp, struct sk_buff *skb)
     53{
     54	struct espintcp_ctx *ctx = container_of(strp, struct espintcp_ctx,
     55						strp);
     56	struct strp_msg *rxm = strp_msg(skb);
     57	int len = rxm->full_len - 2;
     58	u32 nonesp_marker;
     59	int err;
     60
     61	/* keepalive packet? */
     62	if (unlikely(len == 1)) {
     63		u8 data;
     64
     65		err = skb_copy_bits(skb, rxm->offset + 2, &data, 1);
     66		if (err < 0) {
     67			XFRM_INC_STATS(sock_net(strp->sk), LINUX_MIB_XFRMINHDRERROR);
     68			kfree_skb(skb);
     69			return;
     70		}
     71
     72		if (data == 0xff) {
     73			kfree_skb(skb);
     74			return;
     75		}
     76	}
     77
     78	/* drop other short messages */
     79	if (unlikely(len <= sizeof(nonesp_marker))) {
     80		XFRM_INC_STATS(sock_net(strp->sk), LINUX_MIB_XFRMINHDRERROR);
     81		kfree_skb(skb);
     82		return;
     83	}
     84
     85	err = skb_copy_bits(skb, rxm->offset + 2, &nonesp_marker,
     86			    sizeof(nonesp_marker));
     87	if (err < 0) {
     88		XFRM_INC_STATS(sock_net(strp->sk), LINUX_MIB_XFRMINHDRERROR);
     89		kfree_skb(skb);
     90		return;
     91	}
     92
     93	/* remove header, leave non-ESP marker/SPI */
     94	if (!__pskb_pull(skb, rxm->offset + 2)) {
     95		XFRM_INC_STATS(sock_net(strp->sk), LINUX_MIB_XFRMINERROR);
     96		kfree_skb(skb);
     97		return;
     98	}
     99
    100	if (pskb_trim(skb, rxm->full_len - 2) != 0) {
    101		XFRM_INC_STATS(sock_net(strp->sk), LINUX_MIB_XFRMINERROR);
    102		kfree_skb(skb);
    103		return;
    104	}
    105
    106	if (nonesp_marker == 0)
    107		handle_nonesp(ctx, skb, strp->sk);
    108	else
    109		handle_esp(skb, strp->sk);
    110}
    111
    112static int espintcp_parse(struct strparser *strp, struct sk_buff *skb)
    113{
    114	struct strp_msg *rxm = strp_msg(skb);
    115	__be16 blen;
    116	u16 len;
    117	int err;
    118
    119	if (skb->len < rxm->offset + 2)
    120		return 0;
    121
    122	err = skb_copy_bits(skb, rxm->offset, &blen, sizeof(blen));
    123	if (err < 0)
    124		return err;
    125
    126	len = be16_to_cpu(blen);
    127	if (len < 2)
    128		return -EINVAL;
    129
    130	return len;
    131}
    132
    133static int espintcp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
    134			    int flags, int *addr_len)
    135{
    136	struct espintcp_ctx *ctx = espintcp_getctx(sk);
    137	struct sk_buff *skb;
    138	int err = 0;
    139	int copied;
    140	int off = 0;
    141
    142	skb = __skb_recv_datagram(sk, &ctx->ike_queue, flags, &off, &err);
    143	if (!skb) {
    144		if (err == -EAGAIN && sk->sk_shutdown & RCV_SHUTDOWN)
    145			return 0;
    146		return err;
    147	}
    148
    149	copied = len;
    150	if (copied > skb->len)
    151		copied = skb->len;
    152	else if (copied < skb->len)
    153		msg->msg_flags |= MSG_TRUNC;
    154
    155	err = skb_copy_datagram_msg(skb, 0, msg, copied);
    156	if (unlikely(err)) {
    157		kfree_skb(skb);
    158		return err;
    159	}
    160
    161	if (flags & MSG_TRUNC)
    162		copied = skb->len;
    163	kfree_skb(skb);
    164	return copied;
    165}
    166
    167int espintcp_queue_out(struct sock *sk, struct sk_buff *skb)
    168{
    169	struct espintcp_ctx *ctx = espintcp_getctx(sk);
    170
    171	if (skb_queue_len(&ctx->out_queue) >= netdev_max_backlog)
    172		return -ENOBUFS;
    173
    174	__skb_queue_tail(&ctx->out_queue, skb);
    175
    176	return 0;
    177}
    178EXPORT_SYMBOL_GPL(espintcp_queue_out);
    179
    180/* espintcp length field is 2B and length includes the length field's size */
    181#define MAX_ESPINTCP_MSG (((1 << 16) - 1) - 2)
    182
    183static int espintcp_sendskb_locked(struct sock *sk, struct espintcp_msg *emsg,
    184				   int flags)
    185{
    186	do {
    187		int ret;
    188
    189		ret = skb_send_sock_locked(sk, emsg->skb,
    190					   emsg->offset, emsg->len);
    191		if (ret < 0)
    192			return ret;
    193
    194		emsg->len -= ret;
    195		emsg->offset += ret;
    196	} while (emsg->len > 0);
    197
    198	kfree_skb(emsg->skb);
    199	memset(emsg, 0, sizeof(*emsg));
    200
    201	return 0;
    202}
    203
    204static int espintcp_sendskmsg_locked(struct sock *sk,
    205				     struct espintcp_msg *emsg, int flags)
    206{
    207	struct sk_msg *skmsg = &emsg->skmsg;
    208	struct scatterlist *sg;
    209	int done = 0;
    210	int ret;
    211
    212	flags |= MSG_SENDPAGE_NOTLAST;
    213	sg = &skmsg->sg.data[skmsg->sg.start];
    214	do {
    215		size_t size = sg->length - emsg->offset;
    216		int offset = sg->offset + emsg->offset;
    217		struct page *p;
    218
    219		emsg->offset = 0;
    220
    221		if (sg_is_last(sg))
    222			flags &= ~MSG_SENDPAGE_NOTLAST;
    223
    224		p = sg_page(sg);
    225retry:
    226		ret = do_tcp_sendpages(sk, p, offset, size, flags);
    227		if (ret < 0) {
    228			emsg->offset = offset - sg->offset;
    229			skmsg->sg.start += done;
    230			return ret;
    231		}
    232
    233		if (ret != size) {
    234			offset += ret;
    235			size -= ret;
    236			goto retry;
    237		}
    238
    239		done++;
    240		put_page(p);
    241		sk_mem_uncharge(sk, sg->length);
    242		sg = sg_next(sg);
    243	} while (sg);
    244
    245	memset(emsg, 0, sizeof(*emsg));
    246
    247	return 0;
    248}
    249
    250static int espintcp_push_msgs(struct sock *sk, int flags)
    251{
    252	struct espintcp_ctx *ctx = espintcp_getctx(sk);
    253	struct espintcp_msg *emsg = &ctx->partial;
    254	int err;
    255
    256	if (!emsg->len)
    257		return 0;
    258
    259	if (ctx->tx_running)
    260		return -EAGAIN;
    261	ctx->tx_running = 1;
    262
    263	if (emsg->skb)
    264		err = espintcp_sendskb_locked(sk, emsg, flags);
    265	else
    266		err = espintcp_sendskmsg_locked(sk, emsg, flags);
    267	if (err == -EAGAIN) {
    268		ctx->tx_running = 0;
    269		return flags & MSG_DONTWAIT ? -EAGAIN : 0;
    270	}
    271	if (!err)
    272		memset(emsg, 0, sizeof(*emsg));
    273
    274	ctx->tx_running = 0;
    275
    276	return err;
    277}
    278
    279int espintcp_push_skb(struct sock *sk, struct sk_buff *skb)
    280{
    281	struct espintcp_ctx *ctx = espintcp_getctx(sk);
    282	struct espintcp_msg *emsg = &ctx->partial;
    283	unsigned int len;
    284	int offset;
    285
    286	if (sk->sk_state != TCP_ESTABLISHED) {
    287		kfree_skb(skb);
    288		return -ECONNRESET;
    289	}
    290
    291	offset = skb_transport_offset(skb);
    292	len = skb->len - offset;
    293
    294	espintcp_push_msgs(sk, 0);
    295
    296	if (emsg->len) {
    297		kfree_skb(skb);
    298		return -ENOBUFS;
    299	}
    300
    301	skb_set_owner_w(skb, sk);
    302
    303	emsg->offset = offset;
    304	emsg->len = len;
    305	emsg->skb = skb;
    306
    307	espintcp_push_msgs(sk, 0);
    308
    309	return 0;
    310}
    311EXPORT_SYMBOL_GPL(espintcp_push_skb);
    312
    313static int espintcp_sendmsg(struct sock *sk, struct msghdr *msg, size_t size)
    314{
    315	long timeo = sock_sndtimeo(sk, msg->msg_flags & MSG_DONTWAIT);
    316	struct espintcp_ctx *ctx = espintcp_getctx(sk);
    317	struct espintcp_msg *emsg = &ctx->partial;
    318	struct iov_iter pfx_iter;
    319	struct kvec pfx_iov = {};
    320	size_t msglen = size + 2;
    321	char buf[2] = {0};
    322	int err, end;
    323
    324	if (msg->msg_flags & ~MSG_DONTWAIT)
    325		return -EOPNOTSUPP;
    326
    327	if (size > MAX_ESPINTCP_MSG)
    328		return -EMSGSIZE;
    329
    330	if (msg->msg_controllen)
    331		return -EOPNOTSUPP;
    332
    333	lock_sock(sk);
    334
    335	err = espintcp_push_msgs(sk, msg->msg_flags & MSG_DONTWAIT);
    336	if (err < 0) {
    337		if (err != -EAGAIN || !(msg->msg_flags & MSG_DONTWAIT))
    338			err = -ENOBUFS;
    339		goto unlock;
    340	}
    341
    342	sk_msg_init(&emsg->skmsg);
    343	while (1) {
    344		/* only -ENOMEM is possible since we don't coalesce */
    345		err = sk_msg_alloc(sk, &emsg->skmsg, msglen, 0);
    346		if (!err)
    347			break;
    348
    349		err = sk_stream_wait_memory(sk, &timeo);
    350		if (err)
    351			goto fail;
    352	}
    353
    354	*((__be16 *)buf) = cpu_to_be16(msglen);
    355	pfx_iov.iov_base = buf;
    356	pfx_iov.iov_len = sizeof(buf);
    357	iov_iter_kvec(&pfx_iter, WRITE, &pfx_iov, 1, pfx_iov.iov_len);
    358
    359	err = sk_msg_memcopy_from_iter(sk, &pfx_iter, &emsg->skmsg,
    360				       pfx_iov.iov_len);
    361	if (err < 0)
    362		goto fail;
    363
    364	err = sk_msg_memcopy_from_iter(sk, &msg->msg_iter, &emsg->skmsg, size);
    365	if (err < 0)
    366		goto fail;
    367
    368	end = emsg->skmsg.sg.end;
    369	emsg->len = size;
    370	sk_msg_iter_var_prev(end);
    371	sg_mark_end(sk_msg_elem(&emsg->skmsg, end));
    372
    373	tcp_rate_check_app_limited(sk);
    374
    375	err = espintcp_push_msgs(sk, msg->msg_flags & MSG_DONTWAIT);
    376	/* this message could be partially sent, keep it */
    377
    378	release_sock(sk);
    379
    380	return size;
    381
    382fail:
    383	sk_msg_free(sk, &emsg->skmsg);
    384	memset(emsg, 0, sizeof(*emsg));
    385unlock:
    386	release_sock(sk);
    387	return err;
    388}
    389
    390static struct proto espintcp_prot __ro_after_init;
    391static struct proto_ops espintcp_ops __ro_after_init;
    392static struct proto espintcp6_prot;
    393static struct proto_ops espintcp6_ops;
    394static DEFINE_MUTEX(tcpv6_prot_mutex);
    395
    396static void espintcp_data_ready(struct sock *sk)
    397{
    398	struct espintcp_ctx *ctx = espintcp_getctx(sk);
    399
    400	strp_data_ready(&ctx->strp);
    401}
    402
    403static void espintcp_tx_work(struct work_struct *work)
    404{
    405	struct espintcp_ctx *ctx = container_of(work,
    406						struct espintcp_ctx, work);
    407	struct sock *sk = ctx->strp.sk;
    408
    409	lock_sock(sk);
    410	if (!ctx->tx_running)
    411		espintcp_push_msgs(sk, 0);
    412	release_sock(sk);
    413}
    414
    415static void espintcp_write_space(struct sock *sk)
    416{
    417	struct espintcp_ctx *ctx = espintcp_getctx(sk);
    418
    419	schedule_work(&ctx->work);
    420	ctx->saved_write_space(sk);
    421}
    422
    423static void espintcp_destruct(struct sock *sk)
    424{
    425	struct espintcp_ctx *ctx = espintcp_getctx(sk);
    426
    427	ctx->saved_destruct(sk);
    428	kfree(ctx);
    429}
    430
    431bool tcp_is_ulp_esp(struct sock *sk)
    432{
    433	return sk->sk_prot == &espintcp_prot || sk->sk_prot == &espintcp6_prot;
    434}
    435EXPORT_SYMBOL_GPL(tcp_is_ulp_esp);
    436
    437static void build_protos(struct proto *espintcp_prot,
    438			 struct proto_ops *espintcp_ops,
    439			 const struct proto *orig_prot,
    440			 const struct proto_ops *orig_ops);
    441static int espintcp_init_sk(struct sock *sk)
    442{
    443	struct inet_connection_sock *icsk = inet_csk(sk);
    444	struct strp_callbacks cb = {
    445		.rcv_msg = espintcp_rcv,
    446		.parse_msg = espintcp_parse,
    447	};
    448	struct espintcp_ctx *ctx;
    449	int err;
    450
    451	/* sockmap is not compatible with espintcp */
    452	if (sk->sk_user_data)
    453		return -EBUSY;
    454
    455	ctx = kzalloc(sizeof(*ctx), GFP_KERNEL);
    456	if (!ctx)
    457		return -ENOMEM;
    458
    459	err = strp_init(&ctx->strp, sk, &cb);
    460	if (err)
    461		goto free;
    462
    463	__sk_dst_reset(sk);
    464
    465	strp_check_rcv(&ctx->strp);
    466	skb_queue_head_init(&ctx->ike_queue);
    467	skb_queue_head_init(&ctx->out_queue);
    468
    469	if (sk->sk_family == AF_INET) {
    470		sk->sk_prot = &espintcp_prot;
    471		sk->sk_socket->ops = &espintcp_ops;
    472	} else {
    473		mutex_lock(&tcpv6_prot_mutex);
    474		if (!espintcp6_prot.recvmsg)
    475			build_protos(&espintcp6_prot, &espintcp6_ops, sk->sk_prot, sk->sk_socket->ops);
    476		mutex_unlock(&tcpv6_prot_mutex);
    477
    478		sk->sk_prot = &espintcp6_prot;
    479		sk->sk_socket->ops = &espintcp6_ops;
    480	}
    481	ctx->saved_data_ready = sk->sk_data_ready;
    482	ctx->saved_write_space = sk->sk_write_space;
    483	ctx->saved_destruct = sk->sk_destruct;
    484	sk->sk_data_ready = espintcp_data_ready;
    485	sk->sk_write_space = espintcp_write_space;
    486	sk->sk_destruct = espintcp_destruct;
    487	rcu_assign_pointer(icsk->icsk_ulp_data, ctx);
    488	INIT_WORK(&ctx->work, espintcp_tx_work);
    489
    490	/* avoid using task_frag */
    491	sk->sk_allocation = GFP_ATOMIC;
    492
    493	return 0;
    494
    495free:
    496	kfree(ctx);
    497	return err;
    498}
    499
    500static void espintcp_release(struct sock *sk)
    501{
    502	struct espintcp_ctx *ctx = espintcp_getctx(sk);
    503	struct sk_buff_head queue;
    504	struct sk_buff *skb;
    505
    506	__skb_queue_head_init(&queue);
    507	skb_queue_splice_init(&ctx->out_queue, &queue);
    508
    509	while ((skb = __skb_dequeue(&queue)))
    510		espintcp_push_skb(sk, skb);
    511
    512	tcp_release_cb(sk);
    513}
    514
    515static void espintcp_close(struct sock *sk, long timeout)
    516{
    517	struct espintcp_ctx *ctx = espintcp_getctx(sk);
    518	struct espintcp_msg *emsg = &ctx->partial;
    519
    520	strp_stop(&ctx->strp);
    521
    522	sk->sk_prot = &tcp_prot;
    523	barrier();
    524
    525	cancel_work_sync(&ctx->work);
    526	strp_done(&ctx->strp);
    527
    528	skb_queue_purge(&ctx->out_queue);
    529	skb_queue_purge(&ctx->ike_queue);
    530
    531	if (emsg->len) {
    532		if (emsg->skb)
    533			kfree_skb(emsg->skb);
    534		else
    535			sk_msg_free(sk, &emsg->skmsg);
    536	}
    537
    538	tcp_close(sk, timeout);
    539}
    540
    541static __poll_t espintcp_poll(struct file *file, struct socket *sock,
    542			      poll_table *wait)
    543{
    544	__poll_t mask = datagram_poll(file, sock, wait);
    545	struct sock *sk = sock->sk;
    546	struct espintcp_ctx *ctx = espintcp_getctx(sk);
    547
    548	if (!skb_queue_empty(&ctx->ike_queue))
    549		mask |= EPOLLIN | EPOLLRDNORM;
    550
    551	return mask;
    552}
    553
    554static void build_protos(struct proto *espintcp_prot,
    555			 struct proto_ops *espintcp_ops,
    556			 const struct proto *orig_prot,
    557			 const struct proto_ops *orig_ops)
    558{
    559	memcpy(espintcp_prot, orig_prot, sizeof(struct proto));
    560	memcpy(espintcp_ops, orig_ops, sizeof(struct proto_ops));
    561	espintcp_prot->sendmsg = espintcp_sendmsg;
    562	espintcp_prot->recvmsg = espintcp_recvmsg;
    563	espintcp_prot->close = espintcp_close;
    564	espintcp_prot->release_cb = espintcp_release;
    565	espintcp_ops->poll = espintcp_poll;
    566}
    567
    568static struct tcp_ulp_ops espintcp_ulp __read_mostly = {
    569	.name = "espintcp",
    570	.owner = THIS_MODULE,
    571	.init = espintcp_init_sk,
    572};
    573
    574void __init espintcp_init(void)
    575{
    576	build_protos(&espintcp_prot, &espintcp_ops, &tcp_prot, &inet_stream_ops);
    577
    578	tcp_register_ulp(&espintcp_ulp);
    579}