root/tools/testing/vsock/vsock_diag_test.c

/* [<][>][^][v][top][bottom][index][help] */

DEFINITIONS

This source file includes following definitions.
  1. sock_type_str
  2. sock_state_str
  3. sock_shutdown_str
  4. print_vsock_addr
  5. print_vsock_stat
  6. print_vsock_stats
  7. find_vsock_stat
  8. check_no_sockets
  9. check_num_sockets
  10. check_socket_state
  11. send_req
  12. recv_resp
  13. add_vsock_stat
  14. read_vsock_stat
  15. free_sock_stat
  16. test_no_sockets
  17. test_listen_socket_server
  18. test_connect_client
  19. test_connect_server
  20. init_signals
  21. parse_cid
  22. usage
  23. main

   1 // SPDX-License-Identifier: GPL-2.0-only
   2 /*
   3  * vsock_diag_test - vsock_diag.ko test suite
   4  *
   5  * Copyright (C) 2017 Red Hat, Inc.
   6  *
   7  * Author: Stefan Hajnoczi <stefanha@redhat.com>
   8  */
   9 
  10 #include <getopt.h>
  11 #include <stdio.h>
  12 #include <stdbool.h>
  13 #include <stdlib.h>
  14 #include <string.h>
  15 #include <errno.h>
  16 #include <unistd.h>
  17 #include <signal.h>
  18 #include <sys/socket.h>
  19 #include <sys/stat.h>
  20 #include <sys/types.h>
  21 #include <linux/list.h>
  22 #include <linux/net.h>
  23 #include <linux/netlink.h>
  24 #include <linux/sock_diag.h>
  25 #include <netinet/tcp.h>
  26 
  27 #include "../../../include/uapi/linux/vm_sockets.h"
  28 #include "../../../include/uapi/linux/vm_sockets_diag.h"
  29 
  30 #include "timeout.h"
  31 #include "control.h"
  32 
  33 enum test_mode {
  34         TEST_MODE_UNSET,
  35         TEST_MODE_CLIENT,
  36         TEST_MODE_SERVER
  37 };
  38 
  39 /* Per-socket status */
  40 struct vsock_stat {
  41         struct list_head list;
  42         struct vsock_diag_msg msg;
  43 };
  44 
  45 static const char *sock_type_str(int type)
  46 {
  47         switch (type) {
  48         case SOCK_DGRAM:
  49                 return "DGRAM";
  50         case SOCK_STREAM:
  51                 return "STREAM";
  52         default:
  53                 return "INVALID TYPE";
  54         }
  55 }
  56 
  57 static const char *sock_state_str(int state)
  58 {
  59         switch (state) {
  60         case TCP_CLOSE:
  61                 return "UNCONNECTED";
  62         case TCP_SYN_SENT:
  63                 return "CONNECTING";
  64         case TCP_ESTABLISHED:
  65                 return "CONNECTED";
  66         case TCP_CLOSING:
  67                 return "DISCONNECTING";
  68         case TCP_LISTEN:
  69                 return "LISTEN";
  70         default:
  71                 return "INVALID STATE";
  72         }
  73 }
  74 
  75 static const char *sock_shutdown_str(int shutdown)
  76 {
  77         switch (shutdown) {
  78         case 1:
  79                 return "RCV_SHUTDOWN";
  80         case 2:
  81                 return "SEND_SHUTDOWN";
  82         case 3:
  83                 return "RCV_SHUTDOWN | SEND_SHUTDOWN";
  84         default:
  85                 return "0";
  86         }
  87 }
  88 
  89 static void print_vsock_addr(FILE *fp, unsigned int cid, unsigned int port)
  90 {
  91         if (cid == VMADDR_CID_ANY)
  92                 fprintf(fp, "*:");
  93         else
  94                 fprintf(fp, "%u:", cid);
  95 
  96         if (port == VMADDR_PORT_ANY)
  97                 fprintf(fp, "*");
  98         else
  99                 fprintf(fp, "%u", port);
 100 }
 101 
 102 static void print_vsock_stat(FILE *fp, struct vsock_stat *st)
 103 {
 104         print_vsock_addr(fp, st->msg.vdiag_src_cid, st->msg.vdiag_src_port);
 105         fprintf(fp, " ");
 106         print_vsock_addr(fp, st->msg.vdiag_dst_cid, st->msg.vdiag_dst_port);
 107         fprintf(fp, " %s %s %s %u\n",
 108                 sock_type_str(st->msg.vdiag_type),
 109                 sock_state_str(st->msg.vdiag_state),
 110                 sock_shutdown_str(st->msg.vdiag_shutdown),
 111                 st->msg.vdiag_ino);
 112 }
 113 
 114 static void print_vsock_stats(FILE *fp, struct list_head *head)
 115 {
 116         struct vsock_stat *st;
 117 
 118         list_for_each_entry(st, head, list)
 119                 print_vsock_stat(fp, st);
 120 }
 121 
 122 static struct vsock_stat *find_vsock_stat(struct list_head *head, int fd)
 123 {
 124         struct vsock_stat *st;
 125         struct stat stat;
 126 
 127         if (fstat(fd, &stat) < 0) {
 128                 perror("fstat");
 129                 exit(EXIT_FAILURE);
 130         }
 131 
 132         list_for_each_entry(st, head, list)
 133                 if (st->msg.vdiag_ino == stat.st_ino)
 134                         return st;
 135 
 136         fprintf(stderr, "cannot find fd %d\n", fd);
 137         exit(EXIT_FAILURE);
 138 }
 139 
 140 static void check_no_sockets(struct list_head *head)
 141 {
 142         if (!list_empty(head)) {
 143                 fprintf(stderr, "expected no sockets\n");
 144                 print_vsock_stats(stderr, head);
 145                 exit(1);
 146         }
 147 }
 148 
 149 static void check_num_sockets(struct list_head *head, int expected)
 150 {
 151         struct list_head *node;
 152         int n = 0;
 153 
 154         list_for_each(node, head)
 155                 n++;
 156 
 157         if (n != expected) {
 158                 fprintf(stderr, "expected %d sockets, found %d\n",
 159                         expected, n);
 160                 print_vsock_stats(stderr, head);
 161                 exit(EXIT_FAILURE);
 162         }
 163 }
 164 
 165 static void check_socket_state(struct vsock_stat *st, __u8 state)
 166 {
 167         if (st->msg.vdiag_state != state) {
 168                 fprintf(stderr, "expected socket state %#x, got %#x\n",
 169                         state, st->msg.vdiag_state);
 170                 exit(EXIT_FAILURE);
 171         }
 172 }
 173 
 174 static void send_req(int fd)
 175 {
 176         struct sockaddr_nl nladdr = {
 177                 .nl_family = AF_NETLINK,
 178         };
 179         struct {
 180                 struct nlmsghdr nlh;
 181                 struct vsock_diag_req vreq;
 182         } req = {
 183                 .nlh = {
 184                         .nlmsg_len = sizeof(req),
 185                         .nlmsg_type = SOCK_DIAG_BY_FAMILY,
 186                         .nlmsg_flags = NLM_F_REQUEST | NLM_F_DUMP,
 187                 },
 188                 .vreq = {
 189                         .sdiag_family = AF_VSOCK,
 190                         .vdiag_states = ~(__u32)0,
 191                 },
 192         };
 193         struct iovec iov = {
 194                 .iov_base = &req,
 195                 .iov_len = sizeof(req),
 196         };
 197         struct msghdr msg = {
 198                 .msg_name = &nladdr,
 199                 .msg_namelen = sizeof(nladdr),
 200                 .msg_iov = &iov,
 201                 .msg_iovlen = 1,
 202         };
 203 
 204         for (;;) {
 205                 if (sendmsg(fd, &msg, 0) < 0) {
 206                         if (errno == EINTR)
 207                                 continue;
 208 
 209                         perror("sendmsg");
 210                         exit(EXIT_FAILURE);
 211                 }
 212 
 213                 return;
 214         }
 215 }
 216 
 217 static ssize_t recv_resp(int fd, void *buf, size_t len)
 218 {
 219         struct sockaddr_nl nladdr = {
 220                 .nl_family = AF_NETLINK,
 221         };
 222         struct iovec iov = {
 223                 .iov_base = buf,
 224                 .iov_len = len,
 225         };
 226         struct msghdr msg = {
 227                 .msg_name = &nladdr,
 228                 .msg_namelen = sizeof(nladdr),
 229                 .msg_iov = &iov,
 230                 .msg_iovlen = 1,
 231         };
 232         ssize_t ret;
 233 
 234         do {
 235                 ret = recvmsg(fd, &msg, 0);
 236         } while (ret < 0 && errno == EINTR);
 237 
 238         if (ret < 0) {
 239                 perror("recvmsg");
 240                 exit(EXIT_FAILURE);
 241         }
 242 
 243         return ret;
 244 }
 245 
 246 static void add_vsock_stat(struct list_head *sockets,
 247                            const struct vsock_diag_msg *resp)
 248 {
 249         struct vsock_stat *st;
 250 
 251         st = malloc(sizeof(*st));
 252         if (!st) {
 253                 perror("malloc");
 254                 exit(EXIT_FAILURE);
 255         }
 256 
 257         st->msg = *resp;
 258         list_add_tail(&st->list, sockets);
 259 }
 260 
 261 /*
 262  * Read vsock stats into a list.
 263  */
 264 static void read_vsock_stat(struct list_head *sockets)
 265 {
 266         long buf[8192 / sizeof(long)];
 267         int fd;
 268 
 269         fd = socket(AF_NETLINK, SOCK_RAW, NETLINK_SOCK_DIAG);
 270         if (fd < 0) {
 271                 perror("socket");
 272                 exit(EXIT_FAILURE);
 273         }
 274 
 275         send_req(fd);
 276 
 277         for (;;) {
 278                 const struct nlmsghdr *h;
 279                 ssize_t ret;
 280 
 281                 ret = recv_resp(fd, buf, sizeof(buf));
 282                 if (ret == 0)
 283                         goto done;
 284                 if (ret < sizeof(*h)) {
 285                         fprintf(stderr, "short read of %zd bytes\n", ret);
 286                         exit(EXIT_FAILURE);
 287                 }
 288 
 289                 h = (struct nlmsghdr *)buf;
 290 
 291                 while (NLMSG_OK(h, ret)) {
 292                         if (h->nlmsg_type == NLMSG_DONE)
 293                                 goto done;
 294 
 295                         if (h->nlmsg_type == NLMSG_ERROR) {
 296                                 const struct nlmsgerr *err = NLMSG_DATA(h);
 297 
 298                                 if (h->nlmsg_len < NLMSG_LENGTH(sizeof(*err)))
 299                                         fprintf(stderr, "NLMSG_ERROR\n");
 300                                 else {
 301                                         errno = -err->error;
 302                                         perror("NLMSG_ERROR");
 303                                 }
 304 
 305                                 exit(EXIT_FAILURE);
 306                         }
 307 
 308                         if (h->nlmsg_type != SOCK_DIAG_BY_FAMILY) {
 309                                 fprintf(stderr, "unexpected nlmsg_type %#x\n",
 310                                         h->nlmsg_type);
 311                                 exit(EXIT_FAILURE);
 312                         }
 313                         if (h->nlmsg_len <
 314                             NLMSG_LENGTH(sizeof(struct vsock_diag_msg))) {
 315                                 fprintf(stderr, "short vsock_diag_msg\n");
 316                                 exit(EXIT_FAILURE);
 317                         }
 318 
 319                         add_vsock_stat(sockets, NLMSG_DATA(h));
 320 
 321                         h = NLMSG_NEXT(h, ret);
 322                 }
 323         }
 324 
 325 done:
 326         close(fd);
 327 }
 328 
 329 static void free_sock_stat(struct list_head *sockets)
 330 {
 331         struct vsock_stat *st;
 332         struct vsock_stat *next;
 333 
 334         list_for_each_entry_safe(st, next, sockets, list)
 335                 free(st);
 336 }
 337 
 338 static void test_no_sockets(unsigned int peer_cid)
 339 {
 340         LIST_HEAD(sockets);
 341 
 342         read_vsock_stat(&sockets);
 343 
 344         check_no_sockets(&sockets);
 345 
 346         free_sock_stat(&sockets);
 347 }
 348 
 349 static void test_listen_socket_server(unsigned int peer_cid)
 350 {
 351         union {
 352                 struct sockaddr sa;
 353                 struct sockaddr_vm svm;
 354         } addr = {
 355                 .svm = {
 356                         .svm_family = AF_VSOCK,
 357                         .svm_port = 1234,
 358                         .svm_cid = VMADDR_CID_ANY,
 359                 },
 360         };
 361         LIST_HEAD(sockets);
 362         struct vsock_stat *st;
 363         int fd;
 364 
 365         fd = socket(AF_VSOCK, SOCK_STREAM, 0);
 366 
 367         if (bind(fd, &addr.sa, sizeof(addr.svm)) < 0) {
 368                 perror("bind");
 369                 exit(EXIT_FAILURE);
 370         }
 371 
 372         if (listen(fd, 1) < 0) {
 373                 perror("listen");
 374                 exit(EXIT_FAILURE);
 375         }
 376 
 377         read_vsock_stat(&sockets);
 378 
 379         check_num_sockets(&sockets, 1);
 380         st = find_vsock_stat(&sockets, fd);
 381         check_socket_state(st, TCP_LISTEN);
 382 
 383         close(fd);
 384         free_sock_stat(&sockets);
 385 }
 386 
 387 static void test_connect_client(unsigned int peer_cid)
 388 {
 389         union {
 390                 struct sockaddr sa;
 391                 struct sockaddr_vm svm;
 392         } addr = {
 393                 .svm = {
 394                         .svm_family = AF_VSOCK,
 395                         .svm_port = 1234,
 396                         .svm_cid = peer_cid,
 397                 },
 398         };
 399         int fd;
 400         int ret;
 401         LIST_HEAD(sockets);
 402         struct vsock_stat *st;
 403 
 404         control_expectln("LISTENING");
 405 
 406         fd = socket(AF_VSOCK, SOCK_STREAM, 0);
 407 
 408         timeout_begin(TIMEOUT);
 409         do {
 410                 ret = connect(fd, &addr.sa, sizeof(addr.svm));
 411                 timeout_check("connect");
 412         } while (ret < 0 && errno == EINTR);
 413         timeout_end();
 414 
 415         if (ret < 0) {
 416                 perror("connect");
 417                 exit(EXIT_FAILURE);
 418         }
 419 
 420         read_vsock_stat(&sockets);
 421 
 422         check_num_sockets(&sockets, 1);
 423         st = find_vsock_stat(&sockets, fd);
 424         check_socket_state(st, TCP_ESTABLISHED);
 425 
 426         control_expectln("DONE");
 427         control_writeln("DONE");
 428 
 429         close(fd);
 430         free_sock_stat(&sockets);
 431 }
 432 
 433 static void test_connect_server(unsigned int peer_cid)
 434 {
 435         union {
 436                 struct sockaddr sa;
 437                 struct sockaddr_vm svm;
 438         } addr = {
 439                 .svm = {
 440                         .svm_family = AF_VSOCK,
 441                         .svm_port = 1234,
 442                         .svm_cid = VMADDR_CID_ANY,
 443                 },
 444         };
 445         union {
 446                 struct sockaddr sa;
 447                 struct sockaddr_vm svm;
 448         } clientaddr;
 449         socklen_t clientaddr_len = sizeof(clientaddr.svm);
 450         LIST_HEAD(sockets);
 451         struct vsock_stat *st;
 452         int fd;
 453         int client_fd;
 454 
 455         fd = socket(AF_VSOCK, SOCK_STREAM, 0);
 456 
 457         if (bind(fd, &addr.sa, sizeof(addr.svm)) < 0) {
 458                 perror("bind");
 459                 exit(EXIT_FAILURE);
 460         }
 461 
 462         if (listen(fd, 1) < 0) {
 463                 perror("listen");
 464                 exit(EXIT_FAILURE);
 465         }
 466 
 467         control_writeln("LISTENING");
 468 
 469         timeout_begin(TIMEOUT);
 470         do {
 471                 client_fd = accept(fd, &clientaddr.sa, &clientaddr_len);
 472                 timeout_check("accept");
 473         } while (client_fd < 0 && errno == EINTR);
 474         timeout_end();
 475 
 476         if (client_fd < 0) {
 477                 perror("accept");
 478                 exit(EXIT_FAILURE);
 479         }
 480         if (clientaddr.sa.sa_family != AF_VSOCK) {
 481                 fprintf(stderr, "expected AF_VSOCK from accept(2), got %d\n",
 482                         clientaddr.sa.sa_family);
 483                 exit(EXIT_FAILURE);
 484         }
 485         if (clientaddr.svm.svm_cid != peer_cid) {
 486                 fprintf(stderr, "expected peer CID %u from accept(2), got %u\n",
 487                         peer_cid, clientaddr.svm.svm_cid);
 488                 exit(EXIT_FAILURE);
 489         }
 490 
 491         read_vsock_stat(&sockets);
 492 
 493         check_num_sockets(&sockets, 2);
 494         find_vsock_stat(&sockets, fd);
 495         st = find_vsock_stat(&sockets, client_fd);
 496         check_socket_state(st, TCP_ESTABLISHED);
 497 
 498         control_writeln("DONE");
 499         control_expectln("DONE");
 500 
 501         close(client_fd);
 502         close(fd);
 503         free_sock_stat(&sockets);
 504 }
 505 
 506 static struct {
 507         const char *name;
 508         void (*run_client)(unsigned int peer_cid);
 509         void (*run_server)(unsigned int peer_cid);
 510 } test_cases[] = {
 511         {
 512                 .name = "No sockets",
 513                 .run_server = test_no_sockets,
 514         },
 515         {
 516                 .name = "Listen socket",
 517                 .run_server = test_listen_socket_server,
 518         },
 519         {
 520                 .name = "Connect",
 521                 .run_client = test_connect_client,
 522                 .run_server = test_connect_server,
 523         },
 524         {},
 525 };
 526 
 527 static void init_signals(void)
 528 {
 529         struct sigaction act = {
 530                 .sa_handler = sigalrm,
 531         };
 532 
 533         sigaction(SIGALRM, &act, NULL);
 534         signal(SIGPIPE, SIG_IGN);
 535 }
 536 
 537 static unsigned int parse_cid(const char *str)
 538 {
 539         char *endptr = NULL;
 540         unsigned long int n;
 541 
 542         errno = 0;
 543         n = strtoul(str, &endptr, 10);
 544         if (errno || *endptr != '\0') {
 545                 fprintf(stderr, "malformed CID \"%s\"\n", str);
 546                 exit(EXIT_FAILURE);
 547         }
 548         return n;
 549 }
 550 
 551 static const char optstring[] = "";
 552 static const struct option longopts[] = {
 553         {
 554                 .name = "control-host",
 555                 .has_arg = required_argument,
 556                 .val = 'H',
 557         },
 558         {
 559                 .name = "control-port",
 560                 .has_arg = required_argument,
 561                 .val = 'P',
 562         },
 563         {
 564                 .name = "mode",
 565                 .has_arg = required_argument,
 566                 .val = 'm',
 567         },
 568         {
 569                 .name = "peer-cid",
 570                 .has_arg = required_argument,
 571                 .val = 'p',
 572         },
 573         {
 574                 .name = "help",
 575                 .has_arg = no_argument,
 576                 .val = '?',
 577         },
 578         {},
 579 };
 580 
 581 static void usage(void)
 582 {
 583         fprintf(stderr, "Usage: vsock_diag_test [--help] [--control-host=<host>] --control-port=<port> --mode=client|server --peer-cid=<cid>\n"
 584                 "\n"
 585                 "  Server: vsock_diag_test --control-port=1234 --mode=server --peer-cid=3\n"
 586                 "  Client: vsock_diag_test --control-host=192.168.0.1 --control-port=1234 --mode=client --peer-cid=2\n"
 587                 "\n"
 588                 "Run vsock_diag.ko tests.  Must be launched in both\n"
 589                 "guest and host.  One side must use --mode=client and\n"
 590                 "the other side must use --mode=server.\n"
 591                 "\n"
 592                 "A TCP control socket connection is used to coordinate tests\n"
 593                 "between the client and the server.  The server requires a\n"
 594                 "listen address and the client requires an address to\n"
 595                 "connect to.\n"
 596                 "\n"
 597                 "The CID of the other side must be given with --peer-cid=<cid>.\n");
 598         exit(EXIT_FAILURE);
 599 }
 600 
 601 int main(int argc, char **argv)
 602 {
 603         const char *control_host = NULL;
 604         const char *control_port = NULL;
 605         int mode = TEST_MODE_UNSET;
 606         unsigned int peer_cid = VMADDR_CID_ANY;
 607         int i;
 608 
 609         init_signals();
 610 
 611         for (;;) {
 612                 int opt = getopt_long(argc, argv, optstring, longopts, NULL);
 613 
 614                 if (opt == -1)
 615                         break;
 616 
 617                 switch (opt) {
 618                 case 'H':
 619                         control_host = optarg;
 620                         break;
 621                 case 'm':
 622                         if (strcmp(optarg, "client") == 0)
 623                                 mode = TEST_MODE_CLIENT;
 624                         else if (strcmp(optarg, "server") == 0)
 625                                 mode = TEST_MODE_SERVER;
 626                         else {
 627                                 fprintf(stderr, "--mode must be \"client\" or \"server\"\n");
 628                                 return EXIT_FAILURE;
 629                         }
 630                         break;
 631                 case 'p':
 632                         peer_cid = parse_cid(optarg);
 633                         break;
 634                 case 'P':
 635                         control_port = optarg;
 636                         break;
 637                 case '?':
 638                 default:
 639                         usage();
 640                 }
 641         }
 642 
 643         if (!control_port)
 644                 usage();
 645         if (mode == TEST_MODE_UNSET)
 646                 usage();
 647         if (peer_cid == VMADDR_CID_ANY)
 648                 usage();
 649 
 650         if (!control_host) {
 651                 if (mode != TEST_MODE_SERVER)
 652                         usage();
 653                 control_host = "0.0.0.0";
 654         }
 655 
 656         control_init(control_host, control_port, mode == TEST_MODE_SERVER);
 657 
 658         for (i = 0; test_cases[i].name; i++) {
 659                 void (*run)(unsigned int peer_cid);
 660 
 661                 printf("%s...", test_cases[i].name);
 662                 fflush(stdout);
 663 
 664                 if (mode == TEST_MODE_CLIENT)
 665                         run = test_cases[i].run_client;
 666                 else
 667                         run = test_cases[i].run_server;
 668 
 669                 if (run)
 670                         run(peer_cid);
 671 
 672                 printf("ok\n");
 673         }
 674 
 675         control_cleanup();
 676         return EXIT_SUCCESS;
 677 }

/* [<][>][^][v][top][bottom][index][help] */