This source file includes following definitions.
- __skb_nsg
- skb_nsg
- padding_length
- tls_decrypt_done
- tls_do_decryption
- tls_trim_both_msgs
- tls_alloc_encrypted_msg
- tls_clone_plaintext_msg
- tls_get_rec
- tls_free_rec
- tls_free_open_rec
- tls_tx_records
- tls_encrypt_done
- tls_do_encryption
- tls_split_open_record
- tls_merge_open_record
- tls_push_record
- bpf_exec_tx_verdict
- tls_sw_push_pending_record
- tls_sw_sendmsg
- tls_sw_do_sendpage
- tls_sw_sendpage_locked
- tls_sw_sendpage
- tls_wait_data
- tls_setup_from_iter
- decrypt_internal
- decrypt_skb_update
- decrypt_skb
- tls_sw_advance_skb
- process_rx_list
- tls_sw_recvmsg
- tls_sw_splice_read
- tls_sw_stream_read
- tls_read_size
- tls_queue
- tls_data_ready
- tls_sw_cancel_work_tx
- tls_sw_release_resources_tx
- tls_sw_free_ctx_tx
- tls_sw_release_resources_rx
- tls_sw_strparser_done
- tls_sw_free_ctx_rx
- tls_sw_free_resources_rx
- tx_work_handler
- tls_sw_write_space
- tls_sw_strparser_arm
- tls_set_sw_offload
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38 #include <linux/sched/signal.h>
39 #include <linux/module.h>
40 #include <crypto/aead.h>
41
42 #include <net/strparser.h>
43 #include <net/tls.h>
44
45 static int __skb_nsg(struct sk_buff *skb, int offset, int len,
46 unsigned int recursion_level)
47 {
48 int start = skb_headlen(skb);
49 int i, chunk = start - offset;
50 struct sk_buff *frag_iter;
51 int elt = 0;
52
53 if (unlikely(recursion_level >= 24))
54 return -EMSGSIZE;
55
56 if (chunk > 0) {
57 if (chunk > len)
58 chunk = len;
59 elt++;
60 len -= chunk;
61 if (len == 0)
62 return elt;
63 offset += chunk;
64 }
65
66 for (i = 0; i < skb_shinfo(skb)->nr_frags; i++) {
67 int end;
68
69 WARN_ON(start > offset + len);
70
71 end = start + skb_frag_size(&skb_shinfo(skb)->frags[i]);
72 chunk = end - offset;
73 if (chunk > 0) {
74 if (chunk > len)
75 chunk = len;
76 elt++;
77 len -= chunk;
78 if (len == 0)
79 return elt;
80 offset += chunk;
81 }
82 start = end;
83 }
84
85 if (unlikely(skb_has_frag_list(skb))) {
86 skb_walk_frags(skb, frag_iter) {
87 int end, ret;
88
89 WARN_ON(start > offset + len);
90
91 end = start + frag_iter->len;
92 chunk = end - offset;
93 if (chunk > 0) {
94 if (chunk > len)
95 chunk = len;
96 ret = __skb_nsg(frag_iter, offset - start, chunk,
97 recursion_level + 1);
98 if (unlikely(ret < 0))
99 return ret;
100 elt += ret;
101 len -= chunk;
102 if (len == 0)
103 return elt;
104 offset += chunk;
105 }
106 start = end;
107 }
108 }
109 BUG_ON(len);
110 return elt;
111 }
112
113
114
115
116 static int skb_nsg(struct sk_buff *skb, int offset, int len)
117 {
118 return __skb_nsg(skb, offset, len, 0);
119 }
120
121 static int padding_length(struct tls_sw_context_rx *ctx,
122 struct tls_prot_info *prot, struct sk_buff *skb)
123 {
124 struct strp_msg *rxm = strp_msg(skb);
125 int sub = 0;
126
127
128 if (prot->version == TLS_1_3_VERSION) {
129 char content_type = 0;
130 int err;
131 int back = 17;
132
133 while (content_type == 0) {
134 if (back > rxm->full_len - prot->prepend_size)
135 return -EBADMSG;
136 err = skb_copy_bits(skb,
137 rxm->offset + rxm->full_len - back,
138 &content_type, 1);
139 if (err)
140 return err;
141 if (content_type)
142 break;
143 sub++;
144 back++;
145 }
146 ctx->control = content_type;
147 }
148 return sub;
149 }
150
151 static void tls_decrypt_done(struct crypto_async_request *req, int err)
152 {
153 struct aead_request *aead_req = (struct aead_request *)req;
154 struct scatterlist *sgout = aead_req->dst;
155 struct scatterlist *sgin = aead_req->src;
156 struct tls_sw_context_rx *ctx;
157 struct tls_context *tls_ctx;
158 struct tls_prot_info *prot;
159 struct scatterlist *sg;
160 struct sk_buff *skb;
161 unsigned int pages;
162 int pending;
163
164 skb = (struct sk_buff *)req->data;
165 tls_ctx = tls_get_ctx(skb->sk);
166 ctx = tls_sw_ctx_rx(tls_ctx);
167 prot = &tls_ctx->prot_info;
168
169
170 if (err) {
171 ctx->async_wait.err = err;
172 tls_err_abort(skb->sk, err);
173 } else {
174 struct strp_msg *rxm = strp_msg(skb);
175 int pad;
176
177 pad = padding_length(ctx, prot, skb);
178 if (pad < 0) {
179 ctx->async_wait.err = pad;
180 tls_err_abort(skb->sk, pad);
181 } else {
182 rxm->full_len -= pad;
183 rxm->offset += prot->prepend_size;
184 rxm->full_len -= prot->overhead_size;
185 }
186 }
187
188
189
190
191 skb->sk = NULL;
192
193
194
195 if (sgout != sgin) {
196
197 for_each_sg(sg_next(sgout), sg, UINT_MAX, pages) {
198 if (!sg)
199 break;
200 put_page(sg_page(sg));
201 }
202 }
203
204 kfree(aead_req);
205
206 spin_lock_bh(&ctx->decrypt_compl_lock);
207 pending = atomic_dec_return(&ctx->decrypt_pending);
208
209 if (!pending && ctx->async_notify)
210 complete(&ctx->async_wait.completion);
211 spin_unlock_bh(&ctx->decrypt_compl_lock);
212 }
213
214 static int tls_do_decryption(struct sock *sk,
215 struct sk_buff *skb,
216 struct scatterlist *sgin,
217 struct scatterlist *sgout,
218 char *iv_recv,
219 size_t data_len,
220 struct aead_request *aead_req,
221 bool async)
222 {
223 struct tls_context *tls_ctx = tls_get_ctx(sk);
224 struct tls_prot_info *prot = &tls_ctx->prot_info;
225 struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
226 int ret;
227
228 aead_request_set_tfm(aead_req, ctx->aead_recv);
229 aead_request_set_ad(aead_req, prot->aad_size);
230 aead_request_set_crypt(aead_req, sgin, sgout,
231 data_len + prot->tag_size,
232 (u8 *)iv_recv);
233
234 if (async) {
235
236
237
238
239
240
241 skb->sk = sk;
242 aead_request_set_callback(aead_req,
243 CRYPTO_TFM_REQ_MAY_BACKLOG,
244 tls_decrypt_done, skb);
245 atomic_inc(&ctx->decrypt_pending);
246 } else {
247 aead_request_set_callback(aead_req,
248 CRYPTO_TFM_REQ_MAY_BACKLOG,
249 crypto_req_done, &ctx->async_wait);
250 }
251
252 ret = crypto_aead_decrypt(aead_req);
253 if (ret == -EINPROGRESS) {
254 if (async)
255 return ret;
256
257 ret = crypto_wait_req(ret, &ctx->async_wait);
258 }
259
260 if (async)
261 atomic_dec(&ctx->decrypt_pending);
262
263 return ret;
264 }
265
266 static void tls_trim_both_msgs(struct sock *sk, int target_size)
267 {
268 struct tls_context *tls_ctx = tls_get_ctx(sk);
269 struct tls_prot_info *prot = &tls_ctx->prot_info;
270 struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
271 struct tls_rec *rec = ctx->open_rec;
272
273 sk_msg_trim(sk, &rec->msg_plaintext, target_size);
274 if (target_size > 0)
275 target_size += prot->overhead_size;
276 sk_msg_trim(sk, &rec->msg_encrypted, target_size);
277 }
278
279 static int tls_alloc_encrypted_msg(struct sock *sk, int len)
280 {
281 struct tls_context *tls_ctx = tls_get_ctx(sk);
282 struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
283 struct tls_rec *rec = ctx->open_rec;
284 struct sk_msg *msg_en = &rec->msg_encrypted;
285
286 return sk_msg_alloc(sk, msg_en, len, 0);
287 }
288
289 static int tls_clone_plaintext_msg(struct sock *sk, int required)
290 {
291 struct tls_context *tls_ctx = tls_get_ctx(sk);
292 struct tls_prot_info *prot = &tls_ctx->prot_info;
293 struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
294 struct tls_rec *rec = ctx->open_rec;
295 struct sk_msg *msg_pl = &rec->msg_plaintext;
296 struct sk_msg *msg_en = &rec->msg_encrypted;
297 int skip, len;
298
299
300
301
302
303 len = required - msg_pl->sg.size;
304
305
306
307
308 skip = prot->prepend_size + msg_pl->sg.size;
309
310 return sk_msg_clone(sk, msg_pl, msg_en, skip, len);
311 }
312
313 static struct tls_rec *tls_get_rec(struct sock *sk)
314 {
315 struct tls_context *tls_ctx = tls_get_ctx(sk);
316 struct tls_prot_info *prot = &tls_ctx->prot_info;
317 struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
318 struct sk_msg *msg_pl, *msg_en;
319 struct tls_rec *rec;
320 int mem_size;
321
322 mem_size = sizeof(struct tls_rec) + crypto_aead_reqsize(ctx->aead_send);
323
324 rec = kzalloc(mem_size, sk->sk_allocation);
325 if (!rec)
326 return NULL;
327
328 msg_pl = &rec->msg_plaintext;
329 msg_en = &rec->msg_encrypted;
330
331 sk_msg_init(msg_pl);
332 sk_msg_init(msg_en);
333
334 sg_init_table(rec->sg_aead_in, 2);
335 sg_set_buf(&rec->sg_aead_in[0], rec->aad_space, prot->aad_size);
336 sg_unmark_end(&rec->sg_aead_in[1]);
337
338 sg_init_table(rec->sg_aead_out, 2);
339 sg_set_buf(&rec->sg_aead_out[0], rec->aad_space, prot->aad_size);
340 sg_unmark_end(&rec->sg_aead_out[1]);
341
342 return rec;
343 }
344
345 static void tls_free_rec(struct sock *sk, struct tls_rec *rec)
346 {
347 sk_msg_free(sk, &rec->msg_encrypted);
348 sk_msg_free(sk, &rec->msg_plaintext);
349 kfree(rec);
350 }
351
352 static void tls_free_open_rec(struct sock *sk)
353 {
354 struct tls_context *tls_ctx = tls_get_ctx(sk);
355 struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
356 struct tls_rec *rec = ctx->open_rec;
357
358 if (rec) {
359 tls_free_rec(sk, rec);
360 ctx->open_rec = NULL;
361 }
362 }
363
364 int tls_tx_records(struct sock *sk, int flags)
365 {
366 struct tls_context *tls_ctx = tls_get_ctx(sk);
367 struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
368 struct tls_rec *rec, *tmp;
369 struct sk_msg *msg_en;
370 int tx_flags, rc = 0;
371
372 if (tls_is_partially_sent_record(tls_ctx)) {
373 rec = list_first_entry(&ctx->tx_list,
374 struct tls_rec, list);
375
376 if (flags == -1)
377 tx_flags = rec->tx_flags;
378 else
379 tx_flags = flags;
380
381 rc = tls_push_partial_record(sk, tls_ctx, tx_flags);
382 if (rc)
383 goto tx_err;
384
385
386
387
388 list_del(&rec->list);
389 sk_msg_free(sk, &rec->msg_plaintext);
390 kfree(rec);
391 }
392
393
394 list_for_each_entry_safe(rec, tmp, &ctx->tx_list, list) {
395 if (READ_ONCE(rec->tx_ready)) {
396 if (flags == -1)
397 tx_flags = rec->tx_flags;
398 else
399 tx_flags = flags;
400
401 msg_en = &rec->msg_encrypted;
402 rc = tls_push_sg(sk, tls_ctx,
403 &msg_en->sg.data[msg_en->sg.curr],
404 0, tx_flags);
405 if (rc)
406 goto tx_err;
407
408 list_del(&rec->list);
409 sk_msg_free(sk, &rec->msg_plaintext);
410 kfree(rec);
411 } else {
412 break;
413 }
414 }
415
416 tx_err:
417 if (rc < 0 && rc != -EAGAIN)
418 tls_err_abort(sk, EBADMSG);
419
420 return rc;
421 }
422
423 static void tls_encrypt_done(struct crypto_async_request *req, int err)
424 {
425 struct aead_request *aead_req = (struct aead_request *)req;
426 struct sock *sk = req->data;
427 struct tls_context *tls_ctx = tls_get_ctx(sk);
428 struct tls_prot_info *prot = &tls_ctx->prot_info;
429 struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
430 struct scatterlist *sge;
431 struct sk_msg *msg_en;
432 struct tls_rec *rec;
433 bool ready = false;
434 int pending;
435
436 rec = container_of(aead_req, struct tls_rec, aead_req);
437 msg_en = &rec->msg_encrypted;
438
439 sge = sk_msg_elem(msg_en, msg_en->sg.curr);
440 sge->offset -= prot->prepend_size;
441 sge->length += prot->prepend_size;
442
443
444 if (err || sk->sk_err) {
445 rec = NULL;
446
447
448 if (sk->sk_err) {
449 ctx->async_wait.err = sk->sk_err;
450 } else {
451 ctx->async_wait.err = err;
452 tls_err_abort(sk, err);
453 }
454 }
455
456 if (rec) {
457 struct tls_rec *first_rec;
458
459
460 smp_store_mb(rec->tx_ready, true);
461
462
463 first_rec = list_first_entry(&ctx->tx_list,
464 struct tls_rec, list);
465 if (rec == first_rec)
466 ready = true;
467 }
468
469 spin_lock_bh(&ctx->encrypt_compl_lock);
470 pending = atomic_dec_return(&ctx->encrypt_pending);
471
472 if (!pending && ctx->async_notify)
473 complete(&ctx->async_wait.completion);
474 spin_unlock_bh(&ctx->encrypt_compl_lock);
475
476 if (!ready)
477 return;
478
479
480 if (!test_and_set_bit(BIT_TX_SCHEDULED, &ctx->tx_bitmask))
481 schedule_delayed_work(&ctx->tx_work.work, 1);
482 }
483
484 static int tls_do_encryption(struct sock *sk,
485 struct tls_context *tls_ctx,
486 struct tls_sw_context_tx *ctx,
487 struct aead_request *aead_req,
488 size_t data_len, u32 start)
489 {
490 struct tls_prot_info *prot = &tls_ctx->prot_info;
491 struct tls_rec *rec = ctx->open_rec;
492 struct sk_msg *msg_en = &rec->msg_encrypted;
493 struct scatterlist *sge = sk_msg_elem(msg_en, start);
494 int rc, iv_offset = 0;
495
496
497 if (prot->cipher_type == TLS_CIPHER_AES_CCM_128) {
498 rec->iv_data[0] = TLS_AES_CCM_IV_B0_BYTE;
499 iv_offset = 1;
500 }
501
502 memcpy(&rec->iv_data[iv_offset], tls_ctx->tx.iv,
503 prot->iv_size + prot->salt_size);
504
505 xor_iv_with_seq(prot->version, rec->iv_data, tls_ctx->tx.rec_seq);
506
507 sge->offset += prot->prepend_size;
508 sge->length -= prot->prepend_size;
509
510 msg_en->sg.curr = start;
511
512 aead_request_set_tfm(aead_req, ctx->aead_send);
513 aead_request_set_ad(aead_req, prot->aad_size);
514 aead_request_set_crypt(aead_req, rec->sg_aead_in,
515 rec->sg_aead_out,
516 data_len, rec->iv_data);
517
518 aead_request_set_callback(aead_req, CRYPTO_TFM_REQ_MAY_BACKLOG,
519 tls_encrypt_done, sk);
520
521
522 list_add_tail((struct list_head *)&rec->list, &ctx->tx_list);
523 atomic_inc(&ctx->encrypt_pending);
524
525 rc = crypto_aead_encrypt(aead_req);
526 if (!rc || rc != -EINPROGRESS) {
527 atomic_dec(&ctx->encrypt_pending);
528 sge->offset -= prot->prepend_size;
529 sge->length += prot->prepend_size;
530 }
531
532 if (!rc) {
533 WRITE_ONCE(rec->tx_ready, true);
534 } else if (rc != -EINPROGRESS) {
535 list_del(&rec->list);
536 return rc;
537 }
538
539
540 ctx->open_rec = NULL;
541 tls_advance_record_sn(sk, prot, &tls_ctx->tx);
542 return rc;
543 }
544
545 static int tls_split_open_record(struct sock *sk, struct tls_rec *from,
546 struct tls_rec **to, struct sk_msg *msg_opl,
547 struct sk_msg *msg_oen, u32 split_point,
548 u32 tx_overhead_size, u32 *orig_end)
549 {
550 u32 i, j, bytes = 0, apply = msg_opl->apply_bytes;
551 struct scatterlist *sge, *osge, *nsge;
552 u32 orig_size = msg_opl->sg.size;
553 struct scatterlist tmp = { };
554 struct sk_msg *msg_npl;
555 struct tls_rec *new;
556 int ret;
557
558 new = tls_get_rec(sk);
559 if (!new)
560 return -ENOMEM;
561 ret = sk_msg_alloc(sk, &new->msg_encrypted, msg_opl->sg.size +
562 tx_overhead_size, 0);
563 if (ret < 0) {
564 tls_free_rec(sk, new);
565 return ret;
566 }
567
568 *orig_end = msg_opl->sg.end;
569 i = msg_opl->sg.start;
570 sge = sk_msg_elem(msg_opl, i);
571 while (apply && sge->length) {
572 if (sge->length > apply) {
573 u32 len = sge->length - apply;
574
575 get_page(sg_page(sge));
576 sg_set_page(&tmp, sg_page(sge), len,
577 sge->offset + apply);
578 sge->length = apply;
579 bytes += apply;
580 apply = 0;
581 } else {
582 apply -= sge->length;
583 bytes += sge->length;
584 }
585
586 sk_msg_iter_var_next(i);
587 if (i == msg_opl->sg.end)
588 break;
589 sge = sk_msg_elem(msg_opl, i);
590 }
591
592 msg_opl->sg.end = i;
593 msg_opl->sg.curr = i;
594 msg_opl->sg.copybreak = 0;
595 msg_opl->apply_bytes = 0;
596 msg_opl->sg.size = bytes;
597
598 msg_npl = &new->msg_plaintext;
599 msg_npl->apply_bytes = apply;
600 msg_npl->sg.size = orig_size - bytes;
601
602 j = msg_npl->sg.start;
603 nsge = sk_msg_elem(msg_npl, j);
604 if (tmp.length) {
605 memcpy(nsge, &tmp, sizeof(*nsge));
606 sk_msg_iter_var_next(j);
607 nsge = sk_msg_elem(msg_npl, j);
608 }
609
610 osge = sk_msg_elem(msg_opl, i);
611 while (osge->length) {
612 memcpy(nsge, osge, sizeof(*nsge));
613 sg_unmark_end(nsge);
614 sk_msg_iter_var_next(i);
615 sk_msg_iter_var_next(j);
616 if (i == *orig_end)
617 break;
618 osge = sk_msg_elem(msg_opl, i);
619 nsge = sk_msg_elem(msg_npl, j);
620 }
621
622 msg_npl->sg.end = j;
623 msg_npl->sg.curr = j;
624 msg_npl->sg.copybreak = 0;
625
626 *to = new;
627 return 0;
628 }
629
630 static void tls_merge_open_record(struct sock *sk, struct tls_rec *to,
631 struct tls_rec *from, u32 orig_end)
632 {
633 struct sk_msg *msg_npl = &from->msg_plaintext;
634 struct sk_msg *msg_opl = &to->msg_plaintext;
635 struct scatterlist *osge, *nsge;
636 u32 i, j;
637
638 i = msg_opl->sg.end;
639 sk_msg_iter_var_prev(i);
640 j = msg_npl->sg.start;
641
642 osge = sk_msg_elem(msg_opl, i);
643 nsge = sk_msg_elem(msg_npl, j);
644
645 if (sg_page(osge) == sg_page(nsge) &&
646 osge->offset + osge->length == nsge->offset) {
647 osge->length += nsge->length;
648 put_page(sg_page(nsge));
649 }
650
651 msg_opl->sg.end = orig_end;
652 msg_opl->sg.curr = orig_end;
653 msg_opl->sg.copybreak = 0;
654 msg_opl->apply_bytes = msg_opl->sg.size + msg_npl->sg.size;
655 msg_opl->sg.size += msg_npl->sg.size;
656
657 sk_msg_free(sk, &to->msg_encrypted);
658 sk_msg_xfer_full(&to->msg_encrypted, &from->msg_encrypted);
659
660 kfree(from);
661 }
662
663 static int tls_push_record(struct sock *sk, int flags,
664 unsigned char record_type)
665 {
666 struct tls_context *tls_ctx = tls_get_ctx(sk);
667 struct tls_prot_info *prot = &tls_ctx->prot_info;
668 struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
669 struct tls_rec *rec = ctx->open_rec, *tmp = NULL;
670 u32 i, split_point, uninitialized_var(orig_end);
671 struct sk_msg *msg_pl, *msg_en;
672 struct aead_request *req;
673 bool split;
674 int rc;
675
676 if (!rec)
677 return 0;
678
679 msg_pl = &rec->msg_plaintext;
680 msg_en = &rec->msg_encrypted;
681
682 split_point = msg_pl->apply_bytes;
683 split = split_point && split_point < msg_pl->sg.size;
684 if (unlikely((!split &&
685 msg_pl->sg.size +
686 prot->overhead_size > msg_en->sg.size) ||
687 (split &&
688 split_point +
689 prot->overhead_size > msg_en->sg.size))) {
690 split = true;
691 split_point = msg_en->sg.size;
692 }
693 if (split) {
694 rc = tls_split_open_record(sk, rec, &tmp, msg_pl, msg_en,
695 split_point, prot->overhead_size,
696 &orig_end);
697 if (rc < 0)
698 return rc;
699
700
701
702
703
704 if (!msg_pl->sg.size) {
705 tls_merge_open_record(sk, rec, tmp, orig_end);
706 msg_pl = &rec->msg_plaintext;
707 msg_en = &rec->msg_encrypted;
708 split = false;
709 }
710 sk_msg_trim(sk, msg_en, msg_pl->sg.size +
711 prot->overhead_size);
712 }
713
714 rec->tx_flags = flags;
715 req = &rec->aead_req;
716
717 i = msg_pl->sg.end;
718 sk_msg_iter_var_prev(i);
719
720 rec->content_type = record_type;
721 if (prot->version == TLS_1_3_VERSION) {
722
723 sg_set_buf(&rec->sg_content_type, &rec->content_type, 1);
724 sg_mark_end(&rec->sg_content_type);
725 sg_chain(msg_pl->sg.data, msg_pl->sg.end + 1,
726 &rec->sg_content_type);
727 } else {
728 sg_mark_end(sk_msg_elem(msg_pl, i));
729 }
730
731 if (msg_pl->sg.end < msg_pl->sg.start) {
732 sg_chain(&msg_pl->sg.data[msg_pl->sg.start],
733 MAX_SKB_FRAGS - msg_pl->sg.start + 1,
734 msg_pl->sg.data);
735 }
736
737 i = msg_pl->sg.start;
738 sg_chain(rec->sg_aead_in, 2, &msg_pl->sg.data[i]);
739
740 i = msg_en->sg.end;
741 sk_msg_iter_var_prev(i);
742 sg_mark_end(sk_msg_elem(msg_en, i));
743
744 i = msg_en->sg.start;
745 sg_chain(rec->sg_aead_out, 2, &msg_en->sg.data[i]);
746
747 tls_make_aad(rec->aad_space, msg_pl->sg.size + prot->tail_size,
748 tls_ctx->tx.rec_seq, prot->rec_seq_size,
749 record_type, prot->version);
750
751 tls_fill_prepend(tls_ctx,
752 page_address(sg_page(&msg_en->sg.data[i])) +
753 msg_en->sg.data[i].offset,
754 msg_pl->sg.size + prot->tail_size,
755 record_type, prot->version);
756
757 tls_ctx->pending_open_record_frags = false;
758
759 rc = tls_do_encryption(sk, tls_ctx, ctx, req,
760 msg_pl->sg.size + prot->tail_size, i);
761 if (rc < 0) {
762 if (rc != -EINPROGRESS) {
763 tls_err_abort(sk, EBADMSG);
764 if (split) {
765 tls_ctx->pending_open_record_frags = true;
766 tls_merge_open_record(sk, rec, tmp, orig_end);
767 }
768 }
769 ctx->async_capable = 1;
770 return rc;
771 } else if (split) {
772 msg_pl = &tmp->msg_plaintext;
773 msg_en = &tmp->msg_encrypted;
774 sk_msg_trim(sk, msg_en, msg_pl->sg.size + prot->overhead_size);
775 tls_ctx->pending_open_record_frags = true;
776 ctx->open_rec = tmp;
777 }
778
779 return tls_tx_records(sk, flags);
780 }
781
782 static int bpf_exec_tx_verdict(struct sk_msg *msg, struct sock *sk,
783 bool full_record, u8 record_type,
784 ssize_t *copied, int flags)
785 {
786 struct tls_context *tls_ctx = tls_get_ctx(sk);
787 struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
788 struct sk_msg msg_redir = { };
789 struct sk_psock *psock;
790 struct sock *sk_redir;
791 struct tls_rec *rec;
792 bool enospc, policy;
793 int err = 0, send;
794 u32 delta = 0;
795
796 policy = !(flags & MSG_SENDPAGE_NOPOLICY);
797 psock = sk_psock_get(sk);
798 if (!psock || !policy) {
799 err = tls_push_record(sk, flags, record_type);
800 if (err && sk->sk_err == EBADMSG) {
801 *copied -= sk_msg_free(sk, msg);
802 tls_free_open_rec(sk);
803 err = -sk->sk_err;
804 }
805 if (psock)
806 sk_psock_put(sk, psock);
807 return err;
808 }
809 more_data:
810 enospc = sk_msg_full(msg);
811 if (psock->eval == __SK_NONE) {
812 delta = msg->sg.size;
813 psock->eval = sk_psock_msg_verdict(sk, psock, msg);
814 delta -= msg->sg.size;
815 }
816 if (msg->cork_bytes && msg->cork_bytes > msg->sg.size &&
817 !enospc && !full_record) {
818 err = -ENOSPC;
819 goto out_err;
820 }
821 msg->cork_bytes = 0;
822 send = msg->sg.size;
823 if (msg->apply_bytes && msg->apply_bytes < send)
824 send = msg->apply_bytes;
825
826 switch (psock->eval) {
827 case __SK_PASS:
828 err = tls_push_record(sk, flags, record_type);
829 if (err && sk->sk_err == EBADMSG) {
830 *copied -= sk_msg_free(sk, msg);
831 tls_free_open_rec(sk);
832 err = -sk->sk_err;
833 goto out_err;
834 }
835 break;
836 case __SK_REDIRECT:
837 sk_redir = psock->sk_redir;
838 memcpy(&msg_redir, msg, sizeof(*msg));
839 if (msg->apply_bytes < send)
840 msg->apply_bytes = 0;
841 else
842 msg->apply_bytes -= send;
843 sk_msg_return_zero(sk, msg, send);
844 msg->sg.size -= send;
845 release_sock(sk);
846 err = tcp_bpf_sendmsg_redir(sk_redir, &msg_redir, send, flags);
847 lock_sock(sk);
848 if (err < 0) {
849 *copied -= sk_msg_free_nocharge(sk, &msg_redir);
850 msg->sg.size = 0;
851 }
852 if (msg->sg.size == 0)
853 tls_free_open_rec(sk);
854 break;
855 case __SK_DROP:
856 default:
857 sk_msg_free_partial(sk, msg, send);
858 if (msg->apply_bytes < send)
859 msg->apply_bytes = 0;
860 else
861 msg->apply_bytes -= send;
862 if (msg->sg.size == 0)
863 tls_free_open_rec(sk);
864 *copied -= (send + delta);
865 err = -EACCES;
866 }
867
868 if (likely(!err)) {
869 bool reset_eval = !ctx->open_rec;
870
871 rec = ctx->open_rec;
872 if (rec) {
873 msg = &rec->msg_plaintext;
874 if (!msg->apply_bytes)
875 reset_eval = true;
876 }
877 if (reset_eval) {
878 psock->eval = __SK_NONE;
879 if (psock->sk_redir) {
880 sock_put(psock->sk_redir);
881 psock->sk_redir = NULL;
882 }
883 }
884 if (rec)
885 goto more_data;
886 }
887 out_err:
888 sk_psock_put(sk, psock);
889 return err;
890 }
891
892 static int tls_sw_push_pending_record(struct sock *sk, int flags)
893 {
894 struct tls_context *tls_ctx = tls_get_ctx(sk);
895 struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
896 struct tls_rec *rec = ctx->open_rec;
897 struct sk_msg *msg_pl;
898 size_t copied;
899
900 if (!rec)
901 return 0;
902
903 msg_pl = &rec->msg_plaintext;
904 copied = msg_pl->sg.size;
905 if (!copied)
906 return 0;
907
908 return bpf_exec_tx_verdict(msg_pl, sk, true, TLS_RECORD_TYPE_DATA,
909 &copied, flags);
910 }
911
912 int tls_sw_sendmsg(struct sock *sk, struct msghdr *msg, size_t size)
913 {
914 long timeo = sock_sndtimeo(sk, msg->msg_flags & MSG_DONTWAIT);
915 struct tls_context *tls_ctx = tls_get_ctx(sk);
916 struct tls_prot_info *prot = &tls_ctx->prot_info;
917 struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
918 bool async_capable = ctx->async_capable;
919 unsigned char record_type = TLS_RECORD_TYPE_DATA;
920 bool is_kvec = iov_iter_is_kvec(&msg->msg_iter);
921 bool eor = !(msg->msg_flags & MSG_MORE);
922 size_t try_to_copy;
923 ssize_t copied = 0;
924 struct sk_msg *msg_pl, *msg_en;
925 struct tls_rec *rec;
926 int required_size;
927 int num_async = 0;
928 bool full_record;
929 int record_room;
930 int num_zc = 0;
931 int orig_size;
932 int ret = 0;
933 int pending;
934
935 if (msg->msg_flags & ~(MSG_MORE | MSG_DONTWAIT | MSG_NOSIGNAL))
936 return -EOPNOTSUPP;
937
938 mutex_lock(&tls_ctx->tx_lock);
939 lock_sock(sk);
940
941 if (unlikely(msg->msg_controllen)) {
942 ret = tls_proccess_cmsg(sk, msg, &record_type);
943 if (ret) {
944 if (ret == -EINPROGRESS)
945 num_async++;
946 else if (ret != -EAGAIN)
947 goto send_end;
948 }
949 }
950
951 while (msg_data_left(msg)) {
952 if (sk->sk_err) {
953 ret = -sk->sk_err;
954 goto send_end;
955 }
956
957 if (ctx->open_rec)
958 rec = ctx->open_rec;
959 else
960 rec = ctx->open_rec = tls_get_rec(sk);
961 if (!rec) {
962 ret = -ENOMEM;
963 goto send_end;
964 }
965
966 msg_pl = &rec->msg_plaintext;
967 msg_en = &rec->msg_encrypted;
968
969 orig_size = msg_pl->sg.size;
970 full_record = false;
971 try_to_copy = msg_data_left(msg);
972 record_room = TLS_MAX_PAYLOAD_SIZE - msg_pl->sg.size;
973 if (try_to_copy >= record_room) {
974 try_to_copy = record_room;
975 full_record = true;
976 }
977
978 required_size = msg_pl->sg.size + try_to_copy +
979 prot->overhead_size;
980
981 if (!sk_stream_memory_free(sk))
982 goto wait_for_sndbuf;
983
984 alloc_encrypted:
985 ret = tls_alloc_encrypted_msg(sk, required_size);
986 if (ret) {
987 if (ret != -ENOSPC)
988 goto wait_for_memory;
989
990
991
992
993
994 try_to_copy -= required_size - msg_en->sg.size;
995 full_record = true;
996 }
997
998 if (!is_kvec && (full_record || eor) && !async_capable) {
999 u32 first = msg_pl->sg.end;
1000
1001 ret = sk_msg_zerocopy_from_iter(sk, &msg->msg_iter,
1002 msg_pl, try_to_copy);
1003 if (ret)
1004 goto fallback_to_reg_send;
1005
1006 num_zc++;
1007 copied += try_to_copy;
1008
1009 sk_msg_sg_copy_set(msg_pl, first);
1010 ret = bpf_exec_tx_verdict(msg_pl, sk, full_record,
1011 record_type, &copied,
1012 msg->msg_flags);
1013 if (ret) {
1014 if (ret == -EINPROGRESS)
1015 num_async++;
1016 else if (ret == -ENOMEM)
1017 goto wait_for_memory;
1018 else if (ctx->open_rec && ret == -ENOSPC)
1019 goto rollback_iter;
1020 else if (ret != -EAGAIN)
1021 goto send_end;
1022 }
1023 continue;
1024 rollback_iter:
1025 copied -= try_to_copy;
1026 sk_msg_sg_copy_clear(msg_pl, first);
1027 iov_iter_revert(&msg->msg_iter,
1028 msg_pl->sg.size - orig_size);
1029 fallback_to_reg_send:
1030 sk_msg_trim(sk, msg_pl, orig_size);
1031 }
1032
1033 required_size = msg_pl->sg.size + try_to_copy;
1034
1035 ret = tls_clone_plaintext_msg(sk, required_size);
1036 if (ret) {
1037 if (ret != -ENOSPC)
1038 goto send_end;
1039
1040
1041
1042
1043
1044 try_to_copy -= required_size - msg_pl->sg.size;
1045 full_record = true;
1046 sk_msg_trim(sk, msg_en,
1047 msg_pl->sg.size + prot->overhead_size);
1048 }
1049
1050 if (try_to_copy) {
1051 ret = sk_msg_memcopy_from_iter(sk, &msg->msg_iter,
1052 msg_pl, try_to_copy);
1053 if (ret < 0)
1054 goto trim_sgl;
1055 }
1056
1057
1058
1059
1060 tls_ctx->pending_open_record_frags = true;
1061 copied += try_to_copy;
1062 if (full_record || eor) {
1063 ret = bpf_exec_tx_verdict(msg_pl, sk, full_record,
1064 record_type, &copied,
1065 msg->msg_flags);
1066 if (ret) {
1067 if (ret == -EINPROGRESS)
1068 num_async++;
1069 else if (ret == -ENOMEM)
1070 goto wait_for_memory;
1071 else if (ret != -EAGAIN) {
1072 if (ret == -ENOSPC)
1073 ret = 0;
1074 goto send_end;
1075 }
1076 }
1077 }
1078
1079 continue;
1080
1081 wait_for_sndbuf:
1082 set_bit(SOCK_NOSPACE, &sk->sk_socket->flags);
1083 wait_for_memory:
1084 ret = sk_stream_wait_memory(sk, &timeo);
1085 if (ret) {
1086 trim_sgl:
1087 if (ctx->open_rec)
1088 tls_trim_both_msgs(sk, orig_size);
1089 goto send_end;
1090 }
1091
1092 if (ctx->open_rec && msg_en->sg.size < required_size)
1093 goto alloc_encrypted;
1094 }
1095
1096 if (!num_async) {
1097 goto send_end;
1098 } else if (num_zc) {
1099
1100 spin_lock_bh(&ctx->encrypt_compl_lock);
1101 ctx->async_notify = true;
1102
1103 pending = atomic_read(&ctx->encrypt_pending);
1104 spin_unlock_bh(&ctx->encrypt_compl_lock);
1105 if (pending)
1106 crypto_wait_req(-EINPROGRESS, &ctx->async_wait);
1107 else
1108 reinit_completion(&ctx->async_wait.completion);
1109
1110
1111
1112
1113 WRITE_ONCE(ctx->async_notify, false);
1114
1115 if (ctx->async_wait.err) {
1116 ret = ctx->async_wait.err;
1117 copied = 0;
1118 }
1119 }
1120
1121
1122 if (test_and_clear_bit(BIT_TX_SCHEDULED, &ctx->tx_bitmask)) {
1123 cancel_delayed_work(&ctx->tx_work.work);
1124 tls_tx_records(sk, msg->msg_flags);
1125 }
1126
1127 send_end:
1128 ret = sk_stream_error(sk, msg->msg_flags, ret);
1129
1130 release_sock(sk);
1131 mutex_unlock(&tls_ctx->tx_lock);
1132 return copied > 0 ? copied : ret;
1133 }
1134
1135 static int tls_sw_do_sendpage(struct sock *sk, struct page *page,
1136 int offset, size_t size, int flags)
1137 {
1138 long timeo = sock_sndtimeo(sk, flags & MSG_DONTWAIT);
1139 struct tls_context *tls_ctx = tls_get_ctx(sk);
1140 struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
1141 struct tls_prot_info *prot = &tls_ctx->prot_info;
1142 unsigned char record_type = TLS_RECORD_TYPE_DATA;
1143 struct sk_msg *msg_pl;
1144 struct tls_rec *rec;
1145 int num_async = 0;
1146 ssize_t copied = 0;
1147 bool full_record;
1148 int record_room;
1149 int ret = 0;
1150 bool eor;
1151
1152 eor = !(flags & (MSG_MORE | MSG_SENDPAGE_NOTLAST));
1153 sk_clear_bit(SOCKWQ_ASYNC_NOSPACE, sk);
1154
1155
1156 while (size > 0) {
1157 size_t copy, required_size;
1158
1159 if (sk->sk_err) {
1160 ret = -sk->sk_err;
1161 goto sendpage_end;
1162 }
1163
1164 if (ctx->open_rec)
1165 rec = ctx->open_rec;
1166 else
1167 rec = ctx->open_rec = tls_get_rec(sk);
1168 if (!rec) {
1169 ret = -ENOMEM;
1170 goto sendpage_end;
1171 }
1172
1173 msg_pl = &rec->msg_plaintext;
1174
1175 full_record = false;
1176 record_room = TLS_MAX_PAYLOAD_SIZE - msg_pl->sg.size;
1177 copy = size;
1178 if (copy >= record_room) {
1179 copy = record_room;
1180 full_record = true;
1181 }
1182
1183 required_size = msg_pl->sg.size + copy + prot->overhead_size;
1184
1185 if (!sk_stream_memory_free(sk))
1186 goto wait_for_sndbuf;
1187 alloc_payload:
1188 ret = tls_alloc_encrypted_msg(sk, required_size);
1189 if (ret) {
1190 if (ret != -ENOSPC)
1191 goto wait_for_memory;
1192
1193
1194
1195
1196
1197 copy -= required_size - msg_pl->sg.size;
1198 full_record = true;
1199 }
1200
1201 sk_msg_page_add(msg_pl, page, copy, offset);
1202 sk_mem_charge(sk, copy);
1203
1204 offset += copy;
1205 size -= copy;
1206 copied += copy;
1207
1208 tls_ctx->pending_open_record_frags = true;
1209 if (full_record || eor || sk_msg_full(msg_pl)) {
1210 ret = bpf_exec_tx_verdict(msg_pl, sk, full_record,
1211 record_type, &copied, flags);
1212 if (ret) {
1213 if (ret == -EINPROGRESS)
1214 num_async++;
1215 else if (ret == -ENOMEM)
1216 goto wait_for_memory;
1217 else if (ret != -EAGAIN) {
1218 if (ret == -ENOSPC)
1219 ret = 0;
1220 goto sendpage_end;
1221 }
1222 }
1223 }
1224 continue;
1225 wait_for_sndbuf:
1226 set_bit(SOCK_NOSPACE, &sk->sk_socket->flags);
1227 wait_for_memory:
1228 ret = sk_stream_wait_memory(sk, &timeo);
1229 if (ret) {
1230 if (ctx->open_rec)
1231 tls_trim_both_msgs(sk, msg_pl->sg.size);
1232 goto sendpage_end;
1233 }
1234
1235 if (ctx->open_rec)
1236 goto alloc_payload;
1237 }
1238
1239 if (num_async) {
1240
1241 if (test_and_clear_bit(BIT_TX_SCHEDULED, &ctx->tx_bitmask)) {
1242 cancel_delayed_work(&ctx->tx_work.work);
1243 tls_tx_records(sk, flags);
1244 }
1245 }
1246 sendpage_end:
1247 ret = sk_stream_error(sk, flags, ret);
1248 return copied > 0 ? copied : ret;
1249 }
1250
1251 int tls_sw_sendpage_locked(struct sock *sk, struct page *page,
1252 int offset, size_t size, int flags)
1253 {
1254 if (flags & ~(MSG_MORE | MSG_DONTWAIT | MSG_NOSIGNAL |
1255 MSG_SENDPAGE_NOTLAST | MSG_SENDPAGE_NOPOLICY |
1256 MSG_NO_SHARED_FRAGS))
1257 return -EOPNOTSUPP;
1258
1259 return tls_sw_do_sendpage(sk, page, offset, size, flags);
1260 }
1261
1262 int tls_sw_sendpage(struct sock *sk, struct page *page,
1263 int offset, size_t size, int flags)
1264 {
1265 struct tls_context *tls_ctx = tls_get_ctx(sk);
1266 int ret;
1267
1268 if (flags & ~(MSG_MORE | MSG_DONTWAIT | MSG_NOSIGNAL |
1269 MSG_SENDPAGE_NOTLAST | MSG_SENDPAGE_NOPOLICY))
1270 return -EOPNOTSUPP;
1271
1272 mutex_lock(&tls_ctx->tx_lock);
1273 lock_sock(sk);
1274 ret = tls_sw_do_sendpage(sk, page, offset, size, flags);
1275 release_sock(sk);
1276 mutex_unlock(&tls_ctx->tx_lock);
1277 return ret;
1278 }
1279
1280 static struct sk_buff *tls_wait_data(struct sock *sk, struct sk_psock *psock,
1281 int flags, long timeo, int *err)
1282 {
1283 struct tls_context *tls_ctx = tls_get_ctx(sk);
1284 struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
1285 struct sk_buff *skb;
1286 DEFINE_WAIT_FUNC(wait, woken_wake_function);
1287
1288 while (!(skb = ctx->recv_pkt) && sk_psock_queue_empty(psock)) {
1289 if (sk->sk_err) {
1290 *err = sock_error(sk);
1291 return NULL;
1292 }
1293
1294 if (sk->sk_shutdown & RCV_SHUTDOWN)
1295 return NULL;
1296
1297 if (sock_flag(sk, SOCK_DONE))
1298 return NULL;
1299
1300 if ((flags & MSG_DONTWAIT) || !timeo) {
1301 *err = -EAGAIN;
1302 return NULL;
1303 }
1304
1305 add_wait_queue(sk_sleep(sk), &wait);
1306 sk_set_bit(SOCKWQ_ASYNC_WAITDATA, sk);
1307 sk_wait_event(sk, &timeo,
1308 ctx->recv_pkt != skb ||
1309 !sk_psock_queue_empty(psock),
1310 &wait);
1311 sk_clear_bit(SOCKWQ_ASYNC_WAITDATA, sk);
1312 remove_wait_queue(sk_sleep(sk), &wait);
1313
1314
1315 if (signal_pending(current)) {
1316 *err = sock_intr_errno(timeo);
1317 return NULL;
1318 }
1319 }
1320
1321 return skb;
1322 }
1323
1324 static int tls_setup_from_iter(struct sock *sk, struct iov_iter *from,
1325 int length, int *pages_used,
1326 unsigned int *size_used,
1327 struct scatterlist *to,
1328 int to_max_pages)
1329 {
1330 int rc = 0, i = 0, num_elem = *pages_used, maxpages;
1331 struct page *pages[MAX_SKB_FRAGS];
1332 unsigned int size = *size_used;
1333 ssize_t copied, use;
1334 size_t offset;
1335
1336 while (length > 0) {
1337 i = 0;
1338 maxpages = to_max_pages - num_elem;
1339 if (maxpages == 0) {
1340 rc = -EFAULT;
1341 goto out;
1342 }
1343 copied = iov_iter_get_pages(from, pages,
1344 length,
1345 maxpages, &offset);
1346 if (copied <= 0) {
1347 rc = -EFAULT;
1348 goto out;
1349 }
1350
1351 iov_iter_advance(from, copied);
1352
1353 length -= copied;
1354 size += copied;
1355 while (copied) {
1356 use = min_t(int, copied, PAGE_SIZE - offset);
1357
1358 sg_set_page(&to[num_elem],
1359 pages[i], use, offset);
1360 sg_unmark_end(&to[num_elem]);
1361
1362
1363 offset = 0;
1364 copied -= use;
1365
1366 i++;
1367 num_elem++;
1368 }
1369 }
1370
1371 if (num_elem > *pages_used)
1372 sg_mark_end(&to[num_elem - 1]);
1373 out:
1374 if (rc)
1375 iov_iter_revert(from, size - *size_used);
1376 *size_used = size;
1377 *pages_used = num_elem;
1378
1379 return rc;
1380 }
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390 static int decrypt_internal(struct sock *sk, struct sk_buff *skb,
1391 struct iov_iter *out_iov,
1392 struct scatterlist *out_sg,
1393 int *chunk, bool *zc, bool async)
1394 {
1395 struct tls_context *tls_ctx = tls_get_ctx(sk);
1396 struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
1397 struct tls_prot_info *prot = &tls_ctx->prot_info;
1398 struct strp_msg *rxm = strp_msg(skb);
1399 int n_sgin, n_sgout, nsg, mem_size, aead_size, err, pages = 0;
1400 struct aead_request *aead_req;
1401 struct sk_buff *unused;
1402 u8 *aad, *iv, *mem = NULL;
1403 struct scatterlist *sgin = NULL;
1404 struct scatterlist *sgout = NULL;
1405 const int data_len = rxm->full_len - prot->overhead_size +
1406 prot->tail_size;
1407 int iv_offset = 0;
1408
1409 if (*zc && (out_iov || out_sg)) {
1410 if (out_iov)
1411 n_sgout = iov_iter_npages(out_iov, INT_MAX) + 1;
1412 else
1413 n_sgout = sg_nents(out_sg);
1414 n_sgin = skb_nsg(skb, rxm->offset + prot->prepend_size,
1415 rxm->full_len - prot->prepend_size);
1416 } else {
1417 n_sgout = 0;
1418 *zc = false;
1419 n_sgin = skb_cow_data(skb, 0, &unused);
1420 }
1421
1422 if (n_sgin < 1)
1423 return -EBADMSG;
1424
1425
1426 n_sgin = n_sgin + 1;
1427
1428 nsg = n_sgin + n_sgout;
1429
1430 aead_size = sizeof(*aead_req) + crypto_aead_reqsize(ctx->aead_recv);
1431 mem_size = aead_size + (nsg * sizeof(struct scatterlist));
1432 mem_size = mem_size + prot->aad_size;
1433 mem_size = mem_size + crypto_aead_ivsize(ctx->aead_recv);
1434
1435
1436
1437
1438
1439 mem = kmalloc(mem_size, sk->sk_allocation);
1440 if (!mem)
1441 return -ENOMEM;
1442
1443
1444 aead_req = (struct aead_request *)mem;
1445 sgin = (struct scatterlist *)(mem + aead_size);
1446 sgout = sgin + n_sgin;
1447 aad = (u8 *)(sgout + n_sgout);
1448 iv = aad + prot->aad_size;
1449
1450
1451 if (prot->cipher_type == TLS_CIPHER_AES_CCM_128) {
1452 iv[0] = 2;
1453 iv_offset = 1;
1454 }
1455
1456
1457 err = skb_copy_bits(skb, rxm->offset + TLS_HEADER_SIZE,
1458 iv + iv_offset + prot->salt_size,
1459 prot->iv_size);
1460 if (err < 0) {
1461 kfree(mem);
1462 return err;
1463 }
1464 if (prot->version == TLS_1_3_VERSION)
1465 memcpy(iv + iv_offset, tls_ctx->rx.iv,
1466 crypto_aead_ivsize(ctx->aead_recv));
1467 else
1468 memcpy(iv + iv_offset, tls_ctx->rx.iv, prot->salt_size);
1469
1470 xor_iv_with_seq(prot->version, iv, tls_ctx->rx.rec_seq);
1471
1472
1473 tls_make_aad(aad, rxm->full_len - prot->overhead_size +
1474 prot->tail_size,
1475 tls_ctx->rx.rec_seq, prot->rec_seq_size,
1476 ctx->control, prot->version);
1477
1478
1479 sg_init_table(sgin, n_sgin);
1480 sg_set_buf(&sgin[0], aad, prot->aad_size);
1481 err = skb_to_sgvec(skb, &sgin[1],
1482 rxm->offset + prot->prepend_size,
1483 rxm->full_len - prot->prepend_size);
1484 if (err < 0) {
1485 kfree(mem);
1486 return err;
1487 }
1488
1489 if (n_sgout) {
1490 if (out_iov) {
1491 sg_init_table(sgout, n_sgout);
1492 sg_set_buf(&sgout[0], aad, prot->aad_size);
1493
1494 *chunk = 0;
1495 err = tls_setup_from_iter(sk, out_iov, data_len,
1496 &pages, chunk, &sgout[1],
1497 (n_sgout - 1));
1498 if (err < 0)
1499 goto fallback_to_reg_recv;
1500 } else if (out_sg) {
1501 memcpy(sgout, out_sg, n_sgout * sizeof(*sgout));
1502 } else {
1503 goto fallback_to_reg_recv;
1504 }
1505 } else {
1506 fallback_to_reg_recv:
1507 sgout = sgin;
1508 pages = 0;
1509 *chunk = data_len;
1510 *zc = false;
1511 }
1512
1513
1514 err = tls_do_decryption(sk, skb, sgin, sgout, iv,
1515 data_len, aead_req, async);
1516 if (err == -EINPROGRESS)
1517 return err;
1518
1519
1520 for (; pages > 0; pages--)
1521 put_page(sg_page(&sgout[pages]));
1522
1523 kfree(mem);
1524 return err;
1525 }
1526
1527 static int decrypt_skb_update(struct sock *sk, struct sk_buff *skb,
1528 struct iov_iter *dest, int *chunk, bool *zc,
1529 bool async)
1530 {
1531 struct tls_context *tls_ctx = tls_get_ctx(sk);
1532 struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
1533 struct tls_prot_info *prot = &tls_ctx->prot_info;
1534 struct strp_msg *rxm = strp_msg(skb);
1535 int pad, err = 0;
1536
1537 if (!ctx->decrypted) {
1538 if (tls_ctx->rx_conf == TLS_HW) {
1539 err = tls_device_decrypted(sk, skb);
1540 if (err < 0)
1541 return err;
1542 }
1543
1544
1545 if (!ctx->decrypted) {
1546 err = decrypt_internal(sk, skb, dest, NULL, chunk, zc,
1547 async);
1548 if (err < 0) {
1549 if (err == -EINPROGRESS)
1550 tls_advance_record_sn(sk, prot,
1551 &tls_ctx->rx);
1552
1553 return err;
1554 }
1555 } else {
1556 *zc = false;
1557 }
1558
1559 pad = padding_length(ctx, prot, skb);
1560 if (pad < 0)
1561 return pad;
1562
1563 rxm->full_len -= pad;
1564 rxm->offset += prot->prepend_size;
1565 rxm->full_len -= prot->overhead_size;
1566 tls_advance_record_sn(sk, prot, &tls_ctx->rx);
1567 ctx->decrypted = true;
1568 ctx->saved_data_ready(sk);
1569 } else {
1570 *zc = false;
1571 }
1572
1573 return err;
1574 }
1575
1576 int decrypt_skb(struct sock *sk, struct sk_buff *skb,
1577 struct scatterlist *sgout)
1578 {
1579 bool zc = true;
1580 int chunk;
1581
1582 return decrypt_internal(sk, skb, NULL, sgout, &chunk, &zc, false);
1583 }
1584
1585 static bool tls_sw_advance_skb(struct sock *sk, struct sk_buff *skb,
1586 unsigned int len)
1587 {
1588 struct tls_context *tls_ctx = tls_get_ctx(sk);
1589 struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
1590
1591 if (skb) {
1592 struct strp_msg *rxm = strp_msg(skb);
1593
1594 if (len < rxm->full_len) {
1595 rxm->offset += len;
1596 rxm->full_len -= len;
1597 return false;
1598 }
1599 consume_skb(skb);
1600 }
1601
1602
1603 ctx->recv_pkt = NULL;
1604 __strp_unpause(&ctx->strp);
1605
1606 return true;
1607 }
1608
1609
1610
1611
1612
1613
1614 static int process_rx_list(struct tls_sw_context_rx *ctx,
1615 struct msghdr *msg,
1616 u8 *control,
1617 bool *cmsg,
1618 size_t skip,
1619 size_t len,
1620 bool zc,
1621 bool is_peek)
1622 {
1623 struct sk_buff *skb = skb_peek(&ctx->rx_list);
1624 u8 ctrl = *control;
1625 u8 msgc = *cmsg;
1626 struct tls_msg *tlm;
1627 ssize_t copied = 0;
1628
1629
1630 if (!ctrl && skb) {
1631 tlm = tls_msg(skb);
1632 ctrl = tlm->control;
1633 }
1634
1635 while (skip && skb) {
1636 struct strp_msg *rxm = strp_msg(skb);
1637 tlm = tls_msg(skb);
1638
1639
1640 if (ctrl != tlm->control)
1641 return 0;
1642
1643 if (skip < rxm->full_len)
1644 break;
1645
1646 skip = skip - rxm->full_len;
1647 skb = skb_peek_next(skb, &ctx->rx_list);
1648 }
1649
1650 while (len && skb) {
1651 struct sk_buff *next_skb;
1652 struct strp_msg *rxm = strp_msg(skb);
1653 int chunk = min_t(unsigned int, rxm->full_len - skip, len);
1654
1655 tlm = tls_msg(skb);
1656
1657
1658 if (ctrl != tlm->control)
1659 return 0;
1660
1661
1662
1663
1664 if (!msgc) {
1665 int cerr = put_cmsg(msg, SOL_TLS, TLS_GET_RECORD_TYPE,
1666 sizeof(ctrl), &ctrl);
1667 msgc = true;
1668 if (ctrl != TLS_RECORD_TYPE_DATA) {
1669 if (cerr || msg->msg_flags & MSG_CTRUNC)
1670 return -EIO;
1671
1672 *cmsg = msgc;
1673 }
1674 }
1675
1676 if (!zc || (rxm->full_len - skip) > len) {
1677 int err = skb_copy_datagram_msg(skb, rxm->offset + skip,
1678 msg, chunk);
1679 if (err < 0)
1680 return err;
1681 }
1682
1683 len = len - chunk;
1684 copied = copied + chunk;
1685
1686
1687 if (!is_peek) {
1688 rxm->offset = rxm->offset + chunk;
1689 rxm->full_len = rxm->full_len - chunk;
1690
1691
1692 if (rxm->full_len - skip)
1693 break;
1694 }
1695
1696
1697
1698
1699 skip = 0;
1700
1701 if (msg)
1702 msg->msg_flags |= MSG_EOR;
1703
1704 next_skb = skb_peek_next(skb, &ctx->rx_list);
1705
1706 if (!is_peek) {
1707 skb_unlink(skb, &ctx->rx_list);
1708 consume_skb(skb);
1709 }
1710
1711 skb = next_skb;
1712 }
1713
1714 *control = ctrl;
1715 return copied;
1716 }
1717
1718 int tls_sw_recvmsg(struct sock *sk,
1719 struct msghdr *msg,
1720 size_t len,
1721 int nonblock,
1722 int flags,
1723 int *addr_len)
1724 {
1725 struct tls_context *tls_ctx = tls_get_ctx(sk);
1726 struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
1727 struct tls_prot_info *prot = &tls_ctx->prot_info;
1728 struct sk_psock *psock;
1729 unsigned char control = 0;
1730 ssize_t decrypted = 0;
1731 struct strp_msg *rxm;
1732 struct tls_msg *tlm;
1733 struct sk_buff *skb;
1734 ssize_t copied = 0;
1735 bool cmsg = false;
1736 int target, err = 0;
1737 long timeo;
1738 bool is_kvec = iov_iter_is_kvec(&msg->msg_iter);
1739 bool is_peek = flags & MSG_PEEK;
1740 int num_async = 0;
1741 int pending;
1742
1743 flags |= nonblock;
1744
1745 if (unlikely(flags & MSG_ERRQUEUE))
1746 return sock_recv_errqueue(sk, msg, len, SOL_IP, IP_RECVERR);
1747
1748 psock = sk_psock_get(sk);
1749 lock_sock(sk);
1750
1751
1752 err = process_rx_list(ctx, msg, &control, &cmsg, 0, len, false,
1753 is_peek);
1754 if (err < 0) {
1755 tls_err_abort(sk, err);
1756 goto end;
1757 } else {
1758 copied = err;
1759 }
1760
1761 if (len <= copied)
1762 goto recv_end;
1763
1764 target = sock_rcvlowat(sk, flags & MSG_WAITALL, len);
1765 len = len - copied;
1766 timeo = sock_rcvtimeo(sk, flags & MSG_DONTWAIT);
1767
1768 while (len && (decrypted + copied < target || ctx->recv_pkt)) {
1769 bool retain_skb = false;
1770 bool zc = false;
1771 int to_decrypt;
1772 int chunk = 0;
1773 bool async_capable;
1774 bool async = false;
1775
1776 skb = tls_wait_data(sk, psock, flags, timeo, &err);
1777 if (!skb) {
1778 if (psock) {
1779 int ret = __tcp_bpf_recvmsg(sk, psock,
1780 msg, len, flags);
1781
1782 if (ret > 0) {
1783 decrypted += ret;
1784 len -= ret;
1785 continue;
1786 }
1787 }
1788 goto recv_end;
1789 } else {
1790 tlm = tls_msg(skb);
1791 if (prot->version == TLS_1_3_VERSION)
1792 tlm->control = 0;
1793 else
1794 tlm->control = ctx->control;
1795 }
1796
1797 rxm = strp_msg(skb);
1798
1799 to_decrypt = rxm->full_len - prot->overhead_size;
1800
1801 if (to_decrypt <= len && !is_kvec && !is_peek &&
1802 ctx->control == TLS_RECORD_TYPE_DATA &&
1803 prot->version != TLS_1_3_VERSION)
1804 zc = true;
1805
1806
1807 if (ctx->control == TLS_RECORD_TYPE_DATA)
1808 async_capable = ctx->async_capable;
1809 else
1810 async_capable = false;
1811
1812 err = decrypt_skb_update(sk, skb, &msg->msg_iter,
1813 &chunk, &zc, async_capable);
1814 if (err < 0 && err != -EINPROGRESS) {
1815 tls_err_abort(sk, EBADMSG);
1816 goto recv_end;
1817 }
1818
1819 if (err == -EINPROGRESS) {
1820 async = true;
1821 num_async++;
1822 } else if (prot->version == TLS_1_3_VERSION) {
1823 tlm->control = ctx->control;
1824 }
1825
1826
1827
1828
1829
1830
1831
1832
1833
1834 if (!control)
1835 control = tlm->control;
1836 else if (control != tlm->control)
1837 goto recv_end;
1838
1839 if (!cmsg) {
1840 int cerr;
1841
1842 cerr = put_cmsg(msg, SOL_TLS, TLS_GET_RECORD_TYPE,
1843 sizeof(control), &control);
1844 cmsg = true;
1845 if (control != TLS_RECORD_TYPE_DATA) {
1846 if (cerr || msg->msg_flags & MSG_CTRUNC) {
1847 err = -EIO;
1848 goto recv_end;
1849 }
1850 }
1851 }
1852
1853 if (async)
1854 goto pick_next_record;
1855
1856 if (!zc) {
1857 if (rxm->full_len > len) {
1858 retain_skb = true;
1859 chunk = len;
1860 } else {
1861 chunk = rxm->full_len;
1862 }
1863
1864 err = skb_copy_datagram_msg(skb, rxm->offset,
1865 msg, chunk);
1866 if (err < 0)
1867 goto recv_end;
1868
1869 if (!is_peek) {
1870 rxm->offset = rxm->offset + chunk;
1871 rxm->full_len = rxm->full_len - chunk;
1872 }
1873 }
1874
1875 pick_next_record:
1876 if (chunk > len)
1877 chunk = len;
1878
1879 decrypted += chunk;
1880 len -= chunk;
1881
1882
1883 if (async || is_peek || retain_skb) {
1884 skb_queue_tail(&ctx->rx_list, skb);
1885 skb = NULL;
1886 }
1887
1888 if (tls_sw_advance_skb(sk, skb, chunk)) {
1889
1890
1891
1892
1893 msg->msg_flags |= MSG_EOR;
1894 if (ctx->control != TLS_RECORD_TYPE_DATA)
1895 goto recv_end;
1896 } else {
1897 break;
1898 }
1899 }
1900
1901 recv_end:
1902 if (num_async) {
1903
1904 spin_lock_bh(&ctx->decrypt_compl_lock);
1905 ctx->async_notify = true;
1906 pending = atomic_read(&ctx->decrypt_pending);
1907 spin_unlock_bh(&ctx->decrypt_compl_lock);
1908 if (pending) {
1909 err = crypto_wait_req(-EINPROGRESS, &ctx->async_wait);
1910 if (err) {
1911
1912 tls_err_abort(sk, err);
1913 copied = 0;
1914 decrypted = 0;
1915 goto end;
1916 }
1917 } else {
1918 reinit_completion(&ctx->async_wait.completion);
1919 }
1920
1921
1922
1923
1924 WRITE_ONCE(ctx->async_notify, false);
1925
1926
1927 if (is_peek || is_kvec)
1928 err = process_rx_list(ctx, msg, &control, &cmsg, copied,
1929 decrypted, false, is_peek);
1930 else
1931 err = process_rx_list(ctx, msg, &control, &cmsg, 0,
1932 decrypted, true, is_peek);
1933 if (err < 0) {
1934 tls_err_abort(sk, err);
1935 copied = 0;
1936 goto end;
1937 }
1938 }
1939
1940 copied += decrypted;
1941
1942 end:
1943 release_sock(sk);
1944 if (psock)
1945 sk_psock_put(sk, psock);
1946 return copied ? : err;
1947 }
1948
1949 ssize_t tls_sw_splice_read(struct socket *sock, loff_t *ppos,
1950 struct pipe_inode_info *pipe,
1951 size_t len, unsigned int flags)
1952 {
1953 struct tls_context *tls_ctx = tls_get_ctx(sock->sk);
1954 struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
1955 struct strp_msg *rxm = NULL;
1956 struct sock *sk = sock->sk;
1957 struct sk_buff *skb;
1958 ssize_t copied = 0;
1959 int err = 0;
1960 long timeo;
1961 int chunk;
1962 bool zc = false;
1963
1964 lock_sock(sk);
1965
1966 timeo = sock_rcvtimeo(sk, flags & MSG_DONTWAIT);
1967
1968 skb = tls_wait_data(sk, NULL, flags, timeo, &err);
1969 if (!skb)
1970 goto splice_read_end;
1971
1972 if (!ctx->decrypted) {
1973 err = decrypt_skb_update(sk, skb, NULL, &chunk, &zc, false);
1974
1975
1976 if (ctx->control != TLS_RECORD_TYPE_DATA) {
1977 err = -EINVAL;
1978 goto splice_read_end;
1979 }
1980
1981 if (err < 0) {
1982 tls_err_abort(sk, EBADMSG);
1983 goto splice_read_end;
1984 }
1985 ctx->decrypted = true;
1986 }
1987 rxm = strp_msg(skb);
1988
1989 chunk = min_t(unsigned int, rxm->full_len, len);
1990 copied = skb_splice_bits(skb, sk, rxm->offset, pipe, chunk, flags);
1991 if (copied < 0)
1992 goto splice_read_end;
1993
1994 if (likely(!(flags & MSG_PEEK)))
1995 tls_sw_advance_skb(sk, skb, copied);
1996
1997 splice_read_end:
1998 release_sock(sk);
1999 return copied ? : err;
2000 }
2001
2002 bool tls_sw_stream_read(const struct sock *sk)
2003 {
2004 struct tls_context *tls_ctx = tls_get_ctx(sk);
2005 struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
2006 bool ingress_empty = true;
2007 struct sk_psock *psock;
2008
2009 rcu_read_lock();
2010 psock = sk_psock(sk);
2011 if (psock)
2012 ingress_empty = list_empty(&psock->ingress_msg);
2013 rcu_read_unlock();
2014
2015 return !ingress_empty || ctx->recv_pkt ||
2016 !skb_queue_empty(&ctx->rx_list);
2017 }
2018
2019 static int tls_read_size(struct strparser *strp, struct sk_buff *skb)
2020 {
2021 struct tls_context *tls_ctx = tls_get_ctx(strp->sk);
2022 struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
2023 struct tls_prot_info *prot = &tls_ctx->prot_info;
2024 char header[TLS_HEADER_SIZE + MAX_IV_SIZE];
2025 struct strp_msg *rxm = strp_msg(skb);
2026 size_t cipher_overhead;
2027 size_t data_len = 0;
2028 int ret;
2029
2030
2031 if (rxm->offset + prot->prepend_size > skb->len)
2032 return 0;
2033
2034
2035 if (WARN_ON(prot->prepend_size > sizeof(header))) {
2036 ret = -EINVAL;
2037 goto read_failure;
2038 }
2039
2040
2041 ret = skb_copy_bits(skb, rxm->offset, header, prot->prepend_size);
2042
2043 if (ret < 0)
2044 goto read_failure;
2045
2046 ctx->control = header[0];
2047
2048 data_len = ((header[4] & 0xFF) | (header[3] << 8));
2049
2050 cipher_overhead = prot->tag_size;
2051 if (prot->version != TLS_1_3_VERSION)
2052 cipher_overhead += prot->iv_size;
2053
2054 if (data_len > TLS_MAX_PAYLOAD_SIZE + cipher_overhead +
2055 prot->tail_size) {
2056 ret = -EMSGSIZE;
2057 goto read_failure;
2058 }
2059 if (data_len < cipher_overhead) {
2060 ret = -EBADMSG;
2061 goto read_failure;
2062 }
2063
2064
2065 if (header[1] != TLS_1_2_VERSION_MINOR ||
2066 header[2] != TLS_1_2_VERSION_MAJOR) {
2067 ret = -EINVAL;
2068 goto read_failure;
2069 }
2070
2071 tls_device_rx_resync_new_rec(strp->sk, data_len + TLS_HEADER_SIZE,
2072 TCP_SKB_CB(skb)->seq + rxm->offset);
2073 return data_len + TLS_HEADER_SIZE;
2074
2075 read_failure:
2076 tls_err_abort(strp->sk, ret);
2077
2078 return ret;
2079 }
2080
2081 static void tls_queue(struct strparser *strp, struct sk_buff *skb)
2082 {
2083 struct tls_context *tls_ctx = tls_get_ctx(strp->sk);
2084 struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
2085
2086 ctx->decrypted = false;
2087
2088 ctx->recv_pkt = skb;
2089 strp_pause(strp);
2090
2091 ctx->saved_data_ready(strp->sk);
2092 }
2093
2094 static void tls_data_ready(struct sock *sk)
2095 {
2096 struct tls_context *tls_ctx = tls_get_ctx(sk);
2097 struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
2098 struct sk_psock *psock;
2099
2100 strp_data_ready(&ctx->strp);
2101
2102 psock = sk_psock_get(sk);
2103 if (psock) {
2104 if (!list_empty(&psock->ingress_msg))
2105 ctx->saved_data_ready(sk);
2106 sk_psock_put(sk, psock);
2107 }
2108 }
2109
2110 void tls_sw_cancel_work_tx(struct tls_context *tls_ctx)
2111 {
2112 struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
2113
2114 set_bit(BIT_TX_CLOSING, &ctx->tx_bitmask);
2115 set_bit(BIT_TX_SCHEDULED, &ctx->tx_bitmask);
2116 cancel_delayed_work_sync(&ctx->tx_work.work);
2117 }
2118
2119 void tls_sw_release_resources_tx(struct sock *sk)
2120 {
2121 struct tls_context *tls_ctx = tls_get_ctx(sk);
2122 struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
2123 struct tls_rec *rec, *tmp;
2124
2125
2126 smp_store_mb(ctx->async_notify, true);
2127 if (atomic_read(&ctx->encrypt_pending))
2128 crypto_wait_req(-EINPROGRESS, &ctx->async_wait);
2129
2130 tls_tx_records(sk, -1);
2131
2132
2133
2134
2135 if (tls_ctx->partially_sent_record) {
2136 tls_free_partial_record(sk, tls_ctx);
2137 rec = list_first_entry(&ctx->tx_list,
2138 struct tls_rec, list);
2139 list_del(&rec->list);
2140 sk_msg_free(sk, &rec->msg_plaintext);
2141 kfree(rec);
2142 }
2143
2144 list_for_each_entry_safe(rec, tmp, &ctx->tx_list, list) {
2145 list_del(&rec->list);
2146 sk_msg_free(sk, &rec->msg_encrypted);
2147 sk_msg_free(sk, &rec->msg_plaintext);
2148 kfree(rec);
2149 }
2150
2151 crypto_free_aead(ctx->aead_send);
2152 tls_free_open_rec(sk);
2153 }
2154
2155 void tls_sw_free_ctx_tx(struct tls_context *tls_ctx)
2156 {
2157 struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
2158
2159 kfree(ctx);
2160 }
2161
2162 void tls_sw_release_resources_rx(struct sock *sk)
2163 {
2164 struct tls_context *tls_ctx = tls_get_ctx(sk);
2165 struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
2166
2167 kfree(tls_ctx->rx.rec_seq);
2168 kfree(tls_ctx->rx.iv);
2169
2170 if (ctx->aead_recv) {
2171 kfree_skb(ctx->recv_pkt);
2172 ctx->recv_pkt = NULL;
2173 skb_queue_purge(&ctx->rx_list);
2174 crypto_free_aead(ctx->aead_recv);
2175 strp_stop(&ctx->strp);
2176
2177
2178
2179
2180 if (ctx->saved_data_ready) {
2181 write_lock_bh(&sk->sk_callback_lock);
2182 sk->sk_data_ready = ctx->saved_data_ready;
2183 write_unlock_bh(&sk->sk_callback_lock);
2184 }
2185 }
2186 }
2187
2188 void tls_sw_strparser_done(struct tls_context *tls_ctx)
2189 {
2190 struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
2191
2192 strp_done(&ctx->strp);
2193 }
2194
2195 void tls_sw_free_ctx_rx(struct tls_context *tls_ctx)
2196 {
2197 struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
2198
2199 kfree(ctx);
2200 }
2201
2202 void tls_sw_free_resources_rx(struct sock *sk)
2203 {
2204 struct tls_context *tls_ctx = tls_get_ctx(sk);
2205
2206 tls_sw_release_resources_rx(sk);
2207 tls_sw_free_ctx_rx(tls_ctx);
2208 }
2209
2210
2211 static void tx_work_handler(struct work_struct *work)
2212 {
2213 struct delayed_work *delayed_work = to_delayed_work(work);
2214 struct tx_work *tx_work = container_of(delayed_work,
2215 struct tx_work, work);
2216 struct sock *sk = tx_work->sk;
2217 struct tls_context *tls_ctx = tls_get_ctx(sk);
2218 struct tls_sw_context_tx *ctx;
2219
2220 if (unlikely(!tls_ctx))
2221 return;
2222
2223 ctx = tls_sw_ctx_tx(tls_ctx);
2224 if (test_bit(BIT_TX_CLOSING, &ctx->tx_bitmask))
2225 return;
2226
2227 if (!test_and_clear_bit(BIT_TX_SCHEDULED, &ctx->tx_bitmask))
2228 return;
2229 mutex_lock(&tls_ctx->tx_lock);
2230 lock_sock(sk);
2231 tls_tx_records(sk, -1);
2232 release_sock(sk);
2233 mutex_unlock(&tls_ctx->tx_lock);
2234 }
2235
2236 void tls_sw_write_space(struct sock *sk, struct tls_context *ctx)
2237 {
2238 struct tls_sw_context_tx *tx_ctx = tls_sw_ctx_tx(ctx);
2239
2240
2241 if (is_tx_ready(tx_ctx) &&
2242 !test_and_set_bit(BIT_TX_SCHEDULED, &tx_ctx->tx_bitmask))
2243 schedule_delayed_work(&tx_ctx->tx_work.work, 0);
2244 }
2245
2246 void tls_sw_strparser_arm(struct sock *sk, struct tls_context *tls_ctx)
2247 {
2248 struct tls_sw_context_rx *rx_ctx = tls_sw_ctx_rx(tls_ctx);
2249
2250 write_lock_bh(&sk->sk_callback_lock);
2251 rx_ctx->saved_data_ready = sk->sk_data_ready;
2252 sk->sk_data_ready = tls_data_ready;
2253 write_unlock_bh(&sk->sk_callback_lock);
2254
2255 strp_check_rcv(&rx_ctx->strp);
2256 }
2257
2258 int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx)
2259 {
2260 struct tls_context *tls_ctx = tls_get_ctx(sk);
2261 struct tls_prot_info *prot = &tls_ctx->prot_info;
2262 struct tls_crypto_info *crypto_info;
2263 struct tls12_crypto_info_aes_gcm_128 *gcm_128_info;
2264 struct tls12_crypto_info_aes_gcm_256 *gcm_256_info;
2265 struct tls12_crypto_info_aes_ccm_128 *ccm_128_info;
2266 struct tls_sw_context_tx *sw_ctx_tx = NULL;
2267 struct tls_sw_context_rx *sw_ctx_rx = NULL;
2268 struct cipher_context *cctx;
2269 struct crypto_aead **aead;
2270 struct strp_callbacks cb;
2271 u16 nonce_size, tag_size, iv_size, rec_seq_size, salt_size;
2272 struct crypto_tfm *tfm;
2273 char *iv, *rec_seq, *key, *salt, *cipher_name;
2274 size_t keysize;
2275 int rc = 0;
2276
2277 if (!ctx) {
2278 rc = -EINVAL;
2279 goto out;
2280 }
2281
2282 if (tx) {
2283 if (!ctx->priv_ctx_tx) {
2284 sw_ctx_tx = kzalloc(sizeof(*sw_ctx_tx), GFP_KERNEL);
2285 if (!sw_ctx_tx) {
2286 rc = -ENOMEM;
2287 goto out;
2288 }
2289 ctx->priv_ctx_tx = sw_ctx_tx;
2290 } else {
2291 sw_ctx_tx =
2292 (struct tls_sw_context_tx *)ctx->priv_ctx_tx;
2293 }
2294 } else {
2295 if (!ctx->priv_ctx_rx) {
2296 sw_ctx_rx = kzalloc(sizeof(*sw_ctx_rx), GFP_KERNEL);
2297 if (!sw_ctx_rx) {
2298 rc = -ENOMEM;
2299 goto out;
2300 }
2301 ctx->priv_ctx_rx = sw_ctx_rx;
2302 } else {
2303 sw_ctx_rx =
2304 (struct tls_sw_context_rx *)ctx->priv_ctx_rx;
2305 }
2306 }
2307
2308 if (tx) {
2309 crypto_init_wait(&sw_ctx_tx->async_wait);
2310 spin_lock_init(&sw_ctx_tx->encrypt_compl_lock);
2311 crypto_info = &ctx->crypto_send.info;
2312 cctx = &ctx->tx;
2313 aead = &sw_ctx_tx->aead_send;
2314 INIT_LIST_HEAD(&sw_ctx_tx->tx_list);
2315 INIT_DELAYED_WORK(&sw_ctx_tx->tx_work.work, tx_work_handler);
2316 sw_ctx_tx->tx_work.sk = sk;
2317 } else {
2318 crypto_init_wait(&sw_ctx_rx->async_wait);
2319 spin_lock_init(&sw_ctx_rx->decrypt_compl_lock);
2320 crypto_info = &ctx->crypto_recv.info;
2321 cctx = &ctx->rx;
2322 skb_queue_head_init(&sw_ctx_rx->rx_list);
2323 aead = &sw_ctx_rx->aead_recv;
2324 }
2325
2326 switch (crypto_info->cipher_type) {
2327 case TLS_CIPHER_AES_GCM_128: {
2328 nonce_size = TLS_CIPHER_AES_GCM_128_IV_SIZE;
2329 tag_size = TLS_CIPHER_AES_GCM_128_TAG_SIZE;
2330 iv_size = TLS_CIPHER_AES_GCM_128_IV_SIZE;
2331 iv = ((struct tls12_crypto_info_aes_gcm_128 *)crypto_info)->iv;
2332 rec_seq_size = TLS_CIPHER_AES_GCM_128_REC_SEQ_SIZE;
2333 rec_seq =
2334 ((struct tls12_crypto_info_aes_gcm_128 *)crypto_info)->rec_seq;
2335 gcm_128_info =
2336 (struct tls12_crypto_info_aes_gcm_128 *)crypto_info;
2337 keysize = TLS_CIPHER_AES_GCM_128_KEY_SIZE;
2338 key = gcm_128_info->key;
2339 salt = gcm_128_info->salt;
2340 salt_size = TLS_CIPHER_AES_GCM_128_SALT_SIZE;
2341 cipher_name = "gcm(aes)";
2342 break;
2343 }
2344 case TLS_CIPHER_AES_GCM_256: {
2345 nonce_size = TLS_CIPHER_AES_GCM_256_IV_SIZE;
2346 tag_size = TLS_CIPHER_AES_GCM_256_TAG_SIZE;
2347 iv_size = TLS_CIPHER_AES_GCM_256_IV_SIZE;
2348 iv = ((struct tls12_crypto_info_aes_gcm_256 *)crypto_info)->iv;
2349 rec_seq_size = TLS_CIPHER_AES_GCM_256_REC_SEQ_SIZE;
2350 rec_seq =
2351 ((struct tls12_crypto_info_aes_gcm_256 *)crypto_info)->rec_seq;
2352 gcm_256_info =
2353 (struct tls12_crypto_info_aes_gcm_256 *)crypto_info;
2354 keysize = TLS_CIPHER_AES_GCM_256_KEY_SIZE;
2355 key = gcm_256_info->key;
2356 salt = gcm_256_info->salt;
2357 salt_size = TLS_CIPHER_AES_GCM_256_SALT_SIZE;
2358 cipher_name = "gcm(aes)";
2359 break;
2360 }
2361 case TLS_CIPHER_AES_CCM_128: {
2362 nonce_size = TLS_CIPHER_AES_CCM_128_IV_SIZE;
2363 tag_size = TLS_CIPHER_AES_CCM_128_TAG_SIZE;
2364 iv_size = TLS_CIPHER_AES_CCM_128_IV_SIZE;
2365 iv = ((struct tls12_crypto_info_aes_ccm_128 *)crypto_info)->iv;
2366 rec_seq_size = TLS_CIPHER_AES_CCM_128_REC_SEQ_SIZE;
2367 rec_seq =
2368 ((struct tls12_crypto_info_aes_ccm_128 *)crypto_info)->rec_seq;
2369 ccm_128_info =
2370 (struct tls12_crypto_info_aes_ccm_128 *)crypto_info;
2371 keysize = TLS_CIPHER_AES_CCM_128_KEY_SIZE;
2372 key = ccm_128_info->key;
2373 salt = ccm_128_info->salt;
2374 salt_size = TLS_CIPHER_AES_CCM_128_SALT_SIZE;
2375 cipher_name = "ccm(aes)";
2376 break;
2377 }
2378 default:
2379 rc = -EINVAL;
2380 goto free_priv;
2381 }
2382
2383
2384 if (iv_size > MAX_IV_SIZE || nonce_size > MAX_IV_SIZE ||
2385 rec_seq_size > TLS_MAX_REC_SEQ_SIZE) {
2386 rc = -EINVAL;
2387 goto free_priv;
2388 }
2389
2390 if (crypto_info->version == TLS_1_3_VERSION) {
2391 nonce_size = 0;
2392 prot->aad_size = TLS_HEADER_SIZE;
2393 prot->tail_size = 1;
2394 } else {
2395 prot->aad_size = TLS_AAD_SPACE_SIZE;
2396 prot->tail_size = 0;
2397 }
2398
2399 prot->version = crypto_info->version;
2400 prot->cipher_type = crypto_info->cipher_type;
2401 prot->prepend_size = TLS_HEADER_SIZE + nonce_size;
2402 prot->tag_size = tag_size;
2403 prot->overhead_size = prot->prepend_size +
2404 prot->tag_size + prot->tail_size;
2405 prot->iv_size = iv_size;
2406 prot->salt_size = salt_size;
2407 cctx->iv = kmalloc(iv_size + salt_size, GFP_KERNEL);
2408 if (!cctx->iv) {
2409 rc = -ENOMEM;
2410 goto free_priv;
2411 }
2412
2413 prot->rec_seq_size = rec_seq_size;
2414 memcpy(cctx->iv, salt, salt_size);
2415 memcpy(cctx->iv + salt_size, iv, iv_size);
2416 cctx->rec_seq = kmemdup(rec_seq, rec_seq_size, GFP_KERNEL);
2417 if (!cctx->rec_seq) {
2418 rc = -ENOMEM;
2419 goto free_iv;
2420 }
2421
2422 if (!*aead) {
2423 *aead = crypto_alloc_aead(cipher_name, 0, 0);
2424 if (IS_ERR(*aead)) {
2425 rc = PTR_ERR(*aead);
2426 *aead = NULL;
2427 goto free_rec_seq;
2428 }
2429 }
2430
2431 ctx->push_pending_record = tls_sw_push_pending_record;
2432
2433 rc = crypto_aead_setkey(*aead, key, keysize);
2434
2435 if (rc)
2436 goto free_aead;
2437
2438 rc = crypto_aead_setauthsize(*aead, prot->tag_size);
2439 if (rc)
2440 goto free_aead;
2441
2442 if (sw_ctx_rx) {
2443 tfm = crypto_aead_tfm(sw_ctx_rx->aead_recv);
2444
2445 if (crypto_info->version == TLS_1_3_VERSION)
2446 sw_ctx_rx->async_capable = false;
2447 else
2448 sw_ctx_rx->async_capable =
2449 tfm->__crt_alg->cra_flags & CRYPTO_ALG_ASYNC;
2450
2451
2452 memset(&cb, 0, sizeof(cb));
2453 cb.rcv_msg = tls_queue;
2454 cb.parse_msg = tls_read_size;
2455
2456 strp_init(&sw_ctx_rx->strp, sk, &cb);
2457 }
2458
2459 goto out;
2460
2461 free_aead:
2462 crypto_free_aead(*aead);
2463 *aead = NULL;
2464 free_rec_seq:
2465 kfree(cctx->rec_seq);
2466 cctx->rec_seq = NULL;
2467 free_iv:
2468 kfree(cctx->iv);
2469 cctx->iv = NULL;
2470 free_priv:
2471 if (tx) {
2472 kfree(ctx->priv_ctx_tx);
2473 ctx->priv_ctx_tx = NULL;
2474 } else {
2475 kfree(ctx->priv_ctx_rx);
2476 ctx->priv_ctx_rx = NULL;
2477 }
2478 out:
2479 return rc;
2480 }