cachepc-linux

Fork of AMDESE/linux with modifications for CachePC side-channel attack
git clone https://git.sinitax.com/sinitax/cachepc-linux
Log | Files | Refs | README | LICENSE | sfeed.txt

test_hmm.c (30484B)


      1// SPDX-License-Identifier: GPL-2.0
      2/*
      3 * This is a module to test the HMM (Heterogeneous Memory Management)
      4 * mirror and zone device private memory migration APIs of the kernel.
      5 * Userspace programs can register with the driver to mirror their own address
      6 * space and can use the device to read/write any valid virtual address.
      7 */
      8#include <linux/init.h>
      9#include <linux/fs.h>
     10#include <linux/mm.h>
     11#include <linux/module.h>
     12#include <linux/kernel.h>
     13#include <linux/cdev.h>
     14#include <linux/device.h>
     15#include <linux/memremap.h>
     16#include <linux/mutex.h>
     17#include <linux/rwsem.h>
     18#include <linux/sched.h>
     19#include <linux/slab.h>
     20#include <linux/highmem.h>
     21#include <linux/delay.h>
     22#include <linux/pagemap.h>
     23#include <linux/hmm.h>
     24#include <linux/vmalloc.h>
     25#include <linux/swap.h>
     26#include <linux/swapops.h>
     27#include <linux/sched/mm.h>
     28#include <linux/platform_device.h>
     29#include <linux/rmap.h>
     30#include <linux/mmu_notifier.h>
     31#include <linux/migrate.h>
     32
     33#include "test_hmm_uapi.h"
     34
     35#define DMIRROR_NDEVICES		2
     36#define DMIRROR_RANGE_FAULT_TIMEOUT	1000
     37#define DEVMEM_CHUNK_SIZE		(256 * 1024 * 1024U)
     38#define DEVMEM_CHUNKS_RESERVE		16
     39
     40static const struct dev_pagemap_ops dmirror_devmem_ops;
     41static const struct mmu_interval_notifier_ops dmirror_min_ops;
     42static dev_t dmirror_dev;
     43
     44struct dmirror_device;
     45
     46struct dmirror_bounce {
     47	void			*ptr;
     48	unsigned long		size;
     49	unsigned long		addr;
     50	unsigned long		cpages;
     51};
     52
     53#define DPT_XA_TAG_ATOMIC 1UL
     54#define DPT_XA_TAG_WRITE 3UL
     55
     56/*
     57 * Data structure to track address ranges and register for mmu interval
     58 * notifier updates.
     59 */
     60struct dmirror_interval {
     61	struct mmu_interval_notifier	notifier;
     62	struct dmirror			*dmirror;
     63};
     64
     65/*
     66 * Data attached to the open device file.
     67 * Note that it might be shared after a fork().
     68 */
     69struct dmirror {
     70	struct dmirror_device		*mdevice;
     71	struct xarray			pt;
     72	struct mmu_interval_notifier	notifier;
     73	struct mutex			mutex;
     74};
     75
     76/*
     77 * ZONE_DEVICE pages for migration and simulating device memory.
     78 */
     79struct dmirror_chunk {
     80	struct dev_pagemap	pagemap;
     81	struct dmirror_device	*mdevice;
     82};
     83
     84/*
     85 * Per device data.
     86 */
     87struct dmirror_device {
     88	struct cdev		cdevice;
     89	struct hmm_devmem	*devmem;
     90
     91	unsigned int		devmem_capacity;
     92	unsigned int		devmem_count;
     93	struct dmirror_chunk	**devmem_chunks;
     94	struct mutex		devmem_lock;	/* protects the above */
     95
     96	unsigned long		calloc;
     97	unsigned long		cfree;
     98	struct page		*free_pages;
     99	spinlock_t		lock;		/* protects the above */
    100};
    101
    102static struct dmirror_device dmirror_devices[DMIRROR_NDEVICES];
    103
    104static int dmirror_bounce_init(struct dmirror_bounce *bounce,
    105			       unsigned long addr,
    106			       unsigned long size)
    107{
    108	bounce->addr = addr;
    109	bounce->size = size;
    110	bounce->cpages = 0;
    111	bounce->ptr = vmalloc(size);
    112	if (!bounce->ptr)
    113		return -ENOMEM;
    114	return 0;
    115}
    116
    117static void dmirror_bounce_fini(struct dmirror_bounce *bounce)
    118{
    119	vfree(bounce->ptr);
    120}
    121
    122static int dmirror_fops_open(struct inode *inode, struct file *filp)
    123{
    124	struct cdev *cdev = inode->i_cdev;
    125	struct dmirror *dmirror;
    126	int ret;
    127
    128	/* Mirror this process address space */
    129	dmirror = kzalloc(sizeof(*dmirror), GFP_KERNEL);
    130	if (dmirror == NULL)
    131		return -ENOMEM;
    132
    133	dmirror->mdevice = container_of(cdev, struct dmirror_device, cdevice);
    134	mutex_init(&dmirror->mutex);
    135	xa_init(&dmirror->pt);
    136
    137	ret = mmu_interval_notifier_insert(&dmirror->notifier, current->mm,
    138				0, ULONG_MAX & PAGE_MASK, &dmirror_min_ops);
    139	if (ret) {
    140		kfree(dmirror);
    141		return ret;
    142	}
    143
    144	filp->private_data = dmirror;
    145	return 0;
    146}
    147
    148static int dmirror_fops_release(struct inode *inode, struct file *filp)
    149{
    150	struct dmirror *dmirror = filp->private_data;
    151
    152	mmu_interval_notifier_remove(&dmirror->notifier);
    153	xa_destroy(&dmirror->pt);
    154	kfree(dmirror);
    155	return 0;
    156}
    157
    158static struct dmirror_device *dmirror_page_to_device(struct page *page)
    159
    160{
    161	return container_of(page->pgmap, struct dmirror_chunk,
    162			    pagemap)->mdevice;
    163}
    164
    165static int dmirror_do_fault(struct dmirror *dmirror, struct hmm_range *range)
    166{
    167	unsigned long *pfns = range->hmm_pfns;
    168	unsigned long pfn;
    169
    170	for (pfn = (range->start >> PAGE_SHIFT);
    171	     pfn < (range->end >> PAGE_SHIFT);
    172	     pfn++, pfns++) {
    173		struct page *page;
    174		void *entry;
    175
    176		/*
    177		 * Since we asked for hmm_range_fault() to populate pages,
    178		 * it shouldn't return an error entry on success.
    179		 */
    180		WARN_ON(*pfns & HMM_PFN_ERROR);
    181		WARN_ON(!(*pfns & HMM_PFN_VALID));
    182
    183		page = hmm_pfn_to_page(*pfns);
    184		WARN_ON(!page);
    185
    186		entry = page;
    187		if (*pfns & HMM_PFN_WRITE)
    188			entry = xa_tag_pointer(entry, DPT_XA_TAG_WRITE);
    189		else if (WARN_ON(range->default_flags & HMM_PFN_WRITE))
    190			return -EFAULT;
    191		entry = xa_store(&dmirror->pt, pfn, entry, GFP_ATOMIC);
    192		if (xa_is_err(entry))
    193			return xa_err(entry);
    194	}
    195
    196	return 0;
    197}
    198
    199static void dmirror_do_update(struct dmirror *dmirror, unsigned long start,
    200			      unsigned long end)
    201{
    202	unsigned long pfn;
    203	void *entry;
    204
    205	/*
    206	 * The XArray doesn't hold references to pages since it relies on
    207	 * the mmu notifier to clear page pointers when they become stale.
    208	 * Therefore, it is OK to just clear the entry.
    209	 */
    210	xa_for_each_range(&dmirror->pt, pfn, entry, start >> PAGE_SHIFT,
    211			  end >> PAGE_SHIFT)
    212		xa_erase(&dmirror->pt, pfn);
    213}
    214
    215static bool dmirror_interval_invalidate(struct mmu_interval_notifier *mni,
    216				const struct mmu_notifier_range *range,
    217				unsigned long cur_seq)
    218{
    219	struct dmirror *dmirror = container_of(mni, struct dmirror, notifier);
    220
    221	/*
    222	 * Ignore invalidation callbacks for device private pages since
    223	 * the invalidation is handled as part of the migration process.
    224	 */
    225	if (range->event == MMU_NOTIFY_MIGRATE &&
    226	    range->owner == dmirror->mdevice)
    227		return true;
    228
    229	if (mmu_notifier_range_blockable(range))
    230		mutex_lock(&dmirror->mutex);
    231	else if (!mutex_trylock(&dmirror->mutex))
    232		return false;
    233
    234	mmu_interval_set_seq(mni, cur_seq);
    235	dmirror_do_update(dmirror, range->start, range->end);
    236
    237	mutex_unlock(&dmirror->mutex);
    238	return true;
    239}
    240
    241static const struct mmu_interval_notifier_ops dmirror_min_ops = {
    242	.invalidate = dmirror_interval_invalidate,
    243};
    244
    245static int dmirror_range_fault(struct dmirror *dmirror,
    246				struct hmm_range *range)
    247{
    248	struct mm_struct *mm = dmirror->notifier.mm;
    249	unsigned long timeout =
    250		jiffies + msecs_to_jiffies(HMM_RANGE_DEFAULT_TIMEOUT);
    251	int ret;
    252
    253	while (true) {
    254		if (time_after(jiffies, timeout)) {
    255			ret = -EBUSY;
    256			goto out;
    257		}
    258
    259		range->notifier_seq = mmu_interval_read_begin(range->notifier);
    260		mmap_read_lock(mm);
    261		ret = hmm_range_fault(range);
    262		mmap_read_unlock(mm);
    263		if (ret) {
    264			if (ret == -EBUSY)
    265				continue;
    266			goto out;
    267		}
    268
    269		mutex_lock(&dmirror->mutex);
    270		if (mmu_interval_read_retry(range->notifier,
    271					    range->notifier_seq)) {
    272			mutex_unlock(&dmirror->mutex);
    273			continue;
    274		}
    275		break;
    276	}
    277
    278	ret = dmirror_do_fault(dmirror, range);
    279
    280	mutex_unlock(&dmirror->mutex);
    281out:
    282	return ret;
    283}
    284
    285static int dmirror_fault(struct dmirror *dmirror, unsigned long start,
    286			 unsigned long end, bool write)
    287{
    288	struct mm_struct *mm = dmirror->notifier.mm;
    289	unsigned long addr;
    290	unsigned long pfns[64];
    291	struct hmm_range range = {
    292		.notifier = &dmirror->notifier,
    293		.hmm_pfns = pfns,
    294		.pfn_flags_mask = 0,
    295		.default_flags =
    296			HMM_PFN_REQ_FAULT | (write ? HMM_PFN_REQ_WRITE : 0),
    297		.dev_private_owner = dmirror->mdevice,
    298	};
    299	int ret = 0;
    300
    301	/* Since the mm is for the mirrored process, get a reference first. */
    302	if (!mmget_not_zero(mm))
    303		return 0;
    304
    305	for (addr = start; addr < end; addr = range.end) {
    306		range.start = addr;
    307		range.end = min(addr + (ARRAY_SIZE(pfns) << PAGE_SHIFT), end);
    308
    309		ret = dmirror_range_fault(dmirror, &range);
    310		if (ret)
    311			break;
    312	}
    313
    314	mmput(mm);
    315	return ret;
    316}
    317
    318static int dmirror_do_read(struct dmirror *dmirror, unsigned long start,
    319			   unsigned long end, struct dmirror_bounce *bounce)
    320{
    321	unsigned long pfn;
    322	void *ptr;
    323
    324	ptr = bounce->ptr + ((start - bounce->addr) & PAGE_MASK);
    325
    326	for (pfn = start >> PAGE_SHIFT; pfn < (end >> PAGE_SHIFT); pfn++) {
    327		void *entry;
    328		struct page *page;
    329		void *tmp;
    330
    331		entry = xa_load(&dmirror->pt, pfn);
    332		page = xa_untag_pointer(entry);
    333		if (!page)
    334			return -ENOENT;
    335
    336		tmp = kmap(page);
    337		memcpy(ptr, tmp, PAGE_SIZE);
    338		kunmap(page);
    339
    340		ptr += PAGE_SIZE;
    341		bounce->cpages++;
    342	}
    343
    344	return 0;
    345}
    346
    347static int dmirror_read(struct dmirror *dmirror, struct hmm_dmirror_cmd *cmd)
    348{
    349	struct dmirror_bounce bounce;
    350	unsigned long start, end;
    351	unsigned long size = cmd->npages << PAGE_SHIFT;
    352	int ret;
    353
    354	start = cmd->addr;
    355	end = start + size;
    356	if (end < start)
    357		return -EINVAL;
    358
    359	ret = dmirror_bounce_init(&bounce, start, size);
    360	if (ret)
    361		return ret;
    362
    363	while (1) {
    364		mutex_lock(&dmirror->mutex);
    365		ret = dmirror_do_read(dmirror, start, end, &bounce);
    366		mutex_unlock(&dmirror->mutex);
    367		if (ret != -ENOENT)
    368			break;
    369
    370		start = cmd->addr + (bounce.cpages << PAGE_SHIFT);
    371		ret = dmirror_fault(dmirror, start, end, false);
    372		if (ret)
    373			break;
    374		cmd->faults++;
    375	}
    376
    377	if (ret == 0) {
    378		if (copy_to_user(u64_to_user_ptr(cmd->ptr), bounce.ptr,
    379				 bounce.size))
    380			ret = -EFAULT;
    381	}
    382	cmd->cpages = bounce.cpages;
    383	dmirror_bounce_fini(&bounce);
    384	return ret;
    385}
    386
    387static int dmirror_do_write(struct dmirror *dmirror, unsigned long start,
    388			    unsigned long end, struct dmirror_bounce *bounce)
    389{
    390	unsigned long pfn;
    391	void *ptr;
    392
    393	ptr = bounce->ptr + ((start - bounce->addr) & PAGE_MASK);
    394
    395	for (pfn = start >> PAGE_SHIFT; pfn < (end >> PAGE_SHIFT); pfn++) {
    396		void *entry;
    397		struct page *page;
    398		void *tmp;
    399
    400		entry = xa_load(&dmirror->pt, pfn);
    401		page = xa_untag_pointer(entry);
    402		if (!page || xa_pointer_tag(entry) != DPT_XA_TAG_WRITE)
    403			return -ENOENT;
    404
    405		tmp = kmap(page);
    406		memcpy(tmp, ptr, PAGE_SIZE);
    407		kunmap(page);
    408
    409		ptr += PAGE_SIZE;
    410		bounce->cpages++;
    411	}
    412
    413	return 0;
    414}
    415
    416static int dmirror_write(struct dmirror *dmirror, struct hmm_dmirror_cmd *cmd)
    417{
    418	struct dmirror_bounce bounce;
    419	unsigned long start, end;
    420	unsigned long size = cmd->npages << PAGE_SHIFT;
    421	int ret;
    422
    423	start = cmd->addr;
    424	end = start + size;
    425	if (end < start)
    426		return -EINVAL;
    427
    428	ret = dmirror_bounce_init(&bounce, start, size);
    429	if (ret)
    430		return ret;
    431	if (copy_from_user(bounce.ptr, u64_to_user_ptr(cmd->ptr),
    432			   bounce.size)) {
    433		ret = -EFAULT;
    434		goto fini;
    435	}
    436
    437	while (1) {
    438		mutex_lock(&dmirror->mutex);
    439		ret = dmirror_do_write(dmirror, start, end, &bounce);
    440		mutex_unlock(&dmirror->mutex);
    441		if (ret != -ENOENT)
    442			break;
    443
    444		start = cmd->addr + (bounce.cpages << PAGE_SHIFT);
    445		ret = dmirror_fault(dmirror, start, end, true);
    446		if (ret)
    447			break;
    448		cmd->faults++;
    449	}
    450
    451fini:
    452	cmd->cpages = bounce.cpages;
    453	dmirror_bounce_fini(&bounce);
    454	return ret;
    455}
    456
    457static bool dmirror_allocate_chunk(struct dmirror_device *mdevice,
    458				   struct page **ppage)
    459{
    460	struct dmirror_chunk *devmem;
    461	struct resource *res;
    462	unsigned long pfn;
    463	unsigned long pfn_first;
    464	unsigned long pfn_last;
    465	void *ptr;
    466
    467	devmem = kzalloc(sizeof(*devmem), GFP_KERNEL);
    468	if (!devmem)
    469		return false;
    470
    471	res = request_free_mem_region(&iomem_resource, DEVMEM_CHUNK_SIZE,
    472				      "hmm_dmirror");
    473	if (IS_ERR(res))
    474		goto err_devmem;
    475
    476	devmem->pagemap.type = MEMORY_DEVICE_PRIVATE;
    477	devmem->pagemap.range.start = res->start;
    478	devmem->pagemap.range.end = res->end;
    479	devmem->pagemap.nr_range = 1;
    480	devmem->pagemap.ops = &dmirror_devmem_ops;
    481	devmem->pagemap.owner = mdevice;
    482
    483	mutex_lock(&mdevice->devmem_lock);
    484
    485	if (mdevice->devmem_count == mdevice->devmem_capacity) {
    486		struct dmirror_chunk **new_chunks;
    487		unsigned int new_capacity;
    488
    489		new_capacity = mdevice->devmem_capacity +
    490				DEVMEM_CHUNKS_RESERVE;
    491		new_chunks = krealloc(mdevice->devmem_chunks,
    492				sizeof(new_chunks[0]) * new_capacity,
    493				GFP_KERNEL);
    494		if (!new_chunks)
    495			goto err_release;
    496		mdevice->devmem_capacity = new_capacity;
    497		mdevice->devmem_chunks = new_chunks;
    498	}
    499
    500	ptr = memremap_pages(&devmem->pagemap, numa_node_id());
    501	if (IS_ERR(ptr))
    502		goto err_release;
    503
    504	devmem->mdevice = mdevice;
    505	pfn_first = devmem->pagemap.range.start >> PAGE_SHIFT;
    506	pfn_last = pfn_first + (range_len(&devmem->pagemap.range) >> PAGE_SHIFT);
    507	mdevice->devmem_chunks[mdevice->devmem_count++] = devmem;
    508
    509	mutex_unlock(&mdevice->devmem_lock);
    510
    511	pr_info("added new %u MB chunk (total %u chunks, %u MB) PFNs [0x%lx 0x%lx)\n",
    512		DEVMEM_CHUNK_SIZE / (1024 * 1024),
    513		mdevice->devmem_count,
    514		mdevice->devmem_count * (DEVMEM_CHUNK_SIZE / (1024 * 1024)),
    515		pfn_first, pfn_last);
    516
    517	spin_lock(&mdevice->lock);
    518	for (pfn = pfn_first; pfn < pfn_last; pfn++) {
    519		struct page *page = pfn_to_page(pfn);
    520
    521		page->zone_device_data = mdevice->free_pages;
    522		mdevice->free_pages = page;
    523	}
    524	if (ppage) {
    525		*ppage = mdevice->free_pages;
    526		mdevice->free_pages = (*ppage)->zone_device_data;
    527		mdevice->calloc++;
    528	}
    529	spin_unlock(&mdevice->lock);
    530
    531	return true;
    532
    533err_release:
    534	mutex_unlock(&mdevice->devmem_lock);
    535	release_mem_region(devmem->pagemap.range.start, range_len(&devmem->pagemap.range));
    536err_devmem:
    537	kfree(devmem);
    538
    539	return false;
    540}
    541
    542static struct page *dmirror_devmem_alloc_page(struct dmirror_device *mdevice)
    543{
    544	struct page *dpage = NULL;
    545	struct page *rpage;
    546
    547	/*
    548	 * This is a fake device so we alloc real system memory to store
    549	 * our device memory.
    550	 */
    551	rpage = alloc_page(GFP_HIGHUSER);
    552	if (!rpage)
    553		return NULL;
    554
    555	spin_lock(&mdevice->lock);
    556
    557	if (mdevice->free_pages) {
    558		dpage = mdevice->free_pages;
    559		mdevice->free_pages = dpage->zone_device_data;
    560		mdevice->calloc++;
    561		spin_unlock(&mdevice->lock);
    562	} else {
    563		spin_unlock(&mdevice->lock);
    564		if (!dmirror_allocate_chunk(mdevice, &dpage))
    565			goto error;
    566	}
    567
    568	dpage->zone_device_data = rpage;
    569	lock_page(dpage);
    570	return dpage;
    571
    572error:
    573	__free_page(rpage);
    574	return NULL;
    575}
    576
    577static void dmirror_migrate_alloc_and_copy(struct migrate_vma *args,
    578					   struct dmirror *dmirror)
    579{
    580	struct dmirror_device *mdevice = dmirror->mdevice;
    581	const unsigned long *src = args->src;
    582	unsigned long *dst = args->dst;
    583	unsigned long addr;
    584
    585	for (addr = args->start; addr < args->end; addr += PAGE_SIZE,
    586						   src++, dst++) {
    587		struct page *spage;
    588		struct page *dpage;
    589		struct page *rpage;
    590
    591		if (!(*src & MIGRATE_PFN_MIGRATE))
    592			continue;
    593
    594		/*
    595		 * Note that spage might be NULL which is OK since it is an
    596		 * unallocated pte_none() or read-only zero page.
    597		 */
    598		spage = migrate_pfn_to_page(*src);
    599
    600		dpage = dmirror_devmem_alloc_page(mdevice);
    601		if (!dpage)
    602			continue;
    603
    604		rpage = dpage->zone_device_data;
    605		if (spage)
    606			copy_highpage(rpage, spage);
    607		else
    608			clear_highpage(rpage);
    609
    610		/*
    611		 * Normally, a device would use the page->zone_device_data to
    612		 * point to the mirror but here we use it to hold the page for
    613		 * the simulated device memory and that page holds the pointer
    614		 * to the mirror.
    615		 */
    616		rpage->zone_device_data = dmirror;
    617
    618		*dst = migrate_pfn(page_to_pfn(dpage));
    619		if ((*src & MIGRATE_PFN_WRITE) ||
    620		    (!spage && args->vma->vm_flags & VM_WRITE))
    621			*dst |= MIGRATE_PFN_WRITE;
    622	}
    623}
    624
    625static int dmirror_check_atomic(struct dmirror *dmirror, unsigned long start,
    626			     unsigned long end)
    627{
    628	unsigned long pfn;
    629
    630	for (pfn = start >> PAGE_SHIFT; pfn < (end >> PAGE_SHIFT); pfn++) {
    631		void *entry;
    632
    633		entry = xa_load(&dmirror->pt, pfn);
    634		if (xa_pointer_tag(entry) == DPT_XA_TAG_ATOMIC)
    635			return -EPERM;
    636	}
    637
    638	return 0;
    639}
    640
    641static int dmirror_atomic_map(unsigned long start, unsigned long end,
    642			      struct page **pages, struct dmirror *dmirror)
    643{
    644	unsigned long pfn, mapped = 0;
    645	int i;
    646
    647	/* Map the migrated pages into the device's page tables. */
    648	mutex_lock(&dmirror->mutex);
    649
    650	for (i = 0, pfn = start >> PAGE_SHIFT; pfn < (end >> PAGE_SHIFT); pfn++, i++) {
    651		void *entry;
    652
    653		if (!pages[i])
    654			continue;
    655
    656		entry = pages[i];
    657		entry = xa_tag_pointer(entry, DPT_XA_TAG_ATOMIC);
    658		entry = xa_store(&dmirror->pt, pfn, entry, GFP_ATOMIC);
    659		if (xa_is_err(entry)) {
    660			mutex_unlock(&dmirror->mutex);
    661			return xa_err(entry);
    662		}
    663
    664		mapped++;
    665	}
    666
    667	mutex_unlock(&dmirror->mutex);
    668	return mapped;
    669}
    670
    671static int dmirror_migrate_finalize_and_map(struct migrate_vma *args,
    672					    struct dmirror *dmirror)
    673{
    674	unsigned long start = args->start;
    675	unsigned long end = args->end;
    676	const unsigned long *src = args->src;
    677	const unsigned long *dst = args->dst;
    678	unsigned long pfn;
    679
    680	/* Map the migrated pages into the device's page tables. */
    681	mutex_lock(&dmirror->mutex);
    682
    683	for (pfn = start >> PAGE_SHIFT; pfn < (end >> PAGE_SHIFT); pfn++,
    684								src++, dst++) {
    685		struct page *dpage;
    686		void *entry;
    687
    688		if (!(*src & MIGRATE_PFN_MIGRATE))
    689			continue;
    690
    691		dpage = migrate_pfn_to_page(*dst);
    692		if (!dpage)
    693			continue;
    694
    695		/*
    696		 * Store the page that holds the data so the page table
    697		 * doesn't have to deal with ZONE_DEVICE private pages.
    698		 */
    699		entry = dpage->zone_device_data;
    700		if (*dst & MIGRATE_PFN_WRITE)
    701			entry = xa_tag_pointer(entry, DPT_XA_TAG_WRITE);
    702		entry = xa_store(&dmirror->pt, pfn, entry, GFP_ATOMIC);
    703		if (xa_is_err(entry)) {
    704			mutex_unlock(&dmirror->mutex);
    705			return xa_err(entry);
    706		}
    707	}
    708
    709	mutex_unlock(&dmirror->mutex);
    710	return 0;
    711}
    712
    713static int dmirror_exclusive(struct dmirror *dmirror,
    714			     struct hmm_dmirror_cmd *cmd)
    715{
    716	unsigned long start, end, addr;
    717	unsigned long size = cmd->npages << PAGE_SHIFT;
    718	struct mm_struct *mm = dmirror->notifier.mm;
    719	struct page *pages[64];
    720	struct dmirror_bounce bounce;
    721	unsigned long next;
    722	int ret;
    723
    724	start = cmd->addr;
    725	end = start + size;
    726	if (end < start)
    727		return -EINVAL;
    728
    729	/* Since the mm is for the mirrored process, get a reference first. */
    730	if (!mmget_not_zero(mm))
    731		return -EINVAL;
    732
    733	mmap_read_lock(mm);
    734	for (addr = start; addr < end; addr = next) {
    735		unsigned long mapped;
    736		int i;
    737
    738		if (end < addr + (ARRAY_SIZE(pages) << PAGE_SHIFT))
    739			next = end;
    740		else
    741			next = addr + (ARRAY_SIZE(pages) << PAGE_SHIFT);
    742
    743		ret = make_device_exclusive_range(mm, addr, next, pages, NULL);
    744		mapped = dmirror_atomic_map(addr, next, pages, dmirror);
    745		for (i = 0; i < ret; i++) {
    746			if (pages[i]) {
    747				unlock_page(pages[i]);
    748				put_page(pages[i]);
    749			}
    750		}
    751
    752		if (addr + (mapped << PAGE_SHIFT) < next) {
    753			mmap_read_unlock(mm);
    754			mmput(mm);
    755			return -EBUSY;
    756		}
    757	}
    758	mmap_read_unlock(mm);
    759	mmput(mm);
    760
    761	/* Return the migrated data for verification. */
    762	ret = dmirror_bounce_init(&bounce, start, size);
    763	if (ret)
    764		return ret;
    765	mutex_lock(&dmirror->mutex);
    766	ret = dmirror_do_read(dmirror, start, end, &bounce);
    767	mutex_unlock(&dmirror->mutex);
    768	if (ret == 0) {
    769		if (copy_to_user(u64_to_user_ptr(cmd->ptr), bounce.ptr,
    770				 bounce.size))
    771			ret = -EFAULT;
    772	}
    773
    774	cmd->cpages = bounce.cpages;
    775	dmirror_bounce_fini(&bounce);
    776	return ret;
    777}
    778
    779static int dmirror_migrate(struct dmirror *dmirror,
    780			   struct hmm_dmirror_cmd *cmd)
    781{
    782	unsigned long start, end, addr;
    783	unsigned long size = cmd->npages << PAGE_SHIFT;
    784	struct mm_struct *mm = dmirror->notifier.mm;
    785	struct vm_area_struct *vma;
    786	unsigned long src_pfns[64];
    787	unsigned long dst_pfns[64];
    788	struct dmirror_bounce bounce;
    789	struct migrate_vma args;
    790	unsigned long next;
    791	int ret;
    792
    793	start = cmd->addr;
    794	end = start + size;
    795	if (end < start)
    796		return -EINVAL;
    797
    798	/* Since the mm is for the mirrored process, get a reference first. */
    799	if (!mmget_not_zero(mm))
    800		return -EINVAL;
    801
    802	mmap_read_lock(mm);
    803	for (addr = start; addr < end; addr = next) {
    804		vma = vma_lookup(mm, addr);
    805		if (!vma || !(vma->vm_flags & VM_READ)) {
    806			ret = -EINVAL;
    807			goto out;
    808		}
    809		next = min(end, addr + (ARRAY_SIZE(src_pfns) << PAGE_SHIFT));
    810		if (next > vma->vm_end)
    811			next = vma->vm_end;
    812
    813		args.vma = vma;
    814		args.src = src_pfns;
    815		args.dst = dst_pfns;
    816		args.start = addr;
    817		args.end = next;
    818		args.pgmap_owner = dmirror->mdevice;
    819		args.flags = MIGRATE_VMA_SELECT_SYSTEM;
    820		ret = migrate_vma_setup(&args);
    821		if (ret)
    822			goto out;
    823
    824		dmirror_migrate_alloc_and_copy(&args, dmirror);
    825		migrate_vma_pages(&args);
    826		dmirror_migrate_finalize_and_map(&args, dmirror);
    827		migrate_vma_finalize(&args);
    828	}
    829	mmap_read_unlock(mm);
    830	mmput(mm);
    831
    832	/* Return the migrated data for verification. */
    833	ret = dmirror_bounce_init(&bounce, start, size);
    834	if (ret)
    835		return ret;
    836	mutex_lock(&dmirror->mutex);
    837	ret = dmirror_do_read(dmirror, start, end, &bounce);
    838	mutex_unlock(&dmirror->mutex);
    839	if (ret == 0) {
    840		if (copy_to_user(u64_to_user_ptr(cmd->ptr), bounce.ptr,
    841				 bounce.size))
    842			ret = -EFAULT;
    843	}
    844	cmd->cpages = bounce.cpages;
    845	dmirror_bounce_fini(&bounce);
    846	return ret;
    847
    848out:
    849	mmap_read_unlock(mm);
    850	mmput(mm);
    851	return ret;
    852}
    853
    854static void dmirror_mkentry(struct dmirror *dmirror, struct hmm_range *range,
    855			    unsigned char *perm, unsigned long entry)
    856{
    857	struct page *page;
    858
    859	if (entry & HMM_PFN_ERROR) {
    860		*perm = HMM_DMIRROR_PROT_ERROR;
    861		return;
    862	}
    863	if (!(entry & HMM_PFN_VALID)) {
    864		*perm = HMM_DMIRROR_PROT_NONE;
    865		return;
    866	}
    867
    868	page = hmm_pfn_to_page(entry);
    869	if (is_device_private_page(page)) {
    870		/* Is the page migrated to this device or some other? */
    871		if (dmirror->mdevice == dmirror_page_to_device(page))
    872			*perm = HMM_DMIRROR_PROT_DEV_PRIVATE_LOCAL;
    873		else
    874			*perm = HMM_DMIRROR_PROT_DEV_PRIVATE_REMOTE;
    875	} else if (is_zero_pfn(page_to_pfn(page)))
    876		*perm = HMM_DMIRROR_PROT_ZERO;
    877	else
    878		*perm = HMM_DMIRROR_PROT_NONE;
    879	if (entry & HMM_PFN_WRITE)
    880		*perm |= HMM_DMIRROR_PROT_WRITE;
    881	else
    882		*perm |= HMM_DMIRROR_PROT_READ;
    883	if (hmm_pfn_to_map_order(entry) + PAGE_SHIFT == PMD_SHIFT)
    884		*perm |= HMM_DMIRROR_PROT_PMD;
    885	else if (hmm_pfn_to_map_order(entry) + PAGE_SHIFT == PUD_SHIFT)
    886		*perm |= HMM_DMIRROR_PROT_PUD;
    887}
    888
    889static bool dmirror_snapshot_invalidate(struct mmu_interval_notifier *mni,
    890				const struct mmu_notifier_range *range,
    891				unsigned long cur_seq)
    892{
    893	struct dmirror_interval *dmi =
    894		container_of(mni, struct dmirror_interval, notifier);
    895	struct dmirror *dmirror = dmi->dmirror;
    896
    897	if (mmu_notifier_range_blockable(range))
    898		mutex_lock(&dmirror->mutex);
    899	else if (!mutex_trylock(&dmirror->mutex))
    900		return false;
    901
    902	/*
    903	 * Snapshots only need to set the sequence number since any
    904	 * invalidation in the interval invalidates the whole snapshot.
    905	 */
    906	mmu_interval_set_seq(mni, cur_seq);
    907
    908	mutex_unlock(&dmirror->mutex);
    909	return true;
    910}
    911
    912static const struct mmu_interval_notifier_ops dmirror_mrn_ops = {
    913	.invalidate = dmirror_snapshot_invalidate,
    914};
    915
    916static int dmirror_range_snapshot(struct dmirror *dmirror,
    917				  struct hmm_range *range,
    918				  unsigned char *perm)
    919{
    920	struct mm_struct *mm = dmirror->notifier.mm;
    921	struct dmirror_interval notifier;
    922	unsigned long timeout =
    923		jiffies + msecs_to_jiffies(HMM_RANGE_DEFAULT_TIMEOUT);
    924	unsigned long i;
    925	unsigned long n;
    926	int ret = 0;
    927
    928	notifier.dmirror = dmirror;
    929	range->notifier = &notifier.notifier;
    930
    931	ret = mmu_interval_notifier_insert(range->notifier, mm,
    932			range->start, range->end - range->start,
    933			&dmirror_mrn_ops);
    934	if (ret)
    935		return ret;
    936
    937	while (true) {
    938		if (time_after(jiffies, timeout)) {
    939			ret = -EBUSY;
    940			goto out;
    941		}
    942
    943		range->notifier_seq = mmu_interval_read_begin(range->notifier);
    944
    945		mmap_read_lock(mm);
    946		ret = hmm_range_fault(range);
    947		mmap_read_unlock(mm);
    948		if (ret) {
    949			if (ret == -EBUSY)
    950				continue;
    951			goto out;
    952		}
    953
    954		mutex_lock(&dmirror->mutex);
    955		if (mmu_interval_read_retry(range->notifier,
    956					    range->notifier_seq)) {
    957			mutex_unlock(&dmirror->mutex);
    958			continue;
    959		}
    960		break;
    961	}
    962
    963	n = (range->end - range->start) >> PAGE_SHIFT;
    964	for (i = 0; i < n; i++)
    965		dmirror_mkentry(dmirror, range, perm + i, range->hmm_pfns[i]);
    966
    967	mutex_unlock(&dmirror->mutex);
    968out:
    969	mmu_interval_notifier_remove(range->notifier);
    970	return ret;
    971}
    972
    973static int dmirror_snapshot(struct dmirror *dmirror,
    974			    struct hmm_dmirror_cmd *cmd)
    975{
    976	struct mm_struct *mm = dmirror->notifier.mm;
    977	unsigned long start, end;
    978	unsigned long size = cmd->npages << PAGE_SHIFT;
    979	unsigned long addr;
    980	unsigned long next;
    981	unsigned long pfns[64];
    982	unsigned char perm[64];
    983	char __user *uptr;
    984	struct hmm_range range = {
    985		.hmm_pfns = pfns,
    986		.dev_private_owner = dmirror->mdevice,
    987	};
    988	int ret = 0;
    989
    990	start = cmd->addr;
    991	end = start + size;
    992	if (end < start)
    993		return -EINVAL;
    994
    995	/* Since the mm is for the mirrored process, get a reference first. */
    996	if (!mmget_not_zero(mm))
    997		return -EINVAL;
    998
    999	/*
   1000	 * Register a temporary notifier to detect invalidations even if it
   1001	 * overlaps with other mmu_interval_notifiers.
   1002	 */
   1003	uptr = u64_to_user_ptr(cmd->ptr);
   1004	for (addr = start; addr < end; addr = next) {
   1005		unsigned long n;
   1006
   1007		next = min(addr + (ARRAY_SIZE(pfns) << PAGE_SHIFT), end);
   1008		range.start = addr;
   1009		range.end = next;
   1010
   1011		ret = dmirror_range_snapshot(dmirror, &range, perm);
   1012		if (ret)
   1013			break;
   1014
   1015		n = (range.end - range.start) >> PAGE_SHIFT;
   1016		if (copy_to_user(uptr, perm, n)) {
   1017			ret = -EFAULT;
   1018			break;
   1019		}
   1020
   1021		cmd->cpages += n;
   1022		uptr += n;
   1023	}
   1024	mmput(mm);
   1025
   1026	return ret;
   1027}
   1028
   1029static long dmirror_fops_unlocked_ioctl(struct file *filp,
   1030					unsigned int command,
   1031					unsigned long arg)
   1032{
   1033	void __user *uarg = (void __user *)arg;
   1034	struct hmm_dmirror_cmd cmd;
   1035	struct dmirror *dmirror;
   1036	int ret;
   1037
   1038	dmirror = filp->private_data;
   1039	if (!dmirror)
   1040		return -EINVAL;
   1041
   1042	if (copy_from_user(&cmd, uarg, sizeof(cmd)))
   1043		return -EFAULT;
   1044
   1045	if (cmd.addr & ~PAGE_MASK)
   1046		return -EINVAL;
   1047	if (cmd.addr >= (cmd.addr + (cmd.npages << PAGE_SHIFT)))
   1048		return -EINVAL;
   1049
   1050	cmd.cpages = 0;
   1051	cmd.faults = 0;
   1052
   1053	switch (command) {
   1054	case HMM_DMIRROR_READ:
   1055		ret = dmirror_read(dmirror, &cmd);
   1056		break;
   1057
   1058	case HMM_DMIRROR_WRITE:
   1059		ret = dmirror_write(dmirror, &cmd);
   1060		break;
   1061
   1062	case HMM_DMIRROR_MIGRATE:
   1063		ret = dmirror_migrate(dmirror, &cmd);
   1064		break;
   1065
   1066	case HMM_DMIRROR_EXCLUSIVE:
   1067		ret = dmirror_exclusive(dmirror, &cmd);
   1068		break;
   1069
   1070	case HMM_DMIRROR_CHECK_EXCLUSIVE:
   1071		ret = dmirror_check_atomic(dmirror, cmd.addr,
   1072					cmd.addr + (cmd.npages << PAGE_SHIFT));
   1073		break;
   1074
   1075	case HMM_DMIRROR_SNAPSHOT:
   1076		ret = dmirror_snapshot(dmirror, &cmd);
   1077		break;
   1078
   1079	default:
   1080		return -EINVAL;
   1081	}
   1082	if (ret)
   1083		return ret;
   1084
   1085	if (copy_to_user(uarg, &cmd, sizeof(cmd)))
   1086		return -EFAULT;
   1087
   1088	return 0;
   1089}
   1090
   1091static int dmirror_fops_mmap(struct file *file, struct vm_area_struct *vma)
   1092{
   1093	unsigned long addr;
   1094
   1095	for (addr = vma->vm_start; addr < vma->vm_end; addr += PAGE_SIZE) {
   1096		struct page *page;
   1097		int ret;
   1098
   1099		page = alloc_page(GFP_KERNEL | __GFP_ZERO);
   1100		if (!page)
   1101			return -ENOMEM;
   1102
   1103		ret = vm_insert_page(vma, addr, page);
   1104		if (ret) {
   1105			__free_page(page);
   1106			return ret;
   1107		}
   1108		put_page(page);
   1109	}
   1110
   1111	return 0;
   1112}
   1113
   1114static const struct file_operations dmirror_fops = {
   1115	.open		= dmirror_fops_open,
   1116	.release	= dmirror_fops_release,
   1117	.mmap		= dmirror_fops_mmap,
   1118	.unlocked_ioctl = dmirror_fops_unlocked_ioctl,
   1119	.llseek		= default_llseek,
   1120	.owner		= THIS_MODULE,
   1121};
   1122
   1123static void dmirror_devmem_free(struct page *page)
   1124{
   1125	struct page *rpage = page->zone_device_data;
   1126	struct dmirror_device *mdevice;
   1127
   1128	if (rpage)
   1129		__free_page(rpage);
   1130
   1131	mdevice = dmirror_page_to_device(page);
   1132
   1133	spin_lock(&mdevice->lock);
   1134	mdevice->cfree++;
   1135	page->zone_device_data = mdevice->free_pages;
   1136	mdevice->free_pages = page;
   1137	spin_unlock(&mdevice->lock);
   1138}
   1139
   1140static vm_fault_t dmirror_devmem_fault_alloc_and_copy(struct migrate_vma *args,
   1141						      struct dmirror *dmirror)
   1142{
   1143	const unsigned long *src = args->src;
   1144	unsigned long *dst = args->dst;
   1145	unsigned long start = args->start;
   1146	unsigned long end = args->end;
   1147	unsigned long addr;
   1148
   1149	for (addr = start; addr < end; addr += PAGE_SIZE,
   1150				       src++, dst++) {
   1151		struct page *dpage, *spage;
   1152
   1153		spage = migrate_pfn_to_page(*src);
   1154		if (!spage || !(*src & MIGRATE_PFN_MIGRATE))
   1155			continue;
   1156		spage = spage->zone_device_data;
   1157
   1158		dpage = alloc_page_vma(GFP_HIGHUSER_MOVABLE, args->vma, addr);
   1159		if (!dpage)
   1160			continue;
   1161
   1162		lock_page(dpage);
   1163		xa_erase(&dmirror->pt, addr >> PAGE_SHIFT);
   1164		copy_highpage(dpage, spage);
   1165		*dst = migrate_pfn(page_to_pfn(dpage));
   1166		if (*src & MIGRATE_PFN_WRITE)
   1167			*dst |= MIGRATE_PFN_WRITE;
   1168	}
   1169	return 0;
   1170}
   1171
   1172static vm_fault_t dmirror_devmem_fault(struct vm_fault *vmf)
   1173{
   1174	struct migrate_vma args;
   1175	unsigned long src_pfns;
   1176	unsigned long dst_pfns;
   1177	struct page *rpage;
   1178	struct dmirror *dmirror;
   1179	vm_fault_t ret;
   1180
   1181	/*
   1182	 * Normally, a device would use the page->zone_device_data to point to
   1183	 * the mirror but here we use it to hold the page for the simulated
   1184	 * device memory and that page holds the pointer to the mirror.
   1185	 */
   1186	rpage = vmf->page->zone_device_data;
   1187	dmirror = rpage->zone_device_data;
   1188
   1189	/* FIXME demonstrate how we can adjust migrate range */
   1190	args.vma = vmf->vma;
   1191	args.start = vmf->address;
   1192	args.end = args.start + PAGE_SIZE;
   1193	args.src = &src_pfns;
   1194	args.dst = &dst_pfns;
   1195	args.pgmap_owner = dmirror->mdevice;
   1196	args.flags = MIGRATE_VMA_SELECT_DEVICE_PRIVATE;
   1197
   1198	if (migrate_vma_setup(&args))
   1199		return VM_FAULT_SIGBUS;
   1200
   1201	ret = dmirror_devmem_fault_alloc_and_copy(&args, dmirror);
   1202	if (ret)
   1203		return ret;
   1204	migrate_vma_pages(&args);
   1205	/*
   1206	 * No device finalize step is needed since
   1207	 * dmirror_devmem_fault_alloc_and_copy() will have already
   1208	 * invalidated the device page table.
   1209	 */
   1210	migrate_vma_finalize(&args);
   1211	return 0;
   1212}
   1213
   1214static const struct dev_pagemap_ops dmirror_devmem_ops = {
   1215	.page_free	= dmirror_devmem_free,
   1216	.migrate_to_ram	= dmirror_devmem_fault,
   1217};
   1218
   1219static int dmirror_device_init(struct dmirror_device *mdevice, int id)
   1220{
   1221	dev_t dev;
   1222	int ret;
   1223
   1224	dev = MKDEV(MAJOR(dmirror_dev), id);
   1225	mutex_init(&mdevice->devmem_lock);
   1226	spin_lock_init(&mdevice->lock);
   1227
   1228	cdev_init(&mdevice->cdevice, &dmirror_fops);
   1229	mdevice->cdevice.owner = THIS_MODULE;
   1230	ret = cdev_add(&mdevice->cdevice, dev, 1);
   1231	if (ret)
   1232		return ret;
   1233
   1234	/* Build a list of free ZONE_DEVICE private struct pages */
   1235	dmirror_allocate_chunk(mdevice, NULL);
   1236
   1237	return 0;
   1238}
   1239
   1240static void dmirror_device_remove(struct dmirror_device *mdevice)
   1241{
   1242	unsigned int i;
   1243
   1244	if (mdevice->devmem_chunks) {
   1245		for (i = 0; i < mdevice->devmem_count; i++) {
   1246			struct dmirror_chunk *devmem =
   1247				mdevice->devmem_chunks[i];
   1248
   1249			memunmap_pages(&devmem->pagemap);
   1250			release_mem_region(devmem->pagemap.range.start,
   1251					   range_len(&devmem->pagemap.range));
   1252			kfree(devmem);
   1253		}
   1254		kfree(mdevice->devmem_chunks);
   1255	}
   1256
   1257	cdev_del(&mdevice->cdevice);
   1258}
   1259
   1260static int __init hmm_dmirror_init(void)
   1261{
   1262	int ret;
   1263	int id;
   1264
   1265	ret = alloc_chrdev_region(&dmirror_dev, 0, DMIRROR_NDEVICES,
   1266				  "HMM_DMIRROR");
   1267	if (ret)
   1268		goto err_unreg;
   1269
   1270	for (id = 0; id < DMIRROR_NDEVICES; id++) {
   1271		ret = dmirror_device_init(dmirror_devices + id, id);
   1272		if (ret)
   1273			goto err_chrdev;
   1274	}
   1275
   1276	pr_info("HMM test module loaded. This is only for testing HMM.\n");
   1277	return 0;
   1278
   1279err_chrdev:
   1280	while (--id >= 0)
   1281		dmirror_device_remove(dmirror_devices + id);
   1282	unregister_chrdev_region(dmirror_dev, DMIRROR_NDEVICES);
   1283err_unreg:
   1284	return ret;
   1285}
   1286
   1287static void __exit hmm_dmirror_exit(void)
   1288{
   1289	int id;
   1290
   1291	for (id = 0; id < DMIRROR_NDEVICES; id++)
   1292		dmirror_device_remove(dmirror_devices + id);
   1293	unregister_chrdev_region(dmirror_dev, DMIRROR_NDEVICES);
   1294}
   1295
   1296module_init(hmm_dmirror_init);
   1297module_exit(hmm_dmirror_exit);
   1298MODULE_LICENSE("GPL");