cachepc-linux

Fork of AMDESE/linux with modifications for CachePC side-channel attack
git clone https://git.sinitax.com/sinitax/cachepc-linux
Log | Files | Refs | README | LICENSE | sfeed.txt

pvcalls-back.c (30504B)


      1// SPDX-License-Identifier: GPL-2.0-or-later
      2/*
      3 * (c) 2017 Stefano Stabellini <stefano@aporeto.com>
      4 */
      5
      6#include <linux/inet.h>
      7#include <linux/kthread.h>
      8#include <linux/list.h>
      9#include <linux/radix-tree.h>
     10#include <linux/module.h>
     11#include <linux/semaphore.h>
     12#include <linux/wait.h>
     13#include <net/sock.h>
     14#include <net/inet_common.h>
     15#include <net/inet_connection_sock.h>
     16#include <net/request_sock.h>
     17
     18#include <xen/events.h>
     19#include <xen/grant_table.h>
     20#include <xen/xen.h>
     21#include <xen/xenbus.h>
     22#include <xen/interface/io/pvcalls.h>
     23
     24#define PVCALLS_VERSIONS "1"
     25#define MAX_RING_ORDER XENBUS_MAX_RING_GRANT_ORDER
     26
     27static struct pvcalls_back_global {
     28	struct list_head frontends;
     29	struct semaphore frontends_lock;
     30} pvcalls_back_global;
     31
     32/*
     33 * Per-frontend data structure. It contains pointers to the command
     34 * ring, its event channel, a list of active sockets and a tree of
     35 * passive sockets.
     36 */
     37struct pvcalls_fedata {
     38	struct list_head list;
     39	struct xenbus_device *dev;
     40	struct xen_pvcalls_sring *sring;
     41	struct xen_pvcalls_back_ring ring;
     42	int irq;
     43	struct list_head socket_mappings;
     44	struct radix_tree_root socketpass_mappings;
     45	struct semaphore socket_lock;
     46};
     47
     48struct pvcalls_ioworker {
     49	struct work_struct register_work;
     50	struct workqueue_struct *wq;
     51};
     52
     53struct sock_mapping {
     54	struct list_head list;
     55	struct pvcalls_fedata *fedata;
     56	struct sockpass_mapping *sockpass;
     57	struct socket *sock;
     58	uint64_t id;
     59	grant_ref_t ref;
     60	struct pvcalls_data_intf *ring;
     61	void *bytes;
     62	struct pvcalls_data data;
     63	uint32_t ring_order;
     64	int irq;
     65	atomic_t read;
     66	atomic_t write;
     67	atomic_t io;
     68	atomic_t release;
     69	atomic_t eoi;
     70	void (*saved_data_ready)(struct sock *sk);
     71	struct pvcalls_ioworker ioworker;
     72};
     73
     74struct sockpass_mapping {
     75	struct list_head list;
     76	struct pvcalls_fedata *fedata;
     77	struct socket *sock;
     78	uint64_t id;
     79	struct xen_pvcalls_request reqcopy;
     80	spinlock_t copy_lock;
     81	struct workqueue_struct *wq;
     82	struct work_struct register_work;
     83	void (*saved_data_ready)(struct sock *sk);
     84};
     85
     86static irqreturn_t pvcalls_back_conn_event(int irq, void *sock_map);
     87static int pvcalls_back_release_active(struct xenbus_device *dev,
     88				       struct pvcalls_fedata *fedata,
     89				       struct sock_mapping *map);
     90
     91static bool pvcalls_conn_back_read(void *opaque)
     92{
     93	struct sock_mapping *map = (struct sock_mapping *)opaque;
     94	struct msghdr msg;
     95	struct kvec vec[2];
     96	RING_IDX cons, prod, size, wanted, array_size, masked_prod, masked_cons;
     97	int32_t error;
     98	struct pvcalls_data_intf *intf = map->ring;
     99	struct pvcalls_data *data = &map->data;
    100	unsigned long flags;
    101	int ret;
    102
    103	array_size = XEN_FLEX_RING_SIZE(map->ring_order);
    104	cons = intf->in_cons;
    105	prod = intf->in_prod;
    106	error = intf->in_error;
    107	/* read the indexes first, then deal with the data */
    108	virt_mb();
    109
    110	if (error)
    111		return false;
    112
    113	size = pvcalls_queued(prod, cons, array_size);
    114	if (size >= array_size)
    115		return false;
    116	spin_lock_irqsave(&map->sock->sk->sk_receive_queue.lock, flags);
    117	if (skb_queue_empty(&map->sock->sk->sk_receive_queue)) {
    118		atomic_set(&map->read, 0);
    119		spin_unlock_irqrestore(&map->sock->sk->sk_receive_queue.lock,
    120				flags);
    121		return true;
    122	}
    123	spin_unlock_irqrestore(&map->sock->sk->sk_receive_queue.lock, flags);
    124	wanted = array_size - size;
    125	masked_prod = pvcalls_mask(prod, array_size);
    126	masked_cons = pvcalls_mask(cons, array_size);
    127
    128	memset(&msg, 0, sizeof(msg));
    129	if (masked_prod < masked_cons) {
    130		vec[0].iov_base = data->in + masked_prod;
    131		vec[0].iov_len = wanted;
    132		iov_iter_kvec(&msg.msg_iter, WRITE, vec, 1, wanted);
    133	} else {
    134		vec[0].iov_base = data->in + masked_prod;
    135		vec[0].iov_len = array_size - masked_prod;
    136		vec[1].iov_base = data->in;
    137		vec[1].iov_len = wanted - vec[0].iov_len;
    138		iov_iter_kvec(&msg.msg_iter, WRITE, vec, 2, wanted);
    139	}
    140
    141	atomic_set(&map->read, 0);
    142	ret = inet_recvmsg(map->sock, &msg, wanted, MSG_DONTWAIT);
    143	WARN_ON(ret > wanted);
    144	if (ret == -EAGAIN) /* shouldn't happen */
    145		return true;
    146	if (!ret)
    147		ret = -ENOTCONN;
    148	spin_lock_irqsave(&map->sock->sk->sk_receive_queue.lock, flags);
    149	if (ret > 0 && !skb_queue_empty(&map->sock->sk->sk_receive_queue))
    150		atomic_inc(&map->read);
    151	spin_unlock_irqrestore(&map->sock->sk->sk_receive_queue.lock, flags);
    152
    153	/* write the data, then modify the indexes */
    154	virt_wmb();
    155	if (ret < 0) {
    156		atomic_set(&map->read, 0);
    157		intf->in_error = ret;
    158	} else
    159		intf->in_prod = prod + ret;
    160	/* update the indexes, then notify the other end */
    161	virt_wmb();
    162	notify_remote_via_irq(map->irq);
    163
    164	return true;
    165}
    166
    167static bool pvcalls_conn_back_write(struct sock_mapping *map)
    168{
    169	struct pvcalls_data_intf *intf = map->ring;
    170	struct pvcalls_data *data = &map->data;
    171	struct msghdr msg;
    172	struct kvec vec[2];
    173	RING_IDX cons, prod, size, array_size;
    174	int ret;
    175
    176	cons = intf->out_cons;
    177	prod = intf->out_prod;
    178	/* read the indexes before dealing with the data */
    179	virt_mb();
    180
    181	array_size = XEN_FLEX_RING_SIZE(map->ring_order);
    182	size = pvcalls_queued(prod, cons, array_size);
    183	if (size == 0)
    184		return false;
    185
    186	memset(&msg, 0, sizeof(msg));
    187	msg.msg_flags |= MSG_DONTWAIT;
    188	if (pvcalls_mask(prod, array_size) > pvcalls_mask(cons, array_size)) {
    189		vec[0].iov_base = data->out + pvcalls_mask(cons, array_size);
    190		vec[0].iov_len = size;
    191		iov_iter_kvec(&msg.msg_iter, READ, vec, 1, size);
    192	} else {
    193		vec[0].iov_base = data->out + pvcalls_mask(cons, array_size);
    194		vec[0].iov_len = array_size - pvcalls_mask(cons, array_size);
    195		vec[1].iov_base = data->out;
    196		vec[1].iov_len = size - vec[0].iov_len;
    197		iov_iter_kvec(&msg.msg_iter, READ, vec, 2, size);
    198	}
    199
    200	atomic_set(&map->write, 0);
    201	ret = inet_sendmsg(map->sock, &msg, size);
    202	if (ret == -EAGAIN) {
    203		atomic_inc(&map->write);
    204		atomic_inc(&map->io);
    205		return true;
    206	}
    207
    208	/* write the data, then update the indexes */
    209	virt_wmb();
    210	if (ret < 0) {
    211		intf->out_error = ret;
    212	} else {
    213		intf->out_error = 0;
    214		intf->out_cons = cons + ret;
    215		prod = intf->out_prod;
    216	}
    217	/* update the indexes, then notify the other end */
    218	virt_wmb();
    219	if (prod != cons + ret) {
    220		atomic_inc(&map->write);
    221		atomic_inc(&map->io);
    222	}
    223	notify_remote_via_irq(map->irq);
    224
    225	return true;
    226}
    227
    228static void pvcalls_back_ioworker(struct work_struct *work)
    229{
    230	struct pvcalls_ioworker *ioworker = container_of(work,
    231		struct pvcalls_ioworker, register_work);
    232	struct sock_mapping *map = container_of(ioworker, struct sock_mapping,
    233		ioworker);
    234	unsigned int eoi_flags = XEN_EOI_FLAG_SPURIOUS;
    235
    236	while (atomic_read(&map->io) > 0) {
    237		if (atomic_read(&map->release) > 0) {
    238			atomic_set(&map->release, 0);
    239			return;
    240		}
    241
    242		if (atomic_read(&map->read) > 0 &&
    243		    pvcalls_conn_back_read(map))
    244			eoi_flags = 0;
    245		if (atomic_read(&map->write) > 0 &&
    246		    pvcalls_conn_back_write(map))
    247			eoi_flags = 0;
    248
    249		if (atomic_read(&map->eoi) > 0 && !atomic_read(&map->write)) {
    250			atomic_set(&map->eoi, 0);
    251			xen_irq_lateeoi(map->irq, eoi_flags);
    252			eoi_flags = XEN_EOI_FLAG_SPURIOUS;
    253		}
    254
    255		atomic_dec(&map->io);
    256	}
    257}
    258
    259static int pvcalls_back_socket(struct xenbus_device *dev,
    260		struct xen_pvcalls_request *req)
    261{
    262	struct pvcalls_fedata *fedata;
    263	int ret;
    264	struct xen_pvcalls_response *rsp;
    265
    266	fedata = dev_get_drvdata(&dev->dev);
    267
    268	if (req->u.socket.domain != AF_INET ||
    269	    req->u.socket.type != SOCK_STREAM ||
    270	    (req->u.socket.protocol != IPPROTO_IP &&
    271	     req->u.socket.protocol != AF_INET))
    272		ret = -EAFNOSUPPORT;
    273	else
    274		ret = 0;
    275
    276	/* leave the actual socket allocation for later */
    277
    278	rsp = RING_GET_RESPONSE(&fedata->ring, fedata->ring.rsp_prod_pvt++);
    279	rsp->req_id = req->req_id;
    280	rsp->cmd = req->cmd;
    281	rsp->u.socket.id = req->u.socket.id;
    282	rsp->ret = ret;
    283
    284	return 0;
    285}
    286
    287static void pvcalls_sk_state_change(struct sock *sock)
    288{
    289	struct sock_mapping *map = sock->sk_user_data;
    290
    291	if (map == NULL)
    292		return;
    293
    294	atomic_inc(&map->read);
    295	notify_remote_via_irq(map->irq);
    296}
    297
    298static void pvcalls_sk_data_ready(struct sock *sock)
    299{
    300	struct sock_mapping *map = sock->sk_user_data;
    301	struct pvcalls_ioworker *iow;
    302
    303	if (map == NULL)
    304		return;
    305
    306	iow = &map->ioworker;
    307	atomic_inc(&map->read);
    308	atomic_inc(&map->io);
    309	queue_work(iow->wq, &iow->register_work);
    310}
    311
    312static struct sock_mapping *pvcalls_new_active_socket(
    313		struct pvcalls_fedata *fedata,
    314		uint64_t id,
    315		grant_ref_t ref,
    316		evtchn_port_t evtchn,
    317		struct socket *sock)
    318{
    319	int ret;
    320	struct sock_mapping *map;
    321	void *page;
    322
    323	map = kzalloc(sizeof(*map), GFP_KERNEL);
    324	if (map == NULL)
    325		return NULL;
    326
    327	map->fedata = fedata;
    328	map->sock = sock;
    329	map->id = id;
    330	map->ref = ref;
    331
    332	ret = xenbus_map_ring_valloc(fedata->dev, &ref, 1, &page);
    333	if (ret < 0)
    334		goto out;
    335	map->ring = page;
    336	map->ring_order = map->ring->ring_order;
    337	/* first read the order, then map the data ring */
    338	virt_rmb();
    339	if (map->ring_order > MAX_RING_ORDER) {
    340		pr_warn("%s frontend requested ring_order %u, which is > MAX (%u)\n",
    341				__func__, map->ring_order, MAX_RING_ORDER);
    342		goto out;
    343	}
    344	ret = xenbus_map_ring_valloc(fedata->dev, map->ring->ref,
    345				     (1 << map->ring_order), &page);
    346	if (ret < 0)
    347		goto out;
    348	map->bytes = page;
    349
    350	ret = bind_interdomain_evtchn_to_irqhandler_lateeoi(
    351			fedata->dev, evtchn,
    352			pvcalls_back_conn_event, 0, "pvcalls-backend", map);
    353	if (ret < 0)
    354		goto out;
    355	map->irq = ret;
    356
    357	map->data.in = map->bytes;
    358	map->data.out = map->bytes + XEN_FLEX_RING_SIZE(map->ring_order);
    359
    360	map->ioworker.wq = alloc_workqueue("pvcalls_io", WQ_UNBOUND, 1);
    361	if (!map->ioworker.wq)
    362		goto out;
    363	atomic_set(&map->io, 1);
    364	INIT_WORK(&map->ioworker.register_work,	pvcalls_back_ioworker);
    365
    366	down(&fedata->socket_lock);
    367	list_add_tail(&map->list, &fedata->socket_mappings);
    368	up(&fedata->socket_lock);
    369
    370	write_lock_bh(&map->sock->sk->sk_callback_lock);
    371	map->saved_data_ready = map->sock->sk->sk_data_ready;
    372	map->sock->sk->sk_user_data = map;
    373	map->sock->sk->sk_data_ready = pvcalls_sk_data_ready;
    374	map->sock->sk->sk_state_change = pvcalls_sk_state_change;
    375	write_unlock_bh(&map->sock->sk->sk_callback_lock);
    376
    377	return map;
    378out:
    379	down(&fedata->socket_lock);
    380	list_del(&map->list);
    381	pvcalls_back_release_active(fedata->dev, fedata, map);
    382	up(&fedata->socket_lock);
    383	return NULL;
    384}
    385
    386static int pvcalls_back_connect(struct xenbus_device *dev,
    387				struct xen_pvcalls_request *req)
    388{
    389	struct pvcalls_fedata *fedata;
    390	int ret = -EINVAL;
    391	struct socket *sock;
    392	struct sock_mapping *map;
    393	struct xen_pvcalls_response *rsp;
    394	struct sockaddr *sa = (struct sockaddr *)&req->u.connect.addr;
    395
    396	fedata = dev_get_drvdata(&dev->dev);
    397
    398	if (req->u.connect.len < sizeof(sa->sa_family) ||
    399	    req->u.connect.len > sizeof(req->u.connect.addr) ||
    400	    sa->sa_family != AF_INET)
    401		goto out;
    402
    403	ret = sock_create(AF_INET, SOCK_STREAM, 0, &sock);
    404	if (ret < 0)
    405		goto out;
    406	ret = inet_stream_connect(sock, sa, req->u.connect.len, 0);
    407	if (ret < 0) {
    408		sock_release(sock);
    409		goto out;
    410	}
    411
    412	map = pvcalls_new_active_socket(fedata,
    413					req->u.connect.id,
    414					req->u.connect.ref,
    415					req->u.connect.evtchn,
    416					sock);
    417	if (!map) {
    418		ret = -EFAULT;
    419		sock_release(sock);
    420	}
    421
    422out:
    423	rsp = RING_GET_RESPONSE(&fedata->ring, fedata->ring.rsp_prod_pvt++);
    424	rsp->req_id = req->req_id;
    425	rsp->cmd = req->cmd;
    426	rsp->u.connect.id = req->u.connect.id;
    427	rsp->ret = ret;
    428
    429	return 0;
    430}
    431
    432static int pvcalls_back_release_active(struct xenbus_device *dev,
    433				       struct pvcalls_fedata *fedata,
    434				       struct sock_mapping *map)
    435{
    436	disable_irq(map->irq);
    437	if (map->sock->sk != NULL) {
    438		write_lock_bh(&map->sock->sk->sk_callback_lock);
    439		map->sock->sk->sk_user_data = NULL;
    440		map->sock->sk->sk_data_ready = map->saved_data_ready;
    441		write_unlock_bh(&map->sock->sk->sk_callback_lock);
    442	}
    443
    444	atomic_set(&map->release, 1);
    445	flush_work(&map->ioworker.register_work);
    446
    447	xenbus_unmap_ring_vfree(dev, map->bytes);
    448	xenbus_unmap_ring_vfree(dev, (void *)map->ring);
    449	unbind_from_irqhandler(map->irq, map);
    450
    451	sock_release(map->sock);
    452	kfree(map);
    453
    454	return 0;
    455}
    456
    457static int pvcalls_back_release_passive(struct xenbus_device *dev,
    458					struct pvcalls_fedata *fedata,
    459					struct sockpass_mapping *mappass)
    460{
    461	if (mappass->sock->sk != NULL) {
    462		write_lock_bh(&mappass->sock->sk->sk_callback_lock);
    463		mappass->sock->sk->sk_user_data = NULL;
    464		mappass->sock->sk->sk_data_ready = mappass->saved_data_ready;
    465		write_unlock_bh(&mappass->sock->sk->sk_callback_lock);
    466	}
    467	sock_release(mappass->sock);
    468	destroy_workqueue(mappass->wq);
    469	kfree(mappass);
    470
    471	return 0;
    472}
    473
    474static int pvcalls_back_release(struct xenbus_device *dev,
    475				struct xen_pvcalls_request *req)
    476{
    477	struct pvcalls_fedata *fedata;
    478	struct sock_mapping *map, *n;
    479	struct sockpass_mapping *mappass;
    480	int ret = 0;
    481	struct xen_pvcalls_response *rsp;
    482
    483	fedata = dev_get_drvdata(&dev->dev);
    484
    485	down(&fedata->socket_lock);
    486	list_for_each_entry_safe(map, n, &fedata->socket_mappings, list) {
    487		if (map->id == req->u.release.id) {
    488			list_del(&map->list);
    489			up(&fedata->socket_lock);
    490			ret = pvcalls_back_release_active(dev, fedata, map);
    491			goto out;
    492		}
    493	}
    494	mappass = radix_tree_lookup(&fedata->socketpass_mappings,
    495				    req->u.release.id);
    496	if (mappass != NULL) {
    497		radix_tree_delete(&fedata->socketpass_mappings, mappass->id);
    498		up(&fedata->socket_lock);
    499		ret = pvcalls_back_release_passive(dev, fedata, mappass);
    500	} else
    501		up(&fedata->socket_lock);
    502
    503out:
    504	rsp = RING_GET_RESPONSE(&fedata->ring, fedata->ring.rsp_prod_pvt++);
    505	rsp->req_id = req->req_id;
    506	rsp->u.release.id = req->u.release.id;
    507	rsp->cmd = req->cmd;
    508	rsp->ret = ret;
    509	return 0;
    510}
    511
    512static void __pvcalls_back_accept(struct work_struct *work)
    513{
    514	struct sockpass_mapping *mappass = container_of(
    515		work, struct sockpass_mapping, register_work);
    516	struct sock_mapping *map;
    517	struct pvcalls_ioworker *iow;
    518	struct pvcalls_fedata *fedata;
    519	struct socket *sock;
    520	struct xen_pvcalls_response *rsp;
    521	struct xen_pvcalls_request *req;
    522	int notify;
    523	int ret = -EINVAL;
    524	unsigned long flags;
    525
    526	fedata = mappass->fedata;
    527	/*
    528	 * __pvcalls_back_accept can race against pvcalls_back_accept.
    529	 * We only need to check the value of "cmd" on read. It could be
    530	 * done atomically, but to simplify the code on the write side, we
    531	 * use a spinlock.
    532	 */
    533	spin_lock_irqsave(&mappass->copy_lock, flags);
    534	req = &mappass->reqcopy;
    535	if (req->cmd != PVCALLS_ACCEPT) {
    536		spin_unlock_irqrestore(&mappass->copy_lock, flags);
    537		return;
    538	}
    539	spin_unlock_irqrestore(&mappass->copy_lock, flags);
    540
    541	sock = sock_alloc();
    542	if (sock == NULL)
    543		goto out_error;
    544	sock->type = mappass->sock->type;
    545	sock->ops = mappass->sock->ops;
    546
    547	ret = inet_accept(mappass->sock, sock, O_NONBLOCK, true);
    548	if (ret == -EAGAIN) {
    549		sock_release(sock);
    550		return;
    551	}
    552
    553	map = pvcalls_new_active_socket(fedata,
    554					req->u.accept.id_new,
    555					req->u.accept.ref,
    556					req->u.accept.evtchn,
    557					sock);
    558	if (!map) {
    559		ret = -EFAULT;
    560		sock_release(sock);
    561		goto out_error;
    562	}
    563
    564	map->sockpass = mappass;
    565	iow = &map->ioworker;
    566	atomic_inc(&map->read);
    567	atomic_inc(&map->io);
    568	queue_work(iow->wq, &iow->register_work);
    569
    570out_error:
    571	rsp = RING_GET_RESPONSE(&fedata->ring, fedata->ring.rsp_prod_pvt++);
    572	rsp->req_id = req->req_id;
    573	rsp->cmd = req->cmd;
    574	rsp->u.accept.id = req->u.accept.id;
    575	rsp->ret = ret;
    576	RING_PUSH_RESPONSES_AND_CHECK_NOTIFY(&fedata->ring, notify);
    577	if (notify)
    578		notify_remote_via_irq(fedata->irq);
    579
    580	mappass->reqcopy.cmd = 0;
    581}
    582
    583static void pvcalls_pass_sk_data_ready(struct sock *sock)
    584{
    585	struct sockpass_mapping *mappass = sock->sk_user_data;
    586	struct pvcalls_fedata *fedata;
    587	struct xen_pvcalls_response *rsp;
    588	unsigned long flags;
    589	int notify;
    590
    591	if (mappass == NULL)
    592		return;
    593
    594	fedata = mappass->fedata;
    595	spin_lock_irqsave(&mappass->copy_lock, flags);
    596	if (mappass->reqcopy.cmd == PVCALLS_POLL) {
    597		rsp = RING_GET_RESPONSE(&fedata->ring,
    598					fedata->ring.rsp_prod_pvt++);
    599		rsp->req_id = mappass->reqcopy.req_id;
    600		rsp->u.poll.id = mappass->reqcopy.u.poll.id;
    601		rsp->cmd = mappass->reqcopy.cmd;
    602		rsp->ret = 0;
    603
    604		mappass->reqcopy.cmd = 0;
    605		spin_unlock_irqrestore(&mappass->copy_lock, flags);
    606
    607		RING_PUSH_RESPONSES_AND_CHECK_NOTIFY(&fedata->ring, notify);
    608		if (notify)
    609			notify_remote_via_irq(mappass->fedata->irq);
    610	} else {
    611		spin_unlock_irqrestore(&mappass->copy_lock, flags);
    612		queue_work(mappass->wq, &mappass->register_work);
    613	}
    614}
    615
    616static int pvcalls_back_bind(struct xenbus_device *dev,
    617			     struct xen_pvcalls_request *req)
    618{
    619	struct pvcalls_fedata *fedata;
    620	int ret;
    621	struct sockpass_mapping *map;
    622	struct xen_pvcalls_response *rsp;
    623
    624	fedata = dev_get_drvdata(&dev->dev);
    625
    626	map = kzalloc(sizeof(*map), GFP_KERNEL);
    627	if (map == NULL) {
    628		ret = -ENOMEM;
    629		goto out;
    630	}
    631
    632	INIT_WORK(&map->register_work, __pvcalls_back_accept);
    633	spin_lock_init(&map->copy_lock);
    634	map->wq = alloc_workqueue("pvcalls_wq", WQ_UNBOUND, 1);
    635	if (!map->wq) {
    636		ret = -ENOMEM;
    637		goto out;
    638	}
    639
    640	ret = sock_create(AF_INET, SOCK_STREAM, 0, &map->sock);
    641	if (ret < 0)
    642		goto out;
    643
    644	ret = inet_bind(map->sock, (struct sockaddr *)&req->u.bind.addr,
    645			req->u.bind.len);
    646	if (ret < 0)
    647		goto out;
    648
    649	map->fedata = fedata;
    650	map->id = req->u.bind.id;
    651
    652	down(&fedata->socket_lock);
    653	ret = radix_tree_insert(&fedata->socketpass_mappings, map->id,
    654				map);
    655	up(&fedata->socket_lock);
    656	if (ret)
    657		goto out;
    658
    659	write_lock_bh(&map->sock->sk->sk_callback_lock);
    660	map->saved_data_ready = map->sock->sk->sk_data_ready;
    661	map->sock->sk->sk_user_data = map;
    662	map->sock->sk->sk_data_ready = pvcalls_pass_sk_data_ready;
    663	write_unlock_bh(&map->sock->sk->sk_callback_lock);
    664
    665out:
    666	if (ret) {
    667		if (map && map->sock)
    668			sock_release(map->sock);
    669		if (map && map->wq)
    670			destroy_workqueue(map->wq);
    671		kfree(map);
    672	}
    673	rsp = RING_GET_RESPONSE(&fedata->ring, fedata->ring.rsp_prod_pvt++);
    674	rsp->req_id = req->req_id;
    675	rsp->cmd = req->cmd;
    676	rsp->u.bind.id = req->u.bind.id;
    677	rsp->ret = ret;
    678	return 0;
    679}
    680
    681static int pvcalls_back_listen(struct xenbus_device *dev,
    682			       struct xen_pvcalls_request *req)
    683{
    684	struct pvcalls_fedata *fedata;
    685	int ret = -EINVAL;
    686	struct sockpass_mapping *map;
    687	struct xen_pvcalls_response *rsp;
    688
    689	fedata = dev_get_drvdata(&dev->dev);
    690
    691	down(&fedata->socket_lock);
    692	map = radix_tree_lookup(&fedata->socketpass_mappings, req->u.listen.id);
    693	up(&fedata->socket_lock);
    694	if (map == NULL)
    695		goto out;
    696
    697	ret = inet_listen(map->sock, req->u.listen.backlog);
    698
    699out:
    700	rsp = RING_GET_RESPONSE(&fedata->ring, fedata->ring.rsp_prod_pvt++);
    701	rsp->req_id = req->req_id;
    702	rsp->cmd = req->cmd;
    703	rsp->u.listen.id = req->u.listen.id;
    704	rsp->ret = ret;
    705	return 0;
    706}
    707
    708static int pvcalls_back_accept(struct xenbus_device *dev,
    709			       struct xen_pvcalls_request *req)
    710{
    711	struct pvcalls_fedata *fedata;
    712	struct sockpass_mapping *mappass;
    713	int ret = -EINVAL;
    714	struct xen_pvcalls_response *rsp;
    715	unsigned long flags;
    716
    717	fedata = dev_get_drvdata(&dev->dev);
    718
    719	down(&fedata->socket_lock);
    720	mappass = radix_tree_lookup(&fedata->socketpass_mappings,
    721		req->u.accept.id);
    722	up(&fedata->socket_lock);
    723	if (mappass == NULL)
    724		goto out_error;
    725
    726	/*
    727	 * Limitation of the current implementation: only support one
    728	 * concurrent accept or poll call on one socket.
    729	 */
    730	spin_lock_irqsave(&mappass->copy_lock, flags);
    731	if (mappass->reqcopy.cmd != 0) {
    732		spin_unlock_irqrestore(&mappass->copy_lock, flags);
    733		ret = -EINTR;
    734		goto out_error;
    735	}
    736
    737	mappass->reqcopy = *req;
    738	spin_unlock_irqrestore(&mappass->copy_lock, flags);
    739	queue_work(mappass->wq, &mappass->register_work);
    740
    741	/* Tell the caller we don't need to send back a notification yet */
    742	return -1;
    743
    744out_error:
    745	rsp = RING_GET_RESPONSE(&fedata->ring, fedata->ring.rsp_prod_pvt++);
    746	rsp->req_id = req->req_id;
    747	rsp->cmd = req->cmd;
    748	rsp->u.accept.id = req->u.accept.id;
    749	rsp->ret = ret;
    750	return 0;
    751}
    752
    753static int pvcalls_back_poll(struct xenbus_device *dev,
    754			     struct xen_pvcalls_request *req)
    755{
    756	struct pvcalls_fedata *fedata;
    757	struct sockpass_mapping *mappass;
    758	struct xen_pvcalls_response *rsp;
    759	struct inet_connection_sock *icsk;
    760	struct request_sock_queue *queue;
    761	unsigned long flags;
    762	int ret;
    763	bool data;
    764
    765	fedata = dev_get_drvdata(&dev->dev);
    766
    767	down(&fedata->socket_lock);
    768	mappass = radix_tree_lookup(&fedata->socketpass_mappings,
    769				    req->u.poll.id);
    770	up(&fedata->socket_lock);
    771	if (mappass == NULL)
    772		return -EINVAL;
    773
    774	/*
    775	 * Limitation of the current implementation: only support one
    776	 * concurrent accept or poll call on one socket.
    777	 */
    778	spin_lock_irqsave(&mappass->copy_lock, flags);
    779	if (mappass->reqcopy.cmd != 0) {
    780		ret = -EINTR;
    781		goto out;
    782	}
    783
    784	mappass->reqcopy = *req;
    785	icsk = inet_csk(mappass->sock->sk);
    786	queue = &icsk->icsk_accept_queue;
    787	data = READ_ONCE(queue->rskq_accept_head) != NULL;
    788	if (data) {
    789		mappass->reqcopy.cmd = 0;
    790		ret = 0;
    791		goto out;
    792	}
    793	spin_unlock_irqrestore(&mappass->copy_lock, flags);
    794
    795	/* Tell the caller we don't need to send back a notification yet */
    796	return -1;
    797
    798out:
    799	spin_unlock_irqrestore(&mappass->copy_lock, flags);
    800
    801	rsp = RING_GET_RESPONSE(&fedata->ring, fedata->ring.rsp_prod_pvt++);
    802	rsp->req_id = req->req_id;
    803	rsp->cmd = req->cmd;
    804	rsp->u.poll.id = req->u.poll.id;
    805	rsp->ret = ret;
    806	return 0;
    807}
    808
    809static int pvcalls_back_handle_cmd(struct xenbus_device *dev,
    810				   struct xen_pvcalls_request *req)
    811{
    812	int ret = 0;
    813
    814	switch (req->cmd) {
    815	case PVCALLS_SOCKET:
    816		ret = pvcalls_back_socket(dev, req);
    817		break;
    818	case PVCALLS_CONNECT:
    819		ret = pvcalls_back_connect(dev, req);
    820		break;
    821	case PVCALLS_RELEASE:
    822		ret = pvcalls_back_release(dev, req);
    823		break;
    824	case PVCALLS_BIND:
    825		ret = pvcalls_back_bind(dev, req);
    826		break;
    827	case PVCALLS_LISTEN:
    828		ret = pvcalls_back_listen(dev, req);
    829		break;
    830	case PVCALLS_ACCEPT:
    831		ret = pvcalls_back_accept(dev, req);
    832		break;
    833	case PVCALLS_POLL:
    834		ret = pvcalls_back_poll(dev, req);
    835		break;
    836	default:
    837	{
    838		struct pvcalls_fedata *fedata;
    839		struct xen_pvcalls_response *rsp;
    840
    841		fedata = dev_get_drvdata(&dev->dev);
    842		rsp = RING_GET_RESPONSE(
    843				&fedata->ring, fedata->ring.rsp_prod_pvt++);
    844		rsp->req_id = req->req_id;
    845		rsp->cmd = req->cmd;
    846		rsp->ret = -ENOTSUPP;
    847		break;
    848	}
    849	}
    850	return ret;
    851}
    852
    853static void pvcalls_back_work(struct pvcalls_fedata *fedata)
    854{
    855	int notify, notify_all = 0, more = 1;
    856	struct xen_pvcalls_request req;
    857	struct xenbus_device *dev = fedata->dev;
    858
    859	while (more) {
    860		while (RING_HAS_UNCONSUMED_REQUESTS(&fedata->ring)) {
    861			RING_COPY_REQUEST(&fedata->ring,
    862					  fedata->ring.req_cons++,
    863					  &req);
    864
    865			if (!pvcalls_back_handle_cmd(dev, &req)) {
    866				RING_PUSH_RESPONSES_AND_CHECK_NOTIFY(
    867					&fedata->ring, notify);
    868				notify_all += notify;
    869			}
    870		}
    871
    872		if (notify_all) {
    873			notify_remote_via_irq(fedata->irq);
    874			notify_all = 0;
    875		}
    876
    877		RING_FINAL_CHECK_FOR_REQUESTS(&fedata->ring, more);
    878	}
    879}
    880
    881static irqreturn_t pvcalls_back_event(int irq, void *dev_id)
    882{
    883	struct xenbus_device *dev = dev_id;
    884	struct pvcalls_fedata *fedata = NULL;
    885	unsigned int eoi_flags = XEN_EOI_FLAG_SPURIOUS;
    886
    887	if (dev) {
    888		fedata = dev_get_drvdata(&dev->dev);
    889		if (fedata) {
    890			pvcalls_back_work(fedata);
    891			eoi_flags = 0;
    892		}
    893	}
    894
    895	xen_irq_lateeoi(irq, eoi_flags);
    896
    897	return IRQ_HANDLED;
    898}
    899
    900static irqreturn_t pvcalls_back_conn_event(int irq, void *sock_map)
    901{
    902	struct sock_mapping *map = sock_map;
    903	struct pvcalls_ioworker *iow;
    904
    905	if (map == NULL || map->sock == NULL || map->sock->sk == NULL ||
    906		map->sock->sk->sk_user_data != map) {
    907		xen_irq_lateeoi(irq, 0);
    908		return IRQ_HANDLED;
    909	}
    910
    911	iow = &map->ioworker;
    912
    913	atomic_inc(&map->write);
    914	atomic_inc(&map->eoi);
    915	atomic_inc(&map->io);
    916	queue_work(iow->wq, &iow->register_work);
    917
    918	return IRQ_HANDLED;
    919}
    920
    921static int backend_connect(struct xenbus_device *dev)
    922{
    923	int err;
    924	evtchn_port_t evtchn;
    925	grant_ref_t ring_ref;
    926	struct pvcalls_fedata *fedata = NULL;
    927
    928	fedata = kzalloc(sizeof(struct pvcalls_fedata), GFP_KERNEL);
    929	if (!fedata)
    930		return -ENOMEM;
    931
    932	fedata->irq = -1;
    933	err = xenbus_scanf(XBT_NIL, dev->otherend, "port", "%u",
    934			   &evtchn);
    935	if (err != 1) {
    936		err = -EINVAL;
    937		xenbus_dev_fatal(dev, err, "reading %s/event-channel",
    938				 dev->otherend);
    939		goto error;
    940	}
    941
    942	err = xenbus_scanf(XBT_NIL, dev->otherend, "ring-ref", "%u", &ring_ref);
    943	if (err != 1) {
    944		err = -EINVAL;
    945		xenbus_dev_fatal(dev, err, "reading %s/ring-ref",
    946				 dev->otherend);
    947		goto error;
    948	}
    949
    950	err = bind_interdomain_evtchn_to_irq_lateeoi(dev, evtchn);
    951	if (err < 0)
    952		goto error;
    953	fedata->irq = err;
    954
    955	err = request_threaded_irq(fedata->irq, NULL, pvcalls_back_event,
    956				   IRQF_ONESHOT, "pvcalls-back", dev);
    957	if (err < 0)
    958		goto error;
    959
    960	err = xenbus_map_ring_valloc(dev, &ring_ref, 1,
    961				     (void **)&fedata->sring);
    962	if (err < 0)
    963		goto error;
    964
    965	BACK_RING_INIT(&fedata->ring, fedata->sring, XEN_PAGE_SIZE * 1);
    966	fedata->dev = dev;
    967
    968	INIT_LIST_HEAD(&fedata->socket_mappings);
    969	INIT_RADIX_TREE(&fedata->socketpass_mappings, GFP_KERNEL);
    970	sema_init(&fedata->socket_lock, 1);
    971	dev_set_drvdata(&dev->dev, fedata);
    972
    973	down(&pvcalls_back_global.frontends_lock);
    974	list_add_tail(&fedata->list, &pvcalls_back_global.frontends);
    975	up(&pvcalls_back_global.frontends_lock);
    976
    977	return 0;
    978
    979 error:
    980	if (fedata->irq >= 0)
    981		unbind_from_irqhandler(fedata->irq, dev);
    982	if (fedata->sring != NULL)
    983		xenbus_unmap_ring_vfree(dev, fedata->sring);
    984	kfree(fedata);
    985	return err;
    986}
    987
    988static int backend_disconnect(struct xenbus_device *dev)
    989{
    990	struct pvcalls_fedata *fedata;
    991	struct sock_mapping *map, *n;
    992	struct sockpass_mapping *mappass;
    993	struct radix_tree_iter iter;
    994	void **slot;
    995
    996
    997	fedata = dev_get_drvdata(&dev->dev);
    998
    999	down(&fedata->socket_lock);
   1000	list_for_each_entry_safe(map, n, &fedata->socket_mappings, list) {
   1001		list_del(&map->list);
   1002		pvcalls_back_release_active(dev, fedata, map);
   1003	}
   1004
   1005	radix_tree_for_each_slot(slot, &fedata->socketpass_mappings, &iter, 0) {
   1006		mappass = radix_tree_deref_slot(slot);
   1007		if (!mappass)
   1008			continue;
   1009		if (radix_tree_exception(mappass)) {
   1010			if (radix_tree_deref_retry(mappass))
   1011				slot = radix_tree_iter_retry(&iter);
   1012		} else {
   1013			radix_tree_delete(&fedata->socketpass_mappings,
   1014					  mappass->id);
   1015			pvcalls_back_release_passive(dev, fedata, mappass);
   1016		}
   1017	}
   1018	up(&fedata->socket_lock);
   1019
   1020	unbind_from_irqhandler(fedata->irq, dev);
   1021	xenbus_unmap_ring_vfree(dev, fedata->sring);
   1022
   1023	list_del(&fedata->list);
   1024	kfree(fedata);
   1025	dev_set_drvdata(&dev->dev, NULL);
   1026
   1027	return 0;
   1028}
   1029
   1030static int pvcalls_back_probe(struct xenbus_device *dev,
   1031			      const struct xenbus_device_id *id)
   1032{
   1033	int err, abort;
   1034	struct xenbus_transaction xbt;
   1035
   1036again:
   1037	abort = 1;
   1038
   1039	err = xenbus_transaction_start(&xbt);
   1040	if (err) {
   1041		pr_warn("%s cannot create xenstore transaction\n", __func__);
   1042		return err;
   1043	}
   1044
   1045	err = xenbus_printf(xbt, dev->nodename, "versions", "%s",
   1046			    PVCALLS_VERSIONS);
   1047	if (err) {
   1048		pr_warn("%s write out 'versions' failed\n", __func__);
   1049		goto abort;
   1050	}
   1051
   1052	err = xenbus_printf(xbt, dev->nodename, "max-page-order", "%u",
   1053			    MAX_RING_ORDER);
   1054	if (err) {
   1055		pr_warn("%s write out 'max-page-order' failed\n", __func__);
   1056		goto abort;
   1057	}
   1058
   1059	err = xenbus_printf(xbt, dev->nodename, "function-calls",
   1060			    XENBUS_FUNCTIONS_CALLS);
   1061	if (err) {
   1062		pr_warn("%s write out 'function-calls' failed\n", __func__);
   1063		goto abort;
   1064	}
   1065
   1066	abort = 0;
   1067abort:
   1068	err = xenbus_transaction_end(xbt, abort);
   1069	if (err) {
   1070		if (err == -EAGAIN && !abort)
   1071			goto again;
   1072		pr_warn("%s cannot complete xenstore transaction\n", __func__);
   1073		return err;
   1074	}
   1075
   1076	if (abort)
   1077		return -EFAULT;
   1078
   1079	xenbus_switch_state(dev, XenbusStateInitWait);
   1080
   1081	return 0;
   1082}
   1083
   1084static void set_backend_state(struct xenbus_device *dev,
   1085			      enum xenbus_state state)
   1086{
   1087	while (dev->state != state) {
   1088		switch (dev->state) {
   1089		case XenbusStateClosed:
   1090			switch (state) {
   1091			case XenbusStateInitWait:
   1092			case XenbusStateConnected:
   1093				xenbus_switch_state(dev, XenbusStateInitWait);
   1094				break;
   1095			case XenbusStateClosing:
   1096				xenbus_switch_state(dev, XenbusStateClosing);
   1097				break;
   1098			default:
   1099				WARN_ON(1);
   1100			}
   1101			break;
   1102		case XenbusStateInitWait:
   1103		case XenbusStateInitialised:
   1104			switch (state) {
   1105			case XenbusStateConnected:
   1106				if (backend_connect(dev))
   1107					return;
   1108				xenbus_switch_state(dev, XenbusStateConnected);
   1109				break;
   1110			case XenbusStateClosing:
   1111			case XenbusStateClosed:
   1112				xenbus_switch_state(dev, XenbusStateClosing);
   1113				break;
   1114			default:
   1115				WARN_ON(1);
   1116			}
   1117			break;
   1118		case XenbusStateConnected:
   1119			switch (state) {
   1120			case XenbusStateInitWait:
   1121			case XenbusStateClosing:
   1122			case XenbusStateClosed:
   1123				down(&pvcalls_back_global.frontends_lock);
   1124				backend_disconnect(dev);
   1125				up(&pvcalls_back_global.frontends_lock);
   1126				xenbus_switch_state(dev, XenbusStateClosing);
   1127				break;
   1128			default:
   1129				WARN_ON(1);
   1130			}
   1131			break;
   1132		case XenbusStateClosing:
   1133			switch (state) {
   1134			case XenbusStateInitWait:
   1135			case XenbusStateConnected:
   1136			case XenbusStateClosed:
   1137				xenbus_switch_state(dev, XenbusStateClosed);
   1138				break;
   1139			default:
   1140				WARN_ON(1);
   1141			}
   1142			break;
   1143		default:
   1144			WARN_ON(1);
   1145		}
   1146	}
   1147}
   1148
   1149static void pvcalls_back_changed(struct xenbus_device *dev,
   1150				 enum xenbus_state frontend_state)
   1151{
   1152	switch (frontend_state) {
   1153	case XenbusStateInitialising:
   1154		set_backend_state(dev, XenbusStateInitWait);
   1155		break;
   1156
   1157	case XenbusStateInitialised:
   1158	case XenbusStateConnected:
   1159		set_backend_state(dev, XenbusStateConnected);
   1160		break;
   1161
   1162	case XenbusStateClosing:
   1163		set_backend_state(dev, XenbusStateClosing);
   1164		break;
   1165
   1166	case XenbusStateClosed:
   1167		set_backend_state(dev, XenbusStateClosed);
   1168		if (xenbus_dev_is_online(dev))
   1169			break;
   1170		device_unregister(&dev->dev);
   1171		break;
   1172	case XenbusStateUnknown:
   1173		set_backend_state(dev, XenbusStateClosed);
   1174		device_unregister(&dev->dev);
   1175		break;
   1176
   1177	default:
   1178		xenbus_dev_fatal(dev, -EINVAL, "saw state %d at frontend",
   1179				 frontend_state);
   1180		break;
   1181	}
   1182}
   1183
   1184static int pvcalls_back_remove(struct xenbus_device *dev)
   1185{
   1186	return 0;
   1187}
   1188
   1189static int pvcalls_back_uevent(struct xenbus_device *xdev,
   1190			       struct kobj_uevent_env *env)
   1191{
   1192	return 0;
   1193}
   1194
   1195static const struct xenbus_device_id pvcalls_back_ids[] = {
   1196	{ "pvcalls" },
   1197	{ "" }
   1198};
   1199
   1200static struct xenbus_driver pvcalls_back_driver = {
   1201	.ids = pvcalls_back_ids,
   1202	.probe = pvcalls_back_probe,
   1203	.remove = pvcalls_back_remove,
   1204	.uevent = pvcalls_back_uevent,
   1205	.otherend_changed = pvcalls_back_changed,
   1206};
   1207
   1208static int __init pvcalls_back_init(void)
   1209{
   1210	int ret;
   1211
   1212	if (!xen_domain())
   1213		return -ENODEV;
   1214
   1215	ret = xenbus_register_backend(&pvcalls_back_driver);
   1216	if (ret < 0)
   1217		return ret;
   1218
   1219	sema_init(&pvcalls_back_global.frontends_lock, 1);
   1220	INIT_LIST_HEAD(&pvcalls_back_global.frontends);
   1221	return 0;
   1222}
   1223module_init(pvcalls_back_init);
   1224
   1225static void __exit pvcalls_back_fin(void)
   1226{
   1227	struct pvcalls_fedata *fedata, *nfedata;
   1228
   1229	down(&pvcalls_back_global.frontends_lock);
   1230	list_for_each_entry_safe(fedata, nfedata,
   1231				 &pvcalls_back_global.frontends, list) {
   1232		backend_disconnect(fedata->dev);
   1233	}
   1234	up(&pvcalls_back_global.frontends_lock);
   1235
   1236	xenbus_unregister_driver(&pvcalls_back_driver);
   1237}
   1238
   1239module_exit(pvcalls_back_fin);
   1240
   1241MODULE_DESCRIPTION("Xen PV Calls backend driver");
   1242MODULE_AUTHOR("Stefano Stabellini <sstabellini@kernel.org>");
   1243MODULE_LICENSE("GPL");