cachepc-linux

Fork of AMDESE/linux with modifications for CachePC side-channel attack
git clone https://git.sinitax.com/sinitax/cachepc-linux
Log | Files | Refs | README | LICENSE | sfeed.txt

virtio-iommu.c (31147B)


      1// SPDX-License-Identifier: GPL-2.0
      2/*
      3 * Virtio driver for the paravirtualized IOMMU
      4 *
      5 * Copyright (C) 2019 Arm Limited
      6 */
      7
      8#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt
      9
     10#include <linux/amba/bus.h>
     11#include <linux/delay.h>
     12#include <linux/dma-iommu.h>
     13#include <linux/dma-map-ops.h>
     14#include <linux/freezer.h>
     15#include <linux/interval_tree.h>
     16#include <linux/iommu.h>
     17#include <linux/module.h>
     18#include <linux/of_platform.h>
     19#include <linux/pci.h>
     20#include <linux/platform_device.h>
     21#include <linux/virtio.h>
     22#include <linux/virtio_config.h>
     23#include <linux/virtio_ids.h>
     24#include <linux/wait.h>
     25
     26#include <uapi/linux/virtio_iommu.h>
     27
     28#define MSI_IOVA_BASE			0x8000000
     29#define MSI_IOVA_LENGTH			0x100000
     30
     31#define VIOMMU_REQUEST_VQ		0
     32#define VIOMMU_EVENT_VQ			1
     33#define VIOMMU_NR_VQS			2
     34
     35struct viommu_dev {
     36	struct iommu_device		iommu;
     37	struct device			*dev;
     38	struct virtio_device		*vdev;
     39
     40	struct ida			domain_ids;
     41
     42	struct virtqueue		*vqs[VIOMMU_NR_VQS];
     43	spinlock_t			request_lock;
     44	struct list_head		requests;
     45	void				*evts;
     46
     47	/* Device configuration */
     48	struct iommu_domain_geometry	geometry;
     49	u64				pgsize_bitmap;
     50	u32				first_domain;
     51	u32				last_domain;
     52	/* Supported MAP flags */
     53	u32				map_flags;
     54	u32				probe_size;
     55};
     56
     57struct viommu_mapping {
     58	phys_addr_t			paddr;
     59	struct interval_tree_node	iova;
     60	u32				flags;
     61};
     62
     63struct viommu_domain {
     64	struct iommu_domain		domain;
     65	struct viommu_dev		*viommu;
     66	struct mutex			mutex; /* protects viommu pointer */
     67	unsigned int			id;
     68	u32				map_flags;
     69
     70	spinlock_t			mappings_lock;
     71	struct rb_root_cached		mappings;
     72
     73	unsigned long			nr_endpoints;
     74	bool				bypass;
     75};
     76
     77struct viommu_endpoint {
     78	struct device			*dev;
     79	struct viommu_dev		*viommu;
     80	struct viommu_domain		*vdomain;
     81	struct list_head		resv_regions;
     82};
     83
     84struct viommu_request {
     85	struct list_head		list;
     86	void				*writeback;
     87	unsigned int			write_offset;
     88	unsigned int			len;
     89	char				buf[];
     90};
     91
     92#define VIOMMU_FAULT_RESV_MASK		0xffffff00
     93
     94struct viommu_event {
     95	union {
     96		u32			head;
     97		struct virtio_iommu_fault fault;
     98	};
     99};
    100
    101#define to_viommu_domain(domain)	\
    102	container_of(domain, struct viommu_domain, domain)
    103
    104static int viommu_get_req_errno(void *buf, size_t len)
    105{
    106	struct virtio_iommu_req_tail *tail = buf + len - sizeof(*tail);
    107
    108	switch (tail->status) {
    109	case VIRTIO_IOMMU_S_OK:
    110		return 0;
    111	case VIRTIO_IOMMU_S_UNSUPP:
    112		return -ENOSYS;
    113	case VIRTIO_IOMMU_S_INVAL:
    114		return -EINVAL;
    115	case VIRTIO_IOMMU_S_RANGE:
    116		return -ERANGE;
    117	case VIRTIO_IOMMU_S_NOENT:
    118		return -ENOENT;
    119	case VIRTIO_IOMMU_S_FAULT:
    120		return -EFAULT;
    121	case VIRTIO_IOMMU_S_NOMEM:
    122		return -ENOMEM;
    123	case VIRTIO_IOMMU_S_IOERR:
    124	case VIRTIO_IOMMU_S_DEVERR:
    125	default:
    126		return -EIO;
    127	}
    128}
    129
    130static void viommu_set_req_status(void *buf, size_t len, int status)
    131{
    132	struct virtio_iommu_req_tail *tail = buf + len - sizeof(*tail);
    133
    134	tail->status = status;
    135}
    136
    137static off_t viommu_get_write_desc_offset(struct viommu_dev *viommu,
    138					  struct virtio_iommu_req_head *req,
    139					  size_t len)
    140{
    141	size_t tail_size = sizeof(struct virtio_iommu_req_tail);
    142
    143	if (req->type == VIRTIO_IOMMU_T_PROBE)
    144		return len - viommu->probe_size - tail_size;
    145
    146	return len - tail_size;
    147}
    148
    149/*
    150 * __viommu_sync_req - Complete all in-flight requests
    151 *
    152 * Wait for all added requests to complete. When this function returns, all
    153 * requests that were in-flight at the time of the call have completed.
    154 */
    155static int __viommu_sync_req(struct viommu_dev *viommu)
    156{
    157	unsigned int len;
    158	size_t write_len;
    159	struct viommu_request *req;
    160	struct virtqueue *vq = viommu->vqs[VIOMMU_REQUEST_VQ];
    161
    162	assert_spin_locked(&viommu->request_lock);
    163
    164	virtqueue_kick(vq);
    165
    166	while (!list_empty(&viommu->requests)) {
    167		len = 0;
    168		req = virtqueue_get_buf(vq, &len);
    169		if (!req)
    170			continue;
    171
    172		if (!len)
    173			viommu_set_req_status(req->buf, req->len,
    174					      VIRTIO_IOMMU_S_IOERR);
    175
    176		write_len = req->len - req->write_offset;
    177		if (req->writeback && len == write_len)
    178			memcpy(req->writeback, req->buf + req->write_offset,
    179			       write_len);
    180
    181		list_del(&req->list);
    182		kfree(req);
    183	}
    184
    185	return 0;
    186}
    187
    188static int viommu_sync_req(struct viommu_dev *viommu)
    189{
    190	int ret;
    191	unsigned long flags;
    192
    193	spin_lock_irqsave(&viommu->request_lock, flags);
    194	ret = __viommu_sync_req(viommu);
    195	if (ret)
    196		dev_dbg(viommu->dev, "could not sync requests (%d)\n", ret);
    197	spin_unlock_irqrestore(&viommu->request_lock, flags);
    198
    199	return ret;
    200}
    201
    202/*
    203 * __viommu_add_request - Add one request to the queue
    204 * @buf: pointer to the request buffer
    205 * @len: length of the request buffer
    206 * @writeback: copy data back to the buffer when the request completes.
    207 *
    208 * Add a request to the queue. Only synchronize the queue if it's already full.
    209 * Otherwise don't kick the queue nor wait for requests to complete.
    210 *
    211 * When @writeback is true, data written by the device, including the request
    212 * status, is copied into @buf after the request completes. This is unsafe if
    213 * the caller allocates @buf on stack and drops the lock between add_req() and
    214 * sync_req().
    215 *
    216 * Return 0 if the request was successfully added to the queue.
    217 */
    218static int __viommu_add_req(struct viommu_dev *viommu, void *buf, size_t len,
    219			    bool writeback)
    220{
    221	int ret;
    222	off_t write_offset;
    223	struct viommu_request *req;
    224	struct scatterlist top_sg, bottom_sg;
    225	struct scatterlist *sg[2] = { &top_sg, &bottom_sg };
    226	struct virtqueue *vq = viommu->vqs[VIOMMU_REQUEST_VQ];
    227
    228	assert_spin_locked(&viommu->request_lock);
    229
    230	write_offset = viommu_get_write_desc_offset(viommu, buf, len);
    231	if (write_offset <= 0)
    232		return -EINVAL;
    233
    234	req = kzalloc(sizeof(*req) + len, GFP_ATOMIC);
    235	if (!req)
    236		return -ENOMEM;
    237
    238	req->len = len;
    239	if (writeback) {
    240		req->writeback = buf + write_offset;
    241		req->write_offset = write_offset;
    242	}
    243	memcpy(&req->buf, buf, write_offset);
    244
    245	sg_init_one(&top_sg, req->buf, write_offset);
    246	sg_init_one(&bottom_sg, req->buf + write_offset, len - write_offset);
    247
    248	ret = virtqueue_add_sgs(vq, sg, 1, 1, req, GFP_ATOMIC);
    249	if (ret == -ENOSPC) {
    250		/* If the queue is full, sync and retry */
    251		if (!__viommu_sync_req(viommu))
    252			ret = virtqueue_add_sgs(vq, sg, 1, 1, req, GFP_ATOMIC);
    253	}
    254	if (ret)
    255		goto err_free;
    256
    257	list_add_tail(&req->list, &viommu->requests);
    258	return 0;
    259
    260err_free:
    261	kfree(req);
    262	return ret;
    263}
    264
    265static int viommu_add_req(struct viommu_dev *viommu, void *buf, size_t len)
    266{
    267	int ret;
    268	unsigned long flags;
    269
    270	spin_lock_irqsave(&viommu->request_lock, flags);
    271	ret = __viommu_add_req(viommu, buf, len, false);
    272	if (ret)
    273		dev_dbg(viommu->dev, "could not add request: %d\n", ret);
    274	spin_unlock_irqrestore(&viommu->request_lock, flags);
    275
    276	return ret;
    277}
    278
    279/*
    280 * Send a request and wait for it to complete. Return the request status (as an
    281 * errno)
    282 */
    283static int viommu_send_req_sync(struct viommu_dev *viommu, void *buf,
    284				size_t len)
    285{
    286	int ret;
    287	unsigned long flags;
    288
    289	spin_lock_irqsave(&viommu->request_lock, flags);
    290
    291	ret = __viommu_add_req(viommu, buf, len, true);
    292	if (ret) {
    293		dev_dbg(viommu->dev, "could not add request (%d)\n", ret);
    294		goto out_unlock;
    295	}
    296
    297	ret = __viommu_sync_req(viommu);
    298	if (ret) {
    299		dev_dbg(viommu->dev, "could not sync requests (%d)\n", ret);
    300		/* Fall-through (get the actual request status) */
    301	}
    302
    303	ret = viommu_get_req_errno(buf, len);
    304out_unlock:
    305	spin_unlock_irqrestore(&viommu->request_lock, flags);
    306	return ret;
    307}
    308
    309/*
    310 * viommu_add_mapping - add a mapping to the internal tree
    311 *
    312 * On success, return the new mapping. Otherwise return NULL.
    313 */
    314static int viommu_add_mapping(struct viommu_domain *vdomain, u64 iova, u64 end,
    315			      phys_addr_t paddr, u32 flags)
    316{
    317	unsigned long irqflags;
    318	struct viommu_mapping *mapping;
    319
    320	mapping = kzalloc(sizeof(*mapping), GFP_ATOMIC);
    321	if (!mapping)
    322		return -ENOMEM;
    323
    324	mapping->paddr		= paddr;
    325	mapping->iova.start	= iova;
    326	mapping->iova.last	= end;
    327	mapping->flags		= flags;
    328
    329	spin_lock_irqsave(&vdomain->mappings_lock, irqflags);
    330	interval_tree_insert(&mapping->iova, &vdomain->mappings);
    331	spin_unlock_irqrestore(&vdomain->mappings_lock, irqflags);
    332
    333	return 0;
    334}
    335
    336/*
    337 * viommu_del_mappings - remove mappings from the internal tree
    338 *
    339 * @vdomain: the domain
    340 * @iova: start of the range
    341 * @end: end of the range
    342 *
    343 * On success, returns the number of unmapped bytes
    344 */
    345static size_t viommu_del_mappings(struct viommu_domain *vdomain,
    346				  u64 iova, u64 end)
    347{
    348	size_t unmapped = 0;
    349	unsigned long flags;
    350	struct viommu_mapping *mapping = NULL;
    351	struct interval_tree_node *node, *next;
    352
    353	spin_lock_irqsave(&vdomain->mappings_lock, flags);
    354	next = interval_tree_iter_first(&vdomain->mappings, iova, end);
    355	while (next) {
    356		node = next;
    357		mapping = container_of(node, struct viommu_mapping, iova);
    358		next = interval_tree_iter_next(node, iova, end);
    359
    360		/* Trying to split a mapping? */
    361		if (mapping->iova.start < iova)
    362			break;
    363
    364		/*
    365		 * Virtio-iommu doesn't allow UNMAP to split a mapping created
    366		 * with a single MAP request, so remove the full mapping.
    367		 */
    368		unmapped += mapping->iova.last - mapping->iova.start + 1;
    369
    370		interval_tree_remove(node, &vdomain->mappings);
    371		kfree(mapping);
    372	}
    373	spin_unlock_irqrestore(&vdomain->mappings_lock, flags);
    374
    375	return unmapped;
    376}
    377
    378/*
    379 * Fill the domain with identity mappings, skipping the device's reserved
    380 * regions.
    381 */
    382static int viommu_domain_map_identity(struct viommu_endpoint *vdev,
    383				      struct viommu_domain *vdomain)
    384{
    385	int ret;
    386	struct iommu_resv_region *resv;
    387	u64 iova = vdomain->domain.geometry.aperture_start;
    388	u64 limit = vdomain->domain.geometry.aperture_end;
    389	u32 flags = VIRTIO_IOMMU_MAP_F_READ | VIRTIO_IOMMU_MAP_F_WRITE;
    390	unsigned long granule = 1UL << __ffs(vdomain->domain.pgsize_bitmap);
    391
    392	iova = ALIGN(iova, granule);
    393	limit = ALIGN_DOWN(limit + 1, granule) - 1;
    394
    395	list_for_each_entry(resv, &vdev->resv_regions, list) {
    396		u64 resv_start = ALIGN_DOWN(resv->start, granule);
    397		u64 resv_end = ALIGN(resv->start + resv->length, granule) - 1;
    398
    399		if (resv_end < iova || resv_start > limit)
    400			/* No overlap */
    401			continue;
    402
    403		if (resv_start > iova) {
    404			ret = viommu_add_mapping(vdomain, iova, resv_start - 1,
    405						 (phys_addr_t)iova, flags);
    406			if (ret)
    407				goto err_unmap;
    408		}
    409
    410		if (resv_end >= limit)
    411			return 0;
    412
    413		iova = resv_end + 1;
    414	}
    415
    416	ret = viommu_add_mapping(vdomain, iova, limit, (phys_addr_t)iova,
    417				 flags);
    418	if (ret)
    419		goto err_unmap;
    420	return 0;
    421
    422err_unmap:
    423	viommu_del_mappings(vdomain, 0, iova);
    424	return ret;
    425}
    426
    427/*
    428 * viommu_replay_mappings - re-send MAP requests
    429 *
    430 * When reattaching a domain that was previously detached from all endpoints,
    431 * mappings were deleted from the device. Re-create the mappings available in
    432 * the internal tree.
    433 */
    434static int viommu_replay_mappings(struct viommu_domain *vdomain)
    435{
    436	int ret = 0;
    437	unsigned long flags;
    438	struct viommu_mapping *mapping;
    439	struct interval_tree_node *node;
    440	struct virtio_iommu_req_map map;
    441
    442	spin_lock_irqsave(&vdomain->mappings_lock, flags);
    443	node = interval_tree_iter_first(&vdomain->mappings, 0, -1UL);
    444	while (node) {
    445		mapping = container_of(node, struct viommu_mapping, iova);
    446		map = (struct virtio_iommu_req_map) {
    447			.head.type	= VIRTIO_IOMMU_T_MAP,
    448			.domain		= cpu_to_le32(vdomain->id),
    449			.virt_start	= cpu_to_le64(mapping->iova.start),
    450			.virt_end	= cpu_to_le64(mapping->iova.last),
    451			.phys_start	= cpu_to_le64(mapping->paddr),
    452			.flags		= cpu_to_le32(mapping->flags),
    453		};
    454
    455		ret = viommu_send_req_sync(vdomain->viommu, &map, sizeof(map));
    456		if (ret)
    457			break;
    458
    459		node = interval_tree_iter_next(node, 0, -1UL);
    460	}
    461	spin_unlock_irqrestore(&vdomain->mappings_lock, flags);
    462
    463	return ret;
    464}
    465
    466static int viommu_add_resv_mem(struct viommu_endpoint *vdev,
    467			       struct virtio_iommu_probe_resv_mem *mem,
    468			       size_t len)
    469{
    470	size_t size;
    471	u64 start64, end64;
    472	phys_addr_t start, end;
    473	struct iommu_resv_region *region = NULL, *next;
    474	unsigned long prot = IOMMU_WRITE | IOMMU_NOEXEC | IOMMU_MMIO;
    475
    476	start = start64 = le64_to_cpu(mem->start);
    477	end = end64 = le64_to_cpu(mem->end);
    478	size = end64 - start64 + 1;
    479
    480	/* Catch any overflow, including the unlikely end64 - start64 + 1 = 0 */
    481	if (start != start64 || end != end64 || size < end64 - start64)
    482		return -EOVERFLOW;
    483
    484	if (len < sizeof(*mem))
    485		return -EINVAL;
    486
    487	switch (mem->subtype) {
    488	default:
    489		dev_warn(vdev->dev, "unknown resv mem subtype 0x%x\n",
    490			 mem->subtype);
    491		fallthrough;
    492	case VIRTIO_IOMMU_RESV_MEM_T_RESERVED:
    493		region = iommu_alloc_resv_region(start, size, 0,
    494						 IOMMU_RESV_RESERVED);
    495		break;
    496	case VIRTIO_IOMMU_RESV_MEM_T_MSI:
    497		region = iommu_alloc_resv_region(start, size, prot,
    498						 IOMMU_RESV_MSI);
    499		break;
    500	}
    501	if (!region)
    502		return -ENOMEM;
    503
    504	/* Keep the list sorted */
    505	list_for_each_entry(next, &vdev->resv_regions, list) {
    506		if (next->start > region->start)
    507			break;
    508	}
    509	list_add_tail(&region->list, &next->list);
    510	return 0;
    511}
    512
    513static int viommu_probe_endpoint(struct viommu_dev *viommu, struct device *dev)
    514{
    515	int ret;
    516	u16 type, len;
    517	size_t cur = 0;
    518	size_t probe_len;
    519	struct virtio_iommu_req_probe *probe;
    520	struct virtio_iommu_probe_property *prop;
    521	struct iommu_fwspec *fwspec = dev_iommu_fwspec_get(dev);
    522	struct viommu_endpoint *vdev = dev_iommu_priv_get(dev);
    523
    524	if (!fwspec->num_ids)
    525		return -EINVAL;
    526
    527	probe_len = sizeof(*probe) + viommu->probe_size +
    528		    sizeof(struct virtio_iommu_req_tail);
    529	probe = kzalloc(probe_len, GFP_KERNEL);
    530	if (!probe)
    531		return -ENOMEM;
    532
    533	probe->head.type = VIRTIO_IOMMU_T_PROBE;
    534	/*
    535	 * For now, assume that properties of an endpoint that outputs multiple
    536	 * IDs are consistent. Only probe the first one.
    537	 */
    538	probe->endpoint = cpu_to_le32(fwspec->ids[0]);
    539
    540	ret = viommu_send_req_sync(viommu, probe, probe_len);
    541	if (ret)
    542		goto out_free;
    543
    544	prop = (void *)probe->properties;
    545	type = le16_to_cpu(prop->type) & VIRTIO_IOMMU_PROBE_T_MASK;
    546
    547	while (type != VIRTIO_IOMMU_PROBE_T_NONE &&
    548	       cur < viommu->probe_size) {
    549		len = le16_to_cpu(prop->length) + sizeof(*prop);
    550
    551		switch (type) {
    552		case VIRTIO_IOMMU_PROBE_T_RESV_MEM:
    553			ret = viommu_add_resv_mem(vdev, (void *)prop, len);
    554			break;
    555		default:
    556			dev_err(dev, "unknown viommu prop 0x%x\n", type);
    557		}
    558
    559		if (ret)
    560			dev_err(dev, "failed to parse viommu prop 0x%x\n", type);
    561
    562		cur += len;
    563		if (cur >= viommu->probe_size)
    564			break;
    565
    566		prop = (void *)probe->properties + cur;
    567		type = le16_to_cpu(prop->type) & VIRTIO_IOMMU_PROBE_T_MASK;
    568	}
    569
    570out_free:
    571	kfree(probe);
    572	return ret;
    573}
    574
    575static int viommu_fault_handler(struct viommu_dev *viommu,
    576				struct virtio_iommu_fault *fault)
    577{
    578	char *reason_str;
    579
    580	u8 reason	= fault->reason;
    581	u32 flags	= le32_to_cpu(fault->flags);
    582	u32 endpoint	= le32_to_cpu(fault->endpoint);
    583	u64 address	= le64_to_cpu(fault->address);
    584
    585	switch (reason) {
    586	case VIRTIO_IOMMU_FAULT_R_DOMAIN:
    587		reason_str = "domain";
    588		break;
    589	case VIRTIO_IOMMU_FAULT_R_MAPPING:
    590		reason_str = "page";
    591		break;
    592	case VIRTIO_IOMMU_FAULT_R_UNKNOWN:
    593	default:
    594		reason_str = "unknown";
    595		break;
    596	}
    597
    598	/* TODO: find EP by ID and report_iommu_fault */
    599	if (flags & VIRTIO_IOMMU_FAULT_F_ADDRESS)
    600		dev_err_ratelimited(viommu->dev, "%s fault from EP %u at %#llx [%s%s%s]\n",
    601				    reason_str, endpoint, address,
    602				    flags & VIRTIO_IOMMU_FAULT_F_READ ? "R" : "",
    603				    flags & VIRTIO_IOMMU_FAULT_F_WRITE ? "W" : "",
    604				    flags & VIRTIO_IOMMU_FAULT_F_EXEC ? "X" : "");
    605	else
    606		dev_err_ratelimited(viommu->dev, "%s fault from EP %u\n",
    607				    reason_str, endpoint);
    608	return 0;
    609}
    610
    611static void viommu_event_handler(struct virtqueue *vq)
    612{
    613	int ret;
    614	unsigned int len;
    615	struct scatterlist sg[1];
    616	struct viommu_event *evt;
    617	struct viommu_dev *viommu = vq->vdev->priv;
    618
    619	while ((evt = virtqueue_get_buf(vq, &len)) != NULL) {
    620		if (len > sizeof(*evt)) {
    621			dev_err(viommu->dev,
    622				"invalid event buffer (len %u != %zu)\n",
    623				len, sizeof(*evt));
    624		} else if (!(evt->head & VIOMMU_FAULT_RESV_MASK)) {
    625			viommu_fault_handler(viommu, &evt->fault);
    626		}
    627
    628		sg_init_one(sg, evt, sizeof(*evt));
    629		ret = virtqueue_add_inbuf(vq, sg, 1, evt, GFP_ATOMIC);
    630		if (ret)
    631			dev_err(viommu->dev, "could not add event buffer\n");
    632	}
    633
    634	virtqueue_kick(vq);
    635}
    636
    637/* IOMMU API */
    638
    639static struct iommu_domain *viommu_domain_alloc(unsigned type)
    640{
    641	struct viommu_domain *vdomain;
    642
    643	if (type != IOMMU_DOMAIN_UNMANAGED &&
    644	    type != IOMMU_DOMAIN_DMA &&
    645	    type != IOMMU_DOMAIN_IDENTITY)
    646		return NULL;
    647
    648	vdomain = kzalloc(sizeof(*vdomain), GFP_KERNEL);
    649	if (!vdomain)
    650		return NULL;
    651
    652	mutex_init(&vdomain->mutex);
    653	spin_lock_init(&vdomain->mappings_lock);
    654	vdomain->mappings = RB_ROOT_CACHED;
    655
    656	return &vdomain->domain;
    657}
    658
    659static int viommu_domain_finalise(struct viommu_endpoint *vdev,
    660				  struct iommu_domain *domain)
    661{
    662	int ret;
    663	unsigned long viommu_page_size;
    664	struct viommu_dev *viommu = vdev->viommu;
    665	struct viommu_domain *vdomain = to_viommu_domain(domain);
    666
    667	viommu_page_size = 1UL << __ffs(viommu->pgsize_bitmap);
    668	if (viommu_page_size > PAGE_SIZE) {
    669		dev_err(vdev->dev,
    670			"granule 0x%lx larger than system page size 0x%lx\n",
    671			viommu_page_size, PAGE_SIZE);
    672		return -EINVAL;
    673	}
    674
    675	ret = ida_alloc_range(&viommu->domain_ids, viommu->first_domain,
    676			      viommu->last_domain, GFP_KERNEL);
    677	if (ret < 0)
    678		return ret;
    679
    680	vdomain->id		= (unsigned int)ret;
    681
    682	domain->pgsize_bitmap	= viommu->pgsize_bitmap;
    683	domain->geometry	= viommu->geometry;
    684
    685	vdomain->map_flags	= viommu->map_flags;
    686	vdomain->viommu		= viommu;
    687
    688	if (domain->type == IOMMU_DOMAIN_IDENTITY) {
    689		if (virtio_has_feature(viommu->vdev,
    690				       VIRTIO_IOMMU_F_BYPASS_CONFIG)) {
    691			vdomain->bypass = true;
    692			return 0;
    693		}
    694
    695		ret = viommu_domain_map_identity(vdev, vdomain);
    696		if (ret) {
    697			ida_free(&viommu->domain_ids, vdomain->id);
    698			vdomain->viommu = NULL;
    699			return -EOPNOTSUPP;
    700		}
    701	}
    702
    703	return 0;
    704}
    705
    706static void viommu_domain_free(struct iommu_domain *domain)
    707{
    708	struct viommu_domain *vdomain = to_viommu_domain(domain);
    709
    710	/* Free all remaining mappings */
    711	viommu_del_mappings(vdomain, 0, ULLONG_MAX);
    712
    713	if (vdomain->viommu)
    714		ida_free(&vdomain->viommu->domain_ids, vdomain->id);
    715
    716	kfree(vdomain);
    717}
    718
    719static int viommu_attach_dev(struct iommu_domain *domain, struct device *dev)
    720{
    721	int i;
    722	int ret = 0;
    723	struct virtio_iommu_req_attach req;
    724	struct iommu_fwspec *fwspec = dev_iommu_fwspec_get(dev);
    725	struct viommu_endpoint *vdev = dev_iommu_priv_get(dev);
    726	struct viommu_domain *vdomain = to_viommu_domain(domain);
    727
    728	mutex_lock(&vdomain->mutex);
    729	if (!vdomain->viommu) {
    730		/*
    731		 * Properly initialize the domain now that we know which viommu
    732		 * owns it.
    733		 */
    734		ret = viommu_domain_finalise(vdev, domain);
    735	} else if (vdomain->viommu != vdev->viommu) {
    736		dev_err(dev, "cannot attach to foreign vIOMMU\n");
    737		ret = -EXDEV;
    738	}
    739	mutex_unlock(&vdomain->mutex);
    740
    741	if (ret)
    742		return ret;
    743
    744	/*
    745	 * In the virtio-iommu device, when attaching the endpoint to a new
    746	 * domain, it is detached from the old one and, if as a result the
    747	 * old domain isn't attached to any endpoint, all mappings are removed
    748	 * from the old domain and it is freed.
    749	 *
    750	 * In the driver the old domain still exists, and its mappings will be
    751	 * recreated if it gets reattached to an endpoint. Otherwise it will be
    752	 * freed explicitly.
    753	 *
    754	 * vdev->vdomain is protected by group->mutex
    755	 */
    756	if (vdev->vdomain)
    757		vdev->vdomain->nr_endpoints--;
    758
    759	req = (struct virtio_iommu_req_attach) {
    760		.head.type	= VIRTIO_IOMMU_T_ATTACH,
    761		.domain		= cpu_to_le32(vdomain->id),
    762	};
    763
    764	if (vdomain->bypass)
    765		req.flags |= cpu_to_le32(VIRTIO_IOMMU_ATTACH_F_BYPASS);
    766
    767	for (i = 0; i < fwspec->num_ids; i++) {
    768		req.endpoint = cpu_to_le32(fwspec->ids[i]);
    769
    770		ret = viommu_send_req_sync(vdomain->viommu, &req, sizeof(req));
    771		if (ret)
    772			return ret;
    773	}
    774
    775	if (!vdomain->nr_endpoints) {
    776		/*
    777		 * This endpoint is the first to be attached to the domain.
    778		 * Replay existing mappings (e.g. SW MSI).
    779		 */
    780		ret = viommu_replay_mappings(vdomain);
    781		if (ret)
    782			return ret;
    783	}
    784
    785	vdomain->nr_endpoints++;
    786	vdev->vdomain = vdomain;
    787
    788	return 0;
    789}
    790
    791static int viommu_map(struct iommu_domain *domain, unsigned long iova,
    792		      phys_addr_t paddr, size_t size, int prot, gfp_t gfp)
    793{
    794	int ret;
    795	u32 flags;
    796	u64 end = iova + size - 1;
    797	struct virtio_iommu_req_map map;
    798	struct viommu_domain *vdomain = to_viommu_domain(domain);
    799
    800	flags = (prot & IOMMU_READ ? VIRTIO_IOMMU_MAP_F_READ : 0) |
    801		(prot & IOMMU_WRITE ? VIRTIO_IOMMU_MAP_F_WRITE : 0) |
    802		(prot & IOMMU_MMIO ? VIRTIO_IOMMU_MAP_F_MMIO : 0);
    803
    804	if (flags & ~vdomain->map_flags)
    805		return -EINVAL;
    806
    807	ret = viommu_add_mapping(vdomain, iova, end, paddr, flags);
    808	if (ret)
    809		return ret;
    810
    811	map = (struct virtio_iommu_req_map) {
    812		.head.type	= VIRTIO_IOMMU_T_MAP,
    813		.domain		= cpu_to_le32(vdomain->id),
    814		.virt_start	= cpu_to_le64(iova),
    815		.phys_start	= cpu_to_le64(paddr),
    816		.virt_end	= cpu_to_le64(end),
    817		.flags		= cpu_to_le32(flags),
    818	};
    819
    820	if (!vdomain->nr_endpoints)
    821		return 0;
    822
    823	ret = viommu_send_req_sync(vdomain->viommu, &map, sizeof(map));
    824	if (ret)
    825		viommu_del_mappings(vdomain, iova, end);
    826
    827	return ret;
    828}
    829
    830static size_t viommu_unmap(struct iommu_domain *domain, unsigned long iova,
    831			   size_t size, struct iommu_iotlb_gather *gather)
    832{
    833	int ret = 0;
    834	size_t unmapped;
    835	struct virtio_iommu_req_unmap unmap;
    836	struct viommu_domain *vdomain = to_viommu_domain(domain);
    837
    838	unmapped = viommu_del_mappings(vdomain, iova, iova + size - 1);
    839	if (unmapped < size)
    840		return 0;
    841
    842	/* Device already removed all mappings after detach. */
    843	if (!vdomain->nr_endpoints)
    844		return unmapped;
    845
    846	unmap = (struct virtio_iommu_req_unmap) {
    847		.head.type	= VIRTIO_IOMMU_T_UNMAP,
    848		.domain		= cpu_to_le32(vdomain->id),
    849		.virt_start	= cpu_to_le64(iova),
    850		.virt_end	= cpu_to_le64(iova + unmapped - 1),
    851	};
    852
    853	ret = viommu_add_req(vdomain->viommu, &unmap, sizeof(unmap));
    854	return ret ? 0 : unmapped;
    855}
    856
    857static phys_addr_t viommu_iova_to_phys(struct iommu_domain *domain,
    858				       dma_addr_t iova)
    859{
    860	u64 paddr = 0;
    861	unsigned long flags;
    862	struct viommu_mapping *mapping;
    863	struct interval_tree_node *node;
    864	struct viommu_domain *vdomain = to_viommu_domain(domain);
    865
    866	spin_lock_irqsave(&vdomain->mappings_lock, flags);
    867	node = interval_tree_iter_first(&vdomain->mappings, iova, iova);
    868	if (node) {
    869		mapping = container_of(node, struct viommu_mapping, iova);
    870		paddr = mapping->paddr + (iova - mapping->iova.start);
    871	}
    872	spin_unlock_irqrestore(&vdomain->mappings_lock, flags);
    873
    874	return paddr;
    875}
    876
    877static void viommu_iotlb_sync(struct iommu_domain *domain,
    878			      struct iommu_iotlb_gather *gather)
    879{
    880	struct viommu_domain *vdomain = to_viommu_domain(domain);
    881
    882	viommu_sync_req(vdomain->viommu);
    883}
    884
    885static void viommu_get_resv_regions(struct device *dev, struct list_head *head)
    886{
    887	struct iommu_resv_region *entry, *new_entry, *msi = NULL;
    888	struct viommu_endpoint *vdev = dev_iommu_priv_get(dev);
    889	int prot = IOMMU_WRITE | IOMMU_NOEXEC | IOMMU_MMIO;
    890
    891	list_for_each_entry(entry, &vdev->resv_regions, list) {
    892		if (entry->type == IOMMU_RESV_MSI)
    893			msi = entry;
    894
    895		new_entry = kmemdup(entry, sizeof(*entry), GFP_KERNEL);
    896		if (!new_entry)
    897			return;
    898		list_add_tail(&new_entry->list, head);
    899	}
    900
    901	/*
    902	 * If the device didn't register any bypass MSI window, add a
    903	 * software-mapped region.
    904	 */
    905	if (!msi) {
    906		msi = iommu_alloc_resv_region(MSI_IOVA_BASE, MSI_IOVA_LENGTH,
    907					      prot, IOMMU_RESV_SW_MSI);
    908		if (!msi)
    909			return;
    910
    911		list_add_tail(&msi->list, head);
    912	}
    913
    914	iommu_dma_get_resv_regions(dev, head);
    915}
    916
    917static struct iommu_ops viommu_ops;
    918static struct virtio_driver virtio_iommu_drv;
    919
    920static int viommu_match_node(struct device *dev, const void *data)
    921{
    922	return dev->parent->fwnode == data;
    923}
    924
    925static struct viommu_dev *viommu_get_by_fwnode(struct fwnode_handle *fwnode)
    926{
    927	struct device *dev = driver_find_device(&virtio_iommu_drv.driver, NULL,
    928						fwnode, viommu_match_node);
    929	put_device(dev);
    930
    931	return dev ? dev_to_virtio(dev)->priv : NULL;
    932}
    933
    934static struct iommu_device *viommu_probe_device(struct device *dev)
    935{
    936	int ret;
    937	struct viommu_endpoint *vdev;
    938	struct viommu_dev *viommu = NULL;
    939	struct iommu_fwspec *fwspec = dev_iommu_fwspec_get(dev);
    940
    941	if (!fwspec || fwspec->ops != &viommu_ops)
    942		return ERR_PTR(-ENODEV);
    943
    944	viommu = viommu_get_by_fwnode(fwspec->iommu_fwnode);
    945	if (!viommu)
    946		return ERR_PTR(-ENODEV);
    947
    948	vdev = kzalloc(sizeof(*vdev), GFP_KERNEL);
    949	if (!vdev)
    950		return ERR_PTR(-ENOMEM);
    951
    952	vdev->dev = dev;
    953	vdev->viommu = viommu;
    954	INIT_LIST_HEAD(&vdev->resv_regions);
    955	dev_iommu_priv_set(dev, vdev);
    956
    957	if (viommu->probe_size) {
    958		/* Get additional information for this endpoint */
    959		ret = viommu_probe_endpoint(viommu, dev);
    960		if (ret)
    961			goto err_free_dev;
    962	}
    963
    964	return &viommu->iommu;
    965
    966err_free_dev:
    967	generic_iommu_put_resv_regions(dev, &vdev->resv_regions);
    968	kfree(vdev);
    969
    970	return ERR_PTR(ret);
    971}
    972
    973static void viommu_probe_finalize(struct device *dev)
    974{
    975#ifndef CONFIG_ARCH_HAS_SETUP_DMA_OPS
    976	/* First clear the DMA ops in case we're switching from a DMA domain */
    977	set_dma_ops(dev, NULL);
    978	iommu_setup_dma_ops(dev, 0, U64_MAX);
    979#endif
    980}
    981
    982static void viommu_release_device(struct device *dev)
    983{
    984	struct iommu_fwspec *fwspec = dev_iommu_fwspec_get(dev);
    985	struct viommu_endpoint *vdev;
    986
    987	if (!fwspec || fwspec->ops != &viommu_ops)
    988		return;
    989
    990	vdev = dev_iommu_priv_get(dev);
    991
    992	generic_iommu_put_resv_regions(dev, &vdev->resv_regions);
    993	kfree(vdev);
    994}
    995
    996static struct iommu_group *viommu_device_group(struct device *dev)
    997{
    998	if (dev_is_pci(dev))
    999		return pci_device_group(dev);
   1000	else
   1001		return generic_device_group(dev);
   1002}
   1003
   1004static int viommu_of_xlate(struct device *dev, struct of_phandle_args *args)
   1005{
   1006	return iommu_fwspec_add_ids(dev, args->args, 1);
   1007}
   1008
   1009static struct iommu_ops viommu_ops = {
   1010	.domain_alloc		= viommu_domain_alloc,
   1011	.probe_device		= viommu_probe_device,
   1012	.probe_finalize		= viommu_probe_finalize,
   1013	.release_device		= viommu_release_device,
   1014	.device_group		= viommu_device_group,
   1015	.get_resv_regions	= viommu_get_resv_regions,
   1016	.put_resv_regions	= generic_iommu_put_resv_regions,
   1017	.of_xlate		= viommu_of_xlate,
   1018	.owner			= THIS_MODULE,
   1019	.default_domain_ops = &(const struct iommu_domain_ops) {
   1020		.attach_dev		= viommu_attach_dev,
   1021		.map			= viommu_map,
   1022		.unmap			= viommu_unmap,
   1023		.iova_to_phys		= viommu_iova_to_phys,
   1024		.iotlb_sync		= viommu_iotlb_sync,
   1025		.free			= viommu_domain_free,
   1026	}
   1027};
   1028
   1029static int viommu_init_vqs(struct viommu_dev *viommu)
   1030{
   1031	struct virtio_device *vdev = dev_to_virtio(viommu->dev);
   1032	const char *names[] = { "request", "event" };
   1033	vq_callback_t *callbacks[] = {
   1034		NULL, /* No async requests */
   1035		viommu_event_handler,
   1036	};
   1037
   1038	return virtio_find_vqs(vdev, VIOMMU_NR_VQS, viommu->vqs, callbacks,
   1039			       names, NULL);
   1040}
   1041
   1042static int viommu_fill_evtq(struct viommu_dev *viommu)
   1043{
   1044	int i, ret;
   1045	struct scatterlist sg[1];
   1046	struct viommu_event *evts;
   1047	struct virtqueue *vq = viommu->vqs[VIOMMU_EVENT_VQ];
   1048	size_t nr_evts = vq->num_free;
   1049
   1050	viommu->evts = evts = devm_kmalloc_array(viommu->dev, nr_evts,
   1051						 sizeof(*evts), GFP_KERNEL);
   1052	if (!evts)
   1053		return -ENOMEM;
   1054
   1055	for (i = 0; i < nr_evts; i++) {
   1056		sg_init_one(sg, &evts[i], sizeof(*evts));
   1057		ret = virtqueue_add_inbuf(vq, sg, 1, &evts[i], GFP_KERNEL);
   1058		if (ret)
   1059			return ret;
   1060	}
   1061
   1062	return 0;
   1063}
   1064
   1065static int viommu_probe(struct virtio_device *vdev)
   1066{
   1067	struct device *parent_dev = vdev->dev.parent;
   1068	struct viommu_dev *viommu = NULL;
   1069	struct device *dev = &vdev->dev;
   1070	u64 input_start = 0;
   1071	u64 input_end = -1UL;
   1072	int ret;
   1073
   1074	if (!virtio_has_feature(vdev, VIRTIO_F_VERSION_1) ||
   1075	    !virtio_has_feature(vdev, VIRTIO_IOMMU_F_MAP_UNMAP))
   1076		return -ENODEV;
   1077
   1078	viommu = devm_kzalloc(dev, sizeof(*viommu), GFP_KERNEL);
   1079	if (!viommu)
   1080		return -ENOMEM;
   1081
   1082	spin_lock_init(&viommu->request_lock);
   1083	ida_init(&viommu->domain_ids);
   1084	viommu->dev = dev;
   1085	viommu->vdev = vdev;
   1086	INIT_LIST_HEAD(&viommu->requests);
   1087
   1088	ret = viommu_init_vqs(viommu);
   1089	if (ret)
   1090		return ret;
   1091
   1092	virtio_cread_le(vdev, struct virtio_iommu_config, page_size_mask,
   1093			&viommu->pgsize_bitmap);
   1094
   1095	if (!viommu->pgsize_bitmap) {
   1096		ret = -EINVAL;
   1097		goto err_free_vqs;
   1098	}
   1099
   1100	viommu->map_flags = VIRTIO_IOMMU_MAP_F_READ | VIRTIO_IOMMU_MAP_F_WRITE;
   1101	viommu->last_domain = ~0U;
   1102
   1103	/* Optional features */
   1104	virtio_cread_le_feature(vdev, VIRTIO_IOMMU_F_INPUT_RANGE,
   1105				struct virtio_iommu_config, input_range.start,
   1106				&input_start);
   1107
   1108	virtio_cread_le_feature(vdev, VIRTIO_IOMMU_F_INPUT_RANGE,
   1109				struct virtio_iommu_config, input_range.end,
   1110				&input_end);
   1111
   1112	virtio_cread_le_feature(vdev, VIRTIO_IOMMU_F_DOMAIN_RANGE,
   1113				struct virtio_iommu_config, domain_range.start,
   1114				&viommu->first_domain);
   1115
   1116	virtio_cread_le_feature(vdev, VIRTIO_IOMMU_F_DOMAIN_RANGE,
   1117				struct virtio_iommu_config, domain_range.end,
   1118				&viommu->last_domain);
   1119
   1120	virtio_cread_le_feature(vdev, VIRTIO_IOMMU_F_PROBE,
   1121				struct virtio_iommu_config, probe_size,
   1122				&viommu->probe_size);
   1123
   1124	viommu->geometry = (struct iommu_domain_geometry) {
   1125		.aperture_start	= input_start,
   1126		.aperture_end	= input_end,
   1127		.force_aperture	= true,
   1128	};
   1129
   1130	if (virtio_has_feature(vdev, VIRTIO_IOMMU_F_MMIO))
   1131		viommu->map_flags |= VIRTIO_IOMMU_MAP_F_MMIO;
   1132
   1133	viommu_ops.pgsize_bitmap = viommu->pgsize_bitmap;
   1134
   1135	virtio_device_ready(vdev);
   1136
   1137	/* Populate the event queue with buffers */
   1138	ret = viommu_fill_evtq(viommu);
   1139	if (ret)
   1140		goto err_free_vqs;
   1141
   1142	ret = iommu_device_sysfs_add(&viommu->iommu, dev, NULL, "%s",
   1143				     virtio_bus_name(vdev));
   1144	if (ret)
   1145		goto err_free_vqs;
   1146
   1147	iommu_device_register(&viommu->iommu, &viommu_ops, parent_dev);
   1148
   1149#ifdef CONFIG_PCI
   1150	if (pci_bus_type.iommu_ops != &viommu_ops) {
   1151		ret = bus_set_iommu(&pci_bus_type, &viommu_ops);
   1152		if (ret)
   1153			goto err_unregister;
   1154	}
   1155#endif
   1156#ifdef CONFIG_ARM_AMBA
   1157	if (amba_bustype.iommu_ops != &viommu_ops) {
   1158		ret = bus_set_iommu(&amba_bustype, &viommu_ops);
   1159		if (ret)
   1160			goto err_unregister;
   1161	}
   1162#endif
   1163	if (platform_bus_type.iommu_ops != &viommu_ops) {
   1164		ret = bus_set_iommu(&platform_bus_type, &viommu_ops);
   1165		if (ret)
   1166			goto err_unregister;
   1167	}
   1168
   1169	vdev->priv = viommu;
   1170
   1171	dev_info(dev, "input address: %u bits\n",
   1172		 order_base_2(viommu->geometry.aperture_end));
   1173	dev_info(dev, "page mask: %#llx\n", viommu->pgsize_bitmap);
   1174
   1175	return 0;
   1176
   1177err_unregister:
   1178	iommu_device_sysfs_remove(&viommu->iommu);
   1179	iommu_device_unregister(&viommu->iommu);
   1180err_free_vqs:
   1181	vdev->config->del_vqs(vdev);
   1182
   1183	return ret;
   1184}
   1185
   1186static void viommu_remove(struct virtio_device *vdev)
   1187{
   1188	struct viommu_dev *viommu = vdev->priv;
   1189
   1190	iommu_device_sysfs_remove(&viommu->iommu);
   1191	iommu_device_unregister(&viommu->iommu);
   1192
   1193	/* Stop all virtqueues */
   1194	virtio_reset_device(vdev);
   1195	vdev->config->del_vqs(vdev);
   1196
   1197	dev_info(&vdev->dev, "device removed\n");
   1198}
   1199
   1200static void viommu_config_changed(struct virtio_device *vdev)
   1201{
   1202	dev_warn(&vdev->dev, "config changed\n");
   1203}
   1204
   1205static unsigned int features[] = {
   1206	VIRTIO_IOMMU_F_MAP_UNMAP,
   1207	VIRTIO_IOMMU_F_INPUT_RANGE,
   1208	VIRTIO_IOMMU_F_DOMAIN_RANGE,
   1209	VIRTIO_IOMMU_F_PROBE,
   1210	VIRTIO_IOMMU_F_MMIO,
   1211	VIRTIO_IOMMU_F_BYPASS_CONFIG,
   1212};
   1213
   1214static struct virtio_device_id id_table[] = {
   1215	{ VIRTIO_ID_IOMMU, VIRTIO_DEV_ANY_ID },
   1216	{ 0 },
   1217};
   1218MODULE_DEVICE_TABLE(virtio, id_table);
   1219
   1220static struct virtio_driver virtio_iommu_drv = {
   1221	.driver.name		= KBUILD_MODNAME,
   1222	.driver.owner		= THIS_MODULE,
   1223	.id_table		= id_table,
   1224	.feature_table		= features,
   1225	.feature_table_size	= ARRAY_SIZE(features),
   1226	.probe			= viommu_probe,
   1227	.remove			= viommu_remove,
   1228	.config_changed		= viommu_config_changed,
   1229};
   1230
   1231module_virtio_driver(virtio_iommu_drv);
   1232
   1233MODULE_DESCRIPTION("Virtio IOMMU driver");
   1234MODULE_AUTHOR("Jean-Philippe Brucker <jean-philippe.brucker@arm.com>");
   1235MODULE_LICENSE("GPL v2");