1/*	$NetBSD: ipf_rb.h,v 1.4 2012/12/20 21:42:28 christos Exp $	*/
2
3/*
4 * Copyright (C) 2012 by Darren Reed.
5 *
6 * See the IPFILTER.LICENCE file for details on licencing.
7 *
8 */
9
10/*
11 * If the OS has a red-black tree implementation, use it.
12 */
13#ifdef HAVE_RBTREE
14
15#include <sys/rbtree.h>
16
17# define	RBI_LINK(_n, _t)
18# define	RBI_FIELD(_n)			rb_node_t
19# define	RBI_HEAD(_n, _t)		rb_tree_t
20
21/* Define adapter code between the ipf-specific and the system rb impls. */
22# define	RBI_CODE(_n, _t, _f, _cmp)				\
23signed int _n##_compare_nodes(void *ctx, const void *n1, const void *n2);\
24signed int _n##_compare_key(void *ctx, const void *n1, const void *key);\
25typedef	void	(*_n##_rb_walker_t)(_t *, void *);			\
26void	_n##_rb_walktree(rb_tree_t *, _n##_rb_walker_t, void *);	\
27									\
28static const rb_tree_ops_t _n##_tree_ops = {				\
29        .rbto_compare_nodes = _n##_compare_nodes,			\
30        .rbto_compare_key = _n##_compare_key,				\
31        .rbto_node_offset = offsetof(_t, _f),				\
32        .rbto_context = NULL						\
33};									\
34									\
35int									\
36_n##_compare_nodes(void *ctx, const void *n1, const void *n2) {		\
37	return _cmp(n1, n2);						\
38}									\
39									\
40int									\
41_n##_compare_key(void *ctx, const void *n1, const void *key) {		\
42	return _cmp(n1, key);						\
43}									\
44									\
45void									\
46_n##_rb_walktree(rb_tree_t *head, _n##_rb_walker_t func, void *arg)	\
47{									\
48	_t *rb;								\
49	/* Take advantage of the fact that the ipf code only uses this  \
50	   method to clear the tree, in order to do it more safely. */	\
51	while ((rb = rb_tree_iterate(head, NULL, RB_DIR_RIGHT)) != NULL) {\
52		rb_tree_remove_node(head, rb);				\
53		func(rb, arg);						\
54	}								\
55}
56
57# define	RBI_DELETE(_n, _h, _v)		rb_tree_remove_node(_h, _v)
58# define	RBI_INIT(_n, _h)		rb_tree_init(_h, &_n##_tree_ops)
59# define	RBI_INSERT(_n, _h, _v)		rb_tree_insert_node(_h, _v)
60# define	RBI_ISEMPTY(_h)			(rb_tree_iterate(_h, NULL, RB_DIR_RIGHT) == NULL)
61# define	RBI_SEARCH(_n, _h, _k)		rb_tree_find_node(_h, _k)
62# define	RBI_WALK(_n, _h, _w, _a)	_n##_rb_walktree(_h, _w, _a)
63
64#else
65
66typedef enum rbcolour_e {
67	C_BLACK = 0,
68	C_RED = 1
69} rbcolour_t;
70
71#define	RBI_LINK(_n, _t)						\
72	struct _n##_rb_link {						\
73		struct _t	*left;					\
74		struct _t	*right;					\
75		struct _t	*parent;				\
76		rbcolour_t	colour;					\
77	}
78
79#define	RBI_HEAD(_n, _t)						\
80struct _n##_rb_head {							\
81	struct _t	top;						\
82	int		count;						\
83	int		(* compare)(struct _t *, struct _t *);		\
84}
85
86#define	RBI_FIELD(_n)			struct _n##_rb_link
87
88#define	RBI_CODE(_n, _t, _f, _cmp)					\
89									\
90_t RBI_ZERO(_n);							\
91									\
92typedef	void	(*_n##_rb_walker_t)(_t *, void *);			\
93									\
94_t *	_n##_rb_delete(struct _n##_rb_head *, _t *);			\
95void	_n##_rb_init(struct _n##_rb_head *);				\
96void	_n##_rb_insert(struct _n##_rb_head *, _t *);			\
97_t *	_n##_rb_search(struct _n##_rb_head *, void *);			\
98void	_n##_rb_walktree(struct _n##_rb_head *, _n##_rb_walker_t, void *);\
99									\
100static void								\
101rotate_left(struct _n##_rb_head *head, _t *node)			\
102{									\
103	_t *parent, *tmp1, *tmp2;					\
104									\
105	parent = node->_f.parent;					\
106	tmp1 = node->_f.right;						\
107	tmp2 = tmp1->_f.left;						\
108	node->_f.right = tmp2;						\
109	if (tmp2 != & _n##_rb_zero)					\
110		tmp2->_f.parent = node;					\
111	if (parent == & _n##_rb_zero)					\
112		head->top._f.right = tmp1;				\
113	else if (parent->_f.right == node)				\
114		parent->_f.right = tmp1;				\
115	else								\
116		parent->_f.left = tmp1;					\
117	tmp1->_f.left = node;						\
118	tmp1->_f.parent = parent;					\
119	node->_f.parent = tmp1;						\
120}									\
121									\
122static void								\
123rotate_right(struct _n##_rb_head *head, _t *node)			\
124{									\
125	_t *parent, *tmp1, *tmp2;					\
126									\
127	parent = node->_f.parent;					\
128	tmp1 = node->_f.left;						\
129	tmp2 = tmp1->_f.right;						\
130	node->_f.left = tmp2;						\
131	if (tmp2 != &_n##_rb_zero)					\
132		tmp2->_f.parent = node;					\
133	if (parent == &_n##_rb_zero)					\
134		head->top._f.right = tmp1;				\
135	else if (parent->_f.right == node)				\
136		parent->_f.right = tmp1;				\
137	else								\
138		parent->_f.left = tmp1;					\
139	tmp1->_f.right = node;						\
140	tmp1->_f.parent = parent;					\
141	node->_f.parent = tmp1;						\
142}									\
143									\
144void									\
145_n##_rb_insert(struct _n##_rb_head *head, _t *node)			\
146{									\
147	_t *n, *parent, **p, *tmp1, *gparent;				\
148									\
149	parent = &head->top;						\
150	node->_f.left = &_n##_rb_zero;					\
151	node->_f.right = &_n##_rb_zero;					\
152	p = &head->top._f.right;					\
153	while ((n = *p) != &_n##_rb_zero) {				\
154		if (_cmp(node, n) < 0)					\
155			p = &n->_f.left;				\
156		else							\
157			p = &n->_f.right;				\
158		parent = n;						\
159	}								\
160	*p = node;							\
161	node->_f.colour = C_RED;					\
162	node->_f.parent = parent;					\
163									\
164	while ((node != &_n##_rb_zero) && (parent->_f.colour == C_RED)){\
165		gparent = parent->_f.parent;				\
166		if (parent == gparent->_f.left) {			\
167			tmp1 = gparent->_f.right;			\
168			if (tmp1->_f.colour == C_RED) {			\
169				parent->_f.colour = C_BLACK;		\
170				tmp1->_f.colour = C_BLACK;		\
171				gparent->_f.colour = C_RED;		\
172				node = gparent;				\
173			} else {					\
174				if (node == parent->_f.right) {		\
175					node = parent;			\
176					rotate_left(head, node);	\
177					parent = node->_f.parent;	\
178				}					\
179				parent->_f.colour = C_BLACK;		\
180				gparent->_f.colour = C_RED;		\
181				rotate_right(head, gparent);		\
182			}						\
183		} else {						\
184			tmp1 = gparent->_f.left;			\
185			if (tmp1->_f.colour == C_RED) {			\
186				parent->_f.colour = C_BLACK;		\
187				tmp1->_f.colour = C_BLACK;		\
188				gparent->_f.colour = C_RED;		\
189				node = gparent;				\
190			} else {					\
191				if (node == parent->_f.left) {		\
192					node = parent;			\
193					rotate_right(head, node);	\
194					parent = node->_f.parent;	\
195				}					\
196				parent->_f.colour = C_BLACK;		\
197				gparent->_f.colour = C_RED;		\
198				rotate_left(head, parent->_f.parent);	\
199			}						\
200		}							\
201		parent = node->_f.parent;				\
202	}								\
203	head->top._f.right->_f.colour = C_BLACK;			\
204	head->count++;						\
205}									\
206									\
207static void								\
208deleteblack(struct _n##_rb_head *head, _t *parent, _t *node)		\
209{									\
210	_t *tmp;							\
211									\
212	while ((node == &_n##_rb_zero || node->_f.colour == C_BLACK) &&	\
213	       node != &head->top) {					\
214		if (parent->_f.left == node) {				\
215			tmp = parent->_f.right;				\
216			if (tmp->_f.colour == C_RED) {			\
217				tmp->_f.colour = C_BLACK;		\
218				parent->_f.colour = C_RED;		\
219				rotate_left(head, parent);		\
220				tmp = parent->_f.right;			\
221			}						\
222			if ((tmp->_f.left == &_n##_rb_zero ||		\
223			     tmp->_f.left->_f.colour == C_BLACK) &&	\
224			    (tmp->_f.right == &_n##_rb_zero ||		\
225			     tmp->_f.right->_f.colour == C_BLACK)) {	\
226				tmp->_f.colour = C_RED;			\
227				node = parent;				\
228				parent = node->_f.parent;		\
229			} else {					\
230				if (tmp->_f.right == &_n##_rb_zero ||	\
231				    tmp->_f.right->_f.colour == C_BLACK) {\
232					_t *tmp2 = tmp->_f.left;	\
233									\
234					if (tmp2 != &_n##_rb_zero)	\
235						tmp2->_f.colour = C_BLACK;\
236					tmp->_f.colour = C_RED;		\
237					rotate_right(head, tmp);	\
238					tmp = parent->_f.right;		\
239				}					\
240				tmp->_f.colour = parent->_f.colour;	\
241				parent->_f.colour = C_BLACK;		\
242				if (tmp->_f.right != &_n##_rb_zero)	\
243					tmp->_f.right->_f.colour = C_BLACK;\
244				rotate_left(head, parent);		\
245				node = head->top._f.right;		\
246			}						\
247		} else {						\
248			tmp = parent->_f.left;				\
249			if (tmp->_f.colour == C_RED) {			\
250				tmp->_f.colour = C_BLACK;		\
251				parent->_f.colour = C_RED;		\
252				rotate_right(head, parent);		\
253				tmp = parent->_f.left;			\
254			}						\
255			if ((tmp->_f.left == &_n##_rb_zero ||		\
256			     tmp->_f.left->_f.colour == C_BLACK) &&	\
257			    (tmp->_f.right == &_n##_rb_zero ||		\
258			     tmp->_f.right->_f.colour == C_BLACK)) {	\
259				tmp->_f.colour = C_RED;			\
260				node = parent;				\
261				parent = node->_f.parent;		\
262			} else {					\
263				if (tmp->_f.left == &_n##_rb_zero ||	\
264				    tmp->_f.left->_f.colour == C_BLACK) {\
265					_t *tmp2 = tmp->_f.right;	\
266									\
267					if (tmp2 != &_n##_rb_zero)	\
268						tmp2->_f.colour = C_BLACK;\
269					tmp->_f.colour = C_RED;		\
270					rotate_left(head, tmp);		\
271					tmp = parent->_f.left;		\
272				}					\
273				tmp->_f.colour = parent->_f.colour;	\
274				parent->_f.colour = C_BLACK;		\
275				if (tmp->_f.left != &_n##_rb_zero)	\
276					tmp->_f.left->_f.colour = C_BLACK;\
277				rotate_right(head, parent);		\
278				node = head->top._f.right;		\
279				break;					\
280			}						\
281		}							\
282	}								\
283	if (node != &_n##_rb_zero)					\
284		node->_f.colour = C_BLACK;				\
285}									\
286									\
287_t *									\
288_n##_rb_delete(struct _n##_rb_head *head, _t *node)			\
289{									\
290	_t *child, *parent, *old = node, *left;				\
291	rbcolour_t color;						\
292									\
293	if (node->_f.left == &_n##_rb_zero) {				\
294		child = node->_f.right;					\
295	} else if (node->_f.right == &_n##_rb_zero) {			\
296		child = node->_f.left;					\
297	} else {							\
298		node = node->_f.right;					\
299		while ((left = node->_f.left) != &_n##_rb_zero)		\
300			node = left;					\
301		child = node->_f.right;					\
302		parent = node->_f.parent;				\
303		color = node->_f.colour;				\
304		if (child != &_n##_rb_zero)				\
305			child->_f.parent = parent;			\
306		if (parent != &_n##_rb_zero) {				\
307			if (parent->_f.left == node)			\
308				parent->_f.left = child;		\
309			else						\
310				parent->_f.right = child;		\
311		} else {						\
312			head->top._f.right = child;			\
313		}							\
314		if (node->_f.parent == old)				\
315			parent = node;					\
316		*node = *old;						\
317		if (old->_f.parent != &_n##_rb_zero) {			\
318			if (old->_f.parent->_f.left == old)		\
319				old->_f.parent->_f.left = node;		\
320			else						\
321				old->_f.parent->_f.right = node;	\
322		} else {						\
323			head->top._f.right = child;			\
324		}							\
325		old->_f.left->_f.parent = node;				\
326		if (old->_f.right != &_n##_rb_zero)			\
327			old->_f.right->_f.parent = node;		\
328		if (parent != &_n##_rb_zero) {				\
329			left = parent;					\
330		}							\
331		goto colour;						\
332	}								\
333	parent = node->_f.parent;					\
334	color= node->_f.colour;						\
335	if (child != &_n##_rb_zero)					\
336		child->_f.parent = parent;				\
337	if (parent != &_n##_rb_zero) {					\
338		if (parent->_f.left == node)				\
339			parent->_f.left = child;			\
340		else							\
341			parent->_f.right = child;			\
342	} else {							\
343		head->top._f.right = child;				\
344	}								\
345colour:									\
346	if (color == C_BLACK)						\
347		deleteblack(head, parent, node);			\
348	head->count--;							\
349	return old;							\
350}									\
351									\
352void									\
353_n##_rb_init(struct _n##_rb_head *head)					\
354{									\
355	memset(&_n##_rb_zero, 0, sizeof(_n##_rb_zero));			\
356	head->top._f.left = &_n##_rb_zero;				\
357	head->top._f.right = &_n##_rb_zero;				\
358	head->top._f.parent = &head->top;				\
359	_n##_rb_zero._f.left = &_n##_rb_zero;				\
360	_n##_rb_zero._f.right = &_n##_rb_zero;				\
361	_n##_rb_zero._f.parent = &_n##_rb_zero;				\
362}									\
363									\
364void									\
365_n##_rb_walktree(struct _n##_rb_head *head, _n##_rb_walker_t func, void *arg)\
366{									\
367	_t *prev;							\
368	_t *next;							\
369	_t *node = head->top._f.right;					\
370	_t *base;							\
371									\
372	while (node != &_n##_rb_zero)					\
373		node = node->_f.left;					\
374									\
375	for (;;) {							\
376		base = node;						\
377		prev = node;						\
378		while ((node->_f.parent->_f.right == node) &&		\
379		       (node != &_n##_rb_zero))	{			\
380			prev = node;					\
381			node = node->_f.parent;				\
382		}							\
383									\
384		node = prev;						\
385		for (node = node->_f.parent->_f.right; node != &_n##_rb_zero;\
386		     node = node->_f.left)				\
387			prev = node;					\
388		next = prev;						\
389									\
390		if (node != &_n##_rb_zero)				\
391			func(node, arg);				\
392									\
393		node = next;						\
394		if (node == &_n##_rb_zero)				\
395			break;						\
396	}								\
397}									\
398									\
399_t *									\
400_n##_rb_search(struct _n##_rb_head *head, void *key)			\
401{									\
402	int	match = 0;						\
403	_t	*node;							\
404	node = head->top._f.right;					\
405	while (node != &_n##_rb_zero) {					\
406		match = _cmp(key, node);				\
407		if (match == 0)						\
408			break;						\
409		if (match< 0)						\
410			node = node->_f.left;				\
411		else							\
412			node = node->_f.right;				\
413	}								\
414	if (node == &_n##_rb_zero || match != 0)			\
415		return (NULL);						\
416	return (node);							\
417}
418
419#define	RBI_DELETE(_n, _h, _v)		_n##_rb_delete(_h, _v)
420#define	RBI_INIT(_n, _h)		_n##_rb_init(_h)
421#define	RBI_INSERT(_n, _h, _v)		_n##_rb_insert(_h, _v)
422#define	RBI_ISEMPTY(_h)			((_h)->count == 0)
423#define	RBI_SEARCH(_n, _h, _k)		_n##_rb_search(_h, _k)
424#define	RBI_WALK(_n, _h, _w, _a)	_n##_rb_walktree(_h, _w, _a)
425#define	RBI_ZERO(_n)			_n##_rb_zero
426
427#endif
428