cachepc-linux

Fork of AMDESE/linux with modifications for CachePC side-channel attack
git clone https://git.sinitax.com/sinitax/cachepc-linux
Log | Files | Refs | README | LICENSE | sfeed.txt

iommu_v2.c (22119B)


      1// SPDX-License-Identifier: GPL-2.0-only
      2/*
      3 * Copyright (C) 2010-2012 Advanced Micro Devices, Inc.
      4 * Author: Joerg Roedel <jroedel@suse.de>
      5 */
      6
      7#define pr_fmt(fmt)     "AMD-Vi: " fmt
      8
      9#include <linux/refcount.h>
     10#include <linux/mmu_notifier.h>
     11#include <linux/amd-iommu.h>
     12#include <linux/mm_types.h>
     13#include <linux/profile.h>
     14#include <linux/module.h>
     15#include <linux/sched.h>
     16#include <linux/sched/mm.h>
     17#include <linux/wait.h>
     18#include <linux/pci.h>
     19#include <linux/gfp.h>
     20#include <linux/cc_platform.h>
     21
     22#include "amd_iommu.h"
     23
     24MODULE_LICENSE("GPL v2");
     25MODULE_AUTHOR("Joerg Roedel <jroedel@suse.de>");
     26
     27#define PRI_QUEUE_SIZE		512
     28
     29struct pri_queue {
     30	atomic_t inflight;
     31	bool finish;
     32	int status;
     33};
     34
     35struct pasid_state {
     36	struct list_head list;			/* For global state-list */
     37	refcount_t count;				/* Reference count */
     38	unsigned mmu_notifier_count;		/* Counting nested mmu_notifier
     39						   calls */
     40	struct mm_struct *mm;			/* mm_struct for the faults */
     41	struct mmu_notifier mn;                 /* mmu_notifier handle */
     42	struct pri_queue pri[PRI_QUEUE_SIZE];	/* PRI tag states */
     43	struct device_state *device_state;	/* Link to our device_state */
     44	u32 pasid;				/* PASID index */
     45	bool invalid;				/* Used during setup and
     46						   teardown of the pasid */
     47	spinlock_t lock;			/* Protect pri_queues and
     48						   mmu_notifer_count */
     49	wait_queue_head_t wq;			/* To wait for count == 0 */
     50};
     51
     52struct device_state {
     53	struct list_head list;
     54	u32 sbdf;
     55	atomic_t count;
     56	struct pci_dev *pdev;
     57	struct pasid_state **states;
     58	struct iommu_domain *domain;
     59	int pasid_levels;
     60	int max_pasids;
     61	amd_iommu_invalid_ppr_cb inv_ppr_cb;
     62	amd_iommu_invalidate_ctx inv_ctx_cb;
     63	spinlock_t lock;
     64	wait_queue_head_t wq;
     65};
     66
     67struct fault {
     68	struct work_struct work;
     69	struct device_state *dev_state;
     70	struct pasid_state *state;
     71	struct mm_struct *mm;
     72	u64 address;
     73	u32 pasid;
     74	u16 tag;
     75	u16 finish;
     76	u16 flags;
     77};
     78
     79static LIST_HEAD(state_list);
     80static DEFINE_SPINLOCK(state_lock);
     81
     82static struct workqueue_struct *iommu_wq;
     83
     84static void free_pasid_states(struct device_state *dev_state);
     85
     86static struct device_state *__get_device_state(u32 sbdf)
     87{
     88	struct device_state *dev_state;
     89
     90	list_for_each_entry(dev_state, &state_list, list) {
     91		if (dev_state->sbdf == sbdf)
     92			return dev_state;
     93	}
     94
     95	return NULL;
     96}
     97
     98static struct device_state *get_device_state(u32 sbdf)
     99{
    100	struct device_state *dev_state;
    101	unsigned long flags;
    102
    103	spin_lock_irqsave(&state_lock, flags);
    104	dev_state = __get_device_state(sbdf);
    105	if (dev_state != NULL)
    106		atomic_inc(&dev_state->count);
    107	spin_unlock_irqrestore(&state_lock, flags);
    108
    109	return dev_state;
    110}
    111
    112static void free_device_state(struct device_state *dev_state)
    113{
    114	struct iommu_group *group;
    115
    116	/* Get rid of any remaining pasid states */
    117	free_pasid_states(dev_state);
    118
    119	/*
    120	 * Wait until the last reference is dropped before freeing
    121	 * the device state.
    122	 */
    123	wait_event(dev_state->wq, !atomic_read(&dev_state->count));
    124
    125	/*
    126	 * First detach device from domain - No more PRI requests will arrive
    127	 * from that device after it is unbound from the IOMMUv2 domain.
    128	 */
    129	group = iommu_group_get(&dev_state->pdev->dev);
    130	if (WARN_ON(!group))
    131		return;
    132
    133	iommu_detach_group(dev_state->domain, group);
    134
    135	iommu_group_put(group);
    136
    137	/* Everything is down now, free the IOMMUv2 domain */
    138	iommu_domain_free(dev_state->domain);
    139
    140	/* Finally get rid of the device-state */
    141	kfree(dev_state);
    142}
    143
    144static void put_device_state(struct device_state *dev_state)
    145{
    146	if (atomic_dec_and_test(&dev_state->count))
    147		wake_up(&dev_state->wq);
    148}
    149
    150/* Must be called under dev_state->lock */
    151static struct pasid_state **__get_pasid_state_ptr(struct device_state *dev_state,
    152						  u32 pasid, bool alloc)
    153{
    154	struct pasid_state **root, **ptr;
    155	int level, index;
    156
    157	level = dev_state->pasid_levels;
    158	root  = dev_state->states;
    159
    160	while (true) {
    161
    162		index = (pasid >> (9 * level)) & 0x1ff;
    163		ptr   = &root[index];
    164
    165		if (level == 0)
    166			break;
    167
    168		if (*ptr == NULL) {
    169			if (!alloc)
    170				return NULL;
    171
    172			*ptr = (void *)get_zeroed_page(GFP_ATOMIC);
    173			if (*ptr == NULL)
    174				return NULL;
    175		}
    176
    177		root   = (struct pasid_state **)*ptr;
    178		level -= 1;
    179	}
    180
    181	return ptr;
    182}
    183
    184static int set_pasid_state(struct device_state *dev_state,
    185			   struct pasid_state *pasid_state,
    186			   u32 pasid)
    187{
    188	struct pasid_state **ptr;
    189	unsigned long flags;
    190	int ret;
    191
    192	spin_lock_irqsave(&dev_state->lock, flags);
    193	ptr = __get_pasid_state_ptr(dev_state, pasid, true);
    194
    195	ret = -ENOMEM;
    196	if (ptr == NULL)
    197		goto out_unlock;
    198
    199	ret = -ENOMEM;
    200	if (*ptr != NULL)
    201		goto out_unlock;
    202
    203	*ptr = pasid_state;
    204
    205	ret = 0;
    206
    207out_unlock:
    208	spin_unlock_irqrestore(&dev_state->lock, flags);
    209
    210	return ret;
    211}
    212
    213static void clear_pasid_state(struct device_state *dev_state, u32 pasid)
    214{
    215	struct pasid_state **ptr;
    216	unsigned long flags;
    217
    218	spin_lock_irqsave(&dev_state->lock, flags);
    219	ptr = __get_pasid_state_ptr(dev_state, pasid, true);
    220
    221	if (ptr == NULL)
    222		goto out_unlock;
    223
    224	*ptr = NULL;
    225
    226out_unlock:
    227	spin_unlock_irqrestore(&dev_state->lock, flags);
    228}
    229
    230static struct pasid_state *get_pasid_state(struct device_state *dev_state,
    231					   u32 pasid)
    232{
    233	struct pasid_state **ptr, *ret = NULL;
    234	unsigned long flags;
    235
    236	spin_lock_irqsave(&dev_state->lock, flags);
    237	ptr = __get_pasid_state_ptr(dev_state, pasid, false);
    238
    239	if (ptr == NULL)
    240		goto out_unlock;
    241
    242	ret = *ptr;
    243	if (ret)
    244		refcount_inc(&ret->count);
    245
    246out_unlock:
    247	spin_unlock_irqrestore(&dev_state->lock, flags);
    248
    249	return ret;
    250}
    251
    252static void free_pasid_state(struct pasid_state *pasid_state)
    253{
    254	kfree(pasid_state);
    255}
    256
    257static void put_pasid_state(struct pasid_state *pasid_state)
    258{
    259	if (refcount_dec_and_test(&pasid_state->count))
    260		wake_up(&pasid_state->wq);
    261}
    262
    263static void put_pasid_state_wait(struct pasid_state *pasid_state)
    264{
    265	refcount_dec(&pasid_state->count);
    266	wait_event(pasid_state->wq, !refcount_read(&pasid_state->count));
    267	free_pasid_state(pasid_state);
    268}
    269
    270static void unbind_pasid(struct pasid_state *pasid_state)
    271{
    272	struct iommu_domain *domain;
    273
    274	domain = pasid_state->device_state->domain;
    275
    276	/*
    277	 * Mark pasid_state as invalid, no more faults will we added to the
    278	 * work queue after this is visible everywhere.
    279	 */
    280	pasid_state->invalid = true;
    281
    282	/* Make sure this is visible */
    283	smp_wmb();
    284
    285	/* After this the device/pasid can't access the mm anymore */
    286	amd_iommu_domain_clear_gcr3(domain, pasid_state->pasid);
    287
    288	/* Make sure no more pending faults are in the queue */
    289	flush_workqueue(iommu_wq);
    290}
    291
    292static void free_pasid_states_level1(struct pasid_state **tbl)
    293{
    294	int i;
    295
    296	for (i = 0; i < 512; ++i) {
    297		if (tbl[i] == NULL)
    298			continue;
    299
    300		free_page((unsigned long)tbl[i]);
    301	}
    302}
    303
    304static void free_pasid_states_level2(struct pasid_state **tbl)
    305{
    306	struct pasid_state **ptr;
    307	int i;
    308
    309	for (i = 0; i < 512; ++i) {
    310		if (tbl[i] == NULL)
    311			continue;
    312
    313		ptr = (struct pasid_state **)tbl[i];
    314		free_pasid_states_level1(ptr);
    315	}
    316}
    317
    318static void free_pasid_states(struct device_state *dev_state)
    319{
    320	struct pasid_state *pasid_state;
    321	int i;
    322
    323	for (i = 0; i < dev_state->max_pasids; ++i) {
    324		pasid_state = get_pasid_state(dev_state, i);
    325		if (pasid_state == NULL)
    326			continue;
    327
    328		put_pasid_state(pasid_state);
    329
    330		/*
    331		 * This will call the mn_release function and
    332		 * unbind the PASID
    333		 */
    334		mmu_notifier_unregister(&pasid_state->mn, pasid_state->mm);
    335
    336		put_pasid_state_wait(pasid_state); /* Reference taken in
    337						      amd_iommu_bind_pasid */
    338
    339		/* Drop reference taken in amd_iommu_bind_pasid */
    340		put_device_state(dev_state);
    341	}
    342
    343	if (dev_state->pasid_levels == 2)
    344		free_pasid_states_level2(dev_state->states);
    345	else if (dev_state->pasid_levels == 1)
    346		free_pasid_states_level1(dev_state->states);
    347	else
    348		BUG_ON(dev_state->pasid_levels != 0);
    349
    350	free_page((unsigned long)dev_state->states);
    351}
    352
    353static struct pasid_state *mn_to_state(struct mmu_notifier *mn)
    354{
    355	return container_of(mn, struct pasid_state, mn);
    356}
    357
    358static void mn_invalidate_range(struct mmu_notifier *mn,
    359				struct mm_struct *mm,
    360				unsigned long start, unsigned long end)
    361{
    362	struct pasid_state *pasid_state;
    363	struct device_state *dev_state;
    364
    365	pasid_state = mn_to_state(mn);
    366	dev_state   = pasid_state->device_state;
    367
    368	if ((start ^ (end - 1)) < PAGE_SIZE)
    369		amd_iommu_flush_page(dev_state->domain, pasid_state->pasid,
    370				     start);
    371	else
    372		amd_iommu_flush_tlb(dev_state->domain, pasid_state->pasid);
    373}
    374
    375static void mn_release(struct mmu_notifier *mn, struct mm_struct *mm)
    376{
    377	struct pasid_state *pasid_state;
    378	struct device_state *dev_state;
    379	bool run_inv_ctx_cb;
    380
    381	might_sleep();
    382
    383	pasid_state    = mn_to_state(mn);
    384	dev_state      = pasid_state->device_state;
    385	run_inv_ctx_cb = !pasid_state->invalid;
    386
    387	if (run_inv_ctx_cb && dev_state->inv_ctx_cb)
    388		dev_state->inv_ctx_cb(dev_state->pdev, pasid_state->pasid);
    389
    390	unbind_pasid(pasid_state);
    391}
    392
    393static const struct mmu_notifier_ops iommu_mn = {
    394	.release		= mn_release,
    395	.invalidate_range       = mn_invalidate_range,
    396};
    397
    398static void set_pri_tag_status(struct pasid_state *pasid_state,
    399			       u16 tag, int status)
    400{
    401	unsigned long flags;
    402
    403	spin_lock_irqsave(&pasid_state->lock, flags);
    404	pasid_state->pri[tag].status = status;
    405	spin_unlock_irqrestore(&pasid_state->lock, flags);
    406}
    407
    408static void finish_pri_tag(struct device_state *dev_state,
    409			   struct pasid_state *pasid_state,
    410			   u16 tag)
    411{
    412	unsigned long flags;
    413
    414	spin_lock_irqsave(&pasid_state->lock, flags);
    415	if (atomic_dec_and_test(&pasid_state->pri[tag].inflight) &&
    416	    pasid_state->pri[tag].finish) {
    417		amd_iommu_complete_ppr(dev_state->pdev, pasid_state->pasid,
    418				       pasid_state->pri[tag].status, tag);
    419		pasid_state->pri[tag].finish = false;
    420		pasid_state->pri[tag].status = PPR_SUCCESS;
    421	}
    422	spin_unlock_irqrestore(&pasid_state->lock, flags);
    423}
    424
    425static void handle_fault_error(struct fault *fault)
    426{
    427	int status;
    428
    429	if (!fault->dev_state->inv_ppr_cb) {
    430		set_pri_tag_status(fault->state, fault->tag, PPR_INVALID);
    431		return;
    432	}
    433
    434	status = fault->dev_state->inv_ppr_cb(fault->dev_state->pdev,
    435					      fault->pasid,
    436					      fault->address,
    437					      fault->flags);
    438	switch (status) {
    439	case AMD_IOMMU_INV_PRI_RSP_SUCCESS:
    440		set_pri_tag_status(fault->state, fault->tag, PPR_SUCCESS);
    441		break;
    442	case AMD_IOMMU_INV_PRI_RSP_INVALID:
    443		set_pri_tag_status(fault->state, fault->tag, PPR_INVALID);
    444		break;
    445	case AMD_IOMMU_INV_PRI_RSP_FAIL:
    446		set_pri_tag_status(fault->state, fault->tag, PPR_FAILURE);
    447		break;
    448	default:
    449		BUG();
    450	}
    451}
    452
    453static bool access_error(struct vm_area_struct *vma, struct fault *fault)
    454{
    455	unsigned long requested = 0;
    456
    457	if (fault->flags & PPR_FAULT_EXEC)
    458		requested |= VM_EXEC;
    459
    460	if (fault->flags & PPR_FAULT_READ)
    461		requested |= VM_READ;
    462
    463	if (fault->flags & PPR_FAULT_WRITE)
    464		requested |= VM_WRITE;
    465
    466	return (requested & ~vma->vm_flags) != 0;
    467}
    468
    469static void do_fault(struct work_struct *work)
    470{
    471	struct fault *fault = container_of(work, struct fault, work);
    472	struct vm_area_struct *vma;
    473	vm_fault_t ret = VM_FAULT_ERROR;
    474	unsigned int flags = 0;
    475	struct mm_struct *mm;
    476	u64 address;
    477
    478	mm = fault->state->mm;
    479	address = fault->address;
    480
    481	if (fault->flags & PPR_FAULT_USER)
    482		flags |= FAULT_FLAG_USER;
    483	if (fault->flags & PPR_FAULT_WRITE)
    484		flags |= FAULT_FLAG_WRITE;
    485	flags |= FAULT_FLAG_REMOTE;
    486
    487	mmap_read_lock(mm);
    488	vma = find_extend_vma(mm, address);
    489	if (!vma || address < vma->vm_start)
    490		/* failed to get a vma in the right range */
    491		goto out;
    492
    493	/* Check if we have the right permissions on the vma */
    494	if (access_error(vma, fault))
    495		goto out;
    496
    497	ret = handle_mm_fault(vma, address, flags, NULL);
    498out:
    499	mmap_read_unlock(mm);
    500
    501	if (ret & VM_FAULT_ERROR)
    502		/* failed to service fault */
    503		handle_fault_error(fault);
    504
    505	finish_pri_tag(fault->dev_state, fault->state, fault->tag);
    506
    507	put_pasid_state(fault->state);
    508
    509	kfree(fault);
    510}
    511
    512static int ppr_notifier(struct notifier_block *nb, unsigned long e, void *data)
    513{
    514	struct amd_iommu_fault *iommu_fault;
    515	struct pasid_state *pasid_state;
    516	struct device_state *dev_state;
    517	struct pci_dev *pdev = NULL;
    518	unsigned long flags;
    519	struct fault *fault;
    520	bool finish;
    521	u16 tag, devid, seg_id;
    522	int ret;
    523
    524	iommu_fault = data;
    525	tag         = iommu_fault->tag & 0x1ff;
    526	finish      = (iommu_fault->tag >> 9) & 1;
    527
    528	seg_id = PCI_SBDF_TO_SEGID(iommu_fault->sbdf);
    529	devid = PCI_SBDF_TO_DEVID(iommu_fault->sbdf);
    530	pdev = pci_get_domain_bus_and_slot(seg_id, PCI_BUS_NUM(devid),
    531					   devid & 0xff);
    532	if (!pdev)
    533		return -ENODEV;
    534
    535	ret = NOTIFY_DONE;
    536
    537	/* In kdump kernel pci dev is not initialized yet -> send INVALID */
    538	if (amd_iommu_is_attach_deferred(&pdev->dev)) {
    539		amd_iommu_complete_ppr(pdev, iommu_fault->pasid,
    540				       PPR_INVALID, tag);
    541		goto out;
    542	}
    543
    544	dev_state = get_device_state(iommu_fault->sbdf);
    545	if (dev_state == NULL)
    546		goto out;
    547
    548	pasid_state = get_pasid_state(dev_state, iommu_fault->pasid);
    549	if (pasid_state == NULL || pasid_state->invalid) {
    550		/* We know the device but not the PASID -> send INVALID */
    551		amd_iommu_complete_ppr(dev_state->pdev, iommu_fault->pasid,
    552				       PPR_INVALID, tag);
    553		goto out_drop_state;
    554	}
    555
    556	spin_lock_irqsave(&pasid_state->lock, flags);
    557	atomic_inc(&pasid_state->pri[tag].inflight);
    558	if (finish)
    559		pasid_state->pri[tag].finish = true;
    560	spin_unlock_irqrestore(&pasid_state->lock, flags);
    561
    562	fault = kzalloc(sizeof(*fault), GFP_ATOMIC);
    563	if (fault == NULL) {
    564		/* We are OOM - send success and let the device re-fault */
    565		finish_pri_tag(dev_state, pasid_state, tag);
    566		goto out_drop_state;
    567	}
    568
    569	fault->dev_state = dev_state;
    570	fault->address   = iommu_fault->address;
    571	fault->state     = pasid_state;
    572	fault->tag       = tag;
    573	fault->finish    = finish;
    574	fault->pasid     = iommu_fault->pasid;
    575	fault->flags     = iommu_fault->flags;
    576	INIT_WORK(&fault->work, do_fault);
    577
    578	queue_work(iommu_wq, &fault->work);
    579
    580	ret = NOTIFY_OK;
    581
    582out_drop_state:
    583
    584	if (ret != NOTIFY_OK && pasid_state)
    585		put_pasid_state(pasid_state);
    586
    587	put_device_state(dev_state);
    588
    589out:
    590	return ret;
    591}
    592
    593static struct notifier_block ppr_nb = {
    594	.notifier_call = ppr_notifier,
    595};
    596
    597int amd_iommu_bind_pasid(struct pci_dev *pdev, u32 pasid,
    598			 struct task_struct *task)
    599{
    600	struct pasid_state *pasid_state;
    601	struct device_state *dev_state;
    602	struct mm_struct *mm;
    603	u32 sbdf;
    604	int ret;
    605
    606	might_sleep();
    607
    608	if (!amd_iommu_v2_supported())
    609		return -ENODEV;
    610
    611	sbdf      = get_pci_sbdf_id(pdev);
    612	dev_state = get_device_state(sbdf);
    613
    614	if (dev_state == NULL)
    615		return -EINVAL;
    616
    617	ret = -EINVAL;
    618	if (pasid >= dev_state->max_pasids)
    619		goto out;
    620
    621	ret = -ENOMEM;
    622	pasid_state = kzalloc(sizeof(*pasid_state), GFP_KERNEL);
    623	if (pasid_state == NULL)
    624		goto out;
    625
    626
    627	refcount_set(&pasid_state->count, 1);
    628	init_waitqueue_head(&pasid_state->wq);
    629	spin_lock_init(&pasid_state->lock);
    630
    631	mm                        = get_task_mm(task);
    632	pasid_state->mm           = mm;
    633	pasid_state->device_state = dev_state;
    634	pasid_state->pasid        = pasid;
    635	pasid_state->invalid      = true; /* Mark as valid only if we are
    636					     done with setting up the pasid */
    637	pasid_state->mn.ops       = &iommu_mn;
    638
    639	if (pasid_state->mm == NULL)
    640		goto out_free;
    641
    642	mmu_notifier_register(&pasid_state->mn, mm);
    643
    644	ret = set_pasid_state(dev_state, pasid_state, pasid);
    645	if (ret)
    646		goto out_unregister;
    647
    648	ret = amd_iommu_domain_set_gcr3(dev_state->domain, pasid,
    649					__pa(pasid_state->mm->pgd));
    650	if (ret)
    651		goto out_clear_state;
    652
    653	/* Now we are ready to handle faults */
    654	pasid_state->invalid = false;
    655
    656	/*
    657	 * Drop the reference to the mm_struct here. We rely on the
    658	 * mmu_notifier release call-back to inform us when the mm
    659	 * is going away.
    660	 */
    661	mmput(mm);
    662
    663	return 0;
    664
    665out_clear_state:
    666	clear_pasid_state(dev_state, pasid);
    667
    668out_unregister:
    669	mmu_notifier_unregister(&pasid_state->mn, mm);
    670	mmput(mm);
    671
    672out_free:
    673	free_pasid_state(pasid_state);
    674
    675out:
    676	put_device_state(dev_state);
    677
    678	return ret;
    679}
    680EXPORT_SYMBOL(amd_iommu_bind_pasid);
    681
    682void amd_iommu_unbind_pasid(struct pci_dev *pdev, u32 pasid)
    683{
    684	struct pasid_state *pasid_state;
    685	struct device_state *dev_state;
    686	u32 sbdf;
    687
    688	might_sleep();
    689
    690	if (!amd_iommu_v2_supported())
    691		return;
    692
    693	sbdf = get_pci_sbdf_id(pdev);
    694	dev_state = get_device_state(sbdf);
    695	if (dev_state == NULL)
    696		return;
    697
    698	if (pasid >= dev_state->max_pasids)
    699		goto out;
    700
    701	pasid_state = get_pasid_state(dev_state, pasid);
    702	if (pasid_state == NULL)
    703		goto out;
    704	/*
    705	 * Drop reference taken here. We are safe because we still hold
    706	 * the reference taken in the amd_iommu_bind_pasid function.
    707	 */
    708	put_pasid_state(pasid_state);
    709
    710	/* Clear the pasid state so that the pasid can be re-used */
    711	clear_pasid_state(dev_state, pasid_state->pasid);
    712
    713	/*
    714	 * Call mmu_notifier_unregister to drop our reference
    715	 * to pasid_state->mm
    716	 */
    717	mmu_notifier_unregister(&pasid_state->mn, pasid_state->mm);
    718
    719	put_pasid_state_wait(pasid_state); /* Reference taken in
    720					      amd_iommu_bind_pasid */
    721out:
    722	/* Drop reference taken in this function */
    723	put_device_state(dev_state);
    724
    725	/* Drop reference taken in amd_iommu_bind_pasid */
    726	put_device_state(dev_state);
    727}
    728EXPORT_SYMBOL(amd_iommu_unbind_pasid);
    729
    730int amd_iommu_init_device(struct pci_dev *pdev, int pasids)
    731{
    732	struct device_state *dev_state;
    733	struct iommu_group *group;
    734	unsigned long flags;
    735	int ret, tmp;
    736	u32 sbdf;
    737
    738	might_sleep();
    739
    740	/*
    741	 * When memory encryption is active the device is likely not in a
    742	 * direct-mapped domain. Forbid using IOMMUv2 functionality for now.
    743	 */
    744	if (cc_platform_has(CC_ATTR_MEM_ENCRYPT))
    745		return -ENODEV;
    746
    747	if (!amd_iommu_v2_supported())
    748		return -ENODEV;
    749
    750	if (pasids <= 0 || pasids > (PASID_MASK + 1))
    751		return -EINVAL;
    752
    753	sbdf = get_pci_sbdf_id(pdev);
    754
    755	dev_state = kzalloc(sizeof(*dev_state), GFP_KERNEL);
    756	if (dev_state == NULL)
    757		return -ENOMEM;
    758
    759	spin_lock_init(&dev_state->lock);
    760	init_waitqueue_head(&dev_state->wq);
    761	dev_state->pdev  = pdev;
    762	dev_state->sbdf = sbdf;
    763
    764	tmp = pasids;
    765	for (dev_state->pasid_levels = 0; (tmp - 1) & ~0x1ff; tmp >>= 9)
    766		dev_state->pasid_levels += 1;
    767
    768	atomic_set(&dev_state->count, 1);
    769	dev_state->max_pasids = pasids;
    770
    771	ret = -ENOMEM;
    772	dev_state->states = (void *)get_zeroed_page(GFP_KERNEL);
    773	if (dev_state->states == NULL)
    774		goto out_free_dev_state;
    775
    776	dev_state->domain = iommu_domain_alloc(&pci_bus_type);
    777	if (dev_state->domain == NULL)
    778		goto out_free_states;
    779
    780	amd_iommu_domain_direct_map(dev_state->domain);
    781
    782	ret = amd_iommu_domain_enable_v2(dev_state->domain, pasids);
    783	if (ret)
    784		goto out_free_domain;
    785
    786	group = iommu_group_get(&pdev->dev);
    787	if (!group) {
    788		ret = -EINVAL;
    789		goto out_free_domain;
    790	}
    791
    792	ret = iommu_attach_group(dev_state->domain, group);
    793	if (ret != 0)
    794		goto out_drop_group;
    795
    796	iommu_group_put(group);
    797
    798	spin_lock_irqsave(&state_lock, flags);
    799
    800	if (__get_device_state(sbdf) != NULL) {
    801		spin_unlock_irqrestore(&state_lock, flags);
    802		ret = -EBUSY;
    803		goto out_free_domain;
    804	}
    805
    806	list_add_tail(&dev_state->list, &state_list);
    807
    808	spin_unlock_irqrestore(&state_lock, flags);
    809
    810	return 0;
    811
    812out_drop_group:
    813	iommu_group_put(group);
    814
    815out_free_domain:
    816	iommu_domain_free(dev_state->domain);
    817
    818out_free_states:
    819	free_page((unsigned long)dev_state->states);
    820
    821out_free_dev_state:
    822	kfree(dev_state);
    823
    824	return ret;
    825}
    826EXPORT_SYMBOL(amd_iommu_init_device);
    827
    828void amd_iommu_free_device(struct pci_dev *pdev)
    829{
    830	struct device_state *dev_state;
    831	unsigned long flags;
    832	u32 sbdf;
    833
    834	if (!amd_iommu_v2_supported())
    835		return;
    836
    837	sbdf = get_pci_sbdf_id(pdev);
    838
    839	spin_lock_irqsave(&state_lock, flags);
    840
    841	dev_state = __get_device_state(sbdf);
    842	if (dev_state == NULL) {
    843		spin_unlock_irqrestore(&state_lock, flags);
    844		return;
    845	}
    846
    847	list_del(&dev_state->list);
    848
    849	spin_unlock_irqrestore(&state_lock, flags);
    850
    851	put_device_state(dev_state);
    852	free_device_state(dev_state);
    853}
    854EXPORT_SYMBOL(amd_iommu_free_device);
    855
    856int amd_iommu_set_invalid_ppr_cb(struct pci_dev *pdev,
    857				 amd_iommu_invalid_ppr_cb cb)
    858{
    859	struct device_state *dev_state;
    860	unsigned long flags;
    861	u32 sbdf;
    862	int ret;
    863
    864	if (!amd_iommu_v2_supported())
    865		return -ENODEV;
    866
    867	sbdf = get_pci_sbdf_id(pdev);
    868
    869	spin_lock_irqsave(&state_lock, flags);
    870
    871	ret = -EINVAL;
    872	dev_state = __get_device_state(sbdf);
    873	if (dev_state == NULL)
    874		goto out_unlock;
    875
    876	dev_state->inv_ppr_cb = cb;
    877
    878	ret = 0;
    879
    880out_unlock:
    881	spin_unlock_irqrestore(&state_lock, flags);
    882
    883	return ret;
    884}
    885EXPORT_SYMBOL(amd_iommu_set_invalid_ppr_cb);
    886
    887int amd_iommu_set_invalidate_ctx_cb(struct pci_dev *pdev,
    888				    amd_iommu_invalidate_ctx cb)
    889{
    890	struct device_state *dev_state;
    891	unsigned long flags;
    892	u32 sbdf;
    893	int ret;
    894
    895	if (!amd_iommu_v2_supported())
    896		return -ENODEV;
    897
    898	sbdf = get_pci_sbdf_id(pdev);
    899
    900	spin_lock_irqsave(&state_lock, flags);
    901
    902	ret = -EINVAL;
    903	dev_state = __get_device_state(sbdf);
    904	if (dev_state == NULL)
    905		goto out_unlock;
    906
    907	dev_state->inv_ctx_cb = cb;
    908
    909	ret = 0;
    910
    911out_unlock:
    912	spin_unlock_irqrestore(&state_lock, flags);
    913
    914	return ret;
    915}
    916EXPORT_SYMBOL(amd_iommu_set_invalidate_ctx_cb);
    917
    918static int __init amd_iommu_v2_init(void)
    919{
    920	int ret;
    921
    922	if (!amd_iommu_v2_supported()) {
    923		pr_info("AMD IOMMUv2 functionality not available on this system - This is not a bug.\n");
    924		/*
    925		 * Load anyway to provide the symbols to other modules
    926		 * which may use AMD IOMMUv2 optionally.
    927		 */
    928		return 0;
    929	}
    930
    931	ret = -ENOMEM;
    932	iommu_wq = alloc_workqueue("amd_iommu_v2", WQ_MEM_RECLAIM, 0);
    933	if (iommu_wq == NULL)
    934		goto out;
    935
    936	amd_iommu_register_ppr_notifier(&ppr_nb);
    937
    938	pr_info("AMD IOMMUv2 loaded and initialized\n");
    939
    940	return 0;
    941
    942out:
    943	return ret;
    944}
    945
    946static void __exit amd_iommu_v2_exit(void)
    947{
    948	struct device_state *dev_state, *next;
    949	unsigned long flags;
    950	LIST_HEAD(freelist);
    951
    952	if (!amd_iommu_v2_supported())
    953		return;
    954
    955	amd_iommu_unregister_ppr_notifier(&ppr_nb);
    956
    957	flush_workqueue(iommu_wq);
    958
    959	/*
    960	 * The loop below might call flush_workqueue(), so call
    961	 * destroy_workqueue() after it
    962	 */
    963	spin_lock_irqsave(&state_lock, flags);
    964
    965	list_for_each_entry_safe(dev_state, next, &state_list, list) {
    966		WARN_ON_ONCE(1);
    967
    968		put_device_state(dev_state);
    969		list_del(&dev_state->list);
    970		list_add_tail(&dev_state->list, &freelist);
    971	}
    972
    973	spin_unlock_irqrestore(&state_lock, flags);
    974
    975	/*
    976	 * Since free_device_state waits on the count to be zero,
    977	 * we need to free dev_state outside the spinlock.
    978	 */
    979	list_for_each_entry_safe(dev_state, next, &freelist, list) {
    980		list_del(&dev_state->list);
    981		free_device_state(dev_state);
    982	}
    983
    984	destroy_workqueue(iommu_wq);
    985}
    986
    987module_init(amd_iommu_v2_init);
    988module_exit(amd_iommu_v2_exit);