root/net/core/skmsg.c

/* [<][>][^][v][top][bottom][index][help] */

DEFINITIONS

This source file includes following definitions.
  1. sk_msg_try_coalesce_ok
  2. sk_msg_alloc
  3. sk_msg_clone
  4. sk_msg_return_zero
  5. sk_msg_return
  6. sk_msg_free_elem
  7. __sk_msg_free
  8. sk_msg_free_nocharge
  9. sk_msg_free
  10. __sk_msg_free_partial
  11. sk_msg_free_partial
  12. sk_msg_free_partial_nocharge
  13. sk_msg_trim
  14. sk_msg_zerocopy_from_iter
  15. sk_msg_memcopy_from_iter
  16. sk_psock_skb_ingress
  17. sk_psock_handle_skb
  18. sk_psock_backlog
  19. sk_psock_init
  20. sk_psock_link_pop
  21. __sk_psock_purge_ingress_msg
  22. sk_psock_zap_ingress
  23. sk_psock_link_destroy
  24. sk_psock_destroy_deferred
  25. sk_psock_destroy
  26. sk_psock_drop
  27. sk_psock_map_verd
  28. sk_psock_msg_verdict
  29. sk_psock_bpf_run
  30. sk_psock_from_strp
  31. sk_psock_verdict_apply
  32. sk_psock_strp_read
  33. sk_psock_strp_read_done
  34. sk_psock_strp_parse
  35. sk_psock_strp_data_ready
  36. sk_psock_write_space
  37. sk_psock_init_strp
  38. sk_psock_start_strp
  39. sk_psock_stop_strp

   1 // SPDX-License-Identifier: GPL-2.0
   2 /* Copyright (c) 2017 - 2018 Covalent IO, Inc. http://covalent.io */
   3 
   4 #include <linux/skmsg.h>
   5 #include <linux/skbuff.h>
   6 #include <linux/scatterlist.h>
   7 
   8 #include <net/sock.h>
   9 #include <net/tcp.h>
  10 
  11 static bool sk_msg_try_coalesce_ok(struct sk_msg *msg, int elem_first_coalesce)
  12 {
  13         if (msg->sg.end > msg->sg.start &&
  14             elem_first_coalesce < msg->sg.end)
  15                 return true;
  16 
  17         if (msg->sg.end < msg->sg.start &&
  18             (elem_first_coalesce > msg->sg.start ||
  19              elem_first_coalesce < msg->sg.end))
  20                 return true;
  21 
  22         return false;
  23 }
  24 
  25 int sk_msg_alloc(struct sock *sk, struct sk_msg *msg, int len,
  26                  int elem_first_coalesce)
  27 {
  28         struct page_frag *pfrag = sk_page_frag(sk);
  29         int ret = 0;
  30 
  31         len -= msg->sg.size;
  32         while (len > 0) {
  33                 struct scatterlist *sge;
  34                 u32 orig_offset;
  35                 int use, i;
  36 
  37                 if (!sk_page_frag_refill(sk, pfrag))
  38                         return -ENOMEM;
  39 
  40                 orig_offset = pfrag->offset;
  41                 use = min_t(int, len, pfrag->size - orig_offset);
  42                 if (!sk_wmem_schedule(sk, use))
  43                         return -ENOMEM;
  44 
  45                 i = msg->sg.end;
  46                 sk_msg_iter_var_prev(i);
  47                 sge = &msg->sg.data[i];
  48 
  49                 if (sk_msg_try_coalesce_ok(msg, elem_first_coalesce) &&
  50                     sg_page(sge) == pfrag->page &&
  51                     sge->offset + sge->length == orig_offset) {
  52                         sge->length += use;
  53                 } else {
  54                         if (sk_msg_full(msg)) {
  55                                 ret = -ENOSPC;
  56                                 break;
  57                         }
  58 
  59                         sge = &msg->sg.data[msg->sg.end];
  60                         sg_unmark_end(sge);
  61                         sg_set_page(sge, pfrag->page, use, orig_offset);
  62                         get_page(pfrag->page);
  63                         sk_msg_iter_next(msg, end);
  64                 }
  65 
  66                 sk_mem_charge(sk, use);
  67                 msg->sg.size += use;
  68                 pfrag->offset += use;
  69                 len -= use;
  70         }
  71 
  72         return ret;
  73 }
  74 EXPORT_SYMBOL_GPL(sk_msg_alloc);
  75 
  76 int sk_msg_clone(struct sock *sk, struct sk_msg *dst, struct sk_msg *src,
  77                  u32 off, u32 len)
  78 {
  79         int i = src->sg.start;
  80         struct scatterlist *sge = sk_msg_elem(src, i);
  81         struct scatterlist *sgd = NULL;
  82         u32 sge_len, sge_off;
  83 
  84         while (off) {
  85                 if (sge->length > off)
  86                         break;
  87                 off -= sge->length;
  88                 sk_msg_iter_var_next(i);
  89                 if (i == src->sg.end && off)
  90                         return -ENOSPC;
  91                 sge = sk_msg_elem(src, i);
  92         }
  93 
  94         while (len) {
  95                 sge_len = sge->length - off;
  96                 if (sge_len > len)
  97                         sge_len = len;
  98 
  99                 if (dst->sg.end)
 100                         sgd = sk_msg_elem(dst, dst->sg.end - 1);
 101 
 102                 if (sgd &&
 103                     (sg_page(sge) == sg_page(sgd)) &&
 104                     (sg_virt(sge) + off == sg_virt(sgd) + sgd->length)) {
 105                         sgd->length += sge_len;
 106                         dst->sg.size += sge_len;
 107                 } else if (!sk_msg_full(dst)) {
 108                         sge_off = sge->offset + off;
 109                         sk_msg_page_add(dst, sg_page(sge), sge_len, sge_off);
 110                 } else {
 111                         return -ENOSPC;
 112                 }
 113 
 114                 off = 0;
 115                 len -= sge_len;
 116                 sk_mem_charge(sk, sge_len);
 117                 sk_msg_iter_var_next(i);
 118                 if (i == src->sg.end && len)
 119                         return -ENOSPC;
 120                 sge = sk_msg_elem(src, i);
 121         }
 122 
 123         return 0;
 124 }
 125 EXPORT_SYMBOL_GPL(sk_msg_clone);
 126 
 127 void sk_msg_return_zero(struct sock *sk, struct sk_msg *msg, int bytes)
 128 {
 129         int i = msg->sg.start;
 130 
 131         do {
 132                 struct scatterlist *sge = sk_msg_elem(msg, i);
 133 
 134                 if (bytes < sge->length) {
 135                         sge->length -= bytes;
 136                         sge->offset += bytes;
 137                         sk_mem_uncharge(sk, bytes);
 138                         break;
 139                 }
 140 
 141                 sk_mem_uncharge(sk, sge->length);
 142                 bytes -= sge->length;
 143                 sge->length = 0;
 144                 sge->offset = 0;
 145                 sk_msg_iter_var_next(i);
 146         } while (bytes && i != msg->sg.end);
 147         msg->sg.start = i;
 148 }
 149 EXPORT_SYMBOL_GPL(sk_msg_return_zero);
 150 
 151 void sk_msg_return(struct sock *sk, struct sk_msg *msg, int bytes)
 152 {
 153         int i = msg->sg.start;
 154 
 155         do {
 156                 struct scatterlist *sge = &msg->sg.data[i];
 157                 int uncharge = (bytes < sge->length) ? bytes : sge->length;
 158 
 159                 sk_mem_uncharge(sk, uncharge);
 160                 bytes -= uncharge;
 161                 sk_msg_iter_var_next(i);
 162         } while (i != msg->sg.end);
 163 }
 164 EXPORT_SYMBOL_GPL(sk_msg_return);
 165 
 166 static int sk_msg_free_elem(struct sock *sk, struct sk_msg *msg, u32 i,
 167                             bool charge)
 168 {
 169         struct scatterlist *sge = sk_msg_elem(msg, i);
 170         u32 len = sge->length;
 171 
 172         if (charge)
 173                 sk_mem_uncharge(sk, len);
 174         if (!msg->skb)
 175                 put_page(sg_page(sge));
 176         memset(sge, 0, sizeof(*sge));
 177         return len;
 178 }
 179 
 180 static int __sk_msg_free(struct sock *sk, struct sk_msg *msg, u32 i,
 181                          bool charge)
 182 {
 183         struct scatterlist *sge = sk_msg_elem(msg, i);
 184         int freed = 0;
 185 
 186         while (msg->sg.size) {
 187                 msg->sg.size -= sge->length;
 188                 freed += sk_msg_free_elem(sk, msg, i, charge);
 189                 sk_msg_iter_var_next(i);
 190                 sk_msg_check_to_free(msg, i, msg->sg.size);
 191                 sge = sk_msg_elem(msg, i);
 192         }
 193         consume_skb(msg->skb);
 194         sk_msg_init(msg);
 195         return freed;
 196 }
 197 
 198 int sk_msg_free_nocharge(struct sock *sk, struct sk_msg *msg)
 199 {
 200         return __sk_msg_free(sk, msg, msg->sg.start, false);
 201 }
 202 EXPORT_SYMBOL_GPL(sk_msg_free_nocharge);
 203 
 204 int sk_msg_free(struct sock *sk, struct sk_msg *msg)
 205 {
 206         return __sk_msg_free(sk, msg, msg->sg.start, true);
 207 }
 208 EXPORT_SYMBOL_GPL(sk_msg_free);
 209 
 210 static void __sk_msg_free_partial(struct sock *sk, struct sk_msg *msg,
 211                                   u32 bytes, bool charge)
 212 {
 213         struct scatterlist *sge;
 214         u32 i = msg->sg.start;
 215 
 216         while (bytes) {
 217                 sge = sk_msg_elem(msg, i);
 218                 if (!sge->length)
 219                         break;
 220                 if (bytes < sge->length) {
 221                         if (charge)
 222                                 sk_mem_uncharge(sk, bytes);
 223                         sge->length -= bytes;
 224                         sge->offset += bytes;
 225                         msg->sg.size -= bytes;
 226                         break;
 227                 }
 228 
 229                 msg->sg.size -= sge->length;
 230                 bytes -= sge->length;
 231                 sk_msg_free_elem(sk, msg, i, charge);
 232                 sk_msg_iter_var_next(i);
 233                 sk_msg_check_to_free(msg, i, bytes);
 234         }
 235         msg->sg.start = i;
 236 }
 237 
 238 void sk_msg_free_partial(struct sock *sk, struct sk_msg *msg, u32 bytes)
 239 {
 240         __sk_msg_free_partial(sk, msg, bytes, true);
 241 }
 242 EXPORT_SYMBOL_GPL(sk_msg_free_partial);
 243 
 244 void sk_msg_free_partial_nocharge(struct sock *sk, struct sk_msg *msg,
 245                                   u32 bytes)
 246 {
 247         __sk_msg_free_partial(sk, msg, bytes, false);
 248 }
 249 
 250 void sk_msg_trim(struct sock *sk, struct sk_msg *msg, int len)
 251 {
 252         int trim = msg->sg.size - len;
 253         u32 i = msg->sg.end;
 254 
 255         if (trim <= 0) {
 256                 WARN_ON(trim < 0);
 257                 return;
 258         }
 259 
 260         sk_msg_iter_var_prev(i);
 261         msg->sg.size = len;
 262         while (msg->sg.data[i].length &&
 263                trim >= msg->sg.data[i].length) {
 264                 trim -= msg->sg.data[i].length;
 265                 sk_msg_free_elem(sk, msg, i, true);
 266                 sk_msg_iter_var_prev(i);
 267                 if (!trim)
 268                         goto out;
 269         }
 270 
 271         msg->sg.data[i].length -= trim;
 272         sk_mem_uncharge(sk, trim);
 273         /* Adjust copybreak if it falls into the trimmed part of last buf */
 274         if (msg->sg.curr == i && msg->sg.copybreak > msg->sg.data[i].length)
 275                 msg->sg.copybreak = msg->sg.data[i].length;
 276 out:
 277         sk_msg_iter_var_next(i);
 278         msg->sg.end = i;
 279 
 280         /* If we trim data a full sg elem before curr pointer update
 281          * copybreak and current so that any future copy operations
 282          * start at new copy location.
 283          * However trimed data that has not yet been used in a copy op
 284          * does not require an update.
 285          */
 286         if (!msg->sg.size) {
 287                 msg->sg.curr = msg->sg.start;
 288                 msg->sg.copybreak = 0;
 289         } else if (sk_msg_iter_dist(msg->sg.start, msg->sg.curr) >=
 290                    sk_msg_iter_dist(msg->sg.start, msg->sg.end)) {
 291                 sk_msg_iter_var_prev(i);
 292                 msg->sg.curr = i;
 293                 msg->sg.copybreak = msg->sg.data[i].length;
 294         }
 295 }
 296 EXPORT_SYMBOL_GPL(sk_msg_trim);
 297 
 298 int sk_msg_zerocopy_from_iter(struct sock *sk, struct iov_iter *from,
 299                               struct sk_msg *msg, u32 bytes)
 300 {
 301         int i, maxpages, ret = 0, num_elems = sk_msg_elem_used(msg);
 302         const int to_max_pages = MAX_MSG_FRAGS;
 303         struct page *pages[MAX_MSG_FRAGS];
 304         ssize_t orig, copied, use, offset;
 305 
 306         orig = msg->sg.size;
 307         while (bytes > 0) {
 308                 i = 0;
 309                 maxpages = to_max_pages - num_elems;
 310                 if (maxpages == 0) {
 311                         ret = -EFAULT;
 312                         goto out;
 313                 }
 314 
 315                 copied = iov_iter_get_pages(from, pages, bytes, maxpages,
 316                                             &offset);
 317                 if (copied <= 0) {
 318                         ret = -EFAULT;
 319                         goto out;
 320                 }
 321 
 322                 iov_iter_advance(from, copied);
 323                 bytes -= copied;
 324                 msg->sg.size += copied;
 325 
 326                 while (copied) {
 327                         use = min_t(int, copied, PAGE_SIZE - offset);
 328                         sg_set_page(&msg->sg.data[msg->sg.end],
 329                                     pages[i], use, offset);
 330                         sg_unmark_end(&msg->sg.data[msg->sg.end]);
 331                         sk_mem_charge(sk, use);
 332 
 333                         offset = 0;
 334                         copied -= use;
 335                         sk_msg_iter_next(msg, end);
 336                         num_elems++;
 337                         i++;
 338                 }
 339                 /* When zerocopy is mixed with sk_msg_*copy* operations we
 340                  * may have a copybreak set in this case clear and prefer
 341                  * zerocopy remainder when possible.
 342                  */
 343                 msg->sg.copybreak = 0;
 344                 msg->sg.curr = msg->sg.end;
 345         }
 346 out:
 347         /* Revert iov_iter updates, msg will need to use 'trim' later if it
 348          * also needs to be cleared.
 349          */
 350         if (ret)
 351                 iov_iter_revert(from, msg->sg.size - orig);
 352         return ret;
 353 }
 354 EXPORT_SYMBOL_GPL(sk_msg_zerocopy_from_iter);
 355 
 356 int sk_msg_memcopy_from_iter(struct sock *sk, struct iov_iter *from,
 357                              struct sk_msg *msg, u32 bytes)
 358 {
 359         int ret = -ENOSPC, i = msg->sg.curr;
 360         struct scatterlist *sge;
 361         u32 copy, buf_size;
 362         void *to;
 363 
 364         do {
 365                 sge = sk_msg_elem(msg, i);
 366                 /* This is possible if a trim operation shrunk the buffer */
 367                 if (msg->sg.copybreak >= sge->length) {
 368                         msg->sg.copybreak = 0;
 369                         sk_msg_iter_var_next(i);
 370                         if (i == msg->sg.end)
 371                                 break;
 372                         sge = sk_msg_elem(msg, i);
 373                 }
 374 
 375                 buf_size = sge->length - msg->sg.copybreak;
 376                 copy = (buf_size > bytes) ? bytes : buf_size;
 377                 to = sg_virt(sge) + msg->sg.copybreak;
 378                 msg->sg.copybreak += copy;
 379                 if (sk->sk_route_caps & NETIF_F_NOCACHE_COPY)
 380                         ret = copy_from_iter_nocache(to, copy, from);
 381                 else
 382                         ret = copy_from_iter(to, copy, from);
 383                 if (ret != copy) {
 384                         ret = -EFAULT;
 385                         goto out;
 386                 }
 387                 bytes -= copy;
 388                 if (!bytes)
 389                         break;
 390                 msg->sg.copybreak = 0;
 391                 sk_msg_iter_var_next(i);
 392         } while (i != msg->sg.end);
 393 out:
 394         msg->sg.curr = i;
 395         return ret;
 396 }
 397 EXPORT_SYMBOL_GPL(sk_msg_memcopy_from_iter);
 398 
 399 static int sk_psock_skb_ingress(struct sk_psock *psock, struct sk_buff *skb)
 400 {
 401         struct sock *sk = psock->sk;
 402         int copied = 0, num_sge;
 403         struct sk_msg *msg;
 404 
 405         msg = kzalloc(sizeof(*msg), __GFP_NOWARN | GFP_ATOMIC);
 406         if (unlikely(!msg))
 407                 return -EAGAIN;
 408         if (!sk_rmem_schedule(sk, skb, skb->len)) {
 409                 kfree(msg);
 410                 return -EAGAIN;
 411         }
 412 
 413         sk_msg_init(msg);
 414         num_sge = skb_to_sgvec(skb, msg->sg.data, 0, skb->len);
 415         if (unlikely(num_sge < 0)) {
 416                 kfree(msg);
 417                 return num_sge;
 418         }
 419 
 420         sk_mem_charge(sk, skb->len);
 421         copied = skb->len;
 422         msg->sg.start = 0;
 423         msg->sg.size = copied;
 424         msg->sg.end = num_sge;
 425         msg->skb = skb;
 426 
 427         sk_psock_queue_msg(psock, msg);
 428         sk_psock_data_ready(sk, psock);
 429         return copied;
 430 }
 431 
 432 static int sk_psock_handle_skb(struct sk_psock *psock, struct sk_buff *skb,
 433                                u32 off, u32 len, bool ingress)
 434 {
 435         if (ingress)
 436                 return sk_psock_skb_ingress(psock, skb);
 437         else
 438                 return skb_send_sock_locked(psock->sk, skb, off, len);
 439 }
 440 
 441 static void sk_psock_backlog(struct work_struct *work)
 442 {
 443         struct sk_psock *psock = container_of(work, struct sk_psock, work);
 444         struct sk_psock_work_state *state = &psock->work_state;
 445         struct sk_buff *skb;
 446         bool ingress;
 447         u32 len, off;
 448         int ret;
 449 
 450         /* Lock sock to avoid losing sk_socket during loop. */
 451         lock_sock(psock->sk);
 452         if (state->skb) {
 453                 skb = state->skb;
 454                 len = state->len;
 455                 off = state->off;
 456                 state->skb = NULL;
 457                 goto start;
 458         }
 459 
 460         while ((skb = skb_dequeue(&psock->ingress_skb))) {
 461                 len = skb->len;
 462                 off = 0;
 463 start:
 464                 ingress = tcp_skb_bpf_ingress(skb);
 465                 do {
 466                         ret = -EIO;
 467                         if (likely(psock->sk->sk_socket))
 468                                 ret = sk_psock_handle_skb(psock, skb, off,
 469                                                           len, ingress);
 470                         if (ret <= 0) {
 471                                 if (ret == -EAGAIN) {
 472                                         state->skb = skb;
 473                                         state->len = len;
 474                                         state->off = off;
 475                                         goto end;
 476                                 }
 477                                 /* Hard errors break pipe and stop xmit. */
 478                                 sk_psock_report_error(psock, ret ? -ret : EPIPE);
 479                                 sk_psock_clear_state(psock, SK_PSOCK_TX_ENABLED);
 480                                 kfree_skb(skb);
 481                                 goto end;
 482                         }
 483                         off += ret;
 484                         len -= ret;
 485                 } while (len);
 486 
 487                 if (!ingress)
 488                         kfree_skb(skb);
 489         }
 490 end:
 491         release_sock(psock->sk);
 492 }
 493 
 494 struct sk_psock *sk_psock_init(struct sock *sk, int node)
 495 {
 496         struct sk_psock *psock = kzalloc_node(sizeof(*psock),
 497                                               GFP_ATOMIC | __GFP_NOWARN,
 498                                               node);
 499         if (!psock)
 500                 return NULL;
 501 
 502         psock->sk = sk;
 503         psock->eval =  __SK_NONE;
 504 
 505         INIT_LIST_HEAD(&psock->link);
 506         spin_lock_init(&psock->link_lock);
 507 
 508         INIT_WORK(&psock->work, sk_psock_backlog);
 509         INIT_LIST_HEAD(&psock->ingress_msg);
 510         skb_queue_head_init(&psock->ingress_skb);
 511 
 512         sk_psock_set_state(psock, SK_PSOCK_TX_ENABLED);
 513         refcount_set(&psock->refcnt, 1);
 514 
 515         rcu_assign_sk_user_data(sk, psock);
 516         sock_hold(sk);
 517 
 518         return psock;
 519 }
 520 EXPORT_SYMBOL_GPL(sk_psock_init);
 521 
 522 struct sk_psock_link *sk_psock_link_pop(struct sk_psock *psock)
 523 {
 524         struct sk_psock_link *link;
 525 
 526         spin_lock_bh(&psock->link_lock);
 527         link = list_first_entry_or_null(&psock->link, struct sk_psock_link,
 528                                         list);
 529         if (link)
 530                 list_del(&link->list);
 531         spin_unlock_bh(&psock->link_lock);
 532         return link;
 533 }
 534 
 535 void __sk_psock_purge_ingress_msg(struct sk_psock *psock)
 536 {
 537         struct sk_msg *msg, *tmp;
 538 
 539         list_for_each_entry_safe(msg, tmp, &psock->ingress_msg, list) {
 540                 list_del(&msg->list);
 541                 sk_msg_free(psock->sk, msg);
 542                 kfree(msg);
 543         }
 544 }
 545 
 546 static void sk_psock_zap_ingress(struct sk_psock *psock)
 547 {
 548         __skb_queue_purge(&psock->ingress_skb);
 549         __sk_psock_purge_ingress_msg(psock);
 550 }
 551 
 552 static void sk_psock_link_destroy(struct sk_psock *psock)
 553 {
 554         struct sk_psock_link *link, *tmp;
 555 
 556         list_for_each_entry_safe(link, tmp, &psock->link, list) {
 557                 list_del(&link->list);
 558                 sk_psock_free_link(link);
 559         }
 560 }
 561 
 562 static void sk_psock_destroy_deferred(struct work_struct *gc)
 563 {
 564         struct sk_psock *psock = container_of(gc, struct sk_psock, gc);
 565 
 566         /* No sk_callback_lock since already detached. */
 567 
 568         /* Parser has been stopped */
 569         if (psock->progs.skb_parser)
 570                 strp_done(&psock->parser.strp);
 571 
 572         cancel_work_sync(&psock->work);
 573 
 574         psock_progs_drop(&psock->progs);
 575 
 576         sk_psock_link_destroy(psock);
 577         sk_psock_cork_free(psock);
 578         sk_psock_zap_ingress(psock);
 579 
 580         if (psock->sk_redir)
 581                 sock_put(psock->sk_redir);
 582         sock_put(psock->sk);
 583         kfree(psock);
 584 }
 585 
 586 void sk_psock_destroy(struct rcu_head *rcu)
 587 {
 588         struct sk_psock *psock = container_of(rcu, struct sk_psock, rcu);
 589 
 590         INIT_WORK(&psock->gc, sk_psock_destroy_deferred);
 591         schedule_work(&psock->gc);
 592 }
 593 EXPORT_SYMBOL_GPL(sk_psock_destroy);
 594 
 595 void sk_psock_drop(struct sock *sk, struct sk_psock *psock)
 596 {
 597         sk_psock_cork_free(psock);
 598         sk_psock_zap_ingress(psock);
 599 
 600         write_lock_bh(&sk->sk_callback_lock);
 601         sk_psock_restore_proto(sk, psock);
 602         rcu_assign_sk_user_data(sk, NULL);
 603         if (psock->progs.skb_parser)
 604                 sk_psock_stop_strp(sk, psock);
 605         write_unlock_bh(&sk->sk_callback_lock);
 606         sk_psock_clear_state(psock, SK_PSOCK_TX_ENABLED);
 607 
 608         call_rcu(&psock->rcu, sk_psock_destroy);
 609 }
 610 EXPORT_SYMBOL_GPL(sk_psock_drop);
 611 
 612 static int sk_psock_map_verd(int verdict, bool redir)
 613 {
 614         switch (verdict) {
 615         case SK_PASS:
 616                 return redir ? __SK_REDIRECT : __SK_PASS;
 617         case SK_DROP:
 618         default:
 619                 break;
 620         }
 621 
 622         return __SK_DROP;
 623 }
 624 
 625 int sk_psock_msg_verdict(struct sock *sk, struct sk_psock *psock,
 626                          struct sk_msg *msg)
 627 {
 628         struct bpf_prog *prog;
 629         int ret;
 630 
 631         preempt_disable();
 632         rcu_read_lock();
 633         prog = READ_ONCE(psock->progs.msg_parser);
 634         if (unlikely(!prog)) {
 635                 ret = __SK_PASS;
 636                 goto out;
 637         }
 638 
 639         sk_msg_compute_data_pointers(msg);
 640         msg->sk = sk;
 641         ret = BPF_PROG_RUN(prog, msg);
 642         ret = sk_psock_map_verd(ret, msg->sk_redir);
 643         psock->apply_bytes = msg->apply_bytes;
 644         if (ret == __SK_REDIRECT) {
 645                 if (psock->sk_redir)
 646                         sock_put(psock->sk_redir);
 647                 psock->sk_redir = msg->sk_redir;
 648                 if (!psock->sk_redir) {
 649                         ret = __SK_DROP;
 650                         goto out;
 651                 }
 652                 sock_hold(psock->sk_redir);
 653         }
 654 out:
 655         rcu_read_unlock();
 656         preempt_enable();
 657         return ret;
 658 }
 659 EXPORT_SYMBOL_GPL(sk_psock_msg_verdict);
 660 
 661 static int sk_psock_bpf_run(struct sk_psock *psock, struct bpf_prog *prog,
 662                             struct sk_buff *skb)
 663 {
 664         int ret;
 665 
 666         skb->sk = psock->sk;
 667         bpf_compute_data_end_sk_skb(skb);
 668         preempt_disable();
 669         ret = BPF_PROG_RUN(prog, skb);
 670         preempt_enable();
 671         /* strparser clones the skb before handing it to a upper layer,
 672          * meaning skb_orphan has been called. We NULL sk on the way out
 673          * to ensure we don't trigger a BUG_ON() in skb/sk operations
 674          * later and because we are not charging the memory of this skb
 675          * to any socket yet.
 676          */
 677         skb->sk = NULL;
 678         return ret;
 679 }
 680 
 681 static struct sk_psock *sk_psock_from_strp(struct strparser *strp)
 682 {
 683         struct sk_psock_parser *parser;
 684 
 685         parser = container_of(strp, struct sk_psock_parser, strp);
 686         return container_of(parser, struct sk_psock, parser);
 687 }
 688 
 689 static void sk_psock_verdict_apply(struct sk_psock *psock,
 690                                    struct sk_buff *skb, int verdict)
 691 {
 692         struct sk_psock *psock_other;
 693         struct sock *sk_other;
 694         bool ingress;
 695 
 696         switch (verdict) {
 697         case __SK_PASS:
 698                 sk_other = psock->sk;
 699                 if (sock_flag(sk_other, SOCK_DEAD) ||
 700                     !sk_psock_test_state(psock, SK_PSOCK_TX_ENABLED)) {
 701                         goto out_free;
 702                 }
 703                 if (atomic_read(&sk_other->sk_rmem_alloc) <=
 704                     sk_other->sk_rcvbuf) {
 705                         struct tcp_skb_cb *tcp = TCP_SKB_CB(skb);
 706 
 707                         tcp->bpf.flags |= BPF_F_INGRESS;
 708                         skb_queue_tail(&psock->ingress_skb, skb);
 709                         schedule_work(&psock->work);
 710                         break;
 711                 }
 712                 goto out_free;
 713         case __SK_REDIRECT:
 714                 sk_other = tcp_skb_bpf_redirect_fetch(skb);
 715                 if (unlikely(!sk_other))
 716                         goto out_free;
 717                 psock_other = sk_psock(sk_other);
 718                 if (!psock_other || sock_flag(sk_other, SOCK_DEAD) ||
 719                     !sk_psock_test_state(psock_other, SK_PSOCK_TX_ENABLED))
 720                         goto out_free;
 721                 ingress = tcp_skb_bpf_ingress(skb);
 722                 if ((!ingress && sock_writeable(sk_other)) ||
 723                     (ingress &&
 724                      atomic_read(&sk_other->sk_rmem_alloc) <=
 725                      sk_other->sk_rcvbuf)) {
 726                         if (!ingress)
 727                                 skb_set_owner_w(skb, sk_other);
 728                         skb_queue_tail(&psock_other->ingress_skb, skb);
 729                         schedule_work(&psock_other->work);
 730                         break;
 731                 }
 732                 /* fall-through */
 733         case __SK_DROP:
 734                 /* fall-through */
 735         default:
 736 out_free:
 737                 kfree_skb(skb);
 738         }
 739 }
 740 
 741 static void sk_psock_strp_read(struct strparser *strp, struct sk_buff *skb)
 742 {
 743         struct sk_psock *psock = sk_psock_from_strp(strp);
 744         struct bpf_prog *prog;
 745         int ret = __SK_DROP;
 746 
 747         rcu_read_lock();
 748         prog = READ_ONCE(psock->progs.skb_verdict);
 749         if (likely(prog)) {
 750                 skb_orphan(skb);
 751                 tcp_skb_bpf_redirect_clear(skb);
 752                 ret = sk_psock_bpf_run(psock, prog, skb);
 753                 ret = sk_psock_map_verd(ret, tcp_skb_bpf_redirect_fetch(skb));
 754         }
 755         rcu_read_unlock();
 756         sk_psock_verdict_apply(psock, skb, ret);
 757 }
 758 
 759 static int sk_psock_strp_read_done(struct strparser *strp, int err)
 760 {
 761         return err;
 762 }
 763 
 764 static int sk_psock_strp_parse(struct strparser *strp, struct sk_buff *skb)
 765 {
 766         struct sk_psock *psock = sk_psock_from_strp(strp);
 767         struct bpf_prog *prog;
 768         int ret = skb->len;
 769 
 770         rcu_read_lock();
 771         prog = READ_ONCE(psock->progs.skb_parser);
 772         if (likely(prog))
 773                 ret = sk_psock_bpf_run(psock, prog, skb);
 774         rcu_read_unlock();
 775         return ret;
 776 }
 777 
 778 /* Called with socket lock held. */
 779 static void sk_psock_strp_data_ready(struct sock *sk)
 780 {
 781         struct sk_psock *psock;
 782 
 783         rcu_read_lock();
 784         psock = sk_psock(sk);
 785         if (likely(psock)) {
 786                 write_lock_bh(&sk->sk_callback_lock);
 787                 strp_data_ready(&psock->parser.strp);
 788                 write_unlock_bh(&sk->sk_callback_lock);
 789         }
 790         rcu_read_unlock();
 791 }
 792 
 793 static void sk_psock_write_space(struct sock *sk)
 794 {
 795         struct sk_psock *psock;
 796         void (*write_space)(struct sock *sk) = NULL;
 797 
 798         rcu_read_lock();
 799         psock = sk_psock(sk);
 800         if (likely(psock)) {
 801                 if (sk_psock_test_state(psock, SK_PSOCK_TX_ENABLED))
 802                         schedule_work(&psock->work);
 803                 write_space = psock->saved_write_space;
 804         }
 805         rcu_read_unlock();
 806         if (write_space)
 807                 write_space(sk);
 808 }
 809 
 810 int sk_psock_init_strp(struct sock *sk, struct sk_psock *psock)
 811 {
 812         static const struct strp_callbacks cb = {
 813                 .rcv_msg        = sk_psock_strp_read,
 814                 .read_sock_done = sk_psock_strp_read_done,
 815                 .parse_msg      = sk_psock_strp_parse,
 816         };
 817 
 818         psock->parser.enabled = false;
 819         return strp_init(&psock->parser.strp, sk, &cb);
 820 }
 821 
 822 void sk_psock_start_strp(struct sock *sk, struct sk_psock *psock)
 823 {
 824         struct sk_psock_parser *parser = &psock->parser;
 825 
 826         if (parser->enabled)
 827                 return;
 828 
 829         parser->saved_data_ready = sk->sk_data_ready;
 830         sk->sk_data_ready = sk_psock_strp_data_ready;
 831         sk->sk_write_space = sk_psock_write_space;
 832         parser->enabled = true;
 833 }
 834 
 835 void sk_psock_stop_strp(struct sock *sk, struct sk_psock *psock)
 836 {
 837         struct sk_psock_parser *parser = &psock->parser;
 838 
 839         if (!parser->enabled)
 840                 return;
 841 
 842         sk->sk_data_ready = parser->saved_data_ready;
 843         parser->saved_data_ready = NULL;
 844         strp_stop(&parser->strp);
 845         parser->enabled = false;
 846 }

/* [<][>][^][v][top][bottom][index][help] */