cachepc-linux

Fork of AMDESE/linux with modifications for CachePC side-channel attack
git clone https://git.sinitax.com/sinitax/cachepc-linux
Log | Files | Refs | README | LICENSE | sfeed.txt

tls_sw.c (64815B)


      1/*
      2 * Copyright (c) 2016-2017, Mellanox Technologies. All rights reserved.
      3 * Copyright (c) 2016-2017, Dave Watson <davejwatson@fb.com>. All rights reserved.
      4 * Copyright (c) 2016-2017, Lance Chao <lancerchao@fb.com>. All rights reserved.
      5 * Copyright (c) 2016, Fridolin Pokorny <fridolin.pokorny@gmail.com>. All rights reserved.
      6 * Copyright (c) 2016, Nikos Mavrogiannopoulos <nmav@gnutls.org>. All rights reserved.
      7 * Copyright (c) 2018, Covalent IO, Inc. http://covalent.io
      8 *
      9 * This software is available to you under a choice of one of two
     10 * licenses.  You may choose to be licensed under the terms of the GNU
     11 * General Public License (GPL) Version 2, available from the file
     12 * COPYING in the main directory of this source tree, or the
     13 * OpenIB.org BSD license below:
     14 *
     15 *     Redistribution and use in source and binary forms, with or
     16 *     without modification, are permitted provided that the following
     17 *     conditions are met:
     18 *
     19 *      - Redistributions of source code must retain the above
     20 *        copyright notice, this list of conditions and the following
     21 *        disclaimer.
     22 *
     23 *      - Redistributions in binary form must reproduce the above
     24 *        copyright notice, this list of conditions and the following
     25 *        disclaimer in the documentation and/or other materials
     26 *        provided with the distribution.
     27 *
     28 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
     29 * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
     30 * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
     31 * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
     32 * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
     33 * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
     34 * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
     35 * SOFTWARE.
     36 */
     37
     38#include <linux/bug.h>
     39#include <linux/sched/signal.h>
     40#include <linux/module.h>
     41#include <linux/splice.h>
     42#include <crypto/aead.h>
     43
     44#include <net/strparser.h>
     45#include <net/tls.h>
     46
     47struct tls_decrypt_arg {
     48	bool zc;
     49	bool async;
     50};
     51
     52noinline void tls_err_abort(struct sock *sk, int err)
     53{
     54	WARN_ON_ONCE(err >= 0);
     55	/* sk->sk_err should contain a positive error code. */
     56	sk->sk_err = -err;
     57	sk_error_report(sk);
     58}
     59
     60static int __skb_nsg(struct sk_buff *skb, int offset, int len,
     61                     unsigned int recursion_level)
     62{
     63        int start = skb_headlen(skb);
     64        int i, chunk = start - offset;
     65        struct sk_buff *frag_iter;
     66        int elt = 0;
     67
     68        if (unlikely(recursion_level >= 24))
     69                return -EMSGSIZE;
     70
     71        if (chunk > 0) {
     72                if (chunk > len)
     73                        chunk = len;
     74                elt++;
     75                len -= chunk;
     76                if (len == 0)
     77                        return elt;
     78                offset += chunk;
     79        }
     80
     81        for (i = 0; i < skb_shinfo(skb)->nr_frags; i++) {
     82                int end;
     83
     84                WARN_ON(start > offset + len);
     85
     86                end = start + skb_frag_size(&skb_shinfo(skb)->frags[i]);
     87                chunk = end - offset;
     88                if (chunk > 0) {
     89                        if (chunk > len)
     90                                chunk = len;
     91                        elt++;
     92                        len -= chunk;
     93                        if (len == 0)
     94                                return elt;
     95                        offset += chunk;
     96                }
     97                start = end;
     98        }
     99
    100        if (unlikely(skb_has_frag_list(skb))) {
    101                skb_walk_frags(skb, frag_iter) {
    102                        int end, ret;
    103
    104                        WARN_ON(start > offset + len);
    105
    106                        end = start + frag_iter->len;
    107                        chunk = end - offset;
    108                        if (chunk > 0) {
    109                                if (chunk > len)
    110                                        chunk = len;
    111                                ret = __skb_nsg(frag_iter, offset - start, chunk,
    112                                                recursion_level + 1);
    113                                if (unlikely(ret < 0))
    114                                        return ret;
    115                                elt += ret;
    116                                len -= chunk;
    117                                if (len == 0)
    118                                        return elt;
    119                                offset += chunk;
    120                        }
    121                        start = end;
    122                }
    123        }
    124        BUG_ON(len);
    125        return elt;
    126}
    127
    128/* Return the number of scatterlist elements required to completely map the
    129 * skb, or -EMSGSIZE if the recursion depth is exceeded.
    130 */
    131static int skb_nsg(struct sk_buff *skb, int offset, int len)
    132{
    133        return __skb_nsg(skb, offset, len, 0);
    134}
    135
    136static int padding_length(struct tls_prot_info *prot, struct sk_buff *skb)
    137{
    138	struct strp_msg *rxm = strp_msg(skb);
    139	struct tls_msg *tlm = tls_msg(skb);
    140	int sub = 0;
    141
    142	/* Determine zero-padding length */
    143	if (prot->version == TLS_1_3_VERSION) {
    144		int offset = rxm->full_len - TLS_TAG_SIZE - 1;
    145		char content_type = 0;
    146		int err;
    147
    148		while (content_type == 0) {
    149			if (offset < prot->prepend_size)
    150				return -EBADMSG;
    151			err = skb_copy_bits(skb, rxm->offset + offset,
    152					    &content_type, 1);
    153			if (err)
    154				return err;
    155			if (content_type)
    156				break;
    157			sub++;
    158			offset--;
    159		}
    160		tlm->control = content_type;
    161	}
    162	return sub;
    163}
    164
    165static void tls_decrypt_done(struct crypto_async_request *req, int err)
    166{
    167	struct aead_request *aead_req = (struct aead_request *)req;
    168	struct scatterlist *sgout = aead_req->dst;
    169	struct scatterlist *sgin = aead_req->src;
    170	struct tls_sw_context_rx *ctx;
    171	struct tls_context *tls_ctx;
    172	struct tls_prot_info *prot;
    173	struct scatterlist *sg;
    174	struct sk_buff *skb;
    175	unsigned int pages;
    176
    177	skb = (struct sk_buff *)req->data;
    178	tls_ctx = tls_get_ctx(skb->sk);
    179	ctx = tls_sw_ctx_rx(tls_ctx);
    180	prot = &tls_ctx->prot_info;
    181
    182	/* Propagate if there was an err */
    183	if (err) {
    184		if (err == -EBADMSG)
    185			TLS_INC_STATS(sock_net(skb->sk),
    186				      LINUX_MIB_TLSDECRYPTERROR);
    187		ctx->async_wait.err = err;
    188		tls_err_abort(skb->sk, err);
    189	} else {
    190		struct strp_msg *rxm = strp_msg(skb);
    191
    192		/* No TLS 1.3 support with async crypto */
    193		WARN_ON(prot->tail_size);
    194
    195		rxm->offset += prot->prepend_size;
    196		rxm->full_len -= prot->overhead_size;
    197	}
    198
    199	/* After using skb->sk to propagate sk through crypto async callback
    200	 * we need to NULL it again.
    201	 */
    202	skb->sk = NULL;
    203
    204
    205	/* Free the destination pages if skb was not decrypted inplace */
    206	if (sgout != sgin) {
    207		/* Skip the first S/G entry as it points to AAD */
    208		for_each_sg(sg_next(sgout), sg, UINT_MAX, pages) {
    209			if (!sg)
    210				break;
    211			put_page(sg_page(sg));
    212		}
    213	}
    214
    215	kfree(aead_req);
    216
    217	spin_lock_bh(&ctx->decrypt_compl_lock);
    218	if (!atomic_dec_return(&ctx->decrypt_pending))
    219		complete(&ctx->async_wait.completion);
    220	spin_unlock_bh(&ctx->decrypt_compl_lock);
    221}
    222
    223static int tls_do_decryption(struct sock *sk,
    224			     struct sk_buff *skb,
    225			     struct scatterlist *sgin,
    226			     struct scatterlist *sgout,
    227			     char *iv_recv,
    228			     size_t data_len,
    229			     struct aead_request *aead_req,
    230			     struct tls_decrypt_arg *darg)
    231{
    232	struct tls_context *tls_ctx = tls_get_ctx(sk);
    233	struct tls_prot_info *prot = &tls_ctx->prot_info;
    234	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
    235	int ret;
    236
    237	aead_request_set_tfm(aead_req, ctx->aead_recv);
    238	aead_request_set_ad(aead_req, prot->aad_size);
    239	aead_request_set_crypt(aead_req, sgin, sgout,
    240			       data_len + prot->tag_size,
    241			       (u8 *)iv_recv);
    242
    243	if (darg->async) {
    244		/* Using skb->sk to push sk through to crypto async callback
    245		 * handler. This allows propagating errors up to the socket
    246		 * if needed. It _must_ be cleared in the async handler
    247		 * before consume_skb is called. We _know_ skb->sk is NULL
    248		 * because it is a clone from strparser.
    249		 */
    250		skb->sk = sk;
    251		aead_request_set_callback(aead_req,
    252					  CRYPTO_TFM_REQ_MAY_BACKLOG,
    253					  tls_decrypt_done, skb);
    254		atomic_inc(&ctx->decrypt_pending);
    255	} else {
    256		aead_request_set_callback(aead_req,
    257					  CRYPTO_TFM_REQ_MAY_BACKLOG,
    258					  crypto_req_done, &ctx->async_wait);
    259	}
    260
    261	ret = crypto_aead_decrypt(aead_req);
    262	if (ret == -EINPROGRESS) {
    263		if (darg->async)
    264			return 0;
    265
    266		ret = crypto_wait_req(ret, &ctx->async_wait);
    267	}
    268	darg->async = false;
    269
    270	return ret;
    271}
    272
    273static void tls_trim_both_msgs(struct sock *sk, int target_size)
    274{
    275	struct tls_context *tls_ctx = tls_get_ctx(sk);
    276	struct tls_prot_info *prot = &tls_ctx->prot_info;
    277	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
    278	struct tls_rec *rec = ctx->open_rec;
    279
    280	sk_msg_trim(sk, &rec->msg_plaintext, target_size);
    281	if (target_size > 0)
    282		target_size += prot->overhead_size;
    283	sk_msg_trim(sk, &rec->msg_encrypted, target_size);
    284}
    285
    286static int tls_alloc_encrypted_msg(struct sock *sk, int len)
    287{
    288	struct tls_context *tls_ctx = tls_get_ctx(sk);
    289	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
    290	struct tls_rec *rec = ctx->open_rec;
    291	struct sk_msg *msg_en = &rec->msg_encrypted;
    292
    293	return sk_msg_alloc(sk, msg_en, len, 0);
    294}
    295
    296static int tls_clone_plaintext_msg(struct sock *sk, int required)
    297{
    298	struct tls_context *tls_ctx = tls_get_ctx(sk);
    299	struct tls_prot_info *prot = &tls_ctx->prot_info;
    300	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
    301	struct tls_rec *rec = ctx->open_rec;
    302	struct sk_msg *msg_pl = &rec->msg_plaintext;
    303	struct sk_msg *msg_en = &rec->msg_encrypted;
    304	int skip, len;
    305
    306	/* We add page references worth len bytes from encrypted sg
    307	 * at the end of plaintext sg. It is guaranteed that msg_en
    308	 * has enough required room (ensured by caller).
    309	 */
    310	len = required - msg_pl->sg.size;
    311
    312	/* Skip initial bytes in msg_en's data to be able to use
    313	 * same offset of both plain and encrypted data.
    314	 */
    315	skip = prot->prepend_size + msg_pl->sg.size;
    316
    317	return sk_msg_clone(sk, msg_pl, msg_en, skip, len);
    318}
    319
    320static struct tls_rec *tls_get_rec(struct sock *sk)
    321{
    322	struct tls_context *tls_ctx = tls_get_ctx(sk);
    323	struct tls_prot_info *prot = &tls_ctx->prot_info;
    324	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
    325	struct sk_msg *msg_pl, *msg_en;
    326	struct tls_rec *rec;
    327	int mem_size;
    328
    329	mem_size = sizeof(struct tls_rec) + crypto_aead_reqsize(ctx->aead_send);
    330
    331	rec = kzalloc(mem_size, sk->sk_allocation);
    332	if (!rec)
    333		return NULL;
    334
    335	msg_pl = &rec->msg_plaintext;
    336	msg_en = &rec->msg_encrypted;
    337
    338	sk_msg_init(msg_pl);
    339	sk_msg_init(msg_en);
    340
    341	sg_init_table(rec->sg_aead_in, 2);
    342	sg_set_buf(&rec->sg_aead_in[0], rec->aad_space, prot->aad_size);
    343	sg_unmark_end(&rec->sg_aead_in[1]);
    344
    345	sg_init_table(rec->sg_aead_out, 2);
    346	sg_set_buf(&rec->sg_aead_out[0], rec->aad_space, prot->aad_size);
    347	sg_unmark_end(&rec->sg_aead_out[1]);
    348
    349	return rec;
    350}
    351
    352static void tls_free_rec(struct sock *sk, struct tls_rec *rec)
    353{
    354	sk_msg_free(sk, &rec->msg_encrypted);
    355	sk_msg_free(sk, &rec->msg_plaintext);
    356	kfree(rec);
    357}
    358
    359static void tls_free_open_rec(struct sock *sk)
    360{
    361	struct tls_context *tls_ctx = tls_get_ctx(sk);
    362	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
    363	struct tls_rec *rec = ctx->open_rec;
    364
    365	if (rec) {
    366		tls_free_rec(sk, rec);
    367		ctx->open_rec = NULL;
    368	}
    369}
    370
    371int tls_tx_records(struct sock *sk, int flags)
    372{
    373	struct tls_context *tls_ctx = tls_get_ctx(sk);
    374	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
    375	struct tls_rec *rec, *tmp;
    376	struct sk_msg *msg_en;
    377	int tx_flags, rc = 0;
    378
    379	if (tls_is_partially_sent_record(tls_ctx)) {
    380		rec = list_first_entry(&ctx->tx_list,
    381				       struct tls_rec, list);
    382
    383		if (flags == -1)
    384			tx_flags = rec->tx_flags;
    385		else
    386			tx_flags = flags;
    387
    388		rc = tls_push_partial_record(sk, tls_ctx, tx_flags);
    389		if (rc)
    390			goto tx_err;
    391
    392		/* Full record has been transmitted.
    393		 * Remove the head of tx_list
    394		 */
    395		list_del(&rec->list);
    396		sk_msg_free(sk, &rec->msg_plaintext);
    397		kfree(rec);
    398	}
    399
    400	/* Tx all ready records */
    401	list_for_each_entry_safe(rec, tmp, &ctx->tx_list, list) {
    402		if (READ_ONCE(rec->tx_ready)) {
    403			if (flags == -1)
    404				tx_flags = rec->tx_flags;
    405			else
    406				tx_flags = flags;
    407
    408			msg_en = &rec->msg_encrypted;
    409			rc = tls_push_sg(sk, tls_ctx,
    410					 &msg_en->sg.data[msg_en->sg.curr],
    411					 0, tx_flags);
    412			if (rc)
    413				goto tx_err;
    414
    415			list_del(&rec->list);
    416			sk_msg_free(sk, &rec->msg_plaintext);
    417			kfree(rec);
    418		} else {
    419			break;
    420		}
    421	}
    422
    423tx_err:
    424	if (rc < 0 && rc != -EAGAIN)
    425		tls_err_abort(sk, -EBADMSG);
    426
    427	return rc;
    428}
    429
    430static void tls_encrypt_done(struct crypto_async_request *req, int err)
    431{
    432	struct aead_request *aead_req = (struct aead_request *)req;
    433	struct sock *sk = req->data;
    434	struct tls_context *tls_ctx = tls_get_ctx(sk);
    435	struct tls_prot_info *prot = &tls_ctx->prot_info;
    436	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
    437	struct scatterlist *sge;
    438	struct sk_msg *msg_en;
    439	struct tls_rec *rec;
    440	bool ready = false;
    441	int pending;
    442
    443	rec = container_of(aead_req, struct tls_rec, aead_req);
    444	msg_en = &rec->msg_encrypted;
    445
    446	sge = sk_msg_elem(msg_en, msg_en->sg.curr);
    447	sge->offset -= prot->prepend_size;
    448	sge->length += prot->prepend_size;
    449
    450	/* Check if error is previously set on socket */
    451	if (err || sk->sk_err) {
    452		rec = NULL;
    453
    454		/* If err is already set on socket, return the same code */
    455		if (sk->sk_err) {
    456			ctx->async_wait.err = -sk->sk_err;
    457		} else {
    458			ctx->async_wait.err = err;
    459			tls_err_abort(sk, err);
    460		}
    461	}
    462
    463	if (rec) {
    464		struct tls_rec *first_rec;
    465
    466		/* Mark the record as ready for transmission */
    467		smp_store_mb(rec->tx_ready, true);
    468
    469		/* If received record is at head of tx_list, schedule tx */
    470		first_rec = list_first_entry(&ctx->tx_list,
    471					     struct tls_rec, list);
    472		if (rec == first_rec)
    473			ready = true;
    474	}
    475
    476	spin_lock_bh(&ctx->encrypt_compl_lock);
    477	pending = atomic_dec_return(&ctx->encrypt_pending);
    478
    479	if (!pending && ctx->async_notify)
    480		complete(&ctx->async_wait.completion);
    481	spin_unlock_bh(&ctx->encrypt_compl_lock);
    482
    483	if (!ready)
    484		return;
    485
    486	/* Schedule the transmission */
    487	if (!test_and_set_bit(BIT_TX_SCHEDULED, &ctx->tx_bitmask))
    488		schedule_delayed_work(&ctx->tx_work.work, 1);
    489}
    490
    491static int tls_do_encryption(struct sock *sk,
    492			     struct tls_context *tls_ctx,
    493			     struct tls_sw_context_tx *ctx,
    494			     struct aead_request *aead_req,
    495			     size_t data_len, u32 start)
    496{
    497	struct tls_prot_info *prot = &tls_ctx->prot_info;
    498	struct tls_rec *rec = ctx->open_rec;
    499	struct sk_msg *msg_en = &rec->msg_encrypted;
    500	struct scatterlist *sge = sk_msg_elem(msg_en, start);
    501	int rc, iv_offset = 0;
    502
    503	/* For CCM based ciphers, first byte of IV is a constant */
    504	switch (prot->cipher_type) {
    505	case TLS_CIPHER_AES_CCM_128:
    506		rec->iv_data[0] = TLS_AES_CCM_IV_B0_BYTE;
    507		iv_offset = 1;
    508		break;
    509	case TLS_CIPHER_SM4_CCM:
    510		rec->iv_data[0] = TLS_SM4_CCM_IV_B0_BYTE;
    511		iv_offset = 1;
    512		break;
    513	}
    514
    515	memcpy(&rec->iv_data[iv_offset], tls_ctx->tx.iv,
    516	       prot->iv_size + prot->salt_size);
    517
    518	xor_iv_with_seq(prot, rec->iv_data + iv_offset, tls_ctx->tx.rec_seq);
    519
    520	sge->offset += prot->prepend_size;
    521	sge->length -= prot->prepend_size;
    522
    523	msg_en->sg.curr = start;
    524
    525	aead_request_set_tfm(aead_req, ctx->aead_send);
    526	aead_request_set_ad(aead_req, prot->aad_size);
    527	aead_request_set_crypt(aead_req, rec->sg_aead_in,
    528			       rec->sg_aead_out,
    529			       data_len, rec->iv_data);
    530
    531	aead_request_set_callback(aead_req, CRYPTO_TFM_REQ_MAY_BACKLOG,
    532				  tls_encrypt_done, sk);
    533
    534	/* Add the record in tx_list */
    535	list_add_tail((struct list_head *)&rec->list, &ctx->tx_list);
    536	atomic_inc(&ctx->encrypt_pending);
    537
    538	rc = crypto_aead_encrypt(aead_req);
    539	if (!rc || rc != -EINPROGRESS) {
    540		atomic_dec(&ctx->encrypt_pending);
    541		sge->offset -= prot->prepend_size;
    542		sge->length += prot->prepend_size;
    543	}
    544
    545	if (!rc) {
    546		WRITE_ONCE(rec->tx_ready, true);
    547	} else if (rc != -EINPROGRESS) {
    548		list_del(&rec->list);
    549		return rc;
    550	}
    551
    552	/* Unhook the record from context if encryption is not failure */
    553	ctx->open_rec = NULL;
    554	tls_advance_record_sn(sk, prot, &tls_ctx->tx);
    555	return rc;
    556}
    557
    558static int tls_split_open_record(struct sock *sk, struct tls_rec *from,
    559				 struct tls_rec **to, struct sk_msg *msg_opl,
    560				 struct sk_msg *msg_oen, u32 split_point,
    561				 u32 tx_overhead_size, u32 *orig_end)
    562{
    563	u32 i, j, bytes = 0, apply = msg_opl->apply_bytes;
    564	struct scatterlist *sge, *osge, *nsge;
    565	u32 orig_size = msg_opl->sg.size;
    566	struct scatterlist tmp = { };
    567	struct sk_msg *msg_npl;
    568	struct tls_rec *new;
    569	int ret;
    570
    571	new = tls_get_rec(sk);
    572	if (!new)
    573		return -ENOMEM;
    574	ret = sk_msg_alloc(sk, &new->msg_encrypted, msg_opl->sg.size +
    575			   tx_overhead_size, 0);
    576	if (ret < 0) {
    577		tls_free_rec(sk, new);
    578		return ret;
    579	}
    580
    581	*orig_end = msg_opl->sg.end;
    582	i = msg_opl->sg.start;
    583	sge = sk_msg_elem(msg_opl, i);
    584	while (apply && sge->length) {
    585		if (sge->length > apply) {
    586			u32 len = sge->length - apply;
    587
    588			get_page(sg_page(sge));
    589			sg_set_page(&tmp, sg_page(sge), len,
    590				    sge->offset + apply);
    591			sge->length = apply;
    592			bytes += apply;
    593			apply = 0;
    594		} else {
    595			apply -= sge->length;
    596			bytes += sge->length;
    597		}
    598
    599		sk_msg_iter_var_next(i);
    600		if (i == msg_opl->sg.end)
    601			break;
    602		sge = sk_msg_elem(msg_opl, i);
    603	}
    604
    605	msg_opl->sg.end = i;
    606	msg_opl->sg.curr = i;
    607	msg_opl->sg.copybreak = 0;
    608	msg_opl->apply_bytes = 0;
    609	msg_opl->sg.size = bytes;
    610
    611	msg_npl = &new->msg_plaintext;
    612	msg_npl->apply_bytes = apply;
    613	msg_npl->sg.size = orig_size - bytes;
    614
    615	j = msg_npl->sg.start;
    616	nsge = sk_msg_elem(msg_npl, j);
    617	if (tmp.length) {
    618		memcpy(nsge, &tmp, sizeof(*nsge));
    619		sk_msg_iter_var_next(j);
    620		nsge = sk_msg_elem(msg_npl, j);
    621	}
    622
    623	osge = sk_msg_elem(msg_opl, i);
    624	while (osge->length) {
    625		memcpy(nsge, osge, sizeof(*nsge));
    626		sg_unmark_end(nsge);
    627		sk_msg_iter_var_next(i);
    628		sk_msg_iter_var_next(j);
    629		if (i == *orig_end)
    630			break;
    631		osge = sk_msg_elem(msg_opl, i);
    632		nsge = sk_msg_elem(msg_npl, j);
    633	}
    634
    635	msg_npl->sg.end = j;
    636	msg_npl->sg.curr = j;
    637	msg_npl->sg.copybreak = 0;
    638
    639	*to = new;
    640	return 0;
    641}
    642
    643static void tls_merge_open_record(struct sock *sk, struct tls_rec *to,
    644				  struct tls_rec *from, u32 orig_end)
    645{
    646	struct sk_msg *msg_npl = &from->msg_plaintext;
    647	struct sk_msg *msg_opl = &to->msg_plaintext;
    648	struct scatterlist *osge, *nsge;
    649	u32 i, j;
    650
    651	i = msg_opl->sg.end;
    652	sk_msg_iter_var_prev(i);
    653	j = msg_npl->sg.start;
    654
    655	osge = sk_msg_elem(msg_opl, i);
    656	nsge = sk_msg_elem(msg_npl, j);
    657
    658	if (sg_page(osge) == sg_page(nsge) &&
    659	    osge->offset + osge->length == nsge->offset) {
    660		osge->length += nsge->length;
    661		put_page(sg_page(nsge));
    662	}
    663
    664	msg_opl->sg.end = orig_end;
    665	msg_opl->sg.curr = orig_end;
    666	msg_opl->sg.copybreak = 0;
    667	msg_opl->apply_bytes = msg_opl->sg.size + msg_npl->sg.size;
    668	msg_opl->sg.size += msg_npl->sg.size;
    669
    670	sk_msg_free(sk, &to->msg_encrypted);
    671	sk_msg_xfer_full(&to->msg_encrypted, &from->msg_encrypted);
    672
    673	kfree(from);
    674}
    675
    676static int tls_push_record(struct sock *sk, int flags,
    677			   unsigned char record_type)
    678{
    679	struct tls_context *tls_ctx = tls_get_ctx(sk);
    680	struct tls_prot_info *prot = &tls_ctx->prot_info;
    681	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
    682	struct tls_rec *rec = ctx->open_rec, *tmp = NULL;
    683	u32 i, split_point, orig_end;
    684	struct sk_msg *msg_pl, *msg_en;
    685	struct aead_request *req;
    686	bool split;
    687	int rc;
    688
    689	if (!rec)
    690		return 0;
    691
    692	msg_pl = &rec->msg_plaintext;
    693	msg_en = &rec->msg_encrypted;
    694
    695	split_point = msg_pl->apply_bytes;
    696	split = split_point && split_point < msg_pl->sg.size;
    697	if (unlikely((!split &&
    698		      msg_pl->sg.size +
    699		      prot->overhead_size > msg_en->sg.size) ||
    700		     (split &&
    701		      split_point +
    702		      prot->overhead_size > msg_en->sg.size))) {
    703		split = true;
    704		split_point = msg_en->sg.size;
    705	}
    706	if (split) {
    707		rc = tls_split_open_record(sk, rec, &tmp, msg_pl, msg_en,
    708					   split_point, prot->overhead_size,
    709					   &orig_end);
    710		if (rc < 0)
    711			return rc;
    712		/* This can happen if above tls_split_open_record allocates
    713		 * a single large encryption buffer instead of two smaller
    714		 * ones. In this case adjust pointers and continue without
    715		 * split.
    716		 */
    717		if (!msg_pl->sg.size) {
    718			tls_merge_open_record(sk, rec, tmp, orig_end);
    719			msg_pl = &rec->msg_plaintext;
    720			msg_en = &rec->msg_encrypted;
    721			split = false;
    722		}
    723		sk_msg_trim(sk, msg_en, msg_pl->sg.size +
    724			    prot->overhead_size);
    725	}
    726
    727	rec->tx_flags = flags;
    728	req = &rec->aead_req;
    729
    730	i = msg_pl->sg.end;
    731	sk_msg_iter_var_prev(i);
    732
    733	rec->content_type = record_type;
    734	if (prot->version == TLS_1_3_VERSION) {
    735		/* Add content type to end of message.  No padding added */
    736		sg_set_buf(&rec->sg_content_type, &rec->content_type, 1);
    737		sg_mark_end(&rec->sg_content_type);
    738		sg_chain(msg_pl->sg.data, msg_pl->sg.end + 1,
    739			 &rec->sg_content_type);
    740	} else {
    741		sg_mark_end(sk_msg_elem(msg_pl, i));
    742	}
    743
    744	if (msg_pl->sg.end < msg_pl->sg.start) {
    745		sg_chain(&msg_pl->sg.data[msg_pl->sg.start],
    746			 MAX_SKB_FRAGS - msg_pl->sg.start + 1,
    747			 msg_pl->sg.data);
    748	}
    749
    750	i = msg_pl->sg.start;
    751	sg_chain(rec->sg_aead_in, 2, &msg_pl->sg.data[i]);
    752
    753	i = msg_en->sg.end;
    754	sk_msg_iter_var_prev(i);
    755	sg_mark_end(sk_msg_elem(msg_en, i));
    756
    757	i = msg_en->sg.start;
    758	sg_chain(rec->sg_aead_out, 2, &msg_en->sg.data[i]);
    759
    760	tls_make_aad(rec->aad_space, msg_pl->sg.size + prot->tail_size,
    761		     tls_ctx->tx.rec_seq, record_type, prot);
    762
    763	tls_fill_prepend(tls_ctx,
    764			 page_address(sg_page(&msg_en->sg.data[i])) +
    765			 msg_en->sg.data[i].offset,
    766			 msg_pl->sg.size + prot->tail_size,
    767			 record_type);
    768
    769	tls_ctx->pending_open_record_frags = false;
    770
    771	rc = tls_do_encryption(sk, tls_ctx, ctx, req,
    772			       msg_pl->sg.size + prot->tail_size, i);
    773	if (rc < 0) {
    774		if (rc != -EINPROGRESS) {
    775			tls_err_abort(sk, -EBADMSG);
    776			if (split) {
    777				tls_ctx->pending_open_record_frags = true;
    778				tls_merge_open_record(sk, rec, tmp, orig_end);
    779			}
    780		}
    781		ctx->async_capable = 1;
    782		return rc;
    783	} else if (split) {
    784		msg_pl = &tmp->msg_plaintext;
    785		msg_en = &tmp->msg_encrypted;
    786		sk_msg_trim(sk, msg_en, msg_pl->sg.size + prot->overhead_size);
    787		tls_ctx->pending_open_record_frags = true;
    788		ctx->open_rec = tmp;
    789	}
    790
    791	return tls_tx_records(sk, flags);
    792}
    793
    794static int bpf_exec_tx_verdict(struct sk_msg *msg, struct sock *sk,
    795			       bool full_record, u8 record_type,
    796			       ssize_t *copied, int flags)
    797{
    798	struct tls_context *tls_ctx = tls_get_ctx(sk);
    799	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
    800	struct sk_msg msg_redir = { };
    801	struct sk_psock *psock;
    802	struct sock *sk_redir;
    803	struct tls_rec *rec;
    804	bool enospc, policy;
    805	int err = 0, send;
    806	u32 delta = 0;
    807
    808	policy = !(flags & MSG_SENDPAGE_NOPOLICY);
    809	psock = sk_psock_get(sk);
    810	if (!psock || !policy) {
    811		err = tls_push_record(sk, flags, record_type);
    812		if (err && sk->sk_err == EBADMSG) {
    813			*copied -= sk_msg_free(sk, msg);
    814			tls_free_open_rec(sk);
    815			err = -sk->sk_err;
    816		}
    817		if (psock)
    818			sk_psock_put(sk, psock);
    819		return err;
    820	}
    821more_data:
    822	enospc = sk_msg_full(msg);
    823	if (psock->eval == __SK_NONE) {
    824		delta = msg->sg.size;
    825		psock->eval = sk_psock_msg_verdict(sk, psock, msg);
    826		delta -= msg->sg.size;
    827	}
    828	if (msg->cork_bytes && msg->cork_bytes > msg->sg.size &&
    829	    !enospc && !full_record) {
    830		err = -ENOSPC;
    831		goto out_err;
    832	}
    833	msg->cork_bytes = 0;
    834	send = msg->sg.size;
    835	if (msg->apply_bytes && msg->apply_bytes < send)
    836		send = msg->apply_bytes;
    837
    838	switch (psock->eval) {
    839	case __SK_PASS:
    840		err = tls_push_record(sk, flags, record_type);
    841		if (err && sk->sk_err == EBADMSG) {
    842			*copied -= sk_msg_free(sk, msg);
    843			tls_free_open_rec(sk);
    844			err = -sk->sk_err;
    845			goto out_err;
    846		}
    847		break;
    848	case __SK_REDIRECT:
    849		sk_redir = psock->sk_redir;
    850		memcpy(&msg_redir, msg, sizeof(*msg));
    851		if (msg->apply_bytes < send)
    852			msg->apply_bytes = 0;
    853		else
    854			msg->apply_bytes -= send;
    855		sk_msg_return_zero(sk, msg, send);
    856		msg->sg.size -= send;
    857		release_sock(sk);
    858		err = tcp_bpf_sendmsg_redir(sk_redir, &msg_redir, send, flags);
    859		lock_sock(sk);
    860		if (err < 0) {
    861			*copied -= sk_msg_free_nocharge(sk, &msg_redir);
    862			msg->sg.size = 0;
    863		}
    864		if (msg->sg.size == 0)
    865			tls_free_open_rec(sk);
    866		break;
    867	case __SK_DROP:
    868	default:
    869		sk_msg_free_partial(sk, msg, send);
    870		if (msg->apply_bytes < send)
    871			msg->apply_bytes = 0;
    872		else
    873			msg->apply_bytes -= send;
    874		if (msg->sg.size == 0)
    875			tls_free_open_rec(sk);
    876		*copied -= (send + delta);
    877		err = -EACCES;
    878	}
    879
    880	if (likely(!err)) {
    881		bool reset_eval = !ctx->open_rec;
    882
    883		rec = ctx->open_rec;
    884		if (rec) {
    885			msg = &rec->msg_plaintext;
    886			if (!msg->apply_bytes)
    887				reset_eval = true;
    888		}
    889		if (reset_eval) {
    890			psock->eval = __SK_NONE;
    891			if (psock->sk_redir) {
    892				sock_put(psock->sk_redir);
    893				psock->sk_redir = NULL;
    894			}
    895		}
    896		if (rec)
    897			goto more_data;
    898	}
    899 out_err:
    900	sk_psock_put(sk, psock);
    901	return err;
    902}
    903
    904static int tls_sw_push_pending_record(struct sock *sk, int flags)
    905{
    906	struct tls_context *tls_ctx = tls_get_ctx(sk);
    907	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
    908	struct tls_rec *rec = ctx->open_rec;
    909	struct sk_msg *msg_pl;
    910	size_t copied;
    911
    912	if (!rec)
    913		return 0;
    914
    915	msg_pl = &rec->msg_plaintext;
    916	copied = msg_pl->sg.size;
    917	if (!copied)
    918		return 0;
    919
    920	return bpf_exec_tx_verdict(msg_pl, sk, true, TLS_RECORD_TYPE_DATA,
    921				   &copied, flags);
    922}
    923
    924int tls_sw_sendmsg(struct sock *sk, struct msghdr *msg, size_t size)
    925{
    926	long timeo = sock_sndtimeo(sk, msg->msg_flags & MSG_DONTWAIT);
    927	struct tls_context *tls_ctx = tls_get_ctx(sk);
    928	struct tls_prot_info *prot = &tls_ctx->prot_info;
    929	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
    930	bool async_capable = ctx->async_capable;
    931	unsigned char record_type = TLS_RECORD_TYPE_DATA;
    932	bool is_kvec = iov_iter_is_kvec(&msg->msg_iter);
    933	bool eor = !(msg->msg_flags & MSG_MORE);
    934	size_t try_to_copy;
    935	ssize_t copied = 0;
    936	struct sk_msg *msg_pl, *msg_en;
    937	struct tls_rec *rec;
    938	int required_size;
    939	int num_async = 0;
    940	bool full_record;
    941	int record_room;
    942	int num_zc = 0;
    943	int orig_size;
    944	int ret = 0;
    945	int pending;
    946
    947	if (msg->msg_flags & ~(MSG_MORE | MSG_DONTWAIT | MSG_NOSIGNAL |
    948			       MSG_CMSG_COMPAT))
    949		return -EOPNOTSUPP;
    950
    951	mutex_lock(&tls_ctx->tx_lock);
    952	lock_sock(sk);
    953
    954	if (unlikely(msg->msg_controllen)) {
    955		ret = tls_proccess_cmsg(sk, msg, &record_type);
    956		if (ret) {
    957			if (ret == -EINPROGRESS)
    958				num_async++;
    959			else if (ret != -EAGAIN)
    960				goto send_end;
    961		}
    962	}
    963
    964	while (msg_data_left(msg)) {
    965		if (sk->sk_err) {
    966			ret = -sk->sk_err;
    967			goto send_end;
    968		}
    969
    970		if (ctx->open_rec)
    971			rec = ctx->open_rec;
    972		else
    973			rec = ctx->open_rec = tls_get_rec(sk);
    974		if (!rec) {
    975			ret = -ENOMEM;
    976			goto send_end;
    977		}
    978
    979		msg_pl = &rec->msg_plaintext;
    980		msg_en = &rec->msg_encrypted;
    981
    982		orig_size = msg_pl->sg.size;
    983		full_record = false;
    984		try_to_copy = msg_data_left(msg);
    985		record_room = TLS_MAX_PAYLOAD_SIZE - msg_pl->sg.size;
    986		if (try_to_copy >= record_room) {
    987			try_to_copy = record_room;
    988			full_record = true;
    989		}
    990
    991		required_size = msg_pl->sg.size + try_to_copy +
    992				prot->overhead_size;
    993
    994		if (!sk_stream_memory_free(sk))
    995			goto wait_for_sndbuf;
    996
    997alloc_encrypted:
    998		ret = tls_alloc_encrypted_msg(sk, required_size);
    999		if (ret) {
   1000			if (ret != -ENOSPC)
   1001				goto wait_for_memory;
   1002
   1003			/* Adjust try_to_copy according to the amount that was
   1004			 * actually allocated. The difference is due
   1005			 * to max sg elements limit
   1006			 */
   1007			try_to_copy -= required_size - msg_en->sg.size;
   1008			full_record = true;
   1009		}
   1010
   1011		if (!is_kvec && (full_record || eor) && !async_capable) {
   1012			u32 first = msg_pl->sg.end;
   1013
   1014			ret = sk_msg_zerocopy_from_iter(sk, &msg->msg_iter,
   1015							msg_pl, try_to_copy);
   1016			if (ret)
   1017				goto fallback_to_reg_send;
   1018
   1019			num_zc++;
   1020			copied += try_to_copy;
   1021
   1022			sk_msg_sg_copy_set(msg_pl, first);
   1023			ret = bpf_exec_tx_verdict(msg_pl, sk, full_record,
   1024						  record_type, &copied,
   1025						  msg->msg_flags);
   1026			if (ret) {
   1027				if (ret == -EINPROGRESS)
   1028					num_async++;
   1029				else if (ret == -ENOMEM)
   1030					goto wait_for_memory;
   1031				else if (ctx->open_rec && ret == -ENOSPC)
   1032					goto rollback_iter;
   1033				else if (ret != -EAGAIN)
   1034					goto send_end;
   1035			}
   1036			continue;
   1037rollback_iter:
   1038			copied -= try_to_copy;
   1039			sk_msg_sg_copy_clear(msg_pl, first);
   1040			iov_iter_revert(&msg->msg_iter,
   1041					msg_pl->sg.size - orig_size);
   1042fallback_to_reg_send:
   1043			sk_msg_trim(sk, msg_pl, orig_size);
   1044		}
   1045
   1046		required_size = msg_pl->sg.size + try_to_copy;
   1047
   1048		ret = tls_clone_plaintext_msg(sk, required_size);
   1049		if (ret) {
   1050			if (ret != -ENOSPC)
   1051				goto send_end;
   1052
   1053			/* Adjust try_to_copy according to the amount that was
   1054			 * actually allocated. The difference is due
   1055			 * to max sg elements limit
   1056			 */
   1057			try_to_copy -= required_size - msg_pl->sg.size;
   1058			full_record = true;
   1059			sk_msg_trim(sk, msg_en,
   1060				    msg_pl->sg.size + prot->overhead_size);
   1061		}
   1062
   1063		if (try_to_copy) {
   1064			ret = sk_msg_memcopy_from_iter(sk, &msg->msg_iter,
   1065						       msg_pl, try_to_copy);
   1066			if (ret < 0)
   1067				goto trim_sgl;
   1068		}
   1069
   1070		/* Open records defined only if successfully copied, otherwise
   1071		 * we would trim the sg but not reset the open record frags.
   1072		 */
   1073		tls_ctx->pending_open_record_frags = true;
   1074		copied += try_to_copy;
   1075		if (full_record || eor) {
   1076			ret = bpf_exec_tx_verdict(msg_pl, sk, full_record,
   1077						  record_type, &copied,
   1078						  msg->msg_flags);
   1079			if (ret) {
   1080				if (ret == -EINPROGRESS)
   1081					num_async++;
   1082				else if (ret == -ENOMEM)
   1083					goto wait_for_memory;
   1084				else if (ret != -EAGAIN) {
   1085					if (ret == -ENOSPC)
   1086						ret = 0;
   1087					goto send_end;
   1088				}
   1089			}
   1090		}
   1091
   1092		continue;
   1093
   1094wait_for_sndbuf:
   1095		set_bit(SOCK_NOSPACE, &sk->sk_socket->flags);
   1096wait_for_memory:
   1097		ret = sk_stream_wait_memory(sk, &timeo);
   1098		if (ret) {
   1099trim_sgl:
   1100			if (ctx->open_rec)
   1101				tls_trim_both_msgs(sk, orig_size);
   1102			goto send_end;
   1103		}
   1104
   1105		if (ctx->open_rec && msg_en->sg.size < required_size)
   1106			goto alloc_encrypted;
   1107	}
   1108
   1109	if (!num_async) {
   1110		goto send_end;
   1111	} else if (num_zc) {
   1112		/* Wait for pending encryptions to get completed */
   1113		spin_lock_bh(&ctx->encrypt_compl_lock);
   1114		ctx->async_notify = true;
   1115
   1116		pending = atomic_read(&ctx->encrypt_pending);
   1117		spin_unlock_bh(&ctx->encrypt_compl_lock);
   1118		if (pending)
   1119			crypto_wait_req(-EINPROGRESS, &ctx->async_wait);
   1120		else
   1121			reinit_completion(&ctx->async_wait.completion);
   1122
   1123		/* There can be no concurrent accesses, since we have no
   1124		 * pending encrypt operations
   1125		 */
   1126		WRITE_ONCE(ctx->async_notify, false);
   1127
   1128		if (ctx->async_wait.err) {
   1129			ret = ctx->async_wait.err;
   1130			copied = 0;
   1131		}
   1132	}
   1133
   1134	/* Transmit if any encryptions have completed */
   1135	if (test_and_clear_bit(BIT_TX_SCHEDULED, &ctx->tx_bitmask)) {
   1136		cancel_delayed_work(&ctx->tx_work.work);
   1137		tls_tx_records(sk, msg->msg_flags);
   1138	}
   1139
   1140send_end:
   1141	ret = sk_stream_error(sk, msg->msg_flags, ret);
   1142
   1143	release_sock(sk);
   1144	mutex_unlock(&tls_ctx->tx_lock);
   1145	return copied > 0 ? copied : ret;
   1146}
   1147
   1148static int tls_sw_do_sendpage(struct sock *sk, struct page *page,
   1149			      int offset, size_t size, int flags)
   1150{
   1151	long timeo = sock_sndtimeo(sk, flags & MSG_DONTWAIT);
   1152	struct tls_context *tls_ctx = tls_get_ctx(sk);
   1153	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
   1154	struct tls_prot_info *prot = &tls_ctx->prot_info;
   1155	unsigned char record_type = TLS_RECORD_TYPE_DATA;
   1156	struct sk_msg *msg_pl;
   1157	struct tls_rec *rec;
   1158	int num_async = 0;
   1159	ssize_t copied = 0;
   1160	bool full_record;
   1161	int record_room;
   1162	int ret = 0;
   1163	bool eor;
   1164
   1165	eor = !(flags & MSG_SENDPAGE_NOTLAST);
   1166	sk_clear_bit(SOCKWQ_ASYNC_NOSPACE, sk);
   1167
   1168	/* Call the sk_stream functions to manage the sndbuf mem. */
   1169	while (size > 0) {
   1170		size_t copy, required_size;
   1171
   1172		if (sk->sk_err) {
   1173			ret = -sk->sk_err;
   1174			goto sendpage_end;
   1175		}
   1176
   1177		if (ctx->open_rec)
   1178			rec = ctx->open_rec;
   1179		else
   1180			rec = ctx->open_rec = tls_get_rec(sk);
   1181		if (!rec) {
   1182			ret = -ENOMEM;
   1183			goto sendpage_end;
   1184		}
   1185
   1186		msg_pl = &rec->msg_plaintext;
   1187
   1188		full_record = false;
   1189		record_room = TLS_MAX_PAYLOAD_SIZE - msg_pl->sg.size;
   1190		copy = size;
   1191		if (copy >= record_room) {
   1192			copy = record_room;
   1193			full_record = true;
   1194		}
   1195
   1196		required_size = msg_pl->sg.size + copy + prot->overhead_size;
   1197
   1198		if (!sk_stream_memory_free(sk))
   1199			goto wait_for_sndbuf;
   1200alloc_payload:
   1201		ret = tls_alloc_encrypted_msg(sk, required_size);
   1202		if (ret) {
   1203			if (ret != -ENOSPC)
   1204				goto wait_for_memory;
   1205
   1206			/* Adjust copy according to the amount that was
   1207			 * actually allocated. The difference is due
   1208			 * to max sg elements limit
   1209			 */
   1210			copy -= required_size - msg_pl->sg.size;
   1211			full_record = true;
   1212		}
   1213
   1214		sk_msg_page_add(msg_pl, page, copy, offset);
   1215		sk_mem_charge(sk, copy);
   1216
   1217		offset += copy;
   1218		size -= copy;
   1219		copied += copy;
   1220
   1221		tls_ctx->pending_open_record_frags = true;
   1222		if (full_record || eor || sk_msg_full(msg_pl)) {
   1223			ret = bpf_exec_tx_verdict(msg_pl, sk, full_record,
   1224						  record_type, &copied, flags);
   1225			if (ret) {
   1226				if (ret == -EINPROGRESS)
   1227					num_async++;
   1228				else if (ret == -ENOMEM)
   1229					goto wait_for_memory;
   1230				else if (ret != -EAGAIN) {
   1231					if (ret == -ENOSPC)
   1232						ret = 0;
   1233					goto sendpage_end;
   1234				}
   1235			}
   1236		}
   1237		continue;
   1238wait_for_sndbuf:
   1239		set_bit(SOCK_NOSPACE, &sk->sk_socket->flags);
   1240wait_for_memory:
   1241		ret = sk_stream_wait_memory(sk, &timeo);
   1242		if (ret) {
   1243			if (ctx->open_rec)
   1244				tls_trim_both_msgs(sk, msg_pl->sg.size);
   1245			goto sendpage_end;
   1246		}
   1247
   1248		if (ctx->open_rec)
   1249			goto alloc_payload;
   1250	}
   1251
   1252	if (num_async) {
   1253		/* Transmit if any encryptions have completed */
   1254		if (test_and_clear_bit(BIT_TX_SCHEDULED, &ctx->tx_bitmask)) {
   1255			cancel_delayed_work(&ctx->tx_work.work);
   1256			tls_tx_records(sk, flags);
   1257		}
   1258	}
   1259sendpage_end:
   1260	ret = sk_stream_error(sk, flags, ret);
   1261	return copied > 0 ? copied : ret;
   1262}
   1263
   1264int tls_sw_sendpage_locked(struct sock *sk, struct page *page,
   1265			   int offset, size_t size, int flags)
   1266{
   1267	if (flags & ~(MSG_MORE | MSG_DONTWAIT | MSG_NOSIGNAL |
   1268		      MSG_SENDPAGE_NOTLAST | MSG_SENDPAGE_NOPOLICY |
   1269		      MSG_NO_SHARED_FRAGS))
   1270		return -EOPNOTSUPP;
   1271
   1272	return tls_sw_do_sendpage(sk, page, offset, size, flags);
   1273}
   1274
   1275int tls_sw_sendpage(struct sock *sk, struct page *page,
   1276		    int offset, size_t size, int flags)
   1277{
   1278	struct tls_context *tls_ctx = tls_get_ctx(sk);
   1279	int ret;
   1280
   1281	if (flags & ~(MSG_MORE | MSG_DONTWAIT | MSG_NOSIGNAL |
   1282		      MSG_SENDPAGE_NOTLAST | MSG_SENDPAGE_NOPOLICY))
   1283		return -EOPNOTSUPP;
   1284
   1285	mutex_lock(&tls_ctx->tx_lock);
   1286	lock_sock(sk);
   1287	ret = tls_sw_do_sendpage(sk, page, offset, size, flags);
   1288	release_sock(sk);
   1289	mutex_unlock(&tls_ctx->tx_lock);
   1290	return ret;
   1291}
   1292
   1293static struct sk_buff *tls_wait_data(struct sock *sk, struct sk_psock *psock,
   1294				     bool nonblock, long timeo, int *err)
   1295{
   1296	struct tls_context *tls_ctx = tls_get_ctx(sk);
   1297	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
   1298	struct sk_buff *skb;
   1299	DEFINE_WAIT_FUNC(wait, woken_wake_function);
   1300
   1301	while (!(skb = ctx->recv_pkt) && sk_psock_queue_empty(psock)) {
   1302		if (sk->sk_err) {
   1303			*err = sock_error(sk);
   1304			return NULL;
   1305		}
   1306
   1307		if (!skb_queue_empty(&sk->sk_receive_queue)) {
   1308			__strp_unpause(&ctx->strp);
   1309			if (ctx->recv_pkt)
   1310				return ctx->recv_pkt;
   1311		}
   1312
   1313		if (sk->sk_shutdown & RCV_SHUTDOWN)
   1314			return NULL;
   1315
   1316		if (sock_flag(sk, SOCK_DONE))
   1317			return NULL;
   1318
   1319		if (nonblock || !timeo) {
   1320			*err = -EAGAIN;
   1321			return NULL;
   1322		}
   1323
   1324		add_wait_queue(sk_sleep(sk), &wait);
   1325		sk_set_bit(SOCKWQ_ASYNC_WAITDATA, sk);
   1326		sk_wait_event(sk, &timeo,
   1327			      ctx->recv_pkt != skb ||
   1328			      !sk_psock_queue_empty(psock),
   1329			      &wait);
   1330		sk_clear_bit(SOCKWQ_ASYNC_WAITDATA, sk);
   1331		remove_wait_queue(sk_sleep(sk), &wait);
   1332
   1333		/* Handle signals */
   1334		if (signal_pending(current)) {
   1335			*err = sock_intr_errno(timeo);
   1336			return NULL;
   1337		}
   1338	}
   1339
   1340	return skb;
   1341}
   1342
   1343static int tls_setup_from_iter(struct iov_iter *from,
   1344			       int length, int *pages_used,
   1345			       struct scatterlist *to,
   1346			       int to_max_pages)
   1347{
   1348	int rc = 0, i = 0, num_elem = *pages_used, maxpages;
   1349	struct page *pages[MAX_SKB_FRAGS];
   1350	unsigned int size = 0;
   1351	ssize_t copied, use;
   1352	size_t offset;
   1353
   1354	while (length > 0) {
   1355		i = 0;
   1356		maxpages = to_max_pages - num_elem;
   1357		if (maxpages == 0) {
   1358			rc = -EFAULT;
   1359			goto out;
   1360		}
   1361		copied = iov_iter_get_pages(from, pages,
   1362					    length,
   1363					    maxpages, &offset);
   1364		if (copied <= 0) {
   1365			rc = -EFAULT;
   1366			goto out;
   1367		}
   1368
   1369		iov_iter_advance(from, copied);
   1370
   1371		length -= copied;
   1372		size += copied;
   1373		while (copied) {
   1374			use = min_t(int, copied, PAGE_SIZE - offset);
   1375
   1376			sg_set_page(&to[num_elem],
   1377				    pages[i], use, offset);
   1378			sg_unmark_end(&to[num_elem]);
   1379			/* We do not uncharge memory from this API */
   1380
   1381			offset = 0;
   1382			copied -= use;
   1383
   1384			i++;
   1385			num_elem++;
   1386		}
   1387	}
   1388	/* Mark the end in the last sg entry if newly added */
   1389	if (num_elem > *pages_used)
   1390		sg_mark_end(&to[num_elem - 1]);
   1391out:
   1392	if (rc)
   1393		iov_iter_revert(from, size);
   1394	*pages_used = num_elem;
   1395
   1396	return rc;
   1397}
   1398
   1399/* This function decrypts the input skb into either out_iov or in out_sg
   1400 * or in skb buffers itself. The input parameter 'zc' indicates if
   1401 * zero-copy mode needs to be tried or not. With zero-copy mode, either
   1402 * out_iov or out_sg must be non-NULL. In case both out_iov and out_sg are
   1403 * NULL, then the decryption happens inside skb buffers itself, i.e.
   1404 * zero-copy gets disabled and 'zc' is updated.
   1405 */
   1406
   1407static int decrypt_internal(struct sock *sk, struct sk_buff *skb,
   1408			    struct iov_iter *out_iov,
   1409			    struct scatterlist *out_sg,
   1410			    struct tls_decrypt_arg *darg)
   1411{
   1412	struct tls_context *tls_ctx = tls_get_ctx(sk);
   1413	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
   1414	struct tls_prot_info *prot = &tls_ctx->prot_info;
   1415	struct strp_msg *rxm = strp_msg(skb);
   1416	struct tls_msg *tlm = tls_msg(skb);
   1417	int n_sgin, n_sgout, nsg, mem_size, aead_size, err, pages = 0;
   1418	struct aead_request *aead_req;
   1419	struct sk_buff *unused;
   1420	u8 *aad, *iv, *mem = NULL;
   1421	struct scatterlist *sgin = NULL;
   1422	struct scatterlist *sgout = NULL;
   1423	const int data_len = rxm->full_len - prot->overhead_size +
   1424			     prot->tail_size;
   1425	int iv_offset = 0;
   1426
   1427	if (darg->zc && (out_iov || out_sg)) {
   1428		if (out_iov)
   1429			n_sgout = 1 +
   1430				iov_iter_npages_cap(out_iov, INT_MAX, data_len);
   1431		else
   1432			n_sgout = sg_nents(out_sg);
   1433		n_sgin = skb_nsg(skb, rxm->offset + prot->prepend_size,
   1434				 rxm->full_len - prot->prepend_size);
   1435	} else {
   1436		n_sgout = 0;
   1437		darg->zc = false;
   1438		n_sgin = skb_cow_data(skb, 0, &unused);
   1439	}
   1440
   1441	if (n_sgin < 1)
   1442		return -EBADMSG;
   1443
   1444	/* Increment to accommodate AAD */
   1445	n_sgin = n_sgin + 1;
   1446
   1447	nsg = n_sgin + n_sgout;
   1448
   1449	aead_size = sizeof(*aead_req) + crypto_aead_reqsize(ctx->aead_recv);
   1450	mem_size = aead_size + (nsg * sizeof(struct scatterlist));
   1451	mem_size = mem_size + prot->aad_size;
   1452	mem_size = mem_size + MAX_IV_SIZE;
   1453
   1454	/* Allocate a single block of memory which contains
   1455	 * aead_req || sgin[] || sgout[] || aad || iv.
   1456	 * This order achieves correct alignment for aead_req, sgin, sgout.
   1457	 */
   1458	mem = kmalloc(mem_size, sk->sk_allocation);
   1459	if (!mem)
   1460		return -ENOMEM;
   1461
   1462	/* Segment the allocated memory */
   1463	aead_req = (struct aead_request *)mem;
   1464	sgin = (struct scatterlist *)(mem + aead_size);
   1465	sgout = sgin + n_sgin;
   1466	aad = (u8 *)(sgout + n_sgout);
   1467	iv = aad + prot->aad_size;
   1468
   1469	/* For CCM based ciphers, first byte of nonce+iv is a constant */
   1470	switch (prot->cipher_type) {
   1471	case TLS_CIPHER_AES_CCM_128:
   1472		iv[0] = TLS_AES_CCM_IV_B0_BYTE;
   1473		iv_offset = 1;
   1474		break;
   1475	case TLS_CIPHER_SM4_CCM:
   1476		iv[0] = TLS_SM4_CCM_IV_B0_BYTE;
   1477		iv_offset = 1;
   1478		break;
   1479	}
   1480
   1481	/* Prepare IV */
   1482	if (prot->version == TLS_1_3_VERSION ||
   1483	    prot->cipher_type == TLS_CIPHER_CHACHA20_POLY1305) {
   1484		memcpy(iv + iv_offset, tls_ctx->rx.iv,
   1485		       prot->iv_size + prot->salt_size);
   1486	} else {
   1487		err = skb_copy_bits(skb, rxm->offset + TLS_HEADER_SIZE,
   1488				    iv + iv_offset + prot->salt_size,
   1489				    prot->iv_size);
   1490		if (err < 0) {
   1491			kfree(mem);
   1492			return err;
   1493		}
   1494		memcpy(iv + iv_offset, tls_ctx->rx.iv, prot->salt_size);
   1495	}
   1496	xor_iv_with_seq(prot, iv + iv_offset, tls_ctx->rx.rec_seq);
   1497
   1498	/* Prepare AAD */
   1499	tls_make_aad(aad, rxm->full_len - prot->overhead_size +
   1500		     prot->tail_size,
   1501		     tls_ctx->rx.rec_seq, tlm->control, prot);
   1502
   1503	/* Prepare sgin */
   1504	sg_init_table(sgin, n_sgin);
   1505	sg_set_buf(&sgin[0], aad, prot->aad_size);
   1506	err = skb_to_sgvec(skb, &sgin[1],
   1507			   rxm->offset + prot->prepend_size,
   1508			   rxm->full_len - prot->prepend_size);
   1509	if (err < 0) {
   1510		kfree(mem);
   1511		return err;
   1512	}
   1513
   1514	if (n_sgout) {
   1515		if (out_iov) {
   1516			sg_init_table(sgout, n_sgout);
   1517			sg_set_buf(&sgout[0], aad, prot->aad_size);
   1518
   1519			err = tls_setup_from_iter(out_iov, data_len,
   1520						  &pages, &sgout[1],
   1521						  (n_sgout - 1));
   1522			if (err < 0)
   1523				goto fallback_to_reg_recv;
   1524		} else if (out_sg) {
   1525			memcpy(sgout, out_sg, n_sgout * sizeof(*sgout));
   1526		} else {
   1527			goto fallback_to_reg_recv;
   1528		}
   1529	} else {
   1530fallback_to_reg_recv:
   1531		sgout = sgin;
   1532		pages = 0;
   1533		darg->zc = false;
   1534	}
   1535
   1536	/* Prepare and submit AEAD request */
   1537	err = tls_do_decryption(sk, skb, sgin, sgout, iv,
   1538				data_len, aead_req, darg);
   1539	if (darg->async)
   1540		return 0;
   1541
   1542	/* Release the pages in case iov was mapped to pages */
   1543	for (; pages > 0; pages--)
   1544		put_page(sg_page(&sgout[pages]));
   1545
   1546	kfree(mem);
   1547	return err;
   1548}
   1549
   1550static int decrypt_skb_update(struct sock *sk, struct sk_buff *skb,
   1551			      struct iov_iter *dest,
   1552			      struct tls_decrypt_arg *darg)
   1553{
   1554	struct tls_context *tls_ctx = tls_get_ctx(sk);
   1555	struct tls_prot_info *prot = &tls_ctx->prot_info;
   1556	struct strp_msg *rxm = strp_msg(skb);
   1557	struct tls_msg *tlm = tls_msg(skb);
   1558	int pad, err;
   1559
   1560	if (tlm->decrypted) {
   1561		darg->zc = false;
   1562		darg->async = false;
   1563		return 0;
   1564	}
   1565
   1566	if (tls_ctx->rx_conf == TLS_HW) {
   1567		err = tls_device_decrypted(sk, tls_ctx, skb, rxm);
   1568		if (err < 0)
   1569			return err;
   1570		if (err > 0) {
   1571			tlm->decrypted = 1;
   1572			darg->zc = false;
   1573			darg->async = false;
   1574			goto decrypt_done;
   1575		}
   1576	}
   1577
   1578	err = decrypt_internal(sk, skb, dest, NULL, darg);
   1579	if (err < 0) {
   1580		if (err == -EBADMSG)
   1581			TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSDECRYPTERROR);
   1582		return err;
   1583	}
   1584	if (darg->async)
   1585		goto decrypt_next;
   1586
   1587decrypt_done:
   1588	pad = padding_length(prot, skb);
   1589	if (pad < 0)
   1590		return pad;
   1591
   1592	rxm->full_len -= pad;
   1593	rxm->offset += prot->prepend_size;
   1594	rxm->full_len -= prot->overhead_size;
   1595	tlm->decrypted = 1;
   1596decrypt_next:
   1597	tls_advance_record_sn(sk, prot, &tls_ctx->rx);
   1598
   1599	return 0;
   1600}
   1601
   1602int decrypt_skb(struct sock *sk, struct sk_buff *skb,
   1603		struct scatterlist *sgout)
   1604{
   1605	struct tls_decrypt_arg darg = { .zc = true, };
   1606
   1607	return decrypt_internal(sk, skb, NULL, sgout, &darg);
   1608}
   1609
   1610static int tls_record_content_type(struct msghdr *msg, struct tls_msg *tlm,
   1611				   u8 *control)
   1612{
   1613	int err;
   1614
   1615	if (!*control) {
   1616		*control = tlm->control;
   1617		if (!*control)
   1618			return -EBADMSG;
   1619
   1620		err = put_cmsg(msg, SOL_TLS, TLS_GET_RECORD_TYPE,
   1621			       sizeof(*control), control);
   1622		if (*control != TLS_RECORD_TYPE_DATA) {
   1623			if (err || msg->msg_flags & MSG_CTRUNC)
   1624				return -EIO;
   1625		}
   1626	} else if (*control != tlm->control) {
   1627		return 0;
   1628	}
   1629
   1630	return 1;
   1631}
   1632
   1633/* This function traverses the rx_list in tls receive context to copies the
   1634 * decrypted records into the buffer provided by caller zero copy is not
   1635 * true. Further, the records are removed from the rx_list if it is not a peek
   1636 * case and the record has been consumed completely.
   1637 */
   1638static int process_rx_list(struct tls_sw_context_rx *ctx,
   1639			   struct msghdr *msg,
   1640			   u8 *control,
   1641			   size_t skip,
   1642			   size_t len,
   1643			   bool zc,
   1644			   bool is_peek)
   1645{
   1646	struct sk_buff *skb = skb_peek(&ctx->rx_list);
   1647	struct tls_msg *tlm;
   1648	ssize_t copied = 0;
   1649	int err;
   1650
   1651	while (skip && skb) {
   1652		struct strp_msg *rxm = strp_msg(skb);
   1653		tlm = tls_msg(skb);
   1654
   1655		err = tls_record_content_type(msg, tlm, control);
   1656		if (err <= 0)
   1657			goto out;
   1658
   1659		if (skip < rxm->full_len)
   1660			break;
   1661
   1662		skip = skip - rxm->full_len;
   1663		skb = skb_peek_next(skb, &ctx->rx_list);
   1664	}
   1665
   1666	while (len && skb) {
   1667		struct sk_buff *next_skb;
   1668		struct strp_msg *rxm = strp_msg(skb);
   1669		int chunk = min_t(unsigned int, rxm->full_len - skip, len);
   1670
   1671		tlm = tls_msg(skb);
   1672
   1673		err = tls_record_content_type(msg, tlm, control);
   1674		if (err <= 0)
   1675			goto out;
   1676
   1677		if (!zc || (rxm->full_len - skip) > len) {
   1678			err = skb_copy_datagram_msg(skb, rxm->offset + skip,
   1679						    msg, chunk);
   1680			if (err < 0)
   1681				goto out;
   1682		}
   1683
   1684		len = len - chunk;
   1685		copied = copied + chunk;
   1686
   1687		/* Consume the data from record if it is non-peek case*/
   1688		if (!is_peek) {
   1689			rxm->offset = rxm->offset + chunk;
   1690			rxm->full_len = rxm->full_len - chunk;
   1691
   1692			/* Return if there is unconsumed data in the record */
   1693			if (rxm->full_len - skip)
   1694				break;
   1695		}
   1696
   1697		/* The remaining skip-bytes must lie in 1st record in rx_list.
   1698		 * So from the 2nd record, 'skip' should be 0.
   1699		 */
   1700		skip = 0;
   1701
   1702		if (msg)
   1703			msg->msg_flags |= MSG_EOR;
   1704
   1705		next_skb = skb_peek_next(skb, &ctx->rx_list);
   1706
   1707		if (!is_peek) {
   1708			__skb_unlink(skb, &ctx->rx_list);
   1709			consume_skb(skb);
   1710		}
   1711
   1712		skb = next_skb;
   1713	}
   1714	err = 0;
   1715
   1716out:
   1717	return copied ? : err;
   1718}
   1719
   1720int tls_sw_recvmsg(struct sock *sk,
   1721		   struct msghdr *msg,
   1722		   size_t len,
   1723		   int flags,
   1724		   int *addr_len)
   1725{
   1726	struct tls_context *tls_ctx = tls_get_ctx(sk);
   1727	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
   1728	struct tls_prot_info *prot = &tls_ctx->prot_info;
   1729	struct sk_psock *psock;
   1730	unsigned char control = 0;
   1731	ssize_t decrypted = 0;
   1732	struct strp_msg *rxm;
   1733	struct tls_msg *tlm;
   1734	struct sk_buff *skb;
   1735	ssize_t copied = 0;
   1736	bool async = false;
   1737	int target, err = 0;
   1738	long timeo;
   1739	bool is_kvec = iov_iter_is_kvec(&msg->msg_iter);
   1740	bool is_peek = flags & MSG_PEEK;
   1741	bool bpf_strp_enabled;
   1742	bool zc_capable;
   1743
   1744	if (unlikely(flags & MSG_ERRQUEUE))
   1745		return sock_recv_errqueue(sk, msg, len, SOL_IP, IP_RECVERR);
   1746
   1747	psock = sk_psock_get(sk);
   1748	lock_sock(sk);
   1749	bpf_strp_enabled = sk_psock_strp_enabled(psock);
   1750
   1751	/* If crypto failed the connection is broken */
   1752	err = ctx->async_wait.err;
   1753	if (err)
   1754		goto end;
   1755
   1756	/* Process pending decrypted records. It must be non-zero-copy */
   1757	err = process_rx_list(ctx, msg, &control, 0, len, false, is_peek);
   1758	if (err < 0)
   1759		goto end;
   1760
   1761	copied = err;
   1762	if (len <= copied)
   1763		goto end;
   1764
   1765	target = sock_rcvlowat(sk, flags & MSG_WAITALL, len);
   1766	len = len - copied;
   1767	timeo = sock_rcvtimeo(sk, flags & MSG_DONTWAIT);
   1768
   1769	zc_capable = !bpf_strp_enabled && !is_kvec && !is_peek &&
   1770		     prot->version != TLS_1_3_VERSION;
   1771	decrypted = 0;
   1772	while (len && (decrypted + copied < target || ctx->recv_pkt)) {
   1773		struct tls_decrypt_arg darg = {};
   1774		int to_decrypt, chunk;
   1775
   1776		skb = tls_wait_data(sk, psock, flags & MSG_DONTWAIT, timeo, &err);
   1777		if (!skb) {
   1778			if (psock) {
   1779				chunk = sk_msg_recvmsg(sk, psock, msg, len,
   1780						       flags);
   1781				if (chunk > 0)
   1782					goto leave_on_list;
   1783			}
   1784			goto recv_end;
   1785		}
   1786
   1787		rxm = strp_msg(skb);
   1788		tlm = tls_msg(skb);
   1789
   1790		to_decrypt = rxm->full_len - prot->overhead_size;
   1791
   1792		if (zc_capable && to_decrypt <= len &&
   1793		    tlm->control == TLS_RECORD_TYPE_DATA)
   1794			darg.zc = true;
   1795
   1796		/* Do not use async mode if record is non-data */
   1797		if (tlm->control == TLS_RECORD_TYPE_DATA && !bpf_strp_enabled)
   1798			darg.async = ctx->async_capable;
   1799		else
   1800			darg.async = false;
   1801
   1802		err = decrypt_skb_update(sk, skb, &msg->msg_iter, &darg);
   1803		if (err < 0) {
   1804			tls_err_abort(sk, -EBADMSG);
   1805			goto recv_end;
   1806		}
   1807
   1808		async |= darg.async;
   1809
   1810		/* If the type of records being processed is not known yet,
   1811		 * set it to record type just dequeued. If it is already known,
   1812		 * but does not match the record type just dequeued, go to end.
   1813		 * We always get record type here since for tls1.2, record type
   1814		 * is known just after record is dequeued from stream parser.
   1815		 * For tls1.3, we disable async.
   1816		 */
   1817		err = tls_record_content_type(msg, tlm, &control);
   1818		if (err <= 0)
   1819			goto recv_end;
   1820
   1821		ctx->recv_pkt = NULL;
   1822		__strp_unpause(&ctx->strp);
   1823		__skb_queue_tail(&ctx->rx_list, skb);
   1824
   1825		if (async) {
   1826			/* TLS 1.2-only, to_decrypt must be text length */
   1827			chunk = min_t(int, to_decrypt, len);
   1828leave_on_list:
   1829			decrypted += chunk;
   1830			len -= chunk;
   1831			continue;
   1832		}
   1833		/* TLS 1.3 may have updated the length by more than overhead */
   1834		chunk = rxm->full_len;
   1835
   1836		if (!darg.zc) {
   1837			bool partially_consumed = chunk > len;
   1838
   1839			if (bpf_strp_enabled) {
   1840				/* BPF may try to queue the skb */
   1841				__skb_unlink(skb, &ctx->rx_list);
   1842				err = sk_psock_tls_strp_read(psock, skb);
   1843				if (err != __SK_PASS) {
   1844					rxm->offset = rxm->offset + rxm->full_len;
   1845					rxm->full_len = 0;
   1846					if (err == __SK_DROP)
   1847						consume_skb(skb);
   1848					continue;
   1849				}
   1850				__skb_queue_tail(&ctx->rx_list, skb);
   1851			}
   1852
   1853			if (partially_consumed)
   1854				chunk = len;
   1855
   1856			err = skb_copy_datagram_msg(skb, rxm->offset,
   1857						    msg, chunk);
   1858			if (err < 0)
   1859				goto recv_end;
   1860
   1861			if (is_peek)
   1862				goto leave_on_list;
   1863
   1864			if (partially_consumed) {
   1865				rxm->offset += chunk;
   1866				rxm->full_len -= chunk;
   1867				goto leave_on_list;
   1868			}
   1869		}
   1870
   1871		decrypted += chunk;
   1872		len -= chunk;
   1873
   1874		__skb_unlink(skb, &ctx->rx_list);
   1875		consume_skb(skb);
   1876
   1877		/* Return full control message to userspace before trying
   1878		 * to parse another message type
   1879		 */
   1880		msg->msg_flags |= MSG_EOR;
   1881		if (control != TLS_RECORD_TYPE_DATA)
   1882			break;
   1883	}
   1884
   1885recv_end:
   1886	if (async) {
   1887		int ret, pending;
   1888
   1889		/* Wait for all previously submitted records to be decrypted */
   1890		spin_lock_bh(&ctx->decrypt_compl_lock);
   1891		reinit_completion(&ctx->async_wait.completion);
   1892		pending = atomic_read(&ctx->decrypt_pending);
   1893		spin_unlock_bh(&ctx->decrypt_compl_lock);
   1894		if (pending) {
   1895			ret = crypto_wait_req(-EINPROGRESS, &ctx->async_wait);
   1896			if (ret) {
   1897				if (err >= 0 || err == -EINPROGRESS)
   1898					err = ret;
   1899				decrypted = 0;
   1900				goto end;
   1901			}
   1902		}
   1903
   1904		/* Drain records from the rx_list & copy if required */
   1905		if (is_peek || is_kvec)
   1906			err = process_rx_list(ctx, msg, &control, copied,
   1907					      decrypted, false, is_peek);
   1908		else
   1909			err = process_rx_list(ctx, msg, &control, 0,
   1910					      decrypted, true, is_peek);
   1911		decrypted = max(err, 0);
   1912	}
   1913
   1914	copied += decrypted;
   1915
   1916end:
   1917	release_sock(sk);
   1918	if (psock)
   1919		sk_psock_put(sk, psock);
   1920	return copied ? : err;
   1921}
   1922
   1923ssize_t tls_sw_splice_read(struct socket *sock,  loff_t *ppos,
   1924			   struct pipe_inode_info *pipe,
   1925			   size_t len, unsigned int flags)
   1926{
   1927	struct tls_context *tls_ctx = tls_get_ctx(sock->sk);
   1928	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
   1929	struct strp_msg *rxm = NULL;
   1930	struct sock *sk = sock->sk;
   1931	struct tls_msg *tlm;
   1932	struct sk_buff *skb;
   1933	ssize_t copied = 0;
   1934	bool from_queue;
   1935	int err = 0;
   1936	long timeo;
   1937	int chunk;
   1938
   1939	lock_sock(sk);
   1940
   1941	timeo = sock_rcvtimeo(sk, flags & SPLICE_F_NONBLOCK);
   1942
   1943	from_queue = !skb_queue_empty(&ctx->rx_list);
   1944	if (from_queue) {
   1945		skb = __skb_dequeue(&ctx->rx_list);
   1946	} else {
   1947		struct tls_decrypt_arg darg = {};
   1948
   1949		skb = tls_wait_data(sk, NULL, flags & SPLICE_F_NONBLOCK, timeo,
   1950				    &err);
   1951		if (!skb)
   1952			goto splice_read_end;
   1953
   1954		err = decrypt_skb_update(sk, skb, NULL, &darg);
   1955		if (err < 0) {
   1956			tls_err_abort(sk, -EBADMSG);
   1957			goto splice_read_end;
   1958		}
   1959	}
   1960
   1961	rxm = strp_msg(skb);
   1962	tlm = tls_msg(skb);
   1963
   1964	/* splice does not support reading control messages */
   1965	if (tlm->control != TLS_RECORD_TYPE_DATA) {
   1966		err = -EINVAL;
   1967		goto splice_read_end;
   1968	}
   1969
   1970	chunk = min_t(unsigned int, rxm->full_len, len);
   1971	copied = skb_splice_bits(skb, sk, rxm->offset, pipe, chunk, flags);
   1972	if (copied < 0)
   1973		goto splice_read_end;
   1974
   1975	if (!from_queue) {
   1976		ctx->recv_pkt = NULL;
   1977		__strp_unpause(&ctx->strp);
   1978	}
   1979	if (chunk < rxm->full_len) {
   1980		__skb_queue_head(&ctx->rx_list, skb);
   1981		rxm->offset += len;
   1982		rxm->full_len -= len;
   1983	} else {
   1984		consume_skb(skb);
   1985	}
   1986
   1987splice_read_end:
   1988	release_sock(sk);
   1989	return copied ? : err;
   1990}
   1991
   1992bool tls_sw_sock_is_readable(struct sock *sk)
   1993{
   1994	struct tls_context *tls_ctx = tls_get_ctx(sk);
   1995	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
   1996	bool ingress_empty = true;
   1997	struct sk_psock *psock;
   1998
   1999	rcu_read_lock();
   2000	psock = sk_psock(sk);
   2001	if (psock)
   2002		ingress_empty = list_empty(&psock->ingress_msg);
   2003	rcu_read_unlock();
   2004
   2005	return !ingress_empty || ctx->recv_pkt ||
   2006		!skb_queue_empty(&ctx->rx_list);
   2007}
   2008
   2009static int tls_read_size(struct strparser *strp, struct sk_buff *skb)
   2010{
   2011	struct tls_context *tls_ctx = tls_get_ctx(strp->sk);
   2012	struct tls_prot_info *prot = &tls_ctx->prot_info;
   2013	char header[TLS_HEADER_SIZE + MAX_IV_SIZE];
   2014	struct strp_msg *rxm = strp_msg(skb);
   2015	struct tls_msg *tlm = tls_msg(skb);
   2016	size_t cipher_overhead;
   2017	size_t data_len = 0;
   2018	int ret;
   2019
   2020	/* Verify that we have a full TLS header, or wait for more data */
   2021	if (rxm->offset + prot->prepend_size > skb->len)
   2022		return 0;
   2023
   2024	/* Sanity-check size of on-stack buffer. */
   2025	if (WARN_ON(prot->prepend_size > sizeof(header))) {
   2026		ret = -EINVAL;
   2027		goto read_failure;
   2028	}
   2029
   2030	/* Linearize header to local buffer */
   2031	ret = skb_copy_bits(skb, rxm->offset, header, prot->prepend_size);
   2032	if (ret < 0)
   2033		goto read_failure;
   2034
   2035	tlm->decrypted = 0;
   2036	tlm->control = header[0];
   2037
   2038	data_len = ((header[4] & 0xFF) | (header[3] << 8));
   2039
   2040	cipher_overhead = prot->tag_size;
   2041	if (prot->version != TLS_1_3_VERSION &&
   2042	    prot->cipher_type != TLS_CIPHER_CHACHA20_POLY1305)
   2043		cipher_overhead += prot->iv_size;
   2044
   2045	if (data_len > TLS_MAX_PAYLOAD_SIZE + cipher_overhead +
   2046	    prot->tail_size) {
   2047		ret = -EMSGSIZE;
   2048		goto read_failure;
   2049	}
   2050	if (data_len < cipher_overhead) {
   2051		ret = -EBADMSG;
   2052		goto read_failure;
   2053	}
   2054
   2055	/* Note that both TLS1.3 and TLS1.2 use TLS_1_2 version here */
   2056	if (header[1] != TLS_1_2_VERSION_MINOR ||
   2057	    header[2] != TLS_1_2_VERSION_MAJOR) {
   2058		ret = -EINVAL;
   2059		goto read_failure;
   2060	}
   2061
   2062	tls_device_rx_resync_new_rec(strp->sk, data_len + TLS_HEADER_SIZE,
   2063				     TCP_SKB_CB(skb)->seq + rxm->offset);
   2064	return data_len + TLS_HEADER_SIZE;
   2065
   2066read_failure:
   2067	tls_err_abort(strp->sk, ret);
   2068
   2069	return ret;
   2070}
   2071
   2072static void tls_queue(struct strparser *strp, struct sk_buff *skb)
   2073{
   2074	struct tls_context *tls_ctx = tls_get_ctx(strp->sk);
   2075	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
   2076
   2077	ctx->recv_pkt = skb;
   2078	strp_pause(strp);
   2079
   2080	ctx->saved_data_ready(strp->sk);
   2081}
   2082
   2083static void tls_data_ready(struct sock *sk)
   2084{
   2085	struct tls_context *tls_ctx = tls_get_ctx(sk);
   2086	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
   2087	struct sk_psock *psock;
   2088
   2089	strp_data_ready(&ctx->strp);
   2090
   2091	psock = sk_psock_get(sk);
   2092	if (psock) {
   2093		if (!list_empty(&psock->ingress_msg))
   2094			ctx->saved_data_ready(sk);
   2095		sk_psock_put(sk, psock);
   2096	}
   2097}
   2098
   2099void tls_sw_cancel_work_tx(struct tls_context *tls_ctx)
   2100{
   2101	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
   2102
   2103	set_bit(BIT_TX_CLOSING, &ctx->tx_bitmask);
   2104	set_bit(BIT_TX_SCHEDULED, &ctx->tx_bitmask);
   2105	cancel_delayed_work_sync(&ctx->tx_work.work);
   2106}
   2107
   2108void tls_sw_release_resources_tx(struct sock *sk)
   2109{
   2110	struct tls_context *tls_ctx = tls_get_ctx(sk);
   2111	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
   2112	struct tls_rec *rec, *tmp;
   2113	int pending;
   2114
   2115	/* Wait for any pending async encryptions to complete */
   2116	spin_lock_bh(&ctx->encrypt_compl_lock);
   2117	ctx->async_notify = true;
   2118	pending = atomic_read(&ctx->encrypt_pending);
   2119	spin_unlock_bh(&ctx->encrypt_compl_lock);
   2120
   2121	if (pending)
   2122		crypto_wait_req(-EINPROGRESS, &ctx->async_wait);
   2123
   2124	tls_tx_records(sk, -1);
   2125
   2126	/* Free up un-sent records in tx_list. First, free
   2127	 * the partially sent record if any at head of tx_list.
   2128	 */
   2129	if (tls_ctx->partially_sent_record) {
   2130		tls_free_partial_record(sk, tls_ctx);
   2131		rec = list_first_entry(&ctx->tx_list,
   2132				       struct tls_rec, list);
   2133		list_del(&rec->list);
   2134		sk_msg_free(sk, &rec->msg_plaintext);
   2135		kfree(rec);
   2136	}
   2137
   2138	list_for_each_entry_safe(rec, tmp, &ctx->tx_list, list) {
   2139		list_del(&rec->list);
   2140		sk_msg_free(sk, &rec->msg_encrypted);
   2141		sk_msg_free(sk, &rec->msg_plaintext);
   2142		kfree(rec);
   2143	}
   2144
   2145	crypto_free_aead(ctx->aead_send);
   2146	tls_free_open_rec(sk);
   2147}
   2148
   2149void tls_sw_free_ctx_tx(struct tls_context *tls_ctx)
   2150{
   2151	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
   2152
   2153	kfree(ctx);
   2154}
   2155
   2156void tls_sw_release_resources_rx(struct sock *sk)
   2157{
   2158	struct tls_context *tls_ctx = tls_get_ctx(sk);
   2159	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
   2160
   2161	kfree(tls_ctx->rx.rec_seq);
   2162	kfree(tls_ctx->rx.iv);
   2163
   2164	if (ctx->aead_recv) {
   2165		kfree_skb(ctx->recv_pkt);
   2166		ctx->recv_pkt = NULL;
   2167		__skb_queue_purge(&ctx->rx_list);
   2168		crypto_free_aead(ctx->aead_recv);
   2169		strp_stop(&ctx->strp);
   2170		/* If tls_sw_strparser_arm() was not called (cleanup paths)
   2171		 * we still want to strp_stop(), but sk->sk_data_ready was
   2172		 * never swapped.
   2173		 */
   2174		if (ctx->saved_data_ready) {
   2175			write_lock_bh(&sk->sk_callback_lock);
   2176			sk->sk_data_ready = ctx->saved_data_ready;
   2177			write_unlock_bh(&sk->sk_callback_lock);
   2178		}
   2179	}
   2180}
   2181
   2182void tls_sw_strparser_done(struct tls_context *tls_ctx)
   2183{
   2184	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
   2185
   2186	strp_done(&ctx->strp);
   2187}
   2188
   2189void tls_sw_free_ctx_rx(struct tls_context *tls_ctx)
   2190{
   2191	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
   2192
   2193	kfree(ctx);
   2194}
   2195
   2196void tls_sw_free_resources_rx(struct sock *sk)
   2197{
   2198	struct tls_context *tls_ctx = tls_get_ctx(sk);
   2199
   2200	tls_sw_release_resources_rx(sk);
   2201	tls_sw_free_ctx_rx(tls_ctx);
   2202}
   2203
   2204/* The work handler to transmitt the encrypted records in tx_list */
   2205static void tx_work_handler(struct work_struct *work)
   2206{
   2207	struct delayed_work *delayed_work = to_delayed_work(work);
   2208	struct tx_work *tx_work = container_of(delayed_work,
   2209					       struct tx_work, work);
   2210	struct sock *sk = tx_work->sk;
   2211	struct tls_context *tls_ctx = tls_get_ctx(sk);
   2212	struct tls_sw_context_tx *ctx;
   2213
   2214	if (unlikely(!tls_ctx))
   2215		return;
   2216
   2217	ctx = tls_sw_ctx_tx(tls_ctx);
   2218	if (test_bit(BIT_TX_CLOSING, &ctx->tx_bitmask))
   2219		return;
   2220
   2221	if (!test_and_clear_bit(BIT_TX_SCHEDULED, &ctx->tx_bitmask))
   2222		return;
   2223	mutex_lock(&tls_ctx->tx_lock);
   2224	lock_sock(sk);
   2225	tls_tx_records(sk, -1);
   2226	release_sock(sk);
   2227	mutex_unlock(&tls_ctx->tx_lock);
   2228}
   2229
   2230void tls_sw_write_space(struct sock *sk, struct tls_context *ctx)
   2231{
   2232	struct tls_sw_context_tx *tx_ctx = tls_sw_ctx_tx(ctx);
   2233
   2234	/* Schedule the transmission if tx list is ready */
   2235	if (is_tx_ready(tx_ctx) &&
   2236	    !test_and_set_bit(BIT_TX_SCHEDULED, &tx_ctx->tx_bitmask))
   2237		schedule_delayed_work(&tx_ctx->tx_work.work, 0);
   2238}
   2239
   2240void tls_sw_strparser_arm(struct sock *sk, struct tls_context *tls_ctx)
   2241{
   2242	struct tls_sw_context_rx *rx_ctx = tls_sw_ctx_rx(tls_ctx);
   2243
   2244	write_lock_bh(&sk->sk_callback_lock);
   2245	rx_ctx->saved_data_ready = sk->sk_data_ready;
   2246	sk->sk_data_ready = tls_data_ready;
   2247	write_unlock_bh(&sk->sk_callback_lock);
   2248
   2249	strp_check_rcv(&rx_ctx->strp);
   2250}
   2251
   2252int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx)
   2253{
   2254	struct tls_context *tls_ctx = tls_get_ctx(sk);
   2255	struct tls_prot_info *prot = &tls_ctx->prot_info;
   2256	struct tls_crypto_info *crypto_info;
   2257	struct tls_sw_context_tx *sw_ctx_tx = NULL;
   2258	struct tls_sw_context_rx *sw_ctx_rx = NULL;
   2259	struct cipher_context *cctx;
   2260	struct crypto_aead **aead;
   2261	struct strp_callbacks cb;
   2262	u16 nonce_size, tag_size, iv_size, rec_seq_size, salt_size;
   2263	struct crypto_tfm *tfm;
   2264	char *iv, *rec_seq, *key, *salt, *cipher_name;
   2265	size_t keysize;
   2266	int rc = 0;
   2267
   2268	if (!ctx) {
   2269		rc = -EINVAL;
   2270		goto out;
   2271	}
   2272
   2273	if (tx) {
   2274		if (!ctx->priv_ctx_tx) {
   2275			sw_ctx_tx = kzalloc(sizeof(*sw_ctx_tx), GFP_KERNEL);
   2276			if (!sw_ctx_tx) {
   2277				rc = -ENOMEM;
   2278				goto out;
   2279			}
   2280			ctx->priv_ctx_tx = sw_ctx_tx;
   2281		} else {
   2282			sw_ctx_tx =
   2283				(struct tls_sw_context_tx *)ctx->priv_ctx_tx;
   2284		}
   2285	} else {
   2286		if (!ctx->priv_ctx_rx) {
   2287			sw_ctx_rx = kzalloc(sizeof(*sw_ctx_rx), GFP_KERNEL);
   2288			if (!sw_ctx_rx) {
   2289				rc = -ENOMEM;
   2290				goto out;
   2291			}
   2292			ctx->priv_ctx_rx = sw_ctx_rx;
   2293		} else {
   2294			sw_ctx_rx =
   2295				(struct tls_sw_context_rx *)ctx->priv_ctx_rx;
   2296		}
   2297	}
   2298
   2299	if (tx) {
   2300		crypto_init_wait(&sw_ctx_tx->async_wait);
   2301		spin_lock_init(&sw_ctx_tx->encrypt_compl_lock);
   2302		crypto_info = &ctx->crypto_send.info;
   2303		cctx = &ctx->tx;
   2304		aead = &sw_ctx_tx->aead_send;
   2305		INIT_LIST_HEAD(&sw_ctx_tx->tx_list);
   2306		INIT_DELAYED_WORK(&sw_ctx_tx->tx_work.work, tx_work_handler);
   2307		sw_ctx_tx->tx_work.sk = sk;
   2308	} else {
   2309		crypto_init_wait(&sw_ctx_rx->async_wait);
   2310		spin_lock_init(&sw_ctx_rx->decrypt_compl_lock);
   2311		crypto_info = &ctx->crypto_recv.info;
   2312		cctx = &ctx->rx;
   2313		skb_queue_head_init(&sw_ctx_rx->rx_list);
   2314		aead = &sw_ctx_rx->aead_recv;
   2315	}
   2316
   2317	switch (crypto_info->cipher_type) {
   2318	case TLS_CIPHER_AES_GCM_128: {
   2319		struct tls12_crypto_info_aes_gcm_128 *gcm_128_info;
   2320
   2321		gcm_128_info = (void *)crypto_info;
   2322		nonce_size = TLS_CIPHER_AES_GCM_128_IV_SIZE;
   2323		tag_size = TLS_CIPHER_AES_GCM_128_TAG_SIZE;
   2324		iv_size = TLS_CIPHER_AES_GCM_128_IV_SIZE;
   2325		iv = gcm_128_info->iv;
   2326		rec_seq_size = TLS_CIPHER_AES_GCM_128_REC_SEQ_SIZE;
   2327		rec_seq = gcm_128_info->rec_seq;
   2328		keysize = TLS_CIPHER_AES_GCM_128_KEY_SIZE;
   2329		key = gcm_128_info->key;
   2330		salt = gcm_128_info->salt;
   2331		salt_size = TLS_CIPHER_AES_GCM_128_SALT_SIZE;
   2332		cipher_name = "gcm(aes)";
   2333		break;
   2334	}
   2335	case TLS_CIPHER_AES_GCM_256: {
   2336		struct tls12_crypto_info_aes_gcm_256 *gcm_256_info;
   2337
   2338		gcm_256_info = (void *)crypto_info;
   2339		nonce_size = TLS_CIPHER_AES_GCM_256_IV_SIZE;
   2340		tag_size = TLS_CIPHER_AES_GCM_256_TAG_SIZE;
   2341		iv_size = TLS_CIPHER_AES_GCM_256_IV_SIZE;
   2342		iv = gcm_256_info->iv;
   2343		rec_seq_size = TLS_CIPHER_AES_GCM_256_REC_SEQ_SIZE;
   2344		rec_seq = gcm_256_info->rec_seq;
   2345		keysize = TLS_CIPHER_AES_GCM_256_KEY_SIZE;
   2346		key = gcm_256_info->key;
   2347		salt = gcm_256_info->salt;
   2348		salt_size = TLS_CIPHER_AES_GCM_256_SALT_SIZE;
   2349		cipher_name = "gcm(aes)";
   2350		break;
   2351	}
   2352	case TLS_CIPHER_AES_CCM_128: {
   2353		struct tls12_crypto_info_aes_ccm_128 *ccm_128_info;
   2354
   2355		ccm_128_info = (void *)crypto_info;
   2356		nonce_size = TLS_CIPHER_AES_CCM_128_IV_SIZE;
   2357		tag_size = TLS_CIPHER_AES_CCM_128_TAG_SIZE;
   2358		iv_size = TLS_CIPHER_AES_CCM_128_IV_SIZE;
   2359		iv = ccm_128_info->iv;
   2360		rec_seq_size = TLS_CIPHER_AES_CCM_128_REC_SEQ_SIZE;
   2361		rec_seq = ccm_128_info->rec_seq;
   2362		keysize = TLS_CIPHER_AES_CCM_128_KEY_SIZE;
   2363		key = ccm_128_info->key;
   2364		salt = ccm_128_info->salt;
   2365		salt_size = TLS_CIPHER_AES_CCM_128_SALT_SIZE;
   2366		cipher_name = "ccm(aes)";
   2367		break;
   2368	}
   2369	case TLS_CIPHER_CHACHA20_POLY1305: {
   2370		struct tls12_crypto_info_chacha20_poly1305 *chacha20_poly1305_info;
   2371
   2372		chacha20_poly1305_info = (void *)crypto_info;
   2373		nonce_size = 0;
   2374		tag_size = TLS_CIPHER_CHACHA20_POLY1305_TAG_SIZE;
   2375		iv_size = TLS_CIPHER_CHACHA20_POLY1305_IV_SIZE;
   2376		iv = chacha20_poly1305_info->iv;
   2377		rec_seq_size = TLS_CIPHER_CHACHA20_POLY1305_REC_SEQ_SIZE;
   2378		rec_seq = chacha20_poly1305_info->rec_seq;
   2379		keysize = TLS_CIPHER_CHACHA20_POLY1305_KEY_SIZE;
   2380		key = chacha20_poly1305_info->key;
   2381		salt = chacha20_poly1305_info->salt;
   2382		salt_size = TLS_CIPHER_CHACHA20_POLY1305_SALT_SIZE;
   2383		cipher_name = "rfc7539(chacha20,poly1305)";
   2384		break;
   2385	}
   2386	case TLS_CIPHER_SM4_GCM: {
   2387		struct tls12_crypto_info_sm4_gcm *sm4_gcm_info;
   2388
   2389		sm4_gcm_info = (void *)crypto_info;
   2390		nonce_size = TLS_CIPHER_SM4_GCM_IV_SIZE;
   2391		tag_size = TLS_CIPHER_SM4_GCM_TAG_SIZE;
   2392		iv_size = TLS_CIPHER_SM4_GCM_IV_SIZE;
   2393		iv = sm4_gcm_info->iv;
   2394		rec_seq_size = TLS_CIPHER_SM4_GCM_REC_SEQ_SIZE;
   2395		rec_seq = sm4_gcm_info->rec_seq;
   2396		keysize = TLS_CIPHER_SM4_GCM_KEY_SIZE;
   2397		key = sm4_gcm_info->key;
   2398		salt = sm4_gcm_info->salt;
   2399		salt_size = TLS_CIPHER_SM4_GCM_SALT_SIZE;
   2400		cipher_name = "gcm(sm4)";
   2401		break;
   2402	}
   2403	case TLS_CIPHER_SM4_CCM: {
   2404		struct tls12_crypto_info_sm4_ccm *sm4_ccm_info;
   2405
   2406		sm4_ccm_info = (void *)crypto_info;
   2407		nonce_size = TLS_CIPHER_SM4_CCM_IV_SIZE;
   2408		tag_size = TLS_CIPHER_SM4_CCM_TAG_SIZE;
   2409		iv_size = TLS_CIPHER_SM4_CCM_IV_SIZE;
   2410		iv = sm4_ccm_info->iv;
   2411		rec_seq_size = TLS_CIPHER_SM4_CCM_REC_SEQ_SIZE;
   2412		rec_seq = sm4_ccm_info->rec_seq;
   2413		keysize = TLS_CIPHER_SM4_CCM_KEY_SIZE;
   2414		key = sm4_ccm_info->key;
   2415		salt = sm4_ccm_info->salt;
   2416		salt_size = TLS_CIPHER_SM4_CCM_SALT_SIZE;
   2417		cipher_name = "ccm(sm4)";
   2418		break;
   2419	}
   2420	default:
   2421		rc = -EINVAL;
   2422		goto free_priv;
   2423	}
   2424
   2425	/* Sanity-check the sizes for stack allocations. */
   2426	if (iv_size > MAX_IV_SIZE || nonce_size > MAX_IV_SIZE ||
   2427	    rec_seq_size > TLS_MAX_REC_SEQ_SIZE || tag_size != TLS_TAG_SIZE) {
   2428		rc = -EINVAL;
   2429		goto free_priv;
   2430	}
   2431
   2432	if (crypto_info->version == TLS_1_3_VERSION) {
   2433		nonce_size = 0;
   2434		prot->aad_size = TLS_HEADER_SIZE;
   2435		prot->tail_size = 1;
   2436	} else {
   2437		prot->aad_size = TLS_AAD_SPACE_SIZE;
   2438		prot->tail_size = 0;
   2439	}
   2440
   2441	prot->version = crypto_info->version;
   2442	prot->cipher_type = crypto_info->cipher_type;
   2443	prot->prepend_size = TLS_HEADER_SIZE + nonce_size;
   2444	prot->tag_size = tag_size;
   2445	prot->overhead_size = prot->prepend_size +
   2446			      prot->tag_size + prot->tail_size;
   2447	prot->iv_size = iv_size;
   2448	prot->salt_size = salt_size;
   2449	cctx->iv = kmalloc(iv_size + salt_size, GFP_KERNEL);
   2450	if (!cctx->iv) {
   2451		rc = -ENOMEM;
   2452		goto free_priv;
   2453	}
   2454	/* Note: 128 & 256 bit salt are the same size */
   2455	prot->rec_seq_size = rec_seq_size;
   2456	memcpy(cctx->iv, salt, salt_size);
   2457	memcpy(cctx->iv + salt_size, iv, iv_size);
   2458	cctx->rec_seq = kmemdup(rec_seq, rec_seq_size, GFP_KERNEL);
   2459	if (!cctx->rec_seq) {
   2460		rc = -ENOMEM;
   2461		goto free_iv;
   2462	}
   2463
   2464	if (!*aead) {
   2465		*aead = crypto_alloc_aead(cipher_name, 0, 0);
   2466		if (IS_ERR(*aead)) {
   2467			rc = PTR_ERR(*aead);
   2468			*aead = NULL;
   2469			goto free_rec_seq;
   2470		}
   2471	}
   2472
   2473	ctx->push_pending_record = tls_sw_push_pending_record;
   2474
   2475	rc = crypto_aead_setkey(*aead, key, keysize);
   2476
   2477	if (rc)
   2478		goto free_aead;
   2479
   2480	rc = crypto_aead_setauthsize(*aead, prot->tag_size);
   2481	if (rc)
   2482		goto free_aead;
   2483
   2484	if (sw_ctx_rx) {
   2485		tfm = crypto_aead_tfm(sw_ctx_rx->aead_recv);
   2486
   2487		if (crypto_info->version == TLS_1_3_VERSION)
   2488			sw_ctx_rx->async_capable = 0;
   2489		else
   2490			sw_ctx_rx->async_capable =
   2491				!!(tfm->__crt_alg->cra_flags &
   2492				   CRYPTO_ALG_ASYNC);
   2493
   2494		/* Set up strparser */
   2495		memset(&cb, 0, sizeof(cb));
   2496		cb.rcv_msg = tls_queue;
   2497		cb.parse_msg = tls_read_size;
   2498
   2499		strp_init(&sw_ctx_rx->strp, sk, &cb);
   2500	}
   2501
   2502	goto out;
   2503
   2504free_aead:
   2505	crypto_free_aead(*aead);
   2506	*aead = NULL;
   2507free_rec_seq:
   2508	kfree(cctx->rec_seq);
   2509	cctx->rec_seq = NULL;
   2510free_iv:
   2511	kfree(cctx->iv);
   2512	cctx->iv = NULL;
   2513free_priv:
   2514	if (tx) {
   2515		kfree(ctx->priv_ctx_tx);
   2516		ctx->priv_ctx_tx = NULL;
   2517	} else {
   2518		kfree(ctx->priv_ctx_rx);
   2519		ctx->priv_ctx_rx = NULL;
   2520	}
   2521out:
   2522	return rc;
   2523}