cachepc-linux

Fork of AMDESE/linux with modifications for CachePC side-channel attack
git clone https://git.sinitax.com/sinitax/cachepc-linux
Log | Files | Refs | README | LICENSE | sfeed.txt

vsock.c (25220B)


      1// SPDX-License-Identifier: GPL-2.0-only
      2/*
      3 * vhost transport for vsock
      4 *
      5 * Copyright (C) 2013-2015 Red Hat, Inc.
      6 * Author: Asias He <asias@redhat.com>
      7 *         Stefan Hajnoczi <stefanha@redhat.com>
      8 */
      9#include <linux/miscdevice.h>
     10#include <linux/atomic.h>
     11#include <linux/module.h>
     12#include <linux/mutex.h>
     13#include <linux/vmalloc.h>
     14#include <net/sock.h>
     15#include <linux/virtio_vsock.h>
     16#include <linux/vhost.h>
     17#include <linux/hashtable.h>
     18
     19#include <net/af_vsock.h>
     20#include "vhost.h"
     21
     22#define VHOST_VSOCK_DEFAULT_HOST_CID	2
     23/* Max number of bytes transferred before requeueing the job.
     24 * Using this limit prevents one virtqueue from starving others. */
     25#define VHOST_VSOCK_WEIGHT 0x80000
     26/* Max number of packets transferred before requeueing the job.
     27 * Using this limit prevents one virtqueue from starving others with
     28 * small pkts.
     29 */
     30#define VHOST_VSOCK_PKT_WEIGHT 256
     31
     32enum {
     33	VHOST_VSOCK_FEATURES = VHOST_FEATURES |
     34			       (1ULL << VIRTIO_F_ACCESS_PLATFORM) |
     35			       (1ULL << VIRTIO_VSOCK_F_SEQPACKET)
     36};
     37
     38enum {
     39	VHOST_VSOCK_BACKEND_FEATURES = (1ULL << VHOST_BACKEND_F_IOTLB_MSG_V2)
     40};
     41
     42/* Used to track all the vhost_vsock instances on the system. */
     43static DEFINE_MUTEX(vhost_vsock_mutex);
     44static DEFINE_READ_MOSTLY_HASHTABLE(vhost_vsock_hash, 8);
     45
     46struct vhost_vsock {
     47	struct vhost_dev dev;
     48	struct vhost_virtqueue vqs[2];
     49
     50	/* Link to global vhost_vsock_hash, writes use vhost_vsock_mutex */
     51	struct hlist_node hash;
     52
     53	struct vhost_work send_pkt_work;
     54	spinlock_t send_pkt_list_lock;
     55	struct list_head send_pkt_list;	/* host->guest pending packets */
     56
     57	atomic_t queued_replies;
     58
     59	u32 guest_cid;
     60	bool seqpacket_allow;
     61};
     62
     63static u32 vhost_transport_get_local_cid(void)
     64{
     65	return VHOST_VSOCK_DEFAULT_HOST_CID;
     66}
     67
     68/* Callers that dereference the return value must hold vhost_vsock_mutex or the
     69 * RCU read lock.
     70 */
     71static struct vhost_vsock *vhost_vsock_get(u32 guest_cid)
     72{
     73	struct vhost_vsock *vsock;
     74
     75	hash_for_each_possible_rcu(vhost_vsock_hash, vsock, hash, guest_cid) {
     76		u32 other_cid = vsock->guest_cid;
     77
     78		/* Skip instances that have no CID yet */
     79		if (other_cid == 0)
     80			continue;
     81
     82		if (other_cid == guest_cid)
     83			return vsock;
     84
     85	}
     86
     87	return NULL;
     88}
     89
     90static void
     91vhost_transport_do_send_pkt(struct vhost_vsock *vsock,
     92			    struct vhost_virtqueue *vq)
     93{
     94	struct vhost_virtqueue *tx_vq = &vsock->vqs[VSOCK_VQ_TX];
     95	int pkts = 0, total_len = 0;
     96	bool added = false;
     97	bool restart_tx = false;
     98
     99	mutex_lock(&vq->mutex);
    100
    101	if (!vhost_vq_get_backend(vq))
    102		goto out;
    103
    104	if (!vq_meta_prefetch(vq))
    105		goto out;
    106
    107	/* Avoid further vmexits, we're already processing the virtqueue */
    108	vhost_disable_notify(&vsock->dev, vq);
    109
    110	do {
    111		struct virtio_vsock_pkt *pkt;
    112		struct iov_iter iov_iter;
    113		unsigned out, in;
    114		size_t nbytes;
    115		size_t iov_len, payload_len;
    116		int head;
    117		u32 flags_to_restore = 0;
    118
    119		spin_lock_bh(&vsock->send_pkt_list_lock);
    120		if (list_empty(&vsock->send_pkt_list)) {
    121			spin_unlock_bh(&vsock->send_pkt_list_lock);
    122			vhost_enable_notify(&vsock->dev, vq);
    123			break;
    124		}
    125
    126		pkt = list_first_entry(&vsock->send_pkt_list,
    127				       struct virtio_vsock_pkt, list);
    128		list_del_init(&pkt->list);
    129		spin_unlock_bh(&vsock->send_pkt_list_lock);
    130
    131		head = vhost_get_vq_desc(vq, vq->iov, ARRAY_SIZE(vq->iov),
    132					 &out, &in, NULL, NULL);
    133		if (head < 0) {
    134			spin_lock_bh(&vsock->send_pkt_list_lock);
    135			list_add(&pkt->list, &vsock->send_pkt_list);
    136			spin_unlock_bh(&vsock->send_pkt_list_lock);
    137			break;
    138		}
    139
    140		if (head == vq->num) {
    141			spin_lock_bh(&vsock->send_pkt_list_lock);
    142			list_add(&pkt->list, &vsock->send_pkt_list);
    143			spin_unlock_bh(&vsock->send_pkt_list_lock);
    144
    145			/* We cannot finish yet if more buffers snuck in while
    146			 * re-enabling notify.
    147			 */
    148			if (unlikely(vhost_enable_notify(&vsock->dev, vq))) {
    149				vhost_disable_notify(&vsock->dev, vq);
    150				continue;
    151			}
    152			break;
    153		}
    154
    155		if (out) {
    156			virtio_transport_free_pkt(pkt);
    157			vq_err(vq, "Expected 0 output buffers, got %u\n", out);
    158			break;
    159		}
    160
    161		iov_len = iov_length(&vq->iov[out], in);
    162		if (iov_len < sizeof(pkt->hdr)) {
    163			virtio_transport_free_pkt(pkt);
    164			vq_err(vq, "Buffer len [%zu] too small\n", iov_len);
    165			break;
    166		}
    167
    168		iov_iter_init(&iov_iter, READ, &vq->iov[out], in, iov_len);
    169		payload_len = pkt->len - pkt->off;
    170
    171		/* If the packet is greater than the space available in the
    172		 * buffer, we split it using multiple buffers.
    173		 */
    174		if (payload_len > iov_len - sizeof(pkt->hdr)) {
    175			payload_len = iov_len - sizeof(pkt->hdr);
    176
    177			/* As we are copying pieces of large packet's buffer to
    178			 * small rx buffers, headers of packets in rx queue are
    179			 * created dynamically and are initialized with header
    180			 * of current packet(except length). But in case of
    181			 * SOCK_SEQPACKET, we also must clear message delimeter
    182			 * bit (VIRTIO_VSOCK_SEQ_EOM) and MSG_EOR bit
    183			 * (VIRTIO_VSOCK_SEQ_EOR) if set. Otherwise,
    184			 * there will be sequence of packets with these
    185			 * bits set. After initialized header will be copied to
    186			 * rx buffer, these required bits will be restored.
    187			 */
    188			if (le32_to_cpu(pkt->hdr.flags) & VIRTIO_VSOCK_SEQ_EOM) {
    189				pkt->hdr.flags &= ~cpu_to_le32(VIRTIO_VSOCK_SEQ_EOM);
    190				flags_to_restore |= VIRTIO_VSOCK_SEQ_EOM;
    191
    192				if (le32_to_cpu(pkt->hdr.flags) & VIRTIO_VSOCK_SEQ_EOR) {
    193					pkt->hdr.flags &= ~cpu_to_le32(VIRTIO_VSOCK_SEQ_EOR);
    194					flags_to_restore |= VIRTIO_VSOCK_SEQ_EOR;
    195				}
    196			}
    197		}
    198
    199		/* Set the correct length in the header */
    200		pkt->hdr.len = cpu_to_le32(payload_len);
    201
    202		nbytes = copy_to_iter(&pkt->hdr, sizeof(pkt->hdr), &iov_iter);
    203		if (nbytes != sizeof(pkt->hdr)) {
    204			virtio_transport_free_pkt(pkt);
    205			vq_err(vq, "Faulted on copying pkt hdr\n");
    206			break;
    207		}
    208
    209		nbytes = copy_to_iter(pkt->buf + pkt->off, payload_len,
    210				      &iov_iter);
    211		if (nbytes != payload_len) {
    212			virtio_transport_free_pkt(pkt);
    213			vq_err(vq, "Faulted on copying pkt buf\n");
    214			break;
    215		}
    216
    217		/* Deliver to monitoring devices all packets that we
    218		 * will transmit.
    219		 */
    220		virtio_transport_deliver_tap_pkt(pkt);
    221
    222		vhost_add_used(vq, head, sizeof(pkt->hdr) + payload_len);
    223		added = true;
    224
    225		pkt->off += payload_len;
    226		total_len += payload_len;
    227
    228		/* If we didn't send all the payload we can requeue the packet
    229		 * to send it with the next available buffer.
    230		 */
    231		if (pkt->off < pkt->len) {
    232			pkt->hdr.flags |= cpu_to_le32(flags_to_restore);
    233
    234			/* We are queueing the same virtio_vsock_pkt to handle
    235			 * the remaining bytes, and we want to deliver it
    236			 * to monitoring devices in the next iteration.
    237			 */
    238			pkt->tap_delivered = false;
    239
    240			spin_lock_bh(&vsock->send_pkt_list_lock);
    241			list_add(&pkt->list, &vsock->send_pkt_list);
    242			spin_unlock_bh(&vsock->send_pkt_list_lock);
    243		} else {
    244			if (pkt->reply) {
    245				int val;
    246
    247				val = atomic_dec_return(&vsock->queued_replies);
    248
    249				/* Do we have resources to resume tx
    250				 * processing?
    251				 */
    252				if (val + 1 == tx_vq->num)
    253					restart_tx = true;
    254			}
    255
    256			virtio_transport_free_pkt(pkt);
    257		}
    258	} while(likely(!vhost_exceeds_weight(vq, ++pkts, total_len)));
    259	if (added)
    260		vhost_signal(&vsock->dev, vq);
    261
    262out:
    263	mutex_unlock(&vq->mutex);
    264
    265	if (restart_tx)
    266		vhost_poll_queue(&tx_vq->poll);
    267}
    268
    269static void vhost_transport_send_pkt_work(struct vhost_work *work)
    270{
    271	struct vhost_virtqueue *vq;
    272	struct vhost_vsock *vsock;
    273
    274	vsock = container_of(work, struct vhost_vsock, send_pkt_work);
    275	vq = &vsock->vqs[VSOCK_VQ_RX];
    276
    277	vhost_transport_do_send_pkt(vsock, vq);
    278}
    279
    280static int
    281vhost_transport_send_pkt(struct virtio_vsock_pkt *pkt)
    282{
    283	struct vhost_vsock *vsock;
    284	int len = pkt->len;
    285
    286	rcu_read_lock();
    287
    288	/* Find the vhost_vsock according to guest context id  */
    289	vsock = vhost_vsock_get(le64_to_cpu(pkt->hdr.dst_cid));
    290	if (!vsock) {
    291		rcu_read_unlock();
    292		virtio_transport_free_pkt(pkt);
    293		return -ENODEV;
    294	}
    295
    296	if (pkt->reply)
    297		atomic_inc(&vsock->queued_replies);
    298
    299	spin_lock_bh(&vsock->send_pkt_list_lock);
    300	list_add_tail(&pkt->list, &vsock->send_pkt_list);
    301	spin_unlock_bh(&vsock->send_pkt_list_lock);
    302
    303	vhost_work_queue(&vsock->dev, &vsock->send_pkt_work);
    304
    305	rcu_read_unlock();
    306	return len;
    307}
    308
    309static int
    310vhost_transport_cancel_pkt(struct vsock_sock *vsk)
    311{
    312	struct vhost_vsock *vsock;
    313	struct virtio_vsock_pkt *pkt, *n;
    314	int cnt = 0;
    315	int ret = -ENODEV;
    316	LIST_HEAD(freeme);
    317
    318	rcu_read_lock();
    319
    320	/* Find the vhost_vsock according to guest context id  */
    321	vsock = vhost_vsock_get(vsk->remote_addr.svm_cid);
    322	if (!vsock)
    323		goto out;
    324
    325	spin_lock_bh(&vsock->send_pkt_list_lock);
    326	list_for_each_entry_safe(pkt, n, &vsock->send_pkt_list, list) {
    327		if (pkt->vsk != vsk)
    328			continue;
    329		list_move(&pkt->list, &freeme);
    330	}
    331	spin_unlock_bh(&vsock->send_pkt_list_lock);
    332
    333	list_for_each_entry_safe(pkt, n, &freeme, list) {
    334		if (pkt->reply)
    335			cnt++;
    336		list_del(&pkt->list);
    337		virtio_transport_free_pkt(pkt);
    338	}
    339
    340	if (cnt) {
    341		struct vhost_virtqueue *tx_vq = &vsock->vqs[VSOCK_VQ_TX];
    342		int new_cnt;
    343
    344		new_cnt = atomic_sub_return(cnt, &vsock->queued_replies);
    345		if (new_cnt + cnt >= tx_vq->num && new_cnt < tx_vq->num)
    346			vhost_poll_queue(&tx_vq->poll);
    347	}
    348
    349	ret = 0;
    350out:
    351	rcu_read_unlock();
    352	return ret;
    353}
    354
    355static struct virtio_vsock_pkt *
    356vhost_vsock_alloc_pkt(struct vhost_virtqueue *vq,
    357		      unsigned int out, unsigned int in)
    358{
    359	struct virtio_vsock_pkt *pkt;
    360	struct iov_iter iov_iter;
    361	size_t nbytes;
    362	size_t len;
    363
    364	if (in != 0) {
    365		vq_err(vq, "Expected 0 input buffers, got %u\n", in);
    366		return NULL;
    367	}
    368
    369	pkt = kzalloc(sizeof(*pkt), GFP_KERNEL);
    370	if (!pkt)
    371		return NULL;
    372
    373	len = iov_length(vq->iov, out);
    374	iov_iter_init(&iov_iter, WRITE, vq->iov, out, len);
    375
    376	nbytes = copy_from_iter(&pkt->hdr, sizeof(pkt->hdr), &iov_iter);
    377	if (nbytes != sizeof(pkt->hdr)) {
    378		vq_err(vq, "Expected %zu bytes for pkt->hdr, got %zu bytes\n",
    379		       sizeof(pkt->hdr), nbytes);
    380		kfree(pkt);
    381		return NULL;
    382	}
    383
    384	pkt->len = le32_to_cpu(pkt->hdr.len);
    385
    386	/* No payload */
    387	if (!pkt->len)
    388		return pkt;
    389
    390	/* The pkt is too big */
    391	if (pkt->len > VIRTIO_VSOCK_MAX_PKT_BUF_SIZE) {
    392		kfree(pkt);
    393		return NULL;
    394	}
    395
    396	pkt->buf = kmalloc(pkt->len, GFP_KERNEL);
    397	if (!pkt->buf) {
    398		kfree(pkt);
    399		return NULL;
    400	}
    401
    402	pkt->buf_len = pkt->len;
    403
    404	nbytes = copy_from_iter(pkt->buf, pkt->len, &iov_iter);
    405	if (nbytes != pkt->len) {
    406		vq_err(vq, "Expected %u byte payload, got %zu bytes\n",
    407		       pkt->len, nbytes);
    408		virtio_transport_free_pkt(pkt);
    409		return NULL;
    410	}
    411
    412	return pkt;
    413}
    414
    415/* Is there space left for replies to rx packets? */
    416static bool vhost_vsock_more_replies(struct vhost_vsock *vsock)
    417{
    418	struct vhost_virtqueue *vq = &vsock->vqs[VSOCK_VQ_TX];
    419	int val;
    420
    421	smp_rmb(); /* paired with atomic_inc() and atomic_dec_return() */
    422	val = atomic_read(&vsock->queued_replies);
    423
    424	return val < vq->num;
    425}
    426
    427static bool vhost_transport_seqpacket_allow(u32 remote_cid);
    428
    429static struct virtio_transport vhost_transport = {
    430	.transport = {
    431		.module                   = THIS_MODULE,
    432
    433		.get_local_cid            = vhost_transport_get_local_cid,
    434
    435		.init                     = virtio_transport_do_socket_init,
    436		.destruct                 = virtio_transport_destruct,
    437		.release                  = virtio_transport_release,
    438		.connect                  = virtio_transport_connect,
    439		.shutdown                 = virtio_transport_shutdown,
    440		.cancel_pkt               = vhost_transport_cancel_pkt,
    441
    442		.dgram_enqueue            = virtio_transport_dgram_enqueue,
    443		.dgram_dequeue            = virtio_transport_dgram_dequeue,
    444		.dgram_bind               = virtio_transport_dgram_bind,
    445		.dgram_allow              = virtio_transport_dgram_allow,
    446
    447		.stream_enqueue           = virtio_transport_stream_enqueue,
    448		.stream_dequeue           = virtio_transport_stream_dequeue,
    449		.stream_has_data          = virtio_transport_stream_has_data,
    450		.stream_has_space         = virtio_transport_stream_has_space,
    451		.stream_rcvhiwat          = virtio_transport_stream_rcvhiwat,
    452		.stream_is_active         = virtio_transport_stream_is_active,
    453		.stream_allow             = virtio_transport_stream_allow,
    454
    455		.seqpacket_dequeue        = virtio_transport_seqpacket_dequeue,
    456		.seqpacket_enqueue        = virtio_transport_seqpacket_enqueue,
    457		.seqpacket_allow          = vhost_transport_seqpacket_allow,
    458		.seqpacket_has_data       = virtio_transport_seqpacket_has_data,
    459
    460		.notify_poll_in           = virtio_transport_notify_poll_in,
    461		.notify_poll_out          = virtio_transport_notify_poll_out,
    462		.notify_recv_init         = virtio_transport_notify_recv_init,
    463		.notify_recv_pre_block    = virtio_transport_notify_recv_pre_block,
    464		.notify_recv_pre_dequeue  = virtio_transport_notify_recv_pre_dequeue,
    465		.notify_recv_post_dequeue = virtio_transport_notify_recv_post_dequeue,
    466		.notify_send_init         = virtio_transport_notify_send_init,
    467		.notify_send_pre_block    = virtio_transport_notify_send_pre_block,
    468		.notify_send_pre_enqueue  = virtio_transport_notify_send_pre_enqueue,
    469		.notify_send_post_enqueue = virtio_transport_notify_send_post_enqueue,
    470		.notify_buffer_size       = virtio_transport_notify_buffer_size,
    471
    472	},
    473
    474	.send_pkt = vhost_transport_send_pkt,
    475};
    476
    477static bool vhost_transport_seqpacket_allow(u32 remote_cid)
    478{
    479	struct vhost_vsock *vsock;
    480	bool seqpacket_allow = false;
    481
    482	rcu_read_lock();
    483	vsock = vhost_vsock_get(remote_cid);
    484
    485	if (vsock)
    486		seqpacket_allow = vsock->seqpacket_allow;
    487
    488	rcu_read_unlock();
    489
    490	return seqpacket_allow;
    491}
    492
    493static void vhost_vsock_handle_tx_kick(struct vhost_work *work)
    494{
    495	struct vhost_virtqueue *vq = container_of(work, struct vhost_virtqueue,
    496						  poll.work);
    497	struct vhost_vsock *vsock = container_of(vq->dev, struct vhost_vsock,
    498						 dev);
    499	struct virtio_vsock_pkt *pkt;
    500	int head, pkts = 0, total_len = 0;
    501	unsigned int out, in;
    502	bool added = false;
    503
    504	mutex_lock(&vq->mutex);
    505
    506	if (!vhost_vq_get_backend(vq))
    507		goto out;
    508
    509	if (!vq_meta_prefetch(vq))
    510		goto out;
    511
    512	vhost_disable_notify(&vsock->dev, vq);
    513	do {
    514		if (!vhost_vsock_more_replies(vsock)) {
    515			/* Stop tx until the device processes already
    516			 * pending replies.  Leave tx virtqueue
    517			 * callbacks disabled.
    518			 */
    519			goto no_more_replies;
    520		}
    521
    522		head = vhost_get_vq_desc(vq, vq->iov, ARRAY_SIZE(vq->iov),
    523					 &out, &in, NULL, NULL);
    524		if (head < 0)
    525			break;
    526
    527		if (head == vq->num) {
    528			if (unlikely(vhost_enable_notify(&vsock->dev, vq))) {
    529				vhost_disable_notify(&vsock->dev, vq);
    530				continue;
    531			}
    532			break;
    533		}
    534
    535		pkt = vhost_vsock_alloc_pkt(vq, out, in);
    536		if (!pkt) {
    537			vq_err(vq, "Faulted on pkt\n");
    538			continue;
    539		}
    540
    541		total_len += sizeof(pkt->hdr) + pkt->len;
    542
    543		/* Deliver to monitoring devices all received packets */
    544		virtio_transport_deliver_tap_pkt(pkt);
    545
    546		/* Only accept correctly addressed packets */
    547		if (le64_to_cpu(pkt->hdr.src_cid) == vsock->guest_cid &&
    548		    le64_to_cpu(pkt->hdr.dst_cid) ==
    549		    vhost_transport_get_local_cid())
    550			virtio_transport_recv_pkt(&vhost_transport, pkt);
    551		else
    552			virtio_transport_free_pkt(pkt);
    553
    554		vhost_add_used(vq, head, 0);
    555		added = true;
    556	} while(likely(!vhost_exceeds_weight(vq, ++pkts, total_len)));
    557
    558no_more_replies:
    559	if (added)
    560		vhost_signal(&vsock->dev, vq);
    561
    562out:
    563	mutex_unlock(&vq->mutex);
    564}
    565
    566static void vhost_vsock_handle_rx_kick(struct vhost_work *work)
    567{
    568	struct vhost_virtqueue *vq = container_of(work, struct vhost_virtqueue,
    569						poll.work);
    570	struct vhost_vsock *vsock = container_of(vq->dev, struct vhost_vsock,
    571						 dev);
    572
    573	vhost_transport_do_send_pkt(vsock, vq);
    574}
    575
    576static int vhost_vsock_start(struct vhost_vsock *vsock)
    577{
    578	struct vhost_virtqueue *vq;
    579	size_t i;
    580	int ret;
    581
    582	mutex_lock(&vsock->dev.mutex);
    583
    584	ret = vhost_dev_check_owner(&vsock->dev);
    585	if (ret)
    586		goto err;
    587
    588	for (i = 0; i < ARRAY_SIZE(vsock->vqs); i++) {
    589		vq = &vsock->vqs[i];
    590
    591		mutex_lock(&vq->mutex);
    592
    593		if (!vhost_vq_access_ok(vq)) {
    594			ret = -EFAULT;
    595			goto err_vq;
    596		}
    597
    598		if (!vhost_vq_get_backend(vq)) {
    599			vhost_vq_set_backend(vq, vsock);
    600			ret = vhost_vq_init_access(vq);
    601			if (ret)
    602				goto err_vq;
    603		}
    604
    605		mutex_unlock(&vq->mutex);
    606	}
    607
    608	/* Some packets may have been queued before the device was started,
    609	 * let's kick the send worker to send them.
    610	 */
    611	vhost_work_queue(&vsock->dev, &vsock->send_pkt_work);
    612
    613	mutex_unlock(&vsock->dev.mutex);
    614	return 0;
    615
    616err_vq:
    617	vhost_vq_set_backend(vq, NULL);
    618	mutex_unlock(&vq->mutex);
    619
    620	for (i = 0; i < ARRAY_SIZE(vsock->vqs); i++) {
    621		vq = &vsock->vqs[i];
    622
    623		mutex_lock(&vq->mutex);
    624		vhost_vq_set_backend(vq, NULL);
    625		mutex_unlock(&vq->mutex);
    626	}
    627err:
    628	mutex_unlock(&vsock->dev.mutex);
    629	return ret;
    630}
    631
    632static int vhost_vsock_stop(struct vhost_vsock *vsock, bool check_owner)
    633{
    634	size_t i;
    635	int ret = 0;
    636
    637	mutex_lock(&vsock->dev.mutex);
    638
    639	if (check_owner) {
    640		ret = vhost_dev_check_owner(&vsock->dev);
    641		if (ret)
    642			goto err;
    643	}
    644
    645	for (i = 0; i < ARRAY_SIZE(vsock->vqs); i++) {
    646		struct vhost_virtqueue *vq = &vsock->vqs[i];
    647
    648		mutex_lock(&vq->mutex);
    649		vhost_vq_set_backend(vq, NULL);
    650		mutex_unlock(&vq->mutex);
    651	}
    652
    653err:
    654	mutex_unlock(&vsock->dev.mutex);
    655	return ret;
    656}
    657
    658static void vhost_vsock_free(struct vhost_vsock *vsock)
    659{
    660	kvfree(vsock);
    661}
    662
    663static int vhost_vsock_dev_open(struct inode *inode, struct file *file)
    664{
    665	struct vhost_virtqueue **vqs;
    666	struct vhost_vsock *vsock;
    667	int ret;
    668
    669	/* This struct is large and allocation could fail, fall back to vmalloc
    670	 * if there is no other way.
    671	 */
    672	vsock = kvmalloc(sizeof(*vsock), GFP_KERNEL | __GFP_RETRY_MAYFAIL);
    673	if (!vsock)
    674		return -ENOMEM;
    675
    676	vqs = kmalloc_array(ARRAY_SIZE(vsock->vqs), sizeof(*vqs), GFP_KERNEL);
    677	if (!vqs) {
    678		ret = -ENOMEM;
    679		goto out;
    680	}
    681
    682	vsock->guest_cid = 0; /* no CID assigned yet */
    683
    684	atomic_set(&vsock->queued_replies, 0);
    685
    686	vqs[VSOCK_VQ_TX] = &vsock->vqs[VSOCK_VQ_TX];
    687	vqs[VSOCK_VQ_RX] = &vsock->vqs[VSOCK_VQ_RX];
    688	vsock->vqs[VSOCK_VQ_TX].handle_kick = vhost_vsock_handle_tx_kick;
    689	vsock->vqs[VSOCK_VQ_RX].handle_kick = vhost_vsock_handle_rx_kick;
    690
    691	vhost_dev_init(&vsock->dev, vqs, ARRAY_SIZE(vsock->vqs),
    692		       UIO_MAXIOV, VHOST_VSOCK_PKT_WEIGHT,
    693		       VHOST_VSOCK_WEIGHT, true, NULL);
    694
    695	file->private_data = vsock;
    696	spin_lock_init(&vsock->send_pkt_list_lock);
    697	INIT_LIST_HEAD(&vsock->send_pkt_list);
    698	vhost_work_init(&vsock->send_pkt_work, vhost_transport_send_pkt_work);
    699	return 0;
    700
    701out:
    702	vhost_vsock_free(vsock);
    703	return ret;
    704}
    705
    706static void vhost_vsock_flush(struct vhost_vsock *vsock)
    707{
    708	vhost_dev_flush(&vsock->dev);
    709}
    710
    711static void vhost_vsock_reset_orphans(struct sock *sk)
    712{
    713	struct vsock_sock *vsk = vsock_sk(sk);
    714
    715	/* vmci_transport.c doesn't take sk_lock here either.  At least we're
    716	 * under vsock_table_lock so the sock cannot disappear while we're
    717	 * executing.
    718	 */
    719
    720	/* If the peer is still valid, no need to reset connection */
    721	if (vhost_vsock_get(vsk->remote_addr.svm_cid))
    722		return;
    723
    724	/* If the close timeout is pending, let it expire.  This avoids races
    725	 * with the timeout callback.
    726	 */
    727	if (vsk->close_work_scheduled)
    728		return;
    729
    730	sock_set_flag(sk, SOCK_DONE);
    731	vsk->peer_shutdown = SHUTDOWN_MASK;
    732	sk->sk_state = SS_UNCONNECTED;
    733	sk->sk_err = ECONNRESET;
    734	sk_error_report(sk);
    735}
    736
    737static int vhost_vsock_dev_release(struct inode *inode, struct file *file)
    738{
    739	struct vhost_vsock *vsock = file->private_data;
    740
    741	mutex_lock(&vhost_vsock_mutex);
    742	if (vsock->guest_cid)
    743		hash_del_rcu(&vsock->hash);
    744	mutex_unlock(&vhost_vsock_mutex);
    745
    746	/* Wait for other CPUs to finish using vsock */
    747	synchronize_rcu();
    748
    749	/* Iterating over all connections for all CIDs to find orphans is
    750	 * inefficient.  Room for improvement here. */
    751	vsock_for_each_connected_socket(&vhost_transport.transport,
    752					vhost_vsock_reset_orphans);
    753
    754	/* Don't check the owner, because we are in the release path, so we
    755	 * need to stop the vsock device in any case.
    756	 * vhost_vsock_stop() can not fail in this case, so we don't need to
    757	 * check the return code.
    758	 */
    759	vhost_vsock_stop(vsock, false);
    760	vhost_vsock_flush(vsock);
    761	vhost_dev_stop(&vsock->dev);
    762
    763	spin_lock_bh(&vsock->send_pkt_list_lock);
    764	while (!list_empty(&vsock->send_pkt_list)) {
    765		struct virtio_vsock_pkt *pkt;
    766
    767		pkt = list_first_entry(&vsock->send_pkt_list,
    768				struct virtio_vsock_pkt, list);
    769		list_del_init(&pkt->list);
    770		virtio_transport_free_pkt(pkt);
    771	}
    772	spin_unlock_bh(&vsock->send_pkt_list_lock);
    773
    774	vhost_dev_cleanup(&vsock->dev);
    775	kfree(vsock->dev.vqs);
    776	vhost_vsock_free(vsock);
    777	return 0;
    778}
    779
    780static int vhost_vsock_set_cid(struct vhost_vsock *vsock, u64 guest_cid)
    781{
    782	struct vhost_vsock *other;
    783
    784	/* Refuse reserved CIDs */
    785	if (guest_cid <= VMADDR_CID_HOST ||
    786	    guest_cid == U32_MAX)
    787		return -EINVAL;
    788
    789	/* 64-bit CIDs are not yet supported */
    790	if (guest_cid > U32_MAX)
    791		return -EINVAL;
    792
    793	/* Refuse if CID is assigned to the guest->host transport (i.e. nested
    794	 * VM), to make the loopback work.
    795	 */
    796	if (vsock_find_cid(guest_cid))
    797		return -EADDRINUSE;
    798
    799	/* Refuse if CID is already in use */
    800	mutex_lock(&vhost_vsock_mutex);
    801	other = vhost_vsock_get(guest_cid);
    802	if (other && other != vsock) {
    803		mutex_unlock(&vhost_vsock_mutex);
    804		return -EADDRINUSE;
    805	}
    806
    807	if (vsock->guest_cid)
    808		hash_del_rcu(&vsock->hash);
    809
    810	vsock->guest_cid = guest_cid;
    811	hash_add_rcu(vhost_vsock_hash, &vsock->hash, vsock->guest_cid);
    812	mutex_unlock(&vhost_vsock_mutex);
    813
    814	return 0;
    815}
    816
    817static int vhost_vsock_set_features(struct vhost_vsock *vsock, u64 features)
    818{
    819	struct vhost_virtqueue *vq;
    820	int i;
    821
    822	if (features & ~VHOST_VSOCK_FEATURES)
    823		return -EOPNOTSUPP;
    824
    825	mutex_lock(&vsock->dev.mutex);
    826	if ((features & (1 << VHOST_F_LOG_ALL)) &&
    827	    !vhost_log_access_ok(&vsock->dev)) {
    828		goto err;
    829	}
    830
    831	if ((features & (1ULL << VIRTIO_F_ACCESS_PLATFORM))) {
    832		if (vhost_init_device_iotlb(&vsock->dev, true))
    833			goto err;
    834	}
    835
    836	if (features & (1ULL << VIRTIO_VSOCK_F_SEQPACKET))
    837		vsock->seqpacket_allow = true;
    838
    839	for (i = 0; i < ARRAY_SIZE(vsock->vqs); i++) {
    840		vq = &vsock->vqs[i];
    841		mutex_lock(&vq->mutex);
    842		vq->acked_features = features;
    843		mutex_unlock(&vq->mutex);
    844	}
    845	mutex_unlock(&vsock->dev.mutex);
    846	return 0;
    847
    848err:
    849	mutex_unlock(&vsock->dev.mutex);
    850	return -EFAULT;
    851}
    852
    853static long vhost_vsock_dev_ioctl(struct file *f, unsigned int ioctl,
    854				  unsigned long arg)
    855{
    856	struct vhost_vsock *vsock = f->private_data;
    857	void __user *argp = (void __user *)arg;
    858	u64 guest_cid;
    859	u64 features;
    860	int start;
    861	int r;
    862
    863	switch (ioctl) {
    864	case VHOST_VSOCK_SET_GUEST_CID:
    865		if (copy_from_user(&guest_cid, argp, sizeof(guest_cid)))
    866			return -EFAULT;
    867		return vhost_vsock_set_cid(vsock, guest_cid);
    868	case VHOST_VSOCK_SET_RUNNING:
    869		if (copy_from_user(&start, argp, sizeof(start)))
    870			return -EFAULT;
    871		if (start)
    872			return vhost_vsock_start(vsock);
    873		else
    874			return vhost_vsock_stop(vsock, true);
    875	case VHOST_GET_FEATURES:
    876		features = VHOST_VSOCK_FEATURES;
    877		if (copy_to_user(argp, &features, sizeof(features)))
    878			return -EFAULT;
    879		return 0;
    880	case VHOST_SET_FEATURES:
    881		if (copy_from_user(&features, argp, sizeof(features)))
    882			return -EFAULT;
    883		return vhost_vsock_set_features(vsock, features);
    884	case VHOST_GET_BACKEND_FEATURES:
    885		features = VHOST_VSOCK_BACKEND_FEATURES;
    886		if (copy_to_user(argp, &features, sizeof(features)))
    887			return -EFAULT;
    888		return 0;
    889	case VHOST_SET_BACKEND_FEATURES:
    890		if (copy_from_user(&features, argp, sizeof(features)))
    891			return -EFAULT;
    892		if (features & ~VHOST_VSOCK_BACKEND_FEATURES)
    893			return -EOPNOTSUPP;
    894		vhost_set_backend_features(&vsock->dev, features);
    895		return 0;
    896	default:
    897		mutex_lock(&vsock->dev.mutex);
    898		r = vhost_dev_ioctl(&vsock->dev, ioctl, argp);
    899		if (r == -ENOIOCTLCMD)
    900			r = vhost_vring_ioctl(&vsock->dev, ioctl, argp);
    901		else
    902			vhost_vsock_flush(vsock);
    903		mutex_unlock(&vsock->dev.mutex);
    904		return r;
    905	}
    906}
    907
    908static ssize_t vhost_vsock_chr_read_iter(struct kiocb *iocb, struct iov_iter *to)
    909{
    910	struct file *file = iocb->ki_filp;
    911	struct vhost_vsock *vsock = file->private_data;
    912	struct vhost_dev *dev = &vsock->dev;
    913	int noblock = file->f_flags & O_NONBLOCK;
    914
    915	return vhost_chr_read_iter(dev, to, noblock);
    916}
    917
    918static ssize_t vhost_vsock_chr_write_iter(struct kiocb *iocb,
    919					struct iov_iter *from)
    920{
    921	struct file *file = iocb->ki_filp;
    922	struct vhost_vsock *vsock = file->private_data;
    923	struct vhost_dev *dev = &vsock->dev;
    924
    925	return vhost_chr_write_iter(dev, from);
    926}
    927
    928static __poll_t vhost_vsock_chr_poll(struct file *file, poll_table *wait)
    929{
    930	struct vhost_vsock *vsock = file->private_data;
    931	struct vhost_dev *dev = &vsock->dev;
    932
    933	return vhost_chr_poll(file, dev, wait);
    934}
    935
    936static const struct file_operations vhost_vsock_fops = {
    937	.owner          = THIS_MODULE,
    938	.open           = vhost_vsock_dev_open,
    939	.release        = vhost_vsock_dev_release,
    940	.llseek		= noop_llseek,
    941	.unlocked_ioctl = vhost_vsock_dev_ioctl,
    942	.compat_ioctl   = compat_ptr_ioctl,
    943	.read_iter      = vhost_vsock_chr_read_iter,
    944	.write_iter     = vhost_vsock_chr_write_iter,
    945	.poll           = vhost_vsock_chr_poll,
    946};
    947
    948static struct miscdevice vhost_vsock_misc = {
    949	.minor = VHOST_VSOCK_MINOR,
    950	.name = "vhost-vsock",
    951	.fops = &vhost_vsock_fops,
    952};
    953
    954static int __init vhost_vsock_init(void)
    955{
    956	int ret;
    957
    958	ret = vsock_core_register(&vhost_transport.transport,
    959				  VSOCK_TRANSPORT_F_H2G);
    960	if (ret < 0)
    961		return ret;
    962	return misc_register(&vhost_vsock_misc);
    963};
    964
    965static void __exit vhost_vsock_exit(void)
    966{
    967	misc_deregister(&vhost_vsock_misc);
    968	vsock_core_unregister(&vhost_transport.transport);
    969};
    970
    971module_init(vhost_vsock_init);
    972module_exit(vhost_vsock_exit);
    973MODULE_LICENSE("GPL v2");
    974MODULE_AUTHOR("Asias He");
    975MODULE_DESCRIPTION("vhost transport for vsock ");
    976MODULE_ALIAS_MISCDEV(VHOST_VSOCK_MINOR);
    977MODULE_ALIAS("devname:vhost-vsock");