1#include <stdlib.h>
2#include <search.h>
3
4/*
5avl tree implementation using recursive functions
6the height of an n node tree is less than 1.44*log2(n+2)-1
7(so the max recursion depth in case of a tree with 2^32 nodes is 45)
8*/
9
10struct node {
11	const void *key;
12	struct node *left;
13	struct node *right;
14	int height;
15};
16
17static int delta(struct node *n) {
18	return (n->left ? n->left->height:0) - (n->right ? n->right->height:0);
19}
20
21static void updateheight(struct node *n) {
22	n->height = 0;
23	if (n->left && n->left->height > n->height)
24		n->height = n->left->height;
25	if (n->right && n->right->height > n->height)
26		n->height = n->right->height;
27	n->height++;
28}
29
30static struct node *rotl(struct node *n) {
31	struct node *r = n->right;
32	n->right = r->left;
33	r->left = n;
34	updateheight(n);
35	updateheight(r);
36	return r;
37}
38
39static struct node *rotr(struct node *n) {
40	struct node *l = n->left;
41	n->left = l->right;
42	l->right = n;
43	updateheight(n);
44	updateheight(l);
45	return l;
46}
47
48static struct node *balance(struct node *n) {
49	int d = delta(n);
50
51	if (d < -1) {
52		if (delta(n->right) > 0)
53			n->right = rotr(n->right);
54		return rotl(n);
55	} else if (d > 1) {
56		if (delta(n->left) < 0)
57			n->left = rotl(n->left);
58		return rotr(n);
59	}
60	updateheight(n);
61	return n;
62}
63
64static struct node *find(struct node *n, const void *k,
65	int (*cmp)(const void *, const void *))
66{
67	int c;
68
69	if (!n)
70		return 0;
71	c = cmp(k, n->key);
72	if (c == 0)
73		return n;
74	if (c < 0)
75		return find(n->left, k, cmp);
76	else
77		return find(n->right, k, cmp);
78}
79
80static struct node *insert(struct node *n, const void *k,
81	int (*cmp)(const void *, const void *), struct node **found)
82{
83	struct node *r;
84	int c;
85
86	if (!n) {
87		n = malloc(sizeof *n);
88		if (n) {
89			n->key = k;
90			n->left = n->right = 0;
91			n->height = 1;
92		}
93		*found = n;
94		return n;
95	}
96	c = cmp(k, n->key);
97	if (c == 0) {
98		*found = n;
99		return 0;
100	}
101	r = insert(c < 0 ? n->left : n->right, k, cmp, found);
102	if (r) {
103		if (c < 0)
104			n->left = r;
105		else
106			n->right = r;
107		r = balance(n);
108	}
109	return r;
110}
111
112static struct node *remove_rightmost(struct node *n, struct node **rightmost)
113{
114	if (!n->right) {
115		*rightmost = n;
116		return n->left;
117	}
118	n->right = remove_rightmost(n->right, rightmost);
119	return balance(n);
120}
121
122static struct node *remove(struct node **n, const void *k,
123	int (*cmp)(const void *, const void *), struct node *parent)
124{
125	int c;
126
127	if (!*n)
128		return 0;
129	c = cmp(k, (*n)->key);
130	if (c == 0) {
131		struct node *r = *n;
132		if (r->left) {
133			r->left = remove_rightmost(r->left, n);
134			(*n)->left = r->left;
135			(*n)->right = r->right;
136			*n = balance(*n);
137		} else
138			*n = r->right;
139		free(r);
140		return parent;
141	}
142	if (c < 0)
143		parent = remove(&(*n)->left, k, cmp, *n);
144	else
145		parent = remove(&(*n)->right, k, cmp, *n);
146	if (parent)
147		*n = balance(*n);
148	return parent;
149}
150
151void *tdelete(const void *restrict key, void **restrict rootp,
152	int(*compar)(const void *, const void *))
153{
154	if (!rootp)
155		return 0;
156	struct node *n = *rootp;
157	struct node *ret;
158	/* last argument is arbitrary non-null pointer
159	   which is returned when the root node is deleted */
160	ret = remove(&n, key, compar, n);
161	*rootp = n;
162	return ret;
163}
164
165void *tfind(const void *key, void *const *rootp,
166	int(*compar)(const void *, const void *))
167{
168	if (!rootp)
169		return 0;
170	return find(*rootp, key, compar);
171}
172
173void *tsearch(const void *key, void **rootp,
174	int (*compar)(const void *, const void *))
175{
176	struct node *update;
177	struct node *ret;
178	if (!rootp)
179		return 0;
180	update = insert(*rootp, key, compar, &ret);
181	if (update)
182		*rootp = update;
183	return ret;
184}
185
186static void walk(const struct node *r, void (*action)(const void *, VISIT, int), int d)
187{
188	if (r == 0)
189		return;
190	if (r->left == 0 && r->right == 0)
191		action(r, leaf, d);
192	else {
193		action(r, preorder, d);
194		walk(r->left, action, d+1);
195		action(r, postorder, d);
196		walk(r->right, action, d+1);
197		action(r, endorder, d);
198	}
199}
200
201void twalk(const void *root, void (*action)(const void *, VISIT, int))
202{
203	walk(root, action, 0);
204}
205