root/include/linux/skmsg.h

/* [<][>][^][v][top][bottom][index][help] */

INCLUDED FROM


DEFINITIONS

This source file includes following definitions.
  1. sk_msg_check_to_free
  2. sk_msg_apply_bytes
  3. sk_msg_iter_dist
  4. sk_msg_clear_meta
  5. sk_msg_init
  6. sk_msg_xfer
  7. sk_msg_xfer_full
  8. sk_msg_full
  9. sk_msg_elem_used
  10. sk_msg_elem
  11. sk_msg_elem_cpy
  12. sk_msg_page
  13. sk_msg_to_ingress
  14. sk_msg_compute_data_pointers
  15. sk_msg_page_add
  16. sk_msg_sg_copy
  17. sk_msg_sg_copy_set
  18. sk_msg_sg_copy_clear
  19. sk_psock_queue_msg
  20. sk_psock_queue_empty
  21. sk_psock_report_error
  22. sk_psock_init_link
  23. sk_psock_free_link
  24. sk_psock_unlink
  25. sk_psock_cork_free
  26. sk_psock_update_proto
  27. sk_psock_restore_proto
  28. sk_psock_set_state
  29. sk_psock_clear_state
  30. sk_psock_test_state
  31. sk_psock_get_checked
  32. sk_psock_get
  33. sk_psock_put
  34. sk_psock_data_ready
  35. psock_set_prog
  36. psock_progs_drop

   1 /* SPDX-License-Identifier: GPL-2.0 */
   2 /* Copyright (c) 2017 - 2018 Covalent IO, Inc. http://covalent.io */
   3 
   4 #ifndef _LINUX_SKMSG_H
   5 #define _LINUX_SKMSG_H
   6 
   7 #include <linux/bpf.h>
   8 #include <linux/filter.h>
   9 #include <linux/scatterlist.h>
  10 #include <linux/skbuff.h>
  11 
  12 #include <net/sock.h>
  13 #include <net/tcp.h>
  14 #include <net/strparser.h>
  15 
  16 #define MAX_MSG_FRAGS                   MAX_SKB_FRAGS
  17 #define NR_MSG_FRAG_IDS                 (MAX_MSG_FRAGS + 1)
  18 
  19 enum __sk_action {
  20         __SK_DROP = 0,
  21         __SK_PASS,
  22         __SK_REDIRECT,
  23         __SK_NONE,
  24 };
  25 
  26 struct sk_msg_sg {
  27         u32                             start;
  28         u32                             curr;
  29         u32                             end;
  30         u32                             size;
  31         u32                             copybreak;
  32         bool                            copy[MAX_MSG_FRAGS];
  33         /* The extra two elements:
  34          * 1) used for chaining the front and sections when the list becomes
  35          *    partitioned (e.g. end < start). The crypto APIs require the
  36          *    chaining;
  37          * 2) to chain tailer SG entries after the message.
  38          */
  39         struct scatterlist              data[MAX_MSG_FRAGS + 2];
  40 };
  41 
  42 /* UAPI in filter.c depends on struct sk_msg_sg being first element. */
  43 struct sk_msg {
  44         struct sk_msg_sg                sg;
  45         void                            *data;
  46         void                            *data_end;
  47         u32                             apply_bytes;
  48         u32                             cork_bytes;
  49         u32                             flags;
  50         struct sk_buff                  *skb;
  51         struct sock                     *sk_redir;
  52         struct sock                     *sk;
  53         struct list_head                list;
  54 };
  55 
  56 struct sk_psock_progs {
  57         struct bpf_prog                 *msg_parser;
  58         struct bpf_prog                 *skb_parser;
  59         struct bpf_prog                 *skb_verdict;
  60 };
  61 
  62 enum sk_psock_state_bits {
  63         SK_PSOCK_TX_ENABLED,
  64 };
  65 
  66 struct sk_psock_link {
  67         struct list_head                list;
  68         struct bpf_map                  *map;
  69         void                            *link_raw;
  70 };
  71 
  72 struct sk_psock_parser {
  73         struct strparser                strp;
  74         bool                            enabled;
  75         void (*saved_data_ready)(struct sock *sk);
  76 };
  77 
  78 struct sk_psock_work_state {
  79         struct sk_buff                  *skb;
  80         u32                             len;
  81         u32                             off;
  82 };
  83 
  84 struct sk_psock {
  85         struct sock                     *sk;
  86         struct sock                     *sk_redir;
  87         u32                             apply_bytes;
  88         u32                             cork_bytes;
  89         u32                             eval;
  90         struct sk_msg                   *cork;
  91         struct sk_psock_progs           progs;
  92         struct sk_psock_parser          parser;
  93         struct sk_buff_head             ingress_skb;
  94         struct list_head                ingress_msg;
  95         unsigned long                   state;
  96         struct list_head                link;
  97         spinlock_t                      link_lock;
  98         refcount_t                      refcnt;
  99         void (*saved_unhash)(struct sock *sk);
 100         void (*saved_close)(struct sock *sk, long timeout);
 101         void (*saved_write_space)(struct sock *sk);
 102         struct proto                    *sk_proto;
 103         struct sk_psock_work_state      work_state;
 104         struct work_struct              work;
 105         union {
 106                 struct rcu_head         rcu;
 107                 struct work_struct      gc;
 108         };
 109 };
 110 
 111 int sk_msg_alloc(struct sock *sk, struct sk_msg *msg, int len,
 112                  int elem_first_coalesce);
 113 int sk_msg_clone(struct sock *sk, struct sk_msg *dst, struct sk_msg *src,
 114                  u32 off, u32 len);
 115 void sk_msg_trim(struct sock *sk, struct sk_msg *msg, int len);
 116 int sk_msg_free(struct sock *sk, struct sk_msg *msg);
 117 int sk_msg_free_nocharge(struct sock *sk, struct sk_msg *msg);
 118 void sk_msg_free_partial(struct sock *sk, struct sk_msg *msg, u32 bytes);
 119 void sk_msg_free_partial_nocharge(struct sock *sk, struct sk_msg *msg,
 120                                   u32 bytes);
 121 
 122 void sk_msg_return(struct sock *sk, struct sk_msg *msg, int bytes);
 123 void sk_msg_return_zero(struct sock *sk, struct sk_msg *msg, int bytes);
 124 
 125 int sk_msg_zerocopy_from_iter(struct sock *sk, struct iov_iter *from,
 126                               struct sk_msg *msg, u32 bytes);
 127 int sk_msg_memcopy_from_iter(struct sock *sk, struct iov_iter *from,
 128                              struct sk_msg *msg, u32 bytes);
 129 
 130 static inline void sk_msg_check_to_free(struct sk_msg *msg, u32 i, u32 bytes)
 131 {
 132         WARN_ON(i == msg->sg.end && bytes);
 133 }
 134 
 135 static inline void sk_msg_apply_bytes(struct sk_psock *psock, u32 bytes)
 136 {
 137         if (psock->apply_bytes) {
 138                 if (psock->apply_bytes < bytes)
 139                         psock->apply_bytes = 0;
 140                 else
 141                         psock->apply_bytes -= bytes;
 142         }
 143 }
 144 
 145 static inline u32 sk_msg_iter_dist(u32 start, u32 end)
 146 {
 147         return end >= start ? end - start : end + (NR_MSG_FRAG_IDS - start);
 148 }
 149 
 150 #define sk_msg_iter_var_prev(var)                       \
 151         do {                                            \
 152                 if (var == 0)                           \
 153                         var = NR_MSG_FRAG_IDS - 1;      \
 154                 else                                    \
 155                         var--;                          \
 156         } while (0)
 157 
 158 #define sk_msg_iter_var_next(var)                       \
 159         do {                                            \
 160                 var++;                                  \
 161                 if (var == NR_MSG_FRAG_IDS)             \
 162                         var = 0;                        \
 163         } while (0)
 164 
 165 #define sk_msg_iter_prev(msg, which)                    \
 166         sk_msg_iter_var_prev(msg->sg.which)
 167 
 168 #define sk_msg_iter_next(msg, which)                    \
 169         sk_msg_iter_var_next(msg->sg.which)
 170 
 171 static inline void sk_msg_clear_meta(struct sk_msg *msg)
 172 {
 173         memset(&msg->sg, 0, offsetofend(struct sk_msg_sg, copy));
 174 }
 175 
 176 static inline void sk_msg_init(struct sk_msg *msg)
 177 {
 178         BUILD_BUG_ON(ARRAY_SIZE(msg->sg.data) - 1 != NR_MSG_FRAG_IDS);
 179         memset(msg, 0, sizeof(*msg));
 180         sg_init_marker(msg->sg.data, NR_MSG_FRAG_IDS);
 181 }
 182 
 183 static inline void sk_msg_xfer(struct sk_msg *dst, struct sk_msg *src,
 184                                int which, u32 size)
 185 {
 186         dst->sg.data[which] = src->sg.data[which];
 187         dst->sg.data[which].length  = size;
 188         dst->sg.size               += size;
 189         src->sg.size               -= size;
 190         src->sg.data[which].length -= size;
 191         src->sg.data[which].offset += size;
 192 }
 193 
 194 static inline void sk_msg_xfer_full(struct sk_msg *dst, struct sk_msg *src)
 195 {
 196         memcpy(dst, src, sizeof(*src));
 197         sk_msg_init(src);
 198 }
 199 
 200 static inline bool sk_msg_full(const struct sk_msg *msg)
 201 {
 202         return sk_msg_iter_dist(msg->sg.start, msg->sg.end) == MAX_MSG_FRAGS;
 203 }
 204 
 205 static inline u32 sk_msg_elem_used(const struct sk_msg *msg)
 206 {
 207         return sk_msg_iter_dist(msg->sg.start, msg->sg.end);
 208 }
 209 
 210 static inline struct scatterlist *sk_msg_elem(struct sk_msg *msg, int which)
 211 {
 212         return &msg->sg.data[which];
 213 }
 214 
 215 static inline struct scatterlist sk_msg_elem_cpy(struct sk_msg *msg, int which)
 216 {
 217         return msg->sg.data[which];
 218 }
 219 
 220 static inline struct page *sk_msg_page(struct sk_msg *msg, int which)
 221 {
 222         return sg_page(sk_msg_elem(msg, which));
 223 }
 224 
 225 static inline bool sk_msg_to_ingress(const struct sk_msg *msg)
 226 {
 227         return msg->flags & BPF_F_INGRESS;
 228 }
 229 
 230 static inline void sk_msg_compute_data_pointers(struct sk_msg *msg)
 231 {
 232         struct scatterlist *sge = sk_msg_elem(msg, msg->sg.start);
 233 
 234         if (msg->sg.copy[msg->sg.start]) {
 235                 msg->data = NULL;
 236                 msg->data_end = NULL;
 237         } else {
 238                 msg->data = sg_virt(sge);
 239                 msg->data_end = msg->data + sge->length;
 240         }
 241 }
 242 
 243 static inline void sk_msg_page_add(struct sk_msg *msg, struct page *page,
 244                                    u32 len, u32 offset)
 245 {
 246         struct scatterlist *sge;
 247 
 248         get_page(page);
 249         sge = sk_msg_elem(msg, msg->sg.end);
 250         sg_set_page(sge, page, len, offset);
 251         sg_unmark_end(sge);
 252 
 253         msg->sg.copy[msg->sg.end] = true;
 254         msg->sg.size += len;
 255         sk_msg_iter_next(msg, end);
 256 }
 257 
 258 static inline void sk_msg_sg_copy(struct sk_msg *msg, u32 i, bool copy_state)
 259 {
 260         do {
 261                 msg->sg.copy[i] = copy_state;
 262                 sk_msg_iter_var_next(i);
 263                 if (i == msg->sg.end)
 264                         break;
 265         } while (1);
 266 }
 267 
 268 static inline void sk_msg_sg_copy_set(struct sk_msg *msg, u32 start)
 269 {
 270         sk_msg_sg_copy(msg, start, true);
 271 }
 272 
 273 static inline void sk_msg_sg_copy_clear(struct sk_msg *msg, u32 start)
 274 {
 275         sk_msg_sg_copy(msg, start, false);
 276 }
 277 
 278 static inline struct sk_psock *sk_psock(const struct sock *sk)
 279 {
 280         return rcu_dereference_sk_user_data(sk);
 281 }
 282 
 283 static inline void sk_psock_queue_msg(struct sk_psock *psock,
 284                                       struct sk_msg *msg)
 285 {
 286         list_add_tail(&msg->list, &psock->ingress_msg);
 287 }
 288 
 289 static inline bool sk_psock_queue_empty(const struct sk_psock *psock)
 290 {
 291         return psock ? list_empty(&psock->ingress_msg) : true;
 292 }
 293 
 294 static inline void sk_psock_report_error(struct sk_psock *psock, int err)
 295 {
 296         struct sock *sk = psock->sk;
 297 
 298         sk->sk_err = err;
 299         sk->sk_error_report(sk);
 300 }
 301 
 302 struct sk_psock *sk_psock_init(struct sock *sk, int node);
 303 
 304 int sk_psock_init_strp(struct sock *sk, struct sk_psock *psock);
 305 void sk_psock_start_strp(struct sock *sk, struct sk_psock *psock);
 306 void sk_psock_stop_strp(struct sock *sk, struct sk_psock *psock);
 307 
 308 int sk_psock_msg_verdict(struct sock *sk, struct sk_psock *psock,
 309                          struct sk_msg *msg);
 310 
 311 static inline struct sk_psock_link *sk_psock_init_link(void)
 312 {
 313         return kzalloc(sizeof(struct sk_psock_link),
 314                        GFP_ATOMIC | __GFP_NOWARN);
 315 }
 316 
 317 static inline void sk_psock_free_link(struct sk_psock_link *link)
 318 {
 319         kfree(link);
 320 }
 321 
 322 struct sk_psock_link *sk_psock_link_pop(struct sk_psock *psock);
 323 #if defined(CONFIG_BPF_STREAM_PARSER)
 324 void sk_psock_unlink(struct sock *sk, struct sk_psock_link *link);
 325 #else
 326 static inline void sk_psock_unlink(struct sock *sk,
 327                                    struct sk_psock_link *link)
 328 {
 329 }
 330 #endif
 331 
 332 void __sk_psock_purge_ingress_msg(struct sk_psock *psock);
 333 
 334 static inline void sk_psock_cork_free(struct sk_psock *psock)
 335 {
 336         if (psock->cork) {
 337                 sk_msg_free(psock->sk, psock->cork);
 338                 kfree(psock->cork);
 339                 psock->cork = NULL;
 340         }
 341 }
 342 
 343 static inline void sk_psock_update_proto(struct sock *sk,
 344                                          struct sk_psock *psock,
 345                                          struct proto *ops)
 346 {
 347         psock->saved_unhash = sk->sk_prot->unhash;
 348         psock->saved_close = sk->sk_prot->close;
 349         psock->saved_write_space = sk->sk_write_space;
 350 
 351         psock->sk_proto = sk->sk_prot;
 352         sk->sk_prot = ops;
 353 }
 354 
 355 static inline void sk_psock_restore_proto(struct sock *sk,
 356                                           struct sk_psock *psock)
 357 {
 358         sk->sk_prot->unhash = psock->saved_unhash;
 359 
 360         if (psock->sk_proto) {
 361                 struct inet_connection_sock *icsk = inet_csk(sk);
 362                 bool has_ulp = !!icsk->icsk_ulp_data;
 363 
 364                 if (has_ulp) {
 365                         tcp_update_ulp(sk, psock->sk_proto,
 366                                        psock->saved_write_space);
 367                 } else {
 368                         sk->sk_prot = psock->sk_proto;
 369                         sk->sk_write_space = psock->saved_write_space;
 370                 }
 371                 psock->sk_proto = NULL;
 372         } else {
 373                 sk->sk_write_space = psock->saved_write_space;
 374         }
 375 }
 376 
 377 static inline void sk_psock_set_state(struct sk_psock *psock,
 378                                       enum sk_psock_state_bits bit)
 379 {
 380         set_bit(bit, &psock->state);
 381 }
 382 
 383 static inline void sk_psock_clear_state(struct sk_psock *psock,
 384                                         enum sk_psock_state_bits bit)
 385 {
 386         clear_bit(bit, &psock->state);
 387 }
 388 
 389 static inline bool sk_psock_test_state(const struct sk_psock *psock,
 390                                        enum sk_psock_state_bits bit)
 391 {
 392         return test_bit(bit, &psock->state);
 393 }
 394 
 395 static inline struct sk_psock *sk_psock_get_checked(struct sock *sk)
 396 {
 397         struct sk_psock *psock;
 398 
 399         rcu_read_lock();
 400         psock = sk_psock(sk);
 401         if (psock) {
 402                 if (sk->sk_prot->recvmsg != tcp_bpf_recvmsg) {
 403                         psock = ERR_PTR(-EBUSY);
 404                         goto out;
 405                 }
 406 
 407                 if (!refcount_inc_not_zero(&psock->refcnt))
 408                         psock = ERR_PTR(-EBUSY);
 409         }
 410 out:
 411         rcu_read_unlock();
 412         return psock;
 413 }
 414 
 415 static inline struct sk_psock *sk_psock_get(struct sock *sk)
 416 {
 417         struct sk_psock *psock;
 418 
 419         rcu_read_lock();
 420         psock = sk_psock(sk);
 421         if (psock && !refcount_inc_not_zero(&psock->refcnt))
 422                 psock = NULL;
 423         rcu_read_unlock();
 424         return psock;
 425 }
 426 
 427 void sk_psock_stop(struct sock *sk, struct sk_psock *psock);
 428 void sk_psock_destroy(struct rcu_head *rcu);
 429 void sk_psock_drop(struct sock *sk, struct sk_psock *psock);
 430 
 431 static inline void sk_psock_put(struct sock *sk, struct sk_psock *psock)
 432 {
 433         if (refcount_dec_and_test(&psock->refcnt))
 434                 sk_psock_drop(sk, psock);
 435 }
 436 
 437 static inline void sk_psock_data_ready(struct sock *sk, struct sk_psock *psock)
 438 {
 439         if (psock->parser.enabled)
 440                 psock->parser.saved_data_ready(sk);
 441         else
 442                 sk->sk_data_ready(sk);
 443 }
 444 
 445 static inline void psock_set_prog(struct bpf_prog **pprog,
 446                                   struct bpf_prog *prog)
 447 {
 448         prog = xchg(pprog, prog);
 449         if (prog)
 450                 bpf_prog_put(prog);
 451 }
 452 
 453 static inline void psock_progs_drop(struct sk_psock_progs *progs)
 454 {
 455         psock_set_prog(&progs->msg_parser, NULL);
 456         psock_set_prog(&progs->skb_parser, NULL);
 457         psock_set_prog(&progs->skb_verdict, NULL);
 458 }
 459 
 460 #endif /* _LINUX_SKMSG_H */

/* [<][>][^][v][top][bottom][index][help] */