root/tools/testing/selftests/net/psock_snd.c

/* [<][>][^][v][top][bottom][index][help] */

DEFINITIONS

This source file includes following definitions.
  1. add_csum_hword
  2. build_ip_csum
  3. build_vnet_header
  4. build_eth_header
  5. build_ipv4_header
  6. build_udp_header
  7. build_packet
  8. do_bind
  9. do_send
  10. do_tx
  11. setup_rx
  12. do_rx
  13. setup_sniffer
  14. parse_opts
  15. run_test
  16. main

   1 // SPDX-License-Identifier: GPL-2.0
   2 
   3 #define _GNU_SOURCE
   4 
   5 #include <arpa/inet.h>
   6 #include <errno.h>
   7 #include <error.h>
   8 #include <fcntl.h>
   9 #include <limits.h>
  10 #include <linux/filter.h>
  11 #include <linux/bpf.h>
  12 #include <linux/if_packet.h>
  13 #include <linux/if_vlan.h>
  14 #include <linux/virtio_net.h>
  15 #include <net/if.h>
  16 #include <net/ethernet.h>
  17 #include <netinet/ip.h>
  18 #include <netinet/udp.h>
  19 #include <poll.h>
  20 #include <sched.h>
  21 #include <stdbool.h>
  22 #include <stdint.h>
  23 #include <stdio.h>
  24 #include <stdlib.h>
  25 #include <string.h>
  26 #include <sys/mman.h>
  27 #include <sys/socket.h>
  28 #include <sys/stat.h>
  29 #include <sys/types.h>
  30 #include <unistd.h>
  31 
  32 #include "psock_lib.h"
  33 
  34 static bool     cfg_use_bind;
  35 static bool     cfg_use_csum_off;
  36 static bool     cfg_use_csum_off_bad;
  37 static bool     cfg_use_dgram;
  38 static bool     cfg_use_gso;
  39 static bool     cfg_use_qdisc_bypass;
  40 static bool     cfg_use_vlan;
  41 static bool     cfg_use_vnet;
  42 
  43 static char     *cfg_ifname = "lo";
  44 static int      cfg_mtu = 1500;
  45 static int      cfg_payload_len = DATA_LEN;
  46 static int      cfg_truncate_len = INT_MAX;
  47 static uint16_t cfg_port = 8000;
  48 
  49 /* test sending up to max mtu + 1 */
  50 #define TEST_SZ (sizeof(struct virtio_net_hdr) + ETH_HLEN + ETH_MAX_MTU + 1)
  51 
  52 static char tbuf[TEST_SZ], rbuf[TEST_SZ];
  53 
  54 static unsigned long add_csum_hword(const uint16_t *start, int num_u16)
  55 {
  56         unsigned long sum = 0;
  57         int i;
  58 
  59         for (i = 0; i < num_u16; i++)
  60                 sum += start[i];
  61 
  62         return sum;
  63 }
  64 
  65 static uint16_t build_ip_csum(const uint16_t *start, int num_u16,
  66                               unsigned long sum)
  67 {
  68         sum += add_csum_hword(start, num_u16);
  69 
  70         while (sum >> 16)
  71                 sum = (sum & 0xffff) + (sum >> 16);
  72 
  73         return ~sum;
  74 }
  75 
  76 static int build_vnet_header(void *header)
  77 {
  78         struct virtio_net_hdr *vh = header;
  79 
  80         vh->hdr_len = ETH_HLEN + sizeof(struct iphdr) + sizeof(struct udphdr);
  81 
  82         if (cfg_use_csum_off) {
  83                 vh->flags |= VIRTIO_NET_HDR_F_NEEDS_CSUM;
  84                 vh->csum_start = ETH_HLEN + sizeof(struct iphdr);
  85                 vh->csum_offset = __builtin_offsetof(struct udphdr, check);
  86 
  87                 /* position check field exactly one byte beyond end of packet */
  88                 if (cfg_use_csum_off_bad)
  89                         vh->csum_start += sizeof(struct udphdr) + cfg_payload_len -
  90                                           vh->csum_offset - 1;
  91         }
  92 
  93         if (cfg_use_gso) {
  94                 vh->gso_type = VIRTIO_NET_HDR_GSO_UDP;
  95                 vh->gso_size = cfg_mtu - sizeof(struct iphdr);
  96         }
  97 
  98         return sizeof(*vh);
  99 }
 100 
 101 static int build_eth_header(void *header)
 102 {
 103         struct ethhdr *eth = header;
 104 
 105         if (cfg_use_vlan) {
 106                 uint16_t *tag = header + ETH_HLEN;
 107 
 108                 eth->h_proto = htons(ETH_P_8021Q);
 109                 tag[1] = htons(ETH_P_IP);
 110                 return ETH_HLEN + 4;
 111         }
 112 
 113         eth->h_proto = htons(ETH_P_IP);
 114         return ETH_HLEN;
 115 }
 116 
 117 static int build_ipv4_header(void *header, int payload_len)
 118 {
 119         struct iphdr *iph = header;
 120 
 121         iph->ihl = 5;
 122         iph->version = 4;
 123         iph->ttl = 8;
 124         iph->tot_len = htons(sizeof(*iph) + sizeof(struct udphdr) + payload_len);
 125         iph->id = htons(1337);
 126         iph->protocol = IPPROTO_UDP;
 127         iph->saddr = htonl((172 << 24) | (17 << 16) | 2);
 128         iph->daddr = htonl((172 << 24) | (17 << 16) | 1);
 129         iph->check = build_ip_csum((void *) iph, iph->ihl << 1, 0);
 130 
 131         return iph->ihl << 2;
 132 }
 133 
 134 static int build_udp_header(void *header, int payload_len)
 135 {
 136         const int alen = sizeof(uint32_t);
 137         struct udphdr *udph = header;
 138         int len = sizeof(*udph) + payload_len;
 139 
 140         udph->source = htons(9);
 141         udph->dest = htons(cfg_port);
 142         udph->len = htons(len);
 143 
 144         if (cfg_use_csum_off)
 145                 udph->check = build_ip_csum(header - (2 * alen), alen,
 146                                             htons(IPPROTO_UDP) + udph->len);
 147         else
 148                 udph->check = 0;
 149 
 150         return sizeof(*udph);
 151 }
 152 
 153 static int build_packet(int payload_len)
 154 {
 155         int off = 0;
 156 
 157         off += build_vnet_header(tbuf);
 158         off += build_eth_header(tbuf + off);
 159         off += build_ipv4_header(tbuf + off, payload_len);
 160         off += build_udp_header(tbuf + off, payload_len);
 161 
 162         if (off + payload_len > sizeof(tbuf))
 163                 error(1, 0, "payload length exceeds max");
 164 
 165         memset(tbuf + off, DATA_CHAR, payload_len);
 166 
 167         return off + payload_len;
 168 }
 169 
 170 static void do_bind(int fd)
 171 {
 172         struct sockaddr_ll laddr = {0};
 173 
 174         laddr.sll_family = AF_PACKET;
 175         laddr.sll_protocol = htons(ETH_P_IP);
 176         laddr.sll_ifindex = if_nametoindex(cfg_ifname);
 177         if (!laddr.sll_ifindex)
 178                 error(1, errno, "if_nametoindex");
 179 
 180         if (bind(fd, (void *)&laddr, sizeof(laddr)))
 181                 error(1, errno, "bind");
 182 }
 183 
 184 static void do_send(int fd, char *buf, int len)
 185 {
 186         int ret;
 187 
 188         if (!cfg_use_vnet) {
 189                 buf += sizeof(struct virtio_net_hdr);
 190                 len -= sizeof(struct virtio_net_hdr);
 191         }
 192         if (cfg_use_dgram) {
 193                 buf += ETH_HLEN;
 194                 len -= ETH_HLEN;
 195         }
 196 
 197         if (cfg_use_bind) {
 198                 ret = write(fd, buf, len);
 199         } else {
 200                 struct sockaddr_ll laddr = {0};
 201 
 202                 laddr.sll_protocol = htons(ETH_P_IP);
 203                 laddr.sll_ifindex = if_nametoindex(cfg_ifname);
 204                 if (!laddr.sll_ifindex)
 205                         error(1, errno, "if_nametoindex");
 206 
 207                 ret = sendto(fd, buf, len, 0, (void *)&laddr, sizeof(laddr));
 208         }
 209 
 210         if (ret == -1)
 211                 error(1, errno, "write");
 212         if (ret != len)
 213                 error(1, 0, "write: %u %u", ret, len);
 214 
 215         fprintf(stderr, "tx: %u\n", ret);
 216 }
 217 
 218 static int do_tx(void)
 219 {
 220         const int one = 1;
 221         int fd, len;
 222 
 223         fd = socket(PF_PACKET, cfg_use_dgram ? SOCK_DGRAM : SOCK_RAW, 0);
 224         if (fd == -1)
 225                 error(1, errno, "socket t");
 226 
 227         if (cfg_use_bind)
 228                 do_bind(fd);
 229 
 230         if (cfg_use_qdisc_bypass &&
 231             setsockopt(fd, SOL_PACKET, PACKET_QDISC_BYPASS, &one, sizeof(one)))
 232                 error(1, errno, "setsockopt qdisc bypass");
 233 
 234         if (cfg_use_vnet &&
 235             setsockopt(fd, SOL_PACKET, PACKET_VNET_HDR, &one, sizeof(one)))
 236                 error(1, errno, "setsockopt vnet");
 237 
 238         len = build_packet(cfg_payload_len);
 239 
 240         if (cfg_truncate_len < len)
 241                 len = cfg_truncate_len;
 242 
 243         do_send(fd, tbuf, len);
 244 
 245         if (close(fd))
 246                 error(1, errno, "close t");
 247 
 248         return len;
 249 }
 250 
 251 static int setup_rx(void)
 252 {
 253         struct timeval tv = { .tv_usec = 100 * 1000 };
 254         struct sockaddr_in raddr = {0};
 255         int fd;
 256 
 257         fd = socket(PF_INET, SOCK_DGRAM, 0);
 258         if (fd == -1)
 259                 error(1, errno, "socket r");
 260 
 261         if (setsockopt(fd, SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof(tv)))
 262                 error(1, errno, "setsockopt rcv timeout");
 263 
 264         raddr.sin_family = AF_INET;
 265         raddr.sin_port = htons(cfg_port);
 266         raddr.sin_addr.s_addr = htonl(INADDR_ANY);
 267 
 268         if (bind(fd, (void *)&raddr, sizeof(raddr)))
 269                 error(1, errno, "bind r");
 270 
 271         return fd;
 272 }
 273 
 274 static void do_rx(int fd, int expected_len, char *expected)
 275 {
 276         int ret;
 277 
 278         ret = recv(fd, rbuf, sizeof(rbuf), 0);
 279         if (ret == -1)
 280                 error(1, errno, "recv");
 281         if (ret != expected_len)
 282                 error(1, 0, "recv: %u != %u", ret, expected_len);
 283 
 284         if (memcmp(rbuf, expected, ret))
 285                 error(1, 0, "recv: data mismatch");
 286 
 287         fprintf(stderr, "rx: %u\n", ret);
 288 }
 289 
 290 static int setup_sniffer(void)
 291 {
 292         struct timeval tv = { .tv_usec = 100 * 1000 };
 293         int fd;
 294 
 295         fd = socket(PF_PACKET, SOCK_RAW, 0);
 296         if (fd == -1)
 297                 error(1, errno, "socket p");
 298 
 299         if (setsockopt(fd, SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof(tv)))
 300                 error(1, errno, "setsockopt rcv timeout");
 301 
 302         pair_udp_setfilter(fd);
 303         do_bind(fd);
 304 
 305         return fd;
 306 }
 307 
 308 static void parse_opts(int argc, char **argv)
 309 {
 310         int c;
 311 
 312         while ((c = getopt(argc, argv, "bcCdgl:qt:vV")) != -1) {
 313                 switch (c) {
 314                 case 'b':
 315                         cfg_use_bind = true;
 316                         break;
 317                 case 'c':
 318                         cfg_use_csum_off = true;
 319                         break;
 320                 case 'C':
 321                         cfg_use_csum_off_bad = true;
 322                         break;
 323                 case 'd':
 324                         cfg_use_dgram = true;
 325                         break;
 326                 case 'g':
 327                         cfg_use_gso = true;
 328                         break;
 329                 case 'l':
 330                         cfg_payload_len = strtoul(optarg, NULL, 0);
 331                         break;
 332                 case 'q':
 333                         cfg_use_qdisc_bypass = true;
 334                         break;
 335                 case 't':
 336                         cfg_truncate_len = strtoul(optarg, NULL, 0);
 337                         break;
 338                 case 'v':
 339                         cfg_use_vnet = true;
 340                         break;
 341                 case 'V':
 342                         cfg_use_vlan = true;
 343                         break;
 344                 default:
 345                         error(1, 0, "%s: parse error", argv[0]);
 346                 }
 347         }
 348 
 349         if (cfg_use_vlan && cfg_use_dgram)
 350                 error(1, 0, "option vlan (-V) conflicts with dgram (-d)");
 351 
 352         if (cfg_use_csum_off && !cfg_use_vnet)
 353                 error(1, 0, "option csum offload (-c) requires vnet (-v)");
 354 
 355         if (cfg_use_csum_off_bad && !cfg_use_csum_off)
 356                 error(1, 0, "option csum bad (-C) requires csum offload (-c)");
 357 
 358         if (cfg_use_gso && !cfg_use_csum_off)
 359                 error(1, 0, "option gso (-g) requires csum offload (-c)");
 360 }
 361 
 362 static void run_test(void)
 363 {
 364         int fdr, fds, total_len;
 365 
 366         fdr = setup_rx();
 367         fds = setup_sniffer();
 368 
 369         total_len = do_tx();
 370 
 371         /* BPF filter accepts only this length, vlan changes MAC */
 372         if (cfg_payload_len == DATA_LEN && !cfg_use_vlan)
 373                 do_rx(fds, total_len - sizeof(struct virtio_net_hdr),
 374                       tbuf + sizeof(struct virtio_net_hdr));
 375 
 376         do_rx(fdr, cfg_payload_len, tbuf + total_len - cfg_payload_len);
 377 
 378         if (close(fds))
 379                 error(1, errno, "close s");
 380         if (close(fdr))
 381                 error(1, errno, "close r");
 382 }
 383 
 384 int main(int argc, char **argv)
 385 {
 386         parse_opts(argc, argv);
 387 
 388         if (system("ip link set dev lo mtu 1500"))
 389                 error(1, errno, "ip link set mtu");
 390         if (system("ip addr add dev lo 172.17.0.1/24"))
 391                 error(1, errno, "ip addr add");
 392 
 393         run_test();
 394 
 395         fprintf(stderr, "OK\n\n");
 396         return 0;
 397 }

/* [<][>][^][v][top][bottom][index][help] */