1// SPDX-License-Identifier: BSD-3-Clause OR GPL-2.0
2/* Copyright (c) 2018 Mellanox Technologies. All rights reserved */
3
4#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt
5
6#include <linux/kernel.h>
7#include <linux/module.h>
8#include <linux/slab.h>
9#include <linux/random.h>
10#include <linux/objagg.h>
11
12struct tokey {
13	unsigned int id;
14};
15
16#define NUM_KEYS 32
17
18static int key_id_index(unsigned int key_id)
19{
20	if (key_id >= NUM_KEYS) {
21		WARN_ON(1);
22		return 0;
23	}
24	return key_id;
25}
26
27#define BUF_LEN 128
28
29struct world {
30	unsigned int root_count;
31	unsigned int delta_count;
32	char next_root_buf[BUF_LEN];
33	struct objagg_obj *objagg_objs[NUM_KEYS];
34	unsigned int key_refs[NUM_KEYS];
35};
36
37struct root {
38	struct tokey key;
39	char buf[BUF_LEN];
40};
41
42struct delta {
43	unsigned int key_id_diff;
44};
45
46static struct objagg_obj *world_obj_get(struct world *world,
47					struct objagg *objagg,
48					unsigned int key_id)
49{
50	struct objagg_obj *objagg_obj;
51	struct tokey key;
52	int err;
53
54	key.id = key_id;
55	objagg_obj = objagg_obj_get(objagg, &key);
56	if (IS_ERR(objagg_obj)) {
57		pr_err("Key %u: Failed to get object.\n", key_id);
58		return objagg_obj;
59	}
60	if (!world->key_refs[key_id_index(key_id)]) {
61		world->objagg_objs[key_id_index(key_id)] = objagg_obj;
62	} else if (world->objagg_objs[key_id_index(key_id)] != objagg_obj) {
63		pr_err("Key %u: God another object for the same key.\n",
64		       key_id);
65		err = -EINVAL;
66		goto err_key_id_check;
67	}
68	world->key_refs[key_id_index(key_id)]++;
69	return objagg_obj;
70
71err_key_id_check:
72	objagg_obj_put(objagg, objagg_obj);
73	return ERR_PTR(err);
74}
75
76static void world_obj_put(struct world *world, struct objagg *objagg,
77			  unsigned int key_id)
78{
79	struct objagg_obj *objagg_obj;
80
81	if (!world->key_refs[key_id_index(key_id)])
82		return;
83	objagg_obj = world->objagg_objs[key_id_index(key_id)];
84	objagg_obj_put(objagg, objagg_obj);
85	world->key_refs[key_id_index(key_id)]--;
86}
87
88#define MAX_KEY_ID_DIFF 5
89
90static bool delta_check(void *priv, const void *parent_obj, const void *obj)
91{
92	const struct tokey *parent_key = parent_obj;
93	const struct tokey *key = obj;
94	int diff = key->id - parent_key->id;
95
96	return diff >= 0 && diff <= MAX_KEY_ID_DIFF;
97}
98
99static void *delta_create(void *priv, void *parent_obj, void *obj)
100{
101	struct tokey *parent_key = parent_obj;
102	struct world *world = priv;
103	struct tokey *key = obj;
104	int diff = key->id - parent_key->id;
105	struct delta *delta;
106
107	if (!delta_check(priv, parent_obj, obj))
108		return ERR_PTR(-EINVAL);
109
110	delta = kzalloc(sizeof(*delta), GFP_KERNEL);
111	if (!delta)
112		return ERR_PTR(-ENOMEM);
113	delta->key_id_diff = diff;
114	world->delta_count++;
115	return delta;
116}
117
118static void delta_destroy(void *priv, void *delta_priv)
119{
120	struct delta *delta = delta_priv;
121	struct world *world = priv;
122
123	world->delta_count--;
124	kfree(delta);
125}
126
127static void *root_create(void *priv, void *obj, unsigned int id)
128{
129	struct world *world = priv;
130	struct tokey *key = obj;
131	struct root *root;
132
133	root = kzalloc(sizeof(*root), GFP_KERNEL);
134	if (!root)
135		return ERR_PTR(-ENOMEM);
136	memcpy(&root->key, key, sizeof(root->key));
137	memcpy(root->buf, world->next_root_buf, sizeof(root->buf));
138	world->root_count++;
139	return root;
140}
141
142static void root_destroy(void *priv, void *root_priv)
143{
144	struct root *root = root_priv;
145	struct world *world = priv;
146
147	world->root_count--;
148	kfree(root);
149}
150
151static int test_nodelta_obj_get(struct world *world, struct objagg *objagg,
152				unsigned int key_id, bool should_create_root)
153{
154	unsigned int orig_root_count = world->root_count;
155	struct objagg_obj *objagg_obj;
156	const struct root *root;
157	int err;
158
159	if (should_create_root)
160		get_random_bytes(world->next_root_buf,
161			      sizeof(world->next_root_buf));
162
163	objagg_obj = world_obj_get(world, objagg, key_id);
164	if (IS_ERR(objagg_obj)) {
165		pr_err("Key %u: Failed to get object.\n", key_id);
166		return PTR_ERR(objagg_obj);
167	}
168	if (should_create_root) {
169		if (world->root_count != orig_root_count + 1) {
170			pr_err("Key %u: Root was not created\n", key_id);
171			err = -EINVAL;
172			goto err_check_root_count;
173		}
174	} else {
175		if (world->root_count != orig_root_count) {
176			pr_err("Key %u: Root was incorrectly created\n",
177			       key_id);
178			err = -EINVAL;
179			goto err_check_root_count;
180		}
181	}
182	root = objagg_obj_root_priv(objagg_obj);
183	if (root->key.id != key_id) {
184		pr_err("Key %u: Root has unexpected key id\n", key_id);
185		err = -EINVAL;
186		goto err_check_key_id;
187	}
188	if (should_create_root &&
189	    memcmp(world->next_root_buf, root->buf, sizeof(root->buf))) {
190		pr_err("Key %u: Buffer does not match the expected content\n",
191		       key_id);
192		err = -EINVAL;
193		goto err_check_buf;
194	}
195	return 0;
196
197err_check_buf:
198err_check_key_id:
199err_check_root_count:
200	objagg_obj_put(objagg, objagg_obj);
201	return err;
202}
203
204static int test_nodelta_obj_put(struct world *world, struct objagg *objagg,
205				unsigned int key_id, bool should_destroy_root)
206{
207	unsigned int orig_root_count = world->root_count;
208
209	world_obj_put(world, objagg, key_id);
210
211	if (should_destroy_root) {
212		if (world->root_count != orig_root_count - 1) {
213			pr_err("Key %u: Root was not destroyed\n", key_id);
214			return -EINVAL;
215		}
216	} else {
217		if (world->root_count != orig_root_count) {
218			pr_err("Key %u: Root was incorrectly destroyed\n",
219			       key_id);
220			return -EINVAL;
221		}
222	}
223	return 0;
224}
225
226static int check_stats_zero(struct objagg *objagg)
227{
228	const struct objagg_stats *stats;
229	int err = 0;
230
231	stats = objagg_stats_get(objagg);
232	if (IS_ERR(stats))
233		return PTR_ERR(stats);
234
235	if (stats->stats_info_count != 0) {
236		pr_err("Stats: Object count is not zero while it should be\n");
237		err = -EINVAL;
238	}
239
240	objagg_stats_put(stats);
241	return err;
242}
243
244static int check_stats_nodelta(struct objagg *objagg)
245{
246	const struct objagg_stats *stats;
247	int i;
248	int err;
249
250	stats = objagg_stats_get(objagg);
251	if (IS_ERR(stats))
252		return PTR_ERR(stats);
253
254	if (stats->stats_info_count != NUM_KEYS) {
255		pr_err("Stats: Unexpected object count (%u expected, %u returned)\n",
256		       NUM_KEYS, stats->stats_info_count);
257		err = -EINVAL;
258		goto stats_put;
259	}
260
261	for (i = 0; i < stats->stats_info_count; i++) {
262		if (stats->stats_info[i].stats.user_count != 2) {
263			pr_err("Stats: incorrect user count\n");
264			err = -EINVAL;
265			goto stats_put;
266		}
267		if (stats->stats_info[i].stats.delta_user_count != 2) {
268			pr_err("Stats: incorrect delta user count\n");
269			err = -EINVAL;
270			goto stats_put;
271		}
272	}
273	err = 0;
274
275stats_put:
276	objagg_stats_put(stats);
277	return err;
278}
279
280static bool delta_check_dummy(void *priv, const void *parent_obj,
281			      const void *obj)
282{
283	return false;
284}
285
286static void *delta_create_dummy(void *priv, void *parent_obj, void *obj)
287{
288	return ERR_PTR(-EOPNOTSUPP);
289}
290
291static void delta_destroy_dummy(void *priv, void *delta_priv)
292{
293}
294
295static const struct objagg_ops nodelta_ops = {
296	.obj_size = sizeof(struct tokey),
297	.delta_check = delta_check_dummy,
298	.delta_create = delta_create_dummy,
299	.delta_destroy = delta_destroy_dummy,
300	.root_create = root_create,
301	.root_destroy = root_destroy,
302};
303
304static int test_nodelta(void)
305{
306	struct world world = {};
307	struct objagg *objagg;
308	int i;
309	int err;
310
311	objagg = objagg_create(&nodelta_ops, NULL, &world);
312	if (IS_ERR(objagg))
313		return PTR_ERR(objagg);
314
315	err = check_stats_zero(objagg);
316	if (err)
317		goto err_stats_first_zero;
318
319	/* First round of gets, the root objects should be created */
320	for (i = 0; i < NUM_KEYS; i++) {
321		err = test_nodelta_obj_get(&world, objagg, i, true);
322		if (err)
323			goto err_obj_first_get;
324	}
325
326	/* Do the second round of gets, all roots are already created,
327	 * make sure that no new root is created
328	 */
329	for (i = 0; i < NUM_KEYS; i++) {
330		err = test_nodelta_obj_get(&world, objagg, i, false);
331		if (err)
332			goto err_obj_second_get;
333	}
334
335	err = check_stats_nodelta(objagg);
336	if (err)
337		goto err_stats_nodelta;
338
339	for (i = NUM_KEYS - 1; i >= 0; i--) {
340		err = test_nodelta_obj_put(&world, objagg, i, false);
341		if (err)
342			goto err_obj_first_put;
343	}
344	for (i = NUM_KEYS - 1; i >= 0; i--) {
345		err = test_nodelta_obj_put(&world, objagg, i, true);
346		if (err)
347			goto err_obj_second_put;
348	}
349
350	err = check_stats_zero(objagg);
351	if (err)
352		goto err_stats_second_zero;
353
354	objagg_destroy(objagg);
355	return 0;
356
357err_stats_nodelta:
358err_obj_first_put:
359err_obj_second_get:
360	for (i--; i >= 0; i--)
361		world_obj_put(&world, objagg, i);
362
363	i = NUM_KEYS;
364err_obj_first_get:
365err_obj_second_put:
366	for (i--; i >= 0; i--)
367		world_obj_put(&world, objagg, i);
368err_stats_first_zero:
369err_stats_second_zero:
370	objagg_destroy(objagg);
371	return err;
372}
373
374static const struct objagg_ops delta_ops = {
375	.obj_size = sizeof(struct tokey),
376	.delta_check = delta_check,
377	.delta_create = delta_create,
378	.delta_destroy = delta_destroy,
379	.root_create = root_create,
380	.root_destroy = root_destroy,
381};
382
383enum action {
384	ACTION_GET,
385	ACTION_PUT,
386};
387
388enum expect_delta {
389	EXPECT_DELTA_SAME,
390	EXPECT_DELTA_INC,
391	EXPECT_DELTA_DEC,
392};
393
394enum expect_root {
395	EXPECT_ROOT_SAME,
396	EXPECT_ROOT_INC,
397	EXPECT_ROOT_DEC,
398};
399
400struct expect_stats_info {
401	struct objagg_obj_stats stats;
402	bool is_root;
403	unsigned int key_id;
404};
405
406struct expect_stats {
407	unsigned int info_count;
408	struct expect_stats_info info[NUM_KEYS];
409};
410
411struct action_item {
412	unsigned int key_id;
413	enum action action;
414	enum expect_delta expect_delta;
415	enum expect_root expect_root;
416	struct expect_stats expect_stats;
417};
418
419#define EXPECT_STATS(count, ...)		\
420{						\
421	.info_count = count,			\
422	.info = { __VA_ARGS__ }			\
423}
424
425#define ROOT(key_id, user_count, delta_user_count)	\
426	{{user_count, delta_user_count}, true, key_id}
427
428#define DELTA(key_id, user_count)			\
429	{{user_count, user_count}, false, key_id}
430
431static const struct action_item action_items[] = {
432	{
433		1, ACTION_GET, EXPECT_DELTA_SAME, EXPECT_ROOT_INC,
434		EXPECT_STATS(1, ROOT(1, 1, 1)),
435	},	/* r: 1			d: */
436	{
437		7, ACTION_GET, EXPECT_DELTA_SAME, EXPECT_ROOT_INC,
438		EXPECT_STATS(2, ROOT(1, 1, 1), ROOT(7, 1, 1)),
439	},	/* r: 1, 7		d: */
440	{
441		3, ACTION_GET, EXPECT_DELTA_INC, EXPECT_ROOT_SAME,
442		EXPECT_STATS(3, ROOT(1, 1, 2), ROOT(7, 1, 1),
443				DELTA(3, 1)),
444	},	/* r: 1, 7		d: 3^1 */
445	{
446		5, ACTION_GET, EXPECT_DELTA_INC, EXPECT_ROOT_SAME,
447		EXPECT_STATS(4, ROOT(1, 1, 3), ROOT(7, 1, 1),
448				DELTA(3, 1), DELTA(5, 1)),
449	},	/* r: 1, 7		d: 3^1, 5^1 */
450	{
451		3, ACTION_GET, EXPECT_DELTA_SAME, EXPECT_ROOT_SAME,
452		EXPECT_STATS(4, ROOT(1, 1, 4), ROOT(7, 1, 1),
453				DELTA(3, 2), DELTA(5, 1)),
454	},	/* r: 1, 7		d: 3^1, 3^1, 5^1 */
455	{
456		1, ACTION_GET, EXPECT_DELTA_SAME, EXPECT_ROOT_SAME,
457		EXPECT_STATS(4, ROOT(1, 2, 5), ROOT(7, 1, 1),
458				DELTA(3, 2), DELTA(5, 1)),
459	},	/* r: 1, 1, 7		d: 3^1, 3^1, 5^1 */
460	{
461		30, ACTION_GET, EXPECT_DELTA_SAME, EXPECT_ROOT_INC,
462		EXPECT_STATS(5, ROOT(1, 2, 5), ROOT(7, 1, 1), ROOT(30, 1, 1),
463				DELTA(3, 2), DELTA(5, 1)),
464	},	/* r: 1, 1, 7, 30	d: 3^1, 3^1, 5^1 */
465	{
466		8, ACTION_GET, EXPECT_DELTA_INC, EXPECT_ROOT_SAME,
467		EXPECT_STATS(6, ROOT(1, 2, 5), ROOT(7, 1, 2), ROOT(30, 1, 1),
468				DELTA(3, 2), DELTA(5, 1), DELTA(8, 1)),
469	},	/* r: 1, 1, 7, 30	d: 3^1, 3^1, 5^1, 8^7 */
470	{
471		8, ACTION_GET, EXPECT_DELTA_SAME, EXPECT_ROOT_SAME,
472		EXPECT_STATS(6, ROOT(1, 2, 5), ROOT(7, 1, 3), ROOT(30, 1, 1),
473				DELTA(3, 2), DELTA(8, 2), DELTA(5, 1)),
474	},	/* r: 1, 1, 7, 30	d: 3^1, 3^1, 5^1, 8^7, 8^7 */
475	{
476		3, ACTION_PUT, EXPECT_DELTA_SAME, EXPECT_ROOT_SAME,
477		EXPECT_STATS(6, ROOT(1, 2, 4), ROOT(7, 1, 3), ROOT(30, 1, 1),
478				DELTA(8, 2), DELTA(3, 1), DELTA(5, 1)),
479	},	/* r: 1, 1, 7, 30	d: 3^1, 5^1, 8^7, 8^7 */
480	{
481		3, ACTION_PUT, EXPECT_DELTA_DEC, EXPECT_ROOT_SAME,
482		EXPECT_STATS(5, ROOT(1, 2, 3), ROOT(7, 1, 3), ROOT(30, 1, 1),
483				DELTA(8, 2), DELTA(5, 1)),
484	},	/* r: 1, 1, 7, 30	d: 5^1, 8^7, 8^7 */
485	{
486		1, ACTION_PUT, EXPECT_DELTA_SAME, EXPECT_ROOT_SAME,
487		EXPECT_STATS(5, ROOT(7, 1, 3), ROOT(1, 1, 2), ROOT(30, 1, 1),
488				DELTA(8, 2), DELTA(5, 1)),
489	},	/* r: 1, 7, 30		d: 5^1, 8^7, 8^7 */
490	{
491		1, ACTION_PUT, EXPECT_DELTA_SAME, EXPECT_ROOT_SAME,
492		EXPECT_STATS(5, ROOT(7, 1, 3), ROOT(30, 1, 1), ROOT(1, 0, 1),
493				DELTA(8, 2), DELTA(5, 1)),
494	},	/* r: 7, 30		d: 5^1, 8^7, 8^7 */
495	{
496		5, ACTION_PUT, EXPECT_DELTA_DEC, EXPECT_ROOT_DEC,
497		EXPECT_STATS(3, ROOT(7, 1, 3), ROOT(30, 1, 1),
498				DELTA(8, 2)),
499	},	/* r: 7, 30		d: 8^7, 8^7 */
500	{
501		5, ACTION_GET, EXPECT_DELTA_SAME, EXPECT_ROOT_INC,
502		EXPECT_STATS(4, ROOT(7, 1, 3), ROOT(30, 1, 1), ROOT(5, 1, 1),
503				DELTA(8, 2)),
504	},	/* r: 7, 30, 5		d: 8^7, 8^7 */
505	{
506		6, ACTION_GET, EXPECT_DELTA_INC, EXPECT_ROOT_SAME,
507		EXPECT_STATS(5, ROOT(7, 1, 3), ROOT(5, 1, 2), ROOT(30, 1, 1),
508				DELTA(8, 2), DELTA(6, 1)),
509	},	/* r: 7, 30, 5		d: 8^7, 8^7, 6^5 */
510	{
511		8, ACTION_GET, EXPECT_DELTA_SAME, EXPECT_ROOT_SAME,
512		EXPECT_STATS(5, ROOT(7, 1, 4), ROOT(5, 1, 2), ROOT(30, 1, 1),
513				DELTA(8, 3), DELTA(6, 1)),
514	},	/* r: 7, 30, 5		d: 8^7, 8^7, 8^7, 6^5 */
515	{
516		8, ACTION_PUT, EXPECT_DELTA_SAME, EXPECT_ROOT_SAME,
517		EXPECT_STATS(5, ROOT(7, 1, 3), ROOT(5, 1, 2), ROOT(30, 1, 1),
518				DELTA(8, 2), DELTA(6, 1)),
519	},	/* r: 7, 30, 5		d: 8^7, 8^7, 6^5 */
520	{
521		8, ACTION_PUT, EXPECT_DELTA_SAME, EXPECT_ROOT_SAME,
522		EXPECT_STATS(5, ROOT(7, 1, 2), ROOT(5, 1, 2), ROOT(30, 1, 1),
523				DELTA(8, 1), DELTA(6, 1)),
524	},	/* r: 7, 30, 5		d: 8^7, 6^5 */
525	{
526		8, ACTION_PUT, EXPECT_DELTA_DEC, EXPECT_ROOT_SAME,
527		EXPECT_STATS(4, ROOT(5, 1, 2), ROOT(7, 1, 1), ROOT(30, 1, 1),
528				DELTA(6, 1)),
529	},	/* r: 7, 30, 5		d: 6^5 */
530	{
531		8, ACTION_GET, EXPECT_DELTA_INC, EXPECT_ROOT_SAME,
532		EXPECT_STATS(5, ROOT(5, 1, 3), ROOT(7, 1, 1), ROOT(30, 1, 1),
533				DELTA(6, 1), DELTA(8, 1)),
534	},	/* r: 7, 30, 5		d: 6^5, 8^5 */
535	{
536		7, ACTION_PUT, EXPECT_DELTA_SAME, EXPECT_ROOT_DEC,
537		EXPECT_STATS(4, ROOT(5, 1, 3), ROOT(30, 1, 1),
538				DELTA(6, 1), DELTA(8, 1)),
539	},	/* r: 30, 5		d: 6^5, 8^5 */
540	{
541		30, ACTION_PUT, EXPECT_DELTA_SAME, EXPECT_ROOT_DEC,
542		EXPECT_STATS(3, ROOT(5, 1, 3),
543				DELTA(6, 1), DELTA(8, 1)),
544	},	/* r: 5			d: 6^5, 8^5 */
545	{
546		5, ACTION_PUT, EXPECT_DELTA_SAME, EXPECT_ROOT_SAME,
547		EXPECT_STATS(3, ROOT(5, 0, 2),
548				DELTA(6, 1), DELTA(8, 1)),
549	},	/* r:			d: 6^5, 8^5 */
550	{
551		6, ACTION_PUT, EXPECT_DELTA_DEC, EXPECT_ROOT_SAME,
552		EXPECT_STATS(2, ROOT(5, 0, 1),
553				DELTA(8, 1)),
554	},	/* r:			d: 6^5 */
555	{
556		8, ACTION_PUT, EXPECT_DELTA_DEC, EXPECT_ROOT_DEC,
557		EXPECT_STATS(0, ),
558	},	/* r:			d: */
559};
560
561static int check_expect(struct world *world,
562			const struct action_item *action_item,
563			unsigned int orig_delta_count,
564			unsigned int orig_root_count)
565{
566	unsigned int key_id = action_item->key_id;
567
568	switch (action_item->expect_delta) {
569	case EXPECT_DELTA_SAME:
570		if (orig_delta_count != world->delta_count) {
571			pr_err("Key %u: Delta count changed while expected to remain the same.\n",
572			       key_id);
573			return -EINVAL;
574		}
575		break;
576	case EXPECT_DELTA_INC:
577		if (WARN_ON(action_item->action == ACTION_PUT))
578			return -EINVAL;
579		if (orig_delta_count + 1 != world->delta_count) {
580			pr_err("Key %u: Delta count was not incremented.\n",
581			       key_id);
582			return -EINVAL;
583		}
584		break;
585	case EXPECT_DELTA_DEC:
586		if (WARN_ON(action_item->action == ACTION_GET))
587			return -EINVAL;
588		if (orig_delta_count - 1 != world->delta_count) {
589			pr_err("Key %u: Delta count was not decremented.\n",
590			       key_id);
591			return -EINVAL;
592		}
593		break;
594	}
595
596	switch (action_item->expect_root) {
597	case EXPECT_ROOT_SAME:
598		if (orig_root_count != world->root_count) {
599			pr_err("Key %u: Root count changed while expected to remain the same.\n",
600			       key_id);
601			return -EINVAL;
602		}
603		break;
604	case EXPECT_ROOT_INC:
605		if (WARN_ON(action_item->action == ACTION_PUT))
606			return -EINVAL;
607		if (orig_root_count + 1 != world->root_count) {
608			pr_err("Key %u: Root count was not incremented.\n",
609			       key_id);
610			return -EINVAL;
611		}
612		break;
613	case EXPECT_ROOT_DEC:
614		if (WARN_ON(action_item->action == ACTION_GET))
615			return -EINVAL;
616		if (orig_root_count - 1 != world->root_count) {
617			pr_err("Key %u: Root count was not decremented.\n",
618			       key_id);
619			return -EINVAL;
620		}
621	}
622
623	return 0;
624}
625
626static unsigned int obj_to_key_id(struct objagg_obj *objagg_obj)
627{
628	const struct tokey *root_key;
629	const struct delta *delta;
630	unsigned int key_id;
631
632	root_key = objagg_obj_root_priv(objagg_obj);
633	key_id = root_key->id;
634	delta = objagg_obj_delta_priv(objagg_obj);
635	if (delta)
636		key_id += delta->key_id_diff;
637	return key_id;
638}
639
640static int
641check_expect_stats_nums(const struct objagg_obj_stats_info *stats_info,
642			const struct expect_stats_info *expect_stats_info,
643			const char **errmsg)
644{
645	if (stats_info->is_root != expect_stats_info->is_root) {
646		if (errmsg)
647			*errmsg = "Incorrect root/delta indication";
648		return -EINVAL;
649	}
650	if (stats_info->stats.user_count !=
651	    expect_stats_info->stats.user_count) {
652		if (errmsg)
653			*errmsg = "Incorrect user count";
654		return -EINVAL;
655	}
656	if (stats_info->stats.delta_user_count !=
657	    expect_stats_info->stats.delta_user_count) {
658		if (errmsg)
659			*errmsg = "Incorrect delta user count";
660		return -EINVAL;
661	}
662	return 0;
663}
664
665static int
666check_expect_stats_key_id(const struct objagg_obj_stats_info *stats_info,
667			  const struct expect_stats_info *expect_stats_info,
668			  const char **errmsg)
669{
670	if (obj_to_key_id(stats_info->objagg_obj) !=
671	    expect_stats_info->key_id) {
672		if (errmsg)
673			*errmsg = "incorrect key id";
674		return -EINVAL;
675	}
676	return 0;
677}
678
679static int check_expect_stats_neigh(const struct objagg_stats *stats,
680				    const struct expect_stats *expect_stats,
681				    int pos)
682{
683	int i;
684	int err;
685
686	for (i = pos - 1; i >= 0; i--) {
687		err = check_expect_stats_nums(&stats->stats_info[i],
688					      &expect_stats->info[pos], NULL);
689		if (err)
690			break;
691		err = check_expect_stats_key_id(&stats->stats_info[i],
692						&expect_stats->info[pos], NULL);
693		if (!err)
694			return 0;
695	}
696	for (i = pos + 1; i < stats->stats_info_count; i++) {
697		err = check_expect_stats_nums(&stats->stats_info[i],
698					      &expect_stats->info[pos], NULL);
699		if (err)
700			break;
701		err = check_expect_stats_key_id(&stats->stats_info[i],
702						&expect_stats->info[pos], NULL);
703		if (!err)
704			return 0;
705	}
706	return -EINVAL;
707}
708
709static int __check_expect_stats(const struct objagg_stats *stats,
710				const struct expect_stats *expect_stats,
711				const char **errmsg)
712{
713	int i;
714	int err;
715
716	if (stats->stats_info_count != expect_stats->info_count) {
717		*errmsg = "Unexpected object count";
718		return -EINVAL;
719	}
720
721	for (i = 0; i < stats->stats_info_count; i++) {
722		err = check_expect_stats_nums(&stats->stats_info[i],
723					      &expect_stats->info[i], errmsg);
724		if (err)
725			return err;
726		err = check_expect_stats_key_id(&stats->stats_info[i],
727						&expect_stats->info[i], errmsg);
728		if (err) {
729			/* It is possible that one of the neighbor stats with
730			 * same numbers have the correct key id, so check it
731			 */
732			err = check_expect_stats_neigh(stats, expect_stats, i);
733			if (err)
734				return err;
735		}
736	}
737	return 0;
738}
739
740static int check_expect_stats(struct objagg *objagg,
741			      const struct expect_stats *expect_stats,
742			      const char **errmsg)
743{
744	const struct objagg_stats *stats;
745	int err;
746
747	stats = objagg_stats_get(objagg);
748	if (IS_ERR(stats)) {
749		*errmsg = "objagg_stats_get() failed.";
750		return PTR_ERR(stats);
751	}
752	err = __check_expect_stats(stats, expect_stats, errmsg);
753	objagg_stats_put(stats);
754	return err;
755}
756
757static int test_delta_action_item(struct world *world,
758				  struct objagg *objagg,
759				  const struct action_item *action_item,
760				  bool inverse)
761{
762	unsigned int orig_delta_count = world->delta_count;
763	unsigned int orig_root_count = world->root_count;
764	unsigned int key_id = action_item->key_id;
765	enum action action = action_item->action;
766	struct objagg_obj *objagg_obj;
767	const char *errmsg;
768	int err;
769
770	if (inverse)
771		action = action == ACTION_GET ? ACTION_PUT : ACTION_GET;
772
773	switch (action) {
774	case ACTION_GET:
775		objagg_obj = world_obj_get(world, objagg, key_id);
776		if (IS_ERR(objagg_obj))
777			return PTR_ERR(objagg_obj);
778		break;
779	case ACTION_PUT:
780		world_obj_put(world, objagg, key_id);
781		break;
782	}
783
784	if (inverse)
785		return 0;
786	err = check_expect(world, action_item,
787			   orig_delta_count, orig_root_count);
788	if (err)
789		goto errout;
790
791	err = check_expect_stats(objagg, &action_item->expect_stats, &errmsg);
792	if (err) {
793		pr_err("Key %u: Stats: %s\n", action_item->key_id, errmsg);
794		goto errout;
795	}
796
797	return 0;
798
799errout:
800	/* This can only happen when action is not inversed.
801	 * So in case of an error, cleanup by doing inverse action.
802	 */
803	test_delta_action_item(world, objagg, action_item, true);
804	return err;
805}
806
807static int test_delta(void)
808{
809	struct world world = {};
810	struct objagg *objagg;
811	int i;
812	int err;
813
814	objagg = objagg_create(&delta_ops, NULL, &world);
815	if (IS_ERR(objagg))
816		return PTR_ERR(objagg);
817
818	for (i = 0; i < ARRAY_SIZE(action_items); i++) {
819		err = test_delta_action_item(&world, objagg,
820					     &action_items[i], false);
821		if (err)
822			goto err_do_action_item;
823	}
824
825	objagg_destroy(objagg);
826	return 0;
827
828err_do_action_item:
829	for (i--; i >= 0; i--)
830		test_delta_action_item(&world, objagg, &action_items[i], true);
831
832	objagg_destroy(objagg);
833	return err;
834}
835
836struct hints_case {
837	const unsigned int *key_ids;
838	size_t key_ids_count;
839	struct expect_stats expect_stats;
840	struct expect_stats expect_stats_hints;
841};
842
843static const unsigned int hints_case_key_ids[] = {
844	1, 7, 3, 5, 3, 1, 30, 8, 8, 5, 6, 8,
845};
846
847static const struct hints_case hints_case = {
848	.key_ids = hints_case_key_ids,
849	.key_ids_count = ARRAY_SIZE(hints_case_key_ids),
850	.expect_stats =
851		EXPECT_STATS(7, ROOT(1, 2, 7), ROOT(7, 1, 4), ROOT(30, 1, 1),
852				DELTA(8, 3), DELTA(3, 2),
853				DELTA(5, 2), DELTA(6, 1)),
854	.expect_stats_hints =
855		EXPECT_STATS(7, ROOT(3, 2, 9), ROOT(1, 2, 2), ROOT(30, 1, 1),
856				DELTA(8, 3), DELTA(5, 2),
857				DELTA(6, 1), DELTA(7, 1)),
858};
859
860static void __pr_debug_stats(const struct objagg_stats *stats)
861{
862	int i;
863
864	for (i = 0; i < stats->stats_info_count; i++)
865		pr_debug("Stat index %d key %u: u %d, d %d, %s\n", i,
866			 obj_to_key_id(stats->stats_info[i].objagg_obj),
867			 stats->stats_info[i].stats.user_count,
868			 stats->stats_info[i].stats.delta_user_count,
869			 stats->stats_info[i].is_root ? "root" : "noroot");
870}
871
872static void pr_debug_stats(struct objagg *objagg)
873{
874	const struct objagg_stats *stats;
875
876	stats = objagg_stats_get(objagg);
877	if (IS_ERR(stats))
878		return;
879	__pr_debug_stats(stats);
880	objagg_stats_put(stats);
881}
882
883static void pr_debug_hints_stats(struct objagg_hints *objagg_hints)
884{
885	const struct objagg_stats *stats;
886
887	stats = objagg_hints_stats_get(objagg_hints);
888	if (IS_ERR(stats))
889		return;
890	__pr_debug_stats(stats);
891	objagg_stats_put(stats);
892}
893
894static int check_expect_hints_stats(struct objagg_hints *objagg_hints,
895				    const struct expect_stats *expect_stats,
896				    const char **errmsg)
897{
898	const struct objagg_stats *stats;
899	int err;
900
901	stats = objagg_hints_stats_get(objagg_hints);
902	if (IS_ERR(stats))
903		return PTR_ERR(stats);
904	err = __check_expect_stats(stats, expect_stats, errmsg);
905	objagg_stats_put(stats);
906	return err;
907}
908
909static int test_hints_case(const struct hints_case *hints_case)
910{
911	struct objagg_obj *objagg_obj;
912	struct objagg_hints *hints;
913	struct world world2 = {};
914	struct world world = {};
915	struct objagg *objagg2;
916	struct objagg *objagg;
917	const char *errmsg;
918	int i;
919	int err;
920
921	objagg = objagg_create(&delta_ops, NULL, &world);
922	if (IS_ERR(objagg))
923		return PTR_ERR(objagg);
924
925	for (i = 0; i < hints_case->key_ids_count; i++) {
926		objagg_obj = world_obj_get(&world, objagg,
927					   hints_case->key_ids[i]);
928		if (IS_ERR(objagg_obj)) {
929			err = PTR_ERR(objagg_obj);
930			goto err_world_obj_get;
931		}
932	}
933
934	pr_debug_stats(objagg);
935	err = check_expect_stats(objagg, &hints_case->expect_stats, &errmsg);
936	if (err) {
937		pr_err("Stats: %s\n", errmsg);
938		goto err_check_expect_stats;
939	}
940
941	hints = objagg_hints_get(objagg, OBJAGG_OPT_ALGO_SIMPLE_GREEDY);
942	if (IS_ERR(hints)) {
943		err = PTR_ERR(hints);
944		goto err_hints_get;
945	}
946
947	pr_debug_hints_stats(hints);
948	err = check_expect_hints_stats(hints, &hints_case->expect_stats_hints,
949				       &errmsg);
950	if (err) {
951		pr_err("Hints stats: %s\n", errmsg);
952		goto err_check_expect_hints_stats;
953	}
954
955	objagg2 = objagg_create(&delta_ops, hints, &world2);
956	if (IS_ERR(objagg2))
957		return PTR_ERR(objagg2);
958
959	for (i = 0; i < hints_case->key_ids_count; i++) {
960		objagg_obj = world_obj_get(&world2, objagg2,
961					   hints_case->key_ids[i]);
962		if (IS_ERR(objagg_obj)) {
963			err = PTR_ERR(objagg_obj);
964			goto err_world2_obj_get;
965		}
966	}
967
968	pr_debug_stats(objagg2);
969	err = check_expect_stats(objagg2, &hints_case->expect_stats_hints,
970				 &errmsg);
971	if (err) {
972		pr_err("Stats2: %s\n", errmsg);
973		goto err_check_expect_stats2;
974	}
975
976	err = 0;
977
978err_check_expect_stats2:
979err_world2_obj_get:
980	for (i--; i >= 0; i--)
981		world_obj_put(&world2, objagg, hints_case->key_ids[i]);
982	i = hints_case->key_ids_count;
983	objagg_destroy(objagg2);
984err_check_expect_hints_stats:
985	objagg_hints_put(hints);
986err_hints_get:
987err_check_expect_stats:
988err_world_obj_get:
989	for (i--; i >= 0; i--)
990		world_obj_put(&world, objagg, hints_case->key_ids[i]);
991
992	objagg_destroy(objagg);
993	return err;
994}
995static int test_hints(void)
996{
997	return test_hints_case(&hints_case);
998}
999
1000static int __init test_objagg_init(void)
1001{
1002	int err;
1003
1004	err = test_nodelta();
1005	if (err)
1006		return err;
1007	err = test_delta();
1008	if (err)
1009		return err;
1010	return test_hints();
1011}
1012
1013static void __exit test_objagg_exit(void)
1014{
1015}
1016
1017module_init(test_objagg_init);
1018module_exit(test_objagg_exit);
1019MODULE_LICENSE("Dual BSD/GPL");
1020MODULE_AUTHOR("Jiri Pirko <jiri@mellanox.com>");
1021MODULE_DESCRIPTION("Test module for objagg");
1022