root/tools/testing/selftests/bpf/progs/test_l4lb.c

/* [<][>][^][v][top][bottom][index][help] */

DEFINITIONS

This source file includes following definitions.
  1. rol32
  2. jhash
  3. __jhash_nwords
  4. jhash_2words
  5. get_packet_hash
  6. get_packet_dst
  7. parse_icmpv6
  8. parse_icmp
  9. parse_udp
  10. parse_tcp
  11. process_packet
  12. SEC

   1 /* Copyright (c) 2017 Facebook
   2  *
   3  * This program is free software; you can redistribute it and/or
   4  * modify it under the terms of version 2 of the GNU General Public
   5  * License as published by the Free Software Foundation.
   6  */
   7 #include <stddef.h>
   8 #include <stdbool.h>
   9 #include <string.h>
  10 #include <linux/pkt_cls.h>
  11 #include <linux/bpf.h>
  12 #include <linux/in.h>
  13 #include <linux/if_ether.h>
  14 #include <linux/ip.h>
  15 #include <linux/ipv6.h>
  16 #include <linux/icmp.h>
  17 #include <linux/icmpv6.h>
  18 #include <linux/tcp.h>
  19 #include <linux/udp.h>
  20 #include "bpf_helpers.h"
  21 #include "test_iptunnel_common.h"
  22 #include "bpf_endian.h"
  23 
  24 int _version SEC("version") = 1;
  25 
  26 static inline __u32 rol32(__u32 word, unsigned int shift)
  27 {
  28         return (word << shift) | (word >> ((-shift) & 31));
  29 }
  30 
  31 /* copy paste of jhash from kernel sources to make sure llvm
  32  * can compile it into valid sequence of bpf instructions
  33  */
  34 #define __jhash_mix(a, b, c)                    \
  35 {                                               \
  36         a -= c;  a ^= rol32(c, 4);  c += b;     \
  37         b -= a;  b ^= rol32(a, 6);  a += c;     \
  38         c -= b;  c ^= rol32(b, 8);  b += a;     \
  39         a -= c;  a ^= rol32(c, 16); c += b;     \
  40         b -= a;  b ^= rol32(a, 19); a += c;     \
  41         c -= b;  c ^= rol32(b, 4);  b += a;     \
  42 }
  43 
  44 #define __jhash_final(a, b, c)                  \
  45 {                                               \
  46         c ^= b; c -= rol32(b, 14);              \
  47         a ^= c; a -= rol32(c, 11);              \
  48         b ^= a; b -= rol32(a, 25);              \
  49         c ^= b; c -= rol32(b, 16);              \
  50         a ^= c; a -= rol32(c, 4);               \
  51         b ^= a; b -= rol32(a, 14);              \
  52         c ^= b; c -= rol32(b, 24);              \
  53 }
  54 
  55 #define JHASH_INITVAL           0xdeadbeef
  56 
  57 typedef unsigned int u32;
  58 
  59 static inline u32 jhash(const void *key, u32 length, u32 initval)
  60 {
  61         u32 a, b, c;
  62         const unsigned char *k = key;
  63 
  64         a = b = c = JHASH_INITVAL + length + initval;
  65 
  66         while (length > 12) {
  67                 a += *(u32 *)(k);
  68                 b += *(u32 *)(k + 4);
  69                 c += *(u32 *)(k + 8);
  70                 __jhash_mix(a, b, c);
  71                 length -= 12;
  72                 k += 12;
  73         }
  74         switch (length) {
  75         case 12: c += (u32)k[11]<<24;
  76         case 11: c += (u32)k[10]<<16;
  77         case 10: c += (u32)k[9]<<8;
  78         case 9:  c += k[8];
  79         case 8:  b += (u32)k[7]<<24;
  80         case 7:  b += (u32)k[6]<<16;
  81         case 6:  b += (u32)k[5]<<8;
  82         case 5:  b += k[4];
  83         case 4:  a += (u32)k[3]<<24;
  84         case 3:  a += (u32)k[2]<<16;
  85         case 2:  a += (u32)k[1]<<8;
  86         case 1:  a += k[0];
  87                  __jhash_final(a, b, c);
  88         case 0: /* Nothing left to add */
  89                 break;
  90         }
  91 
  92         return c;
  93 }
  94 
  95 static inline u32 __jhash_nwords(u32 a, u32 b, u32 c, u32 initval)
  96 {
  97         a += initval;
  98         b += initval;
  99         c += initval;
 100         __jhash_final(a, b, c);
 101         return c;
 102 }
 103 
 104 static inline u32 jhash_2words(u32 a, u32 b, u32 initval)
 105 {
 106         return __jhash_nwords(a, b, 0, initval + JHASH_INITVAL + (2 << 2));
 107 }
 108 
 109 #define PCKT_FRAGMENTED 65343
 110 #define IPV4_HDR_LEN_NO_OPT 20
 111 #define IPV4_PLUS_ICMP_HDR 28
 112 #define IPV6_PLUS_ICMP_HDR 48
 113 #define RING_SIZE 2
 114 #define MAX_VIPS 12
 115 #define MAX_REALS 5
 116 #define CTL_MAP_SIZE 16
 117 #define CH_RINGS_SIZE (MAX_VIPS * RING_SIZE)
 118 #define F_IPV6 (1 << 0)
 119 #define F_HASH_NO_SRC_PORT (1 << 0)
 120 #define F_ICMP (1 << 0)
 121 #define F_SYN_SET (1 << 1)
 122 
 123 struct packet_description {
 124         union {
 125                 __be32 src;
 126                 __be32 srcv6[4];
 127         };
 128         union {
 129                 __be32 dst;
 130                 __be32 dstv6[4];
 131         };
 132         union {
 133                 __u32 ports;
 134                 __u16 port16[2];
 135         };
 136         __u8 proto;
 137         __u8 flags;
 138 };
 139 
 140 struct ctl_value {
 141         union {
 142                 __u64 value;
 143                 __u32 ifindex;
 144                 __u8 mac[6];
 145         };
 146 };
 147 
 148 struct vip_meta {
 149         __u32 flags;
 150         __u32 vip_num;
 151 };
 152 
 153 struct real_definition {
 154         union {
 155                 __be32 dst;
 156                 __be32 dstv6[4];
 157         };
 158         __u8 flags;
 159 };
 160 
 161 struct vip_stats {
 162         __u64 bytes;
 163         __u64 pkts;
 164 };
 165 
 166 struct eth_hdr {
 167         unsigned char eth_dest[ETH_ALEN];
 168         unsigned char eth_source[ETH_ALEN];
 169         unsigned short eth_proto;
 170 };
 171 
 172 struct {
 173         __uint(type, BPF_MAP_TYPE_HASH);
 174         __uint(max_entries, MAX_VIPS);
 175         __type(key, struct vip);
 176         __type(value, struct vip_meta);
 177 } vip_map SEC(".maps");
 178 
 179 struct {
 180         __uint(type, BPF_MAP_TYPE_ARRAY);
 181         __uint(max_entries, CH_RINGS_SIZE);
 182         __type(key, __u32);
 183         __type(value, __u32);
 184 } ch_rings SEC(".maps");
 185 
 186 struct {
 187         __uint(type, BPF_MAP_TYPE_ARRAY);
 188         __uint(max_entries, MAX_REALS);
 189         __type(key, __u32);
 190         __type(value, struct real_definition);
 191 } reals SEC(".maps");
 192 
 193 struct {
 194         __uint(type, BPF_MAP_TYPE_PERCPU_ARRAY);
 195         __uint(max_entries, MAX_VIPS);
 196         __type(key, __u32);
 197         __type(value, struct vip_stats);
 198 } stats SEC(".maps");
 199 
 200 struct {
 201         __uint(type, BPF_MAP_TYPE_ARRAY);
 202         __uint(max_entries, CTL_MAP_SIZE);
 203         __type(key, __u32);
 204         __type(value, struct ctl_value);
 205 } ctl_array SEC(".maps");
 206 
 207 static __always_inline __u32 get_packet_hash(struct packet_description *pckt,
 208                                              bool ipv6)
 209 {
 210         if (ipv6)
 211                 return jhash_2words(jhash(pckt->srcv6, 16, MAX_VIPS),
 212                                     pckt->ports, CH_RINGS_SIZE);
 213         else
 214                 return jhash_2words(pckt->src, pckt->ports, CH_RINGS_SIZE);
 215 }
 216 
 217 static __always_inline bool get_packet_dst(struct real_definition **real,
 218                                            struct packet_description *pckt,
 219                                            struct vip_meta *vip_info,
 220                                            bool is_ipv6)
 221 {
 222         __u32 hash = get_packet_hash(pckt, is_ipv6) % RING_SIZE;
 223         __u32 key = RING_SIZE * vip_info->vip_num + hash;
 224         __u32 *real_pos;
 225 
 226         real_pos = bpf_map_lookup_elem(&ch_rings, &key);
 227         if (!real_pos)
 228                 return false;
 229         key = *real_pos;
 230         *real = bpf_map_lookup_elem(&reals, &key);
 231         if (!(*real))
 232                 return false;
 233         return true;
 234 }
 235 
 236 static __always_inline int parse_icmpv6(void *data, void *data_end, __u64 off,
 237                                         struct packet_description *pckt)
 238 {
 239         struct icmp6hdr *icmp_hdr;
 240         struct ipv6hdr *ip6h;
 241 
 242         icmp_hdr = data + off;
 243         if (icmp_hdr + 1 > data_end)
 244                 return TC_ACT_SHOT;
 245         if (icmp_hdr->icmp6_type != ICMPV6_PKT_TOOBIG)
 246                 return TC_ACT_OK;
 247         off += sizeof(struct icmp6hdr);
 248         ip6h = data + off;
 249         if (ip6h + 1 > data_end)
 250                 return TC_ACT_SHOT;
 251         pckt->proto = ip6h->nexthdr;
 252         pckt->flags |= F_ICMP;
 253         memcpy(pckt->srcv6, ip6h->daddr.s6_addr32, 16);
 254         memcpy(pckt->dstv6, ip6h->saddr.s6_addr32, 16);
 255         return TC_ACT_UNSPEC;
 256 }
 257 
 258 static __always_inline int parse_icmp(void *data, void *data_end, __u64 off,
 259                                       struct packet_description *pckt)
 260 {
 261         struct icmphdr *icmp_hdr;
 262         struct iphdr *iph;
 263 
 264         icmp_hdr = data + off;
 265         if (icmp_hdr + 1 > data_end)
 266                 return TC_ACT_SHOT;
 267         if (icmp_hdr->type != ICMP_DEST_UNREACH ||
 268             icmp_hdr->code != ICMP_FRAG_NEEDED)
 269                 return TC_ACT_OK;
 270         off += sizeof(struct icmphdr);
 271         iph = data + off;
 272         if (iph + 1 > data_end)
 273                 return TC_ACT_SHOT;
 274         if (iph->ihl != 5)
 275                 return TC_ACT_SHOT;
 276         pckt->proto = iph->protocol;
 277         pckt->flags |= F_ICMP;
 278         pckt->src = iph->daddr;
 279         pckt->dst = iph->saddr;
 280         return TC_ACT_UNSPEC;
 281 }
 282 
 283 static __always_inline bool parse_udp(void *data, __u64 off, void *data_end,
 284                                       struct packet_description *pckt)
 285 {
 286         struct udphdr *udp;
 287         udp = data + off;
 288 
 289         if (udp + 1 > data_end)
 290                 return false;
 291 
 292         if (!(pckt->flags & F_ICMP)) {
 293                 pckt->port16[0] = udp->source;
 294                 pckt->port16[1] = udp->dest;
 295         } else {
 296                 pckt->port16[0] = udp->dest;
 297                 pckt->port16[1] = udp->source;
 298         }
 299         return true;
 300 }
 301 
 302 static __always_inline bool parse_tcp(void *data, __u64 off, void *data_end,
 303                                       struct packet_description *pckt)
 304 {
 305         struct tcphdr *tcp;
 306 
 307         tcp = data + off;
 308         if (tcp + 1 > data_end)
 309                 return false;
 310 
 311         if (tcp->syn)
 312                 pckt->flags |= F_SYN_SET;
 313 
 314         if (!(pckt->flags & F_ICMP)) {
 315                 pckt->port16[0] = tcp->source;
 316                 pckt->port16[1] = tcp->dest;
 317         } else {
 318                 pckt->port16[0] = tcp->dest;
 319                 pckt->port16[1] = tcp->source;
 320         }
 321         return true;
 322 }
 323 
 324 static __always_inline int process_packet(void *data, __u64 off, void *data_end,
 325                                           bool is_ipv6, struct __sk_buff *skb)
 326 {
 327         void *pkt_start = (void *)(long)skb->data;
 328         struct packet_description pckt = {};
 329         struct eth_hdr *eth = pkt_start;
 330         struct bpf_tunnel_key tkey = {};
 331         struct vip_stats *data_stats;
 332         struct real_definition *dst;
 333         struct vip_meta *vip_info;
 334         struct ctl_value *cval;
 335         __u32 v4_intf_pos = 1;
 336         __u32 v6_intf_pos = 2;
 337         struct ipv6hdr *ip6h;
 338         struct vip vip = {};
 339         struct iphdr *iph;
 340         int tun_flag = 0;
 341         __u16 pkt_bytes;
 342         __u64 iph_len;
 343         __u32 ifindex;
 344         __u8 protocol;
 345         __u32 vip_num;
 346         int action;
 347 
 348         tkey.tunnel_ttl = 64;
 349         if (is_ipv6) {
 350                 ip6h = data + off;
 351                 if (ip6h + 1 > data_end)
 352                         return TC_ACT_SHOT;
 353 
 354                 iph_len = sizeof(struct ipv6hdr);
 355                 protocol = ip6h->nexthdr;
 356                 pckt.proto = protocol;
 357                 pkt_bytes = bpf_ntohs(ip6h->payload_len);
 358                 off += iph_len;
 359                 if (protocol == IPPROTO_FRAGMENT) {
 360                         return TC_ACT_SHOT;
 361                 } else if (protocol == IPPROTO_ICMPV6) {
 362                         action = parse_icmpv6(data, data_end, off, &pckt);
 363                         if (action >= 0)
 364                                 return action;
 365                         off += IPV6_PLUS_ICMP_HDR;
 366                 } else {
 367                         memcpy(pckt.srcv6, ip6h->saddr.s6_addr32, 16);
 368                         memcpy(pckt.dstv6, ip6h->daddr.s6_addr32, 16);
 369                 }
 370         } else {
 371                 iph = data + off;
 372                 if (iph + 1 > data_end)
 373                         return TC_ACT_SHOT;
 374                 if (iph->ihl != 5)
 375                         return TC_ACT_SHOT;
 376 
 377                 protocol = iph->protocol;
 378                 pckt.proto = protocol;
 379                 pkt_bytes = bpf_ntohs(iph->tot_len);
 380                 off += IPV4_HDR_LEN_NO_OPT;
 381 
 382                 if (iph->frag_off & PCKT_FRAGMENTED)
 383                         return TC_ACT_SHOT;
 384                 if (protocol == IPPROTO_ICMP) {
 385                         action = parse_icmp(data, data_end, off, &pckt);
 386                         if (action >= 0)
 387                                 return action;
 388                         off += IPV4_PLUS_ICMP_HDR;
 389                 } else {
 390                         pckt.src = iph->saddr;
 391                         pckt.dst = iph->daddr;
 392                 }
 393         }
 394         protocol = pckt.proto;
 395 
 396         if (protocol == IPPROTO_TCP) {
 397                 if (!parse_tcp(data, off, data_end, &pckt))
 398                         return TC_ACT_SHOT;
 399         } else if (protocol == IPPROTO_UDP) {
 400                 if (!parse_udp(data, off, data_end, &pckt))
 401                         return TC_ACT_SHOT;
 402         } else {
 403                 return TC_ACT_SHOT;
 404         }
 405 
 406         if (is_ipv6)
 407                 memcpy(vip.daddr.v6, pckt.dstv6, 16);
 408         else
 409                 vip.daddr.v4 = pckt.dst;
 410 
 411         vip.dport = pckt.port16[1];
 412         vip.protocol = pckt.proto;
 413         vip_info = bpf_map_lookup_elem(&vip_map, &vip);
 414         if (!vip_info) {
 415                 vip.dport = 0;
 416                 vip_info = bpf_map_lookup_elem(&vip_map, &vip);
 417                 if (!vip_info)
 418                         return TC_ACT_SHOT;
 419                 pckt.port16[1] = 0;
 420         }
 421 
 422         if (vip_info->flags & F_HASH_NO_SRC_PORT)
 423                 pckt.port16[0] = 0;
 424 
 425         if (!get_packet_dst(&dst, &pckt, vip_info, is_ipv6))
 426                 return TC_ACT_SHOT;
 427 
 428         if (dst->flags & F_IPV6) {
 429                 cval = bpf_map_lookup_elem(&ctl_array, &v6_intf_pos);
 430                 if (!cval)
 431                         return TC_ACT_SHOT;
 432                 ifindex = cval->ifindex;
 433                 memcpy(tkey.remote_ipv6, dst->dstv6, 16);
 434                 tun_flag = BPF_F_TUNINFO_IPV6;
 435         } else {
 436                 cval = bpf_map_lookup_elem(&ctl_array, &v4_intf_pos);
 437                 if (!cval)
 438                         return TC_ACT_SHOT;
 439                 ifindex = cval->ifindex;
 440                 tkey.remote_ipv4 = dst->dst;
 441         }
 442         vip_num = vip_info->vip_num;
 443         data_stats = bpf_map_lookup_elem(&stats, &vip_num);
 444         if (!data_stats)
 445                 return TC_ACT_SHOT;
 446         data_stats->pkts++;
 447         data_stats->bytes += pkt_bytes;
 448         bpf_skb_set_tunnel_key(skb, &tkey, sizeof(tkey), tun_flag);
 449         *(u32 *)eth->eth_dest = tkey.remote_ipv4;
 450         return bpf_redirect(ifindex, 0);
 451 }
 452 
 453 SEC("l4lb-demo")
 454 int balancer_ingress(struct __sk_buff *ctx)
 455 {
 456         void *data_end = (void *)(long)ctx->data_end;
 457         void *data = (void *)(long)ctx->data;
 458         struct eth_hdr *eth = data;
 459         __u32 eth_proto;
 460         __u32 nh_off;
 461 
 462         nh_off = sizeof(struct eth_hdr);
 463         if (data + nh_off > data_end)
 464                 return TC_ACT_SHOT;
 465         eth_proto = eth->eth_proto;
 466         if (eth_proto == bpf_htons(ETH_P_IP))
 467                 return process_packet(data, nh_off, data_end, false, ctx);
 468         else if (eth_proto == bpf_htons(ETH_P_IPV6))
 469                 return process_packet(data, nh_off, data_end, true, ctx);
 470         else
 471                 return TC_ACT_SHOT;
 472 }
 473 char _license[] SEC("license") = "GPL";

/* [<][>][^][v][top][bottom][index][help] */