1 /*
2 * Copyright (c) 2014, Cisco Systems, Inc. All rights reserved.
3 *
4 * This program is free software; you may redistribute it and/or modify
5 * it under the terms of the GNU General Public License as published by
6 * the Free Software Foundation; version 2 of the License.
7 *
8 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
9 * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
10 * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
11 * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
12 * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
13 * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
14 * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
15 * SOFTWARE.
16 *
17 */
18
19 #include <linux/init.h>
20 #include <linux/list.h>
21 #include <linux/slab.h>
22 #include <linux/list_sort.h>
23
24 #include <linux/interval_tree_generic.h>
25 #include "usnic_uiom_interval_tree.h"
26
27 #define START(node) ((node)->start)
28 #define LAST(node) ((node)->last)
29
30 #define MAKE_NODE(node, start, end, ref_cnt, flags, err, err_out) \
31 do { \
32 node = usnic_uiom_interval_node_alloc(start, \
33 end, ref_cnt, flags); \
34 if (!node) { \
35 err = -ENOMEM; \
36 goto err_out; \
37 } \
38 } while (0)
39
40 #define MARK_FOR_ADD(node, list) (list_add_tail(&node->link, list))
41
42 #define MAKE_NODE_AND_APPEND(node, start, end, ref_cnt, flags, err, \
43 err_out, list) \
44 do { \
45 MAKE_NODE(node, start, end, \
46 ref_cnt, flags, err, \
47 err_out); \
48 MARK_FOR_ADD(node, list); \
49 } while (0)
50
51 #define FLAGS_EQUAL(flags1, flags2, mask) \
52 (((flags1) & (mask)) == ((flags2) & (mask)))
53
54 static struct usnic_uiom_interval_node*
usnic_uiom_interval_node_alloc(long int start,long int last,int ref_cnt,int flags)55 usnic_uiom_interval_node_alloc(long int start, long int last, int ref_cnt,
56 int flags)
57 {
58 struct usnic_uiom_interval_node *interval = kzalloc(sizeof(*interval),
59 GFP_ATOMIC);
60 if (!interval)
61 return NULL;
62
63 interval->start = start;
64 interval->last = last;
65 interval->flags = flags;
66 interval->ref_cnt = ref_cnt;
67
68 return interval;
69 }
70
interval_cmp(void * priv,struct list_head * a,struct list_head * b)71 static int interval_cmp(void *priv, struct list_head *a, struct list_head *b)
72 {
73 struct usnic_uiom_interval_node *node_a, *node_b;
74
75 node_a = list_entry(a, struct usnic_uiom_interval_node, link);
76 node_b = list_entry(b, struct usnic_uiom_interval_node, link);
77
78 /* long to int */
79 if (node_a->start < node_b->start)
80 return -1;
81 else if (node_a->start > node_b->start)
82 return 1;
83
84 return 0;
85 }
86
87 static void
find_intervals_intersection_sorted(struct rb_root * root,unsigned long start,unsigned long last,struct list_head * list)88 find_intervals_intersection_sorted(struct rb_root *root, unsigned long start,
89 unsigned long last,
90 struct list_head *list)
91 {
92 struct usnic_uiom_interval_node *node;
93
94 INIT_LIST_HEAD(list);
95
96 for (node = usnic_uiom_interval_tree_iter_first(root, start, last);
97 node;
98 node = usnic_uiom_interval_tree_iter_next(node, start, last))
99 list_add_tail(&node->link, list);
100
101 list_sort(NULL, list, interval_cmp);
102 }
103
usnic_uiom_get_intervals_diff(unsigned long start,unsigned long last,int flags,int flag_mask,struct rb_root * root,struct list_head * diff_set)104 int usnic_uiom_get_intervals_diff(unsigned long start, unsigned long last,
105 int flags, int flag_mask,
106 struct rb_root *root,
107 struct list_head *diff_set)
108 {
109 struct usnic_uiom_interval_node *interval, *tmp;
110 int err = 0;
111 long int pivot = start;
112 LIST_HEAD(intersection_set);
113
114 INIT_LIST_HEAD(diff_set);
115
116 find_intervals_intersection_sorted(root, start, last,
117 &intersection_set);
118
119 list_for_each_entry(interval, &intersection_set, link) {
120 if (pivot < interval->start) {
121 MAKE_NODE_AND_APPEND(tmp, pivot, interval->start - 1,
122 1, flags, err, err_out,
123 diff_set);
124 pivot = interval->start;
125 }
126
127 /*
128 * Invariant: Set [start, pivot] is either in diff_set or root,
129 * but not in both.
130 */
131
132 if (pivot > interval->last) {
133 continue;
134 } else if (pivot <= interval->last &&
135 FLAGS_EQUAL(interval->flags, flags,
136 flag_mask)) {
137 pivot = interval->last + 1;
138 }
139 }
140
141 if (pivot <= last)
142 MAKE_NODE_AND_APPEND(tmp, pivot, last, 1, flags, err, err_out,
143 diff_set);
144
145 return 0;
146
147 err_out:
148 list_for_each_entry_safe(interval, tmp, diff_set, link) {
149 list_del(&interval->link);
150 kfree(interval);
151 }
152
153 return err;
154 }
155
usnic_uiom_put_interval_set(struct list_head * intervals)156 void usnic_uiom_put_interval_set(struct list_head *intervals)
157 {
158 struct usnic_uiom_interval_node *interval, *tmp;
159 list_for_each_entry_safe(interval, tmp, intervals, link)
160 kfree(interval);
161 }
162
usnic_uiom_insert_interval(struct rb_root * root,unsigned long start,unsigned long last,int flags)163 int usnic_uiom_insert_interval(struct rb_root *root, unsigned long start,
164 unsigned long last, int flags)
165 {
166 struct usnic_uiom_interval_node *interval, *tmp;
167 unsigned long istart, ilast;
168 int iref_cnt, iflags;
169 unsigned long lpivot = start;
170 int err = 0;
171 LIST_HEAD(to_add);
172 LIST_HEAD(intersection_set);
173
174 find_intervals_intersection_sorted(root, start, last,
175 &intersection_set);
176
177 list_for_each_entry(interval, &intersection_set, link) {
178 /*
179 * Invariant - lpivot is the left edge of next interval to be
180 * inserted
181 */
182 istart = interval->start;
183 ilast = interval->last;
184 iref_cnt = interval->ref_cnt;
185 iflags = interval->flags;
186
187 if (istart < lpivot) {
188 MAKE_NODE_AND_APPEND(tmp, istart, lpivot - 1, iref_cnt,
189 iflags, err, err_out, &to_add);
190 } else if (istart > lpivot) {
191 MAKE_NODE_AND_APPEND(tmp, lpivot, istart - 1, 1, flags,
192 err, err_out, &to_add);
193 lpivot = istart;
194 } else {
195 lpivot = istart;
196 }
197
198 if (ilast > last) {
199 MAKE_NODE_AND_APPEND(tmp, lpivot, last, iref_cnt + 1,
200 iflags | flags, err, err_out,
201 &to_add);
202 MAKE_NODE_AND_APPEND(tmp, last + 1, ilast, iref_cnt,
203 iflags, err, err_out, &to_add);
204 } else {
205 MAKE_NODE_AND_APPEND(tmp, lpivot, ilast, iref_cnt + 1,
206 iflags | flags, err, err_out,
207 &to_add);
208 }
209
210 lpivot = ilast + 1;
211 }
212
213 if (lpivot <= last)
214 MAKE_NODE_AND_APPEND(tmp, lpivot, last, 1, flags, err, err_out,
215 &to_add);
216
217 list_for_each_entry_safe(interval, tmp, &intersection_set, link) {
218 usnic_uiom_interval_tree_remove(interval, root);
219 kfree(interval);
220 }
221
222 list_for_each_entry(interval, &to_add, link)
223 usnic_uiom_interval_tree_insert(interval, root);
224
225 return 0;
226
227 err_out:
228 list_for_each_entry_safe(interval, tmp, &to_add, link)
229 kfree(interval);
230
231 return err;
232 }
233
usnic_uiom_remove_interval(struct rb_root * root,unsigned long start,unsigned long last,struct list_head * removed)234 void usnic_uiom_remove_interval(struct rb_root *root, unsigned long start,
235 unsigned long last, struct list_head *removed)
236 {
237 struct usnic_uiom_interval_node *interval;
238
239 for (interval = usnic_uiom_interval_tree_iter_first(root, start, last);
240 interval;
241 interval = usnic_uiom_interval_tree_iter_next(interval,
242 start,
243 last)) {
244 if (--interval->ref_cnt == 0)
245 list_add_tail(&interval->link, removed);
246 }
247
248 list_for_each_entry(interval, removed, link)
249 usnic_uiom_interval_tree_remove(interval, root);
250 }
251
252 INTERVAL_TREE_DEFINE(struct usnic_uiom_interval_node, rb,
253 unsigned long, __subtree_last,
254 START, LAST, , usnic_uiom_interval_tree)
255