cachepc-linux

Fork of AMDESE/linux with modifications for CachePC side-channel attack
git clone https://git.sinitax.com/sinitax/cachepc-linux
Log | Files | Refs | README | LICENSE | sfeed.txt

tls_main.c (26565B)


      1/*
      2 * Copyright (c) 2016-2017, Mellanox Technologies. All rights reserved.
      3 * Copyright (c) 2016-2017, Dave Watson <davejwatson@fb.com>. All rights reserved.
      4 *
      5 * This software is available to you under a choice of one of two
      6 * licenses.  You may choose to be licensed under the terms of the GNU
      7 * General Public License (GPL) Version 2, available from the file
      8 * COPYING in the main directory of this source tree, or the
      9 * OpenIB.org BSD license below:
     10 *
     11 *     Redistribution and use in source and binary forms, with or
     12 *     without modification, are permitted provided that the following
     13 *     conditions are met:
     14 *
     15 *      - Redistributions of source code must retain the above
     16 *        copyright notice, this list of conditions and the following
     17 *        disclaimer.
     18 *
     19 *      - Redistributions in binary form must reproduce the above
     20 *        copyright notice, this list of conditions and the following
     21 *        disclaimer in the documentation and/or other materials
     22 *        provided with the distribution.
     23 *
     24 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
     25 * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
     26 * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
     27 * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
     28 * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
     29 * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
     30 * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
     31 * SOFTWARE.
     32 */
     33
     34#include <linux/module.h>
     35
     36#include <net/tcp.h>
     37#include <net/inet_common.h>
     38#include <linux/highmem.h>
     39#include <linux/netdevice.h>
     40#include <linux/sched/signal.h>
     41#include <linux/inetdevice.h>
     42#include <linux/inet_diag.h>
     43
     44#include <net/snmp.h>
     45#include <net/tls.h>
     46#include <net/tls_toe.h>
     47
     48MODULE_AUTHOR("Mellanox Technologies");
     49MODULE_DESCRIPTION("Transport Layer Security Support");
     50MODULE_LICENSE("Dual BSD/GPL");
     51MODULE_ALIAS_TCP_ULP("tls");
     52
     53enum {
     54	TLSV4,
     55	TLSV6,
     56	TLS_NUM_PROTS,
     57};
     58
     59static const struct proto *saved_tcpv6_prot;
     60static DEFINE_MUTEX(tcpv6_prot_mutex);
     61static const struct proto *saved_tcpv4_prot;
     62static DEFINE_MUTEX(tcpv4_prot_mutex);
     63static struct proto tls_prots[TLS_NUM_PROTS][TLS_NUM_CONFIG][TLS_NUM_CONFIG];
     64static struct proto_ops tls_proto_ops[TLS_NUM_PROTS][TLS_NUM_CONFIG][TLS_NUM_CONFIG];
     65static void build_protos(struct proto prot[TLS_NUM_CONFIG][TLS_NUM_CONFIG],
     66			 const struct proto *base);
     67
     68void update_sk_prot(struct sock *sk, struct tls_context *ctx)
     69{
     70	int ip_ver = sk->sk_family == AF_INET6 ? TLSV6 : TLSV4;
     71
     72	WRITE_ONCE(sk->sk_prot,
     73		   &tls_prots[ip_ver][ctx->tx_conf][ctx->rx_conf]);
     74	WRITE_ONCE(sk->sk_socket->ops,
     75		   &tls_proto_ops[ip_ver][ctx->tx_conf][ctx->rx_conf]);
     76}
     77
     78int wait_on_pending_writer(struct sock *sk, long *timeo)
     79{
     80	int rc = 0;
     81	DEFINE_WAIT_FUNC(wait, woken_wake_function);
     82
     83	add_wait_queue(sk_sleep(sk), &wait);
     84	while (1) {
     85		if (!*timeo) {
     86			rc = -EAGAIN;
     87			break;
     88		}
     89
     90		if (signal_pending(current)) {
     91			rc = sock_intr_errno(*timeo);
     92			break;
     93		}
     94
     95		if (sk_wait_event(sk, timeo, !sk->sk_write_pending, &wait))
     96			break;
     97	}
     98	remove_wait_queue(sk_sleep(sk), &wait);
     99	return rc;
    100}
    101
    102int tls_push_sg(struct sock *sk,
    103		struct tls_context *ctx,
    104		struct scatterlist *sg,
    105		u16 first_offset,
    106		int flags)
    107{
    108	int sendpage_flags = flags | MSG_SENDPAGE_NOTLAST;
    109	int ret = 0;
    110	struct page *p;
    111	size_t size;
    112	int offset = first_offset;
    113
    114	size = sg->length - offset;
    115	offset += sg->offset;
    116
    117	ctx->in_tcp_sendpages = true;
    118	while (1) {
    119		if (sg_is_last(sg))
    120			sendpage_flags = flags;
    121
    122		/* is sending application-limited? */
    123		tcp_rate_check_app_limited(sk);
    124		p = sg_page(sg);
    125retry:
    126		ret = do_tcp_sendpages(sk, p, offset, size, sendpage_flags);
    127
    128		if (ret != size) {
    129			if (ret > 0) {
    130				offset += ret;
    131				size -= ret;
    132				goto retry;
    133			}
    134
    135			offset -= sg->offset;
    136			ctx->partially_sent_offset = offset;
    137			ctx->partially_sent_record = (void *)sg;
    138			ctx->in_tcp_sendpages = false;
    139			return ret;
    140		}
    141
    142		put_page(p);
    143		sk_mem_uncharge(sk, sg->length);
    144		sg = sg_next(sg);
    145		if (!sg)
    146			break;
    147
    148		offset = sg->offset;
    149		size = sg->length;
    150	}
    151
    152	ctx->in_tcp_sendpages = false;
    153
    154	return 0;
    155}
    156
    157static int tls_handle_open_record(struct sock *sk, int flags)
    158{
    159	struct tls_context *ctx = tls_get_ctx(sk);
    160
    161	if (tls_is_pending_open_record(ctx))
    162		return ctx->push_pending_record(sk, flags);
    163
    164	return 0;
    165}
    166
    167int tls_proccess_cmsg(struct sock *sk, struct msghdr *msg,
    168		      unsigned char *record_type)
    169{
    170	struct cmsghdr *cmsg;
    171	int rc = -EINVAL;
    172
    173	for_each_cmsghdr(cmsg, msg) {
    174		if (!CMSG_OK(msg, cmsg))
    175			return -EINVAL;
    176		if (cmsg->cmsg_level != SOL_TLS)
    177			continue;
    178
    179		switch (cmsg->cmsg_type) {
    180		case TLS_SET_RECORD_TYPE:
    181			if (cmsg->cmsg_len < CMSG_LEN(sizeof(*record_type)))
    182				return -EINVAL;
    183
    184			if (msg->msg_flags & MSG_MORE)
    185				return -EINVAL;
    186
    187			rc = tls_handle_open_record(sk, msg->msg_flags);
    188			if (rc)
    189				return rc;
    190
    191			*record_type = *(unsigned char *)CMSG_DATA(cmsg);
    192			rc = 0;
    193			break;
    194		default:
    195			return -EINVAL;
    196		}
    197	}
    198
    199	return rc;
    200}
    201
    202int tls_push_partial_record(struct sock *sk, struct tls_context *ctx,
    203			    int flags)
    204{
    205	struct scatterlist *sg;
    206	u16 offset;
    207
    208	sg = ctx->partially_sent_record;
    209	offset = ctx->partially_sent_offset;
    210
    211	ctx->partially_sent_record = NULL;
    212	return tls_push_sg(sk, ctx, sg, offset, flags);
    213}
    214
    215void tls_free_partial_record(struct sock *sk, struct tls_context *ctx)
    216{
    217	struct scatterlist *sg;
    218
    219	for (sg = ctx->partially_sent_record; sg; sg = sg_next(sg)) {
    220		put_page(sg_page(sg));
    221		sk_mem_uncharge(sk, sg->length);
    222	}
    223	ctx->partially_sent_record = NULL;
    224}
    225
    226static void tls_write_space(struct sock *sk)
    227{
    228	struct tls_context *ctx = tls_get_ctx(sk);
    229
    230	/* If in_tcp_sendpages call lower protocol write space handler
    231	 * to ensure we wake up any waiting operations there. For example
    232	 * if do_tcp_sendpages where to call sk_wait_event.
    233	 */
    234	if (ctx->in_tcp_sendpages) {
    235		ctx->sk_write_space(sk);
    236		return;
    237	}
    238
    239#ifdef CONFIG_TLS_DEVICE
    240	if (ctx->tx_conf == TLS_HW)
    241		tls_device_write_space(sk, ctx);
    242	else
    243#endif
    244		tls_sw_write_space(sk, ctx);
    245
    246	ctx->sk_write_space(sk);
    247}
    248
    249/**
    250 * tls_ctx_free() - free TLS ULP context
    251 * @sk:  socket to with @ctx is attached
    252 * @ctx: TLS context structure
    253 *
    254 * Free TLS context. If @sk is %NULL caller guarantees that the socket
    255 * to which @ctx was attached has no outstanding references.
    256 */
    257void tls_ctx_free(struct sock *sk, struct tls_context *ctx)
    258{
    259	if (!ctx)
    260		return;
    261
    262	memzero_explicit(&ctx->crypto_send, sizeof(ctx->crypto_send));
    263	memzero_explicit(&ctx->crypto_recv, sizeof(ctx->crypto_recv));
    264	mutex_destroy(&ctx->tx_lock);
    265
    266	if (sk)
    267		kfree_rcu(ctx, rcu);
    268	else
    269		kfree(ctx);
    270}
    271
    272static void tls_sk_proto_cleanup(struct sock *sk,
    273				 struct tls_context *ctx, long timeo)
    274{
    275	if (unlikely(sk->sk_write_pending) &&
    276	    !wait_on_pending_writer(sk, &timeo))
    277		tls_handle_open_record(sk, 0);
    278
    279	/* We need these for tls_sw_fallback handling of other packets */
    280	if (ctx->tx_conf == TLS_SW) {
    281		kfree(ctx->tx.rec_seq);
    282		kfree(ctx->tx.iv);
    283		tls_sw_release_resources_tx(sk);
    284		TLS_DEC_STATS(sock_net(sk), LINUX_MIB_TLSCURRTXSW);
    285	} else if (ctx->tx_conf == TLS_HW) {
    286		tls_device_free_resources_tx(sk);
    287		TLS_DEC_STATS(sock_net(sk), LINUX_MIB_TLSCURRTXDEVICE);
    288	}
    289
    290	if (ctx->rx_conf == TLS_SW) {
    291		tls_sw_release_resources_rx(sk);
    292		TLS_DEC_STATS(sock_net(sk), LINUX_MIB_TLSCURRRXSW);
    293	} else if (ctx->rx_conf == TLS_HW) {
    294		tls_device_offload_cleanup_rx(sk);
    295		TLS_DEC_STATS(sock_net(sk), LINUX_MIB_TLSCURRRXDEVICE);
    296	}
    297}
    298
    299static void tls_sk_proto_close(struct sock *sk, long timeout)
    300{
    301	struct inet_connection_sock *icsk = inet_csk(sk);
    302	struct tls_context *ctx = tls_get_ctx(sk);
    303	long timeo = sock_sndtimeo(sk, 0);
    304	bool free_ctx;
    305
    306	if (ctx->tx_conf == TLS_SW)
    307		tls_sw_cancel_work_tx(ctx);
    308
    309	lock_sock(sk);
    310	free_ctx = ctx->tx_conf != TLS_HW && ctx->rx_conf != TLS_HW;
    311
    312	if (ctx->tx_conf != TLS_BASE || ctx->rx_conf != TLS_BASE)
    313		tls_sk_proto_cleanup(sk, ctx, timeo);
    314
    315	write_lock_bh(&sk->sk_callback_lock);
    316	if (free_ctx)
    317		rcu_assign_pointer(icsk->icsk_ulp_data, NULL);
    318	WRITE_ONCE(sk->sk_prot, ctx->sk_proto);
    319	if (sk->sk_write_space == tls_write_space)
    320		sk->sk_write_space = ctx->sk_write_space;
    321	write_unlock_bh(&sk->sk_callback_lock);
    322	release_sock(sk);
    323	if (ctx->tx_conf == TLS_SW)
    324		tls_sw_free_ctx_tx(ctx);
    325	if (ctx->rx_conf == TLS_SW || ctx->rx_conf == TLS_HW)
    326		tls_sw_strparser_done(ctx);
    327	if (ctx->rx_conf == TLS_SW)
    328		tls_sw_free_ctx_rx(ctx);
    329	ctx->sk_proto->close(sk, timeout);
    330
    331	if (free_ctx)
    332		tls_ctx_free(sk, ctx);
    333}
    334
    335static int do_tls_getsockopt_conf(struct sock *sk, char __user *optval,
    336				  int __user *optlen, int tx)
    337{
    338	int rc = 0;
    339	struct tls_context *ctx = tls_get_ctx(sk);
    340	struct tls_crypto_info *crypto_info;
    341	struct cipher_context *cctx;
    342	int len;
    343
    344	if (get_user(len, optlen))
    345		return -EFAULT;
    346
    347	if (!optval || (len < sizeof(*crypto_info))) {
    348		rc = -EINVAL;
    349		goto out;
    350	}
    351
    352	if (!ctx) {
    353		rc = -EBUSY;
    354		goto out;
    355	}
    356
    357	/* get user crypto info */
    358	if (tx) {
    359		crypto_info = &ctx->crypto_send.info;
    360		cctx = &ctx->tx;
    361	} else {
    362		crypto_info = &ctx->crypto_recv.info;
    363		cctx = &ctx->rx;
    364	}
    365
    366	if (!TLS_CRYPTO_INFO_READY(crypto_info)) {
    367		rc = -EBUSY;
    368		goto out;
    369	}
    370
    371	if (len == sizeof(*crypto_info)) {
    372		if (copy_to_user(optval, crypto_info, sizeof(*crypto_info)))
    373			rc = -EFAULT;
    374		goto out;
    375	}
    376
    377	switch (crypto_info->cipher_type) {
    378	case TLS_CIPHER_AES_GCM_128: {
    379		struct tls12_crypto_info_aes_gcm_128 *
    380		  crypto_info_aes_gcm_128 =
    381		  container_of(crypto_info,
    382			       struct tls12_crypto_info_aes_gcm_128,
    383			       info);
    384
    385		if (len != sizeof(*crypto_info_aes_gcm_128)) {
    386			rc = -EINVAL;
    387			goto out;
    388		}
    389		lock_sock(sk);
    390		memcpy(crypto_info_aes_gcm_128->iv,
    391		       cctx->iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE,
    392		       TLS_CIPHER_AES_GCM_128_IV_SIZE);
    393		memcpy(crypto_info_aes_gcm_128->rec_seq, cctx->rec_seq,
    394		       TLS_CIPHER_AES_GCM_128_REC_SEQ_SIZE);
    395		release_sock(sk);
    396		if (copy_to_user(optval,
    397				 crypto_info_aes_gcm_128,
    398				 sizeof(*crypto_info_aes_gcm_128)))
    399			rc = -EFAULT;
    400		break;
    401	}
    402	case TLS_CIPHER_AES_GCM_256: {
    403		struct tls12_crypto_info_aes_gcm_256 *
    404		  crypto_info_aes_gcm_256 =
    405		  container_of(crypto_info,
    406			       struct tls12_crypto_info_aes_gcm_256,
    407			       info);
    408
    409		if (len != sizeof(*crypto_info_aes_gcm_256)) {
    410			rc = -EINVAL;
    411			goto out;
    412		}
    413		lock_sock(sk);
    414		memcpy(crypto_info_aes_gcm_256->iv,
    415		       cctx->iv + TLS_CIPHER_AES_GCM_256_SALT_SIZE,
    416		       TLS_CIPHER_AES_GCM_256_IV_SIZE);
    417		memcpy(crypto_info_aes_gcm_256->rec_seq, cctx->rec_seq,
    418		       TLS_CIPHER_AES_GCM_256_REC_SEQ_SIZE);
    419		release_sock(sk);
    420		if (copy_to_user(optval,
    421				 crypto_info_aes_gcm_256,
    422				 sizeof(*crypto_info_aes_gcm_256)))
    423			rc = -EFAULT;
    424		break;
    425	}
    426	case TLS_CIPHER_AES_CCM_128: {
    427		struct tls12_crypto_info_aes_ccm_128 *aes_ccm_128 =
    428			container_of(crypto_info,
    429				struct tls12_crypto_info_aes_ccm_128, info);
    430
    431		if (len != sizeof(*aes_ccm_128)) {
    432			rc = -EINVAL;
    433			goto out;
    434		}
    435		lock_sock(sk);
    436		memcpy(aes_ccm_128->iv,
    437		       cctx->iv + TLS_CIPHER_AES_CCM_128_SALT_SIZE,
    438		       TLS_CIPHER_AES_CCM_128_IV_SIZE);
    439		memcpy(aes_ccm_128->rec_seq, cctx->rec_seq,
    440		       TLS_CIPHER_AES_CCM_128_REC_SEQ_SIZE);
    441		release_sock(sk);
    442		if (copy_to_user(optval, aes_ccm_128, sizeof(*aes_ccm_128)))
    443			rc = -EFAULT;
    444		break;
    445	}
    446	case TLS_CIPHER_CHACHA20_POLY1305: {
    447		struct tls12_crypto_info_chacha20_poly1305 *chacha20_poly1305 =
    448			container_of(crypto_info,
    449				struct tls12_crypto_info_chacha20_poly1305,
    450				info);
    451
    452		if (len != sizeof(*chacha20_poly1305)) {
    453			rc = -EINVAL;
    454			goto out;
    455		}
    456		lock_sock(sk);
    457		memcpy(chacha20_poly1305->iv,
    458		       cctx->iv + TLS_CIPHER_CHACHA20_POLY1305_SALT_SIZE,
    459		       TLS_CIPHER_CHACHA20_POLY1305_IV_SIZE);
    460		memcpy(chacha20_poly1305->rec_seq, cctx->rec_seq,
    461		       TLS_CIPHER_CHACHA20_POLY1305_REC_SEQ_SIZE);
    462		release_sock(sk);
    463		if (copy_to_user(optval, chacha20_poly1305,
    464				sizeof(*chacha20_poly1305)))
    465			rc = -EFAULT;
    466		break;
    467	}
    468	case TLS_CIPHER_SM4_GCM: {
    469		struct tls12_crypto_info_sm4_gcm *sm4_gcm_info =
    470			container_of(crypto_info,
    471				struct tls12_crypto_info_sm4_gcm, info);
    472
    473		if (len != sizeof(*sm4_gcm_info)) {
    474			rc = -EINVAL;
    475			goto out;
    476		}
    477		lock_sock(sk);
    478		memcpy(sm4_gcm_info->iv,
    479		       cctx->iv + TLS_CIPHER_SM4_GCM_SALT_SIZE,
    480		       TLS_CIPHER_SM4_GCM_IV_SIZE);
    481		memcpy(sm4_gcm_info->rec_seq, cctx->rec_seq,
    482		       TLS_CIPHER_SM4_GCM_REC_SEQ_SIZE);
    483		release_sock(sk);
    484		if (copy_to_user(optval, sm4_gcm_info, sizeof(*sm4_gcm_info)))
    485			rc = -EFAULT;
    486		break;
    487	}
    488	case TLS_CIPHER_SM4_CCM: {
    489		struct tls12_crypto_info_sm4_ccm *sm4_ccm_info =
    490			container_of(crypto_info,
    491				struct tls12_crypto_info_sm4_ccm, info);
    492
    493		if (len != sizeof(*sm4_ccm_info)) {
    494			rc = -EINVAL;
    495			goto out;
    496		}
    497		lock_sock(sk);
    498		memcpy(sm4_ccm_info->iv,
    499		       cctx->iv + TLS_CIPHER_SM4_CCM_SALT_SIZE,
    500		       TLS_CIPHER_SM4_CCM_IV_SIZE);
    501		memcpy(sm4_ccm_info->rec_seq, cctx->rec_seq,
    502		       TLS_CIPHER_SM4_CCM_REC_SEQ_SIZE);
    503		release_sock(sk);
    504		if (copy_to_user(optval, sm4_ccm_info, sizeof(*sm4_ccm_info)))
    505			rc = -EFAULT;
    506		break;
    507	}
    508	default:
    509		rc = -EINVAL;
    510	}
    511
    512out:
    513	return rc;
    514}
    515
    516static int do_tls_getsockopt_tx_zc(struct sock *sk, char __user *optval,
    517				   int __user *optlen)
    518{
    519	struct tls_context *ctx = tls_get_ctx(sk);
    520	unsigned int value;
    521	int len;
    522
    523	if (get_user(len, optlen))
    524		return -EFAULT;
    525
    526	if (len != sizeof(value))
    527		return -EINVAL;
    528
    529	value = ctx->zerocopy_sendfile;
    530	if (copy_to_user(optval, &value, sizeof(value)))
    531		return -EFAULT;
    532
    533	return 0;
    534}
    535
    536static int do_tls_getsockopt(struct sock *sk, int optname,
    537			     char __user *optval, int __user *optlen)
    538{
    539	int rc = 0;
    540
    541	switch (optname) {
    542	case TLS_TX:
    543	case TLS_RX:
    544		rc = do_tls_getsockopt_conf(sk, optval, optlen,
    545					    optname == TLS_TX);
    546		break;
    547	case TLS_TX_ZEROCOPY_RO:
    548		rc = do_tls_getsockopt_tx_zc(sk, optval, optlen);
    549		break;
    550	default:
    551		rc = -ENOPROTOOPT;
    552		break;
    553	}
    554	return rc;
    555}
    556
    557static int tls_getsockopt(struct sock *sk, int level, int optname,
    558			  char __user *optval, int __user *optlen)
    559{
    560	struct tls_context *ctx = tls_get_ctx(sk);
    561
    562	if (level != SOL_TLS)
    563		return ctx->sk_proto->getsockopt(sk, level,
    564						 optname, optval, optlen);
    565
    566	return do_tls_getsockopt(sk, optname, optval, optlen);
    567}
    568
    569static int do_tls_setsockopt_conf(struct sock *sk, sockptr_t optval,
    570				  unsigned int optlen, int tx)
    571{
    572	struct tls_crypto_info *crypto_info;
    573	struct tls_crypto_info *alt_crypto_info;
    574	struct tls_context *ctx = tls_get_ctx(sk);
    575	size_t optsize;
    576	int rc = 0;
    577	int conf;
    578
    579	if (sockptr_is_null(optval) || (optlen < sizeof(*crypto_info)))
    580		return -EINVAL;
    581
    582	if (tx) {
    583		crypto_info = &ctx->crypto_send.info;
    584		alt_crypto_info = &ctx->crypto_recv.info;
    585	} else {
    586		crypto_info = &ctx->crypto_recv.info;
    587		alt_crypto_info = &ctx->crypto_send.info;
    588	}
    589
    590	/* Currently we don't support set crypto info more than one time */
    591	if (TLS_CRYPTO_INFO_READY(crypto_info))
    592		return -EBUSY;
    593
    594	rc = copy_from_sockptr(crypto_info, optval, sizeof(*crypto_info));
    595	if (rc) {
    596		rc = -EFAULT;
    597		goto err_crypto_info;
    598	}
    599
    600	/* check version */
    601	if (crypto_info->version != TLS_1_2_VERSION &&
    602	    crypto_info->version != TLS_1_3_VERSION) {
    603		rc = -EINVAL;
    604		goto err_crypto_info;
    605	}
    606
    607	/* Ensure that TLS version and ciphers are same in both directions */
    608	if (TLS_CRYPTO_INFO_READY(alt_crypto_info)) {
    609		if (alt_crypto_info->version != crypto_info->version ||
    610		    alt_crypto_info->cipher_type != crypto_info->cipher_type) {
    611			rc = -EINVAL;
    612			goto err_crypto_info;
    613		}
    614	}
    615
    616	switch (crypto_info->cipher_type) {
    617	case TLS_CIPHER_AES_GCM_128:
    618		optsize = sizeof(struct tls12_crypto_info_aes_gcm_128);
    619		break;
    620	case TLS_CIPHER_AES_GCM_256: {
    621		optsize = sizeof(struct tls12_crypto_info_aes_gcm_256);
    622		break;
    623	}
    624	case TLS_CIPHER_AES_CCM_128:
    625		optsize = sizeof(struct tls12_crypto_info_aes_ccm_128);
    626		break;
    627	case TLS_CIPHER_CHACHA20_POLY1305:
    628		optsize = sizeof(struct tls12_crypto_info_chacha20_poly1305);
    629		break;
    630	case TLS_CIPHER_SM4_GCM:
    631		optsize = sizeof(struct tls12_crypto_info_sm4_gcm);
    632		break;
    633	case TLS_CIPHER_SM4_CCM:
    634		optsize = sizeof(struct tls12_crypto_info_sm4_ccm);
    635		break;
    636	default:
    637		rc = -EINVAL;
    638		goto err_crypto_info;
    639	}
    640
    641	if (optlen != optsize) {
    642		rc = -EINVAL;
    643		goto err_crypto_info;
    644	}
    645
    646	rc = copy_from_sockptr_offset(crypto_info + 1, optval,
    647				      sizeof(*crypto_info),
    648				      optlen - sizeof(*crypto_info));
    649	if (rc) {
    650		rc = -EFAULT;
    651		goto err_crypto_info;
    652	}
    653
    654	if (tx) {
    655		rc = tls_set_device_offload(sk, ctx);
    656		conf = TLS_HW;
    657		if (!rc) {
    658			TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSTXDEVICE);
    659			TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSCURRTXDEVICE);
    660		} else {
    661			rc = tls_set_sw_offload(sk, ctx, 1);
    662			if (rc)
    663				goto err_crypto_info;
    664			TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSTXSW);
    665			TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSCURRTXSW);
    666			conf = TLS_SW;
    667		}
    668	} else {
    669		rc = tls_set_device_offload_rx(sk, ctx);
    670		conf = TLS_HW;
    671		if (!rc) {
    672			TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSRXDEVICE);
    673			TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSCURRRXDEVICE);
    674		} else {
    675			rc = tls_set_sw_offload(sk, ctx, 0);
    676			if (rc)
    677				goto err_crypto_info;
    678			TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSRXSW);
    679			TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSCURRRXSW);
    680			conf = TLS_SW;
    681		}
    682		tls_sw_strparser_arm(sk, ctx);
    683	}
    684
    685	if (tx)
    686		ctx->tx_conf = conf;
    687	else
    688		ctx->rx_conf = conf;
    689	update_sk_prot(sk, ctx);
    690	if (tx) {
    691		ctx->sk_write_space = sk->sk_write_space;
    692		sk->sk_write_space = tls_write_space;
    693	}
    694	return 0;
    695
    696err_crypto_info:
    697	memzero_explicit(crypto_info, sizeof(union tls_crypto_context));
    698	return rc;
    699}
    700
    701static int do_tls_setsockopt_tx_zc(struct sock *sk, sockptr_t optval,
    702				   unsigned int optlen)
    703{
    704	struct tls_context *ctx = tls_get_ctx(sk);
    705	unsigned int value;
    706
    707	if (sockptr_is_null(optval) || optlen != sizeof(value))
    708		return -EINVAL;
    709
    710	if (copy_from_sockptr(&value, optval, sizeof(value)))
    711		return -EFAULT;
    712
    713	if (value > 1)
    714		return -EINVAL;
    715
    716	ctx->zerocopy_sendfile = value;
    717
    718	return 0;
    719}
    720
    721static int do_tls_setsockopt(struct sock *sk, int optname, sockptr_t optval,
    722			     unsigned int optlen)
    723{
    724	int rc = 0;
    725
    726	switch (optname) {
    727	case TLS_TX:
    728	case TLS_RX:
    729		lock_sock(sk);
    730		rc = do_tls_setsockopt_conf(sk, optval, optlen,
    731					    optname == TLS_TX);
    732		release_sock(sk);
    733		break;
    734	case TLS_TX_ZEROCOPY_RO:
    735		lock_sock(sk);
    736		rc = do_tls_setsockopt_tx_zc(sk, optval, optlen);
    737		release_sock(sk);
    738		break;
    739	default:
    740		rc = -ENOPROTOOPT;
    741		break;
    742	}
    743	return rc;
    744}
    745
    746static int tls_setsockopt(struct sock *sk, int level, int optname,
    747			  sockptr_t optval, unsigned int optlen)
    748{
    749	struct tls_context *ctx = tls_get_ctx(sk);
    750
    751	if (level != SOL_TLS)
    752		return ctx->sk_proto->setsockopt(sk, level, optname, optval,
    753						 optlen);
    754
    755	return do_tls_setsockopt(sk, optname, optval, optlen);
    756}
    757
    758struct tls_context *tls_ctx_create(struct sock *sk)
    759{
    760	struct inet_connection_sock *icsk = inet_csk(sk);
    761	struct tls_context *ctx;
    762
    763	ctx = kzalloc(sizeof(*ctx), GFP_ATOMIC);
    764	if (!ctx)
    765		return NULL;
    766
    767	mutex_init(&ctx->tx_lock);
    768	rcu_assign_pointer(icsk->icsk_ulp_data, ctx);
    769	ctx->sk_proto = READ_ONCE(sk->sk_prot);
    770	ctx->sk = sk;
    771	return ctx;
    772}
    773
    774static void build_proto_ops(struct proto_ops ops[TLS_NUM_CONFIG][TLS_NUM_CONFIG],
    775			    const struct proto_ops *base)
    776{
    777	ops[TLS_BASE][TLS_BASE] = *base;
    778
    779	ops[TLS_SW  ][TLS_BASE] = ops[TLS_BASE][TLS_BASE];
    780	ops[TLS_SW  ][TLS_BASE].sendpage_locked	= tls_sw_sendpage_locked;
    781
    782	ops[TLS_BASE][TLS_SW  ] = ops[TLS_BASE][TLS_BASE];
    783	ops[TLS_BASE][TLS_SW  ].splice_read	= tls_sw_splice_read;
    784
    785	ops[TLS_SW  ][TLS_SW  ] = ops[TLS_SW  ][TLS_BASE];
    786	ops[TLS_SW  ][TLS_SW  ].splice_read	= tls_sw_splice_read;
    787
    788#ifdef CONFIG_TLS_DEVICE
    789	ops[TLS_HW  ][TLS_BASE] = ops[TLS_BASE][TLS_BASE];
    790	ops[TLS_HW  ][TLS_BASE].sendpage_locked	= NULL;
    791
    792	ops[TLS_HW  ][TLS_SW  ] = ops[TLS_BASE][TLS_SW  ];
    793	ops[TLS_HW  ][TLS_SW  ].sendpage_locked	= NULL;
    794
    795	ops[TLS_BASE][TLS_HW  ] = ops[TLS_BASE][TLS_SW  ];
    796
    797	ops[TLS_SW  ][TLS_HW  ] = ops[TLS_SW  ][TLS_SW  ];
    798
    799	ops[TLS_HW  ][TLS_HW  ] = ops[TLS_HW  ][TLS_SW  ];
    800	ops[TLS_HW  ][TLS_HW  ].sendpage_locked	= NULL;
    801#endif
    802#ifdef CONFIG_TLS_TOE
    803	ops[TLS_HW_RECORD][TLS_HW_RECORD] = *base;
    804#endif
    805}
    806
    807static void tls_build_proto(struct sock *sk)
    808{
    809	int ip_ver = sk->sk_family == AF_INET6 ? TLSV6 : TLSV4;
    810	struct proto *prot = READ_ONCE(sk->sk_prot);
    811
    812	/* Build IPv6 TLS whenever the address of tcpv6 _prot changes */
    813	if (ip_ver == TLSV6 &&
    814	    unlikely(prot != smp_load_acquire(&saved_tcpv6_prot))) {
    815		mutex_lock(&tcpv6_prot_mutex);
    816		if (likely(prot != saved_tcpv6_prot)) {
    817			build_protos(tls_prots[TLSV6], prot);
    818			build_proto_ops(tls_proto_ops[TLSV6],
    819					sk->sk_socket->ops);
    820			smp_store_release(&saved_tcpv6_prot, prot);
    821		}
    822		mutex_unlock(&tcpv6_prot_mutex);
    823	}
    824
    825	if (ip_ver == TLSV4 &&
    826	    unlikely(prot != smp_load_acquire(&saved_tcpv4_prot))) {
    827		mutex_lock(&tcpv4_prot_mutex);
    828		if (likely(prot != saved_tcpv4_prot)) {
    829			build_protos(tls_prots[TLSV4], prot);
    830			build_proto_ops(tls_proto_ops[TLSV4],
    831					sk->sk_socket->ops);
    832			smp_store_release(&saved_tcpv4_prot, prot);
    833		}
    834		mutex_unlock(&tcpv4_prot_mutex);
    835	}
    836}
    837
    838static void build_protos(struct proto prot[TLS_NUM_CONFIG][TLS_NUM_CONFIG],
    839			 const struct proto *base)
    840{
    841	prot[TLS_BASE][TLS_BASE] = *base;
    842	prot[TLS_BASE][TLS_BASE].setsockopt	= tls_setsockopt;
    843	prot[TLS_BASE][TLS_BASE].getsockopt	= tls_getsockopt;
    844	prot[TLS_BASE][TLS_BASE].close		= tls_sk_proto_close;
    845
    846	prot[TLS_SW][TLS_BASE] = prot[TLS_BASE][TLS_BASE];
    847	prot[TLS_SW][TLS_BASE].sendmsg		= tls_sw_sendmsg;
    848	prot[TLS_SW][TLS_BASE].sendpage		= tls_sw_sendpage;
    849
    850	prot[TLS_BASE][TLS_SW] = prot[TLS_BASE][TLS_BASE];
    851	prot[TLS_BASE][TLS_SW].recvmsg		  = tls_sw_recvmsg;
    852	prot[TLS_BASE][TLS_SW].sock_is_readable   = tls_sw_sock_is_readable;
    853	prot[TLS_BASE][TLS_SW].close		  = tls_sk_proto_close;
    854
    855	prot[TLS_SW][TLS_SW] = prot[TLS_SW][TLS_BASE];
    856	prot[TLS_SW][TLS_SW].recvmsg		= tls_sw_recvmsg;
    857	prot[TLS_SW][TLS_SW].sock_is_readable   = tls_sw_sock_is_readable;
    858	prot[TLS_SW][TLS_SW].close		= tls_sk_proto_close;
    859
    860#ifdef CONFIG_TLS_DEVICE
    861	prot[TLS_HW][TLS_BASE] = prot[TLS_BASE][TLS_BASE];
    862	prot[TLS_HW][TLS_BASE].sendmsg		= tls_device_sendmsg;
    863	prot[TLS_HW][TLS_BASE].sendpage		= tls_device_sendpage;
    864
    865	prot[TLS_HW][TLS_SW] = prot[TLS_BASE][TLS_SW];
    866	prot[TLS_HW][TLS_SW].sendmsg		= tls_device_sendmsg;
    867	prot[TLS_HW][TLS_SW].sendpage		= tls_device_sendpage;
    868
    869	prot[TLS_BASE][TLS_HW] = prot[TLS_BASE][TLS_SW];
    870
    871	prot[TLS_SW][TLS_HW] = prot[TLS_SW][TLS_SW];
    872
    873	prot[TLS_HW][TLS_HW] = prot[TLS_HW][TLS_SW];
    874#endif
    875#ifdef CONFIG_TLS_TOE
    876	prot[TLS_HW_RECORD][TLS_HW_RECORD] = *base;
    877	prot[TLS_HW_RECORD][TLS_HW_RECORD].hash		= tls_toe_hash;
    878	prot[TLS_HW_RECORD][TLS_HW_RECORD].unhash	= tls_toe_unhash;
    879#endif
    880}
    881
    882static int tls_init(struct sock *sk)
    883{
    884	struct tls_context *ctx;
    885	int rc = 0;
    886
    887	tls_build_proto(sk);
    888
    889#ifdef CONFIG_TLS_TOE
    890	if (tls_toe_bypass(sk))
    891		return 0;
    892#endif
    893
    894	/* The TLS ulp is currently supported only for TCP sockets
    895	 * in ESTABLISHED state.
    896	 * Supporting sockets in LISTEN state will require us
    897	 * to modify the accept implementation to clone rather then
    898	 * share the ulp context.
    899	 */
    900	if (sk->sk_state != TCP_ESTABLISHED)
    901		return -ENOTCONN;
    902
    903	/* allocate tls context */
    904	write_lock_bh(&sk->sk_callback_lock);
    905	ctx = tls_ctx_create(sk);
    906	if (!ctx) {
    907		rc = -ENOMEM;
    908		goto out;
    909	}
    910
    911	ctx->tx_conf = TLS_BASE;
    912	ctx->rx_conf = TLS_BASE;
    913	update_sk_prot(sk, ctx);
    914out:
    915	write_unlock_bh(&sk->sk_callback_lock);
    916	return rc;
    917}
    918
    919static void tls_update(struct sock *sk, struct proto *p,
    920		       void (*write_space)(struct sock *sk))
    921{
    922	struct tls_context *ctx;
    923
    924	WARN_ON_ONCE(sk->sk_prot == p);
    925
    926	ctx = tls_get_ctx(sk);
    927	if (likely(ctx)) {
    928		ctx->sk_write_space = write_space;
    929		ctx->sk_proto = p;
    930	} else {
    931		/* Pairs with lockless read in sk_clone_lock(). */
    932		WRITE_ONCE(sk->sk_prot, p);
    933		sk->sk_write_space = write_space;
    934	}
    935}
    936
    937static int tls_get_info(const struct sock *sk, struct sk_buff *skb)
    938{
    939	u16 version, cipher_type;
    940	struct tls_context *ctx;
    941	struct nlattr *start;
    942	int err;
    943
    944	start = nla_nest_start_noflag(skb, INET_ULP_INFO_TLS);
    945	if (!start)
    946		return -EMSGSIZE;
    947
    948	rcu_read_lock();
    949	ctx = rcu_dereference(inet_csk(sk)->icsk_ulp_data);
    950	if (!ctx) {
    951		err = 0;
    952		goto nla_failure;
    953	}
    954	version = ctx->prot_info.version;
    955	if (version) {
    956		err = nla_put_u16(skb, TLS_INFO_VERSION, version);
    957		if (err)
    958			goto nla_failure;
    959	}
    960	cipher_type = ctx->prot_info.cipher_type;
    961	if (cipher_type) {
    962		err = nla_put_u16(skb, TLS_INFO_CIPHER, cipher_type);
    963		if (err)
    964			goto nla_failure;
    965	}
    966	err = nla_put_u16(skb, TLS_INFO_TXCONF, tls_user_config(ctx, true));
    967	if (err)
    968		goto nla_failure;
    969
    970	err = nla_put_u16(skb, TLS_INFO_RXCONF, tls_user_config(ctx, false));
    971	if (err)
    972		goto nla_failure;
    973
    974	if (ctx->tx_conf == TLS_HW && ctx->zerocopy_sendfile) {
    975		err = nla_put_flag(skb, TLS_INFO_ZC_RO_TX);
    976		if (err)
    977			goto nla_failure;
    978	}
    979
    980	rcu_read_unlock();
    981	nla_nest_end(skb, start);
    982	return 0;
    983
    984nla_failure:
    985	rcu_read_unlock();
    986	nla_nest_cancel(skb, start);
    987	return err;
    988}
    989
    990static size_t tls_get_info_size(const struct sock *sk)
    991{
    992	size_t size = 0;
    993
    994	size += nla_total_size(0) +		/* INET_ULP_INFO_TLS */
    995		nla_total_size(sizeof(u16)) +	/* TLS_INFO_VERSION */
    996		nla_total_size(sizeof(u16)) +	/* TLS_INFO_CIPHER */
    997		nla_total_size(sizeof(u16)) +	/* TLS_INFO_RXCONF */
    998		nla_total_size(sizeof(u16)) +	/* TLS_INFO_TXCONF */
    999		nla_total_size(0) +		/* TLS_INFO_ZC_RO_TX */
   1000		0;
   1001
   1002	return size;
   1003}
   1004
   1005static int __net_init tls_init_net(struct net *net)
   1006{
   1007	int err;
   1008
   1009	net->mib.tls_statistics = alloc_percpu(struct linux_tls_mib);
   1010	if (!net->mib.tls_statistics)
   1011		return -ENOMEM;
   1012
   1013	err = tls_proc_init(net);
   1014	if (err)
   1015		goto err_free_stats;
   1016
   1017	return 0;
   1018err_free_stats:
   1019	free_percpu(net->mib.tls_statistics);
   1020	return err;
   1021}
   1022
   1023static void __net_exit tls_exit_net(struct net *net)
   1024{
   1025	tls_proc_fini(net);
   1026	free_percpu(net->mib.tls_statistics);
   1027}
   1028
   1029static struct pernet_operations tls_proc_ops = {
   1030	.init = tls_init_net,
   1031	.exit = tls_exit_net,
   1032};
   1033
   1034static struct tcp_ulp_ops tcp_tls_ulp_ops __read_mostly = {
   1035	.name			= "tls",
   1036	.owner			= THIS_MODULE,
   1037	.init			= tls_init,
   1038	.update			= tls_update,
   1039	.get_info		= tls_get_info,
   1040	.get_info_size		= tls_get_info_size,
   1041};
   1042
   1043static int __init tls_register(void)
   1044{
   1045	int err;
   1046
   1047	err = register_pernet_subsys(&tls_proc_ops);
   1048	if (err)
   1049		return err;
   1050
   1051	tls_device_init();
   1052	tcp_register_ulp(&tcp_tls_ulp_ops);
   1053
   1054	return 0;
   1055}
   1056
   1057static void __exit tls_unregister(void)
   1058{
   1059	tcp_unregister_ulp(&tcp_tls_ulp_ops);
   1060	tls_device_cleanup();
   1061	unregister_pernet_subsys(&tls_proc_ops);
   1062}
   1063
   1064module_init(tls_register);
   1065module_exit(tls_unregister);