1/*	$NetBSD: rpst.c,v 1.11 2011/04/26 20:53:34 yamt Exp $	*/
2
3/*-
4 * Copyright (c)2009 YAMAMOTO Takashi,
5 * All rights reserved.
6 *
7 * Redistribution and use in source and binary forms, with or without
8 * modification, are permitted provided that the following conditions
9 * are met:
10 * 1. Redistributions of source code must retain the above copyright
11 *    notice, this list of conditions and the following disclaimer.
12 * 2. Redistributions in binary form must reproduce the above copyright
13 *    notice, this list of conditions and the following disclaimer in the
14 *    documentation and/or other materials provided with the distribution.
15 *
16 * THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS ``AS IS'' AND
17 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19 * ARE DISCLAIMED.  IN NO EVENT SHALL THE AUTHOR OR CONTRIBUTORS BE LIABLE
20 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
21 * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
22 * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
23 * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
24 * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
25 * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
26 * SUCH DAMAGE.
27 */
28
29/*
30 * radix priority search tree
31 *
32 * described in:
33 *	SIAM J. COMPUT.
34 *	Vol. 14, No. 2, May 1985
35 *	PRIORITY SEARCH TREES
36 *	EDWARD M. McCREIGHT
37 *
38 * ideas from linux:
39 *	- grow tree height on-demand.
40 *	- allow duplicated X values.  in that case, we act as a heap.
41 */
42
43#include <sys/cdefs.h>
44
45#if defined(_KERNEL) || defined(_STANDALONE)
46__KERNEL_RCSID(0, "$NetBSD: rpst.c,v 1.11 2011/04/26 20:53:34 yamt Exp $");
47#include <sys/param.h>
48#include <lib/libkern/libkern.h>
49#if defined(_STANDALONE)
50#include <lib/libsa/stand.h>
51#endif /* defined(_STANDALONE) */
52#else /* defined(_KERNEL) || defined(_STANDALONE) */
53__RCSID("$NetBSD: rpst.c,v 1.11 2011/04/26 20:53:34 yamt Exp $");
54#include <assert.h>
55#include <stdbool.h>
56#include <string.h>
57#if 1
58#define	KASSERT	assert
59#else
60#define	KASSERT(a)
61#endif
62#endif /* defined(_KERNEL) || defined(_STANDALONE) */
63
64#include <sys/rpst.h>
65
66/*
67 * rpst_init_tree: initialize a tree.
68 */
69
70void
71rpst_init_tree(struct rpst_tree *t)
72{
73
74	t->t_root = NULL;
75	t->t_height = 0;
76}
77
78/*
79 * rpst_height2max: calculate the maximum index which can be handled by
80 * a tree with the given height.
81 *
82 * 0  ... 0x0000000000000001
83 * 1  ... 0x0000000000000003
84 * 2  ... 0x0000000000000007
85 * 3  ... 0x000000000000000f
86 *
87 * 31 ... 0x00000000ffffffff
88 *
89 * 63 ... 0xffffffffffffffff
90 */
91
92static uint64_t
93rpst_height2max(unsigned int height)
94{
95
96	KASSERT(height < 64);
97	if (height == 63) {
98		return UINT64_MAX;
99	}
100	return (UINT64_C(1) << (height + 1)) - 1;
101}
102
103/*
104 * rpst_level2mask: calculate the mask for the given level in the tree.
105 *
106 * the mask used to index root's children is level 0.
107 */
108
109static uint64_t
110rpst_level2mask(const struct rpst_tree *t, unsigned int level)
111{
112	uint64_t mask;
113
114	if (t->t_height < level) {
115		mask = 0;
116	} else {
117		mask = UINT64_C(1) << (t->t_height - level);
118	}
119	return mask;
120}
121
122/*
123 * rpst_startmask: calculate the mask for the start of a search.
124 * (ie. the mask for the top-most bit)
125 */
126
127static uint64_t
128rpst_startmask(const struct rpst_tree *t)
129{
130	const uint64_t mask = rpst_level2mask(t, 0);
131
132	KASSERT((mask | (mask - 1)) == rpst_height2max(t->t_height));
133	return mask;
134}
135
136/*
137 * rpst_update_parents: update n_parent of children
138 */
139
140static inline void
141rpst_update_parents(struct rpst_node *n)
142{
143	int i;
144
145	for (i = 0; i < 2; i++) {
146		if (n->n_children[i] != NULL) {
147			n->n_children[i]->n_parent = n;
148		}
149	}
150}
151
152/*
153 * rpst_enlarge_tree: enlarge tree so that 'idx' can be stored
154 */
155
156static void
157rpst_enlarge_tree(struct rpst_tree *t, uint64_t idx)
158{
159
160	while (idx > rpst_height2max(t->t_height)) {
161		struct rpst_node *n = t->t_root;
162
163		if (n != NULL) {
164			rpst_remove_node(t, n);
165			memset(&n->n_children, 0, sizeof(n->n_children));
166			n->n_children[0] = t->t_root;
167			t->t_root->n_parent = n;
168			t->t_root = n;
169			n->n_parent = NULL;
170		}
171		t->t_height++;
172	}
173}
174
175/*
176 * rpst_insert_node1: a helper for rpst_insert_node.
177 */
178
179static struct rpst_node *
180rpst_insert_node1(struct rpst_node **where, struct rpst_node *n, uint64_t mask)
181{
182	struct rpst_node *parent;
183	struct rpst_node *cur;
184	unsigned int idx;
185
186	KASSERT((n->n_x & ((-mask) << 1)) == 0);
187	parent = NULL;
188next:
189	cur = *where;
190	if (cur == NULL) {
191		n->n_parent = parent;
192		memset(&n->n_children, 0, sizeof(n->n_children));
193		*where = n;
194		return NULL;
195	}
196	KASSERT(cur->n_parent == parent);
197	if (n->n_y == cur->n_y && n->n_x == cur->n_x) {
198		return cur;
199	}
200	if (n->n_y < cur->n_y) {
201		/*
202		 * swap cur and n.
203		 * note that n is not in tree.
204		 */
205		memcpy(n->n_children, cur->n_children, sizeof(n->n_children));
206		n->n_parent = cur->n_parent;
207		rpst_update_parents(n);
208		*where = n;
209		n = cur;
210		cur = *where;
211	}
212	KASSERT(*where == cur);
213	idx = (n->n_x & mask) != 0;
214	where = &cur->n_children[idx];
215	parent = cur;
216	KASSERT((*where) == NULL || ((((*where)->n_x & mask) != 0) == idx));
217	KASSERT((*where) == NULL || (*where)->n_y >= cur->n_y);
218	mask >>= 1;
219	goto next;
220}
221
222/*
223 * rpst_insert_node: insert a node into the tree.
224 *
225 * => return NULL on success.
226 * => if a duplicated node (a node with the same X,Y pair as ours) is found,
227 *    return the node.  in that case, the tree is intact.
228 */
229
230struct rpst_node *
231rpst_insert_node(struct rpst_tree *t, struct rpst_node *n)
232{
233
234	rpst_enlarge_tree(t, n->n_x);
235	return rpst_insert_node1(&t->t_root, n, rpst_startmask(t));
236}
237
238/*
239 * rpst_find_pptr: find a pointer to the given node.
240 *
241 * also, return the parent node via parentp.  (NULL for the root node.)
242 */
243
244static inline struct rpst_node **
245rpst_find_pptr(struct rpst_tree *t, struct rpst_node *n,
246    struct rpst_node **parentp)
247{
248	struct rpst_node * const parent = n->n_parent;
249	unsigned int i;
250
251	*parentp = parent;
252	if (parent == NULL) {
253		return &t->t_root;
254	}
255	for (i = 0; i < 2 - 1; i++) {
256		if (parent->n_children[i] == n) {
257			break;
258		}
259	}
260	KASSERT(parent->n_children[i] == n);
261	return &parent->n_children[i];
262}
263
264/*
265 * rpst_remove_node_at: remove a node at *where.
266 */
267
268static void
269rpst_remove_node_at(struct rpst_node *parent, struct rpst_node **where,
270    struct rpst_node *cur)
271{
272	struct rpst_node *tmp[2];
273	struct rpst_node *selected;
274	unsigned int selected_idx = 0; /* XXX gcc */
275	unsigned int i;
276
277	KASSERT(cur != NULL);
278	KASSERT(parent == cur->n_parent);
279next:
280	selected = NULL;
281	for (i = 0; i < 2; i++) {
282		struct rpst_node *c;
283
284		c = cur->n_children[i];
285		KASSERT(c == NULL || c->n_parent == cur);
286		if (selected == NULL || (c != NULL && c->n_y < selected->n_y)) {
287			selected = c;
288			selected_idx = i;
289		}
290	}
291	/*
292	 * now we have:
293	 *
294	 *      parent
295	 *          \ <- where
296	 *           cur
297	 *           / \
298	 *          A  selected
299	 *              / \
300	 *             B   C
301	 */
302	*where = selected;
303	if (selected == NULL) {
304		return;
305	}
306	/*
307	 * swap selected->n_children and cur->n_children.
308	 */
309	memcpy(tmp, selected->n_children, sizeof(tmp));
310	memcpy(selected->n_children, cur->n_children, sizeof(tmp));
311	memcpy(cur->n_children, tmp, sizeof(tmp));
312	rpst_update_parents(cur);
313	rpst_update_parents(selected);
314	selected->n_parent = parent;
315	/*
316	 *      parent
317	 *          \ <- where
318	 *          selected
319	 *           / \
320	 *          A  selected
321	 *
322	 *              cur
323	 *              / \
324	 *             B   C
325	 */
326	where = &selected->n_children[selected_idx];
327	/*
328	 *      parent
329	 *          \
330	 *          selected
331	 *           / \ <- where
332	 *          A  selected (*)
333	 *
334	 *              cur (**)
335	 *              / \
336	 *             B   C
337	 *
338	 * (*) this 'selected' will be overwritten in the next iteration.
339	 * (**) cur->n_parent is bogus.
340	 */
341	parent = selected;
342	goto next;
343}
344
345/*
346 * rpst_remove_node: remove a node from the tree.
347 */
348
349void
350rpst_remove_node(struct rpst_tree *t, struct rpst_node *n)
351{
352	struct rpst_node *parent;
353	struct rpst_node **where;
354
355	where = rpst_find_pptr(t, n, &parent);
356	rpst_remove_node_at(parent, where, n);
357}
358
359static bool __unused
360rpst_iterator_match_p(const struct rpst_node *n, const struct rpst_iterator *it)
361{
362
363	if (n->n_y > it->it_max_y) {
364		return false;
365	}
366	if (n->n_x < it->it_min_x) {
367		return false;
368	}
369	if (n->n_x > it->it_max_x) {
370		return false;
371	}
372	return true;
373}
374
375struct rpst_node *
376rpst_iterate_first(struct rpst_tree *t, uint64_t max_y, uint64_t min_x,
377    uint64_t max_x, struct rpst_iterator *it)
378{
379	struct rpst_node *n;
380
381	KASSERT(min_x <= max_x);
382	n = t->t_root;
383	if (n == NULL || n->n_y > max_y) {
384		return NULL;
385	}
386	if (rpst_height2max(t->t_height) < min_x) {
387		return NULL;
388	}
389	it->it_tree = t;
390	it->it_cur = n;
391	it->it_idx = (min_x & rpst_startmask(t)) != 0;
392	it->it_level = 0;
393	it->it_max_y = max_y;
394	it->it_min_x = min_x;
395	it->it_max_x = max_x;
396	return rpst_iterate_next(it);
397}
398
399static inline unsigned int
400rpst_node_on_edge_p(const struct rpst_node *n, uint64_t val, uint64_t mask)
401{
402
403	return ((n->n_x ^ val) & ((-mask) << 1)) == 0;
404}
405
406static inline uint64_t
407rpst_maxidx(const struct rpst_node *n, uint64_t max_x, uint64_t mask)
408{
409
410	if (rpst_node_on_edge_p(n, max_x, mask)) {
411		return (max_x & mask) != 0;
412	} else {
413		return 1;
414	}
415}
416
417static inline uint64_t
418rpst_minidx(const struct rpst_node *n, uint64_t min_x, uint64_t mask)
419{
420
421	if (rpst_node_on_edge_p(n, min_x, mask)) {
422		return (min_x & mask) != 0;
423	} else {
424		return 0;
425	}
426}
427
428struct rpst_node *
429rpst_iterate_next(struct rpst_iterator *it)
430{
431	struct rpst_tree *t;
432	struct rpst_node *n;
433	struct rpst_node *next;
434	const uint64_t max_y = it->it_max_y;
435	const uint64_t min_x = it->it_min_x;
436	const uint64_t max_x = it->it_max_x;
437	unsigned int idx;
438	unsigned int maxidx;
439	unsigned int level;
440	uint64_t mask;
441
442	t = it->it_tree;
443	n = it->it_cur;
444	idx = it->it_idx;
445	level = it->it_level;
446	mask = rpst_level2mask(t, level);
447	maxidx = rpst_maxidx(n, max_x, mask);
448	KASSERT(n == t->t_root || rpst_iterator_match_p(n, it));
449next:
450	KASSERT(mask == rpst_level2mask(t, level));
451	KASSERT(idx >= rpst_minidx(n, min_x, mask));
452	KASSERT(maxidx == rpst_maxidx(n, max_x, mask));
453	KASSERT(idx <= maxidx + 2);
454	KASSERT(n != NULL);
455#if 0
456	printf("%s: cur=%p, idx=%u maxidx=%u level=%u mask=%" PRIx64 "\n",
457	    __func__, (void *)n, idx, maxidx, level, mask);
458#endif
459	if (idx == maxidx + 1) { /* visit the current node */
460		idx++;
461		if (min_x <= n->n_x && n->n_x <= max_x) {
462			it->it_cur = n;
463			it->it_idx = idx;
464			it->it_level = level;
465			KASSERT(rpst_iterator_match_p(n, it));
466			return n; /* report */
467		}
468		goto next;
469	} else if (idx == maxidx + 2) { /* back to the parent */
470		struct rpst_node **where;
471
472		where = rpst_find_pptr(t, n, &next);
473		if (next == NULL) {
474			KASSERT(level == 0);
475			KASSERT(t->t_root == n);
476			KASSERT(&t->t_root == where);
477			return NULL; /* done */
478		}
479		KASSERT(level > 0);
480		level--;
481		n = next;
482		mask = rpst_level2mask(t, level);
483		maxidx = rpst_maxidx(n, max_x, mask);
484		idx = where - n->n_children + 1;
485		KASSERT(idx < 2 + 1);
486		goto next;
487	}
488	/* go to a child */
489	KASSERT(idx < 2);
490	next = n->n_children[idx];
491	if (next == NULL || next->n_y > max_y) {
492		idx++;
493		goto next;
494	}
495	KASSERT(next->n_parent == n);
496	KASSERT(next->n_y >= n->n_y);
497	level++;
498	mask >>= 1;
499	n = next;
500	idx = rpst_minidx(n, min_x, mask);
501	maxidx = rpst_maxidx(n, max_x, mask);
502#if 0
503	printf("%s: visit %p idx=%u level=%u mask=%llx\n",
504	    __func__, n, idx, level, mask);
505#endif
506	goto next;
507}
508
509#if defined(UNITTEST)
510#include <sys/time.h>
511
512#include <inttypes.h>
513#include <stdio.h>
514#include <stdlib.h>
515
516static void
517rpst_dump_node(const struct rpst_node *n, unsigned int depth)
518{
519	unsigned int i;
520
521	for (i = 0; i < depth; i++) {
522		printf("  ");
523	}
524	printf("[%u]", depth);
525	if (n == NULL) {
526		printf("NULL\n");
527		return;
528	}
529	printf("%p x=%" PRIx64 "(%" PRIu64 ") y=%" PRIx64 "(%" PRIu64 ")\n",
530	    (const void *)n, n->n_x, n->n_x, n->n_y, n->n_y);
531	for (i = 0; i < 2; i++) {
532		rpst_dump_node(n->n_children[i], depth + 1);
533	}
534}
535
536static void
537rpst_dump_tree(const struct rpst_tree *t)
538{
539
540	printf("pst %p height=%u\n", (const void *)t, t->t_height);
541	rpst_dump_node(t->t_root, 0);
542}
543
544struct testnode {
545	struct rpst_node n;
546	struct testnode *next;
547	bool failed;
548	bool found;
549};
550
551struct rpst_tree t;
552struct testnode *h = NULL;
553
554static uintmax_t
555tvdiff(const struct timeval *tv1, const struct timeval *tv2)
556{
557
558	return (uintmax_t)tv1->tv_sec * 1000000 + tv1->tv_usec -
559	    tv2->tv_sec * 1000000 - tv2->tv_usec;
560}
561
562static unsigned int
563query(uint64_t max_y, uint64_t min_x, uint64_t max_x)
564{
565	struct testnode *n;
566	struct rpst_node *rn;
567	struct rpst_iterator it;
568	struct timeval start;
569	struct timeval end;
570	unsigned int done;
571
572	printf("quering max_y=%" PRIu64 " min_x=%" PRIu64 " max_x=%" PRIu64
573	    "\n",
574	    max_y, min_x, max_x);
575	done = 0;
576	gettimeofday(&start, NULL);
577	for (rn = rpst_iterate_first(&t, max_y, min_x, max_x, &it);
578	    rn != NULL;
579	    rn = rpst_iterate_next(&it)) {
580		done++;
581#if 0
582		printf("found %p x=%" PRIu64 " y=%" PRIu64 "\n",
583		    (void *)rn, rn->n_x, rn->n_y);
584#endif
585		n = (void *)rn;
586		assert(!n->found);
587		n->found = true;
588	}
589	gettimeofday(&end, NULL);
590	printf("%u nodes found in %ju usecs\n", done,
591	    tvdiff(&end, &start));
592
593	gettimeofday(&start, NULL);
594	for (n = h; n != NULL; n = n->next) {
595		assert(n->failed ||
596		    n->found == rpst_iterator_match_p(&n->n, &it));
597		n->found = false;
598	}
599	gettimeofday(&end, NULL);
600	printf("(linear search took %ju usecs)\n", tvdiff(&end, &start));
601	return done;
602}
603
604int
605main(int argc, char *argv[])
606{
607	struct testnode *n;
608	unsigned int i;
609	struct rpst_iterator it;
610	struct timeval start;
611	struct timeval end;
612	uint64_t min_y = UINT64_MAX;
613	uint64_t max_y = 0;
614	uint64_t min_x = UINT64_MAX;
615	uint64_t max_x = 0;
616	uint64_t w;
617	unsigned int done;
618	unsigned int fail;
619	unsigned int num = 500000;
620
621	rpst_init_tree(&t);
622	rpst_dump_tree(&t);
623	assert(NULL == rpst_iterate_first(&t, UINT64_MAX, 0, UINT64_MAX, &it));
624
625	for (i = 0; i < num; i++) {
626		n = malloc(sizeof(*n));
627		if (i > 499000) {
628			n->n.n_x = 10;
629			n->n.n_y = random();
630		} else if (i > 400000) {
631			n->n.n_x = i;
632			n->n.n_y = random();
633		} else {
634			n->n.n_x = random();
635			n->n.n_y = random();
636		}
637		if (n->n.n_y < min_y) {
638			min_y = n->n.n_y;
639		}
640		if (n->n.n_y > max_y) {
641			max_y = n->n.n_y;
642		}
643		if (n->n.n_x < min_x) {
644			min_x = n->n.n_x;
645		}
646		if (n->n.n_x > max_x) {
647			max_x = n->n.n_x;
648		}
649		n->found = false;
650		n->failed = false;
651		n->next = h;
652		h = n;
653	}
654
655	done = 0;
656	fail = 0;
657	gettimeofday(&start, NULL);
658	for (n = h; n != NULL; n = n->next) {
659		struct rpst_node *o;
660#if 0
661		printf("insert %p x=%" PRIu64 " y=%" PRIu64 "\n",
662		    n, n->n.n_x, n->n.n_y);
663#endif
664		o = rpst_insert_node(&t, &n->n);
665		if (o == NULL) {
666			done++;
667		} else {
668			n->failed = true;
669			fail++;
670		}
671	}
672	gettimeofday(&end, NULL);
673	printf("%u nodes inserted and %u insertion failed in %ju usecs\n",
674	    done, fail,
675	    tvdiff(&end, &start));
676
677	assert(min_y == 0 || 0 == query(min_y - 1, 0, UINT64_MAX));
678	assert(max_x == UINT64_MAX ||
679	    0 == query(UINT64_MAX, max_x + 1, UINT64_MAX));
680	assert(min_x == 0 || 0 == query(UINT64_MAX, 0, min_x - 1));
681
682	done = query(max_y, min_x, max_x);
683	assert(done == num - fail);
684
685	done = query(UINT64_MAX, 0, UINT64_MAX);
686	assert(done == num - fail);
687
688	w = max_x - min_x;
689	query(max_y / 2, min_x, max_x);
690	query(max_y, min_x + w / 2, max_x);
691	query(max_y / 2, min_x + w / 2, max_x);
692	query(max_y / 2, min_x, max_x - w / 2);
693	query(max_y / 2, min_x + w / 3, max_x - w / 3);
694	query(max_y - 1, min_x + 1, max_x - 1);
695	query(UINT64_MAX, 10, 10);
696
697	done = 0;
698	gettimeofday(&start, NULL);
699	for (n = h; n != NULL; n = n->next) {
700		if (n->failed) {
701			continue;
702		}
703#if 0
704		printf("remove %p x=%" PRIu64 " y=%" PRIu64 "\n",
705		    n, n->n.n_x, n->n.n_y);
706#endif
707		rpst_remove_node(&t, &n->n);
708		done++;
709	}
710	gettimeofday(&end, NULL);
711	printf("%u nodes removed in %ju usecs\n", done,
712	    tvdiff(&end, &start));
713
714	rpst_dump_tree(&t);
715}
716#endif /* defined(UNITTEST) */
717