1  /* Copyright (c) 2018, Mellanox Technologies All rights reserved.
2   *
3   * This software is available to you under a choice of one of two
4   * licenses.  You may choose to be licensed under the terms of the GNU
5   * General Public License (GPL) Version 2, available from the file
6   * COPYING in the main directory of this source tree, or the
7   * OpenIB.org BSD license below:
8   *
9   *     Redistribution and use in source and binary forms, with or
10   *     without modification, are permitted provided that the following
11   *     conditions are met:
12   *
13   *      - Redistributions of source code must retain the above
14   *        copyright notice, this list of conditions and the following
15   *        disclaimer.
16   *
17   *      - Redistributions in binary form must reproduce the above
18   *        copyright notice, this list of conditions and the following
19   *        disclaimer in the documentation and/or other materials
20   *        provided with the distribution.
21   *
22   * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
23   * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
24   * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
25   * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
26   * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
27   * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
28   * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
29   * SOFTWARE.
30   */
31  
32  #include <crypto/aead.h>
33  #include <linux/highmem.h>
34  #include <linux/module.h>
35  #include <linux/netdevice.h>
36  #include <net/dst.h>
37  #include <net/inet_connection_sock.h>
38  #include <net/tcp.h>
39  #include <net/tls.h>
40  #include <linux/skbuff_ref.h>
41  
42  #include "tls.h"
43  #include "trace.h"
44  
45  /* device_offload_lock is used to synchronize tls_dev_add
46   * against NETDEV_DOWN notifications.
47   */
48  static DECLARE_RWSEM(device_offload_lock);
49  
50  static struct workqueue_struct *destruct_wq __read_mostly;
51  
52  static LIST_HEAD(tls_device_list);
53  static LIST_HEAD(tls_device_down_list);
54  static DEFINE_SPINLOCK(tls_device_lock);
55  
56  static struct page *dummy_page;
57  
tls_device_free_ctx(struct tls_context * ctx)58  static void tls_device_free_ctx(struct tls_context *ctx)
59  {
60  	if (ctx->tx_conf == TLS_HW)
61  		kfree(tls_offload_ctx_tx(ctx));
62  
63  	if (ctx->rx_conf == TLS_HW)
64  		kfree(tls_offload_ctx_rx(ctx));
65  
66  	tls_ctx_free(NULL, ctx);
67  }
68  
tls_device_tx_del_task(struct work_struct * work)69  static void tls_device_tx_del_task(struct work_struct *work)
70  {
71  	struct tls_offload_context_tx *offload_ctx =
72  		container_of(work, struct tls_offload_context_tx, destruct_work);
73  	struct tls_context *ctx = offload_ctx->ctx;
74  	struct net_device *netdev;
75  
76  	/* Safe, because this is the destroy flow, refcount is 0, so
77  	 * tls_device_down can't store this field in parallel.
78  	 */
79  	netdev = rcu_dereference_protected(ctx->netdev,
80  					   !refcount_read(&ctx->refcount));
81  
82  	netdev->tlsdev_ops->tls_dev_del(netdev, ctx, TLS_OFFLOAD_CTX_DIR_TX);
83  	dev_put(netdev);
84  	ctx->netdev = NULL;
85  	tls_device_free_ctx(ctx);
86  }
87  
tls_device_queue_ctx_destruction(struct tls_context * ctx)88  static void tls_device_queue_ctx_destruction(struct tls_context *ctx)
89  {
90  	struct net_device *netdev;
91  	unsigned long flags;
92  	bool async_cleanup;
93  
94  	spin_lock_irqsave(&tls_device_lock, flags);
95  	if (unlikely(!refcount_dec_and_test(&ctx->refcount))) {
96  		spin_unlock_irqrestore(&tls_device_lock, flags);
97  		return;
98  	}
99  
100  	list_del(&ctx->list); /* Remove from tls_device_list / tls_device_down_list */
101  
102  	/* Safe, because this is the destroy flow, refcount is 0, so
103  	 * tls_device_down can't store this field in parallel.
104  	 */
105  	netdev = rcu_dereference_protected(ctx->netdev,
106  					   !refcount_read(&ctx->refcount));
107  
108  	async_cleanup = netdev && ctx->tx_conf == TLS_HW;
109  	if (async_cleanup) {
110  		struct tls_offload_context_tx *offload_ctx = tls_offload_ctx_tx(ctx);
111  
112  		/* queue_work inside the spinlock
113  		 * to make sure tls_device_down waits for that work.
114  		 */
115  		queue_work(destruct_wq, &offload_ctx->destruct_work);
116  	}
117  	spin_unlock_irqrestore(&tls_device_lock, flags);
118  
119  	if (!async_cleanup)
120  		tls_device_free_ctx(ctx);
121  }
122  
123  /* We assume that the socket is already connected */
get_netdev_for_sock(struct sock * sk)124  static struct net_device *get_netdev_for_sock(struct sock *sk)
125  {
126  	struct dst_entry *dst = sk_dst_get(sk);
127  	struct net_device *netdev = NULL;
128  
129  	if (likely(dst)) {
130  		netdev = netdev_sk_get_lowest_dev(dst->dev, sk);
131  		dev_hold(netdev);
132  	}
133  
134  	dst_release(dst);
135  
136  	return netdev;
137  }
138  
destroy_record(struct tls_record_info * record)139  static void destroy_record(struct tls_record_info *record)
140  {
141  	int i;
142  
143  	for (i = 0; i < record->num_frags; i++)
144  		__skb_frag_unref(&record->frags[i], false);
145  	kfree(record);
146  }
147  
delete_all_records(struct tls_offload_context_tx * offload_ctx)148  static void delete_all_records(struct tls_offload_context_tx *offload_ctx)
149  {
150  	struct tls_record_info *info, *temp;
151  
152  	list_for_each_entry_safe(info, temp, &offload_ctx->records_list, list) {
153  		list_del(&info->list);
154  		destroy_record(info);
155  	}
156  
157  	offload_ctx->retransmit_hint = NULL;
158  }
159  
tls_icsk_clean_acked(struct sock * sk,u32 acked_seq)160  static void tls_icsk_clean_acked(struct sock *sk, u32 acked_seq)
161  {
162  	struct tls_context *tls_ctx = tls_get_ctx(sk);
163  	struct tls_record_info *info, *temp;
164  	struct tls_offload_context_tx *ctx;
165  	u64 deleted_records = 0;
166  	unsigned long flags;
167  
168  	if (!tls_ctx)
169  		return;
170  
171  	ctx = tls_offload_ctx_tx(tls_ctx);
172  
173  	spin_lock_irqsave(&ctx->lock, flags);
174  	info = ctx->retransmit_hint;
175  	if (info && !before(acked_seq, info->end_seq))
176  		ctx->retransmit_hint = NULL;
177  
178  	list_for_each_entry_safe(info, temp, &ctx->records_list, list) {
179  		if (before(acked_seq, info->end_seq))
180  			break;
181  		list_del(&info->list);
182  
183  		destroy_record(info);
184  		deleted_records++;
185  	}
186  
187  	ctx->unacked_record_sn += deleted_records;
188  	spin_unlock_irqrestore(&ctx->lock, flags);
189  }
190  
191  /* At this point, there should be no references on this
192   * socket and no in-flight SKBs associated with this
193   * socket, so it is safe to free all the resources.
194   */
tls_device_sk_destruct(struct sock * sk)195  void tls_device_sk_destruct(struct sock *sk)
196  {
197  	struct tls_context *tls_ctx = tls_get_ctx(sk);
198  	struct tls_offload_context_tx *ctx = tls_offload_ctx_tx(tls_ctx);
199  
200  	tls_ctx->sk_destruct(sk);
201  
202  	if (tls_ctx->tx_conf == TLS_HW) {
203  		if (ctx->open_record)
204  			destroy_record(ctx->open_record);
205  		delete_all_records(ctx);
206  		crypto_free_aead(ctx->aead_send);
207  		clean_acked_data_disable(inet_csk(sk));
208  	}
209  
210  	tls_device_queue_ctx_destruction(tls_ctx);
211  }
212  EXPORT_SYMBOL_GPL(tls_device_sk_destruct);
213  
tls_device_free_resources_tx(struct sock * sk)214  void tls_device_free_resources_tx(struct sock *sk)
215  {
216  	struct tls_context *tls_ctx = tls_get_ctx(sk);
217  
218  	tls_free_partial_record(sk, tls_ctx);
219  }
220  
tls_offload_tx_resync_request(struct sock * sk,u32 got_seq,u32 exp_seq)221  void tls_offload_tx_resync_request(struct sock *sk, u32 got_seq, u32 exp_seq)
222  {
223  	struct tls_context *tls_ctx = tls_get_ctx(sk);
224  
225  	trace_tls_device_tx_resync_req(sk, got_seq, exp_seq);
226  	WARN_ON(test_and_set_bit(TLS_TX_SYNC_SCHED, &tls_ctx->flags));
227  }
228  EXPORT_SYMBOL_GPL(tls_offload_tx_resync_request);
229  
tls_device_resync_tx(struct sock * sk,struct tls_context * tls_ctx,u32 seq)230  static void tls_device_resync_tx(struct sock *sk, struct tls_context *tls_ctx,
231  				 u32 seq)
232  {
233  	struct net_device *netdev;
234  	int err = 0;
235  	u8 *rcd_sn;
236  
237  	tcp_write_collapse_fence(sk);
238  	rcd_sn = tls_ctx->tx.rec_seq;
239  
240  	trace_tls_device_tx_resync_send(sk, seq, rcd_sn);
241  	down_read(&device_offload_lock);
242  	netdev = rcu_dereference_protected(tls_ctx->netdev,
243  					   lockdep_is_held(&device_offload_lock));
244  	if (netdev)
245  		err = netdev->tlsdev_ops->tls_dev_resync(netdev, sk, seq,
246  							 rcd_sn,
247  							 TLS_OFFLOAD_CTX_DIR_TX);
248  	up_read(&device_offload_lock);
249  	if (err)
250  		return;
251  
252  	clear_bit_unlock(TLS_TX_SYNC_SCHED, &tls_ctx->flags);
253  }
254  
tls_append_frag(struct tls_record_info * record,struct page_frag * pfrag,int size)255  static void tls_append_frag(struct tls_record_info *record,
256  			    struct page_frag *pfrag,
257  			    int size)
258  {
259  	skb_frag_t *frag;
260  
261  	frag = &record->frags[record->num_frags - 1];
262  	if (skb_frag_page(frag) == pfrag->page &&
263  	    skb_frag_off(frag) + skb_frag_size(frag) == pfrag->offset) {
264  		skb_frag_size_add(frag, size);
265  	} else {
266  		++frag;
267  		skb_frag_fill_page_desc(frag, pfrag->page, pfrag->offset,
268  					size);
269  		++record->num_frags;
270  		get_page(pfrag->page);
271  	}
272  
273  	pfrag->offset += size;
274  	record->len += size;
275  }
276  
tls_push_record(struct sock * sk,struct tls_context * ctx,struct tls_offload_context_tx * offload_ctx,struct tls_record_info * record,int flags)277  static int tls_push_record(struct sock *sk,
278  			   struct tls_context *ctx,
279  			   struct tls_offload_context_tx *offload_ctx,
280  			   struct tls_record_info *record,
281  			   int flags)
282  {
283  	struct tls_prot_info *prot = &ctx->prot_info;
284  	struct tcp_sock *tp = tcp_sk(sk);
285  	skb_frag_t *frag;
286  	int i;
287  
288  	record->end_seq = tp->write_seq + record->len;
289  	list_add_tail_rcu(&record->list, &offload_ctx->records_list);
290  	offload_ctx->open_record = NULL;
291  
292  	if (test_bit(TLS_TX_SYNC_SCHED, &ctx->flags))
293  		tls_device_resync_tx(sk, ctx, tp->write_seq);
294  
295  	tls_advance_record_sn(sk, prot, &ctx->tx);
296  
297  	for (i = 0; i < record->num_frags; i++) {
298  		frag = &record->frags[i];
299  		sg_unmark_end(&offload_ctx->sg_tx_data[i]);
300  		sg_set_page(&offload_ctx->sg_tx_data[i], skb_frag_page(frag),
301  			    skb_frag_size(frag), skb_frag_off(frag));
302  		sk_mem_charge(sk, skb_frag_size(frag));
303  		get_page(skb_frag_page(frag));
304  	}
305  	sg_mark_end(&offload_ctx->sg_tx_data[record->num_frags - 1]);
306  
307  	/* all ready, send */
308  	return tls_push_sg(sk, ctx, offload_ctx->sg_tx_data, 0, flags);
309  }
310  
tls_device_record_close(struct sock * sk,struct tls_context * ctx,struct tls_record_info * record,struct page_frag * pfrag,unsigned char record_type)311  static void tls_device_record_close(struct sock *sk,
312  				    struct tls_context *ctx,
313  				    struct tls_record_info *record,
314  				    struct page_frag *pfrag,
315  				    unsigned char record_type)
316  {
317  	struct tls_prot_info *prot = &ctx->prot_info;
318  	struct page_frag dummy_tag_frag;
319  
320  	/* append tag
321  	 * device will fill in the tag, we just need to append a placeholder
322  	 * use socket memory to improve coalescing (re-using a single buffer
323  	 * increases frag count)
324  	 * if we can't allocate memory now use the dummy page
325  	 */
326  	if (unlikely(pfrag->size - pfrag->offset < prot->tag_size) &&
327  	    !skb_page_frag_refill(prot->tag_size, pfrag, sk->sk_allocation)) {
328  		dummy_tag_frag.page = dummy_page;
329  		dummy_tag_frag.offset = 0;
330  		pfrag = &dummy_tag_frag;
331  	}
332  	tls_append_frag(record, pfrag, prot->tag_size);
333  
334  	/* fill prepend */
335  	tls_fill_prepend(ctx, skb_frag_address(&record->frags[0]),
336  			 record->len - prot->overhead_size,
337  			 record_type);
338  }
339  
tls_create_new_record(struct tls_offload_context_tx * offload_ctx,struct page_frag * pfrag,size_t prepend_size)340  static int tls_create_new_record(struct tls_offload_context_tx *offload_ctx,
341  				 struct page_frag *pfrag,
342  				 size_t prepend_size)
343  {
344  	struct tls_record_info *record;
345  	skb_frag_t *frag;
346  
347  	record = kmalloc(sizeof(*record), GFP_KERNEL);
348  	if (!record)
349  		return -ENOMEM;
350  
351  	frag = &record->frags[0];
352  	skb_frag_fill_page_desc(frag, pfrag->page, pfrag->offset,
353  				prepend_size);
354  
355  	get_page(pfrag->page);
356  	pfrag->offset += prepend_size;
357  
358  	record->num_frags = 1;
359  	record->len = prepend_size;
360  	offload_ctx->open_record = record;
361  	return 0;
362  }
363  
tls_do_allocation(struct sock * sk,struct tls_offload_context_tx * offload_ctx,struct page_frag * pfrag,size_t prepend_size)364  static int tls_do_allocation(struct sock *sk,
365  			     struct tls_offload_context_tx *offload_ctx,
366  			     struct page_frag *pfrag,
367  			     size_t prepend_size)
368  {
369  	int ret;
370  
371  	if (!offload_ctx->open_record) {
372  		if (unlikely(!skb_page_frag_refill(prepend_size, pfrag,
373  						   sk->sk_allocation))) {
374  			READ_ONCE(sk->sk_prot)->enter_memory_pressure(sk);
375  			sk_stream_moderate_sndbuf(sk);
376  			return -ENOMEM;
377  		}
378  
379  		ret = tls_create_new_record(offload_ctx, pfrag, prepend_size);
380  		if (ret)
381  			return ret;
382  
383  		if (pfrag->size > pfrag->offset)
384  			return 0;
385  	}
386  
387  	if (!sk_page_frag_refill(sk, pfrag))
388  		return -ENOMEM;
389  
390  	return 0;
391  }
392  
tls_device_copy_data(void * addr,size_t bytes,struct iov_iter * i)393  static int tls_device_copy_data(void *addr, size_t bytes, struct iov_iter *i)
394  {
395  	size_t pre_copy, nocache;
396  
397  	pre_copy = ~((unsigned long)addr - 1) & (SMP_CACHE_BYTES - 1);
398  	if (pre_copy) {
399  		pre_copy = min(pre_copy, bytes);
400  		if (copy_from_iter(addr, pre_copy, i) != pre_copy)
401  			return -EFAULT;
402  		bytes -= pre_copy;
403  		addr += pre_copy;
404  	}
405  
406  	nocache = round_down(bytes, SMP_CACHE_BYTES);
407  	if (copy_from_iter_nocache(addr, nocache, i) != nocache)
408  		return -EFAULT;
409  	bytes -= nocache;
410  	addr += nocache;
411  
412  	if (bytes && copy_from_iter(addr, bytes, i) != bytes)
413  		return -EFAULT;
414  
415  	return 0;
416  }
417  
tls_push_data(struct sock * sk,struct iov_iter * iter,size_t size,int flags,unsigned char record_type)418  static int tls_push_data(struct sock *sk,
419  			 struct iov_iter *iter,
420  			 size_t size, int flags,
421  			 unsigned char record_type)
422  {
423  	struct tls_context *tls_ctx = tls_get_ctx(sk);
424  	struct tls_prot_info *prot = &tls_ctx->prot_info;
425  	struct tls_offload_context_tx *ctx = tls_offload_ctx_tx(tls_ctx);
426  	struct tls_record_info *record;
427  	int tls_push_record_flags;
428  	struct page_frag *pfrag;
429  	size_t orig_size = size;
430  	u32 max_open_record_len;
431  	bool more = false;
432  	bool done = false;
433  	int copy, rc = 0;
434  	long timeo;
435  
436  	if (flags &
437  	    ~(MSG_MORE | MSG_DONTWAIT | MSG_NOSIGNAL |
438  	      MSG_SPLICE_PAGES | MSG_EOR))
439  		return -EOPNOTSUPP;
440  
441  	if ((flags & (MSG_MORE | MSG_EOR)) == (MSG_MORE | MSG_EOR))
442  		return -EINVAL;
443  
444  	if (unlikely(sk->sk_err))
445  		return -sk->sk_err;
446  
447  	flags |= MSG_SENDPAGE_DECRYPTED;
448  	tls_push_record_flags = flags | MSG_MORE;
449  
450  	timeo = sock_sndtimeo(sk, flags & MSG_DONTWAIT);
451  	if (tls_is_partially_sent_record(tls_ctx)) {
452  		rc = tls_push_partial_record(sk, tls_ctx, flags);
453  		if (rc < 0)
454  			return rc;
455  	}
456  
457  	pfrag = sk_page_frag(sk);
458  
459  	/* TLS_HEADER_SIZE is not counted as part of the TLS record, and
460  	 * we need to leave room for an authentication tag.
461  	 */
462  	max_open_record_len = TLS_MAX_PAYLOAD_SIZE +
463  			      prot->prepend_size;
464  	do {
465  		rc = tls_do_allocation(sk, ctx, pfrag, prot->prepend_size);
466  		if (unlikely(rc)) {
467  			rc = sk_stream_wait_memory(sk, &timeo);
468  			if (!rc)
469  				continue;
470  
471  			record = ctx->open_record;
472  			if (!record)
473  				break;
474  handle_error:
475  			if (record_type != TLS_RECORD_TYPE_DATA) {
476  				/* avoid sending partial
477  				 * record with type !=
478  				 * application_data
479  				 */
480  				size = orig_size;
481  				destroy_record(record);
482  				ctx->open_record = NULL;
483  			} else if (record->len > prot->prepend_size) {
484  				goto last_record;
485  			}
486  
487  			break;
488  		}
489  
490  		record = ctx->open_record;
491  
492  		copy = min_t(size_t, size, max_open_record_len - record->len);
493  		if (copy && (flags & MSG_SPLICE_PAGES)) {
494  			struct page_frag zc_pfrag;
495  			struct page **pages = &zc_pfrag.page;
496  			size_t off;
497  
498  			rc = iov_iter_extract_pages(iter, &pages,
499  						    copy, 1, 0, &off);
500  			if (rc <= 0) {
501  				if (rc == 0)
502  					rc = -EIO;
503  				goto handle_error;
504  			}
505  			copy = rc;
506  
507  			if (WARN_ON_ONCE(!sendpage_ok(zc_pfrag.page))) {
508  				iov_iter_revert(iter, copy);
509  				rc = -EIO;
510  				goto handle_error;
511  			}
512  
513  			zc_pfrag.offset = off;
514  			zc_pfrag.size = copy;
515  			tls_append_frag(record, &zc_pfrag, copy);
516  		} else if (copy) {
517  			copy = min_t(size_t, copy, pfrag->size - pfrag->offset);
518  
519  			rc = tls_device_copy_data(page_address(pfrag->page) +
520  						  pfrag->offset, copy,
521  						  iter);
522  			if (rc)
523  				goto handle_error;
524  			tls_append_frag(record, pfrag, copy);
525  		}
526  
527  		size -= copy;
528  		if (!size) {
529  last_record:
530  			tls_push_record_flags = flags;
531  			if (flags & MSG_MORE) {
532  				more = true;
533  				break;
534  			}
535  
536  			done = true;
537  		}
538  
539  		if (done || record->len >= max_open_record_len ||
540  		    (record->num_frags >= MAX_SKB_FRAGS - 1)) {
541  			tls_device_record_close(sk, tls_ctx, record,
542  						pfrag, record_type);
543  
544  			rc = tls_push_record(sk,
545  					     tls_ctx,
546  					     ctx,
547  					     record,
548  					     tls_push_record_flags);
549  			if (rc < 0)
550  				break;
551  		}
552  	} while (!done);
553  
554  	tls_ctx->pending_open_record_frags = more;
555  
556  	if (orig_size - size > 0)
557  		rc = orig_size - size;
558  
559  	return rc;
560  }
561  
tls_device_sendmsg(struct sock * sk,struct msghdr * msg,size_t size)562  int tls_device_sendmsg(struct sock *sk, struct msghdr *msg, size_t size)
563  {
564  	unsigned char record_type = TLS_RECORD_TYPE_DATA;
565  	struct tls_context *tls_ctx = tls_get_ctx(sk);
566  	int rc;
567  
568  	if (!tls_ctx->zerocopy_sendfile)
569  		msg->msg_flags &= ~MSG_SPLICE_PAGES;
570  
571  	mutex_lock(&tls_ctx->tx_lock);
572  	lock_sock(sk);
573  
574  	if (unlikely(msg->msg_controllen)) {
575  		rc = tls_process_cmsg(sk, msg, &record_type);
576  		if (rc)
577  			goto out;
578  	}
579  
580  	rc = tls_push_data(sk, &msg->msg_iter, size, msg->msg_flags,
581  			   record_type);
582  
583  out:
584  	release_sock(sk);
585  	mutex_unlock(&tls_ctx->tx_lock);
586  	return rc;
587  }
588  
tls_device_splice_eof(struct socket * sock)589  void tls_device_splice_eof(struct socket *sock)
590  {
591  	struct sock *sk = sock->sk;
592  	struct tls_context *tls_ctx = tls_get_ctx(sk);
593  	struct iov_iter iter = {};
594  
595  	if (!tls_is_partially_sent_record(tls_ctx))
596  		return;
597  
598  	mutex_lock(&tls_ctx->tx_lock);
599  	lock_sock(sk);
600  
601  	if (tls_is_partially_sent_record(tls_ctx)) {
602  		iov_iter_bvec(&iter, ITER_SOURCE, NULL, 0, 0);
603  		tls_push_data(sk, &iter, 0, 0, TLS_RECORD_TYPE_DATA);
604  	}
605  
606  	release_sock(sk);
607  	mutex_unlock(&tls_ctx->tx_lock);
608  }
609  
tls_get_record(struct tls_offload_context_tx * context,u32 seq,u64 * p_record_sn)610  struct tls_record_info *tls_get_record(struct tls_offload_context_tx *context,
611  				       u32 seq, u64 *p_record_sn)
612  {
613  	u64 record_sn = context->hint_record_sn;
614  	struct tls_record_info *info, *last;
615  
616  	info = context->retransmit_hint;
617  	if (!info ||
618  	    before(seq, info->end_seq - info->len)) {
619  		/* if retransmit_hint is irrelevant start
620  		 * from the beginning of the list
621  		 */
622  		info = list_first_entry_or_null(&context->records_list,
623  						struct tls_record_info, list);
624  		if (!info)
625  			return NULL;
626  		/* send the start_marker record if seq number is before the
627  		 * tls offload start marker sequence number. This record is
628  		 * required to handle TCP packets which are before TLS offload
629  		 * started.
630  		 *  And if it's not start marker, look if this seq number
631  		 * belongs to the list.
632  		 */
633  		if (likely(!tls_record_is_start_marker(info))) {
634  			/* we have the first record, get the last record to see
635  			 * if this seq number belongs to the list.
636  			 */
637  			last = list_last_entry(&context->records_list,
638  					       struct tls_record_info, list);
639  
640  			if (!between(seq, tls_record_start_seq(info),
641  				     last->end_seq))
642  				return NULL;
643  		}
644  		record_sn = context->unacked_record_sn;
645  	}
646  
647  	/* We just need the _rcu for the READ_ONCE() */
648  	rcu_read_lock();
649  	list_for_each_entry_from_rcu(info, &context->records_list, list) {
650  		if (before(seq, info->end_seq)) {
651  			if (!context->retransmit_hint ||
652  			    after(info->end_seq,
653  				  context->retransmit_hint->end_seq)) {
654  				context->hint_record_sn = record_sn;
655  				context->retransmit_hint = info;
656  			}
657  			*p_record_sn = record_sn;
658  			goto exit_rcu_unlock;
659  		}
660  		record_sn++;
661  	}
662  	info = NULL;
663  
664  exit_rcu_unlock:
665  	rcu_read_unlock();
666  	return info;
667  }
668  EXPORT_SYMBOL(tls_get_record);
669  
tls_device_push_pending_record(struct sock * sk,int flags)670  static int tls_device_push_pending_record(struct sock *sk, int flags)
671  {
672  	struct iov_iter iter;
673  
674  	iov_iter_kvec(&iter, ITER_SOURCE, NULL, 0, 0);
675  	return tls_push_data(sk, &iter, 0, flags, TLS_RECORD_TYPE_DATA);
676  }
677  
tls_device_write_space(struct sock * sk,struct tls_context * ctx)678  void tls_device_write_space(struct sock *sk, struct tls_context *ctx)
679  {
680  	if (tls_is_partially_sent_record(ctx)) {
681  		gfp_t sk_allocation = sk->sk_allocation;
682  
683  		WARN_ON_ONCE(sk->sk_write_pending);
684  
685  		sk->sk_allocation = GFP_ATOMIC;
686  		tls_push_partial_record(sk, ctx,
687  					MSG_DONTWAIT | MSG_NOSIGNAL |
688  					MSG_SENDPAGE_DECRYPTED);
689  		sk->sk_allocation = sk_allocation;
690  	}
691  }
692  
tls_device_resync_rx(struct tls_context * tls_ctx,struct sock * sk,u32 seq,u8 * rcd_sn)693  static void tls_device_resync_rx(struct tls_context *tls_ctx,
694  				 struct sock *sk, u32 seq, u8 *rcd_sn)
695  {
696  	struct tls_offload_context_rx *rx_ctx = tls_offload_ctx_rx(tls_ctx);
697  	struct net_device *netdev;
698  
699  	trace_tls_device_rx_resync_send(sk, seq, rcd_sn, rx_ctx->resync_type);
700  	rcu_read_lock();
701  	netdev = rcu_dereference(tls_ctx->netdev);
702  	if (netdev)
703  		netdev->tlsdev_ops->tls_dev_resync(netdev, sk, seq, rcd_sn,
704  						   TLS_OFFLOAD_CTX_DIR_RX);
705  	rcu_read_unlock();
706  	TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSRXDEVICERESYNC);
707  }
708  
709  static bool
tls_device_rx_resync_async(struct tls_offload_resync_async * resync_async,s64 resync_req,u32 * seq,u16 * rcd_delta)710  tls_device_rx_resync_async(struct tls_offload_resync_async *resync_async,
711  			   s64 resync_req, u32 *seq, u16 *rcd_delta)
712  {
713  	u32 is_async = resync_req & RESYNC_REQ_ASYNC;
714  	u32 req_seq = resync_req >> 32;
715  	u32 req_end = req_seq + ((resync_req >> 16) & 0xffff);
716  	u16 i;
717  
718  	*rcd_delta = 0;
719  
720  	if (is_async) {
721  		/* shouldn't get to wraparound:
722  		 * too long in async stage, something bad happened
723  		 */
724  		if (WARN_ON_ONCE(resync_async->rcd_delta == USHRT_MAX))
725  			return false;
726  
727  		/* asynchronous stage: log all headers seq such that
728  		 * req_seq <= seq <= end_seq, and wait for real resync request
729  		 */
730  		if (before(*seq, req_seq))
731  			return false;
732  		if (!after(*seq, req_end) &&
733  		    resync_async->loglen < TLS_DEVICE_RESYNC_ASYNC_LOGMAX)
734  			resync_async->log[resync_async->loglen++] = *seq;
735  
736  		resync_async->rcd_delta++;
737  
738  		return false;
739  	}
740  
741  	/* synchronous stage: check against the logged entries and
742  	 * proceed to check the next entries if no match was found
743  	 */
744  	for (i = 0; i < resync_async->loglen; i++)
745  		if (req_seq == resync_async->log[i] &&
746  		    atomic64_try_cmpxchg(&resync_async->req, &resync_req, 0)) {
747  			*rcd_delta = resync_async->rcd_delta - i;
748  			*seq = req_seq;
749  			resync_async->loglen = 0;
750  			resync_async->rcd_delta = 0;
751  			return true;
752  		}
753  
754  	resync_async->loglen = 0;
755  	resync_async->rcd_delta = 0;
756  
757  	if (req_seq == *seq &&
758  	    atomic64_try_cmpxchg(&resync_async->req,
759  				 &resync_req, 0))
760  		return true;
761  
762  	return false;
763  }
764  
tls_device_rx_resync_new_rec(struct sock * sk,u32 rcd_len,u32 seq)765  void tls_device_rx_resync_new_rec(struct sock *sk, u32 rcd_len, u32 seq)
766  {
767  	struct tls_context *tls_ctx = tls_get_ctx(sk);
768  	struct tls_offload_context_rx *rx_ctx;
769  	u8 rcd_sn[TLS_MAX_REC_SEQ_SIZE];
770  	u32 sock_data, is_req_pending;
771  	struct tls_prot_info *prot;
772  	s64 resync_req;
773  	u16 rcd_delta;
774  	u32 req_seq;
775  
776  	if (tls_ctx->rx_conf != TLS_HW)
777  		return;
778  	if (unlikely(test_bit(TLS_RX_DEV_DEGRADED, &tls_ctx->flags)))
779  		return;
780  
781  	prot = &tls_ctx->prot_info;
782  	rx_ctx = tls_offload_ctx_rx(tls_ctx);
783  	memcpy(rcd_sn, tls_ctx->rx.rec_seq, prot->rec_seq_size);
784  
785  	switch (rx_ctx->resync_type) {
786  	case TLS_OFFLOAD_SYNC_TYPE_DRIVER_REQ:
787  		resync_req = atomic64_read(&rx_ctx->resync_req);
788  		req_seq = resync_req >> 32;
789  		seq += TLS_HEADER_SIZE - 1;
790  		is_req_pending = resync_req;
791  
792  		if (likely(!is_req_pending) || req_seq != seq ||
793  		    !atomic64_try_cmpxchg(&rx_ctx->resync_req, &resync_req, 0))
794  			return;
795  		break;
796  	case TLS_OFFLOAD_SYNC_TYPE_CORE_NEXT_HINT:
797  		if (likely(!rx_ctx->resync_nh_do_now))
798  			return;
799  
800  		/* head of next rec is already in, note that the sock_inq will
801  		 * include the currently parsed message when called from parser
802  		 */
803  		sock_data = tcp_inq(sk);
804  		if (sock_data > rcd_len) {
805  			trace_tls_device_rx_resync_nh_delay(sk, sock_data,
806  							    rcd_len);
807  			return;
808  		}
809  
810  		rx_ctx->resync_nh_do_now = 0;
811  		seq += rcd_len;
812  		tls_bigint_increment(rcd_sn, prot->rec_seq_size);
813  		break;
814  	case TLS_OFFLOAD_SYNC_TYPE_DRIVER_REQ_ASYNC:
815  		resync_req = atomic64_read(&rx_ctx->resync_async->req);
816  		is_req_pending = resync_req;
817  		if (likely(!is_req_pending))
818  			return;
819  
820  		if (!tls_device_rx_resync_async(rx_ctx->resync_async,
821  						resync_req, &seq, &rcd_delta))
822  			return;
823  		tls_bigint_subtract(rcd_sn, rcd_delta);
824  		break;
825  	}
826  
827  	tls_device_resync_rx(tls_ctx, sk, seq, rcd_sn);
828  }
829  
tls_device_core_ctrl_rx_resync(struct tls_context * tls_ctx,struct tls_offload_context_rx * ctx,struct sock * sk,struct sk_buff * skb)830  static void tls_device_core_ctrl_rx_resync(struct tls_context *tls_ctx,
831  					   struct tls_offload_context_rx *ctx,
832  					   struct sock *sk, struct sk_buff *skb)
833  {
834  	struct strp_msg *rxm;
835  
836  	/* device will request resyncs by itself based on stream scan */
837  	if (ctx->resync_type != TLS_OFFLOAD_SYNC_TYPE_CORE_NEXT_HINT)
838  		return;
839  	/* already scheduled */
840  	if (ctx->resync_nh_do_now)
841  		return;
842  	/* seen decrypted fragments since last fully-failed record */
843  	if (ctx->resync_nh_reset) {
844  		ctx->resync_nh_reset = 0;
845  		ctx->resync_nh.decrypted_failed = 1;
846  		ctx->resync_nh.decrypted_tgt = TLS_DEVICE_RESYNC_NH_START_IVAL;
847  		return;
848  	}
849  
850  	if (++ctx->resync_nh.decrypted_failed <= ctx->resync_nh.decrypted_tgt)
851  		return;
852  
853  	/* doing resync, bump the next target in case it fails */
854  	if (ctx->resync_nh.decrypted_tgt < TLS_DEVICE_RESYNC_NH_MAX_IVAL)
855  		ctx->resync_nh.decrypted_tgt *= 2;
856  	else
857  		ctx->resync_nh.decrypted_tgt += TLS_DEVICE_RESYNC_NH_MAX_IVAL;
858  
859  	rxm = strp_msg(skb);
860  
861  	/* head of next rec is already in, parser will sync for us */
862  	if (tcp_inq(sk) > rxm->full_len) {
863  		trace_tls_device_rx_resync_nh_schedule(sk);
864  		ctx->resync_nh_do_now = 1;
865  	} else {
866  		struct tls_prot_info *prot = &tls_ctx->prot_info;
867  		u8 rcd_sn[TLS_MAX_REC_SEQ_SIZE];
868  
869  		memcpy(rcd_sn, tls_ctx->rx.rec_seq, prot->rec_seq_size);
870  		tls_bigint_increment(rcd_sn, prot->rec_seq_size);
871  
872  		tls_device_resync_rx(tls_ctx, sk, tcp_sk(sk)->copied_seq,
873  				     rcd_sn);
874  	}
875  }
876  
877  static int
tls_device_reencrypt(struct sock * sk,struct tls_context * tls_ctx)878  tls_device_reencrypt(struct sock *sk, struct tls_context *tls_ctx)
879  {
880  	struct tls_sw_context_rx *sw_ctx = tls_sw_ctx_rx(tls_ctx);
881  	const struct tls_cipher_desc *cipher_desc;
882  	int err, offset, copy, data_len, pos;
883  	struct sk_buff *skb, *skb_iter;
884  	struct scatterlist sg[1];
885  	struct strp_msg *rxm;
886  	char *orig_buf, *buf;
887  
888  	cipher_desc = get_cipher_desc(tls_ctx->crypto_recv.info.cipher_type);
889  	DEBUG_NET_WARN_ON_ONCE(!cipher_desc || !cipher_desc->offloadable);
890  
891  	rxm = strp_msg(tls_strp_msg(sw_ctx));
892  	orig_buf = kmalloc(rxm->full_len + TLS_HEADER_SIZE + cipher_desc->iv,
893  			   sk->sk_allocation);
894  	if (!orig_buf)
895  		return -ENOMEM;
896  	buf = orig_buf;
897  
898  	err = tls_strp_msg_cow(sw_ctx);
899  	if (unlikely(err))
900  		goto free_buf;
901  
902  	skb = tls_strp_msg(sw_ctx);
903  	rxm = strp_msg(skb);
904  	offset = rxm->offset;
905  
906  	sg_init_table(sg, 1);
907  	sg_set_buf(&sg[0], buf,
908  		   rxm->full_len + TLS_HEADER_SIZE + cipher_desc->iv);
909  	err = skb_copy_bits(skb, offset, buf, TLS_HEADER_SIZE + cipher_desc->iv);
910  	if (err)
911  		goto free_buf;
912  
913  	/* We are interested only in the decrypted data not the auth */
914  	err = decrypt_skb(sk, sg);
915  	if (err != -EBADMSG)
916  		goto free_buf;
917  	else
918  		err = 0;
919  
920  	data_len = rxm->full_len - cipher_desc->tag;
921  
922  	if (skb_pagelen(skb) > offset) {
923  		copy = min_t(int, skb_pagelen(skb) - offset, data_len);
924  
925  		if (skb->decrypted) {
926  			err = skb_store_bits(skb, offset, buf, copy);
927  			if (err)
928  				goto free_buf;
929  		}
930  
931  		offset += copy;
932  		buf += copy;
933  	}
934  
935  	pos = skb_pagelen(skb);
936  	skb_walk_frags(skb, skb_iter) {
937  		int frag_pos;
938  
939  		/* Practically all frags must belong to msg if reencrypt
940  		 * is needed with current strparser and coalescing logic,
941  		 * but strparser may "get optimized", so let's be safe.
942  		 */
943  		if (pos + skb_iter->len <= offset)
944  			goto done_with_frag;
945  		if (pos >= data_len + rxm->offset)
946  			break;
947  
948  		frag_pos = offset - pos;
949  		copy = min_t(int, skb_iter->len - frag_pos,
950  			     data_len + rxm->offset - offset);
951  
952  		if (skb_iter->decrypted) {
953  			err = skb_store_bits(skb_iter, frag_pos, buf, copy);
954  			if (err)
955  				goto free_buf;
956  		}
957  
958  		offset += copy;
959  		buf += copy;
960  done_with_frag:
961  		pos += skb_iter->len;
962  	}
963  
964  free_buf:
965  	kfree(orig_buf);
966  	return err;
967  }
968  
tls_device_decrypted(struct sock * sk,struct tls_context * tls_ctx)969  int tls_device_decrypted(struct sock *sk, struct tls_context *tls_ctx)
970  {
971  	struct tls_offload_context_rx *ctx = tls_offload_ctx_rx(tls_ctx);
972  	struct tls_sw_context_rx *sw_ctx = tls_sw_ctx_rx(tls_ctx);
973  	struct sk_buff *skb = tls_strp_msg(sw_ctx);
974  	struct strp_msg *rxm = strp_msg(skb);
975  	int is_decrypted, is_encrypted;
976  
977  	if (!tls_strp_msg_mixed_decrypted(sw_ctx)) {
978  		is_decrypted = skb->decrypted;
979  		is_encrypted = !is_decrypted;
980  	} else {
981  		is_decrypted = 0;
982  		is_encrypted = 0;
983  	}
984  
985  	trace_tls_device_decrypted(sk, tcp_sk(sk)->copied_seq - rxm->full_len,
986  				   tls_ctx->rx.rec_seq, rxm->full_len,
987  				   is_encrypted, is_decrypted);
988  
989  	if (unlikely(test_bit(TLS_RX_DEV_DEGRADED, &tls_ctx->flags))) {
990  		if (likely(is_encrypted || is_decrypted))
991  			return is_decrypted;
992  
993  		/* After tls_device_down disables the offload, the next SKB will
994  		 * likely have initial fragments decrypted, and final ones not
995  		 * decrypted. We need to reencrypt that single SKB.
996  		 */
997  		return tls_device_reencrypt(sk, tls_ctx);
998  	}
999  
1000  	/* Return immediately if the record is either entirely plaintext or
1001  	 * entirely ciphertext. Otherwise handle reencrypt partially decrypted
1002  	 * record.
1003  	 */
1004  	if (is_decrypted) {
1005  		ctx->resync_nh_reset = 1;
1006  		return is_decrypted;
1007  	}
1008  	if (is_encrypted) {
1009  		tls_device_core_ctrl_rx_resync(tls_ctx, ctx, sk, skb);
1010  		return 0;
1011  	}
1012  
1013  	ctx->resync_nh_reset = 1;
1014  	return tls_device_reencrypt(sk, tls_ctx);
1015  }
1016  
tls_device_attach(struct tls_context * ctx,struct sock * sk,struct net_device * netdev)1017  static void tls_device_attach(struct tls_context *ctx, struct sock *sk,
1018  			      struct net_device *netdev)
1019  {
1020  	if (sk->sk_destruct != tls_device_sk_destruct) {
1021  		refcount_set(&ctx->refcount, 1);
1022  		dev_hold(netdev);
1023  		RCU_INIT_POINTER(ctx->netdev, netdev);
1024  		spin_lock_irq(&tls_device_lock);
1025  		list_add_tail(&ctx->list, &tls_device_list);
1026  		spin_unlock_irq(&tls_device_lock);
1027  
1028  		ctx->sk_destruct = sk->sk_destruct;
1029  		smp_store_release(&sk->sk_destruct, tls_device_sk_destruct);
1030  	}
1031  }
1032  
alloc_offload_ctx_tx(struct tls_context * ctx)1033  static struct tls_offload_context_tx *alloc_offload_ctx_tx(struct tls_context *ctx)
1034  {
1035  	struct tls_offload_context_tx *offload_ctx;
1036  	__be64 rcd_sn;
1037  
1038  	offload_ctx = kzalloc(sizeof(*offload_ctx), GFP_KERNEL);
1039  	if (!offload_ctx)
1040  		return NULL;
1041  
1042  	INIT_WORK(&offload_ctx->destruct_work, tls_device_tx_del_task);
1043  	INIT_LIST_HEAD(&offload_ctx->records_list);
1044  	spin_lock_init(&offload_ctx->lock);
1045  	sg_init_table(offload_ctx->sg_tx_data,
1046  		      ARRAY_SIZE(offload_ctx->sg_tx_data));
1047  
1048  	/* start at rec_seq - 1 to account for the start marker record */
1049  	memcpy(&rcd_sn, ctx->tx.rec_seq, sizeof(rcd_sn));
1050  	offload_ctx->unacked_record_sn = be64_to_cpu(rcd_sn) - 1;
1051  
1052  	offload_ctx->ctx = ctx;
1053  
1054  	return offload_ctx;
1055  }
1056  
tls_set_device_offload(struct sock * sk)1057  int tls_set_device_offload(struct sock *sk)
1058  {
1059  	struct tls_record_info *start_marker_record;
1060  	struct tls_offload_context_tx *offload_ctx;
1061  	const struct tls_cipher_desc *cipher_desc;
1062  	struct tls_crypto_info *crypto_info;
1063  	struct tls_prot_info *prot;
1064  	struct net_device *netdev;
1065  	struct tls_context *ctx;
1066  	char *iv, *rec_seq;
1067  	int rc;
1068  
1069  	ctx = tls_get_ctx(sk);
1070  	prot = &ctx->prot_info;
1071  
1072  	if (ctx->priv_ctx_tx)
1073  		return -EEXIST;
1074  
1075  	netdev = get_netdev_for_sock(sk);
1076  	if (!netdev) {
1077  		pr_err_ratelimited("%s: netdev not found\n", __func__);
1078  		return -EINVAL;
1079  	}
1080  
1081  	if (!(netdev->features & NETIF_F_HW_TLS_TX)) {
1082  		rc = -EOPNOTSUPP;
1083  		goto release_netdev;
1084  	}
1085  
1086  	crypto_info = &ctx->crypto_send.info;
1087  	if (crypto_info->version != TLS_1_2_VERSION) {
1088  		rc = -EOPNOTSUPP;
1089  		goto release_netdev;
1090  	}
1091  
1092  	cipher_desc = get_cipher_desc(crypto_info->cipher_type);
1093  	if (!cipher_desc || !cipher_desc->offloadable) {
1094  		rc = -EINVAL;
1095  		goto release_netdev;
1096  	}
1097  
1098  	rc = init_prot_info(prot, crypto_info, cipher_desc);
1099  	if (rc)
1100  		goto release_netdev;
1101  
1102  	iv = crypto_info_iv(crypto_info, cipher_desc);
1103  	rec_seq = crypto_info_rec_seq(crypto_info, cipher_desc);
1104  
1105  	memcpy(ctx->tx.iv + cipher_desc->salt, iv, cipher_desc->iv);
1106  	memcpy(ctx->tx.rec_seq, rec_seq, cipher_desc->rec_seq);
1107  
1108  	start_marker_record = kmalloc(sizeof(*start_marker_record), GFP_KERNEL);
1109  	if (!start_marker_record) {
1110  		rc = -ENOMEM;
1111  		goto release_netdev;
1112  	}
1113  
1114  	offload_ctx = alloc_offload_ctx_tx(ctx);
1115  	if (!offload_ctx) {
1116  		rc = -ENOMEM;
1117  		goto free_marker_record;
1118  	}
1119  
1120  	rc = tls_sw_fallback_init(sk, offload_ctx, crypto_info);
1121  	if (rc)
1122  		goto free_offload_ctx;
1123  
1124  	start_marker_record->end_seq = tcp_sk(sk)->write_seq;
1125  	start_marker_record->len = 0;
1126  	start_marker_record->num_frags = 0;
1127  	list_add_tail(&start_marker_record->list, &offload_ctx->records_list);
1128  
1129  	clean_acked_data_enable(inet_csk(sk), &tls_icsk_clean_acked);
1130  	ctx->push_pending_record = tls_device_push_pending_record;
1131  
1132  	/* TLS offload is greatly simplified if we don't send
1133  	 * SKBs where only part of the payload needs to be encrypted.
1134  	 * So mark the last skb in the write queue as end of record.
1135  	 */
1136  	tcp_write_collapse_fence(sk);
1137  
1138  	/* Avoid offloading if the device is down
1139  	 * We don't want to offload new flows after
1140  	 * the NETDEV_DOWN event
1141  	 *
1142  	 * device_offload_lock is taken in tls_devices's NETDEV_DOWN
1143  	 * handler thus protecting from the device going down before
1144  	 * ctx was added to tls_device_list.
1145  	 */
1146  	down_read(&device_offload_lock);
1147  	if (!(netdev->flags & IFF_UP)) {
1148  		rc = -EINVAL;
1149  		goto release_lock;
1150  	}
1151  
1152  	ctx->priv_ctx_tx = offload_ctx;
1153  	rc = netdev->tlsdev_ops->tls_dev_add(netdev, sk, TLS_OFFLOAD_CTX_DIR_TX,
1154  					     &ctx->crypto_send.info,
1155  					     tcp_sk(sk)->write_seq);
1156  	trace_tls_device_offload_set(sk, TLS_OFFLOAD_CTX_DIR_TX,
1157  				     tcp_sk(sk)->write_seq, rec_seq, rc);
1158  	if (rc)
1159  		goto release_lock;
1160  
1161  	tls_device_attach(ctx, sk, netdev);
1162  	up_read(&device_offload_lock);
1163  
1164  	/* following this assignment tls_is_skb_tx_device_offloaded
1165  	 * will return true and the context might be accessed
1166  	 * by the netdev's xmit function.
1167  	 */
1168  	smp_store_release(&sk->sk_validate_xmit_skb, tls_validate_xmit_skb);
1169  	dev_put(netdev);
1170  
1171  	return 0;
1172  
1173  release_lock:
1174  	up_read(&device_offload_lock);
1175  	clean_acked_data_disable(inet_csk(sk));
1176  	crypto_free_aead(offload_ctx->aead_send);
1177  free_offload_ctx:
1178  	kfree(offload_ctx);
1179  	ctx->priv_ctx_tx = NULL;
1180  free_marker_record:
1181  	kfree(start_marker_record);
1182  release_netdev:
1183  	dev_put(netdev);
1184  	return rc;
1185  }
1186  
tls_set_device_offload_rx(struct sock * sk,struct tls_context * ctx)1187  int tls_set_device_offload_rx(struct sock *sk, struct tls_context *ctx)
1188  {
1189  	struct tls12_crypto_info_aes_gcm_128 *info;
1190  	struct tls_offload_context_rx *context;
1191  	struct net_device *netdev;
1192  	int rc = 0;
1193  
1194  	if (ctx->crypto_recv.info.version != TLS_1_2_VERSION)
1195  		return -EOPNOTSUPP;
1196  
1197  	netdev = get_netdev_for_sock(sk);
1198  	if (!netdev) {
1199  		pr_err_ratelimited("%s: netdev not found\n", __func__);
1200  		return -EINVAL;
1201  	}
1202  
1203  	if (!(netdev->features & NETIF_F_HW_TLS_RX)) {
1204  		rc = -EOPNOTSUPP;
1205  		goto release_netdev;
1206  	}
1207  
1208  	/* Avoid offloading if the device is down
1209  	 * We don't want to offload new flows after
1210  	 * the NETDEV_DOWN event
1211  	 *
1212  	 * device_offload_lock is taken in tls_devices's NETDEV_DOWN
1213  	 * handler thus protecting from the device going down before
1214  	 * ctx was added to tls_device_list.
1215  	 */
1216  	down_read(&device_offload_lock);
1217  	if (!(netdev->flags & IFF_UP)) {
1218  		rc = -EINVAL;
1219  		goto release_lock;
1220  	}
1221  
1222  	context = kzalloc(sizeof(*context), GFP_KERNEL);
1223  	if (!context) {
1224  		rc = -ENOMEM;
1225  		goto release_lock;
1226  	}
1227  	context->resync_nh_reset = 1;
1228  
1229  	ctx->priv_ctx_rx = context;
1230  	rc = tls_set_sw_offload(sk, 0);
1231  	if (rc)
1232  		goto release_ctx;
1233  
1234  	rc = netdev->tlsdev_ops->tls_dev_add(netdev, sk, TLS_OFFLOAD_CTX_DIR_RX,
1235  					     &ctx->crypto_recv.info,
1236  					     tcp_sk(sk)->copied_seq);
1237  	info = (void *)&ctx->crypto_recv.info;
1238  	trace_tls_device_offload_set(sk, TLS_OFFLOAD_CTX_DIR_RX,
1239  				     tcp_sk(sk)->copied_seq, info->rec_seq, rc);
1240  	if (rc)
1241  		goto free_sw_resources;
1242  
1243  	tls_device_attach(ctx, sk, netdev);
1244  	up_read(&device_offload_lock);
1245  
1246  	dev_put(netdev);
1247  
1248  	return 0;
1249  
1250  free_sw_resources:
1251  	up_read(&device_offload_lock);
1252  	tls_sw_free_resources_rx(sk);
1253  	down_read(&device_offload_lock);
1254  release_ctx:
1255  	ctx->priv_ctx_rx = NULL;
1256  release_lock:
1257  	up_read(&device_offload_lock);
1258  release_netdev:
1259  	dev_put(netdev);
1260  	return rc;
1261  }
1262  
tls_device_offload_cleanup_rx(struct sock * sk)1263  void tls_device_offload_cleanup_rx(struct sock *sk)
1264  {
1265  	struct tls_context *tls_ctx = tls_get_ctx(sk);
1266  	struct net_device *netdev;
1267  
1268  	down_read(&device_offload_lock);
1269  	netdev = rcu_dereference_protected(tls_ctx->netdev,
1270  					   lockdep_is_held(&device_offload_lock));
1271  	if (!netdev)
1272  		goto out;
1273  
1274  	netdev->tlsdev_ops->tls_dev_del(netdev, tls_ctx,
1275  					TLS_OFFLOAD_CTX_DIR_RX);
1276  
1277  	if (tls_ctx->tx_conf != TLS_HW) {
1278  		dev_put(netdev);
1279  		rcu_assign_pointer(tls_ctx->netdev, NULL);
1280  	} else {
1281  		set_bit(TLS_RX_DEV_CLOSED, &tls_ctx->flags);
1282  	}
1283  out:
1284  	up_read(&device_offload_lock);
1285  	tls_sw_release_resources_rx(sk);
1286  }
1287  
tls_device_down(struct net_device * netdev)1288  static int tls_device_down(struct net_device *netdev)
1289  {
1290  	struct tls_context *ctx, *tmp;
1291  	unsigned long flags;
1292  	LIST_HEAD(list);
1293  
1294  	/* Request a write lock to block new offload attempts */
1295  	down_write(&device_offload_lock);
1296  
1297  	spin_lock_irqsave(&tls_device_lock, flags);
1298  	list_for_each_entry_safe(ctx, tmp, &tls_device_list, list) {
1299  		struct net_device *ctx_netdev =
1300  			rcu_dereference_protected(ctx->netdev,
1301  						  lockdep_is_held(&device_offload_lock));
1302  
1303  		if (ctx_netdev != netdev ||
1304  		    !refcount_inc_not_zero(&ctx->refcount))
1305  			continue;
1306  
1307  		list_move(&ctx->list, &list);
1308  	}
1309  	spin_unlock_irqrestore(&tls_device_lock, flags);
1310  
1311  	list_for_each_entry_safe(ctx, tmp, &list, list)	{
1312  		/* Stop offloaded TX and switch to the fallback.
1313  		 * tls_is_skb_tx_device_offloaded will return false.
1314  		 */
1315  		WRITE_ONCE(ctx->sk->sk_validate_xmit_skb, tls_validate_xmit_skb_sw);
1316  
1317  		/* Stop the RX and TX resync.
1318  		 * tls_dev_resync must not be called after tls_dev_del.
1319  		 */
1320  		rcu_assign_pointer(ctx->netdev, NULL);
1321  
1322  		/* Start skipping the RX resync logic completely. */
1323  		set_bit(TLS_RX_DEV_DEGRADED, &ctx->flags);
1324  
1325  		/* Sync with inflight packets. After this point:
1326  		 * TX: no non-encrypted packets will be passed to the driver.
1327  		 * RX: resync requests from the driver will be ignored.
1328  		 */
1329  		synchronize_net();
1330  
1331  		/* Release the offload context on the driver side. */
1332  		if (ctx->tx_conf == TLS_HW)
1333  			netdev->tlsdev_ops->tls_dev_del(netdev, ctx,
1334  							TLS_OFFLOAD_CTX_DIR_TX);
1335  		if (ctx->rx_conf == TLS_HW &&
1336  		    !test_bit(TLS_RX_DEV_CLOSED, &ctx->flags))
1337  			netdev->tlsdev_ops->tls_dev_del(netdev, ctx,
1338  							TLS_OFFLOAD_CTX_DIR_RX);
1339  
1340  		dev_put(netdev);
1341  
1342  		/* Move the context to a separate list for two reasons:
1343  		 * 1. When the context is deallocated, list_del is called.
1344  		 * 2. It's no longer an offloaded context, so we don't want to
1345  		 *    run offload-specific code on this context.
1346  		 */
1347  		spin_lock_irqsave(&tls_device_lock, flags);
1348  		list_move_tail(&ctx->list, &tls_device_down_list);
1349  		spin_unlock_irqrestore(&tls_device_lock, flags);
1350  
1351  		/* Device contexts for RX and TX will be freed in on sk_destruct
1352  		 * by tls_device_free_ctx. rx_conf and tx_conf stay in TLS_HW.
1353  		 * Now release the ref taken above.
1354  		 */
1355  		if (refcount_dec_and_test(&ctx->refcount)) {
1356  			/* sk_destruct ran after tls_device_down took a ref, and
1357  			 * it returned early. Complete the destruction here.
1358  			 */
1359  			list_del(&ctx->list);
1360  			tls_device_free_ctx(ctx);
1361  		}
1362  	}
1363  
1364  	up_write(&device_offload_lock);
1365  
1366  	flush_workqueue(destruct_wq);
1367  
1368  	return NOTIFY_DONE;
1369  }
1370  
tls_dev_event(struct notifier_block * this,unsigned long event,void * ptr)1371  static int tls_dev_event(struct notifier_block *this, unsigned long event,
1372  			 void *ptr)
1373  {
1374  	struct net_device *dev = netdev_notifier_info_to_dev(ptr);
1375  
1376  	if (!dev->tlsdev_ops &&
1377  	    !(dev->features & (NETIF_F_HW_TLS_RX | NETIF_F_HW_TLS_TX)))
1378  		return NOTIFY_DONE;
1379  
1380  	switch (event) {
1381  	case NETDEV_REGISTER:
1382  	case NETDEV_FEAT_CHANGE:
1383  		if (netif_is_bond_master(dev))
1384  			return NOTIFY_DONE;
1385  		if ((dev->features & NETIF_F_HW_TLS_RX) &&
1386  		    !dev->tlsdev_ops->tls_dev_resync)
1387  			return NOTIFY_BAD;
1388  
1389  		if  (dev->tlsdev_ops &&
1390  		     dev->tlsdev_ops->tls_dev_add &&
1391  		     dev->tlsdev_ops->tls_dev_del)
1392  			return NOTIFY_DONE;
1393  		else
1394  			return NOTIFY_BAD;
1395  	case NETDEV_DOWN:
1396  		return tls_device_down(dev);
1397  	}
1398  	return NOTIFY_DONE;
1399  }
1400  
1401  static struct notifier_block tls_dev_notifier = {
1402  	.notifier_call	= tls_dev_event,
1403  };
1404  
tls_device_init(void)1405  int __init tls_device_init(void)
1406  {
1407  	int err;
1408  
1409  	dummy_page = alloc_page(GFP_KERNEL);
1410  	if (!dummy_page)
1411  		return -ENOMEM;
1412  
1413  	destruct_wq = alloc_workqueue("ktls_device_destruct", 0, 0);
1414  	if (!destruct_wq) {
1415  		err = -ENOMEM;
1416  		goto err_free_dummy;
1417  	}
1418  
1419  	err = register_netdevice_notifier(&tls_dev_notifier);
1420  	if (err)
1421  		goto err_destroy_wq;
1422  
1423  	return 0;
1424  
1425  err_destroy_wq:
1426  	destroy_workqueue(destruct_wq);
1427  err_free_dummy:
1428  	put_page(dummy_page);
1429  	return err;
1430  }
1431  
tls_device_cleanup(void)1432  void __exit tls_device_cleanup(void)
1433  {
1434  	unregister_netdevice_notifier(&tls_dev_notifier);
1435  	destroy_workqueue(destruct_wq);
1436  	clean_acked_data_flush();
1437  	put_page(dummy_page);
1438  }
1439