This source file includes following definitions.
- extract_bit
- longest_prefix_match
- trie_lookup_elem
- lpm_trie_node_alloc
- trie_update_elem
- trie_delete_elem
- trie_alloc
- trie_free
- trie_get_next_key
- trie_check_btf
1
2
3
4
5
6
7
8
9 #include <linux/bpf.h>
10 #include <linux/btf.h>
11 #include <linux/err.h>
12 #include <linux/slab.h>
13 #include <linux/spinlock.h>
14 #include <linux/vmalloc.h>
15 #include <net/ipv6.h>
16 #include <uapi/linux/btf.h>
17
18
19 #define LPM_TREE_NODE_FLAG_IM BIT(0)
20
21 struct lpm_trie_node;
22
23 struct lpm_trie_node {
24 struct rcu_head rcu;
25 struct lpm_trie_node __rcu *child[2];
26 u32 prefixlen;
27 u32 flags;
28 u8 data[0];
29 };
30
31 struct lpm_trie {
32 struct bpf_map map;
33 struct lpm_trie_node __rcu *root;
34 size_t n_entries;
35 size_t max_prefixlen;
36 size_t data_size;
37 raw_spinlock_t lock;
38 };
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151 static inline int extract_bit(const u8 *data, size_t index)
152 {
153 return !!(data[index / 8] & (1 << (7 - (index % 8))));
154 }
155
156
157
158
159
160
161
162
163
164 static size_t longest_prefix_match(const struct lpm_trie *trie,
165 const struct lpm_trie_node *node,
166 const struct bpf_lpm_trie_key *key)
167 {
168 u32 limit = min(node->prefixlen, key->prefixlen);
169 u32 prefixlen = 0, i = 0;
170
171 BUILD_BUG_ON(offsetof(struct lpm_trie_node, data) % sizeof(u32));
172 BUILD_BUG_ON(offsetof(struct bpf_lpm_trie_key, data) % sizeof(u32));
173
174 #if defined(CONFIG_HAVE_EFFICIENT_UNALIGNED_ACCESS) && defined(CONFIG_64BIT)
175
176
177
178
179 if (trie->data_size >= 8) {
180 u64 diff = be64_to_cpu(*(__be64 *)node->data ^
181 *(__be64 *)key->data);
182
183 prefixlen = 64 - fls64(diff);
184 if (prefixlen >= limit)
185 return limit;
186 if (diff)
187 return prefixlen;
188 i = 8;
189 }
190 #endif
191
192 while (trie->data_size >= i + 4) {
193 u32 diff = be32_to_cpu(*(__be32 *)&node->data[i] ^
194 *(__be32 *)&key->data[i]);
195
196 prefixlen += 32 - fls(diff);
197 if (prefixlen >= limit)
198 return limit;
199 if (diff)
200 return prefixlen;
201 i += 4;
202 }
203
204 if (trie->data_size >= i + 2) {
205 u16 diff = be16_to_cpu(*(__be16 *)&node->data[i] ^
206 *(__be16 *)&key->data[i]);
207
208 prefixlen += 16 - fls(diff);
209 if (prefixlen >= limit)
210 return limit;
211 if (diff)
212 return prefixlen;
213 i += 2;
214 }
215
216 if (trie->data_size >= i + 1) {
217 prefixlen += 8 - fls(node->data[i] ^ key->data[i]);
218
219 if (prefixlen >= limit)
220 return limit;
221 }
222
223 return prefixlen;
224 }
225
226
227 static void *trie_lookup_elem(struct bpf_map *map, void *_key)
228 {
229 struct lpm_trie *trie = container_of(map, struct lpm_trie, map);
230 struct lpm_trie_node *node, *found = NULL;
231 struct bpf_lpm_trie_key *key = _key;
232
233
234
235 for (node = rcu_dereference(trie->root); node;) {
236 unsigned int next_bit;
237 size_t matchlen;
238
239
240
241
242
243 matchlen = longest_prefix_match(trie, node, key);
244 if (matchlen == trie->max_prefixlen) {
245 found = node;
246 break;
247 }
248
249
250
251
252
253 if (matchlen < node->prefixlen)
254 break;
255
256
257
258
259 if (!(node->flags & LPM_TREE_NODE_FLAG_IM))
260 found = node;
261
262
263
264
265
266 next_bit = extract_bit(key->data, node->prefixlen);
267 node = rcu_dereference(node->child[next_bit]);
268 }
269
270 if (!found)
271 return NULL;
272
273 return found->data + trie->data_size;
274 }
275
276 static struct lpm_trie_node *lpm_trie_node_alloc(const struct lpm_trie *trie,
277 const void *value)
278 {
279 struct lpm_trie_node *node;
280 size_t size = sizeof(struct lpm_trie_node) + trie->data_size;
281
282 if (value)
283 size += trie->map.value_size;
284
285 node = kmalloc_node(size, GFP_ATOMIC | __GFP_NOWARN,
286 trie->map.numa_node);
287 if (!node)
288 return NULL;
289
290 node->flags = 0;
291
292 if (value)
293 memcpy(node->data + trie->data_size, value,
294 trie->map.value_size);
295
296 return node;
297 }
298
299
300 static int trie_update_elem(struct bpf_map *map,
301 void *_key, void *value, u64 flags)
302 {
303 struct lpm_trie *trie = container_of(map, struct lpm_trie, map);
304 struct lpm_trie_node *node, *im_node = NULL, *new_node = NULL;
305 struct lpm_trie_node __rcu **slot;
306 struct bpf_lpm_trie_key *key = _key;
307 unsigned long irq_flags;
308 unsigned int next_bit;
309 size_t matchlen = 0;
310 int ret = 0;
311
312 if (unlikely(flags > BPF_EXIST))
313 return -EINVAL;
314
315 if (key->prefixlen > trie->max_prefixlen)
316 return -EINVAL;
317
318 raw_spin_lock_irqsave(&trie->lock, irq_flags);
319
320
321
322 if (trie->n_entries == trie->map.max_entries) {
323 ret = -ENOSPC;
324 goto out;
325 }
326
327 new_node = lpm_trie_node_alloc(trie, value);
328 if (!new_node) {
329 ret = -ENOMEM;
330 goto out;
331 }
332
333 trie->n_entries++;
334
335 new_node->prefixlen = key->prefixlen;
336 RCU_INIT_POINTER(new_node->child[0], NULL);
337 RCU_INIT_POINTER(new_node->child[1], NULL);
338 memcpy(new_node->data, key->data, trie->data_size);
339
340
341
342
343
344
345 slot = &trie->root;
346
347 while ((node = rcu_dereference_protected(*slot,
348 lockdep_is_held(&trie->lock)))) {
349 matchlen = longest_prefix_match(trie, node, key);
350
351 if (node->prefixlen != matchlen ||
352 node->prefixlen == key->prefixlen ||
353 node->prefixlen == trie->max_prefixlen)
354 break;
355
356 next_bit = extract_bit(key->data, node->prefixlen);
357 slot = &node->child[next_bit];
358 }
359
360
361
362
363 if (!node) {
364 rcu_assign_pointer(*slot, new_node);
365 goto out;
366 }
367
368
369
370
371 if (node->prefixlen == matchlen) {
372 new_node->child[0] = node->child[0];
373 new_node->child[1] = node->child[1];
374
375 if (!(node->flags & LPM_TREE_NODE_FLAG_IM))
376 trie->n_entries--;
377
378 rcu_assign_pointer(*slot, new_node);
379 kfree_rcu(node, rcu);
380
381 goto out;
382 }
383
384
385
386
387 if (matchlen == key->prefixlen) {
388 next_bit = extract_bit(node->data, matchlen);
389 rcu_assign_pointer(new_node->child[next_bit], node);
390 rcu_assign_pointer(*slot, new_node);
391 goto out;
392 }
393
394 im_node = lpm_trie_node_alloc(trie, NULL);
395 if (!im_node) {
396 ret = -ENOMEM;
397 goto out;
398 }
399
400 im_node->prefixlen = matchlen;
401 im_node->flags |= LPM_TREE_NODE_FLAG_IM;
402 memcpy(im_node->data, node->data, trie->data_size);
403
404
405 if (extract_bit(key->data, matchlen)) {
406 rcu_assign_pointer(im_node->child[0], node);
407 rcu_assign_pointer(im_node->child[1], new_node);
408 } else {
409 rcu_assign_pointer(im_node->child[0], new_node);
410 rcu_assign_pointer(im_node->child[1], node);
411 }
412
413
414 rcu_assign_pointer(*slot, im_node);
415
416 out:
417 if (ret) {
418 if (new_node)
419 trie->n_entries--;
420
421 kfree(new_node);
422 kfree(im_node);
423 }
424
425 raw_spin_unlock_irqrestore(&trie->lock, irq_flags);
426
427 return ret;
428 }
429
430
431 static int trie_delete_elem(struct bpf_map *map, void *_key)
432 {
433 struct lpm_trie *trie = container_of(map, struct lpm_trie, map);
434 struct bpf_lpm_trie_key *key = _key;
435 struct lpm_trie_node __rcu **trim, **trim2;
436 struct lpm_trie_node *node, *parent;
437 unsigned long irq_flags;
438 unsigned int next_bit;
439 size_t matchlen = 0;
440 int ret = 0;
441
442 if (key->prefixlen > trie->max_prefixlen)
443 return -EINVAL;
444
445 raw_spin_lock_irqsave(&trie->lock, irq_flags);
446
447
448
449
450
451
452
453 trim = &trie->root;
454 trim2 = trim;
455 parent = NULL;
456 while ((node = rcu_dereference_protected(
457 *trim, lockdep_is_held(&trie->lock)))) {
458 matchlen = longest_prefix_match(trie, node, key);
459
460 if (node->prefixlen != matchlen ||
461 node->prefixlen == key->prefixlen)
462 break;
463
464 parent = node;
465 trim2 = trim;
466 next_bit = extract_bit(key->data, node->prefixlen);
467 trim = &node->child[next_bit];
468 }
469
470 if (!node || node->prefixlen != key->prefixlen ||
471 node->prefixlen != matchlen ||
472 (node->flags & LPM_TREE_NODE_FLAG_IM)) {
473 ret = -ENOENT;
474 goto out;
475 }
476
477 trie->n_entries--;
478
479
480
481
482 if (rcu_access_pointer(node->child[0]) &&
483 rcu_access_pointer(node->child[1])) {
484 node->flags |= LPM_TREE_NODE_FLAG_IM;
485 goto out;
486 }
487
488
489
490
491
492
493
494
495 if (parent && (parent->flags & LPM_TREE_NODE_FLAG_IM) &&
496 !node->child[0] && !node->child[1]) {
497 if (node == rcu_access_pointer(parent->child[0]))
498 rcu_assign_pointer(
499 *trim2, rcu_access_pointer(parent->child[1]));
500 else
501 rcu_assign_pointer(
502 *trim2, rcu_access_pointer(parent->child[0]));
503 kfree_rcu(parent, rcu);
504 kfree_rcu(node, rcu);
505 goto out;
506 }
507
508
509
510
511
512 if (node->child[0])
513 rcu_assign_pointer(*trim, rcu_access_pointer(node->child[0]));
514 else if (node->child[1])
515 rcu_assign_pointer(*trim, rcu_access_pointer(node->child[1]));
516 else
517 RCU_INIT_POINTER(*trim, NULL);
518 kfree_rcu(node, rcu);
519
520 out:
521 raw_spin_unlock_irqrestore(&trie->lock, irq_flags);
522
523 return ret;
524 }
525
526 #define LPM_DATA_SIZE_MAX 256
527 #define LPM_DATA_SIZE_MIN 1
528
529 #define LPM_VAL_SIZE_MAX (KMALLOC_MAX_SIZE - LPM_DATA_SIZE_MAX - \
530 sizeof(struct lpm_trie_node))
531 #define LPM_VAL_SIZE_MIN 1
532
533 #define LPM_KEY_SIZE(X) (sizeof(struct bpf_lpm_trie_key) + (X))
534 #define LPM_KEY_SIZE_MAX LPM_KEY_SIZE(LPM_DATA_SIZE_MAX)
535 #define LPM_KEY_SIZE_MIN LPM_KEY_SIZE(LPM_DATA_SIZE_MIN)
536
537 #define LPM_CREATE_FLAG_MASK (BPF_F_NO_PREALLOC | BPF_F_NUMA_NODE | \
538 BPF_F_ACCESS_MASK)
539
540 static struct bpf_map *trie_alloc(union bpf_attr *attr)
541 {
542 struct lpm_trie *trie;
543 u64 cost = sizeof(*trie), cost_per_node;
544 int ret;
545
546 if (!capable(CAP_SYS_ADMIN))
547 return ERR_PTR(-EPERM);
548
549
550 if (attr->max_entries == 0 ||
551 !(attr->map_flags & BPF_F_NO_PREALLOC) ||
552 attr->map_flags & ~LPM_CREATE_FLAG_MASK ||
553 !bpf_map_flags_access_ok(attr->map_flags) ||
554 attr->key_size < LPM_KEY_SIZE_MIN ||
555 attr->key_size > LPM_KEY_SIZE_MAX ||
556 attr->value_size < LPM_VAL_SIZE_MIN ||
557 attr->value_size > LPM_VAL_SIZE_MAX)
558 return ERR_PTR(-EINVAL);
559
560 trie = kzalloc(sizeof(*trie), GFP_USER | __GFP_NOWARN);
561 if (!trie)
562 return ERR_PTR(-ENOMEM);
563
564
565 bpf_map_init_from_attr(&trie->map, attr);
566 trie->data_size = attr->key_size -
567 offsetof(struct bpf_lpm_trie_key, data);
568 trie->max_prefixlen = trie->data_size * 8;
569
570 cost_per_node = sizeof(struct lpm_trie_node) +
571 attr->value_size + trie->data_size;
572 cost += (u64) attr->max_entries * cost_per_node;
573
574 ret = bpf_map_charge_init(&trie->map.memory, cost);
575 if (ret)
576 goto out_err;
577
578 raw_spin_lock_init(&trie->lock);
579
580 return &trie->map;
581 out_err:
582 kfree(trie);
583 return ERR_PTR(ret);
584 }
585
586 static void trie_free(struct bpf_map *map)
587 {
588 struct lpm_trie *trie = container_of(map, struct lpm_trie, map);
589 struct lpm_trie_node __rcu **slot;
590 struct lpm_trie_node *node;
591
592
593
594
595 synchronize_rcu();
596
597
598
599
600
601
602 for (;;) {
603 slot = &trie->root;
604
605 for (;;) {
606 node = rcu_dereference_protected(*slot, 1);
607 if (!node)
608 goto out;
609
610 if (rcu_access_pointer(node->child[0])) {
611 slot = &node->child[0];
612 continue;
613 }
614
615 if (rcu_access_pointer(node->child[1])) {
616 slot = &node->child[1];
617 continue;
618 }
619
620 kfree(node);
621 RCU_INIT_POINTER(*slot, NULL);
622 break;
623 }
624 }
625
626 out:
627 kfree(trie);
628 }
629
630 static int trie_get_next_key(struct bpf_map *map, void *_key, void *_next_key)
631 {
632 struct lpm_trie_node *node, *next_node = NULL, *parent, *search_root;
633 struct lpm_trie *trie = container_of(map, struct lpm_trie, map);
634 struct bpf_lpm_trie_key *key = _key, *next_key = _next_key;
635 struct lpm_trie_node **node_stack = NULL;
636 int err = 0, stack_ptr = -1;
637 unsigned int next_bit;
638 size_t matchlen;
639
640
641
642
643
644
645
646
647
648
649
650
651
652 search_root = rcu_dereference(trie->root);
653 if (!search_root)
654 return -ENOENT;
655
656
657 if (!key || key->prefixlen > trie->max_prefixlen)
658 goto find_leftmost;
659
660 node_stack = kmalloc_array(trie->max_prefixlen,
661 sizeof(struct lpm_trie_node *),
662 GFP_ATOMIC | __GFP_NOWARN);
663 if (!node_stack)
664 return -ENOMEM;
665
666
667 for (node = search_root; node;) {
668 node_stack[++stack_ptr] = node;
669 matchlen = longest_prefix_match(trie, node, key);
670 if (node->prefixlen != matchlen ||
671 node->prefixlen == key->prefixlen)
672 break;
673
674 next_bit = extract_bit(key->data, node->prefixlen);
675 node = rcu_dereference(node->child[next_bit]);
676 }
677 if (!node || node->prefixlen != key->prefixlen ||
678 (node->flags & LPM_TREE_NODE_FLAG_IM))
679 goto find_leftmost;
680
681
682
683
684 node = node_stack[stack_ptr];
685 while (stack_ptr > 0) {
686 parent = node_stack[stack_ptr - 1];
687 if (rcu_dereference(parent->child[0]) == node) {
688 search_root = rcu_dereference(parent->child[1]);
689 if (search_root)
690 goto find_leftmost;
691 }
692 if (!(parent->flags & LPM_TREE_NODE_FLAG_IM)) {
693 next_node = parent;
694 goto do_copy;
695 }
696
697 node = parent;
698 stack_ptr--;
699 }
700
701
702 err = -ENOENT;
703 goto free_stack;
704
705 find_leftmost:
706
707
708
709 for (node = search_root; node;) {
710 if (node->flags & LPM_TREE_NODE_FLAG_IM) {
711 node = rcu_dereference(node->child[0]);
712 } else {
713 next_node = node;
714 node = rcu_dereference(node->child[0]);
715 if (!node)
716 node = rcu_dereference(next_node->child[1]);
717 }
718 }
719 do_copy:
720 next_key->prefixlen = next_node->prefixlen;
721 memcpy((void *)next_key + offsetof(struct bpf_lpm_trie_key, data),
722 next_node->data, trie->data_size);
723 free_stack:
724 kfree(node_stack);
725 return err;
726 }
727
728 static int trie_check_btf(const struct bpf_map *map,
729 const struct btf *btf,
730 const struct btf_type *key_type,
731 const struct btf_type *value_type)
732 {
733
734 return BTF_INFO_KIND(key_type->info) != BTF_KIND_STRUCT ?
735 -EINVAL : 0;
736 }
737
738 const struct bpf_map_ops trie_map_ops = {
739 .map_alloc = trie_alloc,
740 .map_free = trie_free,
741 .map_get_next_key = trie_get_next_key,
742 .map_lookup_elem = trie_lookup_elem,
743 .map_update_elem = trie_update_elem,
744 .map_delete_elem = trie_delete_elem,
745 .map_check_btf = trie_check_btf,
746 };