1// SPDX-License-Identifier: GPL-2.0-only
2/*
3 * Generic netlink handshake service
4 *
5 * Author: Chuck Lever <chuck.lever@oracle.com>
6 *
7 * Copyright (c) 2023, Oracle and/or its affiliates.
8 */
9
10#include <linux/types.h>
11#include <linux/socket.h>
12#include <linux/kernel.h>
13#include <linux/module.h>
14#include <linux/skbuff.h>
15#include <linux/mm.h>
16
17#include <net/sock.h>
18#include <net/genetlink.h>
19#include <net/netns/generic.h>
20
21#include <kunit/visibility.h>
22
23#include <uapi/linux/handshake.h>
24#include "handshake.h"
25#include "genl.h"
26
27#include <trace/events/handshake.h>
28
29/**
30 * handshake_genl_notify - Notify handlers that a request is waiting
31 * @net: target network namespace
32 * @proto: handshake protocol
33 * @flags: memory allocation control flags
34 *
35 * Returns zero on success or a negative errno if notification failed.
36 */
37int handshake_genl_notify(struct net *net, const struct handshake_proto *proto,
38			  gfp_t flags)
39{
40	struct sk_buff *msg;
41	void *hdr;
42
43	/* Disable notifications during unit testing */
44	if (!test_bit(HANDSHAKE_F_PROTO_NOTIFY, &proto->hp_flags))
45		return 0;
46
47	if (!genl_has_listeners(&handshake_nl_family, net,
48				proto->hp_handler_class))
49		return -ESRCH;
50
51	msg = genlmsg_new(GENLMSG_DEFAULT_SIZE, flags);
52	if (!msg)
53		return -ENOMEM;
54
55	hdr = genlmsg_put(msg, 0, 0, &handshake_nl_family, 0,
56			  HANDSHAKE_CMD_READY);
57	if (!hdr)
58		goto out_free;
59
60	if (nla_put_u32(msg, HANDSHAKE_A_ACCEPT_HANDLER_CLASS,
61			proto->hp_handler_class) < 0) {
62		genlmsg_cancel(msg, hdr);
63		goto out_free;
64	}
65
66	genlmsg_end(msg, hdr);
67	return genlmsg_multicast_netns(&handshake_nl_family, net, msg,
68				       0, proto->hp_handler_class, flags);
69
70out_free:
71	nlmsg_free(msg);
72	return -EMSGSIZE;
73}
74
75/**
76 * handshake_genl_put - Create a generic netlink message header
77 * @msg: buffer in which to create the header
78 * @info: generic netlink message context
79 *
80 * Returns a ready-to-use header, or NULL.
81 */
82struct nlmsghdr *handshake_genl_put(struct sk_buff *msg,
83				    struct genl_info *info)
84{
85	return genlmsg_put(msg, info->snd_portid, info->snd_seq,
86			   &handshake_nl_family, 0, info->genlhdr->cmd);
87}
88EXPORT_SYMBOL(handshake_genl_put);
89
90int handshake_nl_accept_doit(struct sk_buff *skb, struct genl_info *info)
91{
92	struct net *net = sock_net(skb->sk);
93	struct handshake_net *hn = handshake_pernet(net);
94	struct handshake_req *req = NULL;
95	struct socket *sock;
96	int class, fd, err;
97
98	err = -EOPNOTSUPP;
99	if (!hn)
100		goto out_status;
101
102	err = -EINVAL;
103	if (GENL_REQ_ATTR_CHECK(info, HANDSHAKE_A_ACCEPT_HANDLER_CLASS))
104		goto out_status;
105	class = nla_get_u32(info->attrs[HANDSHAKE_A_ACCEPT_HANDLER_CLASS]);
106
107	err = -EAGAIN;
108	req = handshake_req_next(hn, class);
109	if (!req)
110		goto out_status;
111
112	sock = req->hr_sk->sk_socket;
113	fd = get_unused_fd_flags(O_CLOEXEC);
114	if (fd < 0) {
115		err = fd;
116		goto out_complete;
117	}
118
119	err = req->hr_proto->hp_accept(req, info, fd);
120	if (err) {
121		put_unused_fd(fd);
122		goto out_complete;
123	}
124
125	fd_install(fd, get_file(sock->file));
126
127	trace_handshake_cmd_accept(net, req, req->hr_sk, fd);
128	return 0;
129
130out_complete:
131	handshake_complete(req, -EIO, NULL);
132out_status:
133	trace_handshake_cmd_accept_err(net, req, NULL, err);
134	return err;
135}
136
137int handshake_nl_done_doit(struct sk_buff *skb, struct genl_info *info)
138{
139	struct net *net = sock_net(skb->sk);
140	struct handshake_req *req;
141	struct socket *sock;
142	int fd, status, err;
143
144	if (GENL_REQ_ATTR_CHECK(info, HANDSHAKE_A_DONE_SOCKFD))
145		return -EINVAL;
146	fd = nla_get_s32(info->attrs[HANDSHAKE_A_DONE_SOCKFD]);
147
148	sock = sockfd_lookup(fd, &err);
149	if (!sock)
150		return err;
151
152	req = handshake_req_hash_lookup(sock->sk);
153	if (!req) {
154		err = -EBUSY;
155		trace_handshake_cmd_done_err(net, req, sock->sk, err);
156		fput(sock->file);
157		return err;
158	}
159
160	trace_handshake_cmd_done(net, req, sock->sk, fd);
161
162	status = -EIO;
163	if (info->attrs[HANDSHAKE_A_DONE_STATUS])
164		status = nla_get_u32(info->attrs[HANDSHAKE_A_DONE_STATUS]);
165
166	handshake_complete(req, status, info);
167	fput(sock->file);
168	return 0;
169}
170
171static unsigned int handshake_net_id;
172
173static int __net_init handshake_net_init(struct net *net)
174{
175	struct handshake_net *hn = net_generic(net, handshake_net_id);
176	unsigned long tmp;
177	struct sysinfo si;
178
179	/*
180	 * Arbitrary limit to prevent handshakes that do not make
181	 * progress from clogging up the system. The cap scales up
182	 * with the amount of physical memory on the system.
183	 */
184	si_meminfo(&si);
185	tmp = si.totalram / (25 * si.mem_unit);
186	hn->hn_pending_max = clamp(tmp, 3UL, 50UL);
187
188	spin_lock_init(&hn->hn_lock);
189	hn->hn_pending = 0;
190	hn->hn_flags = 0;
191	INIT_LIST_HEAD(&hn->hn_requests);
192	return 0;
193}
194
195static void __net_exit handshake_net_exit(struct net *net)
196{
197	struct handshake_net *hn = net_generic(net, handshake_net_id);
198	struct handshake_req *req;
199	LIST_HEAD(requests);
200
201	/*
202	 * Drain the net's pending list. Requests that have been
203	 * accepted and are in progress will be destroyed when
204	 * the socket is closed.
205	 */
206	spin_lock(&hn->hn_lock);
207	set_bit(HANDSHAKE_F_NET_DRAINING, &hn->hn_flags);
208	list_splice_init(&requests, &hn->hn_requests);
209	spin_unlock(&hn->hn_lock);
210
211	while (!list_empty(&requests)) {
212		req = list_first_entry(&requests, struct handshake_req, hr_list);
213		list_del(&req->hr_list);
214
215		/*
216		 * Requests on this list have not yet been
217		 * accepted, so they do not have an fd to put.
218		 */
219
220		handshake_complete(req, -ETIMEDOUT, NULL);
221	}
222}
223
224static struct pernet_operations handshake_genl_net_ops = {
225	.init		= handshake_net_init,
226	.exit		= handshake_net_exit,
227	.id		= &handshake_net_id,
228	.size		= sizeof(struct handshake_net),
229};
230
231/**
232 * handshake_pernet - Get the handshake private per-net structure
233 * @net: network namespace
234 *
235 * Returns a pointer to the net's private per-net structure for the
236 * handshake module, or NULL if handshake_init() failed.
237 */
238struct handshake_net *handshake_pernet(struct net *net)
239{
240	return handshake_net_id ?
241		net_generic(net, handshake_net_id) : NULL;
242}
243EXPORT_SYMBOL_IF_KUNIT(handshake_pernet);
244
245static int __init handshake_init(void)
246{
247	int ret;
248
249	ret = handshake_req_hash_init();
250	if (ret) {
251		pr_warn("handshake: hash initialization failed (%d)\n", ret);
252		return ret;
253	}
254
255	ret = genl_register_family(&handshake_nl_family);
256	if (ret) {
257		pr_warn("handshake: netlink registration failed (%d)\n", ret);
258		handshake_req_hash_destroy();
259		return ret;
260	}
261
262	/*
263	 * ORDER: register_pernet_subsys must be done last.
264	 *
265	 *	If initialization does not make it past pernet_subsys
266	 *	registration, then handshake_net_id will remain 0. That
267	 *	shunts the handshake consumer API to return ENOTSUPP
268	 *	to prevent it from dereferencing something that hasn't
269	 *	been allocated.
270	 */
271	ret = register_pernet_subsys(&handshake_genl_net_ops);
272	if (ret) {
273		pr_warn("handshake: pernet registration failed (%d)\n", ret);
274		genl_unregister_family(&handshake_nl_family);
275		handshake_req_hash_destroy();
276	}
277
278	return ret;
279}
280
281static void __exit handshake_exit(void)
282{
283	unregister_pernet_subsys(&handshake_genl_net_ops);
284	handshake_net_id = 0;
285
286	handshake_req_hash_destroy();
287	genl_unregister_family(&handshake_nl_family);
288}
289
290module_init(handshake_init);
291module_exit(handshake_exit);
292