1#include <stdlib.h>
2#include <search.h>
3#include "tsearch.h"
4
5static inline int height(struct node *n) { return n ? n->h : 0; }
6
7static int rot(void **p, struct node *x, int dir /* deeper side */)
8{
9	struct node *y = x->a[dir];
10	struct node *z = y->a[!dir];
11	int hx = x->h;
12	int hz = height(z);
13	if (hz > height(y->a[dir])) {
14		/*
15		 *   x
16		 *  / \ dir          z
17		 * A   y            / \
18		 *    / \   -->    x   y
19		 *   z   D        /|   |\
20		 *  / \          A B   C D
21		 * B   C
22		 */
23		x->a[dir] = z->a[!dir];
24		y->a[!dir] = z->a[dir];
25		z->a[!dir] = x;
26		z->a[dir] = y;
27		x->h = hz;
28		y->h = hz;
29		z->h = hz+1;
30	} else {
31		/*
32		 *   x               y
33		 *  / \             / \
34		 * A   y    -->    x   D
35		 *    / \         / \
36		 *   z   D       A   z
37		 */
38		x->a[dir] = z;
39		y->a[!dir] = x;
40		x->h = hz+1;
41		y->h = hz+2;
42		z = y;
43	}
44	*p = z;
45	return z->h - hx;
46}
47
48/* balance *p, return 0 if height is unchanged.  */
49int __tsearch_balance(void **p)
50{
51	struct node *n = *p;
52	int h0 = height(n->a[0]);
53	int h1 = height(n->a[1]);
54	if (h0 - h1 + 1u < 3u) {
55		int old = n->h;
56		n->h = h0<h1 ? h1+1 : h0+1;
57		return n->h - old;
58	}
59	return rot(p, n, h0<h1);
60}
61
62void *tsearch(const void *key, void **rootp,
63	int (*cmp)(const void *, const void *))
64{
65	if (!rootp)
66		return 0;
67
68	{
69	void **a[MAXH];
70	struct node *n = *rootp;
71	struct node *r;
72	int i=0;
73	a[i++] = rootp;
74	for (;;) {
75		if (!n)
76			break;
77		{
78		int c = cmp(key, n->key);
79		if (!c)
80			return n;
81		a[i++] = &n->a[c>0];
82		n = n->a[c>0];
83		}
84	}
85	r = malloc(sizeof *r);
86	if (!r)
87		return 0;
88	r->key = key;
89	r->a[0] = r->a[1] = 0;
90	r->h = 1;
91	/* insert new node, rebalance ancestors.  */
92	*a[--i] = r;
93	while (i && __tsearch_balance(a[--i]));
94	return r;
95	}
96}
97