1/*
2 * Copyright (c) 2004, 2005 Topspin Communications.  All rights reserved.
3 * Copyright (c) 2006 Cisco Systems, Inc.  All rights reserved.
4 *
5 * This software is available to you under a choice of one of two
6 * licenses.  You may choose to be licensed under the terms of the GNU
7 * General Public License (GPL) Version 2, available from the file
8 * COPYING in the main directory of this source tree, or the
9 * OpenIB.org BSD license below:
10 *
11 *     Redistribution and use in source and binary forms, with or
12 *     without modification, are permitted provided that the following
13 *     conditions are met:
14 *
15 *      - Redistributions of source code must retain the above
16 *        copyright notice, this list of conditions and the following
17 *        disclaimer.
18 *
19 *      - Redistributions in binary form must reproduce the above
20 *        copyright notice, this list of conditions and the following
21 *        disclaimer in the documentation and/or other materials
22 *        provided with the distribution.
23 *
24 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
25 * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
26 * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
27 * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
28 * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
29 * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
30 * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
31 * SOFTWARE.
32 */
33
34#if HAVE_CONFIG_H
35#  include <config.h>
36#endif /* HAVE_CONFIG_H */
37
38#include <errno.h>
39#include <sys/mman.h>
40#include <unistd.h>
41#include <stdlib.h>
42#include <stdint.h>
43
44#include "ibverbs.h"
45
46/*
47 * Most distro's headers don't have these yet.
48 */
49#ifdef	__linux__
50#ifndef MADV_DONTFORK
51#define MADV_DONTFORK	10
52#endif
53
54#ifndef MADV_DOFORK
55#define MADV_DOFORK	11
56#endif
57#else
58#define	MADV_DONTFORK	INHERIT_NONE
59#define	MADV_DOFORK	INHERIT_SHARE
60#endif
61
62struct ibv_mem_node {
63	enum {
64		IBV_RED,
65		IBV_BLACK
66	}			color;
67	struct ibv_mem_node    *parent;
68	struct ibv_mem_node    *left, *right;
69	uintptr_t		start, end;
70	int			refcnt;
71};
72
73static struct ibv_mem_node *mm_root;
74static pthread_mutex_t mm_mutex = PTHREAD_MUTEX_INITIALIZER;
75static int page_size;
76static int too_late;
77
78int ibv_fork_init(void)
79{
80#ifdef __linux__
81	void *tmp;
82	int ret;
83#endif
84
85	if (mm_root)
86		return 0;
87
88	if (too_late)
89		return EINVAL;
90
91	page_size = sysconf(_SC_PAGESIZE);
92	if (page_size < 0)
93		return errno;
94
95#ifdef __linux__
96	if (posix_memalign(&tmp, page_size, page_size))
97		return ENOMEM;
98
99	ret = madvise(tmp, page_size, MADV_DONTFORK) ||
100	      madvise(tmp, page_size, MADV_DOFORK);
101
102	free(tmp);
103
104	if (ret)
105		return ENOSYS;
106#endif
107
108	mm_root = malloc(sizeof *mm_root);
109	if (!mm_root)
110		return ENOMEM;
111
112	mm_root->parent = NULL;
113	mm_root->left   = NULL;
114	mm_root->right  = NULL;
115	mm_root->color  = IBV_BLACK;
116	mm_root->start  = 0;
117	mm_root->end    = UINTPTR_MAX;
118	mm_root->refcnt = 0;
119
120	return 0;
121}
122
123static struct ibv_mem_node *__mm_prev(struct ibv_mem_node *node)
124{
125	if (node->left) {
126		node = node->left;
127		while (node->right)
128			node = node->right;
129	} else {
130		while (node->parent && node == node->parent->left)
131			node = node->parent;
132
133		node = node->parent;
134	}
135
136	return node;
137}
138
139static struct ibv_mem_node *__mm_next(struct ibv_mem_node *node)
140{
141	if (node->right) {
142		node = node->right;
143		while (node->left)
144			node = node->left;
145	} else {
146		while (node->parent && node == node->parent->right)
147			node = node->parent;
148
149		node = node->parent;
150	}
151
152	return node;
153}
154
155static void __mm_rotate_right(struct ibv_mem_node *node)
156{
157	struct ibv_mem_node *tmp;
158
159	tmp = node->left;
160
161	node->left = tmp->right;
162	if (node->left)
163		node->left->parent = node;
164
165	if (node->parent) {
166		if (node->parent->right == node)
167			node->parent->right = tmp;
168		else
169			node->parent->left = tmp;
170	} else
171		mm_root = tmp;
172
173	tmp->parent = node->parent;
174
175	tmp->right = node;
176	node->parent = tmp;
177}
178
179static void __mm_rotate_left(struct ibv_mem_node *node)
180{
181	struct ibv_mem_node *tmp;
182
183	tmp = node->right;
184
185	node->right = tmp->left;
186	if (node->right)
187		node->right->parent = node;
188
189	if (node->parent) {
190		if (node->parent->right == node)
191			node->parent->right = tmp;
192		else
193			node->parent->left = tmp;
194	} else
195		mm_root = tmp;
196
197	tmp->parent = node->parent;
198
199	tmp->left = node;
200	node->parent = tmp;
201}
202
203static int verify(struct ibv_mem_node *node)
204{
205	int hl, hr;
206
207	if (!node)
208		return 1;
209
210	hl = verify(node->left);
211	hr = verify(node->left);
212
213	if (!hl || !hr)
214		return 0;
215	if (hl != hr)
216		return 0;
217
218	if (node->color == IBV_RED) {
219		if (node->left && node->left->color != IBV_BLACK)
220			return 0;
221		if (node->right && node->right->color != IBV_BLACK)
222			return 0;
223		return hl;
224	}
225
226	return hl + 1;
227}
228
229static void __mm_add_rebalance(struct ibv_mem_node *node)
230{
231	struct ibv_mem_node *parent, *gp, *uncle;
232
233	while (node->parent && node->parent->color == IBV_RED) {
234		parent = node->parent;
235		gp     = node->parent->parent;
236
237		if (parent == gp->left) {
238			uncle = gp->right;
239
240			if (uncle && uncle->color == IBV_RED) {
241				parent->color = IBV_BLACK;
242				uncle->color  = IBV_BLACK;
243				gp->color     = IBV_RED;
244
245				node = gp;
246			} else {
247				if (node == parent->right) {
248					__mm_rotate_left(parent);
249					node   = parent;
250					parent = node->parent;
251				}
252
253				parent->color = IBV_BLACK;
254				gp->color     = IBV_RED;
255
256				__mm_rotate_right(gp);
257			}
258		} else {
259			uncle = gp->left;
260
261			if (uncle && uncle->color == IBV_RED) {
262				parent->color = IBV_BLACK;
263				uncle->color  = IBV_BLACK;
264				gp->color     = IBV_RED;
265
266				node = gp;
267			} else {
268				if (node == parent->left) {
269					__mm_rotate_right(parent);
270					node   = parent;
271					parent = node->parent;
272				}
273
274				parent->color = IBV_BLACK;
275				gp->color     = IBV_RED;
276
277				__mm_rotate_left(gp);
278			}
279		}
280	}
281
282	mm_root->color = IBV_BLACK;
283}
284
285static void __mm_add(struct ibv_mem_node *new)
286{
287	struct ibv_mem_node *node, *parent = NULL;
288
289	node = mm_root;
290	while (node) {
291		parent = node;
292		if (node->start < new->start)
293			node = node->right;
294		else
295			node = node->left;
296	}
297
298	if (parent->start < new->start)
299		parent->right = new;
300	else
301		parent->left = new;
302
303	new->parent = parent;
304	new->left   = NULL;
305	new->right  = NULL;
306
307	new->color = IBV_RED;
308	__mm_add_rebalance(new);
309}
310
311static void __mm_remove(struct ibv_mem_node *node)
312{
313	struct ibv_mem_node *child, *parent, *sib, *tmp;
314	int nodecol;
315
316	if (node->left && node->right) {
317		tmp = node->left;
318		while (tmp->right)
319			tmp = tmp->right;
320
321		nodecol    = tmp->color;
322		child      = tmp->left;
323		tmp->color = node->color;
324
325		if (tmp->parent != node) {
326			parent        = tmp->parent;
327			parent->right = tmp->left;
328			if (tmp->left)
329				tmp->left->parent = parent;
330
331			tmp->left   	   = node->left;
332			node->left->parent = tmp;
333		} else
334			parent = tmp;
335
336		tmp->right          = node->right;
337		node->right->parent = tmp;
338
339		tmp->parent = node->parent;
340		if (node->parent) {
341			if (node->parent->left == node)
342				node->parent->left = tmp;
343			else
344				node->parent->right = tmp;
345		} else
346			mm_root = tmp;
347	} else {
348		nodecol = node->color;
349
350		child  = node->left ? node->left : node->right;
351		parent = node->parent;
352
353		if (child)
354			child->parent = parent;
355		if (parent) {
356			if (parent->left == node)
357				parent->left = child;
358			else
359				parent->right = child;
360		} else
361			mm_root = child;
362	}
363
364	free(node);
365
366	if (nodecol == IBV_RED)
367		return;
368
369	while ((!child || child->color == IBV_BLACK) && child != mm_root) {
370		if (parent->left == child) {
371			sib = parent->right;
372
373			if (sib->color == IBV_RED) {
374				parent->color = IBV_RED;
375				sib->color    = IBV_BLACK;
376				__mm_rotate_left(parent);
377				sib = parent->right;
378			}
379
380			if ((!sib->left  || sib->left->color  == IBV_BLACK) &&
381			    (!sib->right || sib->right->color == IBV_BLACK)) {
382				sib->color = IBV_RED;
383				child  = parent;
384				parent = child->parent;
385			} else {
386				if (!sib->right || sib->right->color == IBV_BLACK) {
387					if (sib->left)
388						sib->left->color = IBV_BLACK;
389					sib->color = IBV_RED;
390					__mm_rotate_right(sib);
391					sib = parent->right;
392				}
393
394				sib->color    = parent->color;
395				parent->color = IBV_BLACK;
396				if (sib->right)
397					sib->right->color = IBV_BLACK;
398				__mm_rotate_left(parent);
399				child = mm_root;
400				break;
401			}
402		} else {
403			sib = parent->left;
404
405			if (sib->color == IBV_RED) {
406				parent->color = IBV_RED;
407				sib->color    = IBV_BLACK;
408				__mm_rotate_right(parent);
409				sib = parent->left;
410			}
411
412			if ((!sib->left  || sib->left->color  == IBV_BLACK) &&
413			    (!sib->right || sib->right->color == IBV_BLACK)) {
414				sib->color = IBV_RED;
415				child  = parent;
416				parent = child->parent;
417			} else {
418				if (!sib->left || sib->left->color == IBV_BLACK) {
419					if (sib->right)
420						sib->right->color = IBV_BLACK;
421					sib->color = IBV_RED;
422					__mm_rotate_left(sib);
423					sib = parent->left;
424				}
425
426				sib->color    = parent->color;
427				parent->color = IBV_BLACK;
428				if (sib->left)
429					sib->left->color = IBV_BLACK;
430				__mm_rotate_right(parent);
431				child = mm_root;
432				break;
433			}
434		}
435	}
436
437	if (child)
438		child->color = IBV_BLACK;
439}
440
441static struct ibv_mem_node *__mm_find_start(uintptr_t start, uintptr_t end)
442{
443	struct ibv_mem_node *node = mm_root;
444
445	while (node) {
446		if (node->start <= start && node->end >= start)
447			break;
448
449		if (node->start < start)
450			node = node->right;
451		else
452			node = node->left;
453	}
454
455	return node;
456}
457
458static struct ibv_mem_node *merge_ranges(struct ibv_mem_node *node,
459					 struct ibv_mem_node *prev)
460{
461	prev->end = node->end;
462	prev->refcnt = node->refcnt;
463	__mm_remove(node);
464
465	return prev;
466}
467
468static struct ibv_mem_node *split_range(struct ibv_mem_node *node,
469					uintptr_t cut_line)
470{
471	struct ibv_mem_node *new_node = NULL;
472
473	new_node = malloc(sizeof *new_node);
474	if (!new_node)
475		return NULL;
476	new_node->start  = cut_line;
477	new_node->end    = node->end;
478	new_node->refcnt = node->refcnt;
479	node->end  = cut_line - 1;
480	__mm_add(new_node);
481
482	return new_node;
483}
484
485static struct ibv_mem_node *get_start_node(uintptr_t start, uintptr_t end,
486					   int inc)
487{
488	struct ibv_mem_node *node, *tmp = NULL;
489
490	node = __mm_find_start(start, end);
491	if (node->start < start)
492		node = split_range(node, start);
493	else {
494		tmp = __mm_prev(node);
495		if (tmp && tmp->refcnt == node->refcnt + inc)
496			node = merge_ranges(node, tmp);
497	}
498	return node;
499}
500
501/*
502 * This function is called if madvise() fails to undo merging/splitting
503 * operations performed on the node.
504 */
505static struct ibv_mem_node *undo_node(struct ibv_mem_node *node,
506				      uintptr_t start, int inc)
507{
508	struct ibv_mem_node *tmp = NULL;
509
510	/*
511	 * This condition can be true only if we merged this
512	 * node with the previous one, so we need to split them.
513	*/
514	if (start > node->start) {
515		tmp = split_range(node, start);
516		if (tmp) {
517			node->refcnt += inc;
518			node = tmp;
519		} else
520			return NULL;
521	}
522
523	tmp  =  __mm_prev(node);
524	if (tmp && tmp->refcnt == node->refcnt)
525		node = merge_ranges(node, tmp);
526
527	tmp  =  __mm_next(node);
528	if (tmp && tmp->refcnt == node->refcnt)
529		node = merge_ranges(tmp, node);
530
531	return node;
532}
533
534static int ibv_madvise_range(void *base, size_t size, int advice)
535{
536	uintptr_t start, end;
537	struct ibv_mem_node *node, *tmp;
538	int inc;
539	int rolling_back = 0;
540	int ret = 0;
541
542	if (!size)
543		return 0;
544
545	start = (uintptr_t) base & ~(page_size - 1);
546	end   = ((uintptr_t) (base + size + page_size - 1) &
547		 ~(page_size - 1)) - 1;
548
549	pthread_mutex_lock(&mm_mutex);
550again:
551	inc = advice == MADV_DONTFORK ? 1 : -1;
552
553	node = get_start_node(start, end, inc);
554	if (!node) {
555		ret = -1;
556		goto out;
557	}
558
559	while (node && node->start <= end) {
560		if (node->end > end) {
561			if (!split_range(node, end + 1)) {
562				ret = -1;
563				goto out;
564			}
565		}
566
567		if ((inc == -1 && node->refcnt == 1) ||
568		    (inc ==  1 && node->refcnt == 0)) {
569			/*
570			 * If this is the first time through the loop,
571			 * and we merged this node with the previous
572			 * one, then we only want to do the madvise()
573			 * on start ... node->end (rather than
574			 * starting at node->start).
575			 *
576			 * Otherwise we end up doing madvise() on
577			 * bigger region than we're being asked to,
578			 * and that may lead to a spurious failure.
579			 */
580			if (start > node->start)
581				ret = minherit((void *) start, node->end - start + 1,
582					      advice);
583			else
584				ret = minherit((void *) node->start,
585					      node->end - node->start + 1,
586					      advice);
587			if (ret) {
588				node = undo_node(node, start, inc);
589
590				if (rolling_back || !node)
591					goto out;
592
593				/* madvise failed, roll back previous changes */
594				rolling_back = 1;
595				advice = advice == MADV_DONTFORK ?
596					MADV_DOFORK : MADV_DONTFORK;
597				tmp = __mm_prev(node);
598				if (!tmp || start > tmp->end)
599					goto out;
600				end = tmp->end;
601				goto again;
602			}
603		}
604
605		node->refcnt += inc;
606		node = __mm_next(node);
607	}
608
609	if (node) {
610		tmp = __mm_prev(node);
611		if (tmp && node->refcnt == tmp->refcnt)
612			node = merge_ranges(node, tmp);
613	}
614
615out:
616	if (rolling_back)
617		ret = -1;
618
619	pthread_mutex_unlock(&mm_mutex);
620
621	return ret;
622}
623
624int ibv_dontfork_range(void *base, size_t size)
625{
626	if (mm_root)
627		return ibv_madvise_range(base, size, MADV_DONTFORK);
628	else {
629		too_late = 1;
630		return 0;
631	}
632}
633
634int ibv_dofork_range(void *base, size_t size)
635{
636	if (mm_root)
637		return ibv_madvise_range(base, size, MADV_DOFORK);
638	else {
639		too_late = 1;
640		return 0;
641	}
642}
643