1/*
2 * Copyright (c) 2014, Cisco Systems, Inc. All rights reserved.
3 *
4 * This software is available to you under a choice of one of two
5 * licenses.  You may choose to be licensed under the terms of the GNU
6 * General Public License (GPL) Version 2, available from the file
7 * COPYING in the main directory of this source tree, or the
8 * BSD license below:
9 *
10 *     Redistribution and use in source and binary forms, with or
11 *     without modification, are permitted provided that the following
12 *     conditions are met:
13 *
14 *      - Redistributions of source code must retain the above
15 *        copyright notice, this list of conditions and the following
16 *        disclaimer.
17 *
18 *      - Redistributions in binary form must reproduce the above
19 *        copyright notice, this list of conditions and the following
20 *        disclaimer in the documentation and/or other materials
21 *        provided with the distribution.
22 *
23 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
24 * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
25 * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
26 * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
27 * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
28 * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
29 * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
30 * SOFTWARE.
31 *
32 */
33
34#include <linux/init.h>
35#include <linux/list.h>
36#include <linux/slab.h>
37#include <linux/list_sort.h>
38
39#include <linux/interval_tree_generic.h>
40#include "usnic_uiom_interval_tree.h"
41
42#define START(node) ((node)->start)
43#define LAST(node) ((node)->last)
44
45#define MAKE_NODE(node, start, end, ref_cnt, flags, err, err_out)	\
46		do {							\
47			node = usnic_uiom_interval_node_alloc(start,	\
48					end, ref_cnt, flags);		\
49				if (!node) {				\
50					err = -ENOMEM;			\
51					goto err_out;			\
52				}					\
53		} while (0)
54
55#define MARK_FOR_ADD(node, list) (list_add_tail(&node->link, list))
56
57#define MAKE_NODE_AND_APPEND(node, start, end, ref_cnt, flags, err,	\
58				err_out, list)				\
59				do {					\
60					MAKE_NODE(node, start, end,	\
61						ref_cnt, flags, err,	\
62						err_out);		\
63					MARK_FOR_ADD(node, list);	\
64				} while (0)
65
66#define FLAGS_EQUAL(flags1, flags2, mask)				\
67			(((flags1) & (mask)) == ((flags2) & (mask)))
68
69static struct usnic_uiom_interval_node*
70usnic_uiom_interval_node_alloc(long int start, long int last, int ref_cnt,
71				int flags)
72{
73	struct usnic_uiom_interval_node *interval = kzalloc(sizeof(*interval),
74								GFP_ATOMIC);
75	if (!interval)
76		return NULL;
77
78	interval->start = start;
79	interval->last = last;
80	interval->flags = flags;
81	interval->ref_cnt = ref_cnt;
82
83	return interval;
84}
85
86static int interval_cmp(void *priv, const struct list_head *a,
87			const struct list_head *b)
88{
89	struct usnic_uiom_interval_node *node_a, *node_b;
90
91	node_a = list_entry(a, struct usnic_uiom_interval_node, link);
92	node_b = list_entry(b, struct usnic_uiom_interval_node, link);
93
94	/* long to int */
95	if (node_a->start < node_b->start)
96		return -1;
97	else if (node_a->start > node_b->start)
98		return 1;
99
100	return 0;
101}
102
103static void
104find_intervals_intersection_sorted(struct rb_root_cached *root,
105				   unsigned long start, unsigned long last,
106				   struct list_head *list)
107{
108	struct usnic_uiom_interval_node *node;
109
110	INIT_LIST_HEAD(list);
111
112	for (node = usnic_uiom_interval_tree_iter_first(root, start, last);
113		node;
114		node = usnic_uiom_interval_tree_iter_next(node, start, last))
115		list_add_tail(&node->link, list);
116
117	list_sort(NULL, list, interval_cmp);
118}
119
120int usnic_uiom_get_intervals_diff(unsigned long start, unsigned long last,
121					int flags, int flag_mask,
122					struct rb_root_cached *root,
123					struct list_head *diff_set)
124{
125	struct usnic_uiom_interval_node *interval, *tmp;
126	int err = 0;
127	long int pivot = start;
128	LIST_HEAD(intersection_set);
129
130	INIT_LIST_HEAD(diff_set);
131
132	find_intervals_intersection_sorted(root, start, last,
133						&intersection_set);
134
135	list_for_each_entry(interval, &intersection_set, link) {
136		if (pivot < interval->start) {
137			MAKE_NODE_AND_APPEND(tmp, pivot, interval->start - 1,
138						1, flags, err, err_out,
139						diff_set);
140			pivot = interval->start;
141		}
142
143		/*
144		 * Invariant: Set [start, pivot] is either in diff_set or root,
145		 * but not in both.
146		 */
147
148		if (pivot > interval->last) {
149			continue;
150		} else if (pivot <= interval->last &&
151				FLAGS_EQUAL(interval->flags, flags,
152				flag_mask)) {
153			pivot = interval->last + 1;
154		}
155	}
156
157	if (pivot <= last)
158		MAKE_NODE_AND_APPEND(tmp, pivot, last, 1, flags, err, err_out,
159					diff_set);
160
161	return 0;
162
163err_out:
164	list_for_each_entry_safe(interval, tmp, diff_set, link) {
165		list_del(&interval->link);
166		kfree(interval);
167	}
168
169	return err;
170}
171
172void usnic_uiom_put_interval_set(struct list_head *intervals)
173{
174	struct usnic_uiom_interval_node *interval, *tmp;
175	list_for_each_entry_safe(interval, tmp, intervals, link)
176		kfree(interval);
177}
178
179int usnic_uiom_insert_interval(struct rb_root_cached *root, unsigned long start,
180				unsigned long last, int flags)
181{
182	struct usnic_uiom_interval_node *interval, *tmp;
183	unsigned long istart, ilast;
184	int iref_cnt, iflags;
185	unsigned long lpivot = start;
186	int err = 0;
187	LIST_HEAD(to_add);
188	LIST_HEAD(intersection_set);
189
190	find_intervals_intersection_sorted(root, start, last,
191						&intersection_set);
192
193	list_for_each_entry(interval, &intersection_set, link) {
194		/*
195		 * Invariant - lpivot is the left edge of next interval to be
196		 * inserted
197		 */
198		istart = interval->start;
199		ilast = interval->last;
200		iref_cnt = interval->ref_cnt;
201		iflags = interval->flags;
202
203		if (istart < lpivot) {
204			MAKE_NODE_AND_APPEND(tmp, istart, lpivot - 1, iref_cnt,
205						iflags, err, err_out, &to_add);
206		} else if (istart > lpivot) {
207			MAKE_NODE_AND_APPEND(tmp, lpivot, istart - 1, 1, flags,
208						err, err_out, &to_add);
209			lpivot = istart;
210		} else {
211			lpivot = istart;
212		}
213
214		if (ilast > last) {
215			MAKE_NODE_AND_APPEND(tmp, lpivot, last, iref_cnt + 1,
216						iflags | flags, err, err_out,
217						&to_add);
218			MAKE_NODE_AND_APPEND(tmp, last + 1, ilast, iref_cnt,
219						iflags, err, err_out, &to_add);
220		} else {
221			MAKE_NODE_AND_APPEND(tmp, lpivot, ilast, iref_cnt + 1,
222						iflags | flags, err, err_out,
223						&to_add);
224		}
225
226		lpivot = ilast + 1;
227	}
228
229	if (lpivot <= last)
230		MAKE_NODE_AND_APPEND(tmp, lpivot, last, 1, flags, err, err_out,
231					&to_add);
232
233	list_for_each_entry_safe(interval, tmp, &intersection_set, link) {
234		usnic_uiom_interval_tree_remove(interval, root);
235		kfree(interval);
236	}
237
238	list_for_each_entry(interval, &to_add, link)
239		usnic_uiom_interval_tree_insert(interval, root);
240
241	return 0;
242
243err_out:
244	list_for_each_entry_safe(interval, tmp, &to_add, link)
245		kfree(interval);
246
247	return err;
248}
249
250void usnic_uiom_remove_interval(struct rb_root_cached *root,
251				unsigned long start, unsigned long last,
252				struct list_head *removed)
253{
254	struct usnic_uiom_interval_node *interval;
255
256	for (interval = usnic_uiom_interval_tree_iter_first(root, start, last);
257			interval;
258			interval = usnic_uiom_interval_tree_iter_next(interval,
259									start,
260									last)) {
261		if (--interval->ref_cnt == 0)
262			list_add_tail(&interval->link, removed);
263	}
264
265	list_for_each_entry(interval, removed, link)
266		usnic_uiom_interval_tree_remove(interval, root);
267}
268
269INTERVAL_TREE_DEFINE(struct usnic_uiom_interval_node, rb,
270			unsigned long, __subtree_last,
271			START, LAST, , usnic_uiom_interval_tree)
272