1// SPDX-License-Identifier: GPL-2.0
2/* Copyright (c) 2020 Cloudflare Ltd https://cloudflare.com */
3
4#include <linux/skmsg.h>
5#include <net/sock.h>
6#include <net/udp.h>
7#include <net/inet_common.h>
8
9#include "udp_impl.h"
10
11static struct proto *udpv6_prot_saved __read_mostly;
12
13static int sk_udp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
14			  int flags, int *addr_len)
15{
16#if IS_ENABLED(CONFIG_IPV6)
17	if (sk->sk_family == AF_INET6)
18		return udpv6_prot_saved->recvmsg(sk, msg, len, flags, addr_len);
19#endif
20	return udp_prot.recvmsg(sk, msg, len, flags, addr_len);
21}
22
23static bool udp_sk_has_data(struct sock *sk)
24{
25	return !skb_queue_empty(&udp_sk(sk)->reader_queue) ||
26	       !skb_queue_empty(&sk->sk_receive_queue);
27}
28
29static bool psock_has_data(struct sk_psock *psock)
30{
31	return !skb_queue_empty(&psock->ingress_skb) ||
32	       !sk_psock_queue_empty(psock);
33}
34
35#define udp_msg_has_data(__sk, __psock)	\
36		({ udp_sk_has_data(__sk) || psock_has_data(__psock); })
37
38static int udp_msg_wait_data(struct sock *sk, struct sk_psock *psock,
39			     long timeo)
40{
41	DEFINE_WAIT_FUNC(wait, woken_wake_function);
42	int ret = 0;
43
44	if (sk->sk_shutdown & RCV_SHUTDOWN)
45		return 1;
46
47	if (!timeo)
48		return ret;
49
50	add_wait_queue(sk_sleep(sk), &wait);
51	sk_set_bit(SOCKWQ_ASYNC_WAITDATA, sk);
52	ret = udp_msg_has_data(sk, psock);
53	if (!ret) {
54		wait_woken(&wait, TASK_INTERRUPTIBLE, timeo);
55		ret = udp_msg_has_data(sk, psock);
56	}
57	sk_clear_bit(SOCKWQ_ASYNC_WAITDATA, sk);
58	remove_wait_queue(sk_sleep(sk), &wait);
59	return ret;
60}
61
62static int udp_bpf_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
63			   int flags, int *addr_len)
64{
65	struct sk_psock *psock;
66	int copied, ret;
67
68	if (unlikely(flags & MSG_ERRQUEUE))
69		return inet_recv_error(sk, msg, len, addr_len);
70
71	if (!len)
72		return 0;
73
74	psock = sk_psock_get(sk);
75	if (unlikely(!psock))
76		return sk_udp_recvmsg(sk, msg, len, flags, addr_len);
77
78	if (!psock_has_data(psock)) {
79		ret = sk_udp_recvmsg(sk, msg, len, flags, addr_len);
80		goto out;
81	}
82
83msg_bytes_ready:
84	copied = sk_msg_recvmsg(sk, psock, msg, len, flags);
85	if (!copied) {
86		long timeo;
87		int data;
88
89		timeo = sock_rcvtimeo(sk, flags & MSG_DONTWAIT);
90		data = udp_msg_wait_data(sk, psock, timeo);
91		if (data) {
92			if (psock_has_data(psock))
93				goto msg_bytes_ready;
94			ret = sk_udp_recvmsg(sk, msg, len, flags, addr_len);
95			goto out;
96		}
97		copied = -EAGAIN;
98	}
99	ret = copied;
100out:
101	sk_psock_put(sk, psock);
102	return ret;
103}
104
105enum {
106	UDP_BPF_IPV4,
107	UDP_BPF_IPV6,
108	UDP_BPF_NUM_PROTS,
109};
110
111static DEFINE_SPINLOCK(udpv6_prot_lock);
112static struct proto udp_bpf_prots[UDP_BPF_NUM_PROTS];
113
114static void udp_bpf_rebuild_protos(struct proto *prot, const struct proto *base)
115{
116	*prot        = *base;
117	prot->close  = sock_map_close;
118	prot->recvmsg = udp_bpf_recvmsg;
119	prot->sock_is_readable = sk_msg_is_readable;
120}
121
122static void udp_bpf_check_v6_needs_rebuild(struct proto *ops)
123{
124	if (unlikely(ops != smp_load_acquire(&udpv6_prot_saved))) {
125		spin_lock_bh(&udpv6_prot_lock);
126		if (likely(ops != udpv6_prot_saved)) {
127			udp_bpf_rebuild_protos(&udp_bpf_prots[UDP_BPF_IPV6], ops);
128			smp_store_release(&udpv6_prot_saved, ops);
129		}
130		spin_unlock_bh(&udpv6_prot_lock);
131	}
132}
133
134static int __init udp_bpf_v4_build_proto(void)
135{
136	udp_bpf_rebuild_protos(&udp_bpf_prots[UDP_BPF_IPV4], &udp_prot);
137	return 0;
138}
139late_initcall(udp_bpf_v4_build_proto);
140
141int udp_bpf_update_proto(struct sock *sk, struct sk_psock *psock, bool restore)
142{
143	int family = sk->sk_family == AF_INET ? UDP_BPF_IPV4 : UDP_BPF_IPV6;
144
145	if (restore) {
146		sk->sk_write_space = psock->saved_write_space;
147		sock_replace_proto(sk, psock->sk_proto);
148		return 0;
149	}
150
151	if (sk->sk_family == AF_INET6)
152		udp_bpf_check_v6_needs_rebuild(psock->sk_proto);
153
154	sock_replace_proto(sk, &udp_bpf_prots[family]);
155	return 0;
156}
157EXPORT_SYMBOL_GPL(udp_bpf_update_proto);
158