1/*
2 * Copyright (c) 2016-2017, Mellanox Technologies. All rights reserved.
3 * Copyright (c) 2016-2017, Dave Watson <davejwatson@fb.com>. All rights reserved.
4 *
5 * This software is available to you under a choice of one of two
6 * licenses.  You may choose to be licensed under the terms of the GNU
7 * General Public License (GPL) Version 2, available from the file
8 * COPYING in the main directory of this source tree, or the
9 * OpenIB.org BSD license below:
10 *
11 *     Redistribution and use in source and binary forms, with or
12 *     without modification, are permitted provided that the following
13 *     conditions are met:
14 *
15 *      - Redistributions of source code must retain the above
16 *        copyright notice, this list of conditions and the following
17 *        disclaimer.
18 *
19 *      - Redistributions in binary form must reproduce the above
20 *        copyright notice, this list of conditions and the following
21 *        disclaimer in the documentation and/or other materials
22 *        provided with the distribution.
23 *
24 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
25 * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
26 * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
27 * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
28 * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
29 * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
30 * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
31 * SOFTWARE.
32 */
33
34#include <linux/module.h>
35
36#include <net/tcp.h>
37#include <net/inet_common.h>
38#include <linux/highmem.h>
39#include <linux/netdevice.h>
40#include <linux/sched/signal.h>
41#include <linux/inetdevice.h>
42#include <linux/inet_diag.h>
43
44#include <net/snmp.h>
45#include <net/tls.h>
46#include <net/tls_toe.h>
47
48#include "tls.h"
49
50MODULE_AUTHOR("Mellanox Technologies");
51MODULE_DESCRIPTION("Transport Layer Security Support");
52MODULE_LICENSE("Dual BSD/GPL");
53MODULE_ALIAS_TCP_ULP("tls");
54
55enum {
56	TLSV4,
57	TLSV6,
58	TLS_NUM_PROTS,
59};
60
61#define CHECK_CIPHER_DESC(cipher,ci)				\
62	static_assert(cipher ## _IV_SIZE <= TLS_MAX_IV_SIZE);		\
63	static_assert(cipher ## _SALT_SIZE <= TLS_MAX_SALT_SIZE);		\
64	static_assert(cipher ## _REC_SEQ_SIZE <= TLS_MAX_REC_SEQ_SIZE);	\
65	static_assert(cipher ## _TAG_SIZE == TLS_TAG_SIZE);		\
66	static_assert(sizeof_field(struct ci, iv) == cipher ## _IV_SIZE);	\
67	static_assert(sizeof_field(struct ci, key) == cipher ## _KEY_SIZE);	\
68	static_assert(sizeof_field(struct ci, salt) == cipher ## _SALT_SIZE);	\
69	static_assert(sizeof_field(struct ci, rec_seq) == cipher ## _REC_SEQ_SIZE);
70
71#define __CIPHER_DESC(ci) \
72	.iv_offset = offsetof(struct ci, iv), \
73	.key_offset = offsetof(struct ci, key), \
74	.salt_offset = offsetof(struct ci, salt), \
75	.rec_seq_offset = offsetof(struct ci, rec_seq), \
76	.crypto_info = sizeof(struct ci)
77
78#define CIPHER_DESC(cipher,ci,algname,_offloadable) [cipher - TLS_CIPHER_MIN] = {	\
79	.nonce = cipher ## _IV_SIZE, \
80	.iv = cipher ## _IV_SIZE, \
81	.key = cipher ## _KEY_SIZE, \
82	.salt = cipher ## _SALT_SIZE, \
83	.tag = cipher ## _TAG_SIZE, \
84	.rec_seq = cipher ## _REC_SEQ_SIZE, \
85	.cipher_name = algname,	\
86	.offloadable = _offloadable, \
87	__CIPHER_DESC(ci), \
88}
89
90#define CIPHER_DESC_NONCE0(cipher,ci,algname,_offloadable) [cipher - TLS_CIPHER_MIN] = { \
91	.nonce = 0, \
92	.iv = cipher ## _IV_SIZE, \
93	.key = cipher ## _KEY_SIZE, \
94	.salt = cipher ## _SALT_SIZE, \
95	.tag = cipher ## _TAG_SIZE, \
96	.rec_seq = cipher ## _REC_SEQ_SIZE, \
97	.cipher_name = algname,	\
98	.offloadable = _offloadable, \
99	__CIPHER_DESC(ci), \
100}
101
102const struct tls_cipher_desc tls_cipher_desc[TLS_CIPHER_MAX + 1 - TLS_CIPHER_MIN] = {
103	CIPHER_DESC(TLS_CIPHER_AES_GCM_128, tls12_crypto_info_aes_gcm_128, "gcm(aes)", true),
104	CIPHER_DESC(TLS_CIPHER_AES_GCM_256, tls12_crypto_info_aes_gcm_256, "gcm(aes)", true),
105	CIPHER_DESC(TLS_CIPHER_AES_CCM_128, tls12_crypto_info_aes_ccm_128, "ccm(aes)", false),
106	CIPHER_DESC_NONCE0(TLS_CIPHER_CHACHA20_POLY1305, tls12_crypto_info_chacha20_poly1305, "rfc7539(chacha20,poly1305)", false),
107	CIPHER_DESC(TLS_CIPHER_SM4_GCM, tls12_crypto_info_sm4_gcm, "gcm(sm4)", false),
108	CIPHER_DESC(TLS_CIPHER_SM4_CCM, tls12_crypto_info_sm4_ccm, "ccm(sm4)", false),
109	CIPHER_DESC(TLS_CIPHER_ARIA_GCM_128, tls12_crypto_info_aria_gcm_128, "gcm(aria)", false),
110	CIPHER_DESC(TLS_CIPHER_ARIA_GCM_256, tls12_crypto_info_aria_gcm_256, "gcm(aria)", false),
111};
112
113CHECK_CIPHER_DESC(TLS_CIPHER_AES_GCM_128, tls12_crypto_info_aes_gcm_128);
114CHECK_CIPHER_DESC(TLS_CIPHER_AES_GCM_256, tls12_crypto_info_aes_gcm_256);
115CHECK_CIPHER_DESC(TLS_CIPHER_AES_CCM_128, tls12_crypto_info_aes_ccm_128);
116CHECK_CIPHER_DESC(TLS_CIPHER_CHACHA20_POLY1305, tls12_crypto_info_chacha20_poly1305);
117CHECK_CIPHER_DESC(TLS_CIPHER_SM4_GCM, tls12_crypto_info_sm4_gcm);
118CHECK_CIPHER_DESC(TLS_CIPHER_SM4_CCM, tls12_crypto_info_sm4_ccm);
119CHECK_CIPHER_DESC(TLS_CIPHER_ARIA_GCM_128, tls12_crypto_info_aria_gcm_128);
120CHECK_CIPHER_DESC(TLS_CIPHER_ARIA_GCM_256, tls12_crypto_info_aria_gcm_256);
121
122static const struct proto *saved_tcpv6_prot;
123static DEFINE_MUTEX(tcpv6_prot_mutex);
124static const struct proto *saved_tcpv4_prot;
125static DEFINE_MUTEX(tcpv4_prot_mutex);
126static struct proto tls_prots[TLS_NUM_PROTS][TLS_NUM_CONFIG][TLS_NUM_CONFIG];
127static struct proto_ops tls_proto_ops[TLS_NUM_PROTS][TLS_NUM_CONFIG][TLS_NUM_CONFIG];
128static void build_protos(struct proto prot[TLS_NUM_CONFIG][TLS_NUM_CONFIG],
129			 const struct proto *base);
130
131void update_sk_prot(struct sock *sk, struct tls_context *ctx)
132{
133	int ip_ver = sk->sk_family == AF_INET6 ? TLSV6 : TLSV4;
134
135	WRITE_ONCE(sk->sk_prot,
136		   &tls_prots[ip_ver][ctx->tx_conf][ctx->rx_conf]);
137	WRITE_ONCE(sk->sk_socket->ops,
138		   &tls_proto_ops[ip_ver][ctx->tx_conf][ctx->rx_conf]);
139}
140
141int wait_on_pending_writer(struct sock *sk, long *timeo)
142{
143	DEFINE_WAIT_FUNC(wait, woken_wake_function);
144	int ret, rc = 0;
145
146	add_wait_queue(sk_sleep(sk), &wait);
147	while (1) {
148		if (!*timeo) {
149			rc = -EAGAIN;
150			break;
151		}
152
153		if (signal_pending(current)) {
154			rc = sock_intr_errno(*timeo);
155			break;
156		}
157
158		ret = sk_wait_event(sk, timeo,
159				    !READ_ONCE(sk->sk_write_pending), &wait);
160		if (ret) {
161			if (ret < 0)
162				rc = ret;
163			break;
164		}
165	}
166	remove_wait_queue(sk_sleep(sk), &wait);
167	return rc;
168}
169
170int tls_push_sg(struct sock *sk,
171		struct tls_context *ctx,
172		struct scatterlist *sg,
173		u16 first_offset,
174		int flags)
175{
176	struct bio_vec bvec;
177	struct msghdr msg = {
178		.msg_flags = MSG_SPLICE_PAGES | flags,
179	};
180	int ret = 0;
181	struct page *p;
182	size_t size;
183	int offset = first_offset;
184
185	size = sg->length - offset;
186	offset += sg->offset;
187
188	ctx->splicing_pages = true;
189	while (1) {
190		/* is sending application-limited? */
191		tcp_rate_check_app_limited(sk);
192		p = sg_page(sg);
193retry:
194		bvec_set_page(&bvec, p, size, offset);
195		iov_iter_bvec(&msg.msg_iter, ITER_SOURCE, &bvec, 1, size);
196
197		ret = tcp_sendmsg_locked(sk, &msg, size);
198
199		if (ret != size) {
200			if (ret > 0) {
201				offset += ret;
202				size -= ret;
203				goto retry;
204			}
205
206			offset -= sg->offset;
207			ctx->partially_sent_offset = offset;
208			ctx->partially_sent_record = (void *)sg;
209			ctx->splicing_pages = false;
210			return ret;
211		}
212
213		put_page(p);
214		sk_mem_uncharge(sk, sg->length);
215		sg = sg_next(sg);
216		if (!sg)
217			break;
218
219		offset = sg->offset;
220		size = sg->length;
221	}
222
223	ctx->splicing_pages = false;
224
225	return 0;
226}
227
228static int tls_handle_open_record(struct sock *sk, int flags)
229{
230	struct tls_context *ctx = tls_get_ctx(sk);
231
232	if (tls_is_pending_open_record(ctx))
233		return ctx->push_pending_record(sk, flags);
234
235	return 0;
236}
237
238int tls_process_cmsg(struct sock *sk, struct msghdr *msg,
239		     unsigned char *record_type)
240{
241	struct cmsghdr *cmsg;
242	int rc = -EINVAL;
243
244	for_each_cmsghdr(cmsg, msg) {
245		if (!CMSG_OK(msg, cmsg))
246			return -EINVAL;
247		if (cmsg->cmsg_level != SOL_TLS)
248			continue;
249
250		switch (cmsg->cmsg_type) {
251		case TLS_SET_RECORD_TYPE:
252			if (cmsg->cmsg_len < CMSG_LEN(sizeof(*record_type)))
253				return -EINVAL;
254
255			if (msg->msg_flags & MSG_MORE)
256				return -EINVAL;
257
258			rc = tls_handle_open_record(sk, msg->msg_flags);
259			if (rc)
260				return rc;
261
262			*record_type = *(unsigned char *)CMSG_DATA(cmsg);
263			rc = 0;
264			break;
265		default:
266			return -EINVAL;
267		}
268	}
269
270	return rc;
271}
272
273int tls_push_partial_record(struct sock *sk, struct tls_context *ctx,
274			    int flags)
275{
276	struct scatterlist *sg;
277	u16 offset;
278
279	sg = ctx->partially_sent_record;
280	offset = ctx->partially_sent_offset;
281
282	ctx->partially_sent_record = NULL;
283	return tls_push_sg(sk, ctx, sg, offset, flags);
284}
285
286void tls_free_partial_record(struct sock *sk, struct tls_context *ctx)
287{
288	struct scatterlist *sg;
289
290	for (sg = ctx->partially_sent_record; sg; sg = sg_next(sg)) {
291		put_page(sg_page(sg));
292		sk_mem_uncharge(sk, sg->length);
293	}
294	ctx->partially_sent_record = NULL;
295}
296
297static void tls_write_space(struct sock *sk)
298{
299	struct tls_context *ctx = tls_get_ctx(sk);
300
301	/* If splicing_pages call lower protocol write space handler
302	 * to ensure we wake up any waiting operations there. For example
303	 * if splicing pages where to call sk_wait_event.
304	 */
305	if (ctx->splicing_pages) {
306		ctx->sk_write_space(sk);
307		return;
308	}
309
310#ifdef CONFIG_TLS_DEVICE
311	if (ctx->tx_conf == TLS_HW)
312		tls_device_write_space(sk, ctx);
313	else
314#endif
315		tls_sw_write_space(sk, ctx);
316
317	ctx->sk_write_space(sk);
318}
319
320/**
321 * tls_ctx_free() - free TLS ULP context
322 * @sk:  socket to with @ctx is attached
323 * @ctx: TLS context structure
324 *
325 * Free TLS context. If @sk is %NULL caller guarantees that the socket
326 * to which @ctx was attached has no outstanding references.
327 */
328void tls_ctx_free(struct sock *sk, struct tls_context *ctx)
329{
330	if (!ctx)
331		return;
332
333	memzero_explicit(&ctx->crypto_send, sizeof(ctx->crypto_send));
334	memzero_explicit(&ctx->crypto_recv, sizeof(ctx->crypto_recv));
335	mutex_destroy(&ctx->tx_lock);
336
337	if (sk)
338		kfree_rcu(ctx, rcu);
339	else
340		kfree(ctx);
341}
342
343static void tls_sk_proto_cleanup(struct sock *sk,
344				 struct tls_context *ctx, long timeo)
345{
346	if (unlikely(sk->sk_write_pending) &&
347	    !wait_on_pending_writer(sk, &timeo))
348		tls_handle_open_record(sk, 0);
349
350	/* We need these for tls_sw_fallback handling of other packets */
351	if (ctx->tx_conf == TLS_SW) {
352		tls_sw_release_resources_tx(sk);
353		TLS_DEC_STATS(sock_net(sk), LINUX_MIB_TLSCURRTXSW);
354	} else if (ctx->tx_conf == TLS_HW) {
355		tls_device_free_resources_tx(sk);
356		TLS_DEC_STATS(sock_net(sk), LINUX_MIB_TLSCURRTXDEVICE);
357	}
358
359	if (ctx->rx_conf == TLS_SW) {
360		tls_sw_release_resources_rx(sk);
361		TLS_DEC_STATS(sock_net(sk), LINUX_MIB_TLSCURRRXSW);
362	} else if (ctx->rx_conf == TLS_HW) {
363		tls_device_offload_cleanup_rx(sk);
364		TLS_DEC_STATS(sock_net(sk), LINUX_MIB_TLSCURRRXDEVICE);
365	}
366}
367
368static void tls_sk_proto_close(struct sock *sk, long timeout)
369{
370	struct inet_connection_sock *icsk = inet_csk(sk);
371	struct tls_context *ctx = tls_get_ctx(sk);
372	long timeo = sock_sndtimeo(sk, 0);
373	bool free_ctx;
374
375	if (ctx->tx_conf == TLS_SW)
376		tls_sw_cancel_work_tx(ctx);
377
378	lock_sock(sk);
379	free_ctx = ctx->tx_conf != TLS_HW && ctx->rx_conf != TLS_HW;
380
381	if (ctx->tx_conf != TLS_BASE || ctx->rx_conf != TLS_BASE)
382		tls_sk_proto_cleanup(sk, ctx, timeo);
383
384	write_lock_bh(&sk->sk_callback_lock);
385	if (free_ctx)
386		rcu_assign_pointer(icsk->icsk_ulp_data, NULL);
387	WRITE_ONCE(sk->sk_prot, ctx->sk_proto);
388	if (sk->sk_write_space == tls_write_space)
389		sk->sk_write_space = ctx->sk_write_space;
390	write_unlock_bh(&sk->sk_callback_lock);
391	release_sock(sk);
392	if (ctx->tx_conf == TLS_SW)
393		tls_sw_free_ctx_tx(ctx);
394	if (ctx->rx_conf == TLS_SW || ctx->rx_conf == TLS_HW)
395		tls_sw_strparser_done(ctx);
396	if (ctx->rx_conf == TLS_SW)
397		tls_sw_free_ctx_rx(ctx);
398	ctx->sk_proto->close(sk, timeout);
399
400	if (free_ctx)
401		tls_ctx_free(sk, ctx);
402}
403
404static __poll_t tls_sk_poll(struct file *file, struct socket *sock,
405			    struct poll_table_struct *wait)
406{
407	struct tls_sw_context_rx *ctx;
408	struct tls_context *tls_ctx;
409	struct sock *sk = sock->sk;
410	struct sk_psock *psock;
411	__poll_t mask = 0;
412	u8 shutdown;
413	int state;
414
415	mask = tcp_poll(file, sock, wait);
416
417	state = inet_sk_state_load(sk);
418	shutdown = READ_ONCE(sk->sk_shutdown);
419	if (unlikely(state != TCP_ESTABLISHED || shutdown & RCV_SHUTDOWN))
420		return mask;
421
422	tls_ctx = tls_get_ctx(sk);
423	ctx = tls_sw_ctx_rx(tls_ctx);
424	psock = sk_psock_get(sk);
425
426	if (skb_queue_empty_lockless(&ctx->rx_list) &&
427	    !tls_strp_msg_ready(ctx) &&
428	    sk_psock_queue_empty(psock))
429		mask &= ~(EPOLLIN | EPOLLRDNORM);
430
431	if (psock)
432		sk_psock_put(sk, psock);
433
434	return mask;
435}
436
437static int do_tls_getsockopt_conf(struct sock *sk, char __user *optval,
438				  int __user *optlen, int tx)
439{
440	int rc = 0;
441	const struct tls_cipher_desc *cipher_desc;
442	struct tls_context *ctx = tls_get_ctx(sk);
443	struct tls_crypto_info *crypto_info;
444	struct cipher_context *cctx;
445	int len;
446
447	if (get_user(len, optlen))
448		return -EFAULT;
449
450	if (!optval || (len < sizeof(*crypto_info))) {
451		rc = -EINVAL;
452		goto out;
453	}
454
455	if (!ctx) {
456		rc = -EBUSY;
457		goto out;
458	}
459
460	/* get user crypto info */
461	if (tx) {
462		crypto_info = &ctx->crypto_send.info;
463		cctx = &ctx->tx;
464	} else {
465		crypto_info = &ctx->crypto_recv.info;
466		cctx = &ctx->rx;
467	}
468
469	if (!TLS_CRYPTO_INFO_READY(crypto_info)) {
470		rc = -EBUSY;
471		goto out;
472	}
473
474	if (len == sizeof(*crypto_info)) {
475		if (copy_to_user(optval, crypto_info, sizeof(*crypto_info)))
476			rc = -EFAULT;
477		goto out;
478	}
479
480	cipher_desc = get_cipher_desc(crypto_info->cipher_type);
481	if (!cipher_desc || len != cipher_desc->crypto_info) {
482		rc = -EINVAL;
483		goto out;
484	}
485
486	memcpy(crypto_info_iv(crypto_info, cipher_desc),
487	       cctx->iv + cipher_desc->salt, cipher_desc->iv);
488	memcpy(crypto_info_rec_seq(crypto_info, cipher_desc),
489	       cctx->rec_seq, cipher_desc->rec_seq);
490
491	if (copy_to_user(optval, crypto_info, cipher_desc->crypto_info))
492		rc = -EFAULT;
493
494out:
495	return rc;
496}
497
498static int do_tls_getsockopt_tx_zc(struct sock *sk, char __user *optval,
499				   int __user *optlen)
500{
501	struct tls_context *ctx = tls_get_ctx(sk);
502	unsigned int value;
503	int len;
504
505	if (get_user(len, optlen))
506		return -EFAULT;
507
508	if (len != sizeof(value))
509		return -EINVAL;
510
511	value = ctx->zerocopy_sendfile;
512	if (copy_to_user(optval, &value, sizeof(value)))
513		return -EFAULT;
514
515	return 0;
516}
517
518static int do_tls_getsockopt_no_pad(struct sock *sk, char __user *optval,
519				    int __user *optlen)
520{
521	struct tls_context *ctx = tls_get_ctx(sk);
522	int value, len;
523
524	if (ctx->prot_info.version != TLS_1_3_VERSION)
525		return -EINVAL;
526
527	if (get_user(len, optlen))
528		return -EFAULT;
529	if (len < sizeof(value))
530		return -EINVAL;
531
532	value = -EINVAL;
533	if (ctx->rx_conf == TLS_SW || ctx->rx_conf == TLS_HW)
534		value = ctx->rx_no_pad;
535	if (value < 0)
536		return value;
537
538	if (put_user(sizeof(value), optlen))
539		return -EFAULT;
540	if (copy_to_user(optval, &value, sizeof(value)))
541		return -EFAULT;
542
543	return 0;
544}
545
546static int do_tls_getsockopt(struct sock *sk, int optname,
547			     char __user *optval, int __user *optlen)
548{
549	int rc = 0;
550
551	lock_sock(sk);
552
553	switch (optname) {
554	case TLS_TX:
555	case TLS_RX:
556		rc = do_tls_getsockopt_conf(sk, optval, optlen,
557					    optname == TLS_TX);
558		break;
559	case TLS_TX_ZEROCOPY_RO:
560		rc = do_tls_getsockopt_tx_zc(sk, optval, optlen);
561		break;
562	case TLS_RX_EXPECT_NO_PAD:
563		rc = do_tls_getsockopt_no_pad(sk, optval, optlen);
564		break;
565	default:
566		rc = -ENOPROTOOPT;
567		break;
568	}
569
570	release_sock(sk);
571
572	return rc;
573}
574
575static int tls_getsockopt(struct sock *sk, int level, int optname,
576			  char __user *optval, int __user *optlen)
577{
578	struct tls_context *ctx = tls_get_ctx(sk);
579
580	if (level != SOL_TLS)
581		return ctx->sk_proto->getsockopt(sk, level,
582						 optname, optval, optlen);
583
584	return do_tls_getsockopt(sk, optname, optval, optlen);
585}
586
587static int validate_crypto_info(const struct tls_crypto_info *crypto_info,
588				const struct tls_crypto_info *alt_crypto_info)
589{
590	if (crypto_info->version != TLS_1_2_VERSION &&
591	    crypto_info->version != TLS_1_3_VERSION)
592		return -EINVAL;
593
594	switch (crypto_info->cipher_type) {
595	case TLS_CIPHER_ARIA_GCM_128:
596	case TLS_CIPHER_ARIA_GCM_256:
597		if (crypto_info->version != TLS_1_2_VERSION)
598			return -EINVAL;
599		break;
600	}
601
602	/* Ensure that TLS version and ciphers are same in both directions */
603	if (TLS_CRYPTO_INFO_READY(alt_crypto_info)) {
604		if (alt_crypto_info->version != crypto_info->version ||
605		    alt_crypto_info->cipher_type != crypto_info->cipher_type)
606			return -EINVAL;
607	}
608
609	return 0;
610}
611
612static int do_tls_setsockopt_conf(struct sock *sk, sockptr_t optval,
613				  unsigned int optlen, int tx)
614{
615	struct tls_crypto_info *crypto_info;
616	struct tls_crypto_info *alt_crypto_info;
617	struct tls_context *ctx = tls_get_ctx(sk);
618	const struct tls_cipher_desc *cipher_desc;
619	int rc = 0;
620	int conf;
621
622	if (sockptr_is_null(optval) || (optlen < sizeof(*crypto_info)))
623		return -EINVAL;
624
625	if (tx) {
626		crypto_info = &ctx->crypto_send.info;
627		alt_crypto_info = &ctx->crypto_recv.info;
628	} else {
629		crypto_info = &ctx->crypto_recv.info;
630		alt_crypto_info = &ctx->crypto_send.info;
631	}
632
633	/* Currently we don't support set crypto info more than one time */
634	if (TLS_CRYPTO_INFO_READY(crypto_info))
635		return -EBUSY;
636
637	rc = copy_from_sockptr(crypto_info, optval, sizeof(*crypto_info));
638	if (rc) {
639		rc = -EFAULT;
640		goto err_crypto_info;
641	}
642
643	rc = validate_crypto_info(crypto_info, alt_crypto_info);
644	if (rc)
645		goto err_crypto_info;
646
647	cipher_desc = get_cipher_desc(crypto_info->cipher_type);
648	if (!cipher_desc) {
649		rc = -EINVAL;
650		goto err_crypto_info;
651	}
652
653	if (optlen != cipher_desc->crypto_info) {
654		rc = -EINVAL;
655		goto err_crypto_info;
656	}
657
658	rc = copy_from_sockptr_offset(crypto_info + 1, optval,
659				      sizeof(*crypto_info),
660				      optlen - sizeof(*crypto_info));
661	if (rc) {
662		rc = -EFAULT;
663		goto err_crypto_info;
664	}
665
666	if (tx) {
667		rc = tls_set_device_offload(sk);
668		conf = TLS_HW;
669		if (!rc) {
670			TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSTXDEVICE);
671			TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSCURRTXDEVICE);
672		} else {
673			rc = tls_set_sw_offload(sk, 1);
674			if (rc)
675				goto err_crypto_info;
676			TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSTXSW);
677			TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSCURRTXSW);
678			conf = TLS_SW;
679		}
680	} else {
681		rc = tls_set_device_offload_rx(sk, ctx);
682		conf = TLS_HW;
683		if (!rc) {
684			TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSRXDEVICE);
685			TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSCURRRXDEVICE);
686		} else {
687			rc = tls_set_sw_offload(sk, 0);
688			if (rc)
689				goto err_crypto_info;
690			TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSRXSW);
691			TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSCURRRXSW);
692			conf = TLS_SW;
693		}
694		tls_sw_strparser_arm(sk, ctx);
695	}
696
697	if (tx)
698		ctx->tx_conf = conf;
699	else
700		ctx->rx_conf = conf;
701	update_sk_prot(sk, ctx);
702	if (tx) {
703		ctx->sk_write_space = sk->sk_write_space;
704		sk->sk_write_space = tls_write_space;
705	} else {
706		struct tls_sw_context_rx *rx_ctx = tls_sw_ctx_rx(ctx);
707
708		tls_strp_check_rcv(&rx_ctx->strp);
709	}
710	return 0;
711
712err_crypto_info:
713	memzero_explicit(crypto_info, sizeof(union tls_crypto_context));
714	return rc;
715}
716
717static int do_tls_setsockopt_tx_zc(struct sock *sk, sockptr_t optval,
718				   unsigned int optlen)
719{
720	struct tls_context *ctx = tls_get_ctx(sk);
721	unsigned int value;
722
723	if (sockptr_is_null(optval) || optlen != sizeof(value))
724		return -EINVAL;
725
726	if (copy_from_sockptr(&value, optval, sizeof(value)))
727		return -EFAULT;
728
729	if (value > 1)
730		return -EINVAL;
731
732	ctx->zerocopy_sendfile = value;
733
734	return 0;
735}
736
737static int do_tls_setsockopt_no_pad(struct sock *sk, sockptr_t optval,
738				    unsigned int optlen)
739{
740	struct tls_context *ctx = tls_get_ctx(sk);
741	u32 val;
742	int rc;
743
744	if (ctx->prot_info.version != TLS_1_3_VERSION ||
745	    sockptr_is_null(optval) || optlen < sizeof(val))
746		return -EINVAL;
747
748	rc = copy_from_sockptr(&val, optval, sizeof(val));
749	if (rc)
750		return -EFAULT;
751	if (val > 1)
752		return -EINVAL;
753	rc = check_zeroed_sockptr(optval, sizeof(val), optlen - sizeof(val));
754	if (rc < 1)
755		return rc == 0 ? -EINVAL : rc;
756
757	lock_sock(sk);
758	rc = -EINVAL;
759	if (ctx->rx_conf == TLS_SW || ctx->rx_conf == TLS_HW) {
760		ctx->rx_no_pad = val;
761		tls_update_rx_zc_capable(ctx);
762		rc = 0;
763	}
764	release_sock(sk);
765
766	return rc;
767}
768
769static int do_tls_setsockopt(struct sock *sk, int optname, sockptr_t optval,
770			     unsigned int optlen)
771{
772	int rc = 0;
773
774	switch (optname) {
775	case TLS_TX:
776	case TLS_RX:
777		lock_sock(sk);
778		rc = do_tls_setsockopt_conf(sk, optval, optlen,
779					    optname == TLS_TX);
780		release_sock(sk);
781		break;
782	case TLS_TX_ZEROCOPY_RO:
783		lock_sock(sk);
784		rc = do_tls_setsockopt_tx_zc(sk, optval, optlen);
785		release_sock(sk);
786		break;
787	case TLS_RX_EXPECT_NO_PAD:
788		rc = do_tls_setsockopt_no_pad(sk, optval, optlen);
789		break;
790	default:
791		rc = -ENOPROTOOPT;
792		break;
793	}
794	return rc;
795}
796
797static int tls_setsockopt(struct sock *sk, int level, int optname,
798			  sockptr_t optval, unsigned int optlen)
799{
800	struct tls_context *ctx = tls_get_ctx(sk);
801
802	if (level != SOL_TLS)
803		return ctx->sk_proto->setsockopt(sk, level, optname, optval,
804						 optlen);
805
806	return do_tls_setsockopt(sk, optname, optval, optlen);
807}
808
809struct tls_context *tls_ctx_create(struct sock *sk)
810{
811	struct inet_connection_sock *icsk = inet_csk(sk);
812	struct tls_context *ctx;
813
814	ctx = kzalloc(sizeof(*ctx), GFP_ATOMIC);
815	if (!ctx)
816		return NULL;
817
818	mutex_init(&ctx->tx_lock);
819	rcu_assign_pointer(icsk->icsk_ulp_data, ctx);
820	ctx->sk_proto = READ_ONCE(sk->sk_prot);
821	ctx->sk = sk;
822	return ctx;
823}
824
825static void build_proto_ops(struct proto_ops ops[TLS_NUM_CONFIG][TLS_NUM_CONFIG],
826			    const struct proto_ops *base)
827{
828	ops[TLS_BASE][TLS_BASE] = *base;
829
830	ops[TLS_SW  ][TLS_BASE] = ops[TLS_BASE][TLS_BASE];
831	ops[TLS_SW  ][TLS_BASE].splice_eof	= tls_sw_splice_eof;
832
833	ops[TLS_BASE][TLS_SW  ] = ops[TLS_BASE][TLS_BASE];
834	ops[TLS_BASE][TLS_SW  ].splice_read	= tls_sw_splice_read;
835	ops[TLS_BASE][TLS_SW  ].poll		= tls_sk_poll;
836	ops[TLS_BASE][TLS_SW  ].read_sock	= tls_sw_read_sock;
837
838	ops[TLS_SW  ][TLS_SW  ] = ops[TLS_SW  ][TLS_BASE];
839	ops[TLS_SW  ][TLS_SW  ].splice_read	= tls_sw_splice_read;
840	ops[TLS_SW  ][TLS_SW  ].poll		= tls_sk_poll;
841	ops[TLS_SW  ][TLS_SW  ].read_sock	= tls_sw_read_sock;
842
843#ifdef CONFIG_TLS_DEVICE
844	ops[TLS_HW  ][TLS_BASE] = ops[TLS_BASE][TLS_BASE];
845
846	ops[TLS_HW  ][TLS_SW  ] = ops[TLS_BASE][TLS_SW  ];
847
848	ops[TLS_BASE][TLS_HW  ] = ops[TLS_BASE][TLS_SW  ];
849
850	ops[TLS_SW  ][TLS_HW  ] = ops[TLS_SW  ][TLS_SW  ];
851
852	ops[TLS_HW  ][TLS_HW  ] = ops[TLS_HW  ][TLS_SW  ];
853#endif
854#ifdef CONFIG_TLS_TOE
855	ops[TLS_HW_RECORD][TLS_HW_RECORD] = *base;
856#endif
857}
858
859static void tls_build_proto(struct sock *sk)
860{
861	int ip_ver = sk->sk_family == AF_INET6 ? TLSV6 : TLSV4;
862	struct proto *prot = READ_ONCE(sk->sk_prot);
863
864	/* Build IPv6 TLS whenever the address of tcpv6 _prot changes */
865	if (ip_ver == TLSV6 &&
866	    unlikely(prot != smp_load_acquire(&saved_tcpv6_prot))) {
867		mutex_lock(&tcpv6_prot_mutex);
868		if (likely(prot != saved_tcpv6_prot)) {
869			build_protos(tls_prots[TLSV6], prot);
870			build_proto_ops(tls_proto_ops[TLSV6],
871					sk->sk_socket->ops);
872			smp_store_release(&saved_tcpv6_prot, prot);
873		}
874		mutex_unlock(&tcpv6_prot_mutex);
875	}
876
877	if (ip_ver == TLSV4 &&
878	    unlikely(prot != smp_load_acquire(&saved_tcpv4_prot))) {
879		mutex_lock(&tcpv4_prot_mutex);
880		if (likely(prot != saved_tcpv4_prot)) {
881			build_protos(tls_prots[TLSV4], prot);
882			build_proto_ops(tls_proto_ops[TLSV4],
883					sk->sk_socket->ops);
884			smp_store_release(&saved_tcpv4_prot, prot);
885		}
886		mutex_unlock(&tcpv4_prot_mutex);
887	}
888}
889
890static void build_protos(struct proto prot[TLS_NUM_CONFIG][TLS_NUM_CONFIG],
891			 const struct proto *base)
892{
893	prot[TLS_BASE][TLS_BASE] = *base;
894	prot[TLS_BASE][TLS_BASE].setsockopt	= tls_setsockopt;
895	prot[TLS_BASE][TLS_BASE].getsockopt	= tls_getsockopt;
896	prot[TLS_BASE][TLS_BASE].close		= tls_sk_proto_close;
897
898	prot[TLS_SW][TLS_BASE] = prot[TLS_BASE][TLS_BASE];
899	prot[TLS_SW][TLS_BASE].sendmsg		= tls_sw_sendmsg;
900	prot[TLS_SW][TLS_BASE].splice_eof	= tls_sw_splice_eof;
901
902	prot[TLS_BASE][TLS_SW] = prot[TLS_BASE][TLS_BASE];
903	prot[TLS_BASE][TLS_SW].recvmsg		  = tls_sw_recvmsg;
904	prot[TLS_BASE][TLS_SW].sock_is_readable   = tls_sw_sock_is_readable;
905	prot[TLS_BASE][TLS_SW].close		  = tls_sk_proto_close;
906
907	prot[TLS_SW][TLS_SW] = prot[TLS_SW][TLS_BASE];
908	prot[TLS_SW][TLS_SW].recvmsg		= tls_sw_recvmsg;
909	prot[TLS_SW][TLS_SW].sock_is_readable   = tls_sw_sock_is_readable;
910	prot[TLS_SW][TLS_SW].close		= tls_sk_proto_close;
911
912#ifdef CONFIG_TLS_DEVICE
913	prot[TLS_HW][TLS_BASE] = prot[TLS_BASE][TLS_BASE];
914	prot[TLS_HW][TLS_BASE].sendmsg		= tls_device_sendmsg;
915	prot[TLS_HW][TLS_BASE].splice_eof	= tls_device_splice_eof;
916
917	prot[TLS_HW][TLS_SW] = prot[TLS_BASE][TLS_SW];
918	prot[TLS_HW][TLS_SW].sendmsg		= tls_device_sendmsg;
919	prot[TLS_HW][TLS_SW].splice_eof		= tls_device_splice_eof;
920
921	prot[TLS_BASE][TLS_HW] = prot[TLS_BASE][TLS_SW];
922
923	prot[TLS_SW][TLS_HW] = prot[TLS_SW][TLS_SW];
924
925	prot[TLS_HW][TLS_HW] = prot[TLS_HW][TLS_SW];
926#endif
927#ifdef CONFIG_TLS_TOE
928	prot[TLS_HW_RECORD][TLS_HW_RECORD] = *base;
929	prot[TLS_HW_RECORD][TLS_HW_RECORD].hash		= tls_toe_hash;
930	prot[TLS_HW_RECORD][TLS_HW_RECORD].unhash	= tls_toe_unhash;
931#endif
932}
933
934static int tls_init(struct sock *sk)
935{
936	struct tls_context *ctx;
937	int rc = 0;
938
939	tls_build_proto(sk);
940
941#ifdef CONFIG_TLS_TOE
942	if (tls_toe_bypass(sk))
943		return 0;
944#endif
945
946	/* The TLS ulp is currently supported only for TCP sockets
947	 * in ESTABLISHED state.
948	 * Supporting sockets in LISTEN state will require us
949	 * to modify the accept implementation to clone rather then
950	 * share the ulp context.
951	 */
952	if (sk->sk_state != TCP_ESTABLISHED)
953		return -ENOTCONN;
954
955	/* allocate tls context */
956	write_lock_bh(&sk->sk_callback_lock);
957	ctx = tls_ctx_create(sk);
958	if (!ctx) {
959		rc = -ENOMEM;
960		goto out;
961	}
962
963	ctx->tx_conf = TLS_BASE;
964	ctx->rx_conf = TLS_BASE;
965	update_sk_prot(sk, ctx);
966out:
967	write_unlock_bh(&sk->sk_callback_lock);
968	return rc;
969}
970
971static void tls_update(struct sock *sk, struct proto *p,
972		       void (*write_space)(struct sock *sk))
973{
974	struct tls_context *ctx;
975
976	WARN_ON_ONCE(sk->sk_prot == p);
977
978	ctx = tls_get_ctx(sk);
979	if (likely(ctx)) {
980		ctx->sk_write_space = write_space;
981		ctx->sk_proto = p;
982	} else {
983		/* Pairs with lockless read in sk_clone_lock(). */
984		WRITE_ONCE(sk->sk_prot, p);
985		sk->sk_write_space = write_space;
986	}
987}
988
989static u16 tls_user_config(struct tls_context *ctx, bool tx)
990{
991	u16 config = tx ? ctx->tx_conf : ctx->rx_conf;
992
993	switch (config) {
994	case TLS_BASE:
995		return TLS_CONF_BASE;
996	case TLS_SW:
997		return TLS_CONF_SW;
998	case TLS_HW:
999		return TLS_CONF_HW;
1000	case TLS_HW_RECORD:
1001		return TLS_CONF_HW_RECORD;
1002	}
1003	return 0;
1004}
1005
1006static int tls_get_info(struct sock *sk, struct sk_buff *skb)
1007{
1008	u16 version, cipher_type;
1009	struct tls_context *ctx;
1010	struct nlattr *start;
1011	int err;
1012
1013	start = nla_nest_start_noflag(skb, INET_ULP_INFO_TLS);
1014	if (!start)
1015		return -EMSGSIZE;
1016
1017	rcu_read_lock();
1018	ctx = rcu_dereference(inet_csk(sk)->icsk_ulp_data);
1019	if (!ctx) {
1020		err = 0;
1021		goto nla_failure;
1022	}
1023	version = ctx->prot_info.version;
1024	if (version) {
1025		err = nla_put_u16(skb, TLS_INFO_VERSION, version);
1026		if (err)
1027			goto nla_failure;
1028	}
1029	cipher_type = ctx->prot_info.cipher_type;
1030	if (cipher_type) {
1031		err = nla_put_u16(skb, TLS_INFO_CIPHER, cipher_type);
1032		if (err)
1033			goto nla_failure;
1034	}
1035	err = nla_put_u16(skb, TLS_INFO_TXCONF, tls_user_config(ctx, true));
1036	if (err)
1037		goto nla_failure;
1038
1039	err = nla_put_u16(skb, TLS_INFO_RXCONF, tls_user_config(ctx, false));
1040	if (err)
1041		goto nla_failure;
1042
1043	if (ctx->tx_conf == TLS_HW && ctx->zerocopy_sendfile) {
1044		err = nla_put_flag(skb, TLS_INFO_ZC_RO_TX);
1045		if (err)
1046			goto nla_failure;
1047	}
1048	if (ctx->rx_no_pad) {
1049		err = nla_put_flag(skb, TLS_INFO_RX_NO_PAD);
1050		if (err)
1051			goto nla_failure;
1052	}
1053
1054	rcu_read_unlock();
1055	nla_nest_end(skb, start);
1056	return 0;
1057
1058nla_failure:
1059	rcu_read_unlock();
1060	nla_nest_cancel(skb, start);
1061	return err;
1062}
1063
1064static size_t tls_get_info_size(const struct sock *sk)
1065{
1066	size_t size = 0;
1067
1068	size += nla_total_size(0) +		/* INET_ULP_INFO_TLS */
1069		nla_total_size(sizeof(u16)) +	/* TLS_INFO_VERSION */
1070		nla_total_size(sizeof(u16)) +	/* TLS_INFO_CIPHER */
1071		nla_total_size(sizeof(u16)) +	/* TLS_INFO_RXCONF */
1072		nla_total_size(sizeof(u16)) +	/* TLS_INFO_TXCONF */
1073		nla_total_size(0) +		/* TLS_INFO_ZC_RO_TX */
1074		nla_total_size(0) +		/* TLS_INFO_RX_NO_PAD */
1075		0;
1076
1077	return size;
1078}
1079
1080static int __net_init tls_init_net(struct net *net)
1081{
1082	int err;
1083
1084	net->mib.tls_statistics = alloc_percpu(struct linux_tls_mib);
1085	if (!net->mib.tls_statistics)
1086		return -ENOMEM;
1087
1088	err = tls_proc_init(net);
1089	if (err)
1090		goto err_free_stats;
1091
1092	return 0;
1093err_free_stats:
1094	free_percpu(net->mib.tls_statistics);
1095	return err;
1096}
1097
1098static void __net_exit tls_exit_net(struct net *net)
1099{
1100	tls_proc_fini(net);
1101	free_percpu(net->mib.tls_statistics);
1102}
1103
1104static struct pernet_operations tls_proc_ops = {
1105	.init = tls_init_net,
1106	.exit = tls_exit_net,
1107};
1108
1109static struct tcp_ulp_ops tcp_tls_ulp_ops __read_mostly = {
1110	.name			= "tls",
1111	.owner			= THIS_MODULE,
1112	.init			= tls_init,
1113	.update			= tls_update,
1114	.get_info		= tls_get_info,
1115	.get_info_size		= tls_get_info_size,
1116};
1117
1118static int __init tls_register(void)
1119{
1120	int err;
1121
1122	err = register_pernet_subsys(&tls_proc_ops);
1123	if (err)
1124		return err;
1125
1126	err = tls_strp_dev_init();
1127	if (err)
1128		goto err_pernet;
1129
1130	err = tls_device_init();
1131	if (err)
1132		goto err_strp;
1133
1134	tcp_register_ulp(&tcp_tls_ulp_ops);
1135
1136	return 0;
1137err_strp:
1138	tls_strp_dev_exit();
1139err_pernet:
1140	unregister_pernet_subsys(&tls_proc_ops);
1141	return err;
1142}
1143
1144static void __exit tls_unregister(void)
1145{
1146	tcp_unregister_ulp(&tcp_tls_ulp_ops);
1147	tls_strp_dev_exit();
1148	tls_device_cleanup();
1149	unregister_pernet_subsys(&tls_proc_ops);
1150}
1151
1152module_init(tls_register);
1153module_exit(tls_unregister);
1154