cachepc-linux

Fork of AMDESE/linux with modifications for CachePC side-channel attack
git clone https://git.sinitax.com/sinitax/cachepc-linux
Log | Files | Refs | README | LICENSE | sfeed.txt

mr.c (21356B)


      1// SPDX-License-Identifier: GPL-2.0 or BSD-3-Clause
      2/*
      3 * Copyright(c) 2016 Intel Corporation.
      4 */
      5
      6#include <linux/slab.h>
      7#include <linux/vmalloc.h>
      8#include <rdma/ib_umem.h>
      9#include <rdma/rdma_vt.h>
     10#include "vt.h"
     11#include "mr.h"
     12#include "trace.h"
     13
     14/**
     15 * rvt_driver_mr_init - Init MR resources per driver
     16 * @rdi: rvt dev struct
     17 *
     18 * Do any intilization needed when a driver registers with rdmavt.
     19 *
     20 * Return: 0 on success or errno on failure
     21 */
     22int rvt_driver_mr_init(struct rvt_dev_info *rdi)
     23{
     24	unsigned int lkey_table_size = rdi->dparms.lkey_table_size;
     25	unsigned lk_tab_size;
     26	int i;
     27
     28	/*
     29	 * The top hfi1_lkey_table_size bits are used to index the
     30	 * table.  The lower 8 bits can be owned by the user (copied from
     31	 * the LKEY).  The remaining bits act as a generation number or tag.
     32	 */
     33	if (!lkey_table_size)
     34		return -EINVAL;
     35
     36	spin_lock_init(&rdi->lkey_table.lock);
     37
     38	/* ensure generation is at least 4 bits */
     39	if (lkey_table_size > RVT_MAX_LKEY_TABLE_BITS) {
     40		rvt_pr_warn(rdi, "lkey bits %u too large, reduced to %u\n",
     41			    lkey_table_size, RVT_MAX_LKEY_TABLE_BITS);
     42		rdi->dparms.lkey_table_size = RVT_MAX_LKEY_TABLE_BITS;
     43		lkey_table_size = rdi->dparms.lkey_table_size;
     44	}
     45	rdi->lkey_table.max = 1 << lkey_table_size;
     46	rdi->lkey_table.shift = 32 - lkey_table_size;
     47	lk_tab_size = rdi->lkey_table.max * sizeof(*rdi->lkey_table.table);
     48	rdi->lkey_table.table = (struct rvt_mregion __rcu **)
     49			       vmalloc_node(lk_tab_size, rdi->dparms.node);
     50	if (!rdi->lkey_table.table)
     51		return -ENOMEM;
     52
     53	RCU_INIT_POINTER(rdi->dma_mr, NULL);
     54	for (i = 0; i < rdi->lkey_table.max; i++)
     55		RCU_INIT_POINTER(rdi->lkey_table.table[i], NULL);
     56
     57	rdi->dparms.props.max_mr = rdi->lkey_table.max;
     58	return 0;
     59}
     60
     61/**
     62 * rvt_mr_exit - clean up MR
     63 * @rdi: rvt dev structure
     64 *
     65 * called when drivers have unregistered or perhaps failed to register with us
     66 */
     67void rvt_mr_exit(struct rvt_dev_info *rdi)
     68{
     69	if (rdi->dma_mr)
     70		rvt_pr_err(rdi, "DMA MR not null!\n");
     71
     72	vfree(rdi->lkey_table.table);
     73}
     74
     75static void rvt_deinit_mregion(struct rvt_mregion *mr)
     76{
     77	int i = mr->mapsz;
     78
     79	mr->mapsz = 0;
     80	while (i)
     81		kfree(mr->map[--i]);
     82	percpu_ref_exit(&mr->refcount);
     83}
     84
     85static void __rvt_mregion_complete(struct percpu_ref *ref)
     86{
     87	struct rvt_mregion *mr = container_of(ref, struct rvt_mregion,
     88					      refcount);
     89
     90	complete(&mr->comp);
     91}
     92
     93static int rvt_init_mregion(struct rvt_mregion *mr, struct ib_pd *pd,
     94			    int count, unsigned int percpu_flags)
     95{
     96	int m, i = 0;
     97	struct rvt_dev_info *dev = ib_to_rvt(pd->device);
     98
     99	mr->mapsz = 0;
    100	m = (count + RVT_SEGSZ - 1) / RVT_SEGSZ;
    101	for (; i < m; i++) {
    102		mr->map[i] = kzalloc_node(sizeof(*mr->map[0]), GFP_KERNEL,
    103					  dev->dparms.node);
    104		if (!mr->map[i])
    105			goto bail;
    106		mr->mapsz++;
    107	}
    108	init_completion(&mr->comp);
    109	/* count returning the ptr to user */
    110	if (percpu_ref_init(&mr->refcount, &__rvt_mregion_complete,
    111			    percpu_flags, GFP_KERNEL))
    112		goto bail;
    113
    114	atomic_set(&mr->lkey_invalid, 0);
    115	mr->pd = pd;
    116	mr->max_segs = count;
    117	return 0;
    118bail:
    119	rvt_deinit_mregion(mr);
    120	return -ENOMEM;
    121}
    122
    123/**
    124 * rvt_alloc_lkey - allocate an lkey
    125 * @mr: memory region that this lkey protects
    126 * @dma_region: 0->normal key, 1->restricted DMA key
    127 *
    128 * Returns 0 if successful, otherwise returns -errno.
    129 *
    130 * Increments mr reference count as required.
    131 *
    132 * Sets the lkey field mr for non-dma regions.
    133 *
    134 */
    135static int rvt_alloc_lkey(struct rvt_mregion *mr, int dma_region)
    136{
    137	unsigned long flags;
    138	u32 r;
    139	u32 n;
    140	int ret = 0;
    141	struct rvt_dev_info *dev = ib_to_rvt(mr->pd->device);
    142	struct rvt_lkey_table *rkt = &dev->lkey_table;
    143
    144	rvt_get_mr(mr);
    145	spin_lock_irqsave(&rkt->lock, flags);
    146
    147	/* special case for dma_mr lkey == 0 */
    148	if (dma_region) {
    149		struct rvt_mregion *tmr;
    150
    151		tmr = rcu_access_pointer(dev->dma_mr);
    152		if (!tmr) {
    153			mr->lkey_published = 1;
    154			/* Insure published written first */
    155			rcu_assign_pointer(dev->dma_mr, mr);
    156			rvt_get_mr(mr);
    157		}
    158		goto success;
    159	}
    160
    161	/* Find the next available LKEY */
    162	r = rkt->next;
    163	n = r;
    164	for (;;) {
    165		if (!rcu_access_pointer(rkt->table[r]))
    166			break;
    167		r = (r + 1) & (rkt->max - 1);
    168		if (r == n)
    169			goto bail;
    170	}
    171	rkt->next = (r + 1) & (rkt->max - 1);
    172	/*
    173	 * Make sure lkey is never zero which is reserved to indicate an
    174	 * unrestricted LKEY.
    175	 */
    176	rkt->gen++;
    177	/*
    178	 * bits are capped to ensure enough bits for generation number
    179	 */
    180	mr->lkey = (r << (32 - dev->dparms.lkey_table_size)) |
    181		((((1 << (24 - dev->dparms.lkey_table_size)) - 1) & rkt->gen)
    182		 << 8);
    183	if (mr->lkey == 0) {
    184		mr->lkey |= 1 << 8;
    185		rkt->gen++;
    186	}
    187	mr->lkey_published = 1;
    188	/* Insure published written first */
    189	rcu_assign_pointer(rkt->table[r], mr);
    190success:
    191	spin_unlock_irqrestore(&rkt->lock, flags);
    192out:
    193	return ret;
    194bail:
    195	rvt_put_mr(mr);
    196	spin_unlock_irqrestore(&rkt->lock, flags);
    197	ret = -ENOMEM;
    198	goto out;
    199}
    200
    201/**
    202 * rvt_free_lkey - free an lkey
    203 * @mr: mr to free from tables
    204 */
    205static void rvt_free_lkey(struct rvt_mregion *mr)
    206{
    207	unsigned long flags;
    208	u32 lkey = mr->lkey;
    209	u32 r;
    210	struct rvt_dev_info *dev = ib_to_rvt(mr->pd->device);
    211	struct rvt_lkey_table *rkt = &dev->lkey_table;
    212	int freed = 0;
    213
    214	spin_lock_irqsave(&rkt->lock, flags);
    215	if (!lkey) {
    216		if (mr->lkey_published) {
    217			mr->lkey_published = 0;
    218			/* insure published is written before pointer */
    219			rcu_assign_pointer(dev->dma_mr, NULL);
    220			rvt_put_mr(mr);
    221		}
    222	} else {
    223		if (!mr->lkey_published)
    224			goto out;
    225		r = lkey >> (32 - dev->dparms.lkey_table_size);
    226		mr->lkey_published = 0;
    227		/* insure published is written before pointer */
    228		rcu_assign_pointer(rkt->table[r], NULL);
    229	}
    230	freed++;
    231out:
    232	spin_unlock_irqrestore(&rkt->lock, flags);
    233	if (freed)
    234		percpu_ref_kill(&mr->refcount);
    235}
    236
    237static struct rvt_mr *__rvt_alloc_mr(int count, struct ib_pd *pd)
    238{
    239	struct rvt_mr *mr;
    240	int rval = -ENOMEM;
    241	int m;
    242
    243	/* Allocate struct plus pointers to first level page tables. */
    244	m = (count + RVT_SEGSZ - 1) / RVT_SEGSZ;
    245	mr = kzalloc(struct_size(mr, mr.map, m), GFP_KERNEL);
    246	if (!mr)
    247		goto bail;
    248
    249	rval = rvt_init_mregion(&mr->mr, pd, count, 0);
    250	if (rval)
    251		goto bail;
    252	/*
    253	 * ib_reg_phys_mr() will initialize mr->ibmr except for
    254	 * lkey and rkey.
    255	 */
    256	rval = rvt_alloc_lkey(&mr->mr, 0);
    257	if (rval)
    258		goto bail_mregion;
    259	mr->ibmr.lkey = mr->mr.lkey;
    260	mr->ibmr.rkey = mr->mr.lkey;
    261done:
    262	return mr;
    263
    264bail_mregion:
    265	rvt_deinit_mregion(&mr->mr);
    266bail:
    267	kfree(mr);
    268	mr = ERR_PTR(rval);
    269	goto done;
    270}
    271
    272static void __rvt_free_mr(struct rvt_mr *mr)
    273{
    274	rvt_free_lkey(&mr->mr);
    275	rvt_deinit_mregion(&mr->mr);
    276	kfree(mr);
    277}
    278
    279/**
    280 * rvt_get_dma_mr - get a DMA memory region
    281 * @pd: protection domain for this memory region
    282 * @acc: access flags
    283 *
    284 * Return: the memory region on success, otherwise returns an errno.
    285 */
    286struct ib_mr *rvt_get_dma_mr(struct ib_pd *pd, int acc)
    287{
    288	struct rvt_mr *mr;
    289	struct ib_mr *ret;
    290	int rval;
    291
    292	if (ibpd_to_rvtpd(pd)->user)
    293		return ERR_PTR(-EPERM);
    294
    295	mr = kzalloc(sizeof(*mr), GFP_KERNEL);
    296	if (!mr) {
    297		ret = ERR_PTR(-ENOMEM);
    298		goto bail;
    299	}
    300
    301	rval = rvt_init_mregion(&mr->mr, pd, 0, 0);
    302	if (rval) {
    303		ret = ERR_PTR(rval);
    304		goto bail;
    305	}
    306
    307	rval = rvt_alloc_lkey(&mr->mr, 1);
    308	if (rval) {
    309		ret = ERR_PTR(rval);
    310		goto bail_mregion;
    311	}
    312
    313	mr->mr.access_flags = acc;
    314	ret = &mr->ibmr;
    315done:
    316	return ret;
    317
    318bail_mregion:
    319	rvt_deinit_mregion(&mr->mr);
    320bail:
    321	kfree(mr);
    322	goto done;
    323}
    324
    325/**
    326 * rvt_reg_user_mr - register a userspace memory region
    327 * @pd: protection domain for this memory region
    328 * @start: starting userspace address
    329 * @length: length of region to register
    330 * @virt_addr: associated virtual address
    331 * @mr_access_flags: access flags for this memory region
    332 * @udata: unused by the driver
    333 *
    334 * Return: the memory region on success, otherwise returns an errno.
    335 */
    336struct ib_mr *rvt_reg_user_mr(struct ib_pd *pd, u64 start, u64 length,
    337			      u64 virt_addr, int mr_access_flags,
    338			      struct ib_udata *udata)
    339{
    340	struct rvt_mr *mr;
    341	struct ib_umem *umem;
    342	struct sg_page_iter sg_iter;
    343	int n, m;
    344	struct ib_mr *ret;
    345
    346	if (length == 0)
    347		return ERR_PTR(-EINVAL);
    348
    349	umem = ib_umem_get(pd->device, start, length, mr_access_flags);
    350	if (IS_ERR(umem))
    351		return (void *)umem;
    352
    353	n = ib_umem_num_pages(umem);
    354
    355	mr = __rvt_alloc_mr(n, pd);
    356	if (IS_ERR(mr)) {
    357		ret = (struct ib_mr *)mr;
    358		goto bail_umem;
    359	}
    360
    361	mr->mr.user_base = start;
    362	mr->mr.iova = virt_addr;
    363	mr->mr.length = length;
    364	mr->mr.offset = ib_umem_offset(umem);
    365	mr->mr.access_flags = mr_access_flags;
    366	mr->umem = umem;
    367
    368	mr->mr.page_shift = PAGE_SHIFT;
    369	m = 0;
    370	n = 0;
    371	for_each_sgtable_page (&umem->sgt_append.sgt, &sg_iter, 0) {
    372		void *vaddr;
    373
    374		vaddr = page_address(sg_page_iter_page(&sg_iter));
    375		if (!vaddr) {
    376			ret = ERR_PTR(-EINVAL);
    377			goto bail_inval;
    378		}
    379		mr->mr.map[m]->segs[n].vaddr = vaddr;
    380		mr->mr.map[m]->segs[n].length = PAGE_SIZE;
    381		trace_rvt_mr_user_seg(&mr->mr, m, n, vaddr, PAGE_SIZE);
    382		if (++n == RVT_SEGSZ) {
    383			m++;
    384			n = 0;
    385		}
    386	}
    387	return &mr->ibmr;
    388
    389bail_inval:
    390	__rvt_free_mr(mr);
    391
    392bail_umem:
    393	ib_umem_release(umem);
    394
    395	return ret;
    396}
    397
    398/**
    399 * rvt_dereg_clean_qp_cb - callback from iterator
    400 * @qp: the qp
    401 * @v: the mregion (as u64)
    402 *
    403 * This routine fields the callback for all QPs and
    404 * for QPs in the same PD as the MR will call the
    405 * rvt_qp_mr_clean() to potentially cleanup references.
    406 */
    407static void rvt_dereg_clean_qp_cb(struct rvt_qp *qp, u64 v)
    408{
    409	struct rvt_mregion *mr = (struct rvt_mregion *)v;
    410
    411	/* skip PDs that are not ours */
    412	if (mr->pd != qp->ibqp.pd)
    413		return;
    414	rvt_qp_mr_clean(qp, mr->lkey);
    415}
    416
    417/**
    418 * rvt_dereg_clean_qps - find QPs for reference cleanup
    419 * @mr: the MR that is being deregistered
    420 *
    421 * This routine iterates RC QPs looking for references
    422 * to the lkey noted in mr.
    423 */
    424static void rvt_dereg_clean_qps(struct rvt_mregion *mr)
    425{
    426	struct rvt_dev_info *rdi = ib_to_rvt(mr->pd->device);
    427
    428	rvt_qp_iter(rdi, (u64)mr, rvt_dereg_clean_qp_cb);
    429}
    430
    431/**
    432 * rvt_check_refs - check references
    433 * @mr: the megion
    434 * @t: the caller identification
    435 *
    436 * This routine checks MRs holding a reference during
    437 * when being de-registered.
    438 *
    439 * If the count is non-zero, the code calls a clean routine then
    440 * waits for the timeout for the count to zero.
    441 */
    442static int rvt_check_refs(struct rvt_mregion *mr, const char *t)
    443{
    444	unsigned long timeout;
    445	struct rvt_dev_info *rdi = ib_to_rvt(mr->pd->device);
    446
    447	if (mr->lkey) {
    448		/* avoid dma mr */
    449		rvt_dereg_clean_qps(mr);
    450		/* @mr was indexed on rcu protected @lkey_table */
    451		synchronize_rcu();
    452	}
    453
    454	timeout = wait_for_completion_timeout(&mr->comp, 5 * HZ);
    455	if (!timeout) {
    456		rvt_pr_err(rdi,
    457			   "%s timeout mr %p pd %p lkey %x refcount %ld\n",
    458			   t, mr, mr->pd, mr->lkey,
    459			   atomic_long_read(&mr->refcount.data->count));
    460		rvt_get_mr(mr);
    461		return -EBUSY;
    462	}
    463	return 0;
    464}
    465
    466/**
    467 * rvt_mr_has_lkey - is MR
    468 * @mr: the mregion
    469 * @lkey: the lkey
    470 */
    471bool rvt_mr_has_lkey(struct rvt_mregion *mr, u32 lkey)
    472{
    473	return mr && lkey == mr->lkey;
    474}
    475
    476/**
    477 * rvt_ss_has_lkey - is mr in sge tests
    478 * @ss: the sge state
    479 * @lkey: the lkey
    480 *
    481 * This code tests for an MR in the indicated
    482 * sge state.
    483 */
    484bool rvt_ss_has_lkey(struct rvt_sge_state *ss, u32 lkey)
    485{
    486	int i;
    487	bool rval = false;
    488
    489	if (!ss->num_sge)
    490		return rval;
    491	/* first one */
    492	rval = rvt_mr_has_lkey(ss->sge.mr, lkey);
    493	/* any others */
    494	for (i = 0; !rval && i < ss->num_sge - 1; i++)
    495		rval = rvt_mr_has_lkey(ss->sg_list[i].mr, lkey);
    496	return rval;
    497}
    498
    499/**
    500 * rvt_dereg_mr - unregister and free a memory region
    501 * @ibmr: the memory region to free
    502 * @udata: unused by the driver
    503 *
    504 * Note that this is called to free MRs created by rvt_get_dma_mr()
    505 * or rvt_reg_user_mr().
    506 *
    507 * Returns 0 on success.
    508 */
    509int rvt_dereg_mr(struct ib_mr *ibmr, struct ib_udata *udata)
    510{
    511	struct rvt_mr *mr = to_imr(ibmr);
    512	int ret;
    513
    514	rvt_free_lkey(&mr->mr);
    515
    516	rvt_put_mr(&mr->mr); /* will set completion if last */
    517	ret = rvt_check_refs(&mr->mr, __func__);
    518	if (ret)
    519		goto out;
    520	rvt_deinit_mregion(&mr->mr);
    521	ib_umem_release(mr->umem);
    522	kfree(mr);
    523out:
    524	return ret;
    525}
    526
    527/**
    528 * rvt_alloc_mr - Allocate a memory region usable with the
    529 * @pd: protection domain for this memory region
    530 * @mr_type: mem region type
    531 * @max_num_sg: Max number of segments allowed
    532 *
    533 * Return: the memory region on success, otherwise return an errno.
    534 */
    535struct ib_mr *rvt_alloc_mr(struct ib_pd *pd, enum ib_mr_type mr_type,
    536			   u32 max_num_sg)
    537{
    538	struct rvt_mr *mr;
    539
    540	if (mr_type != IB_MR_TYPE_MEM_REG)
    541		return ERR_PTR(-EINVAL);
    542
    543	mr = __rvt_alloc_mr(max_num_sg, pd);
    544	if (IS_ERR(mr))
    545		return (struct ib_mr *)mr;
    546
    547	return &mr->ibmr;
    548}
    549
    550/**
    551 * rvt_set_page - page assignment function called by ib_sg_to_pages
    552 * @ibmr: memory region
    553 * @addr: dma address of mapped page
    554 *
    555 * Return: 0 on success
    556 */
    557static int rvt_set_page(struct ib_mr *ibmr, u64 addr)
    558{
    559	struct rvt_mr *mr = to_imr(ibmr);
    560	u32 ps = 1 << mr->mr.page_shift;
    561	u32 mapped_segs = mr->mr.length >> mr->mr.page_shift;
    562	int m, n;
    563
    564	if (unlikely(mapped_segs == mr->mr.max_segs))
    565		return -ENOMEM;
    566
    567	m = mapped_segs / RVT_SEGSZ;
    568	n = mapped_segs % RVT_SEGSZ;
    569	mr->mr.map[m]->segs[n].vaddr = (void *)addr;
    570	mr->mr.map[m]->segs[n].length = ps;
    571	mr->mr.length += ps;
    572	trace_rvt_mr_page_seg(&mr->mr, m, n, (void *)addr, ps);
    573
    574	return 0;
    575}
    576
    577/**
    578 * rvt_map_mr_sg - map sg list and set it the memory region
    579 * @ibmr: memory region
    580 * @sg: dma mapped scatterlist
    581 * @sg_nents: number of entries in sg
    582 * @sg_offset: offset in bytes into sg
    583 *
    584 * Overwrite rvt_mr length with mr length calculated by ib_sg_to_pages.
    585 *
    586 * Return: number of sg elements mapped to the memory region
    587 */
    588int rvt_map_mr_sg(struct ib_mr *ibmr, struct scatterlist *sg,
    589		  int sg_nents, unsigned int *sg_offset)
    590{
    591	struct rvt_mr *mr = to_imr(ibmr);
    592	int ret;
    593
    594	mr->mr.length = 0;
    595	mr->mr.page_shift = PAGE_SHIFT;
    596	ret = ib_sg_to_pages(ibmr, sg, sg_nents, sg_offset, rvt_set_page);
    597	mr->mr.user_base = ibmr->iova;
    598	mr->mr.iova = ibmr->iova;
    599	mr->mr.offset = ibmr->iova - (u64)mr->mr.map[0]->segs[0].vaddr;
    600	mr->mr.length = (size_t)ibmr->length;
    601	trace_rvt_map_mr_sg(ibmr, sg_nents, sg_offset);
    602	return ret;
    603}
    604
    605/**
    606 * rvt_fast_reg_mr - fast register physical MR
    607 * @qp: the queue pair where the work request comes from
    608 * @ibmr: the memory region to be registered
    609 * @key: updated key for this memory region
    610 * @access: access flags for this memory region
    611 *
    612 * Returns 0 on success.
    613 */
    614int rvt_fast_reg_mr(struct rvt_qp *qp, struct ib_mr *ibmr, u32 key,
    615		    int access)
    616{
    617	struct rvt_mr *mr = to_imr(ibmr);
    618
    619	if (qp->ibqp.pd != mr->mr.pd)
    620		return -EACCES;
    621
    622	/* not applicable to dma MR or user MR */
    623	if (!mr->mr.lkey || mr->umem)
    624		return -EINVAL;
    625
    626	if ((key & 0xFFFFFF00) != (mr->mr.lkey & 0xFFFFFF00))
    627		return -EINVAL;
    628
    629	ibmr->lkey = key;
    630	ibmr->rkey = key;
    631	mr->mr.lkey = key;
    632	mr->mr.access_flags = access;
    633	mr->mr.iova = ibmr->iova;
    634	atomic_set(&mr->mr.lkey_invalid, 0);
    635
    636	return 0;
    637}
    638EXPORT_SYMBOL(rvt_fast_reg_mr);
    639
    640/**
    641 * rvt_invalidate_rkey - invalidate an MR rkey
    642 * @qp: queue pair associated with the invalidate op
    643 * @rkey: rkey to invalidate
    644 *
    645 * Returns 0 on success.
    646 */
    647int rvt_invalidate_rkey(struct rvt_qp *qp, u32 rkey)
    648{
    649	struct rvt_dev_info *dev = ib_to_rvt(qp->ibqp.device);
    650	struct rvt_lkey_table *rkt = &dev->lkey_table;
    651	struct rvt_mregion *mr;
    652
    653	if (rkey == 0)
    654		return -EINVAL;
    655
    656	rcu_read_lock();
    657	mr = rcu_dereference(
    658		rkt->table[(rkey >> (32 - dev->dparms.lkey_table_size))]);
    659	if (unlikely(!mr || mr->lkey != rkey || qp->ibqp.pd != mr->pd))
    660		goto bail;
    661
    662	atomic_set(&mr->lkey_invalid, 1);
    663	rcu_read_unlock();
    664	return 0;
    665
    666bail:
    667	rcu_read_unlock();
    668	return -EINVAL;
    669}
    670EXPORT_SYMBOL(rvt_invalidate_rkey);
    671
    672/**
    673 * rvt_sge_adjacent - is isge compressible
    674 * @last_sge: last outgoing SGE written
    675 * @sge: SGE to check
    676 *
    677 * If adjacent will update last_sge to add length.
    678 *
    679 * Return: true if isge is adjacent to last sge
    680 */
    681static inline bool rvt_sge_adjacent(struct rvt_sge *last_sge,
    682				    struct ib_sge *sge)
    683{
    684	if (last_sge && sge->lkey == last_sge->mr->lkey &&
    685	    ((uint64_t)(last_sge->vaddr + last_sge->length) == sge->addr)) {
    686		if (sge->lkey) {
    687			if (unlikely((sge->addr - last_sge->mr->user_base +
    688			      sge->length > last_sge->mr->length)))
    689				return false; /* overrun, caller will catch */
    690		} else {
    691			last_sge->length += sge->length;
    692		}
    693		last_sge->sge_length += sge->length;
    694		trace_rvt_sge_adjacent(last_sge, sge);
    695		return true;
    696	}
    697	return false;
    698}
    699
    700/**
    701 * rvt_lkey_ok - check IB SGE for validity and initialize
    702 * @rkt: table containing lkey to check SGE against
    703 * @pd: protection domain
    704 * @isge: outgoing internal SGE
    705 * @last_sge: last outgoing SGE written
    706 * @sge: SGE to check
    707 * @acc: access flags
    708 *
    709 * Check the IB SGE for validity and initialize our internal version
    710 * of it.
    711 *
    712 * Increments the reference count when a new sge is stored.
    713 *
    714 * Return: 0 if compressed, 1 if added , otherwise returns -errno.
    715 */
    716int rvt_lkey_ok(struct rvt_lkey_table *rkt, struct rvt_pd *pd,
    717		struct rvt_sge *isge, struct rvt_sge *last_sge,
    718		struct ib_sge *sge, int acc)
    719{
    720	struct rvt_mregion *mr;
    721	unsigned n, m;
    722	size_t off;
    723
    724	/*
    725	 * We use LKEY == zero for kernel virtual addresses
    726	 * (see rvt_get_dma_mr()).
    727	 */
    728	if (sge->lkey == 0) {
    729		struct rvt_dev_info *dev = ib_to_rvt(pd->ibpd.device);
    730
    731		if (pd->user)
    732			return -EINVAL;
    733		if (rvt_sge_adjacent(last_sge, sge))
    734			return 0;
    735		rcu_read_lock();
    736		mr = rcu_dereference(dev->dma_mr);
    737		if (!mr)
    738			goto bail;
    739		rvt_get_mr(mr);
    740		rcu_read_unlock();
    741
    742		isge->mr = mr;
    743		isge->vaddr = (void *)sge->addr;
    744		isge->length = sge->length;
    745		isge->sge_length = sge->length;
    746		isge->m = 0;
    747		isge->n = 0;
    748		goto ok;
    749	}
    750	if (rvt_sge_adjacent(last_sge, sge))
    751		return 0;
    752	rcu_read_lock();
    753	mr = rcu_dereference(rkt->table[sge->lkey >> rkt->shift]);
    754	if (!mr)
    755		goto bail;
    756	rvt_get_mr(mr);
    757	if (!READ_ONCE(mr->lkey_published))
    758		goto bail_unref;
    759
    760	if (unlikely(atomic_read(&mr->lkey_invalid) ||
    761		     mr->lkey != sge->lkey || mr->pd != &pd->ibpd))
    762		goto bail_unref;
    763
    764	off = sge->addr - mr->user_base;
    765	if (unlikely(sge->addr < mr->user_base ||
    766		     off + sge->length > mr->length ||
    767		     (mr->access_flags & acc) != acc))
    768		goto bail_unref;
    769	rcu_read_unlock();
    770
    771	off += mr->offset;
    772	if (mr->page_shift) {
    773		/*
    774		 * page sizes are uniform power of 2 so no loop is necessary
    775		 * entries_spanned_by_off is the number of times the loop below
    776		 * would have executed.
    777		*/
    778		size_t entries_spanned_by_off;
    779
    780		entries_spanned_by_off = off >> mr->page_shift;
    781		off -= (entries_spanned_by_off << mr->page_shift);
    782		m = entries_spanned_by_off / RVT_SEGSZ;
    783		n = entries_spanned_by_off % RVT_SEGSZ;
    784	} else {
    785		m = 0;
    786		n = 0;
    787		while (off >= mr->map[m]->segs[n].length) {
    788			off -= mr->map[m]->segs[n].length;
    789			n++;
    790			if (n >= RVT_SEGSZ) {
    791				m++;
    792				n = 0;
    793			}
    794		}
    795	}
    796	isge->mr = mr;
    797	isge->vaddr = mr->map[m]->segs[n].vaddr + off;
    798	isge->length = mr->map[m]->segs[n].length - off;
    799	isge->sge_length = sge->length;
    800	isge->m = m;
    801	isge->n = n;
    802ok:
    803	trace_rvt_sge_new(isge, sge);
    804	return 1;
    805bail_unref:
    806	rvt_put_mr(mr);
    807bail:
    808	rcu_read_unlock();
    809	return -EINVAL;
    810}
    811EXPORT_SYMBOL(rvt_lkey_ok);
    812
    813/**
    814 * rvt_rkey_ok - check the IB virtual address, length, and RKEY
    815 * @qp: qp for validation
    816 * @sge: SGE state
    817 * @len: length of data
    818 * @vaddr: virtual address to place data
    819 * @rkey: rkey to check
    820 * @acc: access flags
    821 *
    822 * Return: 1 if successful, otherwise 0.
    823 *
    824 * increments the reference count upon success
    825 */
    826int rvt_rkey_ok(struct rvt_qp *qp, struct rvt_sge *sge,
    827		u32 len, u64 vaddr, u32 rkey, int acc)
    828{
    829	struct rvt_dev_info *dev = ib_to_rvt(qp->ibqp.device);
    830	struct rvt_lkey_table *rkt = &dev->lkey_table;
    831	struct rvt_mregion *mr;
    832	unsigned n, m;
    833	size_t off;
    834
    835	/*
    836	 * We use RKEY == zero for kernel virtual addresses
    837	 * (see rvt_get_dma_mr()).
    838	 */
    839	rcu_read_lock();
    840	if (rkey == 0) {
    841		struct rvt_pd *pd = ibpd_to_rvtpd(qp->ibqp.pd);
    842		struct rvt_dev_info *rdi = ib_to_rvt(pd->ibpd.device);
    843
    844		if (pd->user)
    845			goto bail;
    846		mr = rcu_dereference(rdi->dma_mr);
    847		if (!mr)
    848			goto bail;
    849		rvt_get_mr(mr);
    850		rcu_read_unlock();
    851
    852		sge->mr = mr;
    853		sge->vaddr = (void *)vaddr;
    854		sge->length = len;
    855		sge->sge_length = len;
    856		sge->m = 0;
    857		sge->n = 0;
    858		goto ok;
    859	}
    860
    861	mr = rcu_dereference(rkt->table[rkey >> rkt->shift]);
    862	if (!mr)
    863		goto bail;
    864	rvt_get_mr(mr);
    865	/* insure mr read is before test */
    866	if (!READ_ONCE(mr->lkey_published))
    867		goto bail_unref;
    868	if (unlikely(atomic_read(&mr->lkey_invalid) ||
    869		     mr->lkey != rkey || qp->ibqp.pd != mr->pd))
    870		goto bail_unref;
    871
    872	off = vaddr - mr->iova;
    873	if (unlikely(vaddr < mr->iova || off + len > mr->length ||
    874		     (mr->access_flags & acc) == 0))
    875		goto bail_unref;
    876	rcu_read_unlock();
    877
    878	off += mr->offset;
    879	if (mr->page_shift) {
    880		/*
    881		 * page sizes are uniform power of 2 so no loop is necessary
    882		 * entries_spanned_by_off is the number of times the loop below
    883		 * would have executed.
    884		*/
    885		size_t entries_spanned_by_off;
    886
    887		entries_spanned_by_off = off >> mr->page_shift;
    888		off -= (entries_spanned_by_off << mr->page_shift);
    889		m = entries_spanned_by_off / RVT_SEGSZ;
    890		n = entries_spanned_by_off % RVT_SEGSZ;
    891	} else {
    892		m = 0;
    893		n = 0;
    894		while (off >= mr->map[m]->segs[n].length) {
    895			off -= mr->map[m]->segs[n].length;
    896			n++;
    897			if (n >= RVT_SEGSZ) {
    898				m++;
    899				n = 0;
    900			}
    901		}
    902	}
    903	sge->mr = mr;
    904	sge->vaddr = mr->map[m]->segs[n].vaddr + off;
    905	sge->length = mr->map[m]->segs[n].length - off;
    906	sge->sge_length = len;
    907	sge->m = m;
    908	sge->n = n;
    909ok:
    910	return 1;
    911bail_unref:
    912	rvt_put_mr(mr);
    913bail:
    914	rcu_read_unlock();
    915	return 0;
    916}
    917EXPORT_SYMBOL(rvt_rkey_ok);