1/*
2 * Copyright 2008-2009, Ingo Weinhold <ingo_weinhold@gmx.de>.
3 * Distributed under the terms of the MIT License.
4 *
5 * Original Java implementation:
6 * Available at http://www.link.cs.cmu.edu/splay/
7 * Author: Danny Sleator <sleator@cs.cmu.edu>
8 * This code is in the public domain.
9 */
10#ifndef KERNEL_UTIL_SPLAY_TREE_H
11#define KERNEL_UTIL_SPLAY_TREE_H
12
13/*!	Implements two classes:
14
15	SplayTree: A top-down splay tree.
16
17	IteratableSplayTree: Extends SplayTree by a singly-linked list to make it
18	cheaply iteratable (requires another pointer per node).
19
20	Both classes are templatized over a definition parameter with the following
21	(or a compatible) interface:
22
23	struct SplayTreeDefinition {
24		typedef xxx KeyType;
25		typedef	yyy NodeType;
26
27		static const KeyType& GetKey(const NodeType* node);
28		static SplayTreeLink<NodeType>* GetLink(NodeType* node);
29
30		static int Compare(const KeyType& key, const NodeType* node);
31
32		// for IteratableSplayTree only
33		static NodeType** GetListLink(NodeType* node);
34	};
35*/
36
37
38template<typename Node>
39struct SplayTreeLink {
40	Node*	left;
41	Node*	right;
42};
43
44
45template<typename Definition>
46class SplayTree {
47protected:
48	typedef typename Definition::KeyType	Key;
49	typedef typename Definition::NodeType	Node;
50	typedef SplayTreeLink<Node>				Link;
51
52public:
53	SplayTree()
54		:
55		fRoot(NULL)
56	{
57	}
58
59	/*!
60		Insert into the tree.
61		\param node the item to insert.
62	*/
63	bool Insert(Node* node)
64	{
65		Link* nodeLink = Definition::GetLink(node);
66
67		if (fRoot == NULL) {
68			fRoot = node;
69			nodeLink->left = NULL;
70			nodeLink->right = NULL;
71			return true;
72		}
73
74		Key key = Definition::GetKey(node);
75		_Splay(key);
76
77		int c = Definition::Compare(key, fRoot);
78		if (c == 0)
79			return false;
80
81		Link* rootLink = Definition::GetLink(fRoot);
82
83		if (c < 0) {
84			nodeLink->left = rootLink->left;
85			nodeLink->right = fRoot;
86			rootLink->left = NULL;
87		} else {
88			nodeLink->right = rootLink->right;
89			nodeLink->left = fRoot;
90			rootLink->right = NULL;
91		}
92
93		fRoot = node;
94		return true;
95	}
96
97	Node* Remove(const Key& key)
98	{
99		if (fRoot == NULL)
100			return NULL;
101
102		_Splay(key);
103
104		if (Definition::Compare(key, fRoot) != 0)
105			return NULL;
106
107		// Now delete the root
108		Node* node = fRoot;
109		Link* rootLink = Definition::GetLink(fRoot);
110		if (rootLink->left == NULL) {
111			fRoot = rootLink->right;
112		} else {
113			Node* temp = rootLink->right;
114			fRoot = rootLink->left;
115			_Splay(key);
116			Definition::GetLink(fRoot)->right = temp;
117		}
118
119		return node;
120	}
121
122	/*!
123		Remove from the tree.
124		\param node the item to remove.
125	*/
126	bool Remove(Node* node)
127	{
128		Key key = Definition::GetKey(node);
129		_Splay(key);
130
131		if (node != fRoot)
132			return false;
133
134		// Now delete the root
135		Link* rootLink = Definition::GetLink(fRoot);
136		if (rootLink->left == NULL) {
137			fRoot = rootLink->right;
138		} else {
139			Node* temp = rootLink->right;
140			fRoot = rootLink->left;
141			_Splay(key);
142			Definition::GetLink(fRoot)->right = temp;
143		}
144
145		return true;
146	}
147
148	/*!
149		Find the smallest item in the tree.
150	*/
151	Node* FindMin()
152	{
153		if (fRoot == NULL)
154			return NULL;
155
156		Node* node = fRoot;
157
158		while (Node* left = Definition::GetLink(node)->left)
159			node = left;
160
161		_Splay(Definition::GetKey(node));
162
163		return node;
164	}
165
166	/*!
167		Find the largest item in the tree.
168	*/
169	Node* FindMax()
170	{
171		if (fRoot == NULL)
172			return NULL;
173
174		Node* node = fRoot;
175
176		while (Node* right = Definition::GetLink(node)->right)
177			node = right;
178
179		_Splay(Definition::GetKey(node));
180
181		return node;
182	}
183
184	/*!
185		Find an item in the tree.
186	*/
187	Node* Lookup(const Key& key)
188	{
189		if (fRoot == NULL)
190			return NULL;
191
192		_Splay(key);
193
194		return Definition::Compare(key, fRoot) == 0 ? fRoot : NULL;
195	}
196
197	Node* Root() const
198	{
199		return fRoot;
200	}
201
202	/*!
203		Test if the tree is logically empty.
204		\return true if empty, false otherwise.
205	*/
206	bool IsEmpty() const
207	{
208		return fRoot == NULL;
209	}
210
211	Node* PreviousDontSplay(const Key& key) const
212	{
213		Node* closestNode = NULL;
214		Node* node = fRoot;
215		while (node != NULL) {
216			if (Definition::Compare(key, node) > 0) {
217				closestNode = node;
218				node = Definition::GetLink(node)->right;
219			} else
220				node = Definition::GetLink(node)->left;
221		}
222
223		return closestNode;
224	}
225
226	Node* FindClosest(const Key& key, bool greater, bool orEqual)
227	{
228		if (fRoot == NULL)
229			return NULL;
230
231		_Splay(key);
232
233		Node* closestNode = NULL;
234		Node* node = fRoot;
235		while (node != NULL) {
236			int compare = Definition::Compare(key, node);
237			if (compare == 0 && orEqual)
238				return node;
239
240			if (greater) {
241				if (compare < 0) {
242					closestNode = node;
243					node = Definition::GetLink(node)->left;
244				} else
245					node = Definition::GetLink(node)->right;
246			} else {
247				if (compare > 0) {
248					closestNode = node;
249					node = Definition::GetLink(node)->right;
250				} else
251					node = Definition::GetLink(node)->left;
252			}
253		}
254
255		return closestNode;
256	}
257
258	SplayTree& operator=(const SplayTree& other)
259	{
260		fRoot = other.fRoot;
261		return *this;
262	}
263
264private:
265	/*!
266		Internal method to perform a top-down splay.
267
268		_Splay(key) does the splay operation on the given key.
269		If key is in the tree, then the node containing
270		that key becomes the root. If key is not in the tree,
271		then after the splay, key.root is either the greatest key
272		< key in the tree, or the least key > key in the tree.
273
274		This means, among other things, that if you splay with
275		a key that's larger than any in the tree, the rightmost
276		node of the tree becomes the root. This property is used
277		in the Remove() method.
278	*/
279	void _Splay(const Key& key) {
280		Link headerLink;
281		headerLink.left = headerLink.right = NULL;
282
283		Link* lLink = &headerLink;
284		Link* rLink = &headerLink;
285
286		Node* l = NULL;
287		Node* r = NULL;
288		Node* t = fRoot;
289
290		for (;;) {
291			int c = Definition::Compare(key, t);
292			if (c < 0) {
293				Node*& left = Definition::GetLink(t)->left;
294				if (left == NULL)
295					break;
296
297				if (Definition::Compare(key, left) < 0) {
298					// rotate right
299					Node* y = left;
300					Link* yLink = Definition::GetLink(y);
301					left = yLink->right;
302					yLink->right = t;
303					t = y;
304					if (yLink->left == NULL)
305						break;
306				}
307
308				// link right
309				rLink->left = t;
310				r = t;
311				rLink = Definition::GetLink(r);
312				t = rLink->left;
313			} else if (c > 0) {
314				Node*& right = Definition::GetLink(t)->right;
315				if (right == NULL)
316					break;
317
318				if (Definition::Compare(key, right) > 0) {
319					// rotate left
320					Node* y = right;
321					Link* yLink = Definition::GetLink(y);
322					right = yLink->left;
323					yLink->left = t;
324					t = y;
325					if (yLink->right == NULL)
326						break;
327				}
328
329				// link left
330				lLink->right = t;
331				l = t;
332				lLink = Definition::GetLink(l);
333				t = lLink->right;
334			} else
335				break;
336		}
337
338		// assemble
339		Link* tLink = Definition::GetLink(t);
340		lLink->right = tLink->left;
341		rLink->left = tLink->right;
342		tLink->left = headerLink.right;
343		tLink->right = headerLink.left;
344		fRoot = t;
345	}
346
347protected:
348	Node*	fRoot;
349};
350
351
352template<typename Definition>
353class IteratableSplayTree {
354protected:
355	typedef typename Definition::KeyType	Key;
356	typedef typename Definition::NodeType	Node;
357	typedef SplayTreeLink<Node>				Link;
358	typedef IteratableSplayTree<Definition>	Tree;
359
360public:
361	class Iterator {
362	public:
363		Iterator()
364		{
365		}
366
367		Iterator(const Iterator& other)
368		{
369			*this = other;
370		}
371
372		Iterator(Tree* tree)
373			:
374			fTree(tree)
375		{
376			Rewind();
377		}
378
379		Iterator(Tree* tree, Node* next)
380			:
381			fTree(tree),
382			fCurrent(NULL),
383			fNext(next)
384		{
385		}
386
387		bool HasNext() const
388		{
389			return fNext != NULL;
390		}
391
392		Node* Next()
393		{
394			fCurrent = fNext;
395			if (fNext != NULL)
396				fNext = *Definition::GetListLink(fNext);
397			return fCurrent;
398		}
399
400		Node* Current()
401		{
402			return fCurrent;
403		}
404
405		Node* Remove()
406		{
407			Node* element = fCurrent;
408			if (fCurrent) {
409				fTree->Remove(fCurrent);
410				fCurrent = NULL;
411			}
412			return element;
413		}
414
415		Iterator &operator=(const Iterator &other)
416		{
417			fTree = other.fTree;
418			fCurrent = other.fCurrent;
419			fNext = other.fNext;
420			return *this;
421		}
422
423		void Rewind()
424		{
425			fCurrent = NULL;
426			fNext = fTree->fFirst;
427		}
428
429	private:
430		Tree*	fTree;
431		Node*	fCurrent;
432		Node*	fNext;
433	};
434
435	class ConstIterator {
436	public:
437		ConstIterator()
438		{
439		}
440
441		ConstIterator(const ConstIterator& other)
442		{
443			*this = other;
444		}
445
446		ConstIterator(const Tree* tree)
447			:
448			fTree(tree)
449		{
450			Rewind();
451		}
452
453		ConstIterator(const Tree* tree, Node* next)
454			:
455			fTree(tree),
456			fNext(next)
457		{
458		}
459
460		bool HasNext() const
461		{
462			return fNext != NULL;
463		}
464
465		Node* Next()
466		{
467			Node* node = fNext;
468			if (fNext != NULL)
469				fNext = *Definition::GetListLink(fNext);
470			return node;
471		}
472
473		ConstIterator &operator=(const ConstIterator &other)
474		{
475			fTree = other.fTree;
476			fNext = other.fNext;
477			return *this;
478		}
479
480		void Rewind()
481		{
482			fNext = fTree->fFirst;
483		}
484
485	private:
486		const Tree*	fTree;
487		Node*		fNext;
488	};
489
490	IteratableSplayTree()
491		:
492		fTree(),
493		fFirst(NULL)
494	{
495	}
496
497	bool Insert(Node* node)
498	{
499		if (!fTree.Insert(node))
500			return false;
501
502		Node** previousNext;
503		if (Node* previous = fTree.PreviousDontSplay(Definition::GetKey(node)))
504			previousNext = Definition::GetListLink(previous);
505		else
506			previousNext = &fFirst;
507
508		*Definition::GetListLink(node) = *previousNext;
509		*previousNext = node;
510
511		return true;
512	}
513
514	Node* Remove(const Key& key)
515	{
516		Node* node = fTree.Remove(key);
517		if (node == NULL)
518			return NULL;
519
520		Node** previousNext;
521		if (Node* previous = fTree.PreviousDontSplay(key))
522			previousNext = Definition::GetListLink(previous);
523		else
524			previousNext = &fFirst;
525
526		*previousNext = *Definition::GetListLink(node);
527
528		return node;
529	}
530
531	bool Remove(Node* node)
532	{
533		if (!fTree.Remove(node))
534			return false;
535
536		Node** previousNext;
537		if (Node* previous = fTree.PreviousDontSplay(Definition::GetKey(node)))
538			previousNext = Definition::GetListLink(previous);
539		else
540			previousNext = &fFirst;
541
542		*previousNext = *Definition::GetListLink(node);
543
544		return true;
545	}
546
547	Node* Lookup(const Key& key)
548	{
549		return fTree.Lookup(key);
550	}
551
552	Node* Root() const
553	{
554		return fTree.Root();
555	}
556
557	/*!
558		Test if the tree is logically empty.
559		\return true if empty, false otherwise.
560	*/
561	bool IsEmpty() const
562	{
563		return fTree.IsEmpty();
564	}
565
566	Node* FindClosest(const Key& key, bool greater, bool orEqual)
567	{
568		return fTree.FindClosest(key, greater, orEqual);
569	}
570
571	Node* FindMin()
572	{
573		return fTree.FindMin();
574	}
575
576	Node* FindMax()
577	{
578		return fTree.FindMax();
579	}
580
581	Iterator GetIterator()
582	{
583		return Iterator(this);
584	}
585
586	ConstIterator GetIterator() const
587	{
588		return ConstIterator(this);
589	}
590
591	Iterator GetIterator(const Key& key, bool greater, bool orEqual)
592	{
593		return Iterator(this, fTree.FindClosest(key, greater, orEqual));
594	}
595
596	ConstIterator GetIterator(const Key& key, bool greater, bool orEqual) const
597	{
598		return ConstIterator(this, FindClosest(key, greater, orEqual));
599	}
600
601	IteratableSplayTree& operator=(const IteratableSplayTree& other)
602	{
603		fTree = other.fTree;
604		fFirst = other.fFirst;
605		return *this;
606	}
607
608protected:
609	friend class Iterator;
610	friend class ConstIterator;
611		// needed for gcc 2.95.3 only
612
613	SplayTree<Definition>	fTree;
614	Node*					fFirst;
615};
616
617
618#endif	// KERNEL_UTIL_SPLAY_TREE_H
619