root/net/ipv4/tcp_bpf.c

/* [<][>][^][v][top][bottom][index][help] */

DEFINITIONS

This source file includes following definitions.
  1. tcp_bpf_stream_read
  2. tcp_bpf_wait_data
  3. __tcp_bpf_recvmsg
  4. tcp_bpf_recvmsg
  5. bpf_tcp_ingress
  6. tcp_bpf_push
  7. tcp_bpf_push_locked
  8. tcp_bpf_sendmsg_redir
  9. tcp_bpf_send_verdict
  10. tcp_bpf_sendmsg
  11. tcp_bpf_sendpage
  12. tcp_bpf_remove
  13. tcp_bpf_unhash
  14. tcp_bpf_close
  15. tcp_bpf_rebuild_protos
  16. tcp_bpf_check_v6_needs_rebuild
  17. tcp_bpf_v4_build_proto
  18. tcp_bpf_update_sk_prot
  19. tcp_bpf_reinit_sk_prot
  20. tcp_bpf_assert_proto_ops
  21. tcp_bpf_reinit
  22. tcp_bpf_init

   1 // SPDX-License-Identifier: GPL-2.0
   2 /* Copyright (c) 2017 - 2018 Covalent IO, Inc. http://covalent.io */
   3 
   4 #include <linux/skmsg.h>
   5 #include <linux/filter.h>
   6 #include <linux/bpf.h>
   7 #include <linux/init.h>
   8 #include <linux/wait.h>
   9 
  10 #include <net/inet_common.h>
  11 #include <net/tls.h>
  12 
  13 static bool tcp_bpf_stream_read(const struct sock *sk)
  14 {
  15         struct sk_psock *psock;
  16         bool empty = true;
  17 
  18         rcu_read_lock();
  19         psock = sk_psock(sk);
  20         if (likely(psock))
  21                 empty = list_empty(&psock->ingress_msg);
  22         rcu_read_unlock();
  23         return !empty;
  24 }
  25 
  26 static int tcp_bpf_wait_data(struct sock *sk, struct sk_psock *psock,
  27                              int flags, long timeo, int *err)
  28 {
  29         DEFINE_WAIT_FUNC(wait, woken_wake_function);
  30         int ret = 0;
  31 
  32         if (!timeo)
  33                 return ret;
  34 
  35         add_wait_queue(sk_sleep(sk), &wait);
  36         sk_set_bit(SOCKWQ_ASYNC_WAITDATA, sk);
  37         ret = sk_wait_event(sk, &timeo,
  38                             !list_empty(&psock->ingress_msg) ||
  39                             !skb_queue_empty(&sk->sk_receive_queue), &wait);
  40         sk_clear_bit(SOCKWQ_ASYNC_WAITDATA, sk);
  41         remove_wait_queue(sk_sleep(sk), &wait);
  42         return ret;
  43 }
  44 
  45 int __tcp_bpf_recvmsg(struct sock *sk, struct sk_psock *psock,
  46                       struct msghdr *msg, int len, int flags)
  47 {
  48         struct iov_iter *iter = &msg->msg_iter;
  49         int peek = flags & MSG_PEEK;
  50         int i, ret, copied = 0;
  51         struct sk_msg *msg_rx;
  52 
  53         msg_rx = list_first_entry_or_null(&psock->ingress_msg,
  54                                           struct sk_msg, list);
  55 
  56         while (copied != len) {
  57                 struct scatterlist *sge;
  58 
  59                 if (unlikely(!msg_rx))
  60                         break;
  61 
  62                 i = msg_rx->sg.start;
  63                 do {
  64                         struct page *page;
  65                         int copy;
  66 
  67                         sge = sk_msg_elem(msg_rx, i);
  68                         copy = sge->length;
  69                         page = sg_page(sge);
  70                         if (copied + copy > len)
  71                                 copy = len - copied;
  72                         ret = copy_page_to_iter(page, sge->offset, copy, iter);
  73                         if (ret != copy) {
  74                                 msg_rx->sg.start = i;
  75                                 return -EFAULT;
  76                         }
  77 
  78                         copied += copy;
  79                         if (likely(!peek)) {
  80                                 sge->offset += copy;
  81                                 sge->length -= copy;
  82                                 sk_mem_uncharge(sk, copy);
  83                                 msg_rx->sg.size -= copy;
  84 
  85                                 if (!sge->length) {
  86                                         sk_msg_iter_var_next(i);
  87                                         if (!msg_rx->skb)
  88                                                 put_page(page);
  89                                 }
  90                         } else {
  91                                 sk_msg_iter_var_next(i);
  92                         }
  93 
  94                         if (copied == len)
  95                                 break;
  96                 } while (i != msg_rx->sg.end);
  97 
  98                 if (unlikely(peek)) {
  99                         msg_rx = list_next_entry(msg_rx, list);
 100                         continue;
 101                 }
 102 
 103                 msg_rx->sg.start = i;
 104                 if (!sge->length && msg_rx->sg.start == msg_rx->sg.end) {
 105                         list_del(&msg_rx->list);
 106                         if (msg_rx->skb)
 107                                 consume_skb(msg_rx->skb);
 108                         kfree(msg_rx);
 109                 }
 110                 msg_rx = list_first_entry_or_null(&psock->ingress_msg,
 111                                                   struct sk_msg, list);
 112         }
 113 
 114         return copied;
 115 }
 116 EXPORT_SYMBOL_GPL(__tcp_bpf_recvmsg);
 117 
 118 int tcp_bpf_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
 119                     int nonblock, int flags, int *addr_len)
 120 {
 121         struct sk_psock *psock;
 122         int copied, ret;
 123 
 124         if (unlikely(flags & MSG_ERRQUEUE))
 125                 return inet_recv_error(sk, msg, len, addr_len);
 126 
 127         psock = sk_psock_get(sk);
 128         if (unlikely(!psock))
 129                 return tcp_recvmsg(sk, msg, len, nonblock, flags, addr_len);
 130         if (!skb_queue_empty(&sk->sk_receive_queue) &&
 131             sk_psock_queue_empty(psock)) {
 132                 sk_psock_put(sk, psock);
 133                 return tcp_recvmsg(sk, msg, len, nonblock, flags, addr_len);
 134         }
 135         lock_sock(sk);
 136 msg_bytes_ready:
 137         copied = __tcp_bpf_recvmsg(sk, psock, msg, len, flags);
 138         if (!copied) {
 139                 int data, err = 0;
 140                 long timeo;
 141 
 142                 timeo = sock_rcvtimeo(sk, nonblock);
 143                 data = tcp_bpf_wait_data(sk, psock, flags, timeo, &err);
 144                 if (data) {
 145                         if (!sk_psock_queue_empty(psock))
 146                                 goto msg_bytes_ready;
 147                         release_sock(sk);
 148                         sk_psock_put(sk, psock);
 149                         return tcp_recvmsg(sk, msg, len, nonblock, flags, addr_len);
 150                 }
 151                 if (err) {
 152                         ret = err;
 153                         goto out;
 154                 }
 155                 copied = -EAGAIN;
 156         }
 157         ret = copied;
 158 out:
 159         release_sock(sk);
 160         sk_psock_put(sk, psock);
 161         return ret;
 162 }
 163 
 164 static int bpf_tcp_ingress(struct sock *sk, struct sk_psock *psock,
 165                            struct sk_msg *msg, u32 apply_bytes, int flags)
 166 {
 167         bool apply = apply_bytes;
 168         struct scatterlist *sge;
 169         u32 size, copied = 0;
 170         struct sk_msg *tmp;
 171         int i, ret = 0;
 172 
 173         tmp = kzalloc(sizeof(*tmp), __GFP_NOWARN | GFP_KERNEL);
 174         if (unlikely(!tmp))
 175                 return -ENOMEM;
 176 
 177         lock_sock(sk);
 178         tmp->sg.start = msg->sg.start;
 179         i = msg->sg.start;
 180         do {
 181                 sge = sk_msg_elem(msg, i);
 182                 size = (apply && apply_bytes < sge->length) ?
 183                         apply_bytes : sge->length;
 184                 if (!sk_wmem_schedule(sk, size)) {
 185                         if (!copied)
 186                                 ret = -ENOMEM;
 187                         break;
 188                 }
 189 
 190                 sk_mem_charge(sk, size);
 191                 sk_msg_xfer(tmp, msg, i, size);
 192                 copied += size;
 193                 if (sge->length)
 194                         get_page(sk_msg_page(tmp, i));
 195                 sk_msg_iter_var_next(i);
 196                 tmp->sg.end = i;
 197                 if (apply) {
 198                         apply_bytes -= size;
 199                         if (!apply_bytes)
 200                                 break;
 201                 }
 202         } while (i != msg->sg.end);
 203 
 204         if (!ret) {
 205                 msg->sg.start = i;
 206                 sk_psock_queue_msg(psock, tmp);
 207                 sk_psock_data_ready(sk, psock);
 208         } else {
 209                 sk_msg_free(sk, tmp);
 210                 kfree(tmp);
 211         }
 212 
 213         release_sock(sk);
 214         return ret;
 215 }
 216 
 217 static int tcp_bpf_push(struct sock *sk, struct sk_msg *msg, u32 apply_bytes,
 218                         int flags, bool uncharge)
 219 {
 220         bool apply = apply_bytes;
 221         struct scatterlist *sge;
 222         struct page *page;
 223         int size, ret = 0;
 224         u32 off;
 225 
 226         while (1) {
 227                 bool has_tx_ulp;
 228 
 229                 sge = sk_msg_elem(msg, msg->sg.start);
 230                 size = (apply && apply_bytes < sge->length) ?
 231                         apply_bytes : sge->length;
 232                 off  = sge->offset;
 233                 page = sg_page(sge);
 234 
 235                 tcp_rate_check_app_limited(sk);
 236 retry:
 237                 has_tx_ulp = tls_sw_has_ctx_tx(sk);
 238                 if (has_tx_ulp) {
 239                         flags |= MSG_SENDPAGE_NOPOLICY;
 240                         ret = kernel_sendpage_locked(sk,
 241                                                      page, off, size, flags);
 242                 } else {
 243                         ret = do_tcp_sendpages(sk, page, off, size, flags);
 244                 }
 245 
 246                 if (ret <= 0)
 247                         return ret;
 248                 if (apply)
 249                         apply_bytes -= ret;
 250                 msg->sg.size -= ret;
 251                 sge->offset += ret;
 252                 sge->length -= ret;
 253                 if (uncharge)
 254                         sk_mem_uncharge(sk, ret);
 255                 if (ret != size) {
 256                         size -= ret;
 257                         off  += ret;
 258                         goto retry;
 259                 }
 260                 if (!sge->length) {
 261                         put_page(page);
 262                         sk_msg_iter_next(msg, start);
 263                         sg_init_table(sge, 1);
 264                         if (msg->sg.start == msg->sg.end)
 265                                 break;
 266                 }
 267                 if (apply && !apply_bytes)
 268                         break;
 269         }
 270 
 271         return 0;
 272 }
 273 
 274 static int tcp_bpf_push_locked(struct sock *sk, struct sk_msg *msg,
 275                                u32 apply_bytes, int flags, bool uncharge)
 276 {
 277         int ret;
 278 
 279         lock_sock(sk);
 280         ret = tcp_bpf_push(sk, msg, apply_bytes, flags, uncharge);
 281         release_sock(sk);
 282         return ret;
 283 }
 284 
 285 int tcp_bpf_sendmsg_redir(struct sock *sk, struct sk_msg *msg,
 286                           u32 bytes, int flags)
 287 {
 288         bool ingress = sk_msg_to_ingress(msg);
 289         struct sk_psock *psock = sk_psock_get(sk);
 290         int ret;
 291 
 292         if (unlikely(!psock)) {
 293                 sk_msg_free(sk, msg);
 294                 return 0;
 295         }
 296         ret = ingress ? bpf_tcp_ingress(sk, psock, msg, bytes, flags) :
 297                         tcp_bpf_push_locked(sk, msg, bytes, flags, false);
 298         sk_psock_put(sk, psock);
 299         return ret;
 300 }
 301 EXPORT_SYMBOL_GPL(tcp_bpf_sendmsg_redir);
 302 
 303 static int tcp_bpf_send_verdict(struct sock *sk, struct sk_psock *psock,
 304                                 struct sk_msg *msg, int *copied, int flags)
 305 {
 306         bool cork = false, enospc = sk_msg_full(msg);
 307         struct sock *sk_redir;
 308         u32 tosend, delta = 0;
 309         int ret;
 310 
 311 more_data:
 312         if (psock->eval == __SK_NONE) {
 313                 /* Track delta in msg size to add/subtract it on SK_DROP from
 314                  * returned to user copied size. This ensures user doesn't
 315                  * get a positive return code with msg_cut_data and SK_DROP
 316                  * verdict.
 317                  */
 318                 delta = msg->sg.size;
 319                 psock->eval = sk_psock_msg_verdict(sk, psock, msg);
 320                 delta -= msg->sg.size;
 321         }
 322 
 323         if (msg->cork_bytes &&
 324             msg->cork_bytes > msg->sg.size && !enospc) {
 325                 psock->cork_bytes = msg->cork_bytes - msg->sg.size;
 326                 if (!psock->cork) {
 327                         psock->cork = kzalloc(sizeof(*psock->cork),
 328                                               GFP_ATOMIC | __GFP_NOWARN);
 329                         if (!psock->cork)
 330                                 return -ENOMEM;
 331                 }
 332                 memcpy(psock->cork, msg, sizeof(*msg));
 333                 return 0;
 334         }
 335 
 336         tosend = msg->sg.size;
 337         if (psock->apply_bytes && psock->apply_bytes < tosend)
 338                 tosend = psock->apply_bytes;
 339 
 340         switch (psock->eval) {
 341         case __SK_PASS:
 342                 ret = tcp_bpf_push(sk, msg, tosend, flags, true);
 343                 if (unlikely(ret)) {
 344                         *copied -= sk_msg_free(sk, msg);
 345                         break;
 346                 }
 347                 sk_msg_apply_bytes(psock, tosend);
 348                 break;
 349         case __SK_REDIRECT:
 350                 sk_redir = psock->sk_redir;
 351                 sk_msg_apply_bytes(psock, tosend);
 352                 if (psock->cork) {
 353                         cork = true;
 354                         psock->cork = NULL;
 355                 }
 356                 sk_msg_return(sk, msg, tosend);
 357                 release_sock(sk);
 358                 ret = tcp_bpf_sendmsg_redir(sk_redir, msg, tosend, flags);
 359                 lock_sock(sk);
 360                 if (unlikely(ret < 0)) {
 361                         int free = sk_msg_free_nocharge(sk, msg);
 362 
 363                         if (!cork)
 364                                 *copied -= free;
 365                 }
 366                 if (cork) {
 367                         sk_msg_free(sk, msg);
 368                         kfree(msg);
 369                         msg = NULL;
 370                         ret = 0;
 371                 }
 372                 break;
 373         case __SK_DROP:
 374         default:
 375                 sk_msg_free_partial(sk, msg, tosend);
 376                 sk_msg_apply_bytes(psock, tosend);
 377                 *copied -= (tosend + delta);
 378                 return -EACCES;
 379         }
 380 
 381         if (likely(!ret)) {
 382                 if (!psock->apply_bytes) {
 383                         psock->eval =  __SK_NONE;
 384                         if (psock->sk_redir) {
 385                                 sock_put(psock->sk_redir);
 386                                 psock->sk_redir = NULL;
 387                         }
 388                 }
 389                 if (msg &&
 390                     msg->sg.data[msg->sg.start].page_link &&
 391                     msg->sg.data[msg->sg.start].length)
 392                         goto more_data;
 393         }
 394         return ret;
 395 }
 396 
 397 static int tcp_bpf_sendmsg(struct sock *sk, struct msghdr *msg, size_t size)
 398 {
 399         struct sk_msg tmp, *msg_tx = NULL;
 400         int copied = 0, err = 0;
 401         struct sk_psock *psock;
 402         long timeo;
 403         int flags;
 404 
 405         /* Don't let internal do_tcp_sendpages() flags through */
 406         flags = (msg->msg_flags & ~MSG_SENDPAGE_DECRYPTED);
 407         flags |= MSG_NO_SHARED_FRAGS;
 408 
 409         psock = sk_psock_get(sk);
 410         if (unlikely(!psock))
 411                 return tcp_sendmsg(sk, msg, size);
 412 
 413         lock_sock(sk);
 414         timeo = sock_sndtimeo(sk, msg->msg_flags & MSG_DONTWAIT);
 415         while (msg_data_left(msg)) {
 416                 bool enospc = false;
 417                 u32 copy, osize;
 418 
 419                 if (sk->sk_err) {
 420                         err = -sk->sk_err;
 421                         goto out_err;
 422                 }
 423 
 424                 copy = msg_data_left(msg);
 425                 if (!sk_stream_memory_free(sk))
 426                         goto wait_for_sndbuf;
 427                 if (psock->cork) {
 428                         msg_tx = psock->cork;
 429                 } else {
 430                         msg_tx = &tmp;
 431                         sk_msg_init(msg_tx);
 432                 }
 433 
 434                 osize = msg_tx->sg.size;
 435                 err = sk_msg_alloc(sk, msg_tx, msg_tx->sg.size + copy, msg_tx->sg.end - 1);
 436                 if (err) {
 437                         if (err != -ENOSPC)
 438                                 goto wait_for_memory;
 439                         enospc = true;
 440                         copy = msg_tx->sg.size - osize;
 441                 }
 442 
 443                 err = sk_msg_memcopy_from_iter(sk, &msg->msg_iter, msg_tx,
 444                                                copy);
 445                 if (err < 0) {
 446                         sk_msg_trim(sk, msg_tx, osize);
 447                         goto out_err;
 448                 }
 449 
 450                 copied += copy;
 451                 if (psock->cork_bytes) {
 452                         if (size > psock->cork_bytes)
 453                                 psock->cork_bytes = 0;
 454                         else
 455                                 psock->cork_bytes -= size;
 456                         if (psock->cork_bytes && !enospc)
 457                                 goto out_err;
 458                         /* All cork bytes are accounted, rerun the prog. */
 459                         psock->eval = __SK_NONE;
 460                         psock->cork_bytes = 0;
 461                 }
 462 
 463                 err = tcp_bpf_send_verdict(sk, psock, msg_tx, &copied, flags);
 464                 if (unlikely(err < 0))
 465                         goto out_err;
 466                 continue;
 467 wait_for_sndbuf:
 468                 set_bit(SOCK_NOSPACE, &sk->sk_socket->flags);
 469 wait_for_memory:
 470                 err = sk_stream_wait_memory(sk, &timeo);
 471                 if (err) {
 472                         if (msg_tx && msg_tx != psock->cork)
 473                                 sk_msg_free(sk, msg_tx);
 474                         goto out_err;
 475                 }
 476         }
 477 out_err:
 478         if (err < 0)
 479                 err = sk_stream_error(sk, msg->msg_flags, err);
 480         release_sock(sk);
 481         sk_psock_put(sk, psock);
 482         return copied ? copied : err;
 483 }
 484 
 485 static int tcp_bpf_sendpage(struct sock *sk, struct page *page, int offset,
 486                             size_t size, int flags)
 487 {
 488         struct sk_msg tmp, *msg = NULL;
 489         int err = 0, copied = 0;
 490         struct sk_psock *psock;
 491         bool enospc = false;
 492 
 493         psock = sk_psock_get(sk);
 494         if (unlikely(!psock))
 495                 return tcp_sendpage(sk, page, offset, size, flags);
 496 
 497         lock_sock(sk);
 498         if (psock->cork) {
 499                 msg = psock->cork;
 500         } else {
 501                 msg = &tmp;
 502                 sk_msg_init(msg);
 503         }
 504 
 505         /* Catch case where ring is full and sendpage is stalled. */
 506         if (unlikely(sk_msg_full(msg)))
 507                 goto out_err;
 508 
 509         sk_msg_page_add(msg, page, size, offset);
 510         sk_mem_charge(sk, size);
 511         copied = size;
 512         if (sk_msg_full(msg))
 513                 enospc = true;
 514         if (psock->cork_bytes) {
 515                 if (size > psock->cork_bytes)
 516                         psock->cork_bytes = 0;
 517                 else
 518                         psock->cork_bytes -= size;
 519                 if (psock->cork_bytes && !enospc)
 520                         goto out_err;
 521                 /* All cork bytes are accounted, rerun the prog. */
 522                 psock->eval = __SK_NONE;
 523                 psock->cork_bytes = 0;
 524         }
 525 
 526         err = tcp_bpf_send_verdict(sk, psock, msg, &copied, flags);
 527 out_err:
 528         release_sock(sk);
 529         sk_psock_put(sk, psock);
 530         return copied ? copied : err;
 531 }
 532 
 533 static void tcp_bpf_remove(struct sock *sk, struct sk_psock *psock)
 534 {
 535         struct sk_psock_link *link;
 536 
 537         while ((link = sk_psock_link_pop(psock))) {
 538                 sk_psock_unlink(sk, link);
 539                 sk_psock_free_link(link);
 540         }
 541 }
 542 
 543 static void tcp_bpf_unhash(struct sock *sk)
 544 {
 545         void (*saved_unhash)(struct sock *sk);
 546         struct sk_psock *psock;
 547 
 548         rcu_read_lock();
 549         psock = sk_psock(sk);
 550         if (unlikely(!psock)) {
 551                 rcu_read_unlock();
 552                 if (sk->sk_prot->unhash)
 553                         sk->sk_prot->unhash(sk);
 554                 return;
 555         }
 556 
 557         saved_unhash = psock->saved_unhash;
 558         tcp_bpf_remove(sk, psock);
 559         rcu_read_unlock();
 560         saved_unhash(sk);
 561 }
 562 
 563 static void tcp_bpf_close(struct sock *sk, long timeout)
 564 {
 565         void (*saved_close)(struct sock *sk, long timeout);
 566         struct sk_psock *psock;
 567 
 568         lock_sock(sk);
 569         rcu_read_lock();
 570         psock = sk_psock(sk);
 571         if (unlikely(!psock)) {
 572                 rcu_read_unlock();
 573                 release_sock(sk);
 574                 return sk->sk_prot->close(sk, timeout);
 575         }
 576 
 577         saved_close = psock->saved_close;
 578         tcp_bpf_remove(sk, psock);
 579         rcu_read_unlock();
 580         release_sock(sk);
 581         saved_close(sk, timeout);
 582 }
 583 
 584 enum {
 585         TCP_BPF_IPV4,
 586         TCP_BPF_IPV6,
 587         TCP_BPF_NUM_PROTS,
 588 };
 589 
 590 enum {
 591         TCP_BPF_BASE,
 592         TCP_BPF_TX,
 593         TCP_BPF_NUM_CFGS,
 594 };
 595 
 596 static struct proto *tcpv6_prot_saved __read_mostly;
 597 static DEFINE_SPINLOCK(tcpv6_prot_lock);
 598 static struct proto tcp_bpf_prots[TCP_BPF_NUM_PROTS][TCP_BPF_NUM_CFGS];
 599 
 600 static void tcp_bpf_rebuild_protos(struct proto prot[TCP_BPF_NUM_CFGS],
 601                                    struct proto *base)
 602 {
 603         prot[TCP_BPF_BASE]                      = *base;
 604         prot[TCP_BPF_BASE].unhash               = tcp_bpf_unhash;
 605         prot[TCP_BPF_BASE].close                = tcp_bpf_close;
 606         prot[TCP_BPF_BASE].recvmsg              = tcp_bpf_recvmsg;
 607         prot[TCP_BPF_BASE].stream_memory_read   = tcp_bpf_stream_read;
 608 
 609         prot[TCP_BPF_TX]                        = prot[TCP_BPF_BASE];
 610         prot[TCP_BPF_TX].sendmsg                = tcp_bpf_sendmsg;
 611         prot[TCP_BPF_TX].sendpage               = tcp_bpf_sendpage;
 612 }
 613 
 614 static void tcp_bpf_check_v6_needs_rebuild(struct sock *sk, struct proto *ops)
 615 {
 616         if (sk->sk_family == AF_INET6 &&
 617             unlikely(ops != smp_load_acquire(&tcpv6_prot_saved))) {
 618                 spin_lock_bh(&tcpv6_prot_lock);
 619                 if (likely(ops != tcpv6_prot_saved)) {
 620                         tcp_bpf_rebuild_protos(tcp_bpf_prots[TCP_BPF_IPV6], ops);
 621                         smp_store_release(&tcpv6_prot_saved, ops);
 622                 }
 623                 spin_unlock_bh(&tcpv6_prot_lock);
 624         }
 625 }
 626 
 627 static int __init tcp_bpf_v4_build_proto(void)
 628 {
 629         tcp_bpf_rebuild_protos(tcp_bpf_prots[TCP_BPF_IPV4], &tcp_prot);
 630         return 0;
 631 }
 632 core_initcall(tcp_bpf_v4_build_proto);
 633 
 634 static void tcp_bpf_update_sk_prot(struct sock *sk, struct sk_psock *psock)
 635 {
 636         int family = sk->sk_family == AF_INET6 ? TCP_BPF_IPV6 : TCP_BPF_IPV4;
 637         int config = psock->progs.msg_parser   ? TCP_BPF_TX   : TCP_BPF_BASE;
 638 
 639         sk_psock_update_proto(sk, psock, &tcp_bpf_prots[family][config]);
 640 }
 641 
 642 static void tcp_bpf_reinit_sk_prot(struct sock *sk, struct sk_psock *psock)
 643 {
 644         int family = sk->sk_family == AF_INET6 ? TCP_BPF_IPV6 : TCP_BPF_IPV4;
 645         int config = psock->progs.msg_parser   ? TCP_BPF_TX   : TCP_BPF_BASE;
 646 
 647         /* Reinit occurs when program types change e.g. TCP_BPF_TX is removed
 648          * or added requiring sk_prot hook updates. We keep original saved
 649          * hooks in this case.
 650          */
 651         sk->sk_prot = &tcp_bpf_prots[family][config];
 652 }
 653 
 654 static int tcp_bpf_assert_proto_ops(struct proto *ops)
 655 {
 656         /* In order to avoid retpoline, we make assumptions when we call
 657          * into ops if e.g. a psock is not present. Make sure they are
 658          * indeed valid assumptions.
 659          */
 660         return ops->recvmsg  == tcp_recvmsg &&
 661                ops->sendmsg  == tcp_sendmsg &&
 662                ops->sendpage == tcp_sendpage ? 0 : -ENOTSUPP;
 663 }
 664 
 665 void tcp_bpf_reinit(struct sock *sk)
 666 {
 667         struct sk_psock *psock;
 668 
 669         sock_owned_by_me(sk);
 670 
 671         rcu_read_lock();
 672         psock = sk_psock(sk);
 673         tcp_bpf_reinit_sk_prot(sk, psock);
 674         rcu_read_unlock();
 675 }
 676 
 677 int tcp_bpf_init(struct sock *sk)
 678 {
 679         struct proto *ops = READ_ONCE(sk->sk_prot);
 680         struct sk_psock *psock;
 681 
 682         sock_owned_by_me(sk);
 683 
 684         rcu_read_lock();
 685         psock = sk_psock(sk);
 686         if (unlikely(!psock || psock->sk_proto ||
 687                      tcp_bpf_assert_proto_ops(ops))) {
 688                 rcu_read_unlock();
 689                 return -EINVAL;
 690         }
 691         tcp_bpf_check_v6_needs_rebuild(sk, ops);
 692         tcp_bpf_update_sk_prot(sk, psock);
 693         rcu_read_unlock();
 694         return 0;
 695 }

/* [<][>][^][v][top][bottom][index][help] */