cachepc-linux

Fork of AMDESE/linux with modifications for CachePC side-channel attack
git clone https://git.sinitax.com/sinitax/cachepc-linux
Log | Files | Refs | README | LICENSE | sfeed.txt

st-dma-fence-chain.c (13686B)


      1// SPDX-License-Identifier: MIT
      2
      3/*
      4 * Copyright © 2019 Intel Corporation
      5 */
      6
      7#include <linux/delay.h>
      8#include <linux/dma-fence.h>
      9#include <linux/dma-fence-chain.h>
     10#include <linux/kernel.h>
     11#include <linux/kthread.h>
     12#include <linux/mm.h>
     13#include <linux/sched/signal.h>
     14#include <linux/slab.h>
     15#include <linux/spinlock.h>
     16#include <linux/random.h>
     17
     18#include "selftest.h"
     19
     20#define CHAIN_SZ (4 << 10)
     21
     22static struct kmem_cache *slab_fences;
     23
     24static inline struct mock_fence {
     25	struct dma_fence base;
     26	spinlock_t lock;
     27} *to_mock_fence(struct dma_fence *f) {
     28	return container_of(f, struct mock_fence, base);
     29}
     30
     31static const char *mock_name(struct dma_fence *f)
     32{
     33	return "mock";
     34}
     35
     36static void mock_fence_release(struct dma_fence *f)
     37{
     38	kmem_cache_free(slab_fences, to_mock_fence(f));
     39}
     40
     41static const struct dma_fence_ops mock_ops = {
     42	.get_driver_name = mock_name,
     43	.get_timeline_name = mock_name,
     44	.release = mock_fence_release,
     45};
     46
     47static struct dma_fence *mock_fence(void)
     48{
     49	struct mock_fence *f;
     50
     51	f = kmem_cache_alloc(slab_fences, GFP_KERNEL);
     52	if (!f)
     53		return NULL;
     54
     55	spin_lock_init(&f->lock);
     56	dma_fence_init(&f->base, &mock_ops, &f->lock, 0, 0);
     57
     58	return &f->base;
     59}
     60
     61static struct dma_fence *mock_chain(struct dma_fence *prev,
     62				    struct dma_fence *fence,
     63				    u64 seqno)
     64{
     65	struct dma_fence_chain *f;
     66
     67	f = dma_fence_chain_alloc();
     68	if (!f)
     69		return NULL;
     70
     71	dma_fence_chain_init(f, dma_fence_get(prev), dma_fence_get(fence),
     72			     seqno);
     73
     74	return &f->base;
     75}
     76
     77static int sanitycheck(void *arg)
     78{
     79	struct dma_fence *f, *chain;
     80	int err = 0;
     81
     82	f = mock_fence();
     83	if (!f)
     84		return -ENOMEM;
     85
     86	chain = mock_chain(NULL, f, 1);
     87	if (!chain)
     88		err = -ENOMEM;
     89
     90	dma_fence_signal(f);
     91	dma_fence_put(f);
     92
     93	dma_fence_put(chain);
     94
     95	return err;
     96}
     97
     98struct fence_chains {
     99	unsigned int chain_length;
    100	struct dma_fence **fences;
    101	struct dma_fence **chains;
    102
    103	struct dma_fence *tail;
    104};
    105
    106static uint64_t seqno_inc(unsigned int i)
    107{
    108	return i + 1;
    109}
    110
    111static int fence_chains_init(struct fence_chains *fc, unsigned int count,
    112			     uint64_t (*seqno_fn)(unsigned int))
    113{
    114	unsigned int i;
    115	int err = 0;
    116
    117	fc->chains = kvmalloc_array(count, sizeof(*fc->chains),
    118				    GFP_KERNEL | __GFP_ZERO);
    119	if (!fc->chains)
    120		return -ENOMEM;
    121
    122	fc->fences = kvmalloc_array(count, sizeof(*fc->fences),
    123				    GFP_KERNEL | __GFP_ZERO);
    124	if (!fc->fences) {
    125		err = -ENOMEM;
    126		goto err_chains;
    127	}
    128
    129	fc->tail = NULL;
    130	for (i = 0; i < count; i++) {
    131		fc->fences[i] = mock_fence();
    132		if (!fc->fences[i]) {
    133			err = -ENOMEM;
    134			goto unwind;
    135		}
    136
    137		fc->chains[i] = mock_chain(fc->tail,
    138					   fc->fences[i],
    139					   seqno_fn(i));
    140		if (!fc->chains[i]) {
    141			err = -ENOMEM;
    142			goto unwind;
    143		}
    144
    145		fc->tail = fc->chains[i];
    146	}
    147
    148	fc->chain_length = i;
    149	return 0;
    150
    151unwind:
    152	for (i = 0; i < count; i++) {
    153		dma_fence_put(fc->fences[i]);
    154		dma_fence_put(fc->chains[i]);
    155	}
    156	kvfree(fc->fences);
    157err_chains:
    158	kvfree(fc->chains);
    159	return err;
    160}
    161
    162static void fence_chains_fini(struct fence_chains *fc)
    163{
    164	unsigned int i;
    165
    166	for (i = 0; i < fc->chain_length; i++) {
    167		dma_fence_signal(fc->fences[i]);
    168		dma_fence_put(fc->fences[i]);
    169	}
    170	kvfree(fc->fences);
    171
    172	for (i = 0; i < fc->chain_length; i++)
    173		dma_fence_put(fc->chains[i]);
    174	kvfree(fc->chains);
    175}
    176
    177static int find_seqno(void *arg)
    178{
    179	struct fence_chains fc;
    180	struct dma_fence *fence;
    181	int err;
    182	int i;
    183
    184	err = fence_chains_init(&fc, 64, seqno_inc);
    185	if (err)
    186		return err;
    187
    188	fence = dma_fence_get(fc.tail);
    189	err = dma_fence_chain_find_seqno(&fence, 0);
    190	dma_fence_put(fence);
    191	if (err) {
    192		pr_err("Reported %d for find_seqno(0)!\n", err);
    193		goto err;
    194	}
    195
    196	for (i = 0; i < fc.chain_length; i++) {
    197		fence = dma_fence_get(fc.tail);
    198		err = dma_fence_chain_find_seqno(&fence, i + 1);
    199		dma_fence_put(fence);
    200		if (err) {
    201			pr_err("Reported %d for find_seqno(%d:%d)!\n",
    202			       err, fc.chain_length + 1, i + 1);
    203			goto err;
    204		}
    205		if (fence != fc.chains[i]) {
    206			pr_err("Incorrect fence reported by find_seqno(%d:%d)\n",
    207			       fc.chain_length + 1, i + 1);
    208			err = -EINVAL;
    209			goto err;
    210		}
    211
    212		dma_fence_get(fence);
    213		err = dma_fence_chain_find_seqno(&fence, i + 1);
    214		dma_fence_put(fence);
    215		if (err) {
    216			pr_err("Error reported for finding self\n");
    217			goto err;
    218		}
    219		if (fence != fc.chains[i]) {
    220			pr_err("Incorrect fence reported by find self\n");
    221			err = -EINVAL;
    222			goto err;
    223		}
    224
    225		dma_fence_get(fence);
    226		err = dma_fence_chain_find_seqno(&fence, i + 2);
    227		dma_fence_put(fence);
    228		if (!err) {
    229			pr_err("Error not reported for future fence: find_seqno(%d:%d)!\n",
    230			       i + 1, i + 2);
    231			err = -EINVAL;
    232			goto err;
    233		}
    234
    235		dma_fence_get(fence);
    236		err = dma_fence_chain_find_seqno(&fence, i);
    237		dma_fence_put(fence);
    238		if (err) {
    239			pr_err("Error reported for previous fence!\n");
    240			goto err;
    241		}
    242		if (i > 0 && fence != fc.chains[i - 1]) {
    243			pr_err("Incorrect fence reported by find_seqno(%d:%d)\n",
    244			       i + 1, i);
    245			err = -EINVAL;
    246			goto err;
    247		}
    248	}
    249
    250err:
    251	fence_chains_fini(&fc);
    252	return err;
    253}
    254
    255static int find_signaled(void *arg)
    256{
    257	struct fence_chains fc;
    258	struct dma_fence *fence;
    259	int err;
    260
    261	err = fence_chains_init(&fc, 2, seqno_inc);
    262	if (err)
    263		return err;
    264
    265	dma_fence_signal(fc.fences[0]);
    266
    267	fence = dma_fence_get(fc.tail);
    268	err = dma_fence_chain_find_seqno(&fence, 1);
    269	dma_fence_put(fence);
    270	if (err) {
    271		pr_err("Reported %d for find_seqno()!\n", err);
    272		goto err;
    273	}
    274
    275	if (fence && fence != fc.chains[0]) {
    276		pr_err("Incorrect chain-fence.seqno:%lld reported for completed seqno:1\n",
    277		       fence->seqno);
    278
    279		dma_fence_get(fence);
    280		err = dma_fence_chain_find_seqno(&fence, 1);
    281		dma_fence_put(fence);
    282		if (err)
    283			pr_err("Reported %d for finding self!\n", err);
    284
    285		err = -EINVAL;
    286	}
    287
    288err:
    289	fence_chains_fini(&fc);
    290	return err;
    291}
    292
    293static int find_out_of_order(void *arg)
    294{
    295	struct fence_chains fc;
    296	struct dma_fence *fence;
    297	int err;
    298
    299	err = fence_chains_init(&fc, 3, seqno_inc);
    300	if (err)
    301		return err;
    302
    303	dma_fence_signal(fc.fences[1]);
    304
    305	fence = dma_fence_get(fc.tail);
    306	err = dma_fence_chain_find_seqno(&fence, 2);
    307	dma_fence_put(fence);
    308	if (err) {
    309		pr_err("Reported %d for find_seqno()!\n", err);
    310		goto err;
    311	}
    312
    313	/*
    314	 * We signaled the middle fence (2) of the 1-2-3 chain. The behavior
    315	 * of the dma-fence-chain is to make us wait for all the fences up to
    316	 * the point we want. Since fence 1 is still not signaled, this what
    317	 * we should get as fence to wait upon (fence 2 being garbage
    318	 * collected during the traversal of the chain).
    319	 */
    320	if (fence != fc.chains[0]) {
    321		pr_err("Incorrect chain-fence.seqno:%lld reported for completed seqno:2\n",
    322		       fence ? fence->seqno : 0);
    323
    324		err = -EINVAL;
    325	}
    326
    327err:
    328	fence_chains_fini(&fc);
    329	return err;
    330}
    331
    332static uint64_t seqno_inc2(unsigned int i)
    333{
    334	return 2 * i + 2;
    335}
    336
    337static int find_gap(void *arg)
    338{
    339	struct fence_chains fc;
    340	struct dma_fence *fence;
    341	int err;
    342	int i;
    343
    344	err = fence_chains_init(&fc, 64, seqno_inc2);
    345	if (err)
    346		return err;
    347
    348	for (i = 0; i < fc.chain_length; i++) {
    349		fence = dma_fence_get(fc.tail);
    350		err = dma_fence_chain_find_seqno(&fence, 2 * i + 1);
    351		dma_fence_put(fence);
    352		if (err) {
    353			pr_err("Reported %d for find_seqno(%d:%d)!\n",
    354			       err, fc.chain_length + 1, 2 * i + 1);
    355			goto err;
    356		}
    357		if (fence != fc.chains[i]) {
    358			pr_err("Incorrect fence.seqno:%lld reported by find_seqno(%d:%d)\n",
    359			       fence->seqno,
    360			       fc.chain_length + 1,
    361			       2 * i + 1);
    362			err = -EINVAL;
    363			goto err;
    364		}
    365
    366		dma_fence_get(fence);
    367		err = dma_fence_chain_find_seqno(&fence, 2 * i + 2);
    368		dma_fence_put(fence);
    369		if (err) {
    370			pr_err("Error reported for finding self\n");
    371			goto err;
    372		}
    373		if (fence != fc.chains[i]) {
    374			pr_err("Incorrect fence reported by find self\n");
    375			err = -EINVAL;
    376			goto err;
    377		}
    378	}
    379
    380err:
    381	fence_chains_fini(&fc);
    382	return err;
    383}
    384
    385struct find_race {
    386	struct fence_chains fc;
    387	atomic_t children;
    388};
    389
    390static int __find_race(void *arg)
    391{
    392	struct find_race *data = arg;
    393	int err = 0;
    394
    395	while (!kthread_should_stop()) {
    396		struct dma_fence *fence = dma_fence_get(data->fc.tail);
    397		int seqno;
    398
    399		seqno = prandom_u32_max(data->fc.chain_length) + 1;
    400
    401		err = dma_fence_chain_find_seqno(&fence, seqno);
    402		if (err) {
    403			pr_err("Failed to find fence seqno:%d\n",
    404			       seqno);
    405			dma_fence_put(fence);
    406			break;
    407		}
    408		if (!fence)
    409			goto signal;
    410
    411		/*
    412		 * We can only find ourselves if we are on fence we were
    413		 * looking for.
    414		 */
    415		if (fence->seqno == seqno) {
    416			err = dma_fence_chain_find_seqno(&fence, seqno);
    417			if (err) {
    418				pr_err("Reported an invalid fence for find-self:%d\n",
    419				       seqno);
    420				dma_fence_put(fence);
    421				break;
    422			}
    423		}
    424
    425		dma_fence_put(fence);
    426
    427signal:
    428		seqno = prandom_u32_max(data->fc.chain_length - 1);
    429		dma_fence_signal(data->fc.fences[seqno]);
    430		cond_resched();
    431	}
    432
    433	if (atomic_dec_and_test(&data->children))
    434		wake_up_var(&data->children);
    435	return err;
    436}
    437
    438static int find_race(void *arg)
    439{
    440	struct find_race data;
    441	int ncpus = num_online_cpus();
    442	struct task_struct **threads;
    443	unsigned long count;
    444	int err;
    445	int i;
    446
    447	err = fence_chains_init(&data.fc, CHAIN_SZ, seqno_inc);
    448	if (err)
    449		return err;
    450
    451	threads = kmalloc_array(ncpus, sizeof(*threads), GFP_KERNEL);
    452	if (!threads) {
    453		err = -ENOMEM;
    454		goto err;
    455	}
    456
    457	atomic_set(&data.children, 0);
    458	for (i = 0; i < ncpus; i++) {
    459		threads[i] = kthread_run(__find_race, &data, "dmabuf/%d", i);
    460		if (IS_ERR(threads[i])) {
    461			ncpus = i;
    462			break;
    463		}
    464		atomic_inc(&data.children);
    465		get_task_struct(threads[i]);
    466	}
    467
    468	wait_var_event_timeout(&data.children,
    469			       !atomic_read(&data.children),
    470			       5 * HZ);
    471
    472	for (i = 0; i < ncpus; i++) {
    473		int ret;
    474
    475		ret = kthread_stop(threads[i]);
    476		if (ret && !err)
    477			err = ret;
    478		put_task_struct(threads[i]);
    479	}
    480	kfree(threads);
    481
    482	count = 0;
    483	for (i = 0; i < data.fc.chain_length; i++)
    484		if (dma_fence_is_signaled(data.fc.fences[i]))
    485			count++;
    486	pr_info("Completed %lu cycles\n", count);
    487
    488err:
    489	fence_chains_fini(&data.fc);
    490	return err;
    491}
    492
    493static int signal_forward(void *arg)
    494{
    495	struct fence_chains fc;
    496	int err;
    497	int i;
    498
    499	err = fence_chains_init(&fc, 64, seqno_inc);
    500	if (err)
    501		return err;
    502
    503	for (i = 0; i < fc.chain_length; i++) {
    504		dma_fence_signal(fc.fences[i]);
    505
    506		if (!dma_fence_is_signaled(fc.chains[i])) {
    507			pr_err("chain[%d] not signaled!\n", i);
    508			err = -EINVAL;
    509			goto err;
    510		}
    511
    512		if (i + 1 < fc.chain_length &&
    513		    dma_fence_is_signaled(fc.chains[i + 1])) {
    514			pr_err("chain[%d] is signaled!\n", i);
    515			err = -EINVAL;
    516			goto err;
    517		}
    518	}
    519
    520err:
    521	fence_chains_fini(&fc);
    522	return err;
    523}
    524
    525static int signal_backward(void *arg)
    526{
    527	struct fence_chains fc;
    528	int err;
    529	int i;
    530
    531	err = fence_chains_init(&fc, 64, seqno_inc);
    532	if (err)
    533		return err;
    534
    535	for (i = fc.chain_length; i--; ) {
    536		dma_fence_signal(fc.fences[i]);
    537
    538		if (i > 0 && dma_fence_is_signaled(fc.chains[i])) {
    539			pr_err("chain[%d] is signaled!\n", i);
    540			err = -EINVAL;
    541			goto err;
    542		}
    543	}
    544
    545	for (i = 0; i < fc.chain_length; i++) {
    546		if (!dma_fence_is_signaled(fc.chains[i])) {
    547			pr_err("chain[%d] was not signaled!\n", i);
    548			err = -EINVAL;
    549			goto err;
    550		}
    551	}
    552
    553err:
    554	fence_chains_fini(&fc);
    555	return err;
    556}
    557
    558static int __wait_fence_chains(void *arg)
    559{
    560	struct fence_chains *fc = arg;
    561
    562	if (dma_fence_wait(fc->tail, false))
    563		return -EIO;
    564
    565	return 0;
    566}
    567
    568static int wait_forward(void *arg)
    569{
    570	struct fence_chains fc;
    571	struct task_struct *tsk;
    572	int err;
    573	int i;
    574
    575	err = fence_chains_init(&fc, CHAIN_SZ, seqno_inc);
    576	if (err)
    577		return err;
    578
    579	tsk = kthread_run(__wait_fence_chains, &fc, "dmabuf/wait");
    580	if (IS_ERR(tsk)) {
    581		err = PTR_ERR(tsk);
    582		goto err;
    583	}
    584	get_task_struct(tsk);
    585	yield_to(tsk, true);
    586
    587	for (i = 0; i < fc.chain_length; i++)
    588		dma_fence_signal(fc.fences[i]);
    589
    590	err = kthread_stop(tsk);
    591	put_task_struct(tsk);
    592
    593err:
    594	fence_chains_fini(&fc);
    595	return err;
    596}
    597
    598static int wait_backward(void *arg)
    599{
    600	struct fence_chains fc;
    601	struct task_struct *tsk;
    602	int err;
    603	int i;
    604
    605	err = fence_chains_init(&fc, CHAIN_SZ, seqno_inc);
    606	if (err)
    607		return err;
    608
    609	tsk = kthread_run(__wait_fence_chains, &fc, "dmabuf/wait");
    610	if (IS_ERR(tsk)) {
    611		err = PTR_ERR(tsk);
    612		goto err;
    613	}
    614	get_task_struct(tsk);
    615	yield_to(tsk, true);
    616
    617	for (i = fc.chain_length; i--; )
    618		dma_fence_signal(fc.fences[i]);
    619
    620	err = kthread_stop(tsk);
    621	put_task_struct(tsk);
    622
    623err:
    624	fence_chains_fini(&fc);
    625	return err;
    626}
    627
    628static void randomise_fences(struct fence_chains *fc)
    629{
    630	unsigned int count = fc->chain_length;
    631
    632	/* Fisher-Yates shuffle courtesy of Knuth */
    633	while (--count) {
    634		unsigned int swp;
    635
    636		swp = prandom_u32_max(count + 1);
    637		if (swp == count)
    638			continue;
    639
    640		swap(fc->fences[count], fc->fences[swp]);
    641	}
    642}
    643
    644static int wait_random(void *arg)
    645{
    646	struct fence_chains fc;
    647	struct task_struct *tsk;
    648	int err;
    649	int i;
    650
    651	err = fence_chains_init(&fc, CHAIN_SZ, seqno_inc);
    652	if (err)
    653		return err;
    654
    655	randomise_fences(&fc);
    656
    657	tsk = kthread_run(__wait_fence_chains, &fc, "dmabuf/wait");
    658	if (IS_ERR(tsk)) {
    659		err = PTR_ERR(tsk);
    660		goto err;
    661	}
    662	get_task_struct(tsk);
    663	yield_to(tsk, true);
    664
    665	for (i = 0; i < fc.chain_length; i++)
    666		dma_fence_signal(fc.fences[i]);
    667
    668	err = kthread_stop(tsk);
    669	put_task_struct(tsk);
    670
    671err:
    672	fence_chains_fini(&fc);
    673	return err;
    674}
    675
    676int dma_fence_chain(void)
    677{
    678	static const struct subtest tests[] = {
    679		SUBTEST(sanitycheck),
    680		SUBTEST(find_seqno),
    681		SUBTEST(find_signaled),
    682		SUBTEST(find_out_of_order),
    683		SUBTEST(find_gap),
    684		SUBTEST(find_race),
    685		SUBTEST(signal_forward),
    686		SUBTEST(signal_backward),
    687		SUBTEST(wait_forward),
    688		SUBTEST(wait_backward),
    689		SUBTEST(wait_random),
    690	};
    691	int ret;
    692
    693	pr_info("sizeof(dma_fence_chain)=%zu\n",
    694		sizeof(struct dma_fence_chain));
    695
    696	slab_fences = KMEM_CACHE(mock_fence,
    697				 SLAB_TYPESAFE_BY_RCU |
    698				 SLAB_HWCACHE_ALIGN);
    699	if (!slab_fences)
    700		return -ENOMEM;
    701
    702	ret = subtests(tests, NULL);
    703
    704	kmem_cache_destroy(slab_fences);
    705	return ret;
    706}