cachepc-linux

Fork of AMDESE/linux with modifications for CachePC side-channel attack
git clone https://git.sinitax.com/sinitax/cachepc-linux
Log | Files | Refs | README | LICENSE | sfeed.txt

hyperv_transport.c (24279B)


      1// SPDX-License-Identifier: GPL-2.0-only
      2/*
      3 * Hyper-V transport for vsock
      4 *
      5 * Hyper-V Sockets supplies a byte-stream based communication mechanism
      6 * between the host and the VM. This driver implements the necessary
      7 * support in the VM by introducing the new vsock transport.
      8 *
      9 * Copyright (c) 2017, Microsoft Corporation.
     10 */
     11#include <linux/module.h>
     12#include <linux/vmalloc.h>
     13#include <linux/hyperv.h>
     14#include <net/sock.h>
     15#include <net/af_vsock.h>
     16#include <asm/hyperv-tlfs.h>
     17
     18/* Older (VMBUS version 'VERSION_WIN10' or before) Windows hosts have some
     19 * stricter requirements on the hv_sock ring buffer size of six 4K pages.
     20 * hyperv-tlfs defines HV_HYP_PAGE_SIZE as 4K. Newer hosts don't have this
     21 * limitation; but, keep the defaults the same for compat.
     22 */
     23#define RINGBUFFER_HVS_RCV_SIZE (HV_HYP_PAGE_SIZE * 6)
     24#define RINGBUFFER_HVS_SND_SIZE (HV_HYP_PAGE_SIZE * 6)
     25#define RINGBUFFER_HVS_MAX_SIZE (HV_HYP_PAGE_SIZE * 64)
     26
     27/* The MTU is 16KB per the host side's design */
     28#define HVS_MTU_SIZE		(1024 * 16)
     29
     30/* How long to wait for graceful shutdown of a connection */
     31#define HVS_CLOSE_TIMEOUT (8 * HZ)
     32
     33struct vmpipe_proto_header {
     34	u32 pkt_type;
     35	u32 data_size;
     36};
     37
     38/* For recv, we use the VMBus in-place packet iterator APIs to directly copy
     39 * data from the ringbuffer into the userspace buffer.
     40 */
     41struct hvs_recv_buf {
     42	/* The header before the payload data */
     43	struct vmpipe_proto_header hdr;
     44
     45	/* The payload */
     46	u8 data[HVS_MTU_SIZE];
     47};
     48
     49/* We can send up to HVS_MTU_SIZE bytes of payload to the host, but let's use
     50 * a smaller size, i.e. HVS_SEND_BUF_SIZE, to maximize concurrency between the
     51 * guest and the host processing as one VMBUS packet is the smallest processing
     52 * unit.
     53 *
     54 * Note: the buffer can be eliminated in the future when we add new VMBus
     55 * ringbuffer APIs that allow us to directly copy data from userspace buffer
     56 * to VMBus ringbuffer.
     57 */
     58#define HVS_SEND_BUF_SIZE \
     59		(HV_HYP_PAGE_SIZE - sizeof(struct vmpipe_proto_header))
     60
     61struct hvs_send_buf {
     62	/* The header before the payload data */
     63	struct vmpipe_proto_header hdr;
     64
     65	/* The payload */
     66	u8 data[HVS_SEND_BUF_SIZE];
     67};
     68
     69#define HVS_HEADER_LEN	(sizeof(struct vmpacket_descriptor) + \
     70			 sizeof(struct vmpipe_proto_header))
     71
     72/* See 'prev_indices' in hv_ringbuffer_read(), hv_ringbuffer_write(), and
     73 * __hv_pkt_iter_next().
     74 */
     75#define VMBUS_PKT_TRAILER_SIZE	(sizeof(u64))
     76
     77#define HVS_PKT_LEN(payload_len)	(HVS_HEADER_LEN + \
     78					 ALIGN((payload_len), 8) + \
     79					 VMBUS_PKT_TRAILER_SIZE)
     80
     81/* Upper bound on the size of a VMbus packet for hv_sock */
     82#define HVS_MAX_PKT_SIZE	HVS_PKT_LEN(HVS_MTU_SIZE)
     83
     84union hvs_service_id {
     85	guid_t	srv_id;
     86
     87	struct {
     88		unsigned int svm_port;
     89		unsigned char b[sizeof(guid_t) - sizeof(unsigned int)];
     90	};
     91};
     92
     93/* Per-socket state (accessed via vsk->trans) */
     94struct hvsock {
     95	struct vsock_sock *vsk;
     96
     97	guid_t vm_srv_id;
     98	guid_t host_srv_id;
     99
    100	struct vmbus_channel *chan;
    101	struct vmpacket_descriptor *recv_desc;
    102
    103	/* The length of the payload not delivered to userland yet */
    104	u32 recv_data_len;
    105	/* The offset of the payload */
    106	u32 recv_data_off;
    107
    108	/* Have we sent the zero-length packet (FIN)? */
    109	bool fin_sent;
    110};
    111
    112/* In the VM, we support Hyper-V Sockets with AF_VSOCK, and the endpoint is
    113 * <cid, port> (see struct sockaddr_vm). Note: cid is not really used here:
    114 * when we write apps to connect to the host, we can only use VMADDR_CID_ANY
    115 * or VMADDR_CID_HOST (both are equivalent) as the remote cid, and when we
    116 * write apps to bind() & listen() in the VM, we can only use VMADDR_CID_ANY
    117 * as the local cid.
    118 *
    119 * On the host, Hyper-V Sockets are supported by Winsock AF_HYPERV:
    120 * https://docs.microsoft.com/en-us/virtualization/hyper-v-on-windows/user-
    121 * guide/make-integration-service, and the endpoint is <VmID, ServiceId> with
    122 * the below sockaddr:
    123 *
    124 * struct SOCKADDR_HV
    125 * {
    126 *    ADDRESS_FAMILY Family;
    127 *    USHORT Reserved;
    128 *    GUID VmId;
    129 *    GUID ServiceId;
    130 * };
    131 * Note: VmID is not used by Linux VM and actually it isn't transmitted via
    132 * VMBus, because here it's obvious the host and the VM can easily identify
    133 * each other. Though the VmID is useful on the host, especially in the case
    134 * of Windows container, Linux VM doesn't need it at all.
    135 *
    136 * To make use of the AF_VSOCK infrastructure in Linux VM, we have to limit
    137 * the available GUID space of SOCKADDR_HV so that we can create a mapping
    138 * between AF_VSOCK port and SOCKADDR_HV Service GUID. The rule of writing
    139 * Hyper-V Sockets apps on the host and in Linux VM is:
    140 *
    141 ****************************************************************************
    142 * The only valid Service GUIDs, from the perspectives of both the host and *
    143 * Linux VM, that can be connected by the other end, must conform to this   *
    144 * format: <port>-facb-11e6-bd58-64006a7986d3.                              *
    145 ****************************************************************************
    146 *
    147 * When we write apps on the host to connect(), the GUID ServiceID is used.
    148 * When we write apps in Linux VM to connect(), we only need to specify the
    149 * port and the driver will form the GUID and use that to request the host.
    150 *
    151 */
    152
    153/* 00000000-facb-11e6-bd58-64006a7986d3 */
    154static const guid_t srv_id_template =
    155	GUID_INIT(0x00000000, 0xfacb, 0x11e6, 0xbd, 0x58,
    156		  0x64, 0x00, 0x6a, 0x79, 0x86, 0xd3);
    157
    158static bool hvs_check_transport(struct vsock_sock *vsk);
    159
    160static bool is_valid_srv_id(const guid_t *id)
    161{
    162	return !memcmp(&id->b[4], &srv_id_template.b[4], sizeof(guid_t) - 4);
    163}
    164
    165static unsigned int get_port_by_srv_id(const guid_t *svr_id)
    166{
    167	return *((unsigned int *)svr_id);
    168}
    169
    170static void hvs_addr_init(struct sockaddr_vm *addr, const guid_t *svr_id)
    171{
    172	unsigned int port = get_port_by_srv_id(svr_id);
    173
    174	vsock_addr_init(addr, VMADDR_CID_ANY, port);
    175}
    176
    177static void hvs_set_channel_pending_send_size(struct vmbus_channel *chan)
    178{
    179	set_channel_pending_send_size(chan,
    180				      HVS_PKT_LEN(HVS_SEND_BUF_SIZE));
    181
    182	virt_mb();
    183}
    184
    185static bool hvs_channel_readable(struct vmbus_channel *chan)
    186{
    187	u32 readable = hv_get_bytes_to_read(&chan->inbound);
    188
    189	/* 0-size payload means FIN */
    190	return readable >= HVS_PKT_LEN(0);
    191}
    192
    193static int hvs_channel_readable_payload(struct vmbus_channel *chan)
    194{
    195	u32 readable = hv_get_bytes_to_read(&chan->inbound);
    196
    197	if (readable > HVS_PKT_LEN(0)) {
    198		/* At least we have 1 byte to read. We don't need to return
    199		 * the exact readable bytes: see vsock_stream_recvmsg() ->
    200		 * vsock_stream_has_data().
    201		 */
    202		return 1;
    203	}
    204
    205	if (readable == HVS_PKT_LEN(0)) {
    206		/* 0-size payload means FIN */
    207		return 0;
    208	}
    209
    210	/* No payload or FIN */
    211	return -1;
    212}
    213
    214static size_t hvs_channel_writable_bytes(struct vmbus_channel *chan)
    215{
    216	u32 writeable = hv_get_bytes_to_write(&chan->outbound);
    217	size_t ret;
    218
    219	/* The ringbuffer mustn't be 100% full, and we should reserve a
    220	 * zero-length-payload packet for the FIN: see hv_ringbuffer_write()
    221	 * and hvs_shutdown().
    222	 */
    223	if (writeable <= HVS_PKT_LEN(1) + HVS_PKT_LEN(0))
    224		return 0;
    225
    226	ret = writeable - HVS_PKT_LEN(1) - HVS_PKT_LEN(0);
    227
    228	return round_down(ret, 8);
    229}
    230
    231static int __hvs_send_data(struct vmbus_channel *chan,
    232			   struct vmpipe_proto_header *hdr,
    233			   size_t to_write)
    234{
    235	hdr->pkt_type = 1;
    236	hdr->data_size = to_write;
    237	return vmbus_sendpacket(chan, hdr, sizeof(*hdr) + to_write,
    238				0, VM_PKT_DATA_INBAND, 0);
    239}
    240
    241static int hvs_send_data(struct vmbus_channel *chan,
    242			 struct hvs_send_buf *send_buf, size_t to_write)
    243{
    244	return __hvs_send_data(chan, &send_buf->hdr, to_write);
    245}
    246
    247static void hvs_channel_cb(void *ctx)
    248{
    249	struct sock *sk = (struct sock *)ctx;
    250	struct vsock_sock *vsk = vsock_sk(sk);
    251	struct hvsock *hvs = vsk->trans;
    252	struct vmbus_channel *chan = hvs->chan;
    253
    254	if (hvs_channel_readable(chan))
    255		sk->sk_data_ready(sk);
    256
    257	if (hv_get_bytes_to_write(&chan->outbound) > 0)
    258		sk->sk_write_space(sk);
    259}
    260
    261static void hvs_do_close_lock_held(struct vsock_sock *vsk,
    262				   bool cancel_timeout)
    263{
    264	struct sock *sk = sk_vsock(vsk);
    265
    266	sock_set_flag(sk, SOCK_DONE);
    267	vsk->peer_shutdown = SHUTDOWN_MASK;
    268	if (vsock_stream_has_data(vsk) <= 0)
    269		sk->sk_state = TCP_CLOSING;
    270	sk->sk_state_change(sk);
    271	if (vsk->close_work_scheduled &&
    272	    (!cancel_timeout || cancel_delayed_work(&vsk->close_work))) {
    273		vsk->close_work_scheduled = false;
    274		vsock_remove_sock(vsk);
    275
    276		/* Release the reference taken while scheduling the timeout */
    277		sock_put(sk);
    278	}
    279}
    280
    281static void hvs_close_connection(struct vmbus_channel *chan)
    282{
    283	struct sock *sk = get_per_channel_state(chan);
    284
    285	lock_sock(sk);
    286	hvs_do_close_lock_held(vsock_sk(sk), true);
    287	release_sock(sk);
    288
    289	/* Release the refcnt for the channel that's opened in
    290	 * hvs_open_connection().
    291	 */
    292	sock_put(sk);
    293}
    294
    295static void hvs_open_connection(struct vmbus_channel *chan)
    296{
    297	guid_t *if_instance, *if_type;
    298	unsigned char conn_from_host;
    299
    300	struct sockaddr_vm addr;
    301	struct sock *sk, *new = NULL;
    302	struct vsock_sock *vnew = NULL;
    303	struct hvsock *hvs = NULL;
    304	struct hvsock *hvs_new = NULL;
    305	int rcvbuf;
    306	int ret;
    307	int sndbuf;
    308
    309	if_type = &chan->offermsg.offer.if_type;
    310	if_instance = &chan->offermsg.offer.if_instance;
    311	conn_from_host = chan->offermsg.offer.u.pipe.user_def[0];
    312	if (!is_valid_srv_id(if_type))
    313		return;
    314
    315	hvs_addr_init(&addr, conn_from_host ? if_type : if_instance);
    316	sk = vsock_find_bound_socket(&addr);
    317	if (!sk)
    318		return;
    319
    320	lock_sock(sk);
    321	if ((conn_from_host && sk->sk_state != TCP_LISTEN) ||
    322	    (!conn_from_host && sk->sk_state != TCP_SYN_SENT))
    323		goto out;
    324
    325	if (conn_from_host) {
    326		if (sk->sk_ack_backlog >= sk->sk_max_ack_backlog)
    327			goto out;
    328
    329		new = vsock_create_connected(sk);
    330		if (!new)
    331			goto out;
    332
    333		new->sk_state = TCP_SYN_SENT;
    334		vnew = vsock_sk(new);
    335
    336		hvs_addr_init(&vnew->local_addr, if_type);
    337
    338		/* Remote peer is always the host */
    339		vsock_addr_init(&vnew->remote_addr,
    340				VMADDR_CID_HOST, VMADDR_PORT_ANY);
    341		vnew->remote_addr.svm_port = get_port_by_srv_id(if_instance);
    342		ret = vsock_assign_transport(vnew, vsock_sk(sk));
    343		/* Transport assigned (looking at remote_addr) must be the
    344		 * same where we received the request.
    345		 */
    346		if (ret || !hvs_check_transport(vnew)) {
    347			sock_put(new);
    348			goto out;
    349		}
    350		hvs_new = vnew->trans;
    351		hvs_new->chan = chan;
    352	} else {
    353		hvs = vsock_sk(sk)->trans;
    354		hvs->chan = chan;
    355	}
    356
    357	set_channel_read_mode(chan, HV_CALL_DIRECT);
    358
    359	/* Use the socket buffer sizes as hints for the VMBUS ring size. For
    360	 * server side sockets, 'sk' is the parent socket and thus, this will
    361	 * allow the child sockets to inherit the size from the parent. Keep
    362	 * the mins to the default value and align to page size as per VMBUS
    363	 * requirements.
    364	 * For the max, the socket core library will limit the socket buffer
    365	 * size that can be set by the user, but, since currently, the hv_sock
    366	 * VMBUS ring buffer is physically contiguous allocation, restrict it
    367	 * further.
    368	 * Older versions of hv_sock host side code cannot handle bigger VMBUS
    369	 * ring buffer size. Use the version number to limit the change to newer
    370	 * versions.
    371	 */
    372	if (vmbus_proto_version < VERSION_WIN10_V5) {
    373		sndbuf = RINGBUFFER_HVS_SND_SIZE;
    374		rcvbuf = RINGBUFFER_HVS_RCV_SIZE;
    375	} else {
    376		sndbuf = max_t(int, sk->sk_sndbuf, RINGBUFFER_HVS_SND_SIZE);
    377		sndbuf = min_t(int, sndbuf, RINGBUFFER_HVS_MAX_SIZE);
    378		sndbuf = ALIGN(sndbuf, HV_HYP_PAGE_SIZE);
    379		rcvbuf = max_t(int, sk->sk_rcvbuf, RINGBUFFER_HVS_RCV_SIZE);
    380		rcvbuf = min_t(int, rcvbuf, RINGBUFFER_HVS_MAX_SIZE);
    381		rcvbuf = ALIGN(rcvbuf, HV_HYP_PAGE_SIZE);
    382	}
    383
    384	chan->max_pkt_size = HVS_MAX_PKT_SIZE;
    385
    386	ret = vmbus_open(chan, sndbuf, rcvbuf, NULL, 0, hvs_channel_cb,
    387			 conn_from_host ? new : sk);
    388	if (ret != 0) {
    389		if (conn_from_host) {
    390			hvs_new->chan = NULL;
    391			sock_put(new);
    392		} else {
    393			hvs->chan = NULL;
    394		}
    395		goto out;
    396	}
    397
    398	set_per_channel_state(chan, conn_from_host ? new : sk);
    399
    400	/* This reference will be dropped by hvs_close_connection(). */
    401	sock_hold(conn_from_host ? new : sk);
    402	vmbus_set_chn_rescind_callback(chan, hvs_close_connection);
    403
    404	/* Set the pending send size to max packet size to always get
    405	 * notifications from the host when there is enough writable space.
    406	 * The host is optimized to send notifications only when the pending
    407	 * size boundary is crossed, and not always.
    408	 */
    409	hvs_set_channel_pending_send_size(chan);
    410
    411	if (conn_from_host) {
    412		new->sk_state = TCP_ESTABLISHED;
    413		sk_acceptq_added(sk);
    414
    415		hvs_new->vm_srv_id = *if_type;
    416		hvs_new->host_srv_id = *if_instance;
    417
    418		vsock_insert_connected(vnew);
    419
    420		vsock_enqueue_accept(sk, new);
    421	} else {
    422		sk->sk_state = TCP_ESTABLISHED;
    423		sk->sk_socket->state = SS_CONNECTED;
    424
    425		vsock_insert_connected(vsock_sk(sk));
    426	}
    427
    428	sk->sk_state_change(sk);
    429
    430out:
    431	/* Release refcnt obtained when we called vsock_find_bound_socket() */
    432	sock_put(sk);
    433
    434	release_sock(sk);
    435}
    436
    437static u32 hvs_get_local_cid(void)
    438{
    439	return VMADDR_CID_ANY;
    440}
    441
    442static int hvs_sock_init(struct vsock_sock *vsk, struct vsock_sock *psk)
    443{
    444	struct hvsock *hvs;
    445	struct sock *sk = sk_vsock(vsk);
    446
    447	hvs = kzalloc(sizeof(*hvs), GFP_KERNEL);
    448	if (!hvs)
    449		return -ENOMEM;
    450
    451	vsk->trans = hvs;
    452	hvs->vsk = vsk;
    453	sk->sk_sndbuf = RINGBUFFER_HVS_SND_SIZE;
    454	sk->sk_rcvbuf = RINGBUFFER_HVS_RCV_SIZE;
    455	return 0;
    456}
    457
    458static int hvs_connect(struct vsock_sock *vsk)
    459{
    460	union hvs_service_id vm, host;
    461	struct hvsock *h = vsk->trans;
    462
    463	vm.srv_id = srv_id_template;
    464	vm.svm_port = vsk->local_addr.svm_port;
    465	h->vm_srv_id = vm.srv_id;
    466
    467	host.srv_id = srv_id_template;
    468	host.svm_port = vsk->remote_addr.svm_port;
    469	h->host_srv_id = host.srv_id;
    470
    471	return vmbus_send_tl_connect_request(&h->vm_srv_id, &h->host_srv_id);
    472}
    473
    474static void hvs_shutdown_lock_held(struct hvsock *hvs, int mode)
    475{
    476	struct vmpipe_proto_header hdr;
    477
    478	if (hvs->fin_sent || !hvs->chan)
    479		return;
    480
    481	/* It can't fail: see hvs_channel_writable_bytes(). */
    482	(void)__hvs_send_data(hvs->chan, &hdr, 0);
    483	hvs->fin_sent = true;
    484}
    485
    486static int hvs_shutdown(struct vsock_sock *vsk, int mode)
    487{
    488	if (!(mode & SEND_SHUTDOWN))
    489		return 0;
    490
    491	hvs_shutdown_lock_held(vsk->trans, mode);
    492	return 0;
    493}
    494
    495static void hvs_close_timeout(struct work_struct *work)
    496{
    497	struct vsock_sock *vsk =
    498		container_of(work, struct vsock_sock, close_work.work);
    499	struct sock *sk = sk_vsock(vsk);
    500
    501	sock_hold(sk);
    502	lock_sock(sk);
    503	if (!sock_flag(sk, SOCK_DONE))
    504		hvs_do_close_lock_held(vsk, false);
    505
    506	vsk->close_work_scheduled = false;
    507	release_sock(sk);
    508	sock_put(sk);
    509}
    510
    511/* Returns true, if it is safe to remove socket; false otherwise */
    512static bool hvs_close_lock_held(struct vsock_sock *vsk)
    513{
    514	struct sock *sk = sk_vsock(vsk);
    515
    516	if (!(sk->sk_state == TCP_ESTABLISHED ||
    517	      sk->sk_state == TCP_CLOSING))
    518		return true;
    519
    520	if ((sk->sk_shutdown & SHUTDOWN_MASK) != SHUTDOWN_MASK)
    521		hvs_shutdown_lock_held(vsk->trans, SHUTDOWN_MASK);
    522
    523	if (sock_flag(sk, SOCK_DONE))
    524		return true;
    525
    526	/* This reference will be dropped by the delayed close routine */
    527	sock_hold(sk);
    528	INIT_DELAYED_WORK(&vsk->close_work, hvs_close_timeout);
    529	vsk->close_work_scheduled = true;
    530	schedule_delayed_work(&vsk->close_work, HVS_CLOSE_TIMEOUT);
    531	return false;
    532}
    533
    534static void hvs_release(struct vsock_sock *vsk)
    535{
    536	bool remove_sock;
    537
    538	remove_sock = hvs_close_lock_held(vsk);
    539	if (remove_sock)
    540		vsock_remove_sock(vsk);
    541}
    542
    543static void hvs_destruct(struct vsock_sock *vsk)
    544{
    545	struct hvsock *hvs = vsk->trans;
    546	struct vmbus_channel *chan = hvs->chan;
    547
    548	if (chan)
    549		vmbus_hvsock_device_unregister(chan);
    550
    551	kfree(hvs);
    552}
    553
    554static int hvs_dgram_bind(struct vsock_sock *vsk, struct sockaddr_vm *addr)
    555{
    556	return -EOPNOTSUPP;
    557}
    558
    559static int hvs_dgram_dequeue(struct vsock_sock *vsk, struct msghdr *msg,
    560			     size_t len, int flags)
    561{
    562	return -EOPNOTSUPP;
    563}
    564
    565static int hvs_dgram_enqueue(struct vsock_sock *vsk,
    566			     struct sockaddr_vm *remote, struct msghdr *msg,
    567			     size_t dgram_len)
    568{
    569	return -EOPNOTSUPP;
    570}
    571
    572static bool hvs_dgram_allow(u32 cid, u32 port)
    573{
    574	return false;
    575}
    576
    577static int hvs_update_recv_data(struct hvsock *hvs)
    578{
    579	struct hvs_recv_buf *recv_buf;
    580	u32 pkt_len, payload_len;
    581
    582	pkt_len = hv_pkt_len(hvs->recv_desc);
    583
    584	if (pkt_len < HVS_HEADER_LEN)
    585		return -EIO;
    586
    587	recv_buf = (struct hvs_recv_buf *)(hvs->recv_desc + 1);
    588	payload_len = recv_buf->hdr.data_size;
    589
    590	if (payload_len > pkt_len - HVS_HEADER_LEN ||
    591	    payload_len > HVS_MTU_SIZE)
    592		return -EIO;
    593
    594	if (payload_len == 0)
    595		hvs->vsk->peer_shutdown |= SEND_SHUTDOWN;
    596
    597	hvs->recv_data_len = payload_len;
    598	hvs->recv_data_off = 0;
    599
    600	return 0;
    601}
    602
    603static ssize_t hvs_stream_dequeue(struct vsock_sock *vsk, struct msghdr *msg,
    604				  size_t len, int flags)
    605{
    606	struct hvsock *hvs = vsk->trans;
    607	bool need_refill = !hvs->recv_desc;
    608	struct hvs_recv_buf *recv_buf;
    609	u32 to_read;
    610	int ret;
    611
    612	if (flags & MSG_PEEK)
    613		return -EOPNOTSUPP;
    614
    615	if (need_refill) {
    616		hvs->recv_desc = hv_pkt_iter_first(hvs->chan);
    617		if (!hvs->recv_desc)
    618			return -ENOBUFS;
    619		ret = hvs_update_recv_data(hvs);
    620		if (ret)
    621			return ret;
    622	}
    623
    624	recv_buf = (struct hvs_recv_buf *)(hvs->recv_desc + 1);
    625	to_read = min_t(u32, len, hvs->recv_data_len);
    626	ret = memcpy_to_msg(msg, recv_buf->data + hvs->recv_data_off, to_read);
    627	if (ret != 0)
    628		return ret;
    629
    630	hvs->recv_data_len -= to_read;
    631	if (hvs->recv_data_len == 0) {
    632		hvs->recv_desc = hv_pkt_iter_next(hvs->chan, hvs->recv_desc);
    633		if (hvs->recv_desc) {
    634			ret = hvs_update_recv_data(hvs);
    635			if (ret)
    636				return ret;
    637		}
    638	} else {
    639		hvs->recv_data_off += to_read;
    640	}
    641
    642	return to_read;
    643}
    644
    645static ssize_t hvs_stream_enqueue(struct vsock_sock *vsk, struct msghdr *msg,
    646				  size_t len)
    647{
    648	struct hvsock *hvs = vsk->trans;
    649	struct vmbus_channel *chan = hvs->chan;
    650	struct hvs_send_buf *send_buf;
    651	ssize_t to_write, max_writable;
    652	ssize_t ret = 0;
    653	ssize_t bytes_written = 0;
    654
    655	BUILD_BUG_ON(sizeof(*send_buf) != HV_HYP_PAGE_SIZE);
    656
    657	send_buf = kmalloc(sizeof(*send_buf), GFP_KERNEL);
    658	if (!send_buf)
    659		return -ENOMEM;
    660
    661	/* Reader(s) could be draining data from the channel as we write.
    662	 * Maximize bandwidth, by iterating until the channel is found to be
    663	 * full.
    664	 */
    665	while (len) {
    666		max_writable = hvs_channel_writable_bytes(chan);
    667		if (!max_writable)
    668			break;
    669		to_write = min_t(ssize_t, len, max_writable);
    670		to_write = min_t(ssize_t, to_write, HVS_SEND_BUF_SIZE);
    671		/* memcpy_from_msg is safe for loop as it advances the offsets
    672		 * within the message iterator.
    673		 */
    674		ret = memcpy_from_msg(send_buf->data, msg, to_write);
    675		if (ret < 0)
    676			goto out;
    677
    678		ret = hvs_send_data(hvs->chan, send_buf, to_write);
    679		if (ret < 0)
    680			goto out;
    681
    682		bytes_written += to_write;
    683		len -= to_write;
    684	}
    685out:
    686	/* If any data has been sent, return that */
    687	if (bytes_written)
    688		ret = bytes_written;
    689	kfree(send_buf);
    690	return ret;
    691}
    692
    693static s64 hvs_stream_has_data(struct vsock_sock *vsk)
    694{
    695	struct hvsock *hvs = vsk->trans;
    696	s64 ret;
    697
    698	if (hvs->recv_data_len > 0)
    699		return 1;
    700
    701	switch (hvs_channel_readable_payload(hvs->chan)) {
    702	case 1:
    703		ret = 1;
    704		break;
    705	case 0:
    706		vsk->peer_shutdown |= SEND_SHUTDOWN;
    707		ret = 0;
    708		break;
    709	default: /* -1 */
    710		ret = 0;
    711		break;
    712	}
    713
    714	return ret;
    715}
    716
    717static s64 hvs_stream_has_space(struct vsock_sock *vsk)
    718{
    719	struct hvsock *hvs = vsk->trans;
    720
    721	return hvs_channel_writable_bytes(hvs->chan);
    722}
    723
    724static u64 hvs_stream_rcvhiwat(struct vsock_sock *vsk)
    725{
    726	return HVS_MTU_SIZE + 1;
    727}
    728
    729static bool hvs_stream_is_active(struct vsock_sock *vsk)
    730{
    731	struct hvsock *hvs = vsk->trans;
    732
    733	return hvs->chan != NULL;
    734}
    735
    736static bool hvs_stream_allow(u32 cid, u32 port)
    737{
    738	if (cid == VMADDR_CID_HOST)
    739		return true;
    740
    741	return false;
    742}
    743
    744static
    745int hvs_notify_poll_in(struct vsock_sock *vsk, size_t target, bool *readable)
    746{
    747	struct hvsock *hvs = vsk->trans;
    748
    749	*readable = hvs_channel_readable(hvs->chan);
    750	return 0;
    751}
    752
    753static
    754int hvs_notify_poll_out(struct vsock_sock *vsk, size_t target, bool *writable)
    755{
    756	*writable = hvs_stream_has_space(vsk) > 0;
    757
    758	return 0;
    759}
    760
    761static
    762int hvs_notify_recv_init(struct vsock_sock *vsk, size_t target,
    763			 struct vsock_transport_recv_notify_data *d)
    764{
    765	return 0;
    766}
    767
    768static
    769int hvs_notify_recv_pre_block(struct vsock_sock *vsk, size_t target,
    770			      struct vsock_transport_recv_notify_data *d)
    771{
    772	return 0;
    773}
    774
    775static
    776int hvs_notify_recv_pre_dequeue(struct vsock_sock *vsk, size_t target,
    777				struct vsock_transport_recv_notify_data *d)
    778{
    779	return 0;
    780}
    781
    782static
    783int hvs_notify_recv_post_dequeue(struct vsock_sock *vsk, size_t target,
    784				 ssize_t copied, bool data_read,
    785				 struct vsock_transport_recv_notify_data *d)
    786{
    787	return 0;
    788}
    789
    790static
    791int hvs_notify_send_init(struct vsock_sock *vsk,
    792			 struct vsock_transport_send_notify_data *d)
    793{
    794	return 0;
    795}
    796
    797static
    798int hvs_notify_send_pre_block(struct vsock_sock *vsk,
    799			      struct vsock_transport_send_notify_data *d)
    800{
    801	return 0;
    802}
    803
    804static
    805int hvs_notify_send_pre_enqueue(struct vsock_sock *vsk,
    806				struct vsock_transport_send_notify_data *d)
    807{
    808	return 0;
    809}
    810
    811static
    812int hvs_notify_send_post_enqueue(struct vsock_sock *vsk, ssize_t written,
    813				 struct vsock_transport_send_notify_data *d)
    814{
    815	return 0;
    816}
    817
    818static struct vsock_transport hvs_transport = {
    819	.module                   = THIS_MODULE,
    820
    821	.get_local_cid            = hvs_get_local_cid,
    822
    823	.init                     = hvs_sock_init,
    824	.destruct                 = hvs_destruct,
    825	.release                  = hvs_release,
    826	.connect                  = hvs_connect,
    827	.shutdown                 = hvs_shutdown,
    828
    829	.dgram_bind               = hvs_dgram_bind,
    830	.dgram_dequeue            = hvs_dgram_dequeue,
    831	.dgram_enqueue            = hvs_dgram_enqueue,
    832	.dgram_allow              = hvs_dgram_allow,
    833
    834	.stream_dequeue           = hvs_stream_dequeue,
    835	.stream_enqueue           = hvs_stream_enqueue,
    836	.stream_has_data          = hvs_stream_has_data,
    837	.stream_has_space         = hvs_stream_has_space,
    838	.stream_rcvhiwat          = hvs_stream_rcvhiwat,
    839	.stream_is_active         = hvs_stream_is_active,
    840	.stream_allow             = hvs_stream_allow,
    841
    842	.notify_poll_in           = hvs_notify_poll_in,
    843	.notify_poll_out          = hvs_notify_poll_out,
    844	.notify_recv_init         = hvs_notify_recv_init,
    845	.notify_recv_pre_block    = hvs_notify_recv_pre_block,
    846	.notify_recv_pre_dequeue  = hvs_notify_recv_pre_dequeue,
    847	.notify_recv_post_dequeue = hvs_notify_recv_post_dequeue,
    848	.notify_send_init         = hvs_notify_send_init,
    849	.notify_send_pre_block    = hvs_notify_send_pre_block,
    850	.notify_send_pre_enqueue  = hvs_notify_send_pre_enqueue,
    851	.notify_send_post_enqueue = hvs_notify_send_post_enqueue,
    852
    853};
    854
    855static bool hvs_check_transport(struct vsock_sock *vsk)
    856{
    857	return vsk->transport == &hvs_transport;
    858}
    859
    860static int hvs_probe(struct hv_device *hdev,
    861		     const struct hv_vmbus_device_id *dev_id)
    862{
    863	struct vmbus_channel *chan = hdev->channel;
    864
    865	hvs_open_connection(chan);
    866
    867	/* Always return success to suppress the unnecessary error message
    868	 * in vmbus_probe(): on error the host will rescind the device in
    869	 * 30 seconds and we can do cleanup at that time in
    870	 * vmbus_onoffer_rescind().
    871	 */
    872	return 0;
    873}
    874
    875static int hvs_remove(struct hv_device *hdev)
    876{
    877	struct vmbus_channel *chan = hdev->channel;
    878
    879	vmbus_close(chan);
    880
    881	return 0;
    882}
    883
    884/* hv_sock connections can not persist across hibernation, and all the hv_sock
    885 * channels are forced to be rescinded before hibernation: see
    886 * vmbus_bus_suspend(). Here the dummy hvs_suspend() and hvs_resume()
    887 * are only needed because hibernation requires that every vmbus device's
    888 * driver should have a .suspend and .resume callback: see vmbus_suspend().
    889 */
    890static int hvs_suspend(struct hv_device *hv_dev)
    891{
    892	/* Dummy */
    893	return 0;
    894}
    895
    896static int hvs_resume(struct hv_device *dev)
    897{
    898	/* Dummy */
    899	return 0;
    900}
    901
    902/* This isn't really used. See vmbus_match() and vmbus_probe() */
    903static const struct hv_vmbus_device_id id_table[] = {
    904	{},
    905};
    906
    907static struct hv_driver hvs_drv = {
    908	.name		= "hv_sock",
    909	.hvsock		= true,
    910	.id_table	= id_table,
    911	.probe		= hvs_probe,
    912	.remove		= hvs_remove,
    913	.suspend	= hvs_suspend,
    914	.resume		= hvs_resume,
    915};
    916
    917static int __init hvs_init(void)
    918{
    919	int ret;
    920
    921	if (vmbus_proto_version < VERSION_WIN10)
    922		return -ENODEV;
    923
    924	ret = vmbus_driver_register(&hvs_drv);
    925	if (ret != 0)
    926		return ret;
    927
    928	ret = vsock_core_register(&hvs_transport, VSOCK_TRANSPORT_F_G2H);
    929	if (ret) {
    930		vmbus_driver_unregister(&hvs_drv);
    931		return ret;
    932	}
    933
    934	return 0;
    935}
    936
    937static void __exit hvs_exit(void)
    938{
    939	vsock_core_unregister(&hvs_transport);
    940	vmbus_driver_unregister(&hvs_drv);
    941}
    942
    943module_init(hvs_init);
    944module_exit(hvs_exit);
    945
    946MODULE_DESCRIPTION("Hyper-V Sockets");
    947MODULE_VERSION("1.0.0");
    948MODULE_LICENSE("GPL");
    949MODULE_ALIAS_NETPROTO(PF_VSOCK);