1139743Simp/*
2123474Swpaul * Copyright (c) 2004, 2005 Topspin Communications.  All rights reserved.
3123474Swpaul * Copyright (c) 2006 Cisco Systems, Inc.  All rights reserved.
4123474Swpaul *
5123474Swpaul * This software is available to you under a choice of one of two
6123474Swpaul * licenses.  You may choose to be licensed under the terms of the GNU
7123474Swpaul * General Public License (GPL) Version 2, available from the file
8123474Swpaul * COPYING in the main directory of this source tree, or the
9123474Swpaul * OpenIB.org BSD license below:
10123474Swpaul *
11123474Swpaul *     Redistribution and use in source and binary forms, with or
12123474Swpaul *     without modification, are permitted provided that the following
13123474Swpaul *     conditions are met:
14123474Swpaul *
15123474Swpaul *      - Redistributions of source code must retain the above
16123474Swpaul *        copyright notice, this list of conditions and the following
17123474Swpaul *        disclaimer.
18123474Swpaul *
19123474Swpaul *      - Redistributions in binary form must reproduce the above
20123474Swpaul *        copyright notice, this list of conditions and the following
21123474Swpaul *        disclaimer in the documentation and/or other materials
22123474Swpaul *        provided with the distribution.
23123474Swpaul *
24123474Swpaul * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
25123474Swpaul * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
26123474Swpaul * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
27123474Swpaul * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
28123474Swpaul * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
29123474Swpaul * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
30123474Swpaul * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
31123474Swpaul * SOFTWARE.
32123474Swpaul */
33123474Swpaul
34123474Swpaul#include <config.h>
35123474Swpaul
36123474Swpaul#include <errno.h>
37123474Swpaul#include <sys/mman.h>
38123474Swpaul#include <unistd.h>
39123474Swpaul#include <stdlib.h>
40123474Swpaul#include <stdint.h>
41123474Swpaul#include <stdio.h>
42123474Swpaul#include <string.h>
43123474Swpaul#include <dirent.h>
44123474Swpaul#include <limits.h>
45123474Swpaul#include <inttypes.h>
46123474Swpaul
47123474Swpaul#include "ibverbs.h"
48123474Swpaul
49123474Swpaulstruct ibv_mem_node {
50123474Swpaul	enum {
51123474Swpaul		IBV_RED,
52123474Swpaul		IBV_BLACK
53123474Swpaul	}			color;
54123474Swpaul	struct ibv_mem_node    *parent;
55123474Swpaul	struct ibv_mem_node    *left, *right;
56123474Swpaul	uintptr_t		start, end;
57123474Swpaul	int			refcnt;
58151703Swpaul};
59123474Swpaul
60123474Swpaulstatic struct ibv_mem_node *mm_root;
61123474Swpaulstatic pthread_mutex_t mm_mutex = PTHREAD_MUTEX_INITIALIZER;
62123474Swpaulstatic int page_size;
63123474Swpaulstatic int huge_page_enabled;
64123474Swpaulstatic int too_late;
65123474Swpaul
66124502Sobrienstatic unsigned long smaps_page_size(FILE *file)
67123474Swpaul{
68123474Swpaul	int n;
69123474Swpaul	unsigned long size = page_size;
70123474Swpaul	char buf[1024];
71123474Swpaul
72123474Swpaul	while (fgets(buf, sizeof(buf), file) != NULL) {
73123474Swpaul		if (!strstr(buf, "KernelPageSize:"))
74123474Swpaul			continue;
75123474Swpaul
76123474Swpaul		n = sscanf(buf, "%*s %lu", &size);
77123474Swpaul		if (n < 1)
78123474Swpaul			continue;
79123474Swpaul
80123474Swpaul		/* page size is printed in Kb */
81123474Swpaul		size = size * 1024;
82123474Swpaul
83123821Swpaul		break;
84123474Swpaul	}
85123474Swpaul
86123474Swpaul	return size;
87123474Swpaul}
88123474Swpaul
89123474Swpaulstatic unsigned long get_page_size(void *base)
90123474Swpaul{
91123474Swpaul	unsigned long ret = page_size;
92198786Srpaulo	pid_t pid;
93123474Swpaul	FILE *file;
94123474Swpaul	char buf[1024];
95123474Swpaul
96123474Swpaul	pid = getpid();
97123474Swpaul	snprintf(buf, sizeof(buf), "/proc/%d/smaps", pid);
98123474Swpaul
99123474Swpaul	file = fopen(buf, "r" STREAM_CLOEXEC);
100123474Swpaul	if (!file)
101123474Swpaul		goto out;
102123474Swpaul
103123474Swpaul	while (fgets(buf, sizeof(buf), file) != NULL) {
104123474Swpaul		int n;
105123474Swpaul		uintptr_t range_start, range_end;
106123821Swpaul
107123474Swpaul		n = sscanf(buf, "%" SCNxPTR "-%" SCNxPTR, &range_start, &range_end);
108123474Swpaul
109123474Swpaul		if (n < 2)
110123474Swpaul			continue;
111123474Swpaul
112123474Swpaul		if ((uintptr_t) base >= range_start && (uintptr_t) base < range_end) {
113123474Swpaul			ret = smaps_page_size(file);
114198786Srpaulo			break;
115123474Swpaul		}
116123474Swpaul	}
117198786Srpaulo
118123474Swpaul	fclose(file);
119123474Swpaul
120123474Swpaulout:
121123474Swpaul	return ret;
122123474Swpaul}
123123474Swpaul
124123474Swpaulint ibv_fork_init(void)
125123474Swpaul{
126123474Swpaul	void *tmp, *tmp_aligned;
127123474Swpaul	int ret;
128123474Swpaul	unsigned long size;
129123474Swpaul
130123474Swpaul	if (getenv("RDMAV_HUGEPAGES_SAFE"))
131123474Swpaul		huge_page_enabled = 1;
132123474Swpaul
133123474Swpaul	if (mm_root)
134123821Swpaul		return 0;
135198786Srpaulo
136123474Swpaul	if (too_late)
137123474Swpaul		return EINVAL;
138123474Swpaul
139123474Swpaul	page_size = sysconf(_SC_PAGESIZE);
140123474Swpaul	if (page_size < 0)
141123474Swpaul		return errno;
142123474Swpaul
143123474Swpaul	if (posix_memalign(&tmp, page_size, page_size))
144151703Swpaul		return ENOMEM;
145123474Swpaul
146198786Srpaulo	if (huge_page_enabled) {
147123474Swpaul		size = get_page_size(tmp);
148123474Swpaul		tmp_aligned = (void *) ((uintptr_t) tmp & ~(size - 1));
149123474Swpaul	} else {
150123474Swpaul		size = page_size;
151123474Swpaul		tmp_aligned = tmp;
152123474Swpaul	}
153123474Swpaul
154123474Swpaul	ret = madvise(tmp_aligned, size, MADV_DONTFORK) ||
155123474Swpaul	      madvise(tmp_aligned, size, MADV_DOFORK);
156123474Swpaul
157123474Swpaul	free(tmp);
158123474Swpaul
159123474Swpaul	if (ret)
160123474Swpaul		return ENOSYS;
161123474Swpaul
162123821Swpaul	mm_root = malloc(sizeof *mm_root);
163198786Srpaulo	if (!mm_root)
164123474Swpaul		return ENOMEM;
165123474Swpaul
166123474Swpaul	mm_root->parent = NULL;
167123474Swpaul	mm_root->left   = NULL;
168123474Swpaul	mm_root->right  = NULL;
169123474Swpaul	mm_root->color  = IBV_BLACK;
170123474Swpaul	mm_root->start  = 0;
171151703Swpaul	mm_root->end    = UINTPTR_MAX;
172151703Swpaul	mm_root->refcnt = 0;
173151703Swpaul
174151703Swpaul	return 0;
175151703Swpaul}
176151703Swpaul
177151703Swpaulstatic struct ibv_mem_node *__mm_prev(struct ibv_mem_node *node)
178151703Swpaul{
179123474Swpaul	if (node->left) {
180123474Swpaul		node = node->left;
181123474Swpaul		while (node->right)
182198786Srpaulo			node = node->right;
183123474Swpaul	} else {
184123474Swpaul		while (node->parent && node == node->parent->left)
185123474Swpaul			node = node->parent;
186123474Swpaul
187123474Swpaul		node = node->parent;
188123474Swpaul	}
189123474Swpaul
190123474Swpaul	return node;
191123474Swpaul}
192123474Swpaul
193123474Swpaulstatic struct ibv_mem_node *__mm_next(struct ibv_mem_node *node)
194123474Swpaul{
195123474Swpaul	if (node->right) {
196123474Swpaul		node = node->right;
197123474Swpaul		while (node->left)
198123474Swpaul			node = node->left;
199123821Swpaul	} else {
200198786Srpaulo		while (node->parent && node == node->parent->right)
201123474Swpaul			node = node->parent;
202123474Swpaul
203123474Swpaul		node = node->parent;
204123474Swpaul	}
205123474Swpaul
206123474Swpaul	return node;
207151703Swpaul}
208123474Swpaul
209123474Swpaulstatic void __mm_rotate_right(struct ibv_mem_node *node)
210123474Swpaul{
211198786Srpaulo	struct ibv_mem_node *tmp;
212123474Swpaul
213123474Swpaul	tmp = node->left;
214123474Swpaul
215123474Swpaul	node->left = tmp->right;
216123474Swpaul	if (node->left)
217123474Swpaul		node->left->parent = node;
218123474Swpaul
219123474Swpaul	if (node->parent) {
220123474Swpaul		if (node->parent->right == node)
221123474Swpaul			node->parent->right = tmp;
222123474Swpaul		else
223123474Swpaul			node->parent->left = tmp;
224123474Swpaul	} else
225198786Srpaulo		mm_root = tmp;
226123474Swpaul
227123474Swpaul	tmp->parent = node->parent;
228123474Swpaul
229123474Swpaul	tmp->right = node;
230123474Swpaul	node->parent = tmp;
231123474Swpaul}
232123474Swpaul
233123474Swpaulstatic void __mm_rotate_left(struct ibv_mem_node *node)
234123474Swpaul{
235123474Swpaul	struct ibv_mem_node *tmp;
236123474Swpaul
237123474Swpaul	tmp = node->right;
238123474Swpaul
239123474Swpaul	node->right = tmp->left;
240123474Swpaul	if (node->right)
241123474Swpaul		node->right->parent = node;
242198786Srpaulo
243123474Swpaul	if (node->parent) {
244123474Swpaul		if (node->parent->right == node)
245123474Swpaul			node->parent->right = tmp;
246123474Swpaul		else
247123474Swpaul			node->parent->left = tmp;
248123474Swpaul	} else
249123474Swpaul		mm_root = tmp;
250123474Swpaul
251123474Swpaul	tmp->parent = node->parent;
252123474Swpaul
253123474Swpaul	tmp->left = node;
254123474Swpaul	node->parent = tmp;
255123474Swpaul}
256123474Swpaul
257123474Swpaul#if 0
258123474Swpaulstatic int verify(struct ibv_mem_node *node)
259123474Swpaul{
260123474Swpaul	int hl, hr;
261198786Srpaulo
262123474Swpaul	if (!node)
263123474Swpaul		return 1;
264198786Srpaulo
265123474Swpaul	hl = verify(node->left);
266123474Swpaul	hr = verify(node->left);
267123474Swpaul
268198786Srpaulo	if (!hl || !hr)
269123474Swpaul		return 0;
270123474Swpaul	if (hl != hr)
271123474Swpaul		return 0;
272123474Swpaul
273123474Swpaul	if (node->color == IBV_RED) {
274141963Swpaul		if (node->left && node->left->color != IBV_BLACK)
275123474Swpaul			return 0;
276123474Swpaul		if (node->right && node->right->color != IBV_BLACK)
277123474Swpaul			return 0;
278123474Swpaul		return hl;
279123474Swpaul	}
280123474Swpaul
281123474Swpaul	return hl + 1;
282123474Swpaul}
283198786Srpaulo#endif
284123474Swpaul
285123474Swpaulstatic void __mm_add_rebalance(struct ibv_mem_node *node)
286123474Swpaul{
287123474Swpaul	struct ibv_mem_node *parent, *gp, *uncle;
288123474Swpaul
289151703Swpaul	while (node->parent && node->parent->color == IBV_RED) {
290123474Swpaul		parent = node->parent;
291123474Swpaul		gp     = node->parent->parent;
292123474Swpaul
293123474Swpaul		if (parent == gp->left) {
294123474Swpaul			uncle = gp->right;
295123474Swpaul
296123474Swpaul			if (uncle && uncle->color == IBV_RED) {
297123474Swpaul				parent->color = IBV_BLACK;
298123474Swpaul				uncle->color  = IBV_BLACK;
299123474Swpaul				gp->color     = IBV_RED;
300123474Swpaul
301123474Swpaul				node = gp;
302123474Swpaul			} else {
303123474Swpaul				if (node == parent->right) {
304123474Swpaul					__mm_rotate_left(parent);
305142037Swpaul					node   = parent;
306123474Swpaul					parent = node->parent;
307142037Swpaul				}
308123474Swpaul
309123474Swpaul				parent->color = IBV_BLACK;
310123474Swpaul				gp->color     = IBV_RED;
311123474Swpaul
312123474Swpaul				__mm_rotate_right(gp);
313198786Srpaulo			}
314123474Swpaul		} else {
315198786Srpaulo			uncle = gp->left;
316123474Swpaul
317123474Swpaul			if (uncle && uncle->color == IBV_RED) {
318123474Swpaul				parent->color = IBV_BLACK;
319123474Swpaul				uncle->color  = IBV_BLACK;
320123474Swpaul				gp->color     = IBV_RED;
321123474Swpaul
322123474Swpaul				node = gp;
323123474Swpaul			} else {
324123474Swpaul				if (node == parent->left) {
325123474Swpaul					__mm_rotate_right(parent);
326123474Swpaul					node   = parent;
327123474Swpaul					parent = node->parent;
328123474Swpaul				}
329123474Swpaul
330123474Swpaul				parent->color = IBV_BLACK;
331123474Swpaul				gp->color     = IBV_RED;
332123474Swpaul
333123474Swpaul				__mm_rotate_left(gp);
334123474Swpaul			}
335123474Swpaul		}
336123474Swpaul	}
337123821Swpaul
338198786Srpaulo	mm_root->color = IBV_BLACK;
339123474Swpaul}
340123474Swpaul
341123474Swpaulstatic void __mm_add(struct ibv_mem_node *new)
342123474Swpaul{
343123474Swpaul	struct ibv_mem_node *node, *parent = NULL;
344123474Swpaul
345123474Swpaul	node = mm_root;
346123474Swpaul	while (node) {
347151703Swpaul		parent = node;
348123474Swpaul		if (node->start < new->start)
349123474Swpaul			node = node->right;
350123474Swpaul		else
351123474Swpaul			node = node->left;
352123474Swpaul	}
353198786Srpaulo
354123474Swpaul	if (parent->start < new->start)
355123474Swpaul		parent->right = new;
356123474Swpaul	else
357123474Swpaul		parent->left = new;
358123474Swpaul
359123474Swpaul	new->parent = parent;
360123474Swpaul	new->left   = NULL;
361123474Swpaul	new->right  = NULL;
362123474Swpaul
363123474Swpaul	new->color = IBV_RED;
364123474Swpaul	__mm_add_rebalance(new);
365123474Swpaul}
366123474Swpaul
367123474Swpaulstatic void __mm_remove(struct ibv_mem_node *node)
368123474Swpaul{
369123474Swpaul	struct ibv_mem_node *child, *parent, *sib, *tmp;
370123474Swpaul	int nodecol;
371123474Swpaul
372123474Swpaul	if (node->left && node->right) {
373123474Swpaul		tmp = node->left;
374123474Swpaul		while (tmp->right)
375141963Swpaul			tmp = tmp->right;
376141963Swpaul
377141963Swpaul		nodecol    = tmp->color;
378141963Swpaul		child      = tmp->left;
379123474Swpaul		tmp->color = node->color;
380123474Swpaul
381123474Swpaul		if (tmp->parent != node) {
382123474Swpaul			parent        = tmp->parent;
383123474Swpaul			parent->right = tmp->left;
384123474Swpaul			if (tmp->left)
385123474Swpaul				tmp->left->parent = parent;
386123474Swpaul
387123474Swpaul			tmp->left   	   = node->left;
388123474Swpaul			node->left->parent = tmp;
389123474Swpaul		} else
390123474Swpaul			parent = tmp;
391123474Swpaul
392123474Swpaul		tmp->right          = node->right;
393123474Swpaul		node->right->parent = tmp;
394123474Swpaul
395123474Swpaul		tmp->parent = node->parent;
396123474Swpaul		if (node->parent) {
397123474Swpaul			if (node->parent->left == node)
398123474Swpaul				node->parent->left = tmp;
399123474Swpaul			else
400123474Swpaul				node->parent->right = tmp;
401123474Swpaul		} else
402123474Swpaul			mm_root = tmp;
403123474Swpaul	} else {
404123474Swpaul		nodecol = node->color;
405123474Swpaul
406123474Swpaul		child  = node->left ? node->left : node->right;
407123474Swpaul		parent = node->parent;
408123474Swpaul
409123474Swpaul		if (child)
410123474Swpaul			child->parent = parent;
411123474Swpaul		if (parent) {
412123474Swpaul			if (parent->left == node)
413123474Swpaul				parent->left = child;
414123474Swpaul			else
415141963Swpaul				parent->right = child;
416141963Swpaul		} else
417141963Swpaul			mm_root = child;
418141963Swpaul	}
419141963Swpaul
420189488Sweongyo	free(node);
421141963Swpaul
422123474Swpaul	if (nodecol == IBV_RED)
423198786Srpaulo		return;
424123474Swpaul
425123474Swpaul	while ((!child || child->color == IBV_BLACK) && child != mm_root) {
426123474Swpaul		if (parent->left == child) {
427123474Swpaul			sib = parent->right;
428123474Swpaul
429123474Swpaul			if (sib->color == IBV_RED) {
430123474Swpaul				parent->color = IBV_RED;
431123474Swpaul				sib->color    = IBV_BLACK;
432198786Srpaulo				__mm_rotate_left(parent);
433123474Swpaul				sib = parent->right;
434123474Swpaul			}
435123474Swpaul
436123474Swpaul			if ((!sib->left  || sib->left->color  == IBV_BLACK) &&
437123474Swpaul			    (!sib->right || sib->right->color == IBV_BLACK)) {
438123474Swpaul				sib->color = IBV_RED;
439123474Swpaul				child  = parent;
440142387Swpaul				parent = child->parent;
441142387Swpaul			} else {
442123474Swpaul				if (!sib->right || sib->right->color == IBV_BLACK) {
443123474Swpaul					if (sib->left)
444123474Swpaul						sib->left->color = IBV_BLACK;
445123474Swpaul					sib->color = IBV_RED;
446123474Swpaul					__mm_rotate_right(sib);
447123474Swpaul					sib = parent->right;
448123474Swpaul				}
449189488Sweongyo
450123474Swpaul				sib->color    = parent->color;
451123474Swpaul				parent->color = IBV_BLACK;
452123474Swpaul				if (sib->right)
453123474Swpaul					sib->right->color = IBV_BLACK;
454123821Swpaul				__mm_rotate_left(parent);
455198786Srpaulo				child = mm_root;
456123474Swpaul				break;
457123474Swpaul			}
458123474Swpaul		} else {
459123474Swpaul			sib = parent->left;
460123474Swpaul
461123474Swpaul			if (sib->color == IBV_RED) {
462123474Swpaul				parent->color = IBV_RED;
463123474Swpaul				sib->color    = IBV_BLACK;
464123474Swpaul				__mm_rotate_right(parent);
465123474Swpaul				sib = parent->left;
466142387Swpaul			}
467123474Swpaul
468123474Swpaul			if ((!sib->left  || sib->left->color  == IBV_BLACK) &&
469198786Srpaulo			    (!sib->right || sib->right->color == IBV_BLACK)) {
470123474Swpaul				sib->color = IBV_RED;
471123474Swpaul				child  = parent;
472123474Swpaul				parent = child->parent;
473123474Swpaul			} else {
474123474Swpaul				if (!sib->left || sib->left->color == IBV_BLACK) {
475123474Swpaul					if (sib->right)
476123474Swpaul						sib->right->color = IBV_BLACK;
477124165Swpaul					sib->color = IBV_RED;
478124165Swpaul					__mm_rotate_left(sib);
479124165Swpaul					sib = parent->left;
480124165Swpaul				}
481124165Swpaul
482124165Swpaul				sib->color    = parent->color;
483124165Swpaul				parent->color = IBV_BLACK;
484124165Swpaul				if (sib->left)
485124165Swpaul					sib->left->color = IBV_BLACK;
486124165Swpaul				__mm_rotate_right(parent);
487124165Swpaul				child = mm_root;
488124165Swpaul				break;
489198786Srpaulo			}
490124165Swpaul		}
491124165Swpaul	}
492124165Swpaul
493124165Swpaul	if (child)
494124165Swpaul		child->color = IBV_BLACK;
495124165Swpaul}
496124165Swpaul
497124165Swpaulstatic struct ibv_mem_node *__mm_find_start(uintptr_t start, uintptr_t end)
498124165Swpaul{
499124165Swpaul	struct ibv_mem_node *node = mm_root;
500124165Swpaul
501124165Swpaul	while (node) {
502124165Swpaul		if (node->start <= start && node->end >= start)
503124165Swpaul			break;
504124165Swpaul
505124165Swpaul		if (node->start < start)
506124165Swpaul			node = node->right;
507124165Swpaul		else
508124165Swpaul			node = node->left;
509124165Swpaul	}
510124165Swpaul
511124165Swpaul	return node;
512124165Swpaul}
513124165Swpaul
514124165Swpaulstatic struct ibv_mem_node *merge_ranges(struct ibv_mem_node *node,
515124165Swpaul					 struct ibv_mem_node *prev)
516124165Swpaul{
517198786Srpaulo	prev->end = node->end;
518124165Swpaul	prev->refcnt = node->refcnt;
519124165Swpaul	__mm_remove(node);
520198786Srpaulo
521124165Swpaul	return prev;
522124165Swpaul}
523124165Swpaul
524124173Swpaulstatic struct ibv_mem_node *split_range(struct ibv_mem_node *node,
525124165Swpaul					uintptr_t cut_line)
526124165Swpaul{
527124165Swpaul	struct ibv_mem_node *new_node = NULL;
528124165Swpaul
529124173Swpaul	new_node = malloc(sizeof *new_node);
530124165Swpaul	if (!new_node)
531124165Swpaul		return NULL;
532124165Swpaul	new_node->start  = cut_line;
533124165Swpaul	new_node->end    = node->end;
534124173Swpaul	new_node->refcnt = node->refcnt;
535124165Swpaul	node->end  = cut_line - 1;
536124165Swpaul	__mm_add(new_node);
537124165Swpaul
538124165Swpaul	return new_node;
539198786Srpaulo}
540124165Swpaul
541124165Swpaulstatic struct ibv_mem_node *get_start_node(uintptr_t start, uintptr_t end,
542124165Swpaul					   int inc)
543124165Swpaul{
544124165Swpaul	struct ibv_mem_node *node, *tmp = NULL;
545124165Swpaul
546124165Swpaul	node = __mm_find_start(start, end);
547124165Swpaul	if (node->start < start)
548124165Swpaul		node = split_range(node, start);
549124165Swpaul	else {
550124165Swpaul		tmp = __mm_prev(node);
551124173Swpaul		if (tmp && tmp->refcnt == node->refcnt + inc)
552124165Swpaul			node = merge_ranges(node, tmp);
553124173Swpaul	}
554198786Srpaulo	return node;
555124165Swpaul}
556124165Swpaul
557124165Swpaul/*
558124165Swpaul * This function is called if madvise() fails to undo merging/splitting
559198786Srpaulo * operations performed on the node.
560124165Swpaul */
561124165Swpaulstatic struct ibv_mem_node *undo_node(struct ibv_mem_node *node,
562123474Swpaul				      uintptr_t start, int inc)
563123474Swpaul{
564123474Swpaul	struct ibv_mem_node *tmp = NULL;
565123474Swpaul
566123474Swpaul	/*
567123474Swpaul	 * This condition can be true only if we merged this
568123474Swpaul	 * node with the previous one, so we need to split them.
569123474Swpaul	*/
570123474Swpaul	if (start > node->start) {
571123474Swpaul		tmp = split_range(node, start);
572123474Swpaul		if (tmp) {
573123474Swpaul			node->refcnt += inc;
574123474Swpaul			node = tmp;
575123474Swpaul		} else
576198786Srpaulo			return NULL;
577123474Swpaul	}
578123474Swpaul
579123474Swpaul	tmp  =  __mm_prev(node);
580123474Swpaul	if (tmp && tmp->refcnt == node->refcnt)
581123474Swpaul		node = merge_ranges(node, tmp);
582198786Srpaulo
583123474Swpaul	tmp  =  __mm_next(node);
584123474Swpaul	if (tmp && tmp->refcnt == node->refcnt)
585198786Srpaulo		node = merge_ranges(tmp, node);
586141963Swpaul
587141963Swpaul	return node;
588141963Swpaul}
589141963Swpaul
590141963Swpaulstatic int ibv_madvise_range(void *base, size_t size, int advice)
591141963Swpaul{
592141963Swpaul	uintptr_t start, end;
593141963Swpaul	struct ibv_mem_node *node, *tmp;
594198786Srpaulo	int inc;
595123474Swpaul	int rolling_back = 0;
596123474Swpaul	int ret = 0;
597123474Swpaul	unsigned long range_page_size;
598123474Swpaul
599123474Swpaul	if (!size)
600123474Swpaul		return 0;
601123474Swpaul
602123474Swpaul	if (huge_page_enabled)
603123474Swpaul		range_page_size = get_page_size(base);
604123474Swpaul	else
605123474Swpaul		range_page_size = page_size;
606123474Swpaul
607123474Swpaul	start = (uintptr_t) base & ~(range_page_size - 1);
608123474Swpaul	end   = ((uintptr_t) (base + size + range_page_size - 1) &
609123474Swpaul		 ~(range_page_size - 1)) - 1;
610123474Swpaul
611123474Swpaul	pthread_mutex_lock(&mm_mutex);
612123474Swpaulagain:
613123474Swpaul	inc = advice == MADV_DONTFORK ? 1 : -1;
614123474Swpaul
615123474Swpaul	node = get_start_node(start, end, inc);
616123474Swpaul	if (!node) {
617123821Swpaul		ret = -1;
618198786Srpaulo		goto out;
619123474Swpaul	}
620123474Swpaul
621198786Srpaulo	while (node && node->start <= end) {
622123474Swpaul		if (node->end > end) {
623123474Swpaul			if (!split_range(node, end + 1)) {
624123474Swpaul				ret = -1;
625123474Swpaul				goto out;
626123474Swpaul			}
627123474Swpaul		}
628123848Swpaul
629123474Swpaul		if ((inc == -1 && node->refcnt == 1) ||
630123474Swpaul		    (inc ==  1 && node->refcnt == 0)) {
631123474Swpaul			/*
632123474Swpaul			 * If this is the first time through the loop,
633123474Swpaul			 * and we merged this node with the previous
634123474Swpaul			 * one, then we only want to do the madvise()
635198786Srpaulo			 * on start ... node->end (rather than
636123474Swpaul			 * starting at node->start).
637123474Swpaul			 *
638123474Swpaul			 * Otherwise we end up doing madvise() on
639123474Swpaul			 * bigger region than we're being asked to,
640123474Swpaul			 * and that may lead to a spurious failure.
641198786Srpaulo			 */
642123474Swpaul			if (start > node->start)
643				ret = madvise((void *) start, node->end - start + 1,
644					      advice);
645			else
646				ret = madvise((void *) node->start,
647					      node->end - node->start + 1,
648					      advice);
649			if (ret) {
650				node = undo_node(node, start, inc);
651
652				if (rolling_back || !node)
653					goto out;
654
655				/* madvise failed, roll back previous changes */
656				rolling_back = 1;
657				advice = advice == MADV_DONTFORK ?
658					MADV_DOFORK : MADV_DONTFORK;
659				tmp = __mm_prev(node);
660				if (!tmp || start > tmp->end)
661					goto out;
662				end = tmp->end;
663				goto again;
664			}
665		}
666
667		node->refcnt += inc;
668		node = __mm_next(node);
669	}
670
671	if (node) {
672		tmp = __mm_prev(node);
673		if (tmp && node->refcnt == tmp->refcnt)
674			node = merge_ranges(node, tmp);
675	}
676
677out:
678	if (rolling_back)
679		ret = -1;
680
681	pthread_mutex_unlock(&mm_mutex);
682
683	return ret;
684}
685
686int ibv_dontfork_range(void *base, size_t size)
687{
688	if (mm_root)
689		return ibv_madvise_range(base, size, MADV_DONTFORK);
690	else {
691		too_late = 1;
692		return 0;
693	}
694}
695
696int ibv_dofork_range(void *base, size_t size)
697{
698	if (mm_root)
699		return ibv_madvise_range(base, size, MADV_DOFORK);
700	else {
701		too_late = 1;
702		return 0;
703	}
704}
705