cachepc-linux

Fork of AMDESE/linux with modifications for CachePC side-channel attack
git clone https://git.sinitax.com/sinitax/cachepc-linux
Log | Files | Refs | README | LICENSE | sfeed.txt

rbtree_test.c (9586B)


      1// SPDX-License-Identifier: GPL-2.0-only
      2#include <linux/module.h>
      3#include <linux/moduleparam.h>
      4#include <linux/rbtree_augmented.h>
      5#include <linux/random.h>
      6#include <linux/slab.h>
      7#include <asm/timex.h>
      8
      9#define __param(type, name, init, msg)		\
     10	static type name = init;		\
     11	module_param(name, type, 0444);		\
     12	MODULE_PARM_DESC(name, msg);
     13
     14__param(int, nnodes, 100, "Number of nodes in the rb-tree");
     15__param(int, perf_loops, 1000, "Number of iterations modifying the rb-tree");
     16__param(int, check_loops, 100, "Number of iterations modifying and verifying the rb-tree");
     17
     18struct test_node {
     19	u32 key;
     20	struct rb_node rb;
     21
     22	/* following fields used for testing augmented rbtree functionality */
     23	u32 val;
     24	u32 augmented;
     25};
     26
     27static struct rb_root_cached root = RB_ROOT_CACHED;
     28static struct test_node *nodes = NULL;
     29
     30static struct rnd_state rnd;
     31
     32static void insert(struct test_node *node, struct rb_root_cached *root)
     33{
     34	struct rb_node **new = &root->rb_root.rb_node, *parent = NULL;
     35	u32 key = node->key;
     36
     37	while (*new) {
     38		parent = *new;
     39		if (key < rb_entry(parent, struct test_node, rb)->key)
     40			new = &parent->rb_left;
     41		else
     42			new = &parent->rb_right;
     43	}
     44
     45	rb_link_node(&node->rb, parent, new);
     46	rb_insert_color(&node->rb, &root->rb_root);
     47}
     48
     49static void insert_cached(struct test_node *node, struct rb_root_cached *root)
     50{
     51	struct rb_node **new = &root->rb_root.rb_node, *parent = NULL;
     52	u32 key = node->key;
     53	bool leftmost = true;
     54
     55	while (*new) {
     56		parent = *new;
     57		if (key < rb_entry(parent, struct test_node, rb)->key)
     58			new = &parent->rb_left;
     59		else {
     60			new = &parent->rb_right;
     61			leftmost = false;
     62		}
     63	}
     64
     65	rb_link_node(&node->rb, parent, new);
     66	rb_insert_color_cached(&node->rb, root, leftmost);
     67}
     68
     69static inline void erase(struct test_node *node, struct rb_root_cached *root)
     70{
     71	rb_erase(&node->rb, &root->rb_root);
     72}
     73
     74static inline void erase_cached(struct test_node *node, struct rb_root_cached *root)
     75{
     76	rb_erase_cached(&node->rb, root);
     77}
     78
     79
     80#define NODE_VAL(node) ((node)->val)
     81
     82RB_DECLARE_CALLBACKS_MAX(static, augment_callbacks,
     83			 struct test_node, rb, u32, augmented, NODE_VAL)
     84
     85static void insert_augmented(struct test_node *node,
     86			     struct rb_root_cached *root)
     87{
     88	struct rb_node **new = &root->rb_root.rb_node, *rb_parent = NULL;
     89	u32 key = node->key;
     90	u32 val = node->val;
     91	struct test_node *parent;
     92
     93	while (*new) {
     94		rb_parent = *new;
     95		parent = rb_entry(rb_parent, struct test_node, rb);
     96		if (parent->augmented < val)
     97			parent->augmented = val;
     98		if (key < parent->key)
     99			new = &parent->rb.rb_left;
    100		else
    101			new = &parent->rb.rb_right;
    102	}
    103
    104	node->augmented = val;
    105	rb_link_node(&node->rb, rb_parent, new);
    106	rb_insert_augmented(&node->rb, &root->rb_root, &augment_callbacks);
    107}
    108
    109static void insert_augmented_cached(struct test_node *node,
    110				    struct rb_root_cached *root)
    111{
    112	struct rb_node **new = &root->rb_root.rb_node, *rb_parent = NULL;
    113	u32 key = node->key;
    114	u32 val = node->val;
    115	struct test_node *parent;
    116	bool leftmost = true;
    117
    118	while (*new) {
    119		rb_parent = *new;
    120		parent = rb_entry(rb_parent, struct test_node, rb);
    121		if (parent->augmented < val)
    122			parent->augmented = val;
    123		if (key < parent->key)
    124			new = &parent->rb.rb_left;
    125		else {
    126			new = &parent->rb.rb_right;
    127			leftmost = false;
    128		}
    129	}
    130
    131	node->augmented = val;
    132	rb_link_node(&node->rb, rb_parent, new);
    133	rb_insert_augmented_cached(&node->rb, root,
    134				   leftmost, &augment_callbacks);
    135}
    136
    137
    138static void erase_augmented(struct test_node *node, struct rb_root_cached *root)
    139{
    140	rb_erase_augmented(&node->rb, &root->rb_root, &augment_callbacks);
    141}
    142
    143static void erase_augmented_cached(struct test_node *node,
    144				   struct rb_root_cached *root)
    145{
    146	rb_erase_augmented_cached(&node->rb, root, &augment_callbacks);
    147}
    148
    149static void init(void)
    150{
    151	int i;
    152	for (i = 0; i < nnodes; i++) {
    153		nodes[i].key = prandom_u32_state(&rnd);
    154		nodes[i].val = prandom_u32_state(&rnd);
    155	}
    156}
    157
    158static bool is_red(struct rb_node *rb)
    159{
    160	return !(rb->__rb_parent_color & 1);
    161}
    162
    163static int black_path_count(struct rb_node *rb)
    164{
    165	int count;
    166	for (count = 0; rb; rb = rb_parent(rb))
    167		count += !is_red(rb);
    168	return count;
    169}
    170
    171static void check_postorder_foreach(int nr_nodes)
    172{
    173	struct test_node *cur, *n;
    174	int count = 0;
    175	rbtree_postorder_for_each_entry_safe(cur, n, &root.rb_root, rb)
    176		count++;
    177
    178	WARN_ON_ONCE(count != nr_nodes);
    179}
    180
    181static void check_postorder(int nr_nodes)
    182{
    183	struct rb_node *rb;
    184	int count = 0;
    185	for (rb = rb_first_postorder(&root.rb_root); rb; rb = rb_next_postorder(rb))
    186		count++;
    187
    188	WARN_ON_ONCE(count != nr_nodes);
    189}
    190
    191static void check(int nr_nodes)
    192{
    193	struct rb_node *rb;
    194	int count = 0, blacks = 0;
    195	u32 prev_key = 0;
    196
    197	for (rb = rb_first(&root.rb_root); rb; rb = rb_next(rb)) {
    198		struct test_node *node = rb_entry(rb, struct test_node, rb);
    199		WARN_ON_ONCE(node->key < prev_key);
    200		WARN_ON_ONCE(is_red(rb) &&
    201			     (!rb_parent(rb) || is_red(rb_parent(rb))));
    202		if (!count)
    203			blacks = black_path_count(rb);
    204		else
    205			WARN_ON_ONCE((!rb->rb_left || !rb->rb_right) &&
    206				     blacks != black_path_count(rb));
    207		prev_key = node->key;
    208		count++;
    209	}
    210
    211	WARN_ON_ONCE(count != nr_nodes);
    212	WARN_ON_ONCE(count < (1 << black_path_count(rb_last(&root.rb_root))) - 1);
    213
    214	check_postorder(nr_nodes);
    215	check_postorder_foreach(nr_nodes);
    216}
    217
    218static void check_augmented(int nr_nodes)
    219{
    220	struct rb_node *rb;
    221
    222	check(nr_nodes);
    223	for (rb = rb_first(&root.rb_root); rb; rb = rb_next(rb)) {
    224		struct test_node *node = rb_entry(rb, struct test_node, rb);
    225		u32 subtree, max = node->val;
    226		if (node->rb.rb_left) {
    227			subtree = rb_entry(node->rb.rb_left, struct test_node,
    228					   rb)->augmented;
    229			if (max < subtree)
    230				max = subtree;
    231		}
    232		if (node->rb.rb_right) {
    233			subtree = rb_entry(node->rb.rb_right, struct test_node,
    234					   rb)->augmented;
    235			if (max < subtree)
    236				max = subtree;
    237		}
    238		WARN_ON_ONCE(node->augmented != max);
    239	}
    240}
    241
    242static int __init rbtree_test_init(void)
    243{
    244	int i, j;
    245	cycles_t time1, time2, time;
    246	struct rb_node *node;
    247
    248	nodes = kmalloc_array(nnodes, sizeof(*nodes), GFP_KERNEL);
    249	if (!nodes)
    250		return -ENOMEM;
    251
    252	printk(KERN_ALERT "rbtree testing");
    253
    254	prandom_seed_state(&rnd, 3141592653589793238ULL);
    255	init();
    256
    257	time1 = get_cycles();
    258
    259	for (i = 0; i < perf_loops; i++) {
    260		for (j = 0; j < nnodes; j++)
    261			insert(nodes + j, &root);
    262		for (j = 0; j < nnodes; j++)
    263			erase(nodes + j, &root);
    264	}
    265
    266	time2 = get_cycles();
    267	time = time2 - time1;
    268
    269	time = div_u64(time, perf_loops);
    270	printk(" -> test 1 (latency of nnodes insert+delete): %llu cycles\n",
    271	       (unsigned long long)time);
    272
    273	time1 = get_cycles();
    274
    275	for (i = 0; i < perf_loops; i++) {
    276		for (j = 0; j < nnodes; j++)
    277			insert_cached(nodes + j, &root);
    278		for (j = 0; j < nnodes; j++)
    279			erase_cached(nodes + j, &root);
    280	}
    281
    282	time2 = get_cycles();
    283	time = time2 - time1;
    284
    285	time = div_u64(time, perf_loops);
    286	printk(" -> test 2 (latency of nnodes cached insert+delete): %llu cycles\n",
    287	       (unsigned long long)time);
    288
    289	for (i = 0; i < nnodes; i++)
    290		insert(nodes + i, &root);
    291
    292	time1 = get_cycles();
    293
    294	for (i = 0; i < perf_loops; i++) {
    295		for (node = rb_first(&root.rb_root); node; node = rb_next(node))
    296			;
    297	}
    298
    299	time2 = get_cycles();
    300	time = time2 - time1;
    301
    302	time = div_u64(time, perf_loops);
    303	printk(" -> test 3 (latency of inorder traversal): %llu cycles\n",
    304	       (unsigned long long)time);
    305
    306	time1 = get_cycles();
    307
    308	for (i = 0; i < perf_loops; i++)
    309		node = rb_first(&root.rb_root);
    310
    311	time2 = get_cycles();
    312	time = time2 - time1;
    313
    314	time = div_u64(time, perf_loops);
    315	printk(" -> test 4 (latency to fetch first node)\n");
    316	printk("        non-cached: %llu cycles\n", (unsigned long long)time);
    317
    318	time1 = get_cycles();
    319
    320	for (i = 0; i < perf_loops; i++)
    321		node = rb_first_cached(&root);
    322
    323	time2 = get_cycles();
    324	time = time2 - time1;
    325
    326	time = div_u64(time, perf_loops);
    327	printk("        cached: %llu cycles\n", (unsigned long long)time);
    328
    329	for (i = 0; i < nnodes; i++)
    330		erase(nodes + i, &root);
    331
    332	/* run checks */
    333	for (i = 0; i < check_loops; i++) {
    334		init();
    335		for (j = 0; j < nnodes; j++) {
    336			check(j);
    337			insert(nodes + j, &root);
    338		}
    339		for (j = 0; j < nnodes; j++) {
    340			check(nnodes - j);
    341			erase(nodes + j, &root);
    342		}
    343		check(0);
    344	}
    345
    346	printk(KERN_ALERT "augmented rbtree testing");
    347
    348	init();
    349
    350	time1 = get_cycles();
    351
    352	for (i = 0; i < perf_loops; i++) {
    353		for (j = 0; j < nnodes; j++)
    354			insert_augmented(nodes + j, &root);
    355		for (j = 0; j < nnodes; j++)
    356			erase_augmented(nodes + j, &root);
    357	}
    358
    359	time2 = get_cycles();
    360	time = time2 - time1;
    361
    362	time = div_u64(time, perf_loops);
    363	printk(" -> test 1 (latency of nnodes insert+delete): %llu cycles\n", (unsigned long long)time);
    364
    365	time1 = get_cycles();
    366
    367	for (i = 0; i < perf_loops; i++) {
    368		for (j = 0; j < nnodes; j++)
    369			insert_augmented_cached(nodes + j, &root);
    370		for (j = 0; j < nnodes; j++)
    371			erase_augmented_cached(nodes + j, &root);
    372	}
    373
    374	time2 = get_cycles();
    375	time = time2 - time1;
    376
    377	time = div_u64(time, perf_loops);
    378	printk(" -> test 2 (latency of nnodes cached insert+delete): %llu cycles\n", (unsigned long long)time);
    379
    380	for (i = 0; i < check_loops; i++) {
    381		init();
    382		for (j = 0; j < nnodes; j++) {
    383			check_augmented(j);
    384			insert_augmented(nodes + j, &root);
    385		}
    386		for (j = 0; j < nnodes; j++) {
    387			check_augmented(nnodes - j);
    388			erase_augmented(nodes + j, &root);
    389		}
    390		check_augmented(0);
    391	}
    392
    393	kfree(nodes);
    394
    395	return -EAGAIN; /* Fail will directly unload the module */
    396}
    397
    398static void __exit rbtree_test_exit(void)
    399{
    400	printk(KERN_ALERT "test exit\n");
    401}
    402
    403module_init(rbtree_test_init)
    404module_exit(rbtree_test_exit)
    405
    406MODULE_LICENSE("GPL");
    407MODULE_AUTHOR("Michel Lespinasse");
    408MODULE_DESCRIPTION("Red Black Tree test");