cachepc-linux

Fork of AMDESE/linux with modifications for CachePC side-channel attack
git clone https://git.sinitax.com/sinitax/cachepc-linux
Log | Files | Refs | README | LICENSE | sfeed.txt

test_objagg.c (25159B)


      1// SPDX-License-Identifier: BSD-3-Clause OR GPL-2.0
      2/* Copyright (c) 2018 Mellanox Technologies. All rights reserved */
      3
      4#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt
      5
      6#include <linux/kernel.h>
      7#include <linux/module.h>
      8#include <linux/slab.h>
      9#include <linux/random.h>
     10#include <linux/objagg.h>
     11
     12struct tokey {
     13	unsigned int id;
     14};
     15
     16#define NUM_KEYS 32
     17
     18static int key_id_index(unsigned int key_id)
     19{
     20	if (key_id >= NUM_KEYS) {
     21		WARN_ON(1);
     22		return 0;
     23	}
     24	return key_id;
     25}
     26
     27#define BUF_LEN 128
     28
     29struct world {
     30	unsigned int root_count;
     31	unsigned int delta_count;
     32	char next_root_buf[BUF_LEN];
     33	struct objagg_obj *objagg_objs[NUM_KEYS];
     34	unsigned int key_refs[NUM_KEYS];
     35};
     36
     37struct root {
     38	struct tokey key;
     39	char buf[BUF_LEN];
     40};
     41
     42struct delta {
     43	unsigned int key_id_diff;
     44};
     45
     46static struct objagg_obj *world_obj_get(struct world *world,
     47					struct objagg *objagg,
     48					unsigned int key_id)
     49{
     50	struct objagg_obj *objagg_obj;
     51	struct tokey key;
     52	int err;
     53
     54	key.id = key_id;
     55	objagg_obj = objagg_obj_get(objagg, &key);
     56	if (IS_ERR(objagg_obj)) {
     57		pr_err("Key %u: Failed to get object.\n", key_id);
     58		return objagg_obj;
     59	}
     60	if (!world->key_refs[key_id_index(key_id)]) {
     61		world->objagg_objs[key_id_index(key_id)] = objagg_obj;
     62	} else if (world->objagg_objs[key_id_index(key_id)] != objagg_obj) {
     63		pr_err("Key %u: God another object for the same key.\n",
     64		       key_id);
     65		err = -EINVAL;
     66		goto err_key_id_check;
     67	}
     68	world->key_refs[key_id_index(key_id)]++;
     69	return objagg_obj;
     70
     71err_key_id_check:
     72	objagg_obj_put(objagg, objagg_obj);
     73	return ERR_PTR(err);
     74}
     75
     76static void world_obj_put(struct world *world, struct objagg *objagg,
     77			  unsigned int key_id)
     78{
     79	struct objagg_obj *objagg_obj;
     80
     81	if (!world->key_refs[key_id_index(key_id)])
     82		return;
     83	objagg_obj = world->objagg_objs[key_id_index(key_id)];
     84	objagg_obj_put(objagg, objagg_obj);
     85	world->key_refs[key_id_index(key_id)]--;
     86}
     87
     88#define MAX_KEY_ID_DIFF 5
     89
     90static bool delta_check(void *priv, const void *parent_obj, const void *obj)
     91{
     92	const struct tokey *parent_key = parent_obj;
     93	const struct tokey *key = obj;
     94	int diff = key->id - parent_key->id;
     95
     96	return diff >= 0 && diff <= MAX_KEY_ID_DIFF;
     97}
     98
     99static void *delta_create(void *priv, void *parent_obj, void *obj)
    100{
    101	struct tokey *parent_key = parent_obj;
    102	struct world *world = priv;
    103	struct tokey *key = obj;
    104	int diff = key->id - parent_key->id;
    105	struct delta *delta;
    106
    107	if (!delta_check(priv, parent_obj, obj))
    108		return ERR_PTR(-EINVAL);
    109
    110	delta = kzalloc(sizeof(*delta), GFP_KERNEL);
    111	if (!delta)
    112		return ERR_PTR(-ENOMEM);
    113	delta->key_id_diff = diff;
    114	world->delta_count++;
    115	return delta;
    116}
    117
    118static void delta_destroy(void *priv, void *delta_priv)
    119{
    120	struct delta *delta = delta_priv;
    121	struct world *world = priv;
    122
    123	world->delta_count--;
    124	kfree(delta);
    125}
    126
    127static void *root_create(void *priv, void *obj, unsigned int id)
    128{
    129	struct world *world = priv;
    130	struct tokey *key = obj;
    131	struct root *root;
    132
    133	root = kzalloc(sizeof(*root), GFP_KERNEL);
    134	if (!root)
    135		return ERR_PTR(-ENOMEM);
    136	memcpy(&root->key, key, sizeof(root->key));
    137	memcpy(root->buf, world->next_root_buf, sizeof(root->buf));
    138	world->root_count++;
    139	return root;
    140}
    141
    142static void root_destroy(void *priv, void *root_priv)
    143{
    144	struct root *root = root_priv;
    145	struct world *world = priv;
    146
    147	world->root_count--;
    148	kfree(root);
    149}
    150
    151static int test_nodelta_obj_get(struct world *world, struct objagg *objagg,
    152				unsigned int key_id, bool should_create_root)
    153{
    154	unsigned int orig_root_count = world->root_count;
    155	struct objagg_obj *objagg_obj;
    156	const struct root *root;
    157	int err;
    158
    159	if (should_create_root)
    160		prandom_bytes(world->next_root_buf,
    161			      sizeof(world->next_root_buf));
    162
    163	objagg_obj = world_obj_get(world, objagg, key_id);
    164	if (IS_ERR(objagg_obj)) {
    165		pr_err("Key %u: Failed to get object.\n", key_id);
    166		return PTR_ERR(objagg_obj);
    167	}
    168	if (should_create_root) {
    169		if (world->root_count != orig_root_count + 1) {
    170			pr_err("Key %u: Root was not created\n", key_id);
    171			err = -EINVAL;
    172			goto err_check_root_count;
    173		}
    174	} else {
    175		if (world->root_count != orig_root_count) {
    176			pr_err("Key %u: Root was incorrectly created\n",
    177			       key_id);
    178			err = -EINVAL;
    179			goto err_check_root_count;
    180		}
    181	}
    182	root = objagg_obj_root_priv(objagg_obj);
    183	if (root->key.id != key_id) {
    184		pr_err("Key %u: Root has unexpected key id\n", key_id);
    185		err = -EINVAL;
    186		goto err_check_key_id;
    187	}
    188	if (should_create_root &&
    189	    memcmp(world->next_root_buf, root->buf, sizeof(root->buf))) {
    190		pr_err("Key %u: Buffer does not match the expected content\n",
    191		       key_id);
    192		err = -EINVAL;
    193		goto err_check_buf;
    194	}
    195	return 0;
    196
    197err_check_buf:
    198err_check_key_id:
    199err_check_root_count:
    200	objagg_obj_put(objagg, objagg_obj);
    201	return err;
    202}
    203
    204static int test_nodelta_obj_put(struct world *world, struct objagg *objagg,
    205				unsigned int key_id, bool should_destroy_root)
    206{
    207	unsigned int orig_root_count = world->root_count;
    208
    209	world_obj_put(world, objagg, key_id);
    210
    211	if (should_destroy_root) {
    212		if (world->root_count != orig_root_count - 1) {
    213			pr_err("Key %u: Root was not destroyed\n", key_id);
    214			return -EINVAL;
    215		}
    216	} else {
    217		if (world->root_count != orig_root_count) {
    218			pr_err("Key %u: Root was incorrectly destroyed\n",
    219			       key_id);
    220			return -EINVAL;
    221		}
    222	}
    223	return 0;
    224}
    225
    226static int check_stats_zero(struct objagg *objagg)
    227{
    228	const struct objagg_stats *stats;
    229	int err = 0;
    230
    231	stats = objagg_stats_get(objagg);
    232	if (IS_ERR(stats))
    233		return PTR_ERR(stats);
    234
    235	if (stats->stats_info_count != 0) {
    236		pr_err("Stats: Object count is not zero while it should be\n");
    237		err = -EINVAL;
    238	}
    239
    240	objagg_stats_put(stats);
    241	return err;
    242}
    243
    244static int check_stats_nodelta(struct objagg *objagg)
    245{
    246	const struct objagg_stats *stats;
    247	int i;
    248	int err;
    249
    250	stats = objagg_stats_get(objagg);
    251	if (IS_ERR(stats))
    252		return PTR_ERR(stats);
    253
    254	if (stats->stats_info_count != NUM_KEYS) {
    255		pr_err("Stats: Unexpected object count (%u expected, %u returned)\n",
    256		       NUM_KEYS, stats->stats_info_count);
    257		err = -EINVAL;
    258		goto stats_put;
    259	}
    260
    261	for (i = 0; i < stats->stats_info_count; i++) {
    262		if (stats->stats_info[i].stats.user_count != 2) {
    263			pr_err("Stats: incorrect user count\n");
    264			err = -EINVAL;
    265			goto stats_put;
    266		}
    267		if (stats->stats_info[i].stats.delta_user_count != 2) {
    268			pr_err("Stats: incorrect delta user count\n");
    269			err = -EINVAL;
    270			goto stats_put;
    271		}
    272	}
    273	err = 0;
    274
    275stats_put:
    276	objagg_stats_put(stats);
    277	return err;
    278}
    279
    280static bool delta_check_dummy(void *priv, const void *parent_obj,
    281			      const void *obj)
    282{
    283	return false;
    284}
    285
    286static void *delta_create_dummy(void *priv, void *parent_obj, void *obj)
    287{
    288	return ERR_PTR(-EOPNOTSUPP);
    289}
    290
    291static void delta_destroy_dummy(void *priv, void *delta_priv)
    292{
    293}
    294
    295static const struct objagg_ops nodelta_ops = {
    296	.obj_size = sizeof(struct tokey),
    297	.delta_check = delta_check_dummy,
    298	.delta_create = delta_create_dummy,
    299	.delta_destroy = delta_destroy_dummy,
    300	.root_create = root_create,
    301	.root_destroy = root_destroy,
    302};
    303
    304static int test_nodelta(void)
    305{
    306	struct world world = {};
    307	struct objagg *objagg;
    308	int i;
    309	int err;
    310
    311	objagg = objagg_create(&nodelta_ops, NULL, &world);
    312	if (IS_ERR(objagg))
    313		return PTR_ERR(objagg);
    314
    315	err = check_stats_zero(objagg);
    316	if (err)
    317		goto err_stats_first_zero;
    318
    319	/* First round of gets, the root objects should be created */
    320	for (i = 0; i < NUM_KEYS; i++) {
    321		err = test_nodelta_obj_get(&world, objagg, i, true);
    322		if (err)
    323			goto err_obj_first_get;
    324	}
    325
    326	/* Do the second round of gets, all roots are already created,
    327	 * make sure that no new root is created
    328	 */
    329	for (i = 0; i < NUM_KEYS; i++) {
    330		err = test_nodelta_obj_get(&world, objagg, i, false);
    331		if (err)
    332			goto err_obj_second_get;
    333	}
    334
    335	err = check_stats_nodelta(objagg);
    336	if (err)
    337		goto err_stats_nodelta;
    338
    339	for (i = NUM_KEYS - 1; i >= 0; i--) {
    340		err = test_nodelta_obj_put(&world, objagg, i, false);
    341		if (err)
    342			goto err_obj_first_put;
    343	}
    344	for (i = NUM_KEYS - 1; i >= 0; i--) {
    345		err = test_nodelta_obj_put(&world, objagg, i, true);
    346		if (err)
    347			goto err_obj_second_put;
    348	}
    349
    350	err = check_stats_zero(objagg);
    351	if (err)
    352		goto err_stats_second_zero;
    353
    354	objagg_destroy(objagg);
    355	return 0;
    356
    357err_stats_nodelta:
    358err_obj_first_put:
    359err_obj_second_get:
    360	for (i--; i >= 0; i--)
    361		world_obj_put(&world, objagg, i);
    362
    363	i = NUM_KEYS;
    364err_obj_first_get:
    365err_obj_second_put:
    366	for (i--; i >= 0; i--)
    367		world_obj_put(&world, objagg, i);
    368err_stats_first_zero:
    369err_stats_second_zero:
    370	objagg_destroy(objagg);
    371	return err;
    372}
    373
    374static const struct objagg_ops delta_ops = {
    375	.obj_size = sizeof(struct tokey),
    376	.delta_check = delta_check,
    377	.delta_create = delta_create,
    378	.delta_destroy = delta_destroy,
    379	.root_create = root_create,
    380	.root_destroy = root_destroy,
    381};
    382
    383enum action {
    384	ACTION_GET,
    385	ACTION_PUT,
    386};
    387
    388enum expect_delta {
    389	EXPECT_DELTA_SAME,
    390	EXPECT_DELTA_INC,
    391	EXPECT_DELTA_DEC,
    392};
    393
    394enum expect_root {
    395	EXPECT_ROOT_SAME,
    396	EXPECT_ROOT_INC,
    397	EXPECT_ROOT_DEC,
    398};
    399
    400struct expect_stats_info {
    401	struct objagg_obj_stats stats;
    402	bool is_root;
    403	unsigned int key_id;
    404};
    405
    406struct expect_stats {
    407	unsigned int info_count;
    408	struct expect_stats_info info[NUM_KEYS];
    409};
    410
    411struct action_item {
    412	unsigned int key_id;
    413	enum action action;
    414	enum expect_delta expect_delta;
    415	enum expect_root expect_root;
    416	struct expect_stats expect_stats;
    417};
    418
    419#define EXPECT_STATS(count, ...)		\
    420{						\
    421	.info_count = count,			\
    422	.info = { __VA_ARGS__ }			\
    423}
    424
    425#define ROOT(key_id, user_count, delta_user_count)	\
    426	{{user_count, delta_user_count}, true, key_id}
    427
    428#define DELTA(key_id, user_count)			\
    429	{{user_count, user_count}, false, key_id}
    430
    431static const struct action_item action_items[] = {
    432	{
    433		1, ACTION_GET, EXPECT_DELTA_SAME, EXPECT_ROOT_INC,
    434		EXPECT_STATS(1, ROOT(1, 1, 1)),
    435	},	/* r: 1			d: */
    436	{
    437		7, ACTION_GET, EXPECT_DELTA_SAME, EXPECT_ROOT_INC,
    438		EXPECT_STATS(2, ROOT(1, 1, 1), ROOT(7, 1, 1)),
    439	},	/* r: 1, 7		d: */
    440	{
    441		3, ACTION_GET, EXPECT_DELTA_INC, EXPECT_ROOT_SAME,
    442		EXPECT_STATS(3, ROOT(1, 1, 2), ROOT(7, 1, 1),
    443				DELTA(3, 1)),
    444	},	/* r: 1, 7		d: 3^1 */
    445	{
    446		5, ACTION_GET, EXPECT_DELTA_INC, EXPECT_ROOT_SAME,
    447		EXPECT_STATS(4, ROOT(1, 1, 3), ROOT(7, 1, 1),
    448				DELTA(3, 1), DELTA(5, 1)),
    449	},	/* r: 1, 7		d: 3^1, 5^1 */
    450	{
    451		3, ACTION_GET, EXPECT_DELTA_SAME, EXPECT_ROOT_SAME,
    452		EXPECT_STATS(4, ROOT(1, 1, 4), ROOT(7, 1, 1),
    453				DELTA(3, 2), DELTA(5, 1)),
    454	},	/* r: 1, 7		d: 3^1, 3^1, 5^1 */
    455	{
    456		1, ACTION_GET, EXPECT_DELTA_SAME, EXPECT_ROOT_SAME,
    457		EXPECT_STATS(4, ROOT(1, 2, 5), ROOT(7, 1, 1),
    458				DELTA(3, 2), DELTA(5, 1)),
    459	},	/* r: 1, 1, 7		d: 3^1, 3^1, 5^1 */
    460	{
    461		30, ACTION_GET, EXPECT_DELTA_SAME, EXPECT_ROOT_INC,
    462		EXPECT_STATS(5, ROOT(1, 2, 5), ROOT(7, 1, 1), ROOT(30, 1, 1),
    463				DELTA(3, 2), DELTA(5, 1)),
    464	},	/* r: 1, 1, 7, 30	d: 3^1, 3^1, 5^1 */
    465	{
    466		8, ACTION_GET, EXPECT_DELTA_INC, EXPECT_ROOT_SAME,
    467		EXPECT_STATS(6, ROOT(1, 2, 5), ROOT(7, 1, 2), ROOT(30, 1, 1),
    468				DELTA(3, 2), DELTA(5, 1), DELTA(8, 1)),
    469	},	/* r: 1, 1, 7, 30	d: 3^1, 3^1, 5^1, 8^7 */
    470	{
    471		8, ACTION_GET, EXPECT_DELTA_SAME, EXPECT_ROOT_SAME,
    472		EXPECT_STATS(6, ROOT(1, 2, 5), ROOT(7, 1, 3), ROOT(30, 1, 1),
    473				DELTA(3, 2), DELTA(8, 2), DELTA(5, 1)),
    474	},	/* r: 1, 1, 7, 30	d: 3^1, 3^1, 5^1, 8^7, 8^7 */
    475	{
    476		3, ACTION_PUT, EXPECT_DELTA_SAME, EXPECT_ROOT_SAME,
    477		EXPECT_STATS(6, ROOT(1, 2, 4), ROOT(7, 1, 3), ROOT(30, 1, 1),
    478				DELTA(8, 2), DELTA(3, 1), DELTA(5, 1)),
    479	},	/* r: 1, 1, 7, 30	d: 3^1, 5^1, 8^7, 8^7 */
    480	{
    481		3, ACTION_PUT, EXPECT_DELTA_DEC, EXPECT_ROOT_SAME,
    482		EXPECT_STATS(5, ROOT(1, 2, 3), ROOT(7, 1, 3), ROOT(30, 1, 1),
    483				DELTA(8, 2), DELTA(5, 1)),
    484	},	/* r: 1, 1, 7, 30	d: 5^1, 8^7, 8^7 */
    485	{
    486		1, ACTION_PUT, EXPECT_DELTA_SAME, EXPECT_ROOT_SAME,
    487		EXPECT_STATS(5, ROOT(7, 1, 3), ROOT(1, 1, 2), ROOT(30, 1, 1),
    488				DELTA(8, 2), DELTA(5, 1)),
    489	},	/* r: 1, 7, 30		d: 5^1, 8^7, 8^7 */
    490	{
    491		1, ACTION_PUT, EXPECT_DELTA_SAME, EXPECT_ROOT_SAME,
    492		EXPECT_STATS(5, ROOT(7, 1, 3), ROOT(30, 1, 1), ROOT(1, 0, 1),
    493				DELTA(8, 2), DELTA(5, 1)),
    494	},	/* r: 7, 30		d: 5^1, 8^7, 8^7 */
    495	{
    496		5, ACTION_PUT, EXPECT_DELTA_DEC, EXPECT_ROOT_DEC,
    497		EXPECT_STATS(3, ROOT(7, 1, 3), ROOT(30, 1, 1),
    498				DELTA(8, 2)),
    499	},	/* r: 7, 30		d: 8^7, 8^7 */
    500	{
    501		5, ACTION_GET, EXPECT_DELTA_SAME, EXPECT_ROOT_INC,
    502		EXPECT_STATS(4, ROOT(7, 1, 3), ROOT(30, 1, 1), ROOT(5, 1, 1),
    503				DELTA(8, 2)),
    504	},	/* r: 7, 30, 5		d: 8^7, 8^7 */
    505	{
    506		6, ACTION_GET, EXPECT_DELTA_INC, EXPECT_ROOT_SAME,
    507		EXPECT_STATS(5, ROOT(7, 1, 3), ROOT(5, 1, 2), ROOT(30, 1, 1),
    508				DELTA(8, 2), DELTA(6, 1)),
    509	},	/* r: 7, 30, 5		d: 8^7, 8^7, 6^5 */
    510	{
    511		8, ACTION_GET, EXPECT_DELTA_SAME, EXPECT_ROOT_SAME,
    512		EXPECT_STATS(5, ROOT(7, 1, 4), ROOT(5, 1, 2), ROOT(30, 1, 1),
    513				DELTA(8, 3), DELTA(6, 1)),
    514	},	/* r: 7, 30, 5		d: 8^7, 8^7, 8^7, 6^5 */
    515	{
    516		8, ACTION_PUT, EXPECT_DELTA_SAME, EXPECT_ROOT_SAME,
    517		EXPECT_STATS(5, ROOT(7, 1, 3), ROOT(5, 1, 2), ROOT(30, 1, 1),
    518				DELTA(8, 2), DELTA(6, 1)),
    519	},	/* r: 7, 30, 5		d: 8^7, 8^7, 6^5 */
    520	{
    521		8, ACTION_PUT, EXPECT_DELTA_SAME, EXPECT_ROOT_SAME,
    522		EXPECT_STATS(5, ROOT(7, 1, 2), ROOT(5, 1, 2), ROOT(30, 1, 1),
    523				DELTA(8, 1), DELTA(6, 1)),
    524	},	/* r: 7, 30, 5		d: 8^7, 6^5 */
    525	{
    526		8, ACTION_PUT, EXPECT_DELTA_DEC, EXPECT_ROOT_SAME,
    527		EXPECT_STATS(4, ROOT(5, 1, 2), ROOT(7, 1, 1), ROOT(30, 1, 1),
    528				DELTA(6, 1)),
    529	},	/* r: 7, 30, 5		d: 6^5 */
    530	{
    531		8, ACTION_GET, EXPECT_DELTA_INC, EXPECT_ROOT_SAME,
    532		EXPECT_STATS(5, ROOT(5, 1, 3), ROOT(7, 1, 1), ROOT(30, 1, 1),
    533				DELTA(6, 1), DELTA(8, 1)),
    534	},	/* r: 7, 30, 5		d: 6^5, 8^5 */
    535	{
    536		7, ACTION_PUT, EXPECT_DELTA_SAME, EXPECT_ROOT_DEC,
    537		EXPECT_STATS(4, ROOT(5, 1, 3), ROOT(30, 1, 1),
    538				DELTA(6, 1), DELTA(8, 1)),
    539	},	/* r: 30, 5		d: 6^5, 8^5 */
    540	{
    541		30, ACTION_PUT, EXPECT_DELTA_SAME, EXPECT_ROOT_DEC,
    542		EXPECT_STATS(3, ROOT(5, 1, 3),
    543				DELTA(6, 1), DELTA(8, 1)),
    544	},	/* r: 5			d: 6^5, 8^5 */
    545	{
    546		5, ACTION_PUT, EXPECT_DELTA_SAME, EXPECT_ROOT_SAME,
    547		EXPECT_STATS(3, ROOT(5, 0, 2),
    548				DELTA(6, 1), DELTA(8, 1)),
    549	},	/* r:			d: 6^5, 8^5 */
    550	{
    551		6, ACTION_PUT, EXPECT_DELTA_DEC, EXPECT_ROOT_SAME,
    552		EXPECT_STATS(2, ROOT(5, 0, 1),
    553				DELTA(8, 1)),
    554	},	/* r:			d: 6^5 */
    555	{
    556		8, ACTION_PUT, EXPECT_DELTA_DEC, EXPECT_ROOT_DEC,
    557		EXPECT_STATS(0, ),
    558	},	/* r:			d: */
    559};
    560
    561static int check_expect(struct world *world,
    562			const struct action_item *action_item,
    563			unsigned int orig_delta_count,
    564			unsigned int orig_root_count)
    565{
    566	unsigned int key_id = action_item->key_id;
    567
    568	switch (action_item->expect_delta) {
    569	case EXPECT_DELTA_SAME:
    570		if (orig_delta_count != world->delta_count) {
    571			pr_err("Key %u: Delta count changed while expected to remain the same.\n",
    572			       key_id);
    573			return -EINVAL;
    574		}
    575		break;
    576	case EXPECT_DELTA_INC:
    577		if (WARN_ON(action_item->action == ACTION_PUT))
    578			return -EINVAL;
    579		if (orig_delta_count + 1 != world->delta_count) {
    580			pr_err("Key %u: Delta count was not incremented.\n",
    581			       key_id);
    582			return -EINVAL;
    583		}
    584		break;
    585	case EXPECT_DELTA_DEC:
    586		if (WARN_ON(action_item->action == ACTION_GET))
    587			return -EINVAL;
    588		if (orig_delta_count - 1 != world->delta_count) {
    589			pr_err("Key %u: Delta count was not decremented.\n",
    590			       key_id);
    591			return -EINVAL;
    592		}
    593		break;
    594	}
    595
    596	switch (action_item->expect_root) {
    597	case EXPECT_ROOT_SAME:
    598		if (orig_root_count != world->root_count) {
    599			pr_err("Key %u: Root count changed while expected to remain the same.\n",
    600			       key_id);
    601			return -EINVAL;
    602		}
    603		break;
    604	case EXPECT_ROOT_INC:
    605		if (WARN_ON(action_item->action == ACTION_PUT))
    606			return -EINVAL;
    607		if (orig_root_count + 1 != world->root_count) {
    608			pr_err("Key %u: Root count was not incremented.\n",
    609			       key_id);
    610			return -EINVAL;
    611		}
    612		break;
    613	case EXPECT_ROOT_DEC:
    614		if (WARN_ON(action_item->action == ACTION_GET))
    615			return -EINVAL;
    616		if (orig_root_count - 1 != world->root_count) {
    617			pr_err("Key %u: Root count was not decremented.\n",
    618			       key_id);
    619			return -EINVAL;
    620		}
    621	}
    622
    623	return 0;
    624}
    625
    626static unsigned int obj_to_key_id(struct objagg_obj *objagg_obj)
    627{
    628	const struct tokey *root_key;
    629	const struct delta *delta;
    630	unsigned int key_id;
    631
    632	root_key = objagg_obj_root_priv(objagg_obj);
    633	key_id = root_key->id;
    634	delta = objagg_obj_delta_priv(objagg_obj);
    635	if (delta)
    636		key_id += delta->key_id_diff;
    637	return key_id;
    638}
    639
    640static int
    641check_expect_stats_nums(const struct objagg_obj_stats_info *stats_info,
    642			const struct expect_stats_info *expect_stats_info,
    643			const char **errmsg)
    644{
    645	if (stats_info->is_root != expect_stats_info->is_root) {
    646		if (errmsg)
    647			*errmsg = "Incorrect root/delta indication";
    648		return -EINVAL;
    649	}
    650	if (stats_info->stats.user_count !=
    651	    expect_stats_info->stats.user_count) {
    652		if (errmsg)
    653			*errmsg = "Incorrect user count";
    654		return -EINVAL;
    655	}
    656	if (stats_info->stats.delta_user_count !=
    657	    expect_stats_info->stats.delta_user_count) {
    658		if (errmsg)
    659			*errmsg = "Incorrect delta user count";
    660		return -EINVAL;
    661	}
    662	return 0;
    663}
    664
    665static int
    666check_expect_stats_key_id(const struct objagg_obj_stats_info *stats_info,
    667			  const struct expect_stats_info *expect_stats_info,
    668			  const char **errmsg)
    669{
    670	if (obj_to_key_id(stats_info->objagg_obj) !=
    671	    expect_stats_info->key_id) {
    672		if (errmsg)
    673			*errmsg = "incorrect key id";
    674		return -EINVAL;
    675	}
    676	return 0;
    677}
    678
    679static int check_expect_stats_neigh(const struct objagg_stats *stats,
    680				    const struct expect_stats *expect_stats,
    681				    int pos)
    682{
    683	int i;
    684	int err;
    685
    686	for (i = pos - 1; i >= 0; i--) {
    687		err = check_expect_stats_nums(&stats->stats_info[i],
    688					      &expect_stats->info[pos], NULL);
    689		if (err)
    690			break;
    691		err = check_expect_stats_key_id(&stats->stats_info[i],
    692						&expect_stats->info[pos], NULL);
    693		if (!err)
    694			return 0;
    695	}
    696	for (i = pos + 1; i < stats->stats_info_count; i++) {
    697		err = check_expect_stats_nums(&stats->stats_info[i],
    698					      &expect_stats->info[pos], NULL);
    699		if (err)
    700			break;
    701		err = check_expect_stats_key_id(&stats->stats_info[i],
    702						&expect_stats->info[pos], NULL);
    703		if (!err)
    704			return 0;
    705	}
    706	return -EINVAL;
    707}
    708
    709static int __check_expect_stats(const struct objagg_stats *stats,
    710				const struct expect_stats *expect_stats,
    711				const char **errmsg)
    712{
    713	int i;
    714	int err;
    715
    716	if (stats->stats_info_count != expect_stats->info_count) {
    717		*errmsg = "Unexpected object count";
    718		return -EINVAL;
    719	}
    720
    721	for (i = 0; i < stats->stats_info_count; i++) {
    722		err = check_expect_stats_nums(&stats->stats_info[i],
    723					      &expect_stats->info[i], errmsg);
    724		if (err)
    725			return err;
    726		err = check_expect_stats_key_id(&stats->stats_info[i],
    727						&expect_stats->info[i], errmsg);
    728		if (err) {
    729			/* It is possible that one of the neighbor stats with
    730			 * same numbers have the correct key id, so check it
    731			 */
    732			err = check_expect_stats_neigh(stats, expect_stats, i);
    733			if (err)
    734				return err;
    735		}
    736	}
    737	return 0;
    738}
    739
    740static int check_expect_stats(struct objagg *objagg,
    741			      const struct expect_stats *expect_stats,
    742			      const char **errmsg)
    743{
    744	const struct objagg_stats *stats;
    745	int err;
    746
    747	stats = objagg_stats_get(objagg);
    748	if (IS_ERR(stats)) {
    749		*errmsg = "objagg_stats_get() failed.";
    750		return PTR_ERR(stats);
    751	}
    752	err = __check_expect_stats(stats, expect_stats, errmsg);
    753	objagg_stats_put(stats);
    754	return err;
    755}
    756
    757static int test_delta_action_item(struct world *world,
    758				  struct objagg *objagg,
    759				  const struct action_item *action_item,
    760				  bool inverse)
    761{
    762	unsigned int orig_delta_count = world->delta_count;
    763	unsigned int orig_root_count = world->root_count;
    764	unsigned int key_id = action_item->key_id;
    765	enum action action = action_item->action;
    766	struct objagg_obj *objagg_obj;
    767	const char *errmsg;
    768	int err;
    769
    770	if (inverse)
    771		action = action == ACTION_GET ? ACTION_PUT : ACTION_GET;
    772
    773	switch (action) {
    774	case ACTION_GET:
    775		objagg_obj = world_obj_get(world, objagg, key_id);
    776		if (IS_ERR(objagg_obj))
    777			return PTR_ERR(objagg_obj);
    778		break;
    779	case ACTION_PUT:
    780		world_obj_put(world, objagg, key_id);
    781		break;
    782	}
    783
    784	if (inverse)
    785		return 0;
    786	err = check_expect(world, action_item,
    787			   orig_delta_count, orig_root_count);
    788	if (err)
    789		goto errout;
    790
    791	err = check_expect_stats(objagg, &action_item->expect_stats, &errmsg);
    792	if (err) {
    793		pr_err("Key %u: Stats: %s\n", action_item->key_id, errmsg);
    794		goto errout;
    795	}
    796
    797	return 0;
    798
    799errout:
    800	/* This can only happen when action is not inversed.
    801	 * So in case of an error, cleanup by doing inverse action.
    802	 */
    803	test_delta_action_item(world, objagg, action_item, true);
    804	return err;
    805}
    806
    807static int test_delta(void)
    808{
    809	struct world world = {};
    810	struct objagg *objagg;
    811	int i;
    812	int err;
    813
    814	objagg = objagg_create(&delta_ops, NULL, &world);
    815	if (IS_ERR(objagg))
    816		return PTR_ERR(objagg);
    817
    818	for (i = 0; i < ARRAY_SIZE(action_items); i++) {
    819		err = test_delta_action_item(&world, objagg,
    820					     &action_items[i], false);
    821		if (err)
    822			goto err_do_action_item;
    823	}
    824
    825	objagg_destroy(objagg);
    826	return 0;
    827
    828err_do_action_item:
    829	for (i--; i >= 0; i--)
    830		test_delta_action_item(&world, objagg, &action_items[i], true);
    831
    832	objagg_destroy(objagg);
    833	return err;
    834}
    835
    836struct hints_case {
    837	const unsigned int *key_ids;
    838	size_t key_ids_count;
    839	struct expect_stats expect_stats;
    840	struct expect_stats expect_stats_hints;
    841};
    842
    843static const unsigned int hints_case_key_ids[] = {
    844	1, 7, 3, 5, 3, 1, 30, 8, 8, 5, 6, 8,
    845};
    846
    847static const struct hints_case hints_case = {
    848	.key_ids = hints_case_key_ids,
    849	.key_ids_count = ARRAY_SIZE(hints_case_key_ids),
    850	.expect_stats =
    851		EXPECT_STATS(7, ROOT(1, 2, 7), ROOT(7, 1, 4), ROOT(30, 1, 1),
    852				DELTA(8, 3), DELTA(3, 2),
    853				DELTA(5, 2), DELTA(6, 1)),
    854	.expect_stats_hints =
    855		EXPECT_STATS(7, ROOT(3, 2, 9), ROOT(1, 2, 2), ROOT(30, 1, 1),
    856				DELTA(8, 3), DELTA(5, 2),
    857				DELTA(6, 1), DELTA(7, 1)),
    858};
    859
    860static void __pr_debug_stats(const struct objagg_stats *stats)
    861{
    862	int i;
    863
    864	for (i = 0; i < stats->stats_info_count; i++)
    865		pr_debug("Stat index %d key %u: u %d, d %d, %s\n", i,
    866			 obj_to_key_id(stats->stats_info[i].objagg_obj),
    867			 stats->stats_info[i].stats.user_count,
    868			 stats->stats_info[i].stats.delta_user_count,
    869			 stats->stats_info[i].is_root ? "root" : "noroot");
    870}
    871
    872static void pr_debug_stats(struct objagg *objagg)
    873{
    874	const struct objagg_stats *stats;
    875
    876	stats = objagg_stats_get(objagg);
    877	if (IS_ERR(stats))
    878		return;
    879	__pr_debug_stats(stats);
    880	objagg_stats_put(stats);
    881}
    882
    883static void pr_debug_hints_stats(struct objagg_hints *objagg_hints)
    884{
    885	const struct objagg_stats *stats;
    886
    887	stats = objagg_hints_stats_get(objagg_hints);
    888	if (IS_ERR(stats))
    889		return;
    890	__pr_debug_stats(stats);
    891	objagg_stats_put(stats);
    892}
    893
    894static int check_expect_hints_stats(struct objagg_hints *objagg_hints,
    895				    const struct expect_stats *expect_stats,
    896				    const char **errmsg)
    897{
    898	const struct objagg_stats *stats;
    899	int err;
    900
    901	stats = objagg_hints_stats_get(objagg_hints);
    902	if (IS_ERR(stats))
    903		return PTR_ERR(stats);
    904	err = __check_expect_stats(stats, expect_stats, errmsg);
    905	objagg_stats_put(stats);
    906	return err;
    907}
    908
    909static int test_hints_case(const struct hints_case *hints_case)
    910{
    911	struct objagg_obj *objagg_obj;
    912	struct objagg_hints *hints;
    913	struct world world2 = {};
    914	struct world world = {};
    915	struct objagg *objagg2;
    916	struct objagg *objagg;
    917	const char *errmsg;
    918	int i;
    919	int err;
    920
    921	objagg = objagg_create(&delta_ops, NULL, &world);
    922	if (IS_ERR(objagg))
    923		return PTR_ERR(objagg);
    924
    925	for (i = 0; i < hints_case->key_ids_count; i++) {
    926		objagg_obj = world_obj_get(&world, objagg,
    927					   hints_case->key_ids[i]);
    928		if (IS_ERR(objagg_obj)) {
    929			err = PTR_ERR(objagg_obj);
    930			goto err_world_obj_get;
    931		}
    932	}
    933
    934	pr_debug_stats(objagg);
    935	err = check_expect_stats(objagg, &hints_case->expect_stats, &errmsg);
    936	if (err) {
    937		pr_err("Stats: %s\n", errmsg);
    938		goto err_check_expect_stats;
    939	}
    940
    941	hints = objagg_hints_get(objagg, OBJAGG_OPT_ALGO_SIMPLE_GREEDY);
    942	if (IS_ERR(hints)) {
    943		err = PTR_ERR(hints);
    944		goto err_hints_get;
    945	}
    946
    947	pr_debug_hints_stats(hints);
    948	err = check_expect_hints_stats(hints, &hints_case->expect_stats_hints,
    949				       &errmsg);
    950	if (err) {
    951		pr_err("Hints stats: %s\n", errmsg);
    952		goto err_check_expect_hints_stats;
    953	}
    954
    955	objagg2 = objagg_create(&delta_ops, hints, &world2);
    956	if (IS_ERR(objagg2))
    957		return PTR_ERR(objagg2);
    958
    959	for (i = 0; i < hints_case->key_ids_count; i++) {
    960		objagg_obj = world_obj_get(&world2, objagg2,
    961					   hints_case->key_ids[i]);
    962		if (IS_ERR(objagg_obj)) {
    963			err = PTR_ERR(objagg_obj);
    964			goto err_world2_obj_get;
    965		}
    966	}
    967
    968	pr_debug_stats(objagg2);
    969	err = check_expect_stats(objagg2, &hints_case->expect_stats_hints,
    970				 &errmsg);
    971	if (err) {
    972		pr_err("Stats2: %s\n", errmsg);
    973		goto err_check_expect_stats2;
    974	}
    975
    976	err = 0;
    977
    978err_check_expect_stats2:
    979err_world2_obj_get:
    980	for (i--; i >= 0; i--)
    981		world_obj_put(&world2, objagg, hints_case->key_ids[i]);
    982	i = hints_case->key_ids_count;
    983	objagg_destroy(objagg2);
    984err_check_expect_hints_stats:
    985	objagg_hints_put(hints);
    986err_hints_get:
    987err_check_expect_stats:
    988err_world_obj_get:
    989	for (i--; i >= 0; i--)
    990		world_obj_put(&world, objagg, hints_case->key_ids[i]);
    991
    992	objagg_destroy(objagg);
    993	return err;
    994}
    995static int test_hints(void)
    996{
    997	return test_hints_case(&hints_case);
    998}
    999
   1000static int __init test_objagg_init(void)
   1001{
   1002	int err;
   1003
   1004	err = test_nodelta();
   1005	if (err)
   1006		return err;
   1007	err = test_delta();
   1008	if (err)
   1009		return err;
   1010	return test_hints();
   1011}
   1012
   1013static void __exit test_objagg_exit(void)
   1014{
   1015}
   1016
   1017module_init(test_objagg_init);
   1018module_exit(test_objagg_exit);
   1019MODULE_LICENSE("Dual BSD/GPL");
   1020MODULE_AUTHOR("Jiri Pirko <jiri@mellanox.com>");
   1021MODULE_DESCRIPTION("Test module for objagg");