root/net/tls/tls_device_fallback.c

/* [<][>][^][v][top][bottom][index][help] */

DEFINITIONS

This source file includes following definitions.
  1. chain_to_walk
  2. tls_enc_record
  3. tls_init_aead_request
  4. tls_alloc_aead_request
  5. tls_enc_records
  6. update_chksum
  7. complete_skb
  8. fill_sg_in
  9. fill_sg_out
  10. tls_enc_skb
  11. tls_sw_fallback
  12. tls_validate_xmit_skb
  13. tls_encrypt_skb
  14. tls_sw_fallback_init

   1 /* Copyright (c) 2018, Mellanox Technologies All rights reserved.
   2  *
   3  * This software is available to you under a choice of one of two
   4  * licenses.  You may choose to be licensed under the terms of the GNU
   5  * General Public License (GPL) Version 2, available from the file
   6  * COPYING in the main directory of this source tree, or the
   7  * OpenIB.org BSD license below:
   8  *
   9  *     Redistribution and use in source and binary forms, with or
  10  *     without modification, are permitted provided that the following
  11  *     conditions are met:
  12  *
  13  *      - Redistributions of source code must retain the above
  14  *        copyright notice, this list of conditions and the following
  15  *        disclaimer.
  16  *
  17  *      - Redistributions in binary form must reproduce the above
  18  *        copyright notice, this list of conditions and the following
  19  *        disclaimer in the documentation and/or other materials
  20  *        provided with the distribution.
  21  *
  22  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
  23  * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
  24  * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
  25  * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
  26  * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
  27  * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
  28  * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
  29  * SOFTWARE.
  30  */
  31 
  32 #include <net/tls.h>
  33 #include <crypto/aead.h>
  34 #include <crypto/scatterwalk.h>
  35 #include <net/ip6_checksum.h>
  36 
  37 static void chain_to_walk(struct scatterlist *sg, struct scatter_walk *walk)
  38 {
  39         struct scatterlist *src = walk->sg;
  40         int diff = walk->offset - src->offset;
  41 
  42         sg_set_page(sg, sg_page(src),
  43                     src->length - diff, walk->offset);
  44 
  45         scatterwalk_crypto_chain(sg, sg_next(src), 2);
  46 }
  47 
  48 static int tls_enc_record(struct aead_request *aead_req,
  49                           struct crypto_aead *aead, char *aad,
  50                           char *iv, __be64 rcd_sn,
  51                           struct scatter_walk *in,
  52                           struct scatter_walk *out, int *in_len)
  53 {
  54         unsigned char buf[TLS_HEADER_SIZE + TLS_CIPHER_AES_GCM_128_IV_SIZE];
  55         struct scatterlist sg_in[3];
  56         struct scatterlist sg_out[3];
  57         u16 len;
  58         int rc;
  59 
  60         len = min_t(int, *in_len, ARRAY_SIZE(buf));
  61 
  62         scatterwalk_copychunks(buf, in, len, 0);
  63         scatterwalk_copychunks(buf, out, len, 1);
  64 
  65         *in_len -= len;
  66         if (!*in_len)
  67                 return 0;
  68 
  69         scatterwalk_pagedone(in, 0, 1);
  70         scatterwalk_pagedone(out, 1, 1);
  71 
  72         len = buf[4] | (buf[3] << 8);
  73         len -= TLS_CIPHER_AES_GCM_128_IV_SIZE;
  74 
  75         tls_make_aad(aad, len - TLS_CIPHER_AES_GCM_128_TAG_SIZE,
  76                 (char *)&rcd_sn, sizeof(rcd_sn), buf[0],
  77                 TLS_1_2_VERSION);
  78 
  79         memcpy(iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE, buf + TLS_HEADER_SIZE,
  80                TLS_CIPHER_AES_GCM_128_IV_SIZE);
  81 
  82         sg_init_table(sg_in, ARRAY_SIZE(sg_in));
  83         sg_init_table(sg_out, ARRAY_SIZE(sg_out));
  84         sg_set_buf(sg_in, aad, TLS_AAD_SPACE_SIZE);
  85         sg_set_buf(sg_out, aad, TLS_AAD_SPACE_SIZE);
  86         chain_to_walk(sg_in + 1, in);
  87         chain_to_walk(sg_out + 1, out);
  88 
  89         *in_len -= len;
  90         if (*in_len < 0) {
  91                 *in_len += TLS_CIPHER_AES_GCM_128_TAG_SIZE;
  92                 /* the input buffer doesn't contain the entire record.
  93                  * trim len accordingly. The resulting authentication tag
  94                  * will contain garbage, but we don't care, so we won't
  95                  * include any of it in the output skb
  96                  * Note that we assume the output buffer length
  97                  * is larger then input buffer length + tag size
  98                  */
  99                 if (*in_len < 0)
 100                         len += *in_len;
 101 
 102                 *in_len = 0;
 103         }
 104 
 105         if (*in_len) {
 106                 scatterwalk_copychunks(NULL, in, len, 2);
 107                 scatterwalk_pagedone(in, 0, 1);
 108                 scatterwalk_copychunks(NULL, out, len, 2);
 109                 scatterwalk_pagedone(out, 1, 1);
 110         }
 111 
 112         len -= TLS_CIPHER_AES_GCM_128_TAG_SIZE;
 113         aead_request_set_crypt(aead_req, sg_in, sg_out, len, iv);
 114 
 115         rc = crypto_aead_encrypt(aead_req);
 116 
 117         return rc;
 118 }
 119 
 120 static void tls_init_aead_request(struct aead_request *aead_req,
 121                                   struct crypto_aead *aead)
 122 {
 123         aead_request_set_tfm(aead_req, aead);
 124         aead_request_set_ad(aead_req, TLS_AAD_SPACE_SIZE);
 125 }
 126 
 127 static struct aead_request *tls_alloc_aead_request(struct crypto_aead *aead,
 128                                                    gfp_t flags)
 129 {
 130         unsigned int req_size = sizeof(struct aead_request) +
 131                 crypto_aead_reqsize(aead);
 132         struct aead_request *aead_req;
 133 
 134         aead_req = kzalloc(req_size, flags);
 135         if (aead_req)
 136                 tls_init_aead_request(aead_req, aead);
 137         return aead_req;
 138 }
 139 
 140 static int tls_enc_records(struct aead_request *aead_req,
 141                            struct crypto_aead *aead, struct scatterlist *sg_in,
 142                            struct scatterlist *sg_out, char *aad, char *iv,
 143                            u64 rcd_sn, int len)
 144 {
 145         struct scatter_walk out, in;
 146         int rc;
 147 
 148         scatterwalk_start(&in, sg_in);
 149         scatterwalk_start(&out, sg_out);
 150 
 151         do {
 152                 rc = tls_enc_record(aead_req, aead, aad, iv,
 153                                     cpu_to_be64(rcd_sn), &in, &out, &len);
 154                 rcd_sn++;
 155 
 156         } while (rc == 0 && len);
 157 
 158         scatterwalk_done(&in, 0, 0);
 159         scatterwalk_done(&out, 1, 0);
 160 
 161         return rc;
 162 }
 163 
 164 /* Can't use icsk->icsk_af_ops->send_check here because the ip addresses
 165  * might have been changed by NAT.
 166  */
 167 static void update_chksum(struct sk_buff *skb, int headln)
 168 {
 169         struct tcphdr *th = tcp_hdr(skb);
 170         int datalen = skb->len - headln;
 171         const struct ipv6hdr *ipv6h;
 172         const struct iphdr *iph;
 173 
 174         /* We only changed the payload so if we are using partial we don't
 175          * need to update anything.
 176          */
 177         if (likely(skb->ip_summed == CHECKSUM_PARTIAL))
 178                 return;
 179 
 180         skb->ip_summed = CHECKSUM_PARTIAL;
 181         skb->csum_start = skb_transport_header(skb) - skb->head;
 182         skb->csum_offset = offsetof(struct tcphdr, check);
 183 
 184         if (skb->sk->sk_family == AF_INET6) {
 185                 ipv6h = ipv6_hdr(skb);
 186                 th->check = ~csum_ipv6_magic(&ipv6h->saddr, &ipv6h->daddr,
 187                                              datalen, IPPROTO_TCP, 0);
 188         } else {
 189                 iph = ip_hdr(skb);
 190                 th->check = ~csum_tcpudp_magic(iph->saddr, iph->daddr, datalen,
 191                                                IPPROTO_TCP, 0);
 192         }
 193 }
 194 
 195 static void complete_skb(struct sk_buff *nskb, struct sk_buff *skb, int headln)
 196 {
 197         struct sock *sk = skb->sk;
 198         int delta;
 199 
 200         skb_copy_header(nskb, skb);
 201 
 202         skb_put(nskb, skb->len);
 203         memcpy(nskb->data, skb->data, headln);
 204 
 205         nskb->destructor = skb->destructor;
 206         nskb->sk = sk;
 207         skb->destructor = NULL;
 208         skb->sk = NULL;
 209 
 210         update_chksum(nskb, headln);
 211 
 212         /* sock_efree means skb must gone through skb_orphan_partial() */
 213         if (nskb->destructor == sock_efree)
 214                 return;
 215 
 216         delta = nskb->truesize - skb->truesize;
 217         if (likely(delta < 0))
 218                 WARN_ON_ONCE(refcount_sub_and_test(-delta, &sk->sk_wmem_alloc));
 219         else if (delta)
 220                 refcount_add(delta, &sk->sk_wmem_alloc);
 221 }
 222 
 223 /* This function may be called after the user socket is already
 224  * closed so make sure we don't use anything freed during
 225  * tls_sk_proto_close here
 226  */
 227 
 228 static int fill_sg_in(struct scatterlist *sg_in,
 229                       struct sk_buff *skb,
 230                       struct tls_offload_context_tx *ctx,
 231                       u64 *rcd_sn,
 232                       s32 *sync_size,
 233                       int *resync_sgs)
 234 {
 235         int tcp_payload_offset = skb_transport_offset(skb) + tcp_hdrlen(skb);
 236         int payload_len = skb->len - tcp_payload_offset;
 237         u32 tcp_seq = ntohl(tcp_hdr(skb)->seq);
 238         struct tls_record_info *record;
 239         unsigned long flags;
 240         int remaining;
 241         int i;
 242 
 243         spin_lock_irqsave(&ctx->lock, flags);
 244         record = tls_get_record(ctx, tcp_seq, rcd_sn);
 245         if (!record) {
 246                 spin_unlock_irqrestore(&ctx->lock, flags);
 247                 return -EINVAL;
 248         }
 249 
 250         *sync_size = tcp_seq - tls_record_start_seq(record);
 251         if (*sync_size < 0) {
 252                 int is_start_marker = tls_record_is_start_marker(record);
 253 
 254                 spin_unlock_irqrestore(&ctx->lock, flags);
 255                 /* This should only occur if the relevant record was
 256                  * already acked. In that case it should be ok
 257                  * to drop the packet and avoid retransmission.
 258                  *
 259                  * There is a corner case where the packet contains
 260                  * both an acked and a non-acked record.
 261                  * We currently don't handle that case and rely
 262                  * on TCP to retranmit a packet that doesn't contain
 263                  * already acked payload.
 264                  */
 265                 if (!is_start_marker)
 266                         *sync_size = 0;
 267                 return -EINVAL;
 268         }
 269 
 270         remaining = *sync_size;
 271         for (i = 0; remaining > 0; i++) {
 272                 skb_frag_t *frag = &record->frags[i];
 273 
 274                 __skb_frag_ref(frag);
 275                 sg_set_page(sg_in + i, skb_frag_page(frag),
 276                             skb_frag_size(frag), skb_frag_off(frag));
 277 
 278                 remaining -= skb_frag_size(frag);
 279 
 280                 if (remaining < 0)
 281                         sg_in[i].length += remaining;
 282         }
 283         *resync_sgs = i;
 284 
 285         spin_unlock_irqrestore(&ctx->lock, flags);
 286         if (skb_to_sgvec(skb, &sg_in[i], tcp_payload_offset, payload_len) < 0)
 287                 return -EINVAL;
 288 
 289         return 0;
 290 }
 291 
 292 static void fill_sg_out(struct scatterlist sg_out[3], void *buf,
 293                         struct tls_context *tls_ctx,
 294                         struct sk_buff *nskb,
 295                         int tcp_payload_offset,
 296                         int payload_len,
 297                         int sync_size,
 298                         void *dummy_buf)
 299 {
 300         sg_set_buf(&sg_out[0], dummy_buf, sync_size);
 301         sg_set_buf(&sg_out[1], nskb->data + tcp_payload_offset, payload_len);
 302         /* Add room for authentication tag produced by crypto */
 303         dummy_buf += sync_size;
 304         sg_set_buf(&sg_out[2], dummy_buf, TLS_CIPHER_AES_GCM_128_TAG_SIZE);
 305 }
 306 
 307 static struct sk_buff *tls_enc_skb(struct tls_context *tls_ctx,
 308                                    struct scatterlist sg_out[3],
 309                                    struct scatterlist *sg_in,
 310                                    struct sk_buff *skb,
 311                                    s32 sync_size, u64 rcd_sn)
 312 {
 313         int tcp_payload_offset = skb_transport_offset(skb) + tcp_hdrlen(skb);
 314         struct tls_offload_context_tx *ctx = tls_offload_ctx_tx(tls_ctx);
 315         int payload_len = skb->len - tcp_payload_offset;
 316         void *buf, *iv, *aad, *dummy_buf;
 317         struct aead_request *aead_req;
 318         struct sk_buff *nskb = NULL;
 319         int buf_len;
 320 
 321         aead_req = tls_alloc_aead_request(ctx->aead_send, GFP_ATOMIC);
 322         if (!aead_req)
 323                 return NULL;
 324 
 325         buf_len = TLS_CIPHER_AES_GCM_128_SALT_SIZE +
 326                   TLS_CIPHER_AES_GCM_128_IV_SIZE +
 327                   TLS_AAD_SPACE_SIZE +
 328                   sync_size +
 329                   TLS_CIPHER_AES_GCM_128_TAG_SIZE;
 330         buf = kmalloc(buf_len, GFP_ATOMIC);
 331         if (!buf)
 332                 goto free_req;
 333 
 334         iv = buf;
 335         memcpy(iv, tls_ctx->crypto_send.aes_gcm_128.salt,
 336                TLS_CIPHER_AES_GCM_128_SALT_SIZE);
 337         aad = buf + TLS_CIPHER_AES_GCM_128_SALT_SIZE +
 338               TLS_CIPHER_AES_GCM_128_IV_SIZE;
 339         dummy_buf = aad + TLS_AAD_SPACE_SIZE;
 340 
 341         nskb = alloc_skb(skb_headroom(skb) + skb->len, GFP_ATOMIC);
 342         if (!nskb)
 343                 goto free_buf;
 344 
 345         skb_reserve(nskb, skb_headroom(skb));
 346 
 347         fill_sg_out(sg_out, buf, tls_ctx, nskb, tcp_payload_offset,
 348                     payload_len, sync_size, dummy_buf);
 349 
 350         if (tls_enc_records(aead_req, ctx->aead_send, sg_in, sg_out, aad, iv,
 351                             rcd_sn, sync_size + payload_len) < 0)
 352                 goto free_nskb;
 353 
 354         complete_skb(nskb, skb, tcp_payload_offset);
 355 
 356         /* validate_xmit_skb_list assumes that if the skb wasn't segmented
 357          * nskb->prev will point to the skb itself
 358          */
 359         nskb->prev = nskb;
 360 
 361 free_buf:
 362         kfree(buf);
 363 free_req:
 364         kfree(aead_req);
 365         return nskb;
 366 free_nskb:
 367         kfree_skb(nskb);
 368         nskb = NULL;
 369         goto free_buf;
 370 }
 371 
 372 static struct sk_buff *tls_sw_fallback(struct sock *sk, struct sk_buff *skb)
 373 {
 374         int tcp_payload_offset = skb_transport_offset(skb) + tcp_hdrlen(skb);
 375         struct tls_context *tls_ctx = tls_get_ctx(sk);
 376         struct tls_offload_context_tx *ctx = tls_offload_ctx_tx(tls_ctx);
 377         int payload_len = skb->len - tcp_payload_offset;
 378         struct scatterlist *sg_in, sg_out[3];
 379         struct sk_buff *nskb = NULL;
 380         int sg_in_max_elements;
 381         int resync_sgs = 0;
 382         s32 sync_size = 0;
 383         u64 rcd_sn;
 384 
 385         /* worst case is:
 386          * MAX_SKB_FRAGS in tls_record_info
 387          * MAX_SKB_FRAGS + 1 in SKB head and frags.
 388          */
 389         sg_in_max_elements = 2 * MAX_SKB_FRAGS + 1;
 390 
 391         if (!payload_len)
 392                 return skb;
 393 
 394         sg_in = kmalloc_array(sg_in_max_elements, sizeof(*sg_in), GFP_ATOMIC);
 395         if (!sg_in)
 396                 goto free_orig;
 397 
 398         sg_init_table(sg_in, sg_in_max_elements);
 399         sg_init_table(sg_out, ARRAY_SIZE(sg_out));
 400 
 401         if (fill_sg_in(sg_in, skb, ctx, &rcd_sn, &sync_size, &resync_sgs)) {
 402                 /* bypass packets before kernel TLS socket option was set */
 403                 if (sync_size < 0 && payload_len <= -sync_size)
 404                         nskb = skb_get(skb);
 405                 goto put_sg;
 406         }
 407 
 408         nskb = tls_enc_skb(tls_ctx, sg_out, sg_in, skb, sync_size, rcd_sn);
 409 
 410 put_sg:
 411         while (resync_sgs)
 412                 put_page(sg_page(&sg_in[--resync_sgs]));
 413         kfree(sg_in);
 414 free_orig:
 415         if (nskb)
 416                 consume_skb(skb);
 417         else
 418                 kfree_skb(skb);
 419         return nskb;
 420 }
 421 
 422 struct sk_buff *tls_validate_xmit_skb(struct sock *sk,
 423                                       struct net_device *dev,
 424                                       struct sk_buff *skb)
 425 {
 426         if (dev == tls_get_ctx(sk)->netdev)
 427                 return skb;
 428 
 429         return tls_sw_fallback(sk, skb);
 430 }
 431 EXPORT_SYMBOL_GPL(tls_validate_xmit_skb);
 432 
 433 struct sk_buff *tls_encrypt_skb(struct sk_buff *skb)
 434 {
 435         return tls_sw_fallback(skb->sk, skb);
 436 }
 437 EXPORT_SYMBOL_GPL(tls_encrypt_skb);
 438 
 439 int tls_sw_fallback_init(struct sock *sk,
 440                          struct tls_offload_context_tx *offload_ctx,
 441                          struct tls_crypto_info *crypto_info)
 442 {
 443         const u8 *key;
 444         int rc;
 445 
 446         offload_ctx->aead_send =
 447             crypto_alloc_aead("gcm(aes)", 0, CRYPTO_ALG_ASYNC);
 448         if (IS_ERR(offload_ctx->aead_send)) {
 449                 rc = PTR_ERR(offload_ctx->aead_send);
 450                 pr_err_ratelimited("crypto_alloc_aead failed rc=%d\n", rc);
 451                 offload_ctx->aead_send = NULL;
 452                 goto err_out;
 453         }
 454 
 455         key = ((struct tls12_crypto_info_aes_gcm_128 *)crypto_info)->key;
 456 
 457         rc = crypto_aead_setkey(offload_ctx->aead_send, key,
 458                                 TLS_CIPHER_AES_GCM_128_KEY_SIZE);
 459         if (rc)
 460                 goto free_aead;
 461 
 462         rc = crypto_aead_setauthsize(offload_ctx->aead_send,
 463                                      TLS_CIPHER_AES_GCM_128_TAG_SIZE);
 464         if (rc)
 465                 goto free_aead;
 466 
 467         return 0;
 468 free_aead:
 469         crypto_free_aead(offload_ctx->aead_send);
 470 err_out:
 471         return rc;
 472 }

/* [<][>][^][v][top][bottom][index][help] */