cachepc-linux

Fork of AMDESE/linux with modifications for CachePC side-channel attack
git clone https://git.sinitax.com/sinitax/cachepc-linux
Log | Files | Refs | README | LICENSE | sfeed.txt

tls_device.c (37508B)


      1/* Copyright (c) 2018, Mellanox Technologies All rights reserved.
      2 *
      3 * This software is available to you under a choice of one of two
      4 * licenses.  You may choose to be licensed under the terms of the GNU
      5 * General Public License (GPL) Version 2, available from the file
      6 * COPYING in the main directory of this source tree, or the
      7 * OpenIB.org BSD license below:
      8 *
      9 *     Redistribution and use in source and binary forms, with or
     10 *     without modification, are permitted provided that the following
     11 *     conditions are met:
     12 *
     13 *      - Redistributions of source code must retain the above
     14 *        copyright notice, this list of conditions and the following
     15 *        disclaimer.
     16 *
     17 *      - Redistributions in binary form must reproduce the above
     18 *        copyright notice, this list of conditions and the following
     19 *        disclaimer in the documentation and/or other materials
     20 *        provided with the distribution.
     21 *
     22 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
     23 * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
     24 * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
     25 * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
     26 * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
     27 * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
     28 * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
     29 * SOFTWARE.
     30 */
     31
     32#include <crypto/aead.h>
     33#include <linux/highmem.h>
     34#include <linux/module.h>
     35#include <linux/netdevice.h>
     36#include <net/dst.h>
     37#include <net/inet_connection_sock.h>
     38#include <net/tcp.h>
     39#include <net/tls.h>
     40
     41#include "trace.h"
     42
     43/* device_offload_lock is used to synchronize tls_dev_add
     44 * against NETDEV_DOWN notifications.
     45 */
     46static DECLARE_RWSEM(device_offload_lock);
     47
     48static void tls_device_gc_task(struct work_struct *work);
     49
     50static DECLARE_WORK(tls_device_gc_work, tls_device_gc_task);
     51static LIST_HEAD(tls_device_gc_list);
     52static LIST_HEAD(tls_device_list);
     53static LIST_HEAD(tls_device_down_list);
     54static DEFINE_SPINLOCK(tls_device_lock);
     55
     56static void tls_device_free_ctx(struct tls_context *ctx)
     57{
     58	if (ctx->tx_conf == TLS_HW) {
     59		kfree(tls_offload_ctx_tx(ctx));
     60		kfree(ctx->tx.rec_seq);
     61		kfree(ctx->tx.iv);
     62	}
     63
     64	if (ctx->rx_conf == TLS_HW)
     65		kfree(tls_offload_ctx_rx(ctx));
     66
     67	tls_ctx_free(NULL, ctx);
     68}
     69
     70static void tls_device_gc_task(struct work_struct *work)
     71{
     72	struct tls_context *ctx, *tmp;
     73	unsigned long flags;
     74	LIST_HEAD(gc_list);
     75
     76	spin_lock_irqsave(&tls_device_lock, flags);
     77	list_splice_init(&tls_device_gc_list, &gc_list);
     78	spin_unlock_irqrestore(&tls_device_lock, flags);
     79
     80	list_for_each_entry_safe(ctx, tmp, &gc_list, list) {
     81		struct net_device *netdev = ctx->netdev;
     82
     83		if (netdev && ctx->tx_conf == TLS_HW) {
     84			netdev->tlsdev_ops->tls_dev_del(netdev, ctx,
     85							TLS_OFFLOAD_CTX_DIR_TX);
     86			dev_put(netdev);
     87			ctx->netdev = NULL;
     88		}
     89
     90		list_del(&ctx->list);
     91		tls_device_free_ctx(ctx);
     92	}
     93}
     94
     95static void tls_device_queue_ctx_destruction(struct tls_context *ctx)
     96{
     97	unsigned long flags;
     98
     99	spin_lock_irqsave(&tls_device_lock, flags);
    100	list_move_tail(&ctx->list, &tls_device_gc_list);
    101
    102	/* schedule_work inside the spinlock
    103	 * to make sure tls_device_down waits for that work.
    104	 */
    105	schedule_work(&tls_device_gc_work);
    106
    107	spin_unlock_irqrestore(&tls_device_lock, flags);
    108}
    109
    110/* We assume that the socket is already connected */
    111static struct net_device *get_netdev_for_sock(struct sock *sk)
    112{
    113	struct dst_entry *dst = sk_dst_get(sk);
    114	struct net_device *netdev = NULL;
    115
    116	if (likely(dst)) {
    117		netdev = netdev_sk_get_lowest_dev(dst->dev, sk);
    118		dev_hold(netdev);
    119	}
    120
    121	dst_release(dst);
    122
    123	return netdev;
    124}
    125
    126static void destroy_record(struct tls_record_info *record)
    127{
    128	int i;
    129
    130	for (i = 0; i < record->num_frags; i++)
    131		__skb_frag_unref(&record->frags[i], false);
    132	kfree(record);
    133}
    134
    135static void delete_all_records(struct tls_offload_context_tx *offload_ctx)
    136{
    137	struct tls_record_info *info, *temp;
    138
    139	list_for_each_entry_safe(info, temp, &offload_ctx->records_list, list) {
    140		list_del(&info->list);
    141		destroy_record(info);
    142	}
    143
    144	offload_ctx->retransmit_hint = NULL;
    145}
    146
    147static void tls_icsk_clean_acked(struct sock *sk, u32 acked_seq)
    148{
    149	struct tls_context *tls_ctx = tls_get_ctx(sk);
    150	struct tls_record_info *info, *temp;
    151	struct tls_offload_context_tx *ctx;
    152	u64 deleted_records = 0;
    153	unsigned long flags;
    154
    155	if (!tls_ctx)
    156		return;
    157
    158	ctx = tls_offload_ctx_tx(tls_ctx);
    159
    160	spin_lock_irqsave(&ctx->lock, flags);
    161	info = ctx->retransmit_hint;
    162	if (info && !before(acked_seq, info->end_seq))
    163		ctx->retransmit_hint = NULL;
    164
    165	list_for_each_entry_safe(info, temp, &ctx->records_list, list) {
    166		if (before(acked_seq, info->end_seq))
    167			break;
    168		list_del(&info->list);
    169
    170		destroy_record(info);
    171		deleted_records++;
    172	}
    173
    174	ctx->unacked_record_sn += deleted_records;
    175	spin_unlock_irqrestore(&ctx->lock, flags);
    176}
    177
    178/* At this point, there should be no references on this
    179 * socket and no in-flight SKBs associated with this
    180 * socket, so it is safe to free all the resources.
    181 */
    182void tls_device_sk_destruct(struct sock *sk)
    183{
    184	struct tls_context *tls_ctx = tls_get_ctx(sk);
    185	struct tls_offload_context_tx *ctx = tls_offload_ctx_tx(tls_ctx);
    186
    187	tls_ctx->sk_destruct(sk);
    188
    189	if (tls_ctx->tx_conf == TLS_HW) {
    190		if (ctx->open_record)
    191			destroy_record(ctx->open_record);
    192		delete_all_records(ctx);
    193		crypto_free_aead(ctx->aead_send);
    194		clean_acked_data_disable(inet_csk(sk));
    195	}
    196
    197	if (refcount_dec_and_test(&tls_ctx->refcount))
    198		tls_device_queue_ctx_destruction(tls_ctx);
    199}
    200EXPORT_SYMBOL_GPL(tls_device_sk_destruct);
    201
    202void tls_device_free_resources_tx(struct sock *sk)
    203{
    204	struct tls_context *tls_ctx = tls_get_ctx(sk);
    205
    206	tls_free_partial_record(sk, tls_ctx);
    207}
    208
    209void tls_offload_tx_resync_request(struct sock *sk, u32 got_seq, u32 exp_seq)
    210{
    211	struct tls_context *tls_ctx = tls_get_ctx(sk);
    212
    213	trace_tls_device_tx_resync_req(sk, got_seq, exp_seq);
    214	WARN_ON(test_and_set_bit(TLS_TX_SYNC_SCHED, &tls_ctx->flags));
    215}
    216EXPORT_SYMBOL_GPL(tls_offload_tx_resync_request);
    217
    218static void tls_device_resync_tx(struct sock *sk, struct tls_context *tls_ctx,
    219				 u32 seq)
    220{
    221	struct net_device *netdev;
    222	struct sk_buff *skb;
    223	int err = 0;
    224	u8 *rcd_sn;
    225
    226	skb = tcp_write_queue_tail(sk);
    227	if (skb)
    228		TCP_SKB_CB(skb)->eor = 1;
    229
    230	rcd_sn = tls_ctx->tx.rec_seq;
    231
    232	trace_tls_device_tx_resync_send(sk, seq, rcd_sn);
    233	down_read(&device_offload_lock);
    234	netdev = tls_ctx->netdev;
    235	if (netdev)
    236		err = netdev->tlsdev_ops->tls_dev_resync(netdev, sk, seq,
    237							 rcd_sn,
    238							 TLS_OFFLOAD_CTX_DIR_TX);
    239	up_read(&device_offload_lock);
    240	if (err)
    241		return;
    242
    243	clear_bit_unlock(TLS_TX_SYNC_SCHED, &tls_ctx->flags);
    244}
    245
    246static void tls_append_frag(struct tls_record_info *record,
    247			    struct page_frag *pfrag,
    248			    int size)
    249{
    250	skb_frag_t *frag;
    251
    252	frag = &record->frags[record->num_frags - 1];
    253	if (skb_frag_page(frag) == pfrag->page &&
    254	    skb_frag_off(frag) + skb_frag_size(frag) == pfrag->offset) {
    255		skb_frag_size_add(frag, size);
    256	} else {
    257		++frag;
    258		__skb_frag_set_page(frag, pfrag->page);
    259		skb_frag_off_set(frag, pfrag->offset);
    260		skb_frag_size_set(frag, size);
    261		++record->num_frags;
    262		get_page(pfrag->page);
    263	}
    264
    265	pfrag->offset += size;
    266	record->len += size;
    267}
    268
    269static int tls_push_record(struct sock *sk,
    270			   struct tls_context *ctx,
    271			   struct tls_offload_context_tx *offload_ctx,
    272			   struct tls_record_info *record,
    273			   int flags)
    274{
    275	struct tls_prot_info *prot = &ctx->prot_info;
    276	struct tcp_sock *tp = tcp_sk(sk);
    277	skb_frag_t *frag;
    278	int i;
    279
    280	record->end_seq = tp->write_seq + record->len;
    281	list_add_tail_rcu(&record->list, &offload_ctx->records_list);
    282	offload_ctx->open_record = NULL;
    283
    284	if (test_bit(TLS_TX_SYNC_SCHED, &ctx->flags))
    285		tls_device_resync_tx(sk, ctx, tp->write_seq);
    286
    287	tls_advance_record_sn(sk, prot, &ctx->tx);
    288
    289	for (i = 0; i < record->num_frags; i++) {
    290		frag = &record->frags[i];
    291		sg_unmark_end(&offload_ctx->sg_tx_data[i]);
    292		sg_set_page(&offload_ctx->sg_tx_data[i], skb_frag_page(frag),
    293			    skb_frag_size(frag), skb_frag_off(frag));
    294		sk_mem_charge(sk, skb_frag_size(frag));
    295		get_page(skb_frag_page(frag));
    296	}
    297	sg_mark_end(&offload_ctx->sg_tx_data[record->num_frags - 1]);
    298
    299	/* all ready, send */
    300	return tls_push_sg(sk, ctx, offload_ctx->sg_tx_data, 0, flags);
    301}
    302
    303static int tls_device_record_close(struct sock *sk,
    304				   struct tls_context *ctx,
    305				   struct tls_record_info *record,
    306				   struct page_frag *pfrag,
    307				   unsigned char record_type)
    308{
    309	struct tls_prot_info *prot = &ctx->prot_info;
    310	int ret;
    311
    312	/* append tag
    313	 * device will fill in the tag, we just need to append a placeholder
    314	 * use socket memory to improve coalescing (re-using a single buffer
    315	 * increases frag count)
    316	 * if we can't allocate memory now, steal some back from data
    317	 */
    318	if (likely(skb_page_frag_refill(prot->tag_size, pfrag,
    319					sk->sk_allocation))) {
    320		ret = 0;
    321		tls_append_frag(record, pfrag, prot->tag_size);
    322	} else {
    323		ret = prot->tag_size;
    324		if (record->len <= prot->overhead_size)
    325			return -ENOMEM;
    326	}
    327
    328	/* fill prepend */
    329	tls_fill_prepend(ctx, skb_frag_address(&record->frags[0]),
    330			 record->len - prot->overhead_size,
    331			 record_type);
    332	return ret;
    333}
    334
    335static int tls_create_new_record(struct tls_offload_context_tx *offload_ctx,
    336				 struct page_frag *pfrag,
    337				 size_t prepend_size)
    338{
    339	struct tls_record_info *record;
    340	skb_frag_t *frag;
    341
    342	record = kmalloc(sizeof(*record), GFP_KERNEL);
    343	if (!record)
    344		return -ENOMEM;
    345
    346	frag = &record->frags[0];
    347	__skb_frag_set_page(frag, pfrag->page);
    348	skb_frag_off_set(frag, pfrag->offset);
    349	skb_frag_size_set(frag, prepend_size);
    350
    351	get_page(pfrag->page);
    352	pfrag->offset += prepend_size;
    353
    354	record->num_frags = 1;
    355	record->len = prepend_size;
    356	offload_ctx->open_record = record;
    357	return 0;
    358}
    359
    360static int tls_do_allocation(struct sock *sk,
    361			     struct tls_offload_context_tx *offload_ctx,
    362			     struct page_frag *pfrag,
    363			     size_t prepend_size)
    364{
    365	int ret;
    366
    367	if (!offload_ctx->open_record) {
    368		if (unlikely(!skb_page_frag_refill(prepend_size, pfrag,
    369						   sk->sk_allocation))) {
    370			READ_ONCE(sk->sk_prot)->enter_memory_pressure(sk);
    371			sk_stream_moderate_sndbuf(sk);
    372			return -ENOMEM;
    373		}
    374
    375		ret = tls_create_new_record(offload_ctx, pfrag, prepend_size);
    376		if (ret)
    377			return ret;
    378
    379		if (pfrag->size > pfrag->offset)
    380			return 0;
    381	}
    382
    383	if (!sk_page_frag_refill(sk, pfrag))
    384		return -ENOMEM;
    385
    386	return 0;
    387}
    388
    389static int tls_device_copy_data(void *addr, size_t bytes, struct iov_iter *i)
    390{
    391	size_t pre_copy, nocache;
    392
    393	pre_copy = ~((unsigned long)addr - 1) & (SMP_CACHE_BYTES - 1);
    394	if (pre_copy) {
    395		pre_copy = min(pre_copy, bytes);
    396		if (copy_from_iter(addr, pre_copy, i) != pre_copy)
    397			return -EFAULT;
    398		bytes -= pre_copy;
    399		addr += pre_copy;
    400	}
    401
    402	nocache = round_down(bytes, SMP_CACHE_BYTES);
    403	if (copy_from_iter_nocache(addr, nocache, i) != nocache)
    404		return -EFAULT;
    405	bytes -= nocache;
    406	addr += nocache;
    407
    408	if (bytes && copy_from_iter(addr, bytes, i) != bytes)
    409		return -EFAULT;
    410
    411	return 0;
    412}
    413
    414union tls_iter_offset {
    415	struct iov_iter *msg_iter;
    416	int offset;
    417};
    418
    419static int tls_push_data(struct sock *sk,
    420			 union tls_iter_offset iter_offset,
    421			 size_t size, int flags,
    422			 unsigned char record_type,
    423			 struct page *zc_page)
    424{
    425	struct tls_context *tls_ctx = tls_get_ctx(sk);
    426	struct tls_prot_info *prot = &tls_ctx->prot_info;
    427	struct tls_offload_context_tx *ctx = tls_offload_ctx_tx(tls_ctx);
    428	struct tls_record_info *record;
    429	int tls_push_record_flags;
    430	struct page_frag *pfrag;
    431	size_t orig_size = size;
    432	u32 max_open_record_len;
    433	bool more = false;
    434	bool done = false;
    435	int copy, rc = 0;
    436	long timeo;
    437
    438	if (flags &
    439	    ~(MSG_MORE | MSG_DONTWAIT | MSG_NOSIGNAL | MSG_SENDPAGE_NOTLAST))
    440		return -EOPNOTSUPP;
    441
    442	if (unlikely(sk->sk_err))
    443		return -sk->sk_err;
    444
    445	flags |= MSG_SENDPAGE_DECRYPTED;
    446	tls_push_record_flags = flags | MSG_SENDPAGE_NOTLAST;
    447
    448	timeo = sock_sndtimeo(sk, flags & MSG_DONTWAIT);
    449	if (tls_is_partially_sent_record(tls_ctx)) {
    450		rc = tls_push_partial_record(sk, tls_ctx, flags);
    451		if (rc < 0)
    452			return rc;
    453	}
    454
    455	pfrag = sk_page_frag(sk);
    456
    457	/* TLS_HEADER_SIZE is not counted as part of the TLS record, and
    458	 * we need to leave room for an authentication tag.
    459	 */
    460	max_open_record_len = TLS_MAX_PAYLOAD_SIZE +
    461			      prot->prepend_size;
    462	do {
    463		rc = tls_do_allocation(sk, ctx, pfrag, prot->prepend_size);
    464		if (unlikely(rc)) {
    465			rc = sk_stream_wait_memory(sk, &timeo);
    466			if (!rc)
    467				continue;
    468
    469			record = ctx->open_record;
    470			if (!record)
    471				break;
    472handle_error:
    473			if (record_type != TLS_RECORD_TYPE_DATA) {
    474				/* avoid sending partial
    475				 * record with type !=
    476				 * application_data
    477				 */
    478				size = orig_size;
    479				destroy_record(record);
    480				ctx->open_record = NULL;
    481			} else if (record->len > prot->prepend_size) {
    482				goto last_record;
    483			}
    484
    485			break;
    486		}
    487
    488		record = ctx->open_record;
    489
    490		copy = min_t(size_t, size, max_open_record_len - record->len);
    491		if (copy && zc_page) {
    492			struct page_frag zc_pfrag;
    493
    494			zc_pfrag.page = zc_page;
    495			zc_pfrag.offset = iter_offset.offset;
    496			zc_pfrag.size = copy;
    497			tls_append_frag(record, &zc_pfrag, copy);
    498		} else if (copy) {
    499			copy = min_t(size_t, copy, pfrag->size - pfrag->offset);
    500
    501			rc = tls_device_copy_data(page_address(pfrag->page) +
    502						  pfrag->offset, copy,
    503						  iter_offset.msg_iter);
    504			if (rc)
    505				goto handle_error;
    506			tls_append_frag(record, pfrag, copy);
    507		}
    508
    509		size -= copy;
    510		if (!size) {
    511last_record:
    512			tls_push_record_flags = flags;
    513			if (flags & (MSG_SENDPAGE_NOTLAST | MSG_MORE)) {
    514				more = true;
    515				break;
    516			}
    517
    518			done = true;
    519		}
    520
    521		if (done || record->len >= max_open_record_len ||
    522		    (record->num_frags >= MAX_SKB_FRAGS - 1)) {
    523			rc = tls_device_record_close(sk, tls_ctx, record,
    524						     pfrag, record_type);
    525			if (rc) {
    526				if (rc > 0) {
    527					size += rc;
    528				} else {
    529					size = orig_size;
    530					destroy_record(record);
    531					ctx->open_record = NULL;
    532					break;
    533				}
    534			}
    535
    536			rc = tls_push_record(sk,
    537					     tls_ctx,
    538					     ctx,
    539					     record,
    540					     tls_push_record_flags);
    541			if (rc < 0)
    542				break;
    543		}
    544	} while (!done);
    545
    546	tls_ctx->pending_open_record_frags = more;
    547
    548	if (orig_size - size > 0)
    549		rc = orig_size - size;
    550
    551	return rc;
    552}
    553
    554int tls_device_sendmsg(struct sock *sk, struct msghdr *msg, size_t size)
    555{
    556	unsigned char record_type = TLS_RECORD_TYPE_DATA;
    557	struct tls_context *tls_ctx = tls_get_ctx(sk);
    558	union tls_iter_offset iter;
    559	int rc;
    560
    561	mutex_lock(&tls_ctx->tx_lock);
    562	lock_sock(sk);
    563
    564	if (unlikely(msg->msg_controllen)) {
    565		rc = tls_proccess_cmsg(sk, msg, &record_type);
    566		if (rc)
    567			goto out;
    568	}
    569
    570	iter.msg_iter = &msg->msg_iter;
    571	rc = tls_push_data(sk, iter, size, msg->msg_flags, record_type, NULL);
    572
    573out:
    574	release_sock(sk);
    575	mutex_unlock(&tls_ctx->tx_lock);
    576	return rc;
    577}
    578
    579int tls_device_sendpage(struct sock *sk, struct page *page,
    580			int offset, size_t size, int flags)
    581{
    582	struct tls_context *tls_ctx = tls_get_ctx(sk);
    583	union tls_iter_offset iter_offset;
    584	struct iov_iter msg_iter;
    585	char *kaddr;
    586	struct kvec iov;
    587	int rc;
    588
    589	if (flags & MSG_SENDPAGE_NOTLAST)
    590		flags |= MSG_MORE;
    591
    592	mutex_lock(&tls_ctx->tx_lock);
    593	lock_sock(sk);
    594
    595	if (flags & MSG_OOB) {
    596		rc = -EOPNOTSUPP;
    597		goto out;
    598	}
    599
    600	if (tls_ctx->zerocopy_sendfile) {
    601		iter_offset.offset = offset;
    602		rc = tls_push_data(sk, iter_offset, size,
    603				   flags, TLS_RECORD_TYPE_DATA, page);
    604		goto out;
    605	}
    606
    607	kaddr = kmap(page);
    608	iov.iov_base = kaddr + offset;
    609	iov.iov_len = size;
    610	iov_iter_kvec(&msg_iter, WRITE, &iov, 1, size);
    611	iter_offset.msg_iter = &msg_iter;
    612	rc = tls_push_data(sk, iter_offset, size, flags, TLS_RECORD_TYPE_DATA,
    613			   NULL);
    614	kunmap(page);
    615
    616out:
    617	release_sock(sk);
    618	mutex_unlock(&tls_ctx->tx_lock);
    619	return rc;
    620}
    621
    622struct tls_record_info *tls_get_record(struct tls_offload_context_tx *context,
    623				       u32 seq, u64 *p_record_sn)
    624{
    625	u64 record_sn = context->hint_record_sn;
    626	struct tls_record_info *info, *last;
    627
    628	info = context->retransmit_hint;
    629	if (!info ||
    630	    before(seq, info->end_seq - info->len)) {
    631		/* if retransmit_hint is irrelevant start
    632		 * from the beginning of the list
    633		 */
    634		info = list_first_entry_or_null(&context->records_list,
    635						struct tls_record_info, list);
    636		if (!info)
    637			return NULL;
    638		/* send the start_marker record if seq number is before the
    639		 * tls offload start marker sequence number. This record is
    640		 * required to handle TCP packets which are before TLS offload
    641		 * started.
    642		 *  And if it's not start marker, look if this seq number
    643		 * belongs to the list.
    644		 */
    645		if (likely(!tls_record_is_start_marker(info))) {
    646			/* we have the first record, get the last record to see
    647			 * if this seq number belongs to the list.
    648			 */
    649			last = list_last_entry(&context->records_list,
    650					       struct tls_record_info, list);
    651
    652			if (!between(seq, tls_record_start_seq(info),
    653				     last->end_seq))
    654				return NULL;
    655		}
    656		record_sn = context->unacked_record_sn;
    657	}
    658
    659	/* We just need the _rcu for the READ_ONCE() */
    660	rcu_read_lock();
    661	list_for_each_entry_from_rcu(info, &context->records_list, list) {
    662		if (before(seq, info->end_seq)) {
    663			if (!context->retransmit_hint ||
    664			    after(info->end_seq,
    665				  context->retransmit_hint->end_seq)) {
    666				context->hint_record_sn = record_sn;
    667				context->retransmit_hint = info;
    668			}
    669			*p_record_sn = record_sn;
    670			goto exit_rcu_unlock;
    671		}
    672		record_sn++;
    673	}
    674	info = NULL;
    675
    676exit_rcu_unlock:
    677	rcu_read_unlock();
    678	return info;
    679}
    680EXPORT_SYMBOL(tls_get_record);
    681
    682static int tls_device_push_pending_record(struct sock *sk, int flags)
    683{
    684	union tls_iter_offset iter;
    685	struct iov_iter msg_iter;
    686
    687	iov_iter_kvec(&msg_iter, WRITE, NULL, 0, 0);
    688	iter.msg_iter = &msg_iter;
    689	return tls_push_data(sk, iter, 0, flags, TLS_RECORD_TYPE_DATA, NULL);
    690}
    691
    692void tls_device_write_space(struct sock *sk, struct tls_context *ctx)
    693{
    694	if (tls_is_partially_sent_record(ctx)) {
    695		gfp_t sk_allocation = sk->sk_allocation;
    696
    697		WARN_ON_ONCE(sk->sk_write_pending);
    698
    699		sk->sk_allocation = GFP_ATOMIC;
    700		tls_push_partial_record(sk, ctx,
    701					MSG_DONTWAIT | MSG_NOSIGNAL |
    702					MSG_SENDPAGE_DECRYPTED);
    703		sk->sk_allocation = sk_allocation;
    704	}
    705}
    706
    707static void tls_device_resync_rx(struct tls_context *tls_ctx,
    708				 struct sock *sk, u32 seq, u8 *rcd_sn)
    709{
    710	struct tls_offload_context_rx *rx_ctx = tls_offload_ctx_rx(tls_ctx);
    711	struct net_device *netdev;
    712
    713	trace_tls_device_rx_resync_send(sk, seq, rcd_sn, rx_ctx->resync_type);
    714	rcu_read_lock();
    715	netdev = READ_ONCE(tls_ctx->netdev);
    716	if (netdev)
    717		netdev->tlsdev_ops->tls_dev_resync(netdev, sk, seq, rcd_sn,
    718						   TLS_OFFLOAD_CTX_DIR_RX);
    719	rcu_read_unlock();
    720	TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSRXDEVICERESYNC);
    721}
    722
    723static bool
    724tls_device_rx_resync_async(struct tls_offload_resync_async *resync_async,
    725			   s64 resync_req, u32 *seq, u16 *rcd_delta)
    726{
    727	u32 is_async = resync_req & RESYNC_REQ_ASYNC;
    728	u32 req_seq = resync_req >> 32;
    729	u32 req_end = req_seq + ((resync_req >> 16) & 0xffff);
    730	u16 i;
    731
    732	*rcd_delta = 0;
    733
    734	if (is_async) {
    735		/* shouldn't get to wraparound:
    736		 * too long in async stage, something bad happened
    737		 */
    738		if (WARN_ON_ONCE(resync_async->rcd_delta == USHRT_MAX))
    739			return false;
    740
    741		/* asynchronous stage: log all headers seq such that
    742		 * req_seq <= seq <= end_seq, and wait for real resync request
    743		 */
    744		if (before(*seq, req_seq))
    745			return false;
    746		if (!after(*seq, req_end) &&
    747		    resync_async->loglen < TLS_DEVICE_RESYNC_ASYNC_LOGMAX)
    748			resync_async->log[resync_async->loglen++] = *seq;
    749
    750		resync_async->rcd_delta++;
    751
    752		return false;
    753	}
    754
    755	/* synchronous stage: check against the logged entries and
    756	 * proceed to check the next entries if no match was found
    757	 */
    758	for (i = 0; i < resync_async->loglen; i++)
    759		if (req_seq == resync_async->log[i] &&
    760		    atomic64_try_cmpxchg(&resync_async->req, &resync_req, 0)) {
    761			*rcd_delta = resync_async->rcd_delta - i;
    762			*seq = req_seq;
    763			resync_async->loglen = 0;
    764			resync_async->rcd_delta = 0;
    765			return true;
    766		}
    767
    768	resync_async->loglen = 0;
    769	resync_async->rcd_delta = 0;
    770
    771	if (req_seq == *seq &&
    772	    atomic64_try_cmpxchg(&resync_async->req,
    773				 &resync_req, 0))
    774		return true;
    775
    776	return false;
    777}
    778
    779void tls_device_rx_resync_new_rec(struct sock *sk, u32 rcd_len, u32 seq)
    780{
    781	struct tls_context *tls_ctx = tls_get_ctx(sk);
    782	struct tls_offload_context_rx *rx_ctx;
    783	u8 rcd_sn[TLS_MAX_REC_SEQ_SIZE];
    784	u32 sock_data, is_req_pending;
    785	struct tls_prot_info *prot;
    786	s64 resync_req;
    787	u16 rcd_delta;
    788	u32 req_seq;
    789
    790	if (tls_ctx->rx_conf != TLS_HW)
    791		return;
    792	if (unlikely(test_bit(TLS_RX_DEV_DEGRADED, &tls_ctx->flags)))
    793		return;
    794
    795	prot = &tls_ctx->prot_info;
    796	rx_ctx = tls_offload_ctx_rx(tls_ctx);
    797	memcpy(rcd_sn, tls_ctx->rx.rec_seq, prot->rec_seq_size);
    798
    799	switch (rx_ctx->resync_type) {
    800	case TLS_OFFLOAD_SYNC_TYPE_DRIVER_REQ:
    801		resync_req = atomic64_read(&rx_ctx->resync_req);
    802		req_seq = resync_req >> 32;
    803		seq += TLS_HEADER_SIZE - 1;
    804		is_req_pending = resync_req;
    805
    806		if (likely(!is_req_pending) || req_seq != seq ||
    807		    !atomic64_try_cmpxchg(&rx_ctx->resync_req, &resync_req, 0))
    808			return;
    809		break;
    810	case TLS_OFFLOAD_SYNC_TYPE_CORE_NEXT_HINT:
    811		if (likely(!rx_ctx->resync_nh_do_now))
    812			return;
    813
    814		/* head of next rec is already in, note that the sock_inq will
    815		 * include the currently parsed message when called from parser
    816		 */
    817		sock_data = tcp_inq(sk);
    818		if (sock_data > rcd_len) {
    819			trace_tls_device_rx_resync_nh_delay(sk, sock_data,
    820							    rcd_len);
    821			return;
    822		}
    823
    824		rx_ctx->resync_nh_do_now = 0;
    825		seq += rcd_len;
    826		tls_bigint_increment(rcd_sn, prot->rec_seq_size);
    827		break;
    828	case TLS_OFFLOAD_SYNC_TYPE_DRIVER_REQ_ASYNC:
    829		resync_req = atomic64_read(&rx_ctx->resync_async->req);
    830		is_req_pending = resync_req;
    831		if (likely(!is_req_pending))
    832			return;
    833
    834		if (!tls_device_rx_resync_async(rx_ctx->resync_async,
    835						resync_req, &seq, &rcd_delta))
    836			return;
    837		tls_bigint_subtract(rcd_sn, rcd_delta);
    838		break;
    839	}
    840
    841	tls_device_resync_rx(tls_ctx, sk, seq, rcd_sn);
    842}
    843
    844static void tls_device_core_ctrl_rx_resync(struct tls_context *tls_ctx,
    845					   struct tls_offload_context_rx *ctx,
    846					   struct sock *sk, struct sk_buff *skb)
    847{
    848	struct strp_msg *rxm;
    849
    850	/* device will request resyncs by itself based on stream scan */
    851	if (ctx->resync_type != TLS_OFFLOAD_SYNC_TYPE_CORE_NEXT_HINT)
    852		return;
    853	/* already scheduled */
    854	if (ctx->resync_nh_do_now)
    855		return;
    856	/* seen decrypted fragments since last fully-failed record */
    857	if (ctx->resync_nh_reset) {
    858		ctx->resync_nh_reset = 0;
    859		ctx->resync_nh.decrypted_failed = 1;
    860		ctx->resync_nh.decrypted_tgt = TLS_DEVICE_RESYNC_NH_START_IVAL;
    861		return;
    862	}
    863
    864	if (++ctx->resync_nh.decrypted_failed <= ctx->resync_nh.decrypted_tgt)
    865		return;
    866
    867	/* doing resync, bump the next target in case it fails */
    868	if (ctx->resync_nh.decrypted_tgt < TLS_DEVICE_RESYNC_NH_MAX_IVAL)
    869		ctx->resync_nh.decrypted_tgt *= 2;
    870	else
    871		ctx->resync_nh.decrypted_tgt += TLS_DEVICE_RESYNC_NH_MAX_IVAL;
    872
    873	rxm = strp_msg(skb);
    874
    875	/* head of next rec is already in, parser will sync for us */
    876	if (tcp_inq(sk) > rxm->full_len) {
    877		trace_tls_device_rx_resync_nh_schedule(sk);
    878		ctx->resync_nh_do_now = 1;
    879	} else {
    880		struct tls_prot_info *prot = &tls_ctx->prot_info;
    881		u8 rcd_sn[TLS_MAX_REC_SEQ_SIZE];
    882
    883		memcpy(rcd_sn, tls_ctx->rx.rec_seq, prot->rec_seq_size);
    884		tls_bigint_increment(rcd_sn, prot->rec_seq_size);
    885
    886		tls_device_resync_rx(tls_ctx, sk, tcp_sk(sk)->copied_seq,
    887				     rcd_sn);
    888	}
    889}
    890
    891static int tls_device_reencrypt(struct sock *sk, struct sk_buff *skb)
    892{
    893	struct strp_msg *rxm = strp_msg(skb);
    894	int err = 0, offset = rxm->offset, copy, nsg, data_len, pos;
    895	struct sk_buff *skb_iter, *unused;
    896	struct scatterlist sg[1];
    897	char *orig_buf, *buf;
    898
    899	orig_buf = kmalloc(rxm->full_len + TLS_HEADER_SIZE +
    900			   TLS_CIPHER_AES_GCM_128_IV_SIZE, sk->sk_allocation);
    901	if (!orig_buf)
    902		return -ENOMEM;
    903	buf = orig_buf;
    904
    905	nsg = skb_cow_data(skb, 0, &unused);
    906	if (unlikely(nsg < 0)) {
    907		err = nsg;
    908		goto free_buf;
    909	}
    910
    911	sg_init_table(sg, 1);
    912	sg_set_buf(&sg[0], buf,
    913		   rxm->full_len + TLS_HEADER_SIZE +
    914		   TLS_CIPHER_AES_GCM_128_IV_SIZE);
    915	err = skb_copy_bits(skb, offset, buf,
    916			    TLS_HEADER_SIZE + TLS_CIPHER_AES_GCM_128_IV_SIZE);
    917	if (err)
    918		goto free_buf;
    919
    920	/* We are interested only in the decrypted data not the auth */
    921	err = decrypt_skb(sk, skb, sg);
    922	if (err != -EBADMSG)
    923		goto free_buf;
    924	else
    925		err = 0;
    926
    927	data_len = rxm->full_len - TLS_CIPHER_AES_GCM_128_TAG_SIZE;
    928
    929	if (skb_pagelen(skb) > offset) {
    930		copy = min_t(int, skb_pagelen(skb) - offset, data_len);
    931
    932		if (skb->decrypted) {
    933			err = skb_store_bits(skb, offset, buf, copy);
    934			if (err)
    935				goto free_buf;
    936		}
    937
    938		offset += copy;
    939		buf += copy;
    940	}
    941
    942	pos = skb_pagelen(skb);
    943	skb_walk_frags(skb, skb_iter) {
    944		int frag_pos;
    945
    946		/* Practically all frags must belong to msg if reencrypt
    947		 * is needed with current strparser and coalescing logic,
    948		 * but strparser may "get optimized", so let's be safe.
    949		 */
    950		if (pos + skb_iter->len <= offset)
    951			goto done_with_frag;
    952		if (pos >= data_len + rxm->offset)
    953			break;
    954
    955		frag_pos = offset - pos;
    956		copy = min_t(int, skb_iter->len - frag_pos,
    957			     data_len + rxm->offset - offset);
    958
    959		if (skb_iter->decrypted) {
    960			err = skb_store_bits(skb_iter, frag_pos, buf, copy);
    961			if (err)
    962				goto free_buf;
    963		}
    964
    965		offset += copy;
    966		buf += copy;
    967done_with_frag:
    968		pos += skb_iter->len;
    969	}
    970
    971free_buf:
    972	kfree(orig_buf);
    973	return err;
    974}
    975
    976int tls_device_decrypted(struct sock *sk, struct tls_context *tls_ctx,
    977			 struct sk_buff *skb, struct strp_msg *rxm)
    978{
    979	struct tls_offload_context_rx *ctx = tls_offload_ctx_rx(tls_ctx);
    980	int is_decrypted = skb->decrypted;
    981	int is_encrypted = !is_decrypted;
    982	struct sk_buff *skb_iter;
    983
    984	/* Check if all the data is decrypted already */
    985	skb_walk_frags(skb, skb_iter) {
    986		is_decrypted &= skb_iter->decrypted;
    987		is_encrypted &= !skb_iter->decrypted;
    988	}
    989
    990	trace_tls_device_decrypted(sk, tcp_sk(sk)->copied_seq - rxm->full_len,
    991				   tls_ctx->rx.rec_seq, rxm->full_len,
    992				   is_encrypted, is_decrypted);
    993
    994	if (unlikely(test_bit(TLS_RX_DEV_DEGRADED, &tls_ctx->flags))) {
    995		if (likely(is_encrypted || is_decrypted))
    996			return is_decrypted;
    997
    998		/* After tls_device_down disables the offload, the next SKB will
    999		 * likely have initial fragments decrypted, and final ones not
   1000		 * decrypted. We need to reencrypt that single SKB.
   1001		 */
   1002		return tls_device_reencrypt(sk, skb);
   1003	}
   1004
   1005	/* Return immediately if the record is either entirely plaintext or
   1006	 * entirely ciphertext. Otherwise handle reencrypt partially decrypted
   1007	 * record.
   1008	 */
   1009	if (is_decrypted) {
   1010		ctx->resync_nh_reset = 1;
   1011		return is_decrypted;
   1012	}
   1013	if (is_encrypted) {
   1014		tls_device_core_ctrl_rx_resync(tls_ctx, ctx, sk, skb);
   1015		return 0;
   1016	}
   1017
   1018	ctx->resync_nh_reset = 1;
   1019	return tls_device_reencrypt(sk, skb);
   1020}
   1021
   1022static void tls_device_attach(struct tls_context *ctx, struct sock *sk,
   1023			      struct net_device *netdev)
   1024{
   1025	if (sk->sk_destruct != tls_device_sk_destruct) {
   1026		refcount_set(&ctx->refcount, 1);
   1027		dev_hold(netdev);
   1028		ctx->netdev = netdev;
   1029		spin_lock_irq(&tls_device_lock);
   1030		list_add_tail(&ctx->list, &tls_device_list);
   1031		spin_unlock_irq(&tls_device_lock);
   1032
   1033		ctx->sk_destruct = sk->sk_destruct;
   1034		smp_store_release(&sk->sk_destruct, tls_device_sk_destruct);
   1035	}
   1036}
   1037
   1038int tls_set_device_offload(struct sock *sk, struct tls_context *ctx)
   1039{
   1040	u16 nonce_size, tag_size, iv_size, rec_seq_size, salt_size;
   1041	struct tls_context *tls_ctx = tls_get_ctx(sk);
   1042	struct tls_prot_info *prot = &tls_ctx->prot_info;
   1043	struct tls_record_info *start_marker_record;
   1044	struct tls_offload_context_tx *offload_ctx;
   1045	struct tls_crypto_info *crypto_info;
   1046	struct net_device *netdev;
   1047	char *iv, *rec_seq;
   1048	struct sk_buff *skb;
   1049	__be64 rcd_sn;
   1050	int rc;
   1051
   1052	if (!ctx)
   1053		return -EINVAL;
   1054
   1055	if (ctx->priv_ctx_tx)
   1056		return -EEXIST;
   1057
   1058	netdev = get_netdev_for_sock(sk);
   1059	if (!netdev) {
   1060		pr_err_ratelimited("%s: netdev not found\n", __func__);
   1061		return -EINVAL;
   1062	}
   1063
   1064	if (!(netdev->features & NETIF_F_HW_TLS_TX)) {
   1065		rc = -EOPNOTSUPP;
   1066		goto release_netdev;
   1067	}
   1068
   1069	crypto_info = &ctx->crypto_send.info;
   1070	if (crypto_info->version != TLS_1_2_VERSION) {
   1071		rc = -EOPNOTSUPP;
   1072		goto release_netdev;
   1073	}
   1074
   1075	switch (crypto_info->cipher_type) {
   1076	case TLS_CIPHER_AES_GCM_128:
   1077		nonce_size = TLS_CIPHER_AES_GCM_128_IV_SIZE;
   1078		tag_size = TLS_CIPHER_AES_GCM_128_TAG_SIZE;
   1079		iv_size = TLS_CIPHER_AES_GCM_128_IV_SIZE;
   1080		iv = ((struct tls12_crypto_info_aes_gcm_128 *)crypto_info)->iv;
   1081		rec_seq_size = TLS_CIPHER_AES_GCM_128_REC_SEQ_SIZE;
   1082		salt_size = TLS_CIPHER_AES_GCM_128_SALT_SIZE;
   1083		rec_seq =
   1084		 ((struct tls12_crypto_info_aes_gcm_128 *)crypto_info)->rec_seq;
   1085		break;
   1086	default:
   1087		rc = -EINVAL;
   1088		goto release_netdev;
   1089	}
   1090
   1091	/* Sanity-check the rec_seq_size for stack allocations */
   1092	if (rec_seq_size > TLS_MAX_REC_SEQ_SIZE) {
   1093		rc = -EINVAL;
   1094		goto release_netdev;
   1095	}
   1096
   1097	prot->version = crypto_info->version;
   1098	prot->cipher_type = crypto_info->cipher_type;
   1099	prot->prepend_size = TLS_HEADER_SIZE + nonce_size;
   1100	prot->tag_size = tag_size;
   1101	prot->overhead_size = prot->prepend_size + prot->tag_size;
   1102	prot->iv_size = iv_size;
   1103	prot->salt_size = salt_size;
   1104	ctx->tx.iv = kmalloc(iv_size + TLS_CIPHER_AES_GCM_128_SALT_SIZE,
   1105			     GFP_KERNEL);
   1106	if (!ctx->tx.iv) {
   1107		rc = -ENOMEM;
   1108		goto release_netdev;
   1109	}
   1110
   1111	memcpy(ctx->tx.iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE, iv, iv_size);
   1112
   1113	prot->rec_seq_size = rec_seq_size;
   1114	ctx->tx.rec_seq = kmemdup(rec_seq, rec_seq_size, GFP_KERNEL);
   1115	if (!ctx->tx.rec_seq) {
   1116		rc = -ENOMEM;
   1117		goto free_iv;
   1118	}
   1119
   1120	start_marker_record = kmalloc(sizeof(*start_marker_record), GFP_KERNEL);
   1121	if (!start_marker_record) {
   1122		rc = -ENOMEM;
   1123		goto free_rec_seq;
   1124	}
   1125
   1126	offload_ctx = kzalloc(TLS_OFFLOAD_CONTEXT_SIZE_TX, GFP_KERNEL);
   1127	if (!offload_ctx) {
   1128		rc = -ENOMEM;
   1129		goto free_marker_record;
   1130	}
   1131
   1132	rc = tls_sw_fallback_init(sk, offload_ctx, crypto_info);
   1133	if (rc)
   1134		goto free_offload_ctx;
   1135
   1136	/* start at rec_seq - 1 to account for the start marker record */
   1137	memcpy(&rcd_sn, ctx->tx.rec_seq, sizeof(rcd_sn));
   1138	offload_ctx->unacked_record_sn = be64_to_cpu(rcd_sn) - 1;
   1139
   1140	start_marker_record->end_seq = tcp_sk(sk)->write_seq;
   1141	start_marker_record->len = 0;
   1142	start_marker_record->num_frags = 0;
   1143
   1144	INIT_LIST_HEAD(&offload_ctx->records_list);
   1145	list_add_tail(&start_marker_record->list, &offload_ctx->records_list);
   1146	spin_lock_init(&offload_ctx->lock);
   1147	sg_init_table(offload_ctx->sg_tx_data,
   1148		      ARRAY_SIZE(offload_ctx->sg_tx_data));
   1149
   1150	clean_acked_data_enable(inet_csk(sk), &tls_icsk_clean_acked);
   1151	ctx->push_pending_record = tls_device_push_pending_record;
   1152
   1153	/* TLS offload is greatly simplified if we don't send
   1154	 * SKBs where only part of the payload needs to be encrypted.
   1155	 * So mark the last skb in the write queue as end of record.
   1156	 */
   1157	skb = tcp_write_queue_tail(sk);
   1158	if (skb)
   1159		TCP_SKB_CB(skb)->eor = 1;
   1160
   1161	/* Avoid offloading if the device is down
   1162	 * We don't want to offload new flows after
   1163	 * the NETDEV_DOWN event
   1164	 *
   1165	 * device_offload_lock is taken in tls_devices's NETDEV_DOWN
   1166	 * handler thus protecting from the device going down before
   1167	 * ctx was added to tls_device_list.
   1168	 */
   1169	down_read(&device_offload_lock);
   1170	if (!(netdev->flags & IFF_UP)) {
   1171		rc = -EINVAL;
   1172		goto release_lock;
   1173	}
   1174
   1175	ctx->priv_ctx_tx = offload_ctx;
   1176	rc = netdev->tlsdev_ops->tls_dev_add(netdev, sk, TLS_OFFLOAD_CTX_DIR_TX,
   1177					     &ctx->crypto_send.info,
   1178					     tcp_sk(sk)->write_seq);
   1179	trace_tls_device_offload_set(sk, TLS_OFFLOAD_CTX_DIR_TX,
   1180				     tcp_sk(sk)->write_seq, rec_seq, rc);
   1181	if (rc)
   1182		goto release_lock;
   1183
   1184	tls_device_attach(ctx, sk, netdev);
   1185	up_read(&device_offload_lock);
   1186
   1187	/* following this assignment tls_is_sk_tx_device_offloaded
   1188	 * will return true and the context might be accessed
   1189	 * by the netdev's xmit function.
   1190	 */
   1191	smp_store_release(&sk->sk_validate_xmit_skb, tls_validate_xmit_skb);
   1192	dev_put(netdev);
   1193
   1194	return 0;
   1195
   1196release_lock:
   1197	up_read(&device_offload_lock);
   1198	clean_acked_data_disable(inet_csk(sk));
   1199	crypto_free_aead(offload_ctx->aead_send);
   1200free_offload_ctx:
   1201	kfree(offload_ctx);
   1202	ctx->priv_ctx_tx = NULL;
   1203free_marker_record:
   1204	kfree(start_marker_record);
   1205free_rec_seq:
   1206	kfree(ctx->tx.rec_seq);
   1207free_iv:
   1208	kfree(ctx->tx.iv);
   1209release_netdev:
   1210	dev_put(netdev);
   1211	return rc;
   1212}
   1213
   1214int tls_set_device_offload_rx(struct sock *sk, struct tls_context *ctx)
   1215{
   1216	struct tls12_crypto_info_aes_gcm_128 *info;
   1217	struct tls_offload_context_rx *context;
   1218	struct net_device *netdev;
   1219	int rc = 0;
   1220
   1221	if (ctx->crypto_recv.info.version != TLS_1_2_VERSION)
   1222		return -EOPNOTSUPP;
   1223
   1224	netdev = get_netdev_for_sock(sk);
   1225	if (!netdev) {
   1226		pr_err_ratelimited("%s: netdev not found\n", __func__);
   1227		return -EINVAL;
   1228	}
   1229
   1230	if (!(netdev->features & NETIF_F_HW_TLS_RX)) {
   1231		rc = -EOPNOTSUPP;
   1232		goto release_netdev;
   1233	}
   1234
   1235	/* Avoid offloading if the device is down
   1236	 * We don't want to offload new flows after
   1237	 * the NETDEV_DOWN event
   1238	 *
   1239	 * device_offload_lock is taken in tls_devices's NETDEV_DOWN
   1240	 * handler thus protecting from the device going down before
   1241	 * ctx was added to tls_device_list.
   1242	 */
   1243	down_read(&device_offload_lock);
   1244	if (!(netdev->flags & IFF_UP)) {
   1245		rc = -EINVAL;
   1246		goto release_lock;
   1247	}
   1248
   1249	context = kzalloc(TLS_OFFLOAD_CONTEXT_SIZE_RX, GFP_KERNEL);
   1250	if (!context) {
   1251		rc = -ENOMEM;
   1252		goto release_lock;
   1253	}
   1254	context->resync_nh_reset = 1;
   1255
   1256	ctx->priv_ctx_rx = context;
   1257	rc = tls_set_sw_offload(sk, ctx, 0);
   1258	if (rc)
   1259		goto release_ctx;
   1260
   1261	rc = netdev->tlsdev_ops->tls_dev_add(netdev, sk, TLS_OFFLOAD_CTX_DIR_RX,
   1262					     &ctx->crypto_recv.info,
   1263					     tcp_sk(sk)->copied_seq);
   1264	info = (void *)&ctx->crypto_recv.info;
   1265	trace_tls_device_offload_set(sk, TLS_OFFLOAD_CTX_DIR_RX,
   1266				     tcp_sk(sk)->copied_seq, info->rec_seq, rc);
   1267	if (rc)
   1268		goto free_sw_resources;
   1269
   1270	tls_device_attach(ctx, sk, netdev);
   1271	up_read(&device_offload_lock);
   1272
   1273	dev_put(netdev);
   1274
   1275	return 0;
   1276
   1277free_sw_resources:
   1278	up_read(&device_offload_lock);
   1279	tls_sw_free_resources_rx(sk);
   1280	down_read(&device_offload_lock);
   1281release_ctx:
   1282	ctx->priv_ctx_rx = NULL;
   1283release_lock:
   1284	up_read(&device_offload_lock);
   1285release_netdev:
   1286	dev_put(netdev);
   1287	return rc;
   1288}
   1289
   1290void tls_device_offload_cleanup_rx(struct sock *sk)
   1291{
   1292	struct tls_context *tls_ctx = tls_get_ctx(sk);
   1293	struct net_device *netdev;
   1294
   1295	down_read(&device_offload_lock);
   1296	netdev = tls_ctx->netdev;
   1297	if (!netdev)
   1298		goto out;
   1299
   1300	netdev->tlsdev_ops->tls_dev_del(netdev, tls_ctx,
   1301					TLS_OFFLOAD_CTX_DIR_RX);
   1302
   1303	if (tls_ctx->tx_conf != TLS_HW) {
   1304		dev_put(netdev);
   1305		tls_ctx->netdev = NULL;
   1306	} else {
   1307		set_bit(TLS_RX_DEV_CLOSED, &tls_ctx->flags);
   1308	}
   1309out:
   1310	up_read(&device_offload_lock);
   1311	tls_sw_release_resources_rx(sk);
   1312}
   1313
   1314static int tls_device_down(struct net_device *netdev)
   1315{
   1316	struct tls_context *ctx, *tmp;
   1317	unsigned long flags;
   1318	LIST_HEAD(list);
   1319
   1320	/* Request a write lock to block new offload attempts */
   1321	down_write(&device_offload_lock);
   1322
   1323	spin_lock_irqsave(&tls_device_lock, flags);
   1324	list_for_each_entry_safe(ctx, tmp, &tls_device_list, list) {
   1325		if (ctx->netdev != netdev ||
   1326		    !refcount_inc_not_zero(&ctx->refcount))
   1327			continue;
   1328
   1329		list_move(&ctx->list, &list);
   1330	}
   1331	spin_unlock_irqrestore(&tls_device_lock, flags);
   1332
   1333	list_for_each_entry_safe(ctx, tmp, &list, list)	{
   1334		/* Stop offloaded TX and switch to the fallback.
   1335		 * tls_is_sk_tx_device_offloaded will return false.
   1336		 */
   1337		WRITE_ONCE(ctx->sk->sk_validate_xmit_skb, tls_validate_xmit_skb_sw);
   1338
   1339		/* Stop the RX and TX resync.
   1340		 * tls_dev_resync must not be called after tls_dev_del.
   1341		 */
   1342		WRITE_ONCE(ctx->netdev, NULL);
   1343
   1344		/* Start skipping the RX resync logic completely. */
   1345		set_bit(TLS_RX_DEV_DEGRADED, &ctx->flags);
   1346
   1347		/* Sync with inflight packets. After this point:
   1348		 * TX: no non-encrypted packets will be passed to the driver.
   1349		 * RX: resync requests from the driver will be ignored.
   1350		 */
   1351		synchronize_net();
   1352
   1353		/* Release the offload context on the driver side. */
   1354		if (ctx->tx_conf == TLS_HW)
   1355			netdev->tlsdev_ops->tls_dev_del(netdev, ctx,
   1356							TLS_OFFLOAD_CTX_DIR_TX);
   1357		if (ctx->rx_conf == TLS_HW &&
   1358		    !test_bit(TLS_RX_DEV_CLOSED, &ctx->flags))
   1359			netdev->tlsdev_ops->tls_dev_del(netdev, ctx,
   1360							TLS_OFFLOAD_CTX_DIR_RX);
   1361
   1362		dev_put(netdev);
   1363
   1364		/* Move the context to a separate list for two reasons:
   1365		 * 1. When the context is deallocated, list_del is called.
   1366		 * 2. It's no longer an offloaded context, so we don't want to
   1367		 *    run offload-specific code on this context.
   1368		 */
   1369		spin_lock_irqsave(&tls_device_lock, flags);
   1370		list_move_tail(&ctx->list, &tls_device_down_list);
   1371		spin_unlock_irqrestore(&tls_device_lock, flags);
   1372
   1373		/* Device contexts for RX and TX will be freed in on sk_destruct
   1374		 * by tls_device_free_ctx. rx_conf and tx_conf stay in TLS_HW.
   1375		 * Now release the ref taken above.
   1376		 */
   1377		if (refcount_dec_and_test(&ctx->refcount))
   1378			tls_device_free_ctx(ctx);
   1379	}
   1380
   1381	up_write(&device_offload_lock);
   1382
   1383	flush_work(&tls_device_gc_work);
   1384
   1385	return NOTIFY_DONE;
   1386}
   1387
   1388static int tls_dev_event(struct notifier_block *this, unsigned long event,
   1389			 void *ptr)
   1390{
   1391	struct net_device *dev = netdev_notifier_info_to_dev(ptr);
   1392
   1393	if (!dev->tlsdev_ops &&
   1394	    !(dev->features & (NETIF_F_HW_TLS_RX | NETIF_F_HW_TLS_TX)))
   1395		return NOTIFY_DONE;
   1396
   1397	switch (event) {
   1398	case NETDEV_REGISTER:
   1399	case NETDEV_FEAT_CHANGE:
   1400		if (netif_is_bond_master(dev))
   1401			return NOTIFY_DONE;
   1402		if ((dev->features & NETIF_F_HW_TLS_RX) &&
   1403		    !dev->tlsdev_ops->tls_dev_resync)
   1404			return NOTIFY_BAD;
   1405
   1406		if  (dev->tlsdev_ops &&
   1407		     dev->tlsdev_ops->tls_dev_add &&
   1408		     dev->tlsdev_ops->tls_dev_del)
   1409			return NOTIFY_DONE;
   1410		else
   1411			return NOTIFY_BAD;
   1412	case NETDEV_DOWN:
   1413		return tls_device_down(dev);
   1414	}
   1415	return NOTIFY_DONE;
   1416}
   1417
   1418static struct notifier_block tls_dev_notifier = {
   1419	.notifier_call	= tls_dev_event,
   1420};
   1421
   1422void __init tls_device_init(void)
   1423{
   1424	register_netdevice_notifier(&tls_dev_notifier);
   1425}
   1426
   1427void __exit tls_device_cleanup(void)
   1428{
   1429	unregister_netdevice_notifier(&tls_dev_notifier);
   1430	flush_work(&tls_device_gc_work);
   1431	clean_acked_data_flush();
   1432}