1/*
2 * Copyright 2003-2009, Ingo Weinhold <ingo_weinhold@gmx.de>.
3 * Distributed under the terms of the MIT License.
4 */
5#ifndef _KERNEL_UTIL_AVL_TREE_H
6#define _KERNEL_UTIL_AVL_TREE_H
7
8
9#include <util/AVLTreeBase.h>
10
11
12/*
13	To be implemented by the definition:
14
15	typedef int	Key;
16	typedef Foo	Value;
17
18	AVLTreeNode*		GetAVLTreeNode(Value* value) const;
19	Value*				GetValue(AVLTreeNode* node) const;
20	int					Compare(const Key& a, const Value* b) const;
21	int					Compare(const Value* a, const Value* b) const;
22*/
23
24
25
26template<typename Definition>
27class AVLTree : protected AVLTreeCompare {
28private:
29	typedef typename Definition::Key	Key;
30	typedef typename Definition::Value	Value;
31
32public:
33	class Iterator;
34	class ConstIterator;
35
36public:
37								AVLTree();
38								AVLTree(const Definition& definition);
39	virtual						~AVLTree();
40
41	inline	int					Count() const	{ return fTree.Count(); }
42	inline	bool				IsEmpty() const	{ return fTree.IsEmpty(); }
43	inline	void				Clear();
44
45			Value*				RootNode() const;
46
47			Value*				Previous(Value* value) const;
48			Value*				Next(Value* value) const;
49
50	inline	Iterator			GetIterator();
51	inline	ConstIterator		GetIterator() const;
52
53	inline	Iterator			GetIterator(Value* value);
54	inline	ConstIterator		GetIterator(Value* value) const;
55
56			Value*				Find(const Key& key) const;
57			Value*				FindClosest(const Key& key, bool less) const;
58
59			status_t			Insert(Value* value, Iterator* iterator = NULL);
60			Value*				Remove(const Key& key);
61			bool				Remove(Value* key);
62
63			void				CheckTree() const	{ fTree.CheckTree(); }
64
65protected:
66	// AVLTreeCompare
67	virtual	int					CompareKeyNode(const void* key,
68									const AVLTreeNode* node);
69	virtual	int					CompareNodes(const AVLTreeNode* node1,
70									const AVLTreeNode* node2);
71
72	// definition shortcuts
73	inline	AVLTreeNode*		_GetAVLTreeNode(Value* value) const;
74	inline	Value*				_GetValue(const AVLTreeNode* node) const;
75	inline	int					_Compare(const Key& a, const Value* b);
76	inline	int					_Compare(const Value* a, const Value* b);
77
78protected:
79			friend class Iterator;
80			friend class ConstIterator;
81
82			AVLTreeBase			fTree;
83			Definition			fDefinition;
84
85public:
86	// (need to implement it here, otherwise gcc 2.95.3 chokes)
87	class Iterator : public ConstIterator {
88	public:
89		inline Iterator()
90			:
91			ConstIterator()
92		{
93		}
94
95		inline Iterator(const Iterator& other)
96			:
97			ConstIterator(other)
98		{
99		}
100
101		inline void Remove()
102		{
103			if (AVLTreeNode* node = ConstIterator::fTreeIterator.Remove()) {
104				AVLTree<Definition>* parent
105					= const_cast<AVLTree<Definition>*>(
106						ConstIterator::fParent);
107			}
108		}
109
110	private:
111		inline Iterator(AVLTree<Definition>* parent,
112			const AVLTreeIterator& treeIterator)
113			: ConstIterator(parent, treeIterator)
114		{
115		}
116
117		friend class AVLTree<Definition>;
118	};
119};
120
121
122template<typename Definition>
123class AVLTree<Definition>::ConstIterator {
124public:
125	inline ConstIterator()
126		:
127		fParent(NULL),
128		fTreeIterator()
129	{
130	}
131
132	inline ConstIterator(const ConstIterator& other)
133		:
134		fParent(other.fParent),
135		fTreeIterator(other.fTreeIterator)
136	{
137	}
138
139	inline bool HasCurrent() const
140	{
141		return fTreeIterator.Current();
142	}
143
144	inline Value* Current()
145	{
146		if (AVLTreeNode* node = fTreeIterator.Current())
147			return fParent->_GetValue(node);
148		return NULL;
149	}
150
151	inline bool HasNext() const
152	{
153		return fTreeIterator.HasNext();
154	}
155
156	inline Value* Next()
157	{
158		if (AVLTreeNode* node = fTreeIterator.Next())
159			return fParent->_GetValue(node);
160		return NULL;
161	}
162
163	inline Value* Previous()
164	{
165		if (AVLTreeNode* node = fTreeIterator.Previous())
166			return fParent->_GetValue(node);
167		return NULL;
168	}
169
170	inline ConstIterator& operator=(const ConstIterator& other)
171	{
172		fParent = other.fParent;
173		fTreeIterator = other.fTreeIterator;
174		return *this;
175	}
176
177protected:
178	inline ConstIterator(const AVLTree<Definition>* parent,
179		const AVLTreeIterator& treeIterator)
180	{
181		fParent = parent;
182		fTreeIterator = treeIterator;
183	}
184
185	friend class AVLTree<Definition>;
186
187	const AVLTree<Definition>*	fParent;
188	AVLTreeIterator				fTreeIterator;
189};
190
191
192template<typename Definition>
193AVLTree<Definition>::AVLTree()
194	:
195	fTree(this),
196	fDefinition()
197{
198}
199
200
201template<typename Definition>
202AVLTree<Definition>::AVLTree(const Definition& definition)
203	:
204	fTree(this),
205	fDefinition(definition)
206{
207}
208
209
210template<typename Definition>
211AVLTree<Definition>::~AVLTree()
212{
213}
214
215
216template<typename Definition>
217inline void
218AVLTree<Definition>::Clear()
219{
220	fTree.MakeEmpty();
221}
222
223
224template<typename Definition>
225inline typename AVLTree<Definition>::Value*
226AVLTree<Definition>::RootNode() const
227{
228	if (AVLTreeNode* root = fTree.Root())
229		return _GetValue(root);
230	return NULL;
231}
232
233
234template<typename Definition>
235inline typename AVLTree<Definition>::Value*
236AVLTree<Definition>::Previous(Value* value) const
237{
238	if (value == NULL)
239		return NULL;
240
241	AVLTreeNode* node = fTree.Previous(_GetAVLTreeNode(value));
242	return node != NULL ? _GetValue(node) : NULL;
243}
244
245
246template<typename Definition>
247inline typename AVLTree<Definition>::Value*
248AVLTree<Definition>::Next(Value* value) const
249{
250	if (value == NULL)
251		return NULL;
252
253	AVLTreeNode* node = fTree.Next(_GetAVLTreeNode(value));
254	return node != NULL ? _GetValue(node) : NULL;
255}
256
257
258template<typename Definition>
259inline typename AVLTree<Definition>::Iterator
260AVLTree<Definition>::GetIterator()
261{
262	return Iterator(this, fTree.GetIterator());
263}
264
265
266template<typename Definition>
267inline typename AVLTree<Definition>::ConstIterator
268AVLTree<Definition>::GetIterator() const
269{
270	return ConstIterator(this, fTree.GetIterator());
271}
272
273
274template<typename Definition>
275inline typename AVLTree<Definition>::Iterator
276AVLTree<Definition>::GetIterator(Value* value)
277{
278	return Iterator(this, fTree.GetIterator(_GetAVLTreeNode(value)));
279}
280
281
282template<typename Definition>
283inline typename AVLTree<Definition>::ConstIterator
284AVLTree<Definition>::GetIterator(Value* value) const
285{
286	return ConstIterator(this, fTree.GetIterator(_GetAVLTreeNode(value)));
287}
288
289
290template<typename Definition>
291typename AVLTree<Definition>::Value*
292AVLTree<Definition>::Find(const Key& key) const
293{
294	if (AVLTreeNode* node = fTree.Find(&key))
295		return _GetValue(node);
296	return NULL;
297}
298
299
300template<typename Definition>
301typename AVLTree<Definition>::Value*
302AVLTree<Definition>::FindClosest(const Key& key, bool less) const
303{
304	if (AVLTreeNode* node = fTree.FindClosest(&key, less))
305		return _GetValue(node);
306	return NULL;
307}
308
309
310template<typename Definition>
311status_t
312AVLTree<Definition>::Insert(Value* value, Iterator* iterator)
313{
314	AVLTreeNode* node = _GetAVLTreeNode(value);
315	status_t error = fTree.Insert(node);
316	if (error != B_OK)
317		return error;
318
319	if (iterator != NULL)
320		*iterator = Iterator(this, fTree.GetIterator(node));
321
322	return B_OK;
323}
324
325
326template<typename Definition>
327typename AVLTree<Definition>::Value*
328AVLTree<Definition>::Remove(const Key& key)
329{
330	AVLTreeNode* node = fTree.Remove(&key);
331	return node != NULL ? _GetValue(node) : NULL;
332}
333
334
335template<typename Definition>
336bool
337AVLTree<Definition>::Remove(Value* value)
338{
339	return fTree.Remove(_GetAVLTreeNode(value));
340}
341
342
343template<typename Definition>
344int
345AVLTree<Definition>::CompareKeyNode(const void* key,
346	const AVLTreeNode* node)
347{
348	return _Compare(*(const Key*)key, _GetValue(node));
349}
350
351
352template<typename Definition>
353int
354AVLTree<Definition>::CompareNodes(const AVLTreeNode* node1,
355	const AVLTreeNode* node2)
356{
357	return _Compare(_GetValue(node1), _GetValue(node2));
358}
359
360
361template<typename Definition>
362inline AVLTreeNode*
363AVLTree<Definition>::_GetAVLTreeNode(Value* value) const
364{
365	return fDefinition.GetAVLTreeNode(value);
366}
367
368
369template<typename Definition>
370inline typename AVLTree<Definition>::Value*
371AVLTree<Definition>::_GetValue(const AVLTreeNode* node) const
372{
373	return fDefinition.GetValue(const_cast<AVLTreeNode*>(node));
374}
375
376
377template<typename Definition>
378inline int
379AVLTree<Definition>::_Compare(const Key& a, const Value* b)
380{
381	return fDefinition.Compare(a, b);
382}
383
384
385template<typename Definition>
386inline int
387AVLTree<Definition>::_Compare(const Value* a, const Value* b)
388{
389	return fDefinition.Compare(a, b);
390}
391
392
393#endif	// _KERNEL_UTIL_AVL_TREE_H
394