1254562Scy/*
2254562Scy * Copyright (C) 2012 by Darren Reed.
3254562Scy *
4254562Scy * See the IPFILTER.LICENCE file for details on licencing.
5254562Scy *
6254562Scy */
7254562Scytypedef enum rbcolour_e {
8254562Scy	C_BLACK = 0,
9254562Scy	C_RED = 1
10254562Scy} rbcolour_t;
11254562Scy
12254562Scy#define	RBI_LINK(_n, _t)							\
13254562Scy	struct _n##_rb_link {						\
14254562Scy		struct _t	*left;					\
15254562Scy		struct _t	*right;					\
16254562Scy		struct _t	*parent;				\
17254562Scy		rbcolour_t	colour;					\
18254562Scy	}
19254562Scy
20254562Scy#define	RBI_HEAD(_n, _t)						\
21254562Scystruct _n##_rb_head {							\
22254562Scy	struct _t	top;						\
23254562Scy	int		count;						\
24254562Scy	int		(* compare)(struct _t *, struct _t *);		\
25254562Scy}
26254562Scy
27254562Scy#define	RBI_CODE(_n, _t, _f, _cmp)					\
28254562Scy									\
29254562Scytypedef	void	(*_n##_rb_walker_t)(_t *, void *);			\
30254562Scy									\
31254562Scy_t *	_n##_rb_delete(struct _n##_rb_head *, _t *);			\
32254562Scyvoid	_n##_rb_init(struct _n##_rb_head *);				\
33254562Scyvoid	_n##_rb_insert(struct _n##_rb_head *, _t *);			\
34254562Scy_t *	_n##_rb_search(struct _n##_rb_head *, void *);			\
35254562Scyvoid	_n##_rb_walktree(struct _n##_rb_head *, _n##_rb_walker_t, void *);\
36254562Scy									\
37254562Scystatic void								\
38254562Scyrotate_left(struct _n##_rb_head *head, _t *node)			\
39254562Scy{									\
40254562Scy	_t *parent, *tmp1, *tmp2;					\
41254562Scy									\
42254562Scy	parent = node->_f.parent;					\
43254562Scy	tmp1 = node->_f.right;						\
44254562Scy	tmp2 = tmp1->_f.left;						\
45254562Scy	node->_f.right = tmp2;						\
46254562Scy	if (tmp2 != & _n##_rb_zero)					\
47254562Scy		tmp2->_f.parent = node;					\
48254562Scy	if (parent == & _n##_rb_zero)					\
49254562Scy		head->top._f.right = tmp1;				\
50254562Scy	else if (parent->_f.right == node)				\
51254562Scy		parent->_f.right = tmp1;				\
52254562Scy	else								\
53254562Scy		parent->_f.left = tmp1;					\
54254562Scy	tmp1->_f.left = node;						\
55254562Scy	tmp1->_f.parent = parent;					\
56254562Scy	node->_f.parent = tmp1;						\
57254562Scy}									\
58254562Scy									\
59254562Scystatic void								\
60254562Scyrotate_right(struct _n##_rb_head *head, _t *node)			\
61254562Scy{									\
62254562Scy	_t *parent, *tmp1, *tmp2;					\
63254562Scy									\
64254562Scy	parent = node->_f.parent;					\
65254562Scy	tmp1 = node->_f.left;						\
66254562Scy	tmp2 = tmp1->_f.right;						\
67254562Scy	node->_f.left = tmp2;						\
68254562Scy	if (tmp2 != &_n##_rb_zero)					\
69254562Scy		tmp2->_f.parent = node;					\
70254562Scy	if (parent == &_n##_rb_zero)					\
71254562Scy		head->top._f.right = tmp1;				\
72254562Scy	else if (parent->_f.right == node)				\
73254562Scy		parent->_f.right = tmp1;				\
74254562Scy	else								\
75254562Scy		parent->_f.left = tmp1;					\
76254562Scy	tmp1->_f.right = node;						\
77254562Scy	tmp1->_f.parent = parent;					\
78254562Scy	node->_f.parent = tmp1;						\
79254562Scy}									\
80254562Scy									\
81254562Scyvoid									\
82254562Scy_n##_rb_insert(struct _n##_rb_head *head, _t *node)			\
83254562Scy{									\
84254562Scy	_t *n, *parent, **p, *tmp1, *gparent;				\
85254562Scy									\
86254562Scy	parent = &head->top;						\
87254562Scy	node->_f.left = &_n##_rb_zero;					\
88254562Scy	node->_f.right = &_n##_rb_zero;					\
89254562Scy	p = &head->top._f.right;					\
90254562Scy	while ((n = *p) != &_n##_rb_zero) {				\
91254562Scy		if (_cmp(node, n) < 0)					\
92254562Scy			p = &n->_f.left;				\
93254562Scy		else							\
94254562Scy			p = &n->_f.right;				\
95254562Scy		parent = n;						\
96254562Scy	}								\
97254562Scy	*p = node;							\
98254562Scy	node->_f.colour = C_RED;					\
99254562Scy	node->_f.parent = parent;					\
100254562Scy									\
101254562Scy	while ((node != &_n##_rb_zero) && (parent->_f.colour == C_RED)){\
102254562Scy		gparent = parent->_f.parent;				\
103254562Scy		if (parent == gparent->_f.left) {			\
104254562Scy			tmp1 = gparent->_f.right;			\
105254562Scy			if (tmp1->_f.colour == C_RED) {			\
106254562Scy				parent->_f.colour = C_BLACK;		\
107254562Scy				tmp1->_f.colour = C_BLACK;		\
108254562Scy				gparent->_f.colour = C_RED;		\
109254562Scy				node = gparent;				\
110254562Scy			} else {					\
111254562Scy				if (node == parent->_f.right) {		\
112254562Scy					node = parent;			\
113254562Scy					rotate_left(head, node);	\
114254562Scy					parent = node->_f.parent;	\
115254562Scy				}					\
116254562Scy				parent->_f.colour = C_BLACK;		\
117254562Scy				gparent->_f.colour = C_RED;		\
118254562Scy				rotate_right(head, gparent);		\
119254562Scy			}						\
120254562Scy		} else {						\
121254562Scy			tmp1 = gparent->_f.left;			\
122254562Scy			if (tmp1->_f.colour == C_RED) {			\
123254562Scy				parent->_f.colour = C_BLACK;		\
124254562Scy				tmp1->_f.colour = C_BLACK;		\
125254562Scy				gparent->_f.colour = C_RED;		\
126254562Scy				node = gparent;				\
127254562Scy			} else {					\
128254562Scy				if (node == parent->_f.left) {		\
129254562Scy					node = parent;			\
130254562Scy					rotate_right(head, node);	\
131254562Scy					parent = node->_f.parent;	\
132254562Scy				}					\
133254562Scy				parent->_f.colour = C_BLACK;		\
134254562Scy				gparent->_f.colour = C_RED;		\
135254562Scy				rotate_left(head, parent->_f.parent);	\
136254562Scy			}						\
137254562Scy		}							\
138254562Scy		parent = node->_f.parent;				\
139254562Scy	}								\
140254562Scy	head->top._f.right->_f.colour = C_BLACK;			\
141254562Scy	head->count++;						\
142254562Scy}									\
143254562Scy									\
144254562Scystatic void								\
145254562Scydeleteblack(struct _n##_rb_head *head, _t *parent, _t *node)		\
146254562Scy{									\
147254562Scy	_t *tmp;							\
148254562Scy									\
149254562Scy	while ((node == &_n##_rb_zero || node->_f.colour == C_BLACK) &&	\
150254562Scy	       node != &head->top) {					\
151254562Scy		if (parent->_f.left == node) {				\
152254562Scy			tmp = parent->_f.right;				\
153254562Scy			if (tmp->_f.colour == C_RED) {			\
154254562Scy				tmp->_f.colour = C_BLACK;		\
155254562Scy				parent->_f.colour = C_RED;		\
156254562Scy				rotate_left(head, parent);		\
157254562Scy				tmp = parent->_f.right;			\
158254562Scy			}						\
159254562Scy			if ((tmp->_f.left == &_n##_rb_zero ||		\
160254562Scy			     tmp->_f.left->_f.colour == C_BLACK) &&	\
161254562Scy			    (tmp->_f.right == &_n##_rb_zero ||		\
162254562Scy			     tmp->_f.right->_f.colour == C_BLACK)) {	\
163254562Scy				tmp->_f.colour = C_RED;			\
164254562Scy				node = parent;				\
165254562Scy				parent = node->_f.parent;		\
166254562Scy			} else {					\
167254562Scy				if (tmp->_f.right == &_n##_rb_zero ||	\
168254562Scy				    tmp->_f.right->_f.colour == C_BLACK) {\
169254562Scy					_t *tmp2 = tmp->_f.left;	\
170254562Scy									\
171254562Scy					if (tmp2 != &_n##_rb_zero)	\
172254562Scy						tmp2->_f.colour = C_BLACK;\
173254562Scy					tmp->_f.colour = C_RED;		\
174254562Scy					rotate_right(head, tmp);	\
175254562Scy					tmp = parent->_f.right;		\
176254562Scy				}					\
177254562Scy				tmp->_f.colour = parent->_f.colour;	\
178254562Scy				parent->_f.colour = C_BLACK;		\
179254562Scy				if (tmp->_f.right != &_n##_rb_zero)	\
180254562Scy					tmp->_f.right->_f.colour = C_BLACK;\
181254562Scy				rotate_left(head, parent);		\
182254562Scy				node = head->top._f.right;		\
183254562Scy			}						\
184254562Scy		} else {						\
185254562Scy			tmp = parent->_f.left;				\
186254562Scy			if (tmp->_f.colour == C_RED) {			\
187254562Scy				tmp->_f.colour = C_BLACK;		\
188254562Scy				parent->_f.colour = C_RED;		\
189254562Scy				rotate_right(head, parent);		\
190254562Scy				tmp = parent->_f.left;			\
191254562Scy			}						\
192254562Scy			if ((tmp->_f.left == &_n##_rb_zero ||		\
193254562Scy			     tmp->_f.left->_f.colour == C_BLACK) &&	\
194254562Scy			    (tmp->_f.right == &_n##_rb_zero ||		\
195254562Scy			     tmp->_f.right->_f.colour == C_BLACK)) {	\
196254562Scy				tmp->_f.colour = C_RED;			\
197254562Scy				node = parent;				\
198254562Scy				parent = node->_f.parent;		\
199254562Scy			} else {					\
200254562Scy				if (tmp->_f.left == &_n##_rb_zero ||	\
201254562Scy				    tmp->_f.left->_f.colour == C_BLACK) {\
202254562Scy					_t *tmp2 = tmp->_f.right;	\
203254562Scy									\
204254562Scy					if (tmp2 != &_n##_rb_zero)	\
205254562Scy						tmp2->_f.colour = C_BLACK;\
206254562Scy					tmp->_f.colour = C_RED;		\
207254562Scy					rotate_left(head, tmp);		\
208254562Scy					tmp = parent->_f.left;		\
209254562Scy				}					\
210254562Scy				tmp->_f.colour = parent->_f.colour;	\
211254562Scy				parent->_f.colour = C_BLACK;		\
212254562Scy				if (tmp->_f.left != &_n##_rb_zero)	\
213254562Scy					tmp->_f.left->_f.colour = C_BLACK;\
214254562Scy				rotate_right(head, parent);		\
215254562Scy				node = head->top._f.right;		\
216254562Scy				break;					\
217254562Scy			}						\
218254562Scy		}							\
219254562Scy	}								\
220254562Scy	if (node != &_n##_rb_zero)					\
221254562Scy		node->_f.colour = C_BLACK;				\
222254562Scy}									\
223254562Scy									\
224254562Scy_t *									\
225254562Scy_n##_rb_delete(struct _n##_rb_head *head, _t *node)			\
226254562Scy{									\
227254562Scy	_t *child, *parent, *old = node, *left;				\
228254562Scy	rbcolour_t color;						\
229254562Scy									\
230254562Scy	if (node->_f.left == &_n##_rb_zero) {				\
231254562Scy		child = node->_f.right;					\
232254562Scy	} else if (node->_f.right == &_n##_rb_zero) {			\
233254562Scy		child = node->_f.left;					\
234254562Scy	} else {							\
235254562Scy		node = node->_f.right;					\
236254562Scy		while ((left = node->_f.left) != &_n##_rb_zero)		\
237254562Scy			node = left;					\
238254562Scy		child = node->_f.right;					\
239254562Scy		parent = node->_f.parent;				\
240254562Scy		color = node->_f.colour;				\
241254562Scy		if (child != &_n##_rb_zero)				\
242254562Scy			child->_f.parent = parent;			\
243254562Scy		if (parent != &_n##_rb_zero) {				\
244254562Scy			if (parent->_f.left == node)			\
245254562Scy				parent->_f.left = child;		\
246254562Scy			else						\
247254562Scy				parent->_f.right = child;		\
248254562Scy		} else {						\
249254562Scy			head->top._f.right = child;			\
250254562Scy		}							\
251254562Scy		if (node->_f.parent == old)				\
252254562Scy			parent = node;					\
253254562Scy		*node = *old;						\
254254562Scy		if (old->_f.parent != &_n##_rb_zero) {			\
255254562Scy			if (old->_f.parent->_f.left == old)		\
256254562Scy				old->_f.parent->_f.left = node;		\
257254562Scy			else						\
258254562Scy				old->_f.parent->_f.right = node;	\
259254562Scy		} else {						\
260254562Scy			head->top._f.right = child;			\
261254562Scy		}							\
262254562Scy		old->_f.left->_f.parent = node;				\
263254562Scy		if (old->_f.right != &_n##_rb_zero)			\
264254562Scy			old->_f.right->_f.parent = node;		\
265254562Scy		if (parent != &_n##_rb_zero) {				\
266254562Scy			left = parent;					\
267254562Scy		}							\
268254562Scy		goto colour;						\
269254562Scy	}								\
270254562Scy	parent = node->_f.parent;					\
271254562Scy	color= node->_f.colour;						\
272254562Scy	if (child != &_n##_rb_zero)					\
273254562Scy		child->_f.parent = parent;				\
274254562Scy	if (parent != &_n##_rb_zero) {					\
275254562Scy		if (parent->_f.left == node)				\
276254562Scy			parent->_f.left = child;			\
277254562Scy		else							\
278254562Scy			parent->_f.right = child;			\
279254562Scy	} else {							\
280254562Scy		head->top._f.right = child;				\
281254562Scy	}								\
282254562Scycolour:									\
283254562Scy	if (color == C_BLACK)						\
284254562Scy		deleteblack(head, parent, node);			\
285254562Scy	head->count--;							\
286254562Scy	return old;							\
287254562Scy}									\
288254562Scy									\
289254562Scyvoid									\
290254562Scy_n##_rb_init(struct _n##_rb_head *head)					\
291254562Scy{									\
292254562Scy	memset(head, 0, sizeof(*head));					\
293254562Scy	memset(&_n##_rb_zero, 0, sizeof(_n##_rb_zero));			\
294254562Scy	head->top._f.left = &_n##_rb_zero;				\
295254562Scy	head->top._f.right = &_n##_rb_zero;				\
296254562Scy	head->top._f.parent = &head->top;				\
297254562Scy	_n##_rb_zero._f.left = &_n##_rb_zero;				\
298254562Scy	_n##_rb_zero._f.right = &_n##_rb_zero;				\
299254562Scy	_n##_rb_zero._f.parent = &_n##_rb_zero;				\
300254562Scy}									\
301254562Scy									\
302254562Scyvoid									\
303254562Scy_n##_rb_walktree(struct _n##_rb_head *head, _n##_rb_walker_t func, void *arg)\
304254562Scy{									\
305254562Scy	_t *prev;							\
306254562Scy	_t *next;							\
307254562Scy	_t *node = head->top._f.right;					\
308254562Scy	_t *base;							\
309254562Scy									\
310254562Scy	while (node != &_n##_rb_zero)					\
311254562Scy		node = node->_f.left;					\
312254562Scy									\
313254562Scy	for (;;) {							\
314254562Scy		base = node;						\
315254562Scy		prev = node;						\
316254562Scy		while ((node->_f.parent->_f.right == node) &&		\
317254562Scy		       (node != &_n##_rb_zero))	{			\
318254562Scy			prev = node;					\
319254562Scy			node = node->_f.parent;				\
320254562Scy		}							\
321254562Scy									\
322254562Scy		node = prev;						\
323254562Scy		for (node = node->_f.parent->_f.right; node != &_n##_rb_zero;\
324254562Scy		     node = node->_f.left)				\
325254562Scy			prev = node;					\
326254562Scy		next = prev;						\
327254562Scy									\
328254562Scy		if (node != &_n##_rb_zero)				\
329254562Scy			func(node, arg);				\
330254562Scy									\
331254562Scy		node = next;						\
332254562Scy		if (node == &_n##_rb_zero)				\
333254562Scy			break;						\
334254562Scy	}								\
335254562Scy}									\
336254562Scy									\
337254562Scy_t *									\
338254562Scy_n##_rb_search(struct _n##_rb_head *head, void *key)			\
339254562Scy{									\
340254562Scy	int	match;							\
341254562Scy	_t	*node;							\
342254562Scy	node = head->top._f.right;					\
343254562Scy	while (node != &_n##_rb_zero) {					\
344254562Scy		match = _cmp(key, node);				\
345254562Scy		if (match == 0)						\
346254562Scy			break;						\
347254562Scy		if (match< 0)						\
348254562Scy			node = node->_f.left;				\
349254562Scy		else							\
350254562Scy			node = node->_f.right;				\
351254562Scy	}								\
352254562Scy	if (node == &_n##_rb_zero || match != 0)			\
353254562Scy		return (NULL);						\
354254562Scy	return (node);							\
355254562Scy}
356254562Scy
357254562Scy#define	RBI_DELETE(_n, _h, _v)		_n##_rb_delete(_h, _v)
358254562Scy#define	RBI_FIELD(_n)			struct _n##_rb_link
359254562Scy#define	RBI_INIT(_n, _h)		_n##_rb_init(_h)
360254562Scy#define	RBI_INSERT(_n, _h, _v)		_n##_rb_insert(_h, _v)
361254562Scy#define	RBI_ISEMPTY(_h)			((_h)->count == 0)
362254562Scy#define	RBI_SEARCH(_n, _h, _k)		_n##_rb_search(_h, _k)
363254562Scy#define	RBI_WALK(_n, _h, _w, _a)	_n##_rb_walktree(_h, _w, _a)
364254562Scy#define	RBI_ZERO(_n)			_n##_rb_zero
365